Repository: Oneflow-Inc/oneflow Branch: master Commit: 25c8978c1c8b Files: 4508 Total size: 25.3 MB Directory structure: gitextract_pzk3dhhw/ ├── .clang-format ├── .clang-tidy ├── .cmake-format.py ├── .devcontainer/ │ ├── Dockerfile │ └── devcontainer.json ├── .dockerignore ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── blank_issue.yml │ │ ├── bug_report.md │ │ ├── documention_issue.yml │ │ ├── feature_request.yml │ │ ├── performance_issue.yml │ │ └── question.yml │ ├── PULL_REQUEST_TEMPLATE/ │ │ ├── general_template.md │ │ └── op_template.md │ ├── actions/ │ │ ├── mac-build/ │ │ │ └── action.yml │ │ ├── setup/ │ │ │ └── action.yml │ │ ├── upload_oss/ │ │ │ └── action.yml │ │ ├── upload_ssh/ │ │ │ └── action.yml │ │ └── whl/ │ │ └── action.yml │ ├── scripts/ │ │ ├── requirements.txt │ │ └── set_initial_variables.py │ └── workflows/ │ ├── canary.yml │ ├── community_release.yml │ ├── on_merge.yml │ ├── pr.yml │ ├── priv_release.yml │ ├── release.yml │ ├── simple.yml │ └── test.yml ├── .gitignore ├── .lsan-suppressions ├── .mergify.yml ├── .tsan-suppressions ├── .ubsan-suppressions ├── CMakeLists.txt ├── LICENSE ├── README.md ├── ci/ │ ├── CMakeLists.txt │ ├── build/ │ │ ├── ensure_img.py │ │ └── make.sh │ ├── check/ │ │ ├── clang_tidy_warnings_as_errors_on_diff │ │ ├── lintutils.py │ │ ├── run_clang_format.py │ │ ├── run_clang_tidy.py │ │ ├── run_cmake_format.py │ │ ├── run_license_format.py │ │ └── run_py_format.py │ ├── clang/ │ │ └── build-llvm.sh │ ├── conda/ │ │ ├── build-clang.sh │ │ └── tuna.condarc │ ├── fixed-dev-requirements.txt │ ├── manylinux/ │ │ ├── build-gcc7-xla.sh │ │ ├── build-gcc9.sh │ │ └── build.sh │ ├── requirements.txt │ ├── reset_submodule.sh │ ├── setup_submodule.py │ ├── setup_submodule.sh │ └── test/ │ ├── 1node_benchmark_test.sh │ ├── 1node_benchmark_test_fp16.sh │ ├── 1node_custom_op_test.sh │ ├── 1node_model_eager_test.sh │ ├── 1node_model_test.sh │ ├── 1node_op_test.sh │ ├── 2node_op_test.sh │ ├── 2node_op_test_multi_client.sh │ ├── CMakeLists.txt │ ├── build_docs.sh │ ├── distributed_run.py │ ├── doctest.sh │ ├── excludelist │ ├── expensive_generic_test_multi_client.sh │ ├── generic_test.sh │ ├── generic_test_multi_client.sh │ ├── ir_tests.sh │ ├── multi_client_exception_test.sh │ ├── multi_launch.py │ ├── parallel_run.py │ ├── print_stack_from_core.sh │ ├── print_stack_in_all_dirs.sh │ ├── resource-spec/ │ │ ├── 1x-gtx-1080.json │ │ ├── 2x-rtx-2080.json │ │ └── 4x-rtx-2080ti.json │ ├── test_mock_function.sh │ ├── test_mock_script.sh │ ├── test_resnet50_graph_ddp.sh │ ├── test_speed_multi_client.sh │ └── try_install.sh ├── cmake/ │ ├── caches/ │ │ ├── ci/ │ │ │ ├── canary/ │ │ │ │ └── cuda.cmake │ │ │ ├── cpu-asan-ubsan.cmake │ │ │ ├── cpu-tsan.cmake │ │ │ ├── cpu.cmake │ │ │ ├── cuda-xla.cmake │ │ │ ├── cuda.cmake │ │ │ ├── gh-hosted/ │ │ │ │ ├── cpu-clang.cmake │ │ │ │ └── cpu-gcc.cmake │ │ │ ├── llvm/ │ │ │ │ └── cuda-75-clang.cmake │ │ │ ├── profiler/ │ │ │ │ └── cuda.cmake │ │ │ ├── release/ │ │ │ │ ├── cpu.cmake │ │ │ │ ├── cu118.cmake │ │ │ │ └── cuda.cmake │ │ │ └── serving/ │ │ │ ├── cuda-75.cmake │ │ │ └── openvino.cmake │ │ ├── cn/ │ │ │ ├── cpu.cmake │ │ │ ├── cuda.cmake │ │ │ └── fast/ │ │ │ ├── cpu-clang.cmake │ │ │ ├── cpu.cmake │ │ │ ├── cuda-61-clang.cmake │ │ │ ├── cuda-61.cmake │ │ │ ├── cuda-75-clang.cmake │ │ │ ├── cuda-75.cmake │ │ │ ├── cuda-86.cmake │ │ │ ├── mlir-cpu.cmake │ │ │ ├── mlir-cuda-61.cmake │ │ │ ├── mlir-cuda-75.cmake │ │ │ ├── mlir-cuda-80.cmake │ │ │ └── mlir-cuda-86.cmake │ │ └── international/ │ │ ├── cpu.cmake │ │ └── cuda.cmake │ ├── cuda.cmake │ ├── functional.cmake │ ├── git_version.cmake │ ├── oneflow-config.cmake │ ├── oneflow.cmake │ ├── op_schema.cmake │ ├── platform.cmake │ ├── proto2cpp.cmake │ ├── pybind11.cmake │ ├── python.cmake │ ├── third_party/ │ │ ├── FindBFD.cmake │ │ ├── FindBLAS.cmake │ │ ├── FindCUDNN.cmake │ │ ├── FindUnwind.cmake │ │ ├── absl.cmake │ │ ├── cares.cmake │ │ ├── cocoapi.cmake │ │ ├── cub.cmake │ │ ├── cutlass.cmake │ │ ├── eigen.cmake │ │ ├── flash_attention.cmake │ │ ├── flatbuffers.cmake │ │ ├── glog.cmake │ │ ├── googletest.cmake │ │ ├── grpc.cmake │ │ ├── half.cmake │ │ ├── header_index/ │ │ │ ├── cub_headers.txt │ │ │ ├── grpc_headers.txt │ │ │ ├── libpng_headers.txt │ │ │ └── opencv_headers.txt │ │ ├── hwloc.cmake │ │ ├── json.cmake │ │ ├── libjpeg-turbo.cmake │ │ ├── nccl.cmake │ │ ├── oneDNN.cmake │ │ ├── opencv.cmake │ │ ├── openssl.cmake │ │ ├── patches/ │ │ │ └── tensorflow-logging.patch │ │ ├── protobuf.cmake │ │ ├── re2.cmake │ │ ├── trt_flash_attention.cmake │ │ └── zlib.cmake │ ├── third_party.cmake │ ├── threading.cmake │ └── util.cmake ├── dev-requirements.txt ├── docker/ │ ├── build/ │ │ ├── Dockerfile │ │ ├── build-ubuntu.sh │ │ ├── build.sh │ │ ├── build.ubuntu.dockerfile │ │ ├── launch.sh │ │ └── test.sh │ ├── ci/ │ │ ├── base/ │ │ │ └── Dockerfile │ │ ├── fmt/ │ │ │ ├── Dockerfile │ │ │ └── build.sh │ │ ├── make/ │ │ │ └── Dockerfile │ │ ├── test/ │ │ │ ├── Dockerfile │ │ │ ├── build.sh │ │ │ ├── launch.sh │ │ │ └── requirements.txt │ │ ├── test-v2/ │ │ │ ├── Dockerfile │ │ │ ├── build.sh │ │ │ ├── requirements.txt │ │ │ └── sources.list │ │ └── third_party/ │ │ └── Dockerfile │ └── package/ │ └── manylinux/ │ ├── CentOS-Base.repo │ ├── CentOS7-Base-163.repo │ ├── Dockerfile │ ├── README.md │ ├── build_wheel.py │ └── launch.sh ├── docs/ │ ├── Makefile │ ├── requirements.txt │ └── source/ │ ├── _static/ │ │ └── .gitkeep │ ├── auto_parallel.rst │ ├── autograd.rst │ ├── cn/ │ │ ├── __init__.py │ │ ├── activation.py │ │ └── math_ops.py │ ├── conf.py │ ├── cuda.rst │ ├── distributed.rst │ ├── distributions.rst │ ├── environment_variables.rst │ ├── graph.rst │ ├── hub.rst │ ├── image.rst │ ├── index.rst │ ├── linalg.rst │ ├── nn.functional.rst │ ├── nn.init.rst │ ├── nn.rst │ ├── one_embedding.rst │ ├── oneflow.rst │ ├── optim.rst │ ├── special.rst │ ├── tensor.rst │ ├── tensor_attributes.rst │ ├── troubleshooting.md │ ├── type_info.rst │ ├── utils.data.rst │ ├── utils.global_view.rst │ └── utils.tensor.rst ├── external/ │ ├── CMakeLists.txt │ ├── fmt/ │ │ └── CMakeLists.txt │ ├── kineto/ │ │ └── CMakeLists.txt │ ├── onetbb/ │ │ └── CMakeLists.txt │ └── robin-hood-hashing/ │ └── CMakeLists.txt ├── oneflow/ │ ├── api/ │ │ ├── common/ │ │ │ ├── ir_pass.cpp │ │ │ ├── job_build_and_infer_ctx.h │ │ │ ├── sbp.h │ │ │ └── variable_tensor_mgr.h │ │ ├── cpp/ │ │ │ ├── api.h │ │ │ ├── embedding/ │ │ │ │ ├── embedding.cpp │ │ │ │ └── embedding.h │ │ │ ├── env.cpp │ │ │ ├── env.h │ │ │ ├── env_impl.cpp │ │ │ ├── env_impl.h │ │ │ ├── framework/ │ │ │ │ ├── device.cpp │ │ │ │ ├── device.h │ │ │ │ ├── dtype.cpp │ │ │ │ ├── dtype.h │ │ │ │ ├── graph.cpp │ │ │ │ ├── graph.h │ │ │ │ ├── ivalue.cpp │ │ │ │ ├── ivalue.h │ │ │ │ ├── shape.cpp │ │ │ │ ├── shape.h │ │ │ │ ├── tensor.cpp │ │ │ │ └── tensor.h │ │ │ ├── framework.h │ │ │ ├── nn/ │ │ │ │ └── functional/ │ │ │ │ ├── activation.cpp │ │ │ │ └── activation.h │ │ │ ├── nn.h │ │ │ └── tests/ │ │ │ ├── api_test.cpp │ │ │ ├── api_test.h │ │ │ ├── graph_test.cpp │ │ │ ├── graph_test_model/ │ │ │ │ ├── affine_no_parameter/ │ │ │ │ │ └── model.mlir │ │ │ │ └── affine_with_parameter/ │ │ │ │ ├── model.a/ │ │ │ │ │ ├── meta │ │ │ │ │ └── out │ │ │ │ ├── model.b/ │ │ │ │ │ ├── meta │ │ │ │ │ └── out │ │ │ │ └── model.mlir │ │ │ ├── ivalue_test.cpp │ │ │ ├── nn_test.cpp │ │ │ ├── one_embedding_test.cpp │ │ │ └── tensor_test.cpp │ │ └── python/ │ │ ├── autograd/ │ │ │ ├── autograd.cpp │ │ │ ├── autograd_engine.cpp │ │ │ ├── autograd_function.cpp │ │ │ ├── autograd_function_state.cpp │ │ │ ├── autograd_function_state.h │ │ │ ├── autograd_mode.cpp │ │ │ └── function_node.cpp │ │ ├── caster/ │ │ │ ├── autograd_function_state.h │ │ │ ├── common.h │ │ │ ├── maybe.h │ │ │ ├── optional.h │ │ │ ├── size.h │ │ │ ├── tensor.h │ │ │ └── test.cpp │ │ ├── deprecated.cpp │ │ ├── dlpack/ │ │ │ ├── converter.cpp │ │ │ ├── converter.h │ │ │ └── dlpack.h │ │ ├── eager/ │ │ │ └── eager.cpp │ │ ├── env/ │ │ │ ├── env.cpp │ │ │ └── env.h │ │ ├── ep/ │ │ │ └── cuda_matmul_mode.cpp │ │ ├── exception/ │ │ │ ├── exception.cpp │ │ │ └── exception.h │ │ ├── flags.cpp │ │ ├── framework/ │ │ │ ├── autocast.cpp │ │ │ ├── device.cpp │ │ │ ├── doc.cpp │ │ │ ├── dtype.cpp │ │ │ ├── framework.cpp │ │ │ ├── framework.h │ │ │ ├── global_mode.cpp │ │ │ ├── id_state.cpp │ │ │ ├── id_util.cpp │ │ │ ├── instructions_builder.cpp │ │ │ ├── layout.cpp │ │ │ ├── memory_format.cpp │ │ │ ├── memory_format.h │ │ │ ├── nn_graph.cpp │ │ │ ├── one_embedding.cpp │ │ │ ├── op_builder.cpp │ │ │ ├── op_expr.cpp │ │ │ ├── parallel_conf_util.cpp │ │ │ ├── py_kernel_registry.cpp │ │ │ ├── random_generator.cpp │ │ │ ├── scope_util.cpp │ │ │ ├── session_util.cpp │ │ │ ├── shut_down_util.cpp │ │ │ ├── size.cpp │ │ │ ├── size.h │ │ │ ├── stream.cpp │ │ │ ├── tensor.cpp │ │ │ ├── tensor.h │ │ │ ├── tensor_functions.cpp │ │ │ ├── tensor_functions_util.h │ │ │ ├── tensor_tuple.cpp │ │ │ ├── tensortype.cpp │ │ │ ├── tensortype.h │ │ │ ├── thread.cpp │ │ │ ├── thread.h │ │ │ ├── typeinfo.cpp │ │ │ ├── typeinfo.h │ │ │ └── variable_tensor_mgr.cpp │ │ ├── functional/ │ │ │ ├── common.cpp │ │ │ ├── common.h │ │ │ ├── dispatch_stateful_ops.cpp │ │ │ ├── dispatch_stateful_ops.yaml │ │ │ ├── function_def.h │ │ │ ├── indexing.cpp │ │ │ ├── indexing.h │ │ │ ├── python_arg.cpp │ │ │ ├── python_arg.h │ │ │ ├── python_arg_parser.cpp │ │ │ ├── python_arg_parser.h │ │ │ ├── python_return_types.h │ │ │ ├── tensor_api.cpp │ │ │ ├── tensor_api.yaml │ │ │ ├── value_types.cpp │ │ │ └── value_types.h │ │ ├── gil_foreign_lock_helper.cpp │ │ ├── init.cpp │ │ ├── ir.cpp │ │ ├── job_build/ │ │ │ ├── job_build_and_infer.cpp │ │ │ ├── job_build_and_infer.h │ │ │ └── lazy_mode.cpp │ │ ├── multiprocessing/ │ │ │ ├── init.cpp │ │ │ ├── object_ptr.cpp │ │ │ ├── object_ptr.h │ │ │ └── shared_memory.cpp │ │ ├── numpy/ │ │ │ └── init_numpy_c_api.cpp │ │ ├── of_api_registry.cpp │ │ ├── of_api_registry.h │ │ ├── profiler.cpp │ │ ├── registry/ │ │ │ └── registry.cpp │ │ ├── remat/ │ │ │ └── remat.cpp │ │ ├── rpc/ │ │ │ ├── ccl.cpp │ │ │ └── rank_group.cpp │ │ ├── session/ │ │ │ └── session.cpp │ │ ├── stack_getter.cpp │ │ ├── symbol/ │ │ │ ├── job_conf_symbol.cpp │ │ │ ├── op_conf_symbol.cpp │ │ │ ├── placement_symbol.cpp │ │ │ ├── sbp_symbol.cpp │ │ │ └── scope_symbol.cpp │ │ └── utils/ │ │ ├── dataloader.cpp │ │ ├── tensor_utils.cpp │ │ └── tensor_utils.h │ ├── core/ │ │ ├── auto_parallel/ │ │ │ ├── algorithm_util.cpp │ │ │ ├── algorithm_util.h │ │ │ ├── auto_memory.cpp │ │ │ ├── auto_memory.h │ │ │ ├── binary_set.cpp │ │ │ ├── binary_set.h │ │ │ ├── boxing_collector.cpp │ │ │ ├── boxing_collector.h │ │ │ ├── sbp_collector.cpp │ │ │ ├── sbp_collector.h │ │ │ ├── sbp_constructor.cpp │ │ │ ├── sbp_constructor.h │ │ │ ├── sbp_edge.cpp │ │ │ ├── sbp_edge.h │ │ │ ├── sbp_graph.cpp │ │ │ ├── sbp_graph.h │ │ │ ├── sbp_node.cpp │ │ │ ├── sbp_node.h │ │ │ ├── sbp_util.cpp │ │ │ └── sbp_util.h │ │ ├── autograd/ │ │ │ ├── autograd_captured_tensor.h │ │ │ ├── autograd_engine.cpp │ │ │ ├── autograd_engine.h │ │ │ ├── autograd_function.cpp │ │ │ ├── autograd_function.h │ │ │ ├── autograd_meta.cpp │ │ │ ├── autograd_meta.h │ │ │ ├── autograd_mode.cpp │ │ │ ├── autograd_mode.h │ │ │ ├── gradient_funcs/ │ │ │ │ ├── activation.cpp │ │ │ │ ├── adaptive_avg_pool.cpp │ │ │ │ ├── adaptive_max_pool.cpp │ │ │ │ ├── add_n.cpp │ │ │ │ ├── affine_grid.cpp │ │ │ │ ├── amp_white_identity.cpp │ │ │ │ ├── as_strided.cpp │ │ │ │ ├── avg_pool.cpp │ │ │ │ ├── batch_gather.cpp │ │ │ │ ├── bias_add.cpp │ │ │ │ ├── binary_cross_entropy.cpp │ │ │ │ ├── binary_cross_entropy_with_logits.cpp │ │ │ │ ├── binary_cross_entropy_with_logits_reduce_mean.cpp │ │ │ │ ├── broadcast_binary_ops.cpp │ │ │ │ ├── broadcast_like.cpp │ │ │ │ ├── cast.cpp │ │ │ │ ├── clip_by_scalar.cpp │ │ │ │ ├── clip_by_scalar_max.cpp │ │ │ │ ├── clip_by_scalar_min.cpp │ │ │ │ ├── combined_margin_loss.cpp │ │ │ │ ├── complex.cpp │ │ │ │ ├── concat.cpp │ │ │ │ ├── conv.cpp │ │ │ │ ├── copy.cpp │ │ │ │ ├── ctc_loss.cpp │ │ │ │ ├── cublas_fused_mlp.cpp │ │ │ │ ├── cum_ops.cpp │ │ │ │ ├── deconv.cpp │ │ │ │ ├── deform_conv.cpp │ │ │ │ ├── depand.cpp │ │ │ │ ├── det.cpp │ │ │ │ ├── diag.cpp │ │ │ │ ├── diagonal.cpp │ │ │ │ ├── dim_gather.cpp │ │ │ │ ├── dim_scatter.cpp │ │ │ │ ├── dot.cpp │ │ │ │ ├── dropout.cpp │ │ │ │ ├── eager_ccl_broadcast.cpp │ │ │ │ ├── elementwise_minimum_maximum.cpp │ │ │ │ ├── embedding.cpp │ │ │ │ ├── expand.cpp │ │ │ │ ├── fake_quantization.cpp │ │ │ │ ├── fft.cpp │ │ │ │ ├── fill.cpp │ │ │ │ ├── flatten.cpp │ │ │ │ ├── flip.cpp │ │ │ │ ├── fold.cpp │ │ │ │ ├── fused_bias_add_dropout.cpp │ │ │ │ ├── fused_bias_add_gelu.cpp │ │ │ │ ├── fused_bias_add_scale_mask_softmax_dropout.cpp │ │ │ │ ├── fused_center.cpp │ │ │ │ ├── fused_cross_interaction.cpp │ │ │ │ ├── fused_dot_feature_interaction.cpp │ │ │ │ ├── fused_fast_gelu_mul.cpp │ │ │ │ ├── fused_get_boundding_boxes_coord.cpp │ │ │ │ ├── fused_get_ciou_diagonal_angle.cpp │ │ │ │ ├── fused_get_ciou_result.cpp │ │ │ │ ├── fused_get_convex_diagonal_squared.cpp │ │ │ │ ├── fused_get_intersection_area.cpp │ │ │ │ ├── fused_get_iou.cpp │ │ │ │ ├── fused_glu.cpp │ │ │ │ ├── fused_gru_cell.cpp │ │ │ │ ├── fused_lstm_cell.cpp │ │ │ │ ├── fused_matmul_bias.cpp │ │ │ │ ├── fused_matmul_bias_add_relu_dropout.cpp │ │ │ │ ├── fused_scale_mask_bias_softmax.cpp │ │ │ │ ├── fused_scale_mask_softmax.cpp │ │ │ │ ├── fused_scale_mask_softmax_dropout.cpp │ │ │ │ ├── fused_scale_tril.cpp │ │ │ │ ├── fused_scale_tril_softmax_mask_scale.cpp │ │ │ │ ├── fused_self_attention.cpp │ │ │ │ ├── fused_weighted_sum.cpp │ │ │ │ ├── gather.cpp │ │ │ │ ├── gather_nd.cpp │ │ │ │ ├── global_cast.cpp │ │ │ │ ├── global_to_global.cpp │ │ │ │ ├── gradient_accumulation.cpp │ │ │ │ ├── graph_feed_and_fetch.cpp │ │ │ │ ├── grid_sample.cpp │ │ │ │ ├── group_norm.cpp │ │ │ │ ├── identity.cpp │ │ │ │ ├── inv.cpp │ │ │ │ ├── kl_div.cpp │ │ │ │ ├── l2_normalize.cpp │ │ │ │ ├── layer_norm.cpp │ │ │ │ ├── lerp.cpp │ │ │ │ ├── linalg_cross.cpp │ │ │ │ ├── log_softmax.cpp │ │ │ │ ├── masked_fill.cpp │ │ │ │ ├── math_binary_op.cpp │ │ │ │ ├── math_unary_op.cpp │ │ │ │ ├── matmul.cpp │ │ │ │ ├── matrix_vector_product.cpp │ │ │ │ ├── max_pool.cpp │ │ │ │ ├── max_unpool.cpp │ │ │ │ ├── median.cpp │ │ │ │ ├── mode.cpp │ │ │ │ ├── narrow.cpp │ │ │ │ ├── nll.cpp │ │ │ │ ├── noncontiguous_binary_op.cpp │ │ │ │ ├── normalization.cpp │ │ │ │ ├── normalization_add_relu.cpp │ │ │ │ ├── one_embedding_fused_lookup.cpp │ │ │ │ ├── padding.cpp │ │ │ │ ├── partial_fc_sample.cpp │ │ │ │ ├── reduce_ops.cpp │ │ │ │ ├── reduce_sum_like.cpp │ │ │ │ ├── reshape.cpp │ │ │ │ ├── rms_norm.cpp │ │ │ │ ├── roi_align.cpp │ │ │ │ ├── roll.cpp │ │ │ │ ├── rrelu.cpp │ │ │ │ ├── scalar_add.cpp │ │ │ │ ├── scalar_div.cpp │ │ │ │ ├── scalar_floordiv.cpp │ │ │ │ ├── scalar_fmod.cpp │ │ │ │ ├── scalar_mul.cpp │ │ │ │ ├── scalar_pow.cpp │ │ │ │ ├── scalar_truncdiv.cpp │ │ │ │ ├── scaled_dot_product_attention.cpp │ │ │ │ ├── scatter_nd.cpp │ │ │ │ ├── select_top_n.cpp │ │ │ │ ├── slice.cpp │ │ │ │ ├── smooth_l1_loss.cpp │ │ │ │ ├── softmax.cpp │ │ │ │ ├── softmax_cross_entropy.cpp │ │ │ │ ├── sparse_cross_entropy.cpp │ │ │ │ ├── sparse_softmax_cross_entropy.cpp │ │ │ │ ├── sparse_softmax_cross_entropy_ms.cpp │ │ │ │ ├── split_like.cpp │ │ │ │ ├── squeeze.cpp │ │ │ │ ├── stack.cpp │ │ │ │ ├── tensor_scalar_binary.cpp │ │ │ │ ├── tensor_scatter_nd_update.cpp │ │ │ │ ├── tf_pool.cpp │ │ │ │ ├── to_contiguous.cpp │ │ │ │ ├── transpose.cpp │ │ │ │ ├── tril.cpp │ │ │ │ ├── triu.cpp │ │ │ │ ├── trunc.cpp │ │ │ │ ├── two_stage_reduce.cpp │ │ │ │ ├── unfold.cpp │ │ │ │ ├── unfold_tensor.cpp │ │ │ │ ├── unsqueeze.cpp │ │ │ │ ├── upsample.cpp │ │ │ │ ├── variance.cpp │ │ │ │ ├── vector_matrix_product.cpp │ │ │ │ └── where.cpp │ │ │ └── higher_order_gradient_funcs/ │ │ │ ├── activation.cpp │ │ │ ├── avg_pool.cpp │ │ │ ├── binary_cross_entropy_loss.cpp │ │ │ ├── binary_cross_entropy_with_logits.cpp │ │ │ ├── binary_cross_entropy_with_logits_reduce_mean.cpp │ │ │ ├── conv.cpp │ │ │ ├── div.cpp │ │ │ ├── kl_div_loss.cpp │ │ │ ├── log_softmax.cpp │ │ │ ├── math_unary_op.cpp │ │ │ ├── matmul.cpp │ │ │ ├── max_pool.cpp │ │ │ ├── nll_loss.cpp │ │ │ ├── pow.cpp │ │ │ ├── scalar_pow.cpp │ │ │ ├── slice.cpp │ │ │ ├── smooth_l1_loss.cpp │ │ │ └── softmax.cpp │ │ ├── boxing/ │ │ │ ├── asymmetric_broadcast.cpp │ │ │ ├── boxing_dividor.h │ │ │ ├── boxing_dividor_util.cpp │ │ │ ├── boxing_dividor_util.h │ │ │ ├── boxing_interpreter_status.cpp │ │ │ ├── boxing_interpreter_status.h │ │ │ ├── ccl_boxing_function.cpp │ │ │ ├── cuda_copy_boxing_interpreter.cpp │ │ │ ├── eager_boxing_interpreter.cpp │ │ │ ├── eager_boxing_interpreter.h │ │ │ ├── eager_boxing_interpreter_mgr.cpp │ │ │ ├── eager_boxing_interpreter_mgr.h │ │ │ ├── eager_boxing_logger.cpp │ │ │ ├── eager_boxing_logger.h │ │ │ ├── flatten_hierarchy.cpp │ │ │ ├── generic_symmetric_nd_sbp_boxing.cpp │ │ │ ├── identity_boxing_interpreter.cpp │ │ │ ├── naive_1_to_p_boxing.cpp │ │ │ ├── naive_b_to_1_boxing.cpp │ │ │ ├── naive_b_to_s_boxing.cpp │ │ │ ├── naive_p_to_b_boxing.cpp │ │ │ ├── naive_p_to_s_boxing.cpp │ │ │ ├── naive_s_to_b_boxing.cpp │ │ │ ├── naive_s_to_p_boxing.cpp │ │ │ ├── naive_s_to_s_boxing.cpp │ │ │ ├── nd_sbp_dim_reduce_boxing.cpp │ │ │ ├── one_to_one_boxing.cpp │ │ │ ├── slice_boxing_util.cpp │ │ │ ├── slice_boxing_util.h │ │ │ ├── symmetric_acyclic_nd_sbp_boxing.cpp │ │ │ ├── symmetric_b_to_p_boxing.cpp │ │ │ ├── symmetric_b_to_s_boxing.cpp │ │ │ ├── symmetric_s_to_p_boxing.cpp │ │ │ └── unflatten_hierarchy.cpp │ │ ├── ccl/ │ │ │ ├── ccl.cpp │ │ │ └── ccl.h │ │ ├── comm_network/ │ │ │ ├── comm_network.cpp │ │ │ ├── comm_network.h │ │ │ ├── epoll/ │ │ │ │ ├── epoll_comm_network.cpp │ │ │ │ ├── epoll_comm_network.h │ │ │ │ ├── io_event_poller.cpp │ │ │ │ ├── io_event_poller.h │ │ │ │ ├── socket_helper.cpp │ │ │ │ ├── socket_helper.h │ │ │ │ ├── socket_memory_desc.h │ │ │ │ ├── socket_message.h │ │ │ │ ├── socket_read_helper.cpp │ │ │ │ ├── socket_read_helper.h │ │ │ │ ├── socket_write_helper.cpp │ │ │ │ └── socket_write_helper.h │ │ │ └── ibverbs/ │ │ │ ├── ibverbs.proto │ │ │ ├── ibverbs_comm_network.cpp │ │ │ ├── ibverbs_comm_network.h │ │ │ ├── ibverbs_memory_desc.cpp │ │ │ ├── ibverbs_memory_desc.h │ │ │ ├── ibverbs_qp.cpp │ │ │ └── ibverbs_qp.h │ │ ├── common/ │ │ │ ├── array_ref.h │ │ │ ├── auto_registration_factory.h │ │ │ ├── balanced_splitter.cpp │ │ │ ├── balanced_splitter.h │ │ │ ├── balanced_splitter_test.cpp │ │ │ ├── bfloat16.h │ │ │ ├── bfloat16_math.h │ │ │ ├── bfloat16_test.cpp │ │ │ ├── blas.h │ │ │ ├── blocking_counter.cpp │ │ │ ├── blocking_counter.h │ │ │ ├── blocking_then_busy.h │ │ │ ├── buffer.h │ │ │ ├── buffer_manager.h │ │ │ ├── cached_caller.cpp │ │ │ ├── cached_caller.h │ │ │ ├── cblas.h │ │ │ ├── channel.h │ │ │ ├── channel_test.cpp │ │ │ ├── check.cpp │ │ │ ├── check.h │ │ │ ├── check_level.cpp │ │ │ ├── check_level.h │ │ │ ├── constant.h │ │ │ ├── container_util.h │ │ │ ├── container_util_test.cpp │ │ │ ├── cost_util.h │ │ │ ├── cpp_attribute.h │ │ │ ├── data_type.cpp │ │ │ ├── data_type.h │ │ │ ├── data_type.proto │ │ │ ├── data_type_converter.h │ │ │ ├── data_type_converter_test.cpp │ │ │ ├── data_type_converter_test_static.h │ │ │ ├── data_type_seq.h │ │ │ ├── decorator.h │ │ │ ├── decorator_test.cpp │ │ │ ├── device.proto │ │ │ ├── device_type.cpp │ │ │ ├── device_type.h │ │ │ ├── device_type.proto │ │ │ ├── dtype_signature.h │ │ │ ├── dtype_signature.proto │ │ │ ├── eigen_util.h │ │ │ ├── either_ptr.h │ │ │ ├── env_var/ │ │ │ │ ├── bootstrap.h │ │ │ │ ├── debug_mode.h │ │ │ │ ├── eager.h │ │ │ │ ├── env_var.h │ │ │ │ ├── remat.h │ │ │ │ ├── stream.h │ │ │ │ └── vm.h │ │ │ ├── error.cpp │ │ │ ├── error.h │ │ │ ├── error.proto │ │ │ ├── error_util.cpp │ │ │ ├── error_util.h │ │ │ ├── exception.h │ │ │ ├── flat_shape.cpp │ │ │ ├── flat_shape.h │ │ │ ├── foreign_lock_helper.cpp │ │ │ ├── foreign_lock_helper.h │ │ │ ├── function_traits.h │ │ │ ├── hash.h │ │ │ ├── hash_container.h │ │ │ ├── hash_eq_trait_ptr.h │ │ │ ├── high_order_bool.h │ │ │ ├── just.h │ │ │ ├── layout_standardize.h │ │ │ ├── math_util.cpp │ │ │ ├── math_util.h │ │ │ ├── maybe.h │ │ │ ├── maybe_test.cpp │ │ │ ├── mem_util.cpp │ │ │ ├── mem_util.h │ │ │ ├── memory_format.proto │ │ │ ├── meta_util.hpp │ │ │ ├── nd_index.cpp │ │ │ ├── nd_index.h │ │ │ ├── nd_index_offset_helper.h │ │ │ ├── nd_index_offset_helper_test.cpp │ │ │ ├── not_equal_to_previous_adjacent_iterator.h │ │ │ ├── notifier.cpp │ │ │ ├── notifier.h │ │ │ ├── of_unused.h │ │ │ ├── op_args_reserved_size.h │ │ │ ├── op_args_vector.h │ │ │ ├── optional.h │ │ │ ├── optional_test.cpp │ │ │ ├── pcheck.h │ │ │ ├── permutation_iterator.h │ │ │ ├── platform.h │ │ │ ├── preprocessor.h │ │ │ ├── preprocessor_internal.h │ │ │ ├── preprocessor_test.cpp │ │ │ ├── process_state.h │ │ │ ├── protobuf.cpp │ │ │ ├── protobuf.h │ │ │ ├── range.cpp │ │ │ ├── range.h │ │ │ ├── range.proto │ │ │ ├── registry_error.cpp │ │ │ ├── registry_error.h │ │ │ ├── scalar.cpp │ │ │ ├── scalar.h │ │ │ ├── sequential.proto │ │ │ ├── shape.cpp │ │ │ ├── shape.h │ │ │ ├── shape.proto │ │ │ ├── shape_test.cpp │ │ │ ├── shape_vec.h │ │ │ ├── shape_view.cpp │ │ │ ├── shape_view.h │ │ │ ├── shared_or_scalar.h │ │ │ ├── single_thread_obj_pool.h │ │ │ ├── single_thread_obj_pool_test.cpp │ │ │ ├── singleton.h │ │ │ ├── sized_buffer_view.h │ │ │ ├── small_vector.h │ │ │ ├── spin_counter.cpp │ │ │ ├── spin_counter.h │ │ │ ├── static_check.h │ │ │ ├── static_global.h │ │ │ ├── steady_vector.h │ │ │ ├── steady_vector_test.cpp │ │ │ ├── str_util.cpp │ │ │ ├── str_util.h │ │ │ ├── stream_type.h │ │ │ ├── stride.cpp │ │ │ ├── stride.h │ │ │ ├── switch_func.h │ │ │ ├── symbol.h │ │ │ ├── symbol_test.cpp │ │ │ ├── tensor_buffer.cpp │ │ │ ├── tensor_buffer.h │ │ │ ├── tensor_desc.cpp │ │ │ ├── tensor_desc.h │ │ │ ├── tensor_meta.cpp │ │ │ ├── tensor_meta.h │ │ │ ├── test_util.h │ │ │ ├── thread_local_guard.h │ │ │ ├── thread_local_guard_test.cpp │ │ │ ├── throw.h │ │ │ ├── to_string.h │ │ │ ├── tuple_hash.h │ │ │ ├── type_traits.h │ │ │ ├── util.cpp │ │ │ ├── util.h │ │ │ ├── wrap_dim_utils.h │ │ │ └── zero_only_zip.h │ │ ├── control/ │ │ │ ├── bootstrap_client.h │ │ │ ├── bootstrap_server.h │ │ │ ├── control.proto │ │ │ ├── ctrl_bootstrap.cpp │ │ │ ├── ctrl_bootstrap.h │ │ │ ├── ctrl_bootstrap.proto │ │ │ ├── ctrl_call.h │ │ │ ├── ctrl_client.cpp │ │ │ ├── ctrl_client.h │ │ │ ├── ctrl_server.cpp │ │ │ ├── ctrl_server.h │ │ │ ├── ctrl_service.cpp │ │ │ ├── ctrl_service.h │ │ │ ├── ctrl_test.cpp │ │ │ ├── ctrl_util.cpp │ │ │ ├── ctrl_util.h │ │ │ ├── global_process_ctx.h │ │ │ ├── host_list_bootstrap_client.cpp │ │ │ ├── host_list_bootstrap_client.h │ │ │ ├── host_list_bootstrap_server.cpp │ │ │ ├── host_list_bootstrap_server.h │ │ │ ├── rank_info_bootstrap_client.cpp │ │ │ ├── rank_info_bootstrap_client.h │ │ │ ├── rank_info_bootstrap_server.cpp │ │ │ ├── rank_info_bootstrap_server.h │ │ │ ├── rpc_client.cpp │ │ │ ├── rpc_client.h │ │ │ ├── rpc_server.cpp │ │ │ ├── rpc_server.h │ │ │ └── worker_process_info.proto │ │ ├── cuda/ │ │ │ ├── atomic.cuh │ │ │ ├── elementwise.cuh │ │ │ ├── layer_norm.cuh │ │ │ ├── rms_norm.cuh │ │ │ ├── softmax.cuh │ │ │ └── unique.cuh │ │ ├── device/ │ │ │ ├── cuda_pseudo_bfloat16.h │ │ │ ├── cuda_pseudo_half.h │ │ │ ├── cuda_util.cpp │ │ │ ├── cuda_util.h │ │ │ ├── cudnn_conv_util.cpp │ │ │ ├── cudnn_conv_util.h │ │ │ ├── cudnn_util.cpp │ │ │ ├── cudnn_util.h │ │ │ ├── device_id.cpp │ │ │ ├── device_id.h │ │ │ ├── ep_based_event_record.h │ │ │ ├── event_record.h │ │ │ ├── nccl_util.cpp │ │ │ └── nccl_util.h │ │ ├── eager/ │ │ │ ├── call_context.cpp │ │ │ ├── call_context.h │ │ │ ├── dev_vm_dep_object_consume_mode.h │ │ │ ├── eager_blob_object.cpp │ │ │ ├── eager_blob_object.h │ │ │ ├── local_dep_object.cpp │ │ │ ├── local_dep_object.h │ │ │ ├── tensor_storage.cpp │ │ │ └── tensor_storage.h │ │ ├── embedding/ │ │ │ ├── cache.cpp │ │ │ ├── cache.h │ │ │ ├── cache_test.cpp │ │ │ ├── cached_key_value_store.cu │ │ │ ├── cached_key_value_store.h │ │ │ ├── embedding_manager.cpp │ │ │ ├── embedding_manager.h │ │ │ ├── full_cache.cu │ │ │ ├── full_cache.h │ │ │ ├── hash_functions.cuh │ │ │ ├── key_value_store.h │ │ │ ├── key_value_store_options.h │ │ │ ├── key_value_store_test.cpp │ │ │ ├── kv_iterator.h │ │ │ ├── lru_cache.cu │ │ │ ├── lru_cache.h │ │ │ ├── mock_key_value_store.cu │ │ │ ├── mock_key_value_store.h │ │ │ ├── persistent_table.cpp │ │ │ ├── persistent_table.h │ │ │ ├── persistent_table_key_value_store.cu │ │ │ ├── persistent_table_key_value_store.h │ │ │ └── posix_file.h │ │ ├── ep/ │ │ │ ├── common/ │ │ │ │ ├── active_device_guard.cpp │ │ │ │ ├── device.cpp │ │ │ │ ├── device_manager_registry.cpp │ │ │ │ ├── onednn.h │ │ │ │ └── primitive/ │ │ │ │ ├── add.cpp │ │ │ │ ├── batch_matmul.cpp │ │ │ │ ├── binary_functor.h │ │ │ │ ├── broadcast_elementwise_binary.h │ │ │ │ ├── broadcast_elementwise_unary.h │ │ │ │ ├── broadcast_matmul.h │ │ │ │ ├── broadcast_simplify_dims_test.cpp │ │ │ │ ├── constant_pad.h │ │ │ │ ├── copy_nd.h │ │ │ │ ├── elementwise_unary.h │ │ │ │ ├── matmul.cpp │ │ │ │ ├── permute.h │ │ │ │ ├── permute_impl.h │ │ │ │ ├── permute_test.cpp │ │ │ │ ├── unary_functor.h │ │ │ │ ├── util.h │ │ │ │ └── where.h │ │ │ ├── cpu/ │ │ │ │ ├── cpu_device.cpp │ │ │ │ ├── cpu_device.h │ │ │ │ ├── cpu_device_manager.cpp │ │ │ │ ├── cpu_device_manager.h │ │ │ │ ├── cpu_device_manager_factory.cpp │ │ │ │ ├── cpu_event.cpp │ │ │ │ ├── cpu_event.h │ │ │ │ ├── cpu_random_generator.cpp │ │ │ │ ├── cpu_random_generator.h │ │ │ │ ├── cpu_stream.cpp │ │ │ │ ├── cpu_stream.h │ │ │ │ └── primitive/ │ │ │ │ ├── add.cpp │ │ │ │ ├── binary_functor.h │ │ │ │ ├── broadcast_elementwise_binary.cpp │ │ │ │ ├── broadcast_elementwise_unary.cpp │ │ │ │ ├── broadcast_matmul.cpp │ │ │ │ ├── cast.cpp │ │ │ │ ├── constant_pad.cpp │ │ │ │ ├── copy_nd.cpp │ │ │ │ ├── elementwise_unary.cpp │ │ │ │ ├── fill.cpp │ │ │ │ ├── memcpy.cpp │ │ │ │ ├── memset.cpp │ │ │ │ ├── permute.cpp │ │ │ │ ├── softmax.cpp │ │ │ │ ├── softmax_backward.cpp │ │ │ │ ├── tensor_fill.cpp │ │ │ │ ├── type_seq.h │ │ │ │ ├── unary_functor.h │ │ │ │ └── where.cpp │ │ │ ├── cuda/ │ │ │ │ ├── cuda_device.cpp │ │ │ │ ├── cuda_device.h │ │ │ │ ├── cuda_device_manager.cpp │ │ │ │ ├── cuda_device_manager.h │ │ │ │ ├── cuda_device_manager_factory.cpp │ │ │ │ ├── cuda_event.cpp │ │ │ │ ├── cuda_event.h │ │ │ │ ├── cuda_matmul_mode.cpp │ │ │ │ ├── cuda_matmul_mode.h │ │ │ │ ├── cuda_random_generator.cpp │ │ │ │ ├── cuda_random_generator.h │ │ │ │ ├── cuda_stream.cpp │ │ │ │ ├── cuda_stream.h │ │ │ │ └── primitive/ │ │ │ │ ├── add.cu │ │ │ │ ├── binary_functor.cuh │ │ │ │ ├── broadcast_elementwise_binary.cu │ │ │ │ ├── broadcast_elementwise_binary.cuh │ │ │ │ ├── broadcast_elementwise_binary_activation_grad_0.cu │ │ │ │ ├── broadcast_elementwise_binary_activation_grad_1.cu │ │ │ │ ├── broadcast_elementwise_binary_activation_grad_2.cu │ │ │ │ ├── broadcast_elementwise_binary_bitwise.cu │ │ │ │ ├── broadcast_elementwise_binary_comparision_0.cu │ │ │ │ ├── broadcast_elementwise_binary_comparision_1.cu │ │ │ │ ├── broadcast_elementwise_binary_comparision_complex.cu │ │ │ │ ├── broadcast_elementwise_binary_logical.cu │ │ │ │ ├── broadcast_elementwise_binary_math_0.cu │ │ │ │ ├── broadcast_elementwise_binary_math_1.cu │ │ │ │ ├── broadcast_elementwise_binary_math_2.cu │ │ │ │ ├── broadcast_elementwise_binary_math_complex.cu │ │ │ │ ├── broadcast_elementwise_unary.cu │ │ │ │ ├── broadcast_matmul.cpp │ │ │ │ ├── cast.cu │ │ │ │ ├── constant_pad.cu │ │ │ │ ├── copy_nd.cu │ │ │ │ ├── elementwise_unary.cu │ │ │ │ ├── fill.cu │ │ │ │ ├── math_elementwise_unary_math_grad_0.cu │ │ │ │ ├── math_elementwise_unary_math_grad_1.cu │ │ │ │ ├── math_elementwise_unary_math_grad_2.cu │ │ │ │ ├── math_elementwise_unary_math_grad_3.cu │ │ │ │ ├── math_elementwise_unary_math_grad_complex.cu │ │ │ │ ├── memcpy.cpp │ │ │ │ ├── memset.cpp │ │ │ │ ├── permute.cu │ │ │ │ ├── softmax.cu │ │ │ │ ├── softmax_backward.cu │ │ │ │ ├── tensor_fill.cu │ │ │ │ ├── type_seq.h │ │ │ │ ├── unary_functor.cuh │ │ │ │ └── where.cu │ │ │ ├── include/ │ │ │ │ ├── active_device_guard.h │ │ │ │ ├── allocation_options.h │ │ │ │ ├── device.h │ │ │ │ ├── device_manager.h │ │ │ │ ├── device_manager_factory.h │ │ │ │ ├── device_manager_registry.h │ │ │ │ ├── event.h │ │ │ │ ├── primitive/ │ │ │ │ │ ├── add.h │ │ │ │ │ ├── batch_matmul.h │ │ │ │ │ ├── binary_op.h │ │ │ │ │ ├── blas.h │ │ │ │ │ ├── broadcast_elementwise_binary.h │ │ │ │ │ ├── broadcast_elementwise_unary.h │ │ │ │ │ ├── broadcast_matmul.h │ │ │ │ │ ├── cast.h │ │ │ │ │ ├── constant_pad.h │ │ │ │ │ ├── copy_nd.h │ │ │ │ │ ├── elementwise_unary.h │ │ │ │ │ ├── fast_integer_math.h │ │ │ │ │ ├── fill.h │ │ │ │ │ ├── log_softmax.h │ │ │ │ │ ├── log_softmax_backward.h │ │ │ │ │ ├── matmul.h │ │ │ │ │ ├── memcpy.h │ │ │ │ │ ├── memset.h │ │ │ │ │ ├── one_hot.h │ │ │ │ │ ├── permute.h │ │ │ │ │ ├── primitive.h │ │ │ │ │ ├── softmax.h │ │ │ │ │ ├── softmax_backward.h │ │ │ │ │ ├── tensor_fill.h │ │ │ │ │ ├── unary_op.h │ │ │ │ │ └── where.h │ │ │ │ ├── random_generator.h │ │ │ │ └── stream.h │ │ │ └── test/ │ │ │ ├── primitive/ │ │ │ │ ├── add_test.cpp │ │ │ │ ├── batch_matmul_test.cpp │ │ │ │ ├── binary_test.cpp │ │ │ │ ├── broadcast_matmul_test.cpp │ │ │ │ ├── cast_test.cpp │ │ │ │ ├── constant_pad_test.cpp │ │ │ │ ├── copy_nd_test.cpp │ │ │ │ ├── elementwise_unary_test.cpp │ │ │ │ ├── fill_test.cpp │ │ │ │ ├── matmul_test.cpp │ │ │ │ ├── memcpy_test.cpp │ │ │ │ ├── memset_test.cpp │ │ │ │ ├── permute_test.cpp │ │ │ │ ├── primitive_test.h │ │ │ │ ├── softmax_backward_test.cpp │ │ │ │ ├── softmax_test.cpp │ │ │ │ ├── unary_test.cpp │ │ │ │ └── where_test.cpp │ │ │ └── test_util.h │ │ ├── framework/ │ │ │ ├── arg_tuple.cpp │ │ │ ├── arg_tuple.h │ │ │ ├── attr_map.cpp │ │ │ ├── attr_map.h │ │ │ ├── attr_map_test.cpp │ │ │ ├── attr_value.cpp │ │ │ ├── attr_value.h │ │ │ ├── attr_value_accessor.cpp │ │ │ ├── attr_value_accessor.h │ │ │ ├── auto_random_generator.cpp │ │ │ ├── auto_random_generator.h │ │ │ ├── autocast.cpp │ │ │ ├── autocast.h │ │ │ ├── compute_complexity_fn_context.h │ │ │ ├── config_def.cpp │ │ │ ├── config_def.h │ │ │ ├── config_def.proto │ │ │ ├── consistency_check.cpp │ │ │ ├── consistency_check.h │ │ │ ├── device.cpp │ │ │ ├── device.h │ │ │ ├── dtype.cpp │ │ │ ├── dtype.h │ │ │ ├── eager_util.h │ │ │ ├── framework.h │ │ │ ├── get_nd_sbp_signature_list_context.h │ │ │ ├── global_param_grad_sync_mode.cpp │ │ │ ├── global_param_grad_sync_mode.h │ │ │ ├── global_tensor_infer_cache.cpp │ │ │ ├── global_tensor_infer_cache.h │ │ │ ├── id_util.cpp │ │ │ ├── id_util.h │ │ │ ├── infer_nd_sbp_fn_context.h │ │ │ ├── infer_output_blob_time_shape_fn_context.h │ │ │ ├── infer_util.cpp │ │ │ ├── infer_util.h │ │ │ ├── instructions_builder.cpp │ │ │ ├── instructions_builder.h │ │ │ ├── layout.cpp │ │ │ ├── layout.h │ │ │ ├── load_library.cpp │ │ │ ├── load_library.h │ │ │ ├── local_tensor_infer_cache.cpp │ │ │ ├── local_tensor_infer_cache.h │ │ │ ├── multi_client_session_context.cpp │ │ │ ├── multi_client_session_context.h │ │ │ ├── multi_thread.cpp │ │ │ ├── multi_thread.h │ │ │ ├── mutable_attr_map.h │ │ │ ├── nd_sbp.cpp │ │ │ ├── nd_sbp.h │ │ │ ├── nn_graph.cpp │ │ │ ├── nn_graph.h │ │ │ ├── nn_graph_if.h │ │ │ ├── op_builder.cpp │ │ │ ├── op_builder.h │ │ │ ├── op_definition.h │ │ │ ├── op_expr.cpp │ │ │ ├── op_expr.h │ │ │ ├── op_expr_grad_function.cpp │ │ │ ├── op_expr_grad_function.h │ │ │ ├── op_interpreter/ │ │ │ │ ├── dispatch_frame.cpp │ │ │ │ ├── dispatch_frame.h │ │ │ │ ├── eager_global_op_interpreter.cpp │ │ │ │ ├── eager_local_op_interpreter.cpp │ │ │ │ ├── eager_local_op_interpreter.h │ │ │ │ ├── lazy_op_interpreter.cpp │ │ │ │ ├── lazy_op_interpreter.h │ │ │ │ ├── op_interpreter.cpp │ │ │ │ ├── op_interpreter_util.cpp │ │ │ │ └── op_interpreter_util.h │ │ │ ├── op_interpreter.h │ │ │ ├── op_kernel.cpp │ │ │ ├── op_kernel.h │ │ │ ├── op_kernel_infer_cache.cpp │ │ │ ├── op_kernel_infer_cache.h │ │ │ ├── ordered_string_list.h │ │ │ ├── parallel_conf_util.cpp │ │ │ ├── parallel_conf_util.h │ │ │ ├── parallel_conf_util_test.cpp │ │ │ ├── placed_nd_sbp.cpp │ │ │ ├── placed_nd_sbp.h │ │ │ ├── placement_sbp_util.cpp │ │ │ ├── placement_sbp_util.h │ │ │ ├── placement_sbp_util_test.cpp │ │ │ ├── placement_utils.cpp │ │ │ ├── placement_utils.h │ │ │ ├── random_generator.cpp │ │ │ ├── random_generator.h │ │ │ ├── rank_group_rpc_util.cpp │ │ │ ├── rank_group_rpc_util.h │ │ │ ├── saved_tensor_hooks.h │ │ │ ├── sbp_context.cpp │ │ │ ├── sbp_context.h │ │ │ ├── sbp_infer_util.cpp │ │ │ ├── sbp_infer_util.h │ │ │ ├── sbp_infer_util_test.cpp │ │ │ ├── scope_util.cpp │ │ │ ├── scope_util.h │ │ │ ├── session_util.cpp │ │ │ ├── session_util.h │ │ │ ├── shut_down_util.cpp │ │ │ ├── shut_down_util.h │ │ │ ├── stream.cpp │ │ │ ├── stream.h │ │ │ ├── stream_allocator_is_pinned.h │ │ │ ├── stream_get_stream_type_name.h │ │ │ ├── stream_guard.cpp │ │ │ ├── stream_guard.h │ │ │ ├── stream_is_comm_net_stream.h │ │ │ ├── stream_mgr.cpp │ │ │ ├── stream_mgr.h │ │ │ ├── stream_need_soft_sync.h │ │ │ ├── stream_on_independent_thread.h │ │ │ ├── stream_set.cpp │ │ │ ├── stream_set.h │ │ │ ├── stream_support_stream_wait.h │ │ │ ├── symbol_storage_util.cpp │ │ │ ├── symbol_storage_util.h │ │ │ ├── sync_symbol_global_tensor_meta.cpp │ │ │ ├── sync_symbol_global_tensor_meta.h │ │ │ ├── sync_symbol_nd_sbp.cpp │ │ │ ├── sync_symbol_nd_sbp.h │ │ │ ├── sync_symbol_parallel_desc.cpp │ │ │ ├── sync_symbol_parallel_desc.h │ │ │ ├── synced_symbol_map.cpp │ │ │ ├── synced_symbol_map.h │ │ │ ├── tensor.cpp │ │ │ ├── tensor.h │ │ │ ├── tensor_arg.cpp │ │ │ ├── tensor_arg.h │ │ │ ├── tensor_global_id.cpp │ │ │ ├── tensor_global_id.h │ │ │ ├── tensor_impl.cpp │ │ │ ├── tensor_impl.h │ │ │ ├── tensor_methods.cpp │ │ │ ├── tensor_methods.h │ │ │ ├── tensor_name_scope.cpp │ │ │ ├── tensor_name_scope.h │ │ │ ├── tensor_rpc_util.cpp │ │ │ ├── tensor_rpc_util.h │ │ │ ├── tensor_storage.cpp │ │ │ ├── tensor_storage.h │ │ │ ├── tensor_tuple.cpp │ │ │ ├── tensor_tuple.h │ │ │ ├── tensor_util.cpp │ │ │ ├── tensor_util.h │ │ │ ├── to_string.cpp │ │ │ ├── to_string.h │ │ │ ├── transport_token.cpp │ │ │ ├── transport_token.h │ │ │ ├── transport_util.cpp │ │ │ ├── transport_util.h │ │ │ ├── user_op_attr.proto │ │ │ ├── user_op_conf.cpp │ │ │ ├── user_op_conf.h │ │ │ ├── user_op_conf.proto │ │ │ ├── user_op_def.cpp │ │ │ ├── user_op_def.h │ │ │ ├── user_op_def.proto │ │ │ ├── user_op_hob.h │ │ │ ├── user_op_kernel_registry.cpp │ │ │ ├── user_op_kernel_registry.h │ │ │ ├── user_op_registry.cpp │ │ │ ├── user_op_registry.h │ │ │ ├── user_op_registry_manager.cpp │ │ │ ├── user_op_registry_manager.h │ │ │ ├── user_op_tensor.h │ │ │ ├── util.h │ │ │ ├── variable_meta_info.proto │ │ │ ├── variable_tensor_mgr.cpp │ │ │ └── variable_tensor_mgr.h │ │ ├── functional/ │ │ │ ├── function_library.h │ │ │ ├── functional.h │ │ │ ├── functional_api.yaml │ │ │ ├── impl/ │ │ │ │ ├── activation_functor.cpp │ │ │ │ ├── array_functor.cpp │ │ │ │ ├── binary_functor.cpp │ │ │ │ ├── binary_functor.h │ │ │ │ ├── binary_grad_functor.cpp │ │ │ │ ├── comm_functor.cpp │ │ │ │ ├── common.cpp │ │ │ │ ├── common.h │ │ │ │ ├── dataset_functor.cpp │ │ │ │ ├── eye_functor.cpp │ │ │ │ ├── fused_attention_functor.cpp │ │ │ │ ├── global_cast.cpp │ │ │ │ ├── gradient_accumulation_functor.cpp │ │ │ │ ├── higher_derivative_functor.cpp │ │ │ │ ├── linalg_functor.cpp │ │ │ │ ├── math_functor.cpp │ │ │ │ ├── nn_functor.cpp │ │ │ │ ├── nn_grad_functor.cpp │ │ │ │ ├── quantization.cpp │ │ │ │ ├── random_functor.cpp │ │ │ │ ├── rnn_functor.cpp │ │ │ │ ├── slice_boxing_functor.cpp │ │ │ │ ├── test_functor.cpp │ │ │ │ ├── unary_functor.cpp │ │ │ │ ├── unary_functor.h │ │ │ │ └── util_ops_functor.cpp │ │ │ ├── packed_functor.h │ │ │ ├── sequence_function.h │ │ │ ├── tensor_index.cpp │ │ │ ├── tensor_index.h │ │ │ ├── tensor_processor.cpp │ │ │ └── tensor_processor.h │ │ ├── graph/ │ │ │ ├── boxing/ │ │ │ │ ├── b21_sub_task_graph_builder.cpp │ │ │ │ ├── b21_sub_task_graph_builder.h │ │ │ │ ├── boxing_logger.cpp │ │ │ │ ├── boxing_logger.h │ │ │ │ ├── ccl_sub_task_graph_builder.cpp │ │ │ │ ├── ccl_sub_task_graph_builder.h │ │ │ │ ├── chain_sub_task_graph_builder.cpp │ │ │ │ ├── chain_sub_task_graph_builder.h │ │ │ │ ├── collective_boxing.proto │ │ │ │ ├── collective_boxing_sub_task_graph_builder.cpp │ │ │ │ ├── collective_boxing_sub_task_graph_builder.h │ │ │ │ ├── collective_boxing_util.cpp │ │ │ │ ├── collective_boxing_util.h │ │ │ │ ├── fallback_to_cpu_slice_boxing_sub_task_graph_builder.cpp │ │ │ │ ├── fallback_to_cpu_slice_boxing_sub_task_graph_builder.h │ │ │ │ ├── hierarchical_sub_task_graph_builder.h │ │ │ │ ├── hierarchical_sub_task_graph_builder_impl.cpp │ │ │ │ ├── hierarchical_sub_task_graph_builder_impl.h │ │ │ │ ├── hierarchical_sub_task_graph_builder_util.cpp │ │ │ │ ├── hierarchical_sub_task_graph_builder_util.h │ │ │ │ ├── naive_b2b_sub_task_graph_builder.cpp │ │ │ │ ├── naive_b2b_sub_task_graph_builder.h │ │ │ │ ├── naive_b2p_sub_task_graph_builder.cpp │ │ │ │ ├── naive_b2p_sub_task_graph_builder.h │ │ │ │ ├── one_to_one_sub_task_graph_builder.cpp │ │ │ │ ├── one_to_one_sub_task_graph_builder.h │ │ │ │ ├── slice_boxing_sub_task_graph_builder.cpp │ │ │ │ ├── slice_boxing_sub_task_graph_builder.h │ │ │ │ ├── sub_task_graph_builder.h │ │ │ │ ├── sub_task_graph_builder_context.cpp │ │ │ │ ├── sub_task_graph_builder_context.h │ │ │ │ ├── sub_task_graph_builder_status_util.cpp │ │ │ │ ├── sub_task_graph_builder_status_util.h │ │ │ │ ├── sub_task_graph_builder_util.cpp │ │ │ │ └── sub_task_graph_builder_util.h │ │ │ ├── boxing_identity_task_node.cpp │ │ │ ├── boxing_identity_task_node.h │ │ │ ├── boxing_task_graph.proto │ │ │ ├── boxing_zeros_task_node.cpp │ │ │ ├── boxing_zeros_task_node.h │ │ │ ├── collective_boxing_pack_task_node.cpp │ │ │ ├── collective_boxing_pack_task_node.h │ │ │ ├── collective_boxing_task_node.cpp │ │ │ ├── collective_boxing_task_node.h │ │ │ ├── collective_boxing_unpack_task_node.cpp │ │ │ ├── collective_boxing_unpack_task_node.h │ │ │ ├── compute_task_node.cpp │ │ │ ├── compute_task_node.h │ │ │ ├── copy_task_node.cpp │ │ │ ├── copy_task_node.h │ │ │ ├── exec_graph.cpp │ │ │ ├── exec_graph.h │ │ │ ├── exec_sequence.proto │ │ │ ├── fake_consumed_regst_provider.h │ │ │ ├── graph.h │ │ │ ├── inplace_lbi_graph.cpp │ │ │ ├── inplace_lbi_graph.h │ │ │ ├── inplace_regst_graph.cpp │ │ │ ├── inplace_regst_graph.h │ │ │ ├── nccl_send_recv_boxing_task_node.cpp │ │ │ ├── nccl_send_recv_boxing_task_node.h │ │ │ ├── node.cpp │ │ │ ├── node.h │ │ │ ├── normal_forward_compute_task_node.h │ │ │ ├── op_graph.cpp │ │ │ ├── op_graph.h │ │ │ ├── plan_task_graph.cpp │ │ │ ├── plan_task_graph.h │ │ │ ├── slice_boxing_task_node.cpp │ │ │ ├── slice_boxing_task_node.h │ │ │ ├── straighten_nodes.cpp │ │ │ ├── straighten_nodes.h │ │ │ ├── stream_id.cpp │ │ │ ├── stream_id.h │ │ │ ├── stream_index_generator.cpp │ │ │ ├── stream_index_generator.h │ │ │ ├── task_edge.proto │ │ │ ├── task_graph.cpp │ │ │ ├── task_graph.h │ │ │ ├── task_graph_rebuild_ctx.cpp │ │ │ ├── task_graph_rebuild_ctx.h │ │ │ ├── task_id.cpp │ │ │ ├── task_id.h │ │ │ ├── task_id_generator.cpp │ │ │ ├── task_id_generator.h │ │ │ ├── task_node.cpp │ │ │ ├── task_node.h │ │ │ ├── task_stream_id.h │ │ │ ├── task_stream_index_manager.cpp │ │ │ ├── task_stream_index_manager.h │ │ │ ├── task_type_visitor.h │ │ │ ├── transport_task_node.cpp │ │ │ └── transport_task_node.h │ │ ├── graph_impl/ │ │ │ ├── acc_compute_task_node.cpp │ │ │ ├── acc_ctrl_tick_compute_task_node.cpp │ │ │ ├── acc_tick_compute_task_node.cpp │ │ │ ├── callback_notify_compute_task_node.cpp │ │ │ ├── case_compute_task_node.cpp │ │ │ ├── critical_section_wait_compute_task_node.cpp │ │ │ ├── decode_h2d_compute_task_node.cpp │ │ │ ├── device_tick_compute_task_node.cpp │ │ │ ├── distribute_concat_compute_task_node.cpp │ │ │ ├── distribute_split_compute_task_node.cpp │ │ │ ├── dst_subset_tick_compute_task_node.cpp │ │ │ ├── esac_compute_task_node.cpp │ │ │ ├── normal_forward_compute_task_node.cpp │ │ │ ├── pack_compute_task_node.cpp │ │ │ ├── reentrant_lock_compute_task_node.cpp │ │ │ ├── repeat_compute_task_node.cpp │ │ │ ├── source_tick_compute_task_node.cpp │ │ │ ├── src_subset_tick_compute_task_node.cpp │ │ │ ├── ssp_variable_proxy_task_node.cpp │ │ │ ├── tick_compute_task_node.cpp │ │ │ ├── unpack_compute_task_node.cpp │ │ │ └── wait_and_send_ids_compute_task_node.cpp │ │ ├── hardware/ │ │ │ ├── basic_device_descriptor_list.cpp │ │ │ ├── basic_device_descriptor_list.h │ │ │ ├── cuda_device_descriptor.cpp │ │ │ ├── cuda_device_descriptor.h │ │ │ ├── cuda_device_descriptor_class.cpp │ │ │ ├── device_descriptor.h │ │ │ ├── device_descriptor_class.cpp │ │ │ ├── device_descriptor_class.h │ │ │ ├── device_descriptor_list.h │ │ │ ├── net_ib_device_descriptor.cpp │ │ │ ├── net_ib_device_descriptor.h │ │ │ ├── net_ib_device_descriptor_class.cpp │ │ │ ├── net_socket_device_descriptor.cpp │ │ │ ├── net_socket_device_descriptor.h │ │ │ ├── net_socket_device_descriptor_class.cpp │ │ │ ├── node_device_descriptor.cpp │ │ │ ├── node_device_descriptor.h │ │ │ ├── node_device_descriptor_manager.cpp │ │ │ ├── node_device_descriptor_manager.h │ │ │ ├── topology_descriptor.cpp │ │ │ └── topology_descriptor.h │ │ ├── intrusive/ │ │ │ ├── README.md │ │ │ ├── base.h │ │ │ ├── cpp_attribute.h │ │ │ ├── dss.h │ │ │ ├── dss_test.cpp │ │ │ ├── flat_msg.h │ │ │ ├── flat_msg_test.cpp │ │ │ ├── flat_msg_view.h │ │ │ ├── flat_msg_view_test.cpp │ │ │ ├── for_each.h │ │ │ ├── force_standard_layout.h │ │ │ ├── force_standard_layout_test.cpp │ │ │ ├── head_free_list.h │ │ │ ├── head_free_list_test.cpp │ │ │ ├── intrusive.h │ │ │ ├── intrusive_core_test.cpp │ │ │ ├── list.h │ │ │ ├── list_hook.h │ │ │ ├── list_hook_test.cpp │ │ │ ├── list_test.cpp │ │ │ ├── mutexed_list.h │ │ │ ├── object_pool.h │ │ │ ├── object_pool_test.cpp │ │ │ ├── ref.h │ │ │ ├── reflective.h │ │ │ ├── shared_ptr.h │ │ │ ├── skiplist.h │ │ │ ├── skiplist_hook.h │ │ │ ├── skiplist_hook_test.cpp │ │ │ ├── skiplist_test.cpp │ │ │ ├── static_counter.h │ │ │ ├── static_counter_test.cpp │ │ │ ├── struct_traits.h │ │ │ └── struct_traits_test.cpp │ │ ├── ipc/ │ │ │ ├── shared_memory.cpp │ │ │ └── shared_memory.h │ │ ├── job/ │ │ │ ├── blob_lifetime_signature.proto │ │ │ ├── checkpointing_config_def.cpp │ │ │ ├── cluster_instruction.cpp │ │ │ ├── cluster_instruction.h │ │ │ ├── cluster_instruction.proto │ │ │ ├── collective_boxing/ │ │ │ │ ├── coordinator.h │ │ │ │ ├── executor.cpp │ │ │ │ ├── executor.h │ │ │ │ ├── executor_backend.h │ │ │ │ ├── executor_backend_manager.cpp │ │ │ │ ├── executor_backend_manager.h │ │ │ │ ├── nccl_executor_backend.cu │ │ │ │ ├── request_store.cpp │ │ │ │ ├── request_store.h │ │ │ │ ├── runtime_request_info.h │ │ │ │ ├── scheduler.cpp │ │ │ │ ├── scheduler.h │ │ │ │ ├── static_group_coordinator.cpp │ │ │ │ └── static_group_coordinator.h │ │ │ ├── compile_mode.cpp │ │ │ ├── compile_mode.h │ │ │ ├── compiler.cpp │ │ │ ├── compiler.h │ │ │ ├── critical_section.proto │ │ │ ├── critical_section_desc.cpp │ │ │ ├── critical_section_desc.h │ │ │ ├── critical_section_instance.h │ │ │ ├── distribute_hirarchy.proto │ │ │ ├── dlnet_conf.proto │ │ │ ├── eager_ccl_comm_manager.cpp │ │ │ ├── eager_ccl_comm_manager.h │ │ │ ├── eager_nccl_comm_manager.cpp │ │ │ ├── eager_nccl_comm_manager.h │ │ │ ├── env.proto │ │ │ ├── env_desc.cpp │ │ │ ├── env_desc.h │ │ │ ├── env_global_objects_scope.cpp │ │ │ ├── env_global_objects_scope.h │ │ │ ├── function_config_def.cpp │ │ │ ├── global_for.cpp │ │ │ ├── global_for.h │ │ │ ├── global_mode.cpp │ │ │ ├── global_mode.h │ │ │ ├── graph_scope_vars.cpp │ │ │ ├── graph_scope_vars.h │ │ │ ├── id_manager.cpp │ │ │ ├── id_manager.h │ │ │ ├── id_manager_test.cpp │ │ │ ├── id_state.h │ │ │ ├── initializer_conf.proto │ │ │ ├── inter_job_mem_sharing_util.cpp │ │ │ ├── inter_job_mem_sharing_util.h │ │ │ ├── inter_user_job_info.proto │ │ │ ├── intra_job_mem_sharing_util.cpp │ │ │ ├── intra_job_mem_sharing_util.h │ │ │ ├── job.proto │ │ │ ├── job_build_and_infer_ctx.cpp │ │ │ ├── job_build_and_infer_ctx.h │ │ │ ├── job_build_and_infer_ctx_mgr.cpp │ │ │ ├── job_build_and_infer_ctx_mgr.h │ │ │ ├── job_builder.cpp │ │ │ ├── job_builder.h │ │ │ ├── job_conf.proto │ │ │ ├── job_desc.cpp │ │ │ ├── job_desc.h │ │ │ ├── job_instance.h │ │ │ ├── job_interpreter.cpp │ │ │ ├── job_interpreter.h │ │ │ ├── job_ir.cpp │ │ │ ├── job_ir.h │ │ │ ├── job_set.proto │ │ │ ├── job_set_compile_ctx.h │ │ │ ├── job_set_compile_ctx.proto │ │ │ ├── lazy_mode.cpp │ │ │ ├── lazy_mode.h │ │ │ ├── learning_rate_schedule_conf.proto │ │ │ ├── local_parallel.proto │ │ │ ├── local_sig_infer_hint.h │ │ │ ├── memory_share_strategy.cpp │ │ │ ├── memory_share_strategy.h │ │ │ ├── module_conf.proto │ │ │ ├── nd_sbp_infer_hint.h │ │ │ ├── nd_sbp_util.cpp │ │ │ ├── nd_sbp_util.h │ │ │ ├── oneflow.cpp │ │ │ ├── oneflow.h │ │ │ ├── parallel_conf_signature.proto │ │ │ ├── parallel_desc.cpp │ │ │ ├── parallel_desc.h │ │ │ ├── parallel_desc_test.cpp │ │ │ ├── parallel_signature.proto │ │ │ ├── pipeline_config_def.cpp │ │ │ ├── placement.proto │ │ │ ├── placement_scope.cpp │ │ │ ├── placement_scope.h │ │ │ ├── plan.proto │ │ │ ├── plan_util.cpp │ │ │ ├── plan_util.h │ │ │ ├── qat_config_def.cpp │ │ │ ├── rank_compiler.cpp │ │ │ ├── rank_compiler.h │ │ │ ├── rank_group.cpp │ │ │ ├── rank_group.h │ │ │ ├── rank_group_scope.cpp │ │ │ ├── rank_group_scope.h │ │ │ ├── rank_group_test.cpp │ │ │ ├── regularizer_conf.proto │ │ │ ├── resource.proto │ │ │ ├── resource_desc.cpp │ │ │ ├── resource_desc.h │ │ │ ├── runtime.cpp │ │ │ ├── runtime.h │ │ │ ├── runtime_buffer_managers_scope.cpp │ │ │ ├── runtime_buffer_managers_scope.h │ │ │ ├── runtime_buffers_scope.cpp │ │ │ ├── runtime_buffers_scope.h │ │ │ ├── runtime_context.cpp │ │ │ ├── runtime_context.h │ │ │ ├── runtime_job_descs.cpp │ │ │ ├── runtime_job_descs.h │ │ │ ├── sbp_infer_hint.h │ │ │ ├── sbp_parallel.cpp │ │ │ ├── sbp_parallel.h │ │ │ ├── sbp_parallel.proto │ │ │ ├── sbp_signature_builder.cpp │ │ │ ├── sbp_signature_builder.h │ │ │ ├── scope.cpp │ │ │ ├── scope.h │ │ │ ├── scope.proto │ │ │ ├── session.cpp │ │ │ ├── session.h │ │ │ ├── ssp_config_def.cpp │ │ │ ├── sub_plan.proto │ │ │ ├── task.proto │ │ │ ├── utils/ │ │ │ │ ├── progress_bar.cpp │ │ │ │ └── progress_bar.h │ │ │ ├── version.cpp │ │ │ └── version.h │ │ ├── job_rewriter/ │ │ │ ├── adadelta_optim.cpp │ │ │ ├── adagrad_optm.cpp │ │ │ ├── adam_optm.cpp │ │ │ ├── add_ssp_variable_proxy.cpp │ │ │ ├── auto_learning_rate.cpp │ │ │ ├── auto_mixed_precision.cpp │ │ │ ├── auto_mixed_precision.h │ │ │ ├── auto_mixed_precision_lists.cpp │ │ │ ├── auto_mixed_precision_lists.h │ │ │ ├── auto_parallel.cpp │ │ │ ├── auto_train_step.cpp │ │ │ ├── autograd.cpp │ │ │ ├── autograd.h │ │ │ ├── autotick.cpp │ │ │ ├── autotick.h │ │ │ ├── boxing_with_middle_nodes.cpp │ │ │ ├── boxing_with_middle_nodes.h │ │ │ ├── calculation_pass.cpp │ │ │ ├── calculation_pass.h │ │ │ ├── checkpointing_pass.cpp │ │ │ ├── clip_by_global_norm_job_pass_state.h │ │ │ ├── clone_grad.cpp │ │ │ ├── clone_grad.h │ │ │ ├── cudnn_fused_normalization_add_relu_pass.cpp │ │ │ ├── cutlass_conv_tuning_warmup_pass.cpp │ │ │ ├── delay_variable_op_execution_pass.cpp │ │ │ ├── device_tick_autotick.cpp │ │ │ ├── do_parallel_cast_before_widening_type_cast_pass.cpp │ │ │ ├── dump_blob_parallel_conf_pass.cpp │ │ │ ├── dump_variable_info_pass.cpp │ │ │ ├── dynamic_loss_scale_job_pass_state.h │ │ │ ├── dynamic_loss_scale_schedule_pass.cpp │ │ │ ├── eliminate_dead_nodes_pass.cpp │ │ │ ├── fix_pipeline_stage_id_pass.cpp │ │ │ ├── ftrl_optm.cpp │ │ │ ├── fuse_add_to_output_pass.cpp │ │ │ ├── fuse_bce_reduce_mean_fw_bw_pass.cpp │ │ │ ├── fuse_cast_scale_pass.cpp │ │ │ ├── fuse_consecutive_add_pass.cpp │ │ │ ├── fuse_embedding_interaction_pass.cpp │ │ │ ├── fuse_model_update_cast_pass.cpp │ │ │ ├── fuse_update_ops_pass.cpp │ │ │ ├── generate_optimizer_op_confs.cpp │ │ │ ├── group_boxing_by_dst_parallel.cpp │ │ │ ├── group_boxing_by_dst_parallel.h │ │ │ ├── indexed_slices_optimizer_rewrite_pass.cpp │ │ │ ├── input_autotick.cpp │ │ │ ├── insert_nccl_logical_op_pass.cpp │ │ │ ├── insert_pinned_identity_op_pass.cpp │ │ │ ├── job_completer.cpp │ │ │ ├── job_completer.h │ │ │ ├── job_pass.cpp │ │ │ ├── job_pass.h │ │ │ ├── lamb_optm.cpp │ │ │ ├── lars_optm.cpp │ │ │ ├── logical_chain_pass.cpp │ │ │ ├── momentum_optm.cpp │ │ │ ├── multi_tensor_model_update.cpp │ │ │ ├── nccl_logical_chain_strict_order_pass.cpp │ │ │ ├── nccl_logical_op_fusion_pass.cpp │ │ │ ├── normalization_exponential_average_auto_tick_rewrite_pass.cpp │ │ │ ├── optimizer.cpp │ │ │ ├── optimizer.h │ │ │ ├── optimizer_placement_optimization_pass.cpp │ │ │ ├── pass_util.cpp │ │ │ ├── pass_util.h │ │ │ ├── pipeline_buffer_pass.cpp │ │ │ ├── prune_amp_white_identity_op_pass.cpp │ │ │ ├── prune_cast_to_static_shape_op_pass.cpp │ │ │ ├── prune_depend_op_pass.cpp │ │ │ ├── prune_parallel_cast_op_pass.cpp │ │ │ ├── prune_pinned_identity_op_pass.cpp │ │ │ ├── quantization_aware_training.cpp │ │ │ ├── replace_embedding_ops_pass.cpp │ │ │ ├── rmsprop_optm.cpp │ │ │ ├── sequential_one_embedding_shuffle_ops_pass.cpp │ │ │ ├── sgd_optm.cpp │ │ │ ├── source_user_op_auto_tick.cpp │ │ │ ├── split_sparse_softmax_cross_entropy_op_pass.cpp │ │ │ ├── system_op_fill_job_name_pass.cpp │ │ │ ├── tick_autotick.cpp │ │ │ └── variable_autotick.cpp │ │ ├── kernel/ │ │ │ ├── assign_kernel.cpp │ │ │ ├── blob_access_checker_kernel_observer.cpp │ │ │ ├── blob_access_checker_kernel_observer.h │ │ │ ├── blob_tensor_view.cpp │ │ │ ├── blob_tensor_view.h │ │ │ ├── boxing_kernel.cpp │ │ │ ├── boxing_zeros_kernel.cpp │ │ │ ├── broadcast_to_compatible_with_kernel.cpp │ │ │ ├── callback_notify_kernel.cpp │ │ │ ├── case_kernel.cpp │ │ │ ├── case_kernel.h │ │ │ ├── chain_kernel_observer.cpp │ │ │ ├── chain_kernel_observer.h │ │ │ ├── collective_boxing_kernels.cpp │ │ │ ├── collective_boxing_pack_kernel.cpp │ │ │ ├── collective_boxing_unpack_kernel.cpp │ │ │ ├── constant_like_kernel.cpp │ │ │ ├── cpu_check_numerics_kernel_observer.h │ │ │ ├── cpu_numerics_kernel_observer.cpp │ │ │ ├── critical_section_callback_tick_kernel.cpp │ │ │ ├── critical_section_wait_tick_kernel.cpp │ │ │ ├── cuda_check_numerics_kernel_observer.cu │ │ │ ├── cuda_check_numerics_kernel_observer.h │ │ │ ├── cuda_graph_support.h │ │ │ ├── distribute_kernels.cpp │ │ │ ├── dynamic_reshape_kernel.cpp │ │ │ ├── dynamic_reshape_like_kernel.cpp │ │ │ ├── esac_kernel.cpp │ │ │ ├── esac_kernel.h │ │ │ ├── identity_kernel.cpp │ │ │ ├── image_decoder_random_crop_resize_kernel.cpp │ │ │ ├── input_kernel.cpp │ │ │ ├── kernel.cpp │ │ │ ├── kernel.h │ │ │ ├── kernel.proto │ │ │ ├── kernel_context.h │ │ │ ├── kernel_observer.h │ │ │ ├── kernel_registration.cpp │ │ │ ├── kernel_registration.h │ │ │ ├── kernel_util.cpp │ │ │ ├── kernel_util.cuh │ │ │ ├── kernel_util.h │ │ │ ├── learning_rate_schedule_kernel.cpp │ │ │ ├── nccl_send_recv_boxing_kernel.cpp │ │ │ ├── new_kernel_util.h │ │ │ ├── nop_kernel.cpp │ │ │ ├── output_kernel.cpp │ │ │ ├── profiler_kernel_observer.cpp │ │ │ ├── profiler_kernel_observer.h │ │ │ ├── random_generator.cpp │ │ │ ├── random_generator.cu │ │ │ ├── random_generator.h │ │ │ ├── reentrant_lock_kernel.cpp │ │ │ ├── reentrant_lock_kernel.h │ │ │ ├── return_kernel.cpp │ │ │ ├── runtime_blob_shape_infer_helper.cpp │ │ │ ├── runtime_blob_shape_infer_helper.h │ │ │ ├── shape_elem_cnt_kernel.cpp │ │ │ ├── slice_boxing_kernel.cpp │ │ │ ├── sync_check_kernel_observer.cpp │ │ │ ├── sync_check_kernel_observer.h │ │ │ ├── sync_dynamic_resize_kernel.cpp │ │ │ ├── total_loss_instance_num_kernel.cpp │ │ │ ├── user_kernel.cpp │ │ │ ├── user_kernel.h │ │ │ ├── util/ │ │ │ │ ├── cuda_half_util.h │ │ │ │ ├── numeric_limits.cuh │ │ │ │ └── numerics.cuh │ │ │ ├── wait_and_send_ids_kernel.cpp │ │ │ └── wait_and_send_ids_kernel.h │ │ ├── lazy/ │ │ │ ├── actor/ │ │ │ │ ├── acc_actor.cpp │ │ │ │ ├── acc_ctrl_tick_actor.cpp │ │ │ │ ├── acc_tick_actor.cpp │ │ │ │ ├── actor.cpp │ │ │ │ ├── actor.h │ │ │ │ ├── actor_base.cpp │ │ │ │ ├── actor_base.h │ │ │ │ ├── actor_context.cpp │ │ │ │ ├── actor_context.h │ │ │ │ ├── actor_message.cpp │ │ │ │ ├── actor_message.h │ │ │ │ ├── actor_message_bus.cpp │ │ │ │ ├── actor_message_bus.h │ │ │ │ ├── boxing_zeros_actor.cpp │ │ │ │ ├── callback_notify_actor.cpp │ │ │ │ ├── case_actor.cpp │ │ │ │ ├── collective_boxing_actor_context.cpp │ │ │ │ ├── collective_boxing_actor_context.h │ │ │ │ ├── copy_comm_net_actor.cpp │ │ │ │ ├── esac_actor.cpp │ │ │ │ ├── generic_actor_context.cpp │ │ │ │ ├── generic_actor_context.h │ │ │ │ ├── input_wise_actor.cpp │ │ │ │ ├── input_wise_actor.h │ │ │ │ ├── light_actor.cpp │ │ │ │ ├── light_actor.h │ │ │ │ ├── naive_actor.cpp │ │ │ │ ├── naive_actor.h │ │ │ │ ├── pack_actor.cpp │ │ │ │ ├── reentrant_lock_actor.cpp │ │ │ │ ├── register_slot.cpp │ │ │ │ ├── register_slot.h │ │ │ │ ├── repeat_actor.cpp │ │ │ │ ├── sink_actor.cpp │ │ │ │ ├── sink_actor.h │ │ │ │ ├── source_tick_actor.cpp │ │ │ │ ├── ssp_variable_proxy_actor.cpp │ │ │ │ ├── tick_actor.cpp │ │ │ │ ├── unpack_actor.cpp │ │ │ │ └── wait_and_send_ids_actor.cpp │ │ │ └── stream_context/ │ │ │ ├── common/ │ │ │ │ └── generic_stream_context.cpp │ │ │ ├── cpu/ │ │ │ │ └── cpu_stream_context.cpp │ │ │ ├── cuda/ │ │ │ │ └── cuda_stream_context.cpp │ │ │ └── include/ │ │ │ ├── generic_stream_context.h │ │ │ └── stream_context.h │ │ ├── memory/ │ │ │ ├── chunk_manager.cpp │ │ │ ├── chunk_manager.h │ │ │ ├── memory_allocator.cpp │ │ │ ├── memory_allocator.h │ │ │ ├── memory_block.proto │ │ │ ├── memory_case.proto │ │ │ ├── memory_case_util.cpp │ │ │ ├── memory_case_util.h │ │ │ ├── memory_zone.cpp │ │ │ └── memory_zone.h │ │ ├── ndarray/ │ │ │ ├── binary_func.h │ │ │ ├── cpu_concat_var_ndarray.h │ │ │ ├── cpu_concat_var_ndarray_test.cpp │ │ │ ├── cpu_ndarray.h │ │ │ ├── cpu_ndarray_builder.h │ │ │ ├── cpu_ndarray_copy.h │ │ │ ├── cpu_slice_var_ndarray.h │ │ │ ├── cpu_slice_var_ndarray_test.cpp │ │ │ ├── cpu_var_ndarray.h │ │ │ ├── cpu_var_ndarray_test.cpp │ │ │ ├── ndarray_apply_binary.h │ │ │ ├── ndarray_apply_binary_core.cpp │ │ │ ├── ndarray_apply_binary_core.cu │ │ │ ├── ndarray_apply_binary_core.h │ │ │ ├── ndarray_apply_broadcast_binary.h │ │ │ ├── ndarray_apply_broadcast_binary_core.cpp │ │ │ ├── ndarray_apply_broadcast_binary_core.cu │ │ │ ├── ndarray_apply_broadcast_binary_core.h │ │ │ ├── ndarray_apply_broadcast_unary.h │ │ │ ├── ndarray_apply_broadcast_unary_core.cpp │ │ │ ├── ndarray_apply_broadcast_unary_core.cu │ │ │ ├── ndarray_apply_broadcast_unary_core.h │ │ │ ├── ndarray_apply_unary.h │ │ │ ├── ndarray_apply_unary_core.cpp │ │ │ ├── ndarray_apply_unary_core.cu │ │ │ ├── ndarray_apply_unary_core.h │ │ │ ├── ndarray_assign_core.cpp │ │ │ ├── ndarray_assign_core.cu │ │ │ ├── ndarray_assign_core.h │ │ │ ├── ndarray_reduce.h │ │ │ ├── ndarray_reduce_impl.cpp │ │ │ ├── ndarray_reduce_impl.cu │ │ │ ├── ndarray_reduce_impl.h │ │ │ ├── ndarray_util.h │ │ │ ├── slice.cpp │ │ │ ├── slice.h │ │ │ ├── slice_test.cpp │ │ │ ├── unary_func.h │ │ │ ├── xpu_binary_func_ndarray.h │ │ │ ├── xpu_broadcast_ndarray.h │ │ │ ├── xpu_ndarray_assign.cu │ │ │ ├── xpu_ndarray_assign.h │ │ │ ├── xpu_ndarray_base.h │ │ │ ├── xpu_reduced_ndarray.h │ │ │ ├── xpu_reshape_ndarray.h │ │ │ ├── xpu_shape.cpp │ │ │ ├── xpu_shape.h │ │ │ ├── xpu_transpose_ndarray.h │ │ │ ├── xpu_unary_func_ndarray.h │ │ │ ├── xpu_util.h │ │ │ ├── xpu_var_ndarray.h │ │ │ └── xpu_var_ndarray_builder.h │ │ ├── operator/ │ │ │ ├── acc_tick_op.cpp │ │ │ ├── acc_tick_op.h │ │ │ ├── arg_modifier_signature.proto │ │ │ ├── assign_op.cpp │ │ │ ├── boxing_identity_op.cpp │ │ │ ├── boxing_op.cpp │ │ │ ├── boxing_op.h │ │ │ ├── boxing_zeros_op.cpp │ │ │ ├── broadcast_to_compatible_with_op.cpp │ │ │ ├── callback_notify_op.cpp │ │ │ ├── callback_notify_op.h │ │ │ ├── case_op.cpp │ │ │ ├── case_op.h │ │ │ ├── collective_boxing_ops.cpp │ │ │ ├── collective_boxing_pack_op.cpp │ │ │ ├── collective_boxing_unpack_op.cpp │ │ │ ├── constant_like_op.cpp │ │ │ ├── copy_comm_net_op.cpp │ │ │ ├── copy_comm_net_op.h │ │ │ ├── critical_section_callback_tick_op.cpp │ │ │ ├── critical_section_wait_tick_op.cpp │ │ │ ├── cwise_op.cpp │ │ │ ├── cwise_op.h │ │ │ ├── decode_random_op.h │ │ │ ├── device_tick_op.cpp │ │ │ ├── device_tick_op.h │ │ │ ├── distribute_add_op.cpp │ │ │ ├── distribute_clone_op.cpp │ │ │ ├── distribute_concat_op.cpp │ │ │ ├── distribute_split_op.cpp │ │ │ ├── dst_subset_tick_op.cpp │ │ │ ├── dynamic_reshape_op.cpp │ │ │ ├── esac_op.cpp │ │ │ ├── esac_op.h │ │ │ ├── identity_op.cpp │ │ │ ├── image_decoder_random_crop_resize_op.cpp │ │ │ ├── input_op.cpp │ │ │ ├── input_op.h │ │ │ ├── interface_blob_conf.proto │ │ │ ├── interface_op_util.cpp │ │ │ ├── interface_op_util.h │ │ │ ├── learning_rate_schedule_op.cpp │ │ │ ├── nccl_send_recv_boxing_op.cpp │ │ │ ├── nccl_send_recv_boxing_op_util.cpp │ │ │ ├── nccl_send_recv_boxing_op_util.h │ │ │ ├── op_attribute.proto │ │ │ ├── op_conf.proto │ │ │ ├── op_conf_symbol.cpp │ │ │ ├── op_conf_symbol.h │ │ │ ├── op_conf_util.h │ │ │ ├── op_infer_cache.h │ │ │ ├── op_node_signature.proto │ │ │ ├── operator.cpp │ │ │ ├── operator.h │ │ │ ├── operator_util.cpp │ │ │ ├── operator_util.h │ │ │ ├── output_op.cpp │ │ │ ├── output_op.h │ │ │ ├── reduce_sbp_util.cpp │ │ │ ├── reduce_sbp_util.h │ │ │ ├── reentrant_lock_op.cpp │ │ │ ├── reentrant_lock_op.h │ │ │ ├── return_op.cpp │ │ │ ├── return_op.h │ │ │ ├── scalar_op_base.cpp │ │ │ ├── scalar_op_base.h │ │ │ ├── shape_elem_cnt_op.cpp │ │ │ ├── shape_elem_cnt_op.h │ │ │ ├── sink_tick_op.cpp │ │ │ ├── sink_tick_op.h │ │ │ ├── slice_boxing_op.cpp │ │ │ ├── source_tick_op.cpp │ │ │ ├── source_tick_op.h │ │ │ ├── src_subset_tick_op.cpp │ │ │ ├── sync_dynamic_resize_op.cpp │ │ │ ├── tick_op.cpp │ │ │ ├── tick_op.h │ │ │ ├── total_loss_instance_num_op.cpp │ │ │ ├── total_loss_instance_num_op.h │ │ │ ├── user_op.cpp │ │ │ ├── user_op.h │ │ │ ├── variable_op.cpp │ │ │ ├── variable_op.h │ │ │ ├── wait_and_send_ids_op.cpp │ │ │ └── wait_and_send_ids_op.h │ │ ├── persistence/ │ │ │ ├── binary_in_stream.h │ │ │ ├── binary_in_stream_with_local_copy.cpp │ │ │ ├── binary_in_stream_with_local_copy.h │ │ │ ├── binary_in_stream_without_local_copy.cpp │ │ │ ├── binary_in_stream_without_local_copy.h │ │ │ ├── file_system.cpp │ │ │ ├── file_system.h │ │ │ ├── file_system_test.cpp │ │ │ ├── hadoop/ │ │ │ │ ├── hadoop_file_system.cpp │ │ │ │ ├── hadoop_file_system.h │ │ │ │ └── hdfs.h │ │ │ ├── persistent_in_stream.cpp │ │ │ ├── persistent_in_stream.h │ │ │ ├── persistent_out_stream.cpp │ │ │ ├── persistent_out_stream.h │ │ │ ├── posix/ │ │ │ │ ├── posix_file_system.cpp │ │ │ │ └── posix_file_system.h │ │ │ ├── stream_scanner.cpp │ │ │ ├── stream_scanner.h │ │ │ ├── tee_persistent_log_stream.cpp │ │ │ └── tee_persistent_log_stream.h │ │ ├── platform/ │ │ │ ├── include/ │ │ │ │ ├── ibv.h │ │ │ │ ├── pthread_fork.h │ │ │ │ └── wrapper.h │ │ │ └── lib/ │ │ │ ├── ibv_wrapper.cpp │ │ │ ├── pthread_fork.cpp │ │ │ └── wrapper.cpp │ │ ├── profiler/ │ │ │ ├── event.cpp │ │ │ ├── event.h │ │ │ ├── event_recorder.cpp │ │ │ ├── event_recorder.h │ │ │ ├── kernel.cpp │ │ │ ├── kernel.h │ │ │ ├── kineto_shim.cpp │ │ │ ├── kineto_shim.h │ │ │ ├── profile_manager.cpp │ │ │ ├── profile_manager.h │ │ │ ├── profiler.cpp │ │ │ ├── profiler.h │ │ │ └── util.h │ │ ├── record/ │ │ │ ├── coco.proto │ │ │ └── record.proto │ │ ├── register/ │ │ │ ├── blob.cpp │ │ │ ├── blob.h │ │ │ ├── blob_desc.cpp │ │ │ ├── blob_desc.h │ │ │ ├── blob_desc.proto │ │ │ ├── logical_blob_id.proto │ │ │ ├── op_blob_arg.proto │ │ │ ├── op_blob_arg_info.h │ │ │ ├── register.cpp │ │ │ ├── register.h │ │ │ ├── register_desc.cpp │ │ │ ├── register_desc.h │ │ │ ├── register_desc.proto │ │ │ ├── register_manager.cpp │ │ │ ├── register_manager.h │ │ │ ├── runtime_register_desc.cpp │ │ │ ├── runtime_register_desc.h │ │ │ ├── tensor_slice_copier.cpp │ │ │ ├── tensor_slice_copier.h │ │ │ ├── tensor_slice_view.cpp │ │ │ ├── tensor_slice_view.h │ │ │ └── tensor_slice_view.proto │ │ ├── rpc/ │ │ │ ├── include/ │ │ │ │ ├── base.h │ │ │ │ ├── ctrl.h │ │ │ │ ├── global_process_ctx.h │ │ │ │ ├── grpc.h │ │ │ │ ├── local.h │ │ │ │ └── manager.h │ │ │ └── lib/ │ │ │ ├── global_process_ctx.cpp │ │ │ ├── grpc.cpp │ │ │ └── local.cpp │ │ ├── summary/ │ │ │ ├── event.proto │ │ │ ├── graph.proto │ │ │ ├── plugin_data.proto │ │ │ ├── projector.proto │ │ │ ├── summary.proto │ │ │ └── tensor.proto │ │ ├── thread/ │ │ │ ├── is_main_thread_test.cpp │ │ │ ├── thread.cpp │ │ │ ├── thread.h │ │ │ ├── thread_global_id.cpp │ │ │ ├── thread_global_id.h │ │ │ ├── thread_manager.cpp │ │ │ ├── thread_manager.h │ │ │ ├── thread_pool.cpp │ │ │ ├── thread_pool.h │ │ │ ├── thread_runtime.h │ │ │ ├── thread_runtime_factory.cpp │ │ │ └── thread_runtime_factory.h │ │ ├── transport/ │ │ │ ├── transport.cpp │ │ │ ├── transport.h │ │ │ └── transport_message.h │ │ └── vm/ │ │ ├── access_blob_arg_cb_instruction_policy.h │ │ ├── allocate_tensor_instruction_policy.cpp │ │ ├── allocate_tensor_instruction_policy.h │ │ ├── allocator.h │ │ ├── barrier_instruction_policy.h │ │ ├── bin_allocator.h │ │ ├── bin_allocator_test.cpp │ │ ├── caching_allocator.h │ │ ├── control_stream_policy.h │ │ ├── critical_section_instruction_policy.cpp │ │ ├── critical_section_instruction_policy.h │ │ ├── critical_section_status_querier.h │ │ ├── critical_section_stream_policy.cpp │ │ ├── critical_section_stream_policy.h │ │ ├── ep_backend_allocator.cpp │ │ ├── ep_backend_allocator.h │ │ ├── ep_backend_host_allocator.cpp │ │ ├── ep_backend_host_allocator.h │ │ ├── ep_d2h_stream_policy.cpp │ │ ├── ep_d2h_stream_policy.h │ │ ├── ep_event.cpp │ │ ├── ep_event.h │ │ ├── ep_optional_event_record_status_querier.cpp │ │ ├── ep_optional_event_record_status_querier.h │ │ ├── ep_record_event_instruction_policy.h │ │ ├── ep_stream_policy.cpp │ │ ├── ep_stream_policy.h │ │ ├── ep_stream_policy_base.cpp │ │ ├── ep_stream_policy_base.h │ │ ├── event_recorded_ep_stream_policy.cpp │ │ ├── event_recorded_ep_stream_policy.h │ │ ├── fuse_instruction_policy.h │ │ ├── global_sync_instruction_policy.h │ │ ├── instruction.cpp │ │ ├── instruction.h │ │ ├── instruction_fuse_type.h │ │ ├── instruction_policy.cpp │ │ ├── instruction_policy.h │ │ ├── instruction_policy_util.h │ │ ├── lazy_job_instruction_policy.h │ │ ├── lazy_job_stream_policy.cpp │ │ ├── lazy_job_stream_policy.h │ │ ├── naive_instruction_status_querier.h │ │ ├── op_call_instruction_policy.cpp │ │ ├── op_call_instruction_policy.h │ │ ├── pinned_ep_stream_policy.cpp │ │ ├── pinned_ep_stream_policy.h │ │ ├── probe.h │ │ ├── ref_cnt_instruction_status_querier.h │ │ ├── release_tensor_instruction_policy.h │ │ ├── remat/ │ │ │ ├── allocator.cpp │ │ │ ├── allocator.h │ │ │ ├── disjoint_set.cpp │ │ │ ├── disjoint_set.h │ │ │ ├── env.cpp │ │ │ ├── env.h │ │ │ ├── util.cpp │ │ │ └── util.h │ │ ├── stream.cpp │ │ ├── stream.h │ │ ├── stream_create_stream_policy.h │ │ ├── stream_get_allocator_stream_type.h │ │ ├── stream_policy.cpp │ │ ├── stream_policy.h │ │ ├── stream_record_event_instruction_policy.cpp │ │ ├── stream_record_event_instruction_policy.h │ │ ├── stream_wait_event_instruction_policy.cpp │ │ ├── stream_wait_event_instruction_policy.h │ │ ├── stream_wait_instruction_policy.cpp │ │ ├── stream_wait_instruction_policy.h │ │ ├── symbol_storage.cpp │ │ ├── symbol_storage.h │ │ ├── sync_access_instruction_policy.cpp │ │ ├── sync_access_instruction_policy.h │ │ ├── sync_vm_mode_guard.h │ │ ├── thread_ctx.cpp │ │ ├── thread_ctx.h │ │ ├── thread_safe_guard.h │ │ ├── touch_tensors_instruction_policy.h │ │ ├── virtual_machine.cpp │ │ ├── virtual_machine.h │ │ ├── virtual_machine_engine.cpp │ │ ├── virtual_machine_engine.h │ │ ├── virtual_machine_scope.cpp │ │ ├── virtual_machine_scope.h │ │ ├── vm_object.cpp │ │ ├── vm_object.h │ │ ├── vm_sync.h │ │ ├── vm_util.cpp │ │ └── vm_util.h │ ├── extension/ │ │ ├── python/ │ │ │ ├── numpy.cpp │ │ │ ├── numpy.h │ │ │ ├── numpy_internal.h │ │ │ ├── py_compute.cpp │ │ │ ├── py_compute.h │ │ │ ├── py_kernel_caller.cpp │ │ │ ├── py_kernel_caller.h │ │ │ ├── py_kernel_registry.cpp │ │ │ └── py_kernel_registry.h │ │ └── stack/ │ │ ├── foreign_stack_getter.h │ │ ├── python/ │ │ │ ├── custom_eval_frame.c │ │ │ ├── custom_eval_frame.h │ │ │ ├── stack_getter.cpp │ │ │ └── stack_getter.h │ │ └── stacktrace.h │ ├── ir/ │ │ ├── .gitignore │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── include/ │ │ │ ├── CMakeLists.txt │ │ │ ├── OneFlow/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── Conversion/ │ │ │ │ │ ├── NVVMToCubin.h │ │ │ │ │ └── OneFlowToTosa.h │ │ │ │ ├── Extension.h │ │ │ │ ├── OKL/ │ │ │ │ │ ├── Conversion/ │ │ │ │ │ │ ├── Conversion.h │ │ │ │ │ │ └── OKLToLLVM.h │ │ │ │ │ ├── Kernel/ │ │ │ │ │ │ ├── ComputeContext.h │ │ │ │ │ │ ├── InferContext.h │ │ │ │ │ │ ├── InitContext.h │ │ │ │ │ │ ├── JITEngine.h │ │ │ │ │ │ ├── JITOpInfer.h │ │ │ │ │ │ ├── LauncherContext.h │ │ │ │ │ │ ├── LauncherState.h │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ ├── RegContext.h │ │ │ │ │ │ ├── TmpBufferManager.h │ │ │ │ │ │ └── WrapperContext.h │ │ │ │ │ ├── OKLAttributes.h │ │ │ │ │ ├── OKLAttributes.td │ │ │ │ │ ├── OKLBase.td │ │ │ │ │ ├── OKLDialect.h │ │ │ │ │ ├── OKLDialect.td │ │ │ │ │ ├── OKLOps.h │ │ │ │ │ ├── OKLOps.td │ │ │ │ │ ├── OKLTypes.h │ │ │ │ │ ├── OKLTypes.td │ │ │ │ │ └── passes.h │ │ │ │ ├── OKM/ │ │ │ │ │ ├── Conversion/ │ │ │ │ │ │ └── Conversion.h │ │ │ │ │ ├── OKMAttributes.h │ │ │ │ │ ├── OKMAttributes.td │ │ │ │ │ ├── OKMBase.td │ │ │ │ │ ├── OKMDialect.h │ │ │ │ │ ├── OKMDialect.td │ │ │ │ │ ├── OKMOps.h │ │ │ │ │ ├── OKMOps.td │ │ │ │ │ ├── OKMPasses.td │ │ │ │ │ └── passes.h │ │ │ │ ├── OneFlowBase.td │ │ │ │ ├── OneFlowDataTypeConversion.h │ │ │ │ ├── OneFlowDialect.h │ │ │ │ ├── OneFlowDialect.td │ │ │ │ ├── OneFlowEnums.td │ │ │ │ ├── OneFlowInterfaces.td │ │ │ │ ├── OneFlowOpGetGen.td │ │ │ │ ├── OneFlowOpTraits.h │ │ │ │ ├── OneFlowOps.h │ │ │ │ ├── OneFlowOps.td │ │ │ │ ├── OneFlowPDLLPatterns.h │ │ │ │ ├── OneFlowPasses.td │ │ │ │ ├── OneFlowPatternUtils.h │ │ │ │ ├── OneFlowPatterns.td │ │ │ │ ├── OneFlowSupport.h │ │ │ │ ├── OneFlowTypes.h │ │ │ │ ├── OneFlowUserOps.td │ │ │ │ ├── OneFlowUtils.h │ │ │ │ ├── Passes.h │ │ │ │ ├── SBP/ │ │ │ │ │ ├── SBPAttributes.h │ │ │ │ │ ├── SBPBase.td │ │ │ │ │ ├── SBPDialect.h │ │ │ │ │ ├── SBPDialect.td │ │ │ │ │ ├── SBPImporter.h │ │ │ │ │ └── SBPOps.td │ │ │ │ ├── Transform/ │ │ │ │ │ ├── AggregateOps.h │ │ │ │ │ ├── AutoNhwc.h │ │ │ │ │ ├── BufferHostRegister.h │ │ │ │ │ ├── CSEWithAttributesIgnored.h │ │ │ │ │ ├── ConvertInferenceOp.h │ │ │ │ │ ├── EliminateAllocOps.h │ │ │ │ │ ├── FuncOps.h │ │ │ │ │ ├── OneFlow MLIR CodeGen ABI.md │ │ │ │ │ ├── OneFlowMemPool.h │ │ │ │ │ ├── OneFlowStream.h │ │ │ │ │ ├── OutlineAndFuse.h │ │ │ │ │ ├── TraitFolder.h │ │ │ │ │ └── TransposeHelpers.h │ │ │ │ ├── UserOpConversion.h │ │ │ │ └── UserOpReflection.h │ │ │ └── Transform/ │ │ │ ├── CMakeLists.txt │ │ │ ├── TransformDialectExtension.h │ │ │ ├── TransformDialectExtension.td │ │ │ └── TransformStateExtension.h │ │ ├── install-llvm.cmake │ │ ├── lib/ │ │ │ ├── CMakeLists.txt │ │ │ ├── OneFlow/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── Conversion/ │ │ │ │ │ ├── NVVMToCubin.cpp │ │ │ │ │ ├── OneFlowToLinalg.cpp │ │ │ │ │ └── OneFlowToTosa.cpp │ │ │ │ ├── OKL/ │ │ │ │ │ ├── Conversion/ │ │ │ │ │ │ ├── Conversion.cpp │ │ │ │ │ │ ├── CudaGraphSupport.cpp │ │ │ │ │ │ └── OKLToLLVM.cpp │ │ │ │ │ ├── Kernel/ │ │ │ │ │ │ ├── ComputeContext.cpp │ │ │ │ │ │ ├── InferContext.cpp │ │ │ │ │ │ ├── JITEngine.cpp │ │ │ │ │ │ ├── JITOpInfer.cpp │ │ │ │ │ │ ├── KernelLaunchOp.cpp │ │ │ │ │ │ ├── LauncherContext.cpp │ │ │ │ │ │ ├── LauncherState.cpp │ │ │ │ │ │ ├── RegContext.cpp │ │ │ │ │ │ └── TmpBufferManager.cpp │ │ │ │ │ ├── OKLDialect.cpp │ │ │ │ │ ├── OKLOps.cpp │ │ │ │ │ ├── OKLTypes.cpp │ │ │ │ │ └── README-OriginVersion.md │ │ │ │ ├── OKM/ │ │ │ │ │ ├── Conversion/ │ │ │ │ │ │ └── Conversion.cpp │ │ │ │ │ ├── OKMDialect.cpp │ │ │ │ │ └── passes.cpp │ │ │ │ ├── OneFlowCanonicalizers.cpp │ │ │ │ ├── OneFlowDataTypeConversion.cpp │ │ │ │ ├── OneFlowDialect.cpp │ │ │ │ ├── OneFlowInferReturnTypes.cpp │ │ │ │ ├── OneFlowOpFolders.cpp │ │ │ │ ├── OneFlowOpGetGen.cpp.in │ │ │ │ ├── OneFlowOpTraits.cpp │ │ │ │ ├── OneFlowOps.cpp │ │ │ │ ├── OneFlowRewrites.cpp │ │ │ │ ├── OneFlowSupport.cpp │ │ │ │ ├── OneFlowTypes.cpp │ │ │ │ ├── OneFlowUtils.cpp │ │ │ │ ├── PDLL/ │ │ │ │ │ ├── AllocEliminationPatterns.cpp │ │ │ │ │ ├── AllocEliminationPatterns.pdll │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── ForwardOpPatterns.cpp │ │ │ │ │ ├── ForwardOpPatterns.pdll │ │ │ │ │ ├── FuseConv2DBatchNormPattern.cpp │ │ │ │ │ ├── FuseConv2DBatchNormPattern.pdll │ │ │ │ │ ├── FuseOpsWithBackwardImplPattern.cpp │ │ │ │ │ ├── FuseOpsWithBackwardImplPattern.pdll │ │ │ │ │ ├── NormalizationPatterns.cpp │ │ │ │ │ ├── NormalizationPatterns.pdll │ │ │ │ │ └── OneFlowPDLLUtils.pdll │ │ │ │ ├── Passes.cpp │ │ │ │ ├── SBP/ │ │ │ │ │ ├── SBPAttributes.cpp │ │ │ │ │ ├── SBPDialect.cpp │ │ │ │ │ └── SBPImporter.cpp │ │ │ │ ├── Transform/ │ │ │ │ │ ├── AggregateOps.cpp │ │ │ │ │ ├── AutoNHWCOps.cpp │ │ │ │ │ ├── AutoNhwc.cpp │ │ │ │ │ ├── BufferHostRegister.cpp │ │ │ │ │ ├── CSEWithAttributesIgnored.cpp │ │ │ │ │ ├── ConvertInferenceOp.cpp │ │ │ │ │ ├── EliminateAllocOps.cpp │ │ │ │ │ ├── FuncOps.cpp │ │ │ │ │ ├── GroupMatMulOps.cpp │ │ │ │ │ ├── JITPasses.cpp │ │ │ │ │ ├── OneFlowMemPool.cpp │ │ │ │ │ ├── OneFlowStream.cpp │ │ │ │ │ ├── OutlineAndFuse.cpp │ │ │ │ │ └── TraitFolder.cpp │ │ │ │ ├── TransposeHelpers.cpp │ │ │ │ ├── UserOpConversion.cpp │ │ │ │ └── UserOpReflection.cpp │ │ │ └── Transform/ │ │ │ ├── CMakeLists.txt │ │ │ ├── TransformDialectExtension.cpp │ │ │ ├── TransformDialectInterpreter.cpp │ │ │ └── TransformStateExtension.cpp │ │ ├── llvm-in-tree.cmake │ │ ├── oneflow-extension/ │ │ │ ├── CMakeLists.txt │ │ │ ├── README.md │ │ │ ├── include/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── OneFlow/ │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── JITOpInfer.h │ │ │ │ │ ├── OneFlowLRJITRegistry.h │ │ │ │ │ └── OneFlowRoundTrip.h │ │ │ │ └── PyAst/ │ │ │ │ ├── Ast.h │ │ │ │ └── AstMlirGen.h │ │ │ ├── ir_pass.cpp │ │ │ ├── lr_jit.cpp │ │ │ ├── mlir_gen.cpp │ │ │ ├── mlir_jit_op.cpp │ │ │ └── mlir_jit_op_kernel.cpp │ │ ├── oneflow-lite/ │ │ │ ├── CMakeLists.txt │ │ │ ├── OneFlowLiteCompileMain.cpp │ │ │ ├── include/ │ │ │ │ └── OneFlow/ │ │ │ │ ├── ConvertToLiteExecutable.h │ │ │ │ ├── FlatbufferUtils.h │ │ │ │ ├── OneFlowLiteUtils.h │ │ │ │ └── Transform/ │ │ │ │ ├── FoldVariable.h │ │ │ │ ├── InferPlacement.h │ │ │ │ ├── InsertTransferOp.h │ │ │ │ ├── Lowering/ │ │ │ │ │ ├── LoweringAscend.h │ │ │ │ │ └── LoweringAscendUtils.h │ │ │ │ ├── LoweringLaunchJob.h │ │ │ │ ├── MemoryPlanning.h │ │ │ │ └── PartitionLaunchJob.h │ │ │ ├── lib/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── OneFlow/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── ConvertToLiteExecutable.cpp │ │ │ │ ├── FlatbufferUtils.cpp │ │ │ │ ├── OneFlowLiteUtils.cpp │ │ │ │ ├── Transform/ │ │ │ │ │ ├── FoldVariable.cpp │ │ │ │ │ ├── InferPlacement.cpp │ │ │ │ │ ├── InsertTransferOp.cpp │ │ │ │ │ ├── Lowering/ │ │ │ │ │ │ └── LoweringAscend.cpp │ │ │ │ │ ├── LoweringLaunchJob.cpp │ │ │ │ │ ├── MemoryPlanning.cpp │ │ │ │ │ └── PartitionLaunchJob.cpp │ │ │ │ └── cmake/ │ │ │ │ └── FindAscendSdk.cmake │ │ │ └── schemas/ │ │ │ ├── CMakeLists.txt │ │ │ ├── attributes/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── bool.fbs │ │ │ │ ├── f32.fbs │ │ │ │ ├── f32s.fbs │ │ │ │ ├── f64.fbs │ │ │ │ ├── i32.fbs │ │ │ │ ├── i32s.fbs │ │ │ │ ├── i64.fbs │ │ │ │ ├── i64s.fbs │ │ │ │ ├── shape.fbs │ │ │ │ ├── shapes.fbs │ │ │ │ ├── str.fbs │ │ │ │ └── strs.fbs │ │ │ ├── executable.fbs │ │ │ └── install_flatcc.cmake │ │ ├── oneflow-opt/ │ │ │ ├── CMakeLists.txt │ │ │ ├── README.md │ │ │ └── oneflow-opt.cpp │ │ ├── oneflow-runner/ │ │ │ ├── CMakeLists.txt │ │ │ └── oneflow-runner.cpp │ │ ├── oneflow-runtime/ │ │ │ ├── CMakeLists.txt │ │ │ └── lib/ │ │ │ ├── CMakeLists.txt │ │ │ └── Runtime.cpp │ │ ├── oneflow-translate/ │ │ │ ├── CMakeLists.txt │ │ │ ├── README.md │ │ │ ├── include/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── OneFlow/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── MLIROneFlowTranslation.h │ │ │ ├── lib/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── OneFlow/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── Importer.cpp │ │ │ │ └── MLIROneFlowTranslation.cpp │ │ │ └── oneflow-translate.cpp │ │ └── test/ │ │ ├── CMakeLists.txt │ │ ├── Frontend/ │ │ │ ├── lit.local.cfg │ │ │ ├── oneflow_to_iree.mlir │ │ │ └── tosa_to_elf.mlir │ │ ├── GPU/ │ │ │ ├── lit.local.cfg │ │ │ └── nvvm_to_cubin.mlir │ │ ├── OneFlow/ │ │ │ ├── auto_nhwc/ │ │ │ │ ├── lit.local.cfg │ │ │ │ ├── test_nhwc_batchnorm_relu.py │ │ │ │ ├── test_nhwc_bias_add.py │ │ │ │ ├── test_nhwc_conv.py │ │ │ │ ├── test_nhwc_conv2d_maxpool2d.py │ │ │ │ ├── test_nhwc_conv_relu_add.py │ │ │ │ ├── test_nhwc_lenet.py │ │ │ │ ├── test_nhwc_maxpool_2d.py │ │ │ │ ├── test_nhwc_resnet.py │ │ │ │ ├── test_nhwc_transpose_eliminate.py │ │ │ │ └── test_resnet101_benchmark.py │ │ │ ├── conversion/ │ │ │ │ ├── lower_to_tosa.mlir │ │ │ │ ├── lower_to_tosa_signed.mlir │ │ │ │ └── oneflow_to_tosa.mlir │ │ │ ├── cse.mlir │ │ │ ├── cuda_code_gen/ │ │ │ │ ├── gpu_copy_arg.mlir │ │ │ │ ├── lit.local.cfg │ │ │ │ ├── test_append_oneflow_stream.mlir │ │ │ │ ├── test_cast_ops_to_signless.mlir │ │ │ │ ├── test_fold_alloc_to_subview.mlir │ │ │ │ ├── test_fuser_cast_scale.py │ │ │ │ ├── test_gpu_all_reduce.mlir │ │ │ │ ├── test_insert_ofmempool.mlir │ │ │ │ ├── test_matmul.py │ │ │ │ ├── test_mgpu_to_oneflow_stream.mlir │ │ │ │ └── tosa_to_linalg.mlir │ │ │ ├── folding/ │ │ │ │ ├── test_conv_bn.py │ │ │ │ └── test_simple_multiply.py │ │ │ ├── fuse/ │ │ │ │ ├── fuse_forward_ops.mlir │ │ │ │ ├── test_cast_optimal_pass.py │ │ │ │ └── test_fuse_pad_conv.py │ │ │ ├── group_matmul.mlir │ │ │ ├── jit_outline_func.mlir │ │ │ ├── kernel_launch/ │ │ │ │ ├── OKLPass/ │ │ │ │ │ ├── lower_launcher_to_llvm_ptr.mlir │ │ │ │ │ ├── lower_okl_to_llvm_call.mlir │ │ │ │ │ └── tag_cuda_graph_support.mlir │ │ │ │ ├── OKMPass/ │ │ │ │ │ ├── extract_okm_tensor.mlir │ │ │ │ │ ├── okm_to_okl.mlir │ │ │ │ │ ├── opt_okm_memref.mlir │ │ │ │ │ └── wrap_okm_kernel.mlir │ │ │ │ ├── OneFlowPass/ │ │ │ │ │ ├── aggregate_compute_ops.mlir │ │ │ │ │ └── wrap_ops_to_kernel_launch/ │ │ │ │ │ ├── cuda_graph.mlir │ │ │ │ │ ├── lit.local.cfg │ │ │ │ │ └── simple.mlir │ │ │ │ └── test_resnet.py │ │ │ ├── networks/ │ │ │ │ ├── __init__.py │ │ │ │ └── resnet50.py │ │ │ ├── oneflow-opt.mlir │ │ │ ├── oneflow-translate.mlir │ │ │ ├── psig/ │ │ │ │ ├── error_parse.mlir │ │ │ │ ├── sbp_parse.mlir │ │ │ │ ├── test_2nd_basic_parse.py │ │ │ │ └── test_basic_parse.py │ │ │ ├── traits.mlir │ │ │ └── with_cuda/ │ │ │ ├── lit.local.cfg │ │ │ ├── test_conv_bn_auto_nhwc.py │ │ │ ├── test_fuse_bias_add_dropout.py │ │ │ ├── test_fuse_bias_add_gelu.py │ │ │ ├── test_fuse_bn_add_relu.py │ │ │ ├── test_fuse_gelu.py │ │ │ ├── test_fuse_scale_tril.py │ │ │ ├── test_fused_matmul_bias.py │ │ │ ├── test_fused_multi_head_attention_inference.py │ │ │ └── test_graph_save_and_load.py │ │ ├── Transform/ │ │ │ ├── lit.local.cfg │ │ │ ├── matmul.mlir │ │ │ ├── softmax.mlir │ │ │ ├── softmax_codegen_spec.mlir │ │ │ ├── softmax_codegen_spec_no_vectorize.mlir │ │ │ └── test_dialect.mlir │ │ ├── lit.cfg.py │ │ └── lit.site.cfg.py.in │ ├── maybe/ │ │ ├── config.h │ │ ├── error.h │ │ ├── error_test.cpp │ │ ├── just.h │ │ ├── just_test.cpp │ │ ├── maybe.h │ │ ├── maybe_test.cpp │ │ ├── optional.h │ │ ├── optional_test.cpp │ │ ├── type_traits.h │ │ ├── type_traits_test.cpp │ │ ├── utility.h │ │ ├── utility_test.cpp │ │ ├── variant.h │ │ └── variant_test.cpp │ └── user/ │ ├── data/ │ │ ├── batch_dataset.h │ │ ├── batch_random_shuffle_dataset.h │ │ ├── coco_data_reader.cpp │ │ ├── coco_data_reader.h │ │ ├── coco_dataset.cpp │ │ ├── coco_dataset.h │ │ ├── coco_parser.cpp │ │ ├── coco_parser.h │ │ ├── data_reader.h │ │ ├── dataset.h │ │ ├── distributed_training_dataset.h │ │ ├── distributed_util.h │ │ ├── gpt_dataset.cpp │ │ ├── gpt_dataset.h │ │ ├── group_batch_dataset.h │ │ ├── ofrecord_data_reader.h │ │ ├── ofrecord_dataset.h │ │ ├── ofrecord_image_classification_data_reader.h │ │ ├── ofrecord_image_classification_dataset.cpp │ │ ├── ofrecord_image_classification_dataset.h │ │ ├── ofrecord_image_classification_parser.h │ │ ├── ofrecord_parser.h │ │ ├── parser.h │ │ └── random_shuffle_dataset.h │ ├── image/ │ │ ├── crop_window.h │ │ ├── image_util.cpp │ │ ├── image_util.h │ │ ├── jpeg_decoder.cpp │ │ ├── jpeg_decoder.h │ │ ├── jpeg_decoder_test.cpp │ │ ├── random_crop_generator.cpp │ │ └── random_crop_generator.h │ ├── kernels/ │ │ ├── acc_kernel.cpp │ │ ├── activation_kernels.cpp │ │ ├── adaptive_avg_pool_cpu_kernel.cpp │ │ ├── adaptive_avg_pool_gpu_kernel.cu │ │ ├── adaptive_max_pool_cpu_kernel.cpp │ │ ├── adaptive_max_pool_gpu_kernel.cu │ │ ├── adaptive_pool_kernel_util.h │ │ ├── add_n_kernel.cpp │ │ ├── affine_grid_kernel.cpp │ │ ├── affine_grid_kernel.cu │ │ ├── affine_grid_kernel.h │ │ ├── arange_kernel.cpp │ │ ├── arange_kernel_util.cpp │ │ ├── arange_kernel_util.cu │ │ ├── arange_kernel_util.h │ │ ├── arg_sort_kernel.cpp │ │ ├── arg_sort_kernel.cu │ │ ├── arg_where_kernel.cpp │ │ ├── arg_where_kernel_util.cpp │ │ ├── arg_where_kernel_util.cu │ │ ├── arg_where_kernel_util.h │ │ ├── argmax_kernel.cpp │ │ ├── argmax_kernel.cu │ │ ├── as_strided_kernel.cpp │ │ ├── as_strided_kernel.cu │ │ ├── assign_if_kernel.cpp │ │ ├── assign_if_kernel.cu │ │ ├── assign_kernel.cpp │ │ ├── avg_pool_kernel.cpp │ │ ├── avg_pool_kernel.cu │ │ ├── avg_pool_kernel_util.cpp │ │ ├── avg_pool_kernel_util.h │ │ ├── batch_gather_kernel.cpp │ │ ├── batch_gather_kernel_util.cpp │ │ ├── batch_gather_kernel_util.cu │ │ ├── batch_gather_kernel_util.h │ │ ├── batch_norm_backward_elemt_kernel.cu │ │ ├── batch_norm_backward_reduce_kernel.cu │ │ ├── batch_norm_elemt_kernel.cu │ │ ├── batch_norm_gather_stats_with_counts_kernel.cu │ │ ├── batch_norm_kernel_utils.h │ │ ├── batch_norm_stats_kernel.cu │ │ ├── bernoulli_kernel.cpp │ │ ├── bias_add_kernel.cpp │ │ ├── binary_concat_kernel.cu │ │ ├── binary_cross_entropy_kernel.cpp │ │ ├── binary_cross_entropy_kernel.cu │ │ ├── binary_cross_entropy_with_logits_kernel.cpp │ │ ├── binary_cross_entropy_with_logits_kernel.cu │ │ ├── binary_cross_entropy_with_logits_mean_kernel.cu │ │ ├── binary_cross_entropy_with_logits_mean_kernel_util.h │ │ ├── binary_cross_entropy_with_logits_reduce_mean.cpp │ │ ├── bincount_kernel.cpp │ │ ├── bincount_kernel.cu │ │ ├── broadcast_div_grad_kernel.cpp │ │ ├── broadcast_like_kernel.cpp │ │ ├── cast_kernel.cpp │ │ ├── cast_to_static_shape_kernel.cpp │ │ ├── categorical_ordinal_encode_kernel.cpp │ │ ├── categorical_ordinal_encode_kernel_util.cpp │ │ ├── categorical_ordinal_encode_kernel_util.cu │ │ ├── categorical_ordinal_encode_kernel_util.h │ │ ├── clip_by_value_kernel.cpp │ │ ├── clip_by_value_kernel.cu │ │ ├── clip_by_value_kernel.h │ │ ├── coco_reader_kernel.cpp │ │ ├── collective_communication/ │ │ │ ├── cpu/ │ │ │ │ ├── cpu_all_gather.cpp │ │ │ │ ├── cpu_all_reduce.cpp │ │ │ │ ├── cpu_broadcast.cpp │ │ │ │ ├── cpu_collective_communication_util.h │ │ │ │ ├── cpu_communication_context.cpp │ │ │ │ ├── cpu_communication_context.h │ │ │ │ ├── cpu_recv.cpp │ │ │ │ ├── cpu_reduce.cpp │ │ │ │ ├── cpu_reduce_scatter.cpp │ │ │ │ └── cpu_send.cpp │ │ │ ├── cuda/ │ │ │ │ ├── cuda_all_gather.cpp │ │ │ │ ├── cuda_all_reduce.cpp │ │ │ │ ├── cuda_all_to_all.cpp │ │ │ │ ├── cuda_broadcast.cpp │ │ │ │ ├── cuda_communication_context.cpp │ │ │ │ ├── cuda_communication_context.h │ │ │ │ ├── cuda_recv.cpp │ │ │ │ ├── cuda_reduce.cpp │ │ │ │ ├── cuda_reduce_scatter.cpp │ │ │ │ ├── cuda_send.cpp │ │ │ │ ├── cuda_send_recv_util.cpp │ │ │ │ └── cuda_send_recv_util.h │ │ │ └── include/ │ │ │ ├── all_gather.h │ │ │ ├── all_reduce.h │ │ │ ├── all_to_all.h │ │ │ ├── broadcast.h │ │ │ ├── collective_communication.h │ │ │ ├── communication_context.h │ │ │ ├── recv.h │ │ │ ├── reduce.h │ │ │ ├── reduce_scatter.h │ │ │ └── send.h │ │ ├── combined_margin_loss_kernel.cpp │ │ ├── combined_margin_loss_kernel.cu │ │ ├── communicate_util.cpp │ │ ├── communicate_util.h │ │ ├── complex_kernels.cpp │ │ ├── concat_kernel.cpp │ │ ├── constant_kernel.cpp │ │ ├── conv_cudnn_kernels.cpp │ │ ├── conv_cutlass_kernels.cu │ │ ├── conv_kernels.cpp │ │ ├── convert_memory_format_kernel.cpp │ │ ├── convert_memory_format_util.cpp │ │ ├── convert_memory_format_util.h │ │ ├── copy_data_content_kernel.cpp │ │ ├── copy_hd_kernel.cpp │ │ ├── copy_kernel.cpp │ │ ├── count_not_finite_kernel.cpp │ │ ├── count_not_finite_kernel.cu │ │ ├── ctc_greedy_decoder.cpp │ │ ├── ctc_greedy_decoder.cu │ │ ├── ctc_greedy_decoder.h │ │ ├── ctc_loss_kernel.cpp │ │ ├── ctc_loss_kernel_util.cpp │ │ ├── ctc_loss_kernel_util.cu │ │ ├── ctc_loss_kernel_util.h │ │ ├── cublas_bias_add_relu_matmul_grad_kernel.cu │ │ ├── cublas_fused_matmul_bias_add_grad.cu │ │ ├── cublas_fused_mlp_grad_kernel.cu │ │ ├── cublas_fused_mlp_kernel.cu │ │ ├── cublas_fused_mlp_util.cuh │ │ ├── cufft_plan_cache.h │ │ ├── cum_backward_kernel.cpp │ │ ├── cum_backward_kernel.cu │ │ ├── cum_forward_kernel.cpp │ │ ├── cum_forward_kernel.cu │ │ ├── cutlass_conv_tuner.cpp │ │ ├── cutlass_conv_tuner.h │ │ ├── data_shuffle_kernel.cu │ │ ├── deconv_cpu_kernel.cpp │ │ ├── deconv_cudnn_kernel.cpp │ │ ├── deform_conv_kernel.cpp │ │ ├── deform_conv_kernel.cu │ │ ├── det_kernel.cpp │ │ ├── diag_kernel.cpp │ │ ├── diag_kernel.cu │ │ ├── diag_kernel.h │ │ ├── diagonal_kernel.cpp │ │ ├── diagonal_kernel.cu │ │ ├── dim_gather_kernel_util.cpp │ │ ├── dim_gather_kernel_util.cu │ │ ├── dim_gather_kernel_util.h │ │ ├── dim_gather_kernels.cpp │ │ ├── dim_scatter_kernel_util.cpp │ │ ├── dim_scatter_kernel_util.cu │ │ ├── dim_scatter_kernel_util.h │ │ ├── dim_scatter_kernels.cpp │ │ ├── dim_scatter_scalar_kernel_util.cpp │ │ ├── dim_scatter_scalar_kernel_util.cu │ │ ├── dim_scatter_scalar_kernel_util.h │ │ ├── dim_scatter_scalar_kernels.cpp │ │ ├── distributions/ │ │ │ ├── common.h │ │ │ ├── distribution_template_util.cuh │ │ │ ├── exponential_distribution.cpp │ │ │ ├── exponential_distribution.cu │ │ │ ├── exponential_distribution.h │ │ │ ├── exponential_kernel.cpp │ │ │ ├── exponential_kernel.h │ │ │ ├── multinomial_with_replacement_kernel.cpp │ │ │ ├── multinomial_with_replacement_kernel.cu │ │ │ ├── normal_distribution.cpp │ │ │ ├── normal_distribution.cu │ │ │ ├── normal_distribution.h │ │ │ ├── normal_kernel.cpp │ │ │ ├── normal_kernel.h │ │ │ ├── uniform_distribution.cpp │ │ │ ├── uniform_distribution.cu │ │ │ ├── uniform_distribution.h │ │ │ ├── uniform_int_distribution.cpp │ │ │ ├── uniform_int_distribution.cu │ │ │ ├── uniform_int_distribution.h │ │ │ ├── uniform_int_kernel.cpp │ │ │ ├── uniform_int_kernel.h │ │ │ ├── uniform_kernel.cpp │ │ │ └── uniform_kernel.h │ │ ├── dot_kernel.cpp │ │ ├── dropout_kernel.cpp │ │ ├── dropout_kernel.cu │ │ ├── dropout_kernel.h │ │ ├── dynamic_loss_scale_schedule_kernel.cpp │ │ ├── dynamic_loss_scale_schedule_kernel.cu │ │ ├── eager_b_to_s_kernel.cpp │ │ ├── eager_ccl_kernel.cpp │ │ ├── eager_nccl_s2s_kernel.cu │ │ ├── eager_p_to_b_kernel.cpp │ │ ├── eager_p_to_s_kernel.cpp │ │ ├── eager_s_to_b_kernel.cpp │ │ ├── eager_s_to_p_kernel.cpp │ │ ├── eager_s_to_s_kernel.cpp │ │ ├── eager_symmetric_s_to_p_kernel.cpp │ │ ├── elementwise_maximum_minimum_kernel.cpp │ │ ├── elementwise_maximum_minimum_kernel.cu │ │ ├── elementwise_maximum_minimum_kernel.h │ │ ├── elementwise_primitive_kernel.h │ │ ├── embedding_kernel.cpp │ │ ├── embedding_kernel.cu │ │ ├── embedding_kernel_util.cpp │ │ ├── embedding_kernel_util.cu │ │ ├── embedding_kernel_util.h │ │ ├── empty_kernel.cpp │ │ ├── erfinv_kernel.cpp │ │ ├── erfinv_kernel.cu │ │ ├── expand_kernel.cpp │ │ ├── eye_kernel.cpp │ │ ├── eye_kernel_util.cpp │ │ ├── eye_kernel_util.cu │ │ ├── eye_kernel_util.h │ │ ├── fake_quantization_kernel.cpp │ │ ├── fake_quantization_kernel.cu │ │ ├── fft_kernel_util.cpp │ │ ├── fft_kernel_util.cu │ │ ├── fft_kernel_util.h │ │ ├── fft_kernels.cpp │ │ ├── fill_kernel.cpp │ │ ├── fill_kernel.cu │ │ ├── flip_kernel.cpp │ │ ├── flip_kernel.cu │ │ ├── fold_kernel.cpp │ │ ├── fold_kernel_util.cpp │ │ ├── fold_kernel_util.cu │ │ ├── fold_kernel_util.h │ │ ├── frac_kernel.cpp │ │ ├── frac_kernel.cu │ │ ├── fused_attention_kernels.cu │ │ ├── fused_bias_add_kernel.cu │ │ ├── fused_bias_add_scale_mask_softmax_dropout.cu │ │ ├── fused_cast_scale_kernel.cpp │ │ ├── fused_cast_scale_kernel.cu │ │ ├── fused_center_kernel.cu │ │ ├── fused_clip_grad.cu │ │ ├── fused_clip_grad.h │ │ ├── fused_clip_grad_util.h │ │ ├── fused_codegeex_qkv_reshape_kernel.cu │ │ ├── fused_cross_feature_interaction.cu │ │ ├── fused_cross_feature_interaction_grad.cu │ │ ├── fused_dot_feature_interaction_kernel.cu │ │ ├── fused_gelu_mul_kernel.cu │ │ ├── fused_get_bounding_boxes_coord_kernel.cu │ │ ├── fused_get_ciou_diagonal_angle_kernel.cu │ │ ├── fused_get_ciou_result_kernel.cu │ │ ├── fused_get_convex_diagonal_squared_kernel.cu │ │ ├── fused_get_intersection_area_kernel.cu │ │ ├── fused_get_iou_kernel.cu │ │ ├── fused_glu_kernel.cu │ │ ├── fused_glu_without_linear_grad_kernel.cu │ │ ├── fused_gru_cell_kernel.cu │ │ ├── fused_lstm_cell_kernel.cu │ │ ├── fused_matmul_bias_add_relu_dropout.cu │ │ ├── fused_matmul_bias_kernel.cu │ │ ├── fused_relu_dropout_grad_kernel.cu │ │ ├── fused_rnn_cell_kernel_util.h │ │ ├── fused_scale_mask_bias_softmax.cu │ │ ├── fused_scale_mask_softmax.cu │ │ ├── fused_scale_mask_softmax_dropout.cu │ │ ├── fused_self_attention_query_mul_key_and_value_kernel.cu │ │ ├── fused_softmax.cuh │ │ ├── fused_tril_scale_softmax_mask_scale_kernel.cu │ │ ├── fused_weighted_sum_kernel.cpp │ │ ├── fused_weighted_sum_kernel.cu │ │ ├── gather_kernel.cpp │ │ ├── gather_kernel_util.cpp │ │ ├── gather_kernel_util.cu │ │ ├── gather_kernel_util.h │ │ ├── generate_random_batch_permutation_indices_kernel.cpp │ │ ├── generate_random_batch_permutation_indices_kernel.cu │ │ ├── gpt_data_loader_kernel.cpp │ │ ├── greater_inplace_kernel.cpp │ │ ├── greater_inplace_kernel_util.cpp │ │ ├── greater_inplace_kernel_util.cu │ │ ├── greater_inplace_kernel_util.h │ │ ├── grid_sample_kernel.cpp │ │ ├── grid_sample_kernel_util.cpp │ │ ├── grid_sample_kernel_util.cu │ │ ├── grid_sample_kernel_util.h │ │ ├── group_conv_kernel.cpp │ │ ├── group_deconv_kernel.cpp │ │ ├── group_norm_kernel.cu │ │ ├── grouped_matmul_bias.cu │ │ ├── groupwise_quantization_kernels.cu │ │ ├── host_scalar_add_by_tensor_kernel.cu │ │ ├── image_batch_align_kernel.cpp │ │ ├── image_decode_kernel.cpp │ │ ├── image_object_preprocess_kernels.cpp │ │ ├── image_preprocess_kernels.cpp │ │ ├── image_preprocess_kernels.cu │ │ ├── image_resize_kernels.cpp │ │ ├── image_target_resize_kernel.cpp │ │ ├── in_top_k_kernel.cpp │ │ ├── in_top_k_kernel_util.cpp │ │ ├── in_top_k_kernel_util.cu │ │ ├── in_top_k_kernel_util.h │ │ ├── index_add_kernel.cpp │ │ ├── index_add_kernel.cu │ │ ├── indexed_slices_reduce_sum_kernel.cpp │ │ ├── indexed_slices_reduce_sum_kernel_util.cpp │ │ ├── indexed_slices_reduce_sum_kernel_util.h │ │ ├── inv_kernels.cpp │ │ ├── inv_kernels.cu │ │ ├── kl_div_kernel.cpp │ │ ├── kl_div_kernel.cu │ │ ├── l1_l2_regularize_gradient_kernel.cpp │ │ ├── l1_l2_regularize_gradient_kernel_util.cpp │ │ ├── l1_l2_regularize_gradient_kernel_util.cu │ │ ├── l1_l2_regularize_gradient_kernel_util.h │ │ ├── l2_normalize_kernel.cpp │ │ ├── l2_normalize_kernel.cu │ │ ├── layer_norm_cpu_kernel.cpp │ │ ├── layer_norm_gpu_kernel.cu │ │ ├── lerp_kernel.cpp │ │ ├── lerp_kernel_util.cpp │ │ ├── lerp_kernel_util.cu │ │ ├── lerp_kernel_util.h │ │ ├── linalg_cross_kernel.cpp │ │ ├── linalg_cross_kernel.cu │ │ ├── log_softmax_kernel.cpp │ │ ├── logical_not_kernel.cpp │ │ ├── loss_kernel_util.h │ │ ├── lu_decomposition_kernel.cu │ │ ├── masked_fill_kernel.cpp │ │ ├── math_binary_broadcast_kernels.cpp │ │ ├── math_binary_elementwise_func.h │ │ ├── math_binary_elementwise_kernel.cpp │ │ ├── math_binary_elementwise_kernel.cu │ │ ├── math_unary_elementwise_func.h │ │ ├── math_unary_elementwise_primitive_kernel.cpp │ │ ├── matmul_kernels.cpp │ │ ├── matrix_vector_product_kernel.cpp │ │ ├── max_pool_kernel.cpp │ │ ├── max_pool_kernel.cu │ │ ├── max_pool_kernel_util.cpp │ │ ├── max_pool_kernel_util.h │ │ ├── max_unpool_kernel.cpp │ │ ├── max_unpool_kernel.cu │ │ ├── max_unpool_kernel_util.cpp │ │ ├── max_unpool_kernel_util.h │ │ ├── median_kernel.cpp │ │ ├── median_kernel.cu │ │ ├── median_with_indices_kernel.cpp │ │ ├── median_with_indices_kernel.cu │ │ ├── min_max_observer_kernel.cpp │ │ ├── min_max_observer_kernel.cu │ │ ├── mode_kernel.cpp │ │ ├── model_update_kernel_util.cpp │ │ ├── model_update_kernel_util.cu │ │ ├── model_update_kernel_util.h │ │ ├── model_update_kernels.cpp │ │ ├── moving_average_min_max_observer_kernel.cpp │ │ ├── moving_average_min_max_observer_kernel.cu │ │ ├── multi_reduce_kernel_util.h │ │ ├── multi_reduce_kernels.cpp │ │ ├── multi_reduce_kernels.cu │ │ ├── multi_reduce_kernels.h │ │ ├── multi_tensor_model_update_kernel.cpp │ │ ├── multi_tensor_model_update_kernel_util.cu │ │ ├── multi_tensor_model_update_kernel_util.h │ │ ├── mutable_cast_once_kernel.cpp │ │ ├── narrow_kernel.cpp │ │ ├── nccl_logical_2d_sbp_kernels.cpp │ │ ├── nccl_logical_fusion_kernel.cpp │ │ ├── nccl_logical_kernels.cpp │ │ ├── nccl_logical_send_recv_kernel.cpp │ │ ├── nd_index_slice_kernels.cpp │ │ ├── nd_index_slice_kernels.cu │ │ ├── nd_index_slice_kernels.h │ │ ├── nd_index_slice_util.h │ │ ├── nll_kernel.cpp │ │ ├── nll_kernel_util.cpp │ │ ├── nll_kernel_util.cu │ │ ├── nll_kernel_util.h │ │ ├── nms_kernel.cpp │ │ ├── nms_kernel.cu │ │ ├── noncontiguous_binary_op.cu │ │ ├── nop_kernel.cpp │ │ ├── normalization_kernel.cpp │ │ ├── normalization_kernel.cu │ │ ├── nvtx_range_kernel.cu │ │ ├── ofrecord_decoder_kernels.cpp │ │ ├── ofrecord_image_classification_reader_kernel.cpp │ │ ├── ofrecord_reader_kernel.cpp │ │ ├── one_embedding_data_shuffle.cuh │ │ ├── one_embedding_embedding_gradient_shuffle_p2p_kernel.cu │ │ ├── one_embedding_embedding_shuffle_p2p_kernel.cu │ │ ├── one_embedding_id_shuffle_p2p_kernel.cu │ │ ├── one_embedding_kernels.cu │ │ ├── one_embedding_update_kernels.cu │ │ ├── one_hot_kernel.cpp │ │ ├── one_hot_kernel.cu │ │ ├── ones_like_kernel.cpp │ │ ├── op_kernel_wrapper.h │ │ ├── p2p_comm_kernel.cpp │ │ ├── pack_kernel.cpp │ │ ├── pad_kernel.cpp │ │ ├── partial_fc_sample_kernel.cu │ │ ├── pocketfft_hdronly.h │ │ ├── pocketfftplan.h │ │ ├── prelu_kernel.cpp │ │ ├── prelu_kernel.cu │ │ ├── quantization_kernel.cpp │ │ ├── quantization_kernel.cu │ │ ├── radix_sort.cuh │ │ ├── random_crop_kernel_state.cpp │ │ ├── random_crop_kernel_state.h │ │ ├── random_mask_generator.cpp │ │ ├── random_mask_generator.cu │ │ ├── random_mask_generator.h │ │ ├── random_mask_like_kernel.cpp │ │ ├── random_mask_like_kernel.h │ │ ├── random_seed_util.cpp │ │ ├── random_seed_util.h │ │ ├── randperm_kernel.cpp │ │ ├── randperm_kernel.cu │ │ ├── raw_reader_kernel.cpp │ │ ├── reduce_kernel.cpp │ │ ├── reduce_like_kernels.cpp │ │ ├── reflection_pad_kernels.cpp │ │ ├── reflection_pad_kernels_util.cpp │ │ ├── reflection_pad_kernels_util.cu │ │ ├── reflection_pad_kernels_util.h │ │ ├── repeat_interleave_kernel.cpp │ │ ├── repeat_interleave_kernel.cu │ │ ├── replication_pad_kernels.cpp │ │ ├── replication_pad_kernels_util.cpp │ │ ├── replication_pad_kernels_util.cu │ │ ├── replication_pad_kernels_util.h │ │ ├── rms_norm_gpu_kernel.cu │ │ ├── roc_auc_score_kernel.cpp │ │ ├── roi_align_kernel.cu │ │ ├── roll_kernel.cpp │ │ ├── roll_kernel.cu │ │ ├── roll_kernel_utils.h │ │ ├── rrelu_kernel.cpp │ │ ├── rrelu_kernel.cu │ │ ├── same_padding_kernel.cpp │ │ ├── scalar_bitwise_kernels.cpp │ │ ├── scalar_by_tensor_kernel.cpp │ │ ├── scalar_logical_kernels.cpp │ │ ├── scalar_math_kernels.cpp │ │ ├── scaled_dot_product_attention_grad_kernel.cu │ │ ├── scaled_dot_product_attention_kernel.cu │ │ ├── scaled_dot_product_attention_kernel.h │ │ ├── scaled_dot_product_attention_util.h │ │ ├── search_sorted_kernel.cpp │ │ ├── search_sorted_kernel.cu │ │ ├── search_sorted_kernel_util.h │ │ ├── sigmoid_cross_entropy_kernel.cpp │ │ ├── sigmoid_cross_entropy_kernel.cu │ │ ├── sigmoid_cross_entropy_kernel.h │ │ ├── skip_layer_norm_kernel.cu │ │ ├── skip_rms_norm_kernel.cu │ │ ├── slice_kernel.cpp │ │ ├── slice_util.cpp │ │ ├── slice_util.cu │ │ ├── slice_util.h │ │ ├── smooth_l1_loss_kernel.cpp │ │ ├── smooth_l1_loss_kernel.cu │ │ ├── softmax_cross_entropy_kernel.cpp │ │ ├── softmax_cross_entropy_kernel.cu │ │ ├── softmax_cross_entropy_kernel.h │ │ ├── softmax_kernel.cpp │ │ ├── sort_kernel.cpp │ │ ├── sort_kernel.cu │ │ ├── sparse_cross_entropy_kernel.cpp │ │ ├── sparse_cross_entropy_kernel_util.cpp │ │ ├── sparse_cross_entropy_kernel_util.cu │ │ ├── sparse_cross_entropy_kernel_util.h │ │ ├── sparse_softmax_cross_entropy_kernel.cpp │ │ ├── sparse_softmax_cross_entropy_kernel.cu │ │ ├── sparse_softmax_cross_entropy_kernel_util.cpp │ │ ├── sparse_softmax_cross_entropy_kernel_util.cu │ │ ├── sparse_softmax_cross_entropy_kernel_util.h │ │ ├── split_like_kernel.cpp │ │ ├── sqrt_square_sum_kernel.cpp │ │ ├── sqrt_square_sum_kernel_util.cpp │ │ ├── sqrt_square_sum_kernel_util.cu │ │ ├── sqrt_square_sum_kernel_util.h │ │ ├── square_sum_kernel.cpp │ │ ├── square_sum_kernel_util.cpp │ │ ├── square_sum_kernel_util.cu │ │ ├── square_sum_kernel_util.h │ │ ├── ssp_variable_proxy_kernel.cpp │ │ ├── stack_kernel.cpp │ │ ├── stateful_opkernel.cpp │ │ ├── stateful_opkernel.h │ │ ├── summary_kernels.cpp │ │ ├── tensor_buffer_kernels.cpp │ │ ├── tensor_constant_kernel.cpp │ │ ├── tf_pool_cpu_kernel.cpp │ │ ├── tf_pool_gpu_kernel.cpp │ │ ├── tf_prelu_kernel.cpp │ │ ├── tf_prelu_kernel.cu │ │ ├── throw_error_kernel.cpp │ │ ├── to_contiguous_kernel.cpp │ │ ├── to_contiguous_kernel.cu │ │ ├── to_contiguous_kernel.h │ │ ├── top_k_kernel.cpp │ │ ├── top_k_kernel.cu │ │ ├── transpose_kernel.cpp │ │ ├── tril_kernel.cpp │ │ ├── tril_kernel.cu │ │ ├── triu_kernel.cpp │ │ ├── triu_kernel.cu │ │ ├── tuple_identity_kernel.cpp │ │ ├── two_stage_reduce_kernel.cpp │ │ ├── two_stage_reduce_kernel_util.cpp │ │ ├── two_stage_reduce_kernel_util.cu │ │ ├── two_stage_reduce_kernel_util.h │ │ ├── unfold_kernel.cpp │ │ ├── unfold_kernel_util.cpp │ │ ├── unfold_kernel_util.cu │ │ ├── unfold_kernel_util.h │ │ ├── unfold_tensor_kernel.cpp │ │ ├── unfold_tensor_kernel.cu │ │ ├── unfold_tensor_kernel_utils.h │ │ ├── unique_kernel.cpp │ │ ├── unique_kernel_util.cpp │ │ ├── unique_kernel_util.cu │ │ ├── unique_kernel_util.h │ │ ├── unique_with_counts_kernel.cpp │ │ ├── unpack_kernel.cpp │ │ ├── unsorted_batch_segment_sum_kernel.cpp │ │ ├── unsorted_segment_sum_kernel.cpp │ │ ├── unsorted_segment_sum_kernel_util.cpp │ │ ├── unsorted_segment_sum_kernel_util.cu │ │ ├── unsorted_segment_sum_kernel_util.h │ │ ├── upsample_bicubic_2d_kernel.cpp │ │ ├── upsample_bicubic_2d_kernel.cu │ │ ├── upsample_bilinear_2d_kernel.cpp │ │ ├── upsample_bilinear_2d_kernel.cu │ │ ├── upsample_kernel.h │ │ ├── upsample_linear_1d_kernel.cpp │ │ ├── upsample_linear_1d_kernel.cu │ │ ├── upsample_nearest_kernel.cpp │ │ ├── upsample_nearest_kernel.cu │ │ ├── upsample_trilinear_3d_kernel.cpp │ │ ├── upsample_trilinear_3d_kernel.cu │ │ ├── util_ops_kernels.cpp │ │ ├── variance_kernel.cpp │ │ ├── variance_kernel_util.cpp │ │ ├── variance_kernel_util.cu │ │ ├── variance_kernel_util.h │ │ ├── vector_matrix_product_kernel.cpp │ │ ├── where_kernel.cpp │ │ ├── where_kernel_util.cpp │ │ ├── where_kernel_util.cu │ │ ├── where_kernel_util.h │ │ └── zero_like_kernel.cpp │ ├── ops/ │ │ ├── acc_ctrl_tick_op.cpp │ │ ├── acc_op.cpp │ │ ├── adaptive_max_pool_op.cpp │ │ ├── adaptive_pool_op.cpp │ │ ├── add_n_op.cpp │ │ ├── affine_grid_op.cpp │ │ ├── amp_white_identity_op.cpp │ │ ├── arange_op.cpp │ │ ├── arg_sort_op.cpp │ │ ├── arg_where_op.cpp │ │ ├── argmax_op.cpp │ │ ├── as_strided_op.cpp │ │ ├── assign_op.cpp │ │ ├── avg_pool_op.cpp │ │ ├── batch_gather_op.cpp │ │ ├── batch_norm_backward_elemt_op.cpp │ │ ├── batch_norm_backward_reduce_op.cpp │ │ ├── batch_norm_elemt_op.cpp │ │ ├── batch_norm_gather_stats_with_counts_op.cpp │ │ ├── batch_norm_stats_op.cpp │ │ ├── bernoulli_op.cpp │ │ ├── bias_add_op.cpp │ │ ├── binary_cross_entropy_op.cpp │ │ ├── binary_cross_entropy_with_logits_op.cpp │ │ ├── binary_cross_entropy_with_logits_reduce_mean_op.cpp │ │ ├── bincount_op.cpp │ │ ├── broadcast_div_grad_op.cpp │ │ ├── broadcast_like_op.cpp │ │ ├── buffer_op.cpp │ │ ├── cast_like_op.cpp │ │ ├── cast_op.cpp │ │ ├── cast_to_static_shape_op.cpp │ │ ├── cast_to_tick_op.cpp │ │ ├── categorical_ordinal_encode_op.cpp │ │ ├── celu_op.cpp │ │ ├── clip_by_value_op.cpp │ │ ├── coco_reader_op.cpp │ │ ├── combined_margin_loss_op.cpp │ │ ├── comm_net_device_infer_util.cpp │ │ ├── comm_net_device_infer_util.h │ │ ├── complex_ops.cpp │ │ ├── concat_op.cpp │ │ ├── constant_op.cpp │ │ ├── conv_op.cpp │ │ ├── convert_memory_format_op.cpp │ │ ├── convert_memory_format_op.h │ │ ├── copy_hd_op.cpp │ │ ├── copy_op.cpp │ │ ├── count_not_finite_op.cpp │ │ ├── ctc_loss_op.cpp │ │ ├── cublas_bias_add_relu_matmul_grad_op.cpp │ │ ├── cublas_fused_matmul_bias_add_grad_op.cpp │ │ ├── cublas_fused_mlp_grad_op.cpp │ │ ├── cublas_fused_mlp_op.cpp │ │ ├── cum_ops.cpp │ │ ├── data_shuffle_op.cpp │ │ ├── deconv_op.cpp │ │ ├── deform_conv_op.cpp │ │ ├── depend_op.cpp │ │ ├── det_op.cpp │ │ ├── diag_op.cpp │ │ ├── diagonal_op.cpp │ │ ├── dim_gather_op.cpp │ │ ├── dim_scatter_ops.cpp │ │ ├── distributions/ │ │ │ ├── exponential_op.cpp │ │ │ ├── multinomial_with_replacement_op.cpp │ │ │ ├── normal_op.cpp │ │ │ ├── uniform_int_op.cpp │ │ │ └── uniform_op.cpp │ │ ├── dot_op.cpp │ │ ├── dropout_op.cpp │ │ ├── dynamic_loss_scale_schedule_op.cpp │ │ ├── eager_b_to_s_op.cpp │ │ ├── eager_ccl_ops.cpp │ │ ├── eager_p_to_b_op.cpp │ │ ├── eager_p_to_s_op.cpp │ │ ├── eager_s_to_b_op.cpp │ │ ├── eager_s_to_p_op.cpp │ │ ├── eager_s_to_s_op.cpp │ │ ├── eager_symmetric_s_to_p_op.cpp │ │ ├── elementwise_maximum_minimum_ops.cpp │ │ ├── elu_op.cpp │ │ ├── embedding_op.cpp │ │ ├── empty_op.cpp │ │ ├── erfinv_op.cpp │ │ ├── expand_dims_op.cpp │ │ ├── expand_op.cpp │ │ ├── eye_op.cpp │ │ ├── fake_quantization_op.cpp │ │ ├── fft_ops.cpp │ │ ├── fill_op.cpp │ │ ├── flip_op.cpp │ │ ├── frac_op.cpp │ │ ├── fused_attention_ops.cpp │ │ ├── fused_bias_add_op.cpp │ │ ├── fused_bias_add_scale_mask_softmax_dropout_op.cpp │ │ ├── fused_cast_scale_op.cpp │ │ ├── fused_center_op.cpp │ │ ├── fused_clip_grad_ops.cpp │ │ ├── fused_codegeex_qkv_reshape.cpp │ │ ├── fused_cross_feature_interaction_op.cpp │ │ ├── fused_dot_feature_interaction_op.cpp │ │ ├── fused_get_boundding_boxes_coord_op.cpp │ │ ├── fused_get_ciou_diagonal_angle_op.cpp │ │ ├── fused_get_ciou_result_op.cpp │ │ ├── fused_get_convex_diagonal_squared_op.cpp │ │ ├── fused_get_intersection_area_op.cpp │ │ ├── fused_get_iou_op.cpp │ │ ├── fused_glu_op.cpp │ │ ├── fused_glu_without_linear_grad_op.cpp │ │ ├── fused_gru_cell_op.cpp │ │ ├── fused_linear_with_groupwise_quantized_weight_op.cpp │ │ ├── fused_lstm_cell_op.cpp │ │ ├── fused_matmul_bias_add_relu_dropout_op.cpp │ │ ├── fused_matmul_bias_op.cpp │ │ ├── fused_relu_dropout_grad_op.cpp │ │ ├── fused_scale_mask_bias_softmax_op.cpp │ │ ├── fused_scale_mask_softmax_dropout_op.cpp │ │ ├── fused_scale_mask_softmax_op.cpp │ │ ├── fused_scale_tril_softmax_mask_scale_op.cpp │ │ ├── fused_self_attention_query_mul_key_and_value_ops.cpp │ │ ├── fused_weighted_sum_op.cpp │ │ ├── gather_op.cpp │ │ ├── gelu_op.cpp │ │ ├── generate_random_batch_permutation_indices_op.cpp │ │ ├── gpt_data_loader_op.cpp │ │ ├── greater_inplace_op.cpp │ │ ├── grid_sample_op.cpp │ │ ├── group_norm_op.cpp │ │ ├── grouped_matmul_bias_op.cpp │ │ ├── groupwise_dequantize_op.cpp │ │ ├── hardshrink_op.cpp │ │ ├── hardsigmoid_op.cpp │ │ ├── hardswish_op.cpp │ │ ├── hardtanh_op.cpp │ │ ├── hierarchical_parallel_cast_op.cpp │ │ ├── identity_op.cpp │ │ ├── image_batch_align_op.cpp │ │ ├── image_decode_op.cpp │ │ ├── image_object_preprocess_ops.cpp │ │ ├── image_preprocess_ops.cpp │ │ ├── image_resize_ops.cpp │ │ ├── image_target_resize_op.cpp │ │ ├── in_top_k_op.cpp │ │ ├── index_add_op.cpp │ │ ├── indexed_slices_reduce_sum_op.cpp │ │ ├── inv_op.cpp │ │ ├── kl_div_op.cpp │ │ ├── l1_l2_regularize_gradient_op.cpp │ │ ├── l2_normalize_op.cpp │ │ ├── layer_norm_op.cpp │ │ ├── leaky_relu_op.cpp │ │ ├── lerp_op.cpp │ │ ├── linalg_cross_op.cpp │ │ ├── log_softmax_op.cpp │ │ ├── logical_not_op.cpp │ │ ├── loss_op_util.cpp │ │ ├── loss_op_util.h │ │ ├── lu_composition_op.cpp │ │ ├── masked_fill_op.cpp │ │ ├── math_binary_broadcast_ops.cpp │ │ ├── math_binary_broadcast_seq.h │ │ ├── math_binary_elementwise_ops.cpp │ │ ├── math_binary_elementwise_seq.h │ │ ├── math_unary_elementwise_op.cpp │ │ ├── math_unary_elementwise_seq.h │ │ ├── matmul_op.cpp │ │ ├── matrix_vector_product_op.cpp │ │ ├── max_pool_op.cpp │ │ ├── max_unpool_op.cpp │ │ ├── median_op.cpp │ │ ├── median_with_indices_op.cpp │ │ ├── min_max_observer_op.cpp │ │ ├── mish_op.cpp │ │ ├── mode_op.cpp │ │ ├── model_update_ops.cpp │ │ ├── moving_average_min_max_observer_op.cpp │ │ ├── multi_reduce_ops.cpp │ │ ├── multi_tensor_model_update_ops.cpp │ │ ├── mutable_cast_once_op.cpp │ │ ├── narrow_op.cpp │ │ ├── nccl_logical_2d_sbp_ops.cpp │ │ ├── nccl_logical_fusion_op.cpp │ │ ├── nccl_logical_ops.cpp │ │ ├── nccl_logical_util.cpp │ │ ├── nccl_logical_util.h │ │ ├── nd_index_slice_ops.cpp │ │ ├── nll_op.cpp │ │ ├── nms_op.cpp │ │ ├── nn_util.cpp │ │ ├── nn_util.h │ │ ├── noncontiguous_binary_op.cpp │ │ ├── normalization_op.cpp │ │ ├── nvtx_range_op.cpp │ │ ├── ofrecord_decoder_ops.cpp │ │ ├── ofrecord_image_classification_reader_op.cpp │ │ ├── ofrecord_reader_op.cpp │ │ ├── one_embedding_ops.cpp │ │ ├── one_hot_op.cpp │ │ ├── ones_like_op.cpp │ │ ├── p2p_comm_op.cpp │ │ ├── pack_op.cpp │ │ ├── pad_op.cpp │ │ ├── parallel_cast_op.cpp │ │ ├── partial_fc_sample_op.cpp │ │ ├── pinned_identity_op.cpp │ │ ├── prelu_op.cpp │ │ ├── quantization_op.cpp │ │ ├── quick_gelu_op.cpp │ │ ├── randperm_op.cpp │ │ ├── raw_reader_op.cpp │ │ ├── reduce_like_ops.cpp │ │ ├── reduce_ops.cpp │ │ ├── reflection_pad_op.cpp │ │ ├── relu_op.cpp │ │ ├── repeat_interleave_op.cpp │ │ ├── repeat_op.cpp │ │ ├── replication_pad_op.cpp │ │ ├── reshape_like_op.cpp │ │ ├── reshape_op.cpp │ │ ├── reshape_user_op_util.cpp │ │ ├── reshape_user_op_util.h │ │ ├── reshape_user_op_util_test.cpp │ │ ├── rms_norm_op.cpp │ │ ├── roc_auc_score_op.cpp │ │ ├── roi_align_op.cpp │ │ ├── roll_op.cpp │ │ ├── rrelu_op.cpp │ │ ├── same_padding_op.cpp │ │ ├── scalar_bitwise_op.cpp │ │ ├── scalar_by_tensor_op.cpp │ │ ├── scalar_logical_op.cpp │ │ ├── scalar_math_op.cpp │ │ ├── scaled_dot_product_flash_attention_op.cpp │ │ ├── search_sorted_op.cpp │ │ ├── selu_op.cpp │ │ ├── sigmoid_cross_entropy_op.cpp │ │ ├── silu_op.cpp │ │ ├── skip_layer_norm_op.cpp │ │ ├── skip_rms_norm_op.cpp │ │ ├── slice_op.cpp │ │ ├── smooth_l1_loss_op.cpp │ │ ├── softmax_cross_entropy_op.cpp │ │ ├── softmax_op.cpp │ │ ├── softplus_op.cpp │ │ ├── softshrink_op.cpp │ │ ├── softsign_op.cpp │ │ ├── sort_op.cpp │ │ ├── sparse_cross_entropy_op.cpp │ │ ├── sparse_softmax_cross_entropy_op.cpp │ │ ├── split_like_op.cpp │ │ ├── sqrt_square_sum_op.cpp │ │ ├── square_relu_op.cpp │ │ ├── square_sum_op.cpp │ │ ├── squeeze_op.cpp │ │ ├── ssp_variable_proxy_op.cpp │ │ ├── stack_op.cpp │ │ ├── stft_op.cpp │ │ ├── summary_ops.cpp │ │ ├── tanh_op.cpp │ │ ├── tensor_buffer_ops.cpp │ │ ├── tensor_constant_op.cpp │ │ ├── tf_pool_op.cpp │ │ ├── tf_prelu_op.cpp │ │ ├── threshold_op.cpp │ │ ├── throw_error_op.cpp │ │ ├── to_contiguous_op.cpp │ │ ├── top_k_op.cpp │ │ ├── transpose_ops.cpp │ │ ├── tril_op.cpp │ │ ├── triu_op.cpp │ │ ├── trunc_op.cpp │ │ ├── tuple_identity_op.cpp │ │ ├── two_stage_reduce_ops.cpp │ │ ├── unfold_fold_op.cpp │ │ ├── unfold_tensor_op.cpp │ │ ├── unique_op.cpp │ │ ├── unique_with_counts_op.cpp │ │ ├── unpack_op.cpp │ │ ├── unsorted_batch_segment_sum_op.cpp │ │ ├── unsorted_segment_sum_op.cpp │ │ ├── upsample_op.cpp │ │ ├── util_ops.cpp │ │ ├── variance_op.cpp │ │ ├── vector_matrix_product_op.cpp │ │ ├── where_op.cpp │ │ └── zero_like_op.cpp │ ├── summary/ │ │ ├── crc32c.h │ │ ├── env_time.h │ │ ├── event_writer_helper.cpp │ │ ├── event_writer_helper.h │ │ ├── events_writer.cpp │ │ ├── events_writer.h │ │ ├── histogram.cpp │ │ ├── histogram.h │ │ ├── plan_to_physical_graph.cpp │ │ ├── plan_to_physical_graph.h │ │ └── summary_converter.h │ └── utils/ │ ├── pool_util.cpp │ └── pool_util.h ├── python/ │ ├── .gitignore │ ├── oneflow/ │ │ ├── _C/ │ │ │ ├── __init__.py │ │ │ └── _nn.py │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── _dynamo/ │ │ │ └── __init__.py │ │ ├── _utils.py │ │ ├── amp/ │ │ │ ├── __init__.py │ │ │ ├── autocast_mode.py │ │ │ └── grad_scaler.py │ │ ├── ao/ │ │ │ └── quantization.py │ │ ├── asyncs/ │ │ │ ├── __init__.py │ │ │ └── thread.py │ │ ├── autograd/ │ │ │ ├── __init__.py │ │ │ ├── autograd.py │ │ │ ├── autograd_function.py │ │ │ ├── autograd_mode.py │ │ │ ├── functional.py │ │ │ ├── graph.py │ │ │ └── profiler.py │ │ ├── autoprof/ │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── util.py │ │ ├── backends/ │ │ │ ├── __init__.py │ │ │ ├── cuda/ │ │ │ │ └── __init__.py │ │ │ ├── cudnn/ │ │ │ │ └── __init__.py │ │ │ └── mps/ │ │ │ └── __init__.py │ │ ├── boxing/ │ │ │ ├── __init__.py │ │ │ └── nccl/ │ │ │ └── __init__.py │ │ ├── comm/ │ │ │ ├── __init__.py │ │ │ └── comm_ops.py │ │ ├── cuda/ │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── amp/ │ │ │ │ ├── __init__.py │ │ │ │ └── autocast_mode.py │ │ │ ├── random.py │ │ │ └── type_tensor.py │ │ ├── data.py │ │ ├── distributed/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ └── launch.py │ │ ├── distributions/ │ │ │ ├── __init__.py │ │ │ ├── categorical.py │ │ │ ├── distribution.py │ │ │ └── utils.py │ │ ├── env.py │ │ ├── experimental/ │ │ │ └── load_mnist.py │ │ ├── fft/ │ │ │ └── __init__.py │ │ ├── framework/ │ │ │ ├── __init__.py │ │ │ ├── args_tree.py │ │ │ ├── attr_util.py │ │ │ ├── balanced_splitter.py │ │ │ ├── c_api_util.py │ │ │ ├── check_point_v2.py │ │ │ ├── config_util.py │ │ │ ├── distribute.py │ │ │ ├── docstr/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── addcdiv.py │ │ │ │ ├── amax.py │ │ │ │ ├── amin.py │ │ │ │ ├── arange.py │ │ │ │ ├── argsort.py │ │ │ │ ├── array_ops.py │ │ │ │ ├── as_tensor.py │ │ │ │ ├── autograd.py │ │ │ │ ├── baddbmm.py │ │ │ │ ├── bitwise_ops.py │ │ │ │ ├── bmm.py │ │ │ │ ├── broadcast_like.py │ │ │ │ ├── cast.py │ │ │ │ ├── chunk.py │ │ │ │ ├── clamp.py │ │ │ │ ├── comm.py │ │ │ │ ├── comparison.py │ │ │ │ ├── constant.py │ │ │ │ ├── conv.py │ │ │ │ ├── convolution.py │ │ │ │ ├── ctc_decode.py │ │ │ │ ├── dataset.py │ │ │ │ ├── deconv.py │ │ │ │ ├── depend.py │ │ │ │ ├── distance.py │ │ │ │ ├── dropout.py │ │ │ │ ├── einsum.py │ │ │ │ ├── erfinv.py │ │ │ │ ├── expand.py │ │ │ │ ├── flatten.py │ │ │ │ ├── flip.py │ │ │ │ ├── hann_window.py │ │ │ │ ├── in_top_k.py │ │ │ │ ├── index_add.py │ │ │ │ ├── index_select.py │ │ │ │ ├── inv.py │ │ │ │ ├── is_floating_point.py │ │ │ │ ├── lerp.py │ │ │ │ ├── linalg.py │ │ │ │ ├── logaddexp.py │ │ │ │ ├── logical_ops.py │ │ │ │ ├── loss.py │ │ │ │ ├── masked_fill.py │ │ │ │ ├── math_ops.py │ │ │ │ ├── meshgrid.py │ │ │ │ ├── module.py │ │ │ │ ├── nms.py │ │ │ │ ├── nonzero.py │ │ │ │ ├── norm.py │ │ │ │ ├── normalization.py │ │ │ │ ├── oneflow.py │ │ │ │ ├── onehot.py │ │ │ │ ├── pooling.py │ │ │ │ ├── quantile.py │ │ │ │ ├── random.py │ │ │ │ ├── reduce_ops.py │ │ │ │ ├── repeat.py │ │ │ │ ├── repeat_interleave.py │ │ │ │ ├── roc_auc_score.py │ │ │ │ ├── searchsorted.py │ │ │ │ ├── sort.py │ │ │ │ ├── special_ops.py │ │ │ │ ├── split.py │ │ │ │ ├── swapaxes.py │ │ │ │ ├── swapdims.py │ │ │ │ ├── tensor.py │ │ │ │ ├── tensor_attributes.py │ │ │ │ ├── tensor_ops.py │ │ │ │ ├── tensor_t.py │ │ │ │ ├── tensordot.py │ │ │ │ ├── tile.py │ │ │ │ ├── topk.py │ │ │ │ ├── trigonometric_ops.py │ │ │ │ ├── unbind.py │ │ │ │ ├── util_ops.py │ │ │ │ ├── utils.py │ │ │ │ ├── vision.py │ │ │ │ └── where.py │ │ │ ├── dtype.py │ │ │ ├── env_util.py │ │ │ ├── function_desc.py │ │ │ ├── function_util.py │ │ │ ├── generator.py │ │ │ ├── graph_build_util.py │ │ │ ├── hob.py │ │ │ ├── id_util.py │ │ │ ├── infer_compiler/ │ │ │ │ ├── __init__.py │ │ │ │ ├── import_tools/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── format_utils.py │ │ │ │ │ └── importer.py │ │ │ │ ├── transform/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── builtin_transform.py │ │ │ │ │ ├── custom_transform.py │ │ │ │ │ └── manager.py │ │ │ │ ├── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── args_tree_util.py │ │ │ │ │ ├── cost_util.py │ │ │ │ │ ├── log_utils.py │ │ │ │ │ ├── oneflow_exec_mode.py │ │ │ │ │ ├── param_utils.py │ │ │ │ │ ├── patch_for_compiler.py │ │ │ │ │ └── patch_for_diffusers.py │ │ │ │ ├── with_fx_graph.py │ │ │ │ ├── with_fx_interpreter.py │ │ │ │ ├── with_oneflow_backend.py │ │ │ │ └── with_oneflow_compile.py │ │ │ ├── job_set_util.py │ │ │ ├── model.py │ │ │ ├── multi_client_session.py │ │ │ ├── register_class_method_util.py │ │ │ ├── scope_util.py │ │ │ ├── session_context.py │ │ │ ├── sysconfig.py │ │ │ ├── tensor.py │ │ │ ├── tensor_str.py │ │ │ ├── tensor_str_util.py │ │ │ ├── tensor_tuple_util.py │ │ │ ├── type_tensor.py │ │ │ └── unittest.py │ │ ├── fx/ │ │ │ └── __init__.py │ │ ├── hub.py │ │ ├── ir/ │ │ │ ├── __main__.py │ │ │ ├── ast_gen_transformer.py │ │ │ ├── bisect_transformer.py │ │ │ ├── lr_jit.py │ │ │ ├── math_params_transformer.py │ │ │ └── self_params_transformer.py │ │ ├── jit/ │ │ │ ├── __init__.py │ │ │ └── annotations.py │ │ ├── library.py │ │ ├── linalg.py │ │ ├── mock_torch/ │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── dyn_mock_mod.py │ │ │ ├── mock_importer.py │ │ │ ├── mock_modules.py │ │ │ ├── mock_utils.py │ │ │ └── torch/ │ │ │ └── __init__.py │ │ ├── model.py │ │ ├── multiprocessing/ │ │ │ ├── __init__.py │ │ │ ├── _atfork.py │ │ │ ├── pool.py │ │ │ ├── queue.py │ │ │ ├── reductions.py │ │ │ ├── shared_memory/ │ │ │ │ └── __init__.py │ │ │ └── spawn.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── common_types.py │ │ │ ├── functional/ │ │ │ │ ├── __init__.py │ │ │ │ ├── batch_norm.py │ │ │ │ ├── ctc_loss.py │ │ │ │ ├── deform_conv.py │ │ │ │ ├── depend.py │ │ │ │ ├── maxpool.py │ │ │ │ ├── pad.py │ │ │ │ └── softmax.py │ │ │ ├── graph/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cache.py │ │ │ │ ├── graph.py │ │ │ │ ├── graph_block.py │ │ │ │ ├── graph_config.py │ │ │ │ ├── optimizer.py │ │ │ │ ├── proxy.py │ │ │ │ └── util.py │ │ │ ├── image.py │ │ │ ├── init.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _functions.py │ │ │ │ ├── activation.py │ │ │ │ ├── affine_grid.py │ │ │ │ ├── all_reduce.py │ │ │ │ ├── arange.py │ │ │ │ ├── argsort.py │ │ │ │ ├── argwhere.py │ │ │ │ ├── as_tensor.py │ │ │ │ ├── batchnorm.py │ │ │ │ ├── batchnorm_fused.py │ │ │ │ ├── broadcast_ops.py │ │ │ │ ├── constant.py │ │ │ │ ├── container.py │ │ │ │ ├── conv.py │ │ │ │ ├── dataset.py │ │ │ │ ├── distance.py │ │ │ │ ├── distributed_partial_fc_sample.py │ │ │ │ ├── dropout.py │ │ │ │ ├── einsum.py │ │ │ │ ├── empty.py │ │ │ │ ├── expand.py │ │ │ │ ├── fake_quantization.py │ │ │ │ ├── flatten.py │ │ │ │ ├── fold.py │ │ │ │ ├── fused_mlp.py │ │ │ │ ├── global_cast.py │ │ │ │ ├── grid_sample.py │ │ │ │ ├── instancenorm.py │ │ │ │ ├── interpolate.py │ │ │ │ ├── is_tensor.py │ │ │ │ ├── linear.py │ │ │ │ ├── linspace.py │ │ │ │ ├── logspace.py │ │ │ │ ├── loss.py │ │ │ │ ├── masked_select.py │ │ │ │ ├── math_ops.py │ │ │ │ ├── meshgrid.py │ │ │ │ ├── min_max_observer.py │ │ │ │ ├── module.py │ │ │ │ ├── moving_average_min_max_observer.py │ │ │ │ ├── nms.py │ │ │ │ ├── nonzero.py │ │ │ │ ├── norm.py │ │ │ │ ├── normalization.py │ │ │ │ ├── numel.py │ │ │ │ ├── padding.py │ │ │ │ ├── pixelshuffle.py │ │ │ │ ├── pooling.py │ │ │ │ ├── quantization.py │ │ │ │ ├── reshape.py │ │ │ │ ├── rnn.py │ │ │ │ ├── roll.py │ │ │ │ ├── scatter.py │ │ │ │ ├── slice.py │ │ │ │ ├── sparse.py │ │ │ │ ├── sparse_softmax_cross_entropy.py │ │ │ │ ├── tensor_buffer.py │ │ │ │ ├── tensordot.py │ │ │ │ ├── trigonometric_ops.py │ │ │ │ ├── unique.py │ │ │ │ ├── upsampling.py │ │ │ │ ├── utils.py │ │ │ │ └── where.py │ │ │ ├── optimizer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── adadelta.py │ │ │ │ ├── adagrad.py │ │ │ │ ├── adam.py │ │ │ │ ├── adamw.py │ │ │ │ ├── chained_scheduler.py │ │ │ │ ├── constant_lr.py │ │ │ │ ├── cosine_annealing_lr.py │ │ │ │ ├── cosine_annealing_warm_restarts.py │ │ │ │ ├── cosine_decay_lr.py │ │ │ │ ├── exponential_lr.py │ │ │ │ ├── lamb.py │ │ │ │ ├── lambda_lr.py │ │ │ │ ├── lbfgs.py │ │ │ │ ├── linear_lr.py │ │ │ │ ├── lr_scheduler.py │ │ │ │ ├── multiplicative_lr.py │ │ │ │ ├── multistep_lr.py │ │ │ │ ├── polynomial_lr.py │ │ │ │ ├── reduce_lr_on_plateau.py │ │ │ │ ├── rmsprop.py │ │ │ │ ├── sequential_lr.py │ │ │ │ ├── sgd.py │ │ │ │ ├── step_lr.py │ │ │ │ ├── swa_utils.py │ │ │ │ └── warmup_lr.py │ │ │ ├── parallel/ │ │ │ │ ├── __init__.py │ │ │ │ └── distributed.py │ │ │ ├── parameter.py │ │ │ ├── qat/ │ │ │ │ ├── __init__.py │ │ │ │ └── conv.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── clip_grad.py │ │ │ ├── container.py │ │ │ ├── convert_parameters.py │ │ │ ├── parameters_grouping.py │ │ │ ├── prune.py │ │ │ ├── rnn.py │ │ │ ├── skip_init.py │ │ │ └── weight_norm.py │ │ ├── one_embedding.py │ │ ├── onnx/ │ │ │ ├── __init__.py │ │ │ └── symbolic_helper.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ ├── array_ops.py │ │ │ ├── stateful_ops.py │ │ │ ├── transpose_util.py │ │ │ └── util/ │ │ │ ├── __init__.py │ │ │ └── initializer_util.py │ │ ├── optim/ │ │ │ ├── __init__.py │ │ │ ├── lr_scheduler.py │ │ │ ├── optimizer.py │ │ │ └── swa_utils.py │ │ ├── profiler/ │ │ │ ├── __init__.py │ │ │ ├── events.py │ │ │ ├── profiler.py │ │ │ └── util.py │ │ ├── remat/ │ │ │ └── __init__.py │ │ ├── sbp.py │ │ ├── special/ │ │ │ ├── __init__.py │ │ │ └── special_ops.py │ │ ├── support/ │ │ │ ├── __init__.py │ │ │ ├── async_util.py │ │ │ ├── box.py │ │ │ ├── enable_if.py │ │ │ ├── env_var_util.py │ │ │ ├── func_inspect_util.py │ │ │ ├── high_order_bool.py │ │ │ ├── lazy.py │ │ │ ├── pb_util.py │ │ │ ├── scope_stack.py │ │ │ └── traceinfo.py │ │ ├── sysconfig.py │ │ ├── test/ │ │ │ ├── README.md │ │ │ ├── dataloader/ │ │ │ │ ├── data_utils.py │ │ │ │ ├── test_cifar_dataset_multiprocess.py │ │ │ │ ├── test_cifar_dataset_singleprocess.py │ │ │ │ ├── test_fashion_mnist_dataset.py │ │ │ │ ├── test_lenet.py │ │ │ │ ├── test_mnist_dataset.py │ │ │ │ ├── test_numpy_dataset.py │ │ │ │ ├── test_tensor_dataset.py │ │ │ │ └── test_transforms.py │ │ │ ├── exceptions/ │ │ │ │ ├── test_activation.py │ │ │ │ ├── test_add_n_op.py │ │ │ │ ├── test_arg_sort_op.py │ │ │ │ ├── test_array_functor.py │ │ │ │ ├── test_autograd.py │ │ │ │ ├── test_batch_gather_op.py │ │ │ │ ├── test_bias_add_op.py │ │ │ │ ├── test_binary_functor_exception.py │ │ │ │ ├── test_bmm.py │ │ │ │ ├── test_broadcast_ops.py │ │ │ │ ├── test_chunk.py │ │ │ │ ├── test_cosine_similarity.py │ │ │ │ ├── test_deform_conv2d_op.py │ │ │ │ ├── test_device.py │ │ │ │ ├── test_dot.py │ │ │ │ ├── test_error_reported_in_thread.py │ │ │ │ ├── test_gird_sample_op.py │ │ │ │ ├── test_global_branch_error_local_to_global_with_broadcast_sbp_1n2d.py │ │ │ │ ├── test_global_branch_error_local_to_global_with_broadcast_sbp_1n4d.py │ │ │ │ ├── test_global_branch_error_local_to_global_with_split_sbp.py │ │ │ │ ├── test_global_branch_error_with_global_mean.py │ │ │ │ ├── test_hann_window.py │ │ │ │ ├── test_in_top_k.py │ │ │ │ ├── test_inv.py │ │ │ │ ├── test_layernorm.py │ │ │ │ ├── test_linalg.py │ │ │ │ ├── test_local_global_convert_error.py │ │ │ │ ├── test_median.py │ │ │ │ ├── test_mm.py │ │ │ │ ├── test_mode.py │ │ │ │ ├── test_multi_input_with_diff_device_or_placement.py │ │ │ │ ├── test_mv.py │ │ │ │ ├── test_nn_functor.py │ │ │ │ ├── test_optim_add_param_group.py │ │ │ │ ├── test_pad.py │ │ │ │ ├── test_placement.py │ │ │ │ ├── test_randperm_op.py │ │ │ │ ├── test_reduce_like_ops.py │ │ │ │ ├── test_reduce_ops.py │ │ │ │ ├── test_repeat_interleave.py │ │ │ │ ├── test_reshape.py │ │ │ │ ├── test_reshape_like_op.py │ │ │ │ ├── test_roi_align_op.py │ │ │ │ ├── test_save_load.py │ │ │ │ ├── test_saved_tensor_hooks.py │ │ │ │ ├── test_slice_op.py │ │ │ │ ├── test_smooth_l1_loss_op.py │ │ │ │ ├── test_softmax_cross_entropy_op.py │ │ │ │ ├── test_sparse_cross_entropy_op.py │ │ │ │ ├── test_sparse_softmax_cross_entropy_op.py │ │ │ │ ├── test_split_like_op.py │ │ │ │ ├── test_stft_op.py │ │ │ │ ├── test_tensor_index.py │ │ │ │ ├── test_tensordot.py │ │ │ │ ├── test_to_global_error.py │ │ │ │ ├── test_view.py │ │ │ │ └── throw_error.py │ │ │ ├── expensive/ │ │ │ │ ├── README.md │ │ │ │ ├── _internally_replaced_utils.py │ │ │ │ ├── _test_remat.py │ │ │ │ ├── pytorch_alexnet.py │ │ │ │ ├── pytorch_convmixer.py │ │ │ │ ├── pytorch_convnext.py │ │ │ │ ├── pytorch_crossformer.py │ │ │ │ ├── pytorch_densenet.py │ │ │ │ ├── pytorch_efficientnet.py │ │ │ │ ├── pytorch_ghostnet.py │ │ │ │ ├── pytorch_googlenet.py │ │ │ │ ├── pytorch_inception_v3.py │ │ │ │ ├── pytorch_levit.py │ │ │ │ ├── pytorch_mnasnet.py │ │ │ │ ├── pytorch_poolformer.py │ │ │ │ ├── pytorch_pvt.py │ │ │ │ ├── pytorch_res2net.py │ │ │ │ ├── pytorch_resmlp.py │ │ │ │ ├── pytorch_resnet.py │ │ │ │ ├── pytorch_rexnet.py │ │ │ │ ├── pytorch_rexnetv1_lite.py │ │ │ │ ├── pytorch_senet.py │ │ │ │ ├── pytorch_shufflenetv2.py │ │ │ │ ├── pytorch_squeezenet.py │ │ │ │ ├── pytorch_swin_transformer.py │ │ │ │ ├── pytorch_uniformer.py │ │ │ │ ├── pytroch_mlp_mixer.py │ │ │ │ ├── resnet50_model.py │ │ │ │ ├── test_compatibility.py │ │ │ │ ├── test_conv3d.py │ │ │ │ ├── test_convtranspose.py │ │ │ │ ├── test_dynamic_allocation_gradient_shuffle.py │ │ │ │ ├── test_einsum.py │ │ │ │ ├── test_global_tensor_offload.py │ │ │ │ ├── test_graph_multi_graph_v2.py │ │ │ │ ├── test_id_shuffle.py │ │ │ │ ├── test_id_shuffle_global.py │ │ │ │ ├── test_layernorm.py │ │ │ │ ├── test_oneembedding.py │ │ │ │ ├── test_oneembedding_padding_idx.py │ │ │ │ ├── test_permute.py │ │ │ │ ├── test_remat.py │ │ │ │ ├── test_resnet50_with_bn.py │ │ │ │ ├── test_resnet50_without_bn.py │ │ │ │ ├── test_rnn.py │ │ │ │ ├── test_rnn_cell.py │ │ │ │ ├── test_rnn_pack_sequence.py │ │ │ │ ├── test_rnn_utils.py │ │ │ │ ├── test_sqrt_square_sum.py │ │ │ │ ├── test_tensor_offload.py │ │ │ │ ├── test_tensor_str.py │ │ │ │ └── test_util.py │ │ │ ├── gen_ops_process.py │ │ │ ├── graph/ │ │ │ │ ├── alexnet_model.py │ │ │ │ ├── ofrecord_data_utils.py │ │ │ │ ├── optimizer_test_util.py │ │ │ │ ├── test_alexnet_auto_parallel.py │ │ │ │ ├── test_alexnet_graph.py │ │ │ │ ├── test_comb1to2d.py │ │ │ │ ├── test_comb2d.py │ │ │ │ ├── test_forward_graph.py │ │ │ │ ├── test_free_tensor_not_in_job.py │ │ │ │ ├── test_fx_fuse.py │ │ │ │ ├── test_fx_replace_ops.py │ │ │ │ ├── test_fx_symbolic_trace_module.py │ │ │ │ ├── test_gbc1to2d.py │ │ │ │ ├── test_gbc2d.py │ │ │ │ ├── test_gbc2to1d.py │ │ │ │ ├── test_gbc2to2d.py │ │ │ │ ├── test_graph.py │ │ │ │ ├── test_graph_activation_checkpoint.py │ │ │ │ ├── test_graph_arange.py │ │ │ │ ├── test_graph_asymmetric_io.py │ │ │ │ ├── test_graph_block.py │ │ │ │ ├── test_graph_buffer_limit.py │ │ │ │ ├── test_graph_clip_grad_norm.py │ │ │ │ ├── test_graph_copy.py │ │ │ │ ├── test_graph_debug.py │ │ │ │ ├── test_graph_depend.py │ │ │ │ ├── test_graph_eye.py │ │ │ │ ├── test_graph_free_eager_tensor.py │ │ │ │ ├── test_graph_grad_acc.py │ │ │ │ ├── test_graph_image_gpu_decoder.py │ │ │ │ ├── test_graph_inplace_add.py │ │ │ │ ├── test_graph_io_check.py │ │ │ │ ├── test_graph_linear.py │ │ │ │ ├── test_graph_linear_train.py │ │ │ │ ├── test_graph_loss.py │ │ │ │ ├── test_graph_lr_scale.py │ │ │ │ ├── test_graph_lr_scheduler.py │ │ │ │ ├── test_graph_lr_with_warmup.py │ │ │ │ ├── test_graph_lrs.py │ │ │ │ ├── test_graph_masked_fill.py │ │ │ │ ├── test_graph_nccl_logical_fusion.py │ │ │ │ ├── test_graph_non_contiguous_tensors.py │ │ │ │ ├── test_graph_normal_inplace.py │ │ │ │ ├── test_graph_ofrecord_reader.py │ │ │ │ ├── test_graph_optim_adadelta.py │ │ │ │ ├── test_graph_optim_adagrad.py │ │ │ │ ├── test_graph_optim_adam.py │ │ │ │ ├── test_graph_optim_adamw.py │ │ │ │ ├── test_graph_optim_ftrl.py │ │ │ │ ├── test_graph_optim_lamb.py │ │ │ │ ├── test_graph_optim_rmsprop.py │ │ │ │ ├── test_graph_optim_sgd.py │ │ │ │ ├── test_graph_optimizer.py │ │ │ │ ├── test_graph_pipeline.py │ │ │ │ ├── test_graph_pipeline_delay.py │ │ │ │ ├── test_graph_random_seed.py │ │ │ │ ├── test_graph_relu.py │ │ │ │ ├── test_graph_reshape_acc.py │ │ │ │ ├── test_graph_reuse_var.py │ │ │ │ ├── test_graph_save_load.py │ │ │ │ ├── test_graph_save_load_global_b_s.py │ │ │ │ ├── test_graph_scalar.py │ │ │ │ ├── test_graph_separate_compile.py │ │ │ │ ├── test_graph_session_env_destruct.py │ │ │ │ ├── test_graph_session_env_destruct1.py │ │ │ │ ├── test_graph_sparse_optimizer.py │ │ │ │ ├── test_graph_sparse_softmax_cross_entropy.py │ │ │ │ ├── test_graph_tensor_clone.py │ │ │ │ ├── test_graph_tensor_detach.py │ │ │ │ ├── test_graph_with_global.py │ │ │ │ ├── test_graph_zero.py │ │ │ │ ├── test_input_op_expr.py │ │ │ │ ├── test_long_add_n_pass.py │ │ │ │ ├── test_modify_module_forward.py │ │ │ │ ├── test_multi_client_session.py │ │ │ │ ├── test_multi_graph.py │ │ │ │ ├── test_multi_tensor_adam_update_with_cast.py │ │ │ │ ├── test_multi_tensor_sgd_update_with_cast.py │ │ │ │ ├── test_nccl_logical_send_recv.py │ │ │ │ ├── test_neq_device_process_num.py │ │ │ │ ├── test_oneflow_compiler.py │ │ │ │ ├── test_optimization_conf.py │ │ │ │ ├── test_output_op_expr.py │ │ │ │ ├── test_run_global_graph_by_vm.py │ │ │ │ ├── test_run_graph_by_vm.py │ │ │ │ ├── test_to_global.py │ │ │ │ ├── test_tvm_frontend_dependency_on_graph.py │ │ │ │ ├── test_user_op_expr.py │ │ │ │ ├── test_util.py │ │ │ │ └── test_variable_op_expr.py │ │ │ ├── misc/ │ │ │ │ ├── mock_example.py │ │ │ │ ├── test_autograd_functional.py │ │ │ │ ├── test_distributed_env_vars.py │ │ │ │ ├── test_empty_cache.py │ │ │ │ ├── test_env_cuda.py │ │ │ │ ├── test_manual_seed_api.py │ │ │ │ ├── test_mock_diffusers.py │ │ │ │ ├── test_mock_scope.py │ │ │ │ ├── test_np_dtype_converter.py │ │ │ │ ├── test_placement.py │ │ │ │ └── test_pybind11_caster.py │ │ │ ├── modules/ │ │ │ │ ├── image_test_util.py │ │ │ │ ├── optimizer_test_util.py │ │ │ │ ├── save_load_test_data/ │ │ │ │ │ ├── 3x3_i3o3_conv2d/ │ │ │ │ │ │ ├── pickled_data │ │ │ │ │ │ ├── tensor_3/ │ │ │ │ │ │ │ ├── meta │ │ │ │ │ │ │ └── out │ │ │ │ │ │ └── tensor_4/ │ │ │ │ │ │ ├── meta │ │ │ │ │ │ └── out │ │ │ │ │ └── 3x3_i3o3_conv2d_params/ │ │ │ │ │ ├── pickled_data │ │ │ │ │ ├── tensor_5/ │ │ │ │ │ │ ├── meta │ │ │ │ │ │ └── out │ │ │ │ │ └── tensor_6/ │ │ │ │ │ ├── meta │ │ │ │ │ └── out │ │ │ │ ├── sync_batchnorm_test_util.py │ │ │ │ ├── test_0_dim_tensor.py │ │ │ │ ├── test_TripletMarginLoss.py │ │ │ │ ├── test_abs.py │ │ │ │ ├── test_activation.py │ │ │ │ ├── test_adaptive_max_pool.py │ │ │ │ ├── test_adaptive_pool.py │ │ │ │ ├── test_adaptive_pool_fp16.py │ │ │ │ ├── test_add.py │ │ │ │ ├── test_addcdiv.py │ │ │ │ ├── test_addcmul.py │ │ │ │ ├── test_addmm.py │ │ │ │ ├── test_affine_grid.py │ │ │ │ ├── test_allclose.py │ │ │ │ ├── test_allreduce.py │ │ │ │ ├── test_amax.py │ │ │ │ ├── test_amin.py │ │ │ │ ├── test_arange.py │ │ │ │ ├── test_argmax.py │ │ │ │ ├── test_argmin.py │ │ │ │ ├── test_argsort.py │ │ │ │ ├── test_argwhere.py │ │ │ │ ├── test_as_strided.py │ │ │ │ ├── test_as_tensor.py │ │ │ │ ├── test_asyncs_thread.py │ │ │ │ ├── test_atleast.py │ │ │ │ ├── test_auto_to_global.py │ │ │ │ ├── test_autograd.py │ │ │ │ ├── test_autograd_function.py │ │ │ │ ├── test_autograd_mode.py │ │ │ │ ├── test_avgpool.py │ │ │ │ ├── test_baddbmm.py │ │ │ │ ├── test_batch_gather.py │ │ │ │ ├── test_batchnorm.py │ │ │ │ ├── test_batchnorm_add_relu.py │ │ │ │ ├── test_bernoulli.py │ │ │ │ ├── test_binary_math_ops_dtype.py │ │ │ │ ├── test_bincount.py │ │ │ │ ├── test_bitwise.py │ │ │ │ ├── test_bmm.py │ │ │ │ ├── test_broadcast_like.py │ │ │ │ ├── test_broadcast_ops.py │ │ │ │ ├── test_cast.py │ │ │ │ ├── test_ceil.py │ │ │ │ ├── test_check_meta_consistency.py │ │ │ │ ├── test_checkpointing.py │ │ │ │ ├── test_chunk.py │ │ │ │ ├── test_clamp.py │ │ │ │ ├── test_clip_grad.py │ │ │ │ ├── test_clone.py │ │ │ │ ├── test_coco_reader.py │ │ │ │ ├── test_coin_flip.py │ │ │ │ ├── test_comb2to2d.py │ │ │ │ ├── test_combined_margin_loss.py │ │ │ │ ├── test_comm.py │ │ │ │ ├── test_comm_ops.py │ │ │ │ ├── test_concat.py │ │ │ │ ├── test_constant.py │ │ │ │ ├── test_constant_pad.py │ │ │ │ ├── test_contiguous.py │ │ │ │ ├── test_conv1d.py │ │ │ │ ├── test_conv2d.py │ │ │ │ ├── test_copy.py │ │ │ │ ├── test_cosine_similarity.py │ │ │ │ ├── test_ctc_greedy_decoder.py │ │ │ │ ├── test_ctc_loss.py │ │ │ │ ├── test_cublas_fused_mlp.py │ │ │ │ ├── test_cum_ops.py │ │ │ │ ├── test_dataset.py │ │ │ │ ├── test_ddp.py │ │ │ │ ├── test_ddp_multi_outputs.py │ │ │ │ ├── test_deconv2d.py │ │ │ │ ├── test_default_dtype.py │ │ │ │ ├── test_deform_conv2d.py │ │ │ │ ├── test_det.py │ │ │ │ ├── test_diag.py │ │ │ │ ├── test_diagonal.py │ │ │ │ ├── test_div.py │ │ │ │ ├── test_dlpack.py │ │ │ │ ├── test_dot.py │ │ │ │ ├── test_dropout.py │ │ │ │ ├── test_dynamic_allocation_gradient_shuffle_shuffle_global.py │ │ │ │ ├── test_eager_boxing.py │ │ │ │ ├── test_eager_boxing_exhaustive.py │ │ │ │ ├── test_empty.py │ │ │ │ ├── test_eq.py │ │ │ │ ├── test_equal.py │ │ │ │ ├── test_erf.py │ │ │ │ ├── test_erfc.py │ │ │ │ ├── test_erfinv.py │ │ │ │ ├── test_expand.py │ │ │ │ ├── test_expand_stride.py │ │ │ │ ├── test_expm1.py │ │ │ │ ├── test_eye.py │ │ │ │ ├── test_fake_quantization.py │ │ │ │ ├── test_fft.py │ │ │ │ ├── test_flatten.py │ │ │ │ ├── test_flip.py │ │ │ │ ├── test_floor.py │ │ │ │ ├── test_fmod.py │ │ │ │ ├── test_fold.py │ │ │ │ ├── test_fork_sub_process.py │ │ │ │ ├── test_frac.py │ │ │ │ ├── test_from_numpy.py │ │ │ │ ├── test_from_torch.py │ │ │ │ ├── test_functional_docstr.py │ │ │ │ ├── test_functional_scalar_tensor_param.py │ │ │ │ ├── test_fused_attention_ops.py │ │ │ │ ├── test_fused_bias_add_dropout.py │ │ │ │ ├── test_fused_bias_add_gelu.py │ │ │ │ ├── test_fused_bias_add_scale_mask_softmax_dropout.py │ │ │ │ ├── test_fused_center.py │ │ │ │ ├── test_fused_codegeex_qkv_reshape.py │ │ │ │ ├── test_fused_cross_interaction.py │ │ │ │ ├── test_fused_dot_feature_interaction.py │ │ │ │ ├── test_fused_gelu_mul.py │ │ │ │ ├── test_fused_get_boundding_boxes_coord.py │ │ │ │ ├── test_fused_get_ciou_diagonal_angle.py │ │ │ │ ├── test_fused_get_ciou_result.py │ │ │ │ ├── test_fused_get_convex_diagonal_squared.py │ │ │ │ ├── test_fused_get_intersection_area.py │ │ │ │ ├── test_fused_get_iou.py │ │ │ │ ├── test_fused_glu.py │ │ │ │ ├── test_fused_matmul_bias.py │ │ │ │ ├── test_fused_matmul_bias_add_relu_dropout.py │ │ │ │ ├── test_fused_rotary_embedding.py │ │ │ │ ├── test_fused_scale_mask_bias_softmax.py │ │ │ │ ├── test_fused_scale_mask_softmax.py │ │ │ │ ├── test_fused_scale_mask_softmax_dropout.py │ │ │ │ ├── test_fused_scale_tril.py │ │ │ │ ├── test_fused_self_attention.py │ │ │ │ ├── test_fused_tril_softmax_mask_scale.py │ │ │ │ ├── test_fused_weighted_sum.py │ │ │ │ ├── test_gather.py │ │ │ │ ├── test_gather_nd.py │ │ │ │ ├── test_gelu_approximate.py │ │ │ │ ├── test_generator.py │ │ │ │ ├── test_global_0_dim_tensor.py │ │ │ │ ├── test_global_TripletMarginLoss.py │ │ │ │ ├── test_global_abs.py │ │ │ │ ├── test_global_activation.py │ │ │ │ ├── test_global_adaptive_pool.py │ │ │ │ ├── test_global_add.py │ │ │ │ ├── test_global_addcdiv.py │ │ │ │ ├── test_global_addcmul.py │ │ │ │ ├── test_global_addmm.py │ │ │ │ ├── test_global_affine_grid.py │ │ │ │ ├── test_global_argmax.py │ │ │ │ ├── test_global_argmin.py │ │ │ │ ├── test_global_argsort.py │ │ │ │ ├── test_global_argwhere.py │ │ │ │ ├── test_global_atleast.py │ │ │ │ ├── test_global_avgpool.py │ │ │ │ ├── test_global_batch_gather.py │ │ │ │ ├── test_global_bincount.py │ │ │ │ ├── test_global_bitwise.py │ │ │ │ ├── test_global_broadcase_like.py │ │ │ │ ├── test_global_broadcast_matmul.py │ │ │ │ ├── test_global_broadcast_ops.py │ │ │ │ ├── test_global_cast.py │ │ │ │ ├── test_global_chunk.py │ │ │ │ ├── test_global_clone.py │ │ │ │ ├── test_global_coin_flip.py │ │ │ │ ├── test_global_concat.py │ │ │ │ ├── test_global_constant.py │ │ │ │ ├── test_global_ctc_loss.py │ │ │ │ ├── test_global_cumprod.py │ │ │ │ ├── test_global_cumsum.py │ │ │ │ ├── test_global_deconv2d.py │ │ │ │ ├── test_global_deform_conv2d.py │ │ │ │ ├── test_global_det.py │ │ │ │ ├── test_global_diag.py │ │ │ │ ├── test_global_diagonal.py │ │ │ │ ├── test_global_div.py │ │ │ │ ├── test_global_dot.py │ │ │ │ ├── test_global_dropout.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase1.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase10.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase11.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase2.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase3.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase4.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase5.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase6.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase7.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase8.py │ │ │ │ ├── test_global_einsum_alphaflod_usecase9.py │ │ │ │ ├── test_global_einsum_attention.py │ │ │ │ ├── test_global_einsum_batch_matmul.py │ │ │ │ ├── test_global_einsum_batch_matmul2.py │ │ │ │ ├── test_global_einsum_batch_matmul3.py │ │ │ │ ├── test_global_einsum_batch_matmul4.py │ │ │ │ ├── test_global_einsum_batch_matrix_vector_multiply.py │ │ │ │ ├── test_global_einsum_batch_permute.py │ │ │ │ ├── test_global_einsum_bilinear_transformation.py │ │ │ │ ├── test_global_einsum_eltwise_mul_sum_row.py │ │ │ │ ├── test_global_einsum_eltwise_mul_then_reduce_sum.py │ │ │ │ ├── test_global_einsum_eltwise_multiply.py │ │ │ │ ├── test_global_einsum_get_diagonal.py │ │ │ │ ├── test_global_einsum_matmul.py │ │ │ │ ├── test_global_einsum_matmul2.py │ │ │ │ ├── test_global_einsum_matrix_column_sum.py │ │ │ │ ├── test_global_einsum_matrix_transpose.py │ │ │ │ ├── test_global_einsum_matrix_vector_multiply.py │ │ │ │ ├── test_global_einsum_reduce_sum.py │ │ │ │ ├── test_global_einsum_tensor_contraction.py │ │ │ │ ├── test_global_einsum_tensor_contraction2.py │ │ │ │ ├── test_global_einsum_vector_inner_product.py │ │ │ │ ├── test_global_einsum_vector_outer_product.py │ │ │ │ ├── test_global_empty.py │ │ │ │ ├── test_global_eq.py │ │ │ │ ├── test_global_erf.py │ │ │ │ ├── test_global_erfc.py │ │ │ │ ├── test_global_expand_op.py │ │ │ │ ├── test_global_expm1.py │ │ │ │ ├── test_global_eye.py │ │ │ │ ├── test_global_fill.py │ │ │ │ ├── test_global_flatten.py │ │ │ │ ├── test_global_flip.py │ │ │ │ ├── test_global_floor.py │ │ │ │ ├── test_global_fmod.py │ │ │ │ ├── test_global_fold.py │ │ │ │ ├── test_global_frac.py │ │ │ │ ├── test_global_full.py │ │ │ │ ├── test_global_full_like.py │ │ │ │ ├── test_global_greater.py │ │ │ │ ├── test_global_greater_equal.py │ │ │ │ ├── test_global_grid_sample.py │ │ │ │ ├── test_global_groupnorm.py │ │ │ │ ├── test_global_gru_cell.py │ │ │ │ ├── test_global_hann_window.py │ │ │ │ ├── test_global_higher_derivative_activation.py │ │ │ │ ├── test_global_higher_derivative_conv.py │ │ │ │ ├── test_global_higher_derivative_div.py │ │ │ │ ├── test_global_higher_derivative_loss.py │ │ │ │ ├── test_global_higher_derivative_matmul.py │ │ │ │ ├── test_global_higher_derivative_neg.py │ │ │ │ ├── test_global_higher_derivative_pool.py │ │ │ │ ├── test_global_higher_derivative_pow.py │ │ │ │ ├── test_global_higher_derivative_scalar_pow.py │ │ │ │ ├── test_global_higher_derivative_slice.py │ │ │ │ ├── test_global_higher_derivative_softmax.py │ │ │ │ ├── test_global_inv.py │ │ │ │ ├── test_global_lerp.py │ │ │ │ ├── test_global_linalg_cross.py │ │ │ │ ├── test_global_linear.py │ │ │ │ ├── test_global_linspace.py │ │ │ │ ├── test_global_logspace.py │ │ │ │ ├── test_global_lstm_cell.py │ │ │ │ ├── test_global_masked_fill.py │ │ │ │ ├── test_global_masked_select.py │ │ │ │ ├── test_global_math_op_higher_derivative.py │ │ │ │ ├── test_global_math_ops.py │ │ │ │ ├── test_global_matmul.py │ │ │ │ ├── test_global_max.py │ │ │ │ ├── test_global_maximum_minimum.py │ │ │ │ ├── test_global_maxpool.py │ │ │ │ ├── test_global_maxunpool.py │ │ │ │ ├── test_global_mean.py │ │ │ │ ├── test_global_median.py │ │ │ │ ├── test_global_meshgrid.py │ │ │ │ ├── test_global_min.py │ │ │ │ ├── test_global_min_max_observer.py │ │ │ │ ├── test_global_movedim.py │ │ │ │ ├── test_global_moving_average_max_min_observer.py │ │ │ │ ├── test_global_mul.py │ │ │ │ ├── test_global_mv.py │ │ │ │ ├── test_global_nansum.py │ │ │ │ ├── test_global_narrow.py │ │ │ │ ├── test_global_ne.py │ │ │ │ ├── test_global_negative.py │ │ │ │ ├── test_global_nms.py │ │ │ │ ├── test_global_normal.py │ │ │ │ ├── test_global_normalize.py │ │ │ │ ├── test_global_nozero.py │ │ │ │ ├── test_global_ones_like.py │ │ │ │ ├── test_global_pad.py │ │ │ │ ├── test_global_partical_fc.py │ │ │ │ ├── test_global_permute.py │ │ │ │ ├── test_global_rand.py │ │ │ │ ├── test_global_randint.py │ │ │ │ ├── test_global_randint_like.py │ │ │ │ ├── test_global_randn.py │ │ │ │ ├── test_global_random_op_data.py │ │ │ │ ├── test_global_randperm.py │ │ │ │ ├── test_global_reciprocal.py │ │ │ │ ├── test_global_reflection_pad2d.py │ │ │ │ ├── test_global_repeat.py │ │ │ │ ├── test_global_replication_pad2d.py │ │ │ │ ├── test_global_reshape.py │ │ │ │ ├── test_global_rnn.py │ │ │ │ ├── test_global_rnn_cell.py │ │ │ │ ├── test_global_roi_align.py │ │ │ │ ├── test_global_roll.py │ │ │ │ ├── test_global_round.py │ │ │ │ ├── test_global_scatter_nd.py │ │ │ │ ├── test_global_scatter_ops.py │ │ │ │ ├── test_global_searchsorted.py │ │ │ │ ├── test_global_sign.py │ │ │ │ ├── test_global_slice.py │ │ │ │ ├── test_global_slice_update.py │ │ │ │ ├── test_global_sort.py │ │ │ │ ├── test_global_sparse.py │ │ │ │ ├── test_global_sparse_softmax_cross_entropy.py │ │ │ │ ├── test_global_split.py │ │ │ │ ├── test_global_sqrt_square_sum.py │ │ │ │ ├── test_global_squeeze.py │ │ │ │ ├── test_global_stack.py │ │ │ │ ├── test_global_stateful_kernel_with_cache.py │ │ │ │ ├── test_global_std.py │ │ │ │ ├── test_global_sub.py │ │ │ │ ├── test_global_sum.py │ │ │ │ ├── test_global_tensor_new.py │ │ │ │ ├── test_global_tensor_ops.py │ │ │ │ ├── test_global_tensor_scatter_nd_update.py │ │ │ │ ├── test_global_tensordot.py │ │ │ │ ├── test_global_tile.py │ │ │ │ ├── test_global_transpose.py │ │ │ │ ├── test_global_tril.py │ │ │ │ ├── test_global_triu.py │ │ │ │ ├── test_global_unbind.py │ │ │ │ ├── test_global_unfold.py │ │ │ │ ├── test_global_unfold_tensor.py │ │ │ │ ├── test_global_unique.py │ │ │ │ ├── test_global_unsqueeze.py │ │ │ │ ├── test_global_upsample.py │ │ │ │ ├── test_global_var.py │ │ │ │ ├── test_global_vector_matrix_product.py │ │ │ │ ├── test_global_view.py │ │ │ │ ├── test_global_weight_norm.py │ │ │ │ ├── test_global_where.py │ │ │ │ ├── test_global_zeropad2d.py │ │ │ │ ├── test_global_zeros_like.py │ │ │ │ ├── test_glu.py │ │ │ │ ├── test_gpt_data_loader.py │ │ │ │ ├── test_greater.py │ │ │ │ ├── test_greater_equal.py │ │ │ │ ├── test_grid_sample.py │ │ │ │ ├── test_grouped_matmul_bias.py │ │ │ │ ├── test_groupnorm.py │ │ │ │ ├── test_groupwise_quantization.py │ │ │ │ ├── test_gumbel_softmax.py │ │ │ │ ├── test_hann_window.py │ │ │ │ ├── test_higher_derivative_activation.py │ │ │ │ ├── test_higher_derivative_conv.py │ │ │ │ ├── test_higher_derivative_div.py │ │ │ │ ├── test_higher_derivative_loss.py │ │ │ │ ├── test_higher_derivative_matmul.py │ │ │ │ ├── test_higher_derivative_neg.py │ │ │ │ ├── test_higher_derivative_pool.py │ │ │ │ ├── test_higher_derivative_pow.py │ │ │ │ ├── test_higher_derivative_scalar_pow.py │ │ │ │ ├── test_higher_derivative_slice.py │ │ │ │ ├── test_higher_derivative_softmax.py │ │ │ │ ├── test_host_memory_input.py │ │ │ │ ├── test_hsplit.py │ │ │ │ ├── test_hub.py │ │ │ │ ├── test_image_batch_align.py │ │ │ │ ├── test_image_decode.py │ │ │ │ ├── test_image_flip.py │ │ │ │ ├── test_image_normalize.py │ │ │ │ ├── test_image_resize.py │ │ │ │ ├── test_in_top_k.py │ │ │ │ ├── test_index_add.py │ │ │ │ ├── test_index_select.py │ │ │ │ ├── test_info.py │ │ │ │ ├── test_initializer.py │ │ │ │ ├── test_instancenorm.py │ │ │ │ ├── test_interpolate.py │ │ │ │ ├── test_inv.py │ │ │ │ ├── test_isclose.py │ │ │ │ ├── test_jit_script_api.py │ │ │ │ ├── test_layer_norm.py │ │ │ │ ├── test_lerp.py │ │ │ │ ├── test_less.py │ │ │ │ ├── test_less_equal.py │ │ │ │ ├── test_linalg_cross.py │ │ │ │ ├── test_linear.py │ │ │ │ ├── test_linspace.py │ │ │ │ ├── test_log1p.py │ │ │ │ ├── test_logaddexp.py │ │ │ │ ├── test_logical_and.py │ │ │ │ ├── test_logical_not.py │ │ │ │ ├── test_logical_or.py │ │ │ │ ├── test_logical_reduce.py │ │ │ │ ├── test_logical_xor.py │ │ │ │ ├── test_logspace.py │ │ │ │ ├── test_logsumexp.py │ │ │ │ ├── test_loss.py │ │ │ │ ├── test_loss_global.py │ │ │ │ ├── test_lr_scheduler.py │ │ │ │ ├── test_masked_fill.py │ │ │ │ ├── test_masked_select.py │ │ │ │ ├── test_math_op_higher_derivative.py │ │ │ │ ├── test_math_ops.py │ │ │ │ ├── test_matmul.py │ │ │ │ ├── test_max.py │ │ │ │ ├── test_maxpool.py │ │ │ │ ├── test_maxunpool.py │ │ │ │ ├── test_mean.py │ │ │ │ ├── test_median.py │ │ │ │ ├── test_meshgrid.py │ │ │ │ ├── test_min.py │ │ │ │ ├── test_min_max_observer.py │ │ │ │ ├── test_mock.py │ │ │ │ ├── test_mode.py │ │ │ │ ├── test_module.py │ │ │ │ ├── test_module_to.py │ │ │ │ ├── test_module_to_global_or_local.py │ │ │ │ ├── test_module_to_half.py │ │ │ │ ├── test_movedim.py │ │ │ │ ├── test_moving_average_min_max_observer.py │ │ │ │ ├── test_mul.py │ │ │ │ ├── test_multi_tensor_yolov5_weight_update.py │ │ │ │ ├── test_multinomial.py │ │ │ │ ├── test_nansum.py │ │ │ │ ├── test_narrow.py │ │ │ │ ├── test_ne.py │ │ │ │ ├── test_negative.py │ │ │ │ ├── test_nll_loss.py │ │ │ │ ├── test_nms.py │ │ │ │ ├── test_noncontiguous_binary_op.py │ │ │ │ ├── test_nonzero.py │ │ │ │ ├── test_norm.py │ │ │ │ ├── test_normalize.py │ │ │ │ ├── test_ofrecord_reader.py │ │ │ │ ├── test_one_embedding_adagrad.py │ │ │ │ ├── test_one_embedding_adam.py │ │ │ │ ├── test_one_embedding_ftrl.py │ │ │ │ ├── test_one_embedding_sgd.py │ │ │ │ ├── test_one_hot.py │ │ │ │ ├── test_ones_like.py │ │ │ │ ├── test_optim_adadelta.py │ │ │ │ ├── test_optim_adagrad.py │ │ │ │ ├── test_optim_adam.py │ │ │ │ ├── test_optim_adamw.py │ │ │ │ ├── test_optim_add_param_group.py │ │ │ │ ├── test_optim_ftrl.py │ │ │ │ ├── test_optim_lamb.py │ │ │ │ ├── test_optim_lbfgs.py │ │ │ │ ├── test_optim_rmsprop.py │ │ │ │ ├── test_optim_sgd.py │ │ │ │ ├── test_pairwise_distance.py │ │ │ │ ├── test_param_group.py │ │ │ │ ├── test_parameters_grouping.py │ │ │ │ ├── test_parital_fc.py │ │ │ │ ├── test_pixel_shuffle.py │ │ │ │ ├── test_prelu.py │ │ │ │ ├── test_prod.py │ │ │ │ ├── test_pruning.py │ │ │ │ ├── test_qat_conv_modules.py │ │ │ │ ├── test_quantile.py │ │ │ │ ├── test_quantization.py │ │ │ │ ├── test_quick_gelu.py │ │ │ │ ├── test_rand.py │ │ │ │ ├── test_randint.py │ │ │ │ ├── test_randint_like.py │ │ │ │ ├── test_randn.py │ │ │ │ ├── test_randn_like.py │ │ │ │ ├── test_random_generator_and_seed.py │ │ │ │ ├── test_randperm.py │ │ │ │ ├── test_reciprocal.py │ │ │ │ ├── test_reduce.py │ │ │ │ ├── test_reduce_sum_like.py │ │ │ │ ├── test_reflection_pad.py │ │ │ │ ├── test_repeat.py │ │ │ │ ├── test_repeat_interleave.py │ │ │ │ ├── test_replication_pad.py │ │ │ │ ├── test_reshape.py │ │ │ │ ├── test_reshape_sbp.py │ │ │ │ ├── test_resnet_load_torch_weight_compatibile.py │ │ │ │ ├── test_rmsnorm.py │ │ │ │ ├── test_roc_auc_score.py │ │ │ │ ├── test_roi_align.py │ │ │ │ ├── test_roll.py │ │ │ │ ├── test_round.py │ │ │ │ ├── test_rrelu.py │ │ │ │ ├── test_save_load.py │ │ │ │ ├── test_saved_tensor_hooks.py │ │ │ │ ├── test_sbp_symbol.py │ │ │ │ ├── test_scatter_nd.py │ │ │ │ ├── test_scatter_ops.py │ │ │ │ ├── test_searchsorted.py │ │ │ │ ├── test_select.py │ │ │ │ ├── test_shutting_down.py │ │ │ │ ├── test_sign.py │ │ │ │ ├── test_single_threaded_vm.py │ │ │ │ ├── test_skip_layer_norm.py │ │ │ │ ├── test_skip_rms_norm.py │ │ │ │ ├── test_slice.py │ │ │ │ ├── test_softmax.py │ │ │ │ ├── test_softplus.py │ │ │ │ ├── test_sort.py │ │ │ │ ├── test_sparse.py │ │ │ │ ├── test_sparse_softmax_cross_entropy.py │ │ │ │ ├── test_special_ops.py │ │ │ │ ├── test_split.py │ │ │ │ ├── test_square_relu.py │ │ │ │ ├── test_squeeze.py │ │ │ │ ├── test_stack.py │ │ │ │ ├── test_stateful_kernel_with_cache.py │ │ │ │ ├── test_stateful_local_opkernel.py │ │ │ │ ├── test_std.py │ │ │ │ ├── test_stft.py │ │ │ │ ├── test_sub.py │ │ │ │ ├── test_sum.py │ │ │ │ ├── test_swapaxes.py │ │ │ │ ├── test_swapdims.py │ │ │ │ ├── test_swautils.py │ │ │ │ ├── test_sync_and_async_allreduce.py │ │ │ │ ├── test_sync_batchnorm.py │ │ │ │ ├── test_t.py │ │ │ │ ├── test_t5_layernorm.py │ │ │ │ ├── test_tensor_buffer.py │ │ │ │ ├── test_tensor_ops.py │ │ │ │ ├── test_tensor_scatter_nd_update.py │ │ │ │ ├── test_tensor_split.py │ │ │ │ ├── test_tensor_to.py │ │ │ │ ├── test_tensordot.py │ │ │ │ ├── test_tile.py │ │ │ │ ├── test_to_torch.py │ │ │ │ ├── test_topk.py │ │ │ │ ├── test_transpose.py │ │ │ │ ├── test_tril.py │ │ │ │ ├── test_triu.py │ │ │ │ ├── test_trunc.py │ │ │ │ ├── test_trunc_divide.py │ │ │ │ ├── test_type_tensor.py │ │ │ │ ├── test_unbind.py │ │ │ │ ├── test_unfold.py │ │ │ │ ├── test_unfold_tensor.py │ │ │ │ ├── test_unique.py │ │ │ │ ├── test_unsqueeze.py │ │ │ │ ├── test_upsample.py │ │ │ │ ├── test_util_ops.py │ │ │ │ ├── test_utils.py │ │ │ │ ├── test_var.py │ │ │ │ ├── test_view.py │ │ │ │ ├── test_vsplit.py │ │ │ │ ├── test_weight_norm.py │ │ │ │ ├── test_where.py │ │ │ │ └── test_zeropad2d.py │ │ │ ├── profiler/ │ │ │ │ ├── test_events.py │ │ │ │ └── test_profile_lenet.py │ │ │ └── tensor/ │ │ │ ├── test_autocast.py │ │ │ ├── test_bfloat16_activation.py │ │ │ ├── test_complex.py │ │ │ ├── test_data_ptr.py │ │ │ ├── test_global_tensor.py │ │ │ ├── test_global_tensor_and_ndarray_compatibility.py │ │ │ ├── test_global_tensor_indexing.py │ │ │ ├── test_lazy_tensor_indexing.py │ │ │ ├── test_meta_tensor.py │ │ │ ├── test_new_tensor.py │ │ │ ├── test_parameter.py │ │ │ ├── test_safetensors.py │ │ │ ├── test_tensor_and_ndarray_compatibility.py │ │ │ ├── test_tensor_exponential.py │ │ │ ├── test_tensor_indexing.py │ │ │ ├── test_tensor_indexing2.py │ │ │ ├── test_tensor_is_view.py │ │ │ ├── test_tensor_part_1.py │ │ │ ├── test_tensor_part_2.py │ │ │ ├── test_tensor_part_3.py │ │ │ ├── test_tensor_pin_memory.py │ │ │ └── test_tensor_to_memory_format.py │ │ ├── test_utils/ │ │ │ ├── __init__.py │ │ │ ├── automated_test_util/ │ │ │ │ ├── __init__.py │ │ │ │ ├── generators.py │ │ │ │ ├── global_scope.py │ │ │ │ ├── profiler.py │ │ │ │ ├── torch_flow_dual_object.py │ │ │ │ └── util.py │ │ │ ├── oneflow_pytorch_compatibility/ │ │ │ │ ├── __init__.py │ │ │ │ └── oneflow_pytorch_compatiblity_test.py │ │ │ ├── test_util.py │ │ │ └── throttle.py │ │ ├── unittest/ │ │ │ ├── __init__.py │ │ │ ├── dataset.py │ │ │ ├── env.py │ │ │ └── mlir.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── _utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── collate.py │ │ │ │ ├── fetch.py │ │ │ │ ├── pin_memory.py │ │ │ │ ├── signal_handling.py │ │ │ │ └── worker.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── decorator.py │ │ │ ├── distributed.py │ │ │ └── sampler.py │ │ ├── global_view/ │ │ │ ├── __init__.py │ │ │ ├── global_mode.py │ │ │ ├── global_utils.py │ │ │ ├── to_global.py │ │ │ └── to_local.py │ │ ├── hooks.py │ │ ├── insight/ │ │ │ ├── README.md │ │ │ ├── requirements.txt │ │ │ └── sqlite_to_google_trace_event.py │ │ ├── model_zoo.py │ │ └── tensor/ │ │ ├── __init__.py │ │ └── from_or_to_torch_tensor.py │ └── setup.py └── tools/ ├── check_src.py ├── clean_generated_api.py ├── create_pip_index.py ├── flags_from_git_diff.py ├── functional/ │ ├── generate_dispatch_stateful_ops.py │ ├── generate_functional_api.py │ ├── generate_tensor_api.py │ └── generator.py ├── generate_header_list.py ├── generate_pip_version.py ├── oneflow-tblgen/ │ ├── CMakeLists.txt │ ├── backends.h │ ├── example/ │ │ └── constant.td │ ├── op_schema_emitter.cpp │ ├── op_schema_header.inc │ ├── op_schema_source.inc │ ├── op_schema_types.inc │ └── tablegen.cpp ├── oss_file_exist.py └── package_mirror.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ --- Language: Cpp AccessModifierOffset: -1 AlignAfterOpenBracket: Align AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false AlignEscapedNewlinesLeft: true AlignOperands: true AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: true AllowShortCaseLabelsOnASingleLine: true AllowShortFunctionsOnASingleLine: All AllowShortIfStatementsOnASingleLine: true AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false AlwaysBreakTemplateDeclarations: true BinPackArguments: true BinPackParameters: true BraceWrapping: AfterClass: true AfterControlStatement: false AfterEnum: false AfterFunction: false AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false BeforeCatch: false BeforeElse: false IndentBraces: false BreakBeforeBinaryOperators: NonAssignment BreakBeforeBraces: Attach BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakAfterJavaFieldAnnotations: false BreakStringLiterals: true ColumnLimit: 100 CommentPragmas: '^ IWYU pragma:' BreakBeforeInheritanceComma: false ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] IncludeCategories: - Regex: '^<.*\.h>' Priority: 1 - Regex: '^<.*' Priority: 2 - Regex: '.*' Priority: 3 IncludeIsMainRegex: '([-_](test|unittest))?$' IndentCaseLabels: true IndentWidth: 2 IndentWrappedFunctionNames: false JavaScriptQuotes: Leave JavaScriptWrapImports: true KeepEmptyLinesAtTheStartOfBlocks: false MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None ObjCBlockIndentWidth: 2 ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: false PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 PointerAlignment: Left ReflowComments: true SortIncludes: false SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: false SpaceBeforeAssignmentOperators: true SpaceBeforeParens: ControlStatements SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 2 SpacesInAngles: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false Standard: Auto TabWidth: 8 UseTab: Never ... ================================================ FILE: .clang-tidy ================================================ # `maybe-*` checks are only available on OneFlow custom clang-tidy and clangd # `-allow-enabling-analyzer-alpha-checkers` should be passed to clang-tidy for CSA checkers named `clang-analyzer-alpha.*` (or `-allow-enabling-alpha-checkers` for run-clang-tidy.py) # `aggressive-binary-operation-simplification` should be enabled (via `-Xclang -analyzer-config -Xclang aggressive-binary-operation-simplification=true` in clang) # there is some problem in `clang-analyzer-alpha.clone.*`, so do not enable it # `clang-analyzer-alpha.deadcode.*` is just too verbose to enable Checks: >- -*, clang-diagnostic-*, maybe-*, clang-analyzer-core.*, clang-analyzer-cplusplus.*, clang-analyzer-nullability.*, clang-analyzer-deadcode.*, clang-analyzer-security.*, clang-analyzer-optin.cplusplus.*, clang-analyzer-optin.performance.*, clang-analyzer-alpha.core.*, clang-analyzer-alpha.cplusplus.*, clang-analyzer-alpha.security.*, cppcoreguidelines-avoid-goto, cppcoreguidelines-init-variables, cppcoreguidelines-interfaces-global-init, cppcoreguidelines-no-malloc, cppcoreguidelines-prefer-member-initializer, cppcoreguidelines-pro-type-member-init, cppcoreguidelines-pro-type-static-cast-downcast, cppcoreguidelines-slicing, cppcoreguidelines-special-member-functions, performance-unnecessary-value-param, performance-unnecessary-copy-initialization, performance-noexcept-move-constructor, performance-no-automatic-move, performance-move-const-arg, performance-implicit-conversion-in-loop, performance-for-range-copy, google-default-arguments, google-global-names-in-headers, google-explicit-constructor, modernize-use-emplace # TODO: treat all maybe warnings as errors when existing warnings are all fixed # `clang-analyzer-cplusplus.NewDelete` cannot model reference counting properly for ObjectMsg WarningsAsErrors: >- maybe-unused, clang-analyzer-nullability.*, clang-analyzer-cplusplus.*, performance-implicit-conversion-in-loop, performance-move-const-arg, performance-no-automatic-move, performance-noexcept-move-constructor, google-default-arguments, google-global-names-in-headers, -clang-analyzer-cplusplus.NewDelete, modernize-use-emplace CheckOptions: # `cppcoreguidelines-special-member-functions` is enabled, refer to https://en.cppreference.com/w/cpp/language/rule_of_three - key: cppcoreguidelines-special-member-functions.AllowSoleDefaultDtor value: True - key: performance-move-const-arg.CheckTriviallyCopyableMove value: False - key: cppcoreguidelines-special-member-functions.AllowMissingMoveFunctionsWhenCopyIsDeleted value: True ================================================ FILE: .cmake-format.py ================================================ # ---------------------------------- # Options affecting listfile parsing # ---------------------------------- with section("parse"): # Specify structure for custom cmake functions additional_commands = { "cc_binary": { "flags": ["ADD_RUNTARGET"], "kwargs": { "DEPS": "*", "INC": { "kwargs": {"INTERFACE": "*", "PRIVATE": "*", "PUBLIC": "*"}, "pargs": 0, }, "LIBDIRS": { "kwargs": {"INTERFACE": "*", "PRIVATE": "*", "PUBLIC": "*"}, "pargs": "*", }, "PKGDEPS": "*", "PROPERTIES": {"kwargs": {"EXPORT_NAME": 1, "OUTPUT_NAME": 1}}, "SRCS": "*", }, "pargs": "1+", }, "cc_library": { "flags": ["STATIC", "SHARED"], "kwargs": { "DEPS": { "kwargs": {"INTERFACE": "*", "PRIVATE": "*", "PUBLIC": "*"}, "pargs": "*", }, "INC": { "kwargs": {"INTERFACE": "*", "PRIVATE": "*", "PUBLIC": "*"}, "pargs": 0, }, "LIBDIRS": { "kwargs": {"INTERFACE": "*", "PRIVATE": "*", "PUBLIC": "*"}, "pargs": "*", }, "PKGDEPS": "*", "PROPERTIES": { "kwargs": { "ARCHIVE_OUTPUT_NAME": 1, "EXPORT_NAME": 1, "INTERFACE_INCLUDE_DIRECTORIES": 1, "LIBRARY_OUTPUT_NAME": 1, "OUTPUT_NAME": 1, "SOVERSION": 1, "SUFFIX": 1, "VERSION": 1, } }, "SRCS": "*", }, "pargs": "1+", }, "cc_test": { "kwargs": { "ARGV": "*", "DEPS": "*", "LABELS": "*", "PKGDEPS": "*", "SRCS": "*", "TEST_DEPS": "*", "WORKING_DIRECTORY": "*", }, "pargs": 1, }, "check_call": { "flags": [ "OUTPUT_QUIET", "ERROR_QUIET", "OUTPUT_STRIP_TRAILING_WHITESPACE", "ERROR_STRIP_TRAILING_WHITESPACE", ], "kwargs": { "COMMAND": "*", "ENCODING": "1", "ERROR_FILE": "1", "ERROR_VARIABLE": "1", "INPUT_FILE": "1", "OUTPUT_FILE": "1", "OUTPUT_VARIABLE": "1", "RESULTS_VARIABLE": "1", "RESULT_VARIABLE": "1", "TIMEOUT": "1", "WORKING_DIRECTORY": "1", }, }, "check_pyoneline": { "kwargs": {"ERROR_VARIABLE": 1, "OUTPUT_VARIABLE": 1}, "pargs": "+", }, "create_debian_binary_packages": { "kwargs": {"DEPS": "*", "OUTPUTS": "*"}, "pargs": [3, "+"], }, "create_debian_depsrepo": {"pargs": [3, "+"]}, "create_debian_packages": { "kwargs": {"DEPS": "*", "OUTPUTS": "*"}, "pargs": [{"flags": ["FORCE_PBUILDER"], "nargs": "+"}], }, "debhelp": {"pargs": ["1+"], "spelling": "DEBHELP"}, "exportvars": { "kwargs": {"VARS": "+"}, "pargs": "1+", "spelling": "EXPORTVARS", }, "format_and_lint": { "kwargs": {"CC": "*", "CMAKE": "*", "JS": "*", "PY": "*", "SHELL": "*"} }, "get_debs": {"pargs": [3, "*"]}, "gresource": {"kwargs": {"DEPENDS": "+", "SRCDIR": 1}, "pargs": 2}, "gtk_doc_add_module": { "kwargs": { "FIXREFOPTS": "*", "IGNOREHEADERS": "*", "LIBRARIES": "*", "LIBRARY_DIRS": "*", "SOURCE": "*", "SUFFIXES": "*", "XML": 1, }, "pargs": 1, }, "importvars": { "kwargs": {"VARS": "+"}, "pargs": "1+", "spelling": "IMPORTVARS", }, "join": {"kwargs": {"GLUE": 1}, "pargs": [1, "+"]}, "pkg_find": {"kwargs": {"PKG": "*"}}, "stage_files": { "kwargs": {"FILES": "*", "LIST": 1, "SOURCEDIR": 1, "STAGE": 1} }, "tangent_addtest": { "kwargs": { "COMMAND": "+", "CONFIGURATIONS": "+", "DEPENDS": "+", "LABELS": "+", "NAME": 1, "WORKING_DIRECTORY": 1, } }, "tangent_extract_svg": {"kwargs": {"EXPORT": 1, "OUTPUT": 1, "SRC": 1}}, "tangent_fetchobj": {"kwargs": {"OUTDIR": 1}, "pargs": 2}, "tangent_rmark_render": { "kwargs": {"DEPENDS": 1, "FORMAT": 1, "OUTPUT": 1, "PAGENO": 1, "UUID": 1}, "pargs": 1, }, "tangent_unzip": { "kwargs": {"OUTPUT": "1+", "WORKING_DIRECTORY": 1}, "pargs": "1+", }, "travis_decrypt": {"kwargs": {}, "pargs": [3]}, } # Override configurations per-command where available override_spec = {} # Specify variable tags. vartags = [] # Specify property tags. proptags = [] # ----------------------------- # Options affecting formatting. # ----------------------------- with section("format"): # Disable formatting entirely, making cmake-format a no-op disable = False # How wide to allow formatted cmake files line_width = 100 # How many spaces to tab for indent tab_size = 2 # If true, lines are indented using tab characters (utf-8 0x09) instead of # space characters (utf-8 0x20). In cases where the layout would # require a fractional tab character, the behavior of the fractional # indentation is governed by use_tabchars = False # If is True, then the value of this variable indicates how # fractional indentions are handled during whitespace replacement. If set to # 'use-space', fractional indentation is left as spaces (utf-8 0x20). If set # to `round-up` fractional indentation is replaced with a single tab character # (utf-8 0x09) effectively shifting the column to the next tabstop fractional_tab_policy = "use-space" # If an argument group contains more than this many sub-groups (parg or kwarg # groups) then force it to a vertical layout. max_subgroups_hwrap = 3 # If a positional argument group contains more than this many arguments, then # force it to a vertical layout. max_pargs_hwrap = 6 # If a cmdline positional group consumes more than this many lines without # nesting, then invalidate the layout (and nest) max_rows_cmdline = 3 # If true, separate flow control names from their parentheses with a space separate_ctrl_name_with_space = False # If true, separate function names from parentheses with a space separate_fn_name_with_space = False # If a statement is wrapped to more than one line, than dangle the closing # parenthesis on its own line. dangle_parens = False # If the trailing parenthesis must be 'dangled' on its on line, then align it # to this reference: `prefix`: the start of the statement, `prefix-indent`: # the start of the statement, plus one indentation level, `child`: align to # the column of the arguments dangle_align = "prefix" # If the statement spelling length (including space and parenthesis) is # smaller than this amount, then force reject nested layouts. min_prefix_chars = 4 # If the statement spelling length (including space and parenthesis) is larger # than the tab width by more than this amount, then force reject un-nested # layouts. max_prefix_chars = 10 # If a candidate layout is wrapped horizontally but it exceeds this many # lines, then reject the layout. max_lines_hwrap = 2 # What style line endings to use in the output. line_ending = "unix" # Format command names consistently as 'lower' or 'upper' case command_case = "canonical" # Format keywords consistently as 'lower' or 'upper' case keyword_case = "unchanged" # A list of command names which should always be wrapped always_wrap = [] # If true, the argument lists which are known to be sortable will be sorted # lexicographicall enable_sort = True # If true, the parsers may infer whether or not an argument list is sortable # (without annotation). autosort = False # By default, if cmake-format cannot successfully fit everything into the # desired linewidth it will apply the last, most agressive attempt that it # made. If this flag is True, however, cmake-format will print error, exit # with non-zero status code, and write-out nothing require_valid_layout = False # A dictionary mapping layout nodes to a list of wrap decisions. See the # documentation for more information. layout_passes = {} # ------------------------------------------------ # Options affecting comment reflow and formatting. # ------------------------------------------------ with section("markup"): # What character to use for bulleted lists bullet_char = "*" # What character to use as punctuation after numerals in an enumerated list enum_char = "." # If comment markup is enabled, don't reflow the first comment block in each # listfile. Use this to preserve formatting of your copyright/license # statements. first_comment_is_literal = False # If comment markup is enabled, don't reflow any comment block which matches # this (regex) pattern. Default is `None` (disabled). literal_comment_pattern = None # Regular expression to match preformat fences in comments default= # ``r'^\s*([`~]{3}[`~]*)(.*)$'`` fence_pattern = "^\\s*([`~]{3}[`~]*)(.*)$" # Regular expression to match rulers in comments default= # ``r'^\s*[^\w\s]{3}.*[^\w\s]{3}$'`` ruler_pattern = "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$" # If a comment line matches starts with this pattern then it is explicitly a # trailing comment for the preceeding argument. Default is '#<' explicit_trailing_pattern = "#<" # If a comment line starts with at least this many consecutive hash # characters, then don't lstrip() them off. This allows for lazy hash rulers # where the first hash char is not separated by space hashruler_min_length = 10 # If true, then insert a space between the first hash char and remaining hash # chars in a hash ruler, and normalize its length to fill the column canonicalize_hashrulers = True # enable comment markup parsing and reflow enable_markup = False # ---------------------------- # Options affecting the linter # ---------------------------- with section("lint"): # a list of lint codes to disable disabled_codes = ["C0113"] # regular expression pattern describing valid function names function_pattern = "[0-9a-z_]+" # regular expression pattern describing valid macro names macro_pattern = "[0-9A-Z_]+" # regular expression pattern describing valid names for variables with global # (cache) scope global_var_pattern = "[A-Z][0-9A-Z_]+" # regular expression pattern describing valid names for variables with global # scope (but internal semantic) internal_var_pattern = "_[A-Z][0-9A-Z_]+" # regular expression pattern describing valid names for variables with local # scope local_var_pattern = "[a-z][a-z0-9_]+" # regular expression pattern describing valid names for privatedirectory # variables private_var_pattern = "_[0-9a-z_]+" # regular expression pattern describing valid names for public directory # variables public_var_pattern = "[A-Z][0-9A-Z_]+" # regular expression pattern describing valid names for function/macro # arguments and loop variables. argument_var_pattern = "[a-z][a-z0-9_]+" # regular expression pattern describing valid names for keywords used in # functions or macros keyword_pattern = "[A-Z][0-9A-Z_]+" # In the heuristic for C0201, how many conditionals to match within a loop in # before considering the loop a parser. max_conditionals_custom_parser = 2 # Require at least this many newlines between statements min_statement_spacing = 1 # Require no more than this many newlines between statements max_statement_spacing = 2 max_returns = 6 max_branches = 12 max_arguments = 5 max_localvars = 15 max_statements = 50 # ------------------------------- # Options affecting file encoding # ------------------------------- with section("encode"): # If true, emit the unicode byte-order mark (BOM) at the start of the file emit_byteorder_mark = False # Specify the encoding of the input file. Defaults to utf-8 input_encoding = "utf-8" # Specify the encoding of the output file. Defaults to utf-8. Note that cmake # only claims to support utf-8 so be careful when using anything else output_encoding = "utf-8" # ------------------------------------- # Miscellaneous configurations options. # ------------------------------------- with section("misc"): # A dictionary containing any per-command configuration overrides. Currently # only `command_case` is supported. per_command = {} ================================================ FILE: .devcontainer/Dockerfile ================================================ # See here for image contents: https://github.com/Oneflow-Inc/docker-images/blob/main/oneflow/Dockerfile # [Choice] llvm12 llvm13 cuda11.1 ARG VARIANT="llvm13" ARG REPO="oneflowinc/devcontainer" FROM ${REPO}:${VARIANT} ================================================ FILE: .devcontainer/devcontainer.json ================================================ // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: // https://github.com/microsoft/vscode-dev-containers/tree/v0.209.6/containers/cpp // workaround for EACCES: permission denied, mkdir '/tmp/vsch..... // https://github.com/microsoft/vscode-remote-release/issues/2347 // sudo chmod 777 /tmp/vsch/container-features { "name": "oneflow-devel", "image": "oneflowinc/manylinux2014_x86_64_cuda11.2", "runArgs": [ "--cap-add=SYS_PTRACE", "--privileged", "--shm-size=8g", "--security-opt", "seccomp=unconfined", "--network=host", // "--gpus", // "all", ], "remoteEnv": { "PATH": "${containerEnv:PATH}:/opt/python/cp37-cp37m/bin", "ONEFLOW_CI_PYTHON_EXE": "/opt/python/cp37-cp37m/bin/python3", "ONEFLOW_CI_SRC_DIR": "${containerWorkspaceFolder}", "ONEFLOW_CI_BUILD_DIR": "${containerWorkspaceFolder}/build", "ONEFLOW_CI_CMAKE_INIT_CACHE": "${containerWorkspaceFolder}/cmake/caches/ci/cuda.cmake", "ONEFLOW_CI_BUILD_PARALLEL": "20" }, "initializeCommand": "mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/ccache && mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/local && mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/cache", "mounts": [ "source=${localWorkspaceFolder}/devcontainer-cache/dot/ccache,target=/root/.ccache,type=bind,consistency=cached", "source=${localWorkspaceFolder}/devcontainer-cache/dot/local,target=/root/.local,type=bind,consistency=cached", "source=${localWorkspaceFolder}/devcontainer-cache/dot/cache,target=/root/.cache,type=bind,consistency=cached", "source=/dataset,target=/dataset,type=bind,consistency=cached,readonly", "source=/model_zoo,target=/model_zoo,type=bind,consistency=cached,readonly", ], // Set *default* container specific settings.json values on container create. "settings": { "files.insertFinalNewline": true, "files.trimFinalNewlines": true, "files.trimTrailingWhitespace": true, "files.eol": "\n", "clangd.arguments": [ "-j", "8", "-header-insertion=never" ], }, // Add the IDs of extensions you want installed when the container is created. "extensions": [ "llvm-vs-code-extensions.vscode-clangd", "ms-vscode.cmake-tools", "ms-python.python" ], // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. "remoteUser": "root", } ================================================ FILE: .dockerignore ================================================ **/.git /build /build-* /docs/build /cmake-build-* /third_party /examples/**/oneflow /benchmark/**/oneflow /.vscode /.idea /.clangd /dist /wheelhouse* /.DS_Store /tmp_wheel /manylinux* **/__pycache__ **/*.pyc **/log **/.ipynb_checkpoints **/core.0* **/core.1* **/core.2* **/core.3* **/core.4* **/core.5* **/core.6* **/core.7* **/core.8* **/core.9* /.cache /oneflow-src.zip /distributed-tmp /serving-tmp ================================================ FILE: .github/CODEOWNERS ================================================ *.cu @liujuncheng *.py @BBuf @daquexian /oneflow/core/cuda @liujuncheng /oneflow/core/eager @daquexian /oneflow/core/framework @chengtbf @strint /oneflow/core/functional @hjchen2 /oneflow/core/graph @chengtbf /oneflow/core/ndarray @daquexian /oneflow/core/object_msg @daquexian /oneflow/core/platform @jackalcooper /oneflow/core/ep @liujuncheng /oneflow/core/rpc @jackalcooper /oneflow/core/stream @liujuncheng /oneflow/core/hardware @liujuncheng /oneflow/core/transport @chengtbf /oneflow/core/vm @daquexian /oneflow/xrt @hjchen2 /oneflow/ir @hjchen2 @BBuf @jackalcooper /ci @jackalcooper /python/oneflow/test_utils @daquexian @BBuf /cmake @daquexian @jackalcooper CMakeLists.txt @daquexian @jackalcooper /.github @jackalcooper /tools @jackalcooper /docs @doombeaker ================================================ FILE: .github/ISSUE_TEMPLATE/blank_issue.yml ================================================ name: Blank Issue description: Submit an issue about OneFlow. labels: [Blank Issue] body: - type: textarea id: description attributes: label: Description description: Please describe the issue here. placeholder: Description validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug, community assignees: '' --- ## Summary A short description about the bug/issue ## Code to reproduce bug Please post a minimal example to repro the bug. GitHub Gist or repo is highly recommended. ## System Information - What is your OneFlow installation (pip, source, dockerhub): - OS: - OneFlow version (run `python3 -m oneflow --doctor`): - Python version: - CUDA driver version: - GPU models: - Other info: ================================================ FILE: .github/ISSUE_TEMPLATE/documention_issue.yml ================================================ name: Documentation Issue description: Report an issue about OneFlow ducumention or require a documention. title: "[Documention Issue]: " labels: [Documention Issue] body: - type: markdown attributes: value: | Welcome to suggest to OneFlow documention! This template will help us gather the information we need to improve it. - type: textarea id: brief-description attributes: label: Brief Description description: Please describe the problem or the requst for new documention here. placeholder: Description validations: required: true - type: textarea id: alternatives attributes: label: Alternatives description: | Please provide some alternative information here, if any. placeholder: Alternatives validations: required: false - type: markdown attributes: value: | Thanks for your contributing! ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: Feature Request description: Request/Propose a new OneFlow feature. title: "[Feature Request]: " labels: [feature-request] body: - type: markdown attributes: value: | We welcome feature proposal/request! This template will help us gather the information we need to review the proposal/request. - type: textarea id: background attributes: label: Background and motivation description: Please describe the purpose and value of the new feature here. If the feature is linked to a specific problem, please describe it or put the link here. placeholder: Purpose validations: required: true - type: textarea id: api-proposal attributes: label: API Proposal description: | Please provide the specific public API signature diff that you are proposing. If a new API is not required, please provide the current API related to the feature, or note that there is no related public API. placeholder: API declaration (no method bodies) value: | ```py def new_api(value: Tensor) -> Tensor: pass ``` validations: required: true - type: textarea id: api-usage attributes: label: API Usage description: | Please provide code examples that highlight how the proposed API additions are meant to be consumed. This will help suggest whether the API has the right shape to be functional, performant and usable. If there is not a new API in step 2, please skip it. placeholder: API usage validations: required: false - type: textarea id: alternatives attributes: label: Alternatives description: | Please provide some alternative information of the feature, if any. For example, if you request a feature which depends on a specific device, please provide the device information. placeholder: Alternatives validations: required: false - type: textarea id: risks attributes: label: Risks description: | Please mention any risks that to your knowledge the API proposal might entail, such as breaking changes, performance regressions, etc. placeholder: Risks validations: required: false - type: markdown attributes: value: | Thanks for your contributing! ================================================ FILE: .github/ISSUE_TEMPLATE/performance_issue.yml ================================================ name: Performance Issue description: Submit an issue about performance problem or regression of OneFlow. title: "[Performance Issue]: " labels: [Performance Issue] body: - type: markdown attributes: value: | We welcome issues about OneFlow performance! This template will help us gather the information we need to locate the problem improve the performance. - type: textarea id: brief-description attributes: label: Brief Description description: Please give a brief description about the performance issue here. placeholder: Description validations: required: true - type: textarea id: device-and-context attributes: label: Device and Context description: | Please describe the device and context you used when you encounter the performance problem/regression. placeholder: Device and Context validations: required: true - type: textarea id: benchmark attributes: label: Benchmark description: | We will appreciate it if you'd like to provide benchmark comparison of the performance issue. placeholder: Benchmark validations: required: false - type: textarea id: alternatives attributes: label: Alternatives description: | Please provide some alternative information of the performance issue here, if any. placeholder: Alternatives validations: required: false - type: markdown attributes: value: | Thanks for your contributing! ================================================ FILE: .github/ISSUE_TEMPLATE/question.yml ================================================ name: Question description: Ask a question about OneFlow and discuss with community members. title: "[Question]: " labels: [Question] body: - type: markdown attributes: value: | Welcome to ask questions about OneFlow! This template will help us get your point. - type: textarea id: description attributes: label: Description description: Please describe your question here. placeholder: Description validations: required: true - type: textarea id: alternatives attributes: label: Alternatives description: | Please provide some alternative information here, if any. placeholder: Alternatives validations: required: false - type: markdown attributes: value: | We are always willing to answer your questions! ================================================ FILE: .github/PULL_REQUEST_TEMPLATE/general_template.md ================================================ ## 概述 ## PR Checklist - [ ] PR 标题语句通畅,明确表达 PR 内容,适合直接作为新版本发布时的 changelog - [ ] 代码格式化 - [ ] 已经本地编译通过 - [ ] 已本地针对改动测试 - [ ] 已添加 type 标签:(填写 type 标签名,如 `bug, enhancement, purge, feature, documentation`) - [ ] 已添加 component 标签:(填写 component 标签名,如 `op, system, eager, build, xla, python, ci, test, tooling`) - [ ] Draft 转正式 PR 前已请人 Review ================================================ FILE: .github/PULL_REQUEST_TEMPLATE/op_template.md ================================================ ## 概述 描述 op 的功能、公式等。若参考了其它框架的接口,应列出超链接。 ## 功能 CheckList **注意** : 功能复选框均为可选项,若未选择,说明理由即可。例如:该 Op 由 Python 接口拼接而成,因此无 `SetBatchAxisInferFn` Op 注册;再比如:该 Op 无输入,因此无 `SetInputArgModifyFn`。 模板中自带的复选框可留空,但是不能删除。可根据实际情况增加复选框选项。 ### Op - [ ] Op SetBatchAxisInferFn - [ ] Op SetGetSbpFn - [ ] Op SetInputArgModifyFn - [ ] Op 反向梯度注册 ### Kernel - [ ] CPU in:float32 - [ ] CPU in:float64 - [ ] CPU in:int32 - [ ] CPU in:int64 - [ ] CPU in:int8 - [ ] GPU in:float32 - [ ] GPU in:float64 - [ ] GPU in:int32 - [ ] GPU in:int64 - [ ] GPU in:float16 - [ ] GPU in:int8 ### Python Wrapper - [ ] Python API 参数检查及异常提示 - [ ] 接口注释 - [ ] Example  ### 测试 - [ ] 单机单卡 CPU Test Case - [ ] 单机单卡 GPU Test Case - [ ] 单机多卡 CPU Test Case - [ ] 单机多卡 GPU Test Case - [ ] 分布式 CPU Test Case - [ ] 分布式 GPU Test Case ## GPU 有效带宽 带 GPU 的 Op,请参考 https://github.com/Oneflow-Inc/OneTeam/issues/167 测试有效带宽,并附带测试报告。 以下是报告样例: 理论带宽: ```text Device to Device Bandwidth, 1 Device(s) PINNED Memory Transfers Transfer Size (Bytes) Bandwidth(MB/s) 33554432 250798.5 ``` 实际带宽: ``` PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: sqrt_2 elapsed(ms): 0.196064 memory_size(Byte): 50331648 bandwidth(GB/s): 239.08 PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: sqrt_2_grad elapsed(ms): 0.29072 memory_size(Byte): 75497472 bandwidth(GB/s): 241.856 ``` ## PR Checklist - [ ] PR 标题语句通畅,明确表达 PR 内容,适合直接作为新版本发布时的 changelog - [ ] 代码格式化 - [ ] 已经本地编译通过 - [ ] 已本地针对改动测试 - [ ] 已添加 type 标签:(填写 type 标签名,如 `bug, enhancement, purge, feature, documentation`) - [ ] 已添加 component 标签:(填写 component 标签名,如 `op, system, eager, build, xla, python, ci, test, tooling`) - [ ] Draft 转正式 PR 前已请人 Review ================================================ FILE: .github/actions/mac-build/action.yml ================================================ name: "Build OneFlow on macOS" description: "" runs: using: "composite" steps: - name: Install dependencies run: | brew install nasm shell: bash - name: Set environment variables run: | set -x cmake_flags="" cmake_flags+=" -DPython3_EXECUTABLE=$(which python3)" cmake_flags+=" -DRPC_BACKEND=LOCAL" cmake_flags+=" -DCMAKE_BUILD_TYPE=Release" cmake_flags+=" -DBUILD_CUDA=OFF" echo "cmake_flags=${cmake_flags}" >> $GITHUB_ENV shell: bash - name: Build (third party) run: | mkdir -p build cd build cmake .. $cmake_flags -DTHIRD_PARTY=ON -DONEFLOW=OFF make -j $(nproc) shell: bash - name: Build (oneflow) run: | mkdir -p build cd build cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON make -j 2 oneflow shell: bash - name: Build (oneflow_internal) run: | mkdir -p build cd build cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON make -j 2 oneflow_internal shell: bash - name: Build (generate_api) run: | mkdir -p build cd build cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON make -j 2 generate_api shell: bash ================================================ FILE: .github/actions/setup/action.yml ================================================ inputs: name: description: 'Placeholder' default: 'Placeholder' runs: using: "composite" steps: - run: | echo $HOSTNAME rm -rf build/third_party bash ci/setup_submodule.sh auth_header="$(git config --local --get http.https://github.com/.extraheader)" git -c "http.extraheader=$auth_header" -c protocol.version=2 submodule update --init --recursive shell: bash ================================================ FILE: .github/actions/upload_oss/action.yml ================================================ inputs: src_path: required: true oss_dst_path: required: true oss_access_key_id: required: true oss_access_key_secret: required: true upload_core: required: false runs: using: "composite" steps: - run: | if [ -z "$OSS_ACCESS_KEY_ID" ] then exit 0 fi if [ ! -f "$HOME/ossutil64" ]; then curl http://gosspublic.alicdn.com/ossutil/1.7.15/ossutil64 -o $HOME/ossutil64 fi chmod 755 $HOME/ossutil64 $HOME/ossutil64 config -e oss-cn-beijing.aliyuncs.com -i ${{ inputs.oss_access_key_id }} -k ${{ inputs.oss_access_key_secret }} -L EN -c $HOME/.ossutilconfig dir_arg="" if [ -d "${{ inputs.src_path }}" ]; then dir_arg="--recursive" fi upload_core_arg="" if [ "${{ inputs.upload_core }}" == "true" ]; then echo "will upload core files" else upload_core_arg+='--exclude "core*"' fi set -x $HOME/ossutil64 cp --disable-ignore-error --update ${dir_arg} ${upload_core_arg} ${{ inputs.src_path }} ${{ inputs.oss_dst_path }} shell: bash env: OSS_ACCESS_KEY_ID: ${{ inputs.oss_access_key_id }} OSS_ACCESS_KEY_SECRET: ${{ inputs.oss_access_key_secret }} ================================================ FILE: .github/actions/upload_ssh/action.yml ================================================ name: "Upload via ssh" description: "" inputs: src_path: required: true description: "" dst_host: required: true description: "" dst_path: required: true description: "" runs: using: "composite" steps: - run: | set -x dir_arg="" if [ -d "${{ inputs.src_path }}" ]; then dir_arg="-r" fi parent_dir=$(dirname ${{ inputs.dst_path }}) ssh -o StrictHostKeyChecking=no ${{ inputs.dst_host }} mkdir -p $parent_dir ssh ${{ inputs.dst_host }} rm -rf ${{ inputs.dst_path }} scp ${dir_arg} ${{ inputs.src_path }} ${{ inputs.dst_host }}:${{ inputs.dst_path }} shell: bash ================================================ FILE: .github/actions/whl/action.yml ================================================ inputs: tmp_dir: description: "tmp dir" required: true cuda_version: description: "cuda_version" default: "10.2" python_version: description: "python_version" default: "3.8" extra_flags: description: "flags like --xla" default: "" extra_docker_args: description: "" default: "" runs: using: "composite" steps: - run: | set -x src_dir=${PWD} tmp_dir="${{ inputs.tmp_dir }}" mkdir -p ${tmp_dir} cd ${tmp_dir} docker run --rm -v $PWD:/p -w $PWD:/p busybox rm -rf /p/wheelhouse python3 ${src_dir}/docker/package/manylinux/build_wheel.py \ --cuda_version=${{ inputs.cuda_version }} \ --python_version=${{ inputs.python_version }} \ --use_tuna --use_system_proxy --use_aliyun_mirror \ --wheel_house_dir=${tmp_dir}/wheelhouse \ --oneflow_src_dir=${src_dir} ${{ inputs.extra_flags }} \ --retry=1 \ --extra_docker_args "${extra_docker_args}" shell: bash ================================================ FILE: .github/scripts/requirements.txt ================================================ PyYAML>=5.1 parsec ================================================ FILE: .github/scripts/set_initial_variables.py ================================================ import json def create_one(name=None, allow_fail=None): return { "test_suite": name, "cuda_version": "N/A", "extra_flags": "N/A", "os": ["self-hosted", "linux", "build"], "allow_fail": allow_fail, "python_version": "N/A", } def create_conda(name=None): return create_one(name=name, allow_fail=False) def print_github_action_output(name=None, value=None): print(f"::set-output name={name}::{value}") def print_result(build_matrix=None, test_matrix=None, out=None): check_include(include_key="test_suite", matrix=build_matrix) if test_matrix != {}: check_include(include_key="test_suite", matrix=test_matrix) assert build_matrix assert test_matrix != None root = { "build_matrix": build_matrix, "test_matrix": test_matrix, } for k, v in root.items(): print_github_action_output( name=k, value=json.dumps(v), ) if out: with open(out, "w+") as f: f.write(json.dumps(root, indent=4)) def check_include(include_key=None, matrix: dict = None): assert include_key in matrix in_declare = set(matrix[include_key]) in_include = set() for include_value in matrix["include"]: in_include.add(include_value[include_key]) assert in_declare == in_include, { "in_declare": in_declare, "in_include": in_include, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--labels", type=lambda x: (str(x).replace(" ", "").split(",")), required=True, ) parser.add_argument("--out", type=str, required=False) args = parser.parse_args() if "need-clang-only" in args.labels: print_result( build_matrix={ "test_suite": ["cpu-clang"], "include": [create_conda("cpu-clang")], }, test_matrix={}, out=args.out, ) else: full_build_matrix = { "test_suite": ["cuda", "cpu", "xla", "xla_cpu", "cpu-clang"], "include": [ { "test_suite": "cuda", "cuda_version": 10.2, "extra_flags": "--extra_oneflow_cmake_args=-DCUDA_ARCHITECTURES=61 --extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple", "os": ["self-hosted", "linux", "build"], "allow_fail": False, "python_version": "3.6,3.7", }, { "test_suite": "cpu", "cuda_version": 10.2, "extra_flags": "--extra_oneflow_cmake_args=-DBUILD_SHARED_LIBS=OFF --extra_oneflow_cmake_args=-DRPC_BACKEND=LOCAL --cpu", "os": ["self-hosted", "linux", "build"], "allow_fail": False, "python_version": "3.6,3.7", }, { "test_suite": "xla", "cuda_version": 10.1, "extra_flags": "--extra_oneflow_cmake_args=-DCUDA_ARCHITECTURES=61 --extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --xla --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple", "os": ["self-hosted", "linux", "build"], "allow_fail": True, "python_version": 3.6, }, { "test_suite": "xla_cpu", "cuda_version": 10.1, "extra_flags": "--extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --xla --cpu --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple", "os": ["self-hosted", "linux", "build"], "allow_fail": True, "python_version": 3.6, }, create_conda("cpu-clang"), ], } full_test_matrix = { "test_suite": [ "cuda", "cuda_op", "cuda_new_interface", "cpu_new_interface", "cpu", "xla", "xla_cpu", ], "include": [ { "test_suite": "cuda", "os": ["self-hosted", "linux", "gpu"], "allow_fail": False, "build_env": "build.cuda.env", }, { "test_suite": "cuda_op", "os": ["self-hosted", "linux", "gpu"], "allow_fail": False, "build_env": "build.cuda.env", }, { "test_suite": "cuda_new_interface", "os": ["self-hosted", "linux", "gpu"], "allow_fail": False, "build_env": "build.cuda.env", }, { "test_suite": "cpu", "os": ["self-hosted", "linux", "cpu"], "allow_fail": False, "build_env": "build.cpu.env", }, { "test_suite": "cpu_new_interface", "os": ["self-hosted", "linux", "cpu"], "allow_fail": False, "build_env": "build.cpu.env", }, { "test_suite": "xla", "os": ["self-hosted", "linux", "gpu"], "allow_fail": True, "build_env": "build.xla.env", }, { "test_suite": "xla_cpu", "os": ["self-hosted", "linux", "cpu"], "allow_fail": True, "build_env": "build.xla_cpu.env", }, ], } print_result( build_matrix=full_build_matrix, test_matrix=full_test_matrix, out=args.out, ) ================================================ FILE: .github/workflows/canary.yml ================================================ name: Canary on: push: branches: - master - "canary/*" workflow_dispatch: inputs: oneflow-ref: description: "" default: "master" required: true concurrency: group: canary-${{ github.ref }} cancel-in-progress: false jobs: canary_release: name: Canary Release timeout-minutes: 120 runs-on: [self-hosted, linux, release] if: github.repository == 'Oneflow-Inc/oneflow' strategy: max-parallel: 1 fail-fast: false matrix: entry: ["canary", "profiler"] include: - entry: "canary" cmake-init-cache: "cmake/caches/ci/canary/cuda.cmake" - entry: "profiler" cmake-init-cache: "cmake/caches/ci/profiler/cuda.cmake" env: ONEFLOW_SRC: . MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/canary-cu112 WHEELHOUSE_DIR: manylinux-wheelhouse COMPUTE_PLATFORM: cu118 OSS_BUCKET: oneflow-staging OSS_WHEEL_HOUSE_DIR: ${{ matrix.entry }}/commit/${{ github.sha }} OSS_GITHUB_REF_DIR: ${{ matrix.entry }}/${{ github.ref }} steps: - name: Fix permissions run: | set -x docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * - name: Remove leftover cuda-installer.log run: | docker run --rm -v /tmp:/host/tmp -w /p busybox rm -f /host/tmp/cuda-installer.log - name: Checkout Oneflow-Inc/oneflow if: ${{ github.event.inputs.oneflow-ref != '' }} uses: actions/checkout@v2 with: ref: ${{ github.event.inputs.oneflow-ref }} - name: Checkout Oneflow-Inc/oneflow if: ${{ github.event.inputs.oneflow-ref == '' }} uses: actions/checkout@v2 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build manylinux id: build-cuda with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/${{ matrix.cmake-init-cache }} build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: true manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: true retry-failed-build: true clean-ccache: true compute-platform: ${{ env.COMPUTE_PLATFORM }} python-versions: | 3.8 3.10 - name: Upload wheelhouse uses: ./.github/actions/upload_oss with: src_path: ${{ env.WHEELHOUSE_DIR }} oss_dst_path: oss://${{ env.OSS_BUCKET }}/${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }} oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }} oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }} - name: Update pip index env: OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} run: | python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple python3 -m pip install oss2 beautifulsoup4 --user python3 tools/create_pip_index.py -b ${{ env.OSS_BUCKET }} \ --dir_key ${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }} \ --index_key=${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }}/index.html \ --index_key=${{ env.OSS_GITHUB_REF_DIR}}/${{ env.COMPUTE_PLATFORM }}/index.html ================================================ FILE: .github/workflows/community_release.yml ================================================ name: Community Release on: push: branches: - "community/*" schedule: # beijing: 6 pm. # utc: 10 am. - cron: "0 10 * * sat" workflow_dispatch: inputs: priv_branch: required: false default: "main" concurrency: group: community-release-${{ github.ref }}-${{ inputs.priv_branch }} cancel-in-progress: true jobs: release: name: Release pip permissions: contents: read pull-requests: write uses: ./.github/workflows/release.yml with: is_priv: true branch: ${{ inputs.priv_branch || 'main' }} upload_override_branch: "community" cuda_cmake_cache: cmake/caches/ci/release/cuda_community.cmake secrets: ONEFLOW_PRIV_ORG: ${{ secrets.ONEFLOW_PRIV_ORG }} ONEFLOW_PRIV_GH_TOKEN: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }} ONEFLOW_PRIV_OSS_BUCKET: ${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }} OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} ONEFLOW_CI_HTTP_PROXY: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }} ================================================ FILE: .github/workflows/on_merge.yml ================================================ name: Update Benchmark History on: pull_request: types: - closed branches: - master env: OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} jobs: if_merged: if: github.event.pull_request.merged == true runs-on: ubuntu-latest steps: - uses: Oneflow-Inc/get-oneflow/update-benchmark-history@ci-test-with-cu118 name: Update benchmark history timeout-minutes: 10 ================================================ FILE: .github/workflows/pr.yml ================================================ name: Check PR on: pull_request: types: [opened, labeled, unlabeled, synchronize] jobs: check_labels: runs-on: ubuntu-22.04 name: Labels if: github.event.pull_request.draft == false && github.base_ref == 'master' steps: - name: Check type labels 'bug, enhancement, purge, feature, documentation' if: (contains(github.event.pull_request.labels.*.name, 'bug') || contains(github.event.pull_request.labels.*.name, 'enhancement') || contains(github.event.pull_request.labels.*.name, 'purge') || contains(github.event.pull_request.labels.*.name, 'feature') || contains(github.event.pull_request.labels.*.name, 'documentation')) == false run: | exit 1 - name: Check component labels 'op, system, eager, build, xla, python, ci, test, tooling, quantization, graph, ir, serving' if: (contains(github.event.pull_request.labels.*.name, 'op') || contains(github.event.pull_request.labels.*.name, 'system') || contains(github.event.pull_request.labels.*.name, 'eager') || contains(github.event.pull_request.labels.*.name, 'build') || contains(github.event.pull_request.labels.*.name, 'xla') || contains(github.event.pull_request.labels.*.name, 'python') || contains(github.event.pull_request.labels.*.name, 'ci') || contains(github.event.pull_request.labels.*.name, 'test') || contains(github.event.pull_request.labels.*.name, 'tooling') || contains(github.event.pull_request.labels.*.name, 'quantization') || contains(github.event.pull_request.labels.*.name, 'graph') || contains(github.event.pull_request.labels.*.name, 'ir') || contains(github.event.pull_request.labels.*.name, 'serving')) == false run: | exit 2 ================================================ FILE: .github/workflows/priv_release.yml ================================================ name: Priv Release on: push: branches: - "pro/*" schedule: # beijing: 12 pm. # utc: 4 am. - cron: "0 4 * * sun" workflow_dispatch: inputs: priv_branch: required: false default: "main" concurrency: group: priv-release-${{ github.ref }}-${{ inputs.priv_branch }} cancel-in-progress: true jobs: release: name: Release pip permissions: contents: read pull-requests: write uses: ./.github/workflows/release.yml with: is_priv: true branch: ${{ inputs.priv_branch || 'main' }} cuda_cmake_cache: cmake/caches/ci/release/cuda_pro.cmake secrets: ONEFLOW_PRIV_ORG: ${{ secrets.ONEFLOW_PRIV_ORG }} ONEFLOW_PRIV_GH_TOKEN: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }} ONEFLOW_PRIV_OSS_BUCKET: ${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }} OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} ONEFLOW_CI_HTTP_PROXY: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }} ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: branches: - "release/*" schedule: # beijing: 2 am. # utc: 6 pm. - cron: "0 18 * * *" workflow_dispatch: inputs: placeholder: description: "update .github/workflows/release.yml to config your build" required: false workflow_call: inputs: is_priv: required: true type: boolean branch: required: false type: string default: "main" upload_override_branch: required: false type: string cuda_cmake_cache: required: false type: string secrets: ONEFLOW_PRIV_ORG: required: true ONEFLOW_PRIV_GH_TOKEN: required: true ONEFLOW_PRIV_OSS_BUCKET: required: true OSS_ACCESS_KEY_ID: required: true OSS_ACCESS_KEY_SECRET: required: true ONEFLOW_CI_HTTP_PROXY: required: false concurrency: group: release-${{ github.ref }}-${{ inputs.branch }} cancel-in-progress: ${{ github.ref != 'refs/heads/master' }} env: ONEFLOW_SRC: . jobs: generate-build-matrix: name: "Generate build matrix" runs-on: ubuntu-latest env: ONEFLOW_SRC: . outputs: matrix: ${{ steps.find-cache.outputs.matrix }} formatted_date: ${{ steps.date.outputs.formatted_date }} steps: - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 if: ${{ !inputs.is_priv }} with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - name: Checkout oneflow uses: actions/checkout@v2 if: ${{ inputs.is_priv }} with: ref: ${{ inputs.branch }} repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/oneflow token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }} - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@ci-test-with-cu118 name: Find build cache id: find-cache timeout-minutes: 5 with: delete-cache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} runner-labels: | self-hosted linux release oneflow-src: ${{ env.ONEFLOW_SRC }} entries: | cu122 cu121 cu118 cpu - name: Get current date id: date run: echo "formatted_date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT staging_release: env: MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/release/${{ matrix.entry }} WHEELHOUSE_DIR: manylinux_wheelhouse OSS_DIR: branch/${{ github.ref_name }}/${{ matrix.entry }}/${{ github.sha }} GITHUB_REF_NAME: ${{ github.ref_name }} GITHUB_SHA: ${{ github.sha }} ONEFLOW_OSS_BUCKET: oneflow-staging https_proxy: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }} needs: [generate-build-matrix] name: Staging Release timeout-minutes: 240 runs-on: [self-hosted, linux, release] if: github.repository == 'Oneflow-Inc/oneflow' || inputs.is_priv strategy: fail-fast: false max-parallel: 6 matrix: ${{ fromJson(needs.generate-build-matrix.outputs.matrix) }} steps: - name: Fix permissions run: | docker run --rm -v $PWD:/p -w /p busybox rm -rf * - name: Install dependencies run: | python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple python3 -m pip install -U setuptools wheel --user python3 -m pip install oss2 --user - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 if: ${{ !inputs.is_priv }} with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - name: Checkout private oneflow uses: actions/checkout@v2 if: ${{ inputs.is_priv }} with: ref: ${{ inputs.branch }} repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/oneflow token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }} - name: Checkout cutlass_extension uses: actions/checkout@v2 if: ${{ inputs.is_priv }} with: repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/cutlass-extension token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }} path: cutlass-extension - name: Set Private env if: ${{ inputs.is_priv }} run: | GITHUB_SHA=$(git rev-parse HEAD) echo "OSS_DIR=branch/${{ inputs.upload_override_branch || inputs.branch }}/${{ matrix.entry }}/${GITHUB_SHA}" >> $GITHUB_ENV echo "GITHUB_REF_NAME=${{ inputs.upload_override_branch || inputs.branch }}" >> $GITHUB_ENV echo "GITHUB_SHA=${GITHUB_SHA}" >> $GITHUB_ENV echo "ONEFLOW_OSS_BUCKET=${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }}" >> $GITHUB_ENV - name: Print env if: ${{ inputs.is_priv }} run: | env - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build ${{ matrix.entry }} if: ${{ matrix.entry =='cu118' || startsWith(matrix.entry, 'cu12') }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/${{ inputs.cuda_cmake_cache || 'cmake/caches/ci/release/cu118.cmake' }} build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: true compute-platform: ${{ matrix.entry }} manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: false retry-failed-build: true clean-ccache: true nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}} nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }} use-nvidia-wheels: ${{ matrix.entry !='cu112' }} python-versions: | 3.12 3.11 3.10 3.9 3.8 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build ${{ matrix.entry }} if: ${{ startsWith(matrix.entry, 'cu') && matrix.entry !='cu118' && !startsWith(matrix.entry, 'cu12') }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/release/cuda.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: true compute-platform: ${{ matrix.entry }} manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: false retry-failed-build: true clean-ccache: true nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}} nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }} use-nvidia-wheels: ${{ matrix.entry !='cu112' }} python-versions: | 3.12 3.11 3.10 3.9 3.8 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build ${{ matrix.entry }} if: ${{ matrix.entry =='cpu' }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/release/cpu.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: true compute-platform: ${{ matrix.entry }} manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: false retry-failed-build: true clean-ccache: false nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}} nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }} python-versions: | 3.12 3.11 3.10 3.9 3.8 - name: Upload wheel uses: ./.github/actions/upload_oss with: src_path: ${{ env.WHEELHOUSE_DIR }} oss_dst_path: oss://${{ env.ONEFLOW_OSS_BUCKET }}/${{ env.OSS_DIR }} oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }} oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }} - name: Update pip index env: OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} run: | python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple python3 -m pip install oss2 beautifulsoup4 --user python3 tools/create_pip_index.py --dir_key ${{ env.OSS_DIR }} -b ${{ env.ONEFLOW_OSS_BUCKET }} \ --index_key=branch/${{ env.GITHUB_REF_NAME }}/${{ matrix.entry }}/index.html \ --index_key=branch/${{ env.GITHUB_REF_NAME }}/date/${{ needs.generate-build-matrix.outputs.formatted_date }}/${{ matrix.entry }}/index.html \ --index_key=${{ env.OSS_DIR }}/index.html \ --index_key=commit/${{ env.GITHUB_SHA }}/${{ matrix.entry }}/index.html - name: Update API docs if: github.ref == 'refs/heads/master' && matrix.entry == 'cpu' && !inputs.is_priv env: READTHEDOCS_TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} run: | curl -X POST -d "branches=master" -d "token=${READTHEDOCS_TOKEN}" https://readthedocs.org/api/v2/webhook/oneflow/135376/ ================================================ FILE: .github/workflows/simple.yml ================================================ name: Simple CI on: pull_request: types: [review_requested] branches: - "*" push: branches: - master workflow_dispatch: inputs: placeholder: description: "placeholder, no effect" required: false concurrency: group: simple-ci-${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/master' }} jobs: static_analysis_with_clang: name: Static analysis with clang runs-on: ubuntu-22.04 if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci')) steps: - name: Check out OneFlow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.ref }} repository: ${{github.event.pull_request.head.repo.full_name}} - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y libopenblas-dev nasm python3-pip ninja-build - name: Download OneFlow custom clang-tidy run: | wget https://github.com/Oneflow-Inc/llvm-project/releases/download/maybe-14.0.4/clang-tidy-14.AppImage wget https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py chmod +x clang-tidy-14.AppImage run-clang-tidy.py - name: Build third party libs and generate files run: | mkdir build cd build cmake .. -C ../cmake/caches/international/cpu.cmake \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_TESTING=ON cmake --build . -j$(nproc) --target oneflow_deps of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema - name: Run clang-tidy for all translation units # use clang as compiler for correct compiler flags run: | cd build rm CMakeCache.txt cmake .. -C ../cmake/caches/international/cpu.cmake \ -DCMAKE_C_COMPILER=clang-12 \ -DCMAKE_CXX_COMPILER=clang++-12 \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_TESTING=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON cd .. ./run-clang-tidy.py -clang-tidy-binary ./clang-tidy-14.AppImage -p build -quiet -allow-enabling-alpha-checkers -extra-arg="-Xclang" -extra-arg="-analyzer-config" -extra-arg="-Xclang" -extra-arg="aggressive-binary-operation-simplification=true" "^(?!$(pwd)/build)" hosted: name: CPU-only if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci')) runs-on: ${{ matrix.os }} env: CFLAGS: "-w" CXXFLAGS: "-w" strategy: fail-fast: true max-parallel: 1 matrix: test_suite: ["mac", "ubuntu"] cmake_generator: ["Ninja", "Unix Makefiles"] cmake_build_type: ["Debug", "Release"] build_shared_libs: ["ON", "OFF"] include: - test_suite: mac os: "macos-10.15" make_concurrency: 2 - test_suite: ubuntu os: "ubuntu-22.04" make_concurrency: 2 exclude: - test_suite: mac cmake_build_type: "Debug" - test_suite: mac cmake_generator: "Ninja" - test_suite: ubuntu cmake_generator: "Ninja" cmake_build_type: "Debug" - test_suite: ubuntu cmake_generator: "Ninja" build_shared_libs: "OFF" - test_suite: ubuntu cmake_build_type: "Debug" build_shared_libs: "OFF" - test_suite: ubuntu cmake_generator: "Unix Makefiles" cmake_build_type: "Release" steps: - name: Set Swap Space uses: pierotofy/set-swap-space@master with: swap-size-gb: 5 - uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} - name: Install dependencies (homebrew) if: matrix.test_suite == 'mac' run: | brew install nasm ninja - name: Install dependencies (apt) if: matrix.test_suite == 'ubuntu' run: | sudo apt install -y libopenblas-dev nasm g++ gcc python3-pip ninja-build - name: Cache pip (Linux) if: startsWith(runner.os, 'Linux') uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ matrix.os }}-pip-${{ hashFiles('**/requirements.txt') }} - name: Cache pip (macOS) if: startsWith(runner.os, 'macOS') uses: actions/cache@v4 with: path: ~/Library/Caches/pip key: ${{ matrix.os }}-pip-${{ hashFiles('**/requirements.txt') }} - name: Install dependencies (pip) run: | python3 -m pip install -r ci/requirements.txt python3 -m pip install -r dev-requirements.txt - name: Set environment variables run: | set -x cmake_flags="" cmake_flags+=" -DBUILD_CUDA=OFF" cmake_flags+=" -DBUILD_TESTING=ON" cmake_flags+=" -G '${{ matrix.cmake_generator }}'" cmake_flags+=" -DCMAKE_BUILD_TYPE=${{ matrix.cmake_build_type }}" cmake_flags+=" -DBUILD_SHARED_LIBS=${{ matrix.build_shared_libs }}" cmake_flags+=" -DCMAKE_MACOSX_RPATH=FALSE" cmake_flags+=" -DCMAKE_BUILD_WITH_INSTALL_RPATH=FALSE" echo "cmake_flags=${cmake_flags}" >> $GITHUB_ENV - name: Build (third party) if: matrix.cmake_generator != 'Ninja' run: | set -x mkdir -p build-third_party mkdir -p third_party_install cd build-third_party cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=ON -DONEFLOW=OFF -DTHIRD_PARTY_DIR=$PWD/../third_party_install cmake --build . -j $(nproc) - name: Build (oneflow) if: matrix.cmake_generator != 'Ninja' run: | mkdir -p build cd build cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON -DTHIRD_PARTY_DIR=$PWD/../third_party_install cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow - name: Build (oneflow_internal) if: always() && matrix.cmake_generator != 'Ninja' run: | mkdir -p build cd build cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_internal - name: Build (oneflow_py) if: always() && matrix.cmake_generator != 'Ninja' run: | mkdir -p build cd build cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_py - name: Build (oneflow_testexe) if: always() && matrix.cmake_generator != 'Ninja' run: | mkdir -p build cd build cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_testexe - name: Build (ALL) if: always() continue-on-error: ${{ startsWith(runner.os, 'macOS') && matrix.cmake_generator == 'Ninja' && matrix.build_shared_libs == 'ON' }} run: | mkdir -p build cd build cmake .. ${{ env.cmake_flags }} cmake --build . -j ${{ matrix.make_concurrency }} - name: Exe test if: always() continue-on-error: true run: | ulimit -c ulimit -c unlimited ulimit -c mkdir -p build cd build ./bin/oneflow_testexe - name: Op test if: always() continue-on-error: true run: | ulimit -c ulimit -c unlimited ulimit -c source build/source.sh ONEFLOW_TEST_GITHUB_HOSTED=1 ONEFLOW_TEST_CPU_ONLY=1 bash ci/test/1node_op_test.sh - name: "Tar logs" if: always() && contains(github.event.pull_request.labels.*.name, 'need-simple-ci-upload-artifact') continue-on-error: true run: | set -ex if [[ -d "${HOME}/oneflow_temp" ]] then tar -cvf home_oneflow_temp.tar ${HOME}/oneflow_temp fi if [[ -d "${PWD}/test_tmp_dir" ]] then tar -cvf cwd_test_tmp_dir.tar ${PWD}/test_tmp_dir fi - name: Upload logs if: always() && contains(github.event.pull_request.labels.*.name, 'need-simple-ci-upload-artifact') uses: actions/upload-artifact@v4 with: name: logs-${{ matrix.test_suite }}-${{ matrix.cmake_generator }}-${{ matrix.cmake_build_type }}-shared-${{ matrix.build_shared_libs }} path: | home_oneflow_temp.tar cwd_test_tmp_dir.tar conda: name: Build with conda if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci')) runs-on: ubuntu-latest strategy: fail-fast: true max-parallel: 1 matrix: build-type: ["gcc7", "clang10"] steps: - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 - name: Checkout Oneflow-Inc/conda-env uses: actions/checkout@v2 with: repository: Oneflow-Inc/conda-env ref: 30a7f00eb48ee9009d85a848e720823e5054c66b path: conda-env - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build with gcc7 if: ${{ matrix.build-type == 'gcc7'}} with: cmake-init-cache: cmake/caches/ci/gh-hosted/cpu-gcc.cmake oneflow-src: . oneflow-build-env: conda conda-env-file: conda-env/dev/gcc7/environment-v2.yml conda-env-name: oneflow-dev-gcc7-v2 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build with clang10 if: ${{ matrix.build-type == 'clang10'}} with: cmake-init-cache: cmake/caches/ci/gh-hosted/cpu-clang.cmake oneflow-src: . oneflow-build-env: conda conda-env-file: conda-env/dev/clang10/environment-v2.yml conda-env-name: oneflow-dev-clang10-v2 ================================================ FILE: .github/workflows/test.yml ================================================ name: Build and Test CI on: pull_request: types: [opened, review_requested, ready_for_review, synchronize, unlocked] merge_group: types: [checks_requested] concurrency: group: build-and-test-${{ github.ref }} cancel-in-progress: true env: OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} ONEFLOW_TIMEOUT_SECONDS: 90 ONEFLOW_THRAED_LOCAL_CACHED_SIZE: 16384 FLOW_VISION_SRC: flow_vision FLOW_VISION_COMMIT: ca8ebc663b58667cf8cd1b6ef0c861522780b7bb LIBAI_SRC: libai LIBAI_COMMIT: 94eb85ff0131e8dfce953a3a916de7a4f897c647 ONEFLOW_FACE_SRC: oneflow_face ONEFLOW_FACE_COMMIT: 110a97e8d5737a1f1856281a7df556a5ac8f06de ONEFLOW_IREE_SRC: oneflow_iree ONEFLOW_IREE_COMMIT: 42fd479de7047e6af1d42c6e62b9b056e0a762aa ONE_FX_SRC: one-fx ONE_FX_COMMIT: da4051c7f1ace7a20b3f54395b580cd102fc99da TEST_WITH_TORCH_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/test-with-pytorch-1.10.0-cuda11.3-cudnn8-runtime:25817b5c0e1dd79bef8fdd43d729b98af381e7d5 MLIR_DOCKER_ARGS: "-e ONEFLOW_MLIR_ENABLE_ROUND_TRIP=1 -e ONEFLOW_MLIR_PREFER_NHWC=0 -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=1" SSH_TANK_HOST: 192.168.1.40 SSH_TANK_PATH: /data/tank jobs: source_info: name: Collect information about PR and source runs-on: ubuntu-22.04 if: github.event.pull_request.draft == false && github.base_ref == 'master' steps: - name: Check out OneFlow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} fetch-depth: 0 - name: Python diff id: py-diff run: | ONEFLOW_TEST_FILES="$(git diff --diff-filter=d --name-only ${{ github.event.pull_request.base.sha }} -- python/oneflow/test/**/test_*.py | { grep -v expensive || true; })" ONEFLOW_TEST_FILES=$(echo "${ONEFLOW_TEST_FILES}" | xargs) if [ -z "$ONEFLOW_TEST_FILES" ]; then echo "no changed python tests" echo "has_changed_python_tests=false" >> $GITHUB_OUTPUT else echo "changed python tests: ${ONEFLOW_TEST_FILES}" echo "has_changed_python_tests=true" >> $GITHUB_OUTPUT fi echo "changed_python_tests=${ONEFLOW_TEST_FILES}" >> $GITHUB_OUTPUT outputs: changed_python_tests: ${{ steps.py-diff.outputs.changed_python_tests }} has_changed_python_tests: ${{ steps.py-diff.outputs.has_changed_python_tests }} mirror_third_party: name: Mirror third party dependencies runs-on: ubuntu-22.04 if: github.event.pull_request.draft == false && github.base_ref == 'master' steps: - uses: actions/checkout@v2 - name: Mirror dependencies to aliyun if: github.event.pull_request.head.repo.full_name == github.repository run: | set -x if [ -z "$OSS_ACCESS_KEY_ID" ] then exit 0 fi python3 -m pip install -U pip "setuptools<=68.2.2" wheel python3 -m pip install 'cryptography<=3.4' oss2 python3 tools/package_mirror.py -i $PWD check_license_and_format: name: License and format runs-on: ubuntu-22.04 if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v2 with: repository: ${{github.event.pull_request.head.repo.full_name}} ref: ${{ github.head_ref }} - name: Check license id: license_check run: | python3 ci/check/run_license_format.py -i oneflow -c python3 ci/check/run_license_format.py -i python -c - name: Add license id: license_fmt if: ${{ failure() }} run: | python3 ci/check/run_license_format.py -i oneflow --fix python3 ci/check/run_license_format.py -i python --fix - name: Check C++/CUDA format id: cpp_check run: | sudo apt install libtinfo5 python3 ci/check/run_clang_format.py --clang_format_binary clang-format --source_dir oneflow - name: Run C++/CUDA format id: cpp_fmt if: ${{ failure() }} run: | sudo apt install libtinfo5 python3 ci/check/run_clang_format.py --clang_format_binary clang-format --source_dir oneflow --fix - name: Check Python format id: py_check run: | python3 -m pip install black==19.10b0 click==8.0.0 python3 ci/check/run_py_format.py --source_dir $PWD - name: Run Python Format id: py_fmt if: ${{ failure() }} run: | python3 -m pip install black==19.10b0 --user python3 ci/check/run_py_format.py --source_dir $PWD --fix - name: Check CMake format id: cmake_check run: | python3 -m pip install cmakelang python3 ci/check/run_cmake_format.py --source_dir $PWD - name: Run CMake Format id: cmake_fmt if: ${{ failure() }} run: | python3 -m pip install cmakelang python3 ci/check/run_cmake_format.py --source_dir $PWD --fix - name: Git push id: git_push if: ${{ failure() }} run: | git diff -p > license_and_format.patch cat license_and_format.patch git config --global user.email "ci-bot@oneflow.org" git config --global user.name "oneflow-ci-bot" git add -u git commit -m "auto format by CI" git push - name: Upload patch if: ${{ failure() && steps.git_push.outcome == 'failure' }} uses: actions/upload-artifact@v4 with: name: license_and_format-${{ github.sha }}.patch path: license_and_format.patch - name: Add comment if: ${{ failure() }} uses: actions/github-script@v4 with: script: | github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.' }) - name: Please request CI again if: ${{ failure() }} run: | exit 1 - name: Check source code (prevent creating files at wrong places) run: | python3 tools/check_src.py find-build-cache: name: "Find build cache" if: github.event.pull_request.draft == false && github.base_ref == 'master' runs-on: ubuntu-latest env: ONEFLOW_SRC: . outputs: matrix: ${{ steps.find-cache.outputs.matrix }} steps: - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@ci-test-with-cu118 name: find cache id: find-cache timeout-minutes: 5 with: delete-cache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} runner-labels: | self-hosted linux builder oneflow-src: ${{ env.ONEFLOW_SRC }} entries: | cu118 cpu cpu-asan-ubsan cpu-tsan llvm15 build-oneflow: name: "Build OneFlow" if: github.event.pull_request.draft == false && github.base_ref == 'master' runs-on: ${{ matrix.runs-on }} needs: [find-build-cache] timeout-minutes: 80 strategy: fail-fast: true max-parallel: 5 matrix: ${{ fromJson(needs.find-build-cache.outputs.matrix) }} env: ONEFLOW_SRC: . MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/${{ matrix.entry }} WHEELHOUSE_DIR: manylinux-wheelhouse steps: - name: Set proxy if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | echo "https_proxy=${{ secrets.ONEFLOW_CI_HTTP_PROXY }}" >> $GITHUB_ENV - name: Fix permissions if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | set -x docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118 name: Save cache if successful id: save-cache timeout-minutes: 5 with: oneflow-src: ${{ env.ONEFLOW_SRC }} entry: ${{ matrix.entry }} digest-type: build mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }} - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running. if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }} run: | echo "::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit" exit 1 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build manylinux ${{ matrix.entry }} id: build-cpu if: ${{ matrix.entry =='cpu' && !matrix.cache-hit }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/cpu.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh run-lit: true oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }} cuda-version: none manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: true retry-failed-build: true clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} python-versions: | 3.7 3.8 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build manylinux ${{ matrix.entry }} id: build-cpu-sanitizers if: ${{ (matrix.entry == 'cpu-asan-ubsan' || matrix.entry == 'cpu-tsan') && !matrix.cache-hit && false }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/${{ matrix.entry }}.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh run-lit: false oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }} cuda-version: none manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: true retry-failed-build: true clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} python-versions: | 3.8 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build manylinux ${{ matrix.entry }} id: build-cuda if: ${{ matrix.entry =='cu118' && !matrix.cache-hit }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/cuda.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: manylinux wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }} cuda-version: "11.8" manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: false retry-failed-build: true clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} python-versions: | 3.7 - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118 name: Build ${{ matrix.entry }} if: ${{ matrix.entry == 'llvm15' && !matrix.cache-hit }} with: cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/llvm/cuda-75-clang.cmake build-script: ${{ env.ONEFLOW_SRC }}/ci/clang/build-llvm.sh oneflow-src: ${{ env.ONEFLOW_SRC }} oneflow-build-env: llvm wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} clear-wheelhouse-dir: true self-hosted: true cuda-version: ${{ env.CUDA_VERSION }} manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} docker-run-use-system-http-proxy: false docker-run-use-lld: false retry-failed-build: true clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} wheel-audit: false python-versions: | 3.8 - name: Remove automerge if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }} uses: actions/github-script@v4 with: script: | github.issues.removeLabel({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, name: 'automerge' }) github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'CI failed when running job: Build ${{ matrix.entry }}. PR label automerge has been removed' }) - name: Upload packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && matrix.entry != 'llvm15' && matrix.entry != 'cpu-asan-ubsan' && matrix.entry != 'cpu-tsan' }} uses: Oneflow-Inc/get-oneflow/digest/upload@ci-test-with-cu118 timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: ${{ matrix.entry }} ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} src-dir: ${{ env.MANYLINUX_CACHE_DIR }}/build/cpack dst-dir: cpack - name: Upload whl if: ${{ !fromJson(matrix.cache-hit) && matrix.entry != 'llvm15' && matrix.entry != 'cpu-asan-ubsan' && matrix.entry != 'cpu-tsan' }} uses: Oneflow-Inc/get-oneflow/digest/upload@ci-test-with-cu118 timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: ${{ matrix.entry }} ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} src-dir: ${{ env.WHEELHOUSE_DIR }} dst-dir: whl find-test-cache-distributed: name: "Find test cache (distributed)" if: github.event.pull_request.draft == false && github.base_ref == 'master' && contains(github.event.pull_request.labels.*.name, 'need-test-distributed') runs-on: ubuntu-latest needs: [build-oneflow] env: ONEFLOW_SRC: . outputs: matrix: ${{ steps.find-cache.outputs.matrix }} steps: - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@ci-test-with-cu118 name: find cache id: find-cache timeout-minutes: 5 with: runner-labels: | self-hosted linux oneflow-src: ${{ env.ONEFLOW_SRC }} include-distributed: true world-size: 2 devices: | cuda tests: | module find-test-cache: name: "Find test cache" if: github.event.pull_request.draft == false && github.base_ref == 'master' runs-on: ubuntu-latest needs: [build-oneflow] env: ONEFLOW_SRC: . outputs: matrix: ${{ steps.find-cache.outputs.matrix }} steps: - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@ci-test-with-cu118 name: find cache id: find-cache timeout-minutes: 5 with: runner-labels: | self-hosted linux oneflow-src: ${{ env.ONEFLOW_SRC }} devices: | cuda cpu tests: | module misc speed-test test-distributed: name: Distributed test suite needs: [find-test-cache-distributed, test] runs-on: ${{ matrix.runs-on }} timeout-minutes: 120 if: github.event.pull_request.draft == false && github.base_ref == 'master' && contains(github.event.pull_request.labels.*.name, 'need-test-distributed') concurrency: group: distributed-test-${{ matrix.entry }}-rank-${{ matrix.rank }} cancel-in-progress: false strategy: fail-fast: true max-parallel: 2 matrix: ${{ fromJson(needs.find-test-cache-distributed.outputs.matrix) }} env: ONEFLOW_SRC: . TEST_CONTAINER_NAME: "ci-test-distributed" steps: - name: Fix permissions if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | set -x docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf .pytest_cache - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - name: Checkout Oneflow-Inc/vision if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/vision # please use a commit here ref: ${{ env.FLOW_VISION_COMMIT}} path: ${{ env.FLOW_VISION_SRC}} - name: Checkout Oneflow-Inc/one-fx if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/one-fx # please use a commit here ref: ${{ env.ONE_FX_COMMIT}} path: ${{ env.ONE_FX_SRC}} - name: Checkout Oneflow-Inc/libai if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/libai # please use a commit here ref: ${{ env.LIBAI_COMMIT}} path: ${{ env.LIBAI_SRC}} - name: Checkout Oneflow-Inc/oneflow_iree if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/oneflow_iree # please use a commit here ref: ${{ env.ONEFLOW_IREE_COMMIT}} path: ${{ env.ONEFLOW_IREE_SRC}} - name: Remove container timeout-minutes: 45 if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118 name: Save cache if successful id: save-cache timeout-minutes: 5 with: oneflow-src: ${{ env.ONEFLOW_SRC }} entry: ${{ matrix.entry }} digest-type: ${{ matrix.digest-type }} mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }} - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running. if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }} run: | echo "::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit" exit 1 - name: Download wheel and packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118 id: download-digest timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: ${{ matrix.compute-platform }} ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} - name: Get primary node if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: Oneflow-Inc/get-oneflow/master-address@ci-test-with-cu118 id: get-primary-node with: rank: ${{ matrix.rank }} - name: Set environment variables if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | set -x extra_docker_args="" if [ "${{ matrix.device }}" == "cpu" ]; then extra_docker_args+=" --env ONEFLOW_TEST_CPU_ONLY=1" extra_docker_args+=" --env CUDA_VISIBLE_DEVICES=-1" fi echo "EXTRA_DOCKER_ARGS=${extra_docker_args}" >> $GITHUB_ENV echo "ONEFLOW_TEST_CACHE_DIR=$HOME/ci-cache/test_cache" >> $GITHUB_ENV echo "ONEFLOW_TEST_DATASET_DIR=$HOME/dataset" >> $GITHUB_ENV echo "ONEFLOW_WHEEL_PATH=${{ steps.download-digest.outputs.entry-dir }}/whl" >> $GITHUB_ENV echo "ONEFLOW_CPACK_PATH=${{ steps.download-digest.outputs.entry-dir }}/cpack" >> $GITHUB_ENV - name: Set environment variables (distributed) if: ${{ fromJson(matrix.is-distributed) }} run: | set -x EXTRA_DOCKER_ARGS+=" --network host " echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV - name: Enable ONEFLOW_TEST_VERBOSE if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }} run: | EXTRA_DOCKER_ARGS+=" --env ONEFLOW_TEST_VERBOSE=1" echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV - name: Start container if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} working-directory: ${{ env.ONEFLOW_SRC }} run: | docker run --gpus=all -d --rm --privileged --shm-size=8g \ --pids-limit 2000 \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \ -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \ -v $HOME/test-container-cache/dot-local:/root/.local \ -v $HOME/test-container-cache/dot-cache:/root/.cache \ -e NODE_RANK=${{ matrix.rank }} \ -e _MASTER_ADDR=${{ steps.get-primary-node.outputs.master-address }} \ -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \ -e ONEFLOW_CI=1 \ -v $PWD:$PWD \ -w $PWD \ -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \ -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \ -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \ ${{ env.MLIR_DOCKER_ARGS }} \ --name ${TEST_CONTAINER_NAME} \ ${{ env.EXTRA_DOCKER_ARGS }} \ ${{ env.TEST_WITH_TORCH_IMG_TAG }} \ sleep 5400 - name: Test container if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} ls docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m pip list - name: Install OneFlow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | ls ${ONEFLOW_WHEEL_PATH} docker exec ${TEST_CONTAINER_NAME} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple docker exec ${TEST_CONTAINER_NAME} python3 -m pip install --find-links=${ONEFLOW_WHEEL_PATH} oneflow - name: Install downstream libs if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.FLOW_VISION_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install pybind11 --user docker exec ${TEST_CONTAINER_NAME} python3 -m pip install tensorboardX==2.6 --user docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.LIBAI_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_IREE_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONE_FX_SRC}} - name: Module API test (distributed) timeout-minutes: 90 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && matrix.device == 'cuda' && fromJson(matrix.is-distributed) }} continue-on-error: false run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules ${{ env.TEST_CONTAINER_NAME }} bash ci/test/2node_op_test_multi_client.sh - name: Module API test (distributed, without IB) timeout-minutes: 60 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && matrix.device == 'cuda' && fromJson(matrix.is-distributed) && contains(github.event.pull_request.labels.*.name, 'need-distributed-without-ib')}} continue-on-error: false run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules \ -e ONEFLOW_LIBIBVERBS_PATH=invalid_lib \ -e ONEFLOW_CI_DEVICE_NUMS="4" \ ${{ env.TEST_CONTAINER_NAME }} bash ci/test/2node_op_test_multi_client.sh - name: Print stacks in all core files timeout-minutes: 45 if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/print_stack_in_all_dirs.sh || true - name: Remove automerge if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }} uses: actions/github-script@v4 with: script: | github.issues.removeLabel({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, name: 'automerge' }) github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'CI failed when running job: ${{ matrix.entry }}. PR label automerge has been removed' }) - name: Remove container timeout-minutes: 45 if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * test: name: Test suite needs: [find-test-cache, source_info] timeout-minutes: 120 runs-on: ${{ matrix.runs-on }} if: github.event.pull_request.draft == false && github.base_ref == 'master' strategy: fail-fast: ${{ !contains(github.event.pull_request.labels.*.name, 'need-all-tests-even-fail') }} max-parallel: 10 matrix: ${{ fromJson(needs.find-test-cache.outputs.matrix) }} env: ONEFLOW_SRC: . TEST_CONTAINER_NAME: "pr-${{ github.event.pull_request.number }}-run-id-${{ github.run_id }}-${{ matrix.entry }}-test" TEST_MANYLINUX_CONTAINER_NAME: "pr-${{ github.event.pull_request.number }}-run-id-${{ github.run_id }}-${{ matrix.entry }}-test-manylinux" TEST_WITH_TF_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/test-with-tf-2.3.0:2f831e9354298a11447578e869d983959feb046f TEST_MANYLINUX_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/manylinux2014_x86_64_cuda11.8:6455f9b8154333333e6285fde3747aaac4a92929 METRICS_DIR: metrics steps: - name: Set proxy if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | echo "https_proxy=${{ secrets.ONEFLOW_CI_HTTP_PROXY }}" >> $GITHUB_ENV - name: Fix permissions if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | set -x docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf .pytest_cache - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - name: Checkout Oneflow-Inc/vision if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/vision # please use a commit here ref: ${{ env.FLOW_VISION_COMMIT}} path: ${{ env.FLOW_VISION_SRC}} - name: Checkout Oneflow-Inc/libai if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/libai # please use a commit here ref: ${{ env.LIBAI_COMMIT}} path: ${{ env.LIBAI_SRC}} - name: Checkout Oneflow-Inc/oneflow_face if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/oneflow_face # please use a commit here ref: ${{ env.ONEFLOW_FACE_COMMIT}} path: ${{ env.ONEFLOW_FACE_SRC}} - name: Checkout Oneflow-Inc/oneflow_iree if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/oneflow_iree # please use a commit here ref: ${{ env.ONEFLOW_IREE_COMMIT}} path: ${{ env.ONEFLOW_IREE_SRC}} - name: Checkout Oneflow-Inc/one-fx if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/one-fx # please use a commit here ref: ${{ env.ONE_FX_COMMIT}} path: ${{ env.ONE_FX_SRC}} - name: Remove container timeout-minutes: 45 if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true - name: Remove manylinux container timeout-minutes: 45 if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} || true - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118 name: Save cache if successful id: save-cache timeout-minutes: 5 with: oneflow-src: ${{ env.ONEFLOW_SRC }} entry: ${{ matrix.entry }} digest-type: ${{ matrix.digest-type }} mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }} - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running. if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }} run: | echo "::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit" exit 1 - name: Download wheel and packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118 id: download-digest timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: ${{ matrix.compute-platform }} ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} - name: Download ASAN and UBSAN wheel and packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && matrix.device == 'cpu' && false }} uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118 id: asan-ubsan-download-digest timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: cpu-asan-ubsan ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} - name: Download TSAN wheel and packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && matrix.device == 'cpu' && false }} uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118 id: tsan-download-digest timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} entry: cpu-tsan ssh-tank-host: ${{ env.SSH_TANK_HOST }} ssh-tank-path: ${{ env.SSH_TANK_PATH }} - name: Enable TF container if: ${{ fromJSON(matrix.is-single-client) }} run: | echo "TEST_IMG_TAG=${TEST_WITH_TF_IMG_TAG}" >> $GITHUB_ENV - name: Enable Pytorch container if: ${{ !fromJSON(matrix.is-single-client) }} run: | echo "TEST_IMG_TAG=${TEST_WITH_TORCH_IMG_TAG}" >> $GITHUB_ENV - name: Set environment variables if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | set -x extra_docker_args="" if [ "${{ matrix.device }}" == "cpu" ]; then extra_docker_args+=" --env ONEFLOW_TEST_CPU_ONLY=1" extra_docker_args+=" --env CUDA_VISIBLE_DEVICES=-1" fi echo "EXTRA_DOCKER_ARGS=${extra_docker_args}" >> $GITHUB_ENV echo "ONEFLOW_TEST_CACHE_DIR=$HOME/ci-cache/test_cache" >> $GITHUB_ENV echo "ONEFLOW_TEST_DATASET_DIR=$HOME/dataset" >> $GITHUB_ENV echo "ONEFLOW_WHEEL_PATH=${{ steps.download-digest.outputs.entry-dir }}/whl" >> $GITHUB_ENV echo "ONEFLOW_CPACK_PATH=${{ steps.download-digest.outputs.entry-dir }}/cpack" >> $GITHUB_ENV echo "DOCS_PATH=docs/${{ github.repository }}/pr/${{ github.event.pull_request.number }}" >> $GITHUB_ENV - name: Set environment variables (experimental flags) if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && fromJson(matrix.is-experimental) }} run: | EXTRA_DOCKER_ARGS+=" --env ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1" EXTRA_DOCKER_ARGS+=" --env ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1" EXTRA_DOCKER_ARGS+=" --env ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1" echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV - name: Set Thread Limit (CPU) if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cpu' }} run: | echo "THREAD_LIMIT=25000" >> $GITHUB_ENV - name: Set Thread Limit (CUDA) if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cuda' }} run: | echo "THREAD_LIMIT=20000" >> $GITHUB_ENV - name: Enable ONEFLOW_TEST_VERBOSE if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }} run: | EXTRA_DOCKER_ARGS+=" --env ONEFLOW_TEST_VERBOSE=1" echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV - name: Pull image continue-on-error: true if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker pull ${{ env.TEST_IMG_TAG }} - name: Unzip packed liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && !fromJson(matrix.is-xla) }} run: | unzip ${{ env.ONEFLOW_CPACK_PATH }}/liboneflow-ci-linux.zip - name: Unzip packed sanitized liboneflow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && !fromJson(matrix.is-xla) && matrix.device == 'cpu' && false }} run: | unzip ${{ steps.asan-ubsan-download-digest.outputs.entry-dir }}/cpack/liboneflow-ci-linux.zip -d asan-ubsan unzip ${{ steps.tsan-download-digest.outputs.entry-dir }}/cpack/liboneflow-ci-linux.zip -d tsan - name: Start container if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} working-directory: ${{ env.ONEFLOW_SRC }} run: | docker run --gpus=all -d --rm --privileged --shm-size=8g \ --pids-limit ${{ env.THREAD_LIMIT }} \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \ -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \ -v $HOME/test-container-cache/dot-local:/root/.local \ -v $HOME/test-container-cache/dot-cache:/root/.cache \ -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \ -e ONEFLOW_CI=1 \ -e NVIDIA_TF32_OVERRIDE=0 \ -e NCCL_P2P_DISABLE=1 \ -v $PWD:$PWD \ -w $PWD \ -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \ -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \ -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \ ${{ env.MLIR_DOCKER_ARGS }} \ --name ${TEST_CONTAINER_NAME} \ ${{ env.EXTRA_DOCKER_ARGS }} \ ${{ env.TEST_IMG_TAG }} \ sleep 7200 - name: Start manylinux container if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} working-directory: ${{ env.ONEFLOW_SRC }} # For unknown reason we need to disable the requirement from nvidia docker # by -e NVIDIA_DISABLE_REQUIRE=true run: | docker run --gpus=all -d --rm --privileged --shm-size=8g \ --pids-limit ${{ env.THREAD_LIMIT }} \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \ -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \ -v $HOME/test-container-cache/dot-local:/root/.local \ -v $HOME/test-container-cache/dot-cache:/root/.cache \ -e NVIDIA_DISABLE_REQUIRE=true \ -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \ -e ONEFLOW_CI=1 \ -v $PWD:$PWD \ -w $PWD \ -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \ -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \ -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \ -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \ ${{ env.MLIR_DOCKER_ARGS }} \ --name ${TEST_MANYLINUX_CONTAINER_NAME} \ ${{ env.EXTRA_DOCKER_ARGS }} \ ${{ env.TEST_MANYLINUX_IMG_TAG }} \ sleep 7200 - name: Exe test if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} timeout-minutes: 20 run: | docker exec ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./liboneflow-ci-linux/bin/oneflow_testexe - name: Exe test (C++ API) if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} timeout-minutes: 20 run: | docker exec -e ONEFLOW_SERVING_DEBUG=1 ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./liboneflow-ci-linux/bin/oneflow_cpp_api_testexe --gtest_filter=-Api.embedding* - name: Exe test (C++ API with sanitizers) if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && false }} timeout-minutes: 10 run: | docker exec -e UBSAN_OPTIONS=suppressions=.ubsan-suppressions -e ASAN_OPTIONS=strict_string_checks=1:detect_stack_use_after_return=1 -e LSAN_OPTIONS=suppressions=.lsan-suppressions ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./asan-ubsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe --gtest_filter=Api.graph_\* # Run 5 times to avoid false positive because of occasional lack of stack info docker exec -e TSAN_OPTIONS="history_size=7 suppressions=.tsan-suppressions" ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} bash -c "./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe" - name: Test container if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} ls docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m pip list - name: Install OneFlow if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | ls ${ONEFLOW_WHEEL_PATH} docker exec ${TEST_CONTAINER_NAME} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -U --find-links=${ONEFLOW_WHEEL_PATH} oneflow - name: Install downstream libs if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.FLOW_VISION_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install pybind11 --user docker exec ${TEST_CONTAINER_NAME} python3 -m pip install tensorboardX==2.6 --user docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.LIBAI_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_FACE_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_IREE_SRC}} docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONE_FX_SRC}} - name: Run OneFlow doctor if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow --doctor - name: Build documentation timeout-minutes: 10 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/build_docs.sh - name: Upload documentation id: upload-docs if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && github.repository == 'Oneflow-Inc/oneflow' }} continue-on-error: true uses: ./.github/actions/upload_oss with: src_path: build-docs/build/html oss_dst_path: oss://oneflow-staging/${{ env.DOCS_PATH }} oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }} oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }} - name: Post docs url if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && github.repository == 'Oneflow-Inc/oneflow' && steps.upload-docs.outcome == 'success' }} continue-on-error: true uses: actions/github-script@v4 with: script: | github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: "View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/${{ env.DOCS_PATH }}/" }) - name: Doctest timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/doctest.sh - name: Checkout Oneflow-Inc/models if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} uses: actions/checkout@v2 with: repository: Oneflow-Inc/models ref: d6b2b8260e87541726ed87361171438d258e6a4d path: oneflow-models - name: ResNet50 Graph DDP test id: models-resnet50 timeout-minutes: 20 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} run: | docker exec -e NCCL_DEBUG=INFO -e ONEFLOW_MODELS_DIR=$PWD/oneflow-models ${{ env.TEST_CONTAINER_NAME }} bash ci/test/test_resnet50_graph_ddp.sh - name: Speed test id: speed timeout-minutes: 20 continue-on-error: ${{ !contains(github.event.pull_request.labels.*.name, 'need-pass-speed-test') }} if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} run: | docker exec -e ONEFLOW_MODELS_DIR=$PWD/oneflow-models ${{ env.TEST_CONTAINER_NAME }} bash ci/test/test_speed_multi_client.sh - name: Save speed stats if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} run: | mkdir -p ${{ env.METRICS_DIR }} echo "${{ steps.speed.outputs.stats }}" >> ${{ env.METRICS_DIR }}/speed_stats.txt - name: Upload speed stats if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} # must succeed if it is a branch of Oneflow-Inc/oneflow continue-on-error: ${{ !(github.repository == 'Oneflow-Inc/oneflow') }} uses: ./.github/actions/upload_oss with: src_path: ${{ env.METRICS_DIR }} oss_dst_path: oss://oneflow-log/${{ github.repository }}/metrics/pr/${{ github.event.pull_request.number }}/${{ github.event.pull_request.head.sha }}/${{github.run_id}} oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }} oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }} - name: Post speed stats if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }} continue-on-error: true uses: actions/github-script@v4 with: script: | github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: "
\n Speed stats:\n\n ``` \n${{ steps.speed.outputs.stats }}\n ``` \n\n
".replace(/\\n/g, '\n') }) - name: Run tests in changed files compared to default branch 100 times timeout-minutes: 60 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && !fromJson(matrix.is-distributed) && steps.py-diff.outputs.has_changed_python_tests }} run: | docker exec -e ONEFLOW_TEST_DIR=diff \ -e ONEFLOW_TEST_FILES="${{needs.source_info.outputs.changed_python_tests}}" \ ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh - name: Expensive tests (models, cases require exclusive access to GPU) timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && (matrix.test-type == 'speed-test' || (matrix.test-type == 'misc' && matrix.device == 'cuda')) && !fromJson(matrix.is-distributed) }} run: | docker exec \ -e ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB=1024 \ -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/expensive \ ${{ env.TEST_CONTAINER_NAME }} bash ci/test/expensive_generic_test_multi_client.sh - name: Module API test timeout-minutes: 60 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && !fromJson(matrix.is-distributed) }} run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh - name: Graph API test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/graph ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 8 $PWD/python/oneflow/test/graph/test_neq_device_process_num.py - name: libai test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }} run: | docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_bert.py docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_gpt.py docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_t5.py docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_vit.py - name: oneflow_face test timeout-minutes: 30 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }} run: | docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.ONEFLOW_FACE_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest tests/train/test_train.py - name: oneflow_iree test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && false }} run: | docker exec -w $PWD/${{ env.ONEFLOW_IREE_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m pytest examples - name: IR tests timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && (matrix.test-type == 'misc' && matrix.device == 'cuda') && !fromJson(matrix.is-distributed) }} run: | docker exec \ -e ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB=1024 \ ${{ env.TEST_CONTAINER_NAME }} bash ci/test/ir_tests.sh - name: Exception API test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && false }} run: docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/multi_client_exception_test.sh - name: Misc test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/misc ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh - name: Dataloader API test timeout-minutes: 45 # TODO(luyang): dataset check fails if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && false}} run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/dataloader ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh - name: Tensor API test timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} run: | docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/tensor ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh - name: Test mocking torch by script timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash -x ci/test/test_mock_script.sh - name: Test mocking torch by function timeout-minutes: 45 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash -x ci/test/test_mock_function.sh - name: Benchmark Test timeout-minutes: 100 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'benchmark' && matrix.device == 'cuda' }} uses: Oneflow-Inc/get-oneflow/pytest-benchmark@ci-test-with-cu118 with: collect-path: ${{ env.FLOW_VISION_SRC }}/benchmark container-name: ${{ env.TEST_CONTAINER_NAME }} unknown-threshold: 30 error-threshold: 40 - name: Remove automerge if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }} uses: actions/github-script@v4 with: script: | github.issues.removeLabel({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, name: 'automerge' }) github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'CI failed when running job: ${{ matrix.entry }}. PR label automerge has been removed' }) - name: Print stacks in all core files timeout-minutes: 45 if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }} run: | docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/print_stack_in_all_dirs.sh || true - name: Query system status timeout-minutes: 45 if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }} run: | nvidia-smi || true docker ps || true - name: Remove container timeout-minutes: 45 if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true - name: Remove manylinux container timeout-minutes: 45 if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} || true - name: Clean workspace timeout-minutes: 45 if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }} run: | docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * static_analysis_with_clang_on_diff: name: Static analysis with clang on diff runs-on: ubuntu-22.04 if: github.event.pull_request.draft == false && github.base_ref == 'master' steps: - name: Check out OneFlow uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} fetch-depth: 0 - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118 name: Save cache if successful id: save-cache timeout-minutes: 5 with: oneflow-src: . entry: static_analysis_with_clang_on_diff digest-type: build mark-as-completed: ${{ github.event.pull_request.head.repo.full_name == github.repository }} - name: Install dependencies if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} run: | sudo apt-get update sudo apt-get install -y libopenblas-dev nasm python3-pip ninja-build ccache - name: Download OneFlow custom clang-tidy if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} run: | wget https://github.com/Oneflow-Inc/llvm-project/releases/download/maybe-16.0.0/oneflow-clang-tidy-16 wget https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/clang-tidy-diff.py chmod +x oneflow-clang-tidy-16 clang-tidy-diff.py - name: Cache third party dir uses: actions/cache@v4 if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} with: path: ~/.ccache key: clang-tidy-diff-third-party-ccache-${{ hashFiles('**/CMakeLists.txt') }}-${{ hashFiles('**/*.cmake') }} restore-keys: | clang-tidy-diff-third-party-ccache-${{ hashFiles('**/CMakeLists.txt') }}- clang-tidy-diff-third-party-ccache- - name: Build third party libs and generate files if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} run: | export CCACHE_COMPRESS=true export CCACHE_MAXSIZE=500M mkdir build cd build cmake .. -C ../cmake/caches/international/cpu.cmake \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_TESTING=OFF \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build . -j$(nproc) --target oneflow_deps of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema - name: Fetch upstream if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name }} run: | git remote add upstream https://github.com/Oneflow-Inc/oneflow git fetch upstream - name: Run clang-tidy for modified files # use clang as compiler for correct compiler flags if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} run: | sudo apt install clang-12 lldb-12 lld-12 libfuse2 cd build rm CMakeCache.txt cmake .. -C ../cmake/caches/international/cpu.cmake \ -DCMAKE_C_COMPILER=clang-12 \ -DCMAKE_CXX_COMPILER=clang++-12 \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_TESTING=OFF \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON cd .. git diff -U0 ${{ github.event.pull_request.base.sha }} | ./clang-tidy-diff.py -clang-tidy-binary ./oneflow-clang-tidy-16 -path build -allow-enabling-alpha-checkers -j $(nproc) -p1 -extra-arg="-Xclang" -extra-arg="-analyzer-config" -extra-arg="-Xclang" -extra-arg="aggressive-binary-operation-simplification=true" -warnings-as-errors="$(cat ./ci/check/clang_tidy_warnings_as_errors_on_diff)" - name: Check error message absence in changed files if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && contains(github.event.pull_request.labels.*.name, 'need-check-error-message') }} run: | git diff -U0 ${{ github.event.pull_request.base.sha }} | ./clang-tidy-diff.py -clang-tidy-binary ./oneflow-clang-tidy-16 -path build -allow-enabling-alpha-checkers -j $(nproc) -p1 -extra-arg="-Xclang" -extra-arg="-analyzer-config" -extra-arg="-Xclang" -extra-arg="aggressive-binary-operation-simplification=true" -checks=-*,maybe-need-error-msg -warnings-as-errors=* -skip-line-filter - name: Remove automerge if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && failure() && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }} uses: actions/github-script@v4 with: script: | github.issues.removeLabel({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, name: 'automerge' }) github.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'Static analysis with clang failed. PR label automerge has been removed' }) ================================================ FILE: .gitignore ================================================ /build /build-* /docs/build/ /docs/build-cn/ /docs/source/generated /cmake-build-* /dist /third_party/ /examples/**/oneflow /benchmark/**/oneflow log/ *plan core.* *.pyc *.ipynb /.vscode /.idea /manylinux* wheelhouse/ wheelhouse* .DS_Store /tmp_wheel /oneflow/python/__export_symbols__.py /oneflow/python/compatibility.py /oneflow/python/framework/sysconfig_gen.py /oneflow/python/test/ops/localhost_script_*.sh .clangd compile_commands.json .cache /oneflow-src.zip /oneflow_temp /distributed-tmp /serving-tmp test_tmp_dir unittest-log-* /oneflow/python /oneflow/compatible_single_client_python /benchmarks /oneflow/python/version.py /data-test /tmp /python/oneflow/test/dataloader/data-test/ /target saved_model /devcontainer-cache op_prof.csv *.lock ================================================ FILE: .lsan-suppressions ================================================ leak:CommandT ================================================ FILE: .mergify.yml ================================================ pull_request_rules: - name: automatic update for PR with label “automerge“ conditions: - "#approved-reviews-by>=2" - -conflict # skip conflicts - -draft # skip draft PRs - label="automerge" actions: update: - name: automatic merge conditions: - "#approved-reviews-by>=2" - -conflict # skip conflicts - -draft # skip draft PRs - label="automerge" - "#commits-behind==0" - -closed actions: merge: method: squash ================================================ FILE: .tsan-suppressions ================================================ # These four group of functions are designed to be thread unsafe, # it's user's responsibility to use them correctly. race:ThreadUnsafe race:thread_unsafe race:flying_instruction_cnt race:total_erased_instruction_cnt race:ToShape # glog race:google:: # ~basic_string() in DenseElementsAttrToTensor interferes with # ~~AccessBlobArgCbInstructionPolicy(). Perhaps it's a false # positive. race:~basic_string ================================================ FILE: .ubsan-suppressions ================================================ # llvm vptr:Class.cpp ================================================ FILE: CMakeLists.txt ================================================ # Minimum CMake required set(CMAKE_POLICY_DEFAULT_CMP0135 NEW) cmake_minimum_required(VERSION 3.18.0) set(CMAKE_INSTALL_MESSAGE LAZY CACHE STRING "") set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE BOOL "") option(THIRD_PARTY "Build third party" ON) option(ONEFLOW "Build oneflow" ON) if(NOT THIRD_PARTY AND NOT ONEFLOW) message(FATAL_ERROR "at least one of flags THIRD_PARTY and ONEFLOW should be ON") endif() option(USE_CLANG_FORMAT "" OFF) option(USE_CLANG_TIDY "" OFF) option(BUILD_PYTHON "" ON) option(BUILD_CPP_API "Option to build OneFlow C++ API (beta)" OFF) option(BUILD_RDMA "" OFF) option(BUILD_CUDA "" ON) option(BUILD_TESTING "" OFF) option(BUILD_GIT_VERSION "" ON) option(BUILD_PROFILER "" OFF) option(BUILD_FOR_CI "" OFF) option(WITH_COCOAPI "Option to build with COCO API" ON) option(WITH_ZLIB "" ON) option(WITH_ONEDNN "" ON) option(WITH_MLIR "" OFF) option(WITH_MLIR_CUDA_CODEGEN "" OFF) option(OF_SOFTMAX_USE_FAST_MATH "" ON) option(OF_LAYER_NORM_USE_FAST_MATH "" ON) option(TREAT_WARNINGS_AS_ERRORS "" ON) option(MAYBE_NEED_ERROR_MSG_CHECK "" OFF) option(LITE_USE_ASCEND_NPU "" OFF) # Reference: # https://medium.com/@alasher/colored-c-compiler-output-with-ninja-clang-gcc-10bfe7f2b949 option(OF_FORCE_COLORED_DIAGNOSTICS "Always produce ANSI-colored diagnostics (GNU/Clang only)." ON) set(ONEFLOW_CURRENT_VERSION 0.8.1.dev CACHE STRING "") if(BUILD_FOR_CI) set(ONEFLOW_CURRENT_VERSION ci) endif() set(LLVM_PROVIDER "in-tree" CACHE STRING "in-tree, install") if(NOT WITH_MLIR) set(LLVM_PROVIDER "install" CACHE STRING "in-tree will build LLVM's ALL, not what we want when not building MLIR" FORCE) endif(NOT WITH_MLIR) set(RPC_BACKEND "GRPC,LOCAL" CACHE STRING "") set(THIRD_PARTY_MIRROR "" CACHE STRING "") set(PIP_INDEX_MIRROR "" CACHE STRING "") set(CPU_THREADING_RUNTIMES "TBB;OMP" CACHE STRING "") if(APPLE) set(RPC_BACKEND "LOCAL") set(BUILD_CUDA OFF) set(WITH_COCOAPI OFF) set(WITH_ONEDNN OFF) endif() set(CUDNN_STATIC OFF CACHE BOOL "") project(oneflow C CXX) if(NOT CMAKE_BUILD_TYPE) message(STATUS "No build type selected, default to Release") set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type (default Release)" FORCE) endif() if(NOT CMAKE_BUILD_TYPE MATCHES "^(Debug|Release|RelWithDebInfo|MinSizeRel)$") message( FATAL_ERROR "Expected CMAKE_BUILD_TYPE is Debug, Release, RelWithDebInfo or MinSizeRel, got ${CMAKE_BUILD_TYPE}" ) endif() message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") set(COMPILER_VERSION_ERROR_MSG "At least gcc 9, clang 5 or Apple clang 12 is supported. Current version ${CMAKE_CXX_COMPILER_VERSION}." ) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 9) message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG}) endif() elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 5) message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG}) endif() elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 12) message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG}) endif() else() message(WARNING "Unknown compiler \"${CMAKE_CXX_COMPILER_ID}\".") endif() set(oneflow_cmake_dir ${PROJECT_SOURCE_DIR}/cmake) get_filename_component(real_src_dir "${CMAKE_SOURCE_DIR}" REALPATH) get_filename_component(real_bin_dir "${CMAKE_BINARY_DIR}" REALPATH) if("${real_src_dir}" STREQUAL "${real_bin_dir}") message(FATAL_ERROR "In-source build not allowed") endif() # Modules list(APPEND CMAKE_MODULE_PATH ${oneflow_cmake_dir}/third_party) list(APPEND CMAKE_MODULE_PATH ${oneflow_cmake_dir}) include(threading) include(util) include(proto2cpp) if(NOT DEFINED USE_CXX11_ABI) check_cxx11_abi(CXX11_ABI_AVAILABLE) set(USE_CXX11_ABI ${CXX11_ABI_AVAILABLE}) elseif(USE_CXX11_ABI) check_cxx11_abi(CXX11_ABI_AVAILABLE) if(NOT CXX11_ABI_AVAILABLE) message(FATAL_ERROR "cxx11 abi is not available for current compiler") endif() endif() message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}") if(WITH_MLIR) add_definitions(-DWITH_MLIR) if(WITH_MLIR_CUDA_CODEGEN) add_definitions(-DWITH_MLIR_CUDA_CODEGEN) endif() endif() if(WITH_COCOAPI) add_definitions(-DWITH_COCOAPI) endif() if(USE_CXX11_ABI) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) else() add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) endif() if(BUILD_PROFILER) add_definitions(-DOF_ENABLE_PROFILER) endif() if(OF_SOFTMAX_USE_FAST_MATH) add_definitions(-DOF_SOFTMAX_USE_FAST_MATH) endif() if(OF_LAYER_NORM_USE_FAST_MATH) add_definitions(-DOF_LAYER_NORM_USE_FAST_MATH) endif() if(OF_FORCE_COLORED_DIAGNOSTICS) add_compile_options( $<$:$<$:-fdiagnostics-color=always>> $<$:$<$:-fcolor-diagnostics>> $<$:$<$:-fcolor-diagnostics>>) endif() if(RPC_BACKEND MATCHES "GRPC") add_definitions(-DRPC_BACKEND_GRPC) message(STATUS "RPC backend enabled: gRPC") set(SUPPORTED_RPC_BACKEND_FOUND 1) endif() if(WITH_ONEDNN) add_definitions(-DWITH_ONEDNN) endif() add_definitions(-DRPC_BACKEND_LOCAL) message(STATUS "RPC backend enabled: local") enable_testing() set(CMAKE_CXX_STANDARD 17) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(THIRD_PARTY_DIR "${PROJECT_BINARY_DIR}/third_party_install" CACHE PATH "Where to install third party headers and libs") set(ONEFLOW_PYTHON_DIR "${PROJECT_SOURCE_DIR}/python" CACHE PATH "oneflow python src dir") include(platform) if((ENABLE_ASAN OR ENABLE_UBSAN) AND ENABLE_TSAN) message(FATAL_ERROR "Only ASAN and UBSAN can be enabled at the same time.") endif() if(ENABLE_ASAN) add_compile_options(-fsanitize=address -fno-omit-frame-pointer) add_link_options(-fsanitize=address -fno-omit-frame-pointer) endif() if(ENABLE_UBSAN) add_compile_options(-fsanitize=undefined) add_link_options(-fsanitize=undefined) endif() if(ENABLE_TSAN) add_compile_options(-fsanitize=thread) add_link_options(-fsanitize=thread) endif() if(BUILD_PYTHON) set(ONEFLOW_INCLUDE_DIR "${ONEFLOW_PYTHON_DIR}/oneflow/include") endif(BUILD_PYTHON) set(CUTLASS_URL https://github.com/Oneflow-Inc/cutlass/archive/e6f548d80bfdf1167d66adbbbcfc2ee3394f4777.zip) use_mirror(VARIABLE CUTLASS_URL URL ${CUTLASS_URL}) set(CUTLASS_MD5 425f8cf064ff47c81124e55490135f5c) include(cuda) add_subdirectory(external) include(third_party) message(STATUS "CMAKE_CXX_COMPILER_VERSION: " ${CMAKE_CXX_COMPILER_VERSION}) add_custom_target(oneflow_deps ALL DEPENDS prepare_oneflow_third_party) # skip oneflow cmake to avoid errors caused by the absences of python-dev, proto src if(ONEFLOW) include(oneflow) endif() add_subdirectory(ci) ================================================ FILE: LICENSE ================================================ Copyright 2020 The OneFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # OneFlow OneFlow is a deep learning framework designed to be **user-friendly, scalable and efficient**. With OneFlow, it is easy to: - program a model with [**PyTorch-like API**](https://oneflow.readthedocs.io/en/master/) - scale a model to n-dimensional-parallel execution with the [**Global Tensor**](https://docs.oneflow.org/en/master/cookies/global_tensor.html) - accelerate/deploy a model with the [**Graph Compiler**](https://oneflow.readthedocs.io/en/master/graph.html). [![Simple CI](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml) [![Nightly Docker Image](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml/badge.svg)](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml) [![Nightly Release](https://github.com/Oneflow-Inc/oneflow/actions/workflows/release.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/release.yml) [![Documentation](https://readthedocs.org/projects/oneflow/badge/?version=master)](https://oneflow.readthedocs.io/en/master/?badge=master) ## Latest News - Version 1.0.0 is out! - [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v1.0.0) ## Publication - [OneFlow: Redesign the Distributed Deep Learning Framework from Scratch](https://arxiv.org/abs/2110.15032) ## System Requirements ### General - Linux - Python 3.7, 3.8, 3.9, 3.10, 3.11 ### CUDA - CUDA arch 60 or above - CUDA Toolkit version 10.0 or above - Nvidia driver version 440.33 or above OneFlow will work on a minimum supported driver, and any driver beyond. For more information, please refer to [CUDA compatibility documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html). ## Install ### Preinstall docker image ``` docker pull oneflowinc/oneflow:nightly-cuda11.8 ``` ### Pip Install - (**Highly recommended**) Upgrade pip ``` python3 -m pip install --upgrade pip #--user ``` - To install latest stable release of OneFlow with CUDA support: ```bash python3 -m pip install oneflow ``` - To install nightly release of OneFlow with CPU-only support: ```bash python3 -m pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cpu ``` - To install nightly release of OneFlow with CUDA support: ```bash python3 -m pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu118 ``` If you are in China, you could run this to have pip download packages from domestic mirror of pypi: ``` python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple ``` For more information on this, please refer to [pypi 镜像使用帮助](https://mirror.tuna.tsinghua.edu.cn/help/pypi/) ### Install from Source
Clone Source Code - #### Option 1: Clone source code from GitHub ```bash git clone https://github.com/Oneflow-Inc/oneflow.git ``` - #### Option 2: Download from Aliyun(Only available in China) ```bash curl https://oneflow-public.oss-cn-beijing.aliyuncs.com/oneflow-src.zip -o oneflow-src.zip unzip oneflow-src.zip ```
Build OneFlow - Install dependencies ``` apt install -y libopenblas-dev nasm g++ gcc python3-pip cmake autoconf libtool ``` These dependencies are preinstalled in offical conda environment and docker image, you can use the offical conda environment [here](https://github.com/Oneflow-Inc/conda-env) or use the docker image by: ```bash docker pull oneflowinc/manylinux2014_x86_64_cuda11.2 ``` - In the root directory of OneFlow source code, run: ``` mkdir build cd build ``` - Config the project, inside `build` directory: - If you are in China config for CPU-only like this: ``` cmake .. -C ../cmake/caches/cn/cpu.cmake ``` config for CUDA like this: ``` cmake .. -C ../cmake/caches/cn/cuda.cmake -DCMAKE_CUDA_ARCHITECTURES=80 -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda -DCUDNN_ROOT_DIR=/usr/local/cudnn ``` - If you are not in China config for CPU-only like this: ``` cmake .. -C ../cmake/caches/international/cpu.cmake ``` config for CUDA like this: ``` cmake .. -C ../cmake/caches/international/cuda.cmake -DCMAKE_CUDA_ARCHITECTURES=80 -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda -DCUDNN_ROOT_DIR=/usr/local/cudnn ``` Here the DCMAKE\_CUDA\_ARCHITECTURES macro is used to specify the CUDA architecture, and the DCUDA\_TOOLKIT\_ROOT\_DIR and DCUDNN\_ROOT\_DIR macros are used to specify the root path of the CUDA Toolkit and CUDNN. - Build the project, inside `build` directory, run: ``` make -j$(nproc) ``` - Add oneflow to your PYTHONPATH, inside `build` directory, run: ``` source source.sh ``` Please note that this change is not permanent. - Simple validation ``` python3 -m oneflow --doctor ```
### Troubleshooting Please refer to [troubleshooting](docs/source/troubleshooting.md) for common issues you might encounter when compiling and running OneFlow. ## Getting Started - Please refer to [QUICKSTART](https://docs.oneflow.org/en/master/basics/01_quickstart.html) - 中文版请参见 [快速上手](https://docs.oneflow.org/master/basics/01_quickstart.html) ## Documentation - [API Reference](https://oneflow.readthedocs.io/en/master/) - [Usage & Design Docs](http://docs.oneflow.org/) - [System Design](https://docs.oneflow.org/en/v0.4.0/basics_topics/essentials_of_oneflow.html) ## Model Zoo and Benchmark - [Libai(Toolbox for Parallel Training Large-Scale Transformer Models)](https://github.com/Oneflow-Inc/libai) - [BERT-large](https://libai.readthedocs.io/en/latest/tutorials/get_started/quick_run.html) - [GPT](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id5) - [T5](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id4) - [VisionTransformer](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id1) - [SwinTransformer](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id2) - [FlowVision(Toolbox for Computer Vision Datasets, SOTA Models and Utils)](https://github.com/Oneflow-Inc/vision) - [OneFlow-Models(Outdated)](https://github.com/Oneflow-Inc/models) - [ResNet-50](https://github.com/Oneflow-Inc/models/tree/main/Vision/classification/image/resnet50) - [Wide&Deep](https://github.com/Oneflow-Inc/models/tree/main/RecommenderSystems/wide_and_deep) - [OneFlow-Benchmark(Outdated)](https://github.com/Oneflow-Inc/OneFlow-Benchmark) ## Communication - [GitHub issues](https://github.com/Oneflow-Inc/oneflow/issues): any install, bug, feature issues. - [www.oneflow.org](http://www.oneflow.org): brand related information. - ### 中文 - QQ 群: 331883 - 微信号(加好友入交流群): OneFlowXZS - [知乎](https://www.zhihu.com/org/oneflow-17) - ### International - [Discord](https://discord.gg/4kpjGA5bZY) - [Twitter](https://twitter.com/OneFlowNews) - [LinkedIn](https://www.linkedin.com/company/oneflow-inc) - [Medium](https://oneflow2020.medium.com) ## The Team OneFlow was originally developed by [OneFlow Inc](http://www.oneflow.org) and [Zhejiang Lab](http://www.zhejianglab.com/). ## License [Apache License 2.0](LICENSE) ================================================ FILE: ci/CMakeLists.txt ================================================ add_subdirectory(test) ================================================ FILE: ci/build/ensure_img.py ================================================ import os import argparse from pathlib import Path import re import json import subprocess def check_and_download(tag, url): img_dir = os.path.join(os.path.expanduser("~"), "imgs") if not os.path.exists(img_dir): os.makedirs(img_dir) returncode = subprocess.run( f"docker image inspect {tag}", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode if returncode == 0: print("[OK]", tag) else: basename = os.path.basename(url) dst = os.path.join(img_dir, basename) subprocess.check_call(f"wget -c {url} -O {dst}", shell=True) subprocess.check_call(f"docker load -i {dst}", shell=True) base = os.path.basename(dst) base = os.path.splitext(base)[0] base = os.path.splitext(base)[0] keep_tag = f"ofkeep:{base}" subprocess.check_call(f"docker tag {tag} {keep_tag}", shell=True) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--create_index", action="store_true", required=False, default=False ) args = parser.parse_args() imgs = [ { "tag": "nvidia/cuda:10.0-cudnn7-devel-centos7", "url": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.0-cudnn7-devel-centos7.tar.gz", }, { "tag": "nvidia/cuda:10.1-cudnn7-devel-centos7", "url": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.1-cudnn7-devel-centos7.tar.gz", }, { "tag": "nvidia/cuda:10.2-cudnn7-devel-centos7", "url": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.2-cudnn7-devel-centos7.tar.gz", }, { "tag": "nvidia/cuda:11.0-cudnn8-devel-centos7", "url": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda11.0-cudnn8-devel-centos7.tar.gz", }, { "tag": "nvidia/cuda:11.1-cudnn8-devel-centos7", "url": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda11.1-cudnn8-devel-centos7.tar.gz", }, ] for img in imgs: check_and_download(img["tag"], img["url"]) ================================================ FILE: ci/build/make.sh ================================================ set -ex src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} tmp_dir=${ONEFLOW_CI_TMP_DIR:-"$HOME/ci-tmp"} extra_oneflow_cmake_args=${ONEFLOW_CI_EXTRA_ONEFLOW_CMAKE_ARGS:-""} package_suffix=${ONEFLOW_CI_PACKAGE_SUFFIX:-""} cuda_version=${ONEFLOW_CI_CUDA_VERSION:-"10.2"} python_version_args=${ONEFLOW_CI_PYTHON_VERSION_ARGS:-"--python3.6"} build_wheel_bash_args=${ONEFLOW_CI_BUILD_WHEEL_BASH_ARGS:-"-l"} mkdir -p $tmp_dir docker_tag=${ONEFLOW_CI_DOCKER_TAG:-"oneflow:ci-manylinux2014-cuda10.2"} docker_proxy_build_args="" docker_proxy_build_args+="--build-arg http_proxy=${ONEFLOW_CI_HTTP_PROXY} --build-arg https_proxy=${ONEFLOW_CI_HTTPS_PROXY}" docker_proxy_run_args="" docker_proxy_run_args+="--env http_proxy=${ONEFLOW_CI_HTTP_PROXY} --env https_proxy=${ONEFLOW_CI_HTTPS_PROXY}" docker_it="" if [[ -t 1 ]]; then docker_it="-it" fi # build manylinux image cd $src_dir docker build -f $src_dir/docker/package/manylinux/Dockerfile \ --build-arg from=nvidia/cuda:${cuda_version}-cudnn7-devel-centos7 \ $docker_proxy_build_args -t $docker_tag . cd - # build function function build() { set -x docker run --rm \ -v $tmp_dir:/ci-tmp \ -w $tmp_dir:/ci-tmp busybox rm -rf /ci-tmp/wheelhouse docker run \ $docker_proxy_run_args \ --rm $docker_it \ -v $src_dir:/oneflow-src \ -v $tmp_dir:/ci-tmp \ -w /ci-tmp \ "$docker_tag" \ bash ${build_wheel_bash_args} /oneflow-src/docker/package/manylinux/build_wheel.sh \ ${python_version_args} \ --house-dir /ci-tmp/wheelhouse \ --package-name oneflow${package_suffix} \ $extra_oneflow_cmake_args } set +e # reuse cache build # clean cache and retry cached_build_ret=$? set -e if [ $cached_build_ret -ne 0 ] && [[ ! -t 1 ]]; then echo "retry after cleaning build dir" docker run --rm -v $tmp_dir:/ci-tmp busybox sh -c "rm -rf /ci-tmp/*" build fi ================================================ FILE: ci/check/clang_tidy_warnings_as_errors_on_diff ================================================ *,-maybe-glog-fatal,-clang-analyzer-alpha.*,-clang-analyzer-cplusplus.NewDelete,-clang-diagnostic-* ================================================ FILE: ci/check/lintutils.py ================================================ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import multiprocessing as mp import os from fnmatch import fnmatch from subprocess import Popen def chunk(seq, n): """ divide a sequence into equal sized chunks (the last chunk may be smaller, but won't be empty) """ chunks = [] some = [] for element in seq: if len(some) == n: chunks.append(some) some = [] some.append(element) if len(some) > 0: chunks.append(some) return chunks def dechunk(chunks): "flatten chunks into a single list" seq = [] for chunk in chunks: seq.extend(chunk) return seq def run_parallel(cmds, **kwargs): """ Run each of cmds (with shared **kwargs) using subprocess.Popen then wait for all of them to complete. Runs batches of multiprocessing.cpu_count() * 2 from cmds returns a list of tuples containing each process' returncode, stdout, stderr """ complete = [] for cmds_batch in chunk(cmds, mp.cpu_count() * 2): procs_batch = [Popen(cmd, **kwargs) for cmd in cmds_batch] for proc in procs_batch: stdout, stderr = proc.communicate() complete.append((proc.returncode, stdout, stderr)) return complete _source_extensions = """ .h .cc .cpp .cu .cuh """.split() def get_sources(source_dir, exclude_globs=[]): sources = [] for directory, subdirs, basenames in os.walk(source_dir): for path in [os.path.join(directory, basename) for basename in basenames]: # filter out non-source files if os.path.splitext(path)[1] not in _source_extensions: continue path = os.path.abspath(path) # filter out files that match the globs in the globs file if any([fnmatch(path, glob) for glob in exclude_globs]): continue sources.append(path) return sources def stdout_pathcolonline(completed_process, filenames): """ given a completed process which may have reported some files as problematic by printing the path name followed by ':' then a line number, examine stdout and return the set of actually reported file names """ returncode, stdout, stderr = completed_process bfilenames = set() for filename in filenames: bfilenames.add(filename.encode("utf-8") + b":") problem_files = set() for line in stdout.splitlines(): for filename in bfilenames: if line.startswith(filename): problem_files.add(filename.decode("utf-8")) bfilenames.remove(filename) break return problem_files, stdout ================================================ FILE: ci/check/run_clang_format.py ================================================ #!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import argparse import pathlib import multiprocessing import subprocess import os import platform def split_and_print(prefix, text): lines = text.decode().splitlines(keepends=True) prefixed = "" for l in lines: prefixed += f"{prefix} {l.strip()}" if l.strip(): print(prefixed, flush=True) async def handle_stream(stream, cb): while True: line = await stream.readline() if line: cb(line) else: break async def run_command(cmd=None, dry=False, name=None): if dry: print(f"[dry] {cmd}") return 0 process = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) l = lambda x: split_and_print(f"[{name}]" if name else "", x) # l = lambda x: x await asyncio.gather( handle_stream(process.stdout, l), handle_stream(process.stderr, l), ) await process.wait() return process.returncode def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] def check_version(bin): try: out = subprocess.check_output(["bash", "-c", f"{bin} --version"]).decode() print(out) return "version 11.0.0" in out except: return False def download(dry=False): if platform.system() != "Linux": raise ValueError("Please install clang format 11.0.0") url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-format/linux-x86/clang-format-11" if os.getenv("CI"): url = "https://github.com/Oneflow-Inc/oneflow-fmt/raw/master/clang-format/linux-x86/clang-format-11" dst_dir = ".cache/bin" dst = f"{dst_dir}/clang-format-11" if dry: if os.path.isfile(dst): return dst else: None else: assert subprocess.call(f"mkdir -p {dst_dir}", shell=True) == 0 assert subprocess.call(f"curl -L {url} -o {dst}", shell=True) == 0 assert subprocess.call(f"chmod +x {dst}", shell=True) == 0 return dst if __name__ == "__main__": parser = argparse.ArgumentParser( description="Runs clang-format on all of the source " "files. If --fix is specified enforce format by " "modifying in place, otherwise compare the output " "with the existing file and output any necessary " "changes as a patch in unified diff format" ) parser.add_argument( "--clang_format_binary", required=False, help="Path to the clang-format binary.", default="clang-format", ) parser.add_argument( "--source_dir", required=True, help="Root directory of the source code" ) parser.add_argument( "--fix", default=False, action="store_true", help="If specified, will re-format the source " "code instead of comparing the re-formatted " "output, defaults to %(default)s", ) parser.add_argument( "--quiet", default=False, action="store_true", help="If specified, only print errors", ) args = parser.parse_args() exts = [".h", ".cc", ".cpp", ".cu", ".cuh"] files = filter( lambda p: p.suffix in exts, pathlib.Path(args.source_dir).rglob("*"), ) loop = asyncio.get_event_loop() files = [str(f) for f in files] clang_fmt_args = "-dry-run --Werror" if args.fix: clang_fmt_args = "-i" results = [] if check_version(args.clang_format_binary) == False: downloaded = download(dry=True) if downloaded: assert check_version(downloaded) args.clang_format_binary = downloaded else: args.clang_format_binary = download() assert check_version(args.clang_format_binary) for chunk in chunks(files, multiprocessing.cpu_count() * 2): promises = [ run_command(f"{args.clang_format_binary} {clang_fmt_args} {f}") for f in chunk ] chunk_results = loop.run_until_complete(asyncio.gather(*promises)) results.extend(chunk_results) print(len(results), "files checked") assert len(results) == len(files) for (r, f) in zip(results, files): if r != 0: print("[fail]", f) assert sum(results) == 0 ================================================ FILE: ci/check/run_clang_tidy.py ================================================ #!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import argparse import subprocess import os from typing import List, Optional from pathlib import Path def split_and_print(prefix, text): lines = text.decode().splitlines(keepends=True) prefixed = "" for l in lines: prefixed += f"{prefix} {l.strip()}" if l.strip(): print(prefixed, flush=True) async def handle_stream(stream, cb): while True: line = await stream.readline() if line: cb(line) else: break async def run_command(cmd=None, dry=False, name=None): if dry: print(f"[dry] {cmd}") return 0 process = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) l = lambda x: split_and_print(f"[{name}]" if name else "", x) await asyncio.gather( handle_stream(process.stdout, l), handle_stream(process.stderr, l), ) await process.wait() return process.returncode def download(build_dir, dry=False) -> Optional[List[str]]: urls = [ "https://github.com/Oneflow-Inc/llvm-project/releases/download/update-err-msg-checker/clang-tidy-15.AppImage" if os.getenv("CI") else "https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-tidy/linux-x86_64/clang-tidy-15.AppImage", "https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/clang-tidy-diff.py", ] dst_dir = f"{build_dir}/cache/bin" dst = [f"{dst_dir}/clang-tidy", f"{dst_dir}/clang-tidy-diff.py"] if dry: if os.path.isfile(dst[0]) and os.path.isfile(dst[1]): return dst else: None else: assert subprocess.call(f"mkdir -p {dst_dir}", shell=True) == 0 for i, _dst in enumerate(dst): assert subprocess.call(f"curl -L {urls[i]} -o {_dst}", shell=True) == 0 assert subprocess.call(f"chmod +x {_dst}", shell=True) == 0 return dst if __name__ == "__main__": parser = argparse.ArgumentParser( description="Runs clang-tidy on all of the source files." ) parser.add_argument( "--build_dir", required=True, ) parser.add_argument( "--check-error-msg", action="store_true", default=False, ) args = parser.parse_args() loop = asyncio.get_event_loop() downloaded = download(args.build_dir, dry=True) if downloaded is None: downloaded = download(args.build_dir) assert downloaded is not None warnings_as_errors = ( (Path(__file__).parent / "clang_tidy_warnings_as_errors_on_diff") .read_text() .strip() ) cmd = f"git diff -U0 master | {downloaded[1]} -clang-tidy-binary {downloaded[0]} -path {args.build_dir} -j $(nproc) -p1 -allow-enabling-alpha-checkers -extra-arg=-Xclang -extra-arg=-analyzer-config -extra-arg=-Xclang -extra-arg=aggressive-binary-operation-simplification=true" if args.check_error_msg: command = f" cd .. && {cmd} -warnings-as-errors='{warnings_as_errors}' && {cmd} -checks=-*,maybe-need-error-msg -warnings-as-errors=* -skip-line-filter" else: command = f"cd .. && {cmd} -warnings-as-errors='{warnings_as_errors}'" ret_code = loop.run_until_complete(run_command(command)) exit(ret_code) ================================================ FILE: ci/check/run_cmake_format.py ================================================ from subprocess import call from argparse import ArgumentParser from glob import glob from pathlib import Path from multiprocessing.pool import ThreadPool from multiprocessing import cpu_count if __name__ == "__main__": parser = ArgumentParser( description="Runs cmake-format on all of the cmake source files." ) parser.add_argument( "--bin", default="cmake-format", help="Path of cmake-format binary" ) parser.add_argument( "--fix", default=False, action="store_true", help="Format all sources in place" ) parser.add_argument( "--source_dir", default=".", help="Root directory of the source code" ) parser.add_argument( "-j", "--jobs", type=int, default=cpu_count(), help="Specifies the number of jobs (commands) to run simultaneously", ) args = parser.parse_args() patterns = [ "cmake/**/*.cmake", "oneflow/**/*.cmake", "oneflow/**/CMakeLists.txt", "tools/**/*.cmake", "tools/**/CMakeLists.txt", "CMakeLists.txt", ] files = [] for pattern in patterns: files.extend(glob(str(Path(args.source_dir) / pattern), recursive=True)) def gen_cmd(file): cmd = [args.bin, file] cmd.append("-i" if args.fix else "--check") return cmd tp = ThreadPool(args.jobs) res = tp.map_async(call, [gen_cmd(file) for file in files]) tp.close() tp.join() count = sum(map(lambda x: 0 if x == 0 else 1, res.get())) total = len(files) if args.fix: print(f"cmake-format -i done. {total} total") else: print(f"cmake-format --check done. {count} failed / {total} total") exit(0 if count == 0 else 1) ================================================ FILE: ci/check/run_license_format.py ================================================ import argparse import os import glob from multiprocessing import Pool LICENSE_TXT = """Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ CPP_TXT = "/*\n{}*/\n".format(LICENSE_TXT) PY_TXT = '"""\n{}"""\n'.format(LICENSE_TXT) def get_txt(path: str): if path.endswith((".cpp", ".h", ".hpp", ".cu", ".cuh")): return CPP_TXT elif path.endswith((".py")): return PY_TXT else: return None def check_file(path): with open(path, "r", encoding="utf-8") as f: content = f.read() txt = get_txt(path) if ( "import doctest" in content and "raise_on_error=True" not in content and "doctest.DebugRunner" not in content ): return ("please add 'doctest.testmod(raise_on_error=True)'", content) elif content.count("The OneFlow Authors. All rights reserved.") > 1: return ("license_duplicated", content) elif content.startswith(txt) or (not content): return ("ok", content) elif content.startswith(txt) == False: return ("license_absent", content) def format_file(path): txt = get_txt(path) with open(path, "r", encoding="utf-8") as r: content = r.read() format_status, content = check_file(path) if format_status == "ok": return True elif format_status == "license_absent": with open(path, "w") as w: new_content = txt + content w.write(new_content) return False else: raise ValueError(f"{format_status} {path}") def do_check(x): format_status, _ = check_file(x) return (x, format_status) def do_format(x): return (x, format_file(x)) def glob_files(path: str = None, excludes=None): files = [] for ext in ("**/*.cpp", "**/*.h", "**/*.hpp", "**/*.cu", "**/*.cuh", "**/*.py"): joined = os.path.join(path, ext) files.extend(glob.glob(joined, recursive=True)) files = [ f for f in files if "version.py" not in f and all([not e in f for e in excludes]) ] print("[files]", len(files)) return files if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-i", "--root_path", type=str, required=True) parser.add_argument( "-v", "--verbose", default=False, action="store_true", required=False ) parser.add_argument("--silent", default=False, action="store_true", required=False) parser.add_argument( "-c", "--check", default=False, action="store_true", required=False ) parser.add_argument( "-f", "--fix", default=False, action="store_true", required=False ) parser.add_argument("--exclude", action="append", default=[]) args = parser.parse_args() files = glob_files(args.root_path, excludes=args.exclude) assert args.check != args.fix with Pool(10) as p: if args.check: any_absence = False for (p, format_status) in p.map(do_check, files): if format_status != "ok": print(f"{format_status}:", p) any_absence = True if any_absence: exit(1) if args.fix: for (p, format_result) in p.map(do_format, files): if format_result == True: if args.verbose: print("license already added:", p) else: if args.silent == False: print("license just added:", p) ================================================ FILE: ci/check/run_py_format.py ================================================ import argparse import sys import platform from subprocess import Popen import os if __name__ == "__main__": major = platform.sys.version_info.major minor = platform.sys.version_info.minor if major == 3 and minor < 6: print("WARNING: python >= 3.6 required, python source format won't run") exit(0) parser = argparse.ArgumentParser( description="Runs py-format on all of the source files." "If --fix is specified enforce format by modifying in place." ) parser.add_argument( "--source_dir", required=True, help="Root directory of the source code" ) parser.add_argument( "--fix", default=False, action="store_true", help="If specified, will re-format the source", ) arguments = parser.parse_args() os.chdir(arguments.source_dir) version_cmd = sys.executable + " -m {} --version | grep {} > /dev/null" BLACK_VER = "19.10b0" if os.system(version_cmd.format("black", BLACK_VER)): print( f"Please install black {BLACK_VER}. For instance, run 'python3 -m pip install black=={BLACK_VER} --user'" ) sys.exit(1) cmd_line = sys.executable + " -m black " + "." if arguments.fix == False: cmd_line += " --check" if os.system(cmd_line): sys.exit(1) ================================================ FILE: ci/clang/build-llvm.sh ================================================ set -ex export PATH=/usr/lib/llvm-15/bin:/usr/lib64/ccache:/root/.local/bin:$PATH # clean python dir cd ${ONEFLOW_CI_SRC_DIR} ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt cd python git config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR} git clean -nXd -e \!dist -e \!dist/** git clean -fXd -e \!dist -e \!dist/** # cmake config mkdir -p ${ONEFLOW_CI_BUILD_DIR} cd ${ONEFLOW_CI_BUILD_DIR} find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete if [ ! -f "$ONEFLOW_CI_CMAKE_INIT_CACHE" ]; then echo "$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist." exit 1 fi cmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE} # cmake build cd ${ONEFLOW_CI_BUILD_DIR} cmake --build . -j $(nproc) # build pip cd ${ONEFLOW_CI_SRC_DIR} cd python ${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel ================================================ FILE: ci/conda/build-clang.sh ================================================ set -ex conda activate oneflow-dev-clang10-v2 mkdir -p build cd build cmake .. -C ../cmake/caches/cn/fast/cpu-clang.cmake cmake --build . -j $(nproc) cd - cd python python setup.py bdist_wheel echo "wheelhouse_dir=$PWD/dist" >> $GITHUB_ENV ================================================ FILE: ci/conda/tuna.condarc ================================================ channels: - defaults show_channel_urls: true default_channels: - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 custom_channels: conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud ================================================ FILE: ci/fixed-dev-requirements.txt ================================================ numpy==1.26.4 ; python_version >= "3.12" numpy==1.22.1 ; python_version >= "3.10" and python_version < "3.12" numpy==1.21.6 ; python_version >= "3.7" and python_version < "3.10" ================================================ FILE: ci/manylinux/build-gcc7-xla.sh ================================================ source scl_source enable devtoolset-7 set -ex ONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)} gcc --version ld --version # clean python dir cd ${ONEFLOW_CI_SRC_DIR} ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt cd python git clean -nXd -e \!dist -e \!dist/** git clean -fXd -e \!dist -e \!dist/** # cmake config mkdir -p ${ONEFLOW_CI_BUILD_DIR} cd ${ONEFLOW_CI_BUILD_DIR} find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete if [ ! -f "$ONEFLOW_CI_CMAKE_INIT_CACHE" ]; then echo "$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist." exit 1 fi export PATH="${PATH}:$(dirname ${ONEFLOW_CI_PYTHON_EXE})" export PYTHON_BIN_PATH=${ONEFLOW_CI_PYTHON_EXE} cmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE} # cmake build cd ${ONEFLOW_CI_BUILD_DIR} cmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL} # build pip cd ${ONEFLOW_CI_SRC_DIR} cd python ${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel ================================================ FILE: ci/manylinux/build-gcc9.sh ================================================ source scl_source enable devtoolset-9 set -ex ONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)} gcc --version ld --version # clean python dir cd ${ONEFLOW_CI_SRC_DIR} ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user auditwheel setuptools wheel cd python function clean_artifacts { git config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR} git clean -nXd -e \!dist -e \!dist/** git clean -fXd -e \!dist -e \!dist/** } clean_artifacts # cmake config mkdir -p ${ONEFLOW_CI_BUILD_DIR} cd ${ONEFLOW_CI_BUILD_DIR} find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete if [ ! -f "$ONEFLOW_CI_CMAKE_INIT_CACHE" ]; then echo "$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist." exit 1 fi export PATH="${PATH}:$(dirname ${ONEFLOW_CI_PYTHON_EXE})" export PYTHON_BIN_PATH=${ONEFLOW_CI_PYTHON_EXE} cmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE} # cmake build cd ${ONEFLOW_CI_BUILD_DIR} cmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL} if [ ! -z "$ONEFLOW_CI_BUILD_RUN_LIT" ]; then ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user flowvision==0.1.0 export PATH=$PATH:$(dirname $ONEFLOW_CI_PYTHON_EXE) cmake --build . -t c1 fi # build pip cd ${ONEFLOW_CI_SRC_DIR} cd python ${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel ================================================ FILE: ci/manylinux/build.sh ================================================ set -ex ONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)} gcc --version ld --version # clean python dir cd ${ONEFLOW_CI_SRC_DIR} ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user auditwheel setuptools wheel cd python function clean_artifacts { git config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR} git clean -nXd -e \!dist -e \!dist/** git clean -fXd -e \!dist -e \!dist/** } clean_artifacts # cmake config mkdir -p ${ONEFLOW_CI_BUILD_DIR} cd ${ONEFLOW_CI_BUILD_DIR} find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete if [ ! -f "$ONEFLOW_CI_CMAKE_INIT_CACHE" ]; then echo "$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist." exit 1 fi cmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE} # cmake build cd ${ONEFLOW_CI_BUILD_DIR} cmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL} if [ ! -z "$ONEFLOW_CI_BUILD_RUN_LIT" ]; then ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user flowvision==0.1.0 export PATH=$PATH:$(dirname $ONEFLOW_CI_PYTHON_EXE) cmake --build . -t c1 fi # build pip cd ${ONEFLOW_CI_SRC_DIR} cd python ${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel ================================================ FILE: ci/requirements.txt ================================================ pycocotools opencv-python==4.3.0.38; sys_platform == 'darwin' opencv-python==4.2.0.34; sys_platform != 'darwin' scipy pillow tensorflow-addons==0.13.0 tensorflow==2.5.0 ================================================ FILE: ci/reset_submodule.sh ================================================ set -x set -e git reset --hard git submodule deinit -f . rm -rf .git/modules/* ================================================ FILE: ci/setup_submodule.py ================================================ import configparser import argparse import os parser = argparse.ArgumentParser() parser.add_argument("-s", "--oneflow_src_local_path", type=str, required=False) parser.add_argument("-r", "--oneflow_src_remote_url", type=str, required=False) args = parser.parse_args() assert ( args.oneflow_src_local_path or args.oneflow_src_remote_url ), "require one of oneflow_src_local_path or oneflow_src_remote_url" config = configparser.ConfigParser() config.read(".gitmodules") for s in config.sections(): path = config[s]["path"] if args.oneflow_src_local_path: src_path = os.path.join(args.oneflow_src_local_path, path) assert os.path.exists("{}/.git".format(src_path)), src_path config[s]["url"] = "file://{}".format(src_path) else: src_path = os.path.join(args.oneflow_src_remote_url, path) config[s]["url"] = src_path with open(".gitmodules", "w") as configfile: config.write(configfile) ================================================ FILE: ci/setup_submodule.sh ================================================ set -x set -e src_dir=${ONEFLOW_CI_SRC_DIR:-"$HOME/oneflow"} python3 ci/setup_submodule.py --oneflow_src_local_path=$src_dir git submodule sync git submodule update --init --recursive ================================================ FILE: ci/test/1node_benchmark_test.sh ================================================ set -xe rm -rf /benchmarks cp -r python/oneflow/compatible/single_client/benchmarks /benchmarks cd /benchmarks python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="vgg16" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="alexnet" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="resnet50" \ --batch_size_per_device=8 \ --iter_num=5 \ --gpu_image_decoder=True \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="resnet50" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 python3 bert_benchmark/run_pretraining.py \ --gpu_num_per_node=1 \ --node_num=1 \ --learning_rate=1e-4 \ --weight_decay_rate=0.01 \ --batch_size_per_device=24 \ --iter_num=5 \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/bert/bert_seq_len_128_repeat1024" \ --data_part_num=1 \ --seq_length=128 \ --max_predictions_per_seq=20 \ --num_hidden_layers=12 \ --num_attention_heads=12 \ --max_position_embeddings=512 \ --type_vocab_size=2 \ --vocab_size=30522 \ --attention_probs_dropout_prob=0.1 \ --hidden_dropout_prob=0.1 \ --hidden_size_per_head=64 ================================================ FILE: ci/test/1node_benchmark_test_fp16.sh ================================================ set -ex rm -rf /benchmarks cp -r python/oneflow/compatible/single_client/benchmarks /benchmarks cd /benchmarks python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="vgg16" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" \ --enable_auto_mixed_precision=True python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="alexnet" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" \ --enable_auto_mixed_precision=True python3 cnn_benchmark/of_cnn_benchmarks.py \ --gpu_num_per_node=1 \ --model="resnet50" \ --batch_size_per_device=8 \ --iter_num=5 \ --learning_rate=0.01 \ --optimizer="sgd" \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/imagenet_227/train/32" \ --enable_auto_mixed_precision=True python3 bert_benchmark/run_pretraining.py \ --gpu_num_per_node=1 \ --node_num=1 \ --learning_rate=1e-4 \ --weight_decay_rate=0.01 \ --batch_size_per_device=24 \ --iter_num=5 \ --loss_print_every_n_iter=1 \ --data_dir="/dataset/bert/bert_seq_len_128_repeat1024" \ --data_part_num=1 \ --seq_length=128 \ --max_predictions_per_seq=20 \ --num_hidden_layers=12 \ --num_attention_heads=12 \ --max_position_embeddings=512 \ --type_vocab_size=2 \ --vocab_size=30522 \ --attention_probs_dropout_prob=0.1 \ --hidden_dropout_prob=0.1 \ --hidden_size_per_head=64 \ --enable_auto_mixed_precision=True ================================================ FILE: ci/test/1node_custom_op_test.sh ================================================ #!/bin/bash set -xe src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"./test_tmp_dir"} rm -rf $test_tmp_dir mkdir -p $test_tmp_dir cp -r $src_dir/python/oneflow/compatible/single_client/test/custom_ops $test_tmp_dir cd $test_tmp_dir export ONEFLOW_TEST_DEVICE_NUM=1 python3 -m unittest discover ./custom_ops --failfast --verbose ================================================ FILE: ci/test/1node_model_eager_test.sh ================================================ #!/bin/bash set -xe cp -r python/oneflow/test /test_dir cd /test_dir python3 models/eager_1node_test.py ================================================ FILE: ci/test/1node_model_test.sh ================================================ #!/bin/bash set -xe cp -r python/oneflow/compatible/single_client/test /test_dir cd /test_dir python3 models/1node_test.py ================================================ FILE: ci/test/1node_op_test.sh ================================================ #!/bin/bash set -xe export TF_CPP_MIN_LOG_LEVEL=3 export PYTHONUNBUFFERED=1 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"./test_tmp_dir"} rm -rf $test_tmp_dir mkdir -p $test_tmp_dir cp -r $src_dir/python/oneflow/compatible/single_client/test $test_tmp_dir cd $test_tmp_dir python3 -m oneflow --doctor gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) for CHUNK in 1 do export ONEFLOW_TEST_DEVICE_NUM=${CHUNK} python3 $src_dir/ci/test/parallel_run.py \ --gpu_num="${gpu_num}" \ --dir=test/ops \ --timeout=1 \ --verbose \ --chunk=${CHUNK} done if [ -z "$ONEFLOW_TEST_ENABLE_EAGER" ] then export ONEFLOW_TEST_DEVICE_NUM=2 python3 -m unittest discover test/ops --failfast --verbose export ONEFLOW_TEST_DEVICE_NUM=4 python3 -m unittest discover test/ops --failfast --verbose else echo "deadlock unsolved, skipping multi-card eager" fi ================================================ FILE: ci/test/2node_op_test.sh ================================================ #!/bin/bash set -xe export PYTHONUNBUFFERED=1 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"./test_tmp_dir"} rm -rf $test_tmp_dir mkdir -p $test_tmp_dir chmod -R o+w $test_tmp_dir cp -r $src_dir/python/oneflow/compatible/single_client/test $test_tmp_dir cd $test_tmp_dir ONEFLOW_TEST_DEVICE_NUM=1 python3 test/ops/test_assign.py --failfast --verbose ONEFLOW_TEST_DEVICE_NUM=1 python3 test/ops/test_two_node_boxing.py --failfast --verbose for device_num in 1 2 4 do ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST=1 ONEFLOW_TEST_DEVICE_NUM=$device_num python3 -m unittest discover test/ops --failfast --verbose # use a invalid ibverbs lib to test if falling back to epoll works ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST=1 ONEFLOW_TEST_DEVICE_NUM=$device_num ONEFLOW_LIBIBVERBS_PATH=invalid_lib python3 -m unittest discover test/ops --failfast --verbose done ================================================ FILE: ci/test/2node_op_test_multi_client.sh ================================================ #!/bin/bash set -xeu export PYTHONUNBUFFERED=1 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} ONEFLOW_CI_DEVICE_NUMS=${ONEFLOW_CI_DEVICE_NUMS:-"1 2 4"} for device_num in ${ONEFLOW_CI_DEVICE_NUMS} do export ONEFLOW_TEST_NODE_NUM=2 export ONEFLOW_TEST_DEVICE_NUM=$device_num time python3 ${src_dir}/ci/test/multi_launch.py \ --files "${ONEFLOW_TEST_DIR}/**/test_*.py" \ -n 4 \ --group_size $device_num \ --device_num 4 \ --verbose \ --auto_cuda_visible_devices \ -m oneflow.distributed.launch \ --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr $_MASTER_ADDR \ -m pytest --max-worker-restart=0 -x --durations=50 --capture=sys -p no:cacheprovider -p no:randomly --ignore=log done ================================================ FILE: ci/test/CMakeLists.txt ================================================ set(PYTHON_EXECUTABLE python3 CACHE STRING "python3 exe to run test, usually is the python3 installation oneflow is linked to") set(ONEFLOW_SRC_DIR ${CMAKE_SOURCE_DIR} CACHE STRING "source dir of oneflow") set(IS_DEV ON CACHE BOOL "") set(CTEST_RESOURCE_SPEC_FILE "${CMAKE_CURRENT_SOURCE_DIR}/resource-spec/2x-rtx-2080.json" CACHE STRING "") # CTEST_OUTPUT_ON_FAILURE=1 CTEST_PARALLEL_LEVEL=20 ninja test file(GLOB_RECURSE PYTHON_TEST_FILES LIST_DIRECTORIES false RELATIVE ${ONEFLOW_SRC_DIR} "${ONEFLOW_SRC_DIR}/python/oneflow/test_*.py") foreach(PYTHON_TEST_FILE ${PYTHON_TEST_FILES}) set(TEST_NAME ${PYTHON_TEST_FILE}) add_test(NAME ${TEST_NAME} COMMAND ${PYTHON_EXECUTABLE} ${ONEFLOW_SRC_DIR}/${PYTHON_TEST_FILE} --failfast --verbose ) set_tests_properties(${TEST_NAME} PROPERTIES ENVIRONMENT "$<$>:ONEFLOW_TEST_CPU_ONLY=1>;$<$:PYTHONPATH=${ONEFLOW_SRC_DIR}/python:$ENV{PYTHONPATH}>" RESOURCE_GROUPS "vram:2000" ) endforeach() ================================================ FILE: ci/test/build_docs.sh ================================================ set -ex src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"$PWD/build-docs"} rm -rf $test_tmp_dir cp -r docs ${test_tmp_dir} cd ${test_tmp_dir} make html SPHINXOPTS="-W --keep-going" ================================================ FILE: ci/test/distributed_run.py ================================================ from multiprocessing.connection import Listener import os import subprocess import socket import tempfile from contextlib import closing import argparse import uuid import getpass import atexit import pathlib import asyncio import glob from datetime import date from pathlib import Path HARD_CODED_AFFILIATIONS = { "192.168.1.11": ["192.168.1.12",], "192.168.1.12": ["192.168.1.11",], "192.168.1.13": ["192.168.1.11",], "192.168.1.15": ["192.168.1.16",], "192.168.1.16": ["192.168.1.15",], } def is_img_existing(tag): returncode = subprocess.run( "docker image inspect {}".format(tag), shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode if returncode == 0: print("[OK]", tag) return True else: return False def get_affiliations(host): # TODO(tsai): Implement a HTTP endpoint to retrieve affiliations if host in HARD_CODED_AFFILIATIONS: return HARD_CODED_AFFILIATIONS[host] else: return None def resolve_hostname_hardcoded(host: str): if host.startswith("oneflow"): number = host.split("-")[-1] return f"192.168.1.{number}" else: return host def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("localhost", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] async def spawn_shell(cmd: str = None): p = await asyncio.create_subprocess_shell(cmd,) await p.wait() assert p.returncode == 0, cmd async def spawn_shell_ignoring_failure(cmd: str = None): p = await asyncio.create_subprocess_shell(cmd,) await p.wait() async def build_docker_img(remote_host=None, workspace_dir=None): if remote_host: assert workspace_dir await spawn_shell("rm -f > oneflow-src.zip") await spawn_shell("git archive --format zip HEAD > oneflow-src.zip") await spawn_shell( f"scp oneflow-src.zip {remote_host}:{workspace_dir}/oneflow-src.zip", ) await spawn_shell( f"ssh {remote_host} unzip {workspace_dir}/oneflow-src.zip -d {workspace_dir}/oneflow-src", ) await spawn_shell( f"ssh {remote_host} bash {workspace_dir}/oneflow-src/docker/ci/test/build.sh", ) else: await spawn_shell(f"bash docker/ci/test/build.sh") async def create_remote_workspace_dir( remote_host=None, workspace_dir=None, copy_files=None ): await spawn_shell(f"ssh {remote_host} mkdir -p {workspace_dir}") if copy_files is not None: for path in copy_files: # Reference: https://stackoverflow.com/a/31278462 if os.path.isdir(path) and path[-1] != "/": path += "/" await spawn_shell(f"ssh {remote_host} mkdir -p {workspace_dir}/{path}") await spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --exclude='__pycache__' {path} {remote_host}:{workspace_dir}/{path}" ) print("create_remote_workspace_dir done") def get_docker_cache_args(): return " ".join( [ f"-v {Path.home() / 'test-container-cache/dot-local'}:/root/.local", f"-v {Path.home() / 'test-container-cache/dot-cache'}:/root/.cache", ] ) async def launch_remote_container( remote_host=None, survival_time=None, workspace_dir=None, container_name=None, img_tag=None, oneflow_wheel_path=None, oneflow_python_path=None, cmd=None, node_rank=None, master_addr=None, ): print("launching remote container at", remote_host) assert img_tag multi_client_args = [node_rank, master_addr] multi_client_arg_has_value = [x is not None for x in multi_client_args] assert all(multi_client_arg_has_value) pythonpath_args = None if oneflow_wheel_path: pythonpath_args = "" elif oneflow_python_path: pythonpath_args = f"--env PYTHONPATH={workspace_dir}/python" else: raise ValueError("must have oneflow_wheel_path or oneflow_python_path") docker_cmd = f"""docker run --privileged -d --network host --shm-size=8g --rm {get_docker_cache_args()} -v {workspace_dir}:{workspace_dir} -w {workspace_dir} -v /dataset:/dataset -v /model_zoo:/model_zoo --name {container_name} {pythonpath_args} {img_tag} sleep {survival_time} """ await spawn_shell(f"ssh {remote_host} {docker_cmd}") if oneflow_wheel_path: whl_basename = os.path.basename(oneflow_wheel_path) await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple" ) await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m pip install {workspace_dir}/{whl_basename}" ) await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m oneflow --doctor" ) if cmd: multi_client_docker_args = ( # Use _MASTER_ADDR to avoid name conflict with OneFlow's built-in MASTER_ADDR f"--env NODE_RANK={node_rank} --env _MASTER_ADDR={master_addr}" ) await spawn_shell( f"ssh {remote_host} docker exec {multi_client_docker_args} {container_name} {cmd}" ) def handle_cast(conn=None, cmd=None): received_cmd: str = conn.recv().decode() assert received_cmd.startswith("cast/") received_cmd = received_cmd.replace("cast/", "") assert received_cmd == cmd, (received_cmd, cmd) return conn.recv().decode() def handle_call(conn=None, cmd=None, response=None): received_cmd: str = conn.recv().decode() assert received_cmd.startswith("call/") received_cmd = received_cmd.replace("call/", "") assert received_cmd == cmd, (received_cmd, cmd) msg = conn.recv().decode() conn.send(response.encode()) return msg class DockerAgent: def __init__( self, port=None, authkey=None, this_host=None, remote_hosts=None, container_name=None, timeout=None, workspace_dir=None, img_tag=None, oneflow_wheel_path=None, oneflow_python_path=None, oneflow_test_tmp_dir=None, extra_docker_args: str = None, ) -> None: # info self.this_host = this_host self.remote_hosts = remote_hosts self.container_name = container_name self.timeout = timeout self.common_docker_args = "--privileged --rm --network host --shm-size=8g -v $HOME:$HOME -v /dataset:/dataset -v /model_zoo:/model_zoo" self.workspace_dir = workspace_dir self.img_tag = img_tag self.oneflow_wheel_path = oneflow_wheel_path self.oneflow_python_path = oneflow_python_path self.oneflow_test_tmp_dir = oneflow_test_tmp_dir # impl self.env_proto_txt = None self.bash_tmp_file = None self.bash_proc = None self.remote_docker_proc = {} self.agent_port = port self.agent_authkey = authkey self.extra_docker_args = extra_docker_args def __enter__(self): return self def run_bash_script_async(self, bash_script=None, cmd=None): remote_hosts_str = ",".join(self.remote_hosts) ctrl_port = find_free_port() data_port = find_free_port() exports = f""" export ONEFLOW_TEST_MASTER_PORT={ctrl_port} export ONEFLOW_TEST_DATA_PORT={data_port} export ONEFLOW_TEST_NODE_LIST="{self.this_host},{remote_hosts_str}" export ONEFLOW_WORKER_KEEP_LOG=1 export ONEFLOW_TEST_TMP_DIR="{self.oneflow_test_tmp_dir}" export NCCL_DEBUG=INFO export ONEFLOW_TEST_WORKER_AGENT_PORT={agent_port} export ONEFLOW_TEST_WORKER_AGENT_AUTHKEY={agent_authkey} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple """ if self.oneflow_wheel_path: exports += f"python3 -m pip install {self.oneflow_wheel_path}" if self.oneflow_python_path: exports += f"export PYTHONPATH={self.oneflow_python_path}:$PYTHONPATH\n" bash_cmd = None if bash_script: assert os.path.exists(bash_script) bash_cmd = f"""set -ex {exports} bash {bash_script} """ elif cmd: bash_cmd = f"""set -ex {exports} {cmd} """ else: raise ValueError("not impl") assert bash_cmd def get_docker_cmd(f, cmd): f_name = f.name f.write(cmd) f.flush() return f"docker run {self.common_docker_args} {self.extra_docker_args} {get_docker_cache_args()} -v /tmp:/host/tmp:ro -v $PWD:$PWD -w $PWD --name {self.container_name} {self.img_tag} bash /host{f_name}" f = tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", delete=True) run_docker_cmd = get_docker_cmd(f, bash_cmd) self.bash_tmp_file = f self.bash_proc = subprocess.Popen(run_docker_cmd, shell=True) def __exit__(self, exc_type, exc_val, exc_tb): pass async def fix_and_sync_libs(oneflow_internal_path=None, remote_hosts=None): tmp_dir = tempfile.TemporaryDirectory() tmp_lib_dir = os.path.join(tmp_dir.name, "libs") os.mkdir(tmp_lib_dir) await spawn_shell( """ldd file | grep "=> /" | awk '{print $3}' | xargs -I '{}' cp -v '{}' destination""".replace( "file", oneflow_internal_path ).replace( "destination", tmp_lib_dir ), ) libs = os.listdir(tmp_lib_dir) assert len(libs) > 0 excludelist_path = os.path.join( pathlib.Path(__file__).parent.absolute(), "excludelist" ) excludelist = open(excludelist_path).read().split("\n") await spawn_shell(f"cp {oneflow_internal_path} {tmp_dir.name}") def handle_lib(lib): if lib in excludelist or "libpython" in lib: print("excluding", lib) return spawn_shell(f"rm {tmp_lib_dir}/{lib}") else: print("keeping", lib) return spawn_shell(f"patchelf --set-rpath '$ORIGIN' {tmp_lib_dir}/{lib}") await asyncio.gather(*(handle_lib(lib) for lib in libs)) tmp_oneflow_internal_path = os.path.join( tmp_dir.name, pathlib.Path(oneflow_internal_path).name ) print("before fixing .so") await spawn_shell(f"ldd {tmp_oneflow_internal_path}") print("fixing .so") await spawn_shell( f"patchelf --set-rpath '$ORIGIN/libs' {tmp_oneflow_internal_path}" ) await asyncio.gather( *[ spawn_shell( f"ssh {remote_host} 'mkdir -p {workspace_dir}/python/oneflow/libs'", ) for remote_host in remote_hosts ] ) async def copy_file(path=None, remote_host=None): relpath = os.path.relpath(path, tmp_dir.name) await spawn_shell( f"scp {path} {remote_host}:{workspace_dir}/python/oneflow/{relpath}", ) files = [ os.path.join(root, name) for root, dirs, files in os.walk(tmp_dir.name, topdown=True) for name in files ] await asyncio.gather( *[ copy_file(path=f, remote_host=remote_host) for remote_host in remote_hosts for f in files ], spawn_shell(f"ldd {tmp_oneflow_internal_path}"), ) async def remove_containers_by_name(remote_hosts=None, container_name=None): rm_cmd = f"docker rm -f {container_name}" assert container_name assert remote_hosts await asyncio.gather( *[ spawn_shell_ignoring_failure(f"ssh {remote_host} {rm_cmd}") for remote_host in remote_hosts ], spawn_shell_ignoring_failure(rm_cmd), ) def get_remote_hosts(args): remote_hosts = None if len(args.remote_host) == 1: remote_hosts = args.remote_host.split(",") elif len(args.remote_host) == 0: affiliations = get_affiliations(this_host) assert ( affiliations ), f"no affiliated node found for {this_host}, you should specify one" remote_host = affiliations[0] remote_host = socket.gethostbyname(remote_host) remote_hosts = [remote_host] else: remote_hosts = args.remote_host return remote_hosts if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--debug", action="store_true", required=False, default=False) parser.add_argument( "--skip_libs", action="store_true", required=False, default=False ) parser.add_argument("--bash_script", type=str, required=False) default_this_host = socket.gethostname() parser.add_argument( "--this_host", type=str, required=False, default=default_this_host ) parser.add_argument("--remote_host", action="append", default=[]) parser.add_argument("--oneflow_wheel_path", type=str, required=False, default=None) parser.add_argument( "--oneflow_wheel_python_version", type=str, required=False, default=None ) parser.add_argument("--oneflow_python_path", type=str, required=False, default=None) parser.add_argument("--custom_img_tag", type=str, required=False, default=None) parser.add_argument("--cmd", type=str, required=False, default=None) parser.add_argument( "--oneflow_test_tmp_dir", type=str, required=False, default="distributed-tmp" ) parser.add_argument("--timeout", type=int, required=False, default=1 * 60 * 60) parser.add_argument("--mode", type=str, required=False, default="multi_client") parser.add_argument("--copy_files", action="append", default=[]) args = parser.parse_args() assert args.mode in ["multi_client"] assert bool(args.oneflow_wheel_path) != bool(args.oneflow_python_path) assert bool(args.bash_script) != bool(args.cmd) if args.skip_libs: assert args.debug, "--skip_libs only works with --debug" assert ( args.oneflow_python_path ), "--skip_libs only works with --oneflow_python_path" oneflow_wheel_path = args.oneflow_wheel_path main_node_extra_docker_args = [] if oneflow_wheel_path and os.path.isdir(oneflow_wheel_path): assert os.path.isabs(oneflow_wheel_path) main_node_extra_docker_args.append( f"-v {oneflow_wheel_path}:{oneflow_wheel_path}:ro" ) whl_paths = [ name for name in glob.glob(os.path.join(oneflow_wheel_path, f"*.whl",)) ] if len(whl_paths) == 1: oneflow_wheel_path = whl_paths[0] else: assert args.oneflow_wheel_python_version assert args.oneflow_wheel_python_version in [ "3.6", "3.7", "3.8", "3.9", "3.10", "3.11", ] ver_cat = args.oneflow_wheel_python_version.replace(".", "") found = False for whl_path in whl_paths: if f"cp{ver_cat}" in whl_path: oneflow_wheel_path = whl_path found = True assert found, whl_paths this_host = args.this_host this_host = resolve_hostname_hardcoded(this_host) remote_hosts = get_remote_hosts(args) print(f"this_host: {this_host}, remote_hosts: {remote_hosts}", flush=True) sub_dir = str(uuid.uuid4()) if args.debug: sub_dir = "debug" workspace_dir = os.path.join( os.path.expanduser("~"), "distributed_run_workspace", sub_dir ) print("workspace_dir", workspace_dir) container_name = ( getpass.getuser() + "-distributed-run-main-node-at-" + this_host.replace(".", "-") ) if args.mode == "multi_client": remote_hosts = [this_host] + remote_hosts loop = asyncio.get_event_loop() # add host key to all machines (needed by ssh/scp/rsync) loop.run_until_complete( asyncio.gather( *[ spawn_shell(f"ssh -o StrictHostKeyChecking=no {remote_host} true") for remote_host in remote_hosts ], ), ) loop.run_until_complete( asyncio.gather( *[ create_remote_workspace_dir( remote_host=remote_host, workspace_dir=workspace_dir, copy_files=args.copy_files, ) for remote_host in remote_hosts ], remove_containers_by_name( remote_hosts=remote_hosts, container_name=container_name ), ), ) if args.oneflow_python_path: so_paths = [ name for name in glob.glob( os.path.join( args.oneflow_python_path, f"oneflow/_oneflow_internal.*.so", ) ) ] assert len(so_paths) == 1, so_paths oneflow_internal_path = so_paths[0] oneflow_internal_path = os.path.join( args.oneflow_python_path, oneflow_internal_path ) tmp_dir = None print("copying oneflow python dir") loop.run_until_complete( asyncio.gather( *[ spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --include='*.py' --exclude='*.so' --exclude='__pycache__' --exclude='oneflow/include' --include='*/' --exclude='*' {args.oneflow_python_path} {remote_host}:{workspace_dir}" ) for remote_host in remote_hosts ] ) ) if args.skip_libs == False: print("copying .so") loop.run_until_complete( fix_and_sync_libs( oneflow_internal_path=oneflow_internal_path, remote_hosts=remote_hosts, ) ) elif oneflow_wheel_path: loop.run_until_complete( asyncio.gather( *[ spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group {oneflow_wheel_path} {remote_host}:{workspace_dir}" ) for remote_host in remote_hosts ] ) ) default_docker_image = "oneflow-test:$USER" ci_user_docker_image = "oneflow-test:0.2" img_tag = None if args.custom_img_tag == None: if is_img_existing(default_docker_image): img_tag = default_docker_image elif is_img_existing(ci_user_docker_image): img_tag = ci_user_docker_image else: loop.run_until_complete( asyncio.gather( *[ build_docker_img( remote_host=remote_host, workspace_dir=workspace_dir ) for remote_host in remote_hosts ], build_docker_img(workspace_dir=workspace_dir), ) ) img_tag = default_docker_image else: img_tag = args.custom_img_tag assert img_tag agent_port = find_free_port() agent_authkey = str(uuid.uuid4()) def exit_handler(): print( "---------start cleanup, you should ignore errors below and check the errors above---------" ) if args.oneflow_python_path: print("fixing permission of", args.oneflow_python_path) subprocess.call( f"docker run --rm -v {args.oneflow_python_path}:/p -w /p busybox chmod -R o+w .", shell=True, ) loop.run_until_complete( asyncio.gather( *[ spawn_shell_ignoring_failure( f"ssh {remote_host} docker run --rm -v {workspace_dir}:/p -w /p busybox chmod -R 777 .", ) for remote_host in remote_hosts ], ) ) print("copying artifacts") extra_exclude_args = "" for path in args.copy_files: extra_exclude_args += f"--exclude='{path}' " loop.run_until_complete( asyncio.gather( *[ spawn_shell_ignoring_failure( f"rsync -azPq --omit-dir-times --no-perms --no-group --exclude='*.whl' --exclude='python' {extra_exclude_args} {remote_host}:{workspace_dir}/ {args.oneflow_test_tmp_dir}/{remote_host}" ) for remote_host in remote_hosts ] ) ) assert workspace_dir if args.debug == False: print("removing docker workspace_dir:", workspace_dir) loop.run_until_complete( asyncio.gather( *[ spawn_shell_ignoring_failure( f"ssh {remote_host} rm -rf {workspace_dir}", ) for remote_host in remote_hosts ], ) ) print("removing docker container:", container_name) loop.run_until_complete( remove_containers_by_name( remote_hosts=remote_hosts, container_name=container_name ) ) atexit.register(exit_handler) if args.mode == "multi_client": if args.bash_script: args.cmd = f"bash {args.bash_script}" loop.run_until_complete( asyncio.gather( *[ launch_remote_container( remote_host=remote_host, survival_time=args.timeout, workspace_dir=workspace_dir, container_name=container_name, oneflow_wheel_path=oneflow_wheel_path, oneflow_python_path=args.oneflow_python_path, img_tag=img_tag, cmd=args.cmd, node_rank=node_rank, master_addr=this_host, ) for node_rank, remote_host in enumerate(remote_hosts) ], ) ) else: loop.run_until_complete( asyncio.gather( *[ launch_remote_container( remote_host=remote_host, survival_time=args.timeout, workspace_dir=workspace_dir, container_name=container_name, oneflow_wheel_path=oneflow_wheel_path, oneflow_python_path=args.oneflow_python_path, img_tag=img_tag, ) for remote_host in remote_hosts ], ) ) ================================================ FILE: ci/test/doctest.sh ================================================ #!/bin/bash set -xe export PYTHONUNBUFFERED=1 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"./test_tmp_dir"} mkdir -p ${test_tmp_dir} cd ${test_tmp_dir} python3 -c 'import oneflow; f=open("oneflow_path.txt", "w"); f.write(oneflow.__path__[0])' gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) python3 $src_dir/ci/test/parallel_run.py \ --gpu_num=${gpu_num} \ --dir=$(cat oneflow_path.txt) \ --timeout=1 \ --verbose \ --chunk=1 \ --doctest ================================================ FILE: ci/test/excludelist ================================================ # This file lists libraries that we will assume to be present on the host system and hence # should NOT be bundled inside AppImages. This is a working document; expect it to change # over time. File format: one filename per line. Each entry should have a justification comment. # See the useful tool at https://abi-laboratory.pro/index.php?view=navigator&symbol=hb_buffer_set_cluster_level#result # to investigate issues with missing symbols. ld-linux.so.2 ld-linux-x86-64.so.2 libanl.so.1 libBrokenLocale.so.1 libcidn.so.1 # libcrypt.so.1 # Not part of glibc anymore as of Fedora 30. See https://github.com/slic3r/Slic3r/issues/4798 and https://pagure.io/fedora-docs/release-notes/c/01d74b33564faa42959c035e1eee286940e9170e?branch=f28 libc.so.6 libdl.so.2 libm.so.6 libmvec.so.1 # libnsl.so.1 # Not part of glibc anymore as of Fedora 28. See https://github.com/RPCS3/rpcs3/issues/5224#issuecomment-434930594 libnss_compat.so.2 # libnss_db.so.2 # Not part of neon-useredition-20190321-0530-amd64.iso libnss_dns.so.2 libnss_files.so.2 libnss_hesiod.so.2 libnss_nisplus.so.2 libnss_nis.so.2 libpthread.so.0 libresolv.so.2 librt.so.1 libthread_db.so.1 libutil.so.1 # These files are all part of the GNU C Library which should never be bundled. # List was generated from a fresh build of glibc 2.25. libstdc++.so.6 # Workaround for: # usr/lib/libstdc++.so.6: version `GLIBCXX_3.4.21' not found libGL.so.1 # The above may be missing on Chrome OS, https://www.reddit.com/r/Crostini/comments/d1lp67/ultimaker_cura_no_longer_running_as_an_appimage/ libEGL.so.1 # Part of the video driver (OpenGL); present on any regular # desktop system, may also be provided by proprietary drivers. # Known to cause issues if it's bundled. libGLdispatch.so.0 libGLX.so.0 # reported to be superfluent and conflicting system libraries (graphics driver) # see https://github.com/linuxdeploy/linuxdeploy/issues/89 libOpenGL.so.0 # Qt installed via install-qt.sh apparently links to this library # part of OpenGL like libGL/libEGL, so excluding it should not cause any problems # https://github.com/linuxdeploy/linuxdeploy/issues/152 libdrm.so.2 # Workaround for: # Antergos Linux release 2015.11 (ISO-Rolling) # /usr/lib/libdrm_amdgpu.so.1: error: symbol lookup error: undefined symbol: drmGetNodeTypeFromFd (fatal) # libGL error: unable to load driver: swrast_dri.so # libGL error: failed to load driver: swrast # Unrecognized OpenGL version libglapi.so.0 # Part of mesa # known to cause problems with graphics, see https://github.com/RPCS3/rpcs3/issues/4427#issuecomment-381674910 libgbm.so.1 # Part of mesa # https://github.com/probonopd/linuxdeployqt/issues/390#issuecomment-529036305 libxcb.so.1 # Workaround for: # Fedora 23 # symbol lookup error: /lib64/libxcb-dri3.so.0: undefined symbol: xcb_send_fd # Uncertain if this is required to be bundled for some distributions - if so we need to write a version check script and use LD_PRELOAD to load the system version if it is newer # Fedora 25: # undefined symbol: xcb_send_request_with_fds # https://github.com/AppImage/AppImages/issues/128 libX11.so.6 # Workaround for: # Fedora 23 # symbol lookup error: ./lib/libX11.so.6: undefined symbol: xcb_wait_for_reply64 # Uncertain if this is required to be bundled for some distributions - if so we need to write a version check script and use LD_PRELOAD to load the system version if it is newer libgio-2.0.so.0 # Workaround for: # On Ubuntu, "symbol lookup error: /usr/lib/x86_64-linux-gnu/gtk-2.0/modules/liboverlay-scrollbar.so: undefined symbol: g_settings_new" # libgdk-x11-2.0.so.0 # Missing on openSUSE-Tumbleweed-KDE-Live-x86_64-Snapshot20170601-Media.iso # libgtk-x11-2.0.so.0 # Missing on openSUSE-Tumbleweed-KDE-Live-x86_64-Snapshot20170601-Media.iso libasound.so.2 # Workaround for: # No sound, e.g., in VLC.AppImage (does not find sound cards) # https://github.com/AppImage/pkg2appimage/issues/475 # libgdk_pixbuf-2.0.so.0 # Was: Workaround for: # On Ubuntu, get (inkscape:25621): GdkPixbuf-WARNING **: Error loading XPM image loader: Image type 'xpm' is not supported libfontconfig.so.1 # Workaround for: # Application stalls when loading fonts during application launch; e.g., KiCad on ubuntu-mate libthai.so.0 # Workaround for: # audacity: /tmp/.mount_AudaciUsFbON/usr/lib/libthai.so.0: version `LIBTHAI_0.1.25' not found (required by /usr/lib64/libpango-1.0.so.0) # on openSUSE Tumbleweed # other "low-level" font rendering libraries # should fix https://github.com/probonopd/linuxdeployqt/issues/261#issuecomment-377522251 # and https://github.com/probonopd/linuxdeployqt/issues/157#issuecomment-320755694 libfreetype.so.6 libharfbuzz.so.0 # Note, after discussion we do not exlude this, but we can use a dummy library that just does nothing # libselinux.so.1 # Workaround for: # sed: error while loading shared libraries: libpcre.so.3: cannot open shared object file: No such file or directory # Some distributions, such as Arch Linux, do not come with libselinux.so.1 by default. # The solution is to bundle a dummy mock library: # echo "extern int is_selinux_enabled(void){return 0;}" >> selinux-mock.c # gcc -s -shared -o libselinux.so.1 -Wl,-soname,libselinux.so.1 selinux-mock.c # strip libselinux.so.1 # More information: https://github.com/AppImage/AppImages/issues/83 # and https://github.com/AppImage/AppImageKit/issues/775#issuecomment-614954821 # https://gitlab.com/sulinos/devel/libselinux-dummy # The following are assumed to be part of the base system # Removing these has worked e.g., for Krita. Feel free to report if # you think that some of these should go into AppImages and why. libcom_err.so.2 libexpat.so.1 libgcc_s.so.1 libglib-2.0.so.0 libgpg-error.so.0 # libgssapi_krb5.so.2 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there # libgssapi.so.3 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23 # libhcrypto.so.4 # Missing on openSUSE LEAP 42.0 # libheimbase.so.1 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23 # libheimntlm.so.0 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23 # libhx509.so.5 # Missing on openSUSE LEAP 42.0 libICE.so.6 # libidn.so.11 # Does not come with Solus by default # libk5crypto.so.3 # Runnning AppImage built on Debian 9 or Ubuntu 16.04 on an Archlinux fails otherwise; https://github.com/AppImage/AppImages/issues/301 # libkeyutils.so.1 # Does not come with Void Linux by default; https://github.com/Subsurface-divelog/subsurface/issues/1971#issuecomment-466606834 # libkrb5.so.26 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there. Missing on openSUSE LEAP 42.0 # libkrb5.so.3 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there # libkrb5support.so.0 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there libp11-kit.so.0 # libpcre.so.3 # Missing on Fedora 24, SLED 12 SP1, and openSUSE Leap 42.2 # libroken.so.18 # Mission on openSUSE LEAP 42.0 # libsasl2.so.2 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23 libSM.so.6 libusb-1.0.so.0 libuuid.so.1 # libwind.so.0 # Missing on openSUSE LEAP 42.0 # Potentially dangerous libraries libgobject-2.0.so.0 # Workaround for: # Rectangles instead of fonts # https://github.com/AppImage/AppImages/issues/240 libpangoft2-1.0.so.0 libpangocairo-1.0.so.0 libpango-1.0.so.0 # FIXME: # Can get symbol lookup error: /lib64/libpango-1.0.so.0: undefined symbol: g_log_structured_standard # if libcairo is bundled but libpango is not # Workaround for: # e.g., Spotify # relocation error: /lib/x86_64-linux-gnu/libgcrypt.so.20: # symbol gpgrt_lock_lock, version GPG_ERROR_1.0 not defined # in file libgpg-error.so.0 with link time reference libgpg-error.so.0 libjack.so.0 # it must match the ABI of the JACK server which is installed in the base system # rncbc confirmed this # However, this library is missing on Fedora-WS-Live-31-1-9 # which means that we should avoid using JACK altogether if possible # Unsolved issue: # https://github.com/probonopd/linuxdeployqt/issues/35 # Error initializing NSS with a persistent database (sql:/home/me/.pki/nssdb): libsoftokn3.so: cannot open shared object file: No such file or directory # Error initializing NSS without a persistent database: NSS error code: -5925 # nss_error=-5925, os_error=0 # libnss3.so should not be removed from the bundles, as this causes other issues, e.g., # https://github.com/probonopd/linuxdeployqt/issues/35#issuecomment-256213517 # and https://github.com/AppImage/AppImages/pull/114 # libnss3.so # The following cannot be excluded, see # https://github.com/AppImage/AppImages/commit/6c7473d8cdaaa2572248dcc53d7f617a577ade6b # http://stackoverflow.com/questions/32644157/forcing-a-binary-to-use-a-specific-newer-version-of-a-shared-library-so # libssl.so.1 # libssl.so.1.0.0 # libcrypto.so.1 # libcrypto.so.1.0.0 # According to https://github.com/RicardoEPRodrigues/3Engine/issues/4#issuecomment-511598362 # libGLEW is not tied to a specific GPU. It's linked against libGL.so.1 # and that one is different depending on the installed driver. # In fact libGLEW is changing its soversion very often, so you should always bundle libGLEW.so.2.0 # libglut.so.3 # to be confirmed libxcb-dri3.so.0 # https://github.com/AppImage/AppImages/issues/348 libxcb-dri2.so.0 # https://github.com/probonopd/linuxdeployqt/issues/331#issuecomment-442276277 # If the next line turns out to cause issues, we will have to remove it again and find another solution libfribidi.so.0 # https://github.com/olive-editor/olive/issues/221 and https://github.com/knapsu/plex-media-player-appimage/issues/14 # Workaround for: # symbol lookup error: /lib/x86_64-linux-gnu/libgnutls.so.30: undefined symbol: __gmpz_limbs_write # https://github.com/ONLYOFFICE/appimage-desktopeditors/issues/3 # Apparently coreutils depends on it, so it should be safe to assume that it comes with every target system libgmp.so.10 ================================================ FILE: ci/test/expensive_generic_test_multi_client.sh ================================================ #!/bin/bash set -xe export PYTHONUNBUFFERED=1 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} ONEFLOW_TEST_DIR=${ONEFLOW_TEST_DIR:-"$PWD/python/oneflow/test/modules"} cd $ONEFLOW_TEST_DIR if [ -z "$ONEFLOW_TEST_CPU_ONLY" ] then gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) for ((i=0;i 0: files += list(glob.glob(ff, recursive=True)) print("total files:", len(files)) files = sorted( files, key=lambda x: hashlib.md5(os.path.basename(x.encode("ascii"))).hexdigest(), ) if args.shuffle: random.shuffle(files) files_hash = hashlib.md5( "".join([os.path.basename(x) for x in files]).encode() ).hexdigest()[:8] if args.verbose: print( f"::warning file=testFilesHash,line={len(files)},col=0,endColumn=0::shuffle-{args.shuffle}-group_size-{args.group_size}-md5-{files_hash}" ) if args.parallel_num == "master_port": parallel_num = len(args.master_port) master_ports = args.master_port else: parallel_num = int(args.parallel_num) if parallel_num != len(args.master_port): print( "warning", "parallel_num != len(args.master_port)", "will auto generate" ) default_master_port = 29500 master_ports = list( range(default_master_port, default_master_port + parallel_num) ) assert parallel_num > 0 assert len(master_ports) == parallel_num chunk_size = ceil(len(files) / parallel_num) global PARALLEL_NUM PARALLEL_NUM = parallel_num chunks = [files[i : i + chunk_size] for i in range(0, len(files), chunk_size)] # check args assert args.training_script == "oneflow.distributed.launch" # generate commands cmds = [ [sys.executable, "-m", args.training_script, "--master_port", str(master_port)] + args.training_script_args + chunck for (master_port, chunck) in zip(master_ports, chunks) ] loop = asyncio.get_event_loop() processes = launch_multiple( cmds=cmds, auto_cuda_env=args.auto_cuda_visible_devices, group_size=args.group_size, device_num=args.device_num, ) loop.run_until_complete(processes) if __name__ == "__main__": main() ================================================ FILE: ci/test/parallel_run.py ================================================ import asyncio import os import argparse from subprocess import PIPE, STDOUT import glob import sys import time import socket from contextlib import closing import uuid def gen_cmds(cmd=None, dir=None, doctest=False): if doctest: paths = glob.glob(os.path.join(dir, "**/*.py"), recursive=True) paths = [ p for p in paths if "compatible" not in p and "single_client" not in p and "unittest.py" not in p ] with_doctest = [] for p in paths: with open(p) as f: content = f.read() if "import doctest" in content: with_doctest.append("{} {} -v".format(cmd, p)) print(with_doctest) return with_doctest else: paths = glob.glob(os.path.join(dir, "test_*.py"), recursive=False) return ["{} {} --failfast --verbose".format(cmd, p) for p in paths] def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("localhost", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] def split_and_print(prefix, text): lines = text.splitlines(keepends=True) prefixed = "" for l in lines: prefixed += f"{prefix} {l}" print(prefixed, flush=True) def everyN(l: list, n: int): for i in range(0, len(l), n): yield l[i : i + n] def contains_oom_info(txt: str): return "memory" in txt or "Memory" in txt or "CUDNN" in txt or "ALLOC" in txt def should_retry(txt: str): return contains_oom_info(txt) def print_out(prefix: str = "", content: str = ""): for l in content.split("\n"): print(f"[{prefix}]", l) async def spawn_shell_and_check(cmd: str = None, gpu_id: int = -1, check: bool = False): is_cpu_only = os.getenv("ONEFLOW_TEST_CPU_ONLY") print(f"[gpu={gpu_id}]", cmd) p = await asyncio.create_subprocess_shell( cmd, stdout=PIPE, stderr=STDOUT, env=dict( os.environ, CUDA_VISIBLE_DEVICES=("-1" if is_cpu_only else ",".join([str(gpu_id)])), ONEFLOW_TEST_MASTER_PORT=str(find_free_port()), ONEFLOW_TEST_LOG_DIR=("./unittest-log-" + str(uuid.uuid4())), ), ) (stdout_data, stderr_data) = await p.communicate() decoded = stdout_data.decode() if check or should_retry(decoded) == False: if p.returncode != 0: print_out(prefix=cmd, content=decoded) raise RuntimeError(cmd) return {"returncode": p.returncode, "cmd": cmd, "stdout": decoded} async def run_cmds( cmds, gpu_num=0, timeout=10, chunk=1, verbose=False, per_gpu_process_num=1 ): is_cpu_only = os.getenv("ONEFLOW_TEST_CPU_ONLY") if is_cpu_only: gpu_num = os.cpu_count() fails = [] assert gpu_num > 0 for cmdN in everyN(cmds, per_gpu_process_num * gpu_num): results = await asyncio.gather( *[ spawn_shell_and_check( cmd=cmd, gpu_id=i, check=(per_gpu_process_num == 1) ) for cmd_gpu_num in everyN(cmdN, gpu_num) for (i, cmd) in enumerate(cmd_gpu_num) ], ) for r in list(results): if r["returncode"] != 0: fails.append(r["cmd"]) else: print_out(prefix=r["cmd"], content=r["stdout"]) return fails if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--gpu_num", type=int, required=True, default=0) parser.add_argument("--dir", type=str, required=True, default=".") parser.add_argument("--cmd", type=str, required=False, default=sys.executable) parser.add_argument("--timeout", type=int, required=False, default=2) parser.add_argument("--chunk", type=int, required=True) parser.add_argument("--verbose", action="store_true", required=False, default=False) parser.add_argument("--doctest", action="store_true", required=False, default=False) args = parser.parse_args() cmds = gen_cmds(cmd=args.cmd, dir=args.dir, doctest=args.doctest) start = time.time() loop = asyncio.get_event_loop() PER_GPU_PROCESS_NUMS = [12, 8, 2, 1] is_cpu_only = os.getenv("ONEFLOW_TEST_CPU_ONLY") if is_cpu_only: PER_GPU_PROCESS_NUMS = [1] for per_gpu_process_num in PER_GPU_PROCESS_NUMS: print("[per_gpu_process_num]", per_gpu_process_num) cmds = loop.run_until_complete( run_cmds( cmds, gpu_num=args.gpu_num, timeout=args.timeout, chunk=args.chunk, verbose=args.verbose, per_gpu_process_num=per_gpu_process_num, ) ) elapsed = time.time() - start elapsed_time_txt = time.strftime("elapsed: %H:%M:%S", time.gmtime(elapsed)) print(elapsed_time_txt) ================================================ FILE: ci/test/print_stack_from_core.sh ================================================ set -ex if compgen -G "$2/core.*" > /dev/null; then gdb --batch --quiet -ex "thread apply all bt full" -ex "quit" $1 $2/core.* fi ================================================ FILE: ci/test/print_stack_in_all_dirs.sh ================================================ set -ex find . -type f -name "core.*" -exec gdb --batch --quiet -ex "thread apply all bt full" -ex "quit" python3 {} \; ================================================ FILE: ci/test/resource-spec/1x-gtx-1080.json ================================================ { "version": { "major": 1, "minor": 0 }, "local": [ { "vram": [ { "id": "0", "slots": 8117 } ] } ] } ================================================ FILE: ci/test/resource-spec/2x-rtx-2080.json ================================================ { "version": { "major": 1, "minor": 0 }, "local": [ { "vram": [ { "id": "0", "slots": 7982 }, { "id": "1", "slots": 7982 } ] } ] } ================================================ FILE: ci/test/resource-spec/4x-rtx-2080ti.json ================================================ { "version": { "major": 1, "minor": 0 }, "local": [ { "vram": [ { "id": "0", "slots": 11019 }, { "id": "1", "slots": 11019 }, { "id": "2", "slots": 11019 }, { "id": "3", "slots": 11019 } ] } ] } ================================================ FILE: ci/test/test_mock_function.sh ================================================ #!/bin/bash set -e MOCK_UNITTEST=$PWD/python/oneflow/test/misc/test_mock_scope.py python3 $MOCK_UNITTEST --failfast --verbose # testing import * python3 -c " import oneflow import oneflow.nn import oneflow.mock_torch as mock; mock.enable(); from torch.sbp import *; assert(sbp == oneflow.sbp.sbp); from torch import *; assert(randn == oneflow.randn); from torch.nn import *; assert(Graph == oneflow.nn.Graph); mock.disable(); from torch import *; assert(randn != oneflow.randn); from torch.nn import *; assert(Graph != oneflow.nn.Graph); " ================================================ FILE: ci/test/test_mock_script.sh ================================================ #!/bin/bash set -e python_version=$(python3 --version 2>&1 | awk '{print $2}') if [[ "$python_version" < "3.8" ]]; then echo "Python version is less than 3.8." exit 0 fi MOCK_TORCH=$PWD/python/oneflow/test/misc/mock_example.py same_or_exit() { if [[ "$(python3 $MOCK_TORCH)" != *"$1"* ]]; then exit 1 fi } # generate pytorch file python3 -c "import torch; torch.save(torch.ones(1), 'test.pt')" eval $(python3 -m oneflow.mock_torch) # test call to python module, default argument is enable same_or_exit "True" # test load pytorch file with mock torch enabled python3 -c """ import torch x = torch.load('test.pt') assert torch.equal(x, torch.ones(1)) import torch.nn assert 'oneflow/nn/__init__.py' in torch.nn.__file__ """ # testing import python3 -c 'import torch; torch.randn(2,3)' python3 -c 'import torch.nn; torch.nn.Graph' python3 -c 'import torch.version; torch.version.__version__' python3 -c 'from torch import *; randn(2,3)' python3 -c 'from torch.nn import *; Graph' python3 -c 'from torch.sbp import *; sbp' python3 -c 'from torch import nn; nn.Graph' python3 -c 'from torch.version import __version__' python3 -c 'import torch; torch.not_exist' 2>&1 >/dev/null | grep -q 'AttributeError' python3 -c 'import torch.not_exist' 2>&1 >/dev/null | grep -q 'ModuleNotFoundError' eval $(python3 -m oneflow.mock_torch disable) same_or_exit "False" eval $(python3 -m oneflow.mock_torch enable) same_or_exit "True" eval $(python3 -m oneflow.mock_torch disable) # recover same_or_exit "False" eval $(oneflow-mock-torch) # test scripts same_or_exit "True" eval $(oneflow-mock-torch disable) same_or_exit "False" eval $(oneflow-mock-torch enable) same_or_exit "True" eval $(oneflow-mock-torch disable) same_or_exit "False" # test load pytorch file with mock torch disabled python3 -c "import oneflow as flow; x = flow.load('test.pt'); assert flow.equal(x, flow.ones(1))" rm test.pt eval $(python3 -m oneflow.mock_torch --lazy --verbose) python3 -c "import torch.not_exist" | grep -q 'dummy object' ================================================ FILE: ci/test/test_resnet50_graph_ddp.sh ================================================ #!/usr/bin/env bash set -ex cd $ONEFLOW_MODELS_DIR ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR:-"/dataset"} OFRECORD_PATH=${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord if [ ! -d "${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord/train" ];then mkdir -p ./dataset/ofrecord ln -s ${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord ./dataset/ofrecord/train OFRECORD_PATH=./dataset/ofrecord fi python3 -m oneflow.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 Vision/classification/image/resnet50/train.py --ofrecord-path $OFRECORD_PATH --ofrecord-part-num 1 --num-devices-per-node 1 --lr 0.004 --momentum 0.875 --num-epochs 1 --train-batch-size 4 --val-batch-size 50 --print-interval 10 --exit-num 1 --ddp python3 -m oneflow.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 Vision/classification/image/resnet50/train.py --ofrecord-path $OFRECORD_PATH --ofrecord-part-num 2 --num-devices-per-node 1 --lr 0.004 --momentum 0.875 --num-epochs 1 --train-batch-size 4 --val-batch-size 50 --print-interval 10 --exit-num 1 --use-fp16 --channel-last --scale-grad --graph --fuse-bn-relu --fuse-bn-add-relu --use-gpu-decode ================================================ FILE: ci/test/test_speed_multi_client.sh ================================================ #!/usr/bin/env bash set -uxo pipefail rc=0 # accumulate the score of every test trap 'rc=$(($rc + $?))' ERR cd $ONEFLOW_MODELS_DIR function check_relative_speed { # Default score is 1 SCORE=${2:-1} awk -F'[:(]' -v threshold=$1 -v score=$SCORE 'BEGIN { ret=2 } /Relative speed/{ if ($2 >= threshold) { printf "✔️ "; ret=0 } else { printf "❌ "; ret=score }} {print $0} END { exit ret }' } function check_millisecond_time { # Default score is 1 SCORE=${2:-1} awk -F'[:(]' -v threshold=$1 -v score=$SCORE 'BEGIN { ret=2 } /OneFlow/{ if (substr($2, 2, length($2) - 4) <= threshold) { printf "✔️ "; ret=0 } else { printf "❌ "; ret=score }} { print $0 } END { exit ret }' } function write_to_file_and_print { tee -a result printf "\n" >> result } python3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.05 | check_millisecond_time 129.0 2 | write_to_file_and_print python3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.04 | write_to_file_and_print python3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.01 | write_to_file_and_print python3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.99 | write_to_file_and_print python3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.95 | write_to_file_and_print python3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 1 | write_to_file_and_print python3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 4 | write_to_file_and_print python3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 8 | write_to_file_and_print export OMP_NUM_THREADS=1 python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 1.12 | check_millisecond_time 136.3 2 | write_to_file_and_print python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 1.1 | write_to_file_and_print python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.18 | write_to_file_and_print python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.18 | write_to_file_and_print python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.15 | write_to_file_and_print result="GPU Name: `nvidia-smi --query-gpu=name --format=csv,noheader -i 0` \n\n `cat result`" # escape newline for github actions: https://github.community/t/set-output-truncates-multiline-strings/16852/2 # note that we escape \n and \r to \\n and \\r (i.e. raw string "\n" and "\r") instead of %0A and %0D, # so that they can be correctly handled in javascript code result="${result//'%'/'%25'}" result="${result//$'\n'/'\\n'}" result="${result//$'\r'/'\\r'}" echo "::set-output name=stats::$result" # Only fail when the sum of score >= 2 if (( $rc >= 2 )) then exit 1 else exit 0 fi ================================================ FILE: ci/test/try_install.sh ================================================ #!/bin/bash set -xe src_dir=${ONEFLOW_SRC_DIR:-"$PWD"} wheel_path=${ONEFLOW_WHEEL_PATH:-"$PWD/wheelhouse"} index=${ONEFLOW_PIP_INDEX} pkg_name=${ONEFLOW_PACKAGE_NAME:-"oneflow"} if [ -n "$index" ]; then python3 -m pip install --find-links ${index} ${pkg_name} elif [ -d "$wheel_path" ]; then ls -la $wheel_path export PATH=/root/.local/bin:$PATH python3 -m pip install https://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl --user pipindex build $wheel_path python3 -m pip install -U --user --extra-index-url file://${wheel_path}/simple ${pkg_name} elif [ -e "$wheel_path" ]; then python3 -m pip install --user "$wheel_path" elif [ -d "$src_dir" ]; then python3 -m pip install -e "$src_dir" --user else echo "wheel not found: $wheel_path, src dir not found: $src_dir, continue anyway..." fi ================================================ FILE: cmake/caches/ci/canary/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61-real;70-real;75-real;80-real;86-real" CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_CPP_API OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") ================================================ FILE: cmake/caches/ci/cpu-asan-ubsan.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(WITH_ONEDNN YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(BUILD_CPP_API ON CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_FOR_CI ON CACHE BOOL "") set(BUILD_SHARED_LIBS ON CACHE BOOL "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(ENABLE_ASAN ON CACHE BOOL "") set(ENABLE_UBSAN OFF CACHE BOOL "") ================================================ FILE: cmake/caches/ci/cpu-tsan.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(WITH_ONEDNN YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(BUILD_CPP_API ON CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_FOR_CI ON CACHE BOOL "") set(BUILD_SHARED_LIBS ON CACHE BOOL "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(ENABLE_TSAN ON CACHE BOOL "") ================================================ FILE: cmake/caches/ci/cpu.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_NPU NO CACHE BOOL "") set(BUILD_MLU NO CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(WITH_ONEDNN YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(BUILD_CPP_API ON CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_FOR_CI ON CACHE BOOL "") set(BUILD_SHARED_LIBS ON CACHE BOOL "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") ================================================ FILE: cmake/caches/ci/cuda-xla.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61;75" CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(RPC_BACKEND "LOCAL" CACHE STRING "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") ================================================ FILE: cmake/caches/ci/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "75;86" CACHE STRING "") set(CUDNN_STATIC ON CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_CPP_API ON CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") set(BUILD_FOR_CI ON CACHE BOOL "") set(CMAKE_CXX_FLAGS "-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder" CACHE STRING "") ================================================ FILE: cmake/caches/ci/gh-hosted/cpu-clang.cmake ================================================ set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") ================================================ FILE: cmake/caches/ci/gh-hosted/cpu-gcc.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") ================================================ FILE: cmake/caches/ci/llvm/cuda-75-clang.cmake ================================================ set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(WITH_MLIR YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(CMAKE_CUDA_ARCHITECTURES "75;52-real" CACHE STRING "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(RPC_BACKEND "LOCAL" CACHE STRING "") set(BUILD_HWLOC NO CACHE BOOL "") ================================================ FILE: cmake/caches/ci/profiler/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61-real;70-real;75-real;80-real;86-real" CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_PROFILER ON CACHE BOOL "") set(BUILD_CPP_API OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") ================================================ FILE: cmake/caches/ci/release/cpu.cmake ================================================ set(BUILD_CUDA OFF CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_CPP_API OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_FLAGS "-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder" CACHE STRING "") ================================================ FILE: cmake/caches/ci/release/cu118.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "70-real;80-real;86-real;89-real;90-real" CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_CPP_API OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 2 CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_FLAGS "-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder" CACHE STRING "") ================================================ FILE: cmake/caches/ci/release/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(BUILD_RDMA YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") set(BUILD_CPP_API OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 2 CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_FLAGS "-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder" CACHE STRING "") ================================================ FILE: cmake/caches/ci/serving/cuda-75.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_CPP_API YES CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "75" CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") ================================================ FILE: cmake/caches/ci/serving/openvino.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_CPP_API ON CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING ON CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") # set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") set(WITH_ONEDNN OFF CACHE BOOL "") set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") ================================================ FILE: cmake/caches/cn/cpu.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_NPU NO CACHE BOOL "") set(BUILD_MLU NO CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") ================================================ FILE: cmake/caches/cn/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") ================================================ FILE: cmake/caches/cn/fast/cpu-clang.cmake ================================================ set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/cpu.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/cuda-61-clang.cmake ================================================ set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(CMAKE_CUDA_ARCHITECTURES "61" CACHE STRING "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/cuda-61.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61" CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/cuda-75-clang.cmake ================================================ set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(CMAKE_CUDA_ARCHITECTURES "75" CACHE STRING "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/cuda-75.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "75" CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") # uncomment these when necessary, otherwise it is for the demonstration purpose # set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") # set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") # set(CMAKE_CUDA_HOST_COMPILER clang++ CACHE STRING "") # set(CMAKE_C_COMPILER "clang" CACHE STRING "") # set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") # set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") # set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") # set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") ================================================ FILE: cmake/caches/cn/fast/cuda-86.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "86" CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/mlir-cpu.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN NO CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_HWLOC OFF CACHE BOOL "") set(WITH_ONEDNN OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/mlir-cuda-61.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61-real" CACHE STRING "") set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CMAKE_C_COMPILER "clang" CACHE STRING "") set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_HWLOC OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/mlir-cuda-75.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "75" CACHE STRING "") set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_HWLOC OFF CACHE BOOL "") set(WITH_ONEDNN OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/mlir-cuda-80.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "80" CACHE STRING "") set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CPU_THREADING_RUNTIME SEQ CACHE STRING "when using lld with TBB enabled, there will be linkage error") set(BUILD_HWLOC OFF CACHE BOOL "") set(WITH_ONEDNN OFF CACHE BOOL "") ================================================ FILE: cmake/caches/cn/fast/mlir-cuda-86.cmake ================================================ set(BUILD_SHARED_LIBS YES CACHE BOOL "") # uncomment only if you know what you are doing # set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(BUILD_HWLOC NO CACHE BOOL "") set(BUILD_TESTING OFF CACHE BOOL "") set(WITH_MLIR YES CACHE BOOL "") set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "86" CACHE STRING "") set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CPU_THREADING_RUNTIME SEQ CACHE STRING "when using lld with TBB enabled, there will be linkage error") set(BUILD_HWLOC OFF CACHE BOOL "") set(WITH_ONEDNN OFF CACHE BOOL "") ================================================ FILE: cmake/caches/international/cpu.cmake ================================================ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") ================================================ FILE: cmake/caches/international/cuda.cmake ================================================ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") ================================================ FILE: cmake/cuda.cmake ================================================ if(BUILD_CUDA) if(DEFINED CUDA_TOOLKIT_ROOT_DIR) message(WARNING "CUDA_TOOLKIT_ROOT_DIR is deprecated, use CUDAToolkit_ROOT instead") set(CUDAToolkit_ROOT ${CUDA_TOOLKIT_ROOT_DIR}) endif(DEFINED CUDA_TOOLKIT_ROOT_DIR) find_package(CUDAToolkit REQUIRED) message(STATUS "CUDAToolkit_FOUND: ${CUDAToolkit_FOUND}") message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}") message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}") message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}") message(STATUS "CUDAToolkit_VERSION_PATCH: ${CUDAToolkit_VERSION_PATCH}") message(STATUS "CUDAToolkit_BIN_DIR: ${CUDAToolkit_BIN_DIR}") message(STATUS "CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}") message(STATUS "CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}") message(STATUS "CUDAToolkit_LIBRARY_ROOT: ${CUDAToolkit_LIBRARY_ROOT}") message(STATUS "CUDAToolkit_TARGET_DIR: ${CUDAToolkit_TARGET_DIR}") message(STATUS "CUDAToolkit_NVCC_EXECUTABLE: ${CUDAToolkit_NVCC_EXECUTABLE}") if(CUDA_NVCC_GENCODES) message(FATAL_ERROR "CUDA_NVCC_GENCODES is deprecated, use CMAKE_CUDA_ARCHITECTURES instead") endif() add_definitions(-DWITH_CUDA) # NOTE: For some unknown reason, CUDAToolkit_VERSION may become empty when running cmake again set(CUDA_VERSION ${CUDAToolkit_VERSION} CACHE STRING "") if(NOT CUDA_VERSION) message(FATAL_ERROR "CUDA_VERSION empty") endif() message(STATUS "CUDA_VERSION: ${CUDA_VERSION}") if(CUDA_VERSION VERSION_GREATER_EQUAL "11.0") set(CUDA_STATIC OFF CACHE BOOL "") else() set(CUDA_STATIC ON CACHE BOOL "") endif() if((NOT CUDA_STATIC) OR BUILD_SHARED_LIBS) set(OF_CUDA_LINK_DYNAMIC_LIBRARY ON) else() set(OF_CUDA_LINK_DYNAMIC_LIBRARY OFF) endif() if(OF_CUDA_LINK_DYNAMIC_LIBRARY) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublas) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::curand) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cusolver) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cufft) if(CUDA_VERSION VERSION_GREATER_EQUAL "10.1") list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublasLt) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "10.2") list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nvjpeg) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppc) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppig) endif() else() list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublas_static) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::curand_static) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cufft_static) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cusolver_static) if(CUDA_VERSION VERSION_GREATER_EQUAL "10.1") list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublasLt_static) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "10.2") list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nvjpeg_static) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppig_static) # Must put nppc_static after nppig_static in CUDA 10.2 list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppc_static) list(APPEND VENDOR_CUDA_LIBRARIES CUDA::culibos) endif() endif() message(STATUS "VENDOR_CUDA_LIBRARIES: ${VENDOR_CUDA_LIBRARIES}") # add a cache entry if want to use a ccache/sccache wrapped nvcc set(CMAKE_CUDA_COMPILER ${CUDAToolkit_NVCC_EXECUTABLE} CACHE STRING "") message(STATUS "CMAKE_CUDA_COMPILER: ${CMAKE_CUDA_COMPILER}") set(CMAKE_CUDA_STANDARD 17) find_package(CUDNN REQUIRED) # NOTE: if you want to use source PTX with a version different from produced PTX/binary, you should add flags if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(CUDA_VERSION VERSION_GREATER_EQUAL "10.0") # T4, Quadro RTX xxxx, Txxxx, Geforce RTX 20xx, TITAN RTX list(APPEND CMAKE_CUDA_ARCHITECTURES 75-real) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "11.0") # A100 list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "11.1") # GeForce RTX 30xx list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "11.8") # GeForce RTX 40xx list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL "12.0") # H100, H20 list(APPEND CMAKE_CUDA_ARCHITECTURES 90-real) endif() endif() foreach(CUDA_ARCH ${CMAKE_CUDA_ARCHITECTURES}) if(CUDA_ARCH MATCHES "^([0-9]+)\\-real$") list(APPEND CUDA_REAL_ARCHS_LIST ${CMAKE_MATCH_1}) elseif(CUDA_ARCH MATCHES "^([0-9]+)$") list(APPEND CUDA_REAL_ARCHS_LIST ${CMAKE_MATCH_1}) endif() endforeach() enable_language(CUDA) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") set(CUDA_SEPARABLE_COMPILATION OFF) if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA") if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.2") set(CUDA_NVCC_THREADS_NUMBER "4" CACHE STRING "") list(APPEND CUDA_NVCC_FLAGS -t ${CUDA_NVCC_THREADS_NUMBER}) endif() list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=-fno-strict-aliasing") message(STATUS "CUDA_NVCC_FLAGS: " ${CUDA_NVCC_FLAGS}) list(JOIN CUDA_NVCC_FLAGS " " CMAKE_CUDA_FLAGS) endif() endif() ================================================ FILE: cmake/functional.cmake ================================================ function(GENERATE_FUNCTIONAL_API_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/core/functional/functional_api.yaml) set(GENERATED_API_DIR oneflow/core/functional) list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp) list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h) if(BUILD_PYTHON) set(GENERATED_PYBIND_DIR oneflow/api/python/functional) list(APPEND PYBIND_SRCS ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/functional_api.yaml.pybind.cpp) endif(BUILD_PYTHON) if(BUILD_PYTHON) add_custom_command( OUTPUT "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp" "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h" "${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/functional_api.yaml.pybind.cpp" COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR} COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR} COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py --project_source_dir ${PROJECT_SOURCE_DIR} --export_pybind DEPENDS ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE} VERBATIM) else() # build_python add_custom_command( OUTPUT "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp" "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h" COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR} COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py --project_source_dir ${PROJECT_SOURCE_DIR} DEPENDS ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE} VERBATIM) endif(BUILD_PYTHON) set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) set(${SRCS} ${${SRCS}} PARENT_SCOPE) set(${HDRS} ${${HDRS}} PARENT_SCOPE) if(BUILD_PYTHON) set_source_files_properties(${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE) set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE) endif(BUILD_PYTHON) endfunction() function(GENERATE_FUNCTIONAL_TENSOR_API_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/api/python/functional/tensor_api.yaml) set(GENERATED_API_DIR oneflow/api/python/functional) set(GENERATED_PYBIND_DIR oneflow/api/python/functional) list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.cpp) list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.h) list(APPEND PYBIND_SRCS ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/tensor_api.yaml.pybind.cpp) add_custom_command( OUTPUT "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.cpp" "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.h" "${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/tensor_api.yaml.pybind.cpp" COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR} COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR} COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS ${PROJECT_SOURCE_DIR}/tools/functional/generate_tensor_api.py --project_source_dir ${PROJECT_SOURCE_DIR} DEPENDS ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/functional/generate_tensor_api.py ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE} VERBATIM) set_source_files_properties(${${SRCS}} ${${HDRS}} ${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE) set(${SRCS} ${${SRCS}} PARENT_SCOPE) set(${HDRS} ${${HDRS}} PARENT_SCOPE) set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE) endfunction() function(GENERATE_FUNCTIONAL_DISPATCH_STATEFUL_OPS_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/api/python/functional/dispatch_stateful_ops.yaml) set(GENERATED_API_DIR oneflow/api/python/functional) set(GENERATED_PYBIND_DIR oneflow/api/python/functional) list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp) list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h) list(APPEND PYBIND_SRCS ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp) add_custom_command( OUTPUT "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp" "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h" "${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp" COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR} COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR} COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py --project_source_dir ${PROJECT_SOURCE_DIR} DEPENDS ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE} VERBATIM) set_source_files_properties(${${SRCS}} ${${HDRS}} ${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE) set(${SRCS} ${${SRCS}} PARENT_SCOPE) set(${HDRS} ${${HDRS}} PARENT_SCOPE) set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE) endfunction() ================================================ FILE: cmake/git_version.cmake ================================================ cmake_minimum_required(VERSION 3.5) execute_process( COMMAND git describe --tags --always --dirty=-snapshot WORKING_DIRECTORY ${OF_GIT_VERSION_ROOT} OUTPUT_VARIABLE GIT_REV ERROR_QUIET) if(("${GIT_REV}" STREQUAL "") OR (NOT BUILD_GIT_VERSION)) set(GIT_REV "N/A") else() string(STRIP "${GIT_REV}" GIT_REV) endif() set(VERSION_FILE_CONTENT "namespace oneflow {\n\ \n\ const char* GetOneFlowGitVersion() {\n\ return \"${GIT_REV}\";\n\ }\n\ \n\ }\n") if(EXISTS ${OF_GIT_VERSION_FILE}) file(READ ${OF_GIT_VERSION_FILE} VERSION_FILE_CONTENT_) else() set(VERSION_FILE_CONTENT_ "") endif() if(NOT "${VERSION_FILE_CONTENT}" STREQUAL "${VERSION_FILE_CONTENT_}") file(WRITE ${OF_GIT_VERSION_FILE} "${VERSION_FILE_CONTENT}") endif() ================================================ FILE: cmake/oneflow-config.cmake ================================================ if(DEFINED ENV{ONEFLOW_INSTALL_PREFIX}) set(ONEFLOW_INSTALL_PREFIX $ENV{ONEFLOW_INSTALL_PREFIX}) else() get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) get_filename_component(ONEFLOW_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../" ABSOLUTE) endif() set(ONEFLOW_INCLUDE_DIRS ${ONEFLOW_INSTALL_PREFIX}/include) find_library(ONEFLOW_LIBRARY NAMES oneflow_cpp PATHS ${ONEFLOW_INSTALL_PREFIX}/lib REQUIRED) if(NOT TARGET OneFlow::liboneflow) add_library(OneFlow::liboneflow INTERFACE IMPORTED) set_property(TARGET OneFlow::liboneflow PROPERTY INTERFACE_LINK_LIBRARIES ${ONEFLOW_LIBRARY}) set_property(TARGET OneFlow::liboneflow PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${ONEFLOW_INCLUDE_DIRS}) endif() ================================================ FILE: cmake/oneflow.cmake ================================================ include(python) function(oneflow_add_executable) add_executable(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) endfunction() function(oneflow_add_library) add_library(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) endfunction() # source_group if(WIN32) set(oneflow_platform "windows") list(APPEND oneflow_platform_excludes "linux") else() set(oneflow_platform "linux") list(APPEND oneflow_platform_excludes "windows") endif() file(GLOB_RECURSE oneflow_all_hdr_to_be_expanded "${PROJECT_SOURCE_DIR}/oneflow/core/*.e.h" "${PROJECT_SOURCE_DIR}/oneflow/python/*.e.h") foreach(oneflow_hdr_to_be_expanded ${oneflow_all_hdr_to_be_expanded}) file(RELATIVE_PATH of_ehdr_rel_path ${PROJECT_SOURCE_DIR} ${oneflow_hdr_to_be_expanded}) set(of_e_h_expanded "${PROJECT_BINARY_DIR}/${of_ehdr_rel_path}.expanded.h") if(WIN32) error("Expanding macro in WIN32 is not supported yet") else() add_custom_command( OUTPUT ${of_e_h_expanded} COMMAND ${CMAKE_C_COMPILER} ARGS -E -I"${PROJECT_SOURCE_DIR}" -I"${PROJECT_BINARY_DIR}" -o "${of_e_h_expanded}" "${oneflow_hdr_to_be_expanded}" DEPENDS ${oneflow_hdr_to_be_expanded} COMMENT "Expanding macros in ${oneflow_hdr_to_be_expanded}") list(APPEND oneflow_all_hdr_expanded "${of_e_h_expanded}") endif() set_source_files_properties(${oneflow_all_hdr_expanded} PROPERTIES GENERATED TRUE) endforeach() file( GLOB_RECURSE oneflow_all_src "${PROJECT_SOURCE_DIR}/oneflow/core/*.*" "${PROJECT_SOURCE_DIR}/oneflow/user/*.*" "${PROJECT_SOURCE_DIR}/oneflow/api/*.*" "${PROJECT_SOURCE_DIR}/oneflow/maybe/*.*" "${PROJECT_SOURCE_DIR}/oneflow/extension/*.*") foreach(oneflow_single_file ${oneflow_all_src}) # Verify whether this file is for other platforms set(exclude_this OFF) set(group_this OFF) foreach(oneflow_platform_exclude ${oneflow_platform_excludes}) string(FIND ${oneflow_single_file} ${oneflow_platform_exclude} platform_found) if(NOT ${platform_found} EQUAL -1) # the ${oneflow_single_file} is for other platforms set(exclude_this ON) endif() endforeach() # If this file is for other platforms, just exclude it from current project if(exclude_this) continue() endif() if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe)/.*\\.(h|hpp)$") if((NOT RPC_BACKEND MATCHES "GRPC") AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/control/.*") # skip if GRPC not enabled elseif(APPLE AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*") # skip if macOS else() list(APPEND of_all_obj_cc ${oneflow_single_file}) set(group_this ON) endif() endif() if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user)/.*\\.(cuh|cu)$") if(BUILD_CUDA) list(APPEND of_all_obj_cc ${oneflow_single_file}) endif() set(group_this ON) endif() if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user)/.*\\.proto$") list(APPEND of_all_proto ${oneflow_single_file}) #list(APPEND of_all_obj_cc ${oneflow_single_file}) # include the proto file in the project set(group_this ON) endif() if(BUILD_PYTHON) if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/api/python/.*\\.(h|cpp)$") list(APPEND of_pybind_obj_cc ${oneflow_single_file}) set(group_this ON) endif() if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/extension/.*\\.(c|h|cpp)$") list(APPEND of_pyext_obj_cc ${oneflow_single_file}) set(group_this ON) endif() endif(BUILD_PYTHON) if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe)/.*\\.cpp$") if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe|thread)/.*_test\\.cpp$") # test file list(APPEND of_all_test_cc ${oneflow_single_file}) elseif(APPLE AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*") # skip if macOS elseif(APPLE AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/transport/.*") # skip if macOS elseif((NOT RPC_BACKEND MATCHES "GRPC") AND "${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/core/control.*") # skip if GRPC not enabled else() list(APPEND of_all_obj_cc ${oneflow_single_file}) endif() set(group_this ON) endif() if(group_this) file(RELATIVE_PATH oneflow_relative_file ${PROJECT_SOURCE_DIR}/oneflow/core/ ${oneflow_single_file}) get_filename_component(oneflow_relative_path ${oneflow_relative_file} PATH) string(REPLACE "/" "\\" group_name ${oneflow_relative_path}) source_group("${group_name}" FILES ${oneflow_single_file}) endif() endforeach() # clang format add_custom_target( of_format COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i ${ONEFLOW_PYTHON_DIR} --fix --exclude="oneflow/include" --exclude="oneflow/core" COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix --quiet COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_py_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/python --fix COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/tools/oneflow-tblgen --fix --quiet) # clang tidy set(RUN_CLANG_TIDY_ARGS --build_dir ${CMAKE_BINARY_DIR}) if(MAYBE_NEED_ERROR_MSG_CHECK) list(APPEND RUN_CLANG_TIDY_ARGS --check-error-msg) endif() message(STATUS "RUN_CLANG_TIDY_ARGS: ${RUN_CLANG_TIDY_ARGS}") add_custom_target( of_tidy COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/ci/check/run_clang_tidy.py ${RUN_CLANG_TIDY_ARGS} DEPENDS of_git_version oneflow_deps of_functional_obj of_functional_tensor_obj) # generate version set(OF_GIT_VERSION_DIR ${CMAKE_CURRENT_BINARY_DIR}/of_git_version) set(OF_GIT_VERSION_FILE ${OF_GIT_VERSION_DIR}/version.cpp) set(OF_GIT_VERSION_DUMMY_FILE ${OF_GIT_VERSION_DIR}/_version.cpp) add_custom_target(of_git_version_create_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${OF_GIT_VERSION_DIR}) add_custom_command( OUTPUT ${OF_GIT_VERSION_DUMMY_FILE} COMMAND ${CMAKE_COMMAND} -DOF_GIT_VERSION_FILE=${OF_GIT_VERSION_FILE} -DOF_GIT_VERSION_ROOT=${PROJECT_SOURCE_DIR} -DBUILD_GIT_VERSION=${BUILD_GIT_VERSION} -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/git_version.cmake DEPENDS of_git_version_create_dir) add_custom_target(of_git_version DEPENDS ${OF_GIT_VERSION_DUMMY_FILE}) set_source_files_properties(${OF_GIT_VERSION_FILE} PROPERTIES GENERATED TRUE) list(APPEND of_all_obj_cc ${OF_GIT_VERSION_FILE}) set(of_proto_python_dir "${PROJECT_BINARY_DIR}/of_proto_python") # proto obj lib add_custom_target(make_pyproto_dir ALL COMMAND ${CMAKE_COMMAND} -E make_directory ${of_proto_python_dir}) foreach(proto_name ${of_all_proto}) file(RELATIVE_PATH proto_rel_name ${PROJECT_SOURCE_DIR} ${proto_name}) list(APPEND of_all_rel_protos ${proto_rel_name}) endforeach() relative_protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROJECT_SOURCE_DIR} ${of_all_rel_protos}) oneflow_add_library(of_protoobj SHARED ${PROTO_SRCS} ${PROTO_HDRS}) add_dependencies(of_protoobj make_pyproto_dir protobuf) target_link_libraries(of_protoobj protobuf_imported) include(functional) generate_functional_api_and_pybind11_cpp(FUNCTIONAL_GENERATED_SRCS FUNCTIONAL_GENERATED_HRCS FUNCTIONAL_PYBIND11_SRCS ${PROJECT_SOURCE_DIR}) oneflow_add_library(of_functional_obj OBJECT ${FUNCTIONAL_GENERATED_SRCS} ${FUNCTIONAL_GENERATED_HRCS}) target_link_libraries(of_functional_obj LLVMSupportWithHeader glog::glog fmt) add_dependencies(of_functional_obj prepare_oneflow_third_party) if(BUILD_PYTHON) generate_functional_tensor_api_and_pybind11_cpp( FUNCTIONAL_TENSOR_GENERATED_SRCS FUNCTIONAL_TENSOR_GENERATED_HRCS FUNCTIONAL_TENSOR_PYBIND11_SRCS ${PROJECT_SOURCE_DIR}) generate_functional_dispatch_stateful_ops_and_pybind11_cpp( FUNCTIONAL_OPS_GENERATED_SRCS FUNCTIONAL_OPS_GENERATED_HRCS FUNCTIONAL_OPS_PYBIND11_SRCS ${PROJECT_SOURCE_DIR}) oneflow_add_library( of_functional_tensor_obj OBJECT ${FUNCTIONAL_TENSOR_GENERATED_SRCS} ${FUNCTIONAL_TENSOR_GENERATED_HRCS} ${FUNCTIONAL_OPS_GENERATED_SRCS} ${FUNCTIONAL_OPS_GENERATED_HRCS}) target_link_libraries(of_functional_tensor_obj LLVMSupportWithHeader glog::glog fmt) add_dependencies(of_functional_tensor_obj prepare_oneflow_third_party) target_include_directories(of_functional_tensor_obj PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) set(PYBIND11_SRCS ${FUNCTIONAL_PYBIND11_SRCS} ${FUNCTIONAL_TENSOR_PYBIND11_SRCS} ${FUNCTIONAL_OPS_PYBIND11_SRCS}) endif(BUILD_PYTHON) include_directories(${PROJECT_SOURCE_DIR}) # TO FIND: third_party/eigen3/.. include_directories(${PROJECT_BINARY_DIR}) # cc obj lib oneflow_add_library(oneflow SHARED ${of_all_obj_cc}) add_dependencies(oneflow of_protoobj) add_dependencies(oneflow of_functional_obj) add_dependencies(oneflow of_op_schema) add_dependencies(oneflow of_git_version) if(USE_CLANG_FORMAT) add_dependencies(oneflow of_format) endif() if(USE_CLANG_TIDY) add_dependencies(oneflow of_tidy) endif() target_compile_definitions(oneflow PRIVATE GOOGLE_LOGGING) set(ONEFLOW_TOOLS_DIR "${PROJECT_BINARY_DIR}/tools" CACHE STRING "dir to put binary for debugging and development") set(CACHE_LLVM_MONO_REPO_URL_LIST "https://github.com/llvm/llvm-project/archive/c63522e6ba7782c335043893ae7cbd37eca24fe5.zip" "https://github.com/llvm/llvm-project/archive/a0595f8c99a253c65f30a151337e7aadc19ee3a1.zip" "https://github.com/llvm/llvm-project/archive/7eaa84eac3ba935d13f4267d3d533a6c3e1283ed.zip" "https://github.com/llvm/llvm-project/archive/35e60f5de180aea55ed478298f4b40f04dcc57d1.zip" "https://github.com/llvm/llvm-project/archive/6a9bbd9f20dcd700e28738788bb63a160c6c088c.zip" "https://github.com/llvm/llvm-project/archive/32805e60c9de1f82887cd2af30d247dcabd2e1d3.zip" "https://github.com/llvm/llvm-project/archive/6d6268dcbf0f48e43f6f9fe46b3a28c29ba63c7d.zip" "https://github.com/llvm/llvm-project/archive/5c9a84960de2260f149ee15313998593255a78df.zip" "https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.0-rc4.zip" "https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-15.0.6.zip" "https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.0.zip" "https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.3.zip") set(CACHE_LLVM_MONO_REPO_MD5_LIST "f2f17229cf21049663b8ef4f2b6b8062" "6b7c6506d5922de9632c8ff012b2f945" "e0ea669a9f0872d35bffda5ec6c5ac6f" "241a333828bba1efa35aff4c4fc2ce87" "075fbfdf06cb3f02373ea44971af7b03" "e412dc61159b5e929b0c94e44b11feb2" "1ccc00accc87a1a5d42a275d6e31cd8c" "b64481eaca658a2ff4e3e193440d0f68" "78172b0f67282e28956cd310612091fd" "0c2a3196e656aaab7ca1c2ef21b6091c" "2702b822b71c196a0cc9c8d821c069d7" "334997b4879aba15d9323a732356cf2a") # clean cache for last LLVM version if("${LLVM_MONO_REPO_URL}" IN_LIST CACHE_LLVM_MONO_REPO_URL_LIST OR "${LLVM_MONO_REPO_MD5}" IN_LIST CACHE_LLVM_MONO_REPO_MD5_LIST) unset(LLVM_MONO_REPO_URL CACHE) unset(LLVM_MONO_REPO_MD5 CACHE) endif() set(LLVM_MONO_REPO_URL "https://github.com/llvm/llvm-project/archive/c2ce2a509f74a85a3c0ef4b9d6d79fbacc7e8bdf.zip" CACHE STRING "") use_mirror(VARIABLE LLVM_MONO_REPO_URL URL ${LLVM_MONO_REPO_URL}) set(LLVM_MONO_REPO_MD5 "25489a23c6fa971fcd0d1167a560bf0a" CACHE STRING "") set(ONEFLOW_BUILD_ROOT_DIR "${PROJECT_BINARY_DIR}") add_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir) if(WITH_MLIR) set(ONEFLOW_MLIR_LIBS -Wl,--no-as-needed MLIROneFlowExtension -Wl,--as-needed) endif() if("${LLVM_PROVIDER}" STREQUAL "install") get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR) check_variable_defined(LLVM_INSTALL_DIR) find_library(LLVMSupportLib LLVMSupport PATHS ${LLVM_INSTALL_DIR}/lib REQUIRED) add_library(LLVMSupportWithHeader UNKNOWN IMPORTED) set_property(TARGET LLVMSupportWithHeader PROPERTY IMPORTED_LOCATION ${LLVMSupportLib}) else() add_library(LLVMSupportWithHeader INTERFACE IMPORTED) target_link_libraries(LLVMSupportWithHeader INTERFACE LLVMSupport) endif() check_variable_defined(LLVM_INCLUDE_DIRS) set_property(TARGET LLVMSupportWithHeader PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${LLVM_INCLUDE_DIRS}) list(APPEND oneflow_third_party_libs LLVMSupportWithHeader) # for stack backtrace find_package(BFD) if(BFD_FOUND) add_definitions(-DBACKWARD_HAS_BFD=1) list(APPEND oneflow_third_party_libs bfd::bfd) endif() find_package(Unwind) if(Unwind_FOUND) add_definitions(-DBACKWARD_HAS_LIBUNWIND=1) list(APPEND oneflow_third_party_libs unwind::unwind) endif() add_definitions(-DONEFLOW_SOURCE_DIR="${PROJECT_SOURCE_DIR}") add_definitions(-DONEFLOW_BINARY_DIR="${PROJECT_BINARY_DIR}") include(op_schema) get_property(EXTERNAL_TARGETS GLOBAL PROPERTY EXTERNAL_TARGETS) if(APPLE) set(of_libs ${ALL_ARCHIVE_BEGIN} oneflow of_op_schema ${ALL_ARCHIVE_END}) target_link_libraries(oneflow of_protoobj of_functional_obj ${oneflow_third_party_libs}) elseif(UNIX) set(of_libs ${ALL_ARCHIVE_BEGIN} oneflow of_op_schema ${ALL_ARCHIVE_END} -ldl -lrt) target_link_libraries( oneflow of_protoobj of_functional_obj ${oneflow_third_party_libs} ${EXTERNAL_TARGETS} -Wl,--no-whole-archive -Wl,--as-needed -ldl -lrt) if(BUILD_CUDA) target_link_libraries(oneflow CUDA::cudart_static) endif() if(WITH_OMP) if(OpenMP_CXX_FOUND) target_link_libraries(oneflow OpenMP::OpenMP_CXX) endif() endif() elseif(WIN32) set(of_libs oneflow of_protoobj of_functional_obj of_op_schema) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow") endif() if(BUILD_CUDA) string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST}) set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/core/hardware/cuda_device_descriptor.cpp PROPERTIES COMPILE_FLAGS "-DCUDA_REAL_ARCHS=\"${CUDA_REAL_ARCHS}\"") endif() if(BUILD_NPU) add_definitions(-DWITH_NPU) endif() message(STATUS "BUILD_NPU: ${BUILD_NPU}") if(BUILD_MLU) add_definitions(-DWITH_MLU) endif() message(STATUS "BUILD_MLU: ${BUILD_MLU}") if(BUILD_CUDA AND WITH_CUTLASS) if(CUDA_VERSION VERSION_GREATER_EQUAL "10.1") add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_attention_kernels.cu APPEND PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/xformers_fmha) set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_glu_kernel.cu APPEND PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/45_dual_gemm) if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA") set_property( SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_multi_head_attention_inference_kernel.cu APPEND PROPERTY COMPILE_OPTIONS "--use_fast_math") endif() endif() # oneflow api common if(BUILD_PYTHON OR BUILD_CPP_API) file(GLOB_RECURSE of_api_common_files ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.h ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.cpp) oneflow_add_library(of_api_common OBJECT ${of_api_common_files}) target_link_libraries(of_api_common oneflow) if(WITH_MLIR) target_link_libraries(of_api_common ${ALL_ARCHIVE_BEGIN} ${ONEFLOW_MLIR_LIBS} ${ALL_ARCHIVE_END}) endif() endif() if(BUILD_PYTHON) # py ext lib # This library should be static to make sure all python symbols are included in the final ext shared lib, # so that it is safe to do wheel audits of multiple pythons version in parallel. oneflow_add_library(of_pyext_obj STATIC ${of_pyext_obj_cc}) target_include_directories(of_pyext_obj PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) target_link_libraries(of_pyext_obj oneflow pybind11::headers) if(BUILD_SHARED_LIBS AND APPLE) target_link_libraries(of_pyext_obj ${Python3_LIBRARIES}) endif() add_dependencies(of_pyext_obj oneflow) pybind11_add_module(oneflow_internal ${PYBIND11_SRCS} ${of_pybind_obj_cc} ${PYBIND_REGISTRY_CC}) set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH "\$ORIGIN/../nvidia/cublas/lib") set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH "\$ORIGIN/../nvidia/cudnn/lib") set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH "\$ORIGIN/../nvidia/nccl/lib") set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH "\$ORIGIN/../nvidia/cusparse/lib") set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH "\$ORIGIN/../nvidia/cufft/lib") set_compile_options_to_oneflow_target(oneflow_internal) set_property(TARGET oneflow_internal PROPERTY CXX_VISIBILITY_PRESET "default") add_dependencies(oneflow_internal of_functional_obj of_functional_tensor_obj of_op_schema) set_target_properties(oneflow_internal PROPERTIES PREFIX "_") set_target_properties(oneflow_internal PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_PYTHON_DIR}/oneflow") target_link_libraries( oneflow_internal PRIVATE ${of_libs} of_functional_tensor_obj of_api_common ${oneflow_third_party_libs} of_pyext_obj glog::glog) target_include_directories(oneflow_internal PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) if(WITH_MLIR) add_dependencies(check-oneflow oneflow_internal) endif(WITH_MLIR) set(gen_pip_args "") if(BUILD_CUDA) list(APPEND gen_pip_args --cuda=${CUDA_VERSION}) endif() add_custom_target( of_pyscript_copy ALL COMMAND ${CMAKE_COMMAND} -E touch "${of_proto_python_dir}/oneflow/core/__init__.py" COMMAND ${CMAKE_COMMAND} -E create_symlink "${of_proto_python_dir}/oneflow/core" "${ONEFLOW_PYTHON_DIR}/oneflow/core" COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/generate_pip_version.py ${gen_pip_args} --src=${PROJECT_SOURCE_DIR} --cmake_project_binary_dir=${PROJECT_BINARY_DIR} --out=${ONEFLOW_PYTHON_DIR}/oneflow/version.py) # source this file to add oneflow in PYTHONPATH file(WRITE "${PROJECT_BINARY_DIR}/source.sh" "export PYTHONPATH=${ONEFLOW_PYTHON_DIR}:$PYTHONPATH") add_dependencies(of_pyscript_copy of_protoobj) endif(BUILD_PYTHON) if(BUILD_CPP_API) file(GLOB_RECURSE of_cpp_api_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.cpp ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.h) list(FILTER of_cpp_api_files EXCLUDE REGEX "oneflow/api/cpp/tests") oneflow_add_library(oneflow_cpp SHARED ${of_cpp_api_files}) set_target_properties(oneflow_cpp PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${LIBONEFLOW_LIBRARY_DIR}" LIBRARY_OUTPUT_DIRECTORY "${LIBONEFLOW_LIBRARY_DIR}") target_link_libraries(oneflow_cpp PRIVATE ${of_libs} of_api_common ${oneflow_third_party_libs}) endif() file(RELATIVE_PATH PROJECT_BINARY_DIR_RELATIVE ${PROJECT_SOURCE_DIR} ${PROJECT_BINARY_DIR}) function(oneflow_add_test target_name) cmake_parse_arguments(arg "" "TEST_NAME;WORKING_DIRECTORY" "SRCS" ${ARGN}) oneflow_add_executable(${target_name} ${arg_SRCS}) if(BUILD_CUDA) target_link_libraries(${target_name} CUDA::cudart_static) endif() set_target_properties(${target_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") add_test(NAME ${arg_TEST_NAME} COMMAND ${target_name} WORKING_DIRECTORY ${arg_WORKING_DIRECTORY}) set_tests_properties( ${arg_TEST_NAME} PROPERTIES ENVIRONMENT "HTTP_PROXY='';HTTPS_PROXY='';http_proxy='';https_proxy='';") endfunction() # build test if(BUILD_TESTING) if(of_all_test_cc) oneflow_add_test(oneflow_testexe SRCS ${of_all_test_cc} TEST_NAME oneflow_test) target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs} glog::glog ${oneflow_test_libs}) if(WITH_MLIR) target_link_libraries(oneflow_testexe ${ALL_ARCHIVE_BEGIN} MLIROneFlowExtension ${ALL_ARCHIVE_END}) endif() endif() if(BUILD_CPP_API) file(GLOB_RECURSE cpp_api_test_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/tests/*.cpp) oneflow_add_test( oneflow_cpp_api_testexe SRCS ${cpp_api_test_files} TEST_NAME oneflow_cpp_api_test WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) find_package(Threads REQUIRED) target_link_libraries(oneflow_cpp_api_testexe oneflow_cpp ${oneflow_third_party_libs} ${oneflow_test_libs} Threads::Threads) endif() endif() # build include add_custom_target(of_include_copy ALL) if(BUILD_PYTHON) add_dependencies(of_include_copy oneflow_internal of_pyscript_copy) install( DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/oneflow/core DESTINATION ${ONEFLOW_INCLUDE_DIR}/oneflow COMPONENT oneflow_py_include EXCLUDE_FROM_ALL FILES_MATCHING PATTERN *.h PATTERN *.hpp) install( DIRECTORY ${CMAKE_SOURCE_DIR}/oneflow DESTINATION ${ONEFLOW_INCLUDE_DIR} COMPONENT oneflow_py_include EXCLUDE_FROM_ALL FILES_MATCHING REGEX "oneflow/core/common/.+(h|hpp)$" REGEX "oneflow/core/device/.+(h|hpp)$" REGEX "oneflow/core/framework/.+(h|hpp)$" REGEX "oneflow/core/kernel/util/.+(h|hpp)$" REGEX "oneflow/core/persistence/.+(h|hpp)$" REGEX "oneflow/core/ep/include/.+(h|hpp)$" REGEX "oneflow/core/ep/common/.+(h|hpp)$" REGEX "oneflow/core/ep/cpu/.+(h|hpp)$" REGEX "oneflow/core/ep/cuda/.+(h|hpp)$" REGEX "oneflow/core/job/.+(h|hpp)$" REGEX "oneflow/core/intrusive/.+(h|hpp)$" REGEX "oneflow/core/graph/boxing/.+(h|hpp)$" REGEX "oneflow/core/vm/.+(h|hpp)$" REGEX "oneflow/core/.+(proto)$" REGEX "oneflow/user/.+(h|hpp)$" PATTERN "oneflow/core/kernel/chain_kernel_observer.h" PATTERN "oneflow/core/kernel/cuda_graph_support.h" PATTERN "oneflow/core/kernel/new_kernel_util.h" PATTERN "oneflow/core/kernel/kernel.h" PATTERN "oneflow/core/kernel/kernel_context.h" PATTERN "oneflow/core/kernel/kernel_observer.h" PATTERN "oneflow/core/kernel/kernel_util.h" PATTERN "oneflow/core/kernel/kernel_util.cuh" PATTERN "oneflow/core/kernel/kernel_registration.h" PATTERN "oneflow/core/common/symbol.h" PATTERN "oneflow/core/register/blob.h" PATTERN "oneflow/core/register/op_blob_arg_info.h" PATTERN "oneflow/core/register/register.h" PATTERN "oneflow/core/register/register_desc.h" PATTERN "oneflow/core/register/register_manager.h" PATTERN "oneflow/core/register/runtime_register_desc.h" PATTERN "oneflow/core/register/tensor_slice_view.h" PATTERN "oneflow/core/ndarray/xpu_util.h" PATTERN "oneflow/core/rpc/include/base.h" PATTERN "oneflow/core/rpc/include/ctrl.h" PATTERN "oneflow/core/rpc/include/global_process_ctx.h" PATTERN "oneflow/core/control/ctrl_client.h" PATTERN "oneflow/core/control/global_process_ctx.h" PATTERN "oneflow/core/autograd/autograd_meta.h" PATTERN "oneflow/core/register/blob_desc.h" PATTERN "oneflow/core/operator/operator.h" PATTERN "oneflow/core/operator/operator_util.h" PATTERN "oneflow/core/operator/op_conf_util.h" PATTERN "oneflow/core/graph/compute_task_node.h" PATTERN "oneflow/core/graph/copy_task_node.h" PATTERN "oneflow/core/graph/exec_graph.h" PATTERN "oneflow/core/graph/graph.h" PATTERN "oneflow/core/graph/node.h" PATTERN "oneflow/core/graph/op_graph.h" PATTERN "oneflow/core/graph/task_graph.h" PATTERN "oneflow/core/graph/task_id.h" PATTERN "oneflow/core/graph/task_id_generator.h" PATTERN "oneflow/core/graph/task_node.h" PATTERN "oneflow/core/graph/task_stream_index_manager.h" PATTERN "oneflow/core/graph/stream_id.h" PATTERN "oneflow/core/graph/stream_index_generator.h" PATTERN "oneflow/core/graph/fake_consumed_regst_provider.h" PATTERN "oneflow/core/graph/transport_task_node.h" PATTERN "oneflow/core/thread/thread.h" PATTERN "oneflow/core/thread/thread_manager.h" PATTERN "oneflow/core/thread/thread_pool.h" PATTERN "oneflow/core/thread/thread_runtime.h" PATTERN "oneflow/core/thread/thread_runtime_factory.h" PATTERN "oneflow/core/profiler/profiler.h" PATTERN "oneflow/extension/stack/foreign_stack_getter.h" PATTERN "oneflow/core/platform/include/pthread_fork.h" PATTERN "oneflow/core/lazy/actor/actor.h" PATTERN "oneflow/core/lazy/actor/actor_base.h" PATTERN "oneflow/core/lazy/actor/actor_context.h" PATTERN "oneflow/core/lazy/actor/actor_message.h" PATTERN "oneflow/core/lazy/actor/actor_message_bus.h" PATTERN "oneflow/core/lazy/actor/register_slot.h" PATTERN "oneflow/core/lazy/stream_context/include/stream_context.h" PATTERN "oneflow/core/memory/memory_allocator.h" PATTERN "oneflow/core/memory/memory_case_util.h" PATTERN "oneflow/core/memory/memory_zone.h" PATTERN "oneflow/user/ops/convert_memory_format.h" PATTERN "oneflow/api" EXCLUDE PATTERN "oneflow/maybe" EXCLUDE PATTERN "oneflow/core/graph_impl" EXCLUDE PATTERN "oneflow/core/job_rewriter" EXCLUDE PATTERN "oneflow/core/hardware" EXCLUDE PATTERN "oneflow/core/stream" EXCLUDE PATTERN "oneflow/core/functional" EXCLUDE PATTERN "oneflow/core/boxing" EXCLUDE PATTERN "oneflow/core/transport" EXCLUDE PATTERN "oneflow/core/comm_network" EXCLUDE PATTERN "oneflow/ir" EXCLUDE) add_custom_target( install_oneflow_py_include COMMAND "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=oneflow_py_include -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" DEPENDS oneflow_internal) add_custom_target(oneflow_py ALL) add_dependencies(oneflow_py of_include_copy install_oneflow_py_include) endif(BUILD_PYTHON) if(BUILD_CPP_API) set(LIBONEFLOW_DIR ${PROJECT_BINARY_DIR}/liboneflow_cpp) install( DIRECTORY oneflow/api/cpp/ COMPONENT oneflow_cpp_all DESTINATION include/oneflow FILES_MATCHING PATTERN "*.h" PATTERN "tests" EXCLUDE) set(LIBONEFLOW_THIRD_PARTY_DIRS) checkdirandappendslash(DIR ${PROTOBUF_LIBRARY_DIR} OUTPUT PROTOBUF_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${PROTOBUF_LIBRARY_DIR_APPENDED}) if(BUILD_CUDA) checkdirandappendslash(DIR ${NCCL_LIBRARY_DIR} OUTPUT NCCL_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${NCCL_LIBRARY_DIR_APPENDED}) checkdirandappendslash(DIR ${TRT_FLASH_ATTENTION_LIBRARY_DIR} OUTPUT TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED}) if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") checkdirandappendslash(DIR ${FLASH_ATTENTION_LIBRARY_DIR} OUTPUT FLASH_ATTENTION_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${FLASH_ATTENTION_LIBRARY_DIR_APPENDED}) endif() if(WITH_CUTLASS) checkdirandappendslash(DIR ${CUTLASS_LIBRARY_DIR} OUTPUT CUTLASS_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${CUTLASS_LIBRARY_DIR_APPENDED}) endif() endif() install( DIRECTORY ${LIBONEFLOW_THIRD_PARTY_DIRS} COMPONENT oneflow_cpp_all DESTINATION lib FILES_MATCHING PATTERN "*.so*" PATTERN "*.a" EXCLUDE PATTERN "libprotobuf-lite.so*" EXCLUDE PATTERN "libprotoc.so*" EXCLUDE PATTERN "cmake" EXCLUDE PATTERN "pkgconfig" EXCLUDE) install(FILES ${PROJECT_SOURCE_DIR}/cmake/oneflow-config.cmake COMPONENT oneflow_cpp_all DESTINATION share) get_property(MLIR_RELATED_TARGETS GLOBAL PROPERTY MLIR_EXPORTS) get_property(LLVM_RELATED_TARGETS GLOBAL PROPERTY LLVM_EXPORTS) list( REMOVE_ITEM LLVM_RELATED_TARGETS count not FileCheck lli-child-target llvm-jitlink-executor llvm-PerfectShuffle llvm-tblgen mlir-tblgen mlir-pdll obj2yaml oneflow_tblgen yaml-bench yaml2obj) set(LIBONEFLOW_TARGETS) list( APPEND LIBONEFLOW_TARGETS oneflow_cpp oneflow of_protoobj glog ${MLIR_RELATED_TARGETS} ${LLVM_RELATED_TARGETS} ${EXTERNAL_TARGETS}) if(BUILD_TESTING AND BUILD_SHARED_LIBS) list(APPEND LIBONEFLOW_TARGETS gtest_main gtest) endif() if(BUILD_TESTING) list(APPEND LIBONEFLOW_TARGETS oneflow_cpp_api_testexe) list(APPEND LIBONEFLOW_TARGETS oneflow_testexe) endif(BUILD_TESTING) install( TARGETS ${LIBONEFLOW_TARGETS} COMPONENT oneflow_cpp_all LIBRARY DESTINATION lib ARCHIVE DESTINATION lib RUNTIME DESTINATION bin) add_custom_target( install_oneflow_cpp COMMAND "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=oneflow_cpp_all -DCMAKE_INSTALL_PREFIX="${LIBONEFLOW_DIR}" -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" DEPENDS oneflow_cpp) if(BUILD_TESTING) add_dependencies(install_oneflow_cpp oneflow_cpp_api_testexe oneflow_testexe) endif(BUILD_TESTING) add_dependencies(of_include_copy install_oneflow_cpp) string(TOLOWER ${CMAKE_SYSTEM_NAME} CPACK_SYSTEM_NAME) set(CPACK_GENERATOR ZIP) set(CPACK_PACKAGE_DIRECTORY ${PROJECT_BINARY_DIR}/cpack) set(CPACK_PACKAGE_NAME liboneflow) # TODO: by Shenghang, unify python and c++ version genenerating and getting set(CPACK_PACKAGE_VERSION ${ONEFLOW_CURRENT_VERSION}) set(CPACK_INSTALL_CMAKE_PROJECTS ${PROJECT_BINARY_DIR};oneflow;oneflow_cpp_all;/) include(CPack) endif(BUILD_CPP_API) ================================================ FILE: cmake/op_schema.cmake ================================================ get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR) set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) set(ONEFLOW_OP_GROUPS "ASSIGN" "BINARY" "BROADCAST" "CONV" "CROSS_ENTROPY" "CUDA" "DATASET" "DETECTION" "EAGER" "FUSED" "IDEMPOTENT" "IDENTITY" "IMAGE" "INDICES" "INVOLUTION" "LOSS" "MATH" "MATMUL" "MISC" "NCCL" "NORMALIZATION" "OPTIMIZER" "PADDING" "PARALLEL_CAST" "POOL" "QUANTIZATION" "REDUCE" "RESHAPE" "SCALAR" "SOFTMAX" "SUMMARY" "TENSOR_BUFFER" "TEST" "TRIGONOMETRIC" "UNARY" "UPSAMPLE" "ONE_EMBEDDING" "LINEAR_ALGEBRA" "SYSTEM") if(WITH_MLIR) list(APPEND ONEFLOW_OP_GROUPS "MLIR_JIT") endif(WITH_MLIR) foreach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS") endforeach() list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DREMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS") set(GENERATED_OP_SCHEMA_DIR oneflow/core/framework) set(GENERATED_IR_INCLUDE_DIR oneflow/ir/include) set(SOURCE_IR_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/oneflow/ir/include) set(ONEFLOW_ODS ${SOURCE_IR_INCLUDE_DIR}/OneFlow/OneFlowOps.td) list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${GENERATED_IR_INCLUDE_DIR}") list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${SOURCE_IR_INCLUDE_DIR}") list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${LLVM_INSTALL_DIR}/include") set(GENERATED_OP_SCHEMA_H "${GENERATED_OP_SCHEMA_DIR}/op_generated.h") set(GENERATED_OP_SCHEMA_CPP "${GENERATED_OP_SCHEMA_DIR}/op_generated.cpp") set(ONEFLOW_TABLE_GEN_EXE ${LLVM_INSTALL_DIR}/bin/oneflow_tblgen) if(LLVM_PROVIDER STREQUAL "in-tree") set(ONEFLOW_TABLE_GEN_TARGET oneflow_tblgen install-oneflow-tblgen install-mlir-headers) elseif(LLVM_PROVIDER STREQUAL "install") set(ONEFLOW_TABLE_GEN_TARGET ${ONEFLOW_TABLE_GEN_EXE}) endif() file(GLOB_RECURSE ODS_FILES LIST_DIRECTORIES false "${SOURCE_IR_INCLUDE_DIR}/*.td") if(NOT ODS_FILES) message(FATAL_ERROR "ODS_FILES not found: ${ODS_FILES}") endif() add_custom_command( OUTPUT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_OP_SCHEMA_DIR} COMMAND ${ONEFLOW_TABLE_GEN_EXE} ARGS --gen-op-schema-h ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} -o ${GENERATED_OP_SCHEMA_H} COMMAND ${ONEFLOW_TABLE_GEN_EXE} ARGS --gen-op-schema-cpp ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} --op-include ${GENERATED_OP_SCHEMA_H} -o ${GENERATED_OP_SCHEMA_CPP} DEPENDS ${ONEFLOW_TABLE_GEN_TARGET} ${ODS_FILES} VERBATIM) set_source_files_properties(${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} PROPERTIES GENERATED TRUE) oneflow_add_library(of_op_schema OBJECT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP}) target_link_libraries(of_op_schema LLVMSupportWithHeader glog::glog fmt) add_dependencies(of_op_schema prepare_oneflow_third_party) ================================================ FILE: cmake/platform.cmake ================================================ if(WIN32) set(CMAKE_BUILD_TYPE Debug) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC -D__VERSION__=\"MSVC\") add_definitions( -DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS -D_ITERATOR_DEBUG_LEVEL=0) add_definitions( /bigobj /nologo /EHsc /GF /FC /MP /Gm-) add_definitions(-DGOOGLE_GLOG_DLL_DECL=) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") foreach( flag_var CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) if(${flag_var} MATCHES "/MD") string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") endif() endforeach() # set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /DEBUG:FASTLINK") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0") else() set(EXTRA_CXX_FLAGS "-Wall -Wno-sign-compare -Wno-unused-function -fPIC") if(APPLE) set(EXTRA_CXX_FLAGS "${EXTRA_CXX_FLAGS} -Wno-deprecated-declarations") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${EXTRA_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${EXTRA_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${EXTRA_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${EXTRA_CXX_FLAGS}") endif(WIN32) ================================================ FILE: cmake/proto2cpp.cmake ================================================ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) if(NOT ARGN) message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files") return() endif() set(${SRCS}) set(${HDRS}) foreach(FIL ${ARGN}) set(ABS_FIL ${ROOT_DIR}/${FIL}) get_filename_component(FIL_WE ${FIL} NAME_WE) get_filename_component(FIL_DIR ${ABS_FIL} PATH) file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") add_custom_command( OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" "${of_proto_python_dir}/${REL_DIR}/${FIL_WE}_pb2.py" COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIR} COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS --python_out ${of_proto_python_dir} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIR} COMMAND ${CMAKE_COMMAND} ARGS -E touch ${of_proto_python_dir}/${REL_DIR}/__init__.py DEPENDS ${ABS_FIL} protobuf COMMENT "Running Protocol Buffer Compiler on ${FIL}" VERBATIM) endforeach() set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) set(${SRCS} ${${SRCS}} PARENT_SCOPE) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() ================================================ FILE: cmake/pybind11.cmake ================================================ include(FetchContent) set_mirror_url_with_hash(PYBIND11_URL https://github.com/pybind/pybind11/archive/v2.11.1.zip c62d9e05243bd31cdb3bae1bb2f56655) FetchContent_Declare(pybind11 URL ${PYBIND11_URL} URL_HASH MD5=${PYBIND11_URL_HASH}) FetchContent_MakeAvailable(pybind11) ================================================ FILE: cmake/python.cmake ================================================ if(NOT DEFINED Python3_EXECUTABLE) execute_process( COMMAND which python3 RESULT_VARIABLE STATUS OUTPUT_VARIABLE OUTPUT ERROR_QUIET) if(STATUS EQUAL 0) string(STRIP ${OUTPUT} STRIPPED) message(STATUS "Using Python3 from 'which python3': ${STRIPPED}") set(Python3_EXECUTABLE ${STRIPPED}) endif() endif() find_package(Python3 COMPONENTS Interpreter REQUIRED) message(STATUS "Python3 specified. Version found: " ${Python3_VERSION}) set(Python_EXECUTABLE ${Python3_EXECUTABLE}) message(STATUS "Using Python executable: " ${Python_EXECUTABLE}) message(STATUS "Installing necessary Python packages...") set(requirements_txt ${PROJECT_SOURCE_DIR}/dev-requirements.txt) set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${requirements_txt}) message(STATUS "PIP_INDEX_MIRROR: ${PIP_INDEX_MIRROR}") if(PIP_INDEX_MIRROR) set(extra_index_arg "-i") endif() function(install_py_dev_deps) execute_process(COMMAND ${ARGV0} -m pip install ${extra_index_arg} ${PIP_INDEX_MIRROR} -r ${requirements_txt} --user RESULT_VARIABLE PIP_INSTALL_STATUS) if(NOT PIP_INSTALL_STATUS EQUAL 0) message(FATAL_ERROR "fail to install pip packages") endif() message(STATUS "Python packages are installed.") endfunction(install_py_dev_deps) install_py_dev_deps(${Python_EXECUTABLE}) find_package(Python3 COMPONENTS Development NumPy) if(Python3_Development_FOUND AND Python3_INCLUDE_DIRS) set(Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}) endif() if(Python3_NumPy_FOUND AND Python3_NumPy_INCLUDE_DIRS) set(Python_NumPy_INCLUDE_DIRS ${Python3_NumPy_INCLUDE_DIRS}) endif() if(NOT Python_INCLUDE_DIRS) message(STATUS "Getting python include directory from sysconfig..") execute_process( COMMAND ${Python_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_paths()['include'])" OUTPUT_VARIABLE Python_INCLUDE_DIRS RESULT_VARIABLE ret_code) string(STRIP "${Python_INCLUDE_DIRS}" Python_INCLUDE_DIRS) if((NOT (ret_code EQUAL "0")) OR (NOT IS_DIRECTORY ${Python_INCLUDE_DIRS}) OR (NOT EXISTS ${Python_INCLUDE_DIRS}/Python.h)) set(Python_INCLUDE_DIRS "") endif() endif() if(NOT Python_INCLUDE_DIRS) message(FATAL_ERROR "Cannot find python include directory") endif() message(STATUS "Found python include directory ${Python_INCLUDE_DIRS}") if(NOT Python_NumPy_INCLUDE_DIRS) message(STATUS "Getting numpy include directory by numpy.get_include()..") execute_process(COMMAND ${Python_EXECUTABLE} -c "import numpy; print(numpy.get_include())" OUTPUT_VARIABLE Python_NumPy_INCLUDE_DIRS RESULT_VARIABLE ret_code) string(STRIP "${Python_NumPy_INCLUDE_DIRS}" Python_NumPy_INCLUDE_DIRS) if((NOT ret_code EQUAL 0) OR (NOT IS_DIRECTORY ${Python_NumPy_INCLUDE_DIRS}) OR (NOT EXISTS ${Python_NumPy_INCLUDE_DIRS}/numpy/arrayobject.h)) set(Python_NumPy_INCLUDE_DIRS "") endif() endif() if(NOT Python_NumPy_INCLUDE_DIRS) message(FATAL_ERROR "Cannot find numpy include directory") endif() message(STATUS "Found numpy include directory ${Python_NumPy_INCLUDE_DIRS}") # PYTHON_EXECUTABLE will be used by pybind11 set(PYTHON_EXECUTABLE ${Python_EXECUTABLE}) include(pybind11) set(CODEGEN_PYTHON_EXECUTABLE ${Python_EXECUTABLE} CACHE STRING "Python executable to generate .cpp/.h files") if(NOT "${CODEGEN_PYTHON_EXECUTABLE}" STREQUAL "${Python_EXECUTABLE}") install_py_dev_deps(${CODEGEN_PYTHON_EXECUTABLE}) endif() ================================================ FILE: cmake/third_party/FindBFD.cmake ================================================ # - BFD Library module. #============================================================================= # This module finds libbfd and associated headers. # #=== Variables =============================================================== # This module will set the following variables in your project: # # BFD_FOUND Whether libbfd was successfully found. # bfd::bfd Cmake target for bfd # #============================================================================= include(FindPackageHandleStandardArgs) set(CMAKE_LIBRARY_PATH /lib /usr/lib /usr/local/lib) set(CMAKE_INCLUDE_PATH /usr/include /usr/local/include) find_path(BFD_INCLUDE_PATH bfd.h PATH /usr/include /usr/local/include) find_library(BFD_LIBRARIES bfd PATH /lib /usr/lib /usr/local/lib) find_package_handle_standard_args(BFD DEFAULT_MSG BFD_LIBRARIES BFD_INCLUDE_PATH) if(BFD_FOUND) if(NOT TARGET bfd::bfd) add_library(bfd::bfd INTERFACE IMPORTED) set_property(TARGET bfd::bfd PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${BFD_INCLUDE_PATH}) set_property(TARGET bfd::bfd PROPERTY INTERFACE_LINK_LIBRARIES ${BFD_LIBRARIES}) set_property(TARGET bfd::bfd PROPERTY IMPORTED_CONFIGURATIONS RELEASE) endif(NOT TARGET bfd::bfd) endif() mark_as_advanced(BFD_INCLUDE_PATH BFD_LIBRARIES) ================================================ FILE: cmake/third_party/FindBLAS.cmake ================================================ #.rst: # FindBLAS # -------- # # Find BLAS library # # This module finds an installed fortran library that implements the # BLAS linear-algebra interface (see http://www.netlib.org/blas/). The # list of libraries searched for is taken from the autoconf macro file, # acx_blas.m4 (distributed at # http://ac-archive.sourceforge.net/ac-archive/acx_blas.html). # # This module sets the following variables: # # :: # # BLAS_FOUND - set to true if a library implementing the BLAS interface # is found # BLAS_LINKER_FLAGS - uncached list of required linker flags (excluding -l # and -L). # BLAS_LIBRARIES - uncached list of libraries (using full path name) to # link against to use BLAS # BLAS95_LIBRARIES - uncached list of libraries (using full path name) # to link against to use BLAS95 interface # BLAS95_FOUND - set to true if a library implementing the BLAS f95 interface # is found # BLA_STATIC if set on this determines what kind of linkage we do (static) # BLA_VENDOR if set checks only the specified vendor, if not set checks # all the possibilities # BLA_F95 if set on tries to find the f95 interfaces for BLAS/LAPACK # # ######### ## List of vendors (BLA_VENDOR) valid in this module # # Goto,OpenBLAS,ATLAS PhiPACK,CXML,DXML,SunPerf,SCSL,SGIMATH,IBMESSL, # Intel10_32 (intel mkl v10 32 bit),Intel10_64lp (intel mkl v10 64 bit, # lp thread model, lp64 model), # Intel10_64lp_seq (intel mkl v10 64 # bit,sequential code, lp64 model), # Intel( older versions of mkl 32 # and 64 bit), ACML,ACML_MP,ACML_GPU,Apple, NAS, Generic C/CXX should be # enabled to use Intel mkl #============================================================================= # Copyright 2007-2009 Kitware, Inc. # # Distributed under the OSI-approved BSD License (the "License"); # see accompanying file Copyright.txt for details. # # This software is distributed WITHOUT ANY WARRANTY; without even the # implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the License for more information. #============================================================================= # (To distribute this file outside of CMake, substitute the full # License text for the above reference.) set(CMAKE_REQUIRED_QUIET ${BLAS_FIND_QUIETLY}) set(_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) # Check the language being used if(NOT (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED OR CMAKE_Fortran_COMPILER_LOADED)) if(BLAS_FIND_REQUIRED) message(FATAL_ERROR "FindBLAS requires Fortran, C, or C++ to be enabled.") else() message(STATUS "Looking for BLAS... - NOT found (Unsupported languages)") return() endif() endif() macro( Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list _thread) # This macro checks for the existence of the combination of fortran libraries # given by _list. If the combination is found, this macro checks (using the # Check_Fortran_Function_Exists macro) whether can link against that library # combination using the name of a routine given by _name using the linker # flags given by _flags. If the combination of libraries is found and passes # the link test, LIBRARIES is set to the list of complete library paths that # have been found. Otherwise, LIBRARIES is set to FALSE. # N.B. _prefix is the prefix applied to the names of all cached variables that # are generated internally and marked advanced by this macro. set(_libdir ${ARGN}) set(_libraries_work TRUE) set(${LIBRARIES}) set(_combined_name) if(NOT _libdir) if(WIN32) set(_libdir ENV LIB) elseif(APPLE) set(_libdir ENV DYLD_LIBRARY_PATH) else() set(_libdir ENV LD_LIBRARY_PATH) endif() endif() foreach(_library ${_list}) set(_combined_name ${_combined_name}_${_library}) if(_libraries_work) if(BLA_STATIC) if(WIN32) set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) endif() if(APPLE) set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) else() set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) endif() else() if(CMAKE_SYSTEM_NAME STREQUAL "Linux") # for ubuntu's libblas3gf and liblapack3gf packages set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf) endif() endif() find_library(${_prefix}_${_library}_LIBRARY NAMES ${_library} PATHS ${_libdir}) mark_as_advanced(${_prefix}_${_library}_LIBRARY) set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) endif() endforeach() if(_libraries_work) set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_thread}) set(CMAKE_REQUIRED_LIBRARIES) mark_as_advanced(${_prefix}${_combined_name}_WORKS) set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) endif() endmacro() set(BLAS_LINKER_FLAGS) set(BLAS_LIBRARIES) set(BLAS95_LIBRARIES) if(NOT $ENV{BLA_VENDOR} STREQUAL "") set(BLA_VENDOR $ENV{BLA_VENDOR}) else() if(NOT BLA_VENDOR) set(BLA_VENDOR "All") endif() endif() if(BLA_VENDOR STREQUAL "Goto" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) # gotoblas (http://www.tacc.utexas.edu/tacc-projects/gotoblas2) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "goto2" "") endif() endif() if(BLA_VENDOR STREQUAL "OpenBLAS" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) # OpenBLAS (http://www.openblas.net) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "openblas" "") endif() endif() if(BLA_VENDOR STREQUAL "ATLAS" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm "" "f77blas;atlas" "") endif() endif() # BLAS in PhiPACK libraries? (requires generic BLAS lib, too) if(BLA_VENDOR STREQUAL "PhiPACK" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "sgemm;dgemm;blas" "") endif() endif() # BLAS in Alpha CXML library? if(BLA_VENDOR STREQUAL "CXML" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "cxml" "") endif() endif() # BLAS in Alpha DXML library? (now called CXML, see above) if(BLA_VENDOR STREQUAL "DXML" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "dxml" "") endif() endif() # BLAS in Sun Performance library? if(BLA_VENDOR STREQUAL "SunPerf" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "-xlic_lib=sunperf" "sunperf;sunmath" "") if(BLAS_LIBRARIES) set(BLAS_LINKER_FLAGS "-xlic_lib=sunperf") endif() endif() endif() # BLAS in SCSL library? (SGI/Cray Scientific Library) if(BLA_VENDOR STREQUAL "SCSL" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "scsl" "") endif() endif() # BLAS in SGIMATH library? if(BLA_VENDOR STREQUAL "SGIMATH" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "complib.sgimath" "") endif() endif() # BLAS in IBM ESSL library? (requires generic BLAS lib, too) if(BLA_VENDOR STREQUAL "IBMESSL" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "essl;blas" "") endif() endif() #BLAS in acml library? if(BLA_VENDOR MATCHES "ACML" OR BLA_VENDOR STREQUAL "All") if(((BLA_VENDOR STREQUAL "ACML") AND (NOT BLAS_ACML_LIB_DIRS)) OR ((BLA_VENDOR STREQUAL "ACML_MP") AND (NOT BLAS_ACML_MP_LIB_DIRS)) OR ((BLA_VENDOR STREQUAL "ACML_GPU") AND (NOT BLAS_ACML_GPU_LIB_DIRS))) # try to find acml in "standard" paths if(WIN32) file(GLOB _ACML_ROOT "C:/AMD/acml*/ACML-EULA.txt") else() file(GLOB _ACML_ROOT "/opt/acml*/ACML-EULA.txt") endif() if(WIN32) file(GLOB _ACML_GPU_ROOT "C:/AMD/acml*/GPGPUexamples") else() file(GLOB _ACML_GPU_ROOT "/opt/acml*/GPGPUexamples") endif() list(GET _ACML_ROOT 0 _ACML_ROOT) list(GET _ACML_GPU_ROOT 0 _ACML_GPU_ROOT) if(_ACML_ROOT) get_filename_component(_ACML_ROOT ${_ACML_ROOT} PATH) if(SIZEOF_INTEGER EQUAL 8) set(_ACML_PATH_SUFFIX "_int64") else() set(_ACML_PATH_SUFFIX "") endif() if(CMAKE_Fortran_COMPILER_ID STREQUAL "Intel") set(_ACML_COMPILER32 "ifort32") set(_ACML_COMPILER64 "ifort64") elseif(CMAKE_Fortran_COMPILER_ID STREQUAL "SunPro") set(_ACML_COMPILER32 "sun32") set(_ACML_COMPILER64 "sun64") elseif(CMAKE_Fortran_COMPILER_ID STREQUAL "PGI") set(_ACML_COMPILER32 "pgi32") if(WIN32) set(_ACML_COMPILER64 "win64") else() set(_ACML_COMPILER64 "pgi64") endif() elseif(CMAKE_Fortran_COMPILER_ID STREQUAL "Open64") # 32 bit builds not supported on Open64 but for code simplicity # We'll just use the same directory twice set(_ACML_COMPILER32 "open64_64") set(_ACML_COMPILER64 "open64_64") elseif(CMAKE_Fortran_COMPILER_ID STREQUAL "NAG") set(_ACML_COMPILER32 "nag32") set(_ACML_COMPILER64 "nag64") else() set(_ACML_COMPILER32 "gfortran32") set(_ACML_COMPILER64 "gfortran64") endif() if(BLA_VENDOR STREQUAL "ACML_MP") set(_ACML_MP_LIB_DIRS "${_ACML_ROOT}/${_ACML_COMPILER32}_mp${_ACML_PATH_SUFFIX}/lib" "${_ACML_ROOT}/${_ACML_COMPILER64}_mp${_ACML_PATH_SUFFIX}/lib") else() set(_ACML_LIB_DIRS "${_ACML_ROOT}/${_ACML_COMPILER32}${_ACML_PATH_SUFFIX}/lib" "${_ACML_ROOT}/${_ACML_COMPILER64}${_ACML_PATH_SUFFIX}/lib") endif() endif() elseif(BLAS_${BLA_VENDOR}_LIB_DIRS) set(_${BLA_VENDOR}_LIB_DIRS ${BLAS_${BLA_VENDOR}_LIB_DIRS}) endif() if(BLA_VENDOR STREQUAL "ACML_MP") foreach(BLAS_ACML_MP_LIB_DIRS ${_ACML_MP_LIB_DIRS}) Check_Fortran_Libraries( BLAS_LIBRARIES BLAS sgemm "" "acml_mp;acml_mv" "" ${BLAS_ACML_MP_LIB_DIRS}) if(BLAS_LIBRARIES) break() endif() endforeach() elseif(BLA_VENDOR STREQUAL "ACML_GPU") foreach(BLAS_ACML_GPU_LIB_DIRS ${_ACML_GPU_LIB_DIRS}) Check_Fortran_Libraries( BLAS_LIBRARIES BLAS sgemm "" "acml;acml_mv;CALBLAS" "" ${BLAS_ACML_GPU_LIB_DIRS}) if(BLAS_LIBRARIES) break() endif() endforeach() else() foreach(BLAS_ACML_LIB_DIRS ${_ACML_LIB_DIRS}) Check_Fortran_Libraries( BLAS_LIBRARIES BLAS sgemm "" "acml;acml_mv" "" ${BLAS_ACML_LIB_DIRS}) if(BLAS_LIBRARIES) break() endif() endforeach() endif() # Either acml or acml_mp should be in LD_LIBRARY_PATH but not both if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "acml;acml_mv" "") endif() if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "acml_mp;acml_mv" "") endif() if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "acml;acml_mv;CALBLAS" "") endif() endif() # ACML # Apple BLAS library? if(BLA_VENDOR STREQUAL "Apple" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm "" "Accelerate" "") endif() endif() if(BLA_VENDOR STREQUAL "NAS" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm "" "vecLib" "") endif() endif() # Generic BLAS library? if(BLA_VENDOR STREQUAL "Generic" OR BLA_VENDOR STREQUAL "All") if(NOT BLAS_LIBRARIES) Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm "" "blas" "") endif() endif() #BLAS in intel mkl 10 library? (em64t 64bit) if(BLA_VENDOR MATCHES "Intel" OR BLA_VENDOR STREQUAL "All") if(NOT WIN32) set(LM "-lm") endif() if(CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED) if(BLAS_FIND_QUIETLY OR NOT BLAS_FIND_REQUIRED) find_package(Threads) else() find_package(Threads REQUIRED) endif() set(BLAS_SEARCH_LIBS "") if(BLA_F95) set(BLAS_mkl_SEARCH_SYMBOL SGEMM) set(_LIBRARIES BLAS95_LIBRARIES) if(WIN32) if(BLA_STATIC) set(BLAS_mkl_DLL_SUFFIX "") else() set(BLAS_mkl_DLL_SUFFIX "_dll") endif() # Find the main file (32-bit or 64-bit) set(BLAS_SEARCH_LIBS_WIN_MAIN "") if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN "mkl_blas95${BLAS_mkl_DLL_SUFFIX} mkl_intel_c${BLAS_mkl_DLL_SUFFIX}") endif() if(BLA_VENDOR MATCHES "^Intel10_64lp" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN "mkl_blas95_lp64${BLAS_mkl_DLL_SUFFIX} mkl_intel_lp64${BLAS_mkl_DLL_SUFFIX}") endif() # Add threading/sequential libs set(BLAS_SEARCH_LIBS_WIN_THREAD "") if(BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "mkl_sequential${BLAS_mkl_DLL_SUFFIX}") endif() if(NOT BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") # old version list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") # mkl >= 10.3 list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") endif() # Cartesian product of the above foreach(MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN}) foreach(THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD}) list(APPEND BLAS_SEARCH_LIBS "${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}") endforeach() endforeach() else() if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS "mkl_blas95 mkl_intel mkl_intel_thread mkl_core guide") endif() if(BLA_VENDOR STREQUAL "Intel10_64lp" OR BLA_VENDOR STREQUAL "All") # old version list(APPEND BLAS_SEARCH_LIBS "mkl_blas95 mkl_intel_lp64 mkl_intel_thread mkl_core guide") # mkl >= 10.3 if(CMAKE_C_COMPILER MATCHES ".+gcc") list(APPEND BLAS_SEARCH_LIBS "mkl_blas95_lp64 mkl_intel_lp64 mkl_gnu_thread mkl_core gomp") else() list(APPEND BLAS_SEARCH_LIBS "mkl_blas95_lp64 mkl_intel_lp64 mkl_intel_thread mkl_core iomp5") endif() endif() if(BLA_VENDOR STREQUAL "Intel10_64lp_seq" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS "mkl_intel_lp64 mkl_sequential mkl_core") endif() endif() else() set(BLAS_mkl_SEARCH_SYMBOL sgemm) set(_LIBRARIES BLAS_LIBRARIES) if(WIN32) if(BLA_STATIC) set(BLAS_mkl_DLL_SUFFIX "") else() set(BLAS_mkl_DLL_SUFFIX "_dll") endif() # Find the main file (32-bit or 64-bit) set(BLAS_SEARCH_LIBS_WIN_MAIN "") if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN "mkl_intel_c${BLAS_mkl_DLL_SUFFIX}") endif() if(BLA_VENDOR MATCHES "^Intel10_64lp" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN "mkl_intel_lp64${BLAS_mkl_DLL_SUFFIX}") endif() # Add threading/sequential libs set(BLAS_SEARCH_LIBS_WIN_THREAD "") if(NOT BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") # old version list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") # mkl >= 10.3 list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") endif() if(BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD "mkl_sequential${BLAS_mkl_DLL_SUFFIX}") endif() # Cartesian product of the above foreach(MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN}) foreach(THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD}) list(APPEND BLAS_SEARCH_LIBS "${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}") endforeach() endforeach() else() if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS "mkl_intel mkl_intel_thread mkl_core guide") endif() if(BLA_VENDOR STREQUAL "Intel10_64lp" OR BLA_VENDOR STREQUAL "All") # old version list(APPEND BLAS_SEARCH_LIBS "mkl_intel_lp64 mkl_intel_thread mkl_core guide") # mkl >= 10.3 if(CMAKE_C_COMPILER MATCHES ".+gcc") list(APPEND BLAS_SEARCH_LIBS "mkl_intel_lp64 mkl_gnu_thread mkl_core gomp") else() list(APPEND BLAS_SEARCH_LIBS "mkl_intel_lp64 mkl_intel_thread mkl_core iomp5") endif() endif() if(BLA_VENDOR STREQUAL "Intel10_64lp_seq" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS "mkl_intel_lp64 mkl_sequential mkl_core") endif() #older vesions of intel mkl libs if(BLA_VENDOR STREQUAL "Intel" OR BLA_VENDOR STREQUAL "All") list(APPEND BLAS_SEARCH_LIBS "mkl") list(APPEND BLAS_SEARCH_LIBS "mkl_ia32") list(APPEND BLAS_SEARCH_LIBS "mkl_em64t") endif() endif() endif() foreach(IT ${BLAS_SEARCH_LIBS}) string(REPLACE " " ";" SEARCH_LIBS ${IT}) if(${_LIBRARIES}) else() Check_Fortran_Libraries(${_LIBRARIES} BLAS ${BLAS_mkl_SEARCH_SYMBOL} "" "${SEARCH_LIBS}" "${CMAKE_THREAD_LIBS_INIT};${LM}") endif() endforeach() endif() endif() if(BLA_F95) if(BLAS95_LIBRARIES) set(BLAS95_FOUND TRUE) else() set(BLAS95_FOUND FALSE) endif() if(NOT BLAS_FIND_QUIETLY) if(BLAS95_FOUND) message(STATUS "A library with BLAS95 API found.") else() if(BLAS_FIND_REQUIRED) message( FATAL_ERROR "A required library with BLAS95 API not found. Please specify library location.") else() message(STATUS "A library with BLAS95 API not found. Please specify library location.") endif() endif() endif() set(BLAS_FOUND TRUE) set(BLAS_LIBRARIES "${BLAS95_LIBRARIES}") else() if(BLAS_LIBRARIES) set(BLAS_FOUND TRUE) else() set(BLAS_FOUND FALSE) endif() if(NOT BLAS_FIND_QUIETLY) if(BLAS_FOUND) message(STATUS "A library with BLAS API found.") else() if(BLAS_FIND_REQUIRED) message( FATAL_ERROR "A required library with BLAS API not found. Please specify library location." ) else() message(STATUS "A library with BLAS API not found. Please specify library location.") endif() endif() endif() endif() set(CMAKE_FIND_LIBRARY_SUFFIXES ${_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES}) ================================================ FILE: cmake/third_party/FindCUDNN.cmake ================================================ # - Try to find cuDNN # # The following variables are optionally searched for defaults # CUDNN_ROOT_DIR: Base directory where all cuDNN components are found # # The following are set after configuration is done: # CUDNN_FOUND # CUDNN_INCLUDE_DIRS # CUDNN_LIBRARIES # CUDNN_LIBRARY_DIRS include(FindPackageHandleStandardArgs) include(CMakeDependentOption) set(CUDNN_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA cuDNN") if(CUDA_VERSION VERSION_LESS "11.0") set(CUDA_VERSION_VERSION_LESS_11 TRUE) endif() cmake_dependent_option(CUDNN_STATIC "Look for static cuDNN" ON "CUDA_VERSION_VERSION_LESS_11" OFF) if(OF_CUDA_LINK_DYNAMIC_LIBRARY) set(CUDNN_STATIC OFF) endif() if(CUDNN_STATIC) set(__cudnn_libname "libcudnn_static.a") else() set(__cudnn_libname "libcudnn.so") endif() find_path(CUDNN_INCLUDE_DIR cudnn.h HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_INCLUDE_DIRS} PATH_SUFFIXES cuda/include include) unset(CUDNN_LIBRARY CACHE) find_library(CUDNN_LIBRARY ${__cudnn_libname} HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR} PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_INCLUDE_DIR CUDNN_LIBRARY) if(CUDNN_FOUND) # get cuDNN version if(EXISTS "${CUDNN_INCLUDE_DIR}/cudnn_version.h") file(READ ${CUDNN_INCLUDE_DIR}/cudnn_version.h CUDNN_HEADER_CONTENTS) else() file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_HEADER_CONTENTS) endif() string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") # Assemble cuDNN version if(NOT CUDNN_VERSION_MAJOR) set(CUDNN_VERSION "?") else() set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") endif() set(CUDNN_INCLUDE_DIRS ${CUDNN_INCLUDE_DIR}) if(NOT CUDNN_STATIC AND CUDNN_VERSION_MAJOR GREATER_EQUAL 9) # skipping: libcudnn_adv_infer.so libcudnn_adv_train.so set(CUDNN_DYNAMIC_NAMES libcudnn_cnn.so libcudnn_ops.so) get_filename_component(CUDNN_LIBRARY_DIRECTORY ${CUDNN_LIBRARY} DIRECTORY) foreach(CUDNN_DYNAMIC_NAME ${CUDNN_DYNAMIC_NAMES}) list(APPEND CUDNN_LIBRARIES ${CUDNN_LIBRARY_DIRECTORY}/${CUDNN_DYNAMIC_NAME}) endforeach() elseif(NOT CUDNN_STATIC AND CUDNN_VERSION_MAJOR GREATER_EQUAL 8) # skipping: libcudnn_adv_infer.so libcudnn_adv_train.so set(CUDNN_DYNAMIC_NAMES libcudnn_cnn_infer.so libcudnn_cnn_train.so libcudnn_ops_infer.so libcudnn_ops_train.so) get_filename_component(CUDNN_LIBRARY_DIRECTORY ${CUDNN_LIBRARY} DIRECTORY) foreach(CUDNN_DYNAMIC_NAME ${CUDNN_DYNAMIC_NAMES}) list(APPEND CUDNN_LIBRARIES ${CUDNN_LIBRARY_DIRECTORY}/${CUDNN_DYNAMIC_NAME}) endforeach() else() set(CUDNN_LIBRARIES ${CUDNN_LIBRARY}) endif() message( STATUS "Found cuDNN: v${CUDNN_VERSION} (include: ${CUDNN_INCLUDE_DIR}, library: ${CUDNN_LIBRARIES})" ) mark_as_advanced(CUDNN_ROOT_DIR CUDNN_LIBRARY CUDNN_INCLUDE_DIR) endif() ================================================ FILE: cmake/third_party/FindUnwind.cmake ================================================ # - Try to find libunwind # Once done this will define # # Unwind_FOUND - system has libunwind # unwind::unwind - cmake target for libunwind include(FindPackageHandleStandardArgs) find_path(Unwind_INCLUDE_DIR NAMES unwind.h libunwind.h DOC "unwind include directory") find_library(Unwind_LIBRARY NAMES unwind DOC "unwind library") mark_as_advanced(Unwind_INCLUDE_DIR Unwind_LIBRARY) # Extract version information if(Unwind_LIBRARY) set(_Unwind_VERSION_HEADER ${Unwind_INCLUDE_DIR}/libunwind-common.h) if(EXISTS ${_Unwind_VERSION_HEADER}) file(READ ${_Unwind_VERSION_HEADER} _Unwind_VERSION_CONTENTS) string(REGEX REPLACE ".*#define UNW_VERSION_MAJOR[ \t]+([0-9]+).*" "\\1" Unwind_VERSION_MAJOR "${_Unwind_VERSION_CONTENTS}") string(REGEX REPLACE ".*#define UNW_VERSION_MINOR[ \t]+([0-9]+).*" "\\1" Unwind_VERSION_MINOR "${_Unwind_VERSION_CONTENTS}") string(REGEX REPLACE ".*#define UNW_VERSION_EXTRA[ \t]+([0-9]+).*" "\\1" Unwind_VERSION_PATCH "${_Unwind_VERSION_CONTENTS}") set(Unwind_VERSION ${Unwind_VERSION_MAJOR}.${Unwind_VERSION_MINOR}) if(CMAKE_MATCH_0) # Third version component may be empty set(Unwind_VERSION ${Unwind_VERSION}.${Unwind_VERSION_PATCH}) set(Unwind_VERSION_COMPONENTS 3) else(CMAKE_MATCH_0) set(Unwind_VERSION_COMPONENTS 2) endif(CMAKE_MATCH_0) endif(EXISTS ${_Unwind_VERSION_HEADER}) endif(Unwind_LIBRARY) # handle the QUIETLY and REQUIRED arguments and set Unwind_FOUND to TRUE # if all listed variables are TRUE find_package_handle_standard_args(Unwind REQUIRED_VARS Unwind_INCLUDE_DIR Unwind_LIBRARY VERSION_VAR Unwind_VERSION) if(Unwind_FOUND) if(NOT TARGET unwind::unwind) add_library(unwind::unwind INTERFACE IMPORTED) set_property(TARGET unwind::unwind PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${Unwind_INCLUDE_DIR}) set_property(TARGET unwind::unwind PROPERTY INTERFACE_LINK_LIBRARIES ${Unwind_LIBRARY}) set_property(TARGET unwind::unwind PROPERTY IMPORTED_CONFIGURATIONS RELEASE) endif(NOT TARGET unwind::unwind) endif(Unwind_FOUND) ================================================ FILE: cmake/third_party/absl.cmake ================================================ include(ExternalProject) include(GNUInstallDirs) set(ABSL_PROJECT absl) set(ABSL_TAR_URL https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.2.tar.gz) use_mirror(VARIABLE ABSL_TAR_URL URL ${ABSL_TAR_URL}) set(ABSL_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/absl/src/absl) set(ABSL_INSTALL ${THIRD_PARTY_DIR}/absl) set(ABSL_INCLUDE_DIR ${THIRD_PARTY_DIR}/absl/include CACHE PATH "" FORCE) set(ABSL_LIBRARY_DIR ${THIRD_PARTY_DIR}/absl/${CMAKE_INSTALL_LIBDIR} CACHE PATH "" FORCE) if(WIN32) set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}) set(ABSL_LIBRARY_NAMES absl_spinlock_wait.lib absl_malloc_internal.lib absl_throw_delegate.lib absl_int128.lib absl_strings.lib absl_str_format_internal.lib absl_time.lib absl_bad_optional_access.lib absl_base.lib) else() set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}) set(ABSL_LIBRARY_NAMES libabsl_spinlock_wait.a libabsl_malloc_internal.a libabsl_throw_delegate.a libabsl_int128.a libabsl_strings.a libabsl_str_format_internal.a libabsl_time.a libabsl_bad_optional_access.a libabsl_base.a) endif() foreach(LIBRARY_NAME ${ABSL_LIBRARY_NAMES}) list(APPEND ABSL_STATIC_LIBRARIES ${ABSL_LIBRARY_DIR}/${LIBRARY_NAME}) list(APPEND ABSL_BUILD_STATIC_LIBRARIES ${ABSL_BUILD_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() if(THIRD_PARTY) ExternalProject_Add( ${ABSL_PROJECT} PREFIX absl URL ${ABSL_TAR_URL} URL_MD5 52b9786ca6fbc679869fee2b6fef25a5 UPDATE_COMMAND "" BUILD_BYPRODUCTS ${ABSL_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${ABSL_INSTALL} -DCMAKE_INSTALL_LIBDIR:PATH=${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/cares.cmake ================================================ include(ExternalProject) set(CARES_TAR_URL https://github.com/c-ares/c-ares/releases/download/cares-1_15_0/c-ares-1.15.0.tar.gz) use_mirror(VARIABLE CARES_TAR_URL URL ${CARES_TAR_URL}) set(CARES_URL_HASH d2391da274653f7643270623e822dff7) set(CARES_INSTALL ${THIRD_PARTY_DIR}/cares) set(CARES_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cares/src/cares) if(THIRD_PARTY) ExternalProject_Add( cares PREFIX cares URL ${CARES_TAR_URL} URL_HASH MD5=${CARES_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "") endif() ================================================ FILE: cmake/third_party/cocoapi.cmake ================================================ include(ExternalProject) set(COCOAPI_INCLUDE_DIR ${THIRD_PARTY_DIR}/cocoapi/include) set(COCOAPI_LIBRARY_DIR ${THIRD_PARTY_DIR}/cocoapi/lib) set(COCOAPI_URL https://github.com/Oneflow-Inc/cocoapi/archive/refs/tags/ed842bf.tar.gz) use_mirror(VARIABLE COCOAPI_URL URL ${COCOAPI_URL}) set(COCOAPI_URL_HASH e7e0504231e5614ffaa34f081773f7f1) set(COCOAPI_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cocoapi/src/cocoapi) set(COCOAPI_LIBRARY_NAME libcocoapi_static.a) list(APPEND COCOAPI_STATIC_LIBRARIES ${COCOAPI_LIBRARY_DIR}/${COCOAPI_LIBRARY_NAME}) list(APPEND COCOAPI_BUILD_STATIC_LIBRARIES ${COCOAPI_BASE_DIR}/${COCOAPI_LIBRARY_NAME}) set(COCOAPI_HEADERS "${COCOAPI_BASE_DIR}/common/maskApi.h") if(THIRD_PARTY) ExternalProject_Add( cocoapi PREFIX cocoapi URL ${COCOAPI_URL} URL_HASH MD5=${COCOAPI_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${COCOAPI_STATIC_LIBRARIES} BUILD_COMMAND ${CMAKE_C_COMPILER} -fPIC -O3 -c common/maskApi.c -o maskApi.o && ${CMAKE_AR} rcs ${COCOAPI_LIBRARY_NAME} maskApi.o INSTALL_COMMAND "") add_custom_target(cocoapi_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${COCOAPI_INCLUDE_DIR} DEPENDS cocoapi) add_custom_target(cocoapi_copy_headers_to_destination DEPENDS cocoapi_create_header_dir) foreach(header_file ${COCOAPI_HEADERS}) add_custom_command( TARGET cocoapi_copy_headers_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${COCOAPI_INCLUDE_DIR}) endforeach() add_custom_target(cocoapi_create_library_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${COCOAPI_LIBRARY_DIR} DEPENDS cocoapi) add_custom_target( cocoapi_copy_libs_to_destination COMMAND ${CMAKE_COMMAND} -E copy_if_different ${COCOAPI_BUILD_STATIC_LIBRARIES} ${COCOAPI_LIBRARY_DIR} DEPENDS cocoapi_create_library_dir) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/cub.cmake ================================================ include(ExternalProject) set(CUB_INCLUDE_DIR ${THIRD_PARTY_DIR}/cub/include) set(CUB_BUILD_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub/cub) set(CUB_URL https://github.com/NVIDIA/cub/archive/refs/tags/1.11.0.tar.gz) use_mirror(VARIABLE CUB_URL URL ${CUB_URL}) if(THIRD_PARTY) ExternalProject_Add( cub PREFIX cub URL ${CUB_URL} URL_MD5 97196a885598e40592100e1caaf3d5ea CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "") add_copy_headers_target( NAME cub SRC ${CUB_BUILD_INCLUDE} DST ${CUB_INCLUDE_DIR}/cub DEPS cub INDEX_FILE "${oneflow_cmake_dir}/third_party/header_index/cub_headers.txt") endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/cutlass.cmake ================================================ include(ExternalProject) if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(WITH_CUTLASS_INIT OFF) else() set(WITH_CUTLASS_INIT ON) endif() set(WITH_CUTLASS ${WITH_CUTLASS_INIT} CACHE BOOL "") if(WITH_CUTLASS) add_definitions(-DWITH_CUTLASS) find_package(Threads) set(CUTLASS_PROJECT cutlass) set(CUTLASS_INSTALL_DIR ${THIRD_PARTY_DIR}/cutlass) set(CUTLASS_INCLUDE_DIR ${CUTLASS_INSTALL_DIR}/include CACHE PATH "" FORCE) set(CUTLASS_LIBRARY_DIR ${CUTLASS_INSTALL_DIR}/lib CACHE PATH "" FORCE) set(CUTLASS_LIBRARIES ${CUTLASS_LIBRARY_DIR}/libcutlass.so) set(CUTLASS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cutlass/src/cutlass/) foreach(arch ${CUDA_REAL_ARCHS_LIST}) if(arch GREATER_EQUAL 70) list(APPEND CUTLASS_REAL_ARCHS ${arch}) endif() endforeach() if(THIRD_PARTY) ExternalProject_Add( ${CUTLASS_PROJECT} PREFIX cutlass URL ${CUTLASS_URL} URL_MD5 ${CUTLASS_MD5} UPDATE_COMMAND "" BUILD_BYPRODUCTS ${CUTLASS_LIBRARIES} CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} CMAKE_CACHE_ARGS -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE} -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${CUTLASS_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${CUTLASS_LIBRARY_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCUTLASS_LIBRARY_OPERATIONS:STRING=conv2d -DCUTLASS_LIBRARY_KERNELS:STRING=simt_hfprop_*,tensorop_f16_*fprop,tensorop_h*fprop -DCUTLASS_ENABLE_EXAMPLES:BOOL=OFF -DCUTLASS_ENABLE_PROFILER:BOOL=OFF -DCUTLASS_ENABLE_LIBRARY:BOOL=ON -DCUTLASS_NVCC_ARCHS:STRING=${CUTLASS_REAL_ARCHS} -DCUTLASS_ENABLE_TESTS:BOOL=OFF -DCUTLASS_UNITY_BUILD_ENABLED:BOOL=ON -DCUTLASS_LIBRARY_DEBUG_POSTFIX:STRING= -DCUTLASS_NVCC_EMBED_PTX:BOOL=OFF) add_custom_target(cutlass_copy_examples_to_destination DEPENDS cutlass) set(CUTLASS_SOURCE_EXAMPLES_DIR ${CUTLASS_SOURCE_DIR}/examples) set(CUTLASS_INSTALL_EXAMPLES_FILES "45_dual_gemm/test_run.h" "45_dual_gemm/kernel/dual_gemm.h" "45_dual_gemm/device/dual_gemm.h" "45_dual_gemm/dual_gemm_run.h" "45_dual_gemm/thread/left_silu_and_mul.h" "45_dual_gemm/threadblock/dual_mma_multistage.h" "45_dual_gemm/threadblock/dual_epilogue.h" "45_dual_gemm/threadblock/dual_mma_base.h" "xformers_fmha/gemm_kernel_utils.h" "xformers_fmha/gemm/find_default_mma.h" "xformers_fmha/gemm/mma_accum_lambda_iterator.h" "xformers_fmha/gemm/custom_mma_multistage.h" "xformers_fmha/gemm/mma_from_smem.h" "xformers_fmha/gemm/custom_mma.h" "xformers_fmha/gemm/custom_mma_base.h" "xformers_fmha/gemm/custom_mma_pipelined.h" "xformers_fmha/epilogue/epilogue_thread_apply_logsumexp.h" "xformers_fmha/epilogue/epilogue_rescale_output.h" "xformers_fmha/epilogue/epilogue_pipelined.h" "xformers_fmha/debug_utils.h" "xformers_fmha/kernel_forward.h" "xformers_fmha/pytorch_utils.h" "xformers_fmha/transform/tile_smem_loader.h" "xformers_fmha/autogen/cutlassB.h" "xformers_fmha/autogen/cutlassF.h" "xformers_fmha/iterators/make_residual_last.h" "xformers_fmha/iterators/predicated_tile_iterator_residual_last.h" "xformers_fmha/iterators/epilogue_predicated_tile_iterator.h" "xformers_fmha/iterators/transpose_warp_iterator.h" "xformers_fmha/iterators/warp_iterator_from_smem.h" "xformers_fmha/iterators/predicated_tile_access_iterator_residual_last.h" "xformers_fmha/kernel_backward.h") foreach(filename ${CUTLASS_INSTALL_EXAMPLES_FILES}) add_custom_command( TARGET cutlass_copy_examples_to_destination COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CUTLASS_SOURCE_EXAMPLES_DIR}/${filename} ${CUTLASS_INSTALL_DIR}/examples/${filename}) endforeach() endif(THIRD_PARTY) endif(WITH_CUTLASS) ================================================ FILE: cmake/third_party/eigen.cmake ================================================ include(ExternalProject) set(EIGEN_INCLUDE_DIR ${THIRD_PARTY_DIR}/eigen/include/eigen3) set(EIGEN_INSTALL_DIR ${THIRD_PARTY_DIR}/eigen) set(EIGEN_URL https://github.com/Oneflow-Inc/eigen-git-mirror/archive/refs/tags/e9e95489a.tar.gz) set(EIGEN_MD5 a23cb70e12d1bf9b09cb28af51bc26ae) use_mirror(VARIABLE EIGEN_URL URL ${EIGEN_URL}) if(BUILD_CUDA) add_definitions(-DEIGEN_USE_GPU) endif() if(THIRD_PARTY) ExternalProject_Add( eigen PREFIX eigen URL ${EIGEN_URL} URL_MD5 ${EIGEN_MD5} UPDATE_COMMAND "" INSTALL_DIR "${EIGEN_INSTALL_DIR}" CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${EIGEN_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DBUILD_TESTING:BOOL=OFF) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/flash_attention.cmake ================================================ include(ExternalProject) find_package(Threads) # NOTE: A git version of 1.6.5 or later is required if this download method is used. find_package(Git QUIET REQUIRED) set(FLASH_ATTENTION_PROJECT flash_attention) set(FLASH_ATTENTION_URL https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/flash-attention-v2-eed2e82b880e06237af3e50ceac4cf6728b15645.zip ) set(FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/flash_attention) set(FLASH_ATTENTION_INCLUDE_DIR ${FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH "" FORCE) set(FLASH_ATTENTION_LIBRARY_DIR ${FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH "" FORCE) set(FLASH_ATTENTION_LIBRARIES ${FLASH_ATTENTION_LIBRARY_DIR}/libflash_attention.so) if(THIRD_PARTY) ExternalProject_Add( ${FLASH_ATTENTION_PROJECT} PREFIX flash_attention URL ${FLASH_ATTENTION_URL} URL_HASH MD5=63192a05973f614aff594a8bd11813ce UPDATE_COMMAND "" BUILD_BYPRODUCTS ${FLASH_ATTENTION_LIBRARIES} CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CACHE_ARGS -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE} -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${FLASH_ATTENTION_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${FLASH_ATTENTION_LIBRARY_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/flatbuffers.cmake ================================================ include(ExternalProject) set(FLATBUFFERS_URL https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz) set(FLATBUFFERS_INSTALL_PREFIX ${THIRD_PARTY_DIR}/flatbuffers) set(FLATBUFFERS_INSTALL_INCLUDEDIR include) set(FLATBUFFERS_INSTALL_LIBDIR lib) set(FLATBUFFERS_INSTALL_BINDIR bin) use_mirror(VARIABLE FLATBUFFERS_URL URL ${FLATBUFFERS_URL}) set(FLATBUFFERS_INCLUDE_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_INCLUDEDIR}) set(FLATBUFFERS_LIBRARY_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_LIBDIR}) set(FLATBUFFERS_BINARY_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_BINDIR}) set(FLATC_EXECUTABLE_NAME flatc) set(FLATBUFFERS_FLATC_EXECUTABLE ${FLATBUFFERS_BINARY_DIR}/${FLATC_EXECUTABLE_NAME}) set(FLATBUFFERS_LIBRARY_NAMES libflatbuffers.a) foreach(LIBRARY_NAME ${FLATBUFFERS_LIBRARY_NAMES}) list(APPEND FLATBUFFERS_STATIC_LIBRARIES ${FLATBUFFERS_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() if(THIRD_PARTY) ExternalProject_Add( flatbuffers PREFIX flatbuffers URL ${FLATBUFFERS_URL} URL_MD5 c62ffefb3d4548b127cca14ce047f16c UPDATE_COMMAND bash -c "rm -f BUILD || true" BUILD_IN_SOURCE 1 SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/src/flatbuffers BUILD_BYPRODUCTS ${FLATBUFFERS_STATIC_LIBRARIES} CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${FLATBUFFERS_INSTALL_PREFIX} -DCMAKE_INSTALL_INCLUDEDIR=${FLATBUFFERS_INSTALL_INCLUDEDIR} -DCMAKE_INSTALL_LIBDIR=${FLATBUFFERS_INSTALL_LIBDIR} -DCMAKE_INSTALL_BINDIR=${FLATBUFFERS_INSTALL_BINDIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DFLATBUFFERS_BUILD_TESTS=OFF) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/glog.cmake ================================================ include(ExternalProject) set_mirror_url_with_hash(glog_URL https://github.com/google/glog/archive/refs/tags/v0.5.0.tar.gz 2368e3e0a95cce8b5b35a133271b480f) include(FetchContent) FetchContent_Declare(glog URL ${glog_URL} URL_HASH MD5=${glog_URL_HASH}) set(WITH_GFLAGS OFF CACHE BOOL "") set(BUILD_SHARED_LIBS OFF CACHE BOOL "") set(WITH_GTEST OFF CACHE BOOL "") FetchContent_MakeAvailable(glog) # just for tensorflow, DO NOT USE IN OTHER PLACE FetchContent_GetProperties(glog) set(GLOG_INCLUDE_DIR ${glog_BINARY_DIR}) ================================================ FILE: cmake/third_party/googletest.cmake ================================================ include(FetchContent) set_mirror_url_with_hash( googletest_URL https://github.com/google/googletest/archive/release-1.11.0.tar.gz e8a8df240b6938bb6384155d4c37d937) FetchContent_Declare(googletest URL ${googletest_URL} URL_HASH MD5=${googletest_URL_HASH}) FetchContent_MakeAvailable(googletest) ================================================ FILE: cmake/third_party/grpc.cmake ================================================ include(ExternalProject) set(GRPC_INSTALL_DIR ${THIRD_PARTY_DIR}/grpc) set(GRPC_INSTALL_INCLUDE_DIR include) set(GRPC_INSTALL_LIBRARY_DIR lib) set(GRPC_INCLUDE_DIR ${THIRD_PARTY_DIR}/grpc/${GRPC_INSTALL_INCLUDE_DIR}) set(GRPC_LIBRARY_DIR ${THIRD_PARTY_DIR}/grpc/${GRPC_INSTALL_LIBRARY_DIR}) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_TAR_URL https://github.com/grpc/grpc/archive/v1.27.3.tar.gz) use_mirror(VARIABLE GRPC_TAR_URL URL ${GRPC_TAR_URL}) set(GRPC_URL_HASH 0c6c3fc8682d4262dd0e5e6fabe1a7e2) set(GRPC_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/grpc) if(WIN32) set(GRPC_LIBRARY_NAMES grpc++_unsecure.lib grpc_unsecure.lib gpr.lib upb.lib address_sorting.lib cares.lib) elseif(APPLE AND ("${CMAKE_GENERATOR}" STREQUAL "Xcode")) set(GRPC_LIBRARY_NAMES libgrpc++_unsecure.a libgrpc_unsecure.a libgpr.a libupb.a libaddress_sorting.a libcares.a) else() include(GNUInstallDirs) set(GRPC_LIBRARY_NAMES libgrpc++_unsecure.a libgrpc_unsecure.a libgpr.a libupb.a libaddress_sorting.a libcares.a) endif() foreach(LIBRARY_NAME ${GRPC_LIBRARY_NAMES}) list(APPEND GRPC_STATIC_LIBRARIES ${GRPC_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() set(PROTOBUF_CONFIG_DIR ${PROTOBUF_LIBRARY_DIR}/cmake/protobuf) set(ABSL_CONFIG_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}/cmake/absl) if(THIRD_PARTY) include(ProcessorCount) ProcessorCount(PROC_NUM) ExternalProject_Add( grpc PREFIX ${GRPC_SOURCE_DIR} DEPENDS protobuf absl cares openssl zlib URL ${GRPC_TAR_URL} URL_HASH MD5=${GRPC_URL_HASH} UPDATE_COMMAND "" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${GRPC_STATIC_LIBRARIES} BUILD_COMMAND ${CMAKE_COMMAND} --build . -j ${PROC_NUM} --target grpc && ${CMAKE_COMMAND} --build . -j ${PROC_NUM} --target grpc_unsecure && ${CMAKE_COMMAND} --build . -j ${PROC_NUM} --target grpc++_unsecure CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_C_FLAGS_DEBUG:STRING=${CMAKE_C_FLAGS_DEBUG} -DCMAKE_C_FLAGS_RELEASE:STRING=${CMAKE_C_FLAGS_RELEASE} -DCMAKE_CXX_STANDARD:STRING=${CMAKE_CXX_STANDARD} -DgRPC_INSTALL:BOOL=ON -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DgRPC_BUILD_TESTS:BOOL=OFF -DgRPC_BUILD_GRPC_CPP_PLUGIN:BOOL=ON -DgRPC_BUILD_GRPC_CSHARP_PLUGIN:BOOL=OFF -DgRPC_BUILD_GRPC_NODE_PLUGIN:BOOL=OFF -DgRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN:BOOL=OFF -DgRPC_BUILD_GRPC_PHP_PLUGIN:BOOL=OFF -DgRPC_BUILD_GRPC_PYTHON_PLUGIN:BOOL=OFF -DgRPC_BUILD_GRPC_RUBY_PLUGIN:BOOL=OFF -DgRPC_ABSL_PROVIDER:STRING=package -Dabsl_DIR:PATH=${ABSL_CONFIG_DIR} -DgRPC_PROTOBUF_PROVIDER:STRING=package -DgRPC_PROTOBUF_PACKAGE_TYPE:STRING=CONFIG -DProtobuf_ROOT:STRING=${PROTOBUF_INSTALL_DIR} -DProtobuf_DIR:PATH=${PROTOBUF_CONFIG_DIR} -DgRPC_CARES_PROVIDER:STRING=module -DCARES_ROOT_DIR:PATH=${CARES_SOURCE_DIR} -DgRPC_ZLIB_PROVIDER:STRING=package -DZLIB_ROOT:PATH=${ZLIB_INSTALL} -DgRPC_SSL_PROVIDER:STRING=package -DOpenSSL_ROOT:PATH=${OPENSSL_INSTALL} -DCMAKE_INSTALL_PREFIX:STRING=${GRPC_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/half.cmake ================================================ include(ExternalProject) set(HALF_INCLUDE_DIR ${THIRD_PARTY_DIR}/half/include) set(HALF_URL https://github.com/Oneflow-Inc/half/archive/refs/tags/v2.1.0-fix-cuda-raise.zip) use_mirror(VARIABLE HALF_URL URL ${HALF_URL}) set(HALF_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/half/src/half) set(HALF_URL_HASH 30b0dc289729f9e85ddf6995f2e6968f) set(HALF_HEADERS "${HALF_BASE_DIR}/include/half.hpp") if(THIRD_PARTY) ExternalProject_Add( half PREFIX half URL ${HALF_URL} URL_HASH MD5=${HALF_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_COMMAND "" BUILD_IN_SOURCE 1 INSTALL_COMMAND "") add_custom_target(half_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${HALF_INCLUDE_DIR} DEPENDS half) add_custom_target(half_copy_headers_to_destination DEPENDS half_create_header_dir) foreach(header_file ${HALF_HEADERS}) add_custom_command( TARGET half_copy_headers_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${HALF_INCLUDE_DIR}) endforeach() endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/header_index/cub_headers.txt ================================================ config.cuh cub.cuh util_allocator.cuh util_arch.cuh util_compiler.cuh util_cpp_dialect.cuh util_debug.cuh util_deprecated.cuh util_device.cuh util_macro.cuh util_namespace.cuh util_ptx.cuh util_type.cuh version.cuh agent/agent_histogram.cuh agent/agent_radix_sort_downsweep.cuh agent/agent_radix_sort_histogram.cuh agent/agent_radix_sort_onesweep.cuh agent/agent_radix_sort_upsweep.cuh agent/agent_reduce.cuh agent/agent_reduce_by_key.cuh agent/agent_rle.cuh agent/agent_scan.cuh agent/agent_segment_fixup.cuh agent/agent_select_if.cuh agent/agent_spmv_orig.cuh agent/single_pass_scan_operators.cuh block/block_adjacent_difference.cuh block/block_discontinuity.cuh block/block_exchange.cuh block/block_histogram.cuh block/block_load.cuh block/block_radix_rank.cuh block/block_radix_sort.cuh block/block_raking_layout.cuh block/block_reduce.cuh block/block_scan.cuh block/block_shuffle.cuh block/block_store.cuh block/radix_rank_sort_operations.cuh block/specializations/block_histogram_atomic.cuh block/specializations/block_histogram_sort.cuh block/specializations/block_reduce_raking.cuh block/specializations/block_reduce_raking_commutative_only.cuh block/specializations/block_reduce_warp_reductions.cuh block/specializations/block_scan_raking.cuh block/specializations/block_scan_warp_scans.cuh block/specializations/block_scan_warp_scans2.cuh block/specializations/block_scan_warp_scans3.cuh device/device_histogram.cuh device/device_partition.cuh device/device_radix_sort.cuh device/device_reduce.cuh device/device_run_length_encode.cuh device/device_scan.cuh device/device_segmented_radix_sort.cuh device/device_segmented_reduce.cuh device/device_select.cuh device/device_spmv.cuh device/dispatch/dispatch_histogram.cuh device/dispatch/dispatch_radix_sort.cuh device/dispatch/dispatch_reduce.cuh device/dispatch/dispatch_reduce_by_key.cuh device/dispatch/dispatch_rle.cuh device/dispatch/dispatch_scan.cuh device/dispatch/dispatch_select_if.cuh device/dispatch/dispatch_spmv_orig.cuh grid/grid_barrier.cuh grid/grid_even_share.cuh grid/grid_mapping.cuh grid/grid_queue.cuh host/mutex.cuh iterator/arg_index_input_iterator.cuh iterator/cache_modified_input_iterator.cuh iterator/cache_modified_output_iterator.cuh iterator/constant_input_iterator.cuh iterator/counting_input_iterator.cuh iterator/discard_output_iterator.cuh iterator/tex_obj_input_iterator.cuh iterator/tex_ref_input_iterator.cuh iterator/transform_input_iterator.cuh thread/thread_load.cuh thread/thread_operators.cuh thread/thread_reduce.cuh thread/thread_scan.cuh thread/thread_search.cuh thread/thread_store.cuh warp/warp_reduce.cuh warp/warp_scan.cuh warp/specializations/warp_reduce_shfl.cuh warp/specializations/warp_reduce_smem.cuh warp/specializations/warp_scan_shfl.cuh warp/specializations/warp_scan_smem.cuh ================================================ FILE: cmake/third_party/header_index/grpc_headers.txt ================================================ grpc++/alarm.h grpc++/channel.h grpc++/client_context.h grpc++/completion_queue.h grpc++/create_channel.h grpc++/create_channel_posix.h grpc++/grpc++.h grpc++/health_check_service_interface.h grpc++/resource_quota.h grpc++/server.h grpc++/server_builder.h grpc++/server_context.h grpc++/server_posix.h grpc++/ext/health_check_service_server_builder_option.h grpc++/ext/proto_server_reflection_plugin.h grpc++/generic/async_generic_service.h grpc++/generic/generic_stub.h grpc++/impl/call.h grpc++/impl/channel_argument_option.h grpc++/impl/client_unary_call.h grpc++/impl/grpc_library.h grpc++/impl/method_handler_impl.h grpc++/impl/rpc_method.h grpc++/impl/rpc_service_method.h grpc++/impl/serialization_traits.h grpc++/impl/server_builder_option.h grpc++/impl/server_builder_plugin.h grpc++/impl/server_initializer.h grpc++/impl/service_type.h grpc++/impl/sync_cxx11.h grpc++/impl/sync_no_cxx11.h grpc++/impl/codegen/async_stream.h grpc++/impl/codegen/async_unary_call.h grpc++/impl/codegen/byte_buffer.h grpc++/impl/codegen/call.h grpc++/impl/codegen/call_hook.h grpc++/impl/codegen/channel_interface.h grpc++/impl/codegen/client_context.h grpc++/impl/codegen/client_unary_call.h grpc++/impl/codegen/completion_queue.h grpc++/impl/codegen/completion_queue_tag.h grpc++/impl/codegen/config.h grpc++/impl/codegen/config_protobuf.h grpc++/impl/codegen/core_codegen.h grpc++/impl/codegen/core_codegen_interface.h grpc++/impl/codegen/create_auth_context.h grpc++/impl/codegen/grpc_library.h grpc++/impl/codegen/metadata_map.h grpc++/impl/codegen/method_handler_impl.h grpc++/impl/codegen/proto_utils.h grpc++/impl/codegen/rpc_method.h grpc++/impl/codegen/rpc_service_method.h grpc++/impl/codegen/serialization_traits.h grpc++/impl/codegen/server_context.h grpc++/impl/codegen/server_interface.h grpc++/impl/codegen/service_type.h grpc++/impl/codegen/slice.h grpc++/impl/codegen/status.h grpc++/impl/codegen/status_code_enum.h grpc++/impl/codegen/string_ref.h grpc++/impl/codegen/stub_options.h grpc++/impl/codegen/sync_stream.h grpc++/impl/codegen/time.h grpc++/impl/codegen/security/auth_context.h grpc++/security/auth_context.h grpc++/security/auth_metadata_processor.h grpc++/security/credentials.h grpc++/security/server_credentials.h grpc++/support/async_stream.h grpc++/support/async_unary_call.h grpc++/support/byte_buffer.h grpc++/support/channel_arguments.h grpc++/support/config.h grpc++/support/error_details.h grpc++/support/slice.h grpc++/support/status.h grpc++/support/status_code_enum.h grpc++/support/string_ref.h grpc++/support/stub_options.h grpc++/support/sync_stream.h grpc++/support/time.h grpc++/test/mock_stream.h grpc++/test/server_context_test_spouse.h grpc/byte_buffer.h grpc/byte_buffer_reader.h grpc/census.h grpc/compression.h grpc/fork.h grpc/grpc.h grpc/grpc_cronet.h grpc/grpc_posix.h grpc/grpc_security.h grpc/grpc_security_constants.h grpc/load_reporting.h grpc/slice.h grpc/slice_buffer.h grpc/status.h grpc/impl/codegen/atm.h grpc/impl/codegen/atm_gcc_atomic.h grpc/impl/codegen/atm_gcc_sync.h grpc/impl/codegen/atm_windows.h grpc/impl/codegen/byte_buffer.h grpc/impl/codegen/byte_buffer_reader.h grpc/impl/codegen/compression_types.h grpc/impl/codegen/connectivity_state.h grpc/impl/codegen/fork.h grpc/impl/codegen/gpr_slice.h grpc/impl/codegen/gpr_types.h grpc/impl/codegen/grpc_types.h grpc/impl/codegen/log.h grpc/impl/codegen/port_platform.h grpc/impl/codegen/propagation_bits.h grpc/impl/codegen/slice.h grpc/impl/codegen/status.h grpc/impl/codegen/sync.h grpc/impl/codegen/sync_custom.h grpc/impl/codegen/sync_generic.h grpc/impl/codegen/sync_posix.h grpc/impl/codegen/sync_windows.h grpc/support/alloc.h grpc/support/atm.h grpc/support/atm_gcc_atomic.h grpc/support/atm_gcc_sync.h grpc/support/atm_windows.h grpc/support/cpu.h grpc/support/log.h grpc/support/log_windows.h grpc/support/port_platform.h grpc/support/string_util.h grpc/support/sync.h grpc/support/sync_custom.h grpc/support/sync_generic.h grpc/support/sync_posix.h grpc/support/sync_windows.h grpc/support/thd_id.h grpc/support/time.h grpc/support/workaround_list.h grpcpp/alarm.h grpcpp/alarm_impl.h grpcpp/channel.h grpcpp/channel_impl.h grpcpp/client_context.h grpcpp/completion_queue.h grpcpp/completion_queue_impl.h grpcpp/create_channel.h grpcpp/create_channel_impl.h grpcpp/create_channel_posix.h grpcpp/create_channel_posix_impl.h grpcpp/grpcpp.h grpcpp/health_check_service_interface.h grpcpp/health_check_service_interface_impl.h grpcpp/opencensus.h grpcpp/opencensus_impl.h grpcpp/resource_quota.h grpcpp/resource_quota_impl.h grpcpp/server.h grpcpp/server_builder.h grpcpp/server_builder_impl.h grpcpp/server_context.h grpcpp/server_impl.h grpcpp/server_posix.h grpcpp/server_posix_impl.h grpcpp/ext/channelz_service_plugin.h grpcpp/ext/channelz_service_plugin_impl.h grpcpp/ext/health_check_service_server_builder_option.h grpcpp/ext/proto_server_reflection_plugin.h grpcpp/ext/proto_server_reflection_plugin_impl.h grpcpp/ext/server_load_reporting.h grpcpp/ext/server_load_reporting_impl.h grpcpp/generic/async_generic_service.h grpcpp/generic/generic_stub.h grpcpp/generic/generic_stub_impl.h grpcpp/impl/call.h grpcpp/impl/channel_argument_option.h grpcpp/impl/client_unary_call.h grpcpp/impl/grpc_library.h grpcpp/impl/method_handler_impl.h grpcpp/impl/rpc_method.h grpcpp/impl/rpc_service_method.h grpcpp/impl/serialization_traits.h grpcpp/impl/server_builder_option.h grpcpp/impl/server_builder_option_impl.h grpcpp/impl/server_builder_plugin.h grpcpp/impl/server_initializer.h grpcpp/impl/server_initializer_impl.h grpcpp/impl/service_type.h grpcpp/impl/sync_cxx11.h grpcpp/impl/sync_no_cxx11.h grpcpp/impl/codegen/async_generic_service.h grpcpp/impl/codegen/async_stream.h grpcpp/impl/codegen/async_stream_impl.h grpcpp/impl/codegen/async_unary_call.h grpcpp/impl/codegen/async_unary_call_impl.h grpcpp/impl/codegen/byte_buffer.h grpcpp/impl/codegen/call.h grpcpp/impl/codegen/call_hook.h grpcpp/impl/codegen/call_op_set.h grpcpp/impl/codegen/call_op_set_interface.h grpcpp/impl/codegen/callback_common.h grpcpp/impl/codegen/channel_interface.h grpcpp/impl/codegen/client_callback.h grpcpp/impl/codegen/client_callback_impl.h grpcpp/impl/codegen/client_context.h grpcpp/impl/codegen/client_context_impl.h grpcpp/impl/codegen/client_interceptor.h grpcpp/impl/codegen/client_unary_call.h grpcpp/impl/codegen/completion_queue.h grpcpp/impl/codegen/completion_queue_impl.h grpcpp/impl/codegen/completion_queue_tag.h grpcpp/impl/codegen/config.h grpcpp/impl/codegen/config_protobuf.h grpcpp/impl/codegen/core_codegen.h grpcpp/impl/codegen/core_codegen_interface.h grpcpp/impl/codegen/create_auth_context.h grpcpp/impl/codegen/delegating_channel.h grpcpp/impl/codegen/grpc_library.h grpcpp/impl/codegen/intercepted_channel.h grpcpp/impl/codegen/interceptor.h grpcpp/impl/codegen/interceptor_common.h grpcpp/impl/codegen/message_allocator.h grpcpp/impl/codegen/metadata_map.h grpcpp/impl/codegen/method_handler.h grpcpp/impl/codegen/method_handler_impl.h grpcpp/impl/codegen/proto_buffer_reader.h grpcpp/impl/codegen/proto_buffer_writer.h grpcpp/impl/codegen/proto_utils.h grpcpp/impl/codegen/rpc_method.h grpcpp/impl/codegen/rpc_service_method.h grpcpp/impl/codegen/serialization_traits.h grpcpp/impl/codegen/server_callback.h grpcpp/impl/codegen/server_callback_handlers.h grpcpp/impl/codegen/server_callback_impl.h grpcpp/impl/codegen/server_context.h grpcpp/impl/codegen/server_context_impl.h grpcpp/impl/codegen/server_interceptor.h grpcpp/impl/codegen/server_interface.h grpcpp/impl/codegen/service_type.h grpcpp/impl/codegen/slice.h grpcpp/impl/codegen/status.h grpcpp/impl/codegen/status_code_enum.h grpcpp/impl/codegen/string_ref.h grpcpp/impl/codegen/stub_options.h grpcpp/impl/codegen/sync.h grpcpp/impl/codegen/sync_stream.h grpcpp/impl/codegen/sync_stream_impl.h grpcpp/impl/codegen/time.h grpcpp/impl/codegen/security/auth_context.h grpcpp/security/alts_context.h grpcpp/security/alts_util.h grpcpp/security/auth_context.h grpcpp/security/auth_metadata_processor.h grpcpp/security/auth_metadata_processor_impl.h grpcpp/security/credentials.h grpcpp/security/credentials_impl.h grpcpp/security/cronet_credentials.h grpcpp/security/cronet_credentials_impl.h grpcpp/security/server_credentials.h grpcpp/security/server_credentials_impl.h grpcpp/security/tls_credentials_options.h grpcpp/support/async_stream.h grpcpp/support/async_stream_impl.h grpcpp/support/async_unary_call.h grpcpp/support/async_unary_call_impl.h grpcpp/support/byte_buffer.h grpcpp/support/channel_arguments.h grpcpp/support/channel_arguments_impl.h grpcpp/support/client_callback.h grpcpp/support/client_callback_impl.h grpcpp/support/client_interceptor.h grpcpp/support/config.h grpcpp/support/error_details.h grpcpp/support/error_details_impl.h grpcpp/support/interceptor.h grpcpp/support/message_allocator.h grpcpp/support/method_handler.h grpcpp/support/proto_buffer_reader.h grpcpp/support/proto_buffer_writer.h grpcpp/support/server_callback.h grpcpp/support/server_callback_impl.h grpcpp/support/server_interceptor.h grpcpp/support/slice.h grpcpp/support/status.h grpcpp/support/status_code_enum.h grpcpp/support/string_ref.h grpcpp/support/stub_options.h grpcpp/support/sync_stream.h grpcpp/support/sync_stream_impl.h grpcpp/support/time.h grpcpp/support/validate_service_config.h grpcpp/test/default_reactor_test_peer.h grpcpp/test/mock_stream.h grpcpp/test/server_context_test_spouse.h ================================================ FILE: cmake/third_party/header_index/libpng_headers.txt ================================================ png.h pngconf.h pngdebug.h pnginfo.h pnglibconf.h pngpriv.h pngstruct.h ================================================ FILE: cmake/third_party/header_index/opencv_headers.txt ================================================ opencv2/cvconfig.h opencv2/core/cv_cpu_dispatch.h opencv2/core/types_c.h opencv2/core/cvdef.h opencv2/core/core_c.h opencv2/core/cv_cpu_helper.h opencv2/core/hal/interface.h opencv2/imgproc/imgproc_c.h opencv2/imgproc/types_c.h opencv2/imgproc/hal/interface.h opencv2/imgcodecs/ios.h opencv2/imgcodecs/imgcodecs_c.h opencv/cvwimage.h opencv/cxcore.h opencv/highgui.h opencv/cvaux.h opencv/ml.h opencv/cv.h opencv/cxmisc.h opencv2/opencv.hpp opencv2/imgproc.hpp opencv2/opencv_modules.hpp opencv2/imgcodecs.hpp opencv2/core.hpp opencv2/core/directx.hpp opencv2/core/fast_math.hpp opencv2/core/persistence.hpp opencv2/core/traits.hpp opencv2/core/mat.hpp opencv2/core/affine.hpp opencv2/core/cuda_stream_accessor.hpp opencv2/core/wimage.hpp opencv2/core/cvstd.hpp opencv2/core/base.hpp opencv2/core/optim.hpp opencv2/core/vsx_utils.hpp opencv2/core/va_intel.hpp opencv2/core/ocl.hpp opencv2/core/ptr.inl.hpp opencv2/core/saturate.hpp opencv2/core/neon_utils.hpp opencv2/core/cuda.inl.hpp opencv2/core/utility.hpp opencv2/core/opengl.hpp opencv2/core/eigen.hpp opencv2/core/cuda_types.hpp opencv2/core/cuda.hpp opencv2/core/mat.inl.hpp opencv2/core/operations.hpp opencv2/core/cvstd.inl.hpp opencv2/core/ovx.hpp opencv2/core/ippasync.hpp opencv2/core/bufferpool.hpp opencv2/core/matx.hpp opencv2/core/sse_utils.hpp opencv2/core/types.hpp opencv2/core/version.hpp opencv2/core/ocl_genbase.hpp opencv2/core/core.hpp opencv2/core/softfloat.hpp opencv2/core/hal/hal.hpp opencv2/core/hal/intrin_sse.hpp opencv2/core/hal/intrin_neon.hpp opencv2/core/hal/intrin_cpp.hpp opencv2/core/hal/intrin.hpp opencv2/core/hal/intrin_vsx.hpp opencv2/core/cuda/reduce.hpp opencv2/core/cuda/warp_shuffle.hpp opencv2/core/cuda/emulation.hpp opencv2/core/cuda/limits.hpp opencv2/core/cuda/warp_reduce.hpp opencv2/core/cuda/filters.hpp opencv2/core/cuda/vec_distance.hpp opencv2/core/cuda/scan.hpp opencv2/core/cuda/utility.hpp opencv2/core/cuda/type_traits.hpp opencv2/core/cuda/block.hpp opencv2/core/cuda/vec_traits.hpp opencv2/core/cuda/funcattrib.hpp opencv2/core/cuda/datamov_utils.hpp opencv2/core/cuda/vec_math.hpp opencv2/core/cuda/common.hpp opencv2/core/cuda/warp.hpp opencv2/core/cuda/color.hpp opencv2/core/cuda/border_interpolate.hpp opencv2/core/cuda/simd_functions.hpp opencv2/core/cuda/dynamic_smem.hpp opencv2/core/cuda/functional.hpp opencv2/core/cuda/saturate_cast.hpp opencv2/core/cuda/transform.hpp opencv2/core/cuda/detail/reduce.hpp opencv2/core/cuda/detail/reduce_key_val.hpp opencv2/core/cuda/detail/vec_distance_detail.hpp opencv2/core/cuda/detail/color_detail.hpp opencv2/core/cuda/detail/transform_detail.hpp opencv2/core/cuda/detail/type_traits_detail.hpp opencv2/core/utils/logger.hpp opencv2/core/utils/trace.hpp opencv2/imgproc/imgproc.hpp opencv2/imgproc/hal/hal.hpp opencv2/imgproc/detail/distortion_model.hpp opencv2/imgcodecs/imgcodecs.hpp opencv/cxcore.hpp opencv/cv.hpp opencv/cxeigen.hpp opencv/cvaux.hpp ================================================ FILE: cmake/third_party/hwloc.cmake ================================================ include(ExternalProject) if(UNIX AND NOT APPLE) set(BUILD_HWLOC_DEFAULT ON) else() set(BUILD_HWLOC_DEFAULT OFF) endif() option(BUILD_HWLOC "" ${BUILD_HWLOC_DEFAULT}) if(BUILD_HWLOC) set(PCIACCESS_INSTALL ${THIRD_PARTY_DIR}/pciaccess) set(PCIACCESS_INCLUDE_DIR ${PCIACCESS_INSTALL}/include) set(PCIACCESS_LIBRARY_DIR ${PCIACCESS_INSTALL}/lib) set(PCIACCESS_LIBRARY_NAMES libpciaccess.a) foreach(LIBRARY_NAME ${PCIACCESS_LIBRARY_NAMES}) list(APPEND PCIACCESS_STATIC_LIBRARIES ${PCIACCESS_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() set(HWLOC_INSTALL ${THIRD_PARTY_DIR}/hwloc) set(HWLOC_INCLUDE_DIR ${HWLOC_INSTALL}/include) set(HWLOC_LIBRARY_DIR ${HWLOC_INSTALL}/lib) set(HWLOC_LIBRARY_NAMES libhwloc.a) foreach(LIBRARY_NAME ${HWLOC_LIBRARY_NAMES}) list(APPEND ONEFLOW_HWLOC_STATIC_LIBRARIES ${HWLOC_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() if(THIRD_PARTY) include(ProcessorCount) ProcessorCount(PROC_NUM) set(XORG_MACROS_INSTALL ${THIRD_PARTY_DIR}/xorg-macros) set(XORG_MACROS_TAR_URL https://gitlab.freedesktop.org/xorg/util/macros/-/archive/util-macros-1.19.1/macros-util-macros-1.19.1.tar.gz ) use_mirror(VARIABLE XORG_MACROS_TAR_URL URL ${XORG_MACROS_TAR_URL}) set(XORG_MACROS_URL_HASH 764fb1647d7ebd1c8c5d707db525832f) set(XORG_MACROS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/xorg-macros) set(XORG_MACROS_PKG_CONFIG_DIR ${XORG_MACROS_INSTALL}/share/pkgconfig) ExternalProject_Add( xorg-macros PREFIX xorg-macros URL ${XORG_MACROS_TAR_URL} URL_HASH MD5=${XORG_MACROS_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND ${XORG_MACROS_SOURCE_DIR}/src/xorg-macros/autogen.sh COMMAND ${XORG_MACROS_SOURCE_DIR}/src/xorg-macros/configure --prefix=${XORG_MACROS_INSTALL} BUILD_COMMAND make -j${PROC_NUM} INSTALL_COMMAND make install) set(PCIACCESS_TAR_URL https://gitlab.freedesktop.org/xorg/lib/libpciaccess/-/archive/libpciaccess-0.16/libpciaccess-libpciaccess-0.16.tar.gz ) use_mirror(VARIABLE PCIACCESS_TAR_URL URL ${PCIACCESS_TAR_URL}) set(PCIACCESS_URL_HASH 93554c189796c27dfc72af17a367a0b4) set(PCIACCESS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/pciaccess) set(PCIACCESS_CFLAGS "-O3 -fPIC") ExternalProject_Add( pciaccess PREFIX pciaccess URL ${PCIACCESS_TAR_URL} URL_HASH MD5=${PCIACCESS_URL_HASH} UPDATE_COMMAND "" PATCH_COMMAND cp ${XORG_MACROS_INSTALL}/share/aclocal/xorg-macros.m4 ${PCIACCESS_SOURCE_DIR}/src/pciaccess/m4 CONFIGURE_COMMAND ${PCIACCESS_SOURCE_DIR}/src/pciaccess/autogen.sh COMMAND ${PCIACCESS_SOURCE_DIR}/src/pciaccess/configure --prefix=${PCIACCESS_INSTALL} --enable-shared=no BUILD_COMMAND make -j${PROC_NUM} CFLAGS=${PCIACCESS_CFLAGS} BUILD_BYPRODUCTS ${PCIACCESS_STATIC_LIBRARIES} INSTALL_COMMAND make install DEPENDS xorg-macros) set(HWLOC_TAR_URL https://github.com/open-mpi/hwloc/archive/refs/tags/hwloc-2.4.1.tar.gz) use_mirror(VARIABLE HWLOC_TAR_URL URL ${HWLOC_TAR_URL}) set(HWLOC_URL_HASH ac25fc7c2a665b7914c6c21b782f1c4f) set(HWLOC_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/hwloc) set(HWLOC_CFLAGS "-O3 -fPIC") ExternalProject_Add( hwloc PREFIX hwloc URL ${HWLOC_TAR_URL} URL_HASH MD5=${HWLOC_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND ${HWLOC_SOURCE_DIR}/src/hwloc/autogen.sh COMMAND ${HWLOC_SOURCE_DIR}/src/hwloc/configure --prefix=${HWLOC_INSTALL} PKG_CONFIG_PATH=${PCIACCESS_INSTALL}/lib/pkgconfig --disable-libxml2 --enable-static --enable-shared=no BUILD_COMMAND make -j${PROC_NUM} CFLAGS=${HWLOC_CFLAGS} BUILD_BYPRODUCTS ${ONEFLOW_HWLOC_STATIC_LIBRARIES} INSTALL_COMMAND make install DEPENDS pciaccess) endif(THIRD_PARTY) endif(BUILD_HWLOC) ================================================ FILE: cmake/third_party/json.cmake ================================================ include(FetchContent) set_mirror_url_with_hash(JSON_URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.zip 49097a7ec390ffaf1cd2e14b734b6c75) set(JSON_Install ON CACHE STRING "" FORCE) FetchContent_Declare(json URL ${JSON_URL} URL_HASH MD5=${JSON_URL_HASH}) FetchContent_MakeAvailable(json) ================================================ FILE: cmake/third_party/libjpeg-turbo.cmake ================================================ include(ExternalProject) set(LIBJPEG_INCLUDE_DIR ${THIRD_PARTY_DIR}/libjpeg-turbo/include) set(LIBJPEG_LIBRARY_DIR ${THIRD_PARTY_DIR}/libjpeg-turbo/lib) set(LIBJPEG_URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/refs/tags/2.1.3.tar.gz) use_mirror(VARIABLE LIBJPEG_URL URL ${LIBJPEG_URL}) if(WIN32) elseif(APPLE AND ("${CMAKE_GENERATOR}" STREQUAL "Xcode")) set(LIBJPEG_BUILD_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo) set(LIBJPEG_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo/${CMAKE_BUILD_TYPE}) set(LIBJPEG_LIBRARY_NAMES libturbojpeg.a) else() set(LIBJPEG_BUILD_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo) set(LIBJPEG_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo) set(LIBJPEG_LIBRARY_NAMES libturbojpeg.a) endif() foreach(LIBRARY_NAME ${LIBJPEG_LIBRARY_NAMES}) list(APPEND LIBJPEG_STATIC_LIBRARIES ${LIBJPEG_LIBRARY_DIR}/${LIBRARY_NAME}) list(APPEND LIBJPEG_BUILD_STATIC_LIBRARIES ${LIBJPEG_BUILD_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() set(LIBJPEG_HEADERS "${LIBJPEG_BUILD_SRC_DIR}/cderror.h" "${LIBJPEG_BUILD_SRC_DIR}/cdjpeg.h" "${LIBJPEG_BUILD_SRC_DIR}/cmyk.h" "${LIBJPEG_BUILD_SRC_DIR}/jchuff.h" "${LIBJPEG_BUILD_SRC_DIR}/jconfig.h" "${LIBJPEG_BUILD_SRC_DIR}/jdcoefct.h" "${LIBJPEG_BUILD_SRC_DIR}/jdct.h" "${LIBJPEG_BUILD_SRC_DIR}/jdhuff.h" "${LIBJPEG_BUILD_SRC_DIR}/jdmainct.h" "${LIBJPEG_BUILD_SRC_DIR}/jdmaster.h" "${LIBJPEG_BUILD_SRC_DIR}/jdsample.h" "${LIBJPEG_BUILD_SRC_DIR}/jerror.h" "${LIBJPEG_BUILD_SRC_DIR}/jinclude.h" "${LIBJPEG_BUILD_SRC_DIR}/jmemsys.h" "${LIBJPEG_BUILD_SRC_DIR}/jmorecfg.h" "${LIBJPEG_BUILD_SRC_DIR}/jpegcomp.h" "${LIBJPEG_BUILD_SRC_DIR}/jpegint.h" "${LIBJPEG_BUILD_SRC_DIR}/jpeglib.h" "${LIBJPEG_BUILD_SRC_DIR}/jpeg_nbits_table.h" "${LIBJPEG_BUILD_SRC_DIR}/jsimddct.h" "${LIBJPEG_BUILD_SRC_DIR}/jsimd.h" "${LIBJPEG_BUILD_SRC_DIR}/jversion.h" "${LIBJPEG_BUILD_SRC_DIR}/tjutil.h" "${LIBJPEG_BUILD_SRC_DIR}/transupp.h" "${LIBJPEG_BUILD_SRC_DIR}/turbojpeg.h") if(THIRD_PARTY) ExternalProject_Add( libjpeg-turbo PREFIX libjpeg-turbo URL ${LIBJPEG_URL} URL_MD5 627b980fad0573e08e4c3b80b290fc91 UPDATE_COMMAND "" INSTALL_COMMAND "" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${LIBJPEG_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON) # put libjpeg-turbo includes in the directory where they are expected add_custom_target(libjpeg_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${LIBJPEG_INCLUDE_DIR} DEPENDS libjpeg-turbo) add_custom_target(libjpeg_copy_headers_to_destination DEPENDS libjpeg_create_header_dir) foreach(header_file ${LIBJPEG_HEADERS}) add_custom_command( TARGET libjpeg_copy_headers_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${LIBJPEG_INCLUDE_DIR}) endforeach() # pub libjpeg libs in the directory where they are expected add_custom_target(libjpeg_create_library_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${LIBJPEG_LIBRARY_DIR} DEPENDS libjpeg-turbo) add_custom_target( libjpeg_copy_libs_to_destination COMMAND ${CMAKE_COMMAND} -E copy_if_different ${LIBJPEG_BUILD_STATIC_LIBRARIES} ${LIBJPEG_LIBRARY_DIR} DEPENDS libjpeg_create_library_dir) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/nccl.cmake ================================================ option(NCCL_STATIC "" ON) if(OF_CUDA_LINK_DYNAMIC_LIBRARY) set(NCCL_STATIC OFF) endif() option(USE_SYSTEM_NCCL "" OFF) set(NCCL_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA NCCL") if(WIN32) set(NCCL_LIBRARY_NAME libnccl_static.lib) else() if(NCCL_STATIC) set(NCCL_LIBRARY_NAME libnccl_static.a) else() set(NCCL_LIBRARY_NAME libnccl.so) endif() endif() if(USE_SYSTEM_NCCL) include(FindPackageHandleStandardArgs) find_path(NCCL_INCLUDE_DIR nccl.h HINTS ${NCCL_ROOT_DIR} ${CUDAToolkit_INCLUDE_DIRS} PATH_SUFFIXES cuda/include include) unset(NCCL_LIBRARY CACHE) find_library( NCCL_LIBRARY ${NCCL_LIBRARY_NAME} HINTS ${NCCL_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR} ${CUDAToolkit_LIBRARY_ROOT} PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) set(NCCL_LIBRARIES ${NCCL_LIBRARY}) add_custom_target(nccl) else() get_filename_component(CUDATOOLKIT_BIN_ROOT ${CUDAToolkit_BIN_DIR} DIRECTORY) include(ExternalProject) set(NCCL_INSTALL_DIR ${THIRD_PARTY_DIR}/nccl) set(NCCL_INCLUDE_DIR ${NCCL_INSTALL_DIR}/include) set(NCCL_LIBRARY_DIR ${NCCL_INSTALL_DIR}/lib) # Versions 2.13 and above may cause deadlocks if(CUDA_VERSION VERSION_GREATER_EQUAL "11.8") set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v2.15.1-1.tar.gz) set(NCCL_MD5 37b787ff8934cd9374b4612f663c17fa) else() set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v2.12.10-1.tar.gz) set(NCCL_MD5 bdb91f80b78c99831f09ca8bb28a1032) endif() use_mirror(VARIABLE NCCL_URL URL ${NCCL_URL}) list(APPEND NCCL_LIBRARIES ${NCCL_LIBRARY_DIR}/${NCCL_LIBRARY_NAME}) set(NCCL_ARCHS_LIST ${CUDA_REAL_ARCHS_LIST}) # remove redundant archs, https://github.com/NVIDIA/nccl/blob/cb111f764a6d46370f24f75101d6b219bb2dda54/makefiles/common.mk#L28 if("70" IN_LIST NCCL_ARCHS_LIST AND "75" IN_LIST NCCL_ARCHS_LIST) list(REMOVE_ITEM NCCL_ARCHS_LIST "75") endif() if("80" IN_LIST NCCL_ARCHS_LIST AND "86" IN_LIST NCCL_ARCHS_LIST) list(REMOVE_ITEM NCCL_ARCHS_LIST "86") endif() if("80" IN_LIST NCCL_ARCHS_LIST AND "89" IN_LIST NCCL_ARCHS_LIST) list(REMOVE_ITEM NCCL_ARCHS_LIST "89") endif() foreach(arch ${NCCL_ARCHS_LIST}) string(APPEND NCCL_GENCODE "-gencode=arch=compute_${arch},code=sm_${arch} ") endforeach() if(THIRD_PARTY) include(ProcessorCount) ProcessorCount(PROC_NUM) ExternalProject_Add( nccl PREFIX nccl URL ${NCCL_URL} URL_MD5 ${NCCL_MD5} UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_IN_SOURCE 1 BUILD_COMMAND make -j${PROC_NUM} src.build CUDA_HOME=${CUDATOOLKIT_BIN_ROOT} NVCC_GENCODE=${NCCL_GENCODE} INSTALL_COMMAND make src.install PREFIX=${NCCL_INSTALL_DIR} BUILD_BYPRODUCTS ${NCCL_LIBRARIES}) endif(THIRD_PARTY) endif() ================================================ FILE: cmake/third_party/oneDNN.cmake ================================================ include(ExternalProject) include(GNUInstallDirs) set(ONEDNN_INSTALL_DIR ${THIRD_PARTY_DIR}/onednn) set(ONEDNN_INCLUDE_DIR ${ONEDNN_INSTALL_DIR}/include) set(ONEDNN_LIBRARY_DIR ${ONEDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}) set(ONEDNN_URL https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.4.3.tar.gz) use_mirror(VARIABLE ONEDNN_URL URL ${ONEDNN_URL}) if(WIN32) message(FATAL_ERROR "Windows system does not support onednn") else() if(BUILD_CPP_API) set(ONEDNN_BUILD_SHARED_LIBS OFF) else() set(ONEDNN_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) endif() if(ONEDNN_BUILD_SHARED_LIBS) if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") set(ONEDNN_LIBRARY_NAMES libdnnl.dylib) elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") set(ONEDNN_LIBRARY_NAMES libdnnl.so) set(DNNL_LIBRARY_TYPE SHARED) set(DNNL_LIBRARY_RPATH ON) else() message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for onednn") endif() else() set(ONEDNN_LIBRARY_NAMES libdnnl.a) set(DNNL_LIBRARY_TYPE STATIC) set(DNNL_LIBRARY_RPATH OFF) endif() endif() foreach(LIBRARY_NAME ${ONEDNN_LIBRARY_NAMES}) list(APPEND ONEDNN_STATIC_LIBRARIES ${ONEDNN_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() # the order of the following codes can't be changed set(ONEDNN_CPU_RUNTIME SEQ) if(WITH_OMP) set(ONEDNN_CPU_RUNTIME OMP) endif() if(WITH_TBB) set(ONEDNN_CPU_RUNTIME TBB) set(ONEDNN_DEPENDS install-tbb) endif() if(THIRD_PARTY) ExternalProject_Add( onednn PREFIX onednn DEPENDS ${ONEDNN_DEPENDS} URL ${ONEDNN_URL} URL_MD5 c60ea96acbaccec053be7e3fa81c6184 UPDATE_COMMAND "" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${ONEDNN_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:STRING=${ONEDNN_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_C_FLAGS_DEBUG:STRING=${CMAKE_C_FLAGS_DEBUG} -DCMAKE_C_FLAGS_RELEASE:STRING=${CMAKE_C_FLAGS_RELEASE} -DDNNL_IS_MAIN_PROJECT:BOOL=OFF -DDNNL_BUILD_EXAMPLES:BOOL=OFF -DDNNL_BUILD_TESTS:BOOL=OFF -DDNNL_LIBRARY_TYPE:STRING=${DNNL_LIBRARY_TYPE} -DCMAKE_INSTALL_RPATH_USE_LINK_PATH:BOOL=${DNNL_LIBRARY_RPATH} -DCMAKE_INSTALL_RPATH:STRING=${ONETBB_INSTALL_DIR} -DDNNL_CPU_RUNTIME:STRING=${ONEDNN_CPU_RUNTIME} -DTBBROOT:STRING=${ONETBB_INSTALL_DIR} -DTBB_ROOT:STRING=${ONETBB_INSTALL_DIR}/lib/cmake/TBB) endif(THIRD_PARTY) add_library(onednn_imported UNKNOWN IMPORTED) set_property(TARGET onednn_imported PROPERTY IMPORTED_LOCATION "${ONEDNN_STATIC_LIBRARIES}") ================================================ FILE: cmake/third_party/opencv.cmake ================================================ include(ExternalProject) include(GNUInstallDirs) set(OPENCV_INSTALL_DIR ${THIRD_PARTY_DIR}/opencv) set(OPENCV_INCLUDE_DIR ${OPENCV_INSTALL_DIR}/include) set(LIBPNG_INSTALL_DIR ${THIRD_PARTY_DIR}/libpng) set(LIBPNG_INCLUDE_DIR ${LIBPNG_INSTALL_DIR}/include) set(OPENCV_LIBRARY_DIR ${OPENCV_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}) set(OPENCV_3RDPARTY_LIBRARY_DIR ${OPENCV_INSTALL_DIR}/share/OpenCV/3rdparty/${CMAKE_INSTALL_LIBDIR}) set(OPENCV_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/src) set(OPENCV_URL https://github.com/opencv/opencv/archive/83391ac59d270f2148fc99a62ae279b04d37f5d0.tar.gz) use_mirror(VARIABLE OPENCV_URL URL ${OPENCV_URL}) set(OPENCV_LIBRARY_NAMES libopencv_imgcodecs.a libopencv_imgproc.a libopencv_core.a) set(OPENCV_3RDPARTY_LIBRARY_NAMES libIlmImf.a liblibjasper.a liblibpng.a liblibtiff.a liblibwebp.a) foreach(LIBRARY_NAME ${OPENCV_LIBRARY_NAMES}) list(APPEND OPENCV_STATIC_LIBRARIES ${OPENCV_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() foreach(LIBRARY_NAME ${OPENCV_3RDPARTY_LIBRARY_NAMES}) list(APPEND OPENCV_STATIC_LIBRARIES ${OPENCV_3RDPARTY_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() if(THIRD_PARTY) if(CMAKE_C_COMPILER_LAUNCHER STREQUAL "ccache") set(OPENCV_C_COMPILER_LAUNCHER_DEF "-DENABLE_CCACHE:BOOL=ON") else() set(OPENCV_C_COMPILER_LAUNCHER_DEF "-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}") endif() if(CMAKE_CXX_COMPILER_LAUNCHER STREQUAL "ccache") set(OPENCV_CXX_COMPILER_LAUNCHER_DEF "-DENABLE_CCACHE:BOOL=ON") else() set(OPENCV_CXX_COMPILER_LAUNCHER_DEF "-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}") endif() ExternalProject_Add( opencv DEPENDS libjpeg_copy_headers_to_destination libjpeg_copy_libs_to_destination PREFIX opencv URL ${OPENCV_URL} URL_MD5 b09dc79dec7766a3550907bcafc8bbf5 UPDATE_COMMAND "" PATCH_COMMAND cmake -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/build BUILD_IN_SOURCE 0 SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/build BUILD_BYPRODUCTS ${OPENCV_STATIC_LIBRARIES} CMAKE_CACHE_ARGS ${OPENCV_C_COMPILER_LAUNCHER_DEF} ${OPENCV_CXX_COMPILER_LAUNCHER_DEF} -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX:STRING=${OPENCV_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DWITH_IPP:BOOL=OFF -DWITH_1394:BOOL=OFF -DWITH_AVFOUNDATION:BOOL=OFF -DWITH_CAROTENE:BOOL=OFF -DWITH_CPUFEATURES:BOOL=OFF -DWITH_VTK:BOOL=OFF -DWITH_CUDA:BOOL=OFF -DWITH_CUFFT:BOOL=OFF -DWITH_CUBLAS:BOOL=OFF -DWITH_NVCUVID:BOOL=OFF -DWITH_EIGEN:BOOL=OFF -DWITH_VFW:BOOL=OFF -DWITH_FFMPEG:BOOL=OFF -DWITH_WEBP:BOOL=ON -DBUILD_WEBP:BOOL=ON -DWITH_GSTREAMER:BOOL=OFF -DWITH_GSTREAMER_0_10:BOOL=OFF -DWITH_GTK:BOOL=OFF -DWITH_GTK_2_X:BOOL=OFF -DWITH_WIN32UI:BOOL=OFF -DWITH_PTHREADS_PF:BOOL=OFF -DWITH_DSHOW:BOOL=OFF -DWITH_OPENCL:BOOL=OFF -DWITH_OPENCL_SVM:BOOL=OFF -DWITH_OPENCLAMDFFT:BOOL=OFF -DWITH_OPENCLAMDBLAS:BOOL=OFF -DWITH_DIRECTX:BOOL=OFF -DWITH_MATLAB:BOOL=OFF -DWITH_GPHOTO2:BOOL=OFF -DWITH_LAPACK:BOOL=OFF -DBUILD_SHARED_LIBS:BOOL=OFF -DBUILD_ANDROID_EXAMPLES:BOOL=OFF -DBUILD_DOCS:BOOL=OFF -DBUILD_PACKAGE:BOOL=OFF -DBUILD_PERF_TESTS:BOOL=OFF -DBUILD_TESTS:BOOL=OFF -DBUILD_FAT_JAVA_LIBS:BOOL=OFF -DBUILD_ANDROID_SERVICE:BOOL=OFF -DBUILD_CUDA_STUBS:BOOL=OFF -DENABLE_PYLINT:BOOL=OFF -DBUILD_opencv_python3:BOOL=OFF -DBUILD_opencv_python2:BOOL=OFF -DBUILD_opencv_world:BOOL=OFF -DBUILD_opencv_apps:BOOL=OFF -DBUILD_opencv_js:BOOL=OFF -DBUILD_ZLIB:BOOL=OFF -DZLIB_ROOT:PATH=${ZLIB_INSTALL} -DBUILD_TIFF:BOOL=ON -DBUILD_JASPER:BOOL=ON -DWITH_JPEG:BOOL=ON -DBUILD_JPEG:BOOL=OFF -DJPEG_INCLUDE_DIR:STRING=${LIBJPEG_INCLUDE_DIR} -DJPEG_LIBRARY:STRING=${LIBJPEG_STATIC_LIBRARIES} -DBUILD_PNG:BOOL=ON -DBUILD_OPENEXR:BOOL=ON -DBUILD_TBB:BOOL=ON -DBUILD_IPP_IW:BOOL=OFF -DWITH_ITT:BOOL=OFF -DBUILD_opencv_flann:BOOL=OFF -DBUILD_opencv_ml:BOOL=OFF -DBUILD_opencv_objdetect:BOOL=OFF -DBUILD_opencv_photo:BOOL=OFF -DBUILD_opencv_video:BOOL=OFF -DBUILD_opencv_dnn:BOOL=OFF -DBUILD_opencv_shape:BOOL=OFF -DBUILD_opencv_videoio:BOOL=OFF -DBUILD_opencv_highgui:BOOL=OFF -DBUILD_opencv_superres:BOOL=OFF -DBUILD_opencv_features2d:BOOL=OFF -DBUILD_opencv_calib3d:BOOL=OFF -DBUILD_opencv_stitching:BOOL=OFF -DBUILD_opencv_videostab:BOOL=OFF -DBUILD_opencv_imgproc:BOOL=ON -DBUILD_opencv_imgcodecs:BOOL=ON -DENABLE_CXX11:BOOL=ON # -DLIB_SUFFIX:STRING=64 ) if(WITH_ZLIB) add_dependencies(opencv zlib) endif() install( FILES ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngconf.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngdebug.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/png.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pnginfo.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pnglibconf.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngpriv.h ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngstruct.h TYPE INCLUDE COMPONENT libpng_headers) add_custom_target( install_libpng_headers COMMAND "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=libpng_headers -DCMAKE_INSTALL_PREFIX="${LIBPNG_INSTALL_DIR}" -DCMAKE_INSTALL_MESSAGE=${CMAKE_INSTALL_MESSAGE} -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" DEPENDS opencv) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/openssl.cmake ================================================ include(ExternalProject) set(OPENSSL_INSTALL ${THIRD_PARTY_DIR}/openssl) set(OPENSSL_INCLUDE_DIR ${THIRD_PARTY_DIR}/openssl/include) set(OPENSSL_LIBRARY_DIR ${THIRD_PARTY_DIR}/openssl/lib) set(OPENSSL_TAR_URL https://github.com/openssl/openssl/archive/OpenSSL_1_1_1g.tar.gz) use_mirror(VARIABLE OPENSSL_TAR_URL URL ${OPENSSL_TAR_URL}) set(OPENSSL_URL_HASH dd32f35dd5d543c571bc9ebb90ebe54e) set(OPENSSL_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/openssl) if(WIN32) set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib) set(OPENSSL_LIBRARY_NAMES ssl.lib crypto.lib) elseif(APPLE AND ("${CMAKE_GENERATOR}" STREQUAL "Xcode")) set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib) set(OPENSSL_LIBRARY_NAMES libssl.a libcrypto.a) else() set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib) set(OPENSSL_LIBRARY_NAMES libssl.a libcrypto.a) endif() foreach(LIBRARY_NAME ${OPENSSL_LIBRARY_NAMES}) list(APPEND OPENSSL_STATIC_LIBRARIES ${OPENSSL_LIBRARY_DIR}/${LIBRARY_NAME}) list(APPEND OPENSSL_BUILD_STATIC_LIBRARIES ${OPENSSL_BUILD_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() if(THIRD_PARTY) include(ProcessorCount) ProcessorCount(PROC_NUM) ExternalProject_Add( openssl PREFIX openssl URL ${OPENSSL_TAR_URL} URL_HASH MD5=${OPENSSL_URL_HASH} UPDATE_COMMAND "" CONFIGURE_COMMAND ${OPENSSL_SOURCE_DIR}/src/openssl/config --prefix=${OPENSSL_INSTALL} BUILD_BYPRODUCTS ${OPENSSL_STATIC_LIBRARIES} BUILD_COMMAND make -j${PROC_NUM} INSTALL_COMMAND make install_sw) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/patches/tensorflow-logging.patch ================================================ --- ./build/third_party_install/tensorflow/include/tensorflow_inc/tensorflow/stream_executor/platform/logging.h 2021-06-22 16:41:20.000000000 +0800 +++ logging.h 2021-08-16 19:41:43.082449275 +0800 @@ -19,7 +19,7 @@ #include "tensorflow/core/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" -#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) +#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && !defined(GOOGLE_LOGGING) #define PCHECK(invocation) CHECK(invocation) ================================================ FILE: cmake/third_party/protobuf.cmake ================================================ include(ExternalProject) set(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_DIR}/protobuf) set(PROTOBUF_INSTALL_INCLUDEDIR include) set(PROTOBUF_INSTALL_LIBDIR lib) set(PROTOBUF_INSTALL_BINDIR bin) set(PROTOBUF_INCLUDE_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_INCLUDEDIR}) set(PROTOBUF_LIBRARY_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_LIBDIR}) set(PROTOBUF_BINARY_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_BINDIR}) set(PROTOBUF_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL "https://github.com/protocolbuffers/protobuf/archive/v3.9.2.zip") set(PROTOBUF_MD5 cf02c32870a1f78c860039e0f63a6343) use_mirror(VARIABLE PROTOBUF_URL URL ${PROTOBUF_URL}) if(WIN32) set(PROTOBUF_LIBRARY_NAMES libprotobufd.lib) set(PROTOC_EXECUTABLE_NAME protoc.exe) set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=ON -A x64) else() # NOTE: (houjiang, shenghang), to support xrt, must make libproto built as shared if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") set(PROTOBUF_LIBRARY_NAMES libprotobuf.dylib) elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") set(PROTOBUF_LIBRARY_NAMES libprotobuf.so) else() message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for protobuf") endif() set(PROTOBUF_BUILD_SHARED_LIBS ON) set(PROTOC_EXECUTABLE_NAME protoc) endif() foreach(LIBRARY_NAME ${PROTOBUF_LIBRARY_NAMES}) list(APPEND PROTOBUF_STATIC_LIBRARIES ${PROTOBUF_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() set(PROTOBUF_PROTOC_EXECUTABLE ${PROTOBUF_BINARY_DIR}/${PROTOC_EXECUTABLE_NAME}) if(THIRD_PARTY) ExternalProject_Add( protobuf PREFIX protobuf URL ${PROTOBUF_URL} URL_MD5 ${PROTOBUF_MD5} UPDATE_COMMAND "" BUILD_IN_SOURCE 1 SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf SOURCE_SUBDIR cmake BUILD_BYPRODUCTS ${PROTOBUF_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:PATH=${ZLIB_INSTALL} -Dprotobuf_WITH_ZLIB:BOOL=${WITH_ZLIB} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DBUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS} -Dprotobuf_BUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS} -Dprotobuf_BUILD_TESTS:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_INCLUDEDIR:STRING=${PROTOBUF_INSTALL_INCLUDEDIR} -DCMAKE_INSTALL_LIBDIR:STRING=${PROTOBUF_INSTALL_LIBDIR} -DCMAKE_INSTALL_BINDIR:STRING=${PROTOBUF_INSTALL_BINDIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -Dprotobuf_DEBUG_POSTFIX:STRING= ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS}) if(WITH_ZLIB) add_dependencies(protobuf zlib) endif() else() add_custom_target(protobuf) endif(THIRD_PARTY) add_library(protobuf_imported UNKNOWN IMPORTED) set_property(TARGET protobuf_imported PROPERTY IMPORTED_LOCATION "${PROTOBUF_STATIC_LIBRARIES}") ================================================ FILE: cmake/third_party/re2.cmake ================================================ include(ExternalProject) set(RE2_PROJECT re2) set(RE2_INSTALL_DIR ${THIRD_PARTY_DIR}/re2) set(RE2_INCLUDE_DIR ${RE2_INSTALL_DIR}/include CACHE PATH "" FORCE) set(RE2_LIBRARY_DIR ${RE2_INSTALL_DIR}/lib CACHE PATH "" FORCE) set(RE2_LIBRARIES ${RE2_LIBRARY_DIR}/libre2.a) set(RE2_URL https://github.com/Oneflow-Inc/re2/archive/refs/tags/e17af7789.tar.gz) use_mirror(VARIABLE RE2_URL URL ${RE2_URL}) if(THIRD_PARTY) ExternalProject_Add( ${RE2_PROJECT} PREFIX re2 URL ${RE2_URL} URL_MD5 3b2e20c1edd1cfe887aeef3b0747eac0 UPDATE_COMMAND "" BUILD_BYPRODUCTS ${RE2_LIBRARIES} CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${RE2_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${RE2_LIBRARY_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DRE2_BUILD_TESTING:BOOL=OFF -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/trt_flash_attention.cmake ================================================ include(ExternalProject) find_package(Threads) set(TRT_FLASH_ATTENTION_PROJECT trt_flash_attention) set(TRT_FLASH_ATTENTION_URL https://github.com/Oneflow-Inc/trt_flash_attention/archive/d8b74631eb811c95a0d20f247238db6e91acafe3.zip ) use_mirror(VARIABLE TRT_FLASH_ATTENTION_URL URL ${TRT_FLASH_ATTENTION_URL}) set(TRT_FLASH_ATTENTION_MD5 9e0e822ce1450e11515533fbe32e58a9) set(TRT_FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/trt_flash_attention) set(TRT_FLASH_ATTENTION_INCLUDE_DIR ${TRT_FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH "" FORCE) set(TRT_FLASH_ATTENTION_LIBRARY_DIR ${TRT_FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH "" FORCE) set(TRT_FLASH_ATTENTION_LIBRARIES ${TRT_FLASH_ATTENTION_LIBRARY_DIR}/libtrt_flash_attention.so) if(THIRD_PARTY) ExternalProject_Add( ${TRT_FLASH_ATTENTION_PROJECT} PREFIX trt_flash_attention URL ${TRT_FLASH_ATTENTION_URL} URL_MD5 ${TRT_FLASH_ATTENTION_MD5} UPDATE_COMMAND "" BUILD_BYPRODUCTS ${TRT_FLASH_ATTENTION_LIBRARIES} CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} CMAKE_CACHE_ARGS -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE} -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${TRT_FLASH_ATTENTION_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${TRT_FLASH_ATTENTION_LIBRARY_DIR} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) endif(THIRD_PARTY) ================================================ FILE: cmake/third_party/zlib.cmake ================================================ include(ExternalProject) set(ZLIB_INSTALL ${THIRD_PARTY_DIR}/zlib) set(ZLIB_INCLUDE_DIR ${ZLIB_INSTALL}/include) set(ZLIB_LIBRARY_DIR ${ZLIB_INSTALL}/lib) set(ZLIB_URL https://github.com/madler/zlib/archive/v1.2.8.tar.gz) use_mirror(VARIABLE ZLIB_URL URL ${ZLIB_URL}) # only use zlib shared lib to prevent using zlib in the system if(WIN32) set(ZLIB_LIBRARY_NAMES zlibstaticd.lib) else() if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") set(ZLIB_LIBRARY_NAMES libz.dylib) elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") set(ZLIB_LIBRARY_NAMES libz.so) else() message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for zlib") endif() endif() foreach(LIBRARY_NAME ${ZLIB_LIBRARY_NAMES}) list(APPEND ZLIB_STATIC_LIBRARIES ${ZLIB_LIBRARY_DIR}/${LIBRARY_NAME}) endforeach() set(ZLIB_HEADERS "${ZLIB_INSTALL}/include/zconf.h" "${ZLIB_INSTALL}/include/zlib.h") if(THIRD_PARTY) ExternalProject_Add( zlib PREFIX zlib URL ${ZLIB_URL} URL_MD5 1eabf2698dc49f925ce0ffb81397098f UPDATE_COMMAND "" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${ZLIB_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS:BOOL=${BUILD_SHARED_LIBS} -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}) endif(THIRD_PARTY) add_library(zlib_imported UNKNOWN IMPORTED) set_property(TARGET zlib_imported PROPERTY IMPORTED_LOCATION "${ZLIB_STATIC_LIBRARIES}") ================================================ FILE: cmake/third_party.cmake ================================================ cmake_policy(SET CMP0074 NEW) if(NOT WIN32) find_package(Threads) endif() if(WITH_ZLIB) include(zlib) endif() include(protobuf) include(googletest) include(glog) include(libjpeg-turbo) include(opencv) include(eigen) if(WITH_COCOAPI) include(cocoapi) endif() include(half) include(re2) include(json) if(RPC_BACKEND MATCHES "GRPC") include(absl) include(cares) include(openssl) include(grpc) endif() include(flatbuffers) include(hwloc) if(WITH_ONEDNN) include(oneDNN) endif() set_mirror_url_with_hash(INJA_URL https://github.com/pantor/inja/archive/refs/tags/v3.3.0.zip 611e6b7206d0fb89728a3879f78b4775) if(NOT WIN32) set(BLA_STATIC ON) set(BLA_VENDOR "Intel10_64lp_seq") find_package(BLAS) if(NOT BLAS_FOUND) set(BLA_VENDOR "All") find_package(BLAS) endif() else() set(MKL_LIB_PATH "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_2017/windows/mkl/lib/intel64_win" ) set(BLAS_LIBRARIES ${MKL_LIB_PATH}/mkl_core_dll.lib ${MKL_LIB_PATH}/mkl_sequential_dll.lib ${MKL_LIB_PATH}/mkl_intel_lp64_dll.lib) endif() message(STATUS "Found Blas Lib: " ${BLAS_LIBRARIES}) set(oneflow_test_libs gtest_main) set(oneflow_third_party_libs protobuf_imported ${GRPC_STATIC_LIBRARIES} ${farmhash_STATIC_LIBRARIES} ${BLAS_LIBRARIES} ${OPENCV_STATIC_LIBRARIES} ${COCOAPI_STATIC_LIBRARIES} ${LIBJPEG_STATIC_LIBRARIES} ${ABSL_STATIC_LIBRARIES} ${OPENSSL_STATIC_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} ${FLATBUFFERS_STATIC_LIBRARIES} nlohmann_json::nlohmann_json) if(WITH_ONEDNN) set(oneflow_third_party_libs ${oneflow_third_party_libs} ${ONEDNN_STATIC_LIBRARIES}) endif() list(APPEND oneflow_third_party_libs ${RE2_LIBRARIES}) if(WITH_ZLIB) list(APPEND oneflow_third_party_libs zlib_imported) endif() if(WIN32) # static gflags lib requires "PathMatchSpecA" defined in "ShLwApi.Lib" list(APPEND oneflow_third_party_libs "ShLwApi.Lib") list(APPEND oneflow_third_party_libs "Ws2_32.lib") endif() set(oneflow_third_party_dependencies protobuf eigen half_copy_headers_to_destination re2 opencv install_libpng_headers flatbuffers) if(WITH_ONEDNN) list(APPEND oneflow_third_party_dependencies onednn) endif() if(WITH_ZLIB) list(APPEND oneflow_third_party_dependencies zlib) endif() if(WITH_COCOAPI) list(APPEND oneflow_third_party_dependencies cocoapi_copy_headers_to_destination) list(APPEND oneflow_third_party_dependencies cocoapi_copy_libs_to_destination) endif() if(RPC_BACKEND MATCHES "GRPC") list(APPEND oneflow_third_party_dependencies grpc) endif() list( APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${ZLIB_INCLUDE_DIR} ${PROTOBUF_INCLUDE_DIR} ${GRPC_INCLUDE_DIR} ${GLOG_INCLUDE_DIR} ${LIBJPEG_INCLUDE_DIR} ${OPENCV_INCLUDE_DIR} ${LIBPNG_INCLUDE_DIR} ${EIGEN_INCLUDE_DIR} ${COCOAPI_INCLUDE_DIR} ${HALF_INCLUDE_DIR} ${ABSL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ${FLATBUFFERS_INCLUDE_DIR}) if(WITH_ONEDNN) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${ONEDNN_INCLUDE_DIR}) endif() list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${RE2_INCLUDE_DIR}) if(BUILD_CUDA) # Always use third_party/cub for Clang CUDA in case of compatibility issues if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA" AND CUDA_VERSION VERSION_GREATER_EQUAL "11.0") if(CMAKE_CXX_STANDARD LESS 14) add_definitions(-DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT) add_definitions(-DCUB_IGNORE_DEPRECATED_CPP11) endif() if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.0") add_definitions(-DCUB_IGNORE_DEPRECATED_COMPILER) endif() else() include(cub) list(APPEND oneflow_third_party_dependencies cub_copy_headers_to_destination) endif() include(nccl) include(cutlass) include(trt_flash_attention) if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") include(flash_attention) endif() list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES}) list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES}) list(APPEND oneflow_third_party_libs ${VENDOR_CUDA_LIBRARIES}) list(APPEND oneflow_third_party_dependencies nccl) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUDNN_INCLUDE_DIRS} ${CUB_INCLUDE_DIR} ${NCCL_INCLUDE_DIR}) if(WITH_CUTLASS) list(APPEND oneflow_third_party_dependencies cutlass) list(APPEND oneflow_third_party_dependencies cutlass_copy_examples_to_destination) list(APPEND oneflow_third_party_libs ${CUTLASS_LIBRARIES}) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUTLASS_INCLUDE_DIR}) endif() list(APPEND oneflow_third_party_dependencies trt_flash_attention) list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES}) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR}) if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") list(APPEND oneflow_third_party_dependencies flash_attention) list(APPEND oneflow_third_party_libs ${FLASH_ATTENTION_LIBRARIES}) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLASH_ATTENTION_INCLUDE_DIR}) endif() endif() if(BUILD_RDMA) if(UNIX) include(CheckIncludeFiles) include(CheckLibraryExists) check_include_files(infiniband/verbs.h HAVE_VERBS_H) if(HAVE_VERBS_H) add_definitions(-DWITH_RDMA) else() message(FATAL_ERROR "RDMA head file not found") endif() else() message(FATAL_ERROR "UNIMPLEMENTED") endif() endif() if(BUILD_HWLOC) list(APPEND oneflow_third_party_dependencies hwloc) list(APPEND oneflow_third_party_libs ${ONEFLOW_HWLOC_STATIC_LIBRARIES}) list(APPEND oneflow_third_party_libs ${PCIACCESS_STATIC_LIBRARIES}) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${HWLOC_INCLUDE_DIR}) add_definitions(-DWITH_HWLOC) endif() include_directories(SYSTEM ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS}) foreach(oneflow_third_party_lib IN LISTS oneflow_third_party_libs) if(NOT "${oneflow_third_party_lib}" MATCHES "^-l.+" AND NOT TARGET ${oneflow_third_party_lib} AND "${oneflow_third_party_lib}" MATCHES "^\/.+" AND NOT "${oneflow_third_party_lib}" MATCHES "^.+\.framework") get_filename_component(IMPORTED_LIB_NAME ${oneflow_third_party_lib} NAME_WE) set(IMPORTED_LIB_NAME "imported::${IMPORTED_LIB_NAME}") message(STATUS "Creating imported lib: ${oneflow_third_party_lib} => ${IMPORTED_LIB_NAME}") add_library(${IMPORTED_LIB_NAME} UNKNOWN IMPORTED) set_property(TARGET ${IMPORTED_LIB_NAME} PROPERTY IMPORTED_LOCATION "${oneflow_third_party_lib}") list(APPEND ONEFLOW_THIRD_PARTY_LIBS_TO_LINK "${IMPORTED_LIB_NAME}") else() list(APPEND ONEFLOW_THIRD_PARTY_LIBS_TO_LINK "${oneflow_third_party_lib}") endif() endforeach() set(oneflow_third_party_libs ${ONEFLOW_THIRD_PARTY_LIBS_TO_LINK}) message(STATUS "oneflow_third_party_libs: ${oneflow_third_party_libs}") add_definitions(-DHALF_ENABLE_CPP11_USER_LITERALS=0) if(THIRD_PARTY) add_custom_target(prepare_oneflow_third_party ALL DEPENDS ${oneflow_third_party_dependencies}) if(BUILD_PYTHON) if(NOT ONEFLOW_INCLUDE_DIR MATCHES "/include$") message( FATAL_ERROR "ONEFLOW_INCLUDE_DIR must end with '/include', current value: ${ONEFLOW_INCLUDE_DIR}") endif() get_filename_component(ONEFLOW_INCLUDE_DIR_PARENT "${ONEFLOW_INCLUDE_DIR}" DIRECTORY) foreach(of_include_src_dir ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS}) if(of_include_src_dir MATCHES "/include$") # it requires two slashes, but in CMake doc it states only one slash is needed set(of_include_src_dir "${of_include_src_dir}//") endif() install( DIRECTORY ${of_include_src_dir} DESTINATION ${ONEFLOW_INCLUDE_DIR} COMPONENT oneflow_py_include EXCLUDE_FROM_ALL) endforeach() endif(BUILD_PYTHON) else() add_custom_target(prepare_oneflow_third_party ALL) endif() ================================================ FILE: cmake/threading.cmake ================================================ foreach(threading_runtime_item ${CPU_THREADING_RUNTIMES}) if(NOT ${threading_runtime_item} MATCHES "^(TBB|OMP)$") message(FATAL_ERROR "Unsupported cpu threading runtime: ${threading_runtime_item}") endif() if(${threading_runtime_item} STREQUAL "OMP") # Reference: # https://releases.llvm.org/11.0.0/tools/clang/docs/OpenMPSupport.html if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 11) message( FATAL_ERROR "libopenmp is not supported under clang10, please use TBB with '-DCPU_THREADING_RUNTIMES=TBB'." ) endif() endif() find_package(OpenMP) if(OPENMP_FOUND) set(WITH_${threading_runtime_item} ON) add_definitions(-DWITH_${threading_runtime_item}) endif() else() set(WITH_${threading_runtime_item} ON) add_definitions(-DWITH_${threading_runtime_item}) endif() endforeach() ================================================ FILE: cmake/util.cmake ================================================ function(SHOW_VARIABLES) get_cmake_property(_variableNames VARIABLES) foreach(_variableName ${_variableNames}) message(STATUS "${_variableName}=${${_variableName}}") endforeach() endfunction() macro(write_file_if_different file_path content) if(EXISTS ${file_path}) file(READ ${file_path} current_content) # NOTE: it seems a cmake bug that "content" in this macro is not # treated as a variable if(NOT (current_content STREQUAL ${content})) file(WRITE ${file_path} ${content}) endif() else() file(WRITE ${file_path} ${content}) endif() endmacro() macro(copy_all_files_in_dir source_dir dest_dir target) find_program(rsync rsync) if(rsync) add_custom_command( TARGET ${target} POST_BUILD COMMAND ${rsync} # NOTE: the trailing slash of source_dir is needed. # Reference: https://stackoverflow.com/a/56627246 ARGS -a --omit-dir-times --no-perms --no-owner --no-group --inplace ${source_dir}/ ${dest_dir}) else() add_custom_command(TARGET ${target} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${source_dir} ${dest_dir}) endif() endmacro() set(_COUNTER 0) macro(copy_files file_paths source_dir dest_dir target) find_program(rsync rsync) if(rsync) set(CACHE_FILELIST ${PROJECT_BINARY_DIR}/cached_filename_lists/cache_${_COUNTER}) math(EXPR _COUNTER "${_COUNTER} + 1") file(WRITE ${CACHE_FILELIST} "") foreach(file ${file_paths}) file(RELATIVE_PATH rel_path "${source_dir}" ${file}) file(APPEND ${CACHE_FILELIST} ${rel_path}\n) endforeach() add_custom_command( TARGET ${target} POST_BUILD COMMAND ${rsync} ARGS -a --omit-dir-times --no-perms --no-owner --no-group --inplace --files-from=${CACHE_FILELIST} ${source_dir} ${dest_dir}) else() foreach(file ${file_paths}) file(RELATIVE_PATH rel_path "${source_dir}" ${file}) add_custom_command(TARGET ${target} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different "${file}" "${dest_dir}/${rel_path}") endforeach() endif() endmacro() function(add_copy_headers_target) cmake_parse_arguments(PARSED_ARGS "" "NAME;SRC;DST;INDEX_FILE" "DEPS" ${ARGN}) if(NOT PARSED_ARGS_NAME) message(FATAL_ERROR "name required") endif(NOT PARSED_ARGS_NAME) if(NOT PARSED_ARGS_SRC) message(FATAL_ERROR "src required") endif(NOT PARSED_ARGS_SRC) if(NOT PARSED_ARGS_DST) message(FATAL_ERROR "dst required") endif(NOT PARSED_ARGS_DST) add_custom_target( "${PARSED_ARGS_NAME}_create_header_dir" COMMAND ${CMAKE_COMMAND} -E make_directory "${PARSED_ARGS_DST}" DEPENDS ${PARSED_ARGS_DEPS}) add_custom_target("${PARSED_ARGS_NAME}_copy_headers_to_destination" ALL DEPENDS "${PARSED_ARGS_NAME}_create_header_dir") file(GLOB_RECURSE headers "${PARSED_ARGS_SRC}/*.h") file(GLOB_RECURSE cuda_headers "${PARSED_ARGS_SRC}/*.cuh") file(GLOB_RECURSE hpp_headers "${PARSED_ARGS_SRC}/*.hpp") list(APPEND headers ${cuda_headers}) list(APPEND headers ${hpp_headers}) foreach(header_file ${headers}) file(RELATIVE_PATH relative_file_path ${PARSED_ARGS_SRC} ${header_file}) add_custom_command( TARGET "${PARSED_ARGS_NAME}_copy_headers_to_destination" PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} "${PARSED_ARGS_DST}/${relative_file_path}") endforeach() if(PARSED_ARGS_INDEX_FILE) file(STRINGS ${PARSED_ARGS_INDEX_FILE} inventory_headers) endif(PARSED_ARGS_INDEX_FILE) foreach(header_file ${inventory_headers}) add_custom_command( TARGET "${PARSED_ARGS_NAME}_copy_headers_to_destination" PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PARSED_ARGS_SRC}/${header_file}" "${PARSED_ARGS_DST}/${header_file}") endforeach() endfunction() function(use_mirror) set(ALIYUN_URL_PREFIX "https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/https/" CACHE STRING "URL prefix of Aliyun OSS mirror") cmake_parse_arguments(PARSED_ARGS "" "VARIABLE;URL" "" ${ARGN}) if((NOT PARSED_ARGS_VARIABLE) OR (NOT PARSED_ARGS_URL)) message(FATAL_ERROR "VARIABLE or URL required") endif() if(PARSED_ARGS_URL MATCHES "file://") set(${PARSED_ARGS_VARIABLE} ${PARSED_ARGS_URL} PARENT_SCOPE) return() endif() if(DEFINED THIRD_PARTY_MIRROR) if(THIRD_PARTY_MIRROR STREQUAL "aliyun") if(NOT PARSED_ARGS_URL MATCHES "^https://") message(FATAL_ERROR "URL should start with 'https://'") endif() string(REPLACE "https://" ${ALIYUN_URL_PREFIX} MIRRORED_URL ${PARSED_ARGS_URL}) set(${PARSED_ARGS_VARIABLE} ${MIRRORED_URL} PARENT_SCOPE) message(NOTICE "-- fetch ${PARSED_ARGS_VARIABLE} using aliyun mirror ${MIRRORED_URL}") elseif(NOT THIRD_PARTY_MIRROR STREQUAL "") message(FATAL_ERROR "invalid key for third party mirror") endif() endif() endfunction() macro(set_mirror_url variable url) set(${variable} ${url} ${ARGN}) use_mirror(VARIABLE ${variable} URL ${url}) endmacro() macro(set_mirror_url_with_hash variable url hash) set_mirror_url(${variable} ${url} ${ARGN}) set(${variable}_HASH ${hash} ${ARGN}) endmacro() function(check_cxx11_abi OUTPUT_VAR) execute_process( COMMAND ${CMAKE_COMMAND} -E echo "#include \n void test(std::string){}\n int main(){}" OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp) try_compile( COMPILE_SUCCESS ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp COMPILE_DEFINITIONS -D_GLIBCXX_USE_CXX11_ABI=1 COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/temp) if(NOT COMPILE_SUCCESS) message(FATAL_ERROR "Detecting cxx11 availability failed. Please report to OneFlow developers.") endif() execute_process(COMMAND nm ${CMAKE_CURRENT_BINARY_DIR}/temp COMMAND grep -q cxx11 RESULT_VARIABLE RET_CODE) if(RET_CODE EQUAL 0) set(CXX11_ABI_AVAILABLE ON) else() set(CXX11_ABI_AVAILABLE OFF) endif() execute_process(COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/temp ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp) set(${OUTPUT_VAR} ${CXX11_ABI_AVAILABLE} PARENT_SCOPE) endfunction() include(CheckCXXCompilerFlag) function(target_try_compile_option target flag) # We cannot check for -Wno-foo as this won't throw a warning so we must check for the -Wfoo option directly # http://stackoverflow.com/questions/38785168/cc1plus-unrecognized-command-line-option-warning-on-any-other-warning string(REGEX REPLACE "^-Wno-" "-W" checkedFlag ${flag}) string(REGEX REPLACE "[-=]" "_" varName CXX_FLAG${checkedFlag}) # Avoid double checks. A compiler will not magically support a flag it did not before if(NOT DEFINED ${varName}_SUPPORTED) check_cxx_compiler_flag(${checkedFlag} ${varName}_SUPPORTED) endif() if(${varName}_SUPPORTED) target_compile_options(${target} PRIVATE $<$:${flag}>) if(BUILD_CUDA) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" AND "${CMAKE_CUDA_COMPILER_ID}" STREQUAL "Clang") target_compile_options(${target} PRIVATE $<$:${flag}>) endif() endif() endif() endfunction() function(target_try_compile_options target) foreach(flag ${ARGN}) target_try_compile_option(${target} ${flag}) endforeach() endfunction() function(target_treat_warnings_as_errors target) if(TREAT_WARNINGS_AS_ERRORS) target_compile_options(${target} PRIVATE $<$:-Werror>) if(BUILD_CUDA) # Only pass flags when cuda compiler is Clang because cmake handles -Xcompiler incorrectly if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "Clang") target_compile_options(${target} PRIVATE $<$:-Werror>) endif() endif() # TODO: remove it while fixing all deprecated call target_try_compile_options(${target} -Wno-error=deprecated-declarations) # disable unused-* for different compile mode (maybe unused in cpu.cmake, but used in cuda.cmake) target_try_compile_options( ${target} -Wno-error=unused-const-variable -Wno-error=unused-variable -Wno-error=unused-local-typedefs -Wno-error=unused-private-field -Wno-error=unused-lambda-capture) # there is some strict-overflow warnings in oneflow/user/kernels/ctc_loss_kernel_util.cpp for unknown reason, disable them for now target_try_compile_options(${target} -Wno-error=strict-overflow) target_try_compile_options(${target} -Wno-error=instantiation-after-specialization) # disable for pointer operations of intrusive linked lists target_try_compile_options(${target} -Wno-error=array-bounds) target_try_compile_options(${target} -Wno-error=comment) # disable visibility warnings related to https://github.com/Oneflow-Inc/oneflow/pull/3676. target_try_compile_options(${target} -Wno-error=attributes) # disable error about XXX has no out-of-line virtual method definitions; its vtable will be emitted in every translation unit target_try_compile_options(${target} -Wno-error=weak-vtables) endif() endfunction() function(set_compile_options_to_oneflow_target target) target_treat_warnings_as_errors(${target}) target_compile_options(${target} PRIVATE $<$:-Werror=return-type>) target_compile_definitions(${target} PRIVATE ONEFLOW_CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}) # the mangled name between `struct X` and `class X` is different in MSVC ABI, remove it while windows is supported (in MSVC/cl or clang-cl) target_try_compile_options(${target} -Wno-covered-switch-default) set_target_properties(${target} PROPERTIES INSTALL_RPATH "$ORIGIN/../lib") if(BUILD_CUDA) if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA") target_compile_options( ${target} PRIVATE $<$: -Xcompiler -Werror=return-type; -Wno-deprecated-gpu-targets; -Werror cross-execution-space-call; -Xcudafe --diag_suppress=declared_but_not_referenced; >) elseif("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "Clang") target_compile_options( ${target} PRIVATE $<$: -Werror=return-type; # Suppress warning from cub library -- marking as system header seems not working for .cuh files -Wno-pass-failed; >) else() message(FATAL_ERROR "Unknown CUDA compiler ${CMAKE_CUDA_COMPILER_ID}") endif() # remove THRUST_IGNORE_CUB_VERSION_CHECK if starting using bundled cub target_compile_definitions(${target} PRIVATE $<$: THRUST_IGNORE_CUB_VERSION_CHECK; >) endif() endfunction() function(check_variable_defined variable) if(NOT DEFINED ${variable}) message(FATAL_ERROR "Variable ${variable} is not defined") endif() endfunction() function(checkDirAndAppendSlash) set(singleValues DIR;OUTPUT) set(prefix ARG) cmake_parse_arguments(PARSE_ARGV 0 ${prefix} "${noValues}" "${singleValues}" "${multiValues}") if("${${prefix}_DIR}" STREQUAL "" OR "${${prefix}_DIR}" STREQUAL "/") message(FATAL_ERROR "emtpy path found: ${${prefix}_DIR}") else() set(${${prefix}_OUTPUT} "${${prefix}_DIR}/" PARENT_SCOPE) endif() endfunction() function(mark_targets_as_system) # TODO(daquexian): update this function once https://gitlab.kitware.com/cmake/cmake/-/merge_requests/7308 # and its following PRs are merged in cmake v3.25. foreach(target ${ARGV}) get_target_property(include_dir ${target} INTERFACE_INCLUDE_DIRECTORIES) set_target_properties(${target} PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${include_dir}") endforeach() endfunction() if(NOT BUILD_SHARED_LIBS) if(APPLE) set(ALL_ARCHIVE_BEGIN -Wl,-force_load) set(ALL_ARCHIVE_END) elseif(UNIX) set(ALL_ARCHIVE_BEGIN -Wl,--whole-archive) set(ALL_ARCHIVE_END -Wl,--no-whole-archive) endif() endif() ================================================ FILE: dev-requirements.txt ================================================ black==19.10b0; python_version >= "3.6" click==8.0.0; python_version >= "3.6" # https://github.com/psf/black/issues/2964 numpy>=1.21.6, <2.0 protobuf>=3.9.2, <4.0 wheel tqdm requests jinja2 opencv-python; python_version >= "3.9" and sys_platform != 'darwin' and platform_machine != 'aarch64' opencv-python==4.2.0.34; python_version < '3.9' and sys_platform != 'darwin' and platform_machine != 'aarch64' PyYAML>=5.1 pillow dataclasses; python_version<"3.7" cmakelang==0.6.13 pytest-xdist pytest-repeat rich portalocker typing-extensions>=4.0.0, <5.0 ================================================ FILE: docker/build/Dockerfile ================================================ # warning: never share the container image this dockerfile produces ARG CUDA=10.0 FROM nvidia/cuda:${CUDA}-cudnn7-devel-centos7 RUN yum-config-manager --add-repo https://yum.repos.intel.com/setup/intelproducts.repo && \ rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB RUN yum update -y && yum install -y epel-release RUN yum update -y && yum install -y rdma-core-devel \ nasm \ cmake3 \ make \ git \ centos-release-scl \ intel-mkl-2020.0-088 \ zlib-devel \ curl-devel \ which RUN ln -sf /usr/bin/cmake3 /usr/bin/cmake RUN mkdir -p /tmp/download/cmake-extracted && \ cd /tmp/download && \ curl --location https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0.tar.gz --output cmake.tar.gz && \ tar -xvzf cmake.tar.gz --directory cmake-extracted && \ cd cmake-extracted/* && \ mkdir /cmake-install RUN cd /tmp/download/cmake-extracted/* && \ cmake . -DCMAKE_USE_SYSTEM_CURL=ON -DCMAKE_INSTALL_PREFIX=/cmake-install && \ make -j $(nproc) && \ make install ENV PATH="/cmake-install/bin:${PATH}" ARG USE_PYTHON_3_OR_2=3 RUN if [ "${USE_PYTHON_3_OR_2}" -eq 2 ] ; then yum update -y \ && yum install -y python-devel.x86_64 \ && curl https://bootstrap.pypa.io/get-pip.py --output ./get-pip.py \ && python ./get-pip.py \ && rm get-pip.py \ && pip install numpy==1.12.0 protobuf ; fi COPY dev-requirements.txt /workspace/dev-requirements.txt RUN if [ "${USE_PYTHON_3_OR_2}" -eq 3 ] ; then yum update -y \ && yum install -y rh-python36 python36-devel.x86_64 python36-devel \ && python3 -m ensurepip \ && pip3 install /workspace/dev-requirements.txt; fi WORKDIR /workspace/build COPY cmake /workspace/cmake COPY CMakeLists.txt /workspace/CMakeLists.txt # BUILD DEPENDENCY COPY build/third_party /workspace/build/third_party RUN cmake -DTHIRD_PARTY=ON -DCMAKE_BUILD_TYPE=Release -DRELEASE_VERSION=ON .. && make -j # BUILD ONEFLOW COPY oneflow /workspace/oneflow COPY tools /workspace/tools RUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \ cmake -DTHIRD_PARTY=OFF .. && make -j $(nproc) ; ## BUILD WHEEL WORKDIR /workspace RUN pip${USE_PYTHON_3_OR_2} install wheel COPY setup.py /workspace/setup.py RUN python${USE_PYTHON_3_OR_2} setup.py bdist_wheel RUN pip${USE_PYTHON_3_OR_2} install /workspace/dist/*.whl RUN rm -rf oneflow third_party cmake CMakeLists.txt ================================================ FILE: docker/build/build-ubuntu.sh ================================================ docker build \ --rm \ -t oneflow-build:ubuntu -f docker/build/build.ubuntu.dockerfile . ================================================ FILE: docker/build/build.sh ================================================ docker build \ --rm \ -t oneflow-build -f docker/build/Dockerfile . ================================================ FILE: docker/build/build.ubuntu.dockerfile ================================================ ARG CUDA=10.0 ARG UBUNTU_VERSION=16.04 FROM nvidia/cuda:${CUDA}-cudnn7-devel-ubuntu${UBUNTU_VERSION} USER 0 RUN apt-get update && \ apt-get install -y apt-transport-https && \ apt-get install -y --no-install-recommends \ curl \ nasm \ make \ git \ gcc \ g++ \ libopenblas-dev \ python3-dev # speed up pip install in China ENV TUNA_PIP_INSTALL=" -i https://pypi.tuna.tsinghua.edu.cn/simple" COPY dev-requirements.txt /workspace/dev-requirements.txt RUN curl https://bootstrap.pypa.io/get-pip.py --output ./get-pip.py \ && python3 ./get-pip.py \ && pip3 install $TUNA_INDEX cmake \ && pip3 install $TUNA_INDEX -r /workspace/dev-requirements.txt WORKDIR /workspace/build COPY cmake /workspace/cmake COPY CMakeLists.txt /workspace/CMakeLists.txt # BUILD DEPENDENCY COPY build/third_party /workspace/build/third_party RUN cmake -DTHIRD_PARTY=ON -DONEFLOW=OFF -DCMAKE_BUILD_TYPE=Release .. && make -j$(nproc) # BUILD ONEFLOW COPY oneflow /workspace/oneflow COPY tools /workspace/tools RUN cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j$(nproc) of_pyscript_copy RUN cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j$(nproc) # BUILD WHEEL WORKDIR /workspace COPY setup.py /workspace/setup.py RUN python3 setup.py bdist_wheel RUN pip3 install /workspace/dist/*.whl RUN rm -rf oneflow third_party cmake CMakeLists.txt ================================================ FILE: docker/build/launch.sh ================================================ docker run -it --rm \ -v /dataset:/dataset/ \ oneflow-build ================================================ FILE: docker/build/test.sh ================================================ docker run -it --rm \ -v /dataset:/dataset/ \ oneflow-build \ python3 -c "import oneflow" ================================================ FILE: docker/ci/base/Dockerfile ================================================ # warning: never share the container image this dockerfile produces ARG CUDA=10.0 FROM nvidia/cuda:${CUDA}-cudnn7-devel-centos7 COPY dev-requirements.txt /workspace/dev-requirements.txt RUN yum-config-manager --add-repo https://yum.repos.intel.com/setup/intelproducts.repo && \ rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \ yum update -y && yum install -y epel-release && \ yum update -y && yum install -y rdma-core-devel \ nasm \ make \ git \ centos-release-scl \ intel-mkl-2020.0-088 \ zlib-devel \ curl-devel \ which \ rh-python36 python36-devel.x86_64 python36-devel && \ python3 -m ensurepip && \ pip3 install -r /workspace/dev-requirements.txt && \ yum clean all RUN mkdir -p /tmp/download && \ mkdir /cmake-extracted && \ cd /tmp/download && \ curl --location https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz --output cmake.tar.gz && \ tar -xvzf cmake.tar.gz --directory /cmake-extracted && \ mv /cmake-extracted/* /cmake-extracted/cmake-install && \ rm -rf /tmp/download ENV PATH="/cmake-extracted/cmake-install/bin:${PATH}" ================================================ FILE: docker/ci/fmt/Dockerfile ================================================ FROM python:3.7 RUN curl https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-format -o /usr/local/bin/clang-format && chmod +x /usr/local/bin/clang-format RUN apt update && apt install -y libncurses5 ================================================ FILE: docker/ci/fmt/build.sh ================================================ set -ex cd docker/ci/fmt docker build -t oneflow-fmt . ================================================ FILE: docker/ci/make/Dockerfile ================================================ ARG from FROM ${from} WORKDIR /workspace/build # BUILD ONEFLOW COPY oneflow /workspace/oneflow COPY tools /workspace/tools RUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \ cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j $(nproc) ; ## BUILD WHEEL WORKDIR /workspace COPY setup.py /workspace/setup.py RUN python3 setup.py bdist_wheel FROM centos:7 WORKDIR /workspace COPY --from=0 /workspace/dist/*.whl . COPY --from=0 /workspace/build/bin/oneflow_testexe . ================================================ FILE: docker/ci/test/Dockerfile ================================================ FROM ufoym/deepo RUN apt remove openmpi-common libfabric1 openmpi-bin librdmacm1:amd64 libopenmpi2 libopenmpi2:amd64 -y ENV MOFED_DIR MLNX_OFED_LINUX-4.3-1.0.1.0-ubuntu18.04-x86_64 RUN wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/deps/${MOFED_DIR}.tgz && \ tar -xzvf ${MOFED_DIR}.tgz && \ ${MOFED_DIR}/mlnxofedinstall --user-space-only --without-fw-update --all -q --force && \ cd .. && \ rm -rf ${MOFED_DIR} && \ rm -rf *.tgz RUN apt update && apt install -y --no-install-recommends gdb openssh-server openssh-client RUN echo 'ALL ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers RUN sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config COPY requirements.txt . RUN pip3 install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt ================================================ FILE: docker/ci/test/build.sh ================================================ set -ex test_img_dir="$(dirname "${BASH_SOURCE[0]}")" test_img_dir="$(realpath "${test_img_dir}")" cd $test_img_dir proxy_args="" proxy_args+=" --network=host" proxy_args+=" --build-arg HTTP_PROXY=${HTTP_PROXY}" proxy_args+=" --build-arg HTTPS_PROXY=${HTTPS_PROXY}" proxy_args+=" --build-arg http_proxy=${http_proxy}" proxy_args+=" --build-arg https_proxy=${https_proxy}" img_tag="oneflow-test:0.2" # update me if any of related files are changed if [[ "$(docker images -q ${img_tag} 2> /dev/null)" == "" ]]; then docker build --rm $proxy_args \ -t $img_tag . fi ================================================ FILE: docker/ci/test/launch.sh ================================================ docker run --shm-size=8g --privileged --network=host --rm -it -w $PWD -v $PWD:$PWD -v /dataset:/dataset -v /model_zoo:/model_zoo \ -v $HOME:$HOME \ oneflow-test:0.2 \ bash ================================================ FILE: docker/ci/test/requirements.txt ================================================ sphinx==3.5.4 jinja2<3.1 recommonmark==0.6.0 furo==2021.4.11b34 sphinx-copybutton==0.5.0 # dependencies above must be identical to docs/requirements.txt pycocotools opencv-python==4.2.0.34 scipy pillow tensorflow-addons==0.9.1 https://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl ================================================ FILE: docker/ci/test-v2/Dockerfile ================================================ FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime COPY sources.list /etc/apt/sources.list RUN apt update && apt install ffmpeg libsm6 libxext6 gdb gcc g++ -y --no-install-recommends COPY requirements.txt . RUN python3 -m pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt ================================================ FILE: docker/ci/test-v2/build.sh ================================================ set -ex test_img_dir="$(dirname "${BASH_SOURCE[0]}")" test_img_dir="$(realpath "${test_img_dir}")" cd $test_img_dir proxy_args="" proxy_args+=" --network=host" proxy_args+=" --build-arg HTTP_PROXY=${HTTP_PROXY}" proxy_args+=" --build-arg HTTPS_PROXY=${HTTPS_PROXY}" proxy_args+=" --build-arg http_proxy=${http_proxy}" proxy_args+=" --build-arg https_proxy=${https_proxy}" img_tag="oneflow-test-v2:0.1" # update me if any of related files are changed if [[ "$(docker images -q ${img_tag} 2> /dev/null)" == "" ]]; then docker build --rm $proxy_args \ -t $img_tag . fi ================================================ FILE: docker/ci/test-v2/requirements.txt ================================================ sphinx==3.5.4 jinja2<3.1 recommonmark==0.6.0 furo==2021.4.11b34 sphinx-copybutton==0.5.0 # dependencies above must be identical to docs/requirements.txt pycocotools opencv-python==4.2.0.34 scipy pillow https://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl ================================================ FILE: docker/ci/test-v2/sources.list ================================================ # 默认注释了源码镜像以提高 apt update 速度,如有需要可自行取消注释 deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse # 预发布软件源,不建议启用 # deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse ================================================ FILE: docker/ci/third_party/Dockerfile ================================================ ARG from FROM ${from} WORKDIR /workspace/build COPY cmake /workspace/cmake COPY CMakeLists.txt /workspace/CMakeLists.txt # BUILD DEPENDENCY COPY build/third_party /workspace/build/third_party RUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \ cmake -DTHIRD_PARTY=ON -DONEFLOW=OFF -DCMAKE_BUILD_TYPE=Release -DRELEASE_VERSION=ON .. && make -j prepare_oneflow_third_party ================================================ FILE: docker/package/manylinux/CentOS-Base.repo ================================================ # CentOS-Base.repo # # From https://mirror.tuna.tsinghua.edu.cn/help/centos/ # # The mirror system uses the connecting IP address of the client and the # update status of each mirror to pick mirrors that are updated to and # geographically close to the client. You should use this for CentOS updates # unless you are manually picking other mirrors. # # If the mirrorlist= does not work for you, as a fall back you can try the # remarked out baseurl= line instead. # # [base] name=CentOS-$releasever - Base baseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/os/$basearch/ http://mirrors.aliyun.com/centos/$releasever/os/$basearch/ http://mirrors.aliyuncs.com/centos/$releasever/os/$basearch/ #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=os enabled=1 gpgcheck=1 gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7 #released updates [updates] name=CentOS-$releasever - Updates baseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/updates/$basearch/ http://mirrors.aliyun.com/centos/$releasever/updates/$basearch/ http://mirrors.aliyuncs.com/centos/$releasever/updates/$basearch/ #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=updates enabled=1 gpgcheck=1 gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7 #additional packages that may be useful [extras] name=CentOS-$releasever - Extras baseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/extras/$basearch/ http://mirrors.aliyun.com/centos/$releasever/extras/$basearch/ http://mirrors.aliyuncs.com/centos/$releasever/extras/$basearch/ #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=extras enabled=1 gpgcheck=1 gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7 #additional packages that extend functionality of existing packages [centosplus] name=CentOS-$releasever - Plus baseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/centosplus/$basearch/ http://mirrors.aliyun.com/centos/$releasever/centosplus/$basearch/ http://mirrors.aliyuncs.com/centos/$releasever/centosplus/$basearch/ #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=centosplus gpgcheck=1 enabled=0 gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7 ================================================ FILE: docker/package/manylinux/CentOS7-Base-163.repo ================================================ # CentOS-Base.repo # # The mirror system uses the connecting IP address of the client and the # update status of each mirror to pick mirrors that are updated to and # geographically close to the client. You should use this for CentOS updates # unless you are manually picking other mirrors. # # If the mirrorlist= does not work for you, as a fall back you can try the # remarked out baseurl= line instead. # # [base] name=CentOS-$releasever - Base - 163.com #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=os baseurl=http://mirrors.163.com/centos/$releasever/os/$basearch/ gpgcheck=1 gpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7 #released updates [updates] name=CentOS-$releasever - Updates - 163.com #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=updates baseurl=http://mirrors.163.com/centos/$releasever/updates/$basearch/ gpgcheck=1 gpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7 #additional packages that may be useful [extras] name=CentOS-$releasever - Extras - 163.com #mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=extras baseurl=http://mirrors.163.com/centos/$releasever/extras/$basearch/ gpgcheck=1 gpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7 #additional packages that extend functionality of existing packages [centosplus] name=CentOS-$releasever - Plus - 163.com baseurl=http://mirrors.163.com/centos/$releasever/centosplus/$basearch/ gpgcheck=1 enabled=0 gpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7 ================================================ FILE: docker/package/manylinux/Dockerfile ================================================ ARG from FROM ${from} ARG use_tuna_yum=0 ARG pip_args="" ARG bazel_url="https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-linux-x86_64" LABEL maintainer="OneFlow Maintainers" # manylinux2014 ENV AUDITWHEEL_ARCH x86_64 ENV AUDITWHEEL_PLAT manylinux2014_$AUDITWHEEL_ARCH ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 ENV PATH $PATH:/usr/local/bin ENV LD_LIBRARY_PATH /usr/local/lib64:/usr/local/lib ENV PKG_CONFIG_PATH /usr/local/lib/pkgconfig # use tuna mirror COPY docker/package/manylinux/CentOS7-Base-163.repo /tmp/CentOS-Base.repo RUN if [ "${use_tuna_yum}" = "1" ]; then mv /tmp/CentOS-Base.repo /etc/yum.repos.d/ && yum makecache ; fi # to speed up docker img building disable cuda repo # in 10.1, cuda yum repo will update cublas to 10.2 and breaks build RUN yum-config-manager --disable cuda nvidia-ml ARG MANYLINUX_SHA=b634044 RUN yum -y install unzip && curl -L -o manylinux.zip https://github.com/Oneflow-Inc/manylinux/archive/${MANYLINUX_SHA}.zip && unzip manylinux.zip -d tmp && cp -r tmp/*/docker/build_scripts /build_scripts && bash build_scripts/build.sh && rm -r build_scripts tmp manylinux.zip ENV SSL_CERT_FILE=/opt/_internal/certs.pem # manylinux2014 end RUN yum-config-manager --add-repo https://yum.repos.intel.com/oneapi && \ rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB && \ yum update -y && yum install -y epel-release && \ yum -y install centos-release-scl && \ yum install -y intel-oneapi-mkl-devel-2021.2.0 nasm rdma-core-devel devtoolset-7-gcc* rsync gdb RUN /opt/python/cp35-cp35m/bin/pip install $pip_args -U cmake==3.18.4.post1 && ln -s /opt/_internal/cpython-3.5.9/bin/cmake /usr/bin/cmake RUN mkdir -p /tmp && cd /tmp && \ curl -L -o patchelf-src.zip \ https://github.com/Oneflow-Inc/patchelf/archive/64bf5388ef7d45d3697c4aadbd3f5d7d68a22aa3.zip && \ unzip patchelf-src.zip && cd patchelf-* && ./bootstrap.sh && ./configure && make -j`nproc` && \ make install && cd .. && rm -rf patchelf-* RUN curl -L $bazel_url -o /usr/local/bin/bazel \ && chmod +x /usr/local/bin/bazel \ && bazel COPY dev-requirements.txt /tmp/dev-requirements.txt RUN /opt/python/cp36-cp36m/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \ && /opt/python/cp37-cp37m/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \ && /opt/python/cp38-cp38/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \ && rm /tmp/dev-requirements.txt ================================================ FILE: docker/package/manylinux/README.md ================================================ # 使用 docker 生成 OneFlow wheel 包 ### 创建 docker 容器 在 OneFlow 源码根目录下运行: ``` docker build -f docker/package/manylinux/Dockerfile --build-arg from=nvidia/cuda:10.2-cudnn7-devel-centos7 -t oneflow:manylinux2014-cuda10.2 . ``` ### 打包 manylinux python wheel 这里有 manylinux2014(centos7) + cuda10.2 的 Dockerfile,里面安装了编译 oneflow 所需的库,假设你已经用 Dockerfile build 了一个 docker 镜像,叫做 oneflow:manylinux2014-cuda10.2,那么只要在 oneflow 源码目录运行 ```bash docker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2 ``` If you prefer operate inside docker: ```bash docker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2 bash ``` ```bash /oneflow-src/docker/package/manylinux/build_wheel.sh --python3.6 --wheel-dir /oneflow-src/wheel-test ``` 就会在 docker 镜像里执行 build_wheel.sh 来编译生成 python 3.5 到 python 3.8 的 oneflow manylinux2014 wheel。生成的包在 oneflow 源码目录下的 wheelhouse/ 文件夹内 #### 注意事项 1. 运行 `docker run` 时可能需要添加 `-e http_proxy=$http_proxy -e https_proxy=$https_proxy` 参数,以在容器内使用宿主机的代理,避免编译第三方库时因为网络问题而出错 2. 只要运行了 `cmake -DTHIRD_PARTY=ON ..`,oneflow 本体都会从头编译,所以如果第三方库已经由 docker 容器编译过,这次只想增量编译 oneflow 本体,可以用命令 ```bash docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh --skip-third-party ``` 这会给 build_wheel.sh 传一个 `--skip-third-party` 参数,跳过第三方库的编译 3. 只想在生成某些 python 版本的包,例如 python3.5,可以用命令 ```bash docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh --python3.5 ``` 支持的参数是 `--python3.5`、`--python3.6`、`--python3.7`、`--python3.8`,需要生成多个版本可以同时传入多个参数。不传入版本参数则会生成所有的 python 版本的包 3. 如果想自定义 oneflow 编译时的 cmake 参数,可以直接把 cmake 参数写出来,如: ```bash docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh -DWITH_XLA=ON ``` ================================================ FILE: docker/package/manylinux/build_wheel.py ================================================ import os import subprocess import tempfile from pathlib import Path import getpass import uuid def get_arg_env(env_var_name: str, mode="run"): val = os.getenv(env_var_name) assert val, f"system environment variable {env_var_name} found empty" if mode == "run": return f"--env {env_var_name}={val}" elif mode == "build": return f"--build-arg {env_var_name}={val}" else: raise f"{mode} not supported" def get_proxy_build_args(): proxy_build_args = [] if os.getenv("HTTP_PROXY"): for v in ["HTTP_PROXY", "HTTPS_PROXY"]: proxy_build_args.append(get_arg_env(v, mode="build")) if os.getenv("http_proxy"): for v in ["http_proxy", "https_proxy"]: proxy_build_args.append(get_arg_env(v, mode="build")) return " ".join(proxy_build_args) def get_proxy_env_args(): proxy_build_args = [] if os.getenv("HTTP_PROXY"): for v in ["HTTP_PROXY", "HTTPS_PROXY"]: proxy_build_args.append(get_arg_env(v)) if os.getenv("http_proxy"): for v in ["http_proxy", "https_proxy"]: proxy_build_args.append(get_arg_env(v)) return " ".join(proxy_build_args) def build_img( cuda_version, oneflow_src_dir, use_aliyun_mirror, use_tuna, use_system_proxy, img_tag, dry, ): cudnn_version = 7 if str(cuda_version).startswith("11"): cudnn_version = 8 cuda_version_img = cuda_version if cuda_version == "11.2": cuda_version_img = "11.2.2" if cuda_version == "11.1": cuda_version_img = "11.1.1" if cuda_version == "11.0": cuda_version_img = "11.0.3" from_img = f"nvidia/cuda:{cuda_version_img}-cudnn{cudnn_version}-devel-centos7" tuna_build_arg = "" if use_tuna: tuna_build_arg = '--build-arg use_tuna_yum=1 --build-arg pip_args="-i https://mirrors.aliyun.com/pypi/simple"' if use_aliyun_mirror: tuna_build_arg += ' --build-arg bazel_url="https://oneflow-static.oss-cn-beijing.aliyuncs.com/deps/bazel-3.4.1-linux-x86_64"' proxy_build_arg = get_proxy_build_args() if use_system_proxy else "" cmd = f"docker build -f docker/package/manylinux/Dockerfile {proxy_build_arg} {tuna_build_arg} --build-arg from={from_img} -t {img_tag} ." print(cmd) if dry == False: subprocess.check_call(cmd, cwd=oneflow_src_dir, shell=True) def common_cmake_args(cache_dir=None, extra_oneflow_cmake_args=None): assert cache_dir ret = "" if ( not extra_oneflow_cmake_args or "-DCMAKE_BUILD_TYPE" not in extra_oneflow_cmake_args ): ret += " -DCMAKE_BUILD_TYPE=Release" if not extra_oneflow_cmake_args or "-DBUILD_RDMA" not in extra_oneflow_cmake_args: ret += " -DBUILD_RDMA=ON" third_party_install_dir = os.path.join(cache_dir, "build-third-party-install") ret += f" -DTHIRD_PARTY_DIR={third_party_install_dir}" return ret def get_build_dir_arg(cache_dir, oneflow_src_dir): return "" build_dir_real = os.path.join(cache_dir, "build") build_dir_mount = os.path.join(oneflow_src_dir, "build") return f"-v {build_dir_real}:{build_dir_mount}" def force_rm_dir(dir_to_clean): print("cleaning:", dir_to_clean) assert dir_to_clean clean_cmd = f"docker run --network=host --rm -v {dir_to_clean}:{dir_to_clean} -w {dir_to_clean} busybox rm -rf {dir_to_clean}/*" subprocess.check_call(clean_cmd, shell=True) def create_tmp_bash_and_run(docker_cmd, img, bash_cmd, bash_args, bash_wrap, dry): with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as wrapper_f: with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f: w_name = "/host" + wrapper_f.name f_name = "/host" + f.name bash_cmd = "PATH=/opt/python/cp37-cp37m/bin:$PATH\n" + bash_cmd f.write(bash_cmd) f.flush() wrapped = f""" {bash_wrap} bash {bash_args} {f_name} """ wrapper_f.write(wrapped) wrapper_f.flush() print("=" * 5 + f"bash_cmd: {f_name}" + "=" * 5) print(bash_cmd) print("=" * 5 + f"bash_cmd: {f_name}" + "=" * 5) print("=" * 5 + f"wrapped: {w_name}" + "=" * 5) print(wrapped) print("=" * 5 + f"wrapped: {w_name}" + "=" * 5) docker_cmd = f"{docker_cmd} -v /tmp:/host/tmp {img}" cmd = f"{docker_cmd} bash {bash_args} {w_name}" print(cmd) if dry: print("dry run, skipping") else: subprocess.check_call(cmd, shell=True) def get_common_docker_args( oneflow_src_dir=None, cache_dir=None, current_dir=None, house_dir=None, use_system_proxy=True, inplace=False, ): root = Path(cache_dir) child = Path(current_dir) assert root in child.parents cwd = os.getcwd() pwd_arg = f"-v {cwd}:{cwd}" cache_dir_arg = f"-v {cache_dir}:{cache_dir}" house_dir_arg = "" if house_dir: house_dir_arg = f"-v {house_dir}:{house_dir}" build_dir_arg = get_build_dir_arg(cache_dir, oneflow_src_dir) proxy_env_arg = get_proxy_env_args() if use_system_proxy else "" inplace_attr = "" if inplace == False: inplace_attr = ":ro" cache_dir_args = " ".join( [ f"-v {os.path.join(cache_dir, 'ccache')}:/root/.ccache", f"-v {os.path.join(cache_dir, 'local')}:/root/.local", f"-v {os.path.join(cache_dir, 'cache')}:/root/.cache", ] ) return f"{cache_dir_args} -v {oneflow_src_dir}:{oneflow_src_dir}{inplace_attr} {proxy_env_arg} {pwd_arg} {house_dir_arg} {cache_dir_arg} {build_dir_arg} -w {current_dir} --shm-size=8g" def get_python_dir(inplace=True, oneflow_src_dir=None, cache_dir=None): if inplace: assert oneflow_src_dir return os.path.join(oneflow_src_dir, "python") else: assert cache_dir return os.path.join(cache_dir, "python") def build_third_party( img_tag, oneflow_src_dir, cache_dir, extra_oneflow_cmake_args, extra_docker_args, bash_args, bash_wrap, dry, use_system_proxy, inplace, ): third_party_build_dir = os.path.join(cache_dir, "build-third-party") oneflow_python_dir = get_python_dir( inplace=inplace, oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir ) if inplace: inplace_arg = "" oneflow_python_dir_cmd = "" else: inplace_arg = f"-DONEFLOW_PYTHON_DIR={oneflow_python_dir}" oneflow_python_dir_cmd = f""" rm -rf {oneflow_python_dir} cp -r {oneflow_src_dir}/python {oneflow_python_dir} cd {oneflow_python_dir} git init git clean -nXd git clean -fXd cd - """ cmake_cmd = " ".join( [ "cmake", common_cmake_args( cache_dir=cache_dir, extra_oneflow_cmake_args=extra_oneflow_cmake_args ), "-DTHIRD_PARTY=ON -DONEFLOW=OFF", extra_oneflow_cmake_args, oneflow_src_dir, inplace_arg, ] ) bash_cmd = f"""set -ex export ONEFLOW_PYTHON_DIR={oneflow_python_dir} {oneflow_python_dir_cmd} export PATH="$PATH:$(dirname {get_python_bin('3.6')})" export PYTHON_BIN_PATH={get_python_bin('3.6')} $PYTHON_BIN_PATH -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r {os.path.join(oneflow_src_dir, "ci/fixed-dev-requirements.txt")} $PYTHON_BIN_PATH -c "from __future__ import print_function;import numpy; print(numpy.get_include());" {cmake_cmd} cmake --build . -j `nproc` --target oneflow_deps """ common_docker_args = get_common_docker_args( oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir, current_dir=third_party_build_dir, use_system_proxy=use_system_proxy, inplace=inplace, ) docker_cmd = ( f"docker run --network=host {extra_docker_args} --rm {common_docker_args}" ) create_tmp_bash_and_run(docker_cmd, img_tag, bash_cmd, bash_args, bash_wrap, dry) def get_python_bin(version): assert version in ["3.5", "3.6", "3.7", "3.8", "3.9"] py_ver = "".join(version.split(".")) py_abi = f"cp{py_ver}-cp{py_ver}" if version in ["3.5", "3.6", "3.7"]: py_abi = f"{py_abi}m" py_root = f"/opt/python/{py_abi}" py_bin = f"{py_root}/bin/python" return py_bin def build_oneflow( img_tag, oneflow_src_dir, cache_dir, extra_oneflow_cmake_args, extra_docker_args, python_version, skip_wheel, package_name, house_dir, bash_args, bash_wrap, dry, use_system_proxy, enter_bash, skip_audit, inplace, ): oneflow_build_dir = os.path.join(cache_dir, "build-oneflow") python_bin = get_python_bin(python_version) oneflow_python_dir = get_python_dir( inplace=inplace, oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir ) if inplace: inplace_arg = "" else: inplace_arg = f"-DONEFLOW_PYTHON_DIR={oneflow_python_dir}" cmake_cmd = " ".join( [ "cmake", common_cmake_args( cache_dir=cache_dir, extra_oneflow_cmake_args=extra_oneflow_cmake_args ), "-DTHIRD_PARTY=OFF -DONEFLOW=ON", extra_oneflow_cmake_args, "-DCMAKE_EXPORT_COMPILE_COMMANDS=1", f"-DPython3_EXECUTABLE={python_bin}", f"-DCODEGEN_PYTHON_EXECUTABLE={get_python_bin('3.6')}", oneflow_src_dir, inplace_arg, ] ) common_docker_args = get_common_docker_args( oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir, current_dir=oneflow_build_dir, house_dir=house_dir, use_system_proxy=use_system_proxy, inplace=inplace, ) docker_cmd = ( f"docker run --network=host --rm {common_docker_args} {extra_docker_args}" ) if enter_bash: docker_cmd += " -it" bash_cmd = f"""set -ex export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/opt/intel/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH export ONEFLOW_SRC_DIR={oneflow_src_dir} export ONEFLOW_CMAKE_CMD="{cmake_cmd}" {python_bin} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r {os.path.join(oneflow_src_dir, "ci/fixed-dev-requirements.txt")} """ if enter_bash: bash_cmd += "\nbash" else: bash_cmd += f""" cd {oneflow_python_dir} git clean -nXd -e \!oneflow/include -e \!oneflow/include/** git clean -fXd -e \!oneflow/include -e \!oneflow/include/** cd - {cmake_cmd} cmake --build . -j `nproc` """ if skip_wheel or enter_bash: pass else: bash_cmd += f""" cd {oneflow_python_dir} {python_bin} setup.py bdist_wheel -d /tmp/tmp_wheel --package_name {package_name} cd - """ if skip_wheel == False: if skip_audit: bash_cmd += f""" cp /tmp/tmp_wheel/*.whl {house_dir} """ else: bash_cmd += f""" auditwheel repair /tmp/tmp_wheel/*.whl --wheel-dir {house_dir} """ return create_tmp_bash_and_run( docker_cmd, img_tag, bash_cmd, bash_args, bash_wrap, dry ) def is_img_existing(tag): returncode = subprocess.run( f"docker image inspect {tag}", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode if returncode == 0: return True else: return False if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--custom_img_tag", type=str, required=False, default=None, ) parser.add_argument( "--container_name", type=str, required=False, default=None, ) parser.add_argument( "--cache_dir", type=str, required=False, default=None, ) default_wheel_house_dir = os.path.join(os.getcwd(), "wheelhouse") parser.add_argument( "--wheel_house_dir", type=str, required=False, default=default_wheel_house_dir, ) parser.add_argument("--python_version", type=str, required=True) parser.add_argument( "--cuda_version", type=str, required=False, default="10.2", ) parser.add_argument( "--package_name", type=str, required=False, default="oneflow", ) parser.add_argument( "--extra_oneflow_cmake_args", action="append", nargs="+", default=[] ) parser.add_argument( "--extra_docker_args", type=str, required=False, default="", ) parser.add_argument( "--oneflow_src_dir", type=str, required=False, default=os.getcwd(), ) parser.add_argument( "--skip_third_party", default=False, action="store_true", required=False ) parser.add_argument( "--skip_wheel", default=False, action="store_true", required=False ) parser.add_argument( "--skip_img", default=False, action="store_true", required=False ) parser.add_argument( "--skip_audit", default=False, action="store_true", required=False ) parser.add_argument( "--build_img", default=False, action="store_true", required=False ) parser.add_argument( "--use_tuna", default=False, action="store_true", required=False ) parser.add_argument("--dry", default=False, action="store_true", required=False) parser.add_argument( "--use_system_proxy", default=False, action="store_true", required=False ) parser.add_argument("--mlir", default=False, action="store_true", required=False) parser.add_argument("--gcc4", default=False, action="store_true", required=False) parser.add_argument("--gcc7", default=False, action="store_true", required=False) parser.add_argument("--gcc9", default=False, action="store_true", required=False) parser.add_argument( "--use_aliyun_mirror", default=False, action="store_true", required=False ) parser.add_argument("--cpu", default=False, action="store_true", required=False) parser.add_argument("--bash", default=False, action="store_true", required=False) parser.add_argument("--inplace", default=False, action="store_true", required=False) parser.add_argument( "--shared_lib", default=False, action="store_true", required=False ) parser.add_argument("--retry", default=0, type=int) args = parser.parse_args() if args.skip_img: "Arg skip_img is deprecated. Setting it has no effect. If you want to build image, use --build_img" if args.skip_wheel: args.skip_audit = True print("args.extra_oneflow_cmake_args", args.extra_oneflow_cmake_args) assert args.package_name extra_oneflow_cmake_args = " ".join( [" ".join(l) for l in args.extra_oneflow_cmake_args] ) if (not args.gcc4) and (not args.gcc7) and (not args.gcc9): args.gcc7 = True cuda_versions = [] if args.use_aliyun_mirror: extra_oneflow_cmake_args += " -DTHIRD_PARTY_MIRROR=aliyun" if args.shared_lib: extra_oneflow_cmake_args += " -DBUILD_SHARED_LIBS=ON" if args.cpu: extra_oneflow_cmake_args += " -DBUILD_CUDA=OFF" cuda_versions = ["10.2"] else: extra_oneflow_cmake_args += " -DBUILD_CUDA=ON" cuda_versions = args.cuda_version.split(",") cuda_versions = [v.strip() for v in cuda_versions] if args.mlir: extra_oneflow_cmake_args += " -DWITH_MLIR=ON" else: extra_oneflow_cmake_args += " -DWITH_MLIR=Off" for cuda_version in cuda_versions: cache_dir = None def build(): img_tag = None img_prefix = f"oneflow-manylinux2014-cuda{cuda_version}" user = getpass.getuser() versioned_img_tag = f"{img_prefix}:0.1" if cuda_version in ["11.0", "11.1"]: versioned_img_tag = f"{img_prefix}:0.2" enforced_oneflow_cmake_args = "" enforced_oneflow_cmake_args += " -DBUILD_TESTING=ON" if float(cuda_version) >= 11: assert ( "CUDNN_STATIC" not in extra_oneflow_cmake_args ), "CUDNN_STATIC will be set to OFF if cuda_version > 11" enforced_oneflow_cmake_args += " -DCUDNN_STATIC=OFF" extra_docker_args = args.extra_docker_args if not args.container_name: args.container_name = f"manylinux-build-run-by-{getpass.getuser()}" assert args.container_name subprocess.call( f"docker rm -f {args.container_name}", shell=True, ) extra_docker_args += f" --name {args.container_name}" user_img_tag = f"{img_prefix}:{user}" inc_img_tag = f"oneflowinc/{versioned_img_tag}" img_tag = inc_img_tag if args.build_img: img_tag = user_img_tag elif args.custom_img_tag: img_tag = args.custom_img_tag else: if is_img_existing(versioned_img_tag): img_tag = versioned_img_tag elif is_img_existing(inc_img_tag): img_tag = inc_img_tag else: raise ValueError( f"img not found, please run 'docker pull {inc_img_tag}'" ) assert img_tag is not None print("using", img_tag) if args.build_img: build_img( cuda_version, args.oneflow_src_dir, args.use_aliyun_mirror, args.use_tuna, args.use_system_proxy, img_tag, args.dry, ) bash_args = "" bash_wrap = "" if args.gcc4: bash_wrap = "gcc --version" elif args.gcc7: bash_wrap = """ source scl_source enable devtoolset-7 gcc --version """ elif args.gcc9: bash_wrap = """ source scl_source enable devtoolset-9 gcc --version """ else: raise ValueError("either one in gcc4, gcc7, gcc9 must be enabled") global cache_dir if args.cache_dir: cache_dir = args.cache_dir else: cache_dir = os.path.join(os.getcwd(), "manylinux2014-build-cache") sub_dir = cuda_version if args.mlir: sub_dir += "-mlir" if args.gcc4: sub_dir += "-gcc4" if args.gcc7: sub_dir += "-gcc7" if args.gcc9: sub_dir += "-gcc9" if args.cpu: assert len(cuda_versions) == 1 sub_dir += "-cpu" if args.shared_lib: sub_dir += "-shared" cache_dir = os.path.join(cache_dir, sub_dir) if args.build_img: return if args.skip_third_party == False: build_third_party( img_tag, args.oneflow_src_dir, cache_dir, extra_oneflow_cmake_args + enforced_oneflow_cmake_args, extra_docker_args, bash_args, bash_wrap, args.dry, args.use_system_proxy, args.inplace, ) print(cuda_version.split(".")) cuda_version_literal = "".join(cuda_version.split(".")[:2]) assert len(cuda_version_literal) == 3 python_versions = args.python_version.split(",") python_versions = [pv.strip() for pv in python_versions] for python_version in python_versions: print("building for python version:", python_version) build_oneflow( img_tag, args.oneflow_src_dir, cache_dir, extra_oneflow_cmake_args + enforced_oneflow_cmake_args, extra_docker_args, python_version, args.skip_wheel, args.package_name, args.wheel_house_dir, bash_args, bash_wrap, args.dry, args.use_system_proxy, args.bash, args.skip_audit, args.inplace, ) try: build() except subprocess.CalledProcessError as e: print("failed: ", e.cmd, e.args) if cache_dir and args.retry > 0: print("clean: ", cache_dir, flush=True) print("start retrying...", flush=True) if args.dry: pass else: force_rm_dir(cache_dir) build() else: exit(1) ================================================ FILE: docker/package/manylinux/launch.sh ================================================ set -ex docker run --rm -it \ -v `pwd`:`pwd` \ -w `pwd` oneflow:rel-manylinux2014-cuda-11.0 bash ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile html_cn: Makefile @CN_DOCS=1 $(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)-cn" $(SPHINXOPTS) $(O) html: Makefile @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) clean: Makefile @rm -rf build build-cn ================================================ FILE: docs/requirements.txt ================================================ sphinx==3.5.4 jinja2<3.1 recommonmark==0.6.0 furo==2021.4.11b34 sphinx-copybutton==0.5.0 # above are dev dependencies --pre --find-links https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cpu oneflow ================================================ FILE: docs/source/_static/.gitkeep ================================================ ================================================ FILE: docs/source/auto_parallel.rst ================================================ Auto Parallelism ==================================================== As the scale of deep-learning models grows larger and larger, distributed training, or parallelism, is needed. Data parallelism and model parallelism has been designed to speed up the training and solve memory issues. In oneflow, SBP signature enables users to configure parallelism policy easily. However, users still need to specify the SBP property for each operator, or most of them. Users might spend a couple of days digging into the detail of parallelism and get a low throughput just because of a slight mistake in the configuration of SBP signature. .. note:: It only works on :doc:`graph` mode. Our strength ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To get rid of all those configurations for SBP signatures, we developed auto parallelism. Still, configurations of placement are necessary and we have not supported auto placement yet. If you read this paragraph before you rush into any SBP stuff, then congratulation, you do not need to learn SBPs. You can start writing your code as you did under CPU mode. Our auto parallelism would generate a fast strategy customized for your specific models, the size of parameters, and the number of available GPUs. How to use auto parallelism? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ You just need to simply enable the configuration settings in the model of :doc:`graph` . Example:: import oneflow as flow class SubclassGraph(flow.nn.Graph): def __init__(self): super().__init__() # MUST be called # auto parallelism configuration self.config.enable_auto_parallel(True) # other configurations about auto parallelism # ...... def build(self): pass .. warning:: If you enable auto parallelism, OneFlow will take care of the SBP configurations of operators except for explicit ``to_global`` functions. Configuration API for auto parallelism ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.nn.graph.graph_config.GraphConfig .. autosummary:: :toctree: generated :nosignatures: enable_auto_parallel enable_auto_parallel_ignore_user_sbp_config set_auto_parallel_computation_cost_ratio set_auto_parallel_wait_time enable_auto_parallel_trunk_algo enable_auto_parallel_sbp_collector enable_auto_memory ================================================ FILE: docs/source/autograd.rst ================================================ oneflow.autograd ==================================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/autograd.html ``oneflow.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal changes to the existing code - you only need to declare ``Tensor`` s for which gradients should be computed with the ``requires_grad=True`` keyword. As of now, we only support autograd for floating point ``Tensor`` types ( half, float, double and bfloat16). .. currentmodule:: oneflow.autograd .. autosummary:: :toctree: generated :nosignatures: backward grad Locally disabling gradient computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autosummary:: :toctree: generated :nosignatures: no_grad enable_grad set_grad_enabled inference_mode .. TODO(wyg): uncomment this after aligning accumulate grad .. Default gradient layouts .. ^^^^^^^^^^^^^^^^^^^^^^^^ .. A ``param.grad`` is accumulated by replacing ``.grad`` with a .. new tensor ``.grad + new grad`` during :func:`oneflow.autograd.backward()` or .. :func:`oneflow.Tensor.backward()`. In-place operations on Tensors ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autograd's aggressive buffer freeing and reuse makes it very efficient and there are very few occasions when in-place operations actually lower memory usage by any significant amount. Unless you're operating under heavy memory pressure, you might never need to use them. Tensor autograd functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autosummary:: :nosignatures: oneflow.Tensor.grad oneflow.Tensor.requires_grad oneflow.Tensor.is_leaf oneflow.Tensor.backward oneflow.Tensor.detach oneflow.Tensor.register_hook oneflow.Tensor.retain_grad Function ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: Function .. currentmodule:: oneflow.autograd .. autosummary:: :toctree: generated :nosignatures: Function.forward Function.backward Function.apply Context method mixins ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ When creating a new :class:`Function`, the following methods are available to `ctx`. .. currentmodule:: oneflow._oneflow_internal.autograd.Function .. autosummary:: :toctree: generated :nosignatures: FunctionCtx.mark_non_differentiable FunctionCtx.save_for_backward FunctionCtx.saved_tensors functional ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.autograd.functional .. autosummary:: :toctree: generated :nosignatures: vjp jvp jacobian hessian vhp hvp ================================================ FILE: docs/source/cn/__init__.py ================================================ from .math_ops import * from .activation import * ================================================ FILE: docs/source/cn/activation.py ================================================ import oneflow from oneflow.framework.docstr.utils import reset_docstr reset_docstr( oneflow.nn.ReLU, r"""ReLU(inplace=False) ReLU 激活函数,对张量中的每一个元素做 element-wise 运算,公式如下: :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` 参数: inplace: 是否做 in-place 操作。 默认为 ``False`` 形状: - Input: :math:`(N, *)` 其中 `*` 的意思是,可以指定任意维度 - Output: :math:`(N, *)` 输入形状与输出形状一致 示例: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> relu = flow.nn.ReLU() >>> ndarr = np.asarray([1, -2, 3]) >>> x = flow.Tensor(ndarr) >>> relu(x) tensor([1., 0., 3.], dtype=oneflow.float32) """, ) ================================================ FILE: docs/source/cn/math_ops.py ================================================ import oneflow from oneflow.framework.docstr.utils import reset_docstr reset_docstr( oneflow.add, r"""add(input, other) 计算 `input` 和 `other` 的和。支持 element-wise、标量和广播形式的加法。 公式为: .. math:: out = input + other 示例: .. code-block:: python >>> import numpy as np >>> import oneflow as flow # element-wise 加法 >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) # 标量加法 >>> x = 5 >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) # 广播加法 >>> x = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) """, ) ================================================ FILE: docs/source/conf.py ================================================ # -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys import oneflow sys.path.insert(0, os.path.abspath(".")) CN_DOCS = os.getenv("CN_DOCS") if CN_DOCS: import cn # -- Project information ----------------------------------------------------- project = u"OneFlow" copyright = u"2020, OneFlow" author = u"OneFlow" # The short X.Y version version = u"" # The full version, including alpha/beta/rc tags release = u"" # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "recommonmark", "sphinx.ext.autosummary", "sphinx_copybutton", ] # build the templated autosummary files autosummary_generate = True numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { ".rst": "restructuredtext", ".txt": "markdown", ".md": "markdown", } # The master toctree document. master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = u"en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "furo" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = "OneFlowdoc" # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ( master_doc, "OneFlow.tex", u"OneFlow API Reference", u"Oneflow Contributors", "manual", ), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, "oneflow", u"OneFlow API Reference", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( master_doc, "OneFlow", u"OneFlow API Reference", author, "OneFlow", "OneFlow API Reference", "Miscellaneous", ), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- autodoc_default_options = { "undoc-members": True, "exclude-members": "forward, extra_repr, reset_parameters", } def should_skip_member(app, what, name, obj, skip, options): import collections is_deprecated = oneflow.is_deprecated(obj) if is_deprecated: print("skipping deprecated", what, name, obj) magical = name in ["__weakref__", "__doc__", "__module__", "__dict__"] return skip or is_deprecated or magical def setup(app): app.connect("autodoc-skip-member", should_skip_member) ================================================ FILE: docs/source/cuda.rst ================================================ oneflow.cuda =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/cuda.html. .. currentmodule:: oneflow.cuda .. autosummary:: :toctree: generated :nosignatures: is_available device_count current_device set_device synchronize get_device_properties get_device_capability get_device_name .. note:: The :attr:`current_device` returns local rank as device index. It is different from the 'torch.current_device()' in PyTorch. Random Number Generator ------------------------- .. autosummary:: :toctree: generated :nosignatures: manual_seed_all manual_seed get_rng_state get_rng_state_all set_rng_state set_rng_state_all GPU tensor ----------------------------- .. autosummary:: :toctree: generated :nosignatures: HalfTensor FloatTensor DoubleTensor BoolTensor ByteTensor CharTensor IntTensor LongTensor Memory management ----------------------------- .. autosummary:: :toctree: generated :nosignatures: empty_cache ================================================ FILE: docs/source/distributed.rst ================================================ oneflow.distributed ========================================================= .. note :: Please refer to `OneFlow Distributed Overview `__ for a brief introduction to all features related to distributed training. OneFlow provides two ways to accomplish `Distributed Training`: - The first way is that users are recommended to use OneFlow's global Tensor for distributed training. Global Tensor regards the computing cluster as a supercomputing device, allowing users to write distributed training code just like in a single-machine environment. - OneFlow also provides a DDP(DistributedDataParallel) module aligned with PyTorch. DDP has been well-known and widely used in data parallelism by the majority of PyTorch users. Also see `PyTorch DDP introduction `_. Basic ------------------------------- When you start distributed training in OneFlow, the following functions can be used. .. currentmodule:: oneflow.env .. autosummary:: :toctree: generated :nosignatures: get_world_size get_rank get_local_rank get_node_size init_rdma rdma_is_initialized `Global Tensor` -------------------------------------------------------------- Construct `Global Tensor` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ A `Global Tensor` can be created with a ``placement`` and a ``sbp``. The ``placement`` describes the physical devices of the global tensor will be allocated, and the ``sbp`` describes its distribution among these devices. :: >>>import oneflow as flow >>> # Place a global tensor on cuda device of rank(process) 0 and 1 >>> placement = flow.placement(type="cuda", ranks=[0, 1]) >>> # Each rank's local data is a part data as a result of spliting global data on dim 0 >>> sbp = flow.sbp.split(dim=0) >>> # Create a global tensor by randn >>> x = flow.randn(4, 5, placement=placement, sbp=sbp) >>> x.shape oneflow.Size([4, 5]) Convert `Local Tensor` to `Global Tensor` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ With ``Tensor.to_global`` interface, `Local Tensor` can create a `Global Tensor` and use that `Local Tensor` as its local component at the current node. Two `local tensors` with the shape of ``(2,5)`` are created separately on two devices. While after the ``to_global`` method, the `global tensor` with a shape of ``(4,5)`` is obtained. Code running on Node 0 :: import oneflow as flow x = flow.randn(2,5) placement = flow.placement("cuda", [0,1]) sbp = flow.sbp.split(0) x_global = x.to_global(placement=placement, sbp=sbp) x_global.shape Code running on Node 1 :: import oneflow as flow x = flow.randn(2,5) placement = flow.placement("cuda", [0,1]) sbp = flow.sbp.split(0) x_global = x.to_global(placement=placement, sbp=sbp) x_global.shape Redistribute `Global Tensor` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Redistributing a `Global Tensor` means moving its data to another device group (or placement), or changing its data distribution (or SBP) across the group, or both at the same time. The redistributed tensor is still a `Global Tensor`. :: >>> import oneflow as flow >>> x = flow.tensor([1.0, 2.0], placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.split(0)) >>> y = x.to_global(placement=flow.placement("cuda", ranks=[2, 3]), sbp=flow.sbp.broadcast) According to the operator's semantics, OneFlow defines a sequence of valid input and output SBP combinations for each built-in operator. So OneFlow could automatically redistribute the `Global Tensor` to satisfy the operator's SBP requirements for its input Tensor. For example, the following code: :: >>> import oneflow as flow >>> x = flow.randn(4, 4, placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.split(0)) >>> y = flow.randn(4, 4, placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.split(1)) >>> z = x + y When ``x + y`` is executed, since x is split along dimension ``0`` and y is split along dimension ``1``, their local components at each node can not be added directly, then OneFlow will automatically redistribute one of x and y to make them have the same SBP, and complete the add operation successfully. .. note :: - Global Tensor can not be used in combination with DDP currently. - Global Tensor requires all devices to execute at the same pace, otherwise, it may cause multi-process deadlock. Get Local Tensor from Global Tensor ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ With ``Tensor.to_local`` interface, the `Global Tensor` can return its local component at the current node. :: y = x.to_local() y.is_local True y tensor([[ 2.9186e-01, -3.9442e-01, 4.7072e-04, -3.2216e-01, 1.7788e-01], [-4.5284e-01, 1.2361e-01, -3.5962e-01, 2.6651e-01, 1.2951e+00]], device='cuda:0', dtype=oneflow.float32) DistributedDataParallel -------------------------------------------------------------- For more information about DistributedDataParallel, see ``nn.parallel.DistributedDataParallel`` The following script shows the process of using ``oneflow.nn.parallel.DistributedDataParallel`` for training data parallel: .. code-block:: import oneflow as flow from oneflow.nn.parallel import DistributedDataParallel as ddp train_x = [ flow.tensor([[1, 2], [2, 3]], dtype=flow.float32), flow.tensor([[4, 6], [3, 1]], dtype=flow.float32), ] train_y = [ flow.tensor([[8], [13]], dtype=flow.float32), flow.tensor([[26], [9]], dtype=flow.float32), ] class Model(flow.nn.Module): def __init__(self): super().__init__() self.lr = 0.01 self.iter_count = 500 self.w = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32)) def forward(self, x): x = flow.matmul(x, self.w) return x m = Model().to("cuda") m = ddp(m) loss = flow.nn.MSELoss(reduction="sum") optimizer = flow.optim.SGD(m.parameters(), m.lr) for i in range(0, m.iter_count): rank = flow.env.get_rank() x = train_x[rank].to("cuda") y = train_y[rank].to("cuda") y_pred = m(x) l = loss(y_pred, y) if (i + 1) % 50 == 0: print(f"{i+1}/{m.iter_count} loss:{l}") optimizer.zero_grad() l.backward() optimizer.step() print(f"\nw:{m.w}") There are only two differences between the data parallelism training code and the stand-alone single-card script: - Use `DistributedDataParallel` to wrap the module object (`m = ddp(m)`) - Use `get_rank` to get the current device number and distribute the data to the device. Then use `launcher` to run the script, leave everything else to OneFlow, which makes distributed training as simple as stand-alone single-card training: :: python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py Communication collectives ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.comm .. autosummary:: :toctree: generated :nosignatures: all_reduce all_gather all_gather_into_tensor all_to_all broadcast barrier gather reduce reduce_scatter reduce_scatter_tensor recv scatter send We also provide PyTorch-compatible APIs for communication collectives, for example, `oneflow.distributed.all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False)`. For more information, see `PyTorch Distributed Communication `_. Note that we currently only support op=ReduceOp.SUM, group=None and async_op=False in these operations. Launching distributed training -------------------------------------------------------------- .. currentmodule:: oneflow.distributed run commands below to see more about usage. :: python3 -m oneflow.distributed.launch -h .. code-block:: usage: launch.py [-h] [--nnodes NNODES] [--node_rank NODE_RANK] [--nproc_per_node NPROC_PER_NODE] [--master_addr MASTER_ADDR] [--master_port MASTER_PORT] [-m] [--no_python] [--redirect_stdout_and_stderr] [--logdir LOGDIR] training_script ... OneFlow distributed training launch helper utility that will spawn up multiple distributed processes positional arguments: training_script The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script training_script_args optional arguments: -h, --help show this help message and exit --nnodes NNODES The number of nodes to use for distributed training --node_rank NODE_RANK The rank of the node for multi-node distributed training --nproc_per_node NPROC_PER_NODE The number of processes to launch on each node, for GPU training, this is recommended to be set to the number of GPUs in your system so that each process can be bound to a single GPU. --master_addr MASTER_ADDR Master node (rank 0)'s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, the --master_addr can simply be 127.0.0.1 --master_port MASTER_PORT Master node (rank 0)'s free port that needs to be used for communication during distributed training -m, --module Changes each process to interpret the launch script as a python module, executing with the same behavior as'python -m'. --no_python Do not prepend the training script with "python" - just exec it directly. Useful when the script is not a Python script. --redirect_stdout_and_stderr write the stdout and stderr to files 'stdout' and 'stderr'. Only available when logdir is set --logdir LOGDIR Relative path to write subprocess logs to. Passing in a relative path will create a directory if needed. Note that successive runs with the same path to write logs to will overwrite existing logs, so be sure to save logs as needed. ================================================ FILE: docs/source/distributions.rst ================================================ oneflow.distributions ================================================== .. contents:: oneflow.distributions :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top .. currentmodule:: oneflow.distributions .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst Distribution Categorical ================================================ FILE: docs/source/environment_variables.rst ================================================ Environment Variables ================================================ OneFlow has an extensive set of environment variables to tune for specific usage. `ONEFLOW_COMM_NET_IB_HCA `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- When there are multiple IB NIC(which can be checked by ``ibstatus`` on the server), the system uses the first IB NIC for comm_net communication by default. When this environment variable is set, the system will check all IB NIC and find the NIC with the corresponding name. `#5626 `_ Values accepted ^^^^^^^^^^^^^^^ The default value is empty, such as ``mlx5_0:1``、 ``mlx5_1:1``. When the port is 0, the default value is 1, representing the first port. `ONEFLOW_COMM_NET_IB_GID_INDEX `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- For the query of `ibv_query_gid `_, and 0 represents success. It often used with ``ONEFLOW_COMM_NET_IB_HCA``. GID means the Global ID, QP under RoCE network must be built by this value, instead of just using the LID as in the IB network. `#5626 `_ Values accepted ^^^^^^^^^^^^^^^ The default value is 0, representing the port index value `ONEFLOW_COMM_NET_IB_QUEUE_DEPTH `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Queue length of jobs in IB network. This value effectively controls the size of the module without instead of using IB's default size, such as ``ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE``. Values accepted ^^^^^^^^^^^^^^^ The default value is ``1024``, receiving ``int64_t``. The system would compare with ``max_qp_wr`` (Maximum number of outstanding WR on any work queue), and take the smaller one. `ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- The size of the module read when communicating. The value can calculate the amount of module, and transmit it after encapsulation. Values accepted ^^^^^^^^^^^^^^^ The default value is ``8388608`` (8M) `ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Represents stream, and marks Blocking synchronization in cuda. `Detailed information `_, `#5612 `_, `#5837 `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_LIBIBVERBS_PATH `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- To load the DynamicLibrary by dlopen at runtime, to find symbols of ibverbs functions by dlopen without linking during compile for better compatibility. `#4852 `_. If it failed, it will output ``libibverbs not available, ibv_fork_init skipped``, if it worked, the ``import oneflow`` will output such as ``loaded library: /usr/lib/x86_64-linux-gnu/libibverbs.so.1`` Values accepted ^^^^^^^^^^^^^^^ The default value is empty, but will load ``libibverbs.so.1``, ``libibverbs.so``. `ONEFLOW_DEBUG_MODE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Enable ``debug`` mode, ``ONEFLOW_DEBUG`` can do. If ``debug`` mode is on, it will output more INFO level logs, different ``prototxt`` and ``dot`` to files. The automatically inserted boxing information will be printed to the log file under eager global mode. Values accepted ^^^^^^^^^^^^^^^ The default value is empty, but will receive any string. `ONEFLOW_DRY_RUN `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Only for test running, it can generate log files like ``dot``. Exit once the test is succeed, do not try real training. Values accepted ^^^^^^^^^^^^^^^ The default value is empty, but will receive any string. `ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Only used when debugging because the performance would be affected, it could detect which op in the network appears nan or inf. It will create ``CpuCheckNumericsKernelObserver`` under ``cpu`` , and ``CudaCheckNumericsKernelObserver`` under ``cuda`` `#6052 `_ . Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_DEBUG_KERNEL_SYNC_CHECK `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Only used when debugging because the performance would be affected. It will create ``SyncCheckKernelObserver`` and will be synced after each kernel. It could be used to debug cuda errors. `#6052 `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_PROFILER_KERNEL_PROFILE_CUDA_MEMORY_BANDWIDTH `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Used when generate profiler files by nsys. Profiler is only valid for lazy temporarily. It can estimate the memory bandwidth reached by kernel by counting the execution time of the GPU kernel and the size of the input and output memory, and help find potential kernels that can be optimized. `Details `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``. When using, the compiled package needs to enable ``BUILD_PROFILER``. `ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- The same as above. collect `op name `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``. When using, the compiled package needs to enable ``BUILD_PROFILER``. `ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Only use blob_access_checker after enabling, because blob_access_checker is for correctness assurance, and closing it in some cases can increase the kernel overhead. `#5728 `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Takes effect under ``WITH_CUDA_GRAPHS`` and the default value is ``false``. It uses more memory, so when there's just enough memory, it won't run. Turning on CUDA_GRAPH will use up more memory CUDA Graphs support. `#5868 `_ Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- LightActor is a new type of Actor that only handles NormalForward and similar tasks where all regst_num is 1 or tasks with only one kernel. `#5868 `_. ``export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1`` (Would use more memories), ``export ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1``, ``export ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1``, ``export ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR=1``, ``export ONEFLOW_STREAM_REUSE_CUDA_EVENT=1`` can be used together. Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- `#5720 `_. It is used to enable local message queue, ``oneflow.config.thread_enable_local_message_queue(True)`` is no longer used. Values accepted ^^^^^^^^^^^^^^^ Define and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Represents the size of each read from disk. `#5162 `_ Values accepted ^^^^^^^^^^^^^^^ The default value is empty. If an invalid string or negative number is entered, the default value would be ``32 * 1024``; 32KB. `ONEFLOW_DECODER_ENABLE_NVJPEG_HARDWARE_ACCELERATION `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``NVJPEG_VER_MAJOR`` need to be bigger than ``11``. It can accelerate nvjpeg hardware, warm up jpeg decoder and hw_jpeg decoder, `#5851 `_. Hardware JPEG decoder and NVIDIA nvJPEG library on NVIDIA A100 GPUs Values accepted ^^^^^^^^^^^^^^^ Define and set to ``true``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``. `ONEFLOW_SERVING_DEBUG `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- For printing information of OneFlow Serving Debug Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_DISABLE_VIEW `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- To disable view mechanism, which means op related to view would stop running. Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to disable Middle Node. When it is false, all inter-SBP communication is supported Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to disable NUMA_AWARE memory allocation when the OneEmbedding module allocates video memory. NUMA_AWARE memory allocation means that when allocating pinned host memory, the cpu close to the gpu will be considered (for example, if it is gpu 0 1, memory will be allocated on cpu0) Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to allow CUDA to use TF32 numeric types for computation Values accepted ^^^^^^^^^^^^^^^ The default value is ``true`` `ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to disable the fused_mlp operator implemented by cublasLt in FusedMLPFunctor, if disabled, it will degenerate into a multiple matrix multiplication operation. Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to put the EmbeddingShuffle of the OneEmbedding module on a separate stream for overlapping execution. Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16 `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to allow the EmbeddingGradientShuffle operator of the OneEmbedding module to use the FP16 data type in the AMP case. Values accepted ^^^^^^^^^^^^^^^ The default value is ``true`` `ONEFLOW_ONE_EMBEDDING_NOT_FUSE_CAST_TO_UPDATE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to disable the fusion of cast type conversion and parameter update of OneEmbedding parameters into one operator in the case of AMP Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- When the value appears NaN or Inf, save the data Dump. Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_MLIR_ENABLE_IR_PRINTING `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control whether to print ir when running each pass when debugging Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_MLIR_STDOUT `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control whether MLIR outputs log information in the console Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_MLIR_DUMP_IR `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control whether to dump ir files Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_MLIR_ENABLE_ROUND_TRIP `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control whether Oneflow Job goes into MLIR Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_KERNEL_REDUCE_SUM_USE_MATMUL `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- whether to use matrix multiplication for reduce_sum Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Whether to quantify the shuffle application communication in the case of OneEmbedding multi-card Values accepted ^^^^^^^^^^^^^^^ The default value is ``false`` `ONEFLOW_TENSOR_BUFFER_ALIGNED_SIZE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Align size when allocating TensorBuffer memory Values accepted ^^^^^^^^^^^^^^^ The default value is ``1024`` `ONEFLOW_TENSOR_BUFFER_POOL_THREAD_LOCAL_CACHE_SIZE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control the size of ``thread_local_cache`` in TensorBufferPool Values accepted ^^^^^^^^^^^^^^^ The default value is ``64`` `ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Set the maximum size of the gRPC transport message Values accepted ^^^^^^^^^^^^^^^ The default value is ``-1`` `ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control the initial capacity of the PersistentTable of OneEmbedding to avoid frequent expansion Values accepted ^^^^^^^^^^^^^^^ OneEmbedding will calculate according to the actual situation, and users can also choose to configure a larger capacity. `ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_NUM_WORKERS `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- The number of threads used for reading and writing the PersistentTable of OneEmbedding Values accepted ^^^^^^^^^^^^^^^ The default value is ``4`` `ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Specify the size of the all zero and all one buffers on the CUDA device. This buffer can be used with matrix multiplication to implement operations such as reduce_sum Values accepted ^^^^^^^^^^^^^^^ The default value is ``1024x1024`` `OMP_NUM_THREADS `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Set the number of threads used by OMP Values accepted ^^^^^^^^^^^^^^^ The default value will be generated by specific `computational logic `_. `SBP_INFER_RULE_TAG `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Specify SBP derivation rules Values accepted ^^^^^^^^^^^^^^^ When the default value is ``1`` , select the SBP that satisfies the producer or the SBP with the smallest cost as much as possible. When the default value is ``2``, select the SBP that matches the most. When the default value is ``3``, select the SBP with the smallest cost. `ONEFLOW_TENSOR_BUFFER_GROWTH_FACTOR `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control the growth factor of TensorBuffer Values accepted ^^^^^^^^^^^^^^^ The default value is ``1.0`` `ONEFLOW_TENSOR_BUFFER_SHRINK_FACTOR `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Controls the shrink factor of TensorBuffer Values accepted ^^^^^^^^^^^^^^^ The default value is ``0.7`` `ONEFLOW_TENSOR_BUFFER_POOL_SIZE_FACTOR `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Controls the size factor of TensorBuffer Values accepted ^^^^^^^^^^^^^^^ The default value is ``2.0`` `AUTO_PARALLEL_TRANSFER_COST `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Control the size of the automatic parallel transfer cost Values accepted ^^^^^^^^^^^^^^^ The default value is ``1.65e8`` `ONEFLOW_DEBUG_PASS `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Pass names and print job before and after a specific pass, such as ``export ONEFLOW_DEBUG_PASS="FuseAddToOutputPass``. Or ALL, print job before and after a specific pass, such as ``export ONEFLOW_DEBUG_PASS="ALL"``. Values accepted ^^^^^^^^^^^^^^^ The default value is ``empty`` `ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX `_ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Add a prefix to the name of the named host thread in the profiling context to facilitate sorting in the visualization tool (nsight) Values accepted ^^^^^^^^^^^^^^^ The default value is ``empty`` ================================================ FILE: docs/source/graph.rst ================================================ oneflow.nn.Graph ============================================================ Base class for running neural networks in Static Graph Mode. Currently, there are two main ways to run models in deep learning frameworks, namely dynamic graphs and static graphs , which are also conventionally referred to as :ref:`dynamic graph` and :ref:`static graph` in OneFlow. Both approaches have their advantages and disadvantages, and OneFlow provides support for both approaches, with Eager mode being the default. Generally speaking, dynamic graphs are easier to use and static graphs have more performance advantages. :class:`oneflow.nn.Graph` module is provided by OneFlow to allow users to build static graphs and train models with Eager-like programming conventions. .. contents:: oneflow.nn.Graph :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top .. _dynamic graph: Eager Mode to Static Graph Mode ------------------------------------------------------------ OneFlow runs in Eager mode by default. OneFlow's nn.Graph is programmed in a style very similar to Eager Mode, so it is possible to make small changes and get large performance gains. The following script shows the process of building a neural network in eager mode using the interface under ``oneflow.nn`` : .. code-block:: import oneflow as flow import oneflow.nn as nn class ModuleMyLinear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.weight = nn.Parameter(flow.randn(in_features, out_features)) self.bias = nn.Parameter(flow.randn(out_features)) def forward(self, input): return flow.matmul(input, self.weight) + self.bias linear_model = ModuleMyLinear(4, 3) Eager ``nn.Module`` can be reused by ``nn.Graph``. The above script for eager mode can be changed to static Graph mode by adding just a few lines of code, which consists of the following steps: - Define your customized graph as a subclass of ``nn.Graph`` - At the beginning of __init__. Call super().__init__() to let OneFlow do the necessary initialization of the Graph - Reuse the ``nn.Module`` object in Eager mode in __init__ (self.model = model) - Describe the computation in the ``build`` method - Instantiate your graph then call it. .. code-block:: class GraphMyLinear(nn.Graph): def __init__(self): super().__init__() self.model = linear_model def build(self, input): return self.model(input) graph_mylinear = GraphMyLinear() input = flow.randn(1, 4) out = graph_mylinear(input) print(out) tensor([[-0.3298, -3.7907, 0.1661]], dtype=oneflow.float32) .. _static graph: Static Graph Mode ------------------------------------------------------------ Constructing a Graph ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Base class for training or evaluating a neural network in static graph mode. .. currentmodule:: oneflow.nn.Graph .. autosummary:: :toctree: generated :nosignatures: __init__ build add_optimizer set_grad_scaler Executing a Graph ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Call a nn.Graph instance to run a customized graph. .. currentmodule:: oneflow.nn.Graph .. autosummary:: :toctree: generated :nosignatures: __call__ Config options on a Graph ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Optimization options of a nn.Graph. .. currentmodule:: oneflow.nn.graph.graph_config.GraphConfig .. autosummary:: :toctree: generated :nosignatures: enable_amp enable_zero allow_fuse_model_update_ops allow_fuse_add_to_output allow_fuse_cast_scale set_gradient_accumulation_steps enable_cudnn_conv_heuristic_search_algo enable_straighten_algorithm enable_compress_memory Config options on a GraphModule ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ GraphModule is the graph representation of a nn.Module in a nn.Graph. When an nn.Module is added into an nn.Graph, it is wrapped into a ProxyModule. The ProxyModule has a GraphModule inside it. You can get and set the GraphModule to enable graph optimization on the nn.Module. .. currentmodule:: oneflow.nn.graph.graph_block.GraphModule .. autosummary:: :toctree: generated :nosignatures: set_stage activation_checkpointing Save & Load a Model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.nn.Graph .. autosummary:: :toctree: generated :nosignatures: state_dict load_state_dict Debug a Graph ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autosummary:: :toctree: generated :nosignatures: __repr__ debug name ================================================ FILE: docs/source/hub.rst ================================================ oneflow.hub =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/hub.html Oneflow Hub is a pre-trained model repository designed to facilitate research reproducibility. Publishing models ----------------- Oneflow Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple ``hubconf.py`` file; ``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function (example: a pre-trained model you want to publish). :: def entrypoint_name(*args, **kwargs): # args & kwargs are optional, for models which take positional/keyword arguments. ... How to implement an entrypoint? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Here is a code snippet specifies an entrypoint for ``resnet18`` model if we expand the implementation in ``Oneflow-Inc/vision/hubconf.py``. In most case importing the right function in ``hubconf.py`` is sufficient. Here we just want to use the expanded version as an example to show how it works. You can see the full script in `Oneflow-Inc/vision repo `_ :: dependencies = ['oneflow'] from flowvision.models.resnet import resnet18 as _resnet18 # resnet18 is the name of entrypoint def resnet18(pretrained=False, **kwargs): """ # This docstring shows up in hub.help() Resnet18 model pretrained (bool): kwargs, load pretrained weights into the model """ # Call the model, load pretrained weights model = _resnet18(pretrained=pretrained, **kwargs) return model - ``dependencies`` variable is a **list** of package names required to **load** the model. Note this might be slightly different from dependencies required for training a model. - ``args`` and ``kwargs`` are passed along to the real callable function. - Docstring of the function works as a help message. It explains what does the model do and what are the allowed positional/keyword arguments. It's highly recommended to add a few examples here. - Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers. - Callables prefixed with underscore are considered as helper functions which won't show up in :func:`oneflow.hub.list()`. - Pretrained weights can either be stored locally in the github repo, or loadable by :func:`oneflow.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a `project release `_ and use the url from the release. In the example above ``flowvision.models.resnet.resnet18`` handles ``pretrained``, alternatively you can put the following logic in the entrypoint definition. :: if pretrained: # For checkpoint saved in local github repo, e.g. =weights/save.pth dirname = os.path.dirname(__file__) checkpoint = os.path.join(dirname, ) state_dict = oneflow.load(checkpoint) model.load_state_dict(state_dict) # For checkpoint saved elsewhere checkpoint = 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip' model.load_state_dict(oneflow.hub.load_state_dict_from_url(checkpoint, progress=False)) Important Notice ^^^^^^^^^^^^^^^^ - The published models should be at least in a branch/tag. It can't be a random commit. Loading models from Hub ----------------------- OneFlow Hub provides convenient APIs to explore all available models in hub through :func:`oneflow.hub.list()`, show docstring and examples through :func:`oneflow.hub.help()` and load the pre-trained models using :func:`oneflow.hub.load()`. .. automodule:: oneflow.hub .. autofunction:: list .. autofunction:: help .. autofunction:: load .. autofunction:: download_url_to_file .. autofunction:: load_state_dict_from_url Running a loaded model: ^^^^^^^^^^^^^^^^^^^^^^^ Note that ``*args`` and ``**kwargs`` in :func:`oneflow.hub.load()` are used to **instantiate** a model. After you have loaded a model, how can you find out what you can do with the model? A suggested workflow is - ``dir(model)`` to see all available methods of the model. - ``help(model.foo)`` to check what arguments ``model.foo`` takes to run To help users explore without referring to documentation back and forth, we strongly recommend repo owners make function help messages clear and succinct. It's also helpful to include a minimal working example. Where are my downloaded models saved? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The locations are used in the order of - Calling ``hub.set_dir()`` - ``$ONEFLOW_HOME/hub``, if environment variable ``ONEFLOW_HOME`` is set. - ``$XDG_CACHE_HOME/oneflow/hub``, if environment variable ``XDG_CACHE_HOME`` is set. - ``~/.cache/oneflow/hub`` .. autofunction:: get_dir .. autofunction:: set_dir Caching logic ^^^^^^^^^^^^^ By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the directory returned by :func:`~oneflow.hub.get_dir()`. Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release. Known limitations: ^^^^^^^^^^^^^^^^^^ Oneflow hub works by importing the package as if it was installed. There are some side effects introduced by importing in Python. For example, you can see new items in Python caches ``sys.modules`` and ``sys.path_importer_cache`` which is normal Python behavior. This also means that you may have import errors when importing different models from different repos, if the repos have the same sub-package names (typically, a ``model`` subpackage). A workaround for these kinds of import errors is to remove the offending sub-package from the ``sys.modules`` dict; more details can be found in `this github issue `_. A known limitation that is worth mentioning here: users **CANNOT** load two different branches of the same repo in the **same python process**. It's just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it's totally fine to load them in separate processes. ================================================ FILE: docs/source/image.rst ================================================ oneflow.nn.image ====================================== Image operations for neural networks -------------------------------------- .. currentmodule:: oneflow.nn.image .. autosummary:: :toctree: generated :nosignatures: Resize batch_align decode flip normalize ================================================ FILE: docs/source/index.rst ================================================ OneFlow API Reference =================================== Distributed performance (high efficiency) is the core technical difficulty of deep learning frameworks. OneFlow upholds the core concept and architecture of static compilation and streaming parallelism around performance improvement and heterogeneous distributed scaling, solving the challenge of memory wall at cluster level with world-leading technology. .. toctree:: :maxdepth: 1 troubleshooting .. toctree:: :maxdepth: 1 :caption: OneFlow Python API oneflow nn nn.functional tensor tensor_attributes type_info autograd cuda distributed distributions hub linalg nn.init optim graph auto_parallel image utils.data utils.global_view utils.tensor one_embedding environment_variables special Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/source/linalg.rst ================================================ oneflow.linalg =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/linalg.html Common linear algebra operations. Matrix Properties ----------------- .. currentmodule:: oneflow.linalg .. autosummary:: :toctree: generated :nosignatures: norm vector_norm matrix_norm diagonal inv cross det ================================================ FILE: docs/source/nn.functional.rst ================================================ oneflow.nn.functional =========================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.functional.html. .. contents:: oneflow.nn.functional :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top .. currentmodule:: oneflow.nn.functional Convolution functions ------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: conv1d conv2d conv3d conv_transpose1d conv_transpose2d conv_transpose3d fold unfold Normalization functions ----------------------- .. autosummary:: :toctree: generated :nosignatures: batch_norm layer_norm normalize group_norm Pooling functions ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: avg_pool1d avg_pool2d avg_pool3d max_pool1d max_pool2d max_pool3d max_unpool1d max_unpool2d max_unpool3d adaptive_avg_pool1d adaptive_avg_pool2d adaptive_avg_pool3d adaptive_max_pool1d adaptive_max_pool2d adaptive_max_pool3d Non-linear activation functions ------------------------------- .. autosummary:: :toctree: generated :nosignatures: threshold relu hardtanh hardswish relu6 elu selu celu leaky_relu square_relu prelu glu gelu quick_gelu logsigmoid hardshrink softsign softplus softmax softshrink log_softmax gumbel_softmax tanh sigmoid hardsigmoid silu mish Linear functions ---------------- .. autosummary:: :toctree: generated :nosignatures: linear Dropout functions ----------------- .. autosummary:: :toctree: generated :nosignatures: dropout dropout1d dropout2d dropout3d Sparse functions ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: embedding one_hot Distance functions ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: cosine_similarity pairwise_distance Loss functions -------------- .. autosummary:: :toctree: generated :nosignatures: sparse_softmax_cross_entropy cross_entropy ctc_loss l1_loss mse_loss smooth_l1_loss triplet_margin_loss binary_cross_entropy binary_cross_entropy_with_logits Vision functions ---------------- .. autosummary:: :toctree: generated :nosignatures: deform_conv2d pad interpolate upsample grid_sample affine_grid Greedy decoder ---------------- .. autosummary:: :toctree: generated :nosignatures: ctc_greedy_decoder ================================================ FILE: docs/source/nn.init.rst ================================================ oneflow.nn.init =============== .. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html .. currentmodule:: oneflow.nn.init .. autofunction:: calculate_gain .. autofunction:: uniform_ .. autofunction:: normal_ .. autofunction:: constant_ .. autofunction:: ones_ .. autofunction:: zeros_ .. autofunction:: xavier_uniform_ .. autofunction:: xavier_normal_ .. autofunction:: kaiming_uniform_ .. autofunction:: kaiming_normal_ .. autofunction:: trunc_normal_ .. autofunction:: orthogonal_ ================================================ FILE: docs/source/nn.rst ================================================ oneflow.nn =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.html These are the basic building blocks for graphs: .. contents:: oneflow.nn :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top .. currentmodule:: oneflow.nn .. autosummary:: :toctree: generated :nosignatures: :template: Parameter Containers ---------------------------------- .. currentmodule:: oneflow.nn .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst Module Sequential ModuleList ModuleDict ParameterList ParameterDict nn.Module ---------------------------------- .. currentmodule:: oneflow.nn.Module .. autosummary:: :toctree: generated :nosignatures: add_module apply buffers children cpu cuda double train eval extra_repr float forward load_state_dict modules named_buffers named_children named_modules named_parameters parameters register_buffer register_forward_hook register_forward_pre_hook register_backward_hook register_full_backward_hook register_state_dict_pre_hook register_parameter requires_grad_ state_dict to zero_grad Containers Convolution Layers ---------------------------------- .. currentmodule:: oneflow .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Conv1d nn.Conv2d nn.Conv3d nn.ConvTranspose1d nn.ConvTranspose2d nn.ConvTranspose3d nn.Unfold nn.Fold Pooling Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.MaxPool1d nn.MaxPool2d nn.MaxPool3d nn.MaxUnpool1d nn.MaxUnpool2d nn.MaxUnpool3d nn.AdaptiveAvgPool1d nn.AdaptiveAvgPool2d nn.AdaptiveAvgPool3d nn.AdaptiveMaxPool1d nn.AdaptiveMaxPool2d nn.AdaptiveMaxPool3d nn.AvgPool1d nn.AvgPool2d nn.AvgPool3d Padding Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.ConstantPad1d nn.ConstantPad2d nn.ConstantPad3d nn.ReflectionPad1d nn.ReflectionPad2d nn.ReplicationPad1d nn.ReplicationPad2d nn.ZeroPad2d Non-linear Activations (weighted sum, nonlinearity) ---------------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.ELU nn.Hardshrink nn.Hardsigmoid nn.Hardswish nn.Hardtanh nn.LeakyReLU nn.LogSigmoid nn.PReLU nn.ReLU nn.ReLU6 nn.SELU nn.CELU nn.GELU nn.QuickGELU nn.SquareReLU nn.SiLU nn.Sigmoid nn.Mish nn.Softplus nn.Softshrink nn.Softsign nn.Tanh nn.Threshold nn.GLU Non-linear Activations (other) ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Softmax nn.LogSoftmax Normalization Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.BatchNorm1d nn.BatchNorm2d nn.BatchNorm3d nn.SyncBatchNorm nn.FusedBatchNorm1d nn.FusedBatchNorm2d nn.FusedBatchNorm3d nn.GroupNorm nn.InstanceNorm1d nn.InstanceNorm2d nn.InstanceNorm3d nn.LayerNorm nn.RMSLayerNorm nn.RMSNorm Recurrent Layers ---------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.RNN nn.LSTM nn.GRU nn.RNNCell nn.LSTMCell nn.GRUCell Linear Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Identity nn.Linear Dropout Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Dropout nn.Dropout1d nn.Dropout2d nn.Dropout3d Sparse Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Embedding Distance Functions ------------------ .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.CosineSimilarity nn.PairwiseDistance Loss Functions ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.BCELoss nn.BCEWithLogitsLoss nn.CTCLoss nn.CombinedMarginLoss nn.CrossEntropyLoss nn.KLDivLoss nn.L1Loss nn.MSELoss nn.MarginRankingLoss nn.NLLLoss nn.SmoothL1Loss nn.TripletMarginLoss Vision Layers ---------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.PixelShuffle nn.Upsample nn.UpsamplingBilinear2d nn.UpsamplingNearest2d DataParallel Layers (multi-GPU, distributed) -------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.parallel.DistributedDataParallel Data loading and preprocessing Layers ---------------------------------------- .. autosummary:: :toctree: generated :nosignatures: nn.COCOReader nn.CoinFlip nn.CropMirrorNormalize nn.OFRecordBytesDecoder nn.OFRecordImageDecoder nn.OFRecordImageDecoderRandomCrop nn.OFRecordRawDecoder nn.OFRecordReader Quantization Aware Training -------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: nn.MinMaxObserver nn.MovingAverageMinMaxObserver nn.FakeQuantization nn.QatConv1d nn.QatConv2d nn.QatConv3d Utilities --------- From the ``oneflow.nn.utils`` module .. currentmodule:: oneflow.nn.utils .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst clip_grad_norm_ clip_grad_value_ weight_norm remove_weight_norm Utility functions in other modules .. currentmodule:: oneflow .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.utils.rnn.PackedSequence nn.utils.rnn.pack_padded_sequence nn.utils.rnn.pad_packed_sequence nn.utils.rnn.pad_sequence nn.utils.rnn.pack_sequence .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst nn.Flatten Quantized Functions -------------------- Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than floating point precision. .. autosummary:: :toctree: generated :nosignatures: :template: nn.FakeQuantization nn.MinMaxObserver nn.MovingAverageMinMaxObserver nn.Quantization ================================================ FILE: docs/source/one_embedding.rst ================================================ oneflow.one_embedding =================================== Embedding is an important component of recommender system, and it has also spread to many fields outside recommender systems. Each framework provides basic operators for Embedding, for example, ``flow.nn.Embedding`` in OneFlow: :: import numpy as np import oneflow as flow indices = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int) embedding = flow.nn.Embedding(10, 3) y = embedding(indices) OneEmbedding is the large-scale Embedding solution that OneFlow provides to solve the problem of large-scale deep recommender systems. OneEmbedding has the following advantages compared to ordinary opeartors: - With Flexible hierarchical storage, OneEmbedding can place the Embedding table on GPU memory, CPU memory or SSD, and allow high-speed devices to be used as caches for low-speed devices to achieve both speed and capacity. - OneEmbedding supports dynamic expansion. .. note :: Please refer to `Large-Scale Embedding Solution: OneEmbedding `__ for a brief introduction to all features related to OneEmbedding. Configure Embedding Table ---------------------------------- OneEmbedding supports simultaneous creation of multiple Embedding table. The following codes configured three Embedding tables. .. code-block:: import oneflow as flow import oneflow.nn as nn import numpy as np tables = [ flow.one_embedding.make_table_options( flow.one_embedding.make_uniform_initializer(low=-0.1, high=0.1) ), flow.one_embedding.make_table_options( flow.one_embedding.make_uniform_initializer(low=-0.05, high=0.05) ), flow.one_embedding.make_table_options( flow.one_embedding.make_uniform_initializer(low=-0.15, high=0.15) ), ] When configuring the Embedding table, you need to specify the initialization method. The above Embedding tables are initialized in the ``uniform`` method. The result of configuring the Embedding table is stored in the ``tables`` variable .. autofunction:: oneflow.one_embedding.make_table_options .. autofunction:: oneflow.one_embedding.make_table initialization method ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.one_embedding .. autosummary:: :toctree: generated :nosignatures: make_uniform_initializer make_normal_initializer Configure the Storage Attribute of the Embedding Table -------------------------------------------------------------------- Then run the following codes to configure the storage attribute of the Embedding table: .. code-block:: store_options = flow.one_embedding.make_cached_ssd_store_options( cache_budget_mb=8142, persistent_path="/your_path_to_ssd", capacity=40000000, size_factor=1, physical_block_size=4096 ) Storage Method ^^^^^^^^^^^^^^^^^^^^ .. currentmodule:: oneflow.one_embedding .. autosummary:: :toctree: generated :nosignatures: make_device_mem_store_options make_cached_ssd_store_options make_cached_host_mem_store_options .. note :: Please refer to `Large-Scale Embedding Solution: OneEmbedding `__ for a brief introduction to learn about How to Choose the Proper Storage Configuration Instantiate Embedding -------------------------------------------------------------------- After the above configuration is completed, you can use MultiTableEmbedding to get the instantiated Embedding layer. .. code-block:: embedding_size = 128 embedding = flow.one_embedding.MultiTableEmbedding( name="my_embedding", embedding_dim=embedding_size, dtype=flow.float, key_type=flow.int64, tables=tables, store_options=store_options, ) embedding.to("cuda") .. note :: Please refer to `Large-Scale Embedding Solution: OneEmbedding `__ for a brief introduction to learn about Feature ID and Multi-Table Query. MultiTableEmbedding ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: oneflow.one_embedding.MultiTableEmbedding .. currentmodule:: oneflow.one_embedding.MultiTableEmbedding .. autosummary:: :toctree: generated :nosignatures: forward save_snapshot load_snapshot MultiTableMultiColumnEmbedding ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: oneflow.one_embedding.MultiTableMultiColumnEmbedding .. currentmodule:: oneflow.one_embedding.MultiTableMultiColumnEmbedding .. autosummary:: :toctree: generated :nosignatures: forward save_snapshot load_snapshot Construct Graph for Training -------------------------------------------------------------------- OneEmbedding is only supported in Graph mode. .. code-block:: num_tables = 3 mlp = flow.nn.FusedMLP( in_features=embedding_size * num_tables, hidden_features=[512, 256, 128], out_features=1, skip_final_activation=True, ) mlp.to("cuda") class TrainGraph(flow.nn.Graph): def __init__(self,): super().__init__() self.embedding_lookup = embedding self.mlp = mlp self.add_optimizer( flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0) ) self.add_optimizer( flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0) ) def build(self, ids): embedding = self.embedding_lookup(ids) loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size))) loss = loss.sum() loss.backward() return loss .. note :: Please refer to `Distributed Training: OneEmbedding `__ for a brief introduction to learn about Graph For Training Persistent Read & Write ----------------------------------------------- .. currentmodule:: oneflow.one_embedding .. autosummary:: :toctree: generated :nosignatures: make_persistent_table_reader make_persistent_table_writer .. automodule:: oneflow.one_embedding :members: Ftrl ================================================ FILE: docs/source/oneflow.rst ================================================ oneflow =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/torch.html The oneflow package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serializing of Tensors and arbitrary types, and other useful utilities. It has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU with compute capability >= 3.0 .. currentmodule:: oneflow Tensor ------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: BoolTensor ByteTensor CharTensor DoubleTensor FloatTensor HalfTensor IntTensor LongTensor .. autosummary:: :toctree: generated :nosignatures: is_tensor is_floating_point is_nonzero numel set_printoptions get_default_dtype set_default_dtype set_default_tensor_type .. _tensor-creation-ops: Creation Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. note:: Random sampling creation ops are listed under :ref:`random-sampling` and include: :func:`oneflow.rand` :func:`oneflow.randn` :func:`oneflow.randint` :func:`oneflow.randperm` .. autosummary:: :toctree: generated :nosignatures: tensor as_tensor as_strided from_numpy zeros zeros_like ones ones_like randn_like randint_like masked_fill new_ones arange linspace eye empty empty_like full full_like tensor_scatter_nd_update logspace .. _indexing-slicing-joining: Indexing, Slicing, Joining, Mutating Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: argwhere atleast_1d atleast_2d atleast_3d cat column_stack concat chunk dstack expand gather gather_nd batch_gather hsplit hstack vsplit vstack index_select index_add masked_select movedim narrow nonzero permute repeat reshape row_stack select scatter scatter_add scatter_nd slice slice_update split squeeze stack swapaxes swapdims t tile transpose unbind unsqueeze where tensor_split .. _random-sampling: Random sampling ------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: seed manual_seed initial_seed get_rng_state set_rng_state bernoulli normal rand randint randn randperm multinomial In-place random sampling ~~~~~~~~~~~~~~~~~~~~~~~~ There are a few more in-place random sampling functions defined on Tensors as well. Click through to refer to their documentation: - :func:`oneflow.Tensor.normal_` - in-place version of :func:`oneflow.normal` - :func:`oneflow.Tensor.uniform_` - numbers sampled from the continuous uniform distribution Serialization ------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: save load Parallelism ------------------------------------------- .. autosummary:: :toctree: generated :nosignatures: set_num_threads Locally disabling gradient computation ------------------------------------------- The context managers :func:`oneflow.no_grad`, :func:`oneflow.enable_grad`, and :func:`oneflow.set_grad_enabled` are helpful for locally disabling and enabling gradient computation. These context managers are thread local, so they won't work if you send work to another thread using the ``threading`` module, etc. Examples:: >>> import oneflow >>> x = oneflow.zeros(1, requires_grad=True) >>> with oneflow.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> with oneflow.set_grad_enabled(False): ... y = x * 2 >>> y.requires_grad False >>> with oneflow.set_grad_enabled(True): ... y = x * 2 >>> y.requires_grad True .. autosummary:: :toctree: generated :nosignatures: no_grad set_grad_enabled enable_grad is_grad_enabled inference_mode Math operations ------------------------------------------- Pointwise Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: abs acos acosh arccos arccosh add addcdiv addcmul asin asinh arcsin arcsinh atan atanh arctan arctanh atan2 ceil ceil_ clamp clamp_min clamp_max clip cos cosh div erf erfc erfinv exp expm1 floor floor_ frac frac_ fmod gelu quick_gelu square_relu log log1p log2 log10 logical_and logical_not logical_or logical_xor bitwise_and bitwise_or bitwise_xor bitwise_not mish mul neg negative pow reciprocal round round_ rsqrt selu softmax softplus softsign silu sigmoid sign sin sinh sin_ sqrt square sub tan tanh trunc floor_divide lerp lerp_ quantile Reduction Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: argmax argmin amax amin any max min mean median mode prod nansum std sum logsumexp var norm all Comparison Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: argsort eq equal gt isinf isnan le lt ne sort topk ge greater greater_equal maximum minimum not_equal isclose allclose Spectral Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: hann_window Other Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: adaptive_avg_pool1d adaptive_avg_pool2d adaptive_avg_pool3d broadcast_like cast cumprod cumsum diag diagonal einsum flatten flip in_top_k meshgrid nms roc_auc_score roll searchsorted tensordot tril repeat_interleave triu cross bincount broadcast_shapes broadcast_tensors broadcast_to unique BLAS and LAPACK Operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated :nosignatures: addmm bmm baddbmm dot matmul mm mv ================================================ FILE: docs/source/optim.rst ================================================ oneflow.optim =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/optim.html oneflow.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can be also easily integrated in the future. How to use an optimizer ----------------------- To use :mod:`oneflow.optim` you have to construct an optimizer object, that will hold the current state and will update the parameters based on the computed gradients. Constructing it ^^^^^^^^^^^^^^^ To construct an :class:`Optimizer` you have to give it an iterable containing the parameters (all should be :class:`~oneflow.autograd.Variable` s) to optimize. Then, you can specify optimizer-specific options such as the learning rate, weight decay, etc. .. note:: If you need to move a model to GPU via ``.cuda()``, please do so before constructing optimizers for it. Parameters of a model after ``.cuda()`` will be different objects with those before the call. In general, you should make sure that optimized parameters live in consistent locations when optimizers are constructed and used. Example:: import oneflow import oneflow.nn as nn import oneflow.optim as optim model = nn.Linear(16, 3) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) Per-parameter options ^^^^^^^^^^^^^^^^^^^^^ :class:`Optimizer` also support specifying per-parameter options. To do this, instead of passing an iterable of :class:`~oneflow.autograd.Variable`, pass in an iterable of :class:`dict`. Each of them will define a separate parameter group, and should contain a ``params`` key, containing a list of parameters belonging to it. Other keys should match the keyword arguments accepted by the optimizers, and will be used as optimization options for this group. .. note:: You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn't override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups. For example, this is very useful when one wants to specify per-layer learning rates:: import oneflow.nn as nn import oneflow.optim as optim class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.base = nn.Linear(64, 32) self.classifier = nn.Linear(32, 10) def forward(self, x): out = self.base(x) out = self.classifier(out) return out model = Model() optim.SGD( [ {"params": model.base.parameters()}, {"params": model.classifier.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) This means that ``model.base``'s parameters will use the default learning rate of ``1e-2``, ``model.classifier``'s parameters will use a learning rate of ``1e-3``, and a momentum of ``0.9`` will be used for all parameters. Taking an optimization step ^^^^^^^^^^^^^^^^^^^^^^^^^^^ All optimizers implement a :func:`~Optimizer.step` method, that updates the parameters. It can be used in two ways: ``optimizer.step()`` ~~~~~~~~~~~~~~~~~~~~ This is a simplified version supported by most optimizers. The function can be called once the gradients are computed using e.g. :func:`~oneflow.autograd.Variable.backward`. Example:: import oneflow import oneflow.nn as nn import oneflow.nn.functional as F import oneflow.optim as optim from oneflow.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, num): self.inputs = oneflow.randn(num, 1) self.targets = oneflow.sin(self.inputs) def __len__(self): return self.inputs.shape[0] def __getitem__(self, index): return self.inputs[index], self.targets[index] class Model(nn.Module): def __init__(self, input_size): super(Model, self).__init__() self.linear1 = nn.Linear(input_size, 64) self.linear2 = nn.Linear(64, input_size) def forward(self, x): out = self.linear1(x) return self.linear2(F.relu(out)) dataset = CustomDataset(10000) dataloader = DataLoader(dataset, batch_size=10) model = Model(1) loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-3) for epoch in range(100): for input, target in dataloader: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() .. _optimizer-algorithms: .. currentmodule:: oneflow.optim Base class ---------- .. autoclass:: Optimizer .. autosummary:: :toctree: generated :nosignatures: Optimizer.add_param_group Optimizer.load_state_dict Optimizer.state_dict Optimizer.step Optimizer.zero_grad Algorithms ---------- .. autosummary:: :toctree: generated :nosignatures: Adagrad Adam AdamW LAMB RMSprop SGD LBFGS Adjust Learning Rate -------------------- :mod:`oneflow.optim.lr_scheduler` provides several methods to adjust the learning rate based on the number of epochs. :class:`oneflow.optim.lr_scheduler.ReduceLROnPlateau` allows dynamic learning rate reducing based on some validation measurements. Learning rate scheduling should be applied after optimizer's update; e.g., you should write your code this way: Example:: import oneflow import oneflow.nn as nn import oneflow.nn.functional as F import oneflow.optim as optim from oneflow.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, num): self.inputs = oneflow.randn(num, 1) self.targets = oneflow.sin(self.inputs) def __len__(self): return self.inputs.shape[0] def __getitem__(self, index): return self.inputs[index], self.targets[index] class Model(nn.Module): def __init__(self, input_size): super(Model, self).__init__() self.linear1 = nn.Linear(input_size, 64) self.linear2 = nn.Linear(64, input_size) def forward(self, x): out = self.linear1(x) return self.linear2(F.relu(out)) dataset = CustomDataset(10000) dataloader = DataLoader(dataset, batch_size=10) model = Model(1) loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-3) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) for epoch in range(20): for input, target in dataloader: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() scheduler.step() Most learning rate schedulers can be chained (also referred to as chaining schedulers). Example:: import oneflow import oneflow.nn as nn import oneflow.nn.functional as F import oneflow.optim as optim from oneflow.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, num): self.inputs = oneflow.randn(num, 1) self.targets = oneflow.sin(self.inputs) def __len__(self): return self.inputs.shape[0] def __getitem__(self, index): return self.inputs[index], self.targets[index] class Model(nn.Module): def __init__(self, input_size): super(Model, self).__init__() self.linear1 = nn.Linear(input_size, 64) self.linear2 = nn.Linear(64, input_size) def forward(self, x): out = self.linear1(x) return self.linear2(F.relu(out)) dataset = CustomDataset(10000) dataloader = DataLoader(dataset, batch_size=10) model = Model(1) loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-3) scheduler1 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=0.1) for epoch in range(20): for input, target in dataloader: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() scheduler1.step() scheduler2.step() In many places in the documentation, we will use the following template to refer to schedulers algorithms. >>> scheduler = ... >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() .. warning:: If you use the learning rate scheduler (calling ``scheduler.step()``) before the optimizer's update (calling ``optimizer.step()``), this will skip the first value of the learning rate schedule. Please check if you are calling ``scheduler.step()`` at the wrong time. .. autosummary:: :toctree: generated :nosignatures: lr_scheduler.CosineAnnealingLR lr_scheduler.CosineDecayLR lr_scheduler.ExponentialLR lr_scheduler.LambdaLR lr_scheduler.MultiStepLR lr_scheduler.PolynomialLR lr_scheduler.ReduceLROnPlateau lr_scheduler.StepLR lr_scheduler.ConstantLR lr_scheduler.LinearLR lr_scheduler.ChainedScheduler lr_scheduler.SequentialLR lr_scheduler.CosineAnnealingWarmRestarts ================================================ FILE: docs/source/special.rst ================================================ oneflow.special ====================================== The oneflow.special module, modeled after SciPy's special module. ------------------------------------------------------------------- .. currentmodule:: oneflow.special .. autosummary:: :toctree: generated :nosignatures: digamma erf erfc erfinv exp2 expm1 log1p log_softmax logsumexp round softmax zeta ================================================ FILE: docs/source/tensor.rst ================================================ oneflow.Tensor =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/tensors.html A :class:`oneflow.Tensor` is a multi-dimensional matrix containing elements of a single data type. .. currentmodule:: oneflow Data types ---------- OneFlow defines 8 Tensor types with CPU and GPU variants which are as follows: ======================================= =============================================== =============================== ================================== Data type dtype CPU tensor GPU tensor ======================================= =============================================== =============================== ================================== Boolean ``oneflow.bool`` :class:`oneflow.BoolTensor` :class:`oneflow.cuda.BoolTensor` 8-bit integer (unsigned) ``oneflow.uint8`` :class:`oneflow.ByteTensor` :class:`oneflow.cuda.ByteTensor` 8-bit integer (signed) ``oneflow.int8`` :class:`oneflow.CharTensor` :class:`oneflow.cuda.CharTensor` 64-bit floating point ``oneflow.float64`` or ``oneflow.double`` :class:`oneflow.DoubleTensor` :class:`oneflow.cuda.DoubleTensor` 32-bit floating point ``oneflow.float32`` or ``oneflow.float`` :class:`oneflow.FloatTensor` :class:`oneflow.cuda.FloatTensor` 16-bit floating point ``oneflow.float16`` or ``oneflow.half`` :class:`oneflow.HalfTensor` :class:`oneflow.cuda.HalfTensor` 32-bit integer (signed) ``oneflow.int32`` or ``oneflow.int`` :class:`oneflow.IntTensor` :class:`oneflow.cuda.IntTensor` 64-bit integer (signed) ``oneflow.int64`` or ``oneflow.long`` :class:`oneflow.LongTensor` :class:`oneflow.cuda.LongTensor` ======================================= =============================================== =============================== ================================== Initializing and basic operations --------------------------------- A tensor can be constructed from a Python :class:`list` or sequence using the :func:`oneflow.tensor` constructor: :: >>> import oneflow >>> import numpy as np >>> oneflow.tensor([[1., -1.], [1., -1.]]) tensor([[ 1., -1.], [ 1., -1.]], dtype=oneflow.float32) >>> oneflow.tensor(np.array([[1, 2, 3], [4, 5, 6]])) tensor([[ 1, 2, 3], [ 4, 5, 6]], dtype=oneflow.int64) .. warning:: :func:`oneflow.tensor` always copies :attr:`data`. If you have a Tensor :attr:`data` and just want to change its ``requires_grad`` flag, use :meth:`~oneflow.Tensor.requires_grad_` or :meth:`~oneflow.Tensor.detach` to avoid a copy. If you have a numpy array and want to avoid a copy, use :func:`oneflow.as_tensor`. .. A tensor of specific data type can be constructed by passing a :class:`oneflow.dtype` and/or a :class:`oneflow.device` to a constructor or tensor creation op: :: >>> import oneflow >>> oneflow.zeros([2, 4], dtype=oneflow.int32) tensor([[ 0, 0, 0, 0], [ 0, 0, 0, 0]], dtype=oneflow.int32) >>> cuda0 = oneflow.device('cuda:0') >>> oneflow.ones([2, 4], dtype=oneflow.float64, device=cuda0) tensor([[ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], device='cuda:0', dtype=oneflow.float64) For more information about building tensors, see :ref:`tensor-creation-ops` The contents of a tensor can be accessed and modified using Python's indexing and slicing notation: :: >>> import oneflow >>> x = oneflow.tensor([[1, 2, 3], [4, 5, 6]]) >>> print(x[1][2]) tensor(6, dtype=oneflow.int64) >>> x[0][1] = 8 >>> print(x) tensor([[1, 8, 3], [4, 5, 6]], dtype=oneflow.int64) Use :meth:`oneflow.Tensor.item` to get a Python number from a tensor containing a single value: :: >>> import oneflow >>> x = oneflow.tensor([[1]]) >>> x tensor([[1]], dtype=oneflow.int64) >>> x.item() 1 >>> x = oneflow.tensor(2.5) >>> x tensor(2.5000, dtype=oneflow.float32) >>> x.item() 2.5 For more information about indexing, see :ref:`indexing-slicing-joining` A tensor can be created with :attr:`requires_grad=True` so that :mod:`oneflow.autograd` records operations on them for automatic differentiation. :: >>> import oneflow >>> x = oneflow.tensor([[1., -1.], [1., 1.]], requires_grad=True) >>> out = x.pow(2).sum() >>> out.backward() >>> x.grad tensor([[ 2., -2.], [ 2., 2.]], dtype=oneflow.float32) .. note:: For more information on the :class:`oneflow.dtype`, :class:`oneflow.device`, and :class:`oneflow.layout` attributes of a :class:`oneflow.Tensor`, see :ref:`tensor-attributes-doc`. .. note:: Methods which mutate a tensor are marked with an underscore suffix. For example, :func:`oneflow.FloatTensor.add_` computes the absolute value in-place and returns the modified tensor, while :func:`oneflow.FloatTensor.add` computes the result in a new tensor. .. note:: To change an existing tensor's :class:`oneflow.device` and/or :class:`oneflow.dtype`, consider using :meth:`~oneflow.Tensor.to` method of Tensor object. .. warning:: Current implementation of :class:`oneflow.Tensor` introduces memory overhead, thus it might lead to unexpectedly high memory usage in the applications with many tiny tensors. If this is your case, consider using one large structure. Tensor class reference ---------------------- .. class:: Tensor() There are a few main ways to create a tensor, depending on your use case. - To create a tensor with pre-existing data, use :func:`oneflow.tensor`. - To create a tensor with specific size, use ``oneflow.*`` tensor creation ops (see :ref:`tensor-creation-ops`). - To create a tensor with the same size (and similar types) as another tensor, use ``oneflow.*_like`` tensor creation ops (see :ref:`tensor-creation-ops`). .. currentmodule:: oneflow .. autosummary:: :toctree: generated :nosignatures: Tensor.new_empty Tensor.new_ones Tensor.new_zeros Tensor.new_full Tensor.new_tensor Tensor.is_cuda Tensor.is_global Tensor.device Tensor.grad Tensor.ndim Tensor.abs Tensor.acos Tensor.acosh Tensor.add Tensor.add_ Tensor.addcdiv Tensor.addcdiv_ Tensor.addcmul Tensor.addcmul_ Tensor.addmm Tensor.all Tensor.amin Tensor.amax Tensor.any Tensor.arccos Tensor.arccosh Tensor.arcsin Tensor.arcsinh Tensor.arctan Tensor.arctanh Tensor.argmax Tensor.argmin Tensor.argsort Tensor.argwhere Tensor.asin Tensor.asinh Tensor.atan Tensor.atan2 Tensor.atanh Tensor.backward Tensor.bmm Tensor.bool Tensor.byte Tensor.cast Tensor.ceil Tensor.ceil_ Tensor.chunk Tensor.clamp Tensor.clamp_ Tensor.clip Tensor.clip_ Tensor.clone Tensor.contiguous Tensor.copy_ Tensor.cos Tensor.cosh Tensor.cpu Tensor.cuda Tensor.cumprod Tensor.cumsum Tensor.data Tensor.dot Tensor.detach Tensor.placement Tensor.sbp Tensor.diag Tensor.diagonal Tensor.dim Tensor.div Tensor.div_ Tensor.double Tensor.dtype Tensor.digamma Tensor.element_size Tensor.eq Tensor.equal Tensor.erf Tensor.erfc Tensor.erfinv Tensor.erfinv_ Tensor.exp Tensor.exp2 Tensor.expand Tensor.expand_as Tensor.expm1 Tensor.fill_ Tensor.flatten Tensor.flip Tensor.float Tensor.floor Tensor.floor_ Tensor.floor_divide Tensor.fmod Tensor.gather Tensor.ge Tensor.get_device Tensor.grad_fn Tensor.gt Tensor.gt_ Tensor.half Tensor.in_top_k Tensor.index_select Tensor.index_add Tensor.index_add_ Tensor.int Tensor.is_contiguous Tensor.is_floating_point Tensor.is_lazy Tensor.is_leaf Tensor.isinf Tensor.isnan Tensor.item Tensor.le Tensor.lerp Tensor.lerp_ Tensor.log Tensor.log1p Tensor.log2 Tensor.log10 Tensor.logical_and Tensor.logical_or Tensor.logical_not Tensor.logical_xor Tensor.long Tensor.lt Tensor.masked_fill Tensor.masked_fill_ Tensor.masked_select Tensor.matmul Tensor.mm Tensor.mv Tensor.max Tensor.maximum Tensor.median Tensor.mean Tensor.min Tensor.minimum Tensor.mish Tensor.mode Tensor.mul Tensor.mul_ Tensor.frac Tensor.frac_ Tensor.nansum Tensor.narrow Tensor.ndimension Tensor.ne Tensor.neg Tensor.negative Tensor.nelement Tensor.nonzero Tensor.norm Tensor.normal_ Tensor.numel Tensor.numpy Tensor.offload Tensor.load Tensor.is_offloaded Tensor.permute Tensor.pow Tensor.prod Tensor.quantile Tensor.reciprocal Tensor.register_hook Tensor.relu Tensor.repeat Tensor.repeat_interleave Tensor.requires_grad Tensor.requires_grad_ Tensor.reshape Tensor.reshape_as Tensor.retain_grad Tensor.roll Tensor.round Tensor.round_ Tensor.rsqrt Tensor.selu Tensor.shape Tensor.sigmoid Tensor.sign Tensor.silu Tensor.sin Tensor.sin_ Tensor.sinh Tensor.size Tensor.softmax Tensor.softplus Tensor.softsign Tensor.sort Tensor.split Tensor.sqrt Tensor.square Tensor.squeeze Tensor.squeeze_ Tensor.std Tensor.storage_offset Tensor.stride Tensor.logsumexp Tensor.sum Tensor.swapaxes Tensor.swapdims Tensor.sub Tensor.sub_ Tensor.tan Tensor.tanh Tensor.tile Tensor.to Tensor.local_to_global Tensor.global_to_global Tensor.to_global Tensor.to_local Tensor.to_consistent Tensor.tolist Tensor.topk Tensor.transpose Tensor.tril Tensor.triu Tensor.trunc Tensor.type_as Tensor.type Tensor.t Tensor.T Tensor.unbind Tensor.unfold Tensor.uniform_ Tensor.unsqueeze Tensor.unsqueeze_ Tensor.as_strided Tensor.as_strided_ Tensor.var Tensor.view Tensor.view_as Tensor.where Tensor.zero_ Tensor.nms Tensor.pin_memory Tensor.is_pinned Tensor.inverse Tensor.cross Tensor.scatter Tensor.scatter_ Tensor.scatter_add Tensor.scatter_add_ Tensor.bernoulli Tensor.bernoulli_ Tensor.bincount Tensor.isclose Tensor.allclose Tensor.broadcast_to Tensor.unique Tensor.bitwise_and Tensor.bitwise_or Tensor.bitwise_xor Tensor.baddbmm ================================================ FILE: docs/source/tensor_attributes.rst ================================================ .. currentmodule:: oneflow .. _tensor-attributes-doc: Tensor Attributes ============================================================= .. The documentation is referenced from: https://pytorch.org/docs/1.10/tensor_attributes.html. Each local ``oneflow.Tensor`` has a :class:`oneflow.dtype`, :class:`oneflow.device`, and global ``oneflow.Tensor`` has a :class:`oneflow.dtype`, :class:`oneflow.placement`, :class:`oneflow.sbp`. .. contents:: oneflow :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top .. _dtype-doc: oneflow.dtype ----------------------- .. class:: dtype A :class:`oneflow.dtype` is an object that represents the data type of a :class:`oneflow.Tensor`. Oneflow has eight different data types: ======================================= =============================================== =============================== ================================== Data type dtype CPU tensor GPU tensor ======================================= =============================================== =============================== ================================== Boolean ``oneflow.bool`` :class:`oneflow.BoolTensor` :class:`oneflow.cuda.BoolTensor` 8-bit integer (unsigned) ``oneflow.uint8`` :class:`oneflow.ByteTensor` :class:`oneflow.cuda.ByteTensor` 8-bit integer (signed) ``oneflow.int8`` :class:`oneflow.CharTensor` :class:`oneflow.cuda.CharTensor` 64-bit floating point ``oneflow.float64`` or ``oneflow.double`` :class:`oneflow.DoubleTensor` :class:`oneflow.cuda.DoubleTensor` 32-bit floating point ``oneflow.float32`` or ``oneflow.float`` :class:`oneflow.FloatTensor` :class:`oneflow.cuda.FloatTensor` 16-bit floating point ``oneflow.float16`` or ``oneflow.half`` :class:`oneflow.HalfTensor` :class:`oneflow.cuda.HalfTensor` 32-bit integer (signed) ``oneflow.int32`` or ``oneflow.int`` :class:`oneflow.IntTensor` :class:`oneflow.cuda.IntTensor` 64-bit integer (signed) ``oneflow.int64`` or ``oneflow.long`` :class:`oneflow.LongTensor` :class:`oneflow.cuda.LongTensor` ======================================= =============================================== =============================== ================================== To find out if a :class:`oneflow.dtype` is a floating point data type, the property :attr:`is_floating_point` can be used, which returns ``True`` if the data type is a floating point data type. .. _type-promotion-doc: When the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote by finding the minimum dtype that satisfies the following rules: * If the type of a scalar operand is of a higher category than tensor operands (where complex > floating > integral > boolean), we promote to a type with sufficient size to hold all scalar operands of that category. * If a zero-dimension tensor operand has a higher category than dimensioned operands, we promote to a type with sufficient size and category to hold all zero-dim tensor operands of that category. * If there are no higher-category zero-dim operands, we promote to a type with sufficient size and category to hold all dimensioned operands. A floating point scalar operand has dtype `oneflow.get_default_dtype()` and an integral non-boolean scalar operand has dtype `oneflow.int64`. Unlike numpy, we do not inspect values when determining the minimum `dtypes` of an operand. Quantized and complex types are not yet supported. Promotion Examples:: >>> float_tensor = oneflow.ones(1, dtype=oneflow.float) >>> double_tensor = oneflow.ones(1, dtype=oneflow.double) >>> int_tensor = oneflow.ones(1, dtype=oneflow.int) >>> long_tensor = oneflow.ones(1, dtype=oneflow.long) >>> uint_tensor = oneflow.ones(1, dtype=oneflow.uint8) >>> double_tensor = oneflow.ones(1, dtype=oneflow.double) >>> bool_tensor = oneflow.ones(1, dtype=oneflow.bool) # zero-dim tensors >>> long_zerodim = oneflow.tensor(1, dtype=oneflow.long) >>> int_zerodim = oneflow.tensor(1, dtype=oneflow.int) >>> a,b=oneflow.tensor(5),oneflow.tensor(5) >>> oneflow.add(a, b).dtype oneflow.int64 # 5 is an int64, but does not have higher category than int_tensor so is not considered. >>> (int_tensor + 5).dtype oneflow.int32 >>> (int_tensor + long_zerodim).dtype oneflow.int64 >>> (long_tensor + int_tensor).dtype oneflow.int64 >>> (bool_tensor + long_tensor).dtype oneflow.int64 >>> (bool_tensor + uint_tensor).dtype oneflow.uint8 >>> (float_tensor + double_tensor).dtype oneflow.float64 >>> (bool_tensor + int_tensor).dtype oneflow.int32 # Since long is a different kind than float, result dtype only needs to be large enough # to hold the float. >>> oneflow.add(long_tensor, float_tensor).dtype oneflow.float32 When the output tensor of an arithmetic operation is specified, we allow casting to its `dtype` except that: * An integral output tensor cannot accept a floating point tensor. * A boolean output tensor cannot accept a non-boolean tensor. * A non-complex output tensor cannot accept a complex tensor Casting Examples:: # allowed: >>> float_tensor *= float_tensor >>> float_tensor *= int_tensor >>> float_tensor *= uint_tensor >>> float_tensor *= bool_tensor >>> int_tensor *= uint_tensor # disallowed (RuntimeError: result type can't be cast to the desired output type): >>> float_tensor *= double_tensor >>> int_tensor *= float_tensor >>> int_tensor *= long_tensor >>> uint_tensor *= int_tensor >>> bool_tensor *= int_tensor >>> bool_tensor *= uint_tensor .. _device-doc: oneflow.device ------------------------ .. class:: device A :class:`oneflow.device` is an object representing the device on which a :class:`oneflow.Tensor` is or will be allocated. The :class:`oneflow.device` contains a device type (``'cpu'`` or ``'cuda'``) and optional device ordinal for the device type. If the device ordinal is not present, this object will always represent the current device for the device type, even after :func:`oneflow.cuda.set_device()` is called; e.g., a :class:`oneflow.Tensor` constructed with device ``'cuda'`` is equivalent to ``'cuda:X'`` where X is the result of :func:`oneflow.cuda.current_device()`. A :class:`oneflow.Tensor`'s device can be accessed via the :attr:`Tensor.device` property. A :class:`oneflow.device` can be constructed via a string or via a string and device ordinal Via a string: :: >>> oneflow.device('cuda:0') device(type='cuda', index=0) >>> oneflow.device('cpu') device(type='cpu', index=0) >>> oneflow.device('cuda') # current cuda device device(type='cuda', index=0) Via a string and device ordinal: :: >>> oneflow.device('cuda', 0) device(type='cuda', index=0) >>> oneflow.device('cpu', 0) device(type='cpu', index=0) .. note:: The :class:`oneflow.device` argument in functions can generally be substituted with a string. This allows for fast prototyping of code. >>> # Example of a function that takes in a oneflow.device >>> cuda1 = oneflow.device('cuda:1') >>> oneflow.randn((2,3), device=cuda1) >>> # You can substitute the oneflow.device with a string >>> oneflow.randn((2,3), device='cuda:1') .. note:: For legacy reasons, a device can be constructed via a single device ordinal, which is treated as a cuda device. This matches :meth:`Tensor.get_device`, which returns an ordinal for cuda tensors and is not supported for cpu tensors. >>> oneflow.device(1) device(type='cuda', index=1) .. note:: Methods which take a device will generally accept a (properly formatted) string or (legacy) integer device ordinal, i.e. the following are all equivalent: >>> oneflow.randn((2,3), device=oneflow.device('cuda:1')) >>> oneflow.randn((2,3), device='cuda:1') >>> oneflow.randn((2,3), device=1) # legacy oneflow.placement -------------------------------------------------------------- .. autoclass:: oneflow.placement oneflow.placement.all -------------------------------------------------------------- .. autofunction:: oneflow.placement.all oneflow.env.all_device_placement -------------------------------------------------------------- .. autofunction:: oneflow.env.all_device_placement oneflow.sbp.sbp -------------------------------------------------------------- .. autoclass:: oneflow.sbp.sbp ================================================ FILE: docs/source/troubleshooting.md ================================================ # Troubleshooting - 'libunwind.h' not found - You might add CMake argument `-DWITH_UNWIND=OFF`, or install libunwind in your system. - `CUDNN_STATUS_NOT_INITIALIZED` - You might see error message like these: ``` I0729 22:37:45.483937439 56788 ev_epoll_linux.c:82] Use of signals is disabled. Epoll enginll not be used E0729 22:37:45.515343 56788 version.cpp:82] Failed to get cuda runtime version: CUDA driver version nsufficient for CUDA runtime version F0729 22:38:31.209002 56788 improver.cpp:535] Check failed: mem_size > 0 (-524288000 vs. 0) ``` ``` F0723 19:05:56.194067 40970 cuda_util.cpp:82] Check failed: error == CUDNN_STATUS_SUCCESS (1 vs. 0) CUDNN_STATUS_NOT_INITIALIZED ``` - Please upgrade to Nvidia Linux x86_64 driver. Version >= 440.33 is recommended. - For more information, please refer to [CUDA compatibility documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html). - Failed to compile `.cu` files - Please refer to [CUDA System Requirements](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#system-requirements) . Make sure your linux distribution and libraries shipped with it meet the requirements. - If you are using tools like conda, please make sure libraries you install doesn't shade the proper installation comes with linux distribution or package management like apt-get. - Please build OneFlow with a newer version of CMake. You could download version 3.14 from here: [https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz](https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz) - How do I know what compilers and flags are used to compile OneFlow? - run `make clean && make VERBOSE=1` to get exact compile commands with compiler path and flags - How to compile OneFlow with RDMA support? - add cmake flag `-DBUILD_RDMA` to compile OneFlow - Which version of g++ CMake is using to build OneFlow? - You should find a line like this in CMake output: ```bash -- CMAKE_CXX_COMPILER_VERSION: [YOUR G++ VERSION NUMBER] ``` - Failed to compile NCCL - Try use less threads when compiling OneFlow third party. For instance, use ```bash cmake -DTHIRD_PARTY=ON .. && make ``` instead of ```bash cmake -DTHIRD_PARTY=ON .. && make -j$(nproc) ` ``` - `"CUDA_VERSION" "VERSION_GREATER_EQUAL" "10.0"` - Please use a newer version of CMake - Make sure cmake is correctly included in `PATH` - CUBLAS not found - Usually it happens when using CUDA 10.1 or newer - You should see error massage by CMake like this: ``` cuda lib not found: /usr/local/miniconda3/envs/dl/lib/libcublas_static.a or /usr/local/cuda/lib64/libcublas_static.a ``` - Make sure `libcublas_static.a` is in one of the two directories. - When running OneFlow in gdb, there is no debug information for code location. - add cmake flag `-DCMAKE_BUILD_TYPE=RELWITHDEBINFO` or `-DCMAKE_BUILD_TYPE=DEBUG` and recompile - `libof_ccobj.a: File truncated` - You might see error message like this: ``` /usr/bin/ar: libof_ccobj.a: File truncated make[2]: *** [libof_ccobj.a] Error 1 make[2]: *** Deleting file `libof_ccobj.a' make[1]: *** [CMakeFiles/of_ccobj.dir/all] Error 2 make: *** [all] Error 2 ``` - You should upgrade your GNU Binutils. Version 2.33.1 is recommended. If you are using conda, you could install it by running `conda install -c conda-forge binutils` - Failed to compile because C++ 17 is enabled - In some cases, environment variable `CXXFLAGS` is not empty and contains `--std c++17`. - Check if it is empty by running `echo $CXXFLAGS` and clear it with `unset CXXFLAGS`. - If you are using conda, to make the changes on environment variables permanent, you can run: ```bash conda env config vars set CXXFLAGS="-fPIC" ``` - cmake outputs error `No CMAKE_ASM_NASM_COMPILER could be found.` - Install `nasm`. For instance, run `sudo yum install nasm` if you are on centos. - `No module named 'google.protobuf'` - You might see error message like this: ``` Scanning dependencies of target generate_api ... from google.protobuf import descriptor as _descriptor ModuleNotFoundError: No module named 'google.protobuf' CMakeFiles/generate_api.dir/build.make:57: recipe for target 'CMakeFiles/generate_api' failed make[2]: *** [CMakeFiles/generate_api] Error 1 ``` - Install development dependencies by running: ``` pip3 install -r dev-requirements.txt ``` - Get gdb warning `ptrace: Operation not permitted.` and gdb command `bt` prints no backtrace - You might get this warning when debugging OneFlow with gdb inside a docker container. Try add these flags when launching your container: ``` docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined ``` - Please refer to https://stackoverflow.com/questions/19215177/how-to-solve-ptrace-operation-not-permitted-when-trying-to-attach-gdb-to-a-pro - It takes too long to download python packages when running `make` - If you are in China, you could run this to have pip download packages from domestic mirror of pypi: ``` python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple ``` - For more information on this, please refer to [pypi 镜像使用帮助](https://mirror.tuna.tsinghua.edu.cn/help/pypi/) ================================================ FILE: docs/source/type_info.rst ================================================ .. currentmodule:: oneflow .. _type-info-doc: Type Info ========= .. The documentation is referenced from: https://pytorch.org/docs/1.10/type_info.html. The numerical properties of a :class:`oneflow.dtype` can be accessed through either the :class:`oneflow.finfo` or the :class:`oneflow.iinfo`. .. contents:: oneflow :depth: 2 :local: :class: this-will-duplicate-information-and-it-is-still-useful-here :backlinks: top oneflow.finfo ------------- .. class:: oneflow.finfo A :class:`oneflow.finfo` is an object that represents the numerical properties of a floating point :class:`oneflow.dtype`, (i.e. ``oneflow.float32``, ``oneflow.float64`` and ``oneflow.float16``). This is similar to `numpy.finfo `_. A :class:`oneflow.finfo` provides the following attributes: ================== ======= ========================================================================== Name Type Description ================== ======= ========================================================================== bits int The number of bits occupied by the type. eps float The smallest representable number such that ``1.0 + eps != 1.0``. min float The largest representable number. max float The smallest representable number (typically ``-max``). tiny float The smallest positive normal number. See notes. resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``. ================== ======= ========================================================================== For example: .. code-block:: >>> import oneflow as flow >>> flow.finfo() finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, tiny=1.17549e-38, dtype=oneflow.float32, bits=32) >>> flow.finfo(flow.float) finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, tiny=1.17549e-38, dtype=oneflow.float32, bits=32) >>> flow.finfo(flow.float16).bits 16 >>> flow.finfo(flow.float16).max 65504.0 oneflow.iinfo ------------- .. class:: oneflow.iinfo A :class:`oneflow.iinfo` is an object that represents the numerical properties of a integer :class:`oneflow.dtype` (i.e. ``oneflow.uint8``, ``oneflow.int8``, ``oneflow.int16``, ``oneflow.int32``, and ``oneflow.int64``). This is similar to `numpy.iinfo `_. A :class:`oneflow.iinfo` provides the following attributes: ================== ======= ========================================================================== Name Type Description ================== ======= ========================================================================== bits int The number of bits occupied by the type. min float The largest representable number. max float The smallest representable number. ================== ======= ========================================================================== For example: .. code-block :: >>> import oneflow as flow >>> flow.iinfo(flow.int8) iinfo(min=-128, max=127, dtype=oneflow.int8, bits=8) >>> flow.iinfo(flow.int).max 2147483647 >>> flow.iinfo(flow.int).bits 32 ================================================ FILE: docs/source/utils.data.rst ================================================ oneflow.utils.data =================================== .. The documentation is referenced from: https://pytorch.org/docs/1.10/data.html .. automodule:: oneflow.utils.data At the heart of Oneflow data loading utility is the :class:`oneflow.utils.data.DataLoader` class. It represents a Python iterable over a dataset, with support for * `map-style and iterable-style datasets `_, * `customizing data loading order `_, * `automatic batching `_, * `single- and multi-process data loading `_, * `automatic memory pinning `_. These options are configured by the constructor arguments of a :class:`~oneflow.utils.data.DataLoader`, which has signature:: DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False) The sections below describe in details the effects and usages of these options. Dataset Types ------------- The most important argument of :class:`~oneflow.utils.data.DataLoader` constructor is :attr:`dataset`, which indicates a dataset object to load data from. Oneflow supports two different types of datasets: * `map-style datasets `_, * `iterable-style datasets `_. Map-style datasets ^^^^^^^^^^^^^^^^^^ A map-style dataset is one that implements the :meth:`__getitem__` and :meth:`__len__` protocols, and represents a map from (possibly non-integral) indices/keys to data samples. For example, such a dataset, when accessed with ``dataset[idx]``, could read the ``idx``-th image and its corresponding label from a folder on the disk. See :class:`~oneflow.utils.data.Dataset` for more details. Iterable-style datasets ^^^^^^^^^^^^^^^^^^^^^^^ An iterable-style dataset is an instance of a subclass of :class:`~oneflow.utils.data.IterableDataset` that implements the :meth:`__iter__` protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data. For example, such a dataset, when called ``iter(dataset)``, could return a stream of data reading from a database, a remote server, or even logs generated in real time. See :class:`~oneflow.utils.data.IterableDataset` for more details. .. note:: When using an :class:`~oneflow.utils.data.IterableDataset` with `multi-process data loading `_. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See :class:`~oneflow.utils.data.IterableDataset` documentations for how to achieve this. Data Loading Order and :class:`~oneflow.utils.data.Sampler` ----------------------------------------------------------- For `iterable-style datasets `_, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time). The rest of this section concerns the case with `map-style datasets `_. :class:`oneflow.utils.data.Sampler` classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a :class:`~oneflow.utils.data.Sampler` could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD. A sequential or shuffled sampler will be automatically constructed based on the :attr:`shuffle` argument to a :class:`~oneflow.utils.data.DataLoader`. Alternatively, users may use the :attr:`sampler` argument to specify a custom :class:`~oneflow.utils.data.Sampler` object that at each time yields the next index/key to fetch. A custom :class:`~oneflow.utils.data.Sampler` that yields a list of batch indices at a time can be passed as the :attr:`batch_sampler` argument. Automatic batching can also be enabled via :attr:`batch_size` and :attr:`drop_last` arguments. See `the next section `_ for more details on this. .. note:: Neither :attr:`sampler` nor :attr:`batch_sampler` is compatible with iterable-style datasets, since such datasets have no notion of a key or an index. Loading Batched and Non-Batched Data ------------------------------------ :class:`~oneflow.utils.data.DataLoader` supports automatically collating individual fetched data samples into batches via arguments :attr:`batch_size`, :attr:`drop_last`, :attr:`batch_sampler`, and :attr:`collate_fn` (which has a default function). Automatic batching (default) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This is the most common case, and corresponds to fetching a minibatch of data and collating them into batched samples, i.e., containing Tensors with one dimension being the batch dimension (usually the first). When :attr:`batch_size` (default ``1``) is not ``None``, the data loader yields batched samples instead of individual samples. :attr:`batch_size` and :attr:`drop_last` arguments are used to specify how the data loader obtains batches of dataset keys. For map-style datasets, users can alternatively specify :attr:`batch_sampler`, which yields a list of keys at a time. .. note:: The :attr:`batch_size` and :attr:`drop_last` arguments essentially are used to construct a :attr:`batch_sampler` from :attr:`sampler`. For map-style datasets, the :attr:`sampler` is either provided by user or constructed based on the :attr:`shuffle` argument. For iterable-style datasets, the :attr:`sampler` is a dummy infinite one. See `this section `_ on more details on samplers. .. note:: When fetching from `iterable-style datasets `_ with `multi-processing `_, the :attr:`drop_last` argument drops the last non-full batch of each worker's dataset replica. After fetching a list of samples using the indices from sampler, the function passed as the :attr:`collate_fn` argument is used to collate lists of samples into batches. In this case, loading from a map-style dataset is roughly equivalent with:: for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices]) and loading from an iterable-style dataset is roughly equivalent with:: dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices]) A custom :attr:`collate_fn` can be used to customize collation, e.g., padding sequential data to max length of a batch. See `this section `_ on more about :attr:`collate_fn`. Disable automatic batching ^^^^^^^^^^^^^^^^^^^^^^^^^^ In certain cases, users may want to handle batching manually in dataset code, or simply load individual samples. For example, it could be cheaper to directly load batched data (e.g., bulk reads from a database or reading continuous chunks of memory), or the batch size is data dependent, or the program is designed to work on individual samples. Under these scenarios, it's likely better to not use automatic batching (where :attr:`collate_fn` is used to collate the samples), but let the data loader directly return each member of the :attr:`dataset` object. When both :attr:`batch_size` and :attr:`batch_sampler` are ``None`` (default value for :attr:`batch_sampler` is already ``None``), automatic batching is disabled. Each sample obtained from the :attr:`dataset` is processed with the function passed as the :attr:`collate_fn` argument. **When automatic batching is disabled**, the default :attr:`collate_fn` simply converts NumPy arrays into Oneflow Tensors, and keeps everything else untouched. In this case, loading from a map-style dataset is roughly equivalent with:: for index in sampler: yield collate_fn(dataset[index]) and loading from an iterable-style dataset is roughly equivalent with:: for data in iter(dataset): yield collate_fn(data) See `this section `_ on more about :attr:`collate_fn`. .. _dataloader-collate_fn: Working with :attr:`collate_fn` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The use of :attr:`collate_fn` is slightly different when automatic batching is enabled or disabled. **When automatic batching is disabled**, :attr:`collate_fn` is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default :attr:`collate_fn` simply converts NumPy arrays in Oneflow tensors. **When automatic batching is enabled**, :attr:`collate_fn` is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes the behavior of the default :attr:`collate_fn` (:func:`~oneflow.utils.data.default_collate`). For instance, if each data sample consists of a 3-channel image and an integral class label, i.e., each element of the dataset returns a tuple ``(image, class_index)``, the default :attr:`collate_fn` collates a list of such tuples into a single tuple of a batched image tensor and a batched class label Tensor. In particular, the default :attr:`collate_fn` has the following properties: * It always prepends a new dimension as the batch dimension. * It automatically converts NumPy arrays and Python numerical values into Oneflow Tensors. * It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for ``list`` s, ``tuple`` s, ``namedtuple`` s, etc. Users may use customized :attr:`collate_fn` to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types. If you run into a situation where the outputs of :class:`~oneflow.utils.data.DataLoader` have dimensions or type that is different from your expectation, you may want to check your :attr:`collate_fn`. Single- and Multi-process Data Loading -------------------------------------- A :class:`~oneflow.utils.data.DataLoader` uses single-process data loading by default. Within a Python process, the `Global Interpreter Lock (GIL) `_ prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, Oneflow provides an easy switch to perform multi-process data loading by simply setting the argument :attr:`num_workers` to a positive integer. Single-process data loading (default) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In this mode, data fetching is done in the same process a :class:`~oneflow.utils.data.DataLoader` is initialized. Therefore, data loading may block computing. However, this mode may be preferred when resource(s) used for sharing data among processes (e.g., shared memory, file descriptors) is limited, or when the entire dataset is small and can be loaded entirely in memory. Additionally, single-process loading often shows more readable error traces and thus is useful for debugging. Multi-process data loading ^^^^^^^^^^^^^^^^^^^^^^^^^^ Setting the argument :attr:`num_workers` as a positive integer will turn on multi-process data loading with the specified number of loader worker processes. .. warning:: After several iterations, the loader worker processes will consume the same amount of CPU memory as the parent process for all Python objects in the parent process which are accessed from the worker processes. This can be problematic if the Dataset contains a lot of data (e.g., you are loading a very large list of filenames at Dataset construction time) and/or you are using a lot of workers (overall memory usage is ``number of workers * size of parent process``). The simplest workaround is to replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects. In this mode, each time an iterator of a :class:`~oneflow.utils.data.DataLoader` is created (e.g., when you call ``enumerate(dataloader)``), :attr:`num_workers` worker processes are created. At this point, the :attr:`dataset`, :attr:`collate_fn`, and :attr:`worker_init_fn` are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including :attr:`collate_fn`) runs in the worker process. For map-style datasets, the main process generates the indices using :attr:`sampler` and sends them to the workers. So any shuffle randomization is done in the main process which guides loading by assigning indices to load. For iterable-style datasets, since each worker process gets a replica of the :attr:`dataset` object, naive multi-process loading will often result in duplicated data. Using :attr:`worker_init_fn`, users may configure each replica independently. (See :class:`~oneflow.utils.data.IterableDataset` documentations for how to achieve this. ) For similar reasons, in multi-process loading, the :attr:`drop_last` argument drops the last non-full batch of each worker's iterable-style dataset replica. Workers are shut down once the end of the iteration is reached, or when the iterator becomes garbage collected. .. warning:: It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing. Instead, we recommend using `automatic memory pinning `_ (i.e., setting :attr:`pin_memory=True`), which enables fast data transfer to CUDA-enabled GPUs. Platform-specific behaviors """"""""""""""""""""""""""" Since workers rely on Python :py:mod:`multiprocessing`, worker launch behavior is different on Windows compared to Unix. * On Unix, :func:`fork()` is the default :py:mod:`multiprocessing` start method. Using :func:`fork`, child workers typically can access the :attr:`dataset` and Python argument functions directly through the cloned address space. * On Windows or MacOS, :func:`spawn()` is the default :py:mod:`multiprocessing` start method. Using :func:`spawn()`, another interpreter is launched which runs your main script, followed by the internal worker function that receives the :attr:`dataset`, :attr:`collate_fn` and other arguments through :py:mod:`pickle` serialization. This separate serialization means that you should take two steps to ensure you are compatible with Windows while using multi-process data loading: - Wrap most of you main script's code within ``if __name__ == '__main__':`` block, to make sure it doesn't run again (most likely generating error) when each worker process is launched. You can place your dataset and :class:`~oneflow.utils.data.DataLoader` instance creation logic here, as it doesn't need to be re-executed in workers. - Make sure that any custom :attr:`collate_fn`, :attr:`worker_init_fn` or :attr:`dataset` code is declared as top level definitions, outside of the ``__main__`` check. This ensures that they are available in worker processes. (this is needed since functions are pickled as references only, not ``bytecode``.) .. _data-loading-randomness: Randomness in multi-process data loading """""""""""""""""""""""""""""""""""""""""" By default, each worker will have its Oneflow seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified :attr:`generator`. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. In :attr:`worker_init_fn`, you may access the Oneflow seed set for each worker with :func:`oneflow.initial_seed()`, and use it to seed other libraries before data loading. Memory Pinning -------------- Host to GPU copies are much faster when they originate from pinned (page-locked) memory. See `cuda-memory-pinning` for more details on when and how to use pinned memory generally. For data loading, passing :attr:`pin_memory=True` to a :class:`~oneflow.utils.data.DataLoader` will automatically put the fetched data Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled GPUs. The default memory pinning logic only recognizes Tensors and maps and iterables containing Tensors. By default, if the pinning logic sees a batch that is a custom type (which will occur if you have a :attr:`collate_fn` that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch (or those elements) without pinning the memory. To enable memory pinning for custom batch or data type(s), define a :meth:`pin_memory` method on your custom type(s). See the example below. Example:: class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) self.inp = oneflow.stack(transposed_data[0], 0) self.tgt = oneflow.stack(transposed_data[1], 0) # custom memory pinning method on custom type def pin_memory(self): self.inp = self.inp.pin_memory() self.tgt = self.tgt.pin_memory() return self def collate_wrapper(batch): return SimpleCustomBatch(batch) inps = oneflow.arange(10 * 5, dtype=oneflow.float32).view(10, 5) tgts = oneflow.arange(10 * 5, dtype=oneflow.float32).view(10, 5) dataset = TensorDataset(inps, tgts) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True) for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) print(sample.tgt.is_pinned()) .. autoclass:: DataLoader .. autoclass:: Dataset .. autoclass:: IterableDataset .. autoclass:: TensorDataset .. autoclass:: ConcatDataset .. autoclass:: Subset .. autofunction:: oneflow.utils.data.random_split .. autoclass:: oneflow.utils.data.Sampler .. autoclass:: oneflow.utils.data.SequentialSampler .. autoclass:: oneflow.utils.data.RandomSampler .. autoclass:: oneflow.utils.data.SubsetRandomSampler .. autoclass:: oneflow.utils.data.BatchSampler .. autoclass:: oneflow.utils.data.distributed.DistributedSampler ================================================ FILE: docs/source/utils.global_view.rst ================================================ oneflow.utils.global_view ====================================== Some global view Ops -------------------------------------- .. currentmodule:: oneflow.utils.global_view .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst to_global to_local global_mode current_global_mode ================================================ FILE: docs/source/utils.tensor.rst ================================================ oneflow.utils.tensor ========================================================== Some torch-related Ops are suitable for tensor conversion. ---------------------------------------------------------- .. currentmodule:: oneflow.utils.tensor .. autosummary:: :toctree: generated :nosignatures: from_torch to_torch ================================================ FILE: external/CMakeLists.txt ================================================ set(ONETBB_URL https://github.com/oneapi-src/oneTBB/archive/3db67b5ba2a81bd1288325c5847e09e13c46f4d7.zip) use_mirror(VARIABLE ONETBB_URL URL ${ONETBB_URL}) set(ONETBB_MD5 7545d4084baff17af73da2dae5ab8005) set(ROBIN_HOOD_HASHING_URL https://github.com/martinus/robin-hood-hashing/archive/refs/tags/3.11.5.tar.gz) use_mirror(VARIABLE ROBIN_HOOD_HASHING_URL URL ${ROBIN_HOOD_HASHING_URL}) set(ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467) set(FMT_URL https://github.com/fmtlib/fmt/archive/fc07217d85e6dcec52878807d6bbd89a9d9156a5.zip) use_mirror(VARIABLE FMT_URL URL ${FMT_URL}) set(FMT_MD5 7d9bb2ececc9ede29cd35bdc42a7e22c) set(KINETO_URL https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip) use_mirror(VARIABLE KINETO_URL URL ${KINETO_URL}) set(KINETO_MD5 f9b550591b3899fb267270c19484933f) set(EXTERNAL_TARGETS) if(WITH_TBB) # set(WITH_${threading_runtime_item} ON) in threading.cmake add_subdirectory(onetbb) list(APPEND EXTERNAL_TARGETS tbb) endif() add_subdirectory(robin-hood-hashing) list(APPEND EXTERNAL_TARGETS robin_hood) add_subdirectory(fmt) list(APPEND EXTERNAL_TARGETS fmt) add_subdirectory(kineto) list(APPEND EXTERNAL_TARGETS kineto) mark_targets_as_system(${EXTERNAL_TARGETS}) set_property(GLOBAL PROPERTY EXTERNAL_TARGETS ${EXTERNAL_TARGETS}) ================================================ FILE: external/fmt/CMakeLists.txt ================================================ include(FetchContent) set(FMT_INSTALL_DIR ${THIRD_PARTY_DIR}/fmt) FetchContent_Declare(fmt URL ${FMT_URL} URL_HASH MD5=${FMT_MD5}) FetchContent_MakeAvailable(fmt) # Clang doesn't support __float128 when compiling CUDA target_compile_definitions(fmt PUBLIC FMT_USE_FLOAT128=0) install( TARGETS fmt EXPORT oneflow LIBRARY DESTINATION ${FMT_INSTALL_DIR}/lib ARCHIVE DESTINATION ${FMT_INSTALL_DIR}/lib) install(DIRECTORY ${fmt_SOURCE_DIR}/include DESTINATION ${FMT_INSTALL_DIR}) install(DIRECTORY ${fmt_SOURCE_DIR}/include/ DESTINATION ${ONEFLOW_INCLUDE_DIR} COMPONENT oneflow_py_include EXCLUDE_FROM_ALL) ================================================ FILE: external/kineto/CMakeLists.txt ================================================ include(FetchContent) # reference: https://github.com/PaddlePaddle/Paddle/blob/develop/cmake/cupti.cmake set(CUPTI_ROOT "/usr" CACHE PATH "CUPTI ROOT") set(CUDA_SOURCE_DIR ${CUDAToolkit_TARGET_DIR}) find_path( CUPTI_INCLUDE_DIR cupti.h PATHS ${CUPTI_ROOT} ${CUPTI_ROOT}/include $ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/include ${CUDA_SOURCE_DIR}/extras/CUPTI/include ${CUDA_SOURCE_DIR}/targets/x86_64-linux/include ${CUDA_SOURCE_DIR}/targets/aarch64-linux/include NO_DEFAULT_PATH) set(TARGET_ARCH "x86_64") if(NOT ${CMAKE_SYSTEM_PROCESSOR}) set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR}) endif() list( APPEND CUPTI_CHECK_LIBRARY_DIRS ${CUPTI_ROOT} ${CUPTI_ROOT}/lib64 ${CUPTI_ROOT}/lib ${CUPTI_ROOT}/lib/${TARGET_ARCH}-linux-gnu $ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/lib64 $ENV{CUPTI_ROOT}/lib /usr/lib ${CUDA_SOURCE_DIR}/targets/x86_64-linux/lib64 ${CUDA_SOURCE_DIR}/targets/x86_64-linux/lib ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64 ${CUDA_SOURCE_DIR}/extras/CUPTI/lib) find_library( CUDA_cupti_LIBRARY NAMES libcupti.so libcupti.dylib # libcupti_static.a PATHS ${CUPTI_CHECK_LIBRARY_DIRS} ${CUPTI_INCLUDE_DIR} NO_DEFAULT_PATH DOC "Path to cuPTI library.") list(APPEND CUDA_cupti_LIBRARY CUDA::cudart_static) # for undefined symbol: cudaGetDeviceCount∂ FetchContent_Declare( kineto URL ${KINETO_URL} URL_HASH MD5=${KINETO_MD5} SOURCE_SUBDIR libkineto) FetchContent_MakeAvailable(kineto) target_include_directories(kineto PUBLIC $) ================================================ FILE: external/onetbb/CMakeLists.txt ================================================ find_package(Threads REQUIRED) set(ONETBB_INSTALL_DIR ${THIRD_PARTY_DIR}/tbb CACHE PATH " ") include(FetchContent) FetchContent_Declare(tbb URL ${ONETBB_URL} URL_HASH MD5=${ONETBB_MD5}) FetchContent_GetProperties(tbb) set(TBB_EXAMPLES OFF CACHE BOOL "") set(TBB_TEST OFF CACHE BOOL "") set(TBB_ENABLE_IPO OFF CACHE BOOL "") set(BUILD_SHARED_LIBS ON) set(CMAKE_POLICY_DEFAULT_CMP0079 NEW) FetchContent_MakeAvailable(tbb) # workaround compile error in GCC 12 or later # refer to https://github.com/Oneflow-Inc/oneflow/pull/10236 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12) target_compile_options(tbb PRIVATE "-Wno-error=stringop-overflow") endif() set(TBBBIND_LIBRARY_NAME) if(HWLOC_VERSION) if(HWLOC_VERSION VERSION_LESS 2) set(TBBBIND_LIBRARY_NAME tbbbind) elseif(HWLOC_VERSION VERSION_LESS 2.5) set(TBBBIND_LIBRARY_NAME tbbbind_2_0) else() set(TBBBIND_LIBRARY_NAME tbbbind_2_5) endif() endif() add_custom_target( install-tbb DEPENDS tbb tbbmalloc tbbmalloc_proxy ${TBBBIND_LIBRARY_NAME} COMMAND "${CMAKE_COMMAND}" -DCMAKE_INSTALL_PREFIX=${ONETBB_INSTALL_DIR} -P "${tbb_BINARY_DIR}/cmake_install.cmake") ================================================ FILE: external/robin-hood-hashing/CMakeLists.txt ================================================ include(FetchContent) FetchContent_Declare( robin_hood_hashing URL ${ROBIN_HOOD_HASHING_URL} URL_HASH MD5=${ROBIN_HOOD_HASHING_MD5} ) FetchContent_MakeAvailable(robin_hood_hashing) ================================================ FILE: oneflow/api/common/ir_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_MLIR #include "oneflow/ir/include/OneFlow/Extension.h" #include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h" #include namespace oneflow { REGISTER_JOB_PASS("IRRoundTripBeforeAD", IRRoundTrip); REGISTER_JOB_PASS("IRRoundTrip", IRRoundTrip); } // namespace oneflow #endif // WITH_MLIR ================================================ FILE: oneflow/api/common/job_build_and_infer_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ #define ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" namespace oneflow { inline Maybe GetCurrentJob() { auto* job_ctx_mgr = Singleton::Get(); CHECK_NOTNULL_OR_RETURN(job_ctx_mgr); auto* job_ctx = JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName()))); CHECK_NOTNULL_OR_RETURN(job_ctx); return job_ctx->job(); } } // namespace oneflow #endif // ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ ================================================ FILE: oneflow/api/common/sbp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_COMMON_SBP_H_ #define ONEFLOW_API_COMMON_SBP_H_ #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" namespace oneflow { namespace api { // NOTE: The api inferface will print the whole name of sbp. inline Maybe ApiSbpToString(Symbol sbp_sym) { std::string sbp_str = "oneflow.sbp."; if (sbp_sym->has_broadcast_parallel()) { sbp_str += "broadcast"; } else if (sbp_sym->has_partial_sum_parallel()) { sbp_str += "partial_sum"; } else if (sbp_sym->has_split_parallel()) { sbp_str += "split(dim=" + std::to_string(sbp_sym->split_parallel().axis()) + ")"; } else { UNIMPLEMENTED_THEN_RETURN(); } return sbp_str; } inline Maybe ApiNdSbpToString(Symbol nd_sbp) { std::string str = "("; for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { if (i > 0) { str += ", "; } str += *JUST(ApiSbpToString(SymbolOf(nd_sbp->sbp_parallel(i)))); } if (nd_sbp->sbp_parallel_size() == 1) { str += ","; } str += ")"; return str; } } // namespace api } // namespace oneflow #endif // !ONEFLOW_API_COMMON_SBP_H_ ================================================ FILE: oneflow/api/common/variable_tensor_mgr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_ #define ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_ #include "oneflow/core/common/singleton.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/variable_tensor_mgr.h" namespace oneflow { inline Maybe FillVariableTensorMgr( const std::vector& variable_op_names, const std::vector>& variable_tensors) { auto mgr = Singleton::Get(); return mgr->Fill(variable_op_names, variable_tensors); } inline void ResetVariableTensorMgr() { auto mgr = Singleton::Get(); mgr->Reset(); } inline std::tuple, std::vector>> DumpVariableTensorMgr() { auto mgr = Singleton::Get(); return mgr->Dump(); } } // namespace oneflow #endif // ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_ ================================================ FILE: oneflow/api/cpp/api.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_API_H_ #define ONEFLOW_API_CPP_API_H_ #include "env.h" #include "framework.h" #include "nn.h" #endif // !ONEFLOW_API_CPP_API_H_ ================================================ FILE: oneflow/api/cpp/embedding/embedding.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/embedding/embedding.h" #include "oneflow/core/embedding/embedding_manager.h" namespace oneflow_api { namespace embedding { std::string CreateKeyValueStore(const std::string& key_value_store_options, int64_t local_rank_id, int64_t rank_id, int64_t world_size) { oneflow::embedding::KeyValueStoreOptions options(key_value_store_options); #ifdef WITH_CUDA oneflow::Singleton::Get()->CreateKeyValueStore( options, local_rank_id, rank_id, world_size); return options.Name(); #else UNIMPLEMENTED() << "OneEmbedding Only Support with CUDA"; #endif return ""; } void LoadSnapshot(const std::string& snapshot_name, const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id) { #ifdef WITH_CUDA oneflow::Singleton::Get()->LoadSnapshot( embedding_name, local_rank_id, rank_id, snapshot_name); #else UNIMPLEMENTED() << "OneEmbedding Only Support with CUDA"; #endif } } // namespace embedding } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/embedding/embedding.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_ #define ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_ #include namespace oneflow_api { namespace embedding { // CreateKeyValueStore returns embedding name in the options. std::string CreateKeyValueStore(const std::string& key_value_store_options, int64_t local_rank_id, int64_t rank_id, int64_t world_size); // key_value_store_options is // a serialized json string. void LoadSnapshot(const std::string& snapshot_name, const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id); } // namespace embedding } // namespace oneflow_api #endif // ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_ ================================================ FILE: oneflow/api/cpp/env.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/cpp/env.h" #include "oneflow/api/cpp/env_impl.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/thread/thread_global_id.h" namespace oneflow_api { void initialize() { if (of::Singleton::Get() == nullptr) { of::Singleton::New(); } of::SetShuttingDown(false); } void release() { if (of::Singleton::Get() != nullptr) { of::Singleton::Delete(); } of::SetShuttingDown(); } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/env.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_ENV_H_ #define ONEFLOW_API_CPP_ENV_H_ namespace oneflow_api { void initialize(); void release(); } // namespace oneflow_api #endif // !ONEFLOW_API_CPP_ENV_H_ ================================================ FILE: oneflow/api/cpp/env_impl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include "oneflow/api/cpp/env_impl.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/job/env.pb.h" #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/job/session.h" #include "oneflow/core/rpc/include/base.h" #include "oneflow/core/vm/vm_util.h" namespace oneflow_api { namespace of = oneflow; namespace { // for inltialize inline bool IsEnvInited() { return of::Singleton::Get() != nullptr; } bool HasEnvVar(const std::string& key) { const char* value = getenv(key.c_str()); return value != nullptr; } std::string GetEnvVar(const std::string& key, const std::string& default_value) { const char* value = getenv(key.c_str()); if (value == nullptr) { return default_value; } return std::string(value); } int64_t GetEnvVar(const std::string& key, int64_t default_value) { const char* value = getenv(key.c_str()); if (value == nullptr) { return default_value; } return std::atoll(value); } int32_t FindFreePort(const std::string& addr) { #ifdef __linux__ int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); CHECK_GE(sock, 0) << "fail to find a free port."; int optval = 1; setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); std::mt19937 rng; rng.seed(std::random_device()()); std::uniform_int_distribution dist(1, 1000); int count = 0; int num_attempts = 200; do { int port = 5000 + dist(rng); struct sockaddr_in sockaddr {}; memset(&sockaddr, 0, sizeof(sockaddr)); sockaddr.sin_family = AF_INET; sockaddr.sin_port = htons(port); sockaddr.sin_addr.s_addr = inet_addr(addr.c_str()); int error = bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr)); if (error == 0) { return port; } ++count; } while (count < num_attempts); CHECK_NE(count, num_attempts) << "fail to find a free port."; #endif // __linux__ return -1; } void CompleteEnvProto(of::EnvProto& env_proto) { auto bootstrap_conf = env_proto.mutable_ctrl_bootstrap_conf(); auto master_addr = bootstrap_conf->mutable_master_addr(); const std::string addr = GetEnvVar("MASTER_ADDR", "127.0.0.1"); master_addr->set_host(addr); master_addr->set_port(GetEnvVar("MASTER_PORT", FindFreePort(addr))); bootstrap_conf->set_world_size(GetEnvVar("WORLD_SIZE", 1)); bootstrap_conf->set_rank(GetEnvVar("RANK", 0)); auto cpp_logging_conf = env_proto.mutable_cpp_logging_conf(); if (HasEnvVar("GLOG_log_dir")) { cpp_logging_conf->set_log_dir(GetEnvVar("GLOG_log_dir", "")); } if (HasEnvVar("GLOG_logtostderr")) { cpp_logging_conf->set_logtostderr(GetEnvVar("GLOG_logtostderr", -1)); } if (HasEnvVar("GLOG_logbuflevel")) { cpp_logging_conf->set_logbuflevel(GetEnvVar("GLOG_logbuflevel", -1)); } if (HasEnvVar("GLOG_minloglevel")) { cpp_logging_conf->set_minloglevel(GetEnvVar("GLOG_minloglevel", -1)); } } } // namespace OneFlowEnv::OneFlowEnv() { of::EnvProto env_proto; CompleteEnvProto(env_proto); env_ctx_ = std::make_shared(env_proto); of::ConfigProto config_proto; config_proto.mutable_resource()->set_cpu_device_num(1); // useless, will be set in TryInit const int64_t session_id = of::NewSessionId(); config_proto.set_session_id(session_id); CHECK(of::RegsterSessionId(session_id)); session_ctx_ = std::make_shared(env_ctx_); CHECK_JUST(session_ctx_->TryInit(config_proto)); } OneFlowEnv::~OneFlowEnv() { session_ctx_.reset(); CHECK(of::ClearSessionId(CHECK_JUST(of::GetDefaultSessionId()))); env_ctx_.reset(); } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/env_impl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/env_global_objects_scope.h" #ifndef ONEFLOW_API_CPP_ENV_IMPL_H_ #define ONEFLOW_API_CPP_ENV_IMPL_H_ namespace oneflow_api { namespace of = oneflow; class OneFlowEnv { public: OF_DISALLOW_COPY(OneFlowEnv); OneFlowEnv(); ~OneFlowEnv(); std::shared_ptr GetSessionCtx() { return session_ctx_; } private: std::shared_ptr env_ctx_; std::shared_ptr session_ctx_; }; } // namespace oneflow_api #endif // ONEFLOW_API_CPP_ENV_IMPL_H_ ================================================ FILE: oneflow/api/cpp/framework/device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/device.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/device.h" namespace oneflow_api { namespace of = oneflow; Device::Device(const std::string& type_or_type_with_device_id) : device_(std::make_shared>( of::Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow())) {} Device::Device(const std::string& type, int64_t device_id) : device_( std::make_shared>(of::Device::New(type, device_id).GetOrThrow())) {} const std::string& Device::type() const { return (*device_)->type(); } int64_t Device::device_id() const { return (*device_)->device_id(); } bool Device::operator==(const Device& rhs) const { return *device_ == *rhs.device_; } bool Device::operator!=(const Device& rhs) const { return *device_ != *rhs.device_; } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/device.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_ #define ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_ #include #include namespace oneflow { class Device; template class Symbol; } // namespace oneflow namespace oneflow_api { class Device final { friend class Tensor; friend class Graph; public: explicit Device(const std::string& type_or_type_with_device_id); explicit Device(const std::string& type, int64_t device_id); [[nodiscard]] const std::string& type() const; [[nodiscard]] int64_t device_id() const; [[nodiscard]] bool operator==(const Device& rhs) const; [[nodiscard]] bool operator!=(const Device& rhs) const; private: std::shared_ptr> device_ = nullptr; }; } // namespace oneflow_api #endif // !ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_ ================================================ FILE: oneflow/api/cpp/framework/dtype.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/dtype.h" #include namespace oneflow_api { namespace { std::map DTypeSize = { {DType::kFloat, sizeof(float)}, {DType::kDouble, sizeof(double)}, {DType::kInt8, sizeof(int8_t)}, {DType::kInt32, sizeof(int32_t)}, {DType::kInt64, sizeof(int64_t)}, {DType::kBool, sizeof(bool)}, }; } int32_t GetDTypeSize(DType dtype) { return DTypeSize[dtype]; } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/dtype.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_ #define ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_ #include namespace oneflow_api { enum class DType { kInvalidDataType = 0, kChar = 1, kFloat = 2, kDouble = 3, kInt8 = 4, kInt32 = 5, kInt64 = 6, kUInt8 = 7, kOFRecord = 8, kFloat16 = 9, kTensorBuffer = 10, kBFloat16 = 11, kBool = 12, kMaxDataType = 13 }; [[nodiscard]] int32_t GetDTypeSize(DType dtype); } // namespace oneflow_api #endif // ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_ ================================================ FILE: oneflow/api/cpp/framework/graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "nlohmann/json.hpp" #include "oneflow/api/common/variable_tensor_mgr.h" #include "oneflow/api/cpp/env_impl.h" #include "oneflow/api/cpp/framework/device.h" #include "oneflow/api/cpp/framework/dtype.h" #include "oneflow/api/cpp/framework/graph.h" #include "oneflow/api/cpp/framework/ivalue.h" #include "oneflow/api/cpp/framework/shape.h" #include "oneflow/api/cpp/framework/tensor.h" #include "oneflow/api/cpp/embedding/embedding.h" #include "oneflow/api/common/job_build_and_infer_ctx.h" #include "oneflow/api/python/job_build/job_build_and_infer.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/util.h" #include "oneflow/core/embedding/posix_file.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_ir.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/session.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/operator/interface_blob_conf.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/register/logical_blob_id.pb.h" #include "oneflow/core/vm/vm_util.h" namespace oneflow_api { namespace of = oneflow; namespace { class CompileScope { public: CompileScope(const of::JobConfigProto& job_config, const of::Device& device) { of::JobConfigProto mut_job_config = job_config; const std::shared_ptr scope = CHECK_JUST(MakeScope(mut_job_config, device)); CHECK_JUST(of::ThreadLocalScopeStackPush(scope)); CHECK_JUST(of::JobBuildAndInferCtx_Open(mut_job_config.job_name())); CHECK_JUST(CHECK_JUST(of::GetCurInferCtx())->SetJobConf(mut_job_config)); } ~CompileScope() { CHECK_JUST(of::JobBuildAndInferCtx_Close()); CHECK_JUST(of::ThreadLocalScopeStackPop()); } private: of::LazyMode::Guard lazy_mode_enabled_guard{true}; }; std::shared_ptr ConvertToTensorTuple( const std::vector>& tensors) { auto tensor_tuple = std::make_shared(); for (const auto& tensor : tensors) { tensor_tuple->emplace_back(tensor); } return tensor_tuple; } std::string GetDeviceTag(const Device& device) { return device.type(); } template const std::pair, std::vector> Unzip(const of::HashMap& hash_map) { std::vector vec1; std::vector vec2; for (const auto& entry : hash_map) { vec1.emplace_back(entry.first); vec2.emplace_back(entry.second); } return std::make_pair(vec1, vec2); } Shape OfShapeToOfApiShape(const of::Shape& of_shape) { std::vector dims(of_shape.dim_vec().begin(), of_shape.dim_vec().end()); return Shape(dims); } #ifdef __linux__ void LoadOneEmbedding(const std::string& model_path, const Device& device) { const std::string one_embedding_info_name("one_embedding_options.json"); const std::string one_embedding_info_save_path( oneflow::JoinPath(model_path, one_embedding_info_name)); if (oneflow::embedding::PosixFile::FileExists(one_embedding_info_save_path)) { std::ifstream one_embedding_info_file(one_embedding_info_save_path); auto one_embedding_json = nlohmann::json::parse(one_embedding_info_file); for (auto& it : one_embedding_json["embedding"]) { const std::string snapshot_path = it["snapshot"]; auto kv_options_json = it["kv_options"]; std::string embedding_name = embedding::CreateKeyValueStore(kv_options_json.dump(), /*local_rank_id=*/0, /*rank_id=*/0, /*world_size=*/1); embedding::LoadSnapshot(snapshot_path, embedding_name, /*local_rank_id=*/0, /*rank_id=*/0); } } } #endif // __linux__ } // namespace class Graph::GraphImpl final { public: explicit GraphImpl(const std::string& model_path, const Device& device = Device("cpu")); GraphImpl(const GraphImpl& graph) = delete; GraphImpl(GraphImpl&& graph) = default; ~GraphImpl(); GraphImpl& operator=(const GraphImpl& graph) = delete; GraphImpl& operator=(GraphImpl&& graph) = default; InputOutputInfos GetInputInfos(); InputOutputInfos GetOutputInfos(); std::vector Forward(const std::vector& inputs); void set_batch_size(int batch_size) { batch_size_ = batch_size; } of::Maybe RegisterJobPass( const std::function& pass_fn); private: of::Maybe CollectInputOutputInfos(); of::Maybe Compile(const std::vector& inputs); of::Maybe> Run(const std::vector& inputs) const; of::Maybe AddOp(of::OperatorConf op_conf); of::Maybe BuildGraph(); of::Maybe LoadCheckpoint(); of::Maybe RegisterTensors(const std::vector& inputs); of::Maybe ApplyJobPasses(const of::Job& job); std::shared_ptr graph_ = nullptr; std::string model_path_; bool is_compiled_ = false; int batch_size_ = 0; Device device_; of::Job job_; InputOutputInfos input_infos_; InputOutputInfos output_infos_; of::HashMap> output_name_to_tensor_; of::HashMap> variable_op_name_to_tensor_; std::shared_ptr output_tensor_tuple_; std::shared_ptr parameter_tensor_tuple_; std::vector> registered_job_passes_; }; Graph::Graph(const std::string& model_path, const Device& device) : graph_(std::make_unique(model_path, device)) {} Graph::~Graph() = default; Graph::Graph(Graph&& graph) noexcept : graph_(std::move(graph.graph_)) {} Graph& Graph::operator=(Graph&& graph) noexcept { if (&graph == this) { return *this; } graph_ = std::move(graph.graph_); return *this; } InputOutputInfos Graph::GetInputInfos() { return graph_->GetInputInfos(); } InputOutputInfos Graph::GetOutputInfos() { return graph_->GetOutputInfos(); } void Graph::RegisterJobPass(const std::function& pass_fn) { CHECK_JUST(graph_->RegisterJobPass(pass_fn)); } IValue Graph::Forward(const IValue& inputs) { std::vector input_tensors; if (inputs.IsNone()) { // do nothing } else if (inputs.IsTensor()) { input_tensors.emplace_back(inputs.ToTensor()); } else if (inputs.IsTensorVector()) { input_tensors = inputs.ToTensorVector(); } else { LOG(WARNING) << "Graph currently only support types: Tensor/vector(Tensor)/None"; } std::vector output_tensors = graph_->Forward(input_tensors); if (output_tensors.empty()) { return IValue{}; } else if (output_tensors.size() == 1) { return IValue(output_tensors.at(0)); } else { return IValue(output_tensors); } } void Graph::set_batch_size(int batch_size) { graph_->set_batch_size(batch_size); } Graph Graph::Load(const std::string& model_path, const Device& device) { #ifdef __linux__ LoadOneEmbedding(model_path, device); #endif // __linux__ Graph graph(model_path, device); return graph; } Graph::GraphImpl::GraphImpl(const std::string& model_path, const Device& device) : model_path_(model_path), device_(device) { CHECK_JUST(of::LoadJobFromIR(&job_, model_path + "/model.mlir")); CollectInputOutputInfos(); if (of::ParseBooleanFromEnv("ONEFLOW_SERVING_DEBUG", false)) { LOG(ERROR) << job_.DebugString(); } job_.mutable_job_conf()->mutable_predict_conf(); job_.mutable_job_conf()->set_job_name(job_.mutable_job_conf()->job_name() + of::NewUniqueId()); } InputOutputInfos Graph::GraphImpl::GetInputInfos() { return input_infos_; } InputOutputInfos Graph::GraphImpl::GetOutputInfos() { return output_infos_; } of::Maybe Graph::GraphImpl::CollectInputOutputInfos() { const of::OpGraph op_graph(job_); size_t input_order = 0; size_t output_order = 0; op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe { const of::OperatorConf& op_conf = node->op().op_conf(); if (op_conf.has_input_conf()) { of::InterfaceBlobConf blob_conf = op_conf.input_conf().blob_conf(); input_infos_[op_conf.name()] = InputOutputAttribute(static_cast(blob_conf.data_type()), OfShapeToOfApiShape(of::Shape(blob_conf.shape())), input_order); input_order += 1; } else if (op_conf.has_output_conf()) { of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf(); output_infos_[op_conf.name()] = InputOutputAttribute(static_cast(blob_conf.data_type()), OfShapeToOfApiShape(of::Shape(blob_conf.shape())), output_order); output_order += 1; } return of::Maybe::Ok(); }); return of::Maybe::Ok(); } of::Maybe Graph::GraphImpl::RegisterJobPass( const std::function& pass_fn) { if (is_compiled_) { return of::Error::RuntimeError() << "job pass should be registered before compile and forward"; } registered_job_passes_.emplace_back(pass_fn); return of::Maybe::Ok(); } of::Maybe Graph::GraphImpl::ApplyJobPasses(const of::Job& job) { auto current_job = std::make_shared(job); for (const auto& pass_fn : registered_job_passes_) { std::string new_serialized_original_job = pass_fn(current_job->SerializeAsString()); of::Job new_job; if (!new_job.ParseFromString(new_serialized_original_job)) { return of::Error::RuntimeError() << "invalid serialized job after pass applied"; } current_job->Swap(&new_job); } return current_job; } std::vector Graph::GraphImpl::Forward(const std::vector& inputs) { if (!is_compiled_) { static std::mutex mtx; std::lock_guard lock(mtx); Compile(inputs).GetOrThrow(); is_compiled_ = true; } return Run(inputs).GetOrThrow(); } of::Maybe Graph::GraphImpl::Compile(const std::vector& inputs) { JUST(BuildGraph()); JUST(RegisterTensors(inputs)); JUST(graph_->CompileAndInitRuntime()); return of::Maybe::Ok(); } of::Maybe> Graph::GraphImpl::Run(const std::vector& inputs) const { const auto input_tensor_tuple = std::make_shared(); for (const auto& tensor : inputs) { input_tensor_tuple->emplace_back(tensor.tensor_); } JUST(of::RunLazyNNGraph(*input_tensor_tuple, *output_tensor_tuple_, graph_)); JUST(of::SoftSyncNNGraphBuffers(*output_tensor_tuple_, graph_)); std::vector outputs; for (const auto& tensor : *output_tensor_tuple_) { outputs.emplace_back(Tensor(tensor)); } return outputs; } of::Maybe Graph::GraphImpl::AddOp(of::OperatorConf op_conf) { { const std::shared_ptr scope = JUST(of::GetCurrentScope()); op_conf.set_scope_symbol_id(scope->symbol_id().value_or(0)); } op_conf.set_device_tag(GetDeviceTag(device_)); if (batch_size_ > 0 && op_conf.has_input_conf()) { op_conf.mutable_input_conf()->mutable_blob_conf()->mutable_shape()->mutable_dim()->Set( 0, batch_size_); } auto* ctx = JUST(of::GetCurInferCtx()); JUST(ctx->AddAndInferGlobalOp(op_conf)); return of::Maybe::Ok(); } of::Maybe Graph::GraphImpl::BuildGraph() { CompileScope build_graph_scope(job_.job_conf(), *device_.device_->shared_from_symbol()); { const of::OpGraph op_graph(job_); op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe { const of::OperatorConf& op_conf = node->op().op_conf(); JUST(AddOp(op_conf)); if (op_conf.has_variable_conf()) { const of::LazyMode::Guard lazy_mode_disabled_guard{false}; const of::VariableOpConf& variable_conf = op_conf.variable_conf(); variable_op_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty( of::Shape(variable_conf.shape()), JUST(of::DType::Get(static_cast(variable_conf.data_type()))), *device_.device_, /*requires_grad=*/false, /*pin_memory=*/false)); } return of::Maybe::Ok(); }); } JUST(LoadCheckpoint()); JUST(of::CurJobBuildAndInferCtx_Complete()); std::shared_ptr complete_job = JUST(of::GetCurrentJob()); int64_t job_id = JUST(of::JobBuildAndInferCtx_GetCurrentJobId()); CHECK(of::Singleton::Get() != nullptr); // apply custom job passes complete_job = JUST(ApplyJobPasses(*complete_job)); graph_ = std::make_shared(job_.job_conf().job_name(), *complete_job, job_id, of::Singleton::Get()->GetSessionCtx()); { const of::OpGraph complete_graph(*complete_job); complete_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe { const of::LazyMode::Guard lazy_mode_disabled_guard{false}; const of::OperatorConf& op_conf = node->op().op_conf(); if (op_conf.has_output_conf()) { of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf(); if (batch_size_ > 0) { const std::string input_lbi_str = op_conf.output_conf().in(); const of::LogicalBlobId input_lbi = of::GenLogicalBlobId(input_lbi_str); int64_t batch_size = node->LogicalBlobDesc4Lbi(input_lbi).shape().At(0); blob_conf.mutable_shape()->set_dim(0, batch_size); } output_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty( of::Shape(blob_conf.shape()), JUST(of::DType::Get(static_cast(blob_conf.data_type()))), *device_.device_, /*requires_grad=*/false, /*pin_memory=*/false)); } return of::Maybe::Ok(); }); } return of::Maybe::Ok(); } of::Maybe Graph::GraphImpl::LoadCheckpoint() { for (const auto& variable_op_name_and_tensor : variable_op_name_to_tensor_) { const auto& variable_op_name = variable_op_name_and_tensor.first; const auto& variable_tensor = variable_op_name_and_tensor.second; const std::string variable_filename = model_path_ + "/" + variable_op_name + "/out"; const std::string buffer = [&]() { std::ifstream variable_file(variable_filename, std::ios::binary); CHECK(variable_file.is_open()); std::stringstream ss; ss << variable_file.rdbuf(); return ss.str(); }(); const auto& callback = [&](of::ep::Stream* stream, const std::shared_ptr& eager_blob_object) { of::AutoMemcpy(stream, eager_blob_object->mut_dptr(), buffer.data(), variable_tensor->shape()->elem_cnt() * of::GetSizeOfDataType(variable_tensor->dtype()->data_type()), eager_blob_object->mem_case(), of::memory::MakeHostMemCase()); }; JUST(of::one::SyncAccessTensorWithTimeOut(variable_tensor, callback, "mut")); } const auto& pair = Unzip(variable_op_name_to_tensor_); JUST(of::FillVariableTensorMgr(pair.first, pair.second)); return of::Maybe::Ok(); } of::Maybe Graph::GraphImpl::RegisterTensors(const std::vector& inputs) { { std::vector input_op_names(inputs.size()); std::vector> input_tensors(inputs.size()); for (const auto& input_info : input_infos_) { size_t index = input_info.second.input_output_index_; input_op_names[index] = input_info.first; input_tensors[index] = inputs.at(index).tensor_; } JUST(graph_->RegisterInputOpNamesAndTensors(input_op_names, input_tensors)); } { const auto& pair = Unzip(output_name_to_tensor_); const std::vector& output_op_names = pair.first; const std::vector>& output_tensors = pair.second; JUST(graph_->RegisterOutputOpNamesAndTensors(output_op_names, output_tensors)); output_tensor_tuple_ = ConvertToTensorTuple(output_tensors); } { const auto& t = of::DumpVariableTensorMgr(); const std::vector& variable_op_names = std::get<0>(t); const std::vector>& variable_tensors = std::get<1>(t); JUST(graph_->RegisterVariableOpNamesAndTensors(variable_op_names, variable_tensors)); parameter_tensor_tuple_ = ConvertToTensorTuple(variable_tensors); } return of::Maybe::Ok(); } Graph::GraphImpl::~GraphImpl() { of::vm::ClusterSync().GetOrThrow(); } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_GRAPH_H_ #define ONEFLOW_API_CPP_GRAPH_H_ #include "dtype.h" #include "shape.h" #include "device.h" #include "ivalue.h" #include "tensor.h" #include #include #include #include namespace oneflow { class NNGraph; } // namespace oneflow namespace oneflow_api { struct InputOutputAttribute { InputOutputAttribute(DType datatype, const Shape& input_output_shape, size_t input_output_index) : datatype_(datatype), input_output_shape_(input_output_shape), input_output_index_(input_output_index) {} InputOutputAttribute() : InputOutputAttribute(DType::kInvalidDataType, Shape(), 0) {} DType datatype_; Shape input_output_shape_; size_t input_output_index_; }; using InputOutputInfos = std::unordered_map; class Graph { public: explicit Graph(const std::string& model_path, const Device& device = Device("cpu")); ~Graph(); Graph(const Graph& graph) = delete; Graph(Graph&& graph) noexcept; Graph& operator=(const Graph& graph) = delete; Graph& operator=(Graph&& graph) noexcept; InputOutputInfos GetInputInfos(); InputOutputInfos GetOutputInfos(); IValue Forward(const IValue& inputs); void set_batch_size(int batch_size); void RegisterJobPass(const std::function& pass_fn); static Graph Load(const std::string& model_path, const Device& device = Device("cpu")); private: class GraphImpl; std::unique_ptr graph_; }; } // namespace oneflow_api #endif // ONEFLOW_API_CPP_GRAPH_H_ ================================================ FILE: oneflow/api/cpp/framework/ivalue.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/ivalue.h" #include namespace oneflow_api { namespace of = oneflow; std::ostream& operator<<(std::ostream& os, const IValue::Tag& tag) { os << static_cast(tag); return os; } int64_t IValue::ToInt() const { CHECK_EQ(tag_, Tag::kInt) << "Current value is not an int."; return payload_.i.v_int; } double IValue::ToDouble() const { CHECK_EQ(tag_, Tag::kDouble) << "Current value is not a double."; return payload_.i.v_double; } bool IValue::ToBool() const { CHECK_EQ(tag_, Tag::kBool) << "Current value is not a bool."; return payload_.i.v_bool; } const Tensor& IValue::ToTensor() const { CHECK_EQ(tag_, Tag::kTensor) << "Current value is not a tensor."; return payload_.v_tensor; } const std::vector& IValue::ToTensorVector() const { CHECK_EQ(tag_, Tag::kTensorVector) << "Current value is not a vector of tensor."; return payload_.v_tensor_vector; } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/ivalue.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ #define ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ #include #include #include #include "tensor.h" namespace oneflow_api { class IValue { public: IValue() : tag_(IValue::Tag::kNone) {} explicit IValue(int value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; } explicit IValue(int64_t value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; } explicit IValue(double value) : tag_(IValue::Tag::kDouble) { payload_.i.v_double = value; } explicit IValue(bool value) : tag_(IValue::Tag::kBool) { payload_.i.v_bool = value; } IValue(const Tensor& value) : tag_(IValue::Tag::kTensor) { // NOLINT new (&payload_.v_tensor) Tensor(value); } IValue(Tensor&& value) : tag_(IValue::Tag::kTensor) { // NOLINT new (&payload_.v_tensor) Tensor(std::move(value)); } IValue(const std::vector& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT new (&payload_.v_tensor_vector) std::vector(value); } IValue(std::vector&& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT new (&payload_.v_tensor_vector) std::vector(std::move(value)); } IValue(const IValue& value) : tag_(value.tag_) { if (IsTensor()) { new (&payload_.v_tensor) Tensor(value.payload_.v_tensor); } else if (IsTensorVector()) { new (&payload_.v_tensor_vector) std::vector(value.payload_.v_tensor_vector); } else { payload_.i = value.payload_.i; } } IValue(IValue&& value) noexcept : tag_(value.tag_) { MoveFrom(std::move(value)); } IValue& operator=(const IValue& value) { if (&value == this) { return *this; } this->tag_ = value.tag_; *this = IValue(value); return *this; } IValue& operator=(IValue&& value) noexcept { if (&value == this) { return *this; } Destory(); this->tag_ = value.tag_; MoveFrom(std::move(value)); return *this; } ~IValue() { Destory(); } bool IsNone() const { return tag_ == Tag::kNone; } bool IsInt() const { return tag_ == Tag::kInt; } bool IsDouble() const { return tag_ == Tag::kDouble; } bool IsBool() const { return tag_ == Tag::kBool; } bool IsTensor() const { return tag_ == Tag::kTensor; } bool IsTensorVector() const { return tag_ == Tag::kTensorVector; } int64_t ToInt() const; double ToDouble() const; bool ToBool() const; const Tensor& ToTensor() const; const std::vector& ToTensorVector() const; private: enum class Tag { kNone = 0, kInt = 1, kDouble = 2, kBool = 3, kTensor = 4, kTensorVector = 5 }; friend std::ostream& operator<<(std::ostream&, const Tag&); union Payload { // NOLINT union InternalPayload { InternalPayload() : v_int(0) {} int64_t v_int; double v_double; bool v_bool; } i; Tensor v_tensor; std::vector v_tensor_vector; Payload() : i() {} ~Payload() {} }; Payload payload_; Tag tag_; inline void Destory() { if (IsTensor()) { payload_.v_tensor.~Tensor(); } if (IsTensorVector()) { payload_.v_tensor_vector.~vector(); } } inline void MoveFrom(IValue&& value) { if (IsTensor()) { new (&payload_.v_tensor) Tensor(std::move(value.payload_.v_tensor)); } else if (IsTensorVector()) { new (&payload_.v_tensor_vector) std::vector(std::move(value.payload_.v_tensor_vector)); } else { payload_.i = value.payload_.i; } value.ClearToNone(); } inline void ClearToNone() { Destory(); payload_.i.v_int = 0; tag_ = Tag::kNone; } }; } // namespace oneflow_api #endif // ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ ================================================ FILE: oneflow/api/cpp/framework/shape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/shape.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_vec.h" namespace oneflow_api { namespace of = oneflow; namespace { of::DimVector ToOneflowDimVcetor(const std::vector& dim_vec) { return of::DimVector(dim_vec.begin(), dim_vec.end()); } } // namespace Shape::Shape() : shape_(std::make_shared(of::Shape({0}))) {} Shape::Shape(const std::vector& dim_vec) : shape_(std::make_shared(ToOneflowDimVcetor(dim_vec))) {} Shape::Shape(const std::initializer_list& dim_vec) : shape_(std::make_shared(dim_vec)) {} Shape& Shape::operator=(const Shape& shape) { this->shape_.reset(); this->shape_ = shape.shape_; return *this; } bool Shape::operator==(const Shape& rhs) const { return *shape_ == *rhs.shape_; } bool Shape::operator!=(const Shape& rhs) const { return !(*this == rhs); } int64_t Shape::elem_cnt() const { return shape_->elem_cnt(); } int64_t Shape::At(int64_t index) const { return shape_->At(index); } void Shape::Set(int64_t index, int64_t val) { shape_->Set(index, val); } int64_t Shape::NumAxes() const { return shape_->NumAxes(); } int64_t Shape::Count(int64_t begin_axis, int64_t end_axis) const { return shape_->Count(begin_axis, end_axis); } int64_t Shape::Count(int64_t begin_axis) const { return shape_->Count(begin_axis); } std::ostream& operator<<(std::ostream& os, const Shape& shape) { os << shape.shape_->DebugStr(); return os; } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/shape.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_ #define ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_ #include #include namespace oneflow { class Shape; } namespace oneflow_api { class Shape final { friend class Tensor; public: Shape(); explicit Shape(const std::vector& dim_vec); Shape(const std::initializer_list& dim_vec); ~Shape() = default; Shape& operator=(const Shape& shape); [[nodiscard]] bool operator==(const Shape& rhs) const; [[nodiscard]] bool operator!=(const Shape& rhs) const; void Set(int64_t index, int64_t val); [[nodiscard]] int64_t elem_cnt() const; [[nodiscard]] int64_t At(int64_t index) const; [[nodiscard]] int64_t NumAxes() const; [[nodiscard]] int64_t Count(int64_t begin_axis, int64_t end_axis) const; [[nodiscard]] int64_t Count(int64_t begin_axis) const; private: std::shared_ptr shape_ = nullptr; friend std::ostream& operator<<(std::ostream&, const Shape&); }; } // namespace oneflow_api #endif // ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_ ================================================ FILE: oneflow/api/cpp/framework/tensor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/tensor.h" #include "oneflow/api/cpp/framework/device.h" #include "oneflow/api/cpp/framework/dtype.h" #include "oneflow/api/cpp/framework/shape.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/vm/virtual_machine.h" namespace oneflow_api { namespace of = oneflow; namespace functional = of::one::functional; Tensor::Tensor(const Shape& shape, const Device& device, const DType& dtype) { of::LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); tensor_ = functional::Empty(*shape.shape_, of::DType::Get(static_cast(dtype)).GetOrThrow(), *device.device_, /*requires_grad=*/false, /*pin_memory=*/false) .GetPtrOrThrow(); } Tensor::Tensor(const std::shared_ptr& tensor) : tensor_(tensor) {} Tensor::Tensor(const Tensor& tensor) : tensor_(tensor.tensor_) {} Tensor::Tensor(Tensor&& tensor) noexcept : tensor_(std::move(tensor.tensor_)) {} Tensor& Tensor::operator=(const Tensor& tensor) { if (&tensor == this) { return *this; } tensor_ = tensor.tensor_; return *this; } Tensor& Tensor::operator=(Tensor&& tensor) noexcept { if (&tensor == this) { return *this; } tensor_ = std::move(tensor.tensor_); return *this; } Shape Tensor::shape() const { const auto shape_ = tensor_->shape(); return Shape(std::vector(shape_->dim_vec().begin(), shape_->dim_vec().end())); } Device Tensor::device() const { const auto device_ = tensor_->device().GetOrThrow(); return Device(device_->type(), device_->device_id()); } DType Tensor::dtype() const { return static_cast(tensor_->dtype()->data_type()); } void Tensor::zeros_() { std::shared_ptr local_tensor = tensor_->AsLocalTensor().GetPtrOrThrow(); of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe { JUST(builder->AccessBlobByCallback( local_tensor, [](of::ep::Stream* stream, const std::shared_ptr& eager_blob_object) { of::AutoMemset(stream, eager_blob_object->mut_dptr(), 0, eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case()); }, "mut")); return of::Maybe::Ok(); }).GetOrThrow(); } Tensor Tensor::from_buffer(const void* buffer, const Shape& shape, const Device& device, const DType& dtype) { Tensor tensor(shape, device, dtype); std::shared_ptr local_tensor = tensor.tensor_->AsLocalTensor().GetPtrOrThrow(); of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe { return builder->AccessBlobByCallback( local_tensor, [buffer, shape, dtype](of::ep::Stream* stream, const std::shared_ptr& eager_blob_object) { of::AutoMemcpy(stream, eager_blob_object->mut_dptr(), buffer, shape.Count(0) * GetDTypeSize(dtype), eager_blob_object->mem_case(), of::memory::MakeHostMemCase()); }, "mut"); }).GetOrThrow(); return tensor; } template void Tensor::copy_to(T* buffer) const { std::shared_ptr local_tensor = tensor_->AsLocalTensor().GetPtrOrThrow(); const auto shape = this->shape(); const auto& Callback = [buffer, shape]( of::ep::Stream* stream, const std::shared_ptr& eager_blob_object) { of::AutoMemcpy(stream, buffer, eager_blob_object->mut_dptr(), shape.Count(0) * sizeof(T), of::memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; auto btb = std::make_shared(); CHECK_JUST(of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe { return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, "const"); })); TRY(btb->WaitUntilCntEqualZero(of::VirtualMachine::GetPredicatorNoMoreInstructionsFinished())) .GetOrThrow(); } const std::shared_ptr& Tensor::__internal_tensor() const { return tensor_; } #define REGISTER_TENSOR_COPY_TO(cpp_dtype) \ template void Tensor::copy_to(cpp_dtype * buffer) const; REGISTER_TENSOR_COPY_TO(float) REGISTER_TENSOR_COPY_TO(double) REGISTER_TENSOR_COPY_TO(bool) REGISTER_TENSOR_COPY_TO(int8_t) REGISTER_TENSOR_COPY_TO(int32_t) REGISTER_TENSOR_COPY_TO(int64_t) } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/framework/tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_ #define ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_ #include #include "device.h" #include "shape.h" #include "dtype.h" namespace oneflow { namespace one { class Tensor; } } // namespace oneflow namespace oneflow_api { class Tensor final { friend class Graph; public: explicit Tensor(const Shape& shape = Shape(), const Device& device = Device("cpu"), const DType& dtype = DType::kFloat); explicit Tensor(const std::shared_ptr& tensor); Tensor(const Tensor& tensor); Tensor(Tensor&& tensor) noexcept; ~Tensor() = default; Tensor& operator=(const Tensor& tensor); Tensor& operator=(Tensor&& tensor) noexcept; [[nodiscard]] Shape shape() const; [[nodiscard]] Device device() const; [[nodiscard]] DType dtype() const; void zeros_(); // You should never call __internal_tensor() directly. [[nodiscard]] const std::shared_ptr& __internal_tensor() const; template void copy_to(T* buffer) const; [[nodiscard]] static Tensor from_buffer(const void* buffer, const Shape& shape, const Device& device, const DType& dtype); private: std::shared_ptr tensor_ = nullptr; }; } // namespace oneflow_api #endif // ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_ ================================================ FILE: oneflow/api/cpp/framework.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_FRAMEWORK_H_ #define ONEFLOW_API_CPP_FRAMEWORK_H_ #include "framework/device.h" #include "framework/shape.h" #include "framework/dtype.h" #include "framework/tensor.h" #include "framework/ivalue.h" #include "framework/graph.h" #endif // ONEFLOW_API_CPP_FRAMEWORK_H_ ================================================ FILE: oneflow/api/cpp/nn/functional/activation.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/nn/functional/activation.h" #include "oneflow/core/functional/functional.h" namespace oneflow_api { namespace nn { namespace of = oneflow; namespace functional = of::one::functional; Tensor relu(const Tensor& tensor) { return Tensor(functional::Relu(tensor.__internal_tensor(), false).GetPtrOrThrow()); } } // namespace nn } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/nn/functional/activation.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_ #define ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_ #include "../../framework.h" namespace oneflow_api { namespace nn { Tensor relu(const Tensor& tensor); } } // namespace oneflow_api #endif // ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_ ================================================ FILE: oneflow/api/cpp/nn.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_NN_H_ #define ONEFLOW_API_CPP_NN_H_ #include "nn/functional/activation.h" #endif // ONEFLOW_API_CPP_NN_H_ ================================================ FILE: oneflow/api/cpp/tests/api_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/tests/api_test.h" #include #include #include #ifdef __linux__ #include // readlink #elif defined(__APPLE__) #include // _NSGetExecutablePath #endif namespace oneflow_api { Shape RandomShape() { thread_local static std::mt19937 rng(std::random_device{}()); std::uniform_int_distribution<> dist_ndim(1, 4), dist_dims(16, 64); std::vector dims(dist_ndim(rng), 0); for (auto& x : dims) { x = dist_dims(rng); } return Shape(dims); } template std::vector RandomData(size_t size) { thread_local static std::mt19937 rng(std::random_device{}()); std::uniform_int_distribution<> dist(-100, 100); std::vector data(size); for (auto& x : data) { x = static_cast(dist(rng)); } return data; } #define REGISTER_RANDOM_DATA(cpp_dtype) template std::vector RandomData(size_t size); REGISTER_RANDOM_DATA(float) REGISTER_RANDOM_DATA(double) REGISTER_RANDOM_DATA(int8_t) REGISTER_RANDOM_DATA(int32_t) REGISTER_RANDOM_DATA(int64_t) std::string GetExeDir() { const size_t path_max_size = 4096; // PATH_MAX = 4096 on linux char result[path_max_size]; const auto get_dir_from_path = [](char result[], size_t count) -> std::string { std::string exe_path(result, (count > 0) ? count : 0); // string(path).rfind('/') will never be string::npos on linux or macos. return exe_path.substr(0, exe_path.rfind('/')); }; #ifdef __linux__ ssize_t count = readlink("/proc/self/exe", result, path_max_size); return get_dir_from_path(result, count); #elif defined(__APPLE__) uint32_t count = path_max_size; CHECK_EQ(_NSGetExecutablePath(result, &count), 0) << "Fail to get executable file path."; return get_dir_from_path(result, count); #else #error oneflow_api::GetExeDir() has not been supported on windows. #endif } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/tests/api_test.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_CPP_TESTS_API_TEST_H_ #define ONEFLOW_API_CPP_TESTS_API_TEST_H_ #include "oneflow/api/cpp/api.h" namespace oneflow_api { class EnvScope { // NOLINT public: EnvScope() { initialize(); } ~EnvScope() { release(); } }; Shape RandomShape(); template std::vector RandomData(size_t size); std::string GetExeDir(); } // namespace oneflow_api #endif // !ONEFLOW_API_CPP_TESTS_API_TEST_H_ ================================================ FILE: oneflow/api/cpp/tests/graph_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include #include #include #include #include "oneflow/api/cpp/framework.h" #include "oneflow/api/cpp/framework/dtype.h" #include "oneflow/api/cpp/framework/shape.h" #include "oneflow/api/cpp/tests/api_test.h" namespace oneflow_api { namespace { inline Graph LoadGraph(const Device& device) { Graph graph = Graph::Load("./oneflow/api/cpp/tests/graph_test_model/affine_with_parameter", device); return graph; } inline void Forward(Graph& graph, const Device& device, int expected_batch_dim = 1) { std::vector data(expected_batch_dim * 3); std::fill(data.begin(), data.end(), 1); std::vector inputs; inputs.emplace_back( Tensor::from_buffer(data.data(), Shape({expected_batch_dim, 3}), device, DType::kFloat)); const auto& value = graph.Forward(inputs); ASSERT_TRUE(value.IsTensor()); Tensor output = value.ToTensor(); Shape shape = output.shape(); ASSERT_EQ(shape.At(0), expected_batch_dim); ASSERT_EQ(shape.At(1), 4); std::vector buf(expected_batch_dim * 4); output.copy_to(buf.data()); for (const float& element : buf) { ASSERT_EQ(element, 4); } } } // namespace TEST(Api, graph_cpu_test) { EnvScope scope; Device device("cpu"); Graph graph = LoadGraph(device); Forward(graph, device, 1); } #ifdef WITH_CUDA TEST(Api, graph_gpu_test) { EnvScope scope; Device device("cuda", 0); Graph graph = LoadGraph(device); Forward(graph, device); } TEST(Api, graph_multi_gpu_test) { EnvScope scope; Device device("cuda", 0); Graph graph = LoadGraph(device); Forward(graph, device); Device device1("cuda", 1); Graph graph1 = LoadGraph(device1); Forward(graph1, device1); } #endif TEST(Api, graph_cpu_batching_test) { EnvScope scope; Device device("cpu"); Graph graph = LoadGraph(device); graph.set_batch_size(10); Forward(graph, device, 10); } #ifdef WITH_CUDA TEST(Api, graph_gpu_batching_test) { EnvScope scope; Device device("cuda", 0); Graph graph = LoadGraph(device); graph.set_batch_size(10); Forward(graph, device, 10); } TEST(Api, graph_multi_device_test) { EnvScope scope; Device device("cuda", 0); Graph graph = LoadGraph(device); Forward(graph, device, 1); Device device1("cuda", 1); Graph graph1 = LoadGraph(device1); Forward(graph1, device1, 1); Device device2("cpu"); Graph graph2 = LoadGraph(device2); Forward(graph2, device2, 1); } TEST(Api, graph_unload_test) { { EnvScope scope; Device device("cuda", 0); Graph graph = LoadGraph(device); Forward(graph, device, 1); { Device device1("cuda", 1); Graph graph1 = LoadGraph(device1); Forward(graph1, device1, 1); } Device device2("cpu"); Graph graph2 = LoadGraph(device2); Forward(graph2, device2, 1); } { EnvScope scope; Device device("cpu"); Graph graph = LoadGraph(device); Forward(graph, device, 1); } } #endif TEST(Api, graph_thread_test) { EnvScope scope; Device device("cpu"); std::vector graphs; for (int i = 0; i < 10; i++) { graphs.emplace_back(LoadGraph(device)); } std::vector threads; for (Graph& graph : graphs) { threads.emplace_back(std::thread(std::bind(Forward, std::move(graph), device, 1))); } for (auto& thread : threads) { thread.join(); } } TEST(Api, graph_input_order_test) { EnvScope scope; Device device("cpu"); Graph graph = Graph::Load("./oneflow/api/cpp/tests/graph_test_model/affine_no_parameter", device); std::vector inputs; std::vector x(3); std::fill(x.begin(), x.end(), 1); inputs.emplace_back(Tensor::from_buffer(x.data(), Shape({1, 3}), device, DType::kFloat)); std::vector a(3 * 2); std::fill(a.begin(), a.end(), 1); inputs.emplace_back(Tensor::from_buffer(a.data(), Shape({3, 2}), device, DType::kFloat)); std::vector b(2); std::fill(b.begin(), b.end(), 1); inputs.emplace_back(Tensor::from_buffer(b.data(), Shape({2}), device, DType::kFloat)); const auto& value = graph.Forward(inputs); ASSERT_TRUE(value.IsTensor()); Tensor output = value.ToTensor(); Shape shape = output.shape(); ASSERT_EQ(shape.At(0), 1); ASSERT_EQ(shape.At(1), 2); std::array buf{}; output.copy_to(buf.data()); ASSERT_EQ(buf[0], 4); ASSERT_EQ(buf[1], 4); } TEST(Api, graph_input_output_infos_test) { EnvScope scope; Device device("cpu"); Graph graph = LoadGraph(device); auto input_infos = graph.GetInputInfos(); auto output_infos = graph.GetOutputInfos(); ASSERT_EQ(input_infos.size(), 1); ASSERT_EQ(output_infos.size(), 1); auto it = input_infos.begin(); DType dtype = it->second.datatype_; Shape shape = it->second.input_output_shape_; size_t order = it->second.input_output_index_; ASSERT_EQ(dtype, DType::kFloat); ASSERT_EQ(shape.NumAxes(), 2); ASSERT_EQ(shape.At(0), 1); ASSERT_EQ(shape.At(1), 3); ASSERT_EQ(order, 0); it = output_infos.begin(); dtype = it->second.datatype_; shape = it->second.input_output_shape_; order = it->second.input_output_index_; ASSERT_EQ(dtype, DType::kFloat); ASSERT_EQ(shape.NumAxes(), 2); ASSERT_EQ(shape.At(0), 1); ASSERT_EQ(shape.At(1), 4); ASSERT_EQ(order, 0); } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir ================================================ module { oneflow.job @MyGraph_1(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2xf32> { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_0", output_lbns = ["_MyGraph_1-input_0/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32> %output_0 = "oneflow.input"(%arg1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_1", output_lbns = ["_MyGraph_1-input_1/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [3 : si64, 2 : si64]} : (tensor<3x2xf32>) -> tensor<3x2xf32> %output_1 = "oneflow.input"(%arg2) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_2", output_lbns = ["_MyGraph_1-input_2/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %0 = "oneflow.matmul"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-matmul_0", output_lbns = ["model-matmul_0/out_0"], scope_symbol_id = 4611686018427535359 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x2xf32>) -> tensor<1x2xf32> %1 = "oneflow.broadcast_add"(%0, %output_1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-broadcast_add_1", output_lbns = ["model-broadcast_add_1/z_0"], scope_symbol_id = 4611686018427535359 : i64} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32> %output_2 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-output_0", output_lbns = ["_MyGraph_1-output_0/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 2 : si64]} : (tensor<1x2xf32>) -> tensor<1x2xf32> oneflow.return %output_2 : tensor<1x2xf32> } } ================================================ FILE: oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta ================================================ shape { dim: 3 dim: 4 } data_type: kFloat ================================================ FILE: oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta ================================================ shape { dim: 4 } data_type: kFloat ================================================ FILE: oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir ================================================ module { oneflow.job @MyGraph_0(%arg0: tensor<1x3xf32>) -> tensor<1x4xf32> { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_0-input_0", output_lbns = ["_MyGraph_0-input_0/out"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32> %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = "model.a", output_lbns = ["model.a/out"], scope_symbol_id = 4611686018427482111 : i64, shape = [3 : si64, 4 : si64]} : () -> tensor<3x4xf32> %output_1 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = "model.b", output_lbns = ["model.b/out"], scope_symbol_id = 4611686018427494399 : i64, shape = [4 : si64]} : () -> tensor<4xf32> %0 = "oneflow.matmul"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-matmul_0", output_lbns = ["model-matmul_0/out_0"], scope_symbol_id = 4611686018427486207 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x4xf32>) -> tensor<1x4xf32> %1 = "oneflow.broadcast_add"(%0, %output_1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-broadcast_add_1", output_lbns = ["model-broadcast_add_1/z_0"], scope_symbol_id = 4611686018427486207 : i64} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> %output_2 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_0-output_0", output_lbns = ["_MyGraph_0-output_0/out"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 4 : si64]} : (tensor<1x4xf32>) -> tensor<1x4xf32> oneflow.return %output_2 : tensor<1x4xf32> } } ================================================ FILE: oneflow/api/cpp/tests/ivalue_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/cpp/framework/dtype.h" #include "oneflow/api/cpp/framework/ivalue.h" #include "oneflow/api/cpp/tests/api_test.h" namespace oneflow_api { namespace { std::mt19937 rng(std::random_device{}()); } TEST(Api, ivalue) { std::uniform_real_distribution<> dist(-100, 100); std::uniform_int_distribution<> dist_bool(0, 1); const auto v_int = static_cast(dist(rng)); ASSERT_EQ(IValue(v_int).ToInt(), v_int); const auto v_int64 = static_cast(dist(rng)); ASSERT_EQ(IValue(v_int64).ToInt(), v_int64); const auto v_float = static_cast(dist(rng)); ASSERT_EQ(IValue(v_float).ToDouble(), v_float); const auto v_double = static_cast(dist(rng)); ASSERT_EQ(IValue(v_double).ToDouble(), v_double); const auto v_bool = static_cast(dist_bool(rng)); ASSERT_EQ(IValue(v_bool).ToBool(), v_bool); } TEST(Api, ivalue_tensor) { EnvScope scope; const auto device = Device("cpu"); const auto shape = RandomShape(); const auto dtype = DType::kDouble; const IValue i_tensor(Tensor(shape, device, dtype)); const auto& tensor = i_tensor.ToTensor(); ASSERT_EQ(tensor.shape(), shape); ASSERT_EQ(tensor.device(), device); ASSERT_EQ(tensor.dtype(), dtype); } TEST(Api, ivalue_tensor_vector) { EnvScope scope; const auto device = Device("cpu"); const std::vector v_tensor_vector{Tensor(RandomShape(), device, DType::kDouble), Tensor(RandomShape(), device, DType::kFloat)}; const auto i_tensor = IValue(v_tensor_vector); const auto& tensor_vector = i_tensor.ToTensorVector(); ASSERT_EQ(v_tensor_vector.size(), tensor_vector.size()); for (size_t i = 0; i < tensor_vector.size(); ++i) { ASSERT_EQ(v_tensor_vector[i].device(), tensor_vector[i].device()); ASSERT_EQ(v_tensor_vector[i].shape(), tensor_vector[i].shape()); ASSERT_EQ(v_tensor_vector[i].dtype(), tensor_vector[i].dtype()); } } TEST(Api, ivalue_copy) { EnvScope scope; const auto device = Device("cpu"); const auto shape = RandomShape(); const auto dtype = DType::kDouble; const IValue i_tensor(Tensor(shape, device, dtype)); const auto i_tensor_a = i_tensor; // NOLINT ASSERT_EQ(i_tensor_a.ToTensor().shape(), shape); ASSERT_EQ(i_tensor_a.ToTensor().device(), device); ASSERT_EQ(i_tensor_a.ToTensor().dtype(), dtype); IValue i_tensor_b; i_tensor_b = i_tensor; ASSERT_EQ(i_tensor_b.ToTensor().shape(), shape); ASSERT_EQ(i_tensor_b.ToTensor().device(), device); ASSERT_EQ(i_tensor_b.ToTensor().dtype(), dtype); } TEST(Api, ivalue_move) { EnvScope scope; const auto device = Device("cpu"); const auto shape = RandomShape(); const auto dtype = DType::kDouble; IValue i_tensor_a = IValue(Tensor(shape, device, dtype)); IValue i_tensor_b = IValue(Tensor(shape, device, dtype)); IValue i_tensor_c = std::move(i_tensor_a); ASSERT_EQ(i_tensor_c.ToTensor().shape(), shape); ASSERT_EQ(i_tensor_c.ToTensor().device(), device); ASSERT_EQ(i_tensor_c.ToTensor().dtype(), dtype); IValue i_tensor_d; i_tensor_d = std::move(i_tensor_b); ASSERT_EQ(i_tensor_d.ToTensor().shape(), shape); ASSERT_EQ(i_tensor_d.ToTensor().device(), device); ASSERT_EQ(i_tensor_d.ToTensor().dtype(), dtype); ASSERT_EQ(i_tensor_a.IsNone(), true); ASSERT_EQ(i_tensor_b.IsNone(), true); } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/tests/nn_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/cpp/tests/api_test.h" namespace oneflow_api { namespace { std::mt19937 rng(std::random_device{}()); template std::vector Relu(const std::vector& data) { std::vector result(data.begin(), data.end()); T zero = static_cast(0); for (auto& x : result) { if (x < zero) { x = zero; } } return result; } } // namespace void TestRelu() { const auto shape = RandomShape(); const auto data = RandomData(shape.Count(0)); const auto target_data = Relu(data); std::vector result(shape.Count(0)); auto tensor = Tensor::from_buffer(data.data(), shape, Device("cpu"), DType::kFloat); auto result_tensor = nn::relu(tensor); result_tensor.copy_to(result.data()); ASSERT_EQ(result, target_data); } TEST(Api, nn_relu) { EnvScope scope; TestRelu(); } TEST(Api, nn_relu_multithreading) { EnvScope scope; std::vector threads; std::uniform_int_distribution<> dist(8, 32); int n_threads = dist(rng); for (int i = 0; i < n_threads; ++i) { threads.emplace_back(std::thread(TestRelu)); } for (auto& x : threads) { x.join(); } } } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/tests/one_embedding_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/cpp/tests/api_test.h" namespace oneflow_api { #ifdef WITH_CUDA TEST(Api, embedding_test) { EnvScope scope; Device device("cuda"); Graph graph = Graph::Load("/path/to/embedding", device); int64_t batch_size = 10000; int64_t num_features = 39; std::vector data(batch_size * num_features); std::fill(data.begin(), data.end(), 1); std::vector inputs; inputs.emplace_back( Tensor::from_buffer(data.data(), Shape({batch_size, num_features}), device, DType::kInt64)); const auto& value = graph.Forward(inputs); ASSERT_TRUE(value.IsTensor()); Tensor output = value.ToTensor(); Shape shape = output.shape(); ASSERT_EQ(shape.At(0), batch_size); ASSERT_EQ(shape.At(1), 1); std::vector buf(batch_size); output.copy_to(buf.data()); } #endif } // namespace oneflow_api ================================================ FILE: oneflow/api/cpp/tests/tensor_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/cpp/tests/api_test.h" namespace oneflow_api { TEST(Api, device) { EnvScope scope; auto device = Device("cpu"); ASSERT_EQ(device.type(), "cpu"); #ifdef WITH_CUDA device = Device("cuda:0"); ASSERT_EQ(device.type(), "cuda"); ASSERT_EQ(device.device_id(), 0); device = Device("cuda", 1); ASSERT_EQ(device.type(), "cuda"); ASSERT_EQ(device.device_id(), 1); #endif } TEST(Api, tensor) { EnvScope scope; const auto device = Device("cpu"); const auto shape = RandomShape(); const auto dtype = DType::kDouble; Tensor tensor; ASSERT_EQ(tensor.shape(), Shape()); ASSERT_EQ(tensor.device(), Device("cpu")); ASSERT_EQ(tensor.dtype(), DType::kFloat); Tensor tensor_with_all(shape, device, dtype); ASSERT_EQ(tensor_with_all.shape(), shape); ASSERT_EQ(tensor_with_all.device(), device); ASSERT_EQ(tensor_with_all.dtype(), dtype); } TEST(Api, tensor_from_buffer_and_copy_to) { EnvScope scope; const auto shape = RandomShape(); #define TEST_TENSOR_FROM_AND_TO_BLOB(dtype, cpp_dtype) \ std::vector data_##cpp_dtype(shape.Count(0)), new_data_##cpp_dtype(shape.Count(0)); \ for (int i = 0; i < shape.Count(0); ++i) { data_##cpp_dtype[i] = i; } \ auto tensor_##cpp_dtype = \ Tensor::from_buffer(data_##cpp_dtype.data(), shape, Device("cpu"), dtype); \ tensor_##cpp_dtype.copy_to(new_data_##cpp_dtype.data()); \ ASSERT_EQ(new_data_##cpp_dtype, data_##cpp_dtype); TEST_TENSOR_FROM_AND_TO_BLOB(DType::kFloat, float) TEST_TENSOR_FROM_AND_TO_BLOB(DType::kDouble, double) TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt8, int8_t) TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt32, int32_t) TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt64, int64_t) } TEST(Api, tensor_zeros) { EnvScope scope; const auto shape = RandomShape(); std::vector data(shape.Count(0)), target_data(shape.Count(0)); Tensor tensor(shape, Device("cpu"), DType::kFloat); tensor.zeros_(); tensor.copy_to(data.data()); std::fill(target_data.begin(), target_data.end(), 0); ASSERT_EQ(data, target_data); } } // namespace oneflow_api ================================================ FILE: oneflow/api/python/autograd/autograd.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/job_build/job_build_and_infer.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/saved_tensor_hooks.h" #include "oneflow/extension/stack/python/stack_getter.h" namespace oneflow { namespace autograd { namespace { bool IsScalarTensor(const one::Tensor& tensor) { const auto& shape = tensor.shape(); return shape->elem_cnt() == 1; } // Checks and sets default value for initial gradients based on out_grads // If output is the tensor whose size is greater than 1, out_grad's shape must be same as output's. // If output is a scalar tensor, out_grad will also be a scaler or empty(will be initted to // `oneflow.ones([1])`). Maybe CheckAndInitOutGrads(const one::TensorTuple& outputs, const one::TensorTuple& out_grads, bool is_grads_batched) { size_t grad_size = out_grads.empty() ? outputs.size() : out_grads.size(); auto gradients = std::make_shared(grad_size); CHECK_EQ_OR_RETURN(outputs.size(), gradients->size()) << "RuntimeError: got " << outputs.size() << " tensors and " << gradients->size() << " gradients"; for (int i = 0; i < outputs.size(); ++i) { CHECK_OR_RETURN(outputs.at(i)->requires_grad()) << "\nRuntimeError: element " << i << " of tensors does not require grad and does not have a grad_fn"; if (!outputs.at(i)->grad_fn_node()) { CHECK_OR_RETURN(outputs.at(i)->is_leaf()) << "output[" << i << "] doesn't have grad_fn and it is not leaf tensor!\n" << "It is a bug with oneflow, please submit an issue on GitHub: " "https://github.com/Oneflow-Inc/oneflow/issues"; JUST(one::AddAccumulateFunctionNode(outputs.at(i))); } if (out_grads.empty() || !out_grads.at(i)) { CHECK_OR_RETURN(IsScalarTensor(*outputs.at(i))) << "Grad can be implicitly created only for scalar outputs"; gradients->at(i) = JUST(one::functional::OnesLike(outputs.at(i))); } else { if (is_grads_batched) { if (*(outputs.at(i)->shape()) != *JUST(out_grads.at(i)->shape()->Slice(1))) { THROW(RuntimeError) << "If `is_grads_batched=True`, we interpret the first " << "dimension of each grad_output as the batch dimension. " << "The sizes of the remaining dimensions are expected to match " << "the shape of corresponding output, but a mismatch " << "was detected: grad_output[" << i << "] has a shape of " << out_grads.at(i)->shape()->ToString() << " and output[" << i << "] has a shape of " << outputs.at(i)->shape()->ToString() << "."; } } else { CHECK_EQ_OR_RETURN(*(outputs.at(i)->shape()), *(out_grads.at(i)->shape())) << "out_grad's shape must be same as output's (" << outputs.at(i)->shape()->ToString() << " vs " << out_grads.at(i)->shape()->ToString() << ")"; } if (JUST(oneflow::VectorAt(outputs, i))->dtype() != JUST(oneflow::VectorAt(out_grads, i))->dtype()) { JUST(oneflow::VectorAt(*gradients, i)) = JUST(one::functional::Cast(out_grads[i], outputs[i]->dtype(), /*pin_memory=*/false)); } else { JUST(oneflow::VectorAt(*gradients, i)) = out_grads[i]; } } } if (LazyMode::is_enabled()) { JUST(MarkOutputGradients(outputs, *gradients)); } return gradients; } } // namespace Maybe Backward(const one::TensorTuple& outputs, const one::TensorTuple& out_grads, bool retain_graph, bool create_graph) { PythonFrameGuard pf; BackwardPassScopeGuard backward_guard; if (create_graph) { retain_graph = true; } std::shared_ptr gradients = JUST(CheckAndInitOutGrads(outputs, out_grads, /*is_grads_batched=*/false)); JUST(one::GetThreadLocalAutogradEngine()->RunBackwardAndSaveGrads4LeafTensorIf( outputs, *gradients, retain_graph, create_graph)); return std::make_shared(0); } Maybe Grad(const one::TensorTuple& outputs, const one::TensorTuple& inputs, const one::TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused, bool is_grads_batched) { PythonFrameGuard pf; BackwardPassScopeGuard backward_guard; if (create_graph) { retain_graph = true; } if (inputs.empty()) { return Backward(outputs, out_grads, retain_graph, create_graph); } CHECK_OR_RETURN(std::all_of( inputs.begin(), inputs.end(), [](const std::shared_ptr& tensor) { return tensor->requires_grad(); })) << "All input tensors `.requires_grad` should be true"; std::shared_ptr gradients = JUST(CheckAndInitOutGrads(outputs, out_grads, is_grads_batched)); return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGradIf( outputs, inputs, *gradients, retain_graph, create_graph, allow_unused); } namespace py = pybind11; class PySavedTensorHook final : public one::SavedTensorHook { public: PySavedTensorHook(const py::function& pack_hook, const py::function& unpack_hook) : pack_hook_(pack_hook), unpack_hook_(unpack_hook) {} void pack(const std::shared_ptr& tensor) { py::gil_scoped_acquire acquire; py::object packed = pack_hook_(tensor); data_ = packed.release().ptr(); } std::shared_ptr unpack() { py::gil_scoped_acquire acquire; py::object obj = py::cast(data_); py::object x = unpack_hook_(obj); std::shared_ptr tensor; try { tensor = py::cast>(x); } catch (const py::cast_error& e) { THROW(RuntimeError) << "unpack_hook should return a Tensor, but got `" << py::str(x.get_type()).cast() << "` instead"; } return tensor; } private: PyObject* data_ = nullptr; py::function pack_hook_; py::function unpack_hook_; }; class PySavedTensorHookCreator final : public one::SavedTensorHookCreator { public: std::unique_ptr new_saved_tensor_hook() const override { if (hooks_.empty()) { return nullptr; } return std::make_unique(hooks_.back().first, hooks_.back().second); } void append_new_hooks(const py::function& pack_hook, const py::function& unpack_hook) { hooks_.emplace_back(pack_hook, unpack_hook); } void pop_hooks() { CHECK_OR_THROW(!hooks_.empty()) << "pop_hooks should not be called when there are no hooks"; hooks_.pop_back(); } private: small_vector, 1> hooks_; }; ONEFLOW_API_PYBIND11_MODULE("autograd", m) { m.def("backward", &Backward); m.def("grad", &Grad); m.def_submodule("graph") .def("register_saved_tensors_hook_manager", []() { Singleton::SetAllocated(new PySavedTensorHookCreator()); }) .def("append_new_hooks", [](const py::function& pack_hook, const py::function& unpack_hook) { PySavedTensorHookCreator* creator = dynamic_cast( Singleton::Get()); CHECK_NOTNULL_OR_THROW(creator) << "`register_saved_tensors_hook_manager` should be called " "before calling `append_new_hooks`"; creator->append_new_hooks(pack_hook, unpack_hook); }) .def("pop_hooks", []() { PySavedTensorHookCreator* creator = dynamic_cast(Singleton::Get()); CHECK_NOTNULL_OR_THROW(creator) << "`register_saved_tensors_hook_manager` should be called " "before calling `pop_hooks`"; creator->pop_hooks(); }); } } // namespace autograd } // namespace oneflow ================================================ FILE: oneflow/api/python/autograd/autograd_engine.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/global_param_grad_sync_mode.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>( m, "GlobalParamGradSyncMode") .def(py::init([](bool flag) { return std::make_shared(flag); })); } } // namespace oneflow ================================================ FILE: oneflow/api/python/autograd/autograd_function.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/core/autograd/autograd_function.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/tensor_tuple.h" namespace py = pybind11; namespace oneflow { namespace { // Transform input to TensorTuple Maybe UnpackTensorTuple(const py::object& input) { one::TensorTuple tp; if (one::PyTensor_Check(input.ptr())) { tp.emplace_back(input.cast>()); } else if (py::isinstance(input)) { auto tuple = input.cast(); tp.resize(tuple.size()); for (int i = 0; i < tuple.size(); ++i) { PyObject* obj = tuple[i].ptr(); if (obj == Py_None) { // do nothing } else if (one::PyTensor_Check(obj)) { tp[i] = one::PyTensor_Unpack(obj); } else { return Error::RuntimeError() << "expected Tensor or None as element " << i << ", but got " << one::functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj))); } } } else { return Error::RuntimeError() << "autograd.Function's output only support tensor or list of tensors"; } return tp; } // Return single Tensor when TensorTuple's size is one, otherwise py::tuple py::object PackTensorTuple(const one::TensorTuple& tp) { if (tp.size() == 1) { return py::cast(tp.at(0)); } else { py::tuple out = py::tuple(tp.size()); for (int i = 0; i < tp.size(); ++i) { out[i] = tp.at(i); } return py::cast(out); } } // wrap PyFunction, unpack the inputs from TensorTuple and pack outputs to TensorTuple one::AutogradFunctionBase::FType PackPyFunctionToFType(const py::function& func) { return [func](const std::shared_ptr& ctx, const one::TensorTuple& inputs) { const py::tuple& a = py::cast(inputs); py::object res = func(ctx, *a); return UnpackTensorTuple(res).GetPtrOrThrow(); }; } } // namespace namespace one { ONEFLOW_API_PYBIND11_MODULE("autograd", m) { py::class_>(m, "AutogradFunctionBase") .def(py::init([]() { return std::make_shared(); })) .def_static("apply", [](const std::string& name, const py::function& forward_fn, const py::function& backward_fn, const py::args& input) -> Maybe { const auto& input_tensor_tuple = JUST(UnpackTensorTuple(input)); const std::shared_ptr& res = JUST(AutogradFunctionBase::Apply( name, PackPyFunctionToFType(forward_fn), PackPyFunctionToFType(backward_fn), *input_tensor_tuple)); return PackTensorTuple(*res); }); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/autograd/autograd_function_state.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/autograd/autograd_function_state.h" #include #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/of_api_registry.h" namespace py = pybind11; namespace oneflow { namespace one { namespace { inline FunctionAutoGradCaptureState* CheckAndGetStateData(PyAutogradFunctionState* state) { if (!state->data.lock()) { PyErr_Format(PyExc_RuntimeError, "Data is deallocated. Please don't hold context outside " "autograd.Function.forward or autograd.Function.backward"); return nullptr; } return state->data.lock().get(); } } // namespace #if PY_VERSION_HEX < 0x03070000 #define PYGETSET_NAME(name) const_cast(name) #else #define PYGETSET_NAME(name) (name) #endif #define PY_XINCREF(p) (({ Py_XINCREF(p); }), (p)) static PyObject* PyAutogradFunctionState_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { PyAutogradFunctionState* self = (PyAutogradFunctionState*)type->tp_alloc(type, 0); if (self != NULL) { self->dynamic_attr_dict = PyDict_New(); if (self->dynamic_attr_dict == NULL) { Py_DECREF(self); return NULL; } } return (PyObject*)self; } static void PyAutogradFunctionState_dealloc(PyAutogradFunctionState* self) { Py_XDECREF(self->dynamic_attr_dict); Py_TYPE(self)->tp_free((PyObject*)self); } // PyMethodDef start static PyObject* PyAutogradFunctionState_save_for_backward(PyObject* self, PyObject* args) { HANDLE_ERRORS auto* _self = (PyAutogradFunctionState*)self; if (!functional::PyTensorSequenceCheck(args)) { return PyErr_Format(PyExc_TypeError, "save_for_backward() only support Tensor or Tensors"); } const std::vector>& tensor_list = functional::PyUnpackTensorSequence(args); for (const auto& tensor : tensor_list) { CheckAndGetStateData(_self)->SaveTensorForBackward(tensor); } Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyAutogradFunctionState_mark_non_differentiable(PyObject* self, PyObject* args) { HANDLE_ERRORS auto* _self = (PyAutogradFunctionState*)self; if (!functional::PyTensorSequenceCheck(args)) { return PyErr_Format(PyExc_TypeError, "save_for_backward() only support Tensor or Tensors"); } const std::vector>& tensor_list = functional::PyUnpackTensorSequence(args); for (const auto& tensor : tensor_list) { CheckAndGetStateData(_self)->MarkNonDifferentiable(tensor); } Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyAutogradFunctionState_is_data_valid(PyObject* self) { auto* _self = (PyAutogradFunctionState*)self; return functional::CastToPyObject(_self->data.lock() != nullptr); } static PyMethodDef PyAutogradFunctionState_methods[] = { {"save_for_backward", (PyCFunction)PyAutogradFunctionState_save_for_backward, METH_VARARGS, NULL}, {"mark_non_differentiable", (PyCFunction)PyAutogradFunctionState_mark_non_differentiable, METH_VARARGS, NULL}, {"_is_data_valid", (PyCFunction)PyAutogradFunctionState_is_data_valid, METH_NOARGS, NULL}, {NULL} /* Sentinel */ }; // PyMethodDef end // PyAutogradFunctionState_getset start static PyObject* PyAutogradFunctionState_saved_tensors(PyObject* self, void*) { auto* _self = (PyAutogradFunctionState*)self; return functional::CastToPyObject>( CheckAndGetStateData(_self)->SavedTensors()); } static PyObject* PyAutogradFunctionState_get_dict(PyObject* self, PyObject* args) { HANDLE_ERRORS auto* _self = (PyAutogradFunctionState*)self; return _self->dynamic_attr_dict; Py_RETURN_NONE; END_HANDLE_ERRORS } static PyGetSetDef PyAutogradFunctionState_properties[] = { {PYGETSET_NAME("saved_tensors"), (getter)PyAutogradFunctionState_saved_tensors, NULL, NULL, NULL}, {PYGETSET_NAME("__dict__"), (getter)PyAutogradFunctionState_get_dict, NULL, NULL, NULL}, {NULL} /* Sentinel */ }; // PyAutogradFunctionState_getset end PyObject* PyAutogradFunctionState_getattro(PyObject* self, PyObject* attr) { PyObject* res = NULL; res = PyDict_GetItem(((PyAutogradFunctionState*)self)->dynamic_attr_dict, attr); if (!res) { // Not found attr in dynamic_attr_dict, try to find it in tp_dict res = PyObject_GenericGetAttr(self, attr); if (!res) { return PyErr_Format(PyExc_AttributeError, "attribute %s not found", PyUnicode_AsUTF8(attr)); } } return res; } int PyAutogradFunctionState_setattro(PyObject* self, PyObject* attr, PyObject* value) { auto* _self = (PyAutogradFunctionState*)self; return PyDict_SetItem(_self->dynamic_attr_dict, attr, value); } PyTypeObject PyAutogradFunctionState_Type = { PyVarObject_HEAD_INIT(NULL, 0) "oneflow.autograd.Function.FunctionCtx", /* tp_name */ sizeof(PyAutogradFunctionState), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)PyAutogradFunctionState_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ NULL, /* tp_getattr */ NULL, /* tp_setattr */ NULL, /* tp_reserved */ NULL, /* tp_repr */ NULL, /* tp_as_number */ NULL, /* tp_as_sequence */ NULL, /* tp_as_mapping */ NULL, /* tp_hash */ NULL, /* tp_call */ NULL, /* tp_str */ PyAutogradFunctionState_getattro, /* tp_getattro */ PyAutogradFunctionState_setattro, /* tp_setattro */ NULL, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ NULL, /* tp_doc */ NULL, /* tp_traverse */ NULL, /* tp_clear */ NULL, /* tp_richcompare */ 0, /* tp_weaklistoffset */ NULL, /* tp_iter */ NULL, /* tp_iternext */ PyAutogradFunctionState_methods, /* tp_methods */ NULL, /* tp_members */ PyAutogradFunctionState_properties, /* tp_getset */ 0, /* tp_base */ NULL, /* tp_dict */ NULL, /* tp_descr_get */ NULL, /* tp_descr_set */ offsetof(PyAutogradFunctionState, dynamic_attr_dict), /* tp_dictoffset */ NULL, /* tp_init */ NULL, /* tp_alloc */ PyAutogradFunctionState_new, /* tp_new */ NULL, /* tp_free */ }; PyObject* PyAutogradFunctionState_NewFromPtr( const std::shared_ptr& data) { if (!data) { Py_RETURN_NONE; } if (data->pyobject()) { return PY_XINCREF((PyObject*)data->pyobject()); } auto* self = (PyAutogradFunctionState*)(PyObject_CallObject( (PyObject*)&PyAutogradFunctionState_Type, NULL)); if (self) { PY_XINCREF(self); self->data = data; CheckAndGetStateData(self)->set_pyobject_ptr( std::unique_ptr(self, [](void* ptr) { Py_DECREF((PyObject*)ptr); })); } return (PyObject*)self; } ONEFLOW_API_PYBIND11_MODULE("autograd.Function", m) { if (PyType_Ready(&PyAutogradFunctionState_Type) < 0) { return; } Py_INCREF(&PyAutogradFunctionState_Type); if (PyModule_AddObject(m.ptr(), "FunctionCtx", (PyObject*)&PyAutogradFunctionState_Type) < 0) { return; } } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/autograd/autograd_function_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_ #define ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_ #include #undef _PyGC_FINALIZED #include #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { typedef struct { PyObject_HEAD; PyObject* dynamic_attr_dict; std::weak_ptr data; } PyAutogradFunctionState; extern PyTypeObject PyAutogradFunctionState_Type; inline bool PyAutogradFunctionState_Check(PyObject* state) { return PyObject_TypeCheck(state, &PyAutogradFunctionState_Type); } PyObject* PyAutogradFunctionState_NewFromPtr( const std::shared_ptr& data); } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_ ================================================ FILE: oneflow/api/python/autograd/autograd_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/autograd/autograd_mode.h" namespace py = pybind11; namespace oneflow { namespace autograd { ONEFLOW_API_PYBIND11_MODULE("autograd", m) { py::class_>(m, "AutoGradMode") .def(py::init([](bool mode) { return std::make_shared(mode); })) .def("__enter__", [](const AutoGradMode& no_grad_obj) {}) .def("__exit__", [](const AutoGradMode& no_grad_obj, const py::object& type, const py::object& value, const py::object& traceback) {}); m.def("is_grad_enabled", &GradMode::is_enabled); m.def("set_grad_enabled", &GradMode::set_enabled); } } // namespace autograd } // namespace oneflow ================================================ FILE: oneflow/api/python/autograd/function_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/autograd/autograd_engine.h" namespace py = pybind11; namespace oneflow { namespace { struct FunctionNodeUtil final { static std::string ToString(const one::FunctionNode& func_node) { std::stringstream ss; ss << "<"; ss << func_node.name(); ss << " at " << &func_node; ss << ">"; return ss.str(); } }; } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "FunctionNode") .def("__str__", &FunctionNodeUtil::ToString) .def("__repr__", &FunctionNodeUtil::ToString) .def("_register_hook_dict", []() { TODO(); }) .def_property_readonly( "next_functions", [](const one::FunctionNode& func_node) { return func_node.next_functions(); }) .def_property_readonly("metadata", []() { TODO(); }) .def_property_readonly("requires_grad", []() { TODO(); }) .def("register_hook", &one::FunctionNode::add_post_hook) .def("name", [](const one::FunctionNode& func_node) { return func_node.name(); }) .def_property_readonly( "variable", [](const one::FunctionNode& func_node) { return func_node.Variable(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/caster/autograd_function_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_ #define ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_ #include #include "oneflow/api/python/caster/common.h" #include "oneflow/api/python/autograd/autograd_function_state.h" namespace py = pybind11; namespace pybind11 { namespace detail { template struct autograd_function_state_type_caster { public: bool load(handle src, bool convert) { using namespace oneflow::one; value_ = nullptr; if (!src) { return false; } if (src.is_none()) { return true; } if (!PyAutogradFunctionState_Check(src.ptr())) { return false; } value_ = ((PyAutogradFunctionState*)src.ptr())->data; return true; } template static handle cast(U&& src, return_value_policy policy, handle parent) { using namespace oneflow::one; return reinterpret_steal( PyAutogradFunctionState_NewFromPtr( std::const_pointer_cast(src))) .release(); } operator std::shared_ptr*() { return &value_; } operator std::shared_ptr&() { return value_; } operator std::shared_ptr&&() && { return std::move(value_); } static constexpr auto name = _("autograd_function_state"); protected: std::shared_ptr value_; }; template<> struct type_caster> : public autograd_function_state_type_caster {}; template<> struct type_caster> : public autograd_function_state_type_caster { }; } // namespace detail } // namespace pybind11 #endif // ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_ ================================================ FILE: oneflow/api/python/caster/common.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_COMMON_H_ #define ONEFLOW_API_PYTHON_CASTER_COMMON_H_ #include #include namespace pybind11 { namespace detail { // The condition follows the pybind11 source code template using IsSupportedByPybind11WhenInsideSharedPtr = std::is_base_of, type_caster>; #define PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(type, py_name) \ protected: \ std::shared_ptr value; \ \ public: \ static constexpr auto name = py_name; \ template>::value, int> = 0> \ static handle cast(T_* src, return_value_policy policy, handle parent) { \ if (!src) return none().release(); \ if (policy == return_value_policy::take_ownership) { \ auto h = cast(std::move(*src), policy, parent); \ delete src; \ return h; \ } \ return cast(*src, policy, parent); \ } \ operator type*() { return value.get(); } \ operator type&() { return *value; } \ operator type&&()&& { return std::move(*value); } \ template \ using cast_op_type = pybind11::detail::movable_cast_op_type } // namespace detail } // namespace pybind11 #endif // ONEFLOW_API_PYTHON_CASTER_COMMON_H_ ================================================ FILE: oneflow/api/python/caster/maybe.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_MAYBE_H_ #define ONEFLOW_API_PYTHON_CASTER_MAYBE_H_ #include #include "oneflow/api/python/caster/common.h" #include "oneflow/core/common/maybe.h" namespace pybind11 { namespace detail { using oneflow::Maybe; namespace impl { template using IsHoldedInsideSharedPtrByMaybe = std::is_same>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()), std::shared_ptr>; template::value && IsHoldedInsideSharedPtrByMaybe::value, int> = 0> std::shared_ptr GetOrThrowHelper(Maybe x) { return x.GetPtrOrThrow(); } template::value || !IsHoldedInsideSharedPtrByMaybe::value, int> = 0> T GetOrThrowHelper(Maybe x) { return x.GetOrThrow(); } } // namespace impl // Information about pybind11 custom type caster can be found // at oneflow/api/python/caster/optional.h, and also at // https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html template struct maybe_caster { using Value = decltype(impl::GetOrThrowHelper(std::declval())); using value_conv = make_caster; bool load(handle src, bool convert) { if (!src) { return false; } if (src.is_none()) { // Maybe (except Maybe) does not accept `None` from Python. Users can use Optional in // those cases. return false; } value_conv inner_caster; if (!inner_caster.load(src, convert)) { return false; } value = std::make_shared(cast_op(std::move(inner_caster))); return true; } template static handle cast(T&& src, return_value_policy policy, handle parent) { if (!std::is_lvalue_reference::value) { policy = return_value_policy_override::policy(policy); } return value_conv::cast(impl::GetOrThrowHelper(std::forward(src)), policy, parent); } PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe, _("Maybe[void]")); }; template<> struct maybe_caster> { template static handle cast(T&& src, return_value_policy policy, handle parent) { if (!src.IsOk()) { oneflow::ThrowError(src.stacked_error()); } return none().inc_ref(); } bool load(handle src, bool convert) { if (src && src.is_none()) { return true; // None is accepted because NoneType (i.e. void) is the value type of // Maybe } return false; } PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe, _("Maybe[void]")); }; template struct type_caster> : public maybe_caster> {}; } // namespace detail } // namespace pybind11 #endif // ONEFLOW_API_PYTHON_CASTER_MAYBE_H_ ================================================ FILE: oneflow/api/python/caster/optional.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_ #define ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_ #include #include "oneflow/api/python/caster/common.h" #include "oneflow/core/common/optional.h" namespace pybind11 { namespace detail { using oneflow::Optional; namespace impl { template T& DeferenceIfSharedPtr(std::shared_ptr ptr) { return *ptr; } template T&& DeferenceIfSharedPtr(T&& obj) { return std::forward(obj); } template using IsHoldedInsideSharedPtrByOptional = std::is_same::storage_type, std::shared_ptr>; template::value && IsHoldedInsideSharedPtrByOptional::value, int> = 0> std::shared_ptr GetDataHelper(Optional x) { return CHECK_JUST(x); } template::value || !IsHoldedInsideSharedPtrByOptional::value, int> = 0> T GetDataHelper(Optional x) { return DeferenceIfSharedPtr(CHECK_JUST(x)); } } // namespace impl // Code is copied from pybind11 include/pybind11/stl.h // Comments wrapped by /* */ are copied from // https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html template struct oneflow_optional_caster { using Value = decltype(impl::GetDataHelper(std::declval())); using value_conv = make_caster; /** * Conversion part 1 (Python->C++): convert a PyObject into a Optional * instance or return false upon failure. The second argument * indicates whether implicit conversions should be applied. */ bool load(handle src, bool convert) { if (!src) { return false; } if (src.is_none()) { return true; // default-constructed value is already empty } value_conv inner_caster; if (!inner_caster.load(src, convert)) { return false; } value = cast_op(std::move(inner_caster)); return true; } /** * Conversion part 2 (C++ -> Python): convert an Optional instance into * a Python object. The second and third arguments are used to * indicate the return value policy and parent object (for * ``return_value_policy::reference_internal``) and are generally * ignored by implicit casters. */ template static handle cast(T&& src, return_value_policy policy, handle parent) { if (!src) { return none().inc_ref(); } if (!std::is_lvalue_reference::value) { policy = return_value_policy_override::policy(policy); } return value_conv::cast(impl::GetDataHelper(std::forward(src)), policy, parent); } /** * This macro establishes the name 'Optional[T]' in * function signatures and declares a local variable * 'value' of type inty */ PYBIND11_TYPE_CASTER(Type, _("Optional[") + value_conv::name + _("]")); }; template struct type_caster> : public oneflow_optional_caster> {}; } // namespace detail } // namespace pybind11 #endif // ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_ ================================================ FILE: oneflow/api/python/caster/size.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_SIZE_H_ #define ONEFLOW_API_PYTHON_CASTER_SIZE_H_ #include #include #undef _PyGC_FINALIZED #include #include "oneflow/api/python/framework/size.h" #include "oneflow/core/common/shape.h" PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) class shape : public object { public: PYBIND11_OBJECT_CVT(shape, object, oneflow::TensorSize_Check, raw_shape) explicit shape(size_t size = 0) : object(oneflow::TensorSize_New((ssize_t)size), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate tensor size object!"); } size_t size() const { return (size_t)PyTuple_Size(m_ptr); } bool empty() const { return size() == 0; } detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } detail::item_accessor operator[](handle h) const { return object::operator[](h); } detail::tuple_iterator begin() const { return {*this, 0}; } detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } private: static PyObject* raw_shape(PyObject* op) { if (oneflow::TensorSize_Check(op)) return handle(op).inc_ref().ptr(); return PyObject_CallFunctionObjArgs((PyObject*)&oneflow::TensorSize_Type, op, NULL); } }; PYBIND11_NAMESPACE_BEGIN(detail) template struct shape_type_caster { public: bool load(handle src, bool convert) { value_ = nullptr; if (src && src.is_none()) { return true; } if (!oneflow::TensorSize_Check(src.ptr())) { return false; } value_ = std::make_shared(oneflow::TensorSize_AsShape(src.ptr())); return true; } template static handle cast(U&& src, return_value_policy /*policy*/, handle /*parent*/) { return cast_impl(std::forward(src)); } template static handle cast(U* src, return_value_policy policy, handle parent) { if (!src) { return none().release(); } return cast(*src, policy, parent); } operator T*() { return value_.get(); } operator T&() { return *value_; } operator T&&() && { return std::move(*value_); } operator std::shared_ptr*() { return &value_; } operator std::shared_ptr&() { return value_; } operator std::shared_ptr&&() && { return std::move(value_); } static constexpr auto name = _("shape"); template using cast_op_type = pybind11::detail::cast_op_type>; private: static handle cast_impl(const oneflow::Shape& src) { return reinterpret_steal(oneflow::TensorSize_NewFromShape(src)).release(); } static handle cast_impl(const std::shared_ptr& src) { return reinterpret_steal(oneflow::TensorSize_NewFromShape(*src)).release(); } protected: std::shared_ptr value_; }; template<> struct type_caster : public shape_type_caster {}; template<> struct type_caster> : public shape_type_caster {}; template<> struct type_caster> : public shape_type_caster {}; PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) #endif // ONEFLOW_API_PYTHON_CASTER_SIZE_H_ ================================================ FILE: oneflow/api/python/caster/tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CASTER_TENSOR_H_ #define ONEFLOW_API_PYTHON_CASTER_TENSOR_H_ #include #include "oneflow/api/python/caster/common.h" #include "oneflow/api/python/framework/tensor.h" namespace pybind11 { namespace detail { template struct tensor_type_caster { public: bool load(handle src, bool convert) { using namespace oneflow::one; value_ = nullptr; if (!src) { return false; } if (src.is_none()) { return true; } if (!PyTensor_Check(src.ptr())) { return false; } value_ = PyTensor_Unpack(src.ptr()); return true; } template static handle cast(U&& src, return_value_policy policy, handle parent) { using namespace oneflow::one; return reinterpret_steal(PyTensor_New(std::const_pointer_cast(src))).release(); } operator std::shared_ptr*() { return &value_; } operator std::shared_ptr&() { return value_; } operator std::shared_ptr&&() && { return std::move(value_); } static constexpr auto name = _("tensor"); template using cast_op_type = pybind11::detail::cast_op_type>; protected: std::shared_ptr value_; }; template struct parameter_type_caster { public: bool load(handle src, bool convert) { using namespace oneflow::one; value_ = nullptr; if (!src) { return false; } if (src.is_none()) { return true; } if (!PyTensor_Check(src.ptr())) { return false; } value_ = PyTensor_Unpack(src.ptr()); return true; } template static handle cast(U&& src, return_value_policy policy, handle parent) { using namespace oneflow::one; return reinterpret_steal(PyParameter_New(std::const_pointer_cast(src))) .release(); } operator std::shared_ptr*() { return &value_; } operator std::shared_ptr&() { return value_; } operator std::shared_ptr&&() && { return std::move(value_); } static constexpr auto name = _("parameter"); template using cast_op_type = pybind11::detail::cast_op_type>; protected: std::shared_ptr value_; }; template<> struct type_caster> : public tensor_type_caster {}; template<> struct type_caster> : public tensor_type_caster {}; template<> struct type_caster> : public parameter_type_caster {}; template<> struct type_caster> : public parameter_type_caster {}; } // namespace detail } // namespace pybind11 #endif // ONEFLOW_API_PYTHON_CASTER_TENSOR_H_ ================================================ FILE: oneflow/api/python/caster/test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" namespace py = pybind11; namespace oneflow { class A { public: void inc_x() { x++; } int get_x() { return x; } private: int x = 0; }; std::shared_ptr get_singleton_a() { static std::shared_ptr a = std::make_shared(); return a; } ONEFLOW_API_PYBIND11_MODULE("test_api", m) { py::class_>(m, "A").def("inc_x", &A::inc_x).def("get_x", &A::get_x); m.def("get_singleton_a", []() -> Maybe { return get_singleton_a(); }); m.def("increase_x_of_a_if_not_none", [](const Optional& a) -> Optional { a.map([](const std::shared_ptr& a) -> std::shared_ptr { a->inc_x(); return a; }); return a; }); m.def("increase_if_not_none", [](const Optional& x) -> Optional { return x.map([](int i) { return i + 1; }); }); m.def("divide", [](float x, float y) -> Maybe { CHECK_NE_OR_RETURN(y, 0); return x / y; }); m.def("throw_if_zero", [](int x) -> Maybe { CHECK_NE_OR_RETURN(x, 0); return Maybe::Ok(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/deprecated.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/dtype.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("deprecated", m) { m.def("GetProtoDtype4OfDtype", [](const Symbol& x) { return static_cast(x->data_type()); }); m.def("GetDTypeByDataType", [](int data_type) { return DType::Get(static_cast(data_type)); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/dlpack/converter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/dlpack/dlpack.h" #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/tensor_util.h" namespace oneflow { Maybe> ToOneFlowDevice(const DLDevice& ctx) { switch (ctx.device_type) { case DLDeviceType::kDLCPU: return JUST(Device::New("cpu")); #ifdef WITH_CUDA case DLDeviceType::kDLCUDA: return JUST(Device::New("cuda", ctx.device_id)); #endif default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported device type: " << ctx.device_type; } } Maybe ToOneFlowDataType(const DLDataType& dtype) { DataType ofdtype = DataType::kInvalidDataType; CHECK_EQ_OR_RETURN(dtype.lanes, 1) << "OneFlow does not support lanes != 1"; switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { case 8: ofdtype = DataType::kUInt8; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: " << dtype.code << dtype.bits; } break; case DLDataTypeCode::kDLInt: switch (dtype.bits) { case 8: ofdtype = DataType::kInt8; break; case 16: ofdtype = DataType::kInt16; break; case 32: ofdtype = DataType::kInt32; break; case 64: ofdtype = DataType::kInt64; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: " << dtype.code << dtype.bits; } break; case DLDataTypeCode::kDLFloat: switch (dtype.bits) { case 16: ofdtype = DataType::kFloat16; break; case 32: ofdtype = DataType::kFloat; break; case 64: ofdtype = DataType::kDouble; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: " << dtype.code << dtype.bits; } break; case DLDataTypeCode::kDLBfloat: switch (dtype.bits) { case 16: ofdtype = DataType::kBFloat16; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: bfloat" << dtype.bits; } break; case DLDataTypeCode::kDLComplex: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: complex" << dtype.bits; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported code " << dtype.code; } CHECK_NE_OR_RETURN(ofdtype, DataType::kInvalidDataType); return ofdtype; } Maybe fromDLPack(const DLManagedTensor* src) { using namespace one; const auto& dl_tensor = src->dl_tensor; Symbol device = JUST(ToOneFlowDevice(dl_tensor.device)); DataType dtype = JUST(ToOneFlowDataType(dl_tensor.dtype)); // Build TensorMeta const Shape shape(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); Symbol tensor_meta; if (dl_tensor.strides) { const auto stride = Stride(dl_tensor.strides, dl_tensor.strides + dl_tensor.ndim); tensor_meta = SymbolOf(LocalTensorMeta(shape, stride, dtype, MemoryFormat::kContiguous, device)); } else { tensor_meta = SymbolOf(LocalTensorMeta(shape, dtype, MemoryFormat::kContiguous, device)); } // Build TensorBuffer const auto& Free = [src](char* dptr) { if (src->deleter) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) src->deleter(const_cast(src)); } }; size_t array_size_in_bytes = shape.elem_cnt() * GetSizeOfDataType(dtype); auto tensor_data = std::make_shared(false, device); tensor_data->set_blob_dptr( std::unique_ptr>(static_cast(dl_tensor.data), Free), array_size_in_bytes); // Build TensorStorage: decrease ndarray reference count before releasing auto tensor_storage = std::make_shared(tensor_data); // Build Tensor auto tensor_impl = std::make_shared(tensor_storage, /*requires_grad=*/false, /*ls_leaf=*/true); // Init blob JUST(tensor_impl->InitEagerBlobObject(tensor_meta, NewLocalDepObject())); const auto& stream = JUST(GetDefaultStreamByDevice(device)); const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object()); JUST(eager_blob_object->init_producer_stream(stream)); eager_blob_object->set_last_used_stream(stream); return std::static_pointer_cast(std::make_shared(tensor_impl)); } Maybe ToDLDevice(Symbol ofdevice) { DLDevice ctx; ctx.device_id = ofdevice->device_id(); switch (ofdevice->enum_type()) { case DeviceType::kCPU: ctx.device_type = DLDeviceType::kDLCPU; break; #ifdef WITH_CUDA case DeviceType::kCUDA: ctx.device_type = DLDeviceType::kDLCUDA; break; #endif default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported device type: " << ofdevice->type(); } return ctx; } Maybe ToDLDataType(DataType ofdtype) { DLDataType dtype; dtype.lanes = 1; dtype.bits = GetSizeOfDataType(ofdtype) * 8; switch (ofdtype) { case DataType::kUInt8: dtype.code = DLDataTypeCode::kDLUInt; break; case DataType::kInt8: dtype.code = DLDataTypeCode::kDLInt; break; case DataType::kInt16: dtype.code = DLDataTypeCode::kDLInt; break; case DataType::kInt32: dtype.code = DLDataTypeCode::kDLInt; break; case DataType::kInt64: dtype.code = DLDataTypeCode::kDLInt; break; case DataType::kFloat16: dtype.code = DLDataTypeCode::kDLFloat; break; case DataType::kFloat: dtype.code = DLDataTypeCode::kDLFloat; break; case DataType::kDouble: dtype.code = DLDataTypeCode::kDLFloat; break; case DataType::kBFloat16: dtype.code = DLDataTypeCode::kDLBfloat; break; default: UNIMPLEMENTED_THEN_RETURN() << "Unsupported data type: " << DataType_Name(ofdtype); } return dtype; } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct ATenDLMTensor { std::shared_ptr handle; DLManagedTensor tensor; }; void deleter(DLManagedTensor* arg) { delete static_cast(arg->manager_ctx); } Maybe toDLPack(const std::shared_ptr& src) { auto shape = *src->shape(); auto strides = *JUST(src->stride()); // create a new tensor with possibly normalized strides // Reference: // https://github.com/pytorch/pytorch/issues/83069 // https://github.com/pytorch/pytorch/issues/82610 for (int i = 0; i < src->ndim(); i++) { if (shape[i] <= 1) { strides[i] = 1; } } ATenDLMTensor* atDLMTensor(new ATenDLMTensor); atDLMTensor->handle = src; atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; JUST(one::SyncAccessTensorWithTimeOut( src, [&](ep::Stream*, const std::shared_ptr& tensor) { atDLMTensor->tensor.dl_tensor.data = tensor->mut_raw_dptr(); }, "const")); auto dldevice = JUST(ToDLDevice(JUST(src->device()))); auto dldtype = JUST(ToDLDataType(src->dtype()->data_type())); atDLMTensor->tensor.dl_tensor.device = *dldevice; atDLMTensor->tensor.dl_tensor.ndim = src->ndim(); atDLMTensor->tensor.dl_tensor.dtype = *dldtype; atDLMTensor->tensor.dl_tensor.shape = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(src->shape()->data()); atDLMTensor->tensor.dl_tensor.strides = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(JUST(src->stride())->data()); atDLMTensor->tensor.dl_tensor.byte_offset = 0; return &(atDLMTensor->tensor); } // This function is mostly copied from PyTorch void DLPack_Capsule_Destructor(PyObject* data) { if (likely(!PyCapsule_IsValid(data, "dltensor"))) { // early out, see DLPack spec: if a consuming library sets the capsule // name to something else, they own it and we don't need to do anything return; } HANDLE_ERRORS // Causes overheads for validity checks again, but this case is rare // since consuming libraries should rename the capsule according to spec. // Note that this cannot set a python error (we checked validity above), // so we don't need to handle python error state here. DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); // the dlMTensor has not been consumed, call deleter ourselves. // DLPack spec mentions that deleter may be NULL, but deleter from // `flow.to_dlpack` is never NULL, so no need for an additional check here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) dlMTensor->deleter(const_cast(dlMTensor)); END_HANDLE_ERRORS_RET() } namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("to_dlpack", [](const std::shared_ptr& tensor) -> Maybe { DLManagedTensor* dlMTensor = JUST(toDLPack(tensor)); return py::capsule(dlMTensor, "dltensor", DLPack_Capsule_Destructor); }); // from_dlpack is exported in tensor_api.yaml } } // namespace oneflow ================================================ FILE: oneflow/api/python/dlpack/converter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/dlpack/dlpack.h" #include "oneflow/core/common/maybe.h" namespace oneflow { namespace one { class Tensor; } Maybe fromDLPack(const DLManagedTensor* src); Maybe toDLPack(const std::shared_ptr& src); } // namespace oneflow ================================================ FILE: oneflow/api/python/dlpack/dlpack.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /*! * Copyright (c) 2017 by Contributors * \file dlpack.h * \brief The common header of DLPack. */ #ifndef DLPACK_DLPACK_H_ #define DLPACK_DLPACK_H_ /** * \brief Compatibility with C++ */ #ifdef __cplusplus #define DLPACK_EXTERN_C extern "C" #else #define DLPACK_EXTERN_C #endif /*! \brief The current version of dlpack */ #define DLPACK_VERSION 70 /*! \brief The current ABI version of dlpack */ #define DLPACK_ABI_VERSION 1 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 #ifdef DLPACK_EXPORTS #define DLPACK_DLL __declspec(dllexport) #else #define DLPACK_DLL __declspec(dllimport) #endif #else #define DLPACK_DLL #endif #include #include #ifdef __cplusplus extern "C" { #endif /*! * \brief The device type in DLDevice. */ #ifdef __cplusplus typedef enum : int32_t { #else typedef enum { #endif /*! \brief CPU device */ kDLCPU = 1, /*! \brief CUDA GPU device */ kDLCUDA = 2, /*! * \brief Pinned CUDA CPU memory by cudaMallocHost */ kDLCUDAHost = 3, /*! \brief OpenCL devices. */ kDLOpenCL = 4, /*! \brief Vulkan buffer for next generation graphics. */ kDLVulkan = 7, /*! \brief Metal for Apple GPU. */ kDLMetal = 8, /*! \brief Verilog simulator buffer */ kDLVPI = 9, /*! \brief ROCm GPUs for AMD GPUs */ kDLROCM = 10, /*! * \brief Pinned ROCm CPU memory allocated by hipMallocHost */ kDLROCMHost = 11, /*! * \brief Reserved extension device type, * used for quickly test extension device * The semantics can differ depending on the implementation. */ kDLExtDev = 12, /*! * \brief CUDA managed/unified memory allocated by cudaMallocManaged */ kDLCUDAManaged = 13, /*! * \brief Unified shared memory allocated on a oneAPI non-partititioned * device. Call to oneAPI runtime is required to determine the device * type, the USM allocation type and the sycl context it is bound to. * */ kDLOneAPI = 14, /*! \brief GPU support for next generation WebGPU standard. */ kDLWebGPU = 15, /*! \brief Qualcomm Hexagon DSP */ kDLHexagon = 16, } DLDeviceType; /*! * \brief A Device for Tensor and operator. */ typedef struct { /*! \brief The device type used in the device. */ DLDeviceType device_type; /*! * \brief The device index. * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. */ int32_t device_id; } DLDevice; /*! * \brief The type code options DLDataType. */ typedef enum { /*! \brief signed integer */ kDLInt = 0U, /*! \brief unsigned integer */ kDLUInt = 1U, /*! \brief IEEE floating point */ kDLFloat = 2U, /*! * \brief Opaque handle type, reserved for testing purposes. * Frameworks need to agree on the handle data type for the exchange to be well-defined. */ kDLOpaqueHandle = 3U, /*! \brief bfloat16 */ kDLBfloat = 4U, /*! * \brief complex number * (C/C++/Python layout: compact struct per complex number) */ kDLComplex = 5U, } DLDataTypeCode; /*! * \brief The data type the tensor can hold. The data type is assumed to follow the * native endian-ness. An explicit error message should be raised when attempting to * export an array with non-native endianness * * Examples * - float: type_code = 2, bits = 32, lanes=1 * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 * - int8: type_code = 0, bits = 8, lanes=1 * - std::complex: type_code = 5, bits = 64, lanes = 1 */ typedef struct { /*! * \brief Type code of base types. * We keep it uint8_t instead of DLDataTypeCode for minimal memory * footprint, but the value should be one of DLDataTypeCode enum values. * */ uint8_t code; /*! * \brief Number of bits, common choices are 8, 16, 32. */ uint8_t bits; /*! \brief Number of lanes in the type, used for vector types. */ uint16_t lanes; } DLDataType; /*! * \brief Plain C Tensor object, does not manage memory. */ typedef struct { /*! * \brief The data pointer points to the allocated data. This will be CUDA * device pointer or cl_mem handle in OpenCL. It may be opaque on some device * types. This pointer is always aligned to 256 bytes as in CUDA. The * `byte_offset` field should be used to point to the beginning of the data. * * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, * TVM, perhaps others) do not adhere to this 256 byte aligment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. * * For given DLTensor, the size of memory required to store the contents of * data is calculated as follows: * * \code{.c} * static inline size_t GetDataSize(const DLTensor* t) { * size_t size = 1; * for (tvm_index_t i = 0; i < t->ndim; ++i) { * size *= t->shape[i]; * } * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; * return size; * } * \endcode */ void* data; /*! \brief The device of the tensor */ DLDevice device; /*! \brief Number of dimensions */ int32_t ndim; /*! \brief The data type of the pointer*/ DLDataType dtype; /*! \brief The shape of the tensor */ int64_t* shape; /*! * \brief strides of the tensor (in number of elements, not bytes) * can be NULL, indicating tensor is compact and row-majored. */ int64_t* strides; /*! \brief The offset in bytes to the beginning pointer to data */ uint64_t byte_offset; } DLTensor; /*! * \brief C Tensor object, manage memory of DLTensor. This data structure is * intended to facilitate the borrowing of DLTensor by another framework. It is * not meant to transfer the tensor. When the borrowing framework doesn't need * the tensor, it should call the deleter to notify the host that the resource * is no longer needed. */ typedef struct DLManagedTensor { /*! \brief DLTensor which is being memory managed */ DLTensor dl_tensor; /*! \brief the context of the original host framework of DLManagedTensor in * which DLManagedTensor is used in the framework. It can also be NULL. */ void* manager_ctx; /*! \brief Destructor signature void (*)(void*) - this should be called * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL * if there is no way for the caller to provide a reasonable destructor. * The destructors deletes the argument self as well. */ void (*deleter)(struct DLManagedTensor* self); } DLManagedTensor; #ifdef __cplusplus } // DLPACK_EXTERN_C #endif #endif // DLPACK_DLPACK_H_ ================================================ FILE: oneflow/api/python/eager/eager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" ONEFLOW_API_PYBIND11_MODULE("eager", m) { using namespace oneflow; namespace py = pybind11; m.def( "Sync", []() { return vm::CurrentRankSync(); }, py::call_guard()); m.def( "ClusterSync", []() { return vm::ClusterSync(); }, py::call_guard()); py::class_>( m, "DevVmDepObjectConsumeModeGuard"); m.def("SourceOpOnlyResourceDependenceModeGuard", []() { return std::make_shared( one::DevVmDepObjectConsumeMode::NONE); }); } ================================================ FILE: oneflow/api/python/env/env.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/env/env.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/job/graph_scope_vars.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/mem_util.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace py = pybind11; namespace oneflow { #ifdef WITH_CUDA void RegisterCudaDeviceProperties(py::module& m) { py::class_(m, "_CudaDeviceProperties", py::module_local()) .def(py::init<>()) .def_readonly("name", &cudaDeviceProp::name) .def_readonly("major", &cudaDeviceProp::major) .def_readonly("minor", &cudaDeviceProp::minor) .def_readonly("is_multi_gpu_board", &cudaDeviceProp::isMultiGpuBoard) .def_readonly("is_integrated", &cudaDeviceProp::integrated) .def_readonly("multi_processor_count", &cudaDeviceProp::multiProcessorCount) .def_readonly("total_memory", &cudaDeviceProp::totalGlobalMem) .def("__repr__", [](const cudaDeviceProp& prop) { std::ostringstream stream; stream << "_CudaDeviceProperties(name='" << prop.name << "', major=" << prop.major << ", minor=" << prop.minor << ", total_memory=" << prop.totalGlobalMem / (1024 * 1024) << "MB, multi_processor_count=" << prop.multiProcessorCount << ")"; return stream.str(); }); } #endif // WITH_CUDA Maybe SwitchToShuttingDownPhase(EnvGlobalObjectsScope* env, bool is_normal_exit) { JUST(env->init_is_normal_exit(is_normal_exit)); SetShuttingDown(true); if (is_normal_exit) { JUST(vm::ClusterSync()); auto* vm = JUST(SingletonMaybe()); JUST(vm->CloseVMThreads()); } return Maybe::Ok(); } ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("CurrentResource", &CurrentResource); m.def("EnvResource", &EnvResource); py::class_>( m, "EnvContext") .def(py::init()) .def("SwitchToShuttingDownPhase", &SwitchToShuttingDownPhase, py::call_guard()); m.def("CurrentMachineId", &CurrentMachineId); m.def("GetRank", &GetRank); m.def("GetWorldSize", &GetWorldSize); m.def("GetNodeSize", &GetNodeSize); m.def("GetLocalRank", &GetLocalRank); m.def("InitRDMA", &InitRDMA); m.def("RDMAIsInitialized", &RDMAIsInitialized); m.def("DestoryRDMA", &DestoryRDMA); m.def("CudaGetDeviceCount", &CudaGetDeviceCount); m.def("EmptyCache", &EmptyCache); #ifdef WITH_CUDA RegisterCudaDeviceProperties(m); m.def("GetCudaDeviceIndex", &GetCudaDeviceIndex); m.def("SetCudaDeviceIndex", &SetCudaDeviceIndex); m.def("CudaSynchronize", &CudaSynchronize); m.def("GetCUDAMemoryUsed", &GetCUDAMemoryUsed); m.def("GetCPUMemoryUsed", &GetCPUMemoryUsed); m.def("CudaMemGetInfo", [](int device) -> std::pair { CudaCurrentDeviceGuard guard(device); size_t device_free = 0; size_t device_total = 0; OF_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); return {device_free, device_total}; }); m.def( "_get_device_properties", [](int device) -> cudaDeviceProp* { return GetDeviceProperties(device); }, py::return_value_policy::reference); #endif // WITH_CUDA m.def("SetFLAGS_alsologtostderr", &SetFLAGS_alsologtostderr); m.def("GetFLAGS_alsologtostderr", &GetFLAGS_alsologtostderr); m.def("SetFLAGS_v", &SetFLAGS_v); m.def("GetFLAGS_v", &GetFLAGS_v); m.def("SetGraphLRVerbose", &SetGraphLRVerbose); m.def("GetGraphLRVerbose", &GetGraphLRVerbose); m.def("SetGraphDebugMaxPyStackDepth", &SetGraphDebugMaxPyStackDepth); m.def("GetGraphDebugMaxPyStackDepth", &GetGraphDebugMaxPyStackDepth); m.def("SetGraphDebugMode", &SetGraphDebugMode); m.def("GetGraphDebugMode", &GetGraphDebugMode); m.def("SetGraphDebugOnlyUserPyStack", &SetGraphDebugOnlyUserPyStack); m.def("GetGraphDebugOnlyUserPyStack", &GetGraphDebugOnlyUserPyStack); m.def("InitPythonPathsToBeKeptAndFilteredForDebugging", &InitPythonPathsToBeKeptAndFilteredForDebugging); } } // namespace oneflow ================================================ FILE: oneflow/api/python/env/env.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_ENV_ENV_H_ #define ONEFLOW_API_PYTHON_ENV_ENV_H_ #include #include #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/graph_scope_vars.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/rpc/include/base.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" namespace oneflow { inline Maybe CurrentResource() { CHECK_NOTNULL_OR_RETURN((Singleton::Get())); return PbMessage2TxtString(Singleton::Get()->resource()); } inline Maybe EnvResource() { CHECK_NOTNULL_OR_RETURN((Singleton::Get())); return PbMessage2TxtString(Singleton::Get()->resource()); } inline Maybe CurrentMachineId() { return GlobalProcessCtx::Rank(); } inline Maybe GetRank() { return GlobalProcessCtx::Rank(); } inline Maybe GetWorldSize() { return GlobalProcessCtx::WorldSize(); } inline Maybe GetNodeSize() { return GlobalProcessCtx::NodeSize(); } inline Maybe GetLocalRank() { return GlobalProcessCtx::LocalRank(); } inline Maybe CudaGetDeviceCount() { return Singleton::Get()->GetDeviceCount(DeviceType::kCUDA); } inline Maybe SetFLAGS_alsologtostderr(bool flag) { FLAGS_alsologtostderr = flag; return Maybe::Ok(); } inline Maybe GetFLAGS_alsologtostderr() { return FLAGS_alsologtostderr; } // namespace oneflow inline Maybe SetFLAGS_v(int32_t v_level) { FLAGS_v = v_level; return Maybe::Ok(); } inline Maybe GetFLAGS_v() { return FLAGS_v; } inline Maybe EmptyCache() { JUST(vm::CurrentRankSync()); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); return Maybe::Ok(); } inline Maybe SetGraphLRVerbose(bool verbose) { SetGraphVerboseStepLr(verbose); return Maybe::Ok(); } inline bool GetGraphLRVerbose() { return IsOpenGraphVerboseStepLr(); } } // namespace oneflow #endif // ONEFLOW_API_PYTHON_ENV_ENV_H_ ================================================ FILE: oneflow/api/python/ep/cuda_matmul_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/ep/cuda/cuda_matmul_mode.h" namespace py = pybind11; namespace oneflow { namespace ep { ONEFLOW_API_PYBIND11_MODULE("ep", m) { m.def("is_matmul_allow_tf32", &CudaMatmulMode::is_matmul_allow_tf32); m.def("set_matmul_allow_tf32", &CudaMatmulMode::set_matmul_allow_tf32); m.def("is_matmul_allow_fp16_reduced_precision_reduction", &CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction); m.def("set_matmul_allow_fp16_reduced_precision_reduction", &CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/api/python/exception/exception.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/exception.h" #include "oneflow/core/common/error.h" #include "oneflow/api/python/of_api_registry.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("exception", m) { m.def("GetThreadLocalLastError", &ThreadLocalError); py::register_exception(m, "Exception"); py::register_exception(m, "RuntimeError", PyExc_RuntimeError); py::register_exception(m, "TypeError", PyExc_TypeError); py::register_exception(m, "IndexError", PyExc_IndexError); py::register_exception(m, "NotImplementedError", PyExc_NotImplementedError); } } // namespace oneflow ================================================ FILE: oneflow/api/python/exception/exception.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_ #define ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_ #include #undef _PyGC_FINALIZED #include #include "oneflow/core/common/exception.h" namespace py = pybind11; #define HANDLE_ERRORS try { #define END_HANDLE_ERRORS_RETSTMT(retstmt) \ } \ catch (py::error_already_set & e) { \ e.restore(); \ retstmt; \ } \ catch (const oneflow::RuntimeException& e) { \ PyErr_SetString(PyExc_RuntimeError, e.what()); \ retstmt; \ } \ catch (const oneflow::IndexException& e) { \ PyErr_SetString(PyExc_IndexError, e.what()); \ retstmt; \ } \ catch (const oneflow::TypeException& e) { \ PyErr_SetString(PyExc_TypeError, e.what()); \ retstmt; \ } \ catch (const oneflow::NotImplementedException& e) { \ PyErr_SetString(PyExc_NotImplementedError, e.what()); \ retstmt; \ } \ catch (const std::exception& e) { \ PyErr_SetString(PyExc_RuntimeError, e.what()); \ retstmt; \ } #define END_HANDLE_ERRORS END_HANDLE_ERRORS_RETSTMT(return NULL) #define END_HANDLE_ERRORS_RET(retval) END_HANDLE_ERRORS_RETSTMT(return retval) #define END_HANDLE_ERRORS_NORET END_HANDLE_ERRORS_RETSTMT(void) #endif // ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_ ================================================ FILE: oneflow/api/python/flags.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/of_api_registry.h" #ifdef WITH_CUDA #include #endif namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("flags", m) { m.def("with_cuda", []() { #ifdef WITH_CUDA return true; #else return false; #endif // WITH_CUDA }); m.def("with_npu", []() { #ifdef WITH_NPU return true; #else return false; #endif // WITH_NPU }); m.def("with_mlu", []() { #ifdef WITH_MLU return true; #else return false; #endif // WITH_MLU }); m.def("cuda_version", []() { #ifdef WITH_CUDA return CUDA_VERSION; #else return 0; #endif // WITH_CUDA }); m.def("use_cxx11_abi", []() { #if _GLIBCXX_USE_CXX11_ABI == 1 return true; #else return false; #endif // _GLIBCXX_USE_CXX11_ABI }); m.def("with_mlir", []() { #ifdef WITH_MLIR return true; #else return false; #endif // WITH_MLIR }); m.def("with_mlir_cuda_codegen", []() { #ifdef WITH_MLIR_CUDA_CODEGEN return true; #else return false; #endif // WITH_MLIR_CUDA_CODEGEN }); m.def("with_rdma", []() { #ifdef WITH_RDMA return true; #else return false; #endif // WITH_RDMA }); m.def("has_rpc_backend_grpc", []() { #ifdef RPC_BACKEND_GRPC return true; #else return false; #endif // RPC_BACKEND_GRPC }); m.def("has_rpc_backend_local", []() { #ifdef RPC_BACKEND_LOCAL return true; #else return false; #endif // RPC_BACKEND_LOCAL }); #define STRINGIFY(x) STRINGIFY_(x) #define STRINGIFY_(x) #x m.def("cmake_build_type", []() { #ifdef ONEFLOW_CMAKE_BUILD_TYPE return std::string(STRINGIFY(ONEFLOW_CMAKE_BUILD_TYPE)); #else return std::string("Undefined"); #endif // ONEFLOW_CMAKE_BUILD_TYPE }); #undef STRINGIFY #undef STRINGIFY_ } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/autocast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/autocast.h" namespace py = pybind11; namespace oneflow { size_t* nested_count() { static thread_local size_t _nested_count = 0; return &_nested_count; } bool is_nested_count_zero() { return (*nested_count()) == 0; } void increase_nested_count() { (*nested_count())++; } void decrease_nested_count() { (*nested_count())--; } class AutoCastMode { public: OF_DISALLOW_COPY_AND_MOVE(AutoCastMode); AutoCastMode(const std::string& device_type, Symbol dtype, bool enabled, bool cache_enabled) : prev_enabled_(autocast::is_enabled()), prev_cache_enabled_(autocast::is_autocast_cache_enabled()), prev_device_type_(autocast::get_autocast_device_type()), prev_dtype_(autocast::get_autocast_dtype()), prev_gpu_dtype_(autocast::get_autocast_gpu_dtype()), prev_cpu_dtype_(autocast::get_autocast_cpu_dtype()) { // update autocast state increase_nested_count(); autocast::set_enabled(enabled); autocast::set_autocast_cache_enabled(cache_enabled); if (device_type == "cpu") { autocast::set_autocast_device_type(kCPU); autocast::set_autocast_dtype(dtype); autocast::set_autocast_cpu_dtype(dtype); } else if (device_type == "cuda") { autocast::set_autocast_device_type(kCUDA); autocast::set_autocast_dtype(dtype); autocast::set_autocast_gpu_dtype(dtype); } else { THROW(RuntimeError) << "User specified autocast device_type must be 'cuda' or 'cpu'"; } } ~AutoCastMode() { decrease_nested_count(); autocast::set_enabled(prev_enabled_); autocast::set_autocast_cache_enabled(prev_cache_enabled_); autocast::set_autocast_device_type(prev_device_type_); autocast::set_autocast_dtype(prev_dtype_); autocast::set_autocast_gpu_dtype(prev_gpu_dtype_); autocast::set_autocast_cpu_dtype(prev_cpu_dtype_); if ((!prev_enabled_ || !prev_cache_enabled_) && is_nested_count_zero()) { autocast::clear_cache(); } } private: bool prev_enabled_; bool prev_cache_enabled_; DeviceType prev_device_type_; Symbol prev_dtype_; Symbol prev_gpu_dtype_; Symbol prev_cpu_dtype_; }; ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "AutoCastMode") .def(py::init([](const std::string& device_type, Symbol dtype, bool enabled, bool cache_enabled) { return std::make_shared(device_type, dtype, enabled, cache_enabled); })); m.def("is_autocast_enabled", autocast::is_enabled); m.def("set_autocast_enabled", autocast::set_enabled); m.def("get_autocast_gpu_dtype", autocast::get_autocast_gpu_dtype); m.def("get_autocast_cpu_dtype", autocast::get_autocast_cpu_dtype); m.def("set_autocast_gpu_dtype", autocast::set_autocast_gpu_dtype); m.def("set_autocast_cpu_dtype", autocast::set_autocast_cpu_dtype); m.def("is_autocast_cache_enabled", autocast::is_autocast_cache_enabled); m.def("set_autocast_cache_enabled", autocast::set_autocast_cache_enabled); m.def("clear_autocast_cache", autocast::clear_cache); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/ep/include/device.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_, std::shared_ptr>>(m, "device") .def(py::init([](const std::string& type_or_type_with_device_id) { return Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow(); })) .def(py::init([](const std::string& type, int64_t index) { return Device::New(type, index).GetOrThrow(); }), py::arg("type"), py::arg("index")) .def(py::init([](const Symbol& other_device) { return other_device; })) .def_property_readonly("type", [](const Symbol& d) { return d->type(); }) .def_property_readonly("index", [](const Symbol& d) { return d->device_id(); }) .def_property_readonly("rematable", [](const Symbol& d) { return d->rematable(); }) .def("__str__", [](const Symbol& d) { return d->ToString(); }) .def("__repr__", [](const Symbol& d) { return d->ToRepr(); }) .def(py::self == py::self) .def(py::hash(py::self)); m.def( "max_alignment_size", []() { return ep::kMaxAlignmentRequirement; }, py::return_value_policy::copy); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/doc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/throw.h" namespace py = pybind11; namespace oneflow { py::object AddFunctionDoc(py::object f, const std::string& doc_string) { static std::vector all_doc_strings; all_doc_strings.emplace_back(doc_string); const char* doc_str = all_doc_strings.back().c_str(); PyObject* obj = f.ptr(); if (PyCFunction_Check(obj)) { auto* f = (PyCFunctionObject*)obj; if (f->m_ml->ml_doc) { THROW(RuntimeError) << "function " << f->m_ml->ml_name << " already has a docstring " << "shows: " << f->m_ml->ml_doc; } f->m_ml->ml_doc = doc_str; } else if (PyFunction_Check(obj)) { auto* f = (PyFunctionObject*)obj; if (f->func_doc != Py_None) { THROW(RuntimeError) << "function " << PyBytes_AsString( PyUnicode_AsEncodedString(f->func_name, "utf-8", "~E~")) << " already has a docstring"; } f->func_doc = PyUnicode_FromString(doc_str); } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) { PyMethodDescrObject* f = (PyMethodDescrObject*)obj; if (f->d_method->ml_doc) { THROW(RuntimeError) << "function " << f->d_method->ml_name << "already has a docstring"; } f->d_method->ml_doc = doc_str; } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) { PyMethodDescrObject* f = (PyMethodDescrObject*)obj; if (f->d_method->ml_doc) { THROW(RuntimeError) << "function " << f->d_method->ml_name << "already has a docstring"; } f->d_method->ml_doc = doc_str; } else if (py::isinstance(f)) { if (py::hasattr(f, "__doc__")) { auto doc = py::getattr(f, "__doc__"); if (!doc.is(py::none())) { THROW(RuntimeError) << Py_TYPE(obj)->tp_name << " already has a docstring"; } } py::setattr(f, "__doc__", py::reinterpret_steal(PyUnicode_FromString(doc_str))); } else if (Py_TYPE(obj)->tp_name == PyProperty_Type.tp_name) { py::setattr(f, "__doc__", py::reinterpret_steal(PyUnicode_FromString(doc_str))); } else if (PyInstanceMethod_Check(obj)) { auto* f = (PyCFunctionObject*)(PyInstanceMethod_Function(obj)); f->m_ml->ml_doc = doc_str; } else { THROW(RuntimeError) << "function is " << Py_TYPE(obj)->tp_name << ", not a valid function"; } f.inc_ref(); return f; } py::object ReplaceDoc(py::object f, const std::string& doc_string) { static std::vector all_doc_strings; all_doc_strings.emplace_back(doc_string); const char* doc_str = all_doc_strings.back().c_str(); PyObject* obj = f.ptr(); if (PyCFunction_Check(obj)) { auto* f = (PyCFunctionObject*)obj; if (!f->m_ml->ml_doc) { THROW(RuntimeError) << "function " << f->m_ml->ml_name << " has not a docstring yet."; } f->m_ml->ml_doc = doc_str; } else if (PyFunction_Check(obj)) { auto* f = (PyFunctionObject*)obj; if (f->func_doc == Py_None) { THROW(RuntimeError) << "function " << PyBytes_AsString( PyUnicode_AsEncodedString(f->func_name, "utf-8", "~E~")) << " has not a docstring yet."; } Py_DECREF(f->func_doc); f->func_doc = PyUnicode_FromString(doc_str); } else { THROW(RuntimeError) << "function is " << Py_TYPE(obj)->tp_name << ", not a valid function."; } f.inc_ref(); return f; } } // namespace oneflow ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("add_doc", &oneflow::AddFunctionDoc); m.def("reset_doc", &oneflow::ReplaceDoc); } ================================================ FILE: oneflow/api/python/framework/dtype.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/framework/tensortype.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/core/framework/dtype.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_, std::shared_ptr>>(m, "dtype") .def_property_readonly("is_signed", [](const Symbol& d) { return d->is_signed(); }) .def_property_readonly("is_complex", [](const Symbol& d) { return d->is_complex(); }) .def_property_readonly("is_floating_point", [](const Symbol& d) { return d->is_floating_point(); }) .def("__str__", [](const Symbol& d) { return d->name(); }) .def("__repr__", [](const Symbol& d) { return d->name(); }) .def(py::self == py::self) .def(py::hash(py::self)) .def(py::pickle( [](const Symbol& dtype) { // __getstate__ return static_cast(dtype->data_type()); }, [](int t) { // __setstate__ return CHECK_JUST(DType::Get(DataType(t))); })) .def_property_readonly("bytes", [](const Symbol& dtype) { return dtype->bytes(); }) .def("get", [](const int data_type_enum) { return CHECK_JUST(DType::Get(static_cast(data_type_enum))); }); m.attr("bool") = &CHECK_JUST(DType::Get(DataType::kBool)); m.attr("char") = &CHECK_JUST(DType::Get(DataType::kChar)); m.attr("float16") = &CHECK_JUST(DType::Get(DataType::kFloat16)); m.attr("float") = &CHECK_JUST(DType::Get(DataType::kFloat)); m.attr("float32") = &CHECK_JUST(DType::Get(DataType::kFloat)); m.attr("double") = &CHECK_JUST(DType::Get(DataType::kDouble)); m.attr("float64") = &CHECK_JUST(DType::Get(DataType::kDouble)); m.attr("int8") = &CHECK_JUST(DType::Get(DataType::kInt8)); m.attr("int32") = &CHECK_JUST(DType::Get(DataType::kInt32)); m.attr("int64") = &CHECK_JUST(DType::Get(DataType::kInt64)); m.attr("uint8") = &CHECK_JUST(DType::Get(DataType::kUInt8)); m.attr("record") = &CHECK_JUST(DType::Get(DataType::kOFRecord)); m.attr("tensor_buffer") = &CHECK_JUST(DType::Get(DataType::kTensorBuffer)); m.attr("bfloat16") = &CHECK_JUST(DType::Get(DataType::kBFloat16)); m.attr("uint16") = &CHECK_JUST(DType::Get(DataType::kUInt16)); m.attr("uint32") = &CHECK_JUST(DType::Get(DataType::kUInt32)); m.attr("uint64") = &CHECK_JUST(DType::Get(DataType::kUInt64)); m.attr("uint128") = &CHECK_JUST(DType::Get(DataType::kUInt128)); m.attr("int16") = &CHECK_JUST(DType::Get(DataType::kInt16)); m.attr("int128") = &CHECK_JUST(DType::Get(DataType::kInt128)); m.attr("complex32") = &CHECK_JUST(DType::Get(DataType::kComplex32)); m.attr("chalf") = &CHECK_JUST(DType::Get(DataType::kComplex32)); m.attr("complex64") = &CHECK_JUST(DType::Get(DataType::kComplex64)); m.attr("cfloat") = &CHECK_JUST(DType::Get(DataType::kComplex64)); m.attr("complex128") = &CHECK_JUST(DType::Get(DataType::kComplex128)); m.attr("cdouble") = &CHECK_JUST(DType::Get(DataType::kComplex128)); m.attr("char") = &CHECK_JUST(DType::Get(DataType::kChar)); m.attr("short") = &CHECK_JUST(DType::Get(DataType::kInt16)); py::options options; options.disable_function_signatures(); m.def("get_default_dtype", []() { return GetDefaultDType(); }); m.def("set_default_dtype", [](const Symbol& dtype) { SetDefaultDType(dtype).GetOrThrow(); }); m.def("set_default_tensor_type", [](const py::object& tensor_type) { if (one::PyTensorType_Check(tensor_type.ptr())) { CHECK_JUST(SetDefaultDType(one::PyTensorType_UnpackDType(tensor_type.ptr()))); } else { throw py::type_error("invalid type object"); } }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/framework.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/api/python/framework/framework.h" #include "oneflow/core/framework/load_library.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("GetSerializedCurrentJob", []() -> Maybe { return py::bytes(*JUST(GetSerializedCurrentJob())); }); m.def("GetFunctionConfigDef", &GetFunctionConfigDef); m.def("GetScopeConfigDef", &GetScopeConfigDef); m.def("LoadLibrary", &LoadLibrary); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/framework.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_ #include #include #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/inter_user_job_info.pb.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/oneflow.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/framework/load_library.h" namespace oneflow { inline Maybe GetSerializedCurrentJob() { auto* job_ctx_mgr = Singleton::Get(); CHECK_NOTNULL_OR_RETURN(job_ctx_mgr); auto* job_ctx = JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName()))); CHECK_NOTNULL_OR_RETURN(job_ctx); return job_ctx->job().SerializeAsString(); } inline Maybe GetFunctionConfigDef() { std::string ret; google::protobuf::TextFormat::PrintToString(GlobalFunctionConfigDef(), &ret); return ret; } inline Maybe GetScopeConfigDef() { std::string ret; google::protobuf::TextFormat::PrintToString(GlobalScopeConfigDef(), &ret); return ret; } inline Maybe GetSerializedMachineId2DeviceIdListOFRecord( const std::string& parallel_conf_str) { ParallelConf parallel_conf; CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, ¶llel_conf)) << "parallel conf parse failed"; return PbMessage2TxtString(*JUST(ParseMachineAndDeviceIdList(parallel_conf))); } inline Maybe LoadLibraryNow(const std::string& lib_path) { return LoadLibrary(lib_path); } } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_ ================================================ FILE: oneflow/api/python/framework/global_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/global_mode.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("global_view", m) { py::class_>(m, "global_mode") .def(py::init([](const bool enabled) { if (enabled) { THROW(RuntimeError) << "To enable global mode, placement and sbp must be provided."; } return std::make_shared(enabled); })) .def(py::init([](const bool enabled, const Symbol& placement, const std::vector>& sbp) { if (!enabled) { THROW(RuntimeError) << "To disable global mode, placement and sbp must not be provided."; } return std::make_shared(enabled, CHECK_JUST(GetNdSbp(sbp)), placement); }), py::arg("enabled").none(false), py::arg("placement").none(false), py::arg("sbp").none(false)) .def(py::init([](const bool enabled, const Symbol& placement, const Symbol& sbp) { return std::make_shared(enabled, CHECK_JUST(SbpToNdSbp(sbp)), placement); }), py::arg("enabled").none(false), py::arg("placement").none(false), py::arg("sbp").none(false)) .def("__enter__", [](const GlobalMode::Guard& guard_obj) {}) .def("__exit__", [](const GlobalMode::Guard& guard_obj, const py::object& type, const py::object& value, const py::object& traceback) {}); py::class_>(m, "current_global_mode") .def(py::init([]() { return std::make_shared(); })) .def_property_readonly("is_enabled", [](const GlobalMode& gm) { return gm.is_enabled(); }) .def_property_readonly("sbp", [](const GlobalMode& gm) { if (!gm.is_enabled()) { THROW(RuntimeError) << "Current global mode is disabled, there is no sbp."; } const auto& nd_sbp = gm.nd_sbp(); auto tuple = py::tuple(nd_sbp->sbp_parallel_size()); for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { tuple[i] = SymbolOf(nd_sbp->sbp_parallel(i)); } return tuple; }) .def_property_readonly("placement", [](const GlobalMode& gm) { if (!gm.is_enabled()) { THROW(RuntimeError) << "Current global mode is disabled, there is no placement."; } return gm.parallel_desc(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/id_state.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/id_state.h" namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("", m) { using namespace oneflow; py::class_(m, "IdState") .def(py::init<>()) .def_readwrite("regst_desc_id_state", &IdState::regst_desc_id_state_) .def_readwrite("mem_block_id_state", &IdState::mem_block_id_state_) .def_readwrite("chunk_id_state", &IdState::chunk_id_state_) .def_readwrite("job_id_state", &IdState::job_id_state_) .def_readwrite("task_index_state", &IdState::task_index_state_) .def_readwrite("stream_index_state", &IdState::stream_index_state_) // support pickle .def(py::pickle( [](const IdState& id_state) { return py::make_tuple(id_state.regst_desc_id_state_, id_state.mem_block_id_state_, id_state.chunk_id_state_, id_state.job_id_state_, id_state.task_index_state_, id_state.stream_index_state_); }, [](const py::tuple& t) { CHECK(t.size() == 6); IdState id_state; id_state.regst_desc_id_state_ = t[0].cast(); id_state.mem_block_id_state_ = t[1].cast(); id_state.chunk_id_state_ = t[2].cast(); id_state.job_id_state_ = t[3].cast(); id_state.task_index_state_ = t[4].cast>(); id_state.stream_index_state_ = t[5].cast>(); return id_state; })); m.def("set_id_state", [](const IdState& id_state) { Singleton::Get()->SetIdState(id_state); }); m.def("get_id_state", []() { return Singleton::Get()->GetIdState(); }); } ================================================ FILE: oneflow/api/python/framework/id_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/id_util.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("UniqueStr", &UniqueStr); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/instructions_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor.h" namespace py = pybind11; namespace oneflow { namespace { Maybe DeprecatedPhysicalRun(const std::function& Build) { return PhysicalRun([&](InstructionsBuilder* instruction_builder) -> Maybe { Build(instruction_builder); return Maybe::Ok(); }); } } // namespace ONEFLOW_API_PYBIND11_MODULE("deprecated", m) { py::class_>(m, "InstructionsBuilder") .def( "BuildInitialScope", [](const std::shared_ptr& builder, int64_t session_id, const std::string& job_conf_str, const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy, bool is_local) -> Maybe { JobConfigProto job_conf; CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << Error::RuntimeError() << "job conf parse failed"; return builder->BuildInitialScope(session_id, job_conf, device_tag, machine_device_ids, hierarchy, is_local); }, py::arg("session_id").none(false), py::arg("job_conf_str").none(false), py::arg("device_tag").none(false), py::arg("machine_device_ids").none(false), py::arg("hierarchy").none(true), py::arg("is_local").none(false)) .def( "BuildInitialScopeWithPlacement", [](const std::shared_ptr& builder, int64_t session_id, const std::string& job_conf_str, Symbol placement, bool is_local) -> Maybe { JobConfigProto job_conf; CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << Error::RuntimeError() << "job conf parse failed"; return builder->BuildInitialScopeWithPlacement(session_id, job_conf, placement, is_local); }, py::arg("session_id").none(false), py::arg("job_conf_str").none(false), py::arg("placement").none(false), py::arg("is_local").none(false)) .def("BuildScopeWithNewParallelDesc", &InstructionsBuilder::BuildScopeWithNewParallelDesc, py::arg("scope").none(false), py::arg("device_tag").none(false), py::arg("machine_device_ids").none(false), py::arg("hierarchy").none(true)) .def("BuildScopeWithNewParallelConf", [](const std::shared_ptr& builder, const std::shared_ptr& scope, const std::string& parallel_conf_str) -> Maybe { ParallelConf parallel_conf; CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, ¶llel_conf)) << Error::RuntimeError() << "parallel conf parse failed"; return builder->BuildScopeWithNewParallelConf(scope, parallel_conf); }) .def("BuildScopeWithNewIsLocal", &InstructionsBuilder::BuildScopeWithNewIsLocal) .def("BuildScopeWithNewScopeName", &InstructionsBuilder::BuildScopeWithNewScopeName) .def("BuildScopeByProtoStrSetter", &InstructionsBuilder::BuildScopeByProtoStrSetter); m.def("PhysicalRun", &DeprecatedPhysicalRun, py::call_guard()); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/layout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/framework/tensortype.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/layout.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "layout") .def("__str__", [](Symbol d) { return d->name(); }) .def("__repr__", [](Symbol d) { return d->name(); }) .def(py::self == py::self) .def(py::hash(py::self)) .def(py::pickle( [](Symbol layout) { // __getstate__ return static_cast(layout->layout_type()); }, [](int t) { // __setstate__ return Layout::Get(LayoutType(t)); })) .def("get", [](const int layout_type_enum) { return Layout::Get(static_cast(layout_type_enum)); }); m.attr("strided") = Layout::Get(LayoutType::kStrided); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/memory_format.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/framework/memory_format.h" namespace py = pybind11; namespace oneflow { static PyObject* PyMemoryFormat_repr(PyMemoryFormatObject* self) { auto memory_format = PyMemoryFormat_Unpack((PyObject*)self); if (memory_format == MemoryFormat::kContiguous) { return PyUnicode_FromString("oneflow.contiguous_format"); } else if (memory_format == MemoryFormat::kChannelsLast) { return PyUnicode_FromString("oneflow.channels_last"); } else if (memory_format == MemoryFormat::kPreserve) { return PyUnicode_FromString("oneflow.preserve_format"); } else { THROW(TypeError) << "invalid memory format"; return nullptr; } } PyTypeObject PyMemoryFormat_Type = { PyVarObject_HEAD_INIT(NULL, 0) "oneflow.memory_format", /* tp_name */ sizeof(PyMemoryFormatObject), /* tp_basicsize */ 0, /* tp_itemsize */ NULL, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ NULL, /* tp_getattr */ NULL, /* tp_setattr */ NULL, /* tp_reserved */ (reprfunc)PyMemoryFormat_repr, /* tp_repr */ NULL, /* tp_as_number */ NULL, /* tp_as_sequence */ NULL, /* tp_as_mapping */ NULL, /* tp_hash */ NULL, /* tp_call */ NULL, /* tp_str */ NULL, /* tp_getattro */ NULL, /* tp_setattro */ NULL, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ }; bool PyMemoryFormat_Check(PyObject* self) { return self && self->ob_type == &PyMemoryFormat_Type; } PyObject* PyMemoryFormat_New(MemoryFormat memory_format) { auto* self = (PyMemoryFormatObject*)PyMemoryFormat_Type.tp_alloc(&PyMemoryFormat_Type, 0); self->memory_format = memory_format; return (PyObject*)self; } static PyObject* PyMemoryFormat_contiguous = nullptr; static PyObject* PyMemoryFormat_channels_last = nullptr; static PyObject* PyMemoryFormat_preserve = nullptr; ONEFLOW_API_PYBIND11_MODULE("", m) { if (PyType_Ready(&PyMemoryFormat_Type) < 0) { return; } Py_INCREF(&PyMemoryFormat_Type); if (PyModule_AddObject(m.ptr(), "memory_format", (PyObject*)&PyMemoryFormat_Type) < 0) { return; } PyMemoryFormat_contiguous = PyMemoryFormat_New(MemoryFormat::kContiguous); PyMemoryFormat_channels_last = PyMemoryFormat_New(MemoryFormat::kChannelsLast); PyMemoryFormat_preserve = PyMemoryFormat_New(MemoryFormat::kPreserve); if (PyModule_AddObject(m.ptr(), "contiguous_format", PyMemoryFormat_contiguous) < 0) { return; } if (PyModule_AddObject(m.ptr(), "channels_last", PyMemoryFormat_channels_last) < 0) { return; } if (PyModule_AddObject(m.ptr(), "preserve_format", PyMemoryFormat_preserve) < 0) { return; } } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/memory_format.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_ #include #undef _PyGC_FINALIZED #include #include "oneflow/core/common/memory_format.pb.h" namespace oneflow { typedef struct PyMemoryFormatObject { PyTypeObject ob_type; MemoryFormat memory_format; } PyMemoryFormatObject; bool PyMemoryFormat_Check(PyObject*); inline MemoryFormat PyMemoryFormat_Unpack(PyObject* self) { return ((PyMemoryFormatObject*)self)->memory_format; } PyObject* PyMemoryFormat_New(MemoryFormat memory_format); } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_ ================================================ FILE: oneflow/api/python/framework/nn_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/job_build/job_build_and_infer.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/job_ir.h" #include "oneflow/core/job/job_interpreter.h" namespace py = pybind11; namespace oneflow { namespace { Maybe APINNGraphAdditionalVarNames(const std::shared_ptr& graph) { const auto names = *JUST(graph->GetAdditionalVarOpNames()); py::list name_list = py::cast(names); return py::cast(name_list); } Maybe APINNGraphAdditionalVarTensors(const std::shared_ptr& graph) { const auto tensors = *JUST(graph->GetAdditionalVarOpTensors()); py::list tensor_list = py::cast(tensors); return py::cast(tensor_list); } Maybe APINNGraphGetCurrentSerializedJob(const std::shared_ptr& graph) { const auto job = graph->job(); return py::bytes(job.SerializeAsString()); } } // namespace ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) { using namespace oneflow; py::class_>(m, "CNNGraph") .def(py::init([](const std::string& name, const std::string& serialized_job, int64_t job_id, const std::shared_ptr& session_ctx) { Job job; if (!job.ParseFromString(serialized_job)) { PyErr_SetString(PyExc_TypeError, "The second argument is not a valid job"); } return std::make_shared(name, job, job_id, session_ctx); })) .def(py::init([](const std::string& name, const std::string& serialized_plan, int64_t job_id, const std::shared_ptr& session_ctx, bool init_from_plan) { if (!init_from_plan) { PyErr_SetString( PyExc_TypeError, "init_from_plan must be True when init CNNGraph with this bool parameter."); } Plan plan; if (!plan.ParseFromString(serialized_plan)) { PyErr_SetString(PyExc_TypeError, "The second argument is not a valid plan"); } return std::make_shared(name, plan, job_id, session_ctx); })) .def_property_readonly("name", &NNGraph::job_name) .def_property( "job", /*getter*/ [](const NNGraph& nn_graph) { return py::bytes(nn_graph.job().SerializeAsString()); }, /*setter*/ [](NNGraph& nn_graph, const std::string& serialized_job) { Job job; if (!job.ParseFromString(serialized_job)) { PyErr_SetString(PyExc_TypeError, "the value is not a valid job"); } nn_graph.restore_job(job); }) .def_property("job_id", &NNGraph::job_id, [](NNGraph& nn_graph, int64_t job_id) { nn_graph.restore_job_id(job_id); }) .def_property( "plan", /*getter*/ [](const NNGraph& nn_graph) { return py::bytes(nn_graph.plan().SerializeAsString()); }, /*setter*/ [](NNGraph& nn_graph, const std::string& serialized_plan) { Plan plan; if (!plan.ParseFromString(serialized_plan)) { PyErr_SetString(PyExc_TypeError, "the value is not a valid plan"); } nn_graph.restore_plan(plan); }) .def("register_input_op_names_and_tensors", &NNGraph::RegisterInputOpNamesAndTensors) .def("register_output_op_names_and_tensors", &NNGraph::RegisterOutputOpNamesAndTensors) .def("register_variable_op_names_and_tensors", &NNGraph::RegisterVariableOpNamesAndTensors) .def("register_additional_variable_names_and_tensors", &NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded) .def_property_readonly("additional_var_names", &APINNGraphAdditionalVarNames) .def_property_readonly("additional_var_tensors", &APINNGraphAdditionalVarTensors) .def("align_states_after_logical_graph_compile", &NNGraph::AlignStatesAfterLogicalGraphCompile) .def("complete_graph_for_runtime", &NNGraph::CompleteLogicalGraphForRuntime) .def("build_with_new_input_from_shared_graph", &NNGraph::BuildWithNewInputFromSharedGraph) .def("compile_plan_for_runtime", &NNGraph::CompilePlanForRuntime) .def("init_runtime", &NNGraph::InitRuntime) .def("get_current_job_str", &APINNGraphGetCurrentSerializedJob); m.def("RunLazyNNGraph", &RunLazyNNGraph); m.def("RunLazyNNGraphByVM", &one::InterpretJob); m.def("SoftSyncNNGraphBuffers", &SoftSyncNNGraphBuffers); m.def("AddTensorAsGraphLoss", &AddTensorAsGraphLoss); m.def("MarkVariableGradients", [](const std::vector>& variables, const std::vector>& gradients) { one::TensorTuple variable_tuple(variables.size()); one::TensorTuple gradient_tuple(gradients.size()); for (int i = 0; i < variables.size(); ++i) { variable_tuple[i] = variables[i]; } for (int i = 0; i < gradients.size(); ++i) { gradient_tuple[i] = gradients[i]; } return MarkVariableGradients(variable_tuple, gradient_tuple); }); m.def("ConvertJobToTosaIR", [](const std::string& serialized_job) -> Maybe { Job job; CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << "serialized job conversion failed."; return ConvertJobToTosaIR(&job); }); m.def( "SaveJobToIR", [](const std::string& serialized_job, const std::string& path) -> Maybe { Job job; CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << "serialized job conversion failed."; return SaveJobToIR(&job, path); }); m.def("ConvertJobToIR", [](const std::string& serialized_job) -> Maybe { Job job; CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << "serialized job conversion failed."; return ConvertJobToIR(&job); }); m.def("LoadSerializedJobFromIR", [](const std::string& path) -> Maybe { Job job; JUST(LoadJobFromIR(&job, path)); return py::bytes(job.SerializeAsString()); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/one_embedding.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/embedding/persistent_table.h" #include "oneflow/core/embedding/hash_functions.cuh" #include "oneflow/core/framework/dtype.h" namespace py = pybind11; namespace oneflow { class OneEmbeddingHandler final { public: OneEmbeddingHandler(const std::string& key_value_store_option_string, int64_t local_rank_id, int64_t rank_id, int64_t world_size) : local_rank_id_(local_rank_id), rank_id_(rank_id), world_size_(world_size) { embedding::KeyValueStoreOptions key_value_store_options(key_value_store_option_string); embedding_name_ = key_value_store_options.Name(); CreateKeyValueStore(key_value_store_options); } void LoadSnapshot(const std::string& snapshot_name) { #ifdef WITH_CUDA Singleton::Get()->LoadSnapshot(embedding_name_, local_rank_id_, rank_id_, snapshot_name); #else UNIMPLEMENTED() << "Only Support with CUDA"; #endif } void SaveSnapshot(const std::string& snapshot_name) { #ifdef WITH_CUDA Singleton::Get()->SaveSnapshot(embedding_name_, local_rank_id_, rank_id_, snapshot_name); #else UNIMPLEMENTED() << "Only Support with CUDA"; #endif } private: void CreateKeyValueStore(const embedding::KeyValueStoreOptions& key_value_store_options) { #ifdef WITH_CUDA Singleton::Get()->CreateKeyValueStore( key_value_store_options, local_rank_id_, rank_id_, world_size_); #else UNIMPLEMENTED() << "Only Support with CUDA"; #endif } std::string embedding_name_; int64_t local_rank_id_; int64_t rank_id_; int64_t world_size_; }; namespace embedding { class PersistentTableWriter { public: OF_DISALLOW_COPY_AND_MOVE(PersistentTableWriter); PersistentTableWriter() = default; virtual ~PersistentTableWriter() = default; virtual void Write(const py::array& keys, const py::array& values) = 0; virtual void Close() = 0; }; template class PersistentTableWriterImpl : public PersistentTableWriter { public: OF_DISALLOW_COPY_AND_MOVE(PersistentTableWriterImpl); PersistentTableWriterImpl(const std::vector& paths, const std::string& snapshot_name, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) : closed_(false), snapshot_name_(snapshot_name), storage_dim_(storage_dim) { tables_.resize(paths.size()); for (size_t i = 0; i < paths.size(); ++i) { PersistentTableOptions options; options.path = paths[i]; options.key_size = sizeof(Key); options.value_size = storage_dim * sizeof(Value); options.target_chunk_size_mb = target_chunk_size_mb; options.physical_block_size = physical_block_size; tables_[i] = NewPersistentTable(options); } } ~PersistentTableWriterImpl() override { CloseImpl(); } void Write(const py::array& keys, const py::array& values) override { pybind11::dtype::of().equal(pybind11::dtype::of()); CHECK(!closed_) << "Write on closed table"; CHECK_EQ(keys.ndim(), 1); CHECK_EQ(values.ndim(), 2); CHECK_EQ(keys.shape(0), values.shape(0)); CHECK_EQ(values.shape(1), storage_dim_); CHECK(keys.dtype().equal(py::dtype::of())); CHECK(values.dtype().equal(py::dtype::of())); const size_t n = keys.size(); std::vector> keys_buffers(tables_.size()); std::vector> values_buffers(tables_.size()); for (size_t i = 0; i < n; ++i) { const Key key = *(reinterpret_cast(keys.template data(i))); const uint32_t shard = ShardingHash()(key) % tables_.size(); keys_buffers[shard].push_back(key); const size_t values_offset = values_buffers[shard].size(); values_buffers[shard].resize(values_offset + storage_dim_ * sizeof(Value)); for (size_t j = 0; j < values.shape(1); ++j) { std::memcpy(values_buffers[shard].data() + values_offset + j * values.itemsize(), values.template data(i, j), values.itemsize()); } } for (size_t shard = 0; shard < tables_.size(); ++shard) { tables_[shard]->Put(keys_buffers[shard].size(), keys_buffers[shard].data(), values_buffers[shard].data()); } } void Close() override { CloseImpl(); } private: void CloseImpl() { if (!closed_) { for (auto& table : tables_) { table->SaveSnapshot(snapshot_name_); table.reset(); } } closed_ = true; } bool closed_; std::string snapshot_name_; std::vector> tables_; uint32_t storage_dim_; }; template std::shared_ptr NewPersistentTableWriter( const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { if (value_type->data_type() == DataType::kFloat) { return std::shared_ptr(new PersistentTableWriterImpl( paths, snapshot_name, storage_dim, target_chunk_size_mb, physical_block_size)); } else { UNIMPLEMENTED(); } } std::shared_ptr NewPersistentTableWriter( const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { if (key_type->data_type() == DataType::kInt32) { return NewPersistentTableWriter(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kUInt32) { return NewPersistentTableWriter(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kInt64) { return NewPersistentTableWriter(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kUInt64) { return NewPersistentTableWriter(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else { UNIMPLEMENTED(); return std::shared_ptr(nullptr); } } class PersistentTableReader { public: OF_DISALLOW_COPY_AND_MOVE(PersistentTableReader); PersistentTableReader() = default; virtual ~PersistentTableReader() = default; virtual std::tuple Next() = 0; virtual void Close() = 0; }; template class PersistentTableReaderImpl : public PersistentTableReader { public: constexpr static uint32_t kBatchSize = 65536; OF_DISALLOW_COPY_AND_MOVE(PersistentTableReaderImpl); PersistentTableReaderImpl(const std::vector& paths, const std::string& snapshot_name, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) : closed_(false), snapshot_name_(snapshot_name), storage_dim_(storage_dim), current_table_(0) { tables_.resize(paths.size()); iterators_.resize(paths.size()); for (size_t i = 0; i < paths.size(); ++i) { PersistentTableOptions options; options.path = paths[i]; options.key_size = sizeof(Key); options.value_size = storage_dim * sizeof(Value); options.target_chunk_size_mb = target_chunk_size_mb; options.physical_block_size = physical_block_size; options.read_only = true; tables_[i] = NewPersistentTable(options); iterators_[i] = std::unique_ptr(tables_[i]->ReadSnapshot(snapshot_name)); } keys_buffer_.resize(kBatchSize); values_buffer_.resize(kBatchSize * storage_dim_); } ~PersistentTableReaderImpl() override { CloseImpl(); } std::tuple Next() override { while (current_table_ < tables_.size()) { uint32_t n_result = 0; iterators_[current_table_]->Next(kBatchSize, &n_result, keys_buffer_.data(), values_buffer_.data()); if (n_result != 0) { py::array_t keys_arr(py::array::ShapeContainer({n_result})); py::array_t values_arr(py::array::ShapeContainer({n_result, storage_dim_})); std::memcpy(keys_arr.mutable_data(), keys_buffer_.data(), n_result * sizeof(Key)); std::memcpy(values_arr.mutable_data(), values_buffer_.data(), n_result * storage_dim_ * sizeof(Value)); return std::make_tuple(keys_arr, values_arr); } else { current_table_ += 1; continue; } } throw py::stop_iteration(); } void Close() override { CloseImpl(); } private: void CloseImpl() { if (!closed_) { for (auto& table : tables_) { table.reset(); } } closed_ = true; } bool closed_; std::string snapshot_name_; std::vector> tables_; std::vector> iterators_; uint32_t storage_dim_; size_t current_table_; std::vector keys_buffer_; std::vector values_buffer_; }; template std::shared_ptr NewPersistentTableReader( const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { if (value_type->data_type() == DataType::kFloat) { return std::shared_ptr(new PersistentTableReaderImpl( paths, snapshot_name, storage_dim, target_chunk_size_mb, physical_block_size)); } else { UNIMPLEMENTED(); } } std::shared_ptr NewPersistentTableReader( const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { if (key_type->data_type() == DataType::kInt32) { return NewPersistentTableReader(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kUInt32) { return NewPersistentTableReader(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kInt64) { return NewPersistentTableReader(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else if (key_type->data_type() == DataType::kUInt64) { return NewPersistentTableReader(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); } else { UNIMPLEMENTED(); return std::shared_ptr(nullptr); } } } // namespace embedding ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "OneEmbeddingHandler") .def(py::init([](const std::string& key_value_store_option_str, const int64_t local_rank_id, const int64_t rank_id, const int64_t world_size) { return std::make_shared(key_value_store_option_str, local_rank_id, rank_id, world_size); })) .def("SaveSnapshot", &OneEmbeddingHandler::SaveSnapshot) .def("LoadSnapshot", &OneEmbeddingHandler::LoadSnapshot); py::class_>( m, "PersistentTableWriter") .def(py::init([](const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { return embedding::NewPersistentTableWriter(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); })) .def("__enter__", [](embedding::PersistentTableWriter* writer) { return writer; }) .def("__exit__", [](embedding::PersistentTableWriter* writer, const py::object& exc_type, const py::object& exc_val, const py::object& exc_tb) { writer->Close(); }) .def("write", &embedding::PersistentTableWriter::Write) .def("close", &embedding::PersistentTableWriter::Close); py::class_>( m, "PersistentTableReader") .def(py::init([](const std::vector& paths, const std::string& snapshot_name, const Symbol& key_type, const Symbol& value_type, uint32_t storage_dim, uint64_t target_chunk_size_mb, uint16_t physical_block_size) { return embedding::NewPersistentTableReader(paths, snapshot_name, key_type, value_type, storage_dim, target_chunk_size_mb, physical_block_size); })) .def("__next__", &embedding::PersistentTableReader::Next) .def("__iter__", [](embedding::PersistentTableReader* reader) { return reader; }) .def("__enter__", [](embedding::PersistentTableReader* reader) { return reader; }) .def("__exit__", [](embedding::PersistentTableReader* reader, const py::object& exc_type, const py::object& exc_val, const py::object& exc_tb) { reader->Close(); }) .def("close", &embedding::PersistentTableReader::Close); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/op_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_builder.h" namespace py = pybind11; namespace oneflow { namespace one { ONEFLOW_API_PYBIND11_MODULE("one", m) { py::class_>(m, "OpBuilder") .def(py::init()) .def(py::init()) .def("input", &OpBuilder::MaybeInput) .def("output", &OpBuilder::MaybeOutput) .def("attr", [](const std::shared_ptr& x, const std::string& attr_name, const std::string& attr_val_str) -> Maybe { AttrValue attr_val; if (!TxtString2PbMessage(attr_val_str, &attr_val)) { THROW(RuntimeError) << "attr val parse failed.\n" << attr_val_str; } return x->MaybeAttr(attr_name, attr_val); }) .def("build", &OpBuilder::Build); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/op_expr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" namespace py = pybind11; namespace oneflow { namespace { template::value>::type* = nullptr> py::class_> PybindExportOpExpr( py::module& m, const char* op_type_name) { return py::class_>(m, op_type_name) .def(py::init([](const std::string& op_name, const std::string& op_conf_str, const std::vector& indexed_ibns, const std::vector& indexed_obns) { ConfT proto_op_conf; if (!TxtString2PbMessage(op_conf_str, &proto_op_conf)) { THROW(RuntimeError) << "op conf parse failed.\n" << op_conf_str; } return OpT::New(op_name, std::move(proto_op_conf), indexed_ibns, indexed_obns) .GetPtrOrThrow(); })); } } // namespace ONEFLOW_API_PYBIND11_MODULE("one", m) { py::class_>(m, "OpExpr") .def_property_readonly("op_type_name", &one::OpExpr::op_type_name) .def_property_readonly("input_size", &one::OpExpr::input_size) .def_property_readonly("output_size", &one::OpExpr::output_size); py::class_>(m, "BuiltinOpExpr") .def_property_readonly("name", &one::BuiltinOpExpr::op_name) .def_property_readonly("indexed_ibns", &one::BuiltinOpExpr::indexed_ibns) .def_property_readonly("indexed_obns", &one::BuiltinOpExpr::indexed_obns); auto py_user_op_class = PybindExportOpExpr(m, "UserOpExpr"); py_user_op_class.def_property_readonly( "op_type_name", [](const one::UserOpExpr& op) { return op.proto().op_type_name(); }); PybindExportOpExpr(m, "VariableOpExpr"); // NOTE(chengcheng): export for Lazy nn.Graph Feed/Fetch EagerTensor to/from LazyTensor. PybindExportOpExpr(m, "FeedInputOpExpr"); PybindExportOpExpr(m, "FeedVariableOpExpr"); PybindExportOpExpr(m, "FetchOutputOpExpr"); PybindExportOpExpr( m, "ImageDecoderRandomCropResizeOpExpr"); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/parallel_conf_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/parallel_conf_util.h" namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("GetDeviceTagAndMachineDeviceIdsAndHierarchy", &GetDeviceTagAndMachineDeviceIdsAndHierarchy); m.def("MakeParallelConf", &MakeParallelConf); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/py_kernel_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/framework.h" #include "oneflow/extension/python/py_kernel_registry.h" namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("RegisterPyKernelCaller", &::oneflow::pyext::RegisterPyKernelCaller); m.def("RegisterPyKernels", [](py::object py_kernels) { ::oneflow::pyext::RegisterPyKernels(py_kernels.ptr()); }); } ================================================ FILE: oneflow/api/python/framework/random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/tensor.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #endif // WITH_CUDA namespace py = pybind11; namespace oneflow { Maybe CreateGenerator(const std::string& device_str) { auto [device_name, device_index, rematable] = *JUST(ParseDeviceString(device_str)); return one::MakeGenerator(device_name, device_index); } py::tuple GetCudaDefaultGenerators() { #ifdef WITH_CUDA static int device_count = GetCudaDeviceCount(); #else static int device_count = 0; #endif py::tuple default_cuda_generators(device_count); FOR_RANGE(int, device_id, 0, device_count) { const auto& cuda_gen = one::DefaultCUDAGenerator(device_id); default_cuda_generators[device_id] = py::cast(cuda_gen); } return default_cuda_generators; } ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "Generator") .def(py::init([](const std::string& device_tag) { return CreateGenerator(device_tag).GetPtrOrThrow(); })) .def("manual_seed", [](const std::shared_ptr& generator, const py::object& seed) -> std::shared_ptr { int64_t seed_val = (one::functional::PyUnpackLong(seed.ptr())).GetOrThrow(); generator->set_current_seed(seed_val); return generator; }) .def("initial_seed", &one::Generator::current_seed) .def("seed", &one::Generator::seed) .def_property_readonly("device", &one::Generator::device) .def("get_state", &one::Generator::GetState) .def("set_state", &one::Generator::SetState); m.def("manual_seed", [](const py::object& seed) -> Maybe { int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); return one::ManualSeed(seed_val); }); m.def("manual_seed", [](const py::object& seed, const std::string& device, int device_index) -> Maybe { int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); return one::ManualSeed(seed_val, device, device_index); }); m.def("create_generator", &CreateGenerator); m.def("default_generator", [](const std::string& device_str) -> Maybe { auto [device_name, device_index, rematable] = *JUST(ParseDeviceString(device_str)); return one::DefaultGenerator(device_name, device_index); }); m.def("ManualSeedAllCudaGenerator", [](const py::object& seed) -> Maybe { int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); return one::ManualSeedAllCudaGenerator(seed_val); }); m.def("default_generators", &GetCudaDefaultGenerators); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/scope_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/scope_util.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("GetCurrentScope", &GetCurrentScope); m.def("MakeInitialScope", [](const std::string& job_conf_str, Symbol placement, bool is_local) -> Maybe { JobConfigProto job_conf; CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << "job conf parse failed"; return MakeInitialScope(job_conf, placement, is_local); }); m.def("InitGlobalScopeStack", &InitThreadLocalScopeStack); m.def("GlobalScopeStackPush", &ThreadLocalScopeStackPush); m.def("GlobalScopeStackPop", &ThreadLocalScopeStackPop); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/session_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/session_util.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("GetDefaultSessionId", []() -> int64_t { return GetDefaultSessionId().GetOrThrow(); }); m.def("RegsterSessionId", &RegsterSessionId); m.def("ClearSessionId", &ClearSessionId); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/shut_down_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/shut_down_util.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("SetShuttingDown", []() { return SetShuttingDown(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/size.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/shape.h" namespace py = pybind11; namespace oneflow { using one::functional::PyObjectPtr; static PyObject* TensorSize_repr(TensorSize* self) { std::stringstream ss; int32_t idx = 0; int32_t size = PyTuple_Size((PyObject*)self); ss << "oneflow.Size(["; for (int i = 0; i < size; ++i) { int64_t dim = PyLong_AsLongLong(PyTuple_GET_ITEM(self, i)); ss << dim; if (++idx != size) { ss << ", "; } } ss << "])"; return PyUnicode_FromString(ss.str().c_str()); } static PyObject* TensorSize_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) { PyObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs)); if (self.get()) { for (int i = 0; i < PyTuple_Size(self.get()); ++i) { PyObject* item = PyTuple_GET_ITEM(self.get(), i); if (!PyLong_Check(item)) { return PyErr_Format(PyExc_TypeError, "oneflow.Size() takes an iterable of 'int', but item '%d' is '%s'", i, Py_TYPE(item)->tp_name); } } } return self.release(); } static Py_ssize_t TensorSize_length(TensorSize* self) { return PyTuple_Type.tp_as_sequence->sq_length((PyObject*)self); } static PyObject* TensorSize_concat(TensorSize* self, PyObject* other) { PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_concat((PyObject*)self, other)); if (!result.get()) { return nullptr; } if (PyTuple_Check(result.get())) { PyObjectPtr args(PyTuple_Pack(1, result.get())); return TensorSize_new(&TensorSize_Type, args.get(), nullptr); } return result.release(); } static PyObject* TensorSize_repeat(TensorSize* self, Py_ssize_t n) { PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_repeat((PyObject*)self, n)); if (!result.get()) { return nullptr; } if (PyTuple_Check(result.get())) { PyObjectPtr args(PyTuple_Pack(1, result.get())); return TensorSize_new(&TensorSize_Type, args.get(), nullptr); } return result.release(); } static PyObject* TensorSize_item(TensorSize* self, Py_ssize_t i) { return PyTuple_Type.tp_as_sequence->sq_item((PyObject*)self, i); } static int TensorSize_contains(TensorSize* self, PyObject* el) { return PyTuple_Type.tp_as_sequence->sq_contains((PyObject*)self, el); } static PySequenceMethods TensorSize_as_sequence = { (lenfunc)TensorSize_length, /* sq_length */ (binaryfunc)TensorSize_concat, /* sq_concat */ (ssizeargfunc)TensorSize_repeat, /* sq_repeat */ (ssizeargfunc)TensorSize_item, /* sq_item */ 0, /* sq_slice */ 0, /* sq_ass_item */ 0, /* sq_ass_slice */ (objobjproc)TensorSize_contains, /* sq_contains */ }; static PyObject* TensorSize_subscript(TensorSize* self, PyObject* item) { PyObjectPtr result(PyTuple_Type.tp_as_mapping->mp_subscript((PyObject*)self, item)); if (!result.get()) { return nullptr; } if (PyTuple_Check(result.get())) { PyObjectPtr args(PyTuple_Pack(1, result.get())); return TensorSize_new(&TensorSize_Type, args.get(), nullptr); } return result.release(); }; static PyMappingMethods TensorSize_as_mapping = { (lenfunc)TensorSize_length, /* mp_length */ (binaryfunc)TensorSize_subscript, /* mp_subscript */ 0, /* mp_ass_subscript */ }; static PyObject* TensorSize_numel(PyObject* self, PyObject* args) { int64_t numel = 1; for (int i = 0; i < PyTuple_Size(self); ++i) { numel *= PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i)); } return PyLong_FromLongLong(numel); } static PyMethodDef TensorSize_methods[] = { {"numel", (PyCFunction)TensorSize_numel, METH_NOARGS, NULL}, {NULL}}; PyTypeObject TensorSize_Type = { PyVarObject_HEAD_INIT(NULL, 0) "oneflow.Size", /* tp_name */ sizeof(TensorSize), /* tp_basicsize */ 0, /* tp_itemsize */ NULL, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ NULL, /* tp_getattr */ NULL, /* tp_setattr */ NULL, /* tp_reserved */ (reprfunc)TensorSize_repr, /* tp_repr */ NULL, /* tp_as_number */ &TensorSize_as_sequence, /* tp_as_sequence */ &TensorSize_as_mapping, /* tp_as_mapping */ NULL, /* tp_hash */ NULL, /* tp_call */ NULL, /* tp_str */ NULL, /* tp_getattro */ NULL, /* tp_setattro */ NULL, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ NULL, /* tp_doc */ NULL, /* tp_traverse */ NULL, /* tp_clear */ NULL, /* tp_richcompare */ 0, /* tp_weaklistoffset */ NULL, /* tp_iter */ NULL, /* tp_iternext */ TensorSize_methods, /* tp_methods */ NULL, /* tp_members */ NULL, /* tp_getset */ &PyTuple_Type, /* tp_base */ NULL, /* tp_dict */ NULL, /* tp_descr_get */ NULL, /* tp_descr_set */ 0, /* tp_dictoffset */ NULL, /* tp_init */ NULL, /* tp_alloc */ TensorSize_new, /* tp_new */ NULL, /* tp_free */ }; int TensorSize_Check(PyObject* p) { return p && p->ob_type == &TensorSize_Type; } PyObject* TensorSize_New(Py_ssize_t len) { return TensorSize_Type.tp_alloc(&TensorSize_Type, len); } PyObject* TensorSize_NewFromShape(const Shape& size) { PyObjectPtr self(TensorSize_New(size.NumAxes())); if (self.get()) { for (int i = 0; i < size.NumAxes(); ++i) { PyTuple_SET_ITEM(self.get(), i, PyLong_FromLongLong(size.At(i))); } } return self.release(); } Shape TensorSize_AsShape(PyObject* self) { if (!TensorSize_Check(self)) { PyErr_Format(PyExc_TypeError, "can only convert TensorSize(not \"%s\") to Shape", Py_TYPE(self)->tp_name); return Shape(); } int size = TensorSize_length((TensorSize*)self); DimVector dim_vec(size); for (int i = 0; i < size; ++i) { dim_vec[i] = PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i)); } return Shape(std::move(dim_vec)); } ONEFLOW_API_PYBIND11_MODULE("", m) { if (PyType_Ready(&TensorSize_Type) < 0) { return; } Py_INCREF(&TensorSize_Type); if (PyModule_AddObject(m.ptr(), "Size", (PyObject*)&TensorSize_Type) < 0) { return; } } } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/size.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_ #include #include #undef _PyGC_FINALIZED #include #include "oneflow/core/common/shape.h" namespace oneflow { typedef struct { PyTupleObject ob_base; } TensorSize; extern PyTypeObject TensorSize_Type; int TensorSize_Check(PyObject* p); PyObject* TensorSize_New(Py_ssize_t len); PyObject* TensorSize_NewFromShape(const Shape& size); Shape TensorSize_AsShape(PyObject* self); } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_ ================================================ FILE: oneflow/api/python/framework/stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/framework/thread.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_set.h" #include "oneflow/core/framework/stream_guard.h" namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("", m) { using namespace oneflow; py::class_>(m, "StreamSet") .def(py::init([](const AsyncThread& async_thread) { return StreamSet::New(async_thread.thread_uid()).GetPtrOrThrow(); })); py::class_>(m, "StreamGuard") .def(py::init([](const std::shared_ptr& stream_set) { auto stream_converter = std::make_shared(stream_set); return std::make_shared(stream_converter); })); } ================================================ FILE: oneflow/api/python/framework/tensor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/framework/tensor.h" #include #include #undef _PyGC_FINALIZED #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/framework/tensortype.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/functional_api.yaml.pybind.h" #include "oneflow/api/python/functional/tensor_api.yaml.pybind.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/placement_utils.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/kernel/kernel_util.h" namespace py = pybind11; namespace oneflow { namespace one { #define ASSERT(x) (x).GetOrThrow() #define ASSERT_PTR(x) (x).GetPtrOrThrow() #define PY_XINCREF(p) (({ Py_XINCREF(p); }), (p)) #if PY_VERSION_HEX < 0x03070000 #define PYGETSET_NAME(name) const_cast(name) #else #define PYGETSET_NAME(name) (name) #endif PyTypeObject* PyTensorObject_Type = NULL; PyTypeObject* PyParameterObject_Type = NULL; namespace { template struct AllocType {}; #define DEFINE_ALLOC_TYPE(type) \ template<> \ struct AllocType { \ static PyTypeObject** value; \ }; \ PyTypeObject** AllocType::value = &Py##type##Object_Type DEFINE_ALLOC_TYPE(Tensor); DEFINE_ALLOC_TYPE(Parameter); #undef DEFINE_ALLOC_TYPE template PyObject* PyTensor_wrap(const std::shared_ptr& data, PyTensorObject* bind_pyobj) { if (!data) { Py_RETURN_NONE; } PyObject* py_tensor = (PyObject*)data->pyobject(); if (bind_pyobj == nullptr && py_tensor) { // Has been wrapped by python before if (data->owns_pyobj()) { // PyTensor are not alive in python side, so we flip back the ownership to PyTensor data->set_owns_pyobj(false); ((PyTensorObject*)py_tensor)->data = data; // NOTE: Needn't incref here, because the reference count of py_tensor is already increased return py_tensor; } else { // PyTensor is alive, so we directly incref it and return it Py_XINCREF(py_tensor); return py_tensor; } } else { // Has not been wrapped by python before, so we create a new PyTensor and give it the ownership if (bind_pyobj == nullptr) { bind_pyobj = (PyTensorObject*)PyTensorObject_Type->tp_alloc(*AllocType::value, 0); } bind_pyobj->data = data; if (py_tensor) { // If it has bind pyobj, reset the shared_ptr in origin PyTensorObject ((PyTensorObject*)py_tensor)->data.reset(); } bind_pyobj->data->set_pyobject_ptr(std::unique_ptr( bind_pyobj, [](void* ptr) { Py_DECREF((PyObject*)ptr); })); bind_pyobj->data->set_owns_pyobj(false); return (PyObject*)bind_pyobj; } } bool PyTensor_tryResurrect(PyObject* py_tensor) { auto* self = (PyTensorObject*)py_tensor; if (self->data) { // PyTensor holds the ownership, now we flip it back to C++ and resurrect python object // temporarily auto tensor = self->data; self->data.reset(); tensor->set_owns_pyobj(true); Py_XINCREF(py_tensor); return true; } // Otherwise, PyTensor was already not alive in python side return false; } } // namespace static int PyTensorObject_init(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto* temp = functional::_legacy_tensor_ctor(NULL, args, kwargs); if (PyErr_Occurred()) { throw py::error_already_set(); } PyTensor_wrap(PyTensor_Unpack(temp), (PyTensorObject*)self); return 0; END_HANDLE_ERRORS_RET(-1) } static void PyTensorObject_dealloc(PyObject* self) { if (PyTensor_tryResurrect(self)) { return; } // clear __dict__ PyObject** dict_ptr = _PyObject_GetDictPtr(self); if (dict_ptr) { Py_CLEAR(*dict_ptr); } auto* type = Py_TYPE(self); type->tp_free(self); Py_DECREF(type); } static int PyParameterObject_init(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* data = NULL; int requires_grad = 1; static const char* keywords[3] = {"data", "requires_grad", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|p:__init__", const_cast(keywords), &data, &requires_grad)) { return -1; } if (self) { PyTensor_wrap( ASSERT_PTR(Parameter::MakeTensor(PyTensor_Unpack(data), requires_grad)), (PyTensorObject*)self); } return 0; END_HANDLE_ERRORS_RET(-1) } static Py_ssize_t PyTensorObject_length(PyTensorObject* self) { if (self->data->ndim() == 0) { return 0; } return self->data->dim(0); } static PyObject* PyTensorObject_getitem(PyObject* self, Py_ssize_t item) { HANDLE_ERRORS const auto& p = PyTensor_Unpack(self); return PyTensor_New( ASSERT_PTR(functional::TensorGetItem(p, {functional::detail::IndexItem(item)}))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_subscript(PyObject* self, PyObject* item) { HANDLE_ERRORS const auto& p = PyTensor_Unpack(self); functional::PythonArg arg(item); return PyTensor_New(ASSERT_PTR(functional::TensorGetItem(p, arg.As()))); END_HANDLE_ERRORS } static PySequenceMethods PyTensorObject_as_sequence = { (lenfunc)PyTensorObject_length, NULL, /*sq_concat*/ NULL, /*sq_repeat*/ (ssizeargfunc)PyTensorObject_getitem, /*sq_item*/ }; extern int PyTensorObject_setitem(PyObject*, PyObject*, PyObject*); static PyMappingMethods PyTensorObject_as_mapping = { (lenfunc)PyTensorObject_length, (binaryfunc)PyTensorObject_subscript, (objobjargproc)PyTensorObject_setitem, }; static PyObject* PyTensorObject_storage_offset(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->storage_offset()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_stride(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& stride = ASSERT_PTR(PyTensor_Unpack(self)->stride()); PyObject* tup = PyTuple_New(stride->size()); for (int i = 0; i < stride->size(); ++i) { PyTuple_SetItem(tup, i, PyLong_FromUnsignedLong(stride->at(i))); } return tup; END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_contiguous(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->is_contiguous()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_view(PyObject* self, PyObject* unused) { HANDLE_ERRORS if (PyTensor_Unpack(self)->is_view()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } END_HANDLE_ERRORS } static PyObject* PyTensorObject_contiguous(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(PyTensor_Unpack(self)->contiguous()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_contiguous_(PyObject* self, PyObject* unused) { // NOTE: inplace version of contiguous HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(functional::InplaceToContiguous(PyTensor_Unpack(self)))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_pin_memory(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(PyTensor_Unpack(self)->pin_memory()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_pinned(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(CHECK_JUST(PyTensor_Unpack(self)->is_pinned())); END_HANDLE_ERRORS } static PyObject* PyTensorObject_offload(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); CHECK_JUST(t->offload()); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject_load(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); CHECK_JUST(t->load()); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_offloaded(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(CHECK_JUST(PyTensor_Unpack(self)->is_offloaded())); END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_floating_point(PyObject* self, PyObject* unused) { HANDLE_ERRORS if (PyTensor_Unpack(self)->dtype()->is_floating_point()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } END_HANDLE_ERRORS } static PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS int requires_grad = 1; static const char* keywords[2] = {"requires_grad", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p:requires_grad_", const_cast(keywords), &requires_grad)) { return NULL; } ASSERT(PyTensor_Unpack(self)->set_requires_grad(requires_grad)); Py_XINCREF(self); return self; END_HANDLE_ERRORS } static PyObject* PyTensorObject_retain_grad(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); CHECK_JUST(t->set_retain_grad(true)); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject_detach(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->detach())); END_HANDLE_ERRORS } static PyObject* PyTensorObject_clone(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->clone())); END_HANDLE_ERRORS } static PyObject* PyTensorObject_zero_(PyObject* self, PyObject* unused) { HANDLE_ERRORS ASSERT(EagerLocalTensorZeros(PyTensor_Unpack(self))); Py_XINCREF(self); return self; END_HANDLE_ERRORS } std::vector> RawSbpBToP(Symbol nd_sbp) { std::vector> new_nd_sbp; for (const auto& old_sbp : nd_sbp->sbp_parallel()) { SbpParallel new_sbp = old_sbp; if (new_sbp.has_broadcast_parallel()) { new_sbp.mutable_partial_sum_parallel(); } new_nd_sbp.push_back(SymbolOf(new_sbp)); } return new_nd_sbp; } static constexpr auto* SbpBToP = DECORATE(&RawSbpBToP, ThreadLocalCached); static PyObject* PyTensorObject_zero_grad(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS int set_to_none = 0; static const char* keywords[2] = {"set_to_none", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p:_zero_grad_", const_cast(keywords), &set_to_none)) { return NULL; } const auto& t = PyTensor_Unpack(self); const auto acc_grad = ASSERT_PTR(t->acc_grad()); if (acc_grad) { if (set_to_none) { ASSERT(t->set_acc_grad(NULL)); } else { ASSERT(EagerLocalTensorZeros(acc_grad)); if (acc_grad->is_global() && acc_grad->is_eager()) { const auto local_tensor = ASSERT_PTR(functional::GlobalToLocal(acc_grad, false)); const auto p = ASSERT_PTR(functional::LocalToGlobal( local_tensor, ASSERT(acc_grad->parallel_desc()), SbpBToP(ASSERT(acc_grad->nd_sbp())), *acc_grad->shape(), acc_grad->dtype(), false, false)); ASSERT(acc_grad->set_data(p)); } } } Py_XINCREF(self); return self; END_HANDLE_ERRORS } static PyObject* PyTensorObject_register_hook(PyObject* self, PyObject* hook) { HANDLE_ERRORS const auto& _hook = py::cast(py::reinterpret_borrow(hook)); ASSERT(RegisterTensorHook(PyTensor_Unpack(self), _hook)); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject__register_post_grad_accumulation_hook(PyObject* self, PyObject* hook) { HANDLE_ERRORS const auto& _hook = py::cast(py::reinterpret_borrow(hook)); ASSERT(RegisterTensorPostGradAccumulationHook(PyTensor_Unpack(self), _hook)); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject_global_id(PyObject* self, PyObject* unused) { HANDLE_ERRORS uint64_t global_id = static_cast(ASSERT(PyTensor_Unpack(self)->transport_token())); return functional::CastToPyObject(global_id); END_HANDLE_ERRORS } static PyObject* PyTensorObject_check_meta_consistency(PyObject* self, PyObject* unused) { HANDLE_ERRORS ASSERT(CheckMetaConsistency(PyTensor_Unpack(self))); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject_data_ptr(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); const std::shared_ptr local_tensor = t->is_local() ? ASSERT_PTR(t->AsLocalTensor()) : ASSERT_PTR(t->cur_rank_phy_tensor()); return functional::CastToPyObject( reinterpret_cast(ASSERT(GetTensorDataPtr(local_tensor)))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_to_numpy(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); DataType data_type = t->dtype()->data_type(); switch (data_type) { #define SWITCH_EAGER_TENSOR_TO_NUMPY(cpp_type, of_type) \ case of_type: return ASSERT(EagerLocalTensorToNumpy(self)); OF_PP_FOR_EACH_TUPLE(SWITCH_EAGER_TENSOR_TO_NUMPY, POD_DATA_TYPE_SEQ INT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ) case DataType::kFloat16: return ASSERT(EagerLocalTensorToNumpy(self)); default: { return PyErr_Format(PyExc_RuntimeError, ("Invalid datatype " + DataType_Name(data_type)).data()); } } #undef SWITCH_EAGER_TENSOR_TO_NUMPY END_HANDLE_ERRORS } static PyObject* PyTensorObject_item(PyObject* self, PyObject* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); DataType data_type = t->dtype()->data_type(); switch (data_type) { #define CASE_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \ case of_type: return ASSERT(EagerLocalTensorItem(t)); OF_PP_FOR_EACH_TUPLE(CASE_SCALAR_TENSOR_TO_SCALAR, POD_AND_HALF_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ); default: { return PyErr_Format(PyExc_RuntimeError, ("Invalid datatype " + DataType_Name(data_type)).data()); } } #undef CASE_SCALAR_TENSOR_TO_SCALAR END_HANDLE_ERRORS } static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS const auto& tensor = PyTensor_Unpack(self); PyObject* tensor_type = NULL; int non_blocking = 0; static const char* keywords[3] = {"dtype", "non_blocking", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|Op:type", const_cast(keywords), &tensor_type, &non_blocking)) { return NULL; } // TODO: support non_blocking=True if (non_blocking == 1) { return PyErr_Format(PyExc_TypeError, "non_blocking=True is not supported yet"); } if (tensor_type == NULL) { tensor_type = PyTensorType_FromDTypeAndDeviceType(tensor->dtype(), ASSERT(tensor->device())->enum_type()); return PyUnicode_FromString(((PyTensorType*)tensor_type)->name); } if (PyTensorMetaClass_CheckExact(tensor_type)) { Optional device = "cpu"; return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, DType::Float(), /*copy=*/false))); } if (PyUnicode_Check(tensor_type)) { tensor_type = PyTensorType_FromString(PyUnicode_AsUTF8(tensor_type)); } if (PyTensorType_Check(tensor_type)) { const auto& dtype = PyTensorType_UnpackDType(tensor_type); DeviceType device_type = PyTensorType_UnpackDevice(tensor_type); if (device_type == ASSERT(tensor->device())->enum_type()) { return PyTensor_New(ASSERT_PTR(functional::To(tensor, dtype, /*copy=*/false))); } Optional device = ASSERT(DeviceTag4DeviceType(device_type)); return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, dtype, /*copy=*/false))); } else if (functional::PyDTypeCheck(tensor_type)) { return PyTensor_New( ASSERT_PTR(functional::To(tensor, functional::PyUnpackDType(tensor_type), /*copy=*/false))); } return PyErr_Format(PyExc_TypeError, "dtype must be a type, str, or dtype object"); END_HANDLE_ERRORS } namespace { void CopyFromNumpyArray(ep::Stream* stream, const std::shared_ptr& eager_blob_object, const NumPyArrayPtr& array_ptr) { SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(), eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(), memory::MakeHostMemCase()); } void CopyToNumpyArray(ep::Stream* stream, const std::shared_ptr& eager_blob_object, const NumPyArrayPtr& array_ptr) { SyncAutoMemcpy(stream, array_ptr.data(), eager_blob_object->dptr(), eager_blob_object->ByteSizeOfBlobBody(), memory::MakeHostMemCase(), eager_blob_object->mem_case()); } } // namespace // static PyObject* PyTensorObject__copy_to_numpy(PyObject* self, PyObject* array) { HANDLE_ERRORS ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), array, CopyToNumpyArray, "const", /*block_host_until_done=*/true)); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject__copy_from_numpy(PyObject* self, PyObject* array) { HANDLE_ERRORS auto* copied = PyArray_NewCopy((PyArrayObject*)array, NPY_CORDER); ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), copied, CopyFromNumpyArray, "mut", /*block_host_until_done=*/false)); Py_DECREF(copied); Py_RETURN_NONE; END_HANDLE_ERRORS } static PyObject* PyTensorObject__register_storage_delete_hook(PyObject* self, PyObject* hook) { HANDLE_ERRORS auto _hook = py::cast>(py::reinterpret_borrow(hook)); ASSERT(PyTensor_Unpack(self)->RegisterStorageDeleteHook(_hook)); Py_RETURN_NONE; END_HANDLE_ERRORS } static std::vector concat_method_def(PyMethodDef methods[], PyMethodDef extra_methods[]) { int len1 = 0; int len2 = 0; PyMethodDef* p1 = methods; PyMethodDef* p2 = extra_methods; while ((p1++)->ml_name != NULL) { len1++; } while ((p2++)->ml_name != NULL) { len2++; } std::vector total_methods(len1 + len2 + 1); for (int i = 0; i < len1; i++) total_methods[i] = methods[i]; for (int i = 0; i < len2; i++) total_methods[i + len1] = extra_methods[i]; total_methods[len1 + len2] = {NULL}; return total_methods; } static PyMethodDef PyTensorObject_methods[] = { {"storage_offset", PyTensorObject_storage_offset, METH_NOARGS, NULL}, {"stride", PyTensorObject_stride, METH_NOARGS, NULL}, {"is_contiguous", PyTensorObject_is_contiguous, METH_NOARGS, NULL}, {"is_view", PyTensorObject_is_view, METH_NOARGS, NULL}, {"contiguous", PyTensorObject_contiguous, METH_NOARGS, NULL}, {"contiguous_", PyTensorObject_contiguous_, METH_NOARGS, NULL}, {"pin_memory", PyTensorObject_pin_memory, METH_NOARGS, NULL}, {"is_pinned", PyTensorObject_is_pinned, METH_NOARGS, NULL}, {"offload", PyTensorObject_offload, METH_NOARGS, NULL}, {"load", PyTensorObject_load, METH_NOARGS, NULL}, {"is_offloaded", PyTensorObject_is_offloaded, METH_NOARGS, NULL}, {"is_floating_point", PyTensorObject_is_floating_point, METH_NOARGS, NULL}, {"requires_grad_", (PyCFunction)PyTensorObject_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL}, {"retain_grad", PyTensorObject_retain_grad, METH_NOARGS, NULL}, {"detach", PyTensorObject_detach, METH_NOARGS, NULL}, {"clone", PyTensorObject_clone, METH_NOARGS, NULL}, {"zero_", PyTensorObject_zero_, METH_NOARGS, NULL}, {"_zero_grad_", (PyCFunction)PyTensorObject_zero_grad, METH_VARARGS | METH_KEYWORDS, NULL}, {"register_hook", PyTensorObject_register_hook, METH_O, NULL}, {"_register_post_grad_accumulation_hook", PyTensorObject__register_post_grad_accumulation_hook, METH_O, NULL}, {"global_id", PyTensorObject_global_id, METH_NOARGS, NULL}, {"check_meta_consistency", PyTensorObject_check_meta_consistency, METH_NOARGS, NULL}, {"to_numpy", PyTensorObject_to_numpy, METH_NOARGS, NULL}, {"data_ptr", PyTensorObject_data_ptr, METH_NOARGS, NULL}, {"item", PyTensorObject_item, METH_NOARGS, NULL}, {"type", (PyCFunction)PyTensorObject_type, METH_VARARGS | METH_KEYWORDS, NULL}, {"_copy_to_numpy", PyTensorObject__copy_to_numpy, METH_O, NULL}, {"_copy_from_numpy", PyTensorObject__copy_from_numpy, METH_O, NULL}, {"_register_storage_delete_hook", PyTensorObject__register_storage_delete_hook, METH_O, NULL}, {NULL}}; static PyObject* PyTensorObject_ndim(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->ndim()); } static PyObject* PyTensorObject_shape(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->shape()); } static PyObject* PyTensorObject_dtype(PyObject* self, void* unused) { HANDLE_ERRORS const Symbol* dtype = &ASSERT(DType::Get(PyTensor_Unpack(self)->dtype()->data_type())); return functional::CastToPyObject(dtype); END_HANDLE_ERRORS } static PyObject* PyTensorObject_is_cpu(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_cpu()); } static PyObject* PyTensorObject_is_cuda(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_cuda()); } static PyObject* PyTensorObject_grad(PyObject* self, void* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->acc_grad())); END_HANDLE_ERRORS } static int PyTensorObject_set_grad(PyObject* self, PyObject* grad, void* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); if (self == grad) { PyErr_Format(PyExc_RuntimeError, "can't assign Tensor as its own grad"); } if (grad && grad != Py_None) { ASSERT(t->set_acc_grad(ASSERT_PTR(PyTensor_Unpack(grad)->detach()))); } else { ASSERT(t->set_acc_grad(NULL)); } return 0; END_HANDLE_ERRORS_RET(-1) } static PyObject* PyTensorObject_data(PyObject* self, void* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->data())); END_HANDLE_ERRORS } static int PyTensorObject_set_data(PyObject* self, PyObject* data, void* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); auto hooks = t->autograd_meta()->hooks(); ASSERT(t->set_data(PyTensor_Unpack(data))); // Re-register hooks for (const auto& hook : hooks) { ASSERT(RegisterTensorHook(t, hook)); } return 0; END_HANDLE_ERRORS_RET(-1) } static PyObject* PyTensorObject_ref_tensor(PyObject* self, void* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->ref_tensor())); END_HANDLE_ERRORS } static int PyTensorObject_set_ref_tensor(PyObject* self, PyObject* ref, void* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); if (self == ref) { PyErr_Format(PyExc_RuntimeError, "can't assign Tensor as its own reference"); } if (ref && ref != Py_None) { ASSERT(t->set_ref_tensor(PyTensor_Unpack(ref))); } else { ASSERT(t->set_ref_tensor(NULL)); } return 0; END_HANDLE_ERRORS_RET(-1) } static PyObject* PyTensorObject_ref_index(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->ref_index()); } static int PyTensorObject_set_ref_index(PyObject* self, PyObject* index, void* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); CHECK_OR_THROW(PyLong_Check(index)) << Error::RuntimeError() << "Index must be Integer type."; ASSERT(t->set_ref_index(PyLong_AsLong(index))); return 0; END_HANDLE_ERRORS_RET(-1) } static PyObject* PyTensorObject_grad_fn(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->grad_fn_node()); } static PyObject* PyTensorObject_is_leaf(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_leaf()); } static PyObject* PyTensorObject_requires_grad(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->requires_grad()); } static int PyTensorObject_set_requires_grad(PyObject* self, PyObject* requires_grad, void* unused) { HANDLE_ERRORS const auto& t = PyTensor_Unpack(self); CHECK_OR_THROW(t->is_leaf()) << Error::RuntimeError() << "You can only change requires_grad flags of leaf tensors."; ASSERT(t->set_requires_grad(requires_grad == Py_True)); return 0; END_HANDLE_ERRORS_RET(-1) } static PyObject* PyTensorObject_is_lazy(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_lazy()); } static PyObject* PyTensorObject_is_eager(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_eager()); } static PyObject* PyTensorObject_is_global(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_global()); } static PyObject* PyTensorObject_is_local(PyObject* self, void* unused) { return functional::CastToPyObject(PyTensor_Unpack(self)->is_local()); } static PyObject* PyTensorObject__tensor_buffer_shapes_and_dtypes(PyObject* self, void* unused) { HANDLE_ERRORS return functional::CastToPyObject(MaybeGetTensorBufferShapesAndDTypes(PyTensor_Unpack(self))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_device(PyObject* self, void* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->device()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_placement(PyObject* self, void* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->parallel_desc()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_sbp(PyObject* self, void* unused) { HANDLE_ERRORS return functional::CastToPyObject(TensorGetPyTupleOfSbp(*PyTensor_Unpack(self))); END_HANDLE_ERRORS } // NOLINTNEXTLINE static PyGetSetDef PyTensorObject_properties[] = { {PYGETSET_NAME("ndim"), (getter)PyTensorObject_ndim, NULL, NULL, NULL}, {PYGETSET_NAME("shape"), (getter)PyTensorObject_shape, NULL, NULL, NULL}, {PYGETSET_NAME("dtype"), (getter)PyTensorObject_dtype, NULL, NULL, NULL}, {PYGETSET_NAME("is_cpu"), (getter)PyTensorObject_is_cpu, NULL, NULL, NULL}, {PYGETSET_NAME("is_cuda"), (getter)PyTensorObject_is_cuda, NULL, NULL, NULL}, {PYGETSET_NAME("grad"), (getter)PyTensorObject_grad, (setter)PyTensorObject_set_grad, NULL, NULL}, {PYGETSET_NAME("data"), (getter)PyTensorObject_data, (setter)PyTensorObject_set_data, NULL, NULL}, {PYGETSET_NAME("_ref_tensor"), (getter)PyTensorObject_ref_tensor, (setter)PyTensorObject_set_ref_tensor, NULL, NULL}, {PYGETSET_NAME("_ref_index"), (getter)PyTensorObject_ref_index, (setter)PyTensorObject_set_ref_index, NULL, NULL}, {PYGETSET_NAME("grad_fn"), (getter)PyTensorObject_grad_fn, NULL, NULL, NULL}, {PYGETSET_NAME("is_leaf"), (getter)PyTensorObject_is_leaf, NULL, NULL, NULL}, {PYGETSET_NAME("requires_grad"), (getter)PyTensorObject_requires_grad, (setter)PyTensorObject_set_requires_grad, NULL, NULL}, {PYGETSET_NAME("is_lazy"), (getter)PyTensorObject_is_lazy, NULL, NULL, NULL}, {PYGETSET_NAME("is_eager"), (getter)PyTensorObject_is_eager, NULL, NULL, NULL}, {PYGETSET_NAME("is_global"), (getter)PyTensorObject_is_global, NULL, NULL, NULL}, {PYGETSET_NAME("is_local"), (getter)PyTensorObject_is_local, NULL, NULL, NULL}, {PYGETSET_NAME("_tensor_buffer_shapes_and_dtypes"), (getter)PyTensorObject__tensor_buffer_shapes_and_dtypes, NULL, NULL, NULL}, {PYGETSET_NAME("device"), (getter)PyTensorObject_device, NULL, NULL, NULL}, {PYGETSET_NAME("placement"), (getter)PyTensorObject_placement, NULL, NULL, NULL}, {PYGETSET_NAME("sbp"), (getter)PyTensorObject_sbp, NULL, NULL, NULL}, {NULL}}; // create a Tensor instance static PyObject* TensorMetaCls_call(PyObject* type, PyObject* args, PyObject* kwargs) { return PyType_Type.tp_call(type, args, kwargs); } static void TensorMetaCls_dealloc(PyObject* type) { PyType_Type.tp_dealloc(type); } static PyHeapTypeObject* MakeTensorMetaclass() { PyObject* name = PyUnicode_FromString("_TensorMeta"); auto* heap_type = (PyHeapTypeObject*)PyType_Type.tp_alloc(&PyType_Type, 0); heap_type->ht_name = name; heap_type->ht_qualname = PY_XINCREF(name); auto* type = &heap_type->ht_type; type->tp_name = "_TensorMeta"; type->tp_base = PY_XINCREF(&PyType_Type); type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; type->tp_call = TensorMetaCls_call; type->tp_dealloc = TensorMetaCls_dealloc; if (PyType_Ready(type) < 0) { return NULL; } PyObject_SetAttrString((PyObject*)type, "__module__", PyUnicode_FromString("oneflow._C")); return heap_type; } extern PyNumberMethods PyTensorObject_as_number; extern PyObject* PyTensorObject_richcompare(PyObject*, PyObject*, int); extern PyMethodDef PyTensorObject_extra_methods[]; static PyHeapTypeObject* TensorMetaclass_Type = MakeTensorMetaclass(); static PyTypeObject* MakeTensorType() { PyObject* name = PyUnicode_FromString("Tensor"); auto* metaclass = &TensorMetaclass_Type->ht_type; auto* heap_type = (PyHeapTypeObject*)metaclass->tp_alloc(metaclass, 0); if (!heap_type) { return NULL; } heap_type->ht_name = name; heap_type->ht_qualname = PY_XINCREF(name); auto* type = &heap_type->ht_type; type->tp_name = "Tensor"; type->tp_basicsize = sizeof(PyTensorObject); type->tp_init = PyTensorObject_init; type->tp_dealloc = PyTensorObject_dealloc; type->tp_getset = PyTensorObject_properties; static std::vector total_methods = concat_method_def(PyTensorObject_methods, PyTensorObject_extra_methods); type->tp_methods = total_methods.data(); type->tp_as_number = &PyTensorObject_as_number; type->tp_as_sequence = &PyTensorObject_as_sequence; type->tp_as_mapping = &PyTensorObject_as_mapping; type->tp_richcompare = PyTensorObject_richcompare; type->tp_hash = (hashfunc)_Py_HashPointer; type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; if (PyType_Ready(type) < 0) { return NULL; } PyObject_SetAttrString((PyObject*)type, "__module__", PyUnicode_FromString("oneflow")); return type; } static PyTypeObject* MakeParameterType() { PyObject* name = PyUnicode_FromString("Parameter"); auto* metaclass = &TensorMetaclass_Type->ht_type; auto* heap_type = (PyHeapTypeObject*)metaclass->tp_alloc(metaclass, 0); if (!heap_type) { return NULL; } heap_type->ht_name = name; heap_type->ht_qualname = PY_XINCREF(name); auto* type = &heap_type->ht_type; type->tp_name = "Parameter"; type->tp_basicsize = sizeof(PyTensorObject); type->tp_init = PyParameterObject_init; type->tp_base = PY_XINCREF(PyTensorObject_Type); type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; if (PyType_Ready(type) < 0) { return NULL; } PyObject_SetAttrString((PyObject*)type, "__module__", PyUnicode_FromString("oneflow.nn")); return type; } PyObject* PyTensor_New(const std::shared_ptr& data) { return PyTensor_wrap(data, /*bind_pyobj=*/nullptr); } PyObject* PyParameter_New(const std::shared_ptr& data) { return PyTensor_wrap(data, /*bind_pyobj=*/nullptr); } PyObject* PyParameter_New(const std::shared_ptr& data, bool requires_grad) { if (!data) { Py_RETURN_NONE; } return PyTensor_wrap(ASSERT_PTR(Parameter::MakeTensor(data, requires_grad)), /*bind_pyobj=*/nullptr); } } // namespace one } // namespace oneflow #undef ASSERT #undef ASSERT_PTR using namespace oneflow::one; ONEFLOW_API_PYBIND11_MODULE("", m) { PyTensorObject_Type = MakeTensorType(); PyParameterObject_Type = MakeParameterType(); if (PyTensorObject_Type && PyModule_AddObject(m.ptr(), "Tensor", (PyObject*)PyTensorObject_Type) < 0) { return; } auto nn = m.def_submodule("nn"); if (PyParameterObject_Type && PyModule_AddObject(nn.ptr(), "Parameter", (PyObject*)PyParameterObject_Type) < 0) { return; } } ================================================ FILE: oneflow/api/python/framework/tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/framework/tensor.h" namespace oneflow { namespace one { typedef struct { PyObject_HEAD; std::shared_ptr data; } PyTensorObject; extern PyTypeObject* PyTensorObject_Type; extern PyTypeObject* PyParameterObject_Type; inline bool PyTensorMetaClass_CheckExact(PyObject* obj) { return obj == (PyObject*)PyTensorObject_Type; } inline bool PyTensor_Check(PyObject* op) { return PyObject_TypeCheck(op, PyTensorObject_Type); } inline bool PyTensor_CheckExact(PyObject* op) { return op->ob_type == PyTensorObject_Type || op->ob_type == PyParameterObject_Type; } inline std::shared_ptr& PyTensor_Unpack(PyObject* op) { assert(PyTensor_Check(op)); return ((PyTensorObject*)op)->data; } PyObject* PyTensor_New(const std::shared_ptr& data); PyObject* PyParameter_New(const std::shared_ptr& data); PyObject* PyParameter_New(const std::shared_ptr& data, bool requires_grad); } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_ ================================================ FILE: oneflow/api/python/framework/tensor_functions.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/framework/tensor_functions_util.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/functional_api.yaml.pybind.h" #include "oneflow/api/python/functional/tensor_api.yaml.pybind.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/api/python/functional/tensor_api.yaml.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/api/python/utils/tensor_utils.h" namespace oneflow { namespace one { #define ASSERT(x) (x).GetOrThrow() #define ASSERT_PTR(x) (x).GetPtrOrThrow() using functional::PyObjectPtr; namespace { PyObject* concat_self(PyObject* self, PyObject* args) { PyObjectPtr self_tuple(PyTuple_Pack(1, self)); PyObject* tuple = PySequence_Concat(self_tuple.get(), args); CHECK_OR_THROW(tuple != NULL); return tuple; } PyObject* ndarray_judgment_and_compatibility(PyObject* self, PyObject* other) { if (PyArray_Check(other)) { const auto& tensor = PyTensor_Unpack(self); CHECK_OR_THROW(tensor->is_cpu()) << Error::RuntimeError() << "Can't convert non-cpu device tensor to numpy"; if (tensor->is_global()) { Symbol placement = ASSERT(tensor->parallel_desc()); auto ndsbp = ASSERT(tensor->nd_sbp()); std::vector> sbp(ndsbp->sbp_parallel_size(), ASSERT(MakeBroadcastSbpParallel())); other = functional::CastToPyObject(MakeGlobalTensorFromData(other, tensor->dtype(), placement, sbp, /*requires_grad=*/false)); } else { other = functional::CastToPyObject(functional::LocalTensorSharedNumpyData(other)); } } return other; } } // namespace #define NB_UNARY_FUNC(func_name, bind_func) \ static PyObject* func_name(PyObject* self) { \ HANDLE_ERRORS \ PyObjectPtr tuple(PyTuple_Pack(1, self)); \ auto* result = bind_func(NULL, tuple.get(), NULL); \ if (PyErr_Occurred()) { throw py::error_already_set(); } \ return result; \ END_HANDLE_ERRORS \ } #define NB_BINARY_FUNC(func_name, bind_func) \ static PyObject* func_name(PyObject* a, PyObject* b) { \ HANDLE_ERRORS \ b = ndarray_judgment_and_compatibility(a, b); \ PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \ auto* result = bind_func(NULL, tuple.get(), NULL); \ if (PyErr_Occurred()) { throw py::error_already_set(); } \ return result; \ END_HANDLE_ERRORS \ } // namespace one NB_UNARY_FUNC(PyTensorObject_nb_absolute, functional::abs); NB_UNARY_FUNC(PyTensorObject_nb_negative, functional::negative); NB_UNARY_FUNC(PyTensorObject_nb_invert, functional::bitwise_not); // TODO: not implemented yet // NB_UNARY_FUNC(PyTensorObject_positive, functional::positive); NB_BINARY_FUNC(PyTensorObject_nb_add, functional::add); NB_BINARY_FUNC(PyTensorObject_nb_sub, functional::sub); NB_BINARY_FUNC(PyTensorObject_nb_mul, functional::mul); NB_BINARY_FUNC(PyTensorObject_nb_fmod, functional::fmod); NB_BINARY_FUNC(PyTensorObject_nb_div, functional::div); NB_BINARY_FUNC(PyTensorObject_nb_and, functional::logical_and); NB_BINARY_FUNC(PyTensorObject_nb_xor, functional::logical_xor); NB_BINARY_FUNC(PyTensorObject_nb_or, functional::logical_or); NB_BINARY_FUNC(PyTensorObject_nb_floor_div, functional::floor_divide); NB_BINARY_FUNC(PyTensorObject_nb_true_div, functional::div); NB_BINARY_FUNC(PyTensorObject_nb_matrix_multiply, functional::matmul); static PyObject* PyTensorObject_nb_pow(PyObject* a, PyObject* b, PyObject* unused) { HANDLE_ERRORS b = ndarray_judgment_and_compatibility(a, b); PyObjectPtr tuple(PyTuple_Pack(2, a, b)); PyObject* result = functional::pow(NULL, tuple.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } #define NB_INPLACE_BINARY_FUNC(func_name, bind_func) \ static PyObject* func_name(PyObject* a, PyObject* b) { \ HANDLE_ERRORS \ b = ndarray_judgment_and_compatibility(a, b); \ PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \ PyObjectPtr dict(PyDict_New()); \ CHECK_OR_THROW(PyDict_SetItemString(dict.get(), "inplace", Py_True) > -1); \ PyObject* result = bind_func(NULL, tuple.get(), dict.get()); \ if (PyErr_Occurred()) { throw py::error_already_set(); } \ return result; \ END_HANDLE_ERRORS \ } // inplace operators NB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_add, functional::add); NB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_sub, functional::sub); // The interface of inplace mul not mul(*, inplace=True) but mul_ NB_BINARY_FUNC(PyTensorObject_nb_inplace_mul, functional::mul_); NB_BINARY_FUNC(PyTensorObject_nb_inplace_true_div, functional::div_); PyObject* PyTensorObject_nb_inplace_pow(PyObject* a, PyObject* b, PyObject* unused) { HANDLE_ERRORS PyObjectPtr tuple(PyTuple_Pack(2, a, b)); PyObjectPtr dict(PyDict_New()); CHECK_OR_THROW(PyDict_SetItemString(dict.get(), "inplace", Py_True) > -1); auto* result = functional::pow(NULL, tuple.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } PyNumberMethods PyTensorObject_as_number = { PyTensorObject_nb_add, // nb_add PyTensorObject_nb_sub, // nb_subtract PyTensorObject_nb_mul, // nb_multiply PyTensorObject_nb_fmod, // nb_remainder NULL, // nb_divmod PyTensorObject_nb_pow, // nb_power PyTensorObject_nb_negative, // nb_negative NULL, // nb_positive PyTensorObject_nb_absolute, // nb_absolute NULL, // nb_bool PyTensorObject_nb_invert, // nb_invert NULL, // nb_lshift NULL, // nb_rshift PyTensorObject_nb_and, // nb_and PyTensorObject_nb_xor, // nb_xor PyTensorObject_nb_or, // nb_or NULL, // nb_int NULL, // nb_reserved NULL, // nb_float PyTensorObject_nb_inplace_add, // nb_inplace_add PyTensorObject_nb_inplace_sub, // nb_inplace_sub PyTensorObject_nb_inplace_mul, // nb_inplace_mul NULL, // nb_inplace_remainder PyTensorObject_nb_inplace_pow, // nb_inplace_pow NULL, // nb_inplace_lshift NULL, // nb_inplace_rshift NULL, // nb_inplace_and NULL, // nb_inplace_xor NULL, // nb_inplace_or PyTensorObject_nb_floor_div, // nb_floor_div PyTensorObject_nb_true_div, // nb_true_div NULL, // nb_inplace_floor_div PyTensorObject_nb_inplace_true_div, // nb_inplace_true_div NULL, // nb_index PyTensorObject_nb_matrix_multiply, // nb_matrix_multiply NULL, // nb_inplace_matrix_multiply }; // extra methods // functions that accept only one Tensor #define UNARY_METHOD(func_name, bind_func) \ static PyObject* func_name(PyObject* self, PyObject* unused) { \ HANDLE_ERRORS \ return PyTensor_New(ASSERT_PTR(bind_func(PyTensor_Unpack(self)))); \ END_HANDLE_ERRORS \ } UNARY_METHOD(PyTensorObject_abs, functional::Abs); UNARY_METHOD(PyTensorObject_digamma, functional::Digamma); UNARY_METHOD(PyTensorObject_exp, functional::Exp); UNARY_METHOD(PyTensorObject_exp2, functional::Exp2); UNARY_METHOD(PyTensorObject_floor, functional::Floor); UNARY_METHOD(PyTensorObject_floor_, functional::Floor_); UNARY_METHOD(PyTensorObject_sign, functional::Sign); UNARY_METHOD(PyTensorObject_gelu, functional::Gelu); UNARY_METHOD(PyTensorObject_mish, functional::Mish); UNARY_METHOD(PyTensorObject_negative, functional::Negative); UNARY_METHOD(PyTensorObject_sigmoid, functional::Sigmoid); UNARY_METHOD(PyTensorObject_silu, functional::Silu); UNARY_METHOD(PyTensorObject_selu, functional::Selu); UNARY_METHOD(PyTensorObject_softsign, functional::SoftSign); UNARY_METHOD(PyTensorObject_log1p, functional::Log1p); UNARY_METHOD(PyTensorObject_log2, functional::Log2); UNARY_METHOD(PyTensorObject_log10, functional::Log10); UNARY_METHOD(PyTensorObject_reciprocal, functional::Reciprocal); UNARY_METHOD(PyTensorObject_ceil, functional::Ceil); UNARY_METHOD(PyTensorObject_ceil_, functional::Ceil_); UNARY_METHOD(PyTensorObject_erf, functional::Erf); UNARY_METHOD(PyTensorObject_erfc, functional::Erfc); UNARY_METHOD(PyTensorObject_erfinv, functional::Erfinv); UNARY_METHOD(PyTensorObject_erfinv_, functional::ErfinvInplace); UNARY_METHOD(PyTensorObject_expm1, functional::Expm1); UNARY_METHOD(PyTensorObject_log, functional::Log); UNARY_METHOD(PyTensorObject_rsqrt, functional::Rsqrt); UNARY_METHOD(PyTensorObject_sqrt, functional::Sqrt); UNARY_METHOD(PyTensorObject_square, functional::Square); UNARY_METHOD(PyTensorObject_round, functional::Round); UNARY_METHOD(PyTensorObject_round_, functional::Round_); UNARY_METHOD(PyTensorObject_t, functional::TransposeAllDimFunction); UNARY_METHOD(PyTensorObject_isnan, functional::IsNan); UNARY_METHOD(PyTensorObject_isinf, functional::IsInf); UNARY_METHOD(PyTensorObject_sin, functional::Sin); UNARY_METHOD(PyTensorObject_sin_, functional::Sin_); UNARY_METHOD(PyTensorObject_asin, functional::Asin); UNARY_METHOD(PyTensorObject_cos, functional::Cos); UNARY_METHOD(PyTensorObject_acos, functional::Acos); UNARY_METHOD(PyTensorObject_tan, functional::Tan); UNARY_METHOD(PyTensorObject_atan, functional::Atan); UNARY_METHOD(PyTensorObject_sinh, functional::Sinh); UNARY_METHOD(PyTensorObject_asinh, functional::Asinh); UNARY_METHOD(PyTensorObject_cosh, functional::Cosh); UNARY_METHOD(PyTensorObject_acosh, functional::Acosh); UNARY_METHOD(PyTensorObject_tanh, functional::Tanh); UNARY_METHOD(PyTensorObject_atanh, functional::Atanh); UNARY_METHOD(PyTensorObject_logical_not, functional::LogicalNot); UNARY_METHOD(PyTensorObject_bitwise_not, functional::BitwiseNot); UNARY_METHOD(PyTensorObject_inv, functional::Inv); UNARY_METHOD(PyTensorObject_trunc, functional::Trunc); // functions that directly pass arguments without parsing #define DIRECT_PASS_FUNC(func_name, bind_func) \ static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \ HANDLE_ERRORS \ PyObjectPtr concat_args(concat_self(self, args)); \ PyObject* result = bind_func(NULL, concat_args.get(), kwargs); \ if (PyErr_Occurred()) { throw py::error_already_set(); } \ return result; \ END_HANDLE_ERRORS \ } DIRECT_PASS_FUNC(PyTensorObject_floor_divide, functional::floor_divide) DIRECT_PASS_FUNC(PyTensorObject_atan2, functional::atan2) DIRECT_PASS_FUNC(PyTensorObject_gt, functional::greater) DIRECT_PASS_FUNC(PyTensorObject_gt_, functional::greater_) DIRECT_PASS_FUNC(PyTensorObject_frac, functional::frac) DIRECT_PASS_FUNC(PyTensorObject_frac_, functional::frac_) DIRECT_PASS_FUNC(PyTensorObject_ge, functional::greater_equal) DIRECT_PASS_FUNC(PyTensorObject_div, functional::div) DIRECT_PASS_FUNC(PyTensorObject_div_, functional::div_) DIRECT_PASS_FUNC(PyTensorObject_mul, functional::mul) DIRECT_PASS_FUNC(PyTensorObject_mul_, functional::mul_) DIRECT_PASS_FUNC(PyTensorObject_fmod, functional::fmod) DIRECT_PASS_FUNC(PyTensorObject_logical_and, functional::logical_and) DIRECT_PASS_FUNC(PyTensorObject_logical_or, functional::logical_or) DIRECT_PASS_FUNC(PyTensorObject_logical_xor, functional::logical_xor) DIRECT_PASS_FUNC(PyTensorObject_equal, functional::equal) DIRECT_PASS_FUNC(PyTensorObject_ne, functional::not_equal) DIRECT_PASS_FUNC(PyTensorObject_lt, functional::less) DIRECT_PASS_FUNC(PyTensorObject_le, functional::less_equal) DIRECT_PASS_FUNC(PyTensorObject_bmm, functional::batch_matmul) DIRECT_PASS_FUNC(PyTensorObject_argmax, functional::argmax) DIRECT_PASS_FUNC(PyTensorObject_argmin, functional::argmin) DIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin) DIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax) DIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul) DIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_) DIRECT_PASS_FUNC(PyTensorObject_addcdiv, functional::addcdiv) DIRECT_PASS_FUNC(PyTensorObject_addcdiv_, functional::addcdiv_) DIRECT_PASS_FUNC(PyTensorObject_flip, functional::flip) DIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip) DIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_) DIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp) DIRECT_PASS_FUNC(PyTensorObject_clamp_min, functional::clamp_min) DIRECT_PASS_FUNC(PyTensorObject_clamp_max, functional::clamp_max) DIRECT_PASS_FUNC(PyTensorObject_clamp_, functional::clamp_) DIRECT_PASS_FUNC(PyTensorObject_clamp_min_, functional::clamp_min_) DIRECT_PASS_FUNC(PyTensorObject_clamp_max_, functional::clamp_max_) DIRECT_PASS_FUNC(PyTensorObject_flatten, functional::flatten) DIRECT_PASS_FUNC(PyTensorObject_in_top_k, functional::in_top_k) DIRECT_PASS_FUNC(PyTensorObject_index_select, functional::index_select) DIRECT_PASS_FUNC(PyTensorObject_logsumexp, functional::logsumexp) DIRECT_PASS_FUNC(PyTensorObject_maximum, functional::maximum) DIRECT_PASS_FUNC(PyTensorObject_minimum, functional::minimum) DIRECT_PASS_FUNC(PyTensorObject_tril, functional::tril) DIRECT_PASS_FUNC(PyTensorObject_tril_, functional::tril_) DIRECT_PASS_FUNC(PyTensorObject_triu, functional::triu) DIRECT_PASS_FUNC(PyTensorObject_triu_, functional::triu_) DIRECT_PASS_FUNC(PyTensorObject_softmax, functional::softmax) DIRECT_PASS_FUNC(PyTensorObject_log_softmax, functional::log_softmax) DIRECT_PASS_FUNC(PyTensorObject_roll, functional::roll) DIRECT_PASS_FUNC(PyTensorObject_unbind, functional::unbind) DIRECT_PASS_FUNC(PyTensorObject_squeeze, functional::squeeze) DIRECT_PASS_FUNC(PyTensorObject_swapaxes, functional::swapaxes) DIRECT_PASS_FUNC(PyTensorObject_swapdims, functional::swapdims) DIRECT_PASS_FUNC(PyTensorObject_unfold, functional::unfold_tensor) DIRECT_PASS_FUNC(PyTensorObject_unsqueeze, functional::unsqueeze) DIRECT_PASS_FUNC(PyTensorObject_max, functional::max) DIRECT_PASS_FUNC(PyTensorObject_min, functional::min) DIRECT_PASS_FUNC(PyTensorObject_median, functional::median) DIRECT_PASS_FUNC(PyTensorObject_mode, functional::mode) DIRECT_PASS_FUNC(PyTensorObject_pow, functional::pow) DIRECT_PASS_FUNC(PyTensorObject_chunk, functional::chunk) DIRECT_PASS_FUNC(PyTensorObject_split, functional::split) DIRECT_PASS_FUNC(PyTensorObject_narrow, functional::narrow) DIRECT_PASS_FUNC(PyTensorObject_masked_fill, functional::masked_fill) DIRECT_PASS_FUNC(PyTensorObject_masked_fill_, functional::masked_fill_) DIRECT_PASS_FUNC(PyTensorObject_dot, functional::dot) DIRECT_PASS_FUNC(PyTensorObject_nansum, functional::reduce_nansum) DIRECT_PASS_FUNC(PyTensorObject_sum, functional::reduce_sum) DIRECT_PASS_FUNC(PyTensorObject_bernoulli, functional::bernoulli) DIRECT_PASS_FUNC(PyTensorObject_bernoulli_, functional::bernoulli_) DIRECT_PASS_FUNC(PyTensorObject_bincount, functional::bincount) DIRECT_PASS_FUNC(PyTensorObject_isclose, functional::isclose) DIRECT_PASS_FUNC(PyTensorObject_broadcast_to, functional::broadcast_to) DIRECT_PASS_FUNC(PyTensorObject_lerp, functional::lerp) DIRECT_PASS_FUNC(PyTensorObject_lerp_, functional::lerp_) DIRECT_PASS_FUNC(PyTensorObject_unique, functional::unique) DIRECT_PASS_FUNC(PyTensorObject_topk, functional::topk) DIRECT_PASS_FUNC(PyTensorObject_quantile, functional::quantile) DIRECT_PASS_FUNC(PyTensorObject_bitwise_and, functional::bitwise_and) DIRECT_PASS_FUNC(PyTensorObject_bitwise_or, functional::bitwise_or) DIRECT_PASS_FUNC(PyTensorObject_bitwise_xor, functional::bitwise_xor) DIRECT_PASS_FUNC(PyTensorObject_baddbmm, functional::baddbmm) DIRECT_PASS_FUNC(PyTensorObject_mm, functional::mm) DIRECT_PASS_FUNC(PyTensorObject_sub, functional::sub) DIRECT_PASS_FUNC(PyTensorObject_mv, functional::matrix_vector_product) DIRECT_PASS_FUNC(PyTensorObject_fill_, functional::fill_) DIRECT_PASS_FUNC(PyTensorObject_gather, functional::dim_gather) DIRECT_PASS_FUNC(PyTensorObject_repeat_interleave, functional::repeat_interleave) DIRECT_PASS_FUNC(PyTensorObject_scatter_add, functional::scatter_add) DIRECT_PASS_FUNC(PyTensorObject_logaddexp, functional::logaddexp) // functions that parsing at Python C api layer static PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(functional::To(PyTensor_Unpack(self), DType::UInt8(), false))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_dim(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->ndim()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_nelement(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->nelement()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_element_size(PyObject* self, PyObject* unused) { HANDLE_ERRORS return functional::CastToPyObject(PyTensor_Unpack(self)->dtype()->bytes()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_get_device(PyObject* self, PyObject* unused) { HANDLE_ERRORS DeviceType device_type = ASSERT(PyTensor_Unpack(self)->device())->enum_type(); CHECK_OR_THROW(device_type == DeviceType::kCUDA) << "get_device is only available for GPU tensor."; return functional::CastToPyObject(ASSERT(PyTensor_Unpack(self)->device())->device_id()); END_HANDLE_ERRORS } static PyObject* PyTensorObject_size(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* idx_obj = Py_None; static const char* keywords[2] = {"idx", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:size", const_cast(keywords), &idx_obj)) { return NULL; } auto shape = PyTensor_Unpack(self)->shape(); if (idx_obj == NULL || idx_obj == Py_None) return TensorSize_NewFromShape(*shape); int64_t idx = PyLong_AsLongLong(idx_obj); int64_t ndim = shape->NumAxes(); idx = CHECK_JUST(maybe_wrap_dim(idx, ndim)); idx = idx < 0 ? idx + ndim : idx; return PyLong_FromLongLong(shape->At(idx)); END_HANDLE_ERRORS } static PyObject* PyTensorObject_cast(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dtype = NULL; PyObject* pin_memory = Py_False; static const char* keywords[3] = {"dtype", "pin_memory", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!:cast", const_cast(keywords), &dtype, &PyBool_Type, &pin_memory)) { return NULL; } CHECK_OR_THROW(functional::PyDTypeCheck(dtype)) << Error::TypeError() << "cast(): argument 'dtype' must be data type, but found " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype))); const auto& result = functional::Cast(PyTensor_Unpack(self), functional::PyUnpackDType(dtype), pin_memory == Py_True); return PyTensor_New(ASSERT_PTR(result)); END_HANDLE_ERRORS } static PyObject* PyTensorObject_diag(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS int32_t diagonal = 0; static const char* keywords[2] = {"diagonal", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:diag", const_cast(keywords), &diagonal)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::Diag(PyTensor_Unpack(self), diagonal))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_diagonal(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS int32_t offset = 0; int32_t dim1 = 0; int32_t dim2 = 1; static const char* keywords[4] = {"offset", "dim1", "dim2", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iii:diagonal", const_cast(keywords), &offset, &dim1, &dim2)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::Diagonal(PyTensor_Unpack(self), offset, dim1, dim2))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_matmul(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* other = NULL; static const char* keywords[2] = {"other", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:matmul", const_cast(keywords), &other)) { return NULL; } PyObjectPtr concat_args(PyTuple_Pack(2, self, other)); PyObject* result = functional::matmul(NULL, concat_args.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } static PyObject* PyTensorObject_reshape(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* shape = PyParseArgs(args, kwargs, "reshape", "shape"); PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, shape)); PyObject* result = functional::reshape(NULL, _args.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } static PyObject* PyTensorObject_reshape_as(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); PyObject* other = NULL; static const char* keywords[2] = {"other", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|:reshape_as", const_cast(keywords), &other)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::Reshape(tensor, *PyTensor_Unpack(other)->shape()))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_cpu(PyObject* self, PyObject* unused) { HANDLE_ERRORS Optional device = "cpu"; return PyTensor_New(ASSERT_PTR(functional::To(PyTensor_Unpack(self), device, NullOpt, false))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_cuda(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* device_obj = Py_None; static const char* keywords[2] = {"device", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:cuda", const_cast(keywords), &device_obj)) { return NULL; } auto tensor = PyTensor_Unpack(self); if (functional::PyDeviceCheck(device_obj)) { Optional> device = functional::PyUnpackDevice(device_obj); return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, NullOpt, false))); } Optional device_str; if (device_obj == Py_None) { device_str = "cuda"; } else if (PyLong_Check(device_obj)) { device_str = "cuda:" + std::to_string(PyLong_AsLongLong(device_obj)); } return PyTensor_New(ASSERT_PTR(functional::To(tensor, device_str, tensor->dtype(), false))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_var(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dim_obj = Py_None; PyObject* unbiased_obj = Py_True; PyObject* keepdim_obj = Py_False; static const char* keywords[4] = {"dim", "unbiased", "keepdim", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO!O!:var", const_cast(keywords), &dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type, &keepdim_obj)) { return NULL; } bool unbiased = unbiased_obj == Py_True; bool keepdim = keepdim_obj == Py_True; CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj) || functional::PyLongSequenceCheck(dim_obj)) << Error::TypeError() << "var(): argument 'dim' must be int32 list, not " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj))); auto tensor = PyTensor_Unpack(self); if (dim_obj == Py_None) { return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, NullOpt, unbiased, keepdim))); } std::vector dim; if (PyLong_Check(dim_obj)) { dim.emplace_back(static_cast(PyLong_AsLong(dim_obj))); return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim))); } dim = functional::PyUnpackLongSequence(dim_obj); return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_std(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dim_obj = Py_None; PyObject* unbiased_obj = Py_True; PyObject* keepdim_obj = Py_False; static const char* keywords[4] = {"dim", "unbiased", "keepdim", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO!O!:std", const_cast(keywords), &dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type, &keepdim_obj)) { return NULL; } bool unbiased = unbiased_obj == Py_True; bool keepdim = keepdim_obj == Py_True; CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj) || functional::PyLongSequenceCheck(dim_obj)) << Error::TypeError() << "std(): argument 'dim' must be int32 list, not " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj))); auto tensor = PyTensor_Unpack(self); if (dim_obj == Py_None) { return PyTensor_New( ASSERT_PTR(functional::StandardDeviation(tensor, NullOpt, unbiased, keepdim))); } std::vector dim; if (PyLong_Check(dim_obj)) { dim.emplace_back(static_cast(PyLong_AsLong(dim_obj))); return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim))); } dim = functional::PyUnpackLongSequence(dim_obj); return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_softplus(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS double beta = 1.0; double threshold = 20.0; static const char* keywords[3] = {"beta", "threshold", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "dd:softplus", const_cast(keywords), &beta, &threshold)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::Softplus(PyTensor_Unpack(self), beta, threshold))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_relu(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), false))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_relu_(PyObject* self, PyObject* unused) { HANDLE_ERRORS return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), true))); END_HANDLE_ERRORS } #define REDUCE_FUNC(func_name, bind_func, whole_func) \ static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \ HANDLE_ERRORS \ if ((args == NULL || PyTuple_Size(args) == 0) \ && (kwargs == NULL || PyDict_Size(kwargs) == 0)) { \ return PyTensor_New(ASSERT_PTR(whole_func(PyTensor_Unpack(self)))); \ } \ PyObjectPtr concat_args(concat_self(self, args)); \ PyObject* result = bind_func(NULL, concat_args.get(), kwargs); \ if (PyErr_Occurred()) { throw py::error_already_set(); } \ return result; \ END_HANDLE_ERRORS \ } REDUCE_FUNC(PyTensorObject_any, functional::reduce_any, functional::ReduceAnyWhole) REDUCE_FUNC(PyTensorObject_all, functional::reduce_all, functional::ReduceAllWhole) REDUCE_FUNC(PyTensorObject_mean, functional::reduce_mean, functional::ReduceMeanWhole) #define DATATYPE_FUNC(func_name, dtype) \ static PyObject* func_name(PyObject* self, PyObject* unused) { \ HANDLE_ERRORS \ auto tensor = PyTensor_Unpack(self); \ return PyTensor_New(ASSERT_PTR(functional::To(tensor, dtype, false))); \ END_HANDLE_ERRORS \ } DATATYPE_FUNC(PyTensorObject_bool, DType::Bool()); DATATYPE_FUNC(PyTensorObject_int, DType::Int32()); DATATYPE_FUNC(PyTensorObject_long, DType::Int64()); DATATYPE_FUNC(PyTensorObject_half, DType::Float16()); DATATYPE_FUNC(PyTensorObject_float, DType::Float()); DATATYPE_FUNC(PyTensorObject_double, DType::Double()); DATATYPE_FUNC(PyTensorObject_bfloat16, DType::BFloat16()); static PyObject* PyTensorObject_view(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* size = PyParseArgs(args, kwargs, "view", "size"); PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, size)); PyObject* result = functional::view(NULL, _args.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } static PyObject* PyTensorObject_view_as(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); PyObject* other = NULL; static const char* keywords[2] = {"other", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|:view_as", const_cast(keywords), &other)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::View(tensor, *PyTensor_Unpack(other)->shape()))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_permute(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dims = PyParseArgs(args, kwargs, "permute", "dims"); PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, dims)); PyObject* result = functional::permute(NULL, _args.get(), NULL); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } static PyObject* PyTensorObject_transpose(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); int dim0 = 0; int dim1 = 0; static const char* keywords[3] = {"dim0", "dim1", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii:transpose", const_cast(keywords), &dim0, &dim1)) { return NULL; } return PyTensor_New(ASSERT_PTR(functional::Transpose2dim(tensor, dim0, dim1))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); CHECK_OR_THROW(tensor->is_local()) << Error::RuntimeError() << "input must be a local tensor"; PyObject* placement_obj = Py_None; PyObject* sbp_obj = Py_None; PyObject* check_meta_obj = Py_True; PyObject* copy_obj = Py_False; static const char* keywords[5] = {"placement", "sbp", "check_meta", "copy", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$O!O!:local_to_global", const_cast(keywords), &placement_obj, &sbp_obj, &PyBool_Type, &check_meta_obj, &PyBool_Type, ©_obj)) { return NULL; } const bool check_meta = (check_meta_obj == Py_True); const bool copy = (copy_obj == Py_True); CHECK_OR_THROW(placement_obj != Py_None && sbp_obj != Py_None) << Error::InvalidValueError() << "Converting a local tensor to global tensor must have placement and sbp parameters."; CHECK_OR_THROW(functional::PyParallelDescCheck(placement_obj)) << Error::TypeError() << "Invalid parameter placement with type " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj))); std::vector> sbp; if (functional::PySbpParallelCheck(sbp_obj)) { sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj)); } else { CHECK_OR_THROW(functional::PySbpParallelSequenceCheck(sbp_obj)) << Error::TypeError() << "Invalid parameter sbp with type " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj))); sbp = functional::PyUnpackSbpParallelSequence(sbp_obj); } return PyTensor_New(ASSERT_PTR(functional::ToGlobal( tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta, copy))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); CHECK_OR_THROW(tensor->is_global()) << Error::RuntimeError() << "input must be a global tensor"; PyObject* placement_obj = Py_None; PyObject* sbp_obj = Py_None; PyObject* grad_sbp_obj = Py_None; Symbol placement; std::vector> sbp; std::vector> grad_sbp; PyObject* check_meta_obj = Py_False; PyObject* copy_obj = Py_False; static const char* keywords[6] = {"placement", "sbp", "grad_sbp", "check_meta", "copy", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$OO!O!:global_to_global", const_cast(keywords), &placement_obj, &sbp_obj, &grad_sbp_obj, &PyBool_Type, &check_meta_obj, ©_obj)) { return NULL; } const bool check_meta = (check_meta_obj == Py_True); const bool copy = (copy_obj == Py_True); // sbp CHECK_OR_THROW(sbp_obj == Py_None || functional::PySbpParallelCheck(sbp_obj) || functional::PySbpParallelSequenceCheck(sbp_obj)) << Error::TypeError() << "sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp"; if (functional::PySbpParallelCheck(sbp_obj)) { sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj)); } else if (functional::PySbpParallelSequenceCheck(sbp_obj)) { sbp = functional::PyUnpackSbpParallelSequence(sbp_obj); } else { for (int32_t i = 0; i < ASSERT(tensor->nd_sbp())->sbp_parallel_size(); i++) sbp.emplace_back(ASSERT(tensor->nd_sbp())->sbp_parallel(i)); } // placement CHECK_OR_THROW(placement_obj == Py_None || functional::PyParallelDescCheck(placement_obj)) << Error::TypeError() << "Invalid parameter placement with type " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj))); if (placement_obj == Py_None) { placement = ASSERT(tensor->parallel_desc()); } else { placement = functional::PyUnpackParallelDesc(placement_obj); } // grad_sbp CHECK_OR_THROW(grad_sbp_obj == Py_None || functional::PySbpParallelCheck(grad_sbp_obj) || functional::PySbpParallelSequenceCheck(grad_sbp_obj)) << Error::TypeError() << "grad_sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp"; if (functional::PySbpParallelCheck(grad_sbp_obj)) { grad_sbp.emplace_back(functional::PyUnpackSbpParallel(grad_sbp_obj)); } else if (functional::PySbpParallelSequenceCheck(grad_sbp_obj)) { grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj); } return PyTensor_New( ASSERT_PTR(functional::ToGlobal(tensor, placement, sbp, grad_sbp, check_meta, copy))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS const auto& tensor = PyTensor_Unpack(self); PyObject* result = NULL; if (tensor->is_global()) result = PyTensorObject_global_to_global(self, args, kwargs); else { result = PyTensorObject_local_to_global(self, args, kwargs); } if (PyErr_Occurred()) { throw py::error_already_set(); } return result; END_HANDLE_ERRORS } static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused, PyObject* kwargs) { HANDLE_ERRORS auto tensor = PyTensor_Unpack(self); CHECK_OR_THROW(tensor->is_global()) << Error::RuntimeError() << "Expected global tensor for to_local but got local tensor!"; bool copy = false; static const char* keywords[2] = {"copy", NULL}; if (!PyArg_ParseTupleAndKeywords(unused, kwargs, "|$O!:to_local", const_cast(keywords), &PyBool_Type, ©)) { return NULL; }; return PyTensor_New(ASSERT_PTR(functional::GlobalToLocal(tensor, /*copy=*/copy))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_type_as(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto self_tensor = PyTensor_Unpack(self); PyObject* other = NULL; static const char* keywords[2] = {"other", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|:type_as", const_cast(keywords), &other)) { return NULL; } // target is local auto other_tensor = PyTensor_Unpack(other); if (other_tensor->is_local()) { Optional> device = ASSERT(other_tensor->device()); if (self_tensor->is_global()) { self_tensor = ASSERT_PTR(functional::GlobalToLocal(self_tensor, /*copy=*/false)); } return PyTensor_New( ASSERT_PTR(functional::To(self_tensor, device, other_tensor->dtype(), /*copy=*/false))); } // target is global std::shared_ptr value_tensor; value_tensor = ASSERT_PTR(functional::To(self_tensor, other_tensor->dtype(), /*copy=*/false)); Symbol placement = ASSERT(other_tensor->parallel_desc()); std::vector> sbp; auto ndsbp = ASSERT(other_tensor->nd_sbp()); for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) { sbp.emplace_back(ndsbp->sbp_parallel(i)); } return PyTensor_New( ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false))); END_HANDLE_ERRORS } static PyObject* PyTensorObject_new(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS auto self_tensor = PyTensor_Unpack(self); if (!kwargs) { if (PyTuple_Size(args) == 1 && PyTensor_Check(PyTuple_GET_ITEM(args, 0))) { // tensor.new(other) auto other_tensor = PyTensor_Unpack(PyTuple_GET_ITEM(args, 0)); CHECK_OR_THROW(!self_tensor->is_global() && !other_tensor->is_global()) << "Tensor.new(Tensor) only support local tensor."; CHECK_OR_THROW(self_tensor->dtype() == other_tensor->dtype()) << "Tensor.new() expect " << self_tensor->dtype()->name() << " dtype tensor, but got " << other_tensor->dtype()->name() << " dtype tensor."; CHECK_OR_THROW(ASSERT(self_tensor->device())->enum_type() == ASSERT(other_tensor->device())->enum_type()) << "Tensor.new() expect tensor on " << ASSERT(self_tensor->device())->type() << ", but got tensor on " << ASSERT(other_tensor->device())->type() << "."; return PyTensor_New(ASSERT_PTR(functional::TensorWithOtherCtor(other_tensor))); } kwargs = PyDict_New(); } PyObjectPtr dtype_key(PyUnicode_FromString("dtype")); PyObjectPtr dtype_value(functional::CastToPyObject(self_tensor->dtype())); CHECK_OR_THROW(PyDict_Contains(kwargs, dtype_key.get()) < 1); CHECK_OR_THROW(PyDict_SetItemString(kwargs, "dtype", dtype_value.get()) > -1); if (self_tensor->is_global()) { PyObjectPtr placement_key(PyUnicode_FromString("placement")); PyObjectPtr sbp_key(PyUnicode_FromString("sbp")); CHECK_OR_THROW(PyDict_Contains(kwargs, placement_key.get()) < 1); CHECK_OR_THROW(PyDict_Contains(kwargs, sbp_key.get()) < 1); Symbol placement = ASSERT(self_tensor->parallel_desc()); std::vector> sbp; auto ndsbp = ASSERT(self_tensor->nd_sbp()); for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) { sbp.emplace_back(ndsbp->sbp_parallel(i)); } PyObjectPtr placement_value(functional::CastToPyObject(placement)); PyObjectPtr sbp_value(functional::CastToPyObject(sbp)); CHECK_OR_THROW(PyDict_SetItemString(kwargs, "placement", placement_value.get()) > -1); CHECK_OR_THROW(PyDict_SetItemString(kwargs, "sbp", sbp_value.get()) > -1); } else { auto device = ASSERT(self_tensor->device()); PyObjectPtr device_key(PyUnicode_FromString("device")); CHECK_OR_THROW(PyDict_Contains(kwargs, device_key.get()) < 1) << "Some of the keywords were incorrect: device"; PyObjectPtr device_value(functional::CastToPyObject(device)); CHECK_OR_THROW(PyDict_SetItemString(kwargs, "device", device_value.get()) > -1); } return functional::_legacy_tensor_generic_ctor(NULL, args, kwargs); END_HANDLE_ERRORS } int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) { HANDLE_ERRORS CHECK_OR_THROW(functional::PyTensorIndexCheck(item)) << Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item))); CHECK_OR_THROW(functional::PyScalarCheck(value) || PyTensor_Check(value)) << Error::TypeError() << "tensor_setitem(): argument 'value' must be tensor or scalar, not " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value))); const auto& index_item = functional::PyUnpackTensorIndex(item); auto tensor = PyTensor_Unpack(self); // NOTE: use masked_fill_(local,global) to avoid D2H in TensorSetItem if index is bool tensor if (functional::PyScalarCheck(value) && index_item.size() == 1 && index_item[0].IsTensor()) { const auto& index_tensor = index_item[0].tensor(); if (index_tensor->shape() == tensor->shape() && (index_tensor->dtype() == DType::Bool() || index_tensor->dtype() == DType::UInt8())) { ASSERT_PTR( functional::MaskedFillInplace(tensor, index_tensor, functional::PyUnpackScalar(value))); return 0; } } std::shared_ptr value_tensor; { if (tensor->is_global()) { Symbol placement = ASSERT(tensor->parallel_desc()); auto ndsbp = ASSERT(tensor->nd_sbp()); std::vector> sbp(ndsbp->sbp_parallel_size(), ASSERT(MakeBroadcastSbpParallel())); if (functional::PyScalarCheck(value)) { Scalar value_scalar = functional::PyUnpackScalar(value); value_tensor = ASSERT_PTR( functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp)); } else { value_tensor = PyTensor_Unpack(value); CHECK_OR_THROW(value_tensor->is_global()) << Error::RuntimeError() << "tensor_setitem(): value must be a global tensor when self is global"; value_tensor = ASSERT_PTR( functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false)); } } else { if (functional::PyScalarCheck(value)) { // NOTE: initialize value_tensor in eager mode LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled=*/false); Scalar value_scalar = functional::PyUnpackScalar(value); value_tensor = ASSERT_PTR(functional::Constant(Shape({}), value_scalar, tensor->dtype(), ASSERT(tensor->device()))); } else { value_tensor = PyTensor_Unpack(value); CHECK_OR_THROW(value_tensor->is_local()) << Error::RuntimeError() << "tensor_setitem(): value must be a local tensor when self is local"; Optional> device = ASSERT(tensor->device()); value_tensor = ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false)); } } } ASSERT(functional::TensorSetItem(tensor, index_item, value_tensor)); return 0; END_HANDLE_ERRORS_RET(-1) } PyMethodDef PyTensorObject_extra_methods[] = { {"byte", PyTensorObject_byte, METH_NOARGS, NULL}, {"size", (PyCFunction)PyTensorObject_size, METH_VARARGS | METH_KEYWORDS, NULL}, {"argmax", (PyCFunction)PyTensorObject_argmax, METH_VARARGS | METH_KEYWORDS, NULL}, {"argmin", (PyCFunction)PyTensorObject_argmin, METH_VARARGS | METH_KEYWORDS, NULL}, {"amin", (PyCFunction)PyTensorObject_amin, METH_VARARGS | METH_KEYWORDS, NULL}, {"dim", PyTensorObject_dim, METH_NOARGS, NULL}, {"ndimension", PyTensorObject_dim, METH_NOARGS, NULL}, {"nelement", PyTensorObject_nelement, METH_NOARGS, NULL}, {"numel", PyTensorObject_nelement, METH_NOARGS, NULL}, {"element_size", PyTensorObject_element_size, METH_NOARGS, NULL}, {"get_device", PyTensorObject_get_device, METH_NOARGS, NULL}, {"cast", (PyCFunction)PyTensorObject_cast, METH_VARARGS | METH_KEYWORDS, NULL}, {"diag", (PyCFunction)PyTensorObject_diag, METH_VARARGS | METH_KEYWORDS, NULL}, {"diagonal", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcmul", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcmul_", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcdiv", (PyCFunction)PyTensorObject_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcdiv_", (PyCFunction)PyTensorObject_addcdiv_, METH_VARARGS | METH_KEYWORDS, NULL}, {"matmul", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, {"bool", PyTensorObject_bool, METH_NOARGS, NULL}, {"int", PyTensorObject_int, METH_NOARGS, NULL}, {"long", PyTensorObject_long, METH_NOARGS, NULL}, {"half", PyTensorObject_half, METH_NOARGS, NULL}, {"float", PyTensorObject_float, METH_NOARGS, NULL}, {"double", PyTensorObject_double, METH_NOARGS, NULL}, {"bfloat16", PyTensorObject_bfloat16, METH_NOARGS, NULL}, {"local_to_global", (PyCFunction)PyTensorObject_local_to_global, METH_VARARGS | METH_KEYWORDS, NULL}, {"global_to_global", (PyCFunction)PyTensorObject_global_to_global, METH_VARARGS | METH_KEYWORDS, NULL}, {"to_local", (PyCFunction)PyTensorObject_to_local, METH_VARARGS | METH_KEYWORDS, NULL}, {"to_global", (PyCFunction)PyTensorObject_to_global, METH_VARARGS | METH_KEYWORDS, NULL}, {"type_as", (PyCFunction)PyTensorObject_type_as, METH_VARARGS | METH_KEYWORDS, NULL}, {"cpu", PyTensorObject_cpu, METH_NOARGS, NULL}, {"cuda", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, {"var", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL}, {"std", (PyCFunction)PyTensorObject_std, METH_VARARGS | METH_KEYWORDS, NULL}, {"softplus", (PyCFunction)PyTensorObject_softplus, METH_VARARGS | METH_KEYWORDS, NULL}, {"relu", PyTensorObject_relu, METH_NOARGS, NULL}, {"relu_", PyTensorObject_relu_, METH_NOARGS, NULL}, {"all", (PyCFunction)PyTensorObject_all, METH_VARARGS | METH_KEYWORDS, NULL}, {"any", (PyCFunction)PyTensorObject_any, METH_VARARGS | METH_KEYWORDS, NULL}, {"sum", (PyCFunction)PyTensorObject_sum, METH_VARARGS | METH_KEYWORDS, NULL}, {"mean", (PyCFunction)PyTensorObject_mean, METH_VARARGS | METH_KEYWORDS, NULL}, {"new", (PyCFunction)PyTensorObject_new, METH_VARARGS | METH_KEYWORDS, NULL}, // macro DIRECT_PASS_FUNC {"floor_divide", (PyCFunction)PyTensorObject_floor_divide, METH_VARARGS | METH_KEYWORDS, NULL}, {"atan2", (PyCFunction)PyTensorObject_atan2, METH_VARARGS | METH_KEYWORDS, NULL}, {"equal", (PyCFunction)PyTensorObject_equal, METH_VARARGS | METH_KEYWORDS, NULL}, {"gt", (PyCFunction)PyTensorObject_gt, METH_VARARGS | METH_KEYWORDS, NULL}, {"gt_", (PyCFunction)PyTensorObject_gt_, METH_VARARGS | METH_KEYWORDS, NULL}, {"frac", (PyCFunction)PyTensorObject_frac, METH_VARARGS | METH_KEYWORDS, NULL}, {"frac_", (PyCFunction)PyTensorObject_frac_, METH_VARARGS | METH_KEYWORDS, NULL}, {"ge", (PyCFunction)PyTensorObject_ge, METH_VARARGS | METH_KEYWORDS, NULL}, {"div", (PyCFunction)PyTensorObject_div, METH_VARARGS | METH_KEYWORDS, NULL}, {"div_", (PyCFunction)PyTensorObject_div_, METH_VARARGS | METH_KEYWORDS, NULL}, {"mul", (PyCFunction)PyTensorObject_mul, METH_VARARGS | METH_KEYWORDS, NULL}, {"mul_", (PyCFunction)PyTensorObject_mul_, METH_VARARGS | METH_KEYWORDS, NULL}, {"fmod", (PyCFunction)PyTensorObject_fmod, METH_VARARGS | METH_KEYWORDS, NULL}, {"logical_and", (PyCFunction)PyTensorObject_logical_and, METH_VARARGS | METH_KEYWORDS, NULL}, {"logical_or", (PyCFunction)PyTensorObject_logical_or, METH_VARARGS | METH_KEYWORDS, NULL}, {"logical_xor", (PyCFunction)PyTensorObject_logical_xor, METH_VARARGS | METH_KEYWORDS, NULL}, {"bmm", (PyCFunction)PyTensorObject_bmm, METH_VARARGS | METH_KEYWORDS, NULL}, {"ne", (PyCFunction)PyTensorObject_ne, METH_VARARGS | METH_KEYWORDS, NULL}, {"lt", (PyCFunction)PyTensorObject_lt, METH_VARARGS | METH_KEYWORDS, NULL}, {"le", (PyCFunction)PyTensorObject_le, METH_VARARGS | METH_KEYWORDS, NULL}, {"flip", (PyCFunction)PyTensorObject_flip, METH_VARARGS | METH_KEYWORDS, NULL}, {"clip", (PyCFunction)PyTensorObject_clip, METH_VARARGS | METH_KEYWORDS, NULL}, {"clip_", (PyCFunction)PyTensorObject_clip_, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp", (PyCFunction)PyTensorObject_clamp, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_min", (PyCFunction)PyTensorObject_clamp_min, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_max", (PyCFunction)PyTensorObject_clamp_max, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_", (PyCFunction)PyTensorObject_clamp_, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_min_", (PyCFunction)PyTensorObject_clamp_min_, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_max_", (PyCFunction)PyTensorObject_clamp_max_, METH_VARARGS | METH_KEYWORDS, NULL}, {"flatten", (PyCFunction)PyTensorObject_flatten, METH_VARARGS | METH_KEYWORDS, NULL}, {"in_top_k", (PyCFunction)PyTensorObject_in_top_k, METH_VARARGS | METH_KEYWORDS, NULL}, {"index_select", (PyCFunction)PyTensorObject_index_select, METH_VARARGS | METH_KEYWORDS, NULL}, {"maximum", (PyCFunction)PyTensorObject_maximum, METH_VARARGS | METH_KEYWORDS, NULL}, {"minimum", (PyCFunction)PyTensorObject_minimum, METH_VARARGS | METH_KEYWORDS, NULL}, {"tril", (PyCFunction)PyTensorObject_tril, METH_VARARGS | METH_KEYWORDS, NULL}, {"tril_", (PyCFunction)PyTensorObject_tril_, METH_VARARGS | METH_KEYWORDS, NULL}, {"triu", (PyCFunction)PyTensorObject_triu, METH_VARARGS | METH_KEYWORDS, NULL}, {"triu_", (PyCFunction)PyTensorObject_triu_, METH_VARARGS | METH_KEYWORDS, NULL}, {"softmax", (PyCFunction)PyTensorObject_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, {"log_softmax", (PyCFunction)PyTensorObject_log_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, {"roll", (PyCFunction)PyTensorObject_roll, METH_VARARGS | METH_KEYWORDS, NULL}, {"unbind", (PyCFunction)PyTensorObject_unbind, METH_VARARGS | METH_KEYWORDS, NULL}, {"squeeze", (PyCFunction)PyTensorObject_squeeze, METH_VARARGS | METH_KEYWORDS, NULL}, {"swapaxes", (PyCFunction)PyTensorObject_swapaxes, METH_VARARGS | METH_KEYWORDS, NULL}, {"amax", (PyCFunction)PyTensorObject_amax, METH_VARARGS | METH_KEYWORDS, NULL}, {"swapdims", (PyCFunction)PyTensorObject_swapdims, METH_VARARGS | METH_KEYWORDS, NULL}, {"unfold", (PyCFunction)PyTensorObject_unfold, METH_VARARGS | METH_KEYWORDS, NULL}, {"unsqueeze", (PyCFunction)PyTensorObject_unsqueeze, METH_VARARGS | METH_KEYWORDS, NULL}, {"max", (PyCFunction)PyTensorObject_max, METH_VARARGS | METH_KEYWORDS, NULL}, {"min", (PyCFunction)PyTensorObject_min, METH_VARARGS | METH_KEYWORDS, NULL}, {"median", (PyCFunction)PyTensorObject_median, METH_VARARGS | METH_KEYWORDS, NULL}, {"mode", (PyCFunction)PyTensorObject_mode, METH_VARARGS | METH_KEYWORDS, NULL}, {"pow", (PyCFunction)PyTensorObject_pow, METH_VARARGS | METH_KEYWORDS, NULL}, {"chunk", (PyCFunction)PyTensorObject_chunk, METH_VARARGS | METH_KEYWORDS, NULL}, {"split", (PyCFunction)PyTensorObject_split, METH_VARARGS | METH_KEYWORDS, NULL}, {"narrow", (PyCFunction)PyTensorObject_narrow, METH_VARARGS | METH_KEYWORDS, NULL}, {"masked_fill", (PyCFunction)PyTensorObject_masked_fill, METH_VARARGS | METH_KEYWORDS, NULL}, {"masked_fill_", (PyCFunction)PyTensorObject_masked_fill_, METH_VARARGS | METH_KEYWORDS, NULL}, {"dot", (PyCFunction)PyTensorObject_dot, METH_VARARGS | METH_KEYWORDS, NULL}, {"nansum", (PyCFunction)PyTensorObject_nansum, METH_VARARGS | METH_KEYWORDS, NULL}, {"sum", (PyCFunction)PyTensorObject_sum, METH_VARARGS | METH_KEYWORDS, NULL}, {"bernoulli", (PyCFunction)PyTensorObject_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL}, {"bernoulli_", (PyCFunction)PyTensorObject_bernoulli_, METH_VARARGS | METH_KEYWORDS, NULL}, {"bincount", (PyCFunction)PyTensorObject_bincount, METH_VARARGS | METH_KEYWORDS, NULL}, {"isclose", (PyCFunction)PyTensorObject_isclose, METH_VARARGS | METH_KEYWORDS, NULL}, {"broadcast_to", (PyCFunction)PyTensorObject_broadcast_to, METH_VARARGS | METH_KEYWORDS, NULL}, {"lerp", (PyCFunction)PyTensorObject_lerp, METH_VARARGS | METH_KEYWORDS, NULL}, {"lerp_", (PyCFunction)PyTensorObject_lerp_, METH_VARARGS | METH_KEYWORDS, NULL}, {"unique", (PyCFunction)PyTensorObject_unique, METH_VARARGS | METH_KEYWORDS, NULL}, {"topk", (PyCFunction)PyTensorObject_topk, METH_VARARGS | METH_KEYWORDS, NULL}, {"bitwise_and", (PyCFunction)PyTensorObject_bitwise_and, METH_VARARGS | METH_KEYWORDS, NULL}, {"bitwise_or", (PyCFunction)PyTensorObject_bitwise_or, METH_VARARGS | METH_KEYWORDS, NULL}, {"bitwise_xor", (PyCFunction)PyTensorObject_bitwise_xor, METH_VARARGS | METH_KEYWORDS, NULL}, {"baddbmm", (PyCFunction)PyTensorObject_baddbmm, METH_VARARGS | METH_KEYWORDS, NULL}, {"mm", (PyCFunction)PyTensorObject_mm, METH_VARARGS | METH_KEYWORDS, NULL}, {"sub", (PyCFunction)PyTensorObject_sub, METH_VARARGS | METH_KEYWORDS, NULL}, {"mv", (PyCFunction)PyTensorObject_mv, METH_VARARGS | METH_KEYWORDS, NULL}, {"fill_", (PyCFunction)PyTensorObject_fill_, METH_VARARGS | METH_KEYWORDS, NULL}, {"gather", (PyCFunction)PyTensorObject_gather, METH_VARARGS | METH_KEYWORDS, NULL}, {"repeat_interleave", (PyCFunction)PyTensorObject_repeat_interleave, METH_VARARGS | METH_KEYWORDS, NULL}, {"scatter_add", (PyCFunction)PyTensorObject_scatter_add, METH_VARARGS | METH_KEYWORDS, NULL}, {"logaddexp", (PyCFunction)PyTensorObject_logaddexp, METH_VARARGS | METH_KEYWORDS, NULL}, // macro UNARY_METHOD {"abs", PyTensorObject_abs, METH_NOARGS, NULL}, {"digamma", PyTensorObject_digamma, METH_NOARGS, NULL}, {"exp", PyTensorObject_exp, METH_NOARGS, NULL}, {"exp2", PyTensorObject_exp2, METH_NOARGS, NULL}, {"floor", PyTensorObject_floor, METH_NOARGS, NULL}, {"floor_", PyTensorObject_floor_, METH_NOARGS, NULL}, {"acos", PyTensorObject_acos, METH_NOARGS, NULL}, {"arccos", PyTensorObject_acos, METH_NOARGS, NULL}, {"acosh", PyTensorObject_acosh, METH_NOARGS, NULL}, {"arccosh", PyTensorObject_acosh, METH_NOARGS, NULL}, {"atanh", PyTensorObject_atanh, METH_NOARGS, NULL}, {"arctanh", PyTensorObject_atanh, METH_NOARGS, NULL}, {"sign", PyTensorObject_sign, METH_NOARGS, NULL}, {"sinh", PyTensorObject_sinh, METH_NOARGS, NULL}, {"tan", PyTensorObject_tan, METH_NOARGS, NULL}, {"gelu", PyTensorObject_gelu, METH_NOARGS, NULL}, {"mish", PyTensorObject_mish, METH_NOARGS, NULL}, {"negative", PyTensorObject_negative, METH_NOARGS, NULL}, {"neg", PyTensorObject_negative, METH_NOARGS, NULL}, {"sigmoid", PyTensorObject_sigmoid, METH_NOARGS, NULL}, {"tanh", PyTensorObject_tanh, METH_NOARGS, NULL}, {"silu", PyTensorObject_silu, METH_NOARGS, NULL}, {"selu", PyTensorObject_selu, METH_NOARGS, NULL}, {"softsign", PyTensorObject_softsign, METH_NOARGS, NULL}, {"log1p", PyTensorObject_log1p, METH_NOARGS, NULL}, {"log2", PyTensorObject_log2, METH_NOARGS, NULL}, {"log10", PyTensorObject_log10, METH_NOARGS, NULL}, {"reciprocal", PyTensorObject_reciprocal, METH_NOARGS, NULL}, {"asin", PyTensorObject_asin, METH_NOARGS, NULL}, {"arcsin", PyTensorObject_asin, METH_NOARGS, NULL}, {"asinh", PyTensorObject_asinh, METH_NOARGS, NULL}, {"arcsinh", PyTensorObject_asinh, METH_NOARGS, NULL}, {"atan", PyTensorObject_atan, METH_NOARGS, NULL}, {"arctan", PyTensorObject_atan, METH_NOARGS, NULL}, {"ceil", PyTensorObject_ceil, METH_NOARGS, NULL}, {"ceil_", PyTensorObject_ceil_, METH_NOARGS, NULL}, {"cos", PyTensorObject_cos, METH_NOARGS, NULL}, {"cosh", PyTensorObject_cosh, METH_NOARGS, NULL}, {"erf", PyTensorObject_erf, METH_NOARGS, NULL}, {"erfc", PyTensorObject_erfc, METH_NOARGS, NULL}, {"erfinv", PyTensorObject_erfinv, METH_NOARGS, NULL}, {"erfinv_", PyTensorObject_erfinv_, METH_NOARGS, NULL}, {"expm1", PyTensorObject_expm1, METH_NOARGS, NULL}, {"log", PyTensorObject_log, METH_NOARGS, NULL}, {"rsqrt", PyTensorObject_rsqrt, METH_NOARGS, NULL}, {"sqrt", PyTensorObject_sqrt, METH_NOARGS, NULL}, {"square", PyTensorObject_square, METH_NOARGS, NULL}, {"round", PyTensorObject_round, METH_NOARGS, NULL}, {"round_", PyTensorObject_round_, METH_NOARGS, NULL}, {"t", PyTensorObject_t, METH_NOARGS, NULL}, {"sin", PyTensorObject_sin, METH_NOARGS, NULL}, {"sin_", PyTensorObject_sin_, METH_NOARGS, NULL}, {"isnan", PyTensorObject_isnan, METH_NOARGS, NULL}, {"inverse", PyTensorObject_inv, METH_NOARGS, NULL}, {"trunc", PyTensorObject_trunc, METH_NOARGS, NULL}, {"isinf", PyTensorObject_isinf, METH_NOARGS, NULL}, {"logical_not", PyTensorObject_logical_not, METH_NOARGS, NULL}, {"floor", PyTensorObject_floor, METH_NOARGS, NULL}, {"floor_", PyTensorObject_floor_, METH_NOARGS, NULL}, {"bitwise_not", (PyCFunction)PyTensorObject_bitwise_not, METH_NOARGS, NULL}, {"reshape", (PyCFunction)PyTensorObject_reshape, METH_VARARGS | METH_KEYWORDS, NULL}, {"reshape_as", (PyCFunction)PyTensorObject_reshape_as, METH_VARARGS | METH_KEYWORDS, NULL}, {"view", (PyCFunction)PyTensorObject_view, METH_VARARGS | METH_KEYWORDS, NULL}, {"view_as", (PyCFunction)PyTensorObject_view_as, METH_VARARGS | METH_KEYWORDS, NULL}, {"permute", (PyCFunction)PyTensorObject_permute, METH_VARARGS | METH_KEYWORDS, NULL}, {"transpose", (PyCFunction)PyTensorObject_transpose, METH_VARARGS | METH_KEYWORDS, NULL}, {"logsumexp", (PyCFunction)PyTensorObject_logsumexp, METH_VARARGS | METH_KEYWORDS, NULL}, {"quantile", (PyCFunction)PyTensorObject_quantile, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL}, }; // tp_richcompare PyObject* PyTensorObject_richcompare(PyObject* self, PyObject* other, int op) { PyObjectPtr tuple(PyTuple_Pack(2, self, other)); switch (op) { case Py_LT: return functional::less(NULL, tuple.get(), NULL); case Py_LE: return functional::less_equal(NULL, tuple.get(), NULL); case Py_EQ: { if (self == Py_None || other == Py_None) Py_RETURN_FALSE; return functional::broadcast_equal(NULL, tuple.get(), NULL); } case Py_NE: { if (self == Py_None || other == Py_None) Py_RETURN_TRUE; return functional::not_equal(NULL, tuple.get(), NULL); } case Py_GT: return functional::greater(NULL, tuple.get(), NULL); case Py_GE: return functional::greater_equal(NULL, tuple.get(), NULL); } return NULL; } } // namespace one } // namespace oneflow #undef ASSERT #undef ASSERT_PTR ================================================ FILE: oneflow/api/python/framework/tensor_functions_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #undef _PyGC_FINALIZED #include #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/throw.h" namespace oneflow { namespace one { using functional::PyObjectPtr; std::string PyUnpack_String(PyObject* obj) { CHECK_OR_THROW(PyUnicode_Check(obj)) << "PyUnpack_String(): expect a PyUnicode object"; Py_ssize_t size = -1; const char* data = PyUnicode_AsUTF8AndSize(obj, &size); CHECK_NOTNULL_OR_THROW(data) << "error unpacking string as utf-8"; return std::string(data, (size_t)size); } // For signature like Tensor.reshape(*shape), this function can handle these cases: // 1. parse positional arguments only case, like Tensor.reshape(1, 2) // 2. parse keyword arguments only case, like Tensor.reshape(shape=(1, 2)) // 3. raise Error for multiple arguments case, like Tensor.reshape(1, shape=(1, )) // 4. return empty tuple for empty arguments, like Tensor.reshape() PyObject* PyParseArgs(PyObject* args, PyObject* kwargs, const char* func_name, const std::string& param_name) { PyObject* args_obj = NULL; // Tensor.reshape(shape=(1, 2)), get (1, 2) for kwargs["shape"] if (kwargs != NULL) { PyObject* key = nullptr; PyObject* value = nullptr; Py_ssize_t pos = 0; while (PyDict_Next(kwargs, &pos, &key, &value)) { CHECK_OR_THROW(args_obj == NULL) << Error::TypeError() << func_name << "() got multiple values for argument '" << param_name << "' or get invalid argument"; CHECK_EQ_OR_THROW(PyUnpack_String(key), param_name) << Error::TypeError() << func_name << "() got an unexpected keyword argument " << PyUnpack_String(key); args_obj = value; } } if (PyTuple_GET_SIZE(args) != 0) { CHECK_OR_THROW(args_obj == NULL) << Error::TypeError() << func_name << "() got multiple values for argument '" << param_name << "' or get invalid argument"; if (PyTuple_Size(args) == 1 && functional::PyShapeSequenceCheck(args)) { args_obj = PyTuple_GET_ITEM(args, 0); } else { args_obj = args; } }; if (args_obj == NULL) { args_obj = args; } return args_obj; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/tensor_tuple.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor.h" namespace py = pybind11; namespace oneflow { namespace one { namespace { struct TensorTupleUtil final { static std::string ToString(const TensorTuple& tensor_tuple) { std::stringstream ss; int32_t idx = 0; ss << "TensorTuple("; for (const std::shared_ptr& tensor : tensor_tuple) { ss << tensor; if (++idx != tensor_tuple.size() || tensor_tuple.size() == 1) { ss << ", "; } } ss << ")"; return ss.str(); } static void MergeFrom(std::shared_ptr& tensor_tuple, const TensorTuple& other) { for (const auto& tensor : other) { tensor_tuple->emplace_back(tensor); } } static void AppendTensor(std::shared_ptr& tensor_tuple, const std::shared_ptr& tensor) { tensor_tuple->emplace_back(tensor); } }; } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "TensorTuple") .def(py::init([]() { return std::make_shared(); })) .def(py::init([](const std::shared_ptr& other) { return other; })) .def(py::init([](const std::vector>& list) { auto tensor_tuple = std::make_shared(); for (const auto& t : list) { tensor_tuple->emplace_back(t); } return tensor_tuple; })) .def("__str__", &TensorTupleUtil::ToString) .def("__repr__", &TensorTupleUtil::ToString) .def("__getitem__", [](const TensorTuple& tensor_tuple, int idx) { return tensor_tuple.at(idx); }) .def("__setitem__", [](std::shared_ptr& tensor_tuple, int idx, const std::shared_ptr& tensor) { tensor_tuple->at(idx) = tensor; }) .def( "__iter__", [](const TensorTuple& tensor_tuple) { return py::make_iterator(tensor_tuple.begin(), tensor_tuple.end()); }, py::keep_alive<0, 1>()) .def("__len__", [](const TensorTuple& tensor_tuple) { return tensor_tuple.size(); }) .def("merge_from", &TensorTupleUtil::MergeFrom) .def("append", &TensorTupleUtil::AppendTensor); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/framework/tensortype.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #undef _PyGC_FINALIZED #include #include "oneflow/api/python/framework/tensor.h" #include "oneflow/api/python/framework/tensortype.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/symbol.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/tensor_api.yaml.pybind.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/api/python/exception/exception.h" namespace oneflow { namespace one { #define ASSERT(x) (x).GetOrThrow() #define ASSERT_PTR(x) (x).GetPtrOrThrow() using functional::PyObjectPtr; static PyTypeObject PyTensorTypeMetaClass{ PyVarObject_HEAD_INIT(NULL, 0) "oneflow.tensortype", // tp_name sizeof(PyTypeObject), // tp_basicsize }; static PyTypeObject PyTensorTypeTemplate{ PyVarObject_HEAD_INIT(&PyTensorTypeMetaClass, 0) NULL, // tp_name sizeof(PyTensorType), // tp_basicsize }; static std::vector tensor_types; static const std::unordered_map, std::string> all_data_types = { {DType::Float(), "FloatTensor"}, {DType::Double(), "DoubleTensor"}, {DType::Int8(), "CharTensor"}, {DType::Int32(), "IntTensor"}, {DType::Int64(), "LongTensor"}, {DType::UInt8(), "ByteTensor"}, {DType::Float16(), "HalfTensor"}, {DType::BFloat16(), "BFloat16Tensor"}, {DType::Bool(), "BoolTensor"}, {DType::Complex32(), "ComplexHalfTensor"}, {DType::Complex64(), "ComplexFloatTensor"}, {DType::Complex128(), "ComplexDoubleTensor"}, {DType::Char(), "CharTensor"}, {DType::Int16(), "ShortTensor"}, }; static const std::string get_dtype_string(PyTensorType* tensortype) { return all_data_types.at(tensortype->dtype); } static std::vector> all_device_types = { {kCPU, "oneflow"}, {kCUDA, "oneflow.cuda"}, }; static PyObject* PyTensorTypeMetaCls_call(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS const auto& dtype = PyTensorType_UnpackDType(self); PyObjectPtr dtype_value(functional::CastToPyObject(dtype)); if (!kwargs) { kwargs = PyDict_New(); } else { const char* dtype_str = "dtype"; PyObjectPtr dtype_key(PyUnicode_FromString(dtype_str)); CHECK_OR_THROW(PyDict_Contains(kwargs, dtype_key.get()) < 1) << "Some of the keywords were incorrect: dtype"; } CHECK_OR_THROW(PyDict_SetItemString(kwargs, "dtype", dtype_value.get()) > -1); Maybe maybe_device = DeviceTag4DeviceType(PyTensorType_UnpackDevice(self)); if (!TRY(maybe_device).IsOk()) { return PyErr_Format(PyExc_ValueError, "invalid device"); } { const char* placement_str = "placement"; PyObjectPtr placement_key(PyUnicode_FromString(placement_str)); if (PyDict_Contains(kwargs, placement_key.get()) == 1) { // If creat global tensor, the device of TensorType will be cover by param placement // Raise a warning to inform users of using oneflow.Tensortype rather than // oneflow.xxx.Tensortype CHECK_OR_THROW(PyTensorType_UnpackDevice(self) == kCPU) << "`" << ((PyTensorType*)self)->name << "` can not creat a global tensor, consider use `oneflow." << get_dtype_string((PyTensorType*)self) << "`"; } else { std::string device = ASSERT(maybe_device); PyObjectPtr device_value(PyUnicode_FromString(device.data())); CHECK_OR_THROW(PyDict_SetItemString(kwargs, "device", device_value.get()) > -1); } } auto* tensor = functional::_legacy_tensor_generic_ctor(NULL, args, kwargs); if (PyErr_Occurred()) { throw py::error_already_set(); } return tensor; END_HANDLE_ERRORS }; PyObject* PyTensorType_FromString(const std::string& tensortype) { auto it = std::find_if( tensor_types.begin(), tensor_types.end(), [tensortype](PyTensorType* type) { return std::string(type->name) == tensortype; }); if (it == tensor_types.end()) { PyErr_Format(PyExc_ValueError, "invalid type: %s", tensortype.data()); throw py::error_already_set(); } return (PyObject*)(*it); } static const char* get_doc(PyTensorType* tensortype) { // all tensortype docs static std::vector tensortype_doc; std::string dtype = tensortype->dtype->name(); std::string doc = ""; if (!TRY(DeviceTag4DeviceType(tensortype->devicetype)).IsOk()) doc = "The tensortype " + std::string(tensortype->name) + " is not available."; else { std::string device = ASSERT(DeviceTag4DeviceType(tensortype->devicetype)); doc = "Creates a Tensor with the dtype of " + dtype + " and the device on " + device + ", it has the same parameters as :func:`oneflow.Tensor`"; } tensortype_doc.emplace_back(doc); return tensortype_doc.back().data(); } static void init_tensortype_metaclass(PyTypeObject* metaclass) { metaclass->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; metaclass->tp_base = &PyType_Type; metaclass->tp_call = PyTensorTypeMetaCls_call; if (PyType_Ready(metaclass) < 0) { return; } } static void init_tensortype(PyTypeObject* type, PyTypeObject& type_template, const char* name, const char* doc) { memcpy(type, &type_template, sizeof(PyTypeObject)); type->tp_name = name; type->tp_doc = doc; type->tp_flags = Py_TPFLAGS_DEFAULT; if (PyType_Ready(type) < 0) { THROW(RuntimeError) << "tensortype initialization failed"; } } static void generalize_tensor_types() { init_tensortype_metaclass(&PyTensorTypeMetaClass); for (const auto& devicetype : all_device_types) { for (const auto& dtype : all_data_types) { PyTensorType* tensortype = new PyTensorType(); // set name std::string name = devicetype.second + "." + dtype.second; size_t n = sizeof(tensortype->name); strncpy(tensortype->name, name.c_str(), n - 1); tensortype->name[n - 1] = '\0'; // set type tensortype->dtype = dtype.first; tensortype->devicetype = devicetype.first; tensortype->is_cuda = tensortype->devicetype == DeviceType::kCUDA; tensor_types.push_back(tensortype); const char* doc = get_doc(tensortype); init_tensortype(&tensortype->py_type, PyTensorTypeTemplate, tensortype->name, doc); } } } bool PyTensorType_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyTensorTypeMetaClass); } PyObject* PyTensorType_FromDTypeAndDeviceType(Symbol dtype, DeviceType device) { auto it = std::find_if(tensor_types.begin(), tensor_types.end(), [dtype, device](PyTensorType* x) { return (x->dtype == dtype) && (x->devicetype == device); }); if (it == tensor_types.end()) { if (!TRY(DeviceTag4DeviceType(device)).IsOk()) return PyErr_Format(PyExc_ValueError, "unsupported device"); return PyErr_Format(PyExc_ValueError, "unsupported data type (%s) or device (%s)", dtype->name().c_str(), ASSERT(DeviceTag4DeviceType(device)).c_str()); } return (PyObject*)(*it); }; } // namespace one } // namespace oneflow #undef ASSERT using namespace oneflow::one; ONEFLOW_API_PYBIND11_MODULE("_C", m) { static std::string oneflow_prefix = "oneflow."; generalize_tensor_types(); for (PyTensorType* tensortype : tensor_types) { Py_INCREF(tensortype); std::string name = std::string(tensortype->name); size_t idx = name.rfind('.'); std::string type_name = name.substr(idx + 1); name = name.substr(0, idx); std::string module_name = name.size() > oneflow_prefix.size() ? name.substr(oneflow_prefix.size()) : ""; auto module = m; if (!module_name.empty()) { module = m.def_submodule(module_name.data()); } if (tensortype && PyModule_AddObject(module.ptr(), type_name.c_str(), (PyObject*)tensortype) < 0) { return; } } } ================================================ FILE: oneflow/api/python/framework/tensortype.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/device.h" namespace oneflow { namespace one { typedef struct { PyTypeObject py_type; char name[64]; bool is_cuda; Symbol dtype; DeviceType devicetype; } PyTensorType; bool PyTensorType_Check(PyObject*); inline DeviceType PyTensorType_UnpackDevice(PyObject* self) { return ((PyTensorType*)self)->devicetype; } inline Symbol PyTensorType_UnpackDType(PyObject* self) { return ((PyTensorType*)self)->dtype; } PyObject* PyTensorType_FromDTypeAndDeviceType(Symbol, DeviceType); PyObject* PyTensorType_FromString(const std::string&); } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_ ================================================ FILE: oneflow/api/python/framework/thread.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/framework/thread.h" #include "oneflow/core/common/env_var/vm.h" namespace py = pybind11; namespace oneflow { namespace { class UsingThreadUidSet final { public: UsingThreadUidSet() : using_thread_uids_({Stream::kDefaultStreamThreadUid}), thread_limits_(using_thread_uids_.size() + ThreadLocalEnvInteger()) {} ~UsingThreadUidSet() = default; Maybe Get() { std::unique_lock lock(mutex_); CHECK_LT_OR_RETURN(using_thread_uids_.size(), thread_limits_) << "can not create more worker threads. please check your code or increase environment " "variable ONEFLOW_VM_WORKER_THREAD_LIMIT(default value:" << ThreadLocalEnvInteger() << ")"; for (int i = 0; i < using_thread_uids_.size() + 1; ++i) { if (using_thread_uids_.count(i) == 0) { using_thread_uids_.insert(i); return i; } } UNIMPLEMENTED_THEN_RETURN(); } Maybe Put(int64_t thread_uid) { std::unique_lock lock(mutex_); CHECK_NE_OR_RETURN(thread_uid, Stream::kDefaultStreamThreadUid) << "default thread_uid should not be erased. value: " << thread_uid; CHECK_OR_RETURN(using_thread_uids_.erase(thread_uid) > 0) << "no thread_uid found. (current: " << thread_uid << ")."; return Maybe::Ok(); } private: std::set using_thread_uids_; size_t thread_limits_; std::mutex mutex_; }; UsingThreadUidSet* MutUsingThreadUidSet() { static UsingThreadUidSet thread_uid_set; return &thread_uid_set; } } // namespace /*static*/ Maybe AsyncThread::New() { return std::shared_ptr(new AsyncThread(JUST(MutUsingThreadUidSet()->Get()))); } AsyncThread::~AsyncThread() { MutUsingThreadUidSet()->Put(thread_uid_).GetOrThrow(); } } // namespace oneflow ONEFLOW_API_PYBIND11_MODULE("", m) { using namespace oneflow; py::class_>(m, "AsyncThread").def(py::init([]() { return AsyncThread::New().GetPtrOrThrow(); })); } ================================================ FILE: oneflow/api/python/framework/thread.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_ #include "oneflow/core/framework/stream.h" #include "oneflow/core/common/util.h" namespace oneflow { class AsyncThread final { public: OF_DISALLOW_COPY_AND_MOVE(AsyncThread); ~AsyncThread(); static Maybe New(); int64_t thread_uid() const { return thread_uid_; } private: AsyncThread(int64_t thread_uid) : thread_uid_(thread_uid) {} int64_t thread_uid_; }; } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_ ================================================ FILE: oneflow/api/python/framework/typeinfo.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/framework/typeinfo.h" namespace oneflow { namespace one { #define ASSERT(x) (x).GetOrThrow() #if PY_VERSION_HEX < 0x03070000 #define PYGETSET_NAME(name) const_cast(name) #else #define PYGETSET_NAME(name) (name) #endif using functional::PyObjectPtr; #define INFO_FLOAT_TYPE_SEQ FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ #define INFO_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ INFO_FLOAT_TYPE_SEQ template struct is_floating_point_with_half : public std::false_type {}; #define DEFINE_IS_FLOATING_POINT_WITH_HALF(cpp_type, of_datatype) \ template<> \ struct is_floating_point_with_half : public std::true_type {}; OF_PP_FOR_EACH_TUPLE(DEFINE_IS_FLOATING_POINT_WITH_HALF, INFO_FLOAT_TYPE_SEQ); #undef DEFINE_IS_FLOATING_POINT_WITH_HALF template typename std::enable_if::value, PyObject*>::type PyGetVal(T value) { return PyFloat_FromDouble(value); } template typename std::enable_if::value, PyObject*>::type PyGetVal(T value) { return PyLong_FromLong(value); } PyObject* PyGetMaxVal(DataType datatype) { #define GET_MAX_VAL(cpp_type, of_datatype) \ case of_datatype: return PyGetVal(std::numeric_limits>::max()); switch (datatype) { OF_PP_FOR_EACH_TUPLE(GET_MAX_VAL, INFO_TYPE_SEQ); default: return NULL; #undef GET_MAX_VAL } } PyObject* PyGetMinVal(DataType datatype) { #define GET_MIN_VAL(cpp_type, of_datatype) \ case of_datatype: return PyGetVal(std::numeric_limits>::lowest()); switch (datatype) { OF_PP_FOR_EACH_TUPLE(GET_MIN_VAL, INFO_TYPE_SEQ); default: return NULL; #undef GET_MIN_VAL } } #define GET_FLOAT_RESOLUTION(cpp_type, of_datatype) \ case of_datatype: \ return PyFloat_FromDouble( \ std::pow(10, -std::numeric_limits>::digits10)); #define GET_FLOAT_EPS(cpp_type, of_datatype) \ case of_datatype: \ return PyFloat_FromDouble(std::numeric_limits>::epsilon()); #define GET_FLOAT_TINY(cpp_type, of_datatype) \ case of_datatype: \ return PyFloat_FromDouble(std::numeric_limits>::min()); PyTypeObject PyIInfoType = { PyVarObject_HEAD_INIT(NULL, 0) "oneflow.iinfo", // tp_name sizeof(PyDTypeInfo), // tp_basicsize }; PyTypeObject PyFInfoType = { PyVarObject_HEAD_INIT(NULL, 0) "oneflow.finfo", // tp_name sizeof(PyDTypeInfo), // tp_basicsize }; static PyObject* PyIInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dtype_obj = NULL; static const char* keywords[2] = {"type", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:iinfo", const_cast(keywords), &dtype_obj)) { return NULL; } CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj)) << Error::TypeError() << "iinfo(): argument 'type' must be oneflow.dtype, but found " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj))); auto* self = (PyDTypeInfo*)PyIInfoType.tp_alloc(&PyIInfoType, 0); if (!self) { throw py::error_already_set(); } self->dtype = functional::PyUnpackDType(dtype_obj); CHECK_OR_THROW(!self->dtype->is_floating_point() && !self->dtype->is_complex()) << Error::TypeError() << "oneflow.iinfo() requires an integer input type. Use oneflow.finfo to handle '" << self->dtype->name() << "' "; return (PyObject*)self; END_HANDLE_ERRORS } static PyObject* PyFInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) { HANDLE_ERRORS PyObject* dtype_obj = functional::CastToPyObject(DType::Float()); static const char* keywords[2] = {"type", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:finfo", const_cast(keywords), &dtype_obj)) { return NULL; } CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj)) << Error::TypeError() << "finfo(): argument 'type' must be oneflow.dtype, but found " << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj))); auto* self = (PyDTypeInfo*)PyFInfoType.tp_alloc(&PyFInfoType, 0); if (!self) { throw py::error_already_set(); } self->dtype = functional::PyUnpackDType(dtype_obj); CHECK_OR_THROW(self->dtype->is_floating_point() && !self->dtype->is_complex()) << Error::TypeError() << "oneflow.finfo() requires a float input type. Use oneflow.iinfo to handle '" << self->dtype->name() << "' "; return (PyObject*)self; END_HANDLE_ERRORS } static PyObject* PyDInfo_bits(PyObject* self, void*) { HANDLE_ERRORS size_t bits = ASSERT(((PyDTypeInfo*)self)->dtype->bytes()) * 8; return PyLong_FromSize_t(bits); END_HANDLE_ERRORS } static PyObject* PyDInfo_min(PyObject* self, void*) { HANDLE_ERRORS DataType datatype = PyDTypeInfo_UnpackDataType(self); PyObject* result = PyGetMinVal(datatype); if (!result) { THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by " << self->ob_type->tp_name; } return result; END_HANDLE_ERRORS } static PyObject* PyDInfo_max(PyObject* self, void*) { HANDLE_ERRORS DataType datatype = PyDTypeInfo_UnpackDataType(self); PyObject* result = PyGetMaxVal(datatype); if (!result) { THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by " << self->ob_type->tp_name; } return result; END_HANDLE_ERRORS } static PyObject* PyFInfo_resolution(PyObject* self, void*) { HANDLE_ERRORS DataType datatype = PyDTypeInfo_UnpackDataType(self); switch (datatype) { OF_PP_FOR_EACH_TUPLE(GET_FLOAT_RESOLUTION, INFO_FLOAT_TYPE_SEQ); default: THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by oneflow.finfo"; return NULL; } END_HANDLE_ERRORS } static PyObject* PyFInfo_eps(PyObject* self, void*) { HANDLE_ERRORS DataType datatype = PyDTypeInfo_UnpackDataType(self); switch (datatype) { OF_PP_FOR_EACH_TUPLE(GET_FLOAT_EPS, INFO_FLOAT_TYPE_SEQ); default: THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by oneflow.finfo"; return NULL; } END_HANDLE_ERRORS } static PyObject* PyFInfo_tiny(PyObject* self, void*) { HANDLE_ERRORS DataType datatype = PyDTypeInfo_UnpackDataType(self); switch (datatype) { OF_PP_FOR_EACH_TUPLE(GET_FLOAT_TINY, INFO_FLOAT_TYPE_SEQ); default: THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by oneflow.finfo"; return NULL; } END_HANDLE_ERRORS } static PyObject* PyDInfo_dtype(PyObject* self, void*) { HANDLE_ERRORS std::string name = ((PyDTypeInfo*)self)->dtype->name(); name = name.erase(0, name.find('.') + 1); return PyUnicode_FromString(name.data()); END_HANDLE_ERRORS } static PyObject* PyIInfo_str(PyObject* self) { HANDLE_ERRORS std::ostringstream oss; oss << "iinfo(min=" << PyLong_AS_LONG(PyDInfo_min((PyObject*)self, NULL)) << ", "; oss << "max=" << PyLong_AS_LONG(PyDInfo_max((PyObject*)self, NULL)) << ", "; oss << "dtype=" << PyDTypeInfo_UnpackDType(self)->name() << ", "; oss << "bits=" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << ")"; return PyUnicode_FromString(oss.str().data()); END_HANDLE_ERRORS } static PyObject* PyFInfo_str(PyObject* self) { HANDLE_ERRORS std::ostringstream oss; oss << "finfo(resolution=" << PyFloat_AS_DOUBLE(PyFInfo_resolution((PyObject*)self, NULL)) << ", "; oss << "min=" << PyFloat_AS_DOUBLE(PyDInfo_min((PyObject*)self, NULL)) << ", "; oss << "max=" << PyFloat_AS_DOUBLE(PyDInfo_max((PyObject*)self, NULL)) << ", "; oss << "eps=" << PyFloat_AS_DOUBLE(PyFInfo_eps((PyObject*)self, NULL)) << ", "; oss << "tiny=" << PyFloat_AS_DOUBLE(PyFInfo_tiny((PyObject*)self, NULL)) << ", "; oss << "dtype=" << PyDTypeInfo_UnpackDType(self)->name() << ", "; oss << "bits=" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << ")"; return PyUnicode_FromString(oss.str().data()); END_HANDLE_ERRORS } static struct PyGetSetDef PyIInfo_properties[] = { {PYGETSET_NAME("bits"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr}, {PYGETSET_NAME("max"), (getter)PyDInfo_max, nullptr, nullptr, nullptr}, {PYGETSET_NAME("min"), (getter)PyDInfo_min, nullptr, nullptr, nullptr}, {PYGETSET_NAME("dtype"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr}, {nullptr}, }; static struct PyGetSetDef PyFInfo_properties[] = { {PYGETSET_NAME("bits"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr}, {PYGETSET_NAME("max"), (getter)PyDInfo_max, nullptr, nullptr, nullptr}, {PYGETSET_NAME("min"), (getter)PyDInfo_min, nullptr, nullptr, nullptr}, {PYGETSET_NAME("resolution"), (getter)PyFInfo_resolution, nullptr, nullptr, nullptr}, {PYGETSET_NAME("eps"), (getter)PyFInfo_eps, nullptr, nullptr, nullptr}, {PYGETSET_NAME("tiny"), (getter)PyFInfo_tiny, nullptr, nullptr, nullptr}, {PYGETSET_NAME("dtype"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr}, {nullptr}, }; static void init_info_type() { PyIInfoType.tp_flags = Py_TPFLAGS_DEFAULT; PyIInfoType.tp_str = (reprfunc)PyIInfo_str; PyIInfoType.tp_repr = (reprfunc)PyIInfo_str; PyIInfoType.tp_new = (newfunc)PyIInfo_new; PyIInfoType.tp_getset = PyIInfo_properties; if (PyType_Ready(&PyIInfoType) < 0) { return; } PyFInfoType.tp_flags = Py_TPFLAGS_DEFAULT; PyFInfoType.tp_str = (reprfunc)PyFInfo_str; PyFInfoType.tp_repr = (reprfunc)PyFInfo_str; PyFInfoType.tp_new = (newfunc)PyFInfo_new; PyFInfoType.tp_getset = PyFInfo_properties; if (PyType_Ready(&PyFInfoType) < 0) { return; } } ONEFLOW_API_PYBIND11_MODULE("_C", m) { init_info_type(); if (PyModule_AddObject(m.ptr(), "iinfo", (PyObject*)&PyIInfoType) < 0) return; if (PyModule_AddObject(m.ptr(), "finfo", (PyObject*)&PyFInfoType) < 0) return; } } // namespace one } // namespace oneflow #undef ASSERT #undef GET_FLOAT_RESOLUTION #undef GET_FLOAT_EPS #undef GET_FLOAT_TINY #undef INFO_FLOAT_TYPE_SEQ #undef INFO_TYPE_SEQ #undef PYGETSET_NAME ================================================ FILE: oneflow/api/python/framework/typeinfo.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_ #define ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/dtype.h" namespace oneflow { namespace one { typedef struct { PyObject_HEAD; Symbol dtype; } PyDTypeInfo; extern PyTypeObject PyIInfoType; extern PyTypeObject PyFInfoType; inline bool PyIInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyIInfoType); } inline bool PyFInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyFInfoType); } inline bool PyDTypeInfo_Check(PyObject* obj) { return PyIInfo_Check(obj) || PyFInfo_Check(obj); } inline Symbol PyDTypeInfo_UnpackDType(PyObject* obj) { assert(PyDTypeInfo_Check(obj)); return ((PyDTypeInfo*)obj)->dtype; } inline DataType PyDTypeInfo_UnpackDataType(PyObject* obj) { assert(PyDTypeInfo_Check(obj)); return ((PyDTypeInfo*)obj)->dtype->data_type(); } } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_ ================================================ FILE: oneflow/api/python/framework/variable_tensor_mgr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/common/variable_tensor_mgr.h" #include "oneflow/api/python/of_api_registry.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("FillVariableTensorMgr", &FillVariableTensorMgr); m.def("DumpVariableTensorMgr", &DumpVariableTensorMgr); m.def("ResetVariableTensorMgr", &ResetVariableTensorMgr); } } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/common.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/functional/common.h" #include #include #include #include "oneflow/api/python/framework/memory_format.h" #include "oneflow/api/python/functional/indexing.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/framework/tensor_util.h" namespace oneflow { namespace one { namespace functional { namespace detail { namespace { template Maybe GetItemInPyScalarTensor(PyObject* obj) { return GetItemInScalarTensor(PyTensor_Unpack(obj)); } } // namespace template::value, int>::type = 0> bool isinstance_fast(PyObject* obj) { static auto type = py::detail::get_type_handle(typeid(T), false); if (!type) { return false; } const auto result = PyObject_IsInstance(obj, type.ptr()); if (result == -1) { throw py::error_already_set(); } return result != 0; } template::value && !py::detail::is_shared_ptr::value, int>::type = 0> const T& cast_fast(PyObject* obj) { auto vh = reinterpret_cast(obj)->get_value_and_holder(); auto*& vptr = vh.value_ptr(); if (!vptr) { throw py::cast_error("Unable to cast from object to T& since lazy allocation is not allowed " "for fast cast, please use pybind11::cast instead"); } return *reinterpret_cast(&vptr); } template::value && py::detail::is_shared_ptr::value, int>::type = 0> const T& cast_fast(PyObject* obj) { auto vh = reinterpret_cast(obj)->get_value_and_holder(); if (!vh.holder_constructed()) { throw py::cast_error("Unable to cast from non-held to held instance (T& to Holder)"); } return vh.template holder(); } } // namespace detail bool PySequenceCheck(PyObject* obj, const std::function& item_check) { bool is_tuple = PyTuple_Check(obj); if (!is_tuple && !PyList_Check(obj)) { return false; } size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); if (size == 0) { return true; } PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); return item_check(item); } bool PyLongSequenceCheck(PyObject* obj) { return PySequenceCheck( obj, [](PyObject* item) { return PyLong_Check(item) || PyIntegerScalarTensorCheck(item); }); } bool PyFloatSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PyFloat_Check(item) || PyLong_Check(item) || PyFloatScalarTensorCheck(item) || PyIntegerScalarTensorCheck(item); }); } bool PyStringCheck(PyObject* obj) { return PyBytes_Check(obj) || PyUnicode_Check(obj); } bool PyStringSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PyStringCheck(item); }); } std::string PyStringAsString(PyObject* obj) { PyObject* bytes = PyUnicode_AsEncodedString(obj, "utf-8", "~E~"); std::string str = PyBytes_AS_STRING(bytes); Py_XDECREF(bytes); return str; } std::string PyObjectToReprStr(PyObject* obj) { PyObject* repr_obj = PyObject_Repr(obj); std::string str = PyStringAsString(repr_obj); Py_XDECREF(repr_obj); return str; } // Tensor list bool PyTensorSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PyTensor_Check(item); }); } std::vector> PyUnpackTensorSequence(PyObject* obj) { return PyUnpackSequence>( obj, [](PyObject* item) { return PyTensor_Unpack(item); }); } // TensorTuple bool PyTensorTupleCheck(PyObject* obj) { return detail::isinstance_fast(obj); } std::shared_ptr PyUnpackTensorTuple(PyObject* obj) { return detail::cast_fast>(obj); } // Scalar bool PyScalarCheck(PyObject* obj) { return PyLong_Check(obj) || PyFloat_Check(obj) || PyComplex_Check(obj); } Scalar PyUnpackScalar(PyObject* obj) { if (PyBool_Check(obj)) { return obj == Py_True; } else if (PyLong_Check(obj)) { return static_cast(PyLong_AsLongLong(obj)); } else if (PyFloat_Check(obj)) { return PyFloat_AsDouble(obj); } else if (PyComplex_Check(obj)) { Py_complex value = PyComplex_AsCComplex(obj); return std::complex{value.real, value.imag}; } else if (PyArray_IsScalar(obj, Bool)) { return obj == Py_True; } else if (PyArray_IsScalar(obj, Floating)) { return PyFloat_AsDouble(obj); } else if (PyArray_IsScalar(obj, Complex64) || PyArray_IsScalar(obj, Complex128)) { Py_complex value = PyComplex_AsCComplex(obj); return std::complex{value.real, value.imag}; } THROW(RuntimeError) << "The object is not scalar, but is " << Py_TYPE(obj)->tp_name; return 0; } // Scalar Tensor bool PyScalarTensorCheck(PyObject* obj) { if (!LazyMode::is_enabled() && PyTensor_Check(obj)) { const auto& tensor = PyTensor_Unpack(obj); return tensor->shape()->size() == 0 && IsTriviallyCopyableDataType(tensor->dtype()->data_type()); } return false; } Scalar PyUnpackScalarTensor(PyObject* obj) { if (PyBoolScalarTensorCheck(obj)) { return PyUnpackBoolScalarTensor(obj); } else if (PyIntegerScalarTensorCheck(obj)) { return PyUnpackIntegerScalarTensor_AsLongLong(obj); } else if (PyFloatScalarTensorCheck(obj)) { return PyUnpackFloatScalarTensor_AsDouble(obj); } else if (PyComplexScalarTensorCheck(obj)) { return PyUnpackComplexScalarTensor_AsCComplex(obj); } THROW(RuntimeError) << "The object is not scalar tensor, but is " << Py_TYPE(obj)->tp_name << "with data type: " << DataType_Name(PyTensor_Unpack(obj)->dtype()->data_type()); return 0; } #define SWITCH_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \ case of_type: \ return detail::GetItemInPyScalarTensor(obj).GetOrThrow(); #define SCALAR_TENSOR_UNPACK_FUNC_IMPL(func_name, return_type, type_seq) \ return_type func_name(PyObject* obj) { \ const auto& tensor = PyTensor_Unpack(obj); \ DataType data_type = tensor->dtype()->data_type(); \ switch (data_type) { \ OF_PP_FOR_EACH_TUPLE(SWITCH_SCALAR_TENSOR_TO_SCALAR, type_seq) \ default: { \ throw py::cast_error("Cannot get ##cpp##type from scalar tensor with data type: " \ + DataType_Name(data_type)); \ } \ } \ } SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackBoolScalarTensor, bool, BOOL_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ); SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackIntegerScalarTensor_AsLongLong, long long, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ); SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackFloatScalarTensor_AsDouble, double, FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ); SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackComplexScalarTensor_AsCComplex, std::complex, COMPLEX_DATA_TYPE_SEQ FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ); #undef SWITCH_SCALAR_TENSOR_TO_SCALAR #undef SCALAR_TENSOR_UNPACK_FUNC_IMPL // DType bool PyDTypeCheck(PyObject* obj) { return detail::isinstance_fast>(obj); } Symbol PyUnpackDType(PyObject* obj) { return *detail::cast_fast*>(obj); } // Layout bool PyLayoutCheck(PyObject* obj) { return detail::isinstance_fast>(obj); } Symbol PyUnpackLayout(PyObject* obj) { return *detail::cast_fast*>(obj); } // Memory Format bool PyMemoryFormatCheck(PyObject* obj) { return PyMemoryFormat_Check(obj); } MemoryFormat PyUnpackMemoryFormat(PyObject* obj) { return PyMemoryFormat_Unpack(obj); } // DType list bool PyDTypeSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PyDTypeCheck(item); }); } std::vector> PyUnpackDTypeSequence(PyObject* obj) { return PyUnpackSequence>(obj, [](PyObject* item) { return PyUnpackDType(item); }); } // Shape bool PyShapeCheck(PyObject* obj) { return PyLongSequenceCheck(obj); } Shape PyUnpackShape(PyObject* obj) { bool is_tuple = PyTuple_Check(obj); CHECK_OR_THROW(is_tuple || PyList_Check(obj)) << "The object is not list or tuple, but is " << Py_TYPE(obj)->tp_name; size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); DimVector values(size); for (int i = 0; i < size; ++i) { PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, i) : PyList_GET_ITEM(obj, i); values[i] = PyLong_AsLongLong(item); } return Shape(values); } // Shape list bool PyShapeSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PyLongSequenceCheck(item); }); } std::vector PyUnpackShapeSequence(PyObject* obj) { return PyUnpackSequence(obj, [](PyObject* item) -> Shape { return PyUnpackShape(item); }); } // Generator bool PyGeneratorCheck(PyObject* obj) { return detail::isinstance_fast(obj); } std::shared_ptr PyUnpackGenerator(PyObject* obj) { return detail::cast_fast>(obj); } // Device bool PyDeviceCheck(PyObject* obj) { return detail::isinstance_fast>(obj); } Symbol PyUnpackDevice(PyObject* obj) { return *detail::cast_fast>>(obj); } // Placement bool PyParallelDescCheck(PyObject* obj) { return detail::isinstance_fast>(obj); } Symbol PyUnpackParallelDesc(PyObject* obj) { return *detail::cast_fast>>(obj); } // SBP bool PySbpParallelCheck(PyObject* obj) { return detail::isinstance_fast>(obj); } Symbol PyUnpackSbpParallel(PyObject* obj) { return *detail::cast_fast>>(obj); } // SBP list bool PySbpParallelSequenceCheck(PyObject* obj) { return PySequenceCheck(obj, [](PyObject* item) { return PySbpParallelCheck(item); }); } std::vector> PyUnpackSbpParallelSequence(PyObject* obj) { return PyUnpackSequence>( obj, [](PyObject* item) { return PyUnpackSbpParallel(item); }); } // Tensor index bool PyTensorIndexCheck(PyObject* obj) { return PySlice_Check(obj) || PyLong_Check(obj) || obj == Py_Ellipsis || obj == Py_None || PyTensor_Check(obj) || PySequence_Check(obj) || PyUnicode_Check(obj) || numpy::PyArrayCheckLongScalar(obj); } TensorIndex PyUnpackTensorIndex(PyObject* obj) { TensorIndex tensor_index; // Obvious single-entry cases. if (PySlice_Check(obj) // NOLINT || PyLong_Check(obj) // NOLINT || obj == Py_Ellipsis // NOLINT || obj == Py_None // NOLINT || PyTensor_Check(obj) // NOLINT || !PySequence_Check(obj) // NOLINT || numpy::PyArrayCheckLongScalar(obj) // NOLINT || PyUnicode_Check(obj)) { tensor_index.emplace_back(detail::UnpackIndexItem(obj)); return tensor_index; } PyObject* tup = NULL; Py_ssize_t n = 0; if (PyTuple_Check(obj)) { tup = PySequence_Tuple(obj); n = PySequence_Size(tup); } else { // The follow comments are from numpy: // https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L266 /* * At this point, we're left with a non-tuple, non-array, sequence: * typically, a list. We use some somewhat-arbitrary heuristics from here * onwards to decided whether to treat that list as a single index, or a * list of indices. */ n = PySequence_Size(obj); // Negative size indicates a Python error in the PySequence_Size call. if (n < 0) { PyErr_Clear(); tensor_index.emplace_back(detail::UnpackIndexItem(obj)); return tensor_index; } // The follow comments are from numpy: // https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L280 /* * Backwards compatibility only takes effect for short sequences - otherwise * we treat it like any other scalar. * * Sequences < NPY_MAXDIMS with any slice objects * or newaxis, Ellipsis or other arrays or sequences * embedded, are considered equivalent to an indexing * tuple. (`a[[[1,2], [3,4]]] == a[[1,2], [3,4]]`) */ if (n >= /*NPY_MAXDIMS=*/32) { tensor_index.emplace_back(detail::UnpackIndexItem(obj)); return tensor_index; } // Check whether we should unpack the index like a tuple. bool commit_to_unpack = false; for (Py_ssize_t i = 0; i < n; ++i) { PyObject* item = PySequence_GetItem(obj, i); if (commit_to_unpack) { CHECK_OR_THROW(item) << "Sequence index is required."; } else { if (!item) { PyErr_Clear(); break; } if (PySequence_Check(item) // NOLINT || PySlice_Check(item) // NOLINT || PyTensor_Check(item) // NOLINT || item == Py_Ellipsis || item == Py_None) { commit_to_unpack = true; } } Py_DECREF(item); } if (commit_to_unpack) { tup = PySequence_Tuple(obj); } else { tensor_index.emplace_back(detail::UnpackIndexItem(obj)); return tensor_index; } } tensor_index.resize(n); for (Py_ssize_t i = 0; i < n; ++i) { PyObject* item = PySequence_GetItem(tup, i); tensor_index[i] = detail::UnpackIndexItem(item); Py_DECREF(item); } Py_DECREF(tup); return tensor_index; } // OpExpr bool PyOpExprCheck(PyObject* obj) { return detail::isinstance_fast(obj); } std::shared_ptr PyUnpackOpExpr(PyObject* obj) { return detail::cast_fast>(obj); } // int64_t Maybe PyUnpackLong(PyObject* py_obj) { int overflow = -1; long long val = PyLong_AsLongLongAndOverflow(py_obj, &overflow); if (val == -1 && PyErr_Occurred()) { return Error::RuntimeError() << "Python exception occurs"; } if (overflow != 0) { return Error::RuntimeError() << "Overflow when unpacking long"; } return (int64_t)val; } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/common.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_ #include #include #include #include #include "oneflow/api/python/framework/tensor.h" #include "oneflow/api/python/caster/maybe.h" #include "oneflow/api/python/caster/optional.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/layout.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/common/foreign_lock_helper.h" namespace py = pybind11; namespace oneflow { namespace one { namespace functional { struct PyObjectPtrDeleter { inline void operator()(PyObject* obj) { CHECK_JUST(Singleton::Get()->WithScopedAcquire([&]() -> Maybe { if (obj) { Py_DECREF(obj); } obj = NULL; return Maybe::Ok(); })); } }; using PyObjectPtr = std::unique_ptr; #define INTEGER_AND_BOOL_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t) \ OF_PP_MAKE_TUPLE_SEQ(int64_t) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t) \ OF_PP_MAKE_TUPLE_SEQ(bool) #define FLOATING_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(float) \ OF_PP_MAKE_TUPLE_SEQ(double) bool PySequenceCheck(PyObject* obj); bool PySequenceCheck(PyObject* obj, const std::function& item_check); template inline std::vector PyUnpackSequence(PyObject* obj, UnpackItemFunc unpack_item) { bool is_tuple = PyTuple_Check(obj); CHECK_OR_THROW(is_tuple || PyList_Check(obj)) << "The object is not list or tuple, but is " << Py_TYPE(obj)->tp_name; size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); std::vector values(size); for (int i = 0; i < size; ++i) { PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, i) : PyList_GET_ITEM(obj, i); values[i] = unpack_item(item); } return values; } // Scalar Tensor bool PyScalarTensorCheck(PyObject* obj); Scalar PyUnpackScalarTensor(PyObject* obj); #define DefinePyTypeScalarTensorCheck(type, type_check_func) \ inline bool Py##type##ScalarTensorCheck(PyObject* obj) { \ return PyScalarTensorCheck(obj) \ && type_check_func(PyTensor_Unpack(obj)->dtype()->data_type()); \ } DefinePyTypeScalarTensorCheck(Bool, IsBoolDataType); // PyBoolScalarTensorCheck DefinePyTypeScalarTensorCheck(Integer, IsIntegralDataType); // PyIntegerScalarTensorCheck DefinePyTypeScalarTensorCheck(Float, IsFloatingDataType); // PyFloatScalarTensorCheck DefinePyTypeScalarTensorCheck(Complex, IsComplexDataType); // PyComplexScalarTensorCheck #undef DefinePyTypeScalarTensorCheck bool PyUnpackBoolScalarTensor(PyObject* obj); long long PyUnpackIntegerScalarTensor_AsLongLong(PyObject* obj); double PyUnpackFloatScalarTensor_AsDouble(PyObject* obj); std::complex PyUnpackComplexScalarTensor_AsCComplex(PyObject* obj); // Integer/Float list bool PyLongSequenceCheck(PyObject* obj); bool PyFloatSequenceCheck(PyObject* obj); template inline std::vector PyUnpackLongSequence(PyObject* obj) { return PyUnpackSequence(obj, [](PyObject* item) -> T { if (PyIntegerScalarTensorCheck(item)) { return static_cast(PyUnpackIntegerScalarTensor_AsLongLong(item)); } return static_cast(PyLong_AsLongLong(item)); }); } template inline std::vector PyUnpackFloatSequence(PyObject* obj) { return PyUnpackSequence(obj, [](PyObject* item) -> T { if (PyFloatScalarTensorCheck(item)) { return static_cast(PyUnpackFloatScalarTensor_AsDouble(item)); } return static_cast(PyFloat_AsDouble(item)); }); } // String bool PyStringCheck(PyObject* obj); bool PyStringSequenceCheck(PyObject* obj); std::string PyStringAsString(PyObject* obj); std::string PyObjectToReprStr(PyObject* obj); // Scalar bool PyScalarCheck(PyObject* obj); Scalar PyUnpackScalar(PyObject* obj); // Tensor list bool PyTensorSequenceCheck(PyObject* obj); std::vector> PyUnpackTensorSequence(PyObject* obj); // TensorTuple bool PyTensorTupleCheck(PyObject* obj); std::shared_ptr PyUnpackTensorTuple(PyObject* obj); // DType bool PyDTypeCheck(PyObject* obj); Symbol PyUnpackDType(PyObject* obj); // Layout bool PyLayoutCheck(PyObject* obj); Symbol PyUnpackLayout(PyObject* obj); // Memory Format bool PyMemoryFormatCheck(PyObject* obj); MemoryFormat PyUnpackMemoryFormat(PyObject* obj); // DType list bool PyDTypeSequenceCheck(PyObject* obj); std::vector> PyUnpackDTypeSequence(PyObject* obj); // Shape bool PyShapeCheck(PyObject* obj); Shape PyUnpackShape(PyObject* obj); // Shape list bool PyShapeSequenceCheck(PyObject* obj); std::vector PyUnpackShapeSequence(PyObject* obj); // Generator bool PyGeneratorCheck(PyObject* obj); std::shared_ptr PyUnpackGenerator(PyObject* obj); // Device bool PyDeviceCheck(PyObject* obj); Symbol PyUnpackDevice(PyObject* obj); // Placement bool PyParallelDescCheck(PyObject* obj); Symbol PyUnpackParallelDesc(PyObject* obj); // SBP bool PySbpParallelCheck(PyObject* obj); Symbol PyUnpackSbpParallel(PyObject* obj); // SBP list bool PySbpParallelSequenceCheck(PyObject* obj); std::vector> PyUnpackSbpParallelSequence(PyObject* obj); // Tensor index bool PyTensorIndexCheck(PyObject* obj); TensorIndex PyUnpackTensorIndex(PyObject* obj); // OpExpr bool PyOpExprCheck(PyObject* obj); std::shared_ptr PyUnpackOpExpr(PyObject* obj); template inline PyObject* CastToPyObject(T&& t) { return py::cast(t).inc_ref().ptr(); } template<> inline PyObject* CastToPyObject>(Maybe&& t) { return PyTensor_New(t.GetPtrOrThrow()); } template<> inline PyObject* CastToPyObject>(Maybe&& t) { const auto& tensor_tuple = t.GetPtrOrThrow(); py::tuple tup(tensor_tuple->size()); for (int i = 0; i < tensor_tuple->size(); ++i) { tup[i] = py::cast(tensor_tuple->at(i)); } return py::cast(tup).inc_ref().ptr(); } template<> inline PyObject* CastToPyObject>(Maybe&& t) { t.GetOrThrow(); Py_RETURN_NONE; } // int64_t Maybe PyUnpackLong(PyObject* py_obj); } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_ ================================================ FILE: oneflow/api/python/functional/dispatch_stateful_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_interpreter/lazy_op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" namespace oneflow { namespace one { namespace functional { namespace impl { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor( "DispatchFeedInput", [](const std::shared_ptr& op, const std::shared_ptr& input) -> Maybe { const auto& origin_input = JUST(OpInterpUtil::Dispatch(*op, {input})); // Unpack input when do grad acc return GradAccTryInsertUnpackAfterInput(origin_input); }); m.add_functor( "DispatchFetchOutput", [](const std::shared_ptr& op, const std::shared_ptr& input) -> Maybe { // Pack output when do grad acc const auto& pack_input = JUST(GradAccTryInsertPackBeforeOutput(input)); return OpInterpUtil::Dispatch(*op, {pack_input}); }); m.add_functor("DispatchFeedVariable", [](const std::shared_ptr& op, const std::shared_ptr& input, const Scalar& l2) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("l2"); attrs.SetAllAttrs(l2.As()); const auto& origin_var = JUST(OpInterpUtil::Dispatch(*op, {input}, attrs)); // Repeat variable when do grad acc return GradAccTryInsertRepeatAfterVar(origin_var); }); m.add_functor( "DispatchOfrecordReader", [](const std::shared_ptr& op, const std::string& data_dir, int32_t data_part_num, const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size, int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, const Optional>& device) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "data_dir", "data_part_num", "part_name_prefix", "part_name_suffix_length", "batch_size", "shuffle_buffer_size", "random_shuffle", "shuffle_after_epoch", "seed"); attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length, batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch, seed); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); }); m.add_functor( "DispatchOfrecordReader", [](const std::shared_ptr& op, const std::string& data_dir, int32_t data_part_num, const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size, int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "data_dir", "data_part_num", "part_name_prefix", "part_name_suffix_length", "batch_size", "shuffle_buffer_size", "random_shuffle", "shuffle_after_epoch", "seed", "nd_sbp"); attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length, batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch, seed, *JUST(GetNdSbpStrList(sbp_tuple))); auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }); m.add_functor("DispatchOfrecordRawDecoder", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& name, const Shape& shape, const Symbol& data_type, bool dim1_varying_length, bool truncate) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name", "shape", "data_type", "dim1_varying_length", "truncate"); attrs.SetAllAttrs(name, shape, data_type->data_type(), dim1_varying_length, truncate); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchCoinFlip", [](const std::shared_ptr& op, int64_t batch_size, Scalar probability, int64_t seed, bool has_seed, const Optional>& device) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("probability", "batch_size", "seed", "has_seed"); attrs.SetAllAttrs(probability.As(), batch_size, seed, has_seed); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); }); m.add_functor("DispatchCoinFlip", [](const std::shared_ptr& op, int64_t batch_size, Scalar probability, int64_t seed, bool has_seed, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("probability", "batch_size", "seed", "has_seed", "nd_sbp"); attrs.SetAllAttrs(probability.As(), batch_size, seed, has_seed, *JUST(GetNdSbpStrList(sbp_tuple))); auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch( *op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }); m.add_functor( "DispatchDistributedPariticalFCSample", [](const std::shared_ptr& op, const std::shared_ptr& weight, const std::shared_ptr& label, const int64_t& num_sample) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_sample"); attrs.SetAllAttrs(num_sample); return OpInterpUtil::Dispatch(*op, {weight, label}, attrs); }); m.add_functor( "DispatchCropMirrorNormalizeFromUint8", [](const std::shared_ptr& op, const TensorTuple& input, int64_t crop_h, int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector& mean, const std::vector& std, const Symbol& output_dtype, const std::string& output_layout, const std::string& color_space) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "output_layout", "mean", "std", "crop_h", "crop_w", "crop_pos_x", "crop_pos_y", "output_dtype"); attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x, crop_pos_y, output_dtype->data_type()); return OpInterpUtil::Dispatch(*op, input, attrs); }); m.add_functor( "DispatchCropMirrorNormalizeFromTensorBuffer", [](const std::shared_ptr& op, const TensorTuple& input, int64_t crop_h, int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector& mean, const std::vector& std, const Symbol& output_dtype, const std::string& output_layout, const std::string& color_space) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "output_layout", "mean", "std", "crop_h", "crop_w", "crop_pos_x", "crop_pos_y", "output_dtype"); attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x, crop_pos_y, output_dtype->data_type()); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchOfrecordImageDecoderRandomCrop", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& name, const std::string& color_space, const std::vector& random_area, const std::vector& random_aspect_ratio, int32_t num_attempts, int64_t seed, bool has_seed) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name", "color_space", "num_attempts", "seed", "has_seed", "random_area", "random_aspect_ratio"); attrs.SetAllAttrs(name, color_space, num_attempts, seed, has_seed, random_area, random_aspect_ratio); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchOfrecordImageDecoder", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& name, const std::string& color_space) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name", "color_space"); attrs.SetAllAttrs(name, color_space); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchImageDecoderRandomCropResize", [](const std::shared_ptr& op, const std::shared_ptr& input, int64_t target_width, int64_t target_height, int64_t seed, int64_t num_workers, int64_t max_num_pixels, float random_area_min, float random_area_max, float random_aspect_ratio_min, float random_aspect_ratio_max, int64_t warmup_size, int64_t num_attempts) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "target_width", "target_height", "seed", "num_workers", "max_num_pixels", "random_area_min", "random_area_max", "random_aspect_ratio_min", "random_aspect_ratio_max", "warmup_size", "num_attempts"); attrs.SetAllAttrs(target_width, target_height, seed, num_workers, max_num_pixels, random_area_min, random_area_max, random_aspect_ratio_min, random_aspect_ratio_max, warmup_size, num_attempts); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchTensorBufferToListOfTensorsV2", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::vector& out_shapes, const std::vector>& out_dtypes, bool dynamic_out) -> Maybe { auto out_data_types = std::vector(); for (auto it = out_dtypes.begin(); it != out_dtypes.end(); it++) { out_data_types.emplace_back((*it)->data_type()); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_shapes", "dynamic_out", "out_dtypes"); attrs.SetAllAttrs(out_shapes, dynamic_out, out_data_types); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchImageResizeKeepAspectRatio", [](const std::shared_ptr& op, const std::shared_ptr& input, int32_t target_size, int32_t min_size, int32_t max_size, bool resize_longer, const std::string& interpolation_type) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "target_size", "min_size", "max_size", "resize_longer", "interpolation_type"); attrs.SetAllAttrs(target_size, min_size, max_size, resize_longer, interpolation_type); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchImageResizeToFixed", [](const std::shared_ptr& op, const std::shared_ptr& input, int64_t target_width, int64_t target_height, int64_t channels, const Symbol& data_type, const std::string& interpolation_type) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("target_width", "target_height", "channels", "data_type", "interpolation_type"); attrs.SetAllAttrs(target_width, target_height, channels, data_type->data_type(), interpolation_type); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchImageDecode", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& color_space, const Symbol& data_type) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "data_type"); attrs.SetAllAttrs(color_space, data_type->data_type()); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchImageNormalize", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::vector& mean, const std::vector& std) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("std", "mean"); attrs.SetAllAttrs(std, mean); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchCOCOReader", [](const std::shared_ptr& op, const std::string& image_dir, const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, bool stride_partition, int64_t session_id, const Optional>& device) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "session_id", "annotation_file", "image_dir", "batch_size", "shuffle_after_epoch", "random_seed", "group_by_ratio", "remove_images_without_annotations", "stride_partition"); attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size, shuffle_after_epoch, random_seed, group_by_ratio, remove_images_without_annotations, stride_partition); return OpInterpUtil::Dispatch( *op, {}, OpExprInterpContext(attrs, JUST(device))); }); m.add_functor("DispatchCOCOReader", [](const std::shared_ptr& op, const std::string& image_dir, const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, bool stride_partition, int64_t session_id, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "session_id", "annotation_file", "image_dir", "batch_size", "shuffle_after_epoch", "random_seed", "group_by_ratio", "remove_images_without_annotations", "stride_partition", "nd_sbp"); attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size, shuffle_after_epoch, random_seed, group_by_ratio, remove_images_without_annotations, stride_partition, *JUST(GetNdSbpStrList(sbp_tuple))); auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch( *op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }); m.add_functor( "DispatchImageBatchAlign", [](const std::shared_ptr& op, const std::shared_ptr& input, int32_t alignment, const Shape& shape, const Symbol& data_type, bool dynamic_out) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "data_type", "alignment", "dynamic_out"); attrs.SetAllAttrs(shape, data_type->data_type(), alignment, dynamic_out); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor("DispatchOfrecordBytesDecoder", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& name) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name"); attrs.SetAllAttrs(name); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchMegatronGptMmapDataLoader", [](const std::shared_ptr& op, const std::string& data_file_prefix, int64_t seq_length, int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol& dtype, const std::vector& split_sizes, int64_t split_index, bool shuffle, int64_t random_seed, const Optional>& device) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "data_file_prefix", "seq_length", "label_length", "num_samples", "batch_size", "dtype", "split_sizes", "split_index", "shuffle", "random_seed"); attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size, dtype->data_type(), split_sizes, split_index, shuffle, random_seed); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); }); m.add_functor( "DispatchMegatronGptMmapDataLoader", [](const std::shared_ptr& op, const std::string& data_file_prefix, int64_t seq_length, int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol& dtype, const std::vector& split_sizes, int64_t split_index, bool shuffle, int64_t random_seed, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "data_file_prefix", "seq_length", "label_length", "num_samples", "batch_size", "dtype", "split_sizes", "split_index", "shuffle", "random_seed"); attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size, dtype->data_type(), split_sizes, split_index, shuffle, random_seed); auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }); m.add_functor("DispatchRmspropUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, bool centered, float epsilon, float decay_rate, float weight_decay) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "centered", "epsilon", "decay_rate", "weight_decay"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, centered, epsilon, decay_rate, weight_decay); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor( "DispatchAdamUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, float bias_correction1, float bias_correction2, double scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "learning_rate_val", "bias_correction1_val", "bias_correction2_val", "scale", "l1", "l2", "beta1", "beta2", "epsilon", "weight_decay", "amsgrad", "do_bias_correction"); attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor("DispatchAdagradUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, float lr_decay, float weight_decay, float epsilon, int32_t train_step) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "lr_decay", "weight_decay", "epsilon", "train_step_val"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_decay, weight_decay, epsilon, train_step); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor( "DispatchMomentumUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "beta", "dampening", "nesterov", "maximize", "weight_decay"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor( "DispatchSgdUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, float weight_decay) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "weight_decay"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, weight_decay); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor("DispatchLambUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, float bias_correction1, float bias_correction2, double scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool do_bias_correction) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "learning_rate_val", "bias_correction1_val", "bias_correction2_val", "scale", "l1", "l2", "beta1", "beta2", "epsilon", "weight_decay", "do_bias_correction"); attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1, l2, beta1, beta2, epsilon, weight_decay, do_bias_correction); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor("DispatchFtrlUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "lr_power", "lambda1", "lambda2", "beta", "weight_decay"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor( "DispatchAdadeltaUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, double scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "rho", "epsilon", "maximize", "weight_decay"); attrs.SetAllAttrs(learning_rate, scale, l1, l2, rho, epsilon, maximize, weight_decay); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); m.add_functor("DispatchEagerCclAllReduce", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& parallel_conf) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("parallel_conf"); attrs.SetAllAttrs(parallel_conf); return OpInterpUtil::Dispatch(*op, {input}, attrs); }); m.add_functor( "DispatchRawReader", [](const std::shared_ptr& op, const std::vector& files, const Shape& shape, const Symbol& data_type, const int64_t batch_size, const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed, const Optional>& device) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("files", "shape", "data_type", "batch_size", "random_shuffle", "shuffle_block_size", "seed", "nd_sbp"); attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size, random_shuffle, shuffle_block_size, random_seed, std::vector()); return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); }); m.add_functor("DispatchRawReader", [](const std::shared_ptr& op, const std::vector& files, const Shape& shape, const Symbol& data_type, const int64_t batch_size, const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "files", "shape", "data_type", "batch_size", "random_shuffle", "shuffle_block_size", "seed", "nd_sbp"); attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size, random_shuffle, shuffle_block_size, random_seed, *JUST(GetNdSbpStrList(sbp_tuple))); auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch( *op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }); } } // namespace impl } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/dispatch_stateful_ops.yaml ================================================ # Copyright 2020 The OneFlow Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # The following data types are allowed, # { # "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", # "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", # "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement", # "Sbp", "SbpList" # } - name: "dispatch_feed_input" signature: "Tensor (OpExpr op, Tensor input) => DispatchFeedInput" bind_python: True - name: "dispatch_feed_variable" signature: "Tensor (OpExpr op, Tensor input, Scalar l2) => DispatchFeedVariable" bind_python: True - name: "dispatch_fetch_output" signature: "Tensor (OpExpr op, Tensor input) => DispatchFetchOutput" bind_python: True - name: "dispatch_ofrecord_reader" signature: [ "Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\"part-\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Device device=None) => DispatchOfrecordReader", "Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\"part-\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Placement placement, SbpList sbp) => DispatchOfrecordReader", ] bind_python: True - name: "dispatch_ofrecord_raw_decoder" signature: "Tensor (OpExpr op, Tensor input, String name, Shape shape, DataType data_type, Bool dim1_varying_length=False, Bool truncate=False) => DispatchOfrecordRawDecoder" bind_python: True - name: "dispatch_coin_flip" signature: [ "Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Device device=None) => DispatchCoinFlip", "Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Placement placement, SbpList sbp) => DispatchCoinFlip", ] bind_python: True - name: "dispatch_distributed_partial_fc_sample" signature: "TensorTuple (OpExpr op, Tensor weight, Tensor label, Int64 num_sample) => DispatchDistributedPariticalFCSample" bind_python: True - name: "dispatch_crop_mirror_normalize_from_uint8" signature: "Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\"NCHW\", String color_space=\"BGR\") => DispatchCropMirrorNormalizeFromUint8" bind_python: True - name: "dispatch_crop_mirror_normalize_from_tensorbuffer" signature: "Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\"NCHW\", String color_space=\"BGR\") => DispatchCropMirrorNormalizeFromTensorBuffer" bind_python: True - name: "dispatch_ofrecord_image_decoder_random_crop" signature: "Tensor (OpExpr op, Tensor input, String name, String color_space=\"BGR\", FloatList random_area, FloatList random_aspect_ratio, Int32 num_attempts=10, Int64 seed=-1, Bool has_seed=False) => DispatchOfrecordImageDecoderRandomCrop" bind_python: True - name: "dispatch_ofrecord_image_decoder" signature: "Tensor (OpExpr op, Tensor input, String name, String color_space=\"BGR\") => DispatchOfrecordImageDecoder" bind_python: True - name: "dispatch_image_decoder_random_crop_resize" signature: "Tensor (OpExpr op, Tensor input, Int64 target_width, Int64 target_height, Int64 seed, Int64 num_workers=3, Int64 max_num_pixels=67108864, Float random_area_min=0.08f, Float random_area_max=1.0f, Float random_aspect_ratio_min=0.75f, Float random_aspect_ratio_max=1.333333f, Int64 warmup_size=6400, Int64 num_attempts=10) => DispatchImageDecoderRandomCropResize" bind_python: True - name: "dispatch_tensor_buffer_to_list_of_tensors_v2" signature: "TensorTuple (OpExpr op, Tensor input, ShapeList out_shapes, DataTypeList out_dtypes, Bool dynamic_out) => DispatchTensorBufferToListOfTensorsV2" bind_python: True - name: "dispatch_image_resize_keep_aspect_ratio" signature: "TensorTuple (OpExpr op, Tensor input, Int32 target_size, Int32 min_size=0, Int32 max_size=0, Bool resize_longer=False, String interpolation_type=\"bilinear\") => DispatchImageResizeKeepAspectRatio" bind_python: True - name: "dispatch_image_resize_to_fixed" signature: "TensorTuple (OpExpr op, Tensor input, Int64 target_width=0, Int64 target_height=0, Int64 channels=3, DataType data_type=kUInt8, String interpolation_type=\"bilinear\") => DispatchImageResizeToFixed" bind_python: True - name: "dispatch_image_decode" signature: "Tensor (OpExpr op, Tensor input, String color_space=\"BGR\", DataType data_type=kUInt8) => DispatchImageDecode" bind_python: True - name: "dispatch_image_normalize" signature: "Tensor (OpExpr op, Tensor input, FloatList mean, FloatList std) => DispatchImageNormalize" bind_python: True - name: "dispatch_coco_reader" signature: [ "TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Device device=None) => DispatchCOCOReader", "TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Placement placement, SbpList sbp) => DispatchCOCOReader", ] bind_python: True - name: "dispatch_image_batch_align" signature: "Tensor (OpExpr op, Tensor input, Int32 alignment, Shape shape, DataType data_type, Bool dynamic_out) => DispatchImageBatchAlign" bind_python: True - name: "dispatch_ofrecord_bytes_decoder" signature: "Tensor (OpExpr op, Tensor input, String name) => DispatchOfrecordBytesDecoder" bind_python: True - name: "dispatch_megatron_gpt_mmap_data_loader" signature: [ "Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Device device=None) => DispatchMegatronGptMmapDataLoader", "Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Placement placement, SbpList sbp) => DispatchMegatronGptMmapDataLoader", ] bind_python: True - name: "dispatch_rmsprop_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Bool centered=False, Float epsilon=1e-8, Float decay_rate=0.99, Float weight_decay=0.0) => DispatchRmspropUpdate" bind_python: True - name: "dispatch_adam_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool amsgrad=False, Bool do_bias_correction=True) => DispatchAdamUpdate" bind_python: True - name: "dispatch_adagrad_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_decay=0, Float weight_decay=0, Float epsilon=1e-10, Int32 train_step_val=0) => DispatchAdagradUpdate" bind_python: True - name: "dispatch_momentum_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate" bind_python: True - name: "dispatch_sgd_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float weight_decay=0) => DispatchSgdUpdate" bind_python: True - name: "dispatch_lamb_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool do_bias_correction=True) => DispatchLambUpdate" bind_python: True - name: "dispatch_ftrl_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate" bind_python: True - name: "dispatch_adadelta_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float rho, Float epsilon, Bool maximize, Float weight_decay=0) => DispatchAdadeltaUpdate" bind_python: True - name: "dispatch_eager_ccl_all_reduce" signature: "Tensor (OpExpr op, Tensor input, String parallel_conf) => DispatchEagerCclAllReduce" bind_python: True - name: "dispatch_raw_reader" signature: [ "Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle, Int64 shuffle_block_size, Int64 random_seed=-1, Device device=None) => DispatchRawReader", "Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle, Int64 shuffle_block_size, Int64 random_seed=-1, Placement placement, SbpList sbp) => DispatchRawReader", ] bind_python: True ================================================ FILE: oneflow/api/python/functional/function_def.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_ #include #include #include #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/value_types.h" namespace oneflow { namespace one { namespace functional { struct ReturnDef { explicit ReturnDef(const ValueType& t) : type(t) {} ValueType type; }; struct ArgumentDef { ArgumentDef(const std::string& arg_name, const ValueType& arg_type, int arg_size, bool arg_keyword_only, bool arg_optional) : name(arg_name), type(arg_type), size(arg_size), keyword_only(arg_keyword_only), optional(arg_optional), has_default_value(false) {} template ArgumentDef(const std::string& arg_name, const T& arg_val, int arg_size, bool arg_keyword_only, bool arg_optional) : name(arg_name), type(ValueTypeOf()), size(arg_size), keyword_only(arg_keyword_only), optional(arg_optional), has_default_value(true) { default_value = std::make_shared>(arg_val); } std::string name; ValueType type; int size; bool keyword_only; bool optional; bool has_default_value; std::shared_ptr default_value; }; struct FunctionDef { std::string name; ReturnDef return_def; std::vector argument_def; }; } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_ ================================================ FILE: oneflow/api/python/functional/indexing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/functional/indexing.h" #include #include #include "oneflow/api/python/functional/common.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/functional/functional.h" #include "oneflow/api/python/functional/tensor_api.yaml.h" #include "oneflow/core/common/foreign_lock_helper.h" namespace oneflow { namespace one { namespace functional { namespace detail { void PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step) { PySliceObject* obj = (PySliceObject*)object; if (obj->step == Py_None) { *step = 1; } else { CHECK_OR_THROW(_PyEval_SliceIndex(obj->step, step)) << "Invalid slice " << PyObjectToReprStr(object); CHECK_NE_OR_THROW(*step, 0) << "slice step cannot be zero."; if (*step < -PY_SSIZE_T_MAX) *step = -PY_SSIZE_T_MAX; } if (obj->start == Py_None) { *start = *step < 0 ? PY_SSIZE_T_MAX : 0; } else { CHECK_OR_THROW(_PyEval_SliceIndex(obj->start, start)) << "Invalid slice " << PyObjectToReprStr(object); } if (obj->stop == Py_None) { *stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX; } else { CHECK_OR_THROW(_PyEval_SliceIndex(obj->stop, stop)) << "Invalid slice " << PyObjectToReprStr(object); } } DataType InferScalarType(PyObject* object) { if (PyBool_Check(object)) { return DataType::kBool; } else if (PyLong_Check(object)) { return DataType::kInt64; } else if (PyArray_Check(object)) { return numpy::GetOFDataTypeFromNpArray(reinterpret_cast(object)).GetOrThrow(); } else if (PyArray_CheckScalar(object)) { return numpy::NumpyTypeToOFDataType(PyArray_DescrFromScalar(object)->type_num).GetOrThrow(); } else if (PySequence_Check(object)) { int64_t length = PySequence_Length(object); if (length == 0) { return DataType::kInt64; } DataType scalar_type = DataType::kInvalidDataType; for (int64_t i = 0; i < length; ++i) { PyObjectPtr item(PySequence_GetItem(object, i)); const auto& item_scalar_type = InferScalarType(item.get()); if (scalar_type != DataType::kInvalidDataType) { CHECK_EQ_OR_THROW(scalar_type, item_scalar_type) << "Different scalar types are not allowed."; } else { scalar_type = item_scalar_type; } } return scalar_type; } THROW(TypeError) << "Can't infer scalar type of " << Py_TYPE(object)->tp_name; return DataType::kInvalidDataType; } void ParseScalar(PyObject* object, char* data, const DataType& dtype) { if (dtype == DataType::kInt64) { CHECK_OR_THROW(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object)) << "Expected a long value."; *(reinterpret_cast(data)) = PyLong_AsLongLong(object); } else if (dtype == DataType::kInt32) { CHECK_OR_THROW(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object)) << "Expected a long value."; *(reinterpret_cast(data)) = PyLong_AsLongLong(object); } else if (dtype == DataType::kUInt8 || dtype == DataType::kBool) { CHECK_OR_THROW(PyBool_Check(object) || PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object)) << "Expected a boolean or long value."; if (PyBool_Check(object) || numpy::PyArrayCheckBoolScalar(object)) { *(reinterpret_cast(data)) = (object == Py_True); } else { int64_t value = PyLong_AsLongLong(object); CHECK_OR_THROW(value >= 0 && value <= 255) << "Out of range 0-255."; *(reinterpret_cast(data)) = static_cast(value); } } else { THROW(TypeError) << "Can't parse scalar with data type " << dtype; } } void RecursiveParseAndAssign(PyObject* object, char* data, const int& ndims, const int& dim, const ShapeView& shape, const DimVector& strides, const DataType& dtype) { if (dim == ndims) { return ParseScalar(object, data, dtype); } auto seq = PyObjectPtr(PySequence_Fast(object, "Expected a sequence.")); int64_t size = PySequence_Fast_GET_SIZE(seq.get()); CHECK_EQ_OR_THROW(size, shape.At(dim)) << "Sequence size is " << size << " at dimemsion " << dim << ", but expected " << shape.At(dim); for (int64_t i = 0; i < size; ++i) { PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); RecursiveParseAndAssign(item, data, ndims, dim + 1, shape, strides, dtype); data += strides.at(dim) * GetSizeOfDataType(dtype); } } void ParseArrayToTensor(PyObject* object, const std::shared_ptr& eager_blob_object) { const DataType dtype = eager_blob_object->data_type(); const int ndims = eager_blob_object->shape().NumAxes(); DimVector strides(ndims); int64_t size = 1; for (int i = ndims - 1; i >= 0; --i) { strides[i] = size; size *= eager_blob_object->shape().At(i); } RecursiveParseAndAssign(object, eager_blob_object->mut_dptr(), ndims, 0, eager_blob_object->shape(), strides, dtype); } Shape InferArraySizes(PyObject* object) { DimVector sizes; PyObject* seq = object; PyObjectPtr handle; while (PySequence_Check(seq)) { int64_t length = PySequence_Length(seq); sizes.emplace_back(length); CHECK_LE_OR_THROW(sizes.size(), /*MAX_DIMS=*/128) << "Too many dimensions " << Py_TYPE(seq)->tp_name; if (length == 0) break; handle = PyObjectPtr(PySequence_GetItem(seq, 0)); seq = handle.get(); } return Shape(sizes); } Maybe ConvertToIndexingTensor(PyObject* object) { // NOTE: convert data to indexing will ensure in eager mode LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); const DataType dtype = InferScalarType(object); const auto& device = JUST(Device::New("cpu")); // index type must be integers if (!(IsIntegralDataType(dtype) || (IsBoolDataType(dtype)))) { return Error::IndexError() << "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis " "(`None`) and integer or boolean arrays are valid indices"; } // In advanced indexing condition, index can be array object, need to handle it specially. if (PyArray_Check(object)) { return TensorWithData(object, NullOpt, device, /*requires_grad=*/false, /*pin_memory=*/false); } const auto& sizes = InferArraySizes(object); const auto& tensor = JUST(functional::Empty(sizes, CHECK_JUST(DType::Get(dtype)), device, /*requires_grad=*/false, /*pin_memory=*/false)); // Prevent the python object release until the callback is complete. Py_INCREF(object); auto handle = std::shared_ptr(PyObjectPtr(object)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->AccessBlobByCallback( JUST(tensor->AsLocalTensor()), [handle](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { CHECK_JUST(Singleton::Get()->WithScopedAcquire([&]() -> Maybe { ParseArrayToTensor(handle.get(), eager_blob_object); return Maybe::Ok(); })); }, "mut"); })); return tensor; } IndexItem UnpackIndexItem(PyObject* object) { if (object == Py_Ellipsis) { return IndexItem(EllipsisIndex{}); } else if (PySlice_Check(object)) { Py_ssize_t start, end, step; PySliceUnpack(object, &start, &end, &step); return IndexItem(start, end, step); } else if (PyLong_Check(object) && object != Py_False && object != Py_True) { return IndexItem(static_cast(PyLong_AsLongLong(object))); } else if (numpy::PyArrayCheckLongScalar(object)) { return IndexItem(static_cast(PyLong_AsLongLong(object))); } else if (object == Py_False || object == Py_True) { return IndexItem(object == Py_True); } else if (object == Py_None) { return IndexItem(NoneIndex{}); } else if (PyTensor_Check(object)) { return IndexItem(PyTensor_Unpack(object)); } else if (PySequence_Check(object)) { return IndexItem(ConvertToIndexingTensor(object).GetPtrOrThrow()); } THROW(IndexError) << "Invalid index " << Py_TYPE(object)->tp_name; return IndexItem(); } } // namespace detail } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/indexing.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/functional/common.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/tensor_index.h" namespace oneflow { namespace one { namespace functional { namespace detail { void PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step); Maybe ConvertToIndexingTensor(PyObject* object); IndexItem UnpackIndexItem(PyObject* object); } // namespace detail } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_ ================================================ FILE: oneflow/api/python/functional/python_arg.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/framework/tensor.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/indexing.h" #include "oneflow/api/python/framework/memory_format.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/layout.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/tensor_index.h" namespace py = pybind11; namespace oneflow { namespace one { namespace functional { #define INSTANCE_OBJECT_AS_INTEGER(T) \ template<> \ T PythonArg::ObjectAs() const { \ if (PyIntegerScalarTensorCheck(object_)) { \ return static_cast(PyUnpackIntegerScalarTensor_AsLongLong(object_)); \ } \ return static_cast(PyLong_AsLongLong(object_)); \ } \ template<> \ std::vector PythonArg::ObjectAs>() const { \ if (size_ > 0 && PyLong_Check(object_)) { \ return std::vector(size_, static_cast(PyLong_AsLongLong(object_))); \ } \ return PyUnpackLongSequence(object_); \ } \ template<> \ std::shared_ptr> PythonArg::ObjectAs>>() const { \ return std::make_shared>(ObjectAs>()); \ } OF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_INTEGER, INTEGER_AND_BOOL_TYPE_SEQ) #undef INSTANCE_OBJECT_AS_INTEGER #define INSTANCE_OBJECT_AS_FLOAT(T) \ template<> \ T PythonArg::ObjectAs() const { \ if (PyFloatScalarTensorCheck(object_)) { \ return static_cast(PyUnpackFloatScalarTensor_AsDouble(object_)); \ } \ return static_cast(PyFloat_AsDouble(object_)); \ } \ template<> \ std::vector PythonArg::ObjectAs>() const { \ if (size_ > 0 && PyFloat_Check(object_)) { \ return std::vector(size_, static_cast(PyFloat_AsDouble(object_))); \ } \ return PyUnpackFloatSequence(object_); \ } \ template<> \ std::shared_ptr> PythonArg::ObjectAs>>() const { \ return std::make_shared>(ObjectAs>()); \ } OF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_FLOAT, FLOATING_TYPE_SEQ) #undef INSTANCE_OBJECT_AS_FLOAT #define INSTANCE_OBJECT_AS_SHARED_PTR(T) \ template<> \ std::shared_ptr PythonArg::ObjectAs>() const { \ return std::make_shared(ObjectAs()); \ } template<> std::string PythonArg::ObjectAs() const { return PyStringAsString(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(std::string) template<> Scalar PythonArg::ObjectAs() const { if (PyScalarTensorCheck(object_)) { return PyUnpackScalarTensor(object_); } return PyUnpackScalar(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(Scalar) template<> std::shared_ptr PythonArg::ObjectAs>() const { return PyTensor_Unpack(object_); } template<> one::TensorTuple PythonArg::ObjectAs() const { if (PyTensorTupleCheck(object_)) { return *PyUnpackTensorTuple(object_); } const auto& v = PyUnpackTensorSequence(object_); one::TensorTuple values(v.size()); for (int i = 0; i < v.size(); ++i) { values[i] = v.at(i); } return values; } INSTANCE_OBJECT_AS_SHARED_PTR(one::TensorTuple) template<> Symbol PythonArg::ObjectAs>() const { return PyUnpackDType(object_); } template<> Symbol PythonArg::ObjectAs>() const { return PyUnpackLayout(object_); } template<> Symbol PythonArg::ObjectAs>() const { return PyUnpackMemoryFormat(object_); } template<> std::vector> PythonArg::ObjectAs>>() const { return PyUnpackDTypeSequence(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(std::vector>) template<> Shape PythonArg::ObjectAs() const { return PyUnpackShape(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(Shape) template<> std::vector PythonArg::ObjectAs>() const { return PyUnpackShapeSequence(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(std::vector) template<> std::shared_ptr PythonArg::ObjectAs>() const { return PyUnpackGenerator(object_); } template<> Symbol PythonArg::ObjectAs>() const { if (PyStringCheck(object_)) { std::string device_str = PyStringAsString(object_); return Device::ParseAndNew(device_str).GetOrThrow(); } return PyUnpackDevice(object_); } template<> Symbol PythonArg::ObjectAs>() const { return PyUnpackParallelDesc(object_); } template<> Symbol PythonArg::ObjectAs>() const { return PyUnpackSbpParallel(object_); } template<> std::vector> PythonArg::ObjectAs>>() const { if (PySbpParallelCheck(object_)) { return std::vector>(1, PyUnpackSbpParallel(object_)); } return PyUnpackSbpParallelSequence(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(std::vector>) template<> TensorIndex PythonArg::ObjectAs() const { return PyUnpackTensorIndex(object_); } INSTANCE_OBJECT_AS_SHARED_PTR(TensorIndex) template<> std::shared_ptr PythonArg::ObjectAs>() const { return PyUnpackOpExpr(object_); } template<> PyObject* PythonArg::ObjectAs() const { return object_; } template<> std::vector PythonArg::ObjectAs>() const { return PyUnpackSequence( object_, [](PyObject* item) -> std::string { return PyStringAsString(item); }); } INSTANCE_OBJECT_AS_SHARED_PTR(std::vector) template<> MemoryFormat PythonArg::ObjectAs() const { return PyMemoryFormat_Unpack(object_); } #undef INSTANCE_OBJECT_AS_SHARED_PTR bool PythonArg::TypeCheck(ValueType type) const { if (tag_ == HAS_DEFAULT) { return default_val_->value_type() == type; } switch (type) { case kINT32: case kINT16: case kCHAR: case kUINT32: case kINT64: case kUINT64: case kBOOL: return PyLong_Check(object_) || numpy::PyArrayCheckLongScalar(object_) || PyIntegerScalarTensorCheck(object_) || PyBoolScalarTensorCheck(object_); case kINT32_LIST: case kUINT32_LIST: case kINT64_LIST: case kUINT64_LIST: case kBOOL_LIST: return PyLongSequenceCheck(object_) || (size_ > 0 && PyLong_Check(object_)); case kFLOAT: case kDOUBLE: return PyFloat_Check(object_) || PyLong_Check(object_) || numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_) || PyFloatScalarTensorCheck(object_) || PyIntegerScalarTensorCheck(object_); case kFLOAT_LIST: case kDOUBLE_LIST: return PyFloatSequenceCheck(object_) || (size_ > 0 && (PyFloat_Check(object_) || PyLong_Check(object_))); case kSTRING: return PyStringCheck(object_); case kSTRING_LIST: return PyStringSequenceCheck(object_); case kSCALAR: return PyScalarCheck(object_) || numpy::PyArrayCheckLongScalar(object_) || numpy::PyArrayCheckFloatScalar(object_) || PyScalarTensorCheck(object_); case kTENSOR: case kTENSOR_REF: return PyTensor_Check(object_); case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_); case kDTYPE: return PyDTypeCheck(object_); case kLAYOUT: return PyLayoutCheck(object_); case kMEMORY_FORMAT: return PyMemoryFormat_Check(object_); case kSHAPE: return PyLongSequenceCheck(object_); case kGENERATOR: case kGENERATOR_REF: return PyGeneratorCheck(object_); case kTENSOR_INDEX: return PyTensorIndexCheck(object_); case kDEVICE: return PyStringCheck(object_) || PyDeviceCheck(object_); case kPARALLEL_DESC: return PyParallelDescCheck(object_); case kSBP_PARALLEL: return PySbpParallelCheck(object_); case kSBP_PARALLEL_LIST: return PySbpParallelSequenceCheck(object_) || PySbpParallelCheck(object_); case kOPEXPR_REF: return PyOpExprCheck(object_); case kPY_OBJECT: return nullptr != object_; case kDTYPE_LIST: return PyDTypeSequenceCheck(object_); case kSHAPE_LIST: return PyShapeSequenceCheck(object_); case kCOMPLEX_FLOAT: case kCOMPLEX_DOUBLE: return PyComplex_Check(object_) || PyFloat_Check(object_) || PyLong_Check(object_) || numpy::PyArrayCheckComplexScalar(object_) || numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_) || PyComplexScalarTensorCheck(object_) || PyFloatScalarTensorCheck(object_) || PyIntegerScalarTensorCheck(object_); default: { THROW(RuntimeError) << "Can not check type " << ValueTypeName(type); } } return false; } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/python_arg.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_ #include #include #undef _PyGC_FINALIZED #include "oneflow/core/common/throw.h" #include "oneflow/api/python/functional/value_types.h" #include "oneflow/core/common/maybe.h" namespace py = pybind11; namespace oneflow { namespace one { namespace functional { namespace detail { struct DefaultVal { virtual ValueType value_type() const = 0; virtual const void* Ptr() const = 0; }; template struct TypedDefaultVal final : public DefaultVal { T content; explicit TypedDefaultVal(const T& v) : content(v) {} ValueType value_type() const override { return ValueTypeOf(); } const void* Ptr() const override { return &content; } }; template struct optional_traits { using type = void; }; template struct optional_traits> { using type = decltype(std::declval>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()); }; } // namespace detail class PythonArg { public: PythonArg() = default; PythonArg(PyObject* object, int size = 0) : object_(object), default_val_(), size_(size), tag_(HAS_OBJECT) {} PythonArg(const detail::DefaultVal* value, int size = 0) : object_(nullptr), default_val_(value), size_(size), tag_(HAS_DEFAULT) {} template::value, int>::type = 0> T As() const { if (tag_ == HAS_DEFAULT) { CHECK_EQ_OR_THROW(ValueTypeOf(), default_val_->value_type()) << "Could not convert default value from type " << default_val_->value_type() << " to type " << ValueTypeOf(); return *reinterpret_cast(default_val_->Ptr()); } CHECK_EQ_OR_THROW(tag_, HAS_OBJECT); return ObjectAs>(); } template::value, int>::type = 0> T As() const { if (tag_ == HAS_DEFAULT) { CHECK_EQ_OR_THROW(ValueTypeOf(), default_val_->value_type()) << "Could not convert default value from type " << default_val_->value_type() << " to type " << ValueTypeOf(); return *reinterpret_cast(default_val_->Ptr()); } CHECK_EQ_OR_THROW(tag_, HAS_OBJECT); if (object_ == Py_None) { return T(); } return ObjectAs::type>(); } bool TypeCheck(ValueType type) const; private: template T ObjectAs() const; PyObject* object_; const detail::DefaultVal* default_val_; size_t size_; enum { HAS_OBJECT, HAS_DEFAULT, HAS_NONE } tag_; }; } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_ ================================================ FILE: oneflow/api/python/functional/python_arg_parser.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/functional/python_arg_parser.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/python_arg.h" namespace oneflow { namespace one { namespace functional { void FunctionSchema::ReportKwargsError(PyObject* kwargs, size_t nargs) const { PyObject *key = nullptr, *value = nullptr; Py_ssize_t pos = 0; while (PyDict_Next(kwargs, &pos, &key, &value)) { if (!PyStringCheck(key)) { THROW(TypeError) << def_->name << "(): keywords must be strings"; } int64_t index = -1; const std::string string_key = PyStringAsString(key); for (int i = 0; i < def_->argument_def.size(); ++i) { const auto& arg = def_->argument_def[i]; if (arg.name == string_key) { index = i; break; } } if (index < 0) { THROW(TypeError) << def_->name << "(): got an unexpected keyword argument '" << string_key << "'"; } if (index < nargs) { THROW(TypeError) << def_->name << "(): got multiple values for argument '" << string_key << "'"; } } THROW(TypeError) << def_->name << "(): kwargs unknown error"; } // The argument parsing refers to the implementation of Pytorch. bool FunctionSchema::Parse(PyObject* args, PyObject* kwargs, PythonArg* parsed_args, bool raise_exception) const { bool treat_args_as_list = false; size_t nargs = args ? PyTuple_Size(args) : 0; size_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0; if (max_pos_nargs_ == 1) { const auto& type = def_->argument_def[0].type; treat_args_as_list = IsIntegralListType(type) || type == kSHAPE || type == kTENSOR_TUPLE; } if (nargs > max_pos_nargs_ && !treat_args_as_list) { if (raise_exception) { THROW(TypeError) << def_->name << "(): takes " << max_pos_nargs_ << " positional arguments but " << nargs << " were given"; } return false; } int arg_pos = 0; for (int i = 0; i < def_->argument_def.size(); ++i) { const auto& param = def_->argument_def[i]; PyObject* obj = NULL; if (args && arg_pos < nargs) { if (param.keyword_only) { if (raise_exception) { THROW(TypeError) << def_->name << "(): argument '" << param.name << "' is keyword only"; } return false; } obj = PyTuple_GET_ITEM(args, arg_pos); } else if (kwargs) { obj = PyDict_GetItemString(kwargs, param.name.c_str()); if (obj) { --remaining_kwargs; } } if (obj) { if (arg_pos == 0 && treat_args_as_list && !param.keyword_only && (PyLong_Check(obj) || PyTensor_Check(obj))) { obj = args; arg_pos = nargs; } else { ++arg_pos; } PythonArg arg(obj, param.size); if ((obj == Py_None && param.optional) || arg.TypeCheck(param.type)) { parsed_args[i] = arg; } else { if (raise_exception) { THROW(TypeError) << def_->name << "(): argument '" << param.name << "' must be " << ValueTypeName(param.type) << ", not " << PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj))); } return false; } } else { if (!param.has_default_value) { if (raise_exception) { THROW(TypeError) << def_->name << "(): missing required argument " << param.name; } return false; } parsed_args[i] = param.default_value.get(); } } if (remaining_kwargs > 0) { if (raise_exception) { ReportKwargsError(kwargs, nargs); } return false; } return true; } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/python_arg_parser.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace one { namespace functional { template class ParsedArgs { public: ParsedArgs() = default; const PythonArg& operator[](size_t idx) const { return data[idx]; } PythonArg& operator[](size_t idx) { return data[idx]; } public: PythonArg data[N]; }; class FunctionSchema { public: FunctionSchema() = default; FunctionSchema(const std::string& signature, const FunctionDef* def, size_t max_pos_nargs) : signature_(signature), def_(def), max_pos_nargs_(max_pos_nargs) {} const std::string& signature() const { return signature_; } bool Parse(PyObject* args, PyObject* kwargs, PythonArg* parsed_args, bool raise_exception) const; private: void ReportKwargsError(PyObject* kwargs, size_t nargs) const; std::string signature_; const FunctionDef* def_; size_t max_pos_nargs_; }; template class PythonArgParser { public: static_assert(sizeof...(SchemaT) >= 1, "requires 1 template argument at least."); static constexpr size_t kSchemaSize = sizeof...(SchemaT); static constexpr size_t N = std::max({SchemaT::max_args...}); template using schema_t = typename std::tuple_element>::type; PythonArgParser(const std::string& name) : name_(name) { Init(std::make_index_sequence{}); } int Parse(PyObject* args, PyObject* kwargs, ParsedArgs* parsed_args) const { bool raise_exception = (kSchemaSize == 1); for (int i = 0; i < kSchemaSize; ++i) { if (schema_[i].Parse(args, kwargs, parsed_args->data, raise_exception)) { return i; } } ReportInvalidArgsError(args, kwargs); return -1; } private: template void Init(std::index_sequence) { ((schema_[I] = FunctionSchema(schema_t::signature, &schema_t::function_def, schema_t::max_pos_args)), ...); } void ReportInvalidArgsError(PyObject* args, PyObject* kwargs) const { std::ostringstream ss; ss << name_ << "(): received an invalid combination of arguments. The valid signatures are:"; for (int i = 0; i < kSchemaSize; ++i) { ss << "\n\t*" << i << ": " << schema_[i].signature(); } THROW(TypeError) << ss.str(); } private: std::string name_; FunctionSchema schema_[kSchemaSize]; }; } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_ ================================================ FILE: oneflow/api/python/functional/python_return_types.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // This code is referenced from: // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/structseq.cpp #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_ #include #undef _PyGC_FINALIZED #include #include #include #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" namespace oneflow { namespace one { namespace functional { inline PyObject* toTuple(PyStructSequence* obj) { #if PY_MAJOR_VERSION == 2 ROF_RUNTIME_ERROR() << "Oneflow do not support python 2"; #else Py_INCREF(obj); return (PyObject*)obj; #endif } PyObject* returned_structseq_repr(PyStructSequence* obj) { HANDLE_ERRORS PyTypeObject* tp = Py_TYPE(obj); PyObject* tuple = toTuple(obj); if (tuple == nullptr) { return nullptr; } std::stringstream ss; ss << tp->tp_name << "(\n"; Py_ssize_t num_elements = Py_SIZE(obj); for (Py_ssize_t i = 0; i < num_elements; i++) { const char* cname = tp->tp_members[i].name; if (cname == nullptr) { PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %zd name is nullptr" " for type %.500s", i, tp->tp_name); Py_DECREF(tuple); return nullptr; } PyObject* val = PyTuple_GetItem(tuple, i); if (val == nullptr) { Py_DECREF(tuple); return nullptr; } auto repr = PyObject_Repr(val); if (repr == nullptr) { Py_DECREF(tuple); return nullptr; } const char* crepr = PyUnicode_AsUTF8(repr); Py_DECREF(repr); if (crepr == nullptr) { Py_DECREF(tuple); return nullptr; } ss << cname << '=' << crepr; if (i < num_elements - 1) { ss << ",\n"; } } ss << ")"; Py_DECREF(tuple); return PyUnicode_FromString(ss.str().c_str()); END_HANDLE_ERRORS } } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_ ================================================ FILE: oneflow/api/python/functional/tensor_api.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #undef _PyGC_FINALIZED #include #include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/api/python/dlpack/converter.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/tensor_api.yaml.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/foreign_lock_helper.h" namespace oneflow { namespace one { namespace functional { namespace impl { class TensorWithDataFunctor { public: Maybe operator()(PyObject* data, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. // even if in nn.Graph build (module forward function), if you create a flow.Tensor, // its a eager tensor by Run functional::Empty() in LazyMode::Grad(false) LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST( functional::GlobalTensorWithData(data, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), requires_grad)); } if (PyTensor_Check(data)) { // Throw warnings like pytorch. auto ret = PyErr_WarnEx( PyExc_UserWarning, "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " "or sourceTensor.clone().detach().requires_grad_(True), rather than " "oneflow.tensor(sourceTensor).", 1); if (ret != 0) { return Error::RuntimeError(); } const auto& other = PyTensor_Unpack(data); return MakeTensorFromOtherTensor(other, dtype, device, requires_grad, pin_memory); } else { // Make tensor from python sequence or numpy array. return MakeLocalTensorFromData(data, dtype, device, requires_grad, pin_memory); } } }; class GlobalTensorWithDataFunctor { public: Maybe operator()(PyObject* data, const Optional>& dtype, const Symbol& placement, const std::vector>& sbp_tuple, const bool requires_grad) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); JUST(CheckDeviceIdsIsValid(placement)); if (PyTensor_Check(data)) { // Throw warnings like pytorch. auto ret = PyErr_WarnEx( PyExc_UserWarning, "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " "or sourceTensor.clone().detach().requires_grad_(True), rather than " "oneflow.tensor(sourceTensor).", 1); if (ret != 0) { return Error::RuntimeError(); } const auto& other = PyTensor_Unpack(data); return MakeTensorFromOtherTensor(other, dtype, placement, sbp_tuple, requires_grad); } // Make global tensor from python sequence or numpy array. return MakeGlobalTensorFromData(data, dtype, placement, sbp_tuple, requires_grad); } }; class TensorEmptyGenericCtorFunctor { public: Maybe operator()(const Symbol& dtype, const Optional>& device) const { Shape shape(DimVector{0}); return TensorWithShapeGenericCtor(shape, dtype, device); } }; class GlobalTensorEmptyGenericCtorFunctor { public: Maybe operator()(const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { Shape shape(DimVector{0}); JUST(CheckDeviceIdsIsValid(placement)); return GlobalTensorWithShapeGenericCtor(shape, dtype, placement, sbp_tuple); } }; class TensorWithOtherGenericCtorFunctor { public: Maybe operator()(const std::shared_ptr& other, const Optional>& dtype) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); bool is_pinned = false; if (other->is_local()) { is_pinned = JUST(CHECK_JUST(other->AsLocalTensor())->is_pinned()); } return To(JUST(MakeTensorFromOtherTensor(other, is_pinned)), dtype, false); } }; class TensorWithDataGenericCtorFunctor { public: Maybe operator()(PyObject* data, const Symbol& dtype, const Optional>& device) const { // Treat the single long as shape. if (PyLong_Check(data)) { int64_t size = PyLong_AsLongLong(data); Shape shape(DimVector{size}); return TensorWithShapeGenericCtor(shape, dtype, device); } if (TensorSize_Check(data)) { return TensorWithShapeGenericCtor(TensorSize_AsShape(data), dtype, device); } // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); if (PyTensor_Check(data)) { const auto& other = PyTensor_Unpack(data); const bool pin_memory = other->is_local() ? JUST(JUST(other->AsLocalTensor())->is_pinned()) : false; return MakeTensorFromOtherTensor(other, dtype, device, /*requires_grad=*/false, /*pin_memory=*/pin_memory); } // Make tensor from python sequence or numpy array. return MakeLocalTensorFromData(data, dtype, device, /*requires_grad=*/false, /*pin_memory=*/false); } }; class GlobalTensorWithDataGenericCtorFunctor { public: Maybe operator()(PyObject* data, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { JUST(CheckDeviceIdsIsValid(placement)); // Treat the single long as shape. if (PyLong_Check(data)) { int64_t size = PyLong_AsLongLong(data); Shape shape(DimVector{size}); return GlobalTensorWithShapeGenericCtor(shape, dtype, placement, sbp_tuple); } if (TensorSize_Check(data)) { return GlobalTensorWithShapeGenericCtor(TensorSize_AsShape(data), dtype, placement, sbp_tuple); } // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); if (PyTensor_Check(data)) { const auto& other = PyTensor_Unpack(data); return MakeTensorFromOtherTensor(other, dtype, placement, sbp_tuple, /*requires_grad=*/false); } // Make global tensor from python sequence or numpy array. return MakeGlobalTensorFromData(data, dtype, placement, sbp_tuple, /*requires_grad=*/false); } }; class TensorWithShapeGenericCtorFunctor { public: Maybe operator()(const Shape& shape, const Symbol& dtype, const Optional>& device) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); Symbol device_; if (device) { device_ = JUST(device); } else { device_ = JUST(Device::New("cpu")); } return functional::Empty(shape, dtype, device_, /*requires_grad=*/false, /*pin_memory=*/false); } }; class GlobalTensorWithShapeGenericCtorFunctor { public: Maybe operator()(const Shape& shape, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); JUST(CheckDeviceIdsIsValid(placement)); return functional::GlobalEmpty(shape, dtype, placement, sbp_tuple); } }; class AssignLocalTensorFunctor { public: AssignLocalTensorFunctor() { op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& x) const { // JUST(CheckInplaceValid(y)); // align check to torch CHECK_OR_RETURN(y->is_local() && x->is_local()) << "Both x and y must be local tensor."; std::shared_ptr src = x; if (y->dtype() != src->dtype()) { src = JUST(To(src, y->dtype(), false)); } auto device = JUST(y->device()); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("device", "pin_memory"); attrs.SetAllAttrs(device, false); TensorTuple outputs{y}; return OpInterpUtil::Dispatch(*op_, {x}, &outputs, attrs); } private: std::shared_ptr op_; }; static std::vector get_shape_or_stride_from_numpy(size_t ndim, npy_intp* values) { auto result = std::vector(ndim); for (size_t i = 0; i < ndim; ++i) { result[i] = static_cast(values[i]); } return result; } class LocalTensorSharedDlPackDataFunctor { public: LocalTensorSharedDlPackDataFunctor() {} Maybe operator()(PyObject* obj) const { DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(obj, "dltensor"); CHECK_NOTNULL_OR_RETURN(dlMTensor) << "from_dlpack received an invalid capsule. " "Note that DLTensor capsules can be consumed only once, " "so you might have already constructed a tensor from it once."; // `tensor` steals the ownership of the underlying storage. It also passes a // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. auto tensor = fromDLPack(dlMTensor); // Make sure this capsule will never be used again. PyCapsule_SetName(obj, "used_dltensor"); return tensor; } }; class LocalTensorSharedNumpyDataFunctor { public: LocalTensorSharedNumpyDataFunctor() {} Maybe operator()(PyObject* obj) const { if (!PyArray_Check(obj)) { return Error::TypeError() << "expected np.ndarray, but got " << Py_TYPE(obj)->tp_name; } auto* array = reinterpret_cast(obj); const size_t ndim = PyArray_NDIM(array); std::vector sizes = get_shape_or_stride_from_numpy(ndim, PyArray_DIMS(array)); std::vector strides = get_shape_or_stride_from_numpy(ndim, PyArray_STRIDES(array)); // NumPy strides use bytes. OneFlow strides use element counts. // These checks are consistent with pytorch(v1.10.0): // https://github.com/pytorch/pytorch/blob/v1.10.0/torch/csrc/utils/tensor_numpy.cpp#L171 const auto element_size_in_bytes = PyArray_ITEMSIZE(array); for (auto& stride : strides) { if (stride % element_size_in_bytes != 0) { return Error::InvalidValueError() << "given numpy array strides not a multiple of the element byte size. " << "Copy the numpy array to reallocate the memory."; } stride /= element_size_in_bytes; } for (size_t i = 0; i < ndim; ++i) { if (strides[i] < 0) { return Error::InvalidValueError() << "At least one stride in the given numpy array is negative, " << "and tensors with negative strides are not currently supported. " << "(You can probably work around this by making a copy of your array " << " with array.copy().) "; } } void* data_ptr = PyArray_DATA(array); if (!PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE)) { return Error::InvalidValueError() << "given numpy array has byte order different from the native byte order. " << "Conversion between byte orders is currently not supported."; } Py_INCREF(obj); // Build TensorMeta const auto shape = Shape(DimVector(sizes.begin(), sizes.end())); const auto stride = Stride(strides.begin(), strides.end()); DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(array)); Symbol device = JUST(Device::New("cpu")); auto tensor_meta = SymbolOf(LocalTensorMeta(shape, stride, data_type, MemoryFormat::kContiguous, device)); // Build TensorBuffer const auto& Free = [array](char* dptr) { CHECK_JUST(Singleton::Get()->WithScopedAcquire([&]() -> Maybe { Py_DECREF(array); return Maybe::Ok(); })); }; const auto array_size_in_bytes = PyArray_NBYTES(array); auto tensor_data = std::make_shared(false, device); tensor_data->set_blob_dptr( std::unique_ptr>(static_cast(data_ptr), Free), array_size_in_bytes); // Build TensorStorage: decrease ndarray reference count before releasing auto tensor_storage = std::make_shared(tensor_data); // Build Tensor auto tensor_impl = std::make_shared(tensor_storage, /*requires_grad=*/false, /*ls_leaf=*/true); // Init blob JUST(tensor_impl->InitEagerBlobObject(tensor_meta, NewLocalDepObject())); const auto& stream = JUST(GetDefaultStreamByDevice(device)); const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object()); JUST(eager_blob_object->init_producer_stream(stream)); eager_blob_object->set_last_used_stream(stream); std::shared_ptr out(new LocalTensor(tensor_impl)); return out; } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("TensorWithData"); m.add_functor("GlobalTensorWithData"); m.add_functor("TensorEmptyGenericCtor"); m.add_functor("GlobalTensorEmptyGenericCtor"); m.add_functor("TensorWithOtherGenericCtor"); m.add_functor("TensorWithDataGenericCtor"); m.add_functor("GlobalTensorWithDataGenericCtor"); m.add_functor("TensorWithShapeGenericCtor"); m.add_functor("GlobalTensorWithShapeGenericCtor"); m.add_functor("AssignLocalTensor"); m.add_functor("LocalTensorSharedNumpyData"); m.add_functor("TensorEmptyCtor", [](const Optional>& device) -> Maybe { return TensorEmptyGenericCtor(GetDefaultDType(), device); }); m.add_functor("GlobalTensorEmptyCtor", [](const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { return GlobalTensorEmptyGenericCtor(GetDefaultDType(), placement, sbp_tuple); }); m.add_functor("TensorWithOtherCtor", [](const std::shared_ptr& other) -> Maybe { return TensorWithOtherGenericCtor(other, NullOpt); }); m.add_functor("TensorWithDataCtor", [](PyObject* data, const Optional>& device) -> Maybe { return TensorWithDataGenericCtor(data, GetDefaultDType(), device); }); m.add_functor("GlobalTensorWithDataCtor", [](PyObject* data, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { return GlobalTensorWithDataGenericCtor(data, GetDefaultDType(), placement, sbp_tuple); }); m.add_functor("TensorWithShapeCtor", [](const Shape& shape, const Optional>& device) -> Maybe { return TensorWithShapeGenericCtor(shape, GetDefaultDType(), device); }); m.add_functor("GlobalTensorWithShapeCtor", [](const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple) -> Maybe { return GlobalTensorWithShapeGenericCtor(shape, GetDefaultDType(), placement, sbp_tuple); }); m.add_functor("LocalTensorSharedDlPackData"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/tensor_api.yaml ================================================ # Copyright 2020 The OneFlow Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - name: "tensor" signature: [ "Tensor (PyObject* data, *, DataType dtype=None, Device device=None, Bool requires_grad=False, Bool pin_memory=False) => TensorWithData", "Tensor (PyObject* data, *, DataType dtype=None, Placement placement, SbpList sbp, Bool requires_grad=False) => GlobalTensorWithData", ] bind_python: True - name: "_legacy_tensor_generic_ctor" signature: [ "Tensor (*, DataType dtype, Device device=None) => TensorEmptyGenericCtor", "Tensor (*, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorEmptyGenericCtor", "Tensor (Tensor other, *, DataType dtype=None) => TensorWithOtherGenericCtor", "Tensor (PyObject* data, *, DataType dtype, Device device=None) => TensorWithDataGenericCtor", "Tensor (PyObject* data, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorWithDataGenericCtor", "Tensor (Shape size, *, DataType dtype, Device device=None) => TensorWithShapeGenericCtor", "Tensor (Shape size, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorWithShapeGenericCtor", ] bind_python: True - name: "_legacy_tensor_ctor" signature: [ "Tensor (*, Device device=None) => TensorEmptyCtor", "Tensor (*, Placement placement, SbpList sbp) => GlobalTensorEmptyCtor", "Tensor (Tensor other) => TensorWithOtherCtor", "Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor", "Tensor (PyObject* data, *, Placement placement, SbpList sbp) => GlobalTensorWithDataCtor", "Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor", "Tensor (Shape size, *, Placement placement, SbpList sbp) => GlobalTensorWithShapeCtor", ] bind_python: True - name: "assign_local_tensor" signature: "Void (Tensor ref, Tensor value) => AssignLocalTensor" bind_python: True - name: "from_numpy" signature: "Tensor (PyObject* obj) => LocalTensorSharedNumpyData" bind_python: True - name: "from_dlpack" signature: "Tensor (PyObject* obj) => LocalTensorSharedDlPackData" bind_python: True ================================================ FILE: oneflow/api/python/functional/value_types.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/functional/value_types.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/hash_container.h" namespace oneflow { namespace one { namespace functional { HashMap* GetValueTypeNameMap() { static HashMap value_type_name_map = { {kVOID, "void"}, {kINT32, "int32"}, {kUINT32, "unsigned int32"}, {kINT64, "int64"}, {kUINT64, "unsigned int64"}, {kFLOAT, "float"}, {kDOUBLE, "double"}, {kBOOL, "bool"}, {kSTRING, "string"}, {kINT32_LIST, "int32 list"}, {kUINT32_LIST, "unsigned int32 list"}, {kINT64_LIST, "int64 list"}, {kUINT64_LIST, "unsigned int64 list"}, {kFLOAT_LIST, "float list"}, {kDOUBLE_LIST, "double list"}, {kDOUBLE_LIST, "bool list"}, {kSTRING_LIST, "string list"}, {kVOID_MAYBE, "maybe void"}, {kBOOL_MAYBE, "maybe bool"}, {kSCALAR, "scalar"}, {kTENSOR, "tensor"}, {kTENSOR_REF, "tensor"}, {kTENSOR_MAYBE, "maybe tensor"}, {kTENSOR_TUPLE, "tensor tuple"}, {kTENSOR_TUPLE_REF, "tensor tuple"}, {kTENSOR_TUPLE_MAYBE, "maybe tensor tuple"}, {kATTR, "attr"}, {kATTR_REF, "attr"}, {kDTYPE, "data type"}, {kDTYPE_LIST, "data type list"}, {kSHAPE, "shape"}, {kSHAPE_LIST, "shape list"}, {kGENERATOR, "generator"}, {kGENERATOR_REF, "generator"}, {kGENERATOR_MAYBE, "maybe generator"}, {kTENSOR_INDEX, "index"}, {kDEVICE, "device"}, {kPARALLEL_DESC, "placement"}, {kSBP_PARALLEL, "sbp"}, {kSBP_PARALLEL_LIST, "sbp list"}, {kOPEXPR, "opexpr"}, {kOPEXPR_REF, "opexpr"}, {kPY_OBJECT, "python object"}, {kLAYOUT, "layout"}, {kMEMORY_FORMAT, "memory format"}, {kCOMPLEX_FLOAT, "complex float"}, {kCOMPLEX_DOUBLE, "complex double"}, {kCHAR, "char"}, {kINT16, "int16"}}; return &value_type_name_map; } const std::string& ValueTypeName(ValueType type) { const auto* type_name_map = GetValueTypeNameMap(); const auto& it = type_name_map->find(type); CHECK_OR_THROW(it != type_name_map->end()) << "Value type " << type << " has no type name."; return it->second; } bool IsIntegralType(ValueType type) { return type >= kINT32 && type < kINTEGRAL_MASK; } bool IsIntegralListType(ValueType type) { return type >= kINT32_LIST && type < kINTEGRAL_LIST_MASK; } bool IsFloatingType(ValueType type) { return type >= kFLOAT && type < kFLOATING_MASK; } bool IsFloatingListType(ValueType type) { return type >= kFLOAT_LIST && type < kFLOATING_LIST_MASK; } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/functional/value_types.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_ #define ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_ #include #include #include #undef _PyGC_FINALIZED #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/layout.h" namespace oneflow { class Scalar; class Shape; template class Symbol; class Device; class ParallelDesc; class SbpParallel; namespace one { class Tensor; class TensorTuple; class Generator; class OpExpr; namespace functional { class TensorIndex; } // namespace functional } // namespace one namespace one { namespace functional { enum ValueType : int { kINVALID = 0, kVOID, // Integral kINT32, kINT64, kUINT32, kUINT64, kINTEGRAL_MASK = 10, // Floating kFLOAT, kDOUBLE, kFLOATING_MASK = 15, kBOOL, kSTRING, // Integral list kINT32_LIST = 50, kUINT32_LIST, kINT64_LIST, kUINT64_LIST, kINTEGRAL_LIST_MASK = 60, // Floating list kFLOAT_LIST, kDOUBLE_LIST, kFLOATING_LIST_MASK = 65, kBOOL_LIST, kSTRING_LIST, kVOID_MAYBE = 100, kBOOL_MAYBE, kSCALAR = 200, kTENSOR, kTENSOR_REF, kTENSOR_MAYBE, kTENSOR_TUPLE, kTENSOR_TUPLE_REF, kTENSOR_TUPLE_MAYBE, kATTR, kATTR_REF, kDTYPE, kSHAPE, kLAYOUT, kSHAPE_MAYBE, kGENERATOR, kGENERATOR_REF, kGENERATOR_MAYBE, kTENSOR_INDEX, kDEVICE, kPARALLEL_DESC, kSBP_PARALLEL, kSBP_PARALLEL_LIST, kSHAPE_LIST, kDTYPE_LIST, kMEMORY_FORMAT, kOPEXPR = 390, kOPEXPR_REF, kPY_OBJECT = 400, // Complex kCOMPLEX_FLOAT, kCOMPLEX_DOUBLE, kCHAR, kINT16 }; #define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \ template::value, int>::type = 0> \ inline ValueType ValueTypeOf() { \ return value_type; \ } \ template>::value, int>::type = 0> \ inline ValueType ValueTypeOf() { \ return value_type; \ } VALUE_TYPE_OF_IMPL(void, kVOID); VALUE_TYPE_OF_IMPL(int32_t, kINT32); VALUE_TYPE_OF_IMPL(int16_t, kINT16); VALUE_TYPE_OF_IMPL(char, kCHAR); VALUE_TYPE_OF_IMPL(uint32_t, kUINT32); VALUE_TYPE_OF_IMPL(int64_t, kINT64); VALUE_TYPE_OF_IMPL(uint64_t, kUINT64); VALUE_TYPE_OF_IMPL(float, kFLOAT); VALUE_TYPE_OF_IMPL(double, kDOUBLE); VALUE_TYPE_OF_IMPL(bool, kBOOL); VALUE_TYPE_OF_IMPL(std::string, kSTRING); VALUE_TYPE_OF_IMPL(std::vector, kINT32_LIST); VALUE_TYPE_OF_IMPL(std::vector, kUINT32_LIST); VALUE_TYPE_OF_IMPL(std::vector, kINT64_LIST); VALUE_TYPE_OF_IMPL(std::vector, kUINT64_LIST); VALUE_TYPE_OF_IMPL(std::vector, kFLOAT_LIST); VALUE_TYPE_OF_IMPL(std::vector, kDOUBLE_LIST); VALUE_TYPE_OF_IMPL(std::vector, kBOOL_LIST); VALUE_TYPE_OF_IMPL(std::vector, kSTRING_LIST); VALUE_TYPE_OF_IMPL(Maybe, kVOID_MAYBE); VALUE_TYPE_OF_IMPL(Maybe, kBOOL_MAYBE); VALUE_TYPE_OF_IMPL(Scalar, kSCALAR); VALUE_TYPE_OF_IMPL(one::Tensor, kTENSOR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kTENSOR_REF); VALUE_TYPE_OF_IMPL(Maybe, kTENSOR_MAYBE); VALUE_TYPE_OF_IMPL(one::TensorTuple, kTENSOR_TUPLE); VALUE_TYPE_OF_IMPL(std::shared_ptr, kTENSOR_TUPLE_REF); VALUE_TYPE_OF_IMPL(Maybe, kTENSOR_TUPLE_MAYBE); VALUE_TYPE_OF_IMPL(Symbol, kDTYPE); VALUE_TYPE_OF_IMPL(Symbol, kLAYOUT); VALUE_TYPE_OF_IMPL(std::vector>, kDTYPE_LIST); VALUE_TYPE_OF_IMPL(Shape, kSHAPE); VALUE_TYPE_OF_IMPL(Maybe, kSHAPE_MAYBE); VALUE_TYPE_OF_IMPL(std::vector, kSHAPE_LIST); VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kGENERATOR_REF); VALUE_TYPE_OF_IMPL(Maybe, kGENERATOR_MAYBE); VALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX); VALUE_TYPE_OF_IMPL(Symbol, kDEVICE); VALUE_TYPE_OF_IMPL(Symbol, kPARALLEL_DESC); VALUE_TYPE_OF_IMPL(Symbol, kSBP_PARALLEL); VALUE_TYPE_OF_IMPL(std::vector>, kSBP_PARALLEL_LIST); VALUE_TYPE_OF_IMPL(MemoryFormat, kMEMORY_FORMAT); VALUE_TYPE_OF_IMPL(one::OpExpr, kOPEXPR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kOPEXPR_REF); VALUE_TYPE_OF_IMPL(PyObject*, kPY_OBJECT); VALUE_TYPE_OF_IMPL(const PyObject*, kPY_OBJECT); VALUE_TYPE_OF_IMPL(std::complex, kCOMPLEX_FLOAT); VALUE_TYPE_OF_IMPL(std::complex, kCOMPLEX_DOUBLE); #undef VALUE_TYPE_OF_IMPL const std::string& ValueTypeName(ValueType type); bool IsIntegralType(ValueType type); bool IsIntegralListType(ValueType type); bool IsFloatingType(ValueType type); bool IsFloatingListType(ValueType type); } // namespace functional } // namespace one } // namespace oneflow namespace std { template<> struct hash { std::size_t operator()(oneflow::one::functional::ValueType v) const noexcept { return v; } }; } // namespace std #endif // ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_ ================================================ FILE: oneflow/api/python/gil_foreign_lock_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/foreign_lock_helper.h" #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/singleton.h" namespace py = pybind11; namespace oneflow { class GILForeignLockHelper final : public ForeignLockHelper { Maybe WithScopedRelease(const std::function()>& Callback) const override { if (PyGILState_Check()) { py::gil_scoped_release release; JUST(Callback()); } else { JUST(Callback()); } return Maybe::Ok(); } Maybe WithScopedAcquire(const std::function()>& Callback) const override { if (!PyGILState_Check()) { py::gil_scoped_acquire acquire; JUST(Callback()); } else { JUST(Callback()); } return Maybe::Ok(); } }; ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("RegisterGILForeignLockHelper", []() { Singleton::Delete(); Singleton::SetAllocated(new GILForeignLockHelper()); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/init.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/cluster_instruction.h" namespace py = pybind11; PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::unordered_map>>); namespace oneflow { namespace { using IntList = std::vector; using Int2IntListMap = std::unordered_map>; bool Int2IntListMapContaining(const Int2IntListMap& bigger, const Int2IntListMap& smaller) { for (const auto& pair : smaller) { if (bigger.find(pair.first) == bigger.end()) { return false; } const auto& bigger_device_ids = bigger.find(pair.first)->second; std::vector::iterator ret; for (int64_t device_id : *pair.second) { ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id); if (ret == bigger_device_ids->end()) { return false; } } } return true; } } // namespace PYBIND11_MODULE(_oneflow_internal, m) { using IntList = std::vector; using Int2IntListMap = std::unordered_map>; py::module_ oneflow_api_util = m.def_submodule("util"); py::class_>(oneflow_api_util, "IntList") .def(py::init<>()) .def("__len__", [](const std::shared_ptr& v) { return v->size(); }) .def( "items", [](std::shared_ptr& v) { return py::make_iterator(v->begin(), v->end()); }, py::keep_alive<0, 1>()) .def("__getitem__", (IntList::reference & (IntList::*)(IntList::size_type pos)) & IntList::at) .def( "__iter__", [](std::shared_ptr& v) { return py::make_iterator(v->begin(), v->end()); }, py::keep_alive<0, 1>()) .def("__eq__", [](std::shared_ptr& lhs, std::shared_ptr& rhs) { return *lhs == *rhs; }); py::class_>(oneflow_api_util, "Int2IntListMap") .def(py::init<>()) .def("__len__", [](const std::shared_ptr& v) { return v->size(); }) .def( "items", [](std::shared_ptr& v) { return py::make_iterator(v->begin(), v->end()); }, py::keep_alive<0, 1>()) .def("__getitem__", (Int2IntListMap::mapped_type & (Int2IntListMap::*)(const Int2IntListMap::key_type& pos)) & Int2IntListMap::operator[]) .def( "__iter__", [](std::shared_ptr& v) { return py::make_iterator(v->begin(), v->end()); }, py::keep_alive<0, 1>()) .def("__eq__", [](std::shared_ptr& lhs, std::shared_ptr& rhs) { return Int2IntListMapContaining(*lhs, *rhs) && Int2IntListMapContaining(*rhs, *lhs); }); ::oneflow::OneflowModuleRegistry().ImportAll(m); } } // namespace oneflow ================================================ FILE: oneflow/api/python/ir.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/singleton.h" #include "oneflow/ir/oneflow-extension/include/PyAst/Ast.h" #include #include #include #include #include #include #include #include #include #ifdef WITH_MLIR #include "oneflow/ir/include/OneFlow/Extension.h" #include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h" #include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowLRJITRegistry.h" #include "oneflow/api/python/of_api_registry.h" #include #include #include namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("ir", m) { m.def("load_jit_shared_lib", [](const std::string& lib_path) { MutSharedLibPaths()->insert(lib_path); }); // TODO: this may be move to a common place for create global singleton. m.def("create_global_lr_jit", []() { Singleton::New(); }); m.def("compile_and_register_lr_jit", [](const std::string& function_id, std::shared_ptr& func, bool is_dump) { Singleton::Get()->Register(function_id, *func.get(), is_dump); }); // look up and execute the registered function for python api m.def("get_lr", [](const std::string& function_id, float base_lr, float step) { auto engine = Singleton::Get()->LookUp(function_id); return engine(base_lr, step); }); pybind11::class_>(m, "smt"); pybind11::class_>(m, "expr"); pybind11::class_>( m, "FunctionDef"); m.def("FunctionDef_", &pyast::FunctionDef::FunctionDef_); pybind11::class_>(m, "Return"); m.def("Return_", &pyast::Return::Return_); pybind11::class_>(m, "Assign"); m.def("Assign_", &pyast::Assign::Assign_); pybind11::class_>(m, "If"); m.def("If_", &pyast::If::If_); pybind11::class_>(m, "Raise"); m.def("Raise_", &pyast::Raise::Raise_); pybind11::class_>(m, "Assert"); m.def("Assert_", &pyast::Assert::Assert_); pybind11::class_>(m, "Expr"); m.def("Expr_", &pyast::Expr::Expr_); pybind11::class_>(m, "BoolOp"); m.def("BoolOp_", &pyast::BoolOp::BoolOp_); pybind11::class_>(m, "BinOp"); m.def("BinOp_", &pyast::BinOp::BinOp_); pybind11::class_>(m, "Lambda"); m.def("Lambda_", &pyast::Lambda::Lambda_); pybind11::class_>(m, "Compare"); m.def("Compare_", &pyast::Compare::Compare_); pybind11::class_>(m, "Call"); m.def("Call_", &pyast::Call::Call_); pybind11::class_>(m, "Num"); m.def("Num_", &pyast::Num::Num_); pybind11::class_>(m, "Constant"); m.def("Constant_", &pyast::Constant::Constant_); pybind11::class_>(m, "Attribute"); m.def("Attribute_", &pyast::Attribute::Attribute_); pybind11::class_>(m, "Name"); m.def("Name_", &pyast::Name::Name_); pybind11::class_>(m, "arguments"); m.def("arguments_", &pyast::arguments::arguments_); pybind11::class_>(m, "arg"); m.def("arg_", &pyast::arg::arg_); } } // namespace oneflow #endif // WITH_MLIR ================================================ FILE: oneflow/api/python/job_build/job_build_and_infer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/job_build/job_build_and_infer.h" namespace py = pybind11; namespace oneflow { Maybe MarkVariableGradients(const one::TensorTuple& variables, const one::TensorTuple& gradients) { CHECK_OR_RETURN(LazyMode::is_enabled()); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(variables.size(), gradients.size()); // NOLINT(maybe-need-error-msg) HashMap variable_grad_lbns; for (int i = 0; i < variables.size(); ++i) { const std::string& variable_lbn = one::TensorNameScope::Global()->Lookup(variables[i]); CHECK_OR_RETURN(!variable_lbn.empty()) << "variable which index is " << i << " expected to have a tensor name"; const std::string& gradient_lbn = one::TensorNameScope::Global()->Lookup(gradients[i]); CHECK_OR_RETURN(!gradient_lbn.empty()) << "gradient which index is " << i << " expected to have a tensor name"; variable_grad_lbns.emplace(variable_lbn, gradient_lbn); } return JUST(GetCurInferCtx())->MarkVariableGradientBlobNames(variable_grad_lbns); } Maybe MarkOutputGradients(const one::TensorTuple& outputs, const one::TensorTuple& gradients) { CHECK_OR_RETURN(LazyMode::is_enabled()); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), gradients.size()); // NOLINT(maybe-need-error-msg) HashMap output_gradient_lbns; for (int i = 0; i < outputs.size(); ++i) { const std::string& output_lbn = one::TensorNameScope::Global()->Lookup(outputs[i]); CHECK_OR_RETURN(!output_lbn.empty()) << "output which index is " << i << " expected to have a tensor name"; const std::string& gradient_lbn = one::TensorNameScope::Global()->Lookup(gradients[i]); CHECK_OR_RETURN(!gradient_lbn.empty()) << "gradient which index is " << i << " expected to have a tensor name"; output_gradient_lbns.emplace(output_lbn, gradient_lbn); } return JUST(GetCurInferCtx())->MarkOutputGradientBlobNames(output_gradient_lbns); } ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("JobBuildAndInferCtx_Open", &JobBuildAndInferCtx_Open); m.def("JobBuildAndInferCtx_GetCurrentJobName", &JobBuildAndInferCtx_GetCurrentJobName); m.def("JobBuildAndInferCtx_GetCurrentJobId", &JobBuildAndInferCtx_GetCurrentJobId); m.def("JobBuildAndInferCtx_Close", &JobBuildAndInferCtx_Close); m.def("CurJobBuildAndInferCtx_SetJobConf", &CurJobBuildAndInferCtx_SetJobConf); m.def("CurJobBuildAndInferCtx_Complete", &CurJobBuildAndInferCtx_Complete, py::call_guard()); } } // namespace oneflow ================================================ FILE: oneflow/api/python/job_build/job_build_and_infer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_ #define ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_ #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/record/record.pb.h" namespace oneflow { inline Maybe JobBuildAndInferCtx_Open(const std::string& job_name) { auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr()); return mgr->OpenJobBuildAndInferCtx(job_name); } inline Maybe JobBuildAndInferCtx_GetCurrentJobName() { auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr()); return mgr->GetCurrentJobName(); } inline Maybe JobBuildAndInferCtx_GetCurrentJobId() { return JUST(GetCurInferCtx())->job_id(); } inline Maybe JobBuildAndInferCtx_Close() { auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr()); JUST(mgr->CloseCurrentJobBuildAndInferCtx()); return Maybe::Ok(); } inline Maybe CurJobBuildAndInferCtx_SetJobConf(const std::string& job_conf_str) { JobConfigProto job_conf; CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << "job conf parse failed"; return JUST(GetCurInferCtx())->SetJobConf(job_conf); } inline Maybe CurJobBuildAndInferCtx_Complete() { return JUST(GetCurInferCtx())->Complete(); } inline Maybe AddTensorAsGraphLoss(const std::shared_ptr& t) { CHECK_OR_RETURN(t->is_lazy()); CHECK_OR_RETURN(LazyMode::is_enabled()); const std::string& loss_lbn = one::TensorNameScope::Global()->Lookup(t); CHECK_OR_RETURN("" != loss_lbn); return JUST(GetCurInferCtx())->AddLossLogicalBlobName(loss_lbn); } Maybe MarkVariableGradients(const one::TensorTuple& variables, const one::TensorTuple& gradients); Maybe MarkOutputGradients(const one::TensorTuple& outputs, const one::TensorTuple& gradients); } // namespace oneflow #endif // ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_ ================================================ FILE: oneflow/api/python/job_build/lazy_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/lazy_mode.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("lazy_mode", m) { py::class_>(m, "guard") .def(py::init( [](const bool is_enabled) { return std::make_shared(is_enabled); })) .def("__enter__", [](const LazyMode::Guard& guard_obj) {}) .def("__exit__", [](const LazyMode::Guard& guard_obj, const py::object& type, const py::object& value, const py::object& traceback) {}); m.def("is_enabled", []() { return LazyMode::is_enabled(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/multiprocessing/init.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/multiprocessing/object_ptr.h" #include "oneflow/core/ep/cpu/cpu_device_manager.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include #include #if defined(__linux__) #include #include #endif #define SYSASSERT(rv, ...) \ if ((rv) < 0) { throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); } namespace oneflow { namespace multiprocessing { namespace py = pybind11; void multiprocessing_init() { auto multiprocessing_module = OFObjectPtr(PyImport_ImportModule("oneflow.multiprocessing")); if (!multiprocessing_module) { throw std::runtime_error("multiprocessing init error >> multiprocessing_module init fail!"); } auto module = py::handle(multiprocessing_module).cast(); module.def("_prctl_pr_set_pdeathsig", [](int signal) { #if defined(__linux__) auto rv = prctl(PR_SET_PDEATHSIG, signal); SYSASSERT(rv, "prctl"); #endif }); // Py_RETURN_TRUE; } void set_num_threads(int num) { int64_t cpu_logic_core = std::thread::hardware_concurrency(); if (num <= 0) { py::print("Warning : ", num, " less than 1 will be set to 1."); num = 1; } else if (num >= cpu_logic_core) { py::print("Warning : ", num, " is greater than the number of logical cores and will be set to the maximum number " "of logical cores ", cpu_logic_core); num = cpu_logic_core; } auto cpu_device = std::static_pointer_cast( Singleton::Get()->GetDevice(DeviceType::kCPU, 0)); cpu_device->SetNumThreads(num); } ONEFLOW_API_PYBIND11_MODULE("", m) { py::options options; options.disable_function_signatures(); m.def("_multiprocessing_init", &multiprocessing_init); m.def("_set_num_threads", &set_num_threads); options.disable_function_signatures(); } } // namespace multiprocessing } // namespace oneflow ================================================ FILE: oneflow/api/python/multiprocessing/object_ptr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/multiprocessing/object_ptr.h" template<> void OFPointer::free() { if (ptr) Py_DECREF(ptr); } template class OFPointer; ================================================ FILE: oneflow/api/python/multiprocessing/object_ptr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "oneflow/api/python/of_api_registry.h" // reference: pytorch/torch/csrc/utils/object_ptr.h // https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/utils/object_ptr.h template class OFPointer { public: OFPointer() : ptr(nullptr){}; explicit OFPointer(T* ptr) noexcept : ptr(ptr){}; OFPointer(OFPointer&& p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; }; ~OFPointer() { free(); }; T* get() { return ptr; } const T* get() const { return ptr; } T* release() { T* tmp = ptr; ptr = nullptr; return tmp; } operator T*() { return ptr; } OFPointer& operator=(T* new_ptr) noexcept { free(); ptr = new_ptr; return *this; } OFPointer& operator=(OFPointer&& p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; return *this; } T* operator->() { return ptr; } explicit operator bool() const { return ptr != nullptr; } private: void free(); T* ptr = nullptr; }; /** * An RAII-style, owning pointer to a PyObject. You must protect * destruction of this object with the GIL. * * WARNING: Think twice before putting this as a field in a C++ * struct. This class does NOT take out the GIL on destruction, * so if you will need to ensure that the destructor of your struct * is either (a) always invoked when the GIL is taken or (b) takes * out the GIL itself. Easiest way to avoid this problem is to * not use THPPointer in this situation. */ using OFObjectPtr = OFPointer; ================================================ FILE: oneflow/api/python/multiprocessing/shared_memory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/ipc/shared_memory.h" namespace oneflow { namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("multiprocessing", m) { py::class_>(m, "SharedMemory") .def(py::init([](const std::string& name, bool create, size_t size) { if (create) { return ipc::SharedMemory::Open(size, create).GetPtrOrThrow(); } return ipc::SharedMemory::Open(name, create).GetPtrOrThrow(); }), py::arg("name") = "", py::arg("create") = false, py::arg("size") = 0) .def("close", &ipc::SharedMemory::Close) .def("unlink", &ipc::SharedMemory::Unlink) .def_property_readonly("buf", [](ipc::SharedMemory* shm) { return py::memoryview::from_memory(shm->mut_buf(), shm->size()); }) .def_property_readonly("name", &ipc::SharedMemory::name) .def_property_readonly("size", &ipc::SharedMemory::size); m.def("unlink_all_shared_memory", []() { return ipc::SharedMemoryManager::get().UnlinkAllShms(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/numpy/init_numpy_c_api.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/maybe.h" #include "oneflow/extension/python/numpy.h" namespace py = pybind11; ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("InitNumpyCAPI", []() { return oneflow::numpy::InitNumpyCAPI(); }); } ================================================ FILE: oneflow/api/python/of_api_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/of_api_registry.h" namespace oneflow { namespace { // If different APIs are registered under the same path, the BuildModuleFuntion of which will be // saved in the corresponding vector. using SubModuleMap = std::map>>; SubModuleMap* GetSubModuleMap() { static SubModuleMap sub_module_map; return &sub_module_map; } } // namespace void OneflowModuleRegistry::Register(std::string module_path, std::function BuildModule) { (*GetSubModuleMap())[module_path].emplace_back(BuildModule); } void OneflowModuleRegistry::ImportAll(pybind11::module& m) { for (const auto& pair : (*GetSubModuleMap())) { for (const auto& BuildModule : pair.second) { BuildSubModule(pair.first, m, BuildModule); } } } void OneflowModuleRegistry::BuildSubModule( const std::string& module_path, pybind11::module& m, const std::function& BuildModule) { if (module_path.empty()) { BuildModule(m); return; } size_t dot_pos = module_path.find("."); if (dot_pos == std::string::npos) { pybind11::module sub_module = m.def_submodule(module_path.data()); BuildModule(sub_module); } else { const std::string& sub_module_name = module_path.substr(0, dot_pos); pybind11::module sub_module = m.def_submodule(sub_module_name.data()); BuildSubModule(module_path.substr(dot_pos + 1), sub_module, BuildModule); } } } // namespace oneflow ================================================ FILE: oneflow/api/python/of_api_registry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_ #define ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_ #include #include #include #include #include "oneflow/api/python/caster/maybe.h" #include "oneflow/api/python/caster/optional.h" #include "oneflow/api/python/caster/size.h" #include "oneflow/api/python/caster/tensor.h" #include "oneflow/api/python/caster/autograd_function_state.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { class OneflowModuleRegistry { public: OneflowModuleRegistry() = default; ~OneflowModuleRegistry() = default; void Register(std::string module_path, std::function BuildModule); void ImportAll(pybind11::module& m); private: void BuildSubModule(const std::string& module_path, pybind11::module& m, const std::function& BuildModule); }; } // namespace oneflow #define ONEFLOW_API_PYBIND11_MODULE(module_path, m) \ static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module&); \ namespace { \ struct OfApiRegistryInit { \ OfApiRegistryInit() { \ ::oneflow::OneflowModuleRegistry().Register(module_path, \ &OF_PP_CAT(OneflowApiPythonModule, __LINE__)); \ } \ }; \ OfApiRegistryInit of_api_registry_init; \ } \ static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module & m) #endif // ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_ ================================================ FILE: oneflow/api/python/profiler.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/profiler/profiler.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("profiler", m) { m.def("RangePush", [](const std::string& str) { OF_PROFILER_RANGE_PUSH(str); }); m.def("RangePop", []() { OF_PROFILER_RANGE_POP(); }); m.def("ProfilerStart", []() { profiler::ProfilerStart(); }); m.def("ProfilerStop", []() { profiler::ProfilerStop(); }); m.def("EnableProfiler", &profiler::EnableProfiler); m.def("DisableProfilerAndReturnResult", &profiler::DisableProfilerAndReturnResult); m.def("StartRecord", &profiler::StartRecord); m.def("EndRecord", &profiler::EndRecord); } } // namespace oneflow ================================================ FILE: oneflow/api/python/registry/registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/registry_error.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("CheckAndClearRegistryFlag", &CheckAndClearRegistryFlag); } } // namespace oneflow ================================================ FILE: oneflow/api/python/remat/remat.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/vm/remat/allocator.h" #include "oneflow/core/vm/remat/env.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/eager/tensor_storage.h" namespace py = pybind11; namespace oneflow { namespace { Maybe rematable_storage(const std::shared_ptr& tensor) { auto ret = std::dynamic_pointer_cast( JUST(tensor->eager_blob_object())->tensor_storage()); CHECK_NOTNULL_OR_RETURN(ret); return ret; } } // namespace ONEFLOW_API_PYBIND11_MODULE("remat", m) { m.def("is_in_memory", [](const std::shared_ptr& tensor) -> Maybe { return JUST(rematable_storage(tensor))->is_in_memory(); }); m.def("allocated_memory", [](const std::string& device_str) -> Maybe { auto device = JUST(Device::ParseAndNew(device_str)); return Singleton::Get() ->CreateOrGetAllocator(device->enum_type(), device->device_id()) ->allocated_memory(); }); m.def("display", [](const std::string& device_str) -> Maybe { auto device = JUST(Device::ParseAndNew(device_str)); Singleton::Get() ->CreateOrGetAllocator(device->enum_type(), device->device_id()) ->DisplayAllPieces(); return Maybe::Ok(); }); m.def("remat", [](const std::shared_ptr& t) -> Maybe { // TODO: an instruction JUST(rematable_storage(t))->Remat(); return Maybe::Ok(); }); m.def("evict", [](const std::shared_ptr& t) -> Maybe { // TODO: an instruction JUST(rematable_storage(t))->Evict(false); return Maybe::Ok(); }); m.def("is_evictable", [](const std::shared_ptr& t) -> Maybe { return JUST(rematable_storage(t))->is_evictable(); }); m.def("disable_eviction", [](const std::shared_ptr& t) -> Maybe { JUST(rematable_storage(t))->set_eviction_disabled(true); return Maybe::Ok(); }); m.def("clear_compute_op", [](const std::shared_ptr& t) -> Maybe { JUST(rematable_storage(t))->clear_compute_op(); return Maybe::Ok(); }); m.def("clear_stats", []() { Singleton::Get()->clear_stats(); }); m.def("forced_eviction_num", []() { return Singleton::Get()->forced_eviction_num(); }); m.def("eager_eviction_num", []() { return Singleton::Get()->eager_eviction_num(); }); m.def("recomputation_num", []() { return Singleton::Get()->recomputation_num(); }); m.def("set_budget_in_bytes", [](size_t budget_in_bytes) { Singleton::Get()->set_budget_in_bytes(budget_in_bytes); }); m.def("budget_in_bytes", []() { return Singleton::Get()->budget_in_bytes(); }); m.def("set_small_pieces_optimization", [](bool enabled) { return Singleton::Get()->set_small_pieces_optimization(enabled); }); m.def("is_small_pieces_optimization_enabled", []() { return Singleton::Get()->is_small_pieces_optimization_enabled(); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/rpc/ccl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/job/rank_group.h" namespace py = pybind11; namespace oneflow { namespace { Maybe CpuBroadcast(py::bytes* in, int64_t root) { const auto& rank_group = JUST(RankGroup::DefaultRankGroup()); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group)); Py_ssize_t length; char* buffer; if (GlobalProcessCtx::Rank() == root) { CHECK_NOTNULL_OR_RETURN(in); PyBytes_AsStringAndSize(in->ptr(), &buffer, &length); } const auto& meta_transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta)); JUST(ccl::CpuBroadcast(&length, &length, sizeof(length), root, parallel_desc, meta_transport_token)); const auto& data_transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); if (GlobalProcessCtx::Rank() == root) { JUST(ccl::CpuBroadcast(buffer, buffer, length, root, parallel_desc, // NOLINT data_transport_token)); // NOLINT return *in; } else { // https://github.com/pybind/pybind11/issues/1236#issuecomment-527730864 PyBytesObject* bytesObject = static_cast(PyObject_Malloc(offsetof(PyBytesObject, ob_sval) + length + 1)); PyObject_INIT_VAR(bytesObject, &PyBytes_Type, length); bytesObject->ob_shash = -1; bytesObject->ob_sval[length] = '\0'; buffer = bytesObject->ob_sval; JUST(ccl::CpuBroadcast(nullptr, buffer, length, root, parallel_desc, data_transport_token)); return py::reinterpret_steal(reinterpret_cast(bytesObject)); } } } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("cpu_broadcast", [](py::bytes in, int64_t root) -> Maybe { return CpuBroadcast(&in, root); }); m.def("cpu_broadcast", [](const py::none& in, int64_t root) -> Maybe { return CpuBroadcast(nullptr, root); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/rpc/rank_group.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/common/symbol.h" namespace py = pybind11; namespace oneflow { namespace { Maybe CheckCurrentRankGroupConsistency() { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& ctx = JUST(CheckTransportToken(rank_group)); JUST(ctx->WaitDone()); return Maybe::Ok(); } } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("check_current_rank_group_consistency", &CheckCurrentRankGroupConsistency); } } // namespace oneflow ================================================ FILE: oneflow/api/python/session/session.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/session.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/framework/multi_client_session_context.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { using namespace oneflow; py::class_>( m, "SessionContext") .def(py::init&>()) .def("try_init", [](MultiClientSessionContext& session, const std::string& config_proto_str) { return session.TryInit(config_proto_str).GetOrThrow(); }) .def("update_resource", [](MultiClientSessionContext& session, const std::string& reso_proto_str) { return session.UpdateResource(reso_proto_str).GetOrThrow(); }); m.def("NewSessionId", &NewSessionId); py::class_(m, "LogicalConfigProtoContext") .def(py::init()); } } // namespace oneflow ================================================ FILE: oneflow/api/python/stack_getter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "pybind11/pybind11.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/singleton.h" #include "oneflow/extension/stack/foreign_stack_getter.h" #include "oneflow/extension/stack/python/stack_getter.h" #include "oneflow/extension/stack/stacktrace.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("RegisterStackGetter", &RegisterPyStackGetter); m.def("GetCurrentStack", []() { auto* stack_getter = Singleton::Get(); return stack_getter->GetFormattedStack(stack_getter->GetCurrentFrame()); }); m.def("RegisterSignalHandler", []() { if (ParseBooleanFromEnv("ONEFLOW_ENABLE_SIGNAL_HANDLER", true)) { Singleton::New(); } }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/symbol/job_conf_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/throw.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_conf.pb.h" namespace py = pybind11; namespace oneflow { Maybe CreateJobConfSymbol(int64_t symbol_id, const std::string& serialized_symbol_conf) { JobConfigProto symbol_pb; if (!TxtString2PbMessage(serialized_symbol_conf, &symbol_pb)) { THROW(RuntimeError) << "job conf parse failed.\n" << serialized_symbol_conf; } return JobDesc::New(symbol_id, symbol_pb); } ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "JobConfSymbol") .def(py::init([](int64_t symbol_id, const std::string& serialized_symbol_conf) { return CreateJobConfSymbol(symbol_id, serialized_symbol_conf).GetPtrOrThrow(); })) .def_property_readonly("symbol_id", [](const JobDesc& x) { if (!x.symbol_id().has_value()) { THROW(RuntimeError) << "symbol_id not initialized"; } return CHECK_JUST(x.symbol_id()); }) .def_property_readonly("data", [](const JobDesc& job_conf_sym) -> std::string { return PbMessage2TxtString(job_conf_sym.job_conf()); }); } } // namespace oneflow ================================================ FILE: oneflow/api/python/symbol/op_conf_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/operator/op_conf_symbol.h" #include "oneflow/core/common/maybe.h" namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "OpConfSymbol") .def_property_readonly("symbol_id", [](const OperatorConfSymbol& x) { if (!x.symbol_id().has_value()) { THROW(RuntimeError) << "symbol_id not initialized"; } return CHECK_JUST(x.symbol_id()); }) .def_property_readonly("data", &OperatorConfSymbol::data); } } // namespace oneflow ================================================ FILE: oneflow/api/python/symbol/placement_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/core/common/maybe.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/parallel_conf_util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace py = pybind11; namespace oneflow { namespace { int64_t GetDeviceCount(const std::string& device_name) { return Singleton::Get()->GetDeviceCount(device_name); } struct PlacementSymbolExportUtil { static Maybe CheckDeviceTag(const std::string& type) { if (!TRY(DeviceType4DeviceTag(type)).IsOk()) { return Error::RuntimeError() << "Expected one of " << PrintAvailableDevices() << " device type at start of device string: " << type; } return Maybe::Ok(); } static Maybe CreateParallelDesc( const std::string& type, const std::vector& formated_machine_device_ids, const std::shared_ptr& hierarchy_shape) { JUST(CheckDeviceTag(type)); auto parallel_conf = JUST(MakeParallelConf(type, formated_machine_device_ids, hierarchy_shape)); std::shared_ptr parallel_desc; JUST(PhysicalRun([¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { parallel_desc = JUST(builder->GetParallelDescSymbol(*parallel_conf)); return Maybe::Ok(); })); return parallel_desc; } static Maybe CreateParallelDesc(const std::string& proto_str) { ParallelConf parallel_conf; CHECK_OR_RETURN(TxtString2PbMessage(proto_str, ¶llel_conf)) << " Get ParallelConf Pb from string failed."; std::shared_ptr parallel_desc; JUST(PhysicalRun([¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf)); return Maybe::Ok(); })); return parallel_desc; } static Maybe> ParseAndFormatRanks(const py::dict& device_ids) { std::vector> machine_device_id_vec; for (const auto& pair : device_ids) { CHECK_OR_RETURN(py::isinstance(pair.first)) << "The key (node id) of placement device_ids must be int64."; int64_t machine_id = pair.first.cast(); if (py::isinstance(pair.second)) { machine_device_id_vec.emplace_back(machine_id, pair.second.cast()); } else { CHECK_OR_RETURN(py::isinstance(pair.second)) << "Value of device_ids dict must be int, list or range"; for (const auto& device_id : pair.second) { CHECK_OR_RETURN(py::isinstance(device_id)) << "Value of device_ids dict must be int, list or range of int."; machine_device_id_vec.emplace_back(machine_id, device_id.cast()); } } } auto formated_machine_device_ids = std::make_shared>(); for (const auto& pair : machine_device_id_vec) { const std::string& device_name = std::to_string(pair.first) + ":" + std::to_string(pair.second); formated_machine_device_ids->emplace_back(device_name); } return formated_machine_device_ids; } static Maybe GetRanksShape(PyArrayObject* ranks) { auto* shape = PyArray_SHAPE(ranks); return std::make_shared(DimVector(shape, shape + PyArray_NDIM(ranks))); } // Parse and format ranks to string "machine_id:local_rank" static Maybe> ParseAndFormatRanks(PyArrayObject* ranks) { size_t size = PyArray_SIZE(ranks); CHECK_EQ_OR_RETURN(PyArray_TYPE(ranks), NPY_INT64) << Error::RuntimeError() << "placement ranks shoule be an array of long int"; int64_t* rank_data = static_cast(PyArray_DATA(ranks)); std::vector> machine_device_id_vec; for (int i = 0; i < size; ++i) { int64_t rank = rank_data[i]; int64_t machine_id = GlobalProcessCtx::NodeId(rank); int64_t device_id = GlobalProcessCtx::LocalRank(rank); machine_device_id_vec.emplace_back(machine_id, device_id); } auto formated_machine_device_ids = std::make_shared>(); for (const auto& pair : machine_device_id_vec) { auto device_name = std::to_string(pair.first) + ":" + std::to_string(pair.second); formated_machine_device_ids->emplace_back(device_name); } return formated_machine_device_ids; } static Maybe> CreateParallelDescSymbol( const std::string& type, const py::dict& device_ids, const std::shared_ptr& hierarchy) { const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(device_ids)); return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, hierarchy))); } // create Symbol object through given device_type and ranks parameters static Maybe> CreateParallelDescSymbol(const std::string& type, const py::object& ranks) { auto* obj = reinterpret_cast(PyArray_FromAny( ranks.ptr(), nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr)); if (!obj) { return Error::RuntimeError() << "placement ranks shoule be an array of long int"; } const auto& shape = JUST(GetRanksShape(obj)); const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(obj)); return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, shape))); } static Maybe> CreateParallelDescSymbol(const std::string& proto_str) { return SymbolOf(*JUST(CreateParallelDesc(proto_str))); } static Maybe> AllDevicePlacement(const std::string& type) { static thread_local HashMap> device_tag2placement; CHECK_NOTNULL((Singleton::Get())); JUST(CheckDeviceTag(type)); auto it = device_tag2placement.find(type); if (it == device_tag2placement.end()) { int64_t node_size = GlobalProcessCtx::NodeSize(); int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode(); if (type != "cpu") { const int64_t device_count = GetDeviceCount(type); CHECK_NE_OR_RETURN(device_count, 0) << Error::RuntimeError() << "Can\'t construct placement with \"" << type << "\" type because there is no device!"; device_num = std::min(device_num, device_count); } std::vector machine_device_ids; for (int64_t node_id = 0; node_id < node_size; ++node_id) { std::string device_name = std::to_string(node_id) + ":0-" + std::to_string(device_num - 1); machine_device_ids.emplace_back(device_name); } Symbol placement = SymbolOf(*JUST(CreateParallelDesc(type, machine_device_ids, std::shared_ptr()))); it = device_tag2placement.emplace(type, placement).first; } return it->second; } static Maybe GetPlacementRanks(const Symbol& placement) { py::list ranks; for (int64_t machine_id : placement->sorted_machine_ids()) { int64_t node_id = GlobalProcessCtx::NodeId(machine_id); for (int64_t device_id : placement->sorted_dev_phy_ids(machine_id)) { ranks.append(py::cast(node_id * GlobalProcessCtx::NumOfProcessPerNode() + device_id)); } } auto array_ranks = py::cast(ranks); array_ranks.resize(placement->hierarchy()->dim_vec()); return array_ranks; } }; } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_, std::shared_ptr>>(m, "placement", py::dynamic_attr()) .def(py::init([](const std::string& device_type, const py::dict& device_ids, const std::shared_ptr& hierarchy) { PyErr_WarnEx( PyExc_UserWarning, "The way to construct placement is deprecated, and it will be removed in next " "versions. Please use oneflow.placement(type=str, ranks=int array) instead", 1); return PlacementSymbolExportUtil::CreateParallelDescSymbol(device_type, device_ids, hierarchy) .GetOrThrow(); }), py::arg("device_type"), py::arg("device_ids"), py::arg("hierarchy")) .def(py::init([](const std::string& device_type, const py::dict& device_ids, const py::tuple& hierarchy) { PyErr_WarnEx( PyExc_UserWarning, "The way to construct placement is deprecated, and it will be removed in next " "versions. Please use oneflow.placement(type=str, ranks=int array) instead", 1); DimVector shape_dims{}; for (const auto& dim : hierarchy) { shape_dims.emplace_back(dim.cast()); } return PlacementSymbolExportUtil::CreateParallelDescSymbol( device_type, device_ids, std::make_shared(shape_dims)) .GetOrThrow(); }), py::arg("device_type"), py::arg("device_ids"), py::arg("hierarchy") = py::tuple()) .def(py::init([](const std::string& type, const py::object& ranks) { return PlacementSymbolExportUtil::CreateParallelDescSymbol(type, ranks).GetOrThrow(); }), py::arg("type"), py::arg("ranks")) .def(py::init([](const std::string& proto_str) { return PlacementSymbolExportUtil::CreateParallelDescSymbol(proto_str).GetOrThrow(); }), py::arg("proto_str")) .def_property_readonly( "device_type", [](Symbol p) { PyErr_WarnEx( PyExc_UserWarning, "The property .device_type of placement is deprecated, please use .type instead", 1); return p->device_tag(); }) .def_property_readonly("type", [](Symbol p) { return p->device_tag(); }) .def_property_readonly("hierarchy", [](Symbol p) { PyErr_WarnEx(PyExc_UserWarning, "The property .hierarchy of placement is deprecated, " "please use .ranks.shape instead", 1); return p->hierarchy(); }) .def_property_readonly("ranks", &PlacementSymbolExportUtil::GetPlacementRanks) .def("__str__", PlacementToString) .def("__repr__", PlacementToString) .def(py::self == py::self) .def(py::hash(py::self)) .def_static("all", &PlacementSymbolExportUtil::AllDevicePlacement); } } // namespace oneflow ================================================ FILE: oneflow/api/python/symbol/sbp_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/common/sbp.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/constant.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/framework/nd_sbp.h" namespace py = pybind11; namespace oneflow { namespace { Maybe>> MakeSplitSbpParallelList(int max_split_axis) { std::shared_ptr>> ret = std::make_shared>>(max_split_axis); for (int i = 0; i < max_split_axis; ++i) { ret->at(i) = JUST(MakeSplitSbpParallel(i)); } return ret; } Maybe> GetSplitSbpParallel(int axis) { CHECK_GE_OR_RETURN(axis, 0) << Error::RuntimeError() << "Split axis must not be negative, but got " << axis << "!"; CHECK_LT_OR_RETURN(axis, kMaxSplitAxis) << Error::RuntimeError() << "Expected split axis to be less than the supported maximum axis (" << kMaxSplitAxis << "), but got " << axis << "!"; static std::vector> split_sbp_sym_list = *JUST(MakeSplitSbpParallelList(kMaxSplitAxis)); return split_sbp_sym_list.at(axis); } Maybe> GetBroadcastSbpParallel() { static Symbol broadcast_sbp = JUST(MakeBroadcastSbpParallel()); return broadcast_sbp; } Maybe> GetPartialSumSbpParallel() { static Symbol partial_sum_sbp = JUST(MakePartialSumSbpParallel()); return partial_sum_sbp; } Maybe> SbpGetState(const Symbol& sbp) { if (sbp->has_broadcast_parallel()) { return std::make_shared>("B", -1); } else if (sbp->has_partial_sum_parallel()) { return std::make_shared>("P", -1); } else if (sbp->has_split_parallel()) { return std::make_shared>("S", sbp->split_parallel().axis()); } else { return Error::RuntimeError() << "Invalid sbp signature: " << sbp->DebugString(); } } Maybe> GetSbpFromState(const std::pair& state) { if (state.first == "B") { return GetBroadcastSbpParallel(); } else if (state.first == "P") { return GetPartialSumSbpParallel(); } else if (state.first == "S") { return GetSplitSbpParallel(state.second); } else { return Error::RuntimeError() << "Invalid sbp signature state: (" << state.first << ", " << state.second << ");"; } } } // namespace ONEFLOW_API_PYBIND11_MODULE("sbp", m) { m.attr("max_split_axis") = kMaxSplitAxis; py::class_, std::shared_ptr>>(m, "sbp", py::dynamic_attr()) .def("__str__", &api::ApiSbpToString) .def("__repr__", &api::ApiSbpToString) .def(py::self == py::self) .def(py::hash(py::self)) .def("_ToAttrStr", [](const Symbol& sbp_sym) { return SbpParallelToString(*sbp_sym); }) .def(py::pickle( [](const Symbol& sbp) { // __getstate__ return SbpGetState(sbp).GetOrThrow(); }, [](const std::pair& state) { // __setstate__ return GetSbpFromState(state).GetOrThrow(); })); m.def("split", GetSplitSbpParallel, py::arg("axis")); m.def("broadcast", &GetBroadcastSbpParallel); m.def("partial_sum", &GetPartialSumSbpParallel); } } // namespace oneflow ================================================ FILE: oneflow/api/python/symbol/scope_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/scope.h" namespace py = pybind11; namespace oneflow { Maybe CreateScopeSymbol(int64_t symbol_id, const std::string& symbol_conf_str) { ScopeProto symbol_pb; if (!TxtString2PbMessage(symbol_conf_str, &symbol_pb)) { THROW(RuntimeError) << "symbol conf parse failed.\n" << symbol_conf_str; } return Scope::New(symbol_id, symbol_pb); } ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "ScopeSymbol") .def(py::init([](int64_t symbol_id, const std::string& symbol_conf_str) { return CreateScopeSymbol(symbol_id, symbol_conf_str).GetPtrOrThrow(); })) .def_property_readonly("symbol_id", [](const Scope& x) { if (!x.symbol_id().has_value()) { THROW(RuntimeError) << "symbol_id not initialized"; } return CHECK_JUST(x.symbol_id()); }) .def_property_readonly("_proto_str", [](const Scope& x) { return PbMessage2TxtString(x.scope_proto()); }) .def("auto_increment_id", &Scope::auto_increment_id) .def_property_readonly("session_id", &Scope::session_id) .def_property_readonly("job_desc_symbol", &Scope::job_desc_symbol) .def_property_readonly( "device_parallel_desc_symbol", [](const Scope& x) { return x.device_parallel_desc_symbol().shared_from_symbol(); }) .def_property_readonly("parent_scope_symbol", &Scope::parent_scope_symbol) .def("MakeChildScopeProto", &Scope::MakeChildScopeProto); } } // namespace oneflow ================================================ FILE: oneflow/api/python/utils/dataloader.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _WIN32 #include #include #include #include #include #include #include #include "oneflow/api/python/of_api_registry.h" #include namespace oneflow { namespace py = pybind11; // reference: pytorch/torch/csrc/DataLoader.cpp // https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/DataLoader.cpp // Critical signal handlers should be registered on worker processes before // doing work. // The handler will raise default handler so that the kill information will be // retrieved from main process. // Python handle is _set_worker_signal_handlers(). #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \ auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ (void)_w; \ struct sigaction sa {}; \ sa.sa_handler = SIG_DFL; \ sa.sa_flags = 0; \ if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, nullptr) != 0) { \ _exit(EXIT_FAILURE); \ } else { \ raise(SIGNAL); \ } \ } // signal(2) is really not portable. So use sigaction. // http://man7.org/linux/man-pages/man2/signal.2.html static inline void setSignalHandler(int signal, void (*handler)(int, siginfo_t*, void*), struct sigaction* old_sa_ptr) { struct sigaction sa {}; sa.sa_sigaction = handler; sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER; if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) { std::ostringstream oss; oss << "An error occurred while setting handler for " << strsignal(signal) << "."; throw std::runtime_error(oss.str()); } } SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " "This might be caused by insufficient shared memory (shm).\n"); SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n"); // When an error happened in DataLoader methods and Python starts to exit, the // error trace will keep the loader alive, and Python may kill the children // processes first before deleting the loader object. Then the cleaning up // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError // again, and then it defeats the whole purpose. static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) { if (info->si_pid == getppid()) { _exit(EXIT_SUCCESS); } struct sigaction sa {}; sa.sa_handler = SIG_DFL; sa.sa_flags = 0; if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) { _exit(EXIT_FAILURE); } else { raise(SIGTERM); } } static void set_worker_signal_handlers() { setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static std::map> worker_pids = {}; static void error_if_any_worker_fails() { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int error; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::set* pid_set; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) pid_t worker_pid; siginfo_t infop; // Only check the pids we care about for (auto& w : worker_pids) { pid_set = &(w.second); for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) { worker_pid = *pid_it; // Use waitid rather than waitpid so that we can set NOWAIT, and that Python // and other handlers can get whatever info they want about the child. infop.si_pid = 0; error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT); // ignore errors and case with no waitable child if (error < 0 || infop.si_pid == 0) continue; if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") exited " << "unexpectedly with exit code " << infop.si_status << ". " << "Details are lost due to multiprocessing. Rerunning with " << "num_workers=0 may give better error trace."; // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set->clear(); throw std::runtime_error(oss.str()); } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; if (infop.si_status == SIGBUS) { oss << "It is possible that dataloader's workers are out of shared memory. " << "Please try to raise your shared memory limit."; } // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set->clear(); throw std::runtime_error(oss.str()); } } } } inline int64_t utils_unpackLong(PyObject* obj) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int overflow; long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw py::value_error(); } if (overflow != 0) { throw std::runtime_error("Overflow when unpacking long"); } return (int64_t)value; } // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple // of pids we are interested in. static void set_worker_pids(py::args py_args) { PyObject* args = py_args.ptr(); if (PyTuple_GET_SIZE(args) != 2) { throw py::type_error("_set_worker_pids expects exactly 2 arguments."); } int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0)); if (worker_pids.find(key) != worker_pids.end()) { throw py::value_error( "_set_worker_pids should be called only once for each _BaseDataLoaderIter."); } PyObject* child_pids = PyTuple_GET_ITEM(args, 1); if (!PyTuple_Check(child_pids)) { py::print("_set_worker_pids expects a tuple for child_pids, but got: ", Py_TYPE(child_pids)->tp_name); throw py::type_error("_set_worker_pids expects a tuple for child_pids"); } std::set pids_set = {}; auto size = PyTuple_GET_SIZE(child_pids); for (int idx = 0; idx < size; idx++) { PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); pids_set.insert(static_cast(utils_unpackLong(obj))); } worker_pids[key] = pids_set; } static void remove_worker_pids(py::args py_args) { PyObject* args = py_args.ptr(); int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0)); auto it = worker_pids.find(key); if (it == worker_pids.end()) { py::print("Cannot find worker information for _BaseDataLoaderIter with id :", key); throw py::value_error("Cannot find worker information for _BaseDataLoaderIter"); } worker_pids.erase(it); } #undef SIGNAL_HANDLER #else // dummy implementations for windows static PyObject* set_worker_signal_handlers(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* set_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* remove_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* error_if_any_worker_fails(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } #endif ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("_set_worker_signal_handlers", &set_worker_signal_handlers); m.def("_set_worker_pids", &set_worker_pids); m.def("_remove_worker_pids", &remove_worker_pids); m.def("_error_if_any_worker_fails", &error_if_any_worker_fails); } } // namespace oneflow ================================================ FILE: oneflow/api/python/utils/tensor_utils.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/global_mode.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/functional/impl/common.h" namespace py = pybind11; namespace oneflow { namespace one { Maybe EagerLocalTensorZeros(const std::shared_ptr& t) { JUST(functional::CheckInplaceValid(t)); std::shared_ptr local_tensor; if (t->is_local()) { local_tensor = JUST(t->AsLocalTensor()); } else { local_tensor = JUST(t->cur_rank_phy_tensor()); } CHECK_OR_RETURN(local_tensor->is_eager()) << "eager tensors supported only"; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { JUST(builder->AccessBlobByCallback( local_tensor, [](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { AutoMemset(stream, eager_blob_object->mut_dptr(), 0, eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case()); }, "mut")); return Maybe::Ok(); })); return Maybe::Ok(); } namespace { void CopyFromNumpyArray(ep::Stream* stream, const std::shared_ptr& eager_blob_object, const NumPyArrayPtr& array_ptr) { SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(), eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(), memory::MakeHostMemCase()); } } // namespace Maybe CopyLocalTensorFromUntypedArray(const std::shared_ptr& tensor, PyObject* array) { return CopyBetweenLocalTensorAndNumpy(tensor, array, CopyFromNumpyArray, "mut", /*block_host_until_done=*/false); } Maybe, std::vector>>> MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr& t) { const auto& tensor = JUST(t->AsLocalTensor()); if (tensor->dtype() != DType::TensorBuffer()) { return Error::RuntimeError() << "tensor buffer supported only"; } CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; std::vector shapes; std::vector> dtypes; auto btb = std::make_shared(); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback( tensor, btb, [](ep::Stream* stream, const std::shared_ptr&) {}, "const"); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); const auto& eager_blob_object = JUST(tensor->eager_blob_object()); const Shape& blob_shape = eager_blob_object->shape(); const auto* tensor_buffer_ptr = eager_blob_object->dptr(); for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) { const TensorBuffer* tensor_buffer = tensor_buffer_ptr + i; shapes.emplace_back(tensor_buffer->shape()); dtypes.emplace_back(DType::Get(tensor_buffer->data_type()).GetOrThrow()); } return std::make_tuple(shapes, dtypes); } Maybe RegisterTensorHook(const std::shared_ptr& self, const AutogradMeta::Hook& hook) { CHECK_OR_RETURN(self->requires_grad()) << "cannot register a hook on a tensor that doesn't require gradient"; if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); } self->mut_autograd_meta()->add_hook(hook); return Maybe::Ok(); } Maybe RegisterTensorPostGradAccumulationHook(const std::shared_ptr& self, const AutogradMeta::Hook& hook) { if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); } self->mut_autograd_meta()->add_post_grad_accumulation_hook(hook); return Maybe::Ok(); } Maybe TensorGetPyTupleOfSbp(const Tensor& tensor) { const auto& nd_sbp = JUST(tensor.nd_sbp()); const auto& tuple = std::make_shared(nd_sbp->sbp_parallel_size()); for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { (*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i)); } return tuple; } Maybe MakeLocalTensorFromData(PyObject* data, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory) { bool is_bfloat16_dtype = dtype ? JUST(dtype)->data_type() == DataType::kBFloat16 : false; bool is_cuda_device = device ? JUST(device)->enum_type() == DeviceType::kCUDA : false; if (is_bfloat16_dtype && is_cuda_device) { #if CUDA_VERSION < 11000 return Error::RuntimeError() << "Cannot create a bfloat16 tensor on gpu under cuda version: 11000"; #endif // CUDA_VERSION >= 11000 } PyArray_Descr* np_dtype = dtype.has_value() && !is_bfloat16_dtype ? PyArray_DescrFromType(JUST(numpy::OFDataTypeToNumpyType(JUST(dtype)->data_type()))) : nullptr; // NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the // array with NPY_ARRAY_DEFAULT flag is C-style contiguous. // NPY_ARRAY_FORCECAST is needed otherwise there will a segfault. // // Even though PyArray_FromAny can cast the input array to the desired dtype // if `dtype` argument is set, it fails to handle the following case: // >> x = [flow.tensor([1, 2])] * 3 <-- x is a list of flow.Tensor // >> y = flow.tensor(x, dtype=flow.float32) <-- returns nullptr // However, the following case without `dtype` argument works well: // >> x = [flow.tensor([1, 2])] * 3 // >> y = flow.tensor(x) // So we cast the input array to the desired dtype manually. PyArrayObject* _array = reinterpret_cast( PyArray_FromAny(data, nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST, nullptr)); if (!_array) { return Error::RuntimeError() << "Can not convert input data to a new numpy array."; } // PyArray_FromArray steals a reference to np_dtype object, so no need to decref it. PyObject* array = PyArray_FromArray( _array, np_dtype, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST); Py_DECREF(_array); auto* np_arr = reinterpret_cast(array); const npy_intp* dims_ptr = PyArray_SHAPE(np_arr); const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr))); DataType np_data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr)); Symbol device_; if (device) { device_ = JUST(device); } else { device_ = JUST(Device::New("cpu")); } std::shared_ptr tensor = JUST(functional::Empty(shape, JUST(DType::Get(np_data_type)), device_, /*requires_grad=*/false, /*pin_memory=*/pin_memory)); if (device_->enum_type() != DeviceType::kMeta) { JUST(CopyLocalTensorFromUntypedArray(tensor, array)); } Py_DECREF(array); if (dtype && JUST(dtype)->data_type() != np_data_type) { tensor = JUST(functional::To(tensor, JUST(dtype), false)); } else if (!dtype && !PyArray_Check(data) && tensor->dtype()->is_floating_point() && GetDefaultDType() != tensor->dtype()) { // If it not assign dtype and created from PySequence, cast tensor to default floating dtype tensor = JUST(functional::To(tensor, JUST(DType::Get(DataType::kFloat)), false)); } JUST(tensor->set_requires_grad(requires_grad)); return tensor; } namespace { Maybe> GetAllBroadcastNdSbp(size_t ndim) { NdSbp broadcast_nd_sbp; for (size_t i = 0; i < ndim; ++i) { broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } return SymbolOf(broadcast_nd_sbp); } auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); } // namespace Maybe MakeGlobalTensorFromData(PyObject* data, const Optional>& dtype, Symbol placement, const std::vector>& sbp_tuple, const bool requires_grad) { PyObject* array = NULL; if (PyArray_Check(data)) { // Only NPY_CORDER is supported, and returns a new C-style contiguous array. array = PyArray_NewCopy((PyArrayObject*)data, NPY_CORDER); } else { // NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the // array with NPY_ARRAY_DEFAULT flag is C-style contiguous. array = PyArray_FromAny(data, nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr); if (!array) { return Error::RuntimeError() << "Can not convert input data to a numpy array."; } } auto* np_arr = reinterpret_cast(array); const npy_intp* dims_ptr = PyArray_SHAPE(np_arr); const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr))); DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr)); if (placement->parallel_num() > 1) { const void* buf_ptr = PyArray_DATA(np_arr); size_t array_size = PyArray_SIZE(np_arr); CHECK_EQ_OR_RETURN(array_size, shape.elem_cnt()); size_t byte_size = array_size * GetSizeOfDataType(data_type); JUST(DataConsistencyCheck(buf_ptr, byte_size, placement)); } Symbol device = JUST(Device::New(placement->device_tag())); std::shared_ptr local_tensor; { GlobalMode::Guard guard(/* disable global mode */ false); local_tensor = JUST(functional::Empty(shape, JUST(DType::Get(data_type)), device, /*requires_grad=*/false, /*pin_memory=*/false)); } if (device->enum_type() != DeviceType::kMeta) { JUST(CopyLocalTensorFromUntypedArray(local_tensor, array)); } Py_DECREF(array); // Cast to float if data is double sequence, rather than numpy array. Symbol dtype_; if (dtype) { dtype_ = JUST(dtype); } else if (!dtype && data_type == DataType::kDouble && !PyArray_Check(data)) { dtype_ = DType::Float(); } if (dtype_) { local_tensor = JUST(functional::Cast(local_tensor, dtype_, /*pin_memory=*/false)); } size_t sbp_dims = sbp_tuple.size(); Symbol broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(sbp_dims)); std::shared_ptr broadcast_tensor = JUST( functional::LocalToGlobal(local_tensor, placement, *JUST(GetSbpList(broadcast_nd_sbp)), shape, local_tensor->dtype(), /* sync_data */ true, /*copy=*/false)); std::vector> grad_sbp_tuple; auto global_tensor = JUST(functional::ToGlobal(broadcast_tensor, placement, sbp_tuple, grad_sbp_tuple, /* check_meta */ false, /*copy=*/false)); JUST(global_tensor->set_requires_grad(requires_grad)); return global_tensor; } Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const bool pin_memory) { if (other->is_local()) { const Symbol& device = JUST(other->device()); return functional::Copy(other, device->type(), device->device_id(), pin_memory); } else { const Symbol& nd_sbp = JUST(other->nd_sbp()); const std::vector>& sbp_tuple = *JUST(GetSbpList(nd_sbp)); std::vector> grad_sbp_tuple; // TODO:(zhaoluyang) global case support pin_memory return functional::ToGlobal(other, JUST(other->parallel_desc()), sbp_tuple, grad_sbp_tuple, /* check_meta */ false, /*copy=*/false); } } Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory) { std::shared_ptr tensor; Symbol device_; if (device) { device_ = JUST(device); } if (other->is_local()) { if (!device) { device_ = JUST(other->device()); } tensor = JUST(functional::Copy(other, device_->type(), device_->device_id(), pin_memory && !dtype.has_value())); } else { tensor = JUST(functional::GlobalToLocal(other, /*copy=*/false)); if (!device) { device_ = JUST(Device::New("cpu")); } tensor = JUST(functional::Copy(tensor, device_->type(), device_->device_id(), pin_memory && !dtype.has_value())); } if (dtype) { const Symbol& dtype_ = JUST(dtype); if (tensor->dtype() != dtype_) { tensor = JUST(functional::Cast(tensor, dtype_, pin_memory)); } } JUST(tensor->set_requires_grad(requires_grad)); return tensor; } Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const Optional>& dtype, const Symbol& placement, const std::vector>& sbp_tuple, const bool requires_grad) { std::vector> grad_sbp_tuple; bool check_meta = other->is_global() ? false : true; std::shared_ptr tensor = JUST(functional::ToGlobal( other, placement, sbp_tuple, grad_sbp_tuple, check_meta, /*copy=*/false)); if (dtype) { const Symbol& dtype_ = JUST(dtype); if (tensor->dtype() != dtype_) { tensor = JUST(functional::Cast(tensor, dtype_, /*pin_memory=*/false)); } } JUST(tensor->set_requires_grad(requires_grad)); return tensor; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/api/python/utils/tensor_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_ #define ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_ #include #undef _PyGC_FINALIZED #include #include #include #include #include "oneflow/api/python/framework/tensor.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/profiler/profiler.h" namespace py = pybind11; namespace pybind11 { // reference: https://github.com/pybind/pybind11/issues/1776 template<> struct format_descriptor { static pybind11::dtype dtype() { handle ptr = detail::npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); return reinterpret_borrow(ptr); } static std::string format() { // following: https://docs.python.org/3/library/struct.html#format-characters return "e"; } static constexpr auto name() { return detail::_("float16"); } }; } // namespace pybind11 namespace oneflow { namespace one { Maybe EagerLocalTensorZeros(const std::shared_ptr& t); inline Maybe GetTensorDataPtr(const std::shared_ptr& tensor) { void* data_ptr = nullptr; const auto& Callback = [&](ep::Stream*, const std::shared_ptr& eager_blob_object) { data_ptr = eager_blob_object->mut_raw_dptr(); }; auto btb = std::make_shared(); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback(tensor, btb, Callback, "const"); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return data_ptr; } template inline static Maybe EagerLocalTensorToNumpy(PyObject* py_tensor) { const auto& t = PyTensor_Unpack(py_tensor); std::shared_ptr tensor = JUST(t->AsLocalTensor()); CHECK_OR_RETURN(JUST(tensor->device()) == JUST(Device::New("cpu"))); CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only."; // set base object attr py::handle handle = py::handle(py_tensor); const size_t ndim = tensor->ndim(); const auto shape = numpy::OFShapeToNumpyShape(tensor->shape()->dim_vec()); // NumPy strides use bytes. OneFlow strides use element counts. const auto stride = numpy::OFStrideToNumpyStride(*JUST(tensor->stride()), tensor->dtype()->data_type()); void* data_ptr = JUST(GetTensorDataPtr(tensor)); return py::array(py::buffer_info(data_ptr, sizeof(T), py::format_descriptor::format(), ndim, shape, stride), handle) .release() .ptr(); } template struct TensorTypeToPyType final { typedef T type; }; template<> struct TensorTypeToPyType final { typedef float type; }; template<> struct TensorTypeToPyType final { typedef float type; }; template inline static Maybe EagerLocalTensorItem(const std::shared_ptr& tensor) { // OF_PROFILER_RANGE_GUARD("EagerLocalTensorItem"); T value = JUST(GetItemInScalarTensor(tensor)); return functional::CastToPyObject(static_cast::type>(value)); } inline Maybe CopyBetweenLocalTensorAndNumpy( const std::shared_ptr& t, PyObject* array, void (*Copy)(ep::Stream*, const std::shared_ptr&, const NumPyArrayPtr&), const std::string& modifier, bool block_host_until_done) { auto tensor = JUST(t->AsLocalTensor()); CHECK_OR_RETURN(tensor->is_contiguous()) << "contiguous tensors supported only."; CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only."; if (block_host_until_done) { NumPyArrayPtr array_ptr(array); const auto& Callback = [array_ptr, Copy]( ep::Stream* stream, const std::shared_ptr& eager_blob_object) { Copy(stream, eager_blob_object, array_ptr); }; auto btb = std::make_shared(); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback(tensor, btb, Callback, modifier); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); } else { Py_INCREF(array); NumPyArrayPtr array_ptr(array, [array]() { // release array in main thread to eliminate the time-consuming gil request CHECK_JUST(SingletonMaybe())->add_main_thread_pending_task([array]() { Py_DECREF(array); }); }); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->AccessBlobByCallback( tensor, [array_ptr, Copy](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { Copy(stream, eager_blob_object, array_ptr); }, modifier); })); } return Maybe::Ok(); } Maybe, std::vector>>> MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr& t); Maybe RegisterTensorHook(const std::shared_ptr& self, const AutogradMeta::Hook& hook); Maybe RegisterTensorPostGradAccumulationHook(const std::shared_ptr& self, const AutogradMeta::Hook& hook); Maybe TensorGetPyTupleOfSbp(const Tensor& tensor); Maybe MakeLocalTensorFromData(PyObject* data, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory); Maybe MakeGlobalTensorFromData(PyObject* data, const Optional>& dtype, Symbol placement, const std::vector>& sbp_tuple, const bool requires_grad); Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const bool pin_memory); Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory); Maybe MakeTensorFromOtherTensor(const std::shared_ptr& other, const Optional>& dtype, const Symbol& placement, const std::vector>& sbp_tuple, const bool requires_grad); } // namespace one } // namespace oneflow #endif // ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_ ================================================ FILE: oneflow/core/auto_parallel/algorithm_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/algorithm_util.h" namespace oneflow { namespace auto_parallel { // Inverse function of order // The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate // equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after // v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] < // order[3] We know the strict order. void InverseOrder(const std::vector& order, std::vector& inverse_order) { inverse_order.resize(order.size()); for (int32_t i = 0; i < order.size(); i++) { inverse_order[order[i]] = i; } } } // namespace auto_parallel // Ceil quotient define a division process, denoted by (/), // which give us the maximum part of an integer division. // For example, // 16 (/) 4 = 4, 17 (/) 4 = 5 // 5 (/) 2 = 3, 6 (/) 2 = 3 // 1 (/) 3 = 1, 2 (/) 7 = 1 // 17 divide by 4 give us 5, 4, 4, 4 // The normal quotient would take the smaller one 4, // but the ceil quotient would take the larger one 5. int64_t CeilQuotient(int64_t dividend, int64_t divisor) { return (dividend + divisor - 1) / divisor; } } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/algorithm_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_ #include #include #include #include namespace oneflow { namespace auto_parallel { // this function is to remove the i-th element from a vector in Constant time. // the vector should not care about ordering. // Be more careful about this function. Make sure that the traveling order of // the vector goes from back to front. template void RemoveFrom(std::vector& v, int32_t i) { v[i] = v.back(); v.pop_back(); } template void CheckAndRemoveFrom(std::vector& v, T& t) { for (int32_t i = v.size() - 1; i >= 0; i--) { if (v[i] == t) { RemoveFrom(v, i); break; } } } // Inverse function, which transfer a vector to an unordered_map. template void InverseFunction(const std::vector& v, std::unordered_map& inverse_map) { inverse_map.clear(); for (int32_t i = 0; i < v.size(); i++) { inverse_map[v[i]] = i; } } // When you want to sort something but you can not move any elements, use order. // Decide the order of sorting in a list v, we have // v[order[i]] < v[order[j]] for all i void DecideOrder(const T& v, std::vector& order, const Compare& comp) { // Initialize order order.resize(v.size()); for (int32_t i = 0; i < v.size(); i++) { order[i] = i; } // sort std::sort(order.begin(), order.end(), [&](int32_t i, int32_t j) { return comp(v[i], v[j]); }); } // Inverse function of order // The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate // equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after // v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] < // order[3] We know the strict order. void InverseOrder(const std::vector& order, std::vector& inverse_order); } // namespace auto_parallel // Ceil quotient define a division process, denoted by (/), // which give us the maximum part of an integer division. // For example, // 16 (/) 4 = 4, 17 (/) 4 = 5 // 5 (/) 2 = 3, 6 (/) 2 = 3 // 17 divide by 4 give us 5, 4, 4, 4 // The normal quotient would take the smaller one 4, // but the ceil quotient would take the larger one 5. int64_t CeilQuotient(int64_t dividend, int64_t divisor); static const double kFloatDeviationMinus = 0.9999999; static const double kFloatDeviationPlus = 1.0000001; } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_ ================================================ FILE: oneflow/core/auto_parallel/auto_memory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/auto_parallel/sbp_constructor.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/straighten_nodes.h" #include "oneflow/core/register/logical_blob_id.pb.h" namespace oneflow { namespace auto_parallel { namespace { class TopoStruct { public: SbpNode* sbp_node = nullptr; const OpNode* op_node = nullptr; // Memory increment = (memory of out registers) - (memory of in registers) int64_t memory_increment = -1; int32_t exceed_time = -1; bool is_reusable = false; int32_t counter = 0; int32_t min_layer = -1; // The maximum min_layer among out_topo_structs int32_t max_layer = -1; // TODO: remove tributary layer // This node should be finished before tributary layer int32_t tributary_layer = -1; HashSet in_topo_structs; HashSet out_topo_structs; explicit TopoStruct(SbpNode* sbp_node_); explicit TopoStruct(const OpNode* op_node_); // Compute the minimum layer of this node int32_t ComputeMinLayer(); // Compute the maximum layer of this node void ComputeMaxLayer(int32_t max_min_layer); // Compute the tributary layer int32_t ComputeTributaryLayer(int32_t max_min_layer); // Decide whether all the produced registers are reusable void ComputeIsReusable(); // Exceed time = time of cpu - time of gpu void ComputeExceedTime(); // deciding parameter // kTributaryLayerAscend = 0, // small tributary layers go first // kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first // kLayerAscend = 2, // first in first out // kMemoryIncrementAscend = 3, // small memory increment go first // kExceedTimeAscend = 4, // small exceed time go first // kTributaryLayerDescend = 100, // large tributary layers go first // kDistanceToOverlapDescend = 101, // long distance to overlap go first // kLayerDescend = 102, // last in first out // kMemoryIncrementDescend = 103, // large memory increment go first // kExceedTimeDescend = 104, // large exceed time go first int64_t GetDecidingParameter(StraightenOrder so) const; }; static StraightenAlgorithmTag sat; static std::vector decide_parameters; // Order in the waiting sets struct comp { bool operator()(const TopoStruct* a, const TopoStruct* b) const { for (auto decide_parameter : decide_parameters) { auto decide_parameter_a = a->GetDecidingParameter(decide_parameter); auto decide_parameter_b = b->GetDecidingParameter(decide_parameter); if (decide_parameter_a != decide_parameter_b) { return decide_parameter_a < decide_parameter_b; } } return a->op_node->op().op_name() < b->op_node->op().op_name(); } }; bool IsProducedRegisterReusable(const Operator& op) { // The repeat, acc, pack and unpack operators have non-reusable registers // and a -1 register num at this moment. if (op.op_conf().has_user_conf()) { const auto& op_type_name = op.op_conf().user_conf().op_type_name(); // We record the frequency in swin-transformer on the right hand side // and adjust the position accordingly. if (op_type_name == "repeat" // 213 || op_type_name == "acc" // 173 || op_type_name == "unpack" // 2 || op_type_name == "pack" // 1 ) { return false; } } // NOTE: Please refer to oneflow/core/graph_impl/normal_forward_compute_task_node.cpp // NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum // for detail. // We can not use <= 0 here since RegstNum4Op returns a number with type size_t. // -1 is actually 18446744073709551615 here. return RegstNum4Op(op) == -1; } TopoStruct::TopoStruct(SbpNode* sbp_node_) : sbp_node(sbp_node_), op_node(sbp_node_->GetOperatorNode()) { ComputeIsReusable(); ComputeExceedTime(); } TopoStruct::TopoStruct(const OpNode* op_node_) : op_node(op_node_) { ComputeIsReusable(); ComputeExceedTime(); } // deciding parameter // kTributaryLayerAscend = 0, // small tributary layers go first // kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first // kLayerAscend = 2, // first in first out // kMemoryIncrementAscend = 3, // small memory increment go first // kExceedTimeAscend = 4, // small exceed time go first // kTributaryLayerDescend = 100, // large tributary layers go first // kDistanceToOverlapDescend = 101, // long distance to overlap go first // kLayerDescend = 102, // last in first out // kMemoryIncrementDescend = 103, // large memory increment go first // kExceedTimeDescend = 104, // large exceed time go first int64_t TopoStruct::GetDecidingParameter(StraightenOrder so) const { int64_t sign = 1; if (so >= kDiff4AscendDescend) { so = StraightenOrder(int(so) - kDiff4AscendDescend); sign = -1; } switch (so) { case StraightenOrder::kTributaryLayerAscend: return sign * tributary_layer; case StraightenOrder::kDistanceToOverlapAscend: return 0; case StraightenOrder::kLayerAscend: return sign * min_layer; case StraightenOrder::kMemoryIncrementAscend: return sign * memory_increment; case StraightenOrder::kExceedTimeAscend: return sign * exceed_time; default: return 0; } } // Exceed time = time of cpu - time of gpu void TopoStruct::ComputeExceedTime() { if (ShortGpuTime(op_node->op().op_conf())) { exceed_time = 1; } else { exceed_time = 0; } } // Compute the minimum layer of this node int32_t TopoStruct::ComputeMinLayer() { if (min_layer >= 0) { return min_layer; } for (auto& in_topo_struct : in_topo_structs) { min_layer = std::max(min_layer, in_topo_struct->ComputeMinLayer()); } return ++min_layer; } // Compute the maximum layer of this node void TopoStruct::ComputeMaxLayer(int32_t max_min_layer) { // Execute those optimizer as soon as possible to release the register of weight_diff if (out_topo_structs.empty()) { max_layer = min_layer; return; } max_layer = max_min_layer; for (auto& out_topo_struct : out_topo_structs) { if (max_layer > out_topo_struct->min_layer) { max_layer = out_topo_struct->min_layer; } } --max_layer; } // Compute the tributary layer int32_t TopoStruct::ComputeTributaryLayer(int32_t max_min_layer) { if (tributary_layer >= 0) { return tributary_layer; } tributary_layer = max_min_layer; for (auto& out_topo_struct : out_topo_structs) { if (tributary_layer > out_topo_struct->ComputeTributaryLayer(max_min_layer)) { tributary_layer = out_topo_struct->tributary_layer; } } return --tributary_layer; } void TopoStruct::ComputeIsReusable() { is_reusable = IsProducedRegisterReusable(op_node->op()); } // Compute the memory increment for all the topological structures void ComputeAllMemoryIncrement(std::vector& topo_structs, HashMap& lbi2id, std::vector>& id2consumer_topo_structs, std::vector& id2blob_size) { // Compute the memory increment for produced blobs for (auto& topo_struct : topo_structs) { topo_struct->memory_increment = 0; const auto& curr_operator = topo_struct->op_node->op(); if (topo_struct->is_reusable) { for (const auto& obn : curr_operator.output_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn); auto it = lbi2id.find(lbi); if (it == lbi2id.end()) { // There exist some blobs that do not have any consumer // Such as: op name: // model.cls_head.loss_func.lm_loss-sparse_softmax_cross_entropy_ms-231-split_softmax_reduce_max_global_stage // blob name: mask_0 const BlobDesc& logical_blob_desc = topo_struct->op_node->LogicalBlobDesc4Lbi(lbi); lbi2id[lbi] = id2blob_size.size(); id2blob_size.push_back(TotalByteSize4BlobDesc(logical_blob_desc)); // There are some inconsistency between id2blob_size and id2consumer_topo_structs // We would deal with that at the end to avoid division by 0 topo_struct->memory_increment += id2blob_size.back(); } else { topo_struct->memory_increment += id2blob_size[it->second]; } } } } // Subtract the consumed memory for (int32_t index = 0; index < id2consumer_topo_structs.size(); index++) { int64_t memory_decrease = id2blob_size[index] / id2consumer_topo_structs[index].size(); for (auto& consumer_topo_struct : id2consumer_topo_structs[index]) { consumer_topo_struct->memory_increment -= memory_decrease; } } // Add empty vectors for all those blobs without consumers id2consumer_topo_structs.resize(id2blob_size.size()); } void UpdateSat(const std::vector& topo_structs, StraightenAlgorithmTag* sat) { *sat = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); if (*sat == StraightenAlgorithmTag::kOverlap4CpuGpu) { // If not cpu nodes, then the overlap strategy between cpu and gpu might consume large memory bool exist_cpu_nodes = false; for (const auto& topo_struct : topo_structs) { // Found a cpu node if (topo_struct->exceed_time == 1) { exist_cpu_nodes = true; break; } } if (!exist_cpu_nodes) { // Switch to the compress memory strategy, the default one // Since the overlap strategy for transfer might not be working on 1n1d. *sat = StraightenAlgorithmTag::kCompressMemory; } } } void InitInOutTopoStructs(std::vector* topo_structs) { // Generate the map from operator names to topological structure HashMap op_name2topo_structs; for (auto& topo_struct : *topo_structs) { op_name2topo_structs[topo_struct->op_node->op().op_name()] = topo_struct; } // Traverse the topological structures for (auto& this_topo_struct : *topo_structs) { auto& node = this_topo_struct->op_node; // Initialize input nodes for edges with data node->ForEachNodeOnInEdge([&](OpNode* in) { // Since we might be looking at a sub-graph of the operator graph. // We need to check if the op_node exists in the sub-graph. auto it = op_name2topo_structs.find(in->op().op_name()); if (it != op_name2topo_structs.end()) { this_topo_struct->in_topo_structs.insert(it->second); it->second->out_topo_structs.insert(this_topo_struct); } }); // Initialize input nodes for control edges for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) { auto it = op_name2topo_structs.find(ctrl_in_op_name); if (it != op_name2topo_structs.end()) { auto& ctrl_in_topo_struct = it->second; this_topo_struct->in_topo_structs.insert(ctrl_in_topo_struct); // Initialize output nodes for this control edge simultaneously ctrl_in_topo_struct->out_topo_structs.insert(this_topo_struct); } } } } void ComputeLayer(std::vector* topo_structs) { int32_t max_min_layer = -1; // Compute the minimum layer for the whole graph for (auto& topo_struct : *topo_structs) { if (max_min_layer < topo_struct->ComputeMinLayer()) { max_min_layer = topo_struct->min_layer; } } max_min_layer++; // Compute the maximum layer for the whole graph for (auto& topo_struct : *topo_structs) { topo_struct->ComputeMaxLayer(max_min_layer); } // Compute the tributary layer for (auto& topo_struct : *topo_structs) { topo_struct->ComputeTributaryLayer(max_min_layer); } } void InitAllParameters(std::vector* topo_structs, HashMap* lbi2id, std::vector>* id2consumer_topo_structs, std::vector* id2blob_size) { // Construct the map from a lbi to its id, consumers, blob size for (auto& topo_struct : *topo_structs) { const auto& consumer = topo_struct->op_node->op(); for (const auto& ibn : consumer.input_bns()) { const LogicalBlobId& lbi = consumer.BnInOp2Lbi(ibn); auto it = lbi2id->find(lbi); if (it == lbi2id->end()) { (*lbi2id)[lbi] = id2blob_size->size(); const BlobDesc& logical_blob_desc = topo_struct->op_node->LogicalBlobDesc4Lbi(lbi); id2blob_size->push_back(TotalByteSize4BlobDesc(logical_blob_desc)); id2consumer_topo_structs->push_back({topo_struct}); } else { id2consumer_topo_structs->at(it->second).push_back(topo_struct); } } } // Construct all the data edges and control edges InitInOutTopoStructs(topo_structs); // Compute the layers ComputeLayer(topo_structs); // Compute the memory increment for all the topological structures ComputeAllMemoryIncrement(*topo_structs, *lbi2id, *id2consumer_topo_structs, *id2blob_size); // Update sat, since sat might be changed in previous jobs UpdateSat(*topo_structs, &sat); // Decide which node should run first InitDecideParameters(sat, &decide_parameters); VLOG(3) << "Straightening order in sbp graph: "; for (int32_t decide_parameter : decide_parameters) { VLOG(3) << decide_parameter; } } void StraightenOpNodes(HashMap& op_node2topo_struct, std::vector* topo_structs, HashMap* lbi2id, std::vector>* id2consumer_topo_structs, std::vector* id2blob_size, std::vector* ordered_topo_structs) { InitAllParameters(topo_structs, lbi2id, id2consumer_topo_structs, id2blob_size); std::set waiting_list; // Wait in the list auto wait = [&](TopoStruct* topo_struct) { waiting_list.insert(topo_struct); }; // Initialization for (auto& topo_struct : *topo_structs) { topo_struct->counter = topo_struct->in_topo_structs.size(); if (topo_struct->counter == 0) { wait(topo_struct); } } // Finish execution auto finish_execution = [&](TopoStruct* topo_struct) { for (auto& out : topo_struct->out_topo_structs) { out->counter--; if (out->counter == 0) { wait(out); } } }; // Execute the first node in the waiting list // Make sure to check that waiting list is not empty before execution auto execute = [&]() { auto first_topo_struct = *waiting_list.begin(); // Set the order of execution for sbp nodes ordered_topo_structs->push_back(first_topo_struct); waiting_list.erase(waiting_list.begin()); finish_execution(first_topo_struct); }; // straightening while (!waiting_list.empty()) { execute(); } } } // anonymous namespace // Use two function void InitMemory(const OpGraph& op_graph, SbpGraph* sbp_graph, bool nccl_use_compute_stream) { // Generate topological data structure for each sbp node HashMap op_node2topo_struct; std::vector topo_structs; std::vector ordered_topo_structs; // Traverse all the nodes in the sbp graph for (const auto& sbp_node : sbp_graph->GetNodeList()) { auto* op_node = sbp_node->GetOperatorNode(); CHECK(op_node != nullptr) << "No proxy node allow at this status. InitMemory() should be run before sbp collector!"; op_node2topo_struct.insert({op_node, TopoStruct(sbp_node)}); topo_structs.push_back(&op_node2topo_struct.at(op_node)); } // Construct the map from a lbi to its id, consumers, blob size HashMap lbi2id; std::vector> id2consumer_topo_structs; std::vector id2blob_size; StraightenOpNodes(op_node2topo_struct, &topo_structs, &lbi2id, &id2consumer_topo_structs, &id2blob_size, &ordered_topo_structs); // Mark the memory support, which contains two part: // All the non-reusable memory and those blobs which is a part of the maximum reusable memory int64_t max_reusable_memory = 0; int64_t curr_reusable_memory = 0; std::vector id2count(id2blob_size.size(), -1); // Blobs born, increase count and memory auto GenerateBlobs = [&](TopoStruct* topo_struct) { const auto& curr_operator = topo_struct->op_node->op(); if (topo_struct->is_reusable) { for (const auto& obn : curr_operator.output_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn); int32_t index = lbi2id.at(lbi); // Reusable blobs born curr_reusable_memory += id2blob_size[index]; id2count[index] = id2consumer_topo_structs[index].size(); } } }; // Blobs die, decrease count and memory auto KillBlobs = [&](TopoStruct* topo_struct) { const auto& curr_operator = topo_struct->op_node->op(); // Those reusable blobs who do not have a consumer would die immediately // For example: // register_num: 1, op_name: // "model.cls_head.loss_func.lm_loss-sparse_softmax_cross_entropy_ms-231-split_softmax_reduce_max_device_stage", // blob_name: "mask_0", shape { dim: 2048 dim: 21248 }, // data_type: kBool, time_shape { dim: 1 dim: 1 }, enable_reuse_mem: true, // alloc_before_actor: 369, free_after_actor: 369 if (topo_struct->is_reusable) { for (const auto& obn : curr_operator.output_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn); int32_t index = lbi2id.at(lbi); // Do not have consumer if (id2count[index] == 0) { // Reusable blobs die curr_reusable_memory -= id2blob_size[index]; } } } // Reduce the counter and kill the blobs if count to 0 for (const auto& ibn : curr_operator.input_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(ibn); int32_t index = lbi2id.at(lbi); if (id2count[index] > 0) { --id2count[index]; if (id2count[index] == 0) { // Reusable blobs die curr_reusable_memory -= id2blob_size[index]; } } } }; // Calculate the maximum reusable memory and mark those fixed memory for (auto& topo_struct : ordered_topo_structs) { // Blobs born, increase count and memory GenerateBlobs(topo_struct); // Record the maximum memory if (curr_reusable_memory > max_reusable_memory) { max_reusable_memory = curr_reusable_memory; } // Blobs die, decrease count and memory KillBlobs(topo_struct); } // Make sure that every blob dies CHECK_EQ(curr_reusable_memory, 0) << " Have not kill all the reusable blobs!"; // Mark those reusable memory which constitute the maximum reusable memory for (auto& topo_struct : ordered_topo_structs) { // Blobs born, increase count and memory GenerateBlobs(topo_struct); // Mark the first found support if (curr_reusable_memory == max_reusable_memory) { // Mark the temporary memory created by this operator if (topo_struct->is_reusable) { const auto& curr_operator = topo_struct->op_node->op(); for (const auto& obn : curr_operator.output_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn); int32_t index = lbi2id.at(lbi); // We would use id2count != 0 to record the lbi support // Those obn with no consumers have id2count[index] == 0, now it would be set to 1 id2count[index] = 1; } } // The other lbi in the support would have a non-zero id2count // No further process needed break; } // Blobs die, decrease count and memory KillBlobs(topo_struct); } // Initialize memory for each sbp node for (auto& topo_struct : topo_structs) { topo_struct->sbp_node->InitializeMemory(topo_struct->is_reusable, lbi2id, id2count, nccl_use_compute_stream); } } // Straighten a subset of the op graph void StraightenSubGraph(const std::vector& sub_graph, std::vector* ordered_op_nodes) { // Generate topological data structure for each op node HashMap op_node2topo_struct; std::vector topo_structs; std::vector ordered_topo_structs; // Traverse all the nodes in the sub graph for (const auto& node : sub_graph) { op_node2topo_struct.insert({node, TopoStruct(node)}); topo_structs.push_back(&op_node2topo_struct.at(node)); } // Construct the map from a lbi to its id, consumers, blob size HashMap lbi2id; std::vector> id2consumer_topo_structs; std::vector id2blob_size; StraightenOpNodes(op_node2topo_struct, &topo_structs, &lbi2id, &id2consumer_topo_structs, &id2blob_size, &ordered_topo_structs); for (auto& ordered_topo_struct : ordered_topo_structs) { ordered_op_nodes->push_back(ordered_topo_struct->op_node); } } // Straighten the whole op graph void StraightenOpGraph(const OpGraph& op_graph, std::vector* ordered_op_nodes) { std::vector sub_graph; // Traverse and store all the nodes in the op graph op_graph.ForEachNode([&](OpNode* node) { sub_graph.push_back(node); }); StraightenSubGraph(sub_graph, ordered_op_nodes); } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/auto_memory.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_ #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { namespace auto_parallel { void InitMemory(const OpGraph& op_graph, SbpGraph* sbp_graph, bool nccl_use_compute_stream); // Straighten a subset of the op graph void StraightenSubGraph(const std::vector& sub_graph, std::vector* ordered_op_nodes); // Straighten the whole op graph void StraightenOpGraph(const OpGraph& op_graph, std::vector* ordered_op_nodes); } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_ ================================================ FILE: oneflow/core/auto_parallel/binary_set.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/binary_set.h" namespace oneflow { namespace auto_parallel { namespace { // A static function for initialization of log_2 mapping std::unordered_map InitLog2() { std::unordered_map log_2; for (int32_t i = 0; i < 8 * sizeof(BinarySetEntryType); i++) { log_2[static_cast(1 << i)] = i; } return log_2; } // Initialization of log_2 mapping // Take log2 of a integer value: 2^n -> n. const std::unordered_map log_2 = InitLog2(); } // namespace // Constructor BinarySet::BinarySet(int32_t size_of_set) : size_of_set_(size_of_set) { int32_t k = (size_of_set - 1) / bit_entry_type_ + 1; binary_set_values_.resize(k, 0); } // Initialization if needed void BinarySet::Initialize(int32_t size_of_set) { size_of_set_ = size_of_set; int32_t k = (size_of_set - 1) / bit_entry_type_ + 1; binary_set_values_.resize(k, 0); } // Clear all the elements in the set void BinarySet::Clear() { binary_set_values_.assign(binary_set_values_.size(), 0); } // Check if i-th element in this subset bool BinarySet::CheckExistence(int32_t i) const { int32_t k = i / bit_entry_type_; int32_t j = i % bit_entry_type_; return bool((binary_set_values_[k] >> j) & 1); } // Add i-th element into this subset void BinarySet::AddEntry(int32_t i) { int32_t k = i / bit_entry_type_; int32_t j = i % bit_entry_type_; binary_set_values_[k] |= (1 << j); } // Take i-th element out from this subset void BinarySet::DeleteEntry(int32_t i) { int32_t k = i / bit_entry_type_; int32_t j = i % bit_entry_type_; binary_set_values_[k] &= ~(1 << j); } // Get the union with another subset and store it into u void BinarySet::UnionTo(const BinarySet& bs, BinarySet& u) { for (int32_t k = 0; k < binary_set_values_.size(); k++) { u.binary_set_values_[k] = binary_set_values_[k] | bs.binary_set_values_[k]; } } // If this binary set intersects another one bool BinarySet::IfIntersect(const BinarySet& bs) const { int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size()); for (int32_t k = 0; k < min_bs_size; k++) { if (binary_set_values_[k] & bs.binary_set_values_[k]) { return true; } } return false; } // Get the intersection with another subset and store it into i void BinarySet::IntersectionTo(const BinarySet& bs, BinarySet& i) const { int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size()); if (min_bs_size > i.binary_set_values_.size()) { i.binary_set_values_.resize(min_bs_size, 0); } for (int32_t k = 0; k < binary_set_values_.size(); k++) { i.binary_set_values_[k] = binary_set_values_[k] & bs.binary_set_values_[k]; } } // Count number of elements in this subset int32_t BinarySet::Total() const { int32_t t = 0; for (int32_t k = 0; k < binary_set_values_.size(); k++) { BinarySetEntryType bsv = binary_set_values_[k]; bsv = (bsv & 0x5555555555555555) + ((bsv >> 1) & 0x5555555555555555); bsv = (bsv & 0x3333333333333333) + ((bsv >> 2) & 0x3333333333333333); bsv = (bsv & 0x0F0F0F0F0F0F0F0F) + ((bsv >> 4) & 0x0F0F0F0F0F0F0F0F); bsv = (bsv & 0x00FF00FF00FF00FF) + ((bsv >> 8) & 0x00FF00FF00FF00FF); bsv = (bsv & 0x0000FFFF0000FFFF) + ((bsv >> 16) & 0x0000FFFF0000FFFF); // bsv = (bsv & 0x00000000FFFFFFFF) + ((bsv >> 32) & 0x00000000FFFFFFFF); t += int32_t(bsv); } return t; } // Output all the elements in the subset void BinarySet::Output(std::vector& out) const { out.clear(); for (int32_t i = 0; i < size_of_set_; i++) { if (CheckExistence(i)) { out.emplace_back(i); } } } // Output all the elements in the subset void BinarySet::QuickOutput(std::vector& out) const { out.clear(); for (int32_t i = 0; i < binary_set_values_.size(); i++) { BinarySetEntryType x = binary_set_values_[i]; BinarySetEntryType y = 0; while (x) { y = x; x &= x - 1; out.emplace_back(i * BinarySet::bit_entry_type_ + log_2.find(y - x)->second); } } } // Add elements of input into this subset void BinarySet::AddEntries(std::vector& in) { for (int32_t i : in) { AddEntry(i); } } // If two binary sets are equal to each other bool BinarySet::operator==(const BinarySet& rhs) const { if (size_of_set_ != rhs.size_of_set_) { return false; } for (int32_t i = 0; i < binary_set_values_.size(); i++) { if (binary_set_values_[i] != rhs.binary_set_values_[i]) { return false; } } return true; } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/binary_set.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_ #include #include #include #include "oneflow/core/common/hash.h" namespace oneflow { namespace auto_parallel { // log_2_ index only support 32-bit int. Don't know why. // Don't have any other bugs for unsigned int. using BinarySetEntryType = unsigned int; class BinarySet { public: BinarySet() {} explicit BinarySet(int32_t size_of_set); // Initialization void Initialize(int32_t size_of_set); // Clear all the elements in the set void Clear(); // Check if i-th element in this subset bool CheckExistence(int32_t i) const; // Add i-th element into this subset void AddEntry(int32_t i); // Take i-th element out from this subset void DeleteEntry(int32_t i); // Get the union with another subset and store it into u void UnionTo(const BinarySet& bs, BinarySet& u); // If this binary set intersects another one bool IfIntersect(const BinarySet& bs) const; // Get the intersection with another subset and store it into i void IntersectionTo(const BinarySet& bs, BinarySet& i) const; // Count number of elements in this subset int32_t Total() const; // Output all the elements in the subset void Output(std::vector& out) const; // Output all the elements in the subset void QuickOutput(std::vector& out) const; // Add elements of input into this subset void AddEntries(std::vector& in); // If two binary sets are equal to each other bool operator==(const BinarySet& rhs) const; inline int32_t GetSizeOfSet() const { return size_of_set_; }; private: friend struct BinarySetHasher; // binary_set_values_ contains a vector of 64-bit or 32-bit int. // Each bit means whether an entry is in the set std::vector binary_set_values_; int32_t size_of_set_ = -1; // total bits of the entry type in vector binary_set_values_. static constexpr int32_t bit_entry_type_ = 8 * sizeof(BinarySetEntryType); }; struct BinarySetHasher { std::size_t operator()(const BinarySet& bs) const { using std::hash; using std::size_t; size_t h = 0; for (int i = 0; i < bs.binary_set_values_.size(); i++) { h = HashCombine(h, hash()(bs.binary_set_values_[i])); } return h; }; }; } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_ ================================================ FILE: oneflow/core/auto_parallel/boxing_collector.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace { static bool disable_middle_node = false; void DfsSetNdSbp(const std::vector& id2sbp_parallel, int32_t depth, int32_t max_depth, NdSbp& nd_sbp, std::vector& nd_sbp_lists, std::unordered_map& nd_sbp_universe) { if (depth == max_depth) { nd_sbp_universe[nd_sbp] = nd_sbp_lists.size(); nd_sbp_lists.push_back(nd_sbp); } else { for (const auto& sbp_parallel : id2sbp_parallel) { *nd_sbp.mutable_sbp_parallel(depth) = sbp_parallel; DfsSetNdSbp(id2sbp_parallel, depth + 1, max_depth, nd_sbp, nd_sbp_lists, nd_sbp_universe); } } } // Let a nd sbp be consistent with the given hierarchy number Maybe SetNdSbpDim(const NdSbp& nd_sbp, int32_t hierarchy_num) { // Do not need to change if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; } // (S0, S0) -> S0 if (hierarchy_num == 1) { CHECK_OR_RETURN(Is1dSbp(nd_sbp)) << NdSbpToString(nd_sbp) << " can not be converted to a 1d sbp!"; NdSbp new_sbp; new_sbp.add_sbp_parallel(); *new_sbp.mutable_sbp_parallel(0) = nd_sbp.sbp_parallel(0); return new_sbp; } // S0 -> (S0, S0) CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1) << "Illegal nd sbp transform."; NdSbp new_sbp; for (int32_t i = 0; i < hierarchy_num; i++) { new_sbp.add_sbp_parallel(); *new_sbp.mutable_sbp_parallel(i) = nd_sbp.sbp_parallel(0); } return new_sbp; } int32_t TotalNumSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) { int32_t total_num_split = 1; for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); i++) { if (nd_sbp.sbp_parallel(i).has_split_parallel()) { total_num_split *= parallel_desc.hierarchy()->At(i); } } return total_num_split; } // Dealing with 1D sbp to 1D sbp // Specifically, S -> P. Maybe AskSbpCombinationFor1DSbp(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, std::vector& middle_sbps, int32_t* diag_node_pos) { if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) { // Support [4]: P <--> [2, 2]: (P, P) // Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P) if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num() && sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) { return Maybe::Ok(); } if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) { // S -> B -> P (Large cost!) // TODO: Please implement S -> P directly. // We do not support [3]: P <--> [2, 2]: (P, P) as well. int32_t hierarchy_size = 0; if (producer_parallel_desc.hierarchy()->elem_cnt() < consumer_parallel_desc.hierarchy()->elem_cnt()) { // The diagonal node uses the parallel description from producer // (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P) *diag_node_pos = 1; hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes(); } else { // The diagonal node uses the parallel description from consumer // S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P) *diag_node_pos = 0; hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes(); } NdSbp broadcast_nd; for (int32_t i = 0; i < hierarchy_size; i++) { broadcast_nd.add_sbp_parallel(); broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); } middle_sbps.emplace_back(broadcast_nd); } } return Maybe::Ok(); } } // namespace // A constructor with init, designed for pre-stored boxing collector BoxingCollector::BoxingCollector(int32_t max_axis) { CHECK_JUST(Init(max_axis)); } // Construct a boxing collector with given maximum number of axis Maybe BoxingCollector::Init(int32_t max_axis) { // Update environment parameter disable_middle_node = ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false); // Not allowed two-step boxing and disable checking for debugging if (disable_middle_node) { return Maybe::Ok(); } // Set up at least two split for op graph. // For a negative example: Resnet50 only have B, P, S(0) CollectUniverse(max_axis); GenerateNdSbpList(2); GenerateMap1d2nd(); // Get copy cost in lazy mode LazyMode::Guard enable_lazy_mode(true); JUST(GenerateCombination4SamePlacement(3)); JUST(GenerateCombination4DiffHierarchy(this, this)); JUST(GenerateCombination4DiffPlacement(this, this)); init_type_ = int32_t(enable_general_basic_communication || Singleton::Get()->nccl_use_compute_stream()); return Maybe::Ok(); } // Customized initialization with given blob and parallel description Maybe BoxingCollector::Init(const BlobDesc& logical_blob_desc, const ParallelDesc& parallel_desc) { CollectUniverse(logical_blob_desc.shape().NumAxes()); GenerateNdSbpList(parallel_desc.hierarchy()->NumAxes()); // Filter out unsuitable middle nodes before computing minimum cost. JUST(FilterNdSbpList4LogicalShape(logical_blob_desc, *parallel_desc.hierarchy())); GenerateMap1d2nd(); // Get copy cost in lazy mode LazyMode::Guard enable_lazy_mode(true); JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc)); init_type_ = int32_t(enable_general_basic_communication || Singleton::Get()->nccl_use_compute_stream()); return Maybe::Ok(); } // Collect Sbp Parallel void BoxingCollector::CollectUniverse(const SbpParallel& sbp) { if (sbp_parallel_universe_.find(sbp) == sbp_parallel_universe_.end()) { int32_t curr_size = sbp_parallel_universe_.size(); sbp_parallel_universe_[sbp] = curr_size; id2sbp_parallel_.push_back(sbp); } } // Find corresponding id for Nd sbp int32_t BoxingCollector::FindId4NdSbp(const NdSbp& nd_sbp) { // Directly search on the nd_sbp_list if (nd_sbp.sbp_parallel_size() == hierarchy_num_) { const auto& it_nd_sbp = nd_sbp_universe_.find(nd_sbp); if (it_nd_sbp != nd_sbp_universe_.end()) { return it_nd_sbp->second; } else { return -1; } } // Find the diagonal node if it could be converted to a 1D sbp if (Is1dSbp(nd_sbp)) { const auto& it_nd_sbp = sbp_parallel_universe_.find(nd_sbp.sbp_parallel(0)); if (it_nd_sbp != sbp_parallel_universe_.end()) { return id_1d_2_nd_[it_nd_sbp->second]; } } // Can not be converted to a 1D sbp or not found in the 1D sbp list return -1; } // Set default Sbp list void BoxingCollector::CollectUniverse(int32_t max_axis) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); CollectUniverse(sbp); for (int32_t axis = 0; axis < max_axis; axis++) { sbp.mutable_split_parallel()->set_axis(axis); CollectUniverse(sbp); } sbp.mutable_partial_sum_parallel(); CollectUniverse(sbp); } // Generate nd sbp list void BoxingCollector::GenerateNdSbpList(int32_t hierarchy_num) { // 1D sbp does not support S->P. But it seems that we do not need to deal with it for now. // And we do not have 3D sbp or higher dimension. hierarchy_num_ = hierarchy_num; // Generate possible nd_sbp lists NdSbp nd_sbp; for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num; dim_sbp++) { nd_sbp.add_sbp_parallel(); } DfsSetNdSbp(id2sbp_parallel_, 0, hierarchy_num, nd_sbp, nd_sbp_lists_, nd_sbp_universe_); } // Generate the map from 1d sbp to 2d sbp void BoxingCollector::GenerateMap1d2nd() { // Number of 1d sbp int32_t m = id2sbp_parallel_.size(); // Generate the id Map from 1d sbp to nd sbp NdSbp nd_sbp; for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); } id_1d_2_nd_.clear(); id_1d_2_nd_.resize(m, -1); for (int32_t id_1d = 0; id_1d < m; id_1d++) { for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { *nd_sbp.mutable_sbp_parallel(dim_sbp) = id2sbp_parallel_[id_1d]; } // NOTE: The 2d sbp might be filtered out already. const auto& it_ = nd_sbp_universe_.find(nd_sbp); if (it_ != nd_sbp_universe_.end()) { id_1d_2_nd_[id_1d] = it_->second; } } } // Generate the transfer rule for different combinations with the same hierarchy Maybe BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num) { // other parameters // NOTE: The performance of this function are all the same with different hierarchy int32_t world_size = GlobalProcessCtx::WorldSize(); Shape hierarchy44({4 * world_size, 4 * world_size}); int32_t virtual_range_size = hierarchy44.elem_cnt(); std::shared_ptr virtual_hierarchy = std::make_shared(hierarchy44); auto parallel_desc = JUST(ParallelDesc::New( "cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy)); BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size}, DataType::kInt8, MemoryFormat::kContiguous, /*is_dynamic=*/false); JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc)); return Maybe::Ok(); } // Generate the transfer rule for different combinations with the same hierarchy Maybe BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num, const BlobDesc& blob_desc, const ParallelDesc& parallel_desc) { // Store the origin transfer cost information int32_t n = nd_sbp_lists_.size(); minimum_copy_cost_.clear(); minimum_copy_cost_.resize(n); middle_nodes_.clear(); middle_nodes_.resize(n); for (int32_t i = 0; i < n; i++) { minimum_copy_cost_[i].resize(n); middle_nodes_[i].resize(n); for (int32_t j = 0; j < n; j++) { minimum_copy_cost_[i][j] = JUST(ComputeLazyCopyCostBetweenNdSbp( nd_sbp_lists_[i], nd_sbp_lists_[j], blob_desc, parallel_desc, parallel_desc, /*requires_same_sbp=*/false)); } } auto NotMiddleNode = [&](int32_t i, int32_t j, int32_t k, int32_t middle_node_num_ik) -> bool { // Not allow i -> i -> j or i -> j -> j. if (k == j || k == i) { return true; } // We add middle nodes one by one // Thus, we allow multiple nodes from i to k but we only accept 1 step from k to j. // i -> ? -> k -> j if (middle_nodes_[k][j].size() > 0) { return true; } // To avoid multiple counting and bugs, the number of middle nodes between i and k // must be exactly middle_node_num_ik, which is (middle_node_num - 1) if (middle_node_num_ik) { if (middle_nodes_[i][k].size() == 0 || middle_nodes_[i][k][0].size() != middle_node_num_ik) { return true; } } else { if (middle_nodes_[i][k].size() > 0) { return true; } } return false; }; for (int32_t middle_node_num = 1; middle_node_num <= max_middle_node_num; middle_node_num++) { int32_t middle_node_num_ik = middle_node_num - 1; for (int32_t i = 0; i < n; i++) { for (int32_t j = 0; j < n; j++) { if (minimum_copy_cost_[i][j] < GetValidMaxCopyCost()) { continue; } // Compute the smallest transfer cost // k is the middle node, i -> k -> j for (int32_t k = 0; k < n; k++) { if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; } double curr_copy_cost = minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j]; if (curr_copy_cost < minimum_copy_cost_[i][j]) { minimum_copy_cost_[i][j] = curr_copy_cost; } } // If the minimum copy cost remains infinity, adding one middle node does not make it. if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { continue; } // Find those middle nodes for (int32_t k = 0; k < n; k++) { if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; } // Now we start to judge if the edge have a minimum cost // It needs to be "<=" since we have 0 cost. // Using "<" would give no middle nodes from (B, B) to any other nd sbp. if (minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j] <= minimum_copy_cost_[i][j] * 1.0000001) { // i -> ? -> k if (middle_nodes_[i][k].size() > 0) { // We have multiple choices going from i to k for (const auto& middle_node_ik : middle_nodes_[i][k]) { middle_nodes_[i][j].push_back(middle_node_ik); middle_nodes_[i][j][middle_nodes_[i][j].size() - 1].push_back(k); } } else { // We only need one middle node k to reach j from i middle_nodes_[i][j].push_back({k}); } } } CHECK_OR_RETURN(middle_nodes_[i][j].size() > 0) << "No middle nodes given from " << NdSbpToString(nd_sbp_lists_[i]) << " to " << NdSbpToString(nd_sbp_lists_[j]) << " in boxing collector"; } } } return Maybe::Ok(); } // Generate the transfer rule for different combinations with different hierarchies on the same // placement Maybe BoxingCollector::GenerateCombination4DiffHierarchy( BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) { // Store the boxing collector pointer // Search the path that contains one of the diagonal sbp int32_t n = nd_sbp_lists_.size(); diag_node_diff_hierarchy_.clear(); diag_node_diff_hierarchy_.resize(n); for (int32_t i = 0; i < n; i++) { diag_node_diff_hierarchy_[i].resize(n); for (int32_t j = 0; j < n; j++) { JUST(Generate1Combination4DiffHierarchy(i, j, boxing_collector_producer, boxing_collector_consumer, diag_node_diff_hierarchy_[i][j])); } } return Maybe::Ok(); } // Generate the transfer rule for different combinations with different placements Maybe BoxingCollector::GenerateCombination4DiffPlacement( BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) { // Virtual parallel and blob description int32_t world_size = GlobalProcessCtx::WorldSize(); int32_t virtual_range_size = 4 * world_size * (4 * world_size + 1); BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size, virtual_range_size}, DataType::kInt8, MemoryFormat::kContiguous, /*is_dynamic=*/false); // Virtual placements before transfer Shape in_hierarchy44({4 * world_size + 1, 4 * world_size}); std::shared_ptr in_hierarchy = std::make_shared(in_hierarchy44); auto in_parallel_desc = JUST(ParallelDesc::New( "cpu", {"0:0-" + std::to_string(in_hierarchy44.elem_cnt() - 1)}, in_hierarchy)); // Virtual placements after transfer Shape out_hierarchy44({4 * world_size, 4 * world_size}); std::shared_ptr out_hierarchy = std::make_shared(out_hierarchy44); auto out_parallel_desc = JUST(ParallelDesc::New( "cpu", {"0:0-" + std::to_string(out_hierarchy44.elem_cnt() - 1)}, out_hierarchy)); JUST(GenerateCombination4DiffPlacement(boxing_collector_producer, boxing_collector_consumer, blob_desc, *in_parallel_desc, *out_parallel_desc)); return Maybe::Ok(); } // The cost for transferring a 1D sbp between different placements Maybe BoxingCollector::ComputeCostFor1DSbpDiffPlacement( const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, std::vector>& cost_4_diff_placement) { // Number of 1d sbp int32_t m = id2sbp_parallel_.size(); // Compute the cost while transferring a 1D sbp between different placements cost_4_diff_placement.clear(); cost_4_diff_placement.resize(m); for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) { cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal()); int32_t diag_producer = id_1d_2_nd_[id_1d_producer]; if (diag_producer < 0) { continue; } for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) { int32_t diag_consumer = id_1d_2_nd_[id_1d_consumer]; if (diag_consumer < 0) { continue; } cost_4_diff_placement[id_1d_producer][id_1d_consumer] = JUST(ComputeLazyCopyCostBetweenNdSbp( nd_sbp_lists_[diag_producer], nd_sbp_lists_[diag_consumer], blob_desc, in_parallel_desc, out_parallel_desc, false)); } } return Maybe::Ok(); } // Generate the transfer rule for different combinations with different placements Maybe BoxingCollector::GenerateCombination4DiffPlacement( BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc) { // The cost for transferring a 1D sbp between different placements std::vector> cost_4_diff_placement; // Compute the cost while transferring a 1D sbp between different placements JUST(ComputeCostFor1DSbpDiffPlacement(blob_desc, in_parallel_desc, out_parallel_desc, cost_4_diff_placement)); // Search the path that contains two of the diagonal sbp int32_t n = nd_sbp_lists_.size(); diag_node_diff_placement_.clear(); diag_node_diff_placement_.resize(n); for (int32_t i = 0; i < n; i++) { diag_node_diff_placement_[i].resize(n); for (int32_t j = 0; j < n; j++) { JUST(Generate1Combination4DiffPlacement(i, j, boxing_collector_producer, boxing_collector_consumer, cost_4_diff_placement, diag_node_diff_placement_[i][j])); } } return Maybe::Ok(); } // Print the cost and middle nodes void BoxingCollector::PrintBoxingTables() { if (GlobalProcessCtx::Rank() == 0) { std::cout << "===================minimum copy cost==================" << std::endl; // other parameters // To be noted that the performance of this function are all the same with different hierarchy Shape hierarchy44({4, 4}); std::shared_ptr in_hierarchy = std::make_shared(hierarchy44); double logical_blob_size = 1024.0; int32_t n = nd_sbp_lists_.size(); // Print the origin copy cost table std::cout << "Cost\t"; for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; } std::cout << std::endl; for (int32_t i = 0; i < n; i++) { std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t"; for (int32_t j = 0; j < n; j++) { if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { std::cout << "X\t"; } else { std::cout << minimum_copy_cost_[i][j] << "\t"; } } std::cout << std::endl; } std::cout << std::endl; std::cout << "Original Copy Cost" << std::endl; std::cout << "logical blob size: " << logical_blob_size << std::endl; std::cout << "hierarchy: " << *in_hierarchy << std::endl; std::cout << "============================middle nodes===========================" << std::endl; // Print the middle nodes std::cout << "Middle Sbp\t"; for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; } std::cout << std::endl; for (int32_t i = 0; i < n; i++) { std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t"; for (int32_t j = 0; j < n; j++) { if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { std::cout << "X"; } else if (middle_nodes_[i][j].size() > 0) { for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) { std::cout << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][0]]); for (int32_t l = 1; l < middle_nodes_[i][j][k].size(); l++) { std::cout << "->" << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][l]]); } std::cout << "; "; } } std::cout << "\t"; } std::cout << std::endl; } std::cout << std::endl; std::cout << "Minimum Copy Cost after second search" << std::endl; std::cout << "logical blob size: " << logical_blob_size << std::endl; std::cout << "hierarchy: " << *in_hierarchy << std::endl; std::cout << "====================middle nodes for different placement====================" << std::endl; std::cout << "Middle nodes for different placement\t"; for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; } std::cout << std::endl; for (int32_t i = 0; i < n; i++) { std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t"; for (int32_t j = 0; j < n; j++) { if (diag_node_diff_placement_[i][j].size() > 0) { for (int32_t k = 0; k < diag_node_diff_placement_[i][j].size(); k++) { std::cout << "[" << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][0]]) << ", " << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][1]]) << "]; "; } } std::cout << "\t"; } std::cout << std::endl; } std::cout << "====================middle nodes for different hierarchy====================" << std::endl; std::cout << "Middle nodes for different hierarchy\t"; for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; } std::cout << std::endl; for (int32_t i = 0; i < n; i++) { std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t"; for (int32_t j = 0; j < n; j++) { if (diag_node_diff_hierarchy_[i][j].size() > 0) { for (int32_t k = 0; k < diag_node_diff_hierarchy_[i][j].size(); k++) { std::cout << NdSbpToString(nd_sbp_lists_[diag_node_diff_hierarchy_[i][j][k][0]]) << "; "; } } std::cout << "\t"; } std::cout << std::endl; } std::cout << "================================================" << std::endl; } } // Ask if the boxing algorithm accepts the current sbp combination Maybe BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost) { middle_sbps.clear(); // Not allowed two-step boxing and disable checking for debugging if (disable_middle_node) { return Maybe::Ok(); } if (producer_parallel_desc == consumer_parallel_desc && sbp_producer == sbp_consumer) { return Maybe::Ok(); } // Dealing with 1D sbp to 1D sbp if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) { JUST(AskSbpCombinationFor1DSbp(sbp_producer, sbp_consumer, producer_parallel_desc, consumer_parallel_desc, middle_sbps, diag_node_pos)); // No middle nodes for the other 1d-sbp combinations return Maybe::Ok(); } #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) // Use a general basic communication if no P in the consumer if (((Singleton::Get()->nccl_use_compute_stream() && producer_parallel_desc == consumer_parallel_desc) || enable_general_basic_communication) && (!NdSbpHasPartialParallel(sbp_consumer)) && producer_parallel_desc.device_type() == consumer_parallel_desc.device_type() && producer_parallel_desc.device_type() != DeviceType::kCPU) { if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) { // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer // Directly applying general basic communication would have O(n^2) time complexity for P->B // Using two-step transfer would reduce it to a linear cost JUST(AskSbpCombination4GeneralBasicCommunication( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, middle_sbps, diag_node_pos)); } // Otherwise, one-step transfer return Maybe::Ok(); } #endif // WITH_CUDA || WITH_NPU || defined(WITH_MLU) if (JUST(ComputeLazyCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, /*requires_same_sbp=*/false)) < GetValidMaxCopyCost()) { return Maybe::Ok(); } else { int32_t require_init_type = int32_t(enable_general_basic_communication || Singleton::Get()->nccl_use_compute_stream()); if (init_type_ != require_init_type) { // We assemble the boxing table from S(0) to S(5). // Those splitting in higher axes are considered in the customized boxing. constexpr int32_t kRegularMaxSplitAxes = 6; JUST(Init(kRegularMaxSplitAxes)); } } // Middle nodes algorithm supports transfer for different machines or devices or hierarchies if (producer_parallel_desc != consumer_parallel_desc) { JUST(AskSbpCombination4DiffPlacement(sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost)); return Maybe::Ok(); } // Transfer for the same machines, devices and hierarchy. if (sbp_producer == sbp_consumer) { return Maybe::Ok(); } const auto& parallel_hierarchy = producer_parallel_desc.hierarchy(); *diag_node_pos = 0; // Dealing with nD sbp, n>2 if (parallel_hierarchy->NumAxes() > 2) { CHECK_OR_RETURN(compute_cost) << "Boxing does not support a hierarchy with dimension greater than 2"; return Maybe::Ok(); } // Ask for sbp combination with the same 2-D hierarchy and placement JUST(AskSbpCombination4Same2DPlacement(sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost)); return Maybe::Ok(); } // Ask for sbp combination with the same 2-D hierarchy and placement Maybe BoxingCollector::AskSbpCombination4Same2DPlacement( const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost) { CHECK_OR_RETURN(producer_parallel_desc == consumer_parallel_desc) << "Producer and consumer have different placements, Please use AskSbpCombination directly"; middle_sbps.clear(); // Find the 2D sbp id int32_t i = FindId4NdSbp(sbp_producer); int32_t j = FindId4NdSbp(sbp_consumer); // Dealing with 2D sbp if (i >= 0 && j >= 0) { // Such combination can not be support with limited middle nodes if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer) << " -> " << NdSbpToString(sbp_consumer) << " for 2D sbp"; return Maybe::Ok(); } // Current design can deal with such combination. Do not need to insert middle nodes if (middle_nodes_[i][j].size() == 0) { return Maybe::Ok(); } // Find a list of middle nodes with minimum storage int32_t min_k = -1; double min_cost = GetValidMaxCopyCost(); for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) { double curr_cost = 0.0; for (int32_t middle_sbp_id : middle_nodes_[i][j][k]) { Shape logical_shape = logical_blob_desc.shape(); // Storage4NdSbp would modify logical_shape2 as well curr_cost += Storage4NdSbp(nd_sbp_lists_[middle_sbp_id], logical_shape, *producer_parallel_desc.hierarchy()); if (curr_cost > GetValidMaxCopyCost()) { break; } } // store k if renew minimum cost if (curr_cost < min_cost) { min_k = k; min_cost = curr_cost; } } // If we found a list of middle nodes with current boxing collector int32_t producer_hierarchy_num = producer_parallel_desc.hierarchy()->NumAxes(); if (min_k >= 0) { for (int32_t middle_sbp_id : middle_nodes_[i][j][min_k]) { middle_sbps.emplace_back( *JUST(SetNdSbpDim(nd_sbp_lists_[middle_sbp_id], producer_hierarchy_num))); } return Maybe::Ok(); } } // // If we can not found a list of middle nodes even after customized boxing collector if (is_customized) { CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer) << " -> " << NdSbpToString(sbp_consumer) << " for Shape: " << logical_blob_desc.shape(); return Maybe::Ok(); } // Customized boxing collector and try the algorithm again BoxingCollector customized_boxing_collector; JUST(customized_boxing_collector.Init(logical_blob_desc, producer_parallel_desc)); JUST(customized_boxing_collector.AskSbpCombination4Same2DPlacement( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, /*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost)); return Maybe::Ok(); } // Ask for sbp combination with different hierarchies and placements Maybe BoxingCollector::AskSbpCombination4DiffPlacement( const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost) { middle_sbps.clear(); // Find the 2D sbp id int32_t i = FindId4NdSbp(sbp_producer); int32_t j = FindId4NdSbp(sbp_consumer); // Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda // Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2] bool same_placement = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); // Dealing with 2D sbp if (i >= 0 && j >= 0) { // Pure copy between machines and devices if (i == j && (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy())) { return Maybe::Ok(); } if (same_placement) { // Different hierarchies CHECK_OR_RETURN(diag_node_diff_hierarchy_.size() > 0) << "Have not initialized the combination table for different hierarchies yet! " "Please run JUST(GenerateCombination4DiffHierarchy(this, this)); " "before Asking sbp combination for different parallel description."; if (JUST(Ask1Combination4DiffPlacement( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this, this, diag_node_diff_hierarchy_[i][j]))) { return Maybe::Ok(); } } else { // Different placements CHECK_OR_RETURN(diag_node_diff_placement_.size() > 0) << "Have not initialized the combination table for different hierarchies yet! " "Please run JUST(GenerateCombination4DiffPlacement(this, this)); " "before Asking sbp combination for different parallel description."; if (JUST(Ask1Combination4DiffPlacement( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this, this, diag_node_diff_placement_[i][j]))) { return Maybe::Ok(); } } } // Customized boxing collector and try the algorithm again if (is_customized) { CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer) << "[hierarchy: " << *producer_parallel_desc.hierarchy() << "] -> " << NdSbpToString(sbp_consumer) << "[hierarchy: " << *consumer_parallel_desc.hierarchy() << "] for blob shape: " << logical_blob_desc.shape(); return Maybe::Ok(); } // Customize boxing collector for producer BoxingCollector customized_boxing_collector_producer; JUST(customized_boxing_collector_producer.Init(logical_blob_desc, producer_parallel_desc)); // Customize boxing collector for consumer BoxingCollector customized_boxing_collector_consumer; JUST(customized_boxing_collector_consumer.Init(logical_blob_desc, consumer_parallel_desc)); std::vector> diag_nodes; // Generate the combination table for different hierarchies or placements if (same_placement) { JUST(customized_boxing_collector_producer.Generate1Combination4DiffHierarchy( customized_boxing_collector_producer.FindId4NdSbp(sbp_producer), customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer), &customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes)); } else { // Compute the cost while transferring a 1D sbp between different placements std::vector> cost_4_diff_placement; JUST(ComputeCostFor1DSbpDiffPlacement(logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, cost_4_diff_placement)); JUST(customized_boxing_collector_producer.Generate1Combination4DiffPlacement( customized_boxing_collector_producer.FindId4NdSbp(sbp_producer), customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer), &customized_boxing_collector_producer, &customized_boxing_collector_consumer, cost_4_diff_placement, diag_nodes)); } JUST(customized_boxing_collector_producer.Ask1Combination4DiffPlacement( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, /*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost, &customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes)); return Maybe::Ok(); } // Generate the transfer rule for one combination with different hierarchies on the same // placement. id_producer -> id_consumer. Maybe BoxingCollector::Generate1Combination4DiffHierarchy( int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, std::vector>& diag_nodes) { // Number of 1d sbp int32_t m = id2sbp_parallel_.size(); // Search the path that contains one of the diagonal sbp // minimum number of node int32_t min_path_length = 100; // minimum cost double min_cost = GetValidMaxCopyCost(); for (int32_t id_1d = 0; id_1d < m; id_1d++) { // We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21) // Thus, the diagonal node should suit both the hierarchies. int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d]; if (diag_producer < 0) { continue; } int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d]; if (diag_consumer < 0) { continue; } // Find the path with minimum number of nodes int32_t path_length = 0; // Transfer from id_producer to id_2d if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) { path_length += boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1; } else if (id_producer != diag_producer) { path_length++; } // Transfer from id_2d to id_consumer if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) { path_length += boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1; } else if (diag_consumer != id_consumer) { path_length++; } // Pick the path with minimum copy cost if (path_length <= min_path_length) { double curr_cost = boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer] + boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer]; min_path_length = path_length; // Find a candidate with small cost if (curr_cost < min_cost * kFloatDeviationPlus) { // Find a smaller cost, clear the previous path. if (curr_cost < min_cost * kFloatDeviationMinus) { min_cost = curr_cost; diag_nodes.clear(); } // Add the current diagonal node // Asymmetry happens here. We can only store one side of the diagonal node. // We do not store diag_consumer diag_nodes.push_back({diag_producer, diag_consumer}); } } } return Maybe::Ok(); } // Ask for one combination with different hierarchies and placements Maybe BoxingCollector::Ask1Combination4DiffPlacement( const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const std::vector>& diag_nodes) { // Pick the path with minimum storage for the diagonal node int32_t id_producer = boxing_collector_producer->FindId4NdSbp(sbp_producer); if (id_producer < 0) { CHECK_OR_RETURN(compute_cost) << "Source data with shape " << logical_blob_desc.shape() << " has an invalid sbp " << NdSbpToString(sbp_producer); return false; } int32_t id_consumer = boxing_collector_consumer->FindId4NdSbp(sbp_consumer); if (id_consumer < 0) { CHECK_OR_RETURN(compute_cost) << "Target data with shape " << logical_blob_desc.shape() << " has an invalid sbp " << NdSbpToString(sbp_consumer); return false; } middle_sbps.clear(); // NOTE: For simplicity, We do not dig into those storage cost for the other middle nodes at // this moment. double min_cost = GetValidMaxCopyCost(); int32_t producer_hierarchy_num_axes = producer_parallel_desc.hierarchy()->NumAxes(); int32_t consumer_hierarchy_num_axes = consumer_parallel_desc.hierarchy()->NumAxes(); int32_t min_diag_producer = -1, min_diag_consumer = -1; for (const auto& diag_pair : diag_nodes) { Shape logical_shape = logical_blob_desc.shape(); // We do not check whether such shape is valid under two side of the sbp list in the // middle nodes algorithm. Thus, we need to check them here. double curr_cost = Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[diag_pair[0]], producer_hierarchy_num_axes)), logical_shape, *producer_parallel_desc.hierarchy()); // Check the shape for both producer and consumer. logical_shape = logical_blob_desc.shape(); curr_cost += Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[diag_pair[1]], consumer_hierarchy_num_axes)), logical_shape, *consumer_parallel_desc.hierarchy()); if (curr_cost < min_cost) { min_cost = curr_cost; min_diag_producer = diag_pair[0]; min_diag_consumer = diag_pair[1]; } } // Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda // Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2] bool diff_placement = !producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); // If we found a diagonal middle node with current boxing collector if (min_diag_producer >= 0) { std::vector middle_sbps_buffer; // Find the middle nodes between the producer and the diagonal node if (id_producer != min_diag_producer) { JUST(boxing_collector_producer->AskSbpCombination( sbp_producer, boxing_collector_producer->nd_sbp_lists_[min_diag_producer], logical_blob_desc, producer_parallel_desc, producer_parallel_desc, /*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost)); // Add the path into middle_sbps for (auto& middle_sbp : middle_sbps_buffer) { middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, producer_hierarchy_num_axes))); } // If different placement, // or the same placement but with 2D hierarchies // For example: Oneflow supports [6]: (S0) -> [3, 2]: (S0, S1) // but does not support [2, 3]: (S0, S0) -> [3, 2]: (S0, S1) if (diff_placement || producer_hierarchy_num_axes > 1) { middle_sbps.emplace_back( *JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[min_diag_producer], producer_hierarchy_num_axes))); } } // If we do not have middle nodes on the consumer side *diag_node_pos = middle_sbps.size(); // Find the middle nodes between the diagonal node and the consumer if (id_consumer != min_diag_consumer) { JUST(boxing_collector_consumer->AskSbpCombination( boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer], sbp_consumer, logical_blob_desc, consumer_parallel_desc, consumer_parallel_desc, /*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost)); // Set the diagonal node position and stop using it as buffer *diag_node_pos = middle_sbps.size(); // If different placement if (diff_placement || consumer_hierarchy_num_axes > 1) { middle_sbps.emplace_back( *JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer], consumer_hierarchy_num_axes))); } // Add the path into middle_sbps for (auto& middle_sbp : middle_sbps_buffer) { middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, consumer_hierarchy_num_axes))); } } return true; } return false; } // Generate the transfer rule for one combination with different placements // id_producer -> id_consumer. Maybe BoxingCollector::Generate1Combination4DiffPlacement( int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const std::vector>& cost_4_diff_placement, std::vector>& diag_nodes) { // Number of 1d sbp int32_t m = id2sbp_parallel_.size(); // minimum number of node int32_t min_path_length = 100; // minimum cost double min_cost = GetValidMaxCopyCost(); // Search the path that contains two of the diagonal sbp // From the producer to the first diagonal node for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) { // We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21) // Thus, the diagonal node should suit both the hierarchies. int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d_producer]; if (diag_producer < 0 || boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer] > GetValidMaxCopyCost()) { continue; } // Find the path with minimum number of nodes int32_t path_length = 0; // Transfer from id_producer to diag_producer if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) { path_length += boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1; } else if (id_producer != diag_producer) { path_length++; } // pruning if (path_length > min_path_length) { continue; } // From the second diagonal node to the consumer for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) { int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d_consumer]; // The diagonal sbp is not supported or no paths exist from the diagonal sbp to the // consumer or between the two diagonal sbps. if (diag_consumer < 0 || boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer] > GetValidMaxCopyCost() || cost_4_diff_placement[id_1d_producer][id_1d_consumer] > GetValidMaxCopyCost()) { continue; } // Transfer from diag_consumer to id_consumer int32_t curr_path_length = path_length; if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) { curr_path_length += boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1; } else if (diag_consumer != id_consumer) { curr_path_length++; } // Pick the path with minimum copy cost if (curr_path_length <= min_path_length) { double curr_cost = boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer] + cost_4_diff_placement[id_1d_producer][id_1d_consumer] + boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer]; min_path_length = curr_path_length; // Find a candidate with small cost if (curr_cost < min_cost * 1.0000001) { // Find a smaller cost, clear the previous path. if (curr_cost < min_cost * 0.9999999) { min_cost = curr_cost; diag_nodes.clear(); } // Add the current diagonal node // Asymmetry happens here. We can only store one side of the diagonal node. // We do not store diag_consumer diag_nodes.push_back({diag_producer, diag_consumer}); } } } } return Maybe::Ok(); } // Filter nd sbp from nd_sbp_lists_ with given logical shape Maybe BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc, const Shape& parallel_hierarchy) { for (int32_t middle_sbp_id = nd_sbp_lists_.size() - 1; middle_sbp_id >= 0; middle_sbp_id--) { Shape logical_shape = logical_blob_desc.shape(); if (JUST(FilterNdSbpByLogicalShape(nd_sbp_lists_[middle_sbp_id], logical_shape, parallel_hierarchy))) { // Change the value before erasing // This might be true: nd_sbp_lists_.size() - 1 == middle_sbp_id nd_sbp_universe_[nd_sbp_lists_[nd_sbp_lists_.size() - 1]] = middle_sbp_id; nd_sbp_universe_.erase(nd_sbp_lists_[middle_sbp_id]); nd_sbp_lists_[middle_sbp_id] = nd_sbp_lists_[nd_sbp_lists_.size() - 1]; nd_sbp_lists_.pop_back(); } } return Maybe::Ok(); } // Ask for sbp combination for general basic communication Maybe BoxingCollector::AskSbpCombination4GeneralBasicCommunication( const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, std::vector& middle_sbps, int32_t* diag_node_pos) { // (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP // One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes if (producer_parallel_desc == consumer_parallel_desc && producer_parallel_desc.hierarchy()->NumAxes() == 2 && (sbp_producer.sbp_parallel(0) == sbp_consumer.sbp_parallel(0) || sbp_producer.sbp_parallel(1) == sbp_consumer.sbp_parallel(1))) { return Maybe::Ok(); } // Not enough gain in transfer cost, do not use middle nodes int32_t partial_ratio4producer = PartialRatio4Producer(sbp_producer, producer_parallel_desc); int32_t broadcast_ratio4consumer = BroadcastRatio4Consumer(sbp_consumer, consumer_parallel_desc); if (2 * (partial_ratio4producer + broadcast_ratio4consumer) >= partial_ratio4producer * broadcast_ratio4consumer) { return Maybe::Ok(); } bool close2producer = true; if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()) { // Get close to the one with more splits close2producer = TotalNumSplit(sbp_producer, producer_parallel_desc) > TotalNumSplit(sbp_consumer, consumer_parallel_desc); } else { // Get close to the one with more machines close2producer = producer_parallel_desc.parallel_num() > consumer_parallel_desc.parallel_num(); } // Get the contiguous sbp if (close2producer) { JUST(AskCloseAllSplitSbp(sbp_producer, producer_parallel_desc, logical_blob_desc, middle_sbps)); *diag_node_pos = 1; } else { JUST(AskCloseAllSplitSbp(sbp_consumer, consumer_parallel_desc, logical_blob_desc, middle_sbps)); *diag_node_pos = 0; } return Maybe::Ok(); } // Ask for a all-split sbp which is close to the original one Maybe BoxingCollector::AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const BlobDesc& logical_blob_desc, std::vector& middle_sbps) { Shape remain_shape = logical_blob_desc.shape(); Shape rest_split_shape = logical_blob_desc.shape(); int32_t dim_shape = remain_shape.NumAxes(); // Initialize the remains and splitting // logical_blob_desc.shape() == remain_shape .* rest_split_shape; for (int32_t i = 0; i < dim_shape; i++) { rest_split_shape.Set(i, 1); } for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { const auto& sbp = nd_sbp.sbp_parallel(sbp_id); if (sbp.has_split_parallel()) { int32_t axis = sbp.split_parallel().axis(); int32_t split_num = parallel_desc.hierarchy()->At(sbp_id); remain_shape.Set(axis, remain_shape.At(axis) / split_num); rest_split_shape.Set(axis, rest_split_shape.At(axis) * split_num); } } // Get the contiguous sbp NdSbp new_sbp = nd_sbp; for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { const auto& sbp = nd_sbp.sbp_parallel(sbp_id); int32_t split_num = parallel_desc.hierarchy()->At(sbp_id); if (sbp.has_split_parallel()) { int32_t axis = sbp.split_parallel().axis(); // split shape is the total splitting number starting from sbp_id to the end rest_split_shape.Set(axis, rest_split_shape.At(axis) / split_num); } else { // change P or B to S(axis) int32_t axis = -1; // 4096 is large enough, we might not have that much devices int32_t min_split_num = 4096; // We need to pick a suitable axis for (int32_t i = 0; i < remain_shape.NumAxes(); i++) { if (remain_shape.At(i) % split_num == 0) { if (rest_split_shape.At(i) < min_split_num) { // Pick the axis with smallest splitting number among the rest of the sbp min_split_num = rest_split_shape.At(i); axis = i; } } } // P, B -> S(axis) if (axis >= 0) { new_sbp.mutable_sbp_parallel(sbp_id)->mutable_split_parallel()->set_axis(axis); remain_shape.Set(axis, remain_shape.At(axis) / split_num); } else { // Can not find a suitable contiguous sbp return Maybe::Ok(); } } } // Add the new sbp into the middle node lists middle_sbps.emplace_back(new_sbp); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/boxing_collector.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_ #include "oneflow/core/common/hash_container.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/framework/sbp_infer_util.h" namespace oneflow { class BoxingCollector final { public: BoxingCollector() = default; ~BoxingCollector() = default; // A constructor with init, designed for non-customized boxing collector BoxingCollector(int32_t max_axis); // Set default Sbp list void CollectUniverse(int32_t max_axis); // Construct a boxing collector with given maximum number of axis Maybe Init(int32_t max_axis); // Init with given blob description Maybe Init(const BlobDesc& logical_blob_desc, const ParallelDesc& parallel_desc); // Generate nd sbp list void GenerateNdSbpList(int32_t hierarchy_num); // Generate the map from 1d sbp to 2d sbp void GenerateMap1d2nd(); // Generate the transfer rule for different combinations with the same hierarchy Maybe GenerateCombination4SamePlacement(int32_t max_middle_node_num); Maybe GenerateCombination4SamePlacement(int32_t max_middle_node_num, const BlobDesc& blob_desc, const ParallelDesc& parallel_desc); // Generate the transfer rule for different combinations with different hierarchies // on the same placement Maybe GenerateCombination4DiffHierarchy(BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer); // Generate the transfer rule for different combinations with different placements Maybe GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer); Maybe GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc); // Print the cost and middle nodes void PrintBoxingTables(); // Ask if the boxing algorithm accepts the current sbp combination // If is_customized is true and we can not find a middle node list with // reasonable cost, error occurs. // If compute_cost is true, then no error occur even if no suitable middle nodes paths found. // For different placements, we would return a diagonal node. // Before this diagonal node (< *diag_node_pos), we use the parallel description of the producer. // After this diagonal node (>= *diag_node_pos), we use the parallel description of the consumer. Maybe AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost); // Filter nd sbp from nd_sbp_lists_ with given logical shape Maybe FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc, const Shape& parallel_hierarchy); private: // Collect Sbp Parallel void CollectUniverse(const SbpParallel& sbp); // Find corresponding id for Nd sbp int32_t FindId4NdSbp(const NdSbp& nd_sbp); // Ask for sbp combination with the same 2-D hierarchy and placement Maybe AskSbpCombination4Same2DPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost); // Ask for sbp combination with different hierarchies on the same placement Maybe AskSbpCombination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost); // Generate the transfer rule for one combination with different hierarchies on the same // placement. id_producer -> id_consumer. Maybe Generate1Combination4DiffHierarchy(int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, std::vector>& diag_nodes); // The cost for transferring a 1D sbp between different placements Maybe ComputeCostFor1DSbpDiffPlacement( const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, std::vector>& cost_4_diff_placement); // Generate the transfer rule for one combination with different placements // id_producer -> id_consumer. Maybe Generate1Combination4DiffPlacement( int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const std::vector>& cost_4_diff_placement, std::vector>& diag_nodes); // Ask for one combination with different hierarchies and placements Maybe Ask1Combination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool is_customized, std::vector& middle_sbps, int32_t* diag_node_pos, bool compute_cost, BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const std::vector>& diag_nodes); // Ask for sbp combination for general basic communication Maybe AskSbpCombination4GeneralBasicCommunication( const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, std::vector& middle_sbps, int32_t* diag_node_pos); // Ask for a all-split sbp which is closed to the original one Maybe AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const BlobDesc& logical_blob_desc, std::vector& middle_sbps); // Stores all the possible SbpParallel. HashMap sbp_parallel_universe_; // Relationship between id and Sbp Parallel std::vector id2sbp_parallel_; // minimum cost // minimum_copy_cost[producer][consumer] std::vector> minimum_copy_cost_; // middle nodes // middle_nodes_[producer][consumer][different choices] is a vector of middle nodes // middle_nodes_[producer][consumer][different choices].size() is the minimum number of middle // nodes that needs to be inserted std::vector>>> middle_nodes_; // Stores all the possible NdSbp. std::unordered_map nd_sbp_universe_; // Relationship between id and Nd Sbp std::vector nd_sbp_lists_; // The diagonal middle node for different placements std::vector>>> diag_node_diff_placement_; // The diagonal middle node for different hierarchies in the same placement std::vector>>> diag_node_diff_hierarchy_; // Id Map from 1d sbp to 2d sbp // For example: B -> (B, B), S0 -> (S0, S0) std::vector id_1d_2_nd_; // The sbp size in the combination table int32_t hierarchy_num_; // How the boxing collector is initialized int32_t init_type_ = -1; // Enable general basic communication or not const bool enable_general_basic_communication = ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); }; // class BoxingCollector } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_ ================================================ FILE: oneflow/core/auto_parallel/sbp_collector.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/auto_parallel/sbp_collector.h" #include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/auto_parallel/sbp_constructor.h" namespace oneflow { namespace auto_parallel { namespace { // Whether the given binary set intersects all the sbp sets of the consumers bool IfIntersectAll( const HashMap, BinarySet>& consumer_bn2sbp_set, const BinarySet& bs) { for (const auto& sbp_set_group : consumer_bn2sbp_set) { if (!bs.IfIntersect(sbp_set_group.second)) { return false; } } return true; } // Find unique sbp sets void FindUniqueSbpSets( const HashMap, BinarySet>& consumer_bn2sbp_set, const std::unordered_set& all_sbp_set, std::vector& accumulator, BinarySet& unique_sbps) { std::vector sbp_ids; // count the number of sbp for (const auto& sbp_set_group : consumer_bn2sbp_set) { sbp_set_group.second.QuickOutput(sbp_ids); for (int32_t sbp_id : sbp_ids) { accumulator[sbp_id]++; } } // find unique sbp and clear the accumulator for (const auto& sbp_id : all_sbp_set) { if (accumulator[sbp_id] == 1) { unique_sbps.AddEntry(sbp_id); } accumulator[sbp_id] = 0; } } // Find unique sbp groups void FindUniqueSbpGroups( const HashMap, BinarySet>& consumer_bn2sbp_set, const std::unordered_set& all_sbp_set, std::vector& accumulator, BinarySet& bs_buffer, std::vector& unique_sbp_groups) { // find the unique sbp sets BinarySet unique_sbps(accumulator.size()); FindUniqueSbpSets(consumer_bn2sbp_set, all_sbp_set, accumulator, unique_sbps); // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0} // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them for (const auto& sbp_set_group : consumer_bn2sbp_set) { unique_sbps.IntersectionTo(sbp_set_group.second, bs_buffer); // Find those unique sbp groups with more than two sbp // For example {B, S1, S2} is an impossible proxy candidate, // since {S1, S2} is only contained by A but not contained by C and D. // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2. if (bs_buffer.Total() >= 2) { unique_sbp_groups.push_back(bs_buffer); } } bs_buffer.Clear(); } // If not contains two sbp from a same unique group bool No2SbpFromSameUniqueGroup(const BinarySet& bs, const std::vector& unique_sbp_groups) { BinarySet intersection(bs.GetSizeOfSet()); for (const auto& unique_sbp_group : unique_sbp_groups) { bs.IntersectionTo(unique_sbp_group, intersection); // For example {B, S1, S2} is an impossible proxy candidate, // since {S1, S2} is only contained by A but not contained by C and D. // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2. if (intersection.Total() >= 2) { return false; } } return true; } } // namespace // Default constructor for SbpCollector // Don't allow any special case for broadcast! SbpCollector::SbpCollector() { // initialize Sbp Parallel Universe with broadcast. // NdSbp sbp_broadcast; // sbp_broadcast.mutable_broadcast_parallel(); // nd_sbp_universe_[sbp_broadcast] = 0; // id2nd_sbp_.push_back(sbp_broadcast); } // Collect all the possible Sbp Parallel from a NdSbpSignature void SbpCollector::CollectUniverse(const NdSbpSignature& nd_sbp_sig) { for (auto& bn_sbp_pair : nd_sbp_sig.bn_in_op2nd_sbp()) { if (nd_sbp_universe_.find(bn_sbp_pair.second) == nd_sbp_universe_.end()) { int32_t curr_size = nd_sbp_universe_.size(); nd_sbp_universe_[bn_sbp_pair.second] = curr_size; id2nd_sbp_.push_back(bn_sbp_pair.second); } } } // Collect all the possible Sbp Parallel from a SbpNode void SbpCollector::CollectUniverse(const SbpNode* sbp_node) { for (auto& nd_sbp_sig : sbp_node->sbp_sig_list_) { CollectUniverse(nd_sbp_sig); } } // Collect all the possible Sbp Parallel from a SbpGraph void SbpCollector::CollectUniverse(const SbpGraph& sbp_graph) { for (auto* sbp_node : sbp_graph.node_list_) { CollectUniverse(sbp_node); } accumulator_.resize(nd_sbp_universe_.size(), 0); bs_buffer_.Initialize(nd_sbp_universe_.size()); } // TODO: Auto Placement! // It only collect the same sbp with the same parallel description // In this moment their hierarchy is the same! // Initialize copy cost from producer to proxy of producer void SbpCollector::InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy, const LogicalBlobId& lbi) const { // the only edge from producer to proxy of producer SbpEdge* sbp_edge = sbp_proxy->edges_in_[0]; SbpNode* sbp_node_producer = sbp_edge->start_node_; sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size()); int32_t consumer_sbp_size = sbp_proxy->parallel_candidates_.size(); // look through sbp signature in producer for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size(); sbp_id_producer++) { sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0); } // Assemble copy cost from producer to proxy of producer OpNode* producer = sbp_node_producer->op_node_; // get parallel description. Number of devices. const ParallelDesc& producer_parallel_desc = producer->parallel_desc(); // Need to be careful, the logical blob description should be independent to current // NdSbp. Use producer or op_node? const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi); const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi)); // A buffer to store the sbp parallel id std::vector sbp_parallel_ids; // look through sbp signature in producer for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size(); sbp_id_producer++) { // get sbp parallel for a logical blob in producer const auto& producer_sbp_bn_in_op2sbp_parallel = sbp_node_producer->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp(); const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn); // look through sbp parallel set in consumer for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) { const BinarySet& sbp_parallel_set = sbp_proxy->parallel_candidates_[sbp_id_consumer]; sbp_parallel_set.QuickOutput(sbp_parallel_ids); // look through all sbp parallels in a sbp parallel set for (int32_t sbp_parallel_id : sbp_parallel_ids) { // get sbp parallel for a logical blob in consumer const NdSbp& sbp_consumer = id2nd_sbp_[sbp_parallel_id]; // compute copy cost for a specific logical blob // Use the parallel description of producer as those for consumer for now. sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] += CHECK_JUST(ComputeCopyCostWithMiddleNodes(sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, producer_parallel_desc, /*is_same=*/false)); } } } } // Initialize copy cost from proxy of producer to consumers void SbpCollector::InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, const HashMap, BinarySet>& consumer_bn2sbp_set, const HashMap& op_name2sbp_node) const { // Connect sbp proxy and consumers for (const auto& consumer_bn_group : consumer_bn2sbp_set) { // consumer in cost model SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second; // input blob name of logical blob in consumer const std::string& ibn = consumer_bn_group.first.second; // check is_mutable in consumer OpNode* consumer = sbp_node_consumer->op_node_; CHECK(!RequireSameSbp(consumer, ibn)) << "Create a proxy for an unsuitable consumer!\n"; // Connect sbp proxy and consumer sbp_proxy->PointTo(sbp_node_consumer); // the sbp edge connecting proxy and consumer SbpEdge* sbp_edge = sbp_node_consumer->FindEdgeWithNode(sbp_proxy); sbp_edge->cost_.resize(sbp_proxy->parallel_candidates_.size()); int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size(); // look through sbp parallel set in proxy for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_proxy->parallel_candidates_.size(); sbp_id_producer++) { // initialization for copy cost sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0); // get sbp parallel set for a logical blob in proxy BinarySet& parallel_candidate = sbp_proxy->parallel_candidates_[sbp_id_producer]; // look through sbp signatures in consumers for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) { // get sbp parallel for a logical blob in consumer const auto& consumer_sbp_bn_in_op2sbp_parallel = sbp_node_consumer->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp(); const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn); if ((!parallel_candidate.CheckExistence(nd_sbp_universe_.find(sbp_consumer)->second))) { sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] = GetMaxVal(); } } } } } // Export list of possible combination of Sbp Parallels void SbpCollector::ProxySbpCandidate(const OpGraph& op_graph, const HashMap& op_name2sbp_node, SbpGraph& sbp_graph) { // If needed, we can output the mapping from operator name to its proxy. // HashMap>& // op_name2lbi2sbp_proxy; // mapping from a logical blob id to index HashMap lbi2index; // mapping from the index to producer, consumer and corresponding input blob name, possible sbp // sets std::vector index2producer; std::vector> index2sbp_set; // mapping from consumers and input blob names to an unordered_set of SBP Parallel. std::vector, BinarySet>> index2consumer_bn2sbp_set; for (auto* consumer_sbp_node : sbp_graph.node_list_) { auto* node = consumer_sbp_node->op_node_; OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case(); // If not support boxing, just skip it. if (IsClassRegistered(op_type_case)) { return; } for (const std::string& ibn : node->op().input_bns()) { // Skip those blobs who enforce same SBP. if (RequireSameSbp(node, ibn)) { // Enforcing same SBP. Can not collect sbp from this blob. continue; } const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn); const OpNode& producer = node->ProducerOpNode4Lbi(lbi); // not building proxy for fixed operators if (op_name2sbp_node.find(producer.op().op_name()) == op_name2sbp_node.end()) { return; } // decide the index of a logical blob description const auto& iterator_lbi = lbi2index.find(lbi); int32_t index = 0; if (iterator_lbi == lbi2index.end()) { index = lbi2index.size(); lbi2index[lbi] = index; // map from lbi to the producer index2producer.push_back(&producer); // Initialize consumer_bns and the sbp sets index2consumer_bn2sbp_set.resize(index + 1); index2sbp_set.resize(index + 1); } else { index = iterator_lbi->second; } // a set to store the id of all possible SBP Parallel for a downstream op // should filter out repeated SBP Parallel by pre-storing them into an unordered_set BinarySet& nd_sbp_ids = index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}]; nd_sbp_ids.Initialize(nd_sbp_universe_.size()); // The union sbp set of all the consumers std::unordered_set& union_nd_sbp_ids = index2sbp_set[index]; for (auto& sbp_sig : consumer_sbp_node->sbp_sig_list_) { const auto& map = sbp_sig.bn_in_op2nd_sbp(); const auto& iter = map.find(ibn); CHECK(iter != map.end()) << "blob_name " << ibn << " not found in sbp signature"; const NdSbp& consumer_sbp = iter->second; // filter out repeated SBP int32_t sbp_universe_id = nd_sbp_universe_.find(consumer_sbp)->second; nd_sbp_ids.AddEntry(sbp_universe_id); union_nd_sbp_ids.insert(sbp_universe_id); } } }; // A set of binary set with broadcast only // std::unordered_set parallel_candidates_initializer; // BinarySet one_broadcast(nd_sbp_universe_.size()); // one_broadcast.AddEntry(0); // parallel_candidates_initializer.insert(std::move(one_broadcast)); // Decide if we should insert a proxy for each logical blob for (auto& lbi_index : lbi2index) { int32_t index = lbi_index.second; // Only insert proxy for those blobs with multiple downstream consumers. if (index2consumer_bn2sbp_set[index].size() < 2) { continue; } // Maximum number of possible sbp in the proxy int32_t max_num_sbp_proxy = std::min(max_num_sbp_proxy_, index2consumer_bn2sbp_set[index].size()); // producer in cost model const std::string& producer_name = index2producer[index]->op().op_name(); SbpNode* sbp_node_producer = op_name2sbp_node.find(producer_name)->second; const LogicalBlobId& lbi = lbi_index.first; // store all the binary sets of SBP Parallel into an unordered_set. // std::vector parallel_candidates; // generate sbp proxy SbpNode* sbp_proxy = sbp_graph.GenerateNode(); // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0} // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them std::vector unique_sbp_groups; FindUniqueSbpGroups(index2consumer_bn2sbp_set[index], index2sbp_set[index], accumulator_, bs_buffer_, unique_sbp_groups); // Depth first search to collect Sbp Parallel information for the whole sbp set DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(), index2consumer_bn2sbp_set[index], unique_sbp_groups, sbp_proxy->parallel_candidates_); // Initialize computation cost sbp_proxy->cost_.resize(sbp_proxy->parallel_candidates_.size(), 0); // Transfer a logical blob from producer to a sbp proxy of this blob sbp_node_producer->PointTo(sbp_proxy); // Compute copy cost between producer and proxy InitializeCopyCostFromNode2Proxy(sbp_proxy, lbi); // Build connection and compute copy cost between proxy and consumers InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bn2sbp_set[index], op_name2sbp_node); // Unloading for (const auto& consumer_bn_group : index2consumer_bn2sbp_set[index]) { // consumer in cost model SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second; // the sbp edge connecting producer and consumer SbpEdge* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer); // unload logical blob from sbp edges edge_found->UnloadLbi(lbi); // Do not clip this edge. Save it for wait time. // clip this edge if it no longer carries any blob // We don't clip edges before since we have transfer cost // Now we clip edges, which makes the topology simpler if (edge_found->EmptyLbi() && edge_found->wait_time_ <= 0.0 && edge_found->wait_time_ > -0.5) { sbp_graph.ClipEdge(edge_found); } } } } // Depth first search to collect Sbp Parallel information for different logical blob ids void SbpCollector::DfsSbpSet( int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, const std::unordered_set::iterator& start_it, const HashMap, BinarySet>& consumer_bn2sbp_set, const std::vector& unique_sbp_groups, std::vector& parallel_candidates) { if (depth > 0) { if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer_) && No2SbpFromSameUniqueGroup(bs_buffer_, unique_sbp_groups)) { // store the binary set into an unordered_set parallel_candidates.push_back(bs_buffer_); } } if (depth >= max_depth) { return; } // go through the rest of the sbp parallel std::unordered_set::iterator curr_it = start_it; while (curr_it != sbp_sets.end()) { // Take the value out int32_t nd_sbp_num = *curr_it; // Then move to the next pointer ++curr_it; if (accumulator_[nd_sbp_num] == 0) { bs_buffer_.AddEntry(nd_sbp_num); ++accumulator_[nd_sbp_num]; DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, unique_sbp_groups, parallel_candidates); bs_buffer_.DeleteEntry(nd_sbp_num); --accumulator_[nd_sbp_num]; } } } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_collector.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef SBP_COLLECTOR_ #define SBP_COLLECTOR_ #include #include #include #include #include #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/job/job_builder.h" // #include "sbp_constructor.h" #define DEBUG_COLLECTOR_ namespace oneflow { namespace auto_parallel { class SbpCollector { public: SbpCollector(); ~SbpCollector() {} // Collect all the possible Sbp Parallel from a SbpGraph void CollectUniverse(const SbpGraph& sbp_graph); // Export list of possible combination of Sbp Parallels void ProxySbpCandidate(const OpGraph& op_graph, const HashMap& op_name2sbp_node, SbpGraph& sbp_graph); private: // Stores all the possible NdSbp. std::unordered_map nd_sbp_universe_; // Relationship between id and Sbp Parallel std::vector id2nd_sbp_; // Calculate number of downstream sbp std::vector accumulator_; // A binary set buffer to indicate sets of downstream sbp BinarySet bs_buffer_; // Collect all the possible Sbp Parallel from a NdSbpSignature void CollectUniverse(const NdSbpSignature& nd_sbp_sig); // Collect all the possible Sbp Parallel from a SbpNode void CollectUniverse(const SbpNode* sbp_node); // Initialize copy cost from producer to proxy of producer void InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy, const LogicalBlobId& lbi) const; // Initialize copy cost from proxy of producer to consumers void InitializeCopyCostFromProxy2Consumer( SbpNode* sbp_proxy, const HashMap, BinarySet>& consumer_bn2sbp_set, const HashMap& op_name2sbp_node) const; // Maximum number of possible sbp in the proxy const unsigned long max_num_sbp_proxy_ = 3; // Depth first search to collect Sbp Parallel information for the whole sbp set void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set& sbp_sets, const std::unordered_set::iterator& sbp_set_it, const HashMap, BinarySet>& consumer_bn2sbp_set, const std::vector& unique_sbp_groups, std::vector& parallel_candidates); }; // class SbpCollector } // namespace auto_parallel } // namespace oneflow #endif // SBP_COLLECTOR_ ================================================ FILE: oneflow/core/auto_parallel/sbp_constructor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/sbp_constructor.h" #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/auto_parallel/sbp_collector.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace auto_parallel { namespace { // AMS, a.k.a. Applied Mathematics & Statistics, is a department of the Stony Brook University. // It contains 5 tracks: Computational & Applied Mathematics, Computational Biology, // Operation Research, Quantitative Finance, Statistics. AutoMemoryStrategy ams; // kMemoryRatio increase by this rate at each time. static const double kMemoryIncreaseRatio = 2.0; // The ceil of kMemoryRatio. static const double kMaxMemoryRatio = 22.0; // The floor of kMemoryRatio static const double kMinMemoryRatio = 0.1; // If the current memory > available memory * kImpossibleRatio, // then it is impossible to reduce the memory to an acceptable size static const double kImpossibleRatio = 1.4; // Pick from 5 fixed types of memory ratio. double UpdateMemoryRatio() { switch (ams) { case kAdaptiveAutoMemory: case kDisableAutoMemory: return 0.0; case kSlightAutoMemory: return 0.4; case kModerateAutoMemory: return 4.3; default: return 11.0; // case kHeavyAutoMemory } } } // namespace double kMemoryRatio; Maybe SbpConstructor::Init(const OpGraph& op_graph, Job* job /*Maybe not use*/) { JUST(InitSbpGraph(op_graph, *job)); return Maybe::Ok(); } Maybe SbpConstructor::InitSbpGraph(const OpGraph& op_graph, const Job& job) { // Update nccl_use_compute_stream nccl_use_compute_stream_ = Singleton::Get()->nccl_use_compute_stream(); ams = job.job_conf().enable_auto_memory(); kMemoryRatio = UpdateMemoryRatio(); // TODO: process local node JUST(GenerateNodeAndEdge(op_graph, job)); JUST(FillSbpSignatureForOpNode(op_graph, job)); JUST(InitComputationCost(op_graph)); if (enable_trunk_algo_) { JUST(ApplyTrunkAlgo()); } // Load logical blobs on all sbp edges. LoadLbi2SbpEdge(op_graph); // InitMemory() should be run before the sbp collector and after the ApplyTrunkAlgo() and // LoadLbi2SbpEdge(op_graph). InitAvailableMemory(); InitMemory(op_graph, &sbp_graph_, nccl_use_compute_stream_); if (use_sbp_collector_) { // Use sbp collector to create sbp proxy for nodes with multiple downstream operators. SbpCollector sbp_collector; sbp_collector.CollectUniverse(sbp_graph_); // TODO: Init memory cost for proxy sbp_collector.ProxySbpCandidate(op_graph, op_name2sbp_node_, sbp_graph_); } JUST(InitCopyAndMemoryCost(op_graph)); // We need to store the original cost and memory after the initialization (InitComputationCost(), // InitMemory(), InitCopyAndMemoryCost()) and before the usage of them (InitWeightedCost()) sbp_graph_.StoreOriginMemory(); InitWeightedCost(); // TODO: Set all the sbp signature id to be 0 for initialization. // Could revert it back to // sbp_graph_.RandomSbpSignature(use_sbp_collector_); // after settling down the synchronization of sbp strategy. sbp_graph_.SetDefaultSbpSig(); double ori_cost = sbp_graph_.ComputeCost(); LOG(INFO) << "Initial cost: " << ori_cost; // If we do not prune those parallel cast ops, steal the initial strategy from user setting and // semi-auto parallelism if (!job.job_conf().enable_auto_parallel_ignore_user_sbp_config()) { JUST(StealSbpSignatureFromOpNode(op_graph, job)); ori_cost = sbp_graph_.ComputeCost(); LOG(INFO) << "OpGraph cost: " << ori_cost; } return Maybe::Ok(); } Maybe SbpConstructor::FindBestSbpSignature() { double ori_cost = sbp_graph_.ComputeCost(); LOG(INFO) << "Initial cost: " << ori_cost; int elimination_num = sbp_graph_.NodeAndEdgeEliminations(); LOG(INFO) << "Elimination number: " << elimination_num; if (ori_cost > GetValidMaxCopyCost()) { JUST(sbp_graph_.Find1Strategy4Greedy()); ori_cost = sbp_graph_.ComputeCost(); LOG(INFO) << "Greedy cost: " << ori_cost; } int32_t step = 1; while (true) { sbp_graph_.GreedyStrategy(/*nbh_num=*/4); double curr_memory = sbp_graph_.GetMemory(); double total_weighted_cost = sbp_graph_.ComputeWeightedCost(); LOG(INFO) << "The " << step << "-th try, memory ratio: " << kMemoryRatio << ", memory: " << curr_memory << ", total cost: " << total_weighted_cost << ", time cost: " << (total_weighted_cost - kMemoryRatio * curr_memory); if (ams != AutoMemoryStrategy::kAdaptiveAutoMemory) { break; } if (curr_memory < available_memory_ || kMemoryRatio >= kMaxMemoryRatio) { break; } if (curr_memory > available_memory_ * kImpossibleRatio) { kMemoryRatio = kMaxMemoryRatio; } else { kMemoryRatio = std::max(std::min(kMaxMemoryRatio, kMemoryRatio * kMemoryIncreaseRatio), kMinMemoryRatio); } step++; sbp_graph_.ReComputeWeightedCost(); } sbp_graph_.FinalizeSbp(); double final_cost = sbp_graph_.ComputeCost(); LOG(INFO) << "Final cost: " << final_cost; // TODO: Restart searching with another original random strategy CHECK_LT_OR_RETURN(final_cost, GetValidMaxCopyCost()) << "Failed! Auto parallel can't find a strategy with reasonable cost!"; return Maybe::Ok(); } Maybe SbpConstructor::DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job) { for (auto& op_conf : *job->mutable_net()->mutable_op()) { const OpNode* node = op_graph.OpNode4OpName(op_conf.name()); SbpNode* sbp_node = op_name2sbp_node_[node->op().op_name()]; const NdSbpSignature& nd_sbp_sig = sbp_node->FinalSbpSignature(); // Update NdSbpSignature (*job->mutable_job_parallel_view_conf() ->mutable_op_name2nd_sbp_signature_conf())[node->op().op_name()] .CopyFrom(nd_sbp_sig); // If we have 1D SbpSignature Conf if (node->parallel_desc().hierarchy()->NumAxes() == 1) { // Update SbpSignature SbpSignature sbp_signature; NdSbpSignatureToSbpSignature(nd_sbp_sig, &sbp_signature); (*job->mutable_job_parallel_view_conf() ->mutable_op_name2sbp_signature_conf())[node->op().op_name()] .CopyFrom(sbp_signature); } JUST(node->op().GetDumpNdSbpSignatureForOpConfFn()(nd_sbp_sig, &op_conf)); } return Maybe::Ok(); } Maybe SbpConstructor::GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job) { JobParallelViewConf job_parallel_view_conf(job.job_parallel_view_conf()); // Collect op_node std::vector op_node_list; op_graph.ForEachNode([&](OpNode* op_node) { // TODO: support local op bool is_local_conf = false; { const auto& op_name2is_local = job_parallel_view_conf.op_name2is_local_parallel_view(); const auto& iter = op_name2is_local.find(op_node->op().op_name()); if (iter != op_name2is_local.end()) { is_local_conf = iter->second; } } CHECK(is_local_conf == false) << "Haven't deal with local operators."; op_node_list.push_back(op_node); }); // Decide the order to visit the op std::vector order; auto CompareOpName = [&](OpNode* a, OpNode* b) { return a->op().op_name().compare(b->op().op_name()) > 0; }; auto_parallel::DecideOrder(op_node_list, order, CompareOpName); std::vector output_order; // Create sbp nodes for (int32_t i = 0; i < op_node_list.size(); i++) { OpNode* op_node = op_node_list[order[i]]; // Generate sbp node in cost model and link it with corresponding op node SbpNode* sbp_node = sbp_graph_.GenerateNode(); // Mapping from sbp_node to op_node sbp_node->op_node_ = op_node; // TODO: SetOpNode() op_name2sbp_node_[op_node->op().op_name()] = sbp_node; } // Create sbp edges for (int32_t i = 0; i < op_node_list.size(); i++) { OpNode* op_node = op_node_list[order[i]]; // Get corresponding sbp node SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()]; std::vector output_node_list; for (const auto* op_edge : op_node->out_edges()) { output_node_list.push_back(op_edge->dst_node()); } auto_parallel::DecideOrder(output_node_list, output_order, CompareOpName); for (int32_t j : output_order) { const auto& end_node_name = output_node_list[j]->op().op_name(); // Generate sbp edge in cost model sbp_node->PointTo(op_name2sbp_node_[end_node_name]); } } return Maybe::Ok(); } Maybe SbpConstructor::FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job) { // TODO: use user sbp signature in JobParallelViewConf // const JobParallelViewConf& job_parallel_view_conf(job.job_parallel_view_conf()); JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe { HashMap ibn2blob_desc; auto FindShape4Blobs = [&](const PbRpf& bns) -> Maybe { for (const std::string& ibn : bns) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); const BlobDesc* logical_blob_desc = &op_node->LogicalBlobDesc4Lbi(lbi); ibn2blob_desc.emplace(ibn, logical_blob_desc); } return Maybe::Ok(); }; JUST(FindShape4Blobs(op_node->op().input_bns())); JUST(FindShape4Blobs(op_node->op().output_bns())); // Get logical blob description auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe { const auto& it = ibn2blob_desc.find(ibn); if (it == ibn2blob_desc.end()) { return Error::InvalidValueError() << "Cannot find corresponding blob description for input_blob_name : " + ibn + " in " + op_node->op().op_name(); } return *(it->second); }; // Get all valid sbp_signatures SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()]; JUST(op_node->op().GetValidNdSbpSignatureList(LogicalBlobDesc4Ibn, op_node->parallel_desc(), &sbp_node->sbp_sig_list_, /*check_output=*/true)); sbp_node->InitializeSbp(); return Maybe::Ok(); })); return Maybe::Ok(); } Maybe SbpConstructor::StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job) { // Steal some strategy from original op graph for (auto* sbp_node : sbp_graph_.node_list_) { // sbp_collectors do not have op_node if (sbp_node->op_node_) { for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) { if (*JUST(sbp_node->op_node_->op().nd_sbp_signature()) == sbp_node->sbp_sig_list_[sbp_id]) { sbp_node->final_sbp_sig_id_ = sbp_id; break; } } } } return Maybe::Ok(); } Maybe SbpConstructor::InitComputationCost(const OpGraph& op_graph) { // Compute computation cost for sbp nodes JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe { // get corresponding sbp node producer SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()]; // get parallel description. Number of devices. const ParallelDesc& parallel_desc = op_node->parallel_desc(); CHECK_EQ_OR_RETURN(sbp_node->cost_.size(), sbp_node->sbp_sig_list_.size()); auto LogicalBlobDesc4Bn = [&](const std::string& bn) -> const BlobDesc& { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn); return op_node->LogicalBlobDesc4Lbi(lbi); }; for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) { double comp_cost = JUST(op_node->op().GetComputeComplexity( &sbp_node->sbp_sig_list_[sbp_id], LogicalBlobDesc4Bn, parallel_desc)); if (comp_cost > GetValidMaxCopyCost()) { sbp_node->cost_[sbp_id] = comp_cost; } else { sbp_node->cost_[sbp_id] = cost_ratio_ * comp_cost * JUST(op_node->op().GetInputOutputFastestTimeShape())->elem_cnt(); } } return Maybe::Ok(); })); return Maybe::Ok(); } // Init copy cost and memory for edges Maybe SbpConstructor::InitCopyAndMemoryCost(const OpGraph& op_graph) { bool nccl_not_use_compute_stream = !nccl_use_compute_stream_; // Compute copy cost for sbp edges op_graph.ForEachNode([&](OpNode* op_node) { // get corresponding sbp node consumer SbpNode* sbp_node_consumer = op_name2sbp_node_[op_node->op().op_name()]; // Initialize copy cost between two nodes for (auto* sbp_edge : sbp_node_consumer->edges_in_) { // producer sbp node const auto* sbp_node_producer = sbp_edge->start_node_; // skip it if proxy if (!sbp_node_producer->op_node_) { continue; } sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size()); if (nccl_not_use_compute_stream) { sbp_edge->memory_.resize(sbp_node_producer->sbp_sig_list_.size()); } int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size(); // look through sbp signature in producer for (int32_t i = 0; i < sbp_node_producer->sbp_sig_list_.size(); ++i) { sbp_edge->cost_[i].resize(consumer_sbp_size, 0); if (nccl_not_use_compute_stream) { sbp_edge->memory_[i].resize(consumer_sbp_size, 0); } } } // Find all those cases with wait time // Do not skip edges carrying no lbi sbp_node_consumer->InitCopyAndMemoryCost(use_sbp_collector_, nccl_not_use_compute_stream); }); return Maybe::Ok(); } Maybe SbpConstructor::ApplyTrunkAlgo() { // TODO: Remove this auto OpNode2MutableOpCtrlDeps = JUST(GetMutableOpCtrlDeps(*op_graph_)); // Compute layer number for each node int32_t max_min_layer = sbp_graph_.ComputeLayer(op_name2sbp_node_, *OpNode2MutableOpCtrlDeps); // Accumulate cost on the trunk after initializing computation cost sbp_graph_.FindTrunk(max_min_layer, op_name2sbp_node_); return Maybe::Ok(); } // Load logical blob ids onto sbp edges void SbpConstructor::LoadLbi2SbpEdge(const OpGraph& op_graph) { // Load logical blobs onto sbp edges for (auto* sbp_node_consumer : sbp_graph_.node_list_) { auto* op_node = sbp_node_consumer->op_node_; // Loading logical blobs between two nodes // look through input blobs for (const std::string& ibn : op_node->op().input_bns()) { // Each input blob has one source op node. OpNode* producer = op_node->MutSrcNode4Ibn(ibn); // producer sbp node const auto* sbp_node_producer = op_name2sbp_node_[producer->op().op_name()]; // TODO: recode this auto* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer); CHECK(edge_found != NULL) << "SbpEdge not found while loading!" << std::endl; // Add copy cost for each blob const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); edge_found->LoadLbi(lbi); } }; } Maybe SbpConstructor::CheckSbpAgreement(const Job& job) { Job new_job; new_job.CopyFrom(job); OpGraph op_graph(new_job); // Compare sbp in job JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe { const std::string& op_name = op_node->op().op_name(); const NdSbpSignature& auto_parallel_sbp = NdSbpSignature(job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name)); const NdSbpSignature& new_sbp = op_node->nd_sbp_signature(); CHECK_EQ_OR_RETURN(auto_parallel_sbp.bn_in_op2nd_sbp_size(), new_sbp.bn_in_op2nd_sbp_size()); for (const auto& iter : auto_parallel_sbp.bn_in_op2nd_sbp()) { const NdSbp& new_sbp_parallel = new_sbp.bn_in_op2nd_sbp().at(iter.first); const NdSbp& auto_parallel_sbp = iter.second; // According error message, we can find op_type in op_conf.proto with type_id and locate // the error op type. const std::string& error_mgs = "Op: `" + op_name + "`(type_id: " + std::to_string(op_node->op().op_conf().op_type_case()) + ") changed sbp from " + NdSbpToString(auto_parallel_sbp) + "(AutoParallel) to " + NdSbpToString(new_sbp_parallel) + "(OpGraph) with blob_name: `" + iter.first + "`."; CHECK_OR_RETURN(new_sbp_parallel == auto_parallel_sbp) << error_mgs; } return Maybe::Ok(); })); return Maybe::Ok(); } // TODO: delete this, this is for variable op only Maybe>> SbpConstructor::GetMutableOpCtrlDeps( const OpGraph& op_graph) { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { for (const std::string& bn : op.input_bns()) { if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; } } return false; }; const auto& IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); HashMap> op_node2ctrl_in_op_names; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe::Ok(); } if (op_node->out_edges().size() <= 1) { return Maybe::Ok(); } const Operator& variable_op = op_node->op(); const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn()); const OpNode* mutable_consumer = nullptr; std::vector naive_consumers; naive_consumers.reserve(op_node->out_edges().size()); for (OpEdge* edge : op_node->out_edges()) { const auto& op_conf = edge->dst_node()->op().op_conf(); if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) { CHECK_OR_RETURN(mutable_consumer == nullptr); mutable_consumer = edge->dst_node(); } else { naive_consumers.emplace_back(&op_conf); } } if (mutable_consumer == nullptr) { return Maybe::Ok(); } for (const auto* fw_bw_op : naive_consumers) { op_node2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name()); } return Maybe::Ok(); })); // Filter ctrl edges if all ctrl_in_op_names are reachable HashMap> filter_op_ctrl_deps; for (const auto& pair : op_node2ctrl_in_op_names) { const OpNode* op_node = pair.first; for (const auto& fw_bw_op_name : pair.second) { if (!IsReachable(fw_bw_op_name, op_node->op().op_name())) { filter_op_ctrl_deps[op_node].insert(fw_bw_op_name); } } } return filter_op_ctrl_deps; } void SbpConstructor::InitAvailableMemory() { size_t free = 0; size_t total = 0; #ifdef WITH_CUDA CudaCurrentDeviceGuard guard(GlobalProcessCtx::Rank()); OF_CUDA_CHECK(cudaMemGetInfo(&free, &total)); #else free = 1e13; // 10T = 10,000G total = 1e13; // 10T = 10,000G LOG(INFO) << "We do not use CUDA in CPU mode, auto memory is unnecessary since all the SBPs are " "Broadcast."; #endif // The estimated memory differs from the lower bound of the peak memory by the first ratio. // The first ratio varies from -3% to 3.2% if not enabling nccl_use_compute_stream. // It varies from 0.00313% to 0.5% if enabling nccl_use_compute_stream. double first_ratio = 1.0; if (nccl_use_compute_stream_) { first_ratio = 1.01; } else { first_ratio = 1.04; } // The lower bound of the peak memory differs from the allocated memory by the second ratio. // The second ratio varies from 0 to 2.65% if not using pipeline parallelism. // It varies from 0 to 5.23% if using pipeline parallelism. double second_ratio = 1.06; // The occupied memory at this moment would be around 1114MB to 1240MB. // When it gets to the training process, the occupied memory might drop by 162MB. // But the key is that we start to allocate memory before the training process. // Thus, this 161MB should not be added to the free memory. // We still use "available memory = free / ratio" instead of "free / ratio + 161MB". available_memory_ = int64_t(free / (first_ratio * second_ratio)); LOG(INFO) << "Free memory: " << free << ", total memory: " << total << ", available memory: " << available_memory_; } void SbpConstructor::InitWeightedCost() { for (auto& sbp_node : sbp_graph_.node_list_) { sbp_node->ComputeWeightedCost(); for (auto& sbp_edge : sbp_node->edges_in_) { sbp_edge->ComputeWeightedCost(); } } } // Print the graph with SBP in order void SbpConstructor::PrintSBPGraphDebugInfo() { // sbp constructor information std::cout << "cost_ratio_:" << cost_ratio_ << std::endl; std::cout << "wait_time_:" << sbp_graph_.wait_time_ << std::endl; std::cout << "use_sbp_collector_" << use_sbp_collector_ << std::endl; std::cout << "Total auto parallel guessed memory: " << sbp_graph_.GetMemory() << std::endl; std::cout << "Final memory ratio: " << kMemoryRatio << std::endl; // test debug std::cout << "Get Into Print Op Graph" << std::endl; // Collect op_node std::vector node_list; for (const auto& op_name_sbp_node : op_name2sbp_node_) { auto* op_node_ = op_name_sbp_node.second->op_node_; if (op_node_) { node_list.push_back(op_node_); } } // test debug std::cout << "Deciding order" << std::endl; // Decide the order to visit the op std::vector order; auto_parallel::DecideOrder(node_list, order, [&](OpNode* a, OpNode* b) { return a->op().op_name().compare(b->op().op_name()) > 0; }); std::vector str_order; // test debug std::cout << "Finish deciding order" << std::endl; for (int32_t i = 0; i < node_list.size(); i++) { OpNode* op_node = node_list[order[i]]; std::cout << op_node->op().op_name() << " (^_^):" << std::endl; // get corresponding sbp node const auto& it = op_name2sbp_node_.find(op_node->op().op_name()); // Print debug information for sbp graph CHECK(it != op_name2sbp_node_.end()); const SbpNode* sbp_node = it->second; std::cout << "Computation Cost: " << sbp_node->weighted_cost_[sbp_node->final_sbp_sig_id_]; std::cout << ", Min Layer: " << sbp_node->min_layer_ << ", Max Layer: " << sbp_node->max_layer_ << ", Tributary Layer: " << sbp_node->tributary_layer_ << ", in trunk: " << sbp_node->on_trunk_ << ", Remain Cost: " << sbp_node->acc_trunk_cost_ << std::endl; // Sort before printing const auto& op_input_bns = op_node->op().input_bns(); auto CompareString = [](const std::string& a, const std::string& b) { return a.compare(b) > 0; }; auto_parallel::DecideOrder(op_input_bns, str_order, CompareString); const NdSbpSignature& sbp_signature = sbp_node->FinalSbpSignature(); // Print out SBP information for input operator for (int32_t j : str_order) { const auto& ibn = op_input_bns[j]; const auto& producer_node = op_node->SrcNode4Ibn(ibn); std::cout << "Pre Op:" << producer_node.op().op_name() << ": " << ibn; const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(ibn); std::cout << ", " << NdSbpToString(this_sbp_parallel); if (RequireSameSbp(op_node, ibn)) { std::cout << ", require same SBP"; } std::cout << ", " << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(ibn)).shape(); std::cout << std::endl; } // Sort before printing const auto& op_output_bns = op_node->op().output_bns(); auto_parallel::DecideOrder(op_output_bns, str_order, CompareString); // Print out SBP information for output blobs for (int32_t j : str_order) { const auto& obn = op_output_bns[j]; std::cout << "Out Op:" << obn; const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(obn); std::cout << ", " << NdSbpToString(this_sbp_parallel); std::cout << ", " << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(obn)).shape(); std::cout << std::endl; } std::cout << std::endl; } } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_constructor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/job/global_for.h" namespace oneflow { class OpGraph; class Job; namespace auto_parallel { // A constructor which will assemble the sbp_graph with the information from oneflow. // SbpGraph contains the algorithms for elimination and search which is mainly for the strategy // itself. Constructor mainly deal with the assemblage of each node, edge and the cost computation, // activation of functions. class SbpConstructor final { public: OF_DISALLOW_COPY_AND_MOVE(SbpConstructor); SbpConstructor() = delete; SbpConstructor(const OpGraph& op_graph, Job* job) : cost_ratio_(job->job_conf().auto_parallel_computation_cost_ratio()), enable_trunk_algo_(job->job_conf().enable_auto_parallel_trunk_algo()), use_sbp_collector_(!Singleton::Get() ->resource() .disable_group_boxing_by_dst_parallel() && job->job_conf().enable_auto_parallel_sbp_collector()), op_graph_(&op_graph) { sbp_graph_.SetWaitTime(job->job_conf().auto_parallel_wait_time()); CHECK_JUST(Init(op_graph, job)); } ~SbpConstructor() = default; Maybe Init(const OpGraph& op_graph, Job* job); Maybe FindBestSbpSignature(); Maybe DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job); // Re-build OpGraph and check all sbp is same between op_graph and job Maybe CheckSbpAgreement(const Job& job); // Print the graph with SBP in order void PrintSBPGraphDebugInfo(); private: Maybe InitSbpGraph(const OpGraph& op_graph, const Job& job); Maybe GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job); Maybe FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job); Maybe StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job); Maybe InitComputationCost(const OpGraph& op_graph); Maybe InitCopyAndMemoryCost(const OpGraph& op_graph); Maybe ApplyTrunkAlgo(); Maybe>> GetMutableOpCtrlDeps(const OpGraph& op_graph); void InitAvailableMemory(); void InitWeightedCost(); // Load logical blob ids onto sbp edges void LoadLbi2SbpEdge(const OpGraph& op_graph); double cost_ratio_; bool enable_trunk_algo_; bool use_sbp_collector_; SbpGraph sbp_graph_; const OpGraph* op_graph_; HashMap op_name2sbp_node_; bool nccl_use_compute_stream_; int64_t available_memory_; }; } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_ ================================================ FILE: oneflow/core/auto_parallel/sbp_edge.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/auto_parallel/sbp_edge.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { namespace auto_parallel { extern double kMemoryRatio; // function in cpp. Should be put in one file due to use of template // Otherwise we will need to declare specific template at the end of cpp file. SbpEdge::SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge, SbpEdge* second_edge) : start_node_(start_node), mid_node_(mid_node), end_node_(end_node) { // The first edge must between start_node and mid_node, but it could be // start_node -> mid_node or mid_node -> start node // Same for the second edge. edge_list_.emplace_back(first_edge); edge_list_.emplace_back(second_edge); }; // Deconstructor SbpEdge::~SbpEdge() { if (mid_node_ != nullptr) { delete mid_node_; } for (auto& this_edge : edge_list_) { delete this_edge; } } void SbpEdge::SummarizeCost() { // If any sub data structure is in the memory support, // then this edge is in the memory support if (mid_node_ && mid_node_->in_memory_support_) { in_memory_support_ = true; } else { in_memory_support_ = std::any_of(edge_list_.begin(), edge_list_.end(), [](SbpEdge* sbp_edge) { return sbp_edge->in_memory_support_; }); } // We would need to compute the memory for this elimination int32_t start_node_sbp_size = start_node_->weighted_cost_.size(); if (in_memory_support_) { memory_.resize(start_node_sbp_size); } weighted_cost_.resize(start_node_sbp_size); // Copy cost and memory cost if (mid_node_) { // Buffer int64_t memory_cost = 0; int64_t min_memory_cost = 0; int32_t min_sbp_mid = 0; double weighted_cost = 0.0; double min_weighted_cost = 0.0; // Node elimination mid_node_sbp_sig_.resize(start_node_sbp_size); int32_t end_node_sbp_size = end_node_->weighted_cost_.size(); int32_t mid_node_sbp_size = mid_node_->weighted_cost_.size(); for (int32_t sbp_start = 0; sbp_start < start_node_sbp_size; sbp_start++) { if (in_memory_support_) { memory_[sbp_start].resize(end_node_sbp_size); } weighted_cost_[sbp_start].resize(end_node_sbp_size); mid_node_sbp_sig_[sbp_start].resize(end_node_sbp_size); for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) { for (int32_t sbp_mid = 0; sbp_mid < mid_node_sbp_size; sbp_mid++) { // Add middle node cost memory_cost = mid_node_->GetMemory(sbp_mid); weighted_cost = mid_node_->weighted_cost_[sbp_mid]; // Add first edge cost if (edge_list_[0]->end_node_ == mid_node_) { int32_t edge_sbp_start = start_node_->GetComponentSbpId(sbp_start, edge_list_[0]->start_node_); memory_cost += edge_list_[0]->GetMemory(edge_sbp_start, sbp_mid); weighted_cost += edge_list_[0]->weighted_cost_[edge_sbp_start][sbp_mid]; } else { int32_t edge_sbp_start = start_node_->GetComponentSbpId(sbp_start, edge_list_[0]->end_node_); memory_cost += edge_list_[0]->GetMemory(sbp_mid, edge_sbp_start); weighted_cost += edge_list_[0]->weighted_cost_[sbp_mid][edge_sbp_start]; } // Add second edge cost if (edge_list_[1]->start_node_ == mid_node_) { int32_t edge_sbp_end = end_node_->GetComponentSbpId(sbp_end, edge_list_[1]->end_node_); memory_cost += edge_list_[1]->GetMemory(sbp_mid, edge_sbp_end); weighted_cost += edge_list_[1]->weighted_cost_[sbp_mid][edge_sbp_end]; } else { int32_t edge_sbp_end = end_node_->GetComponentSbpId(sbp_end, edge_list_[1]->start_node_); memory_cost += edge_list_[1]->GetMemory(edge_sbp_end, sbp_mid); weighted_cost += edge_list_[1]->weighted_cost_[edge_sbp_end][sbp_mid]; } // Compare and look for the minimum cost if (sbp_mid == 0 || weighted_cost < min_weighted_cost) { min_sbp_mid = sbp_mid; min_memory_cost = memory_cost; min_weighted_cost = weighted_cost; } } // Store the results of the dynamic programming for minimizing the weighted sum if (in_memory_support_) { memory_[sbp_start][sbp_end] = min_memory_cost; } weighted_cost_[sbp_start][sbp_end] = min_weighted_cost; mid_node_sbp_sig_[sbp_start][sbp_end] = min_sbp_mid; } } } else { // Edge elimination int32_t end_node_sbp_size = end_node_->weighted_cost_.size(); for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) { if (in_memory_support_) { memory_[sbp_start].resize(end_node_sbp_size); } weighted_cost_[sbp_start].resize(end_node_sbp_size); for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) { int64_t memory_cost = 0; double weighted_cost = 0.0; for (int32_t edge_num = 0; edge_num < edge_list_.size(); edge_num++) { // For normal edge elimination, instead of recomputation with different memory ratio // Either (start_node_ == edge_list_[edge_num]->start_node_ // and end_node_ == edge_list_[edge_num]->end_node_) is true // Or (start_node_ == edge_list_[edge_num]->end_node_ and // end_node_ == edge_list_[edge_num]->start_node_) is true. // At this moment, start_node_->component2merged_sig_id2component_sig_id_ is not // initialized. As a result, if start_node_ != edge_list_[edge_num]->start_node_, // IsComponent() would return false immediately. if (start_node_->IsComponent(edge_list_[edge_num]->start_node_)) { int32_t edge_sbp_start = start_node_->GetComponentSbpId(sbp_start, edge_list_[edge_num]->start_node_); int32_t edge_sbp_end = end_node_->GetComponentSbpId(sbp_end, edge_list_[edge_num]->end_node_); memory_cost += edge_list_[edge_num]->GetMemory(edge_sbp_start, edge_sbp_end); weighted_cost += edge_list_[edge_num]->weighted_cost_[edge_sbp_start][edge_sbp_end]; } else { // At this moment // start_node_->IsComponent(edge_list_[edge_num]->end_node_) // end_node_->IsComponent(edge_list_[edge_num]->start_node_) int32_t edge_sbp_start = start_node_->GetComponentSbpId(sbp_start, edge_list_[edge_num]->end_node_); int32_t edge_sbp_end = end_node_->GetComponentSbpId(sbp_end, edge_list_[edge_num]->start_node_); memory_cost += edge_list_[edge_num]->GetMemory(edge_sbp_end, edge_sbp_start); weighted_cost += edge_list_[edge_num]->weighted_cost_[edge_sbp_end][edge_sbp_start]; } } if (in_memory_support_) { memory_[sbp_start][sbp_end] = memory_cost; } weighted_cost_[sbp_start][sbp_end] = weighted_cost; } } } } void SbpEdge::DuplicateCost( bool merged_node_is_start_node, bool duplicating_first_node, const std::vector>& merged_sig_id2half_sig_id) { const int32_t num_sig = merged_sig_id2half_sig_id.size(); std::vector> copy_cost; std::vector> temp_mid_node_sbp_sig; std::vector> temp_memory; std::vector> weighted_cost; if (merged_node_is_start_node) { if (edge_list_.empty()) { copy_cost.resize(num_sig); } if (mid_node_) { temp_mid_node_sbp_sig.resize(num_sig); } weighted_cost.resize(num_sig); if (in_memory_support_) { temp_memory.resize(num_sig); } for (int32_t i = 0; i < num_sig; i++) { const int32_t sig_idx = duplicating_first_node ? merged_sig_id2half_sig_id[i].first : merged_sig_id2half_sig_id[i].second; if (edge_list_.empty()) { copy_cost[i] = cost_[sig_idx]; } weighted_cost[i] = weighted_cost_[sig_idx]; if (mid_node_) { temp_mid_node_sbp_sig[i] = mid_node_sbp_sig_[sig_idx]; } if (in_memory_support_) { temp_memory[i] = memory_[sig_idx]; } } } else { const int32_t num_start_sig = weighted_cost_.size(); if (edge_list_.empty()) { copy_cost.resize(num_start_sig); } weighted_cost.resize(num_start_sig); if (mid_node_) { temp_mid_node_sbp_sig.resize(num_start_sig); } if (in_memory_support_) { temp_memory.resize(num_start_sig); } for (int32_t i = 0; i < num_start_sig; i++) { if (edge_list_.empty()) { copy_cost[i].resize(num_sig); } weighted_cost[i].resize(num_sig); if (mid_node_) { temp_mid_node_sbp_sig[i].resize(num_sig); } if (in_memory_support_) { temp_memory[i].resize(num_sig); } for (int32_t j = 0; j < num_sig; j++) { const int32_t sig_idx = duplicating_first_node ? merged_sig_id2half_sig_id[j].first : merged_sig_id2half_sig_id[j].second; if (edge_list_.empty()) { copy_cost[i][j] = cost_[i][sig_idx]; } weighted_cost[i][j] = weighted_cost_[i][sig_idx]; if (mid_node_) { temp_mid_node_sbp_sig[i][j] = mid_node_sbp_sig_[i][sig_idx]; } if (in_memory_support_) { temp_memory[i][j] = memory_[i][sig_idx]; } } } } if (edge_list_.empty()) { cost_ = copy_cost; } weighted_cost_ = weighted_cost; if (mid_node_) { mid_node_sbp_sig_ = temp_mid_node_sbp_sig; } if (in_memory_support_) { memory_ = temp_memory; } } // Compute the weighted sum of the time and memory cost void SbpEdge::ComputeWeightedCost() { if (edge_list_.empty()) { // If this edge does not contain any sub edges, it should have original cost weighted_cost_ = cost_; if (in_memory_support_) { for (int32_t i = 0; i < memory_.size(); i++) { auto& memory_i = memory_[i]; auto& weighted_cost_i = weighted_cost_[i]; for (int32_t j = 0; j < memory_[i].size(); j++) { weighted_cost_i[j] += kMemoryRatio * memory_i[j]; } } } } else { // Compute the weighted cost for sub components for (auto& sbp_edge : edge_list_) { sbp_edge->ComputeWeightedCost(); } if (mid_node_) { mid_node_->ComputeWeightedCost(); } // Generate relationship if two vertices are merged nodes // For example, we have 4 nodes: A, B, C, D // and two edges: 1: A->B, 2: A->B // We merge the two edges 1 and 2 into 3: A->B. // Then we merge A and C into E and merge B and D into F. // Now the edge 3: E->F has two sub edges: 1: A->B, 2:A->B, // which tell us that the sub edges might have different vertices from the current edge. start_node_->GenerateComponentRelationship(); end_node_->GenerateComponentRelationship(); // Re-compute the weighted cost SummarizeCost(); } } void SbpEdge::FinalizeSbp() { // Finalize Sbp for mid_node_ if (mid_node_) { mid_node_->final_sbp_sig_id_ = mid_node_sbp_sig_[start_node_->final_sbp_sig_id_][end_node_->final_sbp_sig_id_]; mid_node_->FinalizeSbp(); } for (const auto& this_edge : edge_list_) { this_edge->FinalizeSbp(); } } double SbpEdge::GreedyStrategy() { // Sbp combination of the minimum cost int32_t min_sbp_start = start_node_->final_sbp_sig_id_, min_sbp_end = end_node_->final_sbp_sig_id_; // An unordered_map to evaluate cost between two edge nodes and other nodes. std::unordered_map node_list_id2nbh_id = {{start_node_->node_list_id_, 0}, {end_node_->node_list_id_, 1}}; // pre-compute and store the current cost between end_node_ and outside. std::vector end_node_out_cost(end_node_->weighted_cost_.size()); for (int32_t sbp_end = 0; sbp_end < weighted_cost_[0].size(); sbp_end++) { end_node_->final_sbp_sig_id_ = sbp_end; end_node_out_cost[sbp_end] = end_node_->EvalOutNbhCost(node_list_id2nbh_id); } // pre-compute and store the current cost between start_node_ and outside. std::vector start_node_out_cost(start_node_->weighted_cost_.size()); for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) { start_node_->final_sbp_sig_id_ = sbp_start; start_node_out_cost[sbp_start] = start_node_->EvalOutNbhCost(node_list_id2nbh_id); } // Current Cost, Minimum Cost, Cost with original sbp double curr_cost = 0.0; double min_cost = start_node_out_cost[min_sbp_start] + end_node_out_cost[min_sbp_end] + weighted_cost_[min_sbp_start][min_sbp_end]; double original_cost = min_cost; for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) { for (int32_t sbp_end = 0; sbp_end < weighted_cost_[0].size(); sbp_end++) { // compute Current Cost for Neighborhood of edge end_node_->final_sbp_sig_id_ = sbp_end; curr_cost = start_node_out_cost[sbp_start] + end_node_out_cost[sbp_end] + weighted_cost_[sbp_start][sbp_end]; // Find the minimum current cost if (curr_cost < min_cost) { min_cost = curr_cost; min_sbp_start = sbp_start; min_sbp_end = sbp_end; } } } start_node_->final_sbp_sig_id_ = min_sbp_start; end_node_->final_sbp_sig_id_ = min_sbp_end; return min_cost - original_cost; } // Get the minimum element in Cost double SbpEdge::GetMinWeightedCost() { // used the stored value if pre-computed. if (kMemoryRatio == memory_ratio4min_weighted_cost_ && min_weighted_cost_ >= 0) { return min_weighted_cost_; } // Check the size of Cost CHECK(weighted_cost_.size() > 0) << "Cost not initialized!" << std::endl; // Compute the min_cost for corresponding memory ratio min_weighted_cost_ = GetWeightedCost(); for (int32_t i = 0; i < weighted_cost_.size(); i++) { for (int32_t j = 0; j < weighted_cost_[i].size(); j++) { min_weighted_cost_ = std::min(min_weighted_cost_, GetWeightedCost(i, j)); } } // Store current the memory ratio memory_ratio4min_weighted_cost_ = kMemoryRatio; return min_weighted_cost_; } // Assemble copy cost void SbpEdge::InitCopyAndMemoryCost(const std::string& ibn, bool use_sbp_collector, bool nccl_not_use_compute_stream) { std::vector consumer_nd_sbp_sig2memory; if (nccl_not_use_compute_stream) { in_memory_support_ = true; // Compute and store the memory for consumer const auto& consumer_operator = end_node_->op_node_->op(); const auto& end_sbp_sig_list = end_node_->sbp_sig_list_; consumer_nd_sbp_sig2memory.resize(end_sbp_sig_list.size(), 0); const auto& lbi = consumer_operator.BnInOp2Lbi(ibn); const auto& consumer_hierarchy = *CHECK_JUST(consumer_operator.GetParallelDesc4BnInOp(ibn))->hierarchy(); const auto& logical_blob_desc = start_node_->op_node_->LogicalBlobDesc4Lbi(lbi); HashMap consumer_nd_sbp2memory; for (int32_t sbp_sig_id = 0; sbp_sig_id < end_sbp_sig_list.size(); sbp_sig_id++) { const NdSbp& nd_sbp = end_sbp_sig_list[sbp_sig_id].bn_in_op2nd_sbp().at(ibn); auto it = consumer_nd_sbp2memory.find(nd_sbp); if (it == consumer_nd_sbp2memory.end()) { // This compute the memory at rank 0, the largest one. // We could be faster if we just compute the average memory. it = consumer_nd_sbp2memory .insert({nd_sbp, MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, consumer_hierarchy)}) .first; } consumer_nd_sbp_sig2memory[sbp_sig_id] += it->second; } } // In this part, we assemble the cost from nodes to nodes. if (start_node_->op_node_ && end_node_->op_node_) { OpNode* consumer = end_node_->op_node_; // Add copy cost for each blob const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn); // Check whether lbi is transferred by this edge if (use_sbp_collector && !SearchLbi(lbi)) { return; } OpNode* producer = start_node_->op_node_; const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi)); const ParallelDesc& producer_parallel_desc = *CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)); const ParallelDesc& consumer_parallel_desc = *CHECK_JUST(consumer->op().GetParallelDesc4BnInOp(ibn)); // Need to be careful, the logical blob description should be independent to current // SbpParallel. Use producer or op_node? const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi); const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi)); // If we are deciding whether we need the wait time, then make require_same_sbp true. // B->S cause cudaEventSynchronize in current implementation. bool require_same_sbp = RequireSameSbp(consumer, ibn); int32_t consumer_sbp_size = end_node_->sbp_sig_list_.size(); LazyMode::Guard enable_lazy_mode(true); // look through sbp signature in producer for (int32_t sbp_id_producer = 0; sbp_id_producer < start_node_->sbp_sig_list_.size(); sbp_id_producer++) { // get sbp parallel for a logical blob in producer const auto& producer_sbp_bn_in_op2sbp_parallel = start_node_->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp(); const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn); auto& cost4sbp_id_producer = cost_[sbp_id_producer]; // look through sbp signature in consumer for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) { // get sbp parallel for a logical blob in consumer const auto& consumer_sbp_bn_in_op2sbp_parallel = end_node_->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp(); const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn); // compute copy cost for a specific logical blob double curr_edge_cost = CHECK_JUST(ComputeCopyCostWithMiddleNodes( sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, require_same_sbp)); if (curr_edge_cost < GetValidMaxCopyCost()) { cost4sbp_id_producer[sbp_id_consumer] += CHECK_JUST(producer->op().GetOpTimeShape())->elem_cnt() * curr_edge_cost; } else { cost4sbp_id_producer[sbp_id_consumer] = curr_edge_cost; } // If enabling nccl_use_compute_stream and transfer occurs, // the current code would create a non-reusable register to receive data. if (nccl_not_use_compute_stream && curr_edge_cost > 0) { memory_[sbp_id_producer][sbp_id_consumer] += consumer_nd_sbp_sig2memory[sbp_id_consumer]; } } } } } // Assemble memory cost void SbpEdge::InitializeMemory(const HashMap& lbi2id, const std::vector& id2count, const std::vector& producer_nd_sbp_sig2memory) { const auto& consumer_operator = end_node_->op_node_->op(); const auto& end_sbp_sig_list = end_node_->sbp_sig_list_; std::vector consumer_nd_sbp_sig2memory(end_sbp_sig_list.size(), 0); // Compute and store the memory for consumer for (const auto& ibn : consumer_operator.input_bns()) { // Match the ibn to find the hierarchy const auto& lbi = consumer_operator.BnInOp2Lbi(ibn); if (SearchLbi(lbi) && id2count.at(lbi2id.at(lbi)) > 0) { const auto& consumer_hierarchy = *CHECK_JUST(consumer_operator.GetParallelDesc4BnInOp(ibn))->hierarchy(); const auto& logical_blob_desc = start_node_->op_node_->LogicalBlobDesc4Lbi(lbi); HashMap consumer_nd_sbp2memory; for (int32_t sbp_sig_id = 0; sbp_sig_id < end_sbp_sig_list.size(); sbp_sig_id++) { const NdSbp& nd_sbp = end_sbp_sig_list[sbp_sig_id].bn_in_op2nd_sbp().at(ibn); auto it = consumer_nd_sbp2memory.find(nd_sbp); if (it == consumer_nd_sbp2memory.end()) { // This compute the memory at rank 0, the largest one. // We could be faster if we just compute the average memory. it = consumer_nd_sbp2memory .insert({nd_sbp, MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, consumer_hierarchy)}) .first; } consumer_nd_sbp_sig2memory[sbp_sig_id] += it->second; } } } // Avoid negative value for memory // For example, B -> S might reduce memory but we still consider 0 memory increment instead of // negative memory increment. if (*std::max_element(consumer_nd_sbp_sig2memory.begin(), consumer_nd_sbp_sig2memory.end()) > *std::min_element(producer_nd_sbp_sig2memory.begin(), producer_nd_sbp_sig2memory.end())) { in_memory_support_ = true; memory_.resize(producer_nd_sbp_sig2memory.size()); int32_t consumer_sbp_sig_size = consumer_nd_sbp_sig2memory.size(); for (int32_t i = 0; i < producer_nd_sbp_sig2memory.size(); i++) { auto& memory_i = memory_[i]; memory_i.resize(consumer_sbp_sig_size, 0); for (int32_t j = 0; j < consumer_sbp_sig_size; j++) { int64_t memory_difference = consumer_nd_sbp_sig2memory[j] - producer_nd_sbp_sig2memory[i]; // Only accept positive memory change if (memory_difference > 0) { memory_i[j] = memory_difference; } } } } } // Set the cut ratio double SbpEdge::GetCutRatio() const { int32_t num = 0; for (int32_t i = 0; i < weighted_cost_.size(); i++) { for (int32_t j = 0; j < weighted_cost_[i].size(); j++) { if (weighted_cost_[i][j] < GetValidMaxCopyCost()) { num++; } } } return double(num) / double(weighted_cost_.size() * weighted_cost_[0].size()); } // find the cut ratio // (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost) double SbpEdge::FindCutRatio(int32_t threshold) const { double cut_ratio = GetCutRatio(); // lift the cut ratio to 1 to filter out some improper couples to avoid unlimited merging double n = weighted_cost_.size(); double m = weighted_cost_[0].size(); double num = cut_ratio * n * m; cut_ratio += 0.16 * (n + m) / double(threshold); if (num <= n * 2 || num <= m * 2 || (num <= threshold && cut_ratio < 0.51)) { return cut_ratio; } else { return 1.0; } } // load a logical blob void SbpEdge::LoadLbi(const LogicalBlobId& lbi) { carry_lbis_.insert(lbi); } // check the existence of a logical blob bool SbpEdge::SearchLbi(const LogicalBlobId& lbi) const { return carry_lbis_.find(lbi) != carry_lbis_.end(); } // unload a logical blob void SbpEdge::UnloadLbi(const LogicalBlobId& lbi) { if (carry_lbis_.erase(lbi) == 0) { std::cout << "Unload an empty lbi!" << std::endl; } } // Not carrying any blob bool SbpEdge::EmptyLbi() const { return carry_lbis_.empty(); } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_edge.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_ #include #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { namespace auto_parallel { // An edge structure to deal with the SBP strategy. // Please see SbpGraph for the whole algorithm and introduction. class SbpEdge final { /* There are 3 types of edges: * 1. start_node_ -> end_node_ * Nothing special * 2. Multiple start_node_ -> end_node_ * edge_list_ will store all the edges which goes from start_node_ to end_node_ * 3. start_node_ -> mid_node_ -> end_node_ * It will pass by a middle node. */ public: // Constructor for type 1 & 2 SbpEdge(SbpNode* start_node, SbpNode* end_node) : start_node_(start_node), end_node_(end_node) { mid_node_ = nullptr; } // Constructor for type 3 SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge, SbpEdge* second_edge); // Deconstructor ~SbpEdge(); OF_DISALLOW_COPY_AND_MOVE(SbpEdge); bool operator==(const SbpEdge& other) { return this == &other; } // Update copy cost for type 2 and 3 void SummarizeCost(); // Duplicate Cost. Designed for merging two nodes. void DuplicateCost(bool merged_node_is_start_node, bool duplicating_first_node, const std::vector>& merged_sig_id2half_sig_id); // Compute the weighted sum of the time and memory cost void ComputeWeightedCost(); // Determine Final SbpSignature for attachment of this edge void FinalizeSbp(); // Use Greedy Strategy to pick the sbp signature with minimum cost for this // edge. You should have an initial strategy before running this. And the // graph should be fully eliminated. double GreedyStrategy(); // load a logical blob void LoadLbi(const LogicalBlobId& lbi); // check the existence of a logical blob bool SearchLbi(const LogicalBlobId& lbi) const; // unload a logical blob void UnloadLbi(const LogicalBlobId& lbi); // Not carrying any blob bool EmptyLbi() const; // Get the minimum element in Cost double GetMinWeightedCost(); // Assemble copy and partial cost void InitCopyAndMemoryCost(const std::string& ibn, bool use_sbp_collector, bool nccl_not_use_compute_stream); // Assemble memory cost void InitializeMemory(const HashMap& lbi2id, const std::vector& id2count, const std::vector& producer_nd_sbp_sig2memory); // find the cut ratio // (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost) // But we would lift the cut ratio to 1 to filter out some improper couples double FindCutRatio(int32_t threshold) const; // Get the cut ratio double GetCutRatio() const; // Constant getter SbpNode* GetEndNode() const { return end_node_; } int64_t GetMemory(int32_t i, int32_t j) const { return in_memory_support_ ? memory_[i][j] : 0; } // Get the current memory with the current sbp signature index int64_t GetMemory() const { return GetMemory(start_node_->final_sbp_sig_id_, end_node_->final_sbp_sig_id_); } double GetWeightedCost(int32_t i, int32_t j) const { return weighted_cost_[i][j]; } // Get the current weighted cost with the current sbp signature index double GetWeightedCost() const { return GetWeightedCost(start_node_->final_sbp_sig_id_, end_node_->final_sbp_sig_id_); } private: friend class SbpNode; friend class SbpGraph; friend class SbpCollector; friend class SbpConstructor; // The edge point from start_node_ to end_node_ // It will have a middle node if and only if type 3 SbpNode *start_node_, *mid_node_, *end_node_; // Cost[sbp_i][sbp_j] is the total cost from start_node_ with sbp_i to end_node_ // with sbp_j std::vector> cost_; // SbpSignature for mid_node_ with corresponding Cost if type 3, empty otherwise std::vector> mid_node_sbp_sig_; // Contained edge list: // empty if type 1, // Parallel edges if type 2, // succeed edges if type 3 // the edge list might have reverse direction: // example 1: type 3 edge_list_ contain two edges: // mid_node_ -> start_node_, mid_node_ -> end_node_; // example 2: type 2 edge_list_ contain three edges: // start_node_ -> end_node_, end_node_ -> start_node_, start_node_ -> end_node_; std::vector edge_list_; // Time waiting for other gpus. pthread_cond_wait double wait_time_ = -1.0; // a set of ids of logical blobs carried/transferred on this sbp edge std::unordered_set carry_lbis_; // Minimum and maximum cost would not be changed by eliminations, which will generate new edges. // Also would not be changed by node merging, which will only perform cost copy for the expanding // dimensions. // Minimum cost in the 2D array Cost. // Would be initialized after GetMinWeightedCost(); // Only used in the final graph. // Such pre-store and access process save a lot time. // Gpt2 has 1178 storing and 14053 taking. // Bert has 1464 storing and 17633 taking. double min_weighted_cost_ = -1.0; // If consider memory, each GetMinWeightedCost would have a memory_ratio_search // Use the stored value for the same memory_ratio_search double memory_ratio4min_weighted_cost_ = -1.0; // The produced blob belongs to the support of the total memory bool in_memory_support_ = false; // The consumed memory for different sbp strategies std::vector> memory_; // The weighted sum of time cost and memory cost std::vector> weighted_cost_; }; } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_ ================================================ FILE: oneflow/core/auto_parallel/sbp_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/auto_parallel/sbp_edge.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/algorithm_util.h" namespace oneflow { namespace auto_parallel { // function in cpp. Should be put in one file due to use of template // Otherwise we will need to declare specific template at the end of cpp file. namespace { static const int32_t kMinNodeInGraphForMerging = 4; } // anonymous namespace // Generate a node SbpNode* SbpGraph::GenerateNode() { SbpNode* this_node = new SbpNode(); node_list_.emplace_back(this_node); this_node->node_list_id_ = node_list_.size() - 1; return this_node; } void SbpGraph::RemoveFromNodeList(SbpNode* this_node) { if (this_node->node_list_id_ < 0) { return; } node_list_.back()->node_list_id_ = this_node->node_list_id_; RemoveFrom(node_list_, this_node->node_list_id_); this_node->node_list_id_ = -1; } SbpGraph::~SbpGraph() { for (auto this_node : node_list_) { delete this_node; } node_list_.clear(); } void SbpGraph::RandomSbpSignature(bool use_sbp_collector) const { for (const auto& this_node : node_list_) { if (this_node->sbp_sig_list_.size() > 0) { this_node->final_sbp_sig_id_ = rand() % this_node->sbp_sig_list_.size(); } else { // It must be a proxy when this_node->sbp_sig_list_.size() == 0 this_node->final_sbp_sig_id_ = rand() % this_node->parallel_candidates_.size(); } } }; void SbpGraph::SetDefaultSbpSig() const { for (const auto& this_node : node_list_) { this_node->final_sbp_sig_id_ = 0; } }; void SbpGraph::StoreOriginMemory() { // We do not need to store the origin cost and memory for edges // Because the origin cost and memory is the current cost and memory for a bare edge. // For nodes, we need to do so because child elimination would attach the child cost and memory to // the current cost and memory. for (auto& this_node : node_list_) { this_node->origin_cost_ = this_node->cost_; this_node->origin_memory_ = this_node->memory_; } } double SbpGraph::ComputeCost() const { // Overall cost under current strategy double graph_cost_ = 0; for (const auto& this_node : node_list_) { int32_t this_id = this_node->final_sbp_sig_id_; graph_cost_ += this_node->weighted_cost_[this_id]; for (const auto& edge_out : this_node->edges_out_) { graph_cost_ += edge_out->weighted_cost_[this_id][edge_out->end_node_->final_sbp_sig_id_]; } } return graph_cost_; } double SbpGraph::ComputeWeightedCost() const { // Overall cost under current strategy double graph_cost_ = 0; for (const auto& this_node : node_list_) { int32_t this_id = this_node->final_sbp_sig_id_; graph_cost_ += this_node->weighted_cost_[this_id]; for (const auto& edge_out : this_node->edges_out_) { graph_cost_ += edge_out->weighted_cost_[this_id][edge_out->end_node_->final_sbp_sig_id_]; } } return graph_cost_; } // Re-compute weighted cost void SbpGraph::ReComputeWeightedCost() { for (const auto& this_node : node_list_) { this_node->ComputeWeightedCost(); for (const auto& edge_out : this_node->edges_out_) { edge_out->ComputeWeightedCost(); } } } int64_t SbpGraph::GetMemory() const { // Overall memory under current strategy int64_t total_memory = 0; for (const auto& this_node : node_list_) { total_memory += this_node->GetMemory(); for (const auto& edge_out : this_node->edges_out_) { total_memory += edge_out->GetMemory(); } } return total_memory; } int32_t SbpGraph::NodeElimination(SbpNode* this_node) { if (this_node->edges_in_.size() + this_node->edges_out_.size() == 2) { std::vector two_nodes; for (const auto& one_edge : this_node->edges_in_) two_nodes.emplace_back(one_edge->start_node_); for (const auto& one_edge : this_node->edges_out_) two_nodes.emplace_back(one_edge->end_node_); // If a node is pointing to itself, could happen when shrink from a circle if (two_nodes[0] == two_nodes[1]) { int32_t elimination_number = 0; if (this_node->edges_out_.empty()) { elimination_number += EdgeElimination(two_nodes[0]); } else { elimination_number += EdgeElimination(this_node); } elimination_number += ChildElimination(this_node); return elimination_number; } std::vector two_edges(this_node->edges_in_); two_edges.insert(two_edges.end(), this_node->edges_out_.begin(), this_node->edges_out_.end()); int32_t edges_in_size = this_node->edges_in_.size(); SbpEdge* e = new SbpEdge(two_nodes[0], this_node, two_nodes[1], two_edges[0], two_edges[1]); e->SummarizeCost(); // check and remove the edge_in with new edge in graph for (int32_t i = 0; i < edges_in_size; i++) { CheckAndRemoveFrom(two_nodes[i]->edges_out_, two_edges[i]); } // check and remove the edge_out with new edge in graph for (int32_t i = edges_in_size; i < 2; i++) { CheckAndRemoveFrom(two_nodes[i]->edges_in_, two_edges[i]); } // Let e take control of edge_list_ completely by disconnecting MidNode e->mid_node_->edges_out_.clear(); e->mid_node_->edges_in_.clear(); // Insert new compound edge into graph two_nodes[0]->edges_out_.emplace_back(e); two_nodes[1]->edges_in_.emplace_back(e); // eliminate the node from graph by swapping with the last element and // popping RemoveFromNodeList(this_node); // successfully eliminate this node return 1; } // can not eliminate this node return 0; } int32_t SbpGraph::NodeAndEdgeEliminations() { // Total elimination number int32_t total_elimination_num = 0; int32_t elimination_num = 1; // repeat these kinds of elimination until stuck while (elimination_num > 0) { elimination_num = 0; for (int32_t i = node_list_.size() - 1; i >= 0; i--) { elimination_num += NodeElimination(node_list_[i]); } for (int32_t i = node_list_.size() - 1; i >= 0; i--) { elimination_num += EdgeElimination(node_list_[i]); } for (int32_t i = node_list_.size() - 1; i >= 0; i--) { elimination_num += ChildElimination(node_list_[i]); } if (elimination_num == 0 && node_list_.size() > 2) { elimination_num += PickAndMerge(); for (int32_t i = node_list_.size() - 1; i >= 0; i--) { elimination_num += EdgeElimination(node_list_[i]); } } total_elimination_num += elimination_num; } return total_elimination_num; } int32_t SbpGraph::EdgeElimination(SbpNode* this_node) const { // Remove all edges with (start_node -> end_node) from edges_in_ of end_node auto RemoveFromEdgesIn = [](SbpNode* start_node, SbpNode* end_node) -> void { for (int32_t i = end_node->edges_in_.size() - 1; i >= 0; i--) { if (start_node == end_node->edges_in_[i]->start_node_) { RemoveFrom(end_node->edges_in_, i); } } }; auto LookForParallelEdge = [](SbpEdge*& e, SbpNode* start_node, SbpNode* end_node, bool if_reverse, int32_t stop_sign) -> int32_t { // elimination edges with specific start node and end node in // start_node->edges_out_ from index stop sign to the end. // start_node->edges_out_[stop_sign] not included and need special treatment // after this process. int32_t elimination_num = 0; for (int32_t j = start_node->edges_out_.size() - 1; j > stop_sign; j--) { if (end_node == start_node->edges_out_[j]->end_node_) { if (!e) { if (if_reverse) { e = new SbpEdge(end_node, start_node); } else { e = new SbpEdge(start_node, end_node); } } // edge elimination e->edge_list_.emplace_back(start_node->edges_out_[j]); elimination_num++; RemoveFrom(start_node->edges_out_, j); } } return elimination_num; }; int32_t elimination_num = 0; for (int32_t i = 0; i < this_node->edges_out_.size(); i++) { SbpEdge* e = nullptr; // Find and delete Parallel Edges from edges_out_ elimination_num += LookForParallelEdge(e, this_node, this_node->edges_out_[i]->end_node_, /*if_reverse=*/false, i); elimination_num += LookForParallelEdge(e, this_node->edges_out_[i]->end_node_, this_node, /*if_reverse=*/true, /*stop_sign=*/-1); if (e) { // Delete Parallel Edges from edges_in_ RemoveFromEdgesIn(this_node, e->end_node_); RemoveFromEdgesIn(e->end_node_, this_node); // Add the compound edge e->edge_list_.emplace_back(this_node->edges_out_[i]); this_node->edges_out_[i] = e; e->SummarizeCost(); e->end_node_->edges_in_.emplace_back(e); } } return elimination_num; } int32_t SbpGraph::ChildElimination(SbpNode* this_node) { if (this_node->EliminateItselfAsChild()) { // eliminate this node from global node list RemoveFromNodeList(this_node); // successfully eliminate this node return 1; } else { // can not eliminate this node return 0; } } // Merge two nodes int32_t SbpGraph::NodeMerging(SbpNode* first, SbpNode* second) { SbpNode* new_node = new SbpNode(first, second); // Adjust node_list_ RemoveFromNodeList(first); RemoveFromNodeList(second); new_node->node_list_id_ = node_list_.size(); node_list_.emplace_back(new_node); return 1; } void SbpGraph::FinalizeSbp() const { for (const auto& this_node : node_list_) { this_node->FinalizeSbp(); } } double SbpGraph::GreedyStrategy(bool for_node) const { // Overall, this function should be replaced by GreedyStrategy(nbh_num); // Total Cost Reduce & Cost Reduce for one loop double total_cost_reduction = 0, cost_reduction = 0; for (int32_t step = node_list_.size(); step >= 0; step--) { cost_reduction = 0; for (SbpNode* this_node : node_list_) { // Use GreedyStrategy on Nodes if there is one node left for this // connected component. Otherwise, Use GreedyStrategy on Edges. if (for_node || this_node->edges_in_.size() + this_node->edges_out_.size() == 0) { cost_reduction += this_node->GreedyStrategy(); } else { // GreedyStrategy on Edges. for (SbpEdge* this_edge : this_node->edges_out_) { double second_rdc = this_edge->GreedyStrategy(); cost_reduction += second_rdc; } } } if (cost_reduction == 0) { break; } total_cost_reduction += cost_reduction; } return total_cost_reduction; } double SbpGraph::GreedyStrategy(int32_t nbh_num) const { // nbh_num is the maximum number of neighborhood to adjust sbp strategy in each step // Total Cost Reduce & Cost Reduce for one loop double total_cost_reduction = 0, cost_reduction = 0; // A global buffer to store part of the one ring neighborhood. std::vector nbh_id2node_list_id; // Not accept a number lower than 1 if (nbh_num < 1) { nbh_num = 1; } nbh_id2node_list_id.resize(nbh_num); std::vector original_sbp_sig_id(nbh_num); // store all the node_list_id whose corresponding nodes will be visited // We can use unordered_map to do this but vector is faster std::vector pre_visit_node_list(node_list_.size() + 1); for (int32_t nbh_id = 0; nbh_id < node_list_.size(); nbh_id++) { pre_visit_node_list[nbh_id] = nbh_id; } int32_t head = 0, tail = node_list_.size(); // whether a node_list_id is in pre_visit_node_list std::vector pre_visit_tags(node_list_.size(), true); int32_t step = 0; // 1 ring neighborhood buffer std::vector nbh_1ring(nbh_num); // 2 ring neighborhood buffer std::vector nbh_2ring; std::vector node_tags(node_list_.size(), false); std::vector nbh_1ring_buffer; while (head != tail && step < node_list_.size()) { auto* this_node = node_list_[pre_visit_node_list[head]]; if (nbh_num <= 1) { // Greedy strategy on nodes, here we use nbh_1ring to store the nbh_id2node_list_id // information for reutilization nbh_1ring[0] = this_node->node_list_id_; // store the original sbp signature of the 1-ring neighborhood for comparison original_sbp_sig_id[0] = this_node->final_sbp_sig_id_; cost_reduction = NbhGreedyStrategy(nbh_1ring); } else { // Use GreedyStrategy on the one ring neighborhood of this node. this_node->OneRingNeighborhood(nbh_1ring); // store the original sbp signature of the 1-ring neighborhood for comparison original_sbp_sig_id.resize(nbh_1ring.size()); for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) { original_sbp_sig_id[nbh_id] = node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_; } if (nbh_1ring.size() <= nbh_num) { cost_reduction = NbhGreedyStrategy(nbh_1ring); } else { // Use GreedyStrategy on part of the one ring neighborhood. // Loop through the neighborhood. Each loop should contain the centroid. // Initialize part of the one ring neighborhood int32_t nbh_1ring_id = nbh_1ring.size() - nbh_num; for (int32_t nbh_id = 1; nbh_id < nbh_num; ++nbh_id) { nbh_id2node_list_id[nbh_id] = nbh_1ring[++nbh_1ring_id]; } // loop through the one ring neighborhood cost_reduction = 0; int32_t nbh_id = 0; for (nbh_1ring_id = 0; nbh_1ring_id < nbh_1ring.size(); ++nbh_1ring_id) { nbh_id2node_list_id[nbh_id] = nbh_1ring[nbh_1ring_id]; cost_reduction += NbhGreedyStrategy(nbh_id2node_list_id); // nbh_id for the next step if (++nbh_id >= nbh_num) { nbh_id = 1; } } } } // change of strategies if (cost_reduction != 0) { // Add neighborhood into pre-visited node list for each node with changing strategy for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) { // If changes occur if (original_sbp_sig_id[nbh_id] != node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_) { // schedule to visit the neighborhood of that changing node node_list_[nbh_1ring[nbh_id]]->NRingNeighborhood(2, nbh_2ring, nbh_1ring_buffer, node_list_, node_tags); for (int32_t nbh_node_list_id : nbh_2ring) { // Put them into the pre-visited node list if (!pre_visit_tags[nbh_node_list_id]) { pre_visit_node_list[tail] = nbh_node_list_id; pre_visit_tags[nbh_node_list_id] = true; tail++; if (tail == pre_visit_node_list.size()) { tail = 0; } } } } } } // Finish visiting pre_visit_tags[pre_visit_node_list[head]] = false; head++; if (head == pre_visit_node_list.size()) { head = 0; step++; } total_cost_reduction += cost_reduction; } return total_cost_reduction; } void SbpGraph::DfsAddNbhCost(std::vector& nbh_id2node_list_id, std::unordered_map& node_list_id2nbh_id, std::vector& order2nbh_id, std::vector& nbh_id2order, std::vector& order2acc_min_in_nbh_cost, std::vector>& out_nbh_costs, std::vector>& nbh_id2order2sbp_id, std::vector& min_sbp_sig_id, double& min_cost, int32_t order, double curr_cost) const { // We have finished visiting the neighborhood if (order >= nbh_id2node_list_id.size()) { // relative difference > 1e-12 if (curr_cost < min_cost * kFloatDeviationMinus) { min_cost = curr_cost; for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) { min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_; } } return; } // Pruning, remove all those branch with large cost if (curr_cost + order2acc_min_in_nbh_cost[order] >= min_cost) { return; } // Deep first search in the next order int32_t nbh_id = order2nbh_id[order]; SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]]; for (int32_t sbp_id : nbh_id2order2sbp_id[nbh_id]) { sbp_node->final_sbp_sig_id_ = sbp_id; DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order, order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id, min_cost, order + 1, curr_cost + out_nbh_costs[nbh_id][sbp_id] + sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order)); } } bool SbpGraph::DfsFindReasonableCost(std::vector& nbh_id2node_list_id, std::unordered_map& node_list_id2nbh_id, std::vector& nbh_id2order, int32_t nbh_id) const { // We found such a strategy if (nbh_id == nbh_id2order.size()) { return true; } SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]]; // Start from B. for (int32_t sbp_id = sbp_node->weighted_cost_.size() - 1; sbp_id >= 0; sbp_id--) { sbp_node->final_sbp_sig_id_ = sbp_id; // If the cost for this node is reasonable, then go to the next one if (sbp_node->weighted_cost_[sbp_id] + sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order) < GetValidMaxCopyCost()) { if (DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order, nbh_id + 1)) { // If we found one strategy, then exist the Dfs. return true; } } } // Can not find a reasonable strategy with the setting for previous nodes. // Go back and change the previous node. return false; } // Find one strategy with finite cost for adjustment Maybe SbpGraph::Find1Strategy4Greedy() const { std::vector nbh_id2node_list_id; std::vector not_visited(node_list_.size(), true); std::vector nbh_1ring; int32_t head = 0; int32_t tail = 0; std::vector node_cut_ratios(node_list_.size()); // Initialize cut ratio for all the nodes for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) { node_cut_ratios[node_list_id] = node_list_[node_list_id]->GetCutRatio(); } // If have not visited all the nodes while (tail < node_list_.size()) { // Find the node with the minimum cut ratio int32_t node_with_min_cut_ratio = -1; double min_cut_ratio = 2.0; for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) { if (not_visited[node_list_id]) { double curr_cut_ratio = node_cut_ratios[node_list_id]; if (curr_cut_ratio < min_cut_ratio) { min_cut_ratio = curr_cut_ratio; node_with_min_cut_ratio = node_list_id; } } } // put this node into the open set nbh_id2node_list_id.push_back(node_with_min_cut_ratio); not_visited[node_with_min_cut_ratio] = false; tail++; // BFS while (head < tail) { // look for the neighborhood of the head int32_t node_list_id = nbh_id2node_list_id[head]; node_list_[node_list_id]->OneRingNeighborhood(nbh_1ring); // sort std::sort(nbh_1ring.begin(), nbh_1ring.end(), [&](int32_t i, int32_t j) { return node_cut_ratios[i] < node_cut_ratios[j]; }); for (int32_t curr_id : nbh_1ring) { if (not_visited[curr_id]) { nbh_id2node_list_id.push_back(curr_id); tail++; not_visited[curr_id] = false; } } head++; } } // mapping from the node_list_id to the id in the nbh_id2node_list_id std::unordered_map node_list_id2nbh_id; InverseFunction(nbh_id2node_list_id, node_list_id2nbh_id); // Initial an ordinary order std::vector nbh_id2order(nbh_id2node_list_id.size()); for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) { nbh_id2order[nbh_id] = nbh_id; } // Combining deep first search and pruning based on cut ratio CHECK(DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order, /*nbh_id=*/0)) << "Can't find a reasonable strategy!"; return Maybe::Ok(); } // Use brute force to search for a strategy with minimum cost for a neighborhood double SbpGraph::NbhGreedyStrategy(std::vector& nbh_id2node_list_id) const { // number of nodes in the neighborhood int32_t num_nbh = nbh_id2node_list_id.size(); // mapping from the node_list_id to the id in the nbh_id2node_list_id std::unordered_map node_list_id2nbh_id; InverseFunction(nbh_id2node_list_id, node_list_id2nbh_id); // a sbp signature id set minimizing the overall cost, store the original one as default std::vector min_sbp_sig_id(num_nbh); for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_; } // pre-compute and store the cost between neighborhood and outside nodes under different sbp for // each node within the neighborhood std::vector> out_nbh_costs(num_nbh); for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]]; out_nbh_costs[nbh_id].resize(sbp_node->weighted_cost_.size()); for (int32_t sbp_id = sbp_node->weighted_cost_.size() - 1; sbp_id >= 0; sbp_id--) { sbp_node->final_sbp_sig_id_ = sbp_id; out_nbh_costs[nbh_id][sbp_id] = sbp_node->EvalOutNbhCost(node_list_id2nbh_id); } } // pre-compute and store the order of the out_nbh_costs std::vector> nbh_id2order2sbp_id(num_nbh); auto CompareDoubleLess = [](double a, double b) { return a < b; }; for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { DecideOrder(out_nbh_costs[nbh_id], nbh_id2order2sbp_id[nbh_id], CompareDoubleLess); } // Decide the order to go through the neighborhood. // Should visit those nodes with a larger difference in the out cost first. std::vector out_nbh_cost_diff(num_nbh); for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { out_nbh_cost_diff[nbh_id] = *std::max_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end()) - *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end()); } std::vector order2nbh_id; DecideOrder(out_nbh_cost_diff, order2nbh_id, [](double a, double b) { return a > b; }); // Find the inverse map of order std::vector nbh_id2order; InverseOrder(order2nbh_id, nbh_id2order); // Current Cost, Minimum Cost, Cost with original sbp double original_cost = 0; // Recover original sbp for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id]; } // Compute cost with original sbp for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]]; original_cost += out_nbh_costs[nbh_id][min_sbp_sig_id[nbh_id]]; original_cost += sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order); } double min_cost = original_cost; // Accumulate minimum cost from the current node to the end of the neighborhood node list. // The accumulated cost include the current node. std::vector order2acc_min_in_nbh_cost(num_nbh); order2acc_min_in_nbh_cost[num_nbh - 1] = *std::min_element(out_nbh_costs[order2nbh_id[num_nbh - 1]].begin(), out_nbh_costs[order2nbh_id[num_nbh - 1]].end()); for (int32_t order = num_nbh - 2; order >= 0; order--) { int32_t nbh_id = order2nbh_id[order]; order2acc_min_in_nbh_cost[order] = order2acc_min_in_nbh_cost[order + 1] + *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end()) + node_list_[nbh_id2node_list_id[nbh_id]]->EvalMinInNbhCost(node_list_id2nbh_id, nbh_id2order); } // Use brute force (DFS) to adjust for the best strategy in the neighborhood. DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order, order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id, min_cost, /*order=*/0, /*curr_cost=*/0); // Use the sbp strategy with minimum cost for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) { node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id]; } if (min_cost < original_cost) { // Directly return (min_cost - original_cost) might have floating point error up to 3e-16 // For example, original_cost: 2.22507e+06, min_cost: 2.22507e+06, // diff: -4.65661e-10, relative diff:2.09279e-16 // Therefore, we use a threshold to filter out such fake true detection to // avoid unlimited search. if (original_cost * kFloatDeviationMinus > min_cost) { return min_cost - original_cost; } } return 0.0; } // Select and Merge two nodes int32_t SbpGraph::PickAndMerge() { if (node_list_.size() < kMinNodeInGraphForMerging) { return 0; } // Pick the one with the smallest cut ratio double min_cut_ratio = 1.0; double curr_cut_ratio = 0.0; SbpEdge* merging_edge = nullptr; for (int32_t i = 0; i < node_list_.size(); i++) { for (SbpEdge* edge_in : node_list_[i]->edges_in_) { curr_cut_ratio = edge_in->FindCutRatio(threshold_); if (curr_cut_ratio < min_cut_ratio) { min_cut_ratio = curr_cut_ratio; merging_edge = edge_in; } } } if (merging_edge != nullptr) { // Merge two nodes on the edge with the minimum cut ratio return NodeMerging(merging_edge->start_node_, merging_edge->end_node_); } else { // Pick the couple with the largest similar neighborhood std::vector node_binary_sets(node_list_.size()); for (int32_t i = 0; i < node_list_.size(); i++) { // Transfer edge to binary set node_binary_sets[i].Initialize(node_list_.size()); node_binary_sets[i].AddEntry(i); for (const SbpEdge* edge_in : node_list_[i]->edges_in_) { node_binary_sets[i].AddEntry(edge_in->start_node_->node_list_id_); } for (const SbpEdge* edge_out : node_list_[i]->edges_out_) { node_binary_sets[i].AddEntry(edge_out->start_node_->node_list_id_); } } // Find two nodes with largest common subset // buffer of binary set BinarySet buffer_binary_set(node_list_.size()); // Number of common edges int32_t max_comm_edge_num = 0, curr_comm_edge_num = 0; int32_t min_node_pair[2]; // Number of Sbp Signature in merged node int32_t min_sbp_num = 0, curr_sbp_num = 0; for (int32_t i = 0; i < node_list_.size(); i++) { for (int32_t j = i + 1; j < node_list_.size(); j++) { curr_sbp_num = node_list_[i]->weighted_cost_.size() * node_list_[j]->weighted_cost_.size(); if (curr_sbp_num <= threshold_) { node_binary_sets[i].IntersectionTo(node_binary_sets[j], buffer_binary_set); curr_comm_edge_num = buffer_binary_set.Total(); if (curr_comm_edge_num > max_comm_edge_num || (curr_comm_edge_num == max_comm_edge_num && curr_sbp_num < min_sbp_num)) { min_node_pair[0] = i; min_node_pair[1] = j; max_comm_edge_num = curr_comm_edge_num; min_sbp_num = curr_sbp_num; } } } } if (max_comm_edge_num > 0) { return NodeMerging(node_list_[min_node_pair[0]], node_list_[min_node_pair[1]]); } else { return 0; } } } // Clip an edge, remove it from graph void SbpGraph::ClipEdge(SbpEdge* this_edge) const { CheckAndRemoveFrom(this_edge->end_node_->edges_in_, this_edge); CheckAndRemoveFrom(this_edge->start_node_->edges_out_, this_edge); delete this_edge; } // Compute the minimum and maximum layer of each node in the graph int32_t SbpGraph::ComputeLayer( HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps) const { // Compute minimum layer for (SbpNode* this_node : node_list_) { this_node->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps); } // Find the largest minimum layer int32_t max_min_layer = -1; for (SbpNode* this_node : node_list_) { if (max_min_layer < this_node->min_layer_) { max_min_layer = this_node->min_layer_; } } // Compute maximum layer for (SbpNode* this_node : node_list_) { this_node->SpreadMaxLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps); } for (SbpNode* this_node : node_list_) { this_node->LiftMaxLayer(max_min_layer); } return max_min_layer; } // TODO: Remove the tributary layer here. // Find the trunk of the sbp graph, then reduce the wait time for tributaries void SbpGraph::FindTrunk(int32_t max_min_layer, HashMap& op_name2sbp_node) const { // Summarize cost for each layer, on the trunk or tributaries std::vector trunk_cost(max_min_layer + 1, 0); for (SbpNode* this_node : node_list_) { trunk_cost[this_node->min_layer_] += this_node->GetMinCost(); } // Decide trunks double acc_cost = 0; // All the nodes with MinLayer>=trunk_end_id would be considered as trunks int32_t trunk_end_id = max_min_layer; for (int32_t layer_id = max_min_layer; layer_id >= 0; layer_id--) { acc_cost += trunk_cost[layer_id]; if (acc_cost > 0.5 * wait_time_) { trunk_end_id = layer_id; break; } } // Find out all the nodes on the trunk. for (SbpNode* this_node : node_list_) { if (this_node->min_layer_ >= trunk_end_id) { this_node->SpreadTrunk(op_name2sbp_node); } } // Compute maximum layer for tributaries // Clear counter and initialize tributary layer for each sbp node for (SbpNode* this_node : node_list_) { this_node->counter_ = 0; this_node->DropTributaryLayer(max_min_layer); } // Count the number of consumers and downstream nodes for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); } // Compute maximum layer for tributaries for (SbpNode* this_node : node_list_) { this_node->SpreadTributaryLayer(op_name2sbp_node); } // Summarize cost for each layer on the trunk, store it to avoid subtraction of large values. trunk_cost.assign(max_min_layer + 1, 0); // tributary cost start from each min layer std::vector tributary_cost(max_min_layer + 1, 0); // tributary cost would be outdated after Max Layer (before Max Layer + 1) std::vector outdated_tributary_cost(max_min_layer + 1, 0); // number of operators in the trunk std::vector> trunk_ops(max_min_layer + 1); for (SbpNode* this_node : node_list_) { if (this_node->on_trunk_) { trunk_cost[this_node->min_layer_] += this_node->GetMinCost(); trunk_ops[this_node->min_layer_].emplace_back(this_node); } else { double curr_min_cost = this_node->GetMinCost(); tributary_cost[this_node->min_layer_] += curr_min_cost; outdated_tributary_cost[this_node->tributary_layer_] += curr_min_cost; } } // Accumulate the cost from the consumer to the end, not including itself std::vector acc_trunk_cost(max_min_layer + 1, 0); for (int32_t layer_id = max_min_layer; layer_id > 0; layer_id--) { acc_trunk_cost[layer_id - 1] = acc_trunk_cost[layer_id] + trunk_cost[layer_id]; } // Clear counter for each sbp node for (SbpNode* this_node : node_list_) { this_node->counter_ = 0; } // Count the number of consumers and downstream nodes for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); } // Reduce the wait time for tributaries for (SbpNode* this_node : node_list_) { this_node->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time_); } // Reduce the wait time for trunk from the end to the begin double acc_tributary_cost = outdated_tributary_cost[max_min_layer]; double used_tributary_cost = 0.0; double curr_wait_time = 0.0; for (int32_t layer_id = max_min_layer - 1; layer_id >= 0; layer_id--) { // Can not move it backward since we need to do this at the 0th layer. // At some moment, the cost haven't been used would disappear. if (tributary_cost[layer_id + 1] > used_tributary_cost) { acc_tributary_cost -= tributary_cost[layer_id + 1] - used_tributary_cost; used_tributary_cost = 0.0; if (acc_tributary_cost < 0.0) { // should not happen besides floating point error std::cout << "Caution! Current accumulated tributary cost is: " << acc_tributary_cost << std::endl; acc_tributary_cost = 0.0; } } else { used_tributary_cost -= tributary_cost[layer_id + 1]; } // accumulate tributary cost at this layer acc_tributary_cost += outdated_tributary_cost[layer_id]; // If we have more cost in tributaries, we reduce the wait time // This code maintains ( acc_tributary_cost + used_tributary_cost ) if (acc_tributary_cost > 0.0) { if (acc_tributary_cost > wait_time_) { curr_wait_time = 0.0; acc_tributary_cost -= wait_time_; used_tributary_cost += wait_time_; } else { curr_wait_time = wait_time_ - acc_tributary_cost; used_tributary_cost += acc_tributary_cost; acc_tributary_cost = 0.0; } // Reduce the wait time in the trunk for (SbpNode* this_node : trunk_ops[layer_id]) { this_node->SetTrunkWaitTime(curr_wait_time); } } } } // Set wait time void SbpGraph::SetWaitTime(double wait_time) { wait_time_ = wait_time; } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_ #include #include #include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/sbp_edge.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace auto_parallel { // A graph structure to deal with the SBP strategy. // It contains a lot of eliminations to shrink the topography structure of the original graph. // Furthermore, it contains some adjustment tricks for search a good strategy in the shrunk graph. class SbpGraph final { public: // Constructor SbpGraph() = default; // Deconstructor ~SbpGraph(); OF_DISALLOW_COPY_AND_MOVE(SbpGraph); bool operator==(const SbpGraph& other) { return this == &other; } // Randomly assign a SbpSignature strategy void RandomSbpSignature(bool use_sbp_collector) const; // assign 0 to a SbpSignature strategy to avoid randomness void SetDefaultSbpSig() const; void StoreOriginMemory(); // Compute Cost for current strategy double ComputeCost() const; double ComputeWeightedCost() const; // Re-compute weighted cost void ReComputeWeightedCost(); // Generate a node SbpNode* GenerateNode(); // Merge all parallel edges & Check and eliminate all nodes with only one // degree-in and one degree-out int32_t NodeAndEdgeEliminations(); // Finalize Sbp Cost for the whole graph void FinalizeSbp() const; // Use Greedy Strategy to decide Sbp for Nodes in node_list_. Should be used // after we have a initial strategy. // Set for_node to be true will only use GreedyStrategy on Nodes. double GreedyStrategy(bool for_node) const; // Use greedy strategy on the one ring neighborhood with the maximum number of points nbh_num. double GreedyStrategy(int32_t nbh_num = 4) const; // Find one strategy with finite cost for adjustment Maybe Find1Strategy4Greedy() const; // Use brute force to search for a strategy with minimum cost for a neighborhood double NbhGreedyStrategy(std::vector& nbh_id2node_list_id) const; // Set threshold_ for SbpNode Merging void SetThreshold(int32_t threshold) { threshold_ = threshold; } // Clip an edge, remove it from graph // Clipping an edge will also delete the nodes and edges contained in this edge. Though not // suffering from any compiling and runtime bugs, clipping an edge on a shrunk graph is not // recommended. We should carefully think about it before any clipping. void ClipEdge(SbpEdge* this_edge) const; // Compute the minimum and maximum layer of each node in the graph int32_t ComputeLayer( HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps) const; // Find the trunk of the sbp graph, then reduce the wait time for tributaries void FindTrunk(int32_t max_min_layer, HashMap& op_name2sbp_node) const; // Set wait time void SetWaitTime(double wait_time); // Constant getter std::vector& GetNodeList() { return node_list_; } int64_t GetMemory() const; private: friend class SbpCollector; friend class SbpConstructor; // All the nodes std::vector node_list_; // Limitation: Merged node should not have a number of Sbp Signature greater // than threshold. int32_t threshold_ = 100; // Wait time for copy cost, which occurs before communication between devices. double wait_time_ = 16500.0; // Remove a node from the node list void RemoveFromNodeList(SbpNode* this_node); // Check and eliminate one node with only one degree-in and one degree-out int32_t NodeElimination(SbpNode* this_node); // Merge all parallel edges with given start_node_ and end_node_ int32_t EdgeElimination(SbpNode* this_node) const; // Check and eliminate one child node int32_t ChildElimination(SbpNode* this_node); // Merge two nodes int32_t NodeMerging(SbpNode* first, SbpNode* second); // Select two nodes and merge them int32_t PickAndMerge(); void DfsAddNbhCost(std::vector& nbh_id2node_list_id, std::unordered_map& node_list_id2nbh_id, std::vector& order2nbh_id, std::vector& nbh_id2order, std::vector& order2acc_min_in_nbh_cost, std::vector>& out_nbh_costs, std::vector>& nbh_id2order2sbp_id, std::vector& min_sbp_sig_id, double& min_cost, int32_t order, double curr_cost) const; bool DfsFindReasonableCost(std::vector& nbh_id2node_list_id, std::unordered_map& node_list_id2nbh_id, std::vector& nbh_id2order, int32_t nbh_id) const; }; } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_ ================================================ FILE: oneflow/core/auto_parallel/sbp_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/auto_parallel/sbp_node.h" #include "oneflow/core/auto_parallel/sbp_edge.h" #include "oneflow/core/auto_parallel/sbp_graph.h" #include "oneflow/core/register/logical_blob_id.pb.h" namespace oneflow { namespace auto_parallel { // In dynamic programming, we can not minimize a vector (copy cost, memory cost) // Instead, we minimize the weighted sum of the vector, copy cost + kMemoryRatio * memory cost extern double kMemoryRatio; // function in cpp. Should be put in one file due to use of template // Otherwise we will need to declare specific template at the end of cpp file. SbpNode::SbpNode(SbpNode* first, SbpNode* second) { half_node_.resize(2); half_node_[0] = first; half_node_[1] = second; // Get the edge between first and second // NOTE: It must zero or one edge between them SbpEdge* common_edge = nullptr; for (int32_t k = 0; k < first->edges_in_.size(); k++) { if (first->edges_in_[k]->start_node_ == second) { // CHECK_ISNULL(edge); common_edge = first->edges_in_[k]; } } for (int32_t k = 0; k < first->edges_out_.size(); k++) { if (first->edges_out_[k]->end_node_ == second) { common_edge = first->edges_out_[k]; } } // Find all available merged-SbpSignature(edge's cost less than threshold). if (common_edge) { in_memory_support_ = first->in_memory_support_ || second->in_memory_support_ || common_edge->in_memory_support_; // If there is no one case can choose, we will blow up for (int32_t i = 0; i < first->weighted_cost_.size(); i++) { for (int32_t j = 0; j < second->weighted_cost_.size(); j++) { const double edge_weighted_cost = common_edge->start_node_ == first ? common_edge->weighted_cost_[i][j] : common_edge->weighted_cost_[j][i]; if (edge_weighted_cost < GetValidMaxCopyCost()) { merged_sig_id2half_sig_id_.emplace_back(std::make_pair(i, j)); if (in_memory_support_) { memory_.push_back((common_edge->start_node_ == first ? common_edge->GetMemory(i, j) : common_edge->GetMemory(j, i)) + first->GetMemory(i) + second->GetMemory(j)); } weighted_cost_.emplace_back(edge_weighted_cost + first->weighted_cost_[i] + second->weighted_cost_[j]); } } } CHECK(merged_sig_id2half_sig_id_.size() > 0) << "0 size for merge two half nodes with common edge!"; } else { in_memory_support_ = first->in_memory_support_ || second->in_memory_support_; for (int32_t i = 0; i < first->weighted_cost_.size(); i++) { for (int32_t j = 0; j < second->weighted_cost_.size(); j++) { merged_sig_id2half_sig_id_.emplace_back(std::make_pair(i, j)); if (in_memory_support_) { memory_.push_back(first->GetMemory(i) + second->GetMemory(j)); } weighted_cost_.emplace_back(first->weighted_cost_[i] + second->weighted_cost_[j]); } } } // Initialize default sbp choice // If the original sbp pair does not go through, then use 0 as default. final_sbp_sig_id_ = 0; // Track the original strategy for (int32_t sig_id = 0; sig_id < merged_sig_id2half_sig_id_.size(); sig_id++) { if (merged_sig_id2half_sig_id_[sig_id].first == first->final_sbp_sig_id_ && merged_sig_id2half_sig_id_[sig_id].second == second->final_sbp_sig_id_) { final_sbp_sig_id_ = sig_id; } } // Merge edges_in_ edges_in_.reserve(first->edges_in_.size() + second->edges_in_.size()); edges_in_.insert(edges_in_.end(), first->edges_in_.begin(), first->edges_in_.end()); edges_in_.insert(edges_in_.end(), second->edges_in_.begin(), second->edges_in_.end()); // Merge edges_out_ edges_out_.reserve(first->edges_out_.size() + second->edges_out_.size()); edges_out_.insert(edges_out_.end(), first->edges_out_.begin(), first->edges_out_.end()); edges_out_.insert(edges_out_.end(), second->edges_out_.begin(), second->edges_out_.end()); // Merge SbpEdge Cost for (SbpEdge*& this_edge : first->edges_in_) { this_edge->DuplicateCost(false, true, merged_sig_id2half_sig_id_); this_edge->end_node_ = this; } for (SbpEdge*& this_edge : first->edges_out_) { this_edge->DuplicateCost(true, true, merged_sig_id2half_sig_id_); this_edge->start_node_ = this; } for (SbpEdge*& this_edge : second->edges_in_) { this_edge->DuplicateCost(false, false, merged_sig_id2half_sig_id_); this_edge->end_node_ = this; } for (SbpEdge*& this_edge : second->edges_out_) { this_edge->DuplicateCost(true, false, merged_sig_id2half_sig_id_); this_edge->start_node_ = this; } // Remove edges from original nodes first->edges_in_.clear(); first->edges_out_.clear(); second->edges_in_.clear(); second->edges_out_.clear(); // Move edges between two nodes to each half node for (int32_t k = edges_out_.size() - 1; k >= 0; k--) { if (edges_out_[k]->end_node_ == this) { // Remove this edge from edges_out_ and edges_in_ and put it inside the node CheckAndRemoveFrom(edges_in_, edges_out_[k]); first->edges_out_.emplace_back(edges_out_[k]); second->edges_in_.emplace_back(edges_out_[k]); RemoveFrom(edges_out_, k); } } } SbpNode::~SbpNode() { for (auto& edge_out : edges_out_) { delete edge_out; } for (auto& child_node : children_) { if (child_node->edges_in_.size()) { delete child_node->edges_in_[0]; } delete child_node; } for (auto& half_node : half_node_) { delete half_node; } } void SbpNode::InitializeSbp() { global_sbp_sig_size_ = sbp_sig_list_.size(); cost_.resize(sbp_sig_list_.size()); }; // Let one node point to another void SbpNode::StartPointToEnd(SbpNode* start_node, SbpNode* end_node) { // generate the edge between them SbpEdge* e = new SbpEdge(start_node, end_node); start_node->edges_out_.emplace_back(e); end_node->edges_in_.emplace_back(e); }; void SbpNode::PointFrom(SbpNode* start_node) { StartPointToEnd(start_node, this); }; void SbpNode::PointTo(SbpNode* end_node) { StartPointToEnd(this, end_node); }; void SbpNode::SummarizeCost() { if (children_.size() == child_node_sbp_sig_.size()) { return; } int32_t previous_children_size = child_node_sbp_sig_.size(); child_node_sbp_sig_.resize(children_.size()); in_memory_support_ = in_memory_support_ || std::any_of(children_.begin() + previous_children_size, children_.end(), [](SbpNode* sbp_node) { return sbp_node->in_memory_support_; }); if (in_memory_support_) { memory_.resize(weighted_cost_.size(), 0); } // Buffer int64_t min_memory_cost = 0, memory_cost = 0; double min_weighted_sum = 0.0, weighted_sum = 0.0; int32_t min_sbp_child = 0; // Only deal with new children_ for (int32_t child = previous_children_size; child < children_.size(); child++) { child_node_sbp_sig_[child].resize(weighted_cost_.size()); for (int32_t sbp_this = 0; sbp_this < weighted_cost_.size(); sbp_this++) { SbpNode* child_node = children_[child]; for (int32_t sbp_child = 0; sbp_child < child_node->weighted_cost_.size(); sbp_child++) { if (child_node->edges_in_.size()) { // edge in graph: father -> child memory_cost = child_node->edges_in_[0]->GetMemory(sbp_this, sbp_child) + child_node->GetMemory(sbp_child); weighted_sum = child_node->edges_in_[0]->weighted_cost_[sbp_this][sbp_child] + child_node->weighted_cost_[sbp_child]; } else { // edge in graph: child -> father memory_cost = child_node->edges_out_[0]->GetMemory(sbp_child, sbp_this) + child_node->GetMemory(sbp_child); weighted_sum = child_node->edges_out_[0]->weighted_cost_[sbp_child][sbp_this] + child_node->weighted_cost_[sbp_child]; } // update min_cost with fixed SbpSignature for this node and child node if (sbp_child == 0 || weighted_sum < min_weighted_sum) { min_memory_cost = memory_cost; min_weighted_sum = weighted_sum; min_sbp_child = sbp_child; } } child_node_sbp_sig_[child][sbp_this] = min_sbp_child; // Add the cost for child node to this node if (in_memory_support_) { memory_[sbp_this] += min_memory_cost; } weighted_cost_[sbp_this] += min_weighted_sum; } } } bool SbpNode::EliminateItselfAsChild() { if (edges_in_.size() + edges_out_.size() == 1) { if (edges_in_.size()) { // edge in graph: father -> this_node SbpNode* father = edges_in_[0]->start_node_; father->children_.emplace_back(this); CheckAndRemoveFrom(father->edges_out_, edges_in_[0]); father->SummarizeCost(); } else { // edge in graph: this_node -> father SbpNode* father = edges_out_[0]->end_node_; father->children_.emplace_back(this); CheckAndRemoveFrom(father->edges_in_, edges_out_[0]); father->SummarizeCost(); } // successfully eliminate this node return true; } // can not eliminate this node return false; } // Compute the weighted sum of the time and memory cost void SbpNode::ComputeWeightedCost() { if (half_node_.empty()) { // If this node is not generated from merging, it should have original cost // weighted_cost_ = cost_; weighted_cost_ = origin_cost_; memory_ = origin_memory_; if (in_memory_support_) { for (int32_t sbp_id = 0; sbp_id < origin_memory_.size(); sbp_id++) { weighted_cost_[sbp_id] += kMemoryRatio * origin_memory_[sbp_id]; } } } else { half_node_[0]->ComputeWeightedCost(); half_node_[1]->ComputeWeightedCost(); // The edge between two half nodes SbpEdge* edge_found = nullptr; if (!half_node_[0]->edges_in_.empty()) { edge_found = half_node_[0]->edges_in_[0]; } else if (!half_node_[0]->edges_out_.empty()) { edge_found = half_node_[0]->edges_out_[0]; } if (edge_found != nullptr) { edge_found->ComputeWeightedCost(); } // Compute the weighted cost form half nodes for (int32_t merged_sig_id = 0; merged_sig_id < merged_sig_id2half_sig_id_.size(); merged_sig_id++) { const auto& pair = merged_sig_id2half_sig_id_[merged_sig_id]; if (in_memory_support_) { memory_[merged_sig_id] = half_node_[0]->GetMemory(pair.first) + half_node_[1]->GetMemory(pair.second); } weighted_cost_[merged_sig_id] = half_node_[0]->weighted_cost_[pair.first] + half_node_[1]->weighted_cost_[pair.second]; if (edge_found != nullptr) { // The dimension of weighted cost has been expand for the found edge. // Both the dimension of weighted_cost_ is merged_sig_id2half_sig_id_.size(). // The start node and end node is changed to this for the found edge. if (in_memory_support_) { memory_[merged_sig_id] += edge_found->GetMemory(merged_sig_id, merged_sig_id); } weighted_cost_[merged_sig_id] += edge_found->weighted_cost_[merged_sig_id][merged_sig_id]; } } } // Compute the weighted cost for children for (auto& child_node : children_) { child_node->ComputeWeightedCost(); for (auto& in_edge : child_node->edges_in_) { in_edge->ComputeWeightedCost(); } for (auto* out_edge : child_node->edges_out_) { out_edge->ComputeWeightedCost(); } } // Compute the weighted cost from children child_node_sbp_sig_.clear(); SummarizeCost(); } // Generate the relationship between this merged node and its components void SbpNode::GenerateComponentRelationship() { // Do nothing if not merged node or already generated if (half_node_.empty() || !component2merged_sig_id2component_sig_id_.empty()) { return; } // Add the map for two half nodes auto& first_merged2component_id = component2merged_sig_id2component_sig_id_[half_node_[0]]; auto& second_merged2component_id = component2merged_sig_id2component_sig_id_[half_node_[1]]; int32_t total_sbp_num = weighted_cost_.size(); first_merged2component_id.resize(total_sbp_num); second_merged2component_id.resize(total_sbp_num); for (int32_t i = 0; i < total_sbp_num; i++) { first_merged2component_id[i] = merged_sig_id2half_sig_id_[i].first; second_merged2component_id[i] = merged_sig_id2half_sig_id_[i].second; } // Add the map for the half of the half nodes for (int32_t i = 0; i < 2; i++) { half_node_[i]->GenerateComponentRelationship(); auto& merged2half_id = component2merged_sig_id2component_sig_id_[half_node_[i]]; for (auto& pair : half_node_[i]->component2merged_sig_id2component_sig_id_) { auto& merged2component_id = component2merged_sig_id2component_sig_id_[pair.first]; merged2component_id.resize(total_sbp_num); auto& half2component_id = pair.second; for (int32_t merged_id = 0; merged_id < total_sbp_num; merged_id++) { merged2component_id[merged_id] = half2component_id[merged2half_id[merged_id]]; } } } } void SbpNode::FinalizeSbp() { if (!half_node_.empty()) { // Finalize Sbp of merged nodes half_node_[0]->final_sbp_sig_id_ = merged_sig_id2half_sig_id_[final_sbp_sig_id_].first; half_node_[1]->final_sbp_sig_id_ = merged_sig_id2half_sig_id_[final_sbp_sig_id_].second; } // Finalize Sbp of children_ for (int32_t i = 0; i < children_.size(); i++) { children_[i]->final_sbp_sig_id_ = child_node_sbp_sig_[i][this->final_sbp_sig_id_]; } // Finalize Sbp of half_node_ Attachment if (!half_node_.empty()) { half_node_[0]->FinalizeSbp(); half_node_[1]->FinalizeSbp(); } // Finalize Sbp of edges in edges_out_ for (const auto& edge_out : edges_out_) { edge_out->FinalizeSbp(); } // Finalize Sbp again in case of the node on the other side is not finalized // yet. This may happen when Two side of an edge merged into two larger nodes // and this edge is just a sub edge. for (const auto& edge_in : edges_in_) { edge_in->FinalizeSbp(); } // Finalize Sbp of children_ Attachment for (int32_t i = 0; i < children_.size(); i++) { children_[i]->FinalizeSbp(); for (const auto& edge_in : children_[i]->edges_in_) { edge_in->FinalizeSbp(); } } } double SbpNode::GreedyStrategy() { // Current Cost, Minimum Cost, Cost with original sbp double curr_cost = 0; double original_cost = EvalNbhCost(); double min_cost = original_cost; int32_t min_sbp = final_sbp_sig_id_; for (int32_t sbp = 0; sbp < weighted_cost_.size(); sbp++) { final_sbp_sig_id_ = sbp; curr_cost = EvalNbhCost(); if (curr_cost < min_cost) { min_cost = curr_cost; min_sbp = sbp; } } final_sbp_sig_id_ = min_sbp; return min_cost - original_cost; } double SbpNode::EvalNbhCost() const { // Current Cost, Minimum Cost, Cost with original sbp double curr_cost = GetWeightedCost(); for (SbpEdge* this_edge : edges_in_) { curr_cost += this_edge->GetWeightedCost(); } for (SbpEdge* this_edge : edges_out_) { curr_cost += this_edge->GetWeightedCost(); } return curr_cost; } double SbpNode::EvalOutNbhCost( const std::unordered_map& node_list_id2nbh_id) const { // check if this node is in the node list CHECK(node_list_id_ >= 0) << "Compute out cost for a node out of the node list" << std::endl; // Cost with original sbp double curr_cost = GetWeightedCost(); for (SbpEdge* this_edge : edges_in_) { // if the start node is not in the neighborhood if (node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_) == node_list_id2nbh_id.end()) { curr_cost += this_edge->GetWeightedCost(); } } for (SbpEdge* this_edge : edges_out_) { // if the end node is not in the neighborhood if (node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_) == node_list_id2nbh_id.end()) { curr_cost += this_edge->GetWeightedCost(); } } return curr_cost; } // Compute the cost between this node and adjacent nodes with a lower order double SbpNode::EvalInNbhCost(const std::unordered_map& node_list_id2nbh_id, const std::vector& nbh_id2order) const { // check if this node is in the node list CHECK(node_list_id_ >= 0) << "Compute in cost for a node out of the node list"; // check if the node is in the neighborhood const auto& this_it = node_list_id2nbh_id.find(node_list_id_); CHECK(this_it != node_list_id2nbh_id.end()) << "Compute in cost for a node out of the neighborhood"; // Compute the minimum cost between this node and adjacent nodes with a lower order int32_t order = nbh_id2order[this_it->second]; double curr_cost = 0; for (SbpEdge* this_edge : edges_in_) { const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_); // if the start node is in the neighborhood if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) { curr_cost += this_edge->GetWeightedCost(); // End this function and return infinity. if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal(); } } } for (SbpEdge* this_edge : edges_out_) { const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_); // if the end node is in the neighborhood if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) { curr_cost += this_edge->GetWeightedCost(); if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal(); } } } return curr_cost; } double SbpNode::EvalMinInNbhCost(const std::unordered_map& node_list_id2nbh_id, const std::vector& nbh_id2order) const { // check if this node is in the node list CHECK(node_list_id_ >= 0) << "Compute out cost for a node out of the node list" << std::endl; // check if the node is in the neighborhood const auto& this_it = node_list_id2nbh_id.find(node_list_id_); CHECK(this_it != node_list_id2nbh_id.end()) << "Compute out cost for a node out of the neighborhood" << std::endl; // Compute the minimum cost between this node and adjacent nodes with a higher order int32_t order = nbh_id2order[this_it->second]; double curr_cost = 0; for (SbpEdge* this_edge : edges_in_) { const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_); // if the start node is in the neighborhood if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) { curr_cost += this_edge->GetMinWeightedCost(); } } for (SbpEdge* this_edge : edges_out_) { const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_); // if the end node is in the neighborhood if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) { curr_cost += this_edge->GetMinWeightedCost(); } } return curr_cost; } void SbpNode::OneRingNeighborhood(std::vector& nbh_1ring) const { nbh_1ring.resize(edges_in_.size() + edges_out_.size() + 1); int32_t nbh_id = 0; nbh_1ring[nbh_id] = node_list_id_; for (SbpEdge* this_edge : edges_in_) { nbh_id++; nbh_1ring[nbh_id] = this_edge->start_node_->node_list_id_; } for (SbpEdge* this_edge : edges_out_) { nbh_id++; nbh_1ring[nbh_id] = this_edge->end_node_->node_list_id_; } } // Get the n ring neighborhood of this node // Pre-allocate buffer, which will be faster. void SbpNode::NRingNeighborhood(int32_t n, std::vector& nbh_n_ring, std::vector& nbh_1ring, const std::vector& node_list, std::vector& node_tags) const { // Initialize 0 ring if (n <= 0) { n = 0; } nbh_n_ring.resize(1); nbh_n_ring[0] = node_list_id_; node_tags[node_list_id_] = true; int32_t l = 0; // do ring expansion for n times for (int32_t i = 0; i < n; i++) { for (int32_t r = nbh_n_ring.size(); l < r; l++) { node_list[nbh_n_ring[l]]->OneRingNeighborhood(nbh_1ring); for (auto nbh_id : nbh_1ring) { if (!node_tags[nbh_id]) { nbh_n_ring.push_back(nbh_id); node_tags[nbh_id] = true; } } } } // Recover false for buffer for (auto nbh_id : nbh_n_ring) { node_tags[nbh_id] = false; } } // Get or compute the minimum layer of this node int32_t SbpNode::GetMinLayer( const HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps) { if (min_layer_ >= 0) { return min_layer_; } if (!op_node_) { return min_layer_; } for (SbpEdge* this_edge : edges_in_) { int32_t producer_min_layer = this_edge->start_node_->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps); if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; } } for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { int32_t producer_min_layer = it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps); if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; } } } if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) { for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { int32_t producer_min_layer = it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps); if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; } } } } return ++min_layer_; } // Spread the minimum layer to compute the maximum layer of producers void SbpNode::SpreadMaxLayer( const HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps) { if (min_layer_ <= 0) { return; } int32_t producer_max_lay = min_layer_ - 1; for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->DropMaxLayer(producer_max_lay); } for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); } } if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) { for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); } } } } // Drop down the maximum layer with the minimum layer form consumer void SbpNode::DropMaxLayer(int32_t upper_bound) { if (upper_bound < max_layer_ || max_layer_ < 0) { max_layer_ = upper_bound; } } // Set max_layer_ = min_layer_ if this node does not have any consumer // This is the end of the whole graph // We could also set it to be the maximum of the min_layer_ in the graph. (It should be the same.) void SbpNode::LiftMaxLayer() { if (max_layer_ < min_layer_) { max_layer_ = min_layer_; } } // Set max_layer_ = upper_bound if this node does not have any consumer void SbpNode::LiftMaxLayer(int32_t upper_bound) { if (max_layer_ < min_layer_) { max_layer_ = upper_bound; } } // Get the minimum element in Cost double SbpNode::GetMinCost() const { // Check the size of Cost // Can not use weighted cost here since this function is used for find trunk. // We have not initialize weighted cost at this moment CHECK(cost_.size() > 0) << "Cost not initialized!" << std::endl; // Compute the min_comp_cost return *std::min_element(cost_.begin(), cost_.end()); } // Set the cut ratio double SbpNode::GetCutRatio() const { double curr_cut_ratio = 1.0; for (auto* this_edge : edges_in_) { curr_cut_ratio *= this_edge->GetCutRatio(); } for (auto* this_edge : edges_out_) { curr_cut_ratio *= this_edge->GetCutRatio(); } return curr_cut_ratio; } // Judge if this node is on the trunk // If so, judge it for its producer/upstream nodes void SbpNode::SpreadTrunk(const HashMap& op_name2sbp_node) { // Skip it if this node is already judged. if (on_trunk_) { return; } // Skip sbp proxy. This is before we have proxy. if (min_layer_ < 0) { return; } on_trunk_ = true; // If I am in the trunk, then all the children with (min_layer_ >= my layer id - 1) would be // considered as in the trunk for (SbpEdge* this_edge : edges_in_) { if (this_edge->start_node_->min_layer_ >= min_layer_ - 1) { this_edge->start_node_->SpreadTrunk(op_name2sbp_node); } } for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end() && it->second->min_layer_ >= min_layer_ - 1) { it->second->SpreadTrunk(op_name2sbp_node); } } } // Count consumers and any downstream nodes defined by control edges void SbpNode::RaiseConsumerNum(const HashMap& op_name2sbp_node) { // Should clear it before running. // skip the proxy nodes and the sources if (min_layer_ <= 0) { return; } for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->counter_++; } for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { it->second->counter_++; } } } // Compute the minimal available wait time for producers or upstream nodes void SbpNode::SpreadAvailWaitTime(const std::vector& trunk_cost, const std::vector& acc_trunk_cost, const HashMap& op_name2sbp_node, double wait_time) { // skip the proxy nodes and the sources if (min_layer_ <= 0) { return; } // Have not finished spreading for consumers or downstream nodes or already visited. if (counter_) { return; } if (on_trunk_) { // Nodes on the trunk does not have any accumulate cost acc_trunk_cost_ = 0; } else { if (acc_trunk_cost_ < 0) { // Do not have any consumer or downstream node acc_trunk_cost_ = acc_trunk_cost[min_layer_ - 1]; } else { // Add the trunk cost at this layer acc_trunk_cost_ += trunk_cost[min_layer_]; } } // Reduce the wait time for edges_in_, put the rest of the trunk cost in the producers for (SbpEdge* this_edge : edges_in_) { CHECK(this_edge->wait_time_ < 0) << "Double assign values into wait_time_ of this edge!" << std::endl; SbpNode* producer = this_edge->start_node_; // Accumulate the cost from the start node to this node double curr_trunk_cost = acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1]; if (curr_trunk_cost >= wait_time) { // Remain cost in the trunk is able to cover all the wait time this_edge->wait_time_ = 0.0; curr_trunk_cost -= wait_time; } else { // Remain cost in the trunk can only cover partial wait time this_edge->wait_time_ = wait_time - curr_trunk_cost; curr_trunk_cost = 0.0; } // Reducing non-matching edges // For example: // (1) P->S0->S0->S0->B // (2) p->B->B->B->B // We would use (2) when the tensor is relatively tiny. // Do not inherit trunk cost for nodes on the trunk if (!producer->on_trunk_) { // Inherit the minimal of the trunk cost from consumers producer->DropAvailWaitTime(curr_trunk_cost); } producer->counter_--; producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time); } // Put the rest the trunk cost in the upstream nodes. for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { SbpNode* producer = it->second; // Do not inherit trunk cost for nodes on the trunk if (!producer->on_trunk_) { // Accumulate the cost from the start node to this node double curr_trunk_cost = acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1]; // Inherit the minimal of the trunk cost from consumers producer->DropAvailWaitTime(curr_trunk_cost); } producer->counter_--; producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time); } } // Set counter_ to be -1, do not visit it again. counter_--; } // Drop down the available wait time with the minimum cost from downstream void SbpNode::DropAvailWaitTime(double curr_trunk_cost) { if (acc_trunk_cost_ < 0.0 || acc_trunk_cost_ > curr_trunk_cost) { acc_trunk_cost_ = curr_trunk_cost; } } // Assemble copy cost and partial memory cost for all the incoming edges void SbpNode::InitCopyAndMemoryCost(bool use_sbp_collector, bool nccl_not_use_compute_stream) { for (SbpEdge* this_edge : edges_in_) { const auto* sbp_node_producer = this_edge->start_node_; OpNode* producer = sbp_node_producer->op_node_; // skip it if proxy if (use_sbp_collector && !producer) { continue; } // look through input blobs for (const std::string& ibn : op_node_->op().input_bns()) { if (producer->op().op_name() == op_node_->SrcNode4Ibn(ibn).op().op_name()) { this_edge->InitCopyAndMemoryCost(ibn, use_sbp_collector, nccl_not_use_compute_stream); } } // Add Wait time for (auto& cost_row : this_edge->cost_) { for (auto& cost_value : cost_row) { // If transferring between devices, we need to add wait time. if (cost_value > 0.0) { cost_value += this_edge->wait_time_; } } } } } // Assemble memory cost void SbpNode::InitializeMemory(bool is_reusable, const HashMap& lbi2id, const std::vector& id2count, bool nccl_use_compute_stream) { const auto& curr_operator = op_node_->op(); // An edge should not be initialized twice // During each initialization, we are computing sum(memory of consumer) - sum(memory of producer) // This is why we need to pre-store memory of producer HashMap> sbp_edge2nd_sbp_sig2memory; for (const auto& obn : curr_operator.output_bns()) { const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn); // Fixed memory or in the support of the reusable memory if (!is_reusable || id2count.at(lbi2id.at(lbi)) > 0) { // If not in support, memory_ would be empty. in_memory_support_ = true; memory_.resize(sbp_sig_list_.size(), 0); const auto& logical_blob_desc = op_node_->LogicalBlobDesc4Lbi(lbi); const auto& hierarchy = *CHECK_JUST(curr_operator.GetParallelDesc4BnInOp(obn))->hierarchy(); // There are some operators with a fixed sbp for some blobs, such as conv. // {in: S0, kernel: B, out: S0} // {in: B, kernel: B, out: B} // The blob kernel have the same sbp for different signatures. // We pre-store the results for the same sbp while accessing the same blobs. HashMap nd_sbp2memory; SbpEdge* edge_contain_lbi = nullptr; for (const auto& edge_out : edges_out_) { if (edge_out->SearchLbi(lbi)) { edge_contain_lbi = edge_out; } } // There exist some lbi which does not have a consumer // At this moment edge_contain_lbi == nullptr auto& nd_sbp_sig2memory = sbp_edge2nd_sbp_sig2memory[edge_contain_lbi]; nd_sbp_sig2memory.resize(sbp_sig_list_.size(), 0); for (int32_t sbp_sig_id = 0; sbp_sig_id < sbp_sig_list_.size(); sbp_sig_id++) { const NdSbp& nd_sbp = sbp_sig_list_[sbp_sig_id].bn_in_op2nd_sbp().at(obn); auto it = nd_sbp2memory.find(nd_sbp); if (it == nd_sbp2memory.end()) { // This compute the memory at rank 0, the largest one. // We could be faster if we just compute the average memory. it = nd_sbp2memory .insert({nd_sbp, MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, hierarchy)}) .first; } memory_[sbp_sig_id] += it->second; nd_sbp_sig2memory[sbp_sig_id] += it->second; } } } // Even after the correction in the memory of edges, the relative error still have 0.73%. if (nccl_use_compute_stream && in_memory_support_ && is_reusable) { for (const auto& pair : sbp_edge2nd_sbp_sig2memory) { // Init memory for each out-going edge pair.first->InitializeMemory(lbi2id, id2count, pair.second); } } } // Reduce and set the wait time for op in the trunk void SbpNode::SetTrunkWaitTime(double trunk_wait_time) { // only reduce the wait time for operators in the trunk if (on_trunk_) { // Reduce the wait time for edges_out_ for (SbpEdge* edge_out : edges_out_) { if (edge_out->wait_time_ < 0.0 || edge_out->wait_time_ > trunk_wait_time) { edge_out->wait_time_ = trunk_wait_time; } } // Might reduce it for edges_in_ } } // Drop down the maximum layer with the minimum layer form consumer void SbpNode::DropTributaryLayer(int32_t upper_bound) { if (upper_bound < tributary_layer_ || tributary_layer_ < 0) { tributary_layer_ = upper_bound; } } // Compute maximum layer for tributaries void SbpNode::SpreadTributaryLayer(const HashMap& op_name2sbp_node) { if (counter_ || min_layer_ <= 0) { return; } int32_t producer_max_lay = 0; if (on_trunk_) { producer_max_lay = min_layer_ - 1; } else { // On a tributary, the operator could be run later. producer_max_lay = tributary_layer_; // producer_max_lay = tributary_layer_ - 1; } for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->DropTributaryLayer(producer_max_lay); if (--this_edge->start_node_->counter_ == 0) { this_edge->start_node_->SpreadTributaryLayer(op_name2sbp_node); } } for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { const auto& it = op_name2sbp_node.find(ctrl_in_op_name); if (it != op_name2sbp_node.end()) { it->second->DropTributaryLayer(producer_max_lay); if (--it->second->counter_ == 0) { it->second->SpreadTributaryLayer(op_name2sbp_node); } } } counter_--; } SbpEdge* SbpNode::FindEdgeWithNode(const SbpNode* other_node) const { for (auto* sbp_edge : edges_in_) { if (sbp_edge->start_node_ == other_node) { return sbp_edge; } } for (auto* sbp_edge : edges_out_) { if (sbp_edge->end_node_ == other_node) { return sbp_edge; } } return nullptr; }; // Decide to use this SbpSignature const NdSbpSignature& SbpNode::FinalSbpSignature() const { CHECK(!sbp_sig_list_.empty()) << "Asking for sbp signature for an empty node"; return sbp_sig_list_[final_sbp_sig_id_]; }; int32_t SbpNode::GetComponentSbpId(int32_t merged_id, SbpNode* component_node) const { if (this == component_node) { return merged_id; } CHECK(!component2merged_sig_id2component_sig_id_.empty()) << "Check the component before initialization!" << std::endl; return component2merged_sig_id2component_sig_id_.at(component_node).at(merged_id); } // Judge if sbp_node is a port of the current node bool SbpNode::IsComponent(SbpNode* sbp_node) const { if (this == sbp_node) { return true; } // If IsComponent() is call before we initialize component2merged_sig_id2component_sig_id_, // we would also return false. // Please do not call GenerateComponentRelationship() at here. // Please see SbpEdge::SummarizeCost() for more details. return component2merged_sig_id2component_sig_id_.find(sbp_node) != component2merged_sig_id2component_sig_id_.end(); } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_ #include #include #include #include #include "oneflow/core/auto_parallel/binary_set.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/job/sbp_parallel.pb.h" namespace oneflow { namespace auto_parallel { class SbpEdge; // A node structure to deal with the SBP strategy. // Please see SbpGraph for the whole algorithm and introduction. class SbpNode final { public: // default constructor SbpNode() : final_sbp_sig_id_(0) {} // This constructor is to merge two node into one SbpNode(SbpNode* first, SbpNode* second); ~SbpNode(); OF_DISALLOW_COPY_AND_MOVE(SbpNode); bool operator==(const SbpNode& other) { return this == &other; } // another node point to this node void PointFrom(SbpNode* start_node); // this node point to another node void PointTo(SbpNode* end_node); SbpEdge* FindEdgeWithNode(const SbpNode* other_node) const; // Check and eliminate one child node. // Only used by SbpGraph since it need to remove it from the NodeList after this. bool EliminateItselfAsChild(); // Initialize SbpSignature from Signature Objects void InitializeSbp(); // Decide to use this SbpSignature const NdSbpSignature& FinalSbpSignature() const; // Recompute Computation Cost after adding child nodes in it void SummarizeCost(); // Compute the weighted sum of the time and memory cost void ComputeWeightedCost(); // Generate the relationship between this merged node and its components void GenerateComponentRelationship(); // Determine Final SbpSignature for attachment of this node void FinalizeSbp(); // Use Greedy Strategy to pick the sbp signature with minimum cost for this // node You should have an initial strategy before running this double GreedyStrategy(); // Evaluate summery of cost between neighborhood and outside nodes double EvalOutNbhCost(const std::unordered_map& node_list_id2nbh_id) const; // Evaluate summery of cost within neighborhood // We only accumulate the edge cost with a lower order. double EvalInNbhCost(const std::unordered_map& node_list_id2nbh_id, const std::vector& nbh_id2order) const; // Evaluate summery of cost within neighborhood // We only accumulate the minimum edge cost with a higher order. double EvalMinInNbhCost(const std::unordered_map& node_list_id2nbh_id, const std::vector& nbh_id2order) const; // Get the one ring neighborhood of this node, which is itself and all the adjacent nodes. void OneRingNeighborhood(std::vector& nbh_1ring) const; // Get the n ring neighborhood of this node // Pre-allocate buffer, which will be faster. void NRingNeighborhood(int32_t n, std::vector& nbh_n_ring, std::vector& nbh_1ring, const std::vector& node_list, std::vector& node_tags) const; // Get or compute the minimum layer of this node int32_t GetMinLayer( const HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps); // Spread the minimum layer to compute the maximum layer of producers void SpreadMaxLayer( const HashMap& op_name2sbp_node, const HashMap>& op_node2mutable_op_ctrl_deps); // Set max_layer_ = min_layer_ if this node does not have any consumer void LiftMaxLayer(); // Set max_layer_ = upper_bound if this node does not have any consumer void LiftMaxLayer(int32_t upper_bound); // Compute maximum layer for tributaries void SpreadTributaryLayer(const HashMap& op_name2sbp_node); // Drop down the tributary layer void DropTributaryLayer(int32_t upper_bound); // Get the minimum element in Cost double GetMinCost() const; // get the cut ratio double GetCutRatio() const; // Judge if this node is on the trunk // If so, judge it for its producer/upstream nodes void SpreadTrunk(const HashMap& op_name2sbp_node); // Count consumers and any downstream nodes defined by control edges // for producers or upstream nodes void RaiseConsumerNum(const HashMap& op_name2sbp_node); // Compute the minimal available wait time for producers or upstream nodes void SpreadAvailWaitTime(const std::vector& trunk_cost, const std::vector& acc_trunk_cost, const HashMap& op_name2sbp_node, double wait_time); // Reduce and set the wait time for op in the trunk void SetTrunkWaitTime(double trunk_wait_time); // Assemble copy cost and partial memory cost for all the incoming edges void InitCopyAndMemoryCost(bool use_sbp_collector, bool nccl_not_use_compute_stream); // Assemble memory cost void InitializeMemory(bool is_reusable, const HashMap& lbi2id, const std::vector& id2count, bool nccl_use_compute_stream); // Constant getter int32_t GetMinLayer() const { return min_layer_; } int32_t GetTributaryLayer() const { return tributary_layer_; } OpNode* GetOperatorNode() const { return op_node_; } const std::vector& GetEdgesIn() const { return edges_in_; } const std::vector& GetEdgesOut() const { return edges_out_; } int64_t GetMemory(int32_t i) const { return in_memory_support_ ? memory_[i] : 0; } // Get the current memory with the current sbp signature index int64_t GetMemory() const { return GetMemory(final_sbp_sig_id_); } double GetWeightedCost(int32_t i) const { return weighted_cost_[i]; } // Get the current weighted cost with the current sbp signature index double GetWeightedCost() const { return GetWeightedCost(final_sbp_sig_id_); } int32_t GetComponentSbpId(int32_t merged_id, SbpNode* component_node) const; // Judge if sbp_node is a port of the current node bool IsComponent(SbpNode* sbp_node) const; // Setter void SetInMemorySupport(bool in_memory_support) { in_memory_support_ = in_memory_support; } private: friend class SbpEdge; friend class SbpGraph; friend class SbpCollector; friend class SbpConstructor; // compound edge in std::vector edges_in_; // compound edge out std::vector edges_out_; // Location in node_list of SbpGraph int32_t node_list_id_ = -1; // Global SbpSignature List Size int32_t global_sbp_sig_size_ = -1; // Decide to use SbpSignature with this id int32_t final_sbp_sig_id_; // Available SbpSignature object for this node std::vector sbp_sig_list_; // Cost[sbp] is Computation Cost when using sbp_sig_list_[sbp] std::vector cost_; std::vector origin_cost_; // Child node list std::vector children_; // SbpSignature for each child node when using specific SbpSignature for this // node Its dimension is Number of Child Nodes * Number of Available // SbpSignatures for this node std::vector> child_node_sbp_sig_; // Merge two nodes into this compound node std::vector half_node_; // We should delete those merged-signatures which has very large cost for speed up // New sbp_sig_list_ index map to each half_node_'s sig_index std::vector> merged_sig_id2half_sig_id_; std::vector parallel_candidates_; OpNode* op_node_ = nullptr; // We divide the sbp graph into multiple layers. // min_layer_ is the minimum layer number to run this op as soon as possible. // max_layer_ is the maximum layer number without slowing down the whole process of the graph. // producer.max_layer_ < this_node.min_layer_ <= this_node.max_layer_ < consumer.min_layer_ int32_t min_layer_ = -1, max_layer_ = -1; // Maximum layer in tributaries int32_t tributary_layer_ = -1; // Whether we are on the trunk bool on_trunk_ = false; // A counter_ buffer for topological traversal or something else int32_t counter_ = 0; // Accumulate trunk cost from consumer to the end double acc_trunk_cost_ = -1.0; // The produced blob belongs to the support of the total memory bool in_memory_support_ = false; // The consumed memory for different sbp strategies std::vector memory_; std::vector origin_memory_; // The weighted sum of time cost and memory cost // More specifically, weighted cost = time cost + kMemoryRatio * memory; // We do not add any weight for the time cost since we need to judge if a cost is less than // GetValidMaxCopyCost(). std::vector weighted_cost_; // Relationship between a merged node and its components HashMap> component2merged_sig_id2component_sig_id_; // Let one node point to another void StartPointToEnd(SbpNode* start_node, SbpNode* end_node); // Evaluate summery of cost in 1-ring neighborhood. double EvalNbhCost() const; // Drop down the maximum layer with the minimum layer from consumer void DropMaxLayer(int32_t upper_bound); // Drop down the available wait time with the minimum cost from downstream void DropAvailWaitTime(double curr_trunk_cost); }; // class SbpNode } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_ ================================================ FILE: oneflow/core/auto_parallel/sbp_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/auto_parallel/sbp_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" namespace oneflow { namespace auto_parallel { // Judge whether we need the same SBP for both producer and consumer bool RequireSameSbp(const OpNode* consumer, const std::string& ibn) { // is mutable const auto& input_blob_modifier_ = consumer->op().InputBlobModifier4Ibn(ibn); if (input_blob_modifier_.has_is_mutable() && input_blob_modifier_.is_mutable()) { return true; } // kOFRecord or kTensorBuffer don't accept boxing const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn); const OpNode& producer = consumer->ProducerOpNode4Lbi(lbi); const BlobDesc& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi); return (logical_blob_desc.data_type() == DataType::kOFRecord || logical_blob_desc.data_type() == DataType::kTensorBuffer); } } // namespace auto_parallel } // namespace oneflow ================================================ FILE: oneflow/core/auto_parallel/sbp_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_ #define ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_ #include "oneflow/core/graph/op_graph.h" namespace oneflow { namespace auto_parallel { // Judge whether we need the same SBP for both producer and consumer bool RequireSameSbp(const OpNode* consumer, const std::string& ibn); } // namespace auto_parallel } // namespace oneflow #endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_ ================================================ FILE: oneflow/core/autograd/autograd_captured_tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_ #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { class AutogradCapturedTensor final : public ProxyTensor { public: static Maybe MakeTensor(const std::shared_ptr& tensor) { if (tensor->requires_grad()) { CHECK_NOTNULL_OR_RETURN(tensor->grad_fn_node().get()) << Error::RuntimeError() << "a grad function node is expected for the captured tensor " "which requires_grad is True."; } std::shared_ptr captured_tensor( new AutogradCapturedTensor(JUST(tensor->detach()))); captured_tensor->set_autograd_meta(tensor->mut_autograd_meta()); captured_tensor->grad_fn_node_ = tensor->mut_grad_fn_node(); return captured_tensor; } std::shared_ptr grad_fn_node() const override { return grad_fn_node_.lock(); } void set_grad_fn_node(const std::shared_ptr& grad_fn_node) override { PRINT_BUG_PROMPT_AND_ABORT(); } std::shared_ptr mut_grad_fn_node() override { return grad_fn_node_.lock(); } std::shared_ptr contiguous() const override { const auto& tensor = std::const_pointer_cast(shared_from_this()); if (tensor_->is_contiguous()) { return tensor; } return CHECK_JUST(functional::ToContiguous(tensor)); } private: explicit AutogradCapturedTensor(const std::shared_ptr& tensor) : ProxyTensor(tensor) {} private: std::weak_ptr grad_fn_node_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_ ================================================ FILE: oneflow/core/autograd/autograd_engine.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "fmt/core.h" #include "fmt/format.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/error.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/global_param_grad_sync_mode.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" namespace oneflow { namespace one { namespace { void GatherFunctionNodes(FunctionNode* node, std::stack>& stack) { for (auto& prev_node : node->next_functions()) { auto prev_node_fun = std::get<0>(prev_node); if (prev_node_fun) { if (prev_node_fun.use_count() == 1) { stack.push(prev_node_fun); } } } } /* NOTE: * Stack overflows when releasing a very deep computation graph without * a custom deleter. * * For example, here is a very deep computation graph: * Tensor -> FunctionNode -> Tensor -> FunctionNode -> ... -> Tensor -> FunctionNode * When releasing the first Tensor, it will trigger the recursive deletion and stack overflow. * * So we must set a custom deleter and release them iteratively. */ void FunctionNodeDeleter(FunctionNode* node) { std::stack> stack; node->ReleaseData(); GatherFunctionNodes(node, stack); delete node; while (!stack.empty()) { auto now_node = std::move(stack.top()); stack.pop(); now_node->ReleaseData(); GatherFunctionNodes(now_node.get(), stack); } } bool IsReadyToRun(const std::vector>& out_meta_datas) { return std::any_of(out_meta_datas.begin(), out_meta_datas.end(), [](const std::shared_ptr& meta_data) { return !meta_data->current_grad()->Empty(); }); } Maybe CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) { autograd::AutoGradMode mode(autograd_mode); auto current_grad = JUST(autograd_meta->current_grad_value()); if (!current_grad) { return Maybe::Ok(); } if (autograd_meta->acc_grad()) { JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1.0, /*inplace=*/true)); } else { // NOTE: acc_grad can not share data with current_grad, because accumulate acc_grad // with inplace operation and it maybe change current_grad to get wrong result. // See more details in https://github.com/Oneflow-Inc/oneflow/issues/8248 if (!LazyMode::is_enabled()) { current_grad = JUST(functional::Identity(current_grad)); } JUST(autograd_meta->set_acc_grad(current_grad)); } for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) { auto new_grad = hook(autograd_meta->acc_grad()); if (new_grad) { JUST(autograd_meta->set_acc_grad(new_grad)); } } return Maybe::Ok(); } Maybe RawTouchGlobalTensor(const std::shared_ptr& tensor) { // Do nothing. return Maybe::Ok(); } static constexpr auto* TouchGlobalTensor = DECORATE(&RawTouchGlobalTensor, CheckGlobalTensorMeta); Maybe CheckGlobalTensorsMeta(const TensorTuple& tensor_tuple) { for (const auto& tensor : tensor_tuple) { if (tensor->is_global() && tensor->is_eager()) { JUST(TouchGlobalTensor(tensor)); } } return Maybe::Ok(); } std::string GetDebugGraphFileName(const std::string& mode, const std::string& suffix) { return fmt::format("autograd_{}_rank{}_suffix_graph.dot", mode, GlobalProcessCtx::Rank(), suffix); } } // namespace Maybe AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) { JUST(CheckGlobalTensorsMeta(outputs)); JUST(CheckGlobalTensorsMeta(out_grads)); DisableCheckGlobalTensorMetaScope disable_meta_check; return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph); } Maybe AutogradEngine::RunBackwardAndReturnInputsTensorGradIf( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused) { JUST(CheckGlobalTensorsMeta(outputs)); JUST(CheckGlobalTensorsMeta(inputs)); JUST(CheckGlobalTensorsMeta(out_grads)); DisableCheckGlobalTensorMetaScope disable_meta_check; return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph, create_graph, allow_unused); } Maybe FunctionNode::AccGrad4RetainGradTensor(bool create_graph) { for (const std::shared_ptr& out : output_meta_data_) { if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), create_graph)); } } return Maybe::Ok(); } Maybe FunctionNode::AccGrad4LeafTensor(bool create_graph) { for (auto i = 0; i < output_meta_data_.size(); i++) { auto& out = output_meta_data_[i]; if (out->is_leaf() && out->requires_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/create_graph)); // control acc_grad to do boxing conditionally const auto& acc_grad = out->acc_grad(); if (!LazyMode::is_enabled() && GlobalGradSyncMode::is_enabled() && acc_grad->is_global() && acc_grad->is_eager()) { auto& tensor_info = output_tensor_infos_[i]; const auto& placement = JUST(tensor_info.placement()); const auto& nd_sbp = JUST(tensor_info.sbp()); JUST(out->set_acc_grad( JUST(functional::ToGlobal(acc_grad, placement, *JUST(GetSbpList(nd_sbp)), GetNoneSbpList(), /* check_meta */ false, /*copy=*/false)))); } } } return Maybe::Ok(); } void FunctionNode::ReleaseOutTensorArgs() { for (const std::shared_ptr& meta_data : output_meta_data_) { meta_data->current_grad()->Release(); } } Maybe FunctionNode::Apply(bool create_graph) { CHECK_NOTNULL_OR_RETURN(backward_fn_) << "This FunctionNode with name `" << name() << "` has been released.\n" << "Maybe you try to backward through the node a second time. Specify retain_graph=True when " "calling .backward() or autograd.grad() the first time."; if (!IsReadyToRun(output_meta_data_)) { return false; } TensorTuple input_grads(input_meta_data_.size()); TensorTuple output_grads(output_meta_data_.size()); for (int i = 0; i < output_meta_data_.size(); ++i) { if (output_meta_data_[i]->current_grad()->Empty()) { // Only initialize out_grads for those requires_grad outputs if (output_meta_data_[i]->requires_grad()) { output_grads[i] = JUST(output_tensor_infos_[i].zeros()); } } else { JUST(oneflow::VectorAt(output_grads, i)) = JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad_value()); } } JUST(backward_fn_->body(output_grads, &input_grads, create_graph)); for (const auto& hook : hooks_) { auto new_input_grads = hook(input_grads, output_grads); if (new_input_grads.has_value()) { auto new_input_grads_value = *JUST(new_input_grads); CHECK_EQ_OR_RETURN(new_input_grads_value.size(), input_grads.size()) << "The number of input grads returned by hook is not correct, expected " << input_grads.size() << ", but got " << new_input_grads_value.size() << "."; for (int i = 0; i < input_grads.size(); ++i) { input_grads[i] = new_input_grads_value[i]; } } } for (int i = 0; i < input_meta_data_.size(); ++i) { if (JUST(VectorAt(input_grads, i))) { CHECK_NOTNULL_OR_RETURN(input_meta_data_[i]) << name_ << " calculate grad for tensor which requires_grad is False. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible"; JUST(input_meta_data_[i]->current_grad()->PushPartialTensor(JUST(VectorAt(input_grads, i)))); } else { CHECK_OR_RETURN(!input_meta_data_[i]) << name() << "'s input[" << i << "] need calculate grad but got nullptr. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible;"; } } return true; } void GraphFunctionNode::ReleaseData() { if (backward_fn_ && backward_fn_->status()) { backward_fn_.reset(); } } /*static*/ std::shared_ptr GraphFunctionNode::New( const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, const TensorTuple& outputs) { auto node = std::shared_ptr( new GraphFunctionNode(name, backward_fn, inputs, outputs), FunctionNodeDeleter); return node; } GraphFunctionNode::GraphFunctionNode(const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, const TensorTuple& outputs) : FunctionNode(name, backward_fn) { input_meta_data_.resize(inputs.size()); next_functions_.reserve(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { if (inputs.at(i)->requires_grad()) { input_meta_data_.at(i) = inputs.at(i)->mut_autograd_meta(); next_functions_.emplace_back(inputs.at(i)->mut_grad_fn_node(), 0); } } output_meta_data_.resize(outputs.size()); output_tensor_infos_.reserve(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { const auto& autograd_meta = NewAutogradMeta(outputs.at(i)->requires_grad(), outputs.at(i)->is_leaf()); outputs.at(i)->set_autograd_meta(autograd_meta); output_meta_data_.at(i) = outputs.at(i)->mut_autograd_meta(); output_tensor_infos_.emplace_back(*outputs.at(i)); } backward_fn_ = backward_fn; } GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph) : retain_graph_(retain_graph), create_graph_(create_graph) { roots_.reserve(outputs.size()); for (const auto& out_tensor : outputs) { FunctionNode* node = out_tensor->mut_grad_fn_node().get(); roots_.emplace_back(node); } } Maybe GraphTask::WriteGraphToDotFile(const std::string& file_name) const { auto ExecInfoToDotString = [](const ExecInfo& exec_info) -> std::string { std::stringstream ss; ss << "ExecInfo{\\l"; ss << "\tdependencies: " << exec_info.dependencies << "\\l"; ss << "\tneed_execute: " << exec_info.need_execute << "\\l"; if (exec_info.capture_indices) { ss << "\tcapture_indices: ["; for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) { ss << out_idx_and_capture_idx.second << ", "; } ss << "]\\l"; } ss << "}\\l"; return ss.str(); }; auto log_stream = TeePersistentLogStream::Create(file_name); std::vector lines; lines.emplace_back("digraph AutogradTaskGraph {"); lines.emplace_back("\tmargin=\"1.5\";"); lines.emplace_back("\tnode [shape=box];"); for (auto iter = grad_fn2exec_info_.begin(); iter != grad_fn2exec_info_.end(); ++iter) { const FunctionNode* node = iter->first; const ExecInfo& exec_info = iter->second; // write label attribute std::string node_color = "black"; if (exec_info.dependencies == 0 && exec_info.need_execute) { // start node node_color = "red"; } else if (exec_info.need_execute && exec_info.capture_indices) { // end node node_color = "green"; } lines.emplace_back(fmt::format( "\t\"{}\" [label=\"{}\\l{}\\l{}\", color={}];", static_cast(node), node->name(), static_cast(node), ExecInfoToDotString(exec_info), node_color)); // write edge for (const auto& next_fn : node->next_functions()) { lines.emplace_back(fmt::format("\t\"{}\" -> \"{}\";", static_cast(node), static_cast(std::get<0>(next_fn).get()))); } } lines.emplace_back("}"); log_stream << fmt::format("{}", fmt::join(lines, "\n")); log_stream->Flush(); return Maybe::Ok(); } // Computes the number of dependencies for each FunctionNode Maybe GraphTask::ComputeDependencies() { HashSet seen; std::stack stack; for (FunctionNode* node : roots_) { stack.push(node); grad_fn2exec_info_[node].need_execute = true; } while (!stack.empty()) { FunctionNode* node = stack.top(); stack.pop(); if (/*bool has_seen=*/!seen.insert(node).second) { continue; } for (const auto& next_grad_fn : node->next_functions()) { FunctionNode* next_node = std::get<0>(next_grad_fn).get(); ExecInfo& exec_info = grad_fn2exec_info_[next_node]; exec_info.dependencies += 1; exec_info.need_execute = true; if (seen.find(next_node) == seen.end()) { stack.push(next_node); } } } return Maybe::Ok(); } // Computes the number of dependencies for each FunctionNode and prunes useless FunctionNode // according to input tensors Maybe GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs, bool allow_unused) { struct NodeFrame { explicit NodeFrame(FunctionNode* node) : node_(node), next_function_idx_(0) {} FunctionNode* node_; size_t next_function_idx_; FunctionNode* GetNextFunction() { if (next_function_idx_ < node_->next_functions().size()) { next_function_idx_ += 1; return std::get<0>(node_->next_functions().at(next_function_idx_ - 1)).get(); } else { return nullptr; } } }; // initialize all variable to capture grad for input tensors captured_grads_ = std::make_shared(inputs.size()); for (int idx = 0; idx < inputs.size(); idx++) { const auto& input = inputs[idx]; if (allow_unused && !input->mut_grad_fn_node().get()) { continue; } CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()) << Error::RuntimeError() << "One of the differentiated Tensors appears to not have been used in the graph. Set " "allow_unused=True if this is the desired behavior."; ExecInfo& exec_info = grad_fn2exec_info_[input->mut_grad_fn_node().get()]; exec_info.need_execute = true; if (!exec_info.capture_indices) { exec_info.capture_indices = std::make_unique>>(); } exec_info.capture_indices->emplace_back(std::make_pair(input->get_grad_fn_output_index(), idx)); } HashSet seen; std::stack stack; // Note: dfs to determine each FunctionNode should execute or not. for (const auto& root : roots_) { stack.push(NodeFrame(root)); } while (!stack.empty()) { NodeFrame& frame = stack.top(); if (/*bool has_seen=*/seen.find(frame.node_) != seen.end()) { stack.pop(); continue; } if (FunctionNode* node = frame.GetNextFunction()) { grad_fn2exec_info_[node].dependencies += 1; if (seen.find(node) == seen.end()) { stack.push(NodeFrame(node)); continue; // recurse } } else { for (auto& fn : frame.node_->next_functions()) { grad_fn2exec_info_[frame.node_].need_execute |= grad_fn2exec_info_[std::get<0>(fn).get()].need_execute; } seen.insert(frame.node_); stack.pop(); } } return Maybe::Ok(); } Maybe GraphTask::Apply(bool save_grad_for_leaf) { std::queue queue; for (FunctionNode* node : roots_) { if (grad_fn2exec_info_[node].dependencies == 0) { queue.push(node); } } while (!queue.empty()) { FunctionNode* node = queue.front(); queue.pop(); auto& exec_info = grad_fn2exec_info_[node]; if (!exec_info.need_execute) { node->ReleaseOutTensorArgs(); continue; } BackwardPassScopeGuard backward_guard(node->scope()); if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; } if (exec_info.capture_indices) { CHECK_NOTNULL_OR_RETURN(captured_grads_.get()) << "captured grads in GraphTask is nullptr"; for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) { JUST(VectorAt(*captured_grads_, out_idx_and_capture_idx.second)) = JUST(JUST(VectorAt(node->output_meta_data_, out_idx_and_capture_idx.first)) ->current_grad_value()); } } if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); } JUST(node->AccGrad4RetainGradTensor(create_graph_)); node->ReleaseOutTensorArgs(); if (!retain_graph_) { node->ReleaseData(); } for (const auto& next_grad_fn : node->next_functions()) { FunctionNode* next_node = std::get<0>(next_grad_fn).get(); int32_t& dependencies = grad_fn2exec_info_[next_node].dependencies; dependencies -= 1; if (dependencies == 0) { queue.push(next_node); } } } return Maybe::Ok(); } Maybe GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) { for (int i = 0; i < outputs.size(); ++i) { JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); } GraphTask graph_task(outputs, retain_graph, create_graph); JUST(graph_task.ComputeDependencies()); if (IsInDebugMode()) { JUST( graph_task.WriteGraphToDotFile(GetDebugGraphFileName("backward", std::to_string(clock())))); } JUST(graph_task.Apply(/*save_grad_for_leaf=*/true)); return Maybe::Ok(); } Maybe GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused) { for (int i = 0; i < outputs.size(); ++i) { JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); } GraphTask graph_task(outputs, retain_graph, create_graph); JUST(graph_task.ComputeDependenciesAndPruneNode(inputs, allow_unused)); if (IsInDebugMode()) { JUST(graph_task.WriteGraphToDotFile(GetDebugGraphFileName("grad", std::to_string(clock())))); } JUST(graph_task.Apply(/*save_grad_for_leaf=*/false)); return graph_task.GetCapturedGrads(); } Maybe GraphAutogradEngine::AddNode( const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, TensorTuple* outputs) { OF_PROFILER_RANGE_PUSH("AddAccumulateFunctionNode"); // Firstly push function_node of tensor in stack which is leaf and requires_grad for (const std::shared_ptr& in_tensor : inputs) { if (in_tensor->is_leaf() && in_tensor->requires_grad()) { if (!in_tensor->grad_fn_node()) { JUST(AddAccumulateFunctionNode(in_tensor)); } } } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("set_grad_fn_node"); std::shared_ptr func_node = GraphFunctionNode::New(name, backward_fn, inputs, *outputs); for (int i = 0; i < outputs->size(); ++i) { const std::shared_ptr& out_tensor = JUST(VectorAt(*outputs, i)); out_tensor->set_grad_fn_node(func_node); out_tensor->set_grad_fn_output_index(i); } if (LazyMode::is_enabled()) { func_node->set_scope(JUST(GetCurrentScope())); } OF_PROFILER_RANGE_POP(); return func_node; } AutogradEngine* GetThreadLocalAutogradEngine() { thread_local static GraphAutogradEngine autograd_engine; return &autograd_engine; } Maybe AddAccumulateFunctionNode(const std::shared_ptr& tensor) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { return Maybe::Ok(); }; backward_fn->status = []() { return false; }; tensor->set_grad_fn_node(GraphFunctionNode::New("accumulategrad", backward_fn, /*inputs=*/TensorTuple{}, /*outputs*/ TensorTuple{tensor})); tensor->mut_grad_fn_node()->set_variable(tensor); tensor->set_grad_fn_output_index(0); if (LazyMode::is_enabled()) { tensor->mut_grad_fn_node()->set_scope(JUST(GetTensorScope(tensor))); } return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/autograd_engine.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_ #include #include #include #include #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { class Tensor; class TensorTuple; using CaptureStatus = bool; struct BackwardFunction { std::function(const TensorTuple&, TensorTuple*, bool)> body; std::function status; }; // Calculates one backward op class FunctionNode { public: virtual ~FunctionNode() = default; Maybe Apply(bool create_graph); Maybe AccGrad4LeafTensor(bool create_graph); Maybe AccGrad4RetainGradTensor(bool create_graph); void ReleaseOutTensorArgs(); // Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling // `Apply` in second time virtual void ReleaseData() = 0; const std::vector, int>>& next_functions() const { return next_functions_; } const std::string& name() const { return name_; } const std::shared_ptr& scope() const { return scope_; } void set_scope(const std::shared_ptr& scope) { scope_ = scope; } void set_variable(const std::weak_ptr& variable) { variable_ = variable; } const Maybe Variable() const { if (!variable_.lock()) { THROW(RuntimeError) << "The tensor has already been deleted!"; } return variable_.lock(); } using Hook = std::function>>(const TensorTuple&, const TensorTuple&)>; void add_post_hook(const Hook& hook) { hooks_.push_back(hook); } protected: friend class GraphTask; explicit FunctionNode(const std::string& name, const std::shared_ptr& backward_fn) : name_(name), backward_fn_(backward_fn), scope_(nullptr) {} const std::string name_; std::vector, int>> next_functions_; std::vector> input_meta_data_; std::vector> output_meta_data_; std::vector output_tensor_infos_; // Actual backward function builds in `AutogradInterpreter` to calculate one backward op std::shared_ptr backward_fn_; std::weak_ptr variable_; // The execution scope std::shared_ptr scope_; std::vector hooks_; }; class AutogradEngine { public: virtual ~AutogradEngine() = default; Maybe RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph); Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused); virtual void ClearEngine() = 0; // Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine virtual Maybe AddNode(const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, TensorTuple* outputs) = 0; protected: AutogradEngine() = default; private: virtual Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) = 0; virtual Maybe RunBackwardAndReturnInputsTensorGrad( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused) = 0; }; // Graph Autograd Node and Engine class GraphFunctionNode final : public FunctionNode { public: OF_DISALLOW_COPY_AND_MOVE(GraphFunctionNode); static std::shared_ptr New( const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, const TensorTuple& outputs); GraphFunctionNode() = delete; ~GraphFunctionNode() override = default; void ReleaseData() override; private: GraphFunctionNode(const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, const TensorTuple& outputs); }; class GraphTask final { public: OF_DISALLOW_COPY_AND_MOVE(GraphTask); GraphTask() = delete; GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph); Maybe ComputeDependencies(); Maybe ComputeDependenciesAndPruneNode(const TensorTuple& inputs, bool allow_unused); Maybe Apply(bool save_grad_for_leaf); std::shared_ptr GetCapturedGrads() const { return captured_grads_; } Maybe WriteGraphToDotFile(const std::string& file_name) const; private: class ExecInfo { public: ExecInfo() = default; int32_t dependencies = 0; bool need_execute = false; // Used in autograd.grad interface, to record which grad of tensor will be captured. // The pair means: std::unique_ptr>> capture_indices; }; bool retain_graph_; bool create_graph_; std::vector roots_; HashMap grad_fn2exec_info_; std::shared_ptr captured_grads_; }; class GraphAutogradEngine final : public AutogradEngine { public: OF_DISALLOW_COPY_AND_MOVE(GraphAutogradEngine); GraphAutogradEngine() = default; ~GraphAutogradEngine() override = default; void ClearEngine() override{}; Maybe AddNode(const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, TensorTuple* outputs) override; private: Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) override; Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph, bool allow_unused) override; }; AutogradEngine* GetThreadLocalAutogradEngine(); Maybe AddAccumulateFunctionNode(const std::shared_ptr& tensor); } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_ ================================================ FILE: oneflow/core/autograd/autograd_function.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/autograd/autograd_function.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { /*static*/ Maybe AutogradFunctionBase::Apply(const std::string& name, const FType& forward_fn, const FType& backward_fn, const TensorTuple& inputs) { std::shared_ptr outputs = std::make_shared(); const auto& op = JUST(FunctionOpExpr::New(name, forward_fn, backward_fn)); JUST(OpInterpUtil::Dispatch(*op, inputs, outputs.get(), {})); const HashSet& non_differentiable_tensors = op->state()->NonDifferentiableTensors(); for (const auto& tensor : *outputs) { if (non_differentiable_tensors.find(tensor.get()) != non_differentiable_tensors.end()) { JUST(tensor->set_requires_grad(false)); } } return outputs; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/autograd_function.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace one { class TensorTuple; class FunctionAutoGradCaptureState; class FunctionOpExpr; class AutogradFunctionBase { public: using FType = std::function( const std::shared_ptr&, const TensorTuple&)>; AutogradFunctionBase() = default; virtual ~AutogradFunctionBase() = default; static Maybe Apply(const std::string& name, const FType& forward_fn, const FType& backward_fn, const TensorTuple& inputs); }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_ ================================================ FILE: oneflow/core/autograd/autograd_meta.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) { if (tensor.is_global()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); nd_sbp_ = CHECK_JUST(tensor.nd_sbp()); } else { device_ = CHECK_JUST(tensor.device()); } } Maybe>&> GetSbpTuple(Symbol nd_sbp) { static thread_local HashMap, std::vector>> map; auto iter = map.find(nd_sbp); if (iter == map.end()) { std::vector> sbp_tuple; sbp_tuple.reserve(nd_sbp->sbp_parallel().size()); for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) { sbp_tuple.push_back(SymbolOf(sbp_parallel)); } iter = map.emplace(nd_sbp, sbp_tuple).first; } return iter->second; } Maybe TensorInfo::zeros() const { if (device_.has_value()) { const auto& device = JUST(device_); return functional::Constant(*shape_.get(), 0, dtype_, device); } else { const auto& parallel_desc = JUST(parallel_desc_); const auto& nd_sbp = JUST(nd_sbp_); const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp)); return functional::GlobalConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple); } } AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf) : is_leaf_(is_leaf), requires_grad_(requires_grad), retain_grad_(false), current_grad_(new TensorArg) {} Maybe AutogradMeta::set_acc_grad(const std::shared_ptr& grad) { // NOTE(daquexian): update here if we support remat on global tensors if (grad && acc_grad_ != nullptr && acc_grad_->is_eager() && acc_grad_->is_local()) { // set old acc_grad evictable if (auto rematable_storage = std::dynamic_pointer_cast( JUST(acc_grad_->eager_blob_object())->tensor_storage())) { rematable_storage->set_eviction_disabled(false); } } if (const auto& static_zeros_tensor = std::dynamic_pointer_cast(grad)) { acc_grad_ = JUST(static_zeros_tensor->AsLocalTensor()); } else { acc_grad_ = grad; } if (acc_grad_ != nullptr && acc_grad_->is_eager() && acc_grad_->is_local()) { // set new acc_grad non-evictable if (auto rematable_storage = std::dynamic_pointer_cast( JUST(acc_grad_->eager_blob_object())->tensor_storage())) { rematable_storage->set_eviction_disabled(true); } } return Maybe::Ok(); } Maybe AutogradMeta::current_grad_value() const { std::shared_ptr res = JUST(current_grad_->GetAccTensor()); for (const auto& hook : hooks_) { const auto& new_tensor = hook(res); if (new_tensor) { res = new_tensor; } } return res; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/autograd_meta.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ #include #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" namespace oneflow { class Shape; class Device; class ParallelDesc; class NdSbp; namespace one { class Tensor; class TensorArg; class LocalTensor; class AutogradMeta final { public: AutogradMeta() = delete; AutogradMeta(bool requires_grad, bool is_leaf); // Getters const std::shared_ptr& acc_grad() const { return acc_grad_; } const std::shared_ptr& current_grad() const { return current_grad_; } // get current grad processed by hooks Maybe current_grad_value() const; bool requires_grad() const { return requires_grad_; } bool is_leaf() const { return is_leaf_; } bool retain_grad() const { return retain_grad_; } using Hook = std::function(const std::shared_ptr&)>; const std::vector& hooks() const { return hooks_; } const std::vector& post_grad_accumulation_hooks() const { return post_grad_accumulation_hooks_; } // Setters Maybe set_acc_grad(const std::shared_ptr& grad); std::shared_ptr mut_acc_grad() { return acc_grad_; } void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; } void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; } void add_hook(const Hook& hook) { hooks_.emplace_back(hook); } void add_post_grad_accumulation_hook(const Hook& hook) { post_grad_accumulation_hooks_.emplace_back(hook); } private: bool is_leaf_; // Only meaningful on leaf Tensors (must be false otherwise) bool requires_grad_; // Only meaningful on non_leaf Tensors (must be false otherwise) bool retain_grad_; std::shared_ptr acc_grad_; std::shared_ptr current_grad_; std::vector hooks_; std::vector post_grad_accumulation_hooks_; }; inline std::shared_ptr NewAutogradMeta(bool requires_grad, bool is_leaf) { return std::shared_ptr(new AutogradMeta(requires_grad, is_leaf)); } class TensorInfo final { public: TensorInfo() = delete; explicit TensorInfo(const Tensor& tensor); Maybe zeros() const; Optional> placement() const { return parallel_desc_; } Optional> sbp() const { return nd_sbp_; } private: std::shared_ptr shape_; Symbol dtype_; Optional> device_; // for local tensor Optional> parallel_desc_; // for global tensor Optional> nd_sbp_; // for global tensor }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ ================================================ FILE: oneflow/core/autograd/autograd_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/autograd/autograd_mode.h" namespace oneflow { namespace autograd { namespace { bool* GetThreadLocalGradMode() { static thread_local bool g_grad_mode = true; return &g_grad_mode; } } // namespace bool GradMode::is_enabled() { return *GetThreadLocalGradMode(); } void GradMode::set_enabled(bool enabled) { *GetThreadLocalGradMode() = enabled; } } // namespace autograd } // namespace oneflow ================================================ FILE: oneflow/core/autograd/autograd_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ namespace oneflow { namespace autograd { struct GradMode { static bool is_enabled(); static void set_enabled(bool enabled); }; class AutoGradMode { public: AutoGradMode(bool enabled) : prev_mode_(GradMode::is_enabled()) { GradMode::set_enabled(enabled); } ~AutoGradMode() { GradMode::set_enabled(prev_mode_); } bool prev_mode() const { return prev_mode_; } private: bool prev_mode_; }; class NoGradGuard : public AutoGradMode { public: NoGradGuard() : AutoGradMode(false){}; }; } // namespace autograd } // namespace oneflow #endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ ================================================ FILE: oneflow/core/autograd/gradient_funcs/activation.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BaseActivationCaptureState : public AutoGradCaptureState { bool requires_grad; }; class BaseActivation : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } }; class Silu : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::SiluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class Mish : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::MishGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class Selu : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::SeluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class Softsign : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::SoftSignGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class GeLU : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class FastGeLU : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::FastGeluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; struct QuickGeluCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class QuickGeLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(QuickGeluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const QuickGeluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::QuickGeluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; struct SquareReLUCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class SquareReLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(SquareReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const SquareReLUCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::SquareReLUGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; class HardSigmoid : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; struct HardShrinkCaptureState : public AutoGradCaptureState { bool requires_grad = true; double lambd = 0.5; }; class HardShrink : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(HardShrinkCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->lambd = JUST(composed_attrs.GetAttr("lambd")); ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0))); return Maybe::Ok(); } Maybe Apply(const HardShrinkCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::HardShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->lambd)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; class HardSwish : public BaseActivation { public: Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; // ===== Activation with parms ==== struct ReLUCaptureState : public AutoGradCaptureState { bool requires_grad; }; class ReLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y)); } return Maybe::Ok(); } }; // ===== Activation with parms ==== struct LeakyReluCaptureState : public AutoGradCaptureState { bool requires_grad; float alpha; }; class LeakyRelu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct SoftplusCaptureState : public AutoGradCaptureState { bool requires_grad = true; double beta = 1.0; double threshold = 20.0; }; class Softplus : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SoftplusCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->beta = JUST(composed_attrs.GetAttr("beta")); ctx->threshold = JUST(composed_attrs.GetAttr("threshold")); ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); return Maybe::Ok(); } Maybe Apply(const SoftplusCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::SoftplusGrad( x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->beta, ctx->threshold)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct HardTanhCaptureState : public AutoGradCaptureState { bool requires_grad; double min_val; double max_val; }; class HardTanh : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->min_val = JUST(composed_attrs.GetAttr("min_val")); ctx->max_val = JUST(composed_attrs.GetAttr("max_val")); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct EluCaptureState : public AutoGradCaptureState { bool requires_grad; double alpha; }; class Elu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const EluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct CeluCaptureState : public AutoGradCaptureState { bool requires_grad = true; double alpha = 1.0; }; class Celu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(CeluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe Apply(const CeluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::CeluGrad(y, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct SoftShrinkCaptureState : public AutoGradCaptureState { bool requires_grad = true; double alpha = 0.5; }; class SoftShrink : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SoftShrinkCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0))); return Maybe::Ok(); } Maybe Apply(const SoftShrinkCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::SoftShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct PReLUCaptureState : public AutoGradCaptureState { bool input_requires_grad; bool alpha_requires_grad; }; class PReLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input ctx->alpha_requires_grad = inputs.at(1)->requires_grad(); // alpha ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads.at(0); const auto& x = ctx->SavedTensors().at(0); const auto& alpha = ctx->SavedTensors().at(1); in_grads->resize(2); if (ctx->input_requires_grad || ctx->alpha_requires_grad) { const auto& grads = JUST(functional::PReluGrad(dy, x, alpha)); if (ctx->input_requires_grad) { in_grads->at(0) = grads->at(0); } if (ctx->alpha_requires_grad) { in_grads->at(1) = grads->at(1); } } return Maybe::Ok(); } private: std::shared_ptr grad_op_; }; struct ThresholdCaptureState : public AutoGradCaptureState { bool requires_grad = true; double threshold = 0.0; }; class Threshold : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ThresholdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->threshold = JUST(composed_attrs.GetAttr("threshold_val")); ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); return Maybe::Ok(); } Maybe Apply(const ThresholdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::ThresholdGrad(x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->threshold)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct FracCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class Frac : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FracCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } return Maybe::Ok(); } Maybe Apply(const FracCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("frac", Frac); REGISTER_OP_EXPR_GRAD_FUNCTION("silu", Silu); REGISTER_OP_EXPR_GRAD_FUNCTION("mish", Mish); REGISTER_OP_EXPR_GRAD_FUNCTION("selu", Selu); REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", Softsign); REGISTER_OP_EXPR_GRAD_FUNCTION("relu", ReLU); REGISTER_OP_EXPR_GRAD_FUNCTION("gelu", GeLU); REGISTER_OP_EXPR_GRAD_FUNCTION("hardsigmoid", HardSigmoid); REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink", HardShrink); REGISTER_OP_EXPR_GRAD_FUNCTION("hardswish", HardSwish); REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu); REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh); REGISTER_OP_EXPR_GRAD_FUNCTION("elu", Elu); REGISTER_OP_EXPR_GRAD_FUNCTION("celu", Celu); REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU); REGISTER_OP_EXPR_GRAD_FUNCTION("threshold", Threshold); REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus); REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink); REGISTER_OP_EXPR_GRAD_FUNCTION("fast_gelu", FastGeLU); REGISTER_OP_EXPR_GRAD_FUNCTION("quick_gelu", QuickGeLU); REGISTER_OP_EXPR_GRAD_FUNCTION("square_relu", SquareReLU); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/adaptive_avg_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct AdaptivePoolCaptureState : public AutoGradCaptureState { std::string data_format; bool requires_grad; }; class AdaptivePoolNdGrad : public OpExprGradFunction { public: using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, std::string mode, const int& ndims); Maybe Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::string mode_; int32_t ndims_; }; Maybe AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const int& ndims) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); mode_ = mode; ndims_ = ndims; return Maybe::Ok(); } Maybe AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); ctx->data_format = JUST(attrs.GetAttr("data_format")); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::AdaptivePoolNdGrad(x, out_grads.at(0), mode_, ndims_, ctx->data_format)); return Maybe::Ok(); } class AdaptiveAvgPool1dGrad final : public AdaptivePoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 1); } }; class AdaptiveAvgPool2dGrad final : public AdaptivePoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 2); } }; class AdaptiveAvgPool3dGrad final : public AdaptivePoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 3); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool1d", AdaptiveAvgPool1dGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool2d", AdaptiveAvgPool2dGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool3d", AdaptiveAvgPool3dGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/adaptive_max_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct AdaptiveMaxPoolCaptureState : public AutoGradCaptureState { std::string data_format; bool requires_grad = false; }; class AdaptiveMaxPoolNdGrad : public OpExprGradFunction { public: using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, const int& ndims); Maybe Capture(AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: int32_t ndims_ = 0; }; Maybe AdaptiveMaxPoolNdGrad::Init(const OpExpr& op, const int& ndims) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) ndims_ = ndims; return Maybe::Ok(); } Maybe AdaptiveMaxPoolNdGrad::Capture(AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->data_format = JUST(attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(outputs.at(1)); return Maybe::Ok(); } Maybe AdaptiveMaxPoolNdGrad::Apply(const AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); const std::shared_ptr& index = ctx->SavedTensors().at(1); in_grads->resize(1); in_grads->at(0) = JUST(functional::AdaptiveMaxPoolNdGrad(x, out_grads.at(0), index, ndims_, ctx->data_format)); return Maybe::Ok(); } class AdaptiveMaxPool1dGrad final : public AdaptiveMaxPoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 1); } }; class AdaptiveMaxPool2dGrad final : public AdaptiveMaxPoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 2); } }; class AdaptiveMaxPool3dGrad final : public AdaptiveMaxPoolNdGrad { public: Maybe Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 3); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d", AdaptiveMaxPool1dGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d", AdaptiveMaxPool2dGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d", AdaptiveMaxPool3dGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/add_n.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct AddNCaptureState : public AutoGradCaptureState { int32_t input_num; std::vector requires_grad; }; class AddN : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->input_num = inputs.size(); ctx->requires_grad.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); } return Maybe::Ok(); } Maybe Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(ctx->input_num); for (int i = 0; i < ctx->input_num; ++i) { if (ctx->requires_grad.at(i)) { in_grads->at(i) = out_grads.at(0); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("add_n", AddN); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/affine_grid.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct AffineGridInterpState : public AutoGradCaptureState { Shape size; bool align_corners = false; bool requires_grad = false; }; class AffineGrid : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(AffineGridInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); // theta if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->size = JUST(composed_attrs.GetAttr("size")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); return Maybe::Ok(); } Maybe Apply(const AffineGridInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); in_grads->at(0) = JUST(functional::AffineGridGrad(out_grads.at(0), ctx->size, ctx->align_corners)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("affine_grid", AffineGrid); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/amp_white_identity.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { enum class AmpIdentityType { kWhite = 0, kBlack, }; struct AmpIdentityCaptureState : public AutoGradCaptureState {}; template class AmpIdentityGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(AmpIdentityCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { return Maybe::Ok(); } Maybe Apply(const AmpIdentityCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); if (type == AmpIdentityType::kWhite) { (*in_grads)[0] = JUST(functional::AmpWhiteIdentity(out_grads[0])); } else if (type == AmpIdentityType::kBlack) { (*in_grads)[0] = JUST(functional::AmpBlackIdentity(out_grads[0])); } else { (*in_grads)[0] = out_grads[0]; } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("amp_white_identity", AmpIdentityGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("amp_black_identity", AmpIdentityGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/as_strided.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct AsStridedCaptureState : public AutoGradCaptureState { std::vector size; std::vector stride; int64_t storage_offset = 0; bool requires_grad = false; }; class AsStrided : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe AsStrided::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->size = JUST(composed_attrs.GetAttr>("size")); ctx->stride = JUST(composed_attrs.GetAttr>("stride")); ctx->storage_offset = JUST(composed_attrs.GetAttr("storage_offset")); return Maybe::Ok(); } Maybe AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& input = ctx->SavedTensors().at(0); std::vector size = ctx->size; std::vector stride = ctx->stride; int64_t storage_offset = ctx->storage_offset; in_grads->at(0) = JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("as_strided", AsStrided); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/avg_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { namespace { struct AvgPoolCaptureState : public AutoGradCaptureState { bool requires_grad = false; size_t input_index = 0; std::string data_format; std::vector padding; std::vector kernel_size; std::vector stride; bool ceil_mode = false; bool count_include_pad = false; int32_t divisor_override = 0; }; class AvgPoolNdGrad : public OpExprGradFunction { public: virtual ~AvgPoolNdGrad() = default; Maybe Init(const OpExpr& op) override; Maybe Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe AvgPoolNdGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe AvgPoolNdGrad::Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding = JUST(composed_attrs.GetAttr>("padding")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->stride = JUST(composed_attrs.GetAttr>("stride")); ctx->ceil_mode = JUST(composed_attrs.GetAttr("ceil_mode")); ctx->count_include_pad = JUST(composed_attrs.GetAttr("count_include_pad")); ctx->divisor_override = JUST(composed_attrs.GetAttr("divisor_override")); return Maybe::Ok(); } Maybe AvgPoolNdGrad::Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) int32_t ndims = ctx->kernel_size.size(); const auto& input = ctx->SavedTensors().at(ctx->input_index); in_grads->resize(1); (*in_grads)[0] = JUST(functional::AvgPoolNdGrad( input, out_grads[0], ndims, ctx->data_format, ctx->padding, ctx->kernel_size, ctx->stride, ctx->ceil_mode, ctx->count_include_pad, ctx->divisor_override)); return Maybe::Ok(); } } // namespace REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_1d", AvgPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_2d", AvgPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_3d", AvgPoolNdGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/batch_gather.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BatchGatherCaptureState : public AutoGradCaptureState { int64_t num_segments; bool requires_grad; }; class BatchGather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe BatchGather::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& in_shape = inputs.at(0)->shape(); const auto& indices_shape = inputs.at(1)->shape(); ctx->num_segments = in_shape->At(indices_shape->NumAxes() - 1); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& indices = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::UnsortedBatchSegmentSum(out_grads.at(0), indices, ctx->num_segments)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("batch_gather", BatchGather); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/bias_add.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BiasAddCaptureState : public AutoGradCaptureState { bool input_requires_grad; bool bias_requires_grad; int32_t axis; }; class BiasAdd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(BiasAddCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->bias_requires_grad = inputs.at(1)->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); return Maybe::Ok(); } Maybe Apply(const BiasAddCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const int64_t num_axes = out_grads.at(0)->shape()->NumAxes(); in_grads->resize(2); if (ctx->bias_requires_grad) { std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes); for (int i = 0; i < num_axes; ++i) { if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); } } if (ctx->bias_requires_grad) { in_grads->at(1) = JUST(functional::ReduceSum(out_grads.at(0), reduce_axes_vec, false, NullOpt)); } } if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("bias_add", BiasAdd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool target_requires_grad = false; bool has_weight = false; }; class BinaryCrossEntropy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe BinaryCrossEntropy::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 3); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs[0]->requires_grad(); ctx->target_requires_grad = inputs[1]->requires_grad(); ctx->has_weight = inputs.size() == 3; ctx->SaveTensorForBackward(inputs[0]); // input ctx->SaveTensorForBackward(inputs[1]); // target if (ctx->has_weight) { ctx->SaveTensorForBackward(inputs[2]); // weight } return Maybe::Ok(); } Maybe BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2 + ctx->has_weight); // NOLINT(maybe-need-error-msg) in_grads->resize(2 + ctx->has_weight); const auto& dy = out_grads[0]; const auto& input = ctx->SavedTensors()[0]; const auto& target = ctx->SavedTensors()[1]; const auto& weight = ctx->has_weight ? Optional(ctx->SavedTensors()[2]) : NullOpt; if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight)); } if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::BinaryCrossEntropyLossTargetGrad(dy, input, target, weight)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy", BinaryCrossEntropy); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool target_requires_grad = false; bool has_weight = false; bool has_pos_weight = false; }; class BinaryCrossEntropyWithLogits : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe BinaryCrossEntropyWithLogits::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs[0]->requires_grad(); ctx->target_requires_grad = inputs[1]->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->has_pos_weight = JUST(composed_attrs.GetAttr("has_pos_weight")); ctx->has_weight = inputs.size() == 4 || (inputs.size() == 3 && !ctx->has_pos_weight); ctx->SaveTensorForBackward(inputs[0]); // input ctx->SaveTensorForBackward(inputs[1]); // target if (inputs.size() == 3) { ctx->SaveTensorForBackward(inputs[2]); // weight or pos_weight } if (inputs.size() == 4) { ctx->SaveTensorForBackward(inputs[2]); // weight ctx->SaveTensorForBackward(inputs[3]); // pos_weight } return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads[0]; const auto& input = ctx->SavedTensors()[0]; const auto& target = ctx->SavedTensors()[1]; in_grads->resize(ctx->SavedTensors().size()); size_t pos_weight_index = ctx->has_weight ? 3 : 2; auto weight = ctx->has_weight ? Optional(ctx->SavedTensors()[2]) : NullOpt; auto pos_weight = ctx->has_pos_weight ? Optional(ctx->SavedTensors()[pos_weight_index]) : NullOpt; if (ctx->input_requires_grad) { (*in_grads)[0] = JUST( functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight)); } if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsLossTargetGrad( dy, input, target, weight, pos_weight)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool target_requires_grad = false; }; class BinaryCrossEntropyWithLogitsReduceMean : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsReduceMean::Capture( BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->target_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad(); ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsReduceMean::Apply( const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& dy = JUST(VectorAt(out_grads, 0)); const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1)); in_grads->resize(2); if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target)); } if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad(dy, input, target)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_reduce_mean", BinaryCrossEntropyWithLogitsReduceMean); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BroadcastBinaryCaptureState : public AutoGradCaptureState { int x_index = -1; int y_index = -1; int z_index = -1; bool x_requires_grad = false; bool y_requires_grad = false; bool broadcast_x = false; bool broadcast_y = false; }; class BroadcastBinaryGrad : public OpExprGradFunction { public: BroadcastBinaryGrad() = default; virtual ~BroadcastBinaryGrad() = default; virtual Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->broadcast_x = (*inputs.at(0)->shape() != *outputs.at(0)->shape()); ctx->broadcast_y = (*inputs.at(1)->shape() != *outputs.at(0)->shape()); return SaveTensorForBackward(ctx, inputs, outputs); } protected: virtual Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const = 0; }; class BroadcastAdd : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x)); } else { in_grads->at(0) = out_grads.at(0); } } if (ctx->y_requires_grad) { if (ctx->broadcast_y) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), y)); } else { in_grads->at(1) = out_grads.at(0); } } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad && ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); } if (ctx->y_requires_grad && ctx->broadcast_y) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd); class BroadcastSub : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x)); } else { in_grads->at(0) = out_grads.at(0); } } if (ctx->y_requires_grad) { const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), Scalar(-1.f), false)); if (ctx->broadcast_y) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(grad, y)); } else { in_grads->at(1) = grad; } } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad && ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); } if (ctx->y_requires_grad && ctx->broadcast_y) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub); class BroadcastMul : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); const auto& x_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(y)))); if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x)); } else { in_grads->at(0) = x_grad; } } if (ctx->y_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_index); const auto& y_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(x)))); if (ctx->broadcast_y) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y)); } else { in_grads->at(1) = y_grad; } } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); } } if (ctx->y_requires_grad) { if (ctx->x_index == -1 /*x has not been saved*/) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); } if (ctx->broadcast_y && ctx->y_index == -1 /*y has not been saved*/) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul); class BroadcastDiv : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); // const auto& x_grad = JUST(functional::Div(out_grads.at(0), y)); const auto& x_grad = JUST(functional::Div(out_grads.at(0), JUST(functional::Conj(y)))); if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x)); } else { in_grads->at(0) = x_grad; } } if (ctx->y_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); const auto& z = ctx->SavedTensors().at(ctx->z_index); in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), z, y)); } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); } } if (ctx->y_requires_grad) { if (ctx->y_index == -1 /*y has not been saved*/) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv); class BroadcastPow : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(ctx->x_index); const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->resize(2); if (ctx->x_requires_grad) { (*in_grads)[0] = JUST(functional::BroadcastPowXGrad(x, y, out_grads[0])); } if (ctx->y_requires_grad) { (*in_grads)[1] = JUST(functional::BroadcastPowYGrad(x, y, out_grads[0])); } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { ctx->x_index = ctx->SaveTensorForBackward(inputs[0]); ctx->y_index = ctx->SaveTensorForBackward(inputs[1]); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_pow", BroadcastPow); class BroadcastMinMax : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& out_shape = *(out_grads.at(0)->shape()); in_grads->resize(2); if (ctx->x_requires_grad || ctx->y_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_index); const auto& y = ctx->SavedTensors().at(ctx->y_index); auto broad_x_ = x; auto broad_y_ = y; if (ctx->broadcast_x) { const auto& x_shape = *(x->shape()); const Shape& left_extended_x_shape = CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes()); if (left_extended_x_shape == out_shape) { broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0)))); } else { const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape); const std::vector x_axis = std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis)); } } if (ctx->broadcast_y) { const auto& y_shape = *(y->shape()); const Shape& left_extended_y_shape = CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes()); if (left_extended_y_shape == out_shape) { broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0)))); } else { const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape); const std::vector y_axis = std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis)); } } const auto& broad_grads = JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_)); if (ctx->x_requires_grad) { if (ctx->broadcast_x) { in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(0), x)); } else { in_grads->at(0) = broad_grads->at(0); } } if (ctx->y_requires_grad) { if (ctx->broadcast_y) { in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(1), y)); } else { in_grads->at(1) = broad_grads->at(1); } } } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad || ctx->y_requires_grad) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } std::function(const std::shared_ptr&, const std::shared_ptr&, const std::shared_ptr&)> elementwise_grad_functor_; }; class BroadcastMinimum : public BroadcastMinMax { public: Maybe Init(const OpExpr& op) override { JUST(BroadcastMinMax::Init(op)); elementwise_grad_functor_ = functional::ElementwiseMinGrad; return Maybe::Ok(); } }; class BroadcastMaximum : public BroadcastMinMax { public: Maybe Init(const OpExpr& op) override { JUST(BroadcastMinMax::Init(op)); elementwise_grad_functor_ = functional::ElementwiseMaxGrad; return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum); REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum); class BroadcastFMod : public BroadcastBinaryGrad { public: Maybe Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& out_shape = *(JUST(VectorAt(out_grads, 0))->shape()); in_grads->resize(2); if (ctx->x_requires_grad || ctx->y_requires_grad) { const auto& x = JUST(VectorAt(ctx->SavedTensors(), ctx->x_index)); const auto& y = JUST(VectorAt(ctx->SavedTensors(), ctx->y_index)); auto broad_x_ = x; auto broad_y_ = y; if (ctx->broadcast_x) { const auto& x_shape = *(x->shape()); const Shape& left_extended_x_shape = CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes()); if (left_extended_x_shape == out_shape) { broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0)))); } else { const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape); const std::vector x_axis = std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis)); } } if (ctx->broadcast_y) { const auto& y_shape = *(y->shape()); const Shape& left_extended_y_shape = CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes()); if (left_extended_y_shape == out_shape) { broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0)))); } else { const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape); const std::vector y_axis = std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis)); } } if (ctx->x_requires_grad) { if (ctx->broadcast_x) { JUST(VectorAt(*in_grads, 0)) = JUST(functional::BroadcastReduceSumLike(JUST(VectorAt(out_grads, 0)), x)); } else { JUST(VectorAt(*in_grads, 0)) = JUST(VectorAt(out_grads, 0)); } } if (ctx->y_requires_grad) { auto result = JUST(functional::TruncDiv(broad_x_, broad_y_)); result = JUST(functional::Mul(JUST(VectorAt(out_grads, 0)), result)); JUST(functional::ScalarMul(result, Scalar(-1.f), true)); if (ctx->broadcast_y) { in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(result, y)); } else { in_grads->at(1) = result; } } } return Maybe::Ok(); } protected: Maybe SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs) const override { if (ctx->x_requires_grad && ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); } if (ctx->y_requires_grad) { ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->y_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/broadcast_like.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BroadCastLikeCaptureState : public AutoGradCaptureState { bool requires_grad; size_t input_index; std::vector broadcast_axes; }; class BroadCastLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe BroadCastLike::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe BroadCastLike::Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->broadcast_axes = JUST(composed_attrs.GetAttr>("broadcast_axes")); ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe BroadCastLike::Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& x = ctx->SavedTensors().at(ctx->input_index); in_grads->resize(2); in_grads->at(0) = JUST(functional::ReduceSumLike(out_grads.at(0), x, ctx->broadcast_axes)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_like", BroadCastLike); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/cast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/symbol.h" namespace oneflow { namespace one { struct CastCaptureState : public AutoGradCaptureState { Symbol in_dtype; Symbol out_dtype; }; class Cast : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(CastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->in_dtype = inputs.at(0)->dtype(); ctx->out_dtype = outputs.at(0)->dtype(); return Maybe::Ok(); } Maybe Apply(const CastCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); if (!IsComplexDataType(ctx->in_dtype->data_type()) && IsComplexDataType(ctx->out_dtype->data_type())) { (*in_grads)[0] = JUST(functional::Real(out_grads[0])); } else { (*in_grads)[0] = JUST(functional::Cast(out_grads[0], ctx->in_dtype, /*pin_memory=*/false)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("cast", Cast); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ClipByScalarCaptureState : public AutoGradCaptureState { bool requires_grad; Scalar min; Scalar max; }; class ClipByScalar : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ClipByScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) { ctx->min = Scalar(JUST(composed_attrs.GetAttr("floating_min"))); ctx->max = Scalar(JUST(composed_attrs.GetAttr("floating_max"))); } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) { ctx->min = Scalar(JUST(composed_attrs.GetAttr("integral_min"))); ctx->max = Scalar(JUST(composed_attrs.GetAttr("integral_max"))); } else { UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type."; } return Maybe::Ok(); } Maybe Apply(const ClipByScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min, ctx->max)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar", ClipByScalar); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ClipByScalarMaxCaptureState : public AutoGradCaptureState { bool requires_grad; Scalar max; }; class ClipByScalarMax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ClipByScalarMaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) { ctx->max = Scalar(JUST(composed_attrs.GetAttr("floating_max"))); } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) { ctx->max = Scalar(JUST(composed_attrs.GetAttr("integral_max"))); } else { UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type."; } return Maybe::Ok(); } Maybe Apply(const ClipByScalarMaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, /*min=*/NullOpt, ctx->max)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_max", ClipByScalarMax); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ClipByScalarMinCaptureState : public AutoGradCaptureState { bool requires_grad; Scalar min; }; class ClipByScalarMin : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ClipByScalarMinCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) { ctx->min = Scalar(JUST(composed_attrs.GetAttr("floating_min"))); } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) { ctx->min = Scalar(JUST(composed_attrs.GetAttr("integral_min"))); } else { UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type."; } return Maybe::Ok(); } Maybe Apply(const ClipByScalarMinCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min, /*max=*/NullOpt)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_min", ClipByScalarMin); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct CombinedMarginLossCaptureState : public AutoGradCaptureState { float m1; float m2; float m3; int64_t depth; size_t label_index; size_t theta_index; bool requires_grad; }; class CombinedMarginLoss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(CombinedMarginLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); // x if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1)); // label ctx->theta_index = ctx->SaveTensorForBackward(outputs.at(1)); // theta ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->m1 = JUST(composed_attrs.GetAttr("m1")); ctx->m2 = JUST(composed_attrs.GetAttr("m2")); ctx->m3 = JUST(composed_attrs.GetAttr("m3")); ctx->depth = JUST(composed_attrs.GetAttr("depth")); return Maybe::Ok(); } Maybe Apply(const CombinedMarginLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad) { const auto& label = ctx->SavedTensors().at(ctx->label_index); const auto& theta = ctx->SavedTensors().at(ctx->theta_index); in_grads->at(0) = JUST(functional::CombinedMarginLossGrad( out_grads.at(0), label, theta, ctx->m1, ctx->m2, ctx->m3, ctx->depth)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("combined_margin_loss", CombinedMarginLoss); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/complex.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct BaseComplexCaptureState : public AutoGradCaptureState { bool requires_grad; }; // TODO(lml): redesign these Apply method to support high order autograd. class RealGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { const auto& results = JUST(functional::RealGrad(out_grads.at(0))); in_grads->at(0) = results; } return Maybe::Ok(); } }; class ImagGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { const auto& results = JUST(functional::ImagGrad(out_grads.at(0))); in_grads->at(0) = results; } return Maybe::Ok(); } }; class ConjPhysicalGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { const auto& results = JUST(functional::ConjPhysical(out_grads.at(0))); in_grads->at(0) = results; } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("real", RealGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("imag", ImagGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("conj_physical", ConjPhysicalGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/concat.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ConcatCaptureState : public AutoGradCaptureState { std::vector requires_grad; int64_t axis; int64_t input_num; }; class Concat : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ConcatCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Concat::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Concat::Capture(ConcatCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); } ctx->input_num = inputs.size(); return Maybe::Ok(); } Maybe Concat::Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(ctx->input_num); TensorTuple like(ctx->input_num); for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); } if (ctx->input_num == 1) { in_grads->at(0) = out_grads.at(0); } else { const auto& results = JUST(functional::SplitLike(out_grads.at(0), like, ctx->axis)); CHECK_EQ_OR_RETURN(results->size(), ctx->input_num) << Error::RuntimeError() << "The size of results (" << results->size() << ") must match the size of inputs (" << ctx->input_num << ")"; for (int i = 0; i < ctx->input_num; ++i) if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("cat", Concat); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/conv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ConvolutionNdCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool weight_requires_grad = false; bool has_bias = false; bool bias_requires_grad = false; size_t input_index; size_t weight_index; std::string data_format; std::vector padding_before; std::vector kernel_size; std::vector strides; std::vector dilation_rate; int32_t groups; }; class ConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ConvolutionNd::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_OR_RETURN(inputs.size() == 2 || inputs.size() == 3); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); if (inputs.size() == 3) { ctx->has_bias = true; ctx->bias_requires_grad = inputs.at(2)->requires_grad(); } if (!ctx->input_requires_grad && !ctx->weight_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } if (ctx->input_requires_grad) { ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->groups = JUST(composed_attrs.GetAttr("groups")); return Maybe::Ok(); } Maybe ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (ctx->has_bias) { in_grads->resize(3); } else { in_grads->resize(2); } size_t num_spatial_dims = ctx->kernel_size.size(); if (ctx->input_requires_grad) { const auto& weight = ctx->SavedTensors().at(ctx->weight_index); const auto& input = ctx->SavedTensors().at(ctx->input_index); in_grads->at(0) = JUST(functional::ConvDataGrad( out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } if (ctx->weight_requires_grad) { const auto& input = ctx->SavedTensors().at(ctx->input_index); in_grads->at(1) = JUST(functional::ConvFilterGrad( out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } if (ctx->bias_requires_grad) { std::vector dim; for (int i = 0; i < out_grads.at(0)->shape()->NumAxes(); ++i) { if ((ctx->data_format == "channels_first" && i == 1) || (ctx->data_format == "channels_last" && i == out_grads.at(0)->shape()->NumAxes() - 1)) { continue; } dim.push_back(i); } in_grads->at(2) = JUST(functional::ReduceSum(out_grads.at(0), dim, false, NullOpt)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd); REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd); REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/copy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct CopyCaptureState : public AutoGradCaptureState { std::string device_type; int64_t device_id; }; class Copy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) if (inputs[0]->is_global()) { ctx->device_type = JUST(inputs[0]->parallel_desc())->device_tag(); ctx->device_id = 0; // global tensor only has one local device } else { ctx->device_type = JUST(inputs[0]->device())->type(); ctx->device_id = JUST(inputs[0]->device())->device_id(); } return Maybe::Ok(); } Maybe Apply(const CopyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); (*in_grads)[0] = JUST( functional::Copy(out_grads[0], ctx->device_type, ctx->device_id, /*pin_memory=*/false)); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("copy", Copy); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/ctc_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct CTCLossCaptureState : public AutoGradCaptureState { int64_t max_target_length; int32_t blank; bool zero_infinity; bool requires_grad; }; class CTCLoss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::shared_ptr grad_op_; }; Maybe CTCLoss::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->max_target_length = JUST(composed_attrs.GetAttr("max_target_length")); ctx->blank = JUST(composed_attrs.GetAttr("blank")); ctx->zero_infinity = JUST(composed_attrs.GetAttr("zero_infinity")); CHECK_EQ_OR_RETURN(inputs.size(), 4); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->SaveTensorForBackward(outputs.at(0)); // loss ctx->SaveTensorForBackward(outputs.at(1)); // alpha ctx->SaveTensorForBackward(inputs.at(0)); // log_probs ctx->SaveTensorForBackward(inputs.at(1)); // targets ctx->SaveTensorForBackward(inputs.at(2)); // input_lengths ctx->SaveTensorForBackward(inputs.at(3)); // target_lengths return Maybe::Ok(); } Maybe CTCLoss::Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& grad_out = out_grads.at(0); const auto& loss = ctx->SavedTensors().at(0); const auto& alpha = ctx->SavedTensors().at(1); const auto& log_probs = ctx->SavedTensors().at(2); const auto& targets = ctx->SavedTensors().at(3); const auto& input_lengths = ctx->SavedTensors().at(4); const auto& target_lengths = ctx->SavedTensors().at(5); in_grads->resize(4); in_grads->at(0) = JUST(functional::CtcLossGrad(grad_out, log_probs, targets, input_lengths, target_lengths, loss, alpha, ctx->blank, ctx->zero_infinity, ctx->max_target_length)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("ctc_loss", CTCLoss); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #if CUDA_VERSION >= 11060 namespace oneflow { namespace one { struct CublasFusedMLPCaptureState : public AutoGradCaptureState { int32_t weight_num = 0; bool skip_final_activation = false; bool x_requires_grad = false; std::vector weights_requires_grad; std::vector biases_requires_grad; }; class CublasFusedMLP : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const CublasFusedMLPCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe CublasFusedMLP::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_OR_RETURN(inputs.size() % 2 == 1) << Error::RuntimeError() << "Both weight and bias should be passed together"; int32_t weight_num = (inputs.size() - 1) / 2; ctx->weight_num = weight_num; ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->weights_requires_grad.resize(weight_num); ctx->biases_requires_grad.resize(weight_num); for (int32_t i = 0; i < weight_num; i++) { ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad(); // NOLINT ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad(); // NOLINT } ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // x. idx_sum:1 for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1))); // weights. idx_sum:1+w } ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // final layers output. idx_sum:2+w for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward( JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w } for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden. } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->skip_final_activation = JUST(composed_attrs.GetAttr("skip_final_activation")); return Maybe::Ok(); } Maybe CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { int32_t weight_num = ctx->weight_num; in_grads->resize(1 + 2 * weight_num); std::shared_ptr last_bias_dy = JUST(VectorAt(out_grads, 0)); if (!ctx->skip_final_activation) { // step1: use dy and final output to get last layer's relu grad. last_bias_dy = JUST(functional::ReluGrad(JUST(VectorAt(out_grads, 0)), JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num)))); } TensorTuple hiddens(weight_num); TensorTuple weights(weight_num); TensorTuple cublas_auxs(weight_num); TensorTuple dgrad(weight_num); std::shared_ptr x = JUST(VectorAt(ctx->SavedTensors(), 0)); for (int32_t i = 0; i < weight_num; ++i) { weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i)); } for (int32_t i = 0; i < weight_num; ++i) { cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num)); } for (int32_t i = 0; i < weight_num; ++i) { hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num)); } std::shared_ptr cublas_dy = last_bias_dy; // Use Fully Fused MLP Backward. if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) { const std::vector alpha_list(weight_num - 1, 1.0); const auto& fused_mlp_grad = JUST(functional::FusedMLPGrad(cublas_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights, cublas_auxs, hiddens, alpha_list)); if (ctx->x_requires_grad) { // dx: JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0); } for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) { if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) { // dbias JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) = fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT } // dw if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = fused_mlp_grad->at(1 + weight_num + hidden_layer_idx); } } } else { // step2: use reduce_sum to get last layer's bias grad. std::vector reduce_axes_vec{0}; if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) { JUST(VectorAt(*in_grads, 2 * weight_num)) = JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false, NullOpt)); } for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { // If it is final layer, we use out_grads[0] as dy. if (hidden_layer_idx != weight_num - 1) { cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1)); } /* Here we use cublas to compute bias + relu + matmul grad. Then use Matmul to compute weight grad. */ const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad( cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)), JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0)); // dgrad dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) { // dbias JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) = matmul_relu_bias_bgrad->at(1); // NOLINT } // dw if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul( cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0)); } } // For the first layer, we need to use 2 matmul to get grads. std::shared_ptr last_dy; if (weight_num != 1) { last_dy = JUST(VectorAt(dgrad, 1)); } else { last_dy = last_bias_dy; } if (ctx->x_requires_grad) { // dx: JUST(VectorAt(*in_grads, 0)) = JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0)); } if (JUST(VectorAt(ctx->weights_requires_grad, 0))) { // dw: JUST(VectorAt(*in_grads, 1)) = JUST( functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0)); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("cublas_fused_mlp", CublasFusedMLP); } // namespace one } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/core/autograd/gradient_funcs/cum_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct CumCaptureState : public AutoGradCaptureState { bool requires_grad = false; int32_t dim = 0; }; template class CumGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } protected: AttrMap base_attrs_; }; class CumsumGrad : public CumGrad { public: Maybe Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); return Maybe::Ok(); } Maybe Apply(const CumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { std::vector flip_dim(1, ctx->dim); (*in_grads)[0] = JUST( functional::Flip(JUST(functional::Cumsum(JUST(functional::Flip(out_grads[0], flip_dim)), ctx->dim, out_grads[0]->dtype())), flip_dim)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad); class CumProdGrad : public CumGrad { public: Maybe Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); ctx->SaveTensorForBackward(outputs.at(0)); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const CumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::CumprodGrad(out_grads.at(0), ctx->SavedTensors().at(0), ctx->SavedTensors().at(1), ctx->dim)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("cumprod", CumProdGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/deconv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DeConvolutionNdCaptureState : public AutoGradCaptureState { bool weight_requires_grad = false; bool activation_requires_grad = false; size_t ndims; std::string data_format; std::vector padding_before; std::vector kernel_size; std::vector strides; std::vector dilation_rate; int32_t groups; }; class DeConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe DeConvolutionNd::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe DeConvolutionNd::Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->activation_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); if (ctx->activation_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); // weight } if (ctx->weight_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); // x } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->groups = JUST(composed_attrs.GetAttr("groups")); ctx->ndims = ctx->kernel_size.size(); return Maybe::Ok(); } Maybe DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->activation_requires_grad) { const auto& x = ctx->SavedTensors().at(1); std::vector start, stop, step; for (int i = 0; i < x->shape()->NumAxes(); i++) { start.emplace_back(0); stop.emplace_back(x->shape()->At(i)); step.emplace_back(1); } const auto& weight = ctx->SavedTensors().at(0); if (ctx->ndims == 1) { std::shared_ptr result = JUST(functional::Conv1d( out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false)); in_grads->at(0) = result; } else if (ctx->ndims == 2) { std::shared_ptr result = JUST(functional::Conv2d( out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false)); in_grads->at(0) = result; } else if (ctx->ndims == 3) { std::shared_ptr result = JUST(functional::Conv3d( out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false)); in_grads->at(0) = result; } else { UNIMPLEMENTED_THEN_RETURN() << "Invalid ndim " << ctx->ndims << " for conv functor"; } } if (ctx->weight_requires_grad) { int idx = ctx->activation_requires_grad; const auto& x = ctx->SavedTensors().at(idx); in_grads->at(1) = JUST(functional::ConvFilterGrad( x, out_grads.at(0), ctx->ndims, ctx->kernel_size, ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("deconv1d", DeConvolutionNd); REGISTER_OP_EXPR_GRAD_FUNCTION("deconv2d", DeConvolutionNd); REGISTER_OP_EXPR_GRAD_FUNCTION("deconv3d", DeConvolutionNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/deform_conv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DeformConvNdCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool offset_requires_grad = false; bool weight_requires_grad = false; bool mask_requires_grad = false; bool bias_requires_grad = false; int32_t stride_h = 0; int32_t stride_w = 0; int32_t pad_h = 0; int32_t pad_w = 0; int32_t dilation_h = 0; int32_t dilation_w = 0; int32_t groups = 0; int32_t offset_groups = 0; bool use_mask = false; }; class DeformConvNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe DeformConvNd::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe DeformConvNd::Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); ctx->offset_requires_grad = inputs.at(2)->requires_grad(); ctx->mask_requires_grad = inputs.at(3)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->SaveTensorForBackward(inputs.at(1)); // weight ctx->SaveTensorForBackward(inputs.at(2)); // offset ctx->SaveTensorForBackward(inputs.at(3)); // mask ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->use_mask = JUST(composed_attrs.GetAttr("use_mask")); ctx->stride_h = JUST(composed_attrs.GetAttr("stride_h")); ctx->stride_w = JUST(composed_attrs.GetAttr("stride_w")); ctx->pad_h = JUST(composed_attrs.GetAttr("pad_h")); ctx->pad_w = JUST(composed_attrs.GetAttr("pad_w")); ctx->dilation_h = JUST(composed_attrs.GetAttr("dilation_h")); ctx->dilation_w = JUST(composed_attrs.GetAttr("dilation_w")); ctx->groups = JUST(composed_attrs.GetAttr("groups")); ctx->offset_groups = JUST(composed_attrs.GetAttr("offset_groups")); return Maybe::Ok(); } Maybe DeformConvNd::Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(5); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& input = ctx->SavedTensors().at(0); const auto& weight = ctx->SavedTensors().at(1); const auto& offset = ctx->SavedTensors().at(2); const auto& mask = ctx->SavedTensors().at(3); const auto& output_grad = out_grads.at(0); if (ctx->input_requires_grad || ctx->offset_requires_grad || ctx->mask_requires_grad) { std::shared_ptr grads_tuple; if (ctx->use_mask) { grads_tuple = JUST(functional::DeformConv2dInputGrad( output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h, ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups, ctx->use_mask)); } else { grads_tuple = JUST(functional::DeformConv2dInputGrad( output_grad, input, weight, offset, NullOpt, ctx->stride_h, ctx->stride_w, ctx->pad_h, ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups, ctx->use_mask)); } if (ctx->input_requires_grad) { in_grads->at(0) = grads_tuple->at(0); // input_grad } if (ctx->offset_requires_grad) { in_grads->at(2) = grads_tuple->at(1); // offset_grad } if (ctx->use_mask && ctx->mask_requires_grad) { in_grads->at(3) = grads_tuple->at(2); // mask_grad } } if (ctx->weight_requires_grad) { // weight_grad in_grads->at(1) = JUST(functional::DeformConv2dParamGrad( output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h, ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups, ctx->use_mask)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("deform_conv2d", DeformConvNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/depand.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DependCaptureState : public AutoGradCaptureState { bool in_requires_grad = false; bool depend_tensor_requires_grad = false; Shape depend_tensor_shape; Symbol depend_tensor_dtype; Maybe> depend_tensor_device; }; class Depend : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(DependCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->in_requires_grad = inputs.at(0)->requires_grad(); ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad(); if (ctx->depend_tensor_requires_grad) { ctx->depend_tensor_shape = *(inputs.at(1)->shape()); ctx->depend_tensor_dtype = inputs.at(1)->dtype(); ctx->depend_tensor_device = inputs.at(1)->device(); } return Maybe::Ok(); } Maybe Apply(const DependCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); } if (ctx->depend_tensor_requires_grad) { in_grads->at(1) = JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), ctx->depend_tensor_dtype, JUST(ctx->depend_tensor_device))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/det.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct DetCaptureState : public AutoGradCaptureState { bool requires_grad = false; size_t input_index = 0; size_t output_index = 0; }; class Det : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(DetCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (ctx->requires_grad) { ctx->input_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->output_index = ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); } return Maybe::Ok(); } Maybe Apply(const DetCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { const auto& output = JUST(VectorAt(ctx->SavedTensors(), ctx->output_index)); const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index)); const auto& dy = JUST(VectorAt(out_grads, 0)); const auto& dy_unsqueeze = JUST(functional::UnsqueezeMultiple(dy, {-2, -1}, dy->ndim() + 2)); const auto& output_unsqueeze = JUST(functional::UnsqueezeMultiple(output, {-2, -1}, output->ndim() + 2)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::Transpose2dim( JUST(functional::Mul( dy_unsqueeze, JUST(functional::Mul(JUST(functional::Inv(input)), output_unsqueeze)))), -2, -1)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("det", Det); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/diag.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DiagCaptureState : public AutoGradCaptureState { bool requires_grad; int32_t diagonal; }; class Diag : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Diag::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Diag::Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->diagonal = JUST(composed_attrs.GetAttr("diagonal")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Diag::Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::DiagGrad(out_grads.at(0), x, ctx->diagonal)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("diag", Diag); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/diagonal.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DiagonalInterpState : public AutoGradCaptureState { bool requires_grad = false; int32_t offset = 0; }; class Diagonal : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DiagonalInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Diagonal::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Diagonal::Capture(DiagonalInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->offset = JUST(composed_attrs.GetAttr("offset")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Diagonal::Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::DiagonalGrad(out_grads.at(0), x, ctx->offset)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("diagonal", Diagonal); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/dim_gather.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DimGatherCaptureState : public AutoGradCaptureState { int32_t dim; bool requires_grad; }; class DimGather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::shared_ptr bw_dim_gather_op_; }; Maybe DimGather::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe DimGather::Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(1)); ctx->SaveTensorForBackward(inputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); return Maybe::Ok(); } Maybe DimGather::Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& index = ctx->SavedTensors().at(0); const std::shared_ptr& like = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::DimScatterAddLike(like, ctx->dim, index, out_grads.at(0))); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("dim_gather", DimGather); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/dim_scatter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DimScatterCaptureState : public AutoGradCaptureState { int32_t dim; bool input_requires_grad; bool src_requires_grad; }; enum class ScatterType { kUpdate, kAdd, kMultiply }; template class DimScatter : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; template Maybe DimScatter::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } template Maybe DimScatter::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->src_requires_grad = inputs.at(2)->requires_grad(); if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(1)); // index saved if (T == ScatterType::kMultiply) { ctx->SaveTensorForBackward(inputs.at(2)); // src saved } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); return Maybe::Ok(); } template Maybe DimScatter::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); const std::shared_ptr& index = ctx->SavedTensors().at(0); if (ctx->src_requires_grad) { in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), ctx->dim, index, false)); } if (ctx->input_requires_grad) { if (T == ScatterType::kAdd) { in_grads->at(0) = out_grads.at(0); } if (T == ScatterType::kUpdate) { in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f, /*inplace*/ false)); } if (T == ScatterType::kMultiply) { const std::shared_ptr& src = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::DimScatterMul(out_grads.at(0), ctx->dim, index, src, /*inplace*/ false)); } } return Maybe::Ok(); } class DimScatterUpdateScalar : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe DimScatterUpdateScalar::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe DimScatterUpdateScalar::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); if (!ctx->input_requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(1)); // index saved ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); return Maybe::Ok(); } Maybe DimScatterUpdateScalar::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->input_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& index = ctx->SavedTensors().at(0); in_grads->resize(2); in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f, /*inplace*/ false)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter); REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter); REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_mul", DimScatter); REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/dot.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DotCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool y_requires_grad = false; size_t x_offset = 0; size_t y_offset = 0; }; class DotGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(DotCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->x_requires_grad = inputs.at(0)->requires_grad(); if (ctx->x_requires_grad) { ctx->x_offset = ctx->SaveTensorForBackward(inputs.at(1)); } ctx->y_requires_grad = inputs.at(1)->requires_grad(); if (ctx->y_requires_grad) { ctx->y_offset = ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const DotCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); if (ctx->x_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_offset); const auto& results = JUST(functional::Mul(x, out_grads.at(0))); in_grads->at(0) = results; } if (ctx->y_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_offset); const auto& results = JUST(functional::Mul(y, out_grads.at(0))); in_grads->at(1) = results; } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("dot", DotGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/dropout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct DropoutCaptureState : public AutoGradCaptureState { bool requires_grad = true; bool has_addend = false; float rate = 0.0; }; class Dropout : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Dropout::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Dropout::Capture(DropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->rate = JUST(composed_attrs.GetAttr("rate")); if (inputs.size() == 1) { ctx->has_addend = false; } else if (inputs.size() == 2) { ctx->has_addend = true; } else { UNIMPLEMENTED(); } ctx->SaveTensorForBackward(outputs.at(1)); // output mask return Maybe::Ok(); } Maybe Dropout::Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // Output has y and mask. float scale = 0.0f; // When dropout rate = 1.0, we set scale as zero. if (ctx->rate < 1.0f) { scale = 1.0f / (1.0f - ctx->rate); } const std::shared_ptr& mask = ctx->SavedTensors().at(0); if (ctx->has_addend) { in_grads->resize(2); in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale)); in_grads->at(1) = out_grads.at(0); return Maybe::Ok(); } else { in_grads->resize(1); in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale)); return Maybe::Ok(); } } REGISTER_OP_EXPR_GRAD_FUNCTION("dropout", Dropout); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/eager_ccl_broadcast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" namespace oneflow { namespace one { namespace { Maybe EagerCclReduce(Symbol parallel_desc, int64_t root) { return one::OpBuilder("eager_ccl_reduce", *JUST(UniqueStr("eager_ccl_reduce"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Attr("root", root) .Build(); } Maybe FindOrCreatEagerCclReduceOpExpr(Symbol parallel_desc, int64_t root) { thread_local HashMap, int64_t>, std::shared_ptr> parallel_desc_and_root_device2eager_nccl_reduce; const auto& key = std::make_pair(parallel_desc, root); auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key); if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) { std::shared_ptr op_expr = JUST(EagerCclReduce(parallel_desc, root)); iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first; } return iter->second; } } // namespace struct EagerCclBroadcastCaptureState : public AutoGradCaptureState { // NOLINT Symbol parallel_desc; int64_t root; }; class EagerCclBroadcast : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); return Maybe::Ok(); } Maybe Capture(EagerCclBroadcastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { ctx->root = JUST(interp_ctx.attrs.GetAttr("root")); ctx->parallel_desc = JUST(interp_ctx.parallel_desc); return Maybe::Ok(); } Maybe Apply(const EagerCclBroadcastCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& grad_op = JUST(FindOrCreatEagerCclReduceOpExpr(ctx->parallel_desc, ctx->root)); in_grads->resize(1); in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op, {out_grads.at(0)})); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("eager_ccl_broadcast", EagerCclBroadcast); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ElementwiseXimumCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool y_requires_grad; }; class ElementwiseXimumOp : public OpExprGradFunction { public: Maybe Capture(ElementwiseXimumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe Apply(const ElementwiseXimumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe::Ok(); } in_grads->resize(2); const std::shared_ptr& x = ctx->SavedTensors().at(0); const std::shared_ptr& y = ctx->SavedTensors().at(1); if (ctx->x_requires_grad || ctx->y_requires_grad) { const auto& grads = JUST(grad_functor(out_grads.at(0), x, y)); if (ctx->x_requires_grad) { in_grads->at(0) = grads->at(0); } if (ctx->y_requires_grad) { in_grads->at(1) = grads->at(1); } } return Maybe::Ok(); } protected: std::function(const std::shared_ptr&, const std::shared_ptr&, const std::shared_ptr&)> grad_functor; }; class ElementwiseMinimum : public ElementwiseXimumOp { public: Maybe Init(const OpExpr& op) override { grad_functor = functional::ElementwiseMinGrad; return Maybe::Ok(); } }; class ElementwiseMaximum : public ElementwiseXimumOp { public: Maybe Init(const OpExpr& op) override { grad_functor = functional::ElementwiseMaxGrad; return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_minimum", ElementwiseMinimum); REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_maximum", ElementwiseMaximum); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/embedding.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct EmbeddingCaptureState : public AutoGradCaptureState { int64_t padding_idx = -1; bool scale_grad_by_freq = false; bool requires_grad = false; }; class Embedding : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Embedding::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Forward op must be not null"; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Embedding::Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1))); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->padding_idx = JUST(composed_attrs.GetAttr("padding_idx")); ctx->scale_grad_by_freq = JUST(composed_attrs.GetAttr("scale_grad_by_freq")); return Maybe::Ok(); } Maybe Embedding::Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(ctx->SavedTensors().size()); const auto& weight = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0)); const auto& indices = JUST(oneflow::VectorAt(ctx->SavedTensors(), 1)); int64_t padding_idx = ctx->padding_idx; bool scale_grad_by_freq = ctx->scale_grad_by_freq; JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::EmbeddingGrad( JUST(oneflow::VectorAt(out_grads, 0)), weight, indices, padding_idx, scale_grad_by_freq)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("embedding", Embedding); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/expand.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ExpandCaptureState : public AutoGradCaptureState { bool requires_grad; int32_t lpad; bool keep_dims; std::vector reduce_dims; }; class Expand : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ExpandCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Expand::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Expand::Capture(ExpandCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs[0]->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } const Shape& in_shape = *inputs[0]->shape(); const Shape& expand_shape = *outputs[0]->shape(); ctx->lpad = expand_shape.size() - in_shape.size(); ctx->keep_dims = (in_shape.size() > 0); ctx->reduce_dims.reserve(expand_shape.size()); if (ctx->keep_dims) { for (size_t i = 0; i < expand_shape.size(); ++i) { const auto& t_dim = expand_shape[i]; const auto& dim = i < ctx->lpad ? 1 : in_shape[i - ctx->lpad]; if (dim != t_dim) { ctx->reduce_dims.push_back(i); } } } else { for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { ctx->reduce_dims.push_back(axis); } } return Maybe::Ok(); } Maybe Expand::Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); in_grads->at(0) = out_grads[0]; if (ctx->reduce_dims.size() > 0) { in_grads->at(0) = JUST(functional::ReduceSum(in_grads->at(0), ctx->reduce_dims, ctx->keep_dims, NullOpt)); } if (ctx->lpad > 0 && ctx->keep_dims) { in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, ctx->lpad)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("expand", Expand); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fake_quantization.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct FakeQuantizationCaptureState : public AutoGradCaptureState { bool requires_grad; }; class FakeQuantization : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FakeQuantizationCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 3); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const FakeQuantizationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(3); if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fake_quantization", FakeQuantization); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fft.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct FftR2CCaptureState : public AutoGradCaptureState { bool requires_grad = false; bool onesided = false; std::vector dims; DimVector input_shape_vec; int32_t norm_mode = 0; }; class FftR2C : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FftR2CCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->onesided = JUST(attrs.GetAttr("onesided")); ctx->dims = JUST(attrs.GetAttr>("dims")); ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec(); return Maybe::Ok(); } Maybe Apply(const FftR2CCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`"; if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(1); if (!ctx->onesided) { auto complex_grad = JUST(functional::FftC2C(JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode, /*forward=*/false, /*normalized=*/false)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_grad)); } else { std::vector fft_dims = ctx->dims; std::vector fft_shapes(fft_dims.size(), 0); FOR_RANGE(size_t, i, 0, fft_dims.size()) { fft_shapes[i] = ctx->input_shape_vec[fft_dims[i]]; } // fill the last dim bool must_copy = false; auto x_sizes = JUST(oneflow::VectorAt(out_grads, 0))->shape()->dim_vec(); std::vector pad_amount(x_sizes.size() * 2, 0); int64_t last_dim = ctx->dims.back(); if (x_sizes[last_dim] < ctx->input_shape_vec[last_dim]) { must_copy = true; auto pad_idx = pad_amount.size() - 2 * last_dim - 1; pad_amount[pad_idx] = ctx->input_shape_vec[last_dim] - x_sizes[last_dim]; } auto complex_full_grad = must_copy ? JUST(functional::ConstantPad(JUST(oneflow::VectorAt(out_grads, 0)), pad_amount, 0)) : JUST(oneflow::VectorAt(out_grads, 0)); complex_full_grad = JUST(functional::FftC2C(complex_full_grad, NullOpt, ctx->dims, ctx->norm_mode, /*forward=*/false, /*normalized=*/false)); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_full_grad)); } return Maybe::Ok(); } }; struct FftC2CCaptureState : public AutoGradCaptureState { bool requires_grad = false; bool forward = false; std::vector dims; int32_t norm_mode = 0; }; class FftC2C : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FftC2CCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->forward = JUST(attrs.GetAttr("forward")); ctx->dims = JUST(attrs.GetAttr>("dims")); ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); return Maybe::Ok(); } Maybe Apply(const FftC2CCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`"; if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::FftC2C( JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode, /*forward=*/!(ctx->forward), /*normalized=*/false)); return Maybe::Ok(); } }; struct FftC2RCaptureState : public AutoGradCaptureState { bool requires_grad = false; std::vector dims; int32_t norm_mode = 0; int64_t last_dim_size = 1; DimVector input_shape_vec; }; class FftC2R : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FftC2RCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->dims = JUST(attrs.GetAttr>("dims")); ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); ctx->last_dim_size = JUST(attrs.GetAttr("last_dim_size")); ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec(); return Maybe::Ok(); } Maybe Apply(const FftC2RCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: out_grads.size() == 1"; if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(1); // NOTE: set `forward` True to prevent conjugating result auto complex_grad = JUST(functional::FftR2C( JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode, /*onesided=*/true, /*forward=*/true, /*normalized=*/false)); // no need conj Shape input_shape(ctx->input_shape_vec); int64_t last_dim = ctx->dims.back(); auto double_length = JUST(oneflow::VectorAt(out_grads, 0))->dim(last_dim) - complex_grad->dim(last_dim); auto in_grad = complex_grad; // Mul by 2, and slice if (double_length > 0) { in_grad = JUST(functional::Narrow(complex_grad, last_dim, 1, double_length)); // will change shape of in_grad in_grad = JUST(functional::ScalarMul(in_grad, 2, /*inplace=*/true)); } std::vector slice_st(input_shape.size(), 0); std::vector slice_end(input_shape.begin(), input_shape.end()); std::vector slice_step(input_shape.size(), 1); auto sliced_tensor = JUST(functional::Slice(complex_grad, slice_st, slice_end, slice_step, false)); JUST(oneflow::VectorAt(*in_grads, 0)) = sliced_tensor; return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fft_r2c", FftR2C); REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2c", FftC2C); REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2r", FftC2R); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fill.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct FillCaptureState : public AutoGradCaptureState { bool in_requires_grad = false; bool value_requires_grad = false; }; class Fill : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FillCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Fill::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Fill::Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->in_requires_grad = inputs[0]->requires_grad(); return Maybe::Ok(); } Maybe Fill::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads.size() must be equal to 1."; in_grads->resize(1); if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); } return Maybe::Ok(); } class FillTensor : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FillCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FillTensor::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FillTensor::Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->in_requires_grad = inputs[0]->requires_grad(); ctx->value_requires_grad = inputs[1]->requires_grad(); return Maybe::Ok(); } Maybe FillTensor::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads.size() must be equal to 1."; in_grads->resize(2); if (ctx->value_requires_grad) { int32_t num_axes = out_grads[0]->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); (*in_grads)[1] = JUST(functional::ReduceSum(out_grads[0], axes_vec, /*keepdims=*/false, NullOpt)); } if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fill_", Fill); REGISTER_OP_EXPR_GRAD_FUNCTION("fill_tensor_", FillTensor); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/flatten.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FlattenCaptureState : public AutoGradCaptureState { bool requires_grad; }; class Flatten : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FlattenCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Flatten::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); return Maybe::Ok(); } Maybe Flatten::Capture(FlattenCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Flatten::Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& like = ctx->SavedTensors().at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("flatten", Flatten); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/flip.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FlipCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector dims; }; class Flip : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Flip::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Flip::Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dims = JUST(composed_attrs.GetAttr>("dims")); return Maybe::Ok(); } Maybe Flip::Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::Flip(out_grads[0], ctx->dims)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("flip", Flip); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fold.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FoldInterpState : public AutoGradCaptureState { bool requires_grad = true; std::string data_format = "channels_first"; std::vector kernel_size; std::vector dilation_rate; std::vector padding; std::vector strides; }; class Fold : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FoldInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Fold::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Fold::Capture(FoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->padding = JUST(composed_attrs.GetAttr>("padding")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); return Maybe::Ok(); } Maybe Fold::Apply(const FoldInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); in_grads->at(0) = JUST(functional::Unfold(out_grads.at(0), ctx->kernel_size, ctx->dilation_rate, ctx->padding, ctx->strides, ctx->data_format)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fold", Fold); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_bias_add_dropout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedBiasAddDropoutInterpState : public AutoGradCaptureState { bool input_requires_grad = true; bool bias_requires_grad = true; int32_t axis = 1; float scale = 1.0; }; class FusedBiasAddDropout : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedBiasAddDropoutInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedBiasAddDropoutInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedBiasAddDropout::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedBiasAddDropout::Capture(FusedBiasAddDropoutInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 3); ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input ctx->bias_requires_grad = inputs.at(1)->requires_grad(); // bias if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale = JUST(composed_attrs.GetAttr("scale")); ctx->axis = JUST(composed_attrs.GetAttr("axis")); ctx->SaveTensorForBackward(inputs.at(2)); return Maybe::Ok(); } Maybe FusedBiasAddDropout::Apply(const FusedBiasAddDropoutInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } // mask have no grad(reqiures_grad=False), but still take a place in in_grads in_grads->resize(3); const std::shared_ptr& mask = ctx->SavedTensors().at(0); const std::shared_ptr& dropout_grad = JUST(functional::DropoutGrad(out_grads.at(0), mask, ctx->scale)); if (ctx->input_requires_grad) { in_grads->at(0) = dropout_grad; } const int64_t num_axes = out_grads.at(0)->shape()->NumAxes(); if (ctx->bias_requires_grad) { std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes); for (int i = 0; i < num_axes; ++i) { if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); } } in_grads->at(1) = JUST(functional::ReduceSum(dropout_grad, reduce_axes_vec, false, NullOpt)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_mask_scale", FusedBiasAddDropout); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_bias_add_gelu.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedBiasAddGeluInterpState : public AutoGradCaptureState { bool input_requires_grad = true; bool bias_requires_grad = true; int32_t axis = 1; }; class FusedBiasAddGelu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedBiasAddGeluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->bias_requires_grad = inputs.at(1)->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); if (ctx->input_requires_grad || ctx->bias_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } Maybe Apply(const FusedBiasAddGeluInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); const int64_t num_axes = out_grads.at(0)->shape()->NumAxes(); in_grads->resize(2); const auto& a = ctx->SavedTensors().at(0); const auto& b = ctx->SavedTensors().at(1); const std::shared_ptr& fused_bias_add_gelu_grad = JUST(functional::FusedBiasAddGeluGrad(a, b, out_grads.at(0), ctx->axis)); if (ctx->bias_requires_grad) { std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes); for (int i = 0; i < num_axes; ++i) { if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); } } in_grads->at(1) = JUST(functional::ReduceSum(fused_bias_add_gelu_grad, reduce_axes_vec, false, NullOpt)); } if (ctx->input_requires_grad) { in_grads->at(0) = fused_bias_add_gelu_grad; } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_gelu", FusedBiasAddGelu); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_bias_add_scale_mask_softmax_dropout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct FusedBiasAddScaleMaskSoftmaxDropoutCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool bias_requires_grad = false; bool bias_broadcast = false; int softmax_y_index = -1; int bias_index = -1; int mask_index = -1; int dropout_mask_index = -1; float scale = 1.0; float dropout_scale = 1.0; }; class FusedBiasAddScaleMaskSoftmaxDropoutGradFunction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(outputs.size(), 2); // (y, softmax_y) CHECK_EQ_OR_RETURN(inputs.size(), 4); // (x, bias, mask, dropout_mask) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->bias_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale = JUST(composed_attrs.GetAttr("scale_value")); ctx->dropout_scale = JUST(composed_attrs.GetAttr("dropout_scale_value")); if (ctx->x_requires_grad) { ctx->mask_index = ctx->SaveTensorForBackward(inputs.at(2)); // mask ctx->dropout_mask_index = ctx->SaveTensorForBackward(inputs.at(3)); // dropout_mask ctx->softmax_y_index = ctx->SaveTensorForBackward(outputs.at(1)); // softmax_y } if (ctx->bias_requires_grad) { ctx->bias_broadcast = (inputs.at(0)->shape() != inputs.at(1)->shape()); if (ctx->bias_broadcast) { ctx->bias_index = ctx->SaveTensorForBackward(inputs.at(1)); // bias } } return Maybe::Ok(); } Maybe Apply(const FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // (dy, d_softmax_y) in_grads->resize(4); // (x, bias, mask, dropout_mask) const auto& saved_tensors = ctx->SavedTensors(); const auto& dy = out_grads.at(0); CHECK_GE_OR_RETURN(saved_tensors.size(), 3); // (mask, dropout_mask, softmax_y, [bias]) if (ctx->x_requires_grad || ctx->bias_requires_grad) { const auto& mask = saved_tensors.at(ctx->mask_index); const auto& dropout_mask = saved_tensors.at(ctx->dropout_mask_index); const auto& softmax_y = saved_tensors.at(ctx->softmax_y_index); in_grads->at(0) = JUST(functional::FusedScaleMaskSoftmaxDropoutGrad( softmax_y, dy, mask, dropout_mask, ctx->scale, ctx->dropout_scale)); } if (ctx->bias_requires_grad) { if (ctx->bias_broadcast) { const auto& bias = saved_tensors.at(ctx->bias_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(in_grads->at(0), bias)); } else { in_grads->at(1) = in_grads->at(0); } } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_scale_mask_softmax_dropout", FusedBiasAddScaleMaskSoftmaxDropoutGradFunction); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_center.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 8; struct FusedCenterCaptureState : public AutoGradCaptureState { std::vector requires_grad; }; class FusedCenterGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedCenterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); } for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); } return Maybe::Ok(); } Maybe Apply(const FusedCenterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& rho2_diff = out_grads.at(0); const auto& b1_x1 = ctx->SavedTensors().at(0); const auto& b1_x2 = ctx->SavedTensors().at(1); const auto& b2_x1 = ctx->SavedTensors().at(2); const auto& b2_x2 = ctx->SavedTensors().at(3); const auto& b1_y1 = ctx->SavedTensors().at(4); const auto& b1_y2 = ctx->SavedTensors().at(5); const auto& b2_y1 = ctx->SavedTensors().at(6); const auto& b2_y2 = ctx->SavedTensors().at(7); in_grads->resize(INPUT_LEN); auto result = JUST(functional::FusedCenterGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, rho2_diff)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_center_dist", FusedCenterGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_cross_interaction.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct FusedCrossFeatureInteractionInterpState : public AutoGradCaptureState { bool x_requires_grad = true; bool weight_requires_grad = true; bool x0_requires_grad = true; bool bias_requires_grad = true; size_t x_idx = 0; size_t bias_idx = 0; size_t weight_idx = 0; size_t x0_idx = 0; size_t matmul_result_idx = 0; std::string interaction_mode; }; class FusedCrossFeatureInteraction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. "; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedCrossFeatureInteractionInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 4) << "Input size should be equal to 4. "; ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->interaction_mode = JUST(composed_attrs.GetAttr("interaction_mode")); ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad(); ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad(); ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 3))->requires_grad(); ctx->x_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); ctx->weight_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1))); ctx->x0_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2))); if (ctx->interaction_mode == "matrix") { ctx->bias_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 3))); } ctx->matmul_result_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1))); return Maybe::Ok(); } Maybe Apply(const FusedCrossFeatureInteractionInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 2) << "Out grads size should be equal to 2. "; std::shared_ptr grads; in_grads->resize(4); if (ctx->interaction_mode == "vector") { grads = JUST(functional::FusedCrossFeatureInteractionV1Grad( JUST(oneflow::VectorAt(out_grads, 0)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx)))); } else if (ctx->interaction_mode == "matrix") { grads = JUST(functional::FusedCrossFeatureInteractionV2Grad( JUST(oneflow::VectorAt(out_grads, 0)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->bias_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx)))); } else { UNIMPLEMENTED_THEN_RETURN() << "Interaction mode only support `vector` and `matrix`. "; } if (ctx->x_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0)); } if (ctx->weight_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1)); } if (ctx->x0_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2)); } if (ctx->bias_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 3)) = JUST(oneflow::VectorAt(*grads, 3)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_cross_feature_interaction", FusedCrossFeatureInteraction); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_dot_feature_interaction.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct FusedDotFeatureInteractionCaptureState : public AutoGradCaptureState { bool need_grad_op = false; std::vector features_requires_grad; std::vector feature_dims; int32_t output_concat_grad_dim = 0; bool self_interaction = false; bool has_output_concat = false; bool has_output_concat_grad = false; std::string pooling; }; class FusedDotFeatureInteraction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedDotFeatureInteraction::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); return Maybe::Ok(); } Maybe FusedDotFeatureInteraction::Capture(FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->has_output_concat = JUST(attrs.GetAttr("has_output_concat")); int32_t num_features = 0; if (ctx->has_output_concat) { num_features = inputs.size() - 1; const auto& output_concat = JUST(oneflow::VectorAt(inputs, num_features)); ctx->has_output_concat_grad = output_concat->requires_grad(); ctx->output_concat_grad_dim = output_concat->shape()->At(1); } else { num_features = inputs.size(); } if (ctx->has_output_concat_grad) { ctx->need_grad_op = true; } ctx->features_requires_grad.resize(num_features); ctx->feature_dims.resize(num_features); for (int32_t i = 0; i < num_features; ++i) { const auto& feature = JUST(oneflow::VectorAt(inputs, i)); ctx->features_requires_grad[i] = feature->requires_grad(); ctx->feature_dims[i] = feature->shape()->At(1); if (feature->requires_grad()) { ctx->need_grad_op = true; } ctx->SaveTensorForBackward(feature); } ctx->pooling = JUST(attrs.GetAttr("pooling")); if (!ctx->need_grad_op) { return Maybe::Ok(); } ctx->self_interaction = JUST(attrs.GetAttr("self_interaction")); return Maybe::Ok(); } Maybe FusedDotFeatureInteraction::Apply(const FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->need_grad_op) { return Maybe::Ok(); } int32_t num_features = ctx->features_requires_grad.size(); in_grads->resize(num_features + 1); TensorTuple features(num_features); for (int i = 0; i < num_features; ++i) { features[i] = JUST(oneflow::VectorAt(ctx->SavedTensors(), i)); } std::shared_ptr grads; grads = JUST(functional::FusedDotFeatureInteractionGrad( JUST(oneflow::VectorAt(out_grads, 0)), features, ctx->has_output_concat, ctx->self_interaction, ctx->output_concat_grad_dim, ctx->pooling)); for (int32_t i = 0; i < num_features; ++i) { if (JUST(oneflow::VectorAt(ctx->features_requires_grad, i))) { JUST(oneflow::VectorAt(*in_grads, i)) = JUST(oneflow::VectorAt(*grads, i)); } } if (ctx->has_output_concat_grad) { JUST(oneflow::VectorAt(*in_grads, num_features)) = JUST(oneflow::VectorAt(*grads, num_features)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_dot_feature_interaction", FusedDotFeatureInteraction); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_fast_gelu_mul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedFastGeluMulGradCaptureState : public AutoGradCaptureState { bool requires_grad = true; }; class FusedFastGeluMulGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // (in, multiplier) CHECK_EQ_OR_RETURN(outputs.size(), 1); // (out,) ctx->requires_grad = inputs.at(0)->requires_grad() || inputs.at(1)->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); // in ctx->SaveTensorForBackward(inputs.at(1)); // multiplier } return Maybe::Ok(); } Maybe Apply(const FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& out_diff = out_grads.at(0); const auto& saved_tensors = ctx->SavedTensors(); CHECK_EQ_OR_RETURN(saved_tensors.size(), 2); const auto& in = saved_tensors.at(0); const auto& multiplier = saved_tensors.at(1); in_grads->resize(2); // (in_diff, multiplier_diff) auto result = JUST(functional::FusedFastGeluMulGrad(out_diff, in, multiplier)); CHECK_EQ_OR_RETURN(result->size(), 2); in_grads->at(0) = result->at(0); in_grads->at(1) = result->at(1); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_fast_gelu_mul", FusedFastGeluMulGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_boundding_boxes_coord.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 8; struct FusedGetBounddingBoxesCoordGradCaptureState : public AutoGradCaptureState { std::vector requires_grad; }; class FusedGetBounddingBoxesCoordGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedGetBounddingBoxesCoordGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); } return Maybe::Ok(); } Maybe Apply(const FusedGetBounddingBoxesCoordGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), INPUT_LEN); const auto& b1_x1_diff = out_grads.at(0); const auto& b1_x2_diff = out_grads.at(1); const auto& b1_y1_diff = out_grads.at(2); const auto& b1_y2_diff = out_grads.at(3); const auto& b2_x1_diff = out_grads.at(4); const auto& b2_x2_diff = out_grads.at(5); const auto& b2_y1_diff = out_grads.at(6); const auto& b2_y2_diff = out_grads.at(7); in_grads->resize(8); auto result = JUST(functional::FusedGetBounddingBoxesCoordGrad( b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff, b2_x1_diff, b2_x2_diff, b2_y1_diff, b2_y2_diff)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); for (int i = 0; i < result->size(); i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_boundding_boxes_coord", FusedGetBounddingBoxesCoordGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_ciou_diagonal_angle.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 4; struct FusedCiouAngleCaptureState : public AutoGradCaptureState { std::vector requires_grad; float eps = 1e-8; }; class FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedCiouAngleCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); } for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); } ComposedAttrMap composed_attrs(attrs); ctx->eps = JUST(composed_attrs.GetAttr("eps")); return Maybe::Ok(); } Maybe Apply(const FusedCiouAngleCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& v_diff = out_grads.at(0); const auto& w1 = ctx->SavedTensors().at(0); const auto& h1 = ctx->SavedTensors().at(1); const auto& w2 = ctx->SavedTensors().at(2); const auto& h2 = ctx->SavedTensors().at(3); auto result = JUST(functional::FusedGetCiouDiagonalAngleGrad(w1, h1, w2, h2, v_diff, ctx->eps)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); in_grads->resize(INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_diagonal_angle", FusedGetCiouDiagonalAngleGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_ciou_result.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedGetCiouResultGradCaptureState : public AutoGradCaptureState { bool v_requires_grad = false; bool iou_requires_grad = false; bool rho2_requires_grad = false; bool c2_requires_grad = false; }; class FusedGetCiouResultGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 4); CHECK_EQ_OR_RETURN(outputs.size(), 2); ctx->v_requires_grad = inputs.at(0)->requires_grad(); ctx->iou_requires_grad = inputs.at(1)->requires_grad(); ctx->rho2_requires_grad = inputs.at(2)->requires_grad(); ctx->c2_requires_grad = inputs.at(3)->requires_grad(); if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad && ctx->c2_requires_grad) { ctx->SaveTensorForBackward(outputs.at(1)); // alpha ctx->SaveTensorForBackward(inputs.at(2)); // rho2 ctx->SaveTensorForBackward(inputs.at(3)); // c2 } return Maybe::Ok(); } Maybe Apply(const FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 2); const auto& dy = out_grads.at(0); const auto& saved_tensors = ctx->SavedTensors(); CHECK_EQ_OR_RETURN(saved_tensors.size(), 3); const auto& alpha = saved_tensors.at(0); const auto& rho2 = saved_tensors.at(1); const auto& c2 = saved_tensors.at(2); in_grads->resize(4); auto result = JUST(functional::FusedGetCiouResultGrad(dy, alpha, rho2, c2)); CHECK_EQ_OR_RETURN(result->size(), 4); if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad && ctx->c2_requires_grad) { in_grads->at(0) = result->at(0); in_grads->at(1) = result->at(1); in_grads->at(2) = result->at(2); in_grads->at(3) = result->at(3); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_result", FusedGetCiouResultGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_convex_diagonal_squared.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 8; struct FusedGetConvexDiagonalSquaredCaptureState : public AutoGradCaptureState { std::vector requires_grad; float eps = 1e-8; }; class FusedGetConvexDiagonalSquaredGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedGetConvexDiagonalSquaredCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->eps = JUST(composed_attrs.GetAttr("eps")); for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); } return Maybe::Ok(); } Maybe Apply(const FusedGetConvexDiagonalSquaredCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& c2_diff = out_grads.at(0); const auto& b1_x1 = ctx->SavedTensors().at(0); const auto& b1_x2 = ctx->SavedTensors().at(1); const auto& b2_x1 = ctx->SavedTensors().at(2); const auto& b2_x2 = ctx->SavedTensors().at(3); const auto& b1_y1 = ctx->SavedTensors().at(4); const auto& b1_y2 = ctx->SavedTensors().at(5); const auto& b2_y1 = ctx->SavedTensors().at(6); const auto& b2_y2 = ctx->SavedTensors().at(7); in_grads->resize(INPUT_LEN); auto result = JUST(functional::FusedGetConvexDiagonalSquaredGrad( c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, ctx->eps)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_convex_diagonal_squared", FusedGetConvexDiagonalSquaredGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_intersection_area.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 8; struct FusedGetIntersectionAreaCaptureState : public AutoGradCaptureState { std::vector requires_grad; }; class FusedGetIntersectionAreaGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); } for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); } return Maybe::Ok(); } Maybe Apply(const FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& rho2_diff = out_grads.at(0); const auto& b1_x1 = ctx->SavedTensors().at(0); const auto& b1_x2 = ctx->SavedTensors().at(1); const auto& b2_x1 = ctx->SavedTensors().at(2); const auto& b2_x2 = ctx->SavedTensors().at(3); const auto& b1_y1 = ctx->SavedTensors().at(4); const auto& b1_y2 = ctx->SavedTensors().at(5); const auto& b2_y1 = ctx->SavedTensors().at(6); const auto& b2_y2 = ctx->SavedTensors().at(7); in_grads->resize(INPUT_LEN); auto result = JUST(functional::FusedGetIntersectionAreaGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, rho2_diff)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_intersection_area", FusedGetIntersectionAreaGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_get_iou.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedGetIouGradCaptureState : public AutoGradCaptureState { bool requires_grad = true; float eps = 1e-8; }; class FusedGetIouGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedGetIouGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 5); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad() && inputs.at(1)->requires_grad() && inputs.at(4)->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->eps = JUST(composed_attrs.GetAttr("eps")); if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); // w1 ctx->SaveTensorForBackward(inputs.at(1)); // h1 ctx->SaveTensorForBackward(inputs.at(2)); // w2 ctx->SaveTensorForBackward(inputs.at(3)); // h2 ctx->SaveTensorForBackward(inputs.at(4)); // inter } return Maybe::Ok(); } Maybe Apply(const FusedGetIouGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& diou = out_grads.at(0); const auto& saved_tensors = ctx->SavedTensors(); CHECK_EQ_OR_RETURN(saved_tensors.size(), 5); const auto& w1 = saved_tensors.at(0); const auto& h1 = saved_tensors.at(1); const auto& w2 = saved_tensors.at(2); const auto& h2 = saved_tensors.at(3); const auto& inter = saved_tensors.at(4); in_grads->resize(5); auto result = JUST(functional::FusedGetIouGrad(diou, w1, h1, w2, h2, inter, ctx->eps)); CHECK_EQ_OR_RETURN(result->size(), 3); if (ctx->requires_grad) { in_grads->at(0) = result->at(0); in_grads->at(1) = result->at(1); in_grads->at(4) = result->at(2); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_iou", FusedGetIouGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_glu.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedGluGradCaptureState : public AutoGradCaptureState { bool is_split_mode = false; bool has_bias = false; std::string activation = "none"; bool w_requires_grad = false; bool v_requires_grad = false; bool b_requires_grad = false; bool c_requires_grad = false; }; class FusedGluGrad : public OpExprGradFunction { Maybe Init(const OpExpr& op) override; Maybe Capture(FusedGluGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedGluGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedGluGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedGluGrad::Capture(FusedGluGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // check input size const size_t in_size = inputs.size(); CHECK_OR_RETURN(in_size == 2 || in_size == 3 || in_size == 5) << "FusedGluGrad::Capture(): input tensor size must be 2 or 3 or 5"; // check the input pattern: ctx->has_bias = JUST(attrs.GetAttr("has_bias")); ctx->is_split_mode = JUST(attrs.GetAttr("is_split")); // check whether input tensors need grad ctx->w_requires_grad = inputs[1]->requires_grad(); if (ctx->has_bias) { ctx->b_requires_grad = inputs[2]->requires_grad(); if (ctx->is_split_mode) { ctx->v_requires_grad = inputs[3]->requires_grad(); ctx->c_requires_grad = inputs[4]->requires_grad(); } } else { if (ctx->is_split_mode) { ctx->v_requires_grad = inputs[2]->requires_grad(); } } // save tensors for backward ctx->SaveTensorForBackward(inputs[0]); // x ctx->SaveTensorForBackward(outputs[1]); // matmul_wx if (ctx->is_split_mode) { ctx->SaveTensorForBackward(outputs[2]); // matmul_vx } // save activation type ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->activation = JUST(composed_attrs.GetAttr("activation")); return Maybe::Ok(); } Maybe FusedGluGrad::Apply(const FusedGluGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { // obtain saved tensors from forward process const auto& x = ctx->SavedTensors()[0]; const auto& matmul_wx = ctx->SavedTensors()[1]; // obtain gradient dy const auto& dy = out_grads[0]; if (ctx->is_split_mode) { // obtain saved optional tensor from forward process const auto& matmul_vx = ctx->SavedTensors()[2]; if (ctx->w_requires_grad or ctx->b_requires_grad or ctx->v_requires_grad or ctx->c_requires_grad) { // calculate the intermediate gradient using fused kernel const auto& middle_results = JUST(functional::FusedGluWithoutLinearGrad(dy, matmul_wx, matmul_vx, ctx->activation)); const auto& d_matmul_wx = (*middle_results)[0]; const auto& d_matmul_vx = (*middle_results)[1]; // calculate the final gradient result of w (if necessary) if (ctx->w_requires_grad) { (*in_grads)[1] = JUST(functional::BroadcastMatmulGradB(d_matmul_wx, x, 1.0)); } // calculate the final gradient result of b (if necessary) if (ctx->b_requires_grad) { const int64_t num_axes = d_matmul_wx->shape()->NumAxes(); std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes - 1); for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); } (*in_grads)[2] = JUST(functional::ReduceSum(d_matmul_wx, reduce_axes_vec, false, NullOpt)); } // calculate the final gradient result of v (if necessary) if (ctx->v_requires_grad) { if (ctx->has_bias) { (*in_grads)[3] = JUST(functional::BroadcastMatmulGradB(d_matmul_vx, x, 1.0)); } else { (*in_grads)[2] = JUST(functional::BroadcastMatmulGradB(d_matmul_vx, x, 1.0)); } } // calculate the final gradient result of c (if necessary) if (ctx->c_requires_grad) { const int64_t num_axes = d_matmul_vx->shape()->NumAxes(); std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes - 1); for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); } (*in_grads)[4] = JUST(functional::ReduceSum(d_matmul_vx, reduce_axes_vec, false, NullOpt)); } } } else { if (ctx->w_requires_grad or ctx->b_requires_grad) { // calculate the intermediate gradient using fused kernel const auto& middle_results = JUST(functional::FusedGluWithoutLinearGrad(dy, matmul_wx, nullptr, ctx->activation)); const auto& d_matmul_wx = (*middle_results)[0]; // calculate the final gradient result of w (if necessary) if (ctx->w_requires_grad) { (*in_grads)[1] = JUST(functional::BroadcastMatmulGradB(d_matmul_wx, x, 1.0)); } // calculate the final gradient result of b (if necessary) if (ctx->b_requires_grad) { const int64_t num_axes = d_matmul_wx->shape()->NumAxes(); std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes - 1); for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); } (*in_grads)[2] = JUST(functional::ReduceSum(d_matmul_wx, reduce_axes_vec, false, NullOpt)); } } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_glu", FusedGluGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_gru_cell.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedGruCellGradCaptureState : public AutoGradCaptureState { bool has_bias = true; bool hx_needs_grad = true; }; class FusedGruCellGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "FusedGruCellGrad::Init forward op expr is null."; return Maybe::Ok(); } Maybe Capture(FusedGruCellGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { const size_t in_size = inputs.size(); CHECK_OR_RETURN(in_size == 3 || in_size == 5) << "FusedGruCellGrad::Capture(): input tensor size must be 3 or 5"; ctx->has_bias = in_size == 5; ctx->hx_needs_grad = inputs[2]->requires_grad(); ctx->SaveTensorForBackward(outputs[1]); // workspace return Maybe::Ok(); } Maybe Apply(const FusedGruCellGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& workspace = ctx->SavedTensors()[0]; // workspace const auto& grad_hy = out_grads[0]; const auto& results = JUST(functional::FusedGruCellGrad(grad_hy, workspace, ctx->has_bias, ctx->hx_needs_grad)); if (ctx->has_bias) { in_grads->resize(5); } else { in_grads->resize(3); } (*in_grads)[0] = (*results)[0]; (*in_grads)[1] = (*results)[1]; if (ctx->hx_needs_grad) { (*in_grads)[2] = (*results)[2]; } if (ctx->has_bias) { if (ctx->hx_needs_grad) { (*in_grads)[3] = (*results)[3]; (*in_grads)[4] = (*results)[4]; } else { (*in_grads)[3] = (*results)[2]; (*in_grads)[4] = (*results)[3]; } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_gru_cell", FusedGruCellGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_lstm_cell.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedLstmCellGradCaptureState : public AutoGradCaptureState { bool has_bias = true; bool need_grad_cx = true; }; class FusedLstmCellGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "FusedLstmCellGrad::Init forward op expr is null."; return Maybe::Ok(); } Maybe Capture(FusedLstmCellGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { const size_t in_size = inputs.size(); CHECK_OR_RETURN(in_size == 3 || in_size == 5) << "FusedLstmCellGrad::Capture(): input tensor size must be 3 or 5"; ctx->has_bias = in_size == 5; ctx->need_grad_cx = inputs[2]->requires_grad(); ctx->SaveTensorForBackward(inputs[2]); // cx ctx->SaveTensorForBackward(outputs[1]); // cy ctx->SaveTensorForBackward(outputs[2]); // workspace return Maybe::Ok(); } Maybe Apply(const FusedLstmCellGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& cx = ctx->SavedTensors()[0]; // cx const auto& cy = ctx->SavedTensors()[1]; // cy const auto& workspace = ctx->SavedTensors()[2]; // workspace const auto& grad_hy = out_grads[0]; const auto& grad_cy = out_grads[1]; const auto& results = JUST(functional::FusedLstmCellGrad(grad_hy, grad_cy, cx, cy, workspace, ctx->need_grad_cx, ctx->has_bias)); if (ctx->has_bias) { in_grads->resize(5); } else { in_grads->resize(3); } (*in_grads)[0] = (*results)[0]; (*in_grads)[1] = (*results)[0]; if (ctx->need_grad_cx) { (*in_grads)[2] = (*results)[1]; } if (ctx->has_bias) { if (ctx->need_grad_cx) { (*in_grads)[3] = (*results)[2]; (*in_grads)[4] = (*results)[2]; } else { (*in_grads)[3] = (*results)[1]; (*in_grads)[4] = (*results)[1]; } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_lstm_cell", FusedLstmCellGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_matmul_bias.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct FusedMatmulBiasCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool weight_requires_grad = false; bool bias_requires_grad = false; }; class FusedMatmulBias : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedMatmulBiasCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe FusedMatmulBias::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedMatmulBias::Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_GE_OR_RETURN(inputs.size(), 3) << "x, weight, and bias, [add_to_output] should all be included"; ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->weight_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad(); ctx->bias_requires_grad = JUST(VectorAt(inputs, 2))->requires_grad(); ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); return Maybe::Ok(); } Maybe FusedMatmulBias::Apply(const FusedMatmulBiasCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "FusedMatmulBias more than one output"; const auto& x = ctx->SavedTensors().at(0); const auto& weight = ctx->SavedTensors().at(1); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), weight, false, false, 1.0)); } if (ctx->weight_requires_grad) { in_grads->at(1) = JUST(functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), x, 1.0)); } if (ctx->bias_requires_grad) { const int64_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes - 1); for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); } in_grads->at(2) = JUST(functional::ReduceSum(JUST(VectorAt(out_grads, 0)), reduce_axes_vec, false, NullOpt)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_matmul_bias", FusedMatmulBias); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_matmul_bias_add_relu_dropout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #if CUDA_VERSION >= 11060 namespace oneflow { namespace one { struct FusedMatmulBiasAddReluDropoutCaptureState : public AutoGradCaptureState { int32_t weight_num = 0; bool skip_final_activation = false; bool x_requires_grad = false; std::vector weights_requires_grad; std::vector biases_requires_grad; std::vector dropout_rate_list; }; class FusedMatmulBiasAddReluDropout : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe FusedMatmulBiasAddReluDropout::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedMatmulBiasAddReluDropout::Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_OR_RETURN(inputs.size() % 2 == 1) << "Both weight and bias should be passed together. "; int32_t weight_num = (inputs.size() - 1) / 2; ctx->weight_num = weight_num; ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->weights_requires_grad.resize(weight_num); ctx->biases_requires_grad.resize(weight_num); for (int32_t i = 0; i < weight_num; i++) { ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad(); // NOLINT ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad(); // NOLINT } ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // x. idx_sum:1 for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1))); // weights. idx_sum:1+w } ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // final layers output. idx_sum:2+w for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward( JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w } for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden. } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->skip_final_activation = JUST(composed_attrs.GetAttr("skip_final_activation")); ctx->dropout_rate_list = JUST(composed_attrs.GetAttr>("dropout_rate_list")); return Maybe::Ok(); } Maybe FusedMatmulBiasAddReluDropout::Apply( const FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { int32_t weight_num = ctx->weight_num; in_grads->resize(1 + 2 * weight_num); TensorTuple hiddens(weight_num); TensorTuple weights(weight_num); TensorTuple cublas_auxs(weight_num); TensorTuple dgrad(weight_num); std::shared_ptr x = JUST(VectorAt(ctx->SavedTensors(), 0)); std::shared_ptr out = JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num)); for (int32_t i = 0; i < weight_num; ++i) { weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i)); } for (int32_t i = 0; i < weight_num; ++i) { cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num)); } for (int32_t i = 0; i < weight_num; ++i) { hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num)); } float rate = ctx->dropout_rate_list.at(weight_num - 1); float scale = 0.0f; if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } /* step1: use dy and mask to get last layer's dropout + relu grad. Because curand_uniform distribution is (0.0, 1.0], so the value after relu will be write into mask too. And DropoutGrad use this mask to generate grad, it will generate dropout and relu grad simultaneously. */ std::shared_ptr last_bias_dy = JUST(VectorAt(out_grads, 0)); if (!ctx->skip_final_activation || rate != 0.0f) { last_bias_dy = JUST(functional::FusedReluDropoutGrad(JUST(VectorAt(out_grads, 0)), cublas_auxs[weight_num - 1], scale)); } if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) { std::vector alpha_list(weight_num - 1, 1.0); for (int i = 0; i < weight_num - 1; i++) { rate = ctx->dropout_rate_list.at(i); scale = 1.0; if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } alpha_list.at(i) = scale; } const auto& fused_mlp_grad = JUST(functional::FusedMLPGrad(last_bias_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights, cublas_auxs, hiddens, alpha_list)); if (ctx->x_requires_grad) { // dx: JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0); } for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) { if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) { // dbias JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) = fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT } // dw if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = fused_mlp_grad->at(1 + weight_num + hidden_layer_idx); } } } else { // step2: use reduce_sum to get last layer's bias grad. std::vector reduce_axes_vec{0}; if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) { JUST(VectorAt(*in_grads, 2 * weight_num)) = JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false, NullOpt)); } std::shared_ptr cublas_dy = last_bias_dy; for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { // If it is final layer, we use out_grads[0] as dy. if (hidden_layer_idx != weight_num - 1) { cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1)); } rate = ctx->dropout_rate_list.at(hidden_layer_idx - 1); scale = 1.0; if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } /* Here we use cublas to compute bias + relu + matmul grad. Then use Matmul to compute weight grad. */ const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad( cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)), JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/scale)); // dgrad dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) { // dbias JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) = matmul_relu_bias_bgrad->at(1); // NOLINT } // dw if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul( cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0)); } } // For the first layer, we need to use 2 matmul to get grads. std::shared_ptr last_dy; if (weight_num != 1) { last_dy = JUST(VectorAt(dgrad, 1)); } else { last_dy = last_bias_dy; } if (ctx->x_requires_grad) { // dx: JUST(VectorAt(*in_grads, 0)) = JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0)); } if (JUST(VectorAt(ctx->weights_requires_grad, 0))) { // dw: JUST(VectorAt(*in_grads, 1)) = JUST( functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0)); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_matmul_bias_add_relu_dropout", FusedMatmulBiasAddReluDropout); } // namespace one } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_scale_mask_bias_softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct FusedScaleMaskBiasSoftmaxCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool bias_requires_grad = false; int32_t input_size = 3; float scale = 1.0; }; class FusedScaleMaskBiasSoftmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedScaleMaskBiasSoftmax::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedScaleMaskBiasSoftmax::Capture(FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->input_requires_grad = inputs.at(0)->requires_grad(); if (inputs.size() == 3) ctx->bias_requires_grad = inputs.at(2)->requires_grad(); if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } ctx->scale = JUST(composed_attrs.GetAttr("scale")); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe FusedScaleMaskBiasSoftmax::Apply(const FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // dy if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe::Ok(); } in_grads->resize(ctx->input_size); const std::shared_ptr& y = ctx->SavedTensors().at(0); const std::shared_ptr& input_grad = JUST(functional::FusedScaleMaskBiasSoftmaxGrad(y, out_grads.at(0), ctx->scale)); if (ctx->input_requires_grad) in_grads->at(0) = input_grad; if (ctx->bias_requires_grad) { int batch_dim = (y->shape()->NumAxes() == 5) ? 1 : 0; in_grads->at(2) = JUST(functional::ScalarMul( 1 / ctx->scale, JUST(functional::ReduceSum(input_grad, {batch_dim}, true, NullOpt)))); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_mask_bias_softmax", FusedScaleMaskBiasSoftmax); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_scale_mask_softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedScaleMaskSoftmaxInterState : public AutoGradCaptureState { bool input_requires_grad = false; float scale = 1.0; }; class FusedScaleMaskSoftmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedScaleMaskSoftmax::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedScaleMaskSoftmax::Capture(FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); // input, mask ctx->input_requires_grad = inputs.at(0)->requires_grad(); if (!ctx->input_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale = JUST(composed_attrs.GetAttr("scale_value")); ctx->SaveTensorForBackward(inputs.at(1)); // save mask ctx->SaveTensorForBackward(outputs.at(0)); // save y, ie. softmax result return Maybe::Ok(); } Maybe FusedScaleMaskSoftmax::Apply(const FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->input_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // dy in_grads->resize(2); // input, mask const std::shared_ptr& mask = ctx->SavedTensors().at(0); const std::shared_ptr& y = ctx->SavedTensors().at(1); const std::shared_ptr& fused_scale_mask_softmax_grad = JUST(functional::FusedScaleMaskSoftmaxGrad(y, out_grads.at(0), mask, ctx->scale)); in_grads->at(0) = fused_scale_mask_softmax_grad; return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_mask_softmax", FusedScaleMaskSoftmax); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_scale_mask_softmax_dropout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedScaleMaskSoftmaxDropoutInterState : public AutoGradCaptureState { bool input_requires_grad = true; float scale = 1.0; float dropout_scale = 1.0; }; class FusedScaleMaskSoftmaxDropout : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedScaleMaskSoftmaxDropout::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedScaleMaskSoftmaxDropout::Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 3); // input, mask, dropout_mask ctx->input_requires_grad = inputs.at(0)->requires_grad(); if (!ctx->input_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale = JUST(composed_attrs.GetAttr("scale_value")); ctx->dropout_scale = JUST(composed_attrs.GetAttr("dropout_scale_value")); ctx->SaveTensorForBackward(inputs.at(1)); // mask ctx->SaveTensorForBackward(inputs.at(2)); // dropout_mask ctx->SaveTensorForBackward(outputs.at(1)); // softmax_y return Maybe::Ok(); } Maybe FusedScaleMaskSoftmaxDropout::Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); // dy, d_softmax_y if (!ctx->input_requires_grad) { return Maybe::Ok(); } in_grads->resize(3); // input, mask, dropout_mask const std::shared_ptr& mask = ctx->SavedTensors().at(0); const std::shared_ptr& dropout_mask = ctx->SavedTensors().at(1); const std::shared_ptr& softmax_y = ctx->SavedTensors().at(2); const std::shared_ptr& input_grad = JUST(functional::FusedScaleMaskSoftmaxDropoutGrad( softmax_y, out_grads.at(0), mask, dropout_mask, ctx->scale, ctx->dropout_scale)); in_grads->at(0) = input_grad; return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_mask_softmax_dropout", FusedScaleMaskSoftmaxDropout); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_scale_tril.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedScaleTrilState : public AutoGradCaptureState { bool requires_grad; int64_t diagonal; double floating_scale_value; int64_t integer_scale_value; bool is_floating_scale_value; }; class FusedScaleTril : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedScaleTril::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedScaleTril::Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->diagonal = JUST(composed_attrs.GetAttr("diagonal")); ctx->floating_scale_value = JUST(composed_attrs.GetAttr("floating_scale_value")); ctx->integer_scale_value = JUST(composed_attrs.GetAttr("integer_scale_value")); ctx->is_floating_scale_value = JUST(composed_attrs.GetAttr("is_floating_scale_value")); return Maybe::Ok(); } Maybe FusedScaleTril::Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); Scalar scale; if (ctx->is_floating_scale_value) { scale = ctx->floating_scale_value; } else { scale = ctx->integer_scale_value; } (*in_grads)[0] = JUST(functional::FusedScaleTril(out_grads[0], ctx->diagonal, 0, scale)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_tril", FusedScaleTril); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_scale_tril_softmax_mask_scale.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedScaleTrilSoftmaxMaskScaleInterpState : public AutoGradCaptureState { bool input_requires_grad = true; int64_t diagonal = 0; float tril_scale_value = 0.0; float mask_scale_value = 1.0; }; class FusedScaleTrilSoftmaxMaskScale : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe FusedScaleTrilSoftmaxMaskScale::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedScaleTrilSoftmaxMaskScale::Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input if (!ctx->input_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->diagonal = JUST(composed_attrs.GetAttr("diagonal")); ctx->tril_scale_value = JUST(composed_attrs.GetAttr("tril_scale_value")); ctx->mask_scale_value = JUST(composed_attrs.GetAttr("mask_scale_value")); ctx->SaveTensorForBackward(inputs.at(1)); // Save Mask ctx->SaveTensorForBackward(outputs.at(1)); // Save softmax_y return Maybe::Ok(); } Maybe FusedScaleTrilSoftmaxMaskScale::Apply( const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); // Cause output has y and softmax_y if (!ctx->input_requires_grad) { return Maybe::Ok(); } // mask have no grad(reqiures_grad=False), but still take a place in in_grads in_grads->resize(2); const std::shared_ptr& mask = ctx->SavedTensors().at(0); const std::shared_ptr& softmax_y = ctx->SavedTensors().at(1); const std::shared_ptr& input_grad = JUST(functional::FusedScaleTrilSoftmaxMaskScaleGrad(softmax_y, out_grads.at(0), mask, ctx->diagonal, ctx->tril_scale_value, ctx->mask_scale_value)); if (ctx->input_requires_grad) { in_grads->at(0) = input_grad; } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_tril_scale_softmax_mask_scale", FusedScaleTrilSoftmaxMaskScale); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct FusedSelfAttentionInterpState : public AutoGradCaptureState { bool input_requires_grad = false; float alpha = 1.0; }; class FusedSelfAttention : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(FusedSelfAttentionInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->input_requires_grad = inputs.at(0)->requires_grad(); if (!ctx->input_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const FusedSelfAttentionInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->input_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); in_grads->resize(1); const auto& hidden_states = ctx->SavedTensors().at(0); const std::shared_ptr& fused_self_attention_grad = JUST(functional::FusedSelfAttentionGrad(out_grads.at(0), out_grads.at(1), hidden_states, ctx->alpha)); in_grads->at(0) = fused_self_attention_grad; return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_self_attention_query_mul_key_and_value", FusedSelfAttention); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/fused_weighted_sum.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct FusedWeightedSumCaptureState : public AutoGradCaptureState { std::vector requires_grad; std::vector weights; float alpha{}; }; class FusedWeightedSum : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(FusedWeightedSumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad.resize(inputs.size()); ctx->weights = JUST(attrs.GetAttr>("weights")); ctx->alpha = JUST(attrs.GetAttr("alpha")); CHECK_EQ_OR_RETURN(ctx->weights.size(), inputs.size()); for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs[i]->requires_grad(); } return Maybe::Ok(); } Maybe Apply(const FusedWeightedSumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(ctx->requires_grad.size()); for (int i = 0; i < ctx->requires_grad.size(); ++i) { if (ctx->requires_grad[i]) { (*in_grads)[i] = JUST(functional::ScalarMul(out_grads[0], ctx->weights[i] * ctx->alpha, false)); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("fused_weighted_sum", FusedWeightedSum); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/gather.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct GatherCaptureState : public AutoGradCaptureState { int64_t axis; bool requires_grad; }; class Gather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Gather::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Gather::Capture(GatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); return Maybe::Ok(); } Maybe Gather::Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& x = ctx->SavedTensors().at(0); const auto& indices = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::UnsortedSegmentSumLike(out_grads.at(0), indices, x, ctx->axis)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("gather", Gather); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/gather_nd.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct GatherNdCaptureState : public AutoGradCaptureState { bool requires_grad; }; class GatherNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(GatherNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); // params ctx->SaveTensorForBackward(inputs.at(1)); // indices } return Maybe::Ok(); } Maybe Apply(const GatherNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad) { const auto& params = ctx->SavedTensors().at(0); const auto& indices = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::ScatterNdLike(params, out_grads.at(0), indices)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("gather_nd", GatherNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/global_cast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct CastGlobalCaptureState : public AutoGradCaptureState { Symbol parallel_desc; Symbol nd_sbp; std::shared_ptr shape; Symbol dtype; }; class LocalToGlobal : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) const std::string& op_name = fw_op_expr->op_name(); grad_op_ = JUST(one::GlobalToLocalOpExpr::New(GradientOpName(op_name))); return Maybe::Ok(); } Maybe Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { ctx->parallel_desc = JUST(interp_ctx.parallel_desc); ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp))); return Maybe::Ok(); } Maybe Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) std::shared_ptr out_grad = out_grads.at(0); CHECK_OR_RETURN(out_grad->is_global()) << Error::RuntimeError() << "Expected global tensor for local_to_global but got local tensor"; { Symbol nd_sbp_constraint = ctx->nd_sbp; Symbol parallel_desc_constraint = ctx->parallel_desc; out_grad = JUST(functional::ToGlobal(out_grad, parallel_desc_constraint, *JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList(), /* check_meta */ false, /*copy=*/false)); } in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grad})); return Maybe::Ok(); } private: std::shared_ptr grad_op_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("local_to_global", LocalToGlobal); class GlobalToLocal : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) const std::string& op_name = fw_op_expr->op_name(); grad_op_ = JUST(one::LocalToGlobalOpExpr::New(GradientOpName(op_name))); return Maybe::Ok(); } Maybe Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { const auto& input = inputs.at(0); CHECK_OR_RETURN(input->is_global()) << Error::RuntimeError() << "Expected global tensor for global_to_local but got local tensor"; ctx->parallel_desc = JUST(input->parallel_desc()); ctx->nd_sbp = JUST(input->nd_sbp()); ctx->shape = input->shape(); ctx->dtype = input->dtype(); return Maybe::Ok(); } Maybe Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "sync_data"); attrs.SetAllAttrs(*ctx->shape, ctx->dtype->data_type(), true); in_grads->at(0) = JUST(OpInterpUtil::Dispatch( *grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp))); return Maybe::Ok(); } private: std::shared_ptr grad_op_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("global_to_local", GlobalToLocal); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/global_to_global.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/optional.h" namespace oneflow { namespace one { struct GlobalToGlobalState : public AutoGradCaptureState { Symbol parallel_desc; Symbol nd_sbp; }; class GlobalToGlobalGradFunction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) grad_nd_sbp_ = fw_op_expr->grad_nd_sbp(); return Maybe::Ok(); } Maybe Capture(GlobalToGlobalState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc()); ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp()); return Maybe::Ok(); } Maybe Apply(const GlobalToGlobalState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& out_grad = out_grads.at(0); CHECK_OR_RETURN(out_grad->is_global()) << Error::RuntimeError() << "Expected global tensor for global_to_global but got local tensor"; in_grads->resize(1); const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp())); const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp)); if (LazyMode::is_enabled()) { (*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list, {}, /* check_meta */ false, /*copy=*/false)); } else { const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp)); (*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list, *grad_grad_sbp_list, /* check_meta */ false, /*copy=*/false)); } return Maybe::Ok(); } private: Optional> grad_nd_sbp_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("global_to_global", GlobalToGlobalGradFunction); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/gradient_accumulation.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct GradAccRepeatCaptureState : public AutoGradCaptureState { int32_t repeat_num = 1; }; class GradAccRepeat : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe GradAccRepeat::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe GradAccRepeat::Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->repeat_num = JUST(composed_attrs.GetAttr("repeat_num")); return Maybe::Ok(); } Maybe GradAccRepeat::Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); (*in_grads)[0] = JUST(functional::GradAccCollect(out_grads[0], ctx->repeat_num)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("repeat", GradAccRepeat); struct GradAccCollectCaptureState : public AutoGradCaptureState { int32_t max_acc_num = 1; }; class GradAccCollect : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GradAccCollectCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe GradAccCollect::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe GradAccCollect::Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->max_acc_num = JUST(composed_attrs.GetAttr("max_acc_num")); return Maybe::Ok(); } Maybe GradAccCollect::Apply(const GradAccCollectCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); (*in_grads)[0] = JUST(functional::GradAccRepeat(out_grads[0], ctx->max_acc_num)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("acc", GradAccCollect); struct GradAccPackCaptureState : public AutoGradCaptureState { int32_t pack_num = 1; }; class GradAccPack : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe GradAccPack::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe GradAccPack::Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->pack_num = JUST(composed_attrs.GetAttr("pack_num")); return Maybe::Ok(); } Maybe GradAccPack::Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); (*in_grads)[0] = JUST(functional::GradAccUnpack(out_grads[0], ctx->pack_num)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("pack", GradAccPack); struct GradAccUnpackCaptureState : public AutoGradCaptureState { int32_t unpack_num = 1; }; class GradAccUnpack : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe GradAccUnpack::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe GradAccUnpack::Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->unpack_num = JUST(composed_attrs.GetAttr("unpack_num")); return Maybe::Ok(); } Maybe GradAccUnpack::Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); (*in_grads)[0] = JUST(functional::GradAccPack(out_grads[0], ctx->unpack_num)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("unpack", GradAccUnpack); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/graph_feed_and_fetch.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { struct GraphFeedAndFetchCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class GraphFeedAndFetch : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(GraphFeedAndFetchCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const GraphFeedAndFetchCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("graph_feed_and_fetch", GraphFeedAndFetch); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/grid_sample.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct GridSampleInterpState : public AutoGradCaptureState { std::string interpolation_mode = ""; std::string padding_mode = ""; bool align_corners = false; size_t input_index = -1; size_t grid_index = -1; bool input_requires_grad = false; bool grid_requires_grad = false; bool requires_grad = false; }; class GridSample : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(GridSampleInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->grid_requires_grad = inputs.at(1)->requires_grad(); ctx->requires_grad = ctx->input_requires_grad || ctx->grid_requires_grad; if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->grid_index = ctx->SaveTensorForBackward(inputs.at(1)); // grid ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->interpolation_mode = JUST(composed_attrs.GetAttr("interpolation_mode")); ctx->padding_mode = JUST(composed_attrs.GetAttr("padding_mode")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); return Maybe::Ok(); } Maybe Apply(const GridSampleInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& input = ctx->SavedTensors().at(ctx->input_index); const auto& grid = ctx->SavedTensors().at(ctx->grid_index); const auto& results = JUST(functional::GridSampleGrad(out_grads.at(0), input, grid, ctx->interpolation_mode, ctx->padding_mode, ctx->align_corners)); in_grads->resize(2); if (ctx->input_requires_grad) { in_grads->at(0) = results->at(0); } if (ctx->grid_requires_grad) { in_grads->at(1) = results->at(1); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("grid_sample", GridSample); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/group_norm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct GroupNormCaptureState : public AutoGradCaptureState { double epsilon = 1e-5; bool x_requires_grad = true; bool gamma_requires_grad = true; bool beta_requires_grad = true; bool affine = true; int32_t num_groups = 1; size_t x_index = 0; size_t mean_index = 1; size_t inv_variance_index = 2; size_t gamma_index = 3; std::string data_format; std::string activation; }; class GroupNorm : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::string op_name_; }; Maybe GroupNorm::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); op_name_ = fw_op_expr->op_name(); return Maybe::Ok(); } Maybe GroupNorm::Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->affine = JUST(composed_attrs.GetAttr("affine")); ctx->epsilon = JUST(composed_attrs.GetAttr("epsilon")); ctx->num_groups = JUST(composed_attrs.GetAttr("num_groups")); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->activation = JUST(composed_attrs.GetAttr("activation")); if (ctx->affine) { CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->gamma_requires_grad = inputs.at(1)->requires_grad(); ctx->beta_requires_grad = inputs.at(2)->requires_grad(); } else { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) } CHECK_EQ_OR_RETURN(outputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); if (ctx->x_requires_grad || ctx->affine) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->mean_index = ctx->SaveTensorForBackward(outputs.at(1)); ctx->inv_variance_index = ctx->SaveTensorForBackward(outputs.at(2)); if (ctx->affine) { ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1)); // save gamma. } } return Maybe::Ok(); } Maybe GroupNorm::Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(ctx->data_format, "channels_first"); CHECK_EQ_OR_RETURN(ctx->activation, "none"); const auto& saved_tensors = ctx->SavedTensors(); if (ctx->affine) { in_grads->resize(3); } else { in_grads->resize(1); } const auto& dy = out_grads.at(0); const auto& x = saved_tensors.at(ctx->x_index); const auto& mean = saved_tensors.at(ctx->mean_index); const auto& inv_variance = saved_tensors.at(ctx->inv_variance_index); if (ctx->affine && (ctx->gamma_requires_grad || ctx->beta_requires_grad)) { const auto& results = JUST(functional::GroupNormParamGrad(dy, x, mean, inv_variance)); if (ctx->gamma_requires_grad) { in_grads->at(1) = results->at(0); } // For gamma. if (ctx->beta_requires_grad) { in_grads->at(2) = results->at(1); } // For beta. } if (ctx->x_requires_grad) { if (ctx->affine) { std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, gamma, ctx->num_groups, ctx->epsilon)); } else { in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, NullOpt, ctx->num_groups, ctx->epsilon)); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("group_norm", GroupNorm); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/identity.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { struct IdentityCaptureState : public AutoGradCaptureState { bool requires_grad; }; class Identity : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(IdentityCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const IdentityCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { if (LazyMode::is_enabled()) { // requires an intermediate node to avoid redundant memory copy or commnet // communication in lazy mode in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); } else { in_grads->at(0) = out_grads.at(0); } } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("identity", Identity); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/inv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct InvCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class Inv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(InvCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); } return Maybe::Ok(); } Maybe Apply(const InvCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { const auto& output = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& dy = JUST(VectorAt(out_grads, 0)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::Negative(JUST(functional::MatMul( output, JUST(functional::MatMul(dy, output, false, true, 1.0)), true, false, 1.0)))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("inv", Inv); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/kl_div.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct KLDivLossCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool target_requires_grad = false; bool log_target = false; }; class KLDivLoss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe KLDivLoss::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe KLDivLoss::Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs[0]->requires_grad(); ctx->target_requires_grad = inputs[1]->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->log_target = JUST(composed_attrs.GetAttr("log_target")); ctx->SaveTensorForBackward(inputs[0]); // input ctx->SaveTensorForBackward(inputs[1]); // target return Maybe::Ok(); } Maybe KLDivLoss::Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads[0]; const auto& input = ctx->SavedTensors()[0]; const auto& target = ctx->SavedTensors()[1]; in_grads->resize(2); if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target)); } if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::KLDivLossTargetGrad(dy, input, target, ctx->log_target)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("kl_div_loss", KLDivLoss); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/l2_normalize.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct L2NormalizeCaptureState : public AutoGradCaptureState { int64_t axis; float epsilon; bool requires_grad; }; class L2Normalize : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe L2Normalize::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe L2Normalize::Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(outputs.at(0)); // y ctx->SaveTensorForBackward(outputs.at(1)); // square_x_sum ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); ctx->epsilon = JUST(composed_attrs.GetAttr("epsilon")); return Maybe::Ok(); } Maybe L2Normalize::Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(1); CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& y = ctx->SavedTensors().at(0); const auto& square_x_sum = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::L2NormalizeGrad(out_grads.at(0), y, square_x_sum, ctx->axis, ctx->epsilon)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("l2_normalize", L2Normalize); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/layer_norm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false); namespace one { struct LayerNormCaptureState : public AutoGradCaptureState { bool center = true; bool scale = true; int64_t begin_norm_axis = 1; int64_t begin_params_axis = 1; double epsilon = 1e-5; bool x_requires_grad = true; bool has_affine = true; size_t gamma_index = 0; size_t x_index = 1; size_t mean_index = 2; size_t inv_variance_index = 3; }; // y, mean, inv_variance = // layer_norm(x, [gamma], [beta], center=False, scale=False, begin_norm_axis=1, // begin_params_axis=-1, epsilon=1e-5) class LayerNorm : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::string op_name_; }; Maybe LayerNorm::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); op_name_ = fw_op_expr->op_name(); return Maybe::Ok(); } Maybe LayerNorm::Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->center = JUST(composed_attrs.GetAttr("center")); ctx->scale = JUST(composed_attrs.GetAttr("scale")); ctx->begin_norm_axis = JUST(composed_attrs.GetAttr("begin_norm_axis")); ctx->begin_params_axis = JUST(composed_attrs.GetAttr("begin_params_axis")); ctx->epsilon = JUST(composed_attrs.GetAttr("epsilon")); CHECK_EQ_OR_RETURN(inputs.size(), ctx->center + ctx->scale + 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 3); // NOLINT(maybe-need-error-msg) bool has_gamma_diff = ctx->scale && inputs.at(1)->requires_grad(); bool has_beta_diff = ctx->center && inputs.at(2)->requires_grad(); ctx->has_affine = has_gamma_diff && has_beta_diff; ctx->x_requires_grad = inputs.at(0)->requires_grad(); if (ctx->x_requires_grad || ctx->has_affine) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->mean_index = ctx->SaveTensorForBackward(outputs.at(1)); ctx->inv_variance_index = ctx->SaveTensorForBackward(outputs.at(2)); if (ctx->x_requires_grad && ctx->scale) { ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1)); // save gamma. } } return Maybe::Ok(); } Maybe LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& saved_tensors = ctx->SavedTensors(); in_grads->resize(ctx->center + ctx->scale + 1); std::shared_ptr dy = out_grads.at(0); int64_t begin_params_axis = ctx->begin_params_axis; if (begin_params_axis < 0) { begin_params_axis += dy->shape()->NumAxes(); } int64_t begin_norm_axis = ctx->begin_norm_axis; if (begin_norm_axis < 0) { begin_norm_axis += dy->shape()->NumAxes(); } std::shared_ptr x = saved_tensors.at(ctx->x_index); std::shared_ptr mean = saved_tensors.at(ctx->mean_index); std::shared_ptr inv_variance = saved_tensors.at(ctx->inv_variance_index); if (EnvBool()) { // just for npu CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test"; if (ctx->x_requires_grad) { if (ctx->scale) { std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); *in_grads = *JUST(functional::FuseLayerNormGrad( dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon)); } else { UNIMPLEMENTED(); } } } else { if (ctx->has_affine) { // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, // Int64 begin_params_axis) const auto& results = JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis)); in_grads->at(1) = results->at(0); // For gamma. in_grads->at(2) = results->at(1); // For beta. } if (ctx->x_requires_grad) { if (ctx->scale) { std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, begin_norm_axis, ctx->epsilon)); } else { in_grads->at(0) = JUST( functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon)); } } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("layer_norm", LayerNorm); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/lerp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { const int32_t INPUT_LEN = 3; struct LerpCaptureState : public AutoGradCaptureState { std::vector requires_grad; }; struct ScalarLerpCaptureState : public AutoGradCaptureState { std::vector requires_grad; Scalar operand; }; class LerpGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(LerpCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); ctx->SaveTensorForBackward(inputs.at(i)); } return Maybe::Ok(); } Maybe Apply(const LerpCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& out_diff = out_grads.at(0); const auto& start = ctx->SavedTensors().at(0); const auto& end = ctx->SavedTensors().at(1); const auto& weight = ctx->SavedTensors().at(2); auto result = JUST(functional::LerpGrad(start, end, weight, out_diff)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN); in_grads->resize(INPUT_LEN); for (int i = 0; i < INPUT_LEN; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } }; class ScalarLerpGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ScalarLerpCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN - 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); for (int i = 0; i < INPUT_LEN - 1; i++) { ctx->requires_grad.push_back(inputs.at(i)->requires_grad()); ctx->SaveTensorForBackward(inputs.at(i)); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } return Maybe::Ok(); } Maybe Apply(const ScalarLerpCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& out_diff = out_grads.at(0); const auto& start = ctx->SavedTensors().at(0); const auto& end = ctx->SavedTensors().at(1); auto result = JUST(functional::ScalarLerpGrad(start, end, out_diff, ctx->operand)); CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN - 1); in_grads->resize(INPUT_LEN - 1); for (int i = 0; i < INPUT_LEN - 1; i++) { if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); } } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("lerp", LerpGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_lerp", ScalarLerpGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/linalg_cross.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct LinalgCrossCaptureState : public AutoGradCaptureState { int64_t dim = -1; bool input_requires_grad = false; bool other_requires_grad = false; }; class LinalgCross : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe LinalgCross::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe LinalgCross::Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->other_requires_grad = inputs.at(1)->requires_grad(); if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->other_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); return Maybe::Ok(); } Maybe LinalgCross::Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(ctx->SavedTensors().size()); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) if (ctx->input_requires_grad) { in_grads->at(0) = JUST(functional::LinalgCross(ctx->SavedTensors().at(0), out_grads.at(0), ctx->dim)); } if (ctx->other_requires_grad) { in_grads->at(1) = JUST(functional::LinalgCross( out_grads.at(0), ctx->SavedTensors().at(ctx->input_requires_grad ? 1 : 0), ctx->dim)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("linalg_cross", LinalgCross); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/log_softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" namespace oneflow { namespace one { struct LogSoftmaxCaptureState : public AutoGradCaptureState { bool requires_grad; }; class LogSoftmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::shared_ptr grad_op_; }; Maybe LogSoftmax::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe LogSoftmax::Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe LogSoftmax::Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads.at(0); const auto& y = ctx->SavedTensors().at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::LogSoftmaxGrad(dy, y)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("log_softmax", LogSoftmax); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/masked_fill.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct MaskedFillCaptureState : public AutoGradCaptureState { bool requires_grad = true; }; class MaskedFill : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(MaskedFillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe Apply(const MaskedFillCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); const std::shared_ptr& mask = ctx->SavedTensors().at(1); std::shared_ptr zero_out = JUST(functional::ZerosLike(x)); in_grads->resize(2); in_grads->at(0) = JUST(functional::Where(mask, zero_out, out_grads.at(0))); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("masked_fill", MaskedFill); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/math_binary_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" namespace oneflow { namespace one { struct BinaryMathCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool y_requires_grad; }; typedef Maybe (*BinaryBwFunc)(const std::shared_ptr&, const std::shared_ptr&, const std::shared_ptr&); template class BinaryMathOp : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BinaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe Apply(const BinaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe::Ok(); } in_grads->resize(2); const std::shared_ptr& x = ctx->SavedTensors().at(0); const std::shared_ptr& y = ctx->SavedTensors().at(1); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(BwXFunc(x, y, out_grads.at(0))); } if (ctx->y_requires_grad) { in_grads->at(1) = JUST(BwYFunc(x, y, out_grads.at(0))); } return Maybe::Ok(); } }; #define INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS(op_type_name, op_cls) \ class op_cls##Cls final \ : public BinaryMathOp {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS, MATH_BINARY_ELEMENTWISE_FUNC_SEQ); #undef INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/math_unary_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct UnaryMathCaptureState : public AutoGradCaptureState { bool x_requires_grad; }; typedef Maybe (*UnaryBwFunc)(const std::shared_ptr&, const std::shared_ptr&); template class UnaryMathBwdWithDyXOp : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad) { return Maybe::Ok(); } const auto& x = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(BwFunc(x, out_grads.at(0))); return Maybe::Ok(); } protected: std::shared_ptr grad_op_; }; template class UnaryMathBwdWithDyYOp : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad) { return Maybe::Ok(); } const auto& y = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(BwFunc(y, out_grads.at(0))); return Maybe::Ok(); } protected: std::shared_ptr grad_op_; }; class UnaryMathBwdWithFillZeroOp : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad) { return Maybe::Ok(); } in_grads->at(0) = JUST(functional::ZerosLike(out_grads[0])); return Maybe::Ok(); } protected: std::shared_ptr grad_op_; }; #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS(op_type_name, op_cls) \ class op_cls##Cls final : public UnaryMathBwdWithDyXOp {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS, MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ); #undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS(op_type_name, op_cls) \ class op_cls##Cls final : public UnaryMathBwdWithDyYOp {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS, MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS, OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh)); #undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS(op_type_name, op_cls) \ class op_cls##Cls final : public UnaryMathBwdWithDyYOp {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, UnaryMathBwdWithFillZeroOp); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS, MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ); #undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS class NegativeOp : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad) { return Maybe::Ok(); } in_grads->at(0) = JUST(functional::Negative(out_grads[0])); return Maybe::Ok(); } protected: std::shared_ptr grad_op_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("negative", NegativeOp); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct MatmulCaptureState : public AutoGradCaptureState { bool transpose_a; bool transpose_b; double alpha; bool requires_grad_a; bool requires_grad_b; size_t a_index; size_t b_index; }; class Matmul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(MatmulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe Matmul::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Matmul::Capture(MatmulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_a = inputs.at(0)->requires_grad(); ctx->requires_grad_b = inputs.at(1)->requires_grad(); if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->transpose_a = JUST(composed_attrs.GetAttr("transpose_a")); ctx->transpose_b = JUST(composed_attrs.GetAttr("transpose_b")); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); if (ctx->requires_grad_a) { ctx->b_index = ctx->SaveTensorForBackward(inputs.at(1)); // input b } if (ctx->requires_grad_b) { ctx->a_index = ctx->SaveTensorForBackward(inputs.at(0)); // input a } return Maybe::Ok(); } Maybe Matmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad_a) { const auto& input_b = ctx->SavedTensors().at(ctx->b_index); if (ctx->transpose_a) { in_grads->at(0) = JUST(functional::MatMul(input_b, out_grads.at(0), ctx->transpose_b, true, ctx->alpha)); } else { in_grads->at(0) = JUST( functional::MatMul(out_grads.at(0), input_b, false, !(ctx->transpose_b), ctx->alpha)); } } if (ctx->requires_grad_b) { const auto& input_a = ctx->SavedTensors().at(ctx->a_index); if (ctx->transpose_b) { in_grads->at(1) = JUST(functional::MatMul(out_grads.at(0), input_a, true, ctx->transpose_a, ctx->alpha)); } else { in_grads->at(1) = JUST( functional::MatMul(input_a, out_grads.at(0), !(ctx->transpose_a), false, ctx->alpha)); } } return Maybe::Ok(); } struct BroadcastMatmulCaptureState : public AutoGradCaptureState { bool transpose_a = false; bool transpose_b = false; double alpha = 1.0; bool requires_grad_a = true; bool requires_grad_b = true; size_t a_index = 0; size_t b_index = 1; bool broadcast_a = false; bool broadcast_b = false; int64_t b_num_axes = 0; }; class BroadcastMatmul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BroadcastMatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe BroadcastMatmul::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. "; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe BroadcastMatmul::Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad(); if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } const auto a_shape = JUST(VectorAt(inputs, 0))->shape(); const auto b_shape = JUST(VectorAt(inputs, 1))->shape(); const int64_t a_num_axes = a_shape->NumAxes(); const int64_t b_num_axes = b_shape->NumAxes(); const size_t num_max_batch_dims = std::max(a_num_axes, b_num_axes) - 2; auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) { const int64_t num_batch_dims = num_dims - 2; const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims; return [num_padding_dims, shape_dim](size_t index) { return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims); }; }; auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape); auto GetBBatchDim = MakeGetBatchDim(b_num_axes, *b_shape); bool broadcast_a = false; bool broadcast_b = false; for (int32_t i = 0; i < num_max_batch_dims; i++) { if (GetABatchDim(i) < GetBBatchDim(i) || a_num_axes < b_num_axes) { broadcast_a = true; break; } } for (int32_t i = 0; i < num_max_batch_dims; i++) { if (GetBBatchDim(i) < GetABatchDim(i) || b_num_axes < a_num_axes) { broadcast_b = true; break; } } if (b_num_axes == 2 && !ctx->transpose_a) { // In this case, we can directly use `broadcast_matmul_grad_b` OP to generate Grad instead of // broadcast_matmul+reduce_sum_like. broadcast_b = false; } ctx->broadcast_a = broadcast_a; ctx->broadcast_b = broadcast_b; ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->transpose_a = JUST(composed_attrs.GetAttr("transpose_a")); ctx->transpose_b = JUST(composed_attrs.GetAttr("transpose_b")); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); if (ctx->requires_grad_a) { ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b if (broadcast_a) { ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a } } if (ctx->requires_grad_b) { ctx->b_num_axes = JUST(VectorAt(inputs, 1))->shape()->NumAxes(); ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a if (broadcast_b) { ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b } } return Maybe::Ok(); } Maybe BroadcastMatmul::Apply(const BroadcastMatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. "; in_grads->resize(2); const auto out_shape = JUST(VectorAt(out_grads, 0))->shape(); const int64_t out_num_axes = out_shape->NumAxes(); const size_t num_max_batch_dims = out_num_axes - 2; auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) { const int64_t num_batch_dims = num_dims - 2; const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims; return [num_padding_dims, shape_dim](size_t index) { return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims); }; }; auto GetOutBatchDim = MakeGetBatchDim(out_num_axes, *out_shape); if (ctx->requires_grad_a) { std::shared_ptr broadcast_grad_a; const auto& input_b = ctx->SavedTensors().at(ctx->b_index); if (ctx->transpose_a) { broadcast_grad_a = JUST(functional::MatMul(input_b, JUST(VectorAt(out_grads, 0)), ctx->transpose_b, true, ctx->alpha)); } else { broadcast_grad_a = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_b, false, !(ctx->transpose_b), ctx->alpha)); } if (ctx->broadcast_a) { const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index)); const auto a_shape = input_a->shape(); const int64_t a_num_axes = a_shape->NumAxes(); std::vector a_reduce_vec; auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape); const int64_t a_out_num_dim_differ = out_num_axes - a_num_axes; for (int32_t i = 0; i < out_num_axes - 2; i++) { if (GetOutBatchDim(i) > GetABatchDim(i) || (GetOutBatchDim(i) == 1 && i < a_out_num_dim_differ)) { a_reduce_vec.push_back(i); } } JUST(VectorAt(*in_grads, 0)) = JUST(functional::ReduceSumLike(broadcast_grad_a, input_a, a_reduce_vec)); } else { JUST(VectorAt(*in_grads, 0)) = broadcast_grad_a; } } if (ctx->requires_grad_b) { const auto& input_a = ctx->SavedTensors().at(ctx->a_index); if (ctx->b_num_axes == 2 && !ctx->transpose_a) { if (ctx->transpose_b) { JUST(VectorAt(*in_grads, 1)) = JUST( functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha)); } else { JUST(VectorAt(*in_grads, 1)) = JUST( functional::BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha)); } } else { std::shared_ptr broadcast_grad_b; if (ctx->transpose_b) { broadcast_grad_b = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_a, true, ctx->transpose_a, ctx->alpha)); } else { broadcast_grad_b = JUST(functional::MatMul(input_a, JUST(VectorAt(out_grads, 0)), !ctx->transpose_a, false, ctx->alpha)); } if (ctx->broadcast_b) { const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index)); const auto b_shape = input_b->shape(); std::vector b_reduce_vec; auto GetBBatchDim = MakeGetBatchDim(ctx->b_num_axes, *b_shape); const int64_t b_out_num_dim_differ = out_num_axes - ctx->b_num_axes; for (int32_t i = 0; i < out_num_axes - 2; i++) { if (GetOutBatchDim(i) > GetBBatchDim(i) || (GetOutBatchDim(i) == 1 && i < b_out_num_dim_differ)) { b_reduce_vec.push_back(i); } } JUST(VectorAt(*in_grads, 1)) = JUST(functional::ReduceSumLike(broadcast_grad_b, input_b, b_reduce_vec)); } else { JUST(VectorAt(*in_grads, 1)) = broadcast_grad_b; } } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("matmul", Matmul); REGISTER_OP_EXPR_GRAD_FUNCTION("batch_matmul", Matmul); REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_matmul", BroadcastMatmul); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/matrix_vector_product.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct MatrixVectorProductCaptureState : public AutoGradCaptureState { bool requires_grad_a = false; bool requires_grad_b = false; size_t a_index = 0; size_t b_index = 1; }; class MatrixVectorProduct : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(MatrixVectorProductCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const MatrixVectorProductCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe MatrixVectorProduct::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. "; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe MatrixVectorProduct::Capture(MatrixVectorProductCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad(); if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); if (ctx->requires_grad_a) { ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b } if (ctx->requires_grad_b) { ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a } return Maybe::Ok(); } Maybe MatrixVectorProduct::Apply(const MatrixVectorProductCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. "; in_grads->resize(2); if (ctx->requires_grad_a) { const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::MatrixVectorProductGradA(JUST(VectorAt(out_grads, 0)), input_b)); } if (ctx->requires_grad_b) { const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index)); JUST(VectorAt(*in_grads, 1)) = JUST(functional::MatrixVectorProductGradB(JUST(VectorAt(out_grads, 0)), input_a)); if (input_a->dtype()->is_complex()) { JUST(VectorAt(*in_grads, 1)) = JUST(functional::Conj(JUST(VectorAt(*in_grads, 1)))); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("matrix_vector_product", MatrixVectorProduct); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/max_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { namespace { struct MaxPoolCaptureState : public AutoGradCaptureState { bool requires_grad = false; size_t input_index = 0; size_t indice_index = 0; std::string data_format; std::vector padding; std::vector kernel_size; std::vector stride; std::vector dilation; bool return_indices = false; bool ceil_mode = false; }; class MaxPoolNdGrad : public OpExprGradFunction { public: virtual ~MaxPoolNdGrad() = default; using OpExprGradFunction::Init; Maybe Init(const OpExpr& op) override; Maybe Capture(MaxPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const MaxPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe MaxPoolNdGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe MaxPoolNdGrad::Capture(MaxPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->indice_index = ctx->SaveTensorForBackward(outputs.at(1)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding = JUST(composed_attrs.GetAttr>("padding")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->stride = JUST(composed_attrs.GetAttr>("stride")); ctx->dilation = JUST(composed_attrs.GetAttr>("dilation")); ctx->return_indices = JUST(composed_attrs.GetAttr("return_indices")); ctx->ceil_mode = JUST(composed_attrs.GetAttr("ceil_mode")); return Maybe::Ok(); } Maybe MaxPoolNdGrad::Apply(const MaxPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_LE_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) int32_t ndims = ctx->kernel_size.size(); const auto& input = ctx->SavedTensors().at(ctx->input_index); const auto& indice = ctx->SavedTensors().at(ctx->indice_index); in_grads->resize(1); (*in_grads)[0] = JUST(functional::MaxPoolNdGrad( input, indice, out_grads[0], ndims, ctx->data_format, ctx->padding, ctx->kernel_size, ctx->stride, ctx->dilation, ctx->return_indices, ctx->ceil_mode)); return Maybe::Ok(); } } // namespace REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_1d", MaxPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_2d", MaxPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_3d", MaxPoolNdGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/max_unpool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { namespace { struct MaxUnpoolCaptureState : public AutoGradCaptureState { bool requires_grad = false; size_t input_index = 0; size_t indices_index = 0; }; using FuncType = decltype(functional::MaxUnpool1dGrad); template class MaxUnpoolNdGrad : public OpExprGradFunction { public: virtual ~MaxUnpoolNdGrad() = default; using OpExprGradFunction::Init; Maybe Init(const OpExpr& op) override; Maybe Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const MaxUnpoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; template Maybe MaxUnpoolNdGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } template Maybe MaxUnpoolNdGrad::Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->indices_index = ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } template Maybe MaxUnpoolNdGrad::Apply(const MaxUnpoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_LE_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& input = ctx->SavedTensors().at(ctx->input_index); const auto& indices = ctx->SavedTensors().at(ctx->indices_index); in_grads->resize(2); (*in_grads)[0] = JUST(F(input, indices, out_grads[0])); return Maybe::Ok(); } } // namespace REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_1d", MaxUnpoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_2d", MaxUnpoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_3d", MaxUnpoolNdGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/median.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct MedianCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class Median : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(MedianCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); } return Maybe::Ok(); } Maybe Apply(const MedianCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& output = JUST(VectorAt(ctx->SavedTensors(), 1)); const auto& dy = JUST(VectorAt(out_grads, 0)); std::vector axis(input->ndim()); std::iota(axis.begin(), axis.end(), 0); const auto cast_like = JUST(functional::SequenceFunction()>( [&]() { return functional::BroadcastLike(output, input, axis); }) .then(std::bind(functional::BroadcastEqual, input, std::placeholders::_1)) .then(std::bind(functional::CastLike, std::placeholders::_1, input)) .call()); const auto bcast_like_div = JUST(functional::SequenceFunction()>( [&]() { return functional::ReduceSum(cast_like, axis, false, NullOpt); }) .then(std::bind(functional::Div, dy, std::placeholders::_1)) .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, axis)) .call()); in_grads->resize(1); JUST(VectorAt(*in_grads, 0)) = JUST(functional::Mul(bcast_like_div, cast_like)); } return Maybe::Ok(); } }; struct MedianWithIndicesCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class MedianWithIndices : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(MedianWithIndicesCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 1))); } return Maybe::Ok(); } Maybe Apply(const MedianWithIndicesCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { in_grads->resize(1); const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1)); const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::DimScatterUpdate( JUST(functional::Constant(*(input->shape()), Scalar(0), *dout->dtype(), JUST(dout->device()))), -1, indices, dout, /*inplace*/ false)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("median", Median); REGISTER_OP_EXPR_GRAD_FUNCTION("median_with_indices", MedianWithIndices); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct ModeCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class Mode : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ModeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 1))); } return Maybe::Ok(); } Maybe Apply(const ModeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { in_grads->resize(1); const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1)); const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::DimScatterUpdate( JUST(functional::Constant(*(input->shape()), Scalar(0), *dout->dtype(), JUST(dout->device()))), -1, indices, dout, /*inplace*/ false)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("mode", Mode); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/narrow.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace one { struct NarrowCaptureState : public AutoGradCaptureState { bool requires_grad; Shape shape; int64_t dim; int64_t start; int64_t length; }; class Narrow : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(NarrowCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dim = JUST(composed_attrs.GetAttr("dim")); ctx->start = JUST(composed_attrs.GetAttr("start")); ctx->length = JUST(composed_attrs.GetAttr("length")); if (LazyMode::is_enabled()) { ctx->SaveTensorForBackward(inputs.at(0)); } else { ctx->shape = *(inputs.at(0)->shape()); } return Maybe::Ok(); } Maybe Apply(const NarrowCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& dy = out_grads.at(0); if (ctx->requires_grad) { std::shared_ptr like; if (LazyMode::is_enabled()) { like = ctx->SavedTensors().at(0); } else if (dy->is_local()) { like = JUST(functional::Empty(ctx->shape, dy->dtype(), JUST(dy->device()), ctx->requires_grad, /*pin_memory=*/false)); } else { like = JUST( functional::GlobalEmpty(ctx->shape, dy->dtype(), JUST(dy->parallel_desc()), *JUST(private_details::RawGetSbpList(JUST(dy->nd_sbp()))))); } in_grads->resize(1); in_grads->at(0) = JUST(functional::NarrowGrad(dy, like, ctx->dim, ctx->start, ctx->length)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("narrow", Narrow); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/nll.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct NLLCaptureState : public AutoGradCaptureState { bool requires_grad = false; int64_t ignore_index = -100; }; class NLLGradFunction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe NLLGradFunction::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { auto input = JUST(VectorAt(inputs, 0)); ctx->requires_grad = input->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->ignore_index = JUST(composed_attrs.GetAttr("ignore_index")); ctx->SaveTensorForBackward(input); // input ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target if (inputs.size() == 3) { ctx->SaveTensorForBackward(inputs[2]); // weight } return Maybe::Ok(); } Maybe NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2) << Error::RuntimeError() << "The number of saved tensors is expected to be greater than or equal to 2, but got " << ctx->SavedTensors().size(); const auto& out_grad = out_grads[0]; const auto& input = ctx->SavedTensors()[0]; const auto& target = ctx->SavedTensors()[1]; in_grads->resize(ctx->SavedTensors().size()); if (ctx->SavedTensors().size() == 2) { JUST(VectorAt(*in_grads, 0)) = JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index)); } else { // has weight auto weight = JUST(VectorAt(ctx->SavedTensors(), 2)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("nll", NLLGradFunction); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct NonContiguousBinaryOpCaptureState : public AutoGradCaptureState { bool lhs_requires_grad = false; bool rhs_requires_grad = false; std::string op = "add"; bool inplace = false; }; class NonContiguousBinaryOp : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe NonContiguousBinaryOp::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe NonContiguousBinaryOp::Capture(NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->lhs_requires_grad = inputs.at(0)->requires_grad(); ctx->rhs_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->inplace = JUST(composed_attrs.GetAttr("inplace")); ctx->op = JUST(composed_attrs.GetAttr("op")); if (ctx->inplace && ctx->rhs_requires_grad) { CHECK_OR_RETURN(ctx->op == "add" || ctx->op == "sub") << "when inplace and rhs requires grad, op should be add/sub"; } ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); return Maybe::Ok(); } Maybe NonContiguousBinaryOp::Apply(const NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); auto lhs = ctx->SavedTensors().at(0); auto rhs = ctx->SavedTensors().at(1); auto ret = JUST(functional::NonContiguousBinaryOpGrad(out_grads.at(0), lhs, rhs, ctx->op, false)); if (ctx->lhs_requires_grad) in_grads->at(0) = ret->at(0); if (ctx->rhs_requires_grad) in_grads->at(1) = ret->at(1); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("noncontiguous_binary_op", NonContiguousBinaryOp); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/normalization.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct NormalizationGradCaptureState : public AutoGradCaptureState { int32_t axis; float epsilon; bool track_running_stats; bool is_training; bool x_requires_grad; bool gamma_requires_grad; bool beta_requires_grad; }; // training: // y, mean, inv_variance = normalization(x, moving_mean, moving_variance, gamma, beta, // axis=1, epsilon=0.01, momentum=0.9) // y, mean, inv_variance = normalization(x, gamma, beta, axis=1, epsilon=0.01, momentum=0.9) // inference: // y = normalization(x, moving_mean, moving_variance, gamma, beta, axis=1, epsilon=0.01, // momentum=0.9) class NormalizationGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(NormalizationGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // input_size may be 3 or 5, as inputs may be // (x, gamma, beta) or (x, moving_mean, moving_variance, gamma, beta) // ref to track_running_stats false/true // output_size may be 1 or 3, as outputs may be // (x, ) or (x, mean, inv_variance) // ref to is_training false/true ctx->x_requires_grad = inputs.at(0)->requires_grad(); std::shared_ptr gamma, beta; if (inputs.size() == 3) { gamma = inputs.at(1); beta = inputs.at(2); ctx->track_running_stats = false; } else { CHECK_EQ_OR_RETURN(inputs.size(), 5); // NOLINT(maybe-need-error-msg) gamma = inputs.at(3); beta = inputs.at(4); ctx->track_running_stats = true; } ctx->gamma_requires_grad = gamma->requires_grad(); ctx->beta_requires_grad = beta->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); ctx->epsilon = JUST(composed_attrs.GetAttr("epsilon")); ctx->is_training = JUST(composed_attrs.GetAttr("training")); ctx->SaveTensorForBackward(inputs.at(0)); // x ctx->SaveTensorForBackward(gamma); // gamma if (ctx->is_training || !ctx->track_running_stats) { ctx->SaveTensorForBackward(outputs.at(1)); // mean ctx->SaveTensorForBackward(outputs.at(2)); // inv_variance } else { ctx->SaveTensorForBackward(inputs.at(1)); // moving_mean ctx->SaveTensorForBackward(inputs.at(2)); // moving_variance } return Maybe::Ok(); } Maybe Apply(const NormalizationGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); // x const auto& gamma = ctx->SavedTensors().at(1); // gamma const auto& y_grad = out_grads.at(0); std::shared_ptr mean, inv_variance; if (ctx->is_training || !ctx->track_running_stats) { mean = ctx->SavedTensors().at(2); // mean inv_variance = ctx->SavedTensors().at(3); // inv_variance } else { const auto& moving_mean = ctx->SavedTensors().at(2); // moving_mean const auto& moving_variance = ctx->SavedTensors().at(3); // moving_variance const auto& add_eps = JUST( functional::ScalarAdd(moving_variance, ctx->epsilon, /*alpha=*/1, /*inplace=*/false)); mean = moving_mean; inv_variance = JUST(functional::Rsqrt(add_eps)); } const auto& results = JUST(functional::NormalizationGrad(y_grad, x, mean, inv_variance, gamma, ctx->epsilon, ctx->axis)); CHECK_EQ_OR_RETURN(results->size(), 3) << Error::RuntimeError() << "The number of results is expected to be 3, but got " << results->size(); if (ctx->track_running_stats) { // The normalization op has 5 inputs which are x, moving_mean, moving_variance, gamma and // beta. in_grads->resize(5); if (ctx->gamma_requires_grad) { in_grads->at(3) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(4) = results->at(2); // beta_diff } } else { // The normalization op has 3 inputs which are x, gamma and beta. in_grads->resize(3); if (ctx->gamma_requires_grad) { in_grads->at(1) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(2) = results->at(2); // beta_diff } } if (!ctx->x_requires_grad) { return Maybe::Ok(); } if (ctx->is_training) { in_grads->at(0) = results->at(0); return Maybe::Ok(); } Shape shape; for (int i = 0; i < x->shape()->NumAxes(); ++i) { if (i != ctx->axis) { shape.emplace_back(1); } else { shape.emplace_back(x->shape()->At(ctx->axis)); } } const auto& reshaped_gamma = JUST(functional::Reshape(gamma, shape)); const auto& reshaped_inv_variance = JUST(functional::Reshape(inv_variance, shape)); std::shared_ptr y_grad_fp32 = y_grad; bool is_fp16 = y_grad->dtype()->data_type() == DataType::kFloat16; if (is_fp16) { y_grad_fp32 = JUST(functional::Cast(y_grad, DType::Float(), /*pin_memory=*/false)); } const auto& dy_mul_gamma = JUST(functional::Mul(reshaped_gamma, y_grad_fp32)); const auto& dy_mul_inv_var = JUST(functional::Mul(dy_mul_gamma, reshaped_inv_variance)); if (is_fp16) { (*in_grads)[0] = JUST(functional::Cast(dy_mul_inv_var, DType::Float16(), /*pin_memory=*/false)); } else { (*in_grads)[0] = dy_mul_inv_var; } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("normalization", NormalizationGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/normalization_add_relu.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct NormalizationAddReluGradCaptureState : public AutoGradCaptureState { int32_t axis = 1; float epsilon = 1e-5; bool track_running_stats = true; bool is_training = true; bool has_addend = false; bool x_requires_grad = true; bool addend_requires_grad = true; bool gamma_requires_grad = true; bool beta_requires_grad = true; }; // training: // y, mean, inv_variance = normalization_add_relu(x, Optional(add_end), moving_mean, // moving_variance, gamma, beta, axis=1, epsilon=0.01, momentum=0.9) y, mean, inv_variance = // normalization_add_relu(x, Optional(add_end), gamma, beta, axis=1, epsilon=0.01, momentum=0.9) // inference: // y = normalization_add_relu(x, Optional(add_end), moving_mean, moving_variance, gamma, beta, // axis=1, epsilon=0.01, momentum=0.9) class NormalizationAddReluGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(NormalizationAddReluGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // input_size may be 3/4/5/6, as inputs may be // (x, gamma, beta) or (x, moving_mean, moving_variance, gamma, beta) // (x, addend, gamma, beta) or (x, addend, moving_mean, moving_variance, gamma, beta) // ref to track_running_stats false/true // output_size may be 2 or 4, as outputs may be // (x, reserve_space) or (x, reserve_space, mean, inv_variance) // ref to is_training false/true ctx->x_requires_grad = inputs.at(0)->requires_grad(); std::shared_ptr add_end, gamma, beta; if (inputs.size() == 3 || inputs.size() == 5) { add_end = nullptr; if (inputs.size() == 3) { gamma = inputs.at(1); beta = inputs.at(2); ctx->track_running_stats = false; } else { gamma = inputs.at(3); beta = inputs.at(4); ctx->track_running_stats = true; } ctx->has_addend = false; } else if (inputs.size() == 4 || inputs.size() == 6) { add_end = inputs.at(1); if (inputs.size() == 4) { gamma = inputs.at(2); beta = inputs.at(3); ctx->track_running_stats = false; } else { gamma = inputs.at(4); beta = inputs.at(5); ctx->track_running_stats = true; } ctx->has_addend = true; ctx->addend_requires_grad = inputs.at(1)->requires_grad(); } ctx->gamma_requires_grad = gamma->requires_grad(); ctx->beta_requires_grad = beta->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); ctx->epsilon = JUST(composed_attrs.GetAttr("epsilon")); ctx->is_training = JUST(composed_attrs.GetAttr("training")); ctx->SaveTensorForBackward(inputs.at(0)); // x 0 ctx->SaveTensorForBackward(gamma); // gamma 1 ctx->SaveTensorForBackward(beta); // beta 2 if (ctx->is_training || !ctx->track_running_stats) { ctx->SaveTensorForBackward(outputs.at(2)); // mean 3 ctx->SaveTensorForBackward(outputs.at(3)); // inv_variance 4 } else { if (inputs.size() == 5) { // without add_end ctx->SaveTensorForBackward(inputs.at(1)); // moving_mean 3 ctx->SaveTensorForBackward(inputs.at(2)); // moving_variance 4 } else { CHECK_EQ_OR_RETURN(inputs.size(), 6); // NOLINT(maybe-need-error-msg) // with add_end ctx->SaveTensorForBackward(inputs.at(2)); // moving_mean 3 ctx->SaveTensorForBackward(inputs.at(3)); // moving_variance 4 } } ctx->SaveTensorForBackward(outputs.at(0)); // y 5 ctx->SaveTensorForBackward(outputs.at(1)); // reserve space 6 return Maybe::Ok(); } Maybe Apply(const NormalizationAddReluGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); // x const auto& gamma = ctx->SavedTensors().at(1); // gamma const auto& beta = ctx->SavedTensors().at(2); // beta const auto& y_grad = out_grads.at(0); std::shared_ptr mean, inv_variance; if (ctx->is_training || !ctx->track_running_stats) { mean = ctx->SavedTensors().at(3); // mean inv_variance = ctx->SavedTensors().at(4); // inv_variance } else { const auto& moving_mean = ctx->SavedTensors().at(3); // moving_mean const auto& moving_variance = ctx->SavedTensors().at(4); // moving_variance const auto& add_eps = JUST( functional::ScalarAdd(moving_variance, ctx->epsilon, /*alpha=*/1, /*inplace=*/false)); mean = moving_mean; inv_variance = JUST(functional::Rsqrt(add_eps)); } const auto& y = ctx->SavedTensors().at(5); const auto& reserve_space = ctx->SavedTensors().at(6); const auto& results = JUST(functional::NormalizationAddReluGrad( x, y_grad, mean, inv_variance, gamma, beta, reserve_space, y, ctx->axis, ctx->epsilon, ctx->has_addend)); CHECK_EQ_OR_RETURN(results->size(), (ctx->has_addend ? 4 : 3)) << Error::RuntimeError() << "The number of results is expected to be " << (ctx->has_addend ? 4 : 3) << ", but got " << results->size(); // here output includes "gamma_diff" "beta_diff" "dx" "addend_diff" if (ctx->track_running_stats) { // The normalization op has 5 inputs which are x, moving_mean, moving_variance, gamma and // beta. or 6 inputs: x, add_end, moving_mean, moving_variance, gamma and beta. if (ctx->has_addend) { in_grads->resize(6); if (ctx->gamma_requires_grad) { in_grads->at(4) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(5) = results->at(2); // beta_diff } if (ctx->addend_requires_grad) { in_grads->at(1) = results->at(3); // add_end_diff } } else { in_grads->resize(5); if (ctx->gamma_requires_grad) { in_grads->at(3) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(4) = results->at(2); // beta_diff } } } else { // The normalization op has 3 inputs which are x, addend, gamma and beta. // or has 4 inputs which are x, addend, gamma and beta. if (ctx->has_addend) { in_grads->resize(4); if (ctx->addend_requires_grad) { in_grads->at(1) = results->at(3); // addend_diff } if (ctx->gamma_requires_grad) { in_grads->at(2) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(3) = results->at(2); // beta_diff } } else { in_grads->resize(3); if (ctx->gamma_requires_grad) { in_grads->at(1) = results->at(1); // gamma_diff; } if (ctx->beta_requires_grad) { in_grads->at(2) = results->at(2); // beta_diff } } } if (!ctx->x_requires_grad) { return Maybe::Ok(); } if (ctx->is_training) { in_grads->at(0) = results->at(0); return Maybe::Ok(); } // todo(zzk): add eval mode. return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("normalization_add_relu", NormalizationAddReluGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/one_embedding_fused_lookup.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct OneEmbeddingFusedLookupCaptureState : public AutoGradCaptureState { bool requires_grad{}; std::string embedding_name{}; int64_t line_size{}; int64_t embedding_size{}; int shadow_index{}; int ids_index{}; int input_num{}; }; class OneEmbeddingFusedLookup : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_GE_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); // shadow ctx->shadow_index = ctx->SaveTensorForBackward(inputs.at(0)); // shadow ctx->ids_index = ctx->SaveTensorForBackward(inputs.at(1)); // id ctx->embedding_name = JUST(attrs.GetAttr("embedding_name")); ctx->line_size = JUST(attrs.GetAttr("line_size")); ctx->embedding_size = JUST(attrs.GetAttr("embedding_size")); ctx->input_num = inputs.size(); return Maybe::Ok(); } Maybe Apply(const OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(ctx->input_num); const auto& saved_tensors = ctx->SavedTensors(); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) if (ctx->requires_grad) { JUST(functional::OneEmbeddingFusedLookupGrad( saved_tensors.at(ctx->ids_index), JUST(VectorAt(out_grads, 0)), ctx->embedding_name, ctx->line_size, ctx->embedding_size)); (*in_grads)[0] = JUST(functional::ZerosLike(saved_tensors.at(ctx->shadow_index))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("one_embedding_fused_lookup", OneEmbeddingFusedLookup); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/padding.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct PadNdCaptureState : public AutoGradCaptureState { bool requires_grad = false; std::vector paddings{}; }; class PadNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(PadNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->paddings = JUST(composed_attrs.GetAttr>("padding")); return Maybe::Ok(); } private: AttrMap base_attrs_; }; class ReflectionPadNd : public PadNd { public: Maybe Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::PadGrad(JUST(VectorAt(out_grads, 0)), ctx->paddings, "reflect", 0)); } return Maybe::Ok(); } }; class ReplicationPadNd : public PadNd { public: Maybe Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::PadGrad(JUST(VectorAt(out_grads, 0)), ctx->paddings, "replicate", 0)); } return Maybe::Ok(); } }; struct ConstantPadNdCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector paddings; }; class ConstantPadNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ConstantPadNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_0 = JUST(VectorAt(inputs, 0)); ctx->requires_grad = input_0->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->paddings = JUST(composed_attrs.GetAttr>("padding")); for (int i = 0; i < ctx->paddings.size(); i++) { ctx->paddings[i] = -ctx->paddings[i]; } return Maybe::Ok(); } Maybe Apply(const ConstantPadNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::Pad(JUST(VectorAt(out_grads, 0)), ctx->paddings, "constant", Scalar(0))); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("pad", ConstantPadNd); REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad1d", ReflectionPadNd); REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPadNd); REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad1d", ReplicationPadNd); REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPadNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/partial_fc_sample.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct PartialFCSampleState : public AutoGradCaptureState { bool requires_grad = false; int32_t index_sampled_label = -1; int32_t index_weight = -1; }; class PartialFCSample : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(PartialFCSampleState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe PartialFCSample::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe PartialFCSample::Capture(PartialFCSampleState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->index_sampled_label = ctx->SaveTensorForBackward(outputs.at(1)); // sampled_label ctx->index_weight = ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe PartialFCSample::Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 3); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& diff_sampled_weight = out_grads.at(2); // diff of sampled_weight const auto& sampled_tensor = ctx->SavedTensors().at(ctx->index_sampled_label); const auto& weight = ctx->SavedTensors().at(ctx->index_weight); const auto& out_tensors_of_op0 = JUST( functional::DistributedPariticalFCSampleDisableBoxing(diff_sampled_weight, sampled_tensor)); const auto& out_tensors_of_op1 = JUST(functional::UnsortedSegmentSumLike( out_tensors_of_op0->at(0), out_tensors_of_op0->at(1), weight, 0)); in_grads->at(0) = out_tensors_of_op1; return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("distributed_partial_fc_sample", PartialFCSample); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/reduce_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct ReduceSumCaptureState : public AutoGradCaptureState { std::vector axis; }; class ReduceSum : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ReduceSum::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ReduceSum::Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe ReduceSum::Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& input = ctx->SavedTensors().at(0); const auto& dy = out_grads.at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::BroadcastLike(dy, input, ctx->axis)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSum); REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_nansum", ReduceSum); struct ReduceProdOpInterpState : public AutoGradCaptureState { std::vector axis; bool requires_grad; }; class ReduceProdOp : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ReduceProdOpInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ReduceProdOpInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ReduceProdOp::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ReduceProdOp::Capture(ReduceProdOpInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->requires_grad = inputs.at(0)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe ReduceProdOp::Apply(const ReduceProdOpInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& input = ctx->SavedTensors().at(0); const auto& output = ctx->SavedTensors().at(1); const auto& dy = out_grads.at(0); in_grads->resize(1); in_grads->at(0) = JUST( functional::SequenceFunction()>([&]() { return functional::Mul(dy, output); }) .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, ctx->axis)) .then(std::bind(functional::Div, std::placeholders::_1, input)) .call()); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_prod", ReduceProdOp); struct ReduceMaxOrMinCaptureState : public AutoGradCaptureState { std::vector axis; bool keepdims; }; class ReduceMaxOrMin : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ReduceMaxOrMinCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ReduceMaxOrMin::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ReduceMaxOrMin::Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->keepdims = JUST(composed_attrs.GetAttr("keepdims")); ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe ReduceMaxOrMin::Apply(const ReduceMaxOrMinCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& input = ctx->SavedTensors().at(0); const auto& output = ctx->SavedTensors().at(1); const auto& dy = out_grads.at(0); const auto cast_like = JUST(functional::SequenceFunction()>( [&]() { return functional::BroadcastLike(output, input, ctx->axis); }) .then(std::bind(functional::BroadcastEqual, input, std::placeholders::_1)) .then(std::bind(functional::CastLike, std::placeholders::_1, input)) .call()); const auto& bcast_like_div = JUST(functional::SequenceFunction()>([&]() { return functional::ReduceSum(cast_like, ctx->axis, ctx->keepdims, NullOpt); }) .then(std::bind(functional::Div, dy, std::placeholders::_1)) .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, ctx->axis)) .call()); in_grads->resize(1); in_grads->at(0) = JUST(functional::Mul(bcast_like_div, cast_like)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMin); REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMin); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/reduce_sum_like.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { struct ReduceSumLikeCaptureState : public AutoGradCaptureState { bool requires_grad = false; std::vector axis; }; class ReduceSumLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ReduceSumLike::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ReduceSumLike::Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); CHECK_OR_RETURN(!inputs.at(1)->requires_grad()) << Error::RuntimeError() << "like tensor does not require grad"; if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe ReduceSumLike::Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& x = ctx->SavedTensors().at(0); in_grads->resize(2); in_grads->at(0) = JUST(functional::BroadcastLike(out_grads.at(0), x, ctx->axis)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum_like", ReduceSumLike); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/reshape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ReshapeCaptureState : public AutoGradCaptureState { DimVector input_shape_vec; }; class ReshapeGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec(); return Maybe::Ok(); } Maybe Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); Shape shape(ctx->input_shape_vec); in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape)); return Maybe::Ok(); } }; class ReshapeLikeGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!inputs.at(1)->requires_grad()) << "ReshapeLikeOp's input[1] need not requires_grad."; ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec(); return Maybe::Ok(); } Maybe Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); Shape shape(ctx->input_shape_vec); in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape)); return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("reshape", ReshapeGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("reshape_like", ReshapeLikeGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/rms_norm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct RMSNormCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool weight_requires_grad = false; int x_index = -1; int inv_rms_index = -1; int weight_index = -1; }; class RMSNormGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe RMSNormGrad::Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // (x, [weight]) CHECK_GE_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_LE_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) // (y, inv_rms) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) // save x ctx->x_requires_grad = inputs[0]->requires_grad(); ctx->x_index = ctx->SaveTensorForBackward(inputs[0]); // save weight ctx->weight_requires_grad = false; if (inputs.size() > 1) { ctx->weight_requires_grad = inputs[1]->requires_grad(); ctx->weight_index = ctx->SaveTensorForBackward(inputs[1]); } // save inv_rms if (ctx->x_requires_grad || ctx->weight_requires_grad) { ctx->inv_rms_index = ctx->SaveTensorForBackward(outputs[1]); } return Maybe::Ok(); } Maybe RMSNormGrad::Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { // (x, inv_rms) or (x, weight, inv_rms) const auto& saved_tensors = ctx->SavedTensors(); CHECK_GE_OR_RETURN(saved_tensors.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_LE_OR_RETURN(saved_tensors.size(), 3); // NOLINT(maybe-need-error-msg) // (dy, inv_rms_diff) CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads[0]; const auto& x = saved_tensors.at(ctx->x_index); const auto& inv_rms = saved_tensors.at(ctx->inv_rms_index); // (x_grad, weight_grad) in_grads->resize(2); if (ctx->x_requires_grad) { if (saved_tensors.size() == 3) { const auto& weight = saved_tensors.at(ctx->weight_index); in_grads->at(0) = JUST(functional::RMSNormGrad(dy, x, inv_rms, weight, /*param_grad*/ false)); } else { in_grads->at(0) = JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ false)); } } if (ctx->weight_requires_grad) { in_grads->at(1) = JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ true)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("rms_norm", RMSNormGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/roi_align.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct RoiAlignCaptureState : public AutoGradCaptureState { float spatial_scale = 1.0; int32_t pooled_h = 0; int32_t pooled_w = 0; int32_t sampling_ratio = -1; bool aligned = false; bool requires_grad = false; }; class RoiAlign : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(RoiAlignCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(1)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->spatial_scale = JUST(composed_attrs.GetAttr("spatial_scale")); ctx->pooled_h = JUST(composed_attrs.GetAttr("pooled_h")); ctx->pooled_w = JUST(composed_attrs.GetAttr("pooled_w")); ctx->sampling_ratio = JUST(composed_attrs.GetAttr("sampling_ratio")); ctx->aligned = JUST(composed_attrs.GetAttr("aligned")); return Maybe::Ok(); } Maybe Apply(const RoiAlignCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& x_like = ctx->SavedTensors().at(0); const auto& rois = ctx->SavedTensors().at(1); in_grads->at(0) = JUST( functional::RoiAlignGrad(out_grads.at(0), x_like, rois, ctx->spatial_scale, ctx->pooled_h, ctx->pooled_w, ctx->sampling_ratio, ctx->aligned)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("roi_align", RoiAlign); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/roll.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct RollCaptureState : public AutoGradCaptureState { std::vector shifts; std::vector dims; bool requires_grad = false; }; class Roll : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(RollCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const RollCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Roll::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Roll::Capture(RollCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->shifts = JUST(composed_attrs.GetAttr>("shifts")); ctx->dims = JUST(composed_attrs.GetAttr>("dims")); return Maybe::Ok(); } Maybe Roll::Apply(const RollCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) std::vector new_shifts; new_shifts.resize(ctx->shifts.size()); for (int i = 0; i < new_shifts.size(); ++i) { new_shifts[i] = -ctx->shifts[i]; } in_grads->at(0) = JUST(functional::Roll(out_grads.at(0), new_shifts, ctx->dims)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("roll", Roll); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/rrelu.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct RReluCaptureState : public AutoGradCaptureState { bool requires_grad = true; float lower = 1.0 / 8; float upper = 1.0 / 3; bool training = false; int x_index = -1; int noise_data_index = -1; }; class RRelu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(RReluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const RReluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe RRelu::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe RRelu::Capture(RReluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->lower = JUST(composed_attrs.GetAttr("lower")); ctx->upper = JUST(composed_attrs.GetAttr("upper")); ctx->training = JUST(composed_attrs.GetAttr("training")); ctx->x_index = ctx->SaveTensorForBackward(inputs[0]); ctx->noise_data_index = ctx->SaveTensorForBackward(outputs[1]); // output noise data return Maybe::Ok(); } Maybe RRelu::Apply(const RReluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& saved_tensors = ctx->SavedTensors(); if (!ctx->training) { float scale = (ctx->lower + ctx->upper) / 2; const auto& x = saved_tensors.at(ctx->x_index); in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), scale)); return Maybe::Ok(); } else { const auto& noise_data = saved_tensors.at(ctx->noise_data_index); in_grads->at(0) = JUST(functional::Mul(out_grads.at(0), noise_data)); return Maybe::Ok(); } } REGISTER_OP_EXPR_GRAD_FUNCTION("rrelu", RRelu); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_add.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct ScalarAddCaptureState : public AutoGradCaptureState { bool requires_grad; }; class ScalarAdd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ScalarAddCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const ScalarAddCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_add", ScalarAdd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_div.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct ScalarDivCaptureState : public AutoGradCaptureState { bool requires_grad = true; Scalar operand; }; class ScalarDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ScalarDivCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } return Maybe::Ok(); } Maybe Apply(const ScalarDivCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { JUST(VectorAt(*in_grads, 0)) = JUST(functional::ScalarDiv(JUST(VectorAt(out_grads, 0)), ctx->operand)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_div", ScalarDiv); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_floordiv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct ScalarFloorDivCaptureState : public AutoGradCaptureState { bool requires_grad = true; }; class ScalarFloorDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ScalarFloorDivCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); return Maybe::Ok(); } Maybe Apply(const ScalarFloorDivCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0)))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_floordiv", ScalarFloorDiv); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ScalarFModGradCaptureState : public AutoGradCaptureState { bool requires_grad; }; class ScalarFModGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ScalarFModGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const ScalarFModGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_fmod", ScalarFModGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_mul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ScalarMulCaptureState : public AutoGradCaptureState { bool requires_grad; Scalar operand; }; class ScalarMul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ScalarMulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } return Maybe::Ok(); } Maybe Apply(const ScalarMulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::ScalarMul(out_grads.at(0), ctx->operand, false)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul", ScalarMul); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_pow.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ScalarPowCaptureState : public AutoGradCaptureState { bool requires_grad; Scalar operand; }; class ScalarPow : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow", ScalarPow); class ScalarReversePow : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs[0]->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } ctx->SaveTensorForBackward(inputs[0]); return Maybe::Ok(); } Maybe Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors()[0]; in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::ScalarReversePowGrad(x, out_grads[0], ctx->operand)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_reverse_pow", ScalarReversePow); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scalar_truncdiv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct ScalarTruncDivCaptureState : public AutoGradCaptureState { bool requires_grad = true; }; class ScalarTruncDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ScalarTruncDivCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } Maybe Apply(const ScalarTruncDivCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0)))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_truncdiv", ScalarTruncDiv); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" #if CUDA_VERSION >= 11070 namespace oneflow { namespace one { struct ScaledDotProductFlashAttentionCaptureState : public AutoGradCaptureState { bool query_requires_grad = true; bool key_requires_grad = true; bool value_requires_grad = true; size_t query_idx = 0; size_t key_idx = 0; size_t value_idx = 0; size_t out_idx = 0; size_t softmax_lse_idx = 0; size_t rng_state_idx = 0; float p_dropout = .0f; float softmax_scale = .0f; bool is_causal = false; }; class ScaledDotProductFlashAttention : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. "; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 3) << "Input size should be equal to 3. "; ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->p_dropout = JUST(composed_attrs.GetAttr("p_dropout")); ctx->softmax_scale = JUST(composed_attrs.GetAttr("softmax_scale")); ctx->is_causal = JUST(composed_attrs.GetAttr("is_causal")); ctx->query_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); ctx->key_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad(); ctx->value_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad(); ctx->query_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); ctx->key_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1))); ctx->value_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2))); ctx->out_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0))); ctx->softmax_lse_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1))); ctx->rng_state_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 2))); return Maybe::Ok(); } Maybe Apply(const ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 3) << "Out grads size should be equal to 3. "; std::shared_ptr grads; in_grads->resize(3); grads = JUST(functional::ScaledDotProductFlashAttentionGrad( JUST(oneflow::VectorAt(out_grads, 0)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->query_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->key_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->value_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->out_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->softmax_lse_idx)), JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->rng_state_idx)), ctx->p_dropout, ctx->is_causal, ctx->softmax_scale)); if (ctx->query_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0)); } if (ctx->key_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1)); } if (ctx->value_requires_grad) { JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scaled_dot_product_flash_attention", ScaledDotProductFlashAttention); } // namespace one } // namespace oneflow #endif // CUDA_VERSION >= 11070 ================================================ FILE: oneflow/core/autograd/gradient_funcs/scatter_nd.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct ScatterNdCaptureState : public AutoGradCaptureState { bool requires_grad; }; class ScatterNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ScatterNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(1)->requires_grad(); if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); // indices } return Maybe::Ok(); } Maybe Apply(const ScatterNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->requires_grad) { const auto& indices = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::GatherNd(out_grads.at(0), indices)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scatter_nd", ScatterNd); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/select_top_n.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" namespace oneflow { namespace one { struct SelectTopNCaptureState : public AutoGradCaptureState { TensorTuple inputs; std::vector requires_grad; int32_t top_n = 0; }; class SelectTopN : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(SelectTopNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->inputs = inputs; ctx->top_n = JUST(attrs.GetAttr("top_n")); ctx->requires_grad.resize(inputs.size()); for (int i = 0; i < ctx->requires_grad.size(); ++i) { ctx->requires_grad.at(i) = inputs.at(i)->requires_grad(); } return Maybe::Ok(); } Maybe Apply(const SelectTopNCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(ctx->top_n, out_grads.size()); // NOLINT(maybe-need-error-msg) for (int i = 0; i < ctx->top_n; ++i) { if (!ctx->requires_grad.at(i)) { continue; } in_grads->at(i) = out_grads.at(i); } for (int i = ctx->top_n; i < ctx->inputs.size(); ++i) { if (!ctx->requires_grad.at(i)) { continue; } const auto& tensor = ctx->inputs.at(i); in_grads->at(i) = JUST(StaticZerosTensor::MakeTensor(tensor->shape(), tensor->dtype()->data_type(), tensor->memory_format(), JUST(tensor->device()))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("select_top_n", SelectTopN); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/slice.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SliceCaptureState : public AutoGradCaptureState { Shape like_shape; std::vector start; std::vector stop; std::vector step; }; class Slice : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SliceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->start = JUST(composed_attrs.GetAttr>("start")); ctx->stop = JUST(composed_attrs.GetAttr>("stop")); ctx->step = JUST(composed_attrs.GetAttr>("step")); ctx->like_shape = *(inputs[0]->shape()); return Maybe::Ok(); } Maybe Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); (*in_grads)[0] = JUST( functional::SliceGrad(out_grads[0], ctx->like_shape, ctx->start, ctx->stop, ctx->step)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct SliceUpdateCaptureState : public AutoGradCaptureState { bool requires_grad_ref = false; bool requires_grad_value = false; std::vector start; std::vector stop; std::vector step; Shape value_shape; // used to calculate ref gradient Symbol value_sbp; }; class SliceUpdate : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SliceUpdateCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad_ref = inputs[0]->requires_grad(); ctx->requires_grad_value = inputs[1]->requires_grad(); if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->start = JUST(composed_attrs.GetAttr>("start")); ctx->stop = JUST(composed_attrs.GetAttr>("stop")); ctx->step = JUST(composed_attrs.GetAttr>("step")); if (ctx->requires_grad_ref) { ctx->value_shape = *(inputs[1]->shape()); if (inputs[1]->is_global()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); } } return Maybe::Ok(); } Maybe Apply(const SliceUpdateCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->requires_grad_ref) { std::shared_ptr zeros; if (out_grads[0]->is_local()) { zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(), JUST(out_grads[0]->device()))); } else { const auto& parallel_desc = JUST(out_grads[0]->parallel_desc()); zeros = JUST(functional::GlobalConstant(ctx->value_shape, 0, out_grads[0]->dtype(), parallel_desc, *JUST(GetSbpList(ctx->value_sbp)))); } (*in_grads)[0] = JUST(functional::SliceUpdate(out_grads[0], zeros, ctx->start, ctx->stop, ctx->step, /*inplace=*/false)); } if (ctx->requires_grad_value) { (*in_grads)[1] = JUST(functional::Slice(out_grads[0], ctx->start, ctx->stop, ctx->step, /*enable_view_slice=*/false)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("slice_update", SliceUpdate); REGISTER_OP_EXPR_GRAD_FUNCTION("slice", Slice); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SmoothL1LossCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool target_requires_grad = false; float beta = 0.0; }; class SmoothL1Loss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input ctx->target_requires_grad = inputs.at(1)->requires_grad(); // target ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->SaveTensorForBackward(inputs.at(1)); // target ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->beta = JUST(composed_attrs.GetAttr("beta")); return Maybe::Ok(); } Maybe Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg) in_grads->resize(2); const auto& input = ctx->SavedTensors().at(0); const auto& target = ctx->SavedTensors().at(1); const auto& grad = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta)); if (ctx->input_requires_grad) { (*in_grads)[0] = grad; } if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::Negative(grad)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("smooth_l1_loss", SmoothL1Loss); // todo: name } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SoftmaxCaptureState : public AutoGradCaptureState { bool requires_grad; }; class Softmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Softmax::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe Softmax::Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) return Maybe::Ok(); ctx->SaveTensorForBackward(outputs.at(0)); return Maybe::Ok(); } Maybe Softmax::Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) return Maybe::Ok(); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads.at(0); const auto& y = ctx->SavedTensors().at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::SoftmaxGrad(dy, y)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("softmax", Softmax); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/softmax_cross_entropy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SoftmaxCrossEntropyGradState : public AutoGradCaptureState { bool requires_grad = false; }; class SoftmaxCrossEntropy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SoftmaxCrossEntropyGradState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SoftmaxCrossEntropyGradState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe SoftmaxCrossEntropy::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe SoftmaxCrossEntropy::Capture(SoftmaxCrossEntropyGradState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->SaveTensorForBackward(inputs.at(1)); // label ctx->SaveTensorForBackward(outputs.at(1)); // prob return Maybe::Ok(); } Maybe SoftmaxCrossEntropy::Apply(const SoftmaxCrossEntropyGradState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& dy = out_grads.at(0); const auto& label = ctx->SavedTensors().at(0); const auto& prob = ctx->SavedTensors().at(1); in_grads->resize(2); // prediction, label (*in_grads)[0] = JUST(functional::SoftmaxCrossEntropyGrad(dy, label, prob)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("softmax_cross_entropy", SoftmaxCrossEntropy); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/sparse_cross_entropy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SparseCrossEntropyCaptureState : public AutoGradCaptureState { bool requires_grad = false; int64_t depth = -1; size_t prediction_index = -1; size_t label_index = -1; }; template class SparseCrossEntropy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SparseCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->depth = JUST(composed_attrs.GetAttr("depth")); ctx->prediction_index = ctx->SaveTensorForBackward(inputs.at(0)); // prediction ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1)); // label return Maybe::Ok(); } Maybe Apply(const SparseCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& prediction = ctx->SavedTensors().at(ctx->prediction_index); const auto& label = ctx->SavedTensors().at(ctx->label_index); in_grads->resize(2); if (is_distributed) { in_grads->at(0) = JUST( functional::SparseCrossEntropyMsGrad(prediction, label, out_grads.at(0), ctx->depth)); } else { in_grads->at(0) = JUST(functional::SparseCrossEntropyGrad(prediction, label, out_grads.at(0), ctx->depth)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_cross_entropy_ms", SparseCrossEntropy); REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_cross_entropy", SparseCrossEntropy); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SparseSoftmaxCrossEntropyCaptureState : public AutoGradCaptureState { int64_t depth; }; class SparseSoftmaxCrossEntropy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe SparseSoftmaxCrossEntropy::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->depth = JUST(composed_attrs.GetAttr("depth")); CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // prob ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // label return Maybe::Ok(); } Maybe SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& dy = JUST(VectorAt(out_grads, 1)); const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1)); // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not // require gradient. in_grads->resize(2); JUST(VectorAt(*in_grads, 0)) = JUST(functional::SparseSoftmaxCrossEntropyGrad(dy, prob, label, ctx->depth)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy", SparseSoftmaxCrossEntropy); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy_ms.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SparseSoftmaxCrossEntropyMsCaptureState : public AutoGradCaptureState { int64_t depth = 0; }; class SparseSoftmaxCrossEntropyMs : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe SparseSoftmaxCrossEntropyMs::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe SparseSoftmaxCrossEntropyMs::Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->depth = JUST(composed_attrs.GetAttr("depth")); CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // prob ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // label return Maybe::Ok(); } Maybe SparseSoftmaxCrossEntropyMs::Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& dy = JUST(VectorAt(out_grads, 1)); const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1)); // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not // require gradient. in_grads->resize(2); JUST(VectorAt(*in_grads, 0)) = JUST(functional::SparseSoftmaxCrossEntropyMsGrad(dy, prob, label, ctx->depth)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy_ms", SparseSoftmaxCrossEntropyMs); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/split_like.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SplitLikeCaptureState : public AutoGradCaptureState { int64_t axis; bool requires_grad; }; class SplitLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe SplitLike::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe SplitLike::Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); for (int i = 0; i < outputs.size(); ++i) { ctx->SaveTensorForBackward(outputs.at(i)); } return Maybe::Ok(); } Maybe SplitLike::Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(out_grads.size() + 1); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& saved_tensors = ctx->SavedTensors(); TensorTuple inputs; inputs.reserve(out_grads.size()); for (int i = 0; i < out_grads.size(); ++i) { const auto& out_grad_i = out_grads.at(i); if (out_grad_i.get()) { inputs.emplace_back(out_grad_i); } else { const auto& zero_grad = JUST(functional::ZerosLike(saved_tensors.at(i))); inputs.emplace_back(zero_grad); } } in_grads->at(0) = JUST(functional::Concat(inputs, ctx->axis)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("split_like", SplitLike); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/squeeze.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct SqueezeCaptureState : public AutoGradCaptureState { bool requires_grad; }; class Squeeze : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Squeeze::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Squeeze::Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Squeeze::Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& like = ctx->SavedTensors().at(0); in_grads->resize(1); in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("squeeze", Squeeze); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/stack.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct StackCaptureState : public AutoGradCaptureState { std::vector requires_grad; int64_t axis = 1; int64_t input_num = 2; }; class Stack : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(StackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const StackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Stack::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Stack::Capture(StackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); } ctx->input_num = inputs.size(); return Maybe::Ok(); } Maybe Stack::Apply(const StackCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(ctx->input_num); TensorTuple like(ctx->input_num); for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); } const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis)); CHECK_EQ_OR_RETURN(results->size(), ctx->input_num) << Error::RuntimeError() << "The number of results (" << results->size() << ") must match the number of inputs (" << ctx->input_num << ")"; for (int i = 0; i < ctx->input_num; ++i) { if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("stack", Stack); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TensorScalarCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool scalar_requires_grad; }; class TensorScalarAddOrSub : public OpExprGradFunction { public: TensorScalarAddOrSub() = default; virtual ~TensorScalarAddOrSub() = default; Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; }; Maybe TensorScalarAddOrSub::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe TensorScalarAddOrSub::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); return Maybe::Ok(); } class TensorScalarAdd : public TensorScalarAddOrSub { public: Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); } if (ctx->scalar_requires_grad) { int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); in_grads->at(1) = JUST(functional::ReduceSum(out_grads.at(0), axes_vec, false, NullOpt)); } return Maybe::Ok(); } }; class TensorScalarSub : public TensorScalarAddOrSub { public: Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); } if (ctx->scalar_requires_grad) { int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); const auto& reduce_sum = JUST(functional::ReduceSum(out_grads.at(0), axes_vec, /*keepdims=*/false, NullOpt)); in_grads->at(1) = JUST(functional::ScalarMul(reduce_sum, /*other=*/1.0, false)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_add_by_tensor", TensorScalarAdd); REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_sub_by_tensor", TensorScalarSub); class TensorScalarMul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe TensorScalarMul::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe TensorScalarMul::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe TensorScalarMul::Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& scalar = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::Mul(out_grads.at(0), scalar)); } if (ctx->scalar_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_requires_grad); const auto& y = JUST(functional::Mul(out_grads.at(0), x)); int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); in_grads->at(1) = JUST(functional::ReduceSum(y, axes_vec, /*keepdims=*/false, NullOpt)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul_by_tensor", TensorScalarMul); class TensorScalarDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: std::shared_ptr tensor_scalar_div_op_; std::shared_ptr broadcast_div_grad_op_; }; Maybe TensorScalarDiv::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe TensorScalarDiv::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); if (ctx->x_requires_grad || ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); } return Maybe::Ok(); } Maybe TensorScalarDiv::Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& scalar = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::Div(out_grads.at(0), scalar)); } if (ctx->scalar_requires_grad) { const auto& scalar = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), y, scalar)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_div_by_tensor", TensorScalarDiv); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/tensor_scatter_nd_update.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TensorScatterNdUpdateCaptureState : public AutoGradCaptureState { bool tensor_requires_grad = false; bool update_requires_grad = false; }; class TensorScatterNdUpdate : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->tensor_requires_grad = inputs.at(0)->requires_grad(); ctx->update_requires_grad = inputs.at(2)->requires_grad(); if (ctx->update_requires_grad || ctx->tensor_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); // indices } if (ctx->tensor_requires_grad) { ctx->SaveTensorForBackward(inputs.at(2)); // update: only use meta information } return Maybe::Ok(); } Maybe Apply(const TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); if (ctx->update_requires_grad) { const auto& indices = ctx->SavedTensors().at(0); in_grads->at(2) = JUST(functional::GatherNd(out_grads.at(0), indices)); } if (ctx->tensor_requires_grad) { const auto& indices = ctx->SavedTensors().at(0); const auto& update = ctx->SavedTensors().at(1); const auto& temp = JUST(functional::ZerosLike(update)); in_grads->at(0) = JUST( functional::TensorScatterNdUpdate(out_grads.at(0), indices, temp, /*inplace=*/false)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("tensor_scatter_nd_update", TensorScatterNdUpdate); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/tf_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { namespace { struct TFPoolCaptureState : public AutoGradCaptureState { bool requires_grad = false; size_t input_index = 0; size_t output_index = 0; std::string data_format; std::string padding; std::vector padding_before; std::vector padding_after; std::vector pool_size; std::vector strides; bool ceil_mode = false; }; class TFPoolNdGrad : public OpExprGradFunction { public: virtual ~TFPoolNdGrad() = default; using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, const std::string& mode); Maybe Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: std::string mode_; AttrMap base_attrs_; }; Maybe TFPoolNdGrad::Init(const OpExpr& op, const std::string& mode) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); mode_ = mode; return Maybe::Ok(); } Maybe TFPoolNdGrad::Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->output_index = ctx->SaveTensorForBackward(outputs.at(0)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding = JUST(composed_attrs.GetAttr("padding")); ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); ctx->padding_after = JUST(composed_attrs.GetAttr>("padding_after")); ctx->pool_size = JUST(composed_attrs.GetAttr>("pool_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->ceil_mode = JUST(composed_attrs.GetAttr("ceil_mode")); return Maybe::Ok(); } Maybe TFPoolNdGrad::Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) int32_t ndims = ctx->pool_size.size(); const auto& input = ctx->SavedTensors().at(ctx->input_index); const auto& output = ctx->SavedTensors().at(ctx->output_index); in_grads->resize(1); (*in_grads)[0] = JUST(functional::TFPoolNdGrad( input, output, out_grads[0], mode_, ndims, ctx->data_format, ctx->padding, ctx->padding_before, ctx->padding_after, ctx->pool_size, ctx->strides, ctx->ceil_mode)); return Maybe::Ok(); } } // namespace class TFMaxPoolNdGrad final : public TFPoolNdGrad { public: Maybe Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, "tf_max"); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_1d", TFMaxPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_2d", TFMaxPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_3d", TFMaxPoolNdGrad); class TFAvgPoolNdGrad final : public TFPoolNdGrad { public: Maybe Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, "tf_avg"); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_1d", TFAvgPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_2d", TFAvgPoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_3d", TFAvgPoolNdGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/to_contiguous.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" namespace oneflow { namespace one { struct ToContiguousCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class ToContiguous : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ToContiguousCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs[0]->requires_grad(); return Maybe::Ok(); } Maybe Apply(const ToContiguousCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { (*in_grads)[0] = out_grads[0]; } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("to_contiguous", ToContiguous); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/transpose.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TransposeCaptureState : public AutoGradCaptureState { std::vector perm; bool requires_grad; }; class Transpose : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TransposeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Transpose::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Transpose::Capture(TransposeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->perm = JUST(composed_attrs.GetAttr>("perm")); return Maybe::Ok(); } Maybe Transpose::Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) std::vector grad_perm; grad_perm.resize(ctx->perm.size()); FOR_RANGE(int32_t, i, 0, ctx->perm.size()) { grad_perm.at(ctx->perm.at(i)) = i; } in_grads->at(0) = JUST(functional::Transpose(out_grads.at(0), grad_perm)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("transpose", Transpose); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/tril.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TrilCaptureState : public AutoGradCaptureState { bool requires_grad = false; int64_t diagonal = 0; }; class Tril : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TrilCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Tril::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Tril::Capture(TrilCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->diagonal = JUST(composed_attrs.GetAttr("diagonal")); return Maybe::Ok(); } Maybe Tril::Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::Tril(out_grads.at(0), ctx->diagonal)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("tril", Tril); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/triu.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TriuCaptureState : public AutoGradCaptureState { bool requires_grad; int64_t diagonal; }; class Triu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Triu::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Triu::Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->diagonal = JUST(composed_attrs.GetAttr("diagonal")); return Maybe::Ok(); } Maybe Triu::Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::Triu(out_grads.at(0), ctx->diagonal)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("triu", Triu); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/trunc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct TruncCaptureState : public AutoGradCaptureState { bool requires_grad = false; }; class Trunc : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TruncCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Trunc::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Trunc::Capture(TruncCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } return Maybe::Ok(); } Maybe Trunc::Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("trunc", Trunc); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/two_stage_reduce.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { enum class ReduceMode : int32_t { kMin = 0, kMax = 1, }; struct ReduceDeviceCaptureState : public AutoGradCaptureState { std::vector axis; bool requires_grad = false; size_t mask_index = -1; size_t count_index = -1; }; template class ReduceDevice : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ReduceDeviceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1)); // mask ctx->count_index = ctx->SaveTensorForBackward(outputs.at(2)); // count return Maybe::Ok(); } Maybe Apply(const ReduceDeviceCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 3); // NOLINT(maybe-need-error-msg) const auto& mask = ctx->SavedTensors().at(ctx->mask_index); const auto& count = ctx->SavedTensors().at(ctx->count_index); in_grads->resize(1); if (mode == ReduceMode::kMin) { in_grads->at(0) = JUST(functional::ReduceMinDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis)); } else { in_grads->at(0) = JUST(functional::ReduceMaxDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min_device_stage", ReduceDevice); REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max_device_stage", ReduceDevice); struct ReduceGlobalCaptureState : public AutoGradCaptureState { std::vector axis; bool requires_grad = false; bool keepdims = false; size_t mask_index = -1; size_t device_count_index = -1; }; template class ReduceGlobal : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ReduceGlobalCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->keepdims = JUST(composed_attrs.GetAttr("keepdims")); ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1)); // mask ctx->device_count_index = ctx->SaveTensorForBackward(inputs.at(1)); // device_count return Maybe::Ok(); } Maybe Apply(const ReduceGlobalCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) const auto& mask = ctx->SavedTensors().at(ctx->mask_index); const auto& device_count = ctx->SavedTensors().at(ctx->device_count_index); in_grads->resize(2); if (mode == ReduceMode::kMin) { in_grads->at(0) = JUST(functional::ReduceMinGlobalStageGrad( out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims)); } else { in_grads->at(0) = JUST(functional::ReduceMaxGlobalStageGrad( out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min_global_stage", ReduceGlobal); REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max_global_stage", ReduceGlobal); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/unfold.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct UnfoldInterpState : public AutoGradCaptureState { bool requires_grad = true; std::string data_format = "channels_first"; std::vector output_size; std::vector kernel_size; std::vector dilation_rate; std::vector padding; std::vector strides; }; class Unfold : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(UnfoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Unfold::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Unfold::Capture(UnfoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); std::vector out_shape(2); const std::shared_ptr& x = inputs.at(0); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->padding = JUST(composed_attrs.GetAttr>("padding")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); // Only support 4-d Tensor Input. for (int i = 0; i < 2; i++) { out_shape.at(i) = (x->shape()->At(i + 2)); } ctx->output_size = out_shape; return Maybe::Ok(); } Maybe Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); in_grads->at(0) = JUST(functional::Fold(out_grads.at(0), ctx->output_size, ctx->kernel_size, ctx->dilation_rate, ctx->padding, ctx->strides, ctx->data_format)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("unfold", Unfold); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/unfold_tensor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct UnfoldTensorCaptureState : public AutoGradCaptureState { int32_t dimension = -1; int32_t size = -1; int32_t step = -1; bool requires_grad = false; }; class UnfoldTensor : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::shared_ptr grad_op_; }; Maybe UnfoldTensor::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe UnfoldTensor::Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->dimension = JUST(composed_attrs.GetAttr("dimension")); ctx->size = JUST(composed_attrs.GetAttr("size")); ctx->step = JUST(composed_attrs.GetAttr("step")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe UnfoldTensor::Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& in = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::UnfoldTensorGrad(out_grads.at(0), in, ctx->dimension, ctx->size, ctx->step)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("unfold_tensor", UnfoldTensor); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/unsqueeze.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { struct UnsqueezeCaptureState : public AutoGradCaptureState { bool requires_grad; Shape shape; }; class Unsqueeze : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Unsqueeze::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Unsqueeze::Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } if (LazyMode::is_enabled()) { ctx->SaveTensorForBackward(inputs.at(0)); } else { ctx->shape = *(inputs.at(0)->shape()); } return Maybe::Ok(); } Maybe Unsqueeze::Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); if (LazyMode::is_enabled()) { const auto& like = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like)); } else { in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), ctx->shape)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("expand_dims", Unsqueeze); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/upsample.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct UpsampleCaptureState : public AutoGradCaptureState { bool requires_grad = false; double height_scale = 0.0; double width_scale = 0.0; float align_corners; std::string data_format; std::string interpolation; }; class Upsample : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; std::shared_ptr grad_op_; }; Maybe Upsample::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Upsample::Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->interpolation = JUST(composed_attrs.GetAttr("interpolation")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Upsample::Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, ctx->align_corners, ctx->data_format, ctx->interpolation)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample); struct UpsampleNearest2DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double height_scale = 0.0; double width_scale = 0.0; std::vector output_size; std::string data_format; }; class UpsampleNearest2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleNearest2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleNearest2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_2d", UpsampleNearest2D); struct UpsampleBilinear2DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double height_scale = 0.0; double width_scale = 0.0; bool align_corners; std::vector output_size; std::string data_format; }; class UpsampleBilinear2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleBilinear2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleBilinear2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBilinear2DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bilinear_2d", UpsampleBilinear2D); struct UpsampleLinear1DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double scale_factor = 0.0; bool align_corners; std::vector output_size; std::string data_format; }; class UpsampleLinear1D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleLinear1DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleLinear1DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleLinear1DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->scale_factor, ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D); struct UpsampleNearest1DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double scale_factor = 0.0; std::vector output_size; std::string data_format; }; class UpsampleNearest1D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleNearest1DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST( functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->scale_factor, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_1d", UpsampleNearest1D); struct UpsampleBicubic2DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double height_scale = 0.0; double width_scale = 0.0; bool align_corners; std::vector output_size; std::string data_format; }; class UpsampleBicubic2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleBicubic2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleBicubic2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBicubic2DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bicubic_2d", UpsampleBicubic2D); struct UpsampleNearest3DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double depth_scale = 0.0; double height_scale = 0.0; double width_scale = 0.0; std::vector output_size; std::string data_format; }; class UpsampleNearest3D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleNearest3DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleNearest3DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, ctx->width_scale, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_3d", UpsampleNearest3D); struct UpsampleTrilinear3DCaptureState : public AutoGradCaptureState { bool requires_grad = false; double depth_scale = 0.0; double height_scale = 0.0; double width_scale = 0.0; bool align_corners; std::vector output_size; std::string data_format; }; class UpsampleTrilinear3D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); if (composed_attrs.Has("output_size")) { ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleTrilinear3DGrad( JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_trilinear_3d", UpsampleTrilinear3D); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/variance.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { struct VarianceState : public AutoGradCaptureState { VarianceState() : requires_grad(false), unbiased(true), keepdim(false), axis({}){}; bool requires_grad; bool unbiased; bool keepdim; std::vector axis; }; class Variance : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(VarianceState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const VarianceState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe Variance::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Variance::Capture(VarianceState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->keepdim = JUST(composed_attrs.GetAttr("keepdim")); ctx->unbiased = JUST(composed_attrs.GetAttr("unbiased")); ctx->axis = JUST(composed_attrs.GetAttr>("dim")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Variance::Apply(const VarianceState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { // TODO(): replace it using kernel const std::shared_ptr& x = ctx->SavedTensors().at(0); DataType data_type = x->dtype()->data_type(); CHECK_NE_OR_RETURN(data_type, DataType::kBFloat16) << Error::RuntimeError() << "Variance op not support backward for bfloat16 yet!"; size_t correction = ctx->unbiased ? 1 : 0; size_t elem_cnt = 1; CHECK_OR_RETURN(ctx->axis.size() > 0) << Error::RuntimeError() << "The size of the axis must greater than 0, but got " << ctx->axis.size(); for (const auto& item : ctx->axis) { elem_cnt *= x->shape()->At(item); } std::shared_ptr out_grad = out_grads.at(0); if (ctx->keepdim == false) { // for broadcast mul const std::shared_ptr& out_grad_shape = out_grad->shape(); DimVector unsqueeze_vector(out_grad_shape->dim_vec()); for (int i = 0; i < ctx->axis.size(); i++) { unsqueeze_vector.insert(unsqueeze_vector.begin() + ctx->axis.at(i), 1); } Shape unsqueeze_shape(unsqueeze_vector); CHECK_EQ_OR_RETURN(unsqueeze_shape.elem_cnt(), out_grad_shape->elem_cnt()) << Error::RuntimeError() << "tensor size mismatch, expected tensor to have the same number of elements, but got " << unsqueeze_shape.elem_cnt() << " and " << out_grad_shape->elem_cnt() << " elements respectively"; out_grad = JUST(functional::Reshape(out_grad, unsqueeze_shape)); } in_grads->resize(1); in_grads->at(0) = JUST(functional::Mul( out_grad, JUST(functional::ScalarMul( Scalar(2.0 / (elem_cnt - correction)), JUST(functional::Sub(x, JUST(functional::ReduceMean(x, ctx->axis, /*keepdim=*/true)), /*alpha=*/1.0, /*inplace=*/false)))))); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("var", Variance); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/vector_matrix_product.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct VectorMatrixProductCaptureState : public AutoGradCaptureState { bool requires_grad_a = false; bool requires_grad_b = false; size_t a_index = 0; size_t b_index = 1; }; class VectorMatrixProduct : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(VectorMatrixProductCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const VectorMatrixProductCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: AttrMap base_attrs_; }; Maybe VectorMatrixProduct::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. "; base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe VectorMatrixProduct::Capture(VectorMatrixProductCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad(); ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad(); if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); if (ctx->requires_grad_a) { ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b } if (ctx->requires_grad_b) { ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a } return Maybe::Ok(); } Maybe VectorMatrixProduct::Apply(const VectorMatrixProductCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. "; in_grads->resize(2); if (ctx->requires_grad_a) { const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index)); JUST(VectorAt(*in_grads, 0)) = JUST(functional::VectorMatrixProductGradA(JUST(VectorAt(out_grads, 0)), input_b)); } if (ctx->requires_grad_b) { const auto& input_a = JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->a_index)); JUST(VectorAt(*in_grads, 1)) = JUST(functional::VectorMatrixProductGradB(JUST(VectorAt(out_grads, 0)), input_a)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("vector_matrix_product", VectorMatrixProduct); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/gradient_funcs/where.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { struct WhereCaptureState : public AutoGradCaptureState { bool requires_grad_x = false; bool requires_grad_y = false; DimVector x_reduce_dims = {}; DimVector y_reduce_dims = {}; DimVector x_squeeze_dims = {}; DimVector y_squeeze_dims = {}; }; class Where : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Where::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe Where::Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // cond, x, y CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->requires_grad_x = inputs.at(1)->requires_grad(); ctx->requires_grad_y = inputs.at(2)->requires_grad(); if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); // condition CHECK_EQ_OR_RETURN(outputs.size(), 1); const Shape& out_shape = *outputs.at(0)->shape(); auto GetReduceDims = [&](DimVector& reduce_dim_vec, DimVector& squeeze_dim_vec, const std::shared_ptr& tensor) -> Maybe { reduce_dim_vec.clear(); squeeze_dim_vec.clear(); const Shape& shape = *tensor->shape(); if (functional::IsScalarTensor(tensor)) { reduce_dim_vec.resize(out_shape.size()); squeeze_dim_vec.resize(out_shape.size()); std::iota(reduce_dim_vec.begin(), reduce_dim_vec.end(), 0); std::iota(squeeze_dim_vec.begin(), squeeze_dim_vec.end(), 0); } else if (shape != out_shape) { CHECK_GE_OR_RETURN(out_shape.size(), shape.size()); // NOLINT(maybe-need-error-msg) size_t ddiff = out_shape.size() - shape.size(); for (int i = 0; i < out_shape.size(); ++i) { if (i < ddiff) { reduce_dim_vec.push_back(i); squeeze_dim_vec.push_back(i); } else if (out_shape[i] != shape[i - ddiff]) { reduce_dim_vec.push_back(i); } } } return Maybe::Ok(); }; JUST(GetReduceDims(ctx->x_reduce_dims, ctx->x_squeeze_dims, inputs.at(1))); JUST(GetReduceDims(ctx->y_reduce_dims, ctx->y_squeeze_dims, inputs.at(2))); return Maybe::Ok(); } Maybe Where::Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) const auto& out_grad = out_grads.at(0); CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 1); // NOLINT(maybe-need-error-msg) const auto& condition = ctx->SavedTensors().at(0); std::shared_ptr zero; if (out_grad->is_local()) { zero = JUST( functional::Constant(Shape({}), Scalar(0), out_grad->dtype(), JUST(out_grad->device()))); } else { const size_t sbp_ndim = JUST(out_grad->nd_sbp())->sbp_parallel_size(); std::vector> nd_sbp_vec; nd_sbp_vec.reserve(sbp_ndim); for (int i = 0; i < sbp_ndim; ++i) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); nd_sbp_vec.push_back(SymbolOf(sbp)); } const auto& parallel_desc = JUST(out_grad->parallel_desc()); zero = JUST(functional::GlobalConstant(Shape({}), Scalar(0), out_grad->dtype(), parallel_desc, nd_sbp_vec)); } in_grads->resize(3); // cond, x, y if (ctx->requires_grad_x) { auto x_grad = JUST(functional::Where(condition, out_grad, zero)); if (!ctx->x_reduce_dims.empty()) { x_grad = JUST(functional::ReduceSum( x_grad, std::vector{ctx->x_reduce_dims.begin(), ctx->x_reduce_dims.end()}, /*keepdims=*/true, NullOpt)); } if (!ctx->x_squeeze_dims.empty()) { x_grad = JUST(functional::Squeeze( x_grad, std::vector{ctx->x_squeeze_dims.begin(), ctx->x_squeeze_dims.end()})); } in_grads->at(1) = x_grad; } if (ctx->requires_grad_y) { auto y_grad = JUST(functional::Where(condition, zero, out_grad)); if (!ctx->y_reduce_dims.empty()) { y_grad = JUST(functional::ReduceSum( y_grad, std::vector{ctx->y_reduce_dims.begin(), ctx->y_reduce_dims.end()}, /*keepdims=*/true, NullOpt)); } if (!ctx->y_squeeze_dims.empty()) { y_grad = JUST(functional::Squeeze( y_grad, std::vector{ctx->y_squeeze_dims.begin(), ctx->y_squeeze_dims.end()})); } in_grads->at(2) = y_grad; } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("where", Where); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/activation.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct BaseActivationGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; }; typedef Maybe (*NoParamActivationBwFunc)(const std::shared_ptr&, const std::shared_ptr&); template class NoParamActivationGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(BaseActivationGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, x CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(1)->requires_grad(); ctx->grad_requires_grad = inputs.at(0)->requires_grad(); if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(1)); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const BaseActivationGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& x = ctx->SavedTensors().at(0); if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(1); in_grads->at(1) = JUST(functional::Mul(out_grads.at(0), JUST(BwBwFunc(x, grad)))); } if (ctx->grad_requires_grad) { in_grads->at(0) = JUST(BwFunc(out_grads.at(0), x)); } return Maybe::Ok(); } }; #define INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(op_type_name, op_cls) \ class op_cls##GradGradCls final \ : public NoParamActivationGradGrad { \ }; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls); // first order backward param: (dy, x) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("mish_grad", Mish) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("gelu_grad", Gelu) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("silu_grad", Silu) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("selu_grad", Selu) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("softsign_grad", SoftSign) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("hardsigmoid_grad", HardSigmoid) INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("hardswish_grad", HardSwish) #undef INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS struct HardShrinkGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool grad_requires_grad = false; double lambd = 0.5; }; class HardShrinkGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(HardShrinkGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // y, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->lambd = JUST(composed_attrs.GetAttr("lambd")); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const HardShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::HardShrinkGrad(y, out_grads.at(0), ctx->lambd)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct SoftShrinkGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool grad_requires_grad = false; double alpha = 0.5; }; class SoftShrinkGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SoftShrinkGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // y, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const SoftShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::SoftShrinkGrad(y, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct ReluGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool grad_requires_grad = false; }; class ReluGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(ReluGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, y CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->grad_requires_grad = inputs.at(0)->requires_grad(); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } Maybe Apply(const ReluGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->y_requires_grad) { in_grads->at(1) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y)); } return Maybe::Ok(); } }; struct LeakyReluGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; float alpha = 0.01; }; class LeakyReluGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(LeakyReluGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const LeakyReluGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct SoftplusGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; double beta = 1.0; double threshold = 20.0; }; class SoftplusGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SoftplusGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->beta = JUST(composed_attrs.GetAttr("beta")); ctx->threshold = JUST(composed_attrs.GetAttr("threshold")); ctx->SaveTensorForBackward(inputs.at(0)); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } Maybe Apply(const SoftplusGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& x = ctx->SavedTensors().at(0); if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(1); in_grads->at(0) = JUST(functional::Mul( out_grads.at(0), JUST(functional::SoftplusGradGrad(x, grad, ctx->beta, ctx->threshold)))); } if (ctx->grad_requires_grad) { in_grads->at(1) = JUST(functional::SoftplusGrad(x, out_grads.at(0), ctx->beta, ctx->threshold)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct HardTanhGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool grad_requires_grad = false; double min_val = -1.0; double max_val = 1.0; }; class HardTanhGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(HardTanhGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // y, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->min_val = JUST(composed_attrs.GetAttr("min_val")); ctx->max_val = JUST(composed_attrs.GetAttr("max_val")); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const HardTanhGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& y = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct EluGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; double alpha = 1.0; }; class EluGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(EluGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ctx->SaveTensorForBackward(inputs.at(0)); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } Maybe Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& x = ctx->SavedTensors().at(0); if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(1); in_grads->at(0) = JUST( functional::Mul(out_grads.at(0), JUST(functional::EluGradGrad(x, grad, ctx->alpha)))); } if (ctx->grad_requires_grad) { in_grads->at(1) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; class CeluGradGrad : public EluGradGrad { public: Maybe Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& y = ctx->SavedTensors().at(0); if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(1); in_grads->at(0) = JUST( functional::CeluGradGrad(y, JUST(functional::Mul(out_grads.at(0), (grad))), ctx->alpha)); } if (ctx->grad_requires_grad) { in_grads->at(1) = JUST(functional::CeluGrad(y, out_grads.at(0), ctx->alpha)); } return Maybe::Ok(); } }; struct PReluGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool alpha_requires_grad = false; size_t grad_index = 0; size_t input_index = 1; size_t alpha_index = 2; }; class PReluGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(PReluGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, x, alpha CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs.at(0)->requires_grad(); // grad ctx->input_requires_grad = inputs.at(1)->requires_grad(); // input ctx->alpha_requires_grad = inputs.at(2)->requires_grad(); // alpha ctx->input_index = ctx->SaveTensorForBackward(inputs.at(1)); ctx->alpha_index = ctx->SaveTensorForBackward(inputs.at(2)); ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } Maybe Apply(const PReluGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(3); const auto& input = ctx->SavedTensors().at(ctx->input_index); const auto& alpha = ctx->SavedTensors().at(ctx->alpha_index); const auto& grad = ctx->SavedTensors().at(ctx->grad_index); const auto& grad_for_input = out_grads.at(0); const auto& grad_for_alpha = out_grads.at(1); const auto& condition = JUST(functional::ScalarLogicalLess(input, Scalar(0.0))); const auto& zero_grad = JUST(functional::ZerosLike(alpha)); // alpha can broadcast to input if (ctx->grad_requires_grad) { auto input_mul_grad = JUST(functional::Mul(alpha, grad_for_input)); auto alpha_mul_grad = JUST(functional::Mul(input, grad_for_alpha)); auto result = JUST(functional::Add(input_mul_grad, alpha_mul_grad, /*alpha=*/Scalar(1.0), /*inplace*/ false)); in_grads->at(0) = JUST(functional::Where(condition, result, grad_for_input)); } if (ctx->input_requires_grad) { auto result = JUST(functional::Mul(grad, grad_for_alpha)); in_grads->at(1) = JUST(functional::Where(condition, result, zero_grad)); } if (ctx->alpha_requires_grad) { auto result = JUST(functional::Mul(grad, grad_for_input)); in_grads->at(2) = JUST(functional::Where(condition, result, zero_grad)); } return Maybe::Ok(); } private: std::shared_ptr grad_op_; }; struct ThresholdGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; double threshold = 0.0; }; class ThresholdGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(ThresholdGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->threshold = JUST(composed_attrs.GetAttr("threshold_val")); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const ThresholdGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); } if (ctx->grad_requires_grad) { const auto& x = ctx->SavedTensors().at(0); in_grads->at(1) = JUST(functional::ThresholdGrad(x, out_grads.at(0), ctx->threshold)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("relu_grad", ReluGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("elu_grad", EluGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("celu_grad", CeluGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("prelu_grad", PReluGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink_grad", HardShrinkGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink_grad", SoftShrinkGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu_grad", LeakyReluGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh_grad", HardTanhGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("threshold_grad", ThresholdGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("softplus_grad", SoftplusGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/avg_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct AdaptiveAvgPoolNDGradGradCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool grad_requires_grad = false; std::vector pool_output_size; std::string data_format; }; template class AdaptiveAvgPoolNdNdGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, x CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); if (ctx->grad_requires_grad) { const auto& grad_shape = *inputs[0]->shape(); if (ndims == 1) { ctx->pool_output_size = {grad_shape[grad_shape.size() - 1]}; } else if (ndims == 2) { ctx->pool_output_size = {grad_shape[grad_shape.size() - 2], grad_shape[grad_shape.size() - 1]}; } else if (ndims == 3) { ctx->pool_output_size = {grad_shape[grad_shape.size() - 3], grad_shape[grad_shape.size() - 2], grad_shape[grad_shape.size() - 1]}; } else { UNIMPLEMENTED_THEN_RETURN(); } } return Maybe::Ok(); } Maybe Apply(const AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->grad_requires_grad) { if (ndims == 1) { (*in_grads)[0] = JUST( functional::AdaptiveAvgPool1D(out_grads[0], ctx->pool_output_size, ctx->data_format)); } else if (ndims == 2) { (*in_grads)[0] = JUST( functional::AdaptiveAvgPool2D(out_grads[0], ctx->pool_output_size, ctx->data_format)); } else if (ndims == 3) { (*in_grads)[0] = JUST( functional::AdaptiveAvgPool3D(out_grads[0], ctx->pool_output_size, ctx->data_format)); } else { UNIMPLEMENTED_THEN_RETURN(); } } if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; struct AvgPoolGradGradCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool grad_requires_grad = false; std::string data_format; std::vector padding; std::vector kernel_size; std::vector stride; bool ceil_mode = false; bool count_include_pad = false; int32_t divisor_override = 0; }; class AvgPoolNdGradGrad : public OpExprGradFunction { public: virtual ~AvgPoolNdGradGrad() = default; Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(AvgPoolGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, x CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding = JUST(composed_attrs.GetAttr>("padding")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->stride = JUST(composed_attrs.GetAttr>("stride")); ctx->ceil_mode = JUST(composed_attrs.GetAttr("ceil_mode")); ctx->count_include_pad = JUST(composed_attrs.GetAttr("count_include_pad")); ctx->divisor_override = JUST(composed_attrs.GetAttr("divisor_override")); return Maybe::Ok(); } Maybe Apply(const AvgPoolGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(2); if (ctx->grad_requires_grad) { int32_t ndims = ctx->kernel_size.size(); const auto pool_op = (ndims == 1 ? functional::AvgPool1D : (ndims == 2 ? functional::AvgPool2D : (ndims == 3 ? functional::AvgPool3D : nullptr))); CHECK_NOTNULL_OR_RETURN(pool_op); // NOLINT(maybe-need-error-msg) (*in_grads)[0] = JUST(pool_op(out_grads[0], ctx->kernel_size, ctx->stride, ctx->padding, ctx->ceil_mode, ctx->count_include_pad, ctx->divisor_override, ctx->data_format)); } if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_1d_grad", AvgPoolNdGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_2d_grad", AvgPoolNdGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_3d_grad", AvgPoolNdGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool1d_grad", AdaptiveAvgPoolNdNdGradGrad<1>); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool2d_grad", AdaptiveAvgPoolNdNdGradGrad<2>); REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool3d_grad", AdaptiveAvgPoolNdNdGradGrad<3>); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct BinaryCrossEntropyGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool target_requires_grad = false; bool has_weight = false; }; class BinaryCrossEntropyGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe BinaryCrossEntropyGradGrad::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe BinaryCrossEntropyGradGrad::Capture(BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // dy, input, target[, weight] CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->target_requires_grad = inputs[2]->requires_grad(); ctx->has_weight = inputs.size() == 4; ctx->SaveTensorForBackward(inputs[0]); // grad ctx->SaveTensorForBackward(inputs[1]); // input ctx->SaveTensorForBackward(inputs[2]); // target if (ctx->has_weight) { ctx->SaveTensorForBackward(inputs[3]); // weight } return Maybe::Ok(); } Maybe BinaryCrossEntropyGradGrad::Apply(const BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 3 + ctx->has_weight); // NOLINT(maybe-need-error-msg) in_grads->resize(3 + ctx->has_weight); const auto& grad = ctx->SavedTensors()[0]; const auto& input = ctx->SavedTensors()[1]; const auto& target = ctx->SavedTensors()[2]; // dx = grad * [-target/input + (1-target)/(1-input)] // grad_for_grad = out_grad * [-target/input + (1-target)/(1-input)] // grad_for_input = out_grad * grad * [target/(input*input) + (1-target)/((1-input)*(1-input))] // = out_grad * grad * [(input*input-2*input*target+target)/(input*(1-input))^2] // grad_for_target = out_grad * grad * [1/(input*(1-input))] if (ctx->grad_requires_grad) { const auto& weight = ctx->has_weight ? Optional(ctx->SavedTensors()[3]) : NullOpt; (*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(out_grads[0], input, target, weight)); } if (ctx->input_requires_grad) { auto one_sub_input = JUST(functional::ScalarSub(1, input, /*alpha=*/1)); auto input_mul_target = JUST(functional::Mul(input, target)); auto numerator = JUST(functional::sequence_function(functional::Square) .then(std::bind(functional::Sub, std::placeholders::_1, input_mul_target, /*alpha=*/2, /*inplace=*/false)) .then([&target](const std::shared_ptr& in) { return functional::Add(in, target, /*alpha=*/1, /*inplace=*/false); }) .call(input)); auto res = JUST(functional::sequence_function(functional::Mul) .then(functional::Square) .then(std::bind(functional::Div, numerator, std::placeholders::_1)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::Mul, std::placeholders::_1, grad)) .call(input, one_sub_input)); (*in_grads)[1] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res; } if (ctx->target_requires_grad) { auto input_sub_one = JUST(functional::ScalarAdd(-1, input, /*alpha=*/1)); auto res = JUST(functional::sequence_function(functional::Mul) .then(std::bind(functional::LogGrad, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::Mul, std::placeholders::_1, grad)) .call(input, input_sub_one)); (*in_grads)[2] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res; } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_grad", BinaryCrossEntropyGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct BinaryCrossEntropyWithLogitsGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool target_requires_grad = false; bool has_weight = false; bool has_pos_weight = false; }; class BinaryCrossEntropyWithLogitsGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe BinaryCrossEntropyWithLogitsGradGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsGradGrad::Capture( BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // dy, input, target[, weight][, pos_weight] CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 5); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->target_requires_grad = inputs[2]->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->has_pos_weight = JUST(composed_attrs.GetAttr("has_pos_weight")); ctx->has_weight = inputs.size() == 5 || (inputs.size() == 4 && !ctx->has_pos_weight); ctx->SaveTensorForBackward(inputs[0]); // grad ctx->SaveTensorForBackward(inputs[1]); // input ctx->SaveTensorForBackward(inputs[2]); // target if (inputs.size() == 4) { ctx->SaveTensorForBackward(inputs[3]); // weight or pos_weight } if (inputs.size() == 5) { ctx->SaveTensorForBackward(inputs[3]); // weight ctx->SaveTensorForBackward(inputs[4]); // pos_weight } return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsGradGrad::Apply( const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 3 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg) in_grads->resize(3 + ctx->has_weight + ctx->has_pos_weight); const auto& grad = ctx->SavedTensors()[0]; const auto& input = ctx->SavedTensors()[1]; const auto& target = ctx->SavedTensors()[2]; const size_t pos_weight_index = ctx->has_weight ? 4 : 3; const auto& weight = ctx->has_weight ? Optional(ctx->SavedTensors()[3]) : NullOpt; const auto& pos_weight = ctx->has_pos_weight ? Optional(ctx->SavedTensors()[pos_weight_index]) : NullOpt; // dx = grad * weight * (-target*(1-input.sigmoid())*pos_weight + input.sigmoid()*(1-target)) // grad_for_input = out_grad * grad * weight * sig * (1-sig) * [pos_weight * target + 1 - target] // grad_for_target = -out_grad * grad * weight * [pos_weight + sig - pos_weight * sig] if (ctx->grad_requires_grad) { (*in_grads)[0] = JUST(functional::BinaryCrossEntropyWithLogitsLossGrad( out_grads[0], input, target, weight, pos_weight)); } if (ctx->input_requires_grad) { auto res = JUST(functional::sequence_function(functional::Sigmoid) .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, grad)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .call(input)); if (ctx->has_pos_weight) { res = JUST(functional::sequence_function(functional::Mul) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1)); }) .then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1, /*inplace=*/false)) .then(std::bind(functional::Mul, std::placeholders::_1, res)) .call(JUST(pos_weight), target)); } if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); } (*in_grads)[1] = res; } if (ctx->target_requires_grad) { auto res = JUST(functional::sequence_function(functional::Mul) .then(functional::Negative) .call(out_grads[0], grad)); if (ctx->has_pos_weight) { auto sig = JUST(functional::Sigmoid(input)); auto one_sub_sig = JUST(functional::ScalarSub(1, sig, /*alpha=*/1)); res = JUST(functional::sequence_function(functional::Mul) .then([&sig](const std::shared_ptr& input) { return functional::Add(input, sig, /*alpha=*/Scalar(1), /*inplace=*/false); }) .then(std::bind(functional::Mul, std::placeholders::_1, res)) .call(one_sub_sig, JUST(pos_weight))); } if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); } (*in_grads)[2] = res; } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_grad", BinaryCrossEntropyWithLogitsGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool target_requires_grad = false; size_t grad_index = 0; size_t input_index = 0; size_t target_index = 0; }; class BinaryCrossEntropyWithLogitsReduceMeanGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Capture( BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // dy, input, target CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->target_requires_grad = inputs[2]->requires_grad(); if (ctx->input_requires_grad || ctx->target_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]); // grad } if (ctx->input_requires_grad || ctx->grad_requires_grad) { ctx->input_index = ctx->SaveTensorForBackward(inputs[1]); // input } if (ctx->grad_requires_grad) { ctx->target_index = ctx->SaveTensorForBackward(inputs[2]); // target } return Maybe::Ok(); } Maybe BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Apply( const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); // dx = grad * weight * (input.sigmoid() - target) // grad_for_input = out_grad * grad * weight * sig * (1-sig) // grad_for_target = -out_grad * grad * weight if (ctx->grad_requires_grad) { const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index)); const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index)); (*in_grads)[0] = JUST( functional::sequence_function(functional::Sigmoid) .then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1, /*inplace=*/false)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::ReduceMean, std::placeholders::_1, std::vector{}, /*keepdim=*/false)) .call(input)); } if (ctx->input_requires_grad) { const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index)); const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index)); const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad)); (*in_grads)[1] = JUST(functional::sequence_function(functional::Sigmoid) .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::Mul, std::placeholders::_1, mean_grad)) .call(input)); } if (ctx->target_requires_grad) { const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index)); const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad)); (*in_grads)[2] = JUST(functional::sequence_function(functional::Mul) .then(functional::Negative) .call(out_grads[0], mean_grad)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_reduce_mean_grad", BinaryCrossEntropyWithLogitsReduceMeanGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/conv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct ConvDataGradGradCaptureState : public AutoGradCaptureState { bool w_requires_grad = false; bool grad_requires_grad = false; size_t w_index = 0; size_t grad_index = 0; std::string data_format; std::vector padding_before; std::vector kernel_size; std::vector strides; std::vector dilation_rate; int32_t groups = 0; }; class ConvDataGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ConvDataGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ConvDataGradGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ConvDataGradGrad::Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // input: dy, w, x_like, [add to output] // output: dx CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->w_requires_grad = inputs.at(1)->requires_grad(); ctx->grad_requires_grad = inputs.at(0)->requires_grad(); if (ctx->grad_requires_grad) { ctx->w_index = ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->w_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->groups = JUST(composed_attrs.GetAttr("groups")); return Maybe::Ok(); } Maybe ConvDataGradGrad::Apply(const ConvDataGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(3); size_t num_spatial_dims = ctx->kernel_size.size(); // first order forward: ConvND // x * w = y ( * => convolution) // first order backward: // x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad // w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad // second order forward (first order backward): ConvDataGrad // y_grad * w.rot180 = x_grad // second order forward: // w_grad_grad = out_grads_x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad // grad_for_y_grad = out_grads_x * w (x.shape * w.shape -> y.shape) call ConvND // w_grad_grad if (ctx->w_requires_grad) { const auto& grad = ctx->SavedTensors().at(ctx->grad_index); in_grads->at(1) = JUST(functional::ConvFilterGrad( grad, out_grads.at(0), num_spatial_dims, ctx->kernel_size, ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } // grad_for_y_grad if (ctx->grad_requires_grad) { const auto& w = ctx->SavedTensors().at(ctx->w_index); const int32_t ndims = ctx->kernel_size.size(); const auto conv_op = (ndims == 1 ? functional::Conv1d : (ndims == 2 ? functional::Conv2d : (ndims == 3 ? functional::Conv3d : nullptr))); CHECK_NOTNULL_OR_RETURN(conv_op); // NOLINT(maybe-need-error-msg) in_grads->at(0) = JUST(conv_op(out_grads.at(0), w, Optional(), ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe::Ok(); } struct ConvFilterGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; size_t x_index = 0; size_t grad_index = 0; std::string data_format; std::vector padding_before; std::vector kernel_size; std::vector strides; std::vector dilation_rate; int32_t groups = 0; }; class ConvFilterGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(ConvFilterGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const ConvFilterGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe ConvFilterGradGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe ConvFilterGradGrad::Capture(ConvFilterGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // input: dy, x // output: dw CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(1)->requires_grad(); ctx->grad_requires_grad = inputs.at(0)->requires_grad(); ctx->x_index = ctx->SaveTensorForBackward(inputs.at(1)); if (ctx->x_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); ctx->groups = JUST(composed_attrs.GetAttr("groups")); return Maybe::Ok(); } Maybe ConvFilterGradGrad::Apply(const ConvFilterGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); size_t num_spatial_dims = ctx->kernel_size.size(); // first order forward: ConvND // x * w = y ( * => convolution) // first order backward: // x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad // w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad // second order forward (first order backward): ConvFilterGrad // x * y_grad = w_grad // second order backward: // x_grad_grad = out_grads_w * y_grad.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad // grad_for_y_grad = x * out_grads_w (x.shape * w.shape -> y.shape) call ConvND // x_grad_grad if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(ctx->grad_index); const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(1) = JUST(functional::ConvDataGrad( grad, out_grads.at(0), JUST(x->detach()), num_spatial_dims, ctx->kernel_size, ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } // grad_for_y_grad if (ctx->grad_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_index); const int32_t ndims = ctx->kernel_size.size(); const auto conv_op = (ndims == 1 ? functional::Conv1d : (ndims == 2 ? functional::Conv2d : (ndims == 3 ? functional::Conv3d : nullptr))); CHECK_NOTNULL_OR_RETURN(conv_op); // NOLINT(maybe-need-error-msg) in_grads->at(0) = JUST(conv_op(x, out_grads.at(0), Optional(), ctx->strides, ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("conv_data_grad", ConvDataGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("conv_filter_grad", ConvFilterGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/div.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct DivGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool z_requires_grad = false; bool grad_requires_grad = false; size_t y_index = 0; size_t z_index = 1; size_t grad_index = 2; }; class DivGradGrad : public OpExprGradFunction { // div_grad = -x/(y*y)*dz = -z/y*dz // div_grad_y = out_grad * z*dz/(y*y) // div_grad_z = out_grad * -dz/y // div_grad_dz = out_grad * -z/y public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(DivGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dz, z, y CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs.at(0)->requires_grad(); ctx->z_requires_grad = inputs.at(1)->requires_grad(); ctx->y_requires_grad = inputs.at(2)->requires_grad(); ctx->y_index = ctx->SaveTensorForBackward(inputs.at(2)); if (ctx->y_requires_grad || ctx->grad_requires_grad) { ctx->z_index = ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->y_requires_grad || ctx->z_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const DivGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(3); const auto& y = ctx->SavedTensors().at(ctx->y_index); if (ctx->grad_requires_grad) { const auto& z = ctx->SavedTensors().at(ctx->z_index); in_grads->at(0) = JUST(functional::sequence_function(functional::Mul) .then(functional::Negative) .then(std::bind(functional::Div, std::placeholders::_1, y)) .call(out_grads.at(0), z)); } if (ctx->z_requires_grad) { const auto& grad = ctx->SavedTensors().at(ctx->grad_index); in_grads->at(1) = JUST(functional::sequence_function(functional::Mul) .then(functional::Negative) .then(std::bind(functional::Div, std::placeholders::_1, y)) .call(out_grads.at(0), grad)); } if (ctx->y_requires_grad) { const auto& z = ctx->SavedTensors().at(ctx->z_index); const auto& grad = ctx->SavedTensors().at(ctx->grad_index); in_grads->at(2) = JUST( functional::sequence_function(functional::Mul) .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, y)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0))) .then(std::bind(functional::Div, std::placeholders::_1, JUST(functional::Square(y)))) .call(z, grad)); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div_grad", DivGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/kl_div_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct KLDivLossGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool target_requires_grad = false; bool log_target = false; size_t input_index = 0; size_t target_index = 0; }; class KLDivLossGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(KLDivLossGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const KLDivLossGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe KLDivLossGradGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe KLDivLossGradGrad::Capture(KLDivLossGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // grad, input, target CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->target_requires_grad = inputs[2]->requires_grad(); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->log_target = JUST(composed_attrs.GetAttr("log_target")); ctx->input_index = ctx->SaveTensorForBackward(inputs[1]); // input ctx->target_index = ctx->SaveTensorForBackward(inputs[2]); // target return Maybe::Ok(); } Maybe KLDivLossGradGrad::Apply(const KLDivLossGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); if (ctx->grad_requires_grad) { const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index)); const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index)); (*in_grads)[0] = JUST(functional::KLDivLossGrad(out_grads[0], input, target, ctx->log_target)); } if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); } if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::ZerosLike(out_grads[0])); } //// In pytorch 1.13 the higher derivative grad is fixed, which will cause difference here // if (ctx->target_requires_grad) { // if (ctx->log_target) (*in_grads)[2] = // JUST(functional::Mul(JUST(functional::Negative(JUST(functional::Exp(target)))), // out_grads[0])); else (*in_grads)[2] = JUST(functional::Negative(out_grads[0])); // } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("kl_div_loss_grad", KLDivLossGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/log_softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct LogSoftmaxGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool dy_requires_grad = false; }; class LogSoftmaxGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe LogSoftmaxGradGrad::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe LogSoftmaxGradGrad::Capture(LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // y, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs[0]->requires_grad(); ctx->dy_requires_grad = inputs[1]->requires_grad(); ctx->SaveTensorForBackward(inputs[0]); if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]); return Maybe::Ok(); } Maybe LogSoftmaxGradGrad::Apply(const LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); const auto& y = ctx->SavedTensors()[0]; const std::vector reduce_axis{static_cast(y->ndim() - 1)}; if (ctx->y_requires_grad) { const auto& dy = ctx->SavedTensors()[1]; in_grads->at(0) = JUST(functional::sequence_function(functional::ReduceSum) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Exp(y)))) .then(functional::Negative) .call(dy, reduce_axis, true, NullOpt)); } if (ctx->dy_requires_grad) { in_grads->at(1) = JUST(functional::sequence_function(functional::Exp) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .then(std::bind(functional::ReduceSum, std::placeholders::_1, reduce_axis, /*keepdim=*/true, NullOpt)) .then(std::bind(functional::Sub, out_grads[0], std::placeholders::_1, /*alpha=*/1, /*inplace=*/false)) .call(y)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("log_softmax_grad", LogSoftmaxGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/math_unary_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct UnaryMathGradGradState : public AutoGradCaptureState { bool input_requires_grad = false; bool grad_requires_grad = false; }; typedef Maybe (*UnaryBwFunc)(const std::shared_ptr&, const std::shared_ptr&); template class UnaryMathGradGrad : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs[0]->requires_grad(); ctx->grad_requires_grad = inputs[1]->requires_grad(); ctx->SaveTensorForBackward(inputs[0]); if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs[1]); } return Maybe::Ok(); } Maybe Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& input = ctx->SavedTensors()[0]; if (ctx->input_requires_grad) { const auto& grad = ctx->SavedTensors()[1]; (*in_grads)[0] = JUST(functional::Mul(out_grads[0], JUST(BwBwFunc(input, grad)))); } if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); } return Maybe::Ok(); } }; template class UnaryMathGradGradWithZeroDDX : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->input_requires_grad = inputs[0]->requires_grad(); ctx->grad_requires_grad = inputs[1]->requires_grad(); ctx->SaveTensorForBackward(inputs[0]); return Maybe::Ok(); } Maybe Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); const auto& input = ctx->SavedTensors()[0]; if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::ZerosLike(input)); } if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); } return Maybe::Ok(); } }; // TODO: Lgamma, first order backward unimplemented #define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sin_grad", Sin) \ OF_PP_MAKE_TUPLE_SEQ("cos_grad", Cos) \ OF_PP_MAKE_TUPLE_SEQ("tan_grad", Tan) \ OF_PP_MAKE_TUPLE_SEQ("sinh_grad", Sinh) \ OF_PP_MAKE_TUPLE_SEQ("cosh_grad", Cosh) \ OF_PP_MAKE_TUPLE_SEQ("asin_grad", Asin) \ OF_PP_MAKE_TUPLE_SEQ("acos_grad", Acos) \ OF_PP_MAKE_TUPLE_SEQ("atan_grad", Atan) \ OF_PP_MAKE_TUPLE_SEQ("asinh_grad", Asinh) \ OF_PP_MAKE_TUPLE_SEQ("acosh_grad", Acosh) \ OF_PP_MAKE_TUPLE_SEQ("atanh_grad", Atanh) \ OF_PP_MAKE_TUPLE_SEQ("erf_grad", Erf) \ OF_PP_MAKE_TUPLE_SEQ("erfc_grad", Erfc) \ OF_PP_MAKE_TUPLE_SEQ("exp_grad", Exp) \ OF_PP_MAKE_TUPLE_SEQ("exp2_grad", Exp2) \ OF_PP_MAKE_TUPLE_SEQ("expm1_grad", Expm1) \ OF_PP_MAKE_TUPLE_SEQ("log_grad", Log) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid_grad", LogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("log2_grad", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log1p_grad", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_grad", Reciprocal) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan_grad", ReciprocalNoNan) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt_grad", Rsqrt) \ OF_PP_MAKE_TUPLE_SEQ("sqrt_grad", Sqrt) \ OF_PP_MAKE_TUPLE_SEQ("square_grad", Square) #define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sigmoid_grad", Sigmoid) \ OF_PP_MAKE_TUPLE_SEQ("tanh_grad", Tanh) #define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("abs_grad", Abs) #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS(op_type_name, op_cls) \ class op_cls##GradGradCls final \ : public UnaryMathGradGrad {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS, MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS, MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ); #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS(op_type_name, op_cls) \ class op_cls##GradGradCls final \ : public UnaryMathGradGradWithZeroDDX {}; \ REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS, MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct BroadcastMatmulGradBGradCaptureState : public AutoGradCaptureState { bool a_requires_grad = false; bool b_requires_grad = false; size_t a_index = 0; size_t b_index = 1; double alpha = 1.0; }; class BroadcastMatmulGradBGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); }; Maybe Capture(BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->a_requires_grad = inputs.at(0)->requires_grad(); ctx->b_requires_grad = inputs.at(1)->requires_grad(); if (ctx->a_requires_grad) { ctx->b_index = ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->b_requires_grad) { ctx->a_index = ctx->SaveTensorForBackward(inputs.at(0)); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); return Maybe::Ok(); } Maybe Apply(const BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); // for matmul: input_a[dims..., m, k] * input_b[k, n] -> [dims..., m, n] // if forward: BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha)) // then: a.shape = [dims..., m, k], b.shape = [dims..., m, n], grad.shape = [k, n] // if forward: BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha)) // then: a.shape = [dims..., m, n], b.shape = [dims..., m, k], grad.shape = [n, k] if (ctx->a_requires_grad) { const auto& b = ctx->SavedTensors()[ctx->b_index]; in_grads->at(0) = JUST(functional::MatMul(b, out_grads.at(0), false, true, ctx->alpha)); } if (ctx->b_requires_grad) { const auto& a = ctx->SavedTensors()[ctx->a_index]; in_grads->at(1) = JUST(functional::MatMul(a, out_grads.at(0), false, false, ctx->alpha)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_matmul_grad_b", BroadcastMatmulGradBGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/max_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct MaxPoolGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; }; template class MaxPoolNdGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(MaxPoolGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // dy, x, indice CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs[2]); } return Maybe::Ok(); } Maybe Apply(const MaxPoolGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); if (ctx->grad_requires_grad) { const auto& indices = JUST(VectorAt(ctx->SavedTensors(), 0)); (*in_grads)[0] = JUST(functional::MaxPoolNdGradGrad(out_grads[0], indices, ndims)); } if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_1d_grad", MaxPoolNdGradGrad<1>); REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_2d_grad", MaxPoolNdGradGrad<2>); REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_3d_grad", MaxPoolNdGradGrad<3>); // REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d_grad", MaxPoolNdGradGrad<1>); // REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d_grad", MaxPoolNdGradGrad<2>); // REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d_grad", MaxPoolNdGradGrad<3>); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/nll_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct NLLCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool grad_requires_grad = false; bool has_weight = false; int64_t ignore_index = -100; }; class NLLLossGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; Maybe NLLLossGradGrad::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe NLLLossGradGrad::Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // dy, input, target[, weight] CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->has_weight = inputs.size() == 4; if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs[2]); if (ctx->has_weight) { ctx->SaveTensorForBackward(inputs[3]); } // weight ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->ignore_index = JUST(composed_attrs.GetAttr("ignore_index")); } return Maybe::Ok(); } Maybe NLLLossGradGrad::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3 + ctx->has_weight); if (ctx->grad_requires_grad) { const auto& target = JUST(VectorAt(ctx->SavedTensors(), 0)); if (ctx->has_weight) { auto weight = JUST(VectorAt(ctx->SavedTensors(), 1)); (*in_grads)[0] = JUST(functional::NLLLoss(out_grads[0], target, weight, ctx->ignore_index, "none")); } else { (*in_grads)[0] = JUST(functional::NLLLoss(out_grads[0], target, NullOpt, ctx->ignore_index, "none")); } } if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("nll_grad", NLLLossGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/pow.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct PowXGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool y_requires_grad = false; bool dz_requires_grad = false; size_t x_index = 0; size_t y_index = 1; size_t dz_index = 2; }; class PowXGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(PowXGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, y, dz CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->dz_requires_grad = inputs.at(2)->requires_grad(); ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); if (ctx->x_requires_grad || ctx->y_requires_grad) { ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2)); } return Maybe::Ok(); } Maybe Apply(const PowXGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(3); const auto& x = ctx->SavedTensors().at(ctx->x_index); const auto& y = ctx->SavedTensors().at(ctx->y_index); // dx = y * x^(y-1) * dz // grad_for_x = out_grads * dz * y * [x^(y-1)]' // grad_for_y = out_grads * dz * [x^(y-1) * (1 + y * ln(x))] // grad_for_dz = out_grads * y * x^(y-1) if (ctx->x_requires_grad || ctx->y_requires_grad) { const auto& dz = ctx->SavedTensors().at(ctx->dz_index); const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false)); if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::sequence_function(functional::PowXGrad) .then(std::bind(functional::Mul, std::placeholders::_1, y)) .then(std::bind(functional::Mul, std::placeholders::_1, dz)) .call(x, y_sub_one, out_grads.at(0))); } if (ctx->y_requires_grad) { in_grads->at(1) = JUST(functional::sequence_function(functional::Log) .then(std::bind(functional::Mul, std::placeholders::_1, y)) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/1); }) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Pow(x, y_sub_one)))) .then(std::bind(functional::Mul, std::placeholders::_1, dz)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0))) .call(x)); } } if (ctx->dz_requires_grad) { in_grads->at(2) = JUST(functional::PowXGrad(x, y, out_grads.at(0))); } return Maybe::Ok(); } }; struct PowYGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool y_requires_grad = false; bool dz_requires_grad = false; size_t x_index = 0; size_t y_index = 1; size_t dz_index = 2; size_t dy_index = 3; }; class PowYGradGrad : public OpExprGradFunction { public: // dy = x^y*ln(x)*dz = z*ln(x)*dz Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } Maybe Capture(PowYGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // x, y, dz CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); ctx->dz_requires_grad = inputs.at(2)->requires_grad(); ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); if (ctx->x_requires_grad || ctx->y_requires_grad) { ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); } if (ctx->x_requires_grad) { ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2)); } if (ctx->y_requires_grad) { ctx->dy_index = ctx->SaveTensorForBackward(outputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const PowYGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(3); const auto& x = ctx->SavedTensors().at(ctx->x_index); // dy = x^y * ln(x) * dz = z * ln(x) * dz // grad_for_x = out_grads * dz * [x^(y-1) * (1 + y * ln(x))] // grad_for_y = out_grads * dy' = out_grads * dy * ln(x) // grad_for_dz = out_grads * x^y * ln(x) if (ctx->x_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); const auto& dz = ctx->SavedTensors().at(ctx->dz_index); const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false)); in_grads->at(0) = JUST(functional::sequence_function(functional::Log) .then(std::bind(functional::Mul, std::placeholders::_1, y)) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/1); }) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Pow(x, y_sub_one)))) .then(std::bind(functional::Mul, std::placeholders::_1, dz)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0))) .call(x)); } if (ctx->y_requires_grad) { const auto& dy = ctx->SavedTensors().at(ctx->dy_index); in_grads->at(1) = JUST(functional::sequence_function(functional::Log) .then(std::bind(functional::Mul, std::placeholders::_1, dy)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0))) .call(x)); } if (ctx->dz_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(2) = JUST(functional::PowYGrad(x, y, out_grads.at(0))); } return Maybe::Ok(); } }; REGISTER_OP_EXPR_GRAD_FUNCTION("pow_x_grad", PowXGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("pow_y_grad", PowYGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/scalar_pow.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct ScalarPowGradGradCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; bool grad_requires_grad = false; Scalar operand; }; class ScalarPowGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } ctx->SaveTensorForBackward(inputs.at(0)); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } Maybe Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); in_grads->resize(2); // z = x^a, dx = a * x^(a-1) * dz // grad_for_x = out_grad * a * dz * [x^(a-1)]' // grad_for_dz = out_grad * [x^a]' if (ctx->x_requires_grad) { const auto& grad = ctx->SavedTensors().at(1); const auto operand_sub_one = ctx->operand - Scalar(1); in_grads->at(0) = JUST( functional::sequence_function(functional::Mul) .then(std::bind(functional::ScalarPowGrad, x, std::placeholders::_1, operand_sub_one)) .then([&ctx](const std::shared_ptr& input) { return functional::ScalarMul(ctx->operand, input); }) .call(grad, out_grads.at(0))); } if (ctx->grad_requires_grad) { in_grads->at(1) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; class ScalarReversePowGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); } Maybe Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->grad_requires_grad = inputs.at(1)->requires_grad(); if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); if (has_float_operand) { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("float_operand"))); } else { ctx->operand = Scalar(JUST(composed_attrs.GetAttr("int_operand"))); } ctx->SaveTensorForBackward(inputs.at(0)); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); } return Maybe::Ok(); } Maybe Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); in_grads->resize(2); // z = a^x, dx = a^x * ln(a) * dz // grad_for_x = out_grad * dz * a^x * ln(a) * ln(a) // grad_for_dz = out_grad * [a^x]' if (ctx->x_requires_grad) { const auto& dx = ctx->SavedTensors().at(1); const auto log_operand = std::log(ctx->operand.As()); in_grads->at(0) = JUST(functional::sequence_function(functional::Mul) .then([&log_operand](const std::shared_ptr& input) { return functional::ScalarMul(log_operand, input); }) .call(dx, out_grads.at(0))); } if (ctx->grad_requires_grad) { in_grads->at(1) = JUST(functional::ScalarReversePowGrad(x, out_grads.at(0), ctx->operand)); } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow_grad", ScalarPowGradGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_reverse_pow_grad", ScalarReversePowGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/slice.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct SliceGradGradCaptureState : public AutoGradCaptureState { std::vector start; std::vector stop; std::vector step; }; class SliceGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SliceGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->start = JUST(composed_attrs.GetAttr>("start")); ctx->stop = JUST(composed_attrs.GetAttr>("stop")); ctx->step = JUST(composed_attrs.GetAttr>("step")); return Maybe::Ok(); } Maybe Apply(const SliceGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); in_grads->at(0) = JUST(functional::Slice(out_grads.at(0), ctx->start, ctx->stop, ctx->step, /*enable_view_slice=*/false)); return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("slice_grad", SliceGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/smooth_l1_loss.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct SmoothL1LossGradGradCaptureState : public AutoGradCaptureState { bool grad_requires_grad = false; bool input_requires_grad = false; bool target_requires_grad = false; size_t grad_index = 0; size_t input_index = 0; size_t target_index = 0; float beta = 0.0; }; class SmoothL1LossGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe Capture(SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { // grad, input, target CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg) ctx->grad_requires_grad = inputs[0]->requires_grad(); ctx->input_requires_grad = inputs[1]->requires_grad(); ctx->target_requires_grad = inputs[2]->requires_grad(); if (ctx->input_requires_grad || ctx->target_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]); } ctx->input_index = ctx->SaveTensorForBackward(inputs[1]); ctx->target_index = ctx->SaveTensorForBackward(inputs[2]); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->beta = JUST(composed_attrs.GetAttr("beta")); return Maybe::Ok(); } Maybe Apply(const SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(3); const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index)); const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index)); if (ctx->grad_requires_grad) { (*in_grads)[0] = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta)); } if (ctx->input_requires_grad || ctx->target_requires_grad) { const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index)); auto condition = JUST(functional::sequence_function(functional::Sub) .then(functional::Abs) .then([&ctx](const std::shared_ptr& input) { return functional::ScalarLogicalLess(input, ctx->beta); }) .call(input, target, /*alpha=*/1, /*inplace=*/false)); auto out = JUST(functional::sequence_function(functional::Mul) .then(std::bind(functional::Mul, std::placeholders::_1, condition)) .then([&ctx](const std::shared_ptr& input) { double inv_beta = ctx->beta == 0.0 ? 0.0 : 1.0 / ctx->beta; return functional::ScalarMul(inv_beta, input); }) .call(out_grads[0], grad)); if (ctx->input_requires_grad) { (*in_grads)[1] = out; } if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::Negative(out)); } } return Maybe::Ok(); } private: AttrMap base_attrs_; }; REGISTER_OP_EXPR_GRAD_FUNCTION("smooth_l1_loss_grad", SmoothL1LossGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/autograd/higher_order_gradient_funcs/softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { struct SoftmaxGradGradCaptureState : public AutoGradCaptureState { bool y_requires_grad = false; bool dy_requires_grad = false; }; class SoftmaxGradGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe SoftmaxGradGrad::Init(const OpExpr& op) { return Maybe::Ok(); } Maybe SoftmaxGradGrad::Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { // y, dy CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->y_requires_grad = inputs[0]->requires_grad(); ctx->dy_requires_grad = inputs[1]->requires_grad(); ctx->SaveTensorForBackward(inputs[0]); if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]); return Maybe::Ok(); } Maybe SoftmaxGradGrad::Apply(const SoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); const auto& y = ctx->SavedTensors()[0]; if (ctx->y_requires_grad) { const auto& dy = ctx->SavedTensors()[1]; const std::vector reduce_axis{static_cast(y->ndim() - 1)}; const auto& a = JUST(functional::sequence_function(functional::Mul) .then(std::bind(functional::ReduceSum, std::placeholders::_1, reduce_axis, /*keepdim=*/true, NullOpt)) .then(std::bind(functional::Mul, std::placeholders::_1, dy)) .call(y, out_grads[0])); const auto& b = JUST(functional::sequence_function(functional::Mul) .then(std::bind(functional::ReduceSum, std::placeholders::_1, reduce_axis, /*keepdim=*/true, NullOpt)) .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0])) .call(y, dy)); in_grads->at(0) = JUST(functional::sequence_function(functional::Mul) .then(std::bind(functional::Sub, std::placeholders::_1, a, /*alpha=*/1, /*inplace=*/false)) .then(std::bind(functional::Sub, std::placeholders::_1, b, /*alpha=*/1, /*inplace=*/false)) .call(out_grads[0], dy)); } if (ctx->dy_requires_grad) { in_grads->at(1) = JUST(functional::SoftmaxGrad(out_grads[0], y)); } return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("softmax_grad", SoftmaxGradGrad); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/boxing/asymmetric_broadcast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" namespace oneflow { namespace { Maybe RawCheckAsymmetricBroadcast(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp())); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp())); CHECK_OR_RETURN(out->placement()->Bigger(*in->placement()) || in->placement()->Bigger(*out->placement())); CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU || in->placement()->device_type() == DeviceType::kCUDA); // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckAsymmetricBroadcast = DECORATE(&RawCheckAsymmetricBroadcast, ThreadLocalCachedCopiable); Maybe CalBroadcastRoot(Symbol src_parallel_desc, Symbol dst_parallel_desc) { int64_t machine_id = -1; int64_t device_id = -1; for (int64_t mach_id : src_parallel_desc->sorted_machine_ids()) { bool machine_and_device_id_inited = false; for (int64_t dev_id : src_parallel_desc->sorted_dev_phy_ids(mach_id)) { if (dst_parallel_desc->Containing(mach_id, dev_id)) { machine_id = mach_id; device_id = dev_id; machine_and_device_id_inited = true; break; } } if (machine_and_device_id_inited) { break; } } // Always true, if check failed, there is a bug in oneflow needed to be resolved. CHECK_OR_RETURN(machine_id != -1 && device_id != -1) << Error::RuntimeError() << "Calculate the intersection of placements " "failed during execution of asymmetric broadcast," << ", placement_a: " << *JUST(PlacementToString(src_parallel_desc)) << ", placement_b: " << *JUST(PlacementToString(dst_parallel_desc)) << "! Please submit an issue in `https://github.com/Oneflow-Inc/oneflow/issues` " "and we will fix it as soon as possible"; return machine_id; } static constexpr auto* CachedGetBroadcastRoot = DECORATE(&CalBroadcastRoot, ThreadLocalCached); Maybe EagerCclBroadcast(Symbol parallel_desc, int64_t root, const Shape& shape) { return one::OpBuilder("eager_ccl_broadcast", *JUST(UniqueStr("eager_ccl_broadcast"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Attr>("shape_list", {shape}) .Attr("root", root) .Build(); } static constexpr auto* CachedEagerCclBroadcast = DECORATE(&EagerCclBroadcast, ThreadLocalCachedCopiable); } // namespace Maybe AsymmetricBroadcast(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& in_placement = in->placement(); const auto& out_placement = out->placement(); const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in_placement) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in_placement)) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); if (out->placement()->Bigger(*in->placement())) { const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_placement)); if (out_parallel_id->has_value()) { const auto& broadcast_group = JUST(GetBroadcastGroup(in_placement, out_placement)); Symbol broadcast_placement_cur_rank = JUST(MapAt(*broadcast_group, GlobalProcessCtx::Rank())); int64_t root = JUST(CachedGetBroadcastRoot(in_placement, broadcast_placement_cur_rank)); std::shared_ptr op_expr = JUST(CachedEagerCclBroadcast(broadcast_placement_cur_rank, root, *tensor->shape())); local_tensor = JUST(one::OpInterpUtil::Dispatch(*op_expr, {local_tensor})); } } return one::functional::LocalToGlobal(local_tensor, out_placement, *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false); } COMMAND(RegisterBoxingFunction("asymmetric-broadcast", CheckAsymmetricBroadcast, &AsymmetricBroadcast)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/boxing_dividor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_ #define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class PlacedNdSbp; class BoxingDividor final { public: BoxingDividor(const BoxingDividor&) = delete; BoxingDividor(BoxingDividor&&) = delete; ~BoxingDividor() = default; using FunctionT = std::function>(Symbol in, Symbol out)>; BoxingDividor(const std::string& name, const FunctionT& function) : name_(name), function_(function) {} const std::string& name() const { return name_; } Maybe> operator()(Symbol in, Symbol out) const { return function_(in, out); } private: std::string name_; FunctionT function_; }; } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_ ================================================ FILE: oneflow/core/boxing/boxing_dividor_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/boxing_dividor_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { namespace { Maybe RawReplaceInDeviceType(DeviceType device_type) { return std::make_shared( "ReplaceInDeviceType", [device_type](Symbol in, Symbol out) -> Maybe> { const auto& new_placement = JUST(ReplaceDeviceType(in->placement(), device_type)); return PlacedNdSbp::New(in->nd_sbp(), new_placement); }); } Maybe RawReplaceOutDeviceType(DeviceType device_type) { return std::make_shared( "ReplaceOutDeviceType", [device_type](Symbol in, Symbol out) -> Maybe> { const auto& new_placement = JUST(ReplaceDeviceType(out->placement(), device_type)); return PlacedNdSbp::New(out->nd_sbp(), new_placement); }); } } // namespace decltype(ReplaceInDeviceType) ReplaceInDeviceType = DECORATE(&RawReplaceInDeviceType, ThreadLocalCached); decltype(ReplaceOutDeviceType) ReplaceOutDeviceType = DECORATE(&RawReplaceOutDeviceType, ThreadLocalCached); namespace { Maybe> RawFlattenHierarchy(Symbol placed_nd_sbp) { CHECK_GE_OR_RETURN(placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0) << Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!"; const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0); for (const auto& sbp_parallel : placed_nd_sbp->nd_sbp()->sbp_parallel()) { CHECK_OR_RETURN(sbp_parallel == first_sbp_parallel) << Error::RuntimeError() << "Expected all sbps to be on the same in sbp list during flatten sbps list, but find at " "least two sbps, " << SbpToString(first_sbp_parallel) << " and " << SbpToString(sbp_parallel) << "!"; } std::vector> vec{SymbolOf(first_sbp_parallel)}; const auto& flattened_nd_sbp = JUST(GetNdSbp(vec)); ParallelConf flattened_parallel_conf(placed_nd_sbp->placement()->parallel_conf()); flattened_parallel_conf.clear_hierarchy(); const auto& flattened_placement = SymbolOf(ParallelDesc(flattened_parallel_conf)); return JUST(PlacedNdSbp::New(flattened_nd_sbp, flattened_placement)); } static constexpr auto* FlattenHierarchy = DECORATE(&RawFlattenHierarchy, ThreadLocalCached); Maybe RawFlattenInHierarchy() { return std::make_shared( "FlattenInHierarchy", [](Symbol in, Symbol out) -> Maybe> { return FlattenHierarchy(in); }); } Maybe> RawUnflattenHierarchy(Symbol in_placed_nd_sbp, Symbol out_placed_nd_sbp) { CHECK_GE_OR_RETURN(in_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0) << Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!"; CHECK_GE_OR_RETURN(out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0) << Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!"; const auto& in_sbp_parallel = in_placed_nd_sbp->nd_sbp()->sbp_parallel(0); NdSbp unflattened_nd_sbp; for (int64_t i = 0; i < out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) { unflattened_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(in_sbp_parallel); } return JUST(PlacedNdSbp::New(SymbolOf(unflattened_nd_sbp), out_placed_nd_sbp->placement())); } static constexpr auto* UnflattenHierarchy = DECORATE(&RawUnflattenHierarchy, ThreadLocalCached); Maybe RawUnflattenInHierarchy() { return std::make_shared( "UnflattenInHierarchy", [](Symbol in, Symbol out) -> Maybe> { return UnflattenHierarchy(in, out); }); } Maybe RawUnflattenOutHierarchy() { return std::make_shared( "UnflattenOutHierarchy", [](Symbol in, Symbol out) -> Maybe> { return UnflattenHierarchy(out, in); }); } } // namespace decltype(FlattenInHierarchy) FlattenInHierarchy = DECORATE(&RawFlattenInHierarchy, ThreadLocalCached); decltype(UnflattenInHierarchy) UnflattenInHierarchy = DECORATE(&RawUnflattenInHierarchy, ThreadLocalCached); decltype(UnflattenOutHierarchy) UnflattenOutHierarchy = DECORATE(&RawUnflattenOutHierarchy, ThreadLocalCached); namespace { Maybe> GetAllPartialSumNdSbp(int64_t ndim) { NdSbp partial_sum_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { partial_sum_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel(); } return SymbolOf(partial_sum_nd_sbp); } auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocalCached); Maybe> RawReplaceNdSbpWithPartialSum(Symbol placed_nd_sbp) { Symbol partial_sum_nd_sbp = JUST(CachedGetAllPartialSumNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size())); return JUST(PlacedNdSbp::New(partial_sum_nd_sbp, placed_nd_sbp->placement())); } static constexpr auto* ReplaceNdSbpWithPartialSum = DECORATE(&RawReplaceNdSbpWithPartialSum, ThreadLocalCached); Maybe RawOutPlacementAndPartialSum() { return std::make_shared( "OutPlacementAndPartialSum", [](Symbol in, Symbol out) -> Maybe> { return ReplaceNdSbpWithPartialSum(out); }); } } // namespace decltype(OutPlacementAndPartialSum) OutPlacementAndPartialSum = DECORATE(&RawOutPlacementAndPartialSum, ThreadLocalCached); namespace { Maybe> GetAllBroadcastNdSbp(int64_t ndim) { NdSbp broadcast_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } return SymbolOf(broadcast_nd_sbp); } auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocalCached); Maybe> RawReplaceNdSbpWithBroadcast(Symbol placed_nd_sbp) { Symbol broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size())); return JUST(PlacedNdSbp::New(broadcast_nd_sbp, placed_nd_sbp->placement())); } static constexpr auto* ReplaceNdSbpWithBroadcast = DECORATE(&RawReplaceNdSbpWithBroadcast, ThreadLocalCached); Maybe RawInPlacementAndBroadcast() { return std::make_shared( "InPlacementAndBroadcast", [](Symbol in, Symbol out) -> Maybe> { return ReplaceNdSbpWithBroadcast(in); }); } Maybe RawOutPlacementAndBroadcast() { return std::make_shared( "OutPlacementAndBroadcast", [](Symbol in, Symbol out) -> Maybe> { return ReplaceNdSbpWithBroadcast(out); }); } } // namespace decltype(InPlacementAndBroadcast) InPlacementAndBroadcast = DECORATE(&RawInPlacementAndBroadcast, ThreadLocalCached); decltype(OutPlacementAndBroadcast) OutPlacementAndBroadcast = DECORATE(&RawOutPlacementAndBroadcast, ThreadLocalCached); namespace { Maybe> GetSplitNdSbp(int64_t axis) { NdSbp split_nd_sbp; split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); return SymbolOf(split_nd_sbp); } auto* CachedGetSplitNdSbp = DECORATE(&GetSplitNdSbp, ThreadLocalCached); Maybe RawInPlacementAndSplit(int64_t axis) { return std::make_shared( "InPlacementAndSplit", [=](Symbol in, Symbol out) -> Maybe> { Symbol split_nd_sbp = JUST(CachedGetSplitNdSbp(axis)); return PlacedNdSbp::New(split_nd_sbp, in->placement()); }); } Maybe RawOutPlacementAndSplit(int64_t axis) { return std::make_shared( "OutPlacementAndSplit", [=](Symbol in, Symbol out) -> Maybe> { Symbol split_nd_sbp = JUST(CachedGetSplitNdSbp(axis)); return PlacedNdSbp::New(split_nd_sbp, out->placement()); }); } } // namespace decltype(InPlacementAndSplit) InPlacementAndSplit = DECORATE(&RawInPlacementAndSplit, ThreadLocalCached); decltype(OutPlacementAndSplit) OutPlacementAndSplit = DECORATE(&RawOutPlacementAndSplit, ThreadLocalCached); namespace { Maybe> GetFisrtDeviceOfPlacement(Symbol placement) { ParallelConf parallel_conf; int64_t machine_id = JUST(placement->MachineId4ParallelId(0)); int64_t device_id = JUST(placement->DeviceId4ParallelId(0)); parallel_conf.set_device_tag(placement->device_tag()); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":" + std::to_string(device_id)); for (int64_t i = 0; i < placement->hierarchy()->NumAxes(); ++i) { parallel_conf.mutable_hierarchy()->add_dim(1); } std::shared_ptr parallel_desc; JUST(PhysicalRun([¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf)); return Maybe::Ok(); })); return SymbolOf(*parallel_desc); } Maybe RawInFirstDeviceAndAllBroadcast() { return std::make_shared( "InFirstDeviceAndAllBroadcast", [](Symbol in, Symbol out) -> Maybe> { return PlacedNdSbp::New(JUST(CachedGetAllBroadcastNdSbp(in->nd_sbp()->sbp_parallel_size())), JUST(GetFisrtDeviceOfPlacement(in->placement()))); }); } Maybe RawOutFirstDeviceAndAllBroadcast() { return std::make_shared( "OutFirstDeviceAndAllBroadcast", [](Symbol in, Symbol out) -> Maybe> { return PlacedNdSbp::New( JUST(CachedGetAllBroadcastNdSbp(out->nd_sbp()->sbp_parallel_size())), JUST(GetFisrtDeviceOfPlacement(out->placement()))); }); } } // namespace decltype(InFirstDeviceAndAllBroadcast) InFirstDeviceAndAllBroadcast = DECORATE(&RawInFirstDeviceAndAllBroadcast, ThreadLocalCached); decltype(OutFirstDeviceAndAllBroadcast) OutFirstDeviceAndAllBroadcast = DECORATE(&RawOutFirstDeviceAndAllBroadcast, ThreadLocalCached); namespace { Maybe> RawPlacementAndRepeatFirstSbp(Symbol placed_nd_sbp) { const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0); NdSbp out_nd_sbp; for (int64_t i = 0; i < placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) { out_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(first_sbp_parallel); } return JUST(PlacedNdSbp::New(SymbolOf(out_nd_sbp), placed_nd_sbp->placement())); } static constexpr auto* PlacementAndRepeatFirstSbp = DECORATE(&RawPlacementAndRepeatFirstSbp, ThreadLocalCached); Maybe RawInPlacementAndRepeatFirstSbp() { return std::make_shared( "InPlacementAndRepeatFirstSbp", [](Symbol in, Symbol out) -> Maybe> { return PlacementAndRepeatFirstSbp(in); }); } } // namespace decltype(InPlacementAndRepeatFirstSbp) InPlacementAndRepeatFirstSbp = DECORATE(&RawInPlacementAndRepeatFirstSbp, ThreadLocalCached); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/boxing_dividor_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_ #define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_ #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/boxing/boxing_dividor.h" namespace oneflow { extern Maybe (*ReplaceInDeviceType)(DeviceType device_type); extern Maybe (*ReplaceOutDeviceType)(DeviceType device_type); extern Maybe (*FlattenInHierarchy)(); extern Maybe (*UnflattenInHierarchy)(); extern Maybe (*UnflattenOutHierarchy)(); extern Maybe (*OutPlacementAndPartialSum)(); extern Maybe (*InPlacementAndBroadcast)(); extern Maybe (*OutPlacementAndBroadcast)(); extern Maybe (*InPlacementAndSplit)(int64_t axis); extern Maybe (*OutPlacementAndSplit)(int64_t axis); extern Maybe (*InFirstDeviceAndAllBroadcast)(); extern Maybe (*OutFirstDeviceAndAllBroadcast)(); extern Maybe (*InPlacementAndRepeatFirstSbp)(); } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_ ================================================ FILE: oneflow/core/boxing/boxing_interpreter_status.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/boxing/boxing_interpreter_status.h" namespace oneflow { namespace { Maybe RawMakeBoxingInterpreterStatus(const std::string& boxing_name, const Shape& logical_shape, Symbol in, Symbol out) { std::vector sorted_boxing_names{boxing_name}; BoxingInterpreterStatus status(SymbolOf(sorted_boxing_names), logical_shape, in, out); return status; } Maybe RawMakeComposedBoxingInterpreterStatus( const std::shared_ptr& lhs_status, const std::shared_ptr& rhs_status) { CHECK_OR_RETURN(lhs_status->dst_placed_nd_sbp() == rhs_status->src_placed_nd_sbp()) // always true << Error::RuntimeError() << "Intermediate placed_nd_sbp must be equal when compose boxing interpreter status" << ". lhs_status.dst_nd_sbp: " << NdSbpToString(lhs_status->dst_placed_nd_sbp()->nd_sbp()) << ", rhs_status.dst_nd_sbp: " << NdSbpToString(rhs_status->src_placed_nd_sbp()->nd_sbp()) << ", lhs_status.dst_placement: " << *JUST(PlacementToString(lhs_status->dst_placed_nd_sbp()->placement())) << ", rhs_status.dst_placement: " << *JUST(PlacementToString(rhs_status->src_placed_nd_sbp()->placement())); CHECK_OR_RETURN(lhs_status->logical_shape() == rhs_status->logical_shape()) // always true << Error::RuntimeError() << "Logical_shape must be equal when compose boxing interpreter status" << ". lhs_status.logical_shape: " << (lhs_status->logical_shape().ToString()) << ". rhs_status.logical_shape: " << (rhs_status->logical_shape().ToString()); std::vector sorted_boxing_names(*lhs_status->sorted_boxing_names()); sorted_boxing_names.insert(sorted_boxing_names.end(), rhs_status->sorted_boxing_names()->begin(), rhs_status->sorted_boxing_names()->end()); std::vector> mid_placed_nd_sbp(*lhs_status->mid_placed_nd_sbp()); mid_placed_nd_sbp.emplace_back(lhs_status->dst_placed_nd_sbp()); mid_placed_nd_sbp.insert(mid_placed_nd_sbp.end(), rhs_status->mid_placed_nd_sbp()->begin(), rhs_status->mid_placed_nd_sbp()->end()); BoxingInterpreterStatus status(sorted_boxing_names, lhs_status->logical_shape(), lhs_status->src_placed_nd_sbp(), SymbolOf(mid_placed_nd_sbp), rhs_status->dst_placed_nd_sbp()); return status; } } // namespace decltype(MakeBoxingInterpreterStatus) MakeBoxingInterpreterStatus = DECORATE(&RawMakeBoxingInterpreterStatus, ThreadLocalCachedCopiable); decltype(MakeComposedBoxingInterpreterStatus) MakeComposedBoxingInterpreterStatus = DECORATE(&RawMakeComposedBoxingInterpreterStatus, ThreadLocalCachedCopiable); namespace { Maybe RawGetNdSbpRouting(Symbol src_placed_nd_sbp, Symbol>> mid_placed_nd_sbp, Symbol dst_placed_nd_sbp) { std::ostringstream ss; ss << NdSbpToString(src_placed_nd_sbp->nd_sbp()); for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) { ss << " -> " << NdSbpToString(placed_nd_sbp->nd_sbp()); } ss << " -> " << NdSbpToString(dst_placed_nd_sbp->nd_sbp()); return ss.str(); } Maybe RawGetPlacementRouting( Symbol src_placed_nd_sbp, Symbol>> mid_placed_nd_sbp, Symbol dst_placed_nd_sbp) { std::ostringstream ss; ss << *JUST(PlacementToString(src_placed_nd_sbp->placement())); for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) { ss << " -> " << *JUST(PlacementToString(placed_nd_sbp->placement())); } ss << " -> " << *JUST(PlacementToString(dst_placed_nd_sbp->placement())); return ss.str(); } Maybe RawGetBoxingDesc(Symbol> sorted_boxing_names) { CHECK_OR_RETURN(!sorted_boxing_names->empty()) // always true << Error::RuntimeError() << "boxing_names of eager boxing status can't be empty!"; std::ostringstream ss; ss << sorted_boxing_names->at(0); for (size_t i = 1; i < sorted_boxing_names->size(); ++i) { ss << " -> " << sorted_boxing_names->at(i); } return ss.str(); } static constexpr auto* GetNdSbpRouting = DECORATE(&RawGetNdSbpRouting, ThreadLocalCached); static constexpr auto* GetPlacementRouting = DECORATE(&RawGetPlacementRouting, ThreadLocalCached); static constexpr auto* GetBoxingDesc = DECORATE(&RawGetBoxingDesc, ThreadLocalCached); } // namespace const std::string& BoxingInterpreterStatus::boxing_routing() const { return *CHECK_JUST(GetBoxingDesc(sorted_boxing_names_)); } const std::string& BoxingInterpreterStatus::nd_sbp_routing() const { return *CHECK_JUST(GetNdSbpRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_)); } const std::string& BoxingInterpreterStatus::placement_routing() const { return *CHECK_JUST( GetPlacementRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_)); } } // namespace oneflow ================================================ FILE: oneflow/core/boxing/boxing_interpreter_status.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_ #define ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/common/shape.h" namespace oneflow { class BoxingInterpreterStatus; extern Maybe (*MakeBoxingInterpreterStatus)(const std::string& boxing_name, const Shape& logical_shape, Symbol in, Symbol out); extern Maybe (*MakeComposedBoxingInterpreterStatus)( const std::shared_ptr& lhs_status, const std::shared_ptr& rhs_status); class BoxingInterpreterStatus final { public: BoxingInterpreterStatus(Symbol> sorted_boxing_names, const Shape& logical_shape, Symbol src_placed_nd_sbp, Symbol>> mid_placed_nd_sbp, Symbol dst_placed_nd_sbp) : sorted_boxing_names_(sorted_boxing_names), logical_shape_(logical_shape), src_placed_nd_sbp_(src_placed_nd_sbp), mid_placed_nd_sbp_(mid_placed_nd_sbp), dst_placed_nd_sbp_(dst_placed_nd_sbp) {} BoxingInterpreterStatus(Symbol> sorted_boxing_names, const Shape& logical_shape, Symbol src_placed_nd_sbp, Symbol dst_placed_nd_sbp) : BoxingInterpreterStatus(sorted_boxing_names, logical_shape, src_placed_nd_sbp, SymbolOf(std::vector>()), dst_placed_nd_sbp) {} ~BoxingInterpreterStatus() = default; bool operator==(const BoxingInterpreterStatus& other) const { return this->sorted_boxing_names_ == other.sorted_boxing_names_ && this->src_placed_nd_sbp_ == other.src_placed_nd_sbp_ && this->mid_placed_nd_sbp_ == other.mid_placed_nd_sbp_ && this->dst_placed_nd_sbp_ == other.dst_placed_nd_sbp_; } // Getters Symbol> sorted_boxing_names() const { return sorted_boxing_names_; } const Shape& logical_shape() const { return logical_shape_; } Symbol src_placed_nd_sbp() const { return src_placed_nd_sbp_; } Symbol dst_placed_nd_sbp() const { return dst_placed_nd_sbp_; } Symbol>> mid_placed_nd_sbp() const { return mid_placed_nd_sbp_; } const std::string& boxing_routing() const; const std::string& nd_sbp_routing() const; const std::string& placement_routing() const; private: Symbol> sorted_boxing_names_; const Shape logical_shape_; Symbol src_placed_nd_sbp_; Symbol>> mid_placed_nd_sbp_; Symbol dst_placed_nd_sbp_; }; } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::BoxingInterpreterStatus& status) const { using namespace oneflow; size_t ret = 0; for (const auto& boxing_name : *status.sorted_boxing_names()) { AddHash(&ret, boxing_name); } AddHash(&ret, *status.src_placed_nd_sbp()); for (const auto& mid_placed_nd_sbp : *status.mid_placed_nd_sbp()) { AddHash(&ret, *mid_placed_nd_sbp); } AddHash(&ret, *status.dst_placed_nd_sbp()); return ret; } }; } // namespace std #endif // ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_ ================================================ FILE: oneflow/core/boxing/ccl_boxing_function.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace { class EagerBoxingKernelRegContext final : public user_op::KernelRegContext { public: explicit EagerBoxingKernelRegContext(DeviceType device_type) : device_type_(device_type) {} ~EagerBoxingKernelRegContext() = default; DeviceType device_type() const override { return device_type_; } const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::vector>& inputs() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::vector>& outputs() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { PRINT_BUG_PROMPT_AND_ABORT(); } private: DeviceType device_type_; }; Maybe RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) { EagerBoxingKernelRegContext reg_ctx(device_type); return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx); } static constexpr auto* CheckCclKernelRegistered = DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable); Maybe RawCheckCclP2B(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp())); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp())); CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN( // NOLINT JUST(CheckCclKernelRegistered("eager_ccl_all_reduce", // NOLINT in->placement()->device_type()))); // NOLINT // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckCclP2B = DECORATE(&RawCheckCclP2B, ThreadLocalCachedCopiable); Maybe RawCheckCclP2S(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp())); CHECK_OR_RETURN(NdSbpIsAllSplit(*out->nd_sbp(), 0)); CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0); CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0); CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN( // NOLINT JUST(CheckCclKernelRegistered("eager_ccl_reduce_scatter", // NOLINT in->placement()->device_type()))); // NOLINT // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckCclP2S = DECORATE(&RawCheckCclP2S, ThreadLocalCachedCopiable); Maybe RawCheckCclS2B(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(NdSbpIsAllSplit(*in->nd_sbp(), 0)); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp())); CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0); CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0); CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN( // NOLINT JUST(CheckCclKernelRegistered("eager_ccl_all_gather", // NOLINT in->placement()->device_type()))); // NOLINT // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckCclS2B = DECORATE(&RawCheckCclS2B, ThreadLocalCachedCopiable); Maybe RawCheckCclS2S(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(in->nd_sbp()->sbp_parallel(0).has_split_parallel()); CHECK_OR_RETURN(out->nd_sbp()->sbp_parallel(0).has_split_parallel()); CHECK_NE_OR_RETURN(in->nd_sbp()->sbp_parallel(0).split_parallel().axis(), out->nd_sbp()->sbp_parallel(0).split_parallel().axis()); int64_t in_split_axis = in->nd_sbp()->sbp_parallel(0).split_parallel().axis(); int64_t out_split_axis = out->nd_sbp()->sbp_parallel(0).split_parallel().axis(); CHECK_GT_OR_RETURN(logical_shape.NumAxes(), in_split_axis); CHECK_GT_OR_RETURN(logical_shape.NumAxes(), out_split_axis); CHECK_OR_RETURN(logical_shape.At(in_split_axis) % in->placement()->parallel_num() == 0); CHECK_OR_RETURN(logical_shape.At(out_split_axis) % in->placement()->parallel_num() == 0); CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU || in->placement()->device_type() == DeviceType::kCUDA); // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckCclS2S = DECORATE(&RawCheckCclS2S, ThreadLocalCachedCopiable); } // namespace Maybe CclP2B(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; return JUST(one::functional::GlobalAllReduce(tensor)); } Maybe CclP2S(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; return JUST(one::functional::GlobalReduceScatter(tensor, "sum")); } Maybe CclS2B(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; return JUST(one::functional::GlobalAllGather(tensor)); } Maybe CclS2S(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; return JUST(one::functional::GlobalS2S(tensor, *JUST(GetSbpList(out->nd_sbp())))); } COMMAND(RegisterBoxingFunction("ccl-p-to-b", CheckCclP2B, &CclP2B)); COMMAND(RegisterBoxingFunction("ccl-p-to-s", CheckCclP2S, &CclP2S)); COMMAND(RegisterBoxingFunction("ccl-s-to-b", CheckCclS2B, &CclS2B)); COMMAND(RegisterBoxingFunction("ccl-s-to-s", CheckCclS2S, &CclS2S)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { namespace { Maybe IgnoringDeviceTypeEqual(Symbol lhs, Symbol rhs) { return lhs == JUST(ReplaceDeviceType(rhs, lhs->device_type())); } } // namespace // NOLINTBEGIN(maybe-need-error-msg) Maybe CheckCopyH2D(Symbol in, Symbol out, const Shape& logical_shape) { bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement())); CHECK_OR_RETURN(equal); CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU); CHECK_NE_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU); CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp()); return Maybe::Ok(); } Maybe CheckCopyD2H(Symbol in, Symbol out, const Shape& logical_shape) { bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement())); CHECK_OR_RETURN(equal); CHECK_NE_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU); CHECK_EQ_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU); CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) Maybe CopyBoxingFunction(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const std::shared_ptr& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } COMMAND(RegisterBoxingFunction("copy-h2d", &CheckCopyH2D, &CopyBoxingFunction)); COMMAND(RegisterBoxingFunction("copy-d2h", &CheckCopyD2H, &CopyBoxingFunction)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/eager_boxing_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/registry_error.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace { Maybe CheckEagerBoxingDataType(DataType val) { CHECK_OR_RETURN(val != DataType::kTensorBuffer && val != DataType::kOFRecord) << Error::RuntimeError() << "invalid boxing data type " << ToString(val); return Maybe::Ok(); } } // namespace Maybe EagerBoxingInterpreter::Interpret(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) const { JUST(CheckEagerBoxingDataType(input->dtype()->data_type())); DisableCheckGlobalTensorMetaScope disable_meta_check; const auto& tensor = JUST(InterpretImpl(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_nd_sbp == out_nd_sbp) << Error::RuntimeError() << "The sbp of output tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the output sbp (" << NdSbpToString(out_nd_sbp) << ")"; CHECK_OR_RETURN(tensor_placement == out_parallel_desc) << Error::RuntimeError() << "The placement of output tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the output placement (" << *JUST(PlacementToString(out_parallel_desc)) << ")"; return tensor; } namespace { HashMap* MutName2BoxingChecker() { static HashMap map; return ↦ } HashMap* MutName2BoxingFunction() { static HashMap map; return ↦ } Maybe RawGetBoxingFunction(const std::string& method_name, Symbol in, Symbol out, const Shape& logical_shape) { const auto& Checker = JUST_MSG(MapAt(*MutName2BoxingChecker(), method_name), std::stringstream() << "boxing checker not found. checker_name: " << method_name); JUST(Checker(in, out, logical_shape)); return JUST_MSG(MapAt(*MutName2BoxingFunction(), method_name), std::stringstream() << "boxing function not found. function_name: " << method_name); } } // namespace Maybe GetBoxingFunction(const std::string& method_name, Symbol in, Symbol out, const Shape& logical_shape) { return DECORATE(&RawGetBoxingFunction, ThreadLocalCachedCopiable)(method_name, in, out, logical_shape); } void RegisterBoxingFunction(const std::string& method_name, const BoxingCheckerT& Checker, const BoxingFunctionT& BoxingFunction) { CatchRegistryError([&]() -> Maybe { CHECK_OR_RETURN(MutName2BoxingChecker()->emplace(method_name, Checker).second) << Error::RuntimeError() << "register boxing checker failed: " << method_name; CHECK_OR_RETURN(MutName2BoxingFunction()->emplace(method_name, BoxingFunction).second) << Error::RuntimeError() << "register boxing function failed: " << method_name; return Maybe::Ok(); }); } Maybe AtomicBoxingExpr::Check(Symbol in, Symbol out, const Shape& logical_shape) const { const auto& Checker = JUST_MSG(MapAt(*MutName2BoxingChecker(), boxing_name_), std::stringstream() << "boxing checker not found. checker_name: " << boxing_name_); JUST(Checker(in, out, logical_shape)); return MakeBoxingInterpreterStatus(boxing_name_, logical_shape, in, out); } Maybe AtomicBoxingExpr::GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const { return DECORATE(&RawGetBoxingFunction, ThreadLocalCachedCopiable)(boxing_name_, in, out, logical_shape); } Maybe DivideAndConquerBoxingExpr::Check(Symbol in, Symbol out, const Shape& logical_shape) const { const auto& middle = JUST((*boxing_dividor_)(in, out)); const auto& lhs_status = JUST(lhs_conquer_->Check(in, middle, logical_shape)); const auto& rhs_status = JUST(rhs_conquer_->Check(middle, out, logical_shape)); return MakeComposedBoxingInterpreterStatus(lhs_status, rhs_status); } Maybe DivideAndConquerBoxingExpr::GetBoxingFunction( Symbol in, Symbol out, const Shape& logical_shape) const { const auto& middle = JUST((*boxing_dividor_)(in, out)); const auto& lhs_boxing_func = JUST(lhs_conquer_->GetBoxingFunction(in, middle, logical_shape)); const auto& rhs_boxing_func = JUST(rhs_conquer_->GetBoxingFunction(middle, out, logical_shape)); BoxingFunctionT boxing_function = [lhs_boxing_func, rhs_boxing_func, middle, in, out, &logical_shape]( const std::shared_ptr& tensor, Symbol arg_in, Symbol arg_out) -> Maybe { // Always true, if check failed, there is a bug in oneflow needed to be resolved. CHECK_OR_RETURN(in == arg_in) << Error::RuntimeError() << "The placement (" << *JUST(PlacementToString(arg_in->placement())) << ") and sbp (" << NdSbpToString(in->nd_sbp()) << ") of input tensor must match the placement (" << *JUST(PlacementToString(in->placement())) << ") and sbp (" << NdSbpToString(arg_in->nd_sbp()) << ") used for get this boxing function! Please submit an issue " "in `https://github.com/Oneflow-Inc/oneflow/issues` " "and we will fix it as soon as possible"; CHECK_OR_RETURN(logical_shape == *tensor->shape()) << Error::RuntimeError() << "The logical_shape " << tensor->shape()->ToString() << " of input tensor must match the logical_shape " << logical_shape.ToString() << " used for get this boxing function! Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it " "as soon as possible"; CHECK_OR_RETURN(out == arg_out) << Error::RuntimeError() << "The placement (" << *JUST(PlacementToString(arg_out->placement())) << ") and sbp (" << NdSbpToString(arg_out->nd_sbp()) << ") of output tensor must match the placement (" << *JUST(PlacementToString(out->placement())) << ") and sbp (" << NdSbpToString(out->nd_sbp()) << ") used for get this boxing function! Please submit " "an issue in `https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it " "as soon as possible"; const auto& middle_tensor = JUST((*lhs_boxing_func)(tensor, in, middle)); return JUST((*rhs_boxing_func)(middle_tensor, middle, out)); }; return boxing_function; } Maybe OrBoxingExpr::Check(Symbol in, Symbol out, const Shape& logical_shape) const { const auto& lhs_status = TRY(lhs_boxing_->Check(in, out, logical_shape)); if (lhs_status.IsOk()) { return lhs_status; } return rhs_boxing_->Check(in, out, logical_shape); } Maybe OrBoxingExpr::GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const { if (lhs_boxing_->Check(in, out, logical_shape).IsOk()) { return lhs_boxing_->GetBoxingFunction(in, out, logical_shape); } JUST(rhs_boxing_->Check(in, out, logical_shape)); return rhs_boxing_->GetBoxingFunction(in, out, logical_shape); } Maybe BoxingExpr(const std::string& boxing_name) { JUST(MapAt(*MutName2BoxingChecker(), boxing_name)); auto boxing_expr = std::make_unique(boxing_name); return std::shared_ptr(std::move(boxing_expr)); } Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::string& lhs_conquer, const std::string& rhs_conquer) { return BoxingExpr(boxing_dividor, JUST(BoxingExpr(lhs_conquer)), JUST(BoxingExpr(rhs_conquer))); } Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::shared_ptr& lhs_conquer, const std::string& rhs_conquer) { return BoxingExpr(boxing_dividor, lhs_conquer, JUST(BoxingExpr(rhs_conquer))); } Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::string& lhs_conquer, const std::shared_ptr& rhs_conquer) { return BoxingExpr(boxing_dividor, JUST(BoxingExpr(lhs_conquer)), rhs_conquer); } Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::shared_ptr& lhs_conquer, const std::shared_ptr& rhs_conquer) { auto divide_and_conquer = std::make_unique(boxing_dividor, lhs_conquer, rhs_conquer); return std::shared_ptr(std::move(divide_and_conquer)); } std::shared_ptr operator|(const std::shared_ptr& lhs_boxing, const std::shared_ptr& rhs_boxing) { auto or_boxing = std::make_unique(lhs_boxing, rhs_boxing); return std::shared_ptr(std::move(or_boxing)); } Maybe OptionalBoxing(const std::string& boxing_mame) { return JUST(BoxingExpr(boxing_mame)) | JUST(BoxingExpr("identity")); } } // namespace oneflow ================================================ FILE: oneflow/core/boxing/eager_boxing_interpreter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_ #define ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_ #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/boxing/boxing_dividor.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/boxing/boxing_interpreter_status.h" namespace oneflow { class EagerBoxingInterpreter { public: OF_DISALLOW_COPY_AND_MOVE(EagerBoxingInterpreter); EagerBoxingInterpreter() = default; virtual ~EagerBoxingInterpreter() = default; Maybe Interpret(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) const; virtual Maybe boxing_interpreter_status() const = 0; protected: virtual Maybe InterpretImpl(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) const = 0; }; using BoxingCheckerT = std::function(Symbol in, Symbol out, const Shape& logical_shape)>; using BoxingFunctionT = std::function( const std::shared_ptr& input, Symbol in, Symbol out)>; Maybe GetBoxingFunction(const std::string& method_name, Symbol in, Symbol out, const Shape& logical_shape); void RegisterBoxingFunction(const std::string& method_name, const BoxingCheckerT& Check, const BoxingFunctionT& BoxingFunction); inline void RegisterBoxingFunction( const std::string& method_name, const std::pair& CheckAndBoxing) { RegisterBoxingFunction(method_name, CheckAndBoxing.first, CheckAndBoxing.second); } class NaiveEagerBoxingInterpreter : public EagerBoxingInterpreter { public: explicit NaiveEagerBoxingInterpreter( const std::shared_ptr& boxing_function, const std::shared_ptr& boxing_interpreter_status) : boxing_function_(boxing_function), boxing_interpreter_status_(boxing_interpreter_status) {} NaiveEagerBoxingInterpreter(const NaiveEagerBoxingInterpreter&) = delete; NaiveEagerBoxingInterpreter(NaiveEagerBoxingInterpreter&&) = delete; ~NaiveEagerBoxingInterpreter() override = default; Maybe boxing_interpreter_status() const override { return boxing_interpreter_status_; } private: Maybe InterpretImpl(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) const override { const auto& in_placed_nd_sbp = JUST(PlacedNdSbp::New(in_nd_sbp, in_parallel_desc)); const auto& out_placed_nd_sbp = JUST(PlacedNdSbp::New(out_nd_sbp, out_parallel_desc)); return JUST((*boxing_function_)(input, in_placed_nd_sbp, out_placed_nd_sbp)); } const std::shared_ptr boxing_function_; const std::shared_ptr boxing_interpreter_status_; }; class BoxingExprIf { public: BoxingExprIf(const BoxingExprIf&) = default; BoxingExprIf(BoxingExprIf&&) = default; virtual ~BoxingExprIf() = default; virtual Maybe Check(Symbol in, Symbol out, const Shape& logical_shape) const = 0; virtual Maybe GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const = 0; protected: BoxingExprIf() = default; }; class AtomicBoxingExpr final : public BoxingExprIf { public: AtomicBoxingExpr(const AtomicBoxingExpr&) = delete; AtomicBoxingExpr(AtomicBoxingExpr&&) = delete; ~AtomicBoxingExpr() override = default; explicit AtomicBoxingExpr(const std::string& boxing_name) : BoxingExprIf(), boxing_name_(boxing_name) {} Maybe Check(Symbol in, Symbol out, const Shape& logical_shape) const override; Maybe GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const override; private: const std::string boxing_name_; }; class DivideAndConquerBoxingExpr final : public BoxingExprIf { public: DivideAndConquerBoxingExpr(const DivideAndConquerBoxingExpr&) = delete; DivideAndConquerBoxingExpr(DivideAndConquerBoxingExpr&&) = delete; ~DivideAndConquerBoxingExpr() override = default; explicit DivideAndConquerBoxingExpr(const std::shared_ptr& boxing_dividor, const std::shared_ptr& lhs_conquer, const std::shared_ptr& rhs_conquer) : BoxingExprIf(), boxing_dividor_(boxing_dividor), lhs_conquer_(lhs_conquer), rhs_conquer_(rhs_conquer) {} Maybe Check(Symbol in, Symbol out, const Shape& logical_shape) const override; Maybe GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const override; private: const std::shared_ptr boxing_dividor_; const std::shared_ptr lhs_conquer_; const std::shared_ptr rhs_conquer_; }; class OrBoxingExpr final : public BoxingExprIf { public: OrBoxingExpr(const OrBoxingExpr&) = delete; OrBoxingExpr(OrBoxingExpr&&) = delete; ~OrBoxingExpr() override = default; explicit OrBoxingExpr(const std::shared_ptr& lhs_boxing, const std::shared_ptr& rhs_boxing) : BoxingExprIf(), lhs_boxing_(lhs_boxing), rhs_boxing_(rhs_boxing) {} Maybe Check(Symbol in, Symbol out, const Shape& logical_shape) const override; Maybe GetBoxingFunction(Symbol in, Symbol out, const Shape& logical_shape) const override; private: const std::shared_ptr lhs_boxing_; const std::shared_ptr rhs_boxing_; }; Maybe BoxingExpr(const std::string& boxing_name); Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::string& lhs_conquer, const std::string& rhs_conquer); Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::shared_ptr& lhs_conquer, const std::string& rhs_conquer); Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::string& lhs_conquer, const std::shared_ptr& rhs_conquer); Maybe BoxingExpr(const std::shared_ptr& boxing_dividor, const std::shared_ptr& lhs_conquer, const std::shared_ptr& rhs_conquer); std::shared_ptr operator|(const std::shared_ptr& lhs_boxing, const std::shared_ptr& rhs_boxing); Maybe OptionalBoxing(const std::string& boxing_mame); } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_ ================================================ FILE: oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/constant.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/boxing_dividor_util.h" namespace oneflow { namespace { Maybe IgnoringDeviceTypeEqual(Symbol lhs, Symbol rhs) { if (lhs == rhs) { return true; } return lhs == JUST(ReplaceDeviceType(rhs, lhs->device_type())); } namespace { Maybe OptionalCudaCopy(const std::shared_ptr& core_boxing_expr) { return JUST(BoxingExpr(JUST(ReplaceInDeviceType(DeviceType::kCUDA)), JUST(OptionalBoxing("copy-h2d")), JUST(BoxingExpr(JUST(ReplaceOutDeviceType(DeviceType::kCUDA)), core_boxing_expr, JUST(OptionalBoxing("copy-d2h")))))); } Maybe OptionalCpuCopy(const std::shared_ptr& core_boxing_expr) { return JUST(BoxingExpr(JUST(ReplaceInDeviceType(DeviceType::kCPU)), JUST(OptionalBoxing("copy-d2h")), JUST(BoxingExpr(JUST(ReplaceOutDeviceType(DeviceType::kCPU)), core_boxing_expr, JUST(OptionalBoxing("copy-h2d")))))); } Maybe SymmetricOneDimSxToBBoxingExpr() { return JUST(BoxingExpr(JUST(InPlacementAndSplit(0)), JUST(OptionalBoxing("ccl-s-to-s")), JUST(BoxingExpr("ccl-s-to-b")))); } Maybe SymmetricOneDimPToSxBoxingExpr() { return JUST(BoxingExpr(JUST(OutPlacementAndSplit(0)), JUST(BoxingExpr("ccl-p-to-s")), JUST(OptionalBoxing("ccl-s-to-s")))); } Maybe SymmetricCyclicNDimToNDimBoxingExpr() { return JUST(BoxingExpr(JUST(InPlacementAndRepeatFirstSbp()), JUST(BoxingExpr("symmetric-acyclic-nd-sbp-to-nd-sbp")), JUST(BoxingExpr("symmetric-acyclic-nd-sbp-to-nd-sbp")))) | JUST(BoxingExpr(JUST(InPlacementAndBroadcast()), JUST(BoxingExpr("symmetric-acyclic-nd-sbp-to-nd-sbp")), JUST(BoxingExpr("symmetric-acyclic-nd-sbp-to-nd-sbp")))); } Maybe SymmetricNDimToNDimBoxingExpr() { return JUST(BoxingExpr("symmetric-acyclic-nd-sbp-to-nd-sbp")) | JUST(SymmetricCyclicNDimToNDimBoxingExpr()); } Maybe SymmetricOneDimToNDimBoxingExpr() { return JUST(BoxingExpr(JUST(UnflattenInHierarchy()), JUST(BoxingExpr("unflatten-hierarchy")), JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr("identity")))); } Maybe SymmetricNDimToOneDimBoxingExpr() { return JUST(BoxingExpr(JUST(UnflattenOutHierarchy()), JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr("identity")), JUST(BoxingExpr("flatten-hierarchy")))); } Maybe NToOneBoxingExpr() { return JUST(BoxingExpr(JUST(InPlacementAndBroadcast()), JUST(BoxingExpr("identity")) | JUST(BoxingExpr("ccl-p-to-b")) | JUST(SymmetricOneDimSxToBBoxingExpr()) | JUST(BoxingExpr("naive-p-to-b")) | JUST(BoxingExpr("naive-s-to-b")) | JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr("generic-symmetric-nd-sbp-to-nd-sbp")), JUST(BoxingExpr("naive-b-to-1")))); } Maybe OneToNBoxingExpr() { return JUST(BoxingExpr(JUST(OutPlacementAndPartialSum()), JUST(BoxingExpr("naive-1-to-p")), JUST(BoxingExpr("identity")) | JUST(BoxingExpr("ccl-p-to-b")) | JUST(SymmetricOneDimPToSxBoxingExpr()) | JUST(BoxingExpr("naive-p-to-b")) | JUST(BoxingExpr("naive-p-to-s")) | JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr("generic-symmetric-nd-sbp-to-nd-sbp")))); } Maybe SymmetricOneDimXToBBoxingExpr() { return JUST(BoxingExpr("ccl-p-to-b")) | JUST(BoxingExpr(JUST(InPlacementAndSplit(0)), JUST(BoxingExpr("identity")) | JUST(BoxingExpr("ccl-s-to-s")), JUST(BoxingExpr("ccl-s-to-b")))); } Maybe ASymmetricOneDimXToBBoxingExpr() { return JUST(BoxingExpr(JUST(InPlacementAndBroadcast()), JUST(BoxingExpr("identity")) | JUST(SymmetricOneDimXToBBoxingExpr()), JUST(BoxingExpr("asymmetric-broadcast")))); } Maybe GenericBoxingExpr() { // in_placement contain out_placement or out_placement contain in_placement const auto& boxing_expr_with_inclusive_placement = JUST(BoxingExpr(JUST(OutPlacementAndBroadcast()), JUST(ASymmetricOneDimXToBBoxingExpr()), JUST(BoxingExpr("identity")) | JUST(BoxingExpr("symmetric-b-to-p")) | JUST(BoxingExpr("symmetric-b-to-s")))); // in_placement and out_placement have no containment relationship // n to 1 const auto& lhs_boxing = JUST(NToOneBoxingExpr()); // 1 to 1 -> 1 to n const auto& rhs_boxing = JUST(BoxingExpr(JUST(OutFirstDeviceAndAllBroadcast()), JUST(OptionalBoxing("naive-1-to-1")), JUST(OneToNBoxingExpr()))); return boxing_expr_with_inclusive_placement | JUST(BoxingExpr(JUST(InFirstDeviceAndAllBroadcast()), lhs_boxing, rhs_boxing)); } Maybe RawMainBoxingExpr() { // clang-format off const auto& core = JUST(BoxingExpr("identity")) | JUST(BoxingExpr("copy-h2d")) | JUST(BoxingExpr("copy-d2h")) | JUST(BoxingExpr("ccl-p-to-b")) | JUST(BoxingExpr("ccl-s-to-s")) | JUST(SymmetricOneDimSxToBBoxingExpr()) | JUST(SymmetricOneDimPToSxBoxingExpr()) | JUST(BoxingExpr("symmetric-b-to-p")) | JUST(BoxingExpr("symmetric-b-to-s")) | JUST(BoxingExpr("symmetric-s-to-p")) | JUST(SymmetricOneDimXToBBoxingExpr()) | JUST(ASymmetricOneDimXToBBoxingExpr()) | JUST(BoxingExpr("naive-1-to-1")) | JUST(OneToNBoxingExpr()) | JUST(NToOneBoxingExpr()) | JUST(BoxingExpr("naive-s-to-s")) | JUST(BoxingExpr("naive-s-to-b")) | JUST(BoxingExpr("naive-b-to-s")) | JUST(BoxingExpr("naive-p-to-b")) | JUST(BoxingExpr("naive-p-to-s")) | JUST(BoxingExpr("naive-s-to-p")) | JUST(BoxingExpr("nd-sbp-dim-reduce")) | JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr("generic-symmetric-nd-sbp-to-nd-sbp")) | JUST(SymmetricOneDimToNDimBoxingExpr()) | JUST(SymmetricNDimToOneDimBoxingExpr()) | JUST(GenericBoxingExpr()); // clang-format on return core | JUST(OptionalCudaCopy(core)) | JUST(OptionalCpuCopy(core)); } } // namespace static constexpr auto* MainBoxingExpr = DECORATE(&RawMainBoxingExpr, ThreadLocalCached); Maybe GetBoxingInterpreter(Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc, const Shape& logical_shape) { const auto& in = JUST(PlacedNdSbp::New(in_nd_sbp, in_parallel_desc)); const auto& out = JUST(PlacedNdSbp::New(out_nd_sbp, out_parallel_desc)); const auto& main_boxing_expr = JUST(MainBoxingExpr()); const auto& status = TRY(main_boxing_expr->Check(in, out, logical_shape)); if (status.IsOk()) { const auto& boxing_func = JUST(main_boxing_expr->GetBoxingFunction(in, out, logical_shape)); return std::shared_ptr( new NaiveEagerBoxingInterpreter(boxing_func, JUST(status))); } UNIMPLEMENTED_THEN_RETURN() << Error::RuntimeError() << "global-to-global not supported" << ". from_nd_sbp: " << NdSbpToString(in_nd_sbp) << ", to_nd_sbp: " << NdSbpToString(out_nd_sbp) << ", from_placement: " << *JUST(PlacementToString(in_parallel_desc)) << ", to_placement: " << *JUST(PlacementToString(out_parallel_desc)); } static constexpr auto* CachedGetBoxingInterpreter = DECORATE(&GetBoxingInterpreter, ThreadLocalCachedCopiable); } // namespace Maybe EagerBoxingInterpreterManager::GetEagerBoxingInterpreter( Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc, const Shape& logical_shape) const { return JUST(CachedGetBoxingInterpreter(in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, logical_shape)); } COMMAND( Singleton::SetAllocated(new EagerBoxingInterpreterManager())); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/eager_boxing_interpreter_mgr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_ #define ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_ #include "oneflow/core/boxing/eager_boxing_interpreter.h" namespace oneflow { class EagerBoxingInterpreterManager final { public: OF_DISALLOW_COPY_AND_MOVE(EagerBoxingInterpreterManager); EagerBoxingInterpreterManager() = default; virtual ~EagerBoxingInterpreterManager() = default; Maybe GetEagerBoxingInterpreter(Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc, const Shape& logical_shape) const; }; template struct DisableRecusiveBoxingCall { static_assert(is_maybe::value, "returned value type must be Maybe."); template static RetT Call(Args... arg) { static thread_local bool disable_boxing = false; CHECK_OR_RETURN(!disable_boxing); disable_boxing = true; RetT ret = func(arg...); disable_boxing = false; return ret; } }; } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_ ================================================ FILE: oneflow/core/boxing/eager_boxing_logger.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/boxing/boxing_interpreter_status.h" namespace oneflow { namespace { class NullEagerBoxingLogger final : public EagerBoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(NullEagerBoxingLogger); NullEagerBoxingLogger() = default; ~NullEagerBoxingLogger() override = default; void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const override {} }; class NaiveEagerBoxingLogger final : public EagerBoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(NaiveEagerBoxingLogger); NaiveEagerBoxingLogger() = default; ~NaiveEagerBoxingLogger() override = default; void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const override { LOG(INFO) << prefix << "Boxing route: " << (status.boxing_routing()); LOG(INFO) << prefix << "Logical shape: " << (status.logical_shape().ToString()); LOG(INFO) << prefix << "Altered state of sbp: " << (status.nd_sbp_routing()); LOG(INFO) << prefix << "Altered state of placement: " << (status.placement_routing()); } }; const EagerBoxingLogger* CreateEagerBoxingLogger() { if (IsInDebugMode()) { return new NaiveEagerBoxingLogger(); } else { return new NullEagerBoxingLogger(); } } } // namespace COMMAND(Singleton::SetAllocated(CreateEagerBoxingLogger())); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/eager_boxing_logger.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_ #define ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_ #include "oneflow/core/common/util.h" namespace oneflow { class BoxingInterpreterStatus; class EagerBoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(EagerBoxingLogger); EagerBoxingLogger() = default; virtual ~EagerBoxingLogger() = default; virtual void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_ ================================================ FILE: oneflow/core/boxing/flatten_hierarchy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckFlattenHierarchy(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); for (int i = 0; i < in->nd_sbp()->sbp_parallel_size(); ++i) { const auto& sbp_parallel = in->nd_sbp()->sbp_parallel(i); CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << "nd_sbp axis: " << i; } CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type()); CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num()); ParallelConf flattened_parallel_conf(in->placement()->parallel_conf()); flattened_parallel_conf.clear_hierarchy(); const auto& flatten_placement = SymbolOf(ParallelDesc(flattened_parallel_conf)); CHECK_OR_RETURN(flatten_placement == out->placement()) << "The output placement is not a hierarch-flattened version of the input placement"; for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num(); ++in_parallel_id) { const auto& in_physical_shape = JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id)); const auto& out_physical_shape = JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), in_parallel_id)); CHECK_EQ_OR_RETURN(*in_physical_shape, *out_physical_shape); } return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) } // namespace static constexpr auto* CheckFlattenHierarchy = DECORATE(&RawCheckFlattenHierarchy, ThreadLocalCachedCopiable); Maybe FlattenHierarchy(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("flatten-hierarchy", CheckFlattenHierarchy, &FlattenHierarchy)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/stride.h" namespace oneflow { namespace { bool RawIsAllBroadcastNdSbpAfterDim(Symbol nd_sbp, int dim) { for (int i = dim; i < nd_sbp->sbp_parallel_size(); ++i) { if (!nd_sbp->sbp_parallel(i).has_broadcast_parallel()) { return false; } } return true; } static constexpr auto* IsAllBroadcastNdSbpAfterDim = DECORATE(&RawIsAllBroadcastNdSbpAfterDim, ThreadLocalCached); Maybe> GetBroadcastSbp() { SbpParallel broadcast_sbp; broadcast_sbp.mutable_broadcast_parallel(); return SymbolOf(broadcast_sbp); } auto* CachedGetBroadcastSbp = DECORATE(&GetBroadcastSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe CalcLogicalShape4Axis(const Shape& logical_shape, int axis, Symbol parallel_desc, Symbol nd_sbp) { CHECK_LT_OR_RETURN(axis, nd_sbp->sbp_parallel_size()); // Always true std::shared_ptr sub_logical_shape = std::make_shared(logical_shape); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); int64_t parallel_id = JUST(*opt_parallel_id); const auto& hierarchy_shape = *parallel_desc->hierarchy(); Stride hierarchy_stride(hierarchy_shape); FOR_RANGE(int64_t, i, 0, axis) { const auto& sbp_parallel = nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i); int64_t dim = hierarchy_shape.At(i); const int64_t split_axis = sbp_parallel.split_parallel().axis(); if (sub_logical_shape->At(split_axis) > 0) { CHECK_GE_OR_RETURN(sub_logical_shape->At(split_axis), dim) << Error::RuntimeError() << "The size of tensor (" << sub_logical_shape->At(split_axis) << ") at split dimension (" << i << ") should be greater than or equal to parallle num (" << dim << ")"; const BalancedSplitter bs(sub_logical_shape->At(split_axis), dim); sub_logical_shape->Set(split_axis, bs.At(index).size()); } } } return sub_logical_shape; } static constexpr auto* GetLogicalShape4Axis = DECORATE(&CalcLogicalShape4Axis, ThreadLocalCachedCopiable); Maybe CalcTheFirstDiffAxisBetweenTwoNdSbp(Symbol in_nd_sbp, Symbol out_nd_sbp) { CHECK_EQ_OR_RETURN(in_nd_sbp->sbp_parallel_size(), out_nd_sbp->sbp_parallel_size()); // Always true int dim = 0; for (; dim < in_nd_sbp->sbp_parallel_size(); ++dim) { if (in_nd_sbp->sbp_parallel(dim) != out_nd_sbp->sbp_parallel(dim)) { break; } } return dim; } Maybe Apply1DBoxing(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) { const auto& boxing_interpreter = JUST(Singleton::Get()->GetEagerBoxingInterpreter( in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *input->shape())); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ "\t\tInternal boxing of generic-symmetric-nd-sbp-to-nd-sbp, "); return JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); } Maybe RawCheckGenericSymmetricNdSbpBoxing(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN(in->nd_sbp() != out->nd_sbp()); CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), out->nd_sbp()->sbp_parallel_size()); CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckGenericSymmetricNdSbpBoxing = DECORATE(&RawCheckGenericSymmetricNdSbpBoxing, ThreadLocalCachedCopiable); } // namespace Maybe GenericSymmetricNdSbpBoxing(const std::shared_ptr& input, Symbol in, Symbol out) { const auto& in_parallel_desc = in->placement(); const auto& out_nd_sbp = out->nd_sbp(); const auto& out_parallel_desc = out->placement(); std::shared_ptr output; const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); if (out_parallel_id->has_value()) { output = input; int first_diff_sbp_dim = JUST(CalcTheFirstDiffAxisBetweenTwoNdSbp(in->nd_sbp(), out_nd_sbp)); Symbol broadcast_sbp = JUST(CachedGetBroadcastSbp()); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_parallel_desc)); int64_t parallel_id = JUST(*opt_parallel_id); const auto& hierarchy_shape = *in_parallel_desc->hierarchy(); Stride hierarchy_stride(hierarchy_shape); const auto& logical_shape = input->shape(); // Convert input to broadcast tensor step by step // e.g. // If in_nd_sbp is (S(0), B, S(0)), (S(0), S(0), S(1)) // Altered state of sbp is (S(0), B, S(0)) -> (S(0), B, B) for (int64_t i = out_nd_sbp->sbp_parallel_size() - 1; i >= first_diff_sbp_dim; --i) { const auto& nd_sbp = JUST(output->nd_sbp()); const auto& sbp_parallel = nd_sbp->sbp_parallel(i); if (sbp_parallel.has_broadcast_parallel()) { continue; } const auto& one_dim_nd_sbp = JUST(SbpToNdSbp(sbp_parallel)); const auto& sub_logical_shape = *JUST(GetLogicalShape4Axis(*logical_shape, i, in_parallel_desc, nd_sbp)); std::shared_ptr local_tensor = JUST(output->cur_rank_phy_tensor()); const auto& sub_parallel_desc = JUST(CalcSubParallelDesc4Axis(in_parallel_desc, i)); int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i); const auto& physical_shape = JUST(GetPhysicalShape(sub_logical_shape, *one_dim_nd_sbp, *sub_parallel_desc, index)); CHECK_EQ_OR_RETURN(*physical_shape, *local_tensor->shape()) << Error::RuntimeError() << "Invalid input tensor, size of local tensor (" << local_tensor->shape()->ToString() << ") does not match global tensor (" << logical_shape->ToString() << ")!"; std::shared_ptr sub_global_tensor = JUST(one::functional::LocalToGlobal( local_tensor, sub_parallel_desc, *JUST(GetSbpList(one_dim_nd_sbp)), sub_logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false)); sub_global_tensor = JUST(Apply1DBoxing(sub_global_tensor, one_dim_nd_sbp, JUST(SbpToNdSbp(broadcast_sbp)), sub_parallel_desc, sub_parallel_desc)); local_tensor = JUST(sub_global_tensor->cur_rank_phy_tensor()); const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, *broadcast_sbp, i)); output = JUST(one::functional::LocalToGlobal( local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } CHECK_OR_RETURN(IsAllBroadcastNdSbpAfterDim(JUST(output->nd_sbp()), first_diff_sbp_dim)) << Error::RuntimeError() << "Compute generic-symmetric-nd-sbp-to-nd-sbp failed. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible"; // Convert broadcast tensor to output with out_nd_sbp data step by step // e.g. // If out_nd_sbp is (S(0), S(0), S(1)) // Altered state of sbp is (S(0), B, B) -> (S(0), S(0), B) -> (S(0), S(0), S(1)) std::shared_ptr sub_logical_shape = JUST(GetLogicalShape4Axis( *logical_shape, first_diff_sbp_dim, in_parallel_desc, JUST(output->nd_sbp()))); for (int64_t i = first_diff_sbp_dim; i < out_nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = out_nd_sbp->sbp_parallel(i); if (sbp_parallel.has_broadcast_parallel()) { continue; } const auto& nd_sbp = JUST(output->nd_sbp()); const auto& sub_parallel_desc = JUST(CalcSubParallelDesc4Axis(in_parallel_desc, i)); std::shared_ptr local_tensor = JUST(output->cur_rank_phy_tensor()); std::shared_ptr sub_global_tensor = JUST(one::functional::LocalToGlobal( local_tensor, sub_parallel_desc, *JUST(GetSbpList(JUST(SbpToNdSbp(broadcast_sbp)))), *sub_logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false)); const auto& one_dim_nd_sbp = JUST(SbpToNdSbp(sbp_parallel)); sub_global_tensor = JUST(Apply1DBoxing(sub_global_tensor, JUST(SbpToNdSbp(broadcast_sbp)), one_dim_nd_sbp, sub_parallel_desc, sub_parallel_desc)); local_tensor = JUST(sub_global_tensor->cur_rank_phy_tensor()); int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i); const auto& physical_shape = JUST(GetPhysicalShape(*sub_logical_shape, *one_dim_nd_sbp, *sub_parallel_desc, index)); CHECK_EQ_OR_RETURN(*physical_shape, *local_tensor->shape()) << Error::RuntimeError() << "Compute generic-symmetric-nd-sbp-to-nd-sbp failed. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible"; const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, sbp_parallel, i)); output = JUST(one::functional::LocalToGlobal( local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false)); // physical_shape of this axis is logical shape of next axis sub_logical_shape = physical_shape; } } else { one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(), input->memory_format(), out_nd_sbp, out_parallel_desc); const auto& tensor_impl = JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false)); output = std::make_shared(tensor_impl); } return output; } COMMAND(RegisterBoxingFunction("generic-symmetric-nd-sbp-to-nd-sbp", CheckGenericSymmetricNdSbpBoxing, &GenericSymmetricNdSbpBoxing)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/identity_boxing_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace { // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckIdentity(Symbol in, Symbol out, const Shape& logical_shape) { if (in->placement()->parallel_num() == 1) { CHECK_OR_RETURN(in->placement()->EqualsIgnoringHierarchy(*out->placement())); return Maybe::Ok(); } CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) } // namespace Maybe GetIdentity(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; // reset sbp if parallel_num == 1 and reset transport_token const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("identity", DECORATE(&RawCheckIdentity, ThreadLocalCachedCopiable), &GetIdentity)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_1_to_p_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" namespace oneflow { namespace { bool NdSbpIsAllPartialSum(Symbol nd_sbp) { for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) { if (!sbp_parallel.has_partial_sum_parallel()) { return false; } } return true; } // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaive1ToP(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), 1); CHECK_OR_RETURN(NdSbpIsAllPartialSum(out->nd_sbp())); CHECK_OR_RETURN(out->placement()->Bigger(*in->placement())); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaive1ToP = DECORATE(&RawCheckNaive1ToP, ThreadLocalCachedCopiable); } // namespace Maybe Naive1ToP(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; int64_t root = JUST(tensor_placement->MachineId4ParallelId(0)); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (root == GlobalProcessCtx::Rank() || !out_parallel_id->has_value()) { // do nothing } else { const std::string& device_type = tensor_placement->device_tag(); local_tensor = JUST(one::functional::Constant(*tensor->shape(), 0, tensor->dtype(), JUST(Device::New(device_type)))); } return JUST(one::functional::LocalToGlobal( local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("naive-1-to-p", CheckNaive1ToP, &Naive1ToP)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_b_to_1_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" namespace oneflow { namespace { Maybe RawCheckNaiveBTo1(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp())); CHECK_OR_RETURN(in->placement()->Bigger(*out->placement())); // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckNaiveBTo1 = DECORATE(&RawCheckNaiveBTo1, ThreadLocalCachedCopiable); } // namespace Maybe NaiveBTo1(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); return JUST(one::functional::LocalToGlobal( local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("naive-b-to-1", CheckNaiveBTo1, &NaiveBTo1)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_b_to_s_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); bool RawIsBroadcastSbp(Symbol sbp_parallel) { return sbp_parallel->has_broadcast_parallel(); } static constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaiveBToS(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsBroadcastSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaiveBToS = DECORATE(&RawCheckNaiveBToS, ThreadLocalCachedCopiable); } // namespace Maybe NaiveBToS(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerBToS( local_tensor, tensor_placement, out->placement(), *sbp_list, *tensor->shape())); } } return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } static constexpr auto* NaiveBToSWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaiveBToS, EagerSliceBoxingType::kNaiveBToS); COMMAND(RegisterBoxingFunction("naive-b-to-s", CheckNaiveBToS, NaiveBToSWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_p_to_b_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsPartialSumSbp(Symbol sbp_parallel) { return sbp_parallel->has_partial_sum_parallel(); } static constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached); bool RawIsBroadcastSbp(Symbol sbp_parallel) { return sbp_parallel->has_broadcast_parallel(); } static constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaivePToB(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsPartialSumSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsBroadcastSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaivePToB = DECORATE(&RawCheckNaivePToB, ThreadLocalCachedCopiable); } // namespace Maybe NaivePToB(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerPToB(local_tensor, tensor_placement, out->placement(), *tensor->shape())); } } const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } static constexpr auto* NaivePToBWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaivePToB, EagerSliceBoxingType::kNaivePToB); COMMAND(RegisterBoxingFunction("naive-p-to-b", CheckNaivePToB, NaivePToBWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_p_to_s_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsPartialSumSbp(Symbol sbp_parallel) { return sbp_parallel->has_partial_sum_parallel(); } static constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached); bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaivePToS(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsPartialSumSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaivePToS = DECORATE(&RawCheckNaivePToS, ThreadLocalCachedCopiable); } // namespace Maybe NaivePToS(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerPToS( local_tensor, tensor_placement, out->placement(), *sbp_list, *tensor->shape())); } } return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ true, /*copy=*/false)); } static constexpr auto* NaivePToSWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaivePToS, EagerSliceBoxingType::kNaivePToS); COMMAND(RegisterBoxingFunction("naive-p-to-s", CheckNaivePToS, NaivePToSWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_s_to_b_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); bool RawIsBroadcastSbp(Symbol sbp_parallel) { return sbp_parallel->has_broadcast_parallel(); } static constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaiveSToB(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsBroadcastSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaiveSToB = DECORATE(&RawCheckNaiveSToB, ThreadLocalCachedCopiable); } // namespace Maybe NaiveSToB(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerSToB(local_tensor, tensor_placement, out->placement(), *JUST(GetSbpList(tensor_nd_sbp)), *tensor->shape())); } } const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } static constexpr auto* NaiveSToBWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaiveSToB, EagerSliceBoxingType::kNaiveSToB); COMMAND(RegisterBoxingFunction("naive-s-to-b", CheckNaiveSToB, NaiveSToBWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_s_to_p_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); bool RawIsPartialSumSbp(Symbol sbp_parallel) { return sbp_parallel->has_partial_sum_parallel(); } static constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaiveSToP(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsPartialSumSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaiveSToP = DECORATE(&RawCheckNaiveSToP, ThreadLocalCachedCopiable); } // namespace Maybe NaiveSToP(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerSToP(local_tensor, tensor_placement, out->placement(), *JUST(GetSbpList(tensor_nd_sbp)), *tensor->shape())); } } const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } static constexpr auto* NaiveSToPWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaiveSToP, EagerSliceBoxingType::kNaiveSToP); COMMAND(RegisterBoxingFunction("naive-s-to-p", CheckNaiveSToP, NaiveSToPWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/naive_s_to_s_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/boxing/slice_boxing_util.h" namespace oneflow { namespace { bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaiveSToS(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaiveSToS = DECORATE(&RawCheckNaiveSToS, ThreadLocalCachedCopiable); } // namespace Maybe NaiveSToS(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& in_sbp_list = JUST(GetSbpList(tensor_nd_sbp)); const auto& out_sbp_list = JUST(GetSbpList(out->nd_sbp())); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); { const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement())); if (in_parallel_id->has_value() || out_parallel_id->has_value()) { local_tensor = JUST(one::functional::EagerNaiveSToS(local_tensor, tensor_placement, out->placement(), *in_sbp_list, *out_sbp_list, *tensor->shape())); } } return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *out_sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } static constexpr auto* NaiveSToSWithAutoConvert = EAGER_SLICE_BOXING_WARPPER(&NaiveSToS, EagerSliceBoxingType::kNaiveSToS); COMMAND(RegisterBoxingFunction("naive-s-to-s", CheckNaiveSToS, NaiveSToSWithAutoConvert)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/sbp_infer_util.h" namespace oneflow { namespace { Maybe, Symbol>> RawInOutPlacedNdSbpDimReduce( Symbol in, Symbol out, const Shape& logical_shape) { // reduce hierarchy ParallelDesc reduced_in_placement = *in->placement(); ParallelDesc reduced_out_placement = *out->placement(); NdSbp reduced_in_nd_sbp; NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(), &reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_shape); return std::make_tuple( JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))), JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement)))); } constexpr auto* InOutPlacedNdSbpDimReduce = DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCachedCopiable); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckParallelDimReduce(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_OR_RETURN(in->nd_sbp()->sbp_parallel_size() > 1 || out->nd_sbp()->sbp_parallel_size() > 1); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); Symbol reduced_in; Symbol reduced_out; std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, logical_shape)); for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num(); ++in_parallel_id) { const auto& in_physical_shape = JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id)); const auto& reduce_in_physical_shape = JUST(GetPhysicalShape( logical_shape, *reduced_in->nd_sbp(), *reduced_in->placement(), in_parallel_id)); CHECK_EQ_OR_RETURN(*in_physical_shape, *reduce_in_physical_shape); } for (int64_t out_parallel_id = 0; out_parallel_id < out->placement()->parallel_num(); ++out_parallel_id) { const auto& out_physical_shape = JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), out_parallel_id)); const auto& reduce_out_physical_shape = JUST(GetPhysicalShape( logical_shape, *reduced_out->nd_sbp(), *reduced_out->placement(), out_parallel_id)); CHECK_EQ_OR_RETURN(*out_physical_shape, *reduce_out_physical_shape); } if (reduced_in->nd_sbp()->sbp_parallel_size() == 1 && reduced_out->nd_sbp()->sbp_parallel_size() == 1) { return Maybe::Ok(); } if ((reduced_in->placement() != in->placement() || reduced_out->placement() != out->placement()) && reduced_in->placement() == reduced_out->placement()) { return Maybe::Ok(); } return Error::CheckFailedError(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckParallelDimReduce = DECORATE(&RawCheckParallelDimReduce, ThreadLocalCachedCopiable); } // namespace Maybe ParallelDimReduce(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; Symbol reduced_in; Symbol reduced_out; std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, *tensor->shape())); const std::shared_ptr& local_tensor = JUST(tensor->cur_rank_phy_tensor()); std::shared_ptr reduced_in_tensor = JUST(one::functional::LocalToGlobal( local_tensor, reduced_in->placement(), *JUST(GetSbpList(reduced_in->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); const auto& boxing_interpreter = JUST(Singleton::Get()->GetEagerBoxingInterpreter( reduced_in->nd_sbp(), reduced_out->nd_sbp(), reduced_in->placement(), reduced_out->placement(), *tensor->shape())); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ "\t\tInternal boxing of nd-sbp-dim-reduce, "); std::shared_ptr reduced_out_tensor = JUST( boxing_interpreter->Interpret(reduced_in_tensor, reduced_in->nd_sbp(), reduced_out->nd_sbp(), reduced_in->placement(), reduced_out->placement())); const std::shared_ptr& reduced_out_local_tensor = JUST(reduced_out_tensor->cur_rank_phy_tensor()); return JUST(one::functional::LocalToGlobal( reduced_out_local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } COMMAND(RegisterBoxingFunction("nd-sbp-dim-reduce", CheckParallelDimReduce, &ParallelDimReduce)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/one_to_one_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/user/kernels/communicate_util.h" namespace oneflow { namespace { // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckNaiveOneToOne(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), 1); CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); CHECK_OR_RETURN(in->placement() != out->placement()); CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type())); // NOLINT return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckNaiveOneToOne = DECORATE(&RawCheckNaiveOneToOne, ThreadLocalCachedCopiable); } // namespace Maybe NaiveOneToOne(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); int64_t src = JUST(tensor_placement->MachineId4ParallelId(0)); int64_t dst = JUST(out->placement()->MachineId4ParallelId(0)); bool copy = true; if (src != dst) { copy = false; if (GlobalProcessCtx::Rank() == src) { JUST(one::functional::Send(local_tensor, dst, /* send_meta */ false)); } if (GlobalProcessCtx::Rank() == dst) { local_tensor = JUST(one::functional::Recv(src, *tensor->shape(), tensor->dtype(), JUST(local_tensor->device()), NullOpt)); } } return JUST(one::functional::LocalToGlobal( local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/copy)); } COMMAND(RegisterBoxingFunction("naive-1-to-1", CheckNaiveOneToOne, &NaiveOneToOne)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/slice_boxing_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/boxing/slice_boxing_util.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/user/kernels/communicate_util.h" namespace oneflow { namespace private_details { Maybe PreprocessInputTensor4SliceBoxing(const std::shared_ptr& tensor, const std::string& log_prefix) { const auto& tensor_placement = JUST(tensor->parallel_desc()); if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; } const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); Symbol new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU)); const auto& boxing_interpreter = JUST(Singleton::Get()->GetEagerBoxingInterpreter( tensor_nd_sbp, tensor_nd_sbp, tensor_placement, new_placement, *tensor->shape())); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix); return JUST(boxing_interpreter->Interpret(tensor, tensor_nd_sbp, tensor_nd_sbp, tensor_placement, new_placement)); } Maybe PostprocessOutputTensor4SliceBoxing(const std::shared_ptr& tensor, Symbol placed_nd_sbp, const std::string& log_prefix) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_nd_sbp == placed_nd_sbp->nd_sbp()) << Error::RuntimeError() << "Compute slice boxing failed. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible"; CHECK_OR_RETURN(tensor_placement->EqualsIgnoringDeviceType(*placed_nd_sbp->placement())) << Error::RuntimeError() << "Compute slice boxing failed. Please submit an issue in " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "possible"; if (JUST(tensor->parallel_desc()) == placed_nd_sbp->placement()) { return tensor; } const auto& boxing_interpreter = JUST(Singleton::Get()->GetEagerBoxingInterpreter( placed_nd_sbp->nd_sbp(), placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()), placed_nd_sbp->placement(), *tensor->shape())); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix); return JUST(boxing_interpreter->Interpret(tensor, placed_nd_sbp->nd_sbp(), placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()), placed_nd_sbp->placement())); } const std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type) { static thread_local const HashMap boxing_type2log_prefix = { {EagerSliceBoxingType::kNaiveBToS, "\t\tInternal boxing of naive-b-to-s, "}, {EagerSliceBoxingType::kNaivePToB, "\t\tInternal boxing of naive-p-to-b, "}, {EagerSliceBoxingType::kNaivePToS, "\t\tInternal boxing of naive-p-to-s, "}, {EagerSliceBoxingType::kNaiveSToB, "\t\tInternal boxing of naive-s-to-b, "}, {EagerSliceBoxingType::kNaiveSToP, "\t\tInternal boxing of naive-s-to-p, "}, {EagerSliceBoxingType::kNaiveSToS, "\t\tInternal boxing of naive-s-to-s, "}}; return CHECK_JUST(MapAt(boxing_type2log_prefix, boxing_type)); } } // namespace private_details } // namespace oneflow ================================================ FILE: oneflow/core/boxing/slice_boxing_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_ #define ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_ #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { enum class EagerSliceBoxingType : unsigned int; namespace private_details { // Copy to cpu if device of input tensor is not cpu or cuda, otherwise return self Maybe PreprocessInputTensor4SliceBoxing(const std::shared_ptr& tensor, const std::string& log_prefix); // Copy to corresponding device if device of output tensor is not same with that of placed_nd_sbp, // otherwise return self Maybe PostprocessOutputTensor4SliceBoxing(const std::shared_ptr& tensor, Symbol placed_nd_sbp, const std::string& log_prefix); const std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type); } // namespace private_details enum class EagerSliceBoxingType : unsigned int { kNaiveBToS = 0, kNaivePToB = 1, kNaivePToS = 2, kNaiveSToB = 3, kNaiveSToP = 4, kNaiveSToS = 5 }; template struct EagerSliceBoxingAutoConvert { template (*func)(const std::shared_ptr&, Symbol, Symbol)> static Maybe Call(const std::shared_ptr& tensor, Symbol in, Symbol out) { std::shared_ptr processed_in_tensor = JUST(private_details::PreprocessInputTensor4SliceBoxing( tensor, private_details::LogPrefix4EagerSliceBoxingType(boxing_type))); const auto& new_in = JUST(PlacedNdSbp::New(in->nd_sbp(), JUST(processed_in_tensor->parallel_desc()))); Symbol new_out_placement = JUST(ReplaceDeviceType( out->placement(), JUST(processed_in_tensor->parallel_desc())->device_type())); const auto& new_out = JUST(PlacedNdSbp::New(out->nd_sbp(), new_out_placement)); std::shared_ptr out_tensor = JUST(func(processed_in_tensor, new_in, new_out)); return JUST(private_details::PostprocessOutputTensor4SliceBoxing( out_tensor, out, private_details::LogPrefix4EagerSliceBoxingType(boxing_type))); } }; #define EAGER_SLICE_BOXING_WARPPER(fn_ptr, boxing_type) \ (&EagerSliceBoxingAutoConvert::Call) } // namespace oneflow #endif // ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_ ================================================ FILE: oneflow/core/boxing/symmetric_acyclic_nd_sbp_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace { Maybe ReinterpterGlobalTensor(const std::shared_ptr& tensor, const Shape& shape, Symbol parallel_desc, Symbol nd_sbp) { const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); std::shared_ptr pyhsical_shape = JUST(GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id))); std::shared_ptr x = JUST(tensor->cur_rank_phy_tensor()); if (*x->shape() != *pyhsical_shape) { x = JUST(one::functional::Reshape(x, *pyhsical_shape)); } return JUST(one::functional::LocalToGlobal(x, parallel_desc, *JUST(GetSbpList(nd_sbp)), shape, tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } Maybe Apply1DBoxing(const std::shared_ptr& input, Symbol in_nd_sbp, Symbol out_nd_sbp, Symbol in_parallel_desc, Symbol out_parallel_desc) { const auto& boxing_interpreter = JUST(Singleton::Get()->GetEagerBoxingInterpreter( in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *input->shape())); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ "\t\tInternal boxing of symmetric-acyclic-nd-sbp-to-nd-sbp, "); return JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); } // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckSymmetricAcyclicNdSbpBoxing(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN(in->nd_sbp() != out->nd_sbp()); CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), out->nd_sbp()->sbp_parallel_size()); CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); JUST(CheckIsNdSbpBoxingAcyclicWithDecompose(in, out, logical_shape)); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckSymmetricAcyclicNdSbpBoxing = DECORATE(&RawCheckSymmetricAcyclicNdSbpBoxing, ThreadLocalCachedCopiable); } // namespace Maybe SymmetricAcyclicNdSbpBoxing(const std::shared_ptr& input, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(input->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(input->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& out_nd_sbp = out->nd_sbp(); const auto& out_parallel_desc = out->placement(); std::shared_ptr output; const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); if (out_parallel_id->has_value()) { const auto& tensor_meta = JUST(input->global_tensor_meta()); const auto& naive_transformations = JUST(DecomposeIntoNaiveTransformations(tensor_meta, out_nd_sbp)); std::shared_ptr tensor = input; for (const auto& naive_transformation : *naive_transformations) { const auto& sub_tensor_meta = naive_transformation.global_tensor_meta; tensor = JUST(ReinterpterGlobalTensor(tensor, sub_tensor_meta->shape(), sub_tensor_meta->parallel_desc(), sub_tensor_meta->nd_sbp())); tensor = JUST(Apply1DBoxing(tensor, sub_tensor_meta->nd_sbp(), naive_transformation.dst_nd_sbp, sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc())); } output = JUST(ReinterpterGlobalTensor(tensor, *input->shape(), out_parallel_desc, out_nd_sbp)); } else { one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(), input->memory_format(), out_nd_sbp, out_parallel_desc); const auto& tensor_impl = JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false)); output = std::make_shared(tensor_impl); } return output; } COMMAND(RegisterBoxingFunction("symmetric-acyclic-nd-sbp-to-nd-sbp", CheckSymmetricAcyclicNdSbpBoxing, &SymmetricAcyclicNdSbpBoxing)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/symmetric_b_to_p_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace { Maybe RawCheckSymmetricBToP(Symbol in, Symbol out, const Shape& logical_shape) { // NOLINTBEGIN(maybe-need-error-msg) CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp())); CHECK_OR_RETURN(NdSbpIsAllPartialSum(*out->nd_sbp())); CHECK_OR_RETURN(in->placement() == out->placement()); // NOLINTEND(maybe-need-error-msg) return Maybe::Ok(); } static constexpr auto* CheckSymmetricBToP = DECORATE(&RawCheckSymmetricBToP, ThreadLocalCachedCopiable); } // namespace Maybe SymmetricBToP(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; int64_t root = JUST(tensor_placement->MachineId4ParallelId(0)); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); if (root == GlobalProcessCtx::Rank()) { // do nothing } else { local_tensor = JUST(one::functional::ZerosLike(local_tensor)); } return JUST(one::functional::LocalToGlobal( local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("symmetric-b-to-p", CheckSymmetricBToP, &SymmetricBToP)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/symmetric_b_to_s_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace { bool RawIsBroadcastSbp(Symbol sbp_parallel) { return sbp_parallel->has_broadcast_parallel(); } static constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached); bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckSymmetricB2S(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsBroadcastSbp(SymbolOf(in->nd_sbp()->sbp_parallel(0)))); CHECK_OR_RETURN(IsSplitSbp(SymbolOf(out->nd_sbp()->sbp_parallel(0)))); CHECK_OR_RETURN(in->placement() == out->placement()); // NOLINT CHECK_OR_RETURN(in->placement()->device_type() != DeviceType::kInvalidDevice // NOLINT && in->placement()->device_type() != kMeta // NOLINT && in->placement()->device_type() != DeviceType::kMockDevice); // NOLINT return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckSymmetricB2S = DECORATE(&RawCheckSymmetricB2S, ThreadLocalCachedCopiable); } // namespace Maybe SymmetricB2S(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& local_shape = *tensor->shape(); std::shared_ptr local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement)); if (parallel_id->has_value()) { const TensorSliceView& in_slice = GetTensorSliceView4ParallelId( *tensor_placement->hierarchy(), *tensor_nd_sbp, local_shape, JUST(*parallel_id)); CHECK(!in_slice.IsEmpty()); const TensorSliceView& out_slice = GetTensorSliceView4ParallelId( *tensor_placement->hierarchy(), *out->nd_sbp(), local_shape, JUST(*parallel_id)); CHECK(!out_slice.IsEmpty()); const TensorSliceView& intersection = out_slice.Intersect(in_slice); CHECK(!intersection.IsEmpty()); const std::vector& range_vec = intersection.range_vec(); std::vector start; std::vector stop; std::vector step(range_vec.size(), 1); for (const auto& range : range_vec) { start.emplace_back(range.begin()); stop.emplace_back(range.end()); } local_tensor = JUST(one::functional::Slice(local_tensor, start, stop, step, /*enable_view_slice=*/false)); } return JUST(one::functional::LocalToGlobal( local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false)); } COMMAND(RegisterBoxingFunction("symmetric-b-to-s", CheckSymmetricB2S, &SymmetricB2S)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/symmetric_s_to_p_boxing.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" namespace oneflow { namespace { bool RawIsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached); bool RawIsPartialSumSbp(Symbol sbp_parallel) { return sbp_parallel->has_partial_sum_parallel(); } static constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached); Maybe EagerSymmetricSToP(Symbol parallel_desc, Symbol src_sbp, const Shape& logical_shape) { return one::OpBuilder("eager_symmetric_s_to_p", *JUST(UniqueStr("eager_symmetric_s_to_p"))) .Input("in") .Output("out") .Attr("in_split_axis", src_sbp->split_parallel().axis()) .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } static constexpr auto* CachedEagerSymmetricSToPOpExpr = DECORATE(&EagerSymmetricSToP, ThreadLocalCachedCopiable); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckSymmetricSToP(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(IsPartialSumSbp(out->nd_sbp()->sbp_parallel(0))); CHECK_OR_RETURN(in->placement() == out->placement()); return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) static constexpr auto* CheckSymmetricSToP = DECORATE(&RawCheckSymmetricSToP, ThreadLocalCachedCopiable); } // namespace Maybe SymmetricSToP(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; std::shared_ptr op_expr = JUST(CachedEagerSymmetricSToPOpExpr( tensor_placement, SymbolOf(tensor_nd_sbp->sbp_parallel(0)), *tensor->shape())); return JUST(one::OpInterpUtil::Dispatch(*op_expr, {tensor})); } COMMAND(RegisterBoxingFunction("symmetric-s-to-p", CheckSymmetricSToP, &SymmetricSToP)); } // namespace oneflow ================================================ FILE: oneflow/core/boxing/unflatten_hierarchy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckUnflattenHierarchy(Symbol in, Symbol out, const Shape& logical_shape) { CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); CHECK_GT_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); for (int i = 0; i < out->nd_sbp()->sbp_parallel_size(); ++i) { const auto& sbp_parallel = out->nd_sbp()->sbp_parallel(i); CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << "nd_sbp axis: " << i; } CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type()); CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num()); ParallelConf unflattened_parallel_conf(in->placement()->parallel_conf()); unflattened_parallel_conf.mutable_hierarchy()->CopyFrom( out->placement()->parallel_conf().hierarchy()); const auto& unflatten_placement = SymbolOf(ParallelDesc(unflattened_parallel_conf)); CHECK_OR_RETURN(unflatten_placement == out->placement()) << "The output placement is not a hierarch-unflattened version of the input placement"; for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num(); ++in_parallel_id) { const auto& in_physical_shape = JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id)); const auto& out_physical_shape = JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), in_parallel_id)); CHECK_EQ_OR_RETURN(*in_physical_shape, *out_physical_shape); } return Maybe::Ok(); } // NOLINTEND(maybe-need-error-msg) } // namespace static constexpr auto* CheckUnflattenHierarchy = DECORATE(&RawCheckUnflattenHierarchy, ThreadLocalCachedCopiable); Maybe UnflattenHierarchy(const std::shared_ptr& tensor, Symbol in, Symbol out) { const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()) << Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp) << ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")"; const auto& tensor_placement = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(tensor_placement == in->placement()) << Error::RuntimeError() << "The placement of input tensor (" << *JUST(PlacementToString(tensor_placement)) << ") must match the input placement (" << *JUST(PlacementToString(in->placement())) << ")"; const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list, *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/true)); } COMMAND(RegisterBoxingFunction("unflatten-hierarchy", CheckUnflattenHierarchy, &UnflattenHierarchy)); } // namespace oneflow ================================================ FILE: oneflow/core/ccl/ccl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type_seq.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/common/constant.h" namespace oneflow { namespace ccl { namespace { Maybe InitBroadcastRankHeap(std::vector* ranks, const ParallelDesc& parallel_desc, int64_t root) { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), parallel_desc.sorted_machine_ids().size()); ranks->resize(parallel_desc.parallel_num()); int64_t root_index = -1; for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_id)); if (machine_id == root) { root_index = parallel_id; } (*ranks)[parallel_id] = machine_id; } CHECK_NE_OR_RETURN(root_index, -1); std::swap((*ranks)[0], (*ranks)[root_index]); return Maybe::Ok(); } } // namespace Maybe CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root, Symbol parallel_desc, const TransportToken& transport_token) { static thread_local std::vector rank_heap{}; JUST(InitBroadcastRankHeap(&rank_heap, *parallel_desc, root)); auto Send = [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = (root == GlobalProcessCtx::Rank() ? const_cast(in) : out); *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }; auto Recv = [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = out; *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }; { NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv); JUST(TransportUtil::ReceiveDataFromParentInHeap(rank_heap, transport_token, &transport_ctx)); JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg); } { NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv); JUST(TransportUtil::SendDataToChildrenInHeap(rank_heap, transport_token, &transport_ctx)); if (GlobalProcessCtx::Rank() == root && out != in) { std::memcpy(out, in, buffer_size); } JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg); } return Maybe::Ok(); } Maybe CpuSend(const void* in, size_t buffer_size, int64_t dst) { TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); NaiveAsyncTransportCtx transport_ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(in); *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }); JUST(TransportUtil::SendDataToRank(dst, transport_token, &transport_ctx)); JUST(transport_ctx.WaitDone()); return Maybe::Ok(); } Maybe CpuRecv(void* out, size_t buffer_size, int64_t src) { TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); NaiveAsyncTransportCtx transport_ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = out; *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }); JUST(TransportUtil::ReceiveDataFromRank(src, transport_token, &transport_ctx)); JUST(transport_ctx.WaitDone()); return Maybe::Ok(); } } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/core/ccl/ccl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CCL_CCL_H_ #define ONEFLOW_CORE_CCL_CCL_H_ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class ParallelDesc; class TransportToken; // collective communication library namespace ccl { Maybe CpuSend(const void* in, size_t buffer_size, int64_t dst); Maybe CpuRecv(void* out, size_t buffer_size, int64_t src); Maybe CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root, Symbol parallel_desc, const TransportToken& transport_token); } // namespace ccl } // namespace oneflow #endif // ONEFLOW_CORE_CCL_CCL_H_ ================================================ FILE: oneflow/core/comm_network/comm_network.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { CommNet::~CommNet() { ready_cbs_.Close(); ready_cb_poller_.join(); } void* CommNet::NewActorReadId() { return new ActorReadContext; } void CommNet::DeleteActorReadId(void* actor_read_id) { auto actor_read_ctx = static_cast(actor_read_id); CHECK(actor_read_ctx->waiting_list.empty()); delete actor_read_ctx; } void CommNet::Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token) { auto actor_read_ctx = static_cast(actor_read_id); ReadContext* read_ctx = new ReadContext; read_ctx->actor_read_ctx = actor_read_ctx; auto do_read = [this, read_ctx, src_machine_id, src_token, dst_token]() { DoRead(read_ctx, src_machine_id, src_token, dst_token); }; AddWorkToStream(actor_read_id, do_read, true); } void CommNet::AddReadCallBack(void* actor_read_id, std::function callback) { AddWorkToStream(actor_read_id, callback, false); } void CommNet::ReadDone(void* read_id) { ReadContext* read_ctx = static_cast(read_id); ActorReadContext* actor_read_ctx = read_ctx->actor_read_ctx; CommNetItem item; std::unique_lock lck(actor_read_ctx->waiting_list_mtx); CHECK(!actor_read_ctx->waiting_list.empty()); CHECK(actor_read_ctx->waiting_list.front().callback == nullptr); actor_read_ctx->waiting_list.pop_front(); while (true) { if (actor_read_ctx->waiting_list.empty()) { break; } item = actor_read_ctx->waiting_list.front(); actor_read_ctx->waiting_list.pop_front(); CHECK(item.callback); ready_cbs_.Send(item.callback); if (item.is_read) { break; } } delete read_ctx; } void CommNet::AddWorkToStream(void* actor_read_id, const std::function& cb, bool is_read) { auto actor_read_ctx = static_cast(actor_read_id); std::unique_lock lck(actor_read_ctx->waiting_list_mtx); if (actor_read_ctx->waiting_list.empty()) { ready_cbs_.Send(cb); } else { CommNetItem work_item(is_read, cb); actor_read_ctx->waiting_list.emplace_back(work_item); } if (is_read) { CommNetItem empty_cb; actor_read_ctx->waiting_list.emplace_back(empty_cb); } } CommNet::CommNet() { int64_t this_machine_id = GlobalProcessCtx::Rank(); for (int64_t i : Singleton::Get()->process_ranks()) { if (i == this_machine_id) { continue; } peer_machine_id_.insert(i); } ready_cb_poller_ = std::thread([this]() { std::function cb; while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); } }); } } // namespace oneflow ================================================ FILE: oneflow/core/comm_network/comm_network.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_ #define ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_ #ifndef DEPRECATED #define DEPRECATED __attribute__((deprecated)) #endif #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/common/platform.h" #include "oneflow/core/common/channel.h" namespace oneflow { struct CommNetItem { bool is_read; std::function callback; CommNetItem() : CommNetItem(false, nullptr) {} CommNetItem(bool read, const std::function& cb) : is_read(read), callback(cb) {} }; class CommNet { public: OF_DISALLOW_COPY_AND_MOVE(CommNet); virtual ~CommNet(); // "RegisterMemory" will return a Token, after "RegisterMemoryDone", // we can use this token to use the "Read" virtual void* RegisterMemory(void* ptr, size_t byte_size) = 0; virtual void UnRegisterMemory(void* token) = 0; // Stream void* NewActorReadId(); void DeleteActorReadId(void* actor_read_id); void Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token); void AddReadCallBack(void* actor_read_id, std::function callback); void ReadDone(void* read_id); virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0; protected: CommNet(); virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) = 0; const HashSet& peer_machine_id() { return peer_machine_id_; } Channel> ready_cbs_; private: friend class Singleton; void AddWorkToStream(void* actor_read_id, const std::function& cb, bool is_read); struct ActorReadContext; struct ReadContext { ActorReadContext* actor_read_ctx; }; struct ActorReadContext { std::mutex waiting_list_mtx; std::list waiting_list; }; HashSet peer_machine_id_; std::thread ready_cb_poller_; }; template class CommNetIf : public CommNet { public: OF_DISALLOW_COPY_AND_MOVE(CommNetIf); CommNetIf() : CommNet() {} virtual ~CommNetIf() {} void* RegisterMemory(void* ptr, size_t byte_size) override { std::unique_lock lck(mem_descs_mtx_); MemDescType* mem_desc = NewMemDesc(ptr, byte_size); CHECK(mem_descs_.insert(mem_desc).second); return mem_desc; } void UnRegisterMemory(void* token) override { std::unique_lock lck(mem_descs_mtx_); MemDescType* mem_desc = static_cast(token); delete mem_desc; CHECK_EQ(mem_descs_.erase(mem_desc), 1); } protected: virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) = 0; const HashSet& mem_descs() { return mem_descs_; } private: std::mutex mem_descs_mtx_; HashSet mem_descs_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_ ================================================ FILE: oneflow/core/comm_network/epoll/epoll_comm_network.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/comm_network/epoll/epoll_comm_network.h" #include "glog/logging.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/global_for.h" #include namespace oneflow { namespace { static const int32_t kInvlidPort = 0; sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) { sockaddr_in sa; sa.sin_family = AF_INET; sa.sin_port = htons(port); PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1) << "addr: " << addr << ", port: " << port; return sa; } int SockListen(int listen_sockfd, int32_t* listen_port, int32_t total_machine_num) { // System designated available port if listen_port == kInvlidPort, otherwise, the configured port // is used. sockaddr_in sa = GetSockAddr("0.0.0.0", *listen_port); int reuse = 1; int ret_setopt = setsockopt(listen_sockfd, SOL_SOCKET, SO_REUSEADDR, (const void*)&reuse, sizeof(int)); CHECK_EQ(ret_setopt, 0); int bind_result = bind(listen_sockfd, reinterpret_cast(&sa), sizeof(sa)); { sockaddr_in bound_sock; socklen_t bound_sock_size = sizeof(bound_sock); getsockname(listen_sockfd, reinterpret_cast(&bound_sock), &bound_sock_size); if (*listen_port != kInvlidPort) { CHECK_EQ(*listen_port, static_cast(ntohs(bound_sock.sin_port))); } else { *listen_port = static_cast(ntohs(bound_sock.sin_port)); } } if (bind_result == 0) { PCHECK(listen(listen_sockfd, total_machine_num) == 0); LOG(INFO) << "CommNet:Epoll listening on " << "0.0.0.0:" + std::to_string(*listen_port); } else { PCHECK(errno == EACCES || errno == EADDRINUSE) << "SockListen errno: " << errno; } return bind_result; } std::string GenPortKey(int64_t machine_id) { return "EpollPort/" + std::to_string(machine_id); } void PushPort(int64_t machine_id, uint16_t port) { Singleton::Get()->PushKV(GenPortKey(machine_id), std::to_string(port)); } void ClearPort(int64_t machine_id) { Singleton::Get()->ClearKV(GenPortKey(machine_id)); } uint16_t PullPort(int64_t machine_id) { uint16_t port = 0; Singleton::Get()->PullKV( GenPortKey(machine_id), [&](const std::string& v) { port = oneflow_cast(v); }); return port; } } // namespace EpollCommNet::~EpollCommNet() { for (size_t i = 0; i < pollers_.size(); ++i) { VLOG(1) << "CommNet Thread " << i << " finish"; pollers_[i]->Stop(); } OF_ENV_BARRIER(); for (IOEventPoller* poller : pollers_) { delete poller; } for (auto& pair : sockfd2helper_) { delete pair.second; } } void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) { SocketMsg msg; msg.msg_type = SocketMsgType::kActor; msg.actor_msg = actor_msg; if (actor_msg.IsDataRegstMsgToConsumer()) { msg.actor_msg.set_comm_net_token(actor_msg.regst()->comm_net_token()); } GetSocketHelper(dst_machine_id)->AsyncWrite(msg); } void EpollCommNet::SendTransportMsg(int64_t dst_machine_id, const TransportMsg& transport_msg) { SocketMsg msg; msg.msg_type = SocketMsgType::kTransport; msg.transport_msg = transport_msg; SendSocketMsg(dst_machine_id, msg); } void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg) { GetSocketHelper(dst_machine_id)->AsyncWrite(msg); } SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) { SocketMemDesc* mem_desc = new SocketMemDesc; mem_desc->mem_ptr = ptr; mem_desc->byte_size = byte_size; return mem_desc; } EpollCommNet::EpollCommNet() : CommNetIf() { pollers_.resize(Singleton::Get()->CommNetWorkerNum(), nullptr); for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; } InitSockets(); for (IOEventPoller* poller : pollers_) { poller->Start(); } } void EpollCommNet::InitSockets() { int64_t this_machine_id = GlobalProcessCtx::Rank(); auto this_machine = Singleton::Get()->machine(this_machine_id); int64_t total_machine_num = Singleton::Get()->process_ranks().size(); machine_id2sockfd_.assign(total_machine_num, -1); sockfd2helper_.clear(); size_t poller_idx = 0; auto NewSocketHelper = [&](int sockfd) { IOEventPoller* poller = pollers_[poller_idx]; poller_idx = (poller_idx + 1) % pollers_.size(); return new SocketHelper(sockfd, poller); }; // listen int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0); int32_t this_listen_port = kInvlidPort; { if (this_machine.data_port_agent() != -1) { this_listen_port = this_machine.data_port_agent(); } else if (Singleton::Get()->data_port() != -1) { this_listen_port = Singleton::Get()->data_port(); } } CHECK_EQ(SockListen(listen_sockfd, &this_listen_port, total_machine_num), 0); CHECK_NE(this_listen_port, 0); PushPort(this_machine_id, this_listen_port); int32_t src_machine_count = 0; // connect for (int64_t peer_id : peer_machine_id()) { if (peer_id < this_machine_id) { ++src_machine_count; continue; } uint16_t peer_port = PullPort(peer_id); auto peer_machine = Singleton::Get()->machine(peer_id); sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port); int sockfd = socket(AF_INET, SOCK_STREAM, 0); const int val = 1; PCHECK(setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(int)) == 0); PCHECK(connect(sockfd, reinterpret_cast(&peer_sockaddr), sizeof(peer_sockaddr)) == 0); ssize_t n = write(sockfd, &this_machine_id, sizeof(int64_t)); PCHECK(n == sizeof(int64_t)); CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second); machine_id2sockfd_[peer_id] = sockfd; } // accept HashSet processed_ranks; FOR_RANGE(int32_t, idx, 0, src_machine_count) { sockaddr_in peer_sockaddr; socklen_t len = sizeof(peer_sockaddr); int sockfd = accept(listen_sockfd, reinterpret_cast(&peer_sockaddr), &len); PCHECK(sockfd != -1); int64_t peer_rank; ssize_t n = read(sockfd, &peer_rank, sizeof(int64_t)); PCHECK(n == sizeof(int64_t)); CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second); CHECK(processed_ranks.emplace(peer_rank).second); machine_id2sockfd_[peer_rank] = sockfd; } PCHECK(close(listen_sockfd) == 0); ClearPort(this_machine_id); // useful log FOR_RANGE(int64_t, machine_id, 0, total_machine_num) { VLOG(2) << "machine " << machine_id << " sockfd " << machine_id2sockfd_[machine_id]; } } SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id) { int sockfd = machine_id2sockfd_.at(machine_id); return sockfd2helper_.at(sockfd); } void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) { SocketMsg msg; msg.msg_type = SocketMsgType::kRequestWrite; msg.request_write_msg.src_token = src_token; msg.request_write_msg.dst_machine_id = GlobalProcessCtx::Rank(); msg.request_write_msg.dst_token = dst_token; msg.request_write_msg.read_id = read_id; GetSocketHelper(src_machine_id)->AsyncWrite(msg); } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/comm_network/epoll/epoll_comm_network.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_ #ifdef __linux__ #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/comm_network/epoll/socket_helper.h" #include "oneflow/core/comm_network/epoll/socket_memory_desc.h" namespace oneflow { class EpollCommNet final : public CommNetIf { public: OF_DISALLOW_COPY_AND_MOVE(EpollCommNet); ~EpollCommNet(); void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override; void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg); void SendTransportMsg(int64_t dst_machine_id, const TransportMsg& msg); private: SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override; friend class Singleton; EpollCommNet(); void InitSockets(); SocketHelper* GetSocketHelper(int64_t machine_id); void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override; std::vector pollers_; std::vector machine_id2sockfd_; HashMap sockfd2helper_; }; } // namespace oneflow #endif // __linux__ #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_ ================================================ FILE: oneflow/core/comm_network/epoll/io_event_poller.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/comm_network/epoll/io_event_poller.h" #include namespace oneflow { const int IOEventPoller::max_event_num_ = 32; IOEventPoller::IOEventPoller() { epfd_ = epoll_create1(0); ep_events_ = new epoll_event[max_event_num_]; io_handlers_.clear(); break_epoll_loop_fd_ = eventfd(0, 0); PCHECK(break_epoll_loop_fd_ != -1); AddFdWithOnlyReadHandler(break_epoll_loop_fd_, []() { VLOG(1) << "Break Epoll Loop"; }); } IOEventPoller::~IOEventPoller() { for (IOHandler* handler : io_handlers_) { PCHECK(close(handler->fd) == 0); delete handler; } delete[] ep_events_; PCHECK(close(epfd_) == 0); } void IOEventPoller::AddFd(int fd, std::function read_handler, std::function write_handler) { AddFd(fd, &read_handler, &write_handler); } void IOEventPoller::AddFdWithOnlyReadHandler(int fd, std::function read_handler) { AddFd(fd, &read_handler, nullptr); } void IOEventPoller::Start() { thread_ = std::thread(&IOEventPoller::EpollLoop, this); } void IOEventPoller::Stop() { uint64_t break_epoll_loop_event = 1; PCHECK(write(break_epoll_loop_fd_, &break_epoll_loop_event, 8) == 8); thread_.join(); } void IOEventPoller::AddFd(int fd, std::function* read_handler, std::function* write_handler) { // Set Fd NONBLOCK int opt = fcntl(fd, F_GETFL); PCHECK(opt != -1); PCHECK(fcntl(fd, F_SETFL, opt | O_NONBLOCK) == 0); // Set CLOEXEC opt = fcntl(fd, F_GETFD); PCHECK(opt != -1); PCHECK(fcntl(fd, F_SETFD, opt | FD_CLOEXEC) == 0); // New IOHandler on Heap IOHandler* io_handler = new IOHandler; if (read_handler) { io_handler->read_handler = *read_handler; } if (write_handler) { io_handler->write_handler = *write_handler; } io_handler->fd = fd; io_handlers_.push_front(io_handler); // Add Fd to Epoll epoll_event ep_event; ep_event.events = EPOLLET; if (read_handler) { ep_event.events |= EPOLLIN; } if (write_handler) { ep_event.events |= EPOLLOUT; } ep_event.data.ptr = io_handler; PCHECK(epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &ep_event) == 0); } void IOEventPoller::EpollLoop() { while (true) { int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1); if (event_num == -1) { PCHECK(errno == EINTR); continue; } const epoll_event* cur_event = ep_events_; for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) { auto io_handler = static_cast(cur_event->data.ptr); PCHECK(!(cur_event->events & EPOLLERR)) << "fd: " << io_handler->fd; if (io_handler->fd == break_epoll_loop_fd_) { return; } if (cur_event->events & EPOLLIN) { if (cur_event->events & EPOLLRDHUP) { LOG(FATAL) << "fd " << io_handler->fd << " closed by peer"; } else { io_handler->read_handler(); } } if (cur_event->events & EPOLLOUT) { io_handler->write_handler(); } } } } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/comm_network/epoll/io_event_poller.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_ #include "oneflow/core/comm_network/epoll/socket_message.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { class IOEventPoller final { public: OF_DISALLOW_COPY_AND_MOVE(IOEventPoller); IOEventPoller(); ~IOEventPoller(); void AddFd(int fd, std::function read_handler, std::function write_handler); void AddFdWithOnlyReadHandler(int fd, std::function read_handler); void Start(); void Stop(); private: struct IOHandler { IOHandler() { read_handler = []() { UNIMPLEMENTED(); }; write_handler = []() { UNIMPLEMENTED(); }; fd = -1; } std::function read_handler; std::function write_handler; int fd; }; void AddFd(int fd, std::function* read_handler, std::function* write_handler); void EpollLoop(); static const int max_event_num_; int epfd_; epoll_event* ep_events_; std::forward_list io_handlers_; int break_epoll_loop_fd_; std::thread thread_; }; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_ ================================================ FILE: oneflow/core/comm_network/epoll/socket_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/comm_network/epoll/socket_helper.h" namespace oneflow { SocketHelper::SocketHelper(int sockfd, IOEventPoller* poller) { read_helper_ = new SocketReadHelper(sockfd); write_helper_ = new SocketWriteHelper(sockfd, poller); poller->AddFd( sockfd, [this]() { read_helper_->NotifyMeSocketReadable(); }, [this]() { write_helper_->NotifyMeSocketWriteable(); }); } SocketHelper::~SocketHelper() { delete read_helper_; delete write_helper_; } void SocketHelper::AsyncWrite(const SocketMsg& msg) { write_helper_->AsyncWrite(msg); } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/comm_network/epoll/socket_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_ #include "oneflow/core/comm_network/epoll/io_event_poller.h" #include "oneflow/core/comm_network/epoll/socket_read_helper.h" #include "oneflow/core/comm_network/epoll/socket_write_helper.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { class SocketHelper final { public: OF_DISALLOW_COPY_AND_MOVE(SocketHelper); SocketHelper() = delete; ~SocketHelper(); SocketHelper(int sockfd, IOEventPoller* poller); void AsyncWrite(const SocketMsg& msg); private: SocketReadHelper* read_helper_; SocketWriteHelper* write_helper_; }; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_ ================================================ FILE: oneflow/core/comm_network/epoll/socket_memory_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_ #include "oneflow/core/comm_network/epoll/socket_memory_desc.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { struct SocketMemDesc { void* mem_ptr; size_t byte_size; }; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_ ================================================ FILE: oneflow/core/comm_network/epoll/socket_message.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_ #include "oneflow/core/common/platform.h" #include "oneflow/core/common/util.h" #include "oneflow/core/comm_network/comm_network.h" #ifdef OF_PLATFORM_POSIX #include #include #include #include #include #include #include #include #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/transport/transport_message.h" namespace oneflow { #define SOCKET_MSG_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(RequestWrite, request_write) \ OF_PP_MAKE_TUPLE_SEQ(RequestRead, request_read) \ OF_PP_MAKE_TUPLE_SEQ(Actor, actor) \ OF_PP_MAKE_TUPLE_SEQ(Transport, transport) enum class SocketMsgType { #define MAKE_ENTRY(x, y) k##x, OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ) #undef MAKE_ENTRY }; struct RequestWriteMsg { void* src_token; int64_t dst_machine_id; void* dst_token; void* read_id; }; struct RequestReadMsg { void* src_token; void* dst_token; void* read_id; }; struct SocketMsg { SocketMsgType msg_type; union { #define MAKE_ENTRY(x, y) x##Msg y##_msg; OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ) #undef MAKE_ENTRY }; }; using CallBackList = std::list>; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_ ================================================ FILE: oneflow/core/comm_network/epoll/socket_read_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/comm_network/epoll/socket_read_helper.h" #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/comm_network/epoll/epoll_comm_network.h" #include "oneflow/core/transport/transport.h" #include namespace oneflow { SocketReadHelper::~SocketReadHelper() { // do nothing } SocketReadHelper::SocketReadHelper(int sockfd) { sockfd_ = sockfd; SwitchToMsgHeadReadHandle(); } void SocketReadHelper::NotifyMeSocketReadable() { ReadUntilSocketNotReadable(); } void SocketReadHelper::SwitchToMsgHeadReadHandle() { cur_read_handle_ = &SocketReadHelper::MsgHeadReadHandle; read_ptr_ = reinterpret_cast(&cur_msg_); read_size_ = sizeof(cur_msg_); } void SocketReadHelper::ReadUntilSocketNotReadable() { while ((this->*cur_read_handle_)()) {} } bool SocketReadHelper::MsgHeadReadHandle() { return DoCurRead(&SocketReadHelper::SetStatusWhenMsgHeadDone); } bool SocketReadHelper::MsgBodyReadHandle() { return DoCurRead(&SocketReadHelper::SetStatusWhenMsgBodyDone); } bool SocketReadHelper::DoCurRead(void (SocketReadHelper::*set_cur_read_done)()) { ssize_t n = read(sockfd_, read_ptr_, read_size_); const int val = 1; PCHECK(setsockopt(sockfd_, IPPROTO_TCP, TCP_QUICKACK, (char*)&val, sizeof(int)) == 0); if (n == read_size_) { (this->*set_cur_read_done)(); return true; } else if (n >= 0) { read_ptr_ += n; read_size_ -= n; return true; } else { CHECK_EQ(n, -1); PCHECK(errno == EAGAIN || errno == EWOULDBLOCK); return false; } } void SocketReadHelper::SetStatusWhenMsgHeadDone() { switch (cur_msg_.msg_type) { #define MAKE_ENTRY(x, y) \ case SocketMsgType::k##x: SetStatusWhen##x##MsgHeadDone(); break; OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ); #undef MAKE_ENTRY default: UNIMPLEMENTED(); } } void SocketReadHelper::SetStatusWhenMsgBodyDone() { if (cur_msg_.msg_type == SocketMsgType::kRequestRead) { Singleton::Get()->ReadDone(cur_msg_.request_read_msg.read_id); } SwitchToMsgHeadReadHandle(); } void SocketReadHelper::SetStatusWhenRequestWriteMsgHeadDone() { SocketMsg msg_to_send; msg_to_send.msg_type = SocketMsgType::kRequestRead; msg_to_send.request_read_msg.src_token = cur_msg_.request_write_msg.src_token; msg_to_send.request_read_msg.dst_token = cur_msg_.request_write_msg.dst_token; msg_to_send.request_read_msg.read_id = cur_msg_.request_write_msg.read_id; Singleton::Get()->SendSocketMsg(cur_msg_.request_write_msg.dst_machine_id, msg_to_send); SwitchToMsgHeadReadHandle(); } void SocketReadHelper::SetStatusWhenRequestReadMsgHeadDone() { auto mem_desc = static_cast(cur_msg_.request_read_msg.dst_token); read_ptr_ = reinterpret_cast(mem_desc->mem_ptr); read_size_ = mem_desc->byte_size; cur_read_handle_ = &SocketReadHelper::MsgBodyReadHandle; } void SocketReadHelper::SetStatusWhenActorMsgHeadDone() { Singleton::Get()->SendMsgWithoutCommNet(cur_msg_.actor_msg); SwitchToMsgHeadReadHandle(); } void SocketReadHelper::SetStatusWhenTransportMsgHeadDone() { Singleton::Get()->EnqueueTransportMsg(cur_msg_.transport_msg); SwitchToMsgHeadReadHandle(); } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/comm_network/epoll/socket_read_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_ #include "oneflow/core/comm_network/epoll/socket_message.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { class SocketReadHelper final { public: OF_DISALLOW_COPY_AND_MOVE(SocketReadHelper); SocketReadHelper() = delete; ~SocketReadHelper(); SocketReadHelper(int sockfd); void NotifyMeSocketReadable(); private: void SwitchToMsgHeadReadHandle(); void ReadUntilSocketNotReadable(); bool MsgHeadReadHandle(); bool MsgBodyReadHandle(); bool DoCurRead(void (SocketReadHelper::*set_cur_read_done)()); void SetStatusWhenMsgHeadDone(); void SetStatusWhenMsgBodyDone(); #define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone(); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ); #undef MAKE_ENTRY int sockfd_; SocketMsg cur_msg_; bool (SocketReadHelper::*cur_read_handle_)(); char* read_ptr_; size_t read_size_; }; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_ ================================================ FILE: oneflow/core/comm_network/epoll/socket_write_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/comm_network/epoll/socket_write_helper.h" #include "oneflow/core/comm_network/epoll/socket_memory_desc.h" #include namespace oneflow { SocketWriteHelper::~SocketWriteHelper() { delete cur_msg_queue_; cur_msg_queue_ = nullptr; { std::unique_lock lck(pending_msg_queue_mtx_); delete pending_msg_queue_; pending_msg_queue_ = nullptr; } } SocketWriteHelper::SocketWriteHelper(int sockfd, IOEventPoller* poller) { sockfd_ = sockfd; queue_not_empty_fd_ = eventfd(0, 0); PCHECK(queue_not_empty_fd_ != -1); poller->AddFdWithOnlyReadHandler(queue_not_empty_fd_, std::bind(&SocketWriteHelper::ProcessQueueNotEmptyEvent, this)); cur_msg_queue_ = new std::queue; pending_msg_queue_ = new std::queue; cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle; write_ptr_ = nullptr; write_size_ = 0; } void SocketWriteHelper::AsyncWrite(const SocketMsg& msg) { pending_msg_queue_mtx_.lock(); bool need_send_event = pending_msg_queue_->empty(); pending_msg_queue_->push(msg); pending_msg_queue_mtx_.unlock(); if (need_send_event) { SendQueueNotEmptyEvent(); } } void SocketWriteHelper::NotifyMeSocketWriteable() { WriteUntilMsgQueueEmptyOrSocketNotWriteable(); } void SocketWriteHelper::SendQueueNotEmptyEvent() { uint64_t event_num = 1; PCHECK(write(queue_not_empty_fd_, &event_num, 8) == 8); } void SocketWriteHelper::ProcessQueueNotEmptyEvent() { uint64_t event_num = 0; PCHECK(read(queue_not_empty_fd_, &event_num, 8) == 8); WriteUntilMsgQueueEmptyOrSocketNotWriteable(); } void SocketWriteHelper::WriteUntilMsgQueueEmptyOrSocketNotWriteable() { while ((this->*cur_write_handle_)()) {} } bool SocketWriteHelper::InitMsgWriteHandle() { if (cur_msg_queue_->empty()) { { std::unique_lock lck(pending_msg_queue_mtx_); std::swap(cur_msg_queue_, pending_msg_queue_); } if (cur_msg_queue_->empty()) { return false; } } cur_msg_ = cur_msg_queue_->front(); cur_msg_queue_->pop(); write_ptr_ = reinterpret_cast(&cur_msg_); write_size_ = sizeof(cur_msg_); cur_write_handle_ = &SocketWriteHelper::MsgHeadWriteHandle; return true; } bool SocketWriteHelper::MsgHeadWriteHandle() { return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgHeadDone); } bool SocketWriteHelper::MsgBodyWriteHandle() { return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgBodyDone); } bool SocketWriteHelper::DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)()) { ssize_t n = write(sockfd_, write_ptr_, write_size_); if (n == write_size_) { (this->*set_cur_write_done)(); return true; } else if (n >= 0) { write_ptr_ += n; write_size_ -= n; return true; } else { CHECK_EQ(n, -1); PCHECK(errno == EAGAIN || errno == EWOULDBLOCK); return false; } } void SocketWriteHelper::SetStatusWhenMsgHeadDone() { switch (cur_msg_.msg_type) { #define MAKE_ENTRY(x, y) \ case SocketMsgType::k##x: return SetStatusWhen##x##MsgHeadDone(); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ); #undef MAKE_ENTRY default: UNIMPLEMENTED(); } } void SocketWriteHelper::SetStatusWhenMsgBodyDone() { cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle; } void SocketWriteHelper::SetStatusWhenRequestWriteMsgHeadDone() { cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle; } void SocketWriteHelper::SetStatusWhenRequestReadMsgHeadDone() { const void* src_token = cur_msg_.request_read_msg.src_token; auto src_mem_desc = static_cast(src_token); write_ptr_ = reinterpret_cast(src_mem_desc->mem_ptr); write_size_ = src_mem_desc->byte_size; cur_write_handle_ = &SocketWriteHelper::MsgBodyWriteHandle; } void SocketWriteHelper::SetStatusWhenActorMsgHeadDone() { cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle; } void SocketWriteHelper::SetStatusWhenTransportMsgHeadDone() { cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle; } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/comm_network/epoll/socket_write_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_ #define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_ #include "oneflow/core/comm_network/epoll/io_event_poller.h" #include "oneflow/core/comm_network/epoll/socket_message.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { class SocketWriteHelper final { public: OF_DISALLOW_COPY_AND_MOVE(SocketWriteHelper); SocketWriteHelper() = delete; ~SocketWriteHelper(); SocketWriteHelper(int sockfd, IOEventPoller* poller); void AsyncWrite(const SocketMsg& msg); void NotifyMeSocketWriteable(); private: void SendQueueNotEmptyEvent(); void ProcessQueueNotEmptyEvent(); void WriteUntilMsgQueueEmptyOrSocketNotWriteable(); bool InitMsgWriteHandle(); bool MsgHeadWriteHandle(); bool MsgBodyWriteHandle(); bool DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)()); void SetStatusWhenMsgHeadDone(); void SetStatusWhenMsgBodyDone(); #define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone(); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ); #undef MAKE_ENTRY int sockfd_; int queue_not_empty_fd_; std::queue* cur_msg_queue_; std::mutex pending_msg_queue_mtx_; std::queue* pending_msg_queue_; SocketMsg cur_msg_; bool (SocketWriteHelper::*cur_write_handle_)(); const char* write_ptr_; size_t write_size_; }; } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_ ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs.proto ================================================ syntax = "proto2"; package oneflow; message IBVerbsConnectionInfo { required uint32 lid = 1; required uint32 qp_num = 2; required uint64 subnet_prefix = 3; required uint64 interface_id = 4; required uint32 port_num = 5; required int32 mtu = 6; } ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/platform/include/ibv.h" #include "oneflow/core/lazy/actor/actor_message_bus.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) namespace oneflow { namespace { std::string GenTokensMsgKey(int64_t machine_id) { return "IBVerbsTokensMsg/" + std::to_string(machine_id); } std::string GenConnInfoKey(int64_t src_machine_id, int64_t dst_machine_id) { return "IBVerbsConnInfo/" + std::to_string(src_machine_id) + "/" + std::to_string(dst_machine_id); } void IBVForkInit() { if (ibv::IsAvailable()) { if (ibv::wrapper.ibv_fork_init() != 0) { std::cerr << "ibv_fork_init failed\n"; } } else { std::cerr << "libibverbs not available, ibv_fork_init skipped\n"; } } void ParseUserDevicePort(std::string* device_name, int* port) { std::string user_device_port = GetStringFromEnv("ONEFLOW_COMM_NET_IB_HCA", ""); if (user_device_port.empty()) { *device_name = ""; *port = 0; return; } else { const std::string::size_type pos = user_device_port.find(':', 0); if (pos == std::string::npos) { *device_name = user_device_port; *port = 0; return; } else { *device_name = user_device_port.substr(0, pos); *port = std::strtol(user_device_port.data() + pos + 1, nullptr, 10); return; } } } } // namespace IBVerbsCommNet::~IBVerbsCommNet() { while (poll_exit_flag_.test_and_set() == true) {} poll_thread_.join(); for (IBVerbsQP* qp : qp_vec_) { if (qp) { delete qp; } } PCHECK(ibv::wrapper.ibv_destroy_cq(cq_) == 0); PCHECK(ibv::wrapper.ibv_dealloc_pd(pd_) == 0); CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0) << "Error, failed to close the IB device " << ibv::wrapper.ibv_get_device_name(context_->device); } void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) { IBVerbsActorMsgWrapper msg_wrapper; msg_wrapper.msg = msg; if (msg.IsDataRegstMsgToConsumer()) { auto* mem_desc = reinterpret_cast(msg.regst()->comm_net_token()); CHECK(mem_desc != nullptr); msg_wrapper.rma_desc.mem_ptr = reinterpret_cast(mem_desc->mem_ptr()); msg_wrapper.rma_desc.mem_size = mem_desc->mem_size(); msg_wrapper.rma_desc.mr_rkey = mem_desc->mr()->rkey; } qp_vec_.at(dst_machine_id)->PostSendRequest(msg_wrapper); } void IBVerbsCommNet::RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper) { ActorMsg new_msg = msg_wrapper.msg; if (msg_wrapper.msg.IsDataRegstMsgToConsumer()) { std::lock_guard lock(remote_regst2rma_desc_mutex_); auto& desc = remote_regst2rma_desc_[std::make_pair( msg_wrapper.msg.src_actor_id(), reinterpret_cast(msg_wrapper.msg.regst()))]; if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); } *desc = msg_wrapper.rma_desc; new_msg.set_comm_net_token(desc.get()); } Singleton::Get()->SendMsgWithoutCommNet(new_msg); } IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) { int num_device; ibv_device** device_list = ibv::wrapper.ibv_get_device_list(&num_device); CHECK_GT(num_device, 0) << "No IB device found"; PCHECK(device_list); std::string user_device; int user_port; ParseUserDevicePort(&user_device, &user_port); ibv_device* device = nullptr; if (user_device.empty()) { device = device_list[0]; } else { for (int i = 0; i < num_device; ++i) { if (device_list[i]->name == user_device) { device = device_list[i]; break; } } CHECK(device != nullptr) << "No IB device match " << user_device; } context_ = ibv::wrapper.ibv_open_device(device); CHECK(context_ != NULL) << "Error, failed to open the IB device " << ibv::wrapper.ibv_get_device_name(device); ibv::wrapper.ibv_free_device_list(device_list); pd_ = ibv::wrapper.ibv_alloc_pd(context_); CHECK(pd_) << "Error, ibv_alloc_pd() allocates a Protection Domain (PD) failed"; ibv_device_attr device_attr{}; PCHECK(ibv::wrapper.ibv_query_device(context_, &device_attr) == 0); cq_ = ibv::wrapper.ibv_create_cq(context_, device_attr.max_cqe, nullptr, nullptr, 0); PCHECK(cq_); ibv_port_attr port_attr{}; const uint8_t port = user_port == 0 ? 1 : user_port; PCHECK(ibv::wrapper.ibv_query_port_wrap(context_, port, &port_attr) == 0); ibv_gid gid{}; const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0); PCHECK(ibv::wrapper.ibv_query_gid(context_, port, gid_index, &gid) == 0); VLOG(1) << "Using IB device " << device->name << " port " << static_cast(port) << " gid index " << gid_index; int64_t this_machine_id = GlobalProcessCtx::Rank(); qp_vec_.assign(Singleton::Get()->process_ranks().size(), nullptr); for (int64_t peer_id : peer_machine_id()) { IBVerbsQP* cur_qp = new IBVerbsQP(context_, pd_, port_attr, port, cq_, cq_); qp_vec_.at(peer_id) = cur_qp; IBVerbsConnectionInfo conn_info; conn_info.set_lid(port_attr.lid); conn_info.set_qp_num(cur_qp->qp_num()); conn_info.set_subnet_prefix(gid.global.subnet_prefix); conn_info.set_interface_id(gid.global.interface_id); conn_info.set_port_num(port); conn_info.set_mtu(static_cast(port_attr.active_mtu)); Singleton::Get()->PushKV(GenConnInfoKey(this_machine_id, peer_id), conn_info); } for (int64_t peer_id : peer_machine_id()) { IBVerbsConnectionInfo conn_info; Singleton::Get()->PullKV(GenConnInfoKey(peer_id, this_machine_id), &conn_info); if (conn_info.lid() == 0) { VLOG(2) << "Connecting to peer " << peer_id << " port " << conn_info.port_num() << " qpn " << conn_info.qp_num() << " gid index " << gid_index << " spn " << conn_info.subnet_prefix() << " iid " << conn_info.interface_id() << " mtu " << conn_info.mtu(); } else { VLOG(2) << "Connecting to peer " << peer_id << " port " << conn_info.port_num() << " qpn " << conn_info.qp_num() << " lid " << conn_info.interface_id() << " mtu " << conn_info.mtu(); } qp_vec_.at(peer_id)->Connect(conn_info); VLOG(1) << "Connected to peer " << peer_id; } OF_ENV_BARRIER(); for (int64_t peer_id : peer_machine_id()) { qp_vec_.at(peer_id)->PostAllRecvRequest(); Singleton::Get()->ClearKV(GenConnInfoKey(this_machine_id, peer_id)); } OF_ENV_BARRIER(); poll_thread_ = std::thread(&IBVerbsCommNet::PollCQ, this); OF_ENV_BARRIER(); } void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) { qp_vec_.at(src_machine_id) ->PostReadRequest(*reinterpret_cast(src_token), *static_cast(dst_token), read_id); } void IBVerbsCommNet::PollCQ() { std::vector wc_vec(max_poll_wc_num_); while (poll_exit_flag_.test_and_set() == false) { poll_exit_flag_.clear(); int32_t found_wc_num = ibv_poll_cq(cq_, max_poll_wc_num_, wc_vec.data()); CHECK_GE(found_wc_num, 0); FOR_RANGE(int32_t, i, 0, found_wc_num) { const ibv_wc& wc = wc_vec.at(i); CHECK_EQ(wc.status, IBV_WC_SUCCESS) << wc.opcode; WorkRequestId* wr_id = reinterpret_cast(wc.wr_id); IBVerbsQP* qp = wr_id->qp; switch (wc.opcode) { case IBV_WC_RDMA_READ: { qp->ReadDone(wr_id); break; } case IBV_WC_SEND: { qp->SendDone(wr_id); break; } case IBV_WC_RECV: { qp->RecvDone(wr_id); break; } default: UNIMPLEMENTED(); } } } } const int32_t IBVerbsCommNet::max_poll_wc_num_ = 32; COMMAND(IBVForkInit()); } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_ #define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_ #include "oneflow/core/common/platform.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) #include #include namespace oneflow { class IBVerbsCommNet final : public CommNetIf { public: OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet); ~IBVerbsCommNet(); void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override; void RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper); private: friend class Singleton; IBVerbsCommNet(); IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) override { return new IBVerbsMemDesc(pd_, ptr, byte_size); } void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override; void PollCQ(); static const int32_t max_poll_wc_num_; ibv_context* context_; ibv_pd* pd_; ibv_cq* cq_; std::vector qp_vec_; std::atomic_flag poll_exit_flag_; std::thread poll_thread_; HashMap, std::shared_ptr> remote_regst2rma_desc_; std::mutex remote_regst2rma_desc_mutex_; }; } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_ ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/platform/include/ibv.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) namespace oneflow { IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size) : mem_ptr_(mem_ptr), mem_size_(byte_size) { mr_ = ibv::wrapper.ibv_reg_mr_wrap( pd, mem_ptr, byte_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ); PCHECK(mr_); } IBVerbsMemDesc::~IBVerbsMemDesc() { PCHECK(ibv::wrapper.ibv_dereg_mr(mr_) == 0); } } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_ #define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_ #include "oneflow/core/common/platform.h" #include "oneflow/core/common/util.h" #include "oneflow/core/comm_network/ibverbs/ibverbs.pb.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) #include namespace oneflow { class IBVerbsMemDesc final { public: OF_DISALLOW_COPY_AND_MOVE(IBVerbsMemDesc); IBVerbsMemDesc() = delete; IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size); ~IBVerbsMemDesc(); void* mem_ptr() const { return mem_ptr_; } size_t mem_size() const { return mem_size_; } const ibv_mr* mr() const { return mr_; } private: ibv_mr* mr_; void* mem_ptr_; uint64_t mem_size_; }; } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_ ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/platform/include/ibv.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) namespace oneflow { namespace { constexpr uint32_t kDefaultQueueDepth = 1024; constexpr uint64_t kDefaultMemBlockSize = 8388608; // 8M } // namespace IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& port_attr, uint8_t port_num, ibv_cq* send_cq, ibv_cq* recv_cq) { // ctx_, pd_ ctx_ = ctx; pd_ = pd; port_num_ = port_num; // qp_ ibv_device_attr device_attr{}; PCHECK(ibv::wrapper.ibv_query_device(ctx, &device_attr) == 0); const int64_t user_queue_depth = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_QUEUE_DEPTH", kDefaultQueueDepth); const uint32_t queue_depth = std::min(device_attr.max_qp_wr, user_queue_depth); ibv_qp_init_attr qp_init_attr{}; qp_init_attr.qp_context = nullptr; qp_init_attr.send_cq = send_cq; qp_init_attr.recv_cq = recv_cq; qp_init_attr.srq = nullptr; qp_init_attr.cap.max_send_wr = queue_depth; qp_init_attr.cap.max_recv_wr = queue_depth; qp_init_attr.cap.max_send_sge = 1; qp_init_attr.cap.max_recv_sge = 1; qp_init_attr.cap.max_inline_data = 0; qp_init_attr.qp_type = IBV_QPT_RC; qp_init_attr.sq_sig_all = 1; qp_ = ibv::wrapper.ibv_create_qp(pd, &qp_init_attr); PCHECK(qp_); // recv_msg_buf_ recv_msg_buf_.assign(queue_depth, nullptr); FOR_RANGE(size_t, i, 0, recv_msg_buf_.size()) { recv_msg_buf_.at(i) = new ActorMsgMR(pd_); } // send_msg_buf_ CHECK(send_msg_buf_.empty()); num_outstanding_send_wr_ = 0; max_outstanding_send_wr_ = queue_depth; read_block_size_ = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE", kDefaultMemBlockSize); mtu_ = static_cast(port_attr.active_mtu); } IBVerbsQP::~IBVerbsQP() { PCHECK(ibv::wrapper.ibv_destroy_qp(qp_) == 0); while (send_msg_buf_.empty() == false) { delete send_msg_buf_.front(); send_msg_buf_.pop(); } for (ActorMsgMR* msg_mr : recv_msg_buf_) { delete msg_mr; } } void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) { ibv_qp_attr qp_attr{}; // IBV_QPS_INIT memset(&qp_attr, 0, sizeof(ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_INIT; // TODO(liujuncheng): Make pkey_index configurable qp_attr.pkey_index = 0; qp_attr.port_num = port_num_; qp_attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; PCHECK(ibv::wrapper.ibv_modify_qp( qp_, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) == 0); // IBV_QPS_RTR memset(&qp_attr, 0, sizeof(ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_RTR; // TODO(liujuncheng): Make sl configurable; qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; if (peer_info.lid() == 0) { qp_attr.ah_attr.is_global = 1; qp_attr.ah_attr.grh.dgid.global.subnet_prefix = peer_info.subnet_prefix(); qp_attr.ah_attr.grh.dgid.global.interface_id = peer_info.interface_id(); qp_attr.ah_attr.grh.flow_label = 0; const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0); qp_attr.ah_attr.grh.sgid_index = gid_index; qp_attr.ah_attr.grh.hop_limit = 255; // TODO(liujuncheng): Make traffic_class configurable; qp_attr.ah_attr.grh.traffic_class = 0; } else { qp_attr.ah_attr.is_global = 0; qp_attr.ah_attr.dlid = peer_info.lid(); } qp_attr.ah_attr.port_num = peer_info.port_num(); qp_attr.path_mtu = static_cast(std::min(peer_info.mtu(), mtu_)); qp_attr.dest_qp_num = peer_info.qp_num(); qp_attr.rq_psn = 0; qp_attr.max_dest_rd_atomic = 1; qp_attr.min_rnr_timer = 12; PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) == 0); // IBV_QPS_RTS memset(&qp_attr, 0, sizeof(ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_RTS; qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; qp_attr.retry_cnt = 7; qp_attr.rnr_retry = 7; qp_attr.timeout = 14; PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr, IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_TIMEOUT) == 0); } void IBVerbsQP::PostAllRecvRequest() { for (ActorMsgMR* msg_mr : recv_msg_buf_) { PostRecvRequest(msg_mr); } } void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem, void* read_id) { CHECK_EQ(remote_mem.mem_size, local_mem.mem_size()); WorkRequestId* wr_id = NewWorkRequestId(); const size_t block_num = RoundUp(remote_mem.mem_size, read_block_size_) / read_block_size_; wr_id->outstanding_sge_cnt = static_cast(block_num); wr_id->read_id = read_id; FOR_RANGE(size_t, i, 0, block_num) { ibv_send_wr wr{}; ibv_sge sge{}; sge.addr = reinterpret_cast(local_mem.mem_ptr()) + i * read_block_size_; sge.length = std::min(read_block_size_, local_mem.mem_size() - i * read_block_size_); sge.lkey = local_mem.mr()->lkey; wr.wr_id = reinterpret_cast(wr_id); wr.next = nullptr; wr.sg_list = &sge; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_READ; wr.send_flags = 0; wr.imm_data = 0; wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * read_block_size_; wr.wr.rdma.rkey = remote_mem.mr_rkey; EnqueuePostSendReadWR(wr, sge); } } void IBVerbsQP::PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper) { ActorMsgMR* msg_mr = GetOneSendMsgMRFromBuf(); msg_mr->set_msg(msg_wrapper); WorkRequestId* wr_id = NewWorkRequestId(); wr_id->msg_mr = msg_mr; ibv_send_wr wr{}; ibv_sge sge{}; sge.addr = reinterpret_cast(msg_mr->mem_desc().mem_ptr()); sge.length = msg_mr->mem_desc().mem_size(); sge.lkey = msg_mr->mem_desc().mr()->lkey; wr.wr_id = reinterpret_cast(wr_id); wr.next = nullptr; wr.sg_list = &sge; wr.num_sge = 1; wr.opcode = IBV_WR_SEND; wr.send_flags = 0; wr.imm_data = 0; memset(&(wr.wr), 0, sizeof(wr.wr)); EnqueuePostSendReadWR(wr, sge); } void IBVerbsQP::EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge) { std::unique_lock pending_send_wr_lock_(pending_send_wr_mutex_); if (num_outstanding_send_wr_ < max_outstanding_send_wr_) { num_outstanding_send_wr_++; ibv_send_wr* bad_wr = nullptr; PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0); } else { std::pair ibv_send_wr_sge = std::make_pair(wr, sge); pending_send_wr_queue_.push(ibv_send_wr_sge); } } void IBVerbsQP::ReadDone(WorkRequestId* wr_id) { CHECK_GE(wr_id->outstanding_sge_cnt, 1); wr_id->outstanding_sge_cnt -= 1; if (wr_id->outstanding_sge_cnt == 0) { Singleton::Get()->ReadDone(wr_id->read_id); DeleteWorkRequestId(wr_id); } PostPendingSendWR(); } void IBVerbsQP::SendDone(WorkRequestId* wr_id) { { std::unique_lock lck(send_msg_buf_mtx_); send_msg_buf_.push(wr_id->msg_mr); } DeleteWorkRequestId(wr_id); PostPendingSendWR(); } void IBVerbsQP::RecvDone(WorkRequestId* wr_id) { auto* ibv_comm_net = dynamic_cast(Singleton::Get()); CHECK(ibv_comm_net != nullptr); ibv_comm_net->RecvActorMsg(wr_id->msg_mr->msg()); PostRecvRequest(wr_id->msg_mr); DeleteWorkRequestId(wr_id); } void IBVerbsQP::PostPendingSendWR() { std::unique_lock pending_send_wr_lock_(pending_send_wr_mutex_); if (pending_send_wr_queue_.empty() == false) { std::pair ibv_send_wr_sge = std::move(pending_send_wr_queue_.front()); ibv_send_wr wr = ibv_send_wr_sge.first; wr.sg_list = &ibv_send_wr_sge.second; pending_send_wr_queue_.pop(); ibv_send_wr* bad_wr = nullptr; PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0); } else { if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; } } } void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) { WorkRequestId* wr_id = NewWorkRequestId(); wr_id->msg_mr = msg_mr; ibv_recv_wr wr{}; ibv_sge sge{}; sge.addr = reinterpret_cast(msg_mr->mem_desc().mem_ptr()); sge.length = msg_mr->mem_desc().mem_size(); sge.lkey = msg_mr->mem_desc().mr()->lkey; wr.wr_id = reinterpret_cast(wr_id); wr.next = nullptr; wr.sg_list = &sge; wr.num_sge = 1; ibv_recv_wr* bad_wr = nullptr; PCHECK(ibv_post_recv(qp_, &wr, &bad_wr) == 0); } ActorMsgMR* IBVerbsQP::GetOneSendMsgMRFromBuf() { std::unique_lock lck(send_msg_buf_mtx_); if (send_msg_buf_.empty()) { send_msg_buf_.push(new ActorMsgMR(pd_)); } ActorMsgMR* msg_mr = send_msg_buf_.front(); send_msg_buf_.pop(); return msg_mr; } WorkRequestId* IBVerbsQP::NewWorkRequestId() { WorkRequestId* wr_id = new WorkRequestId; wr_id->qp = this; wr_id->outstanding_sge_cnt = 0; wr_id->read_id = nullptr; wr_id->msg_mr = nullptr; return wr_id; } void IBVerbsQP::DeleteWorkRequestId(WorkRequestId* wr_id) { CHECK_EQ(wr_id->qp, this); delete wr_id; } } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX ================================================ FILE: oneflow/core/comm_network/ibverbs/ibverbs_qp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_ #define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_ #include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h" #include "oneflow/core/lazy/actor/actor_message.h" #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) namespace oneflow { struct IBVerbsCommNetRMADesc { uint64_t mem_ptr; uint64_t mem_size; uint32_t mr_rkey; }; struct IBVerbsActorMsgWrapper final { ActorMsg msg; IBVerbsCommNetRMADesc rma_desc; }; class ActorMsgMR final { public: OF_DISALLOW_COPY_AND_MOVE(ActorMsgMR); ActorMsgMR() = delete; ActorMsgMR(ibv_pd* pd) { mem_desc_.reset(new IBVerbsMemDesc(pd, &msg_, sizeof(msg_))); } ~ActorMsgMR() { mem_desc_.reset(); } const IBVerbsActorMsgWrapper& msg() const { return msg_; } void set_msg(const IBVerbsActorMsgWrapper& val) { msg_ = val; } const IBVerbsMemDesc& mem_desc() const { return *mem_desc_; } private: IBVerbsActorMsgWrapper msg_; std::unique_ptr mem_desc_; }; class IBVerbsQP; struct WorkRequestId { IBVerbsQP* qp; int32_t outstanding_sge_cnt; void* read_id; ActorMsgMR* msg_mr; }; struct IBVerbsCommNetRMADesc; class IBVerbsQP final { public: OF_DISALLOW_COPY_AND_MOVE(IBVerbsQP); IBVerbsQP() = delete; IBVerbsQP(ibv_context*, ibv_pd*, const struct ibv_port_attr&, uint8_t port_num, ibv_cq* send_cq, ibv_cq* recv_cq); ~IBVerbsQP(); uint32_t qp_num() const { return qp_->qp_num; } void Connect(const IBVerbsConnectionInfo& peer_info); void PostAllRecvRequest(); void PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem, void* read_id); void PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper); void ReadDone(WorkRequestId*); void SendDone(WorkRequestId*); void RecvDone(WorkRequestId*); private: void EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge); void PostPendingSendWR(); WorkRequestId* NewWorkRequestId(); void DeleteWorkRequestId(WorkRequestId* wr_id); ActorMsgMR* GetOneSendMsgMRFromBuf(); void PostRecvRequest(ActorMsgMR*); ibv_context* ctx_; ibv_pd* pd_; uint8_t port_num_; ibv_qp* qp_; std::vector recv_msg_buf_; std::mutex send_msg_buf_mtx_; std::queue send_msg_buf_; std::mutex pending_send_wr_mutex_; uint32_t num_outstanding_send_wr_; uint32_t max_outstanding_send_wr_; std::queue> pending_send_wr_queue_; size_t read_block_size_; int32_t mtu_; }; } // namespace oneflow #endif // WITH_RDMA && OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_ ================================================ FILE: oneflow/core/common/array_ref.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ARRAY_REF_H_ #define ONEFLOW_CORE_COMMON_ARRAY_REF_H_ #include "llvm/ADT/ArrayRef.h" namespace oneflow { template using ArrayRef = llvm::ArrayRef; template using MutableArrayRef = llvm::MutableArrayRef; } // namespace oneflow #endif ================================================ FILE: oneflow/core/common/auto_registration_factory.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_ #define ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_ #include "oneflow/core/common/util.h" namespace oneflow { template struct AutoRegistrationFactory { public: using Creator = std::function; template struct RawRegisterType { RawRegisterType(Key k) { CHECK((AutoRegistrationFactory::Get() .mutable_creators() ->emplace(k, [](Args&&...) { return new Derived; }) .second)) << k; } }; struct CreatorRegisterType { CreatorRegisterType(Key k, Creator v) { CHECK((AutoRegistrationFactory::Get() .mutable_creators() ->emplace(k, v) .second)) << k; } }; Base* New(Key k, Args&&... args) const { auto creators_it = creators().find(k); CHECK(creators_it != creators().end()) << "Unregistered: key: " << k << " Base type name:" << typeid(Base).name() << " Key type name" << typeid(Key).name(); return creators_it->second(std::forward(args)...); } bool IsClassRegistered(Key k, Args&&... args) const { return creators().find(k) != creators().end(); } static AutoRegistrationFactory& Get() { static AutoRegistrationFactory obj; return obj; } private: std::unique_ptr> creators_; bool has_creators() const { return creators_.get() != nullptr; } const HashMap& creators() const { CHECK(has_creators()) << "Unregistered key type: " << typeid(Key).name() << "Base type name:" << typeid(Base).name(); return *creators_.get(); } HashMap* mutable_creators() { if (!creators_) { creators_.reset(new HashMap); } return creators_.get(); } }; #define REGISTER_VAR_NAME OF_PP_CAT(g_registry_var, __COUNTER__) #define REGISTER_CLASS(Key, k, Base, Derived) \ static AutoRegistrationFactory::RawRegisterType REGISTER_VAR_NAME(k) #define REGISTER_CLASS_WITH_ARGS(Key, k, Base, Derived, ...) \ static AutoRegistrationFactory::RawRegisterType \ REGISTER_VAR_NAME(k) #define REGISTER_CLASS_CREATOR(Key, k, Base, f, ...) \ static AutoRegistrationFactory::CreatorRegisterType REGISTER_VAR_NAME( \ k, f) template inline Base* NewObj(Key k, Args&&... args) { return AutoRegistrationFactory::Get().New(k, std::forward(args)...); } template inline std::unique_ptr NewObjUniquePtr(Key k, Args&&... args) { return std::unique_ptr( AutoRegistrationFactory::Get().New(k, std::forward(args)...)); } template inline bool IsClassRegistered(Key k, Args&&... args) { return AutoRegistrationFactory::Get().IsClassRegistered( k, std::forward(args)...); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_ ================================================ FILE: oneflow/core/common/balanced_splitter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { BalancedSplitter::BalancedSplitter(int64_t total_num, int64_t split_num) { base_part_size_ = total_num / split_num; base_begin_idx_ = total_num % split_num; split_num_ = split_num; CHECK_EQ(this->total_num(), total_num); } int64_t BalancedSplitter::total_num() const { return At(split_num_ - 1).end(); } Range BalancedSplitter::At(int64_t idx) const { CHECK_LT(idx, split_num_); int64_t left_bound = -1; int64_t right_bound = -1; if (idx < base_begin_idx_) { left_bound = (base_part_size_ + 1) * idx; right_bound = left_bound + (base_part_size_ + 1); } else { left_bound = (base_part_size_ + 1) * base_begin_idx_ + base_part_size_ * (idx - base_begin_idx_); right_bound = left_bound + base_part_size_; } return Range(left_bound, right_bound); } Range BalancedSplitter::At(int64_t first_idx, int64_t last_idx) const { CHECK_LE(first_idx, last_idx); CHECK_LT(last_idx, split_num_); Range first_range = At(first_idx); Range last_range = At(last_idx); return Range(first_range.begin(), last_range.end()); } int64_t BalancedSplitter::GetRangeIndexForVal(int64_t value) const { CHECK_GE(value, 0); CHECK_LT(value, total_num()); int64_t base_size = (base_part_size_ + 1) * base_begin_idx_; if (value < base_size) { return value / (base_part_size_ + 1); } else { return base_begin_idx_ + (value - base_size) / base_part_size_; } } } // namespace oneflow ================================================ FILE: oneflow/core/common/balanced_splitter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_ #define ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_ #include #include "oneflow/core/common/range.h" #include "oneflow/core/common/util.h" namespace oneflow { // For example // BalancedSplitter splitter(20, 6) // the result of splitter.At // 0 [0, 4) // 1 [4, 8) // 2 [8, 11) // 3 [11, 14) // 4 [14, 17) // 5 [17, 20) class BalancedSplitter final { public: // OF_DISALLOW_COPY_AND_MOVE(BalancedSplitter); BalancedSplitter() = delete; ~BalancedSplitter() = default; BalancedSplitter(int64_t total_num, int64_t split_num); Range At(int64_t idx) const; Range At(int64_t first_idx, int64_t last_idx) const; // Get the range index number of a value. int64_t GetRangeIndexForVal(int64_t value) const; int64_t total_num() const; private: int64_t base_part_size_; int64_t base_begin_idx_; int64_t split_num_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_ ================================================ FILE: oneflow/core/common/balanced_splitter_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { TEST(BalancedSplitter, split_20_to_6_part) { BalancedSplitter splitter(20, 6); ASSERT_TRUE(splitter.At(0) == Range(0, 4)); ASSERT_TRUE(splitter.At(1) == Range(4, 8)); ASSERT_TRUE(splitter.At(2) == Range(8, 11)); ASSERT_TRUE(splitter.At(3) == Range(11, 14)); ASSERT_TRUE(splitter.At(4) == Range(14, 17)); ASSERT_TRUE(splitter.At(5) == Range(17, 20)); } TEST(BalancedSplitter, split_2_to_3_part) { BalancedSplitter splitter(2, 3); ASSERT_TRUE(splitter.At(0) == Range(0, 1)); ASSERT_TRUE(splitter.At(1) == Range(1, 2)); ASSERT_TRUE(splitter.At(2) == Range(2, 2)); } TEST(BalancedSplitter, GetRangeIndexForVal) { const size_t total_num = 937; const size_t split_num = 11; BalancedSplitter bs(total_num, split_num); ASSERT_TRUE(bs.total_num() == total_num); for (size_t i = 0; i < split_num; ++i) { Range range = bs.At(i); for (size_t value = range.begin(); value < range.end(); ++value) { ASSERT_TRUE(bs.GetRangeIndexForVal(value) == i); } } } } // namespace oneflow ================================================ FILE: oneflow/core/common/bfloat16.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BFLOAT16_H_ #define ONEFLOW_CORE_COMMON_BFLOAT16_H_ #include #include #include #include namespace oneflow { #if defined(__CUDACC__) #define OF_DEVICE_FUNCTION __device__ __host__ __forceinline__ #else #define OF_DEVICE_FUNCTION inline #endif struct alignas(2) bfloat16 { uint16_t x; bfloat16() = default; bfloat16(const bfloat16& o) = default; bfloat16& operator=(const bfloat16& o) = default; bfloat16(bfloat16&& o) = default; bfloat16& operator=(bfloat16&& o) = default; ~bfloat16() = default; struct from_bits_t {}; static constexpr inline from_bits_t from_bits() { return from_bits_t(); } constexpr inline bfloat16(unsigned short bits, from_bits_t) : x(bits){}; // reference: pytorch/c10/util/BFloat16.h // https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16.h bfloat16(float value) { if (std::isnan(value)) { x = 0x7FC0; } else { union { uint32_t U32; float F32; }; F32 = value; uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFFU; x = static_cast((U32 + rounding_bias) >> 16); } } inline operator float() const { float res = 0; uint32_t tmp = x; tmp <<= 16; std::memcpy(&res, &tmp, sizeof(tmp)); return res; } inline bool operator==(const bfloat16& other) const { return x == other.x; } inline explicit operator bool() const { return (x & 0x7fff) != 0; } inline explicit operator int8_t() const { return static_cast(static_cast(*this)); } inline explicit operator uint8_t() const { return static_cast(static_cast(*this)); } inline explicit operator int16_t() const { return static_cast(static_cast(*this)); } inline explicit operator uint16_t() const { return static_cast(static_cast(*this)); } inline explicit operator int32_t() const { return static_cast(static_cast(*this)); } inline explicit operator uint32_t() const { return static_cast(static_cast(*this)); } inline explicit operator int64_t() const { return static_cast(static_cast(*this)); } inline explicit operator uint64_t() const { return static_cast(static_cast(*this)); } inline explicit operator double() const { return static_cast(static_cast(*this)); } }; // Arithmetic inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return static_cast(a) + static_cast(b); } inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return static_cast(a) - static_cast(b); } inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return static_cast(a) * static_cast(b); } inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return static_cast(a) / static_cast(b); } inline bfloat16 operator-(const bfloat16& a) { bfloat16 output; output.x = a.x ^ 0x8000U; return output; } inline bfloat16& operator+=(bfloat16& a, const bfloat16& b) { a = a + b; return a; } inline bfloat16& operator-=(bfloat16& a, const bfloat16& b) { a = a - b; return a; } inline bfloat16& operator*=(bfloat16& a, const bfloat16& b) { a = a * b; return a; } inline bfloat16& operator/=(bfloat16& a, const bfloat16& b) { a = a / b; return a; } inline bfloat16& operator|(bfloat16& a, const bfloat16& b) { a.x = a.x | b.x; return a; } inline bfloat16& operator^(bfloat16& a, const bfloat16& b) { a.x = a.x ^ b.x; return a; } inline bfloat16& operator&(bfloat16& a, const bfloat16& b) { a.x = a.x & b.x; return a; } // Arithmetic with floats inline float operator+(bfloat16 a, float b) { return static_cast(a) + b; } inline float operator-(bfloat16 a, float b) { return static_cast(a) - b; } inline float operator*(bfloat16 a, float b) { return static_cast(a) * b; } inline float operator/(bfloat16 a, float b) { return static_cast(a) / b; } inline float operator+(float a, bfloat16 b) { return a + static_cast(b); } inline float operator-(float a, bfloat16 b) { return a - static_cast(b); } inline float operator*(float a, bfloat16 b) { return a * static_cast(b); } inline float operator/(float a, bfloat16 b) { return a / static_cast(b); } inline float& operator+=(float& a, const bfloat16& b) { return a += static_cast(b); } inline float& operator-=(float& a, const bfloat16& b) { return a -= static_cast(b); } inline float& operator*=(float& a, const bfloat16& b) { return a *= static_cast(b); } inline float& operator/=(float& a, const bfloat16& b) { return a /= static_cast(b); } // Arithmetic with doubles inline double operator+(bfloat16 a, double b) { return static_cast(a) + b; } inline double operator-(bfloat16 a, double b) { return static_cast(a) - b; } inline double operator*(bfloat16 a, double b) { return static_cast(a) * b; } inline double operator/(bfloat16 a, double b) { return static_cast(a) / b; } inline double operator+(double a, bfloat16 b) { return a + static_cast(b); } inline double operator-(double a, bfloat16 b) { return a - static_cast(b); } inline double operator*(double a, bfloat16 b) { return a * static_cast(b); } inline double operator/(double a, bfloat16 b) { return a / static_cast(b); } // Arithmetic with int32_t inline bfloat16 operator+(bfloat16 a, int32_t b) { return a + static_cast(b); } inline bfloat16 operator-(bfloat16 a, int32_t b) { return a - static_cast(b); } inline bfloat16 operator*(bfloat16 a, int32_t b) { return a * static_cast(b); } inline bfloat16 operator/(bfloat16 a, int32_t b) { return a / static_cast(b); } inline bfloat16 operator+(int32_t a, bfloat16 b) { return static_cast(a) + b; } inline bfloat16 operator-(int32_t a, bfloat16 b) { return static_cast(a) - b; } inline bfloat16 operator*(int32_t a, bfloat16 b) { return static_cast(a) * b; } inline bfloat16 operator/(int32_t a, bfloat16 b) { return static_cast(a) / b; } // Arithmetic with int64_t inline bfloat16 operator+(bfloat16 a, int64_t b) { return a + static_cast(b); } inline bfloat16 operator-(bfloat16 a, int64_t b) { return a - static_cast(b); } inline bfloat16 operator*(bfloat16 a, int64_t b) { return a * static_cast(b); } inline bfloat16 operator/(bfloat16 a, int64_t b) { return a / static_cast(b); } inline bfloat16 operator+(int64_t a, bfloat16 b) { return static_cast(a) + b; } inline bfloat16 operator-(int64_t a, bfloat16 b) { return static_cast(a) - b; } inline bfloat16 operator*(int64_t a, bfloat16 b) { return static_cast(a) * b; } inline bfloat16 operator/(int64_t a, bfloat16 b) { return static_cast(a) / b; } // Comparison operators inline bool operator>(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) > static_cast(rhs); } inline bool operator>=(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) >= static_cast(rhs); } inline bool operator<(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) < static_cast(rhs); } inline bool operator<=(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) <= static_cast(rhs); } inline bool operator==(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) == static_cast(rhs); } inline bool operator!=(bfloat16& lhs, bfloat16& rhs) { return static_cast(lhs) != static_cast(rhs); } } // namespace oneflow namespace std { inline bool isnan(const oneflow::bfloat16& value) { return (value.x & 0x7FFFU) > 0x07F80U; } inline bool isinf(const oneflow::bfloat16& value) { return value.x == 0x07F80U; } inline bool isfinite(const oneflow::bfloat16& value) { return !isinf(value) && !isnan(value); } template<> class numeric_limits { public: static constexpr bool is_signed = true; static constexpr bool is_specialized = true; static constexpr bool is_integer = false; static constexpr bool is_exact = false; static constexpr bool has_infinity = true; static constexpr bool has_quiet_NaN = true; static constexpr bool has_signaling_NaN = true; static constexpr auto has_denorm = numeric_limits::has_denorm; static constexpr auto has_denorm_loss = numeric_limits::has_denorm_loss; static constexpr auto round_style = numeric_limits::round_style; static constexpr bool is_iec559 = false; static constexpr bool is_bounded = true; static constexpr bool is_modulo = false; static constexpr int digits = 8; static constexpr int digits10 = 2; static constexpr int max_digits10 = 4; static constexpr int radix = 2; static constexpr int min_exponent = -125; static constexpr int min_exponent10 = -37; static constexpr int max_exponent = 128; static constexpr int max_exponent10 = 38; static constexpr auto traps = numeric_limits::traps; static constexpr auto tinyness_before = numeric_limits::tinyness_before; static constexpr oneflow::bfloat16 min() { return oneflow::bfloat16(0x0080U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 lowest() { return oneflow::bfloat16(0xFF7FU, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 max() { return oneflow::bfloat16(0x7F7FU, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 epsilon() { return oneflow::bfloat16(0x3C00U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 round_error() { return oneflow::bfloat16(0x3F00U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 infinity() { return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 quiet_NaN() { return oneflow::bfloat16(0x7FC0U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 signaling_NaN() { return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits()); } static constexpr oneflow::bfloat16 denorm_min() { return oneflow::bfloat16(0x0001U, oneflow::bfloat16::from_bits()); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_BFLOAT16_H_ ================================================ FILE: oneflow/core/common/bfloat16_math.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_ #define ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_ #include "oneflow/core/common/bfloat16.h" namespace std { // reference: pytorch/c10/util/BFloat16-math.h // https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16-math.h inline oneflow::bfloat16 acos(oneflow::bfloat16 a) { return std::acos(static_cast(a)); } inline oneflow::bfloat16 asin(oneflow::bfloat16 a) { return std::asin(static_cast(a)); } inline oneflow::bfloat16 atan(oneflow::bfloat16 a) { return std::atan(static_cast(a)); } inline oneflow::bfloat16 erf(oneflow::bfloat16 a) { return std::erf(static_cast(a)); } inline oneflow::bfloat16 erfc(oneflow::bfloat16 a) { return std::erfc(static_cast(a)); } inline oneflow::bfloat16 exp(oneflow::bfloat16 a) { return std::exp(static_cast(a)); } inline oneflow::bfloat16 expm1(oneflow::bfloat16 a) { return std::expm1(static_cast(a)); } inline oneflow::bfloat16 log(oneflow::bfloat16 a) { return std::log(static_cast(a)); } inline oneflow::bfloat16 log10(oneflow::bfloat16 a) { return std::log10(static_cast(a)); } inline oneflow::bfloat16 log1p(oneflow::bfloat16 a) { return std::log1p(static_cast(a)); } inline oneflow::bfloat16 log2(oneflow::bfloat16 a) { return std::log2(static_cast(a)); } inline oneflow::bfloat16 ceil(oneflow::bfloat16 a) { return std::ceil(static_cast(a)); } inline oneflow::bfloat16 cos(oneflow::bfloat16 a) { return std::cos(static_cast(a)); } inline oneflow::bfloat16 floor(oneflow::bfloat16 a) { return std::floor(static_cast(a)); } inline oneflow::bfloat16 nearbyint(oneflow::bfloat16 a) { return std::nearbyint(static_cast(a)); } inline oneflow::bfloat16 sin(oneflow::bfloat16 a) { return std::sin(static_cast(a)); } inline oneflow::bfloat16 tan(oneflow::bfloat16 a) { return std::tan(static_cast(a)); } inline oneflow::bfloat16 sinh(oneflow::bfloat16 a) { return std::sinh(static_cast(a)); } inline oneflow::bfloat16 cosh(oneflow::bfloat16 a) { return std::cosh(static_cast(a)); } inline oneflow::bfloat16 tanh(oneflow::bfloat16 a) { return std::tanh(static_cast(a)); } inline oneflow::bfloat16 trunc(oneflow::bfloat16 a) { return std::trunc(static_cast(a)); } inline oneflow::bfloat16 lgamma(oneflow::bfloat16 a) { return std::lgamma(static_cast(a)); } inline oneflow::bfloat16 sqrt(oneflow::bfloat16 a) { return std::sqrt(static_cast(a)); } inline oneflow::bfloat16 rsqrt(oneflow::bfloat16 a) { return 1.0 / std::sqrt(static_cast(a)); } inline oneflow::bfloat16 abs(oneflow::bfloat16 a) { return std::abs(static_cast(a)); } inline oneflow::bfloat16 pow(oneflow::bfloat16 a, double b) { return std::pow(static_cast(a), b); } inline oneflow::bfloat16 pow(oneflow::bfloat16 a, oneflow::bfloat16 b) { return std::pow(static_cast(a), static_cast(b)); } inline oneflow::bfloat16 fmod(oneflow::bfloat16 a, oneflow::bfloat16 b) { return std::fmod(static_cast(a), static_cast(b)); } } // namespace std #endif // ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_ ================================================ FILE: oneflow/core/common/bfloat16_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/bfloat16.h" #include "oneflow/core/common/bfloat16_math.h" namespace oneflow { namespace test { float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { // reference: pytorch/c10/test/util/bfloat16_test.cpp // https://github.com/pytorch/pytorch/blob/release/1.12/c10/test/util/bfloat16_test.cpp uint32_t bytes = 0; bytes |= sign; bytes <<= 8; bytes |= exponent; bytes <<= 23; bytes |= fraction; float res = NAN; std::memcpy(&res, &bytes, sizeof(res)); return res; } TEST(BFLOAT16MATH, Add) { // 6.25 float input = float_from_bytes(0, 0, 0x40C80000U); // 7.25 float expected = float_from_bytes(0, 0, 0x40E80000U); bfloat16 b(input); b = b + 1; float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Sub) { // 7.25 float input = float_from_bytes(0, 0, 0x40E80000U); // 6.25 float expected = float_from_bytes(0, 0, 0x40C80000U); bfloat16 b(input); b = b - 1; float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Mul) { // 3.125 float input = float_from_bytes(0, 0, 0x40480000U); // 6.25 float expected = float_from_bytes(0, 0, 0x40C80000U); bfloat16 b(input); b = b * 2; float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Div) { // 6.25 float input = float_from_bytes(0, 0, 0x40C80000U); // 3.125 float expected = float_from_bytes(0, 0, 0x40480000U); bfloat16 b(input); b = b / 2; float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Log2) { // 16 float input = float_from_bytes(0, 0, 0x41800000U); // 4 float expected = float_from_bytes(0, 0, 0x40800000U); bfloat16 b(input); b = std::log2(b); float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Log10) { // 100 float input = float_from_bytes(0, 0, 0x42C80000U); // 2 float expected = float_from_bytes(0, 0, 0x40000000U); bfloat16 b(input); b = std::log10(b); float res = static_cast(b); EXPECT_EQ(res, expected); } TEST(BFLOAT16MATH, Sqrt) { // 25 float input = float_from_bytes(0, 0, 0x41C80000U); // 5 float expected = float_from_bytes(0, 0, 0x40A00000U); bfloat16 b(input); b = std::sqrt(b); float res = static_cast(b); EXPECT_EQ(res, expected); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/blas.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BLAS_H_ #define ONEFLOW_CORE_COMMON_BLAS_H_ #include #include #include #include "oneflow/core/common/cblas.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { #define BLAS_NAME_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dot) \ OF_PP_MAKE_TUPLE_SEQ(swap) \ OF_PP_MAKE_TUPLE_SEQ(copy) \ OF_PP_MAKE_TUPLE_SEQ(axpy) \ OF_PP_MAKE_TUPLE_SEQ(scal) \ OF_PP_MAKE_TUPLE_SEQ(gemv) \ OF_PP_MAKE_TUPLE_SEQ(gemm) \ OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \ OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched) \ OF_PP_MAKE_TUPLE_SEQ(getrfBatched) \ OF_PP_MAKE_TUPLE_SEQ(getriBatched) #define CBLAS_TEMPLATE(name) \ template \ auto cblas_##name(Args&&... args) \ ->typename std::enable_if::value, \ decltype(cblas_##s##name(std::forward(args)...))>::type { \ return cblas_##s##name(std::forward(args)...); \ } \ template \ auto cblas_##name(Args&&... args) \ ->typename std::enable_if::value, \ decltype(cblas_##d##name(std::forward(args)...))>::type { \ return cblas_##d##name(std::forward(args)...); \ } \ template \ auto cblas_##name(Args&&... args) \ ->typename std::enable_if>::value, \ decltype(cblas_##c##name(std::forward(args)...))>::type { \ return cblas_##c##name(std::forward(args)...); \ } \ template \ auto cblas_##name(Args&&... args) \ ->typename std::enable_if>::value, \ decltype(cblas_##z##name(std::forward(args)...))>::type { \ return cblas_##z##name(std::forward(args)...); \ } OF_PP_FOR_EACH_TUPLE(CBLAS_TEMPLATE, BLAS_NAME_SEQ); #undef CBLAS_TEMPLATE #undef BLAS_NAME_SEQ } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BLAS_H_ ================================================ FILE: oneflow/core/common/blocking_counter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { int64_t BlockingCounter::Increase() { std::unique_lock lck(mtx_); CHECK_GT(cnt_val_, 0); cnt_val_ += 1; return cnt_val_; } int64_t BlockingCounter::Decrease() { std::unique_lock lck(mtx_); cnt_val_ -= 1; if (cnt_val_ == 0) { cond_.notify_all(); } return cnt_val_; } Maybe BlockingCounter::WaitUntilCntEqualZero(size_t timeout_seconds) { return Singleton::Get()->WithScopedRelease([&, this]() -> Maybe { std::chrono::duration seconds(timeout_seconds); std::unique_lock lck(mtx_); CHECK_OR_RETURN(cond_.wait_for(lck, seconds, [this]() { return cnt_val_ == 0; })) << Error::TimeoutError(); return Maybe::Ok(); }); } void BlockingCounter::WaitForeverUntilCntEqualZero() { CHECK_JUST(WaitUntilCntEqualZero([]() -> Maybe { return false; })); } Maybe BlockingCounter::WaitUntilCntEqualZero( const std::function()>& StopWaitingAfterTimeout) { while (true) { auto status = TRY(WaitUntilCntEqualZero(EnvInteger())); if (status.IsOk()) { return status; } if (!status.error()->has_timeout_error()) { return status; } if (JUST(StopWaitingAfterTimeout())) { return status; } } UNIMPLEMENTED_THEN_RETURN(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/blocking_counter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_ #define ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" namespace oneflow { class BlockingCounter final { public: OF_DISALLOW_COPY_AND_MOVE(BlockingCounter); BlockingCounter() = delete; ~BlockingCounter() = default; BlockingCounter(int64_t cnt_val) { cnt_val_ = cnt_val; } int64_t Increase(); int64_t Decrease(); void WaitForeverUntilCntEqualZero(); Maybe WaitUntilCntEqualZero(size_t timeout_seconds); Maybe WaitUntilCntEqualZero(const std::function()>& StopWaitingAfterTimeout); private: std::mutex mtx_; std::condition_variable cond_; int64_t cnt_val_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_ ================================================ FILE: oneflow/core/common/blocking_then_busy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_ #define ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/notifier.h" #include "oneflow/core/common/spin_counter.h" namespace oneflow { class BlockingThenBusy final { public: BlockingThenBusy(const BlockingThenBusy&) = delete; BlockingThenBusy(BlockingThenBusy&&) = delete; constexpr static int kCnt = 1; BlockingThenBusy() : notifier_(), spin_counter_(kCnt) {} Notifier* mut_notifier() { return ¬ifier_; } SpinCounter* mut_spin_counter() { return &spin_counter_; } void Reset() { mut_spin_counter()->Reset(kCnt); } Maybe WaitUntilCntEqualZero(const std::function()>& StopAfterTimeout) { JUST(notifier_.TimedWaitAndClearNotifiedCnt(StopAfterTimeout)); JUST(spin_counter_.WaitUntilCntEqualZero()); return Maybe::Ok(); } private: Notifier notifier_; SpinCounter spin_counter_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_ ================================================ FILE: oneflow/core/common/buffer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BUFFER_H_ #define ONEFLOW_CORE_COMMON_BUFFER_H_ #include "oneflow/core/common/util.h" namespace oneflow { enum BufferStatus { kBufferStatusSuccess = 0, kBufferStatusErrorClosed, kBufferStatusEmpty }; template class Buffer final { public: OF_DISALLOW_COPY_AND_MOVE(Buffer); Buffer(size_t max_len) : max_len_(max_len), is_closed_(false) {} ~Buffer() = default; template BufferStatus Push(U&& item); BufferStatus Pull(T* item); BufferStatus TryReceive(T* item); void Close(); private: std::queue queue_; mutable std::mutex mutex_; size_t max_len_; bool is_closed_; std::condition_variable cond_; }; template template BufferStatus Buffer::Push(U&& item) { std::unique_lock lock(mutex_); cond_.wait(lock, [this]() { return queue_.size() < max_len_ || is_closed_; }); if (is_closed_) { return kBufferStatusErrorClosed; } queue_.push(std::forward(item)); cond_.notify_one(); return kBufferStatusSuccess; } template BufferStatus Buffer::Pull(T* item) { std::unique_lock lock(mutex_); cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; }); if (queue_.empty()) { return kBufferStatusErrorClosed; } *item = std::move(queue_.front()); queue_.pop(); if (queue_.size() < max_len_) { cond_.notify_all(); } return kBufferStatusSuccess; } template BufferStatus Buffer::TryReceive(T* item) { std::unique_lock lock(mutex_); if (queue_.empty()) { return is_closed_ ? kBufferStatusErrorClosed : kBufferStatusEmpty; } *item = std::move(queue_.front()); queue_.pop(); if (queue_.size() < max_len_) { cond_.notify_all(); } return kBufferStatusSuccess; } template void Buffer::Close() { std::unique_lock lock(mutex_); is_closed_ = true; cond_.notify_all(); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BUFFER_H_ ================================================ FILE: oneflow/core/common/buffer_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_ #define ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/buffer.h" namespace oneflow { template class BufferMgr final { public: OF_DISALLOW_COPY_AND_MOVE(BufferMgr); ~BufferMgr() = default; void NewBuffer(const std::string& buffer_name, size_t buffer_size) { CHECK(name2buffer_.emplace(buffer_name, std::make_unique>(buffer_size)).second); } Buffer* Get(const std::string& buffer_name) const { const auto& iter = name2buffer_.find(buffer_name); CHECK(iter != name2buffer_.end()) << "buffer_name: " << buffer_name; return iter->second.get(); } private: friend class Singleton; BufferMgr() = default; HashMap>> name2buffer_; }; static const std::string kBufferNameGlobalWaitJobId = "GlobalWaitJobId"; static const std::string kCallbackNotifierBufferNamePrefix = "CallbackNotifier-"; static const std::string kInputCriticalSectionWaitBufferNamePrefix = "InputCriticalSectionWait-"; static const std::string kInputCriticalSectionCallbackBufferNamePrefix = "InputCriticalSectionCallback-"; static const std::string kOutputCriticalSectionWaitBufferNamePrefix = "OutputCriticalSectionWait-"; static const std::string kOutputCriticalSectionCallbackBufferNamePrefix = "OutputCriticalSectionCallback-"; static const std::string kInputBufferNamePrefix = "Input-"; static const std::string kOutputBufferNamePrefix = "Output-"; static const std::string kSourceTickBufferNamePrefix = "SourceTick-"; inline std::string GetCallbackNotifierBufferName(const std::string& job_name) { return kCallbackNotifierBufferNamePrefix + job_name; } inline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) { return kInputCriticalSectionWaitBufferNamePrefix + job_name; } inline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) { return kInputCriticalSectionCallbackBufferNamePrefix + job_name; } inline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) { return kOutputCriticalSectionWaitBufferNamePrefix + job_name; } inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) { return kOutputCriticalSectionCallbackBufferNamePrefix + job_name; } inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) { return kInputBufferNamePrefix + job_name + "-" + op_name; } inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) { return kOutputBufferNamePrefix + job_name + "-" + op_name; } inline std::string GetSourceTickBufferName(const std::string& job_name) { return kSourceTickBufferNamePrefix + job_name; } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_ ================================================ FILE: oneflow/core/common/cached_caller.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/common/cached_caller.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { bool IsThreadLocalCacheEnabled() { if (Singleton::Get() == nullptr) { return true; } return Singleton::Get()->enable_thread_local_cache(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/cached_caller.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CACHED_CALLER_H_ #define ONEFLOW_CORE_COMMON_CACHED_CALLER_H_ #include #include #include #include "oneflow/core/common/function_traits.h" #include "oneflow/core/common/hash_eq_trait_ptr.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/tuple_hash.h" // gcc 11 falsely reports error: // ‘void operator delete(void*, std::size_t)’ called on unallocated object ‘cache’ // However, `DeleteAndClear` is only called after `cache` is allocated in // if (cache == nullptr) block. // The reason not to use #pragma GCC diagnostic push/pop is that gcc reports // the error on the caller of `ThreadLocalCachedCall`. // TODO: replace ThreadLocalCachedCall with ThreadLocalCached decorator? #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 11 #pragma GCC diagnostic ignored "-Wfree-nonheap-object" #endif namespace oneflow { template void DeleteAndClear(T** ptr, size_t obj_cnt) { static const size_t kThreshold = 4096; if (obj_cnt <= kThreshold) { delete ptr; } else { std::thread([](T* ptr) { delete ptr; }, *ptr); } *ptr = nullptr; } bool IsThreadLocalCacheEnabled(); template< typename F, typename Ret = typename function_traits::return_type, typename RawArg = typename std::tuple_element<0, typename function_traits::args_type>::type, typename Arg = typename std::remove_const::type>::type> Ret ThreadLocalCachedCall(size_t max_size, F f, const Arg& arg) { if (IsThreadLocalCacheEnabled() == false) { return f(arg); } using HashMap = std::unordered_map, Ret>; using KeyStorage = std::list>; static thread_local HashMap* cache = nullptr; static thread_local KeyStorage* key_storage = nullptr; if (cache != nullptr && cache->size() >= max_size) { DeleteAndClear(&cache, cache->size()); DeleteAndClear(&key_storage, cache->size()); } if (cache == nullptr) { cache = new HashMap(); key_storage = new KeyStorage(); } size_t hash_value = std::hash()(arg); { HashEqTraitPtr ptr_wrapper(&arg, hash_value); const auto& iter = cache->find(ptr_wrapper); if (iter != cache->end()) { return iter->second; } } Arg* new_arg = new Arg(arg); key_storage->emplace_back(new_arg); HashEqTraitPtr ptr_wrapper(new_arg, hash_value); return cache->emplace(ptr_wrapper, f(*new_arg)).first->second; } template< typename F, typename Ret = typename function_traits::return_type, typename RawArg = typename std::tuple_element<0, typename function_traits::args_type>::type, typename Arg = typename std::remove_const::type>::type> std::function WithResultCached(F f) { auto cache = std::make_shared>(); return [cache, f](const Arg& arg) -> Ret { const auto& iter = cache->find(arg); if (iter != cache->end()) { return iter->second; } return cache->emplace(arg, f(arg)).first->second; }; } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CACHED_CALLER_H_ ================================================ FILE: oneflow/core/common/cblas.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CBLAS_H_ #define ONEFLOW_CORE_COMMON_CBLAS_H_ #include /* * Enumerated and derived types */ #define CBLAS_INDEX size_t /* this may vary between platforms */ enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; enum CBLAS_TRANSPOSE { CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113 }; enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 }; enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 }; enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 }; #ifdef __cplusplus extern "C" { #endif /* * =========================================================================== * Prototypes for level 1 BLAS functions (complex are recast as routines) * =========================================================================== */ float cblas_sdsdot(const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY); double cblas_dsdot(const int N, const float* X, const int incX, const float* Y, const int incY); float cblas_sdot(const int N, const float* X, const int incX, const float* Y, const int incY); double cblas_ddot(const int N, const double* X, const int incX, const double* Y, const int incY); /* * Functions having prefixes Z and C only */ void cblas_cdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY, void* dotu); void cblas_cdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY, void* dotc); void cblas_zdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY, void* dotu); void cblas_zdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY, void* dotc); /* * Functions having prefixes S D SC DZ */ float cblas_snrm2(const int N, const float* X, const int incX); float cblas_sasum(const int N, const float* X, const int incX); double cblas_dnrm2(const int N, const double* X, const int incX); double cblas_dasum(const int N, const double* X, const int incX); float cblas_scnrm2(const int N, const void* X, const int incX); float cblas_scasum(const int N, const void* X, const int incX); double cblas_dznrm2(const int N, const void* X, const int incX); double cblas_dzasum(const int N, const void* X, const int incX); /* * Functions having standard 4 prefixes (S D C Z) */ CBLAS_INDEX cblas_isamax(const int N, const float* X, const int incX); CBLAS_INDEX cblas_idamax(const int N, const double* X, const int incX); CBLAS_INDEX cblas_icamax(const int N, const void* X, const int incX); CBLAS_INDEX cblas_izamax(const int N, const void* X, const int incX); /* * =========================================================================== * Prototypes for level 1 BLAS routines * =========================================================================== */ /* * Routines with standard 4 prefixes (s, d, c, z) */ void cblas_sswap(const int N, float* X, const int incX, float* Y, const int incY); void cblas_scopy(const int N, const float* X, const int incX, float* Y, const int incY); void cblas_saxpy(const int N, const float alpha, const float* X, const int incX, float* Y, const int incY); void cblas_dswap(const int N, double* X, const int incX, double* Y, const int incY); void cblas_dcopy(const int N, const double* X, const int incX, double* Y, const int incY); void cblas_daxpy(const int N, const double alpha, const double* X, const int incX, double* Y, const int incY); void cblas_cswap(const int N, void* X, const int incX, void* Y, const int incY); void cblas_ccopy(const int N, const void* X, const int incX, void* Y, const int incY); void cblas_caxpy(const int N, const void* alpha, const void* X, const int incX, void* Y, const int incY); void cblas_zswap(const int N, void* X, const int incX, void* Y, const int incY); void cblas_zcopy(const int N, const void* X, const int incX, void* Y, const int incY); void cblas_zaxpy(const int N, const void* alpha, const void* X, const int incX, void* Y, const int incY); /* * Routines with S and D prefix only */ void cblas_srotg(float* a, float* b, float* c, float* s); void cblas_srotmg(float* d1, float* d2, float* b1, const float b2, float* P); void cblas_srot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s); void cblas_srotm(const int N, float* X, const int incX, float* Y, const int incY, const float* P); void cblas_drotg(double* a, double* b, double* c, double* s); void cblas_drotmg(double* d1, double* d2, double* b1, const double b2, double* P); void cblas_drot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s); void cblas_drotm(const int N, double* X, const int incX, double* Y, const int incY, const double* P); /* * Routines with S D C Z CS and ZD prefixes */ void cblas_sscal(const int N, const float alpha, float* X, const int incX); void cblas_dscal(const int N, const double alpha, double* X, const int incX); void cblas_cscal(const int N, const void* alpha, void* X, const int incX); void cblas_zscal(const int N, const void* alpha, void* X, const int incX); void cblas_csscal(const int N, const float alpha, void* X, const int incX); void cblas_zdscal(const int N, const double alpha, void* X, const int incX); /* * =========================================================================== * Prototypes for level 2 BLAS * =========================================================================== */ /* * Routines with standard 4 prefixes (S, D, C, Z) */ void cblas_sgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY); void cblas_sgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const int KL, const int KU, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY); void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const float* A, const int lda, float* X, const int incX); void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX); void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const float* Ap, float* X, const int incX); void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const float* A, const int lda, float* X, const int incX); void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX); void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const float* Ap, float* X, const int incX); void cblas_dgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY); void cblas_dgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const int KL, const int KU, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY); void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const double* A, const int lda, double* X, const int incX); void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX); void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const double* Ap, double* X, const int incX); void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const double* A, const int lda, double* X, const int incX); void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX); void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const double* Ap, double* X, const int incX); void cblas_cgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_cgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const int KL, const int KU, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* A, const int lda, void* X, const int incX); void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const void* A, const int lda, void* X, const int incX); void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* Ap, void* X, const int incX); void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* A, const int lda, void* X, const int incX); void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const void* A, const int lda, void* X, const int incX); void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* Ap, void* X, const int incX); void cblas_zgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_zgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const int KL, const int KU, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* A, const int lda, void* X, const int incX); void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const void* A, const int lda, void* X, const int incX); void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* Ap, void* X, const int incX); void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* A, const int lda, void* X, const int incX); void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const int K, const void* A, const int lda, void* X, const int incX); void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N, const void* Ap, void* X, const int incX); /* * Routines with S and D prefixes only */ void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY); void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY); void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* Ap, const float* X, const int incX, const float beta, float* Y, const int incY); void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda); void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* X, const int incX, float* A, const int lda); void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* X, const int incX, float* Ap); void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda); void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A); void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY); void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY); void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* Ap, const double* X, const int incX, const double beta, double* Y, const int incY); void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda); void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* X, const int incX, double* A, const int lda); void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* X, const int incX, double* Ap); void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda); void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A); /* * Routines with C and Z prefixes only */ void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* Ap, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const void* X, const int incX, void* A, const int lda); void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const float alpha, const void* X, const int incX, void* A); void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* Ap); void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K, const void* alpha, const void* A, const int lda, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* Ap, const void* X, const int incX, const void* beta, void* Y, const int incY); void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const void* X, const int incX, void* A, const int lda); void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const double alpha, const void* X, const int incX, void* A); void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* A, const int lda); void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const void* alpha, const void* X, const int incX, const void* Y, const int incY, void* Ap); /* * =========================================================================== * Prototypes for level 3 BLAS * =========================================================================== */ /* * Routines with standard 4 prefixes (S, D, C, Z) */ void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc); void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc); void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float beta, float* C, const int ldc); void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc); void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb); void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb); void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc); void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc); void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double beta, double* C, const int ldc); void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc); void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb); void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb); void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* beta, void* C, const int ldc); void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha, const void* A, const int lda, void* B, const int ldb); void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha, const void* A, const int lda, void* B, const int ldb); void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* beta, void* C, const int ldc); void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha, const void* A, const int lda, void* B, const int ldb); void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha, const void* A, const int lda, void* B, const int ldb); /* * Routines with prefixes C and Z only */ void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha, const void* A, const int lda, const float beta, void* C, const int ldc); void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const float beta, void* C, const int ldc); void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const void* beta, void* C, const int ldc); void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha, const void* A, const int lda, const double beta, void* C, const int ldc); void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc); void cblas_xerbla(int p, const char* rout, const char* form, ...); #ifdef __cplusplus } #endif #endif ================================================ FILE: oneflow/core/common/channel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CHANNEL_H_ #define ONEFLOW_CORE_COMMON_CHANNEL_H_ #include "oneflow/core/common/util.h" namespace oneflow { enum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed }; template class Channel final { public: OF_DISALLOW_COPY_AND_MOVE(Channel); Channel() : is_closed_(false) {} ~Channel() = default; template ChannelStatus Send(U&& item); ChannelStatus Receive(T* item); ChannelStatus ReceiveMany(std::queue* items); void Close(); private: std::queue queue_; std::mutex mutex_; bool is_closed_; std::condition_variable cond_; }; template template ChannelStatus Channel::Send(U&& item) { bool notify; { std::unique_lock lock(mutex_); if (is_closed_) { return kChannelStatusErrorClosed; } notify = queue_.empty(); queue_.push(std::forward(item)); } if (notify) { cond_.notify_one(); } return kChannelStatusSuccess; } template ChannelStatus Channel::Receive(T* item) { std::unique_lock lock(mutex_); cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; }); if (queue_.empty()) { return kChannelStatusErrorClosed; } *item = std::move(queue_.front()); queue_.pop(); return kChannelStatusSuccess; } template ChannelStatus Channel::ReceiveMany(std::queue* items) { std::unique_lock lock(mutex_); cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; }); if (queue_.empty()) { return kChannelStatusErrorClosed; } while (!queue_.empty()) { items->push(std::move(queue_.front())); queue_.pop(); } return kChannelStatusSuccess; } template void Channel::Close() { std::unique_lock lock(mutex_); is_closed_ = true; cond_.notify_all(); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CHANNEL_H_ ================================================ FILE: oneflow/core/common/channel_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/channel.h" #include "oneflow/core/common/range.h" namespace oneflow { void CallFromSenderThread(Channel* channel, Range range) { for (int i = range.begin(); i < range.end(); ++i) { if (channel->Send(i) != kChannelStatusSuccess) { break; } } } void CallFromReceiverThread(std::vector* visit, Channel* channel) { int num = -1; int* num_ptr = # while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); } } TEST(Channel, 30sender40receiver) { Channel channel; std::vector senders; std::vector receivers; int sender_num = 30; int receiver_num = 40; int range_num = 200; std::vector> visits; for (int i = 0; i < receiver_num; ++i) { std::vector visit_i; for (int j = 0; j < range_num; j++) { visit_i.emplace_back(0); } visits.emplace_back(visit_i); } for (int i = 0; i < sender_num; ++i) { senders.emplace_back(CallFromSenderThread, &channel, Range(0, range_num)); } for (int i = 0; i < receiver_num; ++i) { receivers.emplace_back(CallFromReceiverThread, &visits[i], &channel); } for (std::thread& this_thread : senders) { this_thread.join(); } channel.Close(); for (std::thread& this_thread : receivers) { this_thread.join(); } for (int i = 0; i < range_num; ++i) { int visit_count = 0; for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; } ASSERT_EQ(visit_count, sender_num); } } } // namespace oneflow ================================================ FILE: oneflow/core/common/check.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/throw.h" namespace oneflow { void GLOGCHECK(bool value) { CHECK_OR_THROW(value); } void GLOGLOGFATAL(const char* error_msg) { LOG(FATAL) << error_msg; } } // namespace oneflow ================================================ FILE: oneflow/core/common/check.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // The functions in this header file are used to replace `CHECK` and `LOG(FATAL)` macros of glog // in those header files included by oneflow/core/common/throw.h, so those header files // do not need to include , and we can undef CHECK series macro of // glog in oneflow/core/common/throw.h and use another impl instead with less modification. namespace oneflow { void GLOGCHECK(bool); void GLOGLOGFATAL(const char*); } // namespace oneflow ================================================ FILE: oneflow/core/common/check_level.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { bool IsEnvEnabled(int32_t check_level) { static const int env_check_level = ParseIntegerFromEnv("ONEFLOW_CHECK_LEVEL", -1); static const bool env_debug_mode = IsInDebugMode(); return env_debug_mode || env_check_level >= check_level; } } // namespace oneflow ================================================ FILE: oneflow/core/common/check_level.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_ #define ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_ namespace oneflow { bool IsEnvEnabled(int32_t check_level); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_ ================================================ FILE: oneflow/core/common/constant.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CONSTANT_H_ #define ONEFLOW_CORE_COMMON_CONSTANT_H_ #include namespace oneflow { static const int64_t kInvalidSessionId = -1; static const std::string kNoPassTag = ""; static const std::string kMainOp = "main_op"; static const int64_t kMaxSplitAxis = 6; constexpr size_t kMaxNumDims = 8; static const std::string kAsymmetricCodeErrorMsg = "Maybe executing different code in different ranks, please check if the code is branched and " "operates on the global tensor."; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CONSTANT_H_ ================================================ FILE: oneflow/core/common/container_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_ #define ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_ #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/maybe.h" namespace oneflow { template scalar_or_const_ref_t MapAt(const MapT& map, const KeyT& key, const U& default_val) { const auto& iter = map.find(key); if (iter == map.end()) { return default_val; } return iter->second; } template Maybe> MapAt(const MapT& map, const KeyT& key) { const auto& iter = map.find(key); if constexpr (printable()) { CHECK_OR_RETURN(iter != map.end()) << "Key \"" << key << "\" not found"; } else { CHECK_OR_RETURN(iter != map.end()) << "MapAt failed, but the key is not printable. Please implement operator<< if you want to " "see the key in this error message."; } return iter->second; } template Maybe MapAt(MapT& map, const KeyT& key) { const auto& iter = map.find(key); if constexpr (printable()) { CHECK_OR_RETURN(iter != map.end()) << "Key \"" << key << "\" not found"; } else { CHECK_OR_RETURN(iter != map.end()) << "MapAt failed, but the key is not printable. Please implement operator<< if you want to " "see the key in this error message."; } return iter->second; } template Maybe> VectorAt(const VecT& vec, typename VecT::size_type index) { CHECK_LT_OR_RETURN(index, vec.size()); return vec[index]; } template Maybe VectorAt(VecT& vec, typename VecT::size_type index) { static_assert(!std::is_same::value, "VectorAt(vector&, size_t) is not supported."); CHECK_LT_OR_RETURN(index, vec.size()); return vec[index]; } template<> inline Maybe VectorAt(const std::vector& vec, typename std::vector::size_type index) { CHECK_LT_OR_RETURN(index, vec.size()); // convert vector bool proxy to bool return static_cast(vec[index]); } template std::string Join(const T& con, const std::string& delimiter) { std::ostringstream os; auto b = begin(con), e = end(con); if (b != e) { std::copy(b, prev(e), std::ostream_iterator(os, delimiter)); b = prev(e); } if (b != e) { os << *b; } return os.str(); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_ ================================================ FILE: oneflow/core/common/container_util_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace test { TEST(VectorAt, write_int_vector) { std::vector vec = {1, 2, 3, 4, 5}; EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 2); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 4); CHECK_JUST(VectorAt(vec, 1)) = 6; EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 6); CHECK_JUST(VectorAt(vec, 3)) = 8; EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 8); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)), 1); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 2)), 3); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 4)), 5); } namespace { class A { public: explicit A(int a) : a(a) {} int a; }; } // namespace TEST(VectorAt, write_custom_class_vector) { std::vector vec = {A(1), A(2)}; EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 1); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 2); CHECK_JUST(VectorAt(vec, 0)) = A(3); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 3); CHECK_JUST(VectorAt(vec, 1)) = A(4); EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 4); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/cost_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TIME_UTIL_H_ #define ONEFLOW_CORE_COMMON_TIME_UTIL_H_ #include #include #include #include "nlohmann/json.hpp" #include "oneflow/core/common/util.h" #include "oneflow/core/common/mem_util.h" #include "oneflow/core/job/utils/progress_bar.h" namespace oneflow { template struct Duration { static const std::string& Repr() { static const std::string repr = ""; return repr; } }; #define DEFINE_DURATION_TRAIT(time_type) \ template<> \ struct Duration { \ static const std::string& Repr() { \ static const std::string repr = #time_type; \ return repr; \ } \ }; DEFINE_DURATION_TRAIT(nanoseconds) DEFINE_DURATION_TRAIT(microseconds) DEFINE_DURATION_TRAIT(milliseconds) DEFINE_DURATION_TRAIT(seconds) DEFINE_DURATION_TRAIT(minutes) DEFINE_DURATION_TRAIT(hours) #undef DEFINE_DURATION_TRAIT template class CostCounter final { public: OF_DISALLOW_COPY_AND_MOVE(CostCounter); explicit CostCounter(bool with_log = true, bool with_mem = false) : with_log_(with_log), with_mem_(with_mem) {} ~CostCounter() = default; void Count(const std::string& log_prefix = "", int v_log_level = 0, bool log_progress = false); private: using Clock = std::conditional_t; Clock::time_point start_{Clock::now()}; bool with_log_{false}; bool with_mem_{false}; }; template void CostCounter::Count(const std::string& log_prefix, int v_log_level, bool log_progress) { if (log_progress) { CHECK_JUST(LogProgress(log_prefix)); } const auto end = Clock::now(); if (FLAGS_minloglevel <= 0 && VLOG_IS_ON(v_log_level) && with_log_ && v_log_level >= 0) { // only do time/mem count and log when glog level is INFO and VLOG level is matched. auto dur = std::chrono::duration_cast(end - start_).count(); nlohmann::json json_log; json_log["loc"] = log_prefix; json_log["time_cost"] = std::to_string(dur) + " " + Duration::Repr(); if (with_mem_) { #ifdef __linux__ double vm = 0, rss = 0; ProcessMemUsage(&vm, &rss); json_log["mem_rss"] = std::to_string(rss) + " MB"; #endif // __linux__ } if (v_log_level == 0) { LOG(INFO) << "[count log]" << json_log.dump(); } else { VLOG(v_log_level) << "[count log]" << json_log.dump(); } } start_ = end; return; } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_TIME_UTIL_H_ ================================================ FILE: oneflow/core/common/cpp_attribute.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_ #define ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_ #define likely GOOGLE_PREDICT_TRUE #define unlikely GOOGLE_PREDICT_FALSE #endif // ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_ ================================================ FILE: oneflow/core/common/data_type.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/tensor_buffer.h" namespace oneflow { bool IsBoolDataType(DataType data_type) { switch (data_type) { #define BOOL_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(BOOL_CASE, BOOL_DATA_TYPE_SEQ) default: return false; } #undef BOOL_CASE } bool IsIntegralDataType(DataType data_type) { switch (data_type) { #define INTEGRAL_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(INTEGRAL_CASE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) default: return false; } #undef INTEGRAL_CASE } bool IsFloatingDataType(DataType data_type) { switch (data_type) { #define FLOATING_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(FLOATING_CASE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ) default: return false; } #undef FLOATING_CASE } bool IsHalfDataType(DataType data_type) { switch (data_type) { #define HALF_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(HALF_CASE, FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ) default: return false; } #undef HALF_CASE } bool IsComplexDataType(DataType data_type) { switch (data_type) { #define COMPLEX_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(COMPLEX_CASE, COMPLEX_DATA_TYPE_SEQ) default: return false; } #undef COMPLEX_CASE } bool IsTriviallyCopyableDataType(DataType data_type) { switch (data_type) { #define TRIVIALLY_COPY_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(TRIVIALLY_COPY_CASE, TRIVIALLY_COPY_DATA_TYPE_SEQ INT16_DATA_TYPE_SEQ) default: return false; } #undef TRIVIALLY_COPY_CASE } bool IsIndexDataType(DataType data_type) { switch (data_type) { #define INDEX_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(INDEX_CASE, INDEX_DATA_TYPE_SEQ) default: return false; } #undef INDEX_CASE } bool IsSupportRequireGradDataType(DataType data_type) { switch (data_type) { #define REQUIRE_GRAD_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE( REQUIRE_GRAD_CASE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ) default: return false; } #undef REQUIRE_GRAD_CASE } bool NotSupportBoxingDataType(DataType data_type) { switch (data_type) { #define NO_BOXING_CASE(type_cpp, type_proto) \ case type_proto: return true; OF_PP_FOR_EACH_TUPLE(NO_BOXING_CASE, NO_BOXING_DATA_TYPE_SEQ) default: return false; } #undef NO_BOXING_CASE } size_t GetSizeOfDataType(DataType data_type) { switch (data_type) { // 8-bit case kChar: return 1; case kInt8: return 1; case kUInt8: return 1; case kBool: return 1; // 16-bit case kInt16: return 2; case kUInt16: return 2; case kFloat16: return 2; case kBFloat16: return 2; // 32-bit case kInt32: return 4; case kUInt32: return 4; case kFloat: return 4; case kComplex32: return 4; // 64-bit case kInt64: return 8; case kUInt64: return 8; case kDouble: return 8; case kComplex64: return 8; // 128-bit case kInt128: return 16; case kUInt128: return 16; case kComplex128: return 16; // non pod case kOFRecord: return sizeof(OFRecord); case kTensorBuffer: return sizeof(TensorBuffer); default: LOG(FATAL) << "invalid data_type: " << DataType_Name(data_type); } } namespace { void CheckDataType() { static_assert(sizeof(int8_t) == sizeof(char), "sizeof(int8_t) != sizeof(char)"); static_assert(sizeof(int16_t) == sizeof(short), "sizeof(int16_t) != sizeof(short)"); static_assert(sizeof(int32_t) == sizeof(int), "sizeof(int32_t) != sizeof(int)"); static_assert(sizeof(int64_t) == sizeof(long long), "sizeof(int64_t) != sizeof(long long)"); #if defined(WITH_CUDA) #define CHECK_DEVICE_FP16(get_val) \ do { \ float16 host_fp16 = get_val(); \ half device_fp16 = get_val(); \ CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \ } while (0) CHECK_DEVICE_FP16(GetZeroVal); CHECK_DEVICE_FP16(GetOneVal); CHECK_DEVICE_FP16(GetMaxVal); CHECK_DEVICE_FP16(GetMinVal); #undef CHECK_DEVICE_FP16 #endif #define CHECK_MAX_VAL(T, limit_value) CHECK_EQ(GetMaxVal(), std::numeric_limits::max()); OF_PP_FOR_EACH_TUPLE(CHECK_MAX_VAL, MAX_VAL_SEQ); #undef CHECK_MAX_VAL #define CHECK_MIN_VAL(T, limit_value) CHECK_EQ(GetMinVal(), std::numeric_limits::lowest()); OF_PP_FOR_EACH_TUPLE(CHECK_MIN_VAL, MIN_VAL_SEQ); #undef CHECK_MIN_VAL } COMMAND(CheckDataType()); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/common/data_type.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_H_ #define ONEFLOW_CORE_COMMON_DATA_TYPE_H_ #include #include #if defined(WITH_CUDA) #include #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #endif #include "oneflow/core/common/bfloat16.h" #include "oneflow/core/common/bfloat16_math.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/data_type_seq.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/device_type.h" #include namespace std { // Extend numeric_limits for the C++ standard library. #ifdef WITH_CUDA template<> struct numeric_limits { static constexpr int digits = std::numeric_limits::digits; static constexpr half_float::half lowest() { return std::numeric_limits::lowest(); } static constexpr half_float::half max() { return std::numeric_limits::max(); } }; #endif // WITH_CUDA } // namespace std namespace oneflow { namespace detail { template struct IsFloat16Helper : std::false_type {}; template struct IsFloatingHelper : std::false_type {}; template struct IsIntegralHelper : std::false_type {}; template struct IsUnsignedIntegralHelper : std::false_type {}; #ifdef WITH_CUDA template struct IsCudaComplexHelper : std::false_type {}; #endif // WITH_CUDA } // namespace detail using float16 = half_float::half; #define DEFINE_SPEC(Trait, Type, Value) \ template<> \ struct Trait : std::integral_constant {}; // Type Trait: IsFloat16 DEFINE_SPEC(detail::IsFloat16Helper, float16, true) #ifdef WITH_CUDA DEFINE_SPEC(detail::IsFloat16Helper, half, true) #endif // WITH_CUDA template struct IsFloat16 : std::integral_constant::type>::value)> {}; // Type Trait: IsCudaComplex #ifdef WITH_CUDA DEFINE_SPEC(detail::IsCudaComplexHelper, cuComplex, true) DEFINE_SPEC(detail::IsCudaComplexHelper, cuDoubleComplex, true) template struct IsCudaComplex : std::integral_constant< bool, (detail::IsCudaComplexHelper::type>::value)> {}; #endif // WITH_CUDA // Type Trait: IsFloating #define SPECIALIZE_TRUE_FLOATING(type_cpp, type_proto) \ DEFINE_SPEC(detail::IsFloatingHelper, type_cpp, true) OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_FLOATING, FLOATING_DATA_TYPE_SEQ); #undef SPECIALIZE_TRUE_FLOATING DEFINE_SPEC(detail::IsFloatingHelper, float16, true) #ifdef WITH_CUDA DEFINE_SPEC(detail::IsFloatingHelper, half, true) #endif // WITH_CUDA template struct IsFloating : std::integral_constant::type>::value)> { }; // Type Trait: IsIntegral #define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \ DEFINE_SPEC(detail::IsIntegralHelper, type_cpp, true) OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, INT_DATA_TYPE_SEQ); #undef SPECIALIZE_TRUE_INTEGRAL template struct IsIntegral : std::integral_constant::type>::value)> { }; // Type Trait: IsUnsignedIntegral #define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \ DEFINE_SPEC(detail::IsUnsignedIntegralHelper, type_cpp, true) OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, UNSIGNED_INT_DATA_TYPE_SEQ); #undef SPECIALIZE_TRUE_INTEGRAL template struct IsUnsignedIntegral : std::integral_constant< bool, (detail::IsUnsignedIntegralHelper::type>::value)> {}; #undef DEFINE_SPEC // Type Trait: GetDataType template struct GetDataType; template<> struct GetDataType : std::integral_constant {}; #define SPECIALIZE_GET_DATA_TYPE(type_cpp, type_proto) \ template<> \ struct GetDataType : std::integral_constant {}; \ inline type_cpp GetTypeByDataType(std::integral_constant) { return {}; } OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_DATA_TYPE, ALL_DATA_TYPE_SEQ UNSIGNED_INT32_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ UNSIGNED_INT64_DATA_TYPE_SEQ INT16_DATA_TYPE_SEQ); #undef SPECIALIZE_GET_DATA_TYPE template struct GetDataType::value>::type> : std::integral_constant {}; #ifdef WITH_CUDA template<> struct GetDataType : std::integral_constant {}; template<> struct GetDataType : std::integral_constant {}; #endif // WITH_CUDA #if CUDA_VERSION >= 11000 template<> struct GetDataType : std::integral_constant {}; #endif template using DataTypeToType = decltype(GetTypeByDataType(std::integral_constant{})); #if defined(__CUDACC__) #define OF_DEVICE_FUNC __device__ __host__ __forceinline__ #else #define OF_DEVICE_FUNC inline #endif #ifdef WITH_CUDA template::value || IsCudaComplex::value)>::type* = nullptr> OF_DEVICE_FUNC T GetZeroVal() { return static_cast(0); } template::value || IsCudaComplex::value)>::type* = nullptr> OF_DEVICE_FUNC T GetOneVal() { return static_cast(1); } #else template::value>::type* = nullptr> OF_DEVICE_FUNC T GetZeroVal() { return static_cast(0); } template::value>::type* = nullptr> OF_DEVICE_FUNC T GetOneVal() { return static_cast(1); } #endif // WITH_CUDA template::value>::type* = nullptr> OF_DEVICE_FUNC T GetMinVal(); template::value>::type* = nullptr> OF_DEVICE_FUNC T GetMaxVal(); #ifdef __APPLE__ #define APPLE_MAX_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, ULONG_MAX) #else #define APPLE_MAX_VAL_SEQ #endif #define MAX_VAL_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MAX) \ OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MAX) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MAX) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MAX) \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, UINT8_MAX) \ OF_PP_MAKE_TUPLE_SEQ(uint16_t, UINT16_MAX) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, UINT32_MAX) \ APPLE_MAX_VAL_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, UINT64_MAX) \ OF_PP_MAKE_TUPLE_SEQ(float, FLT_MAX) \ OF_PP_MAKE_TUPLE_SEQ(double, DBL_MAX) \ OF_PP_MAKE_TUPLE_SEQ(bool, true) #ifdef __APPLE__ #define APPLE_MIN_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, 0) #else #define APPLE_MIN_VAL_SEQ #endif #define MIN_VAL_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MIN) \ OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MIN) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MIN) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MIN) \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, 0) \ OF_PP_MAKE_TUPLE_SEQ(uint16_t, 0) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, 0) \ APPLE_MIN_VAL_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, 0) \ OF_PP_MAKE_TUPLE_SEQ(float, -FLT_MAX) \ OF_PP_MAKE_TUPLE_SEQ(double, -DBL_MAX) \ OF_PP_MAKE_TUPLE_SEQ(bool, false) #define SPECIALIZE_MAX_VAL(T, limit_value) \ template<> \ OF_DEVICE_FUNC T GetMaxVal() { \ return limit_value; \ } OF_PP_FOR_EACH_TUPLE(SPECIALIZE_MAX_VAL, MAX_VAL_SEQ); #undef SPECIALIZE_MAX_VAL #define SPECIALIZE_MIN_VAL(T, limit_value) \ template<> \ OF_DEVICE_FUNC T GetMinVal() { \ return limit_value; \ } OF_PP_FOR_EACH_TUPLE(SPECIALIZE_MIN_VAL, MIN_VAL_SEQ); #undef SPECIALIZE_MIN_VAL template const T* GetZeroPtr() { static const T ret = GetZeroVal(); return &ret; } template const T* GetOnePtr() { static const T ret = GetOneVal(); return &ret; } template::value>::type* = nullptr> OF_DEVICE_FUNC T GetZeroVal() { uint16_t ret = 0x0; // Decimal: 0; Binary: 0 00000 0000000000 return *(T*)&ret; } #ifdef WITH_CUDA template::value>::type* = nullptr> OF_DEVICE_FUNC T GetZeroVal() { return make_cuFloatComplex((float)0.0, (float)0.0); } template::value>::type* = nullptr> OF_DEVICE_FUNC T GetZeroVal() { return make_cuDoubleComplex((double)0.0, (double)0.0); } #endif // WITH_CUDA template::value>::type* = nullptr> OF_DEVICE_FUNC T GetOneVal() { uint16_t ret = 0x3c00; // Decimal: 15360; Binary: 0 01111 0000000000 return *(T*)&ret; } #ifdef WITH_CUDA template::value>::type* = nullptr> OF_DEVICE_FUNC T GetOneVal() { return make_cuFloatComplex((float)1.0, (float)1.0); } template::value>::type* = nullptr> OF_DEVICE_FUNC T GetOneVal() { return make_cuDoubleComplex((double)1.0, (double)1.0); } #endif // WITH_CUDA template::value>::type* = nullptr> OF_DEVICE_FUNC T GetMaxVal() { uint16_t ret = 0x7bff; // Decimal: 31743; Binary: 0 11110 1111111111 return *(T*)&ret; } template::value>::type* = nullptr> OF_DEVICE_FUNC T GetMinVal() { uint16_t ret = 0xfbff; // Decimal: 64511; Binary: 1 11110 1111111111 return *(T*)&ret; } #if CUDA_VERSION >= 11000 template<> OF_DEVICE_FUNC nv_bfloat16 GetMinVal() { uint16_t ret = 0xff7f; return *(nv_bfloat16*)&ret; } #endif // CUDA_VERSION >= 11000 template struct DevDType { typedef T type; }; #if defined(WITH_CUDA) template<> struct DevDType { static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)"); typedef half type; }; #if CUDA_VERSION >= 11000 template<> struct DevDType { static_assert(sizeof(bfloat16) == sizeof(nv_bfloat16), "sizeof(bfloat16) != sizeof(nv_bfloat16)"); typedef nv_bfloat16 type; }; #endif // CUDA_VERSION >= 11000 #endif // defined(WITH_CUDA) // Func bool IsBoolDataType(DataType data_type); bool IsIntegralDataType(DataType data_type); bool IsFloatingDataType(DataType data_type); bool IsHalfDataType(DataType data_type); bool IsSupportRequireGradDataType(DataType data_type); bool IsComplexDataType(DataType data_type); bool IsTriviallyCopyableDataType(DataType data_type); bool IsIndexDataType(DataType data_type); bool NotSupportBoxingDataType(DataType data_type); size_t GetSizeOfDataType(DataType data_type); inline bool operator==(const OptInt64& lhs, const OptInt64& rhs) { return (lhs.has_value() && rhs.has_value() && lhs.value() == rhs.value()) || (!lhs.has_value() && !rhs.has_value()); } template void CheckDataType(DataType data_type) { LOG_IF(FATAL, (std::is_same::value == false && std::is_same::value == false && data_type != DataType::kChar && data_type != GetDataType::value)) << data_type << " " << GetDataType::value; } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_DATA_TYPE_H_ ================================================ FILE: oneflow/core/common/data_type.proto ================================================ syntax = "proto2"; package oneflow; enum DataType { kInvalidDataType = 0; kChar = 1; kFloat = 2; kDouble = 3; kInt8 = 4; kInt32 = 5; kInt64 = 6; kUInt8 = 7; kOFRecord = 8; kFloat16 = 9; kTensorBuffer = 10; kBFloat16 = 11; kBool = 12; kUInt16 = 13; kUInt32 = 14; kUInt64 = 15; kUInt128 = 16; kInt16 = 17; kInt128 = 18; kComplex32 = 19; kComplex64 = 20; kComplex128 = 21; } message OptInt64 { optional int64 value = 1 [ default = -1 ]; } ================================================ FILE: oneflow/core/common/data_type_converter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_ #define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_ #ifdef WITH_CUDA #include #endif #include #include #include #include "oneflow/core/common/data_type.h" namespace oneflow { template struct IsFloatingOrHalf { static const bool value = IsFloating::value || IsFloat16::value; }; template struct IsArithmeticOrHalf { static const bool value = std::is_arithmetic::value || IsFloat16::value; }; template struct NeedsClamp { static const bool from_fp = IsFloatingOrHalf::value; static const bool to_fp = IsFloatingOrHalf::value; static const bool from_fp16 = IsFloat16::value; static const bool to_fp16 = IsFloat16::value; static const bool from_unsigned = std::is_unsigned::value; static const bool to_unsigned = std::is_unsigned::value; static const bool value = // to smaller type of same kind (fp, int) (from_fp == to_fp && sizeof(To) < sizeof(From)) || // fp32 has range in excess of (u)int64 (from_fp && !to_fp) || // converting to unsigned requires clamping negatives to zero (!from_unsigned && to_unsigned) || // zero-extending signed unsigned integers requires more bits (from_unsigned && !to_unsigned && sizeof(To) <= sizeof(From)) || // float16 (to_fp16 && sizeof(To) <= sizeof(From)); }; template struct NeedsClamp { static const bool value = false; }; template struct ClampHelper {}; // floating-point and signed integer -> floating-point and signed integer template struct ClampHelper< T, U, std::enable_if_t< NeedsClamp::value && std::is_signed::value && std::is_signed::value, void>> { OF_DEVICE_FUNC static const T Call(U value) { return value <= GetMinVal() ? GetMinVal() : value >= GetMaxVal() ? GetMaxVal() : static_cast(value); } }; // floating-point -> unsigned types template struct ClampHelper::value && std::is_signed::value && IsFloatingOrHalf::value && std::is_unsigned::value, void>> { OF_DEVICE_FUNC static const T Call(U value) { return value <= GetMinVal() ? GetMinVal() : value >= GetMaxVal() ? GetMaxVal() : static_cast(value); } }; // signed integer types -> unsigned types template struct ClampHelper::value && std::is_signed::value && std::is_integral::value && std::is_unsigned::value, void>> { OF_DEVICE_FUNC static const T Call(U value) { return value <= 0 ? 0 : static_cast>(value) >= GetMaxVal() ? GetMaxVal() : static_cast(value); } }; // unsigned types -> any types template struct ClampHelper::value && std::is_unsigned::value, void>> { OF_DEVICE_FUNC static const T Call(U value) { return value >= GetMaxVal() ? GetMaxVal() : static_cast(value); } }; // not clamp template struct ClampHelper::value, void>> { OF_DEVICE_FUNC static const T Call(U value) { return value; } }; OF_DEVICE_FUNC const int32_t Clamp(uint32_t value) { return value & 0x80000000u ? 0x7fffffff : value; } OF_DEVICE_FUNC const uint32_t Clamp(int32_t value) { return value < 0 ? 0u : value; } OF_DEVICE_FUNC const int32_t Clamp(int64_t value) { return value < static_cast(GetMinVal()) ? GetMinVal() : value > static_cast(GetMaxVal()) ? GetMaxVal() : static_cast(value); } template<> struct ClampHelper { OF_DEVICE_FUNC static const int32_t Call(uint64_t value) { return value > static_cast(GetMaxVal()) ? GetMaxVal() : static_cast(value); } }; template<> struct ClampHelper { OF_DEVICE_FUNC static const uint32_t Call(int64_t value) { return value < 0 ? 0 : value > static_cast(GetMaxVal()) ? GetMaxVal() : static_cast(value); } }; template<> struct ClampHelper { OF_DEVICE_FUNC static const uint32_t Call(uint64_t value) { return value > static_cast(GetMaxVal()) ? GetMaxVal() : static_cast(value); } }; template struct ClampHelper { OF_DEVICE_FUNC static const bool Call(T value) { return static_cast(value); } }; template struct ClampHelper { inline static const float16 Call(T value) { return static_cast(ClampHelper::Call(value) < GetMinVal() ? GetMinVal() : ClampHelper::Call(value) > GetMaxVal() ? GetMaxVal() : ClampHelper::Call(value)); } }; template struct ClampHelper { inline static const T Call(float16 value) { return ClampHelper::Call(static_cast(value)); } }; inline const float16 Clamp(float16 value) { return value; } template OF_DEVICE_FUNC const T Clamp(U value) { return ClampHelper::Call(value); } namespace { #ifdef __CUDA_ARCH__ inline __device__ int cuda_round_helper(float f, int) { return __float2int_rn(f); } inline __device__ unsigned cuda_round_helper(float f, unsigned) { return __float2uint_rn(f); } inline __device__ long long cuda_round_helper(float f, long long) { return __float2ll_rd(f + 0.5f); } inline __device__ unsigned long long cuda_round_helper(float f, unsigned long long) { return __float2ull_rd(f + 0.5f); } inline __device__ long cuda_round_helper(float f, long) { return sizeof(long) == sizeof(int) ? __float2int_rn(f) : __float2ll_rd(f + 0.5f); } inline __device__ unsigned long cuda_round_helper(float f, unsigned long) { return sizeof(unsigned long) == sizeof(unsigned int) ? __float2uint_rn(f) : __float2ull_rd(f + 0.5f); } inline __device__ int cuda_round_helper(double f, int) { return __double2int_rn(f); } inline __device__ unsigned cuda_round_helper(double f, unsigned) { return __double2uint_rn(f); } inline __device__ long long cuda_round_helper(double f, long long) { return __double2ll_rd(f + 0.5f); } inline __device__ unsigned long long cuda_round_helper(double f, unsigned long long) { return __double2ull_rd(f + 0.5f); } inline __device__ long cuda_round_helper(double f, long) { return sizeof(long) == sizeof(int) ? __double2int_rn(f) : __double2ll_rd(f + 0.5f); } inline __device__ unsigned long cuda_round_helper(double f, unsigned long) { return sizeof(unsigned long) == sizeof(unsigned int) ? __double2uint_rn(f) : __double2ull_rd(f + 0.5f); } #endif template::value, bool InIsFp = IsFloatingOrHalf::value> struct ConverterBase; template struct Converter : ConverterBase { static_assert(IsArithmeticOrHalf::value && IsArithmeticOrHalf::value, "Default ConverterBase can only be used with arithmetic types."); }; // Converts between two FP types template struct ConverterBase { OF_DEVICE_FUNC static const Out Convert(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return value; } }; // Converts integral to FP type template struct ConverterBase { OF_DEVICE_FUNC static const Out Convert(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return value * (Out(1) / (GetMaxVal())); } OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return value * (Out(1) / (GetMaxVal())); } }; // Converts integral to float16 template struct ConverterBase { OF_DEVICE_FUNC static const float16 Convert(In value) { auto out = ConverterBase::Convert(value); return static_cast(out); } OF_DEVICE_FUNC static const float16 ConvertSat(In value) { auto out = ConverterBase::ConvertSat(value); return static_cast(out); } OF_DEVICE_FUNC static const float16 ConvertNorm(In value) { auto out = ConverterBase::ConvertNorm(value); return static_cast(out); } OF_DEVICE_FUNC static const float16 ConvertSatNorm(In value) { auto out = ConverterBase::ConvertSatNorm(value); return static_cast(out); } }; // Converts FP to integral type template struct ConverterBase { OF_DEVICE_FUNC static const Out Convert(In value) { #ifdef __CUDA_ARCH__ return Clamp(cuda_round_helper(value, Out())); #else return Clamp(std::round(value)); #endif } OF_DEVICE_FUNC static const Out ConvertSat(In value) { #ifdef __CUDA_ARCH__ return Clamp(cuda_round_helper(value, Out())); #else return Clamp(std::round(value)); #endif } OF_DEVICE_FUNC static const Out ConvertNorm(In value) { #ifdef __CUDA_ARCH__ return Clamp(cuda_round_helper(value * GetMaxVal(), Out())); #else return std::round(value * GetMaxVal()); #endif } OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { #ifdef __CUDA_ARCH__ return std::is_signed::value ? Clamp(cuda_round_helper(value * GetMaxVal(), Out())) : cuda_round_helper(GetMaxVal() * __saturatef(value), Out()); #else return Clamp(std::round(value * GetMaxVal())); #endif } }; // Converts signed to signed, unsigned to unsigned or unsigned to signed template::value, bool IsInSigned = std::is_signed::value> struct ConvertIntInt { OF_DEVICE_FUNC static const Out Convert(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return Converter::Convert(value * (1.0f * GetMaxVal() / GetMaxVal())); } OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp(value); } OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return ConvertNorm(value); } }; // Converts signed to unsigned integer template struct ConvertIntInt { OF_DEVICE_FUNC static const Out Convert(In value) { return value; } OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return Converter::Convert(value * (1.0f * GetMaxVal() / GetMaxVal())); } OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp(value); } OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { #ifdef __CUDA_ARCH__ return cuda_round_helper(__saturatef(value * (1.0f / GetMaxVal())) * GetMaxVal()); #else return value < 0 ? 0 : ConvertNorm(value); } #endif }; // Converts between integral types template struct ConverterBase : ConvertIntInt { static_assert(IsArithmeticOrHalf::value && IsArithmeticOrHalf::value, "Default ConverterBase can only be used with arithmetic types."); }; // Pass-through conversion template struct Converter { static OF_DEVICE_FUNC const T Convert(T value) { return value; } static OF_DEVICE_FUNC const T ConvertSat(T value) { return value; } static OF_DEVICE_FUNC const T ConvertNorm(T value) { return value; } static OF_DEVICE_FUNC const T ConvertSatNorm(T value) { return value; } }; template using converter_t = Converter, std::remove_cv_t>>; } // namespace template OF_DEVICE_FUNC const Out Convert(In value) { return converter_t::Convert(value); } template OF_DEVICE_FUNC const Out ConvertNorm(In value) { return converter_t::ConvertNorm(value); } template OF_DEVICE_FUNC const Out ConvertSat(In value) { return converter_t::ConvertSat(value); } template OF_DEVICE_FUNC const Out ConvertSatNorm(In value) { return converter_t::ConvertSatNorm(value); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_ ================================================ FILE: oneflow/core/common/data_type_converter_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "util.h" #include "oneflow/core/common/data_type_converter.h" #include "oneflow/core/common/data_type_converter_test_static.h" #ifdef __CUDA_ARCH__ #include #else #include #endif namespace oneflow { namespace { // cpp17 std::clamp possible implementation template constexpr const T& clamp(const T& v, const T& lo, const T& hi) { return (v < lo) ? lo : (hi < v) ? hi : v; } } // namespace TEST(ClampTest, Clamp) { ASSERT_TRUE(Clamp(0) == 0); ASSERT_TRUE(Clamp(255) == 255); ASSERT_TRUE(Clamp(100) == 100); ASSERT_TRUE(Clamp(100.3) == 100); ASSERT_TRUE(Clamp(256) == 255); ASSERT_TRUE(Clamp(-4) == 0); ASSERT_TRUE(Clamp(-4.0f) == 0); ASSERT_TRUE(Clamp(1e+20f) == 255); ASSERT_TRUE(Clamp(-1e+20f) == 0); ASSERT_TRUE(Clamp(1e+200) == 255); ASSERT_TRUE(Clamp(-1e+200) == 0); ASSERT_TRUE(Clamp(-4) == -4); ASSERT_TRUE(Clamp(-4.2) == -4); ASSERT_TRUE(Clamp(4.2) == 4); ASSERT_TRUE(Clamp(127) == 127); ASSERT_TRUE(Clamp(128) == 127); ASSERT_TRUE(Clamp(256) == 127); ASSERT_TRUE(Clamp(-128) == -128); ASSERT_TRUE(Clamp(-256) == -128); ASSERT_TRUE(Clamp(1e+20f) == 127); ASSERT_TRUE(Clamp(-1e+20f) == -128); ASSERT_TRUE(Clamp(1e+200) == 127); ASSERT_TRUE(Clamp(-1e+200) == -128); ASSERT_TRUE(Clamp(0) == 0); ASSERT_TRUE(Clamp(0xffff) == 0xffff); ASSERT_TRUE(Clamp(100) == 100); ASSERT_TRUE(Clamp(100.3) == 100); ASSERT_TRUE(Clamp(0x10000) == 0xffff); ASSERT_TRUE(Clamp(-4) == 0); ASSERT_TRUE(Clamp(-4.0f) == 0); ASSERT_TRUE(Clamp(1e+20f) == 0xffff); ASSERT_TRUE(Clamp(-1e+20f) == 0); ASSERT_TRUE(Clamp(1e+200) == 0xffff); ASSERT_TRUE(Clamp(-1e+200) == 0); ASSERT_TRUE(Clamp(-4) == -4); ASSERT_TRUE(Clamp(-4.2) == -4); ASSERT_TRUE(Clamp(4.2) == 4); ASSERT_TRUE(Clamp(0x7fff) == 0x7fff); ASSERT_TRUE(Clamp(0x8000) == 0x7fff); ASSERT_TRUE(Clamp(0x10000) == 0x7fff); ASSERT_TRUE(Clamp(-0x8000) == -0x8000); ASSERT_TRUE(Clamp(-0x10000) == -0x8000); ASSERT_TRUE(Clamp(1e+20f) == 0x7fff); ASSERT_TRUE(Clamp(-1e+20f) == -0x8000); ASSERT_TRUE(Clamp(1e+200) == 0x7fff); ASSERT_TRUE(Clamp(-1e+200) == -0x8000); ASSERT_TRUE(Clamp(0) == 0); ASSERT_TRUE(Clamp(0xffffffffLL) == 0xffffffffLL); ASSERT_TRUE(Clamp(100) == 100); ASSERT_TRUE(Clamp(100.3) == 100); ASSERT_TRUE(Clamp(0x100000000LL) == 0xffffffffLL); ASSERT_TRUE(Clamp(-4) == 0); ASSERT_TRUE(Clamp(-4.0f) == 0); ASSERT_TRUE(Clamp(1e+20f) == 0xffffffffu); ASSERT_TRUE(Clamp(-1.0e+20f) == 0); ASSERT_TRUE(Clamp(1e+200) == 0xffffffffu); ASSERT_TRUE(Clamp(-1.0e+200) == 0); ASSERT_TRUE(Clamp(-4) == -4); ASSERT_TRUE(Clamp(-4LL) == -4); ASSERT_TRUE(Clamp(-4.2) == -4); ASSERT_TRUE(Clamp(4.2) == 4); ASSERT_TRUE(Clamp(0x7fffffff) == 0x7fffffff); ASSERT_TRUE(Clamp(0x80000000L) == 0x7fffffff); ASSERT_TRUE(Clamp(0x100000000L) == 0x7fffffff); ASSERT_TRUE(Clamp(-0x80000000LL) == -0x7fffffff - 1); ASSERT_TRUE(Clamp(-0x100000000LL) == -0x7fffffff - 1); ASSERT_TRUE(Clamp(1.0e+20f) == 0x7fffffff); ASSERT_TRUE(Clamp(-1.0e+20f) == -0x80000000L); ASSERT_TRUE(Clamp(1.0e+200) == 0x7fffffff); ASSERT_TRUE(Clamp(-1.0e+200) == -0x80000000L); ASSERT_TRUE(Clamp(1.0e+200) == 0x7fffffffffffffffLL); ASSERT_TRUE(Clamp(-1.0e+200) == -0x7fffffffffffffffLL - 1); ASSERT_TRUE(Clamp(1.0e+200) == 0xffffffffffffffffULL); ASSERT_TRUE(Clamp(-1.0e+200) == 0); } TEST(ConvertSat, float2int) { FOR_RANGE(int32_t, exp, -10, 100) { FOR_RANGE(float, sig, -256, 257) { float f = ldexpf(sig, exp); float integral; float fract = modff(f, &integral); if (fract == 0.5f || fract == -0.5f) continue; double rounded = roundf(f); int64_t clamped = clamp(rounded, -128, 127); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; clamped = clamp(rounded, 0, 255); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; clamped = clamp(rounded, -0x8000, 0x7fff); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; clamped = clamp(rounded, 0, 0xffff); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; clamped = clamp(rounded, int32_t(~0x7fffffff), 0x7fffffff); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; clamped = clamp(rounded, 0, 0xffffffffu); ASSERT_EQ(ConvertSat(f), clamped) << " with f = " << f; } } } TEST(ConvertNorm, int2int) { EXPECT_EQ((ConvertNorm(0)), 0); EXPECT_EQ((ConvertNorm(127)), 255); } TEST(ConvertNorm, float2int) { EXPECT_EQ(ConvertNorm(0.0f), 0); EXPECT_EQ(ConvertNorm(0.499f), 127); EXPECT_EQ(ConvertNorm(1.0f), 255); EXPECT_EQ(ConvertNorm(1.0f), 127); EXPECT_EQ(ConvertNorm(0.499f), 63); EXPECT_EQ(ConvertNorm(-1.0f), -127); EXPECT_EQ(ConvertNorm(0.0f), 0); EXPECT_EQ(ConvertNorm(1.0f), 0xffff); EXPECT_EQ(ConvertNorm(1.0f), 0x7fff); EXPECT_EQ(ConvertNorm(-1.0f), -0x7fff); } TEST(ConvertSatNorm, float2int) { EXPECT_EQ(ConvertSatNorm(2.0f), 255); EXPECT_EQ(ConvertSatNorm(0.499f), 127); EXPECT_EQ(ConvertSatNorm(-2.0f), 0); EXPECT_EQ(ConvertSatNorm(2.0f), 127); EXPECT_EQ(ConvertSatNorm(0.499f), 63); EXPECT_EQ(ConvertSatNorm(-2.0f), -128); EXPECT_EQ(ConvertSatNorm(0.4f / 255), 0); EXPECT_EQ(ConvertSatNorm(0.6f / 255), 1); EXPECT_EQ(ConvertSatNorm(2.0f), 0x7fff); EXPECT_EQ(ConvertSatNorm(-2.0f), -0x8000); } TEST(ConvertNorm, int2float) { EXPECT_EQ((ConvertNorm(255)), 1.0f); EXPECT_NEAR((ConvertNorm(127)), 1.0f * 127 / 255, 1e-7f); EXPECT_EQ((ConvertNorm(127)), 1.0f); EXPECT_NEAR((ConvertNorm(64)), 1.0f * 64 / 127, 1e-7f); } TEST(Clamp1, int64_2_float16) { int64_t big_num = 0x0FFFFFFFFFFFFFFF; EXPECT_EQ(static_cast(Clamp(big_num)), Clamp(Clamp(big_num))); EXPECT_EQ(65504.0f, Clamp(big_num)); EXPECT_EQ(-65504.0f, Clamp(-big_num)); } TEST(Clamp2, float16_2_int64) { float16 fp16 = static_cast(65504.0f); EXPECT_EQ(65504, Clamp(fp16)); EXPECT_EQ(-65504, Clamp(-fp16)); } } // namespace oneflow ================================================ FILE: oneflow/core/common/data_type_converter_test_static.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_ #define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_ #include "oneflow/core/common/data_type_converter.h" namespace oneflow { namespace { // fp to int static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); static_assert(NeedsClamp::value, "Float range exceeds all ints up to 64b"); // same size, different signedness static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); static_assert(NeedsClamp::value, "Signed <-> unsigned requires clamp"); // larger, but unsigned static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(NeedsClamp::value, "Need to clamp negatives to 0"); static_assert(!NeedsClamp::value, "Clamping not required"); static_assert(!NeedsClamp::value, "Clamping not required"); static_assert(!NeedsClamp::value, "Clamping not required"); static_assert(!NeedsClamp::value, "Clamping not required"); static_assert(!NeedsClamp::value, "Clamping not required"); static_assert(!NeedsClamp::value, "Clamping not required"); } // namespace } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_ ================================================ FILE: oneflow/core/common/data_type_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_ #define ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_ #include #include "oneflow/core/common/preprocessor.h" // SEQ #define BOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) #define FLOATING_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define SIGNED_INT_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define INT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16) #define UNSIGNED_INT_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) #define UNSIGNED_INT32_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) #define UNSIGNED_INT64_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) #define INT_DATA_TYPE_SEQ SIGNED_INT_DATA_TYPE_SEQ #define CHAR_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) #define COMPLEX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(std::complex, DataType::kComplex64) \ OF_PP_MAKE_TUPLE_SEQ(std::complex, DataType::kComplex128) #define ARITHMETIC_DATA_TYPE_SEQ \ FLOATING_DATA_TYPE_SEQ \ INT_DATA_TYPE_SEQ #define POD_DATA_TYPE_SEQ \ ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ #define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ #define TRIVIALLY_COPY_DATA_TYPE_SEQ POD_AND_HALF_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ #define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) #define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ #define INDEX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) #define BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) #if defined(WITH_CUDA) #define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) #if CUDA_VERSION >= 11000 #define NV_BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) #endif // CUDA_VERSION >= 11000 #endif // defined(WITH_CUDA) #define IMAGE_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define NO_BOXING_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) \ OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer) #endif // ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_ ================================================ FILE: oneflow/core/common/decorator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DECORATOR_H_ #define ONEFLOW_CORE_COMMON_DECORATOR_H_ #include #include #include "tuple_hash.h" #include "static_check.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/common/cpp_attribute.h" namespace oneflow { template class Decorator> struct WithDecorator final { template struct Decorate; template struct Decorate final { template static T Call(Args... args) { return Decorator::template Call(args...); } }; }; #define DECORATE(fn_ptr, decorator) \ (&WithDecorator::Decorate::Call) template struct ThreadLocalCopiable; template struct ThreadLocalCopiable { template static RetT Call() { static thread_local RetT value = func(); return value; } }; template struct ThreadLocalCopiable { template static RetT Call(Arg0 arg0) { using KeyT = typename std::decay::type; using MappedT = typename std::decay::type; static thread_local std::unordered_map map; auto iter = map.find(arg0); if (iter == map.end()) { iter = map.emplace(arg0, func(arg0)).first; } return iter->second; } private: static_assert(!IsOutArg::value, ""); static_assert(!StaticAny::value, ""); }; template struct ThreadLocalCopiable { template static RetT Call(Arg0 arg0, Arg1 arg1) { using KeyT0 = typename std::decay::type; using KeyT1 = typename std::decay::type; using MappedT = typename std::decay::type; static thread_local std::unordered_map> map; auto* last_map = &map[arg0]; auto iter = last_map->find(arg1); if (iter == last_map->end()) { iter = last_map->emplace(arg1, func(arg0, arg1)).first; } return iter->second; } private: static_assert(!StaticAny::value, ""); }; template struct ThreadLocalCopiable { template static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2) { using KeyT0 = typename std::decay::type; using KeyT1 = typename std::decay::type; using KeyT2 = typename std::decay::type; using MappedT = typename std::decay::type; static thread_local std::unordered_map< KeyT0, std::unordered_map>> map; auto* last_map = &map[arg0][arg1]; auto iter = last_map->find(arg2); if (iter == last_map->end()) { iter = last_map->emplace(arg2, func(arg0, arg1, arg2)).first; } return iter->second; } private: static_assert(!StaticAny::value, ""); }; template struct ThreadLocalCopiable { template static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2, Arg3 arg3, Args... args) { using KeyT0 = typename std::decay::type; using KeyT1 = typename std::decay::type; using KeyT2 = typename std::decay::type; using KeyT3 = typename std::decay::type; using KeyT = std::tuple::type...>; using MappedT = typename std::decay::type; static thread_local std::unordered_map map; const auto& key = KeyT(arg0, arg1, arg2, arg3, args...); auto iter = map.find(key); if (iter == map.end()) { iter = map.emplace(key, func(arg0, arg1, arg2, arg3, args...)).first; } return iter->second; } private: static_assert(!StaticAny::value, ""); }; // for scalar type key. template struct ThreadLocal : public ThreadLocalCopiable { private: static_assert(StaticAll::value, ""); }; template struct ThreadLocalCachedCopiable; template struct ThreadLocalCachedCopiable { template static RetT Call() { static thread_local RetT value = func(); return value; } }; template struct ThreadLocalCachedCopiable { template static RetT Call(Arg0 arg0) { using KeyT = typename std::decay::type; using MappedT = typename std::decay::type; static thread_local std::unordered_map map; auto iter = map.find(arg0); if (iter == map.end()) { if (unlikely(map.size() >= ThreadLocalEnvInteger())) { map.clear(); } iter = map.emplace(arg0, func(arg0)).first; } return iter->second; } private: static_assert(!IsOutArg::value, ""); static_assert(!StaticAny::value, ""); }; template struct ThreadLocalCachedCopiable { template static RetT Call(Arg0 arg0, Args... args) { using KeyT0 = typename std::decay::type; using KeyT = std::tuple::type...>; using MappedT = typename std::decay::type; static thread_local std::unordered_map map; const auto& key = KeyT(arg0, args...); auto iter = map.find(key); if (iter == map.end()) { if (unlikely(map.size() >= ThreadLocalEnvInteger())) { map.clear(); } iter = map.emplace(key, func(arg0, args...)).first; } return iter->second; } private: static_assert(!StaticAny::value, ""); }; // for scalar type key. template struct ThreadLocalCached : public ThreadLocalCachedCopiable { private: static_assert(StaticAll::value, ""); }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_DECORATOR_H_ ================================================ FILE: oneflow/core/common/decorator_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { Maybe Inc(int x) { return x + 1; } Maybe IncByConstRef(const int& x) { return x + 1; } TEST(ThreadLocal, scalar) { auto* CachedInc = DECORATE(&Inc, ThreadLocal); int x = CHECK_JUST(CachedInc(0)); ASSERT_EQ(x, 1); } TEST(ThreadLocal, const_ref) { auto* CachedIncByConstRef = DECORATE(&IncByConstRef, ThreadLocal); int x = CHECK_JUST(CachedIncByConstRef(0)); ASSERT_EQ(x, 1); } namespace { struct Foo { static Maybe New(int x) { return std::shared_ptr(new Foo{x}); } int x; }; } // namespace TEST(ThreadLocal, _class) { auto* CachedFooNew = DECORATE(&Foo::New, ThreadLocal); const auto& foo = CHECK_JUST(CachedFooNew(10)); const auto& bar = CHECK_JUST(CachedFooNew(10)); ASSERT_EQ(foo->x, 10); ASSERT_TRUE(foo == bar); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/device.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/device_type.proto"; message DeviceProto { required DeviceType device_type = 1; required int64 device_id = 2; optional bool rematable = 3 [default = false]; } ================================================ FILE: oneflow/core/common/device_type.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/device_type.h" #include #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { std::vector GetAllAvailableDeviceTypeNames() { const auto& device_types = ep::DeviceManagerRegistry::GetRegisteredDeviceTypes(); std::vector device_type_names; device_type_names.reserve(device_types.size()); for (const auto& device_type : device_types) { device_type_names.emplace_back( ep::DeviceManagerRegistry::GetDeviceTypeNameByDeviceType(device_type)); } return device_type_names; } std::string PrintAvailableDevices() { const auto& device_type_names = GetAllAvailableDeviceTypeNames(); return fmt::format("{}", fmt::join(device_type_names, ", ")); } std::string PrintGeneratorAvailableDevices() { auto device_type_names = GetAllAvailableDeviceTypeNames(); device_type_names.emplace_back("auto"); // "auto" is a fake device type for random generator. return fmt::format("{}", fmt::join(device_type_names, ", ")); } } // namespace oneflow ================================================ FILE: oneflow/core/common/device_type.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_ #define ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_ #include "oneflow/core/common/device_type.pb.h" namespace std { template<> struct hash final { size_t operator()(oneflow::DeviceType device_type) const { return static_cast(device_type); } }; } // namespace std namespace oneflow { std::string PrintAvailableDevices(); std::string PrintGeneratorAvailableDevices(); #if defined(WITH_CUDA) #define DEVICE_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) \ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA) #else #define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) #endif } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_ ================================================ FILE: oneflow/core/common/device_type.proto ================================================ syntax = "proto2"; package oneflow; enum DeviceType { kInvalidDevice = 0; kCPU = 1; kCUDA = 2; kMockDevice = 3; // pseudo device for test. kMeta = 4; kMLU = 5; // Cambricon MLU kNPU = 6; // Ascend NPU kXPU = 7; // KunLunXin } ================================================ FILE: oneflow/core/common/dtype_signature.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_ #define ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_ #include "oneflow/core/common/dtype_signature.pb.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { inline bool operator==(const DTypeSignature& lhs, const DTypeSignature& rhs) { return PbMd().Equals(lhs, rhs); } } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::DTypeSignature& dtype_signature) { std::string serialized; dtype_signature.SerializeToString(&serialized); return std::hash()(serialized); } }; } // namespace std #endif // ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_ ================================================ FILE: oneflow/core/common/dtype_signature.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/data_type.proto"; message DTypeSignature { map name2dtype = 1; } ================================================ FILE: oneflow/core/common/eigen_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_ #define ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_ #include "Eigen/Core" #include "Eigen/Dense" namespace oneflow { template using EigenMatrixMap = Eigen::Map>; template using EigenArrayMap = Eigen::Map>; template using ConstEigenMatrixMap = Eigen::Map>; template using ConstEigenArrayMap = Eigen::Map>; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_ ================================================ FILE: oneflow/core/common/either_ptr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_EITHER_PTR_H_ #define ONEFLOW_CORE_COMMON_EITHER_PTR_H_ #include #include "oneflow/core/common/throw.h" namespace oneflow { template class EitherPtr final { public: static_assert(!std::is_same::value, "X should not be Y"); using XPtr = std::shared_ptr; using YPtr = std::shared_ptr; // WARNING: we should assume that the structure of shared_ptr and shared_ptr is same, // and obviously at most time the assumption holds static_assert(sizeof(XPtr) == sizeof(YPtr), "unsupported shared_ptr implementation"); EitherPtr() : type_(UnionType::value), x_ptr_(nullptr) {} EitherPtr(const XPtr& ptr) : type_(UnionType::value), x_ptr_(ptr) {} EitherPtr(const YPtr& ptr) : type_(UnionType::value) { new (&x_ptr_) YPtr(ptr); } EitherPtr(XPtr&& ptr) : type_(UnionType::value), x_ptr_(std::move(ptr)) {} EitherPtr(YPtr&& ptr) : type_(UnionType::value) { new (&x_ptr_) YPtr(std::move(ptr)); } EitherPtr(const EitherPtr& either_ptr) : type_(either_ptr.type_), x_ptr_(either_ptr.x_ptr_) {} EitherPtr(EitherPtr&& either_ptr) : type_(either_ptr.type_), x_ptr_(std::move(either_ptr.x_ptr_)) {} // the destructor of X or Y will be called properly because it will be stored in the deleter of // shared_ptr while constructed ~EitherPtr() = default; EitherPtr& operator=(const EitherPtr& either_ptr) { x_ptr_ = either_ptr.x_ptr_; type_ = either_ptr.type_; return *this; } EitherPtr& operator=(EitherPtr&& either_ptr) { x_ptr_ = std::move(either_ptr.x_ptr_); type_ = either_ptr.type_; return *this; } template bool Has() const { return type_ == UnionType::value; } template const std::shared_ptr& Get() const { return Get(tag{}); } private: template struct UnionType; template struct UnionType::value>::type> { static constexpr int8_t value = 0; }; template struct UnionType::value>::type> { static constexpr int8_t value = 1; }; template struct tag {}; const XPtr& Get(tag) const { CHECK(Has()); return x_ptr_; } const YPtr& Get(tag) const { CHECK(Has()); const auto* __attribute__((__may_alias__)) ptr = reinterpret_cast(&x_ptr_); return *ptr; } int8_t type_; std::shared_ptr x_ptr_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_EITHER_PTR_H_ ================================================ FILE: oneflow/core/common/env_var/bootstrap.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_ #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS, 20); DEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES, 3); DEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_SLEEP_SECONDS, 5); DEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES, 6); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_ ================================================ FILE: oneflow/core/common/env_var/debug_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_ #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_DEBUG_MODE, false); DEFINE_ENV_BOOL(ONEFLOW_DEBUG, false); inline bool IsInDebugMode() { return EnvBool() || EnvBool(); } DEFINE_ENV_BOOL(ENABLE_ACTOR_DEBUG_LOG, false); inline bool EnableActorDebugLog() { return EnvBool(); } DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true); inline bool EnableLogicalChain() { return EnvBool(); } DEFINE_ENV_BOOL(ENABLE_NCCL_LOGICAL_FUSION, true); inline bool EnableNcclLogicalFusion() { return EnvBool(); } inline bool IsPythonStackGetterEnabledByDebugBuild() { if (std::getenv("ONEFLOW_DEBUG_MODE") == nullptr && std::getenv("ONEFLOW_DEBUG") == nullptr && std::getenv("ONEFLOW_PYTHON_STACK_GETTER") == nullptr) { return std::string(OF_PP_STRINGIZE(ONEFLOW_CMAKE_BUILD_TYPE)) == "Debug"; } return false; } inline bool IsPythonStackGetterEnabled() { if (IsPythonStackGetterEnabledByDebugBuild()) { return true; } return ParseBooleanFromEnv("ONEFLOW_PYTHON_STACK_GETTER", IsInDebugMode()); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_ ================================================ FILE: oneflow/core/common/env_var/eager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ #include "oneflow/core/common/env_var/env_var.h" #ifdef WITH_CUDA #include #endif namespace oneflow { // NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the // use infer cache in naive local op interpret. DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true); // NOTE: use env variable 'ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE' indicate the size of // infer cache in op interpret. DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE, 128 * 1024); DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_NCCL_USE_COMPUTE_STREAM, false); inline bool EagerNcclUseComputeStream() { #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 static bool eager_nccl_use_compute_stream = ThreadLocalEnvBool(); return eager_nccl_use_compute_stream; #else return false; #endif } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ ================================================ FILE: oneflow/core/common/env_var/env_var.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_ #include "oneflow/core/common/util.h" namespace oneflow { template bool EnvBool(); #define DEFINE_ENV_BOOL(env_var, default_value) \ struct env_var {}; \ template<> \ inline bool EnvBool() { \ return ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ } template int64_t EnvInteger(); #define DEFINE_ENV_INTEGER(env_var, default_value) \ struct env_var {}; \ template<> \ inline int64_t EnvInteger() { \ return ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ } DEFINE_ENV_INTEGER(ONEFLOW_TIMEOUT_SECONDS, 7200); DEFINE_ENV_INTEGER(ONEFLOW_CHECK_TIMEOUT_SLEEP_SECONDS, EnvInteger()); DEFINE_ENV_INTEGER(ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT, 100); DEFINE_ENV_INTEGER(ONEFLOW_DELETE_OUTDATED_SHM_NAMES_INTERVAL, 1000); template bool ThreadLocalEnvBool(); #define DEFINE_THREAD_LOCAL_ENV_BOOL(env_var, default_value) \ struct env_var {}; \ template<> \ inline bool ThreadLocalEnvBool() { \ thread_local bool value = ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ return value; \ } template int64_t ThreadLocalEnvInteger(); #define DEFINE_THREAD_LOCAL_ENV_INTEGER(env_var, default_value) \ struct env_var {}; \ template<> \ inline int64_t ThreadLocalEnvInteger() { \ thread_local int64_t value = ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ return value; \ } DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_THRAED_LOCAL_CACHED_SIZE, 128 * 1024); template const std::string& ThreadLocalEnvString(); #define DEFINE_THREAD_LOCAL_ENV_STRING(env_var, default_value) \ struct env_var {}; \ template<> \ inline const std::string& ThreadLocalEnvString() { \ thread_local std::string value = GetStringFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ return value; \ } DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE, false); // Default compilation mode during graph compilation. There 2 modes to choose: // "naive", master rank compile the full plan. // "rank_per_process", multi process(rank) run seperation compile. DEFINE_THREAD_LOCAL_ENV_STRING(ONEFLOW_LAZY_COMPILE_MODE, "naive"); // Default number of threads during graph compilation. DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM, 16); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_ ================================================ FILE: oneflow/core/common/env_var/remat.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_REMAT_DISPLAY_IN_FIRST_TIME, false); DEFINE_ENV_BOOL(ONEFLOW_REMAT_RECORD_MEM_FRAG_RATE, true); DEFINE_ENV_INTEGER(ONEFLOW_REMAT_GROUP_NUM, 1); DEFINE_ENV_BOOL(ONEFLOW_REMAT_NEIGHBOR, true); DEFINE_ENV_BOOL(ONEFLOW_REMAT_HEURISTIC_DTE, false); DEFINE_ENV_BOOL(ONEFLOW_REMAT_HEURISTIC_DTR, false); DEFINE_ENV_BOOL(ONEFLOW_REMAT_LOG, false); } // namespace oneflow ================================================ FILE: oneflow/core/common/env_var/stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_ #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_DEVICE_STREAM_MAX_SIZE, 16); DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_STREAM_ENABLE_H2D_STREAM, false); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_ ================================================ FILE: oneflow/core/common/env_var/vm.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_ #define ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_ #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD, true); DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_STREAM_WAIT, true); DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE, 10) DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_SCHEDULE_YIELD, true) DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_WORKER_THREAD_LIMIT, 16); DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_MULTI_THREAD, true); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_ ================================================ FILE: oneflow/core/common/error.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "fmt/core.h" #include "fmt/color.h" #include "fmt/ostream.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/exception.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/error_util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/extension/stack/foreign_stack_getter.h" #include "oneflow/extension/stack/stacktrace.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { StackedError::StackedError() : stack_frame_(), error_proto_(new ErrorProto()) {} namespace { void LogError(const Error& error) { // gdb break point LOG(ERROR) << error->msg(); } std::shared_ptr* MutThreadLocalError() { thread_local std::shared_ptr error; return &error; } } // namespace Error&& Error::AddStackFrame(Symbol error_stack_frame) { stacked_error_->add_stack_frame(error_stack_frame); return std::move(*this); } Error&& Error::GetStackTrace(int64_t depth, int64_t skip_n_firsts) { backward::StackTrace st; backward::SnippetFactory snippets; backward::TraceResolver resolver; st.load_here(depth); st.skip_n_firsts(skip_n_firsts); resolver.load_stacktrace(st); for (int i = 0; i < st.size(); i++) { const auto& trace = resolver.resolve(st[i]); if (!backward::Printer::is_oneflow_file(trace.object_filename)) { continue; } // without debug info if (!trace.source.filename.size()) { stacked_error_->add_stack_frame( SymbolOf(ErrorStackFrame(trace.object_filename, -1, trace.object_function))); } // with debug info if (trace.source.filename.size()) { const backward::ResolvedTrace::SourceLoc& source_loc = trace.source; backward::SnippetFactory::lines_t lines = snippets.get_snippet(source_loc.filename, source_loc.line, static_cast(1)); std::string code_text = lines[0].second; const auto pos = code_text.find_first_not_of(" \t"); code_text = code_text.substr(pos, code_text.size() - pos); stacked_error_->add_stack_frame(SymbolOf( ErrorStackFrame(source_loc.filename, source_loc.line, source_loc.function, code_text))); } for (size_t inliner_idx = 0; inliner_idx < trace.inliners.size(); ++inliner_idx) { const backward::ResolvedTrace::SourceLoc& source_loc = trace.inliners[inliner_idx]; backward::SnippetFactory::lines_t lines = snippets.get_snippet(source_loc.filename, source_loc.line, static_cast(1)); std::string code_text = lines[0].second; const auto pos = code_text.find_first_not_of(" \t"); code_text = code_text.substr(pos, code_text.size() - pos); stacked_error_->add_stack_frame(SymbolOf( ErrorStackFrame(source_loc.filename, source_loc.line, source_loc.function, code_text))); } } return std::move(*this); } void Error::Merge(const Error& other) { auto* error_proto = stacked_error_->mut_error_proto(); error_proto->MergeFrom(*other.stacked_error_->error_proto()); } Error::operator std::string() const { return stacked_error_->DebugString(); } Error Error::Ok() { return std::make_shared(); } Error Error::ProtoParseFailedError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_proto_parse_failed_error(); return error; } Error Error::JobSetEmptyError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_set_empty_error(); return error; } Error Error::DeviceTagNotFoundError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_device_tag_not_found_error(); return error; } Error Error::InvalidValueError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_invalid_value_error(); return error; } Error Error::IndexError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_index_error(); return error; } Error Error::TypeError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_type_error(); return error; } Error Error::TimeoutError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_timeout_error(); return error; } Error Error::JobNameExistError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_name_exist_error(); return error; } Error Error::JobNameEmptyError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_name_empty_error(); return error; } Error Error::JobNameNotEqualError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_name_not_equal_error(); return error; } Error Error::NoJobBuildAndInferCtxError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_no_job_build_and_infer_ctx_error(); return error; } Error Error::JobConfFrozenError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_conf_frozen_error(); return error; } Error Error::JobConfNotSetError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_conf_not_set_error(); return error; } Error Error::JobConfRepeatedSetError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_conf_repeated_set_error(); return error; } Error Error::JobTypeNotSetError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_job_type_not_set_error(); return error; } Error Error::LogicalBlobNameNotExistError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_logical_blob_name_not_exist_error(); return error; } Error Error::LogicalBlobNameExistError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_logical_blob_name_exist_error(); return error; } Error Error::LogicalBlobNameInvalidError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_logical_blob_name_invalid_error(); return error; } Error Error::OpNameExistError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_op_name_exist_error(); return error; } Error Error::OpConfDeviceTagNoSetError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_op_conf_device_tag_no_set_error(); return error; } Error Error::PlacementError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_placement_error(); return error; } Error Error::BlobSplitAxisInferError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_blob_split_axis_infer_error(); return error; } Error Error::UnknownJobBuildAndInferError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_unknown_job_build_and_infer_error(); return error; } Error Error::CheckFailedError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_check_failed_error(); return error; } Error Error::ValueNotFoundError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_value_not_found_error(); return error; } Error Error::TodoError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_todo_error(); return error; } Error Error::UnimplementedError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_unimplemented_error(); return error; } Error Error::RuntimeError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_runtime_error(); return error; } Error Error::OutOfMemoryError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_out_of_memory_error(); return error; } Error Error::BoxingNotSupportedError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_boxing_not_supported_error(); return error; } Error Error::OpKernelNotFoundError(const std::vector& error_msgs) { auto error = std::make_shared(); auto* op_kernel_not_found_error = error->mut_error_proto()->mutable_op_kernel_not_found_error(); for (const auto& msg : error_msgs) { op_kernel_not_found_error->add_op_kernels_not_found_debug_str(msg); } return error; } Error Error::MultipleOpKernelsMatchedError(const std::vector& error_msgs) { auto error = std::make_shared(); auto* multiple_op_kernels_matched_error = error->mut_error_proto()->mutable_multiple_op_kernels_matched_error(); for (const auto& msg : error_msgs) { multiple_op_kernels_matched_error->add_matched_op_kernels_debug_str(msg); } return error; } Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc, uint64_t available, const std::string& device_tag) { auto error = std::make_shared(); auto* memory_zone_out_of_memory_error = error->mut_error_proto()->mutable_memory_zone_out_of_memory_error(); memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id)); memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id)); memory_zone_out_of_memory_error->add_device_tag(device_tag); memory_zone_out_of_memory_error->add_available(std::to_string(available) + " bytes"); memory_zone_out_of_memory_error->add_required(std::to_string(calc) + " bytes"); return error; } Error Error::LossBlobNotFoundError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_loss_blob_not_found_error(); return error; } Error Error::RwMutexedObjectNotFoundError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_rw_mutexed_object_not_found_error(); return error; } Error Error::GradientFunctionNotFoundError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_gradient_function_not_found_error(); return error; } Error Error::SymbolIdUninitializedError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_symbol_id_uninitialized_error(); return error; } Error Error::CompileOptionWrongError() { auto error = std::make_shared(); error->mut_error_proto()->mutable_compile_option_wrong_error(); return error; } Error Error::InputDeviceNotMatchError() { auto error = std::make_shared(); auto* input_device_not_match_error = error->mut_error_proto()->mutable_input_device_not_match_error(); input_device_not_match_error->add_info( std::string("Input tensors are at different devices, please try to use tensor.to or " "module.to to correct it.")); return error; } std::string GetStackedErrorString(const std::shared_ptr& error) { const auto& maybe_error = TRY(FormatErrorStr(error)); const auto& error_str = maybe_error.GetDataAndStackedError(error->DebugString()); CHECK_NE(error->error_proto()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); return error_str.first; } std::string GetErrorString(const std::shared_ptr& error) { std::string error_str; if (IsInDebugMode()) { error_str = GetStackedErrorString(error); } else { error_str = error->error_proto()->msg(); } if (error_str.empty()) { error_str = ""; } return error_str; } void ThrowError(const std::shared_ptr& error) { std::string error_str; fmt::format_to(std::back_inserter(error_str), "{}: {}\n", fmt::styled("Error", fmt::emphasis::bold | fmt::fg(fmt::color::red)), GetErrorString(error)); // Append foreign stack trace (e.g. Python stack trace) when it is available. if (ForeignFrameThreadLocalGuard::Current().has_value()) { auto frame = *CHECK_JUST(ForeignFrameThreadLocalGuard::Current()); if (!IsMainThread()) { if (auto* stack_getter = Singleton::Get()) { fmt::format_to(std::back_inserter(error_str), fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange), "Related Python stack trace:"); if (IsPythonStackGetterEnabledByDebugBuild()) { fmt::format_to( std::back_inserter(error_str), " (You are seeing this stack trace because you compiled OneFlow with " "CMAKE_BUILD_TYPE=Debug. If you want to see it even with other CMAKE_BUILD_TYPEs, " "you can set ONEFLOW_DEBUG or ONEFLOW_PYTHON_STACK_GETTER to 1)"); } fmt::format_to(std::back_inserter(error_str), "\n{}", stack_getter->GetFormattedStack(frame)); } else { fmt::format_to( std::back_inserter(error_str), "You can set {} or {} to 1 to get the Python stack of the error.", fmt::styled("ONEFLOW_DEBUG", fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)), fmt::styled("ONEFLOW_PYTHON_STACK_GETTER", fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange))); } } } *MutThreadLocalError() = error; if ((*error)->has_runtime_error()) { throw RuntimeException(error_str); } if ((*error)->has_type_error()) { throw TypeException(error_str); } if ((*error)->has_index_error()) { throw IndexException(error_str); } if ((*error)->has_unimplemented_error()) { throw NotImplementedException(error_str); } throw Exception(GetStackedErrorString(error)); } const std::shared_ptr& ThreadLocalError() { return *MutThreadLocalError(); } const char* kOfBugIssueUploadPrompt = "This is a oneflow bug, please submit an issue at " "'https://github.com/Oneflow-Inc/oneflow/issues' including " "the log information of the error, the " "minimum reproduction code, and the system information."; } // namespace oneflow ================================================ FILE: oneflow/core/common/error.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ERROR_H_ #define ONEFLOW_CORE_COMMON_ERROR_H_ #include #include #include #include #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/check.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/hash.h" namespace { std::string RemoveProjectPathPrefix(const std::string& filename) { #if defined(ONEFLOW_SOURCE_DIR) && defined(ONEFLOW_BINARY_DIR) std::string project_path = ONEFLOW_SOURCE_DIR; std::string project_build_path = ONEFLOW_BINARY_DIR; if (filename.rfind(project_build_path, 0) == 0) { return std::filesystem::relative(filename, project_build_path); } else if (filename.rfind(project_path, 0) == 0) { return std::filesystem::relative(filename, project_path); } else { return filename; } #else return filename; #endif } } // namespace namespace oneflow { class ErrorStackFrame final { public: ErrorStackFrame(const ErrorStackFrame&) = default; ErrorStackFrame(const std::string& file, int64_t line, const std::string& function) : file_(RemoveProjectPathPrefix(file)), line_(line), function_(function), code_text_() {} ErrorStackFrame(const std::string& file, int64_t line, const std::string& function, const std::string& code_text) : file_(RemoveProjectPathPrefix(file)), line_(line), function_(function), code_text_(code_text) {} bool operator==(const ErrorStackFrame& other) const { return this->file_ == other.file_ && this->line_ == other.line_ && this->function_ == other.function_ && this->code_text_ == other.code_text_; } const std::string& file() const { return file_; } int64_t line() const { return line_; } const std::string& function() const { return function_; } const std::string& code_text() const { return code_text_; } std::string DebugString() const { return file_ + ":" + std::to_string(line_) + " " + function_ + "\n\t" + code_text_ + "\n"; } private: std::string file_; int64_t line_; std::string function_; std::string code_text_; }; } // namespace oneflow namespace std { template<> struct hash<::oneflow::ErrorStackFrame> final { size_t operator()(const ::oneflow::ErrorStackFrame& frame) const { using namespace oneflow; return Hash(frame.file(), frame.line(), frame.function(), frame.code_text()); } }; } // namespace std namespace oneflow { class StackedError final { public: StackedError(); StackedError(const StackedError&) = default; constexpr static int kStackReservedSize = 16; using FrameVector = small_vector, kStackReservedSize>; const ErrorProto* operator->() const { return error_proto().get(); } ErrorProto* operator->() { return mut_error_proto(); } // Getters const FrameVector& stack_frame() const { return stack_frame_; } const std::shared_ptr& error_proto() const { return error_proto_; } std::string DebugString() const { std::string str; for (const auto& frame : stack_frame()) { str += frame->DebugString() + "\n"; } str += error_proto()->DebugString(); return str; } // Setters void add_stack_frame(Symbol error_frame) { stack_frame_.push_back(error_frame); } ErrorProto* mut_error_proto() { return const_cast(error_proto_.get()); } private: FrameVector stack_frame_; std::shared_ptr error_proto_; }; std::string GetErrorString(const std::shared_ptr& error); class Error final { public: Error(const std::shared_ptr& stacked_error) : stacked_error_(stacked_error), msg_collecting_mode_(kMergeMessage) {} Error(const Error&) = default; ~Error() = default; std::shared_ptr stacked_error() const { return stacked_error_; } const ErrorProto* operator->() const { return stacked_error_->error_proto().get(); } ErrorProto* operator->() { return stacked_error_->mut_error_proto(); } operator std::string() const; void Assign(const Error& other) { stacked_error_ = other.stacked_error_; } void Merge(const Error& other); Error&& AddStackFrame(Symbol error_stack_frame); Error&& GetStackTrace(int64_t depth = 32, int64_t skip_n_firsts = 2); static Error Ok(); static Error ProtoParseFailedError(); static Error JobSetEmptyError(); static Error DeviceTagNotFoundError(); static Error InvalidValueError(); static Error IndexError(); static Error TypeError(); static Error TimeoutError(); static Error JobNameExistError(); static Error JobNameEmptyError(); static Error JobNameNotEqualError(); static Error NoJobBuildAndInferCtxError(); static Error JobConfFrozenError(); static Error JobConfNotSetError(); static Error JobConfRepeatedSetError(); static Error JobTypeNotSetError(); static Error LogicalBlobNameNotExistError(); static Error LogicalBlobNameExistError(); static Error LogicalBlobNameInvalidError(); static Error OpNameExistError(); static Error OpConfDeviceTagNoSetError(); static Error PlacementError(); static Error BlobSplitAxisInferError(); static Error UnknownJobBuildAndInferError(); static Error CheckFailedError(); static Error ValueNotFoundError(); static Error TodoError(); static Error UnimplementedError(); static Error RuntimeError(); static Error OutOfMemoryError(); static Error BoxingNotSupportedError(); static Error MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc, uint64_t available, const std::string& device_type); static Error OpKernelNotFoundError(const std::vector& error_msgs); static Error MultipleOpKernelsMatchedError(const std::vector& error_msgs); static Error LossBlobNotFoundError(); static Error RwMutexedObjectNotFoundError(); // gradient static Error GradientFunctionNotFoundError(); // symbol static Error SymbolIdUninitializedError(); static Error CompileOptionWrongError(); static Error InputDeviceNotMatchError(); enum MsgCollectingMode { kInvalidMsgCollectingMode = 0, kMergeMessage, kOverrideThenMergeMessage, }; MsgCollectingMode msg_collecting_mode() const { return msg_collecting_mode_; } void set_msg_collecting_mode(MsgCollectingMode val) { msg_collecting_mode_ = val; } private: std::shared_ptr stacked_error_; MsgCollectingMode msg_collecting_mode_; }; [[noreturn]] void ThrowError(const std::shared_ptr& error); const std::shared_ptr& ThreadLocalError(); inline Error& operator<<(Error& error, Error::MsgCollectingMode mode) { error.set_msg_collecting_mode(mode); return error; } template Error& operator<<(Error& error, const T& x) { std::ostringstream ss; ss << x; if (error.msg_collecting_mode() == Error::kMergeMessage) { error->set_msg(error->msg() + ss.str()); } else if (error.msg_collecting_mode() == Error::kOverrideThenMergeMessage) { error->set_msg(ss.str()); error.set_msg_collecting_mode(Error::kMergeMessage); } else { GLOGLOGFATAL("UNIMPLEMENTED"); } return error; } // r-value reference is used to supporting expressions like `Error() << "invalid value"` template Error&& operator<<(Error&& error, const T& x) { error << x; return std::move(error); } template<> inline Error&& operator<<(Error&& error, const std::stringstream& x) { error << x.str(); return std::move(error); } template<> inline Error&& operator<<(Error&& error, const std::ostream& x) { error << x.rdbuf(); return std::move(error); } template<> inline Error&& operator<<(Error&& error, const Error& other) { error.Merge(other); return std::move(error); } // handle CHECK_OR_THROW(expr) << ... << std::endl; inline Error&& operator<<(Error&& error, std::ostream& (*os)(std::ostream&)) { error << os; return std::move(error); } extern const char* kOfBugIssueUploadPrompt; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ERROR_H_ ================================================ FILE: oneflow/core/common/error.proto ================================================ syntax = "proto2"; package oneflow; message FieldValue { required string field = 1; required string value = 2; } enum OpcodeType { kInvalidCompareType = 0; kEq = 1; kNe = 2; kGt = 3; kGe = 4; kLt = 5; kLe = 6; } message OneFieldAssertError { required OpcodeType compare_type = 1; required FieldValue left = 2; required string right_value = 3; } message TwoFieldAssertError { required OpcodeType compare_type = 1; required FieldValue left = 2; required FieldValue right = 3; } message ConfigAssertFailedError { oneof oprand_type { OneFieldAssertError one_field_assert_error = 1; TwoFieldAssertError two_field_assert_error = 2; } } message ConfigResourceUnavailableError { required FieldValue field_value = 1; } message JobSetEmptyError { } message DeviceTagNotFoundError { } message JobNameExistError { } message JobNameEmptyError { } message JobNameNotEqualError { } message NoJobBuildAndInferCtxError { } message JobConfFrozenError { } message JobConfNotSetError { } message JobConfRepeatedSetError { } message JobTypeNotSetError { } message LogicalBlobNameNotExistError { } message LogicalBlobNameExistError { } message LogicalBlobNameInvalidError { } message OpNameExistError { } message OpConfDeviceTagNoSetError { } message PlacementError { } message BlobSplitAxisInferError { } message UnknownJobBuildAndInferError { } message ProtoParseFailedError { } message CheckFailedError { } message TodoError { } message UnimplementedError { } message RuntimeError { } message OutOfMemoryError { } message BoxingNotSupportedError { } message GradientFunctionNotFoundError { } message OpKernelNotFoundError { repeated string op_kernels_not_found_debug_str = 1; } message MultipleOpKernelsMatchedError { repeated string matched_op_kernels_debug_str = 1; } message MemoryZoneOutOfMemoryError { repeated string machine_id = 1; repeated string mem_zone_id = 2; repeated string device_tag = 3; repeated string required = 4; repeated string available = 5; } message LossBlobNotFoundError { } message RwMutexedObjectNotFoundError { } message UnknownError { } message CompileOptionWrongError { } message InputDeviceNotMatchError { repeated string info = 1; } message SymbolIdUninitializedError {} message InvalidValueError {} message IndexError {} message TypeError {} message TimeoutError {} message ValueNotFoundError {} message ErrorProto { optional string msg = 1 [default = ""]; optional string frame_msg = 2 [default = ""]; oneof error_type { ConfigAssertFailedError config_assert_failed_error = 12; ConfigResourceUnavailableError config_resource_unavailable_error = 13; ProtoParseFailedError proto_parse_failed_error = 15; CheckFailedError check_failed_error = 16; TodoError todo_error = 17; UnimplementedError unimplemented_error = 18; BoxingNotSupportedError boxing_not_supported_error = 19; GradientFunctionNotFoundError gradient_function_not_found_error = 20; OpKernelNotFoundError op_kernel_not_found_error = 21; MultipleOpKernelsMatchedError multiple_op_kernels_matched_error = 22; MemoryZoneOutOfMemoryError memory_zone_out_of_memory_error = 23; LossBlobNotFoundError loss_blob_not_found_error = 24; JobSetEmptyError job_set_empty_error = 25; DeviceTagNotFoundError device_tag_not_found_error = 26; InvalidValueError invalid_value_error = 27; IndexError index_error = 28; TypeError type_error = 29; RuntimeError runtime_error = 30; OutOfMemoryError out_of_memory_error = 32; TimeoutError timeout_error = 40; ValueNotFoundError value_not_found_error = 31; JobNameExistError job_name_exist_error = 100; JobNameEmptyError job_name_empty_error = 101; JobNameNotEqualError job_name_not_equal_error = 102; NoJobBuildAndInferCtxError no_job_build_and_infer_ctx_error = 200; JobConfFrozenError job_conf_frozen_error = 300; JobConfNotSetError job_conf_not_set_error = 301; JobConfRepeatedSetError job_conf_repeated_set_error = 302; JobTypeNotSetError job_type_not_set_error = 303; LogicalBlobNameNotExistError logical_blob_name_not_exist_error = 400; LogicalBlobNameExistError logical_blob_name_exist_error = 401; LogicalBlobNameInvalidError logical_blob_name_invalid_error = 402; OpNameExistError op_name_exist_error = 450; OpConfDeviceTagNoSetError op_conf_device_tag_no_set_error = 460; PlacementError placement_error= 470; BlobSplitAxisInferError blob_split_axis_infer_error = 480; UnknownJobBuildAndInferError unknown_job_build_and_infer_error = 500; RwMutexedObjectNotFoundError rw_mutexed_object_not_found_error = 600; SymbolIdUninitializedError symbol_id_uninitialized_error = 700; UnknownError unknown_error = 900; CompileOptionWrongError compile_option_wrong_error = 950; InputDeviceNotMatchError input_device_not_match_error = 1000; } } ================================================ FILE: oneflow/core/common/error_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/error_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/graph_scope_vars.h" namespace oneflow { namespace { std::string StripSpace(std::string str) { if (str.size() == 0) { return ""; } size_t pos = str.find_first_not_of(" "); if (pos != std::string::npos) { str.erase(0, pos); } pos = str.find_last_not_of(" "); if (pos != std::string::npos) { str.erase(pos + 1); } return str; } bool IsLetterNumberOrUnderline(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_'); } Maybe ShortenMsg(std::string str) { // 150 characters is the threshold const int num_character_threshold = 150; const int num_displayed_character = 50; if (str.size() == 0) { return str; } // strip space when JUST( xx ); str = StripSpace(str); if (str.size() < num_character_threshold) { return str; } // left part whose number of characters is just over 50 int left_index = num_displayed_character; bool pre_condition = IsLetterNumberOrUnderline(str.at(left_index)); for (; left_index < str.size(); left_index++) { bool cur_condition = IsLetterNumberOrUnderline(str.at(left_index)); if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { break; } } // right part whose number of characters is just over 50 int right_index = str.size() - num_displayed_character; pre_condition = IsLetterNumberOrUnderline(str.at(right_index)); for (; right_index >= 0; right_index--) { bool cur_condition = IsLetterNumberOrUnderline(str.at(right_index)); if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { right_index++; break; } } // a long word of more than 150 if (right_index - left_index < 50) { return str; } std::stringstream ss; CHECK_OR_RETURN(left_index >= 0); CHECK_OR_RETURN(left_index < str.size()); ss << str.substr(0, left_index); ss << " ... "; CHECK_OR_RETURN(right_index >= 0); CHECK_OR_RETURN(right_index < str.size()); ss << str.substr(right_index); return ss.str(); } // file info in stack frame std::string FormatFileOfStackFrame(const std::string& file) { std::stringstream ss; ss << "\n File \"" << file << "\", "; return ss.str(); } // line info in stack frame std::string FormatLineOfStackFrame(const int64_t& line) { std::stringstream ss; if (line >= 0) { ss << "line " << line << ","; } else { ss << "line ,"; } return ss.str(); } // function info in stack frame std::string FormatFunctionOfStackFrame(const std::string& function) { std::stringstream ss; ss << " in " << function; return ss.str(); } // msg in stack frame Maybe FormatMsgOfStackFrame(std::string error_msg, bool is_last_stack_frame) { const bool debug_mode = GetGraphDebugMode(); // only shorten the message if it is not the last stack frame AND not in debug mode if (!is_last_stack_frame && !debug_mode) { error_msg = *JUST(ShortenMsg(error_msg)); } // error_msg of last stack frame come from "<<" if (is_last_stack_frame) { error_msg = StripSpace(error_msg); } std::stringstream ss; if (!error_msg.empty()) { ss << "\n " << error_msg; } return ss.str(); } // the msg in error type instance. Maybe FormatMsgOfErrorType(const std::shared_ptr& error) { const auto& error_proto = error->error_proto(); CHECK_NE_OR_RETURN(error_proto->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET) << Error::RuntimeError() << "Parse error failed, unknown error type"; std::stringstream ss; const google::protobuf::Descriptor* error_des = error_proto->GetDescriptor(); const google::protobuf::OneofDescriptor* oneof_field_des = error_des->FindOneofByName("error_type"); const google::protobuf::Reflection* error_ref = error_proto->GetReflection(); const google::protobuf::FieldDescriptor* field_des = error_ref->GetOneofFieldDescriptor(*error_proto, oneof_field_des); CHECK_OR_RETURN(field_des != nullptr); ss << "Error Type: " << field_des->full_name(); return ss.str(); } } // namespace Maybe FormatErrorStr(const std::shared_ptr& error) { std::stringstream ss; ss << error->error_proto()->msg(); ss << error->error_proto()->frame_msg(); // Get msg from stack frame of error proto for (auto iter = error->stack_frame().rbegin(); iter < error->stack_frame().rend(); iter++) { auto stack_frame = *iter; ss << FormatFileOfStackFrame(stack_frame->file()) << FormatLineOfStackFrame(stack_frame->line()) << FormatFunctionOfStackFrame(stack_frame->function()) << *JUST(FormatMsgOfStackFrame(stack_frame->code_text(), iter == error->stack_frame().rend() - 1)); } // Get msg from error type of error proto std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error)); if (msg_of_error_type.size() != 0) { ss << "\n" << msg_of_error_type; } return ss.str(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/error_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ERROR_UTIL_H #define ONEFLOW_CORE_COMMON_ERROR_UTIL_H #include #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe FormatErrorStr(const std::shared_ptr& error); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ERROR_UTIL_H ================================================ FILE: oneflow/core/common/exception.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_EXCEPTION_H_ #define ONEFLOW_CORE_COMMON_EXCEPTION_H_ #include #include namespace oneflow { class Exception : public std::exception { public: explicit Exception(const std::string& what) : what_(what) {} virtual ~Exception() = default; const char* what() const noexcept override { return what_.c_str(); } private: std::string what_; }; class RuntimeException : public Exception { public: using Exception::Exception; }; class TypeException : public Exception { public: using Exception::Exception; }; class IndexException : public Exception { public: using Exception::Exception; }; class NotImplementedException : public Exception { public: using Exception::Exception; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_EXCEPTION_H_ ================================================ FILE: oneflow/core/common/flat_shape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/flat_shape.h" #include "oneflow/core/common/shape.h" namespace oneflow { /*static*/ Maybe FlatShape::New(const Shape& shape) { const auto& flat_shape = std::make_shared(); JUST(flat_shape->Init(shape)); return flat_shape; } Maybe FlatShape::Init(const Shape& shape) { CHECK_LE_OR_RETURN(shape.NumAxes(), SHAPE_MAX_AXIS_SIZE); this->clear_dim(); for (int i = 0; i < shape.NumAxes(); ++i) { *this->mutable_dim()->Add() = shape.At(i); } return Maybe::Ok(); } Maybe FlatShape::Check(const Shape& shape) const { CHECK_EQ_OR_RETURN(this->dim_size(), shape.NumAxes()) << Error::RuntimeError() << "Expected same shape on each rank, but found at least two shapes, " << JUST(ToShape())->ToString() << " and " << shape.ToString() << "!"; for (int i = 0; i < this->dim_size(); ++i) { CHECK_EQ_OR_RETURN(this->dim(i), shape.At(i)); } return Maybe::Ok(); } Maybe FlatShape::Check(const FlatShape& flat_shape) const { CHECK_EQ_OR_RETURN(this->dim_size(), flat_shape.NumAxes()) << Error::RuntimeError() << "Expected input of each rank must have the same size, but got at least two size, " << JUST(ToShape())->ToString() << " and " << JUST(flat_shape.ToShape())->ToString(); for (int i = 0; i < this->dim_size(); ++i) { CHECK_EQ_OR_RETURN(this->dim(i), flat_shape.At(i)) << Error::RuntimeError() << "Expected input of each rank must have the same size, but got at least two size, " << JUST(ToShape())->ToString() << " and " << JUST(flat_shape.ToShape())->ToString(); } return Maybe::Ok(); } Maybe FlatShape::ToShape() const { const auto& shape = std::make_shared(); JUST(ToShape(shape.get())); return shape; } Maybe FlatShape::ToShape(Shape* shape) const { DimVector dim_vec; for (int i = 0; i < this->dim_size(); ++i) { dim_vec.emplace_back(this->dim(i)); } *shape = Shape(dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/flat_shape.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ #define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ #include #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_vec.h" namespace oneflow { class Shape; // clang-format off FLAT_MSG_BEGIN(FlatShape); public: // Methods static Maybe New(const Shape& shape); Maybe Init(const Shape& shape); Maybe Check(const Shape& shape) const; Maybe Check(const FlatShape& flat_shape) const; Maybe ToShape() const; Maybe ToShape(Shape* shape) const; int64_t At(int i) const { return dim(i); } int64_t NumAxes() const { return dim_size(); } // Fields FLAT_MSG_DEFINE_REPEATED(int64_t, dim, SHAPE_MAX_AXIS_SIZE); FLAT_MSG_END(FlatShape); // clang-format on } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ ================================================ FILE: oneflow/core/common/foreign_lock_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/common/singleton.h" namespace oneflow { class NoForeignLockHelper final : public ForeignLockHelper { Maybe WithScopedRelease(const std::function()>& Callback) const override { return Callback(); } Maybe WithScopedAcquire(const std::function()>& Callback) const override { return Callback(); } }; static int __register_no_foreign_lock_helper __attribute__((unused)) = []() { Singleton::SetAllocated(new NoForeignLockHelper()); return 0; }(); } // namespace oneflow ================================================ FILE: oneflow/core/common/foreign_lock_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H #define ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H #include #include "oneflow/core/common/maybe.h" namespace oneflow { class ForeignLockHelper { public: virtual ~ForeignLockHelper() = default; virtual Maybe WithScopedRelease(const std::function()>&) const = 0; virtual Maybe WithScopedAcquire(const std::function()>&) const = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H ================================================ FILE: oneflow/core/common/function_traits.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_ #define ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_ #include namespace oneflow { template using void_t = void; template struct function_traits; template struct function_traits { using func_type = Ret(Args...); using return_type = Ret; using args_type = std::tuple; template using arg_type = typename std::tuple_element::type; static constexpr size_t nargs = sizeof...(Args); }; template struct function_traits { using func_type = Ret(Args...); using return_type = Ret; using args_type = std::tuple; template using arg_type = typename std::tuple_element::type; static constexpr size_t nargs = sizeof...(Args); }; template struct function_traits { using func_type = Ret(Args...); using return_type = Ret; using args_type = std::tuple; template using arg_type = typename std::tuple_element::type; static constexpr size_t nargs = sizeof...(Args); }; template struct function_traits { using func_type = Ret(Args...); using return_type = Ret; using args_type = std::tuple; template using arg_type = typename std::tuple_element::type; static constexpr size_t nargs = sizeof...(Args); }; template struct function_traits> : public function_traits {}; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_ ================================================ FILE: oneflow/core/common/hash.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_HASH_H_ #define ONEFLOW_CORE_COMMON_HASH_H_ #include #include namespace oneflow { inline size_t HashCombine(size_t lhs, size_t rhs) { return lhs ^ (rhs + 0x9e3779b9 + (lhs << 6U) + (lhs >> 2U)); } inline void HashCombine(size_t* seed, size_t hash) { *seed = HashCombine(*seed, hash); } template inline void AddHash(size_t* seed, const T&... v) { (HashCombine(seed, std::hash()(v)), ...); } template inline size_t Hash(const T& v1, const Ts&... vn) { size_t seed = std::hash()(v1); AddHash(&seed, vn...); return seed; } } // namespace oneflow namespace std { template struct hash> { std::size_t operator()(const std::pair& p) const { return oneflow::Hash(p.first, p.second); } }; template struct hash> { std::size_t operator()(const std::vector& vec) const { std::size_t hash_value = vec.size(); for (const auto& elem : vec) { oneflow::AddHash(&hash_value, elem); } return hash_value; } }; template struct hash> { size_t operator()(const std::complex& c) const { return oneflow::Hash(c.real(), c.imag()); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_HASH_H_ ================================================ FILE: oneflow/core/common/hash_container.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_HASH_CONTAINER_ #define ONEFLOW_CORE_COMMON_HASH_CONTAINER_ #include #include namespace oneflow { template> using HashMap = std::unordered_map; template> using HashSet = std::unordered_set; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_HASH_CONTAINER_ ================================================ FILE: oneflow/core/common/hash_eq_trait_ptr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_ #define ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_ namespace oneflow { template class HashEqTraitPtr final { public: HashEqTraitPtr(const HashEqTraitPtr&) = default; HashEqTraitPtr(T* ptr, size_t hash_value) : ptr_(ptr), hash_value_(hash_value) {} ~HashEqTraitPtr() = default; T* ptr() const { return ptr_; } size_t hash_value() const { return hash_value_; } bool operator==(const HashEqTraitPtr& rhs) const { return *ptr_ == *rhs.ptr_; } private: T* ptr_; size_t hash_value_; }; } // namespace oneflow namespace std { template struct hash> final { size_t operator()(const oneflow::HashEqTraitPtr& ptr) const { return ptr.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_ ================================================ FILE: oneflow/core/common/high_order_bool.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_ #define ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_ #include #include #include #include #include #include "oneflow/core/common/function_traits.h" #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace hob { template struct BaseExpr { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" // NOTE: Performance will be degraded if the destructor is virtual. // So please do NOT implement custom destructor in any child classes of BaseExpr, // and every fields of child classes should be of POD type. ~BaseExpr() = default; #pragma GCC diagnostic pop ALWAYS_INLINE virtual scalar_or_const_ref_t get(const Context&) const = 0; virtual std::string DebugStr(const Context&, bool display_result = true) const = 0; // NOLINT operator bool() = delete; }; template struct Expr : public BaseExpr { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" ~Expr() = default; #pragma GCC diagnostic pop }; template struct Literal final : public Expr> { Literal(const ValueT& val) : Literal(ToString(val), val) {} // NOLINT Literal(const std::string& debug_str, const ValueT& val) : val_(val), debug_str_(debug_str) {} ALWAYS_INLINE scalar_or_const_ref_t get(const Context&) const override { return val_; } std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; } private: ValueT val_; std::string debug_str_; }; template using LiteralBool = Literal; template::template arg_type<0>>, typename ValueT = std::decay_t::return_type>> struct Custom final : public Expr> { explicit Custom(Fn fn) : Custom("", fn) {} Custom(std::string debug_str, Fn fn) : fn_(std::move(fn)), debug_str_(std::move(debug_str)) {} ALWAYS_INLINE scalar_or_const_ref_t get(const Context& context) const override { return fn_(context); } std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; } private: Fn fn_; std::string debug_str_; }; template ALWAYS_INLINE inline Custom make_custom(Fn fn) { return Custom(std::forward(fn)); } template ALWAYS_INLINE inline Custom make_custom(const std::string& debug_str, Fn fn) { return Custom(debug_str, std::forward(fn)); } template using BoolExpr = Expr; template struct NotBoolFunctor final : public BoolExpr> { explicit NotBoolFunctor(const E& expr) : expr_(expr) {} ALWAYS_INLINE bool get(const Context& context) const override { return !expr_.get(context); } std::string DebugStr(const Context& ctx, bool display_result) const override { std::ostringstream string_stream; string_stream << "(" << "not " << expr_.DebugStr(ctx, display_result) << ")"; return string_stream.str(); } private: const E expr_; }; template NotBoolFunctor operator!(BoolExpr const& lhs) { return NotBoolFunctor(*static_cast(&lhs)); } #define DEFINE_BINARY_FUNCTOR(name, op) \ template \ struct name##BoolFunctor final : public BoolExpr> { \ name##BoolFunctor(const E1& lhs, const E2& rhs) : lhs_(lhs), rhs_(rhs) {} \ \ ALWAYS_INLINE bool get(const Context& context) const override; \ \ std::string DebugStr(const Context& ctx, bool display_result) const override; \ \ private: \ const E1 lhs_; \ const E2 rhs_; \ }; \ \ template \ name##BoolFunctor operator op(Expr const& lhs, \ Expr const& rhs) { \ return name##BoolFunctor(*static_cast(&lhs), \ *static_cast(&rhs)); \ } \ \ template \ name##BoolFunctor> operator op( \ Expr const& lhs, ValueT const& rhs) { \ return name##BoolFunctor>( \ *static_cast(&lhs), Literal(rhs)); \ } DEFINE_BINARY_FUNCTOR(Equal, ==) DEFINE_BINARY_FUNCTOR(And, &&) DEFINE_BINARY_FUNCTOR(Or, ||) DEFINE_BINARY_FUNCTOR(Greater, >) DEFINE_BINARY_FUNCTOR(Less, <) DEFINE_BINARY_FUNCTOR(EqualOrGreater, >=) DEFINE_BINARY_FUNCTOR(EqualOrLess, <=) #undef DEFINE_BINARY_FUNCTOR #define DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(name, op) \ template \ ALWAYS_INLINE inline bool name##BoolFunctor::get(const Context& context) \ const { \ return lhs_.get(context) op rhs_.get(context); \ } \ template \ std::string name##BoolFunctor::DebugStr(const Context& ctx, \ bool display_result) const { \ std::string l_str = lhs_.DebugStr(ctx, display_result); \ std::string r_str = rhs_.DebugStr(ctx, display_result); \ std::ostringstream string_stream; \ string_stream << "(" << l_str << " " << OF_PP_STRINGIZE(op) << " " << r_str << ")"; \ return string_stream.str(); \ } DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Equal, ==) DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Greater, >) DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Less, <) DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrGreater, >=) DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrLess, <=) #undef DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS template ALWAYS_INLINE inline bool AndBoolFunctor::get(const Context& context) const { bool lhs_result = lhs_.get(context); if (!lhs_result) { return false; } return rhs_.get(context); } template std::string AndBoolFunctor::DebugStr(const Context& ctx, bool display_result) const { std::string l_str = lhs_.DebugStr(ctx, display_result); display_result = display_result && lhs_.get(ctx); std::string r_str = rhs_.DebugStr(ctx, display_result); std::ostringstream string_stream; string_stream << "(" << l_str << " and " << r_str << ")"; return string_stream.str(); } template ALWAYS_INLINE inline bool OrBoolFunctor::get(const Context& context) const { bool lhs_result = lhs_.get(context); if (lhs_result) { return true; } return rhs_.get(context); } template std::string OrBoolFunctor::DebugStr(const Context& ctx, bool display_result) const { std::string l_str = lhs_.DebugStr(ctx, display_result); display_result = display_result && (!lhs_.get(ctx)); std::string r_str = rhs_.DebugStr(ctx, display_result); std::ostringstream string_stream; string_stream << "(" << l_str << " or " << r_str << ")"; return string_stream.str(); } template EqualBoolFunctor> operator==( Expr const& lhs, const char* rhs) { return EqualBoolFunctor>( *static_cast(&lhs), Literal(rhs)); } } // namespace hob } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_ ================================================ FILE: oneflow/core/common/just.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_JUST_H_ #define ONEFLOW_CORE_COMMON_JUST_H_ #include #include #include "oneflow/core/common/error.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { template class Maybe; template class Optional; Maybe FormatErrorStr(const std::shared_ptr&); namespace { std::string GetFormatedSerializedError(const std::shared_ptr&); } namespace private_details { inline std::shared_ptr&& JustErrorAddStackFrame( std::shared_ptr&& err, Symbol error_stack_frame) { err->add_stack_frame(error_stack_frame); return std::move(err); } template Error&& AddFrameMessage(Error&& error, const T& x) { std::ostringstream ss; ss << x; error->set_frame_msg(error->frame_msg() + ss.str()); return std::move(error); } template<> inline Error&& AddFrameMessage(Error&& error, const std::stringstream& x) { AddFrameMessage(std::move(error), x.str()); return std::move(error); } template<> inline Error&& AddFrameMessage(Error&& error, const std::ostream& x) { AddFrameMessage(std::move(error), x.rdbuf()); return std::move(error); } template Error&& JustErrorAddFrameMessage(Error&& err, T&&... msg) { (AddFrameMessage(std::move(err), std::forward(msg)), ...); return std::move(err); } template bool JustIsOk(const Maybe& val) { return val.IsOk(); } template bool JustIsOk(const Optional& val) { return val.has_value(); } template std::shared_ptr JustGetError(const Maybe& val) { return val.stacked_error(); } template std::shared_ptr JustGetError(const Optional&) { return Error::ValueNotFoundError().stacked_error(); } template typename std::remove_const::type>::type&& RemoveRValConst( T&& v) noexcept { static_assert(std::is_rvalue_reference::value, "rvalue is expected here"); return const_cast::type>::type&&>(v); } } // namespace private_details } // namespace oneflow #define __JustStackCheckWrapper__(...) __VA_ARGS__ #define TRY(...) __JustStackCheckWrapper__(__VA_ARGS__) #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) #define JUST(...) \ ::oneflow::private_details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ return ::oneflow::private_details::JustErrorAddStackFrame( \ ::oneflow::private_details::JustGetError(_just_value_to_check_), \ [](const char* function) { \ thread_local static auto frame = ::oneflow::SymbolOf( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #__VA_ARGS__)); \ return frame; \ }(__FUNCTION__)); \ } \ std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #define CHECK_JUST(...) \ ([&](const char* _just_closure_func_name_) { \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ thread_local static auto frame = ::oneflow::SymbolOf( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #__VA_ARGS__)); \ THROW(RuntimeError) << ::oneflow::GetErrorString( \ ::oneflow::private_details::JustErrorAddStackFrame( \ ::oneflow::private_details::JustGetError(_just_value_to_check_), frame)); \ } \ return std::forward(_just_value_to_check_); \ })(__FUNCTION__) \ .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #define JUST_MSG(value, ...) \ ::oneflow::private_details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = (value); \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ return ::oneflow::private_details::JustErrorAddFrameMessage( \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ .AddStackFrame([](const char* function) { \ thread_local static auto frame = ::oneflow::SymbolOf( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #value)); \ return frame; \ }(__FUNCTION__)), \ "\nError message from " __FILE__, ":", __LINE__, "\n\t", #value, ": ", __VA_ARGS__, \ "\n"); \ } \ std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #define CHECK_JUST_MSG(value, ...) \ ([&](const char* _just_closure_func_name_) { \ auto&& _just_value_to_check_ = (value); \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ thread_local static auto frame = ::oneflow::SymbolOf( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #value)); \ THROW(RuntimeError) << ::oneflow::GetErrorString( \ ::oneflow::private_details::JustErrorAddFrameMessage( \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ .AddStackFrame(frame), \ "\nError message from " __FILE__, ":", __LINE__, "\n\t", #value, ": ", __VA_ARGS__, \ "\n") \ .stacked_error()); \ } \ return std::forward(_just_value_to_check_); \ })(__FUNCTION__) \ .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #define JUST_OPT(...) \ ::oneflow::private_details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ if (!_just_value_to_check_.has_value()) { return NullOpt; } \ std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #else #error statement expression is no supported, please implement try-catch version of JUST #endif #endif // ONEFLOW_CORE_COMMON_JUST_H_ ================================================ FILE: oneflow/core/common/layout_standardize.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_ #define ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_ namespace oneflow { template class LayoutStandardize final { public: void __Init__(const T& val) { new (&data_[0]) T(val); } void __Delete__() { Mutable()->~T(); } const T& Get() const { return *reinterpret_cast(&data_[0]); } T* Mutable() { return reinterpret_cast(&data_[0]); } private: union { char data_[sizeof(T)]; int64_t align_; }; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_ ================================================ FILE: oneflow/core/common/math_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "glog/logging.h" #include "oneflow/core/common/math_util.h" namespace oneflow { int64_t Gcd(int64_t m, int64_t n) { if (m < n) { std::swap(m, n); } if (n == 0) { return m; } CHECK_GT(m, 0); CHECK_GT(n, 0); return Gcd(n, m % n); } int64_t Lcm(int64_t m, int64_t n) { return m * n / Gcd(m, n); } } // namespace oneflow ================================================ FILE: oneflow/core/common/math_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_MATH_UTIL_H_ #define ONEFLOW_CORE_COMMON_MATH_UTIL_H_ #include #include "data_type.h" #include "oneflow/core/common/util.h" namespace oneflow { /* * math constants */ template constexpr T pi = static_cast(3.141592653589793238462643383279502); int64_t Gcd(int64_t m, int64_t n); int64_t Lcm(int64_t m, int64_t n); template OF_DEVICE_FUNC T DeviceMin(T a, T b) { #if defined(__CUDA_ARCH__) return a < b ? a : b; #else return std::min(a, b); #endif } template OF_DEVICE_FUNC T DeviceMax(T a, T b) { #if defined(__CUDA_ARCH__) return a > b ? a : b; #else return std::max(a, b); #endif } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_MATH_UTIL_H_ ================================================ FILE: oneflow/core/common/maybe.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_MAYBE_H_ #define ONEFLOW_CORE_COMMON_MAYBE_H_ #include "oneflow/core/common/throw.h" #include #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/either_ptr.h" #include "oneflow/core/common/shared_or_scalar.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/just.h" namespace oneflow { template struct is_maybe { static const bool value = false; }; template struct is_maybe> { static const bool value = true; }; template class Maybe::value || IsScalarType::value) && !std::is_reference::value>::type> final { public: Maybe(const T& data) : data_or_error_(std::make_shared(data)) {} Maybe(T&& data) : data_or_error_(std::make_shared(std::move(data))) {} Maybe(const Error& error) : data_or_error_(error.stacked_error()) {} Maybe(const std::shared_ptr& data) : data_or_error_(data) {} Maybe(std::shared_ptr&& data) : data_or_error_(std::move(data)) {} Maybe(const std::shared_ptr& error) : data_or_error_(error) {} Maybe(const Maybe&) = default; Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {} ~Maybe() = default; bool IsOk() const { return data_or_error_.template Has(); } std::shared_ptr Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { return data_or_error_.template Get(); } std::shared_ptr stacked_error() const { return data_or_error_.template Get(); } std::shared_ptr error() const { return stacked_error()->error_proto(); } std::string GetSerializedError() const { CHECK(!IsOk()); return GetFormatedSerializedError(this->stacked_error()); } template Type GetDataAndSerializedStackedError(std::string* error_str, const Type& default_for_error) const { static_assert(std::is_same::value, "error type for argument 1"); if (IsOk()) { *error_str = StackedError().DebugString(); return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } else { *error_str = this->stacked_error()->DebugString(); return default_for_error; } } template std::pair> GetDataAndStackedError( const Type& default_for_error) const { if (IsOk()) { return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), std::shared_ptr()); } else { return std::make_pair(default_for_error, stacked_error()); } } std::pair, std::shared_ptr> GetDataPtrAndStackedError() const { if (IsOk()) { return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), std::shared_ptr()); } else { return std::make_pair(std::shared_ptr(), stacked_error()); } } template Type GetOrThrow() const { if (!IsOk()) { ThrowError(stacked_error()); } return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } std::shared_ptr GetPtrOrThrow() const { if (!IsOk()) { ThrowError(stacked_error()); } return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } private: EitherPtr data_or_error_; }; template class Maybe::value>::type> final { public: Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); } Maybe(const std::shared_ptr& error) : error_or_scalar_(error) { CheckError(); } Maybe(const Maybe&) = default; Maybe(Maybe&&) = default; ~Maybe() = default; static Maybe Ok() { return Maybe(); } bool IsOk() const { return error_or_scalar_.IsScalar(); } void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {} std::shared_ptr stacked_error() const { return error_or_scalar_.shared_ptr(); } std::shared_ptr error() const { return stacked_error()->error_proto(); } std::string GetSerializedError() const { CHECK(!IsOk()); return GetFormatedSerializedError(this->stacked_error()); } void GetDataAndSerializedStackedError(std::string* error_str) const { if (IsOk()) { *error_str = StackedError().DebugString(); } else { *error_str = this->stacked_error()->DebugString(); } } std::shared_ptr GetDataAndStackedError() const { if (IsOk()) { return std::shared_ptr(); } else { return stacked_error(); } } void GetOrThrow() const { if (!IsOk()) { ThrowError(stacked_error()); } return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } private: Maybe() : error_or_scalar_(nullptr) {} void CheckError() const { CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); } SharedOrScalar error_or_scalar_; }; inline const std::shared_ptr& UninitializedValueError() { static thread_local const auto& error = (Error::InvalidValueError() << "uninitialized value").stacked_error(); return error; } template class Maybe::value>::type> final { public: Maybe(T data) : error_or_scalar_(data) {} Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); } Maybe(const std::shared_ptr& error) : error_or_scalar_(error) { CheckError(); } Maybe() : error_or_scalar_(UninitializedValueError()) {} Maybe(const Maybe&) = default; Maybe(Maybe&&) = default; ~Maybe() = default; void operator=(const Maybe& rhs) { error_or_scalar_ = rhs.error_or_scalar_; } bool IsOk() const { return error_or_scalar_.IsScalar(); } T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { return error_or_scalar_.scalar_value(); } std::shared_ptr stacked_error() const { return error_or_scalar_.shared_ptr(); } std::shared_ptr error() const { return stacked_error()->error_proto(); } std::string GetSerializedError() const { CHECK(!IsOk()); return GetFormatedSerializedError(this->stacked_error()); } T GetDataAndSerializedStackedError(std::string* error_str, const T& default_for_error) const { if (IsOk()) { *error_str = StackedError().DebugString(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } else { *error_str = this->stacked_error()->DebugString(); return default_for_error; } } std::pair> GetDataAndStackedError( const T& default_for_error) const { if (IsOk()) { return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), std::shared_ptr()); } else { return std::make_pair(default_for_error, stacked_error()); } } T GetOrThrow() const { if (!IsOk()) { ThrowError(stacked_error()); } return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } private: void CheckError() const { CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); } SharedOrScalar error_or_scalar_; }; template class Maybe::value || IsScalarType::value) && std::is_reference::value>::type> final { using ValueT = typename std::remove_reference::type; using PtrT = ValueT*; public: Maybe(T data) : maybe_ptr_(&data) {} Maybe(const Error& error) : maybe_ptr_(error) {} Maybe(const std::shared_ptr& error) : maybe_ptr_(error) {} Maybe(const Maybe&) = default; Maybe(Maybe&&) = default; ~Maybe() = default; bool IsOk() const { return maybe_ptr_.IsOk(); } T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } std::shared_ptr stacked_error() const { return maybe_ptr_.stacked_error(); } std::shared_ptr error() const { return stacked_error()->error_proto(); } std::string GetSerializedError() const { CHECK(!IsOk()); return maybe_ptr_.GetSerializedError(); } T GetDataAndSerializedStackedError(std::string* error_str) const { return *maybe_ptr_.GetDataAndSerializedStackedError(error_str, static_cast(nullptr)); } T GetOrThrow() const { if (!IsOk()) { ThrowError(stacked_error()); } return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); } private: Maybe maybe_ptr_; }; namespace { std::string GetFormatedSerializedError(const std::shared_ptr& stacked_error) { // return error msg got from formatted function or debugstring. const auto& maybe_error = TRY(FormatErrorStr(stacked_error)); const auto& error_str = maybe_error.GetDataAndStackedError(stacked_error->DebugString()); return error_str.first; } } // namespace } // namespace oneflow #define CHECK_OK(...) \ for (auto&& maybe = __JustStackCheckWrapper__(__VA_ARGS__); \ GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!maybe.IsOk());) \ LOG(FATAL) << OF_PP_STRINGIZE(__VA_ARGS__) << " is not OK:\n" << maybe.GetSerializedError() #define OF_RETURN_IF_ERROR(...) \ for (auto&& maybe_##__LINE__ = __JustStackCheckWrapper__(__VA_ARGS__); \ !maybe_##__LINE__.IsOk();) \ return Error(maybe_##__LINE__.stacked_error()).AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) #define OF_TODO() \ return Error::TodoError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) #define OF_UNIMPLEMENTED() \ return Error::UnimplementedError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) #define OF_RUNTIME_ERROR() \ return Error::RuntimeError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) \ << "RuntimeError " \ ": " #define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt #define OF_LOG_ONCE(x) \ { \ static bool warned = false; \ if (!warned) { \ warned = true; \ x; \ } \ } #define OF_COMPLIE_OPTION_ERROR() \ return Error::CompileOptionWrongError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) \ << "Compile option wrong: " #define CHECK_OR_RETURN_INTERNAL(expr, error_msg) \ if (!(expr)) \ return Error::CheckFailedError().AddStackFrame([](const char* function) { \ thread_local static auto frame = \ SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function, error_msg)); \ return frame; \ }(__FUNCTION__)) #define CHECK_OR_RETURN_ERROR(expr) \ if (!(expr)) \ return Error::CheckFailedError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) // NOTE: Please contact @daquexian if you need to modify these CHECK_(XX_)OR_RETURN macros. There // are some static analyzers depending on the internal implementation of them. #define CHECK_OR_RETURN(expr) \ CHECK_OR_RETURN_INTERNAL(expr, OF_PP_STRINGIZE(CHECK_OR_RETURN(expr))) \ << "Check failed: (" << OF_PP_STRINGIZE(expr) << ") " << Error::kOverrideThenMergeMessage #define CHECK_EQ_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) == (rhs), OF_PP_STRINGIZE(CHECK_EQ_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " == " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_GE_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) >= (rhs), OF_PP_STRINGIZE(CHECK_GE_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " >= " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_GT_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) > (rhs), OF_PP_STRINGIZE(CHECK_GT_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " > " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_LE_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) <= (rhs), OF_PP_STRINGIZE(CHECK_LE_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " <= " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_LT_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) < (rhs), OF_PP_STRINGIZE(CHECK_LT_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " < " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_NE_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN_INTERNAL((lhs) != (rhs), OF_PP_STRINGIZE(CHECK_NE_OR_RETURN(lhs, rhs))) \ << "Check failed: (" << (lhs) << " != " << (rhs) << ") " << Error::kOverrideThenMergeMessage #define CHECK_STREQ_OR_RETURN(lhs, rhs) CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs)) #define CHECK_STRNE_OR_RETURN(lhs, rhs) CHECK_NE_OR_RETURN(std::string(lhs), std::string(rhs)) #define CHECK_NOTNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr != nullptr) #define CHECK_ISNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr == nullptr) #define TODO_THEN_RETURN() OF_TODO() #define UNIMPLEMENTED_THEN_RETURN() OF_UNIMPLEMENTED() #endif // ONEFLOW_CORE_COMMON_MAYBE_H_ ================================================ FILE: oneflow/core/common/maybe_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "gtest/gtest.h" #include #include #include "oneflow/core/common/exception.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { TEST(Maybe, JUST_MSG) { auto f = [](int x) -> Maybe { if (x > 10) { return Error::InvalidValueError() << "input value " << x; } return 233; }; auto g = [](int x) { return x * x - 5 * x + 3; }; auto h = [&](int x) -> Maybe { auto y = g(x); return JUST_MSG(f(y), "input value g(", x, ")"); }; auto i = [&](float x) -> Maybe { int y = x; return JUST_MSG(h(y), std::stringstream() << "input value int(" << x << ")"); }; auto data = CHECK_JUST(i(1)); ASSERT_EQ(data, 233); auto err = i(10.123).stacked_error(); ASSERT_EQ(err->error_proto()->msg(), R"(input value 53)"); ASSERT_GE(err->stack_frame().size(), 2); ASSERT_EQ(err->stack_frame().at(0)->code_text(), "f(y)"); ASSERT_EQ(err->stack_frame().at(1)->code_text(), "h(y)"); try { CHECK_JUST(i(10.234)); } catch (const RuntimeException& e) { EXPECT_TRUE(std::string(e.what()).find(R"(input value 53)") != std::string::npos); } } TEST(Maybe, CHECK_OR_RETURN) { auto f = [](int x) -> Maybe { CHECK_OR_RETURN(x > 10); return 233; }; auto i = [&](float x) -> Maybe { return JUST(f(x)); }; auto data = CHECK_JUST(i(20)); ASSERT_EQ(data, 233); auto err = i(1).stacked_error(); ASSERT_GE(err->stack_frame().size(), 2); ASSERT_EQ(err->stack_frame().at(0)->code_text(), "CHECK_OR_RETURN(x > 10)"); ASSERT_EQ(err->stack_frame().at(1)->code_text(), "f(x)"); } TEST(Maybe, CHECK_OK) { auto f = [](int x) -> Maybe { if (x > 10) { return Error::InvalidValueError() << "input value " << x; } return 233; }; auto g = [&](int x) -> Maybe { auto y = JUST(f(x)); return f(y); }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto) ASSERT_EXIT(CHECK_OK(g(11)), testing::KilledBySignal(SIGABRT), R"(g\(11\) is not OK)"); } TEST(Maybe, Noncopyable) { Maybe> a{std::make_unique(1)}; } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/mem_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/mem_util.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" #include #include namespace oneflow { namespace { struct ProcStat { std::string pid, comm, state, ppid, pgrp, session, tty_nr; std::string tpgid, flags, minflt, cminflt, majflt, cmajflt; std::string utime, stime, cutime, cstime, priority, nice; std::string num_threads, itrealvalue, starttime; unsigned long vsize = 0; long rss = 0; }; Maybe CPUSynchronize() { if (Singleton::Get() != nullptr) { return vm::CurrentRankSync(); } return Maybe::Ok(); } } // namespace // Reference: https://stackoverflow.com/questions/669438/how-to-get-memory-usage-at-runtime-using-c void ProcessMemUsage(double* vm_usage, double* resident_set) { *vm_usage = 0.0; *resident_set = 0.0; #ifdef __linux__ // 'file' stat seems to give the most reliable results std::ifstream stat_stream("/proc/self/stat", std::ios_base::in); ProcStat proc_stat; stat_stream >> proc_stat.pid >> proc_stat.comm >> proc_stat.state >> proc_stat.ppid >> proc_stat.pgrp >> proc_stat.session >> proc_stat.tty_nr >> proc_stat.tpgid >> proc_stat.flags >> proc_stat.minflt >> proc_stat.cminflt >> proc_stat.majflt >> proc_stat.cmajflt >> proc_stat.utime >> proc_stat.stime >> proc_stat.cutime >> proc_stat.cstime >> proc_stat.priority >> proc_stat.nice >> proc_stat.num_threads >> proc_stat.itrealvalue >> proc_stat.starttime >> proc_stat.vsize >> proc_stat.rss; // don't care about the rest stat_stream.close(); long page_size_kb = sysconf(_SC_PAGE_SIZE); // in case x86-64 is configured to use 2MB pages // return with MB *vm_usage = proc_stat.vsize >> 20; // return with MB *resident_set = (proc_stat.rss * page_size_kb) >> 20; #endif // __linux__ } Maybe GetCPUMemoryUsed() { JUST(CPUSynchronize()); double vm_ = 0, rss_ = 0; ProcessMemUsage(&vm_, &rss_); return rss_; } std::string FormatMemSize(uint64_t size) { std::ostringstream os; os.precision(1); os << std::fixed; if (size <= 1024UL) { os << size << " Bytes"; } else if (size <= 1048576UL) { os << ((float)size / 1024.0) << " KB"; } else if (size <= 1073741824UL) { os << ((float)size / 1048576.0) << " MB"; } else { os << ((float)size / 1073741824.0) << " GB"; } return os.str(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/mem_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_MEM_UTIL_H_ #define ONEFLOW_CORE_COMMON_MEM_UTIL_H_ #include #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" namespace oneflow { void ProcessMemUsage(double* vm_usage, double* resident_set); std::string FormatMemSize(uint64_t size); Maybe GetCPUMemoryUsed(); } // namespace oneflow #define LOG_MEM(...) \ double vm_ = 0, rss_ = 0; \ ProcessMemUsage(&vm_, &rss_); \ VLOG(1) << "File " __FILE__ << ", Line " << __LINE__ << ", Func " << __FUNCTION__ \ << ", Mem size RSS " << rss_ << "MB." #endif // ONEFLOW_CORE_COMMON_MEM_UTIL_H_ ================================================ FILE: oneflow/core/common/memory_format.proto ================================================ syntax = "proto2"; package oneflow; enum MemoryFormat { kContiguous = 0; kChannelsLast = 1; kPreserve = 2; kMemoryFormatCount = 3; }; ================================================ FILE: oneflow/core/common/meta_util.hpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_META_UTIL_HPP_ #define ONEFLOW_CORE_COMMON_META_UTIL_HPP_ #include #include namespace oneflow { template void for_each(const std::tuple& t, Func&& f, std::index_sequence) { (std::forward(f)(std::get(t)), ...); } template void for_each_i(const std::tuple& t, Func&& f, std::index_sequence) { (std::forward(f)(std::get(t), std::integral_constant{}), ...); } template using remove_const_reference_t = std::remove_const_t>; template auto make_tuple_from_sequence(std::index_sequence) { return std::make_tuple(Is...); } template constexpr auto make_tuple_from_sequence() { return make_tuple_from_sequence(std::make_index_sequence{}); } namespace detail { template void tuple_switch(const std::size_t i, Tuple&& t, F&& f, std::index_sequence) { (void)std::initializer_list{ (i == Is && ((void)std::forward(f)(std::integral_constant{}), 0))...}; } } // namespace detail template inline void tuple_switch(const std::size_t i, Tuple&& t, F&& f) { constexpr auto N = std::tuple_size>::value; detail::tuple_switch(i, std::forward(t), std::forward(f), std::make_index_sequence{}); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_META_UTIL_HPP_ ================================================ FILE: oneflow/core/common/nd_index.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { NdIndex::NdIndex(const std::initializer_list& dim_vec) : dim_vec_(dim_vec) {} NdIndex::NdIndex(const DimVector& dim_vec) : dim_vec_(dim_vec) {} NdIndex& NdIndex::operator=(const NdIndex& shape) { dim_vec_ = shape.dim_vec_; return *this; } bool NdIndex::operator==(const NdIndex& rhs) const { return dim_vec_ == rhs.dim_vec_; } } // namespace oneflow ================================================ FILE: oneflow/core/common/nd_index.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ND_INDEX_H_ #define ONEFLOW_CORE_COMMON_ND_INDEX_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" namespace oneflow { class NdIndex final { public: NdIndex() = default; explicit NdIndex(const DimVector& dim_vec); NdIndex(const std::initializer_list& dim_vec); ~NdIndex() = default; NdIndex& operator=(const NdIndex& other); bool operator==(const NdIndex& rhs) const; bool operator!=(const NdIndex& rhs) const { return !(*this == rhs); } const DimVector& dim_vec() const { return dim_vec_; } int64_t At(int64_t index) const { return dim_vec_.at(index); } int64_t NumAxes() const { return dim_vec_.size(); } private: DimVector dim_vec_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ND_INDEX_H_ ================================================ FILE: oneflow/core/common/nd_index_offset_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_ #define ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_ #include "oneflow/core/common/data_type.h" #include namespace oneflow { template class NdIndexOffsetHelper { public: OF_DEVICE_FUNC NdIndexOffsetHelper() = default; template OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) { constexpr int n = 1 + sizeof...(dims); static_assert(n <= N, ""); T dims_arr[n] = {d0, static_cast(dims)...}; InitStrides(dims_arr, n); } OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims) { InitStrides(dims, N); } template OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims) { T dims_arr[N]; for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; } InitStrides(dims_arr, N); } OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims, int n) { InitStrides(dims, n); } template OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims, int n) { T dims_arr[N]; for (int i = 0; i < N; ++i) { if (i < n) { dims_arr[i] = dims[i]; } } InitStrides(dims_arr, n); } virtual ~NdIndexOffsetHelper() = default; OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const { T offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { offset += index[i] * stride_[i]; } return offset; } OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const { assert(n <= N); T offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { if (i < n) { offset += index[i] * stride_[i]; } } return offset; } template OF_DEVICE_FUNC T NdIndexToOffset(T d0, Ts... others) const { constexpr int n = 1 + sizeof...(others); static_assert(n <= N, ""); T index[n] = {d0, others...}; T offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < n - 1; ++i) { offset += index[i] * stride_[i]; } if (n == N) { offset += index[n - 1]; } else { offset += index[n - 1] * stride_[n - 1]; } return offset; } OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const { T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N - 1; ++i) { const T idx = remaining / stride_[i]; index[i] = idx; remaining = remaining - idx * stride_[i]; } index[N - 1] = remaining; } OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const { assert(n <= N); T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { if (i < n) { const T idx = remaining / stride_[i]; index[i] = idx; remaining = remaining - idx * stride_[i]; } } } template OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const { constexpr int n = 1 + sizeof...(others); static_assert(n <= N, ""); T* index[n] = {&d0, &others...}; T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < n - 1; ++i) { const T idx = remaining / stride_[i]; *index[i] = idx; remaining = remaining - idx * stride_[i]; } if (n == N) { *index[n - 1] = remaining; } else { *index[n - 1] = remaining / stride_[n - 1]; } } OF_DEVICE_FUNC constexpr int Size() const { return N; } protected: OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) { for (int i = n - 1; i < N; ++i) { stride_[i] = 1; } for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; } } T stride_[N]; }; template class NdIndexStrideOffsetHelper : public NdIndexOffsetHelper { public: OF_DEVICE_FUNC NdIndexStrideOffsetHelper() = default; OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides) { for (int i = 0; i < N; ++i) { stride_[i] = strides[i]; } } template OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides) { for (int i = 0; i < N; ++i) { stride_[i] = static_cast(strides[i]); } } OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides, int n) { for (int i = 0; i < N; ++i) { if (i < n) { stride_[i] = strides[i]; } } } template OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides, int n) { for (int i = 0; i < N; ++i) { if (i < n) { stride_[i] = static_cast(strides[i]); } } } private: using NdIndexOffsetHelper::stride_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_ ================================================ FILE: oneflow/core/common/nd_index_offset_helper_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #define private public #define protected public #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace test { template void test_3d() { const T d0_max = 3; const T d1_max = 4; const T d2_max = 5; const NdIndexOffsetHelper helper(d0_max, d1_max, d2_max); for (T d0 = 0; d0 < d0_max; ++d0) { const T offset0 = d0 * d1_max * d2_max; { std::vector expected0({d0}); { std::vector dims(1); helper.OffsetToNdIndex(offset0, dims.data(), 1); ASSERT_EQ(expected0, dims); } { std::vector dims(1); helper.OffsetToNdIndex(offset0, dims.at(0)); ASSERT_EQ(expected0, dims); } ASSERT_EQ(offset0, helper.NdIndexToOffset(expected0.data(), 1)); ASSERT_EQ(offset0, helper.NdIndexToOffset(expected0.at(0))); } for (T d1 = 0; d1 < d1_max; ++d1) { const T offset1 = offset0 + d1 * d2_max; { std::vector expected1({d0, d1}); { std::vector dims(2); helper.OffsetToNdIndex(offset1, dims.data(), 2); ASSERT_EQ(expected1, dims); } { std::vector dims(2); helper.OffsetToNdIndex(offset1, dims.at(0), dims.at(1)); ASSERT_EQ(expected1, dims); } ASSERT_EQ(offset1, helper.NdIndexToOffset(expected1.data(), 2)); ASSERT_EQ(offset1, helper.NdIndexToOffset(expected1.at(0), expected1.at(1))); } for (T d2 = 0; d2 < d2_max; ++d2) { const T offset2 = offset1 + d2; { std::vector expected2({d0, d1, d2}); { std::vector dims(3); helper.OffsetToNdIndex(offset2, dims.data(), 3); ASSERT_EQ(expected2, dims); } { std::vector dims(3); helper.OffsetToNdIndex(offset2, dims.at(0), dims.at(1), dims.at(2)); ASSERT_EQ(expected2, dims); } if (ndims == 3) { std::vector dims(3); helper.OffsetToNdIndex(offset2, dims.data()); ASSERT_EQ(expected2, dims); ASSERT_EQ(offset2, helper.NdIndexToOffset(expected2.data())); } ASSERT_EQ(offset2, helper.NdIndexToOffset(expected2.data(), 3)); ASSERT_EQ(offset2, helper.NdIndexToOffset(expected2.at(0), expected2.at(1), expected2.at(2))); } } } } } TEST(NdIndexOffsetHelper, static_3d) { test_3d(); test_3d(); } TEST(NdIndexOffsetHelper, dynamic_3d) { test_3d(); test_3d(); test_3d(); test_3d(); } template void test_constructor() { const T d0 = 3; const T d1 = 4; const T d2 = 5; // static { std::vector dims({d0, d1, d2}); const NdIndexOffsetHelper helper1(d0, d1, d2); const NdIndexOffsetHelper helper2(dims.data()); const NdIndexOffsetHelper helper3(dims.data(), dims.size()); std::vector stride({d1 * d2, d2, 1}); for (int i = 0; i < 3; ++i) { ASSERT_EQ(helper1.stride_[i], stride[i]); ASSERT_EQ(helper2.stride_[i], stride[i]); ASSERT_EQ(helper3.stride_[i], stride[i]); } } // dynamic { std::vector dims({d0, d1, d2}); const NdIndexOffsetHelper helper1(d0, d1, d2); const NdIndexOffsetHelper helper2(dims.data(), dims.size()); std::vector stride({d1 * d2, d2, 1, 1, 1, 1}); for (int i = 0; i < 6; ++i) { ASSERT_EQ(helper1.stride_[i], stride[i]); ASSERT_EQ(helper2.stride_[i], stride[i]); } } } TEST(NdIndexOffsetHelper, constructor) { test_constructor(); test_constructor(); } template void test_stride_constructor() { const T d1 = 5; const T d2 = 6; const U u1 = 5; const U u2 = 6; std::vector strides({d1 * d2, d2, 1}); std::vector strides_u({u1 * u2, u2, 1}); const NdIndexStrideOffsetHelper helper1(strides.data()); const NdIndexStrideOffsetHelper helper2(strides.data(), strides.size()); const NdIndexStrideOffsetHelper helper3(strides_u.data()); const NdIndexStrideOffsetHelper helper4(strides_u.data(), strides_u.size()); for (int i = 0; i < 3; i++) { ASSERT_EQ(helper1.stride_[i], strides[i]); ASSERT_EQ(helper2.stride_[i], strides[i]); ASSERT_EQ(helper3.stride_[i], strides_u[i]); ASSERT_EQ(helper4.stride_[i], strides_u[i]); } } TEST(NdIndexStrideOffsetHelper, constructor) { test_stride_constructor(); test_stride_constructor(); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/not_equal_to_previous_adjacent_iterator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_ #define ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_ #include namespace oneflow { #define ITER_DEVICE_FUNC __host__ __device__ __forceinline__ template class NotEqualToPreviousAdjacentIterator { public: typedef NotEqualToPreviousAdjacentIterator self_type; typedef OffsetT difference_type; typedef ValueType value_type; typedef ValueType* pointer; typedef ValueType reference; typedef std::random_access_iterator_tag iterator_category; private: const UnderlyingT* underlying; OffsetT offset; public: ITER_DEVICE_FUNC NotEqualToPreviousAdjacentIterator(const UnderlyingT* underlying, OffsetT offset) : underlying(underlying), offset(offset) {} ITER_DEVICE_FUNC self_type operator++(int) { self_type ret = *this; offset++; return ret; } ITER_DEVICE_FUNC self_type operator++() { offset++; return *this; } ITER_DEVICE_FUNC reference operator*() const { return offset == 0 ? 0 : (underlying[offset] == underlying[offset - 1] ? 0 : 1); } template ITER_DEVICE_FUNC self_type operator+(Distance n) const { self_type ret(underlying, offset + n); return ret; } template ITER_DEVICE_FUNC self_type& operator+=(Distance n) { offset += n; return *this; } template ITER_DEVICE_FUNC self_type operator-(Distance n) const { self_type ret(underlying, offset - n); return ret; } template ITER_DEVICE_FUNC self_type& operator-=(Distance n) { offset -= n; return *this; } ITER_DEVICE_FUNC difference_type operator-(self_type other) const { return offset - other.offset; } template ITER_DEVICE_FUNC reference operator[](Distance n) const { return *(*this + n); } ITER_DEVICE_FUNC pointer operator->() { return nullptr; } ITER_DEVICE_FUNC bool operator==(const self_type& rhs) { return (offset == rhs.offset) && ((underlying == rhs.underlying)); } ITER_DEVICE_FUNC bool operator!=(const self_type& rhs) { return offset != rhs.offset || underlying != rhs.underlying; } friend std::ostream& operator<<(std::ostream& os, const self_type& itr) { return os; } }; #undef ITER_DEVICE_FUNC } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_ ================================================ FILE: oneflow/core/common/notifier.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/notifier.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { NotifierStatus Notifier::Notify() { bool notify = false; { std::unique_lock lock(mutex_); if (is_closed_) { return kNotifierStatusErrorClosed; } notify = (notified_cnt_ == 0); ++notified_cnt_; } if (notify) { cond_.notify_one(); } return kNotifierStatusSuccess; } NotifierStatus Notifier::WaitAndClearNotifiedCnt() { std::unique_lock lock(mutex_); cond_.wait(lock, [this]() { return notified_cnt_ > 0 || is_closed_; }); if (notified_cnt_ == 0) { return kNotifierStatusErrorClosed; } notified_cnt_ = 0; return kNotifierStatusSuccess; } Maybe Notifier::TimedWaitAndClearNotifiedCnt(size_t timeout_seconds) { return Singleton::Get()->WithScopedRelease([&, this]() -> Maybe { std::chrono::duration seconds(timeout_seconds); std::unique_lock lock(mutex_); CHECK_OR_RETURN(cond_.wait_for(lock, seconds, [this]() { return notified_cnt_ > 0 || is_closed_; })) << Error::TimeoutError(); CHECK_GT_OR_RETURN(notified_cnt_, 0) << "notifier closed."; notified_cnt_ = 0; return Maybe::Ok(); }); } Maybe Notifier::TimedWaitAndClearNotifiedCnt( const std::function()>& StopWaitingAfterTimeout) { while (true) { auto status = TRY(TimedWaitAndClearNotifiedCnt(EnvInteger())); if (status.IsOk()) { return status; } if (!status.error()->has_timeout_error()) { return status; } if (JUST(StopWaitingAfterTimeout())) { return status; } } UNIMPLEMENTED_THEN_RETURN(); } void Notifier::Close() { std::unique_lock lock(mutex_); is_closed_ = true; cond_.notify_all(); } } // namespace oneflow ================================================ FILE: oneflow/core/common/notifier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_NOTIFIER_H_ #define ONEFLOW_CORE_COMMON_NOTIFIER_H_ #include "oneflow/core/common/util.h" namespace oneflow { enum NotifierStatus { kNotifierStatusSuccess = 0, kNotifierStatusErrorClosed }; class Notifier final { public: OF_DISALLOW_COPY_AND_MOVE(Notifier); Notifier() : notified_cnt_(0), is_closed_(false) {} ~Notifier() = default; NotifierStatus Notify(); NotifierStatus WaitAndClearNotifiedCnt(); void Close(); Maybe TimedWaitAndClearNotifiedCnt(size_t timeout_seconds); Maybe TimedWaitAndClearNotifiedCnt( const std::function()>& StopWaitingAfterTimeout); private: size_t notified_cnt_; std::mutex mutex_; bool is_closed_; std::condition_variable cond_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_NOTIFIER_H_ ================================================ FILE: oneflow/core/common/of_unused.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_OF_UNUSED_H_ #define ONEFLOW_CORE_COMMON_OF_UNUSED_H_ namespace oneflow { #define OF_UNUSED(x) (void)(x) } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_OF_UNUSED_H_ ================================================ FILE: oneflow/core/common/op_args_reserved_size.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ #define ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ namespace oneflow { constexpr static int kOpArgsReservedSize = 4; } #endif // ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ ================================================ FILE: oneflow/core/common/op_args_vector.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ #define ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/op_args_reserved_size.h" namespace oneflow { template using OpArgsVector = small_vector; } #endif // ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ ================================================ FILE: oneflow/core/common/optional.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_ #define ONEFLOW_CORE_COMMON_OPTIONAL_H_ #include #include #include #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/just.h" namespace oneflow { struct InPlaceConstructType { explicit InPlaceConstructType() = default; }; constexpr InPlaceConstructType InPlaceConstruct{}; struct NullOptType { explicit constexpr NullOptType(int) {} }; constexpr NullOptType NullOpt{0}; namespace internal { template class OptionalBase; template class OptionalBase::value>::type> { public: using value_type = T; using storage_type = T; OptionalBase() : init_(false), value_() {} ~OptionalBase() = default; explicit OptionalBase(const T& value) : init_(true), value_(value) {} explicit OptionalBase(T&& value) : init_(true), value_(std::move(value)) {} OptionalBase(const OptionalBase& base) : init_(base.init_), value_(base.value_) {} OptionalBase(OptionalBase&& base) noexcept : init_(base.init_), value_(std::move(base.value_)) {} OptionalBase& operator=(const T& value) { value_ = value; init_ = true; return *this; } OptionalBase& operator=(T&& value) { value_ = std::move(value); init_ = true; return *this; } OptionalBase& operator=(const OptionalBase& rhs) { value_ = rhs.value_; init_ = rhs.init_; return *this; } OptionalBase& operator=(OptionalBase&& rhs) noexcept { value_ = std::move(rhs.value_); init_ = rhs.init_; return *this; } T value() const& { return value_; } // `T value() &&` goes here T& value() & { return value_; } bool has_value() const { return init_; } T value_or(const T& other) const { if (has_value()) { return value(); } else { return other; } } void reset() { init_ = false; } private: bool init_; T value_; }; template class OptionalBase::value>::type> { public: using value_type = typename std::remove_reference::type; using storage_type = value_type*; static_assert(std::is_lvalue_reference::value, "rvalue reference is not supported here"); OptionalBase() : value_(nullptr){}; ~OptionalBase() = default; explicit OptionalBase(T value) : value_(&value) {} OptionalBase(const OptionalBase& base) : value_(base.value_) {} OptionalBase(OptionalBase&& base) noexcept : value_(base.value_) {} OptionalBase& operator=(T value) { value_ = &value; return *this; } OptionalBase& operator=(const OptionalBase& rhs) { value_ = rhs.value_; return *this; } OptionalBase& operator=(OptionalBase&& rhs) noexcept { value_ = std::move(rhs.value_); return *this; } const value_type& value() const { return *value_; } T value() { return *value_; } bool has_value() const { return value_; } const value_type& value_or(const value_type& other) const { if (has_value()) { return value(); } else { return other; } } void reset() { value_ = nullptr; } private: storage_type value_; }; template class OptionalBase< T, typename std::enable_if::value && !std::is_reference::value>::type> { public: using value_type = T; using storage_type = std::shared_ptr; OptionalBase() : value_(nullptr){}; ~OptionalBase() = default; template explicit OptionalBase(InPlaceConstructType, Args&&... args) : value_(std::make_shared(std::forward(args)...)) {} explicit OptionalBase(const T& value) : value_(std::make_shared(value)) {} explicit OptionalBase(T&& value) : value_(std::make_shared(std::move(value))) {} explicit OptionalBase(const storage_type& value) : value_(value) {} explicit OptionalBase(storage_type&& value) : value_(std::move(value)) {} OptionalBase(const OptionalBase&) = default; OptionalBase(OptionalBase&&) noexcept = default; OptionalBase& operator=(const T& value) { if (value_) { *value_ = value; } else { value_ = std::make_shared(value); } return *this; } OptionalBase& operator=(T&& value) { if (value_) { *value_ = std::move(value); } else { value_ = std::make_shared(std::move(value)); } return *this; } OptionalBase& operator=(const storage_type& value) { value_ = value; return *this; } OptionalBase& operator=(storage_type&& value) { value_ = std::move(value); return *this; } OptionalBase& operator=(const OptionalBase& rhs) { value_ = rhs.value_; return *this; } OptionalBase& operator=(OptionalBase&& rhs) noexcept { value_ = std::move(rhs.value_); return *this; } const storage_type& value() const& { return value_; } storage_type& value() & { return value_; } storage_type&& value() && { return std::move(value_); } bool has_value() const { return bool(value_); } const storage_type& value_or(const storage_type& other) const& { if (has_value()) { return value_; } else { return other; } } storage_type value_or(const storage_type& other) && { if (has_value()) { return std::move(value_); } else { return other; } } storage_type value_or(storage_type&& other) const& { if (has_value()) { return value_; } else { return std::move(other); } } storage_type value_or(storage_type&& other) && { if (has_value()) { return std::move(value_); } else { return std::move(other); } } // we introduce a dependent name `U` to delay the instantiation, // so only the default parameter of `U` is allowed template typename std::enable_if::value, const U&>::type value_or( const value_type& other) const& { static_assert(std::is_same::value, "expected default U"); if (has_value()) { return *value_; } else { return other; } } template typename std::enable_if::value, U>::type value_or( const value_type& other) && { static_assert(std::is_same::value, "expected default U"); if (has_value()) { return std::move(*value_); } else { return other; } } template typename std::enable_if::value, U>::type value_or( value_type&& other) const& { static_assert(std::is_same::value, "expected default U"); if (has_value()) { return *value_; } else { return std::move(other); } } template typename std::enable_if::value, U>::type value_or(value_type&& other) && { static_assert(std::is_same::value, "expected default U"); if (has_value()) { return std::move(*value_); } else { return std::move(other); } } void reset() { value_.reset(); } private: storage_type value_; }; template struct IsOptional : std::false_type {}; template struct IsOptional> : std::true_type {}; struct monadic_operations { template static auto map(T&& opt, F&& f) -> Optional(f)(std::forward(opt).value()))> { if (opt.has_value()) { return std::forward(f)(std::forward(opt).value()); } return NullOpt; } template()(std::declval().value()))>> static auto bind(T&& opt, F&& f) -> std::enable_if_t::value, U> { if (opt.has_value()) { return std::forward(f)(std::forward(opt).value()); } return NullOpt; } template()()), void>::value, int> = 0> static auto or_else(T&& opt, F&& f) -> std::decay_t { if (!opt.has_value()) { std::forward(f)(); return NullOpt; } return std::forward(opt); } template()()), std::decay_t>::value, int> = 0> static auto or_else(T&& opt, F&& f) -> std::decay_t { if (!opt.has_value()) { return std::forward(f)(); } return std::forward(opt); } }; } // namespace internal template class Optional final : private internal::OptionalBase { private: using base = internal::OptionalBase; using move_value_type = decltype(std::declval().value()); public: using value_type = typename base::value_type; using storage_type = typename base::storage_type; explicit Optional() = default; ~Optional() = default; Optional(NullOptType) // NOLINT(google-explicit-constructor) : base() {} template< typename Arg1, typename... ArgN, typename std::enable_if::type>::value), int>::type = 0> Optional(Arg1&& v1, ArgN&&... vn) // NOLINT(google-explicit-constructor) : base(std::forward(v1), std::forward(vn)...) {} Optional(const Optional&) = default; Optional(Optional&&) noexcept = default; template::type>::value, int>::type = 0> Optional& operator=(U&& val) { return static_cast(static_cast(*this) = std::forward(val)); } Optional& operator=(const Optional& rhs) = default; Optional& operator=(Optional&& rhs) noexcept = default; template decltype(auto) value_or(U&& other) const& { return base::value_or(std::forward(other)); } template decltype(auto) value_or(U&& other) && { return std::move(*this).base::value_or(std::forward(other)); } bool has_value() const { return base::has_value(); } explicit operator bool() const { return has_value(); } // generate a temporary object to allow `const auto& x = optval().value()` where `optval()` is a // function call which returns a temporary Optional auto Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() && -> std::conditional_t< std::is_rvalue_reference::value, std::remove_reference_t, move_value_type> { return std::move(*this).base::value(); } friend internal::monadic_operations; template auto map(F&& f) const& { return internal::monadic_operations::map(*this, std::forward(f)); } template auto map(F&& f) && { return internal::monadic_operations::map(std::move(*this), std::forward(f)); } template auto bind(F&& f) const& { return internal::monadic_operations::bind(*this, std::forward(f)); } template auto bind(F&& f) && { return internal::monadic_operations::bind(std::move(*this), std::forward(f)); } template auto or_else(F&& f) const& { return internal::monadic_operations::or_else(*this, std::forward(f)); } template auto or_else(F&& f) && { return internal::monadic_operations::or_else(std::move(*this), std::forward(f)); } bool operator==(const Optional& other) const { if (has_value()) { if (other.has_value()) { return base::value() == other.base::value(); } else { return false; } } else { return !other.has_value(); } } bool operator!=(const Optional& other) const { return !operator==(other); } void reset() { base::reset(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_OPTIONAL_H_ ================================================ FILE: oneflow/core/common/optional_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/just.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/exception.h" namespace oneflow { namespace test { TEST(Optional, copy_constructor) { Optional a(0); std::vector> vec; vec.emplace_back(a); ASSERT_TRUE(vec[0].has_value()); int64_t val = CHECK_JUST(vec[0]); ASSERT_EQ(val, 0); } TEST(Optional, move_constructor) { Optional a(0); std::map> map; map.emplace(0, a); ASSERT_TRUE(map.at(0).has_value()); int64_t val = CHECK_JUST(map.at(0)); ASSERT_EQ(val, 0); } TEST(Optional, JUST) { Optional a(233), b; ASSERT_EQ(a.value_or(0), 233); ASSERT_EQ(b.value_or(1), 1); auto f = [](const Optional& v) -> Maybe { return JUST(v); }; ASSERT_EQ(CHECK_JUST(f(a)), 233); ASSERT_EQ(f(b).error()->msg(), ""); auto g = [](const Optional& v) -> Optional { return JUST_OPT(v); }; ASSERT_EQ(CHECK_JUST(g(a)), 233); a = 234; ASSERT_EQ(CHECK_JUST(a), 234); b = a; ASSERT_EQ(CHECK_JUST(b), 234); b.reset(); ASSERT_EQ(b.value_or(1), 1); Optional c(233); ASSERT_EQ(CHECK_JUST(c), 233); } TEST(Optional, reference) { int x = 1, z = 0; Optional a(x), b; x = 2; ASSERT_EQ(CHECK_JUST(a), 2); ASSERT_EQ(b.value_or(z), 0); CHECK_JUST(a) = 3; ASSERT_EQ(x, 3); Optional c(x); ASSERT_EQ(CHECK_JUST(c), 3); } TEST(Optional, non_scalar) { Optional> a(InPlaceConstruct, 10), b; CHECK_JUST(a)->at(1) = 1; ASSERT_EQ(CHECK_JUST(a)->size(), 10); ASSERT_EQ(CHECK_JUST(a)->at(1), 1); auto x = std::make_shared>(1); ASSERT_EQ(b.value_or(x), x); ASSERT_EQ(b.value_or(std::vector{1, 2, 3}), (std::vector{1, 2, 3})); ASSERT_EQ(b.value_or(*x), *x); ASSERT_EQ(a.value_or(*x), *CHECK_JUST(a)); ASSERT_EQ(Optional>().value_or(*x), *x); ASSERT_EQ(Optional>().value_or(std::vector{1, 2, 3}), (std::vector{1, 2, 3})); Optional> c(std::vector{1, 2, 3}); ASSERT_EQ(CHECK_JUST(c)->at(1), 2); } TEST(Optional, optional_just_error_throw) { ASSERT_THROW( // NOLINT(cppcoreguidelines-avoid-goto) { ([]() -> Maybe { Optional a; return JUST(a); })() .GetOrThrow(); }, Exception); } TEST(Optional, monadic_operations) { Optional a(1), b, c(2); ASSERT_EQ(a.map([](int x) { return x + 1; }), c); ASSERT_EQ(b.map([](int x) { return x + 1; }), b); ASSERT_EQ(a.map([](int x) { return std::string(x + 1, 'a'); }).map([](const auto& x) { return (int)x->size(); }), c); ASSERT_EQ(a.bind([](int x) -> Optional { if (x < 10) { return x * 1.1; } else { return NullOpt; } }) .map([](float x) { return x - 1; }) .map([](float x) { return std::abs(x - 0.1) < 0.001; }), Optional(true)); int x = 0; b.or_else([&] { x++; }).or_else([&] { x *= 2; }); ASSERT_EQ(x, 2); ASSERT_EQ(b.or_else([] { return Optional(3); }).map([](int x) { return x - 1; }), c); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/pcheck.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PCHECK_H_ #define ONEFLOW_CORE_COMMON_PCHECK_H_ #include #include "oneflow/core/common/maybe.h" namespace oneflow { #define PCHECK_OR_RETURN(expr) \ for (int __err = (expr), *__cond = nullptr; __cond == nullptr; ++__cond) \ CHECK_EQ_OR_RETURN(__err, 0) << strerror(errno) << " " } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_PCHECK_H_ ================================================ FILE: oneflow/core/common/permutation_iterator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_ #define ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_ #include namespace oneflow { #define ITER_DEVICE_FUNC __host__ __device__ __forceinline__ template class PermutationIterator { public: using iterator_category = std::random_access_iterator_tag; using self_type = PermutationIterator; using difference_type = OffsetT; using value_type = T; using pointer = T*; using reference = T&; ITER_DEVICE_FUNC PermutationIterator(DataIter data_iter, IndexIter index_iter) : data_iter_(data_iter), index_iter_(index_iter) {} // const methods ITER_DEVICE_FUNC bool operator==(const PermutationIterator& rhs) const { return index_iter_ == rhs.index_iter_ && data_iter_ == rhs.data_iter_; } ITER_DEVICE_FUNC bool operator!=(const PermutationIterator& rhs) const { return !(*this == rhs); } template ITER_DEVICE_FUNC PermutationIterator operator+(Int n) const { return PermutationIterator(data_iter_, index_iter_ + n); } template ITER_DEVICE_FUNC PermutationIterator operator-(Int n) const { return PermutationIterator(data_iter_, index_iter_ - n); } ITER_DEVICE_FUNC difference_type operator-(PermutationIterator other) const { return index_iter_ - other.index_iter_; } ITER_DEVICE_FUNC pointer operator->() const { return &data_iter_[*index_iter_]; } ITER_DEVICE_FUNC reference operator*() const { return data_iter_[*index_iter_]; } template ITER_DEVICE_FUNC reference operator[](Int n) const { return data_iter_[index_iter_[n]]; } // mutable methods ITER_DEVICE_FUNC PermutationIterator operator++(int) { PermutationIterator ret = *this; index_iter_++; return ret; } ITER_DEVICE_FUNC PermutationIterator operator++() { index_iter_++; return *this; } ITER_DEVICE_FUNC PermutationIterator operator--(int) { PermutationIterator ret = *this; index_iter_--; return ret; } ITER_DEVICE_FUNC PermutationIterator operator--() { index_iter_--; return *this; } template ITER_DEVICE_FUNC PermutationIterator& operator+=(Int n) { index_iter_ += n; return *this; } template ITER_DEVICE_FUNC PermutationIterator& operator-=(Int n) { index_iter_ -= n; return *this; } ITER_DEVICE_FUNC pointer operator->() { return &data_iter_[*index_iter_]; } ITER_DEVICE_FUNC reference operator*() { return data_iter_[*index_iter_]; } template ITER_DEVICE_FUNC reference operator[](Int n) { return data_iter_[index_iter_[n]]; } private: DataIter data_iter_; IndexIter index_iter_; }; #undef ITER_DEVICE_FUNC } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_ ================================================ FILE: oneflow/core/common/platform.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PLATFORM_H_ #define ONEFLOW_CORE_COMMON_PLATFORM_H_ // Set one OF_PLATFORM_* macro and set OF_IS_MOBILE_PLATFORM if the platform is for // mobile. #if !defined(OF_PLATFORM_POSIX) && !defined(OF_PLATFORM_GOOGLE) \ && !defined(OF_PLATFORM_POSIX_ANDROID) && !defined(OF_PLATFORM_GOOGLE_ANDROID) \ && !defined(OF_PLATFORM_WINDOWS) // Choose which platform we are on. #if defined(ANDROID) || defined(__ANDROID__) #define OF_PLATFORM_POSIX_ANDROID #define OF_IS_MOBILE_PLATFORM #elif defined(__APPLE__) #define OF_PLATFORM_POSIX #include "TargetConditionals.h" #if OF_TARGET_IPHONE_SIMULATOR #define OF_IS_MOBILE_PLATFORM #elif OF_TARGET_OS_IPHONE #define OF_IS_MOBILE_PLATFORM #endif #elif defined(_WIN32) #define OF_PLATFORM_WINDOWS #elif defined(__arm__) #define OF_PLATFORM_POSIX // Require an outside macro to tell us if we're building for Raspberry Pi. #if !defined(RASPBERRY_PI) #define OF_IS_MOBILE_PLATFORM #endif // !defined(RASPBERRY_PI) #else // If no platform specified, use: #define OF_PLATFORM_POSIX #endif #endif // Look for both gcc/clang and Visual Studio macros indicating we're compiling // for an x86 device. #if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) #define OF_PLATFORM_IS_X86 #endif #endif // ONEFLOW_CORE_COMMON_PLATFORM_H_ ================================================ FILE: oneflow/core/common/preprocessor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PREPROCESSOR_H_ #define ONEFLOW_CORE_COMMON_PREPROCESSOR_H_ #include "oneflow/core/common/preprocessor_internal.h" // basic #define OF_PP_CAT(a, b) OF_PP_INTERNAL_CAT(a, b) #define OF_PP_STRINGIZE(...) OF_PP_INTERNAL_STRINGIZE(__VA_ARGS__) #define OF_PP_PAIR_FIRST(pair) OF_PP_INTERNAL_PAIR_FIRST(pair) #define OF_PP_PAIR_SECOND(pair) OF_PP_INTERNAL_PAIR_SECOND(pair) #define OF_PP_PAIR_THIRD(pair) OF_PP_INTERNAL_PAIR_THIRD(pair) #define OF_PP_TUPLE_SIZE(t) OF_PP_INTERNAL_TUPLE_SIZE(t) #define OF_PP_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM(n, t) #define OF_PP_MAKE_TUPLE_SEQ(...) OF_PP_INTERNAL_MAKE_TUPLE_SEQ(__VA_ARGS__) #define OF_PP_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_FOR_EACH_TUPLE(macro, seq) #define OF_PP_OUTTER_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_OUTTER_FOR_EACH_TUPLE(macro, seq) #define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \ OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__) // advanced #define OF_PP_VARIADIC_SIZE(...) OF_PP_INTERNAL_VARIADIC_SIZE(__VA_ARGS__) #define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq) #define OF_PP_ATOMIC_TO_TUPLE(x) (x) #define OF_PP_FOR_EACH_ATOMIC(macro, seq) \ OF_PP_FOR_EACH_TUPLE(macro, OF_PP_SEQ_MAP(OF_PP_ATOMIC_TO_TUPLE, seq)) #define OF_PP_SEQ_PRODUCT(seq0, ...) OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__) #define OF_PP_SEQ_MAP(macro, seq) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_SEQ_MAP_DO_EACH, (macro), seq) #define OF_PP_I_SEQ_MAP_DO_EACH(macro, elem) (macro(elem)) #define OF_PP_JOIN(glue, ...) OF_PP_INTERNAL_JOIN(glue, __VA_ARGS__) #define OF_PP_TUPLE_PUSH_FRONT(t, x) OF_PP_INTERNAL_TUPLE_PUSH_FRONT(t, x) #define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), )) #endif // ONEFLOW_CORE_COMMON_PREPROCESSOR_H_ ================================================ FILE: oneflow/core/common/preprocessor_internal.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_ #define ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_ // Base #define OF_PP_TUPLE2VARADIC(t) OF_PP_TUPLE2VARADIC_I(t) #define OF_PP_TUPLE2VARADIC_I(t) OF_PP_TUPLE2VARADIC_II t #define OF_PP_TUPLE2VARADIC_II(...) __VA_ARGS__ #define OF_PP_INTERNAL_STRINGIZE(...) OF_PP_INTERNAL_STRINGIZE_I(__VA_ARGS__) #define OF_PP_INTERNAL_STRINGIZE_I(...) #__VA_ARGS__ #define OF_PP_INTERNAL_CAT(a, b) OF_PP_INTERNAL_CAT_I(a, b) #define OF_PP_INTERNAL_CAT_I(a, b) a##b #define OF_PP_INTERNAL_JOIN(glue, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_JOIN_, OF_PP_INTERNAL_VARIADIC_SIZE(__VA_ARGS__))( \ glue, __VA_ARGS__), ) #define OF_PP_INTERNAL_JOIN_0(glue) #define OF_PP_INTERNAL_JOIN_1(glue, x) x #define OF_PP_INTERNAL_JOIN_2(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_1(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_3(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_2(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_4(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_3(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_5(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_4(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_6(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_5(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_7(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_6(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_8(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_7(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_9(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_8(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_10(glue, x, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_9(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_11(glue, x, ...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \ OF_PP_INTERNAL_JOIN_10(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_12(glue, x, ...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \ OF_PP_INTERNAL_JOIN_11(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_13(glue, x, ...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \ OF_PP_INTERNAL_JOIN_12(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_14(glue, x, ...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \ OF_PP_INTERNAL_JOIN_13(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_JOIN_15(glue, x, ...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \ OF_PP_INTERNAL_JOIN_14(glue, __VA_ARGS__)), ) #define OF_PP_INTERNAL_SEQ_HEAD(seq) OF_PP_INTERNAL_PAIR_FIRST(OF_PP_INTERNAL_SEQ_TO_PAIR(seq)) #define OF_PP_INTERNAL_SEQ_TAIL(seq) OF_PP_INTERNAL_PAIR_SECOND(OF_PP_INTERNAL_SEQ_TO_PAIR(seq)) #define OF_PP_INTERNAL_SEQ_TO_PAIR(seq) (OF_PP_INTERNAL_SEQ_TO_PAIR_ seq) #define OF_PP_INTERNAL_SEQ_TO_PAIR_(x) x, OF_PP_INTERNAL_NIL #define OF_PP_INTERNAL_NIL #define OF_PP_INTERNAL_PAIR_FIRST(t) OF_PP_INTERNAL_PAIR_FIRST_I(t) #define OF_PP_INTERNAL_PAIR_FIRST_I(t) OF_PP_INTERNAL_FIRST_ARG t #define OF_PP_INTERNAL_PAIR_SECOND(t) OF_PP_INTERNAL_PAIR_SECOND_I(t) #define OF_PP_INTERNAL_PAIR_SECOND_I(t) OF_PP_INTERNAL_SECOND_ARG t #define OF_PP_INTERNAL_PAIR_THIRD(t) OF_PP_INTERNAL_PAIR_THIRD_I(t) #define OF_PP_INTERNAL_PAIR_THIRD_I(t) OF_PP_INTERNAL_THIRD_ARG t #define OF_PP_INTERNAL_FIRST_ARG(x, ...) x #define OF_PP_INTERNAL_SECOND_ARG(x, y, ...) y #define OF_PP_INTERNAL_THIRD_ARG(x, y, z, ...) z #define OF_PP_INTERNAL_MAKE_TUPLE(...) (__VA_ARGS__) #define OF_PP_INTERNAL_MAKE_TUPLE_SEQ(...) (OF_PP_INTERNAL_MAKE_TUPLE(__VA_ARGS__)) // Tuple #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT(tuple, x) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_TUPLE_PUSH_FRONT_, OF_PP_INTERNAL_TUPLE_SIZE(tuple)) \ (tuple, x) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_0(tuple, x) (x) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_1(tuple, x) (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_2(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_3(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_4(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_5(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(4, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_6(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_7(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(6, tuple)) #define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_8(tuple, x) \ (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple), \ OF_PP_INTERNAL_TUPLE_ELEM(6, tuple), OF_PP_INTERNAL_TUPLE_ELEM(7, tuple)) #define OF_PP_INTERNAL_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM_I(n, t) #define OF_PP_INTERNAL_TUPLE_ELEM_I(n, t) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_ARG_, n) t, ) #define OF_PP_INTERNAL_ARG_0(a0, ...) a0 #define OF_PP_INTERNAL_ARG_1(a0, a1, ...) a1 #define OF_PP_INTERNAL_ARG_2(a0, a1, a2, ...) a2 #define OF_PP_INTERNAL_ARG_3(a0, a1, a2, a3, ...) a3 #define OF_PP_INTERNAL_ARG_4(a0, a1, a2, a3, a4, ...) a4 #define OF_PP_INTERNAL_ARG_5(a0, a1, a2, a3, a4, a5, ...) a5 #define OF_PP_INTERNAL_ARG_6(a0, a1, a2, a3, a4, a5, a6, ...) a6 #define OF_PP_INTERNAL_ARG_7(a0, a1, a2, a3, a4, a5, a6, a7, ...) a7 #define OF_PP_INTERNAL_TUPLE_SIZE(tuple) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_TUPLE_SIZE_, OF_PP_INTERNAL_IS_TUPLE_EMPTY(tuple)) \ (tuple) #define OF_PP_INTERNAL_TUPLE_SIZE_1(t) 0 #define OF_PP_INTERNAL_TUPLE_SIZE_0(t) OF_PP_INTERNAL_TUPLE_SIZE_0_I(t) #define OF_PP_INTERNAL_TUPLE_SIZE_0_I(t) OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_VARIADIC_SIZE t, ) #define OF_PP_INTERNAL_VARIADIC_SIZE(...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_VARIADIC_SIZE_I( \ __VA_ARGS__, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, \ 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, \ 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ), ) #define OF_PP_INTERNAL_VARIADIC_SIZE_I( \ e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12, e13, e14, e15, e16, e17, e18, e19, e20, \ e21, e22, e23, e24, e25, e26, e27, e28, e29, e30, e31, e32, e33, e34, e35, e36, e37, e38, e39, \ e40, e41, e42, e43, e44, e45, e46, e47, e48, e49, e50, e51, e52, e53, e54, e55, e56, e57, e58, \ e59, e60, e61, e62, e63, size, ...) \ size #define OF_PP_INTERNAL_IS_TUPLE_EMPTY(t) OF_PP_INTERNAL_IS_TUPLE_EMPTY_I(t) #define OF_PP_INTERNAL_IS_TUPLE_EMPTY_I(t) OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_IS_VARIADIC_EMPTY t, ) #define OF_PP_INTERNAL_IS_VARIADIC_EMPTY(...) \ OF_PP_INTERNAL_IS_VARIADIC_EMPTY_(/* test if there is just one argument, \ eventually an empty one */ \ OF_PP_INTERNAL_VARIADIC_HAS_COMMA( \ __VA_ARGS__), /* test if \ _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ \ together with the \ argument adds a comma \ */ \ OF_PP_INTERNAL_VARIADIC_HAS_COMMA( \ _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ \ __VA_ARGS__), /* test if the \ argument together \ with a \ parenthesis adds \ a comma \ */ \ OF_PP_INTERNAL_VARIADIC_HAS_COMMA(__VA_ARGS__( \ /*empty*/)), /* test if placing it \ between \ _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ \ and the \ parenthesis adds a \ comma */ \ OF_PP_INTERNAL_VARIADIC_HAS_COMMA( \ _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ __VA_ARGS__( \ /*empty*/))) #define OF_PP_INTERNAL_IS_VARIADIC_EMPTY_(e0, e1, e2, e3) \ OF_PP_INTERNAL_VARIADIC_HAS_COMMA( \ OF_PP_INTERNAL_CAT5(OF_PP_INTERNAL_IS_EMPTY_CASE_, e0, e1, e2, e3)) #define OF_PP_INTERNAL_VARIADIC_HAS_COMMA(...) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_VARIADIC_HAS_COMMA_I( \ __VA_ARGS__, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0), ) #define OF_PP_INTERNAL_VARIADIC_HAS_COMMA_I( \ e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12, e13, e14, e15, e16, e17, e18, e19, e20, \ e21, e22, e23, e24, e25, e26, e27, e28, e29, e30, e31, e32, e33, e34, e35, e36, e37, e38, e39, \ e40, e41, e42, e43, e44, e45, e46, e47, e48, e49, e50, e51, e52, e53, e54, e55, e56, e57, e58, \ e59, e60, e61, e62, e63, has_comma, ...) \ has_comma #define _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_(...) , #define OF_PP_INTERNAL_CAT5(e0, e1, e2, e3, e4) e0##e1##e2##e3##e4 #define OF_PP_INTERNAL_IS_EMPTY_CASE_0001 , // Seq Product #define OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, seq0, ...) \ OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE(macro, _, OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__)) #define OF_PP_INTERNAL_SEQ_PRODUCT(seq0, ...) \ OF_PP_INTERNAL_CAT( \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_SEQ_PRODUCT_, \ OF_PP_INTERNAL_VARIADIC_SIZE(seq0, __VA_ARGS__))(seq0, __VA_ARGS__), ) #define OF_PP_INTERNAL_SEQ_PRODUCT_0() #define OF_PP_INTERNAL_SEQ_PRODUCT_1(seq0) OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ((()), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_2(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_1(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_3(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_2(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_4(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_3(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_5(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_4(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_6(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_5(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_7(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_6(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_8(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_7(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_9(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_8(__VA_ARGS__), seq0) #define OF_PP_INTERNAL_SEQ_PRODUCT_10(seq0, ...) \ OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_9(__VA_ARGS__), seq0) // Seq ForEach #define OF_PP_INTERNAL_OUTTER_FOR_EACH_TUPLE(macro, seq) \ OF_PP_INTERNAL_OUTTER_SEQ_FOR_EACH_TUPLE(macro, _, seq) #define OF_PP_INTERNAL_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE(macro, _, seq) #define OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(tuple_seq, atomic_seq) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_D1_APPLY_ATOMIC_WITH_DATA, \ OF_PP_INTERNAL_TUPLE_X_ATOMIC_SEQ, atomic_seq, tuple_seq) #define OF_PP_INTERNAL_TUPLE_X_ATOMIC_SEQ(atomic_seq, tuple) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_D2_APPLY_ATOMIC_WITH_DATA, \ OF_PP_INTERNAL_MAKE_SEQ_TUPLE_PUSH_FRONT, tuple, atomic_seq) #define OF_PP_INTERNAL_D1_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x) #define OF_PP_INTERNAL_D2_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x) #define OF_PP_INTERNAL_MAKE_SEQ_TUPLE_PUSH_FRONT(tuple, x) \ (OF_PP_INTERNAL_TUPLE_PUSH_FRONT(tuple, x)) // Seq Size #define OF_PP_INTERNAL_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE_I(seq) #define OF_PP_INTERNAL_SEQ_SIZE_I(seq) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_SEQ_SIZE_, OF_PP_INTERNAL_SEQ_SIZE_0 seq) #define OF_PP_INTERNAL_OUTTER_SEQ_FOR_EACH_TUPLE OF_PP_INTERNAL_D0_SEQ_FOR_EACH_TUPLE #define OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE OF_PP_INTERNAL_D1_SEQ_FOR_EACH_TUPLE #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_TUPLE(m, d, seq) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH(OF_PP_INTERNAL_D0_APPLY_TUPLE, m, d, seq) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_TUPLE(m, d, seq) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_TUPLE, m, d, seq) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_TUPLE(m, d, seq) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_TUPLE, m, d, seq) #define OF_PP_INTERNAL_SEQ_FOR_EACH_ATOMIC OF_PP_INTERNAL_D1_SEQ_FOR_EACH_ATOMIC #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_ATOMIC(m, d, seq) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_ATOMIC, m, d, seq) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_ATOMIC(m, d, seq) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_ATOMIC, m, d, seq) #define OF_PP_INTERNAL_D0_APPLY_TUPLE(m, d, t) OF_PP_INTERNAL_D0_APPLY_TUPLE_I(m, d, t) #define OF_PP_INTERNAL_D0_APPLY_TUPLE_I(m, d, t) m t #define OF_PP_INTERNAL_APPLY_TUPLE(m, d, t) OF_PP_INTERNAL_APPLY_TUPLE_I(m, d, t) #define OF_PP_INTERNAL_APPLY_TUPLE_I(m, d, t) m t #define OF_PP_INTERNAL_APPLY_ATOMIC(m, d, x) m(x) #define OF_PP_INTERNAL_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x) #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH(apply, m, d, seq) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D0_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \ (apply, m, d, seq) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH(apply, m, d, seq) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D1_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \ (apply, m, d, seq) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH(apply, m, d, seq) \ OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D2_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \ (apply, m, d, seq) #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_0(apply, m, d, seq) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_0(apply, m, d, seq) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_0(apply, m, d, seq) // php code to generate iterator macro // clang-format off /* #define OF_PP_INTERNAL_SEQ_SIZE_(_) OF_PP_INTERNAL_SEQ_SIZE_ #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_ #define OF_PP_INTERNAL_D_SEQ_FOR_EACH_(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D_SEQ_FOR_EACH_(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) */ // clang-format on // do not edit iterator macro directly, it's generated by the above php code. #define OF_PP_INTERNAL_SEQ_SIZE_0(_) OF_PP_INTERNAL_SEQ_SIZE_1 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_0 0 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_1(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_1(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_1(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_1(_) OF_PP_INTERNAL_SEQ_SIZE_2 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_1 1 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_2(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_2(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_2(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_2(_) OF_PP_INTERNAL_SEQ_SIZE_3 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_2 2 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_3(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_3(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_3(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_3(_) OF_PP_INTERNAL_SEQ_SIZE_4 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_3 3 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_4(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_4(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_4(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_4(_) OF_PP_INTERNAL_SEQ_SIZE_5 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_4 4 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_5(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_5(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_5(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_5(_) OF_PP_INTERNAL_SEQ_SIZE_6 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_5 5 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_6(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_6(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_6(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_6(_) OF_PP_INTERNAL_SEQ_SIZE_7 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_6 6 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_7(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_7(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_7(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_7(_) OF_PP_INTERNAL_SEQ_SIZE_8 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_7 7 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_8(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_8(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_8(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_8(_) OF_PP_INTERNAL_SEQ_SIZE_9 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_8 8 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_9(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_9(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_9(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_9(_) OF_PP_INTERNAL_SEQ_SIZE_10 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_9 9 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_10(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_10(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_10(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_10(_) OF_PP_INTERNAL_SEQ_SIZE_11 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_10 10 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_11(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_11(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_11(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_11(_) OF_PP_INTERNAL_SEQ_SIZE_12 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_11 11 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_12(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_12(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_12(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_12(_) OF_PP_INTERNAL_SEQ_SIZE_13 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_12 12 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_13(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_13(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_13(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_13(_) OF_PP_INTERNAL_SEQ_SIZE_14 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_13 13 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_14(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_14(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_14(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_14(_) OF_PP_INTERNAL_SEQ_SIZE_15 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_14 14 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_15(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_15(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_15(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_15(_) OF_PP_INTERNAL_SEQ_SIZE_16 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_15 15 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_16(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_16(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_16(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_16(_) OF_PP_INTERNAL_SEQ_SIZE_17 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_16 16 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_17(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_17(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_17(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_17(_) OF_PP_INTERNAL_SEQ_SIZE_18 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_17 17 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_18(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_18(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_18(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_18(_) OF_PP_INTERNAL_SEQ_SIZE_19 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_18 18 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_19(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_19(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_19(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_19(_) OF_PP_INTERNAL_SEQ_SIZE_20 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_19 19 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_20(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_20(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_20(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_20(_) OF_PP_INTERNAL_SEQ_SIZE_21 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_20 20 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_21(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_21(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_21(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_21(_) OF_PP_INTERNAL_SEQ_SIZE_22 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_21 21 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_22(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_22(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_22(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_22(_) OF_PP_INTERNAL_SEQ_SIZE_23 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_22 22 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_23(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_23(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_23(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_23(_) OF_PP_INTERNAL_SEQ_SIZE_24 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_23 23 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_24(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_24(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_24(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_24(_) OF_PP_INTERNAL_SEQ_SIZE_25 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_24 24 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_25(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_25(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_25(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_25(_) OF_PP_INTERNAL_SEQ_SIZE_26 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_25 25 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_26(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_26(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_26(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_26(_) OF_PP_INTERNAL_SEQ_SIZE_27 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_26 26 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_27(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_27(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_27(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_27(_) OF_PP_INTERNAL_SEQ_SIZE_28 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_27 27 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_28(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_28(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_28(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_28(_) OF_PP_INTERNAL_SEQ_SIZE_29 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_28 28 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_29(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_29(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_29(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_29(_) OF_PP_INTERNAL_SEQ_SIZE_30 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_29 29 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_30(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_30(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_30(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_30(_) OF_PP_INTERNAL_SEQ_SIZE_31 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_30 30 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_31(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_31(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_31(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_31(_) OF_PP_INTERNAL_SEQ_SIZE_32 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_31 31 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_32(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_32(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_32(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_32(_) OF_PP_INTERNAL_SEQ_SIZE_33 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_32 32 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_33(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_33(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_33(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_33(_) OF_PP_INTERNAL_SEQ_SIZE_34 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_33 33 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_34(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_34(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_34(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_34(_) OF_PP_INTERNAL_SEQ_SIZE_35 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_34 34 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_35(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_35(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_35(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_35(_) OF_PP_INTERNAL_SEQ_SIZE_36 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_35 35 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_36(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_36(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_36(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_36(_) OF_PP_INTERNAL_SEQ_SIZE_37 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_36 36 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_37(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_37(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_37(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_37(_) OF_PP_INTERNAL_SEQ_SIZE_38 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_37 37 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_38(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_38(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_38(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_38(_) OF_PP_INTERNAL_SEQ_SIZE_39 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_38 38 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_39(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_39(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_39(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_39(_) OF_PP_INTERNAL_SEQ_SIZE_40 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_39 39 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_40(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_40(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_40(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_40(_) OF_PP_INTERNAL_SEQ_SIZE_41 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_40 40 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_41(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_41(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_41(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_41(_) OF_PP_INTERNAL_SEQ_SIZE_42 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_41 41 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_42(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_42(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_42(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_42(_) OF_PP_INTERNAL_SEQ_SIZE_43 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_42 42 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_43(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_43(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_43(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_43(_) OF_PP_INTERNAL_SEQ_SIZE_44 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_43 43 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_44(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_44(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_44(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_44(_) OF_PP_INTERNAL_SEQ_SIZE_45 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_44 44 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_45(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_45(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_45(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_45(_) OF_PP_INTERNAL_SEQ_SIZE_46 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_45 45 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_46(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_46(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_46(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_46(_) OF_PP_INTERNAL_SEQ_SIZE_47 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_46 46 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_47(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_47(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_47(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_47(_) OF_PP_INTERNAL_SEQ_SIZE_48 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_47 47 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_48(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_48(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_48(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_48(_) OF_PP_INTERNAL_SEQ_SIZE_49 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_48 48 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_49(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_49(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_49(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_49(_) OF_PP_INTERNAL_SEQ_SIZE_50 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_49 49 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_50(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_50(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_50(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_50(_) OF_PP_INTERNAL_SEQ_SIZE_51 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_50 50 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_51(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_51(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_51(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_51(_) OF_PP_INTERNAL_SEQ_SIZE_52 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_51 51 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_52(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_52(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_52(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_52(_) OF_PP_INTERNAL_SEQ_SIZE_53 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_52 52 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_53(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_53(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_53(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_53(_) OF_PP_INTERNAL_SEQ_SIZE_54 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_53 53 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_54(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_54(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_54(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_54(_) OF_PP_INTERNAL_SEQ_SIZE_55 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_54 54 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_55(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_55(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_55(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_55(_) OF_PP_INTERNAL_SEQ_SIZE_56 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_55 55 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_56(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_56(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_56(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_56(_) OF_PP_INTERNAL_SEQ_SIZE_57 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_56 56 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_57(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_57(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_57(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_57(_) OF_PP_INTERNAL_SEQ_SIZE_58 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_57 57 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_58(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_58(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_58(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_58(_) OF_PP_INTERNAL_SEQ_SIZE_59 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_58 58 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_59(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_59(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_59(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_59(_) OF_PP_INTERNAL_SEQ_SIZE_60 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_59 59 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_60(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_60(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_60(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_60(_) OF_PP_INTERNAL_SEQ_SIZE_61 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_60 60 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_61(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_61(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_61(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_61(_) OF_PP_INTERNAL_SEQ_SIZE_62 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_61 61 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_62(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_62(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_62(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_62(_) OF_PP_INTERNAL_SEQ_SIZE_63 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_62 62 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_63(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_63(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_63(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_63(_) OF_PP_INTERNAL_SEQ_SIZE_64 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_63 63 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_64(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_64(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_64(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_64(_) OF_PP_INTERNAL_SEQ_SIZE_65 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_64 64 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_65(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_65(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_65(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_65(_) OF_PP_INTERNAL_SEQ_SIZE_66 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_65 65 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_66(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_66(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_66(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_66(_) OF_PP_INTERNAL_SEQ_SIZE_67 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_66 66 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_67(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_67(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_67(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_67(_) OF_PP_INTERNAL_SEQ_SIZE_68 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_67 67 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_68(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_68(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_68(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_68(_) OF_PP_INTERNAL_SEQ_SIZE_69 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_68 68 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_69(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_69(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_69(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_69(_) OF_PP_INTERNAL_SEQ_SIZE_70 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_69 69 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_70(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_70(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_70(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_70(_) OF_PP_INTERNAL_SEQ_SIZE_71 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_70 70 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_71(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_71(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_71(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_71(_) OF_PP_INTERNAL_SEQ_SIZE_72 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_71 71 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_72(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_72(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_72(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_72(_) OF_PP_INTERNAL_SEQ_SIZE_73 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_72 72 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_73(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_73(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_73(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_73(_) OF_PP_INTERNAL_SEQ_SIZE_74 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_73 73 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_74(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_74(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_74(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_74(_) OF_PP_INTERNAL_SEQ_SIZE_75 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_74 74 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_75(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_75(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_75(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_75(_) OF_PP_INTERNAL_SEQ_SIZE_76 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_75 75 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_76(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_76(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_76(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_76(_) OF_PP_INTERNAL_SEQ_SIZE_77 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_76 76 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_77(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_77(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_77(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_77(_) OF_PP_INTERNAL_SEQ_SIZE_78 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_77 77 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_78(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_78(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_78(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_78(_) OF_PP_INTERNAL_SEQ_SIZE_79 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_78 78 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_79(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_79(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_79(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_79(_) OF_PP_INTERNAL_SEQ_SIZE_80 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_79 79 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_80(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_80(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_80(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_80(_) OF_PP_INTERNAL_SEQ_SIZE_81 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_80 80 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_81(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_81(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_81(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_81(_) OF_PP_INTERNAL_SEQ_SIZE_82 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_81 81 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_82(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_82(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_82(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_82(_) OF_PP_INTERNAL_SEQ_SIZE_83 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_82 82 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_83(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_83(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_83(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_83(_) OF_PP_INTERNAL_SEQ_SIZE_84 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_83 83 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_84(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_84(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_84(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_84(_) OF_PP_INTERNAL_SEQ_SIZE_85 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_84 84 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_85(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_85(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_85(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_85(_) OF_PP_INTERNAL_SEQ_SIZE_86 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_85 85 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_86(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_86(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_86(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_86(_) OF_PP_INTERNAL_SEQ_SIZE_87 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_86 86 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_87(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_87(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_87(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_87(_) OF_PP_INTERNAL_SEQ_SIZE_88 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_87 87 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_88(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_88(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_88(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_88(_) OF_PP_INTERNAL_SEQ_SIZE_89 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_88 88 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_89(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_89(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_89(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_89(_) OF_PP_INTERNAL_SEQ_SIZE_90 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_89 89 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_90(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_90(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_90(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_90(_) OF_PP_INTERNAL_SEQ_SIZE_91 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_90 90 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_91(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_91(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_91(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_91(_) OF_PP_INTERNAL_SEQ_SIZE_92 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_91 91 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_92(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_92(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_92(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_92(_) OF_PP_INTERNAL_SEQ_SIZE_93 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_92 92 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_93(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_93(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_93(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_93(_) OF_PP_INTERNAL_SEQ_SIZE_94 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_93 93 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_94(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_94(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_94(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_94(_) OF_PP_INTERNAL_SEQ_SIZE_95 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_94 94 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_95(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_95(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_95(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_95(_) OF_PP_INTERNAL_SEQ_SIZE_96 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_95 95 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_96(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_96(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_96(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_96(_) OF_PP_INTERNAL_SEQ_SIZE_97 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_96 96 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_97(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_97(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_97(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_97(_) OF_PP_INTERNAL_SEQ_SIZE_98 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_97 97 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_98(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_98(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_98(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_98(_) OF_PP_INTERNAL_SEQ_SIZE_99 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_98 98 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_99(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_99(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_99(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_99(_) OF_PP_INTERNAL_SEQ_SIZE_100 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_99 99 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_100(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_100(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_100(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_100(_) OF_PP_INTERNAL_SEQ_SIZE_101 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_100 100 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_101(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_101(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_101(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_101(_) OF_PP_INTERNAL_SEQ_SIZE_102 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_101 101 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_102(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_102(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_102(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_102(_) OF_PP_INTERNAL_SEQ_SIZE_103 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_102 102 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_103(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_103(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_103(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_103(_) OF_PP_INTERNAL_SEQ_SIZE_104 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_103 103 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_104(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_104(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_104(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_104(_) OF_PP_INTERNAL_SEQ_SIZE_105 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_104 104 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_105(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_105(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_105(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_105(_) OF_PP_INTERNAL_SEQ_SIZE_106 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_105 105 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_106(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_106(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_106(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_106(_) OF_PP_INTERNAL_SEQ_SIZE_107 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_106 106 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_107(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_107(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_107(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_107(_) OF_PP_INTERNAL_SEQ_SIZE_108 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_107 107 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_108(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_108(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_108(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_108(_) OF_PP_INTERNAL_SEQ_SIZE_109 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_108 108 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_109(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_109(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_109(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_109(_) OF_PP_INTERNAL_SEQ_SIZE_110 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_109 109 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_110(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_110(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_110(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_110(_) OF_PP_INTERNAL_SEQ_SIZE_111 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_110 110 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_111(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_111(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_111(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_111(_) OF_PP_INTERNAL_SEQ_SIZE_112 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_111 111 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_112(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_112(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_112(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_112(_) OF_PP_INTERNAL_SEQ_SIZE_113 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_112 112 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_113(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_113(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_113(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_113(_) OF_PP_INTERNAL_SEQ_SIZE_114 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_113 113 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_114(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_114(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_114(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_114(_) OF_PP_INTERNAL_SEQ_SIZE_115 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_114 114 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_115(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_115(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_115(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_115(_) OF_PP_INTERNAL_SEQ_SIZE_116 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_115 115 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_116(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_116(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_116(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_116(_) OF_PP_INTERNAL_SEQ_SIZE_117 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_116 116 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_117(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_117(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_117(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_117(_) OF_PP_INTERNAL_SEQ_SIZE_118 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_117 117 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_118(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_118(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_118(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_118(_) OF_PP_INTERNAL_SEQ_SIZE_119 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_118 118 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_119(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_119(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_119(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_119(_) OF_PP_INTERNAL_SEQ_SIZE_120 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_119 119 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_120(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_120(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_120(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_120(_) OF_PP_INTERNAL_SEQ_SIZE_121 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_120 120 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_121(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_121(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_121(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_121(_) OF_PP_INTERNAL_SEQ_SIZE_122 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_121 121 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_122(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_122(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_122(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_122(_) OF_PP_INTERNAL_SEQ_SIZE_123 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_122 122 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_123(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_123(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_123(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_123(_) OF_PP_INTERNAL_SEQ_SIZE_124 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_123 123 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_124(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_124(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_124(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_124(_) OF_PP_INTERNAL_SEQ_SIZE_125 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_124 124 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_125(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_125(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_125(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_125(_) OF_PP_INTERNAL_SEQ_SIZE_126 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_125 125 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_126(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_126(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_126(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_126(_) OF_PP_INTERNAL_SEQ_SIZE_127 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_126 126 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_127(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_127(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_127(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_127(_) OF_PP_INTERNAL_SEQ_SIZE_128 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_127 127 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_128(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_128(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_128(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_128(_) OF_PP_INTERNAL_SEQ_SIZE_129 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_128 128 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_129(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_129(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_129(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_129(_) OF_PP_INTERNAL_SEQ_SIZE_130 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_129 129 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_130(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_130(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_130(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_130(_) OF_PP_INTERNAL_SEQ_SIZE_131 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_130 130 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_131(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_131(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_131(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_131(_) OF_PP_INTERNAL_SEQ_SIZE_132 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_131 131 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_132(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_132(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_132(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_132(_) OF_PP_INTERNAL_SEQ_SIZE_133 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_132 132 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_133(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_133(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_133(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_133(_) OF_PP_INTERNAL_SEQ_SIZE_134 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_133 133 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_134(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_134(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_134(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_134(_) OF_PP_INTERNAL_SEQ_SIZE_135 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_134 134 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_135(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_135(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_135(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_135(_) OF_PP_INTERNAL_SEQ_SIZE_136 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_135 135 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_136(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_136(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_136(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_136(_) OF_PP_INTERNAL_SEQ_SIZE_137 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_136 136 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_137(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_137(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_137(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_137(_) OF_PP_INTERNAL_SEQ_SIZE_138 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_137 137 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_138(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_138(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_138(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_138(_) OF_PP_INTERNAL_SEQ_SIZE_139 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_138 138 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_139(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_139(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_139(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_139(_) OF_PP_INTERNAL_SEQ_SIZE_140 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_139 139 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_140(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_140(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_140(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_140(_) OF_PP_INTERNAL_SEQ_SIZE_141 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_140 140 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_141(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_141(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_141(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_141(_) OF_PP_INTERNAL_SEQ_SIZE_142 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_141 141 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_142(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_142(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_142(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_142(_) OF_PP_INTERNAL_SEQ_SIZE_143 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_142 142 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_143(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_143(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_143(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_143(_) OF_PP_INTERNAL_SEQ_SIZE_144 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_143 143 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_144(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_144(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_144(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_144(_) OF_PP_INTERNAL_SEQ_SIZE_145 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_144 144 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_145(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_145(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_145(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_145(_) OF_PP_INTERNAL_SEQ_SIZE_146 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_145 145 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_146(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_146(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_146(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_146(_) OF_PP_INTERNAL_SEQ_SIZE_147 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_146 146 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_147(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_147(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_147(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_147(_) OF_PP_INTERNAL_SEQ_SIZE_148 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_147 147 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_148(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_148(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_148(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_148(_) OF_PP_INTERNAL_SEQ_SIZE_149 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_148 148 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_149(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_149(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_149(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_149(_) OF_PP_INTERNAL_SEQ_SIZE_150 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_149 149 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_150(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_150(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_150(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_150(_) OF_PP_INTERNAL_SEQ_SIZE_151 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_150 150 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_151(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_151(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_151(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_151(_) OF_PP_INTERNAL_SEQ_SIZE_152 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_151 151 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_152(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_152(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_152(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_152(_) OF_PP_INTERNAL_SEQ_SIZE_153 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_152 152 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_153(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_153(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_153(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_153(_) OF_PP_INTERNAL_SEQ_SIZE_154 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_153 153 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_154(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_154(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_154(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_154(_) OF_PP_INTERNAL_SEQ_SIZE_155 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_154 154 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_155(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_155(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_155(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_155(_) OF_PP_INTERNAL_SEQ_SIZE_156 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_155 155 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_156(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_156(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_156(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_156(_) OF_PP_INTERNAL_SEQ_SIZE_157 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_156 156 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_157(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_157(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_157(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_157(_) OF_PP_INTERNAL_SEQ_SIZE_158 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_157 157 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_158(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_158(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_158(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_158(_) OF_PP_INTERNAL_SEQ_SIZE_159 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_158 158 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_159(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_159(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_159(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_159(_) OF_PP_INTERNAL_SEQ_SIZE_160 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_159 159 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_160(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_160(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_160(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_160(_) OF_PP_INTERNAL_SEQ_SIZE_161 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_160 160 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_161(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_161(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_161(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_161(_) OF_PP_INTERNAL_SEQ_SIZE_162 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_161 161 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_162(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_162(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_162(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_162(_) OF_PP_INTERNAL_SEQ_SIZE_163 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_162 162 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_163(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_163(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_163(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_163(_) OF_PP_INTERNAL_SEQ_SIZE_164 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_163 163 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_164(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_164(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_164(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_164(_) OF_PP_INTERNAL_SEQ_SIZE_165 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_164 164 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_165(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_165(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_165(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_165(_) OF_PP_INTERNAL_SEQ_SIZE_166 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_165 165 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_166(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_166(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_166(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_166(_) OF_PP_INTERNAL_SEQ_SIZE_167 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_166 166 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_167(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_167(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_167(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_167(_) OF_PP_INTERNAL_SEQ_SIZE_168 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_167 167 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_168(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_168(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_168(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_168(_) OF_PP_INTERNAL_SEQ_SIZE_169 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_168 168 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_169(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_169(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_169(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_169(_) OF_PP_INTERNAL_SEQ_SIZE_170 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_169 169 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_170(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_170(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_170(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_170(_) OF_PP_INTERNAL_SEQ_SIZE_171 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_170 170 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_171(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_171(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_171(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_171(_) OF_PP_INTERNAL_SEQ_SIZE_172 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_171 171 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_172(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_172(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_172(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_172(_) OF_PP_INTERNAL_SEQ_SIZE_173 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_172 172 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_173(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_173(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_173(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_173(_) OF_PP_INTERNAL_SEQ_SIZE_174 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_173 173 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_174(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_174(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_174(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_174(_) OF_PP_INTERNAL_SEQ_SIZE_175 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_174 174 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_175(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_175(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_175(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_175(_) OF_PP_INTERNAL_SEQ_SIZE_176 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_175 175 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_176(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_176(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_176(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_176(_) OF_PP_INTERNAL_SEQ_SIZE_177 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_176 176 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_177(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_177(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_177(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_177(_) OF_PP_INTERNAL_SEQ_SIZE_178 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_177 177 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_178(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_178(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_178(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_178(_) OF_PP_INTERNAL_SEQ_SIZE_179 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_178 178 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_179(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_179(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_179(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_179(_) OF_PP_INTERNAL_SEQ_SIZE_180 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_179 179 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_180(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_180(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_180(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_180(_) OF_PP_INTERNAL_SEQ_SIZE_181 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_180 180 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_181(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_181(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_181(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_181(_) OF_PP_INTERNAL_SEQ_SIZE_182 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_181 181 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_182(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_182(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_182(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_182(_) OF_PP_INTERNAL_SEQ_SIZE_183 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_182 182 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_183(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_183(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_183(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_183(_) OF_PP_INTERNAL_SEQ_SIZE_184 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_183 183 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_184(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_184(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_184(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_184(_) OF_PP_INTERNAL_SEQ_SIZE_185 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_184 184 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_185(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_185(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_185(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_185(_) OF_PP_INTERNAL_SEQ_SIZE_186 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_185 185 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_186(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_186(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_186(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_186(_) OF_PP_INTERNAL_SEQ_SIZE_187 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_186 186 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_187(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_187(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_187(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_187(_) OF_PP_INTERNAL_SEQ_SIZE_188 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_187 187 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_188(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_188(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_188(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_188(_) OF_PP_INTERNAL_SEQ_SIZE_189 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_188 188 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_189(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_189(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_189(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_189(_) OF_PP_INTERNAL_SEQ_SIZE_190 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_189 189 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_190(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_190(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_190(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_190(_) OF_PP_INTERNAL_SEQ_SIZE_191 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_190 190 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_191(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_191(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_191(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_191(_) OF_PP_INTERNAL_SEQ_SIZE_192 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_191 191 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_192(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_192(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_192(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_192(_) OF_PP_INTERNAL_SEQ_SIZE_193 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_192 192 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_193(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_193(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_193(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_193(_) OF_PP_INTERNAL_SEQ_SIZE_194 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_193 193 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_194(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_194(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_194(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_194(_) OF_PP_INTERNAL_SEQ_SIZE_195 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_194 194 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_195(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_195(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_195(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_195(_) OF_PP_INTERNAL_SEQ_SIZE_196 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_195 195 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_196(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_196(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_196(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_196(_) OF_PP_INTERNAL_SEQ_SIZE_197 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_196 196 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_197(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_197(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_197(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_197(_) OF_PP_INTERNAL_SEQ_SIZE_198 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_197 197 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_198(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_198(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_198(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_198(_) OF_PP_INTERNAL_SEQ_SIZE_199 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_198 198 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_199(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_199(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_199(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_199(_) OF_PP_INTERNAL_SEQ_SIZE_200 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_199 199 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_200(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_200(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_200(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_200(_) OF_PP_INTERNAL_SEQ_SIZE_201 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_200 200 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_201(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_201(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_201(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_201(_) OF_PP_INTERNAL_SEQ_SIZE_202 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_201 201 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_202(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_202(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_202(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_202(_) OF_PP_INTERNAL_SEQ_SIZE_203 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_202 202 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_203(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_203(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_203(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_203(_) OF_PP_INTERNAL_SEQ_SIZE_204 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_203 203 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_204(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_204(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_204(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_204(_) OF_PP_INTERNAL_SEQ_SIZE_205 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_204 204 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_205(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_205(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_205(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_205(_) OF_PP_INTERNAL_SEQ_SIZE_206 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_205 205 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_206(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_206(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_206(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_206(_) OF_PP_INTERNAL_SEQ_SIZE_207 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_206 206 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_207(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_207(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_207(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_207(_) OF_PP_INTERNAL_SEQ_SIZE_208 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_207 207 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_208(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_208(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_208(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_208(_) OF_PP_INTERNAL_SEQ_SIZE_209 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_208 208 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_209(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_209(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_209(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_209(_) OF_PP_INTERNAL_SEQ_SIZE_210 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_209 209 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_210(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_210(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_210(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_210(_) OF_PP_INTERNAL_SEQ_SIZE_211 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_210 210 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_211(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_211(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_211(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_211(_) OF_PP_INTERNAL_SEQ_SIZE_212 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_211 211 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_212(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_212(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_212(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_212(_) OF_PP_INTERNAL_SEQ_SIZE_213 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_212 212 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_213(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_213(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_213(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_213(_) OF_PP_INTERNAL_SEQ_SIZE_214 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_213 213 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_214(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_214(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_214(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_214(_) OF_PP_INTERNAL_SEQ_SIZE_215 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_214 214 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_215(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_215(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_215(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_215(_) OF_PP_INTERNAL_SEQ_SIZE_216 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_215 215 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_216(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_216(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_216(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_216(_) OF_PP_INTERNAL_SEQ_SIZE_217 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_216 216 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_217(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_217(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_217(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_217(_) OF_PP_INTERNAL_SEQ_SIZE_218 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_217 217 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_218(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_218(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_218(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_218(_) OF_PP_INTERNAL_SEQ_SIZE_219 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_218 218 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_219(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_219(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_219(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_219(_) OF_PP_INTERNAL_SEQ_SIZE_220 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_219 219 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_220(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_220(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_220(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_220(_) OF_PP_INTERNAL_SEQ_SIZE_221 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_220 220 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_221(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_221(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_221(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_221(_) OF_PP_INTERNAL_SEQ_SIZE_222 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_221 221 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_222(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_222(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_222(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_222(_) OF_PP_INTERNAL_SEQ_SIZE_223 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_222 222 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_223(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_223(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_223(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_223(_) OF_PP_INTERNAL_SEQ_SIZE_224 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_223 223 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_224(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_224(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_224(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_224(_) OF_PP_INTERNAL_SEQ_SIZE_225 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_224 224 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_225(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_225(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_225(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_225(_) OF_PP_INTERNAL_SEQ_SIZE_226 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_225 225 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_226(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_226(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_226(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_226(_) OF_PP_INTERNAL_SEQ_SIZE_227 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_226 226 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_227(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_227(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_227(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_227(_) OF_PP_INTERNAL_SEQ_SIZE_228 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_227 227 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_228(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_228(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_228(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_228(_) OF_PP_INTERNAL_SEQ_SIZE_229 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_228 228 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_229(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_229(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_229(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_229(_) OF_PP_INTERNAL_SEQ_SIZE_230 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_229 229 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_230(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_230(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_230(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_230(_) OF_PP_INTERNAL_SEQ_SIZE_231 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_230 230 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_231(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_231(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_231(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_231(_) OF_PP_INTERNAL_SEQ_SIZE_232 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_231 231 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_232(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_232(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_232(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_232(_) OF_PP_INTERNAL_SEQ_SIZE_233 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_232 232 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_233(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_233(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_233(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_233(_) OF_PP_INTERNAL_SEQ_SIZE_234 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_233 233 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_234(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_234(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_234(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_234(_) OF_PP_INTERNAL_SEQ_SIZE_235 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_234 234 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_235(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_235(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_235(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_235(_) OF_PP_INTERNAL_SEQ_SIZE_236 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_235 235 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_236(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_236(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_236(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_236(_) OF_PP_INTERNAL_SEQ_SIZE_237 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_236 236 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_237(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_237(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_237(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_237(_) OF_PP_INTERNAL_SEQ_SIZE_238 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_237 237 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_238(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_238(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_238(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_238(_) OF_PP_INTERNAL_SEQ_SIZE_239 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_238 238 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_239(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_239(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_239(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_239(_) OF_PP_INTERNAL_SEQ_SIZE_240 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_239 239 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_240(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_240(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_240(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_240(_) OF_PP_INTERNAL_SEQ_SIZE_241 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_240 240 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_241(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_241(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_241(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_241(_) OF_PP_INTERNAL_SEQ_SIZE_242 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_241 241 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_242(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_242(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_242(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_242(_) OF_PP_INTERNAL_SEQ_SIZE_243 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_242 242 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_243(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_243(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_243(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_243(_) OF_PP_INTERNAL_SEQ_SIZE_244 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_243 243 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_244(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_244(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_244(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_244(_) OF_PP_INTERNAL_SEQ_SIZE_245 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_244 244 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_245(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_245(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_245(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_245(_) OF_PP_INTERNAL_SEQ_SIZE_246 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_245 245 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_246(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_246(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_246(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_246(_) OF_PP_INTERNAL_SEQ_SIZE_247 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_246 246 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_247(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_247(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_247(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_247(_) OF_PP_INTERNAL_SEQ_SIZE_248 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_247 247 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_248(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_248(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_248(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_248(_) OF_PP_INTERNAL_SEQ_SIZE_249 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_248 248 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_249(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_249(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_249(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_249(_) OF_PP_INTERNAL_SEQ_SIZE_250 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_249 249 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_250(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_250(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_250(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_250(_) OF_PP_INTERNAL_SEQ_SIZE_251 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_250 250 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_251(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_251(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_251(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_251(_) OF_PP_INTERNAL_SEQ_SIZE_252 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_251 251 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_252(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_252(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_252(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_252(_) OF_PP_INTERNAL_SEQ_SIZE_253 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_252 252 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_253(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_253(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_253(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_253(_) OF_PP_INTERNAL_SEQ_SIZE_254 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_253 253 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_254(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_254(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_254(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_254(_) OF_PP_INTERNAL_SEQ_SIZE_255 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_254 254 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_255(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_255(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_255(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_255(_) OF_PP_INTERNAL_SEQ_SIZE_256 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_255 255 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_256(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_256(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_256(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_256(_) OF_PP_INTERNAL_SEQ_SIZE_257 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_256 256 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_257(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_257(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_257(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_257(_) OF_PP_INTERNAL_SEQ_SIZE_258 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_257 257 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_258(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_258(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_258(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_258(_) OF_PP_INTERNAL_SEQ_SIZE_259 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_258 258 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_259(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_259(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_259(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_259(_) OF_PP_INTERNAL_SEQ_SIZE_260 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_259 259 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_260(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_260(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_260(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_260(_) OF_PP_INTERNAL_SEQ_SIZE_261 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_260 260 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_261(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_261(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_261(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_261(_) OF_PP_INTERNAL_SEQ_SIZE_262 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_261 261 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_262(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_262(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_262(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_262(_) OF_PP_INTERNAL_SEQ_SIZE_263 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_262 262 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_263(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_263(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_263(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_263(_) OF_PP_INTERNAL_SEQ_SIZE_264 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_263 263 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_264(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_264(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_264(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_264(_) OF_PP_INTERNAL_SEQ_SIZE_265 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_264 264 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_265(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_265(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_265(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_265(_) OF_PP_INTERNAL_SEQ_SIZE_266 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_265 265 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_266(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_266(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_266(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_266(_) OF_PP_INTERNAL_SEQ_SIZE_267 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_266 266 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_267(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_267(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_267(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_267(_) OF_PP_INTERNAL_SEQ_SIZE_268 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_267 267 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_268(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_268(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_268(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_268(_) OF_PP_INTERNAL_SEQ_SIZE_269 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_268 268 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_269(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_269(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_269(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_269(_) OF_PP_INTERNAL_SEQ_SIZE_270 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_269 269 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_270(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_270(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_270(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_270(_) OF_PP_INTERNAL_SEQ_SIZE_271 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_270 270 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_271(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_271(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_271(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_271(_) OF_PP_INTERNAL_SEQ_SIZE_272 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_271 271 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_272(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_272(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_272(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_272(_) OF_PP_INTERNAL_SEQ_SIZE_273 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_272 272 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_273(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_273(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_273(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_273(_) OF_PP_INTERNAL_SEQ_SIZE_274 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_273 273 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_274(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_274(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_274(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_274(_) OF_PP_INTERNAL_SEQ_SIZE_275 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_274 274 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_275(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_275(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_275(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_275(_) OF_PP_INTERNAL_SEQ_SIZE_276 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_275 275 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_276(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_276(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_276(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_276(_) OF_PP_INTERNAL_SEQ_SIZE_277 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_276 276 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_277(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_277(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_277(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_277(_) OF_PP_INTERNAL_SEQ_SIZE_278 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_277 277 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_278(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_278(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_278(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_278(_) OF_PP_INTERNAL_SEQ_SIZE_279 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_278 278 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_279(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_279(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_279(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_279(_) OF_PP_INTERNAL_SEQ_SIZE_280 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_279 279 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_280(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_280(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_280(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_280(_) OF_PP_INTERNAL_SEQ_SIZE_281 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_280 280 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_281(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_281(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_281(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_281(_) OF_PP_INTERNAL_SEQ_SIZE_282 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_281 281 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_282(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_282(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_282(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_282(_) OF_PP_INTERNAL_SEQ_SIZE_283 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_282 282 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_283(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_283(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_283(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_283(_) OF_PP_INTERNAL_SEQ_SIZE_284 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_283 283 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_284(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_284(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_284(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_284(_) OF_PP_INTERNAL_SEQ_SIZE_285 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_284 284 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_285(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_285(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_285(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_285(_) OF_PP_INTERNAL_SEQ_SIZE_286 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_285 285 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_286(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_286(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_286(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_286(_) OF_PP_INTERNAL_SEQ_SIZE_287 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_286 286 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_287(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_287(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_287(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_287(_) OF_PP_INTERNAL_SEQ_SIZE_288 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_287 287 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_288(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_288(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_288(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_288(_) OF_PP_INTERNAL_SEQ_SIZE_289 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_288 288 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_289(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_289(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_289(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_289(_) OF_PP_INTERNAL_SEQ_SIZE_290 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_289 289 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_290(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_290(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_290(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_290(_) OF_PP_INTERNAL_SEQ_SIZE_291 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_290 290 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_291(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_291(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_291(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_291(_) OF_PP_INTERNAL_SEQ_SIZE_292 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_291 291 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_292(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_292(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_292(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_292(_) OF_PP_INTERNAL_SEQ_SIZE_293 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_292 292 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_293(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_293(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_293(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_293(_) OF_PP_INTERNAL_SEQ_SIZE_294 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_293 293 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_294(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_294(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_294(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_294(_) OF_PP_INTERNAL_SEQ_SIZE_295 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_294 294 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_295(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_295(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_295(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_295(_) OF_PP_INTERNAL_SEQ_SIZE_296 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_295 295 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_296(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_296(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_296(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_296(_) OF_PP_INTERNAL_SEQ_SIZE_297 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_296 296 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_297(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_297(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_297(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_297(_) OF_PP_INTERNAL_SEQ_SIZE_298 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_297 297 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_298(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_298(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_298(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_298(_) OF_PP_INTERNAL_SEQ_SIZE_299 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_298 298 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_299(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_299(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_299(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_299(_) OF_PP_INTERNAL_SEQ_SIZE_300 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_299 299 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_300(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_300(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_300(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_300(_) OF_PP_INTERNAL_SEQ_SIZE_301 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_300 300 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_301(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_301(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_301(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_301(_) OF_PP_INTERNAL_SEQ_SIZE_302 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_301 301 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_302(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_302(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_302(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_302(_) OF_PP_INTERNAL_SEQ_SIZE_303 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_302 302 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_303(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_303(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_303(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_303(_) OF_PP_INTERNAL_SEQ_SIZE_304 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_303 303 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_304(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_304(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_304(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_304(_) OF_PP_INTERNAL_SEQ_SIZE_305 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_304 304 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_305(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_305(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_305(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_305(_) OF_PP_INTERNAL_SEQ_SIZE_306 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_305 305 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_306(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_306(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_306(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_306(_) OF_PP_INTERNAL_SEQ_SIZE_307 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_306 306 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_307(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_307(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_307(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_307(_) OF_PP_INTERNAL_SEQ_SIZE_308 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_307 307 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_308(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_308(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_308(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_308(_) OF_PP_INTERNAL_SEQ_SIZE_309 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_308 308 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_309(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_309(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_309(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_309(_) OF_PP_INTERNAL_SEQ_SIZE_310 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_309 309 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_310(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_310(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_310(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_310(_) OF_PP_INTERNAL_SEQ_SIZE_311 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_310 310 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_311(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_311(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_311(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_311(_) OF_PP_INTERNAL_SEQ_SIZE_312 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_311 311 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_312(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_312(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_312(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_312(_) OF_PP_INTERNAL_SEQ_SIZE_313 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_312 312 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_313(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_313(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_313(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_313(_) OF_PP_INTERNAL_SEQ_SIZE_314 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_313 313 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_314(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_314(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_314(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_314(_) OF_PP_INTERNAL_SEQ_SIZE_315 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_314 314 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_315(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_315(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_315(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_315(_) OF_PP_INTERNAL_SEQ_SIZE_316 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_315 315 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_316(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_316(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_316(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_316(_) OF_PP_INTERNAL_SEQ_SIZE_317 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_316 316 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_317(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_317(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_317(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_317(_) OF_PP_INTERNAL_SEQ_SIZE_318 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_317 317 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_318(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_318(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_318(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_318(_) OF_PP_INTERNAL_SEQ_SIZE_319 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_318 318 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_319(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_319(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_319(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_319(_) OF_PP_INTERNAL_SEQ_SIZE_320 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_319 319 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_320(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_320(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_320(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_320(_) OF_PP_INTERNAL_SEQ_SIZE_321 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_320 320 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_321(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_321(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_321(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_321(_) OF_PP_INTERNAL_SEQ_SIZE_322 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_321 321 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_322(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_322(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_322(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_322(_) OF_PP_INTERNAL_SEQ_SIZE_323 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_322 322 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_323(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_323(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_323(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_323(_) OF_PP_INTERNAL_SEQ_SIZE_324 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_323 323 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_324(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_324(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_324(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_324(_) OF_PP_INTERNAL_SEQ_SIZE_325 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_324 324 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_325(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_325(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_325(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_325(_) OF_PP_INTERNAL_SEQ_SIZE_326 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_325 325 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_326(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_326(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_326(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_326(_) OF_PP_INTERNAL_SEQ_SIZE_327 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_326 326 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_327(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_327(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_327(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_327(_) OF_PP_INTERNAL_SEQ_SIZE_328 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_327 327 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_328(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_328(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_328(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_328(_) OF_PP_INTERNAL_SEQ_SIZE_329 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_328 328 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_329(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_329(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_329(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_329(_) OF_PP_INTERNAL_SEQ_SIZE_330 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_329 329 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_330(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_330(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_330(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_330(_) OF_PP_INTERNAL_SEQ_SIZE_331 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_330 330 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_331(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_331(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_331(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_331(_) OF_PP_INTERNAL_SEQ_SIZE_332 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_331 331 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_332(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_332(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_332(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_332(_) OF_PP_INTERNAL_SEQ_SIZE_333 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_332 332 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_333(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_333(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_333(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_333(_) OF_PP_INTERNAL_SEQ_SIZE_334 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_333 333 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_334(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_334(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_334(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_334(_) OF_PP_INTERNAL_SEQ_SIZE_335 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_334 334 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_335(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_335(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_335(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_335(_) OF_PP_INTERNAL_SEQ_SIZE_336 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_335 335 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_336(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_336(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_336(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_336(_) OF_PP_INTERNAL_SEQ_SIZE_337 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_336 336 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_337(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_337(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_337(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_337(_) OF_PP_INTERNAL_SEQ_SIZE_338 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_337 337 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_338(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_338(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_338(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_338(_) OF_PP_INTERNAL_SEQ_SIZE_339 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_338 338 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_339(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_339(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_339(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_339(_) OF_PP_INTERNAL_SEQ_SIZE_340 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_339 339 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_340(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_340(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_340(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_340(_) OF_PP_INTERNAL_SEQ_SIZE_341 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_340 340 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_341(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_341(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_341(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_341(_) OF_PP_INTERNAL_SEQ_SIZE_342 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_341 341 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_342(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_342(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_342(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_342(_) OF_PP_INTERNAL_SEQ_SIZE_343 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_342 342 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_343(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_343(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_343(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_343(_) OF_PP_INTERNAL_SEQ_SIZE_344 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_343 343 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_344(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_344(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_344(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_344(_) OF_PP_INTERNAL_SEQ_SIZE_345 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_344 344 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_345(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_345(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_345(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_345(_) OF_PP_INTERNAL_SEQ_SIZE_346 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_345 345 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_346(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_346(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_346(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_346(_) OF_PP_INTERNAL_SEQ_SIZE_347 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_346 346 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_347(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_347(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_347(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_347(_) OF_PP_INTERNAL_SEQ_SIZE_348 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_347 347 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_348(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_348(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_348(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_348(_) OF_PP_INTERNAL_SEQ_SIZE_349 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_348 348 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_349(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_349(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_349(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_349(_) OF_PP_INTERNAL_SEQ_SIZE_350 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_349 349 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_350(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_350(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_350(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_350(_) OF_PP_INTERNAL_SEQ_SIZE_351 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_350 350 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_351(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_351(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_351(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_351(_) OF_PP_INTERNAL_SEQ_SIZE_352 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_351 351 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_352(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_352(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_352(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_352(_) OF_PP_INTERNAL_SEQ_SIZE_353 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_352 352 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_353(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_353(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_353(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_353(_) OF_PP_INTERNAL_SEQ_SIZE_354 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_353 353 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_354(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_354(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_354(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_354(_) OF_PP_INTERNAL_SEQ_SIZE_355 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_354 354 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_355(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_355(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_355(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_355(_) OF_PP_INTERNAL_SEQ_SIZE_356 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_355 355 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_356(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_356(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_356(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_356(_) OF_PP_INTERNAL_SEQ_SIZE_357 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_356 356 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_357(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_357(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_357(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_357(_) OF_PP_INTERNAL_SEQ_SIZE_358 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_357 357 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_358(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_358(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_358(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_358(_) OF_PP_INTERNAL_SEQ_SIZE_359 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_358 358 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_359(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_359(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_359(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_359(_) OF_PP_INTERNAL_SEQ_SIZE_360 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_359 359 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_360(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_360(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_360(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_360(_) OF_PP_INTERNAL_SEQ_SIZE_361 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_360 360 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_361(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_361(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_361(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_361(_) OF_PP_INTERNAL_SEQ_SIZE_362 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_361 361 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_362(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_362(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_362(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_362(_) OF_PP_INTERNAL_SEQ_SIZE_363 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_362 362 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_363(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_363(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_363(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_363(_) OF_PP_INTERNAL_SEQ_SIZE_364 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_363 363 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_364(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_364(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_364(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_364(_) OF_PP_INTERNAL_SEQ_SIZE_365 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_364 364 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_365(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_365(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_365(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_365(_) OF_PP_INTERNAL_SEQ_SIZE_366 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_365 365 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_366(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_366(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_366(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_366(_) OF_PP_INTERNAL_SEQ_SIZE_367 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_366 366 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_367(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_367(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_367(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_367(_) OF_PP_INTERNAL_SEQ_SIZE_368 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_367 367 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_368(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_368(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_368(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_368(_) OF_PP_INTERNAL_SEQ_SIZE_369 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_368 368 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_369(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_369(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_369(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_369(_) OF_PP_INTERNAL_SEQ_SIZE_370 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_369 369 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_370(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_370(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_370(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_370(_) OF_PP_INTERNAL_SEQ_SIZE_371 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_370 370 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_371(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_371(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_371(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_371(_) OF_PP_INTERNAL_SEQ_SIZE_372 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_371 371 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_372(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_372(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_372(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_372(_) OF_PP_INTERNAL_SEQ_SIZE_373 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_372 372 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_373(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_373(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_373(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_373(_) OF_PP_INTERNAL_SEQ_SIZE_374 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_373 373 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_374(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_374(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_374(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_374(_) OF_PP_INTERNAL_SEQ_SIZE_375 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_374 374 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_375(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_375(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_375(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_375(_) OF_PP_INTERNAL_SEQ_SIZE_376 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_375 375 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_376(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_376(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_376(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_376(_) OF_PP_INTERNAL_SEQ_SIZE_377 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_376 376 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_377(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_377(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_377(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_377(_) OF_PP_INTERNAL_SEQ_SIZE_378 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_377 377 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_378(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_378(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_378(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_378(_) OF_PP_INTERNAL_SEQ_SIZE_379 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_378 378 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_379(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_379(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_379(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_379(_) OF_PP_INTERNAL_SEQ_SIZE_380 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_379 379 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_380(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_380(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_380(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_380(_) OF_PP_INTERNAL_SEQ_SIZE_381 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_380 380 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_381(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_381(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_381(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_381(_) OF_PP_INTERNAL_SEQ_SIZE_382 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_381 381 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_382(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_382(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_382(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_382(_) OF_PP_INTERNAL_SEQ_SIZE_383 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_382 382 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_383(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_383(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_383(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_383(_) OF_PP_INTERNAL_SEQ_SIZE_384 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_383 383 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_384(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_384(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_384(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_384(_) OF_PP_INTERNAL_SEQ_SIZE_385 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_384 384 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_385(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_385(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_385(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_385(_) OF_PP_INTERNAL_SEQ_SIZE_386 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_385 385 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_386(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_386(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_386(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_386(_) OF_PP_INTERNAL_SEQ_SIZE_387 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_386 386 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_387(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_387(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_387(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_387(_) OF_PP_INTERNAL_SEQ_SIZE_388 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_387 387 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_388(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_388(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_388(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_388(_) OF_PP_INTERNAL_SEQ_SIZE_389 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_388 388 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_389(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_389(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_389(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_389(_) OF_PP_INTERNAL_SEQ_SIZE_390 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_389 389 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_390(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_390(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_390(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_390(_) OF_PP_INTERNAL_SEQ_SIZE_391 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_390 390 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_391(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_391(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_391(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_391(_) OF_PP_INTERNAL_SEQ_SIZE_392 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_391 391 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_392(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_392(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_392(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_392(_) OF_PP_INTERNAL_SEQ_SIZE_393 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_392 392 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_393(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_393(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_393(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_393(_) OF_PP_INTERNAL_SEQ_SIZE_394 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_393 393 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_394(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_394(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_394(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_394(_) OF_PP_INTERNAL_SEQ_SIZE_395 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_394 394 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_395(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_395(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_395(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_395(_) OF_PP_INTERNAL_SEQ_SIZE_396 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_395 395 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_396(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_396(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_396(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_396(_) OF_PP_INTERNAL_SEQ_SIZE_397 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_396 396 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_397(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_397(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_397(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_397(_) OF_PP_INTERNAL_SEQ_SIZE_398 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_397 397 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_398(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_398(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_398(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_398(_) OF_PP_INTERNAL_SEQ_SIZE_399 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_398 398 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_399(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_399(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_399(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_399(_) OF_PP_INTERNAL_SEQ_SIZE_400 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_399 399 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_400(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_400(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_400(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_400(_) OF_PP_INTERNAL_SEQ_SIZE_401 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_400 400 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_401(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_401(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_401(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_401(_) OF_PP_INTERNAL_SEQ_SIZE_402 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_401 401 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_402(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_402(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_402(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_402(_) OF_PP_INTERNAL_SEQ_SIZE_403 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_402 402 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_403(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_403(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_403(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_403(_) OF_PP_INTERNAL_SEQ_SIZE_404 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_403 403 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_404(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_404(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_404(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_404(_) OF_PP_INTERNAL_SEQ_SIZE_405 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_404 404 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_405(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_405(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_405(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_405(_) OF_PP_INTERNAL_SEQ_SIZE_406 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_405 405 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_406(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_406(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_406(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_406(_) OF_PP_INTERNAL_SEQ_SIZE_407 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_406 406 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_407(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_407(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_407(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_407(_) OF_PP_INTERNAL_SEQ_SIZE_408 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_407 407 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_408(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_408(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_408(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_408(_) OF_PP_INTERNAL_SEQ_SIZE_409 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_408 408 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_409(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_409(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_409(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_409(_) OF_PP_INTERNAL_SEQ_SIZE_410 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_409 409 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_410(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_410(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_410(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_410(_) OF_PP_INTERNAL_SEQ_SIZE_411 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_410 410 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_411(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_411(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_411(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_411(_) OF_PP_INTERNAL_SEQ_SIZE_412 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_411 411 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_412(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_412(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_412(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_412(_) OF_PP_INTERNAL_SEQ_SIZE_413 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_412 412 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_413(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_413(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_413(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_413(_) OF_PP_INTERNAL_SEQ_SIZE_414 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_413 413 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_414(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_414(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_414(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_414(_) OF_PP_INTERNAL_SEQ_SIZE_415 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_414 414 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_415(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_415(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_415(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_415(_) OF_PP_INTERNAL_SEQ_SIZE_416 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_415 415 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_416(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_416(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_416(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_416(_) OF_PP_INTERNAL_SEQ_SIZE_417 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_416 416 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_417(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_417(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_417(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_417(_) OF_PP_INTERNAL_SEQ_SIZE_418 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_417 417 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_418(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_418(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_418(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_418(_) OF_PP_INTERNAL_SEQ_SIZE_419 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_418 418 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_419(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_419(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_419(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_419(_) OF_PP_INTERNAL_SEQ_SIZE_420 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_419 419 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_420(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_420(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_420(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_420(_) OF_PP_INTERNAL_SEQ_SIZE_421 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_420 420 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_421(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_421(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_421(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_421(_) OF_PP_INTERNAL_SEQ_SIZE_422 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_421 421 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_422(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_422(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_422(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_422(_) OF_PP_INTERNAL_SEQ_SIZE_423 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_422 422 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_423(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_423(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_423(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_423(_) OF_PP_INTERNAL_SEQ_SIZE_424 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_423 423 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_424(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_424(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_424(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_424(_) OF_PP_INTERNAL_SEQ_SIZE_425 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_424 424 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_425(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_425(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_425(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_425(_) OF_PP_INTERNAL_SEQ_SIZE_426 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_425 425 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_426(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_426(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_426(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_426(_) OF_PP_INTERNAL_SEQ_SIZE_427 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_426 426 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_427(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_427(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_427(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_427(_) OF_PP_INTERNAL_SEQ_SIZE_428 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_427 427 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_428(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_428(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_428(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_428(_) OF_PP_INTERNAL_SEQ_SIZE_429 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_428 428 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_429(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_429(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_429(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_429(_) OF_PP_INTERNAL_SEQ_SIZE_430 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_429 429 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_430(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_430(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_430(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_430(_) OF_PP_INTERNAL_SEQ_SIZE_431 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_430 430 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_431(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_431(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_431(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_431(_) OF_PP_INTERNAL_SEQ_SIZE_432 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_431 431 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_432(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_432(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_432(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_432(_) OF_PP_INTERNAL_SEQ_SIZE_433 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_432 432 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_433(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_433(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_433(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_433(_) OF_PP_INTERNAL_SEQ_SIZE_434 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_433 433 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_434(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_434(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_434(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_434(_) OF_PP_INTERNAL_SEQ_SIZE_435 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_434 434 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_435(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_435(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_435(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_435(_) OF_PP_INTERNAL_SEQ_SIZE_436 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_435 435 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_436(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_436(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_436(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_436(_) OF_PP_INTERNAL_SEQ_SIZE_437 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_436 436 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_437(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_437(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_437(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_437(_) OF_PP_INTERNAL_SEQ_SIZE_438 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_437 437 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_438(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_438(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_438(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_438(_) OF_PP_INTERNAL_SEQ_SIZE_439 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_438 438 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_439(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_439(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_439(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_439(_) OF_PP_INTERNAL_SEQ_SIZE_440 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_439 439 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_440(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_440(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_440(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_440(_) OF_PP_INTERNAL_SEQ_SIZE_441 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_440 440 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_441(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_441(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_441(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_441(_) OF_PP_INTERNAL_SEQ_SIZE_442 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_441 441 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_442(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_442(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_442(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_442(_) OF_PP_INTERNAL_SEQ_SIZE_443 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_442 442 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_443(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_443(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_443(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_443(_) OF_PP_INTERNAL_SEQ_SIZE_444 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_443 443 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_444(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_444(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_444(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_444(_) OF_PP_INTERNAL_SEQ_SIZE_445 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_444 444 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_445(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_445(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_445(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_445(_) OF_PP_INTERNAL_SEQ_SIZE_446 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_445 445 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_446(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_446(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_446(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_446(_) OF_PP_INTERNAL_SEQ_SIZE_447 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_446 446 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_447(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_447(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_447(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_447(_) OF_PP_INTERNAL_SEQ_SIZE_448 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_447 447 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_448(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_448(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_448(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_448(_) OF_PP_INTERNAL_SEQ_SIZE_449 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_448 448 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_449(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_449(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_449(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_449(_) OF_PP_INTERNAL_SEQ_SIZE_450 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_449 449 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_450(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_450(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_450(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_450(_) OF_PP_INTERNAL_SEQ_SIZE_451 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_450 450 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_451(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_451(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_451(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_451(_) OF_PP_INTERNAL_SEQ_SIZE_452 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_451 451 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_452(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_452(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_452(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_452(_) OF_PP_INTERNAL_SEQ_SIZE_453 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_452 452 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_453(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_453(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_453(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_453(_) OF_PP_INTERNAL_SEQ_SIZE_454 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_453 453 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_454(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_454(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_454(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_454(_) OF_PP_INTERNAL_SEQ_SIZE_455 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_454 454 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_455(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_455(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_455(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_455(_) OF_PP_INTERNAL_SEQ_SIZE_456 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_455 455 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_456(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_456(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_456(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_456(_) OF_PP_INTERNAL_SEQ_SIZE_457 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_456 456 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_457(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_457(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_457(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_457(_) OF_PP_INTERNAL_SEQ_SIZE_458 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_457 457 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_458(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_458(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_458(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_458(_) OF_PP_INTERNAL_SEQ_SIZE_459 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_458 458 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_459(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_459(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_459(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_459(_) OF_PP_INTERNAL_SEQ_SIZE_460 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_459 459 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_460(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_460(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_460(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_460(_) OF_PP_INTERNAL_SEQ_SIZE_461 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_460 460 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_461(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_461(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_461(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_461(_) OF_PP_INTERNAL_SEQ_SIZE_462 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_461 461 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_462(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_462(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_462(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_462(_) OF_PP_INTERNAL_SEQ_SIZE_463 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_462 462 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_463(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_463(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_463(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_463(_) OF_PP_INTERNAL_SEQ_SIZE_464 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_463 463 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_464(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_464(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_464(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_464(_) OF_PP_INTERNAL_SEQ_SIZE_465 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_464 464 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_465(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_465(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_465(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_465(_) OF_PP_INTERNAL_SEQ_SIZE_466 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_465 465 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_466(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_466(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_466(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_466(_) OF_PP_INTERNAL_SEQ_SIZE_467 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_466 466 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_467(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_467(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_467(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_467(_) OF_PP_INTERNAL_SEQ_SIZE_468 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_467 467 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_468(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_468(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_468(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_468(_) OF_PP_INTERNAL_SEQ_SIZE_469 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_468 468 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_469(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_469(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_469(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_469(_) OF_PP_INTERNAL_SEQ_SIZE_470 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_469 469 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_470(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_470(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_470(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_470(_) OF_PP_INTERNAL_SEQ_SIZE_471 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_470 470 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_471(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_471(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_471(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_471(_) OF_PP_INTERNAL_SEQ_SIZE_472 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_471 471 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_472(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_472(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_472(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_472(_) OF_PP_INTERNAL_SEQ_SIZE_473 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_472 472 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_473(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_473(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_473(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_473(_) OF_PP_INTERNAL_SEQ_SIZE_474 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_473 473 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_474(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_474(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_474(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_474(_) OF_PP_INTERNAL_SEQ_SIZE_475 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_474 474 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_475(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_475(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_475(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_475(_) OF_PP_INTERNAL_SEQ_SIZE_476 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_475 475 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_476(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_476(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_476(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_476(_) OF_PP_INTERNAL_SEQ_SIZE_477 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_476 476 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_477(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_477(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_477(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_477(_) OF_PP_INTERNAL_SEQ_SIZE_478 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_477 477 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_478(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_478(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_478(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_478(_) OF_PP_INTERNAL_SEQ_SIZE_479 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_478 478 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_479(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_479(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_479(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_479(_) OF_PP_INTERNAL_SEQ_SIZE_480 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_479 479 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_480(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_480(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_480(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_480(_) OF_PP_INTERNAL_SEQ_SIZE_481 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_480 480 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_481(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_481(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_481(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_481(_) OF_PP_INTERNAL_SEQ_SIZE_482 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_481 481 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_482(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_482(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_482(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_482(_) OF_PP_INTERNAL_SEQ_SIZE_483 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_482 482 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_483(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_483(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_483(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_483(_) OF_PP_INTERNAL_SEQ_SIZE_484 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_483 483 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_484(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_484(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_484(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_484(_) OF_PP_INTERNAL_SEQ_SIZE_485 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_484 484 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_485(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_485(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_485(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_485(_) OF_PP_INTERNAL_SEQ_SIZE_486 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_485 485 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_486(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_486(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_486(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_486(_) OF_PP_INTERNAL_SEQ_SIZE_487 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_486 486 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_487(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_487(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_487(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_487(_) OF_PP_INTERNAL_SEQ_SIZE_488 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_487 487 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_488(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_488(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_488(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_488(_) OF_PP_INTERNAL_SEQ_SIZE_489 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_488 488 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_489(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_489(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_489(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_489(_) OF_PP_INTERNAL_SEQ_SIZE_490 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_489 489 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_490(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_490(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_490(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_490(_) OF_PP_INTERNAL_SEQ_SIZE_491 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_490 490 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_491(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_491(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_491(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_491(_) OF_PP_INTERNAL_SEQ_SIZE_492 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_491 491 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_492(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_492(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_492(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_492(_) OF_PP_INTERNAL_SEQ_SIZE_493 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_492 492 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_493(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_493(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_493(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_493(_) OF_PP_INTERNAL_SEQ_SIZE_494 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_493 493 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_494(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_494(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_494(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_494(_) OF_PP_INTERNAL_SEQ_SIZE_495 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_494 494 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_495(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_495(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_495(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_495(_) OF_PP_INTERNAL_SEQ_SIZE_496 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_495 495 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_496(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_496(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_496(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_496(_) OF_PP_INTERNAL_SEQ_SIZE_497 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_496 496 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_497(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_497(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_497(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_497(_) OF_PP_INTERNAL_SEQ_SIZE_498 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_497 497 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_498(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_498(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_498(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_498(_) OF_PP_INTERNAL_SEQ_SIZE_499 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_498 498 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_499(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_499(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_499(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_499(_) OF_PP_INTERNAL_SEQ_SIZE_500 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_499 499 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_500(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_500(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_500(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_500(_) OF_PP_INTERNAL_SEQ_SIZE_501 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_500 500 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_501(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_501(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_501(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_501(_) OF_PP_INTERNAL_SEQ_SIZE_502 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_501 501 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_502(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_502(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_502(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_502(_) OF_PP_INTERNAL_SEQ_SIZE_503 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_502 502 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_503(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_503(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_503(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_503(_) OF_PP_INTERNAL_SEQ_SIZE_504 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_503 503 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_504(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_504(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_504(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_504(_) OF_PP_INTERNAL_SEQ_SIZE_505 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_504 504 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_505(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_505(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_505(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_505(_) OF_PP_INTERNAL_SEQ_SIZE_506 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_505 505 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_506(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_506(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_506(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_506(_) OF_PP_INTERNAL_SEQ_SIZE_507 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_506 506 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_507(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_507(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_507(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_507(_) OF_PP_INTERNAL_SEQ_SIZE_508 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_507 507 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_508(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_508(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_508(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_508(_) OF_PP_INTERNAL_SEQ_SIZE_509 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_508 508 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_509(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_509(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_509(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_509(_) OF_PP_INTERNAL_SEQ_SIZE_510 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_509 509 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_510(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_510(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_510(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_510(_) OF_PP_INTERNAL_SEQ_SIZE_511 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_510 510 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_511(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_511(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_511(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_SEQ_SIZE_511(_) OF_PP_INTERNAL_SEQ_SIZE_512 #define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_511 511 #define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_512(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D0_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_512(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D1_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_512(apply, m, d, seq) \ apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq)) \ OF_PP_INTERNAL_D2_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq)) #endif // ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_ ================================================ FILE: oneflow/core/common/preprocessor_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/data_type.h" namespace oneflow { TEST(PP_SEQ, internal_seq_size) { #define SEQ (1)(2)(3) ASSERT_EQ(OF_PP_SEQ_SIZE(SEQ), 3); #undef SEQ } TEST(PP_SEQ, internal_big_seq_size) { #define SEQ \ (0)(1)(2)(3)(4)(5)(6)(7)(8)(9)(10)(11)(12)(13)(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)(24)(25)( \ 26)(27)(28)(29)(30)(31)(32)(33)(34)(35)(36)(37)(38)(39)(40)(41)(42)(43)(44)(45)(46)(47)(48)( \ 49)(50)(51)(52)(53)(54)(55)(56)(57)(58)(59)(60)(61)(62)(63) ASSERT_EQ(OF_PP_SEQ_SIZE(SEQ), 64); #undef SEQ } TEST(PP_SEQ, internal_for_each) { #define SEQ (1)(2)(3)(4) #define MAKE_PAIR(x) {x, x}, std::unordered_map identity = {OF_PP_INTERNAL_SEQ_FOR_EACH_ATOMIC(MAKE_PAIR, _, SEQ)}; #undef MAKE_PAIR #undef SEQ for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); } } TEST(PP_TUPLE, internal_is_tuple_empty) { ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY(()), 1); ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY((1)), 0); ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY((1, 2)), 0); } TEST(PP_TUPLE, internal_tuple_size) { ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE(()), 0); ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1)), 1); ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2)), 2); ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3)), 3); ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3, 4)), 4); ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3, 4, 5)), 5); } TEST(PP_SEQ, internal_seq_product) { #define SEQ (0)(1) std::string expanded(OF_PP_STRINGIZE(OF_PP_INTERNAL_SEQ_PRODUCT(SEQ, SEQ))); #undef SEQ ASSERT_TRUE((expanded == "((0, 0)) ((1, 0)) ((0, 1)) ((1, 1))") || (expanded == "((0, 0)) ((1, 0)) ((0, 1)) ((1, 1))")); } TEST(PP_SEQ, internal_different_seq_product) { #define SEQ1 (0)(1) #define SEQ2 (a)(b) std::string expanded(OF_PP_STRINGIZE(OF_PP_INTERNAL_SEQ_PRODUCT(SEQ1, SEQ2))); #undef SEQ1 #undef SEQ2 ASSERT_TRUE((expanded == "((0, a)) ((1, a)) ((0, b)) ((1, b))") || (expanded == "((0, a)) ((1, a)) ((0, b)) ((1, b))")); } TEST(PP_SEQ, internal_seq_product_for_each) { #define SEQ (0)(1) #define MAKE_ENTRY(x, y) {OF_PP_STRINGIZE(OF_PP_CAT(x, y)), x || y}, std::unordered_map or_table = { OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, OF_PP_INTERNAL_SEQ_PRODUCT(SEQ, SEQ))}; #undef MAKE_ENTRY #undef SEQ ASSERT_EQ(or_table["00"], false); ASSERT_EQ(or_table["01"], true); ASSERT_EQ(or_table["10"], true); ASSERT_EQ(or_table["11"], true); } TEST(PP, stringize) { ASSERT_EQ(OF_PP_STRINGIZE(foo), "foo"); ASSERT_EQ(OF_PP_STRINGIZE(bar), "bar"); } TEST(PP, concate) { ASSERT_EQ(OF_PP_CAT(OF_PP_, STRINGIZE)(foo), "foo"); ASSERT_EQ(OF_PP_CAT(OF_PP_, STRINGIZE)(bar), "bar"); } TEST(PP_SEQ, make_tuple_seq) { ASSERT_EQ(OF_PP_STRINGIZE(OF_PP_MAKE_TUPLE_SEQ(1, 2)), "((1, 2))"); } TEST(PP_SEQ, for_each_tuple) { #define SEQ ((1, 1))((2, 2))((3, 3))((4, 4)) #define MAKE_ENTRY(x, y) {x, y}, std::unordered_map identity = {OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ)}; #undef MAKE_ENTRY #undef SEQ for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); } } TEST(PP_SEQ, outter_for_each_tuple) { #define SEQ ((1, 1))((2, 2))((3, 3))((4, 4)) #define MAKE_ENTRY(x, y) {x, y}, std::unordered_map identity = {OF_PP_OUTTER_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ)}; #undef MAKE_ENTRY #undef SEQ for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); } } TEST(PP_SEQ, nested_for_each_tuple) { #define SEQ ((0))((1))((2))((3)) #define MAKE_INNER(x) x, #define MAKE_OUTTER(x) {OF_PP_FOR_EACH_TUPLE(MAKE_INNER, SEQ)}, std::vector> table = {OF_PP_OUTTER_FOR_EACH_TUPLE(MAKE_OUTTER, SEQ)}; #undef MAKE_OUTTER #undef MAKE_INNER #undef SEQ ASSERT_EQ(table.size(), 4); for (int i = 0; i < 4; ++i) { ASSERT_EQ(table[i].size(), 4); for (int j = 0; j < 4; ++j) { ASSERT_EQ(j, table[i][j]); } } } TEST(PP_SEQ, seq_product_for_each) { #define SEQ (0)(1) #define MAKE_ENTRY(x, y) {OF_PP_STRINGIZE(OF_PP_CAT(x, y)), x || y}, std::unordered_map or_table = { OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ, SEQ)}; #undef MAKE_ENTRY #undef SEQ ASSERT_EQ(or_table["00"], false); ASSERT_EQ(or_table["01"], true); ASSERT_EQ(or_table["10"], true); ASSERT_EQ(or_table["11"], true); } } // namespace oneflow ================================================ FILE: oneflow/core/common/process_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PROCESS_STATE_H_ #define ONEFLOW_CORE_COMMON_PROCESS_STATE_H_ #if defined(_MSC_VER) #include #include #include #pragma comment(lib, "Ws2_32.lib") #else #include #endif #include #include namespace oneflow { inline std::string GetCwd() { size_t len = 128; std::unique_ptr a(new char[len]); for (;;) { char* p = getcwd(a.get(), len); if (p != NULL) { return p; } else if (errno == ERANGE) { len += len; a.reset(new char[len]); } else { return NULL; } } } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_PROCESS_STATE_H_ ================================================ FILE: oneflow/core/common/protobuf.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/shape.pb.h" #include "oneflow/core/common/sequential.pb.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/register/blob_desc.pb.h" #include #include #include namespace oneflow { // parse protobuf message from .prototxt file bool TryParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) { std::ifstream in_stream(file_path.c_str(), std::ifstream::in); google::protobuf::io::IstreamInputStream input(&in_stream); return google::protobuf::TextFormat::Parse(&input, proto); } void ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) { CHECK(TryParseProtoFromTextFile(file_path, proto)); } // parse protobuf message from .pb file bool TryParseProtoFromPbFile(const std::string& file_path, PbMessage* proto) { std::ifstream in_stream(file_path.c_str(), std::ifstream::in | std::ifstream::binary); return proto->ParseFromIstream(&in_stream); } void ParseProtoFromPbFile(const std::string& file_path, PbMessage* proto) { CHECK(TryParseProtoFromPbFile(file_path, proto)); } void PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path) { std::ofstream out_stream(file_path.c_str(), std::ofstream::out | std::ofstream::trunc); google::protobuf::io::OstreamOutputStream output(&out_stream); CHECK(google::protobuf::TextFormat::Print(proto, &output)); } std::string PbMessage2TxtString(const PbMessage& proto) { std::string str; PbMessage2TxtString(proto, &str); return str; } void PbMessage2TxtString(const PbMessage& proto, std::string* str) { google::protobuf::TextFormat::PrintToString(proto, str); } bool TxtString2PbMessage(const std::string& proto_str, PbMessage* msg) { return google::protobuf::TextFormat::ParseFromString(proto_str, msg); } bool FieldDefinedInPbMessage(const PbMessage& msg, const std::string& field_name) { PROTOBUF_GET_FIELDDESC(msg, field_name); return fd != nullptr; } #define DEFINE_GET_VAL_FROM_PBMESSAGE(cpp_type, pb_type_name) \ template<> \ cpp_type GetValFromPbMessage(const PbMessage& msg, const std::string& field_name) { \ PROTOBUF_REFLECTION(msg, field_name); \ return r->Get##pb_type_name(msg, fd); \ } OF_PP_FOR_EACH_TUPLE(DEFINE_GET_VAL_FROM_PBMESSAGE, PROTOBUF_BASIC_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(const PbMessage&, Message)) #define DEFINE_SET_VAL_IN_PBMESSAGE(cpp_type, pb_type_name) \ template<> \ void SetValInPbMessage(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \ PROTOBUF_REFLECTION((*msg), field_name); \ r->Set##pb_type_name(msg, fd, val); \ } OF_PP_FOR_EACH_TUPLE(DEFINE_SET_VAL_IN_PBMESSAGE, PROTOBUF_BASIC_DATA_TYPE_SEQ) const PbMessage& GetMessageInPbMessage(const PbMessage& msg, const std::string& field_name) { PROTOBUF_REFLECTION(msg, field_name); return r->GetMessage(msg, fd); } PbMessage* MutableMessageInPbMessage(PbMessage* msg, const std::string& field_name) { PROTOBUF_REFLECTION((*msg), field_name); return r->MutableMessage(msg, fd); } const PbMessage& GetMessageInPbMessage(const PbMessage& msg, int field_index) { const auto* d = const_cast(msg.GetDescriptor()); const auto* fd = const_cast(d->FindFieldByNumber(field_index)); CHECK_NOTNULL(fd); const auto* r = const_cast(msg.GetReflection()); return r->GetMessage(msg, fd); } PbMessage* MutableMessageInPbMessage(PbMessage* msg, int field_index) { const auto* d = const_cast(msg->GetDescriptor()); const auto* fd = const_cast(d->FindFieldByNumber(field_index)); CHECK_NOTNULL(fd); const auto* r = const_cast(msg->GetReflection()); return r->MutableMessage(msg, fd); } #define DECLARE_GETTER_FUNC_HEADER(type) \ template<> \ type GetValFromPbMessage(const PbMessage& msg, const std::string& field_name) #define DECLARE_SETTER_FUNC_HEADER(type) \ template<> \ void SetValInPbMessage(PbMessage * msg, const std::string& field_name, const type& val) #define DEFINE_MESSAGE_VAL_GETTER_AND_SETTER(message_type) \ DECLARE_GETTER_FUNC_HEADER(message_type) { \ PROTOBUF_REFLECTION(msg, field_name); \ return *dynamic_cast(&r->GetMessage(msg, fd)); \ } \ DECLARE_SETTER_FUNC_HEADER(message_type) { \ PROTOBUF_REFLECTION((*msg), field_name); \ r->MutableMessage(msg, fd)->CopyFrom(val); \ } DEFINE_MESSAGE_VAL_GETTER_AND_SETTER(ShapeProto); DEFINE_MESSAGE_VAL_GETTER_AND_SETTER(Int64ListProto); #define DEFINE_ENUM_VAL_GETTER_AND_SETTER(enum_type) \ DECLARE_GETTER_FUNC_HEADER(enum_type) { \ PROTOBUF_REFLECTION(msg, field_name); \ return static_cast(r->GetEnumValue(msg, fd)); \ } \ DECLARE_SETTER_FUNC_HEADER(enum_type) { \ PROTOBUF_REFLECTION((*msg), field_name); \ r->SetEnumValue(msg, fd, val); \ } DEFINE_ENUM_VAL_GETTER_AND_SETTER(DataType); #define DEFINE_VECTOR_VAL_GETTER_AND_SETTER(vec_type, vec_type_name) \ DECLARE_GETTER_FUNC_HEADER(vec_type) { \ PROTOBUF_REFLECTION(msg, field_name); \ int32_t field_size = r->FieldSize(msg, fd); \ vec_type retval(field_size); \ for (int i = 0; i < field_size; ++i) { retval[i] = r->Get##vec_type_name(msg, fd, i); } \ return retval; \ } \ DECLARE_SETTER_FUNC_HEADER(vec_type) { \ PROTOBUF_REFLECTION((*msg), field_name); \ for (int i = 0; i < val.size(); ++i) { r->Set##vec_type_name(msg, fd, i, val[i]); } \ } #define MAKE_REPEATED_TUPLE_SEQ(type, type_name) \ OF_PP_MAKE_TUPLE_SEQ(std::vector, Repeated##type_name) #define PROTOBUF_BASIC_REPEATED_DATA_TYPE_SEQ \ MAKE_REPEATED_TUPLE_SEQ(std::string, String) \ MAKE_REPEATED_TUPLE_SEQ(int32_t, Int32) \ MAKE_REPEATED_TUPLE_SEQ(uint32_t, UInt32) \ MAKE_REPEATED_TUPLE_SEQ(int64_t, Int64) \ MAKE_REPEATED_TUPLE_SEQ(uint64_t, UInt64) \ MAKE_REPEATED_TUPLE_SEQ(float, Float) \ MAKE_REPEATED_TUPLE_SEQ(double, Double) \ MAKE_REPEATED_TUPLE_SEQ(int16_t, EnumValue) \ MAKE_REPEATED_TUPLE_SEQ(bool, Bool) OF_PP_FOR_EACH_TUPLE(DEFINE_VECTOR_VAL_GETTER_AND_SETTER, PROTOBUF_BASIC_REPEATED_DATA_TYPE_SEQ); #define DEFINE_ADD_VAL_IN_PBRF(cpp_type, pb_type_name) \ template<> \ void AddValInPbRf(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \ PROTOBUF_REFLECTION((*msg), field_name); \ r->Add##pb_type_name(msg, fd, val); \ } OF_PP_FOR_EACH_TUPLE(DEFINE_ADD_VAL_IN_PBRF, PROTOBUF_BASIC_DATA_TYPE_SEQ) std::pair GetFieldNameAndIndex4StrVal(const std::string& fd_name_with_idx) { std::string field_name; int32_t idx = 0; CHECK_GE(idx, 0); GetPrefixAndIndex(fd_name_with_idx, &field_name, &idx); return std::make_pair(field_name, idx); } std::string GetStrValInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx) { const PbFd* fd = msg.GetDescriptor()->FindFieldByName(fd_name_may_have_idx); if (fd) { return GetValFromPbMessage(msg, fd_name_may_have_idx); } else { const std::pair prefix_idx = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx); return GetPbRpfFromPbMessage(msg, prefix_idx.first).Get(prefix_idx.second); } } bool HasStrFieldInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx) { const PbFd* fd = msg.GetDescriptor()->FindFieldByName(fd_name_may_have_idx); if (fd != nullptr) { return true; } std::string field_name; int32_t index = 0; return TryGetPrefixAndIndex(fd_name_may_have_idx, &field_name, &index); } std::string ReplaceStrValInPbFdOrPbRpf(PbMessage* msg, const std::string& fd_name_may_have_idx, const std::string& new_val) { const PbFd* fd = msg->GetDescriptor()->FindFieldByName(fd_name_may_have_idx); std::string old_val; if (fd) { old_val = GetValFromPbMessage(*msg, fd_name_may_have_idx); SetValInPbMessage(msg, fd_name_may_have_idx, new_val); } else { const std::pair prefix_idx = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx); old_val = GetPbRpfFromPbMessage(*msg, prefix_idx.first).Get(prefix_idx.second); PbRpf* rpf = MutPbRpfFromPbMessage(msg, prefix_idx.first); *rpf->Mutable(prefix_idx.second) = new_val; } return old_val; } PersistentOutStream& operator<<(PersistentOutStream& out_stream, const PbMessage& msg) { std::string msg_bin; msg.SerializeToString(&msg_bin); int64_t msg_size = msg_bin.size(); CHECK_GT(msg_size, 0); out_stream << msg_size << msg_bin; return out_stream; } bool operator==(const BlobDescProto& lhs, const BlobDescProto& rhs) { return PbMd().Equivalent(lhs, rhs); } } // namespace oneflow ================================================ FILE: oneflow/core/common/protobuf.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_PROTOBUF_H_ #define ONEFLOW_CORE_COMMON_PROTOBUF_H_ #ifdef _MSC_VER #include #endif #include #include #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/register/logical_blob_id.pb.h" #include "oneflow/core/register/op_blob_arg.pb.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/scope.pb.h" #include "oneflow/core/persistence/persistent_out_stream.h" namespace oneflow { using PbMessage = google::protobuf::Message; template using PbRf = google::protobuf::RepeatedField; template using PbRpf = google::protobuf::RepeatedPtrField; template using PbMapPair = google::protobuf::MapPair; template using PbMap = google::protobuf::Map; using PbFd = google::protobuf::FieldDescriptor; using PbMd = google::protobuf::util::MessageDifferencer; #define PROTOBUF_BASIC_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(std::string, String) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, Int32) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, UInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, Int64) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, UInt64) \ OF_PP_MAKE_TUPLE_SEQ(float, Float) \ OF_PP_MAKE_TUPLE_SEQ(double, Double) \ OF_PP_MAKE_TUPLE_SEQ(int16_t, EnumValue) \ OF_PP_MAKE_TUPLE_SEQ(bool, Bool) #define PROTOBUF_GET_FIELDDESC(msg, field_name) \ auto d = const_cast(msg.GetDescriptor()); \ auto fd = const_cast(d->FindFieldByName(field_name)); #define PROTOBUF_REFLECTION(msg, field_name) \ PROTOBUF_GET_FIELDDESC(msg, field_name) \ CHECK_NOTNULL(fd); \ auto r = const_cast(msg.GetReflection()); // Prototxt <-> File bool TryParseProtoFromTextFile(const std::string& file_path, PbMessage* proto); void ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto); bool TryParseProtoFromPbFile(const std::string& file_path, PbMessage* proto); void ParseProtoFromPbFile(const std::string& file_path, PbMessage* proto); void PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path); std::string PbMessage2TxtString(const PbMessage& proto); void PbMessage2TxtString(const PbMessage& proto, std::string* str); bool TxtString2PbMessage(const std::string& proto_str, PbMessage* proto); // Does PbMessage have the field_name bool FieldDefinedInPbMessage(const PbMessage&, const std::string& field_name); // Get From PbMessage template T GetValFromPbMessage(const PbMessage&, const std::string& field_name); template const PbRf& GetPbRfFromPbMessage(const PbMessage& msg, const std::string& field_name) { PROTOBUF_REFLECTION(msg, field_name); return r->GetRepeatedField(msg, fd); } template const PbRpf& GetPbRpfFromPbMessage(const PbMessage& msg, const std::string& field_name) { PROTOBUF_REFLECTION(msg, field_name); return r->GetRepeatedPtrField(msg, fd); } template PbRpf* MutPbRpfFromPbMessage(PbMessage* msg, const std::string& field_name) { PROTOBUF_REFLECTION((*msg), field_name); return r->MutableRepeatedPtrField(msg, fd); } // Set In PbMessage template void SetValInPbMessage(PbMessage* msg, const std::string& field_name, const T& val); const PbMessage& GetMessageInPbMessage(const PbMessage& msg, int field_index); const PbMessage& GetMessageInPbMessage(const PbMessage& msg, const std::string& field_name); PbMessage* MutableMessageInPbMessage(PbMessage*, const std::string& field_name); PbMessage* MutableMessageInPbMessage(PbMessage*, int field_index); // Get/Replace str val maybe repeated; field_name with index is like "name_0" std::pair GetFieldNameAndIndex4StrVal(const std::string& fd_name_with_idx); std::string GetStrValInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx); bool HasStrFieldInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx); // return old value std::string ReplaceStrValInPbFdOrPbRpf(PbMessage* msg, const std::string& fd_name_may_have_idx, const std::string& new_val); // Add In PbMessage RepeatedField template void AddValInPbRf(PbMessage*, const std::string& field_name, const T& val); // PbRf <-> std::vector template inline std::vector PbRf2StdVec(const PbRf& rf) { return std::vector(rf.begin(), rf.end()); } template inline PbRf StdVec2PbRf(const std::vector& vec) { return PbRf(vec.begin(), vec.end()); } // PbRpf <-> std::vector template inline std::vector PbRpf2StdVec(const PbRpf& rpf) { return std::vector(rpf.begin(), rpf.end()); } template inline PbRpf StdVec2PbRpf(const std::vector& vec) { using RetType = PbRpf; return RetType(vec.begin(), vec.end()); } // ProtoMap <-> HashMap template HashMap PbMap2HashMap(const google::protobuf::Map& pb_map) { return HashMap(pb_map.begin(), pb_map.end()); } template google::protobuf::Map HashMap2PbMap(const HashMap& hash_map) { using RetType = google::protobuf::Map; return RetType(hash_map.begin(), hash_map.end()); } // If value exists in RepeatedField template bool IsInRepeatedField(const PbRf& repeated_field, const T& value) { return std::find(repeated_field.cbegin(), repeated_field.cend(), value) != repeated_field.cend(); } // LBI compare operator inline bool operator<(const LogicalBlobId& lhs, const LogicalBlobId& rhs) { if (lhs.op_name() != rhs.op_name()) { return lhs.op_name() < rhs.op_name(); } if (lhs.blob_name() != rhs.blob_name()) { return lhs.blob_name() < rhs.blob_name(); } return false; } inline bool operator==(const LogicalBlobId& lhs, const LogicalBlobId& rhs) { return lhs.op_name() == rhs.op_name() && lhs.blob_name() == rhs.blob_name(); } inline bool operator!=(const LogicalBlobId& lhs, const LogicalBlobId& rhs) { return !(lhs == rhs); } inline bool operator==(const OpBlobArg& lhs, const OpBlobArg& rhs) { return PbMd().Equals(lhs, rhs); } inline bool operator!=(const OpBlobArg& lhs, const OpBlobArg& rhs) { return !(lhs == rhs); } class BlobDescProto; bool operator==(const BlobDescProto& lhs, const BlobDescProto& rhs); inline bool operator!=(const BlobDescProto& lhs, const BlobDescProto& rhs) { return !(lhs == rhs); } inline bool operator==(const JobConfigProto& lhs, const JobConfigProto& rhs) { return PbMd().Equals(lhs, rhs); } inline bool operator==(const ScopeProto& lhs, const ScopeProto& rhs) { return PbMd().Equals(lhs, rhs); } // Persistent PersistentOutStream& operator<<(PersistentOutStream&, const PbMessage&); template struct SerializedHashPb { size_t operator()(const T& pb) const { std::string serialized_string; pb.SerializeToString(&serialized_string); return std::hash()(serialized_string); } }; } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::DataType data_type) const { return std::hash()(data_type); } }; template<> struct hash { size_t operator()(const oneflow::LogicalBlobId& lbi) const { using namespace oneflow; return Hash(lbi.op_name(), lbi.blob_name()); } }; template<> struct hash { size_t operator()(const oneflow::OpBlobArg& oba) const { using namespace oneflow; return Hash(oba.op_name(), oba.bn_in_op()); } }; template<> struct hash { size_t operator()(const oneflow::SbpParallel& sbp_parallel) const { using namespace oneflow; size_t ret = 0; if (sbp_parallel.has_broadcast_parallel()) { AddHash(&ret, std::string("B")); } else if (sbp_parallel.has_partial_sum_parallel()) { AddHash(&ret, std::string("P")); } else if (sbp_parallel.has_split_parallel()) { AddHash(&ret, std::string("S")); AddHash(&ret, sbp_parallel.split_parallel().axis()); } else { UNIMPLEMENTED(); } return ret; } }; template<> struct hash { size_t operator()(const oneflow::NdSbp& nd_sbp) const { const auto& sbp_hash = std::hash(); size_t hash = 0; for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { oneflow::HashCombine(&hash, sbp_hash(nd_sbp.sbp_parallel(i))); } return hash; } }; template<> struct hash { size_t operator()(const oneflow::JobConfigProto& job_conf) const { return oneflow::SerializedHashPb()(job_conf); } }; template<> struct hash { size_t operator()(const oneflow::ScopeProto& scope) const { return oneflow::SerializedHashPb()(scope); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_PROTOBUF_H_ ================================================ FILE: oneflow/core/common/range.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/range.h" namespace oneflow { Range::Range(const RangeProto& range_proto) { begin_ = range_proto.begin(); end_ = range_proto.end(); } void Range::ToProto(RangeProto* ret) const { ret->set_begin(begin_); ret->set_end(end_); } Maybe Range::ForEachSubRange( int64_t sub_range_size, const std::function(const Range&)>& DoEachRange) const { CHECK_EQ_OR_RETURN(size() % sub_range_size, 0); int64_t start = begin(); for (; start < end(); start += sub_range_size) { JUST(DoEachRange(Range(start, start + sub_range_size))); } CHECK_EQ_OR_RETURN(start, end()); return Maybe::Ok(); } Range FindIntersectant(const Range& lhs, const Range& rhs) { if (lhs.end() > rhs.begin() && rhs.end() > lhs.begin()) { int64_t left = lhs.begin() > rhs.begin() ? lhs.begin() : rhs.begin(); int64_t right = lhs.end() < rhs.end() ? lhs.end() : rhs.end(); return Range(left, right); } else { return Range(0, 0); } } } // namespace oneflow ================================================ FILE: oneflow/core/common/range.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_RANGE_H_ #define ONEFLOW_CORE_COMMON_RANGE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/range.pb.h" namespace oneflow { class Range final { public: // OF_DISALLOW_COPY_AND_MOVE(Range); Range() : Range(0, 0) {} ~Range() = default; Range(int64_t begin, int64_t end) : begin_(begin), end_(end) {} explicit Range(const RangeProto& range_proto); bool operator==(const Range& rhs) const { return begin_ == rhs.begin_ && end_ == rhs.end_; } bool operator!=(const Range& rhs) const { return !(*this == rhs); } int64_t begin() const { return begin_; } int64_t end() const { return end_; } int64_t& mut_begin() { return begin_; } int64_t& mut_end() { return end_; } int64_t size() const { return end_ - begin_; } Maybe ForEachSubRange(int64_t sub_range_size, const std::function(const Range&)>& DoEachRange) const; void ToProto(RangeProto* ret) const; private: int64_t begin_; int64_t end_; }; Range FindIntersectant(const Range& lhs, const Range& rhs); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::Range& range) const { return oneflow::HashCombine(range.begin(), range.end()); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_RANGE_H_ ================================================ FILE: oneflow/core/common/range.proto ================================================ syntax = "proto2"; package oneflow; message RangeProto { required int64 begin = 1; required int64 end = 2; } ================================================ FILE: oneflow/core/common/registry_error.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/registry_error.h" namespace oneflow { namespace { std::shared_ptr* MutRegistryError() { static std::shared_ptr registry_error; return ®istry_error; } } // namespace Maybe CheckAndClearRegistryFlag() { if (!*MutRegistryError()) { return Maybe::Ok(); } std::shared_ptr registry_error_old = *MutRegistryError(); *MutRegistryError() = nullptr; return registry_error_old; } void CatchRegistryError(const std::function()>& handler) { const auto& maybe_error = TRY(handler()); if (!maybe_error.IsOk()) { if (!*MutRegistryError()) { *MutRegistryError() = maybe_error.stacked_error(); } } } } // namespace oneflow ================================================ FILE: oneflow/core/common/registry_error.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H #define ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H #include #include "oneflow/core/common/maybe.h" namespace oneflow { // Note: there is a time interval between catching error and reporting an error, // any error occur in this interval can't be displayed. Maybe CheckAndClearRegistryFlag(); void CatchRegistryError(const std::function()>&); } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H ================================================ FILE: oneflow/core/common/scalar.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/scalar.h" namespace oneflow { #define DEFINE_SCALAR_BINARY_OP(op) \ Scalar& Scalar::operator op##=(const Scalar& other) { \ if (IsComplex() || other.IsComplex()) { \ std::complex val = \ Value>() op other.Value>(); \ *this = val; \ } \ if (IsFloatingPoint() || other.IsFloatingPoint()) { \ double val = As() op other.As(); \ *this = val; \ } else { \ int64_t val = As() op other.As(); \ *this = val; \ } \ return *this; \ } \ Scalar Scalar::operator op(const Scalar& other) const { \ if (IsComplex() || other.IsComplex()) { \ std::complex val = \ Value>() op other.Value>(); \ return Scalar(val); \ } \ if (IsFloatingPoint() || other.IsFloatingPoint()) { \ double val = As() op other.As(); \ return Scalar(val); \ } \ int64_t val = As() op other.As(); \ return Scalar(val); \ } DEFINE_SCALAR_BINARY_OP(+); DEFINE_SCALAR_BINARY_OP(-); DEFINE_SCALAR_BINARY_OP(*); DEFINE_SCALAR_BINARY_OP(/); // NOLINT #undef DEFINE_SCALAR_BINARY_OP } // namespace oneflow ================================================ FILE: oneflow/core/common/scalar.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SCALAR_H_ #define ONEFLOW_CORE_COMMON_SCALAR_H_ #include #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/maybe.h" namespace oneflow { class Scalar { public: Scalar() : Scalar(int32_t(0)) {} template, T>::value || std::is_same, T>::value, int>::type = 0> Scalar(const T& value) : value_{.c = {value.real(), value.imag()}}, active_tag_(HAS_C) {} template::value, int>::type = 0> OF_DEVICE_FUNC Scalar(const T& value) : value_{.b = value}, active_tag_(HAS_B) {} template::value && std::is_signed::value, int>::type = 0> OF_DEVICE_FUNC Scalar(const T& value) : value_{.s = value}, active_tag_(HAS_S) {} template::value && std::is_unsigned::value && !std::is_same::value, int>::type = 0> OF_DEVICE_FUNC Scalar(const T& value) : value_{.u = value}, active_tag_(HAS_U) {} template::value, int>::type = 0> OF_DEVICE_FUNC Scalar(const T& value) : value_{.d = value}, active_tag_(HAS_D) {} template::value, int>::type = 0> OF_DEVICE_FUNC Scalar& operator=(const T& value) { *this = Scalar(value); return *this; } OF_DEVICE_FUNC Scalar& operator=(const Scalar& other) { value_ = other.value_; active_tag_ = other.active_tag_; return *this; } template::value, int>::type = 0> OF_DEVICE_FUNC T As() const { switch (active_tag_) { case HAS_B: return static_cast(value_.b); case HAS_S: return static_cast(value_.s); case HAS_U: return static_cast(value_.u); case HAS_D: return static_cast(value_.d); default: assert(false); return 0; } } template::value, int>::type = 0> OF_DEVICE_FUNC T Value() const { return As(); } template, T>::value || std::is_same, T>::value, int>::type = 0> T Value() const { if (!IsComplex()) { return T(As(), 0.0); } return T(value_.c.real, value_.c.imag); } bool IsBool() const { return active_tag_ == HAS_B; } bool IsIntegral() const { return active_tag_ == HAS_S || active_tag_ == HAS_U; } bool IsFloatingPoint() const { return active_tag_ == HAS_D; } bool IsSigned() const { return active_tag_ == HAS_S || active_tag_ == HAS_D; } bool IsUnsigned() const { return active_tag_ == HAS_U; } bool IsComplex() const { return active_tag_ == HAS_C; } Scalar operator+(const Scalar& other) const; Scalar operator-(const Scalar& other) const; Scalar operator*(const Scalar& other) const; Scalar operator/(const Scalar& other) const; Scalar& operator+=(const Scalar& other); Scalar& operator-=(const Scalar& other); Scalar& operator*=(const Scalar& other); Scalar& operator/=(const Scalar& other); private: union Value { bool b; int64_t s; uint64_t u; double d; struct { double real; double imag; } c; } value_; enum { HAS_B, HAS_S, HAS_U, HAS_D, HAS_C, HAS_NONE } active_tag_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SCALAR_H_ ================================================ FILE: oneflow/core/common/sequential.proto ================================================ syntax = "proto2"; package oneflow; message Int64ListProto { repeated int64 dim = 1; } ================================================ FILE: oneflow/core/common/shape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { template int64_t ConstShapeMixIn::elem_cnt() const { return std::accumulate(tp()->begin(), tp()->end(), int64_t(1), std::multiplies<>()); } template int64_t ConstShapeMixIn::At(int64_t index) const { CHECK_GE(index, 0); CHECK_LT(index, tp()->NumAxes()) << " Shape: " << tp()->DebugStr() << " visit index: " << index << " > num_axes: " << tp()->NumAxes(); return (*tp())[index]; } template int64_t ConstShapeMixIn::Count(int64_t begin_axis, int64_t end_axis) const { CHECK(0 <= begin_axis && begin_axis <= end_axis && end_axis <= tp()->NumAxes()) << begin_axis << " " << end_axis; int64_t cnt = 1; for (int64_t i = begin_axis; i < end_axis; ++i) { cnt *= At(i); } return cnt; } template int64_t ConstShapeMixIn::Count(int64_t begin_axis) const { return Count(begin_axis, tp()->NumAxes()); } template bool ConstShapeMixIn::Containing(ShapeView small_shape) const { if (tp()->NumAxes() < small_shape.NumAxes()) { return false; } FOR_RANGE(int, i, 0, small_shape.NumAxes()) { if (tp()->At(i) != small_shape.At(i)) { return false; } } return true; } template bool ConstShapeMixIn::MatchBeforeLastDim(ShapeView next_shape) const { if (tp()->NumAxes() != next_shape.NumAxes()) { return false; } for (int64_t i = 0; i < tp()->NumAxes() - 1; ++i) { if (next_shape.At(i) != tp()->At(i)) { return false; } } return true; } template std::string ConstShapeMixIn::ToString() const { std::stringstream ss; int32_t idx = 0; ss << "("; for (int64_t dim : *tp()) { ss << dim; if (++idx != tp()->size() || tp()->size() == 1) { ss << ","; } } ss << ")"; return ss.str(); } template std::string ConstShapeMixIn::DebugStr() const { return ToString(); } template void ConstShapeMixIn::ToProto(ShapeProto* ret) const { *(ret->mutable_dim()) = PbRf(tp()->begin(), tp()->end()); } template bool ConstShapeMixIn::operator==(const T& rhs) const { if (this->NumAxes() != rhs.NumAxes()) { return false; } FOR_RANGE(int, i, 0, this->NumAxes()) { if (this->At(i) != rhs.At(i)) { return false; } } return true; } template struct ConstShapeMixIn; template struct MutShapeMixIn; template struct ConstShapeMixIn; template struct ConstShapeMixIn; template struct MutShapeMixIn; Shape CreateReducedShape(ShapeView shape, const AxisVector& axis_vec) { // For 0-dim Tensor if (axis_vec.empty()) { return Shape({}); } DimVector dim_vec; shape.ToDimVector(&dim_vec); for (int64_t axis : axis_vec) { dim_vec.at(ShiftNegativeAxis(axis, shape.NumAxes())) = 1; } return Shape(std::move(dim_vec)); } Shape CreateLeftExtendedShape(ShapeView shape, int ndims_left_extend_to) { CHECK_GE(ndims_left_extend_to, shape.NumAxes()); DimVector dim_vec(ndims_left_extend_to); const size_t left_ones_num = ndims_left_extend_to - shape.NumAxes(); int i = 0; for (; i < left_ones_num; ++i) { dim_vec.at(i) = 1LL; } for (; i < ndims_left_extend_to; ++i) { dim_vec.at(i) = shape.At(i - left_ones_num); } return Shape(std::move(dim_vec)); } Shape ExpandDimIf0D(const Shape& shape) { if (shape.NumAxes() == 0) { return {1}; } return shape; } Shape ExpandDimIf0D(ShapeView shape) { if (shape.NumAxes() == 0) { return {1}; } return Shape(shape); } Shape CreateReducedShapeOrOnesShape(ShapeView shape, const AxisVector& axis_vec) { if (axis_vec.empty()) { return Shape::Ones(shape.NumAxes()); } return CreateReducedShape(shape, axis_vec); } int64_t ShiftNegativeAxis(int64_t axis, const int64_t num_axes) { if (axis < 0) { axis += num_axes; } CHECK_GE(axis, 0); CHECK_LT(axis, num_axes); return axis; } Shape::Shape(const DimVector& dim_vec) : DimVector(dim_vec), is_initialized_(true) {} Shape::Shape(DimVector&& dim_vec) : DimVector(std::move(dim_vec)), is_initialized_(true) {} Shape::Shape(const ShapeProto& shape_proto) : DimVector(shape_proto.dim().begin(), shape_proto.dim().end()), is_initialized_(true) {} Shape::Shape(ShapeView shape_view) : DimVector(shape_view.begin(), shape_view.end()), is_initialized_(true) {} Shape& Shape::CheckNumAxesIdenticalAndAssign(ShapeView shape_view) { CHECK_EQ(NumAxes(), shape_view.NumAxes()); std::copy(shape_view.ptr(), shape_view.ptr() + shape_view.NumAxes(), data()); return *this; } Shape& Shape::LeftOnesExtendedAssign(ShapeView shape_view) { CHECK_GE(NumAxes(), shape_view.NumAxes()); size_t left_ones_size = NumAxes() - shape_view.NumAxes(); FOR_RANGE(int, i, 0, left_ones_size) { (*this)[i] = 1LL; } std::copy(shape_view.ptr(), shape_view.ptr() + shape_view.NumAxes(), data() + left_ones_size); return *this; } std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.DebugStr(); return out; } AxisVector Shape::ShiftNegativeAxisVec(const AxisVector& axis_vec) const { const int64_t num_axes = this->NumAxes(); AxisVector ret = axis_vec; for (int64_t i = 0; i < axis_vec.size(); i++) { ret.at(i) = ShiftNegativeAxis(axis_vec.at(i), num_axes); } return ret; } Shape Shape::RemoveOnes(const AxisVector& axis_vec) const { DimVector dim_vec; const AxisVector& axis_vec_shifted = ShiftNegativeAxisVec(axis_vec); for (int64_t i = 0; i < this->dim_vec().size(); i++) { if (std::find(axis_vec_shifted.begin(), axis_vec_shifted.end(), i) == axis_vec_shifted.end()) { dim_vec.emplace_back(this->dim_vec().at(i)); } else { CHECK_EQ(this->dim_vec().at(i), 1); } } return Shape(dim_vec); } Shape Shape::Ones(const int64_t num_axes) { DimVector dim_vec(num_axes); std::fill(dim_vec.begin(), dim_vec.end(), 1); return Shape(dim_vec); } AxisVector Shape::Axes4BroadcastTo(ShapeView broadcast_shape) const { AxisVector broadcast_axis_vec; CHECK_EQ(broadcast_shape.NumAxes(), NumAxes()); for (int64_t i = 0; i < NumAxes(); i++) { if (this->dim_vec().at(i) != broadcast_shape[i] && this->dim_vec().at(i) == 1) { broadcast_axis_vec.emplace_back(i); } else { CHECK_EQ(this->dim_vec().at(i), broadcast_shape[i]); } } CHECK(!broadcast_axis_vec.empty()); return broadcast_axis_vec; } Maybe Shape::Slice(int64_t start_dim, int64_t end_dim) const { CHECK_OR_RETURN(start_dim >= 0 && end_dim >= start_dim); int64_t ndims = this->NumAxes(); if (start_dim > ndims) { start_dim = ndims; } if (end_dim > ndims) { end_dim = ndims; } std::shared_ptr shape = std::make_shared(); shape->assign(this->begin() + start_dim, this->begin() + end_dim); return shape; } Maybe Shape::Slice(int64_t start_dim) const { return Slice(start_dim, NumAxes()); } bool Shape::operator==(const Shape& rhs) const { if (is_initialized_ != rhs.is_initialized_) { return false; } if (is_initialized_ == false) { return true; } if (this->NumAxes() != rhs.NumAxes()) { return false; } FOR_RANGE(int, i, 0, this->NumAxes()) { if (this->At(i) != rhs.At(i)) { return false; } } return true; } } // namespace oneflow ================================================ FILE: oneflow/core/common/shape.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SHAPE_H_ #define ONEFLOW_CORE_COMMON_SHAPE_H_ #include "oneflow/core/common/shape.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/optional.h" namespace oneflow { class ShapeView; class MutShapeView; class ShapeProto; namespace cfg { class ShapeProto; } // namespace cfg /** * NOTE: * * There are two widely used shape-related classes: Shape and ShapeView. * The differences are: * 1. Shape owns the data, and ShapeView does not. * 2. ShapeView is very lightweight, whose size is only 16 bytes (two int64_t). * So it should be passed by value. * * When adding new functions accepting a shape as a parameter, please follow * the rules: * 1. If your function doesn't modify the shape, prefer * ShapeView. Shape can be implicitly converted to ShapeView so the method * with ShapeView parameter can accept both Shape and ShapeView actually. * 2. If your function modify the shape but doesn't affect * its rank, prefer MutShapeView. The reason is the same with rule 1. * 3. Use Shape otherwise. * * When adding new member methods of Shape or ShapeView, please follow * the rules: * 1. If the method is shared between Shape and ShapeView (like `NumAxes()`) * please add it to ConstShapeMixIn. * 2. If the method is shared between Shape and MutShapeView (like `Set()`) * please add it to MutShapeMixIn. * 3. Otherwise, add it to a concrete class (Shape, ShapeView or MutShapeView). * */ template struct ConstShapeMixIn { using DimType = int64_t; int64_t NumAxes() const { return tp()->size(); } int64_t elem_cnt() const; int64_t At(int64_t index) const; int64_t Count(int64_t begin_axis, int64_t end_axis) const; int64_t Count(int64_t begin_axis) const; bool Containing(ShapeView small_shape) const; bool MatchBeforeLastDim(ShapeView next_shape) const; std::string ToString() const; std::string DebugStr() const; void ToProto(ShapeProto* ret) const; template void SerializeWithTextFormat(StreamT& out_stream) const { for (int64_t dim : *this) { out_stream << std::to_string(dim) << ' '; } } bool operator==(const T& rhs) const; protected: // tp means "this pointer" T* tp() { return static_cast(this); } const T* tp() const { return static_cast(this); } }; template struct MutShapeMixIn : public ConstShapeMixIn { void Set(int64_t index, int64_t val) { CHECK_GE(index, 0); CHECK_LT(index, this->tp()->NumAxes()) << " Shape: " << this->tp()->DebugStr() << " visit index: " << index << " > num_axes: " << this->tp()->NumAxes(); (*this->tp())[index] = val; } }; class Shape final : public DimVector, public MutShapeMixIn { public: // OF_DISALLOW_COPY_AND_MOVE(Shape); using DimVector::DimVector; Shape() : is_initialized_(false) {} explicit Shape(const DimVector& dim_vec); explicit Shape(DimVector&& dim_vec); explicit Shape(const ShapeProto& shape_proto); // explicit constructor from ShapeView explicit Shape(ShapeView shape_view); ~Shape() = default; using DimVector::operator==; #define OVERRIDE_ADD_DATA_FUNC(func) \ template \ void func(Args... args) { \ DimVector::func(std::forward(args)...); \ is_initialized_ = true; \ } OVERRIDE_ADD_DATA_FUNC(assign) OVERRIDE_ADD_DATA_FUNC(push_back) OVERRIDE_ADD_DATA_FUNC(emplace_back) OVERRIDE_ADD_DATA_FUNC(append) OVERRIDE_ADD_DATA_FUNC(insert) OVERRIDE_ADD_DATA_FUNC(resize) #undef OVERRIDE_ADD_DATA_FUNC Shape& CheckNumAxesIdenticalAndAssign(ShapeView shape_view); Shape& LeftOnesExtendedAssign(ShapeView shape_view); // Getters and Setters bool is_initialized() const { return is_initialized_; } const DimVector& dim_vec() const { return *this; } DimVector& dim_vec() { return *this; } int64_t NumAxes() const { CHECK(is_initialized()); return ConstShapeMixIn::NumAxes(); } AxisVector ShiftNegativeAxisVec(const AxisVector& axis_vec) const; Shape RemoveOnes(const AxisVector& axis_vec) const; static Shape Ones(const int64_t num_axes); AxisVector Axes4BroadcastTo(ShapeView broadcast_dim_vec) const; Maybe Slice(int64_t start_dim, int64_t end_dim) const; Maybe Slice(int64_t start_dim) const; bool operator==(const Shape& rhs) const; private: // Set default value here because some constructors are inherited from DimVector // TODO(daquexian): remove this field and make it initializied by construction bool is_initialized_ = true; }; int64_t ShiftNegativeAxis(int64_t axis, const int64_t num_axes); Shape CreateReducedShape(ShapeView shape, const AxisVector& axis_vec); Shape CreateLeftExtendedShape(ShapeView shape, int ndims_extend_to); Shape ExpandDimIf0D(const Shape& shape); Shape ExpandDimIf0D(ShapeView shape); Shape CreateReducedShapeOrOnesShape(ShapeView shape, const AxisVector& axis_vec); std::ostream& operator<<(std::ostream& out, const Shape& shape); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::Shape& shape) const { if (!shape.is_initialized()) { return 0; } size_t ret = shape.NumAxes(); FOR_RANGE(int, i, 0, shape.NumAxes()) { oneflow::AddHash(&ret, shape.At(i)); } return ret; } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_SHAPE_H_ ================================================ FILE: oneflow/core/common/shape.proto ================================================ syntax = "proto2"; package oneflow; // NOTE: shape.proto can be replaced with sequential.proto // for compatibility reasons, it will not be modified here. message ShapeProto { repeated int64 dim = 1; } ================================================ FILE: oneflow/core/common/shape_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape.h" #include "gtest/gtest.h" #include #include namespace oneflow { namespace test { TEST(Shape, constructor_0) { Shape a; ASSERT_EQ(a.is_initialized(), false); } TEST(Shape, function_test_1) { Shape shape({4096, 16, 197, 197}); ASSERT_EQ(shape.is_initialized(), true); ASSERT_EQ(shape.NumAxes(), 4); ASSERT_EQ(shape.elem_cnt(), 2543386624); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/shape_vec.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SHAPE_VEC_H_ #define ONEFLOW_CORE_COMMON_SHAPE_VEC_H_ #include "oneflow/core/common/small_vector.h" namespace oneflow { #define SHAPE_MAX_AXIS_SIZE 20 typedef small_vector DimVector; typedef small_vector AxisVector; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SHAPE_VEC_H_ ================================================ FILE: oneflow/core/common/shape_view.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape.pb.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { void ShapeView::ToDimVector(DimVector* dim_vec) const { dim_vec->resize(this->size()); dim_vec->assign(this->data(), this->data() + this->size()); } void ShapeView::ToShape(Shape* shape) const { DimVector dim_vec; this->ToDimVector(&dim_vec); *shape = Shape(dim_vec); } std::ostream& operator<<(std::ostream& out, ShapeView shape) { out << shape.ToString(); return out; } void MutShapeView::set_shape(ShapeView shape) { if (shape.ptr() == mut_ptr() && shape.NumAxes() == NumAxes()) { return; } CHECK_EQ(NumAxes(), shape.NumAxes()); std::copy(shape.ptr(), shape.ptr() + shape.NumAxes(), mut_ptr()); } } // namespace oneflow ================================================ FILE: oneflow/core/common/shape_view.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_ #define ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_ #include "oneflow/core/common/array_ref.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" namespace oneflow { class ShapeProto; class Shape; class ShapeView : public ArrayRef, public ConstShapeMixIn { public: ShapeView() = default; // NOLINTNEXTLINE ShapeView(const ShapeProto& shape_proto) : ArrayRef(shape_proto.dim().data(), shape_proto.dim_size()){}; // NOLINTNEXTLINE ShapeView(const Shape& shape) : ArrayRef(shape.dim_vec().data(), shape.dim_vec().size()){}; using ArrayRef::ArrayRef; const DimType* ptr() const { return this->data(); } void ToDimVector(DimVector* dim_vec) const; void ToShape(Shape* shape) const; }; std::ostream& operator<<(std::ostream& out, ShapeView shape); class MutShapeView final : public MutableArrayRef, public MutShapeMixIn { public: using MutableArrayRef::MutableArrayRef; // NOLINTNEXTLINE MutShapeView(Shape& shape) : MutableArrayRef(shape.dim_vec().data(), shape.dim_vec().size()){}; int64_t* mut_ptr() const { return this->data(); } void set_shape(ShapeView shape); }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_ ================================================ FILE: oneflow/core/common/shared_or_scalar.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_ #define ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_ #include #include "oneflow/core/common/throw.h" #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { template class SharedOrScalar final { public: static_assert(IsScalarType::value, "ScalarT should be scalar type."); using Shared = std::shared_ptr; SharedOrScalar(const ScalarT& scalar_value) : is_scalar_(true), scalar_value_(scalar_value) {} SharedOrScalar(const std::shared_ptr& shared_ptr) : is_scalar_(false) { new (&shared_mem_) Shared(shared_ptr); } SharedOrScalar(std::shared_ptr&& shared_ptr) : is_scalar_(false) { new (&shared_mem_) Shared(std::move(shared_ptr)); } SharedOrScalar(const SharedOrScalar& rhs) : is_scalar_(rhs.is_scalar_) { if (rhs.is_scalar_) { scalar_value_ = rhs.scalar_value_; } else { new (&shared_mem_) Shared(rhs.GetShared()); } } SharedOrScalar(SharedOrScalar&& rhs) : is_scalar_(rhs.is_scalar_) { if (rhs.is_scalar_) { scalar_value_ = rhs.scalar_value_; } else { new (&shared_mem_) Shared(std::move(*rhs.MutableShared())); } } SharedOrScalar& operator=(const SharedOrScalar& rhs) { if (rhs.is_scalar_) { scalar_value_ = rhs.scalar_value_; } else { if (is_scalar_) { scalar_value_.~ScalarT(); new (&shared_mem_) Shared(rhs.GetShared()); } else { *MutableShared() = rhs.GetShared(); } } is_scalar_ = rhs.is_scalar_; return *this; } SharedOrScalar& operator=(SharedOrScalar&& rhs) { if (rhs.is_scalar_) { scalar_value_ = rhs.scalar_value_; } else { if (is_scalar_) { scalar_value_.~ScalarT(); new (&shared_mem_) Shared(std::move(*rhs.MutableShared())); } else { *MutableShared() = std::move(*rhs.MutableShared()); } } is_scalar_ = rhs.is_scalar_; return *this; } ~SharedOrScalar() { if (is_scalar_) { scalar_value_.~ScalarT(); } else { GetShared().~Shared(); } } bool IsScalar() const { return is_scalar_; } const ScalarT& scalar_value() const { CHECK(is_scalar_); return scalar_value_; } const std::shared_ptr& shared_ptr() const { CHECK(!is_scalar_); return GetShared(); } const ScalarT& operator*() const { return scalar_value(); } private: bool is_scalar_; union { ScalarT scalar_value_; // to avoid error(a non-POD class definition is not allowed inside of a statement expression) // in nvcc while using with JUST macro (this type is used in Maybe) alignas(Shared) char shared_mem_[sizeof(Shared)]; }; const Shared& GetShared() const { const auto* __attribute__((__may_alias__)) shared = reinterpret_cast(&shared_mem_); return *shared; } Shared* MutableShared() { auto* __attribute__((__may_alias__)) shared = reinterpret_cast(&shared_mem_); return shared; } }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_ ================================================ FILE: oneflow/core/common/single_thread_obj_pool.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_ #define ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_ #include #include #include #include #include "oneflow/core/common/throw.h" #include "oneflow/core/common/cpp_attribute.h" namespace oneflow { namespace obj_pool { enum ReuseStrategy { kEnableReconstruct, kDisableReconstruct, }; // object pool for single thread. template class SingleThreadObjPool : public std::enable_shared_from_this> { public: SingleThreadObjPool() : pool_(), invalid_thread_id_(), owner_thread_id_(invalid_thread_id_) { pool_.reserve(kInitPoolCap); } ~SingleThreadObjPool() { if (reuse_strategy != kEnableReconstruct) { for (T* ptr : pool_) { delete ptr; } } } template std::shared_ptr make_shared(Args&&... args) { auto* ptr = New(std::forward(args)...); std::weak_ptr pool(this->shared_from_this()); return std::shared_ptr(ptr, [pool](T* ptr) { TryPut(pool.lock(), ptr); }); } private: static constexpr int kInitPoolCap = 1024; template T* New(Args&&... args) { if (likely(pool_.size())) { auto* ptr = Get(); if (reuse_strategy == kEnableReconstruct) { new (ptr) T(std::forward(args)...); } return ptr; } return new T(std::forward(args)...); } static void TryPut(const std::shared_ptr& pool, T* object) { if (likely(static_cast(pool))) { pool->Put(object); } else { object->~T(); } } T* Get() { CheckOrSetSingleThreadFlag(); auto* ptr = pool_[pool_.size() - 1]; pool_.pop_back(); return ptr; } void Put(T* obj) { CheckOrSetSingleThreadFlag(); pool_.push_back(obj); if (reuse_strategy == kEnableReconstruct) { obj->~T(); } } // Try to detect being wrongly used by multi threads, because SingleThreadObjPool does not // guarantee thread safety. This function also is not thread safe, but it's not a big problem. In // the most cases, bugs will be successfully detected even thread unsafe behaviors happen. void CheckOrSetSingleThreadFlag() { if (unlikely(owner_thread_id_ == invalid_thread_id_)) { owner_thread_id_ = std::this_thread::get_id(); } else { CHECK(likely(owner_thread_id_ == std::this_thread::get_id())); } } std::vector pool_; std::thread::id invalid_thread_id_; std::thread::id owner_thread_id_; }; } // namespace obj_pool } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_ ================================================ FILE: oneflow/core/common/single_thread_obj_pool_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/single_thread_obj_pool.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace obj_pool { namespace test { TEST(SingleThreadObjPool, naive) { auto pool = std::make_shared>(); auto* ptr = pool->make_shared().get(); ASSERT_EQ(ptr, pool->make_shared().get()); } struct Int { // NOLINT Int() : x(0) {} explicit Int(int val) : x(val) {} ~Int() { x = 0; } int x; }; TEST(SingleThreadObjPool, enable_reconstruct) { auto pool = std::make_shared>(); (void)pool->make_shared(333); ASSERT_EQ(0, pool->make_shared()->x); } TEST(SingleThreadObjPool, disable_reconstruct) { auto pool = std::make_shared>(); int value = pool->make_shared(333)->x; ASSERT_EQ(value, pool->make_shared()->x); } } // namespace test } // namespace obj_pool } // namespace oneflow ================================================ FILE: oneflow/core/common/singleton.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SINGLETON_H_ #define ONEFLOW_CORE_COMMON_SINGLETON_H_ #include "oneflow/core/common/throw.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/constant.h" namespace oneflow { template class Singleton final { public: static T* Get() { return *GetPPtr(); } static void SetAllocated(T* val) { *GetPPtr() = val; } template static T* New(Args&&... args) { CHECK(Get() == nullptr); VLOG(3) << "NewGlobal " << typeid(T).name(); T* ptr = new T(std::forward(args)...); *GetPPtr() = ptr; return ptr; } static void Delete() { if (Get() != nullptr) { VLOG(3) << "DeleteGlobal " << typeid(T).name(); delete Get(); *GetPPtr() = nullptr; } } private: static T** GetPPtr() { CheckKind(); static T* ptr = nullptr; return &ptr; } static void CheckKind() { if (!std::is_same::value) { CHECK(Singleton::Get() == nullptr) << typeid(Singleton).name() << " are disable for avoiding misuse"; } } }; template Maybe SingletonMaybe() { CHECK_NOTNULL_OR_RETURN((Singleton::Get())) << " typeid: " << typeid(T).name(); return Singleton::Get(); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SINGLETON_H_ ================================================ FILE: oneflow/core/common/sized_buffer_view.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_ #define ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_ namespace oneflow { struct SizedBufferView { size_t capacity; // allocated memory size for `data' field size_t size; // valid data size char data[0]; }; } // namespace oneflow #endif // ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_ ================================================ FILE: oneflow/core/common/small_vector.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_ #define ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_ #include "llvm/ADT/SmallVector.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/common/check.h" namespace oneflow { template class small_vector : public llvm::SmallVector { using Base = llvm::SmallVector; public: constexpr static size_t kInitialSize = N; // https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang using Base::Base; typename Base::reference at(typename Base::size_type idx) { GLOGCHECK(idx < Base::size()); return (*this)[idx]; } typename Base::const_reference at(typename Base::size_type idx) const { GLOGCHECK(idx < Base::size()); return (*this)[idx]; } typename Base::reference operator[](typename Base::size_type idx) { return this->data()[idx]; } typename Base::const_reference operator[](typename Base::size_type idx) const { return this->data()[idx]; } typename Base::const_iterator cbegin() const { return (typename Base::const_iterator)this->BeginX; } typename Base::const_iterator cend() const { return (typename Base::const_iterator)(this->BeginX) + Base::size(); } typename Base::const_iterator cbegin() { return (typename Base::const_iterator)this->BeginX; } typename Base::const_iterator cend() { return (typename Base::const_iterator)(this->BeginX) + Base::size(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_ ================================================ FILE: oneflow/core/common/spin_counter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/spin_counter.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/foreign_lock_helper.h" namespace oneflow { Maybe SpinCounter::WaitUntilCntEqualZero() const { return Singleton::Get()->WithScopedRelease([&]() -> Maybe { while (cnt_val_ > 0) {} return Maybe::Ok(); }); } } // namespace oneflow ================================================ FILE: oneflow/core/common/spin_counter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_ #define ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_ #include #include "oneflow/core/common/maybe.h" namespace oneflow { class SpinCounter final { public: SpinCounter() = delete; SpinCounter(const SpinCounter&) = delete; SpinCounter(SpinCounter&&) = delete; ~SpinCounter() = default; explicit SpinCounter(int64_t cnt_val) : cnt_val_(cnt_val) {} int64_t Decrease() { return --cnt_val_; } void Reset(int64_t cnt_val) { cnt_val_ = cnt_val; } Maybe WaitUntilCntEqualZero() const; private: std::atomic cnt_val_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_ ================================================ FILE: oneflow/core/common/static_check.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_STATIC_CHECK_H_ #define ONEFLOW_CORE_COMMON_STATIC_CHECK_H_ #include "type_traits.h" namespace oneflow { namespace private_details { template class Predicator> struct StaticReduce { template struct All; template struct All { static_assert(std::is_same::value, ""); static constexpr bool value = true; }; template struct All { static constexpr bool value = Predicator::value && All::value; }; template struct Any; template struct Any { static_assert(std::is_same::value, ""); static constexpr bool value = false; }; template struct Any { static constexpr bool value = Predicator::value || Any::value; }; }; } // namespace private_details template class Predicator, typename... Args> struct StaticAll { static constexpr bool value = private_details::StaticReduce::template All::value; }; template class Predicator, typename... Args> struct StaticAny { static constexpr bool value = private_details::StaticReduce::template Any::value; }; template struct IsOutArg { static constexpr bool value = (std::is_reference::value && !std::is_const::type>::value) || (std::is_pointer::value && !std::is_const::type>::value); }; template struct IsDecayedScalarType { static constexpr bool value = IsScalarType::type>::value; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_STATIC_CHECK_H_ ================================================ FILE: oneflow/core/common/static_global.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_ #define ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_ #include #include "oneflow/core/common/decorator.h" namespace oneflow { template struct StaticGlobalCopiable; template struct StaticGlobalCopiable { template static RetT Call() { static RetT value = func(); return value; } }; template struct StaticGlobalCopiable { template static RetT Call(Arg0 arg0) { using KeyT = typename std::decay::type; using MappedT = typename std::decay::type; static std::mutex mutex; static std::unordered_map map; { std::unique_lock lock(mutex); auto iter = map.find(arg0); if (iter != map.end()) { return iter->second; } } auto obj = func(arg0); { std::unique_lock lock(mutex); return map.emplace(arg0, std::move(obj)).first->second; } } private: static_assert(!IsOutArg::value, ""); static_assert(!StaticAny::value, ""); }; template struct StaticGlobalCopiable { template static RetT Call(Arg0 arg0, Arg1 arg1, Args... args) { using KeyT0 = typename std::decay::type; using KeyT1 = typename std::decay::type; using KeyT = std::tuple::type...>; using MappedT = typename std::decay::type; static std::mutex mutex; static std::unordered_map map; const auto& key = KeyT(arg0, arg1, args...); { std::unique_lock lock(mutex); auto iter = map.find(key); if (iter != map.end()) { return iter->second; } } auto obj = func(arg0, arg1, args...); { std::unique_lock lock(mutex); return map.emplace(key, std::move(obj)).first->second; } } private: static_assert(!StaticAny::value, ""); }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_ ================================================ FILE: oneflow/core/common/steady_vector.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ #define ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ #include #include #include #include #include "oneflow/core/common/throw.h" namespace oneflow { template class SteadyVector { public: SteadyVector() : size_(0) {} ~SteadyVector() = default; using value_type = const T; using size_type = size_t; // thread safe. size_t size() const { return size_.load(std::memory_order_acquire); } // thread safe. const T& at(size_t index) const { CHECK_GE(index, 0); CHECK_LT(index, size_); return (*this)[index]; } // thread safe. const T& operator[](size_t index) const { int gran = 0; size_t start = 0; GetGranularityAndStart(index, &gran, &start); return granularity2data_[gran].get()[index - start]; } // `index` should be <= size() void SetOrAdd(size_t index, T value) { std::unique_lock lock(mutex_); size_t size = size_.load(std::memory_order_relaxed); CHECK_LE(index, size) << "index out of range"; if (index == size) { int granularity = GetGranularity(size); if (size + 1 == (1 << granularity)) { CHECK_LT(granularity, N); granularity2data_[granularity].reset(new T[1 << granularity]); } *Mutable(index) = std::move(value); size_.fetch_add(1, std::memory_order_release); } else { *Mutable(index) = std::move(value); } } void push_back(const T& elem) { SetOrAdd(size_, elem); } private: T* Mutable(size_t index) { int gran = 0; size_t start = 0; GetGranularityAndStart(index, &gran, &start); return &granularity2data_[gran].get()[index - start]; } static void GetGranularityAndStart(size_t index, int* gran, size_t* start) { *gran = GetGranularity(index); *start = (1 << *gran) - 1; } #ifdef __GNUC__ #define LOG2(x) ((unsigned)(8 * sizeof(unsigned long long) - __builtin_clzll((x)) - 1)) #else #define LOG2(x) std::log2(x) #endif static int GetGranularity(size_t index) { return LOG2(index + 1); } #undef LOG2 std::atomic size_; std::mutex mutex_; std::array, N> granularity2data_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ ================================================ FILE: oneflow/core/common/steady_vector_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/steady_vector.h" namespace oneflow { namespace test { void TestSteadyVector(int granularity) { CHECK_GT(granularity, 0); SteadyVector vec; ASSERT_EQ(vec.size(), 0); for (int i = 0; i < (1 << granularity); ++i) { vec.push_back(i); ASSERT_EQ(vec.at(i), i); ASSERT_EQ(vec.size(), i + 1); } } TEST(SteadyVector, simple) { TestSteadyVector(6); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/str_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/str_util.h" namespace oneflow { namespace internal { std::string JoinPathImpl(std::initializer_list paths) { std::string result; for (const std::string& path : paths) { if (path.empty()) continue; if (result.empty()) { result = path; continue; } if (result[result.size() - 1] == '/') { if (IsAbsolutePath(path)) { result.append(path.substr(1)); } else { result.append(path); } } else { if (IsAbsolutePath(path)) { result.append(path); } else { result += ("/" + path); } } } return result; } std::string GetHashKeyImpl(std::initializer_list integers) { std::string result = ""; for (int integer : integers) { result += std::to_string(integer) + ","; } return result; } } // namespace internal const char* StrToToken(const char* text, const std::string& delims, std::string* token) { token->clear(); while (*text != '\0' && delims.find(*text) != std::string::npos) { text++; } while (*text != '\0' && delims.find(*text) == std::string::npos) { token->push_back(*text++); } return text; } void Split(const std::string& text, const std::string& delims, std::function Func) { size_t token_start = 0; if (text.empty()) { return; } for (size_t i = 0; i < text.size() + 1; ++i) { if ((i == text.size()) || (delims.find(text[i]) != std::string::npos)) { Func(text.substr(token_start, i - token_start)); token_start = i + 1; } } } std::string Dirname(const std::string& path) { size_t found = path.rfind('/'); if (found == std::string::npos) { return ""; } if (found == 0) { return "/"; } return path.substr(0, found); } std::string Basename(const std::string& path) { size_t found = path.rfind('/'); if (found == std::string::npos) { return path; } return path.substr(found + 1); } std::string CleanPath(const std::string& unclean_path) { std::string path = unclean_path; const char* src = path.c_str(); std::string::iterator dst = path.begin(); // Check for absolute path and determine initial backtrack limit. const bool is_absolute_path = *src == '/'; if (is_absolute_path) { *dst++ = *src++; while (*src == '/') ++src; } std::string::const_iterator backtrack_limit = dst; // Process all parts while (*src) { bool parsed = false; if (src[0] == '.') { // 1dot ".", check for END or SEP. if (src[1] == '/' || !src[1]) { if (*++src) { ++src; } parsed = true; } else if (src[1] == '.' && (src[2] == '/' || !src[2])) { // 2dot END or SEP (".." | "../"). src += 2; if (dst != backtrack_limit) { // We can backtrack the previous part for (--dst; dst != backtrack_limit && dst[-1] != '/'; --dst) { // Empty. } } else if (!is_absolute_path) { // Failed to backtrack and we can't skip it either. Rewind and copy. src -= 2; *dst++ = *src++; *dst++ = *src++; if (*src) { *dst++ = *src; } // We can never backtrack over a copied "../" part so set new limit. backtrack_limit = dst; } if (*src) { ++src; } parsed = true; } } // If not parsed, copy entire part until the next SEP or EOS. if (!parsed) { while (*src && *src != '/') { *dst++ = *src++; } if (*src) { *dst++ = *src++; } } // Skip consecutive SEP occurrences while (*src == '/') { ++src; } } // Calculate and check the length of the cleaned path. std::string::difference_type path_length = dst - path.begin(); if (path_length != 0) { // Remove trailing '/' except if it is root path ("/" ==> path_length := 1) if (path_length > 1 && path[path_length - 1] == '/') { --path_length; } path.resize(path_length); } else { // The cleaned path is empty; assign "." as per the spec. path.assign(1, '.'); } return path; } void GetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index) { const size_t underline_pos = prefix_and_idx.rfind('_'); CHECK_NE(underline_pos, std::string::npos); CHECK_GT(underline_pos, 0); CHECK_LT(underline_pos, prefix_and_idx.size() - 1); *prefix = prefix_and_idx.substr(0, underline_pos); *index = oneflow_cast(prefix_and_idx.substr(underline_pos + 1)); CHECK_GE(*index, 0); } bool TryGetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index) { const size_t underline_pos = prefix_and_idx.rfind('_'); if (underline_pos == std::string::npos) { return false; } if (underline_pos == 0) { return false; } if (underline_pos == prefix_and_idx.size() - 1) { return false; } *prefix = prefix_and_idx.substr(0, underline_pos); std::string index_str = prefix_and_idx.substr(underline_pos + 1); if (IsStrInt(index_str) == false) { return false; } *index = oneflow_cast(index_str); return *index >= 0; } std::string ToLower(const std::string& cap) { std::string small; std::transform(cap.begin(), cap.end(), small.begin(), [](unsigned char c) { return std::tolower(c); }); return small; } // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c std::string GenAlphaNumericString(size_t len) { static thread_local const std::string alphanum("0123456789" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz"); std::string tmp_s; tmp_s.reserve(len); std::random_device rd{}; std::mt19937 mt(rd()); std::uniform_int_distribution<> dist(0, 1024); for (int i = 0; i < len; ++i) { tmp_s += alphanum.at(dist(mt) % alphanum.size()); } return tmp_s; } } // namespace oneflow ================================================ FILE: oneflow/core/common/str_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_STR_UTIL_H_ #define ONEFLOW_CORE_COMMON_STR_UTIL_H_ #include #include #include "oneflow/core/common/util.h" namespace oneflow { inline bool IsStrInt(const std::string& s) { if (s.empty() || (!isdigit(s[0]) && (s[0] != '-'))) { return false; } char* end_ptr = nullptr; strtoll(s.c_str(), &end_ptr, 0); return (*end_ptr == 0); } inline std::string StrCat(const std::string& prefix, int64_t id) { return prefix + std::to_string(id); } inline void StringReplace(std::string* str, char old_ch, char new_ch) { for (size_t i = 0; i < str->size(); ++i) { if (str->at(i) == old_ch) { str->at(i) = new_ch; } } } const char* StrToToken(const char* text, const std::string& delims, std::string* token); void Split(const std::string& text, const std::string& delims, std::function Func); template void SplitAndParseAs(const std::string& text, const std::string& delims, std::function Func) { Split(text, delims, [&Func](std::string&& s) { Func(oneflow_cast(s)); }); } // Return true if path is absolute. inline bool IsAbsolutePath(const std::string& path) { return !path.empty() && path[0] == '/'; } void GetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index); bool TryGetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index); namespace internal { std::string JoinPathImpl(std::initializer_list paths); std::string GetHashKeyImpl(std::initializer_list integers); } // namespace internal // Join multiple paths together, without introducing unnecessary path // separators. // For example: // // Arguments | JoinPath // ---------------------------+---------- // '/foo', 'bar' | /foo/bar // '/foo/', 'bar' | /foo/bar // '/foo', '/bar' | /foo/bar // // Usage: // string path = JoinPath("/mydir", filename); // string path = JoinPath(FLAGS_test_srcdir, filename); // string path = JoinPath("/full", "path", "to", "filename); template std::string JoinPath(const T&... args) { return internal::JoinPathImpl({args...}); } // Returns the part of the path before the final "/". If there is a single // leading "/" in the path, the result will be the leading "/". If there is // no "/" in the path, the result is the empty prefix of the input. std::string Dirname(const std::string& path); // Returns the part of the path after the final "/". If there is no // "/" in the path, the result is the same as the input. std::string Basename(const std::string& path); // Collapse duplicate "/"s, resolve ".." and "." path elements, remove // trailing "/". // // NOTE: This respects relative vs. absolute paths, but does not // invoke any system calls (getcwd(2)) in order to resolve relative // paths with respect to the actual working directory. That is, this is purely // string manipulation, completely independent of process state. std::string CleanPath(const std::string& path); template std::string GetHashKey(const T&... args) { return internal::GetHashKeyImpl({args...}); } std::string ToLower(const std::string& cap); std::string GenAlphaNumericString(size_t len); template const std::string& ReturnEmptyStr(const CallbackT& Callback) { Callback(); static std::string empty{}; return empty; } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_STR_UTIL_H_ ================================================ FILE: oneflow/core/common/stream_type.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_STREAM_TYPE_H_ #define ONEFLOW_CORE_COMMON_STREAM_TYPE_H_ #include #include #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/throw.h" namespace oneflow { enum class StreamType { kInvalid = 0, kCompute, kHost2Device, kDevice2Host, kCcl, kBarrier, kCriticalSection, kLazyJobLauncher, kPinnedCompute }; template struct StreamTypeVisitor { template static auto Visit(StreamType stream_type, Args&&... args) { switch (stream_type) { case StreamType::kInvalid: LOG(FATAL) << "invalid stream type"; case StreamType::kCompute: return DerivedT::VisitCompute(std::forward(args)...); case StreamType::kHost2Device: return DerivedT::VisitHost2Device(std::forward(args)...); case StreamType::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward(args)...); case StreamType::kCcl: return DerivedT::VisitCcl(std::forward(args)...); case StreamType::kBarrier: return DerivedT::VisitBarrier(std::forward(args)...); case StreamType::kCriticalSection: return DerivedT::VisitCriticalSection(std::forward(args)...); case StreamType::kLazyJobLauncher: return DerivedT::VisitLazyJobLauncher(std::forward(args)...); case StreamType::kPinnedCompute: return DerivedT::VisitPinnedCompute(std::forward(args)...); } LOG(FATAL) << "invalid stream type"; } }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::StreamType& stream_type) const { return static_cast(stream_type); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_STREAM_TYPE_H_ ================================================ FILE: oneflow/core/common/stride.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/stride.h" #include "oneflow/core/common/constant.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { Stride::Stride(const ShapeView& shape) { const int64_t ndim = shape.NumAxes(); resize(ndim); if (ndim > 0 && shape.elem_cnt() > 0) { std::exclusive_scan(shape.rbegin(), shape.rend(), rbegin(), (int64_t)1, std::multiplies<>{}); } else if (ndim > 0 && shape.elem_cnt() == 0) { // 0-size shape small_vector tmp_shape(ndim); for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; } std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1, std::multiplies<>{}); } } Stride::Stride(const Shape& shape) { if (shape.is_initialized()) { ShapeView shape_view(shape); new (this) Stride(shape_view); } } Stride::Stride(const std::shared_ptr& shape) : Stride(*shape) {} Stride::Stride(const Int64ListProto& stride_proto) : DimVector(stride_proto.dim().begin(), stride_proto.dim().end()) {} Stride& Stride::CheckNumAxesIdenticalAndAssign(const Stride& stride) { CHECK_EQ(size(), stride.size()); assign(stride); return *this; } std::string Stride::ToString() const { std::stringstream ss; int32_t idx = 0; ss << "("; for (int64_t dim : *this) { ss << dim; if (++idx != this->size() || this->size() == 1) { ss << ","; } } ss << ")"; return ss.str(); } void Stride::ToProto(Int64ListProto* ret) const { *(ret->mutable_dim()) = PbRf(begin(), end()); } std::ostream& operator<<(std::ostream& out, const Stride& stride) { out << stride.ToString(); return out; } } // namespace oneflow ================================================ FILE: oneflow/core/common/stride.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STRIDE_H_ #define ONEFLOW_CORE_FRAMEWORK_STRIDE_H_ #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/sequential.pb.h" #include "oneflow/core/common/util.h" namespace oneflow { class Int64ListProto; class Stride final : public DimVector { public: Stride() = default; using DimVector::DimVector; explicit Stride(const ShapeView& shape); explicit Stride(const Shape& shape); explicit Stride(const std::shared_ptr& shape); explicit Stride(const Int64ListProto& stride_proto); Stride& CheckNumAxesIdenticalAndAssign(const Stride& stride); ~Stride() = default; std::string ToString() const; void ToProto(Int64ListProto*) const; }; std::ostream& operator<<(std::ostream& out, const Stride& stride); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::Stride& stride) const { size_t ret = stride.size(); FOR_RANGE(int, i, 0, stride.size()) { oneflow::AddHash(&ret, stride.at(i)); } return ret; } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_STRIDE_H_ ================================================ FILE: oneflow/core/common/switch_func.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_ #define ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_ #include "oneflow/core/common/preprocessor.h" #include #include template auto SwitchCase(Args&&... args) { return std::make_tuple(std::forward(args)...); } #define DEFINE_STATIC_SWITCH_FUNC(return_type, func_name, make_switch_entry, ctrv_seq, ...) \ DEFINE_STATIC_SWITCH_FUNC_FROM_TUPLE(return_type, func_name, make_switch_entry, \ OF_PP_CAT((ctrv_seq, ##__VA_ARGS__), )) #define DEFINE_STATIC_SWITCH_FUNC_FROM_TUPLE(return_type, func_name, make_switch_entry, \ ctrv_seq_tuple) \ template \ static return_type Switch##func_name( \ const OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE(ctrv_seq_tuple) & switch_tuple, \ Args && ... args) { \ static const std::map> \ case_handlers{OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE(make_switch_entry, func_name, \ Args, ctrv_seq_tuple)}; \ return case_handlers.at(switch_tuple)(std::forward(args)...); \ } // CTRV: Compile-time Token and Runtime Value pair, // CTRV example: (float, DataType::kFloat) // TYPED_CTRV_SEQ example: (DataType, ((float, DataType::kFloat))) #define MAKE_DATA_TYPE_CTRV_SEQ(data_type_seq) MAKE_TYPED_CTRV_SEQ(DataType, data_type_seq) #define MAKE_DEVICE_TYPE_CTRV_SEQ(device_type_seq) \ MAKE_TYPED_CTRV_SEQ(DeviceType, \ OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, device_type_seq)) #define MAKE_NDIM_CTRV_SEQ(ndim_seq) \ MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, ndim_seq)) #define MAKE_STRINGIZED_DATA_TYPE_CTRV(data_type_pair) \ (OF_PP_PAIR_FIRST(data_type_pair), OF_PP_STRINGIZE(OF_PP_PAIR_FIRST(data_type_pair))) #define MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(data_type_seq) \ (std::string, OF_PP_SEQ_MAP(MAKE_STRINGIZED_DATA_TYPE_CTRV, data_type_seq)) #define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq) // internal preprocessor macros #define OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(switch_case, func_args_type, func) \ {switch_case, \ [](func_args_type&&... args) { return func(std::forward(args)...); }}, #define OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x) #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1(make_template_func, func_name, func_args_type, \ switch_case_pair0) \ OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR( \ SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0)), func_args_type, \ make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0))) #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_2(make_template_func, func_name, func_args_type, \ switch_case_pair0, switch_case_pair1) \ OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR( \ SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1)), \ func_args_type, \ make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0), \ OF_PP_PAIR_FIRST(switch_case_pair1))) #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_3(make_template_func, func_name, func_args_type, \ switch_case_pair0, switch_case_pair1, switch_case_pair2) \ OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR( \ SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1), \ OF_PP_PAIR_SECOND(switch_case_pair2)), \ func_args_type, \ make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0), \ OF_PP_PAIR_FIRST(switch_case_pair1), \ OF_PP_PAIR_FIRST(switch_case_pair2))) #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_4(make_template_func, func_name, func_args_type, \ switch_case_pair0, switch_case_pair1, switch_case_pair2, \ switch_case_pair3) \ OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR( \ SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1), \ OF_PP_PAIR_SECOND(switch_case_pair2), OF_PP_PAIR_SECOND(switch_case_pair3)), \ func_args_type, \ make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0), \ OF_PP_PAIR_FIRST(switch_case_pair1), OF_PP_PAIR_FIRST(switch_case_pair2), \ OF_PP_PAIR_FIRST(switch_case_pair3))) #define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE(make_switch_entry, func_name, args_type, t) \ OF_PP_FORCE(OF_PP_CAT(OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_, OF_PP_TUPLE_SIZE(t))( \ make_switch_entry, func_name, args_type, t)) #define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_1(make_switch_entry, func_name, args_type, \ ctrv_seq_tuple) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1, (make_switch_entry), \ (func_name), (args_type), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple))) #define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_2(make_switch_entry, func_name, args_type, \ ctrv_seq_tuple) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_2, (make_switch_entry), \ (func_name), (args_type), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple))) #define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_3(make_switch_entry, func_name, args_type, \ ctrv_seq_tuple) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_3, (make_switch_entry), \ (func_name), (args_type), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(2, ctrv_seq_tuple))) #define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_4(make_switch_entry, func_name, args_type, \ ctrv_seq_tuple) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_4, (make_switch_entry), \ (func_name), (args_type), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(2, ctrv_seq_tuple)), \ OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(3, ctrv_seq_tuple))) #define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE(t) \ OF_PP_FORCE(OF_PP_CAT(OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_, OF_PP_TUPLE_SIZE(t))(t)) #define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_1(t) \ std::tuple #define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_2(t) \ std::tuple #define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_3(t) \ std::tuple #define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_4(t) \ std::tuple #endif // ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_ ================================================ FILE: oneflow/core/common/symbol.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_SYMBOL_H_ #define ONEFLOW_CORE_COMMON_SYMBOL_H_ #include #include #include #include #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/check.h" #include "oneflow/core/common/hash_eq_trait_ptr.h" namespace oneflow { template struct SymbolUtil; template class Symbol final { public: Symbol() : ptr_(nullptr) {} Symbol(const T& obj) : ptr_(GetOrCreatePtr(obj)) {} Symbol(const Symbol& rhs) = default; Symbol(Symbol&& rhs) = default; ~Symbol() = default; explicit operator bool() const { return ptr_ != nullptr; } const T* operator->() const { return ptr_; } const T& operator*() const { return *ptr_; } bool operator==(const Symbol& rhs) const { return ptr_ == rhs.ptr_; } bool operator!=(const Symbol& rhs) const { return !(*this == rhs); } size_t hash_value() const { return std::hash()(ptr_); } Symbol& operator=(const Symbol& other) { ptr_ = other.ptr_; return *this; } void reset() { ptr_ = nullptr; } void reset(const T& obj) { ptr_ = GetOrCreatePtr(obj); } const std::shared_ptr& shared_from_symbol() const; private: template friend struct SymbolUtil; static const T* GetOrCreatePtr(const T& obj); const T* ptr_; }; template struct IsScalarType> final { static const bool value = true; }; template struct SymbolUtil final { using SymbolMap = std::unordered_map, std::shared_ptr>; static SymbolMap* GlobalSymbolMap() { static SymbolMap symbol_map; return &symbol_map; } static std::mutex* GlobalSymbolMapMutex() { static std::mutex mutex; return &mutex; } static SymbolMap* ThreadLocalSymbolMap() { static thread_local SymbolMap thread_local_symbol_map; return &thread_local_symbol_map; } static std::unordered_set* ThreadLocalSymbolPtrSet() { static thread_local std::unordered_set thread_local_symbol_ptr_set; return &thread_local_symbol_ptr_set; } template static const std::shared_ptr& LocalThreadGetOr(const T& obj) { auto* thread_local_symbol_map = ThreadLocalSymbolMap(); size_t hash_value = std::hash()(obj); HashEqTraitPtr obj_ptr_wraper(&obj, hash_value); const auto& local_iter = thread_local_symbol_map->find(obj_ptr_wraper); if (local_iter != thread_local_symbol_map->end()) { return local_iter->second; } const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value); (*thread_local_symbol_map)[iter->first] = iter->second; GLOGCHECK(ThreadLocalSymbolPtrSet()->emplace(iter->second.get()).second); return iter->second; } static typename SymbolMap::iterator FindGlobalSymbol(const T& obj, size_t hash_value) { HashEqTraitPtr new_obj_ptr_wraper(&obj, hash_value); auto* symbol_map = GlobalSymbolMap(); std::unique_lock lock(*GlobalSymbolMapMutex()); const auto& iter = symbol_map->find(new_obj_ptr_wraper); GLOGCHECK(iter != symbol_map->end()); return iter; } static const std::shared_ptr& SharedFromObject(const T& obj) { return LocalThreadGetOr(obj); } static typename SymbolMap::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) { std::shared_ptr ptr(new T(obj)); HashEqTraitPtr new_obj_ptr_wraper(ptr.get(), hash_value); std::unique_lock lock(*GlobalSymbolMapMutex()); return GlobalSymbolMap()->emplace(new_obj_ptr_wraper, ptr).first; } static const std::shared_ptr& GetOrCreatePtr(const T& obj) { return LocalThreadGetOr(obj); } }; template const std::shared_ptr& Symbol::shared_from_symbol() const { if (this->ptr_ == nullptr) { static auto* none = new std::shared_ptr(); return *none; } return SymbolUtil::SharedFromObject(*this->ptr_); } template const T* Symbol::GetOrCreatePtr(const T& obj) { return SymbolUtil::GetOrCreatePtr(obj).get(); } template Symbol SymbolOf(const T& obj) { return Symbol(obj); } } // namespace oneflow namespace std { template struct hash> final { size_t operator()(const oneflow::Symbol& symbol) const { return symbol.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_SYMBOL_H_ ================================================ FILE: oneflow/core/common/symbol_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { namespace detail { class SymObject { public: SymObject(const std::string& name) : name_(name) {} const std::string& name() const { return name_; } bool operator==(const SymObject& other) const { return name_ == other.name_; } private: std::string name_; }; } // namespace detail TEST(Symbol, shared_from_symbol) { Symbol symbol(detail::SymObject("SymbolObjectFoo")); ASSERT_TRUE(symbol.shared_from_symbol().get() == SymbolOf(detail::SymObject("SymbolObjectFoo")).shared_from_symbol().get()); } } // namespace test } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::test::detail::SymObject& sym_object) const { return std::hash()(sym_object.name()); } }; } // namespace std ================================================ FILE: oneflow/core/common/tensor_buffer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/memory/memory_allocator.h" namespace oneflow { namespace detail { static constexpr double kDefaultGrowthFactor = 1.0f; static constexpr double kDefaultShrinkFactor = 0.7f; static constexpr size_t kDefaultTensorBufferAlignedSize = 1024; size_t GetTensorBufferAlignedSize(size_t origin_size, double factor) { static size_t aligned_size = ParseIntegerFromEnv("ONEFLOW_TENSOR_BUFFER_ALIGNED_SIZE", kDefaultTensorBufferAlignedSize); return RoundUp(static_cast(origin_size * factor), aligned_size); } size_t GetTensorBufferGrowthSize(size_t origin_size) { static double factor = ParseFloatFromEnv("ONEFLOW_TENSOR_BUFFER_GROWTH_FACTOR", kDefaultGrowthFactor); return GetTensorBufferAlignedSize(origin_size, factor); } size_t GetTensorBufferShrinkSize(size_t origin_size) { static double factor = ParseFloatFromEnv("ONEFLOW_TENSOR_BUFFER_SHRINK_FACTOR", kDefaultShrinkFactor); return GetTensorBufferAlignedSize(origin_size, factor); } void CheckTensorBufferDataType(DataType val) { CHECK(val != DataType::kTensorBuffer && val != DataType::kOFRecord) << "TensorBuffer only support POD as internal data type."; } void TensorBufferImpl::Reset(const Shape& shape, DataType dtype) { int64_t elem_cnt = shape.elem_cnt(); if (dtype == DataType::kInvalidDataType || elem_cnt == 0) { return; } CheckTensorBufferDataType(dtype); if (shape == shape_ && dtype == data_type_) { return; } shape_ = shape; data_type_ = dtype; size_t new_buffer_size = elem_cnt * GetSizeOfDataType(dtype); Reserve(new_buffer_size); } void TensorBufferImpl::Reset(const Shape& shape) { Reset(shape, data_type_); } void TensorBufferImpl::Reset(DataType dtype) { CheckTensorBufferDataType(dtype); if (dtype == DataType::kInvalidDataType) { Reset(); } else { Reset(shape_, dtype); } } void TensorBufferImpl::Reset() { shape_ = Shape(); data_type_ = DataType::kInvalidDataType; DeallocateBuffer(); } void TensorBufferImpl::AllocateBuffer(size_t size) { CHECK(buffer_ == nullptr); buffer_ = MemoryAllocatorImpl::AllocateUnPinnedHostMem(size); buffer_size_ = size; } void TensorBufferImpl::DeallocateBuffer() { if (buffer_) { MemoryAllocatorImpl::DeallocateUnPinnedHostMem(buffer_); } buffer_ = nullptr; buffer_size_ = 0; } void TensorBufferImpl::Reserve(size_t new_size) { if (new_size > buffer_size_) { size_t growth_size = std::max(new_size, GetTensorBufferGrowthSize(new_size)); DeallocateBuffer(); AllocateBuffer(growth_size); } else { size_t shrink_size = GetTensorBufferShrinkSize(buffer_size_); if (new_size <= shrink_size) { DeallocateBuffer(); AllocateBuffer(shrink_size); } } } void TensorBufferImpl::CopyFrom(const TensorBufferImpl* src) { if (src == this) { return; } Reset(src->shape(), src->data_type()); memcpy(buffer_, src->buffer(), buffer_size_); } void TensorBufferImpl::Swap(TensorBufferImpl* other) { std::swap(buffer_, other->buffer_); std::swap(buffer_size_, other->buffer_size_); std::swap(shape_, other->shape_); std::swap(data_type_, other->data_type_); } } // namespace detail TensorBuffer::~TensorBuffer() { if (auto* pool = TensorBufferPool::TryGet()) { pool->Deallocate(&impl_); } } TensorBuffer::TensorBuffer(const Shape& shape, DataType dtype) { Allocate(shape, dtype); } TensorBuffer& TensorBuffer::operator=(TensorBuffer&& other) noexcept { impl_ = std::move(other.impl_); return *this; } void TensorBuffer::Allocate(const Shape& shape, DataType dtype) { CHECK(!is_allocated()); if (auto* pool = TensorBufferPool::TryGet()) { pool->Allocate(&impl_, shape, dtype); } else { impl_.reset(new detail::TensorBufferImpl(shape, dtype)); } } void TensorBuffer::Reset(const Shape& shape, DataType dtype) { if (is_allocated()) { impl_->Reset(shape, dtype); } else { Allocate(shape, dtype); } } void TensorBuffer::Reset(const Shape& shape) { CHECK(is_allocated()) << "TensorBuffer is not allocated"; impl_->Reset(shape); } void TensorBuffer::Reset(DataType dtype) { CHECK(is_allocated()) << "TensorBuffer is not allocated"; impl_->Reset(dtype); } void TensorBuffer::Reset() { if (impl_) { impl_->Reset(); } } const Shape& TensorBuffer::shape() const { CHECK(is_allocated()) << "TensorBuffer is not allocated"; return impl_->shape(); } DataType TensorBuffer::data_type() const { CHECK(is_allocated()) << "TensorBuffer is not allocated"; return impl_->data_type(); } void* TensorBuffer::raw_data() { CHECK(is_allocated()) << "TensorBuffer is not allocated"; return impl_->buffer(); } const void* TensorBuffer::raw_data() const { CHECK(is_allocated()) << "TensorBuffer is not allocated"; return const_cast(impl_.get())->buffer(); } void TensorBuffer::CopyFrom(const TensorBuffer& src) { CHECK(src.is_allocated()) << "TensorBuffer src is not allocated"; if (!is_allocated()) { Allocate(src.shape(), src.data_type()); } impl_->CopyFrom(src.impl_.get()); } void TensorBuffer::Swap(TensorBuffer& other) { std::swap(impl_, other.impl_); } namespace { constexpr size_t kDefaultPoolSizeBase = 64; constexpr double kDefaultPoolSizeFactor = 2.0; constexpr size_t kDefaultThreadLocalCacheSize = 64; size_t GetTensorBufferPoolSize(size_t base = kDefaultPoolSizeBase) { static double factor = ParseFloatFromEnv("ONEFLOW_TENSOR_BUFFER_POOL_SIZE_FACTOR", kDefaultPoolSizeFactor); return static_cast(std::ceil(base * factor)); } size_t GetTensorBufferPoolThreadLocalCacheSize() { static size_t cache_size = ParseIntegerFromEnv( "ONEFLOW_TENSOR_BUFFER_POOL_THREAD_LOCAL_CACHE_SIZE", kDefaultThreadLocalCacheSize); return cache_size; } } // namespace TensorBufferPool::TensorBufferPool() : thread_local_cache_size_(GetTensorBufferPoolThreadLocalCacheSize()), pool_size_(GetTensorBufferPoolSize()) { auto& thread_local_cache = ThreadLocalCache(); thread_local_cache.reserve(thread_local_cache_size_); global_free_list_.reserve(pool_size_); } void TensorBufferPool::Allocate(ItemT* item, const Shape& shape, DataType dtype) { CHECK(!(*item)) << "TensorBuffer is already allocated"; auto& thread_local_cache = ThreadLocalCache(); if (thread_local_cache.empty() && thread_local_cache_size_ > 0) { std::unique_lock lck(mtx_); if (!global_free_list_.empty()) { // fetch half of thread_local_cache_size of tensor buffers from global free list size_t fetches = thread_local_cache_size_ / 2; auto begin = global_free_list_.size() >= fetches ? (global_free_list_.end() - fetches) : global_free_list_.begin(); for (auto it = begin; it < global_free_list_.end(); ++it) { thread_local_cache.push_back(std::move(*it)); } global_free_list_.erase(begin, global_free_list_.end()); } } if (thread_local_cache.empty()) { item->reset(new detail::TensorBufferImpl(shape, dtype)); } else { *item = std::move(thread_local_cache.back()); thread_local_cache.pop_back(); (*item)->Reset(shape, dtype); } } void TensorBufferPool::Deallocate(ItemT* item) { if (!(*item)) { return; } auto& thread_local_cache = ThreadLocalCache(); if (thread_local_cache.size() < thread_local_cache_size_) { thread_local_cache.push_back(std::move(*item)); } else { size_t releases = thread_local_cache.size() / 2; { std::unique_lock lck(mtx_); if (global_free_list_.size() < pool_size_) { global_free_list_.push_back(std::move(*item)); // release half of tensor buffers in thread local cache back to global free list while (global_free_list_.size() < pool_size_ && releases > 0) { global_free_list_.push_back(std::move(thread_local_cache.back())); thread_local_cache.pop_back(); releases--; } } } // global free list is also full, release half of thread local cache thread_local_cache.resize(thread_local_cache.size() - releases); } if (*item) { item->reset(); } } void TensorBufferPool::IncreasePoolSizeByBase(size_t base) { std::unique_lock lck(mtx_); pool_size_ += GetTensorBufferPoolSize(base); if (pool_size_ > global_free_list_.capacity()) { global_free_list_.reserve(pool_size_); } if (pool_size_ < global_free_list_.size()) { global_free_list_.resize(pool_size_); } } void TensorBufferPool::DecreasePoolSizeByBase(size_t base) { std::unique_lock lck(mtx_); size_t dec = GetTensorBufferPoolSize(base); CHECK_GE(pool_size_, dec) << "pool_size " << pool_size_ << " decreased by " << dec << " would be negative"; pool_size_ -= dec; if (pool_size_ > global_free_list_.capacity()) { global_free_list_.reserve(pool_size_); } if (pool_size_ < global_free_list_.size()) { global_free_list_.resize(pool_size_); } } } // namespace oneflow ================================================ FILE: oneflow/core/common/tensor_buffer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_ #define ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace detail { class TensorBufferImpl final { public: TensorBufferImpl() : shape_(Shape()), data_type_(DataType::kInvalidDataType), buffer_(nullptr), buffer_size_(0) {} TensorBufferImpl(const Shape& shape, DataType dtype) : shape_(Shape()), data_type_(DataType::kInvalidDataType), buffer_(nullptr), buffer_size_(0) { Reset(shape, dtype); } ~TensorBufferImpl() { DeallocateBuffer(); } OF_DISALLOW_COPY_AND_MOVE(TensorBufferImpl); void Reset(const Shape& shape, DataType dtype); void Reset(const Shape& shape); void Reset(DataType dtype); void Reset(); void CopyFrom(const TensorBufferImpl* src); void Swap(TensorBufferImpl* other); const Shape& shape() const { return shape_; } DataType data_type() const { return data_type_; } void* buffer() { return buffer_; } const void* buffer() const { return buffer_; } size_t buffer_size() const { return buffer_size_; } private: void AllocateBuffer(size_t size); void DeallocateBuffer(); void Reserve(size_t new_size); Shape shape_; DataType data_type_; void* buffer_; size_t buffer_size_; }; } // namespace detail class TensorBuffer final { public: TensorBuffer() = default; ~TensorBuffer(); TensorBuffer(const Shape& shape, DataType dtype); TensorBuffer(const TensorBuffer&) = delete; TensorBuffer& operator=(const TensorBuffer&) = delete; TensorBuffer(TensorBuffer&& other) noexcept : impl_(std::move(other.impl_)) {} TensorBuffer& operator=(TensorBuffer&& other) noexcept; bool is_allocated() const { return bool(impl_); } const Shape& shape() const; ShapeView shape_view() const { return shape(); } DataType data_type() const; int64_t elem_cnt() const { return shape().elem_cnt(); } size_t nbytes() const { return elem_cnt() * GetSizeOfDataType(data_type()); } void Reset(const Shape& shape, DataType dtype); void Reset(const Shape& shape); void Reset(DataType dtype); void Reset(); // backward compatible interface and will be deprecated in future void Resize(const Shape& shape, DataType dtype) { Reset(shape, dtype); } void CopyFrom(const TensorBuffer& src); void Swap(TensorBuffer& other); template T* mut_data() { if (raw_data() == nullptr) { return nullptr; } CheckDataType(data_type()); return static_cast(raw_data()); } template const T* data() const { if (raw_data() == nullptr) { return nullptr; } CheckDataType(data_type()); return static_cast(raw_data()); } private: friend class TensorBufferPool; void Allocate(const Shape& shape, DataType dtype); void* raw_data(); const void* raw_data() const; std::unique_ptr impl_; }; #define BUFFER_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer) template<> struct GetDataType : std::integral_constant {}; inline TensorBuffer GetTypeByDataType(std::integral_constant) { return {}; } class TensorBufferPool final { public: using ItemT = std::unique_ptr; using ListT = std::vector; static TensorBufferPool* Get() { auto& ptr = GetPtr(); CHECK(ptr) << "TensorBufferPool has not been created"; return ptr.get(); } static TensorBufferPool* TryGet() { auto& ptr = GetPtr(); return ptr.get(); } static void New() { auto& ptr = GetPtr(); CHECK(!ptr) << "TensorBufferPool is already New"; ptr.reset(new TensorBufferPool()); } static void Delete() { auto& ptr = GetPtr(); if (ptr) { ptr.reset(); } } ~TensorBufferPool() = default; OF_DISALLOW_COPY_AND_MOVE(TensorBufferPool); void Allocate(ItemT* item, const Shape& shape, DataType dtype); void Deallocate(ItemT* item); void IncreasePoolSizeByBase(size_t base); void DecreasePoolSizeByBase(size_t base); private: static std::unique_ptr& GetPtr() { static std::unique_ptr ptr; return ptr; } static ListT& ThreadLocalCache() { thread_local ListT thread_local_cache; return thread_local_cache; } TensorBufferPool(); size_t thread_local_cache_size_; size_t pool_size_; ListT global_free_list_; std::mutex mtx_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_ ================================================ FILE: oneflow/core/common/tensor_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/register/blob_desc.pb.h" namespace oneflow { namespace user_op { TensorDesc& TensorDesc::operator=(const TensorDesc& rhs) { this->set_shape(rhs.shape()); this->set_stride(rhs.stride()); this->set_data_type(rhs.data_type()); this->set_is_dynamic(rhs.is_dynamic()); this->set_memory_format(rhs.memory_format()); return *this; } bool TensorDesc::operator==(const TensorDesc& rhs) const { return (this->shape() == rhs.shape()) && (this->stride() == rhs.stride()) && (this->data_type() == rhs.data_type()) && (this->is_dynamic() == rhs.is_dynamic()) && (this->memory_format() == rhs.memory_format()); } NaiveTensorDesc::NaiveTensorDesc(const NaiveTensorDesc& rhs) { *this = rhs; } NaiveTensorDesc::NaiveTensorDesc(const BlobDescProto& proto) { *this = proto; } NaiveTensorDesc& NaiveTensorDesc::operator=(const BlobDescProto& proto) { data_type_ = proto.data_type(); shape_ = Shape(proto.shape()); stride_ = Stride(proto.stride()); is_dynamic_ = proto.is_dynamic(); memory_format_ = proto.memory_format(); return *this; } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/common/tensor_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ #define ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/memory_format.pb.h" namespace oneflow { class BlobDescProto; namespace user_op { class TensorDesc { public: virtual ~TensorDesc() = default; TensorDesc& operator=(const TensorDesc& rhs); bool operator==(const TensorDesc&) const; virtual const Shape& shape() const = 0; virtual void set_shape(const Shape& shape) = 0; virtual const Stride& stride() const = 0; virtual void set_stride(const Stride& stride) = 0; virtual DataType data_type() const = 0; virtual void set_data_type(DataType data_type) = 0; virtual bool is_dynamic() const = 0; virtual void set_is_dynamic(bool is_dynamic) = 0; virtual MemoryFormat memory_format() const = 0; virtual void set_memory_format(MemoryFormat memory_format) = 0; protected: TensorDesc() = default; }; class NaiveTensorDesc final : public TensorDesc { public: NaiveTensorDesc() = default; ~NaiveTensorDesc() override = default; NaiveTensorDesc(const NaiveTensorDesc&); NaiveTensorDesc(const BlobDescProto&); NaiveTensorDesc& operator=(const BlobDescProto&); const Shape& shape() const override { return shape_; } void set_shape(const Shape& shape) override { shape_ = shape; } const Stride& stride() const override { return stride_; } void set_stride(const Stride& stride) override { stride_ = stride; } DataType data_type() const override { return data_type_; } void set_data_type(DataType data_type) override { data_type_ = data_type; } bool is_dynamic() const override { return is_dynamic_; } void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; } MemoryFormat memory_format() const override { return memory_format_; } void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; } private: Shape shape_; Stride stride_; DataType data_type_; bool is_dynamic_; MemoryFormat memory_format_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ ================================================ FILE: oneflow/core/common/tensor_meta.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { namespace one { MutTensorMeta::MutTensorMeta() : TensorMeta(kInvalidDataType, MemoryFormat::kContiguous), shape_(std::make_shared()), stride_(std::make_shared()) {} MutTensorMeta::MutTensorMeta(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(std::make_shared(*shape)), stride_(std::make_shared(*shape)) {} MutTensorMeta::MutTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(std::make_shared(*shape)), stride_(std::make_shared(*stride)) {} MutTensorMeta::MutTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(std::make_shared(shape)), stride_(std::make_shared(shape)) {} MutTensorMeta::MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(std::make_shared(shape)), stride_(std::make_shared(stride)) {} bool MutTensorMeta::operator==(const MutTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() && this->memory_format() == other.memory_format() && this->stride() == other.stride(); } size_t MutTensorMeta::CalcHashValue() const { // It's correct to ignore is_dynamic_ field. return Hash(*shape_ptr(), dtype(), memory_format(), stride()); } ConstTensorMeta::ConstTensorMeta() : TensorMeta(kInvalidDataType, MemoryFormat::kContiguous), shape_(SymbolOf(Shape())), stride_(SymbolOf(Stride())) {} ConstTensorMeta::ConstTensorMeta(Symbol shape, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(shape), stride_(SymbolOf(Stride(*shape))) {} ConstTensorMeta::ConstTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format) : TensorMeta(dtype, memory_format), shape_(shape), stride_(stride) {} bool ConstTensorMeta::operator==(const ConstTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() && this->memory_format() == other.memory_format() && this->stride() == other.stride(); } size_t ConstTensorMeta::CalcHashValue() const { // It's correct to ignore is_dynamic_ field. return Hash(*shape_ptr(), dtype(), memory_format(), stride()); } LocalTensorMeta::LocalTensorMeta() : ConstTensorMeta(SymbolOf(Shape()), SymbolOf(Stride()), DataType::kInvalidDataType, MemoryFormat::kContiguous), device_(Symbol()) {} LocalTensorMeta::LocalTensorMeta(Symbol shape, DataType dtype, MemoryFormat memory_format, Symbol device) : ConstTensorMeta(shape, SymbolOf(Stride(*shape)), dtype, memory_format), device_(device) {} LocalTensorMeta::LocalTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, Symbol device) : ConstTensorMeta(shape, stride, dtype, memory_format), device_(device) {} LocalTensorMeta::LocalTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, Symbol device, const bool is_view) : ConstTensorMeta(shape, stride, dtype, memory_format), device_(device), is_view_(is_view) {} bool LocalTensorMeta::operator==(const LocalTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() && this->memory_format() == other.memory_format() && this->device() == other.device() && this->stride() == other.stride(); } size_t LocalTensorMeta::CalcHashValue() const { // It's correct to ignore is_dynamic_ field. return Hash(*shape_ptr(), dtype(), memory_format(), device(), stride()); } MutLocalTensorMeta::MutLocalTensorMeta() : MutTensorMeta(std::make_shared(), std::make_shared(), kInvalidDataType, MemoryFormat::kContiguous), device_(Symbol()) {} MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol device) : MutTensorMeta(shape, std::make_shared(*shape), dtype, memory_format), device_(device) {} MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format, Symbol device) : MutTensorMeta(shape, stride, dtype, memory_format), device_(device) {} MutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format, Symbol device) : MutTensorMeta(shape, Stride(shape), dtype, memory_format), device_(device) {} MutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, Symbol device) : MutTensorMeta(shape, stride, dtype, memory_format), device_(device) {} bool MutLocalTensorMeta::operator==(const MutLocalTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() && this->memory_format() == other.memory_format() && *this->device() == *other.device() && this->stride() == other.stride(); } size_t MutLocalTensorMeta::CalcHashValue() const { // It's correct to ignore is_dynamic_ field. return Hash(*shape_ptr(), dtype(), memory_format(), *device(), stride()); } bool GlobalTensorMeta::operator==(const GlobalTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() && this->memory_format() == other.memory_format() && this->nd_sbp() == other.nd_sbp() && this->parallel_desc() == other.parallel_desc(); } size_t GlobalTensorMeta::CalcHashValue() const { return Hash(*shape_ptr(), dtype(), memory_format(), nd_sbp(), parallel_desc()); } bool IsContiguous(const Shape& shape, const Stride& stride) { if (!shape.is_initialized()) { return true; } return IsContiguous(ShapeView(shape), stride); } bool IsContiguous(const ShapeView& shape_view, const Stride& stride) { if (shape_view.NumAxes() < 1 || shape_view.elem_cnt() <= 1) { return true; } int64_t dim = shape_view.NumAxes(); int64_t expected_stride = 1; bool contig_if_nonempty = true; for (int64_t i = dim - 1; i >= 0; --i) { // Contiguous by default when any dim is equal to zero // https://stackoverflow.com/questions/31681324/identify-contiguous-segments-of-a-non-contiguous-numpy-array if (shape_view.At(i) == 0) { return true; } if (contig_if_nonempty && shape_view.At(i) != 1) { if (stride.at(i) != expected_stride) { contig_if_nonempty = false; } expected_stride *= shape_view.At(i); } } return contig_if_nonempty; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/common/tensor_meta.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_COMMON_TENSOR_META_H_ #define ONEFLOW_COMMON_TENSOR_META_H_ #include #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class NdSbp; class Shape; class Stride; class Device; class ParallelDesc; namespace one { bool IsContiguous(const Shape& shape, const Stride& stride); bool IsContiguous(const ShapeView& shape_view, const Stride& stride); class TensorMeta : public user_op::TensorDesc { public: TensorMeta(DataType dtype, MemoryFormat memory_format) : data_type_(dtype), is_dynamic_(false), memory_format_(memory_format) {} TensorMeta(const TensorMeta& other) = default; TensorMeta(TensorMeta&&) = default; virtual ~TensorMeta() = default; virtual const std::shared_ptr& shape_ptr() const = 0; virtual const std::shared_ptr& stride_ptr() const = 0; virtual bool is_contiguous() const = 0; DataType dtype() const { return data_type_; } DataType data_type() const override { return data_type_; } bool is_dynamic() const override { return is_dynamic_; } MemoryFormat memory_format() const override { return memory_format_; } virtual void set_shape(const Shape& shape) override { PRINT_BUG_PROMPT_AND_ABORT(); } virtual void set_stride(const Stride& stride) override { PRINT_BUG_PROMPT_AND_ABORT(); } virtual void set_data_type(DataType data_type) override { PRINT_BUG_PROMPT_AND_ABORT(); } virtual void set_is_dynamic(bool is_dynamic) override { PRINT_BUG_PROMPT_AND_ABORT(); } virtual void set_memory_format(MemoryFormat memory_format) override { PRINT_BUG_PROMPT_AND_ABORT(); } protected: DataType data_type_; bool is_dynamic_; MemoryFormat memory_format_; }; class MutTensorMeta : public TensorMeta { public: // uninitialized MutTensorMeta. MutTensorMeta(); MutTensorMeta(const MutTensorMeta& other) : TensorMeta(other), shape_(std::make_shared(*other.shape_)), stride_(std::make_shared(*other.stride_)) {} MutTensorMeta(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format); MutTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format); MutTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format); MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format); virtual ~MutTensorMeta() = default; const std::shared_ptr& shape_ptr() const override { return shape_; } const std::shared_ptr& stride_ptr() const override { return stride_; } const Shape& shape() const override { return *shape_; } const Stride& stride() const override { return *stride_; } bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); } void set_shape(const Shape& shape) override { *const_cast(shape_.get()) = shape; } void set_stride(const Stride& stride) override { *const_cast(stride_.get()) = stride; } void set_data_type(DataType data_type) override { data_type_ = data_type; } void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; } void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; } bool operator==(const MutTensorMeta& other) const; size_t CalcHashValue() const; MutTensorMeta& operator=(const MutTensorMeta& other) { this->data_type_ = other.data_type_; this->is_dynamic_ = other.is_dynamic_; this->memory_format_ = other.memory_format_; this->shape_ = std::make_shared(*other.shape_); this->stride_ = std::make_shared(*other.stride_); return *this; } protected: std::shared_ptr shape_; std::shared_ptr stride_; }; class ConstTensorMeta : public TensorMeta { public: // uninitialized ConstTensorMeta. ConstTensorMeta(); ConstTensorMeta(const ConstTensorMeta&) = default; ConstTensorMeta(Symbol shape, DataType dtype, MemoryFormat memory_format); ConstTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format); ConstTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format) : ConstTensorMeta(SymbolOf(shape), dtype, memory_format) {} ConstTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format) : ConstTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format) {} virtual ~ConstTensorMeta() = default; const std::shared_ptr& shape_ptr() const override { return shape_.shared_from_symbol(); } const std::shared_ptr& stride_ptr() const override { return stride_.shared_from_symbol(); } const Shape& shape() const override { return *shape_; } const Stride& stride() const override { return *stride_; } bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); } bool operator==(const ConstTensorMeta& other) const; size_t CalcHashValue() const; ConstTensorMeta& operator=(const ConstTensorMeta& other) { this->data_type_ = other.data_type_; this->is_dynamic_ = other.is_dynamic_; this->memory_format_ = other.memory_format_; this->shape_ = other.shape_; this->stride_ = other.stride_; return *this; } protected: Symbol shape_; Symbol stride_; }; class LocalTensorMeta : public ConstTensorMeta { public: // uninitialized LocalTensorMeta. LocalTensorMeta(); LocalTensorMeta(const LocalTensorMeta&) = default; LocalTensorMeta(Symbol shape, DataType dtype, MemoryFormat memory_format, Symbol device); LocalTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, Symbol device); LocalTensorMeta(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, Symbol device, bool is_view); LocalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format, Symbol device) : LocalTensorMeta(SymbolOf(shape), dtype, memory_format, device) {} LocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, Symbol device) : LocalTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format, device) {} LocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, Symbol device, const bool is_view) : LocalTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format, device, is_view) {} virtual ~LocalTensorMeta() = default; const Symbol& device() const { return device_; } bool is_view() const { return is_view_; } bool operator==(const LocalTensorMeta& other) const; size_t CalcHashValue() const; LocalTensorMeta& operator=(const LocalTensorMeta& other) = default; private: Symbol device_; bool is_view_ = false; }; class MutLocalTensorMeta : public MutTensorMeta { public: // uninitialized MutLocalTensorMeta. MutLocalTensorMeta(); MutLocalTensorMeta(const MutLocalTensorMeta&) = default; MutLocalTensorMeta(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol device); MutLocalTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format, Symbol device); MutLocalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format, Symbol device); MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, Symbol device); virtual ~MutLocalTensorMeta() = default; const Symbol& device() const { return device_; } Symbol* mut_device() { return &device_; } bool operator==(const MutLocalTensorMeta& other) const; size_t CalcHashValue() const; MutLocalTensorMeta& operator=(const MutLocalTensorMeta& other) = default; private: Symbol device_; }; class GlobalTensorMeta : public ConstTensorMeta { public: GlobalTensorMeta(Symbol shape, DataType dtype, MemoryFormat memory_format, Symbol nd_sbp, Symbol parallel_desc) : ConstTensorMeta(shape, dtype, memory_format), nd_sbp_(nd_sbp), parallel_desc_(parallel_desc) {} GlobalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format, Symbol nd_sbp, Symbol parallel_desc) : GlobalTensorMeta(SymbolOf(shape), dtype, memory_format, nd_sbp, parallel_desc) {} GlobalTensorMeta(const GlobalTensorMeta&) = default; GlobalTensorMeta(GlobalTensorMeta&&) = default; virtual ~GlobalTensorMeta() = default; bool operator==(const GlobalTensorMeta& other) const; Symbol nd_sbp() const { return nd_sbp_; } Symbol parallel_desc() const { return parallel_desc_; } size_t CalcHashValue() const; private: Symbol nd_sbp_; Symbol parallel_desc_; }; } // namespace one } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::one::LocalTensorMeta& local_tensor_meta) const { return local_tensor_meta.CalcHashValue(); } }; template<> struct hash final { size_t operator()(const oneflow::one::GlobalTensorMeta& global_tensor_meta) const { return global_tensor_meta.CalcHashValue(); } }; } // namespace std #endif // ONEFLOW_COMMON_TENSOR_META_H_ ================================================ FILE: oneflow/core/common/test_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TEST_UTIL_H_ #define ONEFLOW_CORE_COMMON_TEST_UTIL_H_ #ifndef final #define final #endif #ifndef private #define private public #endif #include #include #endif // ONEFLOW_CORE_COMMON_TEST_UTIL_H_ ================================================ FILE: oneflow/core/common/thread_local_guard.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ #define ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ #include #include "oneflow/core/common/optional.h" namespace oneflow { template class ThreadLocalGuard { public: ThreadLocalGuard() { old_value_ = *MutThreadLocalValue(); *MutThreadLocalValue() = Optional(); } explicit ThreadLocalGuard(const T& value) { old_value_ = *MutThreadLocalValue(); *MutThreadLocalValue() = Optional(value); } ~ThreadLocalGuard() { *MutThreadLocalValue() = old_value_; } static const Optional& Current() { return *MutThreadLocalValue(); } private: static Optional* MutThreadLocalValue() { static thread_local Optional value{}; return &value; } Optional old_value_; }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ ================================================ FILE: oneflow/core/common/thread_local_guard_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/thread_local_guard.h" namespace oneflow { namespace test { template void Assert(const T& value0, const T& value1) { ASSERT_FALSE(ThreadLocalGuard::Current().has_value()); { ThreadLocalGuard guard(value0); ASSERT_TRUE(ThreadLocalGuard::Current().has_value()); } { ThreadLocalGuard guard(value0); ASSERT_TRUE(ThreadLocalGuard::Current().has_value()); T value = CHECK_JUST(ThreadLocalGuard::Current()); ASSERT_EQ(value, value0); } { ThreadLocalGuard guard(value1); ASSERT_TRUE(ThreadLocalGuard::Current().has_value()); const auto& value = CHECK_JUST(ThreadLocalGuard::Current()); ASSERT_EQ(value, value1); } { ThreadLocalGuard guard(value0); ASSERT_TRUE(ThreadLocalGuard::Current().has_value()); { const auto& value = CHECK_JUST(ThreadLocalGuard::Current()); ASSERT_EQ(value, value0); } { ThreadLocalGuard nested_guard(value1); ASSERT_TRUE(ThreadLocalGuard::Current().has_value()); const auto& value = CHECK_JUST(ThreadLocalGuard::Current()); ASSERT_EQ(value, value1); } { const auto& value = CHECK_JUST(ThreadLocalGuard::Current()); ASSERT_EQ(value, value0); } } } TEST(ThreadLocalGuard, bool) { Assert(true, false); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/common/throw.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_THROW_H_ #define ONEFLOW_CORE_COMMON_THROW_H_ #include #include "oneflow/core/common/error.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/to_string.h" namespace oneflow { namespace details { struct Throw final { [[noreturn]] void operator=(Error&& error) { ThrowError(error.stacked_error()); } }; } // namespace details } // namespace oneflow #define PRINT_BUG_PROMPT_AND_ABORT() LOG(FATAL) << kOfBugIssueUploadPrompt // use CHECK_XX_OR_THROW instead of glog CHECK to get more information of stack when check failed #undef CHECK #undef CHECK_LT #undef CHECK_LE #undef CHECK_EQ #undef CHECK_NE #undef CHECK_GT #undef CHECK_GE #define CHECK CHECK_OR_THROW #define CHECK_LT CHECK_LT_OR_THROW #define CHECK_LE CHECK_LE_OR_THROW #define CHECK_EQ CHECK_EQ_OR_THROW #define CHECK_NE CHECK_NE_OR_THROW #define CHECK_GT CHECK_GT_OR_THROW #define CHECK_GE CHECK_GE_OR_THROW #define THROW(err_type) \ ::oneflow::details::Throw() = \ ::oneflow::Error::err_type().AddStackFrame([](const char* function) { \ thread_local static auto frame = \ ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) // use __FILE__ __LINE__ etc. macros to get last frame, so this macro can show // the file name and line where CHECK_OR_THROW located even if these is no debug info #define CHECK_OR_THROW_INTERNAL(expr, error_msg) \ if (!(expr)) \ ::oneflow::details::Throw() = \ ::oneflow::Error::CheckFailedError() \ .AddStackFrame([](const char* function) { \ thread_local static auto frame = ::oneflow::SymbolOf( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, error_msg)); \ return frame; \ }(__FUNCTION__)) \ .GetStackTrace() #define CHECK_OR_THROW(expr) \ CHECK_OR_THROW_INTERNAL(expr, OF_PP_STRINGIZE(CHECK_OR_THROW(expr))) \ << "Check failed: (" << OF_PP_STRINGIZE(expr) << ") " #define CHECK_EQ_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) == (rhs), OF_PP_STRINGIZE(CHECK_EQ_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) \ << " == " << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_GE_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) >= (rhs), OF_PP_STRINGIZE(CHECK_GE_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) \ << " >= " << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_GT_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) > (rhs), OF_PP_STRINGIZE(CHECK_GT_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) << " > " \ << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_LE_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) <= (rhs), OF_PP_STRINGIZE(CHECK_LE_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) \ << " <= " << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_LT_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) < (rhs), OF_PP_STRINGIZE(CHECK_LT_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) << " < " \ << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_NE_OR_THROW(lhs, rhs) \ CHECK_OR_THROW_INTERNAL((lhs) != (rhs), OF_PP_STRINGIZE(CHECK_NE_OR_THROW(lhs, rhs))) \ << "Check failed: " \ << "(" << ::oneflow::ToStringIfApplicable(lhs) \ << " != " << ::oneflow::ToStringIfApplicable(rhs) << "): " #define CHECK_STREQ_OR_THROW(lhs, rhs) CHECK_EQ_OR_THROW(std::string(lhs), std::string(rhs)) #define CHECK_STRNE_OR_THROW(lhs, rhs) CHECK_NE_OR_THROW(std::string(lhs), std::string(rhs)) #define CHECK_NOTNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr != nullptr) #define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr) #define TODO_THEN_THROW() \ ::oneflow::details::Throw() = \ ::oneflow::Error::TodoError().AddStackFrame([](const char* function) { \ thread_local static auto frame = \ ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) #define UNIMPLEMENTED_THEN_THROW() \ ::oneflow::details::Throw() = \ ::oneflow::Error::UnimplementedError().AddStackFrame([](const char* function) { \ thread_local static auto frame = \ ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) #endif // ONEFLOW_CORE_COMMON_THROW_H_ ================================================ FILE: oneflow/core/common/to_string.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TO_STRING_H_ #define ONEFLOW_CORE_COMMON_TO_STRING_H_ #include #include "oneflow/core/common/type_traits.h" namespace oneflow { template inline std::string ToString(const T& value) { return std::to_string(value); } template inline std::string ToStringIfApplicable(const T& value) { if constexpr (printable()) { std::stringstream ss; ss << value; return ss.str(); } else { return ""; } } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_TO_STRING_H_ ================================================ FILE: oneflow/core/common/tuple_hash.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TUPLE_HASH_H_ #define ONEFLOW_CORE_COMMON_TUPLE_HASH_H_ #include #include #include "oneflow/core/common/util.h" namespace std { template struct hash> final { size_t operator()(const std::tuple& val) const { return do_hash(val, std::index_sequence_for{}); } private: template size_t do_hash(const std::tuple& val, std::index_sequence) const { return oneflow::Hash(std::get(val)...); } }; } // namespace std #endif // ONEFLOW_CORE_COMMON_TUPLE_HASH_H_ ================================================ FILE: oneflow/core/common/type_traits.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_ #define ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_ #include #if defined(WITH_CUDA) #include #include #endif #include "oneflow/core/common/bfloat16.h" #include #include namespace std { #if __GNUG__ && __GNUC__ < 5 && !__clang__ // copied from // https://llvm.org/doxygen/type__traits_8h_source.html namespace detail { /// Internal utility to detect trivial copy construction. template union copy_construction_triviality_helper { T t; copy_construction_triviality_helper() = default; copy_construction_triviality_helper(const copy_construction_triviality_helper&) = default; ~copy_construction_triviality_helper() = default; }; /// Internal utility to detect trivial move construction. template union move_construction_triviality_helper { T t; move_construction_triviality_helper() = default; move_construction_triviality_helper(move_construction_triviality_helper&&) = default; ~move_construction_triviality_helper() = default; }; template union trivial_helper { T t; }; } // end namespace detail // is_trivially_copyable // An implementation of `std::is_trivially_copyable` since STL version // is not equally supported by all compilers, especially GCC 4.8. // Uniform implementation of this trait is important for ABI compatibility // as it has an impact on SmallVector's ABI (among others). template class is_trivially_copyable { // copy constructors static constexpr bool has_trivial_copy_constructor = std::is_copy_constructible>::value; static constexpr bool has_deleted_copy_constructor = !std::is_copy_constructible::value; // move constructors static constexpr bool has_trivial_move_constructor = std::is_move_constructible>::value; static constexpr bool has_deleted_move_constructor = !std::is_move_constructible::value; // copy assign static constexpr bool has_trivial_copy_assign = is_copy_assignable>::value; static constexpr bool has_deleted_copy_assign = !is_copy_assignable::value; // move assign static constexpr bool has_trivial_move_assign = is_move_assignable>::value; static constexpr bool has_deleted_move_assign = !is_move_assignable::value; // destructor static constexpr bool has_trivial_destructor = std::is_destructible>::value; public: static constexpr bool value = has_trivial_destructor && (has_deleted_move_assign || has_trivial_move_assign) && (has_deleted_move_constructor || has_trivial_move_constructor) && (has_deleted_copy_assign || has_trivial_copy_assign) && (has_deleted_copy_constructor || has_trivial_copy_constructor); #ifdef HAVE_STD_IS_TRIVIALLY_COPYABLE static_assert( value == std::is_trivially_copyable::value, "inconsistent behavior between llvm:: and std:: implementation of is_trivially_copyable"); #endif }; template class is_trivially_copyable : public true_type {}; #endif } // namespace std namespace oneflow { // Type Trait: IsScalarType template struct IsScalarType final { static const bool value = std::is_scalar::value; }; template struct IsScalarType< T, typename std::enable_if< std::is_same::type>::value || std::is_same::type>::value #ifdef WITH_CUDA || std::is_same::type>::value #endif // WITH_CUDA || std::is_same, typename std::remove_cv::type>::value || std::is_same, typename std::remove_cv::type>::value>::type> final { static const bool value = true; }; namespace detail { template using remove_cvref_t = typename std::remove_cv::type>::type; template struct ScalarOrConstRef; template struct ScalarOrConstRef::value>::type> { using type = T; }; template struct ScalarOrConstRef::value>::type> { using type = const T&; }; template constexpr auto printable(int) -> decltype(std::declval() << std::declval(), bool()) { return true; } template constexpr bool printable(...) { return false; } } // namespace detail template using scalar_or_const_ref_t = typename detail::ScalarOrConstRef::type; template constexpr bool printable() { return detail::printable(0); } } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_ ================================================ FILE: oneflow/core/common/util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/platform.h" #include #include #ifdef __linux__ #include #include #endif namespace oneflow { #define DEFINE_ONEFLOW_STR2INT_CAST(dst_type, cast_func) \ template<> \ dst_type oneflow_cast(const std::string& s) { \ char* end_ptr = nullptr; \ dst_type ret = cast_func(s.c_str(), &end_ptr, 0); \ CHECK_EQ(*end_ptr, '\0'); \ return ret; \ } DEFINE_ONEFLOW_STR2INT_CAST(long, strtol); DEFINE_ONEFLOW_STR2INT_CAST(unsigned long, strtoul); DEFINE_ONEFLOW_STR2INT_CAST(long long, strtoll); DEFINE_ONEFLOW_STR2INT_CAST(unsigned long long, strtoull); DEFINE_ONEFLOW_STR2INT_CAST(signed char, strtol); DEFINE_ONEFLOW_STR2INT_CAST(short, strtol); DEFINE_ONEFLOW_STR2INT_CAST(int, strtol); DEFINE_ONEFLOW_STR2INT_CAST(unsigned char, strtoul); DEFINE_ONEFLOW_STR2INT_CAST(unsigned short, strtoul); DEFINE_ONEFLOW_STR2INT_CAST(unsigned int, strtoul); template<> float oneflow_cast(const std::string& s) { char* end_ptr = nullptr; float ret = strtof(s.c_str(), &end_ptr); CHECK_EQ(*end_ptr, '\0'); return ret; } template<> double oneflow_cast(const std::string& s) { char* end_ptr = nullptr; double ret = strtod(s.c_str(), &end_ptr); CHECK_EQ(*end_ptr, '\0'); return ret; } #ifdef OF_PLATFORM_POSIX // COMMAND(feenableexcept(FE_ALL_EXCEPT & ~FE_INEXACT & ~FE_UNDERFLOW)); #endif // If the interrupt during object malloc is changed to exit, the exit function indicates a normal // exit, triggering the object destructor function and then triggering object free. Since there is a // lock in malloc, if malloc and free obtain the same lock, it can cause a deadlock, which prevents // the process from exiting. After calling abort, the OS forces the program to exit, // relying on the OS to do resource cleanup, which can avoid the deadlock issue. // Process inability to exit can be more troublesome than potential resource leaks. If we find that // abort causes unreleased resources later, we can use exit in a local scope rather than globally. // Reference: https://github.com/Oneflow-Inc/OneTeam/issues/1954 void AbortSignalHandler(int signal) { std::abort(); } COMMAND(std::signal(SIGINT, AbortSignalHandler)); size_t GetAvailableCpuMemSize() { #if defined(__linux__) std::ifstream mem_info("/proc/meminfo"); CHECK(mem_info.good()) << "can't open file: /proc/meminfo"; std::string line; while (std::getline(mem_info, line).good()) { std::string token; const char* p = line.c_str(); p = StrToToken(p, " ", &token); if (token != "MemAvailable:") { continue; } CHECK_NE(*p, '\0'); p = StrToToken(p, " ", &token); size_t mem_available = oneflow_cast(token); CHECK_NE(*p, '\0'); p = StrToToken(p, " ", &token); CHECK_EQ(token, "kB"); return mem_available * 1024; } return sysconf(_SC_PAGESIZE) * sysconf(_SC_AVPHYS_PAGES); #elif defined(__APPLE__) // macOS will eagerly make use of all memory so there is no point querying it return std::numeric_limits::max(); #else UNIMPLEMENTED(); return 0; #endif } bool IsKernelSafeInt32(int64_t n) { return n <= GetMaxVal() / 2; } namespace { bool CaseInsensitiveStringEquals(const std::string& lhs, const std::string& rhs) { return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char a, char b) { return std::tolower(a) == std::tolower(b); }); } bool StringToBool(const std::string& str) { return CaseInsensitiveStringEquals(str, "1") || CaseInsensitiveStringEquals(str, "true") || CaseInsensitiveStringEquals(str, "yes") || CaseInsensitiveStringEquals(str, "on") || CaseInsensitiveStringEquals(str, "y"); } bool StringToInteger(const std::string& str, int64_t* value) { char* end; int64_t v = std::strtoll(str.data(), &end, 10); if (end == str.data()) { return false; } else { *value = v; return true; } } bool StringToFloat(const std::string& str, double* value) { char* end = nullptr; double v = std::strtof(str.data(), &end); if (end == str.data()) { return false; } else { *value = v; return true; } } } // namespace bool ParseBooleanFromEnv(const std::string& env_var, bool default_value) { const char* env_p = std::getenv(env_var.c_str()); if (env_p == nullptr) { return default_value; } else { return StringToBool(env_p); } } int64_t ParseIntegerFromEnv(const std::string& env_var, int64_t default_value) { const char* env_p = std::getenv(env_var.c_str()); if (env_p == nullptr) { return default_value; } int64_t value; if (StringToInteger(env_p, &value)) { return value; } else { return default_value; } } double ParseFloatFromEnv(const std::string& env_var, double default_value) { const char* env_p = std::getenv(env_var.c_str()); if (env_p == nullptr) { return default_value; } double value = default_value; StringToFloat(env_p, &value); return value; } std::string GetStringFromEnv(const std::string& env_var, const std::string& default_value) { const char* env_p = std::getenv(env_var.c_str()); if (env_p == nullptr) { return default_value; } else { return env_p; } } } // namespace oneflow ================================================ FILE: oneflow/core/common/util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_UTIL_H_ #define ONEFLOW_CORE_COMMON_UTIL_H_ #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/throw.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/meta_util.hpp" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/hash.h" #include "oneflow/core/common/cpp_attribute.h" #include "fmt/format.h" #include "fmt/ranges.h" #define CHECK_ISNULL(e) CHECK((e) == nullptr) namespace fmt { template struct formatter> : formatter { template auto format(const std::complex& c, FormatContext& ctx) { return formatter::format(fmt::format("({}+{}j)", c.real(), c.imag()), ctx); } }; } // namespace fmt template std::ostream& operator<<(std::ostream& os, const std::vector& v) { os << fmt::format("{}", v); return os; } namespace oneflow { #define OF_DISALLOW_COPY(ClassName) \ ClassName(const ClassName&) = delete; \ ClassName& operator=(const ClassName&) = delete #define OF_DISALLOW_MOVE(ClassName) \ ClassName(ClassName&&) = delete; \ ClassName& operator=(ClassName&&) = delete #define OF_DISALLOW_COPY_AND_MOVE(ClassName) \ OF_DISALLOW_COPY(ClassName); \ OF_DISALLOW_MOVE(ClassName) #define UNIMPLEMENTED() LOG(FATAL) << "UNIMPLEMENTED" #define TODO() LOG(FATAL) << "TODO" #define OF_COMMA , #define DEFINE_STATIC_VAR(type, name) \ static type* name() { \ static type var; \ return &var; \ } #define COMMAND(...) \ namespace { \ struct OF_PP_CAT(CommandT, __LINE__) { \ OF_PP_CAT(CommandT, __LINE__)() { __VA_ARGS__; } \ }; \ OF_PP_CAT(CommandT, __LINE__) OF_PP_CAT(g_command_var, __LINE__); \ } template bool operator==(const std::weak_ptr& lhs, const std::weak_ptr& rhs) { return lhs.lock().get() == rhs.lock().get(); } template void SortAndRemoveDuplication(std::vector* vec) { std::sort(vec->begin(), vec->end()); auto unique_it = std::unique(vec->begin(), vec->end()); vec->erase(unique_it, vec->end()); } inline std::string NewUniqueId() { static std::atomic counter(0); return std::to_string(counter.fetch_add(1, std::memory_order_relaxed)); } template void EraseIf(HashMap* hash_map, std::function::iterator)> cond) { for (auto it = hash_map->begin(); it != hash_map->end();) { if (cond(it)) { hash_map->erase(it++); } else { ++it; } } } template typename std::enable_if::value, std::ostream&>::type operator<<( std::ostream& out_stream, const T& x) { out_stream << static_cast(x); return out_stream; } template OutType oneflow_cast(const InType&); inline uint32_t NewRandomSeed() { static std::mt19937 gen{std::random_device{}()}; return gen(); } #define DIM_SEQ \ OF_PP_MAKE_TUPLE_SEQ(1) \ OF_PP_MAKE_TUPLE_SEQ(2) \ OF_PP_MAKE_TUPLE_SEQ(3) \ OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6) OF_PP_MAKE_TUPLE_SEQ(7) #define BOOL_SEQ (true)(false) #define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i) #define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it) inline double GetCurTime() { return std::chrono::high_resolution_clock::now().time_since_epoch().count(); } const size_t kHostAlignSize = 64; const size_t kCudaAlignSize = 512; const size_t kCudaMemAllocAlignSize = 512; const int32_t kBlobBodyAlignSize = 512; const int32_t kBlobHeaderAlignSize = 64; inline size_t RoundUp(size_t n, size_t val) { return (n + val - 1) / val * val; } inline size_t GetCudaAlignedSize(size_t size) { return RoundUp(size, kCudaAlignSize); } size_t GetAvailableCpuMemSize(); template void Erase(T& container, const std::function& NeedErase, const std::function& EraseElementHandler) { auto iter = container.begin(); auto erase_from = container.end(); while (iter != erase_from) { if (NeedErase(*iter)) { --erase_from; if (iter == erase_from) { break; } std::swap(*iter, *erase_from); } else { ++iter; } } for (; iter != container.end(); ++iter) { EraseElementHandler(*iter); } if (erase_from != container.end()) { container.erase(erase_from, container.end()); } } template void Erase(T& container, const std::function& NeedErase) { Erase(container, NeedErase, [](const typename T::value_type&) {}); } #if defined(__GNUC__) #define ALWAYS_INLINE __attribute__((always_inline)) #elif defined(__CUDACC__) #define ALWAYS_INLINE __forceinline__ #else #define ALWAYS_INLINE inline #endif bool IsKernelSafeInt32(int64_t n); class RoundModeGuard final { public: RoundModeGuard(int mode) { saved_mode_ = std::fegetround(); CHECK_EQ(std::fesetround(mode), 0); } ~RoundModeGuard() { std::fesetround(saved_mode_); } private: int saved_mode_; }; bool ParseBooleanFromEnv(const std::string& env_var, bool default_value); int64_t ParseIntegerFromEnv(const std::string& env_var, int64_t default_value); double ParseFloatFromEnv(const std::string& env_var, double default_value); std::string GetStringFromEnv(const std::string& env_var, const std::string& default_value); #define OF_PREDICT_TRUE likely #define OF_PREDICT_FALSE unlikely } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_UTIL_H_ ================================================ FILE: oneflow/core/common/wrap_dim_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/maybe.h" namespace oneflow { // align with pytorch: `c10/core/WrapDimMinimal.h` static inline Maybe maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar = true) { if (dim_post_expr <= 0) { if (!wrap_scalar) { return Error::RuntimeError() << "dimension specified as " << dim << " but tensor has no dimensions"; } dim_post_expr = 1; // this will make range [-1, 0] } int64_t min = -dim_post_expr; int64_t max = dim_post_expr - 1; if (dim < min || dim > max) { return Error::IndexError() << "Dimension out of range (expected to be in range of [" << min << ", " << max << "], but got " << dim << ")"; } if (dim < 0) dim += dim_post_expr; return dim; } // align with pytorch: `aten/src/ATen/WrapDimUtilsMulti.h` constexpr size_t dim_bitset_size = 64; static inline Maybe> dim_list_to_bitset( const std::vector& dims, int64_t ndims) { CHECK_LE_OR_RETURN(ndims, (int64_t)dim_bitset_size) << Error::RuntimeError() << "Only tensors with up to " << dim_bitset_size << " dims are supported"; std::bitset seen; for (int32_t i = 0; i < dims.size(); i++) { size_t dim = JUST(maybe_wrap_dim(dims[i], ndims)); CHECK_OR_RETURN_ERROR(!seen[dim]) << Error::RuntimeError() << "The dim " << dim << " appears multiple times in the list of dims"; seen[dim] = true; } return seen; } } // namespace oneflow ================================================ FILE: oneflow/core/common/zero_only_zip.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_ #define ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_ #include #include "oneflow/core/common/sized_buffer_view.h" namespace oneflow { struct ZeroOnlyZipUtil final { void ZipToSizedBuffer(const char* data, size_t size, SizedBufferView* sized_buffer); void UnzipToExpectedSize(const SizedBufferView& size_buffer, char* data, size_t expected_size); }; } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_ ================================================ FILE: oneflow/core/control/bootstrap_client.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_ #define ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_ #include "oneflow/core/control/rpc_client.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { class BootstrapClient : public RpcClient { public: OF_DISALLOW_COPY_AND_MOVE(BootstrapClient); virtual ~BootstrapClient() override = default; protected: friend class Singleton; BootstrapClient() = default; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_ ================================================ FILE: oneflow/core/control/bootstrap_server.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_ #define ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_ #include "oneflow/core/control/rpc_server.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { class BootstrapServer : public RpcServer { public: OF_DISALLOW_COPY_AND_MOVE(BootstrapServer); BootstrapServer() = default; virtual ~BootstrapServer() override = default; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_ ================================================ FILE: oneflow/core/control/control.proto ================================================ syntax = "proto2"; package oneflow; message LoadServerRequest { required string addr = 1; optional int64 rank = 2 [default = -1]; } message LoadServerResponse { } message BarrierRequest { required string name = 1; required int32 num = 2; } message BarrierResponse { } enum TryLockResult { kLocked = 0; kDone = 1; kDoing = 2; } message TryLockRequest { required string name = 1; } message TryLockResponse { required TryLockResult result = 1; } message NotifyDoneRequest { required string name = 1; } message NotifyDoneResponse { } message WaitUntilDoneRequest { required string name = 1; } message WaitUntilDoneResponse { } message PushKVRequest { required string key = 1; required bytes val = 2; } message PushKVResponse { } message ClearKVRequest { required string key = 1; } message ClearKVResponse { } message PullKVRequest { required string key = 1; } message PullKVResponse { required bytes val = 1; } message ClearRequest { } message ClearResponse { } message IncreaseCountRequest { required string key = 1; required int32 val = 2; } message IncreaseCountResponse { required int32 val = 1; } message EraseCountRequest { required string key = 1; } message EraseCountResponse { } ================================================ FILE: oneflow/core/control/ctrl_bootstrap.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/worker_process_info.pb.h" #include "oneflow/core/control/host_list_bootstrap_server.h" #include "oneflow/core/control/host_list_bootstrap_client.h" #include "oneflow/core/control/rank_info_bootstrap_server.h" #include "oneflow/core/control/rank_info_bootstrap_client.h" namespace oneflow { Maybe CtrlBootstrap::InitProcessCtx(int64_t port, ProcessCtx* ret_process_ctx) { std::vector worker_process_info_list; worker_process_info_list.reserve(world_size()); if (rank() == 0) { WorkerProcessInfo worker_process_info; { worker_process_info.set_rank(rank()); worker_process_info.set_port(port); JUST(SetCurrentHostByMaster(&worker_process_info)); } worker_process_info_list.emplace_back(worker_process_info); for (int64_t world_rank = 1; world_rank < world_size(); ++world_rank) { std::string key = std::string("GetWorkerProcessInfo") + std::to_string(world_rank); WorkerProcessInfo cur_work_process_info; mut_bootstrap_client()->PullMasterKV(key, &cur_work_process_info); CHECK_EQ_OR_RETURN(world_rank, worker_process_info_list.size()); CHECK_EQ_OR_RETURN(world_rank, cur_work_process_info.rank()); worker_process_info_list.emplace_back(cur_work_process_info); } } else { std::string key = std::string("GetWorkerProcessInfo") + std::to_string(rank()); WorkerProcessInfo cur_work_process_info; { cur_work_process_info.set_rank(rank()); cur_work_process_info.set_port(port); JUST(SetCurrentHostByWorker(&cur_work_process_info)); } mut_bootstrap_client()->PushMasterKV(key, cur_work_process_info); } mut_bootstrap_client()->Barrier(__FILE__ ":" OF_PP_STRINGIZE(__LINE__)); if (rank() == 0) { ret_process_ctx->set_rank(rank()); ret_process_ctx->mutable_ctrl_addr()->Clear(); for (const auto& worker_process_info : worker_process_info_list) { Address* addr = ret_process_ctx->mutable_ctrl_addr()->Add(); if (worker_process_info.has_host()) { addr->set_host(worker_process_info.host()); } addr->set_port(worker_process_info.port()); JUST(SetHostByMaster(addr, worker_process_info.rank())); } JUST(SetNodeSize(ret_process_ctx)); mut_bootstrap_client()->PushMasterKV("BroadcastProcessCtx", *ret_process_ctx); } else { mut_bootstrap_client()->PullMasterKV("BroadcastProcessCtx", ret_process_ctx); ret_process_ctx->set_rank(rank()); } mut_bootstrap_client()->Barrier(__FILE__ ":" OF_PP_STRINGIZE(__LINE__)); VLOG(2) << "\n" << ret_process_ctx->DebugString(); return Maybe::Ok(); } HostListCtrlBootstrap::HostListCtrlBootstrap(const EnvDesc& env_desc) : CtrlBootstrap() { bootstrap_server_.reset(new HostListBootstrapServer(env_desc)); bootstrap_client_.reset(new HostListBootstrapClient(env_desc)); bootstrap_client_->Barrier(__FILE__ ":" OF_PP_STRINGIZE(__LINE__)); host_ = bootstrap_server_->this_machine_addr(); rank_ = env_desc.GetMachineId(host_); world_size_ = env_desc.TotalMachineNum(); } HostListCtrlBootstrap::~HostListCtrlBootstrap() { bootstrap_client_.reset(); bootstrap_server_.reset(); } Maybe HostListCtrlBootstrap::SetHostByMaster(Address* addr, int64_t world_rank) const { return Maybe::Ok(); } Maybe HostListCtrlBootstrap::SetCurrentHostByMaster( WorkerProcessInfo* worker_process_info) const { worker_process_info->set_host(host()); return Maybe::Ok(); } Maybe HostListCtrlBootstrap::SetCurrentHostByWorker( WorkerProcessInfo* worker_process_info) const { worker_process_info->set_host(host()); return Maybe::Ok(); } Maybe HostListCtrlBootstrap::SetNodeSize(ProcessCtx* process_ctx) const { process_ctx->set_node_size(world_size()); return Maybe::Ok(); } BootstrapServer* HostListCtrlBootstrap::mut_bootstrap_server() { return bootstrap_server_.get(); } BootstrapClient* HostListCtrlBootstrap::mut_bootstrap_client() { return bootstrap_client_.get(); } RankInfoCtrlBootstrap::RankInfoCtrlBootstrap(const BootstrapConf& bootstrap_conf) : CtrlBootstrap(), bootstrap_conf_(bootstrap_conf) { bootstrap_server_.reset(new RankInfoBootstrapServer(bootstrap_conf)); bootstrap_client_.reset(new RankInfoBootstrapClient(bootstrap_conf)); bootstrap_client_->Barrier(__FILE__ ":" OF_PP_STRINGIZE(__LINE__)); master_host_ = bootstrap_conf.master_addr().host(); rank_ = bootstrap_conf.rank(); world_size_ = bootstrap_conf.world_size(); } RankInfoCtrlBootstrap::~RankInfoCtrlBootstrap() { bootstrap_client_.reset(); bootstrap_server_.reset(); } Maybe RankInfoCtrlBootstrap::SetHostByMaster(Address* addr, int64_t world_rank) const { if (addr->has_host()) { return Maybe::Ok(); } const auto& rank2host = JUST(bootstrap_server_->rank2host()); CHECK_EQ_OR_RETURN(rank2host.size(), world_size()); CHECK_GE_OR_RETURN(world_rank, 0); CHECK_LT_OR_RETURN(world_rank, rank2host.size()); addr->set_host(rank2host.at(world_rank)); return Maybe::Ok(); } Maybe RankInfoCtrlBootstrap::SetCurrentHostByMaster( WorkerProcessInfo* worker_process_info) const { CHECK_EQ_OR_RETURN(rank(), 0); if (bootstrap_conf_.has_host()) { worker_process_info->set_host(bootstrap_conf_.host()); } else { worker_process_info->set_host(master_host_); } return Maybe::Ok(); } Maybe RankInfoCtrlBootstrap::SetCurrentHostByWorker( WorkerProcessInfo* worker_process_info) const { CHECK_NE_OR_RETURN(rank(), 0); if (bootstrap_conf_.has_host()) { worker_process_info->set_host(bootstrap_conf_.host()); } return Maybe::Ok(); } Maybe RankInfoCtrlBootstrap::SetNodeSize(ProcessCtx* process_ctx) const { if (bootstrap_conf_.has_node_size()) { CHECK_EQ_OR_RETURN(world_size() % bootstrap_conf_.node_size(), 0); process_ctx->set_node_size(bootstrap_conf_.node_size()); return Maybe::Ok(); } const auto& rank2host = JUST(bootstrap_server_->rank2host()); std::set no_duplicated_host; for (const auto& host : rank2host) { no_duplicated_host.insert(host); } CHECK_EQ_OR_RETURN(world_size() % no_duplicated_host.size(), 0); process_ctx->set_node_size(no_duplicated_host.size()); return Maybe::Ok(); } BootstrapServer* RankInfoCtrlBootstrap::mut_bootstrap_server() { return bootstrap_server_.get(); } BootstrapClient* RankInfoCtrlBootstrap::mut_bootstrap_client() { return bootstrap_client_.get(); } } // namespace oneflow ================================================ FILE: oneflow/core/control/ctrl_bootstrap.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_ #define ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_ #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/common/maybe.h" namespace oneflow { class ProcessCtx; class WorkerProcessInfo; class BootstrapServer; class BootstrapClient; class CtrlBootstrap { public: virtual ~CtrlBootstrap() {} Maybe InitProcessCtx(int64_t port, ProcessCtx* ret_process_ctx); protected: virtual int64_t rank() const = 0; virtual int64_t world_size() const = 0; virtual Maybe SetHostByMaster(Address*, int64_t world_rank) const = 0; virtual Maybe SetCurrentHostByMaster(WorkerProcessInfo*) const = 0; virtual Maybe SetCurrentHostByWorker(WorkerProcessInfo*) const = 0; virtual Maybe SetNodeSize(ProcessCtx* process_ctx) const = 0; virtual BootstrapServer* mut_bootstrap_server() = 0; virtual BootstrapClient* mut_bootstrap_client() = 0; CtrlBootstrap() = default; }; class HostListBootstrapServer; class HostListBootstrapClient; class HostListCtrlBootstrap final : public CtrlBootstrap { public: explicit HostListCtrlBootstrap(const EnvDesc& env_desc); ~HostListCtrlBootstrap() override; private: int64_t rank() const override { return rank_; } int64_t world_size() const override { return world_size_; } std::string host() const { return host_; } Maybe SetHostByMaster(Address*, int64_t world_rank) const override; Maybe SetCurrentHostByMaster(WorkerProcessInfo*) const override; Maybe SetCurrentHostByWorker(WorkerProcessInfo*) const override; Maybe SetNodeSize(ProcessCtx* process_ctx) const override; BootstrapServer* mut_bootstrap_server() override; BootstrapClient* mut_bootstrap_client() override; // Uses shared_ptr and forward declaration to avoid `#include ...` std::shared_ptr bootstrap_server_; std::shared_ptr bootstrap_client_; std::string host_; int64_t rank_; int64_t world_size_; }; class RankInfoBootstrapServer; class RankInfoBootstrapClient; class RankInfoCtrlBootstrap final : public CtrlBootstrap { public: explicit RankInfoCtrlBootstrap(const BootstrapConf& bootstrap_conf); ~RankInfoCtrlBootstrap() override; private: int64_t rank() const override { return rank_; } int64_t world_size() const override { return world_size_; } Maybe SetHostByMaster(Address*, int64_t world_rank) const override; Maybe SetCurrentHostByMaster(WorkerProcessInfo*) const override; Maybe SetCurrentHostByWorker(WorkerProcessInfo*) const override; Maybe SetNodeSize(ProcessCtx* process_ctx) const override; BootstrapServer* mut_bootstrap_server() override; BootstrapClient* mut_bootstrap_client() override; // Uses shared_ptr and forward declaration to avoid `#include ...` std::shared_ptr bootstrap_server_; std::shared_ptr bootstrap_client_; std::string master_host_; BootstrapConf bootstrap_conf_; int64_t rank_; int64_t world_size_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_ ================================================ FILE: oneflow/core/control/ctrl_bootstrap.proto ================================================ syntax = "proto2"; package oneflow; message Address { required string host = 1; required int32 port = 2; } message ProcessCtx { repeated Address ctrl_addr = 1; required int64 rank = 2; required int64 node_size = 3; } message BootstrapConf { required Address master_addr = 1; required int64 rank = 2; required int64 world_size = 3; optional string host = 4; optional int32 ctrl_port = 5 [default = -1]; optional int64 node_size = 6 [default = -1]; } ================================================ FILE: oneflow/core/control/ctrl_call.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ #define ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ #include "oneflow/core/control/ctrl_service.h" namespace oneflow { class CtrlCallIf { public: OF_DISALLOW_COPY_AND_MOVE(CtrlCallIf); virtual ~CtrlCallIf() = default; virtual void Process() = 0; virtual void SendResponse() = 0; protected: CtrlCallIf() = default; private: }; template class CtrlCall final : public CtrlCallIf { public: OF_DISALLOW_COPY_AND_MOVE(CtrlCall); CtrlCall() : status_(Status::kBeforeHandleRequest), responder_(&server_ctx_) {} ~CtrlCall() = default; static constexpr const size_t value = (size_t)ctrl_method; const CtrlRequest& request() const { return request_; } CtrlRequest* mut_request() { return &request_; } CtrlResponse* mut_response() { return &response_; } grpc::ServerContext* mut_server_ctx() { return &server_ctx_; } const grpc::ServerContext& server_ctx() const { return server_ctx_; } grpc::ServerAsyncResponseWriter>* mut_responder() { return &responder_; } void set_request_handler(std::function val) { request_handler_ = val; } void Process() override { switch (status_) { case Status::kBeforeHandleRequest: { request_handler_(); return; } case Status::kBeforeDelete: { delete this; return; } } } void SendResponse() override { responder_.Finish(response_, grpc::Status::OK, this); status_ = Status::kBeforeDelete; } private: enum class Status { kBeforeHandleRequest, kBeforeDelete }; Status status_; CtrlRequest request_; CtrlResponse response_; grpc::ServerContext server_ctx_; grpc::ServerAsyncResponseWriter> responder_; std::function request_handler_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ ================================================ FILE: oneflow/core/control/ctrl_client.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/ctrl_client.h" namespace oneflow { namespace { #define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK) } // namespace GrpcCtrlClient::~GrpcCtrlClient() { StopHeartbeat(); } GrpcCtrlClient::GrpcCtrlClient(const ProcessCtx& process_ctx) : process_ctx_(process_ctx) { rpc_client_.ReserveStubsOfSize(process_ctx.ctrl_addr_size()); for (int64_t i = 0; i < process_ctx.ctrl_addr_size(); ++i) { const Address& address = process_ctx.ctrl_addr(i); auto new_stub = CtrlService::NewStub(address.host() + ":" + std::to_string(address.port())); rpc_client_.AddStub(std::move(new_stub)); rpc_client_.LoadServer(address.host(), rpc_client_.GetStubAt(i)); } need_heartbeat_thread_stop_ = false; heartbeat_thread_ = std::thread([this]() { std::mt19937 gen(NewRandomSeed()); std::uniform_int_distribution sleep_second_dis(7, 13); LoadServerRequest request; LoadServerResponse response; while (true) { const auto wait_duration = std::chrono::seconds(sleep_second_dis(gen)); { std::unique_lock lck(need_heartbeat_thread_stop_mtx_); const bool stopped = need_heartbeat_thread_stop_cv_.wait_for( lck, wait_duration, [&]() { return need_heartbeat_thread_stop_; }); if (stopped) { break; } } for (size_t i = 0; i < rpc_client_.GetStubSize(); ++i) { grpc::ClientContext client_ctx; request.set_addr(this->process_ctx().ctrl_addr(i).host()); GRPC_CHECK(rpc_client_.GetStubAt(i)->CallMethod( &client_ctx, request, &response)) << "Machine " << i << " lost"; } } }); } void GrpcCtrlClient::Barrier(const std::string& barrier_name) { rpc_client_.Barrier(barrier_name); } void GrpcCtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) { rpc_client_.Barrier(barrier_name, barrier_num); } TryLockResult GrpcCtrlClient::TryLock(const std::string& name) { return rpc_client_.TryLock(name); } void GrpcCtrlClient::NotifyDone(const std::string& name) { rpc_client_.NotifyDone(name); } void GrpcCtrlClient::WaitUntilDone(const std::string& name) { rpc_client_.WaitUntilDone(name); } void GrpcCtrlClient::PushKV(const std::string& k, const std::string& v) { rpc_client_.PushKV(k, v); } void GrpcCtrlClient::PushKV(const std::string& k, const PbMessage& msg) { rpc_client_.PushKV(k, msg); } void GrpcCtrlClient::PushKV(const std::string& k, std::function VSetter) { rpc_client_.PushKV(k, VSetter); } void GrpcCtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) { rpc_client_.PushMasterKV(k, msg); } void GrpcCtrlClient::ClearKV(const std::string& k) { rpc_client_.ClearKV(k); } void GrpcCtrlClient::ClearMasterKV(const std::string& k) { rpc_client_.ClearMasterKV(k); } void GrpcCtrlClient::PullKV(const std::string& k, std::string* v) { rpc_client_.PullKV(k, v); } void GrpcCtrlClient::PullKV(const std::string& k, PbMessage* msg) { rpc_client_.PullKV(k, msg); } void GrpcCtrlClient::PullKV(const std::string& k, std::function VGetter) { rpc_client_.PullKV(k, VGetter); } void GrpcCtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) { rpc_client_.PullMasterKV(k, msg); } void GrpcCtrlClient::Clear() { rpc_client_.Clear(); } int32_t GrpcCtrlClient::IncreaseCount(const std::string& k, int32_t v) { return rpc_client_.IncreaseCount(k, v); } void GrpcCtrlClient::EraseCount(const std::string& k) { rpc_client_.EraseCount(k); } void GrpcCtrlClient::StopHeartbeat() { bool already_stopped = false; { std::unique_lock lck(need_heartbeat_thread_stop_mtx_); already_stopped = need_heartbeat_thread_stop_; need_heartbeat_thread_stop_ = true; need_heartbeat_thread_stop_cv_.notify_all(); } if (!already_stopped) { heartbeat_thread_.join(); } } } // namespace oneflow ================================================ FILE: oneflow/core/control/ctrl_client.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ #define ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ #include "oneflow/core/rpc/include/ctrl.h" #endif // ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ ================================================ FILE: oneflow/core/control/ctrl_server.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/job/env_desc.h" #include "grpc/grpc_posix.h" namespace oneflow { CtrlServer::CtrlServer(int ctrl_port) : RpcServer(), port_(ctrl_port) { Init(); grpc::ServerBuilder server_builder; server_builder.SetMaxMessageSize(INT_MAX); int bound_port = 0; server_builder.AddListeningPort("0.0.0.0:" + std::to_string(port_), grpc::InsecureServerCredentials(), &bound_port); grpc_service_.reset(new CtrlService::AsyncService); server_builder.RegisterService(grpc_service_.get()); cq_ = server_builder.AddCompletionQueue(); grpc_server_ = server_builder.BuildAndStart(); if (port() != 0) { CHECK_EQ(port(), bound_port) << "Port " << port() << " is unavailable"; } else { port_ = bound_port; CHECK_NE(port(), 0); } LOG(INFO) << "CtrlServer listening on " << "0.0.0.0:" + std::to_string(port()); loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this); } CtrlServer::CtrlServer() : CtrlServer(0) {} void CtrlServer::OnLoadServer(CtrlCall* call) { call->SendResponse(); EnqueueRequest(); } } // namespace oneflow ================================================ FILE: oneflow/core/control/ctrl_server.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ #define ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ #ifdef RPC_BACKEND_GRPC #include "oneflow/core/control/rpc_server.h" namespace oneflow { class CtrlServer final : public RpcServer { public: OF_DISALLOW_COPY_AND_MOVE(CtrlServer); ~CtrlServer() override {} CtrlServer(); // port may be configured in bootstrap_conf CtrlServer(int ctrl_port); int64_t port() const { return port_; } private: void OnLoadServer(CtrlCall* call) override; int port_; }; } // namespace oneflow #endif // RPC_BACKEND_GRPC #endif // ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ ================================================ FILE: oneflow/core/control/ctrl_service.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/ctrl_service.h" namespace oneflow { namespace { template const grpc::internal::RpcMethod BuildOneRpcMethod(std::shared_ptr channel) { return grpc::internal::RpcMethod(GetMethodName(static_cast(method_index)), grpc::internal::RpcMethod::NORMAL_RPC, channel); } template std::array BuildRpcMethods( std::index_sequence, std::shared_ptr channel) { return {BuildOneRpcMethod(channel)...}; } constexpr int64_t kDefaultGrpcMaxMessageByteSize = -1; } // namespace CtrlService::Stub::Stub(std::shared_ptr channel) : rpcmethods_(BuildRpcMethods(std::make_index_sequence{}, channel)), channel_(channel) {} std::unique_ptr CtrlService::NewStub(const std::string& addr) { grpc::ChannelArguments ch_args; int64_t max_msg_byte_size = ParseIntegerFromEnv("ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE", kDefaultGrpcMaxMessageByteSize); ch_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, max_msg_byte_size); return std::make_unique( grpc::CreateCustomChannel(addr, grpc::InsecureChannelCredentials(), ch_args)); } CtrlService::AsyncService::AsyncService() { for (int32_t i = 0; i < kCtrlMethodNum; ++i) { AddMethod(new grpc::internal::RpcServiceMethod(GetMethodName(static_cast(i)), grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); grpc::Service::MarkMethodAsync(i); } } } // namespace oneflow ================================================ FILE: oneflow/core/control/ctrl_service.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ #define ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ #include #include #include #include #include #include #include #include #include #include #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/util.h" #include "oneflow/core/control/control.pb.h" #include "oneflow/core/rpc/include/base.h" namespace oneflow { class CtrlService final { public: class Stub final { public: Stub(std::shared_ptr channel); template grpc::Status CallMethod(grpc::ClientContext* context, const CtrlRequest& request, CtrlResponse* response) { return grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethods_.at(static_cast(ctrl_method)), context, request, response); } private: std::array rpcmethods_; std::shared_ptr channel_; }; static std::unique_ptr NewStub(const std::string& addr); class AsyncService final : public grpc::Service { public: AsyncService(); ~AsyncService() = default; using grpc::Service::RequestAsyncUnary; }; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ ================================================ FILE: oneflow/core/control/ctrl_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/job/env.pb.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/ctrl_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #ifdef OF_PLATFORM_POSIX #include #include #include #include #include namespace oneflow { namespace { EnvProto GetEnvProto(int port) { EnvProto ret; auto* machine0 = ret.add_machine(); machine0->set_id(0); machine0->set_addr("127.0.0.1"); ret.set_ctrl_port(port); return ret; } Resource GetResource() { Resource ret; ret.set_machine_num(1); ret.set_cpu_device_num(1); ret.set_comm_net_worker_num(1); return ret; } } // namespace #ifdef RPC_BACKEND_GRPC TEST(CtrlServer, new_delete) { int port = CtrlUtil().FindAvailablePort(); if (port == -1) { return; } EnvProto env_proto = GetEnvProto(port); Singleton::New(env_proto); Singleton::New(); Singleton::New(); CHECK_JUST( HostListCtrlBootstrap(*Singleton::Get()) .InitProcessCtx(Singleton::Get()->port(), Singleton::Get())); auto* client = new GrpcCtrlClient(*Singleton::Get()); Singleton::SetAllocated(client); Singleton::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode()); Singleton::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode()); // do test // OF_ENV_BARRIER(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); } #endif // RPC_BACKEND_GRPC } // namespace oneflow #endif // OF_PLATFORM_POSIX ================================================ FILE: oneflow/core/control/ctrl_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/platform.h" #ifdef OF_PLATFORM_POSIX #include #include #include #include #include #endif // OF_PLATFORM_POSIX #include "oneflow/core/control/ctrl_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" namespace oneflow { #ifdef OF_PLATFORM_POSIX namespace { sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) { sockaddr_in sa; sa.sin_family = AF_INET; sa.sin_port = htons(port); PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1); return sa; } } // namespace int CtrlUtil::FindAvailablePort() const { int sock = socket(AF_INET, SOCK_STREAM, 0); for (uint16_t port = 10000; port < GetMaxVal(); ++port) { sockaddr_in sa = GetSockAddr("0.0.0.0", port); int bind_result = bind(sock, reinterpret_cast(&sa), sizeof(sa)); if (bind_result == 0) { shutdown(sock, SHUT_RDWR); close(sock); return port; } } return -1; } #else int CtrlUtil::FindAvailablePort() const { UNIMPLEMENTED(); } #endif // OF_PLATFORM_POSIX } // namespace oneflow ================================================ FILE: oneflow/core/control/ctrl_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_CTR_TEST_H_ #define ONEFLOW_CORE_CONTROL_CTR_TEST_H_ namespace oneflow { class CtrlUtil { public: CtrlUtil() = default; ~CtrlUtil() = default; int FindAvailablePort() const; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_CTR_TEST_H_ ================================================ FILE: oneflow/core/control/global_process_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_ #define ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_ #include "oneflow/core/rpc/include/global_process_ctx.h" #endif // ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_ ================================================ FILE: oneflow/core/control/host_list_bootstrap_client.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/host_list_bootstrap_client.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { HostListBootstrapClient::HostListBootstrapClient(const EnvDesc& env_desc) { stubs_.reserve(env_desc.TotalMachineNum()); int32_t port = -1; std::string addr = ""; for (int64_t i = 0; i < env_desc.TotalMachineNum(); ++i) { const Machine& mchn = env_desc.machine(i); port = (mchn.ctrl_port_agent() != -1) ? (mchn.ctrl_port_agent()) : env_desc.ctrl_port(); addr = mchn.addr() + ":" + std::to_string(port); stubs_.emplace_back(CtrlService::NewStub(addr)); LoadServer(mchn.addr(), stubs_[i].get()); } } } // namespace oneflow ================================================ FILE: oneflow/core/control/host_list_bootstrap_client.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_ #define ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_ #include "oneflow/core/control/bootstrap_client.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { class HostListBootstrapClient final : public BootstrapClient { public: OF_DISALLOW_COPY_AND_MOVE(HostListBootstrapClient); ~HostListBootstrapClient() override = default; HostListBootstrapClient(const EnvDesc& env_desc); }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_ ================================================ FILE: oneflow/core/control/host_list_bootstrap_server.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/host_list_bootstrap_server.h" #include "grpc/grpc_posix.h" namespace oneflow { HostListBootstrapServer::HostListBootstrapServer(const EnvDesc& env_desc) : BootstrapServer(), is_first_connect_(true), this_machine_addr_("") { Init(); int port = env_desc.ctrl_port(); grpc::ServerBuilder server_builder; server_builder.SetMaxMessageSize(INT_MAX); int bound_port = 0; server_builder.AddListeningPort("0.0.0.0:" + std::to_string(port), grpc::InsecureServerCredentials(), &bound_port); grpc_service_.reset(new CtrlService::AsyncService); server_builder.RegisterService(grpc_service_.get()); cq_ = server_builder.AddCompletionQueue(); grpc_server_ = server_builder.BuildAndStart(); CHECK_EQ(port, bound_port) << "Port " << port << " is unavailable"; LOG(INFO) << "HostListBootstrapServer listening on " << "0.0.0.0:" + std::to_string(port); loop_thread_ = std::thread(&HostListBootstrapServer::HandleRpcs, this); } void HostListBootstrapServer::OnLoadServer(CtrlCall* call) { if (this->is_first_connect_) { this->this_machine_addr_ = call->request().addr(); this->is_first_connect_ = false; } else { CHECK_EQ(call->request().addr(), this->this_machine_addr_); } call->SendResponse(); EnqueueRequest(); } } // namespace oneflow ================================================ FILE: oneflow/core/control/host_list_bootstrap_server.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_ #define ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_ #include "oneflow/core/control/bootstrap_server.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { class HostListBootstrapServer final : public BootstrapServer { public: OF_DISALLOW_COPY_AND_MOVE(HostListBootstrapServer); ~HostListBootstrapServer() override = default; HostListBootstrapServer(const EnvDesc& env_desc); const std::string& this_machine_addr() { return this_machine_addr_; } private: void OnLoadServer(CtrlCall* call) override; bool is_first_connect_; std::string this_machine_addr_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_ ================================================ FILE: oneflow/core/control/rank_info_bootstrap_client.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/rank_info_bootstrap_client.h" namespace oneflow { RankInfoBootstrapClient::RankInfoBootstrapClient(const BootstrapConf& bootstrap_conf) { stubs_.reserve(bootstrap_conf.world_size()); const auto& master_addr = bootstrap_conf.master_addr(); const std::string& host = master_addr.host() + ":" + std::to_string(master_addr.port()); stubs_.emplace_back(CtrlService::NewStub(host)); LoadServerRequest request; request.set_addr(master_addr.host()); request.set_rank(bootstrap_conf.rank()); LoadServer(request, stubs_[0].get()); } } // namespace oneflow ================================================ FILE: oneflow/core/control/rank_info_bootstrap_client.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_ #define ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_ #include "oneflow/core/control/bootstrap_client.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { class RankInfoBootstrapClient final : public BootstrapClient { public: OF_DISALLOW_COPY_AND_MOVE(RankInfoBootstrapClient); ~RankInfoBootstrapClient() override = default; RankInfoBootstrapClient(const BootstrapConf& bootstrap_conf); }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_ ================================================ FILE: oneflow/core/control/rank_info_bootstrap_server.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "grpc/grpc_posix.h" #include "oneflow/core/common/env_var/bootstrap.h" #include "oneflow/core/control/rank_info_bootstrap_server.h" namespace oneflow { namespace { std::string GetHostFromUri(const std::string& uri) { size_t first_delimiter_pos = uri.find(":"); CHECK_NE(first_delimiter_pos, std::string::npos); const std::string& protocol_family = uri.substr(0, first_delimiter_pos); CHECK_EQ(protocol_family, "ipv4"); size_t second_delimiter_pos = uri.rfind(":"); return uri.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1); } int64_t rpc_bootstrap_server_sleep_seconds() { static const int64_t rpc_bootstrap_server_sleep_seconds = EnvInteger(); return rpc_bootstrap_server_sleep_seconds; } int64_t rpc_bootstrap_server_max_retry_times() { static const int64_t rpc_bootstrap_server_max_retry_times = EnvInteger(); return rpc_bootstrap_server_max_retry_times; } } // namespace RankInfoBootstrapServer::RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf) : BootstrapServer(), port_(0), world_size_(bootstrap_conf.world_size()) { Init(); const int64_t rank = bootstrap_conf.rank(); int p = (rank == 0 ? bootstrap_conf.master_addr().port() : 0); grpc::ServerBuilder server_builder; server_builder.SetMaxMessageSize(INT_MAX); server_builder.AddListeningPort("0.0.0.0:" + std::to_string(p), grpc::InsecureServerCredentials(), &port_); grpc_service_.reset(new CtrlService::AsyncService); server_builder.RegisterService(grpc_service_.get()); cq_ = server_builder.AddCompletionQueue(); grpc_server_ = server_builder.BuildAndStart(); if (rank == 0) { CHECK_EQ(p, port()) << "Port " << p << " is unavailable"; } LOG(INFO) << "RankInfoBootstrapServer listening on " << "0.0.0.0:" + std::to_string(port()); loop_thread_ = std::thread(&RankInfoBootstrapServer::HandleRpcs, this); if (rank == 0) { rank2host_ = std::make_shared>(world_size_, ""); // NOTE: use check_thread_ to check RankInfoBootstrapServer status on rank 0 // if size of ready ranks == total ranks(world_size), means status is ok. // otherwise, it indicates that other ranks' server have not been created successfully! check_thread_ = std::thread(&RankInfoBootstrapServer::CheckServerStatus, this); } } void RankInfoBootstrapServer::CheckServerStatus() { bool status_ok = false; int64_t skip_warning_times = 1; int64_t retry_idx = 0; // lambda function to get valid rank num of rank2host_ auto GetValidRank2HostSize = [](const std::shared_ptr>& rank2host) { int64_t valid_size = 0; for (int64_t i = 0; i < rank2host->size(); ++i) { if (rank2host->at(i) == "") { continue; } valid_size += 1; } return valid_size; }; for (; retry_idx < rpc_bootstrap_server_max_retry_times(); ++retry_idx) { std::this_thread::sleep_for(std::chrono::seconds(rpc_bootstrap_server_sleep_seconds())); int64_t valid_size = 0; { std::lock_guard lock(lock_); valid_size = GetValidRank2HostSize(rank2host_); } CHECK(valid_size <= world_size_); if (valid_size == world_size_) { status_ok = true; break; } else { if (retry_idx >= skip_warning_times) { LOG(WARNING) << "BootstrapServer not ready, rpc server on some rank have not been created " "successfully. Failed at " << retry_idx + 1 << " times, total ranks(world_size): " << world_size_ << ", ready ranks: " << valid_size; } } } if (!status_ok) { LOG(FATAL) << "CheckServerStatus() failed, rpc server on some rank are not ready, please check " "whether the processes on all ranks are " "created successfully."; } } Maybe&> RankInfoBootstrapServer::rank2host() const { CHECK_NOTNULL(rank2host_.get()); return *rank2host_; } void RankInfoBootstrapServer::OnLoadServer(CtrlCall* call) { int64_t rank = call->request().rank(); CHECK_GE(rank, 0); CHECK_LT(rank, world_size_); if (!rank2host_) { rank2host_ = std::make_shared>(world_size_); } std::lock_guard lock(lock_); rank2host_->at(rank) = GetHostFromUri(call->server_ctx().peer()); call->SendResponse(); EnqueueRequest(); } } // namespace oneflow ================================================ FILE: oneflow/core/control/rank_info_bootstrap_server.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_ #define ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_ #include "oneflow/core/control/bootstrap_server.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/common/maybe.h" namespace oneflow { class RankInfoBootstrapServer final : public BootstrapServer { public: OF_DISALLOW_COPY_AND_MOVE(RankInfoBootstrapServer); ~RankInfoBootstrapServer() override { if (check_thread_.joinable()) { check_thread_.join(); } } RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf); int64_t port() const { return port_; } Maybe&> rank2host() const; private: void OnLoadServer(CtrlCall* call) override; void CheckServerStatus(); int port_; const int64_t world_size_; std::mutex lock_; std::thread check_thread_; // use std::shared_ptr as std::optional std::shared_ptr> rank2host_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_ ================================================ FILE: oneflow/core/control/rpc_client.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/rpc_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/common/env_var/bootstrap.h" namespace oneflow { namespace { int64_t rpc_client_max_retry_times() { static const int64_t rpc_client_max_retry_times = EnvInteger(); return rpc_client_max_retry_times; } int64_t rpc_client_sleep_seconds() { static const int64_t rpc_client_sleep_seconds = EnvInteger(); return rpc_client_sleep_seconds; } #define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK) template class ClientCall final { public: OF_DISALLOW_COPY_AND_MOVE(ClientCall); ClientCall() = default; ~ClientCall() = default; CtrlRequest* mut_request() { return &request_; } const CtrlResponse& response() const { return response_; } void operator()(CtrlService::Stub* stub) { grpc::ClientContext client_ctx; GRPC_CHECK(stub->CallMethod(&client_ctx, request_, &response_)); } private: CtrlRequest request_; CtrlResponse response_; }; } // namespace void RpcClient::Barrier(const std::string& barrier_name) { Barrier(barrier_name, Singleton::Get()->TotalMachineNum()); } void RpcClient::Barrier(const std::string& barrier_name, int32_t barrier_num) { ClientCall call; call.mut_request()->set_name(barrier_name); call.mut_request()->set_num(barrier_num); call(GetMasterStub()); } TryLockResult RpcClient::TryLock(const std::string& name) { { std::unique_lock lck(done_names_mtx_); if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; } } ClientCall call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); if (call.response().result() == TryLockResult::kDone) { std::unique_lock lck(done_names_mtx_); done_names_.insert(name); } return call.response().result(); } void RpcClient::NotifyDone(const std::string& name) { ClientCall call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); } void RpcClient::WaitUntilDone(const std::string& name) { ClientCall call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); } void RpcClient::PushKV(const std::string& k, std::function VSetter) { ClientCall call; call.mut_request()->set_key(k); VSetter(call.mut_request()->mutable_val()); call(GetResponsibleStub(k)); } void RpcClient::PushMasterKV(const std::string& k, std::function VSetter) { ClientCall call; call.mut_request()->set_key(k); VSetter(call.mut_request()->mutable_val()); call(GetMasterStub()); } void RpcClient::PushKV(const std::string& k, const std::string& v) { PushKV(k, [&](std::string* o) { *o = v; }); } void RpcClient::PushKV(const std::string& k, const PbMessage& msg) { PushKV(k, [&](std::string* o) { msg.SerializeToString(o); }); } void RpcClient::PushMasterKV(const std::string& k, const PbMessage& msg) { PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); }); } void RpcClient::ClearKV(const std::string& k) { ClientCall call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); } void RpcClient::ClearMasterKV(const std::string& k) { ClientCall call; call.mut_request()->set_key(k); call(GetMasterStub()); } void RpcClient::PullKV(const std::string& k, std::function VGetter) { ClientCall call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); VGetter(call.response().val()); } void RpcClient::PullMasterKV(const std::string& k, std::function VGetter) { ClientCall call; call.mut_request()->set_key(k); call(GetMasterStub()); VGetter(call.response().val()); } void RpcClient::PullKV(const std::string& k, std::string* v) { PullKV(k, [&](const std::string& i) { *v = i; }); } void RpcClient::PullKV(const std::string& k, PbMessage* msg) { PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); } void RpcClient::PullMasterKV(const std::string& k, PbMessage* msg) { PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); } void RpcClient::Clear() { ClientCall call; call(GetThisStub()); std::unique_lock lck(done_names_mtx_); done_names_.clear(); } int32_t RpcClient::IncreaseCount(const std::string& k, int32_t v) { ClientCall call; call.mut_request()->set_key(k); call.mut_request()->set_val(v); call(GetResponsibleStub(k)); return call.response().val(); } void RpcClient::EraseCount(const std::string& k) { ClientCall call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); } void RpcClient::LoadServer(const std::string& server_addr, CtrlService::Stub* stub) { LoadServerRequest request; request.set_addr(server_addr); return LoadServer(request, stub); } void RpcClient::LoadServer(const LoadServerRequest& request, CtrlService::Stub* stub) { int32_t retry_idx = 0; int32_t skip_warning_times = 3; for (; retry_idx < rpc_client_max_retry_times(); ++retry_idx) { grpc::ClientContext client_ctx; LoadServerResponse response; grpc::Status st = stub->CallMethod(&client_ctx, request, &response); if (st.error_code() == grpc::StatusCode::OK) { VLOG(3) << "LoadServer " << request.addr() << " Successful at " << retry_idx + 1 << " times"; break; } else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) { if (retry_idx >= skip_warning_times) { LOG(WARNING) << "LoadServer " << request.addr() << " Failed at " << retry_idx + 1 << " times" << " error_code: " << st.error_code() << " error_message: " << st.error_message(); } std::this_thread::sleep_for(std::chrono::seconds(rpc_client_sleep_seconds())); continue; } else { LOG(FATAL) << st.error_message(); } } CHECK_LT(retry_idx, rpc_client_max_retry_times()); } CtrlService::Stub* RpcClient::GetThisStub() { return stubs_[GlobalProcessCtx::Rank()].get(); } CtrlService::Stub* RpcClient::GetResponsibleStub(const std::string& key) { int64_t machine_id = (std::hash{}(key)) % Singleton::Get()->TotalMachineNum(); return stubs_[machine_id].get(); } } // namespace oneflow ================================================ FILE: oneflow/core/control/rpc_client.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_ #define ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_ #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/control/ctrl_service.h" #include "oneflow/core/job/global_for.h" namespace oneflow { class RpcClient { public: OF_DISALLOW_COPY_AND_MOVE(RpcClient); RpcClient() = default; virtual ~RpcClient() = default; void Barrier(const std::string& barrier_name); void Barrier(const std::string& barrier_name, int32_t barrier_num); TryLockResult TryLock(const std::string& name); void NotifyDone(const std::string& name); void WaitUntilDone(const std::string& name); void PushKV(const std::string& k, std::function VSetter); void PushKV(const std::string& k, const std::string& v); void PushKV(const std::string& k, const PbMessage& msg); void PushMasterKV(const std::string& k, const PbMessage& msg); template typename std::enable_if::value>::type PushKVT(const std::string& k, T v) { PushKV(k, std::to_string(v)); } void ClearKV(const std::string& k); void ClearMasterKV(const std::string& k); void PullKV(const std::string& k, std::function VGetter); void PullKV(const std::string& k, std::string* v); void PullKV(const std::string& k, PbMessage* msg); void PullMasterKV(const std::string& k, PbMessage* msg); template typename std::enable_if::value>::type PullKVT(const std::string& k, T* v) { std::string v_str; PullKV(k, &v_str); *v = oneflow_cast(v_str); } void Clear(); int32_t IncreaseCount(const std::string& k, int32_t v); int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); } void EraseCount(const std::string& k); void LoadServer(const std::string& server_addr, CtrlService::Stub* stub); void LoadServer(const LoadServerRequest& request, CtrlService::Stub* stub); void PushMasterKV(const std::string& k, std::function VSetter); void PullMasterKV(const std::string& k, std::function VGetter); CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); } CtrlService::Stub* GetThisStub(); CtrlService::Stub* GetResponsibleStub(const std::string& key); CtrlService::Stub* GetStubAt(int64_t i) { return stubs_[i].get(); }; size_t GetStubSize() { return stubs_.size(); }; void ReserveStubsOfSize(int64_t n) { stubs_.reserve(n); }; void AddStub(std::unique_ptr s) { stubs_.emplace_back(std::move(s)); }; std::vector> stubs_; std::mutex done_names_mtx_; HashSet done_names_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_ ================================================ FILE: oneflow/core/control/rpc_server.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/control/rpc_server.h" #include "oneflow/core/job/env_desc.h" #include "grpc/grpc_posix.h" namespace oneflow { RpcServer::~RpcServer() { // NOTE(chengcheng): This enqueues a special event (with a null tag) that causes // the completion queue to be shut down on the polling thread. grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr); loop_thread_.join(); } void RpcServer::HandleRpcs() { EnqueueRequests(); void* tag = nullptr; bool ok = false; // NOTE(chengcheng): The is_shutdown bool flag make sure that 'ok = false' occurs ONLY after // cq_->Shutdown() for security check. bool is_shutdown = false; // NOTE(chengcheng): The final end is that cq_->Next() get false and cq_ is empty with no item. while (cq_->Next(&tag, &ok)) { auto call = static_cast(tag); if (!ok) { // NOTE(chengcheng): After call grpc_server_->Shutdown() and cq_->Shutdown(), // there will trigger some cancel tag items on each RPC. And cq_->Next() can get these tag // with ok = false. Then delete the tag with CtrlCallIf pointer for recovery. CHECK(is_shutdown); CHECK(call); delete call; continue; } if (call) { call->Process(); } else { // NOTE(chengcheng): A null `call` indicates that this is the shutdown alarm. CHECK(!is_shutdown); is_shutdown = true; grpc_server_->Shutdown(); cq_->Shutdown(); // NOTE(chengcheng): You CANNOT use code 'break;' in this block because that // there still be items in the cq_. // 'break;' } } } void RpcServer::Init() { Add([this](CtrlCall* call) { OnLoadServer(call); }); Add([this](CtrlCall* call) { const std::string& barrier_name = call->request().name(); int32_t barrier_num = call->request().num(); auto barrier_call_it = barrier_calls_.find(barrier_name); if (barrier_call_it == barrier_calls_.end()) { barrier_call_it = barrier_calls_ .emplace(barrier_name, std::make_pair(std::list{}, barrier_num)) .first; } CHECK_EQ(barrier_num, barrier_call_it->second.second) << barrier_name; barrier_call_it->second.first.emplace_back(call); if (barrier_call_it->second.first.size() == barrier_call_it->second.second) { for (CtrlCallIf* pending_call : barrier_call_it->second.first) { pending_call->SendResponse(); } barrier_calls_.erase(barrier_call_it); } EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& lock_name = call->request().name(); auto name2lock_status_it = name2lock_status_.find(lock_name); if (name2lock_status_it == name2lock_status_.end()) { call->mut_response()->set_result(TryLockResult::kLocked); auto waiting_until_done_calls = new std::list; CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second); } else { if (name2lock_status_it->second) { call->mut_response()->set_result(TryLockResult::kDoing); } else { call->mut_response()->set_result(TryLockResult::kDone); } } call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& lock_name = call->request().name(); auto name2lock_status_it = name2lock_status_.find(lock_name); auto waiting_calls = static_cast*>(name2lock_status_it->second); for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); } delete waiting_calls; name2lock_status_it->second = nullptr; call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& lock_name = call->request().name(); void* lock_status = name2lock_status_.at(lock_name); if (lock_status) { auto waiting_calls = static_cast*>(lock_status); waiting_calls->emplace_back(call); } else { call->SendResponse(); } EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& k = call->request().key(); const std::string& v = call->request().val(); CHECK(kv_.emplace(k, v).second); auto pending_kv_calls_it = pending_kv_calls_.find(k); if (pending_kv_calls_it != pending_kv_calls_.end()) { for (auto pending_call : pending_kv_calls_it->second) { pending_call->mut_response()->set_val(v); pending_call->SendResponse(); } pending_kv_calls_.erase(pending_kv_calls_it); } call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& k = call->request().key(); CHECK_EQ(kv_.erase(k), 1); CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end()); call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { const std::string& k = call->request().key(); auto kv_it = kv_.find(k); if (kv_it != kv_.end()) { call->mut_response()->set_val(kv_it->second); call->SendResponse(); } else { pending_kv_calls_[k].emplace_back(call); } EnqueueRequest(); }); Add([this](CtrlCall* call) { name2lock_status_.clear(); kv_.clear(); CHECK(pending_kv_calls_.empty()) << "size(): " << pending_kv_calls_.size() << ", begin()->key: " << pending_kv_calls_.begin()->first; call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { int32_t& count = count_[call->request().key()]; count += call->request().val(); call->mut_response()->set_val(count); call->SendResponse(); EnqueueRequest(); }); Add([this](CtrlCall* call) { CHECK_EQ(count_.erase(call->request().key()), 1); call->SendResponse(); EnqueueRequest(); }); } } // namespace oneflow ================================================ FILE: oneflow/core/control/rpc_server.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CONTROL_RPC_SERVER_H_ #define ONEFLOW_CORE_CONTROL_RPC_SERVER_H_ #include #include #include "oneflow/core/control/ctrl_call.h" #include "oneflow/core/common/function_traits.h" namespace oneflow { namespace { template static std::tuple*)>...> GetHandlerTuple( std::index_sequence) { return {}; } } // namespace class RpcServer { public: OF_DISALLOW_COPY_AND_MOVE(RpcServer); virtual ~RpcServer(); protected: RpcServer() {} void HandleRpcs(); void Init(); void EnqueueRequests() { for_each_i(handlers_, helper{this}, std::make_index_sequence{}); } template void EnqueueRequest() { constexpr const size_t I = (size_t)kMethod; auto handler = std::get(handlers_); auto call = new CtrlCall<(CtrlMethod)I>(); call->set_request_handler(std::bind(handler, call)); grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(), call->mut_responder(), cq_.get(), cq_.get(), call); } template void Add(F f) { using args_type = typename function_traits::args_type; using arg_type = typename std::remove_pointer::type>::type; std::get(handlers_) = std::move(f); } virtual void OnLoadServer(CtrlCall* call) = 0; struct helper { helper(RpcServer* s) : s_(s) {} template void operator()(const T& t, V) { s_->EnqueueRequest<(CtrlMethod)V::value>(); } RpcServer* s_; }; using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence{})); HandlerTuple handlers_; std::unique_ptr grpc_service_; std::unique_ptr cq_; std::unique_ptr grpc_server_; std::thread loop_thread_; // Barrier HashMap, int32_t>> barrier_calls_; // TryLock, NotifyDone, WaitUntilDone HashMap name2lock_status_; // PushKV, ClearKV, PullKV HashMap kv_; HashMap*>> pending_kv_calls_; // IncreaseCount, EraseCount HashMap count_; }; } // namespace oneflow #endif // ONEFLOW_CORE_CONTROL_RPC_SERVER_H_ ================================================ FILE: oneflow/core/control/worker_process_info.proto ================================================ syntax = "proto2"; package oneflow; message WorkerProcessInfo { required int64 rank = 1; required int64 port = 2; optional string host = 3; } ================================================ FILE: oneflow/core/cuda/atomic.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_ATOMIC_H_ #define ONEFLOW_CORE_CUDA_ATOMIC_H_ #if defined(__CUDACC__) #include #include #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace cuda { namespace atomic { namespace internal { template struct CastCASImpl { __device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const { static_assert(sizeof(T) == sizeof(U), ""); U assumed = *(reinterpret_cast(&compare)); U ret = atomicCAS(reinterpret_cast(address), assumed, *(reinterpret_cast(&val))); *success = (ret == assumed); return *(reinterpret_cast(&ret)); } }; #if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) template struct CastCASImpl { __device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const { static_assert(sizeof(T) == sizeof(unsigned short int), ""); size_t offset = reinterpret_cast(address) & 0x2; unsigned int* address_as_ui = reinterpret_cast(reinterpret_cast(address) - offset); unsigned int old = *address_as_ui; unsigned int assumed = *(reinterpret_cast(&compare)); unsigned int newval = *(reinterpret_cast(&val)); assumed = offset ? (old & 0xffff) | (assumed << 16) : (old & 0xffff0000) | assumed; newval = offset ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; unsigned int ret = atomicCAS(address_as_ui, assumed, newval); *success = (ret == assumed); ret = offset ? (ret >> 16) : (ret & 0xffff); return *(reinterpret_cast(&ret)); } }; #endif // __CUDA_ARCH__ template __device__ __forceinline__ typename std::enable_if::type CASImpl(T* address, T compare, T val, bool* success) { return CastCASImpl()(address, compare, val, success); } template __device__ __forceinline__ typename std::enable_if::type CASImpl(T* address, T compare, T val, bool* success) { return CastCASImpl()(address, compare, val, success); } template __device__ __forceinline__ typename std::enable_if::type CASImpl(T* address, T compare, T val, bool* success) { return CastCASImpl()(address, compare, val, success); } __device__ __forceinline__ int CASImpl(int* address, int compare, int val, bool* success) { int ret = atomicCAS(address, compare, val); *success = (ret == compare); return ret; } __device__ __forceinline__ unsigned int CASImpl(unsigned int* address, unsigned int compare, unsigned int val, bool* success) { unsigned int ret = atomicCAS(address, compare, val); *success = (ret == compare); return ret; } __device__ __forceinline__ unsigned long long int CASImpl(unsigned long long int* address, unsigned long long int compare, unsigned long long int val, bool* success) { unsigned long long int ret = atomicCAS(address, compare, val); *success = (ret == compare); return ret; } #if __CUDA_ARCH__ >= 700 __device__ __forceinline__ unsigned short int CASImpl(unsigned short int* address, unsigned short int compare, unsigned short int val, bool* success) { unsigned short int ret = atomicCAS(address, compare, val); *success = (ret == compare); return ret; } #endif // __CUDA_ARCH__ >= 700 template struct AddOp { __device__ __forceinline__ T operator()(T a, T b) { return a + b; } }; template class BinaryOp> __device__ __forceinline__ T AtomicCASBinaryImpl(T* address, T val) { T old = *address; T assumed; bool success = false; do { assumed = old; old = CASImpl(address, assumed, BinaryOp()(old, val), &success); } while (!success); return old; } template __device__ __forceinline__ T AddImpl(T* address, T val) { return AtomicCASBinaryImpl(address, val); } __device__ __forceinline__ int AddImpl(int* address, int val) { return atomicAdd(address, val); } __device__ __forceinline__ unsigned int AddImpl(unsigned int* address, unsigned int val) { return atomicAdd(address, val); } __device__ __forceinline__ unsigned long long int AddImpl(unsigned long long int* address, unsigned long long int val) { return atomicAdd(address, val); } __device__ __forceinline__ uint64_t AddImpl(uint64_t* address, uint64_t val) { static_assert(sizeof(uint64_t) == sizeof(unsigned long long int), ""); return static_cast(atomicAdd(reinterpret_cast(address), static_cast(val))); } __device__ __forceinline__ float AddImpl(float* address, float val) { return atomicAdd(address, val); } #if __CUDA_ARCH__ >= 600 __device__ __forceinline__ double AddImpl(double* address, double val) { return atomicAdd(address, val); } __device__ __forceinline__ half2 AddImpl(half2* address, half2 val) { return atomicAdd(address, val); } #endif // __CUDA_ARCH__ >= 600 #if __CUDA_ARCH__ >= 700 __device__ __forceinline__ half AddImpl(half* address, half val) { return atomicAdd(address, val); } #endif // __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 800 __device__ __forceinline__ nv_bfloat16 AddImpl(nv_bfloat16* address, nv_bfloat16 val) { return atomicAdd(address, val); } __device__ __forceinline__ nv_bfloat162 AddImpl(nv_bfloat162* address, nv_bfloat162 val) { return atomicAdd(address, val); } #endif // __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ < 530 __device__ __forceinline__ half2 AddImpl(half2* address, half2 val) { __trap(); return val; } #endif // __CUDA_ARCH__ < 530 } // namespace internal template __device__ __forceinline__ typename std::enable_if::value, T>::type Cast(U v) { return static_cast(v); } template __device__ __forceinline__ typename std::enable_if::value, T>::type Cast(U v) { return v; } template __device__ __forceinline__ T CAS(T* address, U compare, V val) { bool success = false; return internal::CASImpl(address, Cast(compare), Cast(val), &success); } template __device__ __forceinline__ T Add(T* address, U val) { return internal::AddImpl(address, Cast(val)); } __device__ __forceinline__ float Mul(int32_t* address, const int32_t val) { int32_t old = *address, assumed; do { assumed = old; old = atomicCAS(address, assumed, val * assumed); } while (assumed != old); return old; } __device__ __forceinline__ float Mul(uint32_t* address, const uint32_t val) { uint32_t old = *address, assumed; do { assumed = old; old = atomicCAS(address, assumed, val * assumed); } while (assumed != old); return old; } __device__ __forceinline__ float Mul(uint64_t* address, const uint64_t val) { static_assert(sizeof(uint64_t) == sizeof(unsigned long long int), ""); unsigned long long int old = *reinterpret_cast(address), assumed; do { assumed = old; old = atomicCAS(reinterpret_cast(address), assumed, static_cast(val) * assumed); } while (assumed != old); return old; } __device__ __forceinline__ float Mul(float* address, const float val) { int32_t* address_as_int = reinterpret_cast(address); int32_t old = *address_as_int, assumed; do { assumed = old; old = atomicCAS(address_as_int, assumed, __float_as_int(val * __int_as_float(assumed))); } while (assumed != old); return __int_as_float(old); } __device__ __forceinline__ float Mul(double* address, const double val) { unsigned long long int* address_as_ull = reinterpret_cast(address); unsigned long long int old = *address_as_ull, assumed; do { assumed = old; old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val * __longlong_as_double(assumed))); } while (assumed != old); return __longlong_as_double(old); } __device__ __forceinline__ float Max(float* address, const float val) { int* address_as_i = (int*)address; int old = *address_as_i; int assumed = 0; do { assumed = old; old = atomicCAS(address_as_i, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); } while (assumed != old); return __int_as_float(old); } __device__ __forceinline__ double Max(double* address, const double val) { unsigned long long int* address_as_i = (unsigned long long int*)address; unsigned long long int old = *address_as_i; unsigned long long int assumed = 0; do { assumed = old; old = atomicCAS(address_as_i, assumed, __double_as_longlong(fmax(val, __longlong_as_double(assumed)))); } while (assumed != old); return __longlong_as_double(old); } // FastAdd is referenced from // https://github.com/pytorch/pytorch/blob/396c3b1d88d7624938a2bb0b287f2a19f1e89bb4/aten/src/ATen/native/cuda/KernelUtils.cuh#L29 #if defined(__CUDACC__) template::value>::type* = nullptr> __device__ __forceinline__ void FastSpecializedAtomicAdd(T* base, size_t offset, const size_t length, T value) { #if ((defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) \ || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) cuda::atomic::Add(reinterpret_cast(base) + offset, static_cast(value)); #else // Accounts for the chance base falls on an odd 16 bit alignment (ie, not 32 bit aligned) __half* target_addr = reinterpret_cast<__half*>(base + offset); bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); if (low_byte && offset < (length - 1)) { __half2 value2; value2.x = value; value2.y = __float2half_rz(0); cuda::atomic::Add(reinterpret_cast<__half2*>(target_addr), value2); } else if (!low_byte && offset > 0) { __half2 value2; value2.x = __float2half_rz(0); value2.y = value; cuda::atomic::Add(reinterpret_cast<__half2*>(target_addr - 1), value2); } else { cuda::atomic::Add(reinterpret_cast<__half*>(base) + offset, static_cast<__half>(value)); } #endif } template::value>::type* = nullptr> __device__ __forceinline__ void FastSpecializedAtomicAdd(T* base, size_t offset, const size_t length, T value) { cuda::atomic::Add(base + offset, value); } template __device__ __forceinline__ void FastAdd(T* base, size_t offset, const size_t length, T value) { FastSpecializedAtomicAdd(base, offset, length, value); } #endif } // namespace atomic } // namespace cuda } // namespace oneflow #endif // defined(__CUDACC__) #endif // ONEFLOW_CORE_CUDA_ATOMIC_H_ ================================================ FILE: oneflow/core/cuda/elementwise.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_ELEMENTWISE_H_ #define ONEFLOW_CORE_CUDA_ELEMENTWISE_H_ #include #include #include #include namespace oneflow { namespace cuda { namespace elementwise { constexpr int kBlockSize = 256; constexpr int kNumWaves = 32; inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min((n + kBlockSize - 1) / kBlockSize, sm_count * tpm / kBlockSize * kNumWaves)); return cudaSuccess; } template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * pack_size, ""); __device__ Pack() { // do nothing } PackType storage; T elem[pack_size]; }; template struct alignas(sizeof(T) * pack_size) Packed { __device__ Packed() { // do nothing } union { T elem[pack_size]; }; }; constexpr int kMaxPackBytes = 128 / 8; constexpr int kMaxPackSize = 8; constexpr int Min(int a, int b) { return a < b ? a : b; } template constexpr int PackSize() { return Min(kMaxPackBytes / sizeof(T), kMaxPackSize); } template constexpr int PackSize() { return Min(PackSize(), PackSize()); } template class HasApply2 { typedef char one; struct two { char x[2]; }; template static one test(decltype(&C::Apply2)); template static two test(...); public: enum { value = sizeof(test(0)) == sizeof(char) }; }; template __device__ typename std::enable_if::value == true && pack_size % 2 == 0, Packed>::type ApplyPack(const FunctorT& functor, const Packed... in) { Packed ret; #pragma unroll for (int j = 0; j < pack_size; j += 2) { functor.Apply2(ret.elem + j, (in.elem + j)...); } return ret; } template __device__ typename std::enable_if::value == false || pack_size % 2 != 0, Packed>::type ApplyPack(const FunctorT& functor, const Packed... in) { Packed ret; #pragma unroll for (int j = 0; j < pack_size; ++j) { ret.elem[j] = functor((in.elem[j])...); } return ret; } template __global__ void __launch_bounds__(kBlockSize) ApplyGeneric(FactoryT factory, int64_t n_pack, Packed* pack_r, const Packed*... pack_in, int64_t n_tail, R* tail_r, const IN*... tail_in) { auto functor = factory(); const int global_tid = blockIdx.x * kBlockSize + threadIdx.x; for (int64_t i = global_tid; i < n_pack; i += blockDim.x * gridDim.x) { pack_r[i] = ApplyPack(functor, (pack_in[i])...); } if (global_tid < n_tail) { tail_r[global_tid] = functor((tail_in[global_tid])...); } } template struct SimpleFactory { explicit SimpleFactory(FunctorT functor) : tpl(functor) {} __device__ FunctorT operator()() const { return tpl; } private: FunctorT tpl; }; template bool IsAlignedForPack() { return true; } template bool IsAlignedForPack(const T* ptr, const Args*... others) { return reinterpret_cast(ptr) % sizeof(Pack) == 0 && IsAlignedForPack(others...); } template cudaError_t LaunchKernel(FactoryT factory, int64_t n, R* r, const IN*... in, cudaStream_t stream) { const int64_t n_pack = n / pack_size; const int64_t tail_offset = n_pack * pack_size; const int64_t n_tail = n - tail_offset; int num_blocks; { cudaError_t err = GetNumBlocks(n_pack, &num_blocks); if (err != cudaSuccess) { return err; } } ApplyGeneric<<>>( factory, n_pack, reinterpret_cast*>(r), (reinterpret_cast*>(in))..., n_tail, r + tail_offset, (in + tail_offset)...); return cudaPeekAtLastError(); } template struct GenericLauncher { static cudaError_t Launch(FactoryT factory, int64_t n, R* r, const IN*... in, cudaStream_t stream) { constexpr int max_pack_size = PackSize(); if (IsAlignedForPack(r, in...)) { return LaunchKernel(factory, n, r, in..., stream); } else { return LaunchKernel<1, FactoryT, R, IN...>(factory, n, r, in..., stream); } } }; template inline cudaError_t UnaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, cudaStream_t stream) { return GenericLauncher::Launch(factory, n, r, a, stream); } template inline cudaError_t Unary(FunctorT functor, int64_t n, R* r, const A* a, cudaStream_t stream) { return UnaryWithFactory(SimpleFactory(functor), n, r, a, stream); } template inline cudaError_t BinaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b, cudaStream_t stream) { return GenericLauncher::Launch(factory, n, r, a, b, stream); } template inline cudaError_t Binary(FunctorT functor, int64_t n, R* r, const A* a, const B* b, cudaStream_t stream) { return BinaryWithFactory(SimpleFactory(functor), n, r, a, b, stream); } template inline cudaError_t TernaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b, const C* c, cudaStream_t stream) { return GenericLauncher::Launch(factory, n, r, a, b, c, stream); } template inline cudaError_t Ternary(FunctorT functor, int64_t n, R* r, const A* a, const B* b, const C* c, cudaStream_t stream) { return TernaryWithFactory(SimpleFactory(functor), n, r, a, b, c, stream); } } // namespace elementwise } // namespace cuda } // namespace oneflow #endif // ONEFLOW_CORE_CUDA_ELEMENTWISE_H_ ================================================ FILE: oneflow/core/cuda/layer_norm.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_LAYER_NORM_H_ #define ONEFLOW_CORE_CUDA_LAYER_NORM_H_ #include #include #include namespace oneflow { namespace cuda { namespace layer_norm { constexpr int kWarpSize = 32; template struct SumOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } }; template struct MaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } }; template class ReductionOp, typename T, int thread_group_width = kWarpSize> __inline__ __device__ T WarpAllReduce(T val) { for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width)); } return val; } template class ReductionOp, typename T, int block_size> __inline__ __device__ T BlockAllReduce(T val) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T result_broadcast; T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); if (threadIdx.x == 0) { result_broadcast = result; } __syncthreads(); return result_broadcast; } template __inline__ __device__ T Div(T a, T b); template<> __inline__ __device__ float Div(float a, float b) { #ifdef OF_LAYER_NORM_USE_FAST_MATH return __fdividef(a, b); #else return a / b; #endif } template<> __inline__ __device__ double Div(double a, double b) { return a / b; } template __inline__ __device__ T Rsqrt(T x); template<> __inline__ __device__ float Rsqrt(float x) { #ifdef OF_LAYER_NORM_USE_FAST_MATH return __frsqrt_rn(x); #else return rsqrt(x); #endif } template<> __inline__ __device__ double Rsqrt(double x) { return rsqrt(x); } template inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size, int64_t max_blocks, int64_t waves, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int max_active_blocks; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, block_size, dynamic_smem_size); } *num_blocks = std::max(1, std::min(max_blocks, sm_count * max_active_blocks * waves)); return cudaSuccess; } template struct DefaultComputeType { using type = T; }; template<> struct DefaultComputeType { using type = float; }; #if CUDA_VERSION >= 11000 template<> struct DefaultComputeType { using type = float; }; #endif // CUDA_VERSION >= 11000 template class HasCanPackAs { typedef char one; struct two { char x[2]; }; template static one test(decltype(&C::CanPackAs)); template static two test(...); public: enum { value = sizeof(test(0)) == sizeof(char) }; }; template typename std::enable_if::value == true, bool>::type CanPackAs(T t, size_t pack_size) { return t.CanPackAs(pack_size); } template typename std::enable_if::value == false, bool>::type CanPackAs(T t, size_t pack_size) { return true; } template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * N, ""); __device__ Pack() { // do nothing } PackType storage; T elem[N]; }; template struct DirectLoad { using LoadType = DST; DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { Pack pack; const int64_t offset = (row * row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } } const SRC* src; int64_t row_size; }; template struct DirectStore { DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { Pack pack; const int64_t offset = (row * row_size + col) / N; #pragma unroll for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } *(reinterpret_cast*>(dst) + offset) = pack.storage; } DST* dst; int64_t row_size; }; template inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) { // Use Welford Online algorithem to compute mean and variance // For more details you can refer to: // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm *count += 1; T delta1 = val - *mean; *mean += Div(delta1, *count); T delta2 = val - *mean; *m2 += delta1 * delta2; } template inline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) { if (b_count == 0) { return; } T new_count = *count + b_count; T nb_over_n = Div(b_count, new_count); T delta = b_mean - *mean; *mean += delta * nb_over_n; *m2 += b_m2 + delta * delta * (*count) * nb_over_n; *count = new_count; } template __inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) { *mean = thread_mean; *m2 = thread_m2; *count = thread_count; for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width); T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width); T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width); WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); } } template __inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) { WelfordWarpReduce(thread_mean, thread_m2, thread_count, mean, m2, count); *mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width); *m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width); *count = __shfl_sync(0xffffffff, *count, 0, thread_group_width); } template __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count, T* result_mean, T* result_m2, T* result_count) { __shared__ T mean_shared[kWarpSize]; __shared__ T m2_shared[kWarpSize]; __shared__ T count_shared[kWarpSize]; __shared__ T mean_result_broadcast; __shared__ T m2_result_broadcast; __shared__ T count_result_broadcast; const int lid = threadIdx.x % kWarpSize; const int wid = threadIdx.x / kWarpSize; T warp_mean = 0; T warp_m2 = 0; T warp_count = 0; WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); __syncthreads(); if (lid == 0) { mean_shared[wid] = warp_mean; m2_shared[wid] = warp_m2; count_shared[wid] = warp_count; } __syncthreads(); if (wid == 0) { if (threadIdx.x < blockDim.x / kWarpSize) { warp_mean = mean_shared[lid]; warp_m2 = m2_shared[lid]; warp_count = count_shared[lid]; } else { warp_mean = static_cast(0); warp_m2 = static_cast(0); warp_count = static_cast(0); } __syncwarp(); T block_mean = 0; T block_m2 = 0; T block_count = 0; WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); if (lid == 0) { mean_result_broadcast = block_mean; m2_result_broadcast = block_m2; count_result_broadcast = block_count; } } __syncthreads(); *result_mean = mean_result_broadcast; *result_m2 = m2_result_broadcast; *result_count = count_result_broadcast; } template __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { using LoadType = typename LOAD::LoadType; static_assert(max_cols_per_thread % pack_size == 0, ""); static_assert(min_cols_per_thread % pack_size == 0, ""); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int max_num_packs = max_cols_per_thread / pack_size; constexpr int min_num_packs = min_cols_per_thread / pack_size; assert(cols <= max_cols_per_thread * thread_group_width); ComputeType buf[rows_per_access][max_cols_per_thread]; const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int64_t num_global_thread_group = gridDim.x * blockDim.y; const int64_t lane_id = threadIdx.x; const int64_t step = num_global_thread_group * rows_per_access; for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { ComputeType thread_mean[rows_per_access]; ComputeType thread_m2[rows_per_access]; ComputeType thread_count[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { thread_mean[row_id] = 0; thread_m2[row_id] = 0; thread_count[row_id] = 0; ComputeType* row_buf = buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; const int pack_offset = pack_id * pack_size; LoadType pack[pack_size]; load.template load(pack, row + row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = static_cast(pack[i]); WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, thread_count + row_id); } } for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; const int pack_offset = pack_id * pack_size; if (!padding || col < cols) { LoadType pack[pack_size]; load.template load(pack, row + row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = static_cast(pack[i]); WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, thread_count + row_id); } } else { #pragma unroll for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; } } } } ComputeType warp_mean[rows_per_access]; ComputeType warp_m2[rows_per_access]; ComputeType warp_count[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { int global_row_id = row + row_id; ComputeType* row_buf = buf[row_id]; WelfordWarpAllReduce( thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id, warp_m2 + row_id, warp_count + row_id); ComputeType row_mean = warp_mean[row_id]; ComputeType row_variance = max(Div(warp_m2[row_id], warp_count[row_id]), static_cast(0.0)); ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); if (lane_id == 0) { mean[global_row_id] = row_mean; inv_variance[global_row_id] = row_inv_var; } #pragma unroll for (int i = 0; i < max_cols_per_thread; ++i) { row_buf[i] = (row_buf[i] - row_mean) * row_inv_var; } #pragma unroll for (int i = 0; i < min_num_packs; ++i) { const int col = (i * thread_group_width + lane_id) * pack_size; store.template store(row_buf + i * pack_size, global_row_id, col); } #pragma unroll for (int i = min_num_packs; i < max_num_packs; ++i) { const int col = (i * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { store.template store(row_buf + i * pack_size, global_row_id, col); } } } } } template inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = GetNumBlocks( LayerNormWarpImpl, block_size, 0, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormWarpImpl <<>>(load, store, rows, cols, epsilon, mean, inv_variance); return cudaPeekAtLastError(); } template inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols == max_cols_per_thread * thread_group_width) { // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param. return LaunchLayerNormWarpImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } else { return LaunchLayerNormWarpImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } } template typename std::enable_if::type DispatchLayerNormWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchLayerNormWarpImplPadding( \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \ } else { \ return DispatchLayerNormWarpImplPadding( \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if (cols <= (max_col)*kWarpSize) { \ return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ epsilon, mean, inv_variance); \ } DEFINE_ONE_ELIF(2, 1) DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template typename std::enable_if::type DispatchLayerNormWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchLayerNormWarpImplPadding( \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \ } else { \ return DispatchLayerNormWarpImplPadding( \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \ return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ epsilon, mean, inv_variance); \ } DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template struct DispatchLayerNormWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { return DispatchLayerNormWarpImplCols( stream, load, store, rows, cols, epsilon, mean, inv_variance); } else { return DispatchLayerNormWarpImplCols( stream, load, store, rows, cols, epsilon, mean, inv_variance); } } }; template inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { return DispatchLayerNormWarpImplPackSize()( stream, load, store, rows, cols, epsilon, mean, inv_variance); } template __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { using LoadType = typename LOAD::LoadType; extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; auto* buf = reinterpret_cast(shared_buf); const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = static_cast(cols) / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_mean = 0; ComputeType thread_m2 = 0; ComputeType thread_count = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { LoadType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { buf[i * num_packs + pack_id] = pack[i]; WelfordCombine(static_cast(pack[i]), &thread_mean, &thread_m2, &thread_count); } } ComputeType row_mean = 0; ComputeType row_m2 = 0; ComputeType row_count = 0; WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2, &row_count); ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0)); ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); if (threadIdx.x == 0) { mean[row] = row_mean; inv_variance[row] = row_inv_var; } for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; ++i) { pack[i] = (static_cast(buf[i * num_packs + pack_id]) - row_mean) * row_inv_var; } store.template store(pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(LayerNormBlockSMemImpl, block_size, smem, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormBlockSMemImpl <<>>(load, store, rows, cols, epsilon, mean, inv_variance); return cudaPeekAtLastError(); } template cudaError_t MaximizeDynamicSharedMemorySize(Func func, const int max_smem_size) { cudaFuncAttributes attr{}; cudaError_t err = cudaFuncGetAttributes(&attr, func); if (err != cudaSuccess) { return err; } constexpr int reserved_smem = 1024; // 1K return cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, max_smem_size - attr.sharedSizeBytes - reserved_smem); } template inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; int dev = 0; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count = 0; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } static const bool max_smem_configed = [=]() { int max_smem_size = 0; cudaError_t err = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } return true; }(); const size_t smem = cols * sizeof(typename LOAD::LoadType); int max_active_blocks_conf_1; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, LayerNormBlockSMemImpl, block_size_conf_1, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_1 <= 0) { *success = false; return cudaSuccess; } int max_active_blocks_conf_4; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_4, LayerNormBlockSMemImpl, block_size_conf_4, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_4 == max_active_blocks_conf_1 || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormBlockSMemImpl( stream, load, store, smem, rows, cols, epsilon, mean, inv_variance); } int max_active_blocks_conf_3; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_3, LayerNormBlockSMemImpl, block_size_conf_3, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_3 == max_active_blocks_conf_1 || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormBlockSMemImpl( stream, load, store, smem, rows, cols, epsilon, mean, inv_variance); } int max_active_blocks_conf_2; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, LayerNormBlockSMemImpl, block_size_conf_2, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_2 == max_active_blocks_conf_1 || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormBlockSMemImpl( stream, load, store, smem, rows, cols, epsilon, mean, inv_variance); } *success = true; return LaunchLayerNormBlockSMemImpl( stream, load, store, smem, rows, cols, epsilon, mean, inv_variance); } template struct TryDispatchLayerNormBlockSMemImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) { if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { return TryDispatchLayerNormBlockSMemImplBlockSize( stream, load, store, rows, cols, epsilon, mean, inv_variance, success); } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { return TryDispatchLayerNormBlockSMemImplBlockSize( stream, load, store, rows, cols, epsilon, mean, inv_variance, success); } else { return TryDispatchLayerNormBlockSMemImplBlockSize( stream, load, store, rows, cols, epsilon, mean, inv_variance, success); } } }; template inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) { return TryDispatchLayerNormBlockSMemImplPackSize()( stream, load, store, rows, cols, epsilon, mean, inv_variance, success); } template __global__ void __launch_bounds__(1024) LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { using LoadType = typename LOAD::LoadType; const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = static_cast(cols) / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_mean = 0; ComputeType thread_m2 = 0; ComputeType thread_count = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { LoadType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { WelfordCombine(static_cast(pack[i]), &thread_mean, &thread_m2, &thread_count); } } ComputeType row_mean = 0; ComputeType row_m2 = 0; ComputeType row_count = 0; WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2, &row_count); ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0)); ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); if (threadIdx.x == 0) { mean[row] = row_mean; inv_variance[row] = row_inv_var; } for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { LoadType pack[pack_size]; ComputeType dst_pack[pack_size]; const int pack_offset = pack_id * pack_size; load.template load(pack, row, pack_offset); #pragma unroll for (int i = 0; i < pack_size; ++i) { dst_pack[i] = (static_cast(pack[i]) - row_mean) * row_inv_var; } store.template store(dst_pack, row, pack_offset); } } } template inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { constexpr int block_size = 1024; constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(LayerNormBlockUncachedImpl, block_size, 0, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormBlockUncachedImpl <<>>(load, store, rows, cols, epsilon, mean, inv_variance); return cudaPeekAtLastError(); } template struct DispatchLayerNormBlockUncachedImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { return LaunchLayerNormBlockUncachedImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { return LaunchLayerNormBlockUncachedImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } else { return LaunchLayerNormBlockUncachedImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } } }; template inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { return DispatchLayerNormBlockUncachedImplPackSize()( stream, load, store, rows, cols, epsilon, mean, inv_variance); } template inline typename std::enable_if::value, cudaError_t>::type DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { if (cols <= 1024) { return DispatchLayerNormWarpImpl(stream, load, store, rows, cols, epsilon, mean, inv_variance); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchLayerNormBlockSMemImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchLayerNormBlockUncachedImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, const double epsilon, ComputeType* mean, ComputeType* inv_variance) { return DispatchLayerNormBlockUncachedImpl( stream, load, store, rows, cols, epsilon, mean, inv_variance); } /* LayerNormGrad dx: normalized = (x - mean) * inv_var sum_stats1 = sum(scaled_dy) sum_stats2 = sum(scaled_dy * normalized) dx = cols * dy - sum_stats1 - normalized * sum_stats2 dx *= inv_var / cols */ template __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { using LoadTypeX = typename LOAD_X::LoadType; using LoadTypeDy = typename LOAD_SCALED_DY::LoadType; static_assert(max_cols_per_thread % pack_size == 0, ""); static_assert(min_cols_per_thread % pack_size == 0, ""); constexpr int max_num_packs = max_cols_per_thread / pack_size; constexpr int min_num_packs = min_cols_per_thread / pack_size; assert(cols <= max_cols_per_thread * thread_group_width); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); ComputeType normalized_buf[rows_per_access][max_cols_per_thread]; ComputeType dy_buf[rows_per_access][max_cols_per_thread]; const ComputeType one_over_cols = static_cast(1.0) / static_cast(cols); const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int64_t num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; const int64_t step = num_global_thread_group * rows_per_access; for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { ComputeType sum_stats1[rows_per_access]; ComputeType sum_stats2[rows_per_access]; ComputeType inv_variance_buf[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { const int global_row_id = row + row_id; ComputeType mean_val = mean[global_row_id]; inv_variance_buf[row_id] = inv_variance[global_row_id]; sum_stats1[row_id] = 0; sum_stats2[row_id] = 0; ComputeType* row_normalized_buf = normalized_buf[row_id]; ComputeType* row_dy_buf = dy_buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; const int pack_offset = pack_id * pack_size; LoadTypeX pack_x[pack_size]; LoadTypeDy pack_dy[pack_size]; load_x.template load(pack_x, global_row_id, col); load_scaled_dy.template load(pack_dy, global_row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { const int col_id = pack_offset + i; // row_normalized_buf store x row_normalized_buf[col_id] = (static_cast(pack_x[i]) - mean_val) * inv_variance_buf[row_id]; row_dy_buf[col_id] = static_cast(pack_dy[i]); sum_stats1[row_id] += row_dy_buf[col_id]; sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id]; } } #pragma unroll for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; const int pack_offset = pack_id * pack_size; if (col < cols) { LoadTypeX pack_x[pack_size]; LoadTypeDy pack_dy[pack_size]; load_x.template load(pack_x, global_row_id, col); load_scaled_dy.template load(pack_dy, global_row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { const int col_id = pack_offset + i; // row_normalized_buf store x row_normalized_buf[col_id] = (static_cast(pack_x[i]) - mean_val) * inv_variance_buf[row_id]; row_dy_buf[col_id] = static_cast(pack_dy[i]); sum_stats1[row_id] += row_dy_buf[col_id]; sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id]; } } } } ComputeType warp_sum_stats1[rows_per_access]; ComputeType warp_sum_stats2[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { warp_sum_stats1[row_id] = WarpAllReduce(sum_stats1[row_id]); warp_sum_stats2[row_id] = WarpAllReduce(sum_stats2[row_id]); } #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { const int global_row_id = row + row_id; ComputeType* row_normalized_buf = normalized_buf[row_id]; ComputeType* row_dy_buf = dy_buf[row_id]; const ComputeType inv_variance_over_cols = inv_variance_buf[row_id] * one_over_cols; #pragma unroll for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; const int pack_offset = pack_id * pack_size; for (int i = 0; i < pack_size; ++i) { const int col_id = pack_offset + i; row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id] - row_normalized_buf[col_id] * warp_sum_stats2[row_id]) * inv_variance_over_cols; } store.template store(row_dy_buf + pack_offset, global_row_id, col); } #pragma unroll for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) { const int col = (pack_id * thread_group_width + lane_id) * pack_size; if (col < cols) { const int pack_offset = pack_id * pack_size; for (int i = 0; i < pack_size; ++i) { const int col_id = pack_offset + i; row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id] - row_normalized_buf[col_id] * warp_sum_stats2[row_id]) * inv_variance_over_cols; } store.template store(row_dy_buf + pack_offset, global_row_id, col); } } } } } template inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = GetNumBlocks(LayerNormGradWarpImpl, block_size, 0, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormGradWarpImpl <<>>(load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t DispatchLayerNormGradWarpImplPadding(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { if (cols == max_cols_per_thread * thread_group_width) { // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param. return LaunchLayerNormGradWarpImpl(stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } else { return LaunchLayerNormGradWarpImpl(stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } template typename std::enable_if::type DispatchLayerNormGradWarpImplCols( cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchLayerNormGradWarpImplPadding( \ stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \ } else { \ return DispatchLayerNormGradWarpImplPadding( \ stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if (cols <= (max_col)*kWarpSize) { \ return DispatchLayerNormGradWarpImplPadding( \ stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \ } DEFINE_ONE_ELIF(2, 1) DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template struct DispatchLayerNormGradWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { return DispatchLayerNormGradWarpImplCols( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } }; template inline cudaError_t DispatchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { return DispatchLayerNormGradWarpImplPackSize()( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } template __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { using LoadTypeX = typename LOAD_X::LoadType; using LoadTypeDy = typename LOAD_SCALED_DY::LoadType; extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[]; auto* normalized_buf = reinterpret_cast(grad_shared_buf); auto* dy_buf = reinterpret_cast(normalized_buf + cols); const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = static_cast(cols) / pack_size; const ComputeType one_over_cols = static_cast(1.0) / static_cast(cols); for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType sum_stats1 = 0; ComputeType sum_stats2 = 0; const ComputeType mean_val = mean[row]; const ComputeType inv_variance_val = inv_variance[row]; const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { LoadTypeX x_pack[pack_size]; LoadTypeDy dy_pack[pack_size]; load_x.template load(x_pack, row, pack_id * pack_size); load_scaled_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { const int buf_offset = i * num_packs + pack_id; ComputeType normalized = (static_cast(x_pack[i]) - mean_val) * inv_variance_val; normalized_buf[buf_offset] = static_cast(normalized); dy_buf[buf_offset] = dy_pack[i]; sum_stats1 += static_cast(dy_pack[i]); sum_stats2 += static_cast(dy_pack[i]) * normalized; } } const ComputeType row_sum_stats1 = BlockAllReduce(sum_stats1); const ComputeType row_sum_stats2 = BlockAllReduce(sum_stats2); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; ++i) { const int buf_offset = i * num_packs + pack_id; pack[i] = (cols * static_cast(dy_buf[buf_offset]) - row_sum_stats1 - static_cast(normalized_buf[buf_offset]) * row_sum_stats2) * inv_variance_over_cols; } store.template store(pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, int smem, const int64_t rows, const int64_t cols) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(LayerNormGradBlockSMemImpl, block_size, smem, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormGradBlockSMemImpl <<>>(load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize( cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; int dev = 0; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count = 0; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } static const bool max_smem_configed = [=]() { int max_smem_size = 0; cudaError_t err = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormGradBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormGradBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormGradBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } err = MaximizeDynamicSharedMemorySize( LayerNormGradBlockSMemImpl, max_smem_size); if (err != cudaSuccess) { return false; } return true; }(); using LoadTypeX = typename LOAD_X::LoadType; using LoadTypeDy = typename LOAD_SCALED_DY::LoadType; const size_t smem = cols * (sizeof(LoadTypeX) + sizeof(LoadTypeDy)); int max_active_blocks_conf_1; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, LayerNormGradBlockSMemImpl, block_size_conf_1, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_1 <= 0) { *success = false; return cudaSuccess; } int max_active_blocks_conf_4; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_4, LayerNormGradBlockSMemImpl, block_size_conf_4, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_4 == max_active_blocks_conf_1 || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormGradBlockSMemImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols); } int max_active_blocks_conf_3; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_3, LayerNormGradBlockSMemImpl, block_size_conf_3, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_3 == max_active_blocks_conf_1 || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormGradBlockSMemImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols); } int max_active_blocks_conf_2; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, LayerNormGradBlockSMemImpl, block_size_conf_2, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_2 == max_active_blocks_conf_1 || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) { *success = true; return LaunchLayerNormGradBlockSMemImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols); } *success = true; return LaunchLayerNormGradBlockSMemImpl(stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols); } template struct TryDispatchLayerNormGradBlockSMemImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols, bool* success) { if (cols % 2 == 0 && CanPackAs(load_x, 2) && CanPackAs(load_scaled_dy, 2) && CanPackAs(store, 2)) { return TryDispatchLayerNormGradBlockSMemImplBlockSize( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success); } else { return TryDispatchLayerNormGradBlockSMemImplBlockSize( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success); } } }; template inline cudaError_t TryDispatchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols, bool* success) { return TryDispatchLayerNormGradBlockSMemImplPackSize()( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success); } template __global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { using LoadTypeX = typename LOAD_X::LoadType; using LoadTypeDy = typename LOAD_SCALED_DY::LoadType; const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = static_cast(cols) / pack_size; const ComputeType one_over_cols = static_cast(1.0) / static_cast(cols); for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { const ComputeType mean_val = mean[row]; const ComputeType inv_variance_val = inv_variance[row]; const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols; ComputeType sum_stats1 = 0; ComputeType sum_stats2 = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { const int pack_offset = pack_id * pack_size; LoadTypeX x_pack[pack_size]; LoadTypeDy dy_pack[pack_size]; load_x.template load(x_pack, row, pack_offset); load_scaled_dy.template load(dy_pack, row, pack_offset); #pragma unroll for (int i = 0; i < pack_size; ++i) { sum_stats1 += static_cast(dy_pack[i]); sum_stats2 += static_cast(dy_pack[i]) * (static_cast(x_pack[i]) - mean_val) * inv_variance_val; } } const ComputeType row_sum_stats1 = BlockAllReduce(sum_stats1); const ComputeType row_sum_stats2 = BlockAllReduce(sum_stats2); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { const int pack_offset = pack_id * pack_size; LoadTypeX x_pack[pack_size]; LoadTypeDy dy_pack[pack_size]; ComputeType dx_pack[pack_size]; load_x.template load(x_pack, row, pack_offset); load_scaled_dy.template load(dy_pack, row, pack_offset); #pragma unroll for (int i = 0; i < pack_size; ++i) { dx_pack[i] = (cols * static_cast(dy_pack[i]) - row_sum_stats1 - (static_cast(x_pack[i]) - mean_val) * inv_variance_val * row_sum_stats2) * inv_variance_over_cols; } store.template store(dx_pack, row, pack_offset); } } } template inline cudaError_t LaunchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(LayerNormGradBlockUncachedImpl, block_size, 0, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } LayerNormGradBlockUncachedImpl <<>>(load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize( cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { int max_active_blocks = 0; constexpr int block_size_conf_1 = 1024; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, LayerNormGradBlockUncachedImpl, block_size_conf_1, 0); if (max_active_blocks > 0) { return LaunchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } constexpr int block_size_conf_2 = 512; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, LayerNormGradBlockUncachedImpl, block_size_conf_2, 0); if (max_active_blocks > 0) { return LaunchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } constexpr int block_size_conf_3 = 256; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, LayerNormGradBlockUncachedImpl, block_size_conf_2, 0); if (max_active_blocks > 0) { return LaunchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } constexpr int block_size_conf_4 = 128; return LaunchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } template struct DispatchLayerNormGradBlockUncachedImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { if (cols % 2 == 0 && CanPackAs(load_x, 2) && CanPackAs(load_scaled_dy, 2) && CanPackAs(store, 2) && cols > kWarpSize) { return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } else { return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } }; template inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { return DispatchLayerNormGradBlockUncachedImplPackSize()( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } template inline typename std::enable_if::value, cudaError_t>::type DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchLayerNormGradWarpImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchLayerNormGradBlockSMemImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store, const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows, const int64_t cols) { return DispatchLayerNormGradBlockUncachedImpl( stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); } } // namespace layer_norm } // namespace cuda } // namespace oneflow #endif // ONEFLOW_CORE_CUDA_LAYER_NORM_H_ ================================================ FILE: oneflow/core/cuda/rms_norm.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_RMS_NORM_H_ #define ONEFLOW_CORE_CUDA_RMS_NORM_H_ #include "oneflow/core/cuda/layer_norm.cuh" namespace oneflow { namespace cuda { namespace rms_norm { constexpr int kWarpSize = 32; template __inline__ __device__ T WarpReduceSum(T val) { for (int mask = 16; mask > 0; mask /= 2) { val += __shfl_down_sync(0xffffffff, val, mask); } return val; } template __global__ void RmsNormWarpImpl(LOAD load, STORE store, const int nrow, const int ncol, const double eps, ComputeType* inv_rms) { static_assert(max_cols_per_thread % pack_size == 0, ""); static_assert(min_cols_per_thread % pack_size == 0, ""); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int max_packs = max_cols_per_thread / pack_size; constexpr int min_packs = min_cols_per_thread / pack_size; assert(ncol <= max_cols_per_thread * thread_group_width); ComputeType buf[rows_per_access][max_cols_per_thread]; const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int num_global_thread_groups = gridDim.x * blockDim.y; for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_groups) { ComputeType thread_square_sum[rows_per_access]; #pragma unroll for (int row_j = 0; row_j < rows_per_access; ++row_j) { thread_square_sum[row_j] = 0; ComputeType* row_buf = buf[row_j]; const int row = row_i * rows_per_access + row_j; #pragma unroll for (int pack_i = 0; pack_i < min_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size; load.template load(row_buf + pack_offset, row, col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { thread_square_sum[row_j] += row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j]; } } #pragma unroll for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size; if (!padding || col < ncol) { load.template load(row_buf + pack_offset, row, col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { thread_square_sum[row_j] += row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j]; } } else { #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { row_buf[pack_i * pack_size + pack_j] = 0; } } } } ComputeType warp_square_sum[rows_per_access]; #pragma unroll for (int row_j = 0; row_j < rows_per_access; ++row_j) { const int row = row_i * rows_per_access + row_j; ComputeType* row_buf = buf[row_j]; warp_square_sum[row_j] = layer_norm::WarpAllReduce( thread_square_sum[row_j]); ComputeType row_square_mean = layer_norm::Div(warp_square_sum[row_j], static_cast(ncol)); ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast(eps)); if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; } #pragma unroll for (int col = 0; col < max_cols_per_thread; ++col) { row_buf[col] *= row_inv_rms; } #pragma unroll for (int pack_i = 0; pack_i < min_packs; ++pack_i) { const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size; store.template store(row_buf + pack_i * pack_size, row, col); } #pragma unroll for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) { const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size; if (!padding || col < ncol) { store.template store(row_buf + pack_i * pack_size, row, col); } } } } } template cudaError_t LaunchRmsNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; const int64_t num_blocks = (nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormWarpImpl, block_size, 0, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } dim3 block_dim(thread_group_width, thread_groups_per_block); RmsNormWarpImpl <<>>(load, store, static_cast(nrow), static_cast(ncol), eps, inv_rms); return cudaPeekAtLastError(); } template cudaError_t DispatchLaunchRmsNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol == max_cols_per_thread * thread_group_width) { // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param. return LaunchRmsNormWarpImpl( stream, load, store, nrow, ncol, eps, inv_rms); } else { return LaunchRmsNormWarpImpl( stream, load, store, nrow, ncol, eps, inv_rms); } } template typename std::enable_if::type DispatchLaunchRmsNormWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (ncol <= (thread_group_width)*pack_size) { \ if (nrow % 2 == 0) { \ return DispatchLaunchRmsNormWarpImplPadding( \ stream, load, store, nrow, ncol, eps, inv_rms); \ } else { \ return DispatchLaunchRmsNormWarpImplPadding( \ stream, load, store, nrow, ncol, eps, inv_rms); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if (ncol <= (max_col)*kWarpSize) { \ return DispatchLaunchRmsNormWarpImplPadding(stream, load, store, nrow, \ ncol, eps, inv_rms); \ } DEFINE_ONE_ELIF(2, 1) DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template typename std::enable_if::type DispatchLaunchRmsNormWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (ncol <= (thread_group_width)*pack_size) { \ if (nrow % 2 == 0) { \ return DispatchLaunchRmsNormWarpImplPadding( \ stream, load, store, nrow, ncol, eps, inv_rms); \ } else { \ return DispatchLaunchRmsNormWarpImplPadding( \ stream, load, store, nrow, ncol, eps, inv_rms); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if ((ncol <= (max_col)*kWarpSize) && (ncol > (min_col)*kWarpSize)) { \ return DispatchLaunchRmsNormWarpImplPadding(stream, load, store, nrow, \ ncol, eps, inv_rms); \ } DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template cudaError_t DispatchLaunchRmsNormWarpImplPackSize(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol % 2 == 0 && layer_norm::CanPackAs(load, 2) && layer_norm::CanPackAs(store, 2)) { return DispatchLaunchRmsNormWarpImplCols(stream, load, store, nrow, ncol, eps, inv_rms); } else { return DispatchLaunchRmsNormWarpImplCols(stream, load, store, nrow, ncol, eps, inv_rms); } } template cudaError_t DispatchLaunchRmsNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { return DispatchLaunchRmsNormWarpImplPackSize(stream, load, store, nrow, ncol, eps, inv_rms); } template __global__ void RmsNormBlockSMemImpl(LOAD load, STORE store, const int nrow, const int ncol, const double eps, ComputeType* inv_rms) { extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; auto* buf = reinterpret_cast(shared_buf); assert(ncol % pack_size == 0); const int num_packs = ncol / pack_size; for (int row = blockIdx.x; row < nrow; row += gridDim.x) { ComputeType thread_square_sum = 0; for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) { ComputeType pack[pack_size]; const int col = pack_i * pack_size; load.template load(pack, row, col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { buf[pack_i * pack_size + pack_j] = pack[pack_j]; thread_square_sum += pack[pack_j] * pack[pack_j]; } } ComputeType row_square_sum = layer_norm::BlockAllReduce(thread_square_sum); ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast(ncol)); ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast(eps)); if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; } for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) { ComputeType pack[pack_size]; #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { pack[pack_j] = buf[pack_i * pack_size + pack_j] * row_inv_rms; } const int col = pack_i * pack_size; store.template store(pack, row, col); } } } template cudaError_t LaunchRmsNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, size_t smem_size, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormBlockSMemImpl, block_size, smem_size, nrow, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } RmsNormBlockSMemImpl <<>>(load, store, nrow, ncol, eps, inv_rms); return cudaPeekAtLastError(); } template cudaError_t TryDispatchLaunchRmsNormBlockSMemImplBlockSize(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; const size_t smem_size = ncol * sizeof(ComputeType); int max_active_blocks = 0; int num_blocks = 0; #define SELECT_BLOCK_SIZE_CONF(block_size_conf) \ { \ cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ &num_blocks, RmsNormBlockSMemImpl, \ block_size_conf, smem_size); \ if (err != cudaSuccess) { return err; } \ if (max_active_blocks == 0) { \ if (num_blocks <= max_active_blocks) { \ *success = false; \ return cudaSuccess; \ } \ max_active_blocks = num_blocks; \ } else { \ if (num_blocks == max_active_blocks) { \ *success = true; \ return LaunchRmsNormBlockSMemImpl( \ stream, load, store, smem_size, nrow, ncol, eps, inv_rms); \ } \ } \ } SELECT_BLOCK_SIZE_CONF(block_size_conf_1) SELECT_BLOCK_SIZE_CONF(block_size_conf_4) SELECT_BLOCK_SIZE_CONF(block_size_conf_3) SELECT_BLOCK_SIZE_CONF(block_size_conf_2) #undef SELECT_BLOCK_SIZE_CONF *success = true; return LaunchRmsNormBlockSMemImpl( stream, load, store, smem_size, nrow, ncol, eps, inv_rms); } template cudaError_t TryDispatchLaunchRmsNormBlockSMemImplPackSize(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms, bool* success) { if (ncol % 4 == 0 && layer_norm::CanPackAs(load, 4) && layer_norm::CanPackAs(store, 4)) { return TryDispatchLaunchRmsNormBlockSMemImplBlockSize( stream, load, store, nrow, ncol, eps, inv_rms, success); } else if (ncol % 2 == 0 && layer_norm::CanPackAs(load, 2) && layer_norm::CanPackAs(store, 2)) { return TryDispatchLaunchRmsNormBlockSMemImplBlockSize( stream, load, store, nrow, ncol, eps, inv_rms, success); } else { return TryDispatchLaunchRmsNormBlockSMemImplBlockSize( stream, load, store, nrow, ncol, eps, inv_rms, success); } } template cudaError_t TryDispatchLaunchRmsNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms, bool* success) { return TryDispatchLaunchRmsNormBlockSMemImplPackSize(stream, load, store, nrow, ncol, eps, inv_rms, success); } template __global__ void RmsNormBlockUncachedImpl(LOAD load, STORE store, const int nrow, const int ncol, const double eps, ComputeType* inv_rms) { assert(ncol % pack_size == 0); const int num_packs = ncol / pack_size; for (int row = blockIdx.x; row < nrow; row += gridDim.x) { ComputeType thread_square_sum = 0; for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) { ComputeType pack[pack_size]; const int col = pack_i * pack_size; load.template load(pack, row, col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { thread_square_sum += pack[pack_j] * pack[pack_j]; } } ComputeType row_square_sum = layer_norm::BlockAllReduce(thread_square_sum); ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast(ncol)); ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast(eps)); if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; } for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) { ComputeType pack[pack_size]; const int col = pack_i * pack_size; load.template load(pack, row, col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { pack[pack_j] = pack[pack_j] * row_inv_rms; } store.template store(pack, row, col); } } } template cudaError_t LaunchRmsNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { constexpr int block_size = 1024; constexpr int waves = 32; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormBlockUncachedImpl, block_size, 0, nrow, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } RmsNormBlockUncachedImpl <<>>(load, store, nrow, ncol, eps, inv_rms); return cudaPeekAtLastError(); } template cudaError_t DispatchLaunchRmsNormBlockUncachedImplPackSize(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol % 4 == 0 && layer_norm::CanPackAs(load, 4) && layer_norm::CanPackAs(store, 4)) { return LaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms); } else if (ncol % 2 == 0 && layer_norm::CanPackAs(load, 2) && layer_norm::CanPackAs(store, 2)) { return LaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms); } else { return LaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms); } } template cudaError_t DispatchLaunchRmsNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { return DispatchLaunchRmsNormBlockUncachedImplPackSize(stream, load, store, nrow, ncol, eps, inv_rms); } template typename std::enable_if::value, cudaError_t>::type LaunchRmsNorm( cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { if (ncol <= 1024) { return DispatchLaunchRmsNormWarpImpl(stream, load, store, nrow, ncol, eps, inv_rms); } else { bool dispatch_smem_impl_success = false; { cudaError_t err = TryDispatchLaunchRmsNormBlockSMemImpl(stream, load, store, nrow, ncol, eps, inv_rms, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms); } return cudaSuccess; } } template typename std::enable_if::value, cudaError_t>::type LaunchRmsNorm( cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol, const double eps, ComputeType* inv_rms) { return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms); } template __global__ void RmsNormGradWarpImpl(const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { static_assert(max_cols_per_thread % pack_size == 0, ""); static_assert(min_cols_per_thread % pack_size == 0, ""); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); assert(ncol <= max_cols_per_thread * thread_group_width); constexpr int max_packs = max_cols_per_thread / pack_size; constexpr int min_packs = min_cols_per_thread / pack_size; ComputeType normalized_buf[rows_per_access][max_cols_per_thread]; ComputeType dy_buf[rows_per_access][max_cols_per_thread]; const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int num_global_thread_group = gridDim.x * blockDim.y; for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_group) { ComputeType sum_stats[rows_per_access]; ComputeType inv_rms_buf[rows_per_access]; #pragma unroll for (int row_j = 0; row_j < rows_per_access; ++row_j) { const int global_row = row_i * rows_per_access + row_j; sum_stats[row_j] = 0; inv_rms_buf[row_j] = inv_rms[global_row]; ComputeType* row_normalized_buf = normalized_buf[row_j]; ComputeType* row_dy_buf = dy_buf[row_j]; #pragma unroll for (int pack_i = 0; pack_i < min_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size; load_x.template load(row_normalized_buf + pack_offset, global_row, global_col); load_dy.template load(row_dy_buf + pack_offset, global_row, global_col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j]; sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col]; } } #pragma unroll for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size; if (global_col < ncol) { load_x.template load(row_normalized_buf + pack_offset, global_row, global_col); load_dy.template load(row_dy_buf + pack_offset, global_row, global_col); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j]; sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col]; } } } } ComputeType warp_sum_stats[rows_per_access]; #pragma unroll for (int row_j = 0; row_j < rows_per_access; ++row_j) { warp_sum_stats[row_j] = layer_norm::WarpAllReduce( sum_stats[row_j]); } #pragma unroll for (int row_j = 0; row_j < rows_per_access; ++row_j) { const int global_row = row_i * rows_per_access + row_j; ComputeType* row_normalized_buf = normalized_buf[row_j]; ComputeType* row_dy_buf = dy_buf[row_j]; #pragma unroll for (int pack_i = 0; pack_i < min_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size; for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; const ComputeType norm_val = layer_norm::Div(row_normalized_buf[col], static_cast(ncol)); row_dy_buf[col] = (row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j]; } store.template store(row_dy_buf + pack_offset, global_row, global_col); } #pragma unroll for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) { const int pack_offset = pack_i * pack_size; const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size; if (global_col < ncol) { for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; const ComputeType norm_val = layer_norm::Div(row_normalized_buf[col], static_cast(ncol)); row_dy_buf[col] = (row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j]; } store.template store(row_dy_buf + pack_offset, global_row, global_col); } } } } } template cudaError_t LaunchRmsNormGradWarpImpl(cudaStream_t stream, const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; const int64_t num_blocks = (nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormGradWarpImpl, block_size, 0, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } dim3 block_dim(thread_group_width, thread_groups_per_block); RmsNormGradWarpImpl <<>>(nrow, ncol, load_x, load_dy, store, inv_rms); return cudaPeekAtLastError(); } template typename std::enable_if::type DispatchLaunchRmsNormGradWarpImplCols( cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { if (ncol <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (ncol <= (thread_group_width)*pack_size) { \ if (nrow % 2 == 0) { \ return LaunchRmsNormGradWarpImpl(stream, nrow, ncol, load_x, \ load_dy, store, inv_rms); \ } else { \ return LaunchRmsNormGradWarpImpl(stream, nrow, ncol, load_x, \ load_dy, store, inv_rms); \ } \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(max_col, min_col) \ else if (ncol <= (max_col)*kWarpSize) { \ return LaunchRmsNormGradWarpImpl(stream, nrow, ncol, load_x, load_dy, \ store, inv_rms); \ } DEFINE_ONE_ELIF(2, 1) DEFINE_ONE_ELIF(4, 2) DEFINE_ONE_ELIF(8, 4) DEFINE_ONE_ELIF(12, 8) DEFINE_ONE_ELIF(16, 12) DEFINE_ONE_ELIF(20, 16) DEFINE_ONE_ELIF(24, 20) DEFINE_ONE_ELIF(28, 24) DEFINE_ONE_ELIF(32, 28) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template cudaError_t DispatchLaunchRmsNormGradWarpImplPackSize(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { return DispatchLaunchRmsNormGradWarpImplCols( stream, nrow, ncol, load_x, load_dy, store, inv_rms); } template __global__ void RmsNormGradBlockSMemImpl(const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { extern __shared__ __align__(sizeof(double)) unsigned char dyn_smem[]; // dynamic shared memory for caching x and dy auto* normalized_buf = reinterpret_cast(dyn_smem); auto* dy_buf = normalized_buf + ncol; assert(ncol % pack_size == 0); const int num_packs = ncol / pack_size; for (int row = blockIdx.x; row < nrow; row += gridDim.x) { ComputeType sum_stats = 0; const ComputeType inv_rms_val = inv_rms[row]; for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) { ComputeType x_pack[pack_size]; ComputeType dy_pack[pack_size]; const int pack_offset = pack_i * pack_size; load_x.template load(x_pack, row, pack_offset); load_dy.template load(dy_pack, row, pack_offset); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; normalized_buf[col] = x_pack[pack_j] * inv_rms_val; dy_buf[col] = dy_pack[pack_j]; sum_stats += dy_buf[col] * normalized_buf[col]; } } const ComputeType row_sum_stats = layer_norm::BlockAllReduce(sum_stats); for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) { ComputeType pack[pack_size]; const int pack_offset = pack_i * pack_size; #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const int col = pack_offset + pack_j; const ComputeType norm_val = layer_norm::Div(normalized_buf[col], static_cast(ncol)); pack[pack_j] = (dy_buf[col] - norm_val * row_sum_stats) * inv_rms_val; } store.template store(pack, row, pack_offset); } } } template cudaError_t LaunchRmsNormGradBlockSMemImpl(cudaStream_t stream, const int64_t nrow, const int64_t ncol, const size_t smem_size, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormGradBlockSMemImpl, block_size, smem_size, nrow, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } RmsNormGradBlockSMemImpl <<>>( static_cast(nrow), static_cast(ncol), load_x, load_dy, store, inv_rms); return cudaPeekAtLastError(); } template cudaError_t TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize( cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; const size_t smem_size = ncol * sizeof(ComputeType) * 2; // ncol * 2 for caching x and dy both int max_active_blocks = 0; int num_blocks = 0; #define SELECT_BLOCK_SIZE_CONF(block_size_conf) \ { \ cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ &num_blocks, \ RmsNormGradBlockSMemImpl, \ block_size_conf, smem_size); \ if (err != cudaSuccess) { return err; } \ if (max_active_blocks == 0) { \ if (num_blocks <= max_active_blocks) { \ *success = false; \ return cudaSuccess; \ } \ max_active_blocks = num_blocks; \ } else { \ if (num_blocks == max_active_blocks) { \ *success = true; \ return LaunchRmsNormGradBlockSMemImpl(stream, nrow, ncol, smem_size, \ load_x, load_dy, store, inv_rms); \ } \ } \ } SELECT_BLOCK_SIZE_CONF(block_size_conf_1) SELECT_BLOCK_SIZE_CONF(block_size_conf_4) SELECT_BLOCK_SIZE_CONF(block_size_conf_3) SELECT_BLOCK_SIZE_CONF(block_size_conf_2) #undef SELECT_BLOCK_SIZE_CONF *success = true; return LaunchRmsNormGradBlockSMemImpl(stream, nrow, ncol, smem_size, load_x, load_dy, store, inv_rms); } template cudaError_t TryDispatchLaunchRmsNormGradBlockSMemImplPackSize( cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms, bool* success) { if (ncol % 2 == 0 && layer_norm::CanPackAs(load_x, 2) && layer_norm::CanPackAs(load_dy, 2) && layer_norm::CanPackAs(store, 2)) { return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms, success); } else { return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms, success); } } template __global__ void RmsNormGradBlockUncachedImpl(const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { assert(ncol % pack_size == 0); const int num_packs = ncol / pack_size; for (int row = blockIdx.x; row < nrow; row += gridDim.x) { const ComputeType inv_rms_val = inv_rms[row]; ComputeType sum_stats = 0; for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) { ComputeType x_pack[pack_size]; ComputeType dy_pack[pack_size]; const int pack_offset = pack_i * pack_size; load_x.template load(x_pack, row, pack_offset); load_dy.template load(dy_pack, row, pack_offset); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { sum_stats += dy_pack[pack_j] * x_pack[pack_j] * inv_rms_val; } } const ComputeType row_sum_stats = layer_norm::BlockAllReduce(sum_stats); for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) { ComputeType x_pack[pack_size]; ComputeType dy_pack[pack_size]; const int pack_offset = pack_i * pack_size; load_x.template load(x_pack, row, pack_offset); load_dy.template load(dy_pack, row, pack_offset); #pragma unroll for (int pack_j = 0; pack_j < pack_size; ++pack_j) { const ComputeType norm_val = layer_norm::Div(x_pack[pack_j] * inv_rms_val, static_cast(ncol)); dy_pack[pack_j] = (dy_pack[pack_j] - norm_val * row_sum_stats) * inv_rms_val; } store.template store(dy_pack, row, pack_offset); } } } template cudaError_t LaunchRmsNormGradBlockUncachedImpl(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = layer_norm::GetNumBlocks( RmsNormGradBlockUncachedImpl, block_size, 0, nrow, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } RmsNormGradBlockUncachedImpl <<>>(nrow, ncol, load_x, load_dy, store, inv_rms); return cudaPeekAtLastError(); } template cudaError_t DispatchLaunchRmsNormGradBlockUncachedImplBlockSize(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; int max_active_blocks = 0; #define SELECT_BLOCK_SIZE_CONF(block_size_conf) \ { \ cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ &max_active_blocks, \ RmsNormGradBlockUncachedImpl, \ block_size_conf, 0); \ if (err != cudaSuccess) { return err; } \ if (max_active_blocks > 0) { \ return LaunchRmsNormGradBlockUncachedImpl(stream, nrow, ncol, load_x, \ load_dy, store, inv_rms); \ } \ } SELECT_BLOCK_SIZE_CONF(block_size_conf_4) SELECT_BLOCK_SIZE_CONF(block_size_conf_3) SELECT_BLOCK_SIZE_CONF(block_size_conf_2) SELECT_BLOCK_SIZE_CONF(block_size_conf_1) #undef SELECT_BLOCK_SIZE_CONF return cudaErrorInvalidValue; } template cudaError_t DispatchLaunchRmsNormGradBlockUncachedImplPackSize(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { if (ncol % 2 == 0 && layer_norm::CanPackAs(load_x, 2) && layer_norm::CanPackAs(load_dy, 2) && layer_norm::CanPackAs(store, 2) && ncol > kWarpSize) { return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms); } else { return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms); } } template typename std::enable_if::value, cudaError_t>::type LaunchRmsNormGrad(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { if (ncol <= 1024) { return DispatchLaunchRmsNormGradWarpImplPackSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms); } else { bool dispatch_smem_impl_success = false; { cudaError_t err = TryDispatchLaunchRmsNormGradBlockSMemImplPackSize( stream, nrow, ncol, load_x, load_dy, store, inv_rms, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms); } return cudaSuccess; } } template typename std::enable_if::value, cudaError_t>::type LaunchRmsNormGrad(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) { return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy, store, inv_rms); } template __global__ void RmsNormParamGrad(int nrow, int ncol, const T* __restrict__ dy, const T* __restrict__ x, const ComputeType* __restrict__ inv_rms, T* __restrict__ b_weight_grad) { __shared__ ComputeType dweight[kWarpSize][kWarpSize + 1]; ComputeType dweight_sum[nproc_per_thread]; #pragma unroll for (int i = 0; i < nproc_per_thread; ++i) { dweight_sum[i] = 0; } const int col = blockIdx.x * blockDim.x + threadIdx.x; if (col < ncol) { // a wave for one traverse (when nrow > warp_size * grad_dim_y) for (int j = blockIdx.y * kWarpSize + threadIdx.y; j < nrow; j += kWarpSize * gridDim.y) { #pragma unroll for (int i = 0; i < nproc_per_thread; ++i) { int row = j + i * blockDim.y; if (row < nrow) { int offset = row * ncol + col; const ComputeType dy_val = static_cast(dy[offset]); const ComputeType x_val = static_cast(x[offset]); const ComputeType inv_rms_val = inv_rms[row]; // collect dx from waves dweight_sum[i] += dy_val * x_val * inv_rms_val; } } } } // broadcast sum to the nproc_per_thread number rows // each warp process the nproc_per_thread number rows of smem #pragma unroll for (int i = 0; i < nproc_per_thread; ++i) { dweight[i * blockDim.y + threadIdx.y][threadIdx.x] = dweight_sum[i]; } __syncthreads(); // transpose access for leveraging warp to reduce rows in a block #pragma unroll for (int i = 0; i < nproc_per_thread; ++i) { // the first col of block threads is for storing the reduced sum of rows, // and each first col thread is writing the nproc_per_thread number cols of output const int row_in_block = threadIdx.y + i * blockDim.y; const int col = blockIdx.x * blockDim.x + row_in_block; if (col < ncol) { // each warp process a col in which reduce sum all rows ComputeType dweight_val = dweight[threadIdx.x][row_in_block]; ComputeType global_dweight = WarpReduceSum(dweight_val); if (threadIdx.x == 0) { const int offset = blockIdx.y * ncol + col; b_weight_grad[offset] = global_dweight; } } } } template cudaError_t GetGrid2Dim(const int64_t nrow, const int64_t ncol, int block_dim_x, int block_dim_y, int* grid_dim_x, int* grid_dim_y) { const int tile_size = block_dim_x; if (nproc_per_thread * block_dim_y != tile_size) { return cudaErrorInvalidValue; } *grid_dim_x = (ncol + tile_size - 1) / tile_size; const int num_blocks_y = (nrow + tile_size - 1) / tile_size; using ComputeType = typename layer_norm::DefaultComputeType::type; cudaError_t err = layer_norm::GetNumBlocks(RmsNormParamGrad, block_dim_x * block_dim_y, /*dynamic_smem_size*/ 0, num_blocks_y, /*waves*/ 1, grid_dim_y); if (err != cudaSuccess) { return err; } return cudaSuccess; } } // namespace rms_norm } // namespace cuda } // namespace oneflow #endif // ONEFLOW_CORE_CUDA_RMS_NORM_H_ ================================================ FILE: oneflow/core/cuda/softmax.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_SOFTMAX_H_ #define ONEFLOW_CORE_CUDA_SOFTMAX_H_ #include #include #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace cuda { namespace softmax { constexpr int kWarpSize = 32; template struct SumOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } }; template struct MaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } }; template class ReductionOp, typename T, int thread_group_width = kWarpSize> __inline__ __device__ T WarpAllReduce(T val) { for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); } return val; } template class ReductionOp, typename T, int block_size> __inline__ __device__ T BlockAllReduce(T val) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T result_broadcast; T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); if (threadIdx.x == 0) { result_broadcast = result; } __syncthreads(); return result_broadcast; } template __inline__ __device__ T Inf(); template<> __inline__ __device__ float Inf() { return CUDART_INF_F; } template<> __inline__ __device__ double Inf() { return CUDART_INF; } template __inline__ __device__ T Exp(T x); template<> __inline__ __device__ float Exp(float x) { #ifdef OF_SOFTMAX_USE_FAST_MATH return __expf(x); #else return exp(x); #endif } template<> __inline__ __device__ double Exp(double x) { return exp(x); } template __inline__ __device__ T Div(T a, T b); template<> __inline__ __device__ float Div(float a, float b) { #ifdef OF_SOFTMAX_USE_FAST_MATH return __fdividef(a, b); #else return a / b; #endif } template<> __inline__ __device__ double Div(double a, double b) { return a / b; } template __inline__ __device__ T Log(T x); template<> __inline__ __device__ float Log(float x) { #ifdef OF_SOFTMAX_USE_FAST_MATH return __logf(x); #else return log(x); #endif } template<> __inline__ __device__ double Log(double x) { return log(x); } inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min(max_blocks, sm_count * tpm / block_size * waves)); return cudaSuccess; } template struct DefaultComputeType { using type = T; }; template<> struct DefaultComputeType { using type = float; }; #if CUDA_VERSION >= 11000 template<> struct DefaultComputeType { using type = float; }; #endif // CUDA_VERSION >= 11000 template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * N, ""); __device__ Pack() { // do nothing } PackType storage; T elem[N]; }; template struct DirectLoad { DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { Pack pack; const int64_t offset = (row * row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } } const SRC* src; int64_t row_size; }; template struct DirectStore { DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { Pack pack; const int64_t offset = (row * row_size + col) / N; #pragma unroll for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } *(reinterpret_cast*>(dst) + offset) = pack.storage; } DST* dst; int64_t row_size; }; enum class Algorithm { kSoftmax = 0, kLogSoftmax = 1, }; template __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int num_packs = cols_per_thread / pack_size; assert(cols <= cols_per_thread * thread_group_width); ComputeType buf[rows_per_access][cols_per_thread]; const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; const int64_t step = num_global_thread_group * rows_per_access; for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { ComputeType thread_max[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { thread_max[row_id] = -Inf(); ComputeType* row_buf = buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < num_packs; ++pack_id) { const int pack_offset = pack_id * pack_size; const int col = (pack_id * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { load.template load(row_buf + pack_offset, row + row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]); } } else { #pragma unroll for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = -Inf(); } } } } ComputeType warp_max[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { warp_max[row_id] = WarpAllReduce(thread_max[row_id]); } ComputeType thread_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { thread_sum[row_id] = 0; ComputeType* row_buf = buf[row_id]; #pragma unroll for (int i = 0; i < cols_per_thread; ++i) { if (algorithm == Algorithm::kSoftmax) { row_buf[i] = Exp(row_buf[i] - warp_max[row_id]); thread_sum[row_id] += row_buf[i]; } else if (algorithm == Algorithm::kLogSoftmax) { row_buf[i] -= warp_max[row_id]; thread_sum[row_id] += Exp(row_buf[i]); } else { __trap(); } } } ComputeType warp_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]); } #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { ComputeType* row_buf = buf[row_id]; #pragma unroll for (int i = 0; i < cols_per_thread; ++i) { if (algorithm == Algorithm::kSoftmax) { row_buf[i] = Div(row_buf[i], warp_sum[row_id]); } else if (algorithm == Algorithm::kLogSoftmax) { row_buf[i] -= Log(warp_sum[row_id]); } else { __trap(); } } #pragma unroll for (int i = 0; i < num_packs; ++i) { const int col = (i * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { store.template store(row_buf + i * pack_size, row + row_id, col); } } } } } template inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxWarpImpl <<>>(load, store, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols == cols_per_thread * thread_group_width) { return LaunchSoftmaxWarpImpl( stream, load, store, rows, cols); } else { return LaunchSoftmaxWarpImpl( stream, load, store, rows, cols); } } template typename std::enable_if::type DispatchSoftmaxWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, \ rows, cols); \ } else { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, \ rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(5) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(7) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(9) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(11) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(13) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(15) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(17) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(19) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(21) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(23) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(25) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(27) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(29) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(31) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template typename std::enable_if::type DispatchSoftmaxWarpImplCols( cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, \ rows, cols); \ } else { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, \ rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template struct DispatchSoftmaxWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } else { return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } } }; template inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxWarpImplPackSize()(stream, load, store, rows, cols); } template __global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; auto* buf = reinterpret_cast(shared_buf); const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_max = -Inf(); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { buf[i * num_packs + pack_id] = pack[i]; thread_max = max(thread_max, pack[i]); } } const ComputeType row_max = BlockAllReduce(thread_max); ComputeType thread_sum = 0; for (int col = tid; col < cols; col += block_size) { if (algorithm == Algorithm::kSoftmax) { const ComputeType exp_x = Exp(buf[col] - row_max); buf[col] = exp_x; thread_sum += exp_x; } else { const ComputeType x = buf[col] - row_max; buf[col] = x; thread_sum += Exp(x); } } const ComputeType row_sum = BlockAllReduce(thread_sum); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { pack[i] = Div(buf[i * num_packs + pack_id], row_sum); } else if (algorithm == Algorithm::kLogSoftmax) { pack[i] = buf[i * num_packs + pack_id] - Log(row_sum); } else { __trap(); } } store.template store(pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem, const int64_t rows, const int64_t cols) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxBlockSMemImpl <<>>(load, store, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; const size_t smem = cols * sizeof(ComputeType); int max_active_blocks_conf_1; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, SoftmaxBlockSMemImpl, block_size_conf_1, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_1 <= 0) { *success = false; return cudaSuccess; } int max_active_blocks_conf_4; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_4, SoftmaxBlockSMemImpl, block_size_conf_4, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_4 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxBlockSMemImpl(stream, load, store, smem, rows, cols); } int max_active_blocks_conf_3; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_3, SoftmaxBlockSMemImpl, block_size_conf_3, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_3 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxBlockSMemImpl(stream, load, store, smem, rows, cols); } int max_active_blocks_conf_2; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, SoftmaxBlockSMemImpl, block_size_conf_2, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_2 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxBlockSMemImpl(stream, load, store, smem, rows, cols); } *success = true; return LaunchSoftmaxBlockSMemImpl(stream, load, store, smem, rows, cols); } template struct TryDispatchSoftmaxBlockSMemImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, bool* success) { if (cols % 2 == 0) { return TryDispatchSoftmaxBlockSMemImplBlockSize( stream, load, store, rows, cols, success); } else { return TryDispatchSoftmaxBlockSMemImplBlockSize( stream, load, store, rows, cols, success); } } }; template inline cudaError_t TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, bool* success) { return TryDispatchSoftmaxBlockSMemImplPackSize()( stream, load, store, rows, cols, success); } template __global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_max = -Inf(); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_max = max(thread_max, pack[i]); } } const ComputeType row_max = BlockAllReduce(thread_max); ComputeType thread_sum = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_sum += Exp(pack[i] - row_max); } } const ComputeType row_sum = BlockAllReduce(thread_sum); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { pack[i] = Div(Exp(pack[i] - row_max), row_sum); } else if (algorithm == Algorithm::kLogSoftmax) { pack[i] = (pack[i] - row_max) - Log(row_sum); } else { __trap(); } } store.template store(pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 1024; constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxBlockUncachedImpl <<>>(load, store, rows, cols); return cudaPeekAtLastError(); } template struct DispatchSoftmaxBlockUncachedImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { return LaunchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } else { return LaunchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } } }; template inline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxBlockUncachedImplPackSize()( stream, load, store, rows, cols); } template inline typename std::enable_if::value, cudaError_t>::type DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols < 1024) { return DispatchSoftmaxWarpImpl( stream, load, store, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchSoftmaxBlockSMemImpl( stream, load, store, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } template inline typename std::enable_if::value, cudaError_t>::type DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxWarpImpl( stream, load, store, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchSoftmaxBlockSMemImpl( stream, load, store, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxBlockUncachedImpl( stream, load, store, rows, cols); } template __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); constexpr int pack_per_thread = cols_per_thread / pack_size; assert(cols <= cols_per_thread * thread_group_width); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); ComputeType y_buf[rows_per_access][cols_per_thread]; ComputeType dy_buf[rows_per_access][cols_per_thread]; const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; const int64_t step = num_global_thread_group * rows_per_access; for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { ComputeType thread_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { thread_sum[row_id] = 0; ComputeType* row_y_buf = y_buf[row_id]; ComputeType* row_dy_buf = dy_buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { const int pack_offset = pack_id * pack_size; const int col = (pack_id * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { load_y.template load(row_y_buf + pack_offset, row + row_id, col); load_dy.template load(row_dy_buf + pack_offset, row + row_id, col); #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { thread_sum[row_id] += row_y_buf[pack_offset + i] * row_dy_buf[pack_offset + i]; } else if (algorithm == Algorithm::kLogSoftmax) { thread_sum[row_id] += row_dy_buf[pack_offset + i]; } else { __trap(); } } } } } ComputeType warp_sum[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]); } #pragma unroll for (int row_id = 0; row_id < rows_per_access; ++row_id) { ComputeType* row_y_buf = y_buf[row_id]; ComputeType* row_dy_buf = dy_buf[row_id]; #pragma unroll for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { const int pack_offset = pack_id * pack_size; const int col = (pack_id * thread_group_width + lane_id) * pack_size; if (!padding || col < cols) { for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { row_dy_buf[pack_offset + i] = (row_dy_buf[pack_offset + i] - warp_sum[row_id]) * row_y_buf[pack_offset + i]; } else if (algorithm == Algorithm::kLogSoftmax) { row_dy_buf[pack_offset + i] -= Exp(row_y_buf[pack_offset + i]) * warp_sum[row_id]; } else { __trap(); } } store.template store(row_dy_buf + pack_offset, row + row_id, col); } } } } } template inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxGradWarpImpl <<>>(load_y, load_dy, store, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols == cols_per_thread * thread_group_width) { return LaunchSoftmaxGradWarpImpl(stream, load_y, load_dy, store, rows, cols); } else { return LaunchSoftmaxGradWarpImpl(stream, load_y, load_dy, store, rows, cols); } } template typename std::enable_if::type DispatchSoftmaxGradWarpImplCols( cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchSoftmaxGradWarpImplPadding( \ stream, load_y, load_dy, store, rows, cols); \ } else { \ return DispatchSoftmaxGradWarpImplPadding( \ stream, load_y, load_dy, store, rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchSoftmaxGradWarpImplPadding(stream, load_y, load_dy, \ store, rows, cols); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(5) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(7) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(9) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(11) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(13) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(15) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(17) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(19) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(21) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(23) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(25) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(27) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(29) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(31) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template typename std::enable_if::type DispatchSoftmaxGradWarpImplCols( cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchSoftmaxGradWarpImplPadding( \ stream, load_y, load_dy, store, rows, cols); \ } else { \ return DispatchSoftmaxGradWarpImplPadding( \ stream, load_y, load_dy, store, rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchSoftmaxGradWarpImplPadding(stream, load_y, load_dy, \ store, rows, cols); \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template struct DispatchSoftmaxGradWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { return DispatchSoftmaxGradWarpImplCols( stream, load_y, load_dy, store, rows, cols); } else { return DispatchSoftmaxGradWarpImplCols( stream, load_y, load_dy, store, rows, cols); } } }; template inline cudaError_t DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxGradWarpImplPackSize()( stream, load_y, load_dy, store, rows, cols); } template __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[]; auto* y_buf = reinterpret_cast(grad_shared_buf); auto* dy_buf = y_buf + cols; const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_sum = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; load_y.template load(y_pack, row, pack_id * pack_size); load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { y_buf[i * num_packs + pack_id] = y_pack[i]; dy_buf[i * num_packs + pack_id] = dy_pack[i]; if (algorithm == Algorithm::kSoftmax) { thread_sum += y_pack[i] * dy_pack[i]; } else if (algorithm == Algorithm::kLogSoftmax) { thread_sum += dy_pack[i]; } else { __trap(); } } } const ComputeType row_sum = BlockAllReduce(thread_sum); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { pack[i] = (dy_buf[i * num_packs + pack_id] - row_sum) * y_buf[i * num_packs + pack_id]; } else if (algorithm == Algorithm::kLogSoftmax) { pack[i] = dy_buf[i * num_packs + pack_id] - Exp(y_buf[i * num_packs + pack_id]) * row_sum; } else { __trap(); } } store.template store(pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, int smem, const int64_t rows, const int64_t cols) { constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxGradBlockSMemImpl <<>>(load_y, load_dy, store, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols, bool* success) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; constexpr int block_size_conf_3 = 512; constexpr int block_size_conf_4 = 1024; const size_t smem = cols * sizeof(ComputeType) * 2; int max_active_blocks_conf_1; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, SoftmaxGradBlockSMemImpl, block_size_conf_1, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_1 <= 0) { *success = false; return cudaSuccess; } int max_active_blocks_conf_4; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_4, SoftmaxGradBlockSMemImpl, block_size_conf_4, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_4 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, cols); } int max_active_blocks_conf_3; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_3, SoftmaxGradBlockSMemImpl, block_size_conf_3, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_3 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, cols); } int max_active_blocks_conf_2; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, SoftmaxGradBlockSMemImpl, block_size_conf_2, smem); if (err != cudaSuccess) { return err; } } if (max_active_blocks_conf_2 == max_active_blocks_conf_1) { *success = true; return LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, cols); } *success = true; return LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, cols); } template struct TryDispatchSoftmaxGradBlockSMemImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols, bool* success) { if (cols % 2 == 0) { return TryDispatchSoftmaxGradBlockSMemImplBlockSize(stream, load_y, load_dy, store, rows, cols, success); } else { return TryDispatchSoftmaxGradBlockSMemImplBlockSize(stream, load_y, load_dy, store, rows, cols, success); } } }; template inline cudaError_t TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols, bool* success) { return TryDispatchSoftmaxGradBlockSMemImplPackSize()(stream, load_y, load_dy, store, rows, cols, success); } template __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { ComputeType thread_sum = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; load_y.template load(y_pack, row, pack_id * pack_size); load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { thread_sum += y_pack[i] * dy_pack[i]; } else if (algorithm == Algorithm::kLogSoftmax) { thread_sum += dy_pack[i]; } else { __trap(); } } } const ComputeType row_sum = BlockAllReduce(thread_sum); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; load_y.template load(y_pack, row, pack_id * pack_size); load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { if (algorithm == Algorithm::kSoftmax) { dy_pack[i] = (dy_pack[i] - row_sum) * y_pack[i]; } else if (algorithm == Algorithm::kLogSoftmax) { dy_pack[i] -= Exp(y_pack[i]) * row_sum; } else { __trap(); } } store.template store(dy_pack, row, pack_id * pack_size); } } } template inline cudaError_t LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 1024; constexpr int waves = 32; int grid_dim_x; { cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } } SoftmaxGradBlockUncachedImpl <<>>(load_y, load_dy, store, rows, cols); return cudaPeekAtLastError(); } template struct DispatchSoftmaxGradBlockUncachedImplPackSize { cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0 && cols > kWarpSize) { return LaunchSoftmaxGradBlockUncachedImpl( stream, load_y, load_dy, store, rows, cols); } else { return LaunchSoftmaxGradBlockUncachedImpl( stream, load_y, load_dy, store, rows, cols); } } }; template inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxGradBlockUncachedImplPackSize()(stream, load_y, load_dy, store, rows, cols); } template inline typename std::enable_if::value, cudaError_t>::type DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxGradWarpImpl( stream, load_y, load_dy, store, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl( stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, store, rows, cols); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, store, rows, cols); } template inline typename std::enable_if::value, cudaError_t>::type DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxGradWarpImpl( stream, load_y, load_dy, store, rows, cols); } else { bool dispatch_smem_impl_success; { cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl( stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success); if (err != cudaSuccess) { return err; } } if (!dispatch_smem_impl_success) { return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, store, rows, cols); } return cudaSuccess; } } template inline typename std::enable_if::value, cudaError_t>::type DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, store, rows, cols); } } // namespace softmax } // namespace cuda } // namespace oneflow #endif // ONEFLOW_CORE_CUDA_SOFTMAX_H_ ================================================ FILE: oneflow/core/cuda/unique.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_CUDA_UNIQUE_H_ #define ONEFLOW_CORE_CUDA_UNIQUE_H_ #include #include #include "oneflow/core/common/permutation_iterator.h" #include "oneflow/core/common/not_equal_to_previous_adjacent_iterator.h" namespace oneflow { namespace cuda { namespace unique { using Flag = uint32_t; static constexpr Flag kDefault = 0x0; static constexpr Flag kInputSorted = 0x1; static constexpr Flag kOutputInverseIndices = 0x1 << 1; static constexpr Flag kOutputCounts = 0x1 << 2; namespace { constexpr size_t kCudaAlignSize = 512; __device__ __host__ __forceinline__ size_t GetCudaAlignedSize(size_t size) { return (size + kCudaAlignSize - 1) / kCudaAlignSize * kCudaAlignSize; } template __device__ __host__ __forceinline__ T* PtrOffset(void* ptr, size_t offset) { return reinterpret_cast(reinterpret_cast(ptr) + offset); } __device__ __host__ __forceinline__ size_t max(size_t a, size_t b) { return a > b ? a : b; } template cudaError_t DoUnique(size_t n, const Key* sorted_in, Key* unique, Index* num_unique, void* workspace, size_t* workspace_size, cudaStream_t stream) { size_t ws = *workspace_size; cudaError_t err = cub::DeviceSelect::Unique( workspace, ws, sorted_in, unique, num_unique, n, stream); if (err != cudaSuccess) { return err; } if (*workspace_size == 0) { *workspace_size = ws; } return cudaSuccess; } template cudaError_t DoUniqueWithCounts(size_t n, const Key* sorted_in, Key* unique, Index* num_unique, Index* counts, void* workspace, size_t* workspace_size, cudaStream_t stream) { size_t ws = *workspace_size; cudaError_t err = cub::DeviceRunLengthEncode::Encode( workspace, ws, sorted_in, unique, counts, num_unique, n, stream); if (err != cudaSuccess) { return err; } if (*workspace_size == 0) { *workspace_size = ws; } return cudaSuccess; } template cudaError_t DispatchOutputCounts(Flag flag, size_t n, const Key* sorted_in, Key* unique, Index* num_unique, Index* counts, void* workspace, size_t* workspace_size, cudaStream_t stream) { size_t ws = *workspace_size; if ((flag & kOutputCounts) != 0) { cudaError_t err = DoUniqueWithCounts(n, sorted_in, unique, num_unique, counts, workspace, &ws, stream); if (err != cudaSuccess) { return err; } } else { cudaError_t err = DoUnique(n, sorted_in, unique, num_unique, workspace, &ws, stream); if (err != cudaSuccess) { return err; } } if (*workspace_size == 0) { *workspace_size = ws; } return cudaSuccess; } template cudaError_t DoGenInverseIndices(size_t n, const Key* sorted_in, InverseIndicesIter inverse_indices_iter, void* workspace, size_t* workspace_size, cudaStream_t stream) { size_t ws = *workspace_size; NotEqualToPreviousAdjacentIterator unique_counting_iter(sorted_in, 0); cudaError_t err = cub::DeviceScan::InclusiveSum( workspace, ws, unique_counting_iter, inverse_indices_iter, n, stream); if (err != cudaSuccess) { return err; } if (*workspace_size == 0) { *workspace_size = ws; } return cudaSuccess; } template cudaError_t DispatchOutputInverseIndices(Flag flag, size_t n, const Key* sorted_in, Key* unique, Index* num_unique, InverseIndicesIter inverse_indices_iter, Index* counts, void* workspace, size_t* workspace_size, cudaStream_t stream) { size_t dispatch_with_counts_ws = *workspace_size; size_t do_gen_inverse_indices_ws = *workspace_size; { cudaError_t err = DispatchOutputCounts(flag, n, sorted_in, unique, num_unique, counts, workspace, &dispatch_with_counts_ws, stream); if (err != cudaSuccess) { return err; } } if ((flag & kOutputInverseIndices) != 0) { cudaError_t err = DoGenInverseIndices( n, sorted_in, inverse_indices_iter, workspace, &do_gen_inverse_indices_ws, stream); if (err != cudaSuccess) { return err; } } if (*workspace_size == 0) { *workspace_size = max(dispatch_with_counts_ws, do_gen_inverse_indices_ws); } return cudaSuccess; } template __global__ void IotaKernel(size_t n, T* out) { for (T i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < n; i += step) { out[i] = i; } } template cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices, void* workspace, size_t* workspace_size, cudaStream_t stream) { Index* indices; const size_t indices_size = GetCudaAlignedSize(n * sizeof(Index)); void* sort_workspace; size_t sort_ws; if (*workspace_size == 0) { indices = nullptr; sort_workspace = nullptr; sort_ws = 0; } else { if (*workspace_size <= indices_size) { return cudaErrorInvalidValue; } indices = PtrOffset(workspace, 0); sort_workspace = PtrOffset(workspace, indices_size); sort_ws = *workspace_size - indices_size; } if (*workspace_size != 0) { const int block_size = 1024; const int num_blocks = static_cast((n + block_size - 1) / block_size); IotaKernel<<>>(n, indices); } cudaError_t err = cub::DeviceRadixSort::SortPairs( sort_workspace, sort_ws, in, sorted, indices, sorted_indices, n, 0, sizeof(Key) * 8, stream); if (err != cudaSuccess) { return err; } if (*workspace_size == 0) { *workspace_size = indices_size + sort_ws; } return cudaSuccess; } template cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique, Index* inverse_indices, Index* counts, void* workspace, size_t* workspace_size, cudaStream_t stream) { if ((flag & kInputSorted) != 0) { return DispatchOutputInverseIndices(flag, n, in, unique, num_unique, inverse_indices, counts, workspace, workspace_size, stream); } else { const size_t sorted_in_size = GetCudaAlignedSize(n * sizeof(Key)); const size_t sorted_indices_size = GetCudaAlignedSize(n * sizeof(Index)); const size_t sort_buffer_size = sorted_in_size + sorted_indices_size; Key* sorted_in; Index* sorted_indices; size_t do_sort_ws; void* do_sort_workspace; size_t do_inverse_indices_ws; void* do_inverse_indices_workspace; if (*workspace_size == 0) { sorted_in = nullptr; sorted_indices = nullptr; do_sort_ws = 0; do_sort_workspace = nullptr; do_inverse_indices_ws = 0; do_inverse_indices_workspace = nullptr; } else { if (*workspace_size <= sort_buffer_size) { return cudaErrorInvalidValue; } sorted_in = PtrOffset(workspace, 0); sorted_indices = PtrOffset(workspace, sorted_in_size); do_sort_ws = *workspace_size - sort_buffer_size; do_sort_workspace = PtrOffset(workspace, sort_buffer_size); do_inverse_indices_ws = do_sort_ws; do_inverse_indices_workspace = do_sort_workspace; } { cudaError_t err = DoSort(n, in, sorted_in, sorted_indices, do_sort_workspace, &do_sort_ws, stream); if (err != cudaSuccess) { return err; } } PermutationIterator inverse_indices_iter(inverse_indices, sorted_indices); { cudaError_t err = DispatchOutputInverseIndices( flag, n, sorted_in, unique, num_unique, inverse_indices_iter, counts, do_inverse_indices_workspace, &do_inverse_indices_ws, stream); if (err != cudaSuccess) { return err; } } if (*workspace_size == 0) { *workspace_size = sort_buffer_size + max(do_sort_ws, do_inverse_indices_ws); } return cudaSuccess; } } } // namespace template cudaError_t Launch(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique, Index* inverse_indices, Index* counts, void* workspace, size_t workspace_size, cudaStream_t stream) { if (workspace_size == 0) { return cudaErrorInvalidValue; } return DispatchInputSorted(flag, n, in, unique, num_unique, inverse_indices, counts, workspace, &workspace_size, stream); } template cudaError_t GetWorkspaceSize(Flag flag, size_t n, size_t* workspace_size) { *workspace_size = 0; return DispatchInputSorted(flag, n, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, workspace_size, 0); } } // namespace unique } // namespace cuda } // namespace oneflow #endif // ONEFLOW_CORE_CUDA_UNIQUE_H_ ================================================ FILE: oneflow/core/device/cuda_pseudo_bfloat16.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_ #define ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_ #ifdef WITH_CUDA #include #include #if CUDA_VERSION >= 11000 #include #endif #if CUDA_VERSION >= 11000 && CUDA_VERSION <= 12010 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(op) \ __device__ __forceinline__ __nv_bfloat16 operator op(const __nv_bfloat16& lh, \ const __nv_bfloat16& rh) { \ return __float2bfloat16(__bfloat162float(lh) op __bfloat162float(rh)); \ } DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(+) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(-) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(*) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(/) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR #define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(func) \ __device__ __forceinline__ __nv_bfloat16 __h##func(const __nv_bfloat16 a, \ const __nv_bfloat16 b) { \ return __float2bfloat16(__f##func##_rn(__bfloat162float(a), __bfloat162float(b))); \ } DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(add) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(div) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(mul) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(sub) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC #define DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(func) \ __device__ __forceinline__ __nv_bfloat162 __h##func##2(const __nv_bfloat162 a, \ const __nv_bfloat162 b) { \ __nv_bfloat162 ret; \ ret.x = __h##func(a.x, b.x); \ ret.y = __h##func(a.y, b.y); \ return ret; \ } DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(add) DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(div) DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(mul) DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(sub) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC #define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(op) \ __device__ __forceinline__ __nv_bfloat16& operator op(__nv_bfloat16& lh, \ const __nv_bfloat16& rh) { \ float lhv = __bfloat162float(lh); \ lhv op __bfloat162float(rh); \ lh = __float2bfloat16(lhv); \ return lh; \ } DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(+=) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(-=) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(*=) DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(/=) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR __device__ __forceinline__ __nv_bfloat16& operator++(__nv_bfloat16& h) { h = __float2bfloat16(__bfloat162float(h) + 1); return h; } __device__ __forceinline__ __nv_bfloat16& operator--(__nv_bfloat16& h) { h = __float2bfloat16(__bfloat162float(h) - 1); return h; } __device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16& h, int) { __nv_bfloat16 ret = h; h = __float2bfloat16(__bfloat162float(h) + 1); return ret; } __device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16& h, int) { __nv_bfloat16 ret = h; h = __float2bfloat16(__bfloat162float(h) - 1); return ret; } __device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) { return h; } __device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) { return __float2bfloat16(-__bfloat162float(h)); } __device__ __forceinline__ __nv_bfloat16 __hneg(const __nv_bfloat16 a) { return -a; } #define DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(op) \ __device__ __forceinline__ bool operator op(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { \ return __bfloat162float(lh) op __bfloat162float(rh); \ } DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(==) DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(!=) DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(>) DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(<) DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(>=) DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(<=) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR __device__ __forceinline__ bool __heq(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a == b; } __device__ __forceinline__ bool __hge(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a >= b; } __device__ __forceinline__ bool __hgt(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a > b; } __device__ __forceinline__ bool __hle(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a <= b; } __device__ __forceinline__ bool __hlt(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a < b; } __device__ __forceinline__ bool __hne(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a != b; } __device__ __forceinline__ __nv_bfloat16 __hmax(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a > b ? a : b; } __device__ __forceinline__ __nv_bfloat16 __hmin(const __nv_bfloat16 a, const __nv_bfloat16 b) { return a > b ? a : b; } #define DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(func) \ __device__ __forceinline__ __nv_bfloat16 h##func(const __nv_bfloat16 h) { \ return __float2bfloat16(func##f(__bfloat162float(h))); \ } DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(cos) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp10) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp2) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log10) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log2) __device__ __forceinline__ __nv_bfloat16 hrcp(const __nv_bfloat16 h) { return __float2bfloat16(1.0f / __bfloat162float(h)); } DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(rsqrt) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(sin) DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(sqrt) #undef DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC #endif // CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #endif // WITH_CUDA #endif // ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_ ================================================ FILE: oneflow/core/device/cuda_pseudo_half.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_ #define ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_ #ifdef WITH_CUDA #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 #define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(op) \ __device__ __forceinline__ __half operator op(const __half& lh, const __half& rh) { \ return __float2half(__half2float(lh) op __half2float(rh)); \ } DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(+) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(-) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(*) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(/) #undef DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR #define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(func) \ __device__ __forceinline__ __half __h##func(const __half a, const __half b) { \ return __float2half(__f##func##_rn(__half2float(a), __half2float(b))); \ } DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(add) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(div) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(mul) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(sub) #undef DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC #define DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(func) \ __device__ __forceinline__ __half2 __h##func##2(const __half2 a, const __half2 b) { \ __half2 ret; \ ret.x = __h##func(a.x, b.x); \ ret.y = __h##func(a.y, b.y); \ return ret; \ } DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(add) DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(div) DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(mul) DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(sub) #undef DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC #define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(op) \ __device__ __forceinline__ __half& operator op(__half& lh, const __half& rh) { \ float lhv = __half2float(lh); \ lhv op __half2float(rh); \ lh = __float2half(lhv); \ return lh; \ } DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(+=) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(-=) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(*=) DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(/=) #undef DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR __device__ __forceinline__ __half& operator++(__half& h) { h = __float2half(__half2float(h) + 1); return h; } __device__ __forceinline__ __half& operator--(__half& h) { h = __float2half(__half2float(h) - 1); return h; } __device__ __forceinline__ __half operator++(__half& h, int) { __half ret = h; h = __float2half(__half2float(h) + 1); return ret; } __device__ __forceinline__ __half operator--(__half& h, int) { __half ret = h; h = __float2half(__half2float(h) - 1); return ret; } __device__ __forceinline__ __half operator+(const __half& h) { return h; } __device__ __forceinline__ __half operator-(const __half& h) { return __float2half(-__half2float(h)); } __device__ __forceinline__ __half __hneg(const __half a) { return -a; } #define DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(op) \ __device__ __forceinline__ bool operator op(const __half& lh, const __half& rh) { \ return __half2float(lh) op __half2float(rh); \ } DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(==) DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(!=) DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(>) DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(<) DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(>=) DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(<=) #undef DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR __device__ __forceinline__ bool __heq(const __half a, const __half b) { return a == b; } __device__ __forceinline__ bool __hge(const __half a, const __half b) { return a >= b; } __device__ __forceinline__ bool __hgt(const __half a, const __half b) { return a > b; } __device__ __forceinline__ bool __hle(const __half a, const __half b) { return a <= b; } __device__ __forceinline__ bool __hlt(const __half a, const __half b) { return a < b; } __device__ __forceinline__ bool __hne(const __half a, const __half b) { return a != b; } __device__ __forceinline__ __half __hmax(const __half a, const __half b) { return a > b ? a : b; } __device__ __forceinline__ __half __hmin(const __half a, const __half b) { return a > b ? a : b; } #define DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(func) \ __device__ __forceinline__ __half h##func(const __half h) { \ return __float2half(func##f(__half2float(h))); \ } DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(cos) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp10) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp2) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log10) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log2) __device__ __forceinline__ __half hrcp(const __half h) { return __float2half(1.0f / __half2float(h)); } DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(rsqrt) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(sin) DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(sqrt) #undef DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 #endif // WITH_CUDA #endif // ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_ ================================================ FILE: oneflow/core/device/cuda_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/hardware/node_device_descriptor_manager.h" #include "oneflow/core/hardware/cuda_device_descriptor.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/vm/vm_util.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace oneflow { #ifdef WITH_CUDA const char* CublasGetErrorString(cublasStatus_t error) { switch (error) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; #if CUDA_VERSION >= 6000 case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; #endif #if CUDA_VERSION >= 6050 case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; #endif default: return "Unknown cublas status"; } } const char* CurandGetErrorString(curandStatus_t error) { switch (error) { case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS"; case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH"; case CURAND_STATUS_NOT_INITIALIZED: return "CURAND_STATUS_NOT_INITIALIZED"; case CURAND_STATUS_ALLOCATION_FAILED: return "CURAND_STATUS_ALLOCATION_FAILED"; case CURAND_STATUS_TYPE_ERROR: return "CURAND_STATUS_TYPE_ERROR"; case CURAND_STATUS_OUT_OF_RANGE: return "CURAND_STATUS_OUT_OF_RANGE"; case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; case CURAND_STATUS_LAUNCH_FAILURE: return "CURAND_STATUS_LAUNCH_FAILURE"; case CURAND_STATUS_PREEXISTING_FAILURE: return "CURAND_STATUS_PREEXISTING_FAILURE"; case CURAND_STATUS_INITIALIZATION_FAILED: return "CURAND_STATUS_INITIALIZATION_FAILED"; case CURAND_STATUS_ARCH_MISMATCH: return "CURAND_STATUS_ARCH_MISMATCH"; case CURAND_STATUS_INTERNAL_ERROR: return "CURAND_STATUS_INTERNAL_ERROR"; default: return "Unknown curand status"; } } const char* CuFFTGetErrorString(cufftResult_t error) { switch (error) { case CUFFT_SUCCESS: return "CUFFT_SUCCESS"; case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN"; case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED"; case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE"; case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE"; case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR"; case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED"; case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED"; case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE"; case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA"; case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST"; case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE"; case CUFFT_PARSE_ERROR: return "CUFFT_PARSE_ERROR"; case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE"; case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED"; case CUFFT_NOT_SUPPORTED: return "CUFFT_NOT_SUPPORTED"; default: return "Unknown cufft status"; } } #if CUDA_VERSION >= 11000 const char* CusovlerGetErrorString(cusolverStatus_t error) { switch (error) { case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCESS"; case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; case CUSOLVER_STATUS_ALLOC_FAILED: return "CUSOLVER_STATUS_ALLOC_FAILED"; case CUSOLVER_STATUS_INVALID_VALUE: return "CUSOLVER_STATUS_INVALID_VALUE"; case CUSOLVER_STATUS_ARCH_MISMATCH: return "CUSOLVER_STATUS_ARCH_MISMATCH"; case CUSOLVER_STATUS_EXECUTION_FAILED: return "CUSOLVER_STATUS_EXECUTION_FAILED"; case CUSOLVER_STATUS_INTERNAL_ERROR: return "CUSOLVER_STATUS_INTERNAL_ERROR"; case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; default: return "Unknown cusolver status"; } } #endif #if CUDA_VERSION >= 10020 const char* NvjpegGetErrorString(nvjpegStatus_t error) { switch (error) { case NVJPEG_STATUS_SUCCESS: return "NVJPEG_STATUS_SUCCESS"; case NVJPEG_STATUS_NOT_INITIALIZED: return "NVJPEG_STATUS_NOT_INITIALIZED"; case NVJPEG_STATUS_INVALID_PARAMETER: return "NVJPEG_STATUS_INVALID_PARAMETER"; case NVJPEG_STATUS_BAD_JPEG: return "NVJPEG_STATUS_BAD_JPEG"; case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED"; case NVJPEG_STATUS_ALLOCATOR_FAILURE: return "NVJPEG_STATUS_ALLOCATOR_FAILURE"; case NVJPEG_STATUS_EXECUTION_FAILED: return "NVJPEG_STATUS_EXECUTION_FAILED"; case NVJPEG_STATUS_ARCH_MISMATCH: return "NVJPEG_STATUS_ARCH_MISMATCH"; case NVJPEG_STATUS_INTERNAL_ERROR: return "NVJPEG_STATUS_INTERNAL_ERROR"; case NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED: return "NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED"; default: return "Unknown nvjpeg status"; } } #endif size_t GetAvailableGpuMemSize(int dev_id) { cudaDeviceProp prop{}; cudaGetDeviceProperties(&prop, dev_id); return prop.totalGlobalMem; } namespace { std::function GetCudaMallocHostFn(int32_t dev) { auto default_fn = [](void** ptr, size_t size) { return cudaMallocHost(ptr, size); }; auto manager = Singleton::Get(); if (manager == nullptr) { return default_fn; } auto node_desc = manager->GetLocalNodeDeviceDescriptor(); auto cuda_device = std::dynamic_pointer_cast( node_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev)); if (!cuda_device) { return default_fn; } auto saved_affinity = node_desc->Topology()->GetMemoryAffinity(); if (!saved_affinity) { return default_fn; } auto device_affinity = node_desc->Topology()->GetMemoryAffinityByPCIBusID(cuda_device->PCIBusID()); if (!device_affinity) { return default_fn; } return [device_affinity, saved_affinity, node_desc, default_fn](void** ptr, size_t size) { node_desc->Topology()->SetMemoryAffinity(device_affinity); cudaError_t err = default_fn(ptr, size); node_desc->Topology()->SetMemoryAffinity(saved_affinity); return err; }; } } // namespace cudaError_t NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size) { auto fn = GetCudaMallocHostFn(dev); return fn(ptr, size); } CudaCurrentDeviceGuard::CudaCurrentDeviceGuard(int32_t dev_id) { CHECK(!pthread_fork::IsForkedSubProcess()) << pthread_fork::kOfCudaNotSupportInForkedSubProcess; OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_)); OF_CUDA_CHECK(cudaSetDevice(dev_id)); } CudaCurrentDeviceGuard::CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_)); } CudaCurrentDeviceGuard::~CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaSetDevice(saved_dev_id_)); } CublasMathModeGuard::CublasMathModeGuard(cublasHandle_t handle, cublasMath_t new_mode) : CublasMathModeGuard(handle) { SetMathMode(new_mode); } CublasMathModeGuard::CublasMathModeGuard(cublasHandle_t handle) : handle_(handle) { OF_CUBLAS_CHECK(cublasGetMathMode(handle_, &saved_mode_)); new_mode_ = saved_mode_; } CublasMathModeGuard::~CublasMathModeGuard() { if (new_mode_ != saved_mode_) { OF_CUBLAS_CHECK(cublasSetMathMode(handle_, saved_mode_)); } } void CublasMathModeGuard::SetMathMode(cublasMath_t new_mode) { new_mode_ = new_mode; if (new_mode_ != saved_mode_) { OF_CUBLAS_CHECK(cublasSetMathMode(handle_, new_mode_)); } } void CudaSynchronize(int device_id) { CudaCurrentDeviceGuard dev_guard(device_id); OF_CUDA_CHECK(cudaDeviceSynchronize()); } void SetCudaDeviceIndex(int device_id) { OF_CUDA_CHECK(cudaSetDevice(device_id)); } int GetCudaDeviceIndex() { return GlobalProcessCtx::LocalRank(); } int GetCudaDeviceCount() { /* static */ int cuda_device_count = 0; OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count)); return cuda_device_count; } // NOTE(lixiang): Get the memory of the current device. Maybe GetCUDAMemoryUsed() { JUST(vm::CurrentRankSync()); int deviceCount = 0; cudaError_t error_id = cudaGetDeviceCount(&deviceCount); if (error_id != cudaSuccess) { return Error::RuntimeError() << "Error: GetCUDAMemoryUsed fails :" << cudaGetErrorString(error_id); } CHECK_OR_RETURN(deviceCount > 0) << "GPU device does not exist"; size_t gpu_total_size; size_t gpu_free_size; cudaError_t cuda_status = cudaMemGetInfo(&gpu_free_size, &gpu_total_size); CHECK_OR_RETURN(cudaSuccess == cuda_status) << "Error: GetCUDAMemoryUsed fails :" << cudaGetErrorString(cuda_status); double total_memory = double(gpu_total_size) / (1024.0 * 1024.0); double free_memory = double(gpu_free_size) / (1024.0 * 1024.0); return (total_memory - free_memory); } static std::once_flag prop_init_flag; static std::vector device_props; void InitDevicePropVectorSize() { int device_count = GetCudaDeviceCount(); device_props.resize(device_count); } void InitDeviceProperties(int device_id) { std::call_once(prop_init_flag, InitDevicePropVectorSize); cudaDeviceProp prop{}; OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); device_props[device_id] = prop; } cudaDeviceProp* GetDeviceProperties(int device_id) { InitCudaContextOnce(device_id); return &device_props[device_id]; } void InitCudaContextOnce(int device_id) { static int device_count = GetCudaDeviceCount(); static std::vector init_flags = std::vector(device_count); if (LazyMode::is_enabled()) { return; } if (device_id == -1) { device_id = GetCudaDeviceIndex(); } std::call_once(init_flags[device_id], [&]() { OF_CUDA_CHECK(cudaSetDevice(device_id)); OF_CUDA_CHECK(cudaDeviceSynchronize()); InitDeviceProperties(device_id); }); } cudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active) { #if CUDA_VERSION >= 11030 CUdevice cu_device{}; { CUresult (*fnCuDeviceGet)(CUdevice*, int) = nullptr; cudaError_t err = cudaGetDriverEntryPoint("cuDeviceGet", (void**)&fnCuDeviceGet, cudaEnableDefault); if (err != cudaSuccess) { return err; } CUresult result = fnCuDeviceGet(&cu_device, dev); if (result == CUDA_SUCCESS) { // do nothing } else if (result == CUresult::CUDA_ERROR_INVALID_DEVICE) { return cudaErrorInvalidDevice; } else { return cudaErrorUnknown; } } { CUresult (*fnCuDevicePrimaryCtxGetState)(CUdevice, unsigned int*, int*) = nullptr; cudaError_t err = cudaGetDriverEntryPoint( "cuDevicePrimaryCtxGetState", (void**)&fnCuDevicePrimaryCtxGetState, cudaEnableDefault); if (err != cudaSuccess) { return err; } unsigned int flags{}; CUresult result = fnCuDevicePrimaryCtxGetState(cu_device, &flags, active); if (result == CUDA_SUCCESS) { return cudaSuccess; } else { return cudaErrorUnknown; } } #else return cudaErrorNotSupported; #endif // CUDA_VERSION < 11030 } #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/core/device/cuda_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_ #define ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_ #include "oneflow/core/common/data_type.h" #ifdef WITH_CUDA #include #if CUDA_VERSION >= 11000 #include #endif #include #if CUDA_VERSION >= 10010 #include #endif #include #include #include #include #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_half.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #if CUDA_VERSION >= 10020 #include #endif namespace oneflow { const char* CublasGetErrorString(cublasStatus_t error); const char* CurandGetErrorString(curandStatus_t error); const char* CuFFTGetErrorString(cufftResult_t error); #if CUDA_VERSION >= 11000 const char* CusovlerGetErrorString(cusolverStatus_t error); #endif #if CUDA_VERSION >= 10020 const char* NvjpegGetErrorString(nvjpegStatus_t error); #endif #define OF_CUDA_CHECK(condition) \ for (cudaError_t _of_cuda_check_status = (condition); _of_cuda_check_status != cudaSuccess;) \ LOG(FATAL) << "Check failed: " #condition " : " << cudaGetErrorString(_of_cuda_check_status) \ << " (" << _of_cuda_check_status << ") " #define OF_CUDNN_CHECK(condition) \ for (cudnnStatus_t _of_cudnn_check_status = (condition); \ _of_cudnn_check_status != CUDNN_STATUS_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " << cudnnGetErrorString(_of_cudnn_check_status) \ << " (" << _of_cudnn_check_status << ") " #define OF_CUBLAS_CHECK(condition) \ for (cublasStatus_t _of_cublas_check_status = (condition); \ _of_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " << CublasGetErrorString(_of_cublas_check_status) \ << " (" << _of_cublas_check_status << ") " #define OF_CUFFT_CHECK(condition) \ for (cufftResult_t _of_cufft_check_status = (condition); \ _of_cufft_check_status != CUFFT_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " << CuFFTGetErrorString(_of_cufft_check_status) \ << " (" << _of_cufft_check_status << ") " #if CUDA_VERSION >= 11000 #define OF_CUSOLVER_CHECK(condition) \ for (cusolverStatus_t _of_cusolver_check_status = (condition); \ _of_cusolver_check_status != CUSOLVER_STATUS_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " \ << CusovlerGetErrorString(_of_cusolver_check_status) << " (" \ << _of_cusolver_check_status << ") "; #endif #define OF_CURAND_CHECK(condition) \ for (curandStatus_t _of_curand_check_status = (condition); \ _of_curand_check_status != CURAND_STATUS_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " << CurandGetErrorString(_of_curand_check_status) \ << " (" << _of_curand_check_status << ") " #define OF_NCCL_CHECK(condition) \ for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \ LOG(FATAL) << "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) \ << " (" << _of_nccl_check_status << "). " \ << "To see more detail, please run OneFlow with system variable NCCL_DEBUG=INFO" #define OF_NCCL_CHECK_OR_RETURN(condition) \ for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \ return Error::CheckFailedError().AddStackFrame([](const char* function) { \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \ return frame; \ }(__FUNCTION__)) \ << "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) << " (" \ << _of_nccl_check_status << ") " #if CUDA_VERSION >= 10020 #define OF_NVJPEG_CHECK(condition) \ for (nvjpegStatus_t _of_nvjpeg_check_status = (condition); \ _of_nvjpeg_check_status != NVJPEG_STATUS_SUCCESS;) \ LOG(FATAL) << "Check failed: " #condition " : " << NvjpegGetErrorString(_of_nvjpeg_check_status) \ << " (" << _of_nvjpeg_check_status << ") " #endif // CUDA: grid stride looping #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \ i += step) #define CUDA_1D_KERNEL_LOOP_T(type, i, n) \ for (type i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \ i += step) const int32_t kCudaThreadsNumPerBlock = 512; const int32_t kCudaMaxBlocksNum = 8192; const int32_t kCudaWarpSize = 32; // 48KB, max byte size of shared memroy per thread block // TODO: limit of shared memory should be different for different arch const int32_t kCudaMaxSharedMemoryByteSize = 48 << 10; inline int64_t BlocksNum4ThreadsNum(const int64_t n) { CHECK_GT(n, 0); return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, static_cast(kCudaMaxBlocksNum)); } #define RUN_CUDA_KERNEL(func, stream, elem_cnt, ...) \ stream->As()->LaunchKernel(func, elem_cnt, 1, __VA_ARGS__) size_t GetAvailableGpuMemSize(int dev_id); cudaError_t NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size); class CudaCurrentDeviceGuard final { public: OF_DISALLOW_COPY_AND_MOVE(CudaCurrentDeviceGuard); explicit CudaCurrentDeviceGuard(int32_t dev_id); CudaCurrentDeviceGuard(); ~CudaCurrentDeviceGuard(); private: int32_t saved_dev_id_ = -1; }; class CublasMathModeGuard final { public: OF_DISALLOW_COPY_AND_MOVE(CublasMathModeGuard); CublasMathModeGuard(cublasHandle_t handle, cublasMath_t new_mode); explicit CublasMathModeGuard(cublasHandle_t handle); ~CublasMathModeGuard(); void SetMathMode(cublasMath_t new_mode); private: cublasHandle_t handle_{}; cublasMath_t saved_mode_{}; cublasMath_t new_mode_{}; }; int GetCudaDeviceIndex(); int GetCudaDeviceCount(); Maybe GetCUDAMemoryUsed(); cudaDeviceProp* GetDeviceProperties(int device_id); void SetCudaDeviceIndex(int device_id); void CudaSynchronize(int device_id); void InitCudaContextOnce(int device_id); cudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active); } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_ ================================================ FILE: oneflow/core/device/cudnn_conv_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/device/cudnn_conv_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/cached_caller.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace { template algo_t GetDefaultAlgo(); template<> cudnnConvolutionFwdAlgo_t GetDefaultAlgo() { return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } template<> cudnnConvolutionBwdDataAlgo_t GetDefaultAlgo() { return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } template<> cudnnConvolutionBwdFilterAlgo_t GetDefaultAlgo() { return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } size_t ByteSize4Tensor(const int* dims, int ndim, cudnnDataType_t data_type) { size_t byte_size = GetCudnnDataTypeByteSize(data_type); FOR_RANGE(int, i, 0, ndim) { byte_size *= dims[i]; } return byte_size; } template void SetAlgo4Perf(const CudnnConvArgs& args, CudnnConvResource* res, perf_t* algo_perf, algo_t algo) { algo_perf->algo = algo; if (args.params.data_type == CUDNN_DATA_HALF) { algo_perf->mathType = CUDNN_TENSOR_OP_MATH; } else { algo_perf->mathType = CUDNN_DEFAULT_MATH; } OF_CUDNN_CHECK(GetCudnnConvWorkspaceSize(args, res, algo_perf->algo, &(algo_perf->memory))); algo_perf->status = CUDNN_STATUS_SUCCESS; } template perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res, const std::vector& perf_vec) { using algo_t = decltype(std::declval().algo); if (perf_vec.size() == 0) { LOG(WARNING) << "There is no result with " << (args.heuristic ? "heuristic searching way." : "exhaustive searching way.") << " (max_workspace_size=" << args.params.max_ws_size << ")" << " Use default algo(" << GetDefaultAlgo() << ") instead."; perf_t perf; SetAlgo4Perf(args, res, &perf, GetDefaultAlgo()); return perf; } int found_algo_idx = -1; FOR_RANGE(size_t, i, 0, perf_vec.size()) { // Note: Shouldn't all returned results be successful? CHECK_EQ(perf_vec[i].status, CUDNN_STATUS_SUCCESS); if (perf_vec[i].memory > args.params.max_ws_size) { continue; } if (args.deterministic && perf_vec[i].determinism == CUDNN_NON_DETERMINISTIC) { continue; } found_algo_idx = i; break; } if (found_algo_idx == -1) { LOG(WARNING) << "Cannot find any algorithm meets requirements (max_workspace_size=" << args.params.max_ws_size << ", determinism=" << args.deterministic << ") using " << (args.heuristic ? "heuristic searching way." : "exhaustive searching way.") << " Using default algo(" << GetDefaultAlgo() << ") instead."; perf_t algo_perf; SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo()); return algo_perf; } if (found_algo_idx != 0) { LOG(WARNING) << "Currently available alogrithm (algo=" << perf_vec[found_algo_idx].algo << ", require memory=" << perf_vec[found_algo_idx].memory << ", idx=" << found_algo_idx << ") meeting requirments (max_workspace_size=" << args.params.max_ws_size << ", determinism=" << args.deterministic << ") is not fastest. Fastest algorithm (" << perf_vec[0].algo << ") requires memory " << perf_vec[0].memory; } #if CUDNN_VERSION < 7500 // google [blacklist fft algorithms for strided dgrad] if (std::is_same::value) { int stride_dim = args.params.x_ndim - 2; bool blacklist = std::any_of(std::begin(args.params.stride), std::begin(args.params.stride) + stride_dim, [](int n) { return n != 1; }); if (blacklist && (static_cast(perf_vec[found_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || static_cast(perf_vec[found_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { perf_t algo_perf; SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo()); return algo_perf; } } #endif return perf_vec.at(found_algo_idx); } template perf_t CudnnConvAlgoGetOrInfer(const CudnnConvParams& params, const std::function& InferFn, CudnnConvAlgoCache::Store* store, std::mutex* mutex) { const size_t cache_size = Singleton::Get()->thread_local_cache_max_size(); auto InferWithCache = [&](const CudnnConvParams& p) -> perf_t { CudnnConvParams params_without_ws = p; params_without_ws.max_ws_size = 0; std::unique_lock lock(*mutex); const auto& key_it = store->find(params_without_ws); if (key_it != store->cend()) { const auto& perf_it = std::find_if( key_it->second.cbegin(), key_it->second.cend(), [&](const std::pair& pair) { // There might be a case that only memory size pair.second.memory was required for the // best algorithm even though a workspace pair.first supplied return pair.second.memory <= p.max_ws_size /* for memory safety */ && pair.first >= p.max_ws_size /* a case with larger workspace infered before */; }); if (perf_it != key_it->second.cend()) { return perf_it->second; } } perf_t perf = InferFn(p); (*store)[params_without_ws].emplace_back(std::make_pair(p.max_ws_size, perf)); return perf; }; return ThreadLocalCachedCall(cache_size, InferWithCache, params); } } // namespace template<> cudnnConvolutionFwdAlgoPerf_t CudnnConvAlgoCache::Remember( const CudnnConvParams& params, const std::function& InferFn) { return CudnnConvAlgoGetOrInfer(params, InferFn, &fwd_algo_store_, &fwd_algo_store_mutex_); } template<> cudnnConvolutionBwdDataAlgoPerf_t CudnnConvAlgoCache::Remember( const CudnnConvParams& params, const std::function& InferFn) { return CudnnConvAlgoGetOrInfer( params, InferFn, &bwd_data_algo_store_, &bwd_data_algo_store_mutex_); } template<> cudnnConvolutionBwdFilterAlgoPerf_t CudnnConvAlgoCache::Remember( const CudnnConvParams& params, const std::function& InferFn) { return CudnnConvAlgoGetOrInfer( params, InferFn, &bwd_filter_algo_store_, &bwd_filter_algo_cache_mutex_); } CudnnConvDesc::~CudnnConvDesc() { OF_CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(val_)); } CudnnConvDesc::CudnnConvDesc(const DataType compute_type, const DataType data_type, const ShapeView& in_blob_shape, const user_op::InferContext& ctx) { int32_t opkernel_dim = in_blob_shape.NumAxes() - 2; OF_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&val_)); const auto& padding_before = ctx.Attr>("padding_before"); const auto& strides = ctx.Attr>("strides"); const auto& dilation_rate = ctx.Attr>("dilation_rate"); if (opkernel_dim == 2) { OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor( val_, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1), dilation_rate.at(0), dilation_rate.at(1), CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } else if (opkernel_dim == 1) { OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(val_, padding_before.at(0), 0, strides.at(0), 1, dilation_rate.at(0), 1, CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } else { OF_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor( val_, opkernel_dim, padding_before.data(), strides.data(), dilation_rate.data(), CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } const int32_t groups = ctx.Attr("groups"); if (groups != 1) { OF_CUDNN_CHECK(cudnnSetConvolutionGroupCount(val_, groups)); } bool use_tensor_op_math; if (GetCudnnDataType(data_type) == CUDNN_DATA_HALF) { use_tensor_op_math = true; #if CUDNN_VERSION >= 8100 } else if (GetCudnnDataType(data_type) == CUDNN_DATA_BFLOAT16) { use_tensor_op_math = true; #endif } else { use_tensor_op_math = false; } if (use_tensor_op_math) { OF_CUDNN_CHECK(cudnnSetConvolutionMathType(val_, CUDNN_TENSOR_OP_MATH)); } } CudnnConvDesc::CudnnConvDesc(const DataType compute_type, const DataType data_type, const ShapeView& in_blob_shape, const user_op::KernelComputeContext& ctx) { int32_t opkernel_dim = in_blob_shape.NumAxes() - 2; OF_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&val_)); const auto& padding_before = ctx.Attr>("padding_before"); const auto& strides = ctx.Attr>("strides"); const auto& dilation_rate = ctx.Attr>("dilation_rate"); if (opkernel_dim == 2) { OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor( val_, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1), dilation_rate.at(0), dilation_rate.at(1), CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } else if (opkernel_dim == 1) { OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(val_, padding_before.at(0), 0, strides.at(0), 1, dilation_rate.at(0), 1, CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } else { OF_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor( val_, opkernel_dim, padding_before.data(), strides.data(), dilation_rate.data(), CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type))); } const int32_t groups = ctx.Attr("groups"); if (groups != 1) { OF_CUDNN_CHECK(cudnnSetConvolutionGroupCount(val_, groups)); } bool use_tensor_op_math; if (GetCudnnDataType(data_type) == CUDNN_DATA_HALF) { use_tensor_op_math = true; #if CUDNN_VERSION >= 8100 } else if (GetCudnnDataType(data_type) == CUDNN_DATA_BFLOAT16) { use_tensor_op_math = true; #endif } else { use_tensor_op_math = false; } if (use_tensor_op_math) { OF_CUDNN_CHECK(cudnnSetConvolutionMathType(val_, CUDNN_TENSOR_OP_MATH)); } } CudnnConvArgs::CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_type, const ShapeView& x_shape, DataType w_data_type, const ShapeView& w_shape, DataType y_data_type, const ShapeView& y_shape, const std::string& data_format, size_t max_workspace_size, bool heuristic_search, bool use_deterministic_algo_only, bool enable_pseudo_half) : xdesc(x_data_type, x_shape, data_format), ydesc(y_data_type, y_shape, data_format), wdesc(w_data_type, w_shape, data_format), cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx), heuristic(heuristic_search), deterministic(use_deterministic_algo_only) { std::memset(¶ms, 0, sizeof(CudnnConvParams)); OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.x_data_type, ¶ms.x_ndim, params.x_dims, params.x_strides)); OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.y_data_type, ¶ms.y_ndim, params.y_dims, params.y_strides)); OF_CUDNN_CHECK(cudnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.w_data_type, ¶ms.w_format, ¶ms.w_ndim, params.w_dims)); cudnnConvolutionMode_t mode; int conv_dim_size = 0; OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims, &conv_dim_size, params.padding, params.stride, params.dilation, &mode, ¶ms.data_type)); CHECK_EQ(params.x_data_type, params.w_data_type); CHECK_EQ(params.x_ndim, params.w_ndim); CHECK_EQ(conv_dim_size + 2, params.x_ndim); OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), ¶ms.groups)); params.max_ws_size = max_workspace_size; } CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType x_data_type, const ShapeView& x_shape, DataType w_data_type, const ShapeView& w_shape, DataType y_data_type, const ShapeView& y_shape, const std::string& data_format, size_t max_workspace_size, bool heuristic_search, bool use_deterministic_algo_only, bool enable_pseudo_half) : xdesc(x_data_type, x_shape, data_format), ydesc(y_data_type, y_shape, data_format), wdesc(w_data_type, w_shape, data_format), cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx), heuristic(heuristic_search), deterministic(use_deterministic_algo_only) { std::memset(¶ms, 0, sizeof(CudnnConvParams)); OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.x_data_type, ¶ms.x_ndim, params.x_dims, params.x_strides)); OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.y_data_type, ¶ms.y_ndim, params.y_dims, params.y_strides)); OF_CUDNN_CHECK(cudnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims, ¶ms.w_data_type, ¶ms.w_format, ¶ms.w_ndim, params.w_dims)); cudnnConvolutionMode_t mode; int conv_dim_size = 0; OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims, &conv_dim_size, params.padding, params.stride, params.dilation, &mode, ¶ms.data_type)); CHECK_EQ(params.x_data_type, params.w_data_type); CHECK_EQ(params.x_ndim, params.w_ndim); CHECK_EQ(conv_dim_size + 2, params.x_ndim); OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), ¶ms.groups)); params.max_ws_size = max_workspace_size; } ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args) : handle_(nullptr), x_dptr_(nullptr), w_dptr_(nullptr), y_dptr_(nullptr), ws_dptr_(nullptr) { x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type); w_byte_size_ = ByteSize4Tensor(args.params.w_dims, args.params.w_ndim, args.params.w_data_type); y_byte_size_ = ByteSize4Tensor(args.params.y_dims, args.params.y_ndim, args.params.y_data_type); ws_byte_size_ = args.params.max_ws_size; } ManagedCudnnConvResource::~ManagedCudnnConvResource() { if (handle_ != nullptr) { Singleton::Get()->Put(handle_); handle_ = nullptr; } if (x_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(x_dptr_)); } if (w_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(w_dptr_)); } if (y_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(y_dptr_)); } if (ws_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(ws_dptr_)); } } cudnnHandle_t ManagedCudnnConvResource::cudnn_handle() { if (handle_ == nullptr) { handle_ = Singleton::Get()->Get(); } return handle_; } void* ManagedCudnnConvResource::x_mut_dptr() { if (x_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&x_dptr_, x_byte_size_)); } return x_dptr_; } void* ManagedCudnnConvResource::w_mut_dptr() { if (w_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&w_dptr_, w_byte_size_)); } return w_dptr_; } void* ManagedCudnnConvResource::y_mut_dptr() { if (y_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&y_dptr_, y_byte_size_)); } return y_dptr_; } const void* ManagedCudnnConvResource::x_const_dptr() const { return const_cast(this)->x_mut_dptr(); } const void* ManagedCudnnConvResource::w_const_dptr() const { return const_cast(this)->w_mut_dptr(); } const void* ManagedCudnnConvResource::y_const_dptr() const { return const_cast(this)->y_mut_dptr(); } void* ManagedCudnnConvResource::ws_dptr() { if (ws_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&ws_dptr_, ws_byte_size_)); } return ws_dptr_; } bool operator==(const CudnnConvParams& a, const CudnnConvParams& b) { auto ptr1 = reinterpret_cast(&a); auto ptr2 = reinterpret_cast(&b); return memcmp(ptr1, ptr2, sizeof(CudnnConvParams)) == 0; } DataType GetConvDescDataType(DataType data_type, bool pseudo_half) { if (data_type == DataType::kFloat16 && pseudo_half) { return DataType::kFloat; } else if (data_type == DataType::kBFloat16) { return DataType::kFloat; } return data_type; } cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionFwdAlgo_t algo, size_t* sz) { return cudnnGetConvolutionForwardWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(), args.wdesc.Get(), args.cdesc.Get(), args.ydesc.Get(), algo, sz); } cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) { return cudnnGetConvolutionBackwardDataWorkspaceSize(res->cudnn_handle(), args.wdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.xdesc.Get(), algo, sz); } cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) { return cudnnGetConvolutionBackwardFilterWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.wdesc.Get(), algo, sz); } template<> struct CudnnConvAlgorithmSearch { using perf_t = cudnnConvolutionFwdAlgoPerf_t; static int GetAlgoMaxCount(CudnnConvResource* res) { int max_algo_cnt = 0; OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt)); return max_algo_cnt; } static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7( res->cudnn_handle(), args.xdesc.Get(), args.wdesc.Get(), args.cdesc.Get(), args.ydesc.Get(), perf_vec->size(), &found_algo_cnt, perf_vec->data())); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.wdesc.Get(), res->w_const_dptr(), args.cdesc.Get(), args.ydesc.Get(), res->y_mut_dptr(), perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(), args.params.max_ws_size)); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } }; template<> struct CudnnConvAlgorithmSearch { using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; static int GetAlgoMaxCount(CudnnConvResource* res) { int max_algo_cnt = 0; OF_CUDNN_CHECK( cudnnGetConvolutionBackwardDataAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt)); return max_algo_cnt; } static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7( res->cudnn_handle(), args.wdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.xdesc.Get(), perf_vec->size(), &found_algo_cnt, perf_vec->data())); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( res->cudnn_handle(), args.wdesc.Get(), res->w_const_dptr(), args.ydesc.Get(), res->y_const_dptr(), args.cdesc.Get(), args.xdesc.Get(), res->x_mut_dptr(), perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(), args.params.max_ws_size)); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } }; template<> struct CudnnConvAlgorithmSearch { using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; static int GetAlgoMaxCount(CudnnConvResource* res) { int max_algo_cnt = 0; OF_CUDNN_CHECK( cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt)); return max_algo_cnt; } static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7( res->cudnn_handle(), args.xdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.wdesc.Get(), perf_vec->size(), &found_algo_cnt, perf_vec->data())); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res, std::vector* perf_vec) { int found_algo_cnt = 0; perf_vec->resize(GetAlgoMaxCount(res)); OF_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.ydesc.Get(), res->y_const_dptr(), args.cdesc.Get(), args.wdesc.Get(), res->w_mut_dptr(), perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(), args.params.max_ws_size)); // vector::resize does not affect the first found_algo_cnt elements. perf_vec->resize(found_algo_cnt); } }; template perf_t FindCudnnConvAlgorithm(CudnnConvArgs* args) { ManagedCudnnConvResource res(*args); return FindCudnnConvAlgorithmWithResource(args, &res); } template perf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs* args, CudnnConvResource* res) { auto Infer = [args, res](const CudnnConvParams& params) { std::vector perf_vec; if (args->heuristic) { CudnnConvAlgorithmSearch::HeuristicSearch(*args, res, &perf_vec); } else { CudnnConvAlgorithmSearch::ExhaustiveSearch(*args, res, &perf_vec); } return GetBestAlgorithm(*args, res, perf_vec); }; return Singleton::Get()->Remember(args->params, Infer); } template perf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs* args, algo_t algo) { ManagedCudnnConvResource res(*args); return GetCudnnConvAlgorithmPerferenceWithResource(args, &res, algo); } template perf_t GetCudnnConvAlgorithmPerferenceWithResource(CudnnConvArgs* args, CudnnConvResource* res, algo_t algo) { perf_t perf; SetAlgo4Perf(*args, res, &perf, algo); return perf; } #define EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(perf_t) \ template perf_t FindCudnnConvAlgorithm(CudnnConvArgs*); \ template perf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs*, CudnnConvResource*); \ template perf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs*, \ decltype(std::declval().algo)); \ template perf_t GetCudnnConvAlgorithmPerferenceWithResource( \ CudnnConvArgs*, CudnnConvResource*, decltype(std::declval().algo)); EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionFwdAlgoPerf_t) EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionBwdDataAlgoPerf_t) EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionBwdFilterAlgoPerf_t) } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/device/cudnn_conv_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_ #define ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace user_op { class KernelComputeContext; class InferContext; } // namespace user_op class CudnnConvDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnConvDesc); CudnnConvDesc() = delete; ~CudnnConvDesc(); CudnnConvDesc(const DataType compute_type, const DataType data_type, const ShapeView& in_blob_shape, const user_op::InferContext& ctx); CudnnConvDesc(const DataType compute_type, const DataType data_type, const ShapeView& in_blob_shape, const user_op::KernelComputeContext& ctx); const cudnnConvolutionDescriptor_t& Get() const { return val_; } private: cudnnConvolutionDescriptor_t val_; }; struct CudnnConvParams { static constexpr size_t kTensorMaxDims = 5; static constexpr size_t kConvMaxDims = 3; cudnnDataType_t x_data_type; cudnnDataType_t w_data_type; cudnnDataType_t y_data_type; cudnnDataType_t data_type; cudnnTensorFormat_t w_format; int x_ndim; int w_ndim; int y_ndim; int x_dims[kTensorMaxDims]; int x_strides[kTensorMaxDims]; int y_dims[kTensorMaxDims]; int y_strides[kTensorMaxDims]; int w_dims[kTensorMaxDims]; int padding[kConvMaxDims]; int stride[kConvMaxDims]; int dilation[kConvMaxDims]; size_t max_ws_size; int groups; }; struct CudnnConvArgs final { CudnnConvParams params; CudnnTensorDesc xdesc; CudnnTensorDesc ydesc; CudnnFilterDesc wdesc; CudnnConvDesc cdesc; bool heuristic; bool deterministic; OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgs); CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_type, const ShapeView& x_shape, DataType w_data_type, const ShapeView& w_shape, DataType y_data_type, const ShapeView& y_shape, const std::string& data_format, size_t max_workspace_size, bool heuristic_search, bool use_deterministic_algo_only, bool enable_pseudo_half); CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType x_data_type, const ShapeView& x_shape, DataType w_data_type, const ShapeView& w_shape, DataType y_data_type, const ShapeView& y_shape, const std::string& data_format, size_t max_workspace_size, bool heuristic_search, bool use_deterministic_algo_only, bool enable_pseudo_half); }; class CudnnConvResource { public: CudnnConvResource() = default; virtual ~CudnnConvResource() = default; virtual cudnnHandle_t cudnn_handle() = 0; virtual void* w_mut_dptr() = 0; virtual void* x_mut_dptr() = 0; virtual void* y_mut_dptr() = 0; virtual const void* w_const_dptr() const = 0; virtual const void* x_const_dptr() const = 0; virtual const void* y_const_dptr() const = 0; virtual void* ws_dptr() = 0; }; class AllocatedCudnnConvResource final : public CudnnConvResource { public: AllocatedCudnnConvResource(cudnnHandle_t handle, void* x_dptr, void* w_dptr, void* y_dptr, void* ws_dptr) : handle_(handle), x_dptr_(x_dptr), w_dptr_(w_dptr), y_dptr_(y_dptr), ws_dptr_(ws_dptr) {} ~AllocatedCudnnConvResource() = default; cudnnHandle_t cudnn_handle() override { return handle_; } const void* x_const_dptr() const override { return x_dptr_; } const void* w_const_dptr() const override { return w_dptr_; } const void* y_const_dptr() const override { return y_dptr_; } void* x_mut_dptr() override { return x_dptr_; } void* w_mut_dptr() override { return w_dptr_; } void* y_mut_dptr() override { return y_dptr_; } void* ws_dptr() override { return ws_dptr_; } private: cudnnHandle_t handle_; void* x_dptr_; void* w_dptr_; void* y_dptr_; void* ws_dptr_; }; class ManagedCudnnConvResource final : public CudnnConvResource { public: ManagedCudnnConvResource(const CudnnConvArgs& args); ~ManagedCudnnConvResource() override; cudnnHandle_t cudnn_handle() override; void* x_mut_dptr() override; void* w_mut_dptr() override; void* y_mut_dptr() override; const void* x_const_dptr() const override; const void* w_const_dptr() const override; const void* y_const_dptr() const override; void* ws_dptr() override; private: cudnnHandle_t handle_; void* x_dptr_; void* w_dptr_; void* y_dptr_; void* ws_dptr_; size_t x_byte_size_; size_t w_byte_size_; size_t y_byte_size_; size_t ws_byte_size_; }; bool operator==(const CudnnConvParams& a, const CudnnConvParams& b); DataType GetConvDescDataType(DataType data_type, bool pseudo_half); template struct CudnnConvAlgorithmSearch; cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionFwdAlgo_t algo, size_t* sz); cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdDataAlgo_t algo, size_t* sz); cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz); template perf_t FindCudnnConvAlgorithm(CudnnConvArgs* args); template perf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs* args, CudnnConvResource* res); template perf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs* args, algo_t algo); template perf_t GetCudnnConvAlgorithmPerferenceWithResource(CudnnConvArgs* args, CudnnConvResource* res, algo_t algo); } // namespace oneflow namespace std { // Hashing machinery for Params // see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function template<> struct hash final { // Params must be a POD because we read out its memory // contenst as char* when hashing static_assert(std::is_pod::value, "CudnnConvParams is not POD"); size_t operator()(const oneflow::CudnnConvParams& params) const { const auto* ptr = reinterpret_cast(¶ms); uint32_t value = 0x811C9DC5; for (int i = 0; i < (int)sizeof(oneflow::CudnnConvParams); ++i) { value ^= ptr[i]; value *= 0x01000193; } return (size_t)value; } }; } // namespace std namespace oneflow { class CudnnConvAlgoCache final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnConvAlgoCache); CudnnConvAlgoCache() = default; ~CudnnConvAlgoCache() = default; template using WorkspaceSizeAndPerfT = std::pair; template using Store = HashMap>>; template perf_t Remember(const CudnnConvParams& params, const std::function& InferFn); private: Store fwd_algo_store_; std::mutex fwd_algo_store_mutex_; Store bwd_data_algo_store_; std::mutex bwd_data_algo_store_mutex_; Store bwd_filter_algo_store_; std::mutex bwd_filter_algo_cache_mutex_; }; } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_ ================================================ FILE: oneflow/core/device/cudnn_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cudnn_util.h" namespace oneflow { #ifdef WITH_CUDA cudnnDataType_t GetCudnnDataType(DataType val) { #define MAKE_ENTRY(type_cpp, type_cudnn) \ if (val == GetDataType::value) { return type_cudnn; } OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CUDNN_DATA_TYPE_SEQ); #undef MAKE_ENTRY #if CUDNN_VERSION >= 8100 if (val == kBFloat16) { return CUDNN_DATA_BFLOAT16; } #endif UNIMPLEMENTED(); } CudnnTensorDesc::CudnnTensorDesc() { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_)); } CudnnTensorDesc::~CudnnTensorDesc() { OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(val_)); } CudnnTensorDesc::CudnnTensorDesc(cudnnTensorFormat_t format, DataType data_type, int n, int c, int h, int w) { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_)); OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, format, GetCudnnDataType(data_type), n, c, h, w)); } CudnnTensorDesc::CudnnTensorDesc(DataType data_type, int dims, const int* dim, const int* stride) { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_)); OF_CUDNN_CHECK(cudnnSetTensorNdDescriptor(val_, GetCudnnDataType(data_type), dims, dim, stride)); } CudnnTensorDesc::CudnnTensorDesc(DataType data_type, const ShapeView& shape, const std::string& data_format) { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_)); cudnnTensorFormat_t cudnn_data_format; if (data_format == "channels_first") { cudnn_data_format = CUDNN_TENSOR_NCHW; } else if (data_format == "channels_last") { cudnn_data_format = CUDNN_TENSOR_NHWC; } else { UNIMPLEMENTED(); } if (shape.NumAxes() == 3) { int data_num = static_cast(shape.At(0)); int channels = data_format == "channels_first" ? static_cast(shape.At(1)) : static_cast(shape.At(2)); int kernel_h = data_format == "channels_first" ? static_cast(shape.At(2)) : static_cast(shape.At(1)); int kernel_w = 1; OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, cudnn_data_format, GetCudnnDataType(data_type), data_num, channels, kernel_h, kernel_w)); } else if (shape.NumAxes() == 4) { int data_num = static_cast(shape.At(0)); int channels = data_format == "channels_first" ? static_cast(shape.At(1)) : static_cast(shape.At(3)); int kernel_h = data_format == "channels_first" ? static_cast(shape.At(2)) : static_cast(shape.At(1)); int kernel_w = data_format == "channels_first" ? static_cast(shape.At(3)) : static_cast(shape.At(2)); OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, cudnn_data_format, GetCudnnDataType(data_type), data_num, channels, kernel_h, kernel_w)); } else { std::vector tensor_dim({shape.ptr(), shape.ptr() + shape.NumAxes()}); std::vector stride_of_tensor(shape.NumAxes(), 1); for (int32_t i = shape.NumAxes() - 2; i >= 0; --i) { stride_of_tensor[i] = stride_of_tensor[i + 1] * shape.At(i + 1); } OF_CUDNN_CHECK(cudnnSetTensorNdDescriptor(val_, GetCudnnDataType(data_type), shape.NumAxes(), tensor_dim.data(), stride_of_tensor.data())); } } CudnnFilterDesc::~CudnnFilterDesc() { OF_CUDNN_CHECK(cudnnDestroyFilterDescriptor(val_)); } CudnnFilterDesc::CudnnFilterDesc(DataType data_type, const ShapeView& shape, const std::string& data_format) { OF_CUDNN_CHECK(cudnnCreateFilterDescriptor(&val_)); cudnnTensorFormat_t cudnn_data_format; if (data_format == "channels_first") { cudnn_data_format = CUDNN_TENSOR_NCHW; } else if (data_format == "channels_last") { cudnn_data_format = CUDNN_TENSOR_NHWC; } else { UNIMPLEMENTED(); } if (shape.NumAxes() == 3) { int filters = static_cast(shape.At(0)); int c = data_format == "channels_first" ? static_cast(shape.At(1)) : static_cast(shape.At(2)); int kernel_h = data_format == "channels_first" ? static_cast(shape.At(2)) : static_cast(shape.At(1)); int kernel_w = 1; OF_CUDNN_CHECK(cudnnSetFilter4dDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format, filters, c, kernel_h, kernel_w)); } else if (shape.NumAxes() == 4) { int filters = static_cast(shape.At(0)); int kernel_h = data_format == "channels_first" ? static_cast(shape.At(2)) : static_cast(shape.At(1)); int kernel_w = data_format == "channels_first" ? static_cast(shape.At(3)) : static_cast(shape.At(2)); int c = data_format == "channels_first" ? static_cast(shape.At(1)) : static_cast(shape.At(3)); OF_CUDNN_CHECK(cudnnSetFilter4dDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format, filters, c, kernel_h, kernel_w)); } else { std::vector dims({shape.ptr(), shape.ptr() + shape.NumAxes()}); OF_CUDNN_CHECK(cudnnSetFilterNdDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format, dims.size(), dims.data())); } } CudnnActivationDesc::CudnnActivationDesc(cudnnActivationMode_t mode, cudnnNanPropagation_t relu_nan_opt, double coef) { OF_CUDNN_CHECK(cudnnCreateActivationDescriptor(&val_)); OF_CUDNN_CHECK(cudnnSetActivationDescriptor(val_, mode, relu_nan_opt, coef)); } CudnnActivationDesc::~CudnnActivationDesc() { OF_CUDNN_CHECK(cudnnDestroyActivationDescriptor(val_)); } size_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type) { size_t byte_size = 0; switch (data_type) { case CUDNN_DATA_FLOAT: case CUDNN_DATA_INT32: case CUDNN_DATA_INT8x4: case CUDNN_DATA_UINT8x4: { byte_size = 4; break; } case CUDNN_DATA_DOUBLE: { byte_size = 8; break; } case CUDNN_DATA_HALF: { byte_size = 2; break; } case CUDNN_DATA_INT8: case CUDNN_DATA_UINT8: { byte_size = 1; break; } #if CUDNN_VERSION > 7200 case CUDNN_DATA_INT8x32: { byte_size = 32; break; } #endif #if CUDNN_VERSION >= 8100 case CUDNN_DATA_BFLOAT16: { byte_size = 2; break; } #endif default: { UNIMPLEMENTED(); } } return byte_size; } CudnnHandlePool::~CudnnHandlePool() { for (auto& pair : handle_list_map_) { int64_t device_id = pair.first; auto& handle_list = pair.second; CudaCurrentDeviceGuard guard(device_id); while (!handle_list.empty()) { cudnnHandle_t handle = handle_list.back(); handle_list.pop_back(); OF_CUDNN_CHECK(cudnnDestroy(handle)); } } handle_list_map_.clear(); } cudnnHandle_t CudnnHandlePool::Get() { int device_id; OF_CUDA_CHECK(cudaGetDevice(&device_id)); { std::unique_lock lock(mutex_); std::vector& handle_list = handle_list_map_[device_id]; if (!handle_list.empty()) { cudnnHandle_t handle = handle_list.back(); handle_list.pop_back(); return handle; } } cudnnHandle_t handle; OF_CUDNN_CHECK(cudnnCreate(&handle)); return handle; } void CudnnHandlePool::Put(cudnnHandle_t handle) { int device_id; OF_CUDA_CHECK(cudaGetDevice(&device_id)); std::unique_lock lock(mutex_); std::vector& handle_list = handle_list_map_[device_id]; handle_list.push_back(handle); } #endif // WITH_CUDA template const void* CudnnSPOnePtr() { static const float fval = 1.0f; static const double dval = 1.0; const void* ret = std::is_same::value ? static_cast(&dval) : static_cast(&fval); return ret; } template const void* CudnnSPZeroPtr() { static const float fval = 0.0f; static const double dval = 0.0; const void* ret = std::is_same::value ? static_cast(&dval) : static_cast(&fval); return ret; } template const void* CudnnSPOnePtr(); template const void* CudnnSPOnePtr(); template const void* CudnnSPOnePtr(); template const void* CudnnSPZeroPtr(); template const void* CudnnSPZeroPtr(); template const void* CudnnSPZeroPtr(); const void* CudnnSPOnePtr(const DataType dtype) { if (dtype == kDouble) { return CudnnSPOnePtr(); } else if (dtype == kFloat) { return CudnnSPOnePtr(); } else if (dtype == kFloat16) { return CudnnSPOnePtr(); } else if (dtype == kBFloat16) { // NOTE(guoran): kBFloat16 use float OnePtr return CudnnSPOnePtr(); } else { UNIMPLEMENTED(); } } const void* CudnnSPZeroPtr(const DataType dtype) { if (dtype == kDouble) { return CudnnSPZeroPtr(); } else if (dtype == kFloat) { return CudnnSPZeroPtr(); } else if (dtype == kFloat16) { return CudnnSPZeroPtr(); } else if (dtype == kBFloat16) { // NOTE(guoran): kBFloat16 use float ZeroPtr return CudnnSPZeroPtr(); } else { UNIMPLEMENTED(); } } } // namespace oneflow ================================================ FILE: oneflow/core/device/cudnn_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_ #define ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape_view.h" #ifdef WITH_CUDA #include "cudnn.h" namespace oneflow { #define CUDNN_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(float, CUDNN_DATA_FLOAT) \ OF_PP_MAKE_TUPLE_SEQ(float16, CUDNN_DATA_HALF) \ OF_PP_MAKE_TUPLE_SEQ(double, CUDNN_DATA_DOUBLE) \ OF_PP_MAKE_TUPLE_SEQ(int8_t, CUDNN_DATA_INT8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, CUDNN_DATA_INT32) cudnnDataType_t GetCudnnDataType(DataType); template struct CudnnDataType; #define SPECIALIZE_CUDNN_DATA_TYPE(type_cpp, type_cudnn) \ template<> \ struct CudnnDataType : std::integral_constant {}; OF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDNN_DATA_TYPE, CUDNN_DATA_TYPE_SEQ); #undef SPECIALIZE_CUDNN_DATA_TYPE class CudnnTensorDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDesc); CudnnTensorDesc(); ~CudnnTensorDesc(); CudnnTensorDesc(cudnnTensorFormat_t, DataType, int n, int c, int h, int w); CudnnTensorDesc(DataType data_type, int dims, const int* dim, const int* stride); CudnnTensorDesc(DataType data_type, const ShapeView& shape, const std::string& data_format); const cudnnTensorDescriptor_t& Get() const { return val_; } private: cudnnTensorDescriptor_t val_; }; class CudnnFilterDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnFilterDesc); CudnnFilterDesc() = delete; ~CudnnFilterDesc(); CudnnFilterDesc(DataType data_type, const ShapeView& shape, const std::string& data_format); const cudnnFilterDescriptor_t& Get() const { return val_; } private: cudnnFilterDescriptor_t val_; }; class CudnnActivationDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnActivationDesc); CudnnActivationDesc() = delete; ~CudnnActivationDesc(); CudnnActivationDesc(cudnnActivationMode_t mode, cudnnNanPropagation_t relu_nan_opt, double coef); const cudnnActivationDescriptor_t& Get() const { return val_; } private: cudnnActivationDescriptor_t val_; }; size_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type); // SP for scaling parameter template const void* CudnnSPOnePtr(); template const void* CudnnSPZeroPtr(); const void* CudnnSPOnePtr(const DataType dtype); const void* CudnnSPZeroPtr(const DataType dtype); class CudnnHandlePool { public: CudnnHandlePool() = default; ~CudnnHandlePool(); cudnnHandle_t Get(); void Put(cudnnHandle_t handle); private: std::mutex mutex_; HashMap> handle_list_map_; }; } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_ ================================================ FILE: oneflow/core/device/device_id.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/device_id.h" namespace oneflow { namespace { constexpr size_t kInt32Bits = sizeof(int32_t) * CHAR_BIT; constexpr size_t kDeviceIndexShift = 0; constexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits; constexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits; static_assert(kRankShift + DeviceId::kRankBits < kInt32Bits, ""); } // namespace int64_t EncodeDeviceIdToInt64(const DeviceId& device_id) { int64_t id = static_cast(device_id.device_index()); id |= static_cast(device_id.device_type()) << kDeviceTypeShift; id |= static_cast(device_id.rank()) << kRankShift; return id; } } // namespace oneflow ================================================ FILE: oneflow/core/device/device_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_DEVICE_ID_H_ #define ONEFLOW_CORE_DEVICE_DEVICE_ID_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/device_type.h" namespace oneflow { // DeviceId encoding (bits) // | reserved | node_index | device_type | device_index | // | --- 1 ---- | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | // | DeviceId | // | ------------------------------- 32 ---------------------------------- | class DeviceId { public: using rank_t = uint32_t; using device_type_t = uint32_t; using device_index_t = uint32_t; constexpr static size_t kRankBits = 16; constexpr static size_t kDeviceTypeBits = 5; constexpr static size_t kDeviceIndexBits = 7; constexpr static rank_t kMaxRank = (rank_t{1} << kRankBits) - rank_t{1}; constexpr static device_type_t kMaxDeviceTypeVal = (device_type_t{1} << kDeviceTypeBits) - device_type_t{1}; constexpr static device_index_t kMaxDeviceIndex = (device_index_t{1} << kDeviceIndexBits) - device_index_t{1}; DeviceId(rank_t rank, DeviceType device_type, device_index_t device_index) : rank_(rank), device_type_(static_cast(device_type)), device_index_(device_index) { CHECK_LE(rank_, kMaxRank); CHECK_LE(device_type_, kMaxDeviceTypeVal); CHECK_LE(device_index_, kMaxDeviceIndex); } rank_t rank() const { return rank_; } DeviceType device_type() const { return static_cast(device_type_); } device_index_t device_index() const { return device_index_; } bool operator==(const DeviceId& rhs) const { return rank_ == rhs.rank_ && device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_; } bool operator!=(const DeviceId& rhs) const { return !(*this == rhs); } size_t hash() const { size_t hash = std::hash{}(rank_); HashCombine(&hash, std::hash{}(device_type_)); HashCombine(&hash, std::hash{}(device_index_)); return hash; } private: rank_t rank_; device_type_t device_type_; device_index_t device_index_; }; int64_t EncodeDeviceIdToInt64(const DeviceId& device_id); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::DeviceId& device_id) const { return device_id.hash(); } }; } // namespace std #endif // ONEFLOW_CORE_DEVICE_DEVICE_ID_H_ ================================================ FILE: oneflow/core/device/ep_based_event_record.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ #define ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ #include "oneflow/core/device/event_record.h" #include "oneflow/core/ep/include/active_device_guard.h" namespace oneflow { class EpBasedEventRecord : public EventRecord { public: OF_DISALLOW_COPY_AND_MOVE(EpBasedEventRecord); EpBasedEventRecord(ep::Event* event, ep::Device* device) : event_(event), device_(device) {} ~EpBasedEventRecord() { ep::ActiveDeviceGuard guard(device_); device_->DestroyEvent(event_); }; static std::shared_ptr MakeEventRecord(ep::Stream* stream) { ep::Device* device = stream->device(); ep::ActiveDeviceGuard guard(device); ep::Event* event = device->CreateEvent(); stream->RecordEvent(event); return std::make_shared(event, device); } bool QueryDone() const override { ep::ActiveDeviceGuard guard(device_); bool done = CHECK_JUST(event_->QueryDone()); return done; } private: ep::Event* event_; ep::Device* device_; }; } // namespace oneflow #endif // ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ ================================================ FILE: oneflow/core/device/event_record.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_ #define ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_ #include #include #include "oneflow/core/common/util.h" namespace oneflow { class EventRecord { public: EventRecord(const EventRecord&) = delete; EventRecord(EventRecord&&) = delete; EventRecord& operator=(const EventRecord&) = delete; EventRecord& operator=(EventRecord&&) = delete; virtual ~EventRecord() = default; virtual bool QueryDone() const = 0; EventRecord() = default; }; class NaiveEventRecord final : public EventRecord { public: NaiveEventRecord(const NaiveEventRecord&) = delete; NaiveEventRecord(NaiveEventRecord&&) = delete; NaiveEventRecord& operator=(const NaiveEventRecord&) = delete; NaiveEventRecord& operator=(NaiveEventRecord&&) = delete; NaiveEventRecord() = default; ~NaiveEventRecord() override = default; bool QueryDone() const override { return true; } }; class SharedEventRecord final : public EventRecord { public: SharedEventRecord(const SharedEventRecord&) = delete; SharedEventRecord(SharedEventRecord&&) = delete; SharedEventRecord& operator=(const SharedEventRecord&) = delete; SharedEventRecord& operator=(SharedEventRecord&&) = delete; SharedEventRecord() : EventRecord(), inited_(false) {} ~SharedEventRecord() override = default; bool QueryDone() const override { return inited_ && event_record_->QueryDone(); } void Init(const std::shared_ptr& event_record) { // No lock needed. This function will be called only one time. // In most cases, errors will be successfully detected by CHECK // even though run in different threads. CHECK(!inited_); event_record_ = event_record; inited_ = true; } void TryInit(const std::shared_ptr& event_record) { if (!inited_) { Init(event_record); } } private: std::atomic inited_; std::shared_ptr event_record_; }; } // namespace oneflow #endif // ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_ ================================================ FILE: oneflow/core/device/nccl_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/nccl_util.h" namespace oneflow { #ifdef WITH_CUDA std::string NcclUniqueIdToString(const ncclUniqueId& unique_id) { return std::string(unique_id.internal, NCCL_UNIQUE_ID_BYTES); } void NcclUniqueIdFromString(const std::string& str, ncclUniqueId* unique_id) { CHECK_EQ(str.size(), NCCL_UNIQUE_ID_BYTES); memcpy(unique_id->internal, str.data(), NCCL_UNIQUE_ID_BYTES); } #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/core/device/nccl_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_ #define ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_ #include "oneflow/core/register/blob.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/device/cuda_util.h" #ifdef WITH_CUDA #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA namespace oneflow { #ifdef WITH_CUDA inline ncclDataType_t GetNcclDataType(const DataType& dt) { switch (dt) { #define NCCL_DATA_TYPE_CASE(dtype) \ case DataType::k##dtype: return ncclDataType_t::nccl##dtype NCCL_DATA_TYPE_CASE(Char); NCCL_DATA_TYPE_CASE(Float); NCCL_DATA_TYPE_CASE(Double); NCCL_DATA_TYPE_CASE(Int8); NCCL_DATA_TYPE_CASE(Int32); NCCL_DATA_TYPE_CASE(Int64); NCCL_DATA_TYPE_CASE(Float16); case DataType::kBool: return ncclDataType_t::ncclUint8; #if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= 21003 case DataType::kBFloat16: return ncclBfloat16; #endif case DataType::kUInt8: return ncclUint8; case DataType::kUInt32: return ncclUint32; case DataType::kUInt64: return ncclUint64; default: UNIMPLEMENTED(); } return ncclDataType_t::ncclFloat; } std::string NcclUniqueIdToString(const ncclUniqueId& unique_id); void NcclUniqueIdFromString(const std::string& str, ncclUniqueId* unique_id); #define HAS_NCCL_SEND_RECV NCCL_VERSION_CODE > 2700 #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_ ================================================ FILE: oneflow/core/eager/call_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/call_context.h" #include "oneflow/core/eager/tensor_storage.h" namespace oneflow { namespace eager { namespace { vm::WeakEagerBlobObjectList shared_to_weak(const vm::EagerBlobObjectList& shared_list) { vm::WeakEagerBlobObjectList ret; ret.reserve(shared_list.size()); for (const auto& shared : shared_list) { ret.emplace_back(shared); } return ret; } } // namespace DtrCallContext::DtrCallContext(const CallContext& call_ctx) : composed_attrs_(call_ctx.composed_attrs()), inputs_(call_ctx.inputs()), outputs_(shared_to_weak(call_ctx.outputs())), global_tensor_infer_result_(call_ctx.global_tensor_infer_result()), op_interp_ctx_(call_ctx.op_interp_ctx()), tmp_tensor_(call_ctx.tmp_tensor()) { for (const auto& x : call_ctx.outputs()) { ebo_infos_.push_back(EBOInfo{std::make_shared(x->mem_case()), x->tensor_meta(), x->mut_tensor_meta(), x->data_type(), x->memory_format()}); } } CallContext::CallContext(const DtrCallContext& dtr_call_ctx) : composed_attrs_(dtr_call_ctx.composed_attrs_), inputs_(dtr_call_ctx.inputs_), global_tensor_infer_result_(dtr_call_ctx.global_tensor_infer_result_), op_interp_ctx_(dtr_call_ctx.op_interp_ctx_), tmp_tensor_(dtr_call_ctx.tmp_tensor_) { for (int i = 0; i < dtr_call_ctx.outputs_.size(); ++i) { const auto& weak = dtr_call_ctx.outputs_[i]; if (weak.expired()) { LOG(INFO) << "index: " << i << " is expired"; outputs_.push_back(std::make_shared( dtr_call_ctx.ebo_infos_[i].mem_case, dtr_call_ctx.ebo_infos_[i].local_tensor_meta, dtr_call_ctx.ebo_infos_[i].dynamic_local_tensor_meta, dtr_call_ctx.ebo_infos_[i].data_type, dtr_call_ctx.ebo_infos_[i].memory_format, std::make_shared( true, dtr_call_ctx.ebo_infos_[i].local_tensor_meta->device()))); } else { outputs_.push_back(weak.lock()); } } } CallContext DtrCallContext::ToCallContext() const { return CallContext(*this); } } // namespace eager } // namespace oneflow ================================================ FILE: oneflow/core/eager/call_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_ #define ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/small_vector.h" namespace oneflow { namespace one { class StatefulLocalOpKernel; class GlobalTensorInferResult; } // namespace one namespace eager { class TmpTensor final : public user_op::Tensor { public: explicit TmpTensor(const std::shared_ptr& mem_case) : mem_case_(mem_case), tmp_buffer_size_(0), tmp_buffer_ptr_(nullptr) {} ~TmpTensor() = default; TmpTensor(const TmpTensor& other) : mem_case_(other.mem_case_), tmp_buffer_size_(other.tmp_buffer_size_), tmp_buffer_ptr_(other.tmp_buffer_ptr_) { CHECK_ISNULL(tmp_buffer_ptr_); } TmpTensor(TmpTensor&&) = delete; TmpTensor& operator=(const TmpTensor& other) = delete; TmpTensor& operator=(TmpTensor&&) = delete; ShapeView shape_view() const override { return ShapeView(&tmp_buffer_size_, 1); } MutShapeView mut_shape_view() override { return MutShapeView(&tmp_buffer_size_, 1); } const Stride& stride() const override { UNIMPLEMENTED() << "TmpTensor::stride() is not implemented."; } DataType data_type() const override { return DataType::kChar; } MemoryFormat memory_format() const override { return MemoryFormat::kContiguous; } const MemoryCase& mem_case() const override { return *mem_case_; } const void* raw_dptr() const override { return tmp_buffer_ptr_; } void* mut_raw_dptr() override { return tmp_buffer_ptr_; } int64_t tmp_buffer_size() const { return tmp_buffer_size_; } void set_tmp_buffer_size(int64_t val) { tmp_buffer_size_ = val; } char* mut_tmp_buffer_ptr() { return tmp_buffer_ptr_; } void set_tmp_buffer_ptr(char* ptr) { tmp_buffer_ptr_ = ptr; } private: std::shared_ptr mem_case_; int64_t tmp_buffer_size_; char* tmp_buffer_ptr_; }; class DtrCallContext; class CallContext { public: CallContext(ComposedAttrMap composed_attrs, vm::EagerBlobObjectList inputs, vm::EagerBlobObjectList outputs, const std::shared_ptr& global_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const std::shared_ptr& mem_case) : composed_attrs_(std::move(composed_attrs)), inputs_(std::move(inputs)), outputs_(std::move(outputs)), global_tensor_infer_result_(global_tensor_infer_result), op_interp_ctx_(op_interp_ctx), tmp_tensor_(mem_case) {} explicit CallContext(const DtrCallContext&); ~CallContext() = default; const ComposedAttrMap& composed_attrs() const { return composed_attrs_; } const vm::EagerBlobObjectList& inputs() const { return inputs_; } const vm::EagerBlobObjectList& outputs() const { return outputs_; } vm::EagerBlobObjectList& mut_inputs() { return inputs_; } vm::EagerBlobObjectList& mut_outputs() { return outputs_; } const std::shared_ptr& global_tensor_infer_result() const { return global_tensor_infer_result_; } const one::OpExprInterpContext& op_interp_ctx() const { return op_interp_ctx_; } TmpTensor* mut_tmp_tensor() { return &tmp_tensor_; } const TmpTensor& tmp_tensor() const { return tmp_tensor_; } private: const ComposedAttrMap composed_attrs_; vm::EagerBlobObjectList inputs_; vm::EagerBlobObjectList outputs_; const std::shared_ptr global_tensor_infer_result_; const one::OpExprInterpContext op_interp_ctx_; TmpTensor tmp_tensor_; }; class DtrCallContext { public: explicit DtrCallContext(const CallContext& call_ctx); CallContext ToCallContext() const; vm::EagerBlobObjectList& mut_inputs() { return inputs_; } vm::WeakEagerBlobObjectList& mut_outputs() { return outputs_; } friend class CallContext; private: struct EBOInfo { const std::shared_ptr mem_case; const Symbol local_tensor_meta; const std::shared_ptr dynamic_local_tensor_meta; const DataType data_type; const MemoryFormat memory_format; }; using EBOInfoList = small_vector; const ComposedAttrMap composed_attrs_; vm::EagerBlobObjectList inputs_; vm::WeakEagerBlobObjectList outputs_; EBOInfoList ebo_infos_; const std::shared_ptr global_tensor_infer_result_; const one::OpExprInterpContext op_interp_ctx_; TmpTensor tmp_tensor_; }; } // namespace eager } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_ ================================================ FILE: oneflow/core/eager/dev_vm_dep_object_consume_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_ #define ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_ namespace oneflow { namespace one { enum class DevVmDepObjectConsumeMode { NONE, MUTABLE, }; inline DevVmDepObjectConsumeMode* CurrentDevVmDepObjectConsumeMode() { static thread_local DevVmDepObjectConsumeMode mode_ = DevVmDepObjectConsumeMode::MUTABLE; return &mode_; } class DevVmDepObjectConsumeModeGuard { public: DevVmDepObjectConsumeModeGuard(DevVmDepObjectConsumeMode mode) : prev_mode_(*CurrentDevVmDepObjectConsumeMode()) { *CurrentDevVmDepObjectConsumeMode() = mode; } ~DevVmDepObjectConsumeModeGuard() { *CurrentDevVmDepObjectConsumeMode() = prev_mode_; } // NOLINT private: DevVmDepObjectConsumeMode prev_mode_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_ ================================================ FILE: oneflow/core/eager/eager_blob_object.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/allocator.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/tensor_meta.h" namespace oneflow { namespace vm { EagerBlobObject::EagerBlobObject( const std::shared_ptr& mem_case, const Symbol& static_local_tensor_meta, const std::shared_ptr& dynamic_local_tensor_meta, DataType data_type, MemoryFormat memory_format, const std::shared_ptr& tensor_storage, const intrusive::shared_ptr& dep_object) : is_dynamic_(false), mem_case_(mem_case), data_type_(data_type), memory_format_(memory_format), storage_offset_(0), tensor_storage_(tensor_storage), compute_local_dep_object_(dep_object), static_local_tensor_meta_(static_local_tensor_meta), dynamic_local_tensor_meta_(dynamic_local_tensor_meta) { CHECK(static_cast(tensor_storage)); } // user_op::TensorDesc overrides const Shape& EagerBlobObject::shape() const { if (dynamic_local_tensor_meta_) { return dynamic_local_tensor_meta_->shape(); } else { return static_local_tensor_meta_->shape(); } } const Stride& EagerBlobObject::stride() const { if (dynamic_local_tensor_meta_) { return dynamic_local_tensor_meta_->stride(); } else { return static_local_tensor_meta_->stride(); } } void EagerBlobObject::set_shape(const Shape& shape) { CHECK(dynamic_local_tensor_meta_); std::const_pointer_cast(dynamic_local_tensor_meta_)->set_shape(shape); } void EagerBlobObject::set_stride(const Stride& stride) { CHECK(dynamic_local_tensor_meta_); std::const_pointer_cast(dynamic_local_tensor_meta_)->set_stride(stride); } MutShapeView EagerBlobObject::mut_shape_view() { CHECK(dynamic_local_tensor_meta_); return *const_cast(dynamic_local_tensor_meta_->shape_ptr().get()); } std::shared_ptr EagerBlobObject::shape_ptr() const { if (dynamic_local_tensor_meta_) { return dynamic_local_tensor_meta_->shape_ptr(); } else { return static_local_tensor_meta_->shape_ptr(); } } std::shared_ptr EagerBlobObject::stride_ptr() const { if (dynamic_local_tensor_meta_) { return dynamic_local_tensor_meta_->stride_ptr(); } else { return static_local_tensor_meta_->stride_ptr(); } } int64_t EagerBlobObject::storage_offset() const { return storage_offset_; } void EagerBlobObject::set_storage_offset(const int64_t offset) { storage_offset_ = offset; } Maybe EagerBlobObject::TryAllocateBlobBodyMemory(vm::Allocator* allocator) { size_t required_body_bytes = AlignedByteSizeOfBlobBody(); if (required_body_bytes == 0) { CHECK_ISNULL_OR_RETURN(tensor_storage_->blob_dptr()); } else if (tensor_storage_->blob_dptr() != nullptr) { CHECK_GE_OR_RETURN(tensor_storage_->blob_bytes(), ByteSizeOfBlobBody()) << "This blob has been allocated memory, but less than needed space."; } else { char* dptr = nullptr; JUST(allocator->Allocate(&dptr, required_body_bytes)); // reset tensor_storage_; const auto& Free = [allocator, required_body_bytes](char* dptr) { if (IsShuttingDown()) { return; } allocator->Deallocate(dptr, required_body_bytes); }; tensor_storage_->set_blob_dptr(std::unique_ptr>(dptr, Free), required_body_bytes); InitNonPODTypeEagerBlobObjectIfNeed(tensor_storage_->non_pod_allocator(), this); return true; } return false; } const void* EagerBlobObject::raw_dptr() const { char* ptr = tensor_storage_->blob_dptr(); if (tensor_storage_->blob_bytes() > 0) { CHECK_NOTNULL(ptr); } return ptr + storage_offset_ * GetSizeOfDataType(data_type_); } Maybe EagerBlobObject::DeallocateBlobDataPtr() { tensor_storage_->Release(); return Maybe::Ok(); } void EagerBlobObject::RegisterStorageDeleteHook(const std::function& hook) { tensor_storage_->RegisterStorageDeleteHook(hook); } const Optional>& EagerBlobObject::producer_stream() const { return tensor_storage_->producer_stream(); } Maybe EagerBlobObject::init_producer_stream(Symbol<::oneflow::Stream> producer_stream) { return tensor_storage_->init_producer_stream(producer_stream); } const Optional>& EagerBlobObject::last_used_stream() const { return tensor_storage_->last_used_stream(); } void EagerBlobObject::set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) { tensor_storage_->set_last_used_stream(last_used_stream); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/eager/eager_blob_object.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ #define ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/register/blob.h" namespace oneflow { namespace one { class LocalTensorMeta; class MutLocalTensorMeta; } // namespace one namespace vm { class Allocator; class EagerBlobObject final : public user_op::Tensor, public user_op::TensorDesc, public std::enable_shared_from_this { public: EagerBlobObject(const EagerBlobObject&) = delete; EagerBlobObject(EagerBlobObject&&) = delete; EagerBlobObject(const std::shared_ptr& mem_case, const Symbol& static_local_tensor_meta, const std::shared_ptr& dynamic_local_tensor_meta, DataType data_type, MemoryFormat memory_format, const std::shared_ptr& tensor_storage) : EagerBlobObject(mem_case, static_local_tensor_meta, dynamic_local_tensor_meta, data_type, memory_format, tensor_storage, intrusive::shared_ptr()) {} EagerBlobObject(const std::shared_ptr& mem_case, const Symbol& static_local_tensor_meta, const std::shared_ptr& dynamic_local_tensor_meta, DataType data_type, MemoryFormat memory_format, const std::shared_ptr& tensor_storage, const intrusive::shared_ptr& dep_object); ~EagerBlobObject() { tensor_storage_.reset(); } const std::shared_ptr& mut_tensor_meta() { return dynamic_local_tensor_meta_; } // Getters const Symbol& tensor_meta() const { return static_local_tensor_meta_; } // user_op::TensorDesc overrides const Shape& shape() const override; const Stride& stride() const override; DataType data_type() const override { return data_type_; } bool is_dynamic() const override { return is_dynamic_; } MemoryFormat memory_format() const override { return memory_format_; } void set_shape(const Shape& shape) override; void set_stride(const Stride& stride) override; void set_data_type(DataType data_type) override { data_type_ = data_type; } void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; } void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; } // user_op::Tensor overrides ShapeView shape_view() const override { return shape(); } MutShapeView mut_shape_view() override; const MemoryCase& mem_case() const override { return *mem_case_; } const void* raw_dptr() const override; void* mut_raw_dptr() override { return const_cast(raw_dptr()); } int64_t storage_offset() const; void set_storage_offset(const int64_t offset); // Returns true if allocate successfully. Maybe TryAllocateBlobBodyMemory(vm::Allocator* allocator); Maybe DeallocateBlobDataPtr(); void RegisterStorageDeleteHook(const std::function& hook); Maybe compute_local_dep_object() const { CHECK_NOTNULL_OR_RETURN(compute_local_dep_object_.get()); return compute_local_dep_object_.get(); } std::shared_ptr& tensor_storage() { return tensor_storage_; } const Optional>& producer_stream() const; Maybe init_producer_stream(Symbol<::oneflow::Stream> producer_stream); const Optional>& last_used_stream() const; void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream); std::shared_ptr shape_ptr() const; std::shared_ptr stride_ptr() const; size_t ByteSizeOfBlobBody() const { const size_t elem_cnt = shape().elem_cnt(); if (elem_cnt == 0) { return 0; } size_t max_offset = 0; for (size_t i = 0; i < shape().NumAxes(); ++i) { max_offset += (shape().at(i) - 1) * stride().at(i); } size_t capacity = max_offset + 1; // TODO(liujuncheng): remove this capacity = std::max(capacity, elem_cnt); return capacity * GetSizeOfDataType(data_type_); } size_t AlignedByteSizeOfBlobBody() const { return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize); } size_t ByteSizeOfBlobHeader() const { return shape().NumAxes() * sizeof(int64_t); } size_t AlignedByteSizeOfBlobHeader() const { return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize); } const char* header_ptr() const { return reinterpret_cast(shape().dim_vec().data()); } char* mut_header_ptr() { return reinterpret_cast(const_cast(shape().dim_vec().data())); } void set_input_of_view_op(std::shared_ptr input) { input_of_view_op_ = std::move(input); } private: bool is_dynamic_; std::shared_ptr mem_case_; DataType data_type_; MemoryFormat memory_format_; int64_t storage_offset_; std::shared_ptr tensor_storage_; intrusive::shared_ptr compute_local_dep_object_; Symbol static_local_tensor_meta_; std::shared_ptr dynamic_local_tensor_meta_; // for rematerialization (i.e. Coop/DTR) std::shared_ptr input_of_view_op_; }; using EagerBlobObjectList = small_vector, kOpArgsReservedSize>; using WeakEagerBlobObjectList = small_vector>; using EagerBlobObjectListPtr = std::shared_ptr; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ ================================================ FILE: oneflow/core/eager/local_dep_object.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/static_global.h" namespace oneflow { intrusive::shared_ptr NewLocalDepObject() { return intrusive::make_shared(); } } // namespace oneflow ================================================ FILE: oneflow/core/eager/local_dep_object.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_ #define ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_ #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/framework/device.h" namespace oneflow { // LocalDepObject helps VirtualMachineEngine building instruction edges using LocalDepObject = vm::Dependence; using DependenceVector = small_vector; intrusive::shared_ptr NewLocalDepObject(); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_ ================================================ FILE: oneflow/core/eager/tensor_storage.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/common/env_var/remat.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/vm/remat/disjoint_set.h" #include "oneflow/core/vm/remat/env.h" #include "oneflow/core/vm/remat/util.h" #include "oneflow/core/vm/virtual_machine.h" namespace oneflow { namespace vm { namespace { int64_t unique_id() { static size_t id = 0; return id++; } } // namespace TensorStorage::TensorStorage(bool is_allocated_in_vm, Symbol device) : blob_bytes_(0), device_(device), non_pod_allocator_(std::make_unique()), producer_stream_(NullOpt), last_used_stream_(NullOpt), is_allocated_in_vm_(is_allocated_in_vm) {} Symbol TensorStorage::device() const { return device_; } TensorStorage::~TensorStorage() { for (const auto& hook : storage_delete_hooks_) { hook(); } } void TensorStorage::_Release() { non_pod_allocator_.reset(); blob_dptr_.reset(); } void TensorStorage::Release() { return _Release(); } Maybe TensorStorage::init_producer_stream(Symbol<::oneflow::Stream> producer_stream) { CHECK_OR_RETURN(!producer_stream_.has_value()); producer_stream_ = producer_stream; return Maybe::Ok(); } RematableTensorStorage::RematableTensorStorage(Symbol device) : TensorStorage(true, device), node(std::make_shared(0)), id_(unique_id()), num_pinned_(0), last_access_time_(0), compute_time_(0) { VLOG(1) << "create rematable storage " << id_; } RematableTensorStorage::~RematableTensorStorage() { // We must call _Release before destruction or the release will be // called in base class's destructor and causes segfault. // Time order: // 1. ~RematableTensorStorage destructs its members // 2. ~TensorStorage, Allocator::Deallocate, which uses RematableTensorStorage members _Release(); if (compute_op_) { Singleton::Get()->remove_compute_op(compute_op_.get()); } VLOG(1) << "delete storage " << id_; } void RematableTensorStorage::LogEviction(bool eager_eviction) const { Singleton::Get()->add_eviction_num(eager_eviction); VLOG(1) << "evict storage " << id_ << ", compute op type: " << compute_op_type_name() << ", eager_eviction: " << eager_eviction; } void RematableTensorStorage::Remat() { if (is_in_memory()) { return; } auto stream = CHECK_JUST(GetDefaultStreamByDevice(device_)); auto* vm_stream = CHECK_JUST(Singleton::Get()->GetVmStream(stream)); auto op = compute_op(); CHECK_JUST(Recompute(&op, vm_stream)); } void RematableTensorStorage::Evict(bool eager_eviction) { CHECK(!is_eviction_disabled()); LogEviction(eager_eviction); return _Release(); } void RematableTensorStorage::Release() { CHECK(device_->rematable()); if (is_eviction_disabled()) { return; } return Evict(true); } std::vector random_ops{"uniform", "uniform_int", "normal", "randperm"}; bool RematableTensorStorage::is_evictable() const { return compute_op_ != nullptr && std::find(random_ops.begin(), random_ops.end(), compute_op_type_name()) == random_ops.end() && !eviction_disabled_; } OpCallInstructionPolicy RematableTensorStorage::compute_op() const { CHECK_NOTNULL(compute_op_); return OpCallInstructionPolicy(*compute_op_); } std::shared_ptr RematableTensorStorage::dtr_compute_op() const { return compute_op_; } void RematableTensorStorage::Pin() { ++num_pinned_; VLOG(3) << "pin storage " << id_ << ", num_pinned: " << num_pinned_; } void RematableTensorStorage::Unpin() { CHECK_GT(num_pinned_, 0); --num_pinned_; VLOG(3) << "unpin storage " << id_ << ", num_pinned: " << num_pinned_; } void RematableTensorStorage::clear_compute_op() { if (compute_op_ == nullptr) { return; } VLOG(1) << "clear_compute_op: " << id_; Singleton::Get()->remove_compute_op(compute_op_.get()); compute_op_ = nullptr; compute_time_ = -1; } void RematableTensorStorage::set_compute_op( const std::shared_ptr& compute_op, double compute_time) { CHECK_ISNULL(compute_op_); compute_op_ = compute_op; VLOG(1) << "set_compute_op: " << id_ << ", compute op: " << compute_op.get(); Singleton::Get()->ops.push_back(CHECK_NOTNULL(compute_op_.get())); compute_time_ = compute_time; } std::string RematableTensorStorage::compute_op_type_name() const { if (is_eviction_disabled()) { return "eviction_disabled"; } if (compute_op_) { return compute_op_->opkernel().op_type_name(); } return "None"; } void RematableTensorStorage::Access() { last_access_time_ = Singleton::Get()->time_now(); } Maybe RematableTensorStorage::cost(size_t override_size) const { CHECK_OR_RETURN(!is_eviction_disabled()); const double time_since_last_access = Singleton::Get()->time_now() - last_access_time_; size_t size = 1; if (EnvBool() || EnvBool()) { size = override_size == 0 ? blob_bytes_ : override_size; } return (EnvBool() ? approx_neighbor_cost() : compute_time_) / time_since_last_access / static_cast(size); } double RematableTensorStorage::approx_neighbor_cost() const { const auto cal_cost = [](const auto& eager_blob_objects) { double all_cost = 0; for (int i = 0; i < eager_blob_objects.size(); ++i) { const auto& tmp = eager_blob_objects[i]; if (auto storage = std::dynamic_pointer_cast(tmp->tensor_storage()); !storage->is_in_memory()) { double tmp_cost = remat::DisjointSet::find_father(storage->node)->compute_time(); if (tmp_cost < storage->compute_time()) { tmp_cost = storage->compute_time(); } all_cost += tmp_cost; } } return all_cost; }; const auto compute_op = this->compute_op(); return cal_cost(compute_op.inputs()) + cal_cost(compute_op.outputs()) + compute_time_; } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/eager/tensor_storage.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_ #define ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/framework/stream.h" namespace oneflow { namespace remat { class DisjNode; } namespace vm { class OpCallInstructionPolicy; class DtrOpCallInstructionPolicy; class TensorStorage { public: explicit TensorStorage(bool is_allocated_in_vm, Symbol device); OF_DISALLOW_COPY_AND_MOVE(TensorStorage); virtual ~TensorStorage(); bool is_allocated_in_vm() const { return is_allocated_in_vm_; } size_t blob_bytes() const { return blob_bytes_; } char* blob_dptr() { return blob_dptr_.get(); } MemoryAllocator* non_pod_allocator() { return non_pod_allocator_.get(); } void set_blob_dptr(std::unique_ptr>&& blob_dptr, size_t bytes) { blob_dptr_ = std::move(blob_dptr); blob_bytes_ = bytes; is_initialized_ = true; } const Optional>& producer_stream() const { return producer_stream_; } Maybe init_producer_stream(Symbol<::oneflow::Stream> producer_stream); const Optional>& last_used_stream() const { return last_used_stream_; } void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) { last_used_stream_ = last_used_stream; } void _Release(); virtual void Release(); void RegisterStorageDeleteHook(const std::function& hook) { storage_delete_hooks_.emplace_back(hook); } Symbol device() const; protected: std::unique_ptr> blob_dptr_; size_t blob_bytes_; bool is_initialized_ = false; Symbol device_; private: std::unique_ptr non_pod_allocator_; Optional> producer_stream_; Optional> last_used_stream_; std::vector> storage_delete_hooks_; bool is_allocated_in_vm_; }; class RematableTensorStorage final : public TensorStorage { public: explicit RematableTensorStorage(Symbol device); OF_DISALLOW_COPY_AND_MOVE(RematableTensorStorage); ~RematableTensorStorage() override; void set_compute_op(const std::shared_ptr& compute_op, double compute_time); void clear_compute_op(); OpCallInstructionPolicy compute_op() const; std::shared_ptr dtr_compute_op() const; void Release() override; void Remat(); void Evict(bool eager_eviction); void Pin(); void Unpin(); void Access(); bool is_in_memory() const { return blob_bytes_ == 0 || blob_dptr_ != nullptr; } bool is_pinned() const { return num_pinned() > 0; } int32_t num_pinned() const { return num_pinned_; } bool is_evictable() const; void set_eviction_disabled(bool disabled) { eviction_disabled_ = disabled; } bool is_eviction_disabled() const { return eviction_disabled_; } int64_t id() const { return id_; } Maybe cost(size_t override_size) const; double approx_neighbor_cost() const; std::string compute_op_type_name() const; bool is_initialized() const { return is_initialized_; } void set_initialized() { is_initialized_ = true; } bool is_needed_by_backward() const { return is_needed_by_backward_; } void set_needed_by_backward() { is_needed_by_backward_ = true; } double compute_time() const { return compute_time_; } std::shared_ptr node; private: int64_t id_{}; size_t num_pinned_{}; bool eviction_disabled_ = false; double last_access_time_{}; double compute_time_{}; std::shared_ptr compute_op_; bool is_needed_by_backward_ = false; void LogEviction(bool eager_eviction) const; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_ ================================================ FILE: oneflow/core/embedding/cache.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/cache.h" #include "oneflow/core/embedding/full_cache.h" #include "oneflow/core/embedding/lru_cache.h" namespace oneflow { namespace embedding { std::unique_ptr NewCache(const CacheOptions& options) { #ifdef WITH_CUDA CHECK_GT(options.key_size, 0); CHECK_GT(options.value_size, 0); CHECK_GT(options.capacity, 0); if (options.policy == CacheOptions::Policy::kLRU) { return NewLruCache(options); } else if (options.policy == CacheOptions::Policy::kFull) { return NewFullCache(options); } else { UNIMPLEMENTED(); return nullptr; } #else UNIMPLEMENTED(); return nullptr; #endif // WITH_CUDA } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_CACHE_H_ #define ONEFLOW_CORE_EMBEDDING_CACHE_H_ #include "oneflow/core/embedding/kv_iterator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace embedding { struct CacheOptions { enum class Policy { kLRU, kFull, }; enum class MemoryKind { kDevice, kHost, }; Policy policy = Policy::kLRU; MemoryKind value_memory_kind = MemoryKind::kDevice; uint64_t capacity{}; uint32_t key_size{}; uint32_t value_size{}; DataType value_type{}; float load_factor = 0.75; }; class Cache { public: OF_DISALLOW_COPY_AND_MOVE(Cache); Cache() = default; virtual ~Cache() = default; virtual uint32_t KeySize() const = 0; virtual uint32_t ValueSize() const = 0; virtual DataType ValueType() const = 0; virtual uint32_t MaxQueryLength() const = 0; virtual void ReserveQueryLength(uint32_t query_length) = 0; virtual uint64_t Capacity() const = 0; virtual uint64_t DumpCapacity() const { return Capacity(); } virtual CacheOptions::Policy Policy() const = 0; virtual void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) = 0; virtual void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) = 0; virtual void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint8_t* mask) { UNIMPLEMENTED(); } virtual void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) = 0; virtual void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) { UNIMPLEMENTED(); } virtual void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, void* keys, void* values) = 0; virtual void ClearDirtyFlags() = 0; virtual void Clear() = 0; }; std::unique_ptr NewCache(const CacheOptions& options); } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_CACHE_H_ ================================================ FILE: oneflow/core/embedding/cache_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/cache.h" #include "oneflow/core/device/cuda_util.h" #include #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace embedding { namespace { #ifdef WITH_CUDA bool HasCudaDevice() { int device_count = 0; if (cudaGetDeviceCount(&device_count) != cudaSuccess) { return false; } if (device_count <= 0) { return false; } return true; } void TestCache(Cache* cache, uint32_t line_size) { std::unique_ptr device_manager_registry( new ep::DeviceManagerRegistry()); auto device = device_manager_registry->GetDevice(DeviceType::kCUDA, 0); ep::Stream* stream = device->CreateStream(); std::unordered_set in_cache; const size_t n_iter = 32; const uint32_t n_keys = 1024; int64_t* d_keys; int64_t* keys; uint32_t* d_n_missing; uint32_t* n_missing; int64_t* d_missing_keys; int64_t* missing_keys; uint32_t* d_missing_indices; uint32_t* missing_indices; float* d_values; float* values; float* d_evicted_values; float* evicted_values; uint32_t* d_n_evicted; uint32_t* n_evicted; int64_t* d_evicted_keys; int64_t* evicted_keys; uint8_t* mask; const size_t keys_size = n_keys * sizeof(int64_t); OF_CUDA_CHECK(cudaMalloc(&d_keys, keys_size)); OF_CUDA_CHECK(cudaMallocHost(&keys, keys_size)); OF_CUDA_CHECK(cudaMalloc(&d_n_missing, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMallocHost(&n_missing, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&d_missing_keys, keys_size)); OF_CUDA_CHECK(cudaMallocHost(&missing_keys, keys_size)); const size_t indices_size = n_keys * sizeof(uint32_t); OF_CUDA_CHECK(cudaMalloc(&d_missing_indices, indices_size)); OF_CUDA_CHECK(cudaMallocHost(&missing_indices, indices_size)); const size_t values_size = n_keys * line_size * sizeof(float); OF_CUDA_CHECK(cudaMalloc(&d_values, values_size)); OF_CUDA_CHECK(cudaMallocHost(&values, values_size)); OF_CUDA_CHECK(cudaMalloc(&d_evicted_values, values_size)); OF_CUDA_CHECK(cudaMallocHost(&evicted_values, values_size)); OF_CUDA_CHECK(cudaMalloc(&d_n_evicted, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMallocHost(&n_evicted, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&d_evicted_keys, keys_size)); OF_CUDA_CHECK(cudaMallocHost(&evicted_keys, keys_size)); OF_CUDA_CHECK(cudaMalloc(&mask, n_keys)); std::vector random_keys(n_keys * 32); std::iota(random_keys.begin(), random_keys.end(), 1); std::random_device rd; std::mt19937 g(rd()); for (size_t iter = 0; iter < n_iter; ++iter) { std::shuffle(random_keys.begin(), random_keys.end(), g); std::copy(random_keys.begin(), random_keys.begin() + n_keys, keys); uint32_t expect_n_missing = 0; std::unordered_set expect_missing_keys_set; std::unordered_set expect_missing_indices_set; std::unordered_set keys_set; for (size_t i = 0; i < n_keys; ++i) { keys_set.emplace(keys[i]); if (in_cache.count(keys[i]) == 0) { expect_missing_keys_set.emplace(keys[i]); expect_missing_indices_set.emplace(i); expect_n_missing += 1; } } // test OF_CUDA_CHECK(cudaMemcpy(d_keys, keys, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); cache->Test(stream, n_keys, d_keys, d_n_missing, d_missing_keys, d_missing_indices); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaMemcpy(n_missing, d_n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(missing_keys, d_missing_keys, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(missing_indices, d_missing_indices, indices_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*n_missing, expect_n_missing); std::unordered_set test_missing_keys_set; std::unordered_set test_missing_indices_set; for (size_t i = 0; i < *n_missing; ++i) { test_missing_keys_set.emplace(missing_keys[i]); test_missing_indices_set.emplace(missing_indices[i]); ASSERT_EQ(keys[missing_indices[i]], missing_keys[i]); } ASSERT_EQ(test_missing_keys_set, expect_missing_keys_set); ASSERT_EQ(test_missing_indices_set, expect_missing_indices_set); // get OF_CUDA_CHECK(cudaDeviceSynchronize()); if (cache->Policy() == CacheOptions::Policy::kFull) { cache->Get(stream, n_keys, d_keys, d_values, mask); } cache->Get(stream, n_keys, d_keys, d_values, d_n_missing, d_missing_keys, d_missing_indices); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaMemcpy(n_missing, d_n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(missing_keys, d_missing_keys, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(missing_indices, d_missing_indices, indices_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(values, d_values, values_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*n_missing, expect_n_missing); std::unordered_set get_missing_keys_set; std::unordered_set get_missing_indices_set; for (size_t i = 0; i < *n_missing; ++i) { get_missing_keys_set.emplace(missing_keys[i]); get_missing_indices_set.emplace(missing_indices[i]); ASSERT_EQ(keys[missing_indices[i]], missing_keys[i]); } ASSERT_EQ(get_missing_keys_set, expect_missing_keys_set); ASSERT_EQ(get_missing_indices_set, expect_missing_indices_set); for (size_t i = 0; i < n_keys; ++i) { if (get_missing_keys_set.count(keys[i]) == 0) { for (size_t j = 0; j < line_size; ++j) { ASSERT_EQ(values[i * line_size + j], static_cast(keys[i] * line_size + j)) << "iter " << iter << " i " << i << " j " << j; } } } // put for (size_t i = 0; i < n_keys; ++i) { for (size_t j = 0; j < line_size; ++j) { values[i * line_size + j] = static_cast(keys[i] * line_size + j); } } OF_CUDA_CHECK(cudaMemcpy(d_values, values, values_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); cache->Put(stream, n_keys, d_keys, d_values, d_n_evicted, d_evicted_keys, d_evicted_values); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaMemcpy(n_evicted, d_n_evicted, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(evicted_keys, d_evicted_keys, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(evicted_values, d_evicted_values, values_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); for (size_t i = 0; i < *n_evicted; ++i) { ASSERT_TRUE(in_cache.count(evicted_keys[i]) > 0 || keys_set.count(evicted_keys[i]) > 0); for (size_t j = 0; j < line_size; ++j) { ASSERT_EQ(evicted_values[i * line_size + j], static_cast(evicted_keys[i] * line_size + j)); } } for (size_t i = 0; i < n_keys; ++i) { in_cache.emplace(keys[i]); } for (size_t i = 0; i < *n_evicted; ++i) { in_cache.erase(evicted_keys[i]); } } const uint64_t dump_capacity = cache->DumpCapacity(); for (size_t start_key_index = 0; start_key_index < dump_capacity; start_key_index += n_keys) { cache->Dump(stream, start_key_index, std::min(start_key_index + n_keys, dump_capacity), d_n_evicted, d_evicted_keys, d_evicted_values); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaMemcpy(n_evicted, d_n_evicted, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(evicted_keys, d_evicted_keys, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(evicted_values, d_evicted_values, values_size, cudaMemcpyDefault)); for (size_t i = 0; i < *n_evicted; ++i) { ASSERT_TRUE(in_cache.count(evicted_keys[i]) > 0); in_cache.erase(evicted_keys[i]); for (size_t j = 0; j < line_size; ++j) { ASSERT_EQ(evicted_values[i * line_size + j], static_cast(evicted_keys[i] * line_size + j)); } } } CHECK_EQ(in_cache.size(), 0); OF_CUDA_CHECK(cudaFree(d_keys)); OF_CUDA_CHECK(cudaFreeHost(keys)); OF_CUDA_CHECK(cudaFree(d_n_missing)); OF_CUDA_CHECK(cudaFreeHost(n_missing)); OF_CUDA_CHECK(cudaFree(d_missing_keys)); OF_CUDA_CHECK(cudaFreeHost(missing_keys)); OF_CUDA_CHECK(cudaFree(d_missing_indices)); OF_CUDA_CHECK(cudaFreeHost(missing_indices)); OF_CUDA_CHECK(cudaFree(d_values)); OF_CUDA_CHECK(cudaFreeHost(values)); OF_CUDA_CHECK(cudaFree(d_evicted_values)); OF_CUDA_CHECK(cudaFreeHost(evicted_values)); OF_CUDA_CHECK(cudaFree(d_n_evicted)); OF_CUDA_CHECK(cudaFreeHost(n_evicted)); OF_CUDA_CHECK(cudaFree(d_evicted_keys)); OF_CUDA_CHECK(cudaFreeHost(evicted_keys)); OF_CUDA_CHECK(cudaFree(mask)); device->DestroyStream(stream); } TEST(Cache, FullCache) { if (!HasCudaDevice()) { return; } CacheOptions options{}; options.policy = CacheOptions::Policy::kFull; const uint32_t line_size = 128; options.value_size = 512; options.capacity = 65536; options.key_size = 8; options.value_memory_kind = CacheOptions::MemoryKind::kDevice; std::unique_ptr cache(NewCache(options)); cache->ReserveQueryLength(65536); TestCache(cache.get(), line_size); } TEST(Cache, LruCache) { if (!HasCudaDevice()) { return; } CacheOptions options{}; options.policy = CacheOptions::Policy::kLRU; const uint32_t line_size = 128; options.value_size = 512; options.capacity = 65536; options.key_size = 8; options.value_memory_kind = CacheOptions::MemoryKind::kDevice; std::unique_ptr cache(NewCache(options)); cache->ReserveQueryLength(65536); TestCache(cache.get(), line_size); } #endif // WITH_CUDA } // namespace } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/cached_key_value_store.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/cached_key_value_store.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace embedding { namespace { template __global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing, uint32_t num_elems_per_value, const uint32_t* cache_missing_indices, const uint32_t* store_missing_indices, const Elem* store_values, Elem* values, uint32_t* missing_indices) { const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value; CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) { const uint32_t value_index = i / num_elems_per_value; const uint32_t elem_index = i - value_index * num_elems_per_value; values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i]; } CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) { missing_indices[i] = cache_missing_indices[store_missing_indices[i]]; } } template class CacheKeyValueStoreImpl : public KeyValueStore { public: OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl); CacheKeyValueStoreImpl(std::unique_ptr&& store, std::unique_ptr&& cache) : store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); CHECK_EQ(store_->KeySize(), cache_->KeySize()); CHECK_EQ(store_->ValueSize(), cache_->ValueSize()); OF_CUDA_CHECK(cudaMalloc(&num_buffer_, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMallocHost(&host_num_buffer_, sizeof(uint32_t))); num_elems_per_value_ = store_->ValueSize() / sizeof(Elem); } ~CacheKeyValueStoreImpl() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFree(num_buffer_)); OF_CUDA_CHECK(cudaFreeHost(host_num_buffer_)); if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFree(keys_buffer_)); OF_CUDA_CHECK(cudaFree(values_buffer_)); OF_CUDA_CHECK(cudaFree(indices_buffer0_)); OF_CUDA_CHECK(cudaFree(indices_buffer1_)); } cache_.reset(); store_.reset(); } uint32_t KeySize() const override { return store_->KeySize(); } uint32_t ValueSize() const override { return store_->ValueSize(); } uint32_t MaxQueryLength() const override { return max_query_length_; } void ReserveQueryLength(uint32_t query_length) override { CudaCurrentDeviceGuard guard(device_index_); if (query_length <= max_query_length_) { return; } if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); } if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); } if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFree(keys_buffer_)); OF_CUDA_CHECK(cudaFree(values_buffer_)); OF_CUDA_CHECK(cudaFree(indices_buffer0_)); OF_CUDA_CHECK(cudaFree(indices_buffer1_)); } OF_CUDA_CHECK(cudaMalloc(&keys_buffer_, query_length * store_->KeySize())); OF_CUDA_CHECK(cudaMalloc(&values_buffer_, query_length * store_->ValueSize())); OF_CUDA_CHECK(cudaMalloc(&indices_buffer0_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&indices_buffer1_, query_length * sizeof(uint32_t))); max_query_length_ = query_length; } void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) override; void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint8_t* mask) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, const float* lr, float scale) override; bool IsFusionSupported() override { return cache_->Policy() == CacheOptions::Policy::kFull && cache_->ValueType() == DataType::kFloat; } bool SnapshotExists(const std::string& name) override; void LoadSnapshot(const std::string& name) override; void SaveSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name, const std::function& Hook) override; private: void SyncCacheToStore(); std::unique_ptr store_; std::unique_ptr cache_; uint32_t* num_buffer_{}; uint32_t* host_num_buffer_{}; Key* keys_buffer_{}; Elem* values_buffer_{}; uint32_t* indices_buffer0_{}; uint32_t* indices_buffer1_{}; int device_index_{}; uint32_t max_query_length_; uint32_t num_elems_per_value_{}; std::recursive_mutex mutex_; bool synced_; }; template void CacheKeyValueStoreImpl::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) { std::lock_guard lock(mutex_); auto cuda_stream = stream->As(); if (cache_->Policy() == CacheOptions::Policy::kFull) { cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices); return; } else { cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_); } OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); const uint32_t num_cache_missing = *host_num_buffer_; if (num_cache_missing == 0) { OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As()->cuda_stream())); return; } store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); const uint32_t num_store_missing = *host_num_buffer_; RUN_CUDA_KERNEL((PostStoreGetKernel), stream, num_cache_missing * num_elems_per_value_, num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_, indices_buffer1_, values_buffer_, static_cast(values), missing_indices); } template void CacheKeyValueStoreImpl::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint8_t* mask) { std::lock_guard lock(mutex_); if (cache_->Policy() == CacheOptions::Policy::kFull) { cache_->Get(stream, num_keys, keys, values, mask); return; } else { UNIMPLEMENTED(); } } template void CacheKeyValueStoreImpl::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) { std::lock_guard lock(mutex_); synced_ = false; auto cuda_stream = stream->As(); if (cache_->Policy() != CacheOptions::Policy::kFull) { OF_CUDA_CHECK(cudaMemsetAsync(num_buffer_, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); } cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_); if (cache_->Policy() == CacheOptions::Policy::kFull) { return; } OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); } template void CacheKeyValueStoreImpl::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values, const void* update, const float* lr, float scale) { std::lock_guard lock(mutex_); if (cache_->Policy() != CacheOptions::Policy::kFull) { OF_CUDA_CHECK(cudaMemsetAsync(num_buffer_, 0, sizeof(uint32_t), stream->As()->cuda_stream())); } if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) { UNIMPLEMENTED(); } synced_ = false; cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_, keys_buffer_, values_buffer_); } template bool CacheKeyValueStoreImpl::SnapshotExists(const std::string& name) { return store_->SnapshotExists(name); } template void CacheKeyValueStoreImpl::LoadSnapshot(const std::string& name) { LoadSnapshot(name, nullptr); } template void CacheKeyValueStoreImpl::LoadSnapshot( const std::string& name, const std::function& Hook) { CudaCurrentDeviceGuard guard(device_index_); std::lock_guard lock(mutex_); CHECK_GT(max_query_length_, 0); cache_->Clear(); auto device = Singleton::Get()->GetDevice(DeviceType::kCUDA, device_index_); CHECK(device); auto* stream = device->CreateStream(); store_->LoadSnapshot(name, [&](KVIterator* iter) { if (cache_->Policy() == CacheOptions::Policy::kFull) { auto* cuda_stream = stream->As(); while (true) { iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(stream->Sync()); if (*host_num_buffer_ == 0) { return; } cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr, nullptr); } } if (Hook) { iter->Reset(); Hook(iter); } }); device->DestroyStream(stream); } template void CacheKeyValueStoreImpl::SaveSnapshot(const std::string& name) { CudaCurrentDeviceGuard guard(device_index_); std::lock_guard lock(mutex_); SyncCacheToStore(); store_->SaveSnapshot(name); } template void CacheKeyValueStoreImpl::SyncCacheToStore() { if (synced_) { return; } CudaCurrentDeviceGuard guard(device_index_); auto device = Singleton::Get()->GetDevice(DeviceType::kCUDA, device_index_); CHECK(device); auto* stream = device->CreateStream(); auto* cuda_stream = stream->As(); const uint64_t dump_capacity = cache_->DumpCapacity(); CHECK_GT(max_query_length_, 0); for (uint64_t start_key_index = 0; start_key_index < dump_capacity; start_key_index += max_query_length_) { cache_->Dump(stream, start_key_index, std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_, keys_buffer_, values_buffer_); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(stream->Sync()); if (*host_num_buffer_ == 0) { continue; } store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); CHECK_JUST(stream->Sync()); } cache_->ClearDirtyFlags(); device->DestroyStream(stream); synced_ = true; } template std::unique_ptr DispatchElemType(std::unique_ptr&& store, std::unique_ptr&& cache) { const uint32_t value_size = store->ValueSize(); if (value_size % sizeof(uint4) == 0) { return std::unique_ptr( new CacheKeyValueStoreImpl(std::move(store), std::move(cache))); } else if (value_size % sizeof(uint64_t) == 0) { return std::unique_ptr( new CacheKeyValueStoreImpl(std::move(store), std::move(cache))); } else if (value_size % sizeof(uint32_t) == 0) { return std::unique_ptr( new CacheKeyValueStoreImpl(std::move(store), std::move(cache))); } else if (value_size % sizeof(uint16_t) == 0) { return std::unique_ptr( new CacheKeyValueStoreImpl(std::move(store), std::move(cache))); } else { return std::unique_ptr( new CacheKeyValueStoreImpl(std::move(store), std::move(cache))); } } std::unique_ptr DispatchKeyType(std::unique_ptr&& store, std::unique_ptr&& cache) { const uint32_t key_size = store->KeySize(); if (key_size == 4) { return DispatchElemType(std::move(store), std::move(cache)); } else if (key_size == 8) { return DispatchElemType(std::move(store), std::move(cache)); } else { UNIMPLEMENTED(); return nullptr; } } } // namespace std::unique_ptr NewCachedKeyValueStore(std::unique_ptr&& store, std::unique_ptr&& cache) { return DispatchKeyType(std::move(store), std::move(cache)); } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/cached_key_value_store.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_ #define ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_ #include "oneflow/core/embedding/key_value_store.h" #include "oneflow/core/embedding/cache.h" namespace oneflow { namespace embedding { std::unique_ptr NewCachedKeyValueStore(std::unique_ptr&& store, std::unique_ptr&& cache); } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_ ================================================ FILE: oneflow/core/embedding/embedding_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/embedding/persistent_table_key_value_store.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/embedding/cached_key_value_store.h" namespace oneflow { namespace embedding { #ifdef WITH_CUDA constexpr size_t kDefaultMaxQueryLength = 131072; constexpr int64_t kRingBufferSize = 8; struct IdStatistics { IdStatistics() : final_num_unique(0), iter(-1) {} uint32_t final_num_unique; std::vector num_unique_matrix; int64_t iter; }; #if CUDA_VERSION >= 11020 class DynamicTmpBufferAllocator final : public TmpBufferAllocator { public: OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator); DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool) : stream_(stream), mem_pool_(pool) {} ~DynamicTmpBufferAllocator() override = default; void Allocate(void** ptr, size_t size) override { OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_)); } void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); } private: cudaStream_t stream_{}; cudaMemPool_t mem_pool_{}; }; class DynamicAllocationEmbeddingState final : public EmbeddingState { public: OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState); DynamicAllocationEmbeddingState() : lookup_values_(nullptr), lookup_values_size_(0), has_lookup_values_(false), lookup_embeddings_(nullptr), lookup_embeddings_size_(0), has_lookup_embeddings_(false), updated_values_(nullptr), iter_(-1) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); id_statistics_vec_.resize(kRingBufferSize); cudaMemPoolProps poolProps = {}; poolProps.allocType = cudaMemAllocationTypePinned; poolProps.handleTypes = cudaMemHandleTypePosixFileDescriptor; poolProps.location.type = cudaMemLocationTypeDevice; poolProps.location.id = device_index_; cudaMemPoolCreate(&mem_pool_, &poolProps); uint64_t threshold = UINT64_MAX; cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolAttrReleaseThreshold, &threshold); } ~DynamicAllocationEmbeddingState() { CudaCurrentDeviceGuard guard(device_index_); if (has_lookup_values_) { OF_CUDA_CHECK(cudaFree(lookup_values_)); } if (has_lookup_embeddings_) { OF_CUDA_CHECK(cudaFree(lookup_embeddings_)); } OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_)); } std::unique_ptr NewTmpBufferAllocator( user_op::KernelComputeContext* ctx) override { return std::make_unique( ctx->stream()->As()->cuda_stream(), mem_pool_); } void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override { iter_ = iter; cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); uint32_t num_unique = this->GetIdNumUnique(iter); size_t lookup_values_size = GetCudaAlignedSize(num_unique * line_size * GetSizeOfDataType(unique_values->data_type())); if (!has_lookup_values_ || lookup_values_size_ < lookup_values_size) { if (has_lookup_values_) { OF_CUDA_CHECK(cudaFreeAsync(lookup_values_, cuda_stream)); } OF_CUDA_CHECK( cudaMallocFromPoolAsync(&lookup_values_, lookup_values_size, mem_pool_, cuda_stream)); has_lookup_values_ = true; lookup_values_size_ = lookup_values_size; if (ctx->has_output("embeddings", 0)) { user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); const size_t lookup_embeddings_size = GetCudaAlignedSize( num_unique * embedding_size * GetSizeOfDataType(embeddings->data_type())); if (!has_lookup_embeddings_ || lookup_embeddings_size_ < lookup_values_size) { if (has_lookup_embeddings_) { OF_CUDA_CHECK(cudaFreeAsync(lookup_embeddings_, cuda_stream)); } OF_CUDA_CHECK(cudaMallocFromPoolAsync(&lookup_embeddings_, lookup_embeddings_size, mem_pool_, cuda_stream)); has_lookup_embeddings_ = true; lookup_embeddings_size_ = lookup_embeddings_size; } } else { lookup_embeddings_ = nullptr; } } } void* LookupUniqueValues(int64_t iter) override { CHECK_EQ(iter_, iter); CHECK(has_lookup_values_); return lookup_values_; } void* LookupEmbeddings(int64_t iter) override { CHECK_EQ(iter_, iter); CHECK(has_lookup_embeddings_); return lookup_embeddings_; } void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } const void* EmbeddingGatherIn(int64_t iter) override { if (has_lookup_embeddings_) { return lookup_embeddings_; } else { CHECK(has_lookup_values_); return lookup_values_; } } void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override { if (has_lookup_embeddings_) { return lookup_embeddings_; } else { CHECK(has_lookup_values_); return lookup_values_; } } void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); const int64_t line_size = ctx->Attr("line_size"); uint32_t num_unique = this->GetIdNumUnique(iter); size_t update_values_size = GetCudaAlignedSize( num_unique * line_size * GetSizeOfDataType(updated_unique_embeddings->data_type())); OF_CUDA_CHECK(cudaMallocFromPoolAsync(&updated_values_, update_values_size, mem_pool_, ctx->stream()->As()->cuda_stream())); } const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override { CHECK_EQ(iter_, iter); CHECK(has_lookup_values_); return lookup_values_; } void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override { CHECK_EQ(iter_, iter); return updated_values_; } void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override { CHECK_EQ(iter_, iter); return updated_values_; } void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { OF_CUDA_CHECK( cudaFreeAsync(updated_values_, ctx->stream()->As()->cuda_stream())); } void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override { CHECK_EQ(iter_, iter); CHECK(has_lookup_values_); return lookup_values_; } void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { // do nothing } void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; id_statistics_vec_.at(index).final_num_unique = final_num_unique; id_statistics_vec_.at(index).iter = iter; } void SetIdNumUniqueMatrix(const std::vector& num_unique_matrix, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix; id_statistics_vec_.at(index).iter = iter; } uint32_t GetIdNumUnique(int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; const IdStatistics& statistics = id_statistics_vec_.at(index); CHECK_EQ(statistics.iter, iter) << "saved iter: " << statistics.iter << " current iter: " << iter; return statistics.final_num_unique; } const std::vector& GetIdNumUniqueMatrix(int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; const IdStatistics& statistics = id_statistics_vec_.at(index); CHECK_EQ(statistics.iter, iter) << "saved iter: " << statistics.iter << " current iter: " << iter; return statistics.num_unique_matrix; } private: void* lookup_values_; size_t lookup_values_size_; bool has_lookup_values_; void* lookup_embeddings_; size_t lookup_embeddings_size_; bool has_lookup_embeddings_; void* updated_values_; int64_t iter_; std::vector id_statistics_vec_; int device_index_{}; cudaMemPool_t mem_pool_{}; std::mutex mutex_; }; #endif class StaticTmpBufferAllocator final : public TmpBufferAllocator { public: OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator); StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {} ~StaticTmpBufferAllocator() override = default; void Allocate(void** ptr, size_t size) override { CHECK(ptr_ != nullptr); CHECK_GE(offset_, 0); size_t aligned_size = GetCudaAlignedSize(size); CHECK_LE(offset_ + aligned_size, size_); *ptr = reinterpret_cast(ptr_) + offset_; offset_ += aligned_size; } void Free(void* ptr) override { // do nothing } private: void* ptr_; int64_t offset_; size_t size_; }; class StaticAllocationEmbeddingState final : public EmbeddingState { public: OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState); StaticAllocationEmbeddingState() : lookup_unique_values_(nullptr), lookup_embeddings_(nullptr), has_lookup_embeddings_(false), embedding_shuffle_cur_rank_embeddings_(nullptr), embedding_update_unique_embeddings_(nullptr), embedding_update_updated_unique_embeddings_(nullptr), embedding_put_unique_embeddings_(nullptr), embedding_fused_update_put_unique_embeddings_(nullptr) { id_statistics_vec_.resize(kRingBufferSize); } ~StaticAllocationEmbeddingState() override = default; std::unique_ptr NewTmpBufferAllocator( user_op::KernelComputeContext* ctx) override { user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); return std::make_unique(tmp_buffer->mut_dptr(), tmp_buffer->shape_view().elem_cnt()); } void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override { user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0); lookup_unique_values_ = unique_values->mut_dptr(); if (ctx->has_output("embeddings", 0)) { user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); has_lookup_embeddings_ = true; lookup_embeddings_ = embeddings->mut_dptr(); } } void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; } void* LookupEmbeddings(int64_t iter) override { CHECK(has_lookup_embeddings_); return lookup_embeddings_; } void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { lookup_unique_values_ = nullptr; lookup_embeddings_ = nullptr; has_lookup_embeddings_ = false; } void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); embedding_gather_in_ = in->dptr(); } const void* EmbeddingGatherIn(int64_t iter) override { return embedding_gather_in_; } void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_gather_in_ = nullptr; } void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* cur_rank_embeddings = ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0); embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr(); } const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override { return embedding_shuffle_cur_rank_embeddings_; } void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_shuffle_cur_rank_embeddings_ = nullptr; } void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); embedding_update_unique_embeddings_ = unique_embeddings->dptr(); embedding_update_updated_unique_embeddings_ = updated_unique_embeddings->mut_dptr(); } const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override { return embedding_update_unique_embeddings_; } void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override { return embedding_update_updated_unique_embeddings_; } void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_update_unique_embeddings_ = nullptr; embedding_update_updated_unique_embeddings_ = nullptr; } void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); embedding_put_unique_embeddings_ = unique_embeddings->dptr(); } const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override { return embedding_put_unique_embeddings_; } void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_put_unique_embeddings_ = nullptr; } void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override { const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); embedding_fused_update_put_unique_embeddings_ = unique_embeddings->dptr(); } const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override { return embedding_fused_update_put_unique_embeddings_; } void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override { embedding_fused_update_put_unique_embeddings_ = nullptr; } void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; id_statistics_vec_.at(index).final_num_unique = final_num_unique; id_statistics_vec_.at(index).iter = iter; } void SetIdNumUniqueMatrix(const std::vector& num_unique_matrix, int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix; id_statistics_vec_.at(index).iter = iter; } uint32_t GetIdNumUnique(int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; const IdStatistics& statistics = id_statistics_vec_.at(index); CHECK_EQ(statistics.iter, iter) << "saved iter: " << statistics.iter << " current iter: " << iter; return statistics.final_num_unique; } const std::vector& GetIdNumUniqueMatrix(int64_t iter) override { std::unique_lock lock(mutex_); int64_t index = iter % kRingBufferSize; const IdStatistics& statistics = id_statistics_vec_.at(index); CHECK_EQ(statistics.iter, iter) << "saved iter: " << statistics.iter << " current iter: " << iter; return statistics.num_unique_matrix; } void* lookup_unique_values_; void* lookup_embeddings_; bool has_lookup_embeddings_; const void* embedding_gather_in_; const void* embedding_shuffle_cur_rank_embeddings_; const void* embedding_update_unique_embeddings_; void* embedding_update_updated_unique_embeddings_; const void* embedding_put_unique_embeddings_; const void* embedding_fused_update_put_unique_embeddings_; std::vector id_statistics_vec_; std::mutex mutex_; }; EmbeddingState* EmbeddingManager::GetEmbeddingState(const std::string& embedding_name, int64_t rank_id) { std::pair map_key = std::make_pair(embedding_name, rank_id); std::unique_lock lock(mutex_); auto it = embedding_state_map_.find(map_key); // for id shuffle test, not need to create table if (it == embedding_state_map_.end()) { LOG(INFO) << "create embedding state: " << embedding_name << "-" << rank_id; if (UseDynamicMemoryAllocation()) { #if CUDA_VERSION >= 11020 it = embedding_state_map_.emplace(map_key, std::make_unique()) .first; #else UNIMPLEMENTED(); #endif } else { it = embedding_state_map_.emplace(map_key, std::make_unique()) .first; } } return it->second.get(); } KeyValueStore* EmbeddingManager::GetKeyValueStore(const std::string& embedding_name, int64_t rank_id) { std::pair map_key = std::make_pair(embedding_name, rank_id); std::unique_lock lock(mutex_); auto it = key_value_store_map_.find(map_key); CHECK(it != key_value_store_map_.end()) << "Can not find embedding: " << embedding_name << "-" << rank_id; return it->second.get(); } void EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value_store_options, int64_t local_rank_id, int64_t rank_id, int64_t world_size) { CudaCurrentDeviceGuard guard(local_rank_id); const std::string& name = key_value_store_options.Name(); const uint32_t line_size = key_value_store_options.LineSize(); std::pair map_key = std::make_pair(name, rank_id); std::unique_lock lock(mutex_); std::unique_ptr store; PersistentTableKeyValueStoreOptions options{}; const std::vector& persistent_table_paths = key_value_store_options.PersistentTablePaths(); CHECK_EQ(persistent_table_paths.size(), world_size); options.table_options.path = persistent_table_paths.at(rank_id); options.table_options.value_size = line_size * key_value_store_options.ValueTypeSize(); options.table_options.key_size = key_value_store_options.KeyTypeSize(); options.table_options.physical_block_size = key_value_store_options.PersistentTablePhysicalBlockSize(); options.table_options.target_chunk_size_mb = 4 * 1024; options.table_options.capacity_hint = key_value_store_options.PersistentTableCapacityHint(); store = NewPersistentTableKeyValueStore(options); const std::vector& cache_options = key_value_store_options.GetCachesOptions(); for (int i = cache_options.size() - 1; i >= 0; --i) { std::unique_ptr cache = NewCache(cache_options.at(i)); store = NewCachedKeyValueStore(std::move(store), std::move(cache)); } store->ReserveQueryLength(kDefaultMaxQueryLength); CHECK(key_value_store_map_.emplace(map_key, std::move(store)).second) << "Can't create an embedding with same name of an existing embedding, the name: " << name; if (UseDynamicMemoryAllocation()) { #if CUDA_VERSION >= 11020 CHECK(embedding_state_map_.emplace(map_key, std::make_unique()) .second) << "Can't create an embedding state with same name of an existing embedding, the name: " << name; #else UNIMPLEMENTED(); #endif } else { CHECK(embedding_state_map_.emplace(map_key, std::make_unique()) .second) << "Can't create an embedding state with same name of an existing embedding, the name: " << name; } } void EmbeddingManager::SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id, const std::string& snapshot_name) { CudaCurrentDeviceGuard guard(local_rank_id); std::pair map_key = std::make_pair(embedding_name, rank_id); std::unique_lock lock(mutex_); auto it = key_value_store_map_.find(map_key); CHECK(it != key_value_store_map_.end()) << "Can not find embedding: " << embedding_name << "-" << rank_id; it->second->SaveSnapshot(snapshot_name); } void EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id, const std::string& snapshot_name) { CudaCurrentDeviceGuard guard(local_rank_id); std::pair map_key = std::make_pair(embedding_name, rank_id); auto it = key_value_store_map_.find(map_key); CHECK(it != key_value_store_map_.end()) << "Can not find embedding: " << embedding_name << "-" << rank_id; if (it->second->SnapshotExists(snapshot_name)) { it->second->LoadSnapshot(snapshot_name); } else { LOG(ERROR) << "Here Exists Embedding name is: " << embedding_name << "-" << rank_id << " but no corresponding snapshot. "; } } #endif // WITH_CUDA } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/embedding_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_ #define ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/embedding/key_value_store.h" #include "oneflow/core/embedding/key_value_store_options.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace embedding { inline bool UseDynamicMemoryAllocation() { static bool use_dynamic_memory_allocation = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION", false); #if CUDA_VERSION >= 11020 return use_dynamic_memory_allocation; #else if (use_dynamic_memory_allocation) { LOG(WARNING) << "Dynamic memory allocation only support when cuda_version greater equal than 11.2. "; } return false; #endif } inline bool UseEmbeddingShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) { static bool use_embedding_shuffle_p2p_env = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_USE_P2P", false); static bool add_id_shuffle_copy_out_env = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true); static bool enable_quantized_comm = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); if (use_embedding_shuffle_p2p_env) { if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) { // p2p kernel only registered kFloat16 and kUint32. return false; } if (!add_id_shuffle_copy_out_env) { // when not enable id shuffle copy out, the ptrs change every iter. return false; } if (enable_quantized_comm) { // p2p kernel not support quantize comm. return false; } if (UseDynamicMemoryAllocation()) { // p2p kernel not support dynamic memory allocation. return false; } } #if CUDA_VERSION >= 11030 return use_embedding_shuffle_p2p_env; #else if (use_embedding_shuffle_p2p_env) { LOG(WARNING) << "embedding shuffle p2p kernel only support when cuda_version greater equal than 11.3. "; } return false; #endif } inline bool UseEmbeddingGradientShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) { static bool use_embedding_gradient_shuffle_p2p_env = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_USE_P2P", false); static bool add_id_shuffle_copy_out_env = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true); static bool enable_quantized_comm = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); if (use_embedding_gradient_shuffle_p2p_env) { if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) { // p2p kernel only registered kFloat16 and kUint32. return false; } if (!add_id_shuffle_copy_out_env) { // when not enable id shuffle copy out, the ptrs change every iter. return false; } if (enable_quantized_comm) { // p2p kernel not support quantize comm. return false; } if (UseDynamicMemoryAllocation()) { // p2p kernel not support dynamic memory allocation. return false; } } #if CUDA_VERSION >= 11030 return use_embedding_gradient_shuffle_p2p_env; #else if (use_embedding_gradient_shuffle_p2p_env) { LOG(WARNING) << "embedding gradient shuffle p2p kernel only support when cuda_version greater " "equal than 11.3. "; } return false; #endif } #ifdef WITH_CUDA class TmpBufferAllocator { public: TmpBufferAllocator() = default; virtual ~TmpBufferAllocator() = default; virtual void Allocate(void** ptr, size_t size) = 0; virtual void Free(void* ptr) = 0; }; class EmbeddingState { public: EmbeddingState() = default; virtual ~EmbeddingState() = default; virtual std::unique_ptr NewTmpBufferAllocator( user_op::KernelComputeContext* ctx) = 0; virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void* LookupUniqueValues(int64_t iter) = 0; virtual void* LookupEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingGatherIn(int64_t iter) = 0; virtual void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0; virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingPutUniqueEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0; virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0; virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0; virtual void SetIdNumUniqueMatrix(const std::vector& num_unique_matrix, int64_t iter) = 0; virtual uint32_t GetIdNumUnique(int64_t iter) = 0; virtual const std::vector& GetIdNumUniqueMatrix(int64_t iter) = 0; }; class EmbeddingManager final { public: EmbeddingManager() = default; ~EmbeddingManager() = default; void SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id, const std::string& snapshot_name); void LoadSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id, const std::string& snapshot_name); KeyValueStore* GetKeyValueStore(const std::string& embedding_name, int64_t rank_id); EmbeddingState* GetEmbeddingState(const std::string& embedding_name, int64_t rank_id); void CreateKeyValueStore(const KeyValueStoreOptions& options, int64_t local_rank_id, int64_t rank_id, int64_t world_size); private: HashMap, std::unique_ptr> key_value_store_map_; HashMap, std::unique_ptr> embedding_state_map_; std::mutex mutex_; }; #endif // WITH_CUDA } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_ ================================================ FILE: oneflow/core/embedding/full_cache.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/full_cache.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/embedding/hash_functions.cuh" #include "oneflow/core/cuda/atomic.cuh" namespace oneflow { namespace embedding { using Key32 = unsigned int; using Key64 = unsigned long long int; using Key128 = ulonglong2; namespace { template __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, bool* entry_dirty_flag, Index* table_size, Key key, Index* out) { Key key_hi = (key | 0x1); Key key_lo = (key & 0x1); Index index_plus_one = 0; Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast(0), key_hi); while (index_plus_one == 0) { if (old_entry_key == static_cast(0)) { Index index = cuda::atomic::Add(table_size, static_cast(1)); index_plus_one = index + 1; *entry_index = ((index_plus_one << 1U) | key_lo); *out = index_plus_one; if (dump_dirty_only) { bool entry_flag_val = *entry_dirty_flag; if (!entry_flag_val) { *entry_dirty_flag = true; } } return true; } else if (old_entry_key == key_hi) { const Index entry_index_val = *entry_index; if (entry_index_val == 0) { // do nothing } else if ((entry_index_val & 0x1) == key_lo) { *out = (entry_index_val >> 1U); if (dump_dirty_only) { bool entry_flag_val = *entry_dirty_flag; if (!entry_flag_val) { *entry_dirty_flag = true; } } return true; } else { return false; } } else { return false; } } return false; } template __device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices, bool* table_dirty_flags, Index* table_size, Key key, size_t hash, Index* out) { const size_t start_idx = hash % capacity; for (size_t count = 0; count < capacity; ++count) { const size_t idx = (start_idx + count) % capacity; Key* entry_key = table_keys + idx; Index* entry_index = table_indices + idx; bool* entry_dirty_flag = dump_dirty_only ? table_dirty_flags + idx : nullptr; if (TryGetOrInsert(entry_key, entry_index, entry_dirty_flag, table_size, key, out)) { return true; } } return false; } template __device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key, size_t hash, Index* out) { const size_t start_idx = hash % capacity; for (size_t count = 0; count < capacity; ++count) { const size_t idx = (start_idx + count) % capacity; Key entry_key = table_keys[idx]; Key entry_index = table_indices[idx]; Key key_hi = (key | 0x1); Key key_lo = (key & 0x1); if (entry_key == 0) { break; } if (entry_key == key_hi) { if ((entry_index & 0x1) == key_lo) { *out = (entry_index >> 1U); return true; } } } *out = 0; return false; } template __global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices, bool* table_dirty_flags, Index* table_size, uint32_t num_keys, const Key* keys, Index* context) { CUDA_1D_KERNEL_LOOP(i, num_keys) { Key key = keys[i]; uint64_t hash = FullCacheHash()(key); bool success = GetOrInsertOne( capacity, table_keys, table_indices, table_dirty_flags, table_size, key, hash, context + i); assert(success); } } template __global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices, uint32_t num_keys, const Key* keys, Index* context) { CUDA_1D_KERNEL_LOOP(i, num_keys) { Key key = keys[i]; uint64_t hash = FullCacheHash()(key); GetOne(capacity, table_keys, table_indices, key, hash, context + i); } } template __global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices, const bool* table_dirty_flags, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, Key* keys, Index* context) { CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) { Key entry_key = table_keys[i + start_key_index]; Index entry_index = table_indices[i + start_key_index]; bool dump_flag = (entry_index != 0); if (dump_dirty_only) { bool entry_dirty_flag = table_dirty_flags[i + start_key_index]; dump_flag = (dump_flag && entry_dirty_flag); } if (dump_flag) { uint32_t index = cuda::atomic::Add(n_dumped, static_cast(1)); keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1)); context[index] = (entry_index >> 1U); } } } template __global__ void LookupKernel(uint32_t value_length, const Elem* cache_values, uint32_t values_elem_cnt, const Key* keys, const Index* context, Elem* values, uint32_t* n_missing, Key* missing_keys, uint32_t* missing_indices) { CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) { const uint64_t key_id = i / value_length; const uint64_t ctx = context[key_id]; const uint64_t row_id = ctx - 1; const uint64_t col_id = i - key_id * value_length; if (ctx == 0) { const Key missing_key = keys[key_id]; if (col_id == 0) { const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast(1)); missing_keys[old_n_missing] = missing_key; missing_indices[old_n_missing] = key_id; } continue; } if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; } } } template __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values, uint32_t values_elem_cnt, const Key* keys, const Index* context, Elem* values, uint32_t* n_missing, Key* missing_keys, uint32_t* missing_indices, const size_t capacity, Key* table_keys, Index* table_indices) { constexpr uint32_t warp_size = 32; constexpr uint32_t n_warp_per_block = block_size / warp_size; const uint32_t warp_id = threadIdx.x / warp_size; const uint32_t lane_id = threadIdx.x % warp_size; const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id; const uint32_t global_n_warp = gridDim.x * n_warp_per_block; const uint32_t n_keys = values_elem_cnt / value_length; __shared__ Key batch_keys[n_warp_per_block][warp_size]; __shared__ Index batch_row_ids[n_warp_per_block][warp_size]; __shared__ Key batch_missing_keys[n_warp_per_block][warp_size]; __shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size]; __shared__ uint32_t batch_n_missing[n_warp_per_block]; for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys; batch_start += global_n_warp * warp_size) { const uint32_t batch_n_key = min(n_keys - batch_start, warp_size); if (lane_id == 0) { batch_n_missing[warp_id] = 0; } __syncwarp(); const uint32_t key_offset = batch_start + lane_id; if (key_offset < n_keys) { const Key key = keys[batch_start + lane_id]; const uint64_t hash = FullCacheHash()(key); Index row; GetOne(capacity, table_keys, table_indices, key, hash, &row); batch_row_ids[warp_id][lane_id] = row; if (row == 0) { const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1); batch_missing_keys[warp_id][batch_missing_idx] = key; batch_missing_indices[warp_id][batch_missing_idx] = key_offset; } } __syncwarp(); const uint32_t batch_n_missing_t = batch_n_missing[warp_id]; if (lane_id == 0) { const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast(batch_n_missing_t)); batch_n_missing[warp_id] = old_n_missing; } __syncwarp(); if (lane_id < batch_n_missing_t) { missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id]; missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id]; } for (int i = 0; i < batch_n_key; ++i) { const Key key = batch_keys[warp_id][i]; const int64_t row = batch_row_ids[warp_id][i]; if (row == 0) { continue; } for (int col = lane_id; col < value_length; col += warp_size) { values[(batch_start + i) * value_length + col] = cache_values[(row - 1) * value_length + col]; } } __syncwarp(); } } template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; template __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values, uint32_t values_elem_cnt, const Key* __restrict__ keys, const Index* __restrict__ context, Elem* __restrict__ values, uint8_t* __restrict__ mask, const size_t capacity, Key* __restrict__ table_keys, Index* __restrict__ table_indices) { const uint32_t packed_cols = value_length / pack_size; auto* packed_values = reinterpret_cast*>(values); const auto* packed_cache_values = reinterpret_cast*>(cache_values); constexpr uint32_t warp_size = 32; constexpr uint32_t n_warp_per_block = block_size / warp_size; const uint32_t warp_id = threadIdx.x / warp_size; const uint32_t lane_id = threadIdx.x % warp_size; const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id; const uint32_t global_n_warp = gridDim.x * n_warp_per_block; const uint32_t n_keys = values_elem_cnt / value_length; __shared__ Key batch_keys[n_warp_per_block][warp_size]; __shared__ Index batch_row_ids[n_warp_per_block][warp_size]; for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys; batch_start += global_n_warp * warp_size) { const uint32_t batch_n_key = min(n_keys - batch_start, warp_size); const uint32_t key_offset = batch_start + lane_id; if (key_offset < n_keys) { const Key key = keys[batch_start + lane_id]; const uint64_t hash = FullCacheHash()(key); Index row; GetOne(capacity, table_keys, table_indices, key, hash, &row); batch_row_ids[warp_id][lane_id] = row; mask[key_offset] = row > 0; } __syncwarp(); for (int i = 0; i < batch_n_key; ++i) { const Key key = batch_keys[warp_id][i]; const int64_t row = batch_row_ids[warp_id][i]; if (row == 0) { continue; } #pragma unroll 4 for (int col = lane_id; col < packed_cols; col += warp_size) { packed_values[(batch_start + i) * packed_cols + col] = packed_cache_values[(row - 1) * packed_cols + col]; } } __syncwarp(); } } template __global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt, const Index* context, const Elem* values) { const int packed_values_elem_cnt = values_elem_cnt / pack_size; const uint32_t packed_elem_cnt = value_length / pack_size; auto* packed_cache_values = reinterpret_cast*>(cache_values); auto* packed_values = reinterpret_cast*>(values); CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) { const uint64_t key_id = i / packed_elem_cnt; const uint64_t ctx = context[key_id]; if (ctx == 0) { continue; } const uint64_t row_id = ctx - 1; const uint64_t col_id = i - key_id * packed_elem_cnt; packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i]; } } template __global__ typename std::enable_if::value, void>::type FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values, uint32_t values_elem_cnt, const Index* __restrict__ context, const Elem* __restrict__ values, const half* __restrict__ update, const float* __restrict__ lr, float scale) { const int packed_values_elem_cnt = values_elem_cnt / pack_size; const uint32_t packed_elem_cnt = value_length / pack_size; auto* packed_cache_values = reinterpret_cast*>(cache_values); auto* packed_values = reinterpret_cast*>(values); auto* packed_update = reinterpret_cast*>(update); const float alpha = -*lr * scale; CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) { const uint64_t key_id = i / packed_elem_cnt; const uint64_t ctx = context[key_id]; if (ctx == 0) { continue; } const uint64_t row_id = ctx - 1; const uint64_t col_id = i - key_id * packed_elem_cnt; Pack m = packed_values[i]; Pack u = packed_update[i]; for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast(u.elem[j]) * alpha; } packed_cache_values[row_id * packed_elem_cnt + col_id] = m; } } template __global__ typename std::enable_if::value, void>::type FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt, const Index* context, const Elem* values, const half* update, const float* lr, float scale) { __trap(); } template __global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped, const Index* context, const Elem* cache_values, Elem* values) { CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) { const uint64_t key_id = i / value_length; const uint64_t ctx = context[key_id]; const uint64_t row_id = ctx - 1; const uint64_t col_id = i - key_id * value_length; values[i] = cache_values[row_id * value_length + col_id]; } } template class OrdinalEncoder { public: OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder); explicit OrdinalEncoder(uint64_t capacity, float load_factor, bool if_dump_dirty) : capacity_(capacity), table_capacity_(capacity / load_factor), if_dump_dirty_(if_dump_dirty) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); OF_CUDA_CHECK(cudaMalloc(&table_size_, sizeof(Index))); OF_CUDA_CHECK(cudaMallocHost(&table_size_host_, sizeof(Index))); OF_CUDA_CHECK(cudaMalloc(&table_keys_, table_capacity_ * sizeof(Key))); OF_CUDA_CHECK(cudaMalloc(&table_indices_, table_capacity_ * sizeof(Index))); if (if_dump_dirty_) { OF_CUDA_CHECK(cudaMalloc(&table_dirty_flags_, table_capacity_ * sizeof(bool))); } Clear(); } ~OrdinalEncoder() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFree(table_size_)); OF_CUDA_CHECK(cudaFreeHost(table_size_host_)); OF_CUDA_CHECK(cudaFree(table_keys_)); OF_CUDA_CHECK(cudaFree(table_indices_)); if (if_dump_dirty_) { OF_CUDA_CHECK(cudaFree(table_dirty_flags_)); } } template void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) { if (insert) { RUN_CUDA_KERNEL((OrdinalEncodeKernel), stream, num_keys, table_capacity_, table_keys_, table_indices_, table_dirty_flags_, table_size_, num_keys, keys, context); } else { RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel), stream, num_keys, table_capacity_, table_keys_, table_indices_, num_keys, keys, context); } } void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, Key* keys, Index* context) { OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t), stream->As()->cuda_stream())); RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel), stream, end_key_index - start_key_index, table_keys_, table_indices_, table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context); } void DumpDirtyOnly(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, Key* keys, Index* context) { OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t), stream->As()->cuda_stream())); RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel), stream, end_key_index - start_key_index, table_keys_, table_indices_, table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context); } void ClearDirtyFlags() { if (if_dump_dirty_) { OF_CUDA_CHECK(cudaMemset(table_dirty_flags_, 0, table_capacity_ * sizeof(bool))); } } void Clear() { OF_CUDA_CHECK(cudaMemset(table_size_, 0, sizeof(Index))); OF_CUDA_CHECK(cudaMemset(table_keys_, 0, table_capacity_ * sizeof(Key))); OF_CUDA_CHECK(cudaMemset(table_indices_, 0, table_capacity_ * sizeof(Index))); if (if_dump_dirty_) { OF_CUDA_CHECK(cudaMemset(table_dirty_flags_, 0, table_capacity_ * sizeof(bool))); } } uint64_t TableCapacity() const { return table_capacity_; } Key* table_keys() const { return table_keys_; } Index* table_indices() const { return table_indices_; } private: int device_index_{}; Key* table_keys_; Index* table_indices_; bool* table_dirty_flags_; uint64_t capacity_; uint64_t table_capacity_; bool if_dump_dirty_; Index* table_size_{}; Index* table_size_host_{}; }; template class CacheImpl : public Cache { public: OF_DISALLOW_COPY_AND_MOVE(CacheImpl); explicit CacheImpl(const CacheOptions& options) : if_dump_dirty_(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DUMP_DIRTY_ONLY", false)), encoder_(options.capacity, options.load_factor, if_dump_dirty_), device_index_(-1), options_(options), max_query_length_(0) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); const uint64_t values_size = options.capacity * options.value_size; if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) { OF_CUDA_CHECK(cudaMalloc(&values_, values_size)); } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) { if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) { OF_CUDA_CHECK(cudaMallocHost(&values_, values_size)); } else { OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&values_), values_size)); } } else { UNIMPLEMENTED(); } num_elem_per_value_ = options_.value_size / sizeof(Elem); } ~CacheImpl() { CudaCurrentDeviceGuard guard(device_index_); if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) { OF_CUDA_CHECK(cudaFree(values_)); } else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) { OF_CUDA_CHECK(cudaFreeHost(values_)); } else { UNIMPLEMENTED(); } if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); } } uint64_t Capacity() const override { return options_.capacity; } uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); } uint32_t KeySize() const override { return options_.key_size; } uint32_t ValueSize() const override { return options_.value_size; } DataType ValueType() const override { return options_.value_type; } uint32_t MaxQueryLength() const override { return max_query_length_; } void ReserveQueryLength(uint32_t query_length) override { CudaCurrentDeviceGuard guard(device_index_); if (query_length <= max_query_length_) { return; } if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); } OF_CUDA_CHECK(cudaMalloc(&encoding_buffer_, query_length * sizeof(uint64_t))); max_query_length_ = query_length; } CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; } void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) override; void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) override; void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint8_t* mask) override; void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override; void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override; void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, void* keys, void* values) override; void ClearDirtyFlags() override; void Clear() override; private: bool if_dump_dirty_; OrdinalEncoder encoder_; int device_index_; uint32_t num_elem_per_value_{}; Elem* values_; Index* encoding_buffer_{}; CacheOptions options_; uint32_t max_query_length_; }; template void CacheImpl::Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) { OF_CUDA_CHECK( cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As()->cuda_stream())); if (n_keys == 0) { return; } CHECK_LE(n_keys, max_query_length_); if (if_dump_dirty_) { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } else { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; RUN_CUDA_KERNEL((LookupKernel), stream, values_elem_cnt, num_elem_per_value_, values_, values_elem_cnt, static_cast(keys), encoding_buffer_, nullptr, n_missing, static_cast(missing_keys), missing_indices); } template void CacheImpl::Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) { OF_CUDA_CHECK( cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As()->cuda_stream())); if (n_keys == 0) { return; } CHECK_LE(n_keys, max_query_length_); constexpr uint32_t block_size = 128; uint32_t grid_size = (n_keys + block_size - 1) / block_size; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; EncodeLookupKernel <<As()->cuda_stream()>>>( num_elem_per_value_, values_, values_elem_cnt, static_cast(keys), encoding_buffer_, static_cast(values), n_missing, static_cast(missing_keys), missing_indices, encoder_.TableCapacity(), encoder_.table_keys(), encoder_.table_indices()); } template void CacheImpl::Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint8_t* mask) { if (n_keys == 0) { return; } CHECK_LE(n_keys, max_query_length_); constexpr uint32_t block_size = 128; uint32_t grid_size = (n_keys + block_size - 1) / block_size; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; EncodeLookupMaskKernel <<As()->cuda_stream()>>>( num_elem_per_value_, values_, values_elem_cnt, static_cast(keys), encoding_buffer_, static_cast(values), mask, encoder_.TableCapacity(), encoder_.table_keys(), encoder_.table_indices()); } template void CacheImpl::Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) { if (n_keys == 0) { return; } CHECK_LE(n_keys, max_query_length_); if (if_dump_dirty_) { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } else { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; RUN_CUDA_KERNEL((UpdateKernel), stream, values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_, static_cast(values)); } template void CacheImpl::FusedHalfUpdatePut( ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) { if (!std::is_same::value) { UNIMPLEMENTED(); } if (n_keys == 0) { return; } CHECK_LE(n_keys, max_query_length_); if (if_dump_dirty_) { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } else { encoder_.template Encode(stream, n_keys, static_cast(keys), encoding_buffer_); } const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; RUN_CUDA_KERNEL((FusedHalfUpdateKernel), stream, values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_, static_cast(values), static_cast(update), lr, scale); } template void CacheImpl::Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, void* keys, void* values) { if (if_dump_dirty_) { encoder_.DumpDirtyOnly(stream, start_key_index, end_key_index, n_dumped, static_cast(keys), encoding_buffer_); } else { encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast(keys), encoding_buffer_); } RUN_CUDA_KERNEL((DumpValueKernel), stream, num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_, n_dumped, encoding_buffer_, values_, static_cast(values)); } template void CacheImpl::ClearDirtyFlags() { encoder_.ClearDirtyFlags(); } template void CacheImpl::Clear() { encoder_.Clear(); } template std::unique_ptr DispatchValueType(const CacheOptions& options) { if (options.value_type == DataType::kFloat) { const size_t value_elem_cnt = options.value_size / sizeof(float); const size_t half_warp = 16; if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) { return std::unique_ptr(new CacheImpl(options)); } else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) { return std::unique_ptr(new CacheImpl(options)); } else { return std::unique_ptr(new CacheImpl(options)); } } else if (options.value_size % sizeof(ulonglong2) == 0) { return std::unique_ptr(new CacheImpl(options)); } else if (options.value_size % sizeof(uint64_t) == 0) { return std::unique_ptr(new CacheImpl(options)); } else if (options.value_size % sizeof(uint32_t) == 0) { return std::unique_ptr(new CacheImpl(options)); } else if (options.value_size % sizeof(uint16_t) == 0) { return std::unique_ptr(new CacheImpl(options)); } else { return std::unique_ptr(new CacheImpl(options)); } } template std::unique_ptr DispatchKeyType(const CacheOptions& options) { if (options.key_size == sizeof(Key32)) { return DispatchValueType(options); } else if (options.key_size == sizeof(Key64)) { return DispatchValueType(options); } else { UNIMPLEMENTED(); return nullptr; } } std::unique_ptr DispatchIndexType(const CacheOptions& options) { const int64_t table_capacity = static_cast(options.capacity) / options.load_factor; if (table_capacity >= (1ULL << 31ULL)) { return DispatchKeyType(options); } else { return DispatchKeyType(options); } } } // namespace std::unique_ptr NewFullCache(const CacheOptions& options) { return DispatchIndexType(options); } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/full_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_ #define ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_ #include "oneflow/core/embedding/cache.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace embedding { #ifdef WITH_CUDA std::unique_ptr NewFullCache(const CacheOptions& options); #endif // WITH_CUDA } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_ ================================================ FILE: oneflow/core/embedding/hash_functions.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_ #define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_ #include #include "oneflow/core/common/data_type.h" namespace oneflow { namespace embedding { namespace { // From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h static const uint64_t PRIME64_1 = 0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111 static const uint64_t PRIME64_2 = 0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111 static const uint64_t PRIME64_3 = 0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001 static const uint64_t PRIME64_4 = 0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011 static const uint64_t PRIME64_5 = 0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101 #define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r)))) OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) { acc += input * PRIME64_2; acc = XXH_rotl64(acc, 31); acc *= PRIME64_1; return acc; } OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) { uint64_t acc = seed + PRIME64_5; acc += sizeof(uint64_t); acc = acc ^ XXH64_round(0, v); acc = XXH_rotl64(acc, 27) * PRIME64_1; acc = acc + PRIME64_4; acc ^= (acc >> 33); acc = acc * PRIME64_2; acc = acc ^ (acc >> 29); acc = acc * PRIME64_3; acc = acc ^ (acc >> 32); return acc; } static const size_t kShardingHashSeed = 1; static const size_t kLocalUniqueHashSeed = 2; static const size_t kGlobalUniqueHashSeed = 3; static const size_t kFullCacheHashSeed = 4; static const size_t kLruCacheHashSeed = 5; } // namespace struct ShardingHash { OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); } OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); } OF_DEVICE_FUNC size_t operator()(int32_t v) { return xxh64_uint64(static_cast(v), kShardingHashSeed); } OF_DEVICE_FUNC size_t operator()(int64_t v) { return xxh64_uint64(static_cast(v), kShardingHashSeed); } }; struct LocalUniqueHash { OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); } }; struct GlobalUniqueHash { OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); } }; struct FullCacheHash { OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); } }; struct LruCacheHash { OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); } }; } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_ ================================================ FILE: oneflow/core/embedding/key_value_store.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_ #define ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_ #include "oneflow/core/embedding/kv_iterator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { namespace embedding { class KeyValueStore { public: OF_DISALLOW_COPY_AND_MOVE(KeyValueStore); KeyValueStore() = default; virtual ~KeyValueStore() = default; virtual uint32_t KeySize() const = 0; virtual uint32_t ValueSize() const = 0; virtual uint32_t MaxQueryLength() const = 0; virtual void ReserveQueryLength(uint32_t query_length) = 0; virtual void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) = 0; virtual void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint8_t* mask) { UNIMPLEMENTED(); } virtual void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) = 0; virtual void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, const float* lr, float scale) { UNIMPLEMENTED(); } virtual bool IsFusionSupported() { return false; } virtual bool SnapshotExists(const std::string& name) = 0; virtual void LoadSnapshot(const std::string& name) = 0; virtual void LoadSnapshot(const std::string& name, const std::function& Hook) = 0; virtual void SaveSnapshot(const std::string& name) = 0; }; } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_ ================================================ FILE: oneflow/core/embedding/key_value_store_options.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_ #define ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_ #include "nlohmann/json.hpp" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/embedding/cache.h" namespace oneflow { namespace embedding { namespace { void ParseCacheOptions(const nlohmann::json& cache_obj, CacheOptions* cache_options) { CHECK_GT(cache_options->key_size, 0); CHECK_GT(cache_options->value_size, 0); CHECK(cache_obj.contains("policy")); CHECK(cache_obj["policy"].is_string()); std::string policy = cache_obj["policy"].get(); if (policy == "lru") { cache_options->policy = CacheOptions::Policy::kLRU; } else if (policy == "full") { cache_options->policy = CacheOptions::Policy::kFull; } else { UNIMPLEMENTED() << "Unsupported cache policy"; } int64_t capacity = 0; if (cache_obj.contains("capacity")) { CHECK(cache_obj["capacity"].is_number()); capacity = cache_obj["capacity"].get(); } if (cache_obj.contains("cache_memory_budget_mb")) { CHECK(cache_obj["cache_memory_budget_mb"].is_number()); int64_t cache_memory_budget_mb = cache_obj["cache_memory_budget_mb"].get(); if (cache_memory_budget_mb > 0) { CHECK_EQ(capacity, 0) << "when set capacity, must not set cache_memory_budget_mb"; capacity = cache_memory_budget_mb * 1024 * 1024 / cache_options->value_size; } } CHECK_GT(capacity, 0) << "capacity or cache_memory_budget_mb must be set"; // add an extra_capacity to avoid crash by uneven partition. const int64_t extra_capacity = capacity * 0.05; cache_options->capacity = capacity + (extra_capacity > 4096 ? extra_capacity : 4096); CHECK(cache_obj.contains("value_memory_kind")); CHECK(cache_obj["value_memory_kind"].is_string()); std::string value_memory_kind = cache_obj["value_memory_kind"].get(); if (value_memory_kind == "device") { cache_options->value_memory_kind = CacheOptions::MemoryKind::kDevice; } else if (value_memory_kind == "host") { cache_options->value_memory_kind = CacheOptions::MemoryKind::kHost; } else { UNIMPLEMENTED() << "Unsupported cache value_memory_kind"; } } } // namespace class KeyValueStoreOptions final { public: OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreOptions); explicit KeyValueStoreOptions(const std::string& json_serialized) { auto json_object = nlohmann::json::parse(json_serialized); CHECK(json_object.contains("key_type_size")); CHECK(json_object["key_type_size"].is_number()); key_type_size_ = json_object["key_type_size"].get(); CHECK(json_object.contains("value_type_size")); CHECK(json_object["value_type_size"].is_number()); std::string value_type_name = json_object["value_type"]; if (value_type_name == "oneflow.float" || value_type_name == "oneflow.float32") { value_type_ = DataType::kFloat; } else { UNIMPLEMENTED(); } value_type_size_ = json_object["value_type_size"].get(); CHECK(json_object.contains("parallel_num")); CHECK(json_object["parallel_num"].is_number()); const int64_t parallel_num = json_object["parallel_num"].get(); CHECK(json_object.contains("name")); CHECK(json_object["name"].is_string()); name_ = json_object["name"].get(); CHECK(json_object.contains("storage_dim")); CHECK(json_object["storage_dim"].is_number()); line_size_ = json_object["storage_dim"].get(); CHECK(json_object.contains("kv_store")); auto kv_store = json_object["kv_store"]; auto caches = kv_store["caches"]; if (caches != nlohmann::detail::value_t::null && caches.size() > 0) { CHECK(caches.is_array()); cache_options_.resize(caches.size()); for (int i = 0; i < caches.size(); ++i) { cache_options_.at(i).key_size = key_type_size_; cache_options_.at(i).value_size = value_type_size_ * line_size_; cache_options_.at(i).value_type = value_type_; ParseCacheOptions(caches.at(i), &cache_options_.at(i)); } } CHECK(kv_store.contains("persistent_table")); auto persistent_table = kv_store["persistent_table"]; CHECK(persistent_table.contains("path")); auto path = persistent_table["path"]; CHECK(path.is_array() || path.is_string()); if (path.is_array()) { CHECK_EQ(path.size(), parallel_num); for (int i = 0; i < path.size(); ++i) { CHECK(path.at(i).is_string()); persistent_table_paths_.push_back(path.at(i).get()); } } else { std::string root_path = path.get(); const std::string& num_rank = std::to_string(parallel_num); const int64_t rank_id_suffix_length = num_rank.size(); for (int i = 0; i < parallel_num; ++i) { const std::string& rank_id = std::to_string(i); const std::string rank_i_path = root_path + "/" + std::string(rank_id_suffix_length - rank_id.size(), '0') + rank_id + "-" + num_rank; persistent_table_paths_.push_back(rank_i_path); } } CHECK(persistent_table.contains("physical_block_size")); CHECK(persistent_table["physical_block_size"].is_number()); persistent_table_physical_block_size_ = persistent_table["physical_block_size"].get(); if (persistent_table.contains("capacity_hint")) { CHECK(persistent_table["capacity_hint"].is_number()); persistent_table_capacity_hint_ = persistent_table["capacity_hint"].get(); } else { persistent_table_capacity_hint_ = 0; } } ~KeyValueStoreOptions() = default; int64_t KeyTypeSize() const { return key_type_size_; } int64_t ValueTypeSize() const { return value_type_size_; } DataType ValueType() const { return value_type_; } const std::string& Name() const { return name_; } int64_t LineSize() const { return line_size_; } const std::vector& GetCachesOptions() const { return cache_options_; } const std::vector& PersistentTablePaths() const { return persistent_table_paths_; } int64_t PersistentTablePhysicalBlockSize() const { return persistent_table_physical_block_size_; } int64_t PersistentTableCapacityHint() const { return persistent_table_capacity_hint_; } bool IsFullCache() const { if (cache_options_.size() > 0 && cache_options_.at(0).policy == CacheOptions::Policy::kFull) { return true; } return false; } private: int64_t key_type_size_; int64_t value_type_size_; DataType value_type_; std::string name_; int64_t line_size_; std::vector persistent_table_paths_; int64_t persistent_table_physical_block_size_; int64_t persistent_table_capacity_hint_; std::vector cache_options_; }; } // namespace embedding } // namespace oneflow #endif // ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_ ================================================ FILE: oneflow/core/embedding/key_value_store_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/persistent_table_key_value_store.h" #include "oneflow/core/embedding/cached_key_value_store.h" #include "oneflow/core/embedding/mock_key_value_store.h" #include "oneflow/core/embedding/cache.h" #include "oneflow/core/device/cuda_util.h" #include #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/embedding/posix_file.h" namespace oneflow { namespace embedding { namespace { #ifdef WITH_CUDA std::string CreateTempDirectory() { const char* tmp_env = getenv("TMPDIR"); const char* tmp_dir = tmp_env == nullptr ? "/tmp" : tmp_env; std::string tpl = std::string(tmp_dir) + "/test_kv_XXXXXX"; char* path = mkdtemp(const_cast(tpl.c_str())); PCHECK(path != nullptr); return std::string(path); } bool HasCudaDevice() { int device_count = 0; if (cudaGetDeviceCount(&device_count) != cudaSuccess) { return false; } if (device_count <= 0) { return false; } return true; } void TestKeyValueStore(KeyValueStore* store, size_t num_embeddings, size_t test_embeddings, size_t embedding_vec_size) { auto device = Singleton::Get()->GetDevice(DeviceType::kCUDA, 0); ep::Stream* stream = device->CreateStream(); store->SaveSnapshot("init"); uint64_t* keys = nullptr; float* values = nullptr; float* values1 = nullptr; uint64_t* keys_host = nullptr; float* values_host = nullptr; uint64_t* context = nullptr; uint32_t* n_missing = nullptr; uint32_t* host_n_missing = nullptr; uint64_t* missing_keys = nullptr; uint32_t* missing_indices = nullptr; size_t keys_size = sizeof(uint64_t) * num_embeddings; size_t values_size = sizeof(float) * embedding_vec_size * num_embeddings; size_t context_size = sizeof(uint64_t) * num_embeddings; const size_t batch_size = 128; OF_CUDA_CHECK(cudaMalloc(&keys, keys_size)); OF_CUDA_CHECK(cudaMalloc(&values, values_size)); OF_CUDA_CHECK(cudaMalloc(&values1, values_size)); OF_CUDA_CHECK(cudaMalloc(&context, context_size)); OF_CUDA_CHECK(cudaMallocHost(&keys_host, keys_size)); OF_CUDA_CHECK(cudaMallocHost(&values_host, values_size)); OF_CUDA_CHECK(cudaMallocHost(&host_n_missing, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&missing_keys, batch_size * sizeof(uint64_t))); OF_CUDA_CHECK(cudaMalloc(&missing_indices, batch_size * sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&n_missing, sizeof(uint32_t))); for (size_t i = 0; i < num_embeddings; ++i) { uint64_t key = i + 1; keys_host[i] = key; for (size_t j = 0; j < embedding_vec_size; j++) { values_host[i * embedding_vec_size + j] = key; } } OF_CUDA_CHECK(cudaMemcpy(keys, keys_host, keys_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaMemcpy(values, values_host, values_size, cudaMemcpyDefault)); store->Put(stream, 0, keys, values); OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaGetLastError()); for (size_t offset = 0; offset < test_embeddings; offset += batch_size) { const size_t num_keys = std::min(batch_size, test_embeddings - offset); store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing, missing_indices); OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*host_n_missing, num_keys); store->Put(stream, num_keys, keys + offset, values + offset * embedding_vec_size); } OF_CUDA_CHECK(cudaDeviceSynchronize()); store->SaveSnapshot("final"); OF_CUDA_CHECK(cudaMemset(values_host, 0, values_size)); OF_CUDA_CHECK(cudaMemset(values, 0, values_size)); for (size_t offset = 0; offset < test_embeddings; offset += batch_size) { const size_t num_keys = std::min(batch_size, test_embeddings - offset); store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing, missing_indices); OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*host_n_missing, 0); } OF_CUDA_CHECK(cudaMemcpy(values_host, values, values_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); for (size_t i = 0; i < test_embeddings; ++i) { uint64_t key = keys_host[i]; for (size_t j = 0; j < embedding_vec_size; j++) { ASSERT_EQ(values_host[i * embedding_vec_size + j], key); } } store->LoadSnapshot("init"); for (size_t offset = 0; offset < test_embeddings; offset += batch_size) { const size_t num_keys = std::min(batch_size, test_embeddings - offset); store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing, missing_indices); OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*host_n_missing, num_keys); } store->LoadSnapshot("final"); OF_CUDA_CHECK(cudaMemset(values_host, 0, values_size)); OF_CUDA_CHECK(cudaMemset(values, 0, values_size)); for (size_t offset = 0; offset < test_embeddings; offset += batch_size) { const size_t num_keys = std::min(batch_size, test_embeddings - offset); store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing, missing_indices); OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); ASSERT_EQ(*host_n_missing, 0); } OF_CUDA_CHECK(cudaMemcpy(values_host, values, values_size, cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); for (size_t i = 0; i < test_embeddings; ++i) { uint64_t key = keys_host[i]; for (size_t j = 0; j < embedding_vec_size; j++) { ASSERT_EQ(values_host[i * embedding_vec_size + j], key); } } OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaGetLastError()); OF_CUDA_CHECK(cudaFree(keys)); OF_CUDA_CHECK(cudaFree(values)); OF_CUDA_CHECK(cudaFree(values1)); OF_CUDA_CHECK(cudaFreeHost(keys_host)); OF_CUDA_CHECK(cudaFreeHost(values_host)); OF_CUDA_CHECK(cudaFreeHost(host_n_missing)); OF_CUDA_CHECK(cudaFree(n_missing)); OF_CUDA_CHECK(cudaFree(missing_keys)); OF_CUDA_CHECK(cudaFree(missing_indices)); CHECK_JUST(stream->Sync()); device->DestroyStream(stream); } TEST(PersistentTableKeyValueStore, PersistentTableKeyValueStore) { if (!HasCudaDevice()) { return; } Singleton::New(); PersistentTableKeyValueStoreOptions options{}; uint32_t value_length = 128; std::string path = CreateTempDirectory(); options.table_options.path = path; options.table_options.value_size = value_length * sizeof(float); options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64); options.table_options.physical_block_size = 512; std::unique_ptr store = NewPersistentTableKeyValueStore(options); store->ReserveQueryLength(128); TestKeyValueStore(store.get(), 1024, 1024, value_length); store.reset(); PosixFile::RecursiveDelete(path); Singleton::Delete(); } TEST(CachedKeyValueStore, LRU) { if (!HasCudaDevice()) { return; } Singleton::New(); PersistentTableKeyValueStoreOptions store_options{}; std::string path = CreateTempDirectory(); store_options.table_options.path = path; uint32_t value_length = 128; store_options.table_options.value_size = value_length * sizeof(float); store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64); store_options.table_options.physical_block_size = 512; std::unique_ptr store = NewPersistentTableKeyValueStore(store_options); CacheOptions cache_options{}; cache_options.policy = CacheOptions::Policy::kLRU; cache_options.value_memory_kind = CacheOptions::MemoryKind::kDevice; cache_options.value_size = 512; cache_options.capacity = 512; cache_options.key_size = 8; std::unique_ptr cache = NewCache(cache_options); std::unique_ptr cached_store = NewCachedKeyValueStore(std::move(store), std::move(cache)); cached_store->ReserveQueryLength(128); TestKeyValueStore(cached_store.get(), 1024, 1024, value_length); cached_store.reset(); PosixFile::RecursiveDelete(path); Singleton::Delete(); } TEST(CachedKeyValueStore, Full) { if (!HasCudaDevice()) { return; } Singleton::New(); PersistentTableKeyValueStoreOptions store_options{}; std::string path = CreateTempDirectory(); store_options.table_options.path = path; uint32_t value_length = 128; store_options.table_options.value_size = value_length * sizeof(float); store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64); store_options.table_options.physical_block_size = 512; std::unique_ptr store = NewPersistentTableKeyValueStore(store_options); CacheOptions cache_options{}; cache_options.policy = CacheOptions::Policy::kFull; cache_options.value_memory_kind = CacheOptions::MemoryKind::kHost; cache_options.value_size = 512; cache_options.capacity = 1024 * 2; cache_options.key_size = 8; std::unique_ptr cache = NewCache(cache_options); std::unique_ptr cached_store = NewCachedKeyValueStore(std::move(store), std::move(cache)); cached_store->ReserveQueryLength(128); TestKeyValueStore(cached_store.get(), 1024, 1024, value_length); cached_store.reset(); PosixFile::RecursiveDelete(path); Singleton::Delete(); } TEST(MockKeyValueStore, Mock) { if (!HasCudaDevice()) { return; } Singleton::New(); MockKeyValueStoreOptions store_options{}; std::string path = CreateTempDirectory(); uint32_t value_length = 128; store_options.value_size = value_length * sizeof(float); store_options.key_size = GetSizeOfDataType(DataType::kUInt64); std::unique_ptr store = NewMockKeyValueStore(store_options); store->ReserveQueryLength(128); TestKeyValueStore(store.get(), 1024, 1024, value_length); store.reset(); PosixFile::RecursiveDelete(path); Singleton::Delete(); } #endif // WITH_CUDA } // namespace } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/kv_iterator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_ #define ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { namespace embedding { class KVIterator { public: OF_DISALLOW_COPY_AND_MOVE(KVIterator); KVIterator() = default; virtual ~KVIterator() = default; virtual void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys, void* values) = 0; virtual void Reset() = 0; }; } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_ ================================================ FILE: oneflow/core/embedding/lru_cache.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu #include "oneflow/core/embedding/lru_cache.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/embedding/hash_functions.cuh" #include #include #if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \ && !(defined(__clang__) && defined(__CUDA__)) #include #endif namespace oneflow { namespace embedding { namespace { constexpr int kWarpSize = 32; constexpr int kNumWarpPerBlock = 4; constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize; constexpr uint32_t kFullMask = 0xFFFFFFFFU; ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) { return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kWarpSize * kNumWarpPerBlock, 0); } struct ThreadContext { __device__ ThreadContext() { const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; global_warp_id = global_thread_id / kWarpSize; warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT lane_id = global_thread_id % kWarpSize; } uint32_t global_warp_id; uint32_t warp_id_in_block; uint32_t num_warps; uint32_t lane_id; }; class WarpMutexAtomicImpl { public: OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl); __device__ WarpMutexAtomicImpl() : flag_(0) {} __device__ ~WarpMutexAtomicImpl() = default; __device__ void Lock(const ThreadContext& thread_ctx) { if (thread_ctx.lane_id == 0) { while (atomicCAS(&flag_, 0, 1) != 0) ; } __threadfence(); __syncwarp(); } __device__ void Unlock(const ThreadContext& thread_ctx) { __syncwarp(); __threadfence(); if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); } } private: int32_t flag_; }; #if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \ && !(defined(__clang__) && defined(__CUDA__)) class WarpMutexSemaphoreImpl { public: OF_DISALLOW_COPY_AND_MOVE(WarpMutexSemaphoreImpl); __device__ WarpMutexSemaphoreImpl() : semaphore_(1) {} __device__ ~WarpMutexSemaphoreImpl() = default; __device__ void Lock(const ThreadContext& thread_ctx) { if (thread_ctx.lane_id == 0) { semaphore_.acquire(); } __syncwarp(); } __device__ void Unlock(const ThreadContext& thread_ctx) { __syncwarp(); if (thread_ctx.lane_id == 0) { semaphore_.release(); } } private: cuda::binary_semaphore semaphore_; }; #endif template struct LruCacheContext { Key* keys; Elem* lines; uint8_t* ages; void* mutex; uint64_t n_set; uint32_t line_size; CacheOptions::MemoryKind value_memory_kind; }; __global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) { #if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__)) using WarpMutex = WarpMutexSemaphoreImpl; #else using WarpMutex = WarpMutexAtomicImpl; #endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && // defined(__CUDA__)) const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n_set) { new (reinterpret_cast(mutex) + idx) WarpMutex; } } template void ClearLruCacheContext(LruCacheContext* ctx) { OF_CUDA_CHECK(cudaMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key))); OF_CUDA_CHECK(cudaMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t))); InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex); } template void InitLruCacheContext(const CacheOptions& options, LruCacheContext* ctx) { const size_t keys_size_per_set = kWarpSize * sizeof(Key); const uint32_t line_size = options.value_size / sizeof(Elem); const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem); const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t); int device = 0; OF_CUDA_CHECK(cudaGetDevice(&device)); int major = 0; OF_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); size_t mutex_size_per_set = 0; #if CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__)) if (major >= 7) { #if !defined(__CUDA_ARCH__) mutex_size_per_set = sizeof(WarpMutexSemaphoreImpl); #else UNIMPLEMENTED(); #endif } else { mutex_size_per_set = sizeof(WarpMutexAtomicImpl); } #else mutex_size_per_set = sizeof(WarpMutexAtomicImpl); #endif // CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__)) const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize; CHECK_GT(n_set, 0); ctx->n_set = n_set; ctx->line_size = line_size; const size_t keys_size = n_set * keys_size_per_set; OF_CUDA_CHECK(cudaMalloc(&(ctx->keys), keys_size)); const size_t lines_size = n_set * lines_size_per_set; if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) { OF_CUDA_CHECK(cudaMalloc(&(ctx->lines), lines_size)); } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) { if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) { OF_CUDA_CHECK(cudaMallocHost(&(ctx->lines), lines_size)); } else { OF_CUDA_CHECK( NumaAwareCudaMallocHost(device, reinterpret_cast(&ctx->lines), lines_size)); } } else { UNIMPLEMENTED(); } ctx->value_memory_kind = options.value_memory_kind; const size_t ages_size = n_set * ages_size_per_set; OF_CUDA_CHECK(cudaMalloc(&(ctx->ages), ages_size)); const size_t mutex_size = n_set * mutex_size_per_set; OF_CUDA_CHECK(cudaMalloc(&(ctx->mutex), mutex_size)); ClearLruCacheContext(ctx); } template void DestroyLruCacheContext(LruCacheContext* ctx) { OF_CUDA_CHECK(cudaFree(ctx->keys)); if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) { OF_CUDA_CHECK(cudaFree(ctx->lines)); } else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) { OF_CUDA_CHECK(cudaFreeHost(ctx->lines)); } else { UNIMPLEMENTED(); } OF_CUDA_CHECK(cudaFree(ctx->ages)); OF_CUDA_CHECK(cudaFree(ctx->mutex)); } template struct SetContext { #if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__)) using WarpMutex = WarpMutexSemaphoreImpl; #else using WarpMutex = WarpMutexAtomicImpl; #endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && // defined(__CUDA__)) __device__ SetContext(const LruCacheContext& ctx, uint32_t set_id) : keys(ctx.keys + set_id * kWarpSize), mutex(reinterpret_cast(ctx.mutex) + set_id), ages(ctx.ages + set_id * kWarpSize), lines(ctx.lines + static_cast(set_id) * kWarpSize * ctx.line_size) {} __device__ int Lookup(const ThreadContext& thread_ctx, Key key) { const Key lane_key = keys[thread_ctx.lane_id]; const int lane_age = ages[thread_ctx.lane_id]; const bool lane_hit = (lane_key == key && lane_age != 0); const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit); if (hit_mask != 0) { return __ffs(static_cast(hit_mask)) - 1; } else { return -1; } } __device__ void Read(const LruCacheContext& cache_ctx, const ThreadContext& thread_ctx, int way, Elem* line) { const Elem* from_line = lines + way * cache_ctx.line_size; for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) { line[i] = from_line[i]; } } __device__ int InsertWithoutEvicting(const LruCacheContext& cache_ctx, const ThreadContext& thread_ctx, Key key) { int insert_way = -1; const Key lane_key = keys[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id]; const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0); if (hit_mask != 0) { insert_way = __ffs(static_cast(hit_mask)) - 1; const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way); if (lane_age > insert_way_age) { lane_age -= 1; } else if (thread_ctx.lane_id == insert_way) { lane_age = kWarpSize; } __syncwarp(); } if (insert_way == -1) { const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0); if (valid_mask != kFullMask) { insert_way = __popc(static_cast(valid_mask)); if (lane_age > 0) { lane_age -= 1; } else if (thread_ctx.lane_id == insert_way) { lane_age = kWarpSize; keys[insert_way] = key; } __syncwarp(); } } if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; } return insert_way; } __device__ void Evict(const LruCacheContext& cache_ctx, const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) { const Key lane_key = keys[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id]; const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1; *evicted_key = __shfl_sync(kFullMask, lane_key, insert_way); if (thread_ctx.lane_id == insert_way) { keys[insert_way] = key; lane_age = kWarpSize; } else if (lane_age > 1) { lane_age -= 1; } __syncwarp(); ages[thread_ctx.lane_id] = lane_age; *way = insert_way; } __device__ void Write(const LruCacheContext& cache_ctx, const ThreadContext& thread_ctx, int way, const Elem* line) { Elem* to_line = lines + way * cache_ctx.line_size; for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) { to_line[i] = line[i]; } } __device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); } __device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); } Key* keys; Elem* lines; uint8_t* ages; WarpMutex* mutex; }; template __global__ void GetKernel(LruCacheContext cache_ctx, uint32_t num_keys, const Key* keys, Elem* values, uint32_t* n_missing_keys, Key* missing_keys, uint32_t* missing_indices) { ThreadContext thread_ctx{}; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys; batch_offset += thread_ctx.num_warps * kWarpSize) { const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset); if (thread_ctx.lane_id < n_batch_keys) { const Key key = keys[batch_offset + thread_ctx.lane_id]; const size_t hash = LruCacheHash()(key); const uint32_t set_id = hash % cache_ctx.n_set; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; } __syncwarp(); uint32_t n_warp_missing = 0; Key warp_missing_key = 0; uint32_t warp_missing_index = 0; for (uint32_t i = 0; i < n_batch_keys; ++i) { const uint32_t key_idx = batch_offset + i; const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; SetContext set_ctx(cache_ctx, set_id); const int way = set_ctx.Lookup(thread_ctx, key); if (way < 0) { if (thread_ctx.lane_id == n_warp_missing) { warp_missing_key = key; warp_missing_index = key_idx; } __syncwarp(); n_warp_missing += 1; } else if (!test_only) { set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size); } } if (n_warp_missing > 0) { uint32_t base_missing_idx = 0; if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); } __syncwarp(); base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0); if (thread_ctx.lane_id < n_warp_missing) { missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key; missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index; } __syncwarp(); } __syncwarp(); } } template __global__ void PutWithoutEvictingKernel(LruCacheContext cache_ctx, uint32_t num_keys, const Key* keys, const Elem* values, uint32_t* n_missing, Key* missing_keys, uint32_t* missing_indices) { ThreadContext thread_ctx{}; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys; batch_offset += thread_ctx.num_warps * kWarpSize) { const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset); if (thread_ctx.lane_id < n_batch_keys) { const Key key = keys[batch_offset + thread_ctx.lane_id]; const size_t hash = LruCacheHash()(key); const uint32_t set_id = hash % cache_ctx.n_set; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; } __syncwarp(); uint32_t n_warp_missing = 0; Key warp_missing_key = 0; uint32_t warp_missing_index = 0; for (uint32_t i = 0; i < n_batch_keys; ++i) { const uint32_t key_idx = batch_offset + i; const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; SetContext set_ctx(cache_ctx, set_id); set_ctx.Lock(thread_ctx); Key evicted_key = 0; const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key); if (insert_way >= 0) { set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx); } else { if (thread_ctx.lane_id == n_warp_missing) { warp_missing_key = key; warp_missing_index = key_idx; } __syncwarp(); n_warp_missing += 1; } set_ctx.Unlock(thread_ctx); } if (n_warp_missing > 0) { uint32_t base_missing_idx = 0; if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); } __syncwarp(); base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0); if (thread_ctx.lane_id < n_warp_missing) { missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key; missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index; } __syncwarp(); } } } template __global__ void EvictKernel(LruCacheContext cache_ctx, const Key* keys, const uint32_t* indices, const Elem* values, const uint32_t* n_evict, Key* evicted_keys, Elem* evicted_values) { ThreadContext thread_ctx{}; uint32_t num_evict = *n_evict; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict; batch_offset += thread_ctx.num_warps * kWarpSize) { const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset); if (thread_ctx.lane_id < n_batch_keys) { const Key key = keys[batch_offset + thread_ctx.lane_id]; const size_t hash = LruCacheHash()(key); const uint32_t set_id = hash % cache_ctx.n_set; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; } __syncwarp(); for (uint32_t i = 0; i < n_batch_keys; ++i) { const uint32_t key_idx = batch_offset + i; const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; SetContext set_ctx(cache_ctx, set_id); set_ctx.Lock(thread_ctx); int evicted_way = -1; Key evicted_key = 0; set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key); if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; } __syncwarp(); set_ctx.Read(cache_ctx, thread_ctx, evicted_way, evicted_values + cache_ctx.line_size * key_idx); set_ctx.Write(cache_ctx, thread_ctx, evicted_way, values + cache_ctx.line_size * indices[key_idx]); set_ctx.Unlock(thread_ctx); } } } template __global__ void DumpKernel(LruCacheContext cache_ctx, size_t start_key_index, size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) { ThreadContext thread_ctx{}; __shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize]; __shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize]; for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize; warp_start_key_index < end_key_index; warp_start_key_index += thread_ctx.num_warps * kWarpSize) { Key lane_key = 0; uint8_t lane_age = 0; if (warp_start_key_index + thread_ctx.lane_id < end_key_index) { lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id]; lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id]; } __syncwarp(); warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key; warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age; const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0)); if (key_count == 0) { continue; } uint32_t offset = 0; if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); } offset = __shfl_sync(kFullMask, offset, 0); __syncwarp(); for (uint32_t i = 0; i < kWarpSize; ++i) { const Key key = warp_keys[thread_ctx.warp_id_in_block][i]; const Key age = warp_ages[thread_ctx.warp_id_in_block][i]; if (age == 0) { continue; } if (thread_ctx.lane_id == 0) { keys[offset] = key; } __syncwarp(); for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) { values[offset * cache_ctx.line_size + j] = cache_ctx .lines[static_cast(warp_start_key_index + i) * cache_ctx.line_size + j]; } __syncwarp(); offset += 1; } } } template class LruCache : public Cache { public: OF_DISALLOW_COPY_AND_MOVE(LruCache); explicit LruCache(const CacheOptions& options) : device_index_{}, max_query_length_(0), query_indices_buffer_(nullptr), query_keys_buffer_(nullptr), value_type_(options.value_type) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); InitLruCacheContext(options, &ctx_); } ~LruCache() override { CudaCurrentDeviceGuard guard(device_index_); if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFree(query_indices_buffer_)); OF_CUDA_CHECK(cudaFree(query_keys_buffer_)); } DestroyLruCacheContext(&ctx_); } uint32_t KeySize() const override { return sizeof(Key); } uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; } DataType ValueType() const override { return value_type_; } uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; } uint32_t MaxQueryLength() const override { return max_query_length_; } void ReserveQueryLength(uint32_t query_length) override { CudaCurrentDeviceGuard guard(device_index_); if (query_length < max_query_length_) { return; } if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFree(query_indices_buffer_)); OF_CUDA_CHECK(cudaFree(query_keys_buffer_)); } OF_CUDA_CHECK(cudaMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(cudaMalloc(&query_keys_buffer_, query_length * sizeof(Key))); max_query_length_ = query_length; } CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; } void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) override { CHECK_LE(n_keys, max_query_length_); auto cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); if (n_keys == 0) { return; } cuda_stream->LaunchKernel(GetKernel, GetLaunchConfig(n_keys), ctx_, n_keys, static_cast(keys), nullptr, n_missing, static_cast(missing_keys), missing_indices); } using Cache::Get; void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) override { CHECK_LE(n_keys, max_query_length_); auto cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); if (n_keys == 0) { return; } cuda_stream->LaunchKernel(GetKernel, GetLaunchConfig(n_keys), ctx_, n_keys, static_cast(keys), static_cast(values), n_missing, static_cast(missing_keys), missing_indices); } void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override { CHECK_LE(n_keys, max_query_length_); auto cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); if (n_keys == 0) { return; } cuda_stream->LaunchKernel(PutWithoutEvictingKernel, GetLaunchConfig(n_keys), ctx_, n_keys, static_cast(keys), static_cast(values), n_evicted, query_keys_buffer_, query_indices_buffer_); cuda_stream->LaunchKernel(EvictKernel, GetLaunchConfig(n_keys), ctx_, query_keys_buffer_, query_indices_buffer_, static_cast(values), n_evicted, static_cast(evicted_keys), static_cast(evicted_values)); } void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, uint32_t* n_dumped, void* keys, void* values) override { auto cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); const uint64_t max_dump_keys = end_key_index - start_key_index; cuda_stream->LaunchKernel( DumpKernel, ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize, 0), ctx_, start_key_index, end_key_index, n_dumped, static_cast(keys), static_cast(values)); } void ClearDirtyFlags() override { // do nothing. return; } void Clear() override { ClearLruCacheContext(&ctx_); } private: int device_index_; uint32_t max_query_length_; LruCacheContext ctx_; uint32_t* query_indices_buffer_; Key* query_keys_buffer_; DataType value_type_; }; template std::unique_ptr DispatchValueType(const CacheOptions& options) { if (options.value_size % sizeof(ulonglong2) == 0) { return std::unique_ptr(new LruCache(options)); } else if (options.value_size % sizeof(uint64_t) == 0) { return std::unique_ptr(new LruCache(options)); } else if (options.value_size % sizeof(uint32_t) == 0) { return std::unique_ptr(new LruCache(options)); } else if (options.value_size % sizeof(uint16_t) == 0) { return std::unique_ptr(new LruCache(options)); } else { return std::unique_ptr(new LruCache(options)); } } std::unique_ptr DispatchKeyType(const CacheOptions& options) { if (options.key_size == sizeof(uint32_t)) { return DispatchValueType(options); } else if (options.key_size == sizeof(uint64_t)) { return DispatchValueType(options); } else { UNIMPLEMENTED(); return nullptr; } } } // namespace std::unique_ptr NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/lru_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_ #define ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_ #include "oneflow/core/embedding/cache.h" namespace oneflow { namespace embedding { std::unique_ptr NewLruCache(const CacheOptions& options); } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_ ================================================ FILE: oneflow/core/embedding/mock_key_value_store.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/mock_key_value_store.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace embedding { namespace { template class IteratorImpl : public KVIterator { public: OF_DISALLOW_COPY_AND_MOVE(IteratorImpl); IteratorImpl(HashMap* store, uint32_t key_size, uint32_t value_size, uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer, uint32_t* host_num_buffer) : store_(store), pos_(store->begin()), key_size_(key_size), value_size_(value_size), max_query_length_(max_query_length), host_keys_buffer_(host_keys_buffer), host_values_buffer_(host_values_buffer), host_num_buffer_(host_num_buffer) {} ~IteratorImpl() override = default; void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys, void* values) override { CHECK_LE(n_request, max_query_length_); auto cuda_stream = stream->As(); CHECK_JUST(cuda_stream->Sync()); *host_num_buffer_ = 0; while (*host_num_buffer_ < n_request && pos_ != store_->end()) { reinterpret_cast(host_keys_buffer_)[*host_num_buffer_] = pos_->first; std::memcpy(reinterpret_cast(host_values_buffer_) + *host_num_buffer_ * value_size_, pos_->second.data(), value_size_); } OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); const uint32_t num_keys = *host_num_buffer_; if (num_keys != 0) { OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); } } void Reset() override { pos_ = store_->begin(); } private: HashMap* store_; typename HashMap::iterator pos_; uint32_t key_size_; uint32_t value_size_; uint32_t max_query_length_; void* host_keys_buffer_; void* host_values_buffer_; uint32_t* host_num_buffer_; }; template class KeyValueStoreImpl : public KeyValueStore { public: OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl); explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options) : device_index_(-1), max_query_length_(0) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); key_size_ = options.key_size; value_size_ = options.value_size; OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_keys_), key_size_ * max_query_length_)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_query_values_), value_size_ * max_query_length_)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_n_missing_), sizeof(uint32_t))); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_missing_indices_), sizeof(uint32_t) * max_query_length_)); } ~KeyValueStoreImpl() { CudaCurrentDeviceGuard guard(device_index_); if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFreeHost(host_query_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_query_values_)); OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_)); } OF_CUDA_CHECK(cudaFreeHost(host_n_missing_)); } uint32_t KeySize() const override { return key_size_; } uint32_t ValueSize() const override { return value_size_; } uint32_t MaxQueryLength() const override { return max_query_length_; } void ReserveQueryLength(uint32_t query_length) override { CudaCurrentDeviceGuard guard(device_index_); if (query_length <= max_query_length_) { return; } if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFreeHost(host_query_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_query_values_)); OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_)); } OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_keys_), key_size_ * query_length)); OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_values_), value_size_ * query_length)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_missing_indices_), sizeof(uint32_t) * query_length)); max_query_length_ = query_length; } using KeyValueStore::Get; void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; bool SnapshotExists(const std::string& name) override; void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name, const std::function& Hook) override; void SaveSnapshot(const std::string& name) override; private: int device_index_; uint32_t max_query_length_; uint32_t key_size_; uint32_t value_size_; Key* host_query_keys_{}; uint8_t* host_query_values_{}; uint32_t* host_n_missing_{}; uint32_t* host_missing_indices_{}; HashMap store_; HashMap> snapshots_; std::mutex mutex_; }; template void KeyValueStoreImpl::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) { std::lock_guard lock(mutex_); auto cuda_stream = stream->As(); CHECK_LE(num_keys, max_query_length_); if (num_keys == 0) { OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As()->cuda_stream())); return; } OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); *host_n_missing_ = 0; for (uint32_t i = 0; i < num_keys; ++i) { auto it = store_.find(host_query_keys_[i]); if (it != store_.end()) { std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_); } else { host_missing_indices_[*host_n_missing_] = i; *host_n_missing_ += 1; } } OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_, (*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); } template void KeyValueStoreImpl::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) { std::lock_guard lock(mutex_); auto cuda_stream = stream->As(); CHECK_LE(num_keys, max_query_length_); if (num_keys == 0) { return; } OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); for (uint32_t i = 0; i < num_keys; ++i) { store_[host_query_keys_[i]] = std::string( reinterpret_cast(host_query_values_) + i * value_size_, value_size_); } } template bool KeyValueStoreImpl::SnapshotExists(const std::string& name) { return snapshots_.find(name) != snapshots_.end(); } template void KeyValueStoreImpl::LoadSnapshot(const std::string& name) { CudaCurrentDeviceGuard guard(device_index_); LoadSnapshot(name, nullptr); } template void KeyValueStoreImpl::LoadSnapshot(const std::string& name, const std::function& Hook) { CudaCurrentDeviceGuard guard(device_index_); store_ = snapshots_[name]; if (Hook) { IteratorImpl iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_, host_query_values_, host_n_missing_); Hook(&iterator); } } template void KeyValueStoreImpl::SaveSnapshot(const std::string& name) { CudaCurrentDeviceGuard guard(device_index_); snapshots_[name] = store_; } } // namespace std::unique_ptr NewMockKeyValueStore(const MockKeyValueStoreOptions& options) { if (options.key_size == sizeof(uint64_t)) { return std::unique_ptr(new KeyValueStoreImpl(options)); } else if (options.key_size == sizeof(uint32_t)) { return std::unique_ptr(new KeyValueStoreImpl(options)); } else { UNIMPLEMENTED(); return nullptr; } } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/mock_key_value_store.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_ #define ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_ #include "oneflow/core/embedding/key_value_store.h" namespace oneflow { namespace embedding { #ifdef WITH_CUDA struct MockKeyValueStoreOptions { uint32_t key_size = 0; uint32_t value_size = 0; }; std::unique_ptr NewMockKeyValueStore(const MockKeyValueStoreOptions& options); #endif // WITH_CUDA } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_ ================================================ FILE: oneflow/core/embedding/persistent_table.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/persistent_table.h" #include "oneflow/core/common/util.h" #include "oneflow/core/embedding/hash_functions.cuh" #ifdef __linux__ #include "oneflow/core/common/channel.h" #include "oneflow/core/embedding/posix_file.h" #include "oneflow/core/common/blocking_counter.h" #include #include #include #include #include #include #include #endif // __linux__ namespace oneflow { namespace embedding { #ifdef __linux__ namespace { constexpr uint32_t kDefaultNumWorkerThreads = 4; constexpr uint32_t kRingQueueDepth = 128; constexpr uint32_t kRingSubmitBatch = 32; constexpr uint32_t kAioQueueDepth = 128; constexpr uint32_t kChunkNameSuffixLength = 12; constexpr char const* kKeyFileNamePrefix = "key-"; constexpr char const* kIndexFileNamePrefix = "index-"; constexpr char const* kValueFileNamePrefix = "value-"; constexpr char const* kLockFileName = "LOCK"; constexpr char const* kKeySizeFileName = "KEY_SIZE"; constexpr char const* kValueSizeFileName = "VALUE_SIZE"; constexpr char const* kPhysicalBlockSizeFileName = "PHYSICAL_BLOCK_SIZE"; constexpr char const* kNumLogicalBlocksPerChunkFileName = "NUM_LOGICAL_BLOCKS_PER_CHUNK"; constexpr char const* kKeysDirName = "keys"; constexpr char const* kValuesDirName = "values"; constexpr char const* kSnapshotsDirName = "snapshots"; constexpr char const* kSnapshotListFileName = "LIST"; constexpr size_t kParallelForStride = 256; template T* BytesOffset(T* ptr, size_t bytes) { return reinterpret_cast( const_cast((reinterpret_cast(ptr) + bytes))); } void MemcpyOffset(void* dst, size_t dst_off, const void* src, size_t src_off, size_t n) { std::memcpy(BytesOffset(dst, dst_off), BytesOffset(src, src_off), n); } void InitOrCheckMetaValue(const std::string& pathname, int64_t expected, bool init) { bool exists = PosixFile::FileExists(pathname); if (init) { CHECK(!exists) << pathname; std::ofstream ofs(pathname); ofs << expected << std::endl; } else { CHECK(exists); std::ifstream ifs(pathname); int64_t value = 0; ifs >> value; if (value != expected) { LOG(FATAL) << "Check failed: " << pathname; } } } std::string GetChunkName(uint64_t chunk_id) { const std::string chunk_name_wo_leading_zero = std::to_string(chunk_id); CHECK_LE(chunk_name_wo_leading_zero.size(), kChunkNameSuffixLength); return std::string(kChunkNameSuffixLength - chunk_name_wo_leading_zero.size(), '0') + chunk_name_wo_leading_zero; } uint64_t GetChunkId(const std::string& chunk_name) { size_t pos = 0; const uint64_t chunk_id = std::stoull(chunk_name, &pos, 10); CHECK_EQ(pos, kChunkNameSuffixLength); return chunk_id; } uint64_t GetChunkId(const std::string& filename, const std::string& prefix) { CHECK_EQ(filename.compare(0, prefix.size(), prefix), 0); return GetChunkId(filename.substr(prefix.size())); } void ListChunkFiles(const std::string& base, const std::string& prefix, std::unordered_map* chunks) { DIR* dir = opendir(base.c_str()); PCHECK(dir != nullptr); struct dirent* ent = nullptr; while ((ent = readdir(dir)) != nullptr) { if (strlen(ent->d_name) != prefix.size() + kChunkNameSuffixLength) { continue; } if (strncmp(ent->d_name, prefix.c_str(), prefix.size()) != 0) { continue; } const uint64_t chunk_id = GetChunkId(ent->d_name + prefix.size()); CHECK(chunks->emplace(chunk_id, PosixFile::JoinPath(base, ent->d_name)).second); } PCHECK(closedir(dir) == 0); } uint32_t GetLogicalBlockSize(uint32_t physical_block_size, uint32_t value_size) { return physical_block_size >= value_size ? physical_block_size : RoundUp(value_size, physical_block_size); } class AlignedBuffer final { public: OF_DISALLOW_COPY_AND_MOVE(AlignedBuffer); explicit AlignedBuffer(size_t alignment) : alignment_(alignment), size_(0) {} ~AlignedBuffer() = default; void Resize(size_t new_size) { if (new_size > size_) { ptr_.reset(static_cast(aligned_alloc(alignment_, new_size))); size_ = new_size; } } void* ptr() { return ptr_.get(); } private: size_t alignment_; size_t size_; std::unique_ptr ptr_; }; template class ChunkIteratorImpl : public PersistentTable::Iterator { public: OF_DISALLOW_COPY_AND_MOVE(ChunkIteratorImpl); ChunkIteratorImpl(uint32_t value_size, uint32_t logical_block_size, uint32_t num_values_per_block, uint64_t num_values_per_chunk, uint64_t chunk_id, uint64_t n, const Key* chunk_keys, const uint64_t* chunk_indices, const void* chunk_values) : pos_(0), value_size_(value_size), logical_block_size_(logical_block_size), num_values_per_block_(num_values_per_block), num_values_per_chunk_(num_values_per_chunk), n_(n), chunk_keys_(chunk_keys), chunk_indices_(chunk_indices), chunk_values_(chunk_values), chunk_index_offset_(chunk_id * num_values_per_chunk_) {} ~ChunkIteratorImpl() override = default; void Next(uint32_t num_keys, uint32_t* return_keys, void* keys, void* values) override { uint32_t count = 0; while (count < num_keys && pos_ != n_) { const uint64_t index_in_chunk = chunk_indices_[pos_] - chunk_index_offset_; static_cast(keys)[count] = chunk_keys_[index_in_chunk]; const uint64_t block_in_chunk = index_in_chunk / num_values_per_block_; const uint32_t index_in_block = index_in_chunk - block_in_chunk * num_values_per_block_; const uint32_t value_offset = block_in_chunk * logical_block_size_ + index_in_block * value_size_; std::memcpy(static_cast(values) + count * value_size_, static_cast(chunk_values_) + value_offset, value_size_); count++; pos_++; } *return_keys = count; } void Reset() override { pos_ = 0; } private: uint64_t pos_; uint32_t value_size_; uint32_t logical_block_size_; uint32_t num_values_per_block_; uint64_t num_values_per_chunk_; uint64_t n_; const Key* chunk_keys_; const uint64_t* chunk_indices_; const void* chunk_values_; uint64_t chunk_index_offset_; }; class AioEngine final { public: OF_DISALLOW_COPY_AND_MOVE(AioEngine); AioEngine() : ctx_{}, num_readings_(0) { PCHECK(syscall(__NR_io_setup, kAioQueueDepth, &ctx_) >= 0); cbs_.resize(kAioQueueDepth); cbs_ptr_.resize(kAioQueueDepth); for (uint32_t i = 0; i < kAioQueueDepth; ++i) { cbs_ptr_[i] = &cbs_[i]; } events_.resize(kAioQueueDepth); } ~AioEngine() { WaitUntilDone(); PCHECK(syscall(__NR_io_destroy, ctx_) >= 0); } void AsyncPread(int fd, void* buf, size_t count, off_t offset) { if (num_readings_ == kAioQueueDepth) { WaitUntilDone(); } struct iocb* cb = &cbs_.at(num_readings_); cb->aio_fildes = fd; cb->aio_lio_opcode = IOCB_CMD_PREAD; cb->aio_reqprio = 0; cb->aio_buf = reinterpret_cast(buf); cb->aio_nbytes = count; cb->aio_offset = offset; const long nr = 1; PCHECK(syscall(__NR_io_submit, ctx_, nr, &cbs_ptr_.at(num_readings_)) >= 0); num_readings_ += 1; } void WaitUntilDone() { if (num_readings_ != 0) { PCHECK(syscall(__NR_io_getevents, ctx_, num_readings_, num_readings_, events_.data(), nullptr) >= 0); for (long i = 0; i < num_readings_; ++i) { CHECK_GT(events_.at(i).res, 0); } num_readings_ = 0; } } private: aio_context_t ctx_; long num_readings_; std::vector cbs_; std::vector cbs_ptr_; std::vector events_; }; constexpr size_t kCacheLineSize = 64; template using IoTask = std::function; template using ForRange = std::function; template class Worker final { public: OF_DISALLOW_COPY_AND_MOVE(Worker); Worker() { thread_ = std::thread(&Worker::PullTask, this); } ~Worker() { Shutdown(); thread_.join(); } void Schedule(IoTask task) { tasks_.Send(std::move(task)); } void Shutdown() { tasks_.Close(); } private: void PullTask() { while (true) { IoTask task; const ChannelStatus status = tasks_.Receive(&task); if (status == ChannelStatus::kChannelStatusErrorClosed) { break; } CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess); task(&engine_); } } Channel> tasks_; Engine engine_; std::thread thread_; }; template class SnapshotIteratorImpl; template class PersistentTableImpl : public PersistentTable { public: OF_DISALLOW_COPY_AND_MOVE(PersistentTableImpl); explicit PersistentTableImpl(const PersistentTableOptions& options); ~PersistentTableImpl() override; uint32_t KeySize() const override { return key_size_; } uint32_t ValueSize() const override { return value_size_; } uint32_t LogicalBlockSize() const override; void GetBlocks(uint32_t num_keys, const void* keys, void* blocks, uint32_t* offsets) override; void Get(uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) override; void PutBlocks(uint32_t num_keys, const void* keys, const void* blocks) override; void Put(uint32_t num_keys, const void* keys, const void* values) override; bool SnapshotExists(const std::string& name) override; void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name, const std::function& Hook) override; void SaveSnapshot(const std::string& name) override; Iterator* ReadSnapshot(const std::string& name) override; private: friend class SnapshotIteratorImpl; std::string KeyFilePath(uint64_t chunk_id) const; std::string ValueFilePath(uint64_t chunk_id) const; std::string IndexFilePath(const std::string& name, uint64_t chunk_id) const; std::string SnapshotDirPath(const std::string& name) const; std::string SnapshotListFilePath(const std::string& name) const; void LoadSnapshotImpl(const std::string& name); void SaveSnapshotImpl(const std::string& name); void ParallelFor(size_t total, const ForRange& for_range); std::string root_dir_; std::string keys_dir_; std::string values_dir_; std::string snapshots_dir_; uint32_t key_size_; uint32_t value_size_; uint64_t num_logical_blocks_per_chunk_; uint64_t num_values_per_chunk_; uint32_t num_values_per_block_; uint32_t physical_block_size_; uint32_t logical_block_size_; std::vector>> workers_; std::vector offsets_buffer_; AlignedBuffer blocks_buffer_; std::recursive_mutex mutex_; uint64_t physical_table_size_; robin_hood::unordered_flat_map row_id_mapping_; std::vector value_files_; PosixFile writable_key_file_; uint64_t writable_key_file_chunk_id_; PosixFileLockGuard lock_; bool read_only_; }; template PersistentTableImpl::PersistentTableImpl(const PersistentTableOptions& options) : root_dir_(options.path), key_size_(options.key_size), value_size_(options.value_size), physical_block_size_(options.physical_block_size), logical_block_size_(GetLogicalBlockSize(options.physical_block_size, value_size_)), blocks_buffer_(options.physical_block_size), writable_key_file_chunk_id_(-1), read_only_(options.read_only) { const uint64_t capacity_hint = ParseIntegerFromEnv( "ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT", options.capacity_hint); if (capacity_hint > 0) { row_id_mapping_.reserve(capacity_hint); } PosixFile::RecursiveCreateDirectory(options.path, 0755); const std::string lock_filename = PosixFile::JoinPath(options.path, kLockFileName); const bool init = !PosixFile::FileExists(lock_filename); if (read_only_) { CHECK(!init) << "The table must be initialized in read only mode"; } else { lock_ = PosixFileLockGuard(PosixFile(lock_filename, O_CREAT | O_RDWR, 0644)); } const uint64_t target_chunk_size = options.target_chunk_size_mb * 1024 * 1024; CHECK_GE(target_chunk_size, logical_block_size_); num_logical_blocks_per_chunk_ = target_chunk_size / logical_block_size_, num_values_per_block_ = logical_block_size_ / value_size_; num_values_per_chunk_ = num_values_per_block_ * num_logical_blocks_per_chunk_; InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kKeySizeFileName), key_size_, init); InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kValueSizeFileName), value_size_, init); InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kPhysicalBlockSizeFileName), options.physical_block_size, init); InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kNumLogicalBlocksPerChunkFileName), num_logical_blocks_per_chunk_, init); keys_dir_ = PosixFile::JoinPath(options.path, kKeysDirName); values_dir_ = PosixFile::JoinPath(options.path, kValuesDirName); snapshots_dir_ = PosixFile::JoinPath(options.path, kSnapshotsDirName); if (init) { PosixFile::RecursiveCreateDirectory(keys_dir_, 0755); PosixFile::RecursiveCreateDirectory(values_dir_, 0755); } const uint32_t num_workers = ParseIntegerFromEnv( "ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_NUM_WORKERS", kDefaultNumWorkerThreads); workers_.resize(num_workers); for (uint32_t tid = 0; tid < workers_.size(); ++tid) { workers_.at(tid).reset(new Worker); } std::unordered_map chunks; ListChunkFiles(values_dir_, kValueFileNamePrefix, &chunks); for (auto& chunk : chunks) { if (value_files_.size() <= chunk.first) { value_files_.resize(chunk.first + 1); } CHECK_EQ(value_files_.at(chunk.first).fd(), -1); const int flags = read_only_ ? (O_RDONLY | O_DIRECT) : (O_RDWR | O_DIRECT); PosixFile value_file(chunk.second, flags, 0644); value_files_.at(chunk.first) = std::move(value_file); } if (!value_files_.empty()) { physical_table_size_ = ((value_files_.size() - 1) * num_logical_blocks_per_chunk_ + value_files_.back().Size() / logical_block_size_) * num_values_per_block_; } else { physical_table_size_ = 0; } } template PersistentTableImpl::~PersistentTableImpl() { for (uint32_t tid = 0; tid < workers_.size(); ++tid) { workers_.at(tid)->Shutdown(); } } template uint32_t PersistentTableImpl::LogicalBlockSize() const { return logical_block_size_; } template void PersistentTableImpl::GetBlocks(uint32_t num_keys, const void* keys, void* blocks, uint32_t* offsets) { std::lock_guard lock(mutex_); ParallelFor(num_keys, [&](Engine* engine, size_t start, size_t end) { for (uint64_t i = start; i < end; ++i) { const Key key = static_cast(keys)[i]; auto it = row_id_mapping_.find(key); if (it == row_id_mapping_.end()) { offsets[i] = logical_block_size_; } else { const uint64_t id = it->second; const uint64_t block_id = id / num_values_per_block_; const uint32_t id_in_block = id - block_id * num_values_per_block_; const uint32_t offset_in_block = id_in_block * value_size_; const uint64_t chunk_id = block_id / num_logical_blocks_per_chunk_; const uint64_t block_in_chunk = block_id - chunk_id * num_logical_blocks_per_chunk_; const uint64_t block_offset = block_in_chunk * logical_block_size_; PosixFile& file = value_files_.at(chunk_id); offsets[i] = offset_in_block; engine->AsyncPread(file.fd(), BytesOffset(blocks, i * logical_block_size_), logical_block_size_, block_offset); } } }); } template void PersistentTableImpl::Get(uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) { std::lock_guard lock(mutex_); offsets_buffer_.resize(num_keys); void* blocks_ptr = nullptr; if (value_size_ == logical_block_size_ && reinterpret_cast(values) % physical_block_size_ == 0) { blocks_ptr = values; } else { blocks_buffer_.Resize(num_keys * logical_block_size_); blocks_ptr = blocks_buffer_.ptr(); } GetBlocks(num_keys, keys, blocks_ptr, offsets_buffer_.data()); uint32_t missing_count = 0; for (uint32_t i = 0; i < num_keys; ++i) { if (offsets_buffer_.at(i) == logical_block_size_) { missing_indices[missing_count] = i; missing_count += 1; } else { if (value_size_ != logical_block_size_) { MemcpyOffset(values, i * value_size_, blocks_ptr, (i * logical_block_size_) + offsets_buffer_[i], value_size_); } } } *n_missing = missing_count; } template void PersistentTableImpl::PutBlocks(uint32_t num_keys, const void* keys, const void* blocks) { CHECK(!read_only_); std::lock_guard lock(mutex_); const uint32_t num_blocks = RoundUp(num_keys, num_values_per_block_) / num_values_per_block_; const uint32_t num_padded_keys = num_blocks * num_values_per_block_; const uint64_t start_index = physical_table_size_; physical_table_size_ += num_padded_keys; CHECK_EQ(start_index % num_values_per_block_, 0); const uint64_t start_block_id = start_index / num_values_per_block_; uint64_t written_blocks = 0; const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key); BlockingCounter bc(1); workers_.at(0)->Schedule([&](Engine*) { while (written_blocks < num_blocks) { const uint64_t batch_start_block_id = start_block_id + written_blocks; const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; if (batch_chunk_id == value_files_.size()) { value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644); } else { CHECK_LE(batch_chunk_id, value_files_.size()); } if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) { writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644); } PosixFile& value_file = value_files_.at(batch_chunk_id); const uint64_t block_id_in_chunk = batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_; const uint64_t blocks_to_write = std::min(num_blocks - written_blocks, (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id); const uint64_t values_bytes = blocks_to_write * logical_block_size_; const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_; CHECK_LE(value_file.Size(), values_offset_in_file); value_file.Truncate(values_offset_in_file + values_bytes); PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_), values_bytes, values_offset_in_file) == values_bytes); const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size; writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size); const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, blocks_to_write * num_values_per_block_) * sizeof(Key); PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size), keys_bytes, keys_offset_in_file) == keys_bytes); written_blocks += blocks_to_write; } bc.Decrease(); }); for (uint64_t i = 0; i < num_keys; ++i) { row_id_mapping_[static_cast(keys)[i]] = start_index + i; } bc.WaitForeverUntilCntEqualZero(); } template void PersistentTableImpl::Put(uint32_t num_keys, const void* keys, const void* values) { CHECK(!read_only_); std::lock_guard lock(mutex_); const void* blocks_ptr = nullptr; if (value_size_ == logical_block_size_ && reinterpret_cast(values) % physical_block_size_ == 0) { blocks_ptr = values; } else { const uint32_t num_blocks = RoundUp(num_keys, num_values_per_block_); blocks_buffer_.Resize(num_blocks * logical_block_size_); for (uint32_t i = 0; i < num_keys; i += num_values_per_block_) { const uint32_t block_id = i / num_values_per_block_; const uint32_t copy_size = (num_keys - i) < num_values_per_block_ ? (num_keys - i) * value_size_ : logical_block_size_; MemcpyOffset(blocks_buffer_.ptr(), block_id * logical_block_size_, values, i * value_size_, copy_size); } blocks_ptr = blocks_buffer_.ptr(); } PutBlocks(num_keys, keys, blocks_ptr); } template std::string PersistentTableImpl::KeyFilePath(uint64_t chunk_id) const { return PosixFile::JoinPath(keys_dir_, kKeyFileNamePrefix + GetChunkName(chunk_id)); } template std::string PersistentTableImpl::ValueFilePath(uint64_t chunk_id) const { return PosixFile::JoinPath(values_dir_, kValueFileNamePrefix + GetChunkName(chunk_id)); } template std::string PersistentTableImpl::IndexFilePath(const std::string& name, uint64_t chunk_id) const { return PosixFile::JoinPath(SnapshotDirPath(name), kIndexFileNamePrefix + GetChunkName(chunk_id)); } template std::string PersistentTableImpl::SnapshotDirPath(const std::string& name) const { return PosixFile::JoinPath(snapshots_dir_, name); } template std::string PersistentTableImpl::SnapshotListFilePath(const std::string& name) const { return PosixFile::JoinPath(SnapshotDirPath(name), kSnapshotListFileName); } template void PersistentTableImpl::LoadSnapshotImpl(const std::string& name) { std::lock_guard lock(mutex_); const std::string snapshot_base = SnapshotDirPath(name); const std::string snapshot_list = SnapshotListFilePath(name); row_id_mapping_.clear(); std::ifstream list_if(snapshot_list); std::string index_filename; while (std::getline(list_if, index_filename)) { const uint64_t chunk_id = GetChunkId(index_filename, kIndexFileNamePrefix); PosixFile index_file(PosixFile::JoinPath(snapshot_base, index_filename), O_RDONLY, 0644); const size_t index_file_size = index_file.Size(); CHECK_EQ(index_file_size % sizeof(uint64_t), 0); if (index_file_size == 0) { return; } const size_t n_entries = index_file_size / sizeof(uint64_t); PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ); PosixFile key_file(KeyFilePath(chunk_id), O_RDONLY, 0644); PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ); const uint64_t* indices = static_cast(mapped_index.ptr()); const Key* keys = static_cast(mapped_key.ptr()); const uint64_t chunk_start_index = chunk_id * num_values_per_chunk_; row_id_mapping_.reserve(row_id_mapping_.size() + n_entries); for (size_t i = 0; i < n_entries; ++i) { CHECK(row_id_mapping_.emplace(keys[indices[i] - chunk_start_index], indices[i]).second); } } } template void PersistentTableImpl::SaveSnapshotImpl(const std::string& name) { CHECK(!read_only_); std::lock_guard lock(mutex_); PosixFile::RecursiveCreateDirectory(SnapshotDirPath(name), 0755); std::ofstream list_ofs(SnapshotListFilePath(name)); if (row_id_mapping_.empty()) { return; } std::vector index_files(value_files_.size()); std::vector counters(value_files_.size()); const uint64_t max_index_file_size = num_values_per_chunk_ * sizeof(uint64_t); for (const auto& pair : row_id_mapping_) { const uint64_t chunk_id = pair.second / num_values_per_chunk_; CHECK(chunk_id < value_files_.size()); if (index_files[chunk_id].ptr() == nullptr) { PosixFile snapshot_file(IndexFilePath(name, chunk_id), O_CREAT | O_RDWR, 0644); snapshot_file.Truncate(max_index_file_size); index_files[chunk_id] = PosixMappedFile(std::move(snapshot_file), max_index_file_size, PROT_READ | PROT_WRITE); } uint64_t* indices = static_cast(index_files[chunk_id].ptr()); uint64_t& count = counters[chunk_id]; CHECK_LT(count, num_values_per_chunk_); indices[count] = pair.second; count += 1; } for (size_t i = 0; i < value_files_.size(); ++i) { const uint64_t count = counters[i]; if (count > 0) { index_files[i].file().Truncate(count * sizeof(uint64_t)); list_ofs << kIndexFileNamePrefix + GetChunkName(i) << std::endl; } else { CHECK(index_files[i].ptr() == nullptr); } } } template bool PersistentTableImpl::SnapshotExists(const std::string& name) { std::lock_guard lock(mutex_); return PosixFile::FileExists(SnapshotListFilePath(name)); } template void PersistentTableImpl::LoadSnapshot(const std::string& name) { LoadSnapshotImpl(name); } template void PersistentTableImpl::LoadSnapshot( const std::string& name, const std::function& Hook) { std::lock_guard lock(mutex_); int mmap_flags = MAP_SHARED; if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_SNAPSHOT_LOAD_MAP_POPULATE", true)) { mmap_flags |= MAP_POPULATE; } const std::string snapshot_base = SnapshotDirPath(name); const std::string snapshot_list = SnapshotListFilePath(name); row_id_mapping_.clear(); std::ifstream list_if(snapshot_list); std::string index_filename; while (std::getline(list_if, index_filename)) { const uint64_t chunk_id = GetChunkId(index_filename, kIndexFileNamePrefix); PosixFile index_file(PosixFile::JoinPath(snapshot_base, index_filename), O_RDONLY, 0644); const size_t index_file_size = index_file.Size(); CHECK_EQ(index_file_size % sizeof(uint64_t), 0); if (index_file_size == 0) { return; } const size_t n_entries = index_file_size / sizeof(uint64_t); PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ, mmap_flags); PosixFile key_file(KeyFilePath(chunk_id), O_RDONLY, 0644); PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ, mmap_flags); const uint64_t* indices = static_cast(mapped_index.ptr()); const Key* keys = static_cast(mapped_key.ptr()); const uint64_t chunk_start_index = chunk_id * num_values_per_chunk_; row_id_mapping_.reserve(row_id_mapping_.size() + n_entries); for (size_t i = 0; i < n_entries; ++i) { CHECK(row_id_mapping_.emplace(keys[indices[i] - chunk_start_index], indices[i]).second); } if (Hook) { PosixFile value_file(ValueFilePath(chunk_id), O_RDONLY, 0644); PosixMappedFile mapped_value(std::move(value_file), value_file.Size(), PROT_READ, mmap_flags); ChunkIteratorImpl chunk_iterator(value_size_, logical_block_size_, num_values_per_block_, num_values_per_chunk_, chunk_id, n_entries, keys, indices, mapped_value.ptr()); Hook(&chunk_iterator); } } } template void PersistentTableImpl::SaveSnapshot(const std::string& name) { SaveSnapshotImpl(name); } template PersistentTable::Iterator* PersistentTableImpl::ReadSnapshot(const std::string& name) { return new SnapshotIteratorImpl(this, name, value_size_, logical_block_size_, num_values_per_block_, num_values_per_chunk_); } template void PersistentTableImpl::ParallelFor(size_t total, const ForRange& for_range) { BlockingCounter bc(workers_.size()); std::atomic counter(0); for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule([&](Engine* engine) { while (true) { const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed); if (start >= total) { break; } const size_t next_start = start + kParallelForStride; const size_t end = std::min(next_start, total); for_range(engine, start, end); } engine->WaitUntilDone(); bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } template class SnapshotIteratorImpl : public PersistentTable::Iterator { public: OF_DISALLOW_COPY_AND_MOVE(SnapshotIteratorImpl); SnapshotIteratorImpl(PersistentTableImpl* table, const std::string& snapshot_name, uint32_t value_size, uint32_t logical_block_size, uint32_t num_values_per_block, uint64_t num_values_per_chunk) : table_(table), snapshot_name_(snapshot_name), value_size_(value_size), logical_block_size_(logical_block_size), num_values_per_block_(num_values_per_block), num_values_per_chunk_(num_values_per_chunk), current_chunk_(0) { const std::string snapshot_list = table_->SnapshotListFilePath(snapshot_name); std::ifstream list_if(snapshot_list); std::string index_filename; while (std::getline(list_if, index_filename)) { indices_names_.push_back(index_filename); } } ~SnapshotIteratorImpl() override = default; void Next(uint32_t num_keys, uint32_t* return_keys, void* keys, void* values) override { *return_keys = 0; while (current_chunk_ < indices_names_.size()) { if (!chunk_iterator_) { const std::string snapshot_base = table_->SnapshotDirPath(snapshot_name_); const uint64_t chunk_id = GetChunkId(indices_names_[current_chunk_], kIndexFileNamePrefix); PosixFile index_file(PosixFile::JoinPath(snapshot_base, indices_names_[current_chunk_]), O_RDONLY, 0644); const size_t index_file_size = index_file.Size(); CHECK_EQ(index_file_size % sizeof(uint64_t), 0); if (index_file_size == 0) { current_chunk_ += 1; continue; } const size_t n_entries = index_file_size / sizeof(uint64_t); indices_file_.reset(new PosixMappedFile(std::move(index_file), index_file_size, PROT_READ)); PosixFile key_file(table_->KeyFilePath(chunk_id), O_RDONLY, 0644); keys_file_.reset(new PosixMappedFile(std::move(key_file), key_file.Size(), PROT_READ)); PosixFile value_file(table_->ValueFilePath(chunk_id), O_RDONLY, 0644); values_file_.reset( new PosixMappedFile(std::move(value_file), value_file.Size(), PROT_READ)); chunk_iterator_.reset(new ChunkIteratorImpl( value_size_, logical_block_size_, num_values_per_block_, num_values_per_chunk_, chunk_id, n_entries, static_cast(keys_file_->ptr()), static_cast(indices_file_->ptr()), values_file_->ptr())); } chunk_iterator_->Next(num_keys, return_keys, keys, values); if (*return_keys == 0) { chunk_iterator_.reset(); keys_file_.reset(); values_file_.reset(); indices_file_.reset(); current_chunk_ += 1; continue; } else { return; } } } void Reset() override { UNIMPLEMENTED(); } private: PersistentTableImpl* table_; std::string snapshot_name_; uint32_t value_size_; uint32_t logical_block_size_; uint32_t num_values_per_block_; uint64_t num_values_per_chunk_; size_t current_chunk_; std::vector indices_names_; std::unique_ptr keys_file_; std::unique_ptr values_file_; std::unique_ptr indices_file_; std::unique_ptr> chunk_iterator_; }; template std::unique_ptr DispatchKeyType(const PersistentTableOptions& options) { if (options.key_size == 4) { return std::unique_ptr(new PersistentTableImpl(options)); } else if (options.key_size == 8) { return std::unique_ptr(new PersistentTableImpl(options)); } else { UNIMPLEMENTED(); return nullptr; } } std::unique_ptr DispatchEngine(const PersistentTableOptions& options) { return DispatchKeyType(options); } } // namespace #endif // __linux__ std::unique_ptr NewPersistentTable(const PersistentTableOptions& options) { #ifdef __linux__ CHECK(!options.path.empty()); CHECK_GT(options.value_size, 0); CHECK_GT(options.target_chunk_size_mb, 0); CHECK_GT(options.physical_block_size, 0); CHECK_GT(options.key_size, 0); return DispatchEngine(options); #else UNIMPLEMENTED(); return nullptr; #endif // __linux__ } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/persistent_table.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_ #define ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace embedding { struct PersistentTableOptions { std::string path; uint32_t key_size = 0; uint32_t value_size = 0; uint64_t target_chunk_size_mb = 4 * 1024; uint16_t physical_block_size = 4096; uint64_t capacity_hint = 0; bool read_only = false; }; class PersistentTable { public: OF_DISALLOW_COPY_AND_MOVE(PersistentTable); PersistentTable() = default; virtual ~PersistentTable() = default; class Iterator { public: OF_DISALLOW_COPY_AND_MOVE(Iterator); Iterator() = default; virtual ~Iterator() = default; virtual void Next(uint32_t n_request, uint32_t* n_result, void* keys, void* values) = 0; virtual void Reset() = 0; }; virtual uint32_t KeySize() const = 0; virtual uint32_t ValueSize() const = 0; virtual uint32_t LogicalBlockSize() const = 0; virtual void GetBlocks(uint32_t num_keys, const void* keys, void* blocks, uint32_t* offsets) = 0; virtual void Get(uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) = 0; virtual void PutBlocks(uint32_t num_keys, const void* keys, const void* blocks) = 0; virtual void Put(uint32_t num_keys, const void* keys, const void* values) = 0; virtual bool SnapshotExists(const std::string& name) = 0; virtual void LoadSnapshot(const std::string& name) = 0; virtual void LoadSnapshot(const std::string& name, const std::function& Hook) = 0; virtual void SaveSnapshot(const std::string& name) = 0; virtual Iterator* ReadSnapshot(const std::string& name) = 0; }; std::unique_ptr NewPersistentTable(const PersistentTableOptions& options); } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_ ================================================ FILE: oneflow/core/embedding/persistent_table_key_value_store.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/embedding/persistent_table_key_value_store.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/embedding/persistent_table.h" #include #include #include #include #include namespace oneflow { namespace embedding { namespace { class IteratorImpl : public KVIterator { public: OF_DISALLOW_COPY_AND_MOVE(IteratorImpl); IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size, uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer, uint32_t* host_num_buffer) : base_iter_(base_iter), key_size_(key_size), value_size_(value_size), max_query_length_(max_query_length), host_keys_buffer_(host_keys_buffer), host_values_buffer_(host_values_buffer), host_num_buffer_(host_num_buffer) {} ~IteratorImpl() override = default; void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys, void* values) override { CHECK_LE(n_request, max_query_length_); auto cuda_stream = stream->As(); CHECK_JUST(cuda_stream->Sync()); base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_); OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); const uint32_t num_keys = *host_num_buffer_; if (num_keys != 0) { OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); } } void Reset() override { base_iter_->Reset(); } private: PersistentTable::Iterator* base_iter_; uint32_t key_size_; uint32_t value_size_; uint32_t max_query_length_; void* host_keys_buffer_; void* host_values_buffer_; uint32_t* host_num_buffer_; }; template class KeyValueStoreImpl : public KeyValueStore { public: OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl); explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options) : device_index_(-1), max_query_length_(0) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); key_size_ = options.table_options.key_size; value_size_ = options.table_options.value_size; table_ = NewPersistentTable(options.table_options); OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_keys_), key_size_ * max_query_length_)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_query_values_), value_size_ * max_query_length_)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_n_missing_), sizeof(uint32_t))); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_missing_indices_), sizeof(uint32_t) * max_query_length_)); } ~KeyValueStoreImpl() { CudaCurrentDeviceGuard guard(device_index_); if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFreeHost(host_query_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_query_values_)); OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_)); } OF_CUDA_CHECK(cudaFreeHost(host_n_missing_)); } uint32_t KeySize() const override { return key_size_; } uint32_t ValueSize() const override { return value_size_; } uint32_t MaxQueryLength() const override { return max_query_length_; } void ReserveQueryLength(uint32_t query_length) override { CudaCurrentDeviceGuard guard(device_index_); if (query_length <= max_query_length_) { return; } if (max_query_length_ != 0) { OF_CUDA_CHECK(cudaFreeHost(host_query_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_query_values_)); OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_)); } OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_keys_), key_size_ * query_length)); OF_CUDA_CHECK(NumaAwareCudaMallocHost( device_index_, reinterpret_cast(&host_query_values_), value_size_ * query_length)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast(&host_missing_indices_), sizeof(uint32_t) * query_length)); max_query_length_ = query_length; } using KeyValueStore::Get; void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; bool SnapshotExists(const std::string& name) override; void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name, const std::function& Hook) override; void SaveSnapshot(const std::string& name) override; private: int device_index_; uint32_t max_query_length_; uint32_t key_size_; uint32_t value_size_; Key* host_query_keys_{}; uint8_t* host_query_values_{}; uint32_t* host_n_missing_{}; uint32_t* host_missing_indices_{}; std::mutex mutex_; std::unique_ptr table_; }; template void KeyValueStoreImpl::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing, uint32_t* missing_indices) { std::lock_guard lock(mutex_); auto cuda_stream = stream->As(); CHECK_LE(num_keys, max_query_length_); if (num_keys == 0) { OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As()->cuda_stream())); return; } OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_, host_missing_indices_); OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_, (*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault, cuda_stream->cuda_stream())); } template void KeyValueStoreImpl::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) { std::lock_guard lock(mutex_); auto cuda_stream = stream->As(); CHECK_LE(num_keys, max_query_length_); if (num_keys == 0) { return; } OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys, cudaMemcpyDefault, cuda_stream->cuda_stream())); CHECK_JUST(cuda_stream->Sync()); table_->Put(num_keys, host_query_keys_, host_query_values_); } template bool KeyValueStoreImpl::SnapshotExists(const std::string& name) { return table_->SnapshotExists(name); } template void KeyValueStoreImpl::LoadSnapshot(const std::string& name) { CudaCurrentDeviceGuard guard(device_index_); LoadSnapshot(name, nullptr); } template void KeyValueStoreImpl::LoadSnapshot(const std::string& name, const std::function& Hook) { CudaCurrentDeviceGuard guard(device_index_); if (Hook) { table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) { IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_, host_query_keys_, host_query_values_, host_n_missing_); Hook(&iterator); }); } else { table_->LoadSnapshot(name); } } template void KeyValueStoreImpl::SaveSnapshot(const std::string& name) { CudaCurrentDeviceGuard guard(device_index_); table_->SaveSnapshot(name); } } // namespace std::unique_ptr NewPersistentTableKeyValueStore( const PersistentTableKeyValueStoreOptions& options) { if (options.table_options.key_size == sizeof(uint64_t)) { return std::unique_ptr(new KeyValueStoreImpl(options)); } else if (options.table_options.key_size == sizeof(uint32_t)) { return std::unique_ptr(new KeyValueStoreImpl(options)); } else { UNIMPLEMENTED(); return nullptr; } } } // namespace embedding } // namespace oneflow ================================================ FILE: oneflow/core/embedding/persistent_table_key_value_store.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_ #define ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_ #include "oneflow/core/embedding/key_value_store.h" #include "oneflow/core/embedding/persistent_table.h" namespace oneflow { namespace embedding { #ifdef WITH_CUDA struct PersistentTableKeyValueStoreOptions { PersistentTableOptions table_options{}; }; std::unique_ptr NewPersistentTableKeyValueStore( const PersistentTableKeyValueStoreOptions& options); #endif // WITH_CUDA } // namespace embedding } // namespace oneflow #endif // ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_ ================================================ FILE: oneflow/core/embedding/posix_file.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_ #define ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_ #ifdef __linux__ #include #include #include #include #include #include #include #include #include #include namespace oneflow { namespace embedding { class PosixFile final { public: PosixFile() : fd_(-1), size_(0) {} PosixFile(const std::string& pathname, int flags, mode_t mode) : PosixFile(pathname.c_str(), flags, mode) {} PosixFile(const char* pathname, int flags, mode_t mode) : PosixFile() { fd_ = open(pathname, flags, mode); PCHECK(fd_ != -1); struct stat sb {}; PCHECK(fstat(fd_, &sb) == 0); size_ = sb.st_size; } PosixFile(PosixFile&& other) noexcept : PosixFile() { *this = std::move(other); } PosixFile(const PosixFile&) = delete; ~PosixFile() { Close(); } PosixFile& operator=(PosixFile&& other) noexcept { this->Close(); fd_ = other.fd_; other.fd_ = -1; size_ = other.size_; other.size_ = 0; return *this; } PosixFile& operator=(const PosixFile&) = delete; int fd() { return fd_; } bool IsOpen() { return fd_ != -1; } void Close() { if (IsOpen()) { PCHECK(close(fd_) == 0); fd_ = -1; } } size_t Size() { return size_; } void Truncate(size_t new_size) { CHECK(IsOpen()); if (new_size == size_) { return; } PCHECK(ftruncate(fd_, new_size) == 0); size_ = new_size; } static bool FileExists(const std::string& pathname) { return access(pathname.c_str(), F_OK) == 0; } static std::string JoinPath(const std::string& a, const std::string& b) { return a + "/" + b; } static void RecursiveCreateDirectory(const std::string& pathname, mode_t mode) { while (true) { struct stat sb {}; if (stat(pathname.c_str(), &sb) == 0) { CHECK(S_ISDIR(sb.st_mode)) << "Could not create directory: '" << pathname << "' already exists and is not a directory."; return; } else { PCHECK(errno == ENOENT) << "Could not create directory '" << pathname << "'."; if (lstat(pathname.c_str(), &sb) == 0) { LOG(FATAL) << "Could not create directory: '" << pathname << "' is a broken link."; } else { PCHECK(errno == ENOENT) << "Could not create directory '" << pathname << "'."; } std::vector dirname_input(pathname.size() + 1); std::memcpy(dirname_input.data(), pathname.c_str(), pathname.size() + 1); const std::string parent = dirname(dirname_input.data()); RecursiveCreateDirectory(parent, mode); if (mkdir(pathname.c_str(), mode) == 0) { return; } else { PCHECK(errno == EEXIST) << "Could not create directory '" << pathname << "'."; } } } } static void RecursiveDelete(const std::string& pathname) { struct stat sb {}; if (stat(pathname.c_str(), &sb) == 0) { if (S_ISDIR(sb.st_mode)) { DIR* dir = opendir(pathname.c_str()); PCHECK(dir != nullptr); struct dirent* ent = nullptr; while ((ent = readdir(dir)) != nullptr) { if (strcmp(ent->d_name, ".") == 0 || strcmp(ent->d_name, "..") == 0) { continue; } RecursiveDelete(pathname + "/" + ent->d_name); } PCHECK(closedir(dir) == 0); PCHECK(rmdir(pathname.c_str()) == 0); } else { PCHECK(unlink(pathname.c_str()) == 0); } } else { PCHECK(errno == ENOENT); } } private: int fd_; size_t size_; }; class PosixMappedFile final { public: PosixMappedFile() : file_(), ptr_(nullptr) {} PosixMappedFile(PosixFile&& file, size_t size, int prot, int flags) : file_(std::move(file)), ptr_(nullptr) { CHECK_NE(file_.fd(), -1); void* ptr = mmap(nullptr, size, prot, flags, file_.fd(), 0); PCHECK(ptr != MAP_FAILED); ptr_ = ptr; } PosixMappedFile(PosixFile&& file, size_t size, int prot) : PosixMappedFile(std::move(file), size, prot, MAP_SHARED) {} PosixMappedFile(PosixMappedFile&& other) noexcept : PosixMappedFile() { *this = std::move(other); } PosixMappedFile(const PosixMappedFile&) = delete; ~PosixMappedFile() { Unmap(); } PosixMappedFile& operator=(PosixMappedFile&& other) noexcept { Unmap(); this->file_ = std::move(other.file_); this->ptr_ = other.ptr_; other.ptr_ = nullptr; return *this; } PosixMappedFile& operator=(const PosixMappedFile&) = delete; void* ptr() { return ptr_; } PosixFile& file() { return file_; } private: void Unmap() { if (ptr_ != nullptr) { PCHECK(munmap(ptr_, file_.Size()) == 0); } } PosixFile file_; void* ptr_; }; class PosixFileLockGuard final { public: OF_DISALLOW_COPY(PosixFileLockGuard); explicit PosixFileLockGuard() : file_() {} explicit PosixFileLockGuard(PosixFile&& file) : file_(std::move(file)) { CHECK_NE(file_.fd(), -1); Lock(); } PosixFileLockGuard(PosixFileLockGuard&& other) noexcept { *this = std::move(other); } PosixFileLockGuard& operator=(PosixFileLockGuard&& other) noexcept { Unlock(); file_ = std::move(other.file_); return *this; } ~PosixFileLockGuard() { Unlock(); } private: void Lock() { if (file_.fd() != -1) { struct flock f {}; f.l_type = F_WRLCK; f.l_whence = SEEK_SET; f.l_start = 0; f.l_len = 0; PCHECK(fcntl(file_.fd(), F_SETLK, &f) == 0); } } void Unlock() { if (file_.fd() != -1) { struct flock f {}; f.l_type = F_UNLCK; f.l_whence = SEEK_SET; f.l_start = 0; f.l_len = 0; PCHECK(fcntl(file_.fd(), F_SETLK, &f) == 0); } } PosixFile file_; }; } // namespace embedding } // namespace oneflow #endif // __linux__ #endif // ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_ ================================================ FILE: oneflow/core/ep/common/active_device_guard.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/active_device_guard.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace ep { ActiveDeviceGuard::ActiveDeviceGuard(Device* device) : device_manager_(device->device_manager()) { saved_active_device_ = device_manager_->GetActiveDeviceIndex(); device->SetAsActiveDevice(); } ActiveDeviceGuard::~ActiveDeviceGuard() { device_manager_->SetActiveDeviceByIndex(saved_active_device_); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/device.h" namespace oneflow { namespace ep { Event* Device::CreateEvent() { Event* event = nullptr; this->CreateEvents(&event, 1); return event; } void Device::DestroyEvent(Event* event) { this->DestroyEvents(&event, 1); } bool Device::IsStreamOrderedMemoryAllocationSupported() const { return false; } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/device_manager_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/device_manager.h" namespace oneflow { namespace ep { class DeviceManagerRegistry::Impl { public: OF_DISALLOW_COPY_AND_MOVE(Impl); explicit Impl(DeviceManagerRegistry* registry) : registry_(registry) { managers_.resize(DeviceType_ARRAYSIZE); } ~Impl() = default; DeviceManager* GetDeviceManagerOrNull(DeviceType device_type) { std::lock_guard lock(mutex_); if (!managers_.at(device_type)) { std::lock_guard factories_lock(*factories_mutex()); auto& factory = factories()->at(device_type); if (factory) { managers_.at(device_type) = factory->NewDeviceManager(registry_); } else { return nullptr; } } return managers_.at(device_type).get(); } DeviceManager* GetDeviceManager(DeviceType device_type) { return CHECK_NOTNULL(GetDeviceManagerOrNull(device_type)); } std::shared_ptr GetDevice(DeviceType device_type, size_t device_index) { return GetDeviceManager(device_type)->GetDevice(device_index); } size_t GetDeviceCount(DeviceType device_type) { DeviceManager* manager = GetDeviceManagerOrNull(device_type); if (manager == nullptr) { return 0; } else { return manager->GetDeviceCount(); } } size_t GetDeviceCount(const std::string& device_type_name) { return GetDeviceCount(GetDeviceTypeByDeviceTypeName(device_type_name)); } static void DumpVersionInfo() { std::lock_guard factories_lock(*factories_mutex()); for (auto& factory : *factories()) { if (factory) { factory->DumpVersionInfo(); } } } static std::string GetDeviceTypeNameByDeviceType(DeviceType device_type) { static thread_local std::vector device_type2device_type_name(DeviceType_ARRAYSIZE); { const std::string& name = device_type2device_type_name.at(device_type); if (!name.empty()) { return name; } } std::lock_guard factories_lock(*factories_mutex()); if (factories()->size() <= device_type) { return ""; } auto& factory = factories()->at(device_type); if (!factory) { return ""; } else { std::string name = factory->device_type_name(); device_type2device_type_name.at(device_type) = name; return name; } } static DeviceType GetDeviceTypeByDeviceTypeName(const std::string& device_type_name) { static thread_local HashMap device_type_name2device_type; { auto it = device_type_name2device_type.find(device_type_name); if (it != device_type_name2device_type.end()) { return it->second; } } std::lock_guard factories_lock(*factories_mutex()); auto it = device_type_name2device_type_map()->find(device_type_name); if (it == device_type_name2device_type_map()->end()) { return DeviceType::kInvalidDevice; } else { device_type_name2device_type[device_type_name] = it->second; return it->second; } } static void RegisterDeviceManagerFactory(std::unique_ptr&& factory) { CHECK(factory); const DeviceType device_type = factory->device_type(); std::lock_guard lock(*factories_mutex()); factories()->resize(DeviceType_ARRAYSIZE); CHECK(!factories()->at(device_type)); const std::string device_type_name = factory->device_type_name(); CHECK(!device_type_name.empty()); CHECK(device_type_name2device_type_map()->emplace(device_type_name, device_type).second); factories()->at(device_type) = std::move(factory); } static std::set GetRegisteredDeviceTypes() { std::lock_guard lock(*factories_mutex()); std::set types; for (auto& factory : *factories()) { if (factory) { types.insert(factory->device_type()); } } return types; } static bool IsDeviceTypeRegistered(DeviceType device_type) { std::lock_guard lock(*factories_mutex()); return factories()->at(device_type).operator bool(); } private: static HashMap* device_type_name2device_type_map() { static HashMap device_type_name2device_type; return &device_type_name2device_type; } static std::vector>* factories() { static std::vector> factories_vec; return &factories_vec; } static std::mutex* factories_mutex() { static std::mutex mutex; return &mutex; } std::mutex mutex_; std::vector> managers_; DeviceManagerRegistry* registry_; }; DeviceManagerRegistry::DeviceManagerRegistry() { impl_.reset(new Impl(this)); } DeviceManagerRegistry::~DeviceManagerRegistry() = default; DeviceManager* DeviceManagerRegistry::GetDeviceManager(DeviceType device_type) { return impl_->GetDeviceManager(device_type); } DeviceManager* DeviceManagerRegistry::GetDeviceManagerOrNull(DeviceType device_type) { return impl_->GetDeviceManagerOrNull(device_type); } std::shared_ptr DeviceManagerRegistry::GetDevice(DeviceType device_type, size_t device_index) { return impl_->GetDevice(device_type, device_index); } size_t DeviceManagerRegistry::GetDeviceCount(DeviceType device_type) { return impl_->GetDeviceCount(device_type); } size_t DeviceManagerRegistry::GetDeviceCount(const std::string& device_type_name) { return impl_->GetDeviceCount(device_type_name); } /*static*/ void DeviceManagerRegistry::RegisterDeviceManagerFactory( std::unique_ptr&& factory) { Impl::RegisterDeviceManagerFactory(std::move(factory)); } /*static*/ void DeviceManagerRegistry::DumpVersionInfo() { Impl::DumpVersionInfo(); } /*static*/ std::string DeviceManagerRegistry::GetDeviceTypeNameByDeviceType( DeviceType device_type) { return Impl::GetDeviceTypeNameByDeviceType(device_type); } /*static*/ DeviceType DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName( const std::string& device_type_name) { return Impl::GetDeviceTypeByDeviceTypeName(device_type_name); } /*static*/ std::set DeviceManagerRegistry::GetRegisteredDeviceTypes() { return Impl::GetRegisteredDeviceTypes(); } /*static*/ bool DeviceManagerRegistry::IsDeviceTypeRegistered(DeviceType device_type) { return Impl::IsDeviceTypeRegistered(device_type); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/onednn.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_ONEDNN_H_ #define ONEFLOW_CORE_EP_COMMON_ONEDNN_H_ #ifdef WITH_ONEDNN #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_ENABLE_ONEDNN_OPTS, true); namespace ep { namespace primitive { inline bool OneDnnIsEnabled() { return EnvBool(); } } // namespace primitive } // namespace ep } // namespace oneflow #endif // WITH_ONEDNN #endif // ONEFLOW_CORE_EP_COMMON_ONEDNN_H_ ================================================ FILE: oneflow/core/ep/common/primitive/add.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/add.h" namespace oneflow { namespace ep { namespace primitive { void Add::Launch(Stream* stream, const void* src0, const void* src1, void* dst, size_t count) { const void* srcs[2]; srcs[0] = src0; srcs[1] = src1; Launch(stream, srcs, 2, dst, count); } } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/primitive/batch_matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/batch_matmul.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" namespace oneflow { namespace ep { namespace primitive { namespace { class BatchMatmulImpl : public BatchMatmul { public: OF_DISALLOW_COPY_AND_MOVE(BatchMatmulImpl); BatchMatmulImpl(BlasTransposeType transpose_a, BlasTransposeType transpose_b, std::unique_ptr&& broadcast_matmul) : transpose_a_(transpose_a), transpose_b_(transpose_b), broadcast_matmul_(std::move(broadcast_matmul)) {} ~BatchMatmulImpl() override = default; void Launch(Stream* stream, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) override { int64_t a_dims[3]; int64_t b_dims[3]; int64_t c_dims[3]; a_dims[0] = batch_size; b_dims[0] = batch_size; c_dims[0] = batch_size; if (transpose_a_ == BlasTransposeType::N) { a_dims[1] = m; a_dims[2] = k; } else if (transpose_a_ == BlasTransposeType::T) { a_dims[1] = k; a_dims[2] = m; } else { UNIMPLEMENTED(); } if (transpose_b_ == BlasTransposeType::N) { b_dims[1] = k; b_dims[2] = n; } else if (transpose_b_ == BlasTransposeType::T) { b_dims[1] = n; b_dims[2] = k; } else { UNIMPLEMENTED(); } c_dims[1] = m; c_dims[2] = n; broadcast_matmul_->Launch(stream, alpha, 3, a_dims, a, 3, b_dims, b, beta, 3, c_dims, c); } private: BlasTransposeType transpose_a_; BlasTransposeType transpose_b_; std::unique_ptr broadcast_matmul_; }; template class BatchMatmulFactoryImpl : public BatchMatmulFactory { public: OF_DISALLOW_COPY_AND_MOVE(BatchMatmulFactoryImpl); BatchMatmulFactoryImpl() = default; ~BatchMatmulFactoryImpl() override = default; std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b) override { auto broadcast_matmul = NewPrimitive(device_type, data_type, transpose_a, transpose_b, 3); if (!broadcast_matmul) { return nullptr; } return std::make_unique(transpose_a, transpose_b, std::move(broadcast_matmul)); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BatchMatmulFactory, BatchMatmulFactoryImpl); #ifdef WITH_CUDA REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BatchMatmulFactory, BatchMatmulFactoryImpl); #endif // WITH_CUDA } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/primitive/binary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ #define ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/scalar.h" #include namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { template struct BinaryFunctor; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 + src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 - src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 * src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const { return src0 && src1; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 / src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 > src1 ? src0 : src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 < src1 ? src0 : src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 & src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 | src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 ^ src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 == src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 != src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 < src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 <= src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 > src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 >= src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : atol(attr0.Value()), rtol(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { bool close = src0 == src1; close |= (std::isnan(src0) and std::isnan(src1)); if (atol == 0 and rtol == 0) return close; Src allowed_error = static_cast(atol) + abs(static_cast(rtol) * src1); Src actual_error = abs(src0 - src1); close |= (std::isfinite(actual_error) and (actual_error <= allowed_error)); return close; } float atol, rtol; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : atol(attr0.Value()), rtol(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { bool close = src0 == src1; if (atol == 0 and rtol == 0) return close; Src allowed_error = static_cast(atol) + abs(static_cast(rtol) * src1); Src actual_error = abs(src0 - src1); close |= (std::isfinite(actual_error) and (actual_error <= allowed_error)); return close; } float atol, rtol; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 && src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 || src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0) != static_cast(src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 % src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return src0 / src1; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 / src1); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { Src trunc_mod = src0 % src1; return (trunc_mod != static_cast(0)) && ((src1 < static_cast(0)) != (trunc_mod < static_cast(0))) ? trunc_mod + src1 : trunc_mod; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC uint8_t operator()(uint8_t src0, uint8_t src1) const { return src0 % src1; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC uint32_t operator()(uint32_t src0, uint32_t src1) const { return src0 % src1; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC uint64_t operator()(uint64_t src0, uint64_t src1) const { return src0 % src1; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return scalar_operand * (pow(src0, scalar_operand - static_cast(1))) * src1; } Src scalar_operand; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return (pow(scalar_operand, src0)) * log(scalar_operand) * src1; } Src scalar_operand; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast(dy); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return (x > static_cast(0)) ? static_cast(dy) : static_cast(dy * alpha * (exp(x))); } const Src alpha; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : inv_alpha(1.0f / attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast((y > static_cast(0)) ? dy : dy * static_cast(y * inv_alpha + static_cast(1))); } const Src inv_alpha; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { if (x <= static_cast(-3)) { return static_cast(0); } else if (x >= static_cast(3)) { return static_cast(dy); } else { return static_cast(((x / static_cast(3)) + static_cast(0.5)) * dy); } } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x <= static_cast(-3) || x >= static_cast(3)) ? static_cast(0) : dy / static_cast(6)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast(y == static_cast(0) ? 0 : dy); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : min_val(attr0.Value()), max_val(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast((y == min_val || y == max_val) ? static_cast(0) : dy); } const Src min_val; const Src max_val; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x > static_cast(0)) ? dy : dy * alpha); } const Src alpha; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { Src sp = log(static_cast(1) + exp(x)); Src grad_sp = static_cast(1) - exp(-sp); Src tsp = (exp(sp) - exp(-sp)) / (exp(sp) + exp(-sp)); Src grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; return static_cast(dy * (x * grad_tsp + tsp)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast((y <= static_cast(0.0)) ? static_cast(0.0) : dy); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x <= static_cast(0.0)) ? static_cast(0.0) : dy); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x > static_cast(0)) ? scale * dy : dy * scale * alpha * (exp(x))); } const Src scale = 1.0507009873554804934193349852946; const Src alpha = 1.6732632423543772848170429916717; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { Src sig = static_cast(1) / (static_cast(1) + exp(-x)); return static_cast(dy * (sig * (static_cast(1) + x * (static_cast(1) - sig)))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { Src val = (static_cast(1) + abs(x)); return static_cast(dy / (val * val)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : beta(attr0.Value()), threshold(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { Src z = exp(x * beta); return static_cast((x * beta) > threshold ? dy : dy * z / (z + static_cast(1.0))); } const Src beta; const Src threshold; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast(y == static_cast(0) ? 0 : dy); } const Src alpha; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : threshold(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x <= threshold) ? 0 : dy); } const Src threshold; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src zero = static_cast(0.0); if (x == zero) { return zero; } else if (x < zero) { return -dy; } else { return dy; } } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * -rsqrt(static_cast(1.0) - x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * rsqrt(x * x - static_cast(1.0)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * rsqrt(static_cast(1.0) - x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * rsqrt(static_cast(1.0) + x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src one = static_cast(1.0); return dy * (one / (one + x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src one = static_cast(1.0); return dy * (one / (one - x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (-sin(x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * sinh(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast(2.0) * rsqrt(static_cast(M_PI)) * exp(-x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * -static_cast(2.0) * rsqrt(static_cast(M_PI)) * exp(-x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp2(x) * log(static_cast(2.0)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / (x * log(static_cast(2.0)))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / (x * log(static_cast(10.0)))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / (x + static_cast(1.0))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / (exp(x) + static_cast(1.0))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (-static_cast(1.0) / (x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { if (abs(x) <= static_cast(0.0)) { return static_cast(0.0); } return dy * (-static_cast(1.0) / (x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(-1.0) / (static_cast(2.0) * sqrt(x * x * x))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return dy * (y * (1.0 - y)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { using UnaryOp = oneflow::ep::primitive::UnaryOp; using UnaryFunctor = oneflow::ep::primitive::UnaryFunctor; auto uf = UnaryFunctor(0, 0); Src y = uf(x); return dy * (y * (static_cast(1.0) - y)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cos(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cosh(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast(0.5) / sqrt(x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast(2.0) * x; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src cos_val = cos(x); return dy * (static_cast(1.0) / (cos_val * cos_val)); } }; } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ ================================================ FILE: oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY #define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/ep/common/primitive/util.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { constexpr size_t kMaxNumDims = 8; inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, const int64_t* src1_dims) { if (num_src0_dims != num_src1_dims) { return false; } for (size_t i = 0; i < num_src1_dims; ++i) { if (src0_dims[i] != src1_dims[i]) { return false; } } return true; } #define BINARY_MATH_OP_SEQ_0 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) #define BINARY_MATH_OP_SEQ_1 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFmod) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorDiv) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTruncDiv) #define BINARY_MATH_OP_SEQ_2 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorMod) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarBasePowerGrad) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarExpPowerGrad) #define BINARY_COMPLEX_MATH_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) #define BINARY_MATH_OP_SEQ \ BINARY_MATH_OP_SEQ_0 \ BINARY_MATH_OP_SEQ_1 \ BINARY_MATH_OP_SEQ_2 #define BINARY_COMPARISION_OP_SEQ_0 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) #define BINARY_COMPARISION_OP_SEQ_1 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsCloseEqualNan) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsClose) #define BINARY_COMPARISION_OP_SEQ \ BINARY_COMPARISION_OP_SEQ_0 \ BINARY_COMPARISION_OP_SEQ_1 #define BINARY_COMPLEX_COMPARISION_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) #define BINARY_LOGICAL_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor) #define BINARY_BITWISE_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseAnd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseOr) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseXor) #define BINARY_MATH_FLOATING_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kZeta) #define BINARY_ACTIVATION_BACKWARD_OP_SEQ_0 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIdentityBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCeluBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGeluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardswishBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardsigmoidBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardshrinkBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardtanhBackwardWithDyY) #define BINARY_ACTIVATION_BACKWARD_OP_SEQ_1 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLeakyReluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMishBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReluBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSeluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSiluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftsignBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftplusBackwardWithDyX) #define BINARY_ACTIVATION_BACKWARD_OP_SEQ_2 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftshrinkBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareReLUBackwardWithDyX) #define BINARY_ACTIVATION_BACKWARD_OP_SEQ \ BINARY_ACTIVATION_BACKWARD_OP_SEQ_0 \ BINARY_ACTIVATION_BACKWARD_OP_SEQ_1 \ BINARY_ACTIVATION_BACKWARD_OP_SEQ_2 #define BINARY_MATH_BACKWARD_OP_SEQ_0 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAbsBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcosBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcoshBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCosBackwardWithDyX) #define BINARY_MATH_BACKWARD_OP_SEQ_1 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExp2BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDigammaBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog2BackwardWithDyX) #define BINARY_MATH_BACKWARD_OP_SEQ_2 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog10BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog1pBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogSigmoidBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalNoNanBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kRsqrtBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyY) #define BINARY_MATH_BACKWARD_OP_SEQ_3 \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanBackwardWithDyX) #define BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX) #define BINARY_MATH_BACKWARD_OP_SEQ \ BINARY_MATH_BACKWARD_OP_SEQ_0 \ BINARY_MATH_BACKWARD_OP_SEQ_1 \ BINARY_MATH_BACKWARD_OP_SEQ_2 \ BINARY_MATH_BACKWARD_OP_SEQ_3 } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY ================================================ FILE: oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY #define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY #include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h" #include "oneflow/core/ep/include/primitive/fast_integer_math.h" #include "oneflow/core/ep/common/primitive/util.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_unary { constexpr size_t kMaxNumDims = 8; template class IndexToOffsetWithStrideCalculator { public: IndexToOffsetWithStrideCalculator() {} OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides) { InitStrides(strides, N); } template OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides) { T strides_arr[N]; for (int i = 0; i < N; ++i) { strides_arr[i] = strides[i]; } InitStrides(strides_arr, N); } OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides, int n) { InitStrides(strides, n); } template OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides, int n) { T strides_arr[N]; for (int i = 0; i < N; ++i) { if (i < n) { strides_arr[i] = strides[i]; } } InitStrides(strides_arr, n); } ~IndexToOffsetWithStrideCalculator() = default; OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const { T offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N - 1; ++i) { offset += index[i] * stride_[i]; } offset += index[N - 1]; return offset; } OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const { assert(n <= N); T offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { if (i < n) { offset += index[i] * stride_[i]; } } return offset; } OF_DEVICE_FUNC constexpr int Size() const { return N; } private: OF_DEVICE_FUNC void InitStrides(const T* strides, const int n) { for (int i = n; i < N; ++i) { stride_[i] = 1; } for (int i = n - 1; i >= 0; --i) { stride_[i] = strides[i]; } } T stride_[N]; }; template class OffsetToIndexWithStrideCalculator { public: OffsetToIndexWithStrideCalculator() {} OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims) { InitFastIntegerMath(dims, N); } template OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims) { T dims_arr[N]; for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; } InitFastIntegerMath(dims_arr, N); } OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims, int n) { InitFastIntegerMath(dims, n); } template OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims, int n) { T dims_arr[N]; for (int i = 0; i < N; ++i) { if (i < n) { dims_arr[i] = dims[i]; } } InitFastIntegerMath(dims_arr, n); } ~OffsetToIndexWithStrideCalculator() = default; OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const { T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N - 1; ++i) { const T idx = math_helper_[i].divides(remaining); index[i] = idx; remaining = remaining - math_helper_[i].mul(idx); } index[N - 1] = remaining; } OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const { assert(n <= N); T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { if (i == n - 1) { break; } if (i < n - 1) { const T idx = math_helper_[i].divides(remaining); index[i] = idx; remaining = remaining - math_helper_[i].mul(idx); } } index[n - 1] = remaining; } OF_DEVICE_FUNC T divides(T remaining, int64_t i) const { return math_helper_[i].divides(remaining); } OF_DEVICE_FUNC T mul(T idx, int64_t i) const { return math_helper_[i].mul(idx); } OF_DEVICE_FUNC constexpr int Size() const { return N; } private: OF_DEVICE_FUNC void InitFastIntegerMath(const T* dims, const int n) { T stride_arr[N]; for (int i = n - 1; i < N; ++i) { stride_arr[i] = 1; math_helper_[i] = FastIntegerMath(1); } for (int i = n - 2; i >= 0; --i) { stride_arr[i] = dims[i + 1] * stride_arr[i + 1]; math_helper_[i] = FastIntegerMath(stride_arr[i]); } } FastIntegerMath math_helper_[N]; }; #define UNARY_IDENTITY_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity) #define BROADCAST_ELEMENTWISE_CAST_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCast) } // namespace broadcast_elementwise_unary } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY ================================================ FILE: oneflow/core/ep/common/primitive/broadcast_matmul.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_ #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/dtype.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_matmul { inline void Simplify(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims, const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims, BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t* m, int64_t* n, int64_t* k, int64_t* num_batch_dims, int64_t* broadcast_batch_dims, int64_t* a_batch_dims, int64_t* b_batch_dims, int64_t* c_batch_dims) { CHECK_GE(num_a_dims, 2); CHECK_GE(num_b_dims, 2); CHECK_GE(num_c_dims, 2); if (transpose_a == BlasTransposeType::N) { *m = a_dims[num_a_dims - 2]; *k = a_dims[num_a_dims - 1]; } else if (transpose_a == BlasTransposeType::T) { *m = a_dims[num_a_dims - 1]; *k = a_dims[num_a_dims - 2]; } else { UNIMPLEMENTED(); } CHECK_GT(*m, 0); CHECK_GT(*k, 0); if (transpose_b == BlasTransposeType::N) { CHECK_EQ(b_dims[num_b_dims - 2], *k); *n = b_dims[num_b_dims - 1]; } else if (transpose_b == BlasTransposeType::T) { CHECK_EQ(b_dims[num_b_dims - 1], *k); *n = b_dims[num_b_dims - 2]; } else { UNIMPLEMENTED(); } CHECK_GT(*n, 0); CHECK_EQ(c_dims[num_c_dims - 2], *m); CHECK_EQ(c_dims[num_c_dims - 1], *n); const size_t num_max_batch_dims = std::max(std::max(num_a_dims, num_b_dims), num_c_dims) - 2; auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const int64_t* dims) { const int64_t num_batch_dims = num_dims - 2; const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims; return [num_padding_dims, dims](size_t index) { return index < num_padding_dims ? 1 : dims[index - num_padding_dims]; }; }; auto GetABatchDim = MakeGetBatchDim(num_a_dims, a_dims); auto GetBBatchDim = MakeGetBatchDim(num_b_dims, b_dims); auto GetCBatchDim = MakeGetBatchDim(num_c_dims, c_dims); *num_batch_dims = 0; bool prev_broadcast_a = false; bool prev_broadcast_b = false; bool prev_broadcast_c = false; for (int64_t i = 0; i < num_max_batch_dims; ++i) { const int64_t a_dim = GetABatchDim(i); const int64_t b_dim = GetBBatchDim(i); const int64_t c_dim = GetCBatchDim(i); const int64_t broadcast_dim = std::max(std::max(a_dim, b_dim), c_dim); CHECK_GT(broadcast_dim, 0); const bool broadcast_a = (a_dim == 1); const bool broadcast_b = (b_dim == 1); const bool broadcast_c = (c_dim == 1); CHECK((a_dim == broadcast_dim) || broadcast_a); CHECK((b_dim == broadcast_dim) || broadcast_b); CHECK((c_dim == broadcast_dim) || broadcast_c); if (broadcast_dim == 1) { continue; } else if (*num_batch_dims != 0 && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b && prev_broadcast_c == broadcast_c)) { a_batch_dims[*num_batch_dims - 1] *= a_dim; b_batch_dims[*num_batch_dims - 1] *= b_dim; c_batch_dims[*num_batch_dims - 1] *= c_dim; broadcast_batch_dims[*num_batch_dims - 1] *= broadcast_dim; } else { a_batch_dims[*num_batch_dims] = a_dim; b_batch_dims[*num_batch_dims] = b_dim; c_batch_dims[*num_batch_dims] = c_dim; broadcast_batch_dims[*num_batch_dims] = broadcast_dim; *num_batch_dims += 1; prev_broadcast_a = broadcast_a; prev_broadcast_b = broadcast_b; prev_broadcast_c = broadcast_c; } } if (*num_batch_dims >= 1 && a_batch_dims[*num_batch_dims - 1] != 1 && b_batch_dims[*num_batch_dims - 1] == 1 && c_batch_dims[*num_batch_dims - 1] != 1 && transpose_a == BlasTransposeType::N) { *m *= a_batch_dims[*num_batch_dims - 1]; *num_batch_dims -= 1; } } template void ForEachMatmul(DataType data_type, size_t m, size_t n, size_t k, Scalar beta, size_t num_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* b_batch_dims, const int64_t* c_batch_dims, const void* a, const void* b, void* c, Func func) { if (num_batch_dims == 0) { func(a, b, c, beta); return; } const size_t size_of_data_type = GetSizeOfDataType(data_type); const size_t stride_a = m * k * size_of_data_type; const size_t stride_b = k * n * size_of_data_type; const size_t stride_c = m * n * size_of_data_type; int64_t broadcast_batch_count = 1; for (int64_t i = 0; i < num_batch_dims; ++i) { broadcast_batch_count *= broadcast_batch_dims[i]; } NdIndexOffsetHelper broadcast_index_helper(broadcast_batch_dims, num_batch_dims); NdIndexOffsetHelper a_index_helper(a_batch_dims, num_batch_dims); NdIndexOffsetHelper b_index_helper(b_batch_dims, num_batch_dims); NdIndexOffsetHelper c_index_helper(c_batch_dims, num_batch_dims); int64_t a_batch_index[max_num_dims]{}; int64_t b_batch_index[max_num_dims]{}; int64_t c_batch_index[max_num_dims]{}; int64_t broadcast_batch_index[max_num_dims]{}; bool init_c = true; for (int64_t broadcast_batch_id = 0; broadcast_batch_id < broadcast_batch_count; ++broadcast_batch_id) { broadcast_index_helper.OffsetToNdIndex(broadcast_batch_id, broadcast_batch_index); for (int64_t i = 0; i < num_batch_dims; ++i) { if (a_batch_dims[i] == 1) { a_batch_index[i] = 0; } else { a_batch_index[i] = broadcast_batch_index[i]; } if (b_batch_dims[i] == 1) { b_batch_index[i] = 0; } else { b_batch_index[i] = broadcast_batch_index[i]; } if (c_batch_dims[i] == 1) { c_batch_index[i] = 0; if (broadcast_batch_index[i] != 0) { init_c = false; } } else { c_batch_index[i] = broadcast_batch_index[i]; } } const int64_t a_batch_id = a_index_helper.NdIndexToOffset(a_batch_index); const int64_t b_batch_id = b_index_helper.NdIndexToOffset(b_batch_index); const int64_t c_batch_id = c_index_helper.NdIndexToOffset(c_batch_index); const void* a_ptr = static_cast(a) + a_batch_id * stride_a; const void* b_ptr = static_cast(b) + b_batch_id * stride_b; void* c_ptr = static_cast(c) + c_batch_id * stride_c; const Scalar batch_beta = init_c ? beta : Scalar(1); func(a_ptr, b_ptr, c_ptr, batch_beta); } } namespace internal { namespace { void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t num_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m, int64_t n, int64_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c); template class BroadcastMatmulImpl : public BroadcastMatmul { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulImpl); BroadcastMatmulImpl(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b) : data_type_(data_type), transpose_a_(transpose_a), transpose_b_(transpose_b) {} ~BroadcastMatmulImpl() override = default; void Launch(Stream* stream, Scalar alpha, size_t num_a_dims, const int64_t* a_dims, const void* a, size_t num_b_dims, const int64_t* b_dims, const void* b, Scalar beta, size_t num_c_dims, const int64_t* c_dims, void* c) override { CHECK_LE(num_a_dims, max_num_dims); CHECK_LE(num_b_dims, max_num_dims); CHECK_LE(num_c_dims, max_num_dims); int64_t m = 0; int64_t n = 0; int64_t k = 0; int64_t num_batch_dims = 0; int64_t broadcast_batch_dims[max_num_dims]{}; int64_t a_batch_dims[max_num_dims]{}; int64_t b_batch_dims[max_num_dims]{}; int64_t c_batch_dims[max_num_dims]{}; Simplify(num_a_dims, a_dims, num_b_dims, b_dims, num_c_dims, c_dims, transpose_a_, transpose_b_, &m, &n, &k, &num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims); LaunchBroadcastMatmul(stream, data_type_, transpose_a_, transpose_b_, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c); } private: DataType data_type_; BlasTransposeType transpose_a_; BlasTransposeType transpose_b_; }; } // namespace } // namespace internal } // namespace broadcast_matmul } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_ ================================================ FILE: oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/util.h" #include namespace oneflow { namespace ep { namespace primitive { namespace { template void TestSimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, const int64_t* src1_dims, size_t expected_num_dims, const int64_t* expected_src0_dims, const int64_t* expected_src1_dims, const int64_t* expected_dst_dims) { size_t simplified_num_dims = 0; int64_t simplified_src0_dims[max_num_dims]{}; int64_t simplified_src1_dims[max_num_dims]{}; int64_t simplified_dst_dims[max_num_dims]{}; SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, &simplified_num_dims, simplified_src0_dims, simplified_src1_dims, simplified_dst_dims); ASSERT_EQ(simplified_num_dims, expected_num_dims); for (size_t i = 0; i < simplified_num_dims; ++i) { ASSERT_EQ(simplified_src0_dims[i], expected_src0_dims[i]); ASSERT_EQ(simplified_src1_dims[i], expected_src1_dims[i]); ASSERT_EQ(simplified_dst_dims[i], expected_dst_dims[i]); } } TEST(Broadcast, SimplifyBroadcastDims) { constexpr size_t max_num_dims = 8; const size_t num_src0_dims_1 = 4; const size_t num_src1_dims_1 = 5; int64_t src0_dims_1[max_num_dims]{2, 5, 10, 5}; int64_t src1_dims_1[max_num_dims]{5, 1, 5, 10, 1}; const size_t simplified_num_dims_1 = 4; int64_t simplified_src0_dims_1[max_num_dims]{1, 2, 50, 5}; int64_t simplified_src1_dims_1[max_num_dims]{5, 1, 50, 1}; int64_t simplified_dst_dims_1[max_num_dims]{5, 2, 50, 5}; TestSimplifyBroadcastDims( num_src0_dims_1, src0_dims_1, num_src1_dims_1, src1_dims_1, simplified_num_dims_1, simplified_src0_dims_1, simplified_src1_dims_1, simplified_dst_dims_1); const size_t num_src0_dims_2 = 4; const size_t num_src1_dims_2 = 1; int64_t src0_dims_2[max_num_dims]{10, 5, 1, 5}; int64_t src1_dims_2[max_num_dims]{5}; const size_t simplified_num_dims_2 = 2; int64_t simplified_src0_dims_2[max_num_dims]{50, 5}; int64_t simplified_src1_dims_2[max_num_dims]{1, 5}; int64_t simplified_dst_dims_2[max_num_dims]{50, 5}; TestSimplifyBroadcastDims( num_src0_dims_2, src0_dims_2, num_src1_dims_2, src1_dims_2, simplified_num_dims_2, simplified_src0_dims_2, simplified_src1_dims_2, simplified_dst_dims_2); const size_t num_src0_dims_3 = 4; const size_t num_src1_dims_3 = 1; int64_t src0_dims_3[max_num_dims]{2, 5, 10, 5}; int64_t src1_dims_3[max_num_dims]{1}; const size_t simplified_num_dims_3 = 1; int64_t simplified_src0_dims_3[max_num_dims]{500}; int64_t simplified_src1_dims_3[max_num_dims]{1}; int64_t simplified_dst_dims_3[max_num_dims]{500}; TestSimplifyBroadcastDims( num_src0_dims_3, src0_dims_3, num_src1_dims_3, src1_dims_3, simplified_num_dims_3, simplified_src0_dims_3, simplified_src1_dims_3, simplified_dst_dims_3); } } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/primitive/constant_pad.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_ #define ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/fast_integer_math.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace ep { namespace primitive { namespace { constexpr int32_t kMaxNumDims = 8; constexpr int32_t Min(int32_t a, int32_t b) { return a < b ? a : b; } constexpr int32_t kMaxPackBytes = 128 / 8; template constexpr int32_t GetMaxPackSize() { return Min(kMaxPackBytes / sizeof(T), 8); } template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * pack_size, ""); explicit OF_DEVICE_FUNC Pack(T value) { #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < pack_size; i++) { elem[i] = value; } } T elem[pack_size]; PackType storage; }; template T GetValue(Scalar value) { return value.Value(); } template class OffsetToIndexCalculator { public: OffsetToIndexCalculator() {} template OF_DEVICE_FUNC explicit OffsetToIndexCalculator(T d0, Ts... dims) { constexpr int n = 1 + sizeof...(dims); static_assert(n <= N, ""); T dims_arr[n] = {d0, static_cast(dims)...}; InitFastIntegerMath(dims_arr, n); } OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const T* dims) { InitFastIntegerMath(dims, N); } template OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const U* dims) { T dims_arr[N]; for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; } InitFastIntegerMath(dims_arr, N); } OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const T* dims, int n) { InitFastIntegerMath(dims, n); } template OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const U* dims, int n) { T dims_arr[N]; for (int i = 0; i < N; ++i) { if (i < n) { dims_arr[i] = dims[i]; } } InitFastIntegerMath(dims_arr, n); } ~OffsetToIndexCalculator() = default; OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const { T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N - 1; ++i) { const T idx = math_helper_[i].divides(remaining); index[i] = idx; remaining = remaining - math_helper_[i].mul(idx); } index[N - 1] = remaining; } OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const { assert(n <= N); T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < N; ++i) { if (i == n - 1) { break; } if (i < n - 1) { const T idx = math_helper_[i].divides(remaining); index[i] = idx; remaining = remaining - math_helper_[i].mul(idx); } } index[n - 1] = remaining; } template OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const { constexpr int n = 1 + sizeof...(others); static_assert(n <= N, ""); T* index[n] = {&d0, &others...}; T remaining = offset; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int i = 0; i < n - 1; ++i) { const T idx = math_helper_[i].divides(remaining); *index[i] = idx; remaining = remaining - math_helper_[i].mul(idx); } if (n == N) { *index[n - 1] = remaining; } else { *index[n - 1] = math_helper_[n - 1].divides(remaining); } } OF_DEVICE_FUNC constexpr int Size() const { return N; } private: OF_DEVICE_FUNC void InitFastIntegerMath(const T* dims, const int n) { T stride_arr[N]; for (int i = n - 1; i < N; ++i) { stride_arr[i] = 1; math_helper_[i] = FastIntegerMath(1); } for (int i = n - 2; i >= 0; --i) { stride_arr[i] = dims[i + 1] * stride_arr[i + 1]; math_helper_[i] = FastIntegerMath(stride_arr[i]); } } FastIntegerMath math_helper_[N]; }; template struct ConstantPadParams { NdIndexOffsetHelper src_index_helper; OffsetToIndexCalculator dst_index_helper; IndexType valid_start[num_dims]; IndexType valid_end[num_dims]; IndexType elem_cnt{}; const void* src{}; void* dst{}; }; template size_t GetLaunchPackSize(size_t num_dims, void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after) { static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); const int64_t last_dst_dim_size = dst_dims[num_dims - 1]; const int64_t last_src_dim_size = src_dims[num_dims - 1]; const int64_t last_padding_before_size = padding_before[num_dims - 1]; const int64_t last_padding_after_size = padding_after[num_dims - 1]; auto src_ptr = reinterpret_cast(src); auto dst_ptr = reinterpret_cast(dst); for (size_t size = max_pack_size; size > 1; size /= 2) { if (last_dst_dim_size % size == 0 && last_src_dim_size % size == 0 && last_padding_before_size % size == 0 && last_padding_after_size % size == 0 && src_ptr % size == 0 && dst_ptr % size == 0) { return size; } } return 1; } void SimplifyPadDims(size_t num_dims, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after, size_t* simplified_num_dims, int64_t* simplified_dst_dims, int64_t* simplified_src_dims, int64_t* simplified_padding_before, int64_t* simplified_padding_after) { CHECK_NE(num_dims, 0); size_t valid_num_dims = 0; FOR_RANGE(size_t, i, 0, num_dims) { const int64_t dst_dim = src_dims[i] + padding_before[i] + padding_after[i]; if ((i != 0) && (padding_before[i] == 0 && padding_after[i] == 0)) { simplified_dst_dims[valid_num_dims - 1] *= dst_dim; simplified_src_dims[valid_num_dims - 1] *= src_dims[i]; simplified_padding_before[valid_num_dims - 1] *= src_dims[i]; simplified_padding_after[valid_num_dims - 1] *= src_dims[i]; } else { simplified_dst_dims[valid_num_dims] = dst_dim; simplified_src_dims[valid_num_dims] = src_dims[i]; simplified_padding_before[valid_num_dims] = padding_before[i]; simplified_padding_after[valid_num_dims] = padding_after[i]; valid_num_dims += 1; } } *simplified_num_dims = valid_num_dims; } } // namespace } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_ ================================================ FILE: oneflow/core/ep/common/primitive/copy_nd.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_ #define ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace ep { namespace primitive { namespace { template struct CopyNdKernelParams { NdIndexOffsetHelper src_index_helper; NdIndexOffsetHelper dst_index_helper; NdIndexOffsetHelper copy_index_helper; IndexType dst_pos[num_dims]; IndexType src_pos[num_dims]; IndexType count{}; const void* src{}; void* dst{}; }; template size_t GetMovementSize(size_t elem_size, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { static_assert(max_movement_size > 0 && (max_movement_size & (max_movement_size - 1)) == 0, ""); CHECK_GT(elem_size, 0); CHECK_EQ((elem_size & (elem_size - 1)), 0); CHECK_EQ(max_movement_size % elem_size, 0); const int64_t last_dst_dim_size = dst_dims[num_dims - 1] * elem_size; const int64_t last_dst_pos = dst_pos[num_dims - 1] * elem_size; const int64_t last_src_dim_size = src_dims[num_dims - 1] * elem_size; const int64_t last_src_pos = src_pos[num_dims - 1] * elem_size; const int64_t last_extent = extent[num_dims - 1] * elem_size; auto src_ptr = reinterpret_cast(src); auto dst_ptr = reinterpret_cast(dst); for (size_t size = max_movement_size; size > elem_size; size /= 2) { if (last_dst_dim_size % size == 0 && last_dst_pos % size == 0 && last_src_dim_size % size == 0 && last_src_pos % size == 0 && last_extent % size == 0 && src_ptr % size == 0 && dst_ptr % size == 0) { return size; } } return elem_size; } void SimplifyCopyNdDims(size_t num_dims, const int64_t* dst_dims, const int64_t* dst_pos, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent, size_t* simplified_num_dims, int64_t* simplified_dst_dims, int64_t* simplified_dst_pos, int64_t* simplified_src_dims, int64_t* simplified_src_pos, int64_t* simplified_extent) { CHECK_NE(num_dims, 0); size_t valid_num_dims = 0; FOR_RANGE(size_t, i, 0, num_dims) { if ((i != 0) && (dst_dims[i] == src_dims[i]) && (dst_dims[i] == extent[i]) && (src_pos[i] == 0) && (dst_pos[i] == 0)) { simplified_dst_dims[valid_num_dims - 1] *= extent[i]; simplified_dst_pos[valid_num_dims - 1] *= extent[i]; simplified_src_dims[valid_num_dims - 1] *= extent[i]; simplified_src_pos[valid_num_dims - 1] *= extent[i]; simplified_extent[valid_num_dims - 1] *= extent[i]; } else { simplified_dst_dims[valid_num_dims] = dst_dims[i]; simplified_dst_pos[valid_num_dims] = dst_pos[i]; simplified_src_dims[valid_num_dims] = src_dims[i]; simplified_src_pos[valid_num_dims] = src_pos[i]; simplified_extent[valid_num_dims] = extent[i]; valid_num_dims += 1; } } *simplified_num_dims = valid_num_dims; } constexpr size_t kMaxMovementSize = 16; constexpr size_t kMaxNumDims = 8; template void LaunchKernel(Stream* stream, CopyNdKernelParams params); template void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent, size_t count) { CopyNdKernelParams params; params.dst_index_helper = NdIndexOffsetHelper(dst_dims); params.src_index_helper = NdIndexOffsetHelper(src_dims); params.copy_index_helper = NdIndexOffsetHelper(extent); for (size_t i = 0; i < num_dims; ++i) { params.dst_pos[i] = dst_pos[i]; params.src_pos[i] = src_pos[i]; } params.src = src; params.dst = dst; params.count = static_cast(count); LaunchKernel(stream, params); } template void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { size_t count = 1; for (size_t i = 0; i < num_dims; ++i) { count *= extent[i]; } if (count < GetMaxVal()) { LaunchKernel(stream, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent, count); } else { LaunchKernel(stream, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent, count); } } template void DispatchMovementSize(Stream* stream, size_t movement_size, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { void (*func)(Stream* /*stream*/, void* /*dst*/, const int64_t* /*dst_dims*/, const int64_t* /*dst_pos*/, const void* /*src*/, const int64_t* /*src_dims*/, const int64_t* /*src_pos*/, const int64_t* /*extent*/) = nullptr; if (movement_size == 1) { func = DispatchIndexType; } else if (movement_size == 2) { func = DispatchIndexType; } else if (movement_size == 4) { func = DispatchIndexType; } else if (movement_size == 8) { func = DispatchIndexType; } else if (movement_size == 16) { func = DispatchIndexType; } else { UNIMPLEMENTED(); } func(stream, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent); } void LaunchWithSimplified(Stream* stream, size_t movement_size, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { void (*func)(Stream* /*stream*/, size_t /*movement_size*/, void* /*dst*/, const int64_t* /*dst_dims*/, const int64_t* /*dst_pos*/, const void* /*src*/, const int64_t* /*src_dims*/, const int64_t* /*src_pos*/, const int64_t* /*extent*/) = nullptr; if (num_dims == 1) { func = DispatchMovementSize<1>; } else if (num_dims == 2) { func = DispatchMovementSize<2>; } else if (num_dims == 3) { func = DispatchMovementSize<3>; } else if (num_dims == 4) { func = DispatchMovementSize<4>; } else if (num_dims == 5) { func = DispatchMovementSize<5>; } else if (num_dims == 6) { func = DispatchMovementSize<6>; } else if (num_dims == 7) { func = DispatchMovementSize<7>; } else if (num_dims == 8) { func = DispatchMovementSize<8>; } else { UNIMPLEMENTED(); } func(stream, movement_size, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent); } template void SimplifyCopyNd(size_t num_dims, const int64_t* dst_dims, const int64_t* dst_pos, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent, size_t* simplified_num_dims, int64_t* simplified_dst_dims, int64_t* simplified_dst_pos, int64_t* simplified_src_dims, int64_t* simplified_src_pos, int64_t* simplified_extent, size_t elem_size, void* dst, const void* src, size_t* movement_size) { SimplifyCopyNdDims(num_dims, dst_dims, dst_pos, src_dims, src_pos, extent, simplified_num_dims, simplified_dst_dims, simplified_dst_pos, simplified_src_dims, simplified_src_pos, simplified_extent); *movement_size = GetMovementSize( elem_size, *simplified_num_dims, dst, simplified_dst_dims, simplified_dst_pos, src, simplified_src_dims, simplified_src_pos, simplified_extent); size_t movement_elem_num = *movement_size / elem_size; simplified_dst_dims[*simplified_num_dims - 1] /= movement_elem_num; simplified_dst_pos[*simplified_num_dims - 1] /= movement_elem_num; simplified_src_dims[*simplified_num_dims - 1] /= movement_elem_num; simplified_src_pos[*simplified_num_dims - 1] /= movement_elem_num; simplified_extent[*simplified_num_dims - 1] /= movement_elem_num; } void SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { CHECK_GT(num_dims, 0) << "num_dims must greater than 0"; CHECK_LE(num_dims, kMaxNumDims); size_t simplified_num_dims = 0; int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_dst_pos[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_src_pos[kMaxNumDims]; int64_t simplified_extent[kMaxNumDims]; size_t movement_size; SimplifyCopyNd(num_dims, dst_dims, dst_pos, src_dims, src_pos, extent, &simplified_num_dims, simplified_dst_dims, simplified_dst_pos, simplified_src_dims, simplified_src_pos, simplified_extent, GetSizeOfDataType(data_type), dst, src, &movement_size); LaunchWithSimplified(stream, movement_size, simplified_num_dims, dst, simplified_dst_dims, simplified_dst_pos, src, simplified_src_dims, simplified_src_pos, simplified_extent); } } // namespace } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_ ================================================ FILE: oneflow/core/ep/common/primitive/elementwise_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_ #include "oneflow/core/ep/include/primitive/elementwise_unary.h" namespace oneflow { namespace ep { namespace primitive { #define UNARY_MATH_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity) #define UNARY_FLOATING_MATH_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcos) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcosh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsin) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsinh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtan) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtanh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCeil) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kDigamma) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrigamma) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp2) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExpm1) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFloor) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLgamma) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog2) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog10) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog1p) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocal) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocalNoNan) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRint) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRound) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRsqrt) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSigmoid) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSin) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSinh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquare) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTan) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrunc) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquareReLU) #define UNARY_COMPLEX_C2C_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative) #define UNARY_COMPLEX_C2R_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReal) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImag) #define UNARY_COMPLEX_R2C_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRealGrad) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImagGrad) #define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs) #define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot) #define UNARY_BITWISE_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kBitwiseNot) #define UNARY_UTILS_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsFinite) } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_ ================================================ FILE: oneflow/core/ep/common/primitive/matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/include/primitive/batch_matmul.h" namespace oneflow { namespace ep { namespace primitive { namespace { class MatmulImpl : public Matmul { public: OF_DISALLOW_COPY_AND_MOVE(MatmulImpl); explicit MatmulImpl(std::unique_ptr&& batch_matmul) : batch_matmul_(std::move(batch_matmul)) {} ~MatmulImpl() override = default; void Launch(Stream* stream, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) override { batch_matmul_->Launch(stream, 1, m, n, k, alpha, a, b, beta, c); } private: std::unique_ptr batch_matmul_; }; template class MatmulFactoryImpl : public MatmulFactory { public: OF_DISALLOW_COPY_AND_MOVE(MatmulFactoryImpl); MatmulFactoryImpl() = default; ~MatmulFactoryImpl() override = default; std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b) override { auto batch_matmul = NewPrimitive(device_type, data_type, transpose_a, transpose_b); if (!batch_matmul) { return nullptr; } return std::make_unique(std::move(batch_matmul)); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MatmulFactory, MatmulFactoryImpl); #ifdef WITH_CUDA REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl); #endif // WITH_CUDA } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/primitive/permute.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace ep { namespace primitive { namespace permute { template size_t GetMovementSize(size_t elem_size, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { static_assert(max_movement_size > 0 && (max_movement_size & (max_movement_size - 1)) == 0, ""); CHECK_GT(elem_size, 0); CHECK_EQ((elem_size & (elem_size - 1)), 0); CHECK_EQ(max_movement_size % elem_size, 0); if (permutation[num_dims - 1] == num_dims - 1) { const int64_t last_dim_size = src_dims[num_dims - 1] * elem_size; auto src_ptr = reinterpret_cast(src); auto dst_ptr = reinterpret_cast(dst); for (size_t size = max_movement_size; size > elem_size; size /= 2) { if (last_dim_size % size == 0 && src_ptr % size == 0 && dst_ptr % size == 0) { return size; } } } return elem_size; } template void SimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation, size_t* simplified_num_dims, int64_t* simplified_src_dims, int* simplified_permutation) { CHECK_NE(num_dims, 0); int64_t coalesced_dims[max_num_dims]; size_t start_permutation_index = 0; while (start_permutation_index < num_dims) { const size_t start_dim_index = permutation[start_permutation_index]; coalesced_dims[start_dim_index] = src_dims[start_dim_index]; size_t end_permutation_index = start_permutation_index + 1; while (end_permutation_index < num_dims && permutation[end_permutation_index] == permutation[end_permutation_index - 1] + 1) { const size_t end_dim_index = permutation[end_permutation_index]; coalesced_dims[start_dim_index] *= src_dims[end_dim_index]; coalesced_dims[end_dim_index] = 1; end_permutation_index += 1; } start_permutation_index = end_permutation_index; } size_t valid_num_dims = 0; int mapping[max_num_dims]; for (size_t i = 0; i < num_dims; ++i) { const int src_dim = coalesced_dims[i]; if (src_dim == 1) { mapping[i] = -1; } else { mapping[i] = valid_num_dims; simplified_src_dims[valid_num_dims] = src_dim; valid_num_dims += 1; } } if (valid_num_dims == 0) { *simplified_num_dims = 1; simplified_src_dims[0] = 1; simplified_permutation[0] = 0; } else { *simplified_num_dims = valid_num_dims; size_t permutation_index = 0; for (size_t i = 0; i < num_dims; ++i) { const int mapped = mapping[permutation[i]]; if (mapped >= 0) { simplified_permutation[permutation_index] = mapped; permutation_index += 1; } } } } template void SimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation, size_t* simplified_num_dims, int64_t* simplified_src_dims, int* simplified_permutation, size_t elem_size, const void* src, void* dst, size_t* movement_size) { const size_t pre_simplified_movement_size = GetMovementSize(elem_size, num_dims, src_dims, src, permutation, dst); int64_t tmp_dims[max_num_dims]; for (size_t i = 0; i < num_dims; ++i) { tmp_dims[i] = src_dims[i]; } tmp_dims[num_dims - 1] /= (pre_simplified_movement_size / elem_size); SimplifyPermutation(num_dims, tmp_dims, permutation, simplified_num_dims, simplified_src_dims, simplified_permutation); *movement_size = GetMovementSize(pre_simplified_movement_size, *simplified_num_dims, simplified_src_dims, src, simplified_permutation, dst); simplified_src_dims[*simplified_num_dims - 1] /= (*movement_size / pre_simplified_movement_size); } } // namespace permute } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_ ================================================ FILE: oneflow/core/ep/common/primitive/permute_impl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_ #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/common/primitive/permute.h" namespace oneflow { namespace ep { namespace primitive { namespace permute { namespace internal { namespace { template struct PermuteKernelParams { NdIndexOffsetHelper src_index_helper; NdIndexOffsetHelper dst_index_helper; int permutation[num_dims]{}; IndexType count{}; const void* src{}; void* dst{}; }; constexpr size_t kMaxMovementSize = 16; constexpr size_t kMaxNumDims = 8; template PermuteKernelParams MakePermuteParams(const int64_t* src_dims, const void* src, const int* permutation, void* dst, size_t count) { PermuteKernelParams params; params.src_index_helper = NdIndexOffsetHelper(src_dims); int64_t dst_dims[num_dims]; for (size_t i = 0; i < num_dims; ++i) { dst_dims[i] = src_dims[permutation[i]]; } params.dst_index_helper = NdIndexOffsetHelper(dst_dims); for (size_t i = 0; i < num_dims; ++i) { params.permutation[i] = permutation[i]; } params.src = src; params.dst = dst; params.count = static_cast(count); return params; } template void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation, void* dst, size_t count); template void DispatchIndexType(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { size_t count = 1; for (size_t i = 0; i < num_dims; ++i) { count *= src_dims[i]; } if (count < GetMaxVal()) { LaunchKernel(stream, src_dims, src, permutation, dst, count); } else { LaunchKernel(stream, src_dims, src, permutation, dst, count); } } template void DispatchMovementSize(Stream* stream, size_t movement_size, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { void (*func)(Stream* /*stream*/, const int64_t* /*src_dims*/, const void* /*src*/, const int* /*permutation*/, void* /*dst*/) = nullptr; if (movement_size == 1) { func = DispatchIndexType; } else if (movement_size == 2) { func = DispatchIndexType; } else if (movement_size == 4) { func = DispatchIndexType; } else if (movement_size == 8) { func = DispatchIndexType; } else if (movement_size == 16) { func = DispatchIndexType; } else { UNIMPLEMENTED(); } func(stream, src_dims, src, permutation, dst); } void LaunchWithSimplified(Stream* stream, size_t movement_size, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { void (*func)(Stream* /*stream*/, size_t /*movement_size*/, const int64_t* /*src_dims*/, const void* /*src*/, const int* /*permutation*/, void* /*dst*/) = nullptr; if (num_dims == 1) { func = DispatchMovementSize<1>; } else if (num_dims == 2) { func = DispatchMovementSize<2>; } else if (num_dims == 3) { func = DispatchMovementSize<3>; } else if (num_dims == 4) { func = DispatchMovementSize<4>; } else if (num_dims == 5) { func = DispatchMovementSize<5>; } else if (num_dims == 6) { func = DispatchMovementSize<6>; } else if (num_dims == 7) { func = DispatchMovementSize<7>; } else if (num_dims == 8) { func = DispatchMovementSize<8>; } else { UNIMPLEMENTED(); } func(stream, movement_size, src_dims, src, permutation, dst); } void SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { CHECK_LE(num_dims, kMaxNumDims); CHECK_GT(num_dims, 0); size_t simplified_num_dims = 0; int64_t simplified_src_dims[kMaxNumDims]; int simplified_permutation[kMaxNumDims]; size_t movement_size = 0; SimplifyPermutation( num_dims, src_dims, permutation, &simplified_num_dims, simplified_src_dims, simplified_permutation, GetSizeOfDataType(data_type), src, dst, &movement_size); LaunchWithSimplified(stream, movement_size, simplified_num_dims, simplified_src_dims, src, simplified_permutation, dst); } } // namespace } // namespace internal } // namespace permute } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_ ================================================ FILE: oneflow/core/ep/common/primitive/permute_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/permute.h" #include namespace oneflow { namespace ep { namespace primitive { namespace permute { namespace { template void TestSimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation, size_t expected_num_dims, const int64_t* expected_src_dims, const int* expected_permutation) { size_t simplified_num_dims = 0; int64_t simplified_src_dims[max_num_dims]{}; int simplified_permutation[max_num_dims]{}; SimplifyPermutation(num_dims, src_dims, permutation, &simplified_num_dims, simplified_src_dims, simplified_permutation); ASSERT_EQ(simplified_num_dims, expected_num_dims); for (size_t i = 0; i < simplified_num_dims; ++i) { ASSERT_EQ(simplified_src_dims[i], expected_src_dims[i]); ASSERT_EQ(simplified_permutation[i], expected_permutation[i]); } } TEST(Permute, SimplifyPermutation) { constexpr size_t max_num_dims = 8; const size_t num_dims_1 = 5; int64_t src_dims_1[max_num_dims]{1, 2, 2, 1, 2}; int permutation_1[max_num_dims]{0, 1, 3, 4, 2}; const size_t simplified_num_dims_1 = 3; int64_t simplified_src_dims_1[max_num_dims]{2, 2, 2}; int simplified_permutation_1[max_num_dims]{0, 2, 1}; TestSimplifyPermutation(num_dims_1, src_dims_1, permutation_1, simplified_num_dims_1, simplified_src_dims_1, simplified_permutation_1); const size_t num_dims_2 = 4; int64_t src_dims_2[max_num_dims]{5, 6, 7, 8}; int permutation_2[max_num_dims]{2, 3, 0, 1}; const size_t simplified_num_dims_2 = 2; int64_t simplified_src_dims_2[max_num_dims]{5 * 6, 7 * 8}; int simplified_permutation_2[max_num_dims]{1, 0}; TestSimplifyPermutation(num_dims_2, src_dims_2, permutation_2, simplified_num_dims_2, simplified_src_dims_2, simplified_permutation_2); const size_t num_dims_3 = 4; int64_t src_dims_3[max_num_dims]{5, 6, 7, 8}; int permutation_3[max_num_dims]{0, 1, 2, 3}; const size_t simplified_num_dims_3 = 1; int64_t simplified_src_dims_3[max_num_dims]{5 * 6 * 7 * 8}; int simplified_permutation_3[max_num_dims]{0}; TestSimplifyPermutation(num_dims_3, src_dims_3, permutation_3, simplified_num_dims_3, simplified_src_dims_3, simplified_permutation_3); } } // namespace } // namespace permute } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/common/primitive/unary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_ #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { template struct UnaryFunctor; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast( (src > static_cast(0.0)) ? src : alpha * (exp(src) - static_cast(1))); } const Src alpha; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()), inv_alpha(1.0f / attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast( (src > static_cast(0.0)) ? src : alpha * (exp(src * inv_alpha) - static_cast(1))); } const Src alpha; const Src inv_alpha; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { if (src <= static_cast(-3)) { return static_cast(0); } else if (src >= static_cast(3)) { return static_cast(src); } else { return static_cast((src * (src + static_cast(3))) / static_cast(6)); } } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { if (src <= static_cast(-3)) { return static_cast(0); } else if (src >= static_cast(3)) { return static_cast(1); } else { return static_cast(src / static_cast(6) + static_cast(0.5)); } } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : lambd(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return (src <= lambd && src >= -lambd) ? static_cast(0) : static_cast(src); } const Src lambd; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : min_val(attr0.Value()), max_val(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { if (src <= min_val) { return static_cast(min_val); } else if (src >= max_val) { return static_cast(max_val); } else { return static_cast(src); } } const Src min_val; const Src max_val; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast((src > static_cast(0.0)) ? src : alpha * src); } const Src alpha; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { Src soft_plus_val = log(static_cast(1) + exp(src)); Src exp_val = exp(soft_plus_val); Src neg_exp_val = exp(-soft_plus_val); Src tanh_val = (exp_val - neg_exp_val) / (exp_val + neg_exp_val); return static_cast(src * tanh_val); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { const Src zero_val = static_cast(0.0); if (src <= zero_val) { return static_cast(zero_val); } else { return static_cast(src); } } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src / (static_cast(1) + exp(-src))); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast((src > static_cast(0.0)) ? src * scale : scale * alpha * (exp(src) - static_cast(1))); } const Src scale = 1.0507009873554804934193349852946; const Src alpha = 1.6732632423543772848170429916717; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src / (static_cast(1) + abs(src))); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : beta(attr0.Value()), threshold(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast( (src * beta) > threshold ? src : log(static_cast(1.0) + exp(src * beta)) / beta); } const Src beta; const Src threshold; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { if (src <= alpha && src >= -alpha) { return static_cast(0); } else if (src > alpha) { return static_cast(src - alpha); } else { return static_cast(src + alpha); } } const Src alpha; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : threshold(attr0.Value()), value(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast((src <= threshold) ? value : src); } const Src threshold; const Src value; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(!src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(Src src) const { return false; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(Src src) const { return false; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(Src src) const { return true; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1); OF_DEVICE_FUNC Dst operator()(Src src) const; }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(abs(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC uint8_t operator()(uint8_t src) const { return src; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(exp(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(exp2(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(acos(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(acosh(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(asin(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(asinh(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(atan(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(atanh(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(ceil(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(cos(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(cosh(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(erf(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(erfc(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(expm1(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(floor(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(lgamma(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(log(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(log2(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(log10(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(log1p(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(-log(static_cast(1.0) + exp(-src))); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(-src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(static_cast(1.0) / src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { if (abs(src) <= static_cast(0.0)) { return static_cast(0.0); } return static_cast(static_cast(1.0) / src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(rint(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(nearbyint(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(rsqrt(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(static_cast(1.0) / (static_cast(1.0) + exp(-src))); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { const Src zero = static_cast(0.0); if (src > zero) { return static_cast(1.0); } else if (src < zero) { return static_cast(-1.0); } else { return static_cast(0.0); } } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(sin(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(sinh(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(sqrt(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src * src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(tan(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src != static_cast(0.0)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return std::isnan(src) ? static_cast(0.0) : src; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(~src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(bool src) const { return static_cast(!src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.real(), -src.imag()}; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.real()); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.imag()); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; } }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_ ================================================ FILE: oneflow/core/ep/common/primitive/util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/primitive/unary_op.h" namespace oneflow { namespace ep { namespace primitive { inline size_t GetElementCount(size_t num_dims, const int64_t* dims) { size_t count = 1; for (size_t i = 0; i < num_dims; ++i) { count *= dims[i]; } return count; } template bool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t* dims, const void* ptr) { return (dims[num_dims - 1] % pack_size == 0) && (reinterpret_cast(ptr) % (pack_size * sizeof(T)) == 0); } inline void CheckInplace(size_t num_dims, const int64_t* src_dims_or_strides, const void* src, const int64_t* dst_dims_or_strides, const void* dst) { if (src == dst) { for (int64_t i = 0; i < num_dims; ++i) { CHECK_EQ(src_dims_or_strides[i], dst_dims_or_strides[i]); } } } template inline void SimplifyBroadcastDims(size_t num_src_dims, const int64_t* src_dims, const int64_t* src_strides, size_t num_dst_dims, const int64_t* dst_dims, const int64_t* dst_strides, size_t* simplified_num_dims, int64_t* simplified_src_dims, int64_t* simplified_src_strides, int64_t* simplified_dst_dims, int64_t* simplified_dst_strides) { *simplified_num_dims = 0; std::pair sorted_dst_strides[max_num_dims]; int64_t new_dst_dims[max_num_dims]; int64_t new_src_dims[max_num_dims]; int64_t new_dst_strides[max_num_dims]; int64_t new_src_strides[max_num_dims]; for (size_t i = 0; i < num_dst_dims; i++) { sorted_dst_strides[i] = {dst_strides[i], i}; } std::sort(sorted_dst_strides, sorted_dst_strides + num_dst_dims, [](auto pair1, auto pair2) { return pair1.first > pair2.first; }); const int64_t num_src_padding_dims = num_dst_dims - num_src_dims; // dimension completion int64_t expanded_src_dims[max_num_dims]; int64_t expanded_src_strides[max_num_dims]; for (int64_t i = num_dst_dims - 1; i >= 0; i--) { expanded_src_dims[i] = i < num_src_padding_dims ? 1 : src_dims[i - num_src_padding_dims]; expanded_src_strides[i] = i < num_src_padding_dims ? 0 : src_strides[i - num_src_padding_dims]; } // dimension permutation for (int64_t i = num_dst_dims - 1; i >= 0; i--) { size_t idx = sorted_dst_strides[i].second; new_dst_dims[i] = dst_dims[idx]; new_dst_strides[i] = dst_strides[idx]; new_src_dims[i] = expanded_src_dims[idx]; new_src_strides[i] = expanded_src_strides[idx]; } // dimension merge bool prev_broadcast_src = false; for (int64_t i = 0; i < num_dst_dims; ++i) { const bool broadcast_src = (new_src_dims[i] == 1); if (new_dst_dims[i] == 1) { continue; } else if (*simplified_num_dims != 0 && prev_broadcast_src == broadcast_src && (new_src_strides[i - 1] == new_src_strides[i] * new_src_dims[i]) && (new_dst_strides[i - 1] == new_dst_strides[i] * new_dst_dims[i])) { simplified_src_dims[*simplified_num_dims - 1] *= new_src_dims[i]; simplified_dst_dims[*simplified_num_dims - 1] *= new_dst_dims[i]; simplified_src_strides[*simplified_num_dims - 1] = new_src_strides[i]; simplified_dst_strides[*simplified_num_dims - 1] = new_dst_strides[i]; } else { simplified_src_dims[*simplified_num_dims] = new_src_dims[i]; simplified_dst_dims[*simplified_num_dims] = new_dst_dims[i]; simplified_src_strides[*simplified_num_dims] = new_src_strides[i]; simplified_dst_strides[*simplified_num_dims] = new_dst_strides[i]; *simplified_num_dims += 1; prev_broadcast_src = broadcast_src; } } if (*simplified_num_dims == 0) { simplified_src_dims[0] = 1; simplified_dst_dims[0] = 1; simplified_src_strides[0] = 1; simplified_dst_strides[0] = 1; *simplified_num_dims = 1; } } inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims, const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims, size_t* simplified_num_dims, int64_t* simplified_broadcast_dims, int64_t* simplified_a_dims, int64_t* simplified_b_dims, int64_t* simplified_c_dims) { const size_t num_max_dims = std::max(num_a_dims, num_b_dims); auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) { const int64_t num_padding_dims = num_max_dims - num_dims; return [num_padding_dims, dims](size_t index) { return index < num_padding_dims ? 1 : dims[index - num_padding_dims]; }; }; auto GetADim = MakeGetDim(num_a_dims, a_dims); auto GetBDim = MakeGetDim(num_b_dims, b_dims); auto GetCDim = MakeGetDim(num_c_dims, c_dims); *simplified_num_dims = 0; bool prev_broadcast_a = false; bool prev_broadcast_b = false; bool prev_broadcast_c = false; for (int64_t i = 0; i < num_max_dims; ++i) { const int64_t a_dim = GetADim(i); const int64_t b_dim = GetBDim(i); const int64_t c_dim = GetCDim(i); const int64_t broadcast_dim = std::max(std::max(a_dim, b_dim), c_dim); CHECK_GT(broadcast_dim, 0); const bool broadcast_a = (a_dim == 1); const bool broadcast_b = (b_dim == 1); const bool broadcast_c = (c_dim == 1); CHECK((a_dim == broadcast_dim) || broadcast_a); CHECK((b_dim == broadcast_dim) || broadcast_b); CHECK((c_dim == broadcast_dim) || broadcast_c); if (broadcast_dim == 1) { continue; } else if (*simplified_num_dims != 0 && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b && prev_broadcast_c == broadcast_c)) { simplified_a_dims[*simplified_num_dims - 1] *= a_dim; simplified_b_dims[*simplified_num_dims - 1] *= b_dim; simplified_c_dims[*simplified_num_dims - 1] *= c_dim; simplified_broadcast_dims[*simplified_num_dims - 1] *= broadcast_dim; } else { simplified_a_dims[*simplified_num_dims] = a_dim; simplified_b_dims[*simplified_num_dims] = b_dim; simplified_c_dims[*simplified_num_dims] = c_dim; simplified_broadcast_dims[*simplified_num_dims] = broadcast_dim; *simplified_num_dims += 1; prev_broadcast_a = broadcast_a; prev_broadcast_b = broadcast_b; prev_broadcast_c = broadcast_c; } } if (*simplified_num_dims == 0) { simplified_a_dims[0] = 1; simplified_b_dims[0] = 1; simplified_c_dims[0] = 1; *simplified_num_dims = 1; } } template inline void SimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, const int64_t* src1_dims, size_t* simplified_num_dims, int64_t* simplified_src0_dims, int64_t* simplified_src1_dims, int64_t* simplified_dst_dims) { size_t src0_count = GetElementCount(num_src0_dims, src0_dims); size_t src1_count = GetElementCount(num_src1_dims, src1_dims); if (src0_count == 1 || src1_count == 1) { *simplified_num_dims = 1; simplified_src0_dims[0] = src0_count; simplified_src1_dims[0] = src1_count; simplified_dst_dims[0] = std::max(src0_count, src1_count); return; } int64_t dst_dims[max_num_dims]; int64_t broadcast_dims[max_num_dims]; const size_t num_dst_dims = std::max(num_src0_dims, num_src1_dims); for (int64_t i = 0; i < num_dst_dims; ++i) { const int64_t num_src0_padding_dims = num_dst_dims - num_src0_dims; const int64_t num_src1_padding_dims = num_dst_dims - num_src1_dims; size_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims]; size_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims]; dst_dims[i] = std::max(src0_dim, src1_dim); } SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, num_dst_dims, dst_dims, simplified_num_dims, broadcast_dims, simplified_src0_dims, simplified_src1_dims, simplified_dst_dims); for (int64_t i = 0; i < *simplified_num_dims; ++i) { CHECK_EQ(broadcast_dims[i], simplified_dst_dims[i]); } } template inline bool InferPermutable(size_t simplified_num_dims, const int64_t* simplified_src_strides, const int64_t* simplified_dst_strides, const int64_t* simplified_src_dims, const int64_t* simplified_dst_dims, int* permutation_list, int64_t* permutation_src_dims, UnaryOp unary_op) { if (unary_op != UnaryOp::kIdentity) { return false; } // all dims of src & dst should be the same for (size_t i = 0; i < simplified_num_dims; i++) { if (simplified_src_dims[i] != simplified_dst_dims[i]) { return false; } } // only simplified_src_strides need to be sorted, simplified_dst_strides has been sorted in // SimplifyBroadcastDims std::pair sorted_src_strides[max_num_dims]; for (size_t i = 0; i < simplified_num_dims; i++) { sorted_src_strides[i] = {simplified_src_strides[i], i}; } std::sort(sorted_src_strides, sorted_src_strides + simplified_num_dims, [](auto pair1, auto pair2) { return pair1.first > pair2.first; }); // src & dst has to be filled with numbers without strides if (sorted_src_strides[simplified_num_dims - 1].first != 1) { return false; } for (size_t i = simplified_num_dims - 1; i > 0; i--) { if (sorted_src_strides[i - 1].first != sorted_src_strides[i].first * simplified_src_dims[sorted_src_strides[i].second]) { return false; } } if (simplified_dst_strides[simplified_num_dims - 1] != 1) { return false; } for (size_t i = simplified_num_dims - 1; i > 0; i--) { if (simplified_dst_strides[i - 1] != simplified_dst_strides[i] * simplified_dst_dims[i]) { return false; } } for (size_t j = 0; j < simplified_num_dims; j++) { permutation_list[j] = sorted_src_strides[j].second; permutation_src_dims[j] = simplified_src_dims[sorted_src_strides[j].second]; } return true; } template std::unique_ptr NewPrimitiveFromHandlers( const std::map()>>& handlers, const D& key) { const auto iter = handlers.find(key); if (iter != handlers.end()) { return iter->second(); } return nullptr; } } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ ================================================ FILE: oneflow/core/ep/common/primitive/where.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_ #define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_ #include "oneflow/core/ep/include/primitive/where.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/common/primitive/util.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace ep { namespace primitive { namespace { constexpr size_t kMaxNumDims = 8; template struct WhereElemwiseFunctor { OF_DEVICE_FUNC WhereElemwiseFunctor() {} OF_DEVICE_FUNC R operator()(Cond cond, X x, Y y) const { return cond ? x : y; } }; template using WhereFunctor = WhereElemwiseFunctor; template struct BroadcastElementwiseWhereParams { NdIndexOffsetHelper cond_index_helper; NdIndexOffsetHelper x_index_helper; NdIndexOffsetHelper y_index_helper; NdIndexOffsetHelper z_index_helper; IndexType cond_index_mask[NDIM]; IndexType x_index_mask[NDIM]; IndexType y_index_mask[NDIM]; IndexType elem_cnt{}; const void* cond{}; const void* x{}; const void* y{}; void* z{}; }; template struct alignas(sizeof(T) * pack_size) Packed { OF_DEVICE_FUNC Packed() { // do nothing } union { T elem[pack_size]; }; }; inline bool IsDimsEquals(size_t ndim, const int64_t* a_dims, const int64_t* b_dims) { for (size_t i = 0; i < ndim; ++i) { if (a_dims[i] != b_dims[i]) { return false; } } return true; } // Calculate compact broadcast dimensions // For example: // [1, 2, 8] and [4, 2, 8] can be compacted to [1, 16] and [4, 16] // [4, 1, 8] and [8] -> [4, 8] and [1, 8] // [1, 1, 8] and [4, 2, 8] -> [1, 8] and [8, 8] // after compacting, cond, x, y will have the same number of dims, // z_dims is the broadcast dims of compacted cond, x, y dims. inline void GetCompactBroadcastDims(const size_t num_cond_ndims, const int64_t* cond_dims, const size_t num_x_dims, const int64_t* x_dims, const size_t num_y_dims, const int64_t* y_dims, size_t* compact_num_dims, int64_t* compact_cond_dims, int64_t* compact_x_dims, int64_t* compact_y_dims, int64_t* compact_z_dims) { size_t max_num_dims = std::max(std::max(num_x_dims, num_y_dims), num_cond_ndims); CHECK_LE(max_num_dims, kMaxNumDims); auto MakeGetDimSize = [max_num_dims](size_t ndim, const int64_t* dims) { size_t lpad = max_num_dims - ndim; return [lpad, dims](int dim) -> int64_t { return dim < lpad ? 1 : dims[dim - lpad]; }; }; auto GetCondDimSize = MakeGetDimSize(num_cond_ndims, cond_dims); auto GetXDimSize = MakeGetDimSize(num_x_dims, x_dims); auto GetYDimSize = MakeGetDimSize(num_y_dims, y_dims); size_t& num_dims = *compact_num_dims; num_dims = 0; bool cond_pred_dim_broadcast = false; bool x_pred_dim_broadcast = false; bool y_pred_dim_broadcast = false; for (int i = 0; i < max_num_dims; ++i) { int64_t cond_dim_size = GetCondDimSize(i); int64_t x_dim_size = GetXDimSize(i); int64_t y_dim_size = GetYDimSize(i); int64_t dim_size = std::max(std::max(x_dim_size, y_dim_size), cond_dim_size); if (dim_size == 1) { continue; } bool cond_broadcast = (cond_dim_size == 1); bool x_broadcast = (x_dim_size == 1); bool y_broadcast = (y_dim_size == 1); if (*compact_num_dims > 0 && cond_broadcast == cond_pred_dim_broadcast && x_broadcast == x_pred_dim_broadcast && y_broadcast == y_pred_dim_broadcast) { compact_cond_dims[num_dims - 1] *= cond_dim_size; compact_x_dims[num_dims - 1] *= x_dim_size; compact_y_dims[num_dims - 1] *= y_dim_size; compact_z_dims[num_dims - 1] *= dim_size; } else { compact_cond_dims[num_dims] = cond_dim_size; compact_x_dims[num_dims] = x_dim_size; compact_y_dims[num_dims] = y_dim_size; compact_z_dims[num_dims] = dim_size; num_dims += 1; cond_pred_dim_broadcast = cond_broadcast; x_pred_dim_broadcast = x_broadcast; y_pred_dim_broadcast = y_broadcast; } } } template void LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z); template void LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z); template void LaunchByDispatchIndexType(Stream* stream, int64_t* cond_dims, int64_t* x_dims, int64_t* y_dims, int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { const size_t elem_cnt = GetElementCount(ndim, z_dims); if (elem_cnt < GetMaxVal()) { return LaunchKernel( stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z); } else { return LaunchKernel( stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z); } } template size_t GetPackSize(const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x, const T* y, const T* z) { static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); CHECK_GT(z_dims[ndim - 1], 1); for (size_t pack_size = max_pack_size; pack_size >= 2; pack_size /= 2) { if (!IsPackSizeSupported(pack_size, ndim, z_dims, z)) { continue; } if (x_dims[ndim - 1] != 1 && !IsPackSizeSupported(pack_size, ndim, x_dims, x)) { continue; } if (y_dims[ndim - 1] != 1 && !IsPackSizeSupported(pack_size, ndim, y_dims, y)) { continue; } if (cond_dims[ndim - 1] != 1 && !IsPackSizeSupported(pack_size, ndim, cond_dims, cond)) { continue; } return pack_size; } return 1; } template void LaunchByDispatchPackSize(Stream* stream, int64_t* cond_dims, int64_t* x_dims, int64_t* y_dims, int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { static_assert(ndim > 0, ""); constexpr size_t kMaxPackSize = 4; size_t pack_size = GetPackSize(cond_dims, x_dims, y_dims, z_dims, cond, x, y, z); size_t cond_pack_size = 1; size_t x_pack_size = 1; size_t y_pack_size = 1; if (pack_size > 1) { if (cond_dims[ndim - 1] != 1) { cond_dims[ndim - 1] /= pack_size; cond_pack_size = pack_size; } if (x_dims[ndim - 1] != 1) { x_dims[ndim - 1] /= pack_size; x_pack_size = pack_size; } if (y_dims[ndim - 1] != 1) { y_dims[ndim - 1] /= pack_size; y_pack_size = pack_size; } z_dims[ndim - 1] /= pack_size; } #define IF(cp, xp, yp) \ if (cond_pack_size == cp && x_pack_size == xp && y_pack_size == yp) { \ LaunchByDispatchIndexType(stream, cond_dims, x_dims, y_dims, \ z_dims, cond, x, y, z); \ } #define ELIF(cp, xp, yp) else IF(cp, xp, yp) #define ELSE \ else { \ UNIMPLEMENTED(); \ } if (pack_size == 1) { IF(1, 1, 1) ELSE } else if (pack_size == 2) { IF(2, 2, 2) ELIF(1, 2, 2) ELIF(1, 2, 1) ELIF(1, 1, 2) ELIF(2, 1, 2) ELIF(2, 1, 1) ELIF(2, 2, 1) ELSE } else if (pack_size == 4) { IF(4, 4, 4) ELIF(1, 4, 4) ELIF(1, 4, 1) ELIF(1, 1, 4) ELIF(4, 1, 4) ELIF(4, 1, 1) ELIF(4, 4, 1) ELSE } ELSE #undef IF #undef ELIF #undef ELSE } template void LaunchByDispatchNDim(Stream* stream, size_t ndim, int64_t* cond_dims, int64_t* x_dims, int64_t* y_dims, int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { #define ELIF(n) \ else if (ndim == n) { \ LaunchByDispatchPackSize(stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, \ z); \ } #define ELSE \ else { \ UNIMPLEMENTED(); \ } if (ndim == 0) { LaunchScalarKernel(stream, cond, x, y, z); } ELIF(1) ELIF(2) ELIF(3) ELIF(4) ELSE #undef IF #undef ELIF #undef ELSE } template class Prim> std::unique_ptr NewWhere(DataType cond_type, DataType data_type, size_t max_num_dims) { if (max_num_dims > kMaxNumDims) { return nullptr; } const size_t data_type_size = GetSizeOfDataType(data_type); #define IF(ctype, dtype_size) \ if (cond_type == ctype && data_type_size == dtype_size) { \ using T = typename std::aligned_storage::type; \ using CondT = DataTypeToType; \ return std::unique_ptr(new Prim()); \ } #define ELIF(ctype, dtype_size) else IF(ctype, dtype_size) #define ELSE \ else { \ return nullptr; \ } IF(DataType::kBool, 1) ELIF(DataType::kBool, 2) ELIF(DataType::kBool, 4) ELIF(DataType::kBool, 8) ELIF(DataType::kInt32, 1) ELIF(DataType::kInt32, 2) ELIF(DataType::kInt32, 4) ELIF(DataType::kInt32, 8) ELIF(DataType::kInt64, 1) ELIF(DataType::kInt64, 2) ELIF(DataType::kInt64, 4) ELIF(DataType::kInt64, 8) ELSE #undef IF #undef ELIF #undef ELSE } } // namespace } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_ ================================================ FILE: oneflow/core/ep/cpu/cpu_device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/mem_util.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/cpu/cpu_event.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace ep { void CpuDevice::SetAsActiveDevice() {} Stream* CpuDevice::CreateStream() { return new CpuStream(this); } void CpuDevice::DestroyStream(Stream* stream) { delete stream; } void CpuDevice::CreateEvents(Event** events, size_t count) { for (size_t i = 0; i < count; ++i) { events[i] = new CpuEvent(); } } void CpuDevice::DestroyEvents(Event** events, size_t count) { for (size_t i = 0; i < count; ++i) { delete events[i]; } } Maybe CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) { if (options.HasPinnedDevice()) { auto device = this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT options.GetPinnedDeviceIndex()); // NOLINT CHECK_OR_RETURN(device); JUST(device->AllocPinned(options, ptr, size)); } else { *ptr = aligned_alloc(kMaxAlignmentRequirement, RoundUp(size, kMaxAlignmentRequirement)); if (*ptr == nullptr) { return Error::RuntimeError() << "CPU can't allocate memory. Tried to allocate " << FormatMemSize(size); } } memset(*ptr, 0, size); return Maybe::Ok(); } void CpuDevice::Free(const AllocationOptions& options, void* ptr) { if (options.HasPinnedDevice()) { auto device = this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT options.GetPinnedDeviceIndex()); // NOLINT CHECK(device); return device->FreePinned(options, ptr); } else { free(ptr); // NOLINT } } Maybe CpuDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) { AllocationOptions new_options = options; new_options.ClearPinnedDevice(); return Alloc(new_options, ptr, size); } void CpuDevice::FreePinned(const AllocationOptions& options, void* ptr) { AllocationOptions new_options = options; new_options.ClearPinnedDevice(); return Free(new_options, ptr); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_device.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_ #define ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_ #include "oneflow/core/ep/include/device.h" namespace oneflow { namespace ep { class CpuDevice : public Device { public: OF_DISALLOW_COPY_AND_MOVE(CpuDevice); explicit CpuDevice(DeviceManager* device_manager) : device_manager_(device_manager), num_threads_(1) {} ~CpuDevice() override = default; void SetAsActiveDevice() override; void Reset() override {} void SetNumThreads(size_t num_threads) { num_threads_ = num_threads; } size_t GetNumThreads() { return num_threads_; } DeviceType device_type() const override { return DeviceType::kCPU; } size_t device_index() const override { return 0; } DeviceManager* device_manager() const override { return device_manager_; } Stream* CreateStream() override; void DestroyStream(Stream* stream) override; void CreateEvents(Event** events, size_t count) override; void DestroyEvents(Event** events, size_t count) override; Maybe Alloc(const AllocationOptions& options, void** ptr, size_t size) override; void Free(const AllocationOptions& options, void* ptr) override; Maybe AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override; void FreePinned(const AllocationOptions& options, void* ptr) override; private: DeviceManager* device_manager_; size_t num_threads_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_ ================================================ FILE: oneflow/core/ep/cpu/cpu_device_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cpu/cpu_device_manager.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/cpu/cpu_random_generator.h" namespace oneflow { namespace ep { CpuDeviceManager::CpuDeviceManager(DeviceManagerRegistry* registry) : device_num_threads_(1), registry_(registry) {} CpuDeviceManager::~CpuDeviceManager() = default; DeviceManagerRegistry* CpuDeviceManager::registry() const { return registry_; } std::shared_ptr CpuDeviceManager::GetDevice(size_t device_index) { std::lock_guard lock(device_mutex_); if (!device_) { device_.reset(new CpuDevice(this)); } device_->SetNumThreads(device_num_threads_); return device_; } size_t CpuDeviceManager::GetDeviceCount(size_t /*primary_device_index*/) { return 1; } size_t CpuDeviceManager::GetDeviceCount() { return 1; } size_t CpuDeviceManager::GetActiveDeviceIndex() { return 0; } void CpuDeviceManager::SetActiveDeviceByIndex(size_t device_index) {} void CpuDeviceManager::SetDeviceNumThreads(size_t num_threads) { device_num_threads_ = num_threads; } std::shared_ptr CpuDeviceManager::CreateRandomGenerator(uint64_t seed, size_t device_index) { return std::make_shared(seed, device_index); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_device_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_ #define ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_ #include "oneflow/core/ep/include/device_manager.h" namespace oneflow { namespace ep { class CpuDevice; class CpuDeviceManager : public DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(CpuDeviceManager); explicit CpuDeviceManager(DeviceManagerRegistry* registry); ~CpuDeviceManager() override; DeviceManagerRegistry* registry() const override; std::shared_ptr GetDevice(size_t device_index) override; size_t GetDeviceCount(size_t primary_device_index) override; size_t GetDeviceCount() override; size_t GetActiveDeviceIndex() override; void SetActiveDeviceByIndex(size_t device_index) override; void SetDeviceNumThreads(size_t num_threads); std::shared_ptr CreateRandomGenerator(uint64_t seed, size_t device_index) override; private: size_t device_num_threads_; std::mutex device_mutex_; std::shared_ptr device_; DeviceManagerRegistry* registry_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_ ================================================ FILE: oneflow/core/ep/cpu/cpu_device_manager_factory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/device_manager_factory.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/cpu/cpu_device_manager.h" namespace oneflow { namespace ep { namespace { class CpuDeviceManagerFactory : public DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(CpuDeviceManagerFactory); CpuDeviceManagerFactory() = default; ~CpuDeviceManagerFactory() override = default; std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kCPU; } std::string device_type_name() const override { return "cpu"; } }; COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory( std::make_unique())) } // namespace namespace { class MockDeviceManagerFactory : public DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(MockDeviceManagerFactory); MockDeviceManagerFactory() = default; ~MockDeviceManagerFactory() override = default; std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kMockDevice; } std::string device_type_name() const override { return "mock"; } }; COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory( std::make_unique())) } // namespace namespace { class MetaDeviceManagerFactory : public DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(MetaDeviceManagerFactory); MetaDeviceManagerFactory() = default; ~MetaDeviceManagerFactory() override = default; std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kMeta; } std::string device_type_name() const override { return "meta"; } }; COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory( std::make_unique())) } // namespace } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_event.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cpu/cpu_event.h" namespace oneflow { namespace ep { Maybe CpuEvent::QueryDone() { return Maybe(true); } Maybe CpuEvent::Sync() { return Maybe::Ok(); } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_event.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_ #define ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_ #include "oneflow/core/ep/include/event.h" namespace oneflow { namespace ep { class CpuEvent : public Event { public: OF_DISALLOW_COPY_AND_MOVE(CpuEvent); CpuEvent() = default; ~CpuEvent() override = default; Maybe QueryDone() override; Maybe Sync() override; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_ ================================================ FILE: oneflow/core/ep/cpu/cpu_random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cpu/cpu_random_generator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/str_util.h" namespace oneflow { namespace ep { struct CPUGeneratorState { static constexpr int64_t state_size = std::mt19937::state_size; // 624 int64_t states[state_size] = {}; int64_t seed = 0; }; constexpr int64_t CPUGeneratorState::state_size; void CPUGenerator::set_current_seed(uint64_t seed) { seed_ = seed; engine_.seed(seed_); torch_engine_ = pytorch_mt19937_engine(seed); } size_t CPUGenerator::GetStateSize() const { return sizeof(CPUGeneratorState); } void CPUGenerator::GetState(size_t state_size, void* state) const { CHECK_EQ_OR_THROW(state_size, GetStateSize()) << "state size of cpu generator should be equal to " << GetStateSize(); CPUGeneratorState local_state; std::stringstream ss; ss << engine_; std::vector splits; Split(ss.str(), " ", [&](std::string&& s) { splits.emplace_back(s); }); // The last element in `splits` indicates state size, not state. if (splits.size() != CPUGeneratorState::state_size + 1) { return THROW(RuntimeError) << "std::mt19937 state size should be " << CPUGeneratorState::state_size << ", but got " << splits.size() - 1; } for (int i = 0; i < CPUGeneratorState::state_size; ++i) { local_state.states[i] = std::atoll(splits[i].data()); } local_state.seed = current_seed(); memcpy(state, &local_state, sizeof(CPUGeneratorState)); } void CPUGenerator::SetState(size_t state_size, const void* state) { CHECK_EQ_OR_THROW(state_size, GetStateSize()) << "state size of cpu generator should be equal to " << GetStateSize(); const CPUGeneratorState* local_state = static_cast(state); seed_ = local_state->seed; std::stringstream ss; for (int i = 0; i < CPUGeneratorState::state_size; ++i) { ss << local_state->states[i] << " "; } ss << CPUGeneratorState::state_size; ss >> engine_; } template<> std::string GetRandomGeneratorDeviceTypeName() { return "cpu"; } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_ #include #include #include #include #include #include "oneflow/core/common/device_type.h" #include "oneflow/core/ep/include/random_generator.h" namespace oneflow { namespace ep { // NOTE(Liang Depeng): The following implementation of mt19937 is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/MT19937RNGEngine.h // in order to make distribution related cpu kernels to have the same output as // pytorch when setting the same seed. constexpr int MERSENNE_STATE_N = 624; constexpr int MERSENNE_STATE_M = 397; constexpr uint32_t MATRIX_A = 0x9908b0df; constexpr uint32_t UMASK = 0x80000000; constexpr uint32_t LMASK = 0x7fffffff; struct pytorch_mt19937_data_pod { uint64_t seed_; int left_; bool seeded_; uint32_t next_; std::array state_; }; class pytorch_mt19937_engine { public: inline explicit pytorch_mt19937_engine(uint64_t seed = 5489) { init_with_uint32(seed); } inline pytorch_mt19937_data_pod data() const { return data_; } inline void set_data(pytorch_mt19937_data_pod data) { data_ = data; } inline uint64_t seed() const { return data_.seed_; } inline bool is_valid() { if ((data_.seeded_ == true) && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N) && (data_.next_ <= MERSENNE_STATE_N)) { return true; } return false; } inline uint32_t operator()() { uint32_t y; if (--(data_.left_) == 0) { next_state(); } y = *(data_.state_.data() + data_.next_++); y ^= (y >> 11); y ^= (y << 7) & 0x9d2c5680; y ^= (y << 15) & 0xefc60000; y ^= (y >> 18); return y; } private: pytorch_mt19937_data_pod data_; inline void init_with_uint32(uint64_t seed) { data_.seed_ = seed; data_.seeded_ = true; data_.state_[0] = seed & 0xffffffff; for (int j = 1; j < MERSENNE_STATE_N; ++j) { data_.state_[j] = (1812433253 * (data_.state_[j - 1] ^ (data_.state_[j - 1] >> 30)) + j); } data_.left_ = 1; data_.next_ = 0; } inline uint32_t mix_bits(uint32_t u, uint32_t v) { return (u & UMASK) | (v & LMASK); } inline uint32_t twist(uint32_t u, uint32_t v) { return (mix_bits(u, v) >> 1) ^ (v & 1 ? MATRIX_A : 0); } inline void next_state() { uint32_t* p = data_.state_.data(); data_.left_ = MERSENNE_STATE_N; data_.next_ = 0; for (int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) { *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]); } for (int j = MERSENNE_STATE_M; --j; p++) { *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]); } *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]); } }; class CPUGenerator : public RandomGenerator { public: explicit CPUGenerator(uint64_t seed, int device_index) : RandomGenerator(), seed_(seed), engine_(seed), torch_engine_(seed) {} virtual ~CPUGenerator() = default; uint64_t current_seed() const override { return seed_; } void set_current_seed(uint64_t seed) override; std::mt19937& engine() { return engine_; } pytorch_mt19937_engine& torch_engine() { return torch_engine_; } std::string device_type_name() const override { return "cpu"; } int64_t device_index() const override { return 0; } size_t GetStateSize() const override; void GetState(size_t state_size, void* state) const override; void SetState(size_t state_size, const void* state) override; public: mutable std::mutex mutex_; uint64_t seed_; std::mt19937 engine_; // TODO(Liang Depeng): needed to implement the get_state/set_state of pytorch_mt_19937_engine // refer to // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CPUGenerator.cpp#L206 pytorch_mt19937_engine torch_engine_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/ep/cpu/cpu_stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/thread/thread_runtime_factory.h" namespace oneflow { namespace ep { DeviceType CpuStream::device_type() const { return DeviceType::kCPU; } CpuDevice* CpuStream::device() const { return device_; } Maybe CpuStream::Sync() { return Maybe::Ok(); } void CpuStream::RecordEvent(Event* /*event*/) {} Maybe CpuStream::InitThreadRuntime() { const auto thread_runtime_type = GetStringFromEnv("OF_THREADING_RUNTIME", [] { if (thread::IsTbbEnabled()) { return "TBB"; } if (thread::IsOmpEnabled()) { return "OMP"; } return "SEQ"; }()); thread_runtime_ = JUST(thread::RuntimeFactory::Create(thread_runtime_type)); return Maybe::Ok(); } #ifdef WITH_ONEDNN const std::unique_ptr& CpuStream::onednn_executor() const { return onednn_executor_; } #endif } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/cpu_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_ #define ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/thread/thread_runtime_factory.h" #ifdef WITH_ONEDNN #include #endif namespace oneflow { namespace ep { class CpuNumThreadsGuard { public: OF_DISALLOW_COPY_AND_MOVE(CpuNumThreadsGuard); #if WITH_TBB explicit CpuNumThreadsGuard(size_t num_threads) : global_thread_limit(tbb::global_control::max_allowed_parallelism, num_threads) {} ~CpuNumThreadsGuard() {} #elif WITH_OMP explicit CpuNumThreadsGuard(size_t num_threads) : set_num_threads_(num_threads) { saved_num_threads_ = omp_get_max_threads(); omp_set_num_threads(set_num_threads_); } ~CpuNumThreadsGuard() { omp_set_num_threads(saved_num_threads_); } #endif private: #if WITH_TBB tbb::global_control global_thread_limit; #elif WITH_OMP size_t set_num_threads_; size_t saved_num_threads_; #endif }; #ifdef WITH_ONEDNN class OneDnnExecutor; #endif class CpuStream : public Stream { public: OF_DISALLOW_COPY_AND_MOVE(CpuStream); explicit CpuStream(CpuDevice* device) : device_(device) { CHECK_JUST(InitThreadRuntime()); #ifdef WITH_ONEDNN onednn_executor_ = std::make_unique(this); #endif } ~CpuStream() override = default; DeviceType device_type() const override; CpuDevice* device() const override; Maybe Sync() override; void RecordEvent(Event* event) override; template void ParallelFor(int64_t begin, int64_t end, const F& func) { ParallelFor(begin, end, func, kParallelForDefaultGrain); } template void ParallelFor(int64_t begin, int64_t end, const F& func, size_t grain_size) { thread_runtime_->ParallelFor(begin, end, func, device()->GetNumThreads(), grain_size); } #ifdef WITH_ONEDNN const std::unique_ptr& onednn_executor() const; #endif private: CpuDevice* device_; static constexpr size_t kParallelForDefaultGrain = 32768; std::shared_ptr thread_runtime_; Maybe InitThreadRuntime(); #ifdef WITH_ONEDNN std::unique_ptr onednn_executor_; #endif }; #ifdef WITH_ONEDNN class OneDnnExecutor { public: OF_DISALLOW_COPY_AND_MOVE(OneDnnExecutor); OneDnnExecutor() = delete; explicit OneDnnExecutor(CpuStream* cpu_stream) : cpu_stream_(cpu_stream) { engine_.reset(new dnnl::engine(dnnl::engine::kind::cpu, 0)); stream_.reset(new dnnl::stream(*engine_)); } ~OneDnnExecutor() = default; template void Launch(const F& f) { CpuNumThreadsGuard guard(cpu_stream_->device()->GetNumThreads()); f(engine_.get(), stream_.get()); stream_->wait(); } private: CpuStream* cpu_stream_ = nullptr; std::unique_ptr engine_; std::unique_ptr stream_; }; #endif } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_ ================================================ FILE: oneflow/core/ep/cpu/primitive/add.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/common/primitive/util.h" #include "oneflow/core/ep/common/onednn.h" namespace oneflow { namespace ep { namespace primitive { namespace { template void AddCpu(const T* const* srcs, T* dst, size_t count) { for (size_t i = 0; i < count; ++i) { T sum = T(0); for (size_t a = 0; a < arity; ++a) { sum += srcs[a][i]; } dst[i] = sum; } } template void AddCpu(const T* const* srcs, size_t arity, T* dst, size_t count) { for (size_t i = 0; i < count; ++i) { T sum = T(0); for (size_t a = 0; a < arity; ++a) { sum += srcs[a][i]; } dst[i] = sum; } } template class AddDefaultImpl : public Add { public: OF_DISALLOW_COPY_AND_MOVE(AddDefaultImpl); AddDefaultImpl() = default; ~AddDefaultImpl() override = default; using Add::Launch; void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, size_t count) override { #define ONE_IF(a) \ if (arity == a) { \ AddCpu(reinterpret_cast(srcs), reinterpret_cast(dst), count); \ } #define ONE_ELIF(a) else ONE_IF(a) #define ONE_ELSE \ else { \ AddCpu(reinterpret_cast(srcs), arity, reinterpret_cast(dst), count); \ } ONE_IF(0) ONE_ELIF(1) ONE_ELIF(2) ONE_ELIF(3) ONE_ELIF(4) ONE_ELIF(5) ONE_ELIF(6) ONE_ELIF(7) ONE_ELIF(8) ONE_ELSE } }; #ifdef WITH_ONEDNN class AddOneDnnImpl : public Add { public: OF_DISALLOW_COPY_AND_MOVE(AddOneDnnImpl); explicit AddOneDnnImpl(dnnl::memory::data_type type) : type_onednn_(type){}; ~AddOneDnnImpl() override = default; using Add::Launch; void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, size_t count) override { if (arity < 2) { // TODO: arity 0 and 1 UNIMPLEMENTED() << "Addn only supports summation of 2 or more tensors"; } else if (arity == 2) { if (srcs[1] == dst && srcs[0] != dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; } } else { for (int i = 2; i < arity; i++) { if (srcs[i] == dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; } } } stream->As()->onednn_executor()->Launch( [&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) { dnnl::memory::dims src_dims = {static_cast(count)}; std::vector src_md; std::vector src_mem; src_md.reserve(arity); src_mem.reserve(arity); for (int i = 0; i < arity; i++) { auto md = dnnl::memory::desc(src_dims, type_onednn_, dnnl::memory::format_tag::x); auto mem = dnnl::memory(md, *onednn_engine, (void*)(srcs)[i]); src_md.emplace_back(md); src_mem.emplace_back(mem); } std::vector scales(arity, 1.0); auto sum_pd = dnnl::sum::primitive_desc(scales, src_md, *onednn_engine); auto sum_prim = dnnl::sum(sum_pd); auto dst_mem = dnnl::memory(sum_pd.dst_desc(), *onednn_engine, dst); std::unordered_map sum_args{{DNNL_ARG_DST, dst_mem}}; for (int i = 0; i < arity; ++i) { sum_args.insert({DNNL_ARG_MULTIPLE_SRC + i, src_mem[i]}); } sum_prim.execute(*onednn_stream, sum_args); }); } private: dnnl::memory::data_type type_onednn_; }; #endif template std::unique_ptr NewAdd() { return std::unique_ptr(new AddDefaultImpl()); } #ifdef WITH_ONEDNN template std::unique_ptr NewOneDnnAdd() { return std::unique_ptr(new AddOneDnnImpl(type_onednn)); } #endif #define CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ #define CPU_PRIMITIVE_ADD_DEFAULT_TYPE_SEQ \ CPU_PRIMITIVE_BOOL_TYPE_SEQ \ CPU_PRIMITIVE_CHAR_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ class AddFactoryImpl : public AddFactory { public: OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl); AddFactoryImpl() = default; ~AddFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd}, static const std::map()>> new_add_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)}; #undef MAKE_NEW_ADD_ENTRY #ifdef WITH_ONEDNN #define MAKE_NEW_ONEDNN_ADD_ENTRY(type_onednn, type_proto) {type_proto, NewOneDnnAdd}, static const std::map()>> new_add_onednn_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ONEDNN_ADD_ENTRY, CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ)}; #undef MAKE_NEW_ONEDNN_ADD_ENTRY if (OneDnnIsEnabled()) { auto add_primitive = NewPrimitiveFromHandlers(new_add_onednn_handle, data_type); if (add_primitive) { return add_primitive; } } #endif return NewPrimitiveFromHandlers(new_add_handle, data_type); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, AddFactory, AddFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/binary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/binary_functor.h" #include "oneflow/core/ep/cpu/primitive/unary_functor.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return std::pow(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast(std::pow(static_cast(src0), static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::fmod(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::fmod(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast(std::fmod(static_cast(src0), static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const { return std::fmod(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::floor(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::floor(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast(std::floor(static_cast(src0) / static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const { return std::floor(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::trunc(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::trunc(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast(std::trunc(static_cast(src0) / static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const { return std::trunc(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { float trunc_mod = std::fmod(src0, src1); return (trunc_mod != static_cast(0)) && ((src1 < static_cast(0)) != (trunc_mod < static_cast(0))) ? trunc_mod + src1 : trunc_mod; } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { double trunc_mod = std::fmod(src0, src1); return (trunc_mod != static_cast(0)) && ((src1 < static_cast(0)) != (trunc_mod < static_cast(0))) ? trunc_mod + src1 : trunc_mod; } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value()) {} OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { return static_cast( scalar_operand * (std::pow(static_cast(src0), scalar_operand - static_cast(1))) * static_cast(src1)); } float scalar_operand; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int src0, int src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int8_t src0, int8_t src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(uint8_t src0, uint8_t src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int src0, int src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value()) {} OF_DEVICE_FUNC Dst operator()(float16 src0, float16 src1) const { return static_cast(std::pow(scalar_operand, static_cast(src0)) * std::log(scalar_operand) * static_cast(src1)); } float scalar_operand; }; template struct BinaryFunctor, Dst> { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(std::complex dy, std::complex x) const { return dy * static_cast>(0.5) / std::conj(std::sqrt(x)); } }; template struct BinaryFunctor, Dst> { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(std::complex dy, std::complex x) const { return dy * static_cast>(0.5) / std::conj(std::sqrt(x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast( 0.5 * (1.0 + std::erf(inv_sqrt2 * x) + x * coef * std::exp(-0.5 * x * x)) * dy); } Src inv_sqrt2 = std::sqrt(0.5); Src coef = std::sqrt(2.0 / std::acos(-1.0)); }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu const Src one = static_cast(1); const Src half = static_cast(0.5); const Src pow3 = x * x * x; const Src tanh_out = std::tanh(alpha * (x + beta * pow3)); const Src dtanh = alpha * (half * x + beta * static_cast(1.5) * pow3); return dy * (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out)); } private: static constexpr Src alpha = static_cast(0.7978845608028654); static constexpr Src beta = static_cast(0.044714998453855515); }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src one = static_cast(1.0); const Src sigmoid = one / (one + exp(-x * alpha)); return dy * (sigmoid + alpha * x * (sigmoid * (one - sigmoid))); } private: static constexpr Src alpha = static_cast(1.702); }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x > static_cast(0.0)) ? static_cast(2.0) * x * dy : static_cast(0.0)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast(dy * (static_cast(1.0) - y * y)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * -(static_cast(1.0) / sqrt(static_cast(1.0) - x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy / sqrt(x * x - static_cast(1.0)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / sqrt(static_cast(1.0) - x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast(1.0) / sqrt(static_cast(1.0) + x * x)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast(2.0) * (static_cast(1.0) / sqrt(static_cast(M_PI))) * exp(-x * x); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast(-2.0) * (static_cast(1.0) / sqrt(static_cast(M_PI))) * exp(-x * x); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float dy, float x) const { ep::primitive::UnaryFunctor trigamma_functor(0, 0); float trigamma_result = trigamma_functor(x); return trigamma_result * dy; } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double dy, double x) const { ep::primitive::UnaryFunctor trigamma_functor(0, 0); double trigamma_result = trigamma_functor(x); return trigamma_result * dy; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { ep::primitive::UnaryFunctor digamma_functor(0, 0); Dst digamma_result = digamma_functor(x); return digamma_result * dy; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src x, Src q) const { // ref // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L235-L309 const Src MACHEP = Src{1.11022302462515654042E-16}; constexpr Src zero = Src{0.0}; constexpr Src half = Src{0.5}; constexpr Src one = Src{1.0}; static const Src A[] = { 12.0, -720.0, 30240.0, -1209600.0, 47900160.0, -1.8924375803183791606e9, /*1.307674368e12/691*/ 7.47242496e10, -2.950130727918164224e12, /*1.067062284288e16/3617*/ 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ }; int i = 0; Src a, b, k, s, t, w; if (x == one) { return std::numeric_limits::infinity(); } if (x < one) { return std::numeric_limits::quiet_NaN(); } if (q <= zero) { if (q == floor(q)) { return std::numeric_limits::infinity(); } if (x != floor(x)) { return std::numeric_limits::quiet_NaN(); } } s = pow(q, -x); a = q; i = 0; b = zero; while ((i < 9) || (a <= Src{9.0})) { i += 1; a += one; b = pow(a, -x); s += b; if ((-MACHEP * s < b) && (b < MACHEP * s)) { return static_cast(s); } }; w = a; s += b * w / (x - one); s -= half * b; a = one; k = zero; for (int i = 0; i < 12; i++) { a *= x + k; b /= w; t = a * b / A[i]; s = s + t; t = fabs(t / s); if (t < MACHEP) { return static_cast(s); } k += one; a *= x + k; b /= w; k += one; } return static_cast(s); } }; #define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \ template<> \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \ \ BinaryFunctor int_functor; \ OF_DEVICE_FUNC type operator()(type src0, type src1) const { \ return static_cast(int_functor(static_cast(src0), static_cast(src1))); \ } \ }; SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, bool); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, char); SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, char); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/common//primitive/constant_pad.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/cpu/primitive/binary_functor.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/common/primitive/util.h" #include "oneflow/core/ep/common/onednn.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { namespace { template T GetValue(Scalar value) { return value.Value(); } template<> float16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template<> bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template struct BinaryLhsScalarFunctor { BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) : scalar(scalar), functor(attr0, attr1) {} Dst operator()(Src src) const { return functor(scalar, src); } const Src scalar; BinaryFunctor functor; }; template struct BinaryRhsScalarFunctor { BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) : scalar(scalar), functor(attr0, attr1) {} Dst operator()(Src src) const { return functor(src, scalar); } const Src scalar; BinaryFunctor functor; }; template void LaunchElementwise(CpuStream* cpu_stream, size_t simplified_num_dims, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims); auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor(0, elem_cnt, [functor, src0, src1, dst](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i] = functor(src0[i], src1[i]); } }); } template void LaunchBinaryLhsScalar(CpuStream* cpu_stream, Src src0_value, size_t src1_elem_cnt, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { auto functor = BinaryLhsScalarFunctor(src0_value, attr0, attr1); cpu_stream->ParallelFor(0, src1_elem_cnt, [functor, src1, dst](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i] = functor(src1[i]); } }); } template void LaunchBinaryRhsScalar(CpuStream* cpu_stream, Src src1_value, size_t src0_elem_cnt, const Src* src0, Dst* dst, Scalar attr0, Scalar attr1) { auto functor = BinaryRhsScalarFunctor(src1_value, attr0, attr1); cpu_stream->ParallelFor(0, src0_elem_cnt, [functor, src0, dst](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i] = functor(src0[i]); } }); } template void LaunchRowWithMatrix(CpuStream* cpu_stream, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { int64_t rows = simplified_src1_dims[0]; int64_t cols = simplified_src0_dims[1]; auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor( 0, rows, [functor, src0, src1, dst, cols](int64_t begin, int64_t end) { for (int64_t row_idx = begin; row_idx < end; row_idx++) { const Src* src1_row = src1 + row_idx * cols; Dst* dst_row = dst + row_idx * cols; for (int64_t col_idx = 0; col_idx < cols; col_idx++) { dst_row[col_idx] = functor(src0[col_idx], src1_row[col_idx]); } } }, 1); } template void LaunchMatrixWithRow(CpuStream* cpu_stream, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { int64_t rows = simplified_src0_dims[0]; int64_t cols = simplified_src1_dims[1]; auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor( 0, rows, [functor, src0, src1, dst, cols](int64_t begin, int64_t end) { for (int64_t row_idx = begin; row_idx < end; row_idx++) { const Src* src0_row = src0 + row_idx * cols; Dst* dst_row = dst + row_idx * cols; for (int64_t col_idx = 0; col_idx < cols; col_idx++) { dst_row[col_idx] = functor(src0_row[col_idx], src1[col_idx]); } } }, 1); } template void LaunchColWithMatrix(CpuStream* cpu_stream, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { int64_t rows = simplified_src0_dims[0]; int64_t cols = simplified_src1_dims[1]; auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor( 0, rows, [functor, src0, src1, dst, cols](int64_t begin, int64_t end) { for (int64_t row_idx = begin; row_idx < end; row_idx++) { const Src* src1_row = src1 + row_idx * cols; Dst* dst_row = dst + row_idx * cols; for (int64_t col_idx = 0; col_idx < cols; col_idx++) { dst_row[col_idx] = functor(src0[row_idx], src1_row[col_idx]); } } }, 1); } template void LaunchMatrixWithCol(CpuStream* cpu_stream, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { int64_t rows = simplified_src1_dims[0]; int64_t cols = simplified_src0_dims[1]; auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor( 0, rows, [functor, src0, src1, dst, cols](int64_t begin, int64_t end) { for (int64_t row_idx = begin; row_idx < end; row_idx++) { const Src* src0_row = src0 + row_idx * cols; Dst* dst_row = dst + row_idx * cols; for (int64_t col_idx = 0; col_idx < cols; col_idx++) { dst_row[col_idx] = functor(src0_row[col_idx], src1[row_idx]); } } }, 1); } template void LaunchGeneral(CpuStream* cpu_stream, size_t simplified_num_dims, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, const int64_t* simplified_dst_dims, Dst* dst, int64_t dst_elem_cnt, Scalar attr0, Scalar attr1) { auto functor = BinaryFunctor(attr0, attr1); cpu_stream->ParallelFor( 0, dst_elem_cnt, [functor, src0, src1, dst, simplified_num_dims, simplified_src0_dims, simplified_src1_dims, simplified_dst_dims](int64_t begin, int64_t end) { auto src0_index_helper = NdIndexOffsetHelper(simplified_src0_dims, simplified_num_dims); auto src1_index_helper = NdIndexOffsetHelper(simplified_src1_dims, simplified_num_dims); auto dst_index_helper = OffsetToIndexCalculator( simplified_dst_dims, simplified_num_dims); IndexType src0_index[kMaxNumDims]; IndexType src1_index[kMaxNumDims]; IndexType dst_index[kMaxNumDims]; for (IndexType offset = begin; offset < end; offset++) { dst_index_helper.OffsetToNdIndex(offset, dst_index, simplified_num_dims); for (int i = 0; i < kMaxNumDims; i++) { if (i < simplified_num_dims) { src0_index[i] = (simplified_src0_dims[i] != 1) ? dst_index[i] : 0; src1_index[i] = (simplified_src1_dims[i] != 1) ? dst_index[i] : 0; } else { src0_index[i] = 0; src1_index[i] = 0; } } const IndexType src0_offset = src0_index_helper.NdIndexToOffset(src0_index, simplified_num_dims); const IndexType src1_offset = src1_index_helper.NdIndexToOffset(src1_index, simplified_num_dims); dst[offset] = functor(src0[src0_offset], src1[src1_offset]); } }); } template void LaunchGeneralDispatchIndexType(CpuStream* cpu_stream, size_t simplified_num_dims, const int64_t* simplified_src0_dims, const Src* src0, const int64_t* simplified_src1_dims, const Src* src1, const int64_t* simplified_dst_dims, Dst* dst, Scalar attr0, Scalar attr1) { const int64_t dst_elem_cnt = GetElementCount(simplified_num_dims, simplified_dst_dims); if (dst_elem_cnt < (GetMaxVal() / 2)) { LaunchGeneral( cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_dst_dims, dst, dst_elem_cnt, attr0, attr1); } else { LaunchGeneral( cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_dst_dims, dst, dst_elem_cnt, attr0, attr1); } } template void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0, size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { auto* cpu_stream = stream->As(); size_t simplified_num_dims = 0; int64_t simplified_src0_dims[kMaxNumDims]; int64_t simplified_src1_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims]; SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, &simplified_num_dims, simplified_src0_dims, simplified_src1_dims, simplified_dst_dims); CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_dst_dims, dst); CheckInplace(simplified_num_dims, simplified_src1_dims, src1, simplified_dst_dims, dst); if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims, simplified_src1_dims)) { LaunchElementwise(cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, dst, attr0, attr1); } else { if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) { LaunchBinaryLhsScalar(cpu_stream, *src0, simplified_src1_dims[0], src1, dst, attr0, attr1); } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) { LaunchBinaryRhsScalar(cpu_stream, *src1, simplified_src0_dims[0], src0, dst, attr0, attr1); } else if (simplified_num_dims == 2 && simplified_src0_dims[0] == 1 && simplified_src0_dims[1] == simplified_src1_dims[1]) { LaunchRowWithMatrix(cpu_stream, simplified_src0_dims, src0, simplified_src1_dims, src1, dst, attr0, attr1); } else if (simplified_num_dims == 2 && simplified_src1_dims[0] == 1 && simplified_src0_dims[1] == simplified_src1_dims[1]) { LaunchMatrixWithRow(cpu_stream, simplified_src0_dims, src0, simplified_src1_dims, src1, dst, attr0, attr1); } else if (simplified_num_dims == 2 && simplified_src0_dims[1] == 1 && simplified_src0_dims[0] == simplified_src1_dims[0]) { LaunchColWithMatrix(cpu_stream, simplified_src0_dims, src0, simplified_src1_dims, src1, dst, attr0, attr1); } else if (simplified_num_dims == 2 && simplified_src1_dims[1] == 1 && simplified_src0_dims[0] == simplified_src1_dims[0]) { LaunchMatrixWithCol(cpu_stream, simplified_src0_dims, src0, simplified_src1_dims, src1, dst, attr0, attr1); } else { LaunchGeneralDispatchIndexType( cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_dst_dims, dst, attr0, attr1); } } } template class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl); BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~BroadcastElementwiseBinaryImpl() override = default; void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1_ptr, void* dst_ptr) override { auto* cpu_stream = stream->As(); const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims); Dst* dst = reinterpret_cast(dst_ptr); const Src* src1 = reinterpret_cast(src1_ptr); LaunchBinaryLhsScalar(cpu_stream, GetValue(src0), elem_cnt, src1, dst, attr0, attr1); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0_ptr, Scalar src1, void* dst_ptr) override { auto* cpu_stream = stream->As(); const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims); Dst* dst = reinterpret_cast(dst_ptr); const Src* src0 = reinterpret_cast(src0_ptr); LaunchBinaryRhsScalar(cpu_stream, GetValue(src1), elem_cnt, src0, dst, attr0, attr1); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) override { DispatchLaunch( stream, num_src0_dims, src0_dims, reinterpret_cast(src0), num_src1_dims, src1_dims, reinterpret_cast(src1), reinterpret_cast(dst), attr0, attr1); } private: Scalar attr0, attr1; }; template std::unique_ptr NewBroadcastElementwiseBinary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new BroadcastElementwiseBinaryImpl(attr0, attr1)); } #define NDARRAY_BINARY_TYPE_SEQ \ CPU_PRIMITIVE_BOOL_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ #ifdef WITH_ONEDNN uint32_t OnednnFormatTagMap[kMaxNumDims] = {dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd, dnnl_abcde, dnnl_abcdef, dnnl_abcdefg, dnnl_abcdefgh}; inline void OneDnnBroadcastDims(dnnl::memory::dims* src0, size_t num_src0_dims, const int64_t* src0_dims, dnnl::memory::dims* src1, size_t num_src1_dims, const int64_t* src1_dims, dnnl::memory::dims& dst) { const int64_t num_dims = dst.size(); const int64_t num_src0_padding_dims = num_dims - num_src0_dims; const int64_t num_src1_padding_dims = num_dims - num_src1_dims; for (int64_t i = 0; i < num_dims; i++) { int64_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims]; int64_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims]; CHECK((src0_dim == src1_dim || src0_dim == 1 || src1_dim == 1)); (*src0)[i] = src0_dim; (*src1)[i] = src1_dim; dst[i] = std::max(src0_dim, src1_dim); } } template class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { public: OF_DISALLOW_COPY_AND_MOVE(OneDnnBroadcastElementwiseBinaryImpl); OneDnnBroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~OneDnnBroadcastElementwiseBinaryImpl() override = default; void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) override { T scalar_val = GetValue(src0); const int64_t src0_dims = 1; Launch(stream, 1, &src0_dims, &scalar_val, num_src1_dims, src1_dims, src1, dst); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, Scalar src1, void* dst) override { T scalar_val = GetValue(src1); const int64_t src1_dims = 1; Launch(stream, num_src0_dims, src0_dims, src0, 1, &src1_dims, &scalar_val, dst); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) override { stream->As()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) { // onednn do not optimize for 3d tensor in our experiments, so expand it // to 4d if needed. // Note that only onednn "internal" dims will be affected, the shape // of oneflow tensor (including the output tensor) will remain unchanged. size_t num_dims = std::max(std::max(num_src0_dims, num_src1_dims), static_cast(4)); dnnl::memory::dims src_0_dims(num_dims); dnnl::memory::dims src_1_dims(num_dims); dnnl::memory::dims dst_dims(num_dims); const void* onednn_src0 = nullptr; const void* onednn_src1 = nullptr; // OneDNN inplace operations only support src_0 if (src1 == dst) { onednn_src0 = src1; onednn_src1 = src0; OneDnnBroadcastDims(&src_0_dims, num_src1_dims, src1_dims, &src_1_dims, num_src0_dims, src0_dims, dst_dims); } else { onednn_src0 = src0; onednn_src1 = src1; OneDnnBroadcastDims(&src_0_dims, num_src0_dims, src0_dims, &src_1_dims, num_src1_dims, src1_dims, dst_dims); } CheckInplace(num_dims, src_0_dims.data(), onednn_src0, dst_dims.data(), dst); CheckInplace(num_dims, src_1_dims.data(), onednn_src1, dst_dims.data(), dst); auto src_0_md = dnnl::memory::desc( src_0_dims, src_onednn, static_cast(OnednnFormatTagMap[num_dims - 1])); auto src_1_md = dnnl::memory::desc( src_1_dims, src_onednn, static_cast(OnednnFormatTagMap[num_dims - 1])); auto dst_md = dnnl::memory::desc( dst_dims, dst_onednn, static_cast(OnednnFormatTagMap[num_dims - 1])); auto src_0_mem = dnnl::memory(src_0_md, *onednn_engine, (void*)onednn_src0); auto src_1_mem = dnnl::memory(src_1_md, *onednn_engine, (void*)onednn_src1); auto dst_mem = dnnl::memory(dst_md, *onednn_engine, dst); auto binary_d = dnnl::binary::desc(algorithm, src_0_md, src_1_md, dst_md); auto binary_pd = dnnl::binary::primitive_desc(binary_d, *onednn_engine); auto binary_prim = dnnl::binary(binary_pd); binary_prim.execute( *onednn_stream, {{DNNL_ARG_SRC_0, src_0_mem}, {DNNL_ARG_SRC_1, src_1_mem}, {DNNL_ARG_DST, dst_mem}}); }); } private: Scalar attr0, attr1; }; #define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) // OneDNN binary op does not support s32 // CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ #define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ #define BINARY_ONEDNN_ADD OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, dnnl::algorithm::binary_add) #define BINARY_ONEDNN_SUB OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, dnnl::algorithm::binary_sub) #define BINARY_ONEDNN_MUL OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, dnnl::algorithm::binary_mul) #define BINARY_ONEDNN_DIV OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, dnnl::algorithm::binary_div) #define BINARY_ONEDNN_MAX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, dnnl::algorithm::binary_max) #define BINARY_ONEDNN_MIN OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, dnnl::algorithm::binary_min) #define BINARY_ONEDNN_EQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, dnnl::algorithm::binary_eq) #define BINARY_ONEDNN_NE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, dnnl::algorithm::binary_ne) #define BINARY_ONEDNN_LT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, dnnl::algorithm::binary_lt) #define BINARY_ONEDNN_LE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, dnnl::algorithm::binary_le) #define BINARY_ONEDNN_GT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, dnnl::algorithm::binary_gt) #define BINARY_ONEDNN_GE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, dnnl::algorithm::binary_ge) #define BINARY_MATH_OP_ONEDNN_PAIR \ BINARY_ONEDNN_ADD \ BINARY_ONEDNN_SUB \ BINARY_ONEDNN_MUL \ BINARY_ONEDNN_DIV \ BINARY_ONEDNN_MAX \ BINARY_ONEDNN_MIN #define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR \ BINARY_ONEDNN_EQ \ BINARY_ONEDNN_NE \ BINARY_ONEDNN_LT \ BINARY_ONEDNN_LE \ BINARY_ONEDNN_GT \ BINARY_ONEDNN_GE #define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR) template std::unique_ptr NewOneDnnBroadcastElementwiseBinary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new OneDnnBroadcastElementwiseBinaryImpl(attr0, attr1)); } #define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \ {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair), \ OF_PP_PAIR_SECOND(data_type_pair)), \ NewOneDnnBroadcastElementwiseBinary< \ OF_PP_PAIR_THIRD(data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \ OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>}, #define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ binary_op_pair, src_data_type_pair, dst_data_type_pair) \ {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \ OF_PP_PAIR_SECOND(dst_data_type_pair)), \ NewOneDnnBroadcastElementwiseBinary< \ OF_PP_PAIR_THIRD(src_data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \ OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>}, #endif // WITH_ONEDNN class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); BroadcastElementwiseBinaryFactoryImpl() = default; ~BroadcastElementwiseBinaryFactoryImpl() override = default; std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) override { return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar()); } std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) override { return New(op, src_type, dst_type, max_num_dims, attr0, Scalar()); } std::unique_ptr New(BinaryOp binary_op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) override { if (max_num_dims > kMaxNumDims) { return nullptr; } #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ OF_PP_PAIR_SECOND(data_type_pair)), \ NewBroadcastElementwiseBinary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ binary_op, src_data_type_pair, dst_data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \ OF_PP_PAIR_SECOND(dst_data_type_pair)), \ NewBroadcastElementwiseBinary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ OF_PP_PAIR_SECOND(data_type_pair)), \ NewBroadcastElementwiseBinary}, static const std::map< std::tuple, std::function(Scalar, Scalar)>> new_broadcast_elementwise_binary_handle{ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_SEQ, NDARRAY_BINARY_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_COMPLEX_MATH_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_BITWISE_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ CPU_PRIMITIVE_BOOL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_FLOATING_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_LOGICAL_OP_SEQ BINARY_COMPARISION_OP_SEQ, NDARRAY_BINARY_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_COMPLEX_COMPARISION_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_ACTIVATION_BACKWARD_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)}; #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY #ifdef WITH_ONEDNN static const std::map< std::tuple, std::function(Scalar, Scalar)>> new_broadcast_elementwise_binary_onednn_handle{ // For oneDNN binary op OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ) // For OneDnn comparasion binary op OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ, CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ)}; #undef MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY if (OneDnnIsEnabled()) { const auto iter = new_broadcast_elementwise_binary_onednn_handle.find( std::make_tuple(binary_op, src_type, dst_type)); if (iter != new_broadcast_elementwise_binary_onednn_handle.end()) { return iter->second(attr0, attr1); } } #endif const auto iter = new_broadcast_elementwise_binary_handle.find( std::make_tuple(binary_op, src_type, dst_type)); if (iter != new_broadcast_elementwise_binary_handle.end()) { return iter->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseBinaryFactory, BroadcastElementwiseBinaryFactoryImpl); } // namespace } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/broadcast_elementwise_unary.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cpu/primitive/unary_functor.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_unary { namespace { #define CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ \ CPU_PRIMITIVE_INT16_TYPE_SEQ \ CPU_PRIMITIVE_NATIVE_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ bool IsContiguous(size_t num_dims, const int64_t* dims, const int64_t* strides) { for (int i = num_dims - 1; i >= 0; i--) { if ((i == num_dims - 1 && strides[i] != 1) || (i != num_dims - 1 && strides[i] != dims[i + 1] * strides[i + 1])) { return false; } } return true; } template void LaunchScalarFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t stride, Scalar attr0, Scalar attr1) { auto functor = UnaryFunctor(attr0, attr1); Dst scalar_value = functor(*src); stream->ParallelFor(0, count, [dst, stride, scalar_value](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i * stride] = scalar_value; } }); } template void LaunchTensorFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t dst_stride, size_t src_stride, Scalar attr0, Scalar attr1) { auto functor = UnaryFunctor(attr0, attr1); stream->ParallelFor(0, count, [functor, src, dst, src_stride, dst_stride](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i * dst_stride] = functor(src[i * src_stride]); } }); } template void LaunchGeneral(CpuStream* stream, Dst* dst, const Src* src, size_t num_dims, const int64_t* dst_dims, const int64_t* src_dims, const int64_t* dst_stride, const int64_t* src_stride, Scalar attr0, Scalar attr1) { bool contiguous_output = IsContiguous(num_dims, dst_dims, dst_stride); const int64_t elem_cnt = GetElementCount(num_dims, dst_dims); auto functor = UnaryFunctor(attr0, attr1); stream->ParallelFor( 0, elem_cnt, [functor, src, dst, num_dims, src_dims, dst_dims, src_stride, dst_stride, contiguous_output]( int64_t begin, int64_t end) { auto src_index_to_offset_helper = IndexToOffsetWithStrideCalculator(src_stride, num_dims); auto dst_offset_to_index_helper = OffsetToIndexWithStrideCalculator(dst_dims, num_dims); auto dst_index_to_offset_helper = IndexToOffsetWithStrideCalculator(dst_stride, num_dims); int64_t src_index[kMaxNumDims]; int64_t dst_index[kMaxNumDims]; for (int64_t offset = begin; offset < end; offset++) { dst_offset_to_index_helper.OffsetToNdIndex(offset, dst_index, num_dims); for (int i = 0; i < kMaxNumDims; i++) { if (i < num_dims) { src_index[i] = (src_dims[i] != 1) ? dst_index[i] : 0; } else { src_index[i] = 0; } } const int64_t src_offset = src_index_to_offset_helper.NdIndexToOffset(src_index, num_dims); if (!contiguous_output) { const int64_t dst_offset = dst_index_to_offset_helper.NdIndexToOffset(dst_index, num_dims); dst[dst_offset] = functor(src[src_offset]); } else { dst[offset] = functor(src[src_offset]); } } }); } template class BroadcastElementwiseUnaryImpl : public BroadcastElementwiseUnary { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryImpl); BroadcastElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~BroadcastElementwiseUnaryImpl() override = default; void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src, size_t num_dst_dims, const int64_t* dst_dims, void* dst) override { CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0"; CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0"; int64_t src_strides[kMaxNumDims]; int64_t dst_strides[kMaxNumDims]; // init stride for (int i = num_src_dims - 1; i < kMaxNumDims; ++i) { src_strides[i] = 1; } for (int i = num_src_dims - 2; i >= 0; --i) { src_strides[i] = src_dims[i + 1] * src_strides[i + 1]; } for (int i = num_dst_dims - 1; i < kMaxNumDims; ++i) { dst_strides[i] = 1; } for (int i = num_dst_dims - 2; i >= 0; --i) { dst_strides[i] = dst_dims[i + 1] * dst_strides[i + 1]; } Launch(stream, num_src_dims, src_dims, src_strides, src, num_dst_dims, dst_dims, dst_strides, dst); } void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const int64_t* src_strides, const void* src_ptr, size_t num_dst_dims, const int64_t* dst_dims, const int64_t* dst_strides, void* dst_ptr) override { CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0"; CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0"; auto* cpu_stream = stream->As(); Dst* dst = reinterpret_cast(dst_ptr); const Src* src = reinterpret_cast(src_ptr); size_t simplified_num_dims = 0; int permutation_list[kMaxNumDims]; int64_t permutation_src_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_src_strides[kMaxNumDims]; int64_t simplified_dst_strides[kMaxNumDims]; SimplifyBroadcastDims(num_src_dims, src_dims, src_strides, num_dst_dims, dst_dims, dst_strides, &simplified_num_dims, simplified_src_dims, simplified_src_strides, simplified_dst_dims, simplified_dst_strides); bool permutable = InferPermutable( simplified_num_dims, simplified_src_strides, simplified_dst_strides, simplified_src_dims, simplified_dst_dims, permutation_list, permutation_src_dims, unary_op); std::unique_ptr permute = NewPrimitive(DeviceType::kCPU, simplified_num_dims); CheckInplace(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst); CheckInplace(simplified_num_dims, simplified_src_strides, src, simplified_dst_strides, dst); if (simplified_num_dims == 1 && simplified_src_dims[0] == 1) { const int64_t elem_cnt = simplified_dst_dims[0]; const int64_t dst_stride = simplified_dst_strides[0]; LaunchScalarFill(cpu_stream, dst, src, elem_cnt, dst_stride, attr0, attr1); } else if (simplified_num_dims == 1) { const int64_t elem_cnt = simplified_src_dims[0]; const int64_t src_stride = simplified_src_strides[0]; const int64_t dst_stride = simplified_dst_strides[0]; LaunchTensorFill(cpu_stream, dst, src, elem_cnt, dst_stride, src_stride, attr0, attr1); } else if (permutable && src_type == dst_type && permute) { permute->Launch(stream, dst_type, simplified_num_dims, permutation_src_dims, src_ptr, permutation_list, dst_ptr); } else { // fall back to normal cases LaunchGeneral( cpu_stream, dst, src, simplified_num_dims, simplified_dst_dims, simplified_src_dims, simplified_dst_strides, simplified_src_strides, attr0, attr1); } } protected: Scalar attr0, attr1; }; template std::unique_ptr NewBroadcastElementwiseUnary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new BroadcastElementwiseUnaryImpl(attr0, attr1)); } class BroadcastElementwiseUnaryFactoryImpl : public BroadcastElementwiseUnaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactoryImpl); BroadcastElementwiseUnaryFactoryImpl() = default; ~BroadcastElementwiseUnaryFactoryImpl() override = default; std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) override { return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar()); } std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) override { return New(op, src_type, dst_type, max_num_dims, attr0, Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) override { if (max_num_dims > kMaxNumDims) { return nullptr; } #define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \ NewBroadcastElementwiseUnary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, src_dtype_pair, dst_dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_dtype_pair), \ OF_PP_PAIR_SECOND(dst_dtype_pair)), \ NewBroadcastElementwiseUnary< \ unary_op, OF_PP_PAIR_FIRST(src_dtype_pair), OF_PP_PAIR_SECOND(src_dtype_pair), \ OF_PP_PAIR_FIRST(dst_dtype_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)>}, static const std::map, std::function(Scalar, Scalar)>> new_broadcast_elementwise_unary_handle{ // For All Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY, UNARY_IDENTITY_SEQ, CPU_PRIMITIVE_ALL_TYPE_SEQ) // For Cast OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ, CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ, CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ, CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)}; #undef MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY const auto iter = new_broadcast_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_type)); if (iter != new_broadcast_elementwise_unary_handle.end()) { return iter->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseUnaryFactory, BroadcastElementwiseUnaryFactoryImpl); } // namespace } // namespace broadcast_elementwise_unary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/broadcast_matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" #include "oneflow/core/ep/common/primitive/broadcast_matmul.h" #include "oneflow/core/common/blas.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_matmul { namespace internal { namespace { constexpr size_t kMaxNumDims = 8; CBLAS_TRANSPOSE GetCblasTranspose(BlasTransposeType transpose_type, DataType data_type) { if (transpose_type == BlasTransposeType::N) { return CblasNoTrans; } else if (transpose_type == BlasTransposeType::T) { return DType(data_type).is_complex() ? CblasConjTrans : CblasTrans; } else { UNIMPLEMENTED(); return CblasNoTrans; } } template>::value || std::is_same>::value)>::type* = nullptr> void CblasMatmul(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int m, int n, int k, T alpha, const T* a, const T* b, T beta, T* c) { int lda = 0; if (trans_a == CblasNoTrans) { lda = k; } else if (trans_a == CblasTrans || trans_a == CblasConjTrans) { lda = m; } else { UNIMPLEMENTED(); } int ldb = 0; if (trans_b == CblasNoTrans) { ldb = n; } else if (trans_b == CblasTrans || trans_b == CblasConjTrans) { ldb = k; } else { UNIMPLEMENTED(); } const int ldc = n; cblas_gemm(CblasRowMajor, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } template>::value || std::is_same>::value>::type* = nullptr> void CblasMatmul(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int m, int n, int k, T alpha, const T* a, const T* b, T beta, T* c) { int lda = 0; if (trans_a == CblasNoTrans) { lda = k; } else if (trans_a == CblasTrans || trans_a == CblasConjTrans) { lda = m; } else { UNIMPLEMENTED(); } int ldb = 0; if (trans_b == CblasNoTrans) { ldb = n; } else if (trans_b == CblasTrans || trans_b == CblasConjTrans) { ldb = k; } else { UNIMPLEMENTED(); } const int ldc = n; cblas_gemm(CblasRowMajor, trans_a, trans_b, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), lda, reinterpret_cast(b), ldb, reinterpret_cast(&beta), reinterpret_cast(c), ldc); } template void LaunchCblasBroadcastMatmul(Stream* /*stream*/, DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t num_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m, int64_t n, int64_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) { const CBLAS_TRANSPOSE cblas_trans_a = GetCblasTranspose(transpose_a, data_type); const CBLAS_TRANSPOSE cblas_trans_b = GetCblasTranspose(transpose_b, data_type); const T alpha_value = alpha.Value(); auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) { const T beta_value = batch_beta.Value(); CblasMatmul(cblas_trans_a, cblas_trans_b, m, n, k, alpha_value, static_cast(batch_a), static_cast(batch_b), beta_value, static_cast(batch_c)); }; ForEachMatmul(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func); } void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t num_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m, int64_t n, int64_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) { if (data_type == DataType::kFloat) { LaunchCblasBroadcastMatmul(stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c); } else if (data_type == DataType::kDouble) { LaunchCblasBroadcastMatmul(stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c); } else if (data_type == DataType::kComplex64) { LaunchCblasBroadcastMatmul>( stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c); } else if (data_type == DataType::kComplex128) { LaunchCblasBroadcastMatmul>( stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c); } else { UNIMPLEMENTED(); } } class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl); BroadcastMatmulFactoryImpl() = default; ~BroadcastMatmulFactoryImpl() override = default; std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, size_t max_num_dims) override { if (max_num_dims > kMaxNumDims) { return nullptr; } if (data_type == DataType::kFloat || data_type == DataType::kDouble || data_type == DataType::kComplex64 || data_type == DataType::kComplex128) { return std::make_unique>(data_type, transpose_a, transpose_b); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl); } // namespace } // namespace internal } // namespace broadcast_matmul } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/cast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" namespace oneflow { namespace ep { namespace primitive { namespace { template struct CpuCastFunctor { static void Call(const From* from, To* to, size_t count) { for (size_t i = 0; i < count; ++i) { to[i] = static_cast(from[i]); } } }; template struct CpuCastFunctor::value)>::type> { static void Call(const bfloat16* from, To* to, size_t count) { for (size_t i = 0; i < count; ++i) { to[i] = static_cast(static_cast(from[i])); } } }; template struct CpuCastFunctor::value)>::type> { static void Call(const From* from, bfloat16* to, size_t count) { for (size_t i = 0; i < count; ++i) { to[i] = bfloat16(static_cast(from[i])); } } }; template class CastImpl : public Cast { public: OF_DISALLOW_COPY_AND_MOVE(CastImpl); CastImpl() = default; ~CastImpl() override = default; void Launch(Stream* stream, const void* from, void* to, size_t count) override { CpuCastFunctor::Call(reinterpret_cast(from), reinterpret_cast(to), count); } }; template std::unique_ptr NewCast() { return std::unique_ptr(new CastImpl()); } #define CPU_PRIMITIVE_CAST_TYPE_SEQ \ CPU_PRIMITIVE_BOOL_TYPE_SEQ \ CPU_PRIMITIVE_CHAR_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_UINT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ \ CPU_PRIMITIVE_UINT64_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ class CastFactoryImpl : public CastFactory { public: OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl); CastFactoryImpl() = default; ~CastFactoryImpl() override = default; std::unique_ptr New(DataType from, DataType to) override { #define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \ {std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \ NewCast}, static const std::map, std::function()>> new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_CAST_ENTRY, CPU_PRIMITIVE_CAST_TYPE_SEQ, CPU_PRIMITIVE_CAST_TYPE_SEQ)}; #undef MAKE_NEW_CAST_ENTRY const auto it = new_cast_handle.find(std::make_pair(from, to)); if (it != new_cast_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, CastFactory, CastFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/constant_pad.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/constant_pad.h" #include "oneflow/core/ep/common/primitive/constant_pad.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" namespace oneflow { namespace ep { namespace primitive { namespace { template void ConstantPadKernel(ConstantPadParams params, StorageType packed_pad_val) { const StorageType* src = reinterpret_cast(params.src); StorageType* dst = reinterpret_cast(params.dst); IndexType src_index[num_dims]; IndexType dst_index[num_dims]; for (IndexType linear_index = 0; linear_index < params.elem_cnt; ++linear_index) { params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index); bool if_pad = false; for (int i = 0; i < num_dims; i++) { if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) { src_index[i] = dst_index[i] - params.valid_start[i]; } else { if_pad = true; break; } } StorageType dst_val = packed_pad_val; if (!if_pad) { const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); dst_val = src[src_offset]; } dst[linear_index] = dst_val; } } template<> float16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template<> bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template void LaunchKernel(ConstantPadParams params, StorageType packed_pad_val) { ConstantPadKernel(params, packed_pad_val); } template void LaunchKernel(void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { ConstantPadParams params; params.dst_index_helper = OffsetToIndexCalculator(dst_dims); params.src_index_helper = NdIndexOffsetHelper(src_dims); params.dst = dst; params.src = src; for (int i = 0; i < num_dims; i++) { params.valid_start[i] = padding_before[i]; params.valid_end[i] = dst_dims[i] - padding_after[i]; } params.elem_cnt = elem_cnt; LaunchKernel(params, packed_pad_val); } template void DispatchIndexType(void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { if (elem_cnt < GetMaxVal()) { LaunchKernel(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val, elem_cnt); } else { LaunchKernel(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val, elem_cnt); } } template void DispatchPackSize(void* dst, int64_t* dst_dims, const void* src, int64_t* src_dims, int64_t* padding_before, int64_t* padding_after, T pad_val) { constexpr int32_t max_packsize = GetMaxPackSize(); size_t launch_pack_size = GetLaunchPackSize(num_dims, dst, dst_dims, src, src_dims, padding_before, padding_after); dst_dims[num_dims - 1] /= launch_pack_size; src_dims[num_dims - 1] /= launch_pack_size; padding_before[num_dims - 1] /= launch_pack_size; padding_after[num_dims - 1] /= launch_pack_size; size_t elem_cnt = 1; for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; } if (launch_pack_size == 1) { Pack packed_pad_val(pad_val); DispatchIndexType>(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 2) { Pack packed_pad_val(pad_val); DispatchIndexType>(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 4) { Pack packed_pad_val(pad_val); DispatchIndexType>(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 8) { Pack packed_pad_val(pad_val); DispatchIndexType>(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 16) { Pack packed_pad_val(pad_val); DispatchIndexType>(dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else { UNIMPLEMENTED(); } } template void LaunchWithSimplified(size_t num_dims, void* dst, int64_t* dst_dims, const void* src, int64_t* src_dims, int64_t* padding_before, int64_t* padding_after, T pad_val) { void (*func)(void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/, int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) = nullptr; if (num_dims == 1) { func = DispatchPackSize<1, T>; } else if (num_dims == 2) { func = DispatchPackSize<2, T>; } else if (num_dims == 3) { func = DispatchPackSize<3, T>; } else if (num_dims == 4) { func = DispatchPackSize<4, T>; } else if (num_dims == 5) { func = DispatchPackSize<5, T>; } else if (num_dims == 6) { func = DispatchPackSize<6, T>; } else if (num_dims == 7) { func = DispatchPackSize<7, T>; } else if (num_dims == 8) { func = DispatchPackSize<8, T>; } else { UNIMPLEMENTED(); } func(dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val); } template void SimplifyThenLaunch(size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* padding_before, const int64_t* padding_after, T pad_val, void* dst) { CHECK_GT(num_dims, 0) << "num_dims must greater than 0"; CHECK_LE(num_dims, kMaxNumDims); int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_padding_before[kMaxNumDims]; int64_t simplified_padding_after[kMaxNumDims]; size_t simplified_num_dims = 1; SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims, simplified_dst_dims, simplified_src_dims, simplified_padding_before, simplified_padding_after); LaunchWithSimplified(simplified_num_dims, dst, simplified_dst_dims, src, simplified_src_dims, simplified_padding_before, simplified_padding_after, pad_val); } template class ConstantPadImpl : public ConstantPad { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl); ConstantPadImpl() = default; ~ConstantPadImpl() override = default; void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val, void* dst) override { SimplifyThenLaunch(num_dims, src_dims, src, padding_before, padding_after, GetValue(pad_val), dst); } }; template std::unique_ptr NewConstantPad() { return std::unique_ptr(new ConstantPadImpl()); } class ConstantPadFactoryImpl : public ConstantPadFactory { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl); ConstantPadFactoryImpl() = default; ~ConstantPadFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad}, static const std::map()>> new_constant_pad_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)}; #undef MAKE_NEW_CONSTANT_PAD_ENTRY const auto it = new_constant_pad_handle.find(data_type); if (it != new_constant_pad_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, ConstantPadFactory, ConstantPadFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/copy_nd.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/common/primitive/copy_nd.h" namespace oneflow { namespace ep { namespace primitive { namespace { template void CopyNdKernel(CopyNdKernelParams params) { using T = typename std::aligned_storage::type; const T* src = reinterpret_cast(params.src); T* dst = reinterpret_cast(params.dst); for (IndexType i = 0; i < params.count; ++i) { IndexType copy_index[num_dims]; IndexType src_index[num_dims]; IndexType dst_index[num_dims]; params.copy_index_helper.OffsetToNdIndex(i, copy_index); for (size_t j = 0; j < num_dims; ++j) { src_index[j] = params.src_pos[j] + copy_index[j]; dst_index[j] = params.dst_pos[j] + copy_index[j]; } const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index); dst[dst_offset] = src[src_offset]; } } template void LaunchKernel(Stream* stream, CopyNdKernelParams params) { CopyNdKernel(params); } class CopyNdImpl : public CopyNd { public: OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl); CopyNdImpl() = default; ~CopyNdImpl() = default; void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) const override { SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent); } }; class CopyNdFactoryImpl : public CopyNdFactory { public: OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl); CopyNdFactoryImpl() = default; ~CopyNdFactoryImpl() override = default; std::unique_ptr New(size_t max_num_dims) override { if (max_num_dims <= kMaxNumDims) { return std::unique_ptr(new CopyNdImpl()); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, CopyNdFactory, CopyNdFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/elementwise_unary.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/elementwise_unary.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/ep/cpu/primitive/unary_functor.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" namespace oneflow { namespace ep { namespace primitive { namespace { template class ElementwiseUnaryImpl : public ElementwiseUnary { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl); ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~ElementwiseUnaryImpl() override = default; void Launch(Stream* stream, const void* src_ptr, void* dst_ptr, size_t count) override { CpuStream* cpu_stream = stream->As(); Dst* dst = reinterpret_cast(dst_ptr); const Src* src = reinterpret_cast(src_ptr); auto functor = UnaryFunctor(attr0, attr1); cpu_stream->ParallelFor(0, count, [functor, src, dst](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { dst[i] = functor(src[i]); } }); } protected: Scalar attr0, attr1; }; template std::unique_ptr NewElementwiseUnary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new ElementwiseUnaryImpl(attr0, attr1)); } class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl); ElementwiseUnaryFactoryImpl() = default; ~ElementwiseUnaryFactoryImpl() override = default; std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype) override { return New(unary_op, src_type, dst_dtype, Scalar(), Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, Scalar attr0) override { return New(unary_op, src_type, dst_dtype, attr0, Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, Scalar attr0, Scalar attr1) override { #define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \ NewElementwiseUnary}, #define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \ NewElementwiseUnary}, static const std::map, std::function(Scalar, Scalar)>> new_elementwise_unary_handle{ // For All Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_MATH_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ) // For Float Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ) // For Complex Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2C_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) // For Int Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_INT_MATH_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ) // For Bitwise OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_BITWISE_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ CPU_PRIMITIVE_BOOL_TYPE_SEQ) // For Utils OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ) // For Logical OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_LOGICAL_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)}; #undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY const auto it = new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype)); if (it != new_elementwise_unary_handle.end()) { return it->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/fill.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { namespace { template T GetValue(Scalar value) { return value.Value(); } template<> float16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template<> bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } template class FillImpl : public Fill { public: OF_DISALLOW_COPY_AND_MOVE(FillImpl); FillImpl() = default; ~FillImpl() override = default; void Launch(Stream* stream, void* dst, Scalar value, size_t count) override { std::fill_n(reinterpret_cast(dst), count, GetValue(value)); } }; template std::unique_ptr NewFill() { return std::unique_ptr(new FillImpl()); } class FillFactoryImpl : public FillFactory { public: OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl); FillFactoryImpl() = default; ~FillFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill}, static const std::map()>> new_fill_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ CPU_PRIMITIVE_INT16_TYPE_SEQ)}; #undef MAKE_NEW_ADD_ENTRY const auto it = new_fill_handle.find(data_type); if (it != new_fill_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, FillFactory, FillFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/memcpy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/memcpy.h" namespace oneflow { namespace ep { namespace primitive { namespace { class MemcpyImpl : public Memcpy { public: OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl); MemcpyImpl() = default; ~MemcpyImpl() = default; void Launch(Stream* stream, void* dst, const void* src, size_t count) { if (dst == src) { return; } std::memcpy(dst, src, count); } }; class MemcpyFactoryImpl : public MemcpyFactory { public: OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl); MemcpyFactoryImpl() = default; ~MemcpyFactoryImpl() override = default; std::unique_ptr New(MemcpyKind kind) override { return std::make_unique(); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MemcpyFactory, MemcpyFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/memset.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace ep { namespace primitive { namespace { class MemsetImpl : public Memset { public: OF_DISALLOW_COPY_AND_MOVE(MemsetImpl); MemsetImpl() = default; ~MemsetImpl() = default; void Launch(Stream* stream, void* ptr, int value, size_t count) { std::memset(ptr, value, count); } }; class MemsetFactoryImpl : public MemsetFactory { public: OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl); MemsetFactoryImpl() = default; ~MemsetFactoryImpl() override = default; std::unique_ptr New() override { return std::make_unique(); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MemsetFactory, MemsetFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/permute.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/common/primitive/permute_impl.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/common/onednn.h" namespace oneflow { namespace ep { namespace primitive { namespace permute { namespace internal { namespace { template void PermuteKernel(PermuteKernelParams params) { using T = typename std::aligned_storage::type; const T* src = reinterpret_cast(params.src); T* dst = reinterpret_cast(params.dst); for (IndexType i = 0; i < params.count; ++i) { IndexType src_index[num_dims]; IndexType dst_index[num_dims]; params.dst_index_helper.OffsetToNdIndex(i, dst_index); for (size_t dim = 0; dim < num_dims; ++dim) { src_index[params.permutation[dim]] = dst_index[dim]; } IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); dst[i] = src[src_offset]; } } template void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation, void* dst, size_t count) { PermuteKernelParams params = MakePermuteParams(src_dims, src, permutation, dst, count); PermuteKernel(params); } class PermuteImpl : public Permute { public: OF_DISALLOW_COPY_AND_MOVE(PermuteImpl); PermuteImpl() = default; ~PermuteImpl() override = default; using Permute::Launch; void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) override { SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst); } }; #ifdef WITH_ONEDNN constexpr size_t kMaxOneDnnMovementSize = 4; constexpr size_t kMaxOneDnnMapSize = 5; uint32_t OnednnDatatypeTagMap[kMaxOneDnnMapSize] = {0, dnnl_u8, dnnl_f16, 0, dnnl_s32}; class OneDnnPermuteImpl : public Permute { public: OF_DISALLOW_COPY_AND_MOVE(OneDnnPermuteImpl); OneDnnPermuteImpl() = default; ~OneDnnPermuteImpl() override = default; using Permute::Launch; void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) override { CHECK_LE(num_dims, kMaxNumDims); CHECK_GT(num_dims, 0); stream->As()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) { size_t onednn_num_dims = num_dims; dnnl::memory::dims onednn_dims(kMaxNumDims + 1, 0); dnnl::memory::dims onednn_permute(kMaxNumDims + 1, 0); dnnl::memory::dims src_stride(kMaxNumDims + 1, 0); dnnl::memory::dims dst_stride(kMaxNumDims + 1, 0); for (int64_t dim = onednn_num_dims - 1; dim >= 0; dim--) { onednn_dims[dim] = src_dims[dim]; onednn_permute[dim] = permutation[dim]; } size_t movement_size = GetSizeOfDataType(data_type); if (movement_size > kMaxOneDnnMovementSize) { onednn_dims[onednn_num_dims] = movement_size / kMaxOneDnnMovementSize; onednn_permute[onednn_num_dims] = onednn_num_dims; onednn_num_dims = onednn_num_dims + 1; movement_size = kMaxOneDnnMovementSize; } onednn_dims.resize(onednn_num_dims); src_stride[onednn_num_dims - 1] = 1; dst_stride[onednn_permute[onednn_num_dims - 1]] = 1; for (int64_t i = onednn_num_dims - 2; i >= 0; i--) { src_stride[i] = src_stride[i + 1] * onednn_dims[i + 1]; dst_stride[onednn_permute[i]] = dst_stride[onednn_permute[i + 1]] * onednn_dims[onednn_permute[i + 1]]; } dnnl::memory::data_type onednn_data_type = static_cast(OnednnDatatypeTagMap[movement_size]); // The reorder primitive requires the source and destination tensors to have the same shape. // Implicit broadcasting is not supported. auto src_mem_desc = dnnl::memory::desc(onednn_dims, onednn_data_type, src_stride); auto dst_mem_desc = dnnl::memory::desc(onednn_dims, onednn_data_type, dst_stride); auto src_mem = dnnl::memory(src_mem_desc, *onednn_engine, const_cast(src)); auto dst_mem = dnnl::memory(dst_mem_desc, *onednn_engine, dst); auto reorder_primitive_desc = dnnl::reorder::primitive_desc(*onednn_engine, src_mem_desc, *onednn_engine, dst_mem_desc); auto reorder_primitive = dnnl::reorder(reorder_primitive_desc); reorder_primitive.execute(*onednn_stream, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}}); }); } }; #endif // WITH_ONEDNN class PermuteFactoryImpl : public PermuteFactory { public: OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl); PermuteFactoryImpl() = default; ~PermuteFactoryImpl() override = default; std::unique_ptr New(size_t max_num_dims) override { if (max_num_dims <= kMaxNumDims) { #ifdef WITH_ONEDNN if (OneDnnIsEnabled()) { return std::unique_ptr(new OneDnnPermuteImpl()); } #endif return std::unique_ptr(new PermuteImpl()); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, PermuteFactory, PermuteFactoryImpl); } // namespace } // namespace internal } // namespace permute } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/softmax.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/softmax.h" #include "oneflow/core/ep/include/primitive/log_softmax.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/common/primitive/util.h" #include "oneflow/core/ep/common/onednn.h" namespace oneflow { namespace ep { namespace primitive { namespace { enum class Algorithm { kSoftmax, kLogSoftmax, }; template void SoftmaxCpu(size_t rows, size_t cols, const T* x, T* y) { for (size_t i = 0; i < rows; ++i) { size_t row_offset = i * cols; const T* row_x = x + row_offset; T* row_y = y + row_offset; const T row_max = *std::max_element(row_x, row_x + cols); T row_sum = 0; for (size_t j = 0; j < cols; ++j) { if (algorithm == Algorithm::kSoftmax) { T exp_x = std::exp(row_x[j] - row_max); row_sum += exp_x; row_y[j] = exp_x; } else if (algorithm == Algorithm::kLogSoftmax) { row_y[j] = row_x[j] - row_max; row_sum += std::exp(row_y[j]); } else { UNIMPLEMENTED(); } } for (size_t j = 0; j < cols; ++j) { if (algorithm == Algorithm::kSoftmax) { row_y[j] /= row_sum; } else if (algorithm == Algorithm::kLogSoftmax) { row_y[j] -= std::log(row_sum); } else { UNIMPLEMENTED(); } } } } template class SoftmaxImpl : public SoftmaxBase { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl); SoftmaxImpl() = default; ~SoftmaxImpl() override = default; void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override { SoftmaxCpu(rows, cols, reinterpret_cast(x), reinterpret_cast(y)); } }; #ifdef WITH_ONEDNN template void SoftmaxOneDnn(Stream* stream, size_t rows, size_t cols, const void* x, void* y) { stream->As()->onednn_executor()->Launch( [&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) { dnnl::memory::dims src_dims = {static_cast(rows), static_cast(cols)}; auto src_md = dnnl::memory::desc(src_dims, data_type, dnnl::memory::format_tag::nc); auto src_mem = dnnl::memory(src_md, *onednn_engine, const_cast(x)); auto dst_mem = dnnl::memory(src_md, *onednn_engine, y); auto softmax_d = typename OneDnnSoftmax::desc(dnnl::prop_kind::forward, src_md, 1); auto softmax_pd = typename OneDnnSoftmax::primitive_desc(softmax_d, *onednn_engine); auto softmax_prim = OneDnnSoftmax(softmax_pd); softmax_prim.execute(*onednn_stream, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}}); }); } template class OneDnnSoftmaxImpl; #define CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(oneflow_algorithm, onednn_algorithm) \ template \ class OneDnnSoftmaxImpl : public SoftmaxBase { \ public: \ OF_DISALLOW_COPY_AND_MOVE(OneDnnSoftmaxImpl); \ OneDnnSoftmaxImpl() = default; \ ~OneDnnSoftmaxImpl() override = default; \ \ using OneDnnClass = onednn_algorithm; \ void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override { \ SoftmaxOneDnn(stream, rows, cols, x, y); \ } \ } CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kSoftmax, dnnl::softmax_forward); CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kLogSoftmax, dnnl::logsoftmax_forward); #undef CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL template std::unique_ptr NewOneDnnSoftmax() { return std::unique_ptr(new OneDnnSoftmaxImpl()); } #endif // WITH_ONEDNN template std::unique_ptr NewSoftmax() { return std::unique_ptr(new SoftmaxImpl()); } template class GenericSoftmaxFactoryImpl : public FactoryBase { public: OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl); GenericSoftmaxFactoryImpl() = default; ~GenericSoftmaxFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \ {type_proto, NewSoftmax}, static const std::map()>> new_softmax_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)}; #undef MAKE_NEW_SOFTMAX_ENTRY #ifdef WITH_ONEDNN if (OneDnnIsEnabled() && data_type == DataType::kFloat) { static std::function()> onednn_softmax = NewOneDnnSoftmax; return onednn_softmax(); } #endif return NewPrimitiveFromHandlers(new_softmax_handle, data_type); } }; using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl; using LogSoftmaxFactoryImpl = GenericSoftmaxFactoryImpl; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, SoftmaxFactory, SoftmaxFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, LogSoftmaxFactory, LogSoftmaxFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/softmax_backward.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/softmax_backward.h" #include "oneflow/core/ep/include/primitive/log_softmax_backward.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/common/onednn.h" #include "oneflow/core/ep/common/primitive/util.h" namespace oneflow { namespace ep { namespace primitive { namespace { enum class Algorithm { kSoftmax, kLogSoftmax, }; template void SoftmaxBackwardCpu(size_t rows, size_t cols, const T* y, const T* dy, T* dx) { for (size_t i = 0; i < rows; ++i) { size_t row_offset = i * cols; const T* row_y = y + row_offset; const T* row_dy = dy + row_offset; T* row_dx = dx + row_offset; T row_sum = 0; for (size_t j = 0; j < cols; ++j) { if (algorithm == Algorithm::kSoftmax) { row_sum += row_y[j] * row_dy[j]; } else if (algorithm == Algorithm::kLogSoftmax) { row_sum += row_dy[j]; } else { UNIMPLEMENTED(); } } for (size_t j = 0; j < cols; ++j) { if (algorithm == Algorithm::kSoftmax) { row_dx[j] = (row_dy[j] - row_sum) * row_y[j]; } else if (algorithm == Algorithm::kLogSoftmax) { row_dx[j] = row_dy[j] - std::exp(row_y[j]) * row_sum; } else { UNIMPLEMENTED(); } } } } template class SoftmaxBackwardImpl : public SoftmaxBackwardBase { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl); SoftmaxBackwardImpl() = default; ~SoftmaxBackwardImpl() override = default; void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void* dx) override { SoftmaxBackwardCpu(rows, cols, reinterpret_cast(y), reinterpret_cast(dy), reinterpret_cast(dx)); } }; #ifdef WITH_ONEDNN template void SoftmaxBackwardOneDnn(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void* dx) { stream->As()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) { dnnl::memory::dims src_dims = {static_cast(rows), static_cast(cols)}; // Input and output parameters of the same data type auto same_md = dnnl::memory::desc(src_dims, data_type, dnnl::memory::format_tag::nc); // Backward memory auto dst_mem = dnnl::memory(same_md, *onednn_engine, const_cast(y)); auto diff_dst_mem = dnnl::memory(same_md, *onednn_engine, const_cast(dy)); // Forward primitive description auto forward_desc = typename OneDnnSoftmaxForward::desc(dnnl::prop_kind::forward, same_md, 1); auto forward_prim_desc = typename OneDnnSoftmaxForward::primitive_desc(forward_desc, *onednn_engine); // Backward primitive description auto diff_src_mem = dnnl::memory(same_md, *onednn_engine, dx); auto backward_desc = typename OneDnnSoftmaxBackward::desc(same_md, same_md, 1); auto backward_prim_desc = typename OneDnnSoftmaxBackward::primitive_desc( backward_desc, *onednn_engine, forward_prim_desc); auto backward_prim = OneDnnSoftmaxBackward(backward_prim_desc); backward_prim.execute(*onednn_stream, {{DNNL_ARG_DIFF_DST, diff_dst_mem}, {DNNL_ARG_DST, dst_mem}, {DNNL_ARG_DIFF_SRC, diff_src_mem}}); }); } template class OneDnnSoftmaxBackwardImpl; #define CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(oneflow_algorithm, onednn_backward_algorithm, \ onednn_forward_algorithm) \ template \ class OneDnnSoftmaxBackwardImpl \ : public SoftmaxBackwardBase { \ public: \ OF_DISALLOW_COPY_AND_MOVE(OneDnnSoftmaxBackwardImpl); \ OneDnnSoftmaxBackwardImpl() = default; \ ~OneDnnSoftmaxBackwardImpl() override = default; \ \ void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, \ void* dx) override { \ SoftmaxBackwardOneDnn( \ stream, rows, cols, y, dy, dx); \ } \ } CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kSoftmax, dnnl::softmax_backward, dnnl::softmax_forward); CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kLogSoftmax, dnnl::logsoftmax_backward, dnnl::logsoftmax_forward); #undef CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL template std::unique_ptr NewOneDnnSoftmaxBackward() { return std::unique_ptr( new OneDnnSoftmaxBackwardImpl()); } #endif // WITH_ONEDNN template std::unique_ptr NewSoftmaxBackward() { return std::unique_ptr( new SoftmaxBackwardImpl()); } template class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase { public: OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl); GenericSoftmaxBackwardFactoryImpl() = default; ~GenericSoftmaxBackwardFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_SOFTMAX_BACKWARD_ENTRY(type_cpp, type_proto) \ {type_proto, NewSoftmaxBackward}, static const std::map()>> new_softmax_backward_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_BACKWARD_ENTRY, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)}; #undef MAKE_NEW_SOFTMAX_BACKWARD_ENTRY #ifdef WITH_ONEDNN if (OneDnnIsEnabled() && data_type == DataType::kFloat) { static std::function()> onednn_f32_softmax_backward = NewOneDnnSoftmaxBackward; return onednn_f32_softmax_backward(); } #endif return NewPrimitiveFromHandlers(new_softmax_backward_handle, data_type); } }; using SoftmaxBackwardFactoryImpl = GenericSoftmaxBackwardFactoryImpl; using LogSoftmaxBackwardFactoryImpl = GenericSoftmaxBackwardFactoryImpl; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, LogSoftmaxBackwardFactory, LogSoftmaxBackwardFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/tensor_fill.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/tensor_fill.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" namespace oneflow { namespace ep { namespace primitive { namespace { template class TensorFillImpl : public TensorFill { public: OF_DISALLOW_COPY_AND_MOVE(TensorFillImpl); TensorFillImpl() = default; ~TensorFillImpl() override = default; void Launch(Stream* stream, const void* src, void* dst, size_t count) override { const T* value = reinterpret_cast(src); std::fill_n(reinterpret_cast(dst), count, value[0]); } }; template std::unique_ptr NewTensorFill() { return std::unique_ptr(new TensorFillImpl()); } class TensorFillFactoryImpl : public TensorFillFactory { public: OF_DISALLOW_COPY_AND_MOVE(TensorFillFactoryImpl); TensorFillFactoryImpl() = default; ~TensorFillFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill}, static const std::map()>> new_fill_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)}; #undef MAKE_NEW_ADD_ENTRY const auto it = new_fill_handle.find(data_type); if (it != new_fill_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, TensorFillFactory, TensorFillFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/type_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_ #define ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_ #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/data_type.h" #include #ifdef WITH_ONEDNN #include "oneapi/dnnl/dnnl.hpp" #endif #define CPU_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) #define CPU_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) #define CPU_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) #define CPU_PRIMITIVE_INT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16) #define CPU_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) #define CPU_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define CPU_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) #define CPU_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define CPU_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) #define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) #define CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) #define CPU_PRIMITIVE_COMPLEX64_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(std::complex, DataType::kComplex64) #define CPU_PRIMITIVE_COMPLEX128_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(std::complex, DataType::kComplex128) #define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool) #define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8) #define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8) #define CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s32, DataType::kInt32) #define CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat) #define CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16) #define CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::bf16, DataType::kBFloat16) #define CPU_PRIMITIVE_NATIVE_TYPE_SEQ \ CPU_PRIMITIVE_BOOL_TYPE_SEQ \ CPU_PRIMITIVE_CHAR_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_UINT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ \ CPU_PRIMITIVE_UINT64_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ #define CPU_PRIMITIVE_ALL_TYPE_SEQ \ CPU_PRIMITIVE_NATIVE_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ \ CPU_PRIMITIVE_COMPLEX_TYPE_SEQ #define CPU_PRIMITIVE_COMPLEX_TYPE_SEQ \ CPU_PRIMITIVE_COMPLEX64_TYPE_SEQ \ CPU_PRIMITIVE_COMPLEX128_TYPE_SEQ #define CPU_PRIMITIVE_FLOATING_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ #define CPU_PRIMITIVE_INT_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ #define UTIL_OPS_DATA_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_INT32_TYPE_SEQ \ CPU_PRIMITIVE_INT64_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ #endif // ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_ ================================================ FILE: oneflow/core/ep/cpu/primitive/unary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/common/math_util.h" namespace oneflow { namespace ep { namespace primitive { template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(0.5) * src * (static_cast(1.0) + std::erf(inv_sqrt2 * src)); } Src inv_sqrt2 = std::sqrt(0.5); }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu const Src half = static_cast(0.5); const Src one = static_cast(1); const Src tanh_in = alpha * (src + beta * src * src * src); return half * src * (one + std::tanh(tanh_in)); } private: // constant ref to: // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py static constexpr Src alpha = static_cast(0.7978845608028654); static constexpr Src beta = static_cast(0.044714998453855515); }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { const Src sigmoid = static_cast(static_cast(1.0) / (static_cast(1.0) + exp(-src * alpha))); return src * sigmoid; } private: static constexpr Src alpha = static_cast(1.702); }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast((src > static_cast(0.0)) ? src * src : 0); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return std::tanh(src); } }; template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(float src) const { return std::isinf(src); } }; template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(double src) const { return std::isinf(src); } }; template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(float src) const { return std::isnan(src); } }; template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(double src) const { return std::isnan(src); } }; template struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(Src src) const { return std::isfinite(src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(std::trunc(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(static_cast(1.0) / static_cast(std::sqrt(src))); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L434-L487 const auto& calc_digamma = [](float x) { std::function compute; compute = [&](float x) { static float PSI_10 = 2.25175258906672110764f; if (x == 0) { // As per C++ standard for gamma related functions and SciPy, // If the argument is ±0, ±∞ is returned return std::copysign(INFINITY, -x); } bool x_is_integer = x == truncf(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, // If the argument is a negative integer, NaN is returned return std::numeric_limits::quiet_NaN(); } // Extracts the fractional part of x as r, since tan(pi * r) is more numerically // accurate than tan(pi * x). While these operations are mathematically equivalent // since both x and r are in radians and tan() has a periodicity of pi, in practice // the computation of pi * x is a source of error (when |x| > 1). double q, r; r = std::modf(x, &q); float pi_over_tan_pi_x = (float)(pi / tan(pi * r)); return compute(1 - x) - pi_over_tan_pi_x; } // Push x to be >= 10 float result = 0; while (x < 10) { result -= 1 / x; x += 1; } if (x == 10) { return result + PSI_10; } // Compute asymptotic digamma static const float A[] = { 8.33333333333333333333E-2f, -2.10927960927960927961E-2f, 7.57575757575757575758E-3f, -4.16666666666666666667E-3f, 3.96825396825396825397E-3f, -8.33333333333333333333E-3f, 8.33333333333333333333E-2f, }; float y = 0; if (x < 1.0e17f) { float z = 1 / (x * x); float polevl_result = 0; for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; } y = z * polevl_result; } return result + logf(x) - (0.5f / x) - y; }; return compute(x); }; return calc_digamma(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L376-L428 const auto& calc_digamma = [](double x) { std::function compute; compute = [&](double x) { static double PSI_10 = 2.25175258906672110764; if (x == 0) { // As per C++ standard for gamma related functions and SciPy, // If the argument is ±0, ±∞ is returned return std::copysign(INFINITY, -x); } bool x_is_integer = x == trunc(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, // If the argument is a negative integer, NaN is returned return std::numeric_limits::quiet_NaN(); } // Extracts the fractional part of x as r, since tan(pi * r) is more numerically // accurate than tan(pi * x). While these operations are mathematically equivalent // since both x and r are in radians and tan() has a periodicity of pi, in practice // the computation of pi * x is a source of error (when |x| > 1). double q, r; r = std::modf(x, &q); return compute(1 - x) - pi / tan(pi * r); } // Push x to be >= 10 double result = 0; while (x < 10) { result -= 1 / x; x += 1; } if (x == 10) { return result + PSI_10; } // Compute asymptotic digamma static const double A[] = { 8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3, -4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3, 8.33333333333333333333E-2, }; double y = 0; if (x < 1.0e17) { double z = 1.0 / (x * x); // y = z * polevl(z, A, 6); double polevl_result = 0; for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; } y = z * polevl_result; } return result + log(x) - (0.5 / x) - y; }; return compute(x); }; return calc_digamma(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double x) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L336-L352 double sign = +1; double result = 0; if (x < 0.5) { sign = -1; const double sin_pi_x = sin(pi * x); result -= (pi * pi) / (sin_pi_x * sin_pi_x); x = 1 - x; } for (int i = 0; i < 6; ++i) { result += 1 / (x * x); x += 1; } const double ixx = 1 / (x * x); result += (1 + 1 / (2 * x) + ixx * (1. / 6 - ixx * (1. / 30 - ixx * (1. / 42)))) / x; return sign * result; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float x) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L354-L370 float sign = +1; float result = 0; if (x < 0.5f) { sign = -1; const float sin_pi_x = sinf(pi * x); result -= (pi * pi) / (sin_pi_x * sin_pi_x); x = 1 - x; } for (int i = 0; i < 6; ++i) { result += 1 / (x * x); x += 1; } const float ixx = 1 / (x * x); result += (1 + 1 / (2 * x) + ixx * (1.f / 6 - ixx * (1.f / 30 - ixx * (1.f / 42)))) / x; return sign * result; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { return std::abs(src); } }; #define SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(op) \ template<> \ struct UnaryFunctor { \ OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ \ UnaryFunctor float_functor; \ OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { \ return bfloat16(float_functor(static_cast(src))); \ } \ }; SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcos); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcosh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsin); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsinh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtan); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtanh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCeil); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCos); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCosh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErf); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErfc); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp2); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExpm1); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFloor); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLgamma); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog2); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog1p); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLogSigmoid); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRint); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRound); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRsqrt); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSigmoid); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSin); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSinh); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSqrt); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquare); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTan); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma); SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma); template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isinf(src); } }; template<> struct UnaryFunctor { UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isnan(src); } }; // avoid warning: narrowing conversion template<> struct UnaryFunctor, double> { UnaryFunctor(Scalar attr0, Scalar attr1) {} std::complex operator()(double src) const { return std::complex{static_cast(src), 0.0f}; } }; template<> struct UnaryFunctor, double> { UnaryFunctor(Scalar attr0, Scalar attr1) {} std::complex operator()(double src) const { return std::complex{0.0f, static_cast(src)}; } }; } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cpu/primitive/where.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/where.h" #include "oneflow/core/ep/common/primitive/where.h" #include "oneflow/core/ep/cpu/cpu_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { template void BroadcastElementwiseWhereKernel(CpuStream* cpu_stream, const BroadcastElementwiseWhereParams& params) { constexpr size_t _pack_size = (x_pack_size > y_pack_size) ? x_pack_size : y_pack_size; constexpr size_t pack_size = (cond_pack_size > _pack_size) ? cond_pack_size : _pack_size; static_assert(cond_pack_size == pack_size || cond_pack_size == 1, ""); static_assert(x_pack_size == pack_size || x_pack_size == 1, ""); static_assert(y_pack_size == pack_size || y_pack_size == 1, ""); const auto* cond_pack = reinterpret_cast*>(params.cond); const auto* x_pack = reinterpret_cast*>(params.x); const auto* y_pack = reinterpret_cast*>(params.y); auto* z_pack = reinterpret_cast*>(params.z); WhereFunctor where_fn{}; cpu_stream->ParallelFor(0, params.elem_cnt, [&](int64_t begin, int64_t end) { IndexT cond_index[ndim]; IndexT x_index[ndim]; IndexT y_index[ndim]; IndexT z_index[ndim]; for (IndexT offset = begin; offset < end; offset++) { params.z_index_helper.OffsetToNdIndex(offset, z_index); for (size_t i = 0; i < ndim; ++i) { cond_index[i] = params.cond_index_mask[i] * z_index[i]; x_index[i] = params.x_index_mask[i] * z_index[i]; y_index[i] = params.y_index_mask[i] * z_index[i]; } const IndexT cond_offset = params.cond_index_helper.NdIndexToOffset(cond_index); const IndexT x_offset = params.x_index_helper.NdIndexToOffset(x_index); const IndexT y_offset = params.y_index_helper.NdIndexToOffset(y_index); for (size_t j = 0; j < pack_size; ++j) { const CondT cond_val = (cond_pack_size == pack_size) ? cond_pack[cond_offset].elem[j] : cond_pack[cond_offset].elem[0]; const T x_val = (x_pack_size == pack_size) ? x_pack[x_offset].elem[j] : x_pack[x_offset].elem[0]; const T y_val = (y_pack_size == pack_size) ? y_pack[y_offset].elem[j] : y_pack[y_offset].elem[0]; z_pack[offset].elem[j] = where_fn(static_cast(cond_val), x_val, y_val); } } }); } template void ScalarWhereKernel(const CondT* cond, const T* x, const T* y, T* z) { WhereFunctor where_fn{}; *z = where_fn(*cond, *x, *y); } template void LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { static_assert(ndim > 0, ""); BroadcastElementwiseWhereParams params; params.cond_index_helper = NdIndexOffsetHelper(cond_dims); params.x_index_helper = NdIndexOffsetHelper(x_dims); params.y_index_helper = NdIndexOffsetHelper(y_dims); params.z_index_helper = NdIndexOffsetHelper(z_dims); for (size_t i = 0; i < ndim; ++i) { params.cond_index_mask[i] = (cond_dims[i] == 1) ? 0 : 1; params.x_index_mask[i] = (x_dims[i] == 1) ? 0 : 1; params.y_index_mask[i] = (y_dims[i] == 1) ? 0 : 1; } params.elem_cnt = static_cast(GetElementCount(ndim, z_dims)); params.cond = cond; params.x = x; params.y = y; params.z = z; auto* cpu_stream = stream->As(); BroadcastElementwiseWhereKernel( cpu_stream, params); } template void LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z) { ScalarWhereKernel(cond, x, y, z); } template class WhereImpl : public Where { public: OF_DISALLOW_COPY_AND_MOVE(WhereImpl); explicit WhereImpl() = default; ~WhereImpl() override = default; void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims, const void* cond, size_t num_x_dims, const int64_t* x_dims, const void* x, size_t num_y_dims, const int64_t* y_dims, const void* y, void* z) override { size_t compact_num_dims = 0; int64_t compact_cond_dims[kMaxNumDims] = {}; int64_t compact_x_dims[kMaxNumDims] = {}; int64_t compact_y_dims[kMaxNumDims] = {}; int64_t compact_z_dims[kMaxNumDims] = {}; GetCompactBroadcastDims(num_cond_dims, cond_dims, num_x_dims, x_dims, num_y_dims, y_dims, &compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims, compact_z_dims); LaunchByDispatchNDim(stream, compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims, compact_z_dims, static_cast(cond), static_cast(x), static_cast(y), static_cast(z)); } }; class WhereFactoryImpl : public WhereFactory { public: OF_DISALLOW_COPY_AND_MOVE(WhereFactoryImpl); WhereFactoryImpl() = default; ~WhereFactoryImpl() override = default; std::unique_ptr New(DataType cond_type, DataType data_type, size_t max_num_dims) override { return NewWhere(cond_type, data_type, max_num_dims); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, WhereFactory, WhereFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/cuda_device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/mem_util.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/core/ep/cuda/cuda_event.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #ifdef WITH_CUDA #include #include #if CUDA_VERSION >= 11000 #include #endif namespace oneflow { namespace ep { namespace { constexpr size_t kDefaultConstBufElementCount = 1024 * 1024; template void CreateConstBuffer(void** buf, T value, size_t n) { OF_CUDA_CHECK(cudaMalloc(buf, n * sizeof(T))); std::vector host(n, value); OF_CUDA_CHECK(cudaMemcpy(*buf, host.data(), n * sizeof(T), cudaMemcpyDefault)); } } // namespace CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager) : device_index_(device_index), event_flags_{}, properties_{}, device_manager_(device_manager), const_buf_elem_cnt_(0), const_zeros_buffer_(nullptr), const_ones_buffer_fp32_(nullptr), const_ones_buffer_fp16_(nullptr), const_ones_buffer_bf16_(nullptr) { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaGetDeviceProperties(&properties_, device_index_)); { const char* env_name = "ONEFLOW_EP_CUDA_DEVICE_FLAGS"; if (std::getenv(env_name) != nullptr) { const unsigned int flags = ParseIntegerFromEnv(env_name, 0); OF_CUDA_CHECK(cudaSetDeviceFlags(flags)); } } event_flags_ = cudaEventDisableTiming; if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) { event_flags_ |= cudaEventBlockingSync; } const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT", kDefaultConstBufElementCount); if (const_buf_elem_cnt_ > 0) { CreateConstBuffer(&const_zeros_buffer_, static_cast(0), const_buf_elem_cnt_); CreateConstBuffer(&const_ones_buffer_fp32_, static_cast(1.0), const_buf_elem_cnt_); CreateConstBuffer(&const_ones_buffer_fp16_, static_cast(1.0), const_buf_elem_cnt_); #if CUDA_VERSION >= 11000 CreateConstBuffer(&const_ones_buffer_bf16_, static_cast(1.0), const_buf_elem_cnt_); #endif // CUDA_VERSION >= 11000 } #if CUDA_VERSION >= 11020 if (ParseBooleanFromEnv("ONEFLOW_EP_CUDA_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATOR", false)) { int memory_pools_supported = 0; cudaError_t err = cudaDeviceGetAttribute(&memory_pools_supported, cudaDevAttrMemoryPoolsSupported, device_index_); if (err == cudaSuccess && memory_pools_supported) { cudaMemPoolProps mem_pool_props = {}; mem_pool_props.allocType = cudaMemAllocationTypePinned; mem_pool_props.handleTypes = cudaMemHandleTypePosixFileDescriptor; mem_pool_props.location.type = cudaMemLocationTypeDevice; mem_pool_props.location.id = device_index_; OF_CUDA_CHECK(cudaMemPoolCreate(&mem_pool_, &mem_pool_props)); uint64_t threshold = UINT64_MAX; OF_CUDA_CHECK( cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolAttrReleaseThreshold, &threshold)); int disabled = 0; OF_CUDA_CHECK( cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseFollowEventDependencies, &disabled)); OF_CUDA_CHECK( cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseAllowOpportunistic, &disabled)); OF_CUDA_CHECK( cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseAllowInternalDependencies, &disabled)); } if (err != cudaSuccess) { (void)cudaGetLastError(); } } #endif // CUDA_VERSION >= 11020 } CudaDevice::~CudaDevice() { CudaCurrentDeviceGuard guard(device_index_); for (auto* event : events_) { delete event; } OF_CUDA_CHECK(cudaFree(const_zeros_buffer_)); OF_CUDA_CHECK(cudaFree(const_ones_buffer_fp32_)); OF_CUDA_CHECK(cudaFree(const_ones_buffer_fp16_)); OF_CUDA_CHECK(cudaFree(const_ones_buffer_bf16_)); #if CUDA_VERSION >= 11020 if (mem_pool_) { OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_)); } #endif // CUDA_VERSION >= 11020 } void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(cudaSetDevice(device_index_)); } void CudaDevice::Reset() { SetAsActiveDevice(); OF_CUDA_CHECK(cudaDeviceReset()); } Stream* CudaDevice::CreateStream() { CudaCurrentDeviceGuard guard(device_index_); return new CudaStream(this); } void CudaDevice::DestroyStream(Stream* stream) { CudaCurrentDeviceGuard guard(device_index_); delete stream; } void CudaDevice::CreateEvents(Event** events, size_t count) { size_t copied = 0; { std::lock_guard lock(events_mutex_); copied = std::min(count, events_.size()); size_t offset = events_.size() - copied; std::copy(events_.begin() + offset, events_.end(), events); events_.resize(offset); } if (copied != count) { CudaCurrentDeviceGuard guard(device_index_); for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); } } } void CudaDevice::DestroyEvents(Event** events, size_t count) { std::lock_guard lock(events_mutex_); events_.insert(events_.end(), events, events + count); } Maybe CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) { CudaCurrentDeviceGuard guard(device_index_); CHECK(!options.HasPinnedDevice()); cudaError_t err = cudaMalloc(ptr, size); if (err != cudaSuccess) { if (err == cudaErrorMemoryAllocation) { // NOTE:return out of memory error, so vm will try to shrink memory and rerun return Error::OutOfMemoryError() << "CUDA " << cudaGetErrorString(err) << ". Tried to allocate " << FormatMemSize(size); } return Error::RuntimeError() << cudaGetErrorString(err); } else { return Maybe::Ok(); } } void CudaDevice::Free(const AllocationOptions& attr, void* ptr) { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFree(ptr)); } Maybe CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) { CudaCurrentDeviceGuard guard(device_index_); cudaError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size); if (err != cudaSuccess) { return Error::RuntimeError() << cudaGetErrorString(err); } else { return Maybe::Ok(); } } void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(ptr)); } bool CudaDevice::IsStreamOrderedMemoryAllocationSupported() const { #if CUDA_VERSION >= 11020 return mem_pool_ != nullptr; #else return false; #endif // CUDA_VERSION >= 11020 } #if CUDA_VERSION >= 11020 cudaMemPool_t CudaDevice::mem_pool() { return mem_pool_; } #endif // CUDA_VERSION >= 11020 const cudaDeviceProp& CudaDevice::properties() const { return properties_; } const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const { if (GetSizeOfDataType(data_type) * n <= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) { return const_zeros_buffer_; } else { return nullptr; } } const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const { if (n <= const_buf_elem_cnt_) { if (data_type == DataType::kFloat) { return const_ones_buffer_fp32_; } else if (data_type == DataType::kFloat16) { return const_ones_buffer_fp16_; } else if (data_type == DataType::kBFloat16) { return const_ones_buffer_bf16_; } else { return nullptr; } } else { return nullptr; } } } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_device.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_ #define ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_ #include "oneflow/core/ep/include/device.h" #include "oneflow/core/common/data_type.h" #ifdef WITH_CUDA #include namespace oneflow { namespace ep { class CudaDevice : public Device { public: OF_DISALLOW_COPY_AND_MOVE(CudaDevice); explicit CudaDevice(int device_index, DeviceManager* device_manager); ~CudaDevice() override; void SetAsActiveDevice() override; void Reset() override; DeviceType device_type() const override { return DeviceType::kCUDA; } size_t device_index() const override { return device_index_; } DeviceManager* device_manager() const override { return device_manager_; } Stream* CreateStream() override; void DestroyStream(Stream* stream) override; void CreateEvents(Event** events, size_t count) override; void DestroyEvents(Event** events, size_t count) override; Maybe Alloc(const AllocationOptions& options, void** ptr, size_t size) override; void Free(const AllocationOptions& options, void* ptr) override; Maybe AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override; void FreePinned(const AllocationOptions& options, void* ptr) override; bool IsStreamOrderedMemoryAllocationSupported() const override; #if CUDA_VERSION >= 11020 cudaMemPool_t mem_pool(); #endif // CUDA_VERSION >= 11020 const cudaDeviceProp& properties() const; const void* GetConstZeros(DataType data_type, size_t n) const; const void* GetConstOnes(DataType data_type, size_t n) const; private: int device_index_; std::mutex events_mutex_; std::vector events_; unsigned int event_flags_; cudaDeviceProp properties_; DeviceManager* device_manager_; int64_t const_buf_elem_cnt_; void* const_zeros_buffer_; void* const_ones_buffer_fp32_; void* const_ones_buffer_fp16_; void* const_ones_buffer_bf16_; #if CUDA_VERSION >= 11020 cudaMemPool_t mem_pool_{}; #endif // CUDA_VERSION >= 11020 }; } // namespace ep } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_ ================================================ FILE: oneflow/core/ep/cuda/cuda_device_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/cuda_device_manager.h" #include "oneflow/core/ep/cuda/cuda_random_generator.h" #include "oneflow/core/device/cuda_util.h" #ifdef WITH_CUDA namespace oneflow { namespace ep { CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {} CudaDeviceManager::~CudaDeviceManager() = default; DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; } std::shared_ptr CudaDeviceManager::GetDevice(size_t device_index) { std::lock_guard lock(devices_mutex_); if (device_index < devices_.size() && devices_.at(device_index)) { return devices_.at(device_index); } auto device = std::make_shared(device_index, this); if (device_index >= devices_.size()) { devices_.resize(device_index + 1); } devices_.at(device_index) = device; return device; } size_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) { CudaCurrentDeviceGuard guard(primary_device_index); return this->GetDeviceCount(); } size_t CudaDeviceManager::GetDeviceCount() { int count = 0; cudaError_t err = cudaGetDeviceCount(&count); if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) { return 0; } OF_CUDA_CHECK(err); return count; } size_t CudaDeviceManager::GetActiveDeviceIndex() { int device = 0; OF_CUDA_CHECK(cudaGetDevice(&device)); return static_cast(device); } void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) { OF_CUDA_CHECK(cudaSetDevice(static_cast(device_index))); } std::shared_ptr CudaDeviceManager::CreateRandomGenerator(uint64_t seed, size_t device_index) { return std::make_shared(seed, device_index); } } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_device_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_ #define ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_ #include "oneflow/core/ep/include/device_manager.h" #ifdef WITH_CUDA namespace oneflow { namespace ep { class CudaDevice; class CudaDeviceManager : public DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager); CudaDeviceManager(DeviceManagerRegistry* registry); ~CudaDeviceManager() override; DeviceManagerRegistry* registry() const override; std::shared_ptr GetDevice(size_t device_index) override; size_t GetDeviceCount(size_t primary_device_index) override; size_t GetDeviceCount() override; size_t GetActiveDeviceIndex() override; void SetActiveDeviceByIndex(size_t device_index) override; bool IsStreamWaitEventSupported() const override { return true; } std::shared_ptr CreateRandomGenerator(uint64_t seed, size_t device_index) override; private: std::mutex devices_mutex_; std::vector> devices_; DeviceManagerRegistry* registry_; }; } // namespace ep } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_ ================================================ FILE: oneflow/core/ep/cuda/cuda_device_manager_factory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/device_manager_factory.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/cuda/cuda_device_manager.h" #ifdef WITH_CUDA #include #include #include namespace oneflow { namespace ep { namespace { std::string GetCudaVersionString(int version) { return std::to_string(version / 1000) + "." + std::to_string((version % 1000) / 10); } bool GetCudnnVersion(libraryPropertyType type, int* version) { cudnnStatus_t status = cudnnGetProperty(type, version); if (status == CUDNN_STATUS_SUCCESS) { return true; } else { LOG(ERROR) << "Failed to get cuDNN version: " << cudnnGetErrorString(status); return false; } } bool GetCudnnVersionString(std::string* version) { int version_major = 0; int version_minor = 0; int version_patch = 0; if (!GetCudnnVersion(libraryPropertyType::MAJOR_VERSION, &version_major)) { return false; } if (!GetCudnnVersion(libraryPropertyType::MINOR_VERSION, &version_minor)) { return false; } if (!GetCudnnVersion(libraryPropertyType::PATCH_LEVEL, &version_patch)) { return false; } *version = std::to_string(version_major) + "." + std::to_string(version_minor) + "." + std::to_string(version_patch); return true; } void CudaDumpVersionInfo() { { int cuda_runtime_version = 0; cudaError_t err = cudaRuntimeGetVersion(&cuda_runtime_version); if (err == cudaSuccess) { LOG(INFO) << "CUDA runtime version: " << GetCudaVersionString(cuda_runtime_version); } else { LOG(ERROR) << "Failed to get cuda runtime version: " << cudaGetErrorString(err); } } { std::string cudnn_version_string; if (GetCudnnVersionString(&cudnn_version_string)) { LOG(INFO) << "cuDNN version: " << cudnn_version_string; } } { int nccl_version = 0; ncclResult_t result = ncclGetVersion(&nccl_version); if (result == ncclSuccess) { int nccl_version_major = (nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000); int nccl_version_minor = (nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100; int nccl_version_patch = (nccl_version % 100); LOG(INFO) << "NCCL version: " << nccl_version_major << "." << nccl_version_minor << "." << nccl_version_patch; } else { LOG(ERROR) << "Failed to get NCCL version: " << ncclGetErrorString(result); } } } class CudaDeviceManagerFactory : public DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory); CudaDeviceManagerFactory() = default; ~CudaDeviceManagerFactory() override = default; std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kCUDA; } std::string device_type_name() const override { return "cuda"; } void DumpVersionInfo() const override { CudaDumpVersionInfo(); } }; COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory( std::make_unique())) } // namespace } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_event.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/cuda_event.h" #ifdef WITH_CUDA namespace oneflow { namespace ep { CudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} { OF_CUDA_CHECK(cudaEventCreateWithFlags(&cuda_event_, flags)); } CudaEvent::~CudaEvent() { OF_CUDA_CHECK(cudaEventDestroy(cuda_event_)); } Maybe CudaEvent::QueryDone() { cudaError_t err = cudaEventQuery(cuda_event_); if (err == cudaSuccess) { return Maybe(true); } else if (err == cudaErrorNotReady) { return Maybe(false); } else { return Error::RuntimeError() << cudaGetErrorString(err); } } Maybe CudaEvent::Sync() { cudaError_t err = cudaEventSynchronize(cuda_event_); if (err == cudaSuccess) { return Maybe::Ok(); } else { return Error::RuntimeError() << cudaGetErrorString(err); } } cudaEvent_t CudaEvent::cuda_event() { return cuda_event_; } } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_event.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_ #define ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_ #include "oneflow/core/ep/include/event.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace ep { class CudaEvent : public Event { public: OF_DISALLOW_COPY_AND_MOVE(CudaEvent); explicit CudaEvent(unsigned int flags); ~CudaEvent() override; Maybe QueryDone() override; Maybe Sync() override; cudaEvent_t cuda_event(); private: cudaEvent_t cuda_event_; }; } // namespace ep } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_ ================================================ FILE: oneflow/core/ep/cuda/cuda_matmul_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/cuda_matmul_mode.h" namespace oneflow { namespace ep { namespace { bool* GetMatmulAllowTF32() { static bool matmul_allow_tf32 = true; return &matmul_allow_tf32; } bool* GetMatmulAllowFP16ReducedPrecisionReducton() { static bool matmul_allow_fp16_reduced_precision_reduction = true; return &matmul_allow_fp16_reduced_precision_reduction; } } // namespace bool CudaMatmulMode::is_matmul_allow_tf32() { return *GetMatmulAllowTF32(); } void CudaMatmulMode::set_matmul_allow_tf32(bool matmul_allow_tf32) { *GetMatmulAllowTF32() = matmul_allow_tf32; } bool CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction() { return *GetMatmulAllowFP16ReducedPrecisionReducton(); } void CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction( bool matmul_allow_fp16_reduced_precision_reduction) { *GetMatmulAllowFP16ReducedPrecisionReducton() = matmul_allow_fp16_reduced_precision_reduction; } } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/cuda_matmul_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_ #define ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_ namespace oneflow { namespace ep { struct CudaMatmulMode { static bool is_matmul_allow_tf32(); static void set_matmul_allow_tf32(bool matmul_allow_tf32); static bool is_matmul_allow_fp16_reduced_precision_reduction(); static void set_matmul_allow_fp16_reduced_precision_reduction( bool matmul_allow_fp16_reduced_precision_reduction); }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_ ================================================ FILE: oneflow/core/ep/cuda/cuda_random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/ep/cuda/cuda_random_generator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_util.h" #include #include namespace oneflow { namespace ep { namespace { int GetThreadNum(const cudaDeviceProp& prop) { switch (prop.major) { case 3: // Kepler return 2 * 192; case 5: // Maxwell return 2 * 128; case 6: // Pascal if ((prop.minor == 1) || (prop.minor == 2)) { return 2 * 128; } else { return 2 * 64; } case 7: // Volta and Turing return 2 * 64; default: return 2 * 64; } } } // namespace CUDAGenerator::CUDAGenerator(uint64_t seed, int device_index) : RandomGenerator(), seed_(seed), device_index_(device_index), philox_offset_per_thread_(0) { int device_count; OF_CUDA_CHECK(cudaGetDeviceCount(&device_count)); CHECK_LT_OR_THROW(device_index, device_count) << "only " << device_count << " cuda devices are visible."; cudaDeviceProp prop; // NOLINT OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index)); max_block_num_ = prop.multiProcessorCount; max_thread_num_ = GetThreadNum(prop); } void CUDAGenerator::set_current_seed(uint64_t seed) { seed_ = seed; philox_offset_per_thread_ = 0; } std::tuple CUDAGenerator::CalcExecutionPolicy(int64_t total_elements, ep::CudaStream* stream) { // NOTE(Liang Depeng): the implementation is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/DistributionTemplates.h const uint64_t numel = static_cast(total_elements); const uint32_t block_size = 256; // block_size_bound // number of randoms given by distributions like curand_uniform4, curand_uniform2_double // used in calculating philox offset. const uint32_t curand4_engine_calls = 4; const uint32_t unroll = curand4_engine_calls; dim3 dim_block(block_size); dim3 grid((numel + block_size - 1) / block_size); uint32_t blocks_per_sm = stream->device_properties().maxThreadsPerMultiProcessor / block_size; grid.x = std::min( static_cast(stream->device_properties().multiProcessorCount) * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox counter in thc random // state uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1) * curand4_engine_calls; return std::make_tuple(counter_offset, grid, dim_block); } // NOTE(Liang Depeng): The implementation of ` CUDAGenerator::get_philox_offset` is modified // from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGenerator.cpp#L269 // in order to make distribution related cuda kernels to have the same output as pytorch // when setting the same seed. uint64_t CUDAGenerator::get_philox_offset(uint64_t increment) { std::lock_guard lock(mutex_); // rounds increment up to the nearest multiple of 4 increment = ((increment + 3) / 4) * 4; CHECK_EQ(this->philox_offset_per_thread_ % 4, 0); uint64_t offset = this->philox_offset_per_thread_; this->philox_offset_per_thread_ += increment; return offset; } // NOTE: The RNG state comprises the seed, and an offset used for Philox. // The following line is just here for aligning Pytorch and it is also no // practical effect in Pytorch just for backward compatibility reason. // For more details pls refer to: // https://github.com/pytorch/pytorch/blob/v1.13.1/aten/src/ATen/cuda/CUDAGenerator.cpp#L152 static constexpr size_t states_size = 200 * sizeof(4120); static constexpr size_t seed_size = sizeof(uint64_t); static constexpr size_t offset_size = sizeof(int64_t); static constexpr size_t total_size = states_size + seed_size + offset_size; size_t CUDAGenerator::GetStateSize() const { return total_size; } void CUDAGenerator::GetState(size_t state_size, void* state) const { CHECK_EQ_OR_THROW(state_size, GetStateSize()) << "the state size of cuda generator should be equal to " << GetStateSize(); memset(static_cast(state), -1, states_size); memcpy(static_cast(state) + states_size, &seed_, seed_size); memcpy(static_cast(state) + states_size + seed_size, &philox_offset_per_thread_, offset_size); } void CUDAGenerator::SetState(size_t state_size, const void* state) { CHECK_EQ_OR_THROW(state_size, GetStateSize()) << "the state size of cuda generator should be equal to " << GetStateSize(); const uint8_t* data = static_cast(state); seed_ = *((uint64_t*)(data + states_size)); philox_offset_per_thread_ = *((uint64_t*)(data + states_size + seed_size)); } template<> std::string GetRandomGeneratorDeviceTypeName() { return "cuda"; } } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_ #ifdef WITH_CUDA #include #include #include #include "oneflow/core/common/device_type.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/random_generator.h" namespace oneflow { namespace ep { class CUDAGenerator : public RandomGenerator { public: explicit CUDAGenerator(uint64_t seed, int device_index); virtual ~CUDAGenerator() = default; int32_t max_block_num() const { return max_block_num_; } int32_t max_thread_num() const { return max_thread_num_; } uint64_t current_seed() const override { return seed_; } void set_current_seed(uint64_t seed) override; std::string device_type_name() const override { return "cuda"; } int64_t device_index() const override { return device_index_; } size_t GetStateSize() const override; void GetState(size_t state_size, void* state) const override; void SetState(size_t state_size, const void* state) override; std::tuple CalcExecutionPolicy(int64_t total_elements, CudaStream* stream); uint64_t get_philox_offset(uint64_t increment); public: mutable std::mutex mutex_; private: uint64_t seed_; int64_t device_index_; int32_t max_block_num_; int32_t max_thread_num_; uint64_t philox_offset_per_thread_ = 0; }; } // namespace ep } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/ep/cuda/cuda_stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/hardware/node_device_descriptor_manager.h" #include "oneflow/core/hardware/cuda_device_descriptor.h" #include "oneflow/core/ep/cuda/cuda_event.h" #include "oneflow/core/ep/cuda/cuda_device.h" #ifdef WITH_CUDA namespace oneflow { namespace ep { namespace { constexpr size_t kDefaultWorkspaceSizeMb = 4; // 4M void SetAffinityByDevice(int dev_id) { auto node_device_desc_mgr = Singleton::Get(); if (node_device_desc_mgr == nullptr) { return; } auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor(); auto cuda_device = std::dynamic_pointer_cast( node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id)); if (!cuda_device) { return; } node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID()); node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID()); } void CheckVersionCompatibility(int compiletime_major, int compiletime_minor, int runtime_major, int runtime_minor, const std::string& name) { if (runtime_major != compiletime_major || runtime_minor < compiletime_minor) { LOG(WARNING) << "Runtime version " << runtime_major << "." << runtime_minor << " of " << name << " incompatible with compiletime version " << compiletime_major << "." << compiletime_minor << "."; } } void CheckCudaRuntimeVersion() { #if !defined(CUDART_VERSION) #error #endif // !defined(CUDART_VERSION) const int compiletime_major = CUDART_VERSION / 1000; const int compiletime_minor = CUDART_VERSION % 1000 / 10; int runtime_version = 0; OF_CUDA_CHECK(cudaRuntimeGetVersion(&runtime_version)); const int runtime_major = runtime_version / 1000; const int runtime_minor = runtime_version % 1000 / 10; CheckVersionCompatibility(compiletime_major, compiletime_minor, runtime_major, runtime_minor, "CUDA Runtime"); } void CheckCublasVersion(cublasHandle_t handle) { #if CUDA_VERSION >= 10020 #if (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR)) #error #endif // (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR)) int runtime_version = 0; OF_CUBLAS_CHECK(cublasGetVersion(handle, &runtime_version)); int runtime_major = 0; int runtime_minor = 0; if (runtime_version >= 100000) { runtime_major = runtime_version / 10000; runtime_minor = runtime_version % 10000 / 100; } else { runtime_major = runtime_version / 1000; runtime_minor = runtime_version % 1000 / 100; } CheckVersionCompatibility(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, runtime_major, runtime_minor, "cuBLAS"); #endif // CUDA_VERSION >= 10020 } void CheckCudnnVersion() { #if (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR)) #error #endif // (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR)) int runtime_major = 0; int runtime_minor = 0; OF_CUDNN_CHECK(cudnnGetProperty(libraryPropertyType::MAJOR_VERSION, &runtime_major)); OF_CUDNN_CHECK(cudnnGetProperty(libraryPropertyType::MINOR_VERSION, &runtime_minor)); CheckVersionCompatibility(CUDNN_MAJOR, CUDNN_MINOR, runtime_major, runtime_minor, "cuDNN"); } } // namespace #ifdef WITH_CUDA_GRAPHS CudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {} CudaGraphExecutable::~CudaGraphExecutable() { Reset(); } void CudaGraphExecutable::Update(cudaGraph_t graph) { int dev = -1; OF_CUDA_CHECK(cudaGetDevice(&dev)); if (dev != dev_) { Reset(); } dev_ = dev; if (graph_exec_ != nullptr) { #if CUDA_VERSION < 12000 cudaGraphExecUpdateResult update_result{}; cudaGraphNode_t error_node = nullptr; OF_CUDA_CHECK(cudaGraphExecUpdate(graph_exec_, graph, &error_node, &update_result)); if (update_result == cudaGraphExecUpdateSuccess) { return; } #else cudaGraphExecUpdateResultInfo update_result{}; OF_CUDA_CHECK(cudaGraphExecUpdate(graph_exec_, graph, &update_result)); if (update_result.result == cudaGraphExecUpdateSuccess) { return; } #endif // CUDA_VERSION < 12000 } Reset(); OF_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0)); } void CudaGraphExecutable::Launch(cudaStream_t stream) const { OF_CUDA_CHECK(cudaGraphLaunch(graph_exec_, stream)); } bool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; } void CudaGraphExecutable::Reset() { if (graph_exec_ != nullptr) { CudaCurrentDeviceGuard guard(dev_); OF_CUDA_CHECK(cudaGraphExecDestroy(graph_exec_)); } } #endif // WITH_CUDA_GRAPHS CudaStream::CudaStream(CudaDevice* device) : device_index_(device->device_index()), device_(device) { CudaCurrentDeviceGuard guard(device_index_); const bool need_check_version = []() { static std::atomic version_checked(false); return version_checked.exchange(true) == false; }(); if (need_check_version) { CheckCudaRuntimeVersion(); } // cuda_stream const char* stream_flags_env_name = "ONEFLOW_EP_CUDA_STREAM_FLAGS"; if (std::getenv(stream_flags_env_name) != nullptr) { const unsigned int stream_flags = ParseIntegerFromEnv(stream_flags_env_name, 0); OF_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream_, stream_flags)); } else { OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); } // cublas_handle OF_CUBLAS_CHECK(cublasCreate(&cublas_handle_)); OF_CUBLAS_CHECK(cublasSetStream(cublas_handle_, cuda_stream_)); if (need_check_version) { CheckCublasVersion(cublas_handle_); } #if CUDA_VERSION >= 10010 // cublas_lt_handle OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); #endif #if CUBLAS_VERSION >= 11000 if (ParseBooleanFromEnv("ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION", true)) { OF_CUBLAS_CHECK(cublasSetMathMode(cublas_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); } #endif // CUBLAS_VERSION >= 11000 // cusolver_dn_handle #if CUDA_VERSION >= 11000 OF_CUSOLVER_CHECK(cusolverDnCreate(&cusolver_dn_handle_)); OF_CUSOLVER_CHECK(cusolverDnSetStream(cusolver_dn_handle_, cuda_stream_)); #endif workspace_size_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) * 1024 * 1024; OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_)); #if CUBLAS_VERSION >= 11200 OF_CUBLAS_CHECK(cublasSetWorkspace(cublas_handle_, workspace_, workspace_size_)); #endif // CUBLAS_VERSION >= 11200 // cudnn_handle OF_CUDNN_CHECK(cudnnCreate(&cudnn_handle_)); OF_CUDNN_CHECK(cudnnSetStream(cudnn_handle_, cuda_stream_)); if (need_check_version) { CheckCudnnVersion(); } } CudaStream::~CudaStream() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); OF_CUDNN_CHECK(cudnnDestroy(cudnn_handle_)); OF_CUBLAS_CHECK(cublasDestroy(cublas_handle_)); #if CUDA_VERSION >= 11000 OF_CUSOLVER_CHECK(cusolverDnDestroy(cusolver_dn_handle_)); #endif #if CUDA_VERSION >= 10010 OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_)); #endif OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); OF_CUDA_CHECK(cudaFree(workspace_)); } Maybe CudaStream::OnExecutionContextSetup() { OF_CUDA_CHECK(cudaSetDevice(device_index_)); SetAffinityByDevice(device_index_); return Maybe::Ok(); } Maybe CudaStream::OnExecutionContextTeardown() { return Maybe::Ok(); } DeviceType CudaStream::device_type() const { return DeviceType::kCUDA; } CudaDevice* CudaStream::device() const { return device_; } Maybe CudaStream::Sync() { cudaError_t err = cudaStreamSynchronize(cuda_stream_); if (err == cudaSuccess) { return Maybe::Ok(); } else { return Error::RuntimeError() << cudaGetErrorString(err) << " (" << err << ") "; } } void CudaStream::RecordEvent(Event* event) { auto* cuda_event = static_cast(event); // NOLINT OF_CUDA_CHECK(cudaEventRecord(cuda_event->cuda_event(), cuda_stream_)); } void CudaStream::WaitEvent(Event* event) { auto* cuda_event = static_cast(event); // NOLINT OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream_, cuda_event->cuda_event(), 0)); } Maybe CudaStream::GetAsyncError() { cudaError_t err = cudaGetLastError(); if (err == cudaSuccess) { return Maybe::Ok(); } else { return Error::RuntimeError() << cudaGetErrorString(err) << " (" << err << ") "; } } Maybe CudaStream::AllocAsync(void** ptr, size_t size) { #if CUDA_VERSION >= 11020 if (!device_->IsStreamOrderedMemoryAllocationSupported()) { UNIMPLEMENTED_THEN_RETURN(); } cudaError_t err = cudaMallocFromPoolAsync(ptr, size, device_->mem_pool(), cuda_stream_); if (err == cudaSuccess) { return Maybe::Ok(); } else { return Error::RuntimeError() << cudaGetErrorString(err) << " (" << err << ") "; } #else UNIMPLEMENTED_THEN_RETURN(); #endif // CUDA_VERSION >= 11020 } Maybe CudaStream::FreeAsync(void* ptr) { #if CUDA_VERSION >= 11020 if (!device_->IsStreamOrderedMemoryAllocationSupported()) { UNIMPLEMENTED_THEN_RETURN(); } cudaError_t err = cudaFreeAsync(ptr, cuda_stream_); if (err == cudaSuccess) { return Maybe::Ok(); } else { return Error::RuntimeError() << cudaGetErrorString(err) << " (" << err << ") "; } #else UNIMPLEMENTED_THEN_RETURN(); #endif // CUDA_VERSION >= 11020 } cudaStream_t CudaStream::cuda_stream() const { return cuda_stream_; } cublasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; } #if CUDA_VERSION >= 11000 cusolverDnHandle_t CudaStream::cusolver_dn_handle() const { return cusolver_dn_handle_; } #endif #if CUDA_VERSION >= 10010 cublasLtHandle_t CudaStream::cublas_lt_handle() const { return cublas_lt_handle_; } #endif void* CudaStream::cublas_workspace() const { return workspace_; } size_t CudaStream::cublas_workspace_size() const { return workspace_size_; } cudnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; } const cudaDeviceProp& CudaStream::device_properties() const { return device_->properties(); } int CudaStream::cuda_arch() const { return device_->properties().major * 100 + device_->properties().minor * 10; } #ifdef WITH_CUDA_GRAPHS void CudaStream::BeginGraphCapture() { CHECK(!is_graph_capturing_); is_graph_capturing_ = true; OF_CUDA_CHECK(cudaStreamBeginCapture(cuda_stream_, cudaStreamCaptureModeThreadLocal)); } void CudaStream::EndGraphCapture(CudaGraphExecutable* executable) { cudaGraph_t graph = nullptr; OF_CUDA_CHECK(cudaStreamEndCapture(cuda_stream_, &graph)); executable->Update(graph); OF_CUDA_CHECK(cudaGraphDestroy(graph)); is_graph_capturing_ = false; } bool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; } void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) { executable->Launch(cuda_stream_); } #endif // WITH_CUDA_GRAPHS } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/cuda_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_ #define ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/cuda/cuda_device.h" #ifdef WITH_CUDA #include #include #if CUDA_VERSION >= 11000 #define WITH_CUDA_GRAPHS #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace ep { class CudaDevice; #ifdef WITH_CUDA_GRAPHS class CudaGraphExecutable { public: OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable); CudaGraphExecutable(); ~CudaGraphExecutable(); void Update(cudaGraph_t graph); void Launch(cudaStream_t stream) const; bool IsInstantiated() const; private: void Reset(); cudaGraphExec_t graph_exec_; int dev_; }; #endif // WITH_CUDA_GRAPHS struct CudaLaunchConfig { dim3 grid_dim; dim3 block_dim; size_t shared_mem_size; CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {} CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size) : grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {} }; class CudaStream : public Stream { public: OF_DISALLOW_COPY_AND_MOVE(CudaStream); explicit CudaStream(CudaDevice* device); ~CudaStream() override; static constexpr uint32_t kDefaultBlockSize = 256; DeviceType device_type() const override; CudaDevice* device() const override; Maybe Sync() override; void RecordEvent(Event* event) override; void WaitEvent(Event* event) override; Maybe GetAsyncError() override; Maybe AllocAsync(void** ptr, size_t size) override; Maybe FreeAsync(void* ptr) override; Maybe OnExecutionContextSetup() override; Maybe OnExecutionContextTeardown() override; cudaStream_t cuda_stream() const; cublasHandle_t cublas_handle() const; #if CUDA_VERSION >= 11000 cusolverDnHandle_t cusolver_dn_handle() const; #endif #if CUDA_VERSION >= 10010 cublasLtHandle_t cublas_lt_handle() const; #endif cudnnHandle_t cudnn_handle() const; void* cublas_workspace() const; size_t cublas_workspace_size() const; const cudaDeviceProp& device_properties() const; int cuda_arch() const; void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size, size_t max_waves) const { const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount * (device_properties().maxThreadsPerMultiProcessor / block_size); const uint32_t grid_size = std::min(max_grid_size, (elem_cnt + block_size - 1) / block_size); config->grid_dim = dim3(grid_size); config->block_dim = dim3(block_size); config->shared_mem_size = 0; } #ifdef __CUDACC__ template void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config, Args... args) { kernel<<>>(args...); } template void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) { constexpr uint32_t block_size = kDefaultBlockSize; CudaLaunchConfig config{}; InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves); LaunchKernel(kernel, config, args...); } template void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) { const size_t default_waves = 32; LaunchKernel(kernel, elem_cnt, default_waves, args...); } #endif // __CUDACC__ #ifdef WITH_CUDA_GRAPHS void BeginGraphCapture(); void EndGraphCapture(CudaGraphExecutable* executable); bool IsGraphCapturing() const; void LaunchGraph(const CudaGraphExecutable* executable); #endif // WITH_CUDA_GRAPHS private: cudaStream_t cuda_stream_{}; cublasHandle_t cublas_handle_{}; #if CUDA_VERSION >= 11000 cusolverDnHandle_t cusolver_dn_handle_{}; #endif #if CUDA_VERSION >= 10010 cublasLtHandle_t cublas_lt_handle_{}; #endif cudnnHandle_t cudnn_handle_{}; int device_index_; void* workspace_{}; size_t workspace_size_{}; #ifdef WITH_CUDA_GRAPHS bool is_graph_capturing_{}; #endif // WITH_CUDA_GRAPHS CudaDevice* device_; }; } // namespace ep } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_ ================================================ FILE: oneflow/core/ep/cuda/primitive/add.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace ep { namespace primitive { namespace { template struct AddFunctor; template struct AddFunctor { __device__ T operator()(T x) const { return x; } }; template struct AddFunctor { __device__ T operator()(T x0, U x1, Args... xs) const { return x0 + AddFunctor()(x1, xs...); } }; template struct AddFunctor { __device__ cuComplex operator()(cuComplex x0, U x1, Args... xs) const { cuComplex xn = AddFunctor()(x1, xs...); return cuComplex{x0.x + xn.x, x0.y + xn.y}; } }; template struct AddFunctor { __device__ cuDoubleComplex operator()(cuDoubleComplex x0, U x1, Args... xs) const { cuDoubleComplex xn = AddFunctor()(x1, xs...); return cuDoubleComplex{x0.x + xn.x, x0.y + xn.y}; } }; template __global__ void AddGpu(const Args*... srcs, T* dst, size_t count) { CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor()(srcs[i]...); } } template void LaunchAddGpu(cudaStream_t stream, const Args*... srcs, T* dst, size_t count) { AddGpu <<>>(srcs..., dst, count); } template void DispatchLaunch(cudaStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) { if (arity == 0) { OF_CUDA_CHECK(cudaMemsetAsync(dst, 0, count * sizeof(T), stream)); } else if (arity == 1) { OF_CUDA_CHECK(cudaMemcpyAsync(dst, srcs[0], count * sizeof(T), cudaMemcpyDefault, stream)); } else if (arity == 2) { OF_CUDA_CHECK((cuda::elementwise::Binary, T, T, T>( AddFunctor(), count, dst, srcs[0], srcs[1], stream))); } else if (arity == 3) { OF_CUDA_CHECK((cuda::elementwise::Ternary, T, T, T, T>( AddFunctor(), count, dst, srcs[0], srcs[1], srcs[2], stream))); } else if (arity == 4) { LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count); } else if (arity == 5) { LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count); } else if (arity == 6) { LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5], dst, count); } else if (arity == 7) { LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5], srcs[6], dst, count); } else if (arity == 8) { LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5], srcs[6], srcs[7], dst, count); } else { DispatchLaunch(stream, srcs + 7, arity - 7, dst, count); LaunchAddGpu(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5], srcs[6], dst, dst, count); } } template class AddImpl : public Add { public: OF_DISALLOW_COPY_AND_MOVE(AddImpl); AddImpl() = default; ~AddImpl() override = default; using Add::Launch; void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, size_t count) override { cudaStream_t cuda_stream = stream->As()->cuda_stream(); DispatchLaunch(cuda_stream, reinterpret_cast(srcs), arity, reinterpret_cast(dst), count); } }; template std::unique_ptr NewAdd() { return std::unique_ptr(new AddImpl()); } class AddFactoryImpl : public AddFactory { public: OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl); AddFactoryImpl() = default; ~AddFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd}, static const std::map()>> new_add_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)}; #undef MAKE_NEW_ADD_ENTRY const auto it = new_add_handle.find(data_type); if (it != new_add_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/binary_functor.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/common/primitive/binary_functor.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return fmod(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return fmod(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return floor(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return floor(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { return truncf(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { return trunc(src0 / src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src0, float src1) const { float trunc_mod = fmod(src0, src1); return (trunc_mod != static_cast(0)) && ((src1 < static_cast(0)) != (trunc_mod < static_cast(0))) ? trunc_mod + src1 : trunc_mod; } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src0, double src1) const { double trunc_mod = fmod(src0, src1); return (trunc_mod != static_cast(0)) && ((src1 < static_cast(0)) != (trunc_mod < static_cast(0))) ? trunc_mod + src1 : trunc_mod; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) { #if defined(__CUDA_ARCH__) coef = sqrt(static_cast(2.0) / acos(static_cast(-1.0))); #else coef = std::sqrt(static_cast(2.0) / std::acos(static_cast(-1.0))); #endif } OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast(0.5) * (static_cast(1.0) + erf(static_cast(M_SQRT1_2) * x) + x * coef * exp(static_cast(-0.5) * x * x)) * dy; } Src coef; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu const Src one = static_cast(1); const Src half = static_cast(0.5); const Src pow3 = x * x * x; const Src tanh_out = std::tanh(alpha * (x + beta * pow3)); const Src dtanh = alpha * (half * x + beta * static_cast(1.5) * pow3); return dy * (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out)); } private: const Src alpha = static_cast(0.7978845608028654); const Src beta = static_cast(0.044714998453855515); }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src one = static_cast(1.0); const Src sigmoid = one / (one + exp(-x * alpha)); return dy * (sigmoid + alpha * x * (sigmoid * (one - sigmoid))); } private: const Src alpha = static_cast(1.702); }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast((x > static_cast(0.0)) ? static_cast(2.0) * x * dy : static_cast(0.0)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return static_cast(dy * (static_cast(1.0) - y * y)); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int src0, int src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int8_t src0, int8_t src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(uint8_t src0, uint8_t src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} BinaryFunctor float_functor; OF_DEVICE_FUNC Dst operator()(int src0, int src1) const { return static_cast(float_functor(static_cast(src0), static_cast(src1))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { const Src one = static_cast(1.0); return dy * one / (one - static_cast(pow(x, 2))); } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : atol(attr0.Value()), rtol(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { bool close = src0 == src1; close |= (isnan(src0) and isnan(src1)); if (atol == 0 and rtol == 0) return close; Src allowed_error = static_cast(atol) + abs(static_cast(rtol) * src1); Src actual_error = abs(src0 - src1); close |= (isfinite(actual_error) and (actual_error <= allowed_error)); return close; } float atol, rtol; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : atol(attr0.Value()), rtol(attr1.Value()) {} OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { bool close = src0 == src1; if (atol == 0 and rtol == 0) return close; Src allowed_error = static_cast(atol) + abs(static_cast(rtol) * src1); Src actual_error = abs(src0 - src1); close |= (isfinite(actual_error) and (actual_error <= allowed_error)); return close; } float atol, rtol; }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { ep::primitive::UnaryFunctor trigamma_functor( 0, 0); Src trigamma_result = trigamma_functor(x); return trigamma_result * dy; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { ep::primitive::UnaryFunctor digamma_functor(0, 0); Dst digamma_result = digamma_functor(x); return digamma_result * dy; } }; template struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src x, Src q) const { // ref // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L302-L384 const Src MACHEP{1.11022302462515654042E-16}; constexpr Src zero{0}; constexpr Src half{0.5}; constexpr Src one{1}; static const Src A[] = { 12.0, -720.0, 30240.0, -1209600.0, 47900160.0, -1.8924375803183791606e9, /*1.307674368e12/691*/ 7.47242496e10, -2.950130727918164224e12, /*1.067062284288e16/3617*/ 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ }; int i = 0; Src a, b, k, s, t, w; // Short-circuits x -> +infty if (x == one) { return INFINITY; } // Short-circuits x < 1 -> NaN if (x < one) { return NAN; } // Short-circuits negative q integers map to +infty, // negative q non-integers map to NaN if (q <= zero) { if (q == floor(q)) { return INFINITY; } if (x != floor(x)) { return NAN; } } s = pow(q, -x); a = q; i = 0; b = zero; while ((i < 9) || (a <= Src{9.0})) { i += 1; a += one; b = pow(a, -x); s += b; if ((-MACHEP * s < b) && (b < MACHEP * s)) { return s; } } w = a; s += b * w / (x - one); s -= half * b; a = one; k = zero; for (int i = 0; i < 12; i++) { a *= x + k; b /= w; t = a * b / A[i]; s = s + t; t = fabs(t / s); if (t < MACHEP) { return s; } k += one; a *= x + k; b /= w; k += one; } return s; } }; #define SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(op, type) \ template \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ OF_DEVICE_FUNC Dst operator()(type src0, type src1) const { \ return float_functor(static_cast(src0), static_cast(src1)); \ } \ BinaryFunctor float_functor; \ }; SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, bool); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, char); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int8_t); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, uint8_t); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int64_t); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, bool); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, char); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int8_t); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, uint8_t); SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int64_t); /*********nv_bfloat16_kernel*******/ #if CUDA_VERSION >= 11000 #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \ template<> \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ \ BinaryFunctor float_functor; \ OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \ return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \ } \ }; SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kPow); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFmod); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorDiv); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTruncDiv); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorMod); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kZeta); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kIdentityBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAsinBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAsinhBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCosBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCoshBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kErfBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kErfcBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExpBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExp2BackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExpm1BackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLog2BackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLog10BackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLogSigmoidBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kReciprocalNoNanBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kRsqrtBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSinBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSinhBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSqrtBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyY); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAtanhBackwardWithDyX); SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLgammaBackwardWithDyX); #define SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(op) \ template \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ BinaryFunctor float_functor; \ OF_DEVICE_FUNC Dst operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \ return float_functor(__bfloat162float(src0), __bfloat162float(src1)); \ } \ }; SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan) SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsClose) #endif // CUDA_VERSION >= 11000 #define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \ template<> \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ \ BinaryFunctor float_functor; \ OF_DEVICE_FUNC half operator()(half src0, half src1) const { \ return __float2half(float_functor(__half2float(src0), __half2float(src1))); \ } \ }; SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kPow); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFmod); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorDiv); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTruncDiv); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorMod); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kZeta); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAsinBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAsinhBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCosBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCoshBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kErfBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kErfcBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExpBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExp2BackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExpm1BackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLog2BackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLog10BackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLogSigmoidBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kReciprocalNoNanBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kRsqrtBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSinBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSinhBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSqrtBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAtanhBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLgammaBackwardWithDyX); #define SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(op) \ template \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ BinaryFunctor float_functor; \ OF_DEVICE_FUNC Dst operator()(half src0, half src1) const { \ return float_functor(__half2float(src0), __half2float(src1)); \ } \ }; SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan) SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsClose) template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src0, cuComplex src1) const { return cuCmulf(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src0, cuComplex src1) const { return cuCdivf(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src0, cuDoubleComplex src1) const { return cuCmul(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src0, cuDoubleComplex src1) const { return cuCdiv(src0, src1); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : unary_functor(attr0, attr1) {} UnaryFunctor unary_functor; OF_DEVICE_FUNC cuComplex operator()(cuComplex dy, cuComplex x) const { // dy / (2 * sqrt(x).conj()) cuComplex y = unary_functor(x); return cuCdivf(dy, cuComplex{2.0f * y.x, -2.0f * y.y}); } }; template<> struct BinaryFunctor { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : unary_functor(attr0, attr1) {} UnaryFunctor unary_functor; OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex dy, cuDoubleComplex x) const { // dy / (2 * sqrt(x).conj()) cuDoubleComplex y = unary_functor(x); return cuCdiv(dy, cuDoubleComplex{2.0 * y.x, -2.0 * y.y}); } }; #define SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(op, complex_type, real_type) \ template<> \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \ BinaryFunctor real_functor; \ OF_DEVICE_FUNC complex_type operator()(complex_type src0, complex_type src1) const { \ return complex_type{real_functor(src0.x, src1.x), real_functor(src0.y, src1.y)}; \ } \ }; SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kAdd, cuComplex, float); SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kSub, cuComplex, float); SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kAdd, cuDoubleComplex, double); SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kSub, cuDoubleComplex, double); #define SPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(complex_type, real_type) \ template \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \ BinaryFunctor real_functor; \ OF_DEVICE_FUNC Dst operator()(complex_type src0, complex_type src1) const { \ return static_cast(real_functor(src0.x, src1.x) && real_functor(src0.y, src1.y)); \ } \ }; SPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(cuComplex, float); SPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(cuDoubleComplex, double); #define SPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(complex_type, real_type) \ template \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \ BinaryFunctor real_functor; \ OF_DEVICE_FUNC Dst operator()(complex_type src0, complex_type src1) const { \ return static_cast(real_functor(src0.x, src1.x) || real_functor(src0.y, src1.y)); \ } \ }; SPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(cuComplex, float); SPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(cuDoubleComplex, double); #define SPECIALIZATION_GPU_BINARY_FUNCTOR(op, type) \ template<> \ struct BinaryFunctor { \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \ \ BinaryFunctor int_functor; \ OF_DEVICE_FUNC type operator()(type src0, type src1) const { \ return static_cast(int_functor(static_cast(src0), static_cast(src1))); \ } \ }; SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, bool); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, char); SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, char); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/primitive/binary_functor.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { template std::unique_ptr NewBroadcastElementwiseBinary(Scalar attr0, Scalar attr1); namespace { class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); BroadcastElementwiseBinaryFactoryImpl() = default; ~BroadcastElementwiseBinaryFactoryImpl() override = default; std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) override { return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar()); } std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) override { return New(op, src_type, dst_type, max_num_dims, attr0, Scalar()); } std::unique_ptr New(BinaryOp binary_op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) override { if (max_num_dims > kMaxNumDims) { return nullptr; } #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ OF_PP_PAIR_SECOND(data_type_pair)), \ NewBroadcastElementwiseBinary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ binary_op, src_data_type_pair, dst_data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \ OF_PP_PAIR_SECOND(dst_data_type_pair)), \ NewBroadcastElementwiseBinary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ OF_PP_PAIR_SECOND(data_type_pair)), \ NewBroadcastElementwiseBinary}, static const std::map< std::tuple, std::function(Scalar, Scalar)>> new_broadcast_elementwise_binary_handle{ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_COMPLEX_MATH_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_BITWISE_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ)}; #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY const auto it = new_broadcast_elementwise_binary_handle.find( std::make_tuple(binary_op, src_type, dst_type)); if (it != new_broadcast_elementwise_binary_handle.end()) { return it->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory, BroadcastElementwiseBinaryFactoryImpl); } // namespace } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/primitive/binary_functor.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { namespace { template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * N, ""); OF_DEVICE_FUNC Pack() { // do nothing } PackType storage; T elem[N]; }; template struct BroadcastElementwiseBinaryParams { NdIndexOffsetHelper src0_index_helper; NdIndexOffsetHelper src1_index_helper; NdIndexOffsetHelper dst_index_helper; size_t num_dims; IndexType src0_index_mask[max_dims]; IndexType src1_index_mask[max_dims]; IndexType count{}; const void* src0{}; const void* src1{}; void* dst{}; Scalar attr0; Scalar attr1; }; template __global__ void BroadcastElementwiseBinaryGpu( BroadcastElementwiseBinaryParams params) { constexpr size_t dst_pack_size = src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size; static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, ""); static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, ""); const PackType* src0 = reinterpret_cast*>(params.src0); const PackType* src1 = reinterpret_cast*>(params.src1); PackType* dst = reinterpret_cast*>(params.dst); IndexType src0_index[max_dims]; IndexType src1_index[max_dims]; IndexType dst_index[max_dims]; size_t num_dims = params.num_dims; CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) { params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims); #pragma unroll for (int i = 0; i < max_dims; ++i) { if (i < num_dims) { src0_index[i] = params.src0_index_mask[i] * dst_index[i]; src1_index[i] = params.src1_index_mask[i] * dst_index[i]; } else { src0_index[i] = 0; src1_index[i] = 0; } } const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims); const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims); Pack src0_pack; src0_pack.storage = src0[src0_offset]; Pack src1_pack; src1_pack.storage = src1[src1_offset]; Pack dst_pack; BinaryFunctor functor(params.attr0, params.attr1); #pragma unroll for (int j = 0; j < dst_pack_size; ++j) { const Src src0_val = (src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0]; const Src src1_val = (src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0]; dst_pack.elem[j] = functor(src0_val, src1_val); } dst[offset] = dst_pack.storage; } } template void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, size_t count, Scalar attr0, Scalar attr1) { BroadcastElementwiseBinaryParams params; for (size_t i = 0; i < num_dims; ++i) { params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1; params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1; } params.src0_index_helper = NdIndexOffsetHelper(src0_dims, num_dims); params.src1_index_helper = NdIndexOffsetHelper(src1_dims, num_dims); params.dst_index_helper = NdIndexOffsetHelper(dst_dims, num_dims); params.num_dims = num_dims; params.src0 = src0; params.src1 = src1; params.dst = dst; params.count = static_cast(count); params.attr0 = attr0; params.attr1 = attr1; auto* cuda_stream = stream->As(); BroadcastElementwiseBinaryGpu <<cuda_stream()>>>(params); } template void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0, Scalar attr1) { size_t count = GetElementCount(num_dims, dst_dims); if (count < GetMaxVal()) { LaunchKernel( stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1); } else { LaunchKernel( stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1); } } template void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0, Scalar attr1) { void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr; if (src0_pack_size == 1 && src1_pack_size == 1) { func = DispatchIndexType; } else if (src0_pack_size == 4 && src1_pack_size == 4) { func = DispatchIndexType; } else if (src0_pack_size == 1 && src1_pack_size == 4) { func = DispatchIndexType; } else if (src0_pack_size == 4 && src1_pack_size == 1) { func = DispatchIndexType; } else { UNIMPLEMENTED(); } func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1); } template void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0, Scalar attr1) { void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr; CHECK_NE(num_dims, 1); if (num_dims == 2) { func = DispatchPackSize; } else if (num_dims == 3) { func = DispatchPackSize; } else if (num_dims == 4) { func = DispatchPackSize; } else if (num_dims <= 8) { func = DispatchPackSize; } else { UNIMPLEMENTED(); } func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1); } template size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const void* src1, void* dst) { static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1); auto dst_ptr = reinterpret_cast(dst); for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) { bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1) || IsPackSizeSupported(pack_size, num_src_dims, src0_dims, src0); bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1) || IsPackSizeSupported(pack_size, num_src_dims, src1_dims, src1); if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) { return pack_size; } } return 1; } constexpr size_t kMaxPackSize = 4; template void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims, const void* src0, int64_t* simplified_src1_dims, const void* src1, int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) { CHECK_LE(simplified_num_dims, kMaxNumDims); size_t pack_size = GetPackSize(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, dst); size_t src0_pack_size = 1; size_t src1_pack_size = 1; if (simplified_src0_dims[simplified_num_dims - 1] != 1) { simplified_src0_dims[simplified_num_dims - 1] /= pack_size; src0_pack_size = pack_size; } if (simplified_src1_dims[simplified_num_dims - 1] != 1) { simplified_src1_dims[simplified_num_dims - 1] /= pack_size; src1_pack_size = pack_size; } simplified_dst_dims[simplified_num_dims - 1] /= pack_size; DispatchNumDims(stream, src0_pack_size, src1_pack_size, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_dst_dims, dst, attr0, attr1); } template struct BinaryLhsScalarFunctor { __host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) : scalar(scalar), functor(attr0, attr1) {} __device__ Dst operator()(Src src) const { return functor(scalar, src); } const Src scalar; BinaryFunctor functor; }; template struct BinaryRhsScalarFunctor { __host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) : scalar(scalar), functor(attr0, attr1) {} __device__ Dst operator()(Src src) const { return functor(src, scalar); } const Src scalar; BinaryFunctor functor; }; template struct BinaryLhsScalarPtrFunctorFactory { __host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0, Scalar attr1) : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {} __device__ BinaryLhsScalarFunctor operator()() const { return BinaryLhsScalarFunctor(*scalar_ptr, attr0, attr1); } const Src* scalar_ptr; Scalar attr0, attr1; }; template struct BinaryRhsScalarPtrFunctorFactory { __host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0, Scalar attr1) : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {} __device__ BinaryRhsScalarFunctor operator()() const { return BinaryRhsScalarFunctor(*scalar_ptr, attr0, attr1); } const Src* scalar_ptr; Scalar attr0, attr1; }; template void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0, size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) { auto* cuda_stream = stream->As(); size_t simplified_num_dims = 0; int64_t simplified_src0_dims[kMaxNumDims]; int64_t simplified_src1_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims]; SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, &simplified_num_dims, simplified_src0_dims, simplified_src1_dims, simplified_dst_dims); CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_dst_dims, dst); CheckInplace(simplified_num_dims, simplified_src1_dims, src1, simplified_dst_dims, dst); if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims, simplified_src1_dims)) { const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims); OF_CUDA_CHECK((cuda::elementwise::Binary( BinaryFunctor(attr0, attr1), elem_cnt, dst, src0, src1, cuda_stream->cuda_stream()))); } else { if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) { OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( BinaryLhsScalarPtrFunctorFactory(src0, attr0, attr1), simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream()))); } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) { OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( BinaryRhsScalarPtrFunctorFactory(src1, attr0, attr1), simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream()))); } else { LaunchWithSimplified(stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_dst_dims, dst, attr0, attr1); } } } template T GetValue(Scalar value) { return value.Value(); } template<> half GetValue(Scalar value) { return static_cast(GetValue(value)); } template<> cuComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuFloatComplex{cpp_value.real(), cpp_value.imag()}; } template<> cuDoubleComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuDoubleComplex{cpp_value.real(), cpp_value.imag()}; } #if CUDA_VERSION >= 11000 template<> nv_bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } #endif // CUDA_VERSION >= 11000 template class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl); BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~BroadcastElementwiseBinaryImpl() override = default; void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) override { auto* cuda_stream = stream->As(); const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims); OF_CUDA_CHECK((cuda::elementwise::Unary( BinaryLhsScalarFunctor(GetValue(src0), attr0, attr1), elem_cnt, reinterpret_cast(dst), reinterpret_cast(src1), cuda_stream->cuda_stream()))); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, Scalar src1, void* dst) override { auto* cuda_stream = stream->As(); const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims); OF_CUDA_CHECK((cuda::elementwise::Unary( BinaryRhsScalarFunctor(GetValue(src1), attr0, attr1), elem_cnt, reinterpret_cast(dst), reinterpret_cast(src0), cuda_stream->cuda_stream()))); } void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) override { DispatchLaunch( stream, num_src0_dims, src0_dims, reinterpret_cast(src0), num_src1_dims, src1_dims, reinterpret_cast(src1), reinterpret_cast(dst), attr0, attr1); } private: Scalar attr0, attr1; }; } // namespace template std::unique_ptr NewBroadcastElementwiseBinary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new BroadcastElementwiseBinaryImpl(attr0, attr1)); } #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_0.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \ data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_ACTIVATION_BACKWARD_OP_SEQ_0, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_1.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \ data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_ACTIVATION_BACKWARD_OP_SEQ_1, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_2.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \ data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, BINARY_ACTIVATION_BACKWARD_OP_SEQ_2, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_bitwise.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_BITWISE_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_BITWISE_ENTRY, BINARY_BITWISE_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_0.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \ binary_op, src_data_type_pair, dst_data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY, BINARY_COMPARISION_OP_SEQ_0, CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_1.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \ binary_op, src_data_type_pair, dst_data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY, BINARY_COMPARISION_OP_SEQ_1, CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_complex.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" #include "oneflow/core/ep/cuda/primitive/type_seq.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \ binary_op, src_data_type_pair, dst_data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY, BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \ dst_data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY, BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_0.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_SEQ_0, CUDA_PRIMITIVE_REAL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_1.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_REAL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_2.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_SEQ_2, CUDA_PRIMITIVE_REAL_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_complex.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_COMPLEX_MATH_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_elementwise_unary.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_unary { namespace { #define CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ \ CUDA_PRIMITIVE_INT16_TYPE_SEQ \ CUDA_PRIMITIVE_UINT32_TYPE_SEQ \ CUDA_PRIMITIVE_REAL_TYPE_SEQ constexpr size_t kMaxPackSize = 4; template size_t GetPackSize(size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* dst_dims, const void* dst) { static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) { bool is_src_supported = IsPackSizeSupported(pack_size, num_dims, src_dims, src); bool is_dst_supported = IsPackSizeSupported(pack_size, num_dims, dst_dims, dst); if (is_src_supported && is_dst_supported) { return pack_size; } } return 1; } template struct BroadcastElementwiseUnaryParams { OffsetToIndexWithStrideCalculator dst_offset_to_index_helper; size_t num_dims; int64_t src_strides[max_dims]; int64_t dst_strides[max_dims]; IndexType src_index_mask[max_dims]; IndexType count{}; const Src* src{}; Dst* dst{}; bool dst_is_contiguous; Scalar attr0; Scalar attr1; }; template struct UnaryScalarFunctor { __host__ __device__ explicit UnaryScalarFunctor(Src scalar) : scalar(scalar) {} __device__ Dst operator()() const { return UnaryFunctor()(scalar); } const Src scalar; }; template struct UnaryScalarPtrFunctorFactory { __host__ __device__ explicit UnaryScalarPtrFunctorFactory(const Src* scalar_ptr) : scalar_ptr(scalar_ptr) {} __device__ UnaryScalarFunctor operator()() const { return UnaryScalarFunctor(*scalar_ptr); } const Src* scalar_ptr; }; template __global__ void BroadcastElementwiseUnaryGpu( BroadcastElementwiseUnaryParams params) { using LoadPack = cuda::elementwise::Packed; using StorePack = cuda::elementwise::Packed; const LoadPack* src = reinterpret_cast(params.src); StorePack* dst = reinterpret_cast(params.dst); size_t num_dims = params.num_dims; const int64_t* src_strides = params.src_strides; const int64_t* dst_strides = params.dst_strides; auto functor = UnaryFunctor(params.attr0, params.attr1); CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) { IndexType src_offset = 0; IndexType dst_offset = 0; IndexType remaining = offset; #pragma unroll for (int i = 0; i < max_dims; ++i) { if (i < num_dims - 1) { IndexType dst_index = params.dst_offset_to_index_helper.divides(remaining, i); remaining = remaining - params.dst_offset_to_index_helper.mul(dst_index, i); dst_offset += dst_index * dst_strides[i]; src_offset += params.src_index_mask[i] * dst_index * src_strides[i]; } else if (i == num_dims - 1) { dst_offset += remaining * dst_strides[i]; src_offset += params.src_index_mask[i] * remaining * src_strides[i]; } else { break; } } LoadPack src_pack = src[src_offset]; StorePack dst_pack; #pragma unroll for (int j = 0; j < pack_size; ++j) { dst_pack.elem[j] = functor(src_pack.elem[j]); } dst[dst_offset] = dst_pack; } } template void LaunchKernel(CudaStream* stream, size_t num_dims, const int64_t* src_dims, const int64_t* src_strides, const Src* src, const int64_t* dst_dims, const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0, Scalar attr1, size_t count) { BroadcastElementwiseUnaryParams params; for (size_t i = 0; i < num_dims; ++i) { params.src_index_mask[i] = (src_dims[i] == 1) ? 0 : 1; params.src_strides[i] = src_strides[i]; params.dst_strides[i] = dst_strides[i]; } params.dst_offset_to_index_helper = OffsetToIndexWithStrideCalculator(dst_dims, num_dims); params.num_dims = num_dims; params.src = src; params.dst = dst; params.count = static_cast(count); params.attr0 = attr0; params.attr1 = attr1; params.dst_is_contiguous = continuous_output; BroadcastElementwiseUnaryGpu <<cuda_stream()>>>( params); } template void DispatchIndexType(CudaStream* stream, size_t num_dims, const int64_t* src_dims, const int64_t* src_strides, const Src* src, const int64_t* dst_dims, const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0, Scalar attr1) { size_t count = GetElementCount(num_dims, dst_dims); if (count < GetMaxVal() / 2) { LaunchKernel( stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output, attr0, attr1, count); } else { LaunchKernel( stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output, attr0, attr1, count); } } template void DispatchPackSize(CudaStream* stream, size_t pack_size, size_t num_dims, const int64_t* src_dims, const int64_t* src_strides, const Src* src, const int64_t* dst_dims, const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0, Scalar attr1) { void (*func)(CudaStream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src_dims*/, const int64_t* /*src_strides*/, const Src* /*src*/, const int64_t* /*dst_dims*/, const int64_t* /*dst_strides*/, Dst* /*dst*/, bool /*continuous_output*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr; if (pack_size == 1) { func = DispatchIndexType; } else if (pack_size == 4) { func = DispatchIndexType; } else { UNIMPLEMENTED(); } func(stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output, attr0, attr1); } template void DispatchNumDims(CudaStream* stream, size_t pack_size, size_t num_dims, const int64_t* src_dims, const int64_t* src_strides, const Src* src, const int64_t* dst_dims, const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0, Scalar attr1) { void (*func)(CudaStream* /*stream*/, size_t /*pack_size*/, size_t /*num_dims*/, const int64_t* /*src_dims*/, const int64_t* /*src_strides*/, const Src* /*src*/, const int64_t* /*dst_dims*/, const int64_t* /*dst_strides*/, Dst* /*dst*/, bool /*continuous_output*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr; if (num_dims == 1) { func = DispatchPackSize; } else if (num_dims == 2) { func = DispatchPackSize; } else if (num_dims == 3) { func = DispatchPackSize; } else if (num_dims == 4) { func = DispatchPackSize; } else if (num_dims <= kMaxNumDims) { func = DispatchPackSize; } else { UNIMPLEMENTED(); } func(stream, pack_size, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output, attr0, attr1); } template void LaunchWithSimplified(CudaStream* stream, size_t simplified_num_dims, int64_t* simplified_src_dims, int64_t* simplified_src_strides, const Src* src, int64_t* simplified_dst_dims, int64_t* simplified_dst_strides, Dst* dst, Scalar attr0, Scalar attr1) { CHECK_LE(simplified_num_dims, kMaxNumDims); bool src_enable_pack = (simplified_src_strides[simplified_num_dims - 1] == 1); bool dst_enable_pack = (simplified_dst_strides[simplified_num_dims - 1] == 1); size_t pack_size = 1; if (src_enable_pack && dst_enable_pack) { pack_size = GetPackSize(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst); } bool continuous_output = true; for (int i = simplified_num_dims - 1; i >= 0; i--) { if ((i == simplified_num_dims - 1 && simplified_dst_strides[i] != 1) || (i != simplified_num_dims - 1 && simplified_dst_strides[i] != simplified_dst_strides[i + 1] * simplified_dst_dims[i + 1])) { continuous_output = false; break; } } simplified_src_dims[simplified_num_dims - 1] /= pack_size; simplified_dst_dims[simplified_num_dims - 1] /= pack_size; for (int i = 0; i < simplified_num_dims - 1; i++) { simplified_src_strides[i] /= pack_size; simplified_dst_strides[i] /= pack_size; } DispatchNumDims(stream, pack_size, simplified_num_dims, simplified_src_dims, simplified_src_strides, src, simplified_dst_dims, simplified_dst_strides, dst, continuous_output, attr0, attr1); } template __global__ void LaunchFillKernel(UnaryFunctor functor, Dst* dst, const Src* src, size_t pack_count, size_t count, size_t tail_count, Dst* tail_dst) { using StorePack = cuda::elementwise::Packed; StorePack pack_value; Dst value = functor(*src); #pragma unroll for (size_t i = 0; i < pack; ++i) { pack_value.elem[i] = value; } StorePack* pack_dst = reinterpret_cast(dst); CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; } if (tail) { CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; } } } template typename std::enable_if<(pack != 0), void>::type LaunchPackFill(CudaStream* stream, Dst* dst, const Src* src, size_t count, Scalar attr0, Scalar attr1) { const size_t pack_count = count / pack; const size_t tail_offset = pack_count * pack; const size_t tail_count = count - tail_offset; auto functor = UnaryFunctor(attr0, attr1); if (tail_count > 0) { LaunchFillKernel <<cuda_stream()>>>( functor, dst, src, pack_count, count, tail_count, dst + tail_offset); } else { LaunchFillKernel <<cuda_stream()>>>( functor, dst, src, pack_count, count, tail_count, dst + tail_offset); } } template typename std::enable_if<(pack == 0), void>::type LaunchPackFill(CudaStream* stream, Dst* dst, const Src* src, size_t count, Scalar attr0, Scalar attr1) { LOG(FATAL) << "wrong alignment"; } template void LaunchFill(CudaStream* stream, Dst* dst, const Src* src, size_t count, Scalar attr0, Scalar attr1) { auto uintptr = reinterpret_cast(dst); if (uintptr % 16 == 0 && count * sizeof(Dst) >= 16) { LaunchPackFill(stream, dst, src, count, attr0, attr1); } else if (uintptr % 8 == 0 && count * sizeof(Dst) >= 8) { LaunchPackFill(stream, dst, src, count, attr0, attr1); } else if (uintptr % 4 == 0 && count * sizeof(Dst) >= 4) { LaunchPackFill(stream, dst, src, count, attr0, attr1); } else if (uintptr % 2 == 0 && count * sizeof(Dst) >= 2) { LaunchPackFill(stream, dst, src, count, attr0, attr1); } else { LaunchPackFill(stream, dst, src, count, attr0, attr1); } } template class BroadcastElementwiseUnaryImpl : public BroadcastElementwiseUnary { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryImpl); BroadcastElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~BroadcastElementwiseUnaryImpl() override = default; void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src, size_t num_dst_dims, const int64_t* dst_dims, void* dst) override { CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0"; CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0"; int64_t src_strides[kMaxNumDims]; int64_t dst_strides[kMaxNumDims]; // init stride for (int i = num_src_dims - 1; i < kMaxNumDims; ++i) { src_strides[i] = 1; } for (int i = num_src_dims - 2; i >= 0; --i) { src_strides[i] = src_dims[i + 1] * src_strides[i + 1]; } for (int i = num_dst_dims - 1; i < kMaxNumDims; ++i) { dst_strides[i] = 1; } for (int i = num_dst_dims - 2; i >= 0; --i) { dst_strides[i] = dst_dims[i + 1] * dst_strides[i + 1]; } Launch(stream, num_src_dims, src_dims, src_strides, src, num_dst_dims, dst_dims, dst_strides, dst); } void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const int64_t* src_strides, const void* src_ptr, size_t num_dst_dims, const int64_t* dst_dims, const int64_t* dst_strides, void* dst_ptr) override { CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0"; CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0"; auto* cuda_stream = stream->As(); Dst* dst = reinterpret_cast(dst_ptr); const Src* src = reinterpret_cast(src_ptr); size_t simplified_num_dims = 0; int permutation_list[kMaxNumDims]; int64_t permutation_src_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_src_strides[kMaxNumDims]; int64_t simplified_dst_strides[kMaxNumDims]; SimplifyBroadcastDims(num_src_dims, src_dims, src_strides, num_dst_dims, dst_dims, dst_strides, &simplified_num_dims, simplified_src_dims, simplified_src_strides, simplified_dst_dims, simplified_dst_strides); bool permutable = InferPermutable( simplified_num_dims, simplified_src_strides, simplified_dst_strides, simplified_src_dims, simplified_dst_dims, permutation_list, permutation_src_dims, unary_op); std::unique_ptr permute = NewPrimitive(DeviceType::kCUDA, simplified_num_dims); CheckInplace(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst); CheckInplace(simplified_num_dims, simplified_src_strides, src, simplified_dst_strides, dst); if (simplified_num_dims == 1 && simplified_src_dims[0] == 1) { const int64_t elem_cnt = simplified_dst_dims[0]; LaunchFill(cuda_stream, dst, src, elem_cnt, attr0, attr1); } else if (simplified_num_dims == 1 && simplified_src_strides[0] == 1 && simplified_dst_strides[0] == 1) { const int64_t elem_cnt = simplified_src_dims[0]; auto functor = UnaryFunctor(attr0, attr1); OF_CUDA_CHECK((cuda::elementwise::Unary( functor, elem_cnt, dst, src, cuda_stream->cuda_stream()))); } else if (permutable && src_type == dst_type && permute) { permute->Launch(stream, dst_type, simplified_num_dims, permutation_src_dims, src_ptr, permutation_list, dst_ptr); } else { // fall back to normal cases LaunchWithSimplified( cuda_stream, simplified_num_dims, simplified_src_dims, simplified_src_strides, src, simplified_dst_dims, simplified_dst_strides, dst, attr0, attr1); } } protected: Scalar attr0, attr1; }; template std::unique_ptr NewBroadcastElementwiseUnary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new BroadcastElementwiseUnaryImpl(attr0, attr1)); } class BroadcastElementwiseUnaryFactoryImpl : public BroadcastElementwiseUnaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactoryImpl); BroadcastElementwiseUnaryFactoryImpl() = default; ~BroadcastElementwiseUnaryFactoryImpl() override = default; std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) override { return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar()); } std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) override { return New(op, src_type, dst_type, max_num_dims, attr0, Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) override { if (max_num_dims > kMaxNumDims) { return nullptr; } #define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \ NewBroadcastElementwiseUnary}, #define MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, src_dtype_pair, dst_dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_dtype_pair), \ OF_PP_PAIR_SECOND(dst_dtype_pair)), \ NewBroadcastElementwiseUnary< \ unary_op, OF_PP_PAIR_FIRST(src_dtype_pair), OF_PP_PAIR_SECOND(src_dtype_pair), \ OF_PP_PAIR_FIRST(dst_dtype_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)>}, static const std::map, std::function(Scalar, Scalar)>> new_broadcast_elementwise_unary_handle{ // For All Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY, UNARY_IDENTITY_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY, UNARY_IDENTITY_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) // For Cast OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ, CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ, CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) }; #undef MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY const auto iter = new_broadcast_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_type)); if (iter != new_broadcast_elementwise_unary_handle.end()) { return iter->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseUnaryFactory, BroadcastElementwiseUnaryFactoryImpl); } // namespace } // namespace broadcast_elementwise_unary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" #include "oneflow/core/ep/common/primitive/broadcast_matmul.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_matmul_mode.h" #include #include namespace oneflow { namespace ep { namespace primitive { namespace broadcast_matmul { namespace internal { namespace { constexpr size_t kMaxNumDims = 8; Optional OptCudaDataType(DataType data_type) { switch (data_type) { case kFloat: return CUDA_R_32F; case kDouble: return CUDA_R_64F; case kFloat16: return CUDA_R_16F; case kComplex64: return CUDA_C_32F; case kComplex128: return CUDA_C_64F; #if CUDA_VERSION >= 11000 case kBFloat16: return CUDA_R_16BF; #endif // CUDA_VERSION >= 11000 default: return NullOpt; } } cudaDataType_t GetCudaDataType(DataType data_type) { auto cuda_data_type = OptCudaDataType(data_type); CHECK(cuda_data_type.has_value()); return cuda_data_type.value_or(CUDA_R_32F); } union CublasScalarParameter { double d; float s; half h; cuComplex c; cuDoubleComplex z; }; CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cublasComputeType_t compute_type) { CublasScalarParameter sp{}; if (compute_type == CUBLAS_COMPUTE_64F) { sp.d = scalar.Value(); } else if (compute_type == CUBLAS_COMPUTE_32F_PEDANTIC || compute_type == CUBLAS_COMPUTE_32F_FAST_TF32 || compute_type == CUBLAS_COMPUTE_32F) { sp.s = scalar.Value(); } else if (compute_type == CUBLAS_COMPUTE_16F) { sp.h = static_cast(scalar.Value()); } else { UNIMPLEMENTED(); } return sp; } cudaDataType_t GetCublasScalarType(DataType data_type) { switch (data_type) { case kFloat: return CUDA_R_32F; case kDouble: return CUDA_R_64F; case kComplex64: return CUDA_C_32F; case kComplex128: return CUDA_C_64F; default: return CUDA_R_32F; } } cublasComputeType_t GetComputeType(DataType data_type, CudaStream* cuda_stream) { switch (data_type) { case kFloat: { if (CudaMatmulMode::is_matmul_allow_tf32()) { return CUBLAS_COMPUTE_32F_FAST_TF32; } else { // Starting with cuBLAS version 11.0.0, the library will automatically make use of Tensor // Core capabilities wherever possible, unless they are explicitly disabled by selecting // pedantic compute modes in cuBLAS return CUBLAS_COMPUTE_32F_PEDANTIC; } } case kDouble: return CUBLAS_COMPUTE_64F; case kFloat16: { if (cuda_stream->device_properties().major >= 5) { return CUBLAS_COMPUTE_32F; } else { return CUBLAS_COMPUTE_16F; } } case kComplex64: { if (CudaMatmulMode::is_matmul_allow_tf32()) { return CUBLAS_COMPUTE_32F_FAST_TF32; } else { return CUBLAS_COMPUTE_32F_PEDANTIC; } } case kComplex128: return CUBLAS_COMPUTE_64F; #if CUDA_VERSION >= 11000 case kBFloat16: return CUBLAS_COMPUTE_32F; #endif // CUDA_VERSION >= 11000 default: UNIMPLEMENTED(); return CUBLAS_COMPUTE_32F; } } void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t num_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m, int64_t n, int64_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) { auto* cuda_stream = stream->As(); const auto cuda_data_type = GetCudaDataType(data_type); const auto compute_type = GetComputeType(data_type, cuda_stream); const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type); const auto GetCublasOperation = [](BlasTransposeType transpose_type, DataType data_type) { if (transpose_type == BlasTransposeType::N) { return CUBLAS_OP_N; } else if (transpose_type == BlasTransposeType::T) { return DType(data_type).is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T; } else { UNIMPLEMENTED(); return CUBLAS_OP_N; } }; const cublasOperation_t cublas_trans_a = GetCublasOperation(transpose_b, data_type); const cublasOperation_t cublas_trans_b = GetCublasOperation(transpose_a, data_type); const int cublas_m = n; const int cublas_n = m; const int cublas_k = k; int cublas_lda = 0; if (transpose_b == BlasTransposeType::N) { cublas_lda = n; } else if (transpose_b == BlasTransposeType::T) { cublas_lda = k; } else { UNIMPLEMENTED(); } int cublas_ldb = 0; if (transpose_a == BlasTransposeType::N) { cublas_ldb = k; } else if (transpose_a == BlasTransposeType::T) { cublas_ldb = m; } else { UNIMPLEMENTED(); } const int cublas_ldc = n; CublasMathModeGuard guard(cuda_stream->cublas_handle()); if (data_type == DataType::kFloat16) { #if CUDA_VERSION < 11000 guard.SetMathMode(CUBLAS_TENSOR_OP_MATH); #else cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; if (cuda_stream->device_properties().major >= 5 && CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction()) { cublas_flags = static_cast(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); } guard.SetMathMode(cublas_flags); #endif // CUDA_VERSION < 11000 } #if CUDA_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; #else cublasGemmAlgo_t algo = (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DEFAULT; #endif if (num_batch_dims == 1 && c_batch_dims[0] != 1) { const void* cublas_a = b; const void* cublas_b = a; void* cublas_c = c; const int64_t a_batch_count = a_batch_dims[0]; const int64_t b_batch_count = b_batch_dims[0]; CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count); CHECK_GT(a_batch_count, 0); CHECK_GT(b_batch_count, 0); const int batch_count = std::max(a_batch_count, b_batch_count); const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k; const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n; const long long int cublas_stride_c = cublas_m * cublas_n; const auto sp_beta = GetCublasScalarParameter(beta, compute_type); OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx( cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type, cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, cublas_stride_c, batch_count, compute_type, algo)); } else { auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) { const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type); const void* cublas_a = batch_b; const void* cublas_b = batch_a; void* cublas_c = batch_c; OF_CUBLAS_CHECK(cublasGemmEx( cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type, cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo)); }; ForEachMatmul(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func); } } class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl); BroadcastMatmulFactoryImpl() = default; ~BroadcastMatmulFactoryImpl() override = default; std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, size_t max_num_dims) override { auto cuda_data_type = OptCudaDataType(data_type); if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) { return std::make_unique>(data_type, transpose_a, transpose_b); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl); } // namespace } // namespace internal } // namespace broadcast_matmul } // namespace primitive } // namespace ep } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/ep/cuda/primitive/cast.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { template struct CastFunctor { __device__ To operator()(From from) const { return static_cast(from); } }; template struct CastFunctor::value>::type> { __device__ To operator()(half from) const { return static_cast(static_cast(from)); } __device__ void Apply2(To* to, const half* from) const { const float2 f2 = __half22float2(*reinterpret_cast(from)); to[0] = static_cast(f2.x); to[1] = static_cast(f2.y); } }; template struct CastFunctor::value>::type> { __device__ half operator()(From from) const { return static_cast(static_cast(from)); } __device__ void Apply2(half* to, const From* from) const { float2 f2; f2.x = static_cast(from[0]); f2.y = static_cast(from[1]); *reinterpret_cast(to) = __float22half2_rn(f2); } }; #if CUDA_VERSION >= 11000 template struct CastFunctor::value || std::is_same::value)>::type> { __device__ To operator()(nv_bfloat16 from) const { return static_cast(static_cast(from)); } }; template struct CastFunctor::value || std::is_same::value)>::type> { __device__ nv_bfloat16 operator()(From from) const { return static_cast(static_cast(from)); } }; #endif // CUDA_VERSION >= 11000 template class CastImpl : public Cast { public: OF_DISALLOW_COPY_AND_MOVE(CastImpl); explicit CastImpl() = default; ~CastImpl() override = default; void Launch(Stream* stream, const void* from, void* to, size_t count) override { auto* cuda_stream = stream->As(); OF_CUDA_CHECK((cuda::elementwise::Unary, To, From>( CastFunctor(), count, reinterpret_cast(to), reinterpret_cast(from), cuda_stream->cuda_stream()))); } }; template std::unique_ptr NewCast() { return std::unique_ptr(new CastImpl()); } #define CUDA_PRIMITIVE_CAST_TYPE_SEQ \ CUDA_PRIMITIVE_BOOL_TYPE_SEQ \ CUDA_PRIMITIVE_CHAR_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_UINT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_UINT64_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ class CastFactoryImpl : public CastFactory { public: OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl); CastFactoryImpl() = default; ~CastFactoryImpl() override = default; std::unique_ptr New(DataType from, DataType to) override { #define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \ {std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \ NewCast}, static const std::map, std::function()>> new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)}; #undef MAKE_NEW_CAST_ENTRY const auto it = new_cast_handle.find(std::make_pair(from, to)); if (it != new_cast_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/constant_pad.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/constant_pad.h" #include "oneflow/core/ep/common/primitive/constant_pad.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace ep { namespace primitive { namespace { template __global__ void ConstantPadKernel(ConstantPadParams params, StorageType packed_pad_val) { const StorageType* src = reinterpret_cast(params.src); StorageType* dst = reinterpret_cast(params.dst); IndexType src_index[num_dims]; IndexType dst_index[num_dims]; CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) { params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index); bool if_pad = false; #pragma unroll for (int i = 0; i < num_dims; i++) { if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) { src_index[i] = dst_index[i] - params.valid_start[i]; } else { if_pad = true; break; } } StorageType dst_val = packed_pad_val; if (!if_pad) { const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); dst_val = src[src_offset]; } dst[linear_index] = dst_val; } } template<> cuComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuComplex{cpp_value.real(), cpp_value.imag()}; } template<> cuDoubleComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuDoubleComplex{cpp_value.real(), cpp_value.imag()}; } template<> half GetValue(Scalar value) { return static_cast(GetValue(value)); } #if CUDA_VERSION >= 11000 template<> nv_bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } #endif // CUDA_VERSION >= 11000 template void LaunchKernel(Stream* stream, ConstantPadParams params, StorageType packed_pad_val, size_t elem_cnt) { stream->As()->LaunchKernelDefaultWaves( (ConstantPadKernel), elem_cnt, params, packed_pad_val); } template void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { ConstantPadParams params; params.dst_index_helper = OffsetToIndexCalculator(dst_dims); params.src_index_helper = NdIndexOffsetHelper(src_dims); params.dst = dst; params.src = src; for (int i = 0; i < num_dims; i++) { params.valid_start[i] = padding_before[i]; params.valid_end[i] = dst_dims[i] - padding_after[i]; } params.elem_cnt = elem_cnt; LaunchKernel(stream, params, packed_pad_val, elem_cnt); } template void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims, const int64_t* padding_before, const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { if (elem_cnt < GetMaxVal()) { LaunchKernel(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val, elem_cnt); } else { LaunchKernel(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val, elem_cnt); } } template void DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src, int64_t* src_dims, int64_t* padding_before, int64_t* padding_after, T pad_val) { constexpr int32_t max_packsize = GetMaxPackSize(); size_t launch_pack_size = GetLaunchPackSize(num_dims, dst, dst_dims, src, src_dims, padding_before, padding_after); dst_dims[num_dims - 1] /= launch_pack_size; src_dims[num_dims - 1] /= launch_pack_size; padding_before[num_dims - 1] /= launch_pack_size; padding_after[num_dims - 1] /= launch_pack_size; size_t elem_cnt = 1; for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; } if (launch_pack_size == 1) { Pack packed_pad_val(pad_val); DispatchIndexType>(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 2) { Pack packed_pad_val(pad_val); DispatchIndexType>(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 4) { Pack packed_pad_val(pad_val); DispatchIndexType>(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 8) { Pack packed_pad_val(pad_val); DispatchIndexType>(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else if (launch_pack_size == 16) { Pack packed_pad_val(pad_val); DispatchIndexType>(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, packed_pad_val.storage, elem_cnt); } else { UNIMPLEMENTED(); } } template void LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims, const void* src, int64_t* src_dims, int64_t* padding_before, int64_t* padding_after, T pad_val) { void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/, int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) = nullptr; if (num_dims == 1) { func = DispatchPackSize<1, T>; } else if (num_dims == 2) { func = DispatchPackSize<2, T>; } else if (num_dims == 3) { func = DispatchPackSize<3, T>; } else if (num_dims == 4) { func = DispatchPackSize<4, T>; } else if (num_dims == 5) { func = DispatchPackSize<5, T>; } else if (num_dims == 6) { func = DispatchPackSize<6, T>; } else if (num_dims == 7) { func = DispatchPackSize<7, T>; } else if (num_dims == 8) { func = DispatchPackSize<8, T>; } else { UNIMPLEMENTED(); } func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val); } template void SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* padding_before, const int64_t* padding_after, T pad_val, void* dst) { CHECK_GT(num_dims, 0) << "num_dims must greater than 0"; CHECK_LE(num_dims, kMaxNumDims); int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_padding_before[kMaxNumDims]; int64_t simplified_padding_after[kMaxNumDims]; size_t simplified_num_dims = 1; SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims, simplified_dst_dims, simplified_src_dims, simplified_padding_before, simplified_padding_after); LaunchWithSimplified(stream, simplified_num_dims, dst, simplified_dst_dims, src, simplified_src_dims, simplified_padding_before, simplified_padding_after, pad_val); } template class ConstantPadImpl : public ConstantPad { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl); ConstantPadImpl() = default; ~ConstantPadImpl() override = default; void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val, void* dst) override { SimplifyThenLaunch(stream, num_dims, src_dims, src, padding_before, padding_after, GetValue(pad_val), dst); } }; template std::unique_ptr NewConstantPad() { return std::unique_ptr(new ConstantPadImpl()); } class ConstantPadFactoryImpl : public ConstantPadFactory { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl); ConstantPadFactoryImpl() = default; ~ConstantPadFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad}, static const std::map()>> new_constant_pad_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)}; #undef MAKE_NEW_CONSTANT_PAD_ENTRY const auto it = new_constant_pad_handle.find(data_type); if (it != new_constant_pad_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/copy_nd.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/common/primitive/copy_nd.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace ep { namespace primitive { namespace { template __global__ void CopyNdKernel(CopyNdKernelParams params) { using T = typename std::aligned_storage::type; const T* src = reinterpret_cast(params.src); T* dst = reinterpret_cast(params.dst); IndexType copy_index[num_dims]; IndexType src_index[num_dims]; IndexType dst_index[num_dims]; CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) { params.copy_index_helper.OffsetToNdIndex(i, copy_index); #pragma unroll for (size_t j = 0; j < num_dims; ++j) { src_index[j] = params.src_pos[j] + copy_index[j]; dst_index[j] = params.dst_pos[j] + copy_index[j]; } const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index); dst[dst_offset] = src[src_offset]; } } template void LaunchKernel(Stream* stream, CopyNdKernelParams params) { cudaStream_t cuda_stream = stream->As()->cuda_stream(); CopyNdKernel <<>>(params); } class CopyNdImpl : public CopyNd { public: OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl); CopyNdImpl() = default; ~CopyNdImpl() override = default; void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) const override { SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent); } }; class CopyNdFactoryImpl : public CopyNdFactory { public: OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl); CopyNdFactoryImpl() = default; ~CopyNdFactoryImpl() override = default; std::unique_ptr New(size_t max_num_dims) override { if (max_num_dims <= kMaxNumDims) { return std::unique_ptr(new CopyNdImpl()); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/elementwise_unary.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/common/primitive/elementwise_unary.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" namespace oneflow { namespace ep { namespace primitive { namespace { template class ElementwiseUnaryImpl : public ElementwiseUnary { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl); ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ~ElementwiseUnaryImpl() override = default; void Launch(Stream* stream, const void* src, void* dst, size_t count) override { auto* cuda_stream = stream->As(); auto functor = UnaryFunctor(attr0, attr1); OF_CUDA_CHECK((cuda::elementwise::Unary( functor, count, reinterpret_cast(dst), reinterpret_cast(src), cuda_stream->cuda_stream()))); } protected: Scalar attr0, attr1; }; template std::unique_ptr NewElementwiseUnary(Scalar attr0, Scalar attr1) { return std::unique_ptr( new ElementwiseUnaryImpl(attr0, attr1)); } class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl); ElementwiseUnaryFactoryImpl() = default; ~ElementwiseUnaryFactoryImpl() override = default; std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype) override { return New(unary_op, src_type, dst_dtype, Scalar(), Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, Scalar attr0) override { return New(unary_op, src_type, dst_dtype, attr0, Scalar()); } std::unique_ptr New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, Scalar attr0, Scalar attr1) override { #define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \ NewElementwiseUnary}, #define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \ NewElementwiseUnary}, static const std::map, std::function(Scalar, Scalar)>> new_elementwise_unary_handle{ // For All Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ) // For Float Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) // For Complex Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2C_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) // For Int Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_INT_MATH_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ) // For Utils OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ) // For Logical OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ) // For bitwise op OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_BITWISE_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ)}; #undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY const auto it = new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype)); if (it != new_elementwise_unary_handle.end()) { return it->second(attr0, attr1); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/fill.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { template using Storage = typename std::aligned_storage::type; template union Pack { static constexpr size_t size = sizeof(T) * pack; explicit __device__ __host__ Pack(T value) { static_assert(sizeof(Pack) == size, ""); static_assert(alignof(Pack) == size, ""); #pragma unroll for (size_t i = 0; i < pack; ++i) { elem[i] = value; } } T elem[pack]; Storage storage; }; template __global__ void FillGpu(T* dst, T value, size_t count) { const size_t pack_count = count / pack; Pack pack_value(value); auto* pack_dst = reinterpret_cast(dst); CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; } T* tail_dst = dst + pack_count * pack; const size_t tail_count = count - pack_count * pack; CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; } } template T GetValue(Scalar value) { return value.Value(); } template<> half GetValue(Scalar value) { return static_cast(GetValue(value)); } template<> cuComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuComplex{cpp_value.real(), cpp_value.imag()}; } template<> cuDoubleComplex GetValue(Scalar value) { const std::complex cpp_value = GetValue>(value); return cuDoubleComplex{cpp_value.real(), cpp_value.imag()}; } #if CUDA_VERSION >= 11000 template<> nv_bfloat16 GetValue(Scalar value) { return static_cast(GetValue(value)); } #endif // CUDA_VERSION >= 11000 template typename std::enable_if<(pack != 0), void>::type LaunchPackFill(cudaStream_t stream, T* dst, T value, size_t count) { FillGpu <<>>(dst, value, count); } template typename std::enable_if<(pack == 0), void>::type LaunchPackFill(cudaStream_t stream, T* dst, T value, size_t count) { LOG(FATAL) << "wrong alignment"; } template void LaunchFill(cudaStream_t stream, T* dst, T value, size_t count) { auto uintptr = reinterpret_cast(dst); if (uintptr % 16 == 0) { LaunchPackFill(stream, dst, value, count); } else if (uintptr % 8 == 0) { LaunchPackFill(stream, dst, value, count); } else if (uintptr % 4 == 0) { LaunchPackFill(stream, dst, value, count); } else if (uintptr % 2 == 0) { LaunchPackFill(stream, dst, value, count); } else { LaunchPackFill(stream, dst, value, count); } } template class FillImpl : public Fill { public: OF_DISALLOW_COPY_AND_MOVE(FillImpl); FillImpl() = default; ~FillImpl() override = default; void Launch(Stream* stream, void* dst, Scalar value, size_t count) override { cudaStream_t cuda_stream = stream->As()->cuda_stream(); LaunchFill(cuda_stream, reinterpret_cast(dst), GetValue(value), count); } }; template std::unique_ptr NewFill() { return std::unique_ptr(new FillImpl()); } class FillFactoryImpl : public FillFactory { public: OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl); FillFactoryImpl() = default; ~FillFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill}, static const std::map()>> new_fill_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ CUDA_PRIMITIVE_INT16_TYPE_SEQ)}; #undef MAKE_NEW_FILL_ENTRY const auto it = new_fill_handle.find(data_type); if (it != new_fill_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_0.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_0, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_1.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_1, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_2.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_2, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_3.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_3, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_complex.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" namespace oneflow { namespace ep { namespace primitive { namespace broadcast_elementwise_binary { #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ template std::unique_ptr NewBroadcastElementwiseBinary< \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ Scalar attr0, Scalar attr1); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ); } // namespace broadcast_elementwise_binary } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/memcpy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace ep { namespace primitive { namespace { class MemcpyImpl : public Memcpy { public: OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl); MemcpyImpl() = default; ~MemcpyImpl() override = default; void Launch(Stream* stream, void* dst, const void* src, size_t count) override { if (dst == src) { return; } auto* cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, cuda_stream->cuda_stream())); } }; class MemcpyFactoryImpl : public MemcpyFactory { public: OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl); MemcpyFactoryImpl() = default; ~MemcpyFactoryImpl() override = default; std::unique_ptr New(MemcpyKind kind) override { return std::unique_ptr(new MemcpyImpl()); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow #endif ================================================ FILE: oneflow/core/ep/cuda/primitive/memset.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace ep { namespace primitive { namespace { class MemsetImpl : public Memset { public: OF_DISALLOW_COPY_AND_MOVE(MemsetImpl); MemsetImpl() = default; ~MemsetImpl() override = default; void Launch(Stream* stream, void* ptr, int value, size_t count) override { auto* cuda_stream = stream->As(); OF_CUDA_CHECK(cudaMemsetAsync(ptr, value, count, cuda_stream->cuda_stream())); } }; class MemsetFactoryImpl : public MemsetFactory { public: OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl); MemsetFactoryImpl() = default; ~MemsetFactoryImpl() override = default; std::unique_ptr New() override { return std::unique_ptr(new MemsetImpl()); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow #endif ================================================ FILE: oneflow/core/ep/cuda/primitive/permute.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/common/primitive/permute_impl.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace ep { namespace primitive { namespace permute { namespace internal { namespace { constexpr int32_t kMov4TileSize = 32; constexpr int32_t kMov2TileSize = 64; constexpr int32_t kBlockRows = 8; template __global__ void PermuteKernel(PermuteKernelParams params) { using T = typename std::aligned_storage::type; const T* src = reinterpret_cast(params.src); T* dst = reinterpret_cast(params.dst); IndexType src_index[num_dims]; IndexType dst_index[num_dims]; CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) { params.dst_index_helper.OffsetToNdIndex(i, dst_index); #pragma unroll for (size_t dim = 0; dim < num_dims; ++dim) { src_index[params.permutation[dim]] = dst_index[dim]; } IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); dst[i] = src[src_offset]; } } // (B, X, Y) -> (B, Y, X) // refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ template __global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows, IndexType cols, IndexType num_tile_rows, IndexType num_tile_cols, int32_t block_nums) { const IndexType src_rows = rows; const IndexType src_cols = cols; const IndexType dst_rows = cols; const IndexType dst_cols = rows; using T = typename std::aligned_storage::type; __shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict. const T* src = reinterpret_cast(src_ptr); T* dst = reinterpret_cast(dst_ptr); IndexType batch_num_tile = num_tile_rows * num_tile_cols; for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { const IndexType batch_index = i / batch_num_tile; // the index of batch. const IndexType tile_index = i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the // flatten index of tile in a batch. const IndexType tile_row_index = tile_index / num_tile_cols; // the row index of tile in a batch. const IndexType tile_col_index = tile_index - tile_row_index * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. const IndexType offset = batch_index * src_rows * src_cols; { IndexType col_in_tile = threadIdx.x; IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x; #pragma unroll for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; row_in_tile += kBlockRows) { IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size; if (col_in_matrix < src_cols && row_in_matrix < src_rows) { tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix]; } } } __syncthreads(); { IndexType col_in_tile = threadIdx.x; IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x; #pragma unroll for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; row_in_tile += kBlockRows) { IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size; if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) { dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile]; } } } __syncthreads(); } } /* Here is a Movementsie=2 version of Batch Transpose. When the H W can be divided by 2. we can read data use movementsize=4, and write back as movementsize=4. */ template __global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows, IndexType cols, IndexType num_tile_rows, IndexType num_tile_cols, int32_t block_nums) { const IndexType src_rows = rows; const IndexType src_cols = cols; const IndexType dst_rows = cols; const IndexType dst_cols = rows; static_assert(tile_size % 2 == 0, ""); using T_MOV2 = typename std::aligned_storage<2, 2>::type; using T_MOV4 = typename std::aligned_storage<4, 4>::type; const T_MOV4* src = reinterpret_cast(src_ptr); T_MOV4* dst = reinterpret_cast(dst_ptr); // Use union structure to process Load and Store. __shared__ union { T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66] T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33] } tile_mem; IndexType batch_num_tile = num_tile_rows * num_tile_cols; for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { const IndexType batch_index = i / batch_num_tile; // the index of batch. const IndexType tile_index = i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the // flatten index of tile in a batch. const IndexType tile_row_index = tile_index / num_tile_cols; // the row index of tile in a batch. const IndexType tile_col_index = tile_index - tile_row_index * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. const IndexType offset = batch_index * src_rows * src_cols; { IndexType col_in_tile = threadIdx.x; IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2; #pragma unroll for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; row_in_tile += kBlockRows) { IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size; if (col_in_matrix < src_cols && row_in_matrix < src_rows) { tile_mem.tile_m4[row_in_tile][col_in_tile] = src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2]; } } } __syncthreads(); { IndexType col_in_tile = threadIdx.x; IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2; #pragma unroll for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; row_in_tile += kBlockRows) { IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size; union { T_MOV4 m4; T_MOV2 m2[2]; } tmp_storage; if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) { tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile]; tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile]; dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4; } } } __syncthreads(); } } template void LaunchBatchTransposeKernel(cudaStream_t& cuda_stream, const PermuteKernelParams& params, const IndexType& num_batches, const IndexType& rows, const IndexType& cols) { IndexType num_tile_rows = (rows + tile_size - 1) / tile_size; IndexType num_tile_cols = (cols + tile_size - 1) / tile_size; const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols; int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum); if (tile_size == kMov2TileSize) { const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements. BatchTransposeMovement2Kernel <<>>( params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums); // Set threads num as 32x8 cause each threads // process 4 elements to 64x66 half share memory. } else { BatchTransposeKernel <<>>( params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums); } } template bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) { if (rows < tile_size || cols < tile_size) { return false; } return true; } template bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches, const IndexType& rows, const IndexType& cols) { if (CheckIfGreaterEqualThanTileSize(rows, cols)) { if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) { // 2d tensor case: (0, 1) -> (1, 0) return true; } else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) { // 3d tensor case: (0, 1, 2) -> (0, 2, 1) return true; } else { return false; } } return false; } template bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) { auto src_ptr = reinterpret_cast(src); auto dst_ptr = reinterpret_cast(dst); return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0) && (dst_ptr % 4 == 0); } template void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows, IndexType* cols) { if (num_dims == 2) { *num_batches = 1; *rows = src_dims[0]; *cols = src_dims[1]; } else { *num_batches = src_dims[0]; *rows = src_dims[1]; *cols = src_dims[2]; } } template void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation, void* dst, size_t count) { PermuteKernelParams params = MakePermuteParams(src_dims, src, permutation, dst, count); cudaStream_t cuda_stream = stream->As()->cuda_stream(); if (num_dims == 2 || num_dims == 3) { IndexType num_batches; IndexType rows; IndexType cols; InferBatchTransposeShape(src_dims, &num_batches, &rows, &cols); if (CheckLaunchBatchTranspose(params.permutation, num_batches, rows, cols)) { if (CheckUseMov2(rows, cols, src, dst)) { LaunchBatchTransposeKernel(cuda_stream, params, num_batches, rows, cols); } else { LaunchBatchTransposeKernel( cuda_stream, params, num_batches, rows, cols); } } else { if (params.count == 0) { return; } PermuteKernel <<>>(params); } } else { if (params.count == 0) { return; } PermuteKernel <<>>(params); } } class PermuteImpl : public Permute { public: OF_DISALLOW_COPY_AND_MOVE(PermuteImpl); PermuteImpl() = default; ~PermuteImpl() override = default; using Permute::Launch; void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) override { SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst); } }; class PermuteFactoryImpl : public PermuteFactory { public: OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl); PermuteFactoryImpl() = default; ~PermuteFactoryImpl() override = default; std::unique_ptr New(size_t max_num_dims) override { if (max_num_dims <= kMaxNumDims) { return std::unique_ptr(new PermuteImpl()); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl); } // namespace } // namespace internal } // namespace permute } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/softmax.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/softmax.h" #include "oneflow/core/ep/include/primitive/log_softmax.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { enum class Algorithm { kSoftmax, kLogSoftmax, }; template void SoftmaxGpu(cudaStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) { using ComputeType = typename cuda::softmax::DefaultComputeType::type; oneflow::cuda::softmax::DirectLoad load(x, cols); oneflow::cuda::softmax::DirectStore store(y, cols); if (algorithm == Algorithm::kSoftmax) { OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( cuda_stream, load, store, rows, cols))); } else if (algorithm == Algorithm::kLogSoftmax) { OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax( cuda_stream, load, store, rows, cols))); } else { UNIMPLEMENTED(); } } template class SoftmaxImpl : public SoftmaxBase { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl); SoftmaxImpl() = default; ~SoftmaxImpl() override = default; void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override { cudaStream_t cuda_stream = stream->As()->cuda_stream(); SoftmaxGpu(cuda_stream, rows, cols, reinterpret_cast(x), reinterpret_cast(y)); } }; template std::unique_ptr NewSoftmax() { return std::unique_ptr(new SoftmaxImpl()); } template class GenericSoftmaxFactoryImpl : public FactoryBase { public: OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl); GenericSoftmaxFactoryImpl() = default; ~GenericSoftmaxFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \ {type_proto, NewSoftmax}, static const std::map()>> new_softmax_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)}; #undef MAKE_NEW_SOFTMAX_ENTRY const auto it = new_softmax_handle.find(data_type); if (it != new_softmax_handle.end()) { return it->second(); } else { return nullptr; } } }; using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl; using LogSoftmaxFactoryImpl = GenericSoftmaxFactoryImpl; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/softmax_backward.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/softmax_backward.h" #include "oneflow/core/ep/include/primitive/log_softmax_backward.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { enum class Algorithm { kSoftmax, kLogSoftmax, }; template void SoftmaxBackwardGpu(cudaStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy, T* dx) { using ComputeType = typename cuda::softmax::DefaultComputeType::type; cuda::softmax::DirectLoad load_y(y, cols); cuda::softmax::DirectLoad load_dy(dy, cols); cuda::softmax::DirectStore store(dx, cols); if (algorithm == Algorithm::kSoftmax) { OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad( cuda_stream, load_y, load_dy, store, rows, cols))); } else if (algorithm == Algorithm::kLogSoftmax) { OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad( cuda_stream, load_y, load_dy, store, rows, cols))); } else { UNIMPLEMENTED(); } } template class SoftmaxBackwardImpl : public SoftmaxBackwardBase { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl); SoftmaxBackwardImpl() = default; ~SoftmaxBackwardImpl() override = default; void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void* dx) override { cudaStream_t cuda_stream = stream->As()->cuda_stream(); SoftmaxBackwardGpu(cuda_stream, rows, cols, reinterpret_cast(y), reinterpret_cast(dy), reinterpret_cast(dx)); } }; template std::unique_ptr NewSoftmaxBackward() { return std::unique_ptr( new SoftmaxBackwardImpl()); } template class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase { public: OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl); GenericSoftmaxBackwardFactoryImpl() = default; ~GenericSoftmaxBackwardFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \ {type_proto, NewSoftmaxBackward}, static const std::map()>> new_softmax_backward_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)}; #undef MAKE_NEW_SOFTMAX_ENTRY const auto it = new_softmax_backward_handle.find(data_type); if (it != new_softmax_backward_handle.end()) { return it->second(); } else { return nullptr; } } }; using SoftmaxBackwardFactoryImpl = GenericSoftmaxBackwardFactoryImpl; using LogSoftmaxBackwardFactoryImpl = GenericSoftmaxBackwardFactoryImpl; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory, LogSoftmaxBackwardFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/tensor_fill.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/tensor_fill.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace ep { namespace primitive { namespace { template using Storage = typename std::aligned_storage::type; template union Pack { static constexpr size_t size = sizeof(T) * pack; explicit __device__ __host__ Pack(const T value) { static_assert(sizeof(Pack) == size, ""); static_assert(alignof(Pack) == size, ""); #pragma unroll for (size_t i = 0; i < pack; ++i) { elem[i] = value; } } T elem[pack]; Storage storage; }; template __global__ void TensorFillGpu(T* dst, const T* value, size_t count) { const size_t pack_count = count / pack; const T fill_value = value[0]; Pack pack_value(fill_value); auto* pack_dst = reinterpret_cast(dst); CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; } T* tail_dst = dst + pack_count * pack; const size_t tail_count = count - pack_count * pack; CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = fill_value; } } template typename std::enable_if<(pack != 0), void>::type LaunchPackTensorFill(cudaStream_t stream, T* dst, const T* value, size_t count) { TensorFillGpu <<>>(dst, value, count); } template typename std::enable_if<(pack == 0), void>::type LaunchPackTensorFill(cudaStream_t stream, T* dst, const T* value, size_t count) { LOG(FATAL) << "wrong alignment"; } template void LaunchTensorFill(cudaStream_t stream, T* dst, const T* value, size_t count) { auto uintptr = reinterpret_cast(dst); if (uintptr % 16 == 0) { LaunchPackTensorFill(stream, dst, value, count); } else if (uintptr % 8 == 0) { LaunchPackTensorFill(stream, dst, value, count); } else if (uintptr % 4 == 0) { LaunchPackTensorFill(stream, dst, value, count); } else if (uintptr % 2 == 0) { LaunchPackTensorFill(stream, dst, value, count); } else { LaunchPackTensorFill(stream, dst, value, count); } } template class TensorFillImpl : public TensorFill { public: OF_DISALLOW_COPY_AND_MOVE(TensorFillImpl); TensorFillImpl() = default; ~TensorFillImpl() override = default; void Launch(Stream* stream, const void* src, void* dst, size_t count) override { cudaStream_t cuda_stream = stream->As()->cuda_stream(); const T* value = reinterpret_cast(src); LaunchTensorFill(cuda_stream, reinterpret_cast(dst), value, count); } }; template std::unique_ptr NewTensorFill() { return std::unique_ptr(new TensorFillImpl()); } class TensorFillFactoryImpl : public TensorFillFactory { public: OF_DISALLOW_COPY_AND_MOVE(TensorFillFactoryImpl); TensorFillFactoryImpl() = default; ~TensorFillFactoryImpl() override = default; std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_TENSOR_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill}, static const std::map()>> new_fill_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_TENSOR_FILL_ENTRY, CUDA_PRIMITIVE_REAL_TYPE_SEQ)}; #undef MAKE_NEW_TENSOR_FILL_ENTRY const auto it = new_fill_handle.find(data_type); if (it != new_fill_handle.end()) { return it->second(); } else { return nullptr; } } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, TensorFillFactory, TensorFillFactoryImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/cuda/primitive/type_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ #define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/data_type.h" #ifdef WITH_CUDA #include #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) #define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) #define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) #define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) #define CUDA_PRIMITIVE_INT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16) #define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) #define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) #define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) #define CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) #define CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128) #if CUDA_VERSION >= 11000 #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) #else #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ #endif // CUDA_VERSION >= 11000 #define CUDA_PRIMITIVE_REAL_TYPE_SEQ \ CUDA_PRIMITIVE_BOOL_TYPE_SEQ \ CUDA_PRIMITIVE_CHAR_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ #define CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ \ CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ \ CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ #define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ #define CUDA_PRIMITIVE_INT_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ #define UTIL_OPS_DATA_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ #endif // WITH_CUDA #endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ ================================================ FILE: oneflow/core/ep/cuda/primitive/unary_functor.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH #define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH #include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include #include "oneflow/core/common/math_util.h" namespace oneflow { namespace ep { namespace primitive { template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(0.5) * src * (static_cast(1.0) + erf(static_cast(M_SQRT1_2) * src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu const Src half = static_cast(0.5); const Src one = static_cast(1); const Src tanh_in = alpha * (src + beta * src * src * src); return half * src * (one + tanh(tanh_in)); } private: // constant ref to: // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py static constexpr Src alpha = static_cast(0.7978845608028654); static constexpr Src beta = static_cast(0.044714998453855515); }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { const Src sigmoid = static_cast(static_cast(1.0) / (static_cast(1.0) + exp(-src * alpha))); return src * sigmoid; } private: static constexpr Src alpha = static_cast(1.702); }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast((src > static_cast(0.0)) ? src * src : 0); } }; namespace unary_functor_internal { namespace { OF_DEVICE_FUNC float TanhApprox(float x) { #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) float r; asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); return r; #else return tanhf(x); #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) } } // namespace } // namespace unary_functor_internal template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} OF_DEVICE_FUNC half operator()(half src) const { #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) const float tanh_in = __half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src)); const float tanh_out = unary_functor_internal::TanhApprox(tanh_in); return __float2half_rn(0.5F) * src * (__float2half_rn(1.0F) + __float2half_rn(tanh_out)); #else return static_cast(float_functor(static_cast(src))); #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) } #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) __device__ void Apply2(half* dst, const half* src) const { const half2 src2 = *(reinterpret_cast(src)); const float2 tanh_in = __half22float2(__hmul2( __float2half2_rn(alpha), __hadd2(src2, __hmul2(__hmul2(__hmul2(__float2half2_rn(beta), src2), src2), src2)))); float2 tanh_out; tanh_out.x = unary_functor_internal::TanhApprox(tanh_in.x); tanh_out.y = unary_functor_internal::TanhApprox(tanh_in.y); const half2 dst2 = __hmul2(__hmul2(__float2half2_rn(0.5F), src2), __hadd2(__float2half2_rn(1.0F), __float22half2_rn(tanh_out))); *reinterpret_cast(dst) = dst2; } #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) private: static constexpr float alpha = 0.7978845608028654F; static constexpr float beta = 0.044714998453855515F; UnaryFunctor float_functor; }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(half src) const { return isfinite(__half2float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(float src) const { return isfinite(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(double src) const { return isfinite(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} __device__ half operator()(half src) const { return htrunc(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC float operator()(float src) const { return truncf(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC double operator()(double src) const { return trunc(src); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src in) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L3029-L3090 static const double PI_f64 = 3.14159265358979323846; const Src PSI_10 = 2.25175258906672110764; const Src A[] = { 8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3, -4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3, 8.33333333333333333333E-2, }; Src x = static_cast(in); if (x == static_cast(0)) { // As per C++ standard for gamma related functions and SciPy, // If the argument is ±0, ±∞ is returned return std::copysign(static_cast(INFINITY), -x); } bool x_is_integer = x == trunc(x); Src result = static_cast(0); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, // If the argument is a negative integer, NaN is returned return static_cast(NAN); } // Extracts the fractional part of x as r, since tan(pi * r) is more numerically // accurate than tan(pi * x). While these operations are mathematically equivalent // since both x and r are in radians and tan() has a periodicity of pi, in practice // the computation of pi * x is a source of error (when |x| > 1). double q, r; r = modf(static_cast(x), &q); result = static_cast(-PI_f64 / tan(PI_f64 * r)); x = static_cast(1) - x; } while (x < 10) { result -= static_cast(1) / x; x += 1; } if (x == static_cast(10)) { return static_cast(result + PSI_10); } Src y = 0; if (x < 1.0e17) { Src z = static_cast(1) / (x * x); Src polevl_result = 0; for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; } y = z * polevl_result; } return static_cast(log(x) - (static_cast(0.5) / x) - y + result); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src x) const { // references // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L387-L410 const Src PI{3.14159265358979323846}; Src sign = 1; Src result = 0; if (x < Src{0.5}) { sign = -1; Src sin_pi_x = sin(PI * x); result -= (PI * PI) / (sin_pi_x * sin_pi_x); x = 1 - x; } for (int i = 0; i < 6; ++i) { result += Src{1} / (x * x); x += 1; } const Src one{1}; const Src ixx = one / (x * x); result += (one + one / (Src{2} * x) + ixx * (one / Src{6} - ixx * (one / Src{30} - ixx * (one / Src{42})))) / x; return sign * result; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} __device__ half operator()(half src) const { return __hlt(src, static_cast(0)) ? __hneg(src) : src; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return isnan(src) ? static_cast(0.0) : src; } }; #if CUDA_VERSION >= 11000 template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} __device__ nv_bfloat16 operator()(nv_bfloat16 src) const { #if CUDA_ARCH >= 800 return __habs(src); #else return __float2bfloat16(abs(__bfloat162float(src))); #endif // CUDA_ARCH >= 800 } }; #endif // CUDA_VERSION >= 11000 /*********half dtype support*********/ template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(half src) const { return static_cast(__half2float(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC half operator()(Src src) const { return __float2half(static_cast(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC half operator()(half src) const { return src; } }; /*********nv_bfloat16 dtype support*********/ #if CUDA_VERSION >= 11000 template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC nv_bfloat16 operator()(half src) const { return __float2bfloat16(__half2float(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(nv_bfloat16 src) const { return static_cast(__bfloat162float(src)); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC nv_bfloat16 operator()(Src src) const { return __float2bfloat16(static_cast(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC half operator()(nv_bfloat16 src) const { return __float2half(__bfloat162float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { return src; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(nv_bfloat16 src) const { return make_cuComplex((__bfloat162float(src)), 0.0); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(nv_bfloat16 src) const { return make_cuDoubleComplex(static_cast(__bfloat162float(src)), 0.0); } }; #endif // CUDA_VERSION >= 11000 #define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \ template<> \ struct UnaryFunctor { \ OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ \ UnaryFunctor float_functor; \ OF_DEVICE_FUNC half operator()(half src) const { \ return __float2half(float_functor(__half2float(src))); \ } \ }; SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAcos); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAcosh); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAsin); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAsinh); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAtan); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAtanh); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCeil); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCos); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCosh); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kDigamma); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kTrigamma); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErf); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErfc); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp2); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExpm1); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kFloor); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLgamma); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog2); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog10); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog1p); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLogSigmoid); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRint); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRound); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRsqrt); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSigmoid); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSin); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSinh); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSqrt); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquare); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kTan); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNotEqualZero); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNanAssign); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kQuickGelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquareReLU); /*********nv_bfloat16_kernel*******/ #if CUDA_VERSION >= 11000 #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \ template<> \ struct UnaryFunctor { \ OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ \ UnaryFunctor float_functor; \ OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \ return __float2bfloat16(float_functor(__bfloat162float(src))); \ } \ }; SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcos); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcosh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsin); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsinh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtan); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtanh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCeil); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCos); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCosh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErf); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErfc); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp2); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExpm1); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFloor); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLgamma); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog2); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog10); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog1p); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLogSigmoid); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRint); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRound); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRsqrt); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSigmoid); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSin); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSinh); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSqrt); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquare); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTan); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma); SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma); template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isfinite(__bfloat162float(src)); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} __device__ nv_bfloat16 operator()(nv_bfloat16 src) const { #if CUDA_ARCH >= 800 return htrunc(src); #else return __float2bfloat16(truncf(__bfloat162float(src))); #endif // CUDA_ARCH >= 800 } }; #endif // CUDA_VERSION >= 11000 /*********float complex dtype support*********/ template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.x, -src.y}; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.x); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.y); } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; } }; // avoid warning: narrowing conversion template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(double src) const { return cuComplex{static_cast(src), 0.0f}; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(double src) const { return cuComplex{0.0f, static_cast(src)}; } }; template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(Src src) const { return make_cuComplex(static_cast(src), 0.0); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return src; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuDoubleComplex src) const { return cuComplexDoubleToFloat(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(half src) const { return make_cuComplex((__half2float(src)), 0.0); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return src; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return cuComplex{src.x, -src.y}; } }; // reference : thrust: `thrust/detail/complex/csqrtf.h:csqrtf` template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { float a = src.x, b = src.y; float t = 0.0f; int scale = 1; cuComplex result; /* We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)). */ const float THRESH = 1.40949553037932e+38f; /* Handle special cases. */ if (src.x == 0.0f && src.y == float()) { // return (complex(0, b)); return (cuComplex{0.0f, b}); } // FLT_MIN*2 const float low_thresh = 2.35098870164458e-38f; scale = 0; if (fabsf(a) >= THRESH || fabsf(b) >= THRESH) { /* Scale to avoid overflow. */ a *= 0.25f; b *= 0.25f; scale = 1; } else if (fabsf(a) <= low_thresh && fabsf(b) <= low_thresh) { /* Scale to avoid underflow. */ a *= 4.f; b *= 4.f; scale = 2; } /* Algorithm 312, CACM vol 10, Oct 1967. */ if (a >= 0.0f) { t = sqrtf((a + hypotf(a, b)) * 0.5f); // result = complex(t, b / (2.0f * t)); result.x = t; result.y = b / (2.0f * t); } else { t = sqrtf((-a + hypotf(a, b)) * 0.5f); // result = complex(fabsf(b) / (2.0f * t), copysignf(t, b)); result.x = fabsf(b) / (2.0f * t); result.y = copysignf(t, b); } /* Rescale. */ if (scale == 1) { // return (result * 2.0f); result.x *= 2.0f; result.y *= 2.0f; } else if (scale == 2) { // return (result * 0.5f); result.x *= 0.5f; result.y *= 0.5f; } return (result); } }; /*********double complex dtype support*********/ template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(Src src) const { return make_cuDoubleComplex(static_cast(src), 0.0); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { return src; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuComplex src) const { return cuComplexFloatToDouble(src); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(half src) const { return make_cuDoubleComplex(static_cast(__half2float(src)), 0.0); } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { return src; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { return cuDoubleComplex{src.x, -src.y}; } }; template<> struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { double a = src.x, b = src.y; double t = 0.0; int scale = 1; cuDoubleComplex result; /* We risk spurious overflow for components >= DBL_MAX / (1 + sqrt(2)). */ const float THRESH = 7.446288774449766337959726e+307; /* Handle special cases. */ if (src.x == 0.0 && src.y == double()) { // return (complex(0, b)); return (cuDoubleComplex{0.0, b}); } // DBL_MIN*2 const double low_thresh = 4.450147717014402766180465e-308; scale = 0; if (fabs(a) >= THRESH || fabs(b) >= THRESH) { /* Scale to avoid overflow. */ a *= 0.25; b *= 0.25; scale = 1; } else if (fabs(a) <= low_thresh && fabs(b) <= low_thresh) { /* Scale to avoid underflow. */ a *= 4.0; b *= 4.0; scale = 2; } /* Algorithm 312, CACM vol 10, Oct 1967. */ if (a >= 0.0) { t = sqrt((a + hypot(a, b)) * 0.5); // result = complex(t, b / (2.0f * t)); result.x = t; result.y = b / (2 * t); } else { t = sqrt((-a + hypot(a, b)) * 0.5); // result = complex(fabsf(b) / (2.0f * t), copysignf(t, b)); result.x = fabs(b) / (2 * t); result.y = copysignf(t, b); } /* Rescale. */ if (scale == 1) { // return (result * 2.0f); result.x *= 2.0; result.y *= 2.0; } else if (scale == 2) { // return (result * 0.5f); result.x *= 0.5; result.y *= 0.5; } return (result); } }; #define SPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(op, complex_type, real_type) \ template<> \ struct UnaryFunctor { \ OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \ UnaryFunctor real_functor; \ OF_DEVICE_FUNC complex_type operator()(complex_type src) const { \ return complex_type{real_functor(src.x), real_functor(src.y)}; \ } \ }; SPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(UnaryOp::kNegative, cuComplex, float); SPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(UnaryOp::kNegative, cuDoubleComplex, double); } // namespace primitive } // namespace ep } // namespace oneflow #endif ================================================ FILE: oneflow/core/ep/cuda/primitive/where.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/where.h" #include "oneflow/core/ep/common/primitive/where.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { namespace ep { namespace primitive { namespace { using cuda::elementwise::GetNumBlocks; using cuda::elementwise::kBlockSize; template __global__ void BroadcastElementwiseWhereCudaKernel( BroadcastElementwiseWhereParams params) { constexpr size_t _pack_size = (x_pack_size > y_pack_size) ? x_pack_size : y_pack_size; constexpr size_t pack_size = (cond_pack_size > _pack_size) ? cond_pack_size : _pack_size; static_assert(cond_pack_size == pack_size || cond_pack_size == 1, ""); static_assert(x_pack_size == pack_size || x_pack_size == 1, ""); static_assert(y_pack_size == pack_size || y_pack_size == 1, ""); constexpr bool cond_pack_one = !(cond_pack_size == pack_size); constexpr bool x_pack_one = !(x_pack_size == pack_size); constexpr bool y_pack_one = !(y_pack_size == pack_size); const auto* cond_pack_ptr = reinterpret_cast*>(params.cond); const auto* x_pack_ptr = reinterpret_cast*>(params.x); const auto* y_pack_ptr = reinterpret_cast*>(params.y); auto* z_pack_ptr = reinterpret_cast*>(params.z); IndexT cond_index[ndim]; IndexT x_index[ndim]; IndexT y_index[ndim]; IndexT z_index[ndim]; WhereFunctor where_fn{}; CUDA_1D_KERNEL_LOOP_T(IndexT, offset, params.elem_cnt) { params.z_index_helper.OffsetToNdIndex(offset, z_index); #pragma unroll for (size_t i = 0; i < ndim; ++i) { cond_index[i] = params.cond_index_mask[i] * z_index[i]; x_index[i] = params.x_index_mask[i] * z_index[i]; y_index[i] = params.y_index_mask[i] * z_index[i]; } const IndexT cond_offset = params.cond_index_helper.NdIndexToOffset(cond_index); const IndexT x_offset = params.x_index_helper.NdIndexToOffset(x_index); const IndexT y_offset = params.y_index_helper.NdIndexToOffset(y_index); Packed cond_pack = cond_pack_ptr[cond_offset]; Packed x_pack = x_pack_ptr[x_offset]; Packed y_pack = y_pack_ptr[y_offset]; Packed z_pack; #pragma unroll for (size_t j = 0; j < pack_size; ++j) { const CondT cond_val = cond_pack_one ? cond_pack.elem[0] : cond_pack.elem[j]; const T x_val = x_pack_one ? x_pack.elem[0] : x_pack.elem[j]; const T y_val = y_pack_one ? y_pack.elem[0] : y_pack.elem[j]; z_pack.elem[j] = where_fn(cond_val, x_val, y_val); } z_pack_ptr[offset] = z_pack; } } template cudaError_t LaunchCudaKernel(cudaStream_t stream, const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { BroadcastElementwiseWhereParams params; params.cond_index_helper = NdIndexOffsetHelper(cond_dims); params.x_index_helper = NdIndexOffsetHelper(x_dims); params.y_index_helper = NdIndexOffsetHelper(y_dims); params.z_index_helper = NdIndexOffsetHelper(z_dims); for (size_t i = 0; i < ndim; ++i) { params.cond_index_mask[i] = (cond_dims[i] == 1) ? 0 : 1; params.x_index_mask[i] = (x_dims[i] == 1) ? 0 : 1; params.y_index_mask[i] = (y_dims[i] == 1) ? 0 : 1; } params.elem_cnt = static_cast(GetElementCount(ndim, z_dims)); params.cond = cond; params.x = x; params.y = y; params.z = z; int num_blocks; { cudaError_t err = GetNumBlocks(params.elem_cnt, &num_blocks); if (err != cudaSuccess) { return err; } } BroadcastElementwiseWhereCudaKernel<<>>(params); return cudaPeekAtLastError(); } template void LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) { static_assert(ndim > 0, ""); auto cuda_stream = stream->As()->cuda_stream(); OF_CUDA_CHECK((LaunchCudaKernel( cuda_stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z))); } template void LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z) { // should dispatch to elemwise tenary UNIMPLEMENTED(); } template void LaunchElemwiseTenary(CudaStream* stream, int64_t elem_cnt, const CondT* cond, const T* x, const T* y, T* z) { cudaStream_t cuda_stream = stream->cuda_stream(); WhereElemwiseFunctor where_fn{}; OF_CUDA_CHECK((cuda::elementwise::Ternary( where_fn, elem_cnt, z, cond, x, y, cuda_stream))); } template class WhereCudaImpl : public Where { public: OF_DISALLOW_COPY_AND_MOVE(WhereCudaImpl); explicit WhereCudaImpl() = default; ~WhereCudaImpl() override = default; void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims, const void* cond, size_t num_x_dims, const int64_t* x_dims, const void* x, size_t num_y_dims, const int64_t* y_dims, const void* y, void* z) override { size_t compact_num_dims = 0; int64_t compact_cond_dims[kMaxNumDims] = {}; int64_t compact_x_dims[kMaxNumDims] = {}; int64_t compact_y_dims[kMaxNumDims] = {}; int64_t compact_z_dims[kMaxNumDims] = {}; GetCompactBroadcastDims(num_cond_dims, cond_dims, num_x_dims, x_dims, num_y_dims, y_dims, &compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims, compact_z_dims); if (IsDimsEquals(compact_num_dims, compact_z_dims, compact_cond_dims) && IsDimsEquals(compact_num_dims, compact_z_dims, compact_x_dims) && IsDimsEquals(compact_num_dims, compact_z_dims, compact_y_dims)) { // elementwise const size_t elem_cnt = GetElementCount(compact_num_dims, compact_z_dims); LaunchElemwiseTenary(stream->As(), elem_cnt, static_cast(cond), static_cast(x), static_cast(y), static_cast(z)); } else { // broadcast LaunchByDispatchNDim(stream, compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims, compact_z_dims, static_cast(cond), static_cast(x), static_cast(y), static_cast(z)); } } }; class WhereFactoryCudaImpl : public WhereFactory { public: OF_DISALLOW_COPY_AND_MOVE(WhereFactoryCudaImpl); WhereFactoryCudaImpl() = default; ~WhereFactoryCudaImpl() override = default; std::unique_ptr New(DataType cond_type, DataType data_type, size_t max_num_dims) override { return NewWhere(cond_type, data_type, max_num_dims); } }; REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, WhereFactory, WhereFactoryCudaImpl); } // namespace } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/include/active_device_guard.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_ #define ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/device.h" namespace oneflow { namespace ep { class DeviceManager; class ActiveDeviceGuard { public: OF_DISALLOW_COPY_AND_MOVE(ActiveDeviceGuard); explicit ActiveDeviceGuard(Device* device); ~ActiveDeviceGuard(); private: size_t saved_active_device_; DeviceManager* device_manager_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_ ================================================ FILE: oneflow/core/ep/include/allocation_options.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_ #define ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/device_type.h" namespace oneflow { namespace ep { class AllocationOptions { public: AllocationOptions() : pinned_device_type_(DeviceType::kInvalidDevice), pinned_device_index_{}, numa_node_affinity_(-1) {} ~AllocationOptions() = default; bool HasPinnedDevice() const { return pinned_device_type_ != DeviceType::kInvalidDevice; } DeviceType GetPinnedDeviceType() const { CHECK(HasPinnedDevice()); return pinned_device_type_; } size_t GetPinnedDeviceIndex() const { CHECK(HasPinnedDevice()); return pinned_device_index_; } void SetPinnedDevice(DeviceType device_type, size_t device_index) { CHECK(!HasPinnedDevice()); CHECK_NE(device_type, DeviceType::kInvalidDevice); pinned_device_type_ = device_type; pinned_device_index_ = device_index; } void ClearPinnedDevice() { pinned_device_type_ = DeviceType::kInvalidDevice; } bool HasNumaNodeAffinity() const { return numa_node_affinity_ >= 0; } size_t GetNumaNodeAffinity() const { CHECK(HasNumaNodeAffinity()); return numa_node_affinity_; } void SetNumaNodeAffinity(size_t numa_node) { numa_node_affinity_ = numa_node; } void ClearNumaNodeAffinity() { numa_node_affinity_ = -1; } private: DeviceType pinned_device_type_; size_t pinned_device_index_; int32_t numa_node_affinity_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_ ================================================ FILE: oneflow/core/ep/include/device.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_DEVICE_H_ #define ONEFLOW_CORE_EP_DEVICE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/event.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/include/allocation_options.h" namespace oneflow { namespace ep { constexpr size_t kMaxAlignmentRequirement = 512; class DeviceManager; class Device { public: OF_DISALLOW_COPY_AND_MOVE(Device); Device() = default; virtual ~Device() = default; virtual void SetAsActiveDevice() = 0; virtual void Reset() = 0; virtual DeviceType device_type() const = 0; virtual size_t device_index() const = 0; virtual DeviceManager* device_manager() const = 0; virtual Stream* CreateStream() = 0; virtual void DestroyStream(Stream* stream) = 0; virtual Event* CreateEvent(); virtual void DestroyEvent(Event* event); virtual void CreateEvents(Event** events, size_t count) = 0; virtual void DestroyEvents(Event** events, size_t count) = 0; virtual Maybe Alloc(const AllocationOptions& options, void** ptr, size_t size) = 0; virtual void Free(const AllocationOptions& options, void* ptr) = 0; virtual Maybe AllocPinned(const AllocationOptions& options, void** ptr, size_t size) = 0; virtual void FreePinned(const AllocationOptions& options, void* ptr) = 0; virtual bool IsStreamOrderedMemoryAllocationSupported() const; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_DEVICE_H_ ================================================ FILE: oneflow/core/ep/include/device_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_H_ #define ONEFLOW_CORE_EP_DEVICE_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/include/random_generator.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/device_type.h" namespace oneflow { namespace ep { class DeviceManagerRegistry; class DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(DeviceManager); DeviceManager() = default; virtual ~DeviceManager() = default; virtual DeviceManagerRegistry* registry() const = 0; virtual std::shared_ptr GetDevice(size_t device_index) = 0; virtual size_t GetDeviceCount(size_t primary_device_index) = 0; virtual size_t GetDeviceCount() = 0; virtual size_t GetActiveDeviceIndex() = 0; virtual void SetActiveDeviceByIndex(size_t device_index) = 0; virtual bool IsStreamWaitEventSupported() const { return false; } virtual std::shared_ptr CreateRandomGenerator(uint64_t seed, size_t device_index) = 0; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_DEVICE_MANAGER_H_ ================================================ FILE: oneflow/core/ep/include/device_manager_factory.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_ #define ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/device_manager.h" #include "oneflow/core/common/device_type.h" namespace oneflow { namespace ep { class DeviceManagerRegistry; class DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(DeviceManagerFactory); DeviceManagerFactory() = default; virtual ~DeviceManagerFactory() = default; virtual std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) = 0; virtual DeviceType device_type() const = 0; virtual std::string device_type_name() const = 0; virtual void DumpVersionInfo() const {} }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_ ================================================ FILE: oneflow/core/ep/include/device_manager_registry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_ #define ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/device_manager.h" #include "oneflow/core/ep/include/device_manager_factory.h" namespace oneflow { namespace ep { class DeviceManagerRegistry { public: OF_DISALLOW_COPY_AND_MOVE(DeviceManagerRegistry); DeviceManagerRegistry(); ~DeviceManagerRegistry(); DeviceManager* GetDeviceManager(DeviceType device_type); DeviceManager* GetDeviceManagerOrNull(DeviceType device_type); std::shared_ptr GetDevice(DeviceType device_type, size_t device_index); size_t GetDeviceCount(DeviceType device_type); size_t GetDeviceCount(const std::string& device_type_name); static void RegisterDeviceManagerFactory(std::unique_ptr&& factory); static void DumpVersionInfo(); static std::string GetDeviceTypeNameByDeviceType(DeviceType device_type); static DeviceType GetDeviceTypeByDeviceTypeName(const std::string& device_type_name); static std::set GetRegisteredDeviceTypes(); static bool IsDeviceTypeRegistered(DeviceType device_type); private: class Impl; std::unique_ptr impl_; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_ ================================================ FILE: oneflow/core/ep/include/event.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_EVENT_H_ #define ONEFLOW_CORE_EP_EVENT_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" namespace oneflow { namespace ep { class Event { public: OF_DISALLOW_COPY_AND_MOVE(Event); Event() = default; virtual ~Event() = default; virtual Maybe QueryDone() = 0; virtual Maybe Sync() = 0; }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_EVENT_H_ ================================================ FILE: oneflow/core/ep/include/primitive/add.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Add : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Add); Add() = default; ~Add() override = default; virtual void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, size_t count) = 0; virtual void Launch(Stream* stream, const void* src0, const void* src1, void* dst, size_t count); }; class AddFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(AddFactory); AddFactory() = default; ~AddFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_ ================================================ FILE: oneflow/core/ep/include/primitive/batch_matmul.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/blas.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class BatchMatmul : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(BatchMatmul); BatchMatmul() = default; ~BatchMatmul() override = default; virtual void Launch(Stream* stream, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) = 0; }; class BatchMatmulFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(BatchMatmulFactory); BatchMatmulFactory() = default; ~BatchMatmulFactory() override = default; virtual std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_ ================================================ FILE: oneflow/core/ep/include/primitive/binary_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { enum class BinaryOp { // Math kAdd, kSub, kMul, kDiv, kMax, kMin, kPow, kFmod, kFloorDiv, kTruncDiv, kFloorMod, kScalarBasePowerGrad, kScalarExpPowerGrad, kZeta, // Comparision kEqual, kNotEqual, kLessThan, kLessEqual, kGreaterThan, kGreaterEqual, kIsClose, kIsCloseEqualNan, // Logical kLogicalAnd, kLogicalOr, kLogicalXor, // Bitwise kBitwiseAnd, kBitwiseOr, kBitwiseXor, // Unary Backward kIdentityBackwardWithDyX, kEluBackwardWithDyX, kCeluBackwardWithDyY, kGeluBackwardWithDyX, kHardswishBackwardWithDyX, kHardsigmoidBackwardWithDyX, kHardshrinkBackwardWithDyY, kHardtanhBackwardWithDyY, kLeakyReluBackwardWithDyX, kMishBackwardWithDyX, kReluBackwardWithDyY, kReluBackwardWithDyX, kSeluBackwardWithDyX, kSiluBackwardWithDyX, kSoftsignBackwardWithDyX, kSoftplusBackwardWithDyX, kSoftshrinkBackwardWithDyY, kTanhBackwardWithDyY, kThresholdBackwardWithDyX, kSigmoidBackwardWithDyY, kSigmoidBackwardWithDyX, kAbsBackwardWithDyX, kAcosBackwardWithDyX, kAcoshBackwardWithDyX, kAsinBackwardWithDyX, kAsinhBackwardWithDyX, kAtanBackwardWithDyX, kAtanhBackwardWithDyX, kCosBackwardWithDyX, kCoshBackwardWithDyX, kErfBackwardWithDyX, kErfcBackwardWithDyX, kExpBackwardWithDyX, kExp2BackwardWithDyX, kExpm1BackwardWithDyX, kLgammaBackwardWithDyX, kDigammaBackwardWithDyX, kLogBackwardWithDyX, kLog2BackwardWithDyX, kLog10BackwardWithDyX, kLog1pBackwardWithDyX, kLogSigmoidBackwardWithDyX, kReciprocalBackwardWithDyX, kReciprocalNoNanBackwardWithDyX, kRsqrtBackwardWithDyX, kSinBackwardWithDyX, kSinhBackwardWithDyX, kSqrtBackwardWithDyX, kSquareBackwardWithDyX, kTanBackwardWithDyX, kFastGeluBackwardWithDyX, kQuickGeluBackwardWithDyX, kSquareReLUBackwardWithDyX, }; } } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_ ================================================ FILE: oneflow/core/ep/include/primitive/blas.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_ namespace oneflow { namespace ep { namespace primitive { enum class BlasTransposeType { N = 0, T, }; } } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_ ================================================ FILE: oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class BroadcastElementwiseBinary : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinary); BroadcastElementwiseBinary() = default; ~BroadcastElementwiseBinary() override = default; virtual void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) = 0; virtual void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, const void* src1, void* dst) = 0; virtual void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, Scalar src1, void* dst) = 0; }; class BroadcastElementwiseBinaryFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactory); BroadcastElementwiseBinaryFactory() = default; ~BroadcastElementwiseBinaryFactory() override = default; virtual std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) = 0; virtual std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) = 0; virtual std::unique_ptr New(BinaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_ ================================================ FILE: oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class BroadcastElementwiseUnary : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnary); BroadcastElementwiseUnary() = default; ~BroadcastElementwiseUnary() override = default; virtual void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const int64_t* src_strides, const void* src, size_t num_dst_dims, const int64_t* dst_dims, const int64_t* dst_strides, void* dst) = 0; virtual void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src, size_t num_dst_dims, const int64_t* dst_dims, void* dst) = 0; }; class BroadcastElementwiseUnaryFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactory); BroadcastElementwiseUnaryFactory() = default; ~BroadcastElementwiseUnaryFactory() override = default; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims) = 0; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0) = 0; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, size_t max_num_dims, Scalar attr0, Scalar attr1) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_ ================================================ FILE: oneflow/core/ep/include/primitive/broadcast_matmul.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/blas.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class BroadcastMatmul : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmul); BroadcastMatmul() = default; ~BroadcastMatmul() override = default; virtual void Launch(Stream* stream, Scalar alpha, size_t num_a_dims, const int64_t* a_dims, const void* a, size_t num_b_dims, const int64_t* b_dims, const void* b, Scalar beta, size_t num_c_dims, const int64_t* c_dims, void* c) = 0; }; class BroadcastMatmulFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactory); BroadcastMatmulFactory() = default; ~BroadcastMatmulFactory() override = default; virtual std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b, size_t max_num_dims) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_ ================================================ FILE: oneflow/core/ep/include/primitive/cast.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Cast : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Cast); Cast() = default; ~Cast() override = default; virtual void Launch(Stream* stream, const void* from, void* to, size_t count) = 0; }; class CastFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(CastFactory); CastFactory() = default; ~CastFactory() override = default; virtual std::unique_ptr New(DataType from, DataType to) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_ ================================================ FILE: oneflow/core/ep/include/primitive/constant_pad.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_CONSTANT_PAD_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_CONSTANT_PAD_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class ConstantPad : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPad); ConstantPad() = default; ~ConstantPad() override = default; virtual void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val, void* dst) = 0; }; class ConstantPadFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactory); ConstantPadFactory() = default; ~ConstantPadFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif ================================================ FILE: oneflow/core/ep/include/primitive/copy_nd.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class CopyNd : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(CopyNd); CopyNd() = default; ~CopyNd() override = default; virtual void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst, const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) const = 0; }; class CopyNdFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(CopyNdFactory); CopyNdFactory() = default; ~CopyNdFactory() override = default; virtual std::unique_ptr New(size_t max_num_dims) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_ ================================================ FILE: oneflow/core/ep/include/primitive/elementwise_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_ #include "oneflow/core/common/scalar.h" #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/unary_op.h" namespace oneflow { namespace ep { namespace primitive { class ElementwiseUnary : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnary); ElementwiseUnary() = default; ~ElementwiseUnary() override = default; virtual void Launch(Stream* stream, const void* src, void* dst, size_t count) = 0; }; class ElementwiseUnaryFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactory); ElementwiseUnaryFactory() = default; ~ElementwiseUnaryFactory() override = default; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type) = 0; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, Scalar attr0) = 0; virtual std::unique_ptr New(UnaryOp op, DataType src_type, DataType dst_type, Scalar attr0, Scalar attr1) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_ ================================================ FILE: oneflow/core/ep/include/primitive/fast_integer_math.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_ #include "oneflow/core/common/data_type.h" #include namespace oneflow { /* Copyright microsoft/onnxruntime https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h */ template struct FastIntegerMath { OF_DEVICE_FUNC FastIntegerMath() {} OF_DEVICE_FUNC explicit FastIntegerMath(T operand) { #if defined(__CUDA_ARCH__) int leading_zeroes = __clzll(operand); #else int leading_zeroes = __builtin_clz(operand); #endif bool is_power_2 = ((operand & (operand - 1)) == 0); if (is_power_2) { log2_operand_ = 31 - leading_zeroes; } else { log2_operand_ = -1; // Set as flag. } operand_ = operand == 0 ? 1 : operand; assert(operand_ >= 1 && operand_ <= GetMaxVal()); } OF_DEVICE_FUNC T divides(T n) const { if (log2_operand_ >= 0) { return n >> log2_operand_; } else { return n / operand_; } } OF_DEVICE_FUNC T mod(T n) const { return n - divides(n) * operand_; } OF_DEVICE_FUNC T mul(T n) const { if (log2_operand_ >= 0) { return n << log2_operand_; } else { return n * operand_; } } OF_DEVICE_FUNC T add(T n) const { return n + operand_; } OF_DEVICE_FUNC T sub(T n) const { return n - operand_; } OF_DEVICE_FUNC void divmod(T n, T* q, T* r) const { *q = divides(n); *r = n - *q * operand_; } T operand_; int32_t log2_operand_; }; template<> struct FastIntegerMath { OF_DEVICE_FUNC FastIntegerMath() {} OF_DEVICE_FUNC explicit FastIntegerMath(const int32_t operand) { operand_ = operand == 0 ? 1 : operand; assert(operand_ >= 1 && operand_ <= GetMaxVal()); for (l_ = 0; l_ < 32; l_++) if ((1U << l_) >= operand_) break; uint64_t one = 1; uint64_t m = ((one << 32) * ((one << l_) - operand_)) / operand_ + 1; M_ = static_cast(m); assert(M_ > 0 && M_ == m); } OF_DEVICE_FUNC int32_t divides(const int32_t n) const { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) uint32_t t = __umulhi(M_, n); return (t + n) >> l_; #else // Using uint64_t for t, then t + n won't overflow. uint64_t t = ((uint64_t)M_ * n) >> 32; return static_cast((t + n) >> l_); #endif } OF_DEVICE_FUNC int32_t mod(int32_t n) const { return n - divides(n) * operand_; } OF_DEVICE_FUNC int32_t mul(int32_t n) const { return n * operand_; } OF_DEVICE_FUNC int32_t add(int32_t n) const { return n + operand_; } OF_DEVICE_FUNC int32_t sub(int32_t n) const { return n - operand_; } OF_DEVICE_FUNC void divmod(int32_t n, int32_t* q, int32_t* r) const { *q = divides(n); *r = n - *q * operand_; } uint32_t operand_; uint32_t M_; // m' in the paper. uint32_t l_; // l_ = ceil(log2(d_)) }; } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_ ================================================ FILE: oneflow/core/ep/include/primitive/fill.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class Fill : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Fill); Fill() = default; ~Fill() override = default; virtual void Launch(Stream* stream, void* dst, Scalar value, size_t count) = 0; }; class FillFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(FillFactory); FillFactory() = default; ~FillFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_ ================================================ FILE: oneflow/core/ep/include/primitive/log_softmax.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class LogSoftmax : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(LogSoftmax); LogSoftmax() = default; ~LogSoftmax() override = default; virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) = 0; }; class LogSoftmaxFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxFactory); LogSoftmaxFactory() = default; ~LogSoftmaxFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_ ================================================ FILE: oneflow/core/ep/include/primitive/log_softmax_backward.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class LogSoftmaxBackward : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxBackward); LogSoftmaxBackward() = default; ~LogSoftmaxBackward() override = default; virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void* dx) = 0; }; class LogSoftmaxBackwardFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxBackwardFactory); LogSoftmaxBackwardFactory() = default; ~LogSoftmaxBackwardFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_ ================================================ FILE: oneflow/core/ep/include/primitive/matmul.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/blas.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class Matmul : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Matmul); Matmul() = default; ~Matmul() override = default; virtual void Launch(Stream* stream, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) = 0; }; class MatmulFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(MatmulFactory); MatmulFactory() = default; ~MatmulFactory() override = default; virtual std::unique_ptr New(DataType data_type, BlasTransposeType transpose_a, BlasTransposeType transpose_b) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_ ================================================ FILE: oneflow/core/ep/include/primitive/memcpy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { enum class MemcpyKind { kAuto = 0, kHtoD, kDtoH, kDtoD, }; class Memcpy : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Memcpy); Memcpy() = default; ~Memcpy() override = default; virtual void Launch(Stream* stream, void* dst, const void* src, size_t count) = 0; }; class MemcpyFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(MemcpyFactory); MemcpyFactory() = default; ~MemcpyFactory() override = default; virtual std::unique_ptr New(MemcpyKind kind) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_ ================================================ FILE: oneflow/core/ep/include/primitive/memset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Memset : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Memset); Memset() = default; ~Memset() override = default; virtual void Launch(Stream* stream, void* ptr, int value, size_t count) = 0; }; class MemsetFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(MemsetFactory); MemsetFactory() = default; ~MemsetFactory() override = default; virtual std::unique_ptr New() = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_ ================================================ FILE: oneflow/core/ep/include/primitive/one_hot.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_ #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/common/scalar.h" namespace oneflow { namespace ep { namespace primitive { class OneHot : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(OneHot); OneHot() = default; ~OneHot() override = default; virtual void Launch(Stream* stream, const void* indices, void* out, Scalar on_value, Scalar off_value, size_t num_indices, size_t lower_bound, size_t upper_bound) = 0; }; class OneHotFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(OneHotFactory); OneHotFactory() = default; ~OneHotFactory() override = default; virtual std::unique_ptr New(DataType indices_type, DataType out_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_ ================================================ FILE: oneflow/core/ep/include/primitive/permute.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Permute : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Permute); Permute() = default; ~Permute() override = default; virtual void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) = 0; }; class PermuteFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(PermuteFactory); PermuteFactory() = default; ~PermuteFactory() override = default; virtual std::unique_ptr New(size_t max_num_dims) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_ ================================================ FILE: oneflow/core/ep/include/primitive/primitive.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { namespace ep { namespace primitive { class Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Primitive); Primitive() = default; virtual ~Primitive() = default; }; template class Factory { public: OF_DISALLOW_COPY_AND_MOVE(Factory); Factory() = default; virtual ~Factory() = default; using PrimitiveType = PrimitiveT; }; template static std::unique_ptr NewPrimitive(DeviceType device_type, Args&&... args) { if (!IsClassRegistered(device_type)) { return nullptr; } std::unique_ptr factory = NewObjUniquePtr(device_type); if (!factory) { return nullptr; } return factory->New(std::forward(args)...); } #define REGISTER_PRIMITIVE_FACTORY(device, Base, Derived) \ REGISTER_CLASS(DeviceType, device, Base, Derived) } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_ ================================================ FILE: oneflow/core/ep/include/primitive/softmax.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Softmax : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Softmax); Softmax() = default; ~Softmax() override = default; virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) = 0; }; class SoftmaxFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxFactory); SoftmaxFactory() = default; ~SoftmaxFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_ ================================================ FILE: oneflow/core/ep/include/primitive/softmax_backward.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class SoftmaxBackward : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackward); SoftmaxBackward() = default; ~SoftmaxBackward() override = default; virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void* dx) = 0; }; class SoftmaxBackwardFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardFactory); SoftmaxBackwardFactory() = default; ~SoftmaxBackwardFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_ ================================================ FILE: oneflow/core/ep/include/primitive/tensor_fill.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class TensorFill : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(TensorFill); TensorFill() = default; ~TensorFill() override = default; virtual void Launch(Stream* stream, const void* src, void* dst, size_t count) = 0; }; class TensorFillFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(TensorFillFactory); TensorFillFactory() = default; ~TensorFillFactory() override = default; virtual std::unique_ptr New(DataType data_type) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_ ================================================ FILE: oneflow/core/ep/include/primitive/unary_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_ namespace oneflow { namespace ep { namespace primitive { enum class UnaryOp { kIdentity, // activation op kElu, kCelu, kRelu, kGelu, kHardSwish, kHardSigmoid, kHardShrink, kHardTanh, kLeakyRelu, kMish, kSelu, kSilu, kSoftShrink, kSoftSign, kSoftPlus, kTanh, kThreshold, kFastGelu, kQuickGelu, kSquareReLU, // math op kAbs, kAcos, kAcosh, kAsin, kAsinh, kAtan, kAtanh, kCeil, kCos, kCosh, kDigamma, kTrigamma, kErf, kErfc, kExp, kExp2, kExpm1, kFloor, kLgamma, kLog, kLog2, kLog10, kLog1p, kLogSigmoid, kNegative, kReciprocal, kReciprocalNoNan, kRint, kRound, kRsqrt, kSigmoid, kSign, kSin, kSinh, kSqrt, kSquare, kTan, kTrunc, kNotEqualZero, // logical op kLogicalNot, // cast op kCast, // utils op kIsInf, kIsNan, kIsFinite, kNanAssign, // bitwise op kBitwiseNot, // complex op kConj, kReal, kImag, kRealGrad, kImagGrad }; } } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_ ================================================ FILE: oneflow/core/ep/include/primitive/where.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_ #define ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_ #include "oneflow/core/ep/include/primitive/primitive.h" namespace oneflow { namespace ep { namespace primitive { class Where : public Primitive { public: OF_DISALLOW_COPY_AND_MOVE(Where); Where() = default; ~Where() override = default; virtual void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims, const void* cond, size_t num_x_dims, const int64_t* x_dims, const void* x, size_t num_y_dims, const int64_t* y_dims, const void* y, void* z) = 0; }; class WhereFactory : public Factory { public: OF_DISALLOW_COPY_AND_MOVE(WhereFactory); WhereFactory() = default; ~WhereFactory() override = default; virtual std::unique_ptr New(DataType cond_type, DataType data_type, size_t max_num_dims) = 0; }; } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_ ================================================ FILE: oneflow/core/ep/include/random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_ #include namespace oneflow { namespace ep { class RandomGenerator { public: RandomGenerator() = default; virtual ~RandomGenerator() = default; virtual uint64_t current_seed() const = 0; virtual void set_current_seed(uint64_t seed) = 0; virtual std::string device_type_name() const = 0; virtual int64_t device_index() const = 0; virtual size_t GetStateSize() const = 0; virtual void GetState(size_t state_size, void* state) const = 0; virtual void SetState(size_t state_size, const void* state) = 0; }; template std::string GetRandomGeneratorDeviceTypeName(); } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/ep/include/stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_STREAM_H_ #define ONEFLOW_CORE_EP_STREAM_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/event.h" namespace oneflow { namespace ep { class Device; class Stream { public: OF_DISALLOW_COPY_AND_MOVE(Stream); Stream() = default; virtual ~Stream() = default; virtual DeviceType device_type() const = 0; virtual Device* device() const = 0; virtual Maybe Sync() = 0; virtual void RecordEvent(Event* event) = 0; virtual void WaitEvent(Event* event) { UNIMPLEMENTED(); } virtual Maybe GetAsyncError() { return Maybe::Ok(); } virtual Maybe AllocAsync(void** ptr, size_t size) { UNIMPLEMENTED_THEN_RETURN(); } virtual Maybe FreeAsync(void* ptr) { UNIMPLEMENTED_THEN_RETURN(); } template Maybe AllocAsync(T** ptr, size_t size) { return AllocAsync(reinterpret_cast(ptr), size); } virtual Maybe OnExecutionContextSetup() { return Maybe::Ok(); } virtual Maybe OnExecutionContextTeardown() { return Maybe::Ok(); } template T* As() { return static_cast(this); } }; } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_STREAM_H_ ================================================ FILE: oneflow/core/ep/test/primitive/add_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/add.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestAdd(DeviceManagerRegistry* registry, const std::set& device_types) { constexpr size_t max_arity = 10; using Matrix = Eigen::Matrix; std::vector srcs(max_arity); std::vector dsts(max_arity); for (size_t i = 0; i < max_arity; ++i) { srcs[i] = Matrix::Random(); if (i == 0) { dsts[i] = Matrix::Zero(); } else { dsts[i] = srcs[i - 1] + dsts[i - 1]; } } const size_t vector_size = n * sizeof(T); for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); std::vector host_srcs(max_arity); std::vector device_srcs(max_arity); std::vector host_dsts(max_arity); std::vector device_dsts(max_arity); AllocationOptions pinned_options; pinned_options.SetPinnedDevice(device_type, 0); AllocationOptions device_options; for (size_t i = 0; i < max_arity; ++i) { CHECK_JUST(device->AllocPinned(pinned_options, &host_srcs[i], vector_size)); CHECK_JUST(device->AllocPinned(pinned_options, &host_dsts[i], vector_size)); CHECK_JUST(device->Alloc(device_options, &device_srcs[i], vector_size)); CHECK_JUST(device->Alloc(device_options, &device_dsts[i], vector_size)); } ep::test::StreamGuard stream(device.get()); std::unique_ptr add = NewPrimitive(device_type, data_type); ASSERT_TRUE(add.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); for (size_t i = 0; i < max_arity; ++i) { std::memcpy(host_srcs[i], srcs[i].data(), vector_size); h2d->Launch(stream.stream(), device_srcs[i], host_srcs[i], vector_size); } for (size_t i = 2; i < max_arity; ++i) { add->Launch(stream.stream(), device_srcs.data(), i, device_dsts.at(i), n); } for (size_t i = 2; i < max_arity; ++i) { d2h->Launch(stream.stream(), host_dsts[i], device_dsts[i], vector_size); } CHECK_JUST(stream.stream()->Sync()); for (size_t i = 2; i < max_arity; ++i) { auto res = Eigen::Map(reinterpret_cast(host_dsts[i]), n); ASSERT_TRUE(dsts[i].template isApprox(res)); } for (size_t i = 0; i < max_arity; ++i) { device->FreePinned(pinned_options, host_srcs[i]); device->FreePinned(pinned_options, host_dsts[i]); device->Free(device_options, device_srcs[i]); device->Free(device_options, device_dsts[i]); } } } } // namespace TEST_F(PrimitiveTest, TestAdd) { TestAdd(&device_manager_registry_, available_device_types_); TestAdd(&device_manager_registry_, available_device_types_); TestAdd(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/batch_matmul_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/batch_matmul.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestBatchMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int batch_size, int m, int k, int n, bool transpose_a, bool transpose_b) { using Matrix = Eigen::Matrix; Eigen::Tensor in_a_buffer(batch_size, m, k); Eigen::Tensor in_b_buffer(batch_size, k, n); Eigen::Tensor out_c_buffer(batch_size, m, n); in_a_buffer.setRandom(); in_b_buffer.setRandom(); for (int i = 0; i < batch_size; ++i) { Eigen::Map a(in_a_buffer.data() + i * m * k, m, k); Eigen::Map b(in_b_buffer.data() + i * k * n, k, n); Eigen::Map c(out_c_buffer.data() + i * m * n, m, n); c = a * b; } int64_t a_size = batch_size * m * k * sizeof(T); int64_t b_size = batch_size * k * n * sizeof(T); int64_t c_size = batch_size * m * n * sizeof(T); Eigen::array shuffling({0, 2, 1}); Eigen::Tensor in_a_transposed = in_a_buffer.shuffle(shuffling); Eigen::Tensor in_b_transposed = in_b_buffer.shuffle(shuffling); for (const auto& device_type : device_types) { if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) { // CPU matmul not support float16 continue; } auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); ep::test::PinnedMemoryGuard input_b(device.get(), b_size); if (transpose_a) { std::memcpy(input_a.ptr(), in_a_transposed.data(), a_size); } else { std::memcpy(input_a.ptr(), in_a_buffer.data(), a_size); } if (transpose_b) { std::memcpy(input_b.ptr(), in_b_transposed.data(), b_size); } else { std::memcpy(input_b.ptr(), in_b_buffer.data(), b_size); } ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_b(device.get(), b_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N; const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N; std::unique_ptr batch_matmul = NewPrimitive(device_type, data_type, trans_a, trans_b); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(batch_matmul.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size); batch_matmul->Launch(stream.stream(), batch_size, m, n, k, 1.0, device_a.ptr(), device_b.ptr(), 0.0, device_c.ptr()); d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out( out_c_buffer.data(), out_c_buffer.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output.ptr()), out_c_buffer.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast(0.001))); } } template void TestBatchMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int batch_size, int m, int k, int n) { TestBatchMatmul(registry, device_types, batch_size, m, k, n, false, false); TestBatchMatmul(registry, device_types, batch_size, m, k, n, true, false); TestBatchMatmul(registry, device_types, batch_size, m, k, n, false, true); TestBatchMatmul(registry, device_types, batch_size, m, k, n, true, true); } template void TestBatchMatmul(DeviceManagerRegistry* registry, const std::set& device_types) { TestBatchMatmul(registry, device_types, 10, 64, 16, 8); TestBatchMatmul(registry, device_types, 12, 16, 7, 12); } } // namespace TEST_F(PrimitiveTest, TestBatchMatmul) { TestBatchMatmul(&device_manager_registry_, available_device_types_); TestBatchMatmul(&device_manager_registry_, available_device_types_); TestBatchMatmul(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/binary_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template Scalar GetScalar(const T& value) { return Scalar(value); } template<> Scalar GetScalar(const Eigen::half& value) { return Scalar(static_cast(value)); } template void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry, const std::set& device_types, int test_type) { const int num_axes = 4; const int broadcast_dim0 = 16; const int broadcast_dim1 = 3; const int broadcast_dim2 = 4; const int broadcast_dim3 = 8; bool is_broadcast = false; bool left_scalar = false; bool right_scalar = false; if (test_type == 0) { // do nothing } else if (test_type == 1) { is_broadcast = true; } else if (test_type == 2) { left_scalar = true; } else if (test_type == 3) { right_scalar = true; } else { UNIMPLEMENTED(); } const int a_dim0 = left_scalar ? 1 : broadcast_dim0; const int a_dim1 = left_scalar ? 1 : broadcast_dim1; const int a_dim2 = left_scalar ? 1 : broadcast_dim2; const int a_dim3 = left_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim3); const int b_dim0 = right_scalar ? 1 : broadcast_dim0; const int b_dim1 = right_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim1); const int b_dim2 = right_scalar ? 1 : broadcast_dim2; const int b_dim3 = right_scalar ? 1 : broadcast_dim3; const int a_broadcast0 = left_scalar ? broadcast_dim0 : 1; const int a_broadcast1 = left_scalar ? broadcast_dim1 : 1; const int a_broadcast2 = left_scalar ? broadcast_dim2 : 1; const int a_broadcast3 = left_scalar ? broadcast_dim3 : (is_broadcast ? broadcast_dim3 : 1); const int b_broadcast0 = right_scalar ? broadcast_dim0 : 1; const int b_broadcast1 = right_scalar ? broadcast_dim1 : (is_broadcast ? broadcast_dim1 : 1); const int b_broadcast2 = right_scalar ? broadcast_dim2 : 1; const int b_broadcast3 = right_scalar ? broadcast_dim3 : 1; const Eigen::array a_broadcast = {a_broadcast0, a_broadcast1, a_broadcast2, a_broadcast3}; const Eigen::array b_broadcast = {b_broadcast0, b_broadcast1, b_broadcast2, b_broadcast3}; Eigen::Tensor a(a_dim0, a_dim1, a_dim2, a_dim3); Eigen::Tensor b(b_dim0, b_dim1, b_dim2, b_dim3); Eigen::Tensor c(broadcast_dim0, broadcast_dim1, broadcast_dim2, broadcast_dim3); a.setRandom(); b.setRandom(); if (binary_op == BinaryOp::kAdd) { c = (a.broadcast(a_broadcast) + b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kSub) { c = (a.broadcast(a_broadcast) - b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kMul) { c = (a.broadcast(a_broadcast) * b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kDiv) { Eigen::Tensor constant_value(b_dim0, b_dim1, b_dim2, b_dim3); // avoid div 0 if (src_data_type == kInt8 || src_data_type == kUInt8) { int rand_value = std::rand() % 127; constant_value.setConstant(static_cast(rand_value)); b = constant_value; } else { constant_value.setConstant(static_cast(1)); b += constant_value; } c = (a.broadcast(a_broadcast) / b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kMax) { c = (a.broadcast(a_broadcast).cwiseMax(b.broadcast(b_broadcast))).template cast(); } else if (binary_op == BinaryOp::kMin) { c = (a.broadcast(a_broadcast).cwiseMin(b.broadcast(b_broadcast))).template cast(); } else if (binary_op == BinaryOp::kEqual) { c = (a.broadcast(a_broadcast) == b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kNotEqual) { c = (a.broadcast(a_broadcast) != b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kLessThan) { c = (a.broadcast(a_broadcast) < b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kLessEqual) { c = (a.broadcast(a_broadcast) <= b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kGreaterThan) { c = (a.broadcast(a_broadcast) > b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kGreaterEqual) { c = (a.broadcast(a_broadcast) >= b.broadcast(b_broadcast)).template cast(); } else if (binary_op == BinaryOp::kLogicalAnd) { c = (a.broadcast(a_broadcast).template cast() && b.broadcast(b_broadcast).template cast()) .template cast(); } else if (binary_op == BinaryOp::kLogicalOr) { c = (a.broadcast(a_broadcast).template cast() || b.broadcast(b_broadcast).template cast()) .template cast(); } else if (binary_op == BinaryOp::kLogicalXor) { c = (a.broadcast(a_broadcast).template cast() ^ b.broadcast(b_broadcast).template cast()) .template cast(); } else { UNIMPLEMENTED(); } std::vector a_dims = {a.dimension(0), a.dimension(1), a.dimension(2), a.dimension(3)}; std::vector b_dims = {b.dimension(0), b.dimension(1), b.dimension(2), b.dimension(3)}; std::vector c_dims = {c.dimension(0), c.dimension(1), c.dimension(2), c.dimension(3)}; int64_t a_size = a.size() * sizeof(Src); int64_t b_size = b.size() * sizeof(Src); int64_t c_size = c.size() * sizeof(Dst); for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); ep::test::PinnedMemoryGuard input_b(device.get(), b_size); std::memcpy(input_a.ptr(), a.data(), a_size); std::memcpy(input_b.ptr(), b.data(), b_size); ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_b(device.get(), b_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); std::unique_ptr binary = NewPrimitive(device_type, binary_op, src_data_type, dst_data_type, num_axes); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(binary.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size); if (left_scalar) { Src a_value = *reinterpret_cast(input_a.ptr()); binary->Launch(stream.stream(), GetScalar(a_value), num_axes, b_dims.data(), device_b.ptr(), device_c.ptr()); } else if (right_scalar) { Src b_value = *reinterpret_cast(input_b.ptr()); binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), GetScalar(b_value), device_c.ptr()); } else { binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), num_axes, b_dims.data(), device_b.ptr(), device_c.ptr()); } d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out(c.data(), c.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output.ptr()), c.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out)); } } template void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry, const std::set& device_types) { TestElementwiseBroadcastBinary( registry, device_types, 0); TestElementwiseBroadcastBinary( registry, device_types, 1); TestElementwiseBroadcastBinary( registry, device_types, 2); TestElementwiseBroadcastBinary( registry, device_types, 3); } template void TestComputeBinary(DeviceManagerRegistry* registry, const std::set& device_types) { TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary(registry, device_types); } template void TestLogicalBinary(DeviceManagerRegistry* registry, const std::set& device_types) { TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); TestElementwiseBroadcastBinary( registry, device_types); } } // namespace TEST_F(PrimitiveTest, TestBinary) { TestComputeBinary(&device_manager_registry_, available_device_types_); TestComputeBinary(&device_manager_registry_, available_device_types_); TestComputeBinary(&device_manager_registry_, available_device_types_); TestComputeBinary(&device_manager_registry_, available_device_types_); TestComputeBinary(&device_manager_registry_, available_device_types_); TestComputeBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); TestLogicalBinary(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/broadcast_matmul_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int batch_size, int m, int k, int n, bool transpose_a, bool transpose_b, bool broadcast_a, bool broadcast_b, bool reduce_c) { using Matrix = Eigen::Matrix; CHECK((!broadcast_a) || (!broadcast_b)); int a_batch_dims = broadcast_a ? 1 : batch_size; int b_batch_dims = broadcast_b ? 1 : batch_size; int c_batch_dims = reduce_c ? 1 : batch_size; Eigen::Tensor in_a_buffer(a_batch_dims, m, k); Eigen::Tensor in_b_buffer(b_batch_dims, k, n); Eigen::Tensor out_c_buffer(c_batch_dims, m, n); Eigen::Tensor broadcast_c_buffer(batch_size, m, n); in_a_buffer.setRandom(); in_b_buffer.setRandom(); for (int i = 0; i < batch_size; ++i) { int64_t a_offset = broadcast_a ? 0 : i * m * k; int64_t b_offset = broadcast_b ? 0 : i * k * n; Eigen::Map a(in_a_buffer.data() + a_offset, m, k); Eigen::Map b(in_b_buffer.data() + b_offset, k, n); Eigen::Map c(broadcast_c_buffer.data() + i * m * n, m, n); c = a * b; } if (reduce_c) { Eigen::array reduce_dim = {0}; out_c_buffer = broadcast_c_buffer.sum(reduce_dim).eval().reshape(out_c_buffer.dimensions()); } else { out_c_buffer = broadcast_c_buffer; } int64_t a_size = a_batch_dims * m * k * sizeof(T); int64_t b_size = b_batch_dims * k * n * sizeof(T); int64_t c_size = c_batch_dims * m * n * sizeof(T); Eigen::array shuffling({0, 2, 1}); Eigen::Tensor in_a_transposed = in_a_buffer.shuffle(shuffling); Eigen::Tensor in_b_transposed = in_b_buffer.shuffle(shuffling); size_t num_a_dims = broadcast_a ? 2 : 3; std::vector a_dims; if (!broadcast_a) { a_dims.push_back(batch_size); } if (transpose_a) { a_dims.push_back(k); a_dims.push_back(m); } else { a_dims.push_back(m); a_dims.push_back(k); } size_t num_b_dims = broadcast_b ? 2 : 3; std::vector b_dims; if (!broadcast_b) { b_dims.push_back(batch_size); } if (transpose_b) { b_dims.push_back(n); b_dims.push_back(k); } else { b_dims.push_back(k); b_dims.push_back(n); } size_t num_c_dims = reduce_c ? 2 : 3; std::vector c_dims; if (!reduce_c) { c_dims.push_back(batch_size); } c_dims.push_back(m); c_dims.push_back(n); for (const auto& device_type : device_types) { if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) { // CPU matmul not support float16 continue; } auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); ep::test::PinnedMemoryGuard input_b(device.get(), b_size); if (transpose_a) { std::memcpy(input_a.ptr(), in_a_transposed.data(), a_size); } else { std::memcpy(input_a.ptr(), in_a_buffer.data(), a_size); } if (transpose_b) { std::memcpy(input_b.ptr(), in_b_transposed.data(), b_size); } else { std::memcpy(input_b.ptr(), in_b_buffer.data(), b_size); } ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_b(device.get(), b_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N; const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N; std::unique_ptr broadcast_matmul = NewPrimitive(device_type, data_type, trans_a, trans_b, 3); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(broadcast_matmul.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size); broadcast_matmul->Launch(stream.stream(), 1.0, num_a_dims, a_dims.data(), device_a.ptr(), num_b_dims, b_dims.data(), device_b.ptr(), 0.0, num_c_dims, c_dims.data(), device_c.ptr()); d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out( out_c_buffer.data(), out_c_buffer.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output.ptr()), out_c_buffer.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast(0.001))); } } template void TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int m, int k, int n, bool transpose_a, bool transpose_b) { TestBroadcastMatmul(registry, device_types, 10, m, k, n, transpose_a, transpose_b, false, false, true); TestBroadcastMatmul(registry, device_types, 10, m, k, n, transpose_a, transpose_b, false, false, false); TestBroadcastMatmul(registry, device_types, 10, m, k, n, transpose_a, transpose_b, false, true, true); TestBroadcastMatmul(registry, device_types, 10, m, k, n, transpose_a, transpose_b, false, true, false); TestBroadcastMatmul(registry, device_types, 12, m, k, n, transpose_a, transpose_b, true, false, true); TestBroadcastMatmul(registry, device_types, 12, m, k, n, transpose_a, transpose_b, true, false, false); } template void TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int m, int k, int n) { TestBroadcastMatmul(registry, device_types, m, k, n, false, false); TestBroadcastMatmul(registry, device_types, m, k, n, true, false); TestBroadcastMatmul(registry, device_types, m, k, n, false, true); TestBroadcastMatmul(registry, device_types, m, k, n, true, true); } template void TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set& device_types) { TestBroadcastMatmul(registry, device_types, 64, 16, 8); TestBroadcastMatmul(registry, device_types, 16, 7, 12); } } // namespace TEST_F(PrimitiveTest, TestBroadcastMatmul) { TestBroadcastMatmul(&device_manager_registry_, available_device_types_); TestBroadcastMatmul(&device_manager_registry_, available_device_types_); TestBroadcastMatmul(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/cast_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/cast.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestCast(DeviceManagerRegistry* registry, const std::set& device_types, int elem_cnt) { if (src_data_type == dst_data_type) { return; } if (dst_data_type == kFloat16 && src_data_type != kFloat) { return; } const int src_data_size = elem_cnt * sizeof(Src); const int dst_data_size = elem_cnt * sizeof(Dst); Eigen::Tensor cast_in(elem_cnt); Eigen::Tensor cast_out(elem_cnt); cast_in.setRandom(); cast_out = cast_in.template cast(); for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input(device.get(), src_data_size); ep::test::PinnedMemoryGuard output(device.get(), dst_data_size); std::memcpy(input.ptr(), cast_in.data(), src_data_size); ep::test::DeviceMemoryGuard device_in(device.get(), src_data_size); ep::test::DeviceMemoryGuard device_out(device.get(), dst_data_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); ASSERT_TRUE(h2d.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); h2d->Launch(stream.stream(), device_in.ptr(), input.ptr(), src_data_size); std::unique_ptr cast = NewPrimitive(device_type, src_data_type, dst_data_type); ASSERT_TRUE(cast.operator bool()); cast->Launch(stream.stream(), device_in.ptr(), device_out.ptr(), elem_cnt); d2h->Launch(stream.stream(), output.ptr(), device_out.ptr(), dst_data_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out(cast_out.data(), cast_out.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output.ptr()), cast_out.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out)); } } template void TestCast(DeviceManagerRegistry* registry, const std::set& device_types, int elem_cnt) { TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); } void TestCast(DeviceManagerRegistry* registry, const std::set& device_types, int elem_cnt) { TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); TestCast(registry, device_types, elem_cnt); } } // namespace TEST_F(PrimitiveTest, TestCast) { std::vector elem_cnts = {1024, 3193, 5765}; for (int i = 0; i < elem_cnts.size(); ++i) { TestCast(&device_manager_registry_, available_device_types_, elem_cnts.at(i)); } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/constant_pad_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/constant_pad.h" #include #include namespace oneflow { namespace ep { namespace primitive { namespace test { template void TestConstantPad2d(DeviceManagerRegistry* registry, const std::set& device_types, const int dims[2], const std::vector padding_before, const std::vector padding_after) { using EigenVec = Eigen::Matrix; int in_elem_cnt = 1; int out_elem_cnt = 1; for (int i = 0; i < 2; i++) { in_elem_cnt *= dims[i]; out_elem_cnt *= (dims[i] + padding_before[i] + padding_after[i]); } const int in_matrix_size = in_elem_cnt * sizeof(T); const int out_matrix_size = out_elem_cnt * sizeof(T); for (const auto& device_type : device_types) { Eigen::Tensor mat(dims[0], dims[1]); mat.setRandom(); auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard host_src(device.get(), in_matrix_size); ep::test::PinnedMemoryGuard host_dst(device.get(), out_matrix_size); ep::test::DeviceMemoryGuard device_src(device.get(), in_matrix_size); ep::test::DeviceMemoryGuard device_dst(device.get(), out_matrix_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr constant_pad = NewPrimitive(device_type, dtype); ASSERT_TRUE(constant_pad.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); T* mat_data = mat.data(); std::memcpy(host_src.ptr(), mat_data, in_matrix_size); h2d->Launch(stream.stream(), device_src.ptr(), host_src.ptr(), in_matrix_size); const int64_t src_dims[2] = {dims[0], dims[1]}; constant_pad->Launch(stream.stream(), /*num_dims=*/2, src_dims, device_src.ptr(), padding_before.data(), padding_after.data(), Scalar(0), device_dst.ptr()); d2h->Launch(stream.stream(), host_dst.ptr(), device_dst.ptr(), out_matrix_size); CHECK_JUST(stream.stream()->Sync()); Eigen::array, 2> paddings; for (int i = 0; i < 2; i++) { paddings[i] = std::make_pair(padding_before[i], padding_after[i]); } Eigen::Tensor mat_padded = mat.pad(paddings); auto eigen_padded_res = Eigen::Map( reinterpret_cast(mat_padded.data()), out_elem_cnt); auto constant_pad_primitive_res = Eigen::Map(host_dst.ptr(), out_elem_cnt); ASSERT_TRUE(eigen_padded_res.template isApprox(constant_pad_primitive_res)); } } template void TestConstantPadNegative2d(DeviceManagerRegistry* registry, const std::set& device_types, const int dims[2], const std::vector padding_before, const std::vector padding_after) { using EigenVec = Eigen::Matrix; int in_elem_cnt = 1; int out_elem_cnt = 1; int offsets[2]; int extents[2]; for (int i = 0; i < 2; i++) { in_elem_cnt *= dims[i]; out_elem_cnt *= (dims[i] + padding_before[i] + padding_after[i]); offsets[i] = -padding_before[i]; extents[i] = dims[i] + padding_before[i] + padding_after[i]; } const int in_matrix_size = in_elem_cnt * sizeof(T); const int out_matrix_size = out_elem_cnt * sizeof(T); for (const auto& device_type : device_types) { Eigen::Tensor mat(dims[0], dims[1]); mat.setRandom(); auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard host_src(device.get(), in_matrix_size); ep::test::PinnedMemoryGuard host_dst(device.get(), out_matrix_size); ep::test::DeviceMemoryGuard device_src(device.get(), in_matrix_size); ep::test::DeviceMemoryGuard device_dst(device.get(), out_matrix_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr constant_pad = NewPrimitive(device_type, dtype); ASSERT_TRUE(constant_pad.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); T* mat_data = mat.data(); std::memcpy(host_src.ptr(), mat_data, in_matrix_size); h2d->Launch(stream.stream(), device_src.ptr(), host_src.ptr(), in_matrix_size); const int64_t src_dims[2] = {dims[0], dims[1]}; constant_pad->Launch(stream.stream(), /*num_dims=*/2, src_dims, device_src.ptr(), padding_before.data(), padding_after.data(), Scalar(0), device_dst.ptr()); d2h->Launch(stream.stream(), host_dst.ptr(), device_dst.ptr(), out_matrix_size); CHECK_JUST(stream.stream()->Sync()); Eigen::array slice_offsets = {offsets[0], offsets[1]}; Eigen::array slice_extents = {extents[0], extents[1]}; Eigen::Tensor mat_padded = mat.slice(slice_offsets, slice_extents); auto eigen_padded_res = Eigen::Map( reinterpret_cast(mat_padded.data()), out_elem_cnt); auto constant_pad_primitive_res = Eigen::Map(host_dst.ptr(), out_elem_cnt); ASSERT_TRUE(eigen_padded_res.template isApprox(constant_pad_primitive_res)); } } TEST_F(PrimitiveTest, TestConstantPadPrimitive2d) { const int32_t dims1[2] = {4, 4}; const int32_t dims2[2] = {10, 3}; const int32_t dims3[2] = {31, 4}; const int32_t dims4[2] = {6, 8}; const int32_t dims5[2] = {4, 11}; const std::vector padding_before1 = {1, 1}; const std::vector padding_after1 = {1, 1}; const std::vector padding_before2 = {1, 2}; const std::vector padding_after2 = {2, 1}; const std::vector padding_before3 = {2, 1}; const std::vector padding_after3 = {1, 2}; const std::vector padding_before4 = {3, 1}; const std::vector padding_after4 = {1, 3}; const std::vector padding_before5 = {1, 3}; const std::vector padding_after5 = {3, 1}; TestConstantPad2d(&device_manager_registry_, available_device_types_, dims1, padding_before1, padding_after1); TestConstantPad2d(&device_manager_registry_, available_device_types_, dims2, padding_before2, padding_after2); TestConstantPad2d(&device_manager_registry_, available_device_types_, dims3, padding_before3, padding_after3); TestConstantPad2d(&device_manager_registry_, available_device_types_, dims4, padding_before4, padding_after4); TestConstantPad2d( &device_manager_registry_, available_device_types_, dims5, padding_before5, padding_after5); } TEST_F(PrimitiveTest, TestConstantPadPrimitiveNegative2d) { // const int32_t dims1[2] = {4, 4}; const int32_t dims1[2] = {7, 9}; const int32_t dims2[2] = {10, 7}; const int32_t dims3[2] = {12, 11}; const int32_t dims4[2] = {6, 8}; const int32_t dims5[2] = {4, 11}; const std::vector padding_before1 = {-1, -1}; const std::vector padding_after1 = {-1, -1}; const std::vector padding_before2 = {-2, 0}; const std::vector padding_after2 = {0, -1}; const std::vector padding_before3 = {-2, -1}; const std::vector padding_after3 = {-1, -2}; const std::vector padding_before4 = {-1, 0}; const std::vector padding_after4 = {0, -1}; const std::vector padding_before5 = {-1, -3}; const std::vector padding_after5 = {0, -1}; TestConstantPadNegative2d( &device_manager_registry_, available_device_types_, dims1, padding_before1, padding_after1); TestConstantPadNegative2d( &device_manager_registry_, available_device_types_, dims2, padding_before2, padding_after2); TestConstantPadNegative2d( &device_manager_registry_, available_device_types_, dims3, padding_before3, padding_after3); TestConstantPadNegative2d( &device_manager_registry_, available_device_types_, dims4, padding_before4, padding_after4); TestConstantPadNegative2d( &device_manager_registry_, available_device_types_, dims5, padding_before5, padding_after5); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/copy_nd_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestCopyNd(DeviceManagerRegistry* registry, const std::set& device_types, int64_t num_dims) { std::vector src_dims(num_dims, 0); std::vector src_pos(num_dims, 0); std::vector dst_pos(num_dims, 0); std::vector dst_dims(num_dims, 0); std::vector extent(num_dims, 0); int64_t src_elem = 1; int64_t dst_elem = 1; for (int i = 0; i < num_dims; ++i) { int64_t rand_dim = 8 + std::rand() % 32; int64_t rand_pos = std::rand() % 16; src_dims.at(i) = rand_dim; dst_pos.at(i) = rand_pos; dst_dims.at(i) = rand_pos + rand_dim; extent.at(i) = rand_dim; src_elem *= src_dims.at(i); dst_elem *= dst_dims.at(i); } int64_t src_size = src_elem * sizeof(T); int64_t dst_size = dst_elem * sizeof(T); for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input(device.get(), src_size); ep::test::PinnedMemoryGuard output(device.get(), src_size); ep::test::DeviceMemoryGuard device0(device.get(), src_size); ep::test::DeviceMemoryGuard device1(device.get(), dst_size); for (size_t i = 0; i < src_elem; ++i) { *(input.ptr() + i) = static_cast(i); } ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); ASSERT_TRUE(h2d.operator bool()); std::unique_ptr copy_nd = NewPrimitive(device_type, num_dims); ASSERT_TRUE(copy_nd.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); std::unique_ptr memset = NewPrimitive(device_type); ASSERT_TRUE(memset.operator bool()); h2d->Launch(stream.stream(), device0.ptr(), input.ptr(), src_size); // contiguous device0 to noncontiguous device1 copy_nd->Launch(stream.stream(), data_type, num_dims, device1.ptr(), dst_dims.data(), dst_pos.data(), device0.ptr(), src_dims.data(), src_pos.data(), extent.data()); // memset device0 memset->Launch(stream.stream(), device0.ptr(), 0x55, src_size); // noncontiguous device1 to contiguous device0 copy_nd->Launch(stream.stream(), data_type, num_dims, device0.ptr(), src_dims.data(), src_pos.data(), device1.ptr(), dst_dims.data(), dst_pos.data(), extent.data()); d2h->Launch(stream.stream(), output.ptr(), device0.ptr(), src_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < src_elem; ++i) { ASSERT_EQ(*(input.ptr() + i), *(output.ptr() + i)); } } } } // namespace TEST_F(PrimitiveTest, TestCopyNd) { for (int i = 1; i < 6; ++i) { TestCopyNd(&device_manager_registry_, available_device_types_, i); TestCopyNd(&device_manager_registry_, available_device_types_, i); TestCopyNd(&device_manager_registry_, available_device_types_, i); TestCopyNd(&device_manager_registry_, available_device_types_, i); TestCopyNd(&device_manager_registry_, available_device_types_, i); } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/elementwise_unary_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { template struct ReluFunctor { Dst operator()(Src src) { if (src > zero_val) { return src; } return zero_val; } Src zero_val = static_cast(0.0); }; template struct GeluFunctor { Dst operator()(Src src) { return static_cast(0.5) * src * (static_cast(1.0) + std::erf(inv_sqrt2 * src)); } Src inv_sqrt2 = std::sqrt(0.5); }; template struct TanhFunctor { Dst operator()(Src src) { return static_cast(std::tanh(src)); } }; template struct LogicalNotFunctor { Dst operator()(Src src) { return static_cast(!src); } }; template void EigenElementwise(FunctorT functor, Src* src, Dst* dst, const size_t elem_cnt) { for (int idx = 0; idx < elem_cnt; idx++) { dst[idx] = functor(src[idx]); } } template class FunctorClass> void TestElementwise(DeviceManagerRegistry* registry, const std::set& device_types, const size_t elem_cnt, Scalar attr0 = Scalar(), Scalar attr1 = Scalar()) { for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); using EigenSrcVec = Eigen::Matrix; using EigenDstVec = Eigen::Matrix; const size_t src_data_size = elem_cnt * sizeof(Src); const size_t dst_data_size = elem_cnt * sizeof(Dst); EigenSrcVec eigen_src(elem_cnt); EigenDstVec eigen_dst(elem_cnt); eigen_src.setRandom(); eigen_dst.setZero(); ep::test::PinnedMemoryGuard host_src(device.get(), elem_cnt * sizeof(Src)); ep::test::PinnedMemoryGuard host_dst(device.get(), elem_cnt * sizeof(Dst)); ep::test::DeviceMemoryGuard device_src(device.get(), elem_cnt * sizeof(Src)); ep::test::DeviceMemoryGuard device_dst(device.get(), elem_cnt * sizeof(Dst)); ep::test::StreamGuard stream(device.get()); std::unique_ptr elementwise_primitive = NewPrimitive( device_type, unary_op, /*src_type=*/SrcType, /*dst_type=*/DstType, attr0, attr1); ASSERT_TRUE(elementwise_primitive.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); Src* eigen_src_data = eigen_src.data(); std::memcpy(host_src.ptr(), eigen_src_data, src_data_size); h2d->Launch(stream.stream(), device_src.ptr(), host_src.ptr(), src_data_size); elementwise_primitive->Launch(stream.stream(), device_src.ptr(), device_dst.ptr(), elem_cnt); d2h->Launch(stream.stream(), host_dst.ptr(), device_dst.ptr(), dst_data_size); CHECK_JUST(stream.stream()->Sync()); FunctorClass functor{}; EigenElementwise>(functor, eigen_src.data(), eigen_dst.data(), elem_cnt); auto elementwise_primitive_res = Eigen::Map(host_dst.ptr(), elem_cnt); ASSERT_TRUE(eigen_dst.template isApprox(elementwise_primitive_res)); } } TEST_F(PrimitiveTest, TestElementwisePrimitive) { // Test Relu TestElementwise(&device_manager_registry_, available_device_types_, 16); TestElementwise(&device_manager_registry_, available_device_types_, 32); TestElementwise(&device_manager_registry_, available_device_types_, 64); TestElementwise(&device_manager_registry_, available_device_types_, 128); // Test Gelu TestElementwise(&device_manager_registry_, available_device_types_, 32); TestElementwise(&device_manager_registry_, available_device_types_, 128); // Test Tanh TestElementwise(&device_manager_registry_, available_device_types_, 32); TestElementwise(&device_manager_registry_, available_device_types_, 128); // Test Logical Not TestElementwise( &device_manager_registry_, available_device_types_, 32); TestElementwise( &device_manager_registry_, available_device_types_, 64); TestElementwise( &device_manager_registry_, available_device_types_, 16); TestElementwise( &device_manager_registry_, available_device_types_, 128); TestElementwise( &device_manager_registry_, available_device_types_, 96); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/fill_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/fill.h" #ifdef WITH_CUDA #include #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestFill(DeviceManagerRegistry* registry, const std::set& device_types, size_t n) { const size_t vector_size = n * sizeof(T); for (const auto& device_type : device_types) { #ifdef WITH_CUDA #if CUDA_VERSION >= 11000 if (device_type == DeviceType::kCPU && data_type == DataType::kBFloat16) { continue; } #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA auto device = registry->GetDevice(device_type, 0); ep::test::DeviceMemoryGuard device_mem(device.get(), vector_size); ep::test::PinnedMemoryGuard host_mem(device.get(), vector_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr fill = NewPrimitive(device_type, data_type); ASSERT_TRUE(fill.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); fill->Launch(stream.stream(), device_mem.ptr(), Scalar(15.0), n); d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < n; ++i) { ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), static_cast(15.0)); } fill->Launch(stream.stream(), device_mem.ptr(), Scalar(0), n); d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < n; ++i) { #ifdef WITH_CUDA if constexpr (std::is_same_v) { ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), __float2half(0.0)); #if CUDA_VERSION >= 11000 } else if constexpr (std::is_same_v) { ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), __float2bfloat16(0.0)); #endif // CUDA_VERSION >= 11000 } else { ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), static_cast(0)); } #else ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), static_cast(0)); #endif // WITH_CUDA } } } } // namespace TEST_F(PrimitiveTest, TestFill) { TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); TestFill(&device_manager_registry_, available_device_types_, 1024); #ifdef WITH_CUDA TestFill(&device_manager_registry_, available_device_types_, 1024); #if CUDA_VERSION >= 11000 TestFill(&device_manager_registry_, available_device_types_, 1024); #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA TestFill(&device_manager_registry_, available_device_types_, 1024); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/matmul_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int m, int k, int n, bool transpose_a, bool transpose_b) { using Matrix = Eigen::Matrix; Matrix a = Matrix::Random(m, k); Matrix b = Matrix::Random(k, n); Matrix c = a * b; Matrix a_transpose = a.transpose(); Matrix b_transpose = b.transpose(); int64_t a_size = m * k * sizeof(T); int64_t b_size = k * n * sizeof(T); int64_t c_size = m * n * sizeof(T); for (const auto& device_type : device_types) { if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) { // CPU matmul not support float16 continue; } auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); ep::test::PinnedMemoryGuard input_b(device.get(), b_size); if (transpose_a) { std::memcpy(input_a.ptr(), a_transpose.data(), a_size); } else { std::memcpy(input_a.ptr(), a.data(), a_size); } if (transpose_b) { std::memcpy(input_b.ptr(), b_transpose.data(), b_size); } else { std::memcpy(input_b.ptr(), b.data(), b_size); } ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_b(device.get(), b_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N; const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N; std::unique_ptr matmul = NewPrimitive(device_type, data_type, trans_a, trans_b); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(matmul.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size); matmul->Launch(stream.stream(), m, n, k, 1.0, device_a.ptr(), device_b.ptr(), 0.0, device_c.ptr()); d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); auto res = Eigen::Map(reinterpret_cast(output.ptr()), m, n); ASSERT_TRUE(c.template isApprox(res, static_cast(0.001))); } } template void TestMatmul(DeviceManagerRegistry* registry, const std::set& device_types, int m, int k, int n) { TestMatmul(registry, device_types, m, k, n, false, false); TestMatmul(registry, device_types, m, k, n, true, false); TestMatmul(registry, device_types, m, k, n, false, true); TestMatmul(registry, device_types, m, k, n, true, true); } template void TestMatmul(DeviceManagerRegistry* registry, const std::set& device_types) { TestMatmul(registry, device_types, 64, 16, 8); TestMatmul(registry, device_types, 16, 7, 12); } } // namespace TEST_F(PrimitiveTest, TestMatmul) { TestMatmul(&device_manager_registry_, available_device_types_); TestMatmul(&device_manager_registry_, available_device_types_); TestMatmul(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/memcpy_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memcpy.h" namespace oneflow { namespace ep { namespace primitive { namespace test { TEST_F(PrimitiveTest, TestMemcpy) { const size_t test_elem = 1024 * 1024; const size_t test_size = test_elem * sizeof(float); for (const auto& device_type : available_device_types_) { auto device = device_manager_registry_.GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input(device.get(), test_size); ep::test::PinnedMemoryGuard output(device.get(), test_size); ep::test::DeviceMemoryGuard device0(device.get(), test_size); ep::test::DeviceMemoryGuard device1(device.get(), test_size); for (size_t i = 0; i < test_elem; ++i) { *(input.ptr() + i) = i; } ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); ASSERT_TRUE(h2d.operator bool()); std::unique_ptr d2d = NewPrimitive(device_type, MemcpyKind::kDtoD); ASSERT_TRUE(d2d.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); h2d->Launch(stream.stream(), device0.ptr(), input.ptr(), test_size); d2d->Launch(stream.stream(), device1.ptr(), device0.ptr(), test_size); d2h->Launch(stream.stream(), output.ptr(), device1.ptr(), test_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < test_elem; ++i) { ASSERT_EQ(*(input.ptr() + i), *(output.ptr() + i)); } } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/memset_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" namespace oneflow { namespace ep { namespace primitive { namespace test { TEST_F(PrimitiveTest, TestMemset) { const size_t test_size = 1024 * 1024; for (const auto& device_type : available_device_types_) { auto device = device_manager_registry_.GetDevice(device_type, 0); ep::test::DeviceMemoryGuard device_mem(device.get(), test_size); ep::test::PinnedMemoryGuard host_mem(device.get(), test_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr memset = NewPrimitive(device_type); ASSERT_TRUE(memset.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); memset->Launch(stream.stream(), device_mem.ptr(), 0x55, test_size); d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), test_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < test_size; ++i) { ASSERT_EQ(*(host_mem.ptr() + i), 0x55); } memset->Launch(stream.stream(), device_mem.ptr(), 0, test_size); d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), test_size); CHECK_JUST(stream.stream()->Sync()); for (size_t i = 0; i < test_size; ++i) { ASSERT_EQ(*(host_mem.ptr() + i), 0); } } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/permute_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/permute.h" #include #include namespace oneflow { namespace ep { namespace primitive { namespace test { template void TestPermute2D(DeviceManagerRegistry* registry, const std::set& device_types, const int dims[NumDims], const int permutation_list[NumDims]) { using EigenVec = Eigen::Matrix; const int elem_cnt = dims[0] * dims[1]; const int matrix_size = elem_cnt * sizeof(T); for (const auto& device_type : device_types) { Eigen::Tensor mat(dims[0], dims[1]); mat.setRandom(); auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard host_src(device.get(), matrix_size); ep::test::PinnedMemoryGuard host_dst(device.get(), matrix_size); ep::test::DeviceMemoryGuard device_src(device.get(), matrix_size); ep::test::DeviceMemoryGuard device_dst(device.get(), matrix_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr permute = NewPrimitive(device_type, /*max_num_dims=*/NumDims); ASSERT_TRUE(permute.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); T* mat_data = mat.data(); std::memcpy(host_src.ptr(), mat_data, matrix_size); h2d->Launch(stream.stream(), device_src.ptr(), host_src.ptr(), matrix_size); const int64_t src_dims[NumDims] = {dims[0], dims[1]}; permute->Launch(stream.stream(), dtype, /*num_dims=*/NumDims, src_dims, device_src.ptr(), permutation_list, device_dst.ptr()); d2h->Launch(stream.stream(), host_dst.ptr(), device_dst.ptr(), matrix_size); CHECK_JUST(stream.stream()->Sync()); Eigen::array shuffle_index({permutation_list[0], permutation_list[1]}); Eigen::Tensor mat_transposed = mat.shuffle(shuffle_index); auto eigen_transposed_res = Eigen::Map( reinterpret_cast(mat_transposed.data()), elem_cnt); auto permute_primitive_res = Eigen::Map(host_dst.ptr(), elem_cnt); ASSERT_TRUE(eigen_transposed_res.template isApprox(permute_primitive_res)); } } template void TestPermute3D(DeviceManagerRegistry* registry, const std::set& device_types, const int dims[NumDims], const int permutation_list[NumDims]) { using EigenVec = Eigen::Matrix; const int elem_cnt = dims[0] * dims[1] * dims[2]; const int matrix_size = elem_cnt * sizeof(T); for (const auto& device_type : device_types) { Eigen::Tensor mat(dims[0], dims[1], dims[2]); mat.setRandom(); auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard host_src(device.get(), matrix_size); ep::test::PinnedMemoryGuard host_dst(device.get(), matrix_size); ep::test::DeviceMemoryGuard device_src(device.get(), matrix_size); ep::test::DeviceMemoryGuard device_dst(device.get(), matrix_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr permute = NewPrimitive(device_type, /*max_num_dims=*/NumDims); ASSERT_TRUE(permute.operator bool()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); T* mat_data = mat.data(); std::memcpy(host_src.ptr(), mat_data, matrix_size); h2d->Launch(stream.stream(), device_src.ptr(), host_src.ptr(), matrix_size); const int64_t src_dims[NumDims] = {dims[0], dims[1], dims[2]}; permute->Launch(stream.stream(), dtype, /*num_dims=*/NumDims, src_dims, device_src.ptr(), permutation_list, device_dst.ptr()); d2h->Launch(stream.stream(), host_dst.ptr(), device_dst.ptr(), matrix_size); CHECK_JUST(stream.stream()->Sync()); Eigen::array shuffle_index( {permutation_list[0], permutation_list[1], permutation_list[2]}); Eigen::Tensor mat_transposed = mat.shuffle(shuffle_index); auto eigen_transposed_res = Eigen::Map( reinterpret_cast(mat_transposed.data()), elem_cnt); auto permute_primitive_res = Eigen::Map(host_dst.ptr(), elem_cnt); ASSERT_TRUE(eigen_transposed_res.template isApprox(permute_primitive_res)); } } TEST_F(PrimitiveTest, TestBatchPermute) { const int permutation_list[2] = {1, 0}; const int32_t dims0[2] = {2, 3}; const int32_t dims1[2] = {7, 9}; const int32_t dims2[2] = {10, 3}; const int32_t dims3[2] = {31, 4}; const int32_t dims4[2] = {6, 8}; TestPermute2D(&device_manager_registry_, available_device_types_, dims0, permutation_list); TestPermute2D(&device_manager_registry_, available_device_types_, dims1, permutation_list); TestPermute2D(&device_manager_registry_, available_device_types_, dims2, permutation_list); TestPermute2D(&device_manager_registry_, available_device_types_, dims3, permutation_list); TestPermute2D( &device_manager_registry_, available_device_types_, dims4, permutation_list); } TEST_F(PrimitiveTest, TestPermute) { const int permutation_list0[3] = {0, 2, 1}; const int permutation_list1[3] = {1, 2, 0}; const int permutation_list2[3] = {1, 0, 2}; const int permutation_list3[3] = {2, 1, 0}; const int permutation_list4[3] = {2, 0, 1}; const int32_t dims0[3] = {2, 3, 9}; const int32_t dims1[3] = {7, 9, 4}; const int32_t dims2[3] = {10, 3, 2}; const int32_t dims3[3] = {3, 7, 2}; const int32_t dims4[3] = {8, 2, 5}; TestPermute3D(&device_manager_registry_, available_device_types_, dims0, permutation_list0); TestPermute3D(&device_manager_registry_, available_device_types_, dims1, permutation_list1); TestPermute3D(&device_manager_registry_, available_device_types_, dims2, permutation_list2); TestPermute3D(&device_manager_registry_, available_device_types_, dims3, permutation_list3); TestPermute3D( &device_manager_registry_, available_device_types_, dims4, permutation_list4); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/primitive_test.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_ #define ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_ #include "oneflow/core/ep/test/test_util.h" namespace oneflow { namespace ep { namespace primitive { namespace test { class PrimitiveTest : public ep::test::TestCase {}; } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_ ================================================ FILE: oneflow/core/ep/test/primitive/softmax_backward_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/softmax_backward.h" #include "oneflow/core/ep/include/primitive/log_softmax_backward.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestSoftmaxBackward(DeviceManagerRegistry* registry, const std::set& device_types, int num_rows, int num_cols, bool log_softmax) { const int elem_cnt = num_rows * num_cols; const int data_size = elem_cnt * sizeof(T); Eigen::Tensor softmax_y(num_rows, num_cols); Eigen::Tensor softmax_dy(num_rows, num_cols); Eigen::Tensor softmax_dx(num_rows, num_cols); softmax_y.setRandom(); softmax_dy.setRandom(); Eigen::array reduce_dim = {1}; Eigen::array reduced_shape = {num_rows, 1}; Eigen::array broadcast_shape = {1, num_cols}; Eigen::Tensor compute_y = softmax_y.template cast(); Eigen::Tensor compute_dy = softmax_dy.template cast(); Eigen::Tensor compute_dx; if (log_softmax) { compute_dx = compute_dy - compute_y.exp() * compute_dy.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape); } else { Eigen::Tensor row_buf = compute_dy * compute_y; compute_dx = (compute_dy - row_buf.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape)) * compute_y; } softmax_dx = compute_dx.template cast(); for (const auto& device_type : device_types) { if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) { // CPU softmax not support float16 continue; } auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_y(device.get(), data_size); ep::test::PinnedMemoryGuard input_dy(device.get(), data_size); ep::test::PinnedMemoryGuard output_dx(device.get(), data_size); std::memcpy(input_y.ptr(), softmax_y.data(), data_size); std::memcpy(input_dy.ptr(), softmax_dy.data(), data_size); ep::test::DeviceMemoryGuard device_in_y(device.get(), data_size); ep::test::DeviceMemoryGuard device_in_dy(device.get(), data_size); ep::test::DeviceMemoryGuard device_out_dx(device.get(), data_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); ASSERT_TRUE(h2d.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); h2d->Launch(stream.stream(), device_in_y.ptr(), input_y.ptr(), data_size); h2d->Launch(stream.stream(), device_in_dy.ptr(), input_dy.ptr(), data_size); if (log_softmax) { std::unique_ptr log_softmax = NewPrimitive(device_type, data_type); ASSERT_TRUE(log_softmax.operator bool()); log_softmax->Launch(stream.stream(), num_rows, num_cols, device_in_y.ptr(), device_in_dy.ptr(), device_out_dx.ptr()); } else { std::unique_ptr softmax = NewPrimitive(device_type, data_type); ASSERT_TRUE(softmax.operator bool()); softmax->Launch(stream.stream(), num_rows, num_cols, device_in_y.ptr(), device_in_dy.ptr(), device_out_dx.ptr()); } d2h->Launch(stream.stream(), output_dx.ptr(), device_out_dx.ptr(), data_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out(softmax_dx.data(), softmax_dx.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output_dx.ptr()), softmax_dx.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast(0.001))); } } void TestSoftmaxBackward(DeviceManagerRegistry* registry, const std::set& device_types, int num_rows, int num_cols) { TestSoftmaxBackward(registry, device_types, num_rows, num_cols, true); TestSoftmaxBackward(registry, device_types, num_rows, num_cols, false); TestSoftmaxBackward(registry, device_types, num_rows, num_cols, true); TestSoftmaxBackward(registry, device_types, num_rows, num_cols, false); TestSoftmaxBackward(registry, device_types, num_rows, num_cols, true); TestSoftmaxBackward(registry, device_types, num_rows, num_cols, false); } } // namespace TEST_F(PrimitiveTest, TestSoftmaxBackward) { std::vector num_rows = {32, 33, 512, 511}; std::vector num_cols = {15, 16, 32, 768, 1536}; for (int i = 0; i < num_rows.size(); ++i) { for (int j = 0; j < num_cols.size(); ++j) { TestSoftmaxBackward(&device_manager_registry_, available_device_types_, num_rows.at(i), num_cols.at(j)); } } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/softmax_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/softmax.h" #include "oneflow/core/ep/include/primitive/log_softmax.h" #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestSoftmax(DeviceManagerRegistry* registry, const std::set& device_types, int num_rows, int num_cols, bool log_softmax) { const int elem_cnt = num_rows * num_cols; const int data_size = elem_cnt * sizeof(T); Eigen::Tensor softmax_in(num_rows, num_cols); Eigen::Tensor softmax_out(num_rows, num_cols); softmax_in.setRandom(); Eigen::array reduce_dim = {1}; Eigen::array reduced_shape = {num_rows, 1}; Eigen::array broadcast_shape = {1, num_cols}; Eigen::Tensor row_buf = (softmax_in - softmax_in.maximum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape)); if (log_softmax) { softmax_out = row_buf - row_buf.exp() .sum(reduce_dim) .eval() .reshape(reduced_shape) .log() .broadcast(broadcast_shape); } else { row_buf = row_buf.exp(); softmax_out = row_buf / row_buf.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape); } for (const auto& device_type : device_types) { if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) { // CPU softmax not support float16 continue; } auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input(device.get(), data_size); ep::test::PinnedMemoryGuard output(device.get(), data_size); std::memcpy(input.ptr(), softmax_in.data(), data_size); ep::test::DeviceMemoryGuard device_in(device.get(), data_size); ep::test::DeviceMemoryGuard device_out(device.get(), data_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); ASSERT_TRUE(h2d.operator bool()); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); ASSERT_TRUE(d2h.operator bool()); h2d->Launch(stream.stream(), device_in.ptr(), input.ptr(), data_size); if (log_softmax) { std::unique_ptr log_softmax = NewPrimitive(device_type, data_type); ASSERT_TRUE(log_softmax.operator bool()); log_softmax->Launch(stream.stream(), num_rows, num_cols, device_in.ptr(), device_out.ptr()); } else { std::unique_ptr softmax = NewPrimitive(device_type, data_type); ASSERT_TRUE(softmax.operator bool()); softmax->Launch(stream.stream(), num_rows, num_cols, device_in.ptr(), device_out.ptr()); } d2h->Launch(stream.stream(), output.ptr(), device_out.ptr(), data_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out(softmax_out.data(), softmax_out.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(output.ptr()), softmax_out.size()); ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast(0.001))); } } void TestSoftmax(DeviceManagerRegistry* registry, const std::set& device_types, int num_rows, int num_cols) { TestSoftmax(registry, device_types, num_rows, num_cols, true); TestSoftmax(registry, device_types, num_rows, num_cols, false); TestSoftmax(registry, device_types, num_rows, num_cols, true); TestSoftmax(registry, device_types, num_rows, num_cols, false); TestSoftmax(registry, device_types, num_rows, num_cols, true); TestSoftmax(registry, device_types, num_rows, num_cols, false); } } // namespace TEST_F(PrimitiveTest, TestSoftmax) { std::vector num_rows = {32, 33, 512, 511}; std::vector num_cols = {15, 16, 32, 768, 1536}; for (int i = 0; i < num_rows.size(); ++i) { for (int j = 0; j < num_cols.size(); ++j) { TestSoftmax(&device_manager_registry_, available_device_types_, num_rows.at(i), num_cols.at(j)); } } } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/unary_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h" #include #include namespace oneflow { namespace ep { namespace primitive { namespace test { namespace { template void TestElementwiseBroadcastUnary(DeviceManagerRegistry* registry, const std::set& device_types) { const std::vector num_src_axes = {1, 4, 1, 4, 4}; const std::vector num_dst_axes = {4, 4, 1, 4, 4}; const std::vector> a_dims_vec = { {1, 1, 1, 1}, {1, 3, 2, 4}, {1, 1, 1, 1}, {1, 2, 3, 4}, {1, 2, 3, 4}}; const std::vector> broadcast_dims_vec = { {2, 3, 2, 4}, {2, 3, 2, 4}, {1, 1, 1, 1}, {1, 2, 3, 4}, {1, 2, 3, 4}}; const std::vector> a_broadcasts_vec = { {2, 3, 2, 4}, {2, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}; const std::vector> a_strides_vec = { {0, 0, 0, 0}, {a_dims_vec[1][1] * a_dims_vec[1][2] * a_dims_vec[1][3], a_dims_vec[1][2] * a_dims_vec[1][3], a_dims_vec[1][3], 1}, {0, 0, 0, 0}, {a_dims_vec[3][1] * a_dims_vec[3][2] * a_dims_vec[3][3], a_dims_vec[3][2] * a_dims_vec[3][3], a_dims_vec[3][3], 1}, {a_dims_vec[4][1] * a_dims_vec[4][2] * a_dims_vec[4][3], a_dims_vec[4][2] * a_dims_vec[4][3], a_dims_vec[4][3], 1}}; const std::vector> c_strides_vec = { {broadcast_dims_vec[0][1] * broadcast_dims_vec[0][2] * broadcast_dims_vec[0][3], broadcast_dims_vec[0][2] * broadcast_dims_vec[0][3], broadcast_dims_vec[0][3], 1}, {broadcast_dims_vec[1][2] * broadcast_dims_vec[1][3], broadcast_dims_vec[1][0] * broadcast_dims_vec[1][2] * broadcast_dims_vec[1][3], 1, broadcast_dims_vec[1][2]}, {0, 0, 0, 0}, {broadcast_dims_vec[3][1] * broadcast_dims_vec[3][2] * broadcast_dims_vec[3][3], broadcast_dims_vec[3][2], 1, broadcast_dims_vec[3][1] * broadcast_dims_vec[3][2]}, {1, broadcast_dims_vec[4][0], broadcast_dims_vec[4][0] * broadcast_dims_vec[4][1], broadcast_dims_vec[4][0] * broadcast_dims_vec[4][1] * broadcast_dims_vec[4][2]}}; for (int i = 0; i < 5; i++) { const std::vector& a_dims = a_dims_vec[i]; const std::vector& c_dims = broadcast_dims_vec[i]; const Eigen::array a_broadcast = {a_broadcasts_vec[i][0], a_broadcasts_vec[i][1], a_broadcasts_vec[i][2], a_broadcasts_vec[i][3]}; Eigen::Tensor a(a_dims[0], a_dims[1], a_dims[2], a_dims[3]); const std::vector& a_strides = a_strides_vec[i]; const std::vector& c_strides = c_strides_vec[i]; a.setRandom(); Eigen::Tensor t = a.broadcast(a_broadcast); Eigen::Tensor broadcast_a = t.template cast(); const int64_t a_size = a.size() * sizeof(Src); const int64_t c_count = std::accumulate(c_dims.begin(), c_dims.end(), 1, std::multiplies()); const int64_t c_size = c_count * sizeof(Dst); const int64_t broadcast_a_size = broadcast_a.size() * sizeof(Dst); ASSERT_TRUE(c_size == broadcast_a_size); for (const auto& device_type : device_types) { // broadcast a with non-broadcast elementwise unary primitive auto device = registry->GetDevice(device_type, 0); ep::test::PinnedMemoryGuard input_broadcast_a(device.get(), broadcast_a_size); std::memcpy(input_broadcast_a.ptr(), broadcast_a.data(), broadcast_a_size); ep::test::PinnedMemoryGuard broadcast_output(device.get(), c_size); ep::test::DeviceMemoryGuard device_broadcast_a(device.get(), broadcast_a_size); ep::test::DeviceMemoryGuard device_broadcast_c(device.get(), c_size); ep::test::StreamGuard stream(device.get()); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); std::unique_ptr unary = NewPrimitive( device_type, unary_op, src_data_type, dst_data_type); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(unary.operator bool()); h2d->Launch(stream.stream(), device_broadcast_a.ptr(), input_broadcast_a.ptr(), broadcast_a_size); unary->Launch(stream.stream(), device_broadcast_a.ptr(), device_broadcast_c.ptr(), c_count); // c.size() is for count d2h->Launch(stream.stream(), broadcast_output.ptr(), device_broadcast_c.ptr(), c_size); // c_size is in bytes CHECK_JUST(stream.stream()->Sync()); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); std::memcpy(input_a.ptr(), a.data(), a_size); ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); std::unique_ptr broadcast_unary = NewPrimitive(device_type, unary_op, src_data_type, dst_data_type, MAX(num_src_axes[i], num_dst_axes[i])); ASSERT_TRUE(broadcast_unary.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); broadcast_unary->Launch(stream.stream(), num_src_axes[i], a_dims.data(), a_strides.data(), device_a.ptr(), num_dst_axes[i], c_dims.data(), c_strides.data(), device_c.ptr()); d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); Dst thresh = 1e-4; bool res = true; std::vector a_broadcast_strides; for (int j = num_dst_axes[i] - 1; j >= 0; j--) { if (j == num_dst_axes[i] - 1) { a_broadcast_strides.push_back(1); } else { a_broadcast_strides.insert(a_broadcast_strides.begin(), a_broadcast_strides[0] * a_dims[j + 1] * a_broadcast[j + 1]); } } for (int i0 = 0; i0 < c_dims[0]; i0++) { for (int i1 = 0; i1 < c_dims[1]; i1++) { for (int i2 = 0; i2 < c_dims[2]; i2++) { for (int i3 = 0; i3 < c_dims[3]; i3++) { #define ABS(x) ((x > 0) ? (x) : (-x)) const size_t src_index = a_broadcast_strides[0] * i0 + a_broadcast_strides[1] * i1 + a_broadcast_strides[2] * i2 + a_broadcast_strides[3] * i3; const size_t dst_index = c_strides[0] * i0 + c_strides[1] * i1 + c_strides[2] * i2 + c_strides[3] * i3; if (ABS(reinterpret_cast(broadcast_output.ptr())[src_index] - reinterpret_cast(output.ptr())[dst_index]) > thresh) { res = false; } #undef ABS } } } } ASSERT_TRUE(res); } } } template void TestElementwiseBroadcastUnaryBatchPermute(DeviceManagerRegistry* registry, const std::set& device_types) { const std::vector& a_dims = {5, 2}; const std::vector& c_dims = {5, 2}; Eigen::Tensor a(5, 4); const std::vector>& a_strides = {{4, 1}, {2, 1}}; const std::vector>& c_strides = {{1, 5}, {1, 10}}; a.setRandom(); const int64_t a_size = a.size() * sizeof(Src); const int64_t c_count = std::accumulate(c_dims.begin(), c_dims.end(), 1, std::multiplies()); const int64_t c_size = MAX(c_count, a.size()) * sizeof(Dst); for (int i = 0; i < a_strides.size(); i++) { auto& a_stride = a_strides[i]; auto& c_stride = c_strides[i]; for (const auto& device_type : device_types) { // broadcast a with non-broadcast elementwise unary primitive auto device = registry->GetDevice(device_type, 0); ep::test::StreamGuard stream(device.get()); ep::test::PinnedMemoryGuard input_a(device.get(), a_size); std::memcpy(input_a.ptr(), a.data(), a_size); ep::test::PinnedMemoryGuard output(device.get(), c_size); ep::test::DeviceMemoryGuard device_a(device.get(), a_size); ep::test::DeviceMemoryGuard device_c(device.get(), c_size); std::unique_ptr h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); std::unique_ptr d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); std::unique_ptr broadcast_unary = NewPrimitive(device_type, UnaryOp::kIdentity, src_data_type, dst_data_type, 2); ASSERT_TRUE(broadcast_unary.operator bool()); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); broadcast_unary->Launch(stream.stream(), 2, a_dims.data(), a_stride.data(), device_a.ptr(), 2, c_dims.data(), c_stride.data(), device_c.ptr()); d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); CHECK_JUST(stream.stream()->Sync()); Dst thresh = 1e-4; bool res = true; for (int i0 = 0; i0 < c_dims[0]; i0++) { for (int i1 = 0; i1 < c_dims[1]; i1++) { #define ABS(x) ((x > 0) ? (x) : (-x)) const size_t src_index = a_stride[0] * i0 + a_stride[1] * i1; const size_t dst_index = c_stride[0] * i0 + c_stride[1] * i1; if (ABS(reinterpret_cast(input_a.ptr())[src_index] - reinterpret_cast(output.ptr())[dst_index]) > thresh) { res = false; } #undef ABS } } ASSERT_TRUE(res); } } } } // namespace TEST_F(PrimitiveTest, TestUnary) { TestElementwiseBroadcastUnary(&device_manager_registry_, available_device_types_); TestElementwiseBroadcastUnaryBatchPermute( &device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/primitive/where_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/where.h" #include "oneflow/core/common/data_type.h" #include #include #include namespace oneflow { template<> struct GetDataType : std::integral_constant {}; namespace ep { namespace primitive { namespace test { namespace { template std::string DimsToString(const dims_type& dims, const std::string& name) { std::ostringstream ss; ss << name << "=("; for (size_t i = 0; i < dims.size(); ++i) { if (i > 0) { ss << ", "; } ss << dims[i]; } ss << ")"; return ss.str(); }; template void TestWhere(const std::vector& devices, size_t num_cond_dims, const int64_t* cond_dims, size_t num_x_dims, const int64_t* x_dims, size_t num_y_dims, const int64_t* y_dims) { ASSERT_TRUE(num_cond_dims <= ndim); ASSERT_TRUE(num_x_dims <= ndim); ASSERT_TRUE(num_y_dims <= ndim); std::array broadcast_dims{}; std::array broadcast_cond_dims{}; std::array broadcast_x_dims{}; std::array broadcast_y_dims{}; std::array extend_cond_dims{}; std::array extend_x_dims{}; std::array extend_y_dims{}; for (size_t i = 0; i < ndim; ++i) { size_t cond_lpad = ndim - num_cond_dims; size_t x_lpad = ndim - num_x_dims; size_t y_lpad = ndim - num_y_dims; int64_t cond_dim = (i < cond_lpad) ? 1 : cond_dims[i - cond_lpad]; int64_t x_dim = (i < x_lpad) ? 1 : x_dims[i - x_lpad]; int64_t y_dim = (i < y_lpad) ? 1 : y_dims[i - y_lpad]; int64_t max_dim = std::max(x_dim, y_dim); max_dim = std::max(max_dim, cond_dim); ASSERT_TRUE((cond_dim == 1 || cond_dim == max_dim) && (x_dim == 1 || x_dim == max_dim) && (y_dim == 1 || y_dim == max_dim)); broadcast_dims[i] = max_dim; broadcast_cond_dims[i] = (cond_dim == max_dim) ? 1 : max_dim; broadcast_x_dims[i] = (x_dim == max_dim) ? 1 : max_dim; broadcast_y_dims[i] = (y_dim == max_dim) ? 1 : max_dim; extend_cond_dims[i] = cond_dim; extend_x_dims[i] = x_dim; extend_y_dims[i] = y_dim; } size_t cond_size = std::accumulate(extend_cond_dims.begin(), extend_cond_dims.end(), 1, std::multiplies()); size_t x_size = std::accumulate(extend_x_dims.begin(), extend_x_dims.end(), 1, std::multiplies()); size_t y_size = std::accumulate(extend_y_dims.begin(), extend_y_dims.end(), 1, std::multiplies()); size_t z_size = std::accumulate(broadcast_dims.begin(), broadcast_dims.end(), 1, std::multiplies()); size_t cond_byte_size = cond_size * sizeof(CondT); size_t x_byte_size = x_size * sizeof(T); size_t y_byte_size = y_size * sizeof(T); size_t z_byte_size = z_size * sizeof(T); // Eigen contrast Eigen::Tensor tensor_c(extend_cond_dims); Eigen::Tensor tensor_x(extend_x_dims); Eigen::Tensor tensor_y(extend_y_dims); tensor_c.setRandom(); tensor_x.setRandom(); tensor_y.setRandom(); tensor_c = tensor_c.unaryExpr([](T x) -> T { return x > T{0} ? T{1} : T{0}; }); Eigen::Tensor tensor_cond = tensor_c.template cast(); auto broadcast_c = tensor_cond.broadcast(broadcast_cond_dims); auto broadcast_x = tensor_x.broadcast(broadcast_x_dims); auto broadcast_y = tensor_y.broadcast(broadcast_y_dims); Eigen::Tensor tensor_z = broadcast_c.select(broadcast_x, broadcast_y); ASSERT_TRUE(tensor_z.size() == z_size) << tensor_z.size() << " vs. " << z_size << ", "; // test on devices for (auto* device : devices) { if (device->device_type() == DeviceType::kCPU && GetDataType() == DataType::kFloat16) { // CPU matmul not support float16 continue; } ep::test::PinnedMemoryGuard host_cond(device, cond_byte_size); ep::test::PinnedMemoryGuard host_x(device, x_byte_size); ep::test::PinnedMemoryGuard host_y(device, y_byte_size); ep::test::DeviceMemoryGuard cond(device, cond_byte_size); ep::test::DeviceMemoryGuard x(device, x_byte_size); ep::test::DeviceMemoryGuard y(device, y_byte_size); ep::test::DeviceMemoryGuard z(device, z_byte_size); ep::test::PinnedMemoryGuard host_z(device, z_byte_size); std::memcpy(host_cond.ptr(), tensor_cond.data(), cond_byte_size); std::memcpy(host_x.ptr(), tensor_x.data(), x_byte_size); std::memcpy(host_y.ptr(), tensor_y.data(), y_byte_size); ep::test::StreamGuard stream(device); auto h2d = NewPrimitive(device->device_type(), MemcpyKind::kHtoD); auto d2h = NewPrimitive(device->device_type(), MemcpyKind::kDtoH); auto where = NewPrimitive(device->device_type(), GetDataType(), GetDataType(), ndim); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(where.operator bool()); h2d->Launch(stream.stream(), cond.ptr(), host_cond.ptr(), cond_byte_size); h2d->Launch(stream.stream(), x.ptr(), host_x.ptr(), x_byte_size); h2d->Launch(stream.stream(), y.ptr(), host_y.ptr(), y_byte_size); where->Launch(stream.stream(), num_cond_dims, cond_dims, cond.ptr(), num_x_dims, x_dims, x.ptr(), num_y_dims, y_dims, y.ptr(), z.ptr()); d2h->Launch(stream.stream(), host_z.ptr(), z.ptr(), z_byte_size); CHECK_JUST(stream.stream()->Sync()); Eigen::Map, Eigen::Unaligned> eigen_out(tensor_z.data(), tensor_z.size()); Eigen::Map, Eigen::Unaligned> of_out( reinterpret_cast(host_z.ptr()), z_size); ASSERT_TRUE(eigen_out.template isApprox(of_out)); } } template void TestWhere(DeviceManagerRegistry* registry, const std::set& device_types, const std::vector& cond_dims, const std::vector& x_dims, const std::vector& y_dims) { std::vector devices; for (const auto& device_type : device_types) { auto device = registry->GetDevice(device_type, 0); ASSERT_TRUE(device); devices.push_back(device.get()); } TestWhere(devices, cond_dims.size(), cond_dims.data(), x_dims.size(), x_dims.data(), y_dims.size(), y_dims.data()); } template struct random {}; template<> struct random { bool operator()() { static std::default_random_engine e; static std::uniform_int_distribution<> dis(0, 1); return static_cast(dis(e)); } }; template struct random::value>> { T operator()() { static std::default_random_engine e; static std::normal_distribution<> dis(0, 2); return dis(e); } }; template<> struct random { Eigen::half operator()() { static std::default_random_engine e; static std::uniform_real_distribution<> dis(-1, 1); return Eigen::half{dis(e)}; } }; template struct random::value>> { T operator()() { static std::default_random_engine e; static std::uniform_real_distribution<> dis(-1, 1); return dis(e); } }; template void TestScalarWhere(DeviceManagerRegistry* registry, const std::set& device_types) { std::vector devices; for (const auto& device_type : device_types) { auto device_ptr = registry->GetDevice(device_type, 0); ASSERT_TRUE(device_ptr); Device* device = device_ptr.get(); CondT cond = random()(); T x = random()(); T y = random()(); T z = cond ? x : y; ep::test::PinnedMemoryGuard host_cond(device, sizeof(CondT)); ep::test::PinnedMemoryGuard host_x(device, sizeof(T)); ep::test::PinnedMemoryGuard host_y(device, sizeof(T)); ep::test::DeviceMemoryGuard device_cond(device, sizeof(CondT)); ep::test::DeviceMemoryGuard device_x(device, sizeof(T)); ep::test::DeviceMemoryGuard device_y(device, sizeof(T)); ep::test::DeviceMemoryGuard device_z(device, sizeof(T)); ep::test::PinnedMemoryGuard host_z(device, sizeof(T)); std::memcpy(host_cond.ptr(), &cond, sizeof(CondT)); std::memcpy(host_x.ptr(), &x, sizeof(T)); std::memcpy(host_y.ptr(), &y, sizeof(T)); ep::test::StreamGuard stream(device); auto h2d = NewPrimitive(device_type, MemcpyKind::kHtoD); auto d2h = NewPrimitive(device_type, MemcpyKind::kDtoH); auto where = NewPrimitive(device_type, GetDataType(), GetDataType(), 0); ASSERT_TRUE(d2h.operator bool()); ASSERT_TRUE(h2d.operator bool()); ASSERT_TRUE(where.operator bool()); h2d->Launch(stream.stream(), device_cond.ptr(), host_cond.ptr(), sizeof(CondT)); h2d->Launch(stream.stream(), device_x.ptr(), host_x.ptr(), sizeof(T)); h2d->Launch(stream.stream(), device_y.ptr(), host_y.ptr(), sizeof(T)); where->Launch(stream.stream(), 0, nullptr, device_cond.ptr(), 0, nullptr, device_x.ptr(), 0, nullptr, device_y.ptr(), device_z.ptr()); d2h->Launch(stream.stream(), host_z.ptr(), device_z.ptr(), sizeof(T)); CHECK_JUST(stream.stream()->Sync()); ASSERT_TRUE(*host_z.ptr() == z); } } } // namespace TEST_F(PrimitiveTest, TestWhere) { TestWhere(&device_manager_registry_, available_device_types_, {4, 8}, {4, 8}, {4, 8}); TestWhere(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8}, {1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8}, {1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8}, {1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8}, {1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8}, {1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {1, 6}, {2, 6}, {2, 1}); TestWhere(&device_manager_registry_, available_device_types_, {3, 7}, {3, 1}, {1, 7}); TestWhere(&device_manager_registry_, available_device_types_, {1, 4, 8}, {4, 1, 8}, {1, 1, 8}); TestWhere(&device_manager_registry_, available_device_types_, {1, 4, 8}, {4, 4, 8}, {1}); TestWhere(&device_manager_registry_, available_device_types_, {2, 1, 4, 8}, {1, 3, 4, 1}, {4, 8}); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); TestScalarWhere(&device_manager_registry_, available_device_types_); } } // namespace test } // namespace primitive } // namespace ep } // namespace oneflow ================================================ FILE: oneflow/core/ep/test/test_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_EP_TEST_TEST_UTIL_ #define ONEFLOW_CORE_EP_TEST_TEST_UTIL_ #include #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace ep { namespace test { class TestCase : public ::testing::Test { protected: void SetUp() override { for (const auto& device_type : device_manager_registry_.GetRegisteredDeviceTypes()) { // ignore mock device if (device_type == DeviceType::kMockDevice || device_type == DeviceType::kMeta) { continue; } if (device_manager_registry_.GetDeviceManager(device_type)->GetDeviceCount() > 0) { available_device_types_.insert(device_type); } } } void TearDown() override { // do nothing } DeviceManagerRegistry device_manager_registry_; std::set available_device_types_; }; class DeviceMemoryGuard { public: OF_DISALLOW_COPY_AND_MOVE(DeviceMemoryGuard); DeviceMemoryGuard(Device* device, size_t size) : device_(device), options_{} { CHECK_JUST(device_->Alloc(options_, &ptr_, size)); } ~DeviceMemoryGuard() { device_->Free(options_, ptr_); } template T* ptr() { return reinterpret_cast(ptr_); } private: Device* device_; AllocationOptions options_; void* ptr_{}; }; class PinnedMemoryGuard { public: OF_DISALLOW_COPY_AND_MOVE(PinnedMemoryGuard); PinnedMemoryGuard(Device* device, size_t size) : device_(device) { options_.SetPinnedDevice(device->device_type(), 0); CHECK_JUST(device_->AllocPinned(options_, &ptr_, size)); } ~PinnedMemoryGuard() { device_->FreePinned(options_, ptr_); } template T* ptr() { return reinterpret_cast(ptr_); } private: AllocationOptions options_; Device* device_; void* ptr_{}; }; class StreamGuard { public: OF_DISALLOW_COPY_AND_MOVE(StreamGuard); explicit StreamGuard(Device* device) : device_(device) { stream_ = device_->CreateStream(); CHECK_JUST(stream_->OnExecutionContextSetup()); } ~StreamGuard() { CHECK_JUST(stream_->OnExecutionContextTeardown()); device_->DestroyStream(stream_); } Stream* stream() { return stream_; } private: Device* device_; Stream* stream_; }; } // namespace test } // namespace ep } // namespace oneflow #endif // ONEFLOW_CORE_EP_TEST_TEST_UTIL_ ================================================ FILE: oneflow/core/framework/arg_tuple.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/arg_tuple.h" #include namespace oneflow { namespace { std::pair GetPair(const std::string& bn) { int32_t index = 0; const size_t pos = bn.rfind('_'); if (pos != std::string::npos) { index = std::stoi(bn.substr(pos + 1)); } return std::make_pair(bn.substr(0, pos), index); } void InitArgName2BnIndex2TensorTupleIndex( const std::vector>& indexed_arg_pairs, std::unordered_map>* arg_name2bn_index2tensor_tuple_index) { for (int i = 0; i < indexed_arg_pairs.size(); i++) { const auto& pair = indexed_arg_pairs.at(i); const std::string& arg_name = pair.first; const int32_t bn_index = pair.second; // vector is auto created by [] if arg_name doesn't exist in map auto* bn_index2tensor_tuple_index = &(*arg_name2bn_index2tensor_tuple_index)[arg_name]; CHECK_EQ(bn_index2tensor_tuple_index->size(), bn_index) << "Duplicate index of " << arg_name << ": " << bn_index; bn_index2tensor_tuple_index->emplace_back(i); } } } // namespace ArgTuple::ArgTuple(const std::vector& indexed_bns) : indexed_bns_(indexed_bns) { indexed_arg_name_and_index_.reserve(indexed_bns.size()); for (const auto& bn : indexed_bns) { indexed_arg_name_and_index_.emplace_back(GetPair(bn)); } InitArgName2BnIndex2TensorTupleIndex(indexed_arg_name_and_index_, &arg_name2bn_index2tensor_tuple_index_); for (int i = 0; i < indexed_bns.size(); ++i) { bn_in_op2tensor_tuple_index_[indexed_bns.at(i)] = i; } } int32_t ArgTuple::TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const { const auto& map = arg_name2bn_index2tensor_tuple_index_; const auto& iter = map.find(name); if (iter == map.end()) { return -1; } const auto& vec = iter->second; return vec.at(index); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/arg_tuple.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_ #define ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_ #include #include #include namespace oneflow { class ArgTuple final { public: explicit ArgTuple(const std::vector& indexed_bns); ~ArgTuple() = default; std::size_t size() const { return indexed_bns_.size(); } const std::vector& indexed_bns() const { return indexed_bns_; } const std::vector>& indexed_arg_name_and_index() const { return indexed_arg_name_and_index_; } const std::unordered_map>& arg_name2bn_index2tensor_tuple_index() const { return arg_name2bn_index2tensor_tuple_index_; } const std::unordered_map& bn_in_op2tensor_tuple_index() const { return bn_in_op2tensor_tuple_index_; } // return -1 if not found int32_t TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const; private: std::vector indexed_bns_; std::vector> indexed_arg_name_and_index_; std::unordered_map> arg_name2bn_index2tensor_tuple_index_; std::unordered_map bn_in_op2tensor_tuple_index_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_ ================================================ FILE: oneflow/core/framework/attr_map.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/framework/mutable_attr_map.h" namespace oneflow { AttrMap::AttrInternal::AttrInternal() : max_size(0), size(0), hash_value(0), ordered_attr_names(std::make_shared>()) {} AttrMap::AttrInternal::AttrInternal( size_t _max_size, size_t _size, size_t _hash_value, const std::shared_ptr>& _ordered_attr_names) : max_size(_max_size), size(_size), hash_value(_hash_value), ordered_attr_names(_ordered_attr_names) {} AttrMap::AttrMap() : internal_(std::make_shared()) {} AttrMap::AttrMap(const MutableAttrMap& other) : internal_(std::make_shared(other.max_size(), /*size*/ 0, /*hash_value*/ 0, other.ordered_attr_names())) { internal_->attrs.resize(internal_->max_size); for (int i = 0; i < internal_->max_size; ++i) { internal_->attrs[i].second = other.valid_masks()[i]; if (other.valid_masks()[i]) { ++(internal_->size); internal_->attrs[i].first = other.attrs()[i]; // compute hash code HashCombine(&internal_->hash_value, other.attrs()[i]->hash_value()); } } } AttrMap::AttrMap(const UserOpConf& user_conf) : internal_(std::make_shared()) { for (const auto& kv : user_conf.attr()) { auto cpp_attr_value = user_op::AttrValueUtil::ToCppAttrValue(kv.second); if (cpp_attr_value.IsOk()) { ++(internal_->size); internal_->ordered_attr_names->emplace_back(kv.first); internal_->attrs.emplace_back(CHECK_JUST(cpp_attr_value), true); // compute hash code HashCombine(&internal_->hash_value, internal_->attrs.back().first->hash_value()); } else { LOG(ERROR) << user_conf.DebugString() << " failed to convert to cpp attr value, key: " << kv.first; } } internal_->max_size = internal_->size; } AttrMap& AttrMap::operator=(const AttrMap& other) { internal_ = other.internal_; return *this; } bool AttrMap::operator==(const AttrMap& other) const { if (internal_->size != other.internal_->size || internal_->hash_value != other.internal_->hash_value) { return false; } for (int i = 0; i < std::min(internal_->size, other.internal_->size); ++i) { if (internal_->attrs[i].second != other.internal_->attrs[i].second) { return false; } if (internal_->attrs[i].second) { if ((*internal_->ordered_attr_names)[i] != (*other.internal_->ordered_attr_names)[i]) { return false; } if (*(internal_->attrs[i].first) != *(other.internal_->attrs[i].first)) { return false; } } } return true; } template Maybe AttrMap::GetAttr(const std::string& attr_name) const { const auto& attr = Attr4Name(attr_name); CHECK_OR_RETURN(attr) << Error::InvalidValueError() << "no attribute found. attribute name: " << attr_name; const auto* ptr = dynamic_cast*>(attr.get()); CHECK_NOTNULL_OR_RETURN(ptr) << Error::RuntimeError() << "Ptr should be non-null"; return ptr->val(); } const std::shared_ptr& AttrMap::Attr4Name( const std::string& attr_name) const { int idx = internal_->ordered_attr_names->order(attr_name); if (idx >= 0) { return internal_->attrs[idx].first; } static const std::shared_ptr none; return none; } bool AttrMap::Has(const std::string& attr_name) const { return Attr4Name(attr_name) != nullptr; } AttrMap::const_iterator::const_iterator(size_t pos, const AttrMap::AttrInternal* internal) : pos_(pos), internal_(internal) { UpdateKV(); } AttrMap::const_iterator& AttrMap::const_iterator::operator++() { ++pos_; UpdateKV(); return *this; } void AttrMap::const_iterator::UpdateKV() { while (pos_ < internal_->max_size) { if (internal_->attrs[pos_].second) { break; } ++pos_; } if (pos_ < internal_->max_size) { kv_.first = (*internal_->ordered_attr_names)[pos_]; kv_.second = internal_->attrs[pos_].first; } } std::string ComposedAttrMap::ToString() const { std::vector results; for (const auto& attr : prior_) { results.emplace_back(fmt::format("{}={}", attr.first, attr.second->ToString())); } for (const auto& attr : base_) { if (prior_.Has(attr.first)) { continue; } results.emplace_back(fmt::format("{}={}", attr.first, attr.second->ToString())); } return fmt::format("{}", fmt::join(results, ", ")); } AttrMap MakeAttrMapFromUserOpConf(const UserOpConf& user_conf) { return AttrMap(user_conf); } template Maybe ComposedAttrMap::GetAttr(const std::string& attr_name) const { const auto& attr = Attr4Name(attr_name); CHECK_OR_RETURN(attr) << Error::InvalidValueError() << "no attribute found. attribute name: " << attr_name; return dynamic_cast*>(attr.get())->val(); } const std::shared_ptr& ComposedAttrMap::Attr4Name( const std::string& attr_name) const { const auto& prior_attr = prior_.Attr4Name(attr_name); if (prior_attr) { return prior_attr; } return base_.Attr4Name(attr_name); } bool ComposedAttrMap::Has(const std::string& attr_name) const { return Attr4Name(attr_name) != nullptr; } #define DEFINE_ATTR_VALUE_MAP_GET_ATTR(field, T, attr_type) \ template Maybe AttrMap::GetAttr(const std::string& attr_name) const; \ template Maybe ComposedAttrMap::GetAttr(const std::string& attr_name) const; OF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_GET_ATTR, ATTR_SEQ); #undef DEFINE_ATTR_VALUE_MAP_GET_ATTR } // namespace oneflow ================================================ FILE: oneflow/core/framework/attr_map.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ #define ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/small_vector.h" namespace oneflow { namespace user_op { class AttrVal; } class AttrValue; class MutableAttrMap; class UserOpConf; template class OrderedStringList; class AttrMap final { public: AttrMap(); AttrMap(const MutableAttrMap& other); AttrMap(const UserOpConf& user_conf); AttrMap(const AttrMap&) = default; AttrMap(AttrMap&&) = default; ~AttrMap() = default; bool Has(const std::string& attr_name) const; template Maybe GetAttr(const std::string& attr_name) const; const std::shared_ptr& Attr4Name(const std::string& attr_name) const; AttrMap& operator=(const AttrMap& other); bool operator==(const AttrMap& other) const; size_t size() const { return internal_->size; } bool empty() const { return internal_->size > 0; } size_t hash_value() const { return internal_->hash_value; } struct AttrInternal { AttrInternal(); AttrInternal(size_t max_size, size_t size, size_t hash_value, const std::shared_ptr>& ordered_attr_names); size_t max_size; size_t size; size_t hash_value; std::shared_ptr> ordered_attr_names; small_vector, bool>, 8> attrs; }; class const_iterator { public: using const_reference = const std::pair>&; using const_pointer = const std::pair>*; const_iterator(size_t pos, const AttrInternal* internal); ~const_iterator() = default; const_reference operator*() const { return kv_; } const_pointer operator->() const { return &kv_; } const_iterator& operator++(); bool operator==(const const_iterator& x) const { return pos_ == x.pos_ && internal_ == x.internal_; } bool operator!=(const const_iterator& x) const { return !(*this == x); } private: void UpdateKV(); size_t pos_; const AttrInternal* internal_; std::pair> kv_; }; const_iterator begin() const { return const_iterator(0, internal_.get()); } const_iterator end() const { return const_iterator(internal_->max_size, internal_.get()); } private: std::shared_ptr internal_; }; AttrMap MakeAttrMapFromUserOpConf(const UserOpConf& user_conf); class ComposedAttrMap final { public: ComposedAttrMap(const AttrMap& base) : base_(base) {} ComposedAttrMap(const AttrMap& prior, const AttrMap& base) : prior_(prior), base_(base) {} template Maybe GetAttr(const std::string& attr_name) const; const std::shared_ptr& Attr4Name(const std::string& attr_name) const; bool Has(const std::string& attr_name) const; void ResetPrior(const AttrMap& prior) { prior_ = prior; } void ResetBase(const AttrMap& base) { base_ = base; } std::string ToString() const; private: AttrMap prior_; AttrMap base_; }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::AttrMap& attr_map) const { return attr_map.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ ================================================ FILE: oneflow/core/framework/attr_map_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/mutable_attr_map.h" namespace oneflow { namespace test { TEST(AttrMap, basic) { auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP("zero", "one", "zeros", "ones"); mut_attr_map.SetAllAttrs(static_cast(0), static_cast(1), std::vector{0}, std::vector{1}); AttrMap attr_map(mut_attr_map); { const auto& val = CHECK_JUST(attr_map.GetAttr("zero")); ASSERT_EQ(val, 0); } { const auto& val = CHECK_JUST(attr_map.GetAttr("one")); ASSERT_EQ(val, 1); } { const auto& val = CHECK_JUST(attr_map.GetAttr>("zeros")); ASSERT_EQ(val.size(), 1); } { const auto& val = CHECK_JUST(attr_map.GetAttr>("zeros")); ASSERT_EQ(val.at(0), 0); } { const auto& val = CHECK_JUST(attr_map.GetAttr>("ones")); ASSERT_EQ(val.size(), 1); } { const auto& val = CHECK_JUST(attr_map.GetAttr>("ones")); ASSERT_EQ(val.at(0), 1); } } TEST(AttrMap, hash_value) { HashMap attr_map2int_value; auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP("zero", "one", "zeros", "ones"); mut_attr_map.SetAllAttrs(static_cast(0), static_cast(1), std::vector{0}, std::vector{1}); ASSERT_EQ(AttrMap(mut_attr_map).hash_value(), AttrMap(mut_attr_map).hash_value()); ASSERT_TRUE(AttrMap(mut_attr_map) == AttrMap(mut_attr_map)); } TEST(AttrMap, hash_map) { HashMap attr_map2int_value; auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP("zero", "one", "zeros", "ones"); attr_map2int_value[AttrMap(mut_attr_map)] = 0; ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 0); mut_attr_map.SetAttr<0>(static_cast(0)); attr_map2int_value[AttrMap(mut_attr_map)] = 1; ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 1); mut_attr_map.SetAttr<1>(static_cast(1)); attr_map2int_value[AttrMap(mut_attr_map)] = 2; ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 2); mut_attr_map.SetAttr<2>(std::vector{0}); attr_map2int_value[AttrMap(mut_attr_map)] = 3; ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 3); mut_attr_map.SetAttr<3>(std::vector{1}); attr_map2int_value[AttrMap(mut_attr_map)] = 4; ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 4); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/framework/attr_value.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_value.h" namespace oneflow { template const T& AttrValueCast(const user_op::AttrVal& attr_val) { const auto* typed_attr = dynamic_cast*>(&attr_val); return CHECK_NOTNULL(typed_attr)->val(); } template std::shared_ptr CastAttrValue(const T& attr_val) { return std::make_shared>(attr_val); } template std::shared_ptr CastAttrValue(const T* attr_val) { return std::make_shared>(attr_val); } template size_t HashTypedAttrVal(const T& val) { return std::hash()(val); } #define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type) \ template const T& AttrValueCast(const user_op::AttrVal& attr_val); \ template std::shared_ptr CastAttrValue(const T& attr_val); \ template std::shared_ptr CastAttrValue(const T* attr_val); \ template size_t HashTypedAttrVal(const T& attr_val); OF_PP_FOR_EACH_TUPLE(INITIALIZE_ATTR_VALUE_CAST, ATTR_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/framework/attr_value.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ #define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ #include #include "fmt/core.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/hash.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { template size_t HashTypedAttrVal(const T& val); namespace user_op { // SEQ #define BASIC_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_int32, int32_t, AttrType::kAtInt32) \ OF_PP_MAKE_TUPLE_SEQ(at_int64, int64_t, AttrType::kAtInt64) \ OF_PP_MAKE_TUPLE_SEQ(at_bool, bool, AttrType::kAtBool) \ OF_PP_MAKE_TUPLE_SEQ(at_float, float, AttrType::kAtFloat) \ OF_PP_MAKE_TUPLE_SEQ(at_double, double, AttrType::kAtDouble) \ OF_PP_MAKE_TUPLE_SEQ(at_string, std::string, AttrType::kAtString) #define ENUM_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_data_type, DataType, AttrType::kAtDataType) \ OF_PP_MAKE_TUPLE_SEQ(at_memory_format, MemoryFormat, AttrType::kAtMemoryFormat) #define MESSAGE_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, AttrType::kAtShape) \ OF_PP_MAKE_TUPLE_SEQ(at_stride, Stride, AttrType::kAtStride) #define BYTES_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_bytes, std::vector, AttrType::kAtBytes) #define LIST_BASIC_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_list_int32, std::vector, AttrType::kAtListInt32) \ OF_PP_MAKE_TUPLE_SEQ(at_list_int64, std::vector, AttrType::kAtListInt64) \ OF_PP_MAKE_TUPLE_SEQ(at_list_float, std::vector, AttrType::kAtListFloat) #define LIST_ENUM_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_list_data_type, std::vector, AttrType::kAtListDataType) #define LIST_MESSAGE_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector, AttrType::kAtListShape) \ OF_PP_MAKE_TUPLE_SEQ(at_list_stride, std::vector, AttrType::kAtListStride) #define LIST_STRING_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector, AttrType::kAtListString) #define DEVICE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_device, Symbol, AttrType::kAtDevice) #define COMPLEX_DOUBLE_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_complex_double, std::complex, AttrType::kAtComplexDouble) #define ATTR_SEQ \ BASIC_ATTR_SEQ \ ENUM_ATTR_SEQ \ MESSAGE_ATTR_SEQ \ BYTES_ATTR_SEQ \ LIST_BASIC_ATTR_SEQ \ LIST_ENUM_ATTR_SEQ \ LIST_MESSAGE_ATTR_SEQ \ LIST_STRING_ATTR_SEQ \ DEVICE_ATTR_SEQ \ COMPLEX_DOUBLE_ATTR_SEQ // Type Trait: GetAttrType, GetCppType template struct GetAttrType; template struct GetCppType; #define SPECIALIZE_GET_ATTR_TYPE(field, type_cpp, type_proto) \ template<> \ struct GetAttrType : std::integral_constant {}; \ template<> \ struct GetCppType { \ typedef type_cpp type; \ }; OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_ATTR_TYPE, ATTR_SEQ); #undef SPECIALIZE_GET_ATTR_TYPE class AttrVal { public: AttrVal() = default; virtual ~AttrVal() = default; virtual AttrType type() const = 0; virtual size_t hash_value() const = 0; virtual std::string ToString() const = 0; virtual const void* Ptr() const = 0; virtual bool operator==(const AttrVal& other) const = 0; bool operator!=(const AttrVal& other) const { return !(*this == other); } private: OF_DISALLOW_COPY_AND_MOVE(AttrVal); }; template class TypedAttrValIf : public AttrVal { public: virtual const T& val() const = 0; size_t hash_value() const override { return std::hash()(val()); } std::string ToString() const override { return fmt::format("{}", val()); } AttrType type() const override { return GetAttrType::value; } bool operator==(const AttrVal& other) const override { if (other.type() != GetAttrType::value) { return false; } return *static_cast(Ptr()) == *static_cast(other.Ptr()); } }; template class TypedAttrVal final : public TypedAttrValIf { public: TypedAttrVal(T v) : val_(v) {} ~TypedAttrVal() = default; const T& val() const override { return val_; } const void* Ptr() const override { return static_cast(&val_); } size_t hash_value() const override { return std::hash()(val_); } private: OF_DISALLOW_COPY_AND_MOVE(TypedAttrVal); T val_; }; template class TypedAttrValRef final : public TypedAttrValIf { public: TypedAttrValRef(const T* v) : val_(v) {} ~TypedAttrValRef() = default; const T& val() const override { return *val_; } const void* Ptr() const override { return static_cast(val_); } size_t hash_value() const override { return std::hash()(*val_); } private: OF_DISALLOW_COPY_AND_MOVE(TypedAttrValRef); const T* val_; }; } // namespace user_op template const T& AttrValueCast(const user_op::AttrVal& val); template std::shared_ptr CastAttrValue(const T& attr_val); template std::shared_ptr CastAttrValue(const T* attr_val); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ ================================================ FILE: oneflow/core/framework/attr_value_accessor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { namespace user_op { // Basic and Enum Attr #define BASIC_AND_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \ template<> \ cpp_type AttrValueAccessor::Attr(const AttrValue& val) { \ CHECK(val.has_##field()); \ return val.field(); \ } \ template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ attr_val->set_##field(cpp_val); \ } #define BASIC_AND_ENUM_ATTR_SEQ \ BASIC_ATTR_SEQ \ ENUM_ATTR_SEQ OF_PP_FOR_EACH_TUPLE(BASIC_AND_ENUM_ATTR_SEQ_ENTRY, BASIC_AND_ENUM_ATTR_SEQ) #undef BASIC_AND_ENUM_ATTR_SEQ #undef BASIC_AND_ENUM_ATTR_SEQ_ENTRY // Customized Message Attr template<> Shape AttrValueAccessor::Attr(const AttrValue& val) { return Shape(val.at_shape()); } template<> void AttrValueAccessor::Attr(const Shape& cpp_val, AttrValue* attr_val) { cpp_val.ToProto(attr_val->mutable_at_shape()); } template<> Stride AttrValueAccessor::Attr(const AttrValue& val) { return Stride(val.at_stride()); } template<> void AttrValueAccessor::Attr(const Stride& cpp_val, AttrValue* attr_val) { cpp_val.ToProto(attr_val->mutable_at_stride()); } template<> Symbol AttrValueAccessor>::Attr(const AttrValue& val) { auto pb_device = val.at_device(); return CHECK_JUST(Device::New(*CHECK_JUST(DeviceTag4DeviceType(pb_device.device_type())), pb_device.device_id(), pb_device.rematable())); } template<> void AttrValueAccessor>::Attr(const Symbol& cpp_val, AttrValue* attr_val) { attr_val->mutable_at_device()->set_device_type(cpp_val->enum_type()); attr_val->mutable_at_device()->set_device_id(cpp_val->device_id()); attr_val->mutable_at_device()->set_rematable(cpp_val->rematable()); } template<> std::vector AttrValueAccessor>::Attr(const AttrValue& val) { return std::vector(val.at_bytes().begin(), val.at_bytes().end()); } template<> void AttrValueAccessor>::Attr(const std::vector& cpp_val, AttrValue* attr_val) { attr_val->mutable_at_bytes()->assign(cpp_val.data(), cpp_val.size()); } // List of Basic Attr #define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \ template<> \ cpp_type AttrValueAccessor::Attr(const AttrValue& val) { \ return PbRf2StdVec(val.field().val()); \ } \ template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ *(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf(cpp_val); \ } OF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ) #undef LIST_BASIC_ATTR_SEQ_ENTRY // List of Enum Attr #define LIST_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \ template<> \ cpp_type AttrValueAccessor::Attr(const AttrValue& val) { \ std::vector ret; \ ret.reserve(val.field().val_size()); \ for (const auto& value : val.field().val()) { \ ret.emplace_back(static_cast(value)); \ } \ return ret; \ } \ template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ using proto_type = std::remove_reference_tfield().val())>::value_type; \ std::vector vec; \ vec.reserve(cpp_val.size()); \ for (const auto& value : cpp_val) { vec.emplace_back(static_cast(value)); } \ *(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf(vec); \ } OF_PP_FOR_EACH_TUPLE(LIST_ENUM_ATTR_SEQ_ENTRY, LIST_ENUM_ATTR_SEQ) #undef LIST_ENUM_ATTR_SEQ_ENTRY // List of Customized Message Attr template<> std::vector AttrValueAccessor>::Attr(const AttrValue& val) { std::vector ret; ret.reserve(val.at_list_shape().val_size()); for (const auto& value : val.at_list_shape().val()) { ret.emplace_back(value); } return ret; } template<> void AttrValueAccessor>::Attr(const std::vector& cpp_val, AttrValue* attr_val) { attr_val->mutable_at_list_shape()->clear_val(); FOR_RANGE(int32_t, i, 0, cpp_val.size()) { cpp_val.at(i).ToProto(attr_val->mutable_at_list_shape()->add_val()); } } template<> std::vector AttrValueAccessor>::Attr(const AttrValue& val) { std::vector ret; ret.reserve(val.at_list_stride().val_size()); for (const auto& value : val.at_list_stride().val()) { ret.emplace_back(value); } return ret; } template<> void AttrValueAccessor>::Attr(const std::vector& cpp_val, AttrValue* attr_val) { attr_val->mutable_at_list_stride()->clear_val(); FOR_RANGE(int32_t, i, 0, cpp_val.size()) { cpp_val.at(i).ToProto(attr_val->mutable_at_list_stride()->add_val()); } } // List of String Attr template<> std::vector AttrValueAccessor>::Attr(const AttrValue& val) { return PbRpf2StdVec(val.at_list_string().val()); } template<> void AttrValueAccessor>::Attr(const std::vector& cpp_val, AttrValue* attr_val) { *(attr_val->mutable_at_list_string()->mutable_val()) = StdVec2PbRpf(cpp_val); } // ComplexDouble Attr template<> std::complex AttrValueAccessor>::Attr(const AttrValue& val) { std::complex ret{val.at_complex_double().real(), val.at_complex_double().imag()}; return ret; } template<> void AttrValueAccessor>::Attr(const std::complex& cpp_val, AttrValue* attr_val) { attr_val->mutable_at_complex_double()->set_real(cpp_val.real()); attr_val->mutable_at_complex_double()->set_imag(cpp_val.imag()); } template Maybe MakeCppAttrValueFromProtoAttrValue(const ProtoT& attr_value) { switch (static_cast(attr_value.value_case())) { #define MAKE_ENTRY(field, T, attr_type) \ case static_cast(attr_type): \ return std::static_pointer_cast( \ std::make_shared>(AttrValueAccessor::Attr(attr_value))); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); #undef MAKE_ENTRY default: OF_UNIMPLEMENTED(); } } /* static */ Maybe AttrValueUtil::ToCppAttrValue(const AttrValue& proto_attr_value) { return MakeCppAttrValueFromProtoAttrValue(proto_attr_value); } /* static */ Maybe AttrValueUtil::ToProtoAttrValue(const AttrVal& cpp_attr_value, AttrValue* attr_value) { if (false) { // clang-format off #define MAKE_ENTRY(field, cpp_type, attr_type) \ } \ else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ AttrValueAccessor::Attr(ptr->val(), attr_value); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); #undef MAKE_ENTRY // clang-format on } else { OF_UNIMPLEMENTED(); } return Maybe::Ok(); } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/attr_value_accessor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_ #define ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_ #include "oneflow/core/common/maybe.h" namespace oneflow { class AttrValue; namespace user_op { template struct AttrValueAccessor final { static T Attr(const AttrValue&); static void Attr(const T&, AttrValue*); }; class AttrVal; struct AttrValueUtil final { static Maybe ToCppAttrValue(const AttrValue& proto_attr_value); static Maybe ToProtoAttrValue(const AttrVal& cpp_attr_value, AttrValue* attr_value); }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_ ================================================ FILE: oneflow/core/framework/auto_random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/auto_random_generator.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/platform/include/pthread_fork.h" namespace oneflow { namespace one { struct AutoGeneratorState { uint64_t seed = 0; int64_t num = 0; int64_t device_tag_length = 0; int64_t state_length = 0; // std::vector state_sizes[num]; // std::vector device_tags[device_tag_length]; // std::vector states[state_sizes[0] + state_sizes[1] + ... + state_sizes[num - 1]] }; void AutoGenerator::set_current_seed(uint64_t seed) { std::lock_guard lock(mutex_); seed_ = seed; for (const auto& it : generators_) { if (unlikely(pthread_fork::IsForkedSubProcess() && it.first->type() != "cpu")) { continue; } it.second->set_current_seed(seed); } } size_t AutoGenerator::GetStateSize() const { std::lock_guard lock(mutex_); size_t state_size = sizeof(AutoGeneratorState) + generators_.size() * sizeof(uint64_t); std::stringstream ss; auto it = generators_.begin(); if (it != generators_.end()) { ss << it->second->device_type_name() << ":" << it->second->device_index(); ++it; } for (; it != generators_.end(); ++it) { ss << "," << it->second->device_type_name() << ":" << it->second->device_index(); } state_size += ss.str().size(); for (const auto& it : generators_) { state_size += it.second->GetStateSize(); } return state_size; } void AutoGenerator::GetState(size_t state_size, void* state) const { std::lock_guard lock(mutex_); AutoGeneratorState state_info; state_info.seed = current_seed(); state_info.num = generators_.size(); state_info.state_length = 0; std::vector state_sizes; state_sizes.reserve(generators_.size()); for (auto it = generators_.begin(); it != generators_.end(); ++it) { state_sizes.emplace_back(it->second->GetStateSize()); state_info.state_length += state_sizes.back(); } std::stringstream ss; auto it = generators_.begin(); if (it != generators_.end()) { ss << it->second->device_type_name() << ":" << it->second->device_index(); ++it; } for (; it != generators_.end(); ++it) { ss << "," << it->second->device_type_name() << ":" << it->second->device_index(); } std::string device_tags = ss.str(); state_info.device_tag_length = device_tags.size(); size_t total_size = sizeof(AutoGeneratorState) + state_info.num * sizeof(int64_t) + state_info.device_tag_length + state_info.state_length; CHECK_EQ_OR_THROW(state_size, total_size) << "the state size of auto generator should be equal to " << total_size; { uint8_t* data = static_cast(state); memcpy(data, &state_info, sizeof(AutoGeneratorState)); data += sizeof(AutoGeneratorState); memcpy(data, state_sizes.data(), state_info.num * sizeof(int64_t)); data += state_info.num * sizeof(int64_t); memcpy(data, device_tags.data(), state_info.device_tag_length); data += state_info.device_tag_length; int i = 0; for (auto it = generators_.begin(); it != generators_.end(); ++it, ++i) { it->second->GetState(state_sizes[i], data); data += state_sizes[i]; } } } void AutoGenerator::SetState(size_t state_size, const void* state) { AutoGeneratorState state_info; const uint8_t* data = static_cast(state); memcpy(reinterpret_cast(&state_info), data, sizeof(AutoGeneratorState)); if (state_size != sizeof(AutoGeneratorState) + state_info.num * sizeof(int64_t) + state_info.device_tag_length + state_info.state_length) { return THROW(RuntimeError) << "Invalid auto generator state, size is not match."; } data += sizeof(AutoGeneratorState); std::vector state_sizes(state_info.num); std::vector state_data(state_info.num); memcpy(state_sizes.data(), data, state_info.num * sizeof(int64_t)); data += state_info.num * sizeof(int64_t); std::string device_tags; device_tags.resize(state_info.device_tag_length); memcpy(const_cast(device_tags.data()), data, state_info.device_tag_length); data += state_info.device_tag_length; for (int i = 0; i < state_info.num; ++i) { state_data[i] = data; data += state_sizes[i]; } // set current seed. set_current_seed(state_info.seed); std::vector splits; Split(device_tags, ",", [&](std::string&& s) { splits.emplace_back(s); }); if (splits.size() != state_info.num) { return THROW(RuntimeError) << "Invalid auto generator state. The number of state is " << state_info.num << ", but device tags number is " << splits.size(); } std::lock_guard lock(mutex_); for (int i = 0; i < splits.size(); ++i) { const auto& device = CHECK_JUST(Device::ParseAndNew(splits[i])); auto generator = CHECK_JUST(GetOrCreate(device->type(), device->device_id())); generator->SetState(state_sizes[i], state_data[i]); } } Maybe AutoGenerator::GetOrCreate(const std::string& device, int device_index) { if (device_index == -1) { device_index = (device == "cpu" ? 0 : GlobalProcessCtx::LocalRank()); } std::lock_guard lock(mutex_); auto device_key = JUST(Device::New(device, device_index)); auto it = generators_.find(device_key); if (it == generators_.end()) { auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device); if (device_type == DeviceType::kInvalidDevice) { return Error::RuntimeError() << "Expected one of " << PrintGeneratorAvailableDevices() << " device type at start of device string: " << device; } auto device_mgr = Singleton::Get()->GetDeviceManager(device_type); it = generators_.emplace(device_key, device_mgr->CreateRandomGenerator(seed_, device_index)) .first; } return it->second; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/auto_random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_ #include #include #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/ep/include/random_generator.h" #include "oneflow/core/framework/device.h" namespace oneflow { namespace one { class AutoGenerator : public ep::RandomGenerator { public: AutoGenerator(uint64_t seed) : seed_(seed) {} virtual ~AutoGenerator() = default; uint64_t current_seed() const override { return seed_; } void set_current_seed(uint64_t seed) override; std::string device_type_name() const override { return "auto"; } int64_t device_index() const override { return 0; } size_t GetStateSize() const override; void GetState(size_t state_size, void* state) const override; void SetState(size_t state_size, const void* state) override; Maybe GetOrCreate(const std::string& device, int device_index); template Maybe GetOrCreate(int device_index) { return std::dynamic_pointer_cast( JUST(GetOrCreate(ep::GetRandomGeneratorDeviceTypeName(), device_index))); } private: mutable std::mutex mutex_; uint64_t seed_; std::unordered_map, std::shared_ptr> generators_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/framework/autocast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/autocast.h" #include "oneflow/core/job_rewriter/auto_mixed_precision.h" #include "oneflow/core/job_rewriter/auto_mixed_precision_lists.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace autocast { namespace { bool* autocast_enabled() { static thread_local bool autocast_enabled = false; return &autocast_enabled; } DeviceType* autocast_device_type() { static thread_local DeviceType autocast_device_type = kCUDA; return &autocast_device_type; } Symbol* autocast_dtype() { static thread_local Symbol autocast_dtype = DType::Float16(); return &autocast_dtype; } Symbol* autocast_cpu_dtype() { static thread_local Symbol autocast_cpu_dtype = DType::BFloat16(); return &autocast_cpu_dtype; } Symbol* autocast_gpu_dtype() { static thread_local Symbol autocast_gpu_dtype = DType::Float16(); return &autocast_gpu_dtype; } bool* cache_enabled() { static thread_local bool cache_enabled = true; return &cache_enabled; } inline Symbol get_lower_precision_fp_from_device_type(DeviceType device_type) { if (device_type == DeviceType::kCPU) { return get_autocast_cpu_dtype(); }; return get_autocast_gpu_dtype(); } // The structure below is referenced from PyTorch: // https://github.com/pytorch/pytorch/blob/41d79695907cd4105b8e7167cf8a57ba48e1f079/aten/src/ATen/autocast_mode.cpp#L60-L63 // The weakref keeps the source's TensorImpl from being deleted. We need to because we're // using the source TensorImpl* as the key. If it were deleted, another random Tensor could // be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would // then mistakenly hit in cache: a rare, intermittent, unpredictable bug. using val_type = std::pair, std::shared_ptr>; using key_type = std::pair; using cached_map = std::unordered_map; std::unordered_map* cached_casts() { static thread_local std::unordered_map cached_casts; return &cached_casts; } } // namespace bool is_enabled() { return *autocast_enabled(); } void set_enabled(bool enabled) { *autocast_enabled() = enabled; } DeviceType get_autocast_device_type() { return *autocast_device_type(); } void set_autocast_device_type(DeviceType device_type) { *autocast_device_type() = device_type; } Symbol get_autocast_dtype() { return *autocast_dtype(); } Symbol get_autocast_cpu_dtype() { return *autocast_cpu_dtype(); } Symbol get_autocast_gpu_dtype() { return *autocast_gpu_dtype(); } void set_autocast_dtype(Symbol dtype) { *autocast_dtype() = dtype; } void set_autocast_cpu_dtype(Symbol dtype) { *autocast_cpu_dtype() = dtype; } void set_autocast_gpu_dtype(Symbol dtype) { *autocast_gpu_dtype() = dtype; } bool is_autocast_cache_enabled() { return *cache_enabled(); } void set_autocast_cache_enabled(bool enabled) { *cache_enabled() = enabled; } Maybe cached_cast(const std::shared_ptr& tensor, Symbol cast_type, DeviceType device_type) { bool use_cache = (is_autocast_cache_enabled() && tensor->requires_grad() && cast_type == get_lower_precision_fp_from_device_type(device_type) && tensor->dtype()->data_type() == DataType::kFloat && tensor->is_leaf() && !tensor->is_view()); if (use_cache) { auto it = cached_casts()->find( std::make_pair(JUST(tensor->mut_eager_local_tensor_impl()), cast_type->data_type())); if (it == cached_casts()->end() || it->second.first.lock() == nullptr) { const std::shared_ptr& result = JUST(one::functional::To(tensor, cast_type, /*copy*/ false)); if (it == cached_casts()->end()) { cached_casts()->emplace( std::make_pair(JUST(tensor->mut_eager_local_tensor_impl()), cast_type->data_type()), std::make_pair(tensor->weak_from_this(), result)); } else { it->second.first = tensor->weak_from_this(); it->second.second = result; } return result; } else { return it->second.second; } } else { return one::functional::To(tensor, cast_type, /*copy*/ false); } }; void clear_cache() { cached_casts()->clear(); } AutoCastColor AutoCastMeta::autocast_color() const { return autocast_color_; } void AutoCastMeta::set_autocast_color(AutoCastColor color) { autocast_color_ = color; } bool AutoCastMeta::is_autocast_eligible(DeviceType device_type, Symbol dtype) const { int device_index = static_cast(device_type); if (is_autocast_eligible_.size() > device_index) { int dtype_index = static_cast(dtype->data_type()); if (is_autocast_eligible_[device_index].size() > dtype_index) { return is_autocast_eligible_[device_index][dtype_index]; } } return false; } void AutoCastMeta::set_autocast_eligible(DeviceType device_type, Symbol dtype) { int device_index = static_cast(device_type); while (is_autocast_eligible_.size() <= device_index) { is_autocast_eligible_.resize(device_index + 1); } int dtype_index = static_cast(dtype->data_type()); while (is_autocast_eligible_[device_index].size() <= dtype_index) { is_autocast_eligible_[device_index].resize(dtype_index + 1); } is_autocast_eligible_[device_index][dtype_index] = true; } bool AutoCastMeta::is_args_autocast_eligible(int arg_index) const { CHECK_LT_OR_THROW(arg_index, is_args_autocast_eligible_.size()); // NOLINT return is_args_autocast_eligible_[arg_index]; } const std::vector& AutoCastMeta::is_args_autocast_eligible() const { return is_args_autocast_eligible_; } void AutoCastMeta::set_arg_autocast_eligible(int arg_index) { CHECK_LT_OR_THROW(arg_index, is_args_autocast_eligible_.size()); // NOLINT is_args_autocast_eligible_[arg_index] = true; } std::shared_ptr MakeAutoCastMeta( const std::string& op_type_name, const std::vector>& input_args) { auto autocast_meta = std::make_shared(input_args.size()); if (AutoMixedPrecisionLists::WhiteList().count(op_type_name)) { autocast_meta->set_autocast_color(kWhite); } else if (AutoMixedPrecisionLists::GrayList().count(op_type_name)) { autocast_meta->set_autocast_color(kGray); } else if (AutoMixedPrecisionLists::ClearList().count(op_type_name)) { autocast_meta->set_autocast_color(kClear); } else { autocast_meta->set_autocast_color(kBlack); } for (int i = 0; i < input_args.size(); ++i) { if (!amp::IsNoCast(op_type_name, input_args[i])) { autocast_meta->set_arg_autocast_eligible(i); } } // autocast only supports the following device type(s) and low precision type(s): // - device type: CUDA // - low precision type: half, bfloat16 static std::vector autocast_device_types{kCUDA}; static std::vector> autocast_dtypes{DType::Float16(), DType::BFloat16()}; if (autocast_meta->autocast_color() != kBlack) { for (auto device_type : autocast_device_types) { for (auto dtype : autocast_dtypes) { autocast_meta->set_autocast_eligible(device_type, dtype); } } } return autocast_meta; } } // namespace autocast } // namespace oneflow ================================================ FILE: oneflow/core/framework/autocast.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_ #define ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_ #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor.h" namespace oneflow { namespace autocast { bool is_enabled(); void set_enabled(bool enabled); DeviceType get_autocast_device_type(); void set_autocast_device_type(DeviceType device_type); Symbol get_autocast_dtype(); Symbol get_autocast_cpu_dtype(); Symbol get_autocast_gpu_dtype(); void set_autocast_dtype(Symbol dtype); void set_autocast_cpu_dtype(Symbol dtype); void set_autocast_gpu_dtype(Symbol dtype); bool is_autocast_cache_enabled(); void set_autocast_cache_enabled(bool enabled); void clear_cache(); Maybe cached_cast(const std::shared_ptr& tensor, Symbol cast_type, DeviceType device_type); enum AutoCastColor { kNoColor, kWhite, kGray, kClear, kBlack }; class AutoCastMeta final { public: AutoCastMeta() : AutoCastMeta(0) {} explicit AutoCastMeta(int args_num) : autocast_color_(kNoColor), is_args_autocast_eligible_(args_num, false) {} AutoCastColor autocast_color() const; bool is_autocast_eligible(DeviceType device_type, Symbol dtype) const; bool is_args_autocast_eligible(int arg_index) const; const std::vector& is_args_autocast_eligible() const; void set_autocast_color(AutoCastColor color); void set_autocast_eligible(DeviceType device_type, Symbol dtype); void set_arg_autocast_eligible(int arg_index); private: AutoCastColor autocast_color_; std::vector> is_autocast_eligible_; std::vector is_args_autocast_eligible_; }; std::shared_ptr MakeAutoCastMeta( const std::string& op_type_name, const std::vector>& input_args); } // namespace autocast } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_ ================================================ FILE: oneflow/core/framework/compute_complexity_fn_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { class Shape; namespace user_op { class UserOpDefWrapper; class ComputeComplexityFnContext { public: virtual ~ComputeComplexityFnContext() = default; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0; virtual const NdSbpSignature* GetNdSbpSignature() const = 0; template T Attr(const std::string& attr_name) const { return conf_.attr(attr_name); } virtual const ParallelDesc& parallel_desc() const = 0; virtual bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const = 0; const UserOpConfWrapper& user_op_conf() const { return conf_; } protected: explicit ComputeComplexityFnContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {} private: UserOpConfWrapper conf_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace { template ConfigDef* MutGlobalConfigDef() { static ConfigDef config_def; return &config_def; } template AttrValue* AddAttrDef(const std::string& name, const std::string& description) { auto* name2flag_def = MutGlobalConfigDef()->mutable_attr_name2attr_def(); CHECK(name2flag_def->find(name) == name2flag_def->end()) << "Duplicate attribute: " << name; auto* flag_def = &(*name2flag_def)[name]; flag_def->set_name(name); flag_def->set_description(description); return flag_def->mutable_default_val(); } } // namespace const ConfigDef& GlobalEnvConfigDef() { return *MutGlobalConfigDef(); } const ConfigDef& GlobalSessionConfigDef() { return *MutGlobalConfigDef(); } const ConfigDef& GlobalFunctionConfigDef() { return *MutGlobalConfigDef(); } const ConfigDef& GlobalScopeConfigDef() { return *MutGlobalConfigDef(); } template const ConfigDefBuidler& ConfigDefBuidler::Bool( const std::string& name, bool default_val, const std::string& description) const { AddAttrDef(name, description)->set_at_bool(default_val); return *this; } template const ConfigDefBuidler& ConfigDefBuidler::Int64( const std::string& name, int64_t default_val, const std::string& description) const { AddAttrDef(name, description)->set_at_int64(default_val); return *this; } template const ConfigDefBuidler& ConfigDefBuidler::Double( const std::string& name, double default_val, const std::string& description) const { AddAttrDef(name, description)->set_at_double(default_val); return *this; } template const ConfigDefBuidler& ConfigDefBuidler::String( const std::string& name, const std::string& default_val, const std::string& description) const { AddAttrDef(name, description)->set_at_string(default_val); return *this; } template const ConfigDefBuidler& ConfigDefBuidler::ListInt64( const std::string& name, const std::vector& default_val, const std::string& description) const { auto* list = AddAttrDef(name, description)->mutable_at_list_int64(); *list->mutable_val() = {default_val.begin(), default_val.end()}; return *this; } template struct ConfigDefBuidler; template struct ConfigDefBuidler; template struct ConfigDefBuidler; template struct ConfigDefBuidler; } // namespace oneflow ================================================ FILE: oneflow/core/framework/config_def.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CONFIG_DEF_H_ #define ONEFLOW_CORE_JOB_CONFIG_DEF_H_ #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/framework/config_def.pb.h" namespace oneflow { template struct ConfigDefBuidler final { const ConfigDefBuidler& Bool(const std::string& name, bool default_val, const std::string& description) const; const ConfigDefBuidler& Int64(const std::string& name, int64_t default_val, const std::string& description) const; const ConfigDefBuidler& Double(const std::string& name, double default_val, const std::string& description) const; const ConfigDefBuidler& String(const std::string& name, const std::string& default_val, const std::string& description) const; const ConfigDefBuidler& ListInt64(const std::string& name, const std::vector& default_val, const std::string& description) const; }; #define REGISTER_ENV_CONFIG_DEF() REGISTER_CONFIG_DEF(kEnvConfigDefType) #define REGISTER_SESSION_CONFIG_DEF() REGISTER_CONFIG_DEF(kSessionConfigDefType) #define REGISTER_FUNCTION_CONFIG_DEF() REGISTER_CONFIG_DEF(kFunctionConfigDefType) #define REGISTER_SCOPE_CONFIG_DEF() REGISTER_CONFIG_DEF(kScopeConfigDefType) #define REGISTER_CONFIG_DEF(config_def_type) \ static ConfigDefBuidler OF_PP_CAT(g_##config_def_type##_def_, __COUNTER__) = \ ConfigDefBuidler() const ConfigDef& GlobalEnvConfigDef(); const ConfigDef& GlobalSessionConfigDef(); const ConfigDef& GlobalFunctionConfigDef(); const ConfigDef& GlobalScopeConfigDef(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CONFIG_DEF_H_ ================================================ FILE: oneflow/core/framework/config_def.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/framework/user_op_attr.proto"; enum ConfigDefType { kEnvConfigDefType = 1; kSessionConfigDefType = 2; kFunctionConfigDefType = 3; kScopeConfigDefType = 4; } message ConfigDef { map attr_name2attr_def = 1; } ================================================ FILE: oneflow/core/framework/consistency_check.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/framework/synced_symbol_map.h" #include "oneflow/core/framework/sync_symbol_nd_sbp.h" #include "oneflow/core/framework/sync_symbol_parallel_desc.h" #include "oneflow/core/common/constant.h" #include "oneflow/core/common/check_level.h" #include "oneflow/core/framework/sync_symbol_global_tensor_meta.h" namespace oneflow { namespace { struct FlatMetaInfoConsistency; class CheckMetaInfoConsistencyAsyncTransportCtx : public AsyncTransportCtx { public: CheckMetaInfoConsistencyAsyncTransportCtx(const TransportToken& transport_token, const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp) : AsyncTransportCtx(transport_token), placement_(placement), nd_sbp_(nd_sbp), grad_nd_sbp_(grad_nd_sbp) {} ~CheckMetaInfoConsistencyAsyncTransportCtx() override = default; Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override; Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override; Maybe Check() const; private: Symbol placement_; Optional> nd_sbp_; Optional> grad_nd_sbp_; std::shared_ptr flat_meta_info_consistency_; }; // clang-format off FLAT_MSG_BEGIN(FlatMetaInfoConsistency); public: static Maybe New() { const auto& consistency = std::make_shared(); consistency->clear(); return consistency; } static Maybe New(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { const auto& consistency = std::make_shared(); consistency->clear(); JUST(consistency->Init(placement, nd_sbp, grad_nd_sbp)); return consistency; } Maybe Check(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { const auto& this_placement = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->placement_symbol_id())); CHECK_OR_RETURN(this_placement == placement) << Error::RuntimeError() << "Each rank must have the same input placement"; CHECK_EQ_OR_RETURN(nd_sbp.has_value(), this->has_nd_sbp_symbol_id()) << Error::RuntimeError() << "Either all ranks have sbp or not"; if (this->has_nd_sbp_symbol_id()) { const auto& that_nd_sbp = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->nd_sbp_symbol_id())); const auto& this_nd_sbp = JUST(nd_sbp); CHECK_OR_RETURN(this_nd_sbp == that_nd_sbp) << Error::RuntimeError() << "Each rank must have the same input sbp"; } CHECK_EQ_OR_RETURN(grad_nd_sbp.has_value(), this->has_grad_nd_sbp_symbol_id()) << Error::RuntimeError() << "Either all ranks have grad sbp or not"; if (this->has_grad_nd_sbp_symbol_id()) { const auto& that_grad_nd_sbp = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->grad_nd_sbp_symbol_id())); const auto& this_grad_nd_sbp = JUST(grad_nd_sbp); CHECK_OR_RETURN(this_grad_nd_sbp == that_grad_nd_sbp)<< Error::RuntimeError() << "Each rank must have same input grad sbp"; } return Maybe::Ok(); } private: Maybe Init(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { this->set_placement_symbol_id( JUST(SyncedSymbolMap::FindOrSync(placement, &SyncSymbolParallelDesc))); if (nd_sbp.has_value()) { this->set_nd_sbp_symbol_id( JUST(SyncedSymbolMap::FindOrSync(JUST(nd_sbp), &SyncSymbolNdSbp))); } if (grad_nd_sbp.has_value()) { this->set_grad_nd_sbp_symbol_id( JUST(SyncedSymbolMap::FindOrSync(JUST(grad_nd_sbp), &SyncSymbolNdSbp))); } return Maybe::Ok(); } FLAT_MSG_DEFINE_OPTIONAL(uint64_t, placement_symbol_id); FLAT_MSG_DEFINE_OPTIONAL(uint64_t, nd_sbp_symbol_id); FLAT_MSG_DEFINE_OPTIONAL(uint64_t, grad_nd_sbp_symbol_id); FLAT_MSG_END(FlatMetaInfoConsistency); // clang-format on Maybe CheckMetaInfoConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& meta_info_consistency = JUST(FlatMetaInfoConsistency::New(placement_, nd_sbp_, grad_nd_sbp_)); *buffer = meta_info_consistency.get(); *size = sizeof(FlatMetaInfoConsistency); *Callback = [meta_info_consistency] {}; return Maybe::Ok(); } Maybe CheckMetaInfoConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& flat_meta_info_consistency = JUST(FlatMetaInfoConsistency::New()); *buffer = flat_meta_info_consistency.get(); *size = sizeof(FlatMetaInfoConsistency); *Callback = [flat_meta_info_consistency]() {}; flat_meta_info_consistency_ = flat_meta_info_consistency; return Maybe::Ok(); } Maybe CheckMetaInfoConsistencyAsyncTransportCtx::Check() const { if (!flat_meta_info_consistency_) { return Maybe::Ok(); } JUST(flat_meta_info_consistency_->Check(placement_, nd_sbp_, grad_nd_sbp_)); return Maybe::Ok(); } } // namespace Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, Symbol placement) { if (!placement->containing_current_rank() || placement->parallel_num() == 1) { return Maybe::Ok(); } const auto& rank_group = JUST(RankGroup::New(placement)); std::vector recv_buffer(buffer_size); char* recv_ptr = recv_buffer.data(); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(buffer_ptr); *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = buffer_size; *Cb = [] {}; return Maybe::Ok(); }); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg); CHECK_OR_RETURN(std::memcmp(buffer_ptr, reinterpret_cast(recv_ptr), buffer_size) == 0) << Error::RuntimeError() << "Each rank must have same input sequence or numpy array"; return Maybe::Ok(); } namespace { Maybe MetaInfoConsistencyCheckUtil(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckRankGroupConsistency)); const auto& ctx = std::make_shared( transport_token, placement, nd_sbp, grad_nd_sbp); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); JUST_MSG(ctx->WaitDone(), kAsymmetricCodeErrorMsg); JUST(ctx->Check()); return Maybe::Ok(); } int64_t* MutThreadLocalMetaInfoConsistencyCheckDepth() { static thread_local int64_t recursive_depth = 0; return &recursive_depth; } inline bool IsMetaInfoConsistencyCheckDisable() { return *MutThreadLocalMetaInfoConsistencyCheckDepth() > 1; } } // namespace NonRecursiveMetaInfoConsistencyCheckScope::NonRecursiveMetaInfoConsistencyCheckScope() { auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth(); ++*recursive_depth; } NonRecursiveMetaInfoConsistencyCheckScope::~NonRecursiveMetaInfoConsistencyCheckScope() { auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth(); --*recursive_depth; } Maybe MetaInfoConsistencyCheck(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp, const size_t debug_level, bool force_check) { if ((IsEnvEnabled(debug_level) || force_check) && !IsMetaInfoConsistencyCheckDisable()) { JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, grad_nd_sbp)); } return Maybe::Ok(); } Maybe MetaInfoConsistencyCheck(const Symbol& placement, const Optional>& nd_sbp, const size_t debug_level, bool force_check) { if ((IsEnvEnabled(debug_level) || force_check) && !IsMetaInfoConsistencyCheckDisable()) { JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, Optional>())); } return Maybe::Ok(); } Maybe MetaInfoConsistencyCheck(const Symbol& placement, const std::vector>& sbp_tuple, const std::vector>& grad_sbp_tuple, const size_t debug_level, bool force_check) { Optional> nd_sbp; Optional> grad_nd_sbp; if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); } if (!grad_sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(grad_sbp_tuple)); } JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp, debug_level, force_check)); return Maybe::Ok(); } Maybe MetaInfoConsistencyCheck(const Symbol& placement, const std::vector>& sbp_tuple, const size_t debug_level, bool force_check) { Optional> nd_sbp; Optional> grad_nd_sbp; if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); } JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp, debug_level, force_check)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/consistency_check.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ #define ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/tensor_meta.h" namespace oneflow { class NonRecursiveMetaInfoConsistencyCheckScope final { public: OF_DISALLOW_COPY_AND_MOVE(NonRecursiveMetaInfoConsistencyCheckScope); NonRecursiveMetaInfoConsistencyCheckScope(); ~NonRecursiveMetaInfoConsistencyCheckScope(); }; Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, Symbol placement); Maybe MetaInfoConsistencyCheck(const Symbol& placement, const Optional>& nd_sbp, const Optional>& grad_nd_sbp, const size_t debug_level, bool force_check); Maybe MetaInfoConsistencyCheck(const Symbol& placement, const Optional>& nd_sbp, const size_t debug_level, bool force_check); Maybe MetaInfoConsistencyCheck(const Symbol& placement, const std::vector>& sbp_tuple, const std::vector>& grad_sbp_tuple, const size_t debug_level, bool force_check); Maybe MetaInfoConsistencyCheck(const Symbol& placement, const std::vector>& sbp_tuple, const size_t debug_level, bool force_check); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ ================================================ FILE: oneflow/core/framework/device.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/device.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/to_string.h" namespace oneflow { namespace { void CheckDeviceType(const std::string& type) { if (!TRY(DeviceType4DeviceTag(type)).IsOk()) { std::string error_msg = "Expected one of " + PrintAvailableDevices() + " device type at start of device string: " + type; throw std::runtime_error(error_msg); } } } // namespace Device::Device(const std::string& type, int64_t device_id, bool rematable) : type_(type), enum_type_(kInvalidDevice), device_id_(device_id), rematable_(rematable), hash_value_(Hash(type, device_id, rematable)) {} Maybe Device::Init() { if (type_ == "auto") { return Maybe::Ok(); } enum_type_ = JUST(DeviceType4DeviceTag(type())); { DeviceType dev_type = enum_type_; if (dev_type == kMockDevice) { dev_type = DeviceType::kCPU; } mem_case_ = memory::MakeMemCaseShared(enum_type_, device_id_); } return Maybe::Ok(); } /* static */ Maybe> Device::New(const std::string& type, int64_t device_id, bool rematable) { CHECK_GE_OR_RETURN(device_id, 0) << Error::InvalidValueError() << "Device ID should be non-negative"; static thread_local HashMap, Symbol> map; auto key = std::make_tuple(type, device_id, rematable); auto iter = map.find(key); if (iter == map.end()) { Device device(type, device_id, rematable); JUST(device.Init()); iter = map.emplace(key, SymbolOf(device)).first; } return iter->second; } /* static */ Maybe> Device::New(const std::string& type, int64_t device_id) { return New(type, device_id, false); } /* static */ Maybe> Device::New(const std::string& type) { return New(type, GlobalProcessCtx::LocalRank()); } /* static */ Maybe> Device::ParseAndNew(const std::string& device_str) { static thread_local HashMap> map; auto iter = map.find(device_str); if (iter == map.end()) { auto [type, device_id, rematable] = *JUST(ParseDeviceString(device_str)); CheckDeviceType(type); if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); } Device device(type, device_id, rematable); JUST(device.Init()); iter = map.emplace(device_str, SymbolOf(device)).first; } return iter->second; } std::string Device::ToRepr() const { auto rematable_suffix = ""; if (rematable_) { rematable_suffix = ", rematable=True"; } return fmt::format("device(type='{}', index={}{})", type_, device_id_, rematable_suffix); } std::ostream& operator<<(std::ostream& os, Symbol device) { os << device->ToRepr(); return os; } std::string Device::ToString() const { auto rematable_suffix = ""; if (rematable_) { rematable_suffix = "+remat"; } return fmt::format("{}:{}{}", type_, device_id_, rematable_suffix); } Maybe> Device::MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc) { const std::string& type = parallel_desc.device_tag(); std::vector machine_device_ids; machine_device_ids.reserve(parallel_desc.parallel_conf().device_name().size()); for (const auto& item : parallel_desc.parallel_conf().device_name()) { machine_device_ids.emplace_back(item); } CHECK_EQ_OR_RETURN(machine_device_ids.size(), 1) << Error::InvalidValueError() << "Number of machine device should be one"; const std::string& machine_device_id = machine_device_ids.at(0); size_t pos = machine_device_id.find(':'); CHECK_NE_OR_RETURN(pos, std::string::npos) << Error::InvalidValueError() << "Invalid device ID: " << machine_device_id; std::string device_id = machine_device_id.substr(pos + 1); CHECK_EQ_OR_RETURN(device_id.find('-'), std::string::npos) << Error::InvalidValueError() << "Device ID should be non-negative"; CHECK_OR_RETURN(IsStrInt(device_id)) << Error::InvalidValueError() << "Device ID is not integer: " << device_id; return Device::New(type, std::stoi(device_id)); } namespace { Maybe> RawGetPlacement(const Device& device) { std::string machine_device_id = "@" + std::to_string(GlobalProcessCtx::Rank()) + ":" + std::to_string(device.device_id()); ParallelConf parallel_conf; parallel_conf.set_device_tag(device.type()); parallel_conf.add_device_name(machine_device_id); return SymbolOf(ParallelDesc(parallel_conf)); } Maybe> RawPlacement4Device(Symbol device) { return RawGetPlacement(*device); } } // namespace decltype(Device::GetPlacement) Device::GetPlacement = DECORATE(&RawGetPlacement, ThreadLocalCopiable); decltype(Placement4Device) Placement4Device = DECORATE(&RawPlacement4Device, ThreadLocal); Maybe> ParseDeviceString(std::string device_str) { bool rematable = false; if (device_str.size() > 6 && device_str.substr(device_str.size() - 6, 6) == "+remat") { rematable = true; device_str = device_str.substr(0, device_str.size() - 6); } std::string::size_type pos = device_str.find(':'); if (pos == std::string::npos) { return std::make_tuple(device_str, -1, rematable); } else { std::string index_str = device_str.substr(pos + 1); CHECK_OR_RETURN(IsStrInt(index_str)) << Error::InvalidValueError() << "Invalid device tag " << device_str; return std::make_tuple(device_str.substr(0, pos), std::stoi(index_str), rematable); } } } // namespace oneflow ================================================ FILE: oneflow/core/framework/device.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_DEVICE_H_ #define ONEFLOW_CORE_FRAMEWORK_DEVICE_H_ #include #include #include #include #include #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" namespace oneflow { class ParallelDesc; class MemoryCase; inline size_t GetInstructionHighWaterMark() { return 40000; } inline size_t GetInstructionLowWaterMark() { return 20000; } class Device final { public: Device(const Device&) = default; Device(Device&&) = default; ~Device() = default; Device& operator=(const Device&) = delete; const std::string& type() const { return type_; } DeviceType enum_type() const { return enum_type_; } int64_t device_id() const { return device_id_; } bool rematable() const { return rematable_; } std::string ToString() const; std::string ToRepr() const; size_t hash_value() const { return hash_value_; } bool operator==(const Device& device) const { return type_ == device.type() && device_id_ == device.device_id() && rematable_ == device.rematable(); } bool operator!=(const Device& device) const { return !operator==(device); } const std::shared_ptr& mem_case() const { return mem_case_; } static Maybe> New(const std::string& type, int64_t device_id, bool rematable); static Maybe> New(const std::string& type, int64_t device_id); static Maybe> New(const std::string& type); static Maybe> ParseAndNew(const std::string& type_or_type_with_device_id); static Maybe> MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc); static Maybe> (*GetPlacement)(const Device& device); private: Device(const std::string& type, int64_t device_id, bool rematable); Maybe Init(); const std::string type_; DeviceType enum_type_; const int64_t device_id_; bool rematable_; const size_t hash_value_; std::shared_ptr mem_case_; }; std::ostream& operator<<(std::ostream& os, Symbol device); extern Maybe> (*Placement4Device)(Symbol device); Maybe> ParseDeviceString(std::string device_str); } // namespace oneflow template<> struct fmt::formatter> : ostream_formatter {}; namespace std { template<> struct hash final { size_t operator()(const oneflow::Device& device) const { return device.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_DEVICE_H_ ================================================ FILE: oneflow/core/framework/dtype.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "half.hpp" #include "oneflow/core/common/util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/data_type_seq.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/dtype.h" namespace oneflow { namespace { template std::size_t GetDataTypeBytes() { return sizeof(T); } #define MAKE_DATA_TYPE_BYTES_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC( std::size_t, GetDataTypeBytes, MAKE_DATA_TYPE_BYTES_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ)); class DTypeMeta final { public: DTypeMeta(const std::string& name, bool is_signed, bool is_integer, bool is_floating_point, bool is_complex) : name_(name), is_signed_(is_signed), is_integer_(is_integer), is_floating_point_(is_floating_point), is_complex_(is_complex) {} DTypeMeta(const DTypeMeta&) = default; DTypeMeta(DTypeMeta&) = default; ~DTypeMeta() = default; const std::string& name() const { return name_; } bool is_signed() const { return is_signed_; } bool is_integer() const { return is_integer_; } bool is_floating_point() const { return is_floating_point_; } bool is_complex() const { return is_complex_; } private: const std::string name_; const bool is_signed_; const bool is_integer_; const bool is_floating_point_; const bool is_complex_; }; Maybe DTypeMeta4DataType(DataType data_type) { static const HashMap data_type2dtype_meta{ {DataType::kInvalidDataType, DTypeMeta("oneflow.invalid_data_type", false, false, false, false)}, {DataType::kChar, DTypeMeta("oneflow.char", false, false, false, false)}, {DataType::kFloat16, DTypeMeta("oneflow.float16", true, false, true, false)}, {DataType::kFloat, DTypeMeta("oneflow.float32", true, false, true, false)}, {DataType::kDouble, DTypeMeta("oneflow.float64", true, false, true, false)}, {DataType::kInt8, DTypeMeta("oneflow.int8", true, true, false, false)}, {DataType::kInt16, DTypeMeta("oneflow.int16", true, true, false, false)}, {DataType::kInt32, DTypeMeta("oneflow.int32", true, true, false, false)}, {DataType::kInt64, DTypeMeta("oneflow.int64", true, true, false, false)}, {DataType::kInt128, DTypeMeta("oneflow.int128", true, true, false, false)}, {DataType::kUInt8, DTypeMeta("oneflow.uint8", false, true, false, false)}, {DataType::kUInt16, DTypeMeta("oneflow.uint16", false, true, false, false)}, {DataType::kUInt32, DTypeMeta("oneflow.uint32", false, true, false, false)}, {DataType::kUInt64, DTypeMeta("oneflow.uint64", false, true, false, false)}, {DataType::kUInt128, DTypeMeta("oneflow.uint128", false, true, false, false)}, {DataType::kOFRecord, DTypeMeta("oneflow.of_record", false, false, false, false)}, {DataType::kTensorBuffer, DTypeMeta("oneflow.tensor_buffer", false, false, false, false)}, {DataType::kBFloat16, DTypeMeta("oneflow.bfloat16", true, false, true, false)}, {DataType::kBool, DTypeMeta("oneflow.bool", false, false, false, false)}, {DataType::kComplex32, DTypeMeta("oneflow.complex32", false, false, false, true)}, {DataType::kComplex64, DTypeMeta("oneflow.complex64", false, false, false, true)}, {DataType::kComplex128, DTypeMeta("oneflow.complex128", false, false, false, true)}, }; return MapAt(data_type2dtype_meta, data_type); }; } // namespace Maybe&> DType::Get(DataType data_type) { static HashMap> data_type2dtype{ #define MAKE_ENTRY(data_type) {OF_PP_CAT(DataType::k, data_type), data_type()}, OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, DTYPE_SEQ) #undef MAKE_ENTRY }; return MapAt(data_type2dtype, data_type); } Maybe DType::bytes() const { // DataType::OFRecord and DataType::TensorBuffer don't have fixed byte size if (data_type() == DataType::kInvalidDataType || data_type() == DataType::kOFRecord || data_type() == DataType::kTensorBuffer) { OF_UNIMPLEMENTED(); } return SwitchGetDataTypeBytes(SwitchCase(data_type())); } bool DType::is_signed() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_signed(); } bool DType::is_complex() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_complex(); } /* The order of datatype is: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u4 u8 u16 i2 i16 cp4 cp8 cp16 The priority order of datatype is: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 iv < b1 < u1 < c1 < i1 < i2 < u4 < i4 < u8 < i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 < cp16 < bf < re < bu. */ const int DType::priority_order[DataType_ARRAYSIZE] = {0, /*kInvalid*/ 3, /*kChar*/ 14, /*kFloat32*/ 15, /*kDouble*/ 4, /*kInt8*/ 8, /*kInt32*/ 10, /*kInt64*/ 2, /*kUInt8*/ 20, /*kOFRecord*/ 13, /*kFloat16*/ 21, /*kTensorBuffer*/ 19, /*kBFloat16*/ 1, /*kBool*/ 5, /*kUint16*/ 7, /*kUint32*/ 9, /*kUint64*/ 11, /*kUint128*/ 6, /*kInt16*/ 12, /*kInt128*/ 16, /*kComplex32*/ 17, /*kComplex64*/ 18 /*kComplex128*/}; bool DType::is_integer() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_integer(); } bool DType::is_floating_point() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_floating_point(); } const std::string& DType::name() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).name(); } #define DEFINE_GET_DATA_TYPE_FUNCTION(data_type) \ const Symbol& DType::data_type() { \ static const auto& dtype = SymbolOf(DType(OF_PP_CAT(DataType::k, data_type))); \ return dtype; \ } OF_PP_FOR_EACH_TUPLE(DEFINE_GET_DATA_TYPE_FUNCTION, DTYPE_SEQ) #undef DEFINE_GET_DATA_TYPE_FUNCTION Symbol promoteTypes(const Symbol a, const Symbol b) { const Symbol iv = CHECK_JUST(DType::Get(DataType::kInvalidDataType)); const Symbol c1 = CHECK_JUST(DType::Get(DataType::kChar)); const Symbol f4 = CHECK_JUST(DType::Get(DataType::kFloat)); const Symbol f8 = CHECK_JUST(DType::Get(DataType::kDouble)); const Symbol i1 = CHECK_JUST(DType::Get(DataType::kInt8)); const Symbol i4 = CHECK_JUST(DType::Get(DataType::kInt32)); const Symbol i8 = CHECK_JUST(DType::Get(DataType::kInt64)); const Symbol u1 = CHECK_JUST(DType::Get(DataType::kUInt8)); const Symbol re = CHECK_JUST(DType::Get(DataType::kOFRecord)); const Symbol f2 = CHECK_JUST(DType::Get(DataType::kFloat16)); const Symbol bu = CHECK_JUST(DType::Get(DataType::kTensorBuffer)); const Symbol bf = CHECK_JUST(DType::Get(DataType::kBFloat16)); const Symbol b1 = CHECK_JUST(DType::Get(DataType::kBool)); const Symbol u2 = CHECK_JUST(DType::Get(DataType::kUInt16)); const Symbol u4 = CHECK_JUST(DType::Get(DataType::kUInt32)); const Symbol u8 = CHECK_JUST(DType::Get(DataType::kUInt64)); const Symbol u16 = CHECK_JUST(DType::Get(DataType::kUInt128)); const Symbol i2 = CHECK_JUST(DType::Get(DataType::kInt16)); const Symbol i16 = CHECK_JUST(DType::Get(DataType::kInt128)); const Symbol cp4 = CHECK_JUST(DType::Get(DataType::kComplex32)); const Symbol cp8 = CHECK_JUST(DType::Get(DataType::kComplex64)); const Symbol cp16 = CHECK_JUST(DType::Get(DataType::kComplex128)); /* It is consistent with data_type.proto(except kInvalidDataType, kOFRecord and kTensorBuffer) kInvalidDataType = 0; kChar = 1; kFloat = 2; kDouble = 3; kInt8 = 4; kInt32 = 5; kInt64 = 6; kUInt8 = 7; kOFRecord = 8; kFloat16 = 9; kTensorBuffer = 10; kBFloat16 = 11; kBool = 12; kUInt16 = 13; kUInt32 = 14; kUInt64 = 15; kUInt128 = 16; kInt16 = 17; kInt128 = 18; kComplex32 = 19; kComplex64 = 20; kComplex128 = 21; The priority order of datatype is: iv < b1 < u1 < c1 < i1 < u2 < i2 < u4 < i4 < u8 < i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 < cp16 < bf < re < bu. When int8 + uint8, it need to promote to int16, etc. But in int8 + uint128, we should promote to int256, but it is not exist, so we set as Invalid. The new DataType should be add in the end of proto, and the Loopup table should be maintained as right priority (author:zhengzekang). */ // clang-format off static const Symbol _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = { /* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u2 u4 u8 u16 i2 i16 cp4 cp8 cp16 */ /* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf, b1, u2, u4, u8, u16, i2, i16, cp4, cp8, cp16}, /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp8, cp16}, /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp8, cp16}, /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp8, cp16}, /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp8, cp16}, /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, /* re */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp8, cp16}, /* bu */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, bu, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp8, cp16}, /* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp8, cp16}, /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp8, cp16}, /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp8, cp16}, /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp8, cp16}, /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp8, cp16}, /* cp4 */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, cp4, cp8, cp16}, /* cp8 */ {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv, cp8, iv, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp16}, /* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv, cp16,iv, cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}}; // clang-format on return _promoteTypesLookup[static_cast(a->data_type())][static_cast(b->data_type())]; } namespace { std::mutex default_dtype_mutex; Symbol* GetMutDefaultDTypeSymbol() { static Symbol default_dtype = CHECK_JUST(DType::Get(DataType::kFloat)); return &default_dtype; } } // namespace Maybe SetDefaultDType(const Symbol& dtype) { std::lock_guard lock(default_dtype_mutex); CHECK_OR_RETURN(dtype->is_floating_point()) << "only floating-point types are supported as the default type"; *GetMutDefaultDTypeSymbol() = dtype; return Maybe::Ok(); } Symbol GetDefaultDType() { std::lock_guard lock(default_dtype_mutex); return *GetMutDefaultDTypeSymbol(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/dtype.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_DTYPE_H_ #define ONEFLOW_CORE_FRAMEWORK_DTYPE_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/symbol.h" namespace oneflow { #define DTYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(InvalidDataType) \ OF_PP_MAKE_TUPLE_SEQ(Bool) \ OF_PP_MAKE_TUPLE_SEQ(Char) \ OF_PP_MAKE_TUPLE_SEQ(Float16) \ OF_PP_MAKE_TUPLE_SEQ(Float) \ OF_PP_MAKE_TUPLE_SEQ(Double) \ OF_PP_MAKE_TUPLE_SEQ(Int8) \ OF_PP_MAKE_TUPLE_SEQ(Int32) \ OF_PP_MAKE_TUPLE_SEQ(Int64) \ OF_PP_MAKE_TUPLE_SEQ(UInt8) \ OF_PP_MAKE_TUPLE_SEQ(OFRecord) \ OF_PP_MAKE_TUPLE_SEQ(TensorBuffer) \ OF_PP_MAKE_TUPLE_SEQ(BFloat16) \ OF_PP_MAKE_TUPLE_SEQ(UInt16) \ OF_PP_MAKE_TUPLE_SEQ(UInt32) \ OF_PP_MAKE_TUPLE_SEQ(UInt64) \ OF_PP_MAKE_TUPLE_SEQ(UInt128) \ OF_PP_MAKE_TUPLE_SEQ(Int16) \ OF_PP_MAKE_TUPLE_SEQ(Int128) \ OF_PP_MAKE_TUPLE_SEQ(Complex32) \ OF_PP_MAKE_TUPLE_SEQ(Complex64) \ OF_PP_MAKE_TUPLE_SEQ(Complex128) class DType final { public: DType(const DType&) = default; DType(DType&&) = delete; explicit DType(DataType data_type) : data_type_(data_type) {} ~DType() = default; bool operator==(const DType& other) const { return this->data_type() == other.data_type(); } DataType data_type() const { return data_type_; } bool is_signed() const; bool is_complex() const; bool is_integer() const; bool is_floating_point() const; const std::string& name() const; Maybe bytes() const; static Maybe&> Get(DataType); static const int priority_order[DataType_ARRAYSIZE]; #define DECLARE_GET_DATA_TYPE_FUNCTION(data_type) static const Symbol& data_type(); OF_PP_FOR_EACH_TUPLE(DECLARE_GET_DATA_TYPE_FUNCTION, DTYPE_SEQ) #undef DECLARE_GET_DATA_TYPE_FUNCTION private: DataType data_type_; }; Symbol promoteTypes(const Symbol a, const Symbol b); Maybe SetDefaultDType(const Symbol& dtype); Symbol GetDefaultDType(); } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::DType& dtype) const { return static_cast(dtype.data_type()); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_DTYPE_H_ ================================================ FILE: oneflow/core/framework/eager_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_ #endif // ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_ ================================================ FILE: oneflow/core/framework/framework.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_ #define ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/util.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/framework/infer_output_blob_time_shape_fn_context.h" #include "oneflow/core/framework/infer_nd_sbp_fn_context.h" #include "oneflow/core/framework/compute_complexity_fn_context.h" #include "oneflow/core/framework/get_nd_sbp_signature_list_context.h" #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/multi_thread.h" #include "oneflow/core/framework/to_string.h" #endif // ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_ ================================================ FILE: oneflow/core/framework/get_nd_sbp_signature_list_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { class Shape; namespace user_op { class UserOpDefWrapper; class GetNdSbpSignatureListContext { public: virtual ~GetNdSbpSignatureListContext() = default; virtual void AddNdSbpSignature(NdSbpSignature&) = 0; virtual std::vector* MutNdSbpSignatureList() = 0; virtual const Shape& parallel_hierarchy() = 0; virtual const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0; template T Attr(const std::string& attr_name) const { return conf_.attr(attr_name); } const UserOpConfWrapper& user_op_conf() const { return conf_; } protected: explicit GetNdSbpSignatureListContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {} private: UserOpConfWrapper conf_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/global_param_grad_sync_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/global_param_grad_sync_mode.h" namespace oneflow { namespace { bool* GetThreadLocalGradSyncMode() { static thread_local bool g_grad_mode = true; return &g_grad_mode; } } // namespace bool GlobalGradSyncMode::is_enabled() { return *GetThreadLocalGradSyncMode(); } void GlobalGradSyncMode::set_enabled(bool enabled) { *GetThreadLocalGradSyncMode() = enabled; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/global_param_grad_sync_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_ #define ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_ namespace oneflow { struct GlobalGradSyncMode { static bool is_enabled(); static void set_enabled(bool enabled); }; class GlobalParamGradSyncMode { public: GlobalParamGradSyncMode(bool enabled) : prev_mode_(GlobalGradSyncMode::is_enabled()) { GlobalGradSyncMode::set_enabled(enabled); } ~GlobalParamGradSyncMode() { GlobalGradSyncMode::set_enabled(prev_mode_); } bool prev_mode() const { return prev_mode_; } private: bool prev_mode_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_ ================================================ FILE: oneflow/core/framework/global_tensor_infer_cache.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/env_var/eager.h" namespace oneflow { namespace one { namespace { bool OptionalEqual(const Optional>& lhs, const Optional>& rhs) { if (lhs.has_value() != rhs.has_value()) { return false; } if (!lhs.has_value()) { return true; } return CHECK_JUST(lhs) == CHECK_JUST(rhs); } } // namespace size_t InputGlobalTensorMeta::hash_value() const { size_t hash_value = std::hash>()(tensor_meta()); if (consumer_nd_sbp_constraint().has_value()) { AddHash(&hash_value, CHECK_JUST(consumer_nd_sbp_constraint())); } return hash_value; } bool InputGlobalTensorMeta::operator==(const InputGlobalTensorMeta& other) const { return this->tensor_meta() == other.tensor_meta() && OptionalEqual(this->consumer_nd_sbp_constraint(), other.consumer_nd_sbp_constraint()); } void InputGlobalTensorMeta::assign(Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint) { tensor_meta_ = tensor_meta; consumer_nd_sbp_constraint_ = consumer_nd_sbp_constraint; } size_t GlobalTensorMetaInferArgs::hash_value() const { size_t hash_value = std::hash()(attrs_); const auto& tensor_meta_hash_functor = std::hash(); for (const auto& tensor_meta : input_global_tensor_metas_) { HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta)); } return hash_value; } size_t SrcOpGlobalTensorMetaInferArgs::hash_value() const { size_t hash_value = std::hash()(attrs_); AddHash(&hash_value, parallel_desc_); AddHash(&hash_value, nd_sbp_); return hash_value; } bool GlobalTensorMetaInferArgs::operator==(const GlobalTensorMetaInferArgs& other) const { return this->attrs_ == other.attrs_ && this->input_global_tensor_metas_ == other.input_global_tensor_metas_; } bool SrcOpGlobalTensorMetaInferArgs::operator==(const SrcOpGlobalTensorMetaInferArgs& other) const { return this->attrs_ == other.attrs_ && this->parallel_desc_ == other.parallel_desc_ && this->nd_sbp_ == other.nd_sbp_; } Maybe GlobalTensorMetaInferArgs::MakeNdSbpConstraints( const UserOpExpr& user_op_expr, NdSbpSignature* nd_sbp_signature) const { const auto& input_arg_tuple = *user_op_expr.input_arg_tuple(); auto* map = nd_sbp_signature->mutable_bn_in_op2nd_sbp(); for (int i = 0; i < input_arg_tuple.size(); ++i) { const auto& constaint = input_global_tensor_metas_[i].consumer_nd_sbp_constraint(); if (constaint.has_value()) { (*map)[input_arg_tuple.indexed_bns().at(i)] = *JUST(constaint); } } return Maybe::Ok(); } Maybe GlobalTensorMetaInferArgs::MakeInputBlobDescs(const UserOpExpr& user_op_expr, std::vector* blob_descs) const { CHECK_OR_RETURN(blob_descs->empty()); const auto& input_arg_tuple = *user_op_expr.input_arg_tuple(); blob_descs->reserve(input_arg_tuple.size()); for (int i = 0; i < input_arg_tuple.size(); ++i) { const auto& tensor_meta = *input_global_tensor_metas_[i].tensor_meta(); blob_descs->emplace_back(tensor_meta.shape(), tensor_meta.stride(), tensor_meta.data_type(), tensor_meta.memory_format()); } return Maybe::Ok(); } Maybe GlobalTensorMetaInferArgs::MakeNdSbpInferHints( const UserOpExpr& user_op_expr, const std::vector& blob_descs, std::vector* hints) const { CHECK_OR_RETURN(hints->empty()); const auto& input_arg_tuple = *user_op_expr.input_arg_tuple(); hints->reserve(input_arg_tuple.size()); for (int i = 0; i < input_arg_tuple.size(); ++i) { const auto& tensor_meta = *input_global_tensor_metas_[i].tensor_meta(); const auto* parallel_desc = &*tensor_meta.parallel_desc(); const auto* blob_desc = &blob_descs.at(i); const auto* nd_sbp = &*tensor_meta.nd_sbp(); hints->emplace_back(parallel_desc, blob_desc, nd_sbp); } return Maybe::Ok(); } Maybe GlobalTensorMetaInferArgs::New(const AttrMap& attrs, const TensorTuple& input_tensors) { std::shared_ptr infer_args(new GlobalTensorMetaInferArgs()); infer_args->attrs_ = attrs; infer_args->input_global_tensor_metas_.resize(input_tensors.size()); JUST(infer_args->InitInputGlobalTensorMetas(input_tensors)); return infer_args; } Maybe SrcOpGlobalTensorMetaInferArgs::New( const AttrMap& attrs, Symbol parallel_desc, Symbol nd_sbp) { std::shared_ptr infer_args(new SrcOpGlobalTensorMetaInferArgs()); infer_args->attrs_ = attrs; infer_args->parallel_desc_ = parallel_desc; infer_args->nd_sbp_ = nd_sbp; return infer_args; } Maybe GlobalTensorMetaInferArgs::InitInputGlobalTensorMetas( const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { const auto& tensor = *input_tensors.at(i); const auto& tensor_meta = JUST(tensor.global_tensor_meta()); const auto& constraint = JUST(tensor.consumer_nd_sbp_constraint()); input_global_tensor_metas_[i].assign(tensor_meta, constraint); } return Maybe::Ok(); } namespace { Maybe MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs, const std::string& device_tag) { OperatorConf op_conf; JUST(user_op_expr.BuildOpConf(&op_conf, attrs)); DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag)); return JUST(ConstructOp(op_conf, device_type)); } Maybe CheckInputParallelDescIdentical(const GlobalTensorMetaInferArgs& infer_args, const UserOpExpr& user_op_expr) { if (infer_args.input_global_tensor_metas().empty()) { return Maybe::Ok(); } Symbol default_parallel_desc; for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) { if (user_op_expr.IsHostMemoryInput(i)) { continue; } default_parallel_desc = JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc(); break; } for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) { if (user_op_expr.IsHostMemoryInput(i)) { continue; } CHECK_OR_RETURN( default_parallel_desc == JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc()) << Error::RuntimeError() << "Expected all tensors to be on the same placement, but found " "at least two placements, " << *JUST(PlacementToString(default_parallel_desc)) << " (positional 0) and " << *JUST(PlacementToString(JUST(VectorAt(infer_args.input_global_tensor_metas(), i)) .tensor_meta() ->parallel_desc())) << " (positional " << i << ")!"; } return Maybe::Ok(); } Maybe CheckIsDeviceSupportedByOp(const ParallelDesc& parallel_desc, const std::string& op_type_name) { if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(parallel_desc.device_tag(), "cpu"); } return Maybe::Ok(); } class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const GlobalTensorMetaInferArgs* infer_args) : user_op_expr_(user_op_expr), composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), in_tensor_devices_(user_op_expr_->input_size()), out_tensor_devices_(user_op_expr_->output_size()) { for (int i = 0; i < user_op_expr_->input_size(); ++i) { const auto& parallel_desc = infer_args->input_global_tensor_metas().at(i).tensor_meta()->parallel_desc(); in_tensor_devices_.at(i) = CHECK_JUST(GetTensorDevice(parallel_desc)); } } const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); } const std::vector>& outputs() const override { return user_op_expr_->indexed_output_pairs(); } Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); CHECK_LT(tuple_index, user_op_expr_->output_size()); return &out_tensor_devices_.at(tuple_index); } Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); CHECK_LT(tuple_index, user_op_expr_->input_size()); return in_tensor_devices_.at(tuple_index); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return composed_attrs_.Attr4Name(attr_name); } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; std::vector> in_tensor_devices_; std::vector> out_tensor_devices_; }; } // namespace /* static */ Maybe> GlobalTensorInferCache::InferDeviceAndStream( const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args) { if (!user_op_expr.device_and_stream_infer_fn()) { Symbol parallel_desc = infer_args.input_global_tensor_metas()[0].tensor_meta()->parallel_desc(); return GetDefaultStreamByPlacement(parallel_desc); } else { UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, &infer_args); return TRY(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); } } /* static */ Maybe GlobalTensorInferCache::Infer( const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args) { CHECK_GT_OR_RETURN(infer_args.input_global_tensor_metas().size(), 0); // NOLINT Symbol parallel_desc = infer_args.input_global_tensor_metas()[0].tensor_meta()->parallel_desc(); JUST(CheckInputParallelDescIdentical(infer_args, user_op_expr)); JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name())); std::vector output_mut_metas(user_op_expr.output_size()); { // Infer OpArgMutGlobalTensorMeta. const auto& input_metas = infer_args.input_global_tensor_metas(); JUST(user_op_expr.InferLogicalTensorDesc( infer_args.attrs(), parallel_desc, [&](int32_t i) { return &*input_metas.at(i).tensor_meta(); }, [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); })); } const auto& op = JUST(MakeOp(user_op_expr, infer_args.attrs(), parallel_desc->device_tag())); JUST(op->FillOpParallelDesc(parallel_desc.shared_from_symbol())); JUST(op->InferParallelSignatureIf()); { // Infer parallel distribution. NdSbpSignature nd_sbp_constraints; JUST(infer_args.MakeNdSbpConstraints(user_op_expr, &nd_sbp_constraints)); std::vector blob_descs; JUST(infer_args.MakeInputBlobDescs(user_op_expr, &blob_descs)); std::vector pd_infer_hints; JUST(infer_args.MakeNdSbpInferHints(user_op_expr, blob_descs, &pd_infer_hints)); const auto& input_arg_tuple = *user_op_expr.input_arg_tuple(); const auto& NdSbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe { int32_t input_index = input_arg_tuple.bn_in_op2tensor_tuple_index().at(ibn); CHECK_GE_OR_RETURN(input_index, 0); CHECK_LT_OR_RETURN(input_index, pd_infer_hints.size()); return &pd_infer_hints.at(input_index); }; // The inferred results can be retrieved by op->NdSbp4BnInOp(obn). JUST(op->InferNdSbpSignatureIf(nd_sbp_constraints, *parallel_desc, NdSbpInferHint4Ibn)); } auto result = std::make_unique(user_op_expr.input_size(), user_op_expr.output_size()); auto* input_metas = result->mut_input_tensor_metas(); for (int32_t i = 0; i < user_op_expr.input_size(); ++i) { const auto& old_global_tensor_meta = infer_args.input_global_tensor_metas()[i].tensor_meta(); const auto& ibn = user_op_expr.input_arg_tuple()->indexed_bns().at(i); const auto& nd_sbp = SymbolOf(*JUST(op->NdSbp4BnInOp(ibn))); GlobalTensorMeta global_tensor_meta( old_global_tensor_meta->shape(), old_global_tensor_meta->dtype(), old_global_tensor_meta->memory_format(), nd_sbp, old_global_tensor_meta->parallel_desc()); (*input_metas)[i] = SymbolOf(global_tensor_meta); } auto* output_metas = result->mut_output_tensor_metas(); for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { const auto& output_mut_meta = output_mut_metas.at(i); const auto& shape = output_mut_meta.tensor_meta().shape(); DataType data_type = output_mut_meta.tensor_meta().data_type(); MemoryFormat memory_format = output_mut_meta.tensor_meta().memory_format(); const auto& obn = user_op_expr.output_arg_tuple()->indexed_bns().at(i); const auto& nd_sbp = SymbolOf(*JUST(op->NdSbp4BnInOp(obn))); GlobalTensorMeta tensor_meta(shape, data_type, memory_format, nd_sbp, parallel_desc); output_metas->at(i) = SymbolOf(tensor_meta); } result->set_stream(JUST(InferDeviceAndStream(user_op_expr, infer_args))); return std::shared_ptr(std::move(result)); } /* static */ Maybe GlobalTensorInferCache::Infer( const UserOpExpr& user_op_expr, const SrcOpGlobalTensorMetaInferArgs& infer_args) { Symbol parallel_desc = infer_args.parallel_desc(); JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name())); std::vector output_mut_metas(user_op_expr.output_size()); { // Infer OpArgMutGlobalTensorMeta. const auto& GetInputTensorMeta = [](int32_t i) { UNIMPLEMENTED(); return nullptr; }; JUST(user_op_expr.InferLogicalTensorDesc( infer_args.attrs(), parallel_desc, GetInputTensorMeta, [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); })); } auto result = std::make_unique(user_op_expr.input_size(), user_op_expr.output_size()); auto* output_metas = result->mut_output_tensor_metas(); for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { const auto& output_mut_meta = output_mut_metas.at(i); const auto& shape = output_mut_meta.tensor_meta().shape(); DataType data_type = output_mut_meta.tensor_meta().data_type(); MemoryFormat memory_format = output_mut_meta.tensor_meta().memory_format(); const auto& nd_sbp = infer_args.nd_sbp(); GlobalTensorMeta tensor_meta(shape, data_type, memory_format, nd_sbp, parallel_desc); output_metas->at(i) = SymbolOf(tensor_meta); } result->set_stream(JUST(GetDefaultStreamByPlacement(parallel_desc))); return std::shared_ptr(std::move(result)); } Maybe GlobalTensorInferCache::GetOrInfer( const GlobalTensorMetaInferArgs& infer_args) { auto iter = cache_.find(infer_args); if (iter == cache_.end()) { if (unlikely(cache_.size() >= ThreadLocalEnvInteger())) { cache_.clear(); } const auto& user_op_expr = user_op_expr_.lock(); CHECK_OR_RETURN(static_cast(user_op_expr)); const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); iter = cache_.emplace(infer_args, output_tensor_metas).first; } return iter->second; } Maybe GlobalTensorInferCache::GetOrInfer( const SrcOpGlobalTensorMetaInferArgs& infer_args) { auto iter = src_op_cache_.find(infer_args); if (iter == src_op_cache_.end()) { if (unlikely(src_op_cache_.size() >= ThreadLocalEnvInteger())) { src_op_cache_.clear(); } const auto& user_op_expr = user_op_expr_.lock(); CHECK_OR_RETURN(static_cast(user_op_expr)); const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); iter = src_op_cache_.emplace(infer_args, output_tensor_metas).first; } return iter->second; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/global_tensor_infer_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_ #define ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_ #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/nd_sbp_infer_hint.h" namespace oneflow { class NdSbp; class ParallelDesc; namespace one { class GlobalTensorMeta; class InputGlobalTensorMeta final { public: InputGlobalTensorMeta() : tensor_meta_(), consumer_nd_sbp_constraint_() {} InputGlobalTensorMeta(Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint) : tensor_meta_(tensor_meta), consumer_nd_sbp_constraint_(consumer_nd_sbp_constraint) {} InputGlobalTensorMeta(const InputGlobalTensorMeta&) = default; InputGlobalTensorMeta(InputGlobalTensorMeta&&) = default; ~InputGlobalTensorMeta() = default; size_t hash_value() const; bool operator==(const InputGlobalTensorMeta& other) const; Symbol tensor_meta() const { return tensor_meta_; } const Optional>& consumer_nd_sbp_constraint() const { return consumer_nd_sbp_constraint_; } void assign(Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint); private: Symbol tensor_meta_; Optional> consumer_nd_sbp_constraint_; }; class TensorTuple; class UserOpExpr; class GlobalTensorMetaInferArgs final { public: GlobalTensorMetaInferArgs(const GlobalTensorMetaInferArgs&) = default; GlobalTensorMetaInferArgs(GlobalTensorMetaInferArgs&&) = default; ~GlobalTensorMetaInferArgs() = default; const std::vector& input_global_tensor_metas() const { return input_global_tensor_metas_; } const AttrMap& attrs() const { return attrs_; } size_t hash_value() const; bool operator==(const GlobalTensorMetaInferArgs& other) const; Maybe MakeNdSbpConstraints(const UserOpExpr& user_op_expr, NdSbpSignature* nd_sbp_signature) const; Maybe MakeInputBlobDescs(const UserOpExpr& user_op_expr, std::vector* blob_descs) const; Maybe MakeNdSbpInferHints(const UserOpExpr& user_op_expr, const std::vector& blob_descs, std::vector* hints) const; static Maybe New(const AttrMap& attrs, const TensorTuple& input_tensors); private: GlobalTensorMetaInferArgs() = default; Maybe InitInputGlobalTensorMetas(const TensorTuple& input_tensors); AttrMap attrs_; std::vector input_global_tensor_metas_; }; class SrcOpGlobalTensorMetaInferArgs final { public: SrcOpGlobalTensorMetaInferArgs(const SrcOpGlobalTensorMetaInferArgs&) = default; SrcOpGlobalTensorMetaInferArgs(SrcOpGlobalTensorMetaInferArgs&&) = default; ~SrcOpGlobalTensorMetaInferArgs() = default; Symbol parallel_desc() const { return parallel_desc_; } Symbol nd_sbp() const { return nd_sbp_; } const AttrMap& attrs() const { return attrs_; } size_t hash_value() const; bool operator==(const SrcOpGlobalTensorMetaInferArgs& other) const; static Maybe New(const AttrMap& attrs, Symbol parallel_desc, Symbol nd_sbp); private: SrcOpGlobalTensorMetaInferArgs() = default; AttrMap attrs_; Symbol parallel_desc_; Symbol nd_sbp_; }; class OpArgMutGlobalTensorMeta final { public: OpArgMutGlobalTensorMeta() : tensor_meta_(std::make_shared(), DataType::kInvalidDataType, MemoryFormat::kContiguous) {} OpArgMutGlobalTensorMeta(const OpArgMutGlobalTensorMeta&) = default; OpArgMutGlobalTensorMeta(OpArgMutGlobalTensorMeta&&) = default; ~OpArgMutGlobalTensorMeta() = default; const TensorMeta& tensor_meta() const { return tensor_meta_; } TensorMeta* mut_tensor_meta() { return &tensor_meta_; } private: MutTensorMeta tensor_meta_; }; } // namespace one } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::one::InputGlobalTensorMeta& val) const { return val.hash_value(); } }; template<> struct hash final { size_t operator()(const oneflow::one::GlobalTensorMetaInferArgs& val) const { return val.hash_value(); } }; template<> struct hash final { size_t operator()(const oneflow::one::SrcOpGlobalTensorMetaInferArgs& val) const { return val.hash_value(); } }; } // namespace std namespace oneflow { namespace one { class GlobalTensorInferResult final { public: GlobalTensorInferResult(size_t input_size, size_t output_size) : input_tensor_metas_(input_size), output_tensor_metas_(output_size) {} GlobalTensorInferResult(const GlobalTensorInferResult&) = delete; GlobalTensorInferResult(GlobalTensorInferResult&&) = delete; ~GlobalTensorInferResult() = default; const std::vector>& input_tensor_metas() const { return input_tensor_metas_; } const std::vector>& output_tensor_metas() const { return output_tensor_metas_; } std::vector>* mut_input_tensor_metas() { return &input_tensor_metas_; } std::vector>* mut_output_tensor_metas() { return &output_tensor_metas_; } const Symbol& stream() const { return stream_; } void set_stream(const Symbol& stream) { stream_ = stream; } private: std::vector> input_tensor_metas_; std::vector> output_tensor_metas_; Symbol stream_; }; class GlobalTensorInferCache final { public: GlobalTensorInferCache(const std::shared_ptr& user_op_expr) : user_op_expr_(user_op_expr) {} Maybe GetOrInfer(const GlobalTensorMetaInferArgs& infer_args); static Maybe Infer(const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args); Maybe GetOrInfer(const SrcOpGlobalTensorMetaInferArgs& infer_args); static Maybe Infer( const UserOpExpr& user_op_expr, const SrcOpGlobalTensorMetaInferArgs& infer_args); private: static Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args); std::weak_ptr user_op_expr_; HashMap> cache_; HashMap> src_op_cache_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_ ================================================ FILE: oneflow/core/framework/id_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/id_util.h" namespace oneflow { Maybe UniqueStr(const std::string& prefix) { return prefix + NewUniqueId(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/id_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_ #include #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe UniqueStr(const std::string& prefix); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_ ================================================ FILE: oneflow/core/framework/infer_nd_sbp_fn_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_ #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { namespace user_op { class InferNdSbpFnContext { public: InferNdSbpFnContext() = default; virtual ~InferNdSbpFnContext() = default; InferNdSbpFnContext(const InferNdSbpFnContext&) = delete; virtual const TensorDesc& LogicalTensorDesc4InputArgNameAndIndex( const std::string& input_arg_name, int32_t index) const = 0; virtual NdSbp* NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) = 0; virtual const NdSbp& NdSbpHint4InputArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0; virtual const NdSbpSignature& nd_sbp_constraints() const = 0; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual int64_t parallel_num() const = 0; virtual const Shape& parallel_hierarchy() = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/infer_output_blob_time_shape_fn_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_ #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { namespace user_op { class InferOutputBlobTimeShapeFnContext { public: InferOutputBlobTimeShapeFnContext() = default; virtual ~InferOutputBlobTimeShapeFnContext() = default; InferOutputBlobTimeShapeFnContext(const InferOutputBlobTimeShapeFnContext&) = delete; virtual const Shape& TimeShape4InputArgNameAndIndex(const std::string& arg_name, int32_t index) = 0; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual Shape* mut_output_blob_time_shape() = 0; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/infer_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/attr_value_accessor.h" namespace oneflow { namespace user_op { Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { const TensorDesc* first_tensor_desc = nullptr; for (size_t i = 0; i < ctx->inputs().size(); ++i) { const std::pair& input_arg = ctx->inputs().at(i); if (first_tensor_desc) { const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); CHECK_EQ_OR_RETURN(tensor_desc.shape(), first_tensor_desc->shape()) << Error::RuntimeError() << "Tensor descriptions should have the same shape: expected " << first_tensor_desc->shape() << " but got " << tensor_desc.shape(); } else { first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second); } } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); ctx->SetOutputIsDynamic(output_arg.first, output_arg.second, // NOLINT first_tensor_desc->is_dynamic()); // NOLINT ctx->SetOutputShape(output_arg.first, output_arg.second, first_tensor_desc->shape()); // NOLINT } return Maybe::Ok(); } Maybe TensorDescInferFnUtil::UnchangedDataType(InferContext* ctx) { const TensorDesc* first_tensor_desc = nullptr; for (size_t i = 0; i < ctx->inputs().size(); ++i) { const std::pair& input_arg = ctx->inputs().at(i); if (first_tensor_desc) { const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); CHECK_EQ_OR_RETURN(tensor_desc.data_type(), first_tensor_desc->data_type()) << Error::TypeError() << "Tensor descriptions should have the same type. Expected " << DataType_Name(first_tensor_desc->data_type()) << ", but got " << DataType_Name(tensor_desc.data_type()); } else { first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second); } } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); ctx->SetOutputDType(output_arg.first, output_arg.second, // NOLINT first_tensor_desc->data_type()); // NOLINT } return Maybe::Ok(); } Maybe TensorDescInferFnUtil::InOutCorrespond(InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->inputs().size(), ctx->outputs().size()) << Error::InvalidValueError() << "Different input and output size. Input size :" << ctx->inputs().size() << ", output size: " << ctx->outputs().size(); for (size_t i = 0; i < ctx->inputs().size(); ++i) { const auto& input_arg = ctx->inputs().at(i); const auto& output_arg = ctx->outputs().at(i); *ctx->MutOutputTensorDesc(output_arg.first, output_arg.second) = ctx->InputTensorDesc(input_arg.first, input_arg.second); } return Maybe::Ok(); } Maybe CheckAttrFnUtil::NoCheck(const UserOpDefWrapper&, const UserOpConfWrapper&) { return Maybe::Ok(); } size_t TmpSizeInferFnUtil::ZeroTmpSize(InferContext*) { return 0; } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/infer_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { class Shape; class JobDesc; class Device; namespace user_op { class AttrVal; } // namespace user_op template extern const T& AttrValueCast(const user_op::AttrVal& val); namespace user_op { class UserOpDefWrapper; class InferContext { public: virtual ~InferContext() = default; virtual const TensorDesc& InputTensorDesc(const std::string&, int32_t) const = 0; virtual const TensorDesc& OutputTensorDesc(const std::string&, int32_t) const = 0; virtual TensorDesc* MutOutputTensorDesc(const std::string&, int32_t) = 0; virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const Shape& InputShape(const std::string&, int32_t) const = 0; virtual const Shape& OutputShape(const std::string&, int32_t) const = 0; virtual void SetOutputShape(const std::string&, int32_t, const Shape&) = 0; virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual void SetShape4ArgNameAndIndex(const std::string&, int32_t, const Shape&) = 0; virtual const Stride& InputStride(const std::string&, int32_t) const = 0; virtual const Stride& OutputStride(const std::string&, int32_t) const = 0; virtual void SetOutputStride(const std::string&, int32_t, const Stride&) = 0; virtual const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual void SetStride4ArgNameAndIndex(const std::string&, int32_t, const Stride&) = 0; virtual DataType InputDType(const std::string&, int32_t) const = 0; virtual DataType OutputDType(const std::string&, int32_t) const = 0; virtual void SetOutputDType(const std::string&, int32_t, DataType) = 0; virtual DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual void SetDtype4ArgNameAndIndex(const std::string&, int32_t, DataType) = 0; virtual MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const = 0; virtual MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const = 0; virtual void SetOutputMemoryFormat(const std::string& arg_name, int32_t index, MemoryFormat memory_format) = 0; virtual MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0; virtual void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index, MemoryFormat memory_format) = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual const std::string& input(const std::string& arg_name, int32_t index) const = 0; virtual const std::string& output(const std::string& arg_name, int32_t index) const = 0; virtual bool has_input(const std::string& arg_name, int32_t index) const = 0; virtual bool has_output(const std::string& arg_name, int32_t index) const = 0; virtual int32_t input_size(const std::string& arg_name) const = 0; virtual int32_t output_size(const std::string& arg_name) const = 0; virtual const std::string& op_name() const = 0; virtual const std::string& op_type_name() const = 0; virtual const std::string& op_loc() const = 0; template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } virtual const ParallelContext& parallel_ctx() const = 0; virtual const ParallelDesc& parallel_desc() const = 0; virtual const JobDesc* job_desc() const { UNIMPLEMENTED(); return nullptr; }; virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual bool InputIsDynamic(const std::string&, int32_t) const = 0; virtual bool OutputIsDynamic(const std::string&, int32_t) const = 0; virtual void SetOutputIsDynamic(const std::string&, int32_t, bool) = 0; virtual bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual void SetIsDynamic4ArgNameAndIndex(const std::string&, int32_t, bool) = 0; virtual int64_t parallel_num() const = 0; protected: InferContext() = default; InferContext(const InferContext&) = delete; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; class DeviceAndStreamInferContext { public: virtual ~DeviceAndStreamInferContext() = default; template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string&, int64_t) = 0; virtual Symbol InputTensorDevice4ArgNameAndIndex(const std::string&, int64_t) const = 0; protected: DeviceAndStreamInferContext() = default; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; struct TensorDescInferFnUtil { static Maybe Unchanged(InferContext*); static Maybe UnchangedDataType(InferContext*); static Maybe InOutCorrespond(InferContext*); }; struct CheckAttrFnUtil { static Maybe NoCheck(const UserOpDefWrapper&, const UserOpConfWrapper&); }; struct TmpSizeInferFnUtil { static size_t ZeroTmpSize(InferContext*); }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_ ================================================ FILE: oneflow/core/framework/instructions_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/stream_guard.h" #include "oneflow/core/framework/symbol_storage_util.h" #include "oneflow/core/device/event_record.h" #include "oneflow/core/framework/parallel_conf_util.h" #include "oneflow/core/operator/op_node_signature.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/access_blob_arg_cb_instruction_policy.h" #include "oneflow/core/vm/ep_record_event_instruction_policy.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/vm/barrier_instruction_policy.h" #include "oneflow/core/vm/critical_section_instruction_policy.h" #include "oneflow/core/vm/release_tensor_instruction_policy.h" #include "oneflow/core/vm/lazy_job_instruction_policy.h" #include "oneflow/core/vm/global_sync_instruction_policy.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/vm/stream_wait_instruction_policy.h" #include "oneflow/core/vm/stream_record_event_instruction_policy.h" #include "oneflow/core/vm/stream_wait_event_instruction_policy.h" #include "oneflow/core/vm/sync_access_instruction_policy.h" #include "oneflow/core/vm/touch_tensors_instruction_policy.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_need_soft_sync.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/framework/stream_support_stream_wait.h" #include "oneflow/core/framework/stream_on_independent_thread.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/vm/allocate_tensor_instruction_policy.h" namespace oneflow { namespace { Maybe> RawGetCriticalSectionStream() { return Stream::New(JUST(Device::New("cpu")), StreamType::kCriticalSection); } static constexpr auto* GetCriticalSectionStream = DECORATE(&RawGetCriticalSectionStream, ThreadLocal); Maybe> RawGetLazyJobLauncherStream() { return Stream::New(JUST(Device::New("cpu")), StreamType::kLazyJobLauncher); } static constexpr auto* GetLazyJobLauncherStream = DECORATE(&RawGetLazyJobLauncherStream, ThreadLocal); } // namespace // clang-format off // Job e.g.: // [wait_and_send_ids] // | // V // | // +-------------------+ // | | // V [cpu_decoder] // | | // [critcial_section_wait] V // | | // V [forward_ops...] // | | // | V // +-------------------+ // | // [copy_loss] // | // +-----------------------+ // | | // V V // | | // [backward_ops...] | // | | // V [critical_section_callback] // | | // [optimizer_ops...] V // | | // V | // | | // +-----------------------+ // | // [callback_notifier] // // // clang-format on // critcial_section_wait is a blocking opkernel which waits tick signal from instruction // CriticalSectionBegin. // critical_section_callback is a non-blocking opkernel which notifies instruction // CriticalSectionEnd done. Maybe InstructionsBuilder::LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, const vm::EagerBlobObjectListPtr& parameters, const std::shared_ptr& nn_graph) { JUST(SoftSyncNNGraphBuffers(inputs, nn_graph)); JUST(SoftSyncNNGraphBuffers(outputs, nn_graph)); JUST(SoftSyncNNGraphBuffers(parameters, nn_graph)); { // instruction chain: [CriticalSectionBegin] -> [CriticalSectionEnd] // instructions LaunchLazyJob are launched independent from instruction chains // [CriticalSectionBegin] -> [CriticalSectionEnd] const auto& input_op_name2end_event_record = std::make_shared>>(); { for (const auto& op_name : nn_graph->inputs_op_names()) { const auto& event_record = std::make_shared(); CHECK_OR_RETURN(input_op_name2end_event_record->emplace(op_name, event_record).second) << Error::RuntimeError() << "Duplicate Op name " << op_name; } auto stream = JUST(GetCriticalSectionStream()); auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); auto instruction = intrusive::make_shared( vm_stream, std::make_shared( nn_graph, inputs, input_op_name2end_event_record, vm_stream)); instruction_list_->EmplaceBack(std::move(instruction)); } const auto& output_op_name2end_event_record = std::make_shared>>(); { for (const auto& op_name : nn_graph->outputs_op_names()) { const auto& event_record = std::make_shared(); CHECK_OR_RETURN(output_op_name2end_event_record->emplace(op_name, event_record).second) << Error::RuntimeError() << "Duplicate Op name " << op_name; } auto stream = JUST(GetCriticalSectionStream()); auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); auto instruction = intrusive::make_shared( vm_stream, std::make_shared( nn_graph, outputs, output_op_name2end_event_record, vm_stream)); instruction_list_->EmplaceBack(std::move(instruction)); } { auto stream = JUST(GetLazyJobLauncherStream()); auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); auto instruction = intrusive::make_shared( vm_stream, std::make_shared(nn_graph, parameters)); instruction_list_->EmplaceBack(std::move(instruction)); } auto stream = JUST(GetCriticalSectionStream()); auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { const auto& eager_blob_object = inputs->at(i); const auto& op_name = nn_graph->inputs_op_names().at(i); const auto& event_record = JUST(MapAt(*input_op_name2end_event_record, op_name)); auto instruction = intrusive::make_shared( vm_stream, std::make_shared( eager_blob_object, event_record, vm_stream)); instruction_list_->EmplaceBack(std::move(instruction)); } for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { const auto& eager_blob_object = outputs->at(i); const auto& op_name = nn_graph->outputs_op_names().at(i); const auto& event_record = JUST(MapAt(*output_op_name2end_event_record, op_name)); auto instruction = intrusive::make_shared( vm_stream, std::make_shared( eager_blob_object, event_record, vm_stream)); instruction_list_->EmplaceBack(std::move(instruction)); } } return Maybe::Ok(); } Maybe InstructionsBuilder::SoftSyncNNGraphBuffers( const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph) { const auto& stream = JUST(GetCriticalSectionStream()); JUST(SoftSyncStream(*eager_blob_objects, stream)); return Maybe::Ok(); } namespace { int64_t NewSymbolId() { static std::atomic cnt(0); return cnt.fetch_add(1, std::memory_order_relaxed); } } // namespace Maybe InstructionsBuilder::GetJobConfSymbol(const JobConfigProto& job_conf) { return Singleton>::Get()->FindOrCreate(job_conf, &NewSymbolId); } Maybe InstructionsBuilder::GetParallelDescSymbol(const ParallelConf& parallel_conf) { return Singleton>::Get()->FindOrCreate(parallel_conf, &NewSymbolId); } Maybe InstructionsBuilder::GetScopeSymbol(const ScopeProto& scope_proto) { return Singleton>::Get()->FindOrCreate(scope_proto, &NewSymbolId); } Maybe InstructionsBuilder::GetOpConfSymbol(const OperatorConf& op_conf) { return Singleton>::Get()->FindOrCreate(op_conf, &NewSymbolId); } Maybe InstructionsBuilder::BuildInitialScope( int64_t session_id, const JobConfigProto& job_conf, const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy, bool is_local) { ScopeProto scope_proto; scope_proto.set_session_id(session_id); std::shared_ptr job_conf_sym = JUST(GetJobConfSymbol(job_conf)); scope_proto.set_job_desc_symbol_id(JUST(job_conf_sym->symbol_id())); std::shared_ptr parallel_conf = JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy)); std::shared_ptr device_parallel_desc_sym = JUST(GetParallelDescSymbol(*parallel_conf)); scope_proto.set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id())); parallel_conf = JUST(MakeParallelConf("cpu", machine_device_ids, hierarchy)); std::shared_ptr host_parallel_desc_sym = JUST(GetParallelDescSymbol(*parallel_conf)); scope_proto.set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id())); if (is_local) { scope_proto.mutable_opt_local_parallel_conf()->mutable_local_parallel(); } else { scope_proto.mutable_opt_local_parallel_conf()->clear_local_parallel(); } return GetScopeSymbol(scope_proto); } Maybe InstructionsBuilder::BuildInitialScopeWithPlacement(int64_t session_id, const JobConfigProto& job_conf, Symbol placement, bool is_local) { ScopeProto scope_proto; scope_proto.set_session_id(session_id); std::shared_ptr job_conf_sym = JUST(GetJobConfSymbol(job_conf)); scope_proto.set_job_desc_symbol_id(JUST(job_conf_sym->symbol_id())); std::shared_ptr device_parallel_desc_sym = JUST(GetParallelDescSymbol(placement->parallel_conf())); scope_proto.set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id())); Symbol new_placement = JUST(ReplaceDeviceType(placement, DeviceType::kCPU)); std::shared_ptr host_parallel_desc_sym = JUST(GetParallelDescSymbol(new_placement->parallel_conf())); scope_proto.set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id())); if (is_local) { scope_proto.mutable_opt_local_parallel_conf()->mutable_local_parallel(); } else { scope_proto.mutable_opt_local_parallel_conf()->clear_local_parallel(); } return GetScopeSymbol(scope_proto); } Maybe InstructionsBuilder::BuildScopeWithNewParallelDesc( const std::shared_ptr& scope, const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy) { const auto SetScopeProto = [this, &device_tag, &machine_device_ids, &hierarchy](const std::shared_ptr& scope_proto) { std::shared_ptr parallel_conf = CHECK_JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy)); std::shared_ptr device_parallel_desc_sym = CHECK_JUST(GetParallelDescSymbol(*parallel_conf)); parallel_conf = CHECK_JUST(MakeParallelConf("cpu", machine_device_ids, hierarchy)); std::shared_ptr host_parallel_desc_sym = CHECK_JUST(GetParallelDescSymbol(*parallel_conf)); scope_proto->set_device_parallel_desc_symbol_id( CHECK_JUST(device_parallel_desc_sym->symbol_id())); scope_proto->set_host_parallel_desc_symbol_id(CHECK_JUST(host_parallel_desc_sym->symbol_id())); }; return BuildScopeByProtoSetter(scope, SetScopeProto); } Maybe InstructionsBuilder::BuildScopeWithNewParallelConf(const std::shared_ptr& scope, const ParallelConf& parallel_conf) { const std::shared_ptr, std::shared_ptr>>& tag_and_dev_ids_and_hierarchy = JUST(GetDeviceTagAndMachineDeviceIdsAndHierarchy(parallel_conf)); std::shared_ptr hierarchy; if (std::get<2>(*tag_and_dev_ids_and_hierarchy)) { hierarchy.reset(new Shape(parallel_conf.hierarchy())); } return BuildScopeWithNewParallelDesc(scope, std::get<0>(*tag_and_dev_ids_and_hierarchy), std::get<1>(*tag_and_dev_ids_and_hierarchy), hierarchy); } Maybe InstructionsBuilder::BuildScopeWithNewIsLocal(const std::shared_ptr& scope, bool is_local) { const auto SetScopeProto = [is_local](const std::shared_ptr& scope_proto) { if (is_local) { scope_proto->mutable_opt_local_parallel_conf()->mutable_local_parallel(); } else { scope_proto->mutable_opt_local_parallel_conf()->clear_local_parallel(); } }; return BuildScopeByProtoSetter(scope, SetScopeProto); } Maybe InstructionsBuilder::BuildScopeWithNewScopeName(const std::shared_ptr& scope, const std::string& scope_name) { const auto SetScopeProto = [&scope_name](const std::shared_ptr& scope_proto) { scope_proto->add_scope_op_name_prefixes(scope_name); }; return BuildScopeByProtoSetter(scope, SetScopeProto); } Maybe InstructionsBuilder::BuildScopeByProtoSetter( const std::shared_ptr& scope, const std::function&)>& Setter) { std::shared_ptr scope_proto = JUST(scope->MakeChildScopeProto()); Setter(scope_proto); return GetScopeSymbol(*scope_proto); } Maybe InstructionsBuilder::BuildScopeByProtoStrSetter( const std::shared_ptr& scope, const std::function& StrSetter) { std::shared_ptr scope_proto = JUST(scope->MakeChildScopeProto()); std::string serialized_scope_proto = PbMessage2TxtString(*scope_proto); std::string new_serialized_scope_proto = StrSetter(serialized_scope_proto); CHECK_OR_RETURN(TxtString2PbMessage(new_serialized_scope_proto, scope_proto.get())) << Error::RuntimeError() << "scope_proto parse failed"; return GetScopeSymbol(*scope_proto); } Maybe InstructionsBuilder::Call(const std::shared_ptr& opkernel, vm::EagerBlobObjectList&& input_eager_blob_objects, vm::EagerBlobObjectList&& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream) { return Call(opkernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), nullptr, ctx, stream); } Maybe InstructionsBuilder::AllocateTensors(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream) { // try soft sync eager blob objects which have memory allocated. JUST(SoftSyncStream(eager_blob_objects, stream)); auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); const auto& instruction_policy = std::make_shared(eager_blob_objects, vm_stream); auto instruction = intrusive::make_shared(vm_stream, instruction_policy); instruction_list_->EmplaceBack(std::move(instruction)); for (const auto& eager_blob_object : eager_blob_objects) { if (!eager_blob_object->producer_stream().has_value()) { JUST(eager_blob_object->init_producer_stream(stream)); } eager_blob_object->set_last_used_stream(stream); } return Maybe::Ok(); } Maybe InstructionsBuilder::Call( const std::shared_ptr& opkernel, vm::EagerBlobObjectList&& input_eager_blob_objects, vm::EagerBlobObjectList&& output_eager_blob_objects, const std::shared_ptr& global_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream) { stream = JUST(StreamGuard::TryConvertStream(stream)); Symbol allocator_stream = JUST(GetAllocatorStream(stream)); if (stream != allocator_stream) { JUST(AllocateTensors(output_eager_blob_objects, allocator_stream)); } JUST(SoftSyncStream(output_eager_blob_objects, stream)); JUST(SoftSyncStream(input_eager_blob_objects, stream)); for (const auto& output : output_eager_blob_objects) { if (!output->producer_stream().has_value()) { JUST(output->init_producer_stream(stream)); } output->set_last_used_stream(stream); } auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); auto instruction = intrusive::make_shared( vm_stream, JUST(vm::OpCallInstructionPolicy::New( vm_stream, opkernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), global_tensor_infer_result, ctx, *one::CurrentDevVmDepObjectConsumeMode()))); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } Maybe InstructionsBuilder::ReleaseTensor( const std::shared_ptr& eager_blob_object) { const auto& last_used_stream = JUST(eager_blob_object->last_used_stream()); const auto& producer_stream = JUST(eager_blob_object->producer_stream()); if (pthread_fork::IsForkedSubProcess() && producer_stream->device()->enum_type() != DeviceType::kCPU) { return Maybe::Ok(); } Optional> stream{}; if (*one::CurrentDevVmDepObjectConsumeMode() == one::DevVmDepObjectConsumeMode::NONE) { stream = Optional>(NullOpt); } else if (IsCommNetStream::Visit(last_used_stream->stream_type())) { // Disable inter-device instruction sequential for tensor used by communicative stream. // It's not acceptable for us that cuda compute stream is blocked by cuda nccl stream. stream = Optional>(NullOpt); } else if (IsCommNetStream::Visit(producer_stream->stream_type())) { // Disable inter-device instruction sequential for tensor produced by communicative stream. stream = Optional>(NullOpt); } else { stream = producer_stream; } struct EnableStreamWaitOnReleaseTensor final : public StreamTypeVisitor { static bool VisitCompute() { return true; } static bool VisitHost2Device() { return true; } static bool VisitDevice2Host() { return true; } static bool VisitCcl() { return false; } static bool VisitBarrier() { return false; } static bool VisitCriticalSection() { return false; } static bool VisitLazyJobLauncher() { return false; } static bool VisitPinnedCompute() { return VisitCompute(); } }; const auto& EnableStreamWait = [&] { if (last_used_stream->device() != producer_stream->device()) { return false; } if (last_used_stream->stream_type() == producer_stream->stream_type()) { return true; } return EnableStreamWaitOnReleaseTensor::Visit(last_used_stream->stream_type()) && EnableStreamWaitOnReleaseTensor::Visit(producer_stream->stream_type()); }; if (last_used_stream != producer_stream) { if (stream.has_value() && EnableStreamWait()) { JUST(SoftSyncStreamBetween({JUST(eager_blob_object->compute_local_dep_object())}, last_used_stream, JUST(stream))); } else { JUST(RecordEvent({JUST(eager_blob_object->compute_local_dep_object())}, last_used_stream)); } eager_blob_object->set_last_used_stream(producer_stream); } auto vm_stream = stream.map([](Symbol stream) -> vm::Stream* { return CHECK_JUST(Singleton::Get()->GetVmStream(stream)); }); StreamType stream_type = producer_stream->stream_type(); auto instruction = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(producer_stream)), JUST(vm::MakeReleaseTensorInstructionPolicy::Visit(stream_type, eager_blob_object, vm_stream))); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } Maybe InstructionsBuilder::TouchTensors( const vm::EagerBlobObjectListPtr& eager_blob_objects) { Symbol device = JUST(Device::New("cpu")); Symbol stream = JUST(GetDefaultStreamByDevice(device)); return TouchTensors(eager_blob_objects, stream); } Maybe InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream) { JUST(SoftSyncStream(*eager_blob_objects, stream)); auto instruction = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(stream)), std::make_unique(*eager_blob_objects)); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } namespace { template using SmallSet = small_vector; template std::pair::iterator, bool> SmallSetInsert(SmallSet* vec, const T& elem) { for (auto iter = vec->begin(); iter != vec->end(); ++iter) { if (*iter == elem) { return std::make_pair(iter, false); } } vec->push_back(elem); return std::make_pair(vec->end() - 1, true); } template Maybe ForEachEagerBlobObjectsNeedingSoftSync( const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream, const DoEachT& DoEach) { if (eager_blob_objects.size() <= kOpArgsReservedSize) { for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_last_used_stream.has_value())) { continue; } const auto& last_used_stream = JUST(opt_last_used_stream); if (last_used_stream != stream) { small_vector> dep_objects{ intrusive::shared_ptr( JUST(eager_blob_object->compute_local_dep_object()))}; JUST(DoEach(last_used_stream, std::move(dep_objects))); } } } else { SmallSet> last_used_streams; for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_last_used_stream.has_value())) { continue; } const auto& last_used_stream = JUST(opt_last_used_stream); if (last_used_stream != stream) { SmallSetInsert(&last_used_streams, last_used_stream); } } for (const auto& last_used_stream : last_used_streams) { small_vector> dep_objects{}; for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_stream.has_value())) { continue; } if (JUST(opt_stream) == last_used_stream) { dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object())); } } JUST(DoEach(last_used_stream, std::move(dep_objects))); } } return Maybe::Ok(); } } // namespace Maybe InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream) { JUST(ForEachEagerBlobObjectsNeedingSoftSync( eager_blob_objects, stream, [&](Symbol last_used_stream, auto&& dep_objects) -> Maybe { return SoftSyncStreamBetween(std::move(dep_objects), last_used_stream, stream); })); for (const auto& eager_blob_object : eager_blob_objects) { eager_blob_object->set_last_used_stream(stream); } return Maybe::Ok(); } namespace { bool SupportingStreamWait(Symbol from_stream, Symbol to_stream) { if (from_stream->device() == to_stream->device() && from_stream->stream_type() == to_stream->stream_type() && from_stream->thread_uid() == to_stream->thread_uid()) { CHECK(from_stream == to_stream); } if (unlikely(!ThreadLocalEnvBool())) { return false; } DeviceType from_device_type = from_stream->device()->enum_type(); DeviceType to_device_type = from_stream->device()->enum_type(); return from_stream->device() == to_stream->device() && from_stream->support_wait_event() && to_stream->support_wait_event() && StreamSupportStreamWait::Visit(from_stream->stream_type(), from_device_type) && StreamSupportStreamWait::Visit(to_stream->stream_type(), to_device_type) && !StreamOnIndependentThread::Visit(from_stream->stream_type()) && !StreamOnIndependentThread::Visit(to_stream->stream_type()); } } // namespace Maybe InstructionsBuilder::SoftSyncStreamBetween( small_vector>&& dependences, Symbol from_stream, Symbol to_stream) { CHECK(from_stream != to_stream) << "synchronization is unnecessary"; if (SupportingStreamWait(from_stream, to_stream)) { JUST(StreamWait(std::move(dependences), from_stream, to_stream)); } else { JUST(RecordEvent(std::move(dependences), from_stream)); } return Maybe::Ok(); } Maybe InstructionsBuilder::StreamWait( small_vector>&& dependences, Symbol from_stream, Symbol to_stream) { auto* from_vm_stream = JUST(Singleton::Get()->GetVmStream(from_stream)); auto* to_vm_stream = JUST(Singleton::Get()->GetVmStream(to_stream)); if (from_vm_stream->mut_thread_ctx() != to_vm_stream->mut_thread_ctx()) { auto stream_record_event = std::make_shared(dependences); auto record_instruction = intrusive::make_shared(from_vm_stream, stream_record_event); instruction_list_->EmplaceBack(std::move(record_instruction)); auto stream_wait_event = std::make_shared(dependences, stream_record_event); auto wait_instruction = intrusive::make_shared(to_vm_stream, stream_wait_event); instruction_list_->EmplaceBack(std::move(wait_instruction)); } else { auto instruction = intrusive::make_shared( to_vm_stream, std::make_unique( std::move(dependences), from_vm_stream, to_vm_stream)); instruction_list_->EmplaceBack(std::move(instruction)); } return Maybe::Ok(); } Maybe InstructionsBuilder::RecordEvent( small_vector>&& compute_local_dep_objects, Symbol last_used_stream) { DeviceType device_type = last_used_stream->device()->enum_type(); if (!NeedSoftSync::Visit(last_used_stream->stream_type(), device_type)) { return Maybe::Ok(); } std::string modifier = "mut"; StreamType stream_type = last_used_stream->stream_type(); auto instruction = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(last_used_stream)), JUST(GetRecordEventInstructionPolicy::Visit(stream_type, device_type, std::move(compute_local_dep_objects), modifier))); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } template Maybe InstructionsBuilder::SyncAccessBlobByCallback( const T tensor, const std::shared_ptr& btb, const std::function&)>& Callback, const std::string& modifier) { // We want balance the cpu overhead and notification latency. // // balanced timeline here: // // B: blocking wait // W: wake up // S: spin wait // // vm thread: |<--------------- prev ops ------------------>|<- Callback() ->| // // main thread: |<-------------------- B -------------------->|<- W ->|<- S ->| // // bad timeline with more notification latency: // // B: blocking wait // W: wake up // S: spin wait // // vm thread: |<--------------- prev ops ------------------>|<- Callback() ->| // // main thread: |<---------------------------- B ----------------------------->|<- W ->| // // bad timeline with more cpu overhead: // // B: blocking wait // W: wake up // S: spin wait // // vm thread: |<--------------- prev ops ------------------>|<- Callback() ->| // | | | // main thread: |<---------------------------- S ----------------------------->| const auto& CallbackWrapper = [btb, Callback]( ep::Stream* stream, const std::shared_ptr& eager_blob_object) { btb->mut_notifier()->Notify(); Callback(stream, eager_blob_object); btb->mut_spin_counter()->Decrease(); }; return AccessBlobByCallback(tensor, CallbackWrapper, modifier); } template Maybe InstructionsBuilder::SyncAccessBlobByCallback( const std::shared_ptr tensor, const std::shared_ptr& btb, const std::function&)>& Callback, const std::string& modifier); template Maybe InstructionsBuilder::SyncAccessBlobByCallback( const one::EagerLocalTensorImpl* tensor, const std::shared_ptr& btb, const std::function&)>& Callback, const std::string& modifier); namespace { Maybe> GetDevice(const std::shared_ptr& tensor) { return tensor->device(); // return Maybe> } Maybe> GetDevice(const one::EagerLocalTensorImpl* tensor) { return tensor->device(); // return const Symbol& } template Maybe> GetAccessStream(const T tensor) { Symbol device = JUST(GetDevice(tensor)); // Do not use producer_stream or last_used_stream. // Bug case when using producer_stream or last_used_stream: // // ```python // tensor = oneflow.ones((1024, 1024, 1024), device='cuda').cpu() // ndarray = tensor.numpy() # share memory // // ``` // `ndarray` may not be ones because instruction AccessBlobByCallback is prescheduled before // oneflow.ones actually finished. Symbol stream = JUST(GetDefaultStreamByDevice(device)); return StreamGuard::TryConvertStream(stream); } } // namespace template Maybe InstructionsBuilder::AccessBlobByCallback( const T tensor, const std::function&)>& callback, const std::string& modifier) { const std::shared_ptr& eager_blob_object = JUST(tensor->eager_blob_object()); Symbol stream = JUST(GetAccessStream(tensor)); JUST(SoftSyncStream({eager_blob_object}, stream)); auto instruction = intrusive::make_shared( // Never replace `stream` with producer_stream or last_used_stream. JUST(Singleton::Get()->GetVmStream(stream)), std::make_shared(eager_blob_object, callback, modifier)); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } template Maybe InstructionsBuilder::AccessBlobByCallback( const std::shared_ptr tensor, const std::function&)>& callback, const std::string& modifier); template Maybe InstructionsBuilder::AccessBlobByCallback( const one::EagerLocalTensorImpl* tensor, const std::function&)>& callback, const std::string& modifier); namespace { Maybe> GetBarrierStream() { auto device = JUST(Device::New("cpu")); return Stream::New(device, StreamType::kBarrier); } } // namespace Maybe InstructionsBuilder::GlobalSync() { auto stream = JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(stream)), std::make_shared()); instruction_list_->PushBack(instruction.Mutable()); return Maybe::Ok(); } Maybe InstructionsBuilder::Barrier(const std::function& Callback) { auto stream = JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(stream)), std::make_shared(Callback)); instruction_list_->PushBack(instruction.Mutable()); return Maybe::Ok(); } namespace { template Maybe MutThreadLocalInstruction(Symbol stream) { static thread_local std::vector> vec; if (unlikely(stream->unique_stream_id() >= vec.size())) { vec.resize(stream->unique_stream_id() + 1); } auto* instruction_ptr = &vec[stream->unique_stream_id()]; if (static_cast(*instruction_ptr) && (*instruction_ptr)->ref_cnt() != 1) { // This instruction should not be reusd because of being hold by other threads. instruction_ptr->Reset(); } if (unlikely(!static_cast(*instruction_ptr))) { *instruction_ptr = intrusive::make_shared( JUST(Singleton::Get()->GetVmStream(stream)), std::make_shared()); } return instruction_ptr->Mutable(); } } // namespace template Maybe SyncAccessSmallMem(char* mem_ptr, size_t bytes, const T tensor) { static thread_local vm::InstructionList instruction_list; static thread_local InstructionsBuilder instructions_builder(&instruction_list); const std::shared_ptr& eager_blob_object = JUST(tensor->eager_blob_object()); const Symbol stream = JUST(GetAccessStream(tensor)); if (eager_blob_object->last_used_stream().has_value() && stream != JUST(eager_blob_object->last_used_stream())) { // Synchronize stream. JUST(instructions_builder.SoftSyncStream({eager_blob_object}, stream)); } InstructionPolicyT* instruction_policy = nullptr; { // Construct instruction. auto* instruction = JUST(MutThreadLocalInstruction(stream)); instruction_policy = static_cast(instruction->mut_instruction_policy()); // NOLINT instruction_policy->Reset(mem_ptr, bytes, eager_blob_object.get()); instruction_list.PushBack(instruction); } // Dispatch instructions. JUST(vm::Run(&instruction_list)); { // This thread should blocking wait if and only if there is a lot of workload on worker thread. // When workload is small, we want better performance by skipping cond_.notify_xxx which costs // about 2us to 3us. auto* virtual_machine = JUST(SingletonMaybe()); static constexpr int kSkipBlockingThreshold = 2; if (virtual_machine->flying_instruction_cnt() < kSkipBlockingThreshold) { // skip pthread_cond_broadcast on worker thread. instruction_policy->mut_btb()->mut_notifier()->Notify(); } } // wait until done. JUST(instruction_policy->mut_btb()->WaitUntilCntEqualZero( VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); } template Maybe SyncReadSmallMem(char* mem_ptr, size_t bytes, const T tensor) { return SyncAccessSmallMem(mem_ptr, bytes, tensor); } template Maybe SyncReadSmallMem(char* mem_ptr, size_t bytes, const std::shared_ptr tensor); template Maybe SyncReadSmallMem(char* mem_ptr, size_t bytes, const one::EagerLocalTensorImpl* tensor); } // namespace oneflow ================================================ FILE: oneflow/core/framework/instructions_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_ #define ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_ #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/scope.pb.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/operator/op_conf_symbol.h" #include "oneflow/core/vm/vm_util.h" namespace oneflow { namespace one { class StatefulOpKernel; class TensorTuple; class LocalTensor; class GlobalTensorInferResult; } // namespace one class NNGraphIf; class SharedEventRecord; class InstructionsBuilder : public std::enable_shared_from_this { public: InstructionsBuilder(const InstructionsBuilder&) = delete; InstructionsBuilder(InstructionsBuilder&&) = delete; explicit InstructionsBuilder(vm::InstructionList* instruction_list) : instruction_list_(instruction_list) {} ~InstructionsBuilder() { instruction_list_->Clear(); } const vm::InstructionList& instruction_list() const { return *instruction_list_; } vm::InstructionList* mut_instruction_list() { return instruction_list_; } // Build VM execution instructions with NNGraph's inputs/outputs/parameters for NNGraph execution. Maybe LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, const vm::EagerBlobObjectListPtr& parameters, const std::shared_ptr& nn_graph); // soft sync for inputs/outputs buffers of NNGraph Maybe SoftSyncNNGraphBuffers(const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph); Maybe GetJobConfSymbol(const JobConfigProto& job_conf); Maybe GetParallelDescSymbol(const ParallelConf& parallel_conf); Maybe GetScopeSymbol(const ScopeProto& scope_proto); Maybe GetOpConfSymbol(const OperatorConf& op_conf); Maybe ReleaseTensor(const std::shared_ptr& eager_blob_object); Maybe TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects); Maybe TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream); template Maybe SyncAccessBlobByCallback( const T tensor, const std::shared_ptr& btb, const std::function&)>& Callback, const std::string& modifier); template Maybe AccessBlobByCallback( const T tensor, const std::function&)>& callback, const std::string& modifier); Maybe GlobalSync(); Maybe Barrier(const std::function& callback); Maybe BuildInitialScope(int64_t session_id, const JobConfigProto& job_conf, const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy, bool is_local); Maybe BuildInitialScopeWithPlacement(int64_t session_id, const JobConfigProto& job_conf, Symbol placement, bool is_local); Maybe BuildScopeWithNewParallelDesc(const std::shared_ptr& scope, const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy); Maybe BuildScopeWithNewParallelConf(const std::shared_ptr& scope, const ParallelConf& parallel_conf); Maybe BuildScopeWithNewIsLocal(const std::shared_ptr& scope, bool is_local); Maybe BuildScopeWithNewScopeName(const std::shared_ptr& scope, const std::string& scope_name); Maybe BuildScopeByProtoSetter( const std::shared_ptr& scope, const std::function&)>& Setter); Maybe BuildScopeByProtoStrSetter( const std::shared_ptr& scope, const std::function& StrSetter); Maybe Call(const std::shared_ptr& opkernel, vm::EagerBlobObjectList&& input_eager_blob_objects, vm::EagerBlobObjectList&& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream); Maybe Call( const std::shared_ptr& opkernel, vm::EagerBlobObjectList&& input_eager_blob_objects, vm::EagerBlobObjectList&& output_eager_blob_objects, const std::shared_ptr& global_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream); Maybe SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream); private: Maybe AllocateTensors(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream); Maybe SoftSyncStreamBetween( small_vector>&& dependences, Symbol from_stream, Symbol to_stream); Maybe StreamWait(small_vector>&& dependences, Symbol from_stream, Symbol to_stream); Maybe RecordEvent( small_vector>&& compute_local_dep_objects, Symbol stream); vm::InstructionList* instruction_list_; }; // Make VM instructions with instruction builder and run instructions with physical/local view. template Maybe PhysicalRun(const CallbackT& Build) { vm::InstructionList instruction_list; InstructionsBuilder instructions_builder(&instruction_list); JUST(Build(&instructions_builder)); JUST(vm::Run(instructions_builder.mut_instruction_list())); return Maybe::Ok(); } template Maybe SyncReadSmallMem(char* mem_ptr, size_t bytes, const T tensor); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_ ================================================ FILE: oneflow/core/framework/layout.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/layout.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { Symbol Layout::Get(LayoutType layout_type) { static const HashMap> layout_type2layout{ #define MAKE_ENTRY(layout_type) {OF_PP_CAT(LayoutType::k, layout_type), layout_type()}, OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, LAYOUT_SEQ) #undef MAKE_ENTRY }; return layout_type2layout.at(layout_type); } const std::string& GetLayoutTypeName(LayoutType layout_type) { static const HashMap layout_type2name{ {LayoutType::kStrided, "oneflow.strided"}}; return layout_type2name.at(layout_type); }; const std::string& Layout::name() const { return GetLayoutTypeName(layout_type_); } #define DEFINE_GET_LAYOUT_TYPE_FUNCTION(layout_type) \ Symbol Layout::layout_type() { \ static const auto& layout = SymbolOf(Layout(OF_PP_CAT(LayoutType::k, layout_type))); \ return layout; \ } OF_PP_FOR_EACH_TUPLE(DEFINE_GET_LAYOUT_TYPE_FUNCTION, LAYOUT_SEQ) #undef DEFINE_GET_LAYOUT_TYPE_FUNCTION } // namespace oneflow ================================================ FILE: oneflow/core/framework/layout.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_ #define ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_ #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" namespace oneflow { enum class LayoutType { kStrided, }; #define LAYOUT_SEQ OF_PP_MAKE_TUPLE_SEQ(Strided) class Layout final { public: Layout(const Layout&) = default; Layout(Layout&&) = delete; explicit Layout(LayoutType layout_type) : layout_type_(layout_type) {} ~Layout() = default; bool operator==(const Layout& other) const { return this->layout_type() == other.layout_type(); } const std::string& name() const; LayoutType layout_type() const { return layout_type_; } static Symbol Get(LayoutType); #define DECLARE_GET_LAYOUT_TYPE_FUNCTION(layout_type) static Symbol layout_type(); OF_PP_FOR_EACH_TUPLE(DECLARE_GET_LAYOUT_TYPE_FUNCTION, LAYOUT_SEQ) #undef DECLARE_GET_LAYOUT_TYPE_FUNCTION private: LayoutType layout_type_; }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::Layout& layout) const { return static_cast(layout.layout_type()); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_ ================================================ FILE: oneflow/core/framework/load_library.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/load_library.h" #include namespace oneflow { Maybe LoadLibrary(const std::string& lib_path) { void* handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL); CHECK_OR_RETURN(handle) << " LoadLibrary ERROR! Cannot load library file: " + lib_path << " the Error is: " << dlerror(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/load_library.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_ #define ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe LoadLibrary(const std::string& lib_path); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_ ================================================ FILE: oneflow/core/framework/local_tensor_infer_cache.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/env_var/eager.h" #include "oneflow/core/framework/infer_util.h" namespace oneflow { namespace one { namespace { Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } // NOLINT return Maybe::Ok(); } Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args, Symbol default_device, const UserOpExpr& user_op_expr) { for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) { if (user_op_expr.IsHostMemoryInput(i)) { continue; } CHECK_OR_RETURN(default_device == JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()) << Error::RuntimeError() << "Expected all tensors to be on the same device, but found " "at least two devices, " << default_device->ToString() << " (positional 0) and " << JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()->ToString() << " (positional " << i << ")!"; } return Maybe::Ok(); } class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs& infer_args, OpArgsVector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()), infer_args_(infer_args), output_tensor_metas_(output_tensor_metas) {} const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); } const std::vector>& outputs() const override { return user_op_expr_->indexed_output_pairs(); } Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0) << "tuple index should be non-negative, but got " << tuple_index; CHECK_LT(tuple_index, user_op_expr_->output_size()) << "tuple index " << tuple_index << " should be less than output size " << user_op_expr_->output_size(); return output_tensor_metas_->at(tuple_index).mut_device(); } Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0) << "tuple index should be non-negative, but got " << tuple_index; CHECK_LT(tuple_index, user_op_expr_->input_size()) << "tuple index " << tuple_index << " should be less than input size " << user_op_expr_->input_size(); return infer_args_.input_local_tensor_metas().at(tuple_index)->device(); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return composed_attrs_.Attr4Name(attr_name); } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; const LocalTensorMetaInferArgs& infer_args_; OpArgsVector* output_tensor_metas_; }; Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, const Symbol& default_device, const LocalTensorMetaInferArgs& infer_args, OpArgsVector* output_tensor_metas) { Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < user_op_expr.output_size(); i++) { auto& tensor_meta = output_tensor_metas->at(i); *tensor_meta.mut_device() = default_device; } } else { if (!user_op_expr.device_and_stream_infer_fn()) { Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); stream = JUST(GetDefaultStreamByDevice(device)); } else { UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, infer_args, output_tensor_metas); stream = JUST(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); } } return stream; } } // namespace size_t LocalTensorMetaInferArgs::hash_value() const { size_t hash_value = std::hash()(attrs_); HashCombine(&hash_value, std::hash>()(default_device_)); const auto& tensor_meta_hash_functor = std::hash>(); for (const auto& tensor_meta : input_local_tensor_metas_) { HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta)); } return hash_value; } bool LocalTensorMetaInferArgs::operator==(const LocalTensorMetaInferArgs& other) const { return this->attrs_ == other.attrs_ && this->default_device_ == other.default_device_ && this->input_local_tensor_metas_ == other.input_local_tensor_metas_; } Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol default_device, const TensorTuple& input_tensors) { this->attrs_ = attrs; this->default_device_ = default_device; this->input_local_tensor_metas_.resize(input_tensors.size()); JUST(this->InitInputLocalTensorMetas(input_tensors)); return Maybe::Ok(); } Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { input_local_tensor_metas_.at(i) = JUST(input_tensors.at(i)->local_tensor_meta()); } return Maybe::Ok(); } /* static */ Maybe LocalTensorInferCache::Infer( const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args) { const auto& default_device = infer_args.default_device(); JUST(CheckInputDeviceIdentical(infer_args, default_device, user_op_expr)); JUST(CheckIsDeviceSupportedByOp(*default_device, user_op_expr.op_type_name())); auto result = std::make_unique(user_op_expr.output_size()); OpArgsVector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream = JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, &output_mut_metas)); result->set_stream(stream); { const auto& GetInputTensorMeta = [&](int32_t i) -> const TensorMeta* { return infer_args.input_local_tensor_metas().at(i).shared_from_symbol().get(); }; JUST(user_op_expr.InferPhysicalTensorDesc( infer_args.attrs(), stream->device()->type(), GetInputTensorMeta, [&](int32_t i) -> TensorMeta* { return &output_mut_metas.at(i); })); } auto* mut_output_tensor_metas = result->mut_output_tensor_metas(); for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { if (!JUST(user_op_expr.SupportNonContiguous())) { Stride stride(output_mut_metas.at(i).shape()); output_mut_metas.at(i).set_stride(stride); } CHECK_OR_RETURN(static_cast(output_mut_metas.at(i).device())) << Error::RuntimeError() << "device not infered"; mut_output_tensor_metas->at(i) = SymbolOf( LocalTensorMeta(output_mut_metas.at(i).shape(), output_mut_metas.at(i).stride(), output_mut_metas.at(i).data_type(), output_mut_metas.at(i).memory_format(), output_mut_metas.at(i).device())); } return std::shared_ptr(std::move(result)); } Maybe LocalTensorInferCache::GetOrInfer( const LocalTensorMetaInferArgs& infer_args) { if (ThreadLocalEnvBool()) { auto iter = cache_.find(infer_args); if (iter == cache_.end()) { if (unlikely(cache_.size() >= ThreadLocalEnvInteger())) { cache_.clear(); } const auto& user_op_expr = user_op_expr_.lock(); CHECK_OR_RETURN(static_cast(user_op_expr)); // NOLINT const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); iter = cache_.emplace(infer_args, output_tensor_metas).first; } return iter->second; } else { const auto& user_op_expr = user_op_expr_.lock(); return JUST(Infer(*user_op_expr, infer_args)); } } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/local_tensor_infer_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ #define ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/op_args_vector.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/common/tensor_meta.h" namespace oneflow { class Device; namespace one { class TensorTuple; class UserOpExpr; class LocalTensorMetaInferArgs final { public: LocalTensorMetaInferArgs() = default; LocalTensorMetaInferArgs(const LocalTensorMetaInferArgs&) = default; LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; ~LocalTensorMetaInferArgs() = default; const OpArgsVector>& input_local_tensor_metas() const { return input_local_tensor_metas_; } const AttrMap& attrs() const { return attrs_; } const Symbol& default_device() const { return default_device_; } size_t hash_value() const; bool operator==(const LocalTensorMetaInferArgs& other) const; Maybe Init(const AttrMap& attrs, Symbol default_device, const TensorTuple& input_tensors); private: Maybe InitInputLocalTensorMetas(const TensorTuple& input_tensors); AttrMap attrs_; Symbol default_device_; OpArgsVector> input_local_tensor_metas_; }; } // namespace one } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::one::LocalTensorMetaInferArgs& val) const { return val.hash_value(); } }; } // namespace std namespace oneflow { namespace one { class LocalTensorInferResult final { public: LocalTensorInferResult(size_t output_size) : output_tensor_metas_(output_size) {} LocalTensorInferResult(const LocalTensorInferResult&) = delete; LocalTensorInferResult(LocalTensorInferResult&&) = delete; ~LocalTensorInferResult() = default; const OpArgsVector>& output_tensor_metas() const { return output_tensor_metas_; } OpArgsVector>* mut_output_tensor_metas() { return &output_tensor_metas_; } const Symbol& stream() const { return stream_; } void set_stream(const Symbol& stream) { stream_ = stream; } private: OpArgsVector> output_tensor_metas_; Symbol stream_; }; class LocalTensorInferCache final { public: LocalTensorInferCache(const std::shared_ptr& user_op_expr) : user_op_expr_(user_op_expr) {} Maybe GetOrInfer(const LocalTensorMetaInferArgs& infer_args); private: static Maybe Infer(const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args); std::weak_ptr user_op_expr_; HashMap> cache_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ ================================================ FILE: oneflow/core/framework/multi_client_session_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/load_library.h" #include "oneflow/core/job/id_state.h" #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/job/version.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/job/runtime_job_descs.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/user/summary/events_writer.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/memory/chunk_manager.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/job/collective_boxing/scheduler.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace oneflow { namespace { int32_t GetCpuDeviceNum() { return std::thread::hardware_concurrency(); } } // namespace MultiClientSessionContext::MultiClientSessionContext( const std::shared_ptr& env_ctx) : env_ctx_(env_ctx) { CHECK(Singleton::Get() == nullptr) << "Duplicate multi client session context"; Singleton::SetAllocated(this); } MultiClientSessionContext::~MultiClientSessionContext() { CHECK_JUST(TryClose()); if (Singleton::Get() != nullptr) { Singleton::SetAllocated(nullptr); } env_ctx_.reset(); } Maybe MultiClientSessionContext::TryInit(const ConfigProto& config_proto) { if (!is_inited_) { DumpVersionInfo(); Resource resource = config_proto.resource(); { // NOTE(chengcheng): // In multi-client, user can NOT config cpu_device_num. // // cpu_device_num is a confusing name, it should be explained as: // in current rank, assign CPU actor compute stream in this optional range. // That is, the number of independent CPU devices that can be abstracted from // this machine and this process. // // NOTE: cpu_device_num NOT necessarily equal to the num of process // on this machine. resource.set_machine_num(GlobalProcessCtx::NodeSize()); resource.set_cpu_device_num(GetCpuDeviceNum()); } // NOTE(chengcheng): detele first because in EnvGlobalObjectScope has created ResourceDesc. if (Singleton::Get() != nullptr) { // TODO(chengcheng): reorganize dependency of all Global objects. Singleton::Delete(); } Singleton::New(resource, GlobalProcessCtx::NumOfProcessPerNode()); Singleton::New(); Singleton::New(); // TODO(chengcheng): refactor JobBuildAndInferCtxMgr Singleton::New(); { // NOTE(chengcheng): init runtime global objects Singleton>>::New(); Singleton>>::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); Singleton::New(); } is_inited_ = true; } return Maybe::Ok(); } Maybe MultiClientSessionContext::TryInit(const std::string& config_proto_str) { ConfigProto config_proto; CHECK_OR_RETURN(TxtString2PbMessage(config_proto_str, &config_proto)) << Error::RuntimeError() << "failed to parse config_proto: " << config_proto_str; return TryInit(config_proto); } Maybe MultiClientSessionContext::UpdateResource(const Resource& reso_proto) { CHECK_OR_RETURN(is_inited_) << Error::RuntimeError() << " session must be inited when updating resource."; CHECK_NOTNULL_OR_RETURN((Singleton::Get())) << Error::RuntimeError() << "ResourceDesc get failed!"; Singleton::Get()->Update(reso_proto); return Maybe::Ok(); } Maybe MultiClientSessionContext::UpdateResource(const std::string& reso_proto_str) { Resource reso_proto; CHECK_OR_RETURN(TxtString2PbMessage(reso_proto_str, &reso_proto)) << Error::RuntimeError() << "failed to parse config_proto: " << reso_proto_str; return UpdateResource(reso_proto); } Maybe MultiClientSessionContext::TryClose() { if (is_inited_) { VLOG(1) << "Try to delete multi client session context." << std::endl; { // NOTE(chengcheng): delete runtime global objects Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton>>::Delete(); Singleton>>::Delete(); Singleton::Delete(); } Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); // TODO(chengcheng): remove template ForEnv and ForSession Singleton::Delete(); // NOTE(chengcheng): New after delete because in EnvGlobalObjectScope once created ResourceDesc. Singleton::New(Singleton::Get()->resource(), GlobalProcessCtx::NumOfProcessPerNode()); VLOG(1) << "Finish delete multi client session context." << std::endl; is_inited_ = false; } return Maybe::Ok(); } void MultiClientSessionContext::StoreFreeEagerTensorWithNameByGraphName( const std::string& graph_name, const std::shared_ptr& tensor, const std::string& tensor_name) { auto it = graph_name2free_eager_tensors_.find(graph_name); if (it == graph_name2free_eager_tensors_.end()) { it = graph_name2free_eager_tensors_ .emplace(graph_name, std::vector>>()) .first; } it->second.emplace_back(std::make_pair(tensor_name, tensor)); } const std::vector>>& MultiClientSessionContext::GetFreeEagerTensorNamePairByGraphName(const std::string& graph_name) { auto it = graph_name2free_eager_tensors_.find(graph_name); if (it == graph_name2free_eager_tensors_.end()) { it = graph_name2free_eager_tensors_ .emplace(graph_name, std::vector>>()) .first; } return it->second; } void MultiClientSessionContext::RemoveGraphFreeEagerTensors(const std::string& graph_name) { graph_name2free_eager_tensors_.erase(graph_name); } IdState MultiClientSessionContext::GetIdState() { CHECK(Singleton::Get() != nullptr); CHECK(Singleton::Get() != nullptr); CHECK(Singleton::Get() != nullptr); IdState id_state; id_state.job_id_state_ = Singleton::Get()->GetJobIdCount(); Singleton::Get()->SaveIdAndTaskIndex(&id_state); Singleton::Get()->GetTaskStreamIndex(&id_state.stream_index_state_); return id_state; } void MultiClientSessionContext::SetIdState(const IdState& id_state) { CHECK(Singleton::Get() != nullptr); CHECK(Singleton::Get() != nullptr); CHECK(Singleton::Get() != nullptr); Singleton::Get()->TryUpdateIdAndTaskIndex(&id_state); Singleton::Get()->TryUpdateTaskStreamIndex(id_state.stream_index_state_); Singleton::Get()->TryUpdateJobIdCount(id_state.job_id_state_); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/multi_client_session_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/id_state.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/job/env_global_objects_scope.h" namespace oneflow { class MultiClientSessionContext { public: OF_DISALLOW_COPY_AND_MOVE(MultiClientSessionContext); explicit MultiClientSessionContext(const std::shared_ptr&); ~MultiClientSessionContext(); Maybe TryInit(const ConfigProto& config_proto); Maybe TryInit(const std::string& config_proto_str); Maybe UpdateResource(const Resource& reso_proto); Maybe UpdateResource(const std::string& reso_proto_str); Maybe TryClose(); // NOTE(chengcheng): for nn.Graph catch free EagerTensor in Graph.build(). // NNGraph should NOT hold ANY shared_ptr because NNGraph will send to VM stream in // RunLazyNNGraphInstruction, the tensor in NNGraph will Never be released for hold in VM // instrunction and compute stream. So we store free EagerTensor in MultiClientSessionContext, // and will be release in NNGraph destructor. void StoreFreeEagerTensorWithNameByGraphName(const std::string& graph_name, const std::shared_ptr& tensor, const std::string& tensor_name); const std::vector>>& GetFreeEagerTensorNamePairByGraphName(const std::string& graph_name); void RemoveGraphFreeEagerTensors(const std::string& graph_name); IdState GetIdState(); void SetIdState(const IdState& id_state); private: bool is_inited_ = false; std::shared_ptr env_ctx_; HashMap>>> graph_name2free_eager_tensors_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/multi_thread.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/multi_thread.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace user_op { void MultiThreadLoopInOpKernel(size_t num, std::function Callback) { MultiThreadLoop(num, Callback); } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/multi_thread.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_ #define ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace user_op { void MultiThreadLoopInOpKernel(size_t num, std::function Callback); } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_ ================================================ FILE: oneflow/core/framework/mutable_attr_map.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_ #define ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/ordered_string_list.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/operator/op_conf.pb.h" namespace oneflow { class MutableAttrMap { public: OF_DISALLOW_COPY_AND_MOVE(MutableAttrMap); explicit MutableAttrMap(const std::vector& attr_names) : max_size_(attr_names.size()), valid_masks_(max_size_, 0), ordered_attr_names_(std::make_shared>()) { for (const auto& attr_name : attr_names) { ordered_attr_names_->emplace_back(attr_name); } attrs_.resize(max_size_); } ~MutableAttrMap() = default; size_t max_size() const { return max_size_; } const std::shared_ptr>& ordered_attr_names() const { return ordered_attr_names_; } const small_vector& valid_masks() const { return valid_masks_; } const small_vector, 8>& attrs() const { return attrs_; } inline void reset() { // mark all cached attributes as illegal values memset(valid_masks_.data(), 0, max_size_); } template inline void SetAttr(const char* attr_name, const T& attr_val) { auto idx = ordered_attr_names_->order(attr_name); CHECK_OR_THROW(idx != -1) << "has no attribute named " << attr_name; SetAttrNoThrow(idx, attr_val); } template inline void SetAttr(const T& attr_val) { CHECK_LT_OR_THROW(I, max_size_) << "index " << I << " is out of bound, and the max size is " << max_size_; SetAttrNoThrow(I, attr_val); } template inline void SetAllAttrs(Args&&... args) { CHECK_EQ_OR_THROW(sizeof...(args), max_size_) << "requires " << max_size_ << " arguments, but gives " << sizeof...(args); SetAttrNoThrow(std::forward(args)..., std::make_index_sequence{}); } private: template::value && !internal::IsOptional::value, int>::type = 0> inline void SetAttrNoThrow(int idx, const T& attr_val) { valid_masks_[idx] = true; if (!attrs_[idx] /*|| attrs_[idx]->type() != user_op::GetAttrType::value*/ || *static_cast(attrs_[idx]->Ptr()) != attr_val) { attrs_[idx] = std::make_shared>(attr_val); } } template::value, int>::type = 0> inline void SetAttrNoThrow(int idx, const T& attr_val) { if (attr_val) { using U = typename T::value_type; SetAttrNoThrow(idx, attr_val.value_or(U())); } } template::value, int>::type = 0> inline void SetAttrNoThrow(int idx, const T&) {} template inline void SetAttrNoThrow(Args&&... args, std::index_sequence) { (SetAttrNoThrow(I, std::forward(args)), ...); } // The actually count of all attributes size_t max_size_; small_vector valid_masks_; small_vector, 8> attrs_; // The ordered attribute names is determined and should be shared // between other AttrMap std::shared_ptr> ordered_attr_names_; }; #define THREAD_CACHED_MUTABLE_ATTR_MAP(...) \ []() -> MutableAttrMap& { \ thread_local static MutableAttrMap attrs(std::vector{__VA_ARGS__}); \ attrs.reset(); \ return attrs; \ }() } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_ ================================================ FILE: oneflow/core/framework/nd_sbp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { namespace { Maybe> FindOrCreateNdSbpString(Symbol nd_sbp) { static thread_local auto* nd_sbp2nd_sbp_str = new HashMap, std::shared_ptr>>(); auto iter = nd_sbp2nd_sbp_str->find(nd_sbp); if (iter == nd_sbp2nd_sbp_str->end()) { std::shared_ptr> nd_sbp_str = std::make_shared>(nd_sbp->sbp_parallel_size()); for (int64_t i = 0; i < nd_sbp_str->size(); ++i) { nd_sbp_str->at(i) = SbpParallelToString(nd_sbp->sbp_parallel(i)); } iter = nd_sbp2nd_sbp_str->emplace(nd_sbp, nd_sbp_str).first; } return iter->second; } Maybe GetDualSbpParallel(const SbpParallel& sbp_parallel, SbpParallel* dual_sbp_parallel) { if (sbp_parallel.has_split_parallel()) { *dual_sbp_parallel = sbp_parallel; } else if (sbp_parallel.has_broadcast_parallel()) { dual_sbp_parallel->mutable_partial_sum_parallel(); } else if (sbp_parallel.has_partial_sum_parallel()) { dual_sbp_parallel->mutable_broadcast_parallel(); } else { UNIMPLEMENTED_THEN_RETURN(); } return Maybe::Ok(); } } // namespace Maybe> GetDualNdSbp(Symbol nd_sbp) { static thread_local HashMap, Symbol> map; auto iter = map.find(nd_sbp); if (iter == map.end()) { NdSbp dual_nd_sbp; auto* mut_sbp_parallel = dual_nd_sbp.mutable_sbp_parallel(); for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) { JUST(GetDualSbpParallel(sbp_parallel, mut_sbp_parallel->Add())); } iter = map.emplace(nd_sbp, SymbolOf(dual_nd_sbp)).first; } return iter->second; } Maybe> GetNdSbpStrList(const std::vector>& sbp_list) { return FindOrCreateNdSbpString(JUST(GetNdSbp(sbp_list))); } Maybe> GetNdSbpStrList(Symbol nd_sbp) { return FindOrCreateNdSbpString(nd_sbp); } Maybe> GetDualNdSbpStrList(Symbol nd_sbp) { return GetNdSbpStrList(JUST(GetDualNdSbp(nd_sbp))); } namespace private_details { Maybe> RawGetNdSbp(const std::vector>& sbp_list) { CHECK_OR_RETURN(!sbp_list.empty()) << Error::InvalidValueError() << "sbp_list should be non-empty"; NdSbp nd_sbp; for (const auto& sbp : sbp_list) { *(nd_sbp.mutable_sbp_parallel()->Add()) = *sbp; } return SymbolOf(nd_sbp); } Maybe>> RawGetSbpList(Symbol nd_sbp) { const auto& vec = std::make_shared>>(); CHECK_OR_RETURN(!nd_sbp->sbp_parallel().empty()) << Error::InvalidValueError() << "sbp_parallel should be non-empty"; for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) { vec->emplace_back(SymbolOf(sbp_parallel)); } return vec; } bool RawContainSplitSbp(Symbol nd_sbp) { for (int32_t i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { if (nd_sbp->sbp_parallel(i).has_split_parallel()) { return true; } } return false; } Maybe>> RawNdSbpReplacePartialByBroadcast( const std::vector>& sbp_list) { auto result = std::make_shared>>(sbp_list.size()); for (int i = 0; i < sbp_list.size(); ++i) { const auto& sbp = sbp_list[i]; if (sbp->has_partial_sum_parallel()) { (*result)[i] = JUST(MakeBroadcastSbpParallel()); } else { (*result)[i] = sbp; } } return result; } } // namespace private_details const std::vector>& GetNoneSbpList() { static thread_local std::vector> none; return none; } std::string SbpToString(Symbol sbp_sym) { return SbpToString(*sbp_sym); } std::string NdSbpToString(Symbol nd_sbp_sym) { return NdSbpToString(*nd_sbp_sym); } std::string SbpToString(const SbpParallel& sbp) { std::ostringstream ss; if (sbp.has_broadcast_parallel()) { ss << "B"; } else if (sbp.has_partial_sum_parallel()) { ss << "P"; } else if (sbp.has_split_parallel()) { ss << "S(" << std::to_string(sbp.split_parallel().axis()) << ")"; } else { UNIMPLEMENTED(); } return ss.str(); } std::string NdSbpToString(const NdSbp& nd_sbp) { std::ostringstream ss; ss << "("; for (size_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (i > 0) { ss << ", "; } ss << SbpToString(nd_sbp.sbp_parallel(i)); } ss << ")"; return ss.str(); } Maybe> SetSbpAtAxis(Symbol nd_sbp, Symbol sbp, int axis) { return SetSbpAtAxis(*nd_sbp, *sbp, axis); } Maybe> SetSbpAtAxis(const NdSbp& nd_sbp, const SbpParallel& sbp, int axis) { CHECK_LT_OR_RETURN(axis, nd_sbp.sbp_parallel_size()) << Error::RuntimeError() << "Expected axis to be less than the size of sbp list (" << nd_sbp.sbp_parallel_size() << "), but got " << axis; NdSbp out_nd_sbp = nd_sbp; *out_nd_sbp.mutable_sbp_parallel(axis) = sbp; return SymbolOf(out_nd_sbp); } Maybe> SbpToNdSbp(Symbol sbp) { return SbpToNdSbp(*sbp); } Maybe> SbpToNdSbp(const SbpParallel& sbp) { NdSbp out_nd_sbp; *out_nd_sbp.add_sbp_parallel() = sbp; return SymbolOf(out_nd_sbp); } // If an nd sbp can be converted to a 1d sbp. bool Is1dSbp(const NdSbp& nd_sbp) { if (nd_sbp.sbp_parallel_size() == 0) { return false; } // Equivalent to // return std::all_of(nd_sbp.sbp_parallel().begin() + 1, nd_sbp.sbp_parallel().end(), // [&](const auto& sbp) { return sbp == nd_sbp.sbp_parallel(0); }); for (int32_t i = 1; i < nd_sbp.sbp_parallel_size(); i++) { if (nd_sbp.sbp_parallel(0) != nd_sbp.sbp_parallel(i)) { return false; } } return true; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/nd_sbp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_ #define ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { Maybe> GetDualNdSbp(Symbol nd_sbp); Maybe> GetDualNdSbp(Symbol sbp_list); Maybe> GetNdSbpStrList(const std::vector>& sbp_list); Maybe> GetNdSbpStrList(Symbol nd_sbp); Maybe> GetDualNdSbpStrList(Symbol nd_sbp); Maybe> GetDualNdSbpStrList(Symbol nd_sbp); namespace private_details { Maybe> RawGetNdSbp(const std::vector>& sbp_list); Maybe>> RawGetSbpList(Symbol nd_sbp); bool RawContainSplitSbp(Symbol nd_sbp); Maybe>> RawNdSbpReplacePartialByBroadcast( const std::vector>& sbp_list); } // namespace private_details static constexpr auto* GetNdSbp = DECORATE(&private_details::RawGetNdSbp, ThreadLocalCopiable); static constexpr auto* GetSbpList = DECORATE(&private_details::RawGetSbpList, ThreadLocal); static constexpr auto* ContainSplitSbp = DECORATE(&private_details::RawContainSplitSbp, ThreadLocal); const std::vector>& GetNoneSbpList(); static constexpr auto* NdSbpReplacePartialByBroadcast = DECORATE(&private_details::RawNdSbpReplacePartialByBroadcast, ThreadLocalCachedCopiable); std::string SbpToString(Symbol sbp_sym); std::string NdSbpToString(Symbol nd_sbp_sym); std::string SbpToString(const SbpParallel& sbp); std::string NdSbpToString(const NdSbp& nd_sbp); Maybe> SetSbpAtAxis(Symbol nd_sbp, Symbol sbp, int axis); Maybe> SetSbpAtAxis(const NdSbp& nd_sbp, const SbpParallel& sbp, int axis); Maybe> SbpToNdSbp(Symbol sbp); Maybe> SbpToNdSbp(const SbpParallel& sbp); // If an nd sbp can be converted to a 1d sbp. bool Is1dSbp(const NdSbp& nd_sbp); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_ ================================================ FILE: oneflow/core/framework/nn_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/cost_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/compiler.h" #include "oneflow/core/job/rank_compiler.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/job/utils/progress_bar.h" #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/job/compile_mode.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace { Maybe GetTensorValidInCurRank(const std::shared_ptr& tensor) { if (tensor->is_global()) { const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(JUST(tensor->parallel_desc()))); if (parallel_id->has_value()) { return true; } else { return false; } } else { return true; } } Maybe GetTensorMetaString(const std::shared_ptr& tensor) { std::string ret = "shape=" + tensor->shape()->ToString() + ", dtype=" + tensor->dtype()->name(); if (tensor->is_global()) { ret += ", placement=" + *JUST(PlacementToString(JUST(tensor->parallel_desc()))); ret += ", nd_sbp=" + NdSbpToString(JUST(tensor->nd_sbp())); } else { ret += ", device=" + JUST(tensor->device())->ToString(); } return ret; } template Maybe MakeEagerBlobObjectList(vm::EagerBlobObjectList* blob_list, const T& tensor_list) { blob_list->reserve(tensor_list.size()); for (const auto& tensor : tensor_list) { CHECK_OR_RETURN(tensor->is_eager()) << Error::RuntimeError() << "Tensors in nn.Graph should be eager"; if (tensor->is_global()) { blob_list->emplace_back(JUST(JUST(tensor->cur_rank_phy_tensor())->eager_blob_object())); } else { blob_list->emplace_back(JUST(tensor->eager_blob_object())); } } return Maybe::Ok(); } } // namespace NNGraph::~NNGraph() { VLOG(1) << "Graph destructor Try to close c nn graph name " << name_ << "." << std::endl; CHECK_JUST(Close()); } Maybe NNGraph::Close() { if (!is_closed_) { VLOG(1) << "Try to close c nn graph name " << name_ << "." << std::endl; CloseRuntimeBuffers(); runtime_.reset(); session_ctx_->RemoveGraphFreeEagerTensors(name_); VLOG(1) << "Finish close c nn graph name " << name_ << "." << std::endl; session_ctx_.reset(); is_closed_ = true; } return Maybe::Ok(); } const std::vector& NNGraph::inputs_op_names() const { return inputs_op_names_; } const std::vector& NNGraph::outputs_op_names() const { return outputs_op_names_; } const std::vector& NNGraph::inputs_valid() const { return input_tensors_valid_; } const std::vector& NNGraph::outputs_valid() const { return output_tensors_valid_; } const std::vector& NNGraph::inputs_tensor_meta_str() const { return inputs_tensor_meta_str_; } const std::vector& NNGraph::outputs_tensor_meta_str() const { return outputs_tensor_meta_str_; } int64_t NNGraph::variable_op_size() const { return variable_op_names_.size(); } const std::shared_ptr& NNGraph::var_blobs() const { return variable_op_blobs_; } Maybe NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded( const std::vector& additional_var_names, const std::vector>& additional_var_tensors) { CHECK_EQ_OR_RETURN(additional_var_names.size(), additional_var_tensors.size()) << Error::RuntimeError() << "Number of additional variable names and tensors mismatch. " "Size of variable names: " << additional_var_names.size() << ", size of tensors: " << additional_var_tensors.size(); CHECK_OR_RETURN(additional_variable_op_tobe_loaded_name2tensor_.empty()) << Error::RuntimeError() << "The additional variables (states in Optimizer or LRScheduler) of nn.Graph " << name_ << " are registered repeatedly."; FOR_RANGE(size_t, i, 0, additional_var_names.size()) { CHECK_OR_RETURN(additional_variable_op_tobe_loaded_name2tensor_ .emplace(JUST(VectorAt(additional_var_names, i)), JUST(VectorAt(additional_var_tensors, i))) .second) << Error::RuntimeError() << "Duplicate variable name: " << additional_var_names[i]; } return Maybe::Ok(); } Maybe NNGraph::RegisterInputOpNamesAndTensors( const std::vector& inputs_op_names, const std::vector>& input_tensors) { CHECK_EQ_OR_RETURN(inputs_op_names.size(), input_tensors.size()) << Error::RuntimeError() << "Number of input op names and tensors mismatch. " "Size of op names: " << inputs_op_names.size() << ", size of tensors: " << input_tensors.size(); CHECK_OR_RETURN(inputs_op_names_.empty()) << Error::RuntimeError() << "The input tensors of nn.Graph " << name_ << " are registered repeatedly."; CHECK_OR_RETURN(input_tensors_valid_.empty()) << Error::RuntimeError() << "The input tensors of nn.Graph " << name_ << " are registered repeatedly."; CHECK_OR_RETURN(inputs_tensor_meta_str_.empty()) << Error::RuntimeError() << "The input tensors of nn.Graph " << name_ << " are registered repeatedly."; inputs_op_names_.assign(inputs_op_names.begin(), inputs_op_names.end()); input_tensors_valid_.reserve(input_tensors.size()); inputs_tensor_meta_str_.reserve(input_tensors.size()); for (const auto& input_tensor : input_tensors) { input_tensors_valid_.emplace_back(JUST(GetTensorValidInCurRank(input_tensor))); inputs_tensor_meta_str_.emplace_back(*JUST(GetTensorMetaString(input_tensor))); } CHECK_EQ_OR_RETURN(input_tensors_valid_.size(), input_tensors.size()); // NOLINE return Maybe::Ok(); } Maybe NNGraph::RegisterOutputOpNamesAndTensors( const std::vector& outputs_op_names, const std::vector>& output_tensors) { CHECK_EQ_OR_RETURN(outputs_op_names.size(), output_tensors.size()) << "Number of output op names and tensors mismatch " "Size of op names: " << outputs_op_names.size() << ", size of tensors: " << output_tensors.size(); CHECK_OR_RETURN(outputs_op_names_.empty()) << Error::RuntimeError() << "The output tensors of nn.Graph " << name_ << " are registered repeatedly."; CHECK_OR_RETURN(output_tensors_valid_.empty()) << Error::RuntimeError() << "The output tensors of nn.Graph " << name_ << " are registered repeatedly."; CHECK_OR_RETURN(outputs_tensor_meta_str_.empty()) << Error::RuntimeError() << "The output tensors of nn.Graph " << name_ << " are registered repeatedly."; outputs_op_names_.assign(outputs_op_names.begin(), outputs_op_names.end()); output_tensors_valid_.reserve(output_tensors.size()); outputs_tensor_meta_str_.reserve(output_tensors.size()); for (const auto& output_tensor : output_tensors) { output_tensors_valid_.emplace_back(JUST(GetTensorValidInCurRank(output_tensor))); outputs_tensor_meta_str_.emplace_back(*JUST(GetTensorMetaString(output_tensor))); } CHECK_EQ_OR_RETURN(output_tensors_valid_.size(), output_tensors.size()); // NOLINT return Maybe::Ok(); } Maybe NNGraph::RegisterVariableOpNamesAndTensors( const std::vector& variable_op_names, const std::vector>& variable_tensors) { JUST(vm::CurrentRankSync()); CHECK_EQ_OR_RETURN(variable_op_names.size(), variable_tensors.size()) << "Number of variable names and tensors mismatch. " "Size of variable names: " << variable_op_names.size() << ", size of tensors: " << variable_tensors.size(); CHECK_ISNULL_OR_RETURN(variable_op_blobs_); variable_op_blobs_ = std::make_shared(); JUST(MakeEagerBlobObjectList(variable_op_blobs_.get(), variable_tensors)); for (int32_t i = 0; i < variable_op_names.size(); ++i) { const std::shared_ptr& var = variable_tensors[i]; CHECK_OR_RETURN(var->is_eager()) << Error::InvalidValueError() << "Tensor variable to register in nn.Graph should be eager"; const std::string& var_name = variable_op_names.at(i); CHECK_OR_RETURN(!var_name.empty()) << Error::InvalidValueError() << "Empty variable name"; CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second) << Error::RuntimeError() << "Duplicate variable name: " << var_name; CHECK_OR_RETURN(variable_op_names_.insert(var_name).second) << Error::RuntimeError() << "Duplicate variable name: " << var_name; } return Maybe::Ok(); } Maybe NNGraph::RegisterFreeEagerTensorsToVariableOpNames() { JUST(vm::CurrentRankSync()); const auto& free_eager_tensors = session_ctx_->GetFreeEagerTensorNamePairByGraphName(name_); for (const auto& pair : free_eager_tensors) { const std::string& var_name = pair.first; const std::shared_ptr& var = pair.second; CHECK_OR_RETURN(var->is_eager()) << Error::RuntimeError() << "Free tensor variable to register in nn.Graph should be eager"; CHECK_OR_RETURN(!var_name.empty()) << Error::RuntimeError() << "Empty variable name"; CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second) << Error::RuntimeError() << "Duplicate variable name: " << var_name; CHECK_OR_RETURN(additional_variable_op_name_.insert(var_name).second) << Error::RuntimeError() << "Duplicate variable name: " << var_name; CHECK_OR_RETURN(variable_op_names_.insert(var_name).second) << Error::RuntimeError() << "Duplicate variable name: " << var_name; } return Maybe::Ok(); } Maybe> NNGraph::GetAdditionalVarOpNames() const { std::vector names; for (const auto& iter : additional_variable_op_name_) { names.push_back(iter); } return names; } Maybe>> NNGraph::GetAdditionalVarOpTensors() const { std::vector> tensors; for (const auto& iter : additional_variable_op_name_) { auto find_iter = variable_op_name2tensor_.find(iter); CHECK_OR_RETURN(find_iter != variable_op_name2tensor_.end()) << Error::RuntimeError() << "Additional variable op name " << iter << " not found."; tensors.push_back(find_iter->second); } return tensors; } Maybe NNGraph::RegisterNewVariableOpInJobPass() { OpGraph op_graph(job_); JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe::Ok(); } const Operator& variable_op = op_node->op(); const VariableOpConf& var_conf = variable_op.op_conf().variable_conf(); const std::string& var_name = variable_op.op_name(); CHECK_OR_RETURN(var_conf.has_initializer()) << Error::RuntimeError() << "nn.Graph ONLY support variable op with initializer conf."; if (var_conf.initializer().has_constant_conf() || var_conf.initializer().has_constant_int_conf() /* vairable ops inserted by system */) { CHECK_OR_RETURN(variable_op_names_.insert(var_name).second) << Error::RuntimeError() << "Variable_op_name: " << var_name << " has been added in nn.Graph: " << name_; CHECK_OR_RETURN( variable_op_name2tensor_.insert({var_name, std::shared_ptr()}).second) << Error::RuntimeError() << "Variable Tensor with op_name: " << var_name << " has been add in nn.Graph: " << name_; CHECK_OR_RETURN(additional_variable_op_name_.insert(var_name).second) << Error::RuntimeError() << "Variable Tensor with op_name: " << var_name << " has been add in nn.Graph: " << name_; } else /* vairable ops from user code */ { CHECK_OR_RETURN(var_conf.initializer().has_empty_conf()) << Error::RuntimeError() << "nn.Graph ONLY support variable_op with empty conf, " << "because variable is inited by eager tensor. " << "This error variable conf is: " << variable_op.op_conf().DebugString() << " in nn.Graph " << name_; CHECK_OR_RETURN(variable_op_names_.find(var_name) != variable_op_names_.end()) << Error::RuntimeError() << var_name << " must be a variable created in nn.Graph: " << name_; } return Maybe::Ok(); })); return Maybe::Ok(); } Maybe NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { const auto& var_get_func = [&]() -> Maybe> { std::set variable_names_; OpGraph op_graph(job_); JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe::Ok(); } variable_names_.insert(op_node->op().op_name()); return Maybe::Ok(); })); return variable_names_; }; std::set variable_names = *JUST(var_get_func()); auto mgr = Singleton::Get(); for (auto& name : mgr->DumpNames()) { if (variable_names.find(name) == variable_names.end()) { mgr->Delete(name); } } return Maybe::Ok(); } namespace { // A templated function that broadcasts data from the master process to worker processes in a // multi-threaded manner. Return push/pull keys only in master process. template std::set MultiThreadBroadcastFromMasterToWorkers(size_t world_size, const std::string& prefix, const X& master_data, Y* worker_data) { const size_t thread_num = ThreadLocalEnvInteger(); const size_t split_num = std::sqrt(world_size); BalancedSplitter bs(world_size, split_num); std::set keys; if (GlobalProcessCtx::IsThisProcessMaster()) { std::mutex mtx4keys; std::string data; master_data.SerializeToString(&data); MultiThreadLoop( split_num, [&](int i) { std::string key = prefix + std::to_string(i); Singleton::Get()->PushKV(key, data); std::lock_guard lock(mtx4keys); CHECK(keys.insert(key).second); }, thread_num); } else { const int64_t bs_index = bs.GetRangeIndexForVal(GlobalProcessCtx::Rank()); std::string key = prefix + std::to_string(bs_index); Singleton::Get()->PullKV(key, worker_data); } return keys; } // A templated function that pushes data from the master process to each worker process using the // control client. The function takes as input a prefix for the key used to store the data in the // control client, a pointer to the data to be pushed, and a callable object PrepareEach that // preprocesses the worker's data. Return push/pull keys only in master process. template std::set MultiThreadPushFromMasterToWorkers(const std::string& prefix, T* data, const PrepareEachT& PrepareEach) { const size_t thread_num = ThreadLocalEnvInteger(); constexpr int kWorkerStartRank = 1; std::set keys{}; if (GlobalProcessCtx::IsThisProcessMaster()) { std::mutex mtx4keys; MultiThreadLoop( GlobalProcessCtx::WorldSize(), [&](int i) { if (i < kWorkerStartRank) { return; } T worker_data; std::string key = prefix + std::to_string(i); PrepareEach(&worker_data, i); Singleton::Get()->PushKV(key, worker_data); std::lock_guard lock(mtx4keys); CHECK(keys.emplace(key).second) << "redundant pull key: " << key; }, thread_num); } else { Singleton::Get()->PullKV(prefix + std::to_string(GlobalProcessCtx::Rank()), data); } return keys; } void DumpCalculationPassName(Job* job) { for (int i = 0; i < job->net().op_size(); ++i) { auto* op_conf = job->mutable_net()->mutable_op(i); if (op_conf->has_scope_symbol_id()) { const auto& scope = Singleton>::Get()->Get(op_conf->scope_symbol_id()); op_conf->set_calculation_pass_name(scope.scope_proto().calculation_pass_name()); } } } } // namespace // The main logic of separation plan compilation. Each rank (process) compile it's related task // nodes. This can reduce plan compile time and avoid transport large plan protobuf. // When master compile the full plan, some plan protos are much larger than 1GB, but protobuf has // 2GB limitation and larg files are slow to transport. So we mush do separatioin plan compile when // total rank num is large. // Separation plan compilation is done by: // a. Master broadcast job(or logical graph) to all workers, make all rank use the same job. // b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be // done on master rank. // c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the // BoxingTaskGraph and the job. Maybe NNGraph::MasterAndWorkerRanksCompile() { // Seperation compile mode only works with nccl use compute stream and logical chain. CHECK_OR_RETURN(EnableLogicalChain()) << Error::RuntimeError() << "nn.Graph separete compilation needs to work with logical chain enabled."; // Note that nccl use compute stream mode has not need to generate CollectiveBoxingPlan. CHECK_OR_RETURN((Singleton::Get()->nccl_use_compute_stream())) << Error::RuntimeError() << "nn.Graph separete compilation needs to work with nccl using compute stream enabled."; std::set push_pull_keys{}; const auto& MergeCommKeys = [&](std::set&& keys) { push_pull_keys.insert(keys.begin(), keys.end()); }; if (GlobalProcessCtx::IsThisProcessMaster()) { DumpCalculationPassName(&job_); } // a. Master broadcast job(or logical graph) to all workers, make all rank use the same job. const size_t world_size = GlobalProcessCtx::WorldSize(); MergeCommKeys(MultiThreadBroadcastFromMasterToWorkers( world_size, name_ + std::string(__FUNCTION__) + "_job", job_, &job_)); OpGraphSingletonGuard op_graph_guard(job_); size_t rank = GlobalProcessCtx::Rank(); // b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be // done on master rank. auto boxing_task_graph_proto = std::make_shared(); std::shared_ptr boxing_task_graph; if (GlobalProcessCtx::IsThisProcessMaster()) { const auto& ParallelLoop = [](size_t work_num, const std::function& Work) { MultiThreadLoop(work_num, Work, -1); }; boxing_task_graph = JUST(BoxingTaskGraph::New(ParallelLoop)); boxing_task_graph->ToProto([](TaskNode*) { return true; }, boxing_task_graph_proto.get()); if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create("boxing_task_" + name_ + "_plan" + std::to_string(0)) ->Write(*boxing_task_graph_proto); } } const auto& PrepareWorkerBoxingTaskGraphProto = [&](BoxingTaskGraphProto* proto, int64_t i) { boxing_task_graph->ToProto( [i](TaskNode* task_node) { return BoxingTaskGraph::SelectTaskNodeByRank(task_node, i); }, proto); if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create("boxing_task_" + name_ + "_plan" + std::to_string(i)) ->Write(*proto); } }; MergeCommKeys(MultiThreadPushFromMasterToWorkers( name_ + std::string(__FUNCTION__) + "_boxing_task_graph", boxing_task_graph_proto.get(), PrepareWorkerBoxingTaskGraphProto)); // c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the // BoxingTaskGraph and the job. auto* plan = &plan_; CHECK_JUST(RankCompiler(boxing_task_graph_proto, rank).Compile(variable_op_names_, &job_, plan)); PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names_); if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create("job_" + name_ + "_plan" + std::to_string(rank))->Write(*plan); PlanUtil::ToDotFile(*plan, "job_" + name_ + "_plan_" + std::to_string(rank) + ".dot"); } PlanUtil::GenRegisterHint(plan); PlanUtil::DumpCtrlRegstInfoToPlan(plan); PlanUtil::PlanMemoryLog(&plan_, name_); if (Singleton::Get()->enable_debug_mode()) { PlanUtil::GenLightPlan(&plan_, name_, rank); } OF_SESSION_BARRIER(); for (const auto& k : push_pull_keys) { Singleton::Get()->ClearKV(k); } OF_SESSION_BARRIER(); return Maybe::Ok(); } // Master compile the full plan. Maybe NNGraph::NaiveCompile() { auto compile_tc = std::make_unique>(true, true); if (GlobalProcessCtx::IsThisProcessMaster()) { auto sub_compile_tc = std::make_unique>(true, true); // TODO(chengcheng): new memory reused by chunk Compiler().Compile(&job_, &plan_); sub_compile_tc->Count("[PlanCompile]" + name_ + " GenerateBasePlan", 1); PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(&plan_, variable_op_names_); sub_compile_tc->Count("[PlanCompile]" + name_ + " GenMemBlockAndChunk", 1); PlanUtil::GenRegisterHint(&plan_); sub_compile_tc->Count("[PlanCompile]" + name_ + " GenRegisterHint", 1); // TODO(chengcheng): test collective boxing for multi-job. PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_); // PlanUtil::SetForceInplaceMemBlock(&plan_); NOTE(chengcheng): only for ssp. sub_compile_tc->Count("[PlanCompile]" + name_ + " GenCollectiveBoxingPlan", 1); PlanUtil::DumpCtrlRegstInfoToPlan(&plan_); sub_compile_tc->Count("[PlanCompile]" + name_ + " DumpCtrlRegstInfoToPlan", 1); PlanUtil::PlanMemoryLog(&plan_, name_); if (Singleton::Get()->enable_debug_mode()) { PlanUtil::GenLightPlan(&plan_, name_); } sub_compile_tc->Count("[GraphCompile]" + name_ + " GenMemAndLightPlanLog", 1, true); } compile_tc->Count("[GraphCompile]" + name_ + " CompilePlan", 0); if (GlobalProcessCtx::WorldSize() > 1) { std::string plan_name = "plan:" + job_name(); if (GlobalProcessCtx::IsThisProcessMaster()) { // TODO(chengcheng): split plan for each rank. Singleton::Get()->PushKV(plan_name, plan_); } else { Singleton::Get()->PullKV(plan_name, &plan_); } OF_SESSION_BARRIER(); if (GlobalProcessCtx::IsThisProcessMaster()) { Singleton::Get()->ClearKV(plan_name); } } compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true); return Maybe::Ok(); } // There are four plan compilation modes, with the first mode "master compilation" (default) and the // fourth mode "rank separation compilation" being the ones actually used. Maybe NNGraph::CompilePlanForRuntime() { // A global variable to get graph configurations. auto current_graph_config = std::make_unique(job_.job_conf(), job_id()); auto compile_tc = std::make_unique>(true, true); typedef Maybe (NNGraph::*CompileMethodT)(); struct GetCompileMethod final : public CompileModeVisitor { static CompileMethodT VisitNaive() { // Master rank compile the full plan. return &NNGraph::NaiveCompile; } static CompileMethodT VisitRankPerProcess() { // Multi process(rank) run seperation compile. return &NNGraph::MasterAndWorkerRanksCompile; } static CompileMethodT VisitInValid() { return nullptr; } }; JUST((this->*GetCompileMethod::Visit(JUST(CurrentCompileMode())))()); compile_tc->Count("[GraphCompile]" + name_ + " CompileAndSyncPlan", 0); PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); compile_tc->Count("[GraphCompile]" + name_ + " PopulateOpAttribute", 0); return Maybe::Ok(); } Maybe NNGraph::InitRuntime() { CHECK_OR_RETURN(!runtime_inited_) << Error::RuntimeError() << "nn.Graph runtime is already initialized"; auto compile_tc = std::make_unique>(true, true); NewRuntimeBuffers(); JUST(GetVariableRealBlobAfterSyncPlan()); // NOTE(strint): Do memory shrink to free cached memory in eager VM before graph runtime init. JUST(vm::CurrentRankSync()); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); if (Singleton::Get()->enable_debug_mode()) { auto cur_rank = GlobalProcessCtx::Rank(); auto plan_name = "job_" + name_ + "_plan"; if (JUST(CurrentCompileMode()) != CompileMode::kNaive) { plan_name += std::to_string(cur_rank); } if (cur_rank == 0 || JUST(CurrentCompileMode()) != CompileMode::kNaive) { TeePersistentLogStream::Create(plan_name)->Write(plan_); PlanUtil::ToDotFile(plan_, plan_name + ".dot"); } } runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_)); compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true); JUST(LogProgress("[GraphCompile]" + name_ + " Done", true)); runtime_inited_ = true; return Maybe::Ok(); } Maybe NNGraph::AlignStatesAfterLogicalGraphCompile() { auto compile_tc = std::make_unique>(true, true); JUST(RegisterFreeEagerTensorsToVariableOpNames()); JUST(RegisterNewVariableOpInJobPass()); JUST(DeleteOutdatedVariableInVariableTensorMgr()); // NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built. one::TensorNameScope::Global()->Clear(); // Clear all backward pass scope ClearAllBackwardPassScope(); compile_tc->Count("[GraphCompile]" + name_ + " AlignStates", 0); return Maybe::Ok(); } Maybe NNGraph::CompleteLogicalGraphForRuntime() { auto compile_tc = std::make_unique>(true, true); // A global variable to get graph configurations. auto current_graph_config = std::make_unique(job_.job_conf(), job_id()); // NOTE(chengcheng): do job compeleter for each rank. JUST(JobCompleter::Complete(&job_)); compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); return Maybe::Ok(); } Maybe NNGraph::BuildWithNewInputFromSharedGraph( const std::vector& shared_inputs_op_names, const std::vector>& new_input_tensors, const std::vector& shared_op_names_from_ordered_original_graph, const std::string& new_serialized_original_job) { CHECK_EQ_OR_RETURN(shared_inputs_op_names.size(), new_input_tensors.size()); // NOLINE auto compile_tc = std::make_unique>(true, true); // Register inputs. JUST(RegisterInputOpNamesAndTensors(shared_inputs_op_names, new_input_tensors)); // Generate new input tensor getter. HashMap> input_name2tensor; for (int64_t idx = 0; idx < shared_inputs_op_names.size(); ++idx) { input_name2tensor.emplace(shared_inputs_op_names[idx], new_input_tensors[idx]); } const auto& InputTensor4Name = [&input_name2tensor](const std::string& op_name) -> Maybe> { auto iter = input_name2tensor.find(op_name); CHECK_OR_RETURN(iter != input_name2tensor.end()) << "Can't find input tensor of " << op_name << "."; return iter->second; }; // Generate new OperatorConf getter. Job new_build_original_job; CHECK_OR_RETURN(new_build_original_job.ParseFromString(new_serialized_original_job)) << "nn.Graph " << name_ << " parse job proto of new build graph failed."; CHECK_EQ_OR_RETURN(new_build_original_job.net().op_size(), shared_op_names_from_ordered_original_graph.size()) << "nn.Graph " << name_ << " new_build_original_job op size and shared_op_names_from_ordered_original_graph " << "size are not equal."; HashMap shared_op_name2_new_op; for (int64_t op_idx = 0; op_idx < shared_op_names_from_ordered_original_graph.size(); ++op_idx) { // Assume that the new graph and the shared graph from nn.Graph.build have the same op order. const auto& op = new_build_original_job.mutable_net()->mutable_op()->at(op_idx); shared_op_name2_new_op.emplace(shared_op_names_from_ordered_original_graph[op_idx], &op); } const auto& NewOp4SharedOpName = [&shared_op_name2_new_op](const std::string& shared_op_name) -> Maybe { auto iter = shared_op_name2_new_op.find(shared_op_name); if (iter == shared_op_name2_new_op.end()) { VLOG(1) << "Can't find new traced operator conf for op " << shared_op_name << " in the shared graph from the base graph. This op is not shared between graphs."; return nullptr; } return iter->second; }; // A global variable to get graph configurations. auto current_graph_config = std::make_unique(job_.job_conf(), job_id()); // NOTE(chengcheng): do job compeleter for each rank. JUST(JobCompleter::UpdateSharedGraphForNewInput(&job_, InputTensor4Name, NewOp4SharedOpName)); compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); return Maybe::Ok(); } Maybe NNGraph::CompileAndInitRuntime() { JUST(AlignStatesAfterLogicalGraphCompile()); JUST(CompleteLogicalGraphForRuntime()); JUST(CompilePlanForRuntime()); JUST(InitRuntime()); return Maybe::Ok(); } Maybe NNGraph::GetVariableRealBlobAfterSyncPlan() { CHECK_OR_RETURN(variable_op_name2eager_blob_object_.empty()) << Error::RuntimeError() << kOfBugIssueUploadPrompt; JUST(vm::CurrentRankSync()); // Create or Rebuild variable, then get the real blob. for (const std::string& var_name : variable_op_names_) { auto iter = variable_op_name2tensor_.find(var_name); CHECK_OR_RETURN(iter != variable_op_name2tensor_.end()) << Error::RuntimeError() << "variable op name " << var_name << " not found."; std::shared_ptr tensor = iter->second; vm::EagerBlobObject* var_blob = nullptr; if (plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().find(var_name) == plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().end()) { // Deal with variable tensor not used in nn.Graph build. CHECK_OR_RETURN(tensor != NULL) << Error::RuntimeError() << "The tensor of " << var_name << " does not exist in the job, so it's not created in nn.Graph and cannot be NULL."; if (tensor->is_global()) { const std::shared_ptr local_var = JUST(tensor->cur_rank_phy_tensor()); var_blob = JUST(local_var->eager_blob_object()).get(); } else { var_blob = JUST(tensor->eager_blob_object()).get(); } } else if (/*is_null=*/!tensor) { // Deal with tensors which are not in the nn.Module. // We can call these tensors as additional variables. const auto& op_attribute = plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().at(var_name); // NOTE(chengcheng): handle constant variable created by job pass Symbol placement(op_attribute.parallel_conf_signature().op_parallel_conf()); NdSbp nd_sbp(NdSbpSignature(op_attribute.nd_sbp_signature()).bn_in_op2nd_sbp().at("out")); const BlobDesc blob_desc( op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc().at("out")); DType dtype(blob_desc.data_type()); std::shared_ptr>> sbp_tuple = JUST(GetSbpList(Symbol(nd_sbp))); auto load_tensor_iter = additional_variable_op_tobe_loaded_name2tensor_.find(var_name); if (load_tensor_iter == additional_variable_op_tobe_loaded_name2tensor_.end()) { // Create a additional variable tensor Scalar value; const VariableOpConf& var_conf = op_attribute.op_conf().variable_conf(); if (var_conf.initializer().has_constant_conf()) { value = var_conf.initializer().constant_conf().value(); } else if (var_conf.initializer().has_constant_int_conf()) { value = var_conf.initializer().constant_int_conf().value(); } else { OF_UNIMPLEMENTED(); } // NOTE(chengcheng): New EagerTensor need set LazyMode false. auto lazy_mode_disabled_guard = LazyMode::Guard(/*is_enabled*/ false); tensor = JUST(one::functional::GlobalConstant(blob_desc.shape(), value, Symbol(dtype), placement, *sbp_tuple)); JUST(vm::CurrentRankSync()); VLOG(2) << "Lazy nn.Graph name " << name_ << " op: " << op_attribute.op_conf().name() << " created in JobPass, nn.Graph has created a eager tensor for this variable.\n"; } else { // Load a additional variable tensor auto lazy_mode_disabled_guard = LazyMode::Guard(/*is_enabled*/ false); std::vector> grad_sbp_tuple; // To consistent from a local or global tensor. bool check_meta = load_tensor_iter->second->is_global() ? false : true; tensor = JUST(one::functional::ToGlobal(load_tensor_iter->second, placement, *sbp_tuple, grad_sbp_tuple, check_meta, /*copy=*/false)); JUST(vm::CurrentRankSync()); VLOG(2) << "Lazy nn.Graph name " << name_ << " op: " << op_attribute.op_conf().name() << " created in JobPass, nn.Graph has loaded the tensor from state dict for this " "variable.\n"; } // Register JUST(MapAt(variable_op_name2tensor_, var_name)) = tensor; // NOTE(chengcheng): Just for tensor lifetime hold by session context in graph lifetime // valid. session_ctx_->StoreFreeEagerTensorWithNameByGraphName(name_, tensor, var_name); const std::shared_ptr local_var = JUST(tensor->cur_rank_phy_tensor()); var_blob = JUST(local_var->eager_blob_object()).get(); } else if (tensor->is_global()) { // Deal with tensors which need to change sbp. NdSbpSignature var_nd_sbp_signature = NdSbpSignature(plan_.job_id2op_attribute_ref_table() .at(job_id_) .op_name2op_attribute() .at(var_name) .nd_sbp_signature()); NdSbp optimized_nd_sbp = var_nd_sbp_signature.bn_in_op2nd_sbp().at("out"); // Change variable tensor's impl with new sbp when job pass has changed their sbp. if (*JUST(tensor->nd_sbp()) != optimized_nd_sbp) { VLOG(2) << "Graph with name " << name_ << " variable with name `" << var_name << "` changes its' sbp from " << NdSbpToString(*JUST(tensor->nd_sbp())) << " to " << NdSbpToString(optimized_nd_sbp) << " after compile optimization."; std::vector> optimized_sbp_parallels; for (int i = 0; i < optimized_nd_sbp.sbp_parallel_size(); ++i) { optimized_sbp_parallels.emplace_back(optimized_nd_sbp.sbp_parallel(i)); } { auto lazy_mode_disabled_guard = LazyMode::Guard(/* is_enabled */ false); const auto& new_tensor = JUST(one::functional::ToGlobal( tensor, JUST(tensor->parallel_desc()), optimized_sbp_parallels, {}, /* check_meta */ false, /*copy=*/false)); JUST(vm::CurrentRankSync()); // Use tensor.set_data inferface and make new TensorImpl instead of the old one. JUST(tensor->set_data(new_tensor)); } } const std::shared_ptr local_var = JUST(tensor->cur_rank_phy_tensor()); var_blob = JUST(local_var->eager_blob_object()).get(); } else { var_blob = JUST(tensor->eager_blob_object()).get(); } CHECK_OR_RETURN(var_blob != nullptr) << Error::RuntimeError() << kOfBugIssueUploadPrompt; CHECK_OR_RETURN(variable_op_name2eager_blob_object_.emplace(var_name, var_blob).second) << Error::RuntimeError() << kOfBugIssueUploadPrompt; } // Initialize or check mem_ptr_for_allocation_computation_pipelining by TouchTensors instruction. JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { auto eager_blob_objects = std::make_shared(); for (const auto& pair : variable_op_name2eager_blob_object_) { eager_blob_objects->push_back(pair.second->shared_from_this()); } return builder->TouchTensors(eager_blob_objects); })); JUST(vm::CurrentRankSync()); // Clear after load additional variable is finished. additional_variable_op_tobe_loaded_name2tensor_.clear(); return Maybe::Ok(); } void NNGraph::NewRuntimeBuffers() { // NOTE(chengcheng): // 1. The BufferSize comes from job_conf.concurrency_width configured by user (default = 128) // 2. In Pipeline Parallelism, this value need greater than pipeline stage num for pipelining. size_t concurrency_width = job_.job_conf().concurrency_width(); { auto* buffer_mgr = Singleton>>::Get(); buffer_mgr->NewBuffer(GetSourceTickBufferName(name_), concurrency_width); buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(name_), concurrency_width); } { auto* buffer_mgr = Singleton>>::Get(); buffer_mgr->NewBuffer(GetInputCriticalSectionWaitBufferName(name_), concurrency_width); buffer_mgr->NewBuffer(GetInputCriticalSectionCallbackBufferName(name_), concurrency_width); buffer_mgr->NewBuffer(GetOutputCriticalSectionWaitBufferName(name_), concurrency_width); buffer_mgr->NewBuffer(GetOutputCriticalSectionCallbackBufferName(name_), concurrency_width); for (const std::string& input_op_name : inputs_op_names_) { buffer_mgr->NewBuffer(GetInputBufferName(name_, input_op_name), concurrency_width); } for (const std::string& output_op_name : outputs_op_names_) { buffer_mgr->NewBuffer(GetOutputBufferName(name_, output_op_name), concurrency_width); } } } void NNGraph::CloseRuntimeBuffers() { if (runtime_inited_) { { auto* buffer_mgr = Singleton>>::Get(); for (const std::string& output_op_name : outputs_op_names_) { buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close(); } for (const std::string& input_op_name : inputs_op_names_) { buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close(); } buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close(); buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close(); buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close(); buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close(); } { auto* buffer_mgr = Singleton>>::Get(); buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close(); buffer_mgr->Get(GetSourceTickBufferName(name_))->Close(); } } } Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTuple& outputs, const std::shared_ptr& nn_graph) { CHECK_EQ_OR_RETURN(inputs.size(), nn_graph->inputs_op_names().size()) << Error::RuntimeError() << "Number of inputs and NNGraph::inputs_op_names mismatch. " "Size of inputs: " << inputs.size() << ", size of NNGraph::inputs_op_names: " << nn_graph->inputs_op_names().size(); CHECK_EQ_OR_RETURN(outputs.size(), nn_graph->outputs_op_names().size()) << Error::RuntimeError() << "Number of outputs and NNGraph::outputs_op_names mismatch. " "Size of outputs: " << outputs.size() << ", size of NNGraph::outputs_op_names: " << nn_graph->outputs_op_names().size(); // NOTE(chengcheng): // parameters not used in LaunchLazyJobInstrucntion; // the args: parameters is all variable tensor hold by nn.Graph // but the NNGraph::variable_op_size may has FreeEagerTensor as sepcial variable op. CHECK_LE_OR_RETURN(nn_graph->var_blobs()->size(), nn_graph->variable_op_size()) << Error::RuntimeError() << "Parameter size should be less than or equal to variable size"; for (int i = 0; i < inputs.size(); ++i) { // TODO(chengcheng, liufengwei): // use TensorMeta.to_string and equal. std::string tensor_meta_str = *JUST(GetTensorMetaString(inputs.at(i))); const std::string& static_meta_str = nn_graph->inputs_tensor_meta_str().at(i); CHECK_OR_RETURN(static_meta_str == tensor_meta_str) << Error::RuntimeError() << "nn.Graph ONLY accepts static inputs tensor meta, please check whether your input " << "tensor meta each step is the same as the input of first call graph.\nThe excepted " << "tensor meta is: " << static_meta_str << ", but the actual tensor meta is: " << tensor_meta_str << ". The input index is " << i << "."; } for (int i = 0; i < outputs.size(); ++i) { CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i) == *JUST(GetTensorMetaString(outputs.at(i)))) << Error::RuntimeError() << "Output tensor meta string mismatch"; } vm::EagerBlobObjectList input_blobs; vm::EagerBlobObjectList output_blobs; JUST(MakeEagerBlobObjectList(&input_blobs, inputs)); JUST(MakeEagerBlobObjectList(&output_blobs, outputs)); const auto& input_blob_list_ptr = std::make_shared(std::move(input_blobs)); const auto& output_blob_list_ptr = std::make_shared(std::move(output_blobs)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->LaunchLazyJob(input_blob_list_ptr, output_blob_list_ptr, nn_graph->var_blobs(), nn_graph); })); return Maybe::Ok(); } Maybe SoftSyncNNGraphBuffers(const one::TensorTuple& buffers, const std::shared_ptr& nn_graph) { const auto& eager_blob_objects = std::make_shared(); JUST(MakeEagerBlobObjectList(eager_blob_objects.get(), buffers)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SoftSyncNNGraphBuffers(eager_blob_objects, nn_graph); })); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/nn_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_ #define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/eager/eager_blob_object.h" namespace oneflow { class Blob; class NNGraph final : public NNGraphIf { public: explicit NNGraph(const std::string& name, const Job& job, int64_t job_id, const std::shared_ptr& session_ctx) : name_(name), job_(job), job_id_(job_id), session_ctx_(session_ctx), runtime_inited_(false), is_closed_(false), run_cnt_(0) {} explicit NNGraph(const std::string& name, const Plan& plan, int64_t job_id, const std::shared_ptr& session_ctx) : name_(name), job_id_(job_id), session_ctx_(session_ctx), plan_(plan), runtime_inited_(false), is_closed_(false), run_cnt_(0) {} OF_DISALLOW_COPY_AND_MOVE(NNGraph); ~NNGraph(); const std::string& job_name() const override { return name_; } const Job& job() const { return job_; } void restore_job(const Job& job) { job_ = job; } int64_t job_id() const { return job_id_; } void restore_job_id(int64_t job_id) { job_id_ = job_id; } const Plan& plan() const { return plan_; } void restore_plan(const Plan& plan) { plan_ = plan; } const std::vector& inputs_op_names() const override; const std::vector& outputs_op_names() const override; const std::vector& inputs_valid() const override; const std::vector& outputs_valid() const override; const std::vector& inputs_tensor_meta_str() const; const std::vector& outputs_tensor_meta_str() const; int64_t variable_op_size() const; const std::shared_ptr& var_blobs() const; int64_t run_cnt() const override { return run_cnt_; } void NextRunCnt() override { run_cnt_++; } Maybe RegisterAdditionalVarOpNamesAndTensorsToBeLoaded( const std::vector& additional_var_names, const std::vector>& additional_var_tensors); Maybe RegisterInputOpNamesAndTensors( const std::vector& inputs_op_names, const std::vector>& input_tensors); Maybe RegisterOutputOpNamesAndTensors( const std::vector& outputs_op_names, const std::vector>& output_tensors); Maybe RegisterVariableOpNamesAndTensors( const std::vector& variable_op_names, const std::vector>& variable_tensors); Maybe> GetAdditionalVarOpNames() const; Maybe>> GetAdditionalVarOpTensors() const; // After logical graph compile, some state variables should be cleaned or built. Maybe AlignStatesAfterLogicalGraphCompile(); // Add special operators into logical graph for lazy runtime. Maybe CompleteLogicalGraphForRuntime(); // Build graph with new inputs from a completed job of a shared graph. Maybe BuildWithNewInputFromSharedGraph( const std::vector& shared_inputs_op_names, const std::vector>& new_input_tensors, const std::vector& shared_op_names_from_ordered_original_graph, const std::string& new_serialized_original_job); // Generate execution plan for lazy runtime. Oneflow lazy runtime is an actor based runtime. Maybe CompilePlanForRuntime(); // Initialize lazy runtime. Maybe InitRuntime(); Maybe CompileAndInitRuntime(); Maybe Close(); const auto variable_op_name2tensor() const { return variable_op_name2tensor_; } std::vector> cached_op_exprs; private: // Compile the full task graph for all ranks and then broadcast to all ranks. Maybe NaiveCompile(); // Each rank compile it's task graph. Maybe MasterAndWorkerRanksCompile(); Maybe RegisterFreeEagerTensorsToVariableOpNames(); Maybe RegisterNewVariableOpInJobPass(); Maybe DeleteOutdatedVariableInVariableTensorMgr(); Maybe GetVariableRealBlobAfterSyncPlan(); void NewRuntimeBuffers(); void CloseRuntimeBuffers(); std::string name_; Job job_; int64_t job_id_; std::shared_ptr session_ctx_; std::vector inputs_op_names_; std::vector outputs_op_names_; std::vector input_tensors_valid_; std::vector output_tensors_valid_; std::vector inputs_tensor_meta_str_; std::vector outputs_tensor_meta_str_; HashMap> variable_op_name2tensor_; // Additional variables are variable other than model states, such as states in // optimizers/lr schedulers or free eager tensors. HashSet additional_variable_op_name_; // Additional states tensor loaded from state dict, // they will be load into job after plan is generated. HashMap> additional_variable_op_tobe_loaded_name2tensor_; HashMap variable_op_name2eager_blob_object_; HashSet variable_op_names_; std::shared_ptr variable_op_blobs_; Plan plan_; // TODO(chengcheng): temp impl using runtime now, need reimplement for dynamic multi nn.Graph. std::unique_ptr runtime_; bool runtime_inited_; bool is_closed_; int64_t run_cnt_; }; Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTuple& outputs, const std::shared_ptr& nn_graph); Maybe SoftSyncNNGraphBuffers(const one::TensorTuple& buffers, const std::shared_ptr& nn_graph); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_ ================================================ FILE: oneflow/core/framework/nn_graph_if.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_ #define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_ #include #include #include "oneflow/core/common/symbol.h" namespace oneflow { class Device; class NNGraphIf { public: virtual ~NNGraphIf() = default; virtual const std::string& job_name() const = 0; virtual const std::vector& inputs_op_names() const = 0; virtual const std::vector& outputs_op_names() const = 0; virtual const std::vector& inputs_valid() const = 0; virtual const std::vector& outputs_valid() const = 0; virtual int64_t run_cnt() const = 0; virtual void NextRunCnt() = 0; protected: NNGraphIf() = default; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_ ================================================ FILE: oneflow/core/framework/op_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/id_util.h" namespace oneflow { namespace one { static constexpr char PositionalPlaceholderPrefix[] = "^Placeholder_"; OpBuilder::OpBuilder(const std::string& op_type_name) { *(proto_.mutable_op_type_name()) = op_type_name; op_name_ = *CHECK_JUST(UniqueStr(op_type_name)); } OpBuilder::OpBuilder(const std::string& op_type_name, const std::string& op_name) : op_name_(op_name) { *(proto_.mutable_op_type_name()) = op_type_name; } Maybe OpBuilder::MaybeInput(const std::string& input_name, const int count) { CHECK_GT_OR_RETURN(count, 0); CHECK_EQ_OR_RETURN(proto_.input().count(input_name), 0) << "The Input " << input_name << " has been specified more than once."; proto_.add_input_order(input_name); auto* input_list = &((*(proto_.mutable_input()))[input_name]); for (int i = 0; i < count; ++i) { const std::string& tensor_name = op_name_ + "/" + PositionalPlaceholderPrefix + std::to_string(input_pos_++); input_list->mutable_s()->Add()->assign(tensor_name); indexed_ibns_.emplace_back(input_name + "_" + std::to_string(i)); } CHECK_EQ_OR_RETURN(proto_.input().size(), proto_.input_order().size()); return *this; } OpBuilder& OpBuilder::Input(const std::string& input_name) { return CHECK_JUST(MaybeInput(input_name, 1)); } OpBuilder& OpBuilder::Input(const std::string& input_name, const int count) { return CHECK_JUST(MaybeInput(input_name, count)); } Maybe OpBuilder::MaybeOutput(const std::string& output_name, const int count) { CHECK_GT_OR_RETURN(count, 0); CHECK_EQ_OR_RETURN(proto_.output().count(output_name), 0) << "The output " << output_name << " has been specified more than once."; proto_.add_output_order(output_name); auto* output_list = &((*(proto_.mutable_output()))[output_name]); for (int i = 0; i < count; ++i) { const std::string& tensor_name = op_name_ + "/" + output_name + "_" + std::to_string(i); output_list->mutable_s()->Add()->assign(tensor_name); indexed_obns_.emplace_back(output_name + "_" + std::to_string(i)); } CHECK_EQ_OR_RETURN(proto_.output().size(), proto_.output_order().size()); return *this; } OpBuilder& OpBuilder::Output(const std::string& output_name) { return CHECK_JUST(MaybeOutput(output_name, 1)); } OpBuilder& OpBuilder::Output(const std::string& output_name, const int count) { return CHECK_JUST(MaybeOutput(output_name, count)); } template<> Maybe OpBuilder::MaybeAttr(const std::string& attr_name, const AttrValue& attr_value) { (*(proto_.mutable_attr()))[attr_name] = attr_value; return *this; } template<> OpBuilder& OpBuilder::Attr(const std::string& attr_name, const AttrValue& attr_value) { return CHECK_JUST(MaybeAttr(attr_name, attr_value)); } #define DEFINE_OP_BUILDER_ATTR_FUNC(field, cpp_type, attr_type) \ template<> \ Maybe OpBuilder::MaybeAttr(const std::string& attr_name, \ const cpp_type& val) { \ AttrValue attr_val; \ user_op::AttrValueAccessor::Attr(val, &attr_val); \ return this->MaybeAttr(attr_name, attr_val); \ } \ \ template<> \ OpBuilder& OpBuilder::Attr(const std::string& attr_name, const cpp_type& val) { \ return CHECK_JUST(MaybeAttr(attr_name, val)); \ } OF_PP_FOR_EACH_TUPLE(DEFINE_OP_BUILDER_ATTR_FUNC, ATTR_SEQ) #undef DEFINE_OP_BUILDER_ATTR_FUNC Maybe OpBuilder::Build() { return UserOpExpr::New(op_name_, std::move(proto_), indexed_ibns_, indexed_obns_); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_ #include #include "oneflow/core/framework/op_expr.h" namespace oneflow { namespace one { // The op builder for UserOp. // Note that the internal proto will be moved if the Build method is called. // Therefore, please make sure that the Build method be called at last, and do not perform any // operations on this builder instance after the calling. class OpBuilder { public: OpBuilder() = delete; explicit OpBuilder(const std::string& op_type_name); explicit OpBuilder(const std::string& op_type_name, const std::string& op_name); virtual ~OpBuilder() = default; Maybe MaybeInput(const std::string& input_name, const int count); OpBuilder& Input(const std::string& input_name); OpBuilder& Input(const std::string& input_name, const int count); Maybe MaybeOutput(const std::string& output_name, const int count); OpBuilder& Output(const std::string& output_name); OpBuilder& Output(const std::string& output_name, const int count); template Maybe MaybeAttr(const std::string& attr_name, const T& attr_value); template OpBuilder& Attr(const std::string& attr_name, const T& attr_value); Maybe Build(); private: std::string op_name_; UserOpConf proto_; int input_pos_ = 0; std::vector indexed_ibns_; std::vector indexed_obns_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_ ================================================ FILE: oneflow/core/framework/op_definition.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_ #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/maybe.h" namespace oneflow { namespace user_op { class AttrVal; } // namespace user_op using AttrVal = user_op::AttrVal; class OpDefinitionBase { public: virtual ~OpDefinitionBase() = default; virtual Maybe Attr(const std::string& attr_name) const = 0; virtual const HashSet& AttributeNames() const = 0; protected: OpDefinitionBase() = default; }; template class OpDefinition : public OpDefinitionBase { public: virtual ~OpDefinition() = default; const HashSet& AttributeNames() const override { return Derived::AttrNames(); } protected: OpDefinition() : OpDefinitionBase() {} }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_ ================================================ FILE: oneflow/core/framework/op_expr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/error.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { Maybe OpExpr::GetOrCreateAutoCastMeta() const { static auto autocast_meta = std::make_shared(); return autocast_meta; } BuiltinOpExpr::BuiltinOpExpr(const std::string& op_name, const std::vector& indexed_ibns, const std::vector& indexed_obns) : op_name_(op_name), input_arg_tuple_(new ArgTuple(indexed_ibns)), output_arg_tuple_(new ArgTuple(indexed_obns)) {} #define DEFINE_BUILTIN_OPEXPR_OP(T, op_type, disable_grad, support_non_contiguous) \ template<> \ const std::string& BuiltinOpExprImpl::op_type_name() const { \ static const std::string& name(op_type); \ return name; \ } \ template<> \ Maybe BuiltinOpExprImpl::IsGradDisabled() const { \ return disable_grad; \ } \ template<> \ Maybe BuiltinOpExprImpl::SupportNonContiguous() const { \ return support_non_contiguous; \ } \ template<> \ Maybe BuiltinOpExprImpl::GetOrCreateAutoCastMeta() const { \ return OpExpr::GetOrCreateAutoCastMeta(); \ } DEFINE_BUILTIN_OPEXPR_OP(FeedInputOpConf, "feed_input", false, false); DEFINE_BUILTIN_OPEXPR_OP(FeedVariableOpConf, "feed_variable", false, false); DEFINE_BUILTIN_OPEXPR_OP(FetchOutputOpConf, "fetch_output", false, false); DEFINE_BUILTIN_OPEXPR_OP(ImageDecoderRandomCropResizeOpConf, "image_gpu_decode", true, false); DEFINE_BUILTIN_OPEXPR_OP(VariableOpConf, "variable", true, false); DEFINE_BUILTIN_OPEXPR_OP(CastToLocalOpConf, "cast_to_local", false, false); DEFINE_BUILTIN_OPEXPR_OP(CastFromLocalOpConf, "cast_from_local", false, false); DEFINE_BUILTIN_OPEXPR_OP(DistributeSplitOpConf, "distribute_split", false, false); DEFINE_BUILTIN_OPEXPR_OP(DistributeCloneOpConf, "distribute_clone", false, false); DEFINE_BUILTIN_OPEXPR_OP(DistributeConcatOpConf, "distribute_concat", false, false); DEFINE_BUILTIN_OPEXPR_OP(DistributeAddOpConf, "distribute_add", false, false); #undef DEFINE_BUILTIN_OPEXPR_OP template<> const std::string& BuiltinOpExprImpl::op_type_name() const { return op_proto_.op_type_name(); } const std::string& GlobalToGlobalOpExpr::op_type_name() const { static const std::string kOpTypeName = "global_to_global"; return kOpTypeName; } const std::string& LocalToGlobalOpExpr::op_type_name() const { static const std::string kOpTypeName = "local_to_global"; return kOpTypeName; } const std::string& GlobalToLocalOpExpr::op_type_name() const { static const std::string kOpTypeName = "global_to_local"; return kOpTypeName; } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_user_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); auto* user_op_conf = op_conf->mutable_user_conf(); for (const auto& it : attrs) { AttrValue attr_val; JUST(user_op::AttrValueUtil::ToProtoAttrValue(*it.second, &attr_val)); (*(user_op_conf->mutable_attr()))[it.first] = attr_val; } return Maybe::Ok(); } Maybe UserOpExpr::MutKernel4Stream(Symbol stream) const { const auto& it = stream2kernel_.find(stream); if (it != stream2kernel_.end()) { return it->second; } std::shared_ptr op_conf = std::make_shared(); JUST(BuildOpConf(op_conf.get(), {})); op_conf->set_device_tag(stream->device()->type()); auto parallel_desc = JUST(Placement4Device(stream->device())).shared_from_symbol(); const auto& opkernel = JUST(StatefulOpKernel::New(op_conf, stream, base_attrs(), parallel_desc, input_arg_tuple(), output_arg_tuple())); stream2kernel_.emplace(stream, opkernel); return opkernel; } template<> Maybe BuiltinOpExprImpl::IsGradDisabled() const { const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(proto().op_type_name()); CHECK_NOTNULL_OR_RETURN(registry); return registry->no_grad; } template<> Maybe BuiltinOpExprImpl::SupportNonContiguous() const { const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(proto().op_type_name()); CHECK_NOTNULL_OR_RETURN(registry) << "The op(operation) " << proto().op_type_name() << " is not found. Please check whether it has been registered correctly."; return registry->non_contiguous_supported; } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { CHECK_OR_RETURN((IsClassRegistered(proto().op_type_name()))) << "The gradient function for op " << proto().op_type_name() << " is not found. Please check whether it has been implemented and registered correctly."; op_grad_func_.reset(NewObj(proto().op_type_name())); JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } template<> Maybe BuiltinOpExprImpl::GetOrCreateAutoCastMeta() const { if (!autocast_meta_) { autocast_meta_ = autocast::MakeAutoCastMeta(proto().op_type_name(), this->indexed_input_pairs()); } return autocast_meta_; } namespace { class UserOpExprInferContext : public user_op::InferContext { public: UserOpExprInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs, const std::string& device_tag, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) : user_op_expr_(user_op_expr), composed_attrs_(attrs, user_op_expr->base_attrs()), tensor_meta4input_index_(TensorMeta4InputIndex), tensor_meta4output_index_(TensorMeta4OutputIndex) { loc_ = DispatchFrame::get_str(); } virtual ~UserOpExprInferContext() override = default; const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); } const std::vector>& outputs() const override { return user_op_expr_->indexed_output_pairs(); } const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(const std::string& name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(name, index); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); } } { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); if (tuple_index >= 0) { return tensor_meta4input_index_(tuple_index); } } return nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); if (tuple_index >= 0) { TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); return tensor_meta_ptr; } } { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); if (tuple_index >= 0) { const TensorMeta* tensor_meta_ptr = tensor_meta4input_index_(tuple_index); CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); return const_cast(tensor_meta_ptr); } } PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } const Shape& InputShape(const std::string& name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4input_index_(tuple_index)->shape(); } const Shape& OutputShape(const std::string& name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4input_index_(tuple_index)->shape(); } void SetOutputShape(const std::string& name, int32_t index, const Shape& shape) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); return tensor_meta_ptr->set_shape(shape); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->shape(); } void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Shape& shape) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_shape(shape); } const Stride& InputStride(const std::string& name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4input_index_(tuple_index)->stride(); } const Stride& OutputStride(const std::string& name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4output_index_(tuple_index)->stride(); } void SetOutputStride(const std::string& name, int32_t index, const Stride& stride) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); return tensor_meta_ptr->set_stride(stride); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->stride(); } void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Stride& stride) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_stride(stride); } DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override { return SetDtype4ArgNameAndIndex(arg_name, index, data_type); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type(); } void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index, DataType data_type) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_data_type(data_type); } MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } void SetOutputMemoryFormat(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format); } MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format(); } void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override { return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic(); } void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index, bool is_dynamic) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_is_dynamic(is_dynamic); } const std::string& input(const std::string& arg_name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); CHECK_GE(tuple_index, 0); return arg_tuple.indexed_bns().at(tuple_index); } const std::string& output(const std::string& arg_name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); CHECK_GE(tuple_index, 0); return arg_tuple.indexed_bns().at(tuple_index); } bool has_input(const std::string& arg_name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); return tuple_index >= 0; } bool has_output(const std::string& arg_name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); return tuple_index >= 0; } int32_t input_size(const std::string& arg_name) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); } int32_t output_size(const std::string& arg_name) const override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); } const std::string& op_name() const override { return user_op_expr_->op_name(); } const std::string& op_type_name() const override { return user_op_expr_->op_type_name(); } const std::string& op_loc() const override { return loc_; } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return composed_attrs_.Attr4Name(attr_name); } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; const std::function& tensor_meta4input_index_; const std::function& tensor_meta4output_index_; std::string loc_; }; namespace { Symbol Get1DBroadcastNdSbp() { NdSbp broadcast_nd_sbp; broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); return SymbolOf(broadcast_nd_sbp); } auto* CachedGet1DBroadcastNdSbp = DECORATE(&Get1DBroadcastNdSbp, ThreadLocalCached); } // namespace class UserOpExprPhysicalInferContext final : public UserOpExprInferContext { public: UserOpExprPhysicalInferContext( const UserOpExpr* user_op_expr, const AttrMap& attrs, const std::string& device_tag, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) : UserOpExprInferContext(user_op_expr, attrs, device_tag, TensorMeta4InputIndex, TensorMeta4OutputIndex), parallel_desc_(CHECK_JUST(GetParallelDescOfThisRank(device_tag))) { parallel_ctx_.set_parallel_id(0); parallel_ctx_.set_parallel_num(1); } ~UserOpExprPhysicalInferContext() override = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } const ParallelContext& parallel_ctx() const override { return parallel_ctx_; } const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, int32_t index) const override { CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(name, index)); return CachedGet1DBroadcastNdSbp()->sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(name, index)); return *(CachedGet1DBroadcastNdSbp()); } int64_t parallel_num() const override { return 1; } private: // these member vars just used for physical infer Symbol parallel_desc_; ParallelContext parallel_ctx_; }; class UserOpExprLogicalInferContext final : public UserOpExprInferContext { public: UserOpExprLogicalInferContext( const UserOpExpr* user_op_expr, const AttrMap& attrs, Symbol parallel_desc, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) : UserOpExprInferContext(user_op_expr, attrs, parallel_desc->device_tag(), TensorMeta4InputIndex, TensorMeta4OutputIndex), parallel_desc_(parallel_desc) { const auto& opt_parallel_id = CHECK_JUST(GetParallelId4CurrentProcessCtx(parallel_desc_)); // Default parallel_id = -1, which will not cause bad effects becauce it will never be used in // LogicalTensorDescInfer. int64_t parallel_id = -1; if (opt_parallel_id->has_value()) { parallel_id = CHECK_JUST(*opt_parallel_id); } parallel_ctx_.set_parallel_id(parallel_id); parallel_ctx_.set_parallel_num(parallel_desc_->parallel_num()); } ~UserOpExprLogicalInferContext() override = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } const ParallelContext& parallel_ctx() const override { return parallel_ctx_; } const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, int32_t index) const override { const GlobalTensorMeta* tensor_meta = dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); Symbol nd_sbp = tensor_meta->nd_sbp(); CHECK_EQ(nd_sbp->sbp_parallel_size(), 1); return nd_sbp->sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { const GlobalTensorMeta* tensor_meta = dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); return *tensor_meta->nd_sbp(); } int64_t parallel_num() const override { return parallel_desc_->parallel_num(); } private: Symbol parallel_desc_; ParallelContext parallel_ctx_; }; class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs, const TensorTuple& input_tensors, TensorTuple* output_tensors) : user_op_expr_(user_op_expr), composed_attrs_(attrs, user_op_expr->base_attrs()), input_tensors_(&input_tensors), output_tensors_(output_tensors) {} const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); } const std::vector>& outputs() const override { return user_op_expr_->indexed_output_pairs(); } Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return CHECK_JUST(output_tensors_->at(tuple_index)->mut_device()); } Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, int64_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return CHECK_JUST(input_tensors_->at(tuple_index)->device()); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return composed_attrs_.Attr4Name(attr_name); } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; const TensorTuple* input_tensors_; TensorTuple* output_tensors_; }; } // namespace UserOpExpr::UserOpExpr(const std::string& op_name, UserOpConf&& proto, const AttrMap& base_attrs, const std::vector& indexed_ibns, const std::vector& indexed_obns) : BuiltinOpExprImpl(op_name, std::move(proto), indexed_ibns, indexed_obns), base_attrs_(base_attrs) {} Maybe UserOpExpr::Init(const std::shared_ptr& self) { const auto& op_type_name = op_proto_.op_type_name(); const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); CHECK_NOTNULL_OR_RETURN(registry); logical_tensor_desc_infer_fn_ = registry->logical_tensor_desc_infer_fn; CHECK_OR_RETURN(static_cast(logical_tensor_desc_infer_fn_)) << Error::RuntimeError() << "registry->logical_tensor_desc_infer_fn failed."; physical_tensor_desc_infer_fn_ = registry->physical_tensor_desc_infer_fn; CHECK_OR_RETURN(static_cast(physical_tensor_desc_infer_fn_)) << Error::RuntimeError() << "registry->logical_tensor_desc_infer_fn failed."; dtype_infer_fn_ = registry->data_type_infer_fn; CHECK_OR_RETURN(static_cast(dtype_infer_fn_)) << Error::RuntimeError() << "registry->data_type_infer_fn failed."; if (registry->device_and_stream_infer_fn) { device_and_stream_infer_fn_ = registry->device_and_stream_infer_fn; } local_tensor_infer_cache_.reset(new LocalTensorInferCache(self)); global_tensor_infer_cache_.reset(new GlobalTensorInferCache(self)); const auto& indexed_input_pairs = this->indexed_input_pairs(); for (int32_t i = 0; i < indexed_input_pairs.size(); ++i) { const auto& input_pair = JUST(VectorAt(indexed_input_pairs, i)); if (user_op::UserOpHostMemoryInputRegistry::Get().IsHostMemoryInput4Op( op_type_name, input_pair.first, input_pair.second)) { host_memory_input_ids_.emplace_back(i); } } return Maybe::Ok(); } /* static */ Maybe UserOpExpr::New(const std::string& op_name, UserOpConf&& op_proto, const std::vector& indexed_ibns, const std::vector& indexed_obns) { JUST(AddAttrDefaultValueAndCheckValid(&op_proto)); AttrMap base_attrs = MakeAttrMapFromUserOpConf(op_proto); std::shared_ptr op_expr( new UserOpExpr(op_name, std::move(op_proto), base_attrs, indexed_ibns, indexed_obns)); JUST(op_expr->Init(op_expr)); return op_expr; } Maybe UserOpExpr::InferPhysicalTensorDesc( const AttrMap& attrs, const std::string& device_tag, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) const { UserOpExprPhysicalInferContext infer_ctx(this, attrs, device_tag, TensorMeta4InputIndex, TensorMeta4OutputIndex); JUST(physical_tensor_desc_infer_fn_(&infer_ctx)); JUST(dtype_infer_fn_(&infer_ctx)); return Maybe::Ok(); } Maybe UserOpExpr::InferLogicalTensorDesc( const AttrMap& attrs, Symbol parallel_desc, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) const { UserOpExprLogicalInferContext infer_ctx(this, attrs, parallel_desc, TensorMeta4InputIndex, TensorMeta4OutputIndex); JUST(logical_tensor_desc_infer_fn_(&infer_ctx)); JUST(dtype_infer_fn_(&infer_ctx)); return Maybe::Ok(); } Maybe> UserOpExpr::InferDeviceAndStream(const AttrMap& attrs, const TensorTuple& input_tensors, TensorTuple* output_tensors) const { CHECK_OR_RETURN(static_cast(device_and_stream_infer_fn_)); UserOpExprDeviceAndStreamInferContext device_infer_ctx(this, attrs, input_tensors, output_tensors); return TRY(device_and_stream_infer_fn_(&device_infer_ctx)); } GlobalToGlobalOpExpr::GlobalToGlobalOpExpr(const Optional>& grad_nd_sbp) : grad_nd_sbp_(grad_nd_sbp) {} /* static */ Maybe GlobalToGlobalOpExpr::New( const Optional>& grad_nd_sbp) { auto* ptr = new GlobalToGlobalOpExpr(grad_nd_sbp); return std::shared_ptr(ptr); } CastGlobalOpExpr::CastGlobalOpExpr(const std::string& op_name) : op_name_(op_name) {} LocalToGlobalOpExpr::LocalToGlobalOpExpr(const std::string& op_name) : CastGlobalOpExpr(op_name) {} /* static */ Maybe LocalToGlobalOpExpr::New(const std::string& op_name) { return std::shared_ptr(new LocalToGlobalOpExpr(op_name)); } GlobalToLocalOpExpr::GlobalToLocalOpExpr(const std::string& op_name) : CastGlobalOpExpr(op_name) {} /* static */ Maybe GlobalToLocalOpExpr::New(const std::string& op_name) { return std::shared_ptr(new GlobalToLocalOpExpr(op_name)); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_feed_input_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("graph_feed_and_fetch")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); // NOLINT JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_feed_variable_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("graph_feed_and_fetch")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); // NOLINT JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_fetch_output_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("graph_feed_and_fetch")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); // NOLINT JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } template<> Maybe BuiltinOpExprImpl::BuildOpConf( OperatorConf* op_conf, const AttrMap& attrs) const { *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_image_decoder_random_crop_resize_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); auto* proto = op_conf->mutable_image_decoder_random_crop_resize_conf(); proto->set_target_width(JUST(attrs.GetAttr("target_width"))); proto->set_target_height(JUST(attrs.GetAttr("target_height"))); proto->set_num_workers(JUST(attrs.GetAttr("num_workers"))); proto->set_max_num_pixels(JUST(attrs.GetAttr("max_num_pixels"))); proto->set_warmup_size(JUST(attrs.GetAttr("warmup_size"))); proto->set_seed(JUST(attrs.GetAttr("seed"))); proto->set_num_attempts(JUST(attrs.GetAttr("num_attempts"))); proto->set_random_area_min(JUST(attrs.GetAttr("random_area_min"))); proto->set_random_area_max(JUST(attrs.GetAttr("random_area_max"))); proto->set_random_aspect_ratio_min(JUST(attrs.GetAttr("random_aspect_ratio_min"))); proto->set_random_aspect_ratio_max(JUST(attrs.GetAttr("random_aspect_ratio_max"))); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_variable_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_cast_to_local_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_cast_from_local_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } Maybe GlobalToGlobalOpExpr::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("global_to_global")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } Maybe LocalToGlobalOpExpr::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("local_to_global")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } Maybe GlobalToLocalOpExpr::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("global_to_local")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_split_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_clone_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_concat_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_add_conf()) = op_proto_; *(op_conf->mutable_loc()) = DispatchFrame::get_str(); return Maybe::Ok(); } template<> Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure() const { UNIMPLEMENTED_THEN_RETURN(); } Maybe SelectTopNOpExpr::GetOrCreateOpGradClosure() const { if (!op_grad_func_.get()) { op_grad_func_.reset(NewObj("select_top_n")); CHECK_NOTNULL_OR_RETURN(op_grad_func_.get()); JUST(op_grad_func_->Init(*this)); } return std::make_shared(op_grad_func_); } void FunctionOpExpr::reset_state() const { state_.reset(new FunctionAutoGradCaptureState); } Maybe FunctionOpExpr::GetOrCreateOpGradClosure() const { if (!op_grad_func_) { op_grad_func_.reset(new FunctionOpExprGradFunction(func_name_, backward_fn_)); } return std::make_shared(op_grad_func_, state_); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_expr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/autocast.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/user_op_conf.pb.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/arg_tuple.h" #include "oneflow/core/autograd/autograd_function.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" namespace oneflow { namespace one { class OpExprGradFunctionIf; class OpExprGradClosure; class OpExpr { public: virtual ~OpExpr() = default; virtual const std::string& op_type_name() const = 0; virtual int input_size() const = 0; virtual int output_size() const = 0; virtual Maybe IsGradDisabled() const = 0; virtual Maybe SupportNonContiguous() const = 0; virtual Maybe GetOrCreateOpGradClosure() const = 0; virtual Maybe GetOrCreateAutoCastMeta() const; protected: OpExpr() = default; }; class BuiltinOpExpr : public OpExpr { public: explicit BuiltinOpExpr(const std::string& op_name, const std::vector& indexed_ibns, const std::vector& indexed_obns); virtual ~BuiltinOpExpr() = default; const std::string& op_name() const { return op_name_; } int input_size() const override { return input_arg_tuple_->size(); } int output_size() const override { return output_arg_tuple_->size(); } const std::shared_ptr& input_arg_tuple() const { return input_arg_tuple_; } const std::shared_ptr& output_arg_tuple() const { return output_arg_tuple_; } const std::vector& indexed_ibns() const { return input_arg_tuple_->indexed_bns(); } const std::vector& indexed_obns() const { return output_arg_tuple_->indexed_bns(); } const std::vector>& indexed_input_pairs() const { return input_arg_tuple_->indexed_arg_name_and_index(); } const std::vector>& indexed_output_pairs() const { return output_arg_tuple_->indexed_arg_name_and_index(); } virtual Maybe BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const = 0; protected: std::string op_name_; std::shared_ptr input_arg_tuple_; std::shared_ptr output_arg_tuple_; }; class TensorMeta; template class BuiltinOpExprImpl : public BuiltinOpExpr { public: static Maybe> New(const std::string& op_name, ProtoType&& op_proto, const std::vector& indexed_ibns, const std::vector& indexed_obns) { return std::shared_ptr>( new BuiltinOpExprImpl(op_name, std::move(op_proto), indexed_ibns, indexed_obns)); } virtual ~BuiltinOpExprImpl() = default; const ProtoType& proto() const { return op_proto_; } ProtoType* mutable_proto() { return &op_proto_; } const std::string& op_type_name() const override; Maybe IsGradDisabled() const override; Maybe SupportNonContiguous() const override; Maybe GetOrCreateOpGradClosure() const override; Maybe GetOrCreateAutoCastMeta() const override; Maybe BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const override; protected: explicit BuiltinOpExprImpl(const std::string& op_name, ProtoType&& op_proto, const std::vector& indexed_ibns, const std::vector& indexed_obns) : BuiltinOpExpr(op_name, indexed_ibns, indexed_obns), op_proto_(std::move(op_proto)) {} ProtoType op_proto_; mutable std::shared_ptr op_grad_func_; mutable std::shared_ptr autocast_meta_; }; class StatefulOpKernel; class LocalTensorInferCache; class GlobalTensorInferCache; class UserOpExpr final : public BuiltinOpExprImpl { public: UserOpExpr() = delete; virtual ~UserOpExpr() = default; static Maybe New(const std::string& op_name, UserOpConf&& op_proto, const std::vector& indexed_ibns, const std::vector& indexed_obns); const AttrMap& base_attrs() const { return base_attrs_; } Maybe MutKernel4Stream(Symbol stream) const; bool has_device_and_stream_infer_fn() const { return static_cast(device_and_stream_infer_fn_); } const user_op::DeviceAndStreamInferFn& device_and_stream_infer_fn() const { return device_and_stream_infer_fn_; } bool IsHostMemoryInput(int32_t input_index) const { return std::find(host_memory_input_ids_.begin(), host_memory_input_ids_.end(), input_index) != host_memory_input_ids_.end(); } Maybe InferPhysicalTensorDesc( const AttrMap& attrs, const std::string& device_tag, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) const; Maybe InferLogicalTensorDesc( const AttrMap& attrs, Symbol parallel_desc, const std::function& TensorMeta4InputIndex, const std::function& TensorMeta4OutputIndex) const; Maybe> InferDeviceAndStream(const AttrMap& attrs, const TensorTuple& inputs, TensorTuple* outputs) const; LocalTensorInferCache* mut_local_tensor_infer_cache() const { return local_tensor_infer_cache_.get(); } GlobalTensorInferCache* mut_global_tensor_infer_cache() const { return global_tensor_infer_cache_.get(); } private: UserOpExpr(const std::string& op_name, UserOpConf&& proto, const AttrMap& base_attrs, const std::vector& indexed_ibns, const std::vector& indexed_obns); Maybe Init(const std::shared_ptr& self); AttrMap base_attrs_; user_op::TensorDescInferFn logical_tensor_desc_infer_fn_; user_op::TensorDescInferFn physical_tensor_desc_infer_fn_; user_op::DataTypeInferFn dtype_infer_fn_; user_op::DeviceAndStreamInferFn device_and_stream_infer_fn_; mutable HashMap, std::shared_ptr> stream2kernel_; std::shared_ptr local_tensor_infer_cache_; std::shared_ptr global_tensor_infer_cache_; small_vector host_memory_input_ids_; }; class GlobalToGlobalOpExpr : public OpExpr { public: virtual ~GlobalToGlobalOpExpr() = default; static Maybe New(const Optional>& grad_nd_sbp); const Optional>& grad_nd_sbp() const { return grad_nd_sbp_; } const std::string& op_type_name() const override; int input_size() const override { return 1; } int output_size() const override { return 1; } Maybe IsGradDisabled() const override { return false; } Maybe SupportNonContiguous() const override { return false; } Maybe GetOrCreateOpGradClosure() const override; protected: GlobalToGlobalOpExpr(const Optional>& grad_nd_sbp); Optional> grad_nd_sbp_; // Reserved for configuring grad sbp mutable std::shared_ptr op_grad_func_; }; class CastGlobalOpExpr : public OpExpr { public: virtual ~CastGlobalOpExpr() = default; const std::string& op_name() const { return op_name_; } int input_size() const override { return 1; } int output_size() const override { return 1; } Maybe IsGradDisabled() const override { return false; } Maybe SupportNonContiguous() const override { return false; } protected: CastGlobalOpExpr(const std::string& op_name); std::string op_name_; mutable std::shared_ptr op_grad_func_; }; class LocalToGlobalOpExpr final : public CastGlobalOpExpr { public: ~LocalToGlobalOpExpr() = default; static Maybe New(const std::string& op_name); const std::string& op_type_name() const override; Maybe GetOrCreateOpGradClosure() const override; private: LocalToGlobalOpExpr(const std::string& op_name); }; class GlobalToLocalOpExpr final : public CastGlobalOpExpr { public: ~GlobalToLocalOpExpr() = default; static Maybe New(const std::string& op_name); const std::string& op_type_name() const override; Maybe GetOrCreateOpGradClosure() const override; private: GlobalToLocalOpExpr(const std::string& op_name); }; // NOTE(chengcheng): For Lazy nn.Graph Feed/Fetch EagerTensor to/from LazyTensor. using FeedInputOpExpr = BuiltinOpExprImpl; using FeedVariableOpExpr = BuiltinOpExprImpl; using FetchOutputOpExpr = BuiltinOpExprImpl; // NOTE(chengcheng): Special SystemOp for image gpu decode. using ImageDecoderRandomCropResizeOpExpr = BuiltinOpExprImpl; using VariableOpExpr = BuiltinOpExprImpl; using CastToLocalOpExpr = BuiltinOpExprImpl; using CastFromLocalOpExpr = BuiltinOpExprImpl; using DistributeSplitOpExpr = BuiltinOpExprImpl; using DistributeCloneOpExpr = BuiltinOpExprImpl; using DistributeConcatOpExpr = BuiltinOpExprImpl; using DistributeAddOpExpr = BuiltinOpExprImpl; class SelectTopNOpExpr final : public OpExpr { public: static Maybe New() { return std::shared_ptr(new SelectTopNOpExpr()); } const std::string& op_type_name() const override { static const std::string kOpTypeName = "select_top_n"; return kOpTypeName; } int input_size() const override { UNIMPLEMENTED(); return 0; } int output_size() const override { // output should be resized in apply function return 0; } Maybe IsGradDisabled() const override { return false; } Maybe SupportNonContiguous() const override { return false; } Maybe GetOrCreateOpGradClosure() const override; private: SelectTopNOpExpr() = default; mutable std::shared_ptr op_grad_func_; }; class AutoGradCaptureState; class FunctionOpExpr final : public OpExpr { public: using FType = AutogradFunctionBase::FType; FunctionOpExpr() = delete; static Maybe New(const std::string& func_name, const FType& forward_fn, const FType& backward_fn) { return std::shared_ptr(new FunctionOpExpr(func_name, forward_fn, backward_fn)); } const std::string& op_type_name() const override { return func_name_; } int input_size() const override { PRINT_BUG_PROMPT_AND_ABORT() << "You cannot get input_size here."; return 0; } int output_size() const override { PRINT_BUG_PROMPT_AND_ABORT() << "You cannot get output_size here."; return 0; } FType forward() const { return forward_fn_; } FType backward() const { return backward_fn_; } std::shared_ptr state() const { return state_; } void reset_state() const; Maybe IsGradDisabled() const override { return false; } Maybe SupportNonContiguous() const override { return false; } Maybe GetOrCreateOpGradClosure() const override; private: FunctionOpExpr(const std::string& func_name, const FType& forward_fn, const FType& backward_fn) : forward_fn_(forward_fn), backward_fn_(backward_fn), func_name_(func_name) {} FType forward_fn_; FType backward_fn_; std::string func_name_; mutable std::shared_ptr state_; mutable std::shared_ptr op_grad_func_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_ ================================================ FILE: oneflow/core/framework/op_expr_grad_function.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/saved_tensor_hooks.h" namespace oneflow { namespace one { void AutoGradCaptureState::unpack() { if (saved_tensors_.empty() && !hooks_.empty()) { for (const auto& hook : hooks_) { saved_tensors_.push_back(hook->unpack()); } hooks_.clear(); } } size_t AutoGradCaptureState::SaveTensorForBackward(const std::shared_ptr& tensor) { auto hook = []() -> std::unique_ptr { if (auto* hook_creator = Singleton::Get()) { return hook_creator->new_saved_tensor_hook(); } return nullptr; }(); if (hook) { hook->pack(tensor); size_t offset = hooks_.size(); hooks_.push_back(std::move(hook)); return offset; } else { size_t offset = saved_tensors_.size(); if (tensor->is_local() && tensor->is_eager()) { if (auto rematable_storage = std::dynamic_pointer_cast( CHECK_JUST(tensor->eager_blob_object())->tensor_storage())) { rematable_storage->set_needed_by_backward(); } } saved_tensors_.emplace_back(tensor); return offset; } } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_expr_grad_function.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_ #include "oneflow/core/autograd/autograd_captured_tensor.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/op_args_vector.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/framework/saved_tensor_hooks.h" namespace oneflow { namespace one { static constexpr char kGradientOpSuffix[] = ".grad"; class AutoGradCaptureState { public: AutoGradCaptureState() = default; virtual ~AutoGradCaptureState() = default; void unpack(); const TensorTuple& SavedTensors() const { return saved_tensors_; } size_t SaveTensorForBackward(const std::shared_ptr& tensor); public: std::vector input_requires_grad; protected: TensorTuple saved_tensors_; small_vector, TensorTuple::kInitialSize> hooks_; }; class FunctionAutoGradCaptureState final : public AutoGradCaptureState, public std::enable_shared_from_this { public: FunctionAutoGradCaptureState() : pyobj_ptr_(nullptr, [](void*) {}) {} using AutoGradCaptureState::SavedTensors; using AutoGradCaptureState::SaveTensorForBackward; void MarkNonDifferentiable(const std::shared_ptr& tensor) { non_differentiable_tensors_.emplace(tensor.get()); } HashSet NonDifferentiableTensors() const { return non_differentiable_tensors_; } std::shared_ptr GetSharedFromThis() { return shared_from_this(); } // NOTE(wyg): Hold PyOjbect ptr to ensure getting the same object when casting to python. // And decrease the reference count when C++ object is destructed to avoid memory leaking. void* pyobject() const { return pyobj_ptr_.get(); } void set_pyobject_ptr(std::unique_ptr&& pyobj_ptr) { pyobj_ptr_ = std::move(pyobj_ptr); } public: std::vector input_requires_grad; private: HashSet non_differentiable_tensors_; std::unique_ptr pyobj_ptr_; }; // Stateless container base of the backward op exprs. // The backward op exprs should be contained in the derived class. class OpExprGradFunctionIf { public: virtual ~OpExprGradFunctionIf() = default; virtual std::shared_ptr MakeCustomState() const = 0; virtual Maybe Init(const OpExpr& op) = 0; // Capture forward inputs and outputs for backward. virtual Maybe CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const = 0; virtual Maybe ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const = 0; }; template class OpExprGradFunction : public OpExprGradFunctionIf { public: std::shared_ptr MakeCustomState() const override { return std::make_shared(); } Maybe CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { StateT* state = dynamic_cast(ctx); CHECK_NOTNULL_OR_RETURN(state); // Convert outputs from `Tensor` to `AutogradCapturedTensor` to avoid // circular reference between `Tensor` and `FunctionNode`. OF_PROFILER_RANGE_PUSH("init inputs"); TensorTuple captured_inputs(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { captured_inputs[i] = JUST(AutogradCapturedTensor::MakeTensor(inputs.at(i))); } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("init outputs"); TensorTuple captured_outputs(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { captured_outputs[i] = JUST(AutogradCapturedTensor::MakeTensor(outputs.at(i))); } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_GUARD("Capture"); return Capture(state, captured_inputs, captured_outputs, interp_ctx); } Maybe ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const StateT* state = dynamic_cast(ctx); CHECK_NOTNULL_OR_RETURN(state); return Apply(state, out_grads, in_grads); } protected: virtual Maybe Capture(StateT* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const { return Capture(ctx, inputs, outputs, interp_ctx.attrs); } virtual Maybe Capture(StateT* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { UNIMPLEMENTED_THEN_RETURN(); } virtual Maybe Apply(const StateT* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const = 0; std::string GradientOpName(const std::string& prefix) const { return prefix + std::string(kGradientOpSuffix); } }; class FunctionOpExprGradFunction final : public OpExprGradFunctionIf { public: using FType = AutogradFunctionBase::FType; FunctionOpExprGradFunction(const std::string& func_name, const FType& backward_fn) : backward_fn_(backward_fn), op_name_(func_name) {} std::shared_ptr MakeCustomState() const override { PRINT_BUG_PROMPT_AND_ABORT() << "You should not construct AutoGradCaptureState by calling this function"; return std::make_shared(); } Maybe Init(const OpExpr& op) override { // do nothing return Maybe::Ok(); } Maybe CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { FunctionAutoGradCaptureState* func_ctx = dynamic_cast(ctx); func_ctx->input_requires_grad.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { func_ctx->input_requires_grad[i] = inputs.at(i)->requires_grad(); } return Maybe::Ok(); } Maybe ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const FunctionAutoGradCaptureState* func_ctx = dynamic_cast(ctx); CHECK_NOTNULL_OR_RETURN(func_ctx); const std::shared_ptr& out = backward_fn_( const_cast(func_ctx)->GetSharedFromThis(), out_grads); in_grads->resize(func_ctx->input_requires_grad.size()); CHECK_EQ_OR_RETURN(out->size(), in_grads->size()) << "RuntimeError: function " << op_name_ << " returned an incorrect number of gradients (expected " << in_grads->size() << ", got " << out->size() << ")"; for (int i = 0; i < in_grads->size(); ++i) { if (func_ctx->input_requires_grad[i]) { if (!out->at(i)) { return Error::RuntimeError() << "autograd.Function named " << op_name_ << "'s inputs[" << i << "] requires grad but got None grad. Please use Tensor.detach() for this " "input."; } in_grads->at(i) = out->at(i); } } return Maybe::Ok(); } protected: FType backward_fn_; std::string op_name_; }; // Stateful wrapper of the `OpExprGradFunction`. class OpExprGradClosure { public: // Use `shared_ptr` in order to keep `impl` alive even if the forward op has been released. explicit OpExprGradClosure(const std::shared_ptr& impl) : OpExprGradClosure(impl, impl->MakeCustomState()) {} explicit OpExprGradClosure(const std::shared_ptr& impl, const std::shared_ptr& state) : impl_(impl), state_(state) {} virtual ~OpExprGradClosure() = default; Maybe Capture(const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const { return impl_->CaptureIf(state_.get(), inputs, outputs, interp_ctx); } Maybe Apply(const TensorTuple& out_grads, TensorTuple* in_grads) const { state_->unpack(); return impl_->ApplyIf(state_.get(), out_grads, in_grads); } const std::shared_ptr& state() const { return state_; } private: std::shared_ptr impl_; std::shared_ptr state_; }; #define REGISTER_OP_EXPR_GRAD_FUNCTION(op_type, op_grad) \ REGISTER_CLASS_CREATOR(std::string, op_type, OpExprGradFunctionIf, ([]() { return new op_grad; })) } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_ ================================================ FILE: oneflow/core/framework/op_interpreter/dispatch_frame.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" #include namespace oneflow { /* static */ std::string* DispatchFrame::get_str_ptr() { static thread_local std::string frame_str = ""; return &frame_str; } /* static */ const std::string& DispatchFrame::get_str() { return *get_str_ptr(); } /* static */ void DispatchFrame::set_str(const std::string& str) { *get_str_ptr() = str; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/dispatch_frame.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_ #include "oneflow/core/common/util.h" namespace oneflow { class DispatchFrame { public: OF_DISALLOW_COPY_AND_MOVE(DispatchFrame); DispatchFrame() = delete; ~DispatchFrame() = delete; static const std::string& get_str(); static void set_str(const std::string& str); class Guard { public: explicit Guard(const std::string& frame_str) : prev_frame_str_(DispatchFrame::get_str()) { DispatchFrame::set_str(frame_str); } ~Guard() { DispatchFrame::set_str(prev_frame_str_); } private: std::string prev_frame_str_; }; private: static std::string* get_str_ptr(); }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_ ================================================ FILE: oneflow/core/framework/op_interpreter/eager_global_op_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/framework/symbol_storage_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_global_id.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/common/cpp_attribute.h" namespace oneflow { namespace one { namespace { bool IsEnvEnableGlobalInputsWithInConsistentPlacement() { const bool env_enable_inconsistent_placement = ParseBooleanFromEnv("ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT", false); return env_enable_inconsistent_placement; } Maybe IsInputsParallelDescIdentical( const std::shared_ptr& infer_args) { if (infer_args->input_global_tensor_metas().empty()) { return true; } Symbol default_parallel_desc = JUST(VectorAt(infer_args->input_global_tensor_metas(), 0)).tensor_meta()->parallel_desc(); for (int i = 1; i < infer_args->input_global_tensor_metas().size(); ++i) { const auto& parallel_desc = JUST(VectorAt(infer_args->input_global_tensor_metas(), i)) .tensor_meta() ->parallel_desc() ->data(); if (!default_parallel_desc->EqualsIgnoringDeviceType(parallel_desc)) { return false; } } return true; } constexpr auto* IsAllInputsParallelDescIdentical = DECORATE(&IsInputsParallelDescIdentical, ThreadLocalCopiable); Maybe MaxRankNumber(Symbol placement) { // Find max rank number of a tensor's placement // e.g. tensor's placement is [[0,1,2],[2,3,4],[7,8,9]] // then max rank number is 9 return placement->sorted_machine_ids().back(); } constexpr auto* GetMaxRankNumber = DECORATE(&MaxRankNumber, ThreadLocalCachedCopiable); Maybe> MaxRankTensorPlacement( const std::shared_ptr& infer_args) { // Find the max rank tensor id in all input tensors. // e.g. if there are three tensor in inputs // tensor parallel_desc // inputs[0] tensor a [0, 1, 2] // inputs[1] tensor b [3, 4, 5] // inputs[2] tensor c [2, 3, 4] // then max rank number is 5, max rank tensor is b, max rank tensor id is 1 const auto& global_tensor_metas = infer_args->input_global_tensor_metas(); CHECK_OR_RETURN(global_tensor_metas.size() > 0); // NOLINT int64_t max_rank_tensor_id = 0; int64_t max_rank = 0; for (int64_t i = 0; i < global_tensor_metas.size(); ++i) { int64_t tensor_max_rank = JUST( GetMaxRankNumber(JUST(VectorAt(global_tensor_metas, i)).tensor_meta()->parallel_desc())); if (tensor_max_rank >= max_rank) { max_rank = tensor_max_rank; max_rank_tensor_id = i; } } return JUST(VectorAt(global_tensor_metas, max_rank_tensor_id)).tensor_meta()->parallel_desc(); } constexpr auto* GetMaxRankTensorPlacement = DECORATE(&MaxRankTensorPlacement, ThreadLocalCachedCopiable); Maybe> GetParallelDesc(const TensorTuple& inputs, const OpExprInterpContext& ctx, const UserOpExpr& user_op_expr) { if (!inputs.empty()) { for (int32_t i = 0; i < inputs.size(); ++i) { if (!user_op_expr.IsHostMemoryInput(i)) { return inputs.at(i)->parallel_desc(); } } } return JUST(ctx.parallel_desc); } std::string GetDynamicOpGlobalFailedDebugString(const UserOpExpr& user_op_expr, const StatefulOpKernel& kernel) { CHECK(!kernel.output_tuple_indexes4mut2_obns().empty()); std::string plentysuffix = kernel.output_tuple_indexes4mut2_obns().size() == 1 ? "s" : ""; std::stringstream ss; ss << "operator `" << user_op_expr.op_type_name() << "`" << " does not support global mode because the shape" << plentysuffix << " of output tensor" << plentysuffix << " "; int i = 0; for (const auto& out_index : kernel.output_tuple_indexes4mut2_obns()) { if (i++ > 0) { ss << ", "; } ss << out_index; } ss << " are not infered before op computation."; return ss.str(); } Maybe IsAllZeroSizeTensorMeta(const std::vector>& tensor_metas) { if (tensor_metas.empty()) { return false; } for (const auto& tensor_meta : tensor_metas) { if (tensor_meta->shape().elem_cnt() != 0) { return false; } } return true; } constexpr auto* CachedIsAllZeroSizeTensorMeta = DECORATE(&IsAllZeroSizeTensorMeta, ThreadLocalCopiable); Maybe CalcBoxingOutput(const std::shared_ptr& input, Symbol out_nd_sbp, Symbol out_parallel_desc, bool current_rank_local_is_valid) { const auto& logical_shape = input->shape(); // If the input is a tensor of size 0, construct the output directly. if (unlikely(logical_shape->elem_cnt() == 0)) { GlobalTensorMeta tensor_meta(*logical_shape, input->dtype()->data_type(), input->memory_format(), out_nd_sbp, out_parallel_desc); const auto& tensor_impl = JUST(EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false)); std::shared_ptr output = std::make_shared(tensor_impl); return output; } const auto* mgr = Singleton::Get(); // Eager boxing const auto& in_nd_sbp = JUST(input->nd_sbp()); const auto& in_parallel_desc = JUST(input->parallel_desc()); const auto& boxing_interpreter = JUST(mgr->GetEagerBoxingInterpreter( in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *logical_shape)); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ ""); if (!current_rank_local_is_valid) { return input; } const auto& output = JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); return output; } auto* GetBoxingOutput = DECORATE(DECORATE(&CalcBoxingOutput, CheckGlobalTensorMeta), DisableRecusiveBoxingCall); Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); Symbol parallel_desc = JUST(GetParallelDesc(inputs, ctx, user_op_expr)); std::shared_ptr result; NonRecursiveMetaInfoConsistencyCheckScope scope; // extand lifetime of boxing outputs to the end of this function TensorTuple boxing_inputs = inputs; if (inputs.empty()) { // check consistency placement and nd_sbp, do not check in non-src op because it is assumed // that InferSbp in op is a deterministic algorithm JUST(MetaInfoConsistencyCheck(parallel_desc, ctx.nd_sbp, 1, /* force_check */ false)); const auto& infer_args = JUST(SrcOpGlobalTensorMetaInferArgs::New(ctx.attrs, parallel_desc, JUST(ctx.nd_sbp))); result = JUST(user_op_expr.mut_global_tensor_infer_cache()->GetOrInfer(*infer_args)); } else { for (int i = 0; i < outputs->size(); ++i) { if ((*outputs)[i]) { const auto& nd_sbp = JUST((*outputs)[i]->nd_sbp()); JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(nd_sbp)); } } std::shared_ptr infer_args = JUST(GlobalTensorMetaInferArgs::New(ctx.attrs, boxing_inputs)); // is_identical is true indicating all inputs tensor have same parallel_desc const bool is_identical = JUST(IsAllInputsParallelDescIdentical(infer_args)); // if is_identical is false and env 'ONEFLOW_ENABLE_PIPELINE_PARALLELISM_AUTO_TO_GLOBAL' set to // true then traverse all input tensor use function GetBoxingOutput(), during this process, // each tensor will to_global with target parallel_desc if (IsEnvEnableGlobalInputsWithInConsistentPlacement() && !is_identical) { parallel_desc = JUST(GetMaxRankTensorPlacement(infer_args)); Optional parallel_id; JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); for (int i = 0; i < inputs.size(); ++i) { const auto& input = inputs.at(i); Optional input_parallel_id; JUST(GetTensorDevice4CurrentProcessCtx(JUST(input->parallel_desc()), &input_parallel_id)); const auto& final_input = JUST(GetBoxingOutput(input, JUST(inputs[i]->nd_sbp()), parallel_desc, input_parallel_id.has_value() || parallel_id.has_value())); boxing_inputs[i] = final_input; } infer_args = JUST(GlobalTensorMetaInferArgs::New(ctx.attrs, boxing_inputs)); } result = JUST(user_op_expr.mut_global_tensor_infer_cache()->GetOrInfer(*infer_args)); } const auto& output_tensor_metas = result->output_tensor_metas(); Optional parallel_id; const auto& tensor_device = JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); for (int i = 0; i < outputs->size(); ++i) { if (!outputs->at(i)) { const auto& tensor_impl = JUST(EagerGlobalTensorImpl::New( output_tensor_metas[i], tensor_device, parallel_id, false, false)); (*outputs)[i].reset(new GlobalTensor(tensor_impl)); } else { JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(NullOpt)); } } // Do nothing if output_tensors has 0-size shape. Since the input of some ops is 0-size but the // output is not 0-size, it cannot be judged based on the input, such as flow.cat if (unlikely(JUST(CachedIsAllZeroSizeTensorMeta(output_tensor_metas)))) { return Maybe::Ok(); } // Run instruction Call const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0) << Error::UnimplementedError() << GetDynamicOpGlobalFailedDebugString(user_op_expr, *kernel); vm::EagerBlobObjectList input_eager_blob_objects(boxing_inputs.size()); // extand lifetime of boxing outputs to the end of this function TensorTuple boxing_outputs; for (int i = 0; i < boxing_inputs.size(); ++i) { std::shared_ptr input = boxing_inputs.at(i); const auto& infered_input_meta = result->input_tensor_metas().at(i); const auto& input_parallel_desc = JUST(input->parallel_desc()); CHECK_OR_RETURN(input_parallel_desc == infered_input_meta->parallel_desc()); bool is_host_input = user_op_expr.IsHostMemoryInput(i); Symbol dst_parallel_desc = is_host_input ? JUST(ReplaceDeviceType(infered_input_meta->parallel_desc(), DeviceType::kCPU)) : infered_input_meta->parallel_desc(); if ((input_parallel_desc->parallel_num() != 1 && infered_input_meta->nd_sbp() != JUST(input->nd_sbp())) || input_parallel_desc->device_type() != dst_parallel_desc->device_type()) { input = JUST(GetBoxingOutput(input, infered_input_meta->nd_sbp(), dst_parallel_desc, parallel_id.has_value())); boxing_outputs.emplace_back(input); } const auto& local_tensor = JUST(input->cur_rank_phy_tensor()); input_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object()); } // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx. if (!parallel_id.has_value()) { return Maybe::Ok(); } vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); for (int i = 0; i < outputs->size(); ++i) { const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor()); output_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object()); } if (tensor_device->enum_type() == DeviceType::kMeta) { return Maybe::Ok(); } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), result, ctx, result->stream()); })); return Maybe::Ok(); } auto* InterpretThenInitGlobalId = DECORATE(&Interpret, NonRecursiveInitGlobalId); } // namespace Maybe EagerGlobalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return InterpretThenInitGlobalId(op_expr, inputs, outputs, ctx); } Maybe EagerGlobalInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } namespace { static constexpr auto* RecursiveGetBoxingOutput = DECORATE(&CalcBoxingOutput, CheckGlobalTensorMeta); Maybe RawGlobalToGlobal(const GlobalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs->size(), 1); const auto& input = inputs.at(0); CHECK_OR_RETURN(input->is_global()); // NOLINT CHECK_OR_RETURN(ctx.parallel_desc.has_value()); CHECK_OR_RETURN(ctx.nd_sbp.has_value()); const auto& in_parallel_desc = JUST(input->parallel_desc()); const auto& out_nd_sbp = JUST(ctx.nd_sbp); const auto& out_parallel_desc = JUST(ctx.parallel_desc); const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_parallel_desc)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); const auto& tensor = JUST(RecursiveGetBoxingOutput(input, out_nd_sbp, out_parallel_desc, in_parallel_id->has_value() || out_parallel_id->has_value())); CHECK_OR_RETURN(tensor); if (out_parallel_id->has_value()) { const auto& nd_sbp = JUST(tensor->nd_sbp()); const auto& parallel_desc = JUST(tensor->parallel_desc()); CHECK_OR_RETURN(nd_sbp == out_nd_sbp) << ". nd_sbp: " << NdSbpToString(nd_sbp) << ", out_nd_sbp" << NdSbpToString(out_nd_sbp); CHECK_OR_RETURN(parallel_desc == out_parallel_desc); outputs->at(0) = tensor; } else { GlobalTensorMeta tensor_meta(*tensor->shape(), tensor->dtype()->data_type(), tensor->memory_format(), out_nd_sbp, out_parallel_desc); const auto& tensor_impl = JUST(EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), tensor->requires_grad(), false)); (*outputs)[0].reset(new GlobalTensor(tensor_impl)); } CHECK_OR_RETURN(outputs->at(0)); return Maybe::Ok(); } static constexpr auto* GlobalToGlobal = DECORATE(&RawGlobalToGlobal, NonRecursiveInitGlobalId); } // namespace Maybe EagerGlobalInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { JUST(GlobalToGlobal(op_expr, inputs, outputs, ctx)); return Maybe::Ok(); } Maybe EagerGlobalInterpreter::ApplyImpl(const LocalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const GlobalToLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); const auto& input_tensor = inputs.at(0); const auto& local_tensor = JUST(JUST(input_tensor->cur_rank_phy_tensor())->detach()); bool requires_grad = autograd::GradMode::is_enabled() && input_tensor->requires_grad(); JUST(local_tensor->set_requires_grad(requires_grad)); local_tensor->set_is_leaf(!requires_grad); (*outputs)[0] = local_tensor; return Maybe::Ok(); } Maybe EagerGlobalInterpreter::ApplyImpl(const CastToLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const CastFromLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerGlobalInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/framework/symbol_storage_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_global_id.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { namespace { Maybe> RawGetDefaultCpuDevice() { return Device::New("cpu"); } constexpr auto* GetDefaultCpuDevice = DECORATE(&RawGetDefaultCpuDevice, ThreadLocal); Maybe> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx, const UserOpExpr& user_op_expr) { if (!inputs.empty()) { for (int32_t i = 0; i < inputs.size(); ++i) { if (!user_op_expr.IsHostMemoryInput(i)) { return JUST(inputs.at(i)->device()); } } } if (ctx.device.has_value()) { return JUST(ctx.device); } else { return GetDefaultCpuDevice(); } } Maybe TensorImpl4Tensor(const std::shared_ptr& tensor) { CHECK_OR_RETURN(static_cast(tensor)); return tensor->mut_eager_local_tensor_impl(); } } // namespace Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("NaiveInterpret"); CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); // NOLINT Symbol default_device = JUST(GetDefaultDevice(inputs, ctx, user_op_expr)); const std::shared_ptr result = JUST([&]() -> Maybe { LocalTensorMetaInferArgs infer_args; JUST(infer_args.Init(ctx.attrs, default_device, inputs)); return JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); }()); vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); // expand lifetime of host_inputs to the end of this function TensorTuple host_inputs; for (int i = 0; i < inputs.size(); i++) { if (user_op_expr.IsHostMemoryInput(i)) { const auto& host_input = JUST(functional::To( inputs.at(i), Optional>(JUST(GetDefaultCpuDevice())), NullOpt, false)); input_eager_blob_objects.at(i) = JUST(host_input->eager_blob_object()); host_inputs.emplace_back(host_input); } else { input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object()); } } const auto& output_tensor_metas = result->output_tensor_metas(); vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { // NOTE: if op support stride(non-contiguous input), then output tensor's stride // should be inferred in InferLogicalTensorDesc. // otherwise, it will be set here(according to shape). std::shared_ptr mut_tensor_meta; { if (kernel->output_is_mut2_type(i)) { mut_tensor_meta = std::make_shared( output_tensor_metas.at(i)->shape(), output_tensor_metas.at(i)->stride(), output_tensor_metas.at(i)->dtype(), output_tensor_metas.at(i)->memory_format(), output_tensor_metas.at(i)->device()); } } std::shared_ptr tensor_impl = std::make_shared(false, false); const auto& dep_object = NewLocalDepObject(); JUST( tensor_impl->InitEagerBlobObject(output_tensor_metas.at(i), mut_tensor_meta, dep_object)); output_eager_blob_objects.at(i) = JUST(tensor_impl->eager_blob_object()); (*outputs)[i] = std::make_shared(tensor_impl); } else { const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. // check TensorMeta of infer result and TensorMeta of output i. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() // NOLINT == output_tensor_metas.at(i)->shape()) // NOLINT << Error::RuntimeError() << tensor_impl->tensor_meta()->shape().ToString() // NOLINT << " .vs " // NOLINT << output_tensor_metas.at(i)->shape().ToString(); // NOLINT CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() // NOLINT == output_tensor_metas.at(i)->dtype()) // NOLINT << Error::RuntimeError() << DataType_Name(tensor_impl->tensor_meta()->dtype()) // NOLINT << " .vs " // NOLINT << DataType_Name(output_tensor_metas.at(i)->dtype()); // NOLINT bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); CHECK_OR_RETURN(has_eager_blob_object); // NOLINT output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); // TODO(zhaoluyang):(thread_local TensorMeta set stride then check) // CHECK_OR_RETURN(tensor_impl->tensor_meta()->stride() == // output_tensor_metas->at(i)->stride()); } } if (default_device->enum_type() == DeviceType::kMeta) { return Maybe::Ok(); } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), ctx, result->stream()); })); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(index))); auto btb = std::make_shared(); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback( tensor_impl, btb, [](ep::Stream* stream, const std::shared_ptr&) {}, "const"); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); const auto& mut_tensor_meta = const_cast(tensor_impl)->mut_tensor_meta(); Symbol new_tensor_meta = SymbolOf(LocalTensorMeta( mut_tensor_meta->shape(), mut_tensor_meta->stride(), mut_tensor_meta->dtype(), mut_tensor_meta->memory_format(), mut_tensor_meta->device())); std::shared_ptr final_tensor_impl = std::make_shared(JUST(tensor_impl->tensor_storage()), JUST(tensor_impl->storage_offset()), false, false); JUST(final_tensor_impl->InitEagerBlobObject( new_tensor_meta, JUST(JUST(outputs->at(index)->eager_blob_object())->compute_local_dep_object()))); JUST(JUST(outputs->at(index)->AsLocalTensor())->set_impl(final_tensor_impl)); } return Maybe::Ok(); } Maybe EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return NaiveInterpret(op_expr, inputs, outputs, ctx); } Maybe EagerLocalInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } static Maybe BuildAndRunLocalCastInstruction(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) { // TODO() OF_UNIMPLEMENTED(); } namespace { Maybe EagerCclBroadcast(Symbol parallel_desc, int64_t root, size_t size, const std::vector& shape_list) { return one::OpBuilder("eager_ccl_broadcast", *JUST(UniqueStr("eager_ccl_broadcast"))) .Input("in", size) .Output("out", size) .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Attr>("shape_list", shape_list) .Attr("root", root) .Build(); } auto* CachedEagerCclBroadcastOpExpr = DECORATE(&EagerCclBroadcast, ThreadLocalCachedCopiable); } // namespace Maybe Broadcast(const std::shared_ptr& tensor, int64_t src_rank, Symbol parallel_desc, bool inplace) { CHECK_OR_RETURN(parallel_desc->containing_current_rank()); if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return tensor; } std::shared_ptr op_expr = JUST(CachedEagerCclBroadcastOpExpr(parallel_desc, src_rank, 1, {*tensor->shape()})); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("root"); attrs.SetAllAttrs(src_rank); if (inplace) { TensorTuple outputs{tensor}; JUST(OpInterpUtil::Dispatch(*op_expr, {tensor}, &outputs, one::OpExprInterpContext(attrs, parallel_desc))); return tensor; } else { return JUST(OpInterpUtil::Dispatch( *op_expr, {tensor}, one::OpExprInterpContext(attrs, parallel_desc))); } } Maybe Broadcast(const TensorTuple& inputs, int64_t src_rank, Symbol parallel_desc, bool inplace) { CHECK_OR_RETURN(parallel_desc->containing_current_rank()) << "Current rank are not contained in the placement arguement"; if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return inputs; } std::vector shape_list; for (const auto& tensor : inputs) { shape_list.emplace_back(*tensor->shape()); } std::shared_ptr op_expr = JUST(CachedEagerCclBroadcastOpExpr(parallel_desc, src_rank, inputs.size(), shape_list)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("root"); attrs.SetAllAttrs(src_rank); if (inplace) { auto outputs = std::make_shared(inputs); JUST(OpInterpUtil::Dispatch(*op_expr, inputs, outputs.get(), one::OpExprInterpContext(attrs, parallel_desc))); return outputs; } else { return JUST(OpInterpUtil::Dispatch( *op_expr, inputs, one::OpExprInterpContext(attrs, parallel_desc))); } } namespace { Maybe GetSyncedTensorIfBroadcast(const std::shared_ptr& tensor, Symbol parallel_desc, Symbol nd_sbp, bool inplace) { Optional parallel_id; JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); if (!parallel_id.has_value()) { return tensor; } const auto& broadcast_parallel_desc = JUST(GetBroadcastSubParallelDesc(parallel_desc, nd_sbp)); int64_t root = JUST(broadcast_parallel_desc->MachineId4ParallelId(0)); if (broadcast_parallel_desc->parallel_num() > 1 && inplace && GlobalProcessCtx::Rank() == 0) { LOG_FIRST_N(WARNING, 1) << "Casting a local tensor to a global tensor with Broadcast sbp will modify the data of " "input! " "If you want to keep the input local tensor unchanged, please set the arg copy to True."; } return Broadcast(tensor, root, broadcast_parallel_desc, inplace); } Maybe CalcPhysicalShape(Symbol global_tensor_meta) { const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(global_tensor_meta->parallel_desc())); int64_t parallel_id = JUST(*opt_parallel_id); return GetPhysicalShape(global_tensor_meta->shape(), *global_tensor_meta->nd_sbp(), *global_tensor_meta->parallel_desc(), parallel_id); } static constexpr auto* GetPhysicalShape = DECORATE(&CalcPhysicalShape, ThreadLocal); Maybe TryReshapeTensor(const std::shared_ptr& tensor, Symbol global_tensor_meta) { CHECK_OR_RETURN(tensor->is_local()); const auto& physical_shape = JUST(GetPhysicalShape(global_tensor_meta)); if (*physical_shape == *tensor->shape()) { return tensor; } CHECK_EQ_OR_RETURN(physical_shape->elem_cnt(), tensor->shape()->elem_cnt()); // TODO(lixinqi) inplace reshape. return tensor; } } // namespace Maybe EagerLocalInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } namespace { Maybe RawLocalToGlobal(const LocalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { std::shared_ptr input_local_tensor; { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_OR_RETURN(!inputs[0]->is_global()); // NOLINT const auto& input_tensor = JUST(inputs.at(0)->detach()); input_local_tensor = JUST(input_tensor->AsLocalTensor()); CHECK_OR_RETURN(input_local_tensor) << Error::InvalidValueError() << "Tensor Cast Error"; // NOLINT bool requires_grad = autograd::GradMode::is_enabled() && inputs.at(0)->requires_grad(); JUST(input_local_tensor->set_requires_grad(requires_grad)); input_local_tensor->set_is_leaf(!requires_grad); } std::shared_ptr global_tensor; { CHECK_OR_RETURN(ctx.parallel_desc.has_value()); CHECK_OR_RETURN(ctx.nd_sbp.has_value()); const auto& nd_sbp = JUST(ctx.nd_sbp); const auto& parallel_desc = JUST(ctx.parallel_desc); const auto& logical_shape = JUST(ctx.attrs.GetAttr("shape")); DataType dtype = JUST(ctx.attrs.GetAttr("dtype")); // MemoryFormat memory_format = JUST(ctx.attrs.GetAttr("memory_format")); GlobalTensorMeta tensor_meta(logical_shape, dtype, MemoryFormat::kContiguous, nd_sbp, parallel_desc); Optional parallel_id{}; const auto& device = JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); const auto& global_tensor_impl = JUST(EagerGlobalTensorImpl::New( SymbolOf(tensor_meta), device, parallel_id, input_local_tensor->requires_grad(), !input_local_tensor->requires_grad())); global_tensor = std::make_shared(global_tensor_impl); if (parallel_id.has_value()) { const auto& pyhsical_shape = JUST(GetPhysicalShape(tensor_meta)); const auto& input_local_tensor_shape = input_local_tensor->shape(); CHECK_EQ_OR_RETURN(*pyhsical_shape, *input_local_tensor_shape); // NOLINT CHECK_OR_RETURN(dtype == input_local_tensor->dtype()->data_type()); // NOLINT global_tensor_impl->reset_cur_rank_phy_tensor(input_local_tensor); } } (*outputs)[0] = global_tensor; return Maybe::Ok(); } static constexpr auto* LocalToGlobal = DECORATE(&RawLocalToGlobal, NonRecursiveInitGlobalId); } // namespace Maybe EagerLocalInterpreter::ApplyImpl(const LocalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { bool sync_data = JUST(ctx.attrs.GetAttr("sync_data")); JUST(LocalToGlobal(op_expr, inputs, outputs, ctx)); const auto& global_tensor = JUST((*outputs)[0]->AsGlobalTensor()); JUST(WithConsistencyChecked(global_tensor, [&]() -> Maybe { if (IsGlobalTensorMetaCheckDisabled()) { return Maybe::Ok(); } const auto& parallel_desc = JUST(ctx.parallel_desc); const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); if (!parallel_id->has_value()) { return Maybe::Ok(); } const auto& nd_sbp = JUST(ctx.nd_sbp); const auto& tensor_meta = JUST(global_tensor->global_tensor_meta()); const auto& local_tensor = JUST(global_tensor->cur_rank_phy_tensor()); const auto& reshaped_tensor = JUST(TryReshapeTensor(local_tensor, tensor_meta)); std::shared_ptr synced_tensor = reshaped_tensor; if (sync_data) { bool inplace = JUST(ctx.attrs.GetAttr("inplace_when_sync_data")); synced_tensor = JUST(GetSyncedTensorIfBroadcast(reshaped_tensor, parallel_desc, nd_sbp, inplace)); } auto* global_tensor_impl = reinterpret_cast(global_tensor->mut_impl()); CHECK_NOTNULL_OR_RETURN(global_tensor_impl); global_tensor_impl->reset_cur_rank_phy_tensor(JUST(synced_tensor->AsLocalTensor())); return Maybe::Ok(); })); return Maybe::Ok(); } Maybe EagerLocalInterpreter::ApplyImpl(const GlobalToLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerLocalInterpreter::ApplyImpl(const CastToLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunLocalCastInstruction(op_expr, inputs, outputs); } Maybe EagerLocalInterpreter::ApplyImpl(const CastFromLocalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunLocalCastInstruction(op_expr, inputs, outputs); } static Maybe BuildAndRunDistributeSplitOrCloneInstruction(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) { // TODO() OF_UNIMPLEMENTED(); } Maybe EagerLocalInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } Maybe EagerLocalInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } static Maybe BuildAndRunDistributeConcatAndAddInstruction(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) { // TODO() OF_UNIMPLEMENTED(); } Maybe EagerLocalInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } Maybe EagerLocalInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } Maybe EagerLocalInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { int top_n = JUST(ctx.attrs.GetAttr("top_n")); outputs->resize(top_n); for (int i = 0; i < top_n; ++i) { (*outputs)[i] = JUST(JUST(VectorAt(inputs, i))->detach()); } return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/eager_blob_object.h" namespace oneflow { class Device; class TensorTuple; class ParallelDesc; namespace one { class Tensor; Maybe Broadcast(const std::shared_ptr& tensor, int64_t src_rank, Symbol parallel_desc, bool inplace); Maybe Broadcast(const TensorTuple& inputs, int64_t src_rank, Symbol parallel_desc, bool inplace); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_interpreter/lazy_op_interpreter.h" #include #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/framework/symbol_storage_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { namespace { Maybe BuildTensor(const OpAttribute& op_attribute, const std::string& bn_in_op, const std::shared_ptr& parallel_desc, const bool is_lazy, const bool is_local) { CHECK_OR_RETURN(op_attribute.has_logical_blob_desc_signature()); // NOLINT(maybe-need-error-msg) const auto& blob_desc_sign_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc(); auto blob_desc_it = blob_desc_sign_map.find(bn_in_op); CHECK_OR_RETURN(blob_desc_it != blob_desc_sign_map.end()) << "blob_desc of " << bn_in_op << " not found in op " << op_attribute.op_conf().name(); auto shape = std::make_shared(blob_desc_it->second.shape()); auto stride = std::make_shared(shape); auto dtype = blob_desc_it->second.data_type(); auto memory_format = blob_desc_it->second.memory_format(); if (is_local) { const auto& device = JUST(Device::MakeDeviceByParallelDesc(*parallel_desc)); const auto& tensor = JUST(LocalTensor::MakeTensor(shape, stride, dtype, memory_format, device, is_lazy, /* requires_grad= */ false, /* is_leaf= */ true)); return static_cast>(tensor); } else { const auto& nd_sbp_sign_map = op_attribute.nd_sbp_signature().bn_in_op2nd_sbp(); auto nd_sbp_it = nd_sbp_sign_map.find(bn_in_op); CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sign_map.end()) << "nd_sbp of " << bn_in_op << " not found in op " << op_attribute.op_conf().name(); NdSbp nd_sbp(nd_sbp_it->second); const auto& tensor = JUST(GlobalTensor::MakeTensor( shape, dtype, memory_format, SymbolOf(nd_sbp), SymbolOf(*parallel_desc), is_lazy, /*requires_grad=*/false, /*is_leaf=*/true)); return static_cast>(tensor); } } Maybe CheckTensorMatchAttr(const std::shared_ptr& tensor, const OpAttribute& op_attribute, const std::string& bn_in_op, const std::shared_ptr& parallel_desc, const bool is_local) { CHECK_EQ_OR_RETURN(tensor->is_local(), is_local); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(op_attribute.has_logical_blob_desc_signature()); // NOLINT(maybe-need-error-msg) const auto& blob_desc_sign_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc(); auto blob_desc_it = blob_desc_sign_map.find(bn_in_op); CHECK_OR_RETURN(blob_desc_it != blob_desc_sign_map.end()) << "blob_desc of " << bn_in_op << " not found in op " << op_attribute.op_conf().name(); auto shape = std::make_shared(blob_desc_it->second.shape()); auto dtype = blob_desc_it->second.data_type(); CHECK_EQ_OR_RETURN(*tensor->shape(), *shape); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(tensor->dtype()->data_type(), dtype); // NOLINT(maybe-need-error-msg) if (is_local) { const auto& device = JUST(Device::MakeDeviceByParallelDesc(*parallel_desc)); CHECK_OR_RETURN(JUST(tensor->device()) == device); // NOLINT(maybe-need-error-msg) } else { const auto& nd_sbp_sign_map = op_attribute.nd_sbp_signature().bn_in_op2nd_sbp(); auto nd_sbp_it = nd_sbp_sign_map.find(bn_in_op); CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sign_map.end()) << "nd_sbp of " << bn_in_op << " not found in op " << op_attribute.op_conf().name(); // Only check the nd_sbp if auto parallel is not enable, // since the semi-auto parallellism rule might have inconsistency with the auto-parallel // strategy. if (!GlobalJobDesc().enable_auto_parallel()) { NdSbp nd_sbp(nd_sbp_it->second); CHECK_OR_RETURN(JUST(tensor->nd_sbp()) == SymbolOf(nd_sbp)) << "The input sbp is not valid for an inplace operation, please try to use non-inplace. " << NdSbpToString(JUST(tensor->nd_sbp())) << " vs " << NdSbpToString(nd_sbp); } CHECK_OR_RETURN(JUST(tensor->parallel_desc()) // NOLINT(maybe-need-error-msg) == SymbolOf(*parallel_desc)); // NOLINT(maybe-need-error-msg) } return Maybe::Ok(); } Maybe GetDeviceTagOfTensor(const std::shared_ptr& tensor) { if (tensor->is_global()) { return JUST(tensor->parallel_desc())->device_tag(); } return JUST(tensor->device())->type(); } bool GetIsDynamicOfTensor(const std::shared_ptr& tensor) { if (tensor->is_global()) { return false; } else { return true; } } Maybe GenNdSbpByTensor(NdSbp* nd_sbp, const std::shared_ptr& tensor) { nd_sbp->clear_sbp_parallel(); if (tensor->is_local()) { // NOTE(chengcheng): // OneFlow Lazy is always global. LocalTensor is a special case of GlobalTensor // which placement is only this rank, and SbpParallel is Broadcast. nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } else { *nd_sbp = *JUST(tensor->nd_sbp()); } return Maybe::Ok(); } Maybe GenVariableOpConfNdSbpStringByTensor(VariableOpConf* var_conf, const std::shared_ptr& tensor) { var_conf->clear_nd_sbp(); if (tensor->is_local()) { SbpParallel broadcast; broadcast.mutable_broadcast_parallel(); var_conf->add_nd_sbp(SbpParallelToString(broadcast)); } else { const NdSbp& nd_sbp = *JUST(tensor->nd_sbp()); for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) { var_conf->add_nd_sbp(SbpParallelToString(sbp_parallel)); } } return Maybe::Ok(); } Maybe GetParallelDescOfTensor(const std::shared_ptr& tensor) { if (tensor->is_local()) { const auto& device = JUST(tensor->device()); const auto& placement = JUST(Placement4Device(device)); return placement.shared_from_symbol(); } else { return JUST(tensor->parallel_desc()).shared_from_symbol(); } } Maybe NewScopeWithParallelConfAndCurScope(const ParallelConf& parallel_conf) { std::shared_ptr new_scope; const auto& old_scope = JUST(GetCurrentScope()); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { new_scope = JUST(builder->BuildScopeWithNewParallelConf(old_scope, parallel_conf)); return Maybe::Ok(); })); // NOTE(chengcheng): need sync vm for get scope right now JUST(vm::CurrentRankSync()); CHECK_OR_RETURN(new_scope); // NOLINT(maybe-need-error-msg) return new_scope; } Maybe NewScopeWithParallelDescByTensor(const std::shared_ptr& tensor) { return NewScopeWithParallelConfAndCurScope( JUST(GetParallelDescOfTensor(tensor))->parallel_conf()); } Maybe GetGradAccStep() { const auto& infer_ctx = JUST(GetCurInferCtx()); const auto& job_conf = infer_ctx->job().job_conf(); if (job_conf.has_train_conf() && job_conf.has_num_gradient_accumulation_steps() && job_conf.num_gradient_accumulation_steps() > 1) { return job_conf.num_gradient_accumulation_steps(); } else { return 1; } } Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_tensor) { if (!input_tensor->is_contiguous()) { LazyMode::Guard lazy_mode_disabled_guard(false); JUST(functional::InplaceToContiguous(input_tensor)); JUST(vm::CurrentRankSync()); } CHECK_OR_RETURN(input_tensor->is_eager()); // NOLINT(maybe-need-error-msg) const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(empty_lbn.empty()); // NOLINT(maybe-need-error-msg) std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); OperatorConf op_conf; op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor))); VariableOpConf* var_conf = op_conf.mutable_variable_conf(); var_conf->set_out("out"); input_tensor->shape()->ToProto(var_conf->mutable_shape()); var_conf->set_data_type(input_tensor->dtype()->data_type()); // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited // by EagerTensor. var_conf->mutable_initializer()->mutable_empty_conf(); JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); // NOTE(chengcheng): Free EagerTensor not trainable var_conf->set_trainable(false); auto infer_ctx = JUST(GetCurInferCtx()); // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no // name so just new a unique name for it. const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf)); op_conf.set_name(new_op_name); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" << op_conf.DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf.name() << " for FreeEagerTensor.\n"; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << " for FreeEagerTensor.\n"; // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind. const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName()); const std::string lbn = GenLogicalBlobName(new_op_name, "out"); Singleton::Get()->StoreFreeEagerTensorWithNameByGraphName( graph_name, input_tensor, new_op_name); int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); auto var_tensor = JUST(BuildTensor(op_attr, "out", blob_parallel_desc, /* is_lazy= */ true, /* is_local= */ input_tensor->is_local())); TensorNameScope::Global()->Record(var_tensor, lbn); // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn. // NOTE(chengcheng): in GradAcc FreeEagerTensor need insert repeat op, but there is no need to // create a new tensor for repeat op out. We just set repeat lbn as this free eager tensor's lbn. auto repeat_tensor = JUST(GradAccTryInsertRepeatAfterVar(var_tensor)); const std::string& repeat_tensor_name = TensorNameScope::Global()->Lookup(repeat_tensor); CHECK_OR_RETURN(!repeat_tensor_name.empty()); // NOLINT(maybe-need-error-msg) TensorNameScope::Global()->Record(input_tensor, repeat_tensor_name); return Maybe::Ok(); } } // namespace Maybe GradAccTryInsertUnpackAfterInput(const std::shared_ptr& input) { int32_t grad_acc_step = JUST(GetGradAccStep()); if (grad_acc_step > 1) { // NOTE(chengcheng): // We assume that the input data is one mini-batch which containing multi micro-batches. // So we need unpack input data for each micro-batch. VLOG(2) << " Current OneFlow nn.Graph grad acc semantics is different from Torch. \n" << " Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \n" << " the input tensor of nn.Graph will be unpacked by 0th dim into multiple micro-batches " << " and exec them in order.\n"; const auto& infer_ctx = JUST(GetCurInferCtx()); const auto& input_lbn = TensorNameScope::Global()->Lookup(input); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add grad acc unpack op after input " << input_lbn << std::endl; return functional::GradAccUnpack(input, grad_acc_step); } else { return input; } } Maybe GradAccTryInsertRepeatAfterVar(const std::shared_ptr& variable) { int32_t grad_acc_step = JUST(GetGradAccStep()); if (grad_acc_step > 1) { // NOTE(chengcheng): // We assume that the nn.Graph once call is one mini-batch which containing multi // micro-batches. So we just repeat variable tensor for each micro-batch. VLOG(2) << " Current OneFlow nn.Graph grad acc semantics is different from Torch. \n" << " Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \n" << " the var tensor of nn.Graph will be repeated exec for multiple micro-batches. \n"; const auto& infer_ctx = JUST(GetCurInferCtx()); const auto& variable_lbn = TensorNameScope::Global()->Lookup(variable); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add grad acc repeat op after variable " << variable_lbn << std::endl; return functional::GradAccRepeat(variable, grad_acc_step); } else { return variable; } } Maybe GradAccTryInsertPackBeforeOutput(const std::shared_ptr& output) { int32_t grad_acc_step = JUST(GetGradAccStep()); if (grad_acc_step > 1) { // NOTE(chengcheng): // We assume that the nn.Graph once call is one mini-batch which containing multi // micro-batches. So we need pack output tensor for each micro-batch to one micro-batch. VLOG(2) << " Current OneFlow nn.Graph grad acc semantics is different from Torch. \n" << " Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \n" << " the output tensor of nn.Graph will be packed to a big tensor by 0th dim, after exec \n" << " for multiple micro-batches. \n"; const auto& infer_ctx = JUST(GetCurInferCtx()); const auto& output_lbn = TensorNameScope::Global()->Lookup(output); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add grad acc pack op before output " << output_lbn << std::endl; return functional::GradAccPack(output, grad_acc_step); } else { return output; } } Maybe GradAccTryInsertRepeatTickBeforeSource( const std::shared_ptr& source_op_conf, bool is_local) { int32_t grad_acc_step = JUST(GetGradAccStep()); if (grad_acc_step > 1) { // NOTE(chengcheng): // We assume that the nn.Graph once call is one mini-batch which containing multi // micro-batches. So we need repeat source op for each micro-batch in one micro-batch. VLOG(2) << " Current OneFlow nn.Graph grad acc semantics is different from Torch. \n" << " Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \n" << " the source op of nn.Graph will be repeated exec n-times for multiple micro-batches.\n"; const auto& infer_ctx = JUST(GetCurInferCtx()); // Insert Tick OperatorConf tick_conf{}; tick_conf.set_name("Sys-GradAcc-RepeatTick-DeviceTick-" + source_op_conf->name()); tick_conf.set_device_tag(source_op_conf->device_tag()); tick_conf.mutable_device_tick_conf()->set_out("out"); tick_conf.set_scope_symbol_id(source_op_conf->scope_symbol_id()); auto tick_lbn = GenLogicalBlobName(tick_conf.name(), tick_conf.device_tick_conf().out()); OpAttribute tick_op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(tick_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op: \n" << tick_conf.DebugString() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << tick_op_attr.DebugString() << std::endl; const auto& scope = Singleton>::Get()->Get(source_op_conf->scope_symbol_id()); int64_t parallel_desc_sym_id = JUST(scope.GetParallelDescSymbolId(tick_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); auto tick_tensor = JUST(BuildTensor(tick_op_attr, tick_conf.device_tick_conf().out(), blob_parallel_desc, /* is_lazy= */ true, /* is_local= */ is_local)); TensorNameScope::Global()->Record(tick_tensor, tick_lbn); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add grad acc repeat op after tick op " << tick_conf.name() << " and before source op" << source_op_conf->name(); auto repeat_tensor = JUST(functional::GradAccRepeat(tick_tensor, grad_acc_step)); const std::string& repeat_tensor_name = TensorNameScope::Global()->Lookup(repeat_tensor); CHECK_OR_RETURN(!repeat_tensor_name.empty()); // NOLINT(maybe-need-error-msg) (*source_op_conf->mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName] .add_s(repeat_tensor_name); } return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE(chengcheng): inputs[0] is the EagerTensor CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_tensor = inputs.at(0); CHECK_OR_RETURN(input_tensor->is_eager()); // NOLINT(maybe-need-error-msg) std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); OperatorConf op_conf; op_conf.set_name(op_expr.op_name()); // construct by python nn.Graph op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor))); // NOTE(chengcheng): // We contruct InputOpConf instead of FeedInputOpConf because FeedInputOpExpr JUST for getting // input EagerTensor. InputOpConf* input_conf = op_conf.mutable_input_conf(); input_conf->set_out("out"); InterfaceBlobConf* blob_conf = input_conf->mutable_blob_conf(); input_tensor->shape()->ToProto(blob_conf->mutable_shape()); blob_conf->set_data_type(input_tensor->dtype()->data_type()); // NOTE(chengcheng): is_dynamic true has conflict in global lazy job even if world size 1. // this flag will be removed in the future. // blob_conf->set_is_dynamic(GetIsDynamicOfTensor(input_tensor)); blob_conf->set_is_dynamic(false); JUST(GenNdSbpByTensor(blob_conf->mutable_nd_sbp(), input_tensor)); auto infer_ctx = JUST(GetCurInferCtx()); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n: " << op_conf.DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf.name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!(*outputs)[0]); // NOLINT(maybe-need-error-msg) const std::string obn = "out"; // NOTE(chengcheng): obn is NOT op_expr.indexed_obns auto origin_input = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, /* is_local= */ input_tensor->is_local())); TensorNameScope::Global()->Record(origin_input, GenLogicalBlobName(op_conf.name(), obn)); TensorNameScope::Global()->Record(input_tensor, GenLogicalBlobName(op_conf.name(), obn)); // NOTE: The input will then be unpacked in DispatchFeedInputOpExprFunctor // if GradAcc is enabled (*outputs)[0] = origin_input; return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE(chengcheng): inputs[0] is the EagerTensor CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_tensor = inputs.at(0); CHECK_OR_RETURN(input_tensor->is_eager()); // NOLINT(maybe-need-error-msg) std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); OperatorConf op_conf; op_conf.set_name(op_expr.op_name()); // construct by python nn.Graph op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor))); // NOTE(chengcheng): // We contruct VariableOpConf instead of FeedVariableOpConf because FeedVariableOpExpr JUST // for getting input EagerTensor. VariableOpConf* var_conf = op_conf.mutable_variable_conf(); var_conf->set_out("out"); input_tensor->shape()->ToProto(var_conf->mutable_shape()); var_conf->set_data_type(input_tensor->dtype()->data_type()); // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited // by EagerTensor. var_conf->mutable_initializer()->mutable_empty_conf(); JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); if (!input_tensor->requires_grad()) { var_conf->set_trainable(false); } if (input_tensor->requires_grad()) { double l2 = JUST(ctx.attrs.GetAttr("l2")); if (unlikely(l2 != 0.0)) { var_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2(l2); } } auto infer_ctx = JUST(GetCurInferCtx()); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n: " << op_conf.DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf.name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!(*outputs)[0]); // NOLINT(maybe-need-error-msg) const std::string obn = "out"; // NOTE(chengcheng): obn is NOT op_expr.indexed_obns auto origin_var = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, /* is_local */ input_tensor->is_local())); // NOTE(chengcheng): Record variable op output LazyTenosr TensorNameScope::Global()->Record(origin_var, GenLogicalBlobName(op_conf.name(), obn)); // NOTE(chengcheng): Record EagerTensor as variable tensor name TensorNameScope::Global()->Record(input_tensor, GenLogicalBlobName(op_conf.name(), obn)); // NOTE: The output variable will then be repeat in DispatchFeedVariableOpExprFunctor // if GradAcc is enabled (*outputs)[0] = origin_var; return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE: The input has been packed in DispatchFetchOutputOpExprFunctor // if GradAcc is enabled // NOTE(chengcheng): inputs[0] is the LazyTensor CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_tensor = inputs.at(0); std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor); // Lazy tensor must has lbn. // Eager tensor may has lbn if it has already been treated as an output of a variable op // or an output of an inplace op. if (input_lbn.empty()) { CHECK_OR_RETURN(input_tensor->is_eager()); // NOLINT(maybe-need-error-msg) // This output tensor is a new free eager tensor, so treat it as a new variable op output. JUST(AddFreeEagerTensorToVariableOp(input_tensor)); input_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg) } std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); OperatorConf op_conf; op_conf.set_name(op_expr.op_name()); // construct by python nn.Graph op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor))); // NOTE(chengcheng): // We contruct OutputOpConf instead of FetchOutputOpConf because FetchOutputOpExpr JUST // for get nn.Graph output LazyTensor. OutputOpConf* output_conf = op_conf.mutable_output_conf(); output_conf->set_in(input_lbn); output_conf->set_out("out"); InterfaceBlobConf* blob_conf = output_conf->mutable_blob_conf(); input_tensor->shape()->ToProto(blob_conf->mutable_shape()); blob_conf->set_data_type(input_tensor->dtype()->data_type()); // NOTE(chengcheng): is_dynamic true has conflict in global lazy job even if world size 1. // this flag will be removed in the future. // blob_conf->set_is_dynamic(GetIsDynamicOfTensor(input_tensor)); blob_conf->set_is_dynamic(false); JUST(GenNdSbpByTensor(blob_conf->mutable_nd_sbp(), input_tensor)); auto infer_ctx = JUST(GetCurInferCtx()); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" << op_conf.DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf.name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!(*outputs)[0]); // NOLINT(maybe-need-error-msg) const std::string obn = "out"; // NOTE(chengcheng): obn is NOT op_expr.indexed_obns (*outputs)[0] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ false, /* is_local= */ input_tensor->is_local())); return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const ImageDecoderRandomCropResizeOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_tensor = inputs.at(0); const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg) auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs)); std::string device_tag; if (IsCpuOnly(*op_conf)) { device_tag = "cpu"; } else { device_tag = "cuda"; } ParallelConf parallel_conf = JUST(GetParallelDescOfTensor(input_tensor))->parallel_conf(); parallel_conf.set_device_tag(device_tag); // NOTE(chengcheng): only support gpu decode. const auto& scope = JUST(NewScopeWithParallelConfAndCurScope(parallel_conf)); op_conf->set_scope_symbol_id(JUST(scope->symbol_id())); op_conf->set_device_tag(device_tag); // NOTE(chengcheng): replace right input_lbn and obn ReplaceInputLbnInOpCustomizedConf(op_conf.get(), /* ibn */ "in", input_lbn); op_conf->mutable_image_decoder_random_crop_resize_conf()->set_out("out"); auto infer_ctx = JUST(GetCurInferCtx()); // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf)); op_conf->set_name(new_op_name); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" << op_conf->DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf->name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!(*outputs)[0]); // NOLINT(maybe-need-error-msg) const std::string obn = "out"; // NOTE(chengcheng): obn is NOT op_expr.indexed_obns (*outputs)[0] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, /* is_local= */ input_tensor->is_local())); TensorNameScope::Global()->Record((*outputs)[0], GenLogicalBlobName(new_op_name, obn)); return Maybe::Ok(); } namespace { Maybe LazyInterpreterApplyImplForSourceUserOpExpr(const UserOpExpr& op_expr, TensorTuple* outputs, const OpExprInterpContext& ctx) { NonRecursiveMetaInfoConsistencyCheckScope non_scope; bool is_local; std::shared_ptr parallel_desc; if (ctx.parallel_desc.has_value()) { // NOTE(chengcheng): global CHECK_OR_RETURN(!ctx.device.has_value()); // NOLINT(maybe-need-error-msg) const auto& parallel_desc_sym = JUST(ctx.parallel_desc); parallel_desc = parallel_desc_sym.shared_from_symbol(); JUST(MetaInfoConsistencyCheck(parallel_desc_sym, ctx.nd_sbp, 1, /* force_check */ false)); is_local = false; } else { // NOTE(chengcheng): local CHECK_OR_RETURN(!ctx.nd_sbp.has_value()); // NOLINT(maybe-need-error-msg) if (ctx.device.has_value()) { const auto& device = JUST(ctx.device); const auto& placement = JUST(Placement4Device(device)); parallel_desc = placement.shared_from_symbol(); } else { // NOTE(chengcheng): if functor NOT set device, using cpu device default. const auto& device = JUST(Device::New("cpu")); const auto& placement = JUST(Placement4Device(device)); parallel_desc = placement.shared_from_symbol(); } is_local = true; } const auto& parallel_conf = parallel_desc->parallel_conf(); const auto& scope = JUST(NewScopeWithParallelConfAndCurScope(parallel_conf)); auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs)); op_conf->set_scope_symbol_id(JUST(scope->symbol_id())); op_conf->set_device_tag(parallel_conf.device_tag()); auto infer_ctx = JUST(GetCurInferCtx()); // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf)); const std::string graph_name = infer_ctx->job().job_conf().job_name(); // NOTE(chengcheng): for UserOp, NOT only reset op_name, but also the output values. op_conf->set_name(new_op_name); for (auto& pair : *(op_conf->mutable_user_conf()->mutable_output())) { auto& list_s = pair.second; for (int i = 0; i < list_s.s_size(); ++i) { std::string old_lbn = list_s.s(i); LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); // NOTE(chengcheng): MUST change the old_lbn to new op name. std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name()); list_s.set_s(i, new_lbn); } } JUST(GradAccTryInsertRepeatTickBeforeSource(op_conf, is_local)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" << op_conf->DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf)); VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" << op_conf->name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size()); // NOLINT(maybe-need-error-msg) for (int i = 0; i < op_expr.output_size(); ++i) { const std::string& obn = op_expr.indexed_obns().at(i); if (!(*outputs)[i]) { (*outputs)[i] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local)); } else { VLOG(2) << "Lazy nn.Graph name " << graph_name << " source op name " << new_op_name << " run with inplace."; const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local)); } TensorNameScope::Global()->Record((*outputs)[i], GenLogicalBlobName(new_op_name, obn)); } return Maybe::Ok(); } Maybe LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_OR_RETURN(op_expr.op_type_name() == "copy"); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) const std::shared_ptr& input_tensor = inputs.at(0); std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor); if (input_lbn.empty()) { JUST(AddFreeEagerTensorToVariableOp(input_tensor)); input_lbn = TensorNameScope::Global()->Lookup(input_tensor); } CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg) auto device = JUST(ctx.attrs.GetAttr>("device")); CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) if (input_tensor->is_local()) { (*outputs)[0] = JUST(LocalTensor::MakeTensor( input_tensor->shape(), JUST(input_tensor->stride()), input_tensor->dtype()->data_type(), input_tensor->memory_format(), device, /* is_lazy= */ true, /*requires_grad=*/false, /*is_leaf=*/true)); } else { ParallelConf parallel_conf = JUST(input_tensor->parallel_desc())->parallel_conf(); parallel_conf.set_device_tag(device->type()); ParallelDesc parallel_desc(parallel_conf); (*outputs)[0] = JUST(GlobalTensor::MakeTensor( input_tensor->shape(), input_tensor->dtype()->data_type(), input_tensor->memory_format(), JUST(input_tensor->nd_sbp()), SymbolOf(parallel_desc), /* is_lazy= */ true, /*requires_grad=*/false, /*is_leaf=*/true)); } // NOTE(chengcheng): output tensor lbn is SAME with input tensor. TensorNameScope::Global()->Record(outputs->at(0), input_lbn); return Maybe::Ok(); } } // namespace Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { CHECK_EQ_OR_RETURN(inputs.size(), op_expr.input_size()); // NOLINT(maybe-need-error-msg) // NOTE(chengcheng): Handle special UserOp such as: // 1. [Source UserOp] : OFRecordReader, CoinFlip // 2. [Change Placement/ParallelDesc UserOp] : to(copy)/to_global/parallel_cast // 3. [Multi-Inputs & Different ParallelDesc for each input UserOp] : like there are 2 inputs, // one from CPU and the other from GPU. // ..., etc. // // Need add if for each special UserOp for infer: // 1. op_conf: device_tag, // 2. output tensor: is_local, // 3. op_parallel_conf for build new scope with parallel_desc // 4. output blob (different with tensor) -> parallel_conf // 5. need add to JobBuildAndInferCtx (like copy will NOT need) if (inputs.size() == 0) { // NOTE(chengcheng): handle for source UserOp like OFRecordReader, CoinFlip return LazyInterpreterApplyImplForSourceUserOpExpr(op_expr, outputs, ctx); } if (op_expr.op_type_name() == "copy") { // NOTE(chengcheng): handle for copy UserOp which will NOT add op to job. return LazyInterpreterApplyImplForCopyUserOpExpr(op_expr, inputs, outputs, ctx); } // NOTE(chengcheng): // Normal UserOp inputs size >= 1 for infer parallel_desc. CHECK_GE_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs)); std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(JUST(VectorAt(inputs, 0)))); op_conf->set_scope_symbol_id(JUST(scope->symbol_id())); const std::string device_tag = JUST(GetDeviceTagOfTensor(JUST(VectorAt(inputs, 0)))); const bool is_local = inputs.at(0)->is_local(); const std::shared_ptr parallel_desc = JUST(GetParallelDescOfTensor(inputs.at(0))); op_conf->set_device_tag(device_tag); auto infer_ctx = JUST(GetCurInferCtx()); // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf)); const std::string graph_name = infer_ctx->job().job_conf().job_name(); for (int i = 0; i < inputs.size(); ++i) { const auto& input_tensor = inputs.at(i); CHECK_EQ_OR_RETURN(is_local, input_tensor->is_local()); // NOLINT(maybe-need-error-msg) if (!op_expr.IsHostMemoryInput(i)) { if (is_local) { CHECK_OR_RETURN(device_tag == JUST(GetDeviceTagOfTensor(input_tensor))) << Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name << " encountered ERROR in module/op_name: " << new_op_name << ". Expected all tensors to be on the same device, but found at least two devices, " << JUST(JUST(VectorAt(inputs, 0))->device())->ToString() << " (positional 0) and " << JUST(JUST(VectorAt(inputs, i))->device())->ToString() << " (positional " << i << ")! Please use tensor.to() to synchronize all the input with the same device."; } else { // TODO: Print out all the placement CHECK_OR_RETURN(parallel_desc->Equals(*JUST(GetParallelDescOfTensor(input_tensor)))) << Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name << " encountered ERROR in module/op_name: " << new_op_name << ". Expected all tensors to be on the same placement, but found at least two " "placements, " << *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, 0))->parallel_desc()))) << " (positional 0) and " << *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, i))->parallel_desc()))) << " (positional " << i << ")! Please use tensor.to_global() to synchronize all the input with the same " "placement."; } } const std::string& ibn = op_expr.indexed_ibns().at(i); std::string lbn = TensorNameScope::Global()->Lookup(input_tensor); if (lbn.empty()) { JUST(AddFreeEagerTensorToVariableOp(input_tensor)); lbn = TensorNameScope::Global()->Lookup(input_tensor); } CHECK_OR_RETURN(!lbn.empty()); // NOLINT(maybe-need-error-msg) ReplaceInputLbnInOpCustomizedConf(op_conf.get(), ibn, lbn); } // NOTE(chengcheng): for UserOp, NOT only reset op_name, but also the output values. op_conf->set_name(new_op_name); for (auto& pair : *(op_conf->mutable_user_conf()->mutable_output())) { auto& list_s = pair.second; for (int i = 0; i < list_s.s_size(); ++i) { std::string old_lbn = list_s.s(i); LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); // NOTE(chengcheng): MUST change the old_lbn to new op name. std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name()); list_s.set_s(i, new_lbn); } } // Check outputs num and setup output tensor properties. CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size()); // NOLINT(maybe-need-error-msg) // Disable boxing if the computation is inplace. for (int i = 0; i < op_expr.output_size(); ++i) { const auto& output = outputs->at(i); if (output) { const std::string& lbn = TensorNameScope::Global()->Lookup(output); CHECK_OR_RETURN(!lbn.empty()) << "The output which index is " << i << " has no tensor name, please check whether the inplaced " "output is also an input of the operation " << new_op_name; JUST(infer_ctx->DisableBoxing(lbn)); } } VLOG(2) << "Lazy nn.Graph name " << graph_name << " try to add op: \n" << op_conf->DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf)); VLOG(2) << "Lazy nn.Graph name " << graph_name << " add op : \n" << op_conf->name() << std::endl; VLOG(3) << "Lazy nn.Graph name " << graph_name << " infer and and op attr : \n" << op_attr.DebugString() << std::endl; int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); for (int i = 0; i < op_expr.output_size(); ++i) { const std::string& obn = op_expr.indexed_obns().at(i); if (!(*outputs)[i]) { (*outputs)[i] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local)); } else { VLOG(2) << "Lazy nn.Graph name " << graph_name << " op name " << new_op_name << " run with inplace."; const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local)); } TensorNameScope::Global()->Record((*outputs)[i], GenLogicalBlobName(new_op_name, obn)); } return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext&) const { // Must reset ctx in each forward op_expr.reset_state(); std::shared_ptr ctx = op_expr.state(); *outputs = *(op_expr.forward()(ctx, inputs)); return Maybe::Ok(); } Maybe LazyInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) const auto& input_tensor = inputs[0]; CHECK_OR_RETURN(input_tensor->is_global()); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(ctx.parallel_desc.has_value()); // NOLINT(maybe-need-error-msg) const auto& parallel_desc_sym = JUST(ctx.parallel_desc); CHECK_OR_RETURN(ctx.nd_sbp.has_value()); // NOLINT(maybe-need-error-msg) const auto& sbp_sym = JUST(ctx.nd_sbp); std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor); if (input_lbn.empty()) { JUST(AddFreeEagerTensorToVariableOp(input_tensor)); input_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg) } std::shared_ptr input_proxy; if (!JUST(GetParallelDescOfTensor(input_tensor)) ->Equals(*parallel_desc_sym.shared_from_symbol())) { // NOTE(zwx): The input tensor's parallel_desc is not equal to that of op's, // create a proxy input with the parallel_desc that is the same as op's input_proxy = JUST(GlobalTensor::MakeTensor( input_tensor->shape(), input_tensor->dtype()->data_type(), input_tensor->memory_format(), JUST(input_tensor->nd_sbp()), parallel_desc_sym, /* is_lazy= */ true, /*requires_grad=*/false, /*is_leaf=*/true)); TensorNameScope::Global()->Record(input_proxy, input_lbn); } CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!(*outputs)[0]); // NOLINT(maybe-need-error-msg) if (!op_expr.grad_nd_sbp().has_value() && sbp_sym == JUST(input_tensor->nd_sbp())) { // NOTE(chengcheng): if to_global ONLY change placement (nd_sbp and grad_nd_sbp is same), // there is no need to build hierarchical_parallel_cast op. if (input_proxy) { (*outputs)[0] = input_proxy; } else { (*outputs)[0] = input_tensor; } return Maybe::Ok(); } // build parallel cast op expr std::shared_ptr> sbp_list_ptr = JUST(GetNdSbpStrList(sbp_sym)); std::string grad_mode; std::vector grad_sbp_str_list; if (op_expr.grad_nd_sbp().has_value()) { grad_mode = "manual"; grad_sbp_str_list = *JUST(GetNdSbpStrList(JUST(op_expr.grad_nd_sbp()))); } else { grad_mode = "identity"; } std::shared_ptr parallel_cast_op_expr = JUST(OpBuilder("hierarchical_parallel_cast", "trivial_op_name") .Input("in") .Output("out") .Attr>("nd_sbp", *sbp_list_ptr) .Attr("grad_mode", grad_mode) .Attr>("grad_nd_sbp", grad_sbp_str_list) .Build()); if (input_proxy) { (*outputs)[0] = JUST(OpInterpUtil::Dispatch(*parallel_cast_op_expr, {input_proxy})); } else { (*outputs)[0] = JUST(OpInterpUtil::Dispatch(*parallel_cast_op_expr, {input_tensor})); } return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/lazy_op_interpreter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace one { Maybe GradAccTryInsertUnpackAfterInput(const std::shared_ptr& input); Maybe GradAccTryInsertRepeatAfterVar(const std::shared_ptr& variable); Maybe GradAccTryInsertPackBeforeOutput(const std::shared_ptr& output); Maybe GradAccTryInsertRepeatTickBeforeSource( const std::shared_ptr& source_op_conf, bool is_local); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/op_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(FeedInputOp); APPLY_IF(FeedVariableOp); APPLY_IF(FetchOutputOp); APPLY_IF(UserOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(FunctionOp); APPLY_IF(ImageDecoderRandomCropResizeOp); #undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in LazyInterpreter::Apply."; } Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // In the op interpreter, judge whether to open the global mode to avoid recursion caused by // GlobalMode. // The global mode is enabled only if it was enabled and the current operation is a local // operation. auto global_mode_gurad = GlobalMode::Guard(GlobalMode::is_enabled() && is_local_, GlobalMode::nd_sbp(), GlobalMode::parallel_desc()); #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(LocalToGlobalOp); APPLY_IF(GlobalToLocalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp) #undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply."; } Maybe EagerInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext&) const { // Must reset ctx in each forward op_expr.reset_state(); std::shared_ptr ctx = op_expr.state(); *outputs = *(op_expr.forward()(ctx, inputs)); return Maybe::Ok(); } Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { bool requires_grad = false; if (autograd::GradMode::is_enabled() && !JUST(op_expr.IsGradDisabled())) { requires_grad = std::any_of(inputs.begin(), inputs.end(), [](const std::shared_ptr& tensor) { return tensor->requires_grad(); }); } { autograd::AutoGradMode mode(false); JUST(internal_->Apply(op_expr, inputs, outputs, ctx)); } // Lazy mode will construct backward compute graph in passes, so disable autograd if lazy mode. std::shared_ptr grad_closure(nullptr); if (requires_grad) { OF_PROFILER_RANGE_PUSH("autograd.GetOrCreateOpGradClosure"); grad_closure = JUST(op_expr.GetOrCreateOpGradClosure()); auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); JUST(grad_closure->Apply(out_grads, in_grads)); return Maybe::Ok(); }; backward_fn->status = [=]() { return grad_closure->state()->SavedTensors().size() > 0; }; OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("autograd.AddNode"); JUST(GetThreadLocalAutogradEngine()->AddNode(op_expr.op_type_name() + "Backward", backward_fn, inputs, outputs)); OF_PROFILER_RANGE_POP(); } if (requires_grad) { OF_PROFILER_RANGE_GUARD("autograd.Capture"); // Capture inputs and outputs after `AddNode` because of that grad function // node has been attached to them. JUST(grad_closure->Capture(inputs, *outputs, ctx)); } // Update outputs autograd meta // Note: if requires_grad is True, we will create a new autograd meta for each output // in `AddNode` to support inplace operation, so the update should after // `AddNode` for (auto& output : *outputs) { output->set_is_leaf(inputs.size() == 0 || !requires_grad); // If the output `requires_grad` is true, it means that the output is inplaced. // The output `requires_grad` should be determined by this: // - If the inplaced output `requires_grad` is true, then the autograd must be disabled, // so the output `requires_grad` should never be changed. // - If the inplaced output `requires_grad` is false, then the output `requires_grad` // shoule be inferred by autograd mode and inputs. For example, // // >>> import oneflow as flow // >>> x = flow.ones(4, 4, requires_grad=False) // >>> y = flow.ones(4, 4, requires_grad=True) // >>> x += y // >>> x.requires_grad // True // >>> with flow.no_grad(): // >>> x += y // >>> x.requires_grad // False // // - If there is no inplace, the output `requires_grad` should be inferred by autograd // mode and inputs. if (!output->requires_grad()) { JUST(output->set_requires_grad( requires_grad && IsSupportRequireGradDataType(output->dtype()->data_type()))); } } return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/op_interpreter_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor_impl.h" #include "oneflow/core/functional/tensor_processor.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { namespace { std::shared_ptr BuildEagerInterpreter(const bool& is_local) { std::shared_ptr internal; if (is_local) { internal = std::make_shared(); } else { internal = std::make_shared(); } return std::make_shared(internal); } std::shared_ptr BuildLazyInterpreter() { auto internal = std::make_shared(); return std::make_shared(internal); } std::string ErrorString4Inputs(const TensorTuple& inputs, const OpExpr& op_expr) { std::stringstream error_str; error_str << "Got input tensors with inconsistent attributes!\n" << "op_type_name: " << op_expr.op_type_name() << "\n" << "attributes of inputs is:\n"; int32_t idx = 0; for (const auto& tensor : inputs) { if (tensor->is_local()) { error_str << "local"; } else { error_str << "global"; } if (++idx != inputs.size()) { error_str << ", "; } } return error_str.str(); } Maybe GetInterpreter(const TensorTuple& inputs, const OpExprInterpContext& ctx, const OpExpr& op_expr) { static const auto& g_lazy_interpreter = BuildLazyInterpreter(); static const auto& g_eager_global_interpreter = BuildEagerInterpreter(/*is_local=*/false); static const auto& g_eager_local_interpreter = BuildEagerInterpreter(/*is_local=*/true); bool is_local = true; if (inputs.empty()) { if (ctx.parallel_desc.has_value()) { JUST(ctx.nd_sbp); CHECK_OR_RETURN(!ctx.device.has_value()); is_local = false; } else { CHECK_OR_RETURN(!ctx.nd_sbp.has_value()); } } else { if (inputs[0]->is_global()) { if (inputs.size() == 1) { // do nothing } else if (inputs.size() == 2) { CHECK_OR_RETURN(inputs[1]->is_global()) // NOLINT << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency } else if (inputs.size() == 3) { CHECK_OR_RETURN(inputs[1]->is_global()) << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency CHECK_OR_RETURN(inputs[2]->is_global()) << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency } else { for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_global()) << ErrorString4Inputs(inputs, op_expr); } } is_local = false; } else { if (inputs.size() == 1) { // do nothing } else if (inputs.size() == 2) { CHECK_OR_RETURN(inputs.at(1)->is_local()) << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency } else if (inputs.size() == 3) { CHECK_OR_RETURN(inputs.at(1)->is_local()) << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency CHECK_OR_RETURN(inputs.at(2)->is_local()) << ErrorString4Inputs(inputs, op_expr); // unroll loop for efficiency } else { for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_local()) << ErrorString4Inputs(inputs, op_expr); } } } } if (!LazyMode::is_enabled()) { if (is_local) { return g_eager_local_interpreter; } else { return g_eager_global_interpreter; } } else { return g_lazy_interpreter; } } } // namespace template<> /* static */ Maybe OpInterpUtil::Dispatch( const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("Dispatch"); auto outputs = std::make_shared(op_expr.output_size()); JUST(Dispatch(op_expr, inputs, outputs.get(), ctx)); return outputs; } template<> /* static */ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("Dispatch"); return JUST(Dispatch(op_expr, inputs, ctx))->at(0); } /* static */ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("Dispatch"); functional::TensorProcessorPipe processor(inputs, outputs); if (autocast::is_enabled()) { JUST(processor.Apply( *JUST(op_expr.GetOrCreateAutoCastMeta()))); } JUST(processor.Apply(JUST(op_expr.SupportNonContiguous()))); return JUST(GetInterpreter(processor.inputs(), ctx, op_expr)) ->Apply(op_expr, processor.inputs(), processor.outputs(), ctx); } /* static */ Maybe OpInterpUtil::GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs) { auto op_conf = std::make_shared(); JUST(op_expr.BuildOpConf(op_conf.get(), attrs)); return op_conf; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_interpreter/op_interpreter_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" namespace oneflow { namespace one { class OpInterpUtil { public: template static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs) { return Dispatch(op_expr, inputs, OpExprInterpContext(attrs)); } template static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) { return Dispatch(op_expr, inputs, OpExprInterpContext(AttrMap{})); } template static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx); static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) { return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(attrs)); } static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) { return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{})); } static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx); static Maybe GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs); }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_ ================================================ FILE: oneflow/core/framework/op_interpreter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_ #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/common/optional.h" namespace oneflow { class Device; class ParallelDesc; class NdSbp; namespace one { struct OpExprInterpContext { OpExprInterpContext(const AttrMap& attrs_arg) : attrs(attrs_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol device_arg) : attrs(attrs_arg), device(device_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, std::shared_ptr state_arg) : attrs(attrs_arg), state(state_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol device_arg, std::shared_ptr state_arg) : attrs(attrs_arg), device(device_arg), state(state_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol parallel_desc_arg) : attrs(attrs_arg), parallel_desc(parallel_desc_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol parallel_desc_arg, Symbol nd_sbp_arg) : attrs(attrs_arg), parallel_desc(parallel_desc_arg), nd_sbp(nd_sbp_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol parallel_desc_arg, Symbol nd_sbp_arg, std::shared_ptr state_arg) : attrs(attrs_arg), parallel_desc(parallel_desc_arg), nd_sbp(nd_sbp_arg), state(state_arg) {} AttrMap attrs; Optional> device; // for local op Optional> parallel_desc; // for global op Optional> nd_sbp; // for global op std::shared_ptr state; }; class OpExprInterpreter { public: OpExprInterpreter() = default; virtual ~OpExprInterpreter() = default; Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { return Apply(op, inputs, outputs, OpExprInterpContext(attrs)); } Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { return Apply(op, inputs, outputs, AttrMap{}); } virtual Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const = 0; }; #define FOR_EACH_BUILTIN_OPS(_macro) \ _macro(UserOp); \ _macro(SelectTopNOp); \ _macro(VariableOp); \ _macro(CastToLocalOp); \ _macro(CastFromLocalOp); \ _macro(GlobalToGlobalOp); \ _macro(LocalToGlobalOp); \ _macro(GlobalToLocalOp); \ _macro(DistributeSplitOp); \ _macro(DistributeCloneOp); \ _macro(DistributeConcatOp); \ _macro(DistributeAddOp); #define DECLARE_NORMAL_APPLY_FUNC(op_type) \ virtual Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ TensorTuple* outputs, const OpExprInterpContext& ctx) const #define DECLARE_PURE_VIRTUAL_APPLY_FUNC(op_type) DECLARE_NORMAL_APPLY_FUNC(op_type) = 0; #define DECLARE_OVERRIDE_APPLY_FUNC(op_type) \ Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ TensorTuple* outputs, const OpExprInterpContext& ctx) const override; class LazyInterpreter : public OpExprInterpreter { public: LazyInterpreter() : OpExprInterpreter() {} virtual ~LazyInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const override; private: DECLARE_NORMAL_APPLY_FUNC(UserOp); DECLARE_NORMAL_APPLY_FUNC(FeedInputOp); DECLARE_NORMAL_APPLY_FUNC(FeedVariableOp); DECLARE_NORMAL_APPLY_FUNC(FetchOutputOp); DECLARE_NORMAL_APPLY_FUNC(FunctionOp); DECLARE_NORMAL_APPLY_FUNC(GlobalToGlobalOp); DECLARE_NORMAL_APPLY_FUNC(ImageDecoderRandomCropResizeOp); }; class EagerInterpreter : public OpExprInterpreter { public: EagerInterpreter(bool is_local) : OpExprInterpreter(), is_local_(is_local) {} virtual ~EagerInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const override; protected: // NOTE(lixiang): To ensure the correctness of GlobalMode, check whether it is a local operation // and initialize it as true when using EagerLocalInterpreter. // Used by Maybe EagerInterpreter::Apply. bool is_local_; private: FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC); DECLARE_NORMAL_APPLY_FUNC(FunctionOp); }; class EagerGlobalInterpreter : public EagerInterpreter { public: EagerGlobalInterpreter() : EagerInterpreter(false) {} virtual ~EagerGlobalInterpreter() = default; private: FOR_EACH_BUILTIN_OPS(DECLARE_OVERRIDE_APPLY_FUNC); }; class EagerLocalInterpreter : public EagerInterpreter { public: EagerLocalInterpreter() : EagerInterpreter(true) {} virtual ~EagerLocalInterpreter() = default; private: FOR_EACH_BUILTIN_OPS(DECLARE_OVERRIDE_APPLY_FUNC); }; #undef DECLARE_OVERRIDE_APPLY_FUNC #undef DECLARE_PURE_VIRTUAL_APPLY_FUNC #undef DECLARE_NORMAL_APPLY_FUNC #undef FOR_EACH_BUILTIN_OPS class AutogradInterpreter { public: AutogradInterpreter() = delete; AutogradInterpreter(const std::shared_ptr& internal) : internal_(internal) {} virtual ~AutogradInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs)); } Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) const { return Apply(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{})); } Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const; private: std::shared_ptr internal_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_ ================================================ FILE: oneflow/core/framework/op_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/attr_value_accessor.h" namespace oneflow { namespace user_op { void OpKernel::InferShape(KernelInferContext* ctx) const { InferContext* op_infer_ctx = ctx->MutOpInferContext(); CHECK_NOTNULL(op_infer_ctx); ctx->GetOpInferFn()(op_infer_ctx); for (const auto& arg_pair : ctx->outputs()) { const Shape& shape = op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second); auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(arg_pair.first, arg_pair.second); mut_shape_view.set_shape(shape); } } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ #include #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/util.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class JobDesc; namespace user_op { class KernelInitContext { public: OF_DISALLOW_COPY_AND_MOVE(KernelInitContext); virtual ~KernelInitContext() = default; virtual ep::Stream* stream() = 0; virtual DeviceType device_type() const = 0; virtual const ParallelContext& parallel_ctx() const = 0; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const ParallelDesc& parallel_desc() const = 0; virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; const std::string& input(const std::string& arg_name, int32_t index) const { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const { return user_op_conf().output_size(arg_name); } const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); } const OperatorConf& op_conf() const { return user_op_conf().op_conf(); } template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } template const T& attr(const std::string& attr_name) const; protected: KernelInitContext() = default; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; class KernelCacheContext { public: OF_DISALLOW_COPY_AND_MOVE(KernelCacheContext); virtual ~KernelCacheContext() = default; virtual ep::Stream* stream() = 0; virtual DeviceType device_type() const = 0; virtual const ParallelContext& parallel_ctx() const = 0; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const ParallelDesc& parallel_desc() const = 0; virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; const std::string& input(const std::string& arg_name, int32_t index) const { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const { return user_op_conf().output_size(arg_name); } const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); } const OperatorConf& op_conf() const { return user_op_conf().op_conf(); } template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } template const T& attr(const std::string& attr_name) const; protected: KernelCacheContext() = default; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; class KernelInferContext { public: OF_DISALLOW_COPY_AND_MOVE(KernelInferContext); virtual ~KernelInferContext() = default; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual DeviceType device_type() const = 0; virtual const ParallelContext& parallel_ctx() const = 0; virtual ep::Stream* stream() = 0; virtual Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0; virtual ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0; virtual MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0; const std::string& input(const std::string& arg_name, int32_t index) const { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const { return user_op_conf().output_size(arg_name); } const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); } template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } virtual InferContext* MutOpInferContext() { UNIMPLEMENTED(); return nullptr; } virtual const TensorDescInferFn& GetOpInferFn() const { UNIMPLEMENTED(); static TensorDescInferFn empty_fn; return empty_fn; } protected: KernelInferContext() = default; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; class Tensor; class KernelComputeContext { public: OF_DISALLOW_COPY_AND_MOVE(KernelComputeContext); virtual ~KernelComputeContext() = default; virtual Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) = 0; virtual ep::Stream* stream() = 0; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0; virtual DeviceType device_type() const = 0; virtual const ParallelContext& parallel_ctx() const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; const std::string& input(const std::string& arg_name, int32_t index) const { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const { return user_op_conf().output_size(arg_name); } const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); } template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } protected: KernelComputeContext() = default; virtual const UserOpConfWrapper& user_op_conf() const = 0; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; class OpKernelState { public: virtual ~OpKernelState() = default; protected: OpKernelState() = default; }; class OpKernelCache { public: virtual ~OpKernelCache() = default; static const int32_t kAllMayChanged = 0; static const int32_t kShapeNotChanged = 1 << 0; static const int32_t kAttrNotChanged = 1 << 1; protected: OpKernelCache() = default; }; class OpKernel; template OpKernel* NewOpKernel(Args&&... args); class OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(OpKernel); virtual ~OpKernel() = default; virtual std::shared_ptr CreateOpKernelState(KernelInitContext* ctx) const { return std::shared_ptr(); } virtual std::shared_ptr InitOpKernelCache(KernelCacheContext* ctx) const { return std::shared_ptr(); } virtual void InitOpKernelCacheWithFlags(KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const { *cache_ptr = InitOpKernelCache(ctx); } virtual void Compute(KernelComputeContext* ctx, OpKernelState*, const OpKernelCache*) const { Compute(ctx); } virtual void Compute(KernelComputeContext* ctx) const { LOG(WARNING) << ctx->op_name() << " :UNIMPLEMENTED"; } virtual void InferShape(KernelInferContext* ctx) const; virtual bool AlwaysComputeWhenAllOutputsEmpty() const = 0; virtual bool IsKernelLaunchSynchronized() const { return true; } bool has_state_or_cache() const { return has_state_or_cache_; } protected: OpKernel() : has_state_or_cache_(true) {} private: template friend OpKernel* NewOpKernel(Args&&... args); bool has_state_or_cache_; }; template OpKernel* NewOpKernel(Args&&... args) { OpKernel* ptr = new T(std::forward(args)...); ptr->has_state_or_cache_ = !(std::is_same::value && std::is_same::value && std::is_same::value); return ptr; } } // namespace user_op } // namespace oneflow #endif ================================================ FILE: oneflow/core/framework/op_kernel_infer_cache.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_kernel_infer_cache.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace user_op { OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const void* scope) { const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf(); std::shared_ptr op = CHECK_JUST(ConstructOp(op_conf)); cache_key_.scope = scope; cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn(); cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size()); cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature()); } bool OpKernelInferCache::IsCacheHit() const { size_t hash_value = std::hash()(cache_key_); HashEqTraitPtr ptr_wrapper(&cache_key_, hash_value); return cached_key2value_.find(ptr_wrapper) != cached_key2value_.end(); } OpKernelInferCache::ValueType OpKernelInferCache::GetCacheValue() const { size_t hash_value = std::hash()(cache_key_); HashEqTraitPtr ptr_wrapper(&cache_key_, hash_value); CHECK(cached_key2value_.find(ptr_wrapper) != cached_key2value_.end()); return cached_key2value_.at(ptr_wrapper); } void OpKernelInferCache::UpdateCacheKey(KernelInferContext* ctx) { auto GetSymbolOfShape = [&](const std::string& arg_name, int32_t arg_index) -> Symbol { Shape shape; ctx->ShapeView4ArgNameAndIndex(arg_name, arg_index).ToShape(&shape); return SymbolOf(shape); }; const auto& inputs = ctx->inputs(); FOR_RANGE(int, i, 0, inputs.size()) { const auto& arg_pair = inputs.at(i); cache_key_.ibn_idx2shape_sym.at(i) = GetSymbolOfShape(arg_pair.first, arg_pair.second); } } void OpKernelInferCache::UpdateCacheValue(KernelInferContext* ctx) { // TODO: make max size configurable if (cached_key2value_.size() >= kReleaseInIndependentThreadThreshold) { Reset(); } auto* cache_value = new OpInferCacheValue(); cache_value->obn_idx2shape_sym.resize(ctx->outputs().size()); FOR_RANGE(int, i, 0, ctx->outputs().size()) { const auto& out_arg_pair = ctx->outputs().at(i); const ShapeView& out_shape_view = ctx->ShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second); Shape out_shape; out_shape_view.ToShape(&out_shape); cache_value->obn_idx2shape_sym.at(i).reset(out_shape); } KeyType* new_key = new KeyType(cache_key_); key_storage_.emplace_back(new_key); size_t hash_value = std::hash()(cache_key_); HashEqTraitPtr ptr_wrapper(new_key, hash_value); CHECK(cached_key2value_.emplace(ptr_wrapper, ValueType(cache_value)).second); } void OpKernelInferCache::Reset() { CHECK_EQ(cached_key2value_.size(), key_storage_.size()); HashMap to_release_key2values; KeyStorage to_release_key_storage; std::swap(cached_key2value_, to_release_key2values); std::swap(key_storage_, to_release_key_storage); if (to_release_key2values.size() <= kReleaseInIndependentThreadThreshold) { to_release_key2values.clear(); to_release_key_storage.clear(); } else { std::thread( [](HashMap&& cache, KeyStorage&& key_storage) { cache.clear(); key_storage.clear(); }, std::move(to_release_key2values), std::move(to_release_key_storage)); } } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/op_kernel_infer_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_ #include "oneflow/core/operator/op_infer_cache.h" #include "oneflow/core/common/hash_eq_trait_ptr.h" #include "oneflow/core/kernel/kernel.pb.h" namespace oneflow { namespace user_op { class KernelInferContext; class OpKernelInferCache final { public: using KeyType = OpInferCacheKey; using ValueType = std::shared_ptr; using HashMap = std::unordered_map, ValueType>; using KeyStorage = std::list>; static constexpr size_t kReleaseInIndependentThreadThreshold = 4096; OpKernelInferCache(const KernelConf& kernel_conf, const void* scope); ~OpKernelInferCache() = default; bool IsCacheHit() const; ValueType GetCacheValue() const; void UpdateCacheKey(KernelInferContext* ctx); void UpdateCacheValue(KernelInferContext* ctx); void Reset(); private: KeyType cache_key_; HashMap cached_key2value_; KeyStorage key_storage_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_ ================================================ FILE: oneflow/core/framework/ordered_string_list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_ #define ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_ #include "llvm/ADT/StringRef.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/small_vector.h" namespace oneflow { template class OrderedStringList { public: OrderedStringList() = default; size_t size() const { return strings_.size(); } void emplace_back(llvm::StringRef s) { strings_.emplace_back(std::make_shared(s.str())); order_.emplace(*strings_.back(), order_.size()); } int order(llvm::StringRef s) const { const auto& it = order_.find(s); if (it == order_.end()) { return -1; } return it->second; } const std::string& operator[](int idx) { return *(strings_[idx]); } private: struct Hash { size_t operator()(llvm::StringRef val) const { return HashCombine(val.size(), val.size() > 0 ? static_cast(val.data()[0] - '0') : 0); } }; HashMap order_; // Use shared_ptr to prevent the appended element from being freed when the // vector increases small_vector, N> strings_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_ ================================================ FILE: oneflow/core/framework/parallel_conf_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/str_util.h" #include "oneflow/core/framework/parallel_conf_util.h" #include "oneflow/core/common/shape.pb.h" namespace oneflow { Maybe, std::shared_ptr>> GetDeviceTagAndMachineDeviceIdsAndHierarchy(const ParallelConf& parallel_conf) { std::vector machine_device_ids; machine_device_ids.reserve(parallel_conf.device_name().size()); for (const std::string& device_name : parallel_conf.device_name()) { machine_device_ids.emplace_back(device_name); } std::shared_ptr hierarchy; if (parallel_conf.has_hierarchy()) { hierarchy.reset(new ShapeProto(parallel_conf.hierarchy())); } return std::make_tuple(parallel_conf.device_tag(), machine_device_ids, hierarchy); } Maybe MakeParallelConf(const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy) { std::shared_ptr parallel_conf = std::make_shared(); parallel_conf->set_device_tag(device_tag); for (const std::string& machine_device_id : machine_device_ids) { size_t pos = machine_device_id.find(':'); CHECK_NE_OR_RETURN(pos, std::string::npos) << "device_name: " << machine_device_id; std::string machine_id = machine_device_id.substr(0, pos); CHECK_OR_RETURN( (IsStrInt(machine_id) || (machine_id[0] == '@' && IsStrInt(machine_id.substr(1))))) << " machine_id: " << machine_id; std::string device_id = machine_device_id.substr(pos + 1); size_t minus_pos = device_id.rfind('-'); if (minus_pos == std::string::npos) { CHECK_OR_RETURN(IsStrInt(device_id)); } else { std::string min_id = device_id.substr(0, minus_pos); CHECK_OR_RETURN(IsStrInt(min_id)); std::string max_id = device_id.substr(minus_pos + 1); CHECK_OR_RETURN(IsStrInt(max_id)); } parallel_conf->add_device_name(machine_device_id); if (hierarchy) { ShapeProto proto; hierarchy->ToProto(&proto); parallel_conf->mutable_hierarchy()->CopyFrom(proto); } } return parallel_conf; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/parallel_conf_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_ #include #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/common/shape.h" namespace oneflow { Maybe, std::shared_ptr>> GetDeviceTagAndMachineDeviceIdsAndHierarchy(const ParallelConf& parallel_conf); Maybe MakeParallelConf(const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_ ================================================ FILE: oneflow/core/framework/parallel_conf_util_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/parallel_conf_util.h" namespace oneflow { namespace test { TEST(ParallelConfUtil, MakeParallelConfSuccess) { std::string device_tag = "cpu"; std::vector machine_device_ids; machine_device_ids.emplace_back("0:0-3"); machine_device_ids.emplace_back("1:0-3"); auto parallel_conf = CHECK_JUST(MakeParallelConf(device_tag, machine_device_ids, nullptr)); ASSERT_EQ(parallel_conf->device_tag(), "cpu"); ASSERT_EQ(parallel_conf->device_name().size(), 2); ASSERT_EQ(parallel_conf->has_hierarchy(), false); } TEST(ParallelConfUtil, MakeParallelConfError) { std::string device_tag = "cpu"; std::vector machine_device_ids; machine_device_ids.emplace_back("0:0-3"); machine_device_ids.emplace_back("1:0-"); auto parallel_conf = TRY(MakeParallelConf(device_tag, machine_device_ids, nullptr)); ASSERT_EQ(parallel_conf.error()->has_check_failed_error(), true); } TEST(ParallelConfUtil, GetDeviceTagAndMachineDeviceIdsAndHierarchy) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); parallel_conf.add_device_name("0:2-3"); parallel_conf.add_device_name("1:0-1"); parallel_conf.add_device_name("1:2-3"); parallel_conf.mutable_hierarchy()->add_dim(2); parallel_conf.mutable_hierarchy()->add_dim(4); std::tuple, std::shared_ptr> tag_and_dev_ids_and_hierarchy = *CHECK_JUST(GetDeviceTagAndMachineDeviceIdsAndHierarchy(parallel_conf)); std::string device_tag = std::get<0>(tag_and_dev_ids_and_hierarchy); std::vector machine_device_ids = std::get<1>(tag_and_dev_ids_and_hierarchy); std::shared_ptr hierarchy = std::get<2>(tag_and_dev_ids_and_hierarchy); ASSERT_EQ(device_tag, "cpu"); ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "0:0-1"), 0); ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "0:2-3"), 0); ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "1:0-1"), 0); ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "1:2-3"), 0); ASSERT_EQ(std::count(machine_device_ids.begin(), machine_device_ids.end(), "2:0-3"), 0); ASSERT_EQ(hierarchy->dim(0), 2); ASSERT_EQ(hierarchy->dim(1), 4); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/framework/placed_nd_sbp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/decorator.h" namespace oneflow { namespace { Maybe> RawNew(const Symbol& nd_sbp, const Symbol& placement) { CHECK_OR_RETURN(nd_sbp); CHECK_OR_RETURN(placement); CHECK_GT_OR_RETURN(nd_sbp->sbp_parallel_size(), 0); CHECK_EQ_OR_RETURN(nd_sbp->sbp_parallel_size(), placement->hierarchy()->NumAxes()); return SymbolOf(PlacedNdSbp(nd_sbp, placement)); } } // namespace decltype(PlacedNdSbp::New) PlacedNdSbp::New = DECORATE(&RawNew, ThreadLocal); } // namespace oneflow ================================================ FILE: oneflow/core/framework/placed_nd_sbp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_ #define ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_ #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" namespace oneflow { class NdSbp; class ParallelDesc; class PlacedNdSbp final { public: PlacedNdSbp(const Symbol& nd_sbp, const Symbol& placement) : nd_sbp_(nd_sbp), placement_(placement) {} ~PlacedNdSbp() = default; static Maybe> (*New)(const Symbol&, const Symbol&); const Symbol& nd_sbp() const { return nd_sbp_; } const Symbol& placement() const { return placement_; } bool operator==(const PlacedNdSbp& other) const { return this->nd_sbp_ == other.nd_sbp_ && this->placement_ == other.placement_; } private: Symbol nd_sbp_; Symbol placement_; }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::PlacedNdSbp& placed_nd_sbp) const { return oneflow::Hash(placed_nd_sbp.nd_sbp(), placed_nd_sbp.placement()); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_ ================================================ FILE: oneflow/core/framework/placement_sbp_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/math_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace private_details { namespace { using IndexVector = DimVector; Maybe GetIndexesFromOffset(const Stride& strides, int64_t offset, IndexVector* indexes) { indexes->resize(strides.size()); for (int i = 0; i < strides.size(); ++i) { indexes->at(i) = offset / strides.at(i); offset = offset % strides.at(i); } CHECK_EQ_OR_RETURN(offset, 0); return Maybe::Ok(); } Maybe GetOffsetFromIndexes(const Stride& strides, const IndexVector& indexes, int64_t* offset) { CHECK_EQ_OR_RETURN(strides.size(), indexes.size()) << Error::RuntimeError() << "Expected size of strides to match that of indexes"; *offset = 0; for (int i = 0; i < strides.size(); ++i) { *offset += indexes.at(i) * strides.at(i); } return Maybe::Ok(); } Maybe GetSelectedIndex2OriginIndex( const IndexVector& indexes, const std::vector& axis2is_selected, std::function* SelectedIndex2OriginIndex) { CHECK_EQ_OR_RETURN(axis2is_selected.size(), indexes.size()); *SelectedIndex2OriginIndex = [=](const DimVector& broadcast, DimVector* origin) { origin->resize(indexes.size()); for (int i = 0; i < indexes.size(); ++i) { origin->at(i) = axis2is_selected.at(i) ? broadcast.at(i) : indexes.at(i); } }; return Maybe::Ok(); } Maybe GetSelectedShape(const Shape& hierarchy_shape, const std::vector& axis2is_selected) { CHECK_EQ_OR_RETURN(hierarchy_shape.NumAxes(), axis2is_selected.size()); DimVector dim_vec = hierarchy_shape.dim_vec(); for (int i = 0; i < axis2is_selected.size(); ++i) { if (!axis2is_selected.at(i)) { dim_vec.at(i) = 1; } } return std::make_shared(dim_vec); } Maybe>> CalcAxis2IsBroadcast(Symbol nd_sbp) { std::vector axis2is_selected(nd_sbp->sbp_parallel_size()); for (int i = 0; i < axis2is_selected.size(); ++i) { axis2is_selected.at(i) = nd_sbp->sbp_parallel(i).has_broadcast_parallel(); } return SymbolOf(axis2is_selected); } static auto* GetAxis2IsBroadcast = DECORATE(&CalcAxis2IsBroadcast, ThreadLocal); Maybe> CalcSelectedSubParallelDesc(Symbol parallel_desc, Symbol> axis2is_selected) { const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); int64_t parallel_id = JUST(*opt_parallel_id); const auto& hierarchy_shape = *parallel_desc->hierarchy(); const auto& broadcast_parallel_ids = JUST(GetSelectedParallelIds(hierarchy_shape, *axis2is_selected, parallel_id)); ParallelConf parallel_conf; parallel_conf.set_device_tag(parallel_desc->device_tag()); bool found_parallel_id = false; for (int64_t i : *broadcast_parallel_ids) { found_parallel_id = found_parallel_id || (i == parallel_id); int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(i)); int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(i)); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":" + std::to_string(device_id)); } CHECK_OR_RETURN(found_parallel_id); return SymbolOf(ParallelDesc(parallel_conf)); } static auto* GetSelectedSubParallelDesc = DECORATE(&CalcSelectedSubParallelDesc, ThreadLocal); } // namespace Maybe> CalcSubParallelDesc4Axis(Symbol parallel_desc, int axis) { const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); int64_t parallel_id = JUST(*opt_parallel_id); const auto& hierarchy_shape = *parallel_desc->hierarchy(); Stride hierarchy_strides(hierarchy_shape); int64_t index = CalcIndex4Axis(parallel_id, hierarchy_strides, axis); int64_t stride = hierarchy_strides.at(axis); int64_t start_parallel_id = parallel_id - index * stride; ParallelConf parallel_conf; parallel_conf.set_device_tag(parallel_desc->device_tag()); for (int64_t i = 0; i < hierarchy_shape.At(axis); ++i) { int64_t id = start_parallel_id + i * stride; int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(id)); int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(id)); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":" + std::to_string(device_id)); } return SymbolOf(ParallelDesc(parallel_conf)); } Maybe> GetSelectedParallelIds(const Shape& hierarchy_shape, const std::vector& axis2is_selected, int64_t parallel_id) { CHECK_EQ_OR_RETURN(hierarchy_shape.NumAxes(), axis2is_selected.size()); Stride hierarchy_strides(hierarchy_shape); IndexVector indexes{}; JUST(GetIndexesFromOffset(hierarchy_strides, parallel_id, &indexes)); std::function SelectedIndex2OriginIndex; JUST(GetSelectedIndex2OriginIndex(indexes, axis2is_selected, &SelectedIndex2OriginIndex)); const auto& broadcast_shape = JUST(GetSelectedShape(hierarchy_shape, axis2is_selected)); Stride broadcast_strides(*broadcast_shape); const auto& origin_offsets = std::make_shared>(broadcast_shape->elem_cnt()); for (int64_t i = 0; i < broadcast_shape->elem_cnt(); ++i) { IndexVector broadcast_indexes{}; JUST(GetIndexesFromOffset(broadcast_strides, i, &broadcast_indexes)); IndexVector origin_indexes{}; SelectedIndex2OriginIndex(broadcast_indexes, &origin_indexes); int64_t origin_offset = -1; JUST(GetOffsetFromIndexes(hierarchy_strides, origin_indexes, &origin_offset)); origin_offsets->at(i) = origin_offset; } return origin_offsets; } Maybe> GetBroadcastSubParallelDesc(Symbol parallel_desc, Symbol nd_sbp) { const auto& axis2is_selected = JUST(GetAxis2IsBroadcast(nd_sbp)); return GetSelectedSubParallelDesc(parallel_desc, axis2is_selected); } namespace { Maybe> MakeNdSbp(const SbpParallel& sbp) { NdSbp nd_sbp; nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(sbp); return SymbolOf(nd_sbp); } Maybe InitShapeAxis2NdSbpIndexes( Symbol nd_sbp, std::vector>* shape_axis2nd_sbp_indexes) { for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp = nd_sbp->sbp_parallel(i); if (sbp.has_split_parallel()) { int64_t axis = sbp.split_parallel().axis(); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, shape_axis2nd_sbp_indexes->size()); shape_axis2nd_sbp_indexes->at(axis).emplace_back(i); } } return Maybe::Ok(); } Maybe CheckSplitAxisExpandable( const Shape& hierarchy, const std::vector>& shape_axis2src_nd_sbp_indexes, const std::vector>& shape_axis2dst_nd_sbp_indexes) { const auto& GetHierarchyDim = [&](int64_t axis) { return hierarchy.At(axis); }; for (int i = 0; i < shape_axis2src_nd_sbp_indexes.size(); ++i) { const auto& src_nd_sbp_indexes = JUST(VectorAt(shape_axis2src_nd_sbp_indexes, i)); if (src_nd_sbp_indexes.empty()) { continue; } const auto& dst_nd_sbp_indexes = JUST(VectorAt(shape_axis2dst_nd_sbp_indexes, i)); if (dst_nd_sbp_indexes.empty()) { continue; } std::vector src_nd_sbp_dims{}; src_nd_sbp_dims.reserve(src_nd_sbp_indexes.size()); std::transform(src_nd_sbp_indexes.begin(), src_nd_sbp_indexes.end(), std::back_inserter(src_nd_sbp_dims), GetHierarchyDim); std::vector dst_nd_sbp_dims{}; dst_nd_sbp_dims.reserve(dst_nd_sbp_indexes.size()); std::transform(dst_nd_sbp_indexes.begin(), dst_nd_sbp_indexes.end(), std::back_inserter(dst_nd_sbp_dims), GetHierarchyDim); CHECK_OR_RETURN(src_nd_sbp_dims == dst_nd_sbp_dims) << Error::BoxingNotSupportedError(); } return Maybe::Ok(); } Maybe InitShapAxis2ExpandedDim( std::vector* shape_axis2expanded_dims, const Shape& shape, const Shape& hierarchy, const std::vector>& shape_axis2src_nd_sbp_indexes, const std::vector>& shape_axis2dst_nd_sbp_indexes) { std::vector shape_axis2required_dim(shape.NumAxes()); for (int i = 0; i < shape.NumAxes(); ++i) { const auto& src_nd_sbp_indexes = shape_axis2src_nd_sbp_indexes.at(i); const auto& dst_nd_sbp_indexes = shape_axis2dst_nd_sbp_indexes.at(i); int64_t max_used_cnt = std::max(src_nd_sbp_indexes.size(), dst_nd_sbp_indexes.size()); for (int j = 0; j < max_used_cnt; ++j) { if (j < src_nd_sbp_indexes.size() && j < dst_nd_sbp_indexes.size()) { int64_t m = hierarchy.At(src_nd_sbp_indexes.at(j)); int64_t n = hierarchy.At(dst_nd_sbp_indexes.at(j)); shape_axis2required_dim.at(i).emplace_back(Lcm(m, n)); } else if (j < src_nd_sbp_indexes.size()) { shape_axis2required_dim.at(i).emplace_back(hierarchy.At(src_nd_sbp_indexes.at(j))); } else if (j < dst_nd_sbp_indexes.size()) { shape_axis2required_dim.at(i).emplace_back(hierarchy.At(dst_nd_sbp_indexes.at(j))); } else { UNIMPLEMENTED_THEN_RETURN(); } } } for (int i = 0; i < shape.NumAxes(); ++i) { int64_t total_dim = shape.At(i); shape_axis2expanded_dims->at(i).clear(); if (JUST(VectorAt(shape_axis2required_dim, i)).empty() || JUST(VectorAt(shape_axis2required_dim, i)).size() == 1) { shape_axis2expanded_dims->at(i).emplace_back(total_dim); } else { Shape inner_shape(shape_axis2required_dim.at(i)); CHECK_EQ_OR_RETURN(total_dim % inner_shape.elem_cnt(), 0) << "dim " << total_dim << "(axis " << i << " in shape " << shape.ToString() << ")" << " cannot be reshape into exapanded shape " << inner_shape.ToString(); auto* dim_vec = &shape_axis2expanded_dims->at(i); *dim_vec = shape_axis2required_dim.at(i); dim_vec->at(dim_vec->size() - 1) *= total_dim / inner_shape.elem_cnt(); } } return Maybe::Ok(); } Maybe Flatten(const std::vector& shape_axis2expanded_dims) { DimVector dim_vec; for (const auto& expanded_dims : shape_axis2expanded_dims) { CHECK_OR_RETURN(!expanded_dims.empty()); dim_vec.insert(dim_vec.end(), expanded_dims.begin(), expanded_dims.end()); } return std::make_shared(dim_vec); } Maybe InitOldAxis2NewAxisOffset(std::vector* old_axis2new_axis_offset, const std::vector& shape_axis2expanded_dims) { for (int i = 0, offset = 0; i < shape_axis2expanded_dims.size(); ++i) { old_axis2new_axis_offset->at(i) = offset; offset += shape_axis2expanded_dims.at(i).size(); } return Maybe::Ok(); } Maybe> ShiftSplitAxis( Symbol nd_sbp, const std::vector>& shape_axis2nd_sbp_indexes, const std::vector& old_axis2new_axis_offset) { CHECK_EQ_OR_RETURN(shape_axis2nd_sbp_indexes.size(), old_axis2new_axis_offset.size()); NdSbp new_nd_sbp(*nd_sbp); for (int axis = 0; axis < shape_axis2nd_sbp_indexes.size(); ++axis) { int64_t offset = old_axis2new_axis_offset.at(axis); for (int64_t j = 0; j < shape_axis2nd_sbp_indexes.at(axis).size(); ++j) { int64_t nd_sbp_index = shape_axis2nd_sbp_indexes.at(axis).at(j); CHECK_GE_OR_RETURN(nd_sbp_index, 0); CHECK_LT_OR_RETURN(nd_sbp_index, new_nd_sbp.sbp_parallel_size()); auto* sbp_parallel = new_nd_sbp.mutable_sbp_parallel(nd_sbp_index); CHECK_OR_RETURN(sbp_parallel->has_split_parallel()); CHECK_EQ_OR_RETURN(sbp_parallel->split_parallel().axis(), axis); sbp_parallel->mutable_split_parallel()->set_axis(offset + j); } } return SymbolOf(new_nd_sbp); } } // namespace Maybe, Symbol, Symbol>> CalcDecomposableEquivalentShapeAndNdSbpPair(const Shape& shape, const Shape& hierarchy, Symbol src_nd_sbp, Symbol dst_nd_sbp) { CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); std::vector> shape_axis2src_nd_sbp_indexes(shape.NumAxes()); JUST(InitShapeAxis2NdSbpIndexes(src_nd_sbp, &shape_axis2src_nd_sbp_indexes)); std::vector> shape_axis2dst_nd_sbp_indexes(shape.NumAxes()); JUST(InitShapeAxis2NdSbpIndexes(dst_nd_sbp, &shape_axis2dst_nd_sbp_indexes)); std::vector shape_axis2expanded_dims(shape.NumAxes()); CHECK_EQ_OR_RETURN(hierarchy.NumAxes(), src_nd_sbp->sbp_parallel_size()); JUST(CheckSplitAxisExpandable(hierarchy, shape_axis2src_nd_sbp_indexes, shape_axis2dst_nd_sbp_indexes)); JUST(InitShapAxis2ExpandedDim(&shape_axis2expanded_dims, shape, hierarchy, shape_axis2src_nd_sbp_indexes, shape_axis2dst_nd_sbp_indexes)); std::shared_ptr new_shape = JUST(Flatten(shape_axis2expanded_dims)); CHECK_EQ_OR_RETURN(new_shape->elem_cnt(), shape.elem_cnt()); std::vector old_axis2new_axis_offset(shape.NumAxes()); JUST(InitOldAxis2NewAxisOffset(&old_axis2new_axis_offset, shape_axis2expanded_dims)); Symbol new_src_nd_sbp = JUST(ShiftSplitAxis(src_nd_sbp, shape_axis2src_nd_sbp_indexes, old_axis2new_axis_offset)); Symbol new_dst_nd_sbp = JUST(ShiftSplitAxis(dst_nd_sbp, shape_axis2dst_nd_sbp_indexes, old_axis2new_axis_offset)); return std::make_tuple(new_shape, new_src_nd_sbp, new_dst_nd_sbp); } namespace { // nd_sbp is called decomposable if no particular axis is used to split tensor more than once. // e.g. // 1) (S0, S1) is decomposable. // 2) (S0, S0) is not decomposable. // 3) (S1, S1) is not decomposable. // although `nd_sbp (S0, S0) on shape (4, 4)` is not decomposable, they could be transformed into a // decomposable form: `n_sbp (S0, S1) on shape (2, 2, 4)`. Maybe, Symbol>> CalcDecomposableEquivalent( Symbol tensor_meta, Symbol dst_nd_sbp) { std::shared_ptr shape = tensor_meta->shape_ptr(); Symbol src_nd_sbp = tensor_meta->nd_sbp(); const auto& hierarchy = tensor_meta->parallel_desc()->hierarchy(); std::tie(shape, src_nd_sbp, dst_nd_sbp) = *JUST( CalcDecomposableEquivalentShapeAndNdSbpPair(*shape, *hierarchy, src_nd_sbp, dst_nd_sbp)); one::GlobalTensorMeta decomposible_tensor_meta(*shape, tensor_meta->dtype(), tensor_meta->memory_format(), src_nd_sbp, tensor_meta->parallel_desc()); return std::make_pair(SymbolOf(decomposible_tensor_meta), dst_nd_sbp); } static constexpr auto* GetDecomposableEquivalent = DECORATE(&CalcDecomposableEquivalent, ThreadLocal); Maybe InitDstNdSbpAxis2ExclusiveSrcNdSbpAxis( HashMap* dst_nd_sbp_axis2exclusive_src_nd_sbp_axis, Symbol src_nd_sbp, Symbol dst_nd_sbp) { HashMap split_axis2src_nd_sbp_axis; for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = src_nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { split_axis2src_nd_sbp_axis[sbp_parallel.split_parallel().axis()] = i; } } for (int i = 0; i < dst_nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = dst_nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { int64_t axis = sbp_parallel.split_parallel().axis(); const auto& iter = split_axis2src_nd_sbp_axis.find(axis); if (iter != split_axis2src_nd_sbp_axis.end() && iter->second != i) { (*dst_nd_sbp_axis2exclusive_src_nd_sbp_axis)[i] = iter->second; } } } return Maybe::Ok(); } Maybe MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis( std::function>(int64_t)>* ExclusiveSrcNdSbpAxis4DstNdSbpAxis, Symbol src_nd_sbp, Symbol dst_nd_sbp) { CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); HashMap split_axis2src_nd_sbp_axis; for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = src_nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_OR_RETURN(split_axis2src_nd_sbp_axis.emplace(split_axis, i).second); } } { // check split_axis used only once. HashMap split_axis2dst_nd_sbp_axis; for (int i = 0; i < dst_nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = dst_nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_OR_RETURN(split_axis2dst_nd_sbp_axis.emplace(split_axis, i).second); } } } *ExclusiveSrcNdSbpAxis4DstNdSbpAxis = [split_axis2src_nd_sbp_axis, src_nd_sbp, dst_nd_sbp](int64_t dst_axis) -> Maybe> { CHECK_GE_OR_RETURN(dst_axis, 0); CHECK_LT_OR_RETURN(dst_axis, dst_nd_sbp->sbp_parallel_size()); const auto& dst_sbp_parallel = dst_nd_sbp->sbp_parallel(dst_axis); if (!dst_sbp_parallel.has_split_parallel()) { return Optional(); } int64_t split_axis = dst_sbp_parallel.split_parallel().axis(); const auto& src_iter = split_axis2src_nd_sbp_axis.find(split_axis); if (src_iter == split_axis2src_nd_sbp_axis.end()) { return Optional(); } int64_t src_axis = src_iter->second; CHECK_GE_OR_RETURN(src_axis, 0); CHECK_LT_OR_RETURN(src_axis, dst_nd_sbp->sbp_parallel_size()); const auto& src_sbp_parallel = src_nd_sbp->sbp_parallel(src_axis); CHECK_OR_RETURN(src_sbp_parallel.has_split_parallel()); CHECK_EQ_OR_RETURN(src_sbp_parallel.split_parallel().axis(), split_axis); if (src_axis == dst_axis) { return Optional(); } return Optional(src_axis); }; return Maybe::Ok(); } Maybe IsNdSbpBoxingAcyclic( int64_t num_axes, const std::function>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) { for (int start_axis = 0; start_axis < num_axes; ++start_axis) { int64_t axis = start_axis; HashSet visited_axes; for (int i = 0; i < num_axes + 1; ++i) { const auto& opt_axis = JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(axis)); if (!opt_axis->has_value()) { break; } axis = JUST(*opt_axis); if (!visited_axes.insert(axis).second) { return false; } } } return true; } Maybe InitNdSbpValidTransformationAxisSequence( std::vector* nd_sbp_axis_sequence, Symbol src_nd_sbp, Symbol dst_nd_sbp, const std::function>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) { CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); int64_t num_axes = src_nd_sbp->sbp_parallel_size(); HashSet handled_axes; nd_sbp_axis_sequence->reserve(num_axes); const auto& HasNoExclusiveSrcNdSbpAxis = [&](int64_t axis) -> Maybe { const auto& opt_src_axis = JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(axis)); if (!opt_src_axis->has_value()) { return true; } return handled_axes.count(JUST(*opt_src_axis)) > 0; }; for (int i = 0; i < num_axes; ++i) { for (int axis = 0; axis < num_axes; ++axis) { if (handled_axes.count(axis) == 0 && JUST(HasNoExclusiveSrcNdSbpAxis(axis))) { if (!(src_nd_sbp->sbp_parallel(axis) == dst_nd_sbp->sbp_parallel(axis))) { nd_sbp_axis_sequence->emplace_back(axis); } handled_axes.insert(axis); } } } CHECK_EQ_OR_RETURN(handled_axes.size(), num_axes); return Maybe::Ok(); } } // namespace Maybe IsNdSbpBoxingAcyclic(Symbol src_nd_sbp, Symbol dst_nd_sbp) { std::function>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis; JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp, dst_nd_sbp)); return IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis); } Maybe> GetNdSbpValidTransformationAxisSequence(Symbol src_nd_sbp, Symbol dst_nd_sbp) { HashMap dst_nd_sbp_axis2exclusive_src_nd_sbp_axis; std::function>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis; JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp, dst_nd_sbp)); bool is_acyclic = JUST( IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError() << "cyclic split axis boxing are not supported"; std::vector nd_sbp_axis_sequence; JUST(InitNdSbpValidTransformationAxisSequence(&nd_sbp_axis_sequence, src_nd_sbp, dst_nd_sbp, ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); return nd_sbp_axis_sequence; } std::string GetCyclicBoxingDebugString( Symbol src_nd_sbp, Symbol dst_nd_sbp, const std::function>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) { CHECK_EQ(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); std::stringstream ss; ss << "cyclic split axis boxing are not supported. " << "src_nd_sbp: " << NdSbpToString(src_nd_sbp) << ", dst_nd_sbp: " << NdSbpToString(dst_nd_sbp) << ". " << "dst_nd_sbp axis to exclusive src_nd_sbp axis: "; ss << "["; for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) { const auto& opt_axis = CHECK_JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(i)); if (i) { ss << ", "; } if (opt_axis->has_value()) { ss << CHECK_JUST(*opt_axis); } else { ss << "None"; } } ss << "]"; return ss.str(); } Maybe GetPhysicalShape(const Shape& shape, Symbol nd_sbp, Symbol parallel_desc) { const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); return GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id)); } Maybe GetSubLogicalShape(Symbol tensor_meta, Symbol sub_parallel_desc, Symbol sub_nd_sbp) { CHECK_EQ_OR_RETURN(sub_nd_sbp->sbp_parallel_size(), 1); // NOLINT(maybe-need-error-msg) const auto& logical_shape = tensor_meta->shape(); const auto& physical_shape = JUST(GetPhysicalShape(logical_shape, tensor_meta->nd_sbp(), tensor_meta->parallel_desc())); std::shared_ptr sub_logical_shape = std::make_shared(*physical_shape); if (sub_nd_sbp->sbp_parallel(0).has_split_parallel()) { const int64_t split_axis = sub_nd_sbp->sbp_parallel(0).split_parallel().axis(); sub_logical_shape->Set(split_axis, logical_shape.At(split_axis)); } return sub_logical_shape; } Maybe> CalcSubGlobalTensorMeta( Symbol tensor_meta, Symbol sub_parallel_desc, Symbol sub_nd_sbp) { CHECK_EQ_OR_RETURN(sub_nd_sbp->sbp_parallel_size(), 1); // NOLINT(maybe-need-error-msg) const auto& logical_shape = JUST(GetSubLogicalShape(tensor_meta, sub_parallel_desc, sub_nd_sbp)); one::GlobalTensorMeta sub_global_tensor_meta(*logical_shape, tensor_meta->dtype(), tensor_meta->memory_format(), sub_nd_sbp, sub_parallel_desc); return SymbolOf(sub_global_tensor_meta); } static constexpr auto* GetSubGlobalTensorMeta = DECORATE(&CalcSubGlobalTensorMeta, ThreadLocal); Maybe> ReplaceNdSbpComponent(Symbol nd_sbp, int64_t axis, Symbol component) { CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, nd_sbp->sbp_parallel_size()); CHECK_EQ_OR_RETURN(component->sbp_parallel_size(), 1); NdSbp new_nd_sbp(*nd_sbp); *new_nd_sbp.mutable_sbp_parallel(axis) = component->sbp_parallel(0); return SymbolOf(new_nd_sbp); } Maybe> ReplaceNdSbp(Symbol tensor_meta, Symbol nd_sbp) { one::GlobalTensorMeta new_tensor_meta(tensor_meta->shape(), tensor_meta->dtype(), tensor_meta->memory_format(), nd_sbp, tensor_meta->parallel_desc()); return SymbolOf(new_tensor_meta); } Maybe> DecomposeIntoNaiveTransformations( Symbol tensor_meta, Symbol dst_nd_sbp) { std::tie(tensor_meta, dst_nd_sbp) = *JUST(GetDecomposableEquivalent(tensor_meta, dst_nd_sbp)); const auto& parallel_desc = tensor_meta->parallel_desc(); const auto& src_nd_sbp = tensor_meta->nd_sbp(); CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); std::vector nd_sbp_axis_sequence; { std::function>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis; JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp, dst_nd_sbp)); bool is_acyclic = JUST( IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError() << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp, ExclusiveSrcNdSbpAxis4DstNdSbpAxis); JUST(InitNdSbpValidTransformationAxisSequence(&nd_sbp_axis_sequence, src_nd_sbp, dst_nd_sbp, ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); } const auto& transformations = std::make_shared>(); for (int axis : nd_sbp_axis_sequence) { const auto& src_sbp = src_nd_sbp->sbp_parallel(axis); const auto& dst_sbp = dst_nd_sbp->sbp_parallel(axis); if (src_sbp == dst_sbp) { continue; } std::vector axis2selected(src_nd_sbp->sbp_parallel_size()); axis2selected[axis] = 1; const auto& sub_parallel_desc = JUST(GetSelectedSubParallelDesc(parallel_desc, SymbolOf(axis2selected))); const auto& sub_src_nd_sbp = JUST(MakeNdSbp(src_sbp)); const auto& sub_dst_nd_sbp = JUST(MakeNdSbp(dst_sbp)); const auto& sub_global_tensor_meta = JUST(GetSubGlobalTensorMeta(tensor_meta, sub_parallel_desc, sub_src_nd_sbp)); const auto& new_src_nd_sbp = JUST(ReplaceNdSbpComponent(tensor_meta->nd_sbp(), axis, sub_dst_nd_sbp)); tensor_meta = JUST(ReplaceNdSbp(tensor_meta, new_src_nd_sbp)); transformations->emplace_back(NaiveBoxingTransformation{ .global_tensor_meta = sub_global_tensor_meta, .dst_nd_sbp = sub_dst_nd_sbp, }); } return transformations; } } // namespace private_details namespace { Maybe>> CalcBroadcastGroup( Symbol src_parallel_desc, Symbol dst_parallel_desc, bool allow_across_node) { CHECK_EQ_OR_RETURN(src_parallel_desc->parallel_num(), src_parallel_desc->sorted_machine_ids().size()); CHECK_EQ_OR_RETURN(dst_parallel_desc->parallel_num(), dst_parallel_desc->sorted_machine_ids().size()); CHECK_EQ_OR_RETURN(src_parallel_desc->device_type(), dst_parallel_desc->device_type()); CHECK_LE_OR_RETURN(src_parallel_desc->parallel_num(), dst_parallel_desc->parallel_num()); const auto& src_process_ids = src_parallel_desc->sorted_machine_ids(); HashMap> process_id2group{}; HashMap> node_id2src_process_id{}; for (int64_t process_id : src_process_ids) { std::vector vec{process_id}; CHECK_OR_RETURN(process_id2group.emplace(process_id, vec).second); CHECK_OR_RETURN(dst_parallel_desc->ContainingMachineId(process_id)); node_id2src_process_id[GlobalProcessCtx::NodeId(process_id)].emplace_back(process_id); } std::vector remainder_process_ids{}; remainder_process_ids.reserve(dst_parallel_desc->sorted_machine_ids().size()); HashMap node_id2counter{}; for (int64_t process_id : dst_parallel_desc->sorted_machine_ids()) { if (!src_parallel_desc->ContainingMachineId(process_id)) { const auto& node_iter = node_id2src_process_id.find(GlobalProcessCtx::NodeId(process_id)); if (node_iter == node_id2src_process_id.end()) { CHECK_OR_RETURN(allow_across_node) << Error::UnimplementedError() << "\n----[src_placement]----\n" << src_parallel_desc->parallel_conf().DebugString() << "\n----[dst_placement]----\n" << dst_parallel_desc->parallel_conf().DebugString(); // handle `process_id` later. remainder_process_ids.emplace_back(process_id); } else { // balancedly put `process_id` into the groups within the same node.. int64_t node_id = node_iter->first; const auto& src_process_ids = node_iter->second; int64_t src_process_index = (node_id2counter[node_id]++) % src_process_ids.size(); int64_t src_process_id = src_process_ids.at(src_process_index); JUST(MapAt(process_id2group, src_process_id)).emplace_back(process_id); } } } // put remainder process ids into src groups. for (int i = 0; i < remainder_process_ids.size(); ++i) { int64_t src_process_id = src_process_ids.at(i % src_process_ids.size()); JUST(MapAt(process_id2group, src_process_id)) .emplace_back(JUST(oneflow::VectorAt(remainder_process_ids, i))); } const auto& map = std::make_shared>>(); for (const auto& pair : process_id2group) { const auto& group = pair.second; ParallelConf parallel_conf; parallel_conf.set_device_tag(dst_parallel_desc->parallel_conf().device_tag()); for (int64_t process_id : group) { const auto& device_ids = dst_parallel_desc->sorted_dev_phy_ids(process_id); CHECK_EQ_OR_RETURN(device_ids.size(), 1); parallel_conf.add_device_name(std::string("@") + std::to_string(process_id) + ":" + std::to_string(device_ids.at(0))); } const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); for (int64_t process_id : group) { CHECK_OR_RETURN(map->emplace(process_id, parallel_desc).second); } } return map; } auto* CachedBroadcastGroup = DECORATE(&CalcBroadcastGroup, ThreadLocal); Maybe RawCheckIsNdSbpBoxingAcyclic(Symbol in, Symbol out) { using namespace private_details; const auto& src_nd_sbp = in->nd_sbp(); const auto& dst_nd_sbp = out->nd_sbp(); std::function>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis; JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp, dst_nd_sbp)); bool is_acyclic = JUST( IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError() << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp, ExclusiveSrcNdSbpAxis4DstNdSbpAxis); return Maybe::Ok(); } Maybe RawCheckIsNdSbpBoxingAcyclicWithDecompose(Symbol in, Symbol out, const Shape& logical_shape) { using namespace private_details; Symbol src_nd_sbp = in->nd_sbp(); Symbol dst_nd_sbp = out->nd_sbp(); const auto& hierarchy = in->placement()->hierarchy(); std::shared_ptr shape; std::tie(shape, src_nd_sbp, dst_nd_sbp) = *JUST(CalcDecomposableEquivalentShapeAndNdSbpPair( logical_shape, *hierarchy, src_nd_sbp, dst_nd_sbp)); std::function>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis; JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp, dst_nd_sbp)); bool is_acyclic = JUST( IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis)); CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError() << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp, ExclusiveSrcNdSbpAxis4DstNdSbpAxis); return Maybe::Ok(); } } // namespace int64_t CalcIndex4Axis(int64_t offset, const Stride& stride, int axis) { CHECK_LT(axis, stride.size()) << "Expected axis (" << axis << ") to be less than size of stride (" << stride.size() << ")"; if (axis == 0) { return offset / stride.at(0); } else { return offset % stride.at(axis - 1) / stride.at(axis); } } decltype(CheckIsNdSbpBoxingAcyclic) CheckIsNdSbpBoxingAcyclic = DECORATE(&RawCheckIsNdSbpBoxingAcyclic, ThreadLocal); decltype(CheckIsNdSbpBoxingAcyclicWithDecompose) CheckIsNdSbpBoxingAcyclicWithDecompose = DECORATE(&RawCheckIsNdSbpBoxingAcyclicWithDecompose, ThreadLocalCopiable); Maybe>> GetBroadcastGroup( Symbol src_parallel_desc, Symbol dst_parallel_desc) { return CachedBroadcastGroup(src_parallel_desc, dst_parallel_desc, true); } Maybe>> GetBroadcastGroupWithoutAcrossNode( Symbol src_parallel_desc, Symbol dst_parallel_desc) { return CachedBroadcastGroup(src_parallel_desc, dst_parallel_desc, false); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/placement_sbp_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/stride.h" namespace oneflow { class Shape; class Stride; class ParallelDesc; class PlacedNdSbp; namespace one { class GlobalTensorMeta; } // 1) src_nd_sbp.sbp_parallel_size() == 1 // 2) dst_nd_sbp.sbp_parallel_size() == 1 struct NaiveBoxingTransformation { Symbol global_tensor_meta; Symbol dst_nd_sbp; }; namespace private_details { Maybe> GetSelectedParallelIds(const Shape& hierarchy_shape, const std::vector& axis2is_selected, int64_t parallel_id); Maybe, Symbol, Symbol>> CalcDecomposableEquivalentShapeAndNdSbpPair(const Shape& shape, const Shape& hierarchy, Symbol src_nd_sbp, Symbol dst_nd_sbp); Maybe> GetBroadcastSubParallelDesc(Symbol parallel_desc, Symbol nd_sbp); Maybe> DecomposeIntoNaiveTransformations( Symbol tensor_meta, Symbol dst_nd_sbp); Maybe IsNdSbpBoxingAcyclic(Symbol src_nd_sbp, Symbol dst_nd_sbp); Maybe> GetNdSbpValidTransformationAxisSequence(Symbol src_nd_sbp, Symbol dst_nd_sbp); Maybe> CalcSubGlobalTensorMeta( Symbol tensor_meta, Symbol sub_parallel_desc, Symbol sub_nd_sbp); Maybe> CalcSubParallelDesc4Axis(Symbol parallel_desc, int axis); } // namespace private_details extern Maybe (*CheckIsNdSbpBoxingAcyclic)(Symbol in, Symbol out); extern Maybe (*CheckIsNdSbpBoxingAcyclicWithDecompose)(Symbol in, Symbol out, const Shape& logical_shape); int64_t CalcIndex4Axis(int64_t offset, const Stride& stride, int axis); static constexpr auto* GetSubGlobalTensorMeta = DECORATE(&private_details::CalcSubGlobalTensorMeta, ThreadLocal); static constexpr auto* GetBroadcastSubParallelDesc = DECORATE(&private_details::GetBroadcastSubParallelDesc, ThreadLocal); static constexpr auto* DecomposeIntoNaiveTransformations = DECORATE(&private_details::DecomposeIntoNaiveTransformations, ThreadLocal); static constexpr auto* CalcSubParallelDesc4Axis = DECORATE(&private_details::CalcSubParallelDesc4Axis, ThreadLocal); Maybe>> GetBroadcastGroup( Symbol src_parallel_desc, Symbol dst_parallel_desc); Maybe>> GetBroadcastGroupWithoutAcrossNode( Symbol src_parallel_desc, Symbol dst_parallel_desc); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_ ================================================ FILE: oneflow/core/framework/placement_sbp_util_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { namespace test { namespace { struct GlobaProcessCtxScope final { GlobaProcessCtxScope(GlobaProcessCtxScope&) = default; GlobaProcessCtxScope(GlobaProcessCtxScope&&) = default; GlobaProcessCtxScope& operator=(GlobaProcessCtxScope&) = default; GlobaProcessCtxScope& operator=(GlobaProcessCtxScope&&) = default; GlobaProcessCtxScope(int64_t node_size, int64_t world_size) { Singleton::New(); auto* ctx = Singleton::Get(); for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); } ctx->set_rank(0); ctx->set_node_size(node_size); } ~GlobaProcessCtxScope() { Singleton::Delete(); } }; } // namespace TEST(GetSelectedParallelIds, 1d_broadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size}); std::vector axis2is_selected{true}; const auto& expected = std::vector{0, 1, 2, 3}; for (int i = 0; i < parallel_size; ++i) { const auto& broadcast_parallel_ids = CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i)); ASSERT_TRUE(*broadcast_parallel_ids == expected); } } TEST(GetSelectedParallelIds, 1d_nonbroadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size}); std::vector axis2is_selected{false}; for (int i = 0; i < parallel_size; ++i) { const auto& broadcast_parallel_ids = CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i)); const auto& expected = std::vector{i}; ASSERT_TRUE(*broadcast_parallel_ids == expected); } } TEST(GetSelectedParallelIds, 2d_broadcast_broadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size, parallel_size}); std::vector axis2is_selected{true, true}; std::vector expected{}; for (int i = 0; i < parallel_size * parallel_size; ++i) { expected.emplace_back(i); } for (int i = 0; i < parallel_size * parallel_size; ++i) { const auto& broadcast_parallel_ids = CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i)); ASSERT_TRUE(*broadcast_parallel_ids == expected); } } TEST(GetSelectedParallelIds, 2d_nonbroadcast_nonbroadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size, parallel_size}); std::vector axis2is_selected{false, false}; for (int i = 0; i < parallel_size * parallel_size; ++i) { const auto& broadcast_parallel_ids = CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i)); const auto& expected = std::vector{i}; ASSERT_TRUE(*broadcast_parallel_ids == expected); } } TEST(GetSelectedParallelIds, 2d_broadcast_nonbroadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size, parallel_size}); std::vector axis2is_selected{true, false}; for (int i = 0; i < parallel_size; ++i) { for (int j = 0; j < parallel_size; ++j) { std::vector expected{}; for (int k = 0; k < parallel_size; ++k) { expected.emplace_back(k * parallel_size + j); } int64_t parallel_id = i * parallel_size + j; const auto& broadcast_parallel_ids = CHECK_JUST( private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, parallel_id)); ASSERT_TRUE(*broadcast_parallel_ids == expected); } } } TEST(GetSelectedParallelIds, 2d_nonbroadcast_broadcast) { int64_t parallel_size = 4; Shape hierarchy_shape(DimVector{parallel_size, parallel_size}); std::vector axis2is_selected{false, true}; for (int i = 0; i < parallel_size; ++i) { std::vector expected{}; for (int j = 0; j < parallel_size; ++j) { expected.emplace_back(i * parallel_size + j); } for (int j = 0; j < parallel_size; ++j) { int64_t parallel_id = i * parallel_size + j; const auto& broadcast_parallel_ids = CHECK_JUST( private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, parallel_id)); ASSERT_TRUE(*broadcast_parallel_ids == expected); } } } namespace { void InitSbpParallel(SbpParallel* sbp_parallel, const std::string& sbp_tag) { CHECK(sbp_tag.size() == 1 || sbp_tag.size() == 2); if (sbp_tag[0] == 'S') { CHECK_EQ(sbp_tag.size(), 2); int64_t axis = sbp_tag[1] - '0'; sbp_parallel->mutable_split_parallel()->set_axis(axis); } else if (sbp_tag == "B") { sbp_parallel->mutable_broadcast_parallel(); } else if (sbp_tag == "P") { sbp_parallel->mutable_partial_sum_parallel(); } else { UNIMPLEMENTED(); } } template Symbol GetNdSbp(Args... sbps) { NdSbp nd_sbp; for (const auto& sbp : std::vector{sbps...}) { InitSbpParallel(nd_sbp.mutable_sbp_parallel()->Add(), sbp); } return SymbolOf(nd_sbp); } Symbol MakeGlobalTensorMeta(Symbol parallel_desc, Symbol nd_sbp) { auto shape = Shape(DimVector{256, 256}); one::GlobalTensorMeta tensor_meta(shape, DataType::kInt32, MemoryFormat::kContiguous, nd_sbp, parallel_desc); return SymbolOf(tensor_meta); } } // namespace TEST(DecomposeIntoNaiveTransformations, decompose_axis0) { GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); parallel_conf.add_device_name("1:0-3"); parallel_conf.mutable_hierarchy()->add_dim(2); parallel_conf.mutable_hierarchy()->add_dim(4); const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); const auto& src_nd_sbp = GetNdSbp("P", "B"); const auto& dst_nd_sbp = GetNdSbp("S0", "B"); const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp); const auto& transformations = CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp)); ASSERT_EQ(transformations->size(), 1); ParallelConf expected_parallel_conf; expected_parallel_conf.set_device_tag("cpu"); expected_parallel_conf.add_device_name(std::string("0:0")); expected_parallel_conf.add_device_name(std::string("1:0")); const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf)); const auto& ctensor_meta = transformations->at(0).global_tensor_meta; ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc); ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1); ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1); ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel()); ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_split_parallel()); ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0); } TEST(DecomposeIntoNaiveTransformations, decompose_axis1) { GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); parallel_conf.add_device_name("1:0-3"); parallel_conf.mutable_hierarchy()->add_dim(2); parallel_conf.mutable_hierarchy()->add_dim(4); const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); const auto& src_nd_sbp = GetNdSbp("S0", "P"); const auto& dst_nd_sbp = GetNdSbp("S0", "S1"); const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp); const auto& transformations = CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp)); ASSERT_EQ(transformations->size(), 1); ParallelConf expected_parallel_conf; expected_parallel_conf.set_device_tag("cpu"); expected_parallel_conf.add_device_name("0:0-3"); const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf)); const auto& ctensor_meta = transformations->at(0).global_tensor_meta; ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc); ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1); ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1); ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel()); ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_split_parallel()); ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1); } TEST(DecomposeIntoNaiveTransformations, decompose_two_axes) { GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); parallel_conf.add_device_name("1:0-1"); parallel_conf.mutable_hierarchy()->add_dim(2); parallel_conf.mutable_hierarchy()->add_dim(2); const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); const auto& src_nd_sbp = GetNdSbp("S0", "P"); const auto& dst_nd_sbp = GetNdSbp("B", "S0"); const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp); const auto& transformations = CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp)); ASSERT_EQ(transformations->size(), 2); { ParallelConf expected_parallel_conf; expected_parallel_conf.set_device_tag("cpu"); expected_parallel_conf.add_device_name(std::string("0:0")); expected_parallel_conf.add_device_name(std::string("1:0")); const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf)); const auto& ctensor_meta = transformations->at(0).global_tensor_meta; ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc); ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1); ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1); ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_split_parallel()); ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_broadcast_parallel()); ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel(0).split_parallel().axis(), 0); } { ParallelConf expected_parallel_conf; expected_parallel_conf.set_device_tag("cpu"); expected_parallel_conf.add_device_name("0:0-1"); const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf)); const auto& ctensor_meta = transformations->at(1).global_tensor_meta; ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc); ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1); ASSERT_EQ(transformations->at(1).dst_nd_sbp->sbp_parallel_size(), 1); ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel()); ASSERT_TRUE(transformations->at(1).dst_nd_sbp->sbp_parallel(0).has_split_parallel()); ASSERT_EQ(transformations->at(1).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0); } } TEST(CalcDecomposableEquivalentShapeAndNdSbpPair, naive) { Shape shape(DimVector{4, 4}); Shape hierarchy(DimVector{4, 4}); const auto& src_nd_sbp = GetNdSbp("S0", "S1"); const auto& dst_nd_sbp = GetNdSbp("B", "P"); const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair( shape, hierarchy, src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_tuple.IsOk()); const auto& tuple = CHECK_JUST(maybe_tuple); ASSERT_TRUE(*std::get<0>(*tuple) == shape); ASSERT_TRUE(std::get<1>(*tuple) == src_nd_sbp); ASSERT_TRUE(std::get<2>(*tuple) == dst_nd_sbp); } TEST(CalcDecomposableEquivalentShapeAndNdSbpPair, expand_src) { Shape shape(DimVector{16, 4}); Shape hierarchy(DimVector{4, 4}); const auto& src_nd_sbp = GetNdSbp("S0", "S0"); const auto& dst_nd_sbp = GetNdSbp("B", "P"); const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair( shape, hierarchy, src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_tuple.IsOk()); const auto& tuple = CHECK_JUST(maybe_tuple); ASSERT_TRUE(*std::get<0>(*tuple) == Shape(DimVector{4, 4, 4})); ASSERT_TRUE(std::get<1>(*tuple) == GetNdSbp("S0", "S1")); ASSERT_TRUE(std::get<2>(*tuple) == dst_nd_sbp); } TEST(CalcDecomposableEquivalentShapeAndNdSbpPair, expand_failed) { Shape shape(DimVector{32, 4}); Shape hierarchy(DimVector{4, 4, 4}); const auto& src_nd_sbp = GetNdSbp("S0", "S0", "S0"); const auto& dst_nd_sbp = GetNdSbp("P", "S0", "S1"); const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair( shape, hierarchy, src_nd_sbp, dst_nd_sbp)); ASSERT_FALSE(maybe_tuple.IsOk()); } TEST(IsNdSbpBoxingAcyclic, yes) { const auto& src_nd_sbp = GetNdSbp("S0", "S1", "S2"); const auto& dst_nd_sbp = GetNdSbp("S1", "S2", "S3"); const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_acyclic.IsOk()); ASSERT_TRUE(CHECK_JUST(maybe_acyclic)); } TEST(IsNdSbpBoxingAcyclic, ring) { const auto& src_nd_sbp = GetNdSbp("S0", "S1", "S2"); const auto& dst_nd_sbp = GetNdSbp("S1", "S2", "S0"); const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_acyclic.IsOk()); ASSERT_FALSE(CHECK_JUST(maybe_acyclic)); } TEST(IsNdSbpBoxingAcyclic, partial_ring) { const auto& src_nd_sbp = GetNdSbp("B", "S0", "S1", "S2", "S5"); const auto& dst_nd_sbp = GetNdSbp("P", "S1", "S2", "S0", "S4"); const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_acyclic.IsOk()); ASSERT_FALSE(CHECK_JUST(maybe_acyclic)); } TEST(IsNdSbpBoxingAcyclic, dag) { const auto& src_nd_sbp = GetNdSbp("S0", "S1", "S2"); const auto& dst_nd_sbp = GetNdSbp("S1", "S2", "S3"); const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_acyclic.IsOk()); ASSERT_TRUE(CHECK_JUST(maybe_acyclic)); } TEST(GetNdSbpValidTransformationAxisSequence, naive) { const auto& src_nd_sbp = GetNdSbp("S0", "S1", "S2"); const auto& dst_nd_sbp = GetNdSbp("S0", "B", "S2"); const auto& maybe_axis_seq = TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_axis_seq.IsOk()); const auto& axis_seq = CHECK_JUST(maybe_axis_seq); ASSERT_TRUE(*axis_seq == std::vector{1}); } TEST(GetNdSbpValidTransformationAxisSequence, 2d) { const auto& src_nd_sbp = GetNdSbp("B", "S0"); const auto& dst_nd_sbp = GetNdSbp("S0", "S1"); const auto& maybe_axis_seq = TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_axis_seq.IsOk()); const auto& axis_seq = CHECK_JUST(maybe_axis_seq); ASSERT_TRUE(*axis_seq == (std::vector{1, 0})); } TEST(GetNdSbpValidTransformationAxisSequence, 3d) { const auto& src_nd_sbp = GetNdSbp("S0", "S1", "S2"); const auto& dst_nd_sbp = GetNdSbp("S1", "S2", "S3"); const auto& maybe_axis_seq = TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp)); ASSERT_TRUE(maybe_axis_seq.IsOk()); const auto& axis_seq = CHECK_JUST(maybe_axis_seq); ASSERT_TRUE(*axis_seq == (std::vector{2, 1, 0})); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/framework/placement_utils.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/just.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/placement_utils.h" #include "oneflow/core/framework/parallel_conf_util.h" namespace oneflow { Maybe> ReplacePlacementDeviceTag(Symbol parallel_desc, const std::string& device_type) { ParallelConf parallel_conf = parallel_desc->parallel_conf(); parallel_conf.set_device_tag(device_type); std::shared_ptr out_parallel_desc; JUST(PhysicalRun( [&out_parallel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { out_parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf)); return Maybe::Ok(); })); return SymbolOf(*out_parallel_desc); } Maybe TouchGlobalTensor(const std::shared_ptr& tensor) { CHECK_OR_RETURN(tensor->is_global()); // NOLINT return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/placement_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_CORE_FRAMEWORK_PLACEMENT_UTILS_H_ #define _ONEFLOW_CORE_FRAMEWORK_PLACEMENT_UTILS_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { Maybe> ReplacePlacementDeviceTag(Symbol parallel_desc, const std::string& device_type); Maybe TouchGlobalTensor(const std::shared_ptr& tensor); constexpr auto* CheckMetaConsistency = DECORATE(&TouchGlobalTensor, CheckGlobalTensorMeta); } // namespace oneflow #endif ================================================ FILE: oneflow/core/framework/random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/framework/auto_random_generator.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace { uint64_t GetNonDeterministicRandom() { std::random_device rd; // limit to 53 bits to ensure unique representation in double auto s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; return s; } Maybe CPUSynchronize() { if (Singleton::Get() != nullptr) { return vm::CurrentRankSync(); } return Maybe::Ok(); } } // namespace Generator::Generator(const std::shared_ptr& internal) : internal_(internal) {} uint64_t Generator::current_seed() const { return internal_->current_seed(); } void Generator::add_children_generator(Symbol placement, Symbol nd_sbp, const std::shared_ptr& generator) { children_generators_.emplace(std::make_pair(placement, nd_sbp), generator); } const HashMap, Symbol>, std::shared_ptr>& Generator::children_generators() const { return children_generators_; } void Generator::set_current_seed(uint64_t seed) { CHECK_JUST(CPUSynchronize()); internal_->set_current_seed(seed); for (auto pair : children_generators_) { uint64_t rank_seed = seed; if (pair.first.first->parallel_num() > 1) { CHECK_JUST(one::functional::BroadcastSeedToAllRanks(&seed, /*root=*/0)); // NOLINT rank_seed = CHECK_JUST(GetRandomSeedForRank(*(pair.first.first), *(pair.first.second), // NOLINT seed, // NOLINT GlobalProcessCtx::Rank())); // NOLINT } pair.second->set_current_seed(rank_seed); } } uint64_t Generator::seed() { uint64_t seed = GetNonDeterministicRandom(); set_current_seed(seed); return seed; } Maybe> Generator::device() const { return Device::New(internal_->device_type_name(), internal_->device_index()); } Maybe Generator::GetState() const { JUST(CPUSynchronize()); int64_t state_size = internal_->GetStateSize(); std::vector state_data(state_size); internal_->GetState(state_size, state_data.data()); const auto& device = JUST(Device::New("cpu")); const auto& state = JUST(functional::Empty(Shape{state_size}, DType::UInt8(), device, /*requires_grad=*/false, /*pin_memory=*/false)); const auto& callback = [&](ep::Stream*, const std::shared_ptr& eager_blob_object) { memcpy(eager_blob_object->mut_dptr(), state_data.data(), state_size); }; JUST(SyncAccessTensorWithTimeOut(state, callback, "mut")); return state; } Maybe Generator::SetState(const std::shared_ptr& state) { const auto& device = JUST(state->device()); if (device->type() != "cpu") { return Error::RuntimeError() << "Generator state should be host tensor."; } if (state->dtype() != DType::UInt8()) { return Error::RuntimeError() << "Generator state should be dtype=flow.uint8"; } size_t state_size = state->shape()->elem_cnt(); std::vector state_data(state_size); const auto& callback = [&](ep::Stream*, const std::shared_ptr& eager_blob_object) { memcpy(state_data.data(), eager_blob_object->dptr(), state_size); }; JUST(SyncAccessTensorWithTimeOut(state, callback, "const")); JUST(CPUSynchronize()); internal_->SetState(state_size, state_data.data()); return Maybe::Ok(); } Maybe DefaultGenerator(const std::string& device, int device_index) { static auto* default_auto_generator = dynamic_cast(JUST(DefaultAutoGenerator())->internal().get()); if (device_index == -1) { device_index = (device == "cpu" ? 0 : GlobalProcessCtx::LocalRank()); } return std::make_shared( JUST(default_auto_generator->GetOrCreate(device, device_index))); } Maybe DefaultAutoGenerator() { // Skip destructing to avoid calling symbols in other dynamic libraries when the global object is // released. static auto default_auto_generator = std::make_shared(std::shared_ptr( new AutoGenerator(GetNonDeterministicRandom()), [](AutoGenerator*) {})); return default_auto_generator; } Maybe DefaultCPUGenerator() { static auto default_cpu_generator = JUST(DefaultGenerator("cpu", 0)); return default_cpu_generator; } Maybe DefaultCUDAGenerator(int device_index) { #ifdef WITH_CUDA static int device_count = GetCudaDeviceCount(); #else static int device_count = 0; #endif // WITH_CUDA static std::vector init_flags(device_count); static std::vector> default_cuda_generator(device_count); if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); } CHECK_OR_RETURN(device_index >= 0 && device_index < device_count) << "Invalid device index " << device_index; std::call_once(init_flags[device_index], [&]() { default_cuda_generator[device_index] = CHECK_JUST(DefaultGenerator("cuda", device_index)); }); return default_cuda_generator.at(device_index); } Maybe MakeAutoGenerator() { return std::make_shared(std::make_shared(default_rng_seed_val)); } Maybe MakeCPUGenerator() { static auto device_mgr = Singleton::Get()->GetDeviceManager(DeviceType::kCPU); return std::make_shared(device_mgr->CreateRandomGenerator(default_rng_seed_val, 0)); } Maybe MakeCUDAGenerator(int device_index) { static auto device_mgr = Singleton::Get()->GetDeviceManager(DeviceType::kCUDA); if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); } return std::make_shared( device_mgr->CreateRandomGenerator(default_rng_seed_val, device_index)); } Maybe ManualSeedAllCudaGenerator(uint64_t seed) { #ifdef WITH_CUDA static int device_count = GetCudaDeviceCount(); FOR_RANGE(int, device_id, 0, device_count) { const auto& cuda_gen = JUST(DefaultCUDAGenerator(device_id)); cuda_gen->set_current_seed(seed); } #endif // WITH_CUDA return Maybe::Ok(); } Maybe MakeGenerator(const std::string& device, int device_index) { if (device == "auto") { return std::make_shared(std::make_shared(default_rng_seed_val)); } auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device); if (device_type == DeviceType::kInvalidDevice) { return Error::RuntimeError() << "Expected one of " << PrintGeneratorAvailableDevices() << " device type at start of device string: " << device; } auto device_mgr = Singleton::Get()->GetDeviceManager(device_type); if (device_index == -1) { device_index = (device == "cpu" ? 0 : GlobalProcessCtx::LocalRank()); } return std::make_shared( device_mgr->CreateRandomGenerator(default_rng_seed_val, device_index)); } Maybe DefaultGenerator(DeviceType device, int device_index) { return DefaultGenerator(*JUST(DeviceTag4DeviceType(device)), device_index); } Maybe MakeGenerator(DeviceType device, int device_index) { return MakeGenerator(*JUST(DeviceTag4DeviceType(device)), device_index); } Maybe ManualSeed(uint64_t seed) { const auto& default_auto_generator = JUST(DefaultAutoGenerator()); default_auto_generator->set_current_seed(seed); return default_auto_generator; } Maybe ManualSeed(uint64_t seed, const std::string& device, int device_index) { JUST(DefaultGenerator(device, device_index))->set_current_seed(seed); return Maybe::Ok(); } Maybe ManualSeed(uint64_t seed, DeviceType device, int device_index) { return ManualSeed(seed, *JUST(DeviceTag4DeviceType(device)), device_index); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_ #include #include "oneflow/core/ep/include/random_generator.h" #include "oneflow/core/framework/auto_random_generator.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/ep/cpu/cpu_random_generator.h" #include "oneflow/core/ep/cuda/cuda_random_generator.h" #include "oneflow/core/common/hash_container.h" namespace oneflow { class NdSbp; namespace one { // The default seed is selected to be a large number // with good distribution of 0s and 1s in bit representation. static constexpr uint64_t default_rng_seed_val = 67280421310721; class Tensor; class Generator final { public: explicit Generator(const std::shared_ptr& internal); ~Generator() = default; void set_current_seed(uint64_t seed); uint64_t current_seed() const; void add_children_generator(Symbol placement, Symbol nd_sbp, const std::shared_ptr& generator); const HashMap, Symbol>, std::shared_ptr>& children_generators() const; // Reset current generator by a non-deterministic random seed, and returns it. uint64_t seed(); Maybe> device() const; Maybe GetState() const; Maybe SetState(const std::shared_ptr& state); const std::shared_ptr& internal() const { return internal_; } template Maybe Get(int device_index = -1) const { if (auto* internal = dynamic_cast(internal_.get())) { return internal->GetOrCreate(device_index); } auto internal = std::dynamic_pointer_cast(internal_); CHECK_NOTNULL_OR_RETURN(internal); if (device_index != -1) { CHECK_EQ_OR_RETURN(device_index, internal->device_index()) << "Invalid device index " << device_index << " since the generator's device index is " << internal->device_index(); } return internal; } private: mutable std::mutex mutex_; std::shared_ptr internal_; // children generator for eager global mode HashMap, Symbol>, // NOLINT std::shared_ptr> // NOLINT children_generators_; // NOLINT }; Maybe MakeGenerator(const std::string& device, int device_index = -1); Maybe MakeGenerator(DeviceType device, int device_index = -1); Maybe MakeAutoGenerator(); Maybe MakeCPUGenerator(); Maybe MakeCUDAGenerator(); Maybe DefaultAutoGenerator(); Maybe DefaultCPUGenerator(); Maybe DefaultCUDAGenerator(int device_index = -1); Maybe DefaultGenerator(const std::string& device, int device_index = -1); Maybe DefaultGenerator(DeviceType device, int device_index = -1); Maybe ManualSeed(uint64_t seed); Maybe ManualSeed(uint64_t seed, const std::string& device, int device_index = -1); Maybe ManualSeed(uint64_t seed, DeviceType device, int device_index = -1); Maybe ManualSeedAllCudaGenerator(uint64_t seed); } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/framework/rank_group_rpc_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { Maybe CheckTransportToken(Symbol rank_group) { const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckRankGroupConsistency)); const auto& PrepareBuffer = [](void** buffer, std::size_t* size, std::function* Callback) -> Maybe { const auto& placeholder = std::make_shared(); *buffer = placeholder.get(); *size = sizeof(uint32_t); *Callback = [placeholder]() {}; return Maybe::Ok(); }; const auto& ctx = std::make_shared(transport_token, PrepareBuffer, PrepareBuffer); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); return ctx; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/rank_group_rpc_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/rank_group.h" namespace oneflow { Maybe CheckTransportToken(Symbol rank_group); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ ================================================ FILE: oneflow/core/framework/saved_tensor_hooks.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_ #define ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_ #include "oneflow/core/framework/tensor.h" namespace oneflow { namespace one { class SavedTensorHook { public: virtual ~SavedTensorHook() = default; virtual void pack(const std::shared_ptr& tensor) = 0; virtual std::shared_ptr unpack() = 0; }; class SavedTensorHookCreator { public: virtual ~SavedTensorHookCreator() = default; virtual std::unique_ptr new_saved_tensor_hook() const = 0; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_ ================================================ FILE: oneflow/core/framework/sbp_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace { inline void SplitImpl(SbpSignature* sbp_sign, const std::string& bn, int64_t axis) { (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_split_parallel()->set_axis(axis); } inline void BroadcastImpl(SbpSignature* sbp_sign, const std::string& bn) { (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_broadcast_parallel(); } inline void PartialSumImpl(SbpSignature* sbp_sign, const std::string& bn) { (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_partial_sum_parallel(); } } // namespace namespace user_op { UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split(const OpArg& op_arg, int64_t axis) { SplitImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index()), axis); return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split(const std::vector& op_args, int64_t axis) { for (const auto& op_arg : op_args) { Split(op_arg, axis); } return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split( const std::vector>& args, int64_t axis) { for (const auto& pair : args) { SplitImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second), axis); } return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast(const OpArg& op_arg) { BroadcastImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index())); return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast(const std::vector& op_args) { for (const auto& op_arg : op_args) { Broadcast(op_arg); } return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast( const std::vector>& op_args) { for (const auto& pair : op_args) { BroadcastImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second)); } return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum(const OpArg& op_arg) { PartialSumImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index())); return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum( const std::vector& op_args) { for (const auto& op_arg : op_args) { PartialSum(op_arg); } return *this; } UserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum( const std::vector>& op_args) { for (const auto& pair : op_args) { PartialSumImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second)); } return *this; } Maybe GetSbpFnUtil::DefaultBroadcastToBroadcast(SbpContext* ctx) { return Maybe::Ok(); } Maybe GetSbpFnUtil::SplitForEachAxis(SbpContext* ctx) { const auto& inputs = ctx->inputs(); CHECK_GE_OR_RETURN(inputs.size(), 1) << "At least one input for op GetSbpFnUtil::SplitForEachAxis"; int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs.at(0).first, inputs.at(0).second) .shape() .NumAxes(); for (const auto& pair : inputs) { CHECK_EQ( num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex(pair.first, pair.second).shape().NumAxes()); } for (int64_t axis = 0; axis < num_axes; ++axis) { ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); } return Maybe::Ok(); } Maybe InferNdSbp4SrcOp(user_op::InferNdSbpFnContext* ctx, const SbpParallel& default_sbp) { const Shape& hierarchy = ctx->parallel_hierarchy(); const auto& sbp_str_list = ctx->user_op_conf().attr>("nd_sbp"); // src op may have tick inputs whose sbp should be broadcast for (const auto& input_arg : ctx->inputs()) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(input_arg.first, input_arg.second); FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { input_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } } for (const auto& output_arg : ctx->outputs()) { NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(output_arg.first, output_arg.second); size_t nd_sbp_size = sbp_str_list.size(); if (nd_sbp_size == 0) { nd_sbp_size = hierarchy.NumAxes(); } else { CHECK_EQ_OR_RETURN(nd_sbp_size, hierarchy.NumAxes()); } FOR_RANGE(size_t, i, 0, nd_sbp_size) { SbpParallel* sbp = output_nd_sbp->add_sbp_parallel(); if (sbp_str_list.size() == 0) { *sbp = default_sbp; } else { CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str_list[i], sbp)); } CHECK_OR_RETURN(sbp->has_split_parallel() || sbp->has_broadcast_parallel()); } } return Maybe::Ok(); } Maybe SetSrcOpNdSbp(const NdSbpSignature& nd_sbp_sig, const std::string& blob_name, OperatorConf* op_conf) { CHECK_OR_RETURN(nd_sbp_sig.bn_in_op2nd_sbp().find(blob_name) != nd_sbp_sig.bn_in_op2nd_sbp().end()) << "blob `" << blob_name << "` can't found in NdSBP signature: " << nd_sbp_sig.DebugString(); const auto& nd_sbp = nd_sbp_sig.bn_in_op2nd_sbp().at(blob_name); std::vector nd_sbp_str_list = *JUST(GetNdSbpStrList(nd_sbp)); CHECK_OR_RETURN(op_conf->has_user_conf()) << "user_op::SetSrcOpNdSbp function only used to set user op conf"; CHECK_OR_RETURN(op_conf->user_conf().attr().find("nd_sbp") != op_conf->user_conf().attr().end()) << op_conf->name() << " has no attr named `nd_sbp`"; *op_conf->mutable_user_conf() ->mutable_attr() ->at("nd_sbp") .mutable_at_list_string() ->mutable_val() = {nd_sbp_str_list.begin(), nd_sbp_str_list.end()}; return Maybe::Ok(); } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/sbp_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_ #define ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/framework/infer_nd_sbp_fn_context.h" namespace oneflow { namespace user_op { class TensorDesc; class UserOpSbpSignatureBuilder final { public: UserOpSbpSignatureBuilder(SbpSignatureList* sbp_sig_list) : sbp_sig_list_(sbp_sig_list) {} UserOpSbpSignatureBuilder& Split(const OpArg& op_arg, int64_t axis); UserOpSbpSignatureBuilder& Split(const std::vector& op_args, int64_t axis); UserOpSbpSignatureBuilder& Split(const std::vector>& op_args, int64_t axis); UserOpSbpSignatureBuilder& Broadcast(const OpArg& op_arg); UserOpSbpSignatureBuilder& Broadcast(const std::vector& op_args); UserOpSbpSignatureBuilder& Broadcast(const std::vector>& op_args); UserOpSbpSignatureBuilder& PartialSum(const OpArg& op_arg); UserOpSbpSignatureBuilder& PartialSum(const std::vector& op_args); UserOpSbpSignatureBuilder& PartialSum( const std::vector>& op_args); void Build() { *(sbp_sig_list_->mutable_sbp_signature()->Add()) = sbp_sig_tmp_; } private: SbpSignatureList* sbp_sig_list_; SbpSignature sbp_sig_tmp_; }; class SbpContextBase { public: SbpContextBase() = default; virtual ~SbpContextBase() = default; virtual const TensorDesc& LogicalTensorDesc4InputArgNameAndIndex( const std::string& input_arg_name, int32_t index) const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual DeviceType device_type() const = 0; virtual int64_t parallel_num() const = 0; template T Attr(const std::string& attr_name) const { return user_op_conf().attr(attr_name); } virtual const UserOpConfWrapper& user_op_conf() const = 0; }; class SbpContext : public SbpContextBase { public: SbpContext() = default; ~SbpContext() override = default; // hierarchy value is the value at the dimension corresponding to the current SBP // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] // Suppose we have nd_sbp = (S0, B) // The hierarchy value corresponding to S0 is 2 // The hierarchy value corresponding to B is 4. virtual int64_t hierarchy_value() const = 0; virtual UserOpSbpSignatureBuilder NewBuilder() = 0; }; class InferSbpSignatureFnContext : public SbpContextBase { public: InferSbpSignatureFnContext() = default; ~InferSbpSignatureFnContext() override = default; virtual SbpSignature* mutable_sbp_signature() = 0; virtual const SbpSignature& sbp_signature_conf() const = 0; virtual const SbpParallel& SbpParallelHint4InputArgNameAndIndex(const std::string& input_arg_name, int32_t index) const = 0; }; struct GetSbpFnUtil { static Maybe DefaultBroadcastToBroadcast(SbpContext*); static Maybe SplitForEachAxis(SbpContext*); }; Maybe InferNdSbp4SrcOp(user_op::InferNdSbpFnContext* ctx, const SbpParallel& default_sbp); Maybe SetSrcOpNdSbp(const NdSbpSignature& nd_sbp_sig, const std::string& blob_name, OperatorConf* op_conf); } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_ ================================================ FILE: oneflow/core/framework/sbp_infer_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/register/blob_desc.h" namespace oneflow { namespace { static const double kUnsupportedBoxing = GetMaxVal(); // check whether the sbp_parallel is legal bool CheckSbpParallel(const SbpParallel& sbp_parallel) { return sbp_parallel.has_split_parallel() || sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel(); } // check whether the nd_sbp is legal bool CheckNdSbp(const NdSbp& nd_sbp) { if (nd_sbp.sbp_parallel_size() <= 0) { return false; } for (const auto& sbp : nd_sbp.sbp_parallel()) { if (!CheckSbpParallel(sbp)) { return false; } } return true; } double Penalty4PartialInConsumer(double logical_blob_size, int32_t producer_parallel_num, int32_t consumer_parallel_num) { static const int64_t penalty4partial_in_consumer_tag = ParseIntegerFromEnv("ONEFLOW_PENALTY_FOR_PARTIAL_IN_CONSUMER_POLICY", 2); if (penalty4partial_in_consumer_tag == Penalty4PartialInConsumerTag::kSlight) { return 1.0; } else if (penalty4partial_in_consumer_tag == Penalty4PartialInConsumerTag::kMiddle) { return 4 * logical_blob_size * (producer_parallel_num + consumer_parallel_num); } else { return kUnsupportedBoxing; } } int32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const std::function& classifier) { int32_t ratio = 1; for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { if (classifier(nd_sbp.sbp_parallel(sbp_id))) { ratio *= parallel_desc.hierarchy()->At(sbp_id); } } return ratio; } Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel, const SbpParallel& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, bool on_same_devices, int32_t producer_parallel_num, int32_t consumer_parallel_num) { if (!(CheckSbpParallel(producer_sbp_parallel) && CheckSbpParallel(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } // Not supporting S->P for lazy boxing now. if (LazyMode::is_enabled()) { if (consumer_sbp_parallel.has_partial_sum_parallel() && producer_sbp_parallel.has_split_parallel()) { return kUnsupportedBoxing; } } // NOTE: A tensor placed on cpu with a consumer operator that accepts cuda inputs would be // transferred to cuda later. We might not have correct parallel description at this moment. if (on_same_devices && producer_parallel_num == consumer_parallel_num) { // Same sbp, no cost: S->S, B->B, P->P if (producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; } double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc); // S->P for eager. It should be 0 as well. // NOTE: Similar to B->P, we just make the other part to be 0. You can consider P as S(i) for an // arbitrary i. // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, consumer_parallel_num); } // B->S if (producer_sbp_parallel.has_broadcast_parallel()) { return 1.0; } // has S if (consumer_sbp_parallel.has_split_parallel() || producer_sbp_parallel.has_split_parallel()) { if (consumer_sbp_parallel.has_split_parallel() && producer_sbp_parallel.has_split_parallel()) { // S(0)->S(1), S(1)->S(0), etc. return logical_blob_size * (producer_parallel_num - 1) / producer_parallel_num; } else { // P->S, S->B/P return logical_blob_size * (producer_parallel_num - 1); } } // P->B return 2 * logical_blob_size * (producer_parallel_num - 1); } else { // Not supporting P->P for different placement if (LazyMode::is_enabled()) { if (consumer_sbp_parallel.has_partial_sum_parallel() && producer_sbp_parallel.has_partial_sum_parallel()) { return kUnsupportedBoxing; } } double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc); double overall_cost = logical_blob_size; // ? -> B if (consumer_sbp_parallel.has_broadcast_parallel()) { overall_cost += (consumer_parallel_num - 1) * logical_blob_size; } // P -> ? if (producer_sbp_parallel.has_partial_sum_parallel()) { overall_cost += (producer_parallel_num - 1) * logical_blob_size; } // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { overall_cost += Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, consumer_parallel_num); } // For B->S, S->S, overall_cost == logical_blob_size; return overall_cost; } } // compute copy cost for two SBPs. // They may be either different or on different devices. double ComputCopyCostBetweenTwoDiffSbpParallel(const SbpParallel& producer_sbp_parallel, const SbpParallel& consumer_sbp_parallel, double logical_blob_size, double parallel_num, bool on_same_devices) { // Not supporting S->P for now. if (consumer_sbp_parallel.has_partial_sum_parallel() && producer_sbp_parallel.has_split_parallel()) { return kUnsupportedBoxing; } if (on_same_devices) { // B->P if (consumer_sbp_parallel.has_partial_sum_parallel()) { return Penalty4PartialInConsumer(logical_blob_size, parallel_num, parallel_num); } // B->S if (producer_sbp_parallel.has_broadcast_parallel()) { return 1; } // has S if (consumer_sbp_parallel.has_split_parallel() || producer_sbp_parallel.has_split_parallel()) { if (consumer_sbp_parallel.has_split_parallel() && producer_sbp_parallel.has_split_parallel()) { // S(0)->S(1), S(1)->S(0), etc. return logical_blob_size * (parallel_num - 1) / parallel_num; } else { // P->S, S->B return logical_blob_size * (parallel_num - 1); } } // P->B (= P->S + S->B) return 2 * logical_blob_size * (parallel_num - 1); } else { // They have the same hierarchy at the transfer dimension. double overall_cost = logical_blob_size; // ? -> B if (consumer_sbp_parallel.has_broadcast_parallel()) { overall_cost += logical_blob_size * (parallel_num - 1); } // P -> ? if (producer_sbp_parallel.has_partial_sum_parallel()) { overall_cost += logical_blob_size * (parallel_num - 1); } if (consumer_sbp_parallel.has_partial_sum_parallel()) { overall_cost += Penalty4PartialInConsumer(logical_blob_size, parallel_num, parallel_num); } // For B->P, B->S, S->S, overall_cost == logical_blob_size; return overall_cost; } } Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp, double logical_blob_size, const Shape& hierarchy, bool on_same_devices) { if (hierarchy.NumAxes() != 2) { return kUnsupportedBoxing; } const auto& producer_sbp_size = producer_nd_sbp.sbp_parallel_size(); const auto& consumer_sbp_size = consumer_nd_sbp.sbp_parallel_size(); // One of the SBP should have size 2 CHECK_OR_RETURN((producer_sbp_size == 1 && consumer_sbp_size == 2) || (producer_sbp_size == 2 && consumer_sbp_size == 1) || (producer_sbp_size == 2 && consumer_sbp_size == 2)) << "Not supporting such boxing type. Check if we have bugs in auto parallel."; for (int32_t dim_same_sbp = 0; dim_same_sbp < 2; dim_same_sbp++) { // If the nd_sbp only have size 1, then make its dimension 0 int32_t dim_producer = dim_same_sbp; if (producer_sbp_size == 1) { dim_producer = 0; } int32_t dim_consumer = dim_same_sbp; if (consumer_sbp_size == 1) { dim_consumer = 0; } // The SBP parallel are the same at dimension (dim_same_sbp) if (producer_nd_sbp.sbp_parallel(dim_producer) == consumer_nd_sbp.sbp_parallel(dim_consumer)) { if (!producer_nd_sbp.sbp_parallel(dim_producer).has_split_parallel()) { logical_blob_size *= hierarchy.At(dim_same_sbp); } // The SBP parallel are different at dimension (dim_diff_sbp) int32_t dim_diff_sbp = 1 - dim_same_sbp; // If the nd_sbp only have size 1, then make its dimension 0. // Since we have already do this before, we just maintain the value. // Otherwise, switch the dimension to dim_diff_sbp if (producer_sbp_size == 2) { dim_producer = dim_diff_sbp; } if (consumer_sbp_size == 2) { dim_consumer = dim_diff_sbp; } // Spliting at the same dimension needs special cares! // Not supported by nccl if (dim_diff_sbp == 0 && producer_nd_sbp.sbp_parallel(dim_producer) != consumer_nd_sbp.sbp_parallel(dim_consumer) && (NdSbpAllSameSplitParallel(producer_nd_sbp) || NdSbpAllSameSplitParallel(consumer_nd_sbp))) { return kUnsupportedBoxing; } return ComputCopyCostBetweenTwoDiffSbpParallel( producer_nd_sbp.sbp_parallel(dim_producer), consumer_nd_sbp.sbp_parallel(dim_consumer), logical_blob_size, hierarchy.At(dim_diff_sbp), on_same_devices); } } return kUnsupportedBoxing; } Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } // TODO: get copy cost from each EagerBoxingInterpreter if (!TRY(Singleton::Get()->GetEagerBoxingInterpreter( producer_sbp_parallel, consumer_sbp_parallel, producer_parallel_desc, consumer_parallel_desc, logical_blob_desc.shape())) .IsOk()) { return kUnsupportedBoxing; } bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); // Reduce before cost computation Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { return 0.0; } if (requires_same_sbp) { return kUnsupportedBoxing; } int32_t in_dim = reduced_in_hierarchy.NumAxes(); int32_t out_dim = reduced_out_hierarchy.NumAxes(); // We support different hierarchy for 1D sbp if (in_dim == 1 && out_dim == 1) { return ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt()); } double total_cost = 1.0; if (on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { // NOTE: After analysis, transfer cost increase if spliting the same dimension. // Example 1: (S(1), S(0), S(1), S(0)) -> (S(0), S(0), S(0), S(0)) // Example 2: (B, S(0)) -> (S(0), S(0)) // The cost would be (1-1/n)T, where n is the product of hierarchy number in those splitting // dimensions. To give a more precise cost, we add a upper bound of those lost cost back for // simplification. bool normal_case = true; // nd to nd for (int32_t i = 0; i < in_dim; ++i) { const auto& in_sbp = reduced_in_nd_sbp.sbp_parallel(i); const auto& out_sbp = reduced_out_nd_sbp.sbp_parallel(i); // Have bugs here. (B, S0) -> (S0, S0) will give a cost 0. // Actually it is (1-1/m)T for hierarchy (n, m) // TODO: Fix that after support all sbp combination for eager. total_cost += JUST(ComputCopyCostBetweenTwoSbpParallel( in_sbp, out_sbp, logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt())); // Add the penalty for P in the consumer if (out_sbp.has_partial_sum_parallel() && (in_sbp != out_sbp)) { total_cost += Penalty4PartialInConsumer(TotalByteSize4BlobDesc(logical_blob_desc), producer_parallel_desc.parallel_num(), consumer_parallel_desc.parallel_num()); } // detect the cases that splits the same dimension before this splitting if (normal_case && in_sbp.has_split_parallel() && in_sbp == out_sbp) { for (int32_t j = 0; j < i; j++) { const auto& in_sbp_j = reduced_in_nd_sbp.sbp_parallel(j); const auto& out_sbp_j = reduced_out_nd_sbp.sbp_parallel(j); // in_sbp == out_sbp in this situation if ((in_sbp_j != out_sbp_j) && (in_sbp_j == in_sbp || out_sbp_j == in_sbp)) { normal_case = false; break; } } } } // Add the cost for the special case if (!normal_case) { total_cost += TotalByteSize4BlobDesc(logical_blob_desc); } } else { double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc); { double in_cost = 1.0; for (int32_t i = 0; i < in_dim; ++i) { // P -> ? if (reduced_in_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { in_cost *= reduced_in_hierarchy.At(i); } } total_cost += logical_blob_size * in_cost; } { double out_cost = 1.0; for (int32_t i = 0; i < out_dim; ++i) { // ? -> B if (reduced_out_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { out_cost *= reduced_out_hierarchy.At(i); } // Add the penalty for P in the consumer if (reduced_out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { total_cost += Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(), consumer_parallel_desc.parallel_num()); } } total_cost += logical_blob_size * out_cost; } } return total_cost; } using CopyCostFunc = Maybe(const NdSbp&, const NdSbp&, const BlobDesc&, const ParallelDesc&, const ParallelDesc&, bool); Maybe GetComputeCopyCostFunc() { if (LazyMode::is_enabled()) { return &ComputeCopyCostWithMiddleNodes; } else { return &ComputeEagerCopyCostBetweenNdSbp; } } // Replace the hierarchy and then create a new parallel description void ReplaceHierarchy4ParallelDesc(const ParallelDesc& old_parallel_desc, const Shape& new_hierarchy, ParallelDesc* new_parallel_desc) { if (*old_parallel_desc.hierarchy() == new_hierarchy) { *new_parallel_desc = old_parallel_desc; } else { ParallelConf new_parallel_conf = old_parallel_desc.parallel_conf(); new_hierarchy.ToProto(new_parallel_conf.mutable_hierarchy()); *new_parallel_desc = ParallelDesc(new_parallel_conf); } } // We can not just simply merging two same split // For example, shape = [6], we are trying to merge [2, 2]: (S0, S0) -> [4]: S0 // For each rank, [4]: S0 has number of data: 2, 2, 1, 1 // For each rank, [2]: S0 has number of data: 3, 3 // For each rank, [2, 2]: (S0, S0) has number of data: 2, 1, 2, 1 // Thus {[2, 2]: (S0, S0)} != {[4]: S0} for shape [6] // However {[2, 2]: (S0, S0)} == {[4]: S0} for shape [4], [5], [7], [8] // More specifically, {[a, b]: (Si, Si)} == {[a*b]: Si} if and only if // shape value % (a * b) == 0, 1, a*b - 1 bool CanMergeSplit(int32_t shape_value, int32_t merged_split_hierarchy_value) { int32_t remainder = shape_value % merged_split_hierarchy_value; if (remainder <= 1 || remainder == merged_split_hierarchy_value - 1) { return true; } else { return false; } } } // namespace int32_t PartialRatio4Producer(const NdSbp& sbp_producer, const ParallelDesc& producer_parallel_desc) { return Ratio4Sbp(sbp_producer, producer_parallel_desc, &SbpParallel::has_partial_sum_parallel); } int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, const ParallelDesc& consumer_parallel_desc) { return Ratio4Sbp(sbp_consumer, consumer_parallel_desc, &SbpParallel::has_broadcast_parallel); } void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, NdSbp* reduced_nd_sbp, const Shape& logical_shape) { NdSbpsDimReduce(hierarchy, {&nd_sbp}, reduced_hierarchy, {reduced_nd_sbp}, logical_shape); } void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd_sbps, Shape* reduced_hierarchy, const std::vector& reduced_nd_sbps, const Shape& logical_shape) { int32_t sbp_num = nd_sbps.size(); // Speed up for 1d sbp if (hierarchy.NumAxes() == 1) { *reduced_hierarchy = hierarchy; for (int32_t index = 0; index < sbp_num; index++) { if (hierarchy.elem_cnt() == 1) { reduced_nd_sbps[index]->add_sbp_parallel()->mutable_broadcast_parallel(); } else { *reduced_nd_sbps[index] = *nd_sbps[index]; } } return; } reduced_hierarchy->clear(); for (auto& reduced_nd_sbp : reduced_nd_sbps) { reduced_nd_sbp->clear_sbp_parallel(); } // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999] // We hold the split when accessing the current dimension // Do the true splitting until we reach the next step // dim = 0, split_axis2holding_reduced_shapes: {(0: 601)}, last split axis = -1 // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 301)}, last split axis = 0 // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 75, 76)}, last split axis = 1 // dim = 3, at this moment, last split axis (0) == current split axis (0), // dim = 3, but judging 300 % (3 * 7) = 6 fails the CanMergeSplit(), not merging // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 75, 76)}, last split axis = 0 std::vector>> index2split_axis2holding_reduced_shapes(sbp_num); std::vector> index2last_holding_reduced_shapes(sbp_num); std::vector last_split_axises(sbp_num, -1); std::vector indexes(sbp_num); for (int32_t index = 0; index < sbp_num; index++) { indexes[index] = index; } auto add_to_reduced_sbp_hierarchy = [&](int32_t hierarchy_dim) { // Clear the last holding split axis for (int32_t index = 0; index < sbp_num; index++) { auto& split_axis2holding_reduced_shapes = index2split_axis2holding_reduced_shapes[index]; auto& last_holding_reduced_shapes = index2last_holding_reduced_shapes[index]; auto& last_split_axis = last_split_axises[index]; auto& nd_sbp = nd_sbps[index]; auto& reduced_nd_sbp = reduced_nd_sbps[index]; if (last_split_axis >= 0) { auto& holding_reduced_shapes = split_axis2holding_reduced_shapes[last_split_axis]; holding_reduced_shapes.clear(); for (int32_t last_holding_reduced_shape : last_holding_reduced_shapes) { int32_t quotient = last_holding_reduced_shape / reduced_hierarchy->back(); if (last_holding_reduced_shape % reduced_hierarchy->back() != 0) { holding_reduced_shapes.insert(quotient + 1); } holding_reduced_shapes.insert(quotient); } } // Add a new sbp_parallel and a new hierarchy dimension const auto& curr_sbp_parallel = nd_sbp->sbp_parallel(hierarchy_dim); *reduced_nd_sbp->add_sbp_parallel() = curr_sbp_parallel; // Hold the current split shape if (curr_sbp_parallel.has_split_parallel()) { last_holding_reduced_shapes.clear(); last_split_axis = curr_sbp_parallel.split_parallel().axis(); auto it = split_axis2holding_reduced_shapes.find(last_split_axis); if (it == split_axis2holding_reduced_shapes.end()) { // Looking at a dimension which is never splitted before // Shape: [601, ...], sbp: (S0, ...) last_holding_reduced_shapes.push_back(logical_shape.At(last_split_axis)); } else { // This dimension is splitted before // Shape: [601, 301, ...], sbp: (S0, S1, B, S0, ...), hierarchy: [2, 3, 100, 7, ...] // Looking at i = 3, we hold the second S0, but 601 is already splitted by the first S0. // split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 100, 101)} last_holding_reduced_shapes.assign(it->second.begin(), it->second.end()); } } else { last_split_axis = -1; } } // Add a new hierarchy dimension reduced_hierarchy->emplace_back(hierarchy.At(hierarchy_dim)); }; for (int32_t hierarchy_dim = 0; hierarchy_dim < hierarchy.NumAxes(); hierarchy_dim++) { // Shrink those dimension with hierarchy value = 1 if (hierarchy.At(hierarchy_dim) == 1) { continue; } if (reduced_hierarchy->empty()) { // Empty hierarchy, add to the back add_to_reduced_sbp_hierarchy(hierarchy_dim); continue; } if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) { // reduced_hierarchy->size() == reduced_nd_sbps[index]->sbp_parallel_size() // Basically, current nd sbp == reduced nd sbp.back() return nd_sbps[index]->sbp_parallel(hierarchy_dim) == reduced_nd_sbps[index]->sbp_parallel(reduced_hierarchy->size() - 1); })) { int32_t merged_hierarchy_value = reduced_hierarchy->back() * hierarchy.At(hierarchy_dim); // You can merge two sbp with B or P. // If sbp = S, then you need to make sure that all the shape value can be splitted if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) { return !nd_sbps[index]->sbp_parallel(hierarchy_dim).has_split_parallel() || std::all_of(index2last_holding_reduced_shapes[index].begin(), index2last_holding_reduced_shapes[index].end(), [&](int32_t i) { return CanMergeSplit(i, merged_hierarchy_value); }); })) { // Merge sbp and hierarchy reduced_hierarchy->back() = merged_hierarchy_value; continue; } } // Can not merge, add to the back add_to_reduced_sbp_hierarchy(hierarchy_dim); } // [1, 1, ..., 1]: Any --> [1]: (B) if (reduced_hierarchy->empty()) { reduced_hierarchy->emplace_back(hierarchy.At(0)); for (auto& reduced_nd_sbp : reduced_nd_sbps) { reduced_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } } } void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, const Shape& logical_shape) { // Speed up for 1d sbp if (parallel_desc.hierarchy()->NumAxes() == 1) { *reduced_parallel_desc = parallel_desc; *reduced_nd_sbp = nd_sbp; return; } Shape reduced_hierarchy; NdSbpDimReduce(*parallel_desc.hierarchy(), nd_sbp, &reduced_hierarchy, reduced_nd_sbp, logical_shape); ReplaceHierarchy4ParallelDesc(parallel_desc, reduced_hierarchy, reduced_parallel_desc); } void InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy, NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, const Shape& logical_shape) { if (in_hierarchy == out_hierarchy) { // [2, 4]: (S0, S0) -> [2, 4]: (S0, S1) NdSbpsDimReduce(in_hierarchy, {&in_nd_sbp, &out_nd_sbp}, reduced_in_hierarchy, {reduced_in_nd_sbp, reduced_out_nd_sbp}, logical_shape); *reduced_out_hierarchy = *reduced_in_hierarchy; } else { // [2, 4]: (S0, S0) -> [4, 2]: (S0, S1) // [2, 4]: (S0, S0) -> [3, 3]: (S0, S1) NdSbpDimReduce(in_hierarchy, in_nd_sbp, reduced_in_hierarchy, reduced_in_nd_sbp, logical_shape); NdSbpDimReduce(out_hierarchy, out_nd_sbp, reduced_out_hierarchy, reduced_out_nd_sbp, logical_shape); // Sbp of 3d or higher dimension would use general basic communication // Only looks at 1d to 2d or 2d to 1d if (reduced_in_hierarchy->NumAxes() + reduced_out_hierarchy->NumAxes() == 3 && reduced_in_hierarchy->elem_cnt() == reduced_out_hierarchy->elem_cnt()) { if (reduced_in_hierarchy->NumAxes() == 1) { // [8]: S0 -> [4, 2]: (S0, S1) // [8]: B -> [2, 4]: (S0, S1) const auto& in_sbp_parallel = reduced_in_nd_sbp->sbp_parallel(0); if (!in_sbp_parallel.has_split_parallel() || CanMergeSplit(logical_shape.At(in_sbp_parallel.split_parallel().axis()), reduced_in_hierarchy->elem_cnt())) { // Change [8]: S0 -> [4, 2]: (S0, S1) to [4, 2]: (S0, S0) -> [4, 2]: (S0, S1) // Change [8]: B -> [2, 4]: (S0, S1) to [2, 4]: (B, B) -> [2, 4]: (S0, S1) *reduced_in_nd_sbp->add_sbp_parallel() = in_sbp_parallel; *reduced_in_hierarchy = *reduced_out_hierarchy; } } else { // [2, 3]: (S0, P) -> [6]: S0 // [3, 4]: (B, S1) -> [12]: B const auto& out_sbp_parallel = reduced_out_nd_sbp->sbp_parallel(0); if (!out_sbp_parallel.has_split_parallel() || CanMergeSplit(logical_shape.At(out_sbp_parallel.split_parallel().axis()), reduced_out_hierarchy->elem_cnt())) { // Change [2, 3]: (S0, P) -> [6]: S0 to [2, 3]: (S0, P) -> [2, 3]: (S0, S0) // Change [3, 4]: (B, S1) -> [12]: B to [3, 4]: (B, S1) -> [3, 4]: (B, B) *reduced_out_nd_sbp->add_sbp_parallel() = out_sbp_parallel; *reduced_out_hierarchy = *reduced_in_hierarchy; } } } } } void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, const Shape& logical_shape) { // Speed up for 1d sbp if (in_parallel_desc.hierarchy()->NumAxes() == 1 && out_parallel_desc.hierarchy()->NumAxes() == 1) { *reduced_in_parallel_desc = in_parallel_desc; *reduced_out_parallel_desc = out_parallel_desc; *reduced_in_nd_sbp = in_nd_sbp; *reduced_out_nd_sbp = out_nd_sbp; } else { Shape reduced_in_hierarchy; Shape reduced_out_hierarchy; InOutParallelDimReduce(*in_parallel_desc.hierarchy(), *out_parallel_desc.hierarchy(), in_nd_sbp, out_nd_sbp, &reduced_in_hierarchy, &reduced_out_hierarchy, reduced_in_nd_sbp, reduced_out_nd_sbp, logical_shape); ReplaceHierarchy4ParallelDesc(in_parallel_desc, reduced_in_hierarchy, reduced_in_parallel_desc); ReplaceHierarchy4ParallelDesc(out_parallel_desc, reduced_out_hierarchy, reduced_out_parallel_desc); } } int64_t TotalByteSize4BlobDesc(const BlobDesc& logical_blob_desc) { return logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); } int64_t MaxByteSize4BlobDescSbp(const BlobDesc& logical_blob_desc, const NdSbp& nd_sbp, const Shape& hierarchy) { Shape blob_shape = logical_blob_desc.shape(); for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { const auto& sbp = nd_sbp.sbp_parallel(sbp_id); if (sbp.has_split_parallel()) { int32_t split_axis = sbp.split_parallel().axis(); blob_shape.Set(split_axis, CeilQuotient(blob_shape.At(split_axis), hierarchy.At(sbp_id))); } } return blob_shape.elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); } Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); // Reduce before cost computation Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); int32_t in_dim = reduced_in_hierarchy.NumAxes(); int32_t out_dim = reduced_out_hierarchy.NumAxes(); // Not supporting n-D sbp with n >= 3 // TODO: Support it in the future if (std::min(in_dim, out_dim) <= 0 || std::max(in_dim, out_dim) >= 3) { return kUnsupportedBoxing; } bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { return 0.0; } if (requires_same_sbp) { return kUnsupportedBoxing; } // We support different hierarchy for 1D sbp if (in_dim == 1 && out_dim == 1) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt())); } #ifdef WITH_CUDA static const bool enable_general_basic_communication = ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); // Use a general basic communication if no P in the consumer if ((((Singleton::Get()->nccl_use_compute_stream() && producer_parallel_desc == consumer_parallel_desc) || enable_general_basic_communication) && !NdSbpHasPartialParallel(consumer_sbp_parallel)) && producer_parallel_desc.device_type() == DeviceType::kCUDA && consumer_parallel_desc.device_type() == DeviceType::kCUDA) { return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc) + GetTransferCost(); } #endif // WITH_CUDA // Not supporting different hierarchy without general basic communication if (reduced_in_hierarchy.elem_cnt() != reduced_out_hierarchy.elem_cnt()) { return kUnsupportedBoxing; } double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc); if (in_dim == 2 && out_dim == 2) { // Not supporting different hierarchy // TODO: Support it in the future if (reduced_in_hierarchy != reduced_out_hierarchy) { return kUnsupportedBoxing; } return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, logical_blob_size, reduced_in_hierarchy, on_same_devices)); } // (in_dim == 2 && out_dim == 1) || (in_dim == 1 && out_dim == 2) if (in_dim == 2 && out_dim == 1) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, logical_blob_size, reduced_in_hierarchy, on_same_devices)); } if (in_dim == 1 && out_dim == 2) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, logical_blob_size, reduced_out_hierarchy, on_same_devices)); } return Error::RuntimeError() << "Should not reach here. Something went wrong in ComputCopyCostBetweenNdSbp() in " "sbp_util.cpp."; } double GetValidMaxCopyCost() { // We suppose that valid copy cost range is [0, FloatMax*0.8] static const double kValidMaxCopyCost = kUnsupportedBoxing * 0.8; return kValidMaxCopyCost; } double GetTransferCost() { // Each transfer would have cost. // Except for same parallel description and sbp static const double kTransferCost = ParseFloatFromEnv("AUTO_PARALLEL_TRANSFER_COST", 1.65e4); return kTransferCost; } void ResizeNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t size) { for (auto& pair : *nd_sbp_sig.mutable_bn_in_op2nd_sbp()) { if (pair.second.sbp_parallel_size() > size) { pair.second.clear_sbp_parallel(); } while (pair.second.sbp_parallel_size() < size) { pair.second.add_sbp_parallel(); } } } void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp_signature, int32_t sbp_axis) { for (const auto& pair : sbp_signature.bn_in_op2sbp_parallel()) { *((*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[pair.first].mutable_sbp_parallel(sbp_axis)) = pair.second; } } void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, const Shape& hierarchy, const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list) { if (depth == dims) { nd_sbp_sig_list->push_back(nd_sbp_sig); } else { for (const auto& sbp_signature : hierarchy_value2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth); DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_value2sbp_sig_list, nd_sbp_sig_list); } } } namespace { // give a mesure value for NdSbp for sorting size_t MesureNdSbp(const NdSbp& nd_sbp) { // start from 1, B + P + max split axis (8) constexpr size_t kMaxSplitAxis = 8; constexpr size_t kCarryDigit = kMaxSplitAxis + 3; size_t value = 0; for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { size_t cur_dim_value = 0; const auto& sbp = nd_sbp.sbp_parallel(i); if (sbp.has_broadcast_parallel()) { cur_dim_value = 1; } else if (sbp.has_partial_sum_parallel()) { cur_dim_value = 2; } else if (sbp.has_split_parallel()) { CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis); // from 3 to 10 cur_dim_value = 3 + sbp.split_parallel().axis(); } else { UNIMPLEMENTED(); } value = value * kCarryDigit + cur_dim_value; } return value; } size_t MesureNdSbpSignature(const NdSbpSignature& nd_sbp_sig, const std::vector& bns) { // big enough for 2d-sbp signatrue set // if want to extend to 3d-sbp, consider increase to 170 constexpr size_t kCarryDigit = 97; size_t value = 0; for (size_t i = 0; i < bns.size(); ++i) { auto nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(bns[i]); CHECK(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end()) << "can't find bn (" << bns[i] << ") in " << PbMessage2TxtString(nd_sbp_sig); size_t cur_arg_value = MesureNdSbp(nd_sbp_it->second); CHECK_LE(value + cur_arg_value / kCarryDigit, std::numeric_limits::max() / kCarryDigit); value = value * kCarryDigit + cur_arg_value; } return value; } } // namespace void DeduplicateNdSbpSignatureList(std::vector* nd_sbp_sig_list, const std::vector& bns) { if (bns.size() > 8) { return; } std::map value2nd_sbp_sig; for (auto& nd_sbp_sig : *nd_sbp_sig_list) { size_t order_value = MesureNdSbpSignature(nd_sbp_sig, bns); if (value2nd_sbp_sig.find(order_value) == value2nd_sbp_sig.end()) { value2nd_sbp_sig.emplace(order_value, std::move(nd_sbp_sig)); } } nd_sbp_sig_list->clear(); for (auto& nd_sbp_pair : value2nd_sbp_sig) { nd_sbp_sig_list->emplace_back(std::move(nd_sbp_pair.second)); } } // Compute storage per device for given NdSbp double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) { if (nd_sbp.sbp_parallel_size() == 1) { double logical_blob_size = logical_shape.elem_cnt(); // Checking 1D sbp const auto& sbp_parallel = nd_sbp.sbp_parallel(0); if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); if (axis >= logical_shape.NumAxes()) { return kUnsupportedBoxing; } if (logical_shape.At(axis) < parallel_hierarchy.At(0)) { return kUnsupportedBoxing; } logical_blob_size /= parallel_hierarchy.At(0); } return logical_blob_size; } else { for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); ++dim_sbp) { const auto& sbp_parallel = nd_sbp.sbp_parallel(dim_sbp); if (sbp_parallel.has_split_parallel()) { // Split axis and store result back to logical shape const int64_t axis = sbp_parallel.split_parallel().axis(); if (axis >= logical_shape.NumAxes()) { return kUnsupportedBoxing; } // Use completely average split to count the storage if (logical_shape.At(axis) < parallel_hierarchy.At(dim_sbp)) { return kUnsupportedBoxing; } logical_shape.Set(axis, logical_shape.At(axis) / parallel_hierarchy.At(dim_sbp)); } } return logical_shape.elem_cnt(); } } // Judge whether an NdSbp could be applied on a tensor with given logical shape // True means this NdSbp is not valid. Maybe FilterNdSbpByLogicalShape(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) { return Storage4NdSbp(nd_sbp, logical_shape, parallel_hierarchy) > GetValidMaxCopyCost(); } Maybe ComputeCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { return JUST(GetComputeCopyCostFunc())(producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, requires_same_sbp); } Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { // In 90% of the transfer, we would have the same parallel description for producer and consumer // We need to speed it up and give an approximation of the cost if (producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { // [2, 2]: (S0, S1) -> [2, 2]: (S0, S1) if (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy() && producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; } // Reduce before cost computation Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); // [2, 2]: (B, B) -> [4]: B if (reduced_in_hierarchy == reduced_out_hierarchy && reduced_in_nd_sbp == reduced_out_nd_sbp) { return 1.0; } } if (requires_same_sbp) { return kUnsupportedBoxing; } #ifdef WITH_CUDA static const bool enable_general_basic_communication = ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); // Use a general basic communication if no P in the consumer if ((((Singleton::Get()->nccl_use_compute_stream() && producer_parallel_desc == consumer_parallel_desc) || enable_general_basic_communication) && !NdSbpHasPartialParallel(consumer_sbp_parallel)) && producer_parallel_desc.device_type() == DeviceType::kCUDA && consumer_parallel_desc.device_type() == DeviceType::kCUDA) { return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc) + GetTransferCost(); } #endif // WITH_CUDA // Initialize boxing collector constexpr int32_t kRegularMaxSplitAxes = 6; static thread_local BoxingCollector boxing_collector(kRegularMaxSplitAxes); std::vector middle_sbps; // Ask for middle nodes int32_t diag_node = 0; JUST(boxing_collector.AskSbpCombination( producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, /*is_customized=*/false, middle_sbps, &diag_node, /*compute_cost=*/true)); // Parameters double total_cost = 0.0; // Set up the information of the first node in the first connection const NdSbp* pre_nd_sbp = &producer_sbp_parallel; const ParallelDesc* pre_parallel_desc = &producer_parallel_desc; const ParallelDesc* middle_parallel_desc = nullptr; // Connection for the next middle node for (int32_t middle_node_id = 0; middle_node_id < middle_sbps.size(); middle_node_id++) { const auto& middle_sbp = middle_sbps[middle_node_id]; if (middle_node_id < diag_node) { middle_parallel_desc = &producer_parallel_desc; } else { middle_parallel_desc = &consumer_parallel_desc; } // We use the parallel description of consumer as the parallel description for all the middle // nodes, following the same procedure in boxing_with_middle_nodes.cpp // TODO: Needs more effort if dealing with different placement total_cost += JUST(ComputeLazyCopyCostBetweenNdSbp(*pre_nd_sbp, middle_sbp, logical_blob_desc, *pre_parallel_desc, *middle_parallel_desc, requires_same_sbp)); // Set up the information of the first node in the next connection pre_nd_sbp = &middle_sbp; pre_parallel_desc = middle_parallel_desc; } // Connection between the last middle node and consumer total_cost += JUST(ComputeLazyCopyCostBetweenNdSbp(*pre_nd_sbp, consumer_sbp_parallel, logical_blob_desc, *pre_parallel_desc, consumer_parallel_desc, requires_same_sbp)); return total_cost; } // Decide the priority to infer sbp double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp, const Shape& logical_shape) { if (producer_nd_sbp == consumer_nd_sbp && producer_parallel_desc == consumer_parallel_desc) { // Highest priority: this blob have the same placement and sbp on both the producer and // consumer return 0.0; } // Reduce before cost computation Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), producer_nd_sbp, consumer_nd_sbp, &reduced_in_hierarchy, &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_shape); if (requires_same_sbp) { // This blob does not support boxing if (reduced_in_nd_sbp == reduced_out_nd_sbp && reduced_in_hierarchy == reduced_out_hierarchy && producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { // Normal priority: No transfer occurs but we have different sbp // For example: [1]:S0 -> [1]:B // [1, 2]:(P, S0) -> [1, 2]:(S0, S0) return 1.0; } else { // Penalty: this blob have different placements and sbps but it does not support boxing return 2.0; } } else { // This blob supports boxing if (producer_nd_sbp.sbp_parallel_size() == consumer_nd_sbp.sbp_parallel_size()) { if (producer_nd_sbp == consumer_nd_sbp) { // Highest priority: this blob have the same sbp on both the producer and consumer // Not just [0-3] -> [4-7], but also cpu:[0] -> cuda:[0-3] return 0.0; } } else { if (reduced_in_nd_sbp == reduced_out_nd_sbp) { // Highest priority: this blob have the same sbp on both the producer and consumer // [2, 2]: (S0, S0) -> [2]: S0 // (learning rate) [1]: B -> [2, 2]: (B, B) return 0.0; } } // Normal priority: transfer might occurs // Or might not: [1, 2]: (P, S0) -> [1, 2]: (B, S0) // No transfer but not highest priority return 1.0; } } // The transfer ratio for general basic communication // Cost = ratio * data amount // When we get the this function, either producer_sbp_parallel != consumer_sbp_parallel // or producer_parallel_desc != consumer_parallel_desc double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc) { // The upper bound of the amount of the transferred data int32_t producer_partial_ratio = PartialRatio4Producer(producer_sbp_parallel, producer_parallel_desc); int32_t consumer_broadcast_ratio = BroadcastRatio4Consumer(consumer_sbp_parallel, consumer_parallel_desc); // More intersection on the same devices bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); // approximate intersection ratio double intersection_ratio = 1.0; // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer if (producer_partial_ratio > 1 && consumer_broadcast_ratio > 1) { if (on_same_devices) { // Pure P in the producer or B in the consumer // (P, P, P) -> ? or ? -> (B, B) if (producer_partial_ratio == producer_parallel_desc.parallel_num() || consumer_broadcast_ratio == consumer_parallel_desc.parallel_num()) { // There some cases which is not applicable to this ratio // We just take the one with the largest possibility // For example: (P, S0) -> (B, B) for 1-D blob with machine hierarchy [n, m] // The path should be (P, S0) -> (S0, S0) -> (B, B) // true intersection ratio = 1/m + 1 intersection_ratio = 2.0; } else { // sbp_consumer = (B, Si) or (Si, B) for (int32_t sbp_id = 0; sbp_id < std::min(producer_sbp_parallel.sbp_parallel_size(), consumer_sbp_parallel.sbp_parallel_size()); sbp_id++) { if (consumer_sbp_parallel.sbp_parallel(sbp_id).has_split_parallel()) { const auto& producer_sbp4sbp_id = producer_sbp_parallel.sbp_parallel(sbp_id); // (B, P) or (Si, P) -> (Si, B) // (P, B) or (P, Si) -> (B, Si) if (producer_sbp4sbp_id.has_broadcast_parallel() || producer_sbp4sbp_id == consumer_sbp_parallel.sbp_parallel(sbp_id)) { intersection_ratio = 2.0; break; } } } // Judge whether the intersection ratio is given a value (2.0) if (intersection_ratio == 1.0) { // The true intersection ratio range from 0 to 2, // we just take a middle point of the range as the approximation // For example: (P, S0) -> (S0, B), Path: (P, S0) -> (S1, S0) -> (S0, B) // true intersection ratio = 1 + 1/m // For example: (P, S0) -> (S1, B), Path: (P, S0) -> (S1, S0) -> (S1, B) // true intersection ratio = 1 + 1 // For example: (P, S0) -> (B, S0), with a 1D blob // true intersection ratio = (n+p-1)/nm + (n+p-1)/nm // For example: (S0, P) -> (B, S0), Path: (S0, P) -> (S0, S1) -> (B, S0) // true intersection ratio = 1 + 1/n // We use the approximation 1 + (1/n + 1/m)/2 intersection_ratio = 1.0 + 0.5 / producer_parallel_desc.hierarchy()->At(0) + 0.5 / producer_parallel_desc.hierarchy()->At(1); } } } // Otherwise, on different devices // intersection_ratio = 1.0; } else { // No P in the producer or no B in the consumer, one-step transfer if (on_same_devices) { // We use simulation for nD sbp with n=1,2,3,... TensorSliceView in_second_slice = GetTensorSliceView4ParallelId(*producer_parallel_desc.hierarchy(), producer_sbp_parallel, logical_blob_desc.shape(), /*parallel_id=*/1); TensorSliceView out_second_slice = GetTensorSliceView4ParallelId(*consumer_parallel_desc.hierarchy(), consumer_sbp_parallel, logical_blob_desc.shape(), /*parallel_id=*/1); const TensorSliceView& intersection = in_second_slice.Intersect(out_second_slice); // The intersection ratio is design for two steps. // However, we only have one step here, we would increase the ratio by 1.0 // to eliminate the unused step intersection_ratio += std::min( 1.0, (double)(intersection.shape().elem_cnt() * producer_parallel_desc.parallel_num()) / logical_blob_desc.shape().elem_cnt()); } // Otherwise, on different devices // intersection_ratio = 1.0; } // Subtract the intersection part return (producer_partial_ratio + consumer_broadcast_ratio - intersection_ratio) * TotalByteSize4BlobDesc(logical_blob_desc); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/sbp_infer_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_ #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { enum SbpInferRuleTag : int { kAllMatch = 1, // All match first, then lowest cost kMatchAMAP = 2, // Match as much as possible kMinCost = 3 // Lowest cost }; enum Penalty4PartialInConsumerTag : int { kSlight = 1, // Slight penalty kMiddle = 2, // Make sure we do not select P in the consumer kStrict = 3 // Not allow a transfer to P }; // [2, 3, 4, 5, 9, 100, 8]: (P, S0, P, P, B, S1, P) // partial ratio = 2 * 4 * 5 * 8 int32_t PartialRatio4Producer(const NdSbp& sbp_producer, const ParallelDesc& producer_parallel_desc); // [2, 3, 4, 5, 9, 100, 8]: (P, S0, B, P, B, S1, P) // broadcast ratio = 4 * 9 int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, const ParallelDesc& consumer_parallel_desc); void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, NdSbp* reduced_nd_sbp, const Shape& logical_shape); void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd_sbps, Shape* reduced_hierarchy, const std::vector& reduced_nd_sbps, const Shape& logical_shape); void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, const Shape& logical_shape); void InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy, NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, const Shape& logical_shape); void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, const Shape& logical_shape); double GetValidMaxCopyCost(); double GetTransferCost(); void ResizeNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t size); void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp_signature, int32_t sbp_axis); void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, const Shape& hierarchy, const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list); void DeduplicateNdSbpSignatureList(std::vector* nd_sbp_sig_list, const std::vector& bns); // Compute storage for given NdSbp double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy); // Judge whether an NdSbp could be applied on a tensor with given logical shape Maybe FilterNdSbpByLogicalShape(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy); // TODO: Unify lazy and eager boxing Maybe ComputeCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp); // Cost for boxing in lazy Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp); // The public interface for computing cost // It uses the middle nodes algorithm. Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp); // Decide the priority to infer sbp // 0: highest priority // 1.0: normal priority // 2.0: Penalty, the same as infinity double ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp, const Shape& logical_shape); // The transfer ratio for general basic communication // Cost = ratio * data amount double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc); int64_t TotalByteSize4BlobDesc(const BlobDesc& logical_blob_desc); int64_t MaxByteSize4BlobDescSbp(const BlobDesc& logical_blob_desc, const NdSbp& nd_sbp, const Shape& hierarchy); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_ ================================================ FILE: oneflow/core/framework/sbp_infer_util_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/framework/nd_sbp.h" #include namespace oneflow { namespace test { namespace { bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str, NdSbpSignature& nd_sbp_signature) { auto* bn2nd_sbp = nd_sbp_signature.mutable_bn_in_op2nd_sbp(); std::string arg_name = "in"; bool meet_nd_sbp_group = false; bool meet_split = false; int nd_sbp_group_id = 0; std::vector nd_sbp_str_group; size_t pos = 0; while (pos < nd_sbp_signature_str.size()) { const char& c = nd_sbp_signature_str[pos]; pos++; if (c == ' ') { continue; } else if (c == '(') { if (!meet_nd_sbp_group) { // enter a nd-sbp group meet_nd_sbp_group = true; nd_sbp_str_group.emplace_back(); continue; } else { // meet left parentheses of S(x) meet_split = true; } } else if (c == ')') { if (meet_split) { // meet right parentheses of S(x) meet_split = false; } else if (meet_nd_sbp_group) { // leave a nd-sbp group meet_nd_sbp_group = false; std::string bn = arg_name + "_" + std::to_string(nd_sbp_group_id); if (!ParseNdSbpFromStringList(nd_sbp_str_group, &(*bn2nd_sbp)[bn])) { return false; } nd_sbp_str_group.clear(); continue; } else { return false; } } else if (c == ',') { if (meet_nd_sbp_group) { nd_sbp_str_group.emplace_back(); } else { nd_sbp_group_id += 1; } continue; } else if (c == '-') { if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') { // in args parsing has finished, parse out args arg_name = "out"; nd_sbp_group_id = 0; // skip '>' in substr '->' pos++; continue; } else { return false; } } else { // do nothing } nd_sbp_str_group.back() += c; } return true; } std::string NdSbpSignature2String(const NdSbpSignature& nd_sbp_signature, const std::vector& inputs, const std::vector& outputs) { std::ostringstream ss; auto BnNdSbpToString = [&](const std::string& bn) { auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn); CHECK(iter != nd_sbp_signature.bn_in_op2nd_sbp().end()); ss << NdSbpToString(iter->second); }; auto ArgsNdSbpToString = [&](const std::vector& arg_bns) { for (size_t i = 0; i < arg_bns.size(); ++i) { if (i > 0) { ss << ", "; } BnNdSbpToString(arg_bns[i]); } }; ArgsNdSbpToString(inputs); ss << " -> "; ArgsNdSbpToString(outputs); return ss.str(); } void TestDeduplicateNdSbpSignature(const std::vector& nd_sbp_signature_str_list, const std::vector& input_bns, const std::vector& output_bns) { // parse std::vector nd_sbp_sig_list; nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size()); for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) { nd_sbp_sig_list.emplace_back(); ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back())); } // shuffle and repeat std::random_device rd; std::mt19937 gen(rd()); std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen); nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2); std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2, std::back_inserter(nd_sbp_sig_list)); std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen); // dedup and sort std::vector bns; bns.insert(bns.end(), input_bns.begin(), input_bns.end()); bns.insert(bns.end(), output_bns.begin(), output_bns.end()); DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns); // compare ASSERT_EQ(nd_sbp_signature_str_list.size(), nd_sbp_sig_list.size()); for (size_t i = 0; i < nd_sbp_sig_list.size(); ++i) { auto nd_sbp_sig_result = NdSbpSignature2String(nd_sbp_sig_list[i], input_bns, output_bns); ASSERT_EQ(nd_sbp_sig_result, nd_sbp_signature_str_list[i]); } } } // namespace TEST(SbpInferUtil, DeduplicateNdSbpSignatureList) { TestDeduplicateNdSbpSignature( { "(B, B) -> (B, B)", "(B, P) -> (B, P)", "(B, S(0)) -> (B, S(0))", "(B, S(1)) -> (B, S(1))", "(B, S(3)) -> (B, S(2))", "(P, B) -> (P, B)", "(P, P) -> (P, P)", "(P, S(0)) -> (P, S(0))", "(P, S(1)) -> (P, S(1))", "(P, S(3)) -> (P, S(2))", "(S(0), B) -> (S(0), B)", "(S(0), P) -> (S(0), P)", "(S(0), S(0)) -> (S(0), S(0))", "(S(0), S(1)) -> (S(0), S(1))", "(S(0), S(3)) -> (S(0), S(2))", "(S(1), B) -> (S(1), B)", "(S(1), P) -> (S(1), P)", "(S(1), S(0)) -> (S(1), S(0))", "(S(1), S(1)) -> (S(1), S(1))", "(S(1), S(3)) -> (S(1), S(2))", "(S(3), B) -> (S(2), B)", "(S(3), P) -> (S(2), P)", "(S(3), S(0)) -> (S(2), S(0))", "(S(3), S(1)) -> (S(2), S(1))", "(S(3), S(3)) -> (S(2), S(2))", }, {"in_0"}, {"out_0"}); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/framework/scope_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/common/just.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/session_util.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace { Maybe MakeDefaultScope() { JobConfigProto config_proto; config_proto.mutable_predict_conf(); config_proto.set_job_name(""); return MakeScope(config_proto, *JUST(Device::New("cpu"))); } std::list>* ThreadLocalScopeStack() { thread_local static std::list> scope_stack{CHECK_JUST(MakeDefaultScope())}; return &scope_stack; } } // namespace Maybe MakeScope(const JobConfigProto& config_proto, const Device& device) { std::shared_ptr scope; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { int64_t session_id = JUST(GetDefaultSessionId()); std::string device_tag = "cpu"; std::string machine_ids = "0"; std::string device_ids = "0"; if (device.type() != "cpu") { device_tag = device.type(); device_ids = std::to_string(device.device_id()); } scope = JUST(builder->BuildInitialScope(session_id, config_proto, device_tag, {machine_ids + ":" + device_ids}, nullptr, false)); return Maybe::Ok(); })); return scope; } Maybe MakeInitialScope(const JobConfigProto& job_conf, Symbol placement, bool is_local) { std::shared_ptr scope; JUST(PhysicalRun([&scope, &job_conf, placement, is_local](InstructionsBuilder* builder) -> Maybe { int64_t session_id = JUST(GetDefaultSessionId()); scope = JUST(builder->BuildInitialScopeWithPlacement(session_id, job_conf, placement, is_local)); return Maybe::Ok(); })); return scope; } Maybe GetCurrentScope() { auto* scope_stack = ThreadLocalScopeStack(); CHECK_GT_OR_RETURN(scope_stack->size(), 0); return scope_stack->back(); } Maybe InitThreadLocalScopeStack(const std::shared_ptr& scope) { auto* scope_stack = ThreadLocalScopeStack(); scope_stack->clear(); scope_stack->emplace_back(scope); return Maybe::Ok(); } Maybe ThreadLocalScopeStackPush(const std::shared_ptr& scope) { auto* scope_stack = ThreadLocalScopeStack(); scope_stack->emplace_back(scope); return Maybe::Ok(); } Maybe ThreadLocalScopeStackPop() { auto* scope_stack = ThreadLocalScopeStack(); scope_stack->pop_back(); return Maybe::Ok(); } BackwardPassScopeGuard::BackwardPassScopeGuard() { if (LazyMode::is_enabled()) { const auto& scope = CHECK_JUST(GetCurrentScope()); if (scope) { backward_pass_scope_ = CHECK_JUST(FindOrCreateBackwardPassScope(scope)); CHECK_JUST(ThreadLocalScopeStackPush(backward_pass_scope_)); } } } BackwardPassScopeGuard::BackwardPassScopeGuard(const std::shared_ptr& scope) { if (scope && LazyMode::is_enabled()) { backward_pass_scope_ = CHECK_JUST(FindOrCreateBackwardPassScope(scope)); CHECK_JUST(ThreadLocalScopeStackPush(backward_pass_scope_)); } } BackwardPassScopeGuard::~BackwardPassScopeGuard() { if (backward_pass_scope_) { CHECK_JUST(ThreadLocalScopeStackPop()); } } class BackwardPassScopeStorage { public: std::mutex mutex; static BackwardPassScopeStorage* Global() { static BackwardPassScopeStorage instance; return &instance; } HashMap>& get() { return scopes_; } private: HashMap> scopes_; }; extern const std::string kBackwardPass; Maybe FindOrCreateBackwardPassScope(const std::shared_ptr& scope) { auto* storage = BackwardPassScopeStorage::Global(); auto& scopes = storage->get(); std::lock_guard lock(storage->mutex); auto it = scopes.find(JUST(scope->symbol_id())); if (it != scopes.end()) { return it->second; } auto scope_proto = JUST((scope->MakeChildScopeProto())); scope_proto->set_calculation_pass_name(kBackwardPass); std::shared_ptr backward_pass_scope; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { backward_pass_scope = JUST(builder->GetScopeSymbol(*scope_proto)); return Maybe::Ok(); })); scopes.emplace(JUST(scope->symbol_id()), backward_pass_scope); return backward_pass_scope; } void ClearAllBackwardPassScope() { auto* storage = BackwardPassScopeStorage::Global(); std::lock_guard lock(storage->mutex); storage->get().clear(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/scope_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_ #include #include "oneflow/core/job/scope.h" namespace oneflow { Maybe MakeScope(const JobConfigProto& config_proto, const Device& device); Maybe MakeInitialScope(const JobConfigProto& job_conf, Symbol placement, bool is_local); Maybe GetCurrentScope(); Maybe InitThreadLocalScopeStack(const std::shared_ptr& scope); Maybe ThreadLocalScopeStackPush(const std::shared_ptr& scope); Maybe ThreadLocalScopeStackPop(); class BackwardPassScopeGuard { public: BackwardPassScopeGuard(); explicit BackwardPassScopeGuard(const std::shared_ptr& scope); ~BackwardPassScopeGuard(); private: std::shared_ptr backward_pass_scope_; }; Maybe FindOrCreateBackwardPassScope(const std::shared_ptr& scope); void ClearAllBackwardPassScope(); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_ ================================================ FILE: oneflow/core/framework/session_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/session_util.h" namespace oneflow { namespace { std::mutex* GlobalSessionUtilMutex() { static std::mutex global_id2session_map_mutex; return &global_id2session_map_mutex; } std::vector* RegsiteredSessionIds() { static std::vector default_sess_id; return &default_sess_id; } Maybe SetDefaultSessionId(int64_t val) { std::vector* ids = RegsiteredSessionIds(); ids->emplace_back(val); return Maybe::Ok(); } } // namespace Maybe GetDefaultSessionId() { std::unique_lock lock(*GlobalSessionUtilMutex()); const auto& regsitered_ids = *(RegsiteredSessionIds()); CHECK_GT_OR_RETURN(regsitered_ids.size(), 0); return regsitered_ids.back(); } bool RegsterSessionId(int64_t session_id) { std::unique_lock lock(*GlobalSessionUtilMutex()); auto* regsitered_ids = RegsiteredSessionIds(); auto itor = std::find(regsitered_ids->begin(), regsitered_ids->end(), session_id); if (itor != regsitered_ids->end()) { return false; } regsitered_ids->push_back(session_id); return true; } bool ClearSessionId(int64_t session_id) { std::unique_lock lock(*GlobalSessionUtilMutex()); auto* regsitered_ids = RegsiteredSessionIds(); auto itor = std::find(regsitered_ids->begin(), regsitered_ids->end(), session_id); if (itor == regsitered_ids->end()) { return false; } regsitered_ids->erase(itor); return true; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/session_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_ #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe GetDefaultSessionId(); bool RegsterSessionId(int64_t session_id); bool ClearSessionId(int64_t session_id); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_ ================================================ FILE: oneflow/core/framework/shut_down_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/shut_down_util.h" namespace oneflow { namespace { std::atomic* GetShuttingDown() { static std::atomic shutting_down{false}; return &shutting_down; } } // namespace bool IsShuttingDown() { auto* shutting_down = GetShuttingDown(); bool is_interpreter_shutdown = *shutting_down; return is_interpreter_shutdown; } void SetShuttingDown(bool arg_shutting_down) { auto* shutting_down = GetShuttingDown(); *shutting_down = arg_shutting_down; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/shut_down_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_ #include "oneflow/core/common/maybe.h" namespace oneflow { bool IsShuttingDown(); void SetShuttingDown(bool arg_shutting_down = true); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_ ================================================ FILE: oneflow/core/framework/stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/static_global.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/stream_mgr.h" #include "oneflow/core/vm/stream_get_allocator_stream_type.h" #include "oneflow/core/ep/include/device_manager.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { Stream::Stream(Symbol device, StreamType stream_type, size_t thread_uid) : device_(device), stream_type_(stream_type), thread_uid_(thread_uid), unique_stream_id_(-1), support_wait_event_(false) { ep::DeviceManager* device_mgr = Singleton::Get()->GetDeviceManagerOrNull(device->enum_type()); if (!device_mgr) { return; } support_wait_event_ = device_mgr->IsStreamWaitEventSupported(); } Maybe Stream::Init(size_t unique_stream_id) { unique_stream_id_ = unique_stream_id; return Maybe::Ok(); } /*static*/ Maybe> Stream::RawNew(Symbol device, StreamType stream_type, size_t thread_uid) { std::shared_ptr stream(new Stream(device, stream_type, thread_uid)); return JUST(SingletonMaybe()) ->AddStreamSymbol(*stream, [&](size_t unique_stream_id) -> Maybe> { JUST(stream->Init(unique_stream_id)); return SymbolOf(*stream); }); } /*static*/ Maybe> Stream::New(Symbol device, StreamType stream_type, size_t thread_uid) { constexpr auto* Make = DECORATE(&Stream::RawNew, ThreadLocalCopiable); return Make(device, stream_type, thread_uid); } namespace { Maybe> RawGetDefaultStreamByDevice(Symbol device) { return Stream::New(device, StreamType::kCompute); } Maybe> RawGetDefaultStreamByPlacement(Symbol parallel_desc) { return RawGetDefaultStreamByDevice(JUST(GetTensorDevice(parallel_desc))); } Maybe> RawGetAllocatorStream(Symbol stream) { StreamType allocator_stream_type = JUST(GetAllocatorStreamType::Visit(stream->stream_type())); if (allocator_stream_type == stream->stream_type()) { return stream; } return Stream::New(stream->device(), allocator_stream_type, stream->thread_uid()); } } // namespace int64_t Stream::kDefaultStreamThreadUid = 0; decltype(GetDefaultStreamByDevice) GetDefaultStreamByDevice = DECORATE(&RawGetDefaultStreamByDevice, ThreadLocal); decltype(GetDefaultStreamByPlacement) GetDefaultStreamByPlacement = DECORATE(&RawGetDefaultStreamByPlacement, ThreadLocal); decltype(GetAllocatorStream) GetAllocatorStream = DECORATE(&RawGetAllocatorStream, ThreadLocal); } // namespace oneflow ================================================ FILE: oneflow/core/framework/stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_H_ #include #include "oneflow/core/common/stream_type.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/device.h" namespace oneflow { class Stream final { public: Stream(const Stream&) = default; Stream(Stream&&) = default; ~Stream() = default; bool operator==(const Stream& that) const { return this->device() == that.device() && this->stream_type() == that.stream_type() && this->thread_uid() == that.thread_uid() && this->support_wait_event() == that.support_wait_event(); } bool operator!=(const Stream& that) const { return !(*this == that); } static Maybe> New(Symbol device, StreamType stream_type) { return New(device, stream_type, kDefaultStreamThreadUid); } static Maybe> New(Symbol device, StreamType stream_type, size_t thread_uid); Symbol device() const { return device_; } StreamType stream_type() const { return stream_type_; } size_t thread_uid() const { return thread_uid_; } size_t unique_stream_id() const { return unique_stream_id_; } bool support_wait_event() const { return support_wait_event_; } static int64_t kDefaultStreamThreadUid; private: Stream(Symbol device, StreamType stream_type, size_t thread_uid); static Maybe> RawNew(Symbol device, StreamType stream_type, size_t thread_uid); Maybe Init(size_t unique_stream_id); Symbol device_; StreamType stream_type_; size_t thread_uid_; size_t unique_stream_id_; bool support_wait_event_; }; extern Maybe> (*GetDefaultStreamByDevice)(Symbol); class ParallelDesc; extern Maybe> (*GetDefaultStreamByPlacement)(Symbol); extern Maybe> (*GetAllocatorStream)(Symbol); } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::Stream& stream) const { using namespace oneflow; return Hash(stream.device(), stream.stream_type(), stream.thread_uid()); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_H_ ================================================ FILE: oneflow/core/framework/stream_allocator_is_pinned.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_ #include "oneflow/core/common/stream_type.h" namespace oneflow { struct IsStreamAllocatorPinned : public StreamTypeVisitor { static bool VisitCompute() { return false; } static bool VisitHost2Device() { return false; } static bool VisitDevice2Host() { return false; } static bool VisitCcl() { return false; } static bool VisitBarrier() { return false; } static bool VisitCriticalSection() { return false; } static bool VisitLazyJobLauncher() { return false; } static bool VisitPinnedCompute() { return true; } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_ ================================================ FILE: oneflow/core/framework/stream_get_stream_type_name.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_ #include #include "oneflow/core/common/stream_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/framework/to_string.h" namespace oneflow { struct GetStreamTypeName : public StreamTypeVisitor { static const char* VisitCompute() { return "compute"; } static const char* VisitHost2Device() { return "h2d"; } static const char* VisitDevice2Host() { return "d2h"; } static const char* VisitCcl() { return "ccl"; } static const char* VisitBarrier() { return "barrier"; } static const char* VisitCriticalSection() { return "critical_section"; } static const char* VisitLazyJobLauncher() { return "lazy_job_launcher"; } static const char* VisitPinnedCompute() { return "pinned_compute"; } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_ ================================================ FILE: oneflow/core/framework/stream_guard.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/stream_guard.h" namespace oneflow { /*static*/ Optional* StreamGuard::MutCurrent() { static thread_local Optional current; return ¤t; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/stream_guard.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/env_var/stream.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_set.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/thread/thread_global_id.h" namespace oneflow { class StreamConverter final { public: explicit StreamConverter(const std::shared_ptr& stream_set) : stream_set_(stream_set) {} Maybe> TryConvertStream(Symbol stream) { size_t thread_uid = stream_set_->worker_thread_id(); return Stream::New(stream->device(), stream->stream_type(), thread_uid); } private: const std::shared_ptr stream_set_; }; class StreamGuard final { public: explicit StreamGuard(const std::shared_ptr& stream_converter) { old_value_ = Current(); *MutCurrent() = stream_converter; } ~StreamGuard() { *MutCurrent() = old_value_; } static Maybe> TryConvertStream(Symbol stream) { if (!Current().has_value()) { return stream; } return JUST(Current())->TryConvertStream(stream); } private: static const Optional& Current() { return *MutCurrent(); } static Optional* MutCurrent(); Optional old_value_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_ ================================================ FILE: oneflow/core/framework/stream_is_comm_net_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_ #include "oneflow/core/common/stream_type.h" namespace oneflow { struct IsCommNetStream final : public StreamTypeVisitor { static bool VisitCompute() { return false; } static bool VisitHost2Device() { return false; } static bool VisitDevice2Host() { return false; } static bool VisitCcl() { return true; } static bool VisitBarrier() { return false; } static bool VisitCriticalSection() { return false; } static bool VisitLazyJobLauncher() { return false; } static bool VisitPinnedCompute() { return VisitCompute(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_ ================================================ FILE: oneflow/core/framework/stream_mgr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/stream_mgr.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/util.h" namespace oneflow { Maybe> StreamMgr::AddStreamSymbol( const Stream& stream, const std::function>(size_t unique_stream_id)>& CreateStreamSymbol) { Symbol stream_symbol; std::unique_lock lock(mutex_); if (stream2unique_stream_id_.count(stream) > 0) { size_t unique_stream_id = stream2unique_stream_id_[stream]; auto existed_stream_symbol = JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id)); stream_symbol = JUST(CreateStreamSymbol(unique_stream_id)); CHECK_OR_RETURN(existed_stream_symbol == stream_symbol) << "the result of current called CreateStreamSymbol is not the result of last called " "CreateStreamSymbol"; } else { size_t unique_stream_id = unique_stream_id2stream_symbol_.size(); stream2unique_stream_id_[stream] = unique_stream_id; stream_symbol = JUST(CreateStreamSymbol(unique_stream_id)); unique_stream_id2stream_symbol_.push_back(stream_symbol); CHECK_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id] == stream) << "the result of CreateStreamSymbol is no the symbol of `stream`"; CHECK_EQ_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id]->unique_stream_id(), unique_stream_id) << "unique_stream_id is wrongly initialized"; } return stream_symbol; } size_t StreamMgr::UniqueStreamSize() const { std::unique_lock lock(mutex_); return unique_stream_id2stream_symbol_.size(); } Maybe> StreamMgr::GetStreamSymbol(size_t unique_stream_id) const { std::unique_lock lock(mutex_); return JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id)); } COMMAND(Singleton::SetAllocated(new StreamMgr())); } // namespace oneflow ================================================ FILE: oneflow/core/framework/stream_mgr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ #include #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/stream.h" namespace oneflow { class StreamMgr final { public: StreamMgr() = default; ~StreamMgr() = default; Maybe> AddStreamSymbol( const Stream& stream, const std::function>(size_t unique_stream_id)>& CreateStreamSymbol); size_t UniqueStreamSize() const; Maybe> GetStreamSymbol(size_t unique_stream_id) const; private: mutable std::mutex mutex_; std::vector> unique_stream_id2stream_symbol_; std::unordered_map stream2unique_stream_id_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ ================================================ FILE: oneflow/core/framework/stream_need_soft_sync.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_ #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/stream_type.h" namespace oneflow { struct NeedSoftSync : public StreamTypeVisitor { static bool VisitCompute(DeviceType device_type) { return device_type != kCPU; } static bool VisitHost2Device(DeviceType) { return false; } static bool VisitDevice2Host(DeviceType) { return false; } static bool VisitCcl(DeviceType device_type) { return false; } static bool VisitBarrier(DeviceType) { return false; } static bool VisitCriticalSection(DeviceType) { return false; } static bool VisitLazyJobLauncher(DeviceType) { return false; } static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_ ================================================ FILE: oneflow/core/framework/stream_on_independent_thread.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ #include "oneflow/core/common/stream_type.h" namespace oneflow { struct StreamOnIndependentThread : public StreamTypeVisitor { static bool VisitCompute() { return false; } static bool VisitHost2Device() { return false; } static bool VisitDevice2Host() { return false; } static bool VisitCcl() { return false; } static bool VisitBarrier() { return false; } static bool VisitCriticalSection() { return true; } static bool VisitLazyJobLauncher() { return true; } static bool VisitPinnedCompute() { return VisitCompute(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ ================================================ FILE: oneflow/core/framework/stream_set.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/core/framework/stream_set.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/common/env_var/stream.h" #include "oneflow/core/common/container_util.h" namespace oneflow { StreamSet::StreamSet(int64_t worker_thread_id) : worker_thread_id_(worker_thread_id) {} StreamSet::~StreamSet() {} /*static*/ Maybe StreamSet::New(int64_t worker_thread_id) { return std::shared_ptr(new StreamSet(worker_thread_id)); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/stream_set.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/stream.h" namespace oneflow { class StreamSet final { public: ~StreamSet(); static Maybe New(int64_t worker_thread_id); int64_t worker_thread_id() const { return worker_thread_id_; } private: StreamSet(int64_t worker_thread_id); int64_t worker_thread_id_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_ ================================================ FILE: oneflow/core/framework/stream_support_stream_wait.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_ #define ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_ #include "oneflow/core/common/stream_type.h" namespace oneflow { struct StreamSupportStreamWait : public StreamTypeVisitor { static bool VisitCompute(DeviceType device_type) { return Supported(device_type); } static bool VisitHost2Device(DeviceType device_type) { return Supported(device_type); } static bool VisitDevice2Host(DeviceType device_type) { return Supported(device_type); } static bool VisitCcl(DeviceType device_type) { return Supported(device_type); } static bool VisitBarrier(DeviceType device_type) { return false; } static bool VisitCriticalSection(DeviceType device_type) { return false; } static bool VisitLazyJobLauncher(DeviceType device_type) { return false; } static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); } private: static bool Supported(DeviceType device_type) { return device_type == kCUDA; } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_ ================================================ FILE: oneflow/core/framework/symbol_storage_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/operator/op_node_signature.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_conf_symbol.h" #include "oneflow/core/vm/symbol_storage.h" namespace oneflow { COMMAND( Singleton>::SetAllocated(new symbol::Storage())); COMMAND(Singleton>::SetAllocated(new symbol::Storage())); COMMAND(Singleton>::SetAllocated(new symbol::Storage())); COMMAND(Singleton>::SetAllocated( new symbol::Storage())); } // namespace oneflow ================================================ FILE: oneflow/core/framework/symbol_storage_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_ #define ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_ #include "oneflow/core/vm/symbol_storage.h" namespace oneflow { template Maybe GetSymbol(int64_t symbol_id) { const auto& symbol_storage = *Singleton>::Get(); const auto& ptr = JUST(symbol_storage.MaybeGetPtr(symbol_id)); JUST(ptr->symbol_id()); return ptr; } } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_ ================================================ FILE: oneflow/core/framework/sync_symbol_global_tensor_meta.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/sync_symbol_global_tensor_meta.h" #include "oneflow/core/framework/sync_symbol_parallel_desc.h" #include "oneflow/core/framework/sync_symbol_nd_sbp.h" #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/synced_symbol_map.h" #include "oneflow/core/common/flat_shape.h" namespace oneflow { struct FlatGlobalTensorMeta final { static Maybe New(uint64_t symbol_id, Symbol global_tensor_meta) { const auto& meta = std::make_shared(); JUST(meta->Init(symbol_id, global_tensor_meta)); return meta; } Maybe Init(uint64_t symbol_id, Symbol global_tensor_meta) { this->symbol_id = symbol_id; JUST(this->shape.Init(global_tensor_meta->shape())); this->dtype = static_cast(global_tensor_meta->dtype()); this->is_dynamic = global_tensor_meta->is_dynamic(); this->nd_sbp = JUST(SyncedSymbolMap::FindOrSync(global_tensor_meta->nd_sbp(), &SyncSymbolNdSbp)); this->parallel_desc = JUST(SyncedSymbolMap::FindOrSync( global_tensor_meta->parallel_desc(), &SyncSymbolParallelDesc)); return Maybe::Ok(); } Maybe Check(uint64_t symbol_id, Symbol global_tensor_meta) { CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id); JUST(this->shape.Check(global_tensor_meta->shape())); CHECK_EQ_OR_RETURN(static_cast(this->dtype), global_tensor_meta->dtype()); // NOLINT CHECK_EQ_OR_RETURN(this->is_dynamic, global_tensor_meta->is_dynamic()); // NOLINT const auto& nd_sbp = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId(this->nd_sbp)); CHECK_OR_RETURN(nd_sbp == global_tensor_meta->nd_sbp()); // NOLINT const auto& parallel_desc = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId(this->parallel_desc)); CHECK_OR_RETURN(parallel_desc == global_tensor_meta->parallel_desc()); // NOLINT return Maybe::Ok(); } uint64_t symbol_id; FlatShape shape; int32_t dtype; bool is_dynamic; uint64_t nd_sbp; uint64_t parallel_desc; }; Maybe SyncSymbolGlobalTensorMeta(uint64_t symbol_id, Symbol global_tensor_meta) { const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolGlobalTensorMeta)); const auto& recv_buffer = std::make_shared(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { const auto& send_buffer = JUST(FlatGlobalTensorMeta::New(symbol_id, global_tensor_meta)); *buffer = send_buffer.get(); *size = sizeof(FlatGlobalTensorMeta); *Cb = [send_buffer] {}; return Maybe::Ok(); }, [recv_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_buffer.get(); *size = sizeof(FlatGlobalTensorMeta); *Cb = [recv_buffer] {}; return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); JUST(ctx.WaitDone()); JUST(recv_buffer->Check(symbol_id, global_tensor_meta)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/sync_symbol_global_tensor_meta.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_ #define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/transport_token.h" namespace oneflow { namespace one { class GlobalTensorMeta; } Maybe SyncSymbolGlobalTensorMeta(uint64_t symbol_id, Symbol); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_ ================================================ FILE: oneflow/core/framework/sync_symbol_nd_sbp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/framework/sync_symbol_nd_sbp.h" #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/constant.h" namespace oneflow { namespace { // clang-format off FLAT_MSG_BEGIN(FlatSplitParallel); FLAT_MSG_DEFINE_OPTIONAL(int64_t, axis); FLAT_MSG_END(FlatSplitParallel); FLAT_MSG_BEGIN(FlatBroadcastParallel); FLAT_MSG_END(FlatBroadcastParallel); FLAT_MSG_BEGIN(FlatPartialSumParallel); FLAT_MSG_END(FlatPartialSumParallel); FLAT_MSG_BEGIN(FlatSbpParallel); public: Maybe Init(const SbpParallel& sbp_parallel) { if (sbp_parallel.has_split_parallel()) { this->mutable_split_parallel()->set_axis(sbp_parallel.split_parallel().axis()); } else if (sbp_parallel.has_broadcast_parallel()) { this->mutable_broadcast_parallel(); } else if (sbp_parallel.has_partial_sum_parallel()) { this->mutable_partial_sum_parallel(); } else { OF_UNIMPLEMENTED(); } return Maybe::Ok(); } Maybe Check(const SbpParallel& sbp_parallel) const { if (sbp_parallel.has_split_parallel()) { CHECK_EQ_OR_RETURN(this->split_parallel().axis(), sbp_parallel.split_parallel().axis()); } else if (sbp_parallel.has_broadcast_parallel()) { CHECK_OR_RETURN(this->has_broadcast_parallel()); } else if (sbp_parallel.has_partial_sum_parallel()) { CHECK_OR_RETURN(this->has_partial_sum_parallel()); } else { OF_UNIMPLEMENTED(); } return Maybe::Ok(); } private: FLAT_MSG_DEFINE_ONEOF(parallel_type, FLAT_MSG_ONEOF_FIELD(FlatSplitParallel, split_parallel) FLAT_MSG_ONEOF_FIELD(FlatBroadcastParallel, broadcast_parallel) FLAT_MSG_ONEOF_FIELD(FlatPartialSumParallel, partial_sum_parallel)); FLAT_MSG_END(FlatSbpParallel); FLAT_MSG_BEGIN(FlatNdSbp); public: Maybe Init(uint64_t symbol_id, Symbol nd_sbp) { this->set_symbol_id(symbol_id); this->set_size(nd_sbp->sbp_parallel_size()); for (int i = 0; i < this->size(); ++i) { const auto& sbp_parallel = nd_sbp->sbp_parallel(i); JUST(this->mutable_sbp_parallel()->Mutable(i)->Init(sbp_parallel)); } return Maybe::Ok(); } Maybe Check(uint64_t symbol_id, Symbol nd_sbp) const { CHECK_EQ_OR_RETURN(this->symbol_id(), symbol_id); CHECK_EQ_OR_RETURN(this->size(), nd_sbp->sbp_parallel_size()); for (int i = 0; i < this->size(); ++i) { JUST(this->sbp_parallel().Get(i).Check(nd_sbp->sbp_parallel(i))); } return Maybe::Ok(); } private: FLAT_MSG_DEFINE_OPTIONAL(uint64_t, symbol_id); FLAT_MSG_DEFINE_OPTIONAL(size_t, size); FLAT_MSG_DEFINE_REPEATED(FlatSbpParallel, sbp_parallel, SHAPE_MAX_AXIS_SIZE); FLAT_MSG_END(FlatNdSbp); // clang-format on class FlatNdSbpAsyncTransportCtx : public AsyncTransportCtx { public: FlatNdSbpAsyncTransportCtx(const TransportToken& transport_token, uint64_t symbol_id, Symbol nd_sbp) : AsyncTransportCtx(transport_token), symbol_id_(symbol_id), nd_sbp_(nd_sbp) {} ~FlatNdSbpAsyncTransportCtx() override {} Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { const auto& flat_nd_sbp = std::make_shared(); JUST(flat_nd_sbp->Init(symbol_id_, nd_sbp_)); *buffer = flat_nd_sbp.get(); *size = sizeof(FlatNdSbp); *Callback = [flat_nd_sbp]() {}; return Maybe::Ok(); } Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { const auto& flat_nd_sbp = std::make_shared(); *buffer = flat_nd_sbp.get(); *size = sizeof(FlatNdSbp); *Callback = [flat_nd_sbp]() {}; flat_nd_sbp_ = flat_nd_sbp; return Maybe::Ok(); } Maybe Check() const { CHECK_NOTNULL_OR_RETURN(flat_nd_sbp_.get()); JUST(flat_nd_sbp_->Check(symbol_id_, nd_sbp_)); return Maybe::Ok(); } private: uint64_t symbol_id_; Symbol nd_sbp_; std::shared_ptr flat_nd_sbp_; }; } // namespace namespace {} Maybe SyncSymbolNdSbp(uint64_t symbol_id, Symbol symbol) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolNdSbp)); FlatNdSbpAsyncTransportCtx ctx(transport_token, symbol_id, symbol); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg); JUST(ctx.Check()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/sync_symbol_nd_sbp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_ #define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/transport_token.h" namespace oneflow { class NdSbp; Maybe SyncSymbolNdSbp(uint64_t symbol_id, Symbol); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_ ================================================ FILE: oneflow/core/framework/sync_symbol_parallel_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/sync_symbol_parallel_desc.h" #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/constant.h" namespace oneflow { namespace { static const int kLimitParallelConfString = 1024 * 64; struct FlatParallelConf { size_t available_size() const { CHECK_GE(this->buffer_size, 0) << "Buffer size should be non-negative"; CHECK_LT(this->buffer_size, kLimitParallelConfString) << "Buffer size should be less than " << kLimitParallelConfString; return sizeof(FlatParallelConf) - kLimitParallelConfString + this->buffer_size; } size_t capacity() const { return sizeof(FlatParallelConf); } static Maybe New(uint64_t symbol_id, Symbol parallel_desc) { const auto& data = std::make_shared(); JUST(data->Init(symbol_id, parallel_desc)); return data; } Maybe Init(uint64_t symbol_id, Symbol parallel_desc) { const auto& parallel_conf = parallel_desc->parallel_conf(); int64_t byte_size = parallel_conf.ByteSize(); CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString) << Error::InvalidValueError() << "Byte size of parallel description should be less than " << kLimitParallelConfString << ", but got " << byte_size; this->symbol_id = symbol_id; this->buffer_size = byte_size; CHECK_OR_RETURN(parallel_conf.SerializeToArray(this->buffer, kLimitParallelConfString)) << Error::RuntimeError() << "Error serializing parallel description: " << parallel_conf.ShortDebugString(); return Maybe::Ok(); } Maybe Check(uint64_t symbol_id, Symbol parallel_desc) const { const auto& parallel_conf = parallel_desc->parallel_conf(); int64_t byte_size = parallel_conf.ByteSize(); const auto& debugString = parallel_conf.ShortDebugString(); CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString) << Error::InvalidValueError() << "Byte size of parallel description should be less than " << kLimitParallelConfString << ", but got " << byte_size; CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id) << Error::RuntimeError() << "expected symbol id " << symbol_id << ", but got " << this->symbol_id; CHECK_EQ_OR_RETURN(this->buffer_size, byte_size) << Error::RuntimeError() << "Inconsistent parallel description: " << debugString; std::vector serialized(byte_size); CHECK_OR_RETURN(parallel_conf.SerializeToArray(serialized.data(), kLimitParallelConfString)) << Error::RuntimeError() << "Error serializing parallel description: " << debugString; CHECK_EQ_OR_RETURN(std::memcmp(serialized.data(), this->buffer, byte_size), 0) << Error::RuntimeError() << "Inconsistent parallel description: " << debugString; return Maybe::Ok(); } uint64_t symbol_id; uint64_t buffer_size; char buffer[kLimitParallelConfString]; }; } // namespace Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol parallel_desc) { const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolParallelDesc)); const auto& recv_buffer = std::make_shared(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { const auto& send_buffer = JUST(FlatParallelConf::New(symbol_id, parallel_desc)); *buffer = send_buffer.get(); *size = send_buffer->available_size(); *Cb = [send_buffer] {}; return Maybe::Ok(); }, [recv_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_buffer.get(); *size = recv_buffer->capacity(); *Cb = [recv_buffer] {}; return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg); JUST(recv_buffer->Check(symbol_id, parallel_desc)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/sync_symbol_parallel_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ #define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/transport_token.h" namespace oneflow { class ParallelDesc; Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ ================================================ FILE: oneflow/core/framework/synced_symbol_map.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/synced_symbol_map.h" namespace oneflow { uint64_t GetAutoIncrementalSymbolId() { static thread_local uint64_t id = 4096; return id++; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/synced_symbol_map.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ #define ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/type_traits.h" #include "oneflow/core/job/rank_group_scope.h" namespace oneflow { uint64_t GetAutoIncrementalSymbolId(); template struct SyncedSymbolMap final { template static Maybe FindOrSync(Symbol symbol, const SyncT& Sync) { auto* map = JUST(MutThreadLocalSymbol2SyncedSymbolId()); const auto& iter = map->find(symbol); if (iter != map->end()) { return iter->second; } uint64_t symbol_id = GetAutoIncrementalSymbolId(); JUST(Sync(symbol_id, symbol)); JUST(Emplace(symbol_id, symbol)); return symbol_id; } static Maybe> Symbol4SyncedSymbolId(uint64_t synced_symbol_id) { auto* map = JUST(MutThreadLocalSyncedSymbolId2Symbol()); return JUST(MapAt(*map, synced_symbol_id)); } private: static Maybe Emplace(uint64_t synced_symbol_id, Symbol symbol) { auto* id2symbol = JUST(MutThreadLocalSyncedSymbolId2Symbol()); CHECK_OR_RETURN(id2symbol->emplace(synced_symbol_id, symbol).second); auto* symbol2id = JUST(MutThreadLocalSymbol2SyncedSymbolId()); CHECK_OR_RETURN(symbol2id->emplace(symbol, synced_symbol_id).second); return Maybe::Ok(); } static Maybe>*> MutThreadLocalSyncedSymbolId2Symbol() { static thread_local auto* map = new std::unordered_map, std::unordered_map>>(); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); return &(*map)[rank_group]; } static Maybe, uint64_t>*> MutThreadLocalSymbol2SyncedSymbolId() { static thread_local auto* map = new std::unordered_map, std::unordered_map, uint64_t>>(); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); return &(*map)[rank_group]; } }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ ================================================ FILE: oneflow/core/framework/tensor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" namespace oneflow { namespace one { Maybe Tensor::BorrowTensorName(const Tensor* other) const { CHECK_OR_RETURN(other->is_lazy()) << Error::RuntimeError() << "can not borrow tensor name from an eager tensor"; const auto& lbn = TensorNameScope::Global()->Lookup(other); CHECK_OR_RETURN(!lbn.empty()) << "the input lazy tensor has no tensor name"; TensorNameScope::Global()->Record(this, lbn); return Maybe::Ok(); } Maybe Tensor::set_ref_tensor(const std::shared_ptr& ref) { ref_tensor_ = ref; return Maybe::Ok(); } Maybe Tensor::set_ref_index(const int64_t index) { ref_index_ = index; return Maybe::Ok(); } Maybe StaticZerosTensor::AsLocalTensor() { CHECK_OR_RETURN(is_local()); // NOLINT return std::dynamic_pointer_cast( JUST(functional::Constant(*shape_, Scalar(0), CHECK_JUST(DType::Get(dtype_)), device_))); } Parameter::Parameter(const std::shared_ptr& tensor, bool requires_grad) : ProxyTensor(tensor) { CHECK_JUST(this->tensor_->set_requires_grad(requires_grad)); if (tensor->is_local() && tensor->is_eager()) { if (auto rematable_storage = std::dynamic_pointer_cast( CHECK_JUST(tensor_->eager_blob_object())->tensor_storage()); rematable_storage != nullptr && tensor_->is_local() && tensor_->is_eager()) { rematable_storage->set_eviction_disabled(true); } } } Maybe Parameter::set_data(const std::shared_ptr& other) { if (is_local() && is_eager()) { auto rematable_storage = std::dynamic_pointer_cast( CHECK_JUST(tensor_->eager_blob_object())->tensor_storage()); bool enable_remat = rematable_storage != nullptr && tensor_->is_local() && tensor_->is_eager(); if (enable_remat) { rematable_storage->set_eviction_disabled(false); } JUST(tensor_->set_data(other)); if (enable_remat) { rematable_storage->set_eviction_disabled(true); } } else { JUST(tensor_->set_data(other)); } return Maybe::Ok(); } std::shared_ptr Parameter::contiguous() const { const auto& tensor = std::const_pointer_cast(shared_from_this()); if (tensor_->is_contiguous()) { return tensor; } return CHECK_JUST(functional::ToContiguous(tensor)); } std::shared_ptr Parameter::pin_memory() const { std::shared_ptr tensor = std::const_pointer_cast(shared_from_this()); return CHECK_JUST(functional::PinMemory(tensor)); } /* static */ Maybe LocalTensor::MakeTensor(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format, const Symbol& device, bool is_lazy, bool requires_grad, bool is_leaf) { const auto& tensor_meta = SymbolOf(LocalTensorMeta(*shape, dtype, memory_format, device)); if (is_lazy) { const auto& impl = std::make_shared(tensor_meta, requires_grad, is_leaf); return std::make_shared(impl); } else { const auto& impl = std::make_shared(requires_grad, is_leaf); const auto& dep_object = NewLocalDepObject(); JUST(impl->InitEagerBlobObject(tensor_meta, dep_object)); return std::make_shared(impl); } } bool LocalTensor::is_cpu() const { return CHECK_JUST(device())->type() == "cpu"; } bool LocalTensor::is_cuda() const { return CHECK_JUST(device())->type() == "cuda"; } Maybe LocalTensor::detach() const { std::shared_ptr tensor = std::make_shared(JUST(impl_->detach())); if (this->is_lazy()) { JUST(tensor->BorrowTensorName(this)); } return tensor; } std::shared_ptr LocalTensor::contiguous() const { std::shared_ptr tensor = std::const_pointer_cast(shared_from_this()); if (tensor->is_contiguous()) { return tensor; } return CHECK_JUST(functional::ToContiguous(tensor)); } std::shared_ptr LocalTensor::pin_memory() const { std::shared_ptr tensor = std::const_pointer_cast(shared_from_this()); return CHECK_JUST(functional::PinMemory(tensor)); } Maybe LocalTensor::clone() const { std::shared_ptr input = std::const_pointer_cast(shared_from_this()); const bool pin_memory = JUST(JUST(input->AsLocalTensor())->is_pinned()); return JUST(functional::Copy(input, JUST(this->device()), /*pin_memory=*/pin_memory)); } Maybe LocalTensor::set_data(const std::shared_ptr& other) { CHECK_OR_RETURN(this->is_leaf()) << "Can only set leaf tensor's data."; const auto& mirrored_tensor = std::dynamic_pointer_cast(JUST(other->detach())); CHECK_NOTNULL_OR_RETURN(mirrored_tensor) << "Can not set a global tensor to the data of a local tensor"; bool old_requires_grad = requires_grad(); impl_ = mirrored_tensor->impl_; JUST(set_requires_grad(old_requires_grad)); grad_fn_node_ = nullptr; if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); } return Maybe::Ok(); } #define TENSOR_OFFLOAD_CHECK(is_offloaded, msg) \ if (is_cpu()) { \ LOG(WARNING) << "Only non-cpu tensor can be offloaded."; \ return Maybe::Ok(); \ } \ if (is_offloaded_ != is_offloaded) { \ LOG(WARNING) << "This tensor has already be " << msg << "."; \ return Maybe::Ok(); \ } Maybe LocalTensor::offload() { TENSOR_OFFLOAD_CHECK(false, "offloaded"); // Offload to cpu mem with a cpu tensor implantation. int64_t device_id = JUST(this->device())->device_id(); std::shared_ptr cuda_tensor = shared_from_this(); auto offloaded_tensor = JUST(functional::Copy(cuda_tensor, "cpu", device_id, /*pin_memory=*/JUST(is_pinned()))); JUST(vm::CurrentRankSync()); const auto& detached_tensor = std::dynamic_pointer_cast(JUST(offloaded_tensor->detach())); CHECK_NOTNULL_OR_RETURN(detached_tensor) << " detached_tensor must be a local tensor."; offloaded_impl_ = detached_tensor->impl_; // Release cuda memory, but the meta data is valid. auto eager_blob_obj = JUST(JUST(impl_->mut_eager_local_tensor_impl())->eager_blob_object()); JUST(eager_blob_obj->DeallocateBlobDataPtr()); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); is_offloaded_ = true; return Maybe::Ok(); } Maybe LocalTensor::load() { TENSOR_OFFLOAD_CHECK(true, "loaded"); // Load cpu to cuda. int64_t device_id = JUST(this->device())->device_id(); std::shared_ptr cpu_tensor = std::make_shared(offloaded_impl_); auto loaded_tensor = JUST(functional::Copy(cpu_tensor, "cuda", device_id, /*pin_memory=*/JUST(cpu_tensor->is_pinned()))); JUST(vm::CurrentRankSync()); JUST(set_data(loaded_tensor)); // Release cpu memory. cpu_tensor.reset(); offloaded_impl_.reset(); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); is_offloaded_ = false; return Maybe::Ok(); } std::shared_ptr GlobalTensor::contiguous() const { std::shared_ptr tensor = std::const_pointer_cast(shared_from_this()); if (tensor->is_contiguous()) { return tensor; } return CHECK_JUST(functional::ToContiguous(tensor)); } std::shared_ptr GlobalTensor::pin_memory() const { std::shared_ptr tensor = std::const_pointer_cast(shared_from_this()); return CHECK_JUST(functional::PinMemory(tensor)); } Maybe GlobalTensor::clone() const { std::shared_ptr input = std::const_pointer_cast(shared_from_this()); DisableCheckGlobalTensorMetaScope disable_meta_check{}; return JUST(functional::ToGlobal(input, JUST(parallel_desc()), *JUST(GetSbpList(JUST(nd_sbp()))), /*grad_sbp_parallels=*/{}, /* sync_data */ true, /*copy=*/true)); } Maybe GlobalTensor::MakeTensor(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol nd_sbp, Symbol parallel_desc, bool is_lazy, bool requires_grad, bool is_leaf) { std::shared_ptr impl; Symbol global_tensor_meta( GlobalTensorMeta(*shape, dtype, memory_format, nd_sbp, parallel_desc)); if (is_lazy) { impl = std::make_shared(global_tensor_meta, requires_grad, is_leaf); } else { impl = JUST(EagerGlobalTensorImpl::New(global_tensor_meta, requires_grad, is_leaf)); } return std::make_shared(impl); } bool GlobalTensor::is_cpu() const { return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCPU; } bool GlobalTensor::is_cuda() const { return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCUDA; } Maybe GlobalTensor::detach() const { std::shared_ptr tensor = std::make_shared(JUST(impl_->detach())); if (this->is_lazy()) { JUST(tensor->BorrowTensorName(this)); } return tensor; } Maybe GlobalTensor::set_data(const std::shared_ptr& other) { CHECK_OR_RETURN(this->is_leaf()) << "Only leaf tensor's data can be set, because non-leaf tensor's data has been captured in " "the backward graph in autograd."; const auto& global_tensor = std::dynamic_pointer_cast(JUST(other->detach())); CHECK_NOTNULL_OR_RETURN(global_tensor); // NOLINT JUST(WithConsistencyChecked(global_tensor, [&]() -> Maybe { return Maybe::Ok(); })); bool old_requires_grad = requires_grad(); impl_ = global_tensor->impl_; JUST(set_requires_grad(old_requires_grad)); grad_fn_node_ = nullptr; if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); } return Maybe::Ok(); } Maybe GlobalTensor::offload() { TENSOR_OFFLOAD_CHECK(false, "offloaded"); // Offload to cpu mem with a cpu tensor implantation. std::shared_ptr cuda_tensor = shared_from_this(); auto offloaded_tensor = JUST(functional::Copy(cuda_tensor, "cpu", GlobalProcessCtx::LocalRank(), /*pin_memory=*/false)); JUST(vm::ClusterSync()); const auto& detached_tensor = std::dynamic_pointer_cast(JUST(offloaded_tensor->detach())); CHECK_NOTNULL_OR_RETURN(detached_tensor) << "detached_tensor must be a global tensor."; offloaded_impl_ = detached_tensor->impl_; // Release cuda memory, but the meta data is valid. auto eager_blob_obj = JUST(JUST(impl_->cur_rank_phy_tensor())->eager_blob_object()); JUST(eager_blob_obj->DeallocateBlobDataPtr()); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); is_offloaded_ = true; return Maybe::Ok(); } Maybe GlobalTensor::load() { TENSOR_OFFLOAD_CHECK(true, "loaded"); // Load cpu to cuda. std::shared_ptr cpu_tensor = std::make_shared(offloaded_impl_); auto loaded_tensor = JUST(functional::Copy(cpu_tensor, "cuda", GlobalProcessCtx::LocalRank(), /*pin_memory=*/false)); JUST(vm::ClusterSync()); JUST(set_data(loaded_tensor)); // Release cpu memory. cpu_tensor.reset(); offloaded_impl_.reset(); auto* vm = JUST(SingletonMaybe()); JUST(vm->ShrinkAllMem()); is_offloaded_ = false; return Maybe::Ok(); } #undef TENSOR_OFFLOAD_CHECK } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_H_ #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/framework/tensor_impl.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/common/error.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/job/global_mode.h" namespace oneflow { class NdSbp; class Device; namespace one { class FunctionNode; class GlobalTensor; class LocalTensor; class Tensor : public std::enable_shared_from_this { public: virtual ~Tensor() = default; // Getters int64_t dim(int64_t index) const { return shape()->At(index); } int64_t nelement() const { return shape()->elem_cnt(); } int64_t ndim() const { return shape()->NumAxes(); } Maybe ref_tensor() const { return ref_tensor_.lock(); } int64_t ref_index() const { return ref_index_; } virtual std::shared_ptr shape() const = 0; virtual Symbol dtype() const = 0; virtual Maybe transport_token() const = 0; virtual Maybe> nd_sbp() const = 0; virtual Maybe> parallel_desc() const = 0; virtual Maybe> device() const = 0; virtual Maybe*> mut_device() = 0; virtual bool is_cpu() const = 0; virtual bool is_cuda() const = 0; virtual bool is_global() const = 0; virtual bool is_local() const { return !is_global(); } virtual bool is_lazy() const = 0; virtual bool is_eager() const { return !is_lazy(); } virtual bool is_contiguous() const = 0; virtual bool is_view() const = 0; virtual Maybe is_pinned() const = 0; virtual const TensorMeta& tensor_meta() const = 0; virtual Maybe data() = 0; virtual std::shared_ptr pin_memory() const = 0; virtual Maybe> local_tensor_meta() const { OF_UNIMPLEMENTED(); } virtual Maybe> global_tensor_meta() const { OF_UNIMPLEMENTED(); } // Getters valid only for EagerLocalTensor virtual Maybe mut_eager_local_tensor_impl() { OF_UNIMPLEMENTED(); } virtual Maybe eager_blob_object() const = 0; virtual Maybe compute_local_dep_object() const = 0; virtual Maybe has_eager_blob_object() const = 0; virtual Maybe tensor_storage() const { OF_UNIMPLEMENTED(); } virtual Maybe stride() const { OF_UNIMPLEMENTED(); } virtual Maybe storage_offset() const { OF_UNIMPLEMENTED(); } virtual MemoryFormat memory_format() const = 0; // Getters/Setters valid only for EagerGlobalTensor virtual Maybe>&> consumer_nd_sbp_constraint() const { OF_UNIMPLEMENTED(); } virtual Maybe cur_rank_phy_tensor() const { OF_UNIMPLEMENTED(); } virtual Maybe set_consumer_nd_sbp_constraint(const Optional>& val) { OF_UNIMPLEMENTED(); } // Getters for autograd virtual bool requires_grad() const = 0; virtual bool is_leaf() const = 0; virtual bool retain_grad() const = 0; virtual std::shared_ptr grad_fn_node() const = 0; virtual int32_t get_grad_fn_output_index() const = 0; virtual Maybe acc_grad() const = 0; virtual Maybe current_grad() const = 0; virtual Maybe detach() const = 0; virtual Maybe clone() const = 0; virtual std::shared_ptr contiguous() const = 0; // Setters for autograd virtual Maybe set_requires_grad(bool requires_grad) = 0; virtual Maybe set_retain_grad(bool retain_grad) = 0; virtual void set_grad_fn_node(const std::shared_ptr& grad_fn_node) = 0; virtual std::shared_ptr mut_grad_fn_node() = 0; virtual void set_grad_fn_output_index(int32_t idx) = 0; virtual Maybe set_acc_grad(const std::shared_ptr& grad) = 0; virtual Maybe mut_acc_grad() = 0; virtual void set_is_leaf(bool is_leaf) = 0; virtual std::shared_ptr autograd_meta() const = 0; virtual std::shared_ptr mut_autograd_meta() = 0; virtual void set_autograd_meta(const std::shared_ptr& autograd_meta) = 0; virtual user_op::TensorDesc* mut_tensor_meta() = 0; virtual Maybe set_data(const std::shared_ptr& other) = 0; // For offloading between devices virtual Maybe offload() = 0; virtual Maybe load() = 0; virtual Maybe is_offloaded() const = 0; virtual Maybe RegisterStorageDeleteHook(const std::function& hook) { OF_UNIMPLEMENTED(); }; virtual Maybe AsLocalTensor() = 0; virtual Maybe AsGlobalTensor() = 0; Maybe BorrowTensorName(const Tensor* other) const; Maybe set_ref_tensor(const std::shared_ptr& ref); Maybe set_ref_index(const int64_t index); // The same tensor instance should share the python object to ensure that // their id are consistent in Python. That is if x and y are hold the same tensor, // then `id(x)` should equal to `id(y)` void* pyobject() const { return pyobj_ptr_.get(); } void set_pyobject_ptr(std::unique_ptr&& pyobj_ptr) { pyobj_ptr_ = std::move(pyobj_ptr); } bool owns_pyobj() const { return owns_pyobj_; } void set_owns_pyobj(bool owns_pyobj) { owns_pyobj_ = owns_pyobj; } protected: Tensor() : pyobj_ptr_(nullptr, [](void*) {}), owns_pyobj_(false), ref_tensor_(std::weak_ptr()), ref_index_(0) {} private: std::unique_ptr pyobj_ptr_; bool owns_pyobj_; std::weak_ptr ref_tensor_; int64_t ref_index_; }; class StaticZerosTensor final : public Tensor { public: static Maybe MakeTensor(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol device) { return std::shared_ptr( new StaticZerosTensor(shape, dtype, memory_format, device)); } // Getters std::shared_ptr shape() const override { return shape_; } Symbol dtype() const override { return CHECK_JUST(DType::Get(dtype_)); } Maybe transport_token() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe> nd_sbp() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe> parallel_desc() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe> device() const override { return device_; } Maybe*> mut_device() override { RETURN_ERROR_WITH_BUG_PROMPT(); } bool is_cpu() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool is_cuda() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool is_global() const override { return false; } bool is_local() const override { return !is_global(); } bool is_lazy() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool is_eager() const override { return !is_lazy(); } const TensorMeta& tensor_meta() const override { PRINT_BUG_PROMPT_AND_ABORT(); return *(TensorMeta*)nullptr; } Maybe data() override { RETURN_ERROR_WITH_BUG_PROMPT(); } std::shared_ptr pin_memory() const override { return std::const_pointer_cast(shared_from_this()); } Maybe> local_tensor_meta() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe> global_tensor_meta() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } // Getters valid only for EagerLocalTensor Maybe mut_eager_local_tensor_impl() override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe compute_local_dep_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe stride() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe storage_offset() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } MemoryFormat memory_format() const override { return memory_format_; } // Getters/Setters valid only for EagerGlobalTensor Maybe>&> consumer_nd_sbp_constraint() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe cur_rank_phy_tensor() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe set_consumer_nd_sbp_constraint(const Optional>& val) override { RETURN_ERROR_WITH_BUG_PROMPT(); } // Getters for autograd bool requires_grad() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool is_leaf() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool retain_grad() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } bool is_contiguous() const override { PRINT_BUG_PROMPT_AND_ABORT(); return true; } bool is_view() const override { PRINT_BUG_PROMPT_AND_ABORT(); return false; } Maybe is_pinned() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } std::shared_ptr grad_fn_node() const override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } int32_t get_grad_fn_output_index() const override { PRINT_BUG_PROMPT_AND_ABORT(); return 0; } Maybe acc_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe current_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe detach() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe clone() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } std::shared_ptr contiguous() const override { return std::const_pointer_cast(shared_from_this()); } // Setters for autograd Maybe set_requires_grad(bool requires_grad) override { PRINT_BUG_PROMPT_AND_ABORT(); return Maybe::Ok(); } Maybe set_retain_grad(bool retain_grad) override { RETURN_ERROR_WITH_BUG_PROMPT(); return Maybe::Ok(); } void set_grad_fn_node(const std::shared_ptr& grad_fn_node) override { PRINT_BUG_PROMPT_AND_ABORT(); } void set_grad_fn_output_index(int32_t idx) override { PRINT_BUG_PROMPT_AND_ABORT(); } std::shared_ptr mut_grad_fn_node() override { PRINT_BUG_PROMPT_AND_ABORT(); return *(std::shared_ptr*)nullptr; } Maybe set_acc_grad(const std::shared_ptr& grad) override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe mut_acc_grad() override { RETURN_ERROR_WITH_BUG_PROMPT(); } void set_is_leaf(bool is_leaf) override { PRINT_BUG_PROMPT_AND_ABORT(); } std::shared_ptr autograd_meta() const override { PRINT_BUG_PROMPT_AND_ABORT(); } std::shared_ptr mut_autograd_meta() override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } void set_autograd_meta(const std::shared_ptr& autograd_meta) override { PRINT_BUG_PROMPT_AND_ABORT(); } user_op::TensorDesc* mut_tensor_meta() override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } Maybe set_data(const std::shared_ptr& other) override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe offload() override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe load() override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe is_offloaded() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe AsLocalTensor() override; Maybe AsGlobalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); } private: StaticZerosTensor(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol device) : shape_(shape), dtype_(dtype), memory_format_(memory_format), device_(device) {} const std::shared_ptr shape_; DataType dtype_; MemoryFormat memory_format_; Symbol device_; }; template class TensorIf : public Tensor { public: virtual ~TensorIf() = default; // Getters for autograd // acc_grad is tensor's accumulated grad in more than once backward operation, // and current_grad is temporary grad to shared data with different FunctionNode std::shared_ptr grad_fn_node() const override { return grad_fn_node_; } int32_t get_grad_fn_output_index() const override { return grad_fn_output_index_; } // Setters for autograd void set_grad_fn_node(const std::shared_ptr& grad_fn_node) override { grad_fn_node_ = grad_fn_node; } std::shared_ptr mut_grad_fn_node() override { return grad_fn_node_; } void set_grad_fn_output_index(int32_t idx) override { grad_fn_output_index_ = idx; } protected: TensorIf() = default; std::shared_ptr grad_fn_node_; int32_t grad_fn_output_index_ = -1; }; template class ProxyTensor : public TensorIf { public: ProxyTensor(const std::shared_ptr& tensor) : tensor_(tensor) { if (tensor->is_lazy()) { CHECK_JUST(this->BorrowTensorName(tensor.get())); } } virtual ~ProxyTensor() = default; virtual std::shared_ptr shape() const override { return tensor_->shape(); } virtual Symbol dtype() const override { return tensor_->dtype(); } virtual Maybe> nd_sbp() const override { return tensor_->nd_sbp(); } virtual Maybe> parallel_desc() const override { return tensor_->parallel_desc(); } virtual Maybe> device() const override { return tensor_->device(); } virtual Maybe*> mut_device() override { return tensor_->mut_device(); } virtual bool is_cpu() const override { return tensor_->is_cpu(); } virtual bool is_cuda() const override { return tensor_->is_cuda(); } virtual bool is_global() const override { return tensor_->is_global(); } virtual bool is_local() const override { return tensor_->is_local(); } virtual bool is_lazy() const override { return tensor_->is_lazy(); } virtual bool is_eager() const override { return tensor_->is_eager(); } virtual const TensorMeta& tensor_meta() const override { return tensor_->tensor_meta(); } virtual Maybe> local_tensor_meta() const override { return tensor_->local_tensor_meta(); } virtual Maybe> global_tensor_meta() const override { return tensor_->global_tensor_meta(); } virtual Maybe data() override { return tensor_->detach(); } virtual std::shared_ptr pin_memory() const override { return tensor_->pin_memory(); } // Must override grad_fn_node function. Otherwise grad_fn will belong to this not tensor_, // and it will be wrong when use Tensor.data() in operators. virtual std::shared_ptr grad_fn_node() const override { return tensor_->grad_fn_node(); } virtual void set_grad_fn_node(const std::shared_ptr& grad_fn_node) override { tensor_->set_grad_fn_node(grad_fn_node); } virtual std::shared_ptr mut_grad_fn_node() override { return tensor_->mut_grad_fn_node(); } virtual Maybe mut_eager_local_tensor_impl() override { return tensor_->mut_eager_local_tensor_impl(); } virtual Maybe eager_blob_object() const override { return tensor_->eager_blob_object(); } virtual Maybe compute_local_dep_object() const override { return tensor_->compute_local_dep_object(); } virtual Maybe has_eager_blob_object() const override { return tensor_->has_eager_blob_object(); } virtual Maybe tensor_storage() const override { return tensor_->tensor_storage(); } virtual Maybe stride() const override { return tensor_->stride(); } virtual Maybe storage_offset() const override { return tensor_->storage_offset(); } virtual MemoryFormat memory_format() const override { return tensor_->memory_format(); } virtual Maybe>&> consumer_nd_sbp_constraint() const override { return tensor_->consumer_nd_sbp_constraint(); } virtual Maybe transport_token() const override { return tensor_->transport_token(); } virtual Maybe cur_rank_phy_tensor() const override { return tensor_->cur_rank_phy_tensor(); } virtual Maybe set_consumer_nd_sbp_constraint(const Optional>& val) override { return tensor_->set_consumer_nd_sbp_constraint(val); } virtual bool requires_grad() const override { return tensor_->requires_grad(); } virtual bool is_leaf() const override { return tensor_->is_leaf(); } virtual bool retain_grad() const override { return tensor_->retain_grad(); } virtual bool is_contiguous() const override { return tensor_->is_contiguous(); } virtual bool is_view() const override { return tensor_->is_view(); } virtual Maybe is_pinned() const override { return tensor_->is_pinned(); } virtual Maybe acc_grad() const override { return tensor_->acc_grad(); } virtual Maybe current_grad() const override { return tensor_->current_grad(); } virtual Maybe detach() const override { return tensor_->detach(); } virtual Maybe clone() const override { return tensor_->clone(); } virtual Maybe set_requires_grad(bool requires_grad) override { return tensor_->set_requires_grad(requires_grad); } virtual Maybe set_retain_grad(bool retain_grad) override { return tensor_->set_retain_grad(retain_grad); } virtual Maybe set_acc_grad(const std::shared_ptr& grad) override { return tensor_->set_acc_grad(grad); } virtual Maybe mut_acc_grad() override { return tensor_->mut_acc_grad(); } virtual void set_is_leaf(bool is_leaf) override { return tensor_->set_is_leaf(is_leaf); } virtual std::shared_ptr autograd_meta() const override { return tensor_->autograd_meta(); } virtual std::shared_ptr mut_autograd_meta() override { return tensor_->mut_autograd_meta(); } virtual void set_autograd_meta(const std::shared_ptr& autograd_meta) override { return tensor_->set_autograd_meta(autograd_meta); } virtual user_op::TensorDesc* mut_tensor_meta() override { return tensor_->mut_tensor_meta(); } virtual Maybe set_data(const std::shared_ptr& other) override { bool old_requires_grad = tensor_->requires_grad(); this->tensor_ = JUST(other->detach()); JUST(this->tensor_->set_requires_grad(old_requires_grad)); if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); } return Maybe::Ok(); } virtual Maybe offload() override { JUST(tensor_->offload()); return Maybe::Ok(); } virtual Maybe load() override { JUST(tensor_->load()); return Maybe::Ok(); } Maybe is_offloaded() const override { return JUST(tensor_->is_offloaded()); } virtual Maybe AsLocalTensor() override { if (const auto& local_tensor = std::dynamic_pointer_cast(tensor_)) { return local_tensor; } RETURN_ERROR_WITH_BUG_PROMPT(); } virtual Maybe AsGlobalTensor() override { if (const auto& global_tensor = std::dynamic_pointer_cast(tensor_)) { return global_tensor; } RETURN_ERROR_WITH_BUG_PROMPT(); } protected: std::shared_ptr tensor_; }; class Parameter final : public ProxyTensor { public: static Maybe MakeTensor(const std::shared_ptr& tensor, bool requires_grad) { return std::shared_ptr(new Parameter(JUST(tensor->detach()), requires_grad)); } bool is_leaf() const override { return true; } std::shared_ptr contiguous() const override; std::shared_ptr pin_memory() const override; Maybe set_data(const std::shared_ptr& other) override; private: Parameter(const std::shared_ptr& tensor, bool requires_grad); }; class LocalTensor final : public TensorIf { public: OF_DISALLOW_COPY_AND_MOVE(LocalTensor); LocalTensor() = default; explicit LocalTensor(const std::shared_ptr& impl) { impl_ = impl; } ~LocalTensor() override = default; // Getters std::shared_ptr shape() const override { return impl_->shape(); } Symbol dtype() const override { return CHECK_JUST(DType::Get(impl_->dtype())); } Maybe transport_token() const override { OF_RUNTIME_ERROR() << "Only global tensors have 'global_id', global id is used to " "synchronize rank"; } Maybe> nd_sbp() const override { OF_RUNTIME_ERROR() << "Local tensor has no sbp property. " "sbp is the description in the oneflow distributed case, you can refer to " "https://docs.oneflow.org/master/parallelism/03_global_tensor.html; " "For example, create a global tensor like this : 'x = oneflow.tensor((2,3, " "placement=oneflow.placement(\"cuda\", {0: 0}), sbp=oneflow.sbp.broadcast))', then " "'x.sbp' is 'oneflow.sbp.broadcast'"; } Maybe> parallel_desc() const override { OF_RUNTIME_ERROR() << "Only global tensors have 'placement'. Placement is used to describe " "the distribution of global tensor in multiple GPUs. Please use " "'.device' for local tensors."; } Maybe> device() const override { return impl_->device(); } Maybe*> mut_device() override { return impl_->mut_device(); } bool is_lazy() const override { return impl_->is_lazy(); } bool is_global() const override { return false; } bool is_cpu() const override; bool is_cuda() const override; std::shared_ptr contiguous() const override; const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); } Maybe data() override { return this->detach(); } std::shared_ptr pin_memory() const override; // Getters valid only for EagerLocalTensor Maybe eager_blob_object() const override { return impl_->eager_blob_object(); } Maybe compute_local_dep_object() const override { return impl_->compute_local_dep_object(); } Maybe tensor_storage() const override { return impl_->tensor_storage(); } Maybe has_eager_blob_object() const override { return impl_->has_eager_blob_object(); } Maybe stride() const override { return impl_->stride(); } Maybe storage_offset() const override { return impl_->storage_offset(); } MemoryFormat memory_format() const override { return impl_->memory_format(); } // Getters for autograd Maybe acc_grad() const override { return impl_->acc_grad(); } Maybe current_grad() const override { return impl_->current_grad(); } bool requires_grad() const override { return impl_->requires_grad(); } bool is_leaf() const override { return impl_->is_leaf(); } bool retain_grad() const override { return impl_->retain_grad(); } bool is_contiguous() const override { return impl_->is_contiguous(); } bool is_view() const override { return impl_->is_view(); } Maybe is_pinned() const override { return impl_->is_pinned(); }; Maybe> local_tensor_meta() const override { return impl_->tensor_meta(); } // Setters for autograd Maybe set_acc_grad(const std::shared_ptr& grad) override { if (!grad_fn_node_ && requires_grad()) { CHECK_OR_RETURN(is_leaf()) << "only leaf tensor may have no grad_fn"; AddAccumulateFunctionNode(shared_from_this()); } return impl_->set_acc_grad(grad); } Maybe set_requires_grad(bool requires_grad) override { JUST(impl_->set_requires_grad(requires_grad)); if (!requires_grad) { set_grad_fn_node(nullptr); } return Maybe::Ok(); } Maybe set_retain_grad(bool retain_grad) override { return impl_->set_retain_grad(retain_grad); } Maybe mut_acc_grad() override { return impl_->mut_acc_grad(); } void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); } std::shared_ptr autograd_meta() const override { return impl_->autograd_meta(); } std::shared_ptr mut_autograd_meta() override { return impl_->mut_autograd_meta(); } void set_autograd_meta(const std::shared_ptr& autograd_meta) override { impl_->set_autograd_meta(autograd_meta); } // Operators for tensor Maybe detach() const override; Maybe clone() const override; static Maybe MakeTensor(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, MemoryFormat memory_format, const Symbol& device, bool is_lazy, bool requires_grad, bool is_leaf); LocalTensorImpl* mut_impl() { return impl_.get(); } Maybe mut_eager_local_tensor_impl() override { return impl_->mut_eager_local_tensor_impl(); } user_op::TensorDesc* mut_tensor_meta() override { return std::const_pointer_cast(impl_->mut_tensor_meta()).get(); } Maybe set_data(const std::shared_ptr& other) override; Maybe offload() override; Maybe load() override; Maybe is_offloaded() const override { return is_offloaded_; } Maybe set_impl(std::shared_ptr impl) { impl_ = impl; return Maybe::Ok(); } Maybe RegisterStorageDeleteHook(const std::function& hook) override { return impl_->RegisterStorageDeleteHook(hook); } Maybe AsLocalTensor() override { return std::dynamic_pointer_cast(shared_from_this()); } Maybe AsGlobalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); } private: std::shared_ptr impl_; std::shared_ptr offloaded_impl_; bool is_offloaded_{false}; }; class GlobalTensor final : public TensorIf { public: OF_DISALLOW_COPY_AND_MOVE(GlobalTensor); GlobalTensor() = default; explicit GlobalTensor(const std::shared_ptr& impl) { impl_ = impl; } ~GlobalTensor() override = default; // Getters std::shared_ptr shape() const override { return impl_->shape(); } Symbol dtype() const override { return CHECK_JUST(DType::Get(impl_->dtype())); } Maybe transport_token() const override { return impl_->transport_token(); } Maybe> nd_sbp() const override { return impl_->nd_sbp(); } Maybe> parallel_desc() const override { return impl_->parallel_desc(); } Maybe> device() const override { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); const auto& device_tag = JUST(parallel_desc())->device_tag(); return JUST(Device::New(device_tag)); } OF_RUNTIME_ERROR() << "Only local tensors have 'device'. Please use " "'.placement' for global tensors."; } Maybe*> mut_device() override { OF_RUNTIME_ERROR() << "GlobalTensor has no mut_device property"; } bool is_lazy() const override { return impl_->is_lazy(); } bool is_global() const override { return true; } Maybe>&> consumer_nd_sbp_constraint() const override { return impl_->consumer_nd_sbp_constraint(); } Maybe cur_rank_phy_tensor() const override { return impl_->cur_rank_phy_tensor(); } bool is_cpu() const override; bool is_cuda() const override; std::shared_ptr contiguous() const override; Maybe data() override { return this->detach(); } Maybe stride() const override { return impl_->stride(); } MemoryFormat memory_format() const override { return impl_->memory_format(); } std::shared_ptr pin_memory() const override; // Getters valid only for EagerLocalTensor Maybe eager_blob_object() const override { return impl_->eager_blob_object(); } Maybe compute_local_dep_object() const override { return impl_->compute_local_dep_object(); } const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); } Maybe tensor_storage() const override { return impl_->tensor_storage(); } Maybe has_eager_blob_object() const override { return impl_->has_eager_blob_object(); } // Setters Maybe set_consumer_nd_sbp_constraint(const Optional>& val) override { impl_->set_consumer_nd_sbp_constraint(val); return Maybe::Ok(); } // Getters for autograd Maybe acc_grad() const override { return impl_->acc_grad(); } Maybe current_grad() const override { return impl_->current_grad(); } bool requires_grad() const override { return impl_->requires_grad(); } bool is_leaf() const override { return impl_->is_leaf(); } bool retain_grad() const override { return impl_->retain_grad(); } bool is_contiguous() const override { return impl_->is_contiguous(); } bool is_view() const override { return impl_->is_view(); } Maybe is_pinned() const override { OF_RUNTIME_ERROR() << "Global tensor has no is_pinned method"; } // Setters for autograd Maybe set_acc_grad(const std::shared_ptr& grad) override { if (!grad_fn_node_ && requires_grad()) { CHECK_OR_RETURN(is_leaf()) << "only leaf tensor may have no grad_fn"; AddAccumulateFunctionNode(shared_from_this()); } return impl_->set_acc_grad(grad); } Maybe mut_acc_grad() override { return impl_->mut_acc_grad(); } Maybe set_requires_grad(bool requires_grad) override { JUST(impl_->set_requires_grad(requires_grad)); if (!requires_grad) { set_grad_fn_node(nullptr); } return Maybe::Ok(); } Maybe set_retain_grad(bool retain_grad) override { return impl_->set_retain_grad(retain_grad); } void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); } std::shared_ptr autograd_meta() const override { return impl_->autograd_meta(); } std::shared_ptr mut_autograd_meta() override { return impl_->mut_autograd_meta(); } void set_autograd_meta(const std::shared_ptr& autograd_meta) override { impl_->set_autograd_meta(autograd_meta); } // Operators for tensor Maybe detach() const override; Maybe clone() const override; static Maybe MakeTensor(const std::shared_ptr& shape, DataType dtype, MemoryFormat memory_format, Symbol nd_sbp, Symbol parallel_desc, bool is_lazy, bool requires_grad, bool is_leaf); GlobalTensorImpl* mut_impl() { return impl_.get(); } Maybe> global_tensor_meta() const override { return impl_->tensor_meta(); } user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); } Maybe set_data(const std::shared_ptr& other) override; Maybe offload() override; Maybe load() override; Maybe is_offloaded() const override { return is_offloaded_; } Maybe AsLocalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe AsGlobalTensor() override { return std::dynamic_pointer_cast(shared_from_this()); } private: std::shared_ptr impl_; std::shared_ptr offloaded_impl_; bool is_offloaded_{false}; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_H_ ================================================ FILE: oneflow/core/framework/tensor_arg.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { bool TensorArg::Empty() const { return !acc_tensor_; } void TensorArg::Release() { acc_tensor_.reset(); } Maybe TensorArg::PushPartialTensor(const std::shared_ptr& partial_tensor) { if (!acc_tensor_) { acc_tensor_ = partial_tensor; } else { // Should not inplace accumulate grad. For example, // >>> z = x + y // >>> p = x / z // >>> p.sum().backward() // // As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value // for dy if dx is shared with dz. acc_tensor_ = JUST(functional::Add(partial_tensor, acc_tensor_, /*alpha=*/1, /*inplace=*/false)); } return Maybe::Ok(); } Maybe TensorArg::GetAccTensor() const { CHECK_OR_RETURN(Empty() == false) << "Can not GetAccTensor because it is empty"; return acc_tensor_; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_arg.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_ #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/autograd/autograd_meta.h" namespace oneflow { namespace one { class Tensor; // This class will be used in TensorImpl and Autograd. It will share data with different // FunctionNodes. class TensorArg final { public: OF_DISALLOW_COPY_AND_MOVE(TensorArg); TensorArg() = default; ~TensorArg() = default; bool Empty() const; void Release(); Maybe PushPartialTensor(const std::shared_ptr& partial_tensor); Maybe GetAccTensor() const; private: std::shared_ptr acc_tensor_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_ ================================================ FILE: oneflow/core/framework/tensor_global_id.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/framework/tensor_global_id.h" namespace oneflow { namespace { Maybe> RawGetMetaTransportToken() { const auto& token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta)); return std::make_shared(token); } static constexpr auto* GetMetaTransportToken = DECORATE(&RawGetMetaTransportToken, ThreadLocal); } // namespace Maybe NewTensorGlobalId() { return ++**JUST(GetMetaTransportToken()); } namespace one { int64_t* MutThreadLocalGlobalIdDepth() { static thread_local int64_t recursive_depth = 0; return &recursive_depth; } Maybe InitGlobalId(TensorTuple* outputs) { for (const auto& output : *outputs) { CHECK_OR_RETURN(output); const auto& global_tensor = JUST(output->AsGlobalTensor()); CHECK_OR_RETURN(global_tensor) << Error::UnimplementedError() << "global tensors suppported only."; const auto& transport_token = JUST(NewTensorGlobalId()); JUST(global_tensor->mut_impl()->set_transport_token(transport_token)); } return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_global_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_ #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe NewTensorGlobalId(); namespace one { class TensorTuple; int64_t* MutThreadLocalGlobalIdDepth(); Maybe InitGlobalId(TensorTuple* outputs); template struct NonRecursiveInitGlobalId; template struct NonRecursiveInitGlobalId, Arg0, Arg1, TensorTuple*, Args...> { template (*func)(Arg0, Arg1, TensorTuple*, Args...)> static Maybe Call(Arg0 arg0, Arg1 arg1, TensorTuple* outputs, Args... args) { auto* recursive_depth = MutThreadLocalGlobalIdDepth(); ++*recursive_depth; Maybe ret = func(arg0, arg1, outputs, args...); --*recursive_depth; if (*recursive_depth == 0 && ret.IsOk()) { JUST(InitGlobalId(outputs)); } return ret; } }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_ ================================================ FILE: oneflow/core/framework/tensor_impl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_impl.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/stream_allocator_is_pinned.h" namespace oneflow { namespace one { Maybe TensorImpl::set_requires_grad(bool requires_grad) { if (requires_grad) { const DataType tensor_dtype = dtype(); CHECK_OR_RETURN(IsSupportRequireGradDataType(tensor_dtype)) << "RuntimeError: only Tensors of floating point or complex can require gradients"; } autograd_meta_->set_requires_grad(requires_grad); return Maybe::Ok(); } Maybe TensorImpl::acc_grad() const { return autograd_meta_->acc_grad(); } Maybe TensorImpl::current_grad() const { return autograd_meta_->current_grad(); } Maybe TensorImpl::set_acc_grad(const std::shared_ptr& grad) { return autograd_meta_->set_acc_grad(grad); } Maybe TensorImpl::mut_acc_grad() { return autograd_meta_->mut_acc_grad(); } Maybe TensorImpl::set_retain_grad(bool retain_grad) { if (!requires_grad() && retain_grad) { return Error::RuntimeError() << "Can't retain_grad on Tensor that has requires_grad=False"; } if (!is_leaf() && retain_grad) { autograd_meta_->set_retain_grad(retain_grad); } return Maybe::Ok(); } Maybe LazyLocalTensorImpl::detach() const { auto detached_impl = std::make_shared(tensor_meta_, false, true); return std::shared_ptr(detached_impl); } EagerLocalTensorImpl::EagerLocalTensorImpl(const std::shared_ptr& tensor_storage, int64_t storage_offset, bool requires_grad, bool is_leaf) : LocalTensorImpl(requires_grad, is_leaf), tensor_storage_(tensor_storage), storage_offset_(storage_offset) {} EagerLocalTensorImpl::~EagerLocalTensorImpl() {} Maybe EagerLocalTensorImpl::UpdateTensorStorage() { const auto& eager_blob_object = eager_blob_object_; tensor_storage_ = std::make_shared(eager_blob_object->tensor_storage()); tensor_storage_->set_releaser_hook([eager_blob_object]( const std::shared_ptr&) { auto ret = PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { if (eager_blob_object->producer_stream().has_value()) { JUST(builder->ReleaseTensor(eager_blob_object)); } return Maybe::Ok(); }); // We should not use CHECK_JUST here because it will throw an exception // in destructor. if (!ret.IsOk()) { LOG(WARNING) << "Release hook gets an error. Release hooks are executed in destructor, so the error " "is possibly only a secondary error caused by another unrelated exception."; LOG(WARNING) << "======= Error message begin ======="; LOG(WARNING) << ret.GetSerializedError(); LOG(WARNING) << "======= Error message end ======="; } }); return Maybe::Ok(); } const std::shared_ptr& EagerLocalTensorImpl::mut_tensor_meta() { return eager_blob_object_->mut_tensor_meta(); } // Getters const Symbol& EagerLocalTensorImpl::tensor_meta() const { return eager_blob_object_->tensor_meta(); } Maybe EagerLocalTensorImpl::compute_local_dep_object() const { return JUST(eager_blob_object())->compute_local_dep_object(); } Maybe EagerLocalTensorImpl::InitEagerBlobObject( const Symbol& local_tensor_meta, const std::shared_ptr& mut_local_tensor_meta, const intrusive::shared_ptr& dep_object) { CHECK_OR_RETURN(static_cast(local_tensor_meta->device())); // NOLINT const auto& mem_case = local_tensor_meta->device()->mem_case(); if (tensor_storage_) { auto tensor_storage = tensor_storage_->storage(); eager_blob_object_ = std::make_shared( mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(), local_tensor_meta->memory_format(), tensor_storage, dep_object); } else { auto device = local_tensor_meta->device(); auto storage = device->rematable() ? std::make_shared(device) : std::make_shared(true, device); const auto& eager_blob_object = std::make_shared( mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(), local_tensor_meta->memory_format(), storage, dep_object); JUST(set_eager_blob_object(eager_blob_object)); } return Maybe::Ok(); } Maybe EagerLocalTensorImpl::is_pinned() const { if (this->device() == JUST(Device::New("meta"))) { return false; } if (!eager_blob_object_) { return false; } return IsStreamAllocatorPinned::Visit(JUST(eager_blob_object_->producer_stream())->stream_type()); } Maybe EagerLocalTensorImpl::set_eager_blob_object( std::shared_ptr eager_blob_object) { eager_blob_object_ = eager_blob_object; CHECK_OR_RETURN(eager_blob_object_->shape() == tensor_meta()->shape()) << kOfBugIssueUploadPrompt; CHECK_OR_RETURN(eager_blob_object_->data_type() == tensor_meta()->dtype()) << kOfBugIssueUploadPrompt; JUST(UpdateTensorStorage()); return Maybe::Ok(); } std::shared_ptr EagerLocalTensorImpl::shape() const { if (!eager_blob_object_) { return tensor_meta()->shape_ptr(); } return eager_blob_object_->shape_ptr(); } std::shared_ptr EagerLocalTensorImpl::stride() const { if (!eager_blob_object_) { return tensor_meta()->stride_ptr(); } return eager_blob_object_->stride_ptr(); } MemoryFormat EagerLocalTensorImpl::memory_format() const { if (!eager_blob_object_) { return tensor_meta()->memory_format(); } return eager_blob_object_->memory_format(); } Maybe EagerLocalTensorImpl::detach() const { auto detached_impl = std::make_shared(tensor_storage_, false, true); detached_impl->eager_blob_object_ = eager_blob_object_; return std::shared_ptr(detached_impl); } Maybe EagerLocalTensorImpl::RegisterStorageDeleteHook(const std::function& hook) { CHECK_OR_RETURN(eager_blob_object_) << "EagerBlobObject has not initialized"; eager_blob_object_->RegisterStorageDeleteHook(hook); return Maybe::Ok(); } Maybe LazyGlobalTensorImpl::detach() const { auto detached_impl = std::make_shared(tensor_meta_, false, true); return std::shared_ptr(detached_impl); } EagerGlobalTensorImpl::EagerGlobalTensorImpl( Symbol global_tensor_meta, const std::shared_ptr& cur_rank_phy_tensor) : GlobalTensorImpl(global_tensor_meta, cur_rank_phy_tensor->requires_grad(), cur_rank_phy_tensor->is_leaf()), cur_rank_phy_tensor_(cur_rank_phy_tensor) {} /* static */ Maybe EagerGlobalTensorImpl::New( Symbol global_tensor_meta, bool requires_grad, bool is_leaf) { const auto& parallel_desc = global_tensor_meta->parallel_desc(); Optional parallel_id; const auto& device = JUST(parallel_desc->GetTensorDevice4CurrentProcessCtx(¶llel_id)); return EagerGlobalTensorImpl::New(global_tensor_meta, device, parallel_id, requires_grad, is_leaf); } namespace { Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const Optional& parallel_id) { if (parallel_id.has_value()) { return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, JUST(parallel_id)); } else { return std::make_shared(DimVector(logical_shape.NumAxes(), 0)); } } } // namespace /* static */ Maybe EagerGlobalTensorImpl::New( Symbol global_tensor_meta, Symbol device, const Optional& parallel_id, bool requires_grad, bool is_leaf) { const auto& shape = global_tensor_meta->shape_ptr(); const auto& dtype = global_tensor_meta->dtype(); const auto& memory_format = global_tensor_meta->memory_format(); const auto& nd_sbp = global_tensor_meta->nd_sbp(); const auto& parallel_desc = global_tensor_meta->parallel_desc(); const auto& cur_rank_phy_shape = JUST(GetPhysicalShape(*shape, *nd_sbp, *parallel_desc, parallel_id)); std::shared_ptr cur_rank_phy_tensor; // If the `'parallel_desc` doesn't cover current ProcessCtx or the tensor has 0-size shape, there // is no need to compute through the corresponding opkernel, and can be obtained directly through // empty op. if (parallel_id.has_value() && shape->elem_cnt() != 0) { const auto& cur_rank_phy_tensor_meta = SymbolOf(LocalTensorMeta(*cur_rank_phy_shape, dtype, memory_format, device)); auto cur_rank_phy_tensor_impl = std::make_shared(requires_grad, is_leaf); const auto& dep_object = NewLocalDepObject(); JUST(cur_rank_phy_tensor_impl->InitEagerBlobObject(cur_rank_phy_tensor_meta, dep_object)); cur_rank_phy_tensor = std::make_shared(cur_rank_phy_tensor_impl); } else { const auto& dtype_symbol = JUST(DType::Get(dtype)); const auto& empty = JUST(functional::Empty(*cur_rank_phy_shape, dtype_symbol, device, /*requires_grad=*/requires_grad, /*pin_memory=*/false)); cur_rank_phy_tensor = JUST(empty->AsLocalTensor()); JUST(cur_rank_phy_tensor->set_requires_grad(requires_grad)); cur_rank_phy_tensor->set_is_leaf(is_leaf); } auto* tensor_impl = new EagerGlobalTensorImpl(global_tensor_meta, cur_rank_phy_tensor); return std::shared_ptr(tensor_impl); } Maybe EagerGlobalTensorImpl::detach() const { auto detached_impl = std::shared_ptr(new EagerGlobalTensorImpl( tensor_meta_, JUST(JUST(cur_rank_phy_tensor_->detach())->AsLocalTensor()))); detached_impl->consumer_nd_sbp_constraint_ = consumer_nd_sbp_constraint_; detached_impl->transport_token_ = transport_token_; return std::shared_ptr(detached_impl); } std::shared_ptr EagerGlobalTensorImpl::stride() const { if (!cur_rank_phy_tensor_) { return tensor_meta()->stride_ptr(); } return cur_rank_phy_tensor_->tensor_meta().stride_ptr(); } MemoryFormat EagerGlobalTensorImpl::memory_format() const { if (!cur_rank_phy_tensor_) { return tensor_meta()->memory_format(); } return cur_rank_phy_tensor_->tensor_meta().memory_format(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_impl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/eager/local_dep_object.h" namespace oneflow { class MemoryCase; class Shape; class Stride; class Device; namespace vm { class EagerBlobObject; class TensorStorage; } // namespace vm namespace one { class Tensor; class TensorArg; class TensorImpl { public: virtual ~TensorImpl() = default; // Getters virtual std::shared_ptr shape() const = 0; virtual std::shared_ptr stride() const = 0; virtual MemoryFormat memory_format() const = 0; virtual DataType dtype() const = 0; virtual bool is_lazy() const = 0; // Getters valid only for EagerLocalTensorImpl virtual Maybe eager_blob_object() const = 0; virtual Maybe compute_local_dep_object() const = 0; virtual Maybe tensor_storage() const { OF_UNIMPLEMENTED(); } virtual Maybe has_eager_blob_object() const = 0; virtual Maybe storage_offset() const { OF_UNIMPLEMENTED(); } virtual bool is_contiguous() const = 0; virtual bool is_view() const = 0; virtual Maybe is_pinned() const { OF_UNIMPLEMENTED(); } // Getters for autograd Maybe acc_grad() const; Maybe current_grad() const; bool requires_grad() const { return autograd_meta_->requires_grad(); } bool is_leaf() const { return autograd_meta_->is_leaf(); } bool retain_grad() const { return autograd_meta_->retain_grad(); } // Setters for autograd Maybe set_acc_grad(const std::shared_ptr& grad); Maybe mut_acc_grad(); Maybe set_requires_grad(bool requires_grad); Maybe set_retain_grad(bool retain_grad); void set_is_leaf(bool is_leaf) { autograd_meta_->set_is_leaf(is_leaf); } std::shared_ptr autograd_meta() const { return autograd_meta_; } std::shared_ptr mut_autograd_meta() { return autograd_meta_; } void set_autograd_meta(const std::shared_ptr& autograd_meta) { autograd_meta_ = autograd_meta; } virtual Maybe RegisterStorageDeleteHook(const std::function& hook) { OF_UNIMPLEMENTED(); } protected: TensorImpl(bool requires_grad, bool is_leaf) : autograd_meta_(std::make_shared(requires_grad, is_leaf)) {} protected: std::shared_ptr autograd_meta_; }; class EagerLocalTensorImpl; class LocalTensorImpl : public TensorImpl { public: virtual ~LocalTensorImpl() = default; // Getters DataType dtype() const override { return tensor_meta()->dtype(); } const Symbol& device() const { return tensor_meta()->device(); } bool is_contiguous() const override { return tensor_meta()->is_contiguous(); } bool is_view() const override { return tensor_meta()->is_view(); } virtual const Symbol& tensor_meta() const = 0; // Setters virtual const std::shared_ptr& mut_tensor_meta() = 0; Maybe*> mut_device() { return std::const_pointer_cast(mut_tensor_meta())->mut_device(); } virtual Maybe mut_eager_local_tensor_impl() { RETURN_ERROR_WITH_BUG_PROMPT(); } virtual Maybe detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); } protected: LocalTensorImpl(bool requires_grad, bool is_leaf) : TensorImpl(requires_grad, is_leaf) {} }; class LocalTensor; class GlobalTensorImpl : public TensorImpl { public: virtual ~GlobalTensorImpl() = default; // Getters std::shared_ptr shape() const override { return tensor_meta_->shape_ptr(); } std::shared_ptr stride() const override { return tensor_meta_->stride_ptr(); } MemoryFormat memory_format() const override { return tensor_meta_->memory_format(); } DataType dtype() const override { return tensor_meta_->dtype(); } Symbol nd_sbp() const { return tensor_meta_->nd_sbp(); } Symbol parallel_desc() const { return tensor_meta_->parallel_desc(); } const Optional>& consumer_nd_sbp_constraint() const { return consumer_nd_sbp_constraint_; } virtual Maybe cur_rank_phy_tensor() const { RETURN_ERROR_WITH_BUG_PROMPT(); } Symbol tensor_meta() const { return tensor_meta_; } // Getters valid only for EagerLocalTensorImpl Maybe eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe compute_local_dep_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } // Setters void set_consumer_nd_sbp_constraint(const Optional>& val) { consumer_nd_sbp_constraint_ = val; } GlobalTensorMeta* mut_tensor_meta() { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } Maybe transport_token() const { return JUST(transport_token_); } Maybe set_transport_token(const TransportToken& transport_token) { transport_token_ = transport_token; return Maybe::Ok(); } virtual Maybe detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); } protected: GlobalTensorImpl(Symbol tensor_meta, bool requires_grad, bool is_leaf) : TensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta), consumer_nd_sbp_constraint_(), transport_token_() {} Symbol tensor_meta_; Optional> consumer_nd_sbp_constraint_; Optional transport_token_; }; class LazyLocalTensorImpl final : public LocalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(LazyLocalTensorImpl); LazyLocalTensorImpl(const Symbol& tensor_meta, bool requires_grad, bool is_leaf) : LocalTensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta) {} ~LazyLocalTensorImpl() override = default; // Getters const Symbol& tensor_meta() const override { return tensor_meta_; } std::shared_ptr shape() const override { return tensor_meta()->shape_ptr(); } std::shared_ptr stride() const override { return tensor_meta()->stride_ptr(); } MemoryFormat memory_format() const override { return tensor_meta()->memory_format(); } bool is_lazy() const override { return true; } bool is_contiguous() const override { // TODO:(zhaoluyang) default return true for now, // but should return real status while stride/view mechanism is ready in lazy-local mode return true; } bool is_view() const override { return false; } Maybe is_pinned() const override { return false; } const std::shared_ptr& mut_tensor_meta() override { PRINT_BUG_PROMPT_AND_ABORT(); } // Getters valid only for EagerLocalTensorImpl Maybe eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe compute_local_dep_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe detach() const override; private: Symbol tensor_meta_; }; class EagerLocalTensorImpl final : public LocalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(EagerLocalTensorImpl); EagerLocalTensorImpl() : EagerLocalTensorImpl(std::shared_ptr(), 0, false, false) {} EagerLocalTensorImpl(const std::shared_ptr& tensor_storage, bool requires_grad, bool is_leaf) : EagerLocalTensorImpl(tensor_storage, 0, requires_grad, is_leaf) {} EagerLocalTensorImpl(const std::shared_ptr& tensor_storage, int64_t storage_offset, bool requires_grad, bool is_leaf); EagerLocalTensorImpl(bool requires_grad, bool is_leaf) : EagerLocalTensorImpl(std::shared_ptr(), 0, requires_grad, is_leaf) {} ~EagerLocalTensorImpl() override; const std::shared_ptr& mut_tensor_meta() override; // Getters const Symbol& tensor_meta() const override; std::shared_ptr shape() const override; std::shared_ptr stride() const override; MemoryFormat memory_format() const override; Maybe detach() const override; bool is_lazy() const override { return false; } bool is_contiguous() const override { return tensor_meta()->is_contiguous(); } bool is_view() const override { return tensor_meta()->is_view(); } Maybe is_pinned() const override; // Getters valid only for EagerLocalTensorImpl Maybe eager_blob_object() const override { CHECK_OR_RETURN(eager_blob_object_); return eager_blob_object_; } Maybe compute_local_dep_object() const override; Maybe tensor_storage() const override { CHECK_OR_RETURN(eager_blob_object_); return tensor_storage_; } Maybe has_eager_blob_object() const override { return eager_blob_object_.get(); } Maybe storage_offset() const override { return storage_offset_; } // Setters TensorStorage* mut_tensor_storage() { return tensor_storage_.get(); } void set_storage_offset(int64_t offset) { storage_offset_ = offset; } Maybe InitEagerBlobObject( const Symbol& local_tensor_meta, const std::shared_ptr& mut_local_tensor_meta, const intrusive::shared_ptr& dep_object); Maybe InitEagerBlobObject(const Symbol& local_tensor_meta, const intrusive::shared_ptr& dep_object) { JUST(InitEagerBlobObject(local_tensor_meta, std::shared_ptr(), dep_object)); return Maybe::Ok(); } Maybe mut_eager_local_tensor_impl() override { return this; } Maybe RegisterStorageDeleteHook(const std::function& hook) override; private: Maybe UpdateTensorStorage(); Maybe set_eager_blob_object(std::shared_ptr eager_blob_object); std::shared_ptr tensor_storage_; int64_t storage_offset_; std::shared_ptr eager_blob_object_; }; class LazyGlobalTensorImpl final : public GlobalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(LazyGlobalTensorImpl); LazyGlobalTensorImpl(Symbol global_tensor_meta, bool requires_grad, bool is_leaf) : GlobalTensorImpl(global_tensor_meta, requires_grad, is_leaf) {} ~LazyGlobalTensorImpl() override = default; // Getters bool is_lazy() const override { return true; } bool is_contiguous() const override { // TODO:(zhaoluyang) default return true for now, // but should return real status while stride/view mechanism is ready in lazy-global mode return true; } bool is_view() const override { return false; } Maybe detach() const override; }; class EagerGlobalTensorImpl final : public GlobalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(EagerGlobalTensorImpl); ~EagerGlobalTensorImpl() override = default; // Getters std::shared_ptr stride() const override; MemoryFormat memory_format() const override; bool is_lazy() const override { return false; } bool is_contiguous() const override { // TODO:(zhaoluyang) default return true for now, // but should return real status while stride/view mechanism is ready in eager-global mode return true; } bool is_view() const override { return false; } Maybe cur_rank_phy_tensor() const override { return cur_rank_phy_tensor_; } void reset_cur_rank_phy_tensor(const std::shared_ptr& val) { cur_rank_phy_tensor_ = val; } static Maybe New(Symbol global_tensor_meta, bool requires_grad, bool is_leaf); static Maybe New(Symbol global_tensor_meta, Symbol device, const Optional& parallel_id, bool requires_grad, bool is_leaf); Maybe detach() const override; private: EagerGlobalTensorImpl(Symbol global_tensor_meta, const std::shared_ptr& cur_rank_phy_tensor); std::shared_ptr cur_rank_phy_tensor_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_ ================================================ FILE: oneflow/core/framework/tensor_methods.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/functional/functional_api.yaml.h" namespace oneflow { namespace one { namespace view { // NOTE: use env variable 'ONEFLOW_DISABLE_VIEW' control use view mechanism or not // If set true, then do not use view mechanism(and view ops) bool IsEnvViewDisabled() { static const bool env_view_disabled = ParseBooleanFromEnv("ONEFLOW_DISABLE_VIEW", false); return env_view_disabled; } bool IsViewApplicable(const std::shared_ptr& input) { if (IsEnvViewDisabled()) { return false; } // NOTE: only eager local tensor support view for now // elem_cnt() >= 1 used to excluding 0 shape tensor if (input->is_local() && !(LazyMode::is_enabled()) && input->shape()->elem_cnt() >= 1) { return true; } return false; } static bool IsOverlappingMemorys(const std::vector& sizes, const std::vector& strides) { // reference: torch/csrc/autograd/FunctionsManual.cpp _maybe_overlapping_memory() if (sizes.size() > 0) { std::vector argsort(sizes.size()); std::iota(argsort.begin(), argsort.end(), 0); std::sort(argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j) { return strides[i] < strides[j]; }); int64_t max_index_in_slice = 0; for (auto i : argsort) { auto stride_ = strides[i]; if (stride_ <= max_index_in_slice) { return true; } max_index_in_slice += stride_ * (sizes[i] - 1); } } return false; } static int64_t MinStorageSize(const std::vector& sizes, const std::vector& strides, int64_t storage_offset) { int64_t storage_size = storage_offset + 1; int64_t ndim = sizes.size(); for (size_t i = 0; i < ndim; i++) { auto size_i = sizes[i]; if (size_i == 0) { return storage_offset; } storage_size += (size_i - 1) * strides[i]; } return storage_size; } Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, const int64_t storage_offset) { /** * This function provides basic view capabilities which * accept input tensor with target shape, and return viewed tensor. * * The viewed tensor shared memory with input tensor, and both of * them are memory contiguous, but has different shapes/strides. */ Stride target_stride(target_shape); return BasicView(input, target_shape, target_stride, storage_offset); } Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride, const int64_t storage_offset) { auto device = JUST(input->device()); auto tensor_meta = SymbolOf(LocalTensorMeta(target_shape, target_stride, input->dtype()->data_type(), input->memory_format(), device, /*is_view=*/true)); CHECK_OR_RETURN(JUST(input->has_eager_blob_object())); // new output tensor const auto& blob_object = JUST(input->eager_blob_object()); bool requires_grad = (autograd::GradMode::is_enabled() && input->requires_grad()); auto tensor_impl = std::make_shared(JUST(input->tensor_storage()), storage_offset, requires_grad, /*is_leaf=*/!requires_grad); JUST( tensor_impl->InitEagerBlobObject(tensor_meta, JUST(blob_object->compute_local_dep_object()))); auto view_tensor = std::make_shared(tensor_impl); const std::shared_ptr& view_eager_blob_object = JUST(view_tensor->eager_blob_object()); view_eager_blob_object->set_storage_offset(JUST(view_tensor->storage_offset())); view_eager_blob_object->set_input_of_view_op(blob_object); return std::static_pointer_cast(view_tensor); } Maybe InplaceView(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride, const int64_t storage_offset) { Symbol new_tensor_meta = SymbolOf(LocalTensorMeta(target_shape, target_stride, input->dtype()->data_type(), input->memory_format(), JUST(input->device()))); bool requires_grad = (autograd::GradMode::is_enabled() && input->requires_grad()); std::shared_ptr new_tensor_impl = std::make_shared( JUST(input->tensor_storage()), storage_offset, /*requires_grad=*/requires_grad, /*is_leaf=*/!requires_grad); JUST(new_tensor_impl->InitEagerBlobObject( new_tensor_meta, JUST(JUST(input->eager_blob_object())->compute_local_dep_object()))); JUST(JUST(input->AsLocalTensor())->set_impl(new_tensor_impl)); return Maybe::Ok(); } Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape) { Stride target_stride(target_shape); return Reshape(input, target_shape, target_stride); } Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride) { int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); std::shared_ptr output = JUST(BasicView(input, target_shape, target_stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { Shape input_shape(input->shape()->dim_vec()); auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), input_shape)); return Maybe::Ok(); }; backward_fn->status = []() { return false; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::reshape_backward", backward_fn, {input}, &outputs)); } return output; } Maybe Slice(const std::shared_ptr& input, const std::vector& starts, const std::vector& ends, const std::vector& steps) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = starts.size(); CHECK_OR_RETURN(ndim == shape->NumAxes()) << Error::RuntimeError() << "view::Slice(): starts size is expected " << shape->NumAxes() << ", but got " << ndim; CHECK_OR_RETURN(ends.size() == ndim && steps.size() == ndim) << Error::RuntimeError() << "view::Slice(): " << (ends.size() != ndim ? "ends" : "steps") << " size is not equal to start."; DimVector target_dims(ndim); Stride target_strides(ndim); int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); for (int i = 0; i < ndim; ++i) { int64_t step = std::min(steps[i], shape->At(i)); CHECK_OR_RETURN(step >= 0) << Error::RuntimeError() << "Step must be greater than zero."; int64_t start = std::min(starts[i], shape->At(i)); int64_t end = std::min(ends[i], shape->At(i)); if (start < 0) { start += shape->At(i); } if (start < 0) start = 0; if (end < 0) { end += shape->At(i); } if (end < start) end = start; int64_t length = start == end ? 0 : (end - start + step - 1) / step; target_dims[i] = length; target_strides[i] = step * strides->at(i); storage_offset += start * strides->at(i); } auto output = JUST(BasicView(input, Shape(target_dims), target_strides, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { const Shape in_shape = *input->shape(); auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); (*in_grads)[0] = JUST(functional::SliceGrad(out_grads[0], in_shape, starts, ends, steps)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::slice_backward", backward_fn, {input}, &outputs)); } return output; } Maybe Unsqueeze(const std::shared_ptr& input, const int32_t expand_dim) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const auto& ndim = shape->NumAxes(); DimVector target_dim_vec(ndim + 1); Stride target_stride_vec(ndim + 1); { int cnt = 0; for (int i = 0; i < ndim; i++) { if (i == expand_dim) { cnt++; } target_dim_vec[cnt] = shape->at(i); target_stride_vec[cnt] = strides->at(i); cnt++; } target_dim_vec[expand_dim] = 1; target_stride_vec[expand_dim] = expand_dim < ndim ? strides->at(expand_dim) * target_dim_vec.at(expand_dim + 1) : 1; } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); std::shared_ptr output = JUST(BasicView(input, Shape(target_dim_vec), target_stride_vec, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), *shape)); return Maybe::Ok(); }; backward_fn->status = []() { return false; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::unsqueeze_backward", backward_fn, {input}, &outputs)); } return output; } Maybe InplaceUnsqueeze(const std::shared_ptr& input, const int32_t expand_dim) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const auto& ndim = shape->NumAxes(); DimVector target_dim_vec(ndim + 1); Stride target_stride_vec(ndim + 1); { int cnt = 0; for (int i = 0; i < ndim; i++) { if (i == expand_dim) { cnt++; } target_dim_vec[cnt] = shape->at(i); target_stride_vec[cnt] = strides->at(i); cnt++; } target_dim_vec[expand_dim] = 1; target_stride_vec[expand_dim] = expand_dim < ndim ? strides->at(expand_dim) * target_dim_vec.at(expand_dim + 1) : 1; } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); JUST(view::InplaceView(input, Shape(target_dim_vec), target_stride_vec, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), *shape)); return Maybe::Ok(); }; backward_fn->status = []() { return false; }; TensorTuple outputs{input}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::inplace_unsqueeze_backward", backward_fn, {input}, &outputs)); } return Maybe::Ok(); } Maybe Squeeze(const std::shared_ptr& input, const std::vector& squeeze_dims) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); const int target_ndim = ndim - squeeze_dims.size(); DimVector target_dim_vec(target_ndim); Stride target_stride_vec(target_ndim); { int cnt = 0; for (int i = 0; i < ndim; i++) { if (find(squeeze_dims.begin(), squeeze_dims.end(), i) == squeeze_dims.end()) { target_dim_vec[cnt] = shape->At(i); target_stride_vec[cnt] = strides->at(i); cnt++; } } } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); std::shared_ptr output = JUST(BasicView(input, Shape(target_dim_vec), target_stride_vec, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape( JUST(oneflow::VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec()))); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::squeeze_backward", backward_fn, {input}, &outputs)); } return output; } Maybe InplaceSqueeze(const std::shared_ptr& input, const std::vector& squeeze_dims) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); const int target_ndim = ndim - squeeze_dims.size(); DimVector target_dim_vec(target_ndim); Stride target_stride_vec(target_ndim); { int cnt = 0; for (int i = 0; i < ndim; i++) { if (find(squeeze_dims.begin(), squeeze_dims.end(), i) == squeeze_dims.end()) { target_dim_vec[cnt] = shape->At(i); target_stride_vec[cnt] = strides->at(i); cnt++; } } } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); JUST(view::InplaceView(input, Shape(target_dim_vec), target_stride_vec, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) in_grads->resize(1); JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape( JUST(oneflow::VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec()))); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{input}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::inplace_squeeze_backward", backward_fn, {input}, &outputs)); } return Maybe::Ok(); } Maybe Expand(const std::shared_ptr& input, const Shape& expand_shape) { const Shape& input_shape = *input->shape(); const Stride& input_stride = *JUST(input->stride()); size_t lpad = expand_shape.size() - input_shape.size(); CHECK_GE_OR_RETURN(lpad, 0); // NOLINT(maybe-need-error-msg) Stride expand_stride(expand_shape.size(), 0); std::vector reduce_dims; reduce_dims.reserve(expand_shape.size()); for (int i = expand_shape.size() - 1; i >= 0; --i) { int64_t dim = i < lpad ? 1 : input_shape[i - lpad]; if (dim == expand_shape[i]) { if (i >= lpad) { expand_stride[i] = input_stride[i - lpad]; } else if (i < expand_shape.size() - 1) { expand_stride[i] = expand_stride[i + 1] * expand_shape[i + 1]; } } else { CHECK_EQ_OR_RETURN(dim, 1); // NOLINT(maybe-need-error-msg) reduce_dims.push_back(i); } } if (input_shape.size() == 0) { // handle scalar expand backward reduce dims reduce_dims.clear(); for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { reduce_dims.push_back(axis); } } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); std::shared_ptr output = JUST(BasicView(input, expand_shape, expand_stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); in_grads->at(0) = out_grads[0]; bool keep_dims = (input_shape.size() > 0); if (reduce_dims.size() > 0) { in_grads->at(0) = JUST(functional::ReduceSum(in_grads->at(0), reduce_dims, keep_dims, NullOpt)); } if (lpad > 0 && keep_dims) { in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, lpad)); } return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::expand_backward", backward_fn, {input}, &outputs)); } return output; } Maybe InplaceExpand(const std::shared_ptr& input, const Shape& expand_shape) { const Shape& input_shape = *input->shape(); const Stride& input_stride = *JUST(input->stride()); size_t lpad = expand_shape.size() - input_shape.size(); CHECK_GE_OR_RETURN(lpad, 0); // NOLINT(maybe-need-error-msg) Stride expand_stride(expand_shape.size(), 0); std::vector reduce_dims; reduce_dims.reserve(expand_shape.size()); for (int i = expand_shape.size() - 1; i >= 0; --i) { int64_t dim = i < lpad ? 1 : input_shape[i - lpad]; if (dim == expand_shape[i]) { if (i >= lpad) { expand_stride[i] = input_stride[i - lpad]; } else if (i < expand_shape.size() - 1) { expand_stride[i] = expand_stride[i + 1] * expand_shape[i + 1]; } } else { CHECK_EQ_OR_RETURN(dim, 1); // NOLINT(maybe-need-error-msg) reduce_dims.push_back(i); } } if (input_shape.size() == 0) { // handle scalar expand backward reduce dims reduce_dims.clear(); for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { reduce_dims.push_back(axis); } } int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); JUST(view::InplaceView(input, expand_shape, expand_stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); in_grads->at(0) = out_grads[0]; bool keep_dims = (input_shape.size() > 0); if (reduce_dims.size() > 0) { in_grads->at(0) = JUST(functional::ReduceSum(in_grads->at(0), reduce_dims, keep_dims, NullOpt)); } if (lpad > 0 && keep_dims) { in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, lpad)); } return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{input}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::expand_backward", backward_fn, {input}, &outputs)); } return Maybe::Ok(); } Maybe Narrow(const std::shared_ptr& input, const int64_t dim, const int64_t start, const int64_t length) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); DimVector dim_vec; dim_vec.insert(dim_vec.end(), shape->dim_vec().cbegin(), shape->dim_vec().cbegin() + dim); dim_vec.insert(dim_vec.end(), length); dim_vec.insert(dim_vec.end(), shape->dim_vec().cbegin() + dim + 1, shape->dim_vec().end()); int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); Shape target_shape(dim_vec); Stride stride(ndim); for (int i = 0; i < ndim; ++i) { stride[i] = strides->at(i); if (dim == i) { storage_offset += start * strides->at(i); } } auto output = JUST(BasicView(input, target_shape, stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); auto like = JUST(functional::Empty(Shape(input->shape()->dim_vec()), input->dtype(), JUST(input->device()), /*requires_grad=*/input->requires_grad(), /*pin_memory=*/false)); in_grads->resize(1); (*in_grads)[0] = JUST(functional::NarrowGrad(out_grads[0], like, dim, start, length)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::narrow_backward", backward_fn, {input}, &outputs)); } return output; } Maybe AsStridedGrad(const std::shared_ptr& dy, const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset) { CHECK_OR_RETURN(input->is_local()) << "input must be local tensor."; // reference: torch/csrc/autograd/FunctionsManual.cpp const size_t odim = dy->ndim(); std::vector out_sizes_, out_strides_; out_sizes_.reserve(odim); out_strides_.reserve(odim); auto grad = dy; for (int64_t i = odim - 1; i >= 0; i--) { auto size_i = sizes[i]; auto stride_i = strides[i]; if (size_i == 0) { return functional::Constant(*dy->shape(), 0, grad->dtype(), JUST(grad->device())); } else if (size_i == 1) { grad = JUST(functional::Squeeze(grad, std::vector{int(i)})); } else if (stride_i == 0) { grad = JUST(functional::ReduceSum(grad, std::vector{int(i)}, false, NullOpt)); } else { out_sizes_.insert(out_sizes_.begin(), size_i); out_strides_.insert(out_strides_.begin(), stride_i); } } // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A // Strided Tensor ] // on output geometry const bool out_maybe_overlap = IsOverlappingMemorys(out_sizes_, out_strides_); // For input geometry, // check for size 0 dimensions, // skip size 1 dimensions, // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A // Strided Tensor ] // on input geometry auto idim = input->ndim(); std::vector inp_sizes(input->shape()->begin(), input->shape()->end()); std::vector inp_strides(JUST(input->stride())->begin(), JUST(input->stride())->end()); std::vector inp_sizes_, inp_strides_; inp_sizes_.reserve(idim); inp_strides_.reserve(idim); for (int64_t i = idim - 1; i >= 0; i--) { auto size_i = inp_sizes[i]; auto stride_i = inp_strides[i]; if (size_i == 0) { return functional::Constant(*input->shape(), 0, grad->dtype(), JUST(grad->device())); } else if (size_i != 1) { inp_sizes_.insert(inp_sizes_.begin(), size_i); inp_strides_.insert(inp_strides_.begin(), stride_i); } } // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A // Strided Tensor ] // on input geometry const bool inp_maybe_overlap = IsOverlappingMemorys(inp_sizes_, inp_strides_); // Rest of this function implements // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and // layout-aware/agnostic autograd ] // TODO: Raise if not all output values are visible in input geometry. // Technically speaking, if you treat those values as constants, not // raising is fine, and mathematically correct. However, these values // really are contained in some base tensor, and by treating them as // constants we are ignoring this tight dependency. Therefore, it is // more sensible to raise here. // Step (1): create underlying tensor as "storage" auto input_storage_offset = JUST(input->storage_offset()); auto shared_offset = std::min(input_storage_offset, storage_offset); auto inp_effective_offset = input_storage_offset - shared_offset; auto out_effective_offset = storage_offset - shared_offset; auto base_size = std::max(MinStorageSize(inp_sizes_, inp_strides_, inp_effective_offset), MinStorageSize(out_sizes_, out_strides_, out_effective_offset)); auto storage = JUST(functional::Constant(Shape({base_size}), 0, grad->dtype(), JUST(grad->device()))); std::shared_ptr flatten_full_indices; if (inp_maybe_overlap || out_maybe_overlap) { flatten_full_indices = JUST(functional::Arange(Scalar(0), Scalar(base_size), Scalar(1), DType::Int64(), JUST(grad->device()))); } // Step (2): use output geometry to scatter gradients into storage if (out_maybe_overlap) { auto out_indices = JUST(functional::AsStrided(flatten_full_indices, out_sizes_, out_strides_, out_effective_offset)); storage = JUST(functional::IndexAddInplace( storage, 0, JUST(functional::Reshape(out_indices, Shape({out_indices->shape()->elem_cnt()}))), JUST(functional::Reshape(grad, Shape({grad->shape()->elem_cnt()}))), Scalar(1.0))); } else { // assume that new tensors have 0 storage offset // torch impl: storage.as_strided(out_sizes_, out_strides_, out_effective_offset) // .copy_(grad); // TODO(wangyinggang): use functional::copy_ replace this TensorSetItem storage = JUST(functional::AsStrided(storage, out_sizes_, out_strides_, out_effective_offset)); functional::TensorIndex ellipsis_index; ellipsis_index.emplace_back(functional::detail::EllipsisIndex()); JUST(functional::TensorSetItem(storage, ellipsis_index, grad)); } // Step (3): if input tensor has overlapping memory, divide scattered gradient // at storage[i] by the number of times i shows up in input geometry if (inp_maybe_overlap) { auto count = JUST(functional::Constant(*storage->shape(), 0, storage->dtype(), JUST(storage->device()))); flatten_full_indices = JUST(functional::AsStrided(flatten_full_indices, inp_sizes_, inp_strides_, inp_effective_offset)); auto inp_indices = JUST(functional::Reshape( flatten_full_indices, Shape({flatten_full_indices->shape()->elem_cnt()}))); auto ones = JUST(functional::Constant(Shape({1}), 0, grad->dtype(), JUST(grad->device()))); count = JUST(functional::IndexAddInplace(count, 0, inp_indices, ones, Scalar(1.0))); count = JUST(functional::Expand(count, *inp_indices->shape())); storage = JUST(functional::Div(storage, count)); // this will give nan outside visible range } // Step (4): return as_strided view of the storage tensor with input geometry return functional::AsStrided(storage, inp_sizes, inp_strides, inp_effective_offset); } Maybe AsStrided(const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset) { DimVector dim_vec; dim_vec.insert(dim_vec.end(), sizes.begin(), sizes.end()); Shape target_shape(dim_vec); Stride stride(strides.begin(), strides.end()); auto output = JUST(view::BasicView(input, target_shape, stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); (*in_grads)[0] = JUST(AsStridedGrad(out_grads[0], input, sizes, strides, storage_offset)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::as_strided_backward", backward_fn, {input}, &outputs)); } return output; } Maybe InplaceAsStrided(const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset) { DimVector dim_vec; dim_vec.insert(dim_vec.end(), sizes.begin(), sizes.end()); Shape target_shape(dim_vec); Stride stride(strides.begin(), strides.end()); JUST(view::InplaceView(input, target_shape, stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); (*in_grads)[0] = JUST(AsStridedGrad(out_grads[0], input, sizes, strides, storage_offset)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{input}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::inplace_as_strided_backward", backward_fn, {input}, &outputs)); } return Maybe::Ok(); } Maybe Transpose(const std::shared_ptr& input, const std::vector& permute) { const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); CHECK_EQ_OR_RETURN(permute.size(), ndim) << "permute size should be equal to input tensor's ndim, but got " << permute.size(); auto positive_perm = permute; for (auto i = 0; i < positive_perm.size(); i++) { positive_perm[i] = JUST(maybe_wrap_dim(positive_perm[i], ndim)); } DimVector target_dims(ndim); Stride stride(ndim); for (int i = 0; i < ndim; ++i) { target_dims[i] = shape->At(permute[i]); stride[i] = strides->at(permute[i]); } auto output = JUST(BasicView(input, Shape(target_dims), stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { std::vector grad_perm; grad_perm.resize(ndim); for (int i = 0; i < ndim; ++i) { grad_perm[permute[i]] = i; } autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); (*in_grads)[0] = JUST(functional::Transpose(out_grads[0], grad_perm)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::transpose_backward", backward_fn, {input}, &outputs)); } return output; } Maybe UnfoldTensor(const std::shared_ptr& input, const int32_t dimension, const int32_t size, const int32_t step) { const auto& shape = input->shape(); const auto& stride = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); CHECK_GE_OR_RETURN(dimension, 0) << "attibute dimension should be >= 0, but got " << dimension; CHECK_LE_OR_RETURN(dimension, ndim) << "attibute dimension should be <= input tensor's ndim, but got " << dimension; const int32_t max_size = ndim == 0 ? 1 : shape->At(dimension); CHECK_GT_OR_RETURN(size, 0) << "attibute size should be > 0, but got " << size; CHECK_LE_OR_RETURN(size, max_size) << "attibute size should be <= max_size(" << max_size << ") but got " << size; CHECK_GT_OR_RETURN(step, 0) << "attibute step should be > 0, but got " << size; DimVector out_shape(ndim + 1); Stride out_stride(ndim + 1); out_shape[ndim] = size; out_stride[ndim] = ndim == 0 ? 1 : stride->at(dimension); for (int64_t d = 0; d < ndim; ++d) { const int64_t in_size_at_d = shape->At(d); if (d == dimension) { out_shape.at(d) = (in_size_at_d - size) / step + 1; out_stride.at(d) = step * stride->at(d); } else { out_shape.at(d) = in_size_at_d; out_stride.at(d) = stride->at(d); } } auto output = JUST(BasicView(input, Shape(out_shape), out_stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); (*in_grads)[0] = JUST(functional::UnfoldTensorGrad(out_grads[0], input, dimension, size, step)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::unfold_tensor_backward", backward_fn, {input}, &outputs)); } return output; } Maybe Diagonal(const std::shared_ptr& input, const int32_t offset, const int32_t dim1, const int32_t dim2) { const auto& shape = input->shape(); const auto& stride = JUST(input->stride()); const int64_t ndim = shape->NumAxes(); int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset()); // infer output storage_offset int64_t diag_size = 0; if (offset >= 0) { diag_size = std::max(std::min(shape->At(dim1), shape->At(dim2) - offset), 0); } else { diag_size = std::max(std::min(shape->At(dim1) + offset, shape->At(dim2)), 0); } if (diag_size == 0) { // skip } else if (offset >= 0) { storage_offset += offset * stride->at(dim2); } else { storage_offset -= offset * stride->at(dim1); } CHECK_GE_OR_RETURN(ndim, 2) << "input tensor's ndim should be >= 2, but got " << ndim; // infer output shape and stride DimVector out_shape(shape->dim_vec()); Stride out_stride(*stride); out_shape.erase(out_shape.begin() + std::max(dim1, dim2)); out_stride.erase(out_stride.begin() + std::max(dim1, dim2)); out_shape.erase(out_shape.begin() + std::min(dim1, dim2)); out_stride.erase(out_stride.begin() + std::min(dim1, dim2)); out_shape.emplace_back(diag_size); out_stride.emplace_back(stride->at(dim1) + stride->at(dim2)); // generate view tensor auto output = JUST(BasicView(input, Shape(out_shape), out_stride, storage_offset)); // autograd if (autograd::GradMode::is_enabled() && input->requires_grad()) { std::vector input_index{dim1, dim2}; for (int32_t i = 0; i < ndim; i++) { if (i != dim1 && i != dim2) { input_index.push_back(i); } } auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, bool create_graph) -> Maybe { autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out grad size should be 1, but got " << out_grads.size(); in_grads->resize(1); std::shared_ptr d_x = JUST(functional::Transpose(input, input_index)); (*in_grads)[0] = JUST(functional::DiagonalGrad(out_grads[0], d_x, offset)); return Maybe::Ok(); }; backward_fn->status = []() { return true; }; TensorTuple outputs{output}; JUST(GetThreadLocalAutogradEngine()->AddNode("view::diagonal_backward", backward_fn, {input}, &outputs)); } return output; } } // namespace view Maybe Touch(std::shared_ptr input, Symbol stream) { auto eager_blob_objects = std::make_shared(); if (input->is_global()) { input = JUST(input->cur_rank_phy_tensor()); } if (input) { eager_blob_objects->push_back(JUST(input->eager_blob_object())); } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->TouchTensors(eager_blob_objects, stream); })); return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_methods.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_METHODS_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_METHODS_H_ #include "oneflow/core/framework/tensor.h" namespace oneflow { class Stream; namespace one { class Tensor; namespace view { bool IsEnvViewDisabled(); bool IsViewApplicable(const std::shared_ptr& input); static bool IsOverlappingMemorys(const std::vector& sizes, const std::vector& strides); static int64_t MinStorageSize(const std::vector& sizes, const std::vector& strides, int64_t storage_offset); Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, const int64_t storage_offset); Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride, const int64_t storage_offset); Maybe InplaceView(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride, int64_t const storage_offset); Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape); Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride); Maybe Slice(const std::shared_ptr& input, const std::vector& starts, const std::vector& ends, const std::vector& steps); Maybe Unsqueeze(const std::shared_ptr& input, const int32_t expand_dim); Maybe InplaceUnsqueeze(const std::shared_ptr& input, const int32_t expand_dim); Maybe Squeeze(const std::shared_ptr& input, const std::vector& squeeze_dims); Maybe InplaceSqueeze(const std::shared_ptr& input, const std::vector& squeeze_dims); Maybe Expand(const std::shared_ptr& input, const Shape& expand_shape); Maybe InplaceExpand(const std::shared_ptr& input, const Shape& expand_shape); Maybe Narrow(const std::shared_ptr& input, const int64_t dim, const int64_t start, const int64_t length); Maybe AsStridedGrad(const std::shared_ptr& dy, const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset); Maybe AsStrided(const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset); Maybe InplaceAsStrided(const std::shared_ptr& input, const std::vector& sizes, const std::vector& strides, const int64_t storage_offset); Maybe Transpose(const std::shared_ptr& input, const std::vector& permute); Maybe UnfoldTensor(const std::shared_ptr& input, const int32_t dimension, const int32_t size, const int32_t step); Maybe Diagonal(const std::shared_ptr& input, const int32_t offset, const int32_t dim1, const int32_t dim2); } // namespace view Maybe Touch(std::shared_ptr input, Symbol stream); } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_ ================================================ FILE: oneflow/core/framework/tensor_name_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_name_scope.h" #include namespace oneflow { namespace one { /* static */ TensorNameScope* TensorNameScope::Global() { static TensorNameScope scope; return &scope; } const std::string& TensorNameScope::Lookup(const Tensor* tensor) const { uint64_t key = reinterpret_cast(tensor); const auto* tensor_names = [&]() { if (tensor->is_lazy()) { return &lazy_tensor_names_; } return &eager_tensor_names_; }(); std::lock_guard lock(mutex_); const auto& it = tensor_names->find(key); if (it != tensor_names->end()) { return it->second; } else { return default_tensor_name_; } } const std::string& TensorNameScope::Lookup(const std::shared_ptr& tensor) const { return Lookup(tensor.get()); } void TensorNameScope::Record(const Tensor* tensor, const std::string& name) { uint64_t key = reinterpret_cast(tensor); auto* tensor_names = [&]() { if (tensor->is_lazy()) { return &lazy_tensor_names_; } return &eager_tensor_names_; }(); std::lock_guard lock(mutex_); // We assume that the name of the tensor will be update more than once. (*tensor_names)[key] = name; } void TensorNameScope::Record(const std::shared_ptr& tensor, const std::string& name) { Record(tensor.get(), name); } void TensorNameScope::Clear() { std::lock_guard lock(mutex_); lazy_tensor_names_.clear(); eager_tensor_names_.clear(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_name_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_ #include #include "oneflow/core/framework/tensor.h" namespace oneflow { namespace one { class TensorNameScope { public: static TensorNameScope* Global(); const std::string& Lookup(const Tensor* tensor) const; const std::string& Lookup(const std::shared_ptr& tensor) const; void Record(const Tensor* tensor, const std::string& name); void Record(const std::shared_ptr& tensor, const std::string& name); void Clear(); private: TensorNameScope() : default_tensor_name_("") {} virtual ~TensorNameScope() = default; private: mutable std::mutex mutex_; std::string default_tensor_name_; // uint64_t(Tensor*) -> the name of the tensor. std::unordered_map lazy_tensor_names_; std::unordered_map eager_tensor_names_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_ ================================================ FILE: oneflow/core/framework/tensor_rpc_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/sync_symbol_global_tensor_meta.h" #include "oneflow/core/framework/sync_symbol_nd_sbp.h" #include "oneflow/core/framework/synced_symbol_map.h" #include "oneflow/core/framework/rank_group_rpc_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/flat_shape.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/common/constant.h" namespace oneflow { namespace private_details { struct FlatTensorConsistency; class CheckConsistencyAsyncTransportCtx : public AsyncTransportCtx { public: CheckConsistencyAsyncTransportCtx(const TransportToken& transport_token, Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint, const TransportToken& tensor_transport_token) : AsyncTransportCtx(transport_token), tensor_meta_(tensor_meta), consumer_nd_sbp_constraint_(consumer_nd_sbp_constraint), tensor_transport_token_(tensor_transport_token) {} ~CheckConsistencyAsyncTransportCtx() override; Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override; Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override; Maybe Check() const; private: Symbol tensor_meta_; Optional> consumer_nd_sbp_constraint_; TransportToken tensor_transport_token_; std::shared_ptr flat_tensor_consistency_; }; // clang-format off FLAT_MSG_BEGIN(FlatTensorConsistency); public: static Maybe New() { const auto& consistency = std::make_shared(); consistency->clear(); return consistency; } static Maybe New( Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint, const TransportToken& tensor_transport_token) { const auto& consistency = std::make_shared(); consistency->clear(); JUST(consistency->Init(tensor_meta, consumer_nd_sbp_constraint, tensor_transport_token)); return consistency; } Maybe Check(Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint, const TransportToken& tensor_transport_token) { const auto& this_synced_tensor_meta = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->synced_tensor_meta_symbol_id())); CHECK_OR_RETURN(this_synced_tensor_meta == tensor_meta); CHECK_EQ_OR_RETURN(consumer_nd_sbp_constraint.has_value(), this->has_consumer_nd_sbp_constraint_symbol_id()); if (this->has_consumer_nd_sbp_constraint_symbol_id()) { const auto& that_rank_constaint = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->consumer_nd_sbp_constraint_symbol_id()))->nd_sbp(); const auto& this_rank_constaint = JUST(consumer_nd_sbp_constraint); CHECK_OR_RETURN(this_rank_constaint == that_rank_constaint); } CHECK_EQ_OR_RETURN(this->tensor_transport_token(), tensor_transport_token); return Maybe::Ok(); } private: Maybe Init(Symbol tensor_meta, const Optional>& consumer_nd_sbp_constraint, const TransportToken& tensor_transport_token) { this->set_synced_tensor_meta_symbol_id(JUST(SyncedSymbolMap::FindOrSync( tensor_meta, &SyncSymbolGlobalTensorMeta))); if (consumer_nd_sbp_constraint.has_value()) { const auto& this_rank_constaint = JUST(consumer_nd_sbp_constraint); this->set_consumer_nd_sbp_constraint_symbol_id( JUST(SyncedSymbolMap::FindOrSync( this_rank_constaint, &SyncSymbolNdSbp))); } else { this->clear_consumer_nd_sbp_constraint_symbol_id(); } this->set_tensor_transport_token(static_cast(tensor_transport_token)); return Maybe::Ok(); } FLAT_MSG_DEFINE_OPTIONAL(uint64_t, synced_tensor_meta_symbol_id); FLAT_MSG_DEFINE_OPTIONAL(uint64_t, consumer_nd_sbp_constraint_symbol_id); FLAT_MSG_DEFINE_OPTIONAL(uint64_t, tensor_transport_token); FLAT_MSG_END(FlatTensorConsistency); // clang-format on CheckConsistencyAsyncTransportCtx::~CheckConsistencyAsyncTransportCtx() {} Maybe CheckConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& tensor_consistency = JUST(FlatTensorConsistency::New( tensor_meta_, consumer_nd_sbp_constraint_, tensor_transport_token_)); *buffer = tensor_consistency.get(); *size = sizeof(FlatTensorConsistency); *Callback = [tensor_consistency] {}; return Maybe::Ok(); } Maybe CheckConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& flat_tensor_consistency = JUST(FlatTensorConsistency::New()); *buffer = flat_tensor_consistency.get(); *size = sizeof(FlatTensorConsistency); *Callback = [flat_tensor_consistency]() {}; flat_tensor_consistency_ = flat_tensor_consistency; return Maybe::Ok(); } Maybe CheckConsistencyAsyncTransportCtx::Check() const { if (!flat_tensor_consistency_) { return Maybe::Ok(); } JUST(flat_tensor_consistency_->Check(tensor_meta_, consumer_nd_sbp_constraint_, tensor_transport_token_)); return Maybe::Ok(); } int64_t* MutThreadLocalTensorMetaCheckDepth() { static thread_local int64_t depth = 0; return &depth; } Maybe LaunchTensorMetaConsistencyCheck( const one::Tensor& tensor) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckTensorConsistency)); const auto& tensor_meta = JUST(tensor.global_tensor_meta()); const auto& constaint = JUST(tensor.consumer_nd_sbp_constraint()); const TransportToken& tensor_transport_token = JUST(tensor.transport_token()); const auto& ctx = std::make_shared( transport_token, tensor_meta, constaint, tensor_transport_token); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); return ctx; } Maybe BusyWaitAndCheck(std::shared_ptr& ctx) { JUST_MSG(ctx->WaitDone(), kAsymmetricCodeErrorMsg); JUST(ctx->Check()); return Maybe::Ok(); } Maybe RunCallback(const std::shared_ptr& tensor, const std::function()>& Callback) { return Callback(); } } // namespace private_details } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_rpc_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/common/check_level.h" namespace oneflow { namespace private_details { class CheckConsistencyAsyncTransportCtx; int64_t* MutThreadLocalTensorMetaCheckDepth(); Maybe LaunchTensorMetaConsistencyCheck( const one::Tensor& tensor); Maybe BusyWaitAndCheck(std::shared_ptr& ctx); Maybe RunCallback(const std::shared_ptr& tensor, const std::function()>& Callback); } // namespace private_details inline bool IsGlobalTensorMetaCheckDisabled() { return *private_details::MutThreadLocalTensorMetaCheckDepth() > 1; } template struct CheckGlobalTensorMeta; template struct CheckGlobalTensorMeta&, Args...> { static_assert(is_maybe::value, "returned value type must be Maybe."); template&, Args...)> static RetT Call(const std::shared_ptr& tensor, Args... args) { std::shared_ptr ctx; static bool is_env_enabled_check = IsEnvEnabled(/* check_level */ 1); int64_t* depth = private_details::MutThreadLocalTensorMetaCheckDepth(); if (*depth == 0 && is_env_enabled_check) { ctx = JUST(private_details::LaunchTensorMetaConsistencyCheck(*tensor)); } ++*depth; RetT ret = func(tensor, args...); --*depth; // Always synchronize global tensor meta even if `func` failed. if (*depth == 0 && is_env_enabled_check) { JUST(private_details::BusyWaitAndCheck(ctx)); } return ret; } }; struct DisableCheckGlobalTensorMetaScope final { DisableCheckGlobalTensorMetaScope() { ++*private_details::MutThreadLocalTensorMetaCheckDepth(); } ~DisableCheckGlobalTensorMetaScope() { --*private_details::MutThreadLocalTensorMetaCheckDepth(); } }; static constexpr auto* WithConsistencyChecked = DECORATE(&private_details::RunCallback, CheckGlobalTensorMeta); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ ================================================ FILE: oneflow/core/framework/tensor_storage.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/shut_down_util.h" namespace oneflow { namespace one { TensorStorage::TensorStorage(const std::shared_ptr& tensor_storage) : storage_(tensor_storage) {} TensorStorage::~TensorStorage() { if (!IsShuttingDown() && releaser_hook_) { (*releaser_hook_)(storage_); } } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_storage.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_ #include #include namespace oneflow { class ParallelDesc; namespace vm { class TensorStorage; } // namespace vm namespace one { class TensorStorage final { public: explicit TensorStorage(const std::shared_ptr& tensor_storage); ~TensorStorage(); using ReleaserHookT = std::function&)>; const std::shared_ptr storage() const { return storage_; } void set_releaser_hook(const ReleaserHookT& releaser_hook) { releaser_hook_ = std::make_shared(releaser_hook); } private: std::shared_ptr storage_; std::shared_ptr releaser_hook_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_ ================================================ FILE: oneflow/core/framework/tensor_tuple.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_tuple.h" namespace oneflow { namespace one { TensorTuple::TensorTuple(std::vector>::size_type size) { resize(size); } TensorTuple::TensorTuple(std::initializer_list> init_list) { for (const auto& tensor : init_list) { emplace_back(tensor); } } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_tuple.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_ #include #include #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/op_args_reserved_size.h" namespace oneflow { namespace one { class Tensor; class TensorTuple final : public small_vector>, public std::enable_shared_from_this { public: // TensorTuple(const TensorTuple&) = delete; // TensorTuple(TensorTuple&) = delete; TensorTuple() = default; TensorTuple(std::vector>::size_type size); TensorTuple(std::initializer_list> init_list); ~TensorTuple() = default; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_ ================================================ FILE: oneflow/core/framework/tensor_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { Maybe SyncAccessTensorWithTimeOut( const std::shared_ptr& tensor, const std::function&)>& Callback, const std::string& modifier) { auto btb = std::make_shared(); auto local_tensor = JUST(tensor->AsLocalTensor()); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, modifier); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); } Maybe CopyLocalTensorDataTo(const std::shared_ptr& input, void* mem_ptr, size_t size) { CHECK_OR_RETURN(input->is_local()); // NOLINT CHECK_OR_RETURN(input->is_contiguous()) << Error::RuntimeError() << kOfBugIssueUploadPrompt; CHECK_EQ_OR_RETURN(input->shape()->elem_cnt() * JUST(input->dtype()->bytes()), size) << Error::RuntimeError() << kOfBugIssueUploadPrompt; if (input->nelement() == 1) { return GetItemInScalarTensor(input, mem_ptr, size); } std::shared_ptr local_tensor = JUST(input->AsLocalTensor()); const auto& Callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, mem_ptr, eager_blob_object->dptr(), size, memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; auto btb = std::make_shared(); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, "const"); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); } Maybe GetTensorScope(const std::shared_ptr& tensor) { CHECK_OR_RETURN(LazyMode::is_enabled()) << "it's not allowed to access tensor scope in eager mode"; const auto& lbn = TensorNameScope::Global()->Lookup(tensor); CHECK_OR_RETURN(!lbn.empty()) << "can not access tensor scope since it is not a lazy tensor or a " "captured eager tensor in graph"; const auto& infer_ctx = JUST(GetCurInferCtx()); auto lbi = GenLogicalBlobId(lbn); const auto* op = JUST(infer_ctx->Op4OpName(lbi.op_name())); return Singleton>::Get()->MaybeGetPtr(op->op_conf().scope_symbol_id()); } Maybe GetItemInScalarTensor(const std::shared_ptr& scalar_tensor, void* scalar_ptr, size_t size) { CHECK_EQ_OR_RETURN(GetSizeOfDataType(scalar_tensor->dtype()->data_type()), size) << "invalid size"; CHECK_OR_RETURN(scalar_tensor->is_eager()) << "Only eager scalar tensor support GetItem."; CHECK_EQ_OR_RETURN(scalar_tensor->nelement(), 1) << "can only convert a tensor of size 1 to a Python scalar"; std::shared_ptr local_tensor; { auto tensor = scalar_tensor; if (tensor->is_global()) { Symbol parallel_desc; { const ParallelConf parallel_conf = GenParallelConfOfCpuOnAllRanks(); JUST(PhysicalRun( [¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { parallel_desc = SymbolOf(*JUST(builder->GetParallelDescSymbol(parallel_conf))); return Maybe::Ok(); })); } const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel()); tensor = JUST(functional::ToGlobal(tensor, parallel_desc, {broadcast_sbp}, /*grad_sbp=*/{}, /*check_meta=*/false, /*copy=*/false)); tensor = JUST(functional::GlobalToLocal(tensor, /*copy=*/false)); } local_tensor = JUST(tensor->AsLocalTensor()); } JUST(SyncReadSmallMem(reinterpret_cast(scalar_ptr), size, local_tensor)); return Maybe::Ok(); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/framework/tensor_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #include #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace ep { class Stream; } namespace vm { class EagerBlobObject; } namespace one { class Tensor; Maybe SyncAccessTensorWithTimeOut( const std::shared_ptr& tensor, const std::function&)>& callback, const std::string& modifier); Maybe CopyLocalTensorDataTo(const std::shared_ptr& input, void* mem_ptr, size_t size); Maybe GetTensorScope(const std::shared_ptr& tensor); Maybe GetItemInScalarTensor(const std::shared_ptr& scalar_tensor, void* scalar_ptr, size_t size); template Maybe GetItemInScalarTensor(const std::shared_ptr& scalar_tensor) { T scalar{0}; if constexpr (GetDataType() == kInt64) { if (scalar_tensor->dtype()->data_type() == DataType::kInt8 || scalar_tensor->dtype()->data_type() == kUInt8) { int8_t int8_integer = 0; JUST(GetItemInScalarTensor(scalar_tensor, &int8_integer, sizeof(int8_t))); scalar = static_cast(int8_integer); } else if (scalar_tensor->dtype()->data_type() == DataType::kInt16 || scalar_tensor->dtype()->data_type() == kUInt16) { int16_t int16_integer = 0; JUST(GetItemInScalarTensor(scalar_tensor, &int16_integer, sizeof(int16_t))); scalar = static_cast(int16_integer); } else if (scalar_tensor->dtype()->data_type() == DataType::kInt32 || scalar_tensor->dtype()->data_type() == kUInt32) { int32_t int32_integer = 0; JUST(GetItemInScalarTensor(scalar_tensor, &int32_integer, sizeof(int32_t))); scalar = static_cast(int32_integer); } else if (scalar_tensor->dtype()->data_type() == DataType::kInt64 || scalar_tensor->dtype()->data_type() == kUInt64) { int64_t int64_integer = 0; JUST(GetItemInScalarTensor(scalar_tensor, &int64_integer, sizeof(int64_t))); scalar = static_cast(int64_integer); } } else { JUST(GetItemInScalarTensor(scalar_tensor, &scalar, sizeof(T))); } return scalar; } } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ ================================================ FILE: oneflow/core/framework/to_string.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { Maybe DeviceTag4DeviceType(DeviceType device_type) { auto device_tag = ep::DeviceManagerRegistry::GetDeviceTypeNameByDeviceType(device_type); if (device_tag.empty()) { return Error::DeviceTagNotFoundError() << "invalid_device"; } else { return device_tag; } } Maybe DeviceType4DeviceTag(const std::string& device_tag) { auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_tag); if (device_type == DeviceType::kInvalidDevice) { return Error::DeviceTagNotFoundError() << "device tag `" << device_tag << "' not found"; } else { return device_type; } } } // namespace oneflow ================================================ FILE: oneflow/core/framework/to_string.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_ #define ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_ #include "oneflow/core/common/to_string.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/maybe.h" namespace oneflow { Maybe DeviceTag4DeviceType(DeviceType device_type); Maybe DeviceType4DeviceTag(const std::string& device_tag); template<> inline std::string ToString(const DataType& data_type) { return DataType_Name(data_type); } template<> inline std::string ToString(const DeviceType& device_type) { return DeviceType_Name(device_type); } template<> inline std::string ToString(const std::string& value) { return value; } } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_ ================================================ FILE: oneflow/core/framework/transport_token.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/framework/rank_group_rpc_util.h" namespace oneflow { /*static*/ Maybe TransportToken::NewTransportToken(TransportTokenType type) { int32_t thread_global_id = GetThisThreadGlobalId(); CHECK_GE_OR_RETURN(thread_global_id, 0); // NOLINT CHECK_LT_OR_RETURN(thread_global_id, MaxNumberOfThreadGlobalUId()); // NOLINT return TransportToken(type, thread_global_id); } Maybe TransportToken::CheckThreadGlobalId() const { int32_t thread_global_id = GetThisThreadGlobalId(); CHECK_EQ_OR_RETURN(thread_global_id, this->thread_global_id()); // NOLINT return Maybe::Ok(); } Maybe TransportToken::set_src_rank(int64_t val) { CHECK_GE_OR_RETURN(val, 0); CHECK_LT_OR_RETURN(val, GetMaxVal()); src_rank_ = val; return Maybe::Ok(); } Maybe TransportToken::set_dst_rank(int64_t val) { CHECK_GE_OR_RETURN(val, 0); CHECK_LT_OR_RETURN(val, GetMaxVal()); dst_rank_ = val; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/transport_token.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ #define ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ #include #include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" namespace oneflow { const static int kTransportTokenTypeBit = 5; const static int kTransportTokenThreadGlobalIdBit = 3; enum TransportTokenType { // Begin kTransportTokenTypeInvalid = 0, kTransportTokenTypeData, // e.g. for tensor data transportation kTransportTokenTypeMeta, // e.g. for consistent id generating kTransportTokenTypeSyncSymbolParallelDesc, kTransportTokenTypeSyncSymbolNdSbp, kTransportTokenTypeSyncSymbolGlobalTensorMeta, kTransportTokenTypeCheckRankGroupConsistency, kTransportTokenTypeCheckTensorConsistency, kTransportTokenTypeSyncLocalShapeDtype, // End kTransportTokenTypeSize, }; static_assert(kTransportTokenTypeSize <= (1 << kTransportTokenTypeBit), ""); class TransportToken; template<> struct IsScalarType final { static const bool value = true; }; class TransportToken final { public: TransportToken() : TransportToken(kTransportTokenTypeInvalid, 0) {} TransportToken(const TransportToken&) = default; TransportToken(TransportToken&) = default; ~TransportToken() = default; static Maybe NewTransportToken(TransportTokenType type); static constexpr size_t MaxNumberOfThreadGlobalUId() { return (1 << kTransportTokenThreadGlobalIdBit); } Maybe CheckThreadGlobalId() const; bool operator==(const TransportToken& other) const { return static_cast(*this) == static_cast(other); } // Getters TransportTokenType type() const { return static_cast(type_); } int thread_global_id() const { return thread_global_id_; } int32_t seq_id() const { return seq_id_; } // Setters Maybe set_src_rank(int64_t val); Maybe set_dst_rank(int64_t val); operator uint64_t() const { return *reinterpret_cast(this); } TransportToken& operator++() { ++seq_id_; return *this; } private: TransportToken(TransportTokenType type, uint8_t thread_global_id) : src_rank_(0), dst_rank_(0), type_(static_cast(type)), thread_global_id_(thread_global_id), seq_id_(0) {} uint16_t src_rank_; uint16_t dst_rank_; uint8_t type_ : kTransportTokenTypeBit; // TransportTokenType uint8_t thread_global_id_ : kTransportTokenThreadGlobalIdBit; uint32_t seq_id_ : (32 - kTransportTokenTypeBit - kTransportTokenThreadGlobalIdBit); }; static_assert(sizeof(TransportToken) == sizeof(uint64_t), ""); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::TransportToken& token) const { return std::hash()(static_cast(token)); } }; } // namespace std #endif // ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ ================================================ FILE: oneflow/core/framework/transport_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/transport/transport.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/spin_counter.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace { template (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t, const std::function&), Maybe (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*, std::function*), typename ForEachRankT> Maybe AccessToOtherRanks(const ForEachRankT& ForEachRank, const TransportToken& token, AsyncTransportCtx* ctx) { auto* blocking_counter = ctx->mut_blocking_counter(); JUST(ForEachRank([&, blocking_counter](int64_t rank) -> Maybe { if (rank == GlobalProcessCtx::Rank()) { return Maybe::Ok(); } blocking_counter->Increase(); void* buffer = nullptr; std::size_t size = 0; std::function Callback; JUST((ctx->*Prepare)(rank, &buffer, &size, &Callback)); JUST(SendOrRecv(token, rank, buffer, size, [blocking_counter, Callback]() { Callback(); blocking_counter->Decrease(); })); return Maybe::Ok(); })); return Maybe::Ok(); } template (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t, const std::function&), Maybe (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*, std::function*)> Maybe AccessToAllOtherRanks(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { return rank_group->ForEachRank(DoEach); }; return AccessToOtherRanks(ForEachRank, token, ctx); } template (RankGroup::*GetPrevOrNext)() const, Maybe (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t, const std::function&), Maybe (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*, std::function*)> Maybe AccessToNearbyRank(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { CHECK_OR_RETURN(rank_group->ContainingCurrentRank()); const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { return DoEach(JUST(((*rank_group).*GetPrevOrNext)())); }; return AccessToOtherRanks(ForEachRank, token, ctx); } namespace { Maybe> RawGetTransportToken(const TransportToken& token) { CHECK_EQ_OR_RETURN(token.seq_id(), 0); JUST(token.CheckThreadGlobalId()); auto auto_token = std::make_shared(token); return auto_token; } static constexpr auto* GetTransportToken = DECORATE(&RawGetTransportToken, ThreadLocal); Maybe GetAutoIncrementalTransportToken(int64_t src_rank, int64_t dst_rank, TransportToken token) { CHECK_EQ_OR_RETURN(token.seq_id(), 0); JUST(token.set_src_rank(src_rank)); JUST(token.set_dst_rank(dst_rank)); return ++**JUST(GetTransportToken(token)); } } // namespace Maybe Send(const TransportToken& token, int64_t rank, void* buffer, std::size_t size, const std::function& Callback) { #ifdef __linux__ int64_t src_rank = GlobalProcessCtx::Rank(); int64_t dst_rank = rank; TransportToken send_token = JUST(GetAutoIncrementalTransportToken(src_rank, dst_rank, token)); auto* transport = JUST(SingletonMaybe()); transport->Send(static_cast(send_token), rank, buffer, size, Callback); return Maybe::Ok(); #else UNIMPLEMENTED(); return Maybe::Ok(); #endif // __linux__ } Maybe Recv(const TransportToken& token, int64_t rank, void* buffer, std::size_t size, const std::function& Callback) { #ifdef __linux__ int64_t src_rank = rank; int64_t dst_rank = GlobalProcessCtx::Rank(); TransportToken recv_token = JUST(GetAutoIncrementalTransportToken(src_rank, dst_rank, token)); auto* transport = JUST(SingletonMaybe()); transport->Receive(static_cast(recv_token), rank, buffer, size, Callback); return Maybe::Ok(); #else UNIMPLEMENTED(); return Maybe::Ok(); #endif // __linux__ } } // namespace /*static*/ Maybe TransportUtil::BroadcastToAllOtherRanks(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { CHECK_OR_RETURN(rank_group->ContainingCurrentRank()); JUST(AccessToAllOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } /*static*/ Maybe TransportUtil::CollectFromAllOtherRanks(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { CHECK_OR_RETURN(rank_group->ContainingCurrentRank()); JUST(AccessToAllOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } /*static*/ Maybe TransportUtil::BroadcastToOtherRanks(Symbol src_rank_group, Symbol dst_rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { if (src_rank_group->ContainingCurrentRank()) { JUST(AccessToAllOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>( dst_rank_group, token, ctx)); } return Maybe::Ok(); } /*static*/ Maybe TransportUtil::CollectFromOtherRanks(Symbol src_rank_group, Symbol dst_rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { if (dst_rank_group->ContainingCurrentRank()) { JUST(AccessToAllOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>( src_rank_group, token, ctx)); } return Maybe::Ok(); } /*static*/ Maybe TransportUtil::SendToNextRankInRing(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { JUST( AccessToNearbyRank<&RankGroup::GetNextRankInRing, &Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } /*static*/ Maybe TransportUtil::ReceiveFromPrevRankInRing(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx) { JUST( AccessToNearbyRank<&RankGroup::GetPrevRankInRing, &Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } namespace { Maybe GetCurrentRankIndex(const std::vector& rank_heap) { for (int i = 0; i < rank_heap.size(); ++i) { if (rank_heap.at(i) == GlobalProcessCtx::Rank()) { return i; } } UNIMPLEMENTED_THEN_RETURN(); } } // namespace /*static*/ Maybe TransportUtil::SendDataToChildrenInHeap( const std::vector& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx) { int64_t current_rank_index = JUST(GetCurrentRankIndex(rank_heap)); const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { int64_t left_index = current_rank_index * 2 + 1; if (left_index < rank_heap.size()) { JUST(DoEach(rank_heap.at(left_index))); } int64_t right_index = current_rank_index * 2 + 2; if (right_index < rank_heap.size()) { JUST(DoEach(rank_heap.at(right_index))); } return Maybe::Ok(); }; return AccessToOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(ForEachRank, token, ctx); } /*static*/ Maybe TransportUtil::ReceiveDataFromParentInHeap( const std::vector& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx) { int64_t current_rank_index = JUST(GetCurrentRankIndex(rank_heap)); const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { if (current_rank_index == 0) { return Maybe::Ok(); } return DoEach(rank_heap.at((current_rank_index - 1) / 2)); }; return AccessToOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(ForEachRank, token, ctx); } /*static*/ Maybe TransportUtil::ReceiveDataFromRank(int64_t rank, const TransportToken& token, AsyncTransportCtx* ctx) { const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { return DoEach(rank); }; return AccessToOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(ForEachRank, token, ctx); } /*static*/ Maybe TransportUtil::SendDataToRank(int64_t rank, const TransportToken& token, AsyncTransportCtx* ctx) { const auto& ForEachRank = [&](const std::function(int64_t)>& DoEach) -> Maybe { return DoEach(rank); }; return AccessToOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(ForEachRank, token, ctx); } } // namespace oneflow ================================================ FILE: oneflow/core/framework/transport_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/framework/transport_token.h" namespace oneflow { class AsyncTransportCtx { public: explicit AsyncTransportCtx(const TransportToken& transport_token) : transport_token_(transport_token), blocking_counter_(1) {} virtual ~AsyncTransportCtx() = default; const TransportToken& transport_token() const { return transport_token_; } BlockingCounter* mut_blocking_counter() { return &blocking_counter_; } Maybe WaitDone() { mut_blocking_counter()->Decrease(); return mut_blocking_counter()->WaitUntilCntEqualZero([]() -> Maybe { return true; }); } virtual Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) = 0; virtual Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) = 0; private: TransportToken transport_token_; BlockingCounter blocking_counter_; }; class NaiveAsyncTransportCtx final : public AsyncTransportCtx { public: NaiveAsyncTransportCtx( const TransportToken& transport_token, const std::function(void**, std::size_t*, std::function*)>& PrepareSend, const std::function(void**, std::size_t*, std::function*)>& PrepareRecv) : AsyncTransportCtx(transport_token), prepare_send_(PrepareSend), prepare_recv_(PrepareRecv) {} NaiveAsyncTransportCtx( const TransportToken& transport_token, const std::function(void**, std::size_t*, std::function*)>& PrepareSend, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareRecvWithRank) : AsyncTransportCtx(transport_token), prepare_send_(PrepareSend), prepare_recv_with_rank_(PrepareRecvWithRank) {} NaiveAsyncTransportCtx( const TransportToken& transport_token, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareSendWithRank, const std::function(void**, std::size_t*, std::function*)>& PrepareRecv) : AsyncTransportCtx(transport_token), prepare_send_with_rank_(PrepareSendWithRank), prepare_recv_(PrepareRecv) {} NaiveAsyncTransportCtx( const TransportToken& transport_token, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareSendWithRank, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareRecvWithRank) : AsyncTransportCtx(transport_token), prepare_send_with_rank_(PrepareSendWithRank), prepare_recv_with_rank_(PrepareRecvWithRank) {} ~NaiveAsyncTransportCtx() override = default; Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { if (prepare_send_with_rank_) { return prepare_send_with_rank_(rank, buffer, size, Callback); } return prepare_send_(buffer, size, Callback); } Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { if (prepare_recv_with_rank_) { return prepare_recv_with_rank_(rank, buffer, size, Callback); } return prepare_recv_(buffer, size, Callback); } private: std::function(void**, std::size_t*, std::function*)> prepare_send_; std::function(int64_t, void**, std::size_t*, std::function*)> prepare_send_with_rank_; std::function(void**, std::size_t*, std::function*)> prepare_recv_; std::function(int64_t, void**, std::size_t*, std::function*)> prepare_recv_with_rank_; }; class RankGroup; struct TransportUtil final { static Maybe SendToNextRankInRing(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe ReceiveFromPrevRankInRing(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe BroadcastToAllOtherRanks(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe CollectFromAllOtherRanks(Symbol rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe BroadcastToOtherRanks(Symbol src_rank_group, Symbol dst_rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe CollectFromOtherRanks(Symbol src_rank_group, Symbol dst_rank_group, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe SendDataToChildrenInHeap(const std::vector& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe ReceiveDataFromParentInHeap(const std::vector& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe ReceiveDataFromRank(int64_t rank, const TransportToken& token, AsyncTransportCtx* ctx); static Maybe SendDataToRank(int64_t rank, const TransportToken& token, AsyncTransportCtx* ctx); }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ ================================================ FILE: oneflow/core/framework/user_op_attr.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/sequential.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/common/device.proto"; import "oneflow/core/common/memory_format.proto"; enum AttrType { kAtInt32 = 1; kAtInt64 = 2; kAtBool = 3; kAtFloat = 4; kAtDouble = 5; kAtString = 6; kAtShape = 7; kAtDataType = 8; kAtListInt32 = 9; kAtListInt64 = 10; kAtListFloat = 11; kAtListDataType = 12; kAtListShape = 13; kAtListString = 14; kAtStride = 15; kAtListStride = 16; kAtDevice = 17; kAtComplexDouble = 18; kAtMemoryFormat = 19; kAtBytes = 20; } message AttrValue { message ListInt32 { repeated int32 val = 1; } message ListInt64 { repeated int64 val = 1; } message ListFloat { repeated float val = 1; } message ListDataType { repeated DataType val = 1; } message ListShape { repeated ShapeProto val = 1; } message ListStride { repeated Int64ListProto val = 1; } // order and naming convention of the oneof field must be consistent with the enum AttrType message ListString { repeated string val = 1; } message ComplexDouble { required double real = 1; required double imag = 2; } oneof value { int32 at_int32 = 1; int64 at_int64 = 2; bool at_bool = 3; float at_float = 4; double at_double = 5; string at_string = 6; ShapeProto at_shape = 7; DataType at_data_type = 8; ListInt32 at_list_int32 = 9; ListInt64 at_list_int64 = 10; ListFloat at_list_float = 11; ListDataType at_list_data_type = 12; ListShape at_list_shape = 13; ListString at_list_string = 14; Int64ListProto at_stride = 15; ListStride at_list_stride = 16; DeviceProto at_device = 17; ComplexDouble at_complex_double = 18; MemoryFormat at_memory_format = 19; bytes at_bytes = 20; } } message AttrDef { required string name = 1; required string description = 2; required AttrValue default_val = 3; } ================================================ FILE: oneflow/core/framework/user_op_conf.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" namespace oneflow { namespace user_op { UserOpConfWrapper::UserOpConfWrapper(std::shared_ptr op_conf) : op_conf_(op_conf) { CHECK(op_conf_); CHECK(op_conf_->has_user_conf()); attrs_ = MakeAttrMapFromUserOpConf(op_conf_->user_conf()); } UserOpConfWrapper::UserOpConfWrapper(const OperatorConf& op_conf) : UserOpConfWrapper(std::make_shared(op_conf)) {} const OperatorConf& UserOpConfWrapper::op_conf() const { return *op_conf_; } const UserOpConf& UserOpConfWrapper::user_op_conf() const { return op_conf_->user_conf(); } const std::string& UserOpConfWrapper::op_name() const { return op_conf_->name(); } const std::string& UserOpConfWrapper::op_type_name() const { return op_conf_->user_conf().op_type_name(); } const std::string& UserOpConfWrapper::input(const std::string& arg_name, int32_t index) const { auto it = op_conf_->user_conf().input().find(arg_name); CHECK(it != op_conf_->user_conf().input().end()) << "arg_name: " << arg_name << ", index: " << index; CHECK(index >= 0 && index < it->second.s_size()); return it->second.s(index); } const std::string& UserOpConfWrapper::output(const std::string& arg_name, int32_t index) const { auto it = op_conf_->user_conf().output().find(arg_name); CHECK(it != op_conf_->user_conf().output().end()) << "arg_name: " << arg_name << ", index: " << index; CHECK(index >= 0 && index < it->second.s_size()); return it->second.s(index); } bool UserOpConfWrapper::has_input(const std::string& arg_name, int32_t index) const { return input_size(arg_name) > index; } bool UserOpConfWrapper::has_output(const std::string& arg_name, int32_t index) const { return output_size(arg_name) > index; } int32_t UserOpConfWrapper::input_size(const std::string& arg_name) const { auto it = op_conf_->user_conf().input().find(arg_name); if (it == op_conf_->user_conf().input().end()) { return 0; } return it->second.s_size(); } int32_t UserOpConfWrapper::output_size(const std::string& arg_name) const { auto it = op_conf_->user_conf().output().find(arg_name); if (it == op_conf_->user_conf().output().end()) { return 0; } return it->second.s_size(); } const std::shared_ptr& UserOpConfWrapper::Attr4Name( const std::string& attr_name) const { const auto& attr = attrs_.Attr4Name(attr_name); CHECK(attr.get() != nullptr) << "attr_name: " << attr_name; return attr; } #define OP_WRAPPER_ATTR_MEMBER_FUNC(field, cpp_type, attr_type) \ template<> \ UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Attr(const std::string& attr_name, \ const cpp_type& val) { \ AttrValue attr_val; \ AttrValueAccessor::Attr(val, &attr_val); \ attr_.emplace(attr_name, attr_val); \ return *this; \ } OF_PP_FOR_EACH_TUPLE(OP_WRAPPER_ATTR_MEMBER_FUNC, ATTR_SEQ) #undef OP_WRAPPER_ATTR_MEMBER_FUNC UserOpWrapper::UserOpWrapper( const OperatorConf& op, const std::function& LogicalBlobDesc4BnInOp, const std::function& DiffLbi4BnInOp) : conf_(op), diff_fn_(DiffLbi4BnInOp) { auto InitTensorDescFromOpArgs = [&](const PbMap& args) { for (const auto& pair : args) { for (int32_t i = 0; i < pair.second.s_size(); ++i) { std::string bn = GenRepeatedBn(pair.first, i); const BlobDesc& blob_desc = LogicalBlobDesc4BnInOp(bn); CHECK((&blob_desc) != nullptr); BlobDescProto proto; blob_desc.ToProto(&proto); NaiveTensorDesc tensor_desc(proto); CHECK(bn2tensor_desc_.emplace(bn, tensor_desc).second); } } }; InitTensorDescFromOpArgs(op.user_conf().input()); InitTensorDescFromOpArgs(op.user_conf().output()); } const TensorDesc& UserOpWrapper::arg_tensor_desc(const std::string& arg_name, int32_t index) const { std::string bn = GenRepeatedBn(arg_name, index); CHECK(bn2tensor_desc_.find(bn) != bn2tensor_desc_.end()); return bn2tensor_desc_.at(bn); } const TensorDesc& UserOpWrapper::TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { return arg_tensor_desc(arg_name, index); } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::InputBind( const std::string& arg_name, const std::string& logical_blob_name) { if (input_.find(arg_name) == input_.end()) { input_order_.emplace_back(arg_name); } input_[arg_name].emplace_back(logical_blob_name); CHECK_EQ(input_.size(), input_order_.size()); return *this; } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Input(const std::string& arg_name, const std::string& logical_blob_name) { return InputBind(arg_name, logical_blob_name); } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Output(const std::string& arg_name) { return Output(arg_name, 1); } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Output(const std::string& arg_name, int32_t num) { CHECK(num >= 0); if (output_.find(arg_name) == output_.end()) { output_order_.emplace_back(arg_name); } output_[arg_name].resize(num); for (int32_t i = 0; i < num; ++i) { std::string bn = GenRepeatedBn(arg_name, i); output_[arg_name].at(i) = GenLogicalBlobName(op_name_, bn); } CHECK_EQ(output_.size(), output_order_.size()); return *this; } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::ScopeSymbolId(int64_t scope_symbol_id) { scope_symbol_id_.set_value(scope_symbol_id); return *this; } UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::DeviceTag(const std::string& device_tag) { device_tag_ = device_tag; return *this; } UserOpConfWrapper UserOpConfWrapperBuilder::Build() { OperatorConf op_conf; op_conf.set_name(op_name_); if (!device_tag_.empty()) { op_conf.set_device_tag(device_tag_); } if (scope_symbol_id_.has_value()) { op_conf.set_scope_symbol_id(scope_symbol_id_.value()); } UserOpConf* user_conf = op_conf.mutable_user_conf(); user_conf->set_op_type_name(op_type_name_); auto GenArgs = [&](const HashMap>& src, PbMap* arg_name2lbns) { for (const auto& pair : src) { *(*arg_name2lbns)[pair.first].mutable_s() = StdVec2PbRpf(pair.second); } }; GenArgs(input_, user_conf->mutable_input()); GenArgs(output_, user_conf->mutable_output()); for (const auto& arg_name : input_order_) { user_conf->add_input_order(arg_name); } for (const auto& arg_name : output_order_) { user_conf->add_output_order(arg_name); } for (const auto& pair : attr_) { (*user_conf->mutable_attr())[pair.first] = pair.second; } wrapper_ = UserOpConfWrapper(*CHECK_JUST(CheckAndCompleteUserOpConfImpl(op_conf))); return wrapper_; } } // namespace user_op Maybe CheckArgDefIsValidInUserOpConf( const OperatorConf& op_conf, const PbMap& arg_name2lbns, const PbRpf& args) { const std::string& op_name = op_conf.name(); const std::string& op_type_name = op_conf.user_conf().op_type_name(); HashSet op_def_arg_names; for (const auto& arg : args) { int32_t arg_blob_num = 0; if (arg_name2lbns.find(arg.name()) != arg_name2lbns.end()) { arg_blob_num = arg_name2lbns.at(arg.name()).s_size(); } if (arg_blob_num == 0) { CHECK_OR_RETURN(arg.is_optional()) << " op_name: " << op_name << " op_type_name: " << op_type_name << " arg name: " << arg.name() << " in OpDef must have blob in op_conf: \n" << op_conf.DebugString(); } op_def_arg_names.insert(arg.name()); } for (const auto& pair : arg_name2lbns) { CHECK_OR_RETURN(op_def_arg_names.find(pair.first) != op_def_arg_names.end()) << " op_name: " << op_name << " op_type_name: " << op_type_name << " has not arg name: " << pair.first << " in OpDef"; } return Maybe::Ok(); } Maybe CheckUserOpConfArgOrderValid( const OperatorConf& op_conf, const PbMap& arg_name2lbns, const PbRpf& arg_order) { CHECK_EQ_OR_RETURN(arg_name2lbns.size(), arg_order.size()) << " op_conf: " << op_conf.DebugString() << " io order is not valid."; HashSet arg_names; for (const std::string& arg_name : arg_order) { CHECK_OR_RETURN(arg_names.insert(arg_name).second) << " op_conf: " << op_conf.DebugString() << " io order is not valid."; CHECK_OR_RETURN(arg_name2lbns.find(arg_name) != arg_name2lbns.end()) << " op_conf: " << op_conf.DebugString() << " io order is not valid."; } return Maybe::Ok(); } Maybe AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, UserOpConf* user_conf, const std::string& error_msg_prefix) { auto* attr_name2attr = user_conf->mutable_attr(); HashSet op_def_attr_names; for (const auto& attr : op_def.attr()) { if (attr_name2attr->find(attr.name()) == attr_name2attr->end()) { CHECK_OR_RETURN(attr.has_default_val()) << error_msg_prefix << " op_type_name: " << user_conf->op_type_name() << " must set attr val for attr_name: " << attr.name(); (*attr_name2attr)[attr.name()] = attr.default_val(); } op_def_attr_names.insert(attr.name()); } for (const auto& pair : user_conf->attr()) { CHECK_OR_RETURN(op_def_attr_names.find(pair.first) != op_def_attr_names.end()) << error_msg_prefix << " op_type_name: " << user_conf->op_type_name() << " has not attr_name: " << pair.first << " in OpDef"; } for (const auto& attr : op_def.attr()) { CHECK_OR_RETURN(static_cast(attr.type()) == static_cast(attr_name2attr->at(attr.name()).value_case())) << error_msg_prefix << " op_type_name: " << user_conf->op_type_name() << " attr_name: " << attr.name() << " has different attr type in OpDef and OpConf, it should be with type: " << AttrType_Name(attr.type()); } return Maybe::Ok(); } Maybe AddAttrDefaultValueAndCheckValid(UserOpConf* user_conf) { const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_conf->op_type_name()); CHECK_OR_RETURN(val) << " Cannot find op_type_name: " << user_conf->op_type_name(); const UserOpDef& op_def = val->op_def; return AddAttrDefaultValueAndCheckValid(op_def, user_conf, ""); } Maybe AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, OperatorConf* op_conf) { UserOpConf* user_conf = op_conf->mutable_user_conf(); std::string error_msg_prefix = " op_name: " + op_conf->name(); return AddAttrDefaultValueAndCheckValid(op_def, user_conf, error_msg_prefix); } Maybe GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name) { const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); CHECK_OR_RETURN(val) << " Cannot find op " << op_type_name; const UserOpDef& op_def = val->op_def; for (int32_t i = 0; i < op_def.attr_size(); ++i) { if (op_def.attr(i).name() == attr_name) { return op_def.attr(i).type(); } } CHECK_OR_RETURN(false) << " Cannot find attr " << attr_name << " in op " << op_type_name; } Maybe CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf) { CHECK_OR_RETURN(op_conf.has_user_conf()) << " Add default value only for user op"; OperatorConf ret = op_conf; UserOpConf* user_conf = ret.mutable_user_conf(); const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_conf->op_type_name()); CHECK_OR_RETURN(val) << " Cannot find op_type_name: " << user_conf->op_type_name(); const UserOpDef& op_def = val->op_def; JUST(AddAttrDefaultValueAndCheckValid(op_def, &ret)); // check input and output valid JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->input(), op_def.input())); JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->output(), op_def.output())); JUST(CheckUserOpConfArgOrderValid(op_conf, user_conf->input(), user_conf->input_order())); JUST(CheckUserOpConfArgOrderValid(op_conf, user_conf->output(), user_conf->output_order())); // check attr valid by user JUST(val->check_fn(user_op::UserOpDefWrapper(op_def), user_op::UserOpConfWrapper(ret))); return ret; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/user_op_conf.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/framework/user_op_conf.pb.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/operator/op_conf.pb.h" namespace oneflow { class BlobDesc; namespace user_op { class OpArg final { public: OpArg(std::string&& name, int32_t index) : name_(std::move(name)), index_(index) {} const std::string& name() const { return name_; } int32_t index() const { return index_; } private: std::string name_; int32_t index_; }; class AttrVal; class UserOpConfWrapper final { public: UserOpConfWrapper(const OperatorConf&); UserOpConfWrapper(std::shared_ptr op_conf); const OperatorConf& op_conf() const; const UserOpConf& user_op_conf() const; const std::string& op_name() const; const std::string& op_type_name() const; const std::string& input(const std::string& arg_name, int32_t index) const; const std::string& output(const std::string& arg_name, int32_t index) const; bool has_input(const std::string& arg_name, int32_t index) const; bool has_output(const std::string& arg_name, int32_t index) const; int32_t input_size(const std::string& arg_name) const; int32_t output_size(const std::string& arg_name) const; template const T& attr(const std::string& attr_name) const { return CHECK_JUST(attrs_.GetAttr(attr_name)); } template const T& attr_or_default(const std::string& attr_name, const T& default_val) const { if (attrs_.Has(attr_name)) { return CHECK_JUST(attrs_.GetAttr(attr_name)); } else { return default_val; } } const std::shared_ptr& Attr4Name(const std::string& attr_name) const; private: UserOpConfWrapper() = default; friend class UserOpConfWrapperBuilder; std::shared_ptr op_conf_; AttrMap attrs_; }; class UserOpWrapper final { public: UserOpWrapper(const OperatorConf& op, const std::function&, const std::function&); public: const UserOpConfWrapper& user_op_conf() const { return conf_; } const OperatorConf& op_conf() const { return conf_.op_conf(); } const std::string& op_name() const { return conf_.op_name(); } const std::string& op_type_name() const { return conf_.op_type_name(); } int32_t input_size(const std::string& arg_name) const { return conf_.input_size(arg_name); } const std::string& input(const std::string& arg_name, int32_t index) const { return conf_.input(arg_name, index); } int32_t output_size(const std::string& arg_name) const { return conf_.output_size(arg_name); } const std::string& output(const std::string& arg_name, int32_t index) const { return conf_.output(arg_name, index); } template T attr(const std::string& attr_name) const { return conf_.attr(attr_name); } template T attr_or_default(const std::string& attr_name, const T& default_val) const { return conf_.attr_or_default(attr_name, default_val); } const TensorDesc& arg_tensor_desc(const std::string& arg_name, int32_t index) const; const TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const; private: UserOpConfWrapper conf_; std::function diff_fn_; HashMap bn2tensor_desc_; }; class UserOpConfWrapperBuilder final { public: UserOpConfWrapperBuilder(const std::string& op_name) : op_name_(op_name) {} UserOpConfWrapperBuilder& OpTypeName(const std::string& op_type_name) { op_type_name_ = op_type_name; return *this; } UserOpConfWrapperBuilder& Op(const std::string& op_type_name) { return OpTypeName(op_type_name); } UserOpConfWrapperBuilder& InputBind(const std::string& arg_name, const std::string& logical_blob_name); UserOpConfWrapperBuilder& Input(const std::string& arg_name, const std::string& logical_blob_name); UserOpConfWrapperBuilder& Output(const std::string& arg_name, int32_t num); UserOpConfWrapperBuilder& Output(const std::string& arg_name); template UserOpConfWrapperBuilder& Attr(const std::string& attr_name, const T& val); UserOpConfWrapperBuilder& ScopeSymbolId(int64_t scope_symbol_id); UserOpConfWrapperBuilder& DeviceTag(const std::string& device_tag); UserOpConfWrapper Build(); private: UserOpConfWrapper wrapper_; std::string op_name_; std::string op_type_name_; HashMap> input_; HashMap> output_; HashMap attr_; std::vector input_order_; std::vector output_order_; OptInt64 scope_symbol_id_; std::string device_tag_; }; } // namespace user_op Maybe GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name); Maybe CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf); Maybe AddAttrDefaultValueAndCheckValid(UserOpConf* user_conf); } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_ ================================================ FILE: oneflow/core/framework/user_op_conf.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/framework/user_op_attr.proto"; message UserOpConf { message ListString { repeated string s = 1; } required string op_type_name = 1; map input = 2; map output = 3; map attr = 4; // NOTE(chengcheng): specify the input/output order according to the order called by // UserOpBuilder. repeated string input_order = 5; repeated string output_order = 6; } ================================================ FILE: oneflow/core/framework/user_op_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" namespace oneflow { namespace user_op { UserOpDefWrapper::UserOpDefWrapper(const UserOpDef& def) : def_(def), inputs_(), outputs_(), attrs_() { for (int32_t i = 0; i < def_.input_size(); ++i) { inputs_.emplace(def_.input(i).name(), def_.mutable_input(i)); } for (int32_t i = 0; i < def_.output_size(); ++i) { outputs_.emplace(def_.output(i).name(), def_.mutable_output(i)); } for (int32_t i = 0; i < def_.attr_size(); ++i) { attrs_.emplace(def_.attr(i).name(), def_.mutable_attr(i)); } } bool UserOpDefWrapper::IsInputArgName(const std::string& name) const { return inputs_.find(name) != inputs_.end(); } bool UserOpDefWrapper::IsOutputArgName(const std::string& name) const { return outputs_.find(name) != outputs_.end(); } bool UserOpDefWrapper::IsAttrName(const std::string& name) const { return attrs_.find(name) != attrs_.end(); } bool UserOpDefWrapper::IsArgOptional(const std::string& name) const { const UserOpDef::ArgDef* arg_def = GetArgPointer(name); CHECK_NOTNULL(arg_def); return arg_def->is_optional(); } const UserOpDef::ArgDef* UserOpDefWrapper::GetArgPointer(const std::string& name) const { auto it = inputs_.find(name); if (it != inputs_.end()) { return it->second; } it = outputs_.find(name); if (it != outputs_.end()) { return it->second; } return nullptr; } AttrType UserOpDefWrapper::GetAttrType(const std::string& name) const { return attrs_.at(name)->type(); } bool UserOpDefWrapper::AttrHasDefaultVal(const std::string& name) const { return attrs_.at(name)->has_default_val(); } #define ATTR_TYPE_SPECIALIZATION(field, cpp_type, attr_type) \ template<> \ cpp_type UserOpDefWrapper::GetAttrDefaultVal(const std::string& name) const { \ CHECK(AttrHasDefaultVal(name)); \ const AttrValue& default_val = attrs_.at(name)->default_val(); \ CHECK_EQ(static_cast(attr_type), default_val.value_case()); \ return AttrValueAccessor::Attr(default_val); \ } OF_PP_FOR_EACH_TUPLE(ATTR_TYPE_SPECIALIZATION, ATTR_SEQ) #undef ATTR_TYPE_SPECIALIZATION } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/user_op_def.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_ #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace user_op { class UserOpDefWrapper final { public: UserOpDefWrapper(const UserOpDef&); ~UserOpDefWrapper() = default; UserOpDefWrapper(const UserOpDefWrapper&) = delete; UserOpDefWrapper(UserOpDefWrapper&&) = delete; const std::string& name() const { return def_.name(); } bool IsInputArgName(const std::string&) const; bool IsOutputArgName(const std::string&) const; bool IsAttrName(const std::string&) const; bool IsArgOptional(const std::string&) const; AttrType GetAttrType(const std::string&) const; bool AttrHasDefaultVal(const std::string&) const; template T GetAttrDefaultVal(const std::string&) const; private: const UserOpDef::ArgDef* GetArgPointer(const std::string&) const; UserOpDef def_; HashMap inputs_; HashMap outputs_; HashMap attrs_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_ ================================================ FILE: oneflow/core/framework/user_op_def.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/framework/user_op_attr.proto"; message UserOpDef { required string name = 1; message ArgDef { required string name = 1; optional bool is_optional = 2 [default = false]; } repeated ArgDef input = 2; repeated ArgDef output = 3; message AttrDef { required string name = 1; required AttrType type = 2; optional AttrValue default_val = 3; } repeated AttrDef attr = 4; } ================================================ FILE: oneflow/core/framework/user_op_hob.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_ #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/high_order_bool.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace user_op { ALWAYS_INLINE inline auto HobTrue() { std::ostringstream string_stream; string_stream << "\" always true \""; return hob::LiteralBool(string_stream.str(), true); } ALWAYS_INLINE inline auto HobFalse() { std::ostringstream string_stream; string_stream << "\" always false \""; return hob::LiteralBool(string_stream.str(), false); } ALWAYS_INLINE inline auto HobDataType(const std::string& tensor_name, int tensor_idx) { std::ostringstream string_stream; string_stream << "data_type of tensor \'" << tensor_name << "\'"; return hob::make_custom( string_stream.str(), [tensor_name, tensor_idx](const KernelRegContext& ctx) -> DataType { const user_op::TensorDesc* desc = ctx.TensorDesc4ArgNameAndIndex(tensor_name, tensor_idx); CHECK(desc != nullptr) << "key `" << tensor_name << "_" << tensor_idx << "` not found."; return desc->data_type(); }); } ALWAYS_INLINE inline auto HobInputSize(const std::string& tensor_name) { std::ostringstream string_stream; string_stream << "size of input \'" << tensor_name << "\'"; return hob::make_custom(string_stream.str(), [tensor_name](const KernelRegContext& ctx) -> int32_t { return ctx.user_op_conf().input_size(tensor_name); }); } template ALWAYS_INLINE inline auto HobAttr(const std::string& attr_name) { return hob::make_custom(attr_name, [attr_name](const user_op::KernelRegContext& ctx) -> const T& { return ctx.Attr(attr_name); }); } ALWAYS_INLINE inline auto HobDeviceType() { return hob::make_custom( "device_type", [](const KernelRegContext& ctx) -> DeviceType { return ctx.device_type(); }); } ALWAYS_INLINE inline auto HobDeviceSubTag() { return hob::make_custom("device_sub_tag", [](const KernelRegContext& ctx) -> const std::string& { return ctx.Attr("device_sub_tag"); }); } ALWAYS_INLINE inline auto HobEnvBool(const std::string& env_var, bool default_value) { std::ostringstream string_stream; string_stream << "environment variable \'" << env_var << "\'"; return hob::make_custom(string_stream.str(), [env_var, default_value](const KernelRegContext& ctx) -> bool { return ParseBooleanFromEnv(env_var, default_value); }); } } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_ ================================================ FILE: oneflow/core/framework/user_op_kernel_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/framework/user_op_hob.h" namespace oneflow { namespace user_op { OpKernelRegistry& OpKernelRegistry::Name(const std::string& op_type_name) { result_.op_type_name = op_type_name; return *this; } OpKernelRegistry& OpKernelRegistry::SetCreateFn(OpKernelCreateFn fn) { result_.create_fn = std::move(fn); return *this; } OpKernelRegistry& OpKernelRegistry::SetInferTmpSizeFn(InferTmpSizeFn fn) { result_.infer_tmp_size_fn = std::move(fn); return *this; } OpKernelRegistry& OpKernelRegistry::SetInplaceProposalFn(InplaceProposalFn fn) { result_.inplace_proposal_fn = std::move(fn); return *this; } OpKernelRegistry& OpKernelRegistry::SetPriority(int32_t priority) { result_.priority = priority; return *this; } Maybe OpKernelRegistry::Finish() { CHECK_OR_RETURN(result_.create_fn != nullptr) << "No Create function for " << result_.op_type_name; result_.need_temp_storage = (result_.infer_tmp_size_fn != nullptr); if (!result_.need_temp_storage) { result_.infer_tmp_size_fn = TmpSizeInferFnUtil::ZeroTmpSize; } if (result_.inplace_proposal_fn == nullptr) { result_.inplace_proposal_fn = [](const InferContext&, AddInplaceArgPair) { return Maybe::Ok(); }; } if (result_.is_matched_hob == nullptr) { static auto hob_true = std::make_shared(user_op::HobTrue()); result_.is_matched_hob = hob_true; } return *this; } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/user_op_kernel_registry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_ #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/common/high_order_bool.h" namespace oneflow { namespace user_op { class OpKernel; class TensorDesc; class InferContext; class KernelRegContext { public: virtual ~KernelRegContext() = default; virtual DeviceType device_type() const = 0; virtual const ParallelContext& parallel_ctx() const = 0; virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual const UserOpConfWrapper& user_op_conf() const = 0; template const T& Attr(const std::string& attr_name) const { return AttrValueCast(*Attr4Name(attr_name)); } protected: KernelRegContext() = default; KernelRegContext(const KernelRegContext&) = delete; virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; using OpKernelCreateFn = std::function; using InferTmpSizeFn = std::function; using AddInplaceArgPair = std::function( const std::string& out_arg_name, int32_t out_arg_index, const std::string& in_arg_name, int32_t in_arg_index, bool is_mutable)>; using InplaceProposalFn = std::function(const InferContext&, AddInplaceArgPair)>; using IsMatchedHob = std::shared_ptr>; constexpr int kKernelPriorityFallback = -10; constexpr int kKernelPriorityDefault = 0; constexpr int kKernelPriorityOptimized = 10; constexpr int kKernelPriorityExperimental = 100; struct OpKernelRegistryResult { std::string op_type_name; OpKernelCreateFn create_fn; bool need_temp_storage; InferTmpSizeFn infer_tmp_size_fn; InplaceProposalFn inplace_proposal_fn; IsMatchedHob is_matched_hob; int32_t priority = kKernelPriorityDefault; }; class OpKernelRegistry final { public: OpKernelRegistry& Name(const std::string& op_type_name); template OpKernelRegistry& SetCreateFn() { return SetCreateFn([]() -> const OpKernel* { return NewOpKernel(); }); } template OpKernelRegistry& SetIsMatchedHob(const T& hob) { result_.is_matched_hob = std::make_shared(hob); return *this; } OpKernelRegistry& SetInferTmpSizeFn(InferTmpSizeFn fn); OpKernelRegistry& SetInplaceProposalFn(InplaceProposalFn fn); Maybe Finish(); OpKernelRegistryResult GetResult() { return result_; } OpKernelRegistry& SetCreateFn(OpKernelCreateFn fn); OpKernelRegistry& SetPriority(int32_t priority); private: OpKernelRegistryResult result_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_ ================================================ FILE: oneflow/core/framework/user_op_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace user_op { namespace { bool InsertIfNotExists(const std::string& name, HashSet* unique_names) { if (unique_names->find(name) != unique_names->end()) { return false; } unique_names->emplace(name); return true; } } // namespace OpRegistry& OpRegistry::Name(const std::string& op_type_name) { CHECK(InsertIfNotExists(op_type_name, &unique_names_)); result_.op_type_name = op_type_name; return *this; } OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional) { CHECK(InsertIfNotExists(name, &unique_names_)) << "op arg registered, name: " << name << ", op: " << result_.op_type_name; UserOpDef::ArgDef arg_def; { arg_def.set_name(name); arg_def.set_is_optional(is_optional); } if (is_input) { *(result_.op_def.mutable_input()->Add()) = arg_def; } else { *(result_.op_def.mutable_output()->Add()) = arg_def; } return *this; } #define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \ OpRegistry& OpRegistry::name_prefix(const std::string& name) { \ return ArgImpl(is_input, name, is_optional); \ } OP_REG_ARG_MEMBER_FUNC(Input, true, false) OP_REG_ARG_MEMBER_FUNC(OptionalInput, true, true) OP_REG_ARG_MEMBER_FUNC(Output, false, false) OP_REG_ARG_MEMBER_FUNC(OptionalOutput, false, true) #undef OP_REG_ARG_MEMBER_FUNC OpRegistry& OpRegistry::SupportCpuOnly() { result_.cpu_only_supported = true; return *this; } OpRegistry& OpRegistry::SupportNonContiguous() { result_.non_contiguous_supported = true; return *this; } OpRegistry& OpRegistry::NoGrad() { result_.no_grad = true; return *this; } OpRegistry& OpRegistry::SetOutputBufferNum(int32_t num) { result_.same_output_regst_num = num; return *this; } OpRegistry& OpRegistry::Attr(const std::string& name, AttrType type) { CHECK(InsertIfNotExists(name, &unique_names_)); UserOpDef::AttrDef attr_def; attr_def.set_name(name); attr_def.set_type(type); *(result_.op_def.mutable_attr()->Add()) = attr_def; return *this; } namespace { void AddAttrWithDefault(OpRegistryResult* result, const std::string& name, AttrType type, std::function handler) { UserOpDef::AttrDef attr_def; attr_def.set_name(name); attr_def.set_type(type); handler(&attr_def); *(result->op_def.mutable_attr()->Add()) = std::move(attr_def); } } // namespace #define ATTR_MEMBER_FUNC(field, cpp_type, attr_type) \ template<> \ OpRegistry& OpRegistry::Attr(const std::string& name, AttrType type, \ const cpp_type& default_val) { \ CHECK_EQ(type, attr_type); \ return DefaultedAttr(name, type, [default_val](UserOpDef::AttrDef* attr_def) { \ AttrValueAccessor::Attr(default_val, attr_def->mutable_default_val()); \ }); \ } \ template<> \ OpRegistry& OpRegistry::Attr(const std::string& name, const cpp_type& default_val) { \ return DefaultedAttr( \ name, GetAttrType::value, [default_val](UserOpDef::AttrDef* attr_def) { \ AttrValueAccessor::Attr(default_val, attr_def->mutable_default_val()); \ }); \ } \ template<> \ OpRegistry& OpRegistry::Attr(const std::string& name) { \ return Attr(name, cpp_type()); \ } OF_PP_FOR_EACH_TUPLE(ATTR_MEMBER_FUNC, ATTR_SEQ) #undef ATTR_MEMBER_FUNC OpRegistry& OpRegistry::DefaultedAttr(const std::string& name, AttrType type, const std::function& SetDefault) { CHECK(InsertIfNotExists(name, &unique_names_)); AddAttrWithDefault(&result_, name, type, SetDefault); return *this; } OpRegistry& OpRegistry::SetTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) { SetLogicalTensorDescInferFn(tensor_desc_infer_fn); SetPhysicalTensorDescInferFn(tensor_desc_infer_fn); return *this; } OpRegistry& OpRegistry::SetLogicalTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) { result_.logical_tensor_desc_infer_fn = std::move(tensor_desc_infer_fn); return *this; } OpRegistry& OpRegistry::SetPhysicalTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) { result_.physical_tensor_desc_infer_fn = std::move(tensor_desc_infer_fn); return *this; } OpRegistry& OpRegistry::SetCheckAttrFn(CheckAttrFn fn) { result_.check_fn = std::move(fn); return *this; } OpRegistry& OpRegistry::SetGetSbpFn(GetSbpFn get_sbp_fn) { result_.get_sbp_fn = std::move(get_sbp_fn); return *this; } OpRegistry& OpRegistry::SetSbpSignatureInferFn(SbpSignatureInferFn sbp_signature_infer_fn) { result_.sbp_signature_infer_fn = std::move(sbp_signature_infer_fn); return *this; } OpRegistry& OpRegistry::SetInputArgModifyFn(InputArgModifyFn input_arg_modify_fn) { result_.input_arg_modify_fn = std::move(input_arg_modify_fn); return *this; } OpRegistry& OpRegistry::SetOutputArgModifyFn(OutputArgModifyFn output_arg_modify_fn) { result_.output_arg_modify_fn = std::move(output_arg_modify_fn); return *this; } OpRegistry& OpRegistry::SetOutputBlobTimeShapeInferFn( OutputBlobTimeShapeInferFn output_blob_time_shape_infer_fn) { result_.output_blob_time_shape_infer_fn = std::move(output_blob_time_shape_infer_fn); return *this; } OpRegistry& OpRegistry::SetNdSbpInferFn(NdSbpInferFn nd_sbp_infer_fn) { result_.nd_sbp_infer_fn = std::move(nd_sbp_infer_fn); return *this; } OpRegistry& OpRegistry::SetDataTypeInferFn(DataTypeInferFn data_type_infer_fn) { result_.data_type_infer_fn = std::move(data_type_infer_fn); return *this; } OpRegistry& OpRegistry::SetDeviceAndStreamInferFn( DeviceAndStreamInferFn device_and_stream_infer_fn) { result_.device_and_stream_infer_fn = std::move(device_and_stream_infer_fn); return *this; } OpRegistry& OpRegistry::SetComputeComplexityFn(ComputeComplexityFn compute_complexity_fn) { result_.compute_complexity_fn = std::move(compute_complexity_fn); return *this; } OpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_nd_sbp_list_fn) { result_.get_nd_sbp_list_fn = std::move(get_nd_sbp_list_fn); return *this; } OpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) { result_.enumerate_nd_sbp_signatures_fn = std::move(fn); return *this; } OpRegistry& OpRegistry::SetDumpNdSbpSignatureForOpConfFn( Operator::DumpNdSbpSignatureForOpConfFn fn) { result_.dump_nd_sbp_signature_for_op_conf_fn = std::move(fn); return *this; } Maybe OpRegistry::Finish() { CHECK_OR_RETURN(result_.logical_tensor_desc_infer_fn != nullptr) << "No TensorDescInfer function for " << result_.op_type_name; if (!result_.physical_tensor_desc_infer_fn) { const auto& logical_fn = result_.logical_tensor_desc_infer_fn; result_.physical_tensor_desc_infer_fn = [logical_fn](user_op::InferContext* ctx) -> Maybe { if (ctx->parallel_num() == 1) { logical_fn(ctx); } else { for (const auto& pair : ctx->inputs()) { const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second); const TensorDesc* in_logical = ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second); const TensorDesc& in_physical = ctx->InputTensorDesc(pair.first, pair.second); CHECK_OR_RETURN(*JUST(GetPhysicalShape(in_logical->shape(), nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())) == in_physical.shape()); } for (const auto& pair : ctx->outputs()) { TensorDesc* desc = ctx->MutOutputTensorDesc(pair.first, pair.second); *desc = *ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second); desc->set_shape(*JUST( GetPhysicalShape(desc->shape(), nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()))); desc->set_stride(Stride(desc->shape())); } } return Maybe::Ok(); }; } if (result_.check_fn == nullptr) { result_.check_fn = CheckAttrFnUtil::NoCheck; } CHECK_OR_RETURN(result_.get_sbp_fn != nullptr) << "No Sbp function for " << result_.op_type_name; if (result_.cpu_only_supported && result_.device_and_stream_infer_fn == nullptr) { result_.device_and_stream_infer_fn = [](DeviceAndStreamInferContext* ctx) -> Maybe> { for (const auto& pair : ctx->inputs()) { const Symbol& input_device = ctx->InputTensorDevice4ArgNameAndIndex(pair.first, pair.second); CHECK_EQ(input_device->type(), "cpu"); } Symbol default_device; { if (ctx->inputs().size() != 0) { const auto& first_input_name = ctx->inputs().begin()->first; default_device = ctx->InputTensorDevice4ArgNameAndIndex(first_input_name, 0); } else { default_device = JUST(Device::New("cpu")); } } for (const auto& pair : ctx->outputs()) { *ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = default_device; } return Stream::New(default_device, StreamType::kCompute); }; } return *this; } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/user_op_registry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/framework/user_op_conf.pb.h" #include "oneflow/core/operator/op_attribute.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { class Device; class Stream; namespace user_op { class UserOpDefWrapper; class UserOpConfWrapper; class InferContext; class SbpContext; class InferSbpSignatureFnContext; class InferOutputBlobTimeShapeFnContext; class InferNdSbpFnContext; class DeviceAndStreamInferContext; class ComputeComplexityFnContext; class GetNdSbpSignatureListContext; using CheckAttrFn = std::function(const UserOpDefWrapper&, const UserOpConfWrapper&)>; using TensorDescInferFn = std::function(InferContext*)>; using DataTypeInferFn = std::function(InferContext*)>; using DeviceAndStreamInferFn = std::function>(DeviceAndStreamInferContext*)>; using GetSbpFn = std::function(SbpContext*)>; using SbpSignatureInferFn = std::function(InferSbpSignatureFnContext*)>; using InputArgModifier = InputBlobModifier; using GetInputArgModifier = std::function; using InputArgModifyFn = std::function(const GetInputArgModifier&, const UserOpConfWrapper&)>; using OutputArgModifier = OutputBlobModifier; using GetOutputArgModifier = std::function; using OutputArgModifyFn = std::function(const GetOutputArgModifier&, const UserOpConfWrapper&)>; using OutputBlobTimeShapeInferFn = std::function(InferOutputBlobTimeShapeFnContext*)>; using NdSbpInferFn = std::function(InferNdSbpFnContext*)>; using ComputeComplexityFn = std::function(ComputeComplexityFnContext*)>; // TODO: set up another context using GetNdSbpSignatureListFn = std::function(GetNdSbpSignatureListContext*)>; using EnumerateNdSbpSignaturesFn = std::function(GetNdSbpSignatureListContext*)>; struct OpRegistryResult { OpRegistryResult() : cpu_only_supported(false), no_grad(false), non_contiguous_supported(false), same_output_regst_num(-1) {} ~OpRegistryResult() = default; std::string op_type_name; bool cpu_only_supported; bool no_grad; bool non_contiguous_supported; int32_t same_output_regst_num; UserOpDef op_def; CheckAttrFn check_fn; TensorDescInferFn logical_tensor_desc_infer_fn; TensorDescInferFn physical_tensor_desc_infer_fn; GetSbpFn get_sbp_fn; SbpSignatureInferFn sbp_signature_infer_fn; DataTypeInferFn data_type_infer_fn; DeviceAndStreamInferFn device_and_stream_infer_fn; // TODO(niuchong): move input_arg_modify_fn out of OpRegistryResult since it is more about // performance other than op definition InputArgModifyFn input_arg_modify_fn; OutputArgModifyFn output_arg_modify_fn; OutputBlobTimeShapeInferFn output_blob_time_shape_infer_fn; NdSbpInferFn nd_sbp_infer_fn; ComputeComplexityFn compute_complexity_fn; GetNdSbpSignatureListFn get_nd_sbp_list_fn; EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn; Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn; }; class OpRegistry final { public: OpRegistry& Name(const std::string& op_type_name); OpRegistry& Input(const std::string& name); OpRegistry& Input(const std::string& name, int32_t num); OpRegistry& InputWithMinimum(const std::string& name, int32_t min_num); OpRegistry& OptionalInput(const std::string& name); OpRegistry& OptionalInput(const std::string& name, int32_t num); OpRegistry& OptionalInputWithMinimum(const std::string& name, int32_t min_num); OpRegistry& Output(const std::string& name); OpRegistry& Output(const std::string& name, int32_t num); OpRegistry& OutputWithMinimum(const std::string& name, int32_t min_num); OpRegistry& OptionalOutput(const std::string& name); OpRegistry& OptionalOutput(const std::string& name, int32_t num); OpRegistry& OptionalOutputWithMinimum(const std::string& name, int32_t min_num); OpRegistry& SupportCpuOnly(); OpRegistry& SupportNonContiguous(); OpRegistry& NoGrad(); OpRegistry& SetOutputBufferNum(int32_t num); __attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type); template __attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type, const T& default_val); template OpRegistry& Attr(const std::string& name, const T& default_val); template OpRegistry& Attr(const std::string& name); OpRegistry& SetTensorDescInferFn(TensorDescInferFn fn); OpRegistry& SetLogicalTensorDescInferFn(TensorDescInferFn fn); OpRegistry& SetPhysicalTensorDescInferFn(TensorDescInferFn fn); OpRegistry& SetGetSbpFn(GetSbpFn fn); OpRegistry& SetSbpSignatureInferFn(SbpSignatureInferFn fn); OpRegistry& SetInputArgModifyFn(InputArgModifyFn fn); OpRegistry& SetOutputArgModifyFn(OutputArgModifyFn fn); OpRegistry& SetOutputBlobTimeShapeInferFn(OutputBlobTimeShapeInferFn fn); OpRegistry& SetNdSbpInferFn(NdSbpInferFn fn); OpRegistry& SetCheckAttrFn(CheckAttrFn fn); OpRegistry& SetDataTypeInferFn(DataTypeInferFn fn); OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn); OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn); OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn); OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn); OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn); Maybe Finish(); OpRegistryResult GetResult() { return result_; } private: OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional); OpRegistry& DefaultedAttr(const std::string& name, AttrType type, const std::function& SetDefault); private: HashSet unique_names_; OpRegistryResult result_; }; static const std::string kUserSourceOpTickInputArgName = "UserSourceOpTickInput"; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_ ================================================ FILE: oneflow/core/framework/user_op_registry_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_KERNEL_ENABLE_PRIORITY_EXPERIMENTAL, false); namespace user_op { UserOpRegistryMgr& UserOpRegistryMgr::Get() { static UserOpRegistryMgr mgr; return mgr; } OpRegistry UserOpRegistryMgr::CheckAndGetOpRegistry(const std::string& op_type_name) { CHECK(!op_type_name.empty()); auto it = op_reg_result_.find(op_type_name); CHECK(it == op_reg_result_.end()); return OpRegistry().Name(op_type_name); } Maybe UserOpRegistryMgr::Register(OpRegistryResult result) { CHECK_OR_RETURN(result.data_type_infer_fn); CHECK_OR_RETURN(op_reg_result_.emplace(result.op_type_name, result).second); return Maybe::Ok(); } const OpRegistryResult* UserOpRegistryMgr::GetOpRegistryResult(const std::string& op_type_name) { auto it = op_reg_result_.find(op_type_name); if (it != op_reg_result_.end()) { return &(it->second); } return nullptr; } OpKernelRegistry UserOpRegistryMgr::CheckAndGetOpKernelRegistry(const std::string& op_type_name) { CHECK(!op_type_name.empty()); return OpKernelRegistry().Name(op_type_name); } Maybe UserOpRegistryMgr::Register(OpKernelRegistryResult result) { op_kernel_reg_result_[result.op_type_name].emplace_back(result); return Maybe::Ok(); } namespace { std::string GetErrorMsgOfSearchedOp(const KernelRegContext& ctx) { const auto& op_conf = ctx.user_op_conf(); std::stringstream ss; ss << " The Info of OperatorConf are " << "\n op_name: " << op_conf.op_name() << "\n op_type_name: " << op_conf.op_type_name() << "\n DeviceType_Name: " << DeviceType_Name(ctx.device_type()); for (const auto& pair : ctx.inputs()) { ss << "\n DataType_Name of " << pair.first << "_" << pair.second << ": " << DataType_Name(ctx.TensorDesc4ArgNameAndIndex(pair.first, pair.second)->data_type()); } for (const auto& pair : ctx.outputs()) { ss << "\n DataType_Name of " << pair.first << "_" << pair.second << ": " << DataType_Name(ctx.TensorDesc4ArgNameAndIndex(pair.first, pair.second)->data_type()); } return ss.str(); } } // namespace Maybe UserOpRegistryMgr::GetOpKernelRegistryResult( const std::string& op_type_name, const KernelRegContext& ctx) { auto it = op_kernel_reg_result_.find(op_type_name); if (it == op_kernel_reg_result_.end()) { return Error::OpKernelNotFoundError({}) << "There is no kernel registered for Current OperatorConf. " << GetErrorMsgOfSearchedOp(ctx); } const OpKernelRegistryResult* ret = nullptr; int32_t cur_priority = kKernelPriorityFallback; const bool enable_priority_experimental = EnvBool(); for (const auto& reg_val : it->second) { if (reg_val.priority >= kKernelPriorityExperimental && (!enable_priority_experimental)) { continue; } if (reg_val.is_matched_hob->get(ctx)) { if (ret == nullptr || reg_val.priority > cur_priority) { ret = ®_val; cur_priority = reg_val.priority; } else if (ret != nullptr && reg_val.priority == cur_priority) { LOG(WARNING) << "There are more than one kernels with same priority matching Current OperatorConf. " << GetErrorMsgOfSearchedOp(ctx); } else { // do nothing } } } if (ret == nullptr) { std::vector debug_msgs; for (const auto& reg_val : it->second) { debug_msgs.emplace_back(reg_val.is_matched_hob->DebugStr(ctx)); } return Error::OpKernelNotFoundError(debug_msgs) << "Cannot find the kernel matching Current OperatorConf. " << GetErrorMsgOfSearchedOp(ctx); } return ret; } Maybe UserOpRegistryMgr::IsOpKernelRegistered(const std::string& op_type_name, const KernelRegContext& ctx) { auto it = op_kernel_reg_result_.find(op_type_name); if (it == op_kernel_reg_result_.end()) { return false; } const bool enable_priority_experimental = EnvBool(); for (const auto& reg_val : it->second) { if (reg_val.priority >= kKernelPriorityExperimental && (!enable_priority_experimental)) { continue; } if (reg_val.is_matched_hob->get(ctx)) { return true; } } return false; } UserOpHostMemoryInputRegistry& UserOpHostMemoryInputRegistry::Get() { static UserOpHostMemoryInputRegistry mgr; return mgr; } Maybe UserOpHostMemoryInputRegistry::SetHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name, int32_t index) { auto it = op_type_name2host_memory_input_args_.find(op_type_name); if (it == op_type_name2host_memory_input_args_.end()) { auto pair = op_type_name2host_memory_input_args_.emplace( op_type_name, small_vector>()); CHECK_OR_RETURN(pair.second); it = pair.first; } it->second.emplace_back(std::make_pair(arg_name, index)); return Maybe::Ok(); } bool UserOpHostMemoryInputRegistry::IsHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name, int32_t index) const { auto it = op_type_name2host_memory_input_args_.find(op_type_name); if (it == op_type_name2host_memory_input_args_.end()) { return false; } return std::find(it->second.begin(), it->second.end(), std::make_pair(arg_name, index)) != it->second.end(); } bool UserOpHostMemoryInputRegistry::HasHostMemoryInput(const std::string& op_type_name) const { return op_type_name2host_memory_input_args_.find(op_type_name) != op_type_name2host_memory_input_args_.end(); } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/framework/user_op_registry_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/common/registry_error.h" #include "oneflow/core/common/op_args_reserved_size.h" namespace oneflow { namespace user_op { class UserOpRegistryMgr final { private: UserOpRegistryMgr() {} public: UserOpRegistryMgr(UserOpRegistryMgr const&) = delete; UserOpRegistryMgr& operator=(UserOpRegistryMgr const&) = delete; static UserOpRegistryMgr& Get(); public: OpRegistry CheckAndGetOpRegistry(const std::string& op_type_name); Maybe Register(OpRegistryResult result); const OpRegistryResult* GetOpRegistryResult(const std::string& op_type_name); OpKernelRegistry CheckAndGetOpKernelRegistry(const std::string& op_type_name); Maybe Register(OpKernelRegistryResult result); Maybe GetOpKernelRegistryResult(const std::string& op_type_name, const KernelRegContext& ctx); Maybe IsOpKernelRegistered(const std::string& op_type_name, const KernelRegContext& ctx); const HashMap& GetAllOpRegistryResults() { return op_reg_result_; }; private: HashMap op_reg_result_; HashMap> op_kernel_reg_result_; }; template struct UserOpRegisterTrigger final { UserOpRegisterTrigger(RegistryT& registry) { CatchRegistryError([&]() -> Maybe { return UserOpRegistryMgr::Get().Register(JUST(registry.Finish()).GetResult()); }); } }; class UserOpHostMemoryInputRegistry final { public: UserOpHostMemoryInputRegistry(UserOpHostMemoryInputRegistry const&) = delete; UserOpHostMemoryInputRegistry& operator=(UserOpHostMemoryInputRegistry const&) = delete; ~UserOpHostMemoryInputRegistry() = default; static UserOpHostMemoryInputRegistry& Get(); Maybe SetHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name, int32_t index); bool IsHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name, int32_t index) const; bool HasHostMemoryInput(const std::string& op_type_name) const; private: UserOpHostMemoryInputRegistry() {} HashMap>> op_type_name2host_memory_input_args_; }; } // namespace user_op } // namespace oneflow #define REGISTER_OP_HOST_MEMORY_INPUT(op_type_name, arg_name, index) \ COMMAND(CHECK_JUST(user_op::UserOpHostMemoryInputRegistry::Get().SetHostMemoryInput4Op( \ op_type_name, arg_name, index))); #define REGISTER_USER_OP(name) \ static ::oneflow::user_op::UserOpRegisterTrigger<::oneflow::user_op::OpRegistry> OF_PP_CAT( \ g_register_trigger, __COUNTER__) = \ ::oneflow::user_op::UserOpRegistryMgr::Get().CheckAndGetOpRegistry(name) #define REGISTER_CPU_ONLY_USER_OP(name) REGISTER_USER_OP(name).SupportCpuOnly() #define REGISTER_NO_GRAD_USER_OP(name) REGISTER_USER_OP(name).NoGrad() #define REGISTER_NO_GRAD_CPU_ONLY_USER_OP(name) REGISTER_NO_GRAD_USER_OP(name).SupportCpuOnly() #define REGISTER_USER_KERNEL(name) \ static ::oneflow::user_op::UserOpRegisterTrigger<::oneflow::user_op::OpKernelRegistry> \ OF_PP_CAT(g_register_trigger, __COUNTER__) = \ ::oneflow::user_op::UserOpRegistryMgr::Get().CheckAndGetOpKernelRegistry(name) #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_ ================================================ FILE: oneflow/core/framework/user_op_tensor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_ #define ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_ #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/common/error.h" namespace oneflow { namespace user_op { class Tensor { public: #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" // NOTE: Performance will be degraded if the destructor is virtual. // So please do NOT implement custom destructor in any child classes of user_op::Tensor, // and every fields of child classes should be of POD type. ~Tensor() = default; #pragma GCC diagnostic pop virtual ShapeView shape_view() const = 0; virtual MutShapeView mut_shape_view() = 0; virtual const Stride& stride() const = 0; virtual DataType data_type() const = 0; virtual MemoryFormat memory_format() const = 0; virtual const MemoryCase& mem_case() const = 0; virtual const void* raw_dptr() const = 0; virtual void* mut_raw_dptr() = 0; template const T* dptr() const { CheckDataType(); return reinterpret_cast(raw_dptr()); } template T* mut_dptr() { CheckDataType(); return reinterpret_cast(mut_raw_dptr()); } protected: template void CheckDataType() const { LOG_IF(FATAL, (std::is_same::value == false && std::is_same::value == false && data_type() != DataType::kChar && data_type() != GetDataType::value)) << "tensor data_type mismatched. value: " << DataType_Name(data_type()) << ", template T:" << DataType_Name(GetDataType::value); } }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_ ================================================ FILE: oneflow/core/framework/util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_UTIL_H_ #include "oneflow/core/common/util.h" namespace std { template<> struct hash> { std::size_t operator()(const std::pair& p) const { return oneflow::Hash(p.first, p.second); } }; } // namespace std namespace oneflow { namespace user_op {} } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_UTIL_H_ ================================================ FILE: oneflow/core/framework/variable_meta_info.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; message VariableMetaInfo { required ShapeProto shape = 2; required DataType data_type = 3; } ================================================ FILE: oneflow/core/framework/variable_tensor_mgr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/functional.h" namespace oneflow { Maybe VariableTensorMgr::Set(const std::string& variable_op_name, const std::shared_ptr& variable_tensor, const Symbol& dtype) { if (dtype && variable_tensor->dtype() != dtype) { LazyMode::Guard guard{false}; variables_[variable_op_name] = JUST(one::functional::Cast(variable_tensor, dtype, false)); } else { variables_[variable_op_name] = variable_tensor; } return Maybe::Ok(); } Maybe VariableTensorMgr::Get(const std::string& variable_op_name, const Symbol& dtype) { if (variables_.find(variable_op_name) != variables_.end()) { const auto variable_tensor = variables_[variable_op_name]; if (dtype && variable_tensor->dtype() != dtype) { LazyMode::Guard guard{false}; return JUST(one::functional::Cast(variable_tensor, dtype, false)); } return variable_tensor; } return std::shared_ptr(nullptr); } void VariableTensorMgr::Delete(const std::string& variable_op_name) { if (variables_.find(variable_op_name) != variables_.end()) { variables_.erase(variable_op_name); } } Maybe VariableTensorMgr::Fill( const std::vector& variable_op_names, const std::vector>& variable_tensors) { CHECK_EQ_OR_THROW(variable_op_names.size(), variable_tensors.size()) << "The number of variable op names is not equal with the number of variable tensors."; for (size_t i = 0; i < variable_op_names.size(); ++i) { JUST(Set(JUST(oneflow::VectorAt(variable_op_names, i)), JUST(oneflow::VectorAt(variable_tensors, i)))); } return Maybe::Ok(); } std::tuple, std::vector>> VariableTensorMgr::Dump() { std::vector variable_op_names; std::vector> variable_tensors; for (const auto& x : variables_) { variable_op_names.push_back(x.first); variable_tensors.push_back(x.second); } return std::make_tuple(variable_op_names, variable_tensors); } void VariableTensorMgr::Reset() { std::map>().swap(variables_); } std::vector VariableTensorMgr::DumpNames() { std::vector variable_op_names; for (const auto& x : variables_) { variable_op_names.push_back(x.first); } return variable_op_names; } } // namespace oneflow ================================================ FILE: oneflow/core/framework/variable_tensor_mgr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_ #define ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_ #include #include #include #include "oneflow/core/common/just.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/dtype.h" namespace oneflow { template class Singleton; namespace one { class Tensor; } class VariableTensorMgr final { public: OF_DISALLOW_COPY_AND_MOVE(VariableTensorMgr); ~VariableTensorMgr() = default; Maybe Set(const std::string& variable_op_name, const std::shared_ptr& variable_tensor, const Symbol& dtype = Symbol()); Maybe Get(const std::string& variable_op_name, const Symbol& dtype = Symbol()); void Delete(const std::string& variable_op_name); Maybe Fill(const std::vector& variable_op_names, const std::vector>& variable_tensors); std::tuple, std::vector>> Dump(); std::vector DumpNames(); void Reset(); private: friend class Singleton; VariableTensorMgr() = default; std::map> variables_; }; } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_ ================================================ FILE: oneflow/core/functional/function_library.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_ #define ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/functional/packed_functor.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/common/throw.h" namespace oneflow { namespace one { namespace functional { class FunctionLibrary { public: virtual ~FunctionLibrary() = default; template struct PackedFuncCreatorMap; template struct PackedFuncCreatorMap { using FunctorCreator = typename std::function()>; static HashMap* Get() { static HashMap functors; return &functors; } }; template void add_functor(const std::string& func_name, const Func& func) { using func_type = typename function_traits::func_type; add_functor_creator( func_name, [=]() { return PackedFunctorMaker::make(func_name, func); }); } template void add_one_functor(const std::string& func_name) { using func_type = typename function_traits::func_type; add_functor_creator(func_name, [=]() { // Lazily construct functor since ops maybe have not been registered. Func func; return PackedFunctorMaker::make(func_name, func); }); } template void add_functor(const std::string& func_name) { static_assert(sizeof...(Fs) > 0, "at least one functor is expected"); (add_one_functor(func_name), ...); } template auto find(const std::string& func_name) -> Maybe::FType>> { auto* functors = PackedFuncCreatorMap::FType>::Get(); const auto& it = functors->find(func_name); CHECK_OR_RETURN(it != functors->end()) << Error::RuntimeError() << "Functor was not found for \"" << func_name << "\", please check whether the functor has been registered correctly or not."; return it->second(); } static FunctionLibrary* Global() { static FunctionLibrary global_function_library; return &global_function_library; } private: FunctionLibrary() = default; template void add_functor_creator(const std::string& func_name, Creator creator) { using func_type = typename function_traits::func_type; auto* functors = PackedFuncCreatorMap::FType>::Get(); CHECK_OR_THROW(functors->count(func_name) == 0) << Error::RuntimeError() << "The functor with name " << func_name << " has been registered more than once."; functors->emplace(func_name, creator); } }; #define ONEFLOW_FUNCTION_LIBRARY(m) ONEFLOW_FUNCTION_LIBRARY_IMPL(m, __COUNTER__) #define ONEFLOW_FUNCTION_LIBRARY_IMPL(m, uuid) \ static void OF_PP_CAT(_oneflow_function_library_, uuid)(FunctionLibrary & m); \ static int OF_PP_CAT(_oneflow_function_library_dummy_, uuid) = []() { \ FunctionLibrary* library = FunctionLibrary::Global(); \ OF_PP_CAT(_oneflow_function_library_, uuid)(*library); \ return 0; \ }(); \ void OF_PP_CAT(_oneflow_function_library_, uuid)(FunctionLibrary & m) } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_ ================================================ FILE: oneflow/core/functional/functional.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_ #define ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_ #include "oneflow/core/functional/functional_api.yaml.h" #endif // ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_ ================================================ FILE: oneflow/core/functional/functional_api.yaml ================================================ # Copyright 2020 The OneFlow Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # The following data types are allowed, # { # "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", # "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", # "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement", # "Sbp", "SbpList", "Layout", "MemoryFormat", # } - name: "add" signature: [ "Tensor (Tensor input, Tensor other, *, Scalar alpha=1, Bool inplace=False) => Add", "Tensor (Tensor input, Scalar other, *, Scalar alpha=1, Bool inplace=False) => ScalarAdd", "Tensor (Scalar input, Tensor other, *, Scalar alpha=1) => ScalarAdd", "Tensor (TensorTuple inputs, *, Bool inplace=False) => Add", ] bind_python: true # this api just for test host memory input - name: "host_scalar_add_by_tensor" signature: "Tensor (Tensor x, Tensor scalar) => HostScalarAddByTensor" bind_python: true - name: "amin" signature: "Tensor (Tensor input, Int32List[1] dim=None, Bool keepdim=False) => Amin" bind_python: True - name: "sub" signature: [ "Tensor (Tensor input, Tensor other, *, Scalar alpha=1, Bool inplace=False) => Sub", "Tensor (Tensor input, Scalar other, *, Scalar alpha=1, Bool inplace=False) => ScalarSub", "Tensor (Scalar input, Tensor other, *, Scalar alpha=1) => ScalarSub", ] bind_python: true - name: "mul" signature: [ "Tensor (Tensor input, Tensor other) => Mul", "Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarMul", "Tensor (Scalar input, Tensor other) => ScalarMul", ] bind_python: true - name: "mul_" signature: [ "Tensor (Tensor input, Tensor other) => InplaceMul", "Tensor (Tensor input, Scalar other) => InplaceScalarMul", ] bind_python: true - name: "addcmul" signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => Addcmul" bind_python: true - name: "addcmul_" signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddcmul" bind_python: true - name: "addcdiv" signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => AddCDiv" bind_python: true - name: "addcdiv_" signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddCDiv" bind_python: true - name: "div" signature: [ "Tensor (Tensor input, Tensor other) => Div", "Tensor (Tensor input, Scalar other) => ScalarDiv", "Tensor (Scalar input, Tensor other) => ScalarDiv", "Tensor (Tensor input, Tensor other, *, String rounding_mode=None) => DivMode", "Tensor (Tensor input, Scalar other, *, String rounding_mode=None) => ScalarDivMode", "Tensor (Scalar input, Tensor other, *, String rounding_mode=None) => ScalarDivMode", ] bind_python: true - name: "div_" signature: [ "Tensor (Tensor input, Tensor other) => InplaceDiv", "Tensor (Tensor input, Scalar other) => InplaceScalarDiv", ] bind_python: true - name: "div_grad" signature: "Tensor (Tensor dz, Tensor z, Tensor y) => DivGrad" bind_python: False - name: "equal" signature: "Bool (Tensor input, Tensor other) => Equal" bind_python: true - name: "broadcast_equal" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastEqual", "Tensor (Tensor input, Scalar other) => ScalarLogicalEqual", "Tensor (Scalar input, Tensor other) => ScalarLogicalEqual", ] bind_python: true - name: "not_equal" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastNotEqual", "Tensor (Tensor input, Scalar other) => ScalarLogicalNotEqual", "Tensor (Scalar input, Tensor other) => ScalarLogicalNotEqual", ] bind_python: true - name: "greater" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastGreater", "Tensor (Tensor input, Scalar other) => ScalarLogicalGreater", "Tensor (Scalar input, Tensor other) => ScalarLogicalGreater", ] bind_python: true - name: "greater_" signature: [ "Tensor (Tensor input, Tensor other) => InplaceBroadcastGreater", "Tensor (Tensor input, Scalar other) => InplaceScalarLogicalGreater", ] bind_python: true - name: "greater_equal" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastGreaterEqual", "Tensor (Tensor input, Scalar other) => ScalarLogicalGreaterEqual", "Tensor (Scalar input, Tensor other) => ScalarLogicalGreaterEqual", ] bind_python: true - name: "logical_and" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastLogicalAnd", "Tensor (Tensor input, Scalar other) => ScalarLogicalAnd", "Tensor (Scalar input, Tensor other) => ScalarLogicalAnd", ] bind_python: true - name: "logical_or" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastLogicalOr", "Tensor (Tensor input, Scalar other) => ScalarLogicalOr", "Tensor (Scalar input, Tensor other) => ScalarLogicalOr", ] bind_python: true - name: "logical_not" signature: "Tensor (Tensor input) => LogicalNot" bind_python: true - name: "logical_xor" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastLogicalXor", "Tensor (Tensor input, Scalar other) => ScalarLogicalXor", "Tensor (Scalar input, Tensor other) => ScalarLogicalXor", ] bind_python: true - name: "bitwise_and" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastBitwiseAnd", "Tensor (Tensor input, Scalar other) => ScalarBitwiseAnd", "Tensor (Scalar input, Tensor other) => ScalarBitwiseAnd", ] bind_python: true - name: "bitwise_or" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastBitwiseOr", "Tensor (Tensor input, Scalar other) => ScalarBitwiseOr", "Tensor (Scalar input, Tensor other) => ScalarBitwiseOr", ] bind_python: true - name: "bitwise_xor" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastBitwiseXor", "Tensor (Tensor input, Scalar other) => ScalarBitwiseXor", "Tensor (Scalar input, Tensor other) => ScalarBitwiseXor", ] bind_python: true - name: "bitwise_not" signature: "Tensor (Tensor input) => BitwiseNot" bind_python: true - name: "less" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastLess", "Tensor (Tensor input, Scalar other) => ScalarLogicalLess", "Tensor (Scalar input, Tensor other) => ScalarLogicalLess", ] bind_python: True - name: "less_equal" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastLessEqual", "Tensor (Tensor input, Scalar other) => ScalarLogicalLessEqual", "Tensor (Scalar input, Tensor other) => ScalarLogicalLessEqual", ] bind_python: True - name: "pow" signature: [ "Tensor (Tensor input, Tensor exponent) => Pow", "Tensor (Tensor input, Scalar exponent, *, Bool inplace=False) => ScalarPow", "Tensor (Tensor input, Scalar exponent) => ScalarPow", "Tensor (Scalar exponent, Tensor input) => ScalarReversePow", ] bind_python: True - name: "pow_x_grad" signature: "Tensor (Tensor x, Tensor y, Tensor dz) => PowXGrad" bind_python: False - name: "pow_y_grad" signature: "Tensor (Tensor x, Tensor y, Tensor dz) => PowYGrad" bind_python: False - name: "searchsorted" signature: [ "Tensor (Tensor sorted_sequence, Tensor values, Bool out_int32=False, Bool right=False) => SearchSorted", "Tensor (Tensor sorted_sequence, Scalar values, Bool out_int32=False, Bool right=False) => SearchSortedScalar", ] bind_python: True - name: "scalar_pow_grad" signature: "Tensor (Tensor input, Tensor dy, Scalar exponent) => ScalarPowGrad" bind_python: False - name: "scalar_reverse_pow_grad" signature: "Tensor (Tensor input, Tensor dy, Scalar exponent) => ScalarReversePowGrad" bind_python: False - name: "broadcast_pow" signature: "Tensor (Tensor x, Tensor y) => BroadcastPow" bind_python: False - name: "broadcast_pow_x_grad" signature: "Tensor (Tensor x, Tensor y, Tensor dz) => BroadcastPowXGrad" bind_python: False - name: "broadcast_pow_y_grad" signature: "Tensor (Tensor x, Tensor y, Tensor dz) => BroadcastPowYGrad" bind_python: False - name: "floor_divide" signature: [ "Tensor (Tensor input, Tensor other) => FloorDiv", "Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarFloorDiv", "Tensor (Tensor input, Scalar other) => ScalarFloorDiv", ] bind_python: True - name: "floordiv_x_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => FloorDivXGrad" bind_python: False - name: "floordiv_y_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => FloorDivYGrad" bind_python: False - name: "lerp" signature: [ "Tensor (Tensor start, Tensor end, Tensor weight) => Lerp", "Tensor (Tensor start, Tensor end, Scalar weight) => ScalarLerp" ] bind_python: True - name: "lerp_" signature: [ "Tensor (Tensor start, Tensor end, Tensor weight) => InplaceLerp", "Tensor (Tensor start, Tensor end, Scalar weight) => ScalarInplaceLerp", ] bind_python: True - name: "lerp_grad" signature: "TensorTuple (Tensor start, Tensor end, Tensor weight, Tensor out_diff) => LerpGrad" bind_python: False - name: "scalar_lerp_grad" signature: "TensorTuple (Tensor start, Tensor end, Tensor out_diff, Scalar weight) => ScalarLerpGrad" bind_python: False - name: "trunc_divide" signature: [ "Tensor (Tensor input, Tensor other) => TruncDiv", "Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarTruncDiv", ] bind_python: True - name: "truncdiv_x_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => TruncDivXGrad" bind_python: False - name: "truncdiv_y_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => TruncDivYGrad" bind_python: False - name: "xdivy_x_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => XdivyXGrad" bind_python: False - name: "xdivy_y_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => XdivyYGrad" bind_python: False - name: "xlogy_x_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => XlogyXGrad" bind_python: False - name: "xlogy_y_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => XlogyYGrad" bind_python: False - name: "max" signature: [ "Tensor (Tensor input) => Max", "Tensor (Tensor input, Tensor other) => Max", "TensorTuple[values, indices] (Tensor input, Int32 dim, Bool keepdim=False) => Max", ] bind_python: True - name: "min" signature: [ "Tensor (Tensor input) => Min", "TensorTuple[values, indices] (Tensor input, Int32 dim, Bool keepdim=False) => Min", "Tensor (Tensor input, Tensor other) => Min", ] bind_python: True - name: "median" signature: [ "Tensor (Tensor input) => Median", "TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool keepdim=False) => MedianWithIndices", ] bind_python: True - name: "reduce_max" signature: "Tensor (Tensor x, Int32List axis, Bool keepdim=False) => ReduceMax" bind_python: True - name: "reduce_min" signature: "Tensor (Tensor x, Int32List axis, Bool keepdim=False) => ReduceMin" bind_python: True - name: "reduce_sum" signature: [ "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceSum", "Tensor (Tensor x, *, DataType dtype=None) => ReduceSumWhole", ] bind_python: True - name: "reduce_nansum" signature: [ "Tensor (Tensor input, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceNanSum", "Tensor (Tensor input, *, DataType dtype=None) => ReduceNanSumWhole" ] bind_python: True - name: "reduce_mean" signature: [ "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceMean", "Tensor (Tensor x) => ReduceMeanWhole", ] bind_python: True - name: "reduce_all" signature: [ "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll", "Tensor (Tensor x) => ReduceAllWhole", ] bind_python: True - name: "reduce_any" signature: [ "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAny", "Tensor (Tensor x) => ReduceAnyWhole", ] bind_python: True - name: "reduce_prod" signature: [ "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceProd", "Tensor (Tensor x, *, DataType dtype=None) => ReduceProdWhole", ] bind_python: True - name: "reduce_min_device_stage" signature: "TensorTuple (Tensor in, Int32List axis) => ReduceMinDeviceStage" bind_python: True - name: "reduce_min_device_stage_grad" signature: "Tensor (Tensor out_diff, Tensor mask, Tensor count, Int32List axis) => ReduceMinDeviceStageGrad" bind_python: False - name: "reduce_max_device_stage" signature: "TensorTuple (Tensor in, Int32List axis) => ReduceMaxDeviceStage" bind_python: True - name: "reduce_max_device_stage_grad" signature: "Tensor (Tensor out_diff, Tensor mask, Tensor count, Int32List axis) => ReduceMaxDeviceStageGrad" bind_python: False - name: "reduce_min_global_stage" signature: "TensorTuple (Tensor in, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMinGlobalStage" bind_python: True - name: "reduce_min_global_stage_grad" signature: "Tensor (Tensor out_diff, Tensor mask, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMinGlobalStageGrad" bind_python: False - name: "reduce_max_global_stage" signature: "TensorTuple (Tensor in, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMaxGlobalStage" bind_python: True - name: "reduce_max_global_stage_grad" signature: "Tensor (Tensor out_diff, Tensor mask, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMaxGlobalStageGrad" bind_python: False - name: "logsumexp" signature: "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => LogSumExp" bind_python: True - name: "logaddexp" signature: "Tensor (Tensor x, Tensor y) => LogAddExp" bind_python: True - name: "quantile" signature: [ 'Tensor (Tensor input, Tensor q, Int64 dim=None, Bool keepdim=False, String interpolation="linear", Bool ignore_nan=False) => Quantile', 'Tensor (Tensor input, Scalar q, Int64 dim=None, Bool keepdim=False, String interpolation="linear", Bool ignore_nan=False) ==> ScalarQuantile' ] bind_python: True - name: "transpose" signature: [ "Tensor (Tensor input, Int32List perm) => Transpose", "Tensor (Tensor input, Int32 dim0, Int32 dim1) => Transpose2dim", ] bind_python: True - name: "as_strided" signature: "Tensor (Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => AsStrided" bind_python: True - name: "as_strided_grad" signature: "Tensor (Tensor dy, Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => AsStridedGrad" bind_python: False - name: "as_strided_" signature: "Tensor (Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => InplaceAsStrided" bind_python: True - name: "select" signature: "Tensor (Tensor input, Int32 dim, Int32 index) => Select" bind_python: True - name: "swapaxes" signature: "Tensor (Tensor input, Int32 dim0, Int32 dim1) => Swapaxes" bind_python: True - name: "swapdims" signature: "Tensor (Tensor input, Int32 dim0, Int32 dim1) => Swapdims" bind_python: True - name: "amax" signature: "Tensor (Tensor input, Int32List[1] dim=None, Bool keepdim=False) => Amax" bind_python: True - name: "permute" signature: "Tensor (Tensor input, Int32List dims) => Permute" bind_python: True - name: "T" signature: "Tensor (Tensor input) => TransposeAllDimProperty" bind_python: True - name: "t" signature: "Tensor (Tensor input) => TransposeAllDimFunction" bind_python: True - name: "not_equal_zero" signature: "Tensor (Tensor x) => NotEqualZero" bind_python: False - name: "not_equal_zero_grad" signature: "Tensor (Tensor x, Tensor dy) => NotEqualZeroGrad" bind_python: False - name: "reciprocal" signature: "Tensor (Tensor x) => Reciprocal" bind_python: True - name: "reciprocal_grad" signature: "Tensor (Tensor x, Tensor dy) => ReciprocalGrad" bind_python: False - name: "reciprocal_no_nan" signature: "Tensor (Tensor x) => ReciprocalNoNan" bind_python: True - name: "reciprocal_no_nan_grad" signature: "Tensor (Tensor x, Tensor dy) => ReciprocalNoNanGrad" bind_python: False - name: "image_flip" signature: "Tensor (Tensor x, Tensor flip_code) => ImageFlip" bind_python: True - name: "sin" signature: "Tensor (Tensor x) => Sin" bind_python: True - name: "sin_grad" signature: "Tensor (Tensor x, Tensor dy) => SinGrad" bind_python: False - name: "sin_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SinGradGrad" bind_python: False - name: "sin_" signature: "Tensor (Tensor x) => Sin_" bind_python: True - name: "cos" signature: "Tensor (Tensor x) => Cos" bind_python: True - name: "cos_grad" signature: "Tensor (Tensor x, Tensor dy) => CosGrad" bind_python: False - name: "cos_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => CosGradGrad" bind_python: False - name: "cosh" signature: "Tensor (Tensor x) => Cosh" bind_python: True - name: "cosh_grad" signature: "Tensor (Tensor x, Tensor dy) => CoshGrad" bind_python: True - name: "fmod" signature: [ "Tensor (Tensor input, Tensor other) => BroadcastFMod", "Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarFMod", "Tensor (Tensor input, Scalar other) => ScalarFMod", ] bind_python: true - name: "log" signature: "Tensor (Tensor x) => Log" bind_python: True - name: "log_grad" signature: "Tensor (Tensor x, Tensor dy) => LogGrad" bind_python: False - name: "log2" signature: "Tensor (Tensor x) => Log2" bind_python: True - name: "log2_grad" signature: "Tensor (Tensor x, Tensor dy) => Log2Grad" bind_python: False - name: "log10" signature: "Tensor (Tensor x) => Log10" bind_python: True - name: "log10_grad" signature: "Tensor (Tensor x, Tensor dy) => Log10Grad" bind_python: False - name: "sqrt" signature: "Tensor (Tensor x) => Sqrt" bind_python: True - name: "sqrt_grad" signature: "Tensor (Tensor x, Tensor dy) => SqrtGrad" bind_python: False - name: "rsqrt" signature: "Tensor (Tensor x) => Rsqrt" bind_python: True - name: "rsqrt_grad" signature: "Tensor (Tensor x, Tensor dy) => RsqrtGrad" bind_python: False - name: "square" signature: "Tensor (Tensor x) => Square" bind_python: True - name: "square_grad" signature: "Tensor (Tensor x, Tensor dy) => SquareGrad" bind_python: False - name: "sqrt_square_sum" signature: "Tensor (Tensor x) => SqrtSquareSum" bind_python: True - name: "std" signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => StandardDeviation" bind_python: True - name: "var" signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => Variance" bind_python: True - name: "rms_layer_norm" signature: "Tensor (Tensor hidden_states, Tensor weight, Float variance_epsilon) => RMSLayerNormalization" bind_python: True - name: "relu" signature: "Tensor (Tensor x, Bool inplace=False) => Relu" bind_python: True - name: "relu_grad" signature: "Tensor (Tensor dy, Tensor y) => ReluGrad" bind_python: False - name: "hann_window" signature: [ "Tensor (Int64 window_length, Bool periodic=True, *, Device device=None, DataType dtype=None, Bool requires_grad=False) => HannWindow", "Tensor (Int64 window_length, Bool periodic=True, *, Placement placement, SbpList sbp, DataType dtype=None, Bool requires_grad=False) => GlobalHannWindow", ] bind_python: True - name: "hardtanh" signature: "Tensor (Tensor x, Double min_val, Double max_val) => HardTanh" bind_python: True - name: "hardtanh_grad" signature: "Tensor (Tensor y, Tensor dy, Double min_val, Double max_val) => HardTanhGrad" bind_python: False - name: "tan" signature: "Tensor (Tensor x) => Tan" bind_python: True - name: "tan_grad" signature: "Tensor (Tensor x, Tensor dy) => TanGrad" bind_python: True - name: "tanh" signature: "Tensor (Tensor x) => Tanh" bind_python: True - name: "tanh_grad" signature: "Tensor (Tensor y, Tensor dy) => TanhGrad" bind_python: True - name: "threshold" signature: "Tensor (Tensor x, *, Double threshold, Double value) => Threshold" bind_python: True - name: "threshold_grad" signature: "Tensor (Tensor x, Tensor dy, Double threshold) => ThresholdGrad" bind_python: False - name: "elu" signature: "Tensor (Tensor x, Double alpha) => Elu" bind_python: True - name: "elu_grad" signature: "Tensor (Tensor x, Tensor dy, Double alpha) => EluGrad" bind_python: False - name: "celu" signature: "Tensor (Tensor x, *, Double alpha=1.0, Bool inplace=False) => Celu" bind_python: True - name: "celu_grad" signature: "Tensor (Tensor y, Tensor dy, Double alpha=1.0) => CeluGrad" bind_python: False - name: "gelu" signature: "Tensor (Tensor x) => Gelu" bind_python: True - name: "gelu_grad" signature: "Tensor (Tensor dy, Tensor x) => GeluGrad" bind_python: False - name: "fast_gelu" signature: "Tensor (Tensor x) => FastGelu" bind_python: True - name: "fast_gelu_grad" signature: "Tensor (Tensor dy, Tensor x) => FastGeluGrad" bind_python: False - name: "quick_gelu" signature: "Tensor (Tensor x) => QuickGelu" bind_python: True - name: "quick_gelu_grad" signature: "Tensor (Tensor dy, Tensor x) => QuickGeluGrad" bind_python: False - name: "square_relu" signature: "Tensor (Tensor x) => SquareReLU" bind_python: True - name: "square_relu_grad" signature: "Tensor (Tensor dy, Tensor x) => SquareReLUGrad" bind_python: False - name: "gelu_with_approximate" signature: 'Tensor (Tensor x, String approximate="none") => GeluWithApproximate' bind_python: True - name: "glu" signature: "Tensor (Tensor input, Int64 dim=-1) => Glu" bind_python: True - name: "fused_glu" signature: "Tensor (Tensor x, Tensor w, Tensor b=None, Tensor v=None, Tensor c=None, String activation=\"none\") => FusedGlu" bind_python: True - name: "fused_glu_without_linear_grad" signature: "TensorTuple (Tensor dy, Tensor matmul_wx, Tensor matmul_vx=None, String activation=\"none\") => FusedGluWithoutLinearGrad" bind_python: False - name: "sigmoid" signature: "Tensor (Tensor x) => Sigmoid" bind_python: True - name: "sigmoid_grad" signature: "Tensor (Tensor y, Tensor dy) => SigmoidGrad" bind_python: True - name: "hardsigmoid" signature: "Tensor (Tensor input, Bool inplace=False, *) => HardSigmoid" bind_python: True - name: "hardsigmoid_grad" signature: "Tensor (Tensor dy, Tensor x) => HardSigmoidGrad" bind_python: False - name: "hardshrink" signature: "Tensor (Tensor x, *, Double lambd=0.5, Bool inplace=False) => HardShrink" bind_python: True - name: "hardshrink_grad" signature: "Tensor (Tensor y, Tensor dy, Double lambd=0.5) => HardShrinkGrad" bind_python: False - name: "softmax" signature: "Tensor (Tensor x, Int64 dim=None) => Softmax" bind_python: True - name: "softmax_grad" signature: "Tensor (Tensor dy, Tensor y) => SoftmaxGrad" bind_python: False - name: "gumbel_softmax" signature: "Tensor (Tensor x, Double tau=1., Int64 dim=None, Bool hard=False, Generator generator=None) => GumbelSoftmax" bind_python: True - name: "log_softmax" signature: "Tensor (Tensor x, Int64 dim=None) => LogSoftmax" bind_python: True - name: "log_softmax_grad" signature: "Tensor (Tensor dy, Tensor y) => LogSoftmaxGrad" bind_python: False - name: "hardswish" signature: "Tensor (Tensor x) => HardSwish" bind_python: True - name: "hardswish_grad" signature: "Tensor (Tensor dy, Tensor x) => HardSwishGrad" bind_python: False - name: "leaky_relu" signature: "Tensor (Tensor x, Float alpha, Bool inplace=False) => LeakyRelu" bind_python: True - name: "leaky_relu_grad" signature: "Tensor (Tensor x, Tensor dy, Float alpha) => LeakyReluGrad" bind_python: False - name: "rrelu" signature: "Tensor (Tensor x, Float lower=0.125, Float upper=0.3333333333333333, Bool training=False, Bool inplace=False) => RRelu" bind_python: True - name: "rrelu_" signature: "Tensor (Tensor x, Float lower=0.125, Float upper=0.3333333333333333, Bool training=False) => RReluInplace" bind_python: True - name: "normal_" signature: "Tensor (Tensor x, Float mean=0.0, Float std=1.0, Generator generator=None) => Normal_" bind_python: True - name: "normal" signature: [ "Tensor (Tensor mean, Tensor std, *, Tensor out=None, Generator generator=None, Bool requires_grad=False) => TensorTensorNormal", "Tensor (Tensor mean, Float std=1.0, *, Tensor out=None, Generator generator=None, Bool requires_grad=False) => TensorScalarNormal", "Tensor (Float mean, Tensor std, *, Tensor out=None, Generator generator=None, Bool requires_grad=False) => ScalarTensorNormal", "Tensor (Float mean, Float std, Shape size, *, Tensor out=None, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False) => Normal", "Tensor (Float mean, Float std, Int32 size, *, Tensor out=None, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False) => Normal2", "Tensor (Float mean, Float std, Shape size, *, Tensor out=None, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False) => GlobalNormal", "Tensor (Float mean, Float std, Int32 size, *, Tensor out=None, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False) => GlobalNormal2", ] bind_python: True - name: "normalization" signature: "Tensor (Tensor x, Tensor moving_mean=None, Tensor moving_variance=None, Tensor gamma=None, Tensor beta=None, Int32 axis=1, Float epsilon=1e-5, Float momentum=0.9, Bool is_training=False) => Normalization" bind_python: True - name: "normalization_grad" signature: "TensorTuple (Tensor grad, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Float epsilon, Int32 axis) => NormalizationGrad" bind_python: False - name: "normalization_add_relu" signature: "Tensor (Tensor x, Tensor addend=None, Tensor moving_mean=None, Tensor moving_variance=None, Tensor gamma, Tensor beta, Int32 axis=1, Float epsilon=1e-5, Float momentum=0.9, Bool is_training=False) => NormalizationAddRelu" bind_python: True - name: "normalization_add_relu_grad" signature: "TensorTuple (Tensor x, Tensor dy, Tensor moving_mean, Tensor moving_variance, Tensor gamma, Tensor beta, Tensor reserve_space, Tensor y, Int32 axis=1, Float epsilon=1e-5, Bool has_addend) => NormalizationAddReluGrad" bind_python: False - name: "eye" signature: [ "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Device device=None, Bool requires_grad=False) => Eye", "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, String device, Bool requires_grad=False) => Eye", "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, SbpList sbp) => Eye", "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, Sbp sbp) => Eye", ] bind_python: True - name: "eye_" signature: "Tensor (Tensor x) => EyeInplace" bind_python: True - name: "erfinv" signature: "Tensor (Tensor x) => Erfinv" bind_python: True - name: "erfinv_" signature: "Tensor (Tensor x) => ErfinvInplace" bind_python: True - name: "arange" signature: [ "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Device device=None) => Arange", "Tensor (Scalar end, *, DataType dtype=None, Device device=None) => Arange", ] bind_python: True - name: "global_arange" signature: [ "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Placement placement, SbpList sbp) => GlobalArange", "Tensor (Scalar end, *, DataType dtype=None, Placement placement, SbpList sbp) => GlobalArange", ] bind_python: True - name: "flatten" signature: "Tensor (Tensor x, Int32 start_dim=0, Int32 end_dim=-1) => Flatten" bind_python: True - name: "argmax" signature: "Tensor (Tensor x, Int32 dim=None, Bool keepdim=None, DataType dtype=None) => ArgMax" bind_python: True - name: "argmin" signature: "Tensor (Tensor x, Int32 dim=None, Bool keepdim=None, DataType dtype=None) => ArgMin" bind_python: True - name: "argwhere" signature: "TensorTuple (Tensor x, DataType dtype=kInt32) => ArgWhere" bind_python: True - name: "nonzero" signature: "TensorTuple (Tensor x, Bool as_tuple=False) => NonZero" bind_python: True - name: "broadcast_like" signature: "Tensor (Tensor x, Tensor like, Int32List broadcast_axes=[]) => BroadcastLike" bind_python: True - name: "cast" signature: "Tensor (Tensor x, DataType dtype, Bool pin_memory=False) => Cast" bind_python: True - name: "global_tensor_constant" signature: "Tensor (Shape shape, Tensor value, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorConstant" bind_python: True - name: "tensor_constant" signature: "Tensor (Shape shape, Tensor value, *, DataType dtype, Device device=None) => TensorConstant" bind_python: True - name: "constant" signature: [ "Tensor (Shape shape, Scalar value, *, DataType dtype, Device device=None) => Constant", ] bind_python: True - name: "global_constant" signature: [ "Tensor (Shape shape, Scalar value, *, DataType dtype, Placement placement, SbpList sbp) => GlobalConstant", ] bind_python: True - name: "empty" signature: "Tensor (Shape shape, *, DataType dtype, Device device=None, Bool requires_grad=False, Bool pin_memory=False) => Empty" bind_python: True - name: "empty_strided" signature: "Tensor (Int64List shape, Int64List stride, DataType dtype=None, Device device=None, Bool requires_grad=False, Bool pin_memory=False) => EmptyStrided" bind_python: True - name: "global_empty" signature: [ "Tensor (Shape shape, *, DataType dtype, Placement placement, SbpList sbp) => GlobalEmpty", ] bind_python: True - name: "zeros_like" signature: "Tensor (Tensor x) => ZerosLike" bind_python: False - name: "ones_like" signature: "Tensor (Tensor x) => OnesLike" bind_python: False - name: "full_like" signature: "Tensor (Tensor x, Scalar fill_value) => FullLike" bind_python: False - name: "bernoulli" signature: [ "Tensor (Tensor input, *, DataType dtype=kFloat, Generator generator=None, Bool inplace=False) => Bernoulli", "Tensor (Tensor input, Double p, *, DataType dtype=kFloat, Generator generator=None, Bool inplace=False) => BernoulliProb", ] bind_python: True - name: "bernoulli_" signature: [ "Tensor (Tensor input, *, DataType dtype=kFloat, Generator generator=None) => BernoulliInplace", "Tensor (Tensor input, Double p, *, DataType dtype=kFloat, Generator generator=None) => BernoulliProbInplace", ] bind_python: True - name: "concat" signature: "Tensor (TensorTuple inputs, Int64 dim=0) => Concat" bind_python: True - name: "bias_add" signature: "Tensor (Tensor x, Tensor bias, Int32 axis=1) => BiasAdd" bind_python: True - name: "conv1d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[1] stride=1, Int32List[1] padding=0, Int32List[1] dilation=1, Int32 groups=1, String channel_pos="channels_first") => Conv1d' bind_python: True - name: "conv2d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[2] stride=1, Int32List[2] padding=0, Int32List[2] dilation=1, Int32 groups=1, String channel_pos="channels_first") => Conv2d' bind_python: True - name: "conv3d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[3] stride=1, Int32List[3] padding=0, Int32List[3] dilation=1, Int32 groups=1, String channel_pos="channels_first") => Conv3d' bind_python: True - name: "fake_quantization" signature: "Tensor (Tensor in, Tensor scale, Tensor zero_point, String quantization_formula, Int32 quantization_bit, String quantization_scheme) => FakeQuantization" bind_python: True - name: "quantization" signature: "Tensor (Tensor in, Tensor scale, Tensor zero_point, String quantization_formula, Int32 quantization_bit, String quantization_scheme) => Quantization" bind_python: True - name: "min_max_observer" signature: "TensorTuple (Tensor in, String quantization_formula, Int32 quantization_bit, String quantization_scheme, Bool per_layer_quantization) => MinMaxObserver" bind_python: True - name: "moving_average_min_max_observer" signature: "TensorTuple (Tensor in, Tensor current_train_step, Tensor moving_max, Tensor moving_min, Bool training, Int64 stop_update_after_iters, String quantization_formula, Int32 quantization_bit, String quantization_scheme, Float momentum) => MovingAverageMinMaxObserver" bind_python: True - name: "groupwise_dequantize" signature: 'Tensor (Tensor in, Tensor scale, *, Tensor zero=None, Int32 num_bits=8, Bool symmetric=True, Int64 group_dim=-1, Int64 group_size=-1) => GroupwiseDequantize' bind_python: True - name: "fused_linear_with_groupwise_quantized_weight" signature: 'Tensor (Tensor x, Tensor w, Tensor w_scale, *, Tensor w_zero=None, Tensor b=None, Int32 num_bits=8, Bool symmetric=True, Int64 group_dim=-1, Int64 group_size=-1) => FusedLinearWithGroupwiseQuantizedWeight' bind_python: True - name: "conv_data_grad" signature: 'Tensor (Tensor dy, Tensor weight, Tensor x, Int32 num_spatial_dims, Int32List kernel_size, Int32List strides, Int32List padding_before, Int32List dilation_rate, Int32 groups=1, String data_format="channels_first") => ConvDataGrad' bind_python: False - name: "conv_filter_grad" signature: 'Tensor (Tensor dy, Tensor x, Int32 num_spatial_dims, Int32List kernel_size, Int32List strides, Int32List padding_before, Int32List dilation_rate, Int32 groups=1, String data_format="channels_first") => ConvFilterGrad' bind_python: False - name: "conv_bias_grad" signature: 'Tensor (Tensor dy, Int32 num_spatial_dims, String data_format="channels_first") => ConvBiasGrad' bind_python: False - name: "deconv1d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[1] stride=1, Int32List[1] padding=0, Int32List[1] output_padding=0, Int32 groups=1, Int32List[1] dilation=1, String data_format="channels_first") => Deconv1d' bind_python: True - name: "deconv2d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[2] stride=1, Int32List[2] padding=0, Int32List[2] output_padding=0, Int32 groups=1, Int32List[2] dilation=1, String data_format="channels_first") => Deconv2d' bind_python: True - name: "deconv3d" signature: 'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[3] stride=1, Int32List[3] padding=0, Int32List[3] output_padding=0, Int32 groups=1, Int32List[3] dilation=1, String data_format="channels_first") => Deconv3d' bind_python: True - name: "expand" signature: "Tensor (Tensor x, Shape shape) => Expand" bind_python: True - name: "repeat" signature: "Tensor (Tensor input, Shape repeat_shape) => Repeat" bind_python: True - name: "repeat_interleave_index" signature: "Tensor (Tensor input, Tensor cumsum, Int32 dim) => RepeatInterLeaveIndex" bind_python: False - name: "repeat_interleave" signature: [ "Tensor (Tensor input, Int32 repeats, Int32 dim=None) => RepeatInterLeaveInt", "Tensor (Tensor input, Tensor repeats, Int32 dim, Int32 output_size=None) => RepeatInterLeaveTensor", ] bind_python: True - name: "tile" signature: "Tensor (Tensor input, Shape dims) => Tile" bind_python: True - name: "roll" signature: "Tensor (Tensor x, Int32List[1] shifts, Int32List[1] dims=None) => Roll" bind_python: True - name: "expand_dims" signature: "Tensor (Tensor input, Int32 dim) => ExpandDims" bind_python: True - name: "unsqueeze" signature: "Tensor (Tensor input, Int32 dim) => Unsqueeze" bind_python: True - name: "unsqueeze_multiple" signature: "Tensor (Tensor input, Int32List dim, Int32 dims) => UnsqueezeMultiple" bind_python: False - name: "unsqueeze_" signature: "Tensor (Tensor input, Int32 dim) => InplaceUnsqueeze" bind_python: True - name: "squeeze" signature: "Tensor (Tensor x, Int32List[1] dim=None) => Squeeze" bind_python: True - name: "squeeze_" signature: "Tensor (Tensor x, Int32List[1] dim=None) => InplaceSqueeze" bind_python: True - name: "exp" signature: "Tensor (Tensor x) => Exp" bind_python: True - name: "exp2" signature: "Tensor (Tensor x) => Exp2" bind_python: True - name: "exp_grad" signature: "Tensor (Tensor x, Tensor dy) => ExpGrad" bind_python: False - name: "exp2_grad" signature: "Tensor (Tensor x, Tensor dy) => Exp2Grad" bind_python: False - name: "gather" signature: "Tensor (Tensor x, Tensor indices, Int64 axis) => Gather" bind_python: True - name: "dim_gather" signature: " Tensor (Tensor input, Int64 dim, Tensor index, Bool sparse_grad=False) => DimGather" bind_python: True - name: "embedding_renorm_" signature: " Tensor (Tensor in, Tensor indices, Double max_norm, Double norm_type) => EmbeddingReNorm" bind_python: True - name: "embedding" signature: " Tensor (Tensor weight, Tensor indices, Int64 padding_idx=None, Bool scale_grad_by_freq=False) => Embedding" bind_python: True - name: "embedding_grad" signature: " Tensor (Tensor dy, Tensor weight, Tensor indices, Int64 padding_idx, Bool scale_grad_by_freq=False) => EmbeddingGrad" bind_python: False - name: "arg_sort" signature: "Tensor (Tensor in, String direction) => ArgSort" bind_python: True - name: "gather_nd" signature: "Tensor (Tensor params, Tensor indices) => GatherNd" bind_python: True - name: "scatternd" signature: "Tensor (Tensor indices, Tensor updates, Shape shape) => ScatterNd" bind_python: True - name: "tensor_scatter_nd_update" signature: "Tensor (Tensor tensor, Tensor indices, Tensor updates, Bool inplace=False) => TensorScatterNdUpdate" bind_python: True - name: "scatterndlike" signature: "Tensor (Tensor like, Tensor updates, Tensor indices) => ScatterNdLike" bind_python: True - name: "matmul" signature: "Tensor (Tensor input, Tensor other, Bool transpose_a=False, Bool transpose_b=False, Double alpha=1.0) => MatMul" bind_python: True - name: "mm" signature: "Tensor (Tensor input, Tensor mat2) => MatMulNoBroadCast" bind_python: True - name: "fused_mlp" signature: "Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation) => FusedMLP" bind_python: True - name: "fused_matmul_bias" signature: "Tensor (Tensor x, Tensor weight, Tensor bias, Tensor _add_to_output=None, Double alpha=1.0, Double beta=1.0) => FusedMatmulBias" bind_python: True - name: "fused_mlp_grad" signature: "TensorTuple (Tensor dy, Tensor x, TensorTuple weights, TensorTuple cublas_aux, TensorTuple hidden, FloatList alpha_list) => FusedMLPGrad" bind_python: False - name: "cublas_bias_add_relu_matmul_grad" signature: "TensorTuple (Tensor dy, Tensor weight, Tensor aux, Double alpha=1.0) => CublasBiasAddReluMatmulGrad" bind_python: False - name: "cublas_matmul_bias_add_grad" signature: "TensorTuple (Tensor dy, Tensor x) => CublasMatmulBiasAddGrad" bind_python: False - name: "fused_matmul_bias_add_relu_dropout" signature: "Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation, FloatList dropout_rate_list, Generator generator=None) => FusedMatmulBiasAddReluDropout" bind_python: True - name: "fused_apply_rotary_emb" signature: 'Tensor (Tensor x, *, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmb' bind_python: True - name: "fused_relu_dropout_grad" signature: "Tensor (Tensor dy, Tensor mask, Float scale) => FusedReluDropoutGrad" bind_python: False - name: "broadcast_matmul_grad_b" signature: "Tensor (Tensor a, Tensor b, Double alpha=1.0) => BroadcastMatmulGradB" bind_python: False - name: "batch_matmul" signature: "Tensor (Tensor a, Tensor b, Bool transpose_a=False, Bool transpose_b=False, Double alpha=1.0) => BatchMatMul" bind_python: True - name: "baddbmm" signature: "Tensor (Tensor input, Tensor batch1, Tensor batch2, *, Double beta=1.0, Double alpha=1.0) => BaddBmm" bind_python: True - name: "matrix_vector_product" signature: "Tensor (Tensor input, Tensor vec) => MatrixVectorProduct" bind_python: True - name: "matrix_vector_product_grad_a" signature: "Tensor (Tensor dy, Tensor b) => MatrixVectorProductGradA" bind_python: False - name: "matrix_vector_product_grad_b" signature: "Tensor (Tensor dy, Tensor a) => MatrixVectorProductGradB" bind_python: False - name: "vector_matrix_product" signature: "Tensor (Tensor vec, Tensor input) => VectorMatrixProduct" bind_python: False - name: "vector_matrix_product_grad_a" signature: "Tensor (Tensor dy, Tensor b) => VectorMatrixProductGradA" bind_python: False - name: "vector_matrix_product_grad_b" signature: "Tensor (Tensor dy, Tensor a) => VectorMatrixProductGradB" bind_python: False - name: "tensordot" signature: [ "Tensor (Tensor a, Tensor b, Int32List dims_a, Int32List dims_b) => TensorDot", "Tensor (Tensor a, Tensor b, Int32 dims) => TensorDotIntDims", ] bind_python: True - name: "l1_loss" signature: 'Tensor(Tensor input, Tensor target, String reduction="mean") => L1Loss' bind_python: True - name: "mse_loss" signature: 'Tensor(Tensor input, Tensor target, String reduction="mean") => MseLoss' bind_python: True - name: "kl_div_loss" signature: 'Tensor(Tensor input, Tensor target, Bool log_target=False, String reduction="mean") => KLDivLoss' bind_python: True - name: "kl_div_loss_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Bool log_target) => KLDivLossGrad" bind_python: False - name: "kl_div_loss_target_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Bool log_target) => KLDivLossTargetGrad" bind_python: False - name: "nll_loss" signature: "Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NLLLoss" bind_python: True - name: "nll_grad" signature: "Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad" bind_python: False - name: "binary_cross_entropy_loss" signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, String reduction="mean") => BinaryCrossEntropyLoss' bind_python: True - name: "binary_cross_entropy_loss_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None) => BinaryCrossEntropyLossGrad" bind_python: False - name: "binary_cross_entropy_loss_target_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None) => BinaryCrossEntropyLossTargetGrad" bind_python: False - name: "binary_cross_entropy_with_logits_loss" signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None, String reduction="mean") => BinaryCrossEntropyWithLogitsLoss' bind_python: True - name: "binary_cross_entropy_with_logits_loss_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None) => BinaryCrossEntropyWithLogitsLossGrad" bind_python: True - name: "binary_cross_entropy_with_logits_loss_target_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None) => BinaryCrossEntropyWithLogitsLossTargetGrad" bind_python: False - name: "binary_cross_entropy_with_logits_reduce_mean_loss_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target) => BinaryCrossEntropyWithLogitsReduceMeanLossGrad" bind_python: False - name: "binary_cross_entropy_with_logits_reduce_mean_loss_target_grad" signature: "Tensor(Tensor dy, Tensor input, Tensor target) => BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad" bind_python: False - name: "sparse_cross_entropy" signature: "Tensor (Tensor prediction, Tensor label, Int64 depth) => SparseCrossEntropy" bind_python: True - name: "sparse_cross_entropy_grad" signature: "Tensor (Tensor prediction, Tensor label, Tensor dy, Int64 depth) => SparseCrossEntropyGrad" bind_python: False - name: "distributed_sparse_cross_entropy" signature: "Tensor (Tensor prediction, Tensor label, Int64 depth) => SparseCrossEntropyMs" bind_python: True - name: "cross_entropy" signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index=-100, String reduction="mean", Double label_smoothing=0.0) => CrossEntropy' bind_python: True - name: "cross_entropy_label_smoothing" signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index=-100, String reduction="mean", Double label_smoothing=0.0) => CrossEntropyLabelSmoothing' bind_python: False - name: "cross_entropy_prob" signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, String reduction="mean", Double label_smoothing=0.0) => CrossEntropyProb' bind_python: False - name: "distributed_sparse_cross_entropy_grad" signature: "Tensor (Tensor prediction, Tensor label, Tensor dy, Int64 depth) => SparseCrossEntropyMsGrad" bind_python: False - name: "sparse_softmax_cross_entropy" signature: "Tensor (Tensor logits, Tensor label) => SparseSoftmaxCrossEntropy" bind_python: True - name: "sparse_softmax_cross_entropy_grad" signature: "Tensor (Tensor dy, Tensor prob, Tensor label, Int64 depth) => SparseSoftmaxCrossEntropyGrad" bind_python: False - name: "sparse_softmax_cross_entropy_ms_grad" signature: "Tensor (Tensor dy, Tensor prob, Tensor label, Int64 depth) => SparseSoftmaxCrossEntropyMsGrad" bind_python: False - name: "softmax_cross_entropy" signature: "Tensor (Tensor logits, Tensor label) => SoftmaxCrossEntropy" bind_python: True - name: "softmax_cross_entropy_grad" signature: "Tensor (Tensor dy, Tensor label, Tensor prob) => SoftmaxCrossEntropyGrad" bind_python: True - name: "smooth_l1_loss" signature: "Tensor (Tensor logits, Tensor label, Float beta, String reduction) => SmoothL1Loss" bind_python: True - name: "smooth_l1_loss_grad" signature: "Tensor (Tensor loss_grad, Tensor prediction, Tensor label, Float beta) => SmoothL1LossGrad" bind_python: False - name: "combined_margin_loss" signature: "Tensor (Tensor x, Tensor label, Float m1, Float m2, Float m3) => CombinedMarginLoss" bind_python: True - name: "combined_margin_loss_grad" signature: "Tensor (Tensor dy, Tensor label, Tensor theta, Float m1, Float m2, Float m3, Int64 depth) => CombinedMarginLossGrad" bind_python: False - name: "triplet_margin_loss" signature: "Tensor (Tensor anchor, Tensor positive, Tensor negative, *, Float margin, Float p, Float eps, Bool swap, String reduction) => TripletMarginLoss" bind_python: True - name: "margin_ranking_loss" signature: "Tensor (Tensor input_1, Tensor input_2, Tensor target, Float margin, String reduction) => MarginRankingLoss" bind_python: True - name: "ctc_loss" signature: "Tensor (Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Int64 max_target_length, Int64 blank, Bool zero_infinity, String reduction) => CtcLoss" bind_python: True - name: "affine_grid" signature: "Tensor (Tensor theta, *, Shape size, Bool align_corners) => AffineGrid" bind_python: True - name: "affine_grid_grad" signature: "Tensor (Tensor dgrid, *, Shape size, Bool align_corners) => AffineGridGrad" bind_python: False - name: "grid_sample" signature: "Tensor (Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSample" bind_python: True - name: "grid_sample_grad" signature: "TensorTuple (Tensor doutput, Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSampleGrad" bind_python: False - name: "where" signature: [ "Tensor (Tensor condition, Tensor x, Tensor y) => Where", "Tensor (Tensor condition, Scalar x, Tensor y) => WhereScalarX", "Tensor (Tensor condition, Tensor x, Scalar y) => WhereScalarY", "Tensor (Tensor condition, Scalar x, Scalar y) => WhereScalarXY", ] bind_python: true - name: "masked_fill" signature: "Tensor (Tensor input, Tensor mask, Scalar value) => MaskedFill" bind_python: true - name: "masked_fill_" signature: "Tensor (Tensor input, Tensor mask, Scalar value) => MaskedFillInplace" bind_python: true - name: "movedim" signature: [ "Tensor (Tensor input, Int32 source, Int32 destination) => MovedimInt", "Tensor (Tensor input, Int32List source, Int32List destination) => MovedimVec", ] bind_python: True - name: "tensor_split" signature: [ "TensorTuple (Tensor input, Int32 indices_or_sections, Int32 dim=0) => TensorSplitInt", "TensorTuple (Tensor input, Int32List indices_or_sections, Int32 dim=0) => TensorSplitVec", ] bind_python: True - name: "hsplit" signature: [ "TensorTuple (Tensor input, Int32 indices_or_sections) => HsplitInt", "TensorTuple (Tensor input, Int32List indices_or_sections) => HsplitVec", ] bind_python: True - name: "vsplit" signature: [ "TensorTuple (Tensor input, Int32 indices_or_sections) => VsplitInt", "TensorTuple (Tensor input, Int32List indices_or_sections) => VsplitVec", ] bind_python: True - name: "negative" signature: "Tensor (Tensor x) => Negative" bind_python: True - name: "layer_norm_affine" signature: "Tensor (Tensor x, Tensor gamma, Tensor beta, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNormAffine" bind_python: True - name: "skip_layer_norm" signature: "Tensor (Tensor x, *, Tensor gamma=None, Tensor beta=None, Tensor bias=None, Tensor skip=None, Double epsilon=1e-5, Double alpha=1e1) => SkipLayerNorm" bind_python: True - name: "layer_norm" signature: "Tensor (Tensor x, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNorm" bind_python: True - name: "layer_norm_grad" signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Double epsilon) => LayerNormGrad" bind_python: False - name: "layer_norm_affine_grad" signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad" bind_python: False - name: "fuse_layer_norm_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad" bind_python: False - name: "layer_norm_param_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad" bind_python: False - name: "rms_norm" signature: "Tensor (Tensor x, Tensor weight=None, Shape normalized_shape, Float epsilon=1e-6) => RMSNorm" bind_python: True - name: "rms_norm_grad" signature: "Tensor (Tensor dy, Tensor x, Tensor inv_rms, Tensor weight=None, Bool param_grad) => RMSNormGrad" bind_python: False - name: "skip_rms_norm" signature: "Tensor (Tensor x, *, Tensor weight=None, Tensor bias=None, Tensor skip=None, Double epsilon=1e-5, Double alpha=1e1) => SkipRMSNorm" bind_python: True - name: "group_norm" signature: 'Tensor (Tensor x, Tensor gamma=None, Tensor beta=None, Bool affine, Int32 num_groups, Double epsilon, String data_format="channels_first", String activation="none") => GroupNorm' bind_python: True - name: "group_norm_grad" signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma=None, Int32 num_groups, Double epsilon) => GroupNormGrad" bind_python: False - name: "group_norm_param_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance) => GroupNormParamGrad" bind_python: False - name: "avg_pool2d_nhwc" signature: 'Tensor (Tensor x, Int32List kernel_size, Int32List stride, String padding, Int32List padding_before, Int32List padding_after, String data_format="channels_first", Bool ceil_mode=False) => TFAvgPool2D' bind_python: True - name: "ctc_loss_grad" signature: "Tensor (Tensor loss_grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor loss, Tensor alpha, Int64 blank, Bool zero_infinity, Int64 max_target_length) => CtcLossGrad" bind_python: False - name: "adaptive_avg_pool1d" signature: 'Tensor (Tensor x, Int64List[1] output_size, String data_format="channels_first") => AdaptiveAvgPool1D' bind_python: True - name: "adaptive_avg_pool2d" signature: 'Tensor (Tensor x, Int64List[2] output_size, String data_format="channels_first") => AdaptiveAvgPool2D' bind_python: True - name: "adaptive_avg_pool3d" signature: 'Tensor (Tensor x, Int64List[3] output_size, String data_format="channels_first") => AdaptiveAvgPool3D' bind_python: True - name: "adaptive_pool_grad" signature: 'Tensor (Tensor x, Tensor dy, String mode, Int32 ndims, String data_format="channels_first") => AdaptivePoolNdGrad' - name: "tf_pool_grad" signature: "Tensor (Tensor x, Tensor y, Tensor dy, String mode, Int32 ndims, String data_format, String padding, Int32List padding_before, Int32List padding_after, Int32List pool_size, Int32List strides, Bool ceil_mode) => TFPoolNdGrad" bind_python: False - name: "max_pool1d" signature: 'TensorTuple (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None, Int32List[1] padding=0, Int32List[1] dilation=1, Bool return_indices=True, Bool ceil_mode=False, String data_format="channels_first") => MaxPool1D' bind_python: True - name: "max_pool2d" signature: 'TensorTuple (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None, Int32List[2] padding=0, Int32List[2] dilation=1, Bool return_indices=True, Bool ceil_mode=False, String data_format="channels_first") => MaxPool2D' bind_python: True - name: "max_pool3d" signature: 'TensorTuple (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None, Int32List[3] padding=0, Int32List[3] dilation=1, Bool return_indices=True, Bool ceil_mode=False, String data_format="channels_first") => MaxPool3D' bind_python: True - name: "max_pool_grad" signature: "Tensor (Tensor x, Tensor indice, Tensor dy, Int32 ndims, String data_format, Int32List padding, Int32List kernel_size, Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode) => MaxPoolNdGrad" bind_python: False - name: "max_unpool1d" signature: 'Tensor (Tensor input, Tensor indices, Int32List[1] kernel_size, Int32List[1] stride=None, Int32List[1] padding=0, Shape output_size=None) => MaxUnpool1D' bind_python: True - name: "max_unpool2d" signature: 'Tensor (Tensor input, Tensor indices, Int32List[2] kernel_size, Int32List[2] stride=None, Int32List[2] padding=0, Shape output_size=None) => MaxUnpool2D' bind_python: True - name: "max_unpool3d" signature: 'Tensor (Tensor input, Tensor indices, Int32List[3] kernel_size, Int32List[3] stride=None, Int32List[3] padding=0, Shape output_size=None) => MaxUnpool3D' bind_python: True - name: "max_unpool1d_grad" signature: "Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool1dGrad" bind_python: False - name: "max_unpool2d_grad" signature: "Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool2dGrad" bind_python: False - name: "max_unpool3d_grad" signature: "Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool3dGrad" bind_python: False - name: "prelu" signature: "Tensor (Tensor x, Tensor alpha) => PRelu" bind_python: True - name: "prelu_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor alpha) => PReluGrad" bind_python: False - name: "reshape" signature: "Tensor (Tensor x, Shape shape) => Reshape" bind_python: True - name: "view" signature: "Tensor (Tensor x, Shape shape) => View" bind_python: True - name: "contiguous" signature: "Tensor (Tensor input) => ToContiguous" bind_python: True - name: "contiguous_" signature: "Tensor (Tensor input) => InplaceToContiguous" bind_python: True - name: "slice_view_1d_contiguous" signature: "Tensor (Tensor x, Int64 start, Int64 end) => SliceView1dContiguous" bind_python: True - name: "narrow" signature: "Tensor (Tensor input, Int64 dim, Int64 start, Int64 length) => Narrow" bind_python: True - name: "narrow_grad" signature: "Tensor (Tensor dy, Tensor like, Int64 dim, Int64 start, Int64 length) => NarrowGrad" bind_python: False - name: "slice" signature: "Tensor (Tensor x, Int64List start, Int64List stop, Int64List step, Bool enable_view_slice=None) => Slice" bind_python: True - name: "slice_update" signature: "Tensor (Tensor ref, Tensor value, Int64List start, Int64List stop, Int64List step, Bool inplace=False) => SliceUpdate" bind_python: True - name: "slice_grad" signature: "Tensor (Tensor dy, Shape like_shape, Int64List start, Int64List stop, Int64List step) => SliceGrad" bind_python: False - name: "copy" signature: [ "Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy", "Tensor (Tensor x, Device device, Bool pin_memory=False) => Copy" ] bind_python: True - name: "to" signature: [ # type of device must be string for global tensor to perform argument validation "Tensor (Tensor x, String device=None, DataType dtype=None, Bool copy=False) => To", "Tensor (Tensor x, Device device=None, DataType dtype=None, Bool copy=False) => To", "Tensor (Tensor x, DataType dtype=None, Bool copy=False) => To", "Tensor (Tensor x, Tensor other, Bool copy=False) => To", "Tensor (Tensor x, String device=None) => To", "Tensor (Tensor x, *, MemoryFormat memory_format) => To", ] bind_python: True - name: "flip" signature: "Tensor (Tensor x, Int32List[1] dims) => Flip" bind_python: True - name: "upsample" signature: 'Tensor (Tensor x, Double height_scale, Double width_scale, Bool align_corners, String interpolation, String data_format="channels_first") => Upsample' bind_python: True - name: "upsample_grad" signature: "Tensor (Tensor dy, Tensor x, Double height_scale, Double width_scale, Bool align_corners, String data_format, String interpolation) => UpsampleGrad" bind_python: False - name: "upsample_linear_1d" signature: 'Tensor (Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleLinear1D' bind_python: True - name: "upsample_linear_1d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleLinear1DGrad' bind_python: False - name: "upsample_nearest_1d" signature: 'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleNearest1D' bind_python: True - name: "upsample_nearest_1d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleNearest1DGrad' bind_python: False - name: "upsample_nearest_2d" signature: 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleNearest2D' bind_python: True - name: "upsample_nearest_2d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleNearest2DGrad' bind_python: False - name: "upsample_bilinear_2d" signature: 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBilinear2D' bind_python: True - name: "upsample_bilinear_2d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBilinear2DGrad' bind_python: False - name: "upsample_bicubic_2d" signature: 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBicubic2D' bind_python: True - name: "upsample_bicubic_2d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBicubic2DGrad' bind_python: False - name: "upsample_nearest_3d" signature: 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleNearest3D' bind_python: True - name: "upsample_nearest_3d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleNearest3DGrad' bind_python: False - name: "upsample_trilinear_3d" signature: 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleTrilinear3D' bind_python: True - name: "upsample_trilinear_3d_grad" signature: 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleTrilinear3DGrad' bind_python: False - name: "fused_get_boundding_boxes_coord" signature: "TensorTuple (Tensor x1, Tensor y1, Tensor w1, Tensor h1, Tensor x2, Tensor y2, Tensor w2, Tensor h2) => FusedGetBounddingBoxesCoord" bind_python: True - name: "fused_get_boundding_boxes_coord_grad" signature: "TensorTuple (Tensor b1_x1_diff, Tensor b1_x2_diff, Tensor b1_y1_diff, Tensor b1_y2_diff, Tensor b2_x1_diff, Tensor b2_x2_diff, Tensor b2_y1_diff, Tensor b2_y2_diff) => FusedGetBounddingBoxesCoordGrad" bind_python: False - name: "fused_get_ciou_result" signature: "TensorTuple (Tensor v, Tensor iou, Tensor rho2, Tensor c2, Float eps) => FusedGetCiouResult" bind_python: True - name: "fused_get_ciou_result_grad" signature: "TensorTuple (Tensor dy ,Tensor alpha, Tensor rho2, Tensor c2) => FusedGetCiouResultGrad" bind_python: False - name: "fused_codegeex_qkv_reshape" signature: "TensorTuple (Tensor query, Tensor key, Tensor value, Int32 num_attention_heads) => FusedCodegeexQkvReshape" bind_python: True - name: "fused_get_iou" signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIou" bind_python: True - name: "fused_get_iou_grad" signature: "TensorTuple (Tensor diou, Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIouGrad" bind_python: False - name: "abs" signature: "Tensor (Tensor x) => Abs" bind_python: True - name: "abs_grad" signature: "Tensor (Tensor x, Tensor dy) => AbsGrad" bind_python: False - name: "acos" signature: "Tensor (Tensor x) => Acos" bind_python: True - name: "acos_grad" signature: "Tensor (Tensor x, Tensor dy) => AcosGrad" bind_python: False - name: "acosh" signature: "Tensor (Tensor x) => Acosh" bind_python: True - name: "acosh_grad" signature: "Tensor (Tensor x, Tensor dy) => AcoshGrad" bind_python: False - name: "asin" signature: "Tensor (Tensor x) => Asin" bind_python: True - name: "asin_grad" signature: "Tensor (Tensor x, Tensor dy) => AsinGrad" bind_python: False - name: "asinh" signature: "Tensor (Tensor x) => Asinh" bind_python: True - name: "asinh_grad" signature: "Tensor (Tensor x, Tensor dy) => AsinhGrad" bind_python: False - name: "atan" signature: "Tensor (Tensor x) => Atan" bind_python: True - name: "atan_grad" signature: "Tensor (Tensor x, Tensor dy) => AtanGrad" bind_python: False - name: "atan2" signature: "Tensor (Tensor input, Tensor other) => Atan2" bind_python: True - name: "atan2_x_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => Atan2XGrad" bind_python: False - name: "atan2_y_grad" signature: "Tensor (Tensor dz, Tensor x, Tensor y) => Atan2YGrad" bind_python: False - name: "atanh" signature: "Tensor (Tensor x) => Atanh" bind_python: True - name: "atanh_grad" signature: "Tensor (Tensor x, Tensor dy) => AtanhGrad" bind_python: False - name: "ceil" signature: "Tensor (Tensor x) => Ceil" bind_python: True - name: "ceil_" signature: "Tensor (Tensor x) => Ceil_" bind_python: True - name: "ceil_grad" signature: "Tensor (Tensor x, Tensor dy) => CeilGrad" bind_python: False - name: "erf" signature: "Tensor (Tensor x) => Erf" bind_python: True - name: "erf_grad" signature: "Tensor (Tensor x, Tensor dy) => ErfGrad" bind_python: False - name: "erfc" signature: "Tensor (Tensor x) => Erfc" bind_python: True - name: "erfc_grad" signature: "Tensor (Tensor x, Tensor dy) => ErfcGrad" bind_python: False - name: "expm1" signature: "Tensor (Tensor x) => Expm1" bind_python: True - name: "expm1_grad" signature: "Tensor (Tensor x, Tensor dy) => Expm1Grad" bind_python: False - name: "floor" signature: "Tensor (Tensor x) => Floor" bind_python: True - name: "floor_" signature: "Tensor (Tensor x) => Floor_" bind_python: True - name: "floor_grad" signature: "Tensor (Tensor x, Tensor dy) => FloorGrad" bind_python: False - name: "lgamma" signature: "Tensor (Tensor x) => Lgamma" bind_python: True - name: "lgamma_grad" signature: "Tensor (Tensor x, Tensor dy) => LgammaGrad" bind_python: False - name: "log1p" signature: "Tensor (Tensor x) => Log1p" bind_python: True - name: "log1p_grad" signature: "Tensor (Tensor x, Tensor dy) => Log1pGrad" bind_python: False - name: "logsigmoid" signature: "Tensor (Tensor x) => LogSigmoid" bind_python: True - name: "logsigmoid_grad" signature: "Tensor (Tensor x, Tensor dy) => LogSigmoidGrad" bind_python: False - name: "rint" signature: "Tensor (Tensor x) => Rint" bind_python: True - name: "rint_grad" signature: "Tensor (Tensor x, Tensor dy) => RintGrad" bind_python: False - name: "round" signature: "Tensor (Tensor x) => Round" bind_python: True - name: "round_" signature: "Tensor (Tensor x) => Round_" bind_python: True - name: "round_grad" signature: "Tensor (Tensor x, Tensor dy) => RoundGrad" bind_python: False - name: "sign" signature: "Tensor (Tensor x) => Sign" bind_python: True - name: "sign_grad" signature: "Tensor (Tensor x, Tensor dy) => SignGrad" bind_python: False - name: "sinh" signature: "Tensor (Tensor x) => Sinh" bind_python: True - name: "sinh_grad" signature: "Tensor (Tensor x, Tensor dy) => SinhGrad" bind_python: False - name: "softplus" signature: "Tensor (Tensor x, Double beta=1.0, Double threshold=20.0) => Softplus" bind_python: True - name: "softplus_grad" signature: "Tensor (Tensor x, Tensor dy, Double beta=1.0, Double threshold=20.0) => SoftplusGrad" bind_python: False - name: "softshrink" signature: "Tensor (Tensor x, *, Double alpha=0.5, Bool inplace=False) => SoftShrink" bind_python: True - name: "softshrink_grad" signature: "Tensor (Tensor y, Tensor dy, Double alpha=0.5) => SoftShrinkGrad" bind_python: False - name: "one_hot" signature: "Tensor (Tensor input, Int64 num_classes=-1, Scalar on_value=1, Scalar off_value=0) => OneHot" bind_python: True - name: "unsorted_segment_sum_like" signature: "Tensor (Tensor x, Tensor segment_ids, Tensor like, Int64 axis) => UnsortedSegmentSumLike" bind_python: True - name: "unsorted_segment_sum" signature: "Tensor (Tensor x, Tensor segment_ids, Int64 axis, Int64 num_segments) => UnsortedSegmentSum" bind_python: True - name: "tril" signature: "Tensor (Tensor x, Int64 diagonal=0) => Tril" bind_python: True - name: "tril_" signature: "Tensor (Tensor x, Int64 diagonal=0) => InplaceTril" bind_python: True - name: "triu" signature: "Tensor (Tensor x, Int64 diagonal=0) => Triu" bind_python: True - name: "triu_" signature: "Tensor (Tensor x, Int64 diagonal=0) => InplaceTriu" bind_python: True - name: "clamp" signature: "Tensor (Tensor input, Scalar min=None, Scalar max=None) => Clamp" bind_python: true - name: "clamp_" signature: "Tensor (Tensor input, Scalar min=None, Scalar max=None) => ClampInplace" bind_python: true - name: "clamp_min" signature: "Tensor (Tensor input, Scalar min) => ClampMin" bind_python: true - name: "clamp_min_" signature: "Tensor (Tensor input, Scalar min) => ClampMinInplace" bind_python: true - name: "clamp_max" signature: "Tensor (Tensor input, Scalar max) => ClampMax" bind_python: true - name: "clamp_max_" signature: "Tensor (Tensor input, Scalar min) => ClampMaxInplace" bind_python: true - name: "clip" signature: ["Tensor (Tensor input, Scalar min=None, Scalar max=None) => Clip"] bind_python: true - name: "clip_" signature: ["Tensor (Tensor input, Scalar min=None, Scalar max=None) => ClipInplace"] bind_python: true - name: "clamp_grad" signature: "Tensor (Tensor dy, Tensor x, Scalar min=None, Scalar max=None) => ClampGrad" bind_python: False - name: "vector_norm" signature: [ "Tensor (Tensor input, Scalar ord=2, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => VectorNorm", "Tensor (Tensor input, Scalar ord=2, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => VectorNorm", ] bind_python: True - name: "matrix_norm" signature: [ "Tensor (Tensor input, Scalar ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm", "Tensor (Tensor input, String ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm", ] bind_python: True - name: "norm" signature: [ "Tensor (Tensor input, Scalar ord=None, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None, Bool for_norm=False) => Norm", "Tensor (Tensor input, String ord, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm", "Tensor (Tensor input, Scalar ord=None, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm", "Tensor (Tensor input, String ord, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm", ] bind_python: True - name: "inv" signature: "Tensor (Tensor x) => Inv" bind_python: True - name: "linalg_cross" signature: "Tensor (Tensor input, Tensor other, Int64 dim=None) => LinalgCross" bind_python: True - name: "det" signature: "Tensor (Tensor x) => Det" bind_python: True - name: "dropout" signature: "Tensor (Tensor input, Float p=0.5, Bool training=True, Bool inplace=False, Generator generator=None, *, Tensor addend=None) => Dropout" bind_python: True - name: "dropout_grad" signature: "Tensor (Tensor dy, Tensor mask, Float scale) => DropoutGrad" bind_python: False - name: "dropout1d" signature: "Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout1d" bind_python: True - name: "dropout2d" signature: "Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout2d" bind_python: True - name: "dropout3d" signature: "Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout3d" bind_python: True - name: "constant_pad" signature: "Tensor (Tensor x, Int64List pad, Scalar value=0) => ConstantPad" bind_python: False - name: "reflection_pad" signature: "Tensor (Tensor x, Int64List pad) => ReflectionPad" bind_python: False - name: "replication_pad" signature: "Tensor (Tensor x, Int64List pad) => ReplicationPad" bind_python: False - name: "pad" signature: 'Tensor (Tensor x, Int64List pad, String mode="constant", Scalar value=0) => Pad' bind_python: True - name: "pad_grad" signature: 'Tensor (Tensor dy, Int64List pad, String mode="constant", Scalar value=0) => PadGrad' bind_python: False - name: "silu" signature: "Tensor (Tensor x) => Silu" bind_python: True - name: "silu_grad" signature: "Tensor (Tensor dy, Tensor x) => SiluGrad" bind_python: False - name: "mish" signature: "Tensor (Tensor x) => Mish" bind_python: True - name: "mish_grad" signature: "Tensor (Tensor dy, Tensor x) => MishGrad" bind_python: False - name: "selu" signature: "Tensor (Tensor x) => Selu" bind_python: True - name: "selu_grad" signature: "Tensor (Tensor dy, Tensor x) => SeluGrad" bind_python: False - name: "softsign" signature: "Tensor (Tensor x) => SoftSign" bind_python: True - name: "softsign_grad" signature: "Tensor (Tensor dy, Tensor x) => SoftSignGrad" bind_python: False - name: "diag" signature: "Tensor (Tensor x, Int32 diagonal=0) => Diag" bind_python: True - name: "diag_grad" signature: "Tensor (Tensor dy, Tensor in, Int32 diagonal=0) => DiagGrad" bind_python: False - name: "diagonal" signature: "Tensor (Tensor x, Int32 offset=0, Int32 dim1=0, Int32 dim2=1) => Diagonal" bind_python: True - name: "diagonal_grad" signature: "Tensor (Tensor dy, Tensor in, Int32 offset=0) => DiagonalGrad" bind_python: False - name: "tensor_getitem" signature: "Tensor (Tensor x, TensorIndex index) => TensorGetItem" bind_python: False - name: "scatter" signature: [ "Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, String reduce=None, Bool inplace=False) => DimScatter", "Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, String reduce=None, Bool inplace=False) => DimScatterScalar", ] bind_python: True - name: "scatter_update" signature: [ "Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterUpdate", "Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterUpdateScalar", ] bind_python: False - name: "scatter_add" signature: [ "Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterAdd", "Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterAddScalar", ] bind_python: True - name: "scatter_mul" signature: [ "Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterMul", "Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterMulScalar", ] bind_python: False - name: "scatter_add_like" signature: "Tensor (Tensor like, Int32 dim, Tensor index, Tensor src) => DimScatterAddLike" bind_python: False - name: "tensor_setitem" signature: "Void (Tensor x, TensorIndex index, Tensor value) => TensorSetItem" bind_python: True - name: "avg_pool1d" signature: 'Tensor (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None, Int32List[1] padding=0, Bool ceil_mode=False, Bool count_include_pad=True, Int32 divisor_override=0, String data_format="channels_first") => AvgPool1D' bind_python: True - name: "avg_pool2d" signature: 'Tensor (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None, Int32List[2] padding=0, Bool ceil_mode=False, Bool count_include_pad=True, Int32 divisor_override=0, String data_format="channels_first") => AvgPool2D' bind_python: True - name: "avg_pool3d" signature: 'Tensor (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None, Int32List[3] padding=0, Bool ceil_mode=False, Bool count_include_pad=True, Int32 divisor_override=0, String data_format="channels_first") => AvgPool3D' bind_python: True - name: "avg_pool_grad" signature: "Tensor (Tensor x, Tensor dy, Int32 ndims, String data_format, Int32List padding, Int32List kernel_size, Int32List stride, Bool ceil_mode, Bool count_include_pad, Int32 divisor_override=0) => AvgPoolNdGrad" bind_python: False - name: "minimum" signature: "Tensor (Tensor input, Tensor other) => Minimum" bind_python: True - name: "maximum" signature: "Tensor (Tensor input, Tensor other) => Maximum" bind_python: True - name: "elementwise_min_grad" signature: "TensorTuple (Tensor dz, Tensor x, Tensor y) => ElementwiseMinGrad" bind_python: False - name: "elementwise_max_grad" signature: "TensorTuple (Tensor dz, Tensor x, Tensor y) => ElementwiseMaxGrad" bind_python: False - name: "stack" signature: "Tensor (TensorTuple inputs, Int64 dim=0) => Stack" bind_python: True - name: "stack_grad" signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => StackGrad" bind_python: False - name: "atleast_1d" signature: [ "Tensor (Tensor input) => AtLeast1D", "TensorTuple (TensorTuple tensors) => AtLeast1D", ] bind_python: True - name: "atleast_2d" signature: [ "Tensor (Tensor input) => AtLeast2D", "TensorTuple (TensorTuple tensors) => AtLeast2D", ] bind_python: True - name: "atleast_3d" signature: [ "Tensor (Tensor input) => AtLeast3D", "TensorTuple (TensorTuple tensors) => AtLeast3D", ] bind_python: True - name: "hstack" signature: "Tensor (TensorTuple tensors) => HStack" bind_python: True - name: "vstack" signature: "Tensor (TensorTuple tensors) => VStack" bind_python: True - name: "dstack" signature: "Tensor (TensorTuple tensors) => DStack" bind_python: True - name: "column_stack" signature: "Tensor (TensorTuple tensors) => ColumnStack" bind_python: True - name: "row_stack" signature: "Tensor (TensorTuple tensors) => RowStack" bind_python: True - name: "local_to_global" signature: "Tensor (Tensor x, Placement placement, SbpList sbp, Shape shape, DataType dtype, Bool sync_data, Bool copy=False) => LocalToGlobal" bind_python: False - name: "to_global" signature: "Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool copy=False) => ToGlobal" bind_python: True - name: "to_local" signature: "Tensor (Tensor x, Bool copy=False) => GlobalToLocal" bind_python: True - name: "stream_touch" signature: "Void (TensorTuple x) => StreamTouch" bind_python: True - name: "comm_broadcast" signature: [ "Tensor (Tensor x, *, Int64 src_rank=0, Bool inplace=True) => CommBroadcast", "TensorTuple (TensorTuple inputs, *, Int64 src_rank=0, Bool inplace=True) => CommBroadcastTensors", ] bind_python: True - name: "local_all_reduce" signature: "Tensor (Tensor x, Bool inplace=False) => LocalAllReduce" bind_python: True - name: "local_all_gather" signature: "Tensor (Tensor output, Tensor input) => LocalAllGather" bind_python: True - name: "local_reduce_scatter" signature: "Tensor (Tensor output, Tensor input) => LocalReduceScatter" bind_python: True - name: "local_reduce" signature: "Tensor (Tensor x, *, Int64 dst=0, Bool inplace=True) => LocalReduce" bind_python: True - name: "eager_p_to_b" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, Shape shape) => EagerPToB" bind_python: False - name: "eager_b_to_s" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerBToS" bind_python: False - name: "eager_s_to_b" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList in_sbp, Shape shape) => EagerSToB" bind_python: False - name: "eager_naive_s_to_s" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList in_sbp, SbpList out_sbp, Shape shape) => EagerNaiveSToS" bind_python: False - name: "eager_p_to_s" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerPToS" bind_python: False - name: "eager_s_to_p" signature: "Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerSToP" bind_python: False - name: "global_all_reduce" signature: "Tensor (Tensor x) => GlobalAllReduce" bind_python: False - name: "global_reduce_scatter" signature: "Tensor (Tensor x, String op_type) => GlobalReduceScatter" bind_python: False - name: "global_all_gather" signature: "Tensor (Tensor x) => GlobalAllGather" bind_python: False - name: "global_s2s" signature: "Tensor (Tensor x, SbpList out_sbp) => GlobalS2S" bind_python: False - name: "select_top_n" signature: "TensorTuple (TensorTuple inputs, Int32 n) => SelectTopN" bind_python: True - name: "cast_like" signature: "Tensor (Tensor x, Tensor like) => CastLike" bind_python: False - name: "identity" signature: "Tensor (Tensor in) => Identity" bind_python: True - name: "amp_white_identity" signature: "Tensor (Tensor in) => AmpWhiteIdentity" bind_python: True - name: "amp_black_identity" signature: "Tensor (Tensor in) => AmpBlackIdentity" bind_python: True - name: "reshape_like" signature: "Tensor (Tensor in, Tensor like) => ReshapeLike" bind_python: True - name: "reduce_sum_like" signature: "Tensor (Tensor in, Tensor like, Int32List axis) => ReduceSumLike" bind_python: True - name: "broadcast_reduce_sum_like" signature: "Tensor (Tensor in, Tensor like) => BroadcastReduceSumLike" bind_python: False - name: "rand" signature: [ "Tensor (Shape size, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False) => Rand", "Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False) => GlobalRand", ] bind_python: True - name: "randn" signature: [ "Tensor (Shape size, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False, Layout layout=kStrided) => RandN", "Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False) => GlobalRandN", ] bind_python: True - name: "randn_like" signature: [ "Tensor (Tensor input, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False) => RandnLike", "Tensor (Tensor input, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False) => GlobalRandnLike", ] bind_python: True - name: "randint" signature: [ "Tensor (Int64 low, Int64 high, Shape size, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt", "Tensor (Int64 high, Shape size, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt", "Tensor (Int64 low, Int64 high, Shape size, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandInt", "Tensor (Int64 high, Shape size, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandInt", ] bind_python: True - name: "randint_like" signature: [ "Tensor (Tensor x, Int64 low, Int64 high, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False)=> RandIntLike", "Tensor (Tensor x, Int64 high, *, DataType dtype=None, Device device=None, Generator generator=None, Bool requires_grad=False)=> RandIntLike", "Tensor (Tensor x, Int64 low, Int64 high, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandIntLike", "Tensor (Tensor x, Int64 high, *, Placement placement, SbpList sbp, DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandIntLike", ] bind_python: True - name: "randperm" signature: [ "Tensor (Int32 n, *, Generator generator=None, DataType dtype=kInt64, Device device=None, Bool requires_grad=False) => RandPerm", "Tensor (Int32 n, *, Placement placement, SbpList sbp, Generator generator=None, DataType dtype=kInt64, Bool requires_grad=False) => GlobalRandPerm", ] bind_python: True - name: "unfold_tensor" signature: "Tensor (Tensor x, Int32 dimension, Int32 size, Int32 step) => UnfoldTensor" bind_python: True - name: "unfold_tensor_grad" signature: "Tensor (Tensor dy, Tensor x, Int32 dimension, Int32 size, Int32 step) => UnfoldTensorGrad" bind_python: False - name: "unfold" signature: 'Tensor (Tensor x, Int32List[2] kernel_size, Int32List[2] dilation=1, Int32List[2] padding=0, Int32List[2] stride=1, String data_format="channels_first") => Unfold' bind_python: True - name: "fold" signature: 'Tensor (Tensor x, Int32List[1] output_size, Int32List[2] kernel_size, Int32List[2] dilation=1, Int32List[2] padding=0, Int32List[2] stride=1, String data_format="channels_first") => Fold' bind_python: True - name: "split" signature: [ "TensorTuple (Tensor x, Int64 split_size_or_sections, Int64 dim=0) => Split", "TensorTuple (Tensor x, Int64List split_size_or_sections, Int64 dim=0) => SplitWithSize", ] bind_python: True - name: "unbind" signature: ["TensorTuple (Tensor x, Int64 dim=0) => Unbind"] bind_python: True - name: "chunk" signature: ["TensorTuple (Tensor x, Int64 chunks, Int64 dim=0) => Chunk"] bind_python: True - name: "split_like" signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => SplitLike" bind_python: True - name: "pairwise_distance" signature: "Tensor (Tensor x1, Tensor x2, Float p=2.0, Double eps=1e-6, Bool keepdim=False) => PairwiseDistance" bind_python: True - name: "cosine_similarity" signature: "Tensor (Tensor x, Tensor y, Int32 dim=1, Double eps=1e-8) => CosineSimilarity" bind_python: True - name: "normalize" signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize" bind_python: True - name: "l2_normalize" signature: "Tensor (Tensor input, Int32 axis=0, Float epsilon=1e-12) => L2Normalize" bind_python: False - name: "l2_normalize_grad" signature: "Tensor (Tensor dy, Tensor y, Tensor square_x_sum, Int32 axis, Float epsilon) => L2NormalizeGrad" bind_python: False - name: "fused_self_attention" signature: "TensorTuple (Tensor hidden_states, Int64 head_size=8, Float alpha=1.0) => FusedSelfAttention" bind_python: True - name: "fused_self_attention_grad" signature: "Tensor (Tensor query_mul_key_grad, Tensor value_grad, Tensor hidden_states, Float alpha=1.0) => FusedSelfAttentionGrad" bind_python: False - name: "fused_scale_tril" signature: "Tensor (Tensor x, Int64 diagonal=0, Scalar fill_value=0, Scalar scale=1) => FusedScaleTril" bind_python: True - name: "fused_bias_add_gelu" signature: "Tensor (Tensor a, Tensor b, *, Int32 axis) => FusedBiasAddGelu" bind_python: True - name: "fused_bias_add_gelu_grad" signature: "Tensor (Tensor a, Tensor b, Tensor dy, Int32 axis) => FusedBiasAddGeluGrad" bind_python: false - name: "fused_bias_add_dropout" signature: "Tensor (Tensor a, Tensor b, *, Float p=0.5, Int32 axis, Generator generator=None) => FusedBiasAddDropout" bind_python: True - name: "fused_scale_mask_softmax" signature: "Tensor (Tensor x, Tensor mask, *, Float fill_value=0.0, Float scale=1.0) => FusedScaleMaskSoftmax" bind_python: True - name: "fused_scale_mask_softmax_grad" signature: "Tensor (Tensor y, Tensor dy, Tensor mask, Float scale=1.0) => FusedScaleMaskSoftmaxGrad" bind_python: False - name: "fused_scale_mask_softmax_dropout" signature: "TensorTuple (Tensor x, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedScaleMaskSoftmaxDropout" bind_python: True - name: "fused_scale_mask_softmax_dropout_grad" signature: "Tensor (Tensor softmax_y, Tensor dy, Tensor mask, Tensor dropout_mask, Float scale=1.0, Float dropout_scale=1.0) => FusedScaleMaskSoftmaxDropoutGrad" bind_python: False - name: "fused_scale_tril_softmax_mask_scale" signature: "TensorTuple (Tensor a, *, Float p=0.5, Int64 diagonal, Float tril_scale_value, Float tril_fill_value=0.0, Generator generator=None) => FusedScaleTrilSoftmaxMaskScale" bind_python: True - name: "fused_scale_tril_softmax_mask_scale_grad" signature: "Tensor (Tensor softmax_y, Tensor dy, Tensor mask, Int64 diagonal, Float tril_scale_value, Float mask_scale_value) => FusedScaleTrilSoftmaxMaskScaleGrad" bind_python: False - name: "fused_bias_add_scale_mask_softmax_dropout" signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout" bind_python: True - name: "scaled_dot_product_attention" signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention" bind_python: True - name: "scaled_dot_product_attention_grad" signature: "TensorTuple (Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor softmax_lse, Tensor rng_state, Float dropout_p=0.0, Bool is_causal=False, Float scale=0.0) => ScaledDotProductFlashAttentionGrad" bind_python: False - name: "fused_multi_head_attention_inference" signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference" bind_python: True - name: "fused_multi_head_attention_inference_v2" signature: 'Tensor (*, Tensor query, String query_layout, Int64 query_head_size=None, Tensor query_seq_start=None, Int64 query_max_seq_len=None, Tensor key=None, String key_layout=None, Tensor key_seq_start=None, Tensor key_seq_len=None, Int64 key_max_seq_len=None, Tensor value=None, String value_layout=None, Tensor attn_bias=None, String output_layout="BM(HK)", Float scale=None, Bool causal=None, String attn_mask_type=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInferenceV2' bind_python: True - name: "fused_attention_concat_past_key_value" signature: 'TensorTuple (*, Tensor past_key=None, String past_key_layout, Tensor past_value=None, String past_value_layout, Tensor key, String key_layout, Tensor value, String value_layout, Int64 key_head_size=None) => FusedAttentionConcatPastKeyValue' bind_python: True - name: "fused_scale_mask_bias_softmax" signature: 'Tensor (Tensor x, Tensor mask, Tensor bias=None, Float scale=0.35355, Bool inplace=False) => FusedScaleMaskBiasSoftmax' bind_python: True - name: "fused_scale_mask_bias_softmax_grad" signature: 'Tensor (Tensor y, Tensor dy, Float scale=0.35355) => FusedScaleMaskBiasSoftmaxGrad' bind_python: False - name: "noncontiguous_binary_op" signature: 'Tensor (Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOp' bind_python: True - name: "noncontiguous_binary_op_grad" signature: 'TensorTuple (Tensor dy, Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOpGrad' bind_python: False - name: "fused_get_center_dist" signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedCenter" bind_python: True - name: "fused_get_center_dist_grad" signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor rho2_diff) => FusedCenterGrad" bind_python: False - name: "fused_get_intersection_area" signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedGetIntersectionArea" bind_python: True - name: "fused_get_intersection_area_grad" signature: "TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor inter_diff) => FusedGetIntersectionAreaGrad" bind_python: False - name: "fused_get_ciou_diagonal_angle" signature: "Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Float eps) => FusedGetCiouDiagonalAngle" bind_python: True - name: "fused_get_ciou_diagonal_angle_grad" signature: "TensorTuple (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor v_diff, Float eps) => FusedGetCiouDiagonalAngleGrad" bind_python: False - name: "fused_get_convex_diagonal_squared" signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquared" bind_python: True - name: "fused_get_convex_diagonal_squared_grad" signature: "TensorTuple (Tensor c2_diff, Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquaredGrad" bind_python: False - name: "grouped_matmul_bias" signature: "TensorTuple (TensorTuple xs, TensorTuple weights, TensorTuple biases) => GroupedMatmulBias" bind_python: True - name: "grouped_matmul" signature: "TensorTuple (TensorTuple xs, TensorTuple weights) => GroupedMatmul" bind_python: True - name: "send" signature: "Void (Tensor input, Int64 dst, Bool send_meta=True) => Send" bind_python: True - name: "recv" signature: "Tensor (Int64 src, Shape shape=None, DataType dtype=None, Device device=None, *, Tensor out=None) => Recv" bind_python: True - name: "batch_gather" signature: "Tensor (Tensor in, Tensor indices) => BatchGather" bind_python: True - name: "unsorted_batch_segment_sum" signature: "Tensor (Tensor data, Tensor segment_ids, Int64 num_segments) => UnsortedBatchSegmentSum" bind_python: False - name: "ctc_greedy_decoder" signature: "TensorTuple (Tensor log_probs, Tensor input_lengths, Bool merge_repeated=True) => CtcGreedyDecoder" bind_python: True - name: "distributed_partial_fc_sample_disable_boxing" signature: "TensorTuple (Tensor sampled_weight_diff, Tensor sampled_label) => DistributedPariticalFCSampleDisableBoxing" bind_python: False - name: "nms" signature: "Tensor (Tensor x, Float iou_threshold, Int32 keep_n=-1) => Nms" bind_python: True - name: "roi_align" signature: "Tensor (Tensor x, Tensor rois, Float spatial_scale, Int32 pooled_h, Int32 pooled_w, Int32 sampling_ratio, Bool aligned) => RoiAlign" bind_python: True - name: "roi_align_grad" signature: "Tensor (Tensor dy, Tensor x_like, Tensor rois, Float spatial_scale, Int32 pooled_h, Int32 pooled_w, Int32 sampling_ratio, Bool aligned) => RoiAlignGrad" bind_python: False - name: "meshgrid" signature: 'TensorTuple (TensorTuple tensors, String indexing="ij") => Meshgrid' bind_python: True - name: "index_select" signature: "Tensor (Tensor input, Int64 dim, Tensor index) => IndexSelect" bind_python: True - name: "dot" signature: "Tensor (Tensor input, Tensor other) => Dot" bind_python: True - name: "fused_dot_feature_interaction" signature: 'Tensor (TensorTuple features, Tensor output_concat=None, Bool self_interaction=False, Int32 output_padding=0, String pooling="none") => FusedDotFeatureInteraction' bind_python: True - name: "fused_dot_feature_interaction_grad" signature: 'TensorTuple (Tensor dy, TensorTuple features, Bool has_output_concat_grad=False, Bool self_interaction=False, Int32 output_concat_grad_dim=0, String pooling="none") => FusedDotFeatureInteractionGrad' bind_python: False - name: "fused_cross_feature_interaction" signature: "Tensor (Tensor x, Tensor weight, Tensor x_0, Tensor bias, String interaction_mode) => FusedCrossFeatureInteraction" bind_python: True - name: "fused_cross_feature_interaction_v1_grad" signature: "TensorTuple (Tensor dy, Tensor weight, Tensor x, Tensor x_0, Tensor matmul_result) => FusedCrossFeatureInteractionV1Grad" bind_python: False - name: "fused_cross_feature_interaction_v2_grad" signature: "TensorTuple (Tensor dy, Tensor weight, Tensor bias, Tensor x, Tensor x_0, Tensor matmul_result) => FusedCrossFeatureInteractionV2Grad" bind_python: False - name: "tensor_buffer_to_tensor" signature: "Tensor (Tensor input, Shape instance_shape, DataType dtype) => TensorBufferToTensor" bind_python: True - name: "tensor_to_tensor_buffer" signature: "Tensor (Tensor input, Int32 instance_dims) => TensorToTensorBuffer" bind_python: True - name: "gen_tensor_buffer" signature: "Tensor (Shape shape, ShapeList shape_list, FloatList value_list, DataType data_type, Bool dynamic_out) => GenTensorBuffer" bind_python: True - name: "topk" signature: "TensorTuple[values, indices] (Tensor input, Int32 k, Int32 dim=None, Bool largest=True, Bool sorted=True) => TopK" bind_python: True - name: "in_top_k" signature: "Tensor (Tensor targets, Tensor predictions, Int32 k) => InTopK" bind_python: True - name: "cumsum" signature: "Tensor (Tensor input, Int64 dim, *, DataType dtype=None) => Cumsum" bind_python: True - name: "cumprod" signature: "Tensor (Tensor input, Int64 dim, *, DataType dtype=None) => Cumprod" bind_python: True - name: "cumprod_grad" signature: "Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad" bind_python: False - name: "one_embedding_id_shuffle" signature: "TensorTuple (Tensor ids, Tensor table_ids=None, Int32 num_tables=1, String embedding_name) => OneEmbeddingIdShuffle" bind_python: True - name: "one_embedding_embedding_shuffle" signature: "Tensor (Tensor cur_rank_embeddings, Tensor num_unique_matrix, Tensor cur_rank_inverse_indices, Tensor inverse_unique_partition_indices, String embedding_name) => OneEmbeddingEmbeddingShuffle" bind_python: True - name: "one_embedding_embedding_gradient_shuffle" signature: "Tensor (Tensor embedding_grad, Tensor num_unique_matrix, Tensor cur_rank_inverse_indices, Tensor inverse_unique_partition_indices, String embedding_name) => OneEmbeddingEmbeddingGradientShuffle" bind_python: True - name: "one_embedding_lookup" signature: "Tensor (Tensor num_unique_ids, Tensor unique_ids, Tensor table_ids, DataType dtype, DataType embedding_dtype, Int64 line_size, Int64 embedding_size, String embedding_name, String embedding_tables, String state_initializer, Int64 seed=0) => OneEmbeddingLookup" bind_python: True - name: "one_embedding_fused_lookup" signature: "Tensor (Tensor shadow, Tensor ids, Tensor table_ids=None, DataType dtype, String embedding_name, Int64 line_size, Int64 embedding_size, Bool is_full_cache, Int32 num_tables, String embedding_tables, Int64 padding_idx=None, Int64 seed=0) => OneEmbeddingFusedLookup" bind_python: True - name: "one_embedding_fused_lookup_grad" signature: "Void (Tensor ids, Tensor embedding_grad, String embedding_name, Int64 line_size, Int64 embedding_size) => OneEmbeddingFusedLookupGrad" bind_python: True - name: "one_embedding_unique_key_value_pair" signature: "TensorTuple (Tensor keys, Tensor values=None, Int32 num_tables, String embedding_name) => OneEmbeddingUniqueKeyValuePair" bind_python: True - name: "one_embedding_embedding_put" signature: "Void (Tensor num_unique_ids, Tensor unique_ids, Tensor unique_embeddings, String embedding_name, Int64 line_size) => OneEmbeddingEmbeddingPut" bind_python: True - name: "one_embedding_sgd_update" signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Float learning_rate_val, Double scale, Float weight_decay, Float momentum, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingSgdUpdate" bind_python: True - name: "one_embedding_adam_update" signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Tensor bias_correction1=None, Tensor bias_correction2=None, Float learning_rate_val, Double scale, Float weight_decay, Float beta1, Float beta2, Float bias_correction1_val, Float bias_correction2_val, Float epsilon, Bool do_bias_correction, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingAdamUpdate" bind_python: True - name: "one_embedding_adagrad_update" signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Tensor train_step=None, Int64 train_step_val, Float learning_rate_val, Double scale, Float weight_decay, Float lr_decay, Float epsilon, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingAdagradUpdate" bind_python: True - name: "one_embedding_ftrl_update" signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Float learning_rate_val, Double scale, Float weight_decay, Float lr_power, Float lambda1, Float lambda2, Float beta, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingFtrlUpdate" bind_python: True - name: "einsum" signature: "Tensor (String equation, TensorTuple operands) => EinSum" bind_python: True - name: "pixel_shuffle" signature: "Tensor (Tensor input, Int64 h_upscale_factor, Int64 w_upscale_factor) => PixelShuffle" bind_python: True - name: "isnan" signature: "Tensor (Tensor input) => IsNan" bind_python: True - name: "isinf" signature: "Tensor (Tensor input) => IsInf" bind_python: True - name: "isfinite" signature: "Tensor (Tensor input) => IsFinite" bind_python: True - name: "depend" signature: [ "Tensor (Tensor input, Tensor depend) => Depend", "Tensor (Tensor input, TensorTuple depends) => DependTuple", ] bind_python: True - name: "roc_auc_score" signature: "Tensor (Tensor label, Tensor pred) => RocAucScore" bind_python: True - name: "pin_memory" signature: "Tensor (Tensor input) => PinMemory" bind_python: True - name: "fill_" signature: [ "Tensor (Tensor in, Tensor value) => FillTensor", "Tensor (Tensor in, Scalar value) => Fill", ] bind_python: True - name: "index_add" signature: "Tensor (Tensor input, Int64 dim, Tensor index, Tensor source, Scalar alpha=1.0) => IndexAdd" bind_python: True - name: "index_add_" signature: "Tensor (Tensor input, Int64 dim, Tensor index, Tensor source, Scalar alpha=1.0) => IndexAddInplace" bind_python: True - name: "rnn_tanh_cell" signature: "Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => RnnTanhCell" bind_python: True - name: "rnn_relu_cell" signature: "Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => RnnReluCell" bind_python: True - name: "lstm_cell" signature: "TensorTuple (Tensor input, TensorTuple hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => LstmCell" bind_python: True - name: "gru_cell" signature: "Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => GruCell" bind_python: True - name: "_fused_gru_cell" signature: "TensorTuple (Tensor igates, Tensor hgates, Tensor hx, Tensor b_ih=None, Tensor b_hh=None) => FusedGruCell" bind_python: False - name: "_fused_gru_cell_grad" signature: "TensorTuple (Tensor grad_hy, Tensor workspace, Bool has_bias, Bool hx_needs_grad) => FusedGruCellGrad" bind_python: False - name: "_fused_lstm_cell" signature: "TensorTuple (Tensor igates, Tensor hgates, Tensor cx, Tensor b_ih=None, Tensor b_hh=None) => FusedLstmCell" bind_python: False - name: "_fused_lstm_cell_grad" signature: "TensorTuple (Tensor grad_hy, Tensor grad_cy, Tensor cx, Tensor cy, Tensor workspace, Bool need_cx_grad, Bool has_bias) => FusedLstmCellGrad" bind_python: False - name: "rnn_tanh" signature: [ "TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => RnnTanhInput", "TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => RnnTanhData", ] bind_python: True - name: "rnn_relu" signature: [ "TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => RnnReluInput", "TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => RnnReluData", ] bind_python: True - name: "lstm" signature: [ "TensorTuple (Tensor input, TensorTuple hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => LstmInput", "TensorTuple (Tensor data, Tensor batch_sizes, TensorTuple hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => LstmData", ] bind_python: True - name: "gru" signature: [ "TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => GruInput", "TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => GruData", ] bind_python: True - name: "pack_padded_sequence" signature: "TensorTuple (Tensor input, Tensor lengths, Bool batch_first) => PackPaddedSequence" bind_python: True - name: "multi_tensor_sgd_update" signature: "Void (TensorTuple model, TensorTuple model_diff, Double scale, Float weight_decay, Float learning_rate_val) => MultiTensorSgdUpdate" bind_python: True - name: "multi_tensor_yolov5_weight_update" signature: "Void (TensorTuple model, TensorTuple model_update, Float d) => MultiTensorYoloV5WeightUpdate" bind_python: True - name: "multi_tensor_momentum_update" signature: "Void (TensorTuple model, TensorTuple model_diff, TensorTuple momentum_buf, Double scale, Float weight_decay, Float learning_rate_val, Float momentum, Float dampening, Bool nesterov, Bool maximize) => MultiTensorMomentumUpdate" bind_python: True - name: "multi_tensor_adam_update" signature: "Void (TensorTuple model, TensorTuple model_diff, TensorTuple m, TensorTuple v, Float learning_rate_val, Float l2, Float beta1, Float beta2, Float bias_correction1_val, Float bias_correction2_val, Bool do_bias_correction, Double scale, Float weight_decay, Float epsilon) => MultiTensorAdamUpdate" bind_python: True - name: "grad_acc_repeat" signature: "Tensor (Tensor input, Int32 repeat_num) => GradAccRepeat" bind_python: False - name: "grad_acc_collect" signature: "Tensor (Tensor input, Int32 collect_num) => GradAccCollect" bind_python: False - name: "grad_acc_pack" signature: "Tensor (Tensor input, Int32 pack_num) => GradAccPack" bind_python: False - name: "grad_acc_unpack" signature: "Tensor (Tensor input, Int32 unpack_num) => GradAccUnpack" bind_python: False - name: "trunc" signature: "Tensor (Tensor input) => Trunc" bind_python: True - name: "silu_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SiluGradGrad" bind_python: False - name: "mish_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => MishGradGrad" bind_python: False - name: "selu_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SeluGradGrad" bind_python: False - name: "softsign_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SoftSignGradGrad" bind_python: False - name: "gelu_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => GeluGradGrad" bind_python: False - name: "hardsigmoid_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => HardSigmoidGradGrad" bind_python: False - name: "hardswish_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => HardSwishGradGrad" bind_python: False - name: "softplus_grad_grad" signature: "Tensor (Tensor x, Tensor dydx, Double beta=1.0, Double threshold=20.0) => SoftplusGradGrad" bind_python: False - name: "elu_grad_grad" signature: "Tensor (Tensor x, Tensor dydx, Double alpha) => EluGradGrad" bind_python: False - name: "celu_grad_grad" signature: "Tensor (Tensor y, Tensor dydx, Double alpha) => CeluGradGrad" bind_python: False - name: "batch_norm_stats" signature: "TensorTuple (Tensor input, Int32 axis, Float eps) => BatchNormStats" bind_python: True - name: "batch_norm_gather_stats_with_counts" signature: "TensorTuple (Tensor input, Tensor mean, Tensor invstd, Tensor running_mean=None, Tensor running_var=None, Float momentum, Float eps, Tensor counts) => BatchNormGatherStatsWithCounts" bind_python: True - name: "batch_norm_elemt" signature: "Tensor (Tensor input, Tensor weight, Tensor bias, Tensor mean, Tensor invstd, Int32 axis, Float eps) => BatchNormElemt" bind_python: True - name: "batch_norm_backward_reduce" signature: "TensorTuple (Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Int32 axis) => BatchNormBackwardReduce" bind_python: True - name: "batch_norm_backward_elemt" signature: "Tensor (Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, Int32 axis) => BatchNormBackwardElemt" bind_python: True - name: "adaptive_max_pool1d" signature: 'TensorTuple (Tensor input, Int64List[1] output_size, String data_format="channels_first") => AdaptiveMaxPool1D' bind_python: True - name: "adaptive_max_pool2d" signature: 'TensorTuple (Tensor input, Int64List[2] output_size, String data_format="channels_first") => AdaptiveMaxPool2D' bind_python: True - name: "adaptive_max_pool3d" signature: 'TensorTuple (Tensor input, Int64List[3] output_size, String data_format="channels_first") => AdaptiveMaxPool3D' bind_python: True - name: "adaptive_max_pool_grad" signature: 'Tensor (Tensor x, Tensor index, Tensor dy, Int32 ndims, String data_format="channels_first") => AdaptiveMaxPoolNdGrad' bind_python: False - name: "tan_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => TanGradGrad" bind_python: False - name: "sinh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SinhGradGrad" bind_python: False - name: "cosh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => CoshGradGrad" bind_python: False - name: "tanh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => TanhGradGrad" bind_python: False - name: "acos_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AcosGradGrad" bind_python: False - name: "asin_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AsinGradGrad" bind_python: False - name: "atan_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AtanGradGrad" bind_python: False - name: "asinh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AsinhGradGrad" bind_python: False - name: "acosh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AcoshGradGrad" bind_python: False - name: "atanh_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => AtanhGradGrad" bind_python: False - name: "erf_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => ErfGradGrad" bind_python: False - name: "erfc_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => ErfcGradGrad" bind_python: False - name: "exp_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => ExpGradGrad" bind_python: False - name: "exp2_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => Exp2GradGrad" bind_python: False - name: "expm1_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => Expm1GradGrad" bind_python: False - name: "log_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => LogGradGrad" bind_python: False - name: "logsigmoid_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => LogSigmoidGradGrad" bind_python: False - name: "log2_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => Log2GradGrad" bind_python: False - name: "log10_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => Log10GradGrad" bind_python: False - name: "log1p_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => Log1pGradGrad" bind_python: False - name: "reciprocal_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => ReciprocalGradGrad" bind_python: False - name: "reciprocal_no_nan_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => ReciprocalNoNanGradGrad" bind_python: False - name: "rsqrt_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => RsqrtGradGrad" bind_python: False - name: "sqrt_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SqrtGradGrad" bind_python: False - name: "square_grad_grad" signature: "Tensor (Tensor x, Tensor dydx) => SquareGradGrad" bind_python: False - name: "sigmoid_grad_grad" signature: "Tensor (Tensor y, Tensor dydx) => SigmoidGradGrad" bind_python: False - name: "max_pool_grad_grad" signature: "Tensor (Tensor dydx, Tensor indices, Int32 ndims) => MaxPoolNdGradGrad" bind_python: False - name: "exponential_" signature: "Tensor (Tensor x, Float lambd=1.0, Generator generator=None) => Exponential" bind_python: True - name: "multinomial" signature: "Tensor (Tensor x, Int32 num_samples, Bool replacement=False, Generator generator=None) => Multinomial" bind_python: True - name: "max_pool_grad_grad" signature: "Tensor (Tensor dydx, Tensor indices, Int32 ndims) => MaxPoolNdGradGrad" bind_python: False - name: "deform_conv2d" signature: "Tensor (Tensor input,Tensor weight,Tensor offset,Tensor mask,Tensor bias=None, Int32 stride_h,Int32 stride_w,Int32 pad_h, Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2d" bind_python: True - name: "deform_conv2d_input_grad" signature: "TensorTuple (Tensor output_grad,Tensor input,Tensor weight,Tensor offset,Tensor mask=None, Int32 stride_h,Int32 stride_w,Int32 pad_h, Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2dInputGrad" bind_python: False - name: "deform_conv2d_param_grad" signature: "Tensor (Tensor output_grad,Tensor input,Tensor weight,Tensor offset,Tensor mask, Int32 stride_h,Int32 stride_w,Int32 pad_h, Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2dParamGrad" bind_python: False - name: "broadcast_shapes" signature: "Shape (ShapeList shapes) => BroadcastShapes" bind_python: True - name: "broadcast_tensors" signature: "TensorTuple (TensorTuple tensors) => BroadcastTensors" bind_python: True - name: "broadcast_to" signature: "Tensor (Tensor x, Shape shape) => BroadcastTo" bind_python: True - name: "bincount" signature: "Tensor (Tensor input, Tensor weights=None, Int64 minlength=None) => BinCount" bind_python: True - name: "stft" signature: 'Tensor (Tensor input, Int64 n_fft,Int64 hop_length=None, Int64 win_length=None, Tensor window=None,Bool center=True,String pad_mode="reflect",Bool normalized=False,Bool onesided=True,Bool return_complex=False) =>Stft' bind_python: True - name: "fft_c2c" signature: 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) => FftC2C' bind_python: False - name: "fft_r2c" signature: 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool onesided=False, Bool forward=True, Bool normalized=False) => FftR2C' bind_python: False - name: "fft_c2r" signature: 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) =>FftC2R' bind_python: False - name: "fft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => Fft' bind_python: True - name: "ifft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IFft' bind_python: True - name: "fft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => Fft2' bind_python: True - name: "ifft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IFft2' bind_python: True - name: "fftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => FftN' bind_python: True - name: "ifftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IFftN' bind_python: True - name: "rfft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => RFft' bind_python: True - name: "irfft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IRFft' bind_python: True - name: "rfft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => RFft2' bind_python: True - name: "irfft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IRFft2' bind_python: True - name: "rfftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => RFftN' bind_python: True - name: "irfftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IRFftN' bind_python: True - name: "hfft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => HFft' bind_python: True - name: "ihfft" signature: 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IHFft' bind_python: True - name: "hfft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => HFft2' bind_python: True - name: "ihfft2" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IHFft2' bind_python: True - name: "hfftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => HFftN' bind_python: True - name: "ihfftn" signature: 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IHFftN' bind_python: True - name: "isclose" signature: "Tensor (Tensor input, Tensor other, Float atol=1e-08, Float rtol=1e-05, Bool equal_nan=False) => IsClose" bind_python: True - name: "uniform_" signature: "Tensor (Tensor x,Scalar from, Scalar to) => InplaceUniform" bind_python: True - name: "fused_fast_gelu_mul" signature: "Tensor (Tensor x, Tensor multiplier) => FusedFastGeluMul" bind_python: True - name: "fused_fast_gelu_mul_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor multiplier) => FusedFastGeluMulGrad" bind_python: False - name: "unique" signature: [ "Tensor (Tensor x, Bool sorted=True, DataType dtype=kInt32) => Unique", "TensorTuple (Tensor x, Bool sorted=True, Bool return_inverse=False, Bool return_counts=False, DataType dtype=kInt32) => UniqueWithCounts" ] bind_python: True - name: "fused_weighted_sum" signature: "Tensor (TensorTuple in, FloatList weights, Float alpha=1.0) => FusedWeightedSum" bind_python: True - name: "sort" signature: "TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool descending=False) => Sort" bind_python: True - name: "throw_error" signature: "Tensor (Tensor input) => ThrowError" bind_python: True - name: "mode" signature: "TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool keepdim=False) => Mode" bind_python: True - name: "clone" signature: "Tensor (Tensor input) => Clone" bind_python: True - name: "real" signature: "Tensor (Tensor x) => Real" bind_python: True - name: "real_grad" signature: "Tensor (Tensor dout) => RealGrad" bind_python: False - name: "imag" signature: "Tensor (Tensor x) => Imag" bind_python: True - name: "imag_grad" signature: "Tensor (Tensor dout) => ImagGrad" bind_python: False - name: "conj" signature: "Tensor (Tensor x) => Conj" bind_python: True - name: "conj_physical" signature: "Tensor (Tensor x) => ConjPhysical" bind_python: True - name: "frac" signature: "Tensor (Tensor x) => Frac" bind_python: True - name: "frac_" signature: "Tensor (Tensor x) => FracInplace" bind_python: True - name: "digamma" signature: "Tensor (Tensor x) => Digamma" bind_python: True - name: "digamma_grad" signature: "Tensor (Tensor x, Tensor dy) => DigammaGrad" bind_python: False - name: "trigamma" signature: "Tensor (Tensor x) => Trigamma" bind_python: False - name: "zeta" signature: [ "Tensor (Tensor x, Tensor other) => BroadcastZeta", "Tensor (Scalar x, Tensor other) => ZetaScalarTensor", "Tensor (Tensor x, Scalar other) => ZetaTensorScalar", ] bind_python: True - name: "fused_clip_grad" signature: "Tensor (TensorTuple model_diff, Float max_norm, Float norm_type) => FusedClipGrad" bind_python: True ================================================ FILE: oneflow/core/functional/impl/activation_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace one { namespace functional { namespace impl { class ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe operator()(const std::shared_ptr& x, bool inplace) const { if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{})); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op_, {x}); } } private: std::shared_ptr op_; }; class ReluGradFunctor : public BinaryFunctor { public: ReluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu_grad").Input("dy").Input("y").Output("dx").Build()); } }; class PReluFunctor { public: PReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("prelu").Input("x").Input("alpha").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& alpha) const { int num_params = alpha->dim(0); CHECK_OR_RETURN(((num_params == 1) || (num_params == x->shape()->At(1)))) << Error::RuntimeError() << "num_parameters in prelu must be 1 or " << x->shape()->At(1); return OpInterpUtil::Dispatch(*op_, {x, alpha}); } private: std::shared_ptr op_; }; class PReluGradFunctor { public: PReluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("prelu_grad") .Input("dy") .Input("x") .Input("alpha") .Output("dx") .Output("alpha_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha_requires_grad"); attrs.SetAllAttrs(alpha->requires_grad()); return OpInterpUtil::Dispatch(*op_, {dy, x, alpha}, attrs); } private: std::shared_ptr op_; }; class HardTanhFunctor { public: HardTanhFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardtanh").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& min_val, const double& max_val) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("min_val", "max_val"); attrs.SetAllAttrs(min_val, max_val); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class HardTanhGradFunctor { public: HardTanhGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardtanh_grad").Input("y").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const double& min_val, const double& max_val) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("min_val", "max_val"); attrs.SetAllAttrs(min_val, max_val); return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); } private: std::shared_ptr op_; }; class EluFunctor { public: EluFunctor() { op_ = CHECK_JUST(one::OpBuilder("elu").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class EluGradFunctor { public: EluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("elu_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const double& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class CeluFunctor { public: CeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("celu").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& alpha, bool inplace) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); (*outputs)[0] = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; }; class CeluGradFunctor { public: CeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("celu_grad").Input("y").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const double& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); } private: std::shared_ptr op_; }; class GeluFunctor : public UnaryFunctor { public: GeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("gelu").Input("in").Output("out").Build()); } }; class GeluGradFunctor : public BinaryFunctor { public: GeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("gelu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class FastGeluFunctor : public UnaryFunctor { public: FastGeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("fast_gelu").Input("in").Output("out").Build()); } }; class FastGeluGradFunctor : public BinaryFunctor { public: FastGeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fast_gelu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class QuickGeluFunctor : public UnaryFunctor { public: QuickGeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("quick_gelu").Input("x").Output("y").Build()); } }; class QuickGeluGradFunctor : public BinaryFunctor { public: QuickGeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("quick_gelu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class SquareReLUFunctor : public UnaryFunctor { public: SquareReLUFunctor() { op_ = CHECK_JUST(one::OpBuilder("square_relu").Input("x").Output("y").Build()); } }; class SquareReLUGradFunctor : public BinaryFunctor { public: SquareReLUGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("square_relu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class GluFunctor { public: GluFunctor() {} Maybe operator()(const std::shared_ptr& input, int64_t dim) const { const auto ndim = input->ndim(); CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError() << "glu does not support scalars because halving size must be even"; dim = JUST(maybe_wrap_dim(dim, ndim)); if (dim < 0) { dim += ndim; } int64_t nc = input->dim(dim); CHECK_EQ_OR_RETURN(nc % 2, 0) << Error::RuntimeError() << "Halving dimension must be even, but dimension " << dim << " is size " << nc; nc = nc / 2; std::vector split_sizes(2, nc); const auto split_x = JUST(SplitWithSize(input, split_sizes, dim)); return sequence_function(functional::Sigmoid) .then(std::bind(functional::Mul, (*split_x)[0], std::placeholders::_1)) .call((*split_x)[1]); } }; class HardSigmoidFunctor { public: HardSigmoidFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardsigmoid").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, bool inplace) const { if (inplace) { JUST(CheckInplaceValid(input)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = input; JUST(OpInterpUtil::Dispatch(*op_, {input}, outputs.get(), AttrMap{})); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op_, {input}); } } private: std::shared_ptr op_; }; class HardSigmoidGradFunctor : public BinaryFunctor { public: HardSigmoidGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardsigmoid_grad").Input("dy").Input("x").Output("dx").Build()); } }; class HardShrinkFunctor { public: HardShrinkFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardshrink").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& lambd, bool inplace) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("lambd"); attrs.SetAllAttrs(lambd); if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); JUST(oneflow::VectorAt(*outputs, 0)) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return JUST(oneflow::VectorAt(*outputs, 0)); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; }; class HardShrinkGradFunctor { public: HardShrinkGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardshrink_grad").Input("y").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const double& lambd) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("lambd"); attrs.SetAllAttrs(lambd); return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); } private: std::shared_ptr op_; }; class SoftmaxFunctorBase { public: Maybe operator()(const std::shared_ptr& input, const Optional& dim) const { const auto input_shape = input->shape(); const int64_t num_axes = input_shape->NumAxes(); const auto get_dim = [num_axes]() -> int64_t { const int64_t ndim = num_axes; if (ndim == 0 || ndim == 1 || ndim == 3) { return 0; } else { return 1; } }; int64_t dim_ = dim ? JUST(dim) : get_dim(); dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); if (dim_ != num_axes - 1) { std::vector input_perm(input_shape->dim_vec().size(), 0); for (size_t i = 1; i < input_perm.size(); ++i) { input_perm[i] = i; } input_perm[dim_] = input_perm[input_perm.size() - 1]; input_perm[input_perm.size() - 1] = dim_; return sequence_function(functional::Transpose) .then([&](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_, {x}); }) .then(std::bind(functional::Transpose, std::placeholders::_1, input_perm)) .call(input, input_perm); } return OpInterpUtil::Dispatch(*op_, {input}); } protected: SoftmaxFunctorBase() = default; virtual ~SoftmaxFunctorBase() = default; std::shared_ptr op_; }; class SoftmaxFunctor : public SoftmaxFunctorBase { public: SoftmaxFunctor() { op_ = CHECK_JUST(one::OpBuilder("softmax").Input("in").Output("out").Build()); } }; class SoftmaxGradFunctor { public: SoftmaxGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("softmax_grad").Input("y").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {y, dy}); } private: std::shared_ptr op_; }; class LogSoftmaxFunctor : public SoftmaxFunctorBase { public: LogSoftmaxFunctor() { op_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build()); } }; class LogSoftmaxGradFunctor { public: LogSoftmaxGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("log_softmax_grad").Input("prob").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {y, dy}); } private: std::shared_ptr op_; }; class GumbelSoftmaxFunctor { public: Maybe operator()(const std::shared_ptr& in, const double& tau, const Optional& dim, bool hard, const Optional& generator) const { auto in_shape = in->shape(); auto device = JUST(in->device()); auto dtype = in->dtype(); const int64_t num_axes = in_shape->NumAxes(); const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); auto random_tensor = JUST(functional::Rand(*in_shape.get(), dtype, device, gen, /*requires_grad=*/false)); auto gumbel_noise_tensor = JUST(functional::ScalarSub( Scalar(0.0), JUST(functional::Log(JUST(functional::ScalarSub( Scalar(0.0), JUST(functional::Log(random_tensor)), /*alpha=*/1.0)))), /*alpha=*/1.0)); auto gumbel_in_tensor = JUST(functional::ScalarDiv( JUST(functional::Add(in, gumbel_noise_tensor, /*alpha=*/1.0, /*inplace=*/false)), Scalar(tau))); auto out_soft = JUST(functional::Softmax(gumbel_in_tensor, dim)); if (hard) { const auto get_dim = [num_axes]() -> int64_t { const int64_t ndim = num_axes; if (ndim == 0 || ndim == 1 || ndim == 3) { return 0; } else { return 1; } }; int64_t dim_ = dim ? JUST(dim) : get_dim(); dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); auto out_max = JUST(functional::ArgMax(out_soft, dim_, /*keepdim=*/true, dtype)); auto index = JUST(functional::To(out_max, JUST(DType::Get(DataType::kInt64)), /*copy=*/false)); auto zero = JUST(functional::ZerosLike(out_soft)); auto out_hard = JUST(functional::DimScatterUpdateScalar(zero, dim_, index, 1.0, /*inplace=*/false)); auto out_hard_has_grad = functional::Add(JUST(functional::Sub(out_hard, JUST(out_soft->detach()), /*alpha=*/1.0, /*inplace=*/false)), out_soft, /*alpha=*/1.0, /*inplace=*/false); return out_hard_has_grad; } else { return out_soft; } } }; class HardSwishFunctor : public UnaryFunctor { public: HardSwishFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardswish").Input("in").Output("out").Build()); } }; class HardSwishGradFunctor : public BinaryFunctor { public: HardSwishGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("hardswish_grad").Input("dy").Input("x").Output("dx").Build()); } }; class LeakyReluFunctor { public: LeakyReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("leaky_relu").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const float& alpha, bool inplace) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); JUST(oneflow::VectorAt(*outputs, 0)) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return JUST(oneflow::VectorAt(*outputs, 0)); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; }; class LeakyReluGradFunctor { public: LeakyReluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("leaky_relu_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const float& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class RReluFunctor { public: RReluFunctor() { op_ = CHECK_JUST( one::OpBuilder("rrelu").Input("in").Output("output").Output("noise_data").Build()); } Maybe operator()(const std::shared_ptr& x, const float& lower, const float& upper, bool training, bool inplace) const { if (!training) { return JUST(functional::LeakyRelu(x, ((lower + upper) / 2), inplace)); } auto gen = JUST( GetGeneratorForLazyOrGlobal(JUST(one::DefaultAutoGenerator()), LazyMode::is_enabled(), x)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "lower", "upper", "training"); attrs.SetAllAttrs(static_cast(gen->current_seed()), lower, upper, training); const auto& state = std::make_shared(gen); OpExprInterpContext ctx(attrs, state); std::shared_ptr outputs = std::make_shared(2); if (inplace) { JUST(CheckInplaceValid(x)); outputs->at(0) = x; } JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), ctx)); return outputs->at(0); } private: std::shared_ptr op_; }; class RReluInplaceFunctor { public: Maybe operator()(const std::shared_ptr& x, const float& lower, const float& upper, bool training) const { return JUST(functional::RRelu(x, lower, upper, training, true /*inplace*/)); } }; class SoftplusFunctor { public: SoftplusFunctor() { op_ = CHECK_JUST(one::OpBuilder("softplus").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& beta, const double& threshold) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("beta", "threshold"); attrs.SetAllAttrs(beta, threshold); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class SoftplusGradFunctor { public: SoftplusGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("softplus_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const double& beta, const double& threshold) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("beta", "threshold"); attrs.SetAllAttrs(beta, threshold); return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class SiluFunctor : public UnaryFunctor { public: SiluFunctor() { op_ = CHECK_JUST(one::OpBuilder("silu").Input("in").Output("out").Build()); } }; class SiluGradFunctor : public BinaryFunctor { public: SiluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("silu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class MishFunctor : public UnaryFunctor { public: MishFunctor() { op_ = CHECK_JUST(one::OpBuilder("mish").Input("in").Output("out").Build()); } }; class MishGradFunctor : public BinaryFunctor { public: MishGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("mish_grad").Input("dy").Input("x").Output("dx").Build()); } }; class SeluFunctor : public UnaryFunctor { public: SeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("selu").Input("in").Output("out").Build()); } }; class SeluGradFunctor : public BinaryFunctor { public: SeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("selu_grad").Input("dy").Input("x").Output("dx").Build()); } }; class SoftSignFunctor : public UnaryFunctor { public: SoftSignFunctor() { op_ = CHECK_JUST(one::OpBuilder("softsign").Input("in").Output("out").Build()); } }; class SoftSignGradFunctor : public BinaryFunctor { public: SoftSignGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("softsign_grad").Input("dy").Input("x").Output("dx").Build()); } }; class SoftShrinkFunctor { public: SoftShrinkFunctor() { op_ = CHECK_JUST(one::OpBuilder("softshrink").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& alpha, bool inplace) const { CHECK_GE_OR_RETURN(alpha, 0) << Error::RuntimeError() << "alpha must be greater or equal to 0, but found to be " << alpha << "."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); JUST(oneflow::VectorAt(*outputs, 0)) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return JUST(oneflow::VectorAt(*outputs, 0)); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; }; class ThresholdFunctor { public: ThresholdFunctor() { op_ = CHECK_JUST(one::OpBuilder("threshold").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& threshold, const double& value) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("threshold_val", "value"); attrs.SetAllAttrs(threshold, value); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ThresholdGradFunctor { public: ThresholdGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("threshold_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const double& threshold) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("threshold_val"); attrs.SetAllAttrs(threshold); return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class SoftShrinkGradFunctor { public: SoftShrinkGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("softshrink_grad").Input("y").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const double& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); } private: std::shared_ptr op_; }; class FracFunctor { public: FracFunctor() { op_ = CHECK_JUST(one::OpBuilder("frac").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class FracInplaceFunctor { public: FracInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder("frac").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{})); return outputs->at(0); } private: std::shared_ptr op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Frac"); m.add_functor("FracInplace"); m.add_functor("Relu"); m.add_functor("ReluGrad"); m.add_functor("PRelu"); m.add_functor("PReluGrad"); m.add_functor("HardTanh"); m.add_functor("HardTanhGrad"); m.add_functor("Elu"); m.add_functor("EluGrad"); m.add_functor("Celu"); m.add_functor("CeluGrad"); m.add_functor("Gelu"); m.add_functor("GeluGrad"); m.add_functor("FastGelu"); m.add_functor("FastGeluGrad"); m.add_functor("QuickGelu"); m.add_functor("QuickGeluGrad"); m.add_functor("SquareReLU"); m.add_functor("SquareReLUGrad"); m.add_functor("Glu"); m.add_functor("HardSigmoid"); m.add_functor("HardSigmoidGrad"); m.add_functor("HardShrink"); m.add_functor("HardShrinkGrad"); m.add_functor("Softmax"); m.add_functor("SoftmaxGrad"); m.add_functor("LogSoftmax"); m.add_functor("LogSoftmaxGrad"); m.add_functor("GumbelSoftmax"); m.add_functor("HardSwish"); m.add_functor("HardSwishGrad"); m.add_functor("LeakyRelu"); m.add_functor("LeakyReluGrad"); m.add_functor("RRelu"); m.add_functor("RReluInplace"); m.add_functor("Softplus"); m.add_functor("SoftplusGrad"); m.add_functor("Silu"); m.add_functor("SiluGrad"); m.add_functor("Mish"); m.add_functor("MishGrad"); m.add_functor("Selu"); m.add_functor("SeluGrad"); m.add_functor("SoftSign"); m.add_functor("SoftSignGrad"); m.add_functor("Threshold"); m.add_functor("ThresholdGrad"); m.add_functor("SoftShrink"); m.add_functor("SoftShrinkGrad"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/array_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/placement_utils.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/job/global_mode.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/eager/tensor_storage.h" #include namespace oneflow { namespace one { namespace functional { namespace impl { class ArgMaxFunctor { public: ArgMaxFunctor() { op_ = CHECK_JUST(one::OpBuilder("argmax").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const Optional& dim, const Optional& keepdim, const Optional>& dtype) const { if (dim.has_value() == false) { return SequenceFunction()>([&]() { return Flatten(input, 0, -1); }) .then([&](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_, {x}); }) .call(); } int new_dim = JUST(dim); const int32_t ndims = input->shape()->NumAxes(); new_dim = JUST(maybe_wrap_dim(new_dim, ndims)); if (new_dim < 0) { new_dim += ndims; } const auto do_cast = [&](const std::shared_ptr& x) -> Maybe { return Cast(x, JUST(dtype), /*pin_memory=*/false); }; if (new_dim == ndims - 1) { return SequenceFunction()>( [&]() { return OpInterpUtil::Dispatch(*op_, {input}); }) .then_if(keepdim.has_value() && JUST(keepdim) == true, std::bind(ExpandDims, std::placeholders::_1, -1)) .then_if(dtype.has_value(), do_cast) .call(); } std::vector permute; permute.reserve(ndims); for (int32_t i = 0; i < ndims - 1; i++) { permute.emplace_back(i < new_dim ? i : i + 1); } permute.emplace_back(new_dim); std::vector permute_inv(ndims, 0); for (int32_t i = 0; i < ndims; i++) { permute_inv[i] = -1; } for (int32_t i = 0; i < ndims; i++) { permute_inv[permute[i]] = i; } std::vector squeeze_dim = {new_dim}; return SequenceFunction()>([&]() { return Transpose(input, permute); }) .then([&](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_, {x}); }) .then(std::bind(ExpandDims, std::placeholders::_1, -1)) .then(std::bind(Transpose, std::placeholders::_1, permute_inv)) .then_if((!keepdim.has_value()) || (keepdim.has_value() && JUST(keepdim) == false), std::bind(Squeeze, std::placeholders::_1, squeeze_dim)) .then_if(dtype.has_value(), do_cast) .call(); } private: std::shared_ptr op_; }; class ArgMinFunctor { public: ArgMinFunctor() {} Maybe operator()(const std::shared_ptr& input, const Optional& dim, const Optional& keepdim, const Optional>& dtype) const { TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({input}, DType::Float()).Apply()); const auto x = JUST(tensor_processor.GetInputs()).at(0); return sequence_function(Negative) .then(std::bind(ArgMax, std::placeholders::_1, dim, keepdim, dtype)) .call(x); } }; class GlobalTensorConstantFunctor { public: GlobalTensorConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("tensor_constant").Input("in").Output("out").Build()); } Maybe operator()(const Shape& shape, const std::shared_ptr& value, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { CHECK_OR_RETURN(value->ndim() <= 1 && value->nelement() == 1) << "Only tensor with single element or scalar tensor are supported as value!"; CHECK_OR_RETURN(value->is_global()) << "The value tensor should be global tensor"; // NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status. autograd::AutoGradMode mode(false); JUST(CheckDeviceIdsIsValid(placement)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "nd_sbp"); attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt); auto dispatch_constant = [&](const std::vector>& sbp_tuple) -> Maybe { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp[i] = SbpParallelToString(*sbp_tuple[i]); } } attrs.SetAttr<2>(nd_sbp); return OpInterpUtil::Dispatch(*op_, {value}, attrs); }; bool has_partial_parallel = std::any_of(sbp_tuple.begin(), sbp_tuple.end(), [](const Symbol& sbp) { return sbp->has_partial_sum_parallel(); }); // The source op does not support Partial if (has_partial_parallel) { const auto& fixed_sbp_tuple = JUST(NdSbpReplacePartialByBroadcast(sbp_tuple)); const auto& tensor = JUST(dispatch_constant(*fixed_sbp_tuple)); return functional::ToGlobal(tensor, placement, sbp_tuple, {}, /* check_meta */ false, /*copy*/ false); } else { return dispatch_constant(sbp_tuple); } } private: std::shared_ptr op_; }; class TensorConstantFunctor { public: TensorConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("tensor_constant").Input("in").Output("out").Build()); } Maybe operator()(const Shape& shape, const std::shared_ptr& value, const Symbol& dtype, const Optional>& device) const { CHECK_OR_RETURN(value->ndim() <= 1 && value->nelement() == 1) << "Only tensor with single element or scalar tensor are supported as value!"; // NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status. autograd::AutoGradMode mode(false); if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalTensorConstant(shape, value, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype"); attrs.SetAllAttrs(shape, dtype->data_type()); if (device.has_value()) { Symbol device_symbol = JUST(device); return OpInterpUtil::Dispatch(*op_, {value}, OpExprInterpContext(attrs, device_symbol)); } else { return OpInterpUtil::Dispatch(*op_, {value}, attrs); } } private: std::shared_ptr op_; }; class GlobalConstantFunctor { public: GlobalConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); } Maybe operator()(const Shape& shape, const Scalar& value, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { JUST(CheckDeviceIdsIsValid(placement)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "complex_value", "is_complex_value", "floating_value", "is_floating_value", "integer_value", "nd_sbp"); if (IsComplexDataType(dtype->data_type())) { attrs.SetAllAttrs(shape, dtype->data_type(), value.Value>(), true, NullOpt, false, NullOpt, NullOpt); } else if (IsIntegralDataType(dtype->data_type())) { attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, NullOpt, false, value.As(), NullOpt); } else { attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, value.As(), true, NullOpt, NullOpt); } auto dispatch_constant = [&](const std::vector>& sbp_tuple) -> Maybe { if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp[i] = SbpParallelToString(*sbp_tuple[i]); } } attrs.SetAttr<7>(nd_sbp); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)); }; bool has_partial_parallel = [&]() { for (const auto& sbp : sbp_tuple) { if (sbp->has_partial_sum_parallel()) { return true; } } return false; }(); // Since the source op does not support Partial, it is necessary to replace Partial // with Broadcast, and then convert it to Partial if (has_partial_parallel) { const auto& fixed_sbp_tuple = JUST(NdSbpReplacePartialByBroadcast(sbp_tuple)); const auto& tensor = JUST(dispatch_constant(*fixed_sbp_tuple)); return functional::ToGlobal(tensor, placement, sbp_tuple, {}, /* check_meta */ false, /*copy*/ false); } else { return dispatch_constant(sbp_tuple); } } private: std::shared_ptr op_; }; class ConstantFunctor { public: ConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); } Maybe operator()(const Shape& shape, const Scalar& value, const Symbol& dtype, const Optional>& device) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalConstant(shape, value, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "complex_value", "is_complex_value", "floating_value", "is_floating_value", "integer_value"); if (IsComplexDataType(dtype->data_type())) { attrs.SetAllAttrs(shape, dtype->data_type(), value.Value>(), true, NullOpt, false, NullOpt); } else if (IsIntegralDataType(dtype->data_type())) { attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, NullOpt, false, value.As()); } else { attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, value.As(), true, NullOpt); } if (device.has_value()) { Symbol device_symbol = JUST(device); return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, device_symbol)); } else { return OpInterpUtil::Dispatch(*op_, {}, attrs); } } private: std::shared_ptr op_; }; class EmptyFunctor { public: EmptyFunctor() { op_ = CHECK_JUST(one::OpBuilder("empty").Output("out").Build()); } Maybe operator()(const Shape& shape, const Symbol& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory) const { std::shared_ptr empty; if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); empty = JUST(functional::GlobalEmpty(shape, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); } return empty; } Symbol device_symbol = device.value_or(JUST(Device::New("cpu"))); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "pin_memory", "device_type", "device_id"); attrs.SetAllAttrs(shape, dtype->data_type(), pin_memory, device_symbol->type(), device_symbol->device_id()); if (device.has_value()) { Symbol device_symbol = JUST(device); empty = JUST(OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, device_symbol))); } else { empty = JUST(OpInterpUtil::Dispatch(*op_, {}, attrs)); } if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); } return empty; } private: std::shared_ptr op_; }; class EmptyStridedFunctor { public: Maybe operator()(const std::vector& shape, const std::vector& stride, const Optional>& dtype, const Optional>& device, const bool requires_grad, const bool pin_memory) const { Symbol data_type = GetDefaultDType(); if (dtype.has_value()) { data_type = JUST(dtype); } auto empty = JUST(functional::Empty(Shape(shape), dtype.value_or(GetDefaultDType()), device, requires_grad, pin_memory)); CHECK_OR_RETURN(view::IsViewApplicable(empty)) << "oneflow.empty_strided() only support in eager local mode!"; return view::AsStrided(empty, shape, stride, 1); } }; class GlobalEmptyFunctor { public: GlobalEmptyFunctor() { op_ = CHECK_JUST(one::OpBuilder("empty").Output("out").Build()); } Maybe operator()(const Shape& shape, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { JUST(CheckDeviceIdsIsValid(placement)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "nd_sbp"); if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); } } attrs.SetAllAttrs(shape, dtype->data_type(), nd_sbp); } else { attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)); } private: std::shared_ptr op_; }; class ZerosLikeFunctor : public UnaryFunctor { public: ZerosLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("zero_like").Input("like").Output("out").Build()); } }; class OnesLikeFunctor : public UnaryFunctor { public: OnesLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("ones_like").Input("like").Output("out").Build()); } }; class FullLikeFunctor { public: FullLikeFunctor() {} Maybe operator()(const std::shared_ptr& x, const Scalar& fill_value) const { std::shared_ptr out; if (x->is_local()) { out = JUST(functional::Empty(*(x->shape()), x->dtype(), JUST(x->device()), /*requires_grad=*/false, /*pin_memory=*/false)); } else { out = JUST(functional::GlobalEmpty(*(x->shape()), x->dtype(), JUST(x->parallel_desc()), *JUST(private_details::RawGetSbpList(JUST(x->nd_sbp()))))); } out = JUST(functional::Fill(out, fill_value)); return out; } }; class FlattenFunctor { public: FlattenFunctor() = default; Maybe operator()(const std::shared_ptr& x, const int32_t& start_dim, const int32_t& end_dim) const { const Shape& in_shape = *x->shape(); int32_t ndim = in_shape.size(); auto CheckAndWrapDim = [&](int32_t dim) -> Maybe { // handle scalar if (ndim == 0 && (dim == 0 || dim == -1)) { return 0; } if (dim < -ndim || dim >= ndim) { return Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim << ", " << ndim - 1 << "], but got " << dim << ")"; } return dim >= 0 ? dim : dim + ndim; }; // -n dim (negative dim) indicate ndim-n // for example, when ndim == 3, (-3) == (0), (-2) == (1), (-1) == (2) int32_t true_start_dim = JUST(CheckAndWrapDim(start_dim)); int32_t true_end_dim = JUST(CheckAndWrapDim(end_dim)); if (true_start_dim > true_end_dim) { return Error::RuntimeError() << "flatten() has invalid args: start_dim (" << start_dim << ") cannot come after end_dim (" << end_dim << ")"; } // identity when start_dim == end_dim if (true_start_dim == true_end_dim) { return x; } DimVector dim_vec{in_shape.begin(), in_shape.begin() + true_start_dim + 1}; for (int i = true_start_dim + 1; i <= true_end_dim; ++i) { dim_vec.back() *= in_shape[i]; } dim_vec.insert(dim_vec.end(), in_shape.begin() + true_end_dim + 1, in_shape.end()); Shape reshape_shape{dim_vec}; CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), reshape_shape.elem_cnt()) << Error::RuntimeError() << "invalid reshape from " << in_shape.ToString() << " to " << reshape_shape.ToString(); return JUST(Reshape(x, reshape_shape)); } }; class WhereFunctor { public: WhereFunctor() { op_ = CHECK_JUST( one::OpBuilder("where").Input("condition").Input("x").Input("y").Output("out").Build()); } Maybe operator()(const std::shared_ptr& condition, const std::shared_ptr& x, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {condition, x, y}); } private: std::shared_ptr op_; }; class WhereScalarXFunctor { public: WhereScalarXFunctor() = default; Maybe operator()(const std::shared_ptr& condition, const Scalar& scalar, const std::shared_ptr& y) const { std::shared_ptr x; if (y->is_local()) { x = JUST(functional::Constant(Shape({}), scalar, y->dtype(), JUST(y->device()))); } else { const size_t sbp_ndim = JUST(y->nd_sbp())->sbp_parallel_size(); std::vector> nd_sbp_vec; nd_sbp_vec.reserve(sbp_ndim); for (int i = 0; i < sbp_ndim; ++i) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); nd_sbp_vec.push_back(SymbolOf(sbp)); } const auto& parallel_desc = JUST(y->parallel_desc()); x = JUST( functional::GlobalConstant(Shape({}), scalar, y->dtype(), parallel_desc, nd_sbp_vec)); } return functional::Where(condition, x, y); } }; class WhereScalarYFunctor { public: WhereScalarYFunctor() = default; Maybe operator()(const std::shared_ptr& condition, const std::shared_ptr& x, const Scalar& scalar) const { std::shared_ptr y; if (x->is_local()) { y = JUST(functional::Constant(Shape({}), scalar, x->dtype(), JUST(x->device()))); } else { const size_t sbp_ndim = JUST(x->nd_sbp())->sbp_parallel_size(); std::vector> nd_sbp_vec; nd_sbp_vec.reserve(sbp_ndim); for (int i = 0; i < sbp_ndim; ++i) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); nd_sbp_vec.push_back(SymbolOf(sbp)); } const auto& parallel_desc = JUST(x->parallel_desc()); y = JUST( functional::GlobalConstant(Shape({}), scalar, x->dtype(), parallel_desc, nd_sbp_vec)); } return functional::Where(condition, x, y); } }; class WhereScalarXYFunctor { public: WhereScalarXYFunctor() = default; Maybe operator()(const std::shared_ptr& condition, const Scalar& x_scalar, const Scalar& y_scalar) const { std::shared_ptr x; std::shared_ptr y; DataType dtype = DataType::kInvalidDataType; if (x_scalar.IsBool() && y_scalar.IsBool()) { dtype = DataType::kBool; } else if (x_scalar.IsFloatingPoint() && y_scalar.IsFloatingPoint()) { double x_val = x_scalar.As(); double y_val = y_scalar.As(); if (x_val >= GetMinVal>() && x_val <= GetMaxVal>() && y_val >= GetMinVal>() && y_val <= GetMaxVal>()) { dtype = DataType::kFloat; } else { dtype = DataType::kDouble; } } else if (x_scalar.IsIntegral() && y_scalar.IsIntegral()) { if (x_scalar.IsUnsigned() && y_scalar.IsUnsigned()) { uint64_t x_val = x_scalar.As(); uint64_t y_val = y_scalar.As(); if (x_val <= GetMaxVal>() && y_val <= GetMaxVal>()) { dtype = DataType::kUInt32; } else { dtype = DataType::kUInt64; } } else if (x_scalar.IsSigned() && y_scalar.IsSigned()) { int64_t x_val = x_scalar.As(); int64_t y_val = y_scalar.As(); if (x_val >= GetMinVal>() && x_val <= GetMaxVal>() && y_val >= GetMinVal>() && y_val <= GetMaxVal>()) { dtype = DataType::kInt32; } else { dtype = DataType::kInt64; } } else { UNIMPLEMENTED_THEN_RETURN() << "The x scalar and y scalar in Where shoule be signed or unsigned at the same time."; } } else { UNIMPLEMENTED_THEN_RETURN() << "The x scalar and y in Where shoule be bool, float or int at the same time."; } if (condition->is_local()) { x = JUST(functional::Constant(Shape({}), x_scalar, DType(dtype), JUST(condition->device()))); y = JUST(functional::Constant(Shape({}), y_scalar, DType(dtype), JUST(condition->device()))); } else { const size_t sbp_ndim = JUST(condition->nd_sbp())->sbp_parallel_size(); std::vector> nd_sbp_vec; nd_sbp_vec.reserve(sbp_ndim); for (int i = 0; i < sbp_ndim; ++i) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); nd_sbp_vec.push_back(SymbolOf(sbp)); } const auto& parallel_desc = JUST(condition->parallel_desc()); x = JUST( functional::GlobalConstant(Shape({}), x_scalar, DType(dtype), parallel_desc, nd_sbp_vec)); y = JUST( functional::GlobalConstant(Shape({}), y_scalar, DType(dtype), parallel_desc, nd_sbp_vec)); } return functional::Where(condition, x, y); } }; class ArgWhereFunctor { public: ArgWhereFunctor() { op_ = CHECK_JUST( one::OpBuilder("argwhere").Input("input").Output("output").Output("output_size").Build()); } Maybe operator()(const std::shared_ptr& x, const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype"); attrs.SetAllAttrs(dtype->data_type()); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class NonZeroFunctor { public: NonZeroFunctor() {} Maybe operator()(const std::shared_ptr& x, bool as_tuple) const { std::shared_ptr input = x; if (as_tuple && input->ndim() == 0) { input = JUST(functional::Unsqueeze(input, 0)); } int64_t ndim = input->ndim(); const auto& output_tuple = JUST(functional::ArgWhere(input, JUST(DType::Get(DataType::kInt64)))); const std::shared_ptr& size = JUST(VectorAt(*output_tuple, 1)); CHECK_EQ_OR_RETURN(size->shape()->elem_cnt(), 1) << Error::RuntimeError() << kOfBugIssueUploadPrompt; CHECK_OR_RETURN(size->dtype() == JUST(DType::Get(DataType::kInt64))) << Error::RuntimeError() << kOfBugIssueUploadPrompt; int64_t size_val = -1; { if (size->is_global()) { CHECK_OR_RETURN(JUST(size->parallel_desc())->parallel_num() == 1 // NOLINT || NdSbpIsAllBroadcast(*JUST(size->nd_sbp()))); // NOLINT } JUST(GetItemInScalarTensor(size->is_local() ? size : JUST(size->cur_rank_phy_tensor()), &size_val, sizeof(size_val))); } std::vector start{0, 0}; std::vector stop{size_val, ndim}; std::vector step{1, 1}; const auto& output = JUST( functional::Slice(output_tuple->at(0), start, stop, step, /*enable_view_slice=*/false)); std::shared_ptr outputs = std::make_shared(); if (as_tuple) { const auto& transposed_output = JUST(functional::Transpose2dim(output, 1, 0)); for (int64_t i = 0; i < ndim; ++i) { outputs->emplace_back( JUST(functional::TensorGetItem(transposed_output, {functional::detail::IndexItem(i)}))); } } else { outputs->emplace_back(output); } return outputs; } }; class BroadcastLikeFunctor { public: BroadcastLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_like").Input("x").Input("like").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& like, const std::vector& broadcast_axes) const { const Shape& x_shape = *x->shape(); const Shape& like_shape = *like->shape(); if (x_shape == like_shape) { return x; } CHECK_GE_OR_RETURN(like_shape.NumAxes(), x_shape.NumAxes()) << Error::RuntimeError() << "The number of sizes provided (" << like_shape.NumAxes() << ") must be greater or equal to the number of dimensions in the tensor (" << x_shape.NumAxes() << ")" << ". Target sizes: " << like_shape.ToString() << ". Tensor sizes: " << x_shape.ToString(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("broadcast_axes"); if (broadcast_axes.empty()) { int64_t like_ndim = like_shape.NumAxes(); int64_t x_ndim = x_shape.NumAxes(); int64_t num_prepend = like_ndim - x_ndim; std::vector prepend_shape(num_prepend, 1); std::vector broadcast_axes; for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x_shape.At(i)); } for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); } for (int i = num_prepend; i < prepend_shape.size(); ++i) { if (prepend_shape[i] != like_shape.At(i)) { if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); } else { return Error::RuntimeError() << "The expanded size of the tensor " << "(" << like_shape.At(i) << ")" << " must match the existing size (" << prepend_shape[i] << ") at non-singleton dimension " << i << ". Target sizes: " << like_shape.ToString() << ". Tensor sizes: " << x_shape.ToString(); } } } attrs.SetAllAttrs(broadcast_axes); } else { attrs.SetAllAttrs(broadcast_axes); } return OpInterpUtil::Dispatch(*op_, {x, JUST(like->detach())}, attrs); } private: std::shared_ptr op_; }; class ConcatFunctor { public: ConcatFunctor() { ops_.resize(kMaxInputCount); for (int n = 0; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST(one::OpBuilder("cat").Input("in", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& inputs, const int64_t& dim) const { const int64_t ninput = inputs.size(); int64_t axis = dim; int64_t ndim = inputs[0]->ndim(); int64_t nelement = inputs[0]->nelement(); int64_t max_dim_size = 0; CHECK_GE_OR_RETURN(ninput, 1) << Error::RuntimeError() << "inputs size must greater than 0"; axis = JUST(maybe_wrap_dim(axis, ndim)); const std::shared_ptr& shape = inputs[0]->shape(); for (const auto& input : inputs) { if (nelement == 0 and ndim == 1) { if (input->nelement() != 0 or input->ndim() != 1) { ndim = input->ndim(); nelement = input->nelement(); } else { continue; } } else if (input->nelement() != 0 or input->ndim() != 1) { CHECK_OR_RETURN(input->ndim() == ndim) << Error::RuntimeError() << "Tensors must have same number of dimensions: got " << ndim << " and " << input->ndim() << " is expected."; } for (int i = 0; i < ndim; ++i) { if (input->nelement() == 0 and input->ndim() == 1) { continue; } if (axis == i) { max_dim_size += input->shape()->At(i); } else if (inputs[0]->nelement() != 0) { CHECK_OR_RETURN(input->shape()->At(i) == shape->At(i)) << Error::RuntimeError() << "Sizes of tensors must match except in dimension " << axis << ". Got " << input->shape()->At(i) << " and " << shape->At(i) << " is expected in dimension " << i << "."; } } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "max_dim_size"); attrs.SetAllAttrs(axis, max_dim_size); TensorTuple outputs; for (int i = 0; i < ninput; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i; TensorTuple partial_inputs(size); TensorProcessor tensor_processor; for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs(partial_inputs, inputs.at(i)->dtype()) .Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); outputs.emplace_back( JUST(OpInterpUtil::Dispatch(*ops_[size - 1], input_tuple, attrs))); } if (outputs.size() == 1) { return outputs.at(0); } return this->operator()(outputs, axis); } private: std::vector> ops_; }; class StackFunctor { public: StackFunctor() { ops_.resize(kMaxInputCount); for (int n = 0; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST(one::OpBuilder("stack").Input("in", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& inputs, const int64_t& dim) const { const int64_t ninput = inputs.size(); int64_t ndims = inputs[0]->ndim(); int64_t stack_dim = dim; stack_dim = JUST(maybe_wrap_dim(stack_dim, ndims + 1)); const std::shared_ptr& first_in_shape = inputs[0]->shape(); for (const auto& input : inputs) { for (int i = 0; i < ndims; ++i) { CHECK_OR_RETURN(input->shape()->At(i) == first_in_shape->At(i)) << Error::RuntimeError() << "stack expects each tensor to be equal size, but got " << first_in_shape->ToString() << " at first input and " << input->shape()->ToString() << " which index is " << i; } } int64_t max_dim_size = ninput; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "max_dim_size"); attrs.SetAllAttrs(stack_dim, max_dim_size); TensorTuple outputs; for (int i = 0; i < ninput; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i; TensorTuple partial_inputs(size); for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } if (partial_inputs.size() == 1) { // Use ExpandDims functor for only one input outputs.emplace_back(JUST(functional::ExpandDims(partial_inputs[0], dim))); } else { outputs.emplace_back( JUST(OpInterpUtil::Dispatch(*ops_[size - 1], partial_inputs, attrs))); } } if (outputs.size() == 1) { return outputs.at(0); } return Concat(outputs, stack_dim); } private: std::vector> ops_; }; class StackGradFunctor { public: StackGradFunctor() { ops_.resize(kMaxInputCount); for (int n = 1; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST(one::OpBuilder("stack_grad") .Input("in") .Input("like", n + 1) .Output("out", n + 1) .Build()); } } Maybe operator()(const std::shared_ptr& x, const TensorTuple& like, const int64_t& axis) const { CHECK_GE_OR_RETURN(like.size(), 2) << Error::RuntimeError() << "like.size() must not less than 2, but got " << like.size(); CHECK_LE_OR_RETURN(like.size(), kMaxInputCount) << Error::RuntimeError() << "like.size() must not greater than " << kMaxInputCount << ", but got " << like.size(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); TensorTuple inputs(like.size() + 1); inputs[0] = x; for (int i = 0; i < like.size(); ++i) { inputs[i + 1] = like[i]; } return OpInterpUtil::Dispatch(*ops_.at(like.size() - 1), inputs, attrs); } private: std::vector> ops_; }; class AtLeast1DFunctor { public: Maybe operator()(const std::shared_ptr& x) const { if (x->ndim() == 0) { return JUST(Reshape(x, {1})); } else return x; } }; class AtLeast1DListFunctor { public: Maybe operator()(const TensorTuple& inputs) const { TensorTuple result = TensorTuple(inputs.size()); for (int32_t i = 0; i < inputs.size(); i++) { result.at(i) = JUST(AtLeast1D(JUST(VectorAt(inputs, i)))); } return result; } }; class AtLeast2DFunctor { public: Maybe operator()(const std::shared_ptr& x) const { if (x->ndim() == 0) { return JUST(Reshape(x, {1, 1})); } else if (x->ndim() == 1) { return JUST(Unsqueeze(x, 0)); } else return x; } }; class AtLeast2DListFunctor { public: Maybe operator()(const TensorTuple& inputs) const { TensorTuple result = TensorTuple(inputs.size()); for (int32_t i = 0; i < inputs.size(); i++) { result.at(i) = JUST(AtLeast2D(JUST(VectorAt(inputs, i)))); } return result; } }; class AtLeast3DFunctor { public: Maybe operator()(const std::shared_ptr& x) const { if (x->ndim() == 0) { return JUST(Reshape(x, {1, 1, 1})); } else if (x->ndim() == 1) { return JUST(Reshape(x, {1, x->shape()->At(0), 1})); } else if (x->ndim() == 2) { return JUST(Unsqueeze(x, -1)); } else return x; } }; class AtLeast3DListFunctor { public: Maybe operator()(const TensorTuple& inputs) const { TensorTuple result = TensorTuple(inputs.size()); for (int32_t i = 0; i < inputs.size(); i++) { result.at(i) = JUST(AtLeast3D(JUST(VectorAt(inputs, i)))); } return result; } }; class ColumnStackFunctor { public: Maybe operator()(const TensorTuple& inputs) const { std::shared_ptr new_inputs = std::make_shared(inputs.size()); for (int32_t i = 0; i < inputs.size(); i++) { const auto& t = JUST(VectorAt(inputs, i)); if (t->ndim() <= 1) new_inputs->at(i) = JUST(Reshape(t, {t->nelement(), 1})); else new_inputs->at(i) = t; } return HStack(*new_inputs); } }; class HStackFunctor { public: Maybe operator()(const TensorTuple& inputs) const { std::shared_ptr new_inputs = JUST(AtLeast1D(inputs)); if (new_inputs->at(0)->ndim() == 1) return Concat(*new_inputs, 0); else return Concat(*new_inputs, 1); } }; class VStackFunctor { public: Maybe operator()(const TensorTuple& inputs) const { std::shared_ptr new_inputs = JUST(AtLeast2D(inputs)); return Concat(*new_inputs, 0); } }; class RowStackFunctor { public: Maybe operator()(const TensorTuple& inputs) const { return VStack(inputs); } }; class DStackFunctor { public: Maybe operator()(const TensorTuple& inputs) const { std::shared_ptr new_inputs = JUST(AtLeast3D(inputs)); return Concat(*new_inputs, 2); } }; class ExpandFunctor { public: ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { const Shape& in_shape = *x->shape(); int lpad = shape.size() - in_shape.size(); if (lpad < 0) { return Error::RuntimeError() << "expand(tensor{" << in_shape.ToString() << "}, size=" << in_shape.size() << "): the number of sizes provided (" << shape.size() << ") " << "must be greater or equal to the number of dimensions in the tensor (" << in_shape.size() << ")"; } DimVector expand_shape_vec = shape.dim_vec(); for (size_t i = 0; i < shape.size(); ++i) { const auto& t_dim = shape[i]; if (t_dim < -1) { return Error::RuntimeError() << "Trying to create tensor with negative dimension " << t_dim; } if (i >= lpad) { const auto& dim = in_shape[i - lpad]; if (dim != 1 && t_dim != -1 && t_dim != dim) { return Error::RuntimeError() << "The expanded size of the tensor (" << t_dim << ") must match the existing size (" << dim << ") at non-singleton dimension " << i << ". Target sizes: " << shape.ToString() << ". Tensor sizes: " << in_shape.ToString(); } if (t_dim == -1) { expand_shape_vec[i] = dim; } } else { if (t_dim == -1) { return Error::RuntimeError() << "The expanded size of the tensor (-1) isn't allowed in a " "leading, non-existing dimension " << i; } } } // if input tensor is eager local, then try return tensor's view Shape expand_shape(expand_shape_vec); if (view::IsViewApplicable(x)) { return view::Expand(x, expand_shape); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("expand_shape"); attrs.SetAllAttrs(expand_shape); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ExpandDimsFunctor { public: ExpandDimsFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand_dims").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const int32_t& dim) const { int32_t expand_dim = dim; const int32_t ndim = input->shape()->NumAxes(); expand_dim = JUST(maybe_wrap_dim(dim, ndim + 1)); if (view::IsViewApplicable(input)) { return view::Unsqueeze(input, expand_dim); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(expand_dim); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class UnsqueezeMultipleFunctor { public: UnsqueezeMultipleFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::vector& dim, const int32_t& n_dims) const { if (dim.size() == 0 || x->ndim() == n_dims) { return x; } else if (dim.size() == 1) { return JUST(functional::Unsqueeze(x, JUST(VectorAt(dim, 0)))); } else { std::shared_ptr tensor = x; const auto& dims_to_unsqueeze = JUST(dim_list_to_bitset(dim, n_dims)); // Unsqueeze is called several times to extend the dimension when the View mechanism is // enabled. Otherwise, calculate the target shape and call reshape. if (view::IsViewApplicable(tensor)) { for (int32_t i = 0; i < n_dims; i++) { if ((*dims_to_unsqueeze)[i]) { tensor = JUST(view::Unsqueeze(tensor, i)); } } } else { std::vector target_dims(n_dims, 0); int32_t tensor_index = 0; for (int32_t i = 0; i < n_dims; i++) { if ((*dims_to_unsqueeze)[i]) { target_dims[i] = 1; } else { CHECK_LT_OR_RETURN(tensor_index, tensor->ndim()); // NOLINT(maybe-need-error-msg) target_dims[i] = tensor->shape()->at(tensor_index); tensor_index++; } } Shape infered_shape(DimVector(target_dims.begin(), target_dims.end())); tensor = JUST(functional::Reshape(tensor, infered_shape)); } return tensor; } } }; class InplaceUnsqueezeFunctor { public: Maybe operator()(const std::shared_ptr& input, const int32_t& dim) const { JUST(CheckInplaceValid(input)); const int64_t expand_dim = JUST(maybe_wrap_dim(dim, input->shape()->NumAxes() + 1)); CHECK_OR_RETURN(view::IsViewApplicable(input)) << "inplace unsqueeze(tensor.unsqueeze_) only support in eager local mode!"; JUST(view::InplaceUnsqueeze(input, expand_dim)); return input; } }; class SqueezeFunctor { public: SqueezeFunctor() { op_ = CHECK_JUST(one::OpBuilder("squeeze").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional>& dim) const { int32_t ndim = x->shape()->NumAxes(); std::vector squeeze_dims; squeeze_dims.reserve(ndim); if (dim.has_value()) { std::vector dims = *JUST(dim); for (int32_t dim_i : dims) { dim_i = JUST(maybe_wrap_dim(dim_i, ndim)); if (x->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); } } } else { for (int i = 0; i < ndim; ++i) { if (x->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); } } } if (view::IsViewApplicable(x)) { return view::Squeeze(x, squeeze_dims); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axes"); attrs.SetAllAttrs(squeeze_dims); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class InplaceSqueezeFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& dim) const { JUST(CheckInplaceValid(input)); const int32_t ndim = input->shape()->NumAxes(); std::vector squeeze_dims; squeeze_dims.reserve(ndim); if (dim.has_value()) { std::vector dims = *JUST(dim); for (int32_t dim_i : dims) { dim_i = JUST(maybe_wrap_dim(dim_i, ndim)); if (input->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); } } } else { for (int i = 0; i < ndim; ++i) { if (input->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); } } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axes"); attrs.SetAllAttrs(squeeze_dims); CHECK_OR_RETURN(view::IsViewApplicable(input)) << "inplace squeeze(tensor.squeeze_) only support in eager local mode!"; JUST(view::InplaceSqueeze(input, squeeze_dims)); return input; } }; class RollFunctor { public: RollFunctor() { op_ = CHECK_JUST(one::OpBuilder("roll").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& shifts, const Optional>& dims) const { std::vector actual_dims; if (dims.has_value()) { actual_dims = *JUST(dims); } else { actual_dims.emplace_back(-1); } CHECK_EQ_OR_RETURN(shifts.size(), actual_dims.size()) << Error::RuntimeError() << "shifts and dimensions must align. shifts: " << shifts.size() << ", dims: " << actual_dims.size(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shifts", "dims"); attrs.SetAllAttrs(shifts, actual_dims); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class GatherFunctor { public: GatherFunctor() { op_ = CHECK_JUST(one::OpBuilder("gather").Input("in").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& indices, const int64_t& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {x, indices}, attrs); } private: std::shared_ptr op_; }; class DimGatherFunctor { public: DimGatherFunctor() { op_ = CHECK_JUST( one::OpBuilder("dim_gather").Input("input").Input("index").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t& dim, const std::shared_ptr& index, const bool sparse_grad) const { CHECK_OR_RETURN(index->dtype()->data_type() == kInt64 || index->dtype()->data_type() == kInt32) << Error::RuntimeError() << "gather(): Expected dtype int32 or int64 for index"; CHECK_EQ_OR_RETURN(sparse_grad, false) << Error::RuntimeError() << "Only support bool = False for now!"; int64_t new_dim = JUST(maybe_wrap_dim(dim, index->ndim())); if (input->ndim() > 0 && index->ndim() > 0) { CHECK_EQ_OR_RETURN(input->ndim(), index->ndim()) << Error::RuntimeError() << "Index tensor must have the same number of dimensions as input tensor"; } else if (input->ndim() == 0) { CHECK_LE_OR_RETURN(index->ndim(), 1) << Error::RuntimeError() << "Index tensor must have the same number of dimensions as input tensor"; } else { CHECK_LE_OR_RETURN(input->ndim(), 1) << Error::RuntimeError() << "Index tensor must have the same number of dimensions as input tensor"; } if (input->ndim() > 0 && index->ndim() > 0) { FOR_RANGE(int32_t, i, 0, input->ndim()) { if (i != new_dim) { CHECK_LE_OR_RETURN(index->shape()->At(i), input->shape()->At(i)) << Error::RuntimeError() << "Size does not match at dimension " << i << " expected index " << *(index->shape()) << " to be smaller than self " << *(input->shape()) << " apart from dimension " << new_dim; } } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); attrs.SetAllAttrs(static_cast(new_dim)); return OpInterpUtil::Dispatch(*op_, {input, index}, attrs); } private: std::shared_ptr op_; }; enum class DimScatterType { kUpdate, kAdd, kMultiply }; template std::string DimScatterTypeToString() { switch (T) { case DimScatterType::kUpdate: return "_update"; case DimScatterType::kAdd: return "_add"; case DimScatterType::kMultiply: return "_mul"; } return ""; } template class DimScatterFunctorImpl { public: DimScatterFunctorImpl() : op_(CHECK_JUST(one::OpBuilder("dim_scatter" + DimScatterTypeToString()) .Input("input") .Input("index") .Input("src") .Output("output") .Build())) {} Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const std::shared_ptr& index, const std::shared_ptr& src, bool inplace) const { const int32_t ndim = input->shape()->NumAxes(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); attrs.SetAllAttrs(static_cast(JUST(maybe_wrap_dim(dim, ndim)))); if (inplace) { JUST(CheckInplaceValid(input)); auto outputs = std::make_shared(1); outputs->at(0) = input; JUST(OpInterpUtil::Dispatch(*op_, {input, index, src}, outputs.get(), attrs)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_, {input, index, src}, attrs); } private: std::shared_ptr op_; }; class DimScatterFunctor { public: Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const std::shared_ptr& index, const std::shared_ptr& src, const Optional& reduce, bool inplace) const { if (reduce.has_value()) { const std::string& reduce_str = *JUST(reduce); if (reduce_str == "add") { return DimScatterAdd(input, dim, index, src, inplace); } else if (reduce_str == "multiply") { return DimScatterMul(input, dim, index, src, inplace); } else { CHECK_OR_RETURN(false) << Error::RuntimeError() << "Invalid reduce type: " << reduce_str; } } return functional::DimScatterUpdate(input, dim, index, src, inplace); } }; template class DimScatterScalarFunctorImpl { public: DimScatterScalarFunctorImpl() : op_(CHECK_JUST(one::OpBuilder("dim_scatter" + DimScatterTypeToString() + "_scalar") .Input("input") .Input("index") .Output("output") .Build())) {} Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const std::shared_ptr& index, const Scalar& src, bool inplace) const { const int32_t ndim = input->shape()->NumAxes(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "src_scalar"); attrs.SetAllAttrs(static_cast(JUST(maybe_wrap_dim(dim, ndim))), src.As()); if (inplace) { JUST(CheckInplaceValid(input)); auto outputs = std::make_shared(1); outputs->at(0) = input; JUST(OpInterpUtil::Dispatch(*op_, {input, index}, outputs.get(), attrs)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_, {input, index}, attrs); } private: std::shared_ptr op_; }; class DimScatterScalarFunctor { public: Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const std::shared_ptr& index, const Scalar& src, const Optional& reduce, bool inplace) const { if (reduce.has_value()) { const std::string& reduce_str = *JUST(reduce); if (reduce_str == "add") { return DimScatterAddScalar(input, dim, index, src, inplace); } else if (reduce_str == "multiply") { return DimScatterMulScalar(input, dim, index, src, inplace); } else { CHECK_OR_RETURN(false) << Error::RuntimeError() << "Invalid reduce type: " << reduce_str; } } return functional::DimScatterUpdateScalar(input, dim, index, src, inplace); } }; class DimScatterAddLikeFunctor { public: DimScatterAddLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add_like") .Input("like") .Input("index") .Input("src") .Output("output") .Build()); } Maybe operator()(const std::shared_ptr& like, const int32_t& dim, const std::shared_ptr& index, const std::shared_ptr& src) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); attrs.SetAllAttrs(dim); return OpInterpUtil::Dispatch(*op_, {like, index, src}, attrs); } private: std::shared_ptr op_; }; class ArgSortFunctor { public: ArgSortFunctor() { op_ = CHECK_JUST(one::OpBuilder("arg_sort").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const std::string& direction) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("direction"); attrs.SetAllAttrs(direction); CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING") << Error::RuntimeError() << "expected the input direction parameter value is \"ASCENDING\" or \"DESCENDING\", " << "but found the value is " << "\"" << direction << "\""; return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; class SearchSortedFunctor { public: SearchSortedFunctor() { op_ = CHECK_JUST(one::OpBuilder("searchsorted") .Input("sorted_sequence") .Input("values") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& sorted_sequence, const std::shared_ptr& values, bool out_int32, bool right) const { // checks CHECK_OR_RETURN(values->shape()->NumAxes() > 0) << "for searchsorted op, input values tensor should have positive dimension"; CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() > 0) << "for searchsorted op, input sorted_sequence should have positive dimension, " << "but got 0 dimension"; CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() == 1 || sorted_sequence->shape()->MatchBeforeLastDim(*(values->shape()))) << "for searchsorted op, sorted_sequence should be 1 dimension or the first N-1 dimensions " << "of boundaries tensor and input value tensor must match"; if (out_int32) { CHECK_OR_RETURN(sorted_sequence->shape()->At(sorted_sequence->shape()->NumAxes() - 1) < INT32_MAX) << "for searchsorted op, the size of input sorted_sequence' last dimension should " << "be less than " << INT32_MAX; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_int32", "right"); attrs.SetAllAttrs(out_int32, right); return OpInterpUtil::Dispatch(*op_, {sorted_sequence, values}, attrs); } private: std::shared_ptr op_; }; class SearchSortedScalarFunctor { public: SearchSortedScalarFunctor() { op_ = CHECK_JUST( one::OpBuilder("searchsorted_scalar").Input("sorted_sequence").Output("out").Build()); } Maybe operator()(const std::shared_ptr& sorted_sequence, const Scalar& values, bool out_int32, bool right) const { // checks CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() == 1) << "for searchsorted op, input value can be a scalar only when sorted_sequence tensor " << "dimension is 1, but we got sorted_sequence dim(" << sorted_sequence->shape()->NumAxes() << ")"; if (out_int32) { CHECK_OR_RETURN(sorted_sequence->shape()->At(sorted_sequence->shape()->NumAxes() - 1) < INT32_MAX) << "for searchsorted op, the size of input sorted_sequence' last dimension should " << "be less than " << INT32_MAX; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_int32", "right", "values"); attrs.SetAllAttrs(out_int32, right, values.As()); return OpInterpUtil::Dispatch(*op_, {sorted_sequence}, attrs); } private: std::shared_ptr op_; }; class GatherNdFunctor { public: GatherNdFunctor() { op_ = CHECK_JUST( one::OpBuilder("gather_nd").Input("params").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& params, const std::shared_ptr& indices) const { return OpInterpUtil::Dispatch(*op_, {params, indices}); } private: std::shared_ptr op_; }; class ScatterNdFunctor { public: ScatterNdFunctor() { op_ = CHECK_JUST( one::OpBuilder("scatter_nd").Input("indices").Input("updates").Output("out").Build()); } Maybe operator()(const std::shared_ptr& indices, const std::shared_ptr& updates, const Shape& shape) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape"); attrs.SetAllAttrs(shape); return OpInterpUtil::Dispatch(*op_, {indices, updates}, attrs); } private: std::shared_ptr op_; }; class TensorScatterNdUpdateFunctor { public: TensorScatterNdUpdateFunctor() { op_ = CHECK_JUST(one::OpBuilder("tensor_scatter_nd_update") .Input("params") .Input("indices") .Input("updates") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& tensor, const std::shared_ptr& indices, const std::shared_ptr& updates, bool inplace) const { CHECK_OR_RETURN(*tensor->dtype() == *updates->dtype()) << Error::RuntimeError() << "The dtype of tensor and updates must be same."; std::shared_ptr contiguous_index = JUST(functional::ToContiguous(indices)); if (inplace) { if (tensor->is_global()) { // NOTE: global tensor_scatter_nd_update inplace must calculate on another tensor and assign // back because of input's sbp limited auto output = JUST(OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates})); int64_t ndim = tensor->shape()->NumAxes(); // TODO: use inplace copy op to write back to origin tensor std::vector start(ndim, 0); std::vector stop(tensor->shape()->begin(), tensor->shape()->end()); std::vector step(ndim, 1); return functional::SliceUpdate(tensor, output, start, stop, step, /*inplace=*/true); } else { JUST(CheckInplaceValid(tensor)); auto outputs = std::make_shared(1); (*outputs)[0] = tensor; JUST(OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates}, outputs.get())); return (*outputs)[0]; } } else { return OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates}); } } private: std::shared_ptr op_; }; class ScatterNdLikeFunctor { public: ScatterNdLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("scatter_nd_like") .Input("like") .Input("updates") .Input("indices") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& like, const std::shared_ptr& updates, const std::shared_ptr& indices) const { return OpInterpUtil::Dispatch(*op_, {like, updates, indices}); } private: std::shared_ptr op_; }; class ReshapeFunctor { public: ReshapeFunctor() { op_ = CHECK_JUST(one::OpBuilder("reshape").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { Shape infered_shape = *JUST(InferShapeUnspecifiedDim(x->shape()->Count(0), shape)); if (view::IsViewApplicable(x)) { Optional infered_stride = ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape); if (infered_stride.has_value()) { return view::Reshape(x, infered_shape, *JUST(infered_stride)); } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape"); attrs.SetAllAttrs(infered_shape); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ViewFunctor { public: ViewFunctor() { op_ = CHECK_JUST(one::OpBuilder("reshape").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { Shape infered_shape = *JUST(InferShapeUnspecifiedDim(x->shape()->Count(0), shape)); if (view::IsViewApplicable(x)) { Optional infered_stride = ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape); CHECK_OR_RETURN_ERROR(infered_stride.has_value()) << Error::RuntimeError() << "view size is not compatible with input tensor's size and stride (at least one " "dimension spans across two contiguous subspaces). Use .reshape(...) instead."; return view::Reshape(x, infered_shape, *JUST(infered_stride)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape"); attrs.SetAllAttrs(infered_shape); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ToContiguousFunctor { public: ToContiguousFunctor() { op_ = CHECK_JUST(one::OpBuilder("to_contiguous").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input) const { if (input->is_global() || input->is_lazy()) { return input; } return OpInterpUtil::Dispatch(*op_, {input}); } private: std::shared_ptr op_; }; class InplaceToContiguousFunctor { public: InplaceToContiguousFunctor() { assign_op_ = CHECK_JUST(one::OpBuilder("assign").Input("ref").Input("value").Build()); } Maybe operator()(const std::shared_ptr& input) const { // TODO: use original "inplace_to_contiguous" op replace assign if (input->is_contiguous()) { return input; } auto contiguous_tensor = JUST(functional::ToContiguous(input)); CHECK_OR_RETURN(input->is_local() && contiguous_tensor->is_local()) << "Both ref and value must be local tensor."; const Stride stride(*input->shape()); // update stride const auto& blob_object = JUST(input->eager_blob_object()); Symbol old_tensor_meta = JUST(input->local_tensor_meta()); Symbol new_tensor_meta = SymbolOf(LocalTensorMeta(old_tensor_meta->shape(), stride, old_tensor_meta->dtype(), old_tensor_meta->memory_format(), old_tensor_meta->device())); std::shared_ptr final_tensor_impl = std::make_shared(JUST(input->tensor_storage()), JUST(input->storage_offset()), input->requires_grad(), input->is_leaf()); JUST(final_tensor_impl->set_retain_grad(input->retain_grad())); JUST(final_tensor_impl->InitEagerBlobObject(new_tensor_meta, JUST(blob_object->compute_local_dep_object()))); JUST(JUST(input->AsLocalTensor())->set_impl(final_tensor_impl)); // assign contiguous tensor data JUST(OpInterpUtil::Dispatch(*assign_op_, {input, contiguous_tensor})); return input; } private: std::shared_ptr assign_op_; }; class NarrowFunctor { public: NarrowFunctor() { op_ = CHECK_JUST(one::OpBuilder("narrow").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t& dim, const int64_t& start, const int64_t& length) const { int64_t narrow_dim = dim; int64_t narrow_start = start; const int64_t ndim = input->shape()->NumAxes(); CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError() << "narrow() cannot be applied to a 0-dim tensor."; narrow_dim = JUST(maybe_wrap_dim(narrow_dim, ndim)); int64_t dim_length = input->shape()->At(narrow_dim); CHECK_OR_RETURN((-dim_length <= start) && (start <= dim_length)) << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim << ", " << ndim << "], but got " << start << ")"; if (narrow_start < 0) { narrow_start += ndim; } CHECK_GE_OR_RETURN(dim_length, narrow_start + length) << Error::RuntimeError() << "start (" << narrow_start << ") + length (" << length << ") exceeds dimension size (" << dim_length << ")"; if (view::IsViewApplicable(input)) { return JUST(view::Narrow(input, narrow_dim, narrow_start, length)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "start", "length"); attrs.SetAllAttrs(narrow_dim, start, length); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class NarrowGradFunctor { public: NarrowGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("narrow_grad").Input("dy").Input("like").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& like, const int64_t& dim, const int64_t& start, const int64_t& length) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "start", "length"); attrs.SetAllAttrs(dim, start, length); return OpInterpUtil::Dispatch(*op_, {dy, like}, attrs); } private: std::shared_ptr op_; }; class SliceFunctor { public: SliceFunctor() { op_ = CHECK_JUST(one::OpBuilder("slice").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& start, const std::vector& stop, const std::vector& step, const Optional& enable_view_slice) const { if (view::IsViewApplicable(x) && enable_view_slice.value_or(false)) { return view::Slice(x, start, stop, step); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("start", "stop", "step"); attrs.SetAllAttrs(start, stop, step); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } protected: std::shared_ptr op_; }; class SliceUpdateFunctor { public: SliceUpdateFunctor() { op_ = CHECK_JUST(one::OpBuilder("slice_update").Input("ref").Input("value").Output("y").Build()); } Maybe operator()(const std::shared_ptr& ref, const std::shared_ptr& value, const std::vector& start, const std::vector& stop, const std::vector& step, bool inplace) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("start", "stop", "step"); attrs.SetAllAttrs(start, stop, step); TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({ref, value}) .PromoteInputsToCommonDtype(true, ref->dtype()) .Apply()); if (inplace) { auto outputs = std::make_shared(1); JUST(CheckInplaceValid(ref)); JUST(VectorAt(*outputs, 0)) = ref; JUST(OpInterpUtil::Dispatch(*op_, JUST(tensor_processor.GetInputs()), outputs.get(), attrs)); return JUST(VectorAt(*outputs, 0)); } else { return OpInterpUtil::Dispatch(*op_, JUST(tensor_processor.GetInputs()), attrs); } } private: std::shared_ptr op_; }; class SliceGradFunctor { public: SliceGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("slice_grad").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const Shape& like_shape, const std::vector& start, const std::vector& stop, const std::vector& step) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("like_shape", "start", "stop", "step"); attrs.SetAllAttrs(like_shape, start, stop, step); return OpInterpUtil::Dispatch(*op_, {dy}, attrs); } protected: std::shared_ptr op_; }; class UpsampleGradFunctor { public: UpsampleGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& height_scale, const double& width_scale, const bool& align_corners, const std::string& data_format, const std::string& interpolation) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "align_corners", "interpolation", "data_format"); attrs.SetAllAttrs(height_scale, width_scale, align_corners, interpolation, data_format); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class CopyToDeviceFunctor { public: CopyToDeviceFunctor() { op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, Symbol device, const bool pin_memory) const { if (x->is_local()) { if (auto x_device = JUST(x->device()); x_device != device && x_device->rematable()) { std::dynamic_pointer_cast( JUST(x->eager_blob_object())->tensor_storage()) ->Remat(); } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("device", "pin_memory"); attrs.SetAllAttrs(device, pin_memory); // Trigger the construction of device context in advance if (device->enum_type() != DeviceType::kCPU) { TouchEpDevice(device); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: void TouchEpDevice(Symbol device) const { ep::DeviceManager* device_mgr = Singleton::Get()->GetDeviceManagerOrNull(device->enum_type()); if (!device_mgr) { return; } device_mgr->GetDevice(device->device_id()); } std::shared_ptr op_; }; class CopyFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::string& device_type, const int64_t& device_id, const bool pin_memory) const { return functional::Copy(x, JUST(Device::New(device_type, device_id)), pin_memory); } }; class FlipFunctor { public: FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& dims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims"); if (dims.empty()) { attrs.SetAllAttrs(dims); } else { std::vector flip_dims = *JUST(CheckAxis(dims, x->ndim())); attrs.SetAllAttrs(flip_dims); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UnfoldTensorFunctor { public: UnfoldTensorFunctor() { op_ = CHECK_JUST(one::OpBuilder("unfold_tensor").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const int32_t& dimension, const int32_t& size, const int32_t& step) const { // if input tensor is eager local, than try return tensor's view if (view::IsViewApplicable(x)) { return view::UnfoldTensor(x, dimension, size, step); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dimension", "size", "step"); attrs.SetAllAttrs(dimension, size, step); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UnfoldTensorGradFunctor { public: UnfoldTensorGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("unfold_tensor_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const int32_t& dimension, const int32_t& size, const int32_t& step) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dimension", "size", "step"); attrs.SetAllAttrs(dimension, size, step); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleLinear1DFunctor { public: UpsampleLinear1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_linear_1d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& scale_factor, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_factor", "align_corners", "data_format", "output_size"); if (output_size.has_value()) { attrs.SetAllAttrs(scale_factor, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(scale_factor, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleLinear1DGradFunctor { public: UpsampleLinear1DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_linear_1d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& scale_factor, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_factor", "align_corners", "data_format", "output_size"); if (output_size.has_value()) { attrs.SetAllAttrs(scale_factor, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(scale_factor, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest1DFunctor { public: UpsampleNearest1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_1d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& scale_factor, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_factor", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(scale_factor, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(scale_factor, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest1DGradFunctor { public: UpsampleNearest1DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_nearest_1d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& scale_factor, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_factor", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(scale_factor, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(scale_factor, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest2DFunctor { public: UpsampleNearest2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_2d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& height_scale, const double& width_scale, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest2DGradFunctor { public: UpsampleNearest2DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_nearest_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& height_scale, const double& width_scale, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleBilinear2DFunctor { public: UpsampleBilinear2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_bilinear_2d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleBilinear2DGradFunctor { public: UpsampleBilinear2DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_bilinear_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleBicubic2DFunctor { public: UpsampleBicubic2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_bicubic_2d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleBicubic2DGradFunctor { public: UpsampleBicubic2DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_bicubic_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest3DFunctor { public: UpsampleNearest3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_3d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& depth_scale, const double& height_scale, const double& width_scale, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth_scale", "height_scale", "width_scale", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleNearest3DGradFunctor { public: UpsampleNearest3DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_nearest_3d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& depth_scale, const double& height_scale, const double& width_scale, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth_scale", "height_scale", "width_scale", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UpsampleTrilinear3DFunctor { public: UpsampleTrilinear3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_trilinear_3d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const double& depth_scale, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth_scale", "height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class UpsampleTrilinear3DGradFunctor { public: UpsampleTrilinear3DGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("upsample_trilinear_3d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const double& depth_scale, const double& height_scale, const double& width_scale, const bool& align_corners, const Optional>& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth_scale", "height_scale", "width_scale", "align_corners", "data_format", "output_size"); if (output_size) { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format, *JUST(output_size)); } else { attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format, NullOpt); } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class UnsortedSegmentSumLikeFunctor { public: UnsortedSegmentSumLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("unsorted_segment_sum_like") .Input("data") .Input("segment_ids") .Input("like") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& segment_ids, const std::shared_ptr& like, const int64_t& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {x, segment_ids, like}, attrs); } private: std::shared_ptr op_; }; class UnsortedSegmentSumFunctor { public: UnsortedSegmentSumFunctor() { op_ = CHECK_JUST(one::OpBuilder("unsorted_segment_sum") .Input("data") .Input("segment_ids") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& segment_ids, const int64_t& axis, const int64_t& num_segments) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "num_segments"); attrs.SetAllAttrs(axis, num_segments); return OpInterpUtil::Dispatch(*op_, {x, segment_ids}, attrs); } private: std::shared_ptr op_; }; class TrilFunctor { public: TrilFunctor() { op_ = CHECK_JUST(one::OpBuilder("tril").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal", "is_floating_fill_value", "integer_fill_value"); attrs.SetAllAttrs(diagonal, false, static_cast(0)); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class InplaceTrilFunctor { public: InplaceTrilFunctor() { op_ = CHECK_JUST(one::OpBuilder("tril").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal) const { JUST(CheckInplaceValid(x)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal"); attrs.SetAllAttrs(diagonal); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class TriuFunctor { public: TriuFunctor() { op_ = CHECK_JUST(one::OpBuilder("triu").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal"); attrs.SetAllAttrs(diagonal); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class InplaceTriuFunctor { public: InplaceTriuFunctor() { op_ = CHECK_JUST(one::OpBuilder("triu").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal) const { JUST(CheckInplaceValid(x)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal"); attrs.SetAllAttrs(diagonal); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class DiagFunctor { public: DiagFunctor() { op_ = CHECK_JUST(one::OpBuilder("diag").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int32_t& diagonal) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal"); attrs.SetAllAttrs(diagonal); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class DiagGradFunctor { public: DiagGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("diag_grad").Input("dy").Input("in").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const int32_t& diagonal) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal"); attrs.SetAllAttrs(diagonal); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class DiagonalFunctor { public: DiagonalFunctor() { op_ = CHECK_JUST(one::OpBuilder("diagonal").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int32_t& offset, const int32_t& dim1, const int32_t& dim2) const { int64_t ndims = x->shape()->NumAxes(); int32_t p_dim1 = dim1; int32_t p_dim2 = dim2; p_dim1 = JUST(maybe_wrap_dim(p_dim1, ndims)); p_dim2 = JUST(maybe_wrap_dim(p_dim2, ndims)); CHECK_NE_OR_RETURN(p_dim1, p_dim2) << Error::RuntimeError() << "diagonal dimensions cannot be identical " << dim1 << ", " << dim2; if (view::IsViewApplicable(x)) { return view::Diagonal(x, offset, p_dim1, p_dim2); } else { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("offset"); attrs.SetAllAttrs(offset); std::vector input_index{p_dim1, p_dim2}; for (int32_t i = 0; i < ndims; i++) { if (i != p_dim1 && i != p_dim2) { input_index.push_back(i); } } std::shared_ptr d_x = JUST(Transpose(x, input_index)); return OpInterpUtil::Dispatch(*op_, {d_x}, attrs); } } private: std::shared_ptr op_; }; class DiagonalGradFunctor { public: DiagonalGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("diagonal_grad").Input("dy").Input("in").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const int32_t& offset) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("offset"); attrs.SetAllAttrs(offset); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; // Only for ddp gradient grouping class SliceView1dContiguousFunctor { public: SliceView1dContiguousFunctor() = default; Maybe operator()(const std::shared_ptr& x, int64_t start, int64_t end) const { if (view::IsViewApplicable(x)) { return JUST(view::Slice(x, {start}, {end}, {1})); } return JUST(functional::Slice(x, {start}, {end}, {1}, /*enable_view_slice=*/true)); } }; class TensorGetItemFunctor { public: TensorGetItemFunctor() {} Maybe operator()(const std::shared_ptr& x, const TensorIndex& index) const { if (x->is_local() && !(LazyMode::is_enabled()) && x->requires_grad() == false && index.size() == 1 && index[0].IsInteger()) { // NOTE: speed up in special case, e.g. dataloader(refer to torch) // function call chain of pytorch : tensor getitem -> select -> as_strided // function call chain of oneflow : tensor getitem -> as_strided return ApplySelectIndexing(x, index); } std::vector slice_indices; TensorTuple tensor_indices; std::vector target_dims; std::vector expand_dims; JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &expand_dims, &target_dims)); auto expand_input = x; for (int i = 0; i < expand_dims.size(); ++i) { int64_t dim = expand_dims.at(i); expand_input = JUST(functional::ExpandDims(expand_input, dim + i)); } int64_t ndims = expand_input->shape()->NumAxes(); CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << Error::RuntimeError() << "Failed to prepare slice indices."; Shape target_shape(DimVector(target_dims.begin(), target_dims.end())); std::vector start(ndims), end(ndims), step(ndims); for (int i = 0; i < ndims; ++i) { const auto& slice = slice_indices.at(i); start[i] = slice.start(); end[i] = slice.end(); step[i] = slice.step(); } bool is_identity = [&]() { if (target_shape.NumAxes() == 0) { return false; } for (int i = 0; i < ndims; ++i) { if (start[i] != 0 || end[i] != expand_input->shape()->At(i) || step[i] != 1) { return false; } } return true; }(); std::shared_ptr result; if (is_identity) { result = expand_input; } else { result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true)); } Shape shape(DimVector(target_dims.begin(), target_dims.end())); if (shape != *(result->shape())) { result = JUST(Reshape(result, shape)); } if (!tensor_indices.empty()) { JUST(UnifyInputAndIndicesOnDevice(x, tensor_indices)); result = JUST(ApplyAdvancedIndexing(result, tensor_indices)); } return result; } }; class TensorSetItemFunctor { public: TensorSetItemFunctor() {} Maybe operator()(const std::shared_ptr& x, const TensorIndex& index, const std::shared_ptr& value) const { std::vector slice_indices; TensorTuple tensor_indices; std::vector expand_dims; std::vector target_dims; JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &expand_dims, &target_dims)); auto expand_input = x; if (!expand_dims.empty()) { CHECK_OR_RETURN(view::IsViewApplicable(x)) << "expand dims must enable view, " "please try to set ONEFLOW_DISABLE_VIEW=0"; for (int i = 0; i < expand_dims.size(); ++i) { int64_t dim = expand_dims[i]; expand_input = JUST(functional::ExpandDims(expand_input, dim + i)); } } int64_t ndims = expand_input->shape()->NumAxes(); CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << Error::RuntimeError() << "Failed to prepare slice indices."; Shape target_shape(DimVector(target_dims.begin(), target_dims.end())); if (target_shape.Count(0) == 0) { return Maybe::Ok(); } const auto& value_shape = value->shape(); bool matched = [&]() { for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) { if (value_shape->At(i) != 1) { return false; } } return true; }(); CHECK_OR_RETURN(matched) << Error::RuntimeError() << "The tensor size mismatch. Target sizes: " << target_shape.ToString() << ", value sizes: " << value_shape->ToString(); std::shared_ptr value_tensor(value); // TODO: replace reshape by unsqueeze with view mechanism. // after here, each scalar tensor will be one with one dimension. for (auto& tensor : tensor_indices) { if (tensor && tensor->ndim() == 0) { tensor = JUST(functional::Reshape(tensor, Shape({1}))); } } DimVector slice_dims(ndims); std::vector start(ndims), end(ndims), step(ndims); for (int i = 0; i < ndims; ++i) { const auto& slice = slice_indices[i]; start[i] = slice.start(); end[i] = slice.end(); step[i] = slice.step(); slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i]; } if (tensor_indices.empty()) { Shape slice_shape(slice_dims); if (slice_shape != *(value_tensor->shape())) { // NOTE: // 1. The value shape must can be broadcasted to the target shape. // 2. The slice shape must have equal element count with the target shape. // // So, we should be expand to target_shape and then reshape to slice_shape. // // For example: // x = flow.rand(2, 3, 4) // y = flow.rand(3) // x[:, :, 1] = y // // value_shape = (3,), target_shape = (2, 3), slice_shape = (2, 3, 1) // We must change value shape to slice_shape if it uses SliceUpdate op. if (target_shape != *(value_tensor->shape()) && target_shape.NumAxes() > 0) { value_tensor = JUST(Expand(value_tensor, target_shape)); } if (slice_shape != *(value_tensor->shape())) { value_tensor = JUST(Reshape(value_tensor, slice_shape)); } } JUST(SliceUpdate(expand_input, value_tensor, start, end, step, /*inplace=*/true)); } else { bool is_identity = [&]() { if (target_shape.NumAxes() == 0) { return false; } for (int i = 0; i < ndims; ++i) { if (start[i] != 0 || end[i] != expand_input->shape()->At(i) || step[i] != 1) { return false; } } return true; }(); std::shared_ptr result; if (is_identity) { result = expand_input; } else { if (expand_input->is_local()) { CHECK_OR_RETURN(view::IsViewApplicable(expand_input)) << "combined slice setitem must enable view, please try to set " "ONEFLOW_DISABLE_VIEW=0"; result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true)); } else { // global tensor result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/false)); } } const Shape& slice_result_shape = *(result->shape()); if (target_shape != slice_result_shape) { result = JUST(functional::View(result, target_shape)); } JUST(UnifyInputAndIndicesOnDevice(result, tensor_indices)); result = JUST(ApplyAdvancedIndexingUpdate(result, tensor_indices, value)); // Write the sliced tensor back to the original tensor. if (result->is_global()) { if (*result->shape() != slice_result_shape) { CHECK_EQ_OR_RETURN(result->shape()->elem_cnt(), slice_result_shape.elem_cnt()) << Error::RuntimeError() << "The global tensor size mismatch. Target sizes: " << slice_result_shape.ToString() << ", value sizes: " << result->shape()->ToString(); result = JUST(functional::View(result, slice_result_shape)); } JUST(SliceUpdate(expand_input, result, start, end, step, /*inplace=*/true)); } } return Maybe::Ok(); } }; class CastLikeFunctor { public: CastLikeFunctor() { op_ = CHECK_JUST( one::OpBuilder("cast_like").Input("in").Input("dtype_like").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& like) const { return OpInterpUtil::Dispatch(*op_, {x, like}); } private: std::shared_ptr op_; }; class ElementwiseMinimumGradFunctor { public: ElementwiseMinimumGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("elementwise_minimum_backward") .Input("dz") .Input("x") .Input("y") .Output("dx") .Output("dy") .Build()); } Maybe operator()(const std::shared_ptr& dz, const std::shared_ptr& x, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {dz, x, y}); } private: std::shared_ptr op_; }; class ElementwiseMaximumGradFunctor { public: ElementwiseMaximumGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("elementwise_maximum_backward") .Input("dz") .Input("x") .Input("y") .Output("dx") .Output("dy") .Build()); } Maybe operator()(const std::shared_ptr& dz, const std::shared_ptr& x, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {dz, x, y}); } private: std::shared_ptr op_; }; class DivGradFunctor { public: DivGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_div_grad") .Input("dz") .Input("z") .Input("y") .Output("dy") .Build()); } Maybe operator()(const std::shared_ptr& dz, const std::shared_ptr& z, const std::shared_ptr& y) const { return OpInterpUtil::Dispatch(*op_, {dz, z, y}); } private: std::shared_ptr op_; }; class BroadcastPowXGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const std::shared_ptr& dz) const { auto y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false)); auto result = functional::sequence_function(functional::BroadcastPow) .then(std::bind(functional::Mul, std::placeholders::_1, y)) .then(std::bind(functional::Mul, std::placeholders::_1, dz)) .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, x)) .call(x, y_sub_one); return result; } }; class BroadcastPowYGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const std::shared_ptr& dz) const { auto result = functional::sequence_function(functional::BroadcastPow) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Log(x)))) .then(std::bind(functional::Mul, std::placeholders::_1, dz)) .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, y)) .call(x, y); return result; } }; class IdentityFunctor { public: IdentityFunctor() { op_ = CHECK_JUST(one::OpBuilder("identity").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in) const { return OpInterpUtil::Dispatch(*op_, {in}); } private: std::shared_ptr op_; }; class AmpWhiteIdentityFunctor { public: AmpWhiteIdentityFunctor() { op_ = CHECK_JUST(one::OpBuilder("amp_white_identity").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in) const { return OpInterpUtil::Dispatch(*op_, {in}); } private: std::shared_ptr op_; }; class AmpBlackIdentityFunctor { public: AmpBlackIdentityFunctor() { op_ = CHECK_JUST(one::OpBuilder("amp_black_identity").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in) const { return OpInterpUtil::Dispatch(*op_, {in}); } private: std::shared_ptr op_; }; class ReduceSumLikeFunctor { public: ReduceSumLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("reduce_sum_like").Input("x").Input("like").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& like, const std::vector& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {x, JUST(like->detach())}, attrs); } private: std::shared_ptr op_; }; class BroadcastReduceSumLikeFunctor { public: BroadcastReduceSumLikeFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& like) const { const auto& in_shape = *(input->shape()); const auto& like_shape = *(like->shape()); if (in_shape != like_shape) { const Shape& left_extended_shape = CreateLeftExtendedShape(ShapeView(like_shape), in_shape.NumAxes()); if (in_shape == left_extended_shape) { return JUST(ReshapeLike(input, like)); } else { const AxisVector& broadcast_axis_vec = left_extended_shape.Axes4BroadcastTo(in_shape); return JUST(ReduceSumLike( input, like, std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()})); } } return JUST(Identity(input)); } }; class SplitFunctor { public: SplitFunctor() {} Maybe operator()(const std::shared_ptr& x, const int64_t& split_size_or_sections, const int64_t& dim) const { int64_t axis = dim; axis = JUST(maybe_wrap_dim(axis, x->ndim())); CHECK_GE_OR_RETURN(split_size_or_sections, 0) << Error::RuntimeError() << "split expects split_size be non-negative, but got split_size=" << split_size_or_sections; int64_t dim_size = x->shape()->At(axis); int64_t num_splits = std::max((dim_size + split_size_or_sections - 1) / split_size_or_sections, 1); TensorTuple splits(num_splits); int64_t last_split_size = split_size_or_sections - (split_size_or_sections * num_splits - dim_size); for (int i = 0; i < num_splits; ++i) { int64_t length = i < num_splits - 1 ? split_size_or_sections : last_split_size; splits[i] = JUST(Narrow(x, axis, i * split_size_or_sections, length)); } return splits; } }; class UnbindFunctor { public: UnbindFunctor() {} Maybe operator()(const std::shared_ptr& x, const int64_t& dim) const { int32_t axis = dim; const int32_t ndim = x->ndim(); axis = JUST(maybe_wrap_dim(axis, ndim)); int32_t dim_size = x->shape()->At(axis); std::shared_ptr chunk_res = JUST(functional::Chunk(x, dim_size, axis)); TensorTuple unbinds(dim_size); std::vector dims = {axis}; for (int i = 0; i < dim_size; ++i) { unbinds[i] = JUST(functional::Squeeze((*chunk_res)[i], dims)); } return unbinds; } }; class ChunkFunctor { public: ChunkFunctor() {} Maybe operator()(const std::shared_ptr& x, const int64_t& chunks, const int64_t& dim) const { const int64_t ndim = x->ndim(); int64_t infferd_dim = dim; CHECK_OR_RETURN(ndim > 0) << Error::RuntimeError() << "chunk expects at least a 1-dimensional tensor."; CHECK_OR_RETURN(chunks > 0) << Error::RuntimeError() << "chunk expects `chunks` to be greater than 0, got: " << chunks; infferd_dim = JUST(maybe_wrap_dim(infferd_dim, ndim)); const auto dim_size = x->shape()->At(infferd_dim); int64_t split_size = (dim_size + chunks - 1) / chunks; if (split_size == 0 && dim_size == 0) { std::vector split_sizes(chunks, split_size); split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size); return functional::SplitWithSize(x, split_sizes, infferd_dim); } else { return functional::Split(x, split_size, infferd_dim); } } }; class SplitLikeFunctor { public: SplitLikeFunctor() { ops_.resize(kMaxInputCount); for (int n = 1; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST(one::OpBuilder("split_like") .Input("in") .Input("like", n + 1) .Output("out", n + 1) .Build()); } } Maybe operator()(const std::shared_ptr& x, const TensorTuple& like, const int64_t& axis) const { CHECK_GE_OR_RETURN(like.size(), 2) << Error::RuntimeError() << "like.size() must not less than 2, but got " << like.size(); CHECK_LE_OR_RETURN(like.size(), kMaxInputCount) << Error::RuntimeError() << "like.size() must not greater than " << kMaxInputCount << ", but got " << like.size(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); TensorTuple inputs(like.size() + 1); inputs[0] = x; for (int i = 0; i < like.size(); ++i) { inputs[i + 1] = JUST(like[i]->detach()); } return OpInterpUtil::Dispatch(*ops_.at(like.size() - 1), inputs, attrs); } private: std::vector> ops_; }; class SplitWithSizeFunctor { public: SplitWithSizeFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::vector& split_size_or_sections, const int64_t& dim) const { int64_t axis = dim; axis = JUST(maybe_wrap_dim(axis, x->ndim())); int64_t dim_size = x->shape()->At(axis); int64_t num_splits = split_size_or_sections.size(); TensorTuple splits(num_splits); int64_t start_idx = 0; for (int i = 0; i < num_splits; ++i) { int64_t length = split_size_or_sections[i]; CHECK_GE_OR_RETURN(length, 0) << Error::RuntimeError() << "split_with_sizes expects split_sizes have only " "non-negative entries, but split_sizes[" << i << "] = " << length; splits[i] = JUST(Narrow(x, axis, start_idx, length)); start_idx += length; } CHECK_EQ_OR_RETURN(start_idx, dim_size) << Error::RuntimeError() << "split_with_sizes expects split_sizes to sum exactly to " << dim_size << " (input tensor's size at dimension " << axis << "), " << "but got sum(split_sizes)=" << start_idx; return splits; } }; class BatchGatherFunctor { public: BatchGatherFunctor() { op_ = CHECK_JUST( one::OpBuilder("batch_gather").Input("in").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& indices) const { return OpInterpUtil::Dispatch(*op_, {in, indices}); } protected: std::shared_ptr op_; }; class UnsortedBatchSegmentSumFunctor { public: UnsortedBatchSegmentSumFunctor() { op_ = CHECK_JUST(one::OpBuilder("unsorted_batch_segment_sum") .Input("data") .Input("segment_ids") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& data, const std::shared_ptr& segment_ids, const int64_t& num_segments) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_segments"); attrs.SetAllAttrs(num_segments); return OpInterpUtil::Dispatch(*op_, {data, segment_ids}, attrs); } protected: std::shared_ptr op_; }; template class MaskedFillFunctor { public: MaskedFillFunctor() { op_ = CHECK_JUST(one::OpBuilder("masked_fill").Input("x").Input("mask").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& mask, const Scalar& value) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand", "bool_operand", "has_bool_operand"); if (IsFloatingDataType(x->dtype()->data_type())) { attrs.SetAllAttrs(value.As(), true, NullOpt, false, NullOpt, false); } else if (IsIntegralDataType(x->dtype()->data_type())) { attrs.SetAllAttrs(NullOpt, false, value.As(), true, NullOpt, false); } else if (IsBoolDataType(x->dtype()->data_type())) { attrs.SetAllAttrs(NullOpt, false, NullOpt, false, value.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "Only support floating or integral data type."; } const auto& x_shape = *(x->shape()); const auto& mask_shape = *(mask->shape()); std::shared_ptr outputs = std::make_shared(1); if (inplace) { JUST(CheckInplaceValid(x)); (*outputs)[0] = x; } if (x_shape != mask_shape) { Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), mask_shape.NumAxes())); const Shape& x_extend_shape = CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes()); const Shape& mask_extend_shape = CreateLeftExtendedShape(ShapeView(mask_shape), max_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) { max_shape.Set(i, std::max(x_extend_shape.At(i), mask_extend_shape.At(i))); } JUST(OpInterpUtil::Dispatch(*op_, {JUST(Expand(x, max_shape)), JUST(Expand(mask, max_shape))}, outputs.get(), attrs)); return outputs->at(0); } JUST(OpInterpUtil::Dispatch(*op_, {x, mask}, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class MeshgridFunctor { public: Maybe operator()(const TensorTuple& tensors, const std::string& indexing) const { int size = tensors.size(); CHECK_GT_OR_RETURN(size, 0) << Error::RuntimeError() << "meshgrid expects a non-empty TensorList"; for (int i = 0; i < size - 1; ++i) { const auto& cur_tensor = JUST(VectorAt(tensors, i)); const auto& next_tensor = JUST(VectorAt(tensors, i + 1)); CHECK_OR_RETURN(cur_tensor->dtype() == next_tensor->dtype()) << Error::RuntimeError() << "meshgrid expects all tensors to have the same dtype"; if (cur_tensor->is_local()) { CHECK_OR_RETURN(next_tensor->is_local()) << Error::RuntimeError() << "meshgrid expects all tensors are local tensor"; CHECK_OR_RETURN(JUST(cur_tensor->device())->type() == JUST(next_tensor->device())->type()) << Error::RuntimeError() << "meshgrid expects all tensors to have the same device"; } else { CHECK_OR_RETURN(!next_tensor->is_local()) << Error::RuntimeError() << "meshgrid expects all tensors are global tensor"; CHECK_OR_RETURN(JUST(cur_tensor->parallel_desc()) == JUST(next_tensor->parallel_desc())) << Error::RuntimeError() << "meshgrid expects all tensors to have the same placement"; } } std::vector> tensor_consts(tensors.begin(), tensors.end()); bool swap_first_and_second_tensors = false; if (indexing == "xy") { swap_first_and_second_tensors = (size >= 2); if (swap_first_and_second_tensors) { std::swap(tensor_consts[0], tensor_consts[1]); } } else { CHECK_EQ_OR_RETURN(indexing, "ij") << Error::RuntimeError() << "meshgrid: indexing must be one of \"xy\" or \"ij\", " "but received: " << indexing; } TensorTuple grids(size); DimVector grids_vec(size); for (int i = 0; i < size; ++i) { CHECK_LE_OR_RETURN(tensor_consts[i]->shape()->NumAxes(), 1) << Error::RuntimeError() << "Expected scalar or 1D tensor in the tensor list but got " << tensor_consts[i]->shape()->NumAxes(); if (tensor_consts[i]->shape()->NumAxes() == 0) { grids_vec[i] = 1; } else { grids_vec[i] = tensor_consts[i]->shape()->At(0); } } Shape grids_shape(grids_vec); DimVector view_shape_vec(size, 1); Shape view_shape(view_shape_vec); for (int i = 0; i < size; ++i) { view_shape.Set(i, -1); std::shared_ptr reshaped = JUST(Reshape(tensor_consts.at(i), view_shape)); grids[i] = JUST(Expand(reshaped, grids_shape)); view_shape.Set(i, 1); } if (swap_first_and_second_tensors) { std::swap(grids[0], grids[1]); } return grids; } }; class IndexSelectFunctor { public: Maybe operator()(const std::shared_ptr& input, const int64_t& dim, const std::shared_ptr& index) const { const int64_t input_num_axes = input->shape()->NumAxes(); const int64_t index_num_axes = index->shape()->NumAxes(); CHECK_LE_OR_RETURN(index_num_axes, 1) << Error::IndexError() << "index_select(): Index is supposed to be a vector"; bool index_dtype_flag = (index->dtype()->data_type() == kInt32) || (index->dtype()->data_type() == kInt64); CHECK_EQ_OR_RETURN(index_dtype_flag, true) << Error::RuntimeError() << "index_select(): Expected dtype int32 or int64 for index"; int64_t new_dim = dim; new_dim = JUST(maybe_wrap_dim(new_dim, input_num_axes)); return JUST(functional::Gather(input, index, new_dim)); } }; namespace { Maybe LocalTensorTo(const std::shared_ptr& x, Symbol device, const Symbol& dtype, const bool& copy) { std::shared_ptr tensor = x; if (device != JUST(x->device())) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); } if (dtype != x->dtype()) { tensor = JUST(Cast(tensor, dtype, /*pin_memory=*/false)); } if (copy && tensor == x) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); } return tensor; } Maybe GlobalTensorTo(const std::shared_ptr& x, const std::string& device_type, const Symbol& dtype, const bool& copy) { std::shared_ptr tensor; auto input_placement = JUST(x->parallel_desc()); std::string input_device_tag = input_placement->device_tag(); if (input_device_tag == "gpu") { input_device_tag = "cuda"; } if (device_type == input_device_tag) { if (dtype == x->dtype()) { return (copy ? JUST(x->clone()) : x); } else { return JUST(Cast(x, dtype, /*pin_memory=*/false)); } } if (LazyMode::is_enabled()) { if (dtype != x->dtype()) { tensor = JUST(Cast(x, dtype, /*pin_memory=*/false)); } if (device_type != JUST(x->parallel_desc())->device_tag()) { tensor = JUST(Copy(tensor ? tensor : x, device_type, 0, /*pin_memory=*/false)); } return tensor; } else { CheckMetaConsistency(x).GetOrThrow(); auto placement = JUST(ReplacePlacementDeviceTag(input_placement, device_type)); auto nd_sbp = JUST(x->nd_sbp()); std::vector> sbp_tuple(nd_sbp->sbp_parallel().size()); for (int i = 0; i < sbp_tuple.size(); ++i) { sbp_tuple[i] = nd_sbp->sbp_parallel().Get(i); } tensor = JUST(GlobalToLocal(x, /*copy=*/false)); Symbol device = JUST(Device::New(device_type)); tensor = JUST(LocalTensorTo(tensor, device, dtype, copy)); JUST(tensor->set_requires_grad(x->requires_grad())); return JUST(LocalToGlobal(tensor, placement, sbp_tuple, *(x->shape()), dtype, /* sync_data */ true, /*copy=*/false)); } } } // namespace class ToFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional& device_, const Optional>& dtype_, bool copy) const { Symbol dtype = dtype_.value_or(input->dtype()); if (input->is_global()) { std::string device_type = device_.value_or(JUST(input->parallel_desc())->device_tag()); CHECK_OR_RETURN(ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_type) != DeviceType::kInvalidDevice) << Error::RuntimeError() << "Only string device without device id (eg. \"cpu\" or \"cuda\") is expected " << "for global tensor, but got " << device_.value_or(""); return JUST(GlobalTensorTo(input, device_type, dtype, copy)); } else { Symbol device = device_ .map([](const std::shared_ptr& str) -> Symbol { return CHECK_JUST(Device::ParseAndNew(*str)); }) .value_or(JUST(input->device())); return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; class To2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& device_, const Optional>& dtype_, bool copy) const { if (input->is_global()) { if (!device_.has_value()) { std::string device_type = JUST(input->parallel_desc())->device_tag(); return JUST(GlobalTensorTo(input, device_type, dtype_.value_or(input->dtype()), copy)); } else { if (!GlobalMode::is_enabled()) { CHECK_OR_RETURN(!device_.has_value()) << Error::RuntimeError() << "Only string device without device id (eg. \"cpu\" or \"cuda\") is expected " << "for global tensor, but got " << device_.value_or(Symbol())->ToRepr(); } std::string device_type = device_.value_or(Symbol())->type(); return JUST(GlobalTensorTo(input, device_type, dtype_.value_or(input->dtype()), copy)); } } else { auto dtype = dtype_.value_or(input->dtype()); auto device = device_.value_or(JUST(input->device())); return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; class To3Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& dtype_, bool copy) const { Symbol dtype = dtype_.value_or(input->dtype()); if (input->is_global()) { return GlobalTensorTo(input, JUST(input->parallel_desc())->device_tag(), dtype, copy); } else { auto device = JUST(input->device()); return LocalTensorTo(input, device, dtype, copy); } } }; class To4Functor { public: Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& other, bool copy) const { CHECK_OR_RETURN(!input->is_global() && !other->is_global()) << Error::RuntimeError() << "tensor.to(other) can only be called when tensor and other are local tensors"; Symbol dtype = other->dtype(); Symbol device = JUST(other->device()); return LocalTensorTo(input, device, dtype, copy); } }; class ToDeviceFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional& device_) const { Symbol dtype = input->dtype(); const bool copy = false; if (input->is_global()) { std::string device_type = device_.value_or(JUST(input->parallel_desc())->device_tag()); CHECK_OR_RETURN(ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_type) != DeviceType::kInvalidDevice) << Error::RuntimeError() << "Only string device without device id (eg. \"cpu\" or \"cuda\") is expected " << "for global tensor, but got " << device_.value_or(""); return JUST(GlobalTensorTo(input, device_type, dtype, copy)); } else { Symbol device = device_ .map([](const std::shared_ptr& str) -> Symbol { return CHECK_JUST(Device::ParseAndNew(*str)); }) .value_or(JUST(input->device())); return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; class ToMemoryFormatFunctor { public: ToMemoryFormatFunctor() { op_ = CHECK_JUST(one::OpBuilder("convert_memory_format").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, MemoryFormat memory_format) const { if (input->memory_format() == memory_format) { return input; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("memory_format"); attrs.SetAllAttrs(memory_format); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class TopKFunctor { public: TopKFunctor() { op_ = CHECK_JUST(one::OpBuilder("top_k").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const int32_t k, const Optional& dim, const bool largest, const bool sorted) const { auto outputs = std::make_shared(2); std::shared_ptr values; std::shared_ptr indices; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("k", "sorted"); attrs.SetAllAttrs(k, sorted); int32_t dim_value = dim.value_or(-1); int32_t axis = dim_value; axis = JUST(maybe_wrap_dim(axis, input->ndim())); if (axis == input->ndim() - 1) { if (largest) { indices = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); } else { auto neg_input = JUST(ScalarMul(input, -1, false)); indices = JUST(OpInterpUtil::Dispatch(*op_, {neg_input}, attrs)); } values = JUST(DimGather(input, axis, indices, false)); } else { auto perm = JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim_value)); auto x = JUST(Transpose(input, *perm)); if (largest) { indices = JUST(OpInterpUtil::Dispatch(*op_, {x}, attrs)); } else { auto neg_input = JUST(ScalarMul(x, -1, false)); indices = JUST(OpInterpUtil::Dispatch(*op_, {neg_input}, attrs)); } auto inversed_perm = JUST(GetInversedPerm(*perm)); indices = JUST(Transpose(indices, *inversed_perm)); values = JUST(DimGather(input, axis, indices, false)); } (*outputs)[0] = values; (*outputs)[1] = indices; return outputs; } private: std::shared_ptr op_; }; class InTopKFunctor { public: InTopKFunctor() { op_ = CHECK_JUST( one::OpBuilder("in_top_k").Input("targets").Input("predictions").Output("out").Build()); } Maybe operator()(const std::shared_ptr& targets, const std::shared_ptr& predictions, int32_t k) const { CHECK_EQ_OR_RETURN(targets->shape()->At(0), predictions->shape()->At(0)) << Error::RuntimeError() << "The num of targets must equal the num of predictions"; CHECK_EQ_OR_RETURN(targets->ndim(), 1) << Error::RuntimeError() << "The dimension of targets must be 1"; CHECK_EQ_OR_RETURN(predictions->ndim(), 2) << Error::RuntimeError() << "The dimension of predictions must be 2"; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("k"); attrs.SetAllAttrs(k); return OpInterpUtil::Dispatch(*op_, {targets, predictions}, attrs); } private: std::shared_ptr op_; }; class TensorBufferToTensorFunctor { public: TensorBufferToTensorFunctor() { op_ = CHECK_JUST(one::OpBuilder("tensor_buffer_to_tensor").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const Shape& instance_shape, const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("instance_shape", "dtype"); attrs.SetAllAttrs(instance_shape, dtype->data_type()); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class TensorToTensorBufferFunctor { public: TensorToTensorBufferFunctor() { op_ = CHECK_JUST(one::OpBuilder("tensor_to_tensor_buffer").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, int32_t instance_dims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("instance_dims"); attrs.SetAllAttrs(instance_dims); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class GenTensorBufferFunctor { public: GenTensorBufferFunctor() { op_ = CHECK_JUST(one::OpBuilder("gen_tensor_buffer").Output("out").Build()); } Maybe operator()(const Shape& shape, const std::vector& shape_list, const std::vector& value_list, const Symbol& dtype, bool dynamic_out) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "shape_list", "value_list", "data_type", "dynamic_out"); attrs.SetAllAttrs(shape, shape_list, value_list, dtype->data_type(), dynamic_out); return OpInterpUtil::Dispatch(*op_, {}, attrs); } private: std::shared_ptr op_; }; class RepeatFunctor { public: RepeatFunctor() {} Maybe operator()(const std::shared_ptr& input, const Shape& repeat_shape) const { Shape input_shape = *(input->shape()); std::vector input_reshape_vec; std::vector expand_shape_vec; std::vector output_reshape_vec; int32_t numaxes_diff = repeat_shape.NumAxes() - input_shape.NumAxes(); CHECK_GE_OR_RETURN(numaxes_diff, 0) << Error::RuntimeError() << "Number of dimensions of repeat dims can not be " "smaller than number of dimensions of tensor"; for (int32_t i = repeat_shape.NumAxes() - 1; i >= 0; i--) { if (i >= numaxes_diff) { int32_t input_shape_val = input_shape.At(i - numaxes_diff); int32_t repeat_shape_val = repeat_shape.At(i); if (repeat_shape_val > 1) { if (input_shape_val > 1) { input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val); input_reshape_vec.insert(input_reshape_vec.begin(), 1); expand_shape_vec.insert(expand_shape_vec.begin(), input_shape_val); expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape_val); output_reshape_vec.insert(output_reshape_vec.begin(), repeat_shape_val * input_shape_val); } else { input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val); expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape_val); output_reshape_vec.insert(output_reshape_vec.begin(), repeat_shape_val); } } else { input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val); // For 0-size tensor, align with PyTorch. if (repeat_shape_val == 0) { expand_shape_vec.insert(expand_shape_vec.begin(), 0); output_reshape_vec.insert(output_reshape_vec.begin(), 0); } else { expand_shape_vec.insert(expand_shape_vec.begin(), input_shape_val); output_reshape_vec.insert(output_reshape_vec.begin(), input_shape_val); } } } else { expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape.At(i)); output_reshape_vec.insert(output_reshape_vec.begin(), repeat_shape.At(i)); } } Shape input_reshape(DimVector(input_reshape_vec.begin(), input_reshape_vec.end())); Shape expand_shape(DimVector(expand_shape_vec.begin(), expand_shape_vec.end())); Shape output_reshape(DimVector(output_reshape_vec.begin(), output_reshape_vec.end())); std::shared_ptr reshaped_tensor = JUST(Reshape(input, input_reshape)); std::shared_ptr expanded_tensor = JUST(Expand(reshaped_tensor, expand_shape)); std::shared_ptr result = JUST(Reshape(expanded_tensor, output_reshape)); return result->contiguous(); } }; class RepeatInterLeaveIndexFunctor { public: RepeatInterLeaveIndexFunctor() { op_ = CHECK_JUST( one::OpBuilder("repeat_interleave").Input("in").Input("cumsum").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& cumsum, const int32_t& repeat_num) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("repeat_num"); attrs.SetAllAttrs(static_cast(repeat_num)); return OpInterpUtil::Dispatch(*op_, {input, cumsum}, attrs); } private: std::shared_ptr op_; }; class RepeatInterLeaveIntFunctor { public: RepeatInterLeaveIntFunctor() {} Maybe operator()(const std::shared_ptr& input, const int32_t& repeats, const Optional& dim) const { CHECK_OR_RETURN(input->is_local() == true) << Error::RuntimeError() << "repeat_interleave only support local tensor now"; std::shared_ptr res; if (!dim.has_value()) { std::shared_ptr flatten_input = JUST(Flatten(input, 0, -1)); std::shared_ptr repeats_expand = JUST( Expand(JUST(Constant(Shape{1}, Scalar(repeats), DType::Int32(), JUST(input->device()))), Shape{flatten_input->shape()->At(0)})); std::shared_ptr cumsum = JUST(Cumsum(repeats_expand, 0, DType::Int32())); int64_t output_size = flatten_input->shape()->At(0); if (repeats > 0) { output_size *= repeats; } res = JUST(IndexSelect(flatten_input, 0, JUST(RepeatInterLeaveIndex(repeats_expand, cumsum, output_size)))); } else { int32_t dim_ = JUST(dim); const auto& input_shape = input->shape(); const int64_t& num_axes = input_shape->NumAxes(); dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); std::shared_ptr repeats_expand = JUST( Expand(JUST(Constant(Shape{1}, Scalar(repeats), DType::Int32(), JUST(input->device()))), Shape{input->shape()->At(dim_)})); std::shared_ptr cumsum = JUST(Cumsum(repeats_expand, 0, DType::Int32())); int64_t output_size = input->shape()->At(dim_); if (repeats > 0) { output_size *= repeats; } res = JUST(IndexSelect(input, dim_, JUST(RepeatInterLeaveIndex(repeats_expand, cumsum, output_size)))); } return res; } }; class RepeatInterLeaveTensorFunctor { public: RepeatInterLeaveTensorFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& repeats, const int32_t& dim, const Optional& output_size) const { CHECK_OR_RETURN(input->is_local() == true) << Error::RuntimeError() << "repeat_interleave only support local tensor now"; const auto repeats_shape = repeats->shape(); const int64_t& repeat_num_axes = repeats_shape->NumAxes(); CHECK_OR_RETURN(repeat_num_axes == 1) << Error::RuntimeError() << "repeat_interleave only accept 1D vector as repeat"; CHECK_OR_RETURN(repeats->dtype() == DType::Int64()) << Error::RuntimeError() << "repeats has to be Long tensor"; std::vector repeats_value(repeats_shape->elem_cnt()); if (!output_size.has_value()) { const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, repeats_value.data(), eager_blob_object->dptr(), repeats_value.size() * sizeof(int64_t), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; SyncAccessTensorWithTimeOut(repeats, callback, "const").GetOrThrow(); for (const auto x : repeats_value) { CHECK_OR_RETURN(x >= 0) << Error::RuntimeError() << "repeats can not be negative"; } } else { repeats_value.push_back(JUST(output_size)); } int32_t dim_ = dim; const auto& input_shape = input->shape(); const int64_t& num_axes = input_shape->NumAxes(); dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); CHECK_OR_RETURN(repeats_shape->At(0) == input->shape()->At(dim_)) << Error::RuntimeError() << "repeats must have the same size as input along dim"; std::shared_ptr cumsum = JUST(Cumsum(repeats, 0, DType::Int32())); const int64_t& output_size_value = std::accumulate(repeats_value.begin(), repeats_value.end(), 0); return JUST( IndexSelect(input, dim_, JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value)))); } }; class TileFunctor { public: TileFunctor() {} Maybe operator()(const std::shared_ptr& input, const Shape& dims) const { std::vector new_dims_vec; int32_t numaxes_diff = input->shape()->NumAxes() - dims.NumAxes(); for (int32_t i = dims.NumAxes() - 1; i >= 0; i--) { CHECK_GE_OR_RETURN(dims.At(i), 0) << Error::RuntimeError() << "Trying to create tensor with negative dimension " << dims.At(i); new_dims_vec.insert(new_dims_vec.begin(), dims.At(i)); } for (int32_t i = 0; i < numaxes_diff; i++) { new_dims_vec.insert(new_dims_vec.begin(), 1); } Shape new_dims(DimVector(new_dims_vec.begin(), new_dims_vec.end())); return JUST(Repeat(input, new_dims)); } }; class TransposeAllDimPropertyFunctor { public: TransposeAllDimPropertyFunctor() {} Maybe operator()(const std::shared_ptr& x) const { const int64_t ndim = x->ndim(); std::vector permute; permute.resize(ndim); std::iota(permute.begin(), permute.end(), 0); std::reverse(permute.begin(), permute.end()); return Transpose(x, permute); } }; class TransposeAllDimFunctionFunctor { public: TransposeAllDimFunctionFunctor() {} Maybe operator()(const std::shared_ptr& x) const { const int64_t ndim = x->ndim(); CHECK_OR_RETURN(ndim <= 2) << Error::RuntimeError() << "t() expects a tensor with <= 2 dimensions, but input tensor is " << ndim << "D"; if (ndim == 0 || ndim == 1) { return x; } return Transpose2dim(x, 0, 1); } }; class ReshapeLikeFunctor { public: ReshapeLikeFunctor() { op_ = CHECK_JUST(one::OpBuilder("reshape_like").Input("in").Input("like").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& like) const { return OpInterpUtil::Dispatch(*op_, {x, JUST(like->detach())}); } private: std::shared_ptr op_; }; class PinMemoryFunctor { public: PinMemoryFunctor() { op_ = CHECK_JUST(one::OpBuilder("slice_update").Input("ref").Input("value").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input) const { // TODO:(zhaoluyang) support global tensor.pin_memory() CHECK_OR_RETURN(input->is_local() && !(LazyMode::is_enabled())) << Error::RuntimeError() << "Tensor.pin_memory() only support local tensor for now!"; // if tensor already pinned, then just return if (JUST(JUST(input->AsLocalTensor())->is_pinned())) { return input; } auto shape = input->shape(); auto device = JUST(input->device()); const bool requires_grad = input->requires_grad(); CHECK_EQ_OR_RETURN(device->enum_type(), DeviceType::kCPU) << Error::RuntimeError() << "cannot pin tensor with device: " << device->ToString() << ", only dense CPU tensors can be pinned."; auto empty = JUST(functional::Empty(*shape.get(), input->dtype(), device, requires_grad, /*pin_memory=*/true)); const int32_t ndim = input->ndim(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("start", "stop", "step"); if (ndim == 0) { // TODO(wyg): use TensorSetItem after supporting non-requires_grad tensor inplace // for 0-dim tensor empty = JUST(functional::ExpandDims(empty, 0)); // expand to [1, ] auto expand_input = JUST(functional::ExpandDims(input, 0)); // expand to [1, ] attrs.SetAllAttrs(std::vector{0}, std::vector{1}, std::vector{1}); auto outputs = TensorTuple{empty}; JUST(OpInterpUtil::Dispatch(*op_, TensorTuple{empty, expand_input}, &outputs, attrs)); return outputs[0]; } else { std::vector starts(ndim, 0); std::vector stops(ndim); std::vector steps(ndim, 1); for (int i = 0; i < ndim; ++i) { stops[i] = input->shape()->At(i); } attrs.SetAllAttrs(starts, stops, steps); JUST(empty->set_requires_grad(requires_grad)); auto outputs = TensorTuple{empty}; JUST(OpInterpUtil::Dispatch(*op_, TensorTuple{empty, input}, &outputs, attrs)); return outputs[0]; } } private: std::shared_ptr op_; }; class FillFunctor { public: FillFunctor() { op_ = CHECK_JUST(one::OpBuilder("fill_").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const Scalar& value) const { JUST(CheckInplaceValid(in)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("floating_value", "is_floating_value", "integral_value"); if (IsFloatingDataType(in->dtype()->data_type())) { attrs.SetAllAttrs(value.As(), true, NullOpt); } else if (IsIntegralDataType(in->dtype()->data_type())) { attrs.SetAllAttrs(NullOpt, false, value.As()); } else { UNIMPLEMENTED_THEN_RETURN() << "Only support floating or integral data type."; } auto outputs = std::make_shared(1); (*outputs)[0] = in; JUST(OpInterpUtil::Dispatch(*op_, {in}, outputs.get(), attrs)); return (*outputs)[0]; } private: std::shared_ptr op_; }; class FillTensorFunctor { public: FillTensorFunctor() { op_ = CHECK_JUST(one::OpBuilder("fill_tensor_").Input("in").Input("value").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& value) const { JUST(CheckInplaceValid(in)); const int64_t ndim = value->ndim(); CHECK_EQ_OR_RETURN(ndim, 0) << Error::RuntimeError() << "fill_ only supports 0-dimension value tensor but got tensor with " << ndim << " dimensions."; TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true, in->dtype()) .AddInputs({in, value}) .Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); auto outputs = std::make_shared(1); (*outputs)[0] = in; JUST(OpInterpUtil::Dispatch(*op_, {input_tuple[0], input_tuple[1]}, outputs.get())); return (*outputs)[0]; } private: std::shared_ptr op_; }; class IndexAddFunctor { public: IndexAddFunctor() { op_ = CHECK_JUST(one::OpBuilder("index_add") .Input("input") .Input("index") .Input("source") .Output("output") .Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t& dim, const std::shared_ptr& index, const std::shared_ptr& source, const Scalar& alpha) const { CHECK_OR_RETURN(source->ndim() == 0 || index->shape()->Count(0) == source->shape()->At(dim)) << "index_copy_(): Number of indices (," << index->shape()->Count(0) << ", \") should be equal to source.size(dim) (," << source->shape()->At(dim) << ", \")"; CHECK_OR_RETURN(index->dtype()->data_type() != DataType::kInt32 || index->dtype()->data_type() != DataType::kInt64) << "Input(Index) holds the wrong type, it holds " << DataType_Name(index->dtype()->data_type()) << " , but " "desires to be int32_t or int64_t"; const float alpha_value = alpha.As(); int64_t dim_ = dim; dim_ = JUST(maybe_wrap_dim(dim_, input->ndim())); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "alpha"); attrs.SetAllAttrs(dim_, alpha_value); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true, input->dtype()) .AddInputs({input, source}) .Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, {input, index, input_tuple.at(1)}, attrs); } private: std::shared_ptr op_; }; class IndexAddInplaceFunctor { public: IndexAddInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder("index_add") .Input("input") .Input("index") .Input("source") .Output("output") .Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t& dim, const std::shared_ptr& index, const std::shared_ptr& source, const Scalar& alpha) const { CHECK_OR_RETURN(source->ndim() == 0 || index->shape()->Count(0) == source->shape()->At(dim)) << "index_copy_(): Number of indices (," << index->shape()->Count(0) << ", \") should be equal to source.size(dim) (," << source->shape()->At(dim) << ", \")"; CHECK_OR_RETURN(index->dtype()->data_type() != DataType::kInt32 || index->dtype()->data_type() != DataType::kInt64) << "Input(Index) holds the wrong type, it holds " << DataType_Name(index->dtype()->data_type()) << " , but " "desires to be int32_t or int64_t"; const float alpha_value = alpha.As(); int64_t dim_ = dim; dim_ = JUST(maybe_wrap_dim(dim_, input->ndim())); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "alpha"); attrs.SetAllAttrs(dim_, alpha_value); JUST(CheckInplaceValid(input)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = input; TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true, input->dtype()) .AddInputs({input, source}) .Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); JUST(OpInterpUtil::Dispatch(*op_, {input, index, input_tuple.at(1)}, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class BroadcastShapesFunctor { public: Maybe operator()(const std::vector& shapes) const { return InferUnifiedShapeForBroadcasting(shapes); } }; class BroadcastTensorsFunctor { public: Maybe operator()(const TensorTuple& tensors) const { if (tensors.empty()) { return Error::RuntimeError() << "tensors should not be empty."; } Shape shape_to_broadcast; std::deque need_to_broadcast; std::tie(shape_to_broadcast, need_to_broadcast) = *JUST(InferUnifiedShapeForBroadcastingWithInfo([&tensors]() { std::vector shapes; for (auto& x : tensors) { shapes.push_back(*x->shape()); } return shapes; }())); std::shared_ptr outputs = std::make_shared(); for (size_t i = 0; i < tensors.size(); ++i) { outputs->emplace_back(need_to_broadcast.at(i) // NOLINT ? JUST(functional::Expand(tensors.at(i), shape_to_broadcast)) : tensors.at(i)); } return outputs; } }; class BinCountFunctor { public: BinCountFunctor() { op_ = CHECK_JUST(OpBuilder("bincount").Input("in").Output("out").Build()); weight_op_ = CHECK_JUST(OpBuilder("bincount").Input("in").Input("weight").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const Optional& weight, const Optional& minlength) const { CHECK_OR_RETURN(!input->dtype()->is_floating_point()) << "bincount can only support int tensor"; TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply()); const auto x = JUST(tensor_processor.GetInputs()).at(0); std::shared_ptr local_tensor = x; int64_t max = 0; // check min value { if (x->is_global()) { local_tensor = JUST(GlobalToLocal(x, false)); } auto tensor_min = JUST(functional::Min(local_tensor)); int64_t min = 0; const auto& callback_min = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &min, eager_blob_object->dptr(), sizeof(min), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(tensor_min, callback_min, "const")); CHECK_GE_OR_RETURN(min, 0) << "bincount only supports 1-d non-negative integral inputs."; auto tensor_max = JUST(functional::Max(local_tensor)); const auto& callback_max = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(tensor_max, callback_max, "const")); max += 1; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size"); if (minlength) { CHECK_GE_OR_RETURN(JUST(minlength), 0) << "minlength should be >= 0"; max = std::max(JUST(minlength), max); } attrs.SetAllAttrs(max); if (weight) { CHECK_EQ_OR_RETURN(JUST(weight)->nelement(), x->nelement()) << "input and weights should have the same length"; return OpInterpUtil::Dispatch(*weight_op_, {x, JUST(weight)}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; std::shared_ptr weight_op_; }; class UniqueFunctor { public: UniqueFunctor() { op_ = CHECK_JUST( OpBuilder("unique").Input("x").Output("y").Output("idx").Output("num_unique").Build()); }; Maybe operator()(const std::shared_ptr& x, const bool sorted, const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_idx", "sorted"); DataType out_idx = dtype->data_type(); attrs.SetAllAttrs(out_idx, sorted); std::shared_ptr output = JUST( OpInterpUtil::Dispatch(*op_, {JUST(functional::Flatten(x, 0, -1))}, attrs)); int64_t num_unique = 0; std::shared_ptr num_unique_tensor = output->at(2); { if (num_unique_tensor->is_global()) { num_unique_tensor = JUST(GlobalToLocal(num_unique_tensor, false)); } const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &num_unique, eager_blob_object->dptr(), GetSizeOfDataType(dtype->data_type()), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(num_unique_tensor, callback, "const")); } return functional::Slice(output->at(0), /*start=*/{0}, /*end=*/{num_unique}, /*step=*/{1}, false); } private: std::shared_ptr op_; }; class UniqueWithCountsFunctor { public: UniqueWithCountsFunctor() { unique_op_ = CHECK_JUST( OpBuilder("unique").Input("x").Output("y").Output("idx").Output("num_unique").Build()); unique_with_counts_op_ = CHECK_JUST(OpBuilder("unique_with_counts") .Input("x") .Output("y") .Output("idx") .Output("num_unique") .Output("count") .Build()); }; Maybe operator()(const std::shared_ptr& x, const bool sorted, const bool return_inverse, const bool return_counts, const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_idx", "sorted"); attrs.SetAllAttrs(dtype->data_type(), sorted); std::shared_ptr output; if (return_counts) { output = JUST(OpInterpUtil::Dispatch( *unique_with_counts_op_, {JUST(functional::Flatten(x, 0, -1))}, attrs)); } else { output = JUST(OpInterpUtil::Dispatch( *unique_op_, {JUST(functional::Flatten(x, 0, -1))}, attrs)); } int64_t num_unique = 0; std::shared_ptr num_unique_tensor = output->at(2); { if (num_unique_tensor->is_global()) { num_unique_tensor = JUST(GlobalToLocal(num_unique_tensor, false)); } const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &num_unique, eager_blob_object->dptr(), GetSizeOfDataType(dtype->data_type()), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(num_unique_tensor, callback, "const")); } auto result = std::make_shared(); const auto& y = JUST( functional::Slice(output->at(0), /*start=*/{0}, /*end=*/{num_unique}, /*step=*/{1}, false)); result->emplace_back(y); if (return_inverse) { result->emplace_back(JUST(functional::Reshape(output->at(1), *x->shape()))); } if (return_counts) { const auto count = JUST(functional::Slice(output->at(3), /*start=*/{0}, /*end=*/{num_unique}, /*step=*/{1}, false)); result->emplace_back(count); } return result; } private: std::shared_ptr unique_op_; std::shared_ptr unique_with_counts_op_; }; class BaddBmmFunctor { public: Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& batch1, const std::shared_ptr& batch2, const double& beta, const double& alpha) const { const int32_t batch1_ndim = batch1->ndim(); const int32_t batch2_ndim = batch2->ndim(); CHECK_EQ_OR_RETURN(batch1_ndim, 3) << Error::RuntimeError() << "batch1 must be a 3D tensor"; CHECK_EQ_OR_RETURN(batch2_ndim, 3) << Error::RuntimeError() << "batch2 must be a 3D tensor"; CHECK_EQ_OR_RETURN(batch1->dim(0), batch2->dim(0)) << Error::RuntimeError() << "batch1 and batch2 must have same number of batches, got ," << batch1->dim(0) << " and " << batch2->dim(0); CHECK_EQ_OR_RETURN(batch1->dim(2), batch2->dim(1)) << "Incompatible matrix sizes for bmm (" << batch1->dim(1) << "x" << batch1->dim(2) << " and " << batch2->dim(1) << "x" << batch2->dim(2) << ")"; if (beta == 0.0) { // In stable diffsion, the beta param is always 0.0, so we can avoid use add and mul op to // optimize speed and bandwidth in cuda. return JUST(functional::BatchMatMul(batch1, batch2, false, false, alpha)); } else { // TODO(add a fuse kernel to optimize speed and bancwidth in cuda) return JUST( functional::Add(JUST(functional::ScalarMul(beta, input)), JUST(functional::BatchMatMul(batch1, batch2, false, false, alpha)), /*alpha=*/1.0, /*inplace=*/false)); } } }; class SortFunctor { public: Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const bool descending) const { auto outputs = std::make_shared(2); std::shared_ptr values; std::shared_ptr indices; int32_t axis = dim; axis = JUST(maybe_wrap_dim(axis, input->ndim())); std::string direction("ASCENDING"); if (descending) { direction.assign("DESCENDING"); } if (axis == input->ndim() - 1) { indices = JUST(ArgSort(input, direction)); values = JUST(DimGather(input, axis, indices, false)); } else { std::shared_ptr> perm = JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim)); auto x = JUST(Transpose(input, *perm)); auto indices_temp = JUST(ArgSort(x, direction)); auto inversed_perm = JUST(GetInversedPerm(*perm)); indices = JUST(Transpose(indices_temp, *inversed_perm)); values = JUST(DimGather(input, axis, indices, false)); } (*outputs)[0] = values; (*outputs)[1] = indices; return outputs; } }; class CloneFunctor { public: Maybe operator()(const std::shared_ptr& input) const { return input->clone(); } }; class FusedCodegeexQkvReshapeFunctor { public: FusedCodegeexQkvReshapeFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_codegeex_qkv_reshape") .Input("query") .Input("key") .Input("value") .Output("new_query") .Output("new_key") .Output("new_value") .Build()); } Maybe operator()(const std::shared_ptr& query, const std::shared_ptr& key, const std::shared_ptr& value, const int32_t num_attention_heads) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_attention_heads"); attrs.SetAllAttrs(num_attention_heads); return OpInterpUtil::Dispatch(*op_, {query, key, value}, attrs); } private: std::shared_ptr op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ArgMax"); m.add_functor("ArgMin"); m.add_functor("GlobalTensorConstant"); m.add_functor("TensorConstant"); m.add_functor("GlobalConstant"); m.add_functor("Constant"); m.add_functor("GlobalEmpty"); m.add_functor("Empty"); m.add_functor("EmptyStrided"); m.add_functor("ZerosLike"); m.add_functor("OnesLike"); m.add_functor("FullLike"); m.add_functor("Flatten"); m.add_functor("Fill"); m.add_functor("FillTensor"); m.add_functor("Where"); m.add_functor("WhereScalarX"); m.add_functor("WhereScalarY"); m.add_functor("WhereScalarXY"); m.add_functor("ArgWhere"); m.add_functor("NonZero"); m.add_functor("BroadcastLike"); m.add_functor("Concat"); m.add_functor("Stack"); m.add_functor("StackGrad"); m.add_functor("AtLeast1D"); m.add_functor("AtLeast1D"); m.add_functor("AtLeast2D"); m.add_functor("AtLeast2D"); m.add_functor("AtLeast3D"); m.add_functor("AtLeast3D"); m.add_functor("HStack"); m.add_functor("ColumnStack"); m.add_functor("VStack"); m.add_functor("RowStack"); m.add_functor("DStack"); m.add_functor("Expand"); m.add_functor("ExpandDims"); m.add_functor("Unsqueeze"); m.add_functor("UnsqueezeMultiple"); m.add_functor("InplaceUnsqueeze"); m.add_functor("Squeeze"); m.add_functor("InplaceSqueeze"); m.add_functor("Roll"); m.add_functor("Gather"); m.add_functor("DimGather"); m.add_functor("ArgSort"); m.add_functor("SearchSorted"); m.add_functor("SearchSortedScalar"); m.add_functor("GatherNd"); m.add_functor("ScatterNd"); m.add_functor("TensorScatterNdUpdate"); m.add_functor("ScatterNdLike"); m.add_functor("Reshape"); m.add_functor("View"); m.add_functor("ToContiguous"); m.add_functor("InplaceToContiguous"); m.add_functor("Narrow"); m.add_functor("NarrowGrad"); m.add_functor("SliceUpdate"); m.add_functor("Slice"); m.add_functor("SliceGrad"); m.add_functor("SliceView1dContiguous"); m.add_functor("Copy"); m.add_functor("Flip"); m.add_functor("UnfoldTensor"); m.add_functor("UnfoldTensorGrad"); m.add_functor("UpsampleGrad"); m.add_functor("UpsampleNearest2D"); m.add_functor("UpsampleNearest2DGrad"); m.add_functor("UpsampleBilinear2D"); m.add_functor("UpsampleBilinear2DGrad"); m.add_functor("UpsampleLinear1D"); m.add_functor("UpsampleLinear1DGrad"); m.add_functor("UpsampleNearest1D"); m.add_functor("UpsampleNearest1DGrad"); m.add_functor("UpsampleBicubic2D"); m.add_functor("UpsampleBicubic2DGrad"); m.add_functor("UpsampleNearest3D"); m.add_functor("UpsampleNearest3DGrad"); m.add_functor("UpsampleTrilinear3D"); m.add_functor("UpsampleTrilinear3DGrad"); m.add_functor("UnsortedSegmentSumLike"); m.add_functor("UnsortedSegmentSum"); m.add_functor("Tril"); m.add_functor("InplaceTril"); m.add_functor("Triu"); m.add_functor("InplaceTriu"); m.add_functor("Diag"); m.add_functor("DiagGrad"); m.add_functor("Diagonal"); m.add_functor("DiagonalGrad"); m.add_functor("TensorGetItem"); m.add_functor>("DimScatterUpdate"); m.add_functor>("DimScatterAdd"); m.add_functor>("DimScatterMul"); m.add_functor("DimScatter"); m.add_functor>( "DimScatterUpdateScalar"); m.add_functor>( "DimScatterAddScalar"); m.add_functor>( "DimScatterMulScalar"); m.add_functor("DimScatterScalar"); m.add_functor("DimScatterAddLike"); m.add_functor("TensorSetItem"); m.add_functor("CastLike"); m.add_functor("ElementwiseMinGrad"); m.add_functor("ElementwiseMaxGrad"); m.add_functor("BroadcastPowXGrad"); m.add_functor("BroadcastPowYGrad"); m.add_functor("DivGrad"); m.add_functor("Identity"); m.add_functor("AmpWhiteIdentity"); m.add_functor("AmpBlackIdentity"); m.add_functor("ReduceSumLike"); m.add_functor("BroadcastReduceSumLike"); m.add_functor("Split"); m.add_functor("Unbind"); m.add_functor("Chunk"); m.add_functor("SplitLike"); m.add_functor("SplitWithSize"); m.add_functor("BatchGather"); m.add_functor("UnsortedBatchSegmentSum"); m.add_functor>("MaskedFill"); m.add_functor>("MaskedFillInplace"); m.add_functor("Meshgrid"); m.add_functor("IndexSelect"); m.add_functor("To"); m.add_functor("TopK"); m.add_functor("InTopK"); m.add_functor("TensorToTensorBuffer"); m.add_functor("TensorBufferToTensor"); m.add_functor("GenTensorBuffer"); m.add_functor("Repeat"); m.add_functor("RepeatInterLeaveIndex"); m.add_functor("RepeatInterLeaveInt"); m.add_functor("RepeatInterLeaveTensor"); m.add_functor("Tile"); m.add_functor("TransposeAllDimProperty"); m.add_functor("TransposeAllDimFunction"); m.add_functor("ReshapeLike"); m.add_functor("PinMemory"); m.add_functor("BroadcastShapes"); m.add_functor("BroadcastTensors"); m.add_functor("BroadcastTo"); // BroadcastTo is an alias of Expand m.add_functor("BinCount"); m.add_functor("IndexAdd"); m.add_functor("IndexAddInplace"); m.add_functor("Unique"); m.add_functor("UniqueWithCounts"); m.add_functor("BaddBmm"); m.add_functor("Sort"); m.add_functor("Clone"); m.add_functor("FusedCodegeexQkvReshape"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/binary_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace { bool IsCPUScalarTensor(const std::shared_ptr& tensor) { return tensor->shape()->NumAxes() == 0 && TensorDeviceToString(tensor).find("cpu") != std::string::npos; } } // namespace std::string TensorDeviceToString(const std::shared_ptr& tensor) { if (tensor->is_global()) { return CHECK_JUST(tensor->parallel_desc())->device_tag(); } return CHECK_JUST(tensor->device())->ToString(); } Maybe CastDeviceForCPUScalarTensor(std::shared_ptr& tensor, std::shared_ptr& other, bool inplace) { if (TensorDeviceToString(tensor) != TensorDeviceToString(other)) { if (IsCPUScalarTensor(other)) { other = JUST(functional::To(other, TensorDeviceToString(tensor))); } else if (!inplace && IsCPUScalarTensor(tensor)) { tensor = JUST(functional::To(tensor, TensorDeviceToString(other))); } } return Maybe::Ok(); } class AddFunctor { public: AddFunctor() { add_op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); broadcast_add_op_ = CHECK_JUST(one::OpBuilder("broadcast_add").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& other, const Scalar& alpha, bool inplace) const { auto input_tensor = input; if (IsIntegralDataType(input_tensor->dtype()->data_type()) && IsIntegralDataType(other->dtype()->data_type()) && alpha.IsFloatingPoint()) { return Error::RuntimeError() << "For integral input tensors, argument alpha must not be a floating point number."; } bool input_static_zeros = IsStaticZerosTensor(input_tensor); if (input_static_zeros || IsStaticZerosTensor(other)) { CHECK_OR_RETURN(JUST(input_tensor->device()) == JUST(other->device())) << Error::RuntimeError() << "Expected all tensors to be on the same device, but found at least two devices, " << JUST(input_tensor->device())->ToString() << " and " << JUST(other->device())->ToString() << "!"; CHECK_OR_RETURN(*input_tensor->shape() == *other->shape()) << Error::RuntimeError() << "The size of tensor a " << input_tensor->shape()->ToString() << " must match the size of tensor b " << other->shape(); if (input_static_zeros) { if ((alpha.IsIntegral() && alpha.Value() == 1) || (alpha.IsFloatingPoint() && std::fabs(alpha.Value() - 1.0) < std::numeric_limits::epsilon())) { return other; } else { return JUST(functional::ScalarMul(alpha, other)); } } return input_tensor; } const OpExpr* op = nullptr; Optional> promote_dtype; if (inplace) { promote_dtype = input_tensor->dtype(); } TensorProcessor tensor_processor; if ((alpha.IsIntegral() && alpha.Value() == 1) || (alpha.IsFloatingPoint() && std::fabs(alpha.Value() - 1.0) < std::numeric_limits::epsilon())) { JUST(tensor_processor.PromoteInputsToCommonDtype(true, promote_dtype) .AddInputs({input_tensor, other}) .Apply()); } else { JUST(tensor_processor.PromoteInputsToCommonDtype(true, promote_dtype) .AddInputs({input_tensor, JUST(functional::ScalarMul(alpha, other))}) .Apply()); } TensorTuple input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& input_cast = input_vec[0]; const std::shared_ptr& other_cast = input_vec[1]; JUST(CastDeviceForCPUScalarTensor(input_vec[0], input_vec[1], inplace)); if (*input_cast->shape() == *other_cast->shape()) { op = add_op_.get(); } else { op = broadcast_add_op_.get(); } if (inplace) { JUST(CheckInplaceCastValid(input_tensor, input_cast)); JUST(CheckInplaceValid(input_tensor)); JUST(CheckInplaceShapeCanExpandTo(*other_cast->shape(), *input_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = input_cast; JUST(OpInterpUtil::Dispatch(*op, input_vec, outputs.get())); return outputs->at(0); } return OpInterpUtil::Dispatch(*op, input_vec); } private: std::shared_ptr add_op_; std::shared_ptr broadcast_add_op_; }; class BroadcastPowFunctor : public BinaryFloatFunctor { public: BroadcastPowFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_pow").Input("x").Input("y").Output("z").Build()); } }; class SubFunctor : public InplaceableBinaryFunctor { public: SubFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_sub").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& other, const Scalar& alpha, bool inplace) const { if (IsIntegralDataType(input->dtype()->data_type()) && IsIntegralDataType(other->dtype()->data_type()) && alpha.IsFloatingPoint()) { return Error::RuntimeError() << "For integral input tensors, argument alpha must not be a floating point number."; } if ((alpha.IsIntegral() && alpha.Value() == 1) || (alpha.IsFloatingPoint() && std::fabs(alpha.Value() - 1.0) < std::numeric_limits::epsilon())) { return InplaceableBinaryFunctor::operator()(input, other, inplace); } else { return InplaceableBinaryFunctor::operator()(input, JUST(functional::ScalarMul(alpha, other)), inplace); } } }; class MulFunctor { public: MulFunctor() { broadcast_mul_op_ = CHECK_JUST(one::OpBuilder("broadcast_mul").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false)); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply()); TensorTuple input_vec = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*broadcast_mul_op_, input_vec); } private: std::shared_ptr broadcast_mul_op_; }; class InplaceMulFunctor { public: InplaceMulFunctor() { broadcast_mul_op_ = CHECK_JUST(one::OpBuilder("broadcast_mul").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { TensorProcessor tensor_processor; if (y->requires_grad()) { JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({JUST(Identity(x)), y}) .Apply()); } else { JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply()); } const TensorTuple& input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& x_cast = input_vec.at(0); const std::shared_ptr& y_cast = input_vec.at(1); JUST(CheckInplaceValid(x)); JUST(CheckInplaceCastValid(x, x_cast)); JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*broadcast_mul_op_, input_vec, outputs.get())); return outputs->at(0); } private: std::shared_ptr broadcast_mul_op_; }; class AddcmulBaseFunctor { public: AddcmulBaseFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& tensor1, const std::shared_ptr& tensor2, const Scalar& value, bool inplace) const { return SequenceFunction()>([&]() { return functional::Mul(tensor1, tensor2); }) .then([&](const auto& x) { return functional::ScalarMul(value, x); }) .then([&](const auto& x) { return functional::Add(input, x, /*alpha=*/1, inplace); }) .call(); } }; class AddcmulFunctor : public AddcmulBaseFunctor { public: AddcmulFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& tensor1, const std::shared_ptr& tensor2, const Scalar& value) const { return AddcmulBaseFunctor::operator()(input, tensor1, tensor2, value, /*inplace=*/false); } }; class InplaceAddcmulFunctor : public AddcmulBaseFunctor { public: InplaceAddcmulFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& tensor1, const std::shared_ptr& tensor2, const Scalar& value) const { return AddcmulBaseFunctor::operator()(input, tensor1, tensor2, value, /*inplace=*/true); } }; class DivFunctor : public BinaryFloatFunctor { public: DivFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build()); } }; class DivFunctorMode { public: DivFunctorMode() {} Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const Optional& rounding_mode) const { std::string rmode = rounding_mode.value_or(""); if (rmode == "floor") { return JUST(functional::FloorDiv(x, y)); } else if (rmode == "trunc") { return JUST(functional::TruncDiv(x, y)); } CHECK_OR_RETURN(rmode == "") << "div expected rounding_mode to be one of None," " 'trunc', or 'floor' but found " << rmode; return JUST(functional::Div(x, y)); } private: std::shared_ptr op_; }; class InplaceDivFunctor { public: InplaceDivFunctor() { broadcast_div_op_ = CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/true)); // NOTE: div operator will cast inputs to float when dtype is integral TensorProcessor tensor_processor; TensorTuple tensor_processor_inputs; { if (tensor_y->requires_grad()) { tensor_processor_inputs.assign({JUST(Identity(tensor_x)), tensor_y}); } else { tensor_processor_inputs.assign({tensor_x, tensor_y}); } } if (promoteTypes(tensor_x->dtype(), tensor_y->dtype())->is_integer()) { tensor_processor.AddInputs(tensor_processor_inputs, DType::Float()); } else { tensor_processor.AddInputs(tensor_processor_inputs) .PromoteInputsToCommonDtype(true) .PromoteIntegerInputsToFloatDtype(true); } JUST(tensor_processor.Apply()); const TensorTuple& input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& x_cast = input_vec.at(0); const std::shared_ptr& y_cast = input_vec.at(1); JUST(CheckInplaceValid(x)); JUST(CheckInplaceCastValid(x, x_cast)); JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*broadcast_div_op_, input_vec, outputs.get())); return outputs->at(0); } private: std::shared_ptr broadcast_div_op_; }; class Atan2Functor : public BinaryFloatFunctor { public: Atan2Functor() { op_ = CHECK_JUST(one::OpBuilder("atan2").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { const int64_t x_element = x->nelement(); const int64_t y_element = y->nelement(); CHECK_GT_OR_RETURN(x_element, 0) << Error::RuntimeError() << "the size of input should be > 0, but got " << x_element; CHECK_GT_OR_RETURN(y_element, 0) << Error::RuntimeError() << "the size of input should be > 0, but got " << y_element; if ((x_element != 1 && y_element != 1) && (x->shape()->NumAxes() == y->shape()->NumAxes())) { return BinaryFloatFunctor::operator()(x, y); } auto broad_x_ = x; auto broad_y_ = y; if (x_element == 1) { broad_x_ = JUST(functional::Expand(x, *y->shape())); } else if (y_element == 1) { broad_y_ = JUST(functional::Expand(y, *x->shape())); } else if (x->shape()->NumAxes() != y->shape()->NumAxes()) { return Error::RuntimeError() << "The size of tensor a (" << x->shape()->NumAxes() << ") must match the size of tensor b " "(" << y->shape()->NumAxes() << ") at non-singleton dimension 1"; } else { return Error::RuntimeError() << ""; } return BinaryFloatFunctor::operator()(broad_x_, broad_y_); } }; class PowFunctor : public BinaryFloatFunctor { public: PowFunctor() { op_ = CHECK_JUST(one::OpBuilder("pow").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { if (*x->shape() != *y->shape()) { return BroadcastPow(x, y); } return BinaryFloatFunctor::operator()(x, y); } }; class FloorDivFunctor : public BinaryFunctor { public: FloorDivFunctor() { op_ = CHECK_JUST(one::OpBuilder("floordiv").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { return BinaryFunctor::operator()(x, y); } }; class TruncDivFunctor : public BinaryFunctor { public: TruncDivFunctor() { op_ = CHECK_JUST(one::OpBuilder("truncdiv").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { return BinaryFunctor::operator()(x, y); } }; class LerpFunctor { public: LerpFunctor() { op_ = CHECK_JUST( one::OpBuilder("lerp").Input("start").Input("end").Input("weight").Output("out").Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const std::shared_ptr& weight) const { const int64_t weight_elem_cnt = weight->nelement(); CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes()) << Error::RuntimeError() << "expected dim" << start->shape()->NumAxes() << "for `end` but got dim" << end->shape()->NumAxes(); CHECK_EQ_OR_RETURN(start->dtype()->data_type(), weight->dtype()->data_type()) << Error::RuntimeError() << "expected dtype " << start->dtype()->name() << " for `weights` but got dtype " << weight->dtype()->name(); auto broadcast_shape = *start->shape(); if (*start->shape() != *end->shape() || *start->shape() != *weight->shape()) { broadcast_shape = *JUST( InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape(), *weight->shape()})); } if (weight_elem_cnt == 1 && weight->is_eager() && !weight->requires_grad()) { std::shared_ptr cast_double_weight = JUST(functional::Cast(weight, DType::Double(), /*pin_memory=*/false)); double weight_scalar = JUST(GetItemInScalarTensor(cast_double_weight)); return functional::ScalarLerp(start, end, weight_scalar); } std::shared_ptr broadcast_start = start; std::shared_ptr broadcast_end = end; std::shared_ptr broadcast_weight = weight; if (*start->shape() != broadcast_shape) { broadcast_start = JUST(functional::Expand(start, broadcast_shape)); } if (*end->shape() != broadcast_shape) { broadcast_end = JUST(functional::Expand(end, broadcast_shape)); } if (*weight->shape() != broadcast_shape) { broadcast_weight = JUST(functional::Expand(weight, broadcast_shape)); } return OpInterpUtil::Dispatch(*op_, {broadcast_start, broadcast_end, broadcast_weight}); } private: std::shared_ptr op_; }; class InplaceLerpFunctor { public: InplaceLerpFunctor() { lerp_op_ = CHECK_JUST( one::OpBuilder("lerp").Input("start").Input("end").Input("weight").Output("out").Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const std::shared_ptr& weight) const { const int64_t weight_elem_cnt = weight->nelement(); CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes()) << Error::RuntimeError() << "expected dim" << start->shape()->NumAxes() << "for `end` but got dim" << end->shape()->NumAxes(); CHECK_EQ_OR_RETURN(start->dtype()->data_type(), weight->dtype()->data_type()) << Error::RuntimeError() << "expected dtype " << start->dtype()->name() << " for `weights` but got dtype " << weight->dtype()->name(); if (weight_elem_cnt == 1 && weight->is_eager() && !weight->requires_grad()) { std::shared_ptr cast_double_weight = JUST(functional::Cast(weight, DType::Double(), /*pin_memory=*/false)); double weight_scalar = JUST(GetItemInScalarTensor(cast_double_weight)); JUST(functional::ScalarInplaceLerp(start, end, weight_scalar)); return start; } auto broadcast_shape = *start->shape(); if (*start->shape() != *end->shape() || *start->shape() != *weight->shape()) { broadcast_shape = *JUST( InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape(), *weight->shape()})); } std::shared_ptr broadcast_start = JUST(Identity(start)); std::shared_ptr broadcast_end = JUST(Identity(end)); std::shared_ptr broadcast_weight = JUST(Identity(weight)); if (*start->shape() != broadcast_shape) { broadcast_start = JUST(view::Expand(start, broadcast_shape)); } if (*end->shape() != broadcast_shape) { broadcast_end = JUST(view::Expand(end, broadcast_shape)); } if (*weight->shape() != broadcast_shape) { broadcast_weight = JUST(view::Expand(weight, broadcast_shape)); } TensorProcessor tensor_processor; if (broadcast_end->requires_grad() || broadcast_weight->requires_grad()) { JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({JUST(Identity(broadcast_start)), broadcast_end, broadcast_weight}) .Apply()); } else { JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({broadcast_start, broadcast_end, broadcast_weight}) .Apply()); } const TensorTuple& input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& start_cast = input_vec.at(0); const std::shared_ptr& end_cast = input_vec.at(1); JUST(CheckInplaceValid(broadcast_start)); JUST(CheckInplaceCastValid(broadcast_start, start_cast)); JUST(CheckInplaceShapeCanExpandTo(*start_cast->shape(), *end_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = start; JUST(OpInterpUtil::Dispatch(*lerp_op_, input_vec, outputs.get())); return outputs->at(0); } private: std::shared_ptr lerp_op_; }; class LerpGradFunctor { public: LerpGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("lerp_grad") .Input("start") .Input("end") .Input("weight") .Input("out_diff") .Output("start_diff") .Output("end_diff") .Output("weight_diff") .Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const std::shared_ptr& weight, const std::shared_ptr& out_diff) const { return OpInterpUtil::Dispatch(*op_, {start, end, weight, out_diff}, {}); } private: std::shared_ptr op_; }; class BroadcastFModFunctor : public BinaryFunctor { public: BroadcastFModFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_fmod").Input("x").Input("y").Output("z").Build()); } }; class BroadcastEqualFunctor : public BinaryFunctor { public: BroadcastEqualFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_equal").Input("x").Input("y").Output("z").Build()); } }; class EqualFunctor { public: EqualFunctor() { broadcast_equal_op_ = CHECK_JUST(one::OpBuilder("broadcast_equal").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { if (*x->shape() != *y->shape()) { return false; } if (x->nelement() == 0) { return true; } std::shared_ptr output = JUST( ReduceAllWhole(JUST(OpInterpUtil::Dispatch(*broadcast_equal_op_, {x, y}, {})))); bool status = JUST(GetItemInScalarTensor(output)); return status; } private: std::shared_ptr broadcast_equal_op_; }; class BroadcastNotEqualFunctor : public BinaryFunctor { public: BroadcastNotEqualFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_not_equal").Input("x").Input("y").Output("z").Build()); } }; class BroadcastGreaterFunctor : public BinaryFunctor { public: BroadcastGreaterFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_greater").Input("x").Input("y").Output("z").Build()); } }; class InplaceBroadcastGreaterFunctor { public: InplaceBroadcastGreaterFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_inplace_greater").Input("x").Input("y").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply()); const TensorTuple& input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& x_cast = input_vec.at(0); const std::shared_ptr& y_cast = input_vec.at(1); JUST(CheckInplaceValid(x)); JUST(CheckInplaceCastValid(x, x_cast)); JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get())); return outputs->at(0); } private: std::shared_ptr op_; }; class BroadcastGreaterEqualFunctor : public BinaryFunctor { public: BroadcastGreaterEqualFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_greater_equal").Input("x").Input("y").Output("z").Build()); } }; class BroadcastLogicalAndFunctor : public BinaryFunctor { public: BroadcastLogicalAndFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_logical_and").Input("x").Input("y").Output("z").Build()); } }; class BroadcastLogicalOrFunctor : public BinaryFunctor { public: BroadcastLogicalOrFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_logical_or").Input("x").Input("y").Output("z").Build()); } }; class BroadcastLogicalXorFunctor : public BinaryFunctor { public: BroadcastLogicalXorFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_logical_xor").Input("x").Input("y").Output("z").Build()); } }; class BroadcastBitwiseAndFunctor : public BinaryFunctor { public: BroadcastBitwiseAndFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_bitwise_and").Input("x").Input("y").Output("z").Build()); } }; class BroadcastBitwiseOrFunctor : public BinaryFunctor { public: BroadcastBitwiseOrFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_bitwise_or").Input("x").Input("y").Output("z").Build()); } }; class BroadcastBitwiseXorFunctor : public BinaryFunctor { public: BroadcastBitwiseXorFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_bitwise_xor").Input("x").Input("y").Output("z").Build()); } }; class BroadcastLessFunctor : public BinaryFunctor { public: BroadcastLessFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_less").Input("x").Input("y").Output("z").Build()); } }; class BroadcastLessEqualFunctor : public BinaryFunctor { public: BroadcastLessEqualFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_less_equal").Input("x").Input("y").Output("z").Build()); } }; class BroadcastIsCloseFunctor { public: BroadcastIsCloseFunctor() { eq_nan_op_ = CHECK_JUST( one::OpBuilder("broadcast_isclose_eq_nan").Input("x").Input("y").Output("z").Build()); neq_nan_op_ = CHECK_JUST( one::OpBuilder("broadcast_isclose_neq_nan").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const float atol, const float rtol, const bool equal_nan) const { auto& attr = THREAD_CACHED_MUTABLE_ATTR_MAP("atol", "rtol", "equal_nan"); attr.SetAllAttrs(atol, rtol, equal_nan); if (equal_nan) { return OpInterpUtil::Dispatch(*eq_nan_op_, {x, y}, attr); } else { return OpInterpUtil::Dispatch(*neq_nan_op_, {x, y}, attr); } } private: std::shared_ptr eq_nan_op_; std::shared_ptr neq_nan_op_; }; class ScalarAddByTensorFunctor : public InplaceableBinaryFunctor { public: ScalarAddByTensorFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_add_by_tensor").Input("x").Input("scalar").Output("y").Build()); } }; // this functor just for test host memory input class HostScalarAddByTensorFunctor { public: HostScalarAddByTensorFunctor() { op_ = CHECK_JUST( one::OpBuilder("host_scalar_add_by_tensor").Input("x").Input("scalar").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& scalar) const { return OpInterpUtil::Dispatch(*op_, {x, scalar}); } private: std::shared_ptr op_; }; class ScalarSubByTensorFunctor : public BinaryFunctor { public: ScalarSubByTensorFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_sub_by_tensor").Input("x").Input("scalar").Output("y").Build()); } }; class ScalarMulByTensorFunctor : public BinaryFunctor { public: ScalarMulByTensorFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_mul_by_tensor").Input("x").Input("scalar").Output("y").Build()); } }; class ScalarDivByTensorFunctor : public BinaryFunctor { public: ScalarDivByTensorFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_div_by_tensor").Input("x").Input("scalar").Output("y").Build()); } }; class BroadcastZetaFunctor : public BinaryFloatFunctor { public: BroadcastZetaFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_zeta").Input("x").Input("y").Output("z").Build()); } }; class ZetaScalarTensorFunctor { public: Maybe operator()(const Scalar x, const std::shared_ptr& y) const { auto scalar_tensor = JUST(functional::FullLike(y, x)); // wrap scalar to tensor return functional::BroadcastZeta(scalar_tensor, y); } }; class ZetaTensorScalarFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar y) const { auto scalar_tensor = JUST(functional::FullLike(x, y)); // wrap scalar to tensor return functional::BroadcastZeta(x, scalar_tensor); } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Add"); m.add_functor("Addcmul"); m.add_functor("InplaceAddcmul"); m.add_functor("Atan2"); m.add_functor("Sub"); m.add_functor("Mul"); m.add_functor("InplaceMul"); m.add_functor("InplaceDiv"); m.add_functor("Div"); m.add_functor("DivMode"); m.add_functor("Pow"); m.add_functor("BroadcastPow"); m.add_functor("BroadcastEqual"); m.add_functor("Equal"); m.add_functor("BroadcastNotEqual"); m.add_functor("BroadcastGreater"); m.add_functor("InplaceBroadcastGreater"); m.add_functor("BroadcastGreaterEqual"); m.add_functor("BroadcastLogicalAnd"); m.add_functor("BroadcastLogicalOr"); m.add_functor("BroadcastLogicalXor"); m.add_functor("BroadcastBitwiseAnd"); m.add_functor("BroadcastBitwiseOr"); m.add_functor("BroadcastBitwiseXor"); m.add_functor("BroadcastLess"); m.add_functor("BroadcastLessEqual"); m.add_functor("ScalarAddByTensor"); m.add_functor("HostScalarAddByTensor"); m.add_functor("ScalarSubByTensor"); m.add_functor("ScalarMulByTensor"); m.add_functor("ScalarDivByTensor"); m.add_functor("BroadcastFMod"); m.add_functor("FloorDiv"); m.add_functor("TruncDiv"); m.add_functor("IsClose"); m.add_functor("Lerp"); m.add_functor("InplaceLerp"); m.add_functor("LerpGrad"); m.add_functor("BroadcastZeta"); m.add_functor("ZetaScalarTensor"); m.add_functor("ZetaTensorScalar"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/binary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ #define ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/tensor_processor.h" #include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { namespace functional { namespace impl { std::string TensorDeviceToString(const std::shared_ptr& tensor); Maybe CastDeviceForCPUScalarTensor(std::shared_ptr& tensor, std::shared_ptr& other, bool inplace); class BinaryFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false)); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple); } protected: BinaryFunctor() = default; virtual ~BinaryFunctor() = default; std::shared_ptr op_; }; class BinaryFloatFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false)); TensorProcessor tensor_processor; if (promoteTypes(tensor_x->dtype(), tensor_y->dtype())->is_integer()) { tensor_processor.AddInputs({tensor_x, tensor_y}, DType::Float()); } else { tensor_processor.AddInputs({tensor_x, tensor_y}) .PromoteInputsToCommonDtype(true) .PromoteIntegerInputsToFloatDtype(true); } JUST(tensor_processor.Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple); } protected: BinaryFloatFunctor() = default; virtual ~BinaryFloatFunctor() = default; std::shared_ptr op_; }; class BinaryGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const std::shared_ptr& dz) const { TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y, dz}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple); } protected: BinaryGradFunctor() = default; virtual ~BinaryGradFunctor() = default; std::shared_ptr op_; }; class InplaceableBinaryFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, bool inplace) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, inplace)); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); if (inplace) { std::shared_ptr& x_cast = input_tuple.at(0); std::shared_ptr& y_cast = input_tuple.at(1); JUST(CheckInplaceCastValid(x, x_cast)); JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x_cast; JUST(OpInterpUtil::Dispatch(*op_, input_tuple, outputs.get())); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op_, input_tuple); } } protected: InplaceableBinaryFunctor() = default; virtual ~InplaceableBinaryFunctor() = default; std::shared_ptr op_; }; } // namespace impl } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ ================================================ FILE: oneflow/core/functional/impl/binary_grad_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" namespace oneflow { namespace one { namespace functional { namespace impl { #define BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name, base) \ class class_name##XGradFunctor : public base { \ public: \ class_name##XGradFunctor() { \ op_ = CHECK_JUST(one::OpBuilder(std::string("") + op_type_name + "_x_grad") \ .Input("x") \ .Input("y") \ .Input("dz") \ .Output("dx") \ .Build()); \ } \ }; \ class class_name##YGradFunctor : public base { \ public: \ class_name##YGradFunctor() { \ op_ = CHECK_JUST(one::OpBuilder(std::string("") + op_type_name + "_y_grad") \ .Input("x") \ .Input("y") \ .Input("dz") \ .Output("dy") \ .Build()); \ } \ }; #define INSTANTIAT_BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name) \ BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name, BinaryGradFunctor); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_BINARY_ELEMENTWISE_GRAD_FUNCTOR, MATH_BINARY_ELEMENTWISE_FUNC_SEQ); } // namespace impl using namespace impl; #define ADD_BINARY_GRAD_FUNCTOR(class_name, functor_name) \ m.add_functor(std::string("") + functor_name + "XGrad"); \ m.add_functor(std::string("") + functor_name + "YGrad"); ONEFLOW_FUNCTION_LIBRARY(m) { ADD_BINARY_GRAD_FUNCTOR(Pow, "Pow"); ADD_BINARY_GRAD_FUNCTOR(Atan2, "Atan2"); ADD_BINARY_GRAD_FUNCTOR(FloorDiv, "FloorDiv"); ADD_BINARY_GRAD_FUNCTOR(TruncDiv, "TruncDiv"); ADD_BINARY_GRAD_FUNCTOR(Xdivy, "Xdivy"); ADD_BINARY_GRAD_FUNCTOR(Xlogy, "Xlogy"); }; #undef ADD_BINARY_GRAD_FUNCTOR } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/comm_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/common/flat_shape.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace { #define OF_KERNEL_NOT_SUPPORT_ERROR(op_type, device_type) \ Error::RuntimeError() << op_type << " not suport for the device (" \ << DeviceType_Name(device_type) << ") because eager kernel of " << op_type \ << " is not registered" class EagerCclKernelRegContext final : public user_op::KernelRegContext { public: explicit EagerCclKernelRegContext(DeviceType device_type) : device_type_(device_type) {} ~EagerCclKernelRegContext() = default; DeviceType device_type() const override { return device_type_; } const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::vector>& inputs() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::vector>& outputs() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { PRINT_BUG_PROMPT_AND_ABORT(); } private: DeviceType device_type_; }; Maybe RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) { EagerCclKernelRegContext reg_ctx(device_type); return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx); } static constexpr auto* CheckCclKernelRegistered = DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable); bool IsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } Maybe EagerCclAllReduce(Symbol parallel_desc) { CHECK_OR_RETURN( JUST(CheckCclKernelRegistered("eager_ccl_all_reduce", parallel_desc->device_type()))) << OF_KERNEL_NOT_SUPPORT_ERROR("AllReduce", parallel_desc->device_type()); return one::OpBuilder("eager_ccl_all_reduce", *JUST(UniqueStr("eager_ccl_all_reduce"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } static constexpr auto* CachedEagerCclAllReduceOpExpr = DECORATE(&EagerCclAllReduce, ThreadLocal); Maybe EagerCclReduceScatter(Symbol parallel_desc, const std::string& op_type) { CHECK_OR_RETURN( JUST(CheckCclKernelRegistered("eager_ccl_reduce_scatter", parallel_desc->device_type()))) << OF_KERNEL_NOT_SUPPORT_ERROR("ReduceScatter", parallel_desc->device_type()); return one::OpBuilder("eager_ccl_reduce_scatter", *JUST(UniqueStr("eager_ccl_reduce_scatter"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Attr("op_type", op_type) .Build(); } static constexpr auto* CachedCclReduceScatterOpExpr = DECORATE(&EagerCclReduceScatter, ThreadLocalCopiable); Maybe EagerCclAllGather(Symbol parallel_desc) { CHECK_OR_RETURN( JUST(CheckCclKernelRegistered("eager_ccl_all_gather", parallel_desc->device_type()))) << OF_KERNEL_NOT_SUPPORT_ERROR("AllGather", parallel_desc->device_type()); return one::OpBuilder("eager_ccl_all_gather", *JUST(UniqueStr("eager_ccl_all_gather"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } static constexpr auto* CachedEagerCclAllGatherOpExpr = DECORATE(&EagerCclAllGather, ThreadLocal); Maybe EagerCclS2S(Symbol parallel_desc, Symbol src_sbp, Symbol dst_sbp) { return one::OpBuilder("eager_ccl_s2s", *JUST(UniqueStr("eager_ccl_s2s"))) .Input("in") .Output("out") .Attr("in_split_axis", src_sbp->split_parallel().axis()) .Attr("out_split_axis", dst_sbp->split_parallel().axis()) .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } auto* CachedEagerCclS2SOpExpr = DECORATE(&EagerCclS2S, ThreadLocal); Maybe EagerCclReduce(Symbol parallel_desc, int64_t root) { CHECK_OR_RETURN(JUST(CheckCclKernelRegistered("eager_ccl_reduce", parallel_desc->device_type()))) << OF_KERNEL_NOT_SUPPORT_ERROR("Reduce", parallel_desc->device_type()); return one::OpBuilder("eager_ccl_reduce", *JUST(UniqueStr("eager_ccl_reduce"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Attr("root", root) .Build(); } auto* CachedEagerCclReduceOpExpr = DECORATE(&EagerCclReduce, ThreadLocal); Maybe RankGroupAndDeviceType2AllReduceOpExpr(Symbol rank_group, DeviceType device_type) { CHECK_OR_RETURN(JUST(CheckCclKernelRegistered("eager_ccl_all_reduce", device_type))) << OF_KERNEL_NOT_SUPPORT_ERROR("AllReduce", device_type); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group)); return one::OpBuilder("eager_ccl_all_reduce") .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } auto* CachedRankGroupAndDeviceType2AllReduceOpExpr = DECORATE(&RankGroupAndDeviceType2AllReduceOpExpr, ThreadLocal); Maybe RankGroupAndDeviceType2AllGatherOpExpr(Symbol rank_group, DeviceType device_type) { CHECK_OR_RETURN(JUST(CheckCclKernelRegistered("eager_ccl_all_gather", device_type))) << OF_KERNEL_NOT_SUPPORT_ERROR("AllGather", device_type); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group)); return one::OpBuilder("eager_ccl_all_gather") .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } auto* CachedRankGroupAndDeviceType2AllGatherOpExpr = DECORATE(&RankGroupAndDeviceType2AllGatherOpExpr, ThreadLocal); Maybe RankGroupAndDeviceType2ReduceScatterOpExpr(Symbol rank_group, DeviceType device_type) { CHECK_OR_RETURN(JUST(CheckCclKernelRegistered("eager_ccl_reduce_scatter", device_type))) << OF_KERNEL_NOT_SUPPORT_ERROR("ReduceScatter", device_type); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group)); return one::OpBuilder("eager_ccl_reduce_scatter", *JUST(UniqueStr("eager_ccl_reduce_scatter"))) .Input("in") .Output("out") .Attr("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) .Build(); } auto* CachedRankGroupAndDeviceType2ReduceScatterOpExpr = DECORATE(&RankGroupAndDeviceType2ReduceScatterOpExpr, ThreadLocal); #undef OF_KERNEL_NOT_SUPPORT_ERROR } // namespace class CommBroadcastFunctor { public: CommBroadcastFunctor() = default; Maybe operator()(const std::shared_ptr& x, int64_t src_rank, bool inplace) const { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); DeviceType device_type = JUST(x->device())->enum_type(); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group)); return one::Broadcast(x, src_rank, parallel_desc, inplace); } }; class CommBroadcastTensorsFunctor { public: CommBroadcastTensorsFunctor() = default; Maybe operator()(const one::TensorTuple& inputs, int64_t src_rank, bool inplace) const { if (inputs.empty()) { return inputs; } const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& x = JUST(VectorAt(inputs, 0)); DeviceType device_type = JUST(x->device())->enum_type(); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group)); return one::Broadcast(inputs, src_rank, parallel_desc, inplace); } }; namespace { Maybe RawStreamTouchFunctorOpExpr(size_t input_size) { return one::OpBuilder("eager_ccl_touch", *JUST(UniqueStr("eager_ccl_touch"))) .Input("in", input_size) .Build(); } static constexpr auto* StreamTouchFunctorOpExpr = DECORATE(&RawStreamTouchFunctorOpExpr, ThreadLocal); } // namespace class StreamTouchFunctor { public: StreamTouchFunctor() = default; Maybe operator()(const one::TensorTuple& inputs) const { if (inputs.empty()) { return Maybe::Ok(); } std::shared_ptr op_expr = JUST(StreamTouchFunctorOpExpr(inputs.size())); TensorTuple outputs{}; JUST(OpInterpUtil::Dispatch(*op_expr, inputs, &outputs)); return Maybe::Ok(); } }; class LocalAllReduceFunctor { public: LocalAllReduceFunctor() = default; Maybe operator()(const std::shared_ptr& x, bool inplace) const { const auto& device = JUST(x->device()); CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); DeviceType device_type = device->enum_type(); std::shared_ptr op_expr = JUST(CachedRankGroupAndDeviceType2AllReduceOpExpr(rank_group, device_type)); auto op_input = x; if (const auto& static_zeros_tensor = std::dynamic_pointer_cast(x)) { op_input = std::dynamic_pointer_cast(JUST(static_zeros_tensor->AsLocalTensor())); } if (inplace) { JUST(CheckInplaceValid(op_input)); TensorTuple outputs{op_input}; JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs)); return outputs[0]; } else { return OpInterpUtil::Dispatch(*op_expr, {op_input}, {}); } } }; class GlobalAllReduceFunctor { public: GlobalAllReduceFunctor() = default; Maybe operator()(const std::shared_ptr& x) const { { CHECK_OR_RETURN(x->is_global()) << "Tensor is not global"; CHECK_OR_RETURN(NdSbpIsAllPartialSum(*JUST(x->nd_sbp()))) << "Tensor's sbp must be partial_sum"; } std::shared_ptr op_expr = JUST(CachedEagerCclAllReduceOpExpr(JUST(x->parallel_desc()))); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class GlobalReduceScatterFunctor { public: GlobalReduceScatterFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::string& op_type) const { { CHECK_OR_RETURN(x->is_global()); // NOLINT if (op_type == "max") { CHECK_OR_RETURN(NdSbpIsAllBroadcast(*JUST(x->nd_sbp()))) << "Tensor's sbp must be broadcast to get reduce_max"; CHECK_EQ_OR_RETURN(JUST(x->parallel_desc())->device_type(), DeviceType::kCUDA) << "reduce_max only support CUDA"; } else if (op_type == "sum") { CHECK_OR_RETURN(NdSbpIsAllPartialSum(*JUST(x->nd_sbp()))) << "Tensor's sbp must be partial_sum to get reduce_sum"; } else { UNIMPLEMENTED_THEN_RETURN(); } } std::shared_ptr op_expr = JUST(CachedCclReduceScatterOpExpr(JUST(x->parallel_desc()), op_type)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class LocalReduceScatterFunctor { public: LocalReduceScatterFunctor() = default; Maybe operator()(const std::shared_ptr& output, const std::shared_ptr& input) const { DataType dtype_val = input->dtype()->data_type(); CHECK_EQ_OR_RETURN(input->shape()->elem_cnt(), output->nelement() * GlobalProcessCtx::WorldSize()) << Error::RuntimeError() << "output tensor size must be equal to world_size times input tensor size"; CHECK_EQ_OR_RETURN(dtype_val, output->dtype()->data_type()) << Error::RuntimeError() << Error::RuntimeError() << "output tensor must have the same type as input tensor"; const Shape& shape = *output->shape(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("output_shape", "output_dtype"); attrs.SetAllAttrs(shape, dtype_val); const auto& device = JUST(input->device()); CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); DeviceType device_type = device->enum_type(); std::shared_ptr op_expr = JUST(CachedRankGroupAndDeviceType2ReduceScatterOpExpr(rank_group, device_type)); auto op_input = input; if (const auto& static_zeros_tensor = std::dynamic_pointer_cast(input)) { op_input = std::dynamic_pointer_cast(JUST(static_zeros_tensor->AsLocalTensor())); } TensorTuple outputs{output}; JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs, attrs)); return outputs[0]; } }; class GlobalAllGatherFunctor { public: GlobalAllGatherFunctor() = default; Maybe operator()(const std::shared_ptr& x) const { { CHECK_OR_RETURN(x->is_global()) << "Tensor is not global"; CHECK_OR_RETURN(NdSbpIsAllSplit(*JUST(x->nd_sbp()), 0)) << "Tensor's sbp must be split to get all_gather"; } std::shared_ptr op_expr = JUST(CachedEagerCclAllGatherOpExpr(JUST(x->parallel_desc()))); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class LocalAllGatherFunctor { public: LocalAllGatherFunctor() = default; Maybe operator()(const std::shared_ptr& output, const std::shared_ptr& input) const { DataType dtype_val = input->dtype()->data_type(); CHECK_EQ_OR_RETURN(input->shape()->elem_cnt() * GlobalProcessCtx::WorldSize(), output->nelement()) << Error::RuntimeError() << "output tensor size must be equal to world_size times input tensor size"; CHECK_EQ_OR_RETURN(dtype_val, output->dtype()->data_type()) << Error::RuntimeError() << Error::RuntimeError() << "output tensor must have the same type as input tensor"; const Shape& shape = *output->shape(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("output_shape", "output_dtype"); attrs.SetAllAttrs(shape, dtype_val); const auto& device = JUST(input->device()); CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); DeviceType device_type = device->enum_type(); std::shared_ptr op_expr = JUST(CachedRankGroupAndDeviceType2AllGatherOpExpr(rank_group, device_type)); auto op_input = input; if (const auto& static_zeros_tensor = std::dynamic_pointer_cast(input)) { op_input = std::dynamic_pointer_cast(JUST(static_zeros_tensor->AsLocalTensor())); } TensorTuple outputs{output}; JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs, attrs)); return outputs[0]; } }; class GlobalS2SFunctor { public: GlobalS2SFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::vector>& sbp_parallels) const { Symbol in_nd_sbp = JUST(x->nd_sbp()); Symbol out_nd_sbp = JUST(GetNdSbp(sbp_parallels)); { CHECK_OR_RETURN(x->is_global()); // NOLINT CHECK_EQ_OR_RETURN(in_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(in_nd_sbp->sbp_parallel(0))); CHECK_EQ_OR_RETURN(out_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(IsSplitSbp(out_nd_sbp->sbp_parallel(0))); CHECK_NE_OR_RETURN(in_nd_sbp->sbp_parallel(0).split_parallel().axis(), out_nd_sbp->sbp_parallel(0).split_parallel().axis()); } std::shared_ptr op_expr = JUST(CachedEagerCclS2SOpExpr(JUST(x->parallel_desc()), SymbolOf(in_nd_sbp->sbp_parallel(0)), SymbolOf(out_nd_sbp->sbp_parallel(0)))); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class SendFunctor { public: SendFunctor() { op_expr_ = CHECK_JUST(one::OpBuilder("send").Input("in").Build()); } Maybe operator()(const std::shared_ptr& x, int64_t dst, bool send_meta) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dst_process_id"); attrs.SetAllAttrs(dst); if (send_meta) { std::shared_ptr flat_shape = JUST(FlatShape::New(*x->shape())); JUST(ccl::CpuSend(flat_shape.get(), sizeof(*flat_shape), dst)); DataType dtype = x->dtype()->data_type(); JUST(ccl::CpuSend(&dtype, sizeof(dtype), dst)); DeviceType device_type = JUST(Device::GetPlacement(*JUST(x->device())))->device_type(); JUST(ccl::CpuSend(&device_type, sizeof(device_type), dst)); } JUST(OpInterpUtil::Dispatch(*op_expr_, {x}, attrs)); return Maybe::Ok(); } private: std::shared_ptr op_expr_; }; class RecvFunctor { public: RecvFunctor() { op_expr_ = CHECK_JUST(one::OpBuilder("recv").Output("out").Build()); } Maybe operator()(int64_t src, const Optional& optional_shape, const Optional>& optional_dtype, const Optional>& optional_device, const Optional& out) const { Shape shape; DataType data_type = DataType::kInvalidDataType; Symbol device; if (optional_shape.has_value() && optional_dtype.has_value() && optional_device.has_value()) { shape = *JUST(optional_shape); data_type = JUST(optional_dtype)->data_type(); device = JUST(optional_device); } else if (!optional_shape.has_value() && !optional_dtype.has_value() && !optional_device.has_value()) { FlatShape flat_shape{}; JUST(ccl::CpuRecv(&flat_shape, sizeof(flat_shape), src)); shape = *JUST(flat_shape.ToShape()); JUST(ccl::CpuRecv(&data_type, sizeof(data_type), src)); DeviceType device_type = DeviceType::kInvalidDevice; JUST(ccl::CpuRecv(&device_type, sizeof(device_type), src)); device = JUST(Device::New(*JUST(DeviceTag4DeviceType(device_type)))); } else { UNIMPLEMENTED_THEN_RETURN() << "All or none of shape, dtype and device should have value."; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("src_process_id", "shape", "dtype", "device_type", "device_id"); attrs.SetAllAttrs(src, shape, data_type, device->type(), device->device_id()); OpExprInterpContext op_expr_interp_context(attrs, device); if (out.has_value()) { std::shared_ptr out_tensor = JUST(out); Symbol out_tensor_device = JUST(out_tensor->device()); CHECK_OR_RETURN(out_tensor_device == device); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = out_tensor; JUST(OpInterpUtil::Dispatch(*op_expr_, {}, outputs.get(), op_expr_interp_context)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_expr_, {}, op_expr_interp_context); } private: std::shared_ptr op_expr_; }; class LocalReduceFunctor { public: LocalReduceFunctor() = default; Maybe operator()(const std::shared_ptr& x, int64_t dst, bool inplace) const { const auto& device = JUST(x->device()); { CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()); } static thread_local std::unordered_map, Symbol>, Symbol> rank_group_with_device2parallel_desc; const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); auto iter = rank_group_with_device2parallel_desc.find({rank_group, device}); Symbol parallel_desc; if (iter == rank_group_with_device2parallel_desc.end()) { ParallelConf parallel_conf; parallel_conf.set_device_tag(device->type()); JUST(rank_group->ForEachRank([¶llel_conf](int64_t rank) -> Maybe { parallel_conf.add_device_name("@" + std::to_string(rank) + ":" + std::to_string(GlobalProcessCtx::LocalRank(rank))); return Maybe::Ok(); })); parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); rank_group_with_device2parallel_desc[{rank_group, device}] = parallel_desc; } else { parallel_desc = iter->second; } std::shared_ptr op_expr = JUST(CachedEagerCclReduceOpExpr(parallel_desc, dst)); if (inplace) { TensorTuple outputs{x}; JUST(OpInterpUtil::Dispatch(*op_expr, {x}, &outputs)); return x; } else { return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("StreamTouch"); m.add_functor("CommBroadcast"); m.add_functor("CommBroadcastTensors"); m.add_functor("LocalAllReduce"); m.add_functor("LocalAllGather"); m.add_functor("LocalReduceScatter"); m.add_functor("GlobalAllReduce"); m.add_functor("GlobalReduceScatter"); m.add_functor("GlobalAllGather"); m.add_functor("GlobalS2S"); m.add_functor("Send"); m.add_functor("Recv"); m.add_functor("LocalReduce"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/common.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/memory/memory_case_util.h" namespace oneflow { namespace one { namespace functional { namespace { Maybe InferUnifiedShapeForBroadcasting(const Shape& input_shape, const Shape& other_shape) { // same shapes need no broadcasting if (input_shape == other_shape) { return input_shape; } const auto unify_shapes_with_same_num_axes = [](const Shape& input_shape, const Shape& other_shape) -> Maybe { // num_axes.first == num_axes.second Shape target; for (size_t i = 0; i < input_shape.NumAxes() /* both input_shape and other_shape are ok */; ++i) { const auto num_in_curr_dim = std::make_pair(input_shape.At(i), other_shape.At(i)); // A = (2, ), B = (2, ), A[0] == B[0], so C = (2, ) if (num_in_curr_dim.first == num_in_curr_dim.second) { target.push_back(num_in_curr_dim.first); continue; } // A = (2, ), B = (3, ), A[0] != B[0] and A[0] != 1 and B[0] != 1, so raise RuntimeError if (num_in_curr_dim.first != 1 && num_in_curr_dim.second != 1) { return Error::RuntimeError() << fmt::format("input and other can't be broadcasted to a single shape. [input's " "shape: {}, other's shape: {}].", input_shape.ToString(), other_shape.ToString()); } // A = (2, ), B = (1, ), A[0] != B[0] but B[0] == 1, so C = (2, ) target.push_back( num_in_curr_dim.first == 1 ? num_in_curr_dim.second : num_in_curr_dim.first); // num_in_curr_dim.first and num_in_curr_dim.second can't // be 1 at the same time } return target; }; const int64_t input_num_axes = input_shape.NumAxes(); const int64_t other_num_axes = other_shape.NumAxes(); if (input_num_axes == other_num_axes) { return unify_shapes_with_same_num_axes(input_shape, other_shape); } const int64_t unified_num_axes = std::max(input_num_axes, other_num_axes); // shape = (3, 4) and unified_num_axes = 3 ==> shape will be (1, 3, 4) const auto expand_shape_if_necessary = [unified_num_axes](const Shape& shape_to_expand) { const int64_t shape_to_expand_num_axes = shape_to_expand.NumAxes(); if (shape_to_expand_num_axes < unified_num_axes) { auto new_shape = Shape::Ones(unified_num_axes); std::copy(shape_to_expand.begin(), shape_to_expand.end(), new_shape.begin() + (unified_num_axes - shape_to_expand_num_axes)); return new_shape; } return shape_to_expand; }; return unify_shapes_with_same_num_axes(expand_shape_if_necessary(input_shape), expand_shape_if_necessary(other_shape)); } } // namespace bool IsStaticZerosTensor(const std::shared_ptr& x) { return nullptr != std::dynamic_pointer_cast(x); } bool IsInplaceValid(const std::shared_ptr& x) { return !autograd::GradMode::is_enabled() || !(x->is_leaf() && x->requires_grad()); } bool IsScalarTensor(const std::shared_ptr& x) { return x->shape()->NumAxes() == 0 && x->shape()->elem_cnt() == 1; } Maybe ComputeNonOverlappingAndDense(const std::shared_ptr& x) { // A function used to check whether the tensor is non-overlapping and dense, reference: (pytorch) // c10/core/TensorImpl.cpp const int64_t ndim = x->ndim(); const auto& shape = x->shape(); const auto& stride = JUST(x->stride()); // If 1D tensor and shape(0) < 2 or stride(0) == 1 then true if (ndim == 1) { return shape->at(0) < 2 || stride->at(0) == 1; } small_vector perm; perm.resize(ndim); for (int64_t i = 0; i < ndim; ++i) { perm[i] = i; } // Sort by strides, leaving 0 and 1 sized dims at the end of the array std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { if (shape->at(a) < 2) { return false; } else if (shape->at(b) < 2) { return true; } return stride->at(a) < stride->at(b); }); // CHeck if tareget stride == required stride auto require_stride = 1; for (int64_t i = 0; i < ndim; ++i) { const auto size_perm_i = shape->at(perm[i]); if (size_perm_i < 2) { return true; } if (stride->at(perm[i]) != require_stride) { return false; } require_stride *= size_perm_i; } return true; } Maybe IsNonOverlappingAndDense(const std::shared_ptr& x) { // if tensor is_contiguous or ComputeNonOverlappingAndDense = True, then indicates it's memory // layout is non-overlapping and dense. return x->is_contiguous() || JUST(ComputeNonOverlappingAndDense(x)); } Maybe> CheckAxis(const std::vector& axis, const int32_t& ndim) { const int32_t naxis = axis.size(); int32_t reduce_ndim = naxis; if (naxis == 0 || ndim == 0) { reduce_ndim = ndim; }; std::vector reduce_axis(reduce_ndim); if (naxis == 0) { std::iota(reduce_axis.begin(), reduce_axis.end(), 0); } else { JUST(dim_list_to_bitset(axis, ndim)); // checking axis[dim]'s validation for (int32_t i = 0; i < naxis; i++) { if (i < reduce_ndim) { reduce_axis[i] = JUST(maybe_wrap_dim(axis[i], ndim)); }; } } return reduce_axis; } Maybe CheckInplaceValid(const std::shared_ptr& x) { CHECK_OR_RETURN(IsInplaceValid(x)) << Error::RuntimeError() << "a leaf Tensor that requires grad is being used in an in-place operation"; return Maybe::Ok(); } Maybe CheckInplaceCastValid(const std::shared_ptr& x, const std::shared_ptr& x_cast) { CHECK_OR_RETURN(*x->dtype() == *x_cast->dtype()) << Error::RuntimeError() << "result type " << x_cast->dtype()->name() << " can't be cast to the desired output type " << x->dtype()->name(); return Maybe::Ok(); } Maybe CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) { if (shape == expand_shape) { return Maybe::Ok(); } CHECK_OR_RETURN(expand_shape.NumAxes() >= shape.NumAxes()) << Error::RuntimeError() << "Can not expand origin shape " << shape.ToString() << " to " << expand_shape.ToString() << " in an inplace operation"; int shift = expand_shape.NumAxes() - shape.NumAxes(); for (int i = expand_shape.NumAxes() - 1; i >= 0; --i) { int index = i - shift; if (index >= 0) { int dim_a = expand_shape.At(i); int dim_b = shape.At(index); // NOTE(lixiang): When a dimension of tensor a and tensor b are not equal in size, dim_a needs // to be greater than or equal 0, and dim_b should be equal to 1. CHECK_OR_RETURN(!(dim_a != dim_b && (dim_a < 0 || dim_b != 1))) << Error::RuntimeError() << "Tensor with shape " << expand_shape.ToString() << " doesn't match the broadcast shape in an inplace operation"; } else { // For 0-size tensor, expand_shape.At(i) can equal to 0. CHECK_OR_RETURN(expand_shape.At(i) >= 0); // NOLINT(maybe-need-error-msg) } } return Maybe::Ok(); } Optional ComputeStride(const Shape& shape, const Stride& stride, const Shape& target_shape) { /************************************************* * Description: in some case, view operate is not allowed, so need to check it's validation, * the check refers to torch(aten/src/ATen/native/TensorShape.cpp) *************************************************/ if (stride.size() == 0) { // for scalar input tensor return Stride(target_shape.NumAxes(), 1); } int64_t elem_count = shape.elem_cnt(); int64_t ndim = shape.NumAxes(); int64_t tgt_ndim = target_shape.NumAxes(); DimVector shape_vec = shape.dim_vec(); DimVector tgt_shape_vec = target_shape.dim_vec(); if (elem_count == 0) { return NullOpt; } int64_t view_d = tgt_ndim - 1; int64_t chunk_base_stride = stride.back(); Stride target_stride(tgt_ndim); // stride for each subspace in the chunk // numel in current chunk int64_t tensor_numel = 1; int64_t view_numel = 1; for (int64_t tensor_d = ndim - 1; tensor_d >= 0; tensor_d--) { tensor_numel *= shape_vec[tensor_d]; // if end of tensor size chunk, check view if ((tensor_d == 0) || (shape_vec[tensor_d - 1] != 1 && stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) { while (view_d >= 0 && (view_numel < tensor_numel || tgt_shape_vec[view_d] == 1)) { target_stride[view_d] = view_numel * chunk_base_stride; view_numel *= tgt_shape_vec[view_d]; view_d--; } if (view_numel != tensor_numel) { return NullOpt; } if (tensor_d > 0) { chunk_base_stride = stride[tensor_d - 1]; tensor_numel = 1; view_numel = 1; } } } if (view_d != -1) { return NullOpt; } return target_stride; } Maybe InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape) { int need_infer_axis = -1; int64_t target_elem_count = 1; for (int i = 0; i < shape.NumAxes(); ++i) { if (shape.At(i) < -1) { return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); } else if (shape.At(i) == -1) { CHECK_OR_RETURN_ERROR(need_infer_axis == -1) << Error::RuntimeError() << "only one dimension can be inferred"; need_infer_axis = i; } else { target_elem_count *= shape.At(i); } } Shape infered_shape = shape; if (need_infer_axis == -1) { if (elem_count > 0) { // For 0-size tensor, we don't need to check the element size. CHECK_OR_RETURN_ERROR(target_elem_count == elem_count) << Error::RuntimeError() << "shape '" << shape.ToString() << "' is invalid for input of size " << elem_count; } } else { infered_shape.Set(need_infer_axis, elem_count / target_elem_count); CHECK_OR_RETURN_ERROR(target_elem_count * infered_shape.At(need_infer_axis) == elem_count) << Error::RuntimeError() << "shape '" << shape.ToString() << "' is invalid for input of size " << elem_count; } return infered_shape; } Maybe InferUnifiedShapeForBroadcasting(const std::vector& shapes) { if (shapes.empty()) { return Error::RuntimeError() << "shapes should not be empty."; } if (shapes.size() == 1) { return JUST(VectorAt(shapes, 0)); } auto result = *JUST(InferUnifiedShapeForBroadcasting(JUST(VectorAt(shapes, 0)), JUST(VectorAt(shapes, 1)))); // (1, 2) vs (3, 2) => (3, 2) if (shapes.size() == 2) { return result; } /* (1, 3) vs (3, 1) vs (3, 1, 1) 1. (1, 3) vs (3, 1) => (3, 3) 2. (3, 3) vs (3, 1, 1) => (3, 3, 3) 3. (3, 3, 3) is the final result */ for (auto iter = shapes.begin() + 2; iter != shapes.end(); ++iter) { result = *JUST(InferUnifiedShapeForBroadcasting(result, *iter)); } return result; } /* if input shapes are [(1, 3), (3, 1), (3, 1, 1)] will return ((3, 3, 3), [true, true, true]) means the shape to broadcast to is (3, 3, 3) and all three shapes need broadcasting */ Maybe>> InferUnifiedShapeForBroadcastingWithInfo( const std::vector& shapes) { const auto unified_shape = *JUST(InferUnifiedShapeForBroadcasting(shapes)); std::deque need_to_broadcast; for (const auto& x : shapes) { need_to_broadcast.emplace_back(x != unified_shape); } return std::make_tuple(unified_shape, need_to_broadcast); } Maybe BroadcastSeedToAllRanks(uint64_t* seed, int64_t root) { CHECK_NOTNULL_OR_RETURN(seed) << "seed is not allowed to be nullptr"; const auto& rank_group = JUST(RankGroup::DefaultRankGroup()); const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group)); const auto& meta_transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta)); JUST(ccl::CpuBroadcast(seed, seed, sizeof(*seed), root, parallel_desc, meta_transport_token)); return Maybe::Ok(); } Maybe> GetPermWhenTransposeAxisToLastDim(const int32_t& ndim, const int32_t& axis) { auto wrap_dim = JUST(maybe_wrap_dim(axis, ndim)); std::vector perm(ndim); for (int i = 0; i < ndim - 1; i++) { if (i < wrap_dim) { perm[i] = i; } else { perm[i] = i + 1; } } perm[ndim - 1] = wrap_dim; return perm; } Maybe> GetInversedPerm(const std::vector& perm) { std::vector inversed_perm(perm.size()); for (int i = 0; i < perm.size(); i++) { inversed_perm[perm[i]] = i; } return inversed_perm; } Maybe, bool>> batchify(const std::shared_ptr& input, const int64_t num_spatial_dims, const std::string& func_name) { const int64_t dim_count_no_batch = num_spatial_dims + 1; const int64_t dim_count_batch = dim_count_no_batch + 1; const bool is_batched = (input->ndim() == dim_count_batch); CHECK_EQ_OR_RETURN(input->ndim() == dim_count_no_batch || is_batched, true) << fmt::format( "Expected `{}`D (unbatched) or `{}`D (batched) input to `{}`, but got input of size: `{}`", dim_count_no_batch, dim_count_batch, func_name, input->shape()->DebugStr()); return std::make_tuple(is_batched ? input : JUST(functional::Unsqueeze(input, 0)), is_batched); } template T GetTensorItemValue(const std::shared_ptr& input) { CHECK_EQ_OR_THROW(input->nelement(), 1) << "Input tensor must have exactly one element"; T value; const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &value, eager_blob_object->dptr(), sizeof(T), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; SyncAccessTensorWithTimeOut(input, callback, "const").GetOrThrow(); return value; } Maybe CheckNormalTensorStd(const std::shared_ptr& std) { CHECK_OR_RETURN(!std->dtype()->is_complex()) << "normal expects standard deviation to be non-complex"; if (std->nelement() > 0) { auto std_check = CHECK_JUST(ScalarLogicalGreaterEqual(CHECK_JUST(Min(std)), Scalar(0.0))); CHECK_OR_THROW(GetTensorItemValue(std_check)) << "normal expects all elements of std >= 0.0"; } return Maybe::Ok(); } Maybe CheckNormalTensorStd(const float std) { CHECK_GE_OR_RETURN(std, 0.0) << "normal expects std >= 0.0, but found std " << (std) << ". This may cause an error."; return Maybe::Ok(); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/common.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ #define ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/maybe.h" #include "fmt/core.h" namespace oneflow { namespace one { namespace functional { static constexpr size_t kMaxInputCount = 128; static constexpr size_t kMaxOutputCount = 128; bool IsStaticZerosTensor(const std::shared_ptr& x); bool IsInplaceValid(const std::shared_ptr& x); bool IsScalarTensor(const std::shared_ptr& x); Maybe ComputeNonOverlappingAndDense(const std::shared_ptr& x); Maybe IsNonOverlappingAndDense(const std::shared_ptr& x); Maybe> CheckAxis(const std::vector& axis, const int32_t& ndim); Maybe CheckInplaceValid(const std::shared_ptr& x); Maybe CheckInplaceCastValid(const std::shared_ptr& x, const std::shared_ptr& x_cast); Maybe CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape); Optional ComputeStride(const Shape& shape, const Stride& stride, const Shape& target_shape); Maybe InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape); // returns unified_shape Maybe InferUnifiedShapeForBroadcasting(const std::vector& shapes); // returns tuple Maybe>> InferUnifiedShapeForBroadcastingWithInfo( const std::vector& shapes); Maybe BroadcastSeedToAllRanks(uint64_t* seed, int64_t root = 0); Maybe> GetPermWhenTransposeAxisToLastDim(const int32_t& ndim, const int32_t& axis); Maybe> GetInversedPerm(const std::vector& perm); // batchify function is referenced from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L729 Maybe, bool>> batchify(const std::shared_ptr& input, const int64_t num_spatial_dims, const std::string& func_name); // CheckNormalTensorStd function is referenced from // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/DistributionTemplates.h#L171-L182 Maybe CheckNormalTensorStd(const std::shared_ptr& std); Maybe CheckNormalTensorStd(const float std); } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ ================================================ FILE: oneflow/core/functional/impl/dataset_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace one { namespace functional { namespace impl { class ImageFlipFuntor { public: ImageFlipFuntor() { op_ = CHECK_JUST( one::OpBuilder("image_flip").Input("in").Input("flip_code").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& flip_code) const { return OpInterpUtil::Dispatch(*op_, {x, flip_code}); } private: std::shared_ptr op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ImageFlip"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/eye_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { namespace one { namespace functional { namespace impl { class EyeDevcieFunctor { public: EyeDevcieFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } Maybe operator()(const Scalar& rows, const Optional& cols, const Symbol& dtype, const Optional>& device, const bool& requires_grad) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rows", "cols", "dtype"); attrs.SetAllAttrs(rows.As(), cols.value_or(rows).As(), dtype->data_type()); OpExprInterpContext ctx(attrs); ctx.device = device; auto res = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); JUST(res->set_requires_grad(requires_grad)); return res; } private: std::shared_ptr op_; }; class EyeDeviceStrFunctor { public: Maybe operator()(const Scalar& rows, const Optional& cols, const Symbol& dtype, const std::string& device, const bool& requires_grad) const { const Symbol& dev = JUST(Device::ParseAndNew(device)); return JUST(functional::Eye(rows, cols, dtype, dev, requires_grad)); } }; class GlobalEyeSbpListFunctor { public: GlobalEyeSbpListFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } Maybe operator()(const Scalar& rows, const Optional& cols, const Symbol& dtype, const bool& requires_grad, const Symbol& placement, const std::vector>& sbp_tuple) const { CHECK_EQ_OR_RETURN(sbp_tuple.size(), placement->hierarchy()->NumAxes()) << "len(sbp) == len(placement.hierarchy) required, but " << "len(sbp)==" << sbp_tuple.size() << ", " << "len(placement.hierarchy)==" << placement->hierarchy()->NumAxes(); FOR_RANGE(int32_t, i, 0, sbp_tuple.size()) { CHECK_OR_RETURN(sbp_tuple.at(i)->has_broadcast_parallel()) << "sbp of eye should be broadcast only"; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rows", "cols", "dtype", "nd_sbp"); if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); } } attrs.SetAllAttrs(rows.As(), cols.value_or(rows).As(), dtype->data_type(), nd_sbp); } else { attrs.SetAllAttrs(rows.As(), cols.value_or(rows).As(), dtype->data_type(), NullOpt); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto res = JUST( OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp))); JUST(res->set_requires_grad(requires_grad)); return res; } private: std::shared_ptr op_; }; class GlobalEyeSbpFunctor { public: Maybe operator()(const Scalar& rows, const Optional& cols, const Symbol& dtype, const bool& requires_grad, const Symbol& placement, const Symbol& sbp) const { std::vector> sbp_tuple{sbp}; return JUST(functional::Eye(rows, cols, dtype, requires_grad, placement, sbp_tuple)); } }; } // namespace impl class EyeInplaceFunctor { public: EyeInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rows", "cols", "dtype"); attrs.SetAllAttrs(x->shape()->At(0), x->shape()->At(1), x->dtype()->data_type()); OpExprInterpContext ctx(attrs); ctx.device = JUST(x->device()); JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx)); return outputs->at(0); } private: std::shared_ptr op_; }; using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Eye"); m.add_functor("EyeInplace"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/fused_attention_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/random_mask_like_kernel.h" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/core/common/container_util.h" #include "oneflow/user/kernels/distributions/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace { Maybe ParseDims(const std::string& name, const Shape& shape, const std::string& layout, const Optional& batch_size, const Optional& seq_len, const Optional& num_heads, const Optional& head_size, int64_t* b, int64_t* m, int64_t* h, int64_t* k, bool* bm_packed) { if (shape.NumAxes() == 2) { if (layout == "(BM)(HK)" || layout == "(BM)(H2K)" || layout == "(BM)(H3K)") { *bm_packed = true; CHECK_OR_RETURN(batch_size); CHECK_OR_RETURN(seq_len); *b = JUST(batch_size); *m = JUST(seq_len); int64_t packed_n = 0; if (layout == "(BM)(HK)") { packed_n = 1; } else if (layout == "(BM)(H2K)") { CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be '(BM)(H2K)'"; packed_n = 2; } else if (layout == "(BM)(H3K)") { packed_n = 3; } else { UNIMPLEMENTED_THEN_RETURN(); } const int64_t hidden_size = shape.At(1); if (num_heads) { const int64_t expected_h = JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0) << "The size of the last dimension of the " << name << " tensor should be a multiple of " << packed_h << "."; *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = JUST(head_size); const int64_t packed_k = expected_k * packed_n; CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0) << "The size of the last dimension of the " << name << " tensor should be a multiple of " << packed_k << "."; *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED_THEN_RETURN(); } } else { UNIMPLEMENTED_THEN_RETURN() << name << "_layout should be '(BM)(HK)', '(BM)(H2K)', or '(BM)(H3K)' " "when the number of dimensions of " << name << " tensor is 2."; } } else if (shape.NumAxes() == 3) { if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)" || layout == "BM(H3K)" || layout == "MB(H3K)") { *bm_packed = false; int64_t packed_n = 0; if (layout == "BM(HK)") { *b = shape.At(0); *m = shape.At(1); packed_n = 1; } else if (layout == "MB(HK)") { *b = shape.At(1); *m = shape.At(0); packed_n = 1; } else if (layout == "BM(H2K)") { CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'BM(H2K)'"; *b = shape.At(0); *m = shape.At(1); packed_n = 2; } else if (layout == "MB(H2K)") { CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'MB(H2K)'"; *b = shape.At(1); *m = shape.At(0); packed_n = 2; } else if (layout == "BM(H3K)") { *b = shape.At(0); *m = shape.At(1); packed_n = 3; } else if (layout == "MB(H3K)") { *b = shape.At(1); *m = shape.At(0); packed_n = 3; } else { UNIMPLEMENTED_THEN_RETURN(); } const int64_t hidden_size = shape.At(2); if (num_heads) { const int64_t expected_h = JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0) << "The size of the last dimension of the " << name << " tensor should be a multiple of " << packed_h << "."; *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = JUST(head_size); const int64_t packed_k = expected_k * packed_n; CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0) << "The size of the last dimension of the " << name << " tensor should be a multiple of " << packed_k << "."; *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (layout == "(BM)HK") { *bm_packed = true; CHECK_OR_RETURN(batch_size); CHECK_OR_RETURN(seq_len); *b = JUST(batch_size); *m = JUST(seq_len); *h = shape.At(1); *k = shape.At(2); } else { UNIMPLEMENTED_THEN_RETURN() << name << "_layout should be 'BM(HK)', 'MB(HK)', 'BM(H2K)', 'MB(H2K)', 'BM(H3K)', " "'MB(H3K)' or '(BM)HK' when the number of dimensions of " << name << " tensor is 3."; } } else if (shape.NumAxes() == 4) { *bm_packed = false; if (layout == "BMHK") { *b = shape.At(0); *m = shape.At(1); *h = shape.At(2); *k = shape.At(3); } else if (layout == "BHMK") { *b = shape.At(0); *m = shape.At(2); *h = shape.At(1); *k = shape.At(3); } else if (layout == "MBHK") { *b = shape.At(1); *m = shape.At(0); *h = shape.At(2); *k = shape.At(3); } else { UNIMPLEMENTED_THEN_RETURN() << name << "_layout should be 'BMHK', 'BHMK' or 'MBHK' when the number of dimensions of " << name << " tensor is 4."; } } else { UNIMPLEMENTED_THEN_RETURN() << "The number of dimensions of the " << name << " tensor should be 3 or 4"; }; if (batch_size) { const int64_t expected_b = JUST(batch_size); CHECK_EQ_OR_RETURN(*b, expected_b) << "The size of dimension 'B' of " << name << " tensor should be " << expected_b << "."; } if (seq_len) { const int64_t expected_m = JUST(seq_len); CHECK_EQ_OR_RETURN(*m, expected_m) << "The size of dimension 'M' of " << name << " tensor should be " << expected_m << "."; } if (num_heads) { const int64_t expected_h = JUST(num_heads); CHECK_EQ_OR_RETURN(*h, expected_h) << "The size of dimension 'H' of " << name << " tensor should be " << expected_h << "."; } if (head_size) { const int64_t expected_k = JUST(head_size); CHECK_EQ_OR_RETURN(*k, expected_k) << "The size of dimension 'K' of " << name << " tensor should be " << expected_k << "."; } return Maybe::Ok(); } Maybe ParseDims(const std::string& name, const Shape& shape, const std::string& layout, const Optional& num_heads, const Optional& head_size, int64_t* b, int64_t* m, int64_t* h, int64_t* k) { bool bm_packed{}; return ParseDims(name, shape, layout, Optional(), Optional(), num_heads, head_size, b, m, h, k, &bm_packed); } } // namespace class FusedMultiHeadAttentionInferenceFunctor { public: FusedMultiHeadAttentionInferenceFunctor() = default; Maybe operator()( const std::shared_ptr& query, const std::shared_ptr& key, const std::shared_ptr& value, const int64_t& num_heads, const bool& causal, const int64_t& query_hidden_slice_start, const int64_t& query_hidden_slice_end, const int64_t& key_hidden_slice_start, const int64_t& key_hidden_slice_end, const int64_t& value_hidden_slice_start, const int64_t& value_hidden_slice_end, const Optional& attn_bias, const int64_t& causal_diagonal_offset) const { CHECK_OR_RETURN(query_hidden_slice_start == 0 && key_hidden_slice_start == 0 && value_hidden_slice_start == 0 && query_hidden_slice_end == -1 && key_hidden_slice_end == -1 && value_hidden_slice_end == -1) << "The parameters 'query_hidden_slice_start', 'query_hidden_slice_end', " "'key_hidden_slice_start', 'key_hidden_slice_end', 'value_hidden_slice_start', " "'value_hidden_slice_end' have been deprecated."; const int64_t query_hidden_size = query->shape()->At(2); CHECK_EQ_OR_RETURN(query_hidden_size % num_heads, 0) << "The hidden size of the query tensor should be a multiple of num_heads."; const int64_t query_head_size = query_hidden_size / num_heads; return functional::FusedMultiHeadAttentionInferenceV2( query, "BM(HK)", query_head_size, Optional(), Optional(), key, "BM(HK)", Optional(), Optional(), Optional(), value, "BM(HK)", attn_bias, "BM(HK)", Optional(), causal, Optional(), causal_diagonal_offset); } }; class FusedMultiHeadAttentionInferenceV2Functor { public: struct OpExprCacheKey { bool has_attn_bias = false; bool has_seq_start = false; bool has_key_seq_len = false; bool operator==(const OpExprCacheKey& rhs) const { return this->has_attn_bias == rhs.has_attn_bias && this->has_seq_start == rhs.has_seq_start && this->has_key_seq_len == rhs.has_key_seq_len; } }; struct OpExprCacheKeyHash { size_t operator()(const OpExprCacheKey& key) const { return Hash(key.has_attn_bias, key.has_seq_start, key.has_key_seq_len); } }; using OpExprCache = std::unordered_map, OpExprCacheKeyHash>; FusedMultiHeadAttentionInferenceV2Functor() { for (bool has_attn_bias : {false, true}) { for (bool has_seq_start : {false, true}) { for (bool has_key_seq_len : {false, true}) { auto builder = one::OpBuilder("fused_multi_head_attention_inference") .Input("query") .Input("key") .Input("value"); if (has_attn_bias) { builder.Input("attn_bias"); } if (has_seq_start) { builder.Input("query_seq_start").Input("key_seq_start"); } if (has_key_seq_len) { builder.Input("key_seq_len"); } auto op = CHECK_JUST(builder.Output("out").Build()); OpExprCacheKey key; key.has_attn_bias = has_attn_bias; key.has_seq_start = has_seq_start; key.has_key_seq_len = has_key_seq_len; op_cache_.emplace(key, op); } } } } Maybe operator()( const std::shared_ptr& query, const std::string& query_layout, const Optional& query_head_size, const Optional& query_seq_start, const Optional& query_max_seq_len, const Optional& key, const Optional& key_layout, const Optional& key_seq_start, const Optional& key_seq_len, const Optional& key_max_seq_len, const Optional& value, const Optional& value_layout, const Optional& attn_bias, const std::string& output_layout, const Optional& scale, const Optional& causal, const Optional& attn_mask_type, const int64_t& causal_diagonal_offset) const { std::string attn_mask_type_val = "none"; if (attn_mask_type) { CHECK(!causal) << "Only one of attn_mask_type and causal can be specified at the same time."; attn_mask_type_val = *JUST(attn_mask_type); CHECK_OR_RETURN(attn_mask_type_val == "none" || attn_mask_type_val == "causal_from_top_left" || attn_mask_type_val == "causal_from_bottom_right") << "The value of attn_mask_type should be one of 'none', 'causal_from_top_left' or " "'causal_from_bottom_right'"; } else if (causal && JUST(causal)) { attn_mask_type_val = "causal_from_top_left"; } else { // do nothing } CHECK_GE_OR_RETURN(causal_diagonal_offset, 0) << "The value of causal_diagonal_offset should be greater or equal to 0."; Optional batch_size; std::shared_ptr query_seq_start_tensor; std::shared_ptr key_seq_start_tensor; if (query_seq_start) { CHECK_OR_RETURN(key_seq_start) << "The tensors query_seq_start and key_seq_start should both " "be None or both not be None at the same time."; CHECK_OR_RETURN(query_max_seq_len) << "query_max_seq_len should not be None when query_seq_start is not None."; CHECK_OR_RETURN(key_max_seq_len) << "key_max_seq_len should not be None when key_seq_start is not None."; query_seq_start_tensor = JUST(query_seq_start); key_seq_start_tensor = JUST(key_seq_start); CHECK_EQ_OR_RETURN(query_seq_start_tensor->shape()->NumAxes(), 1) << "The number of dimensions of query_seq_start tensor should be 1."; CHECK_OR_RETURN(*query_seq_start_tensor->shape() == *key_seq_start_tensor->shape()) << "The shapes of the query_seq_start and key_seq_start tensors should match."; CHECK_GT_OR_RETURN(query_seq_start_tensor->shape()->At(0), 1) << "The size of query_seq_start should be greater than 1."; batch_size = query_seq_start_tensor->shape()->At(0) - 1; if (key_seq_len) { CHECK_EQ_OR_RETURN(JUST(key_seq_len)->shape()->NumAxes(), 1) << "The number of dimensions of key_seq_len tensor should be 1."; CHECK_EQ_OR_RETURN(JUST(key_seq_len)->shape()->At(0), JUST(batch_size)) << "The size of the key_seq_len tensor should be " << JUST(batch_size) << "."; } } else { CHECK_OR_RETURN(!key_seq_start) << "The tensors query_seq_start and key_seq_start should both " "be None or both not be None at the same time."; CHECK_OR_RETURN(!key_seq_len) << "The key_seq_len tensor should be None when query_seq_start is None."; } std::shared_ptr key_tensor; std::string key_tensor_layout; std::shared_ptr value_tensor; std::string value_tensor_layout; int64_t q_b = 0; int64_t q_m = 0; int64_t q_h = 0; int64_t q_k = 0; bool q_bm_packed = false; JUST(ParseDims("query", *query->shape(), query_layout, batch_size, query_max_seq_len, Optional(), query_head_size, &q_b, &q_m, &q_h, &q_k, &q_bm_packed)); CHECK_EQ_OR_RETURN(q_k % 8, 0) << "The size of dimension 'K' of the query tensor should be a multiple of 8."; if (q_bm_packed) { CHECK_OR_RETURN(query_seq_start) << "The query_seq_start tensor should not be None when the query tensor is BM-Packed."; } int64_t k_b = 0; int64_t k_m = 0; int64_t k_h = 0; int64_t k_k = 0; bool k_bm_packed = false; if (key) { key_tensor = JUST(key); key_tensor_layout = *JUST(key_layout); JUST(ParseDims("key", *key_tensor->shape(), key_tensor_layout, q_b, key_max_seq_len, Optional(), q_k, &k_b, &k_m, &k_h, &k_k, &k_bm_packed)); CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the " "same as that of the query tensor."; CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the " "same as that of the query tensor."; CHECK_EQ_OR_RETURN(k_bm_packed, q_bm_packed) << "The query tensor and the key tensor should either both be BM-Packed or both not be " "BM-Packed at the same time."; } else { CHECK_OR_RETURN(query_layout == "BM(H3K)" || query_layout == "MB(H3K)") << "The value of query_layout should be 'BM(H3K)' or 'MB(H3K)' when the key tensor is " "None."; key_tensor = query; key_tensor_layout = query_layout; k_b = q_b; k_m = q_m; k_h = q_h; k_k = q_k; k_bm_packed = q_bm_packed; } int64_t v_b = 0; int64_t v_m = 0; int64_t v_h = 0; int64_t v_k = 0; bool v_bm_packed = false; if (value) { value_tensor = JUST(value); value_tensor_layout = *JUST(value_layout); JUST(ParseDims("value", *value_tensor->shape(), value_tensor_layout, q_b, k_m, q_h, Optional(), &v_b, &v_m, &v_h, &v_k, &v_bm_packed)); CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the " "same as that of the query tensor."; CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the " "same as that of the key tensor."; CHECK_EQ_OR_RETURN(v_k % 8, 0) << "The size of dimension 'K' of the value tensor should be a multiple of 8."; CHECK_EQ_OR_RETURN(v_bm_packed, k_bm_packed) << "The key tensor and the value tensor should either both be BM-Packed or both not be " "BM-Packed at the same time."; } else { CHECK_OR_RETURN(key_tensor_layout == "BM(H2K)" || key_tensor_layout == "MB(H2K)" || key_tensor_layout == "BM(H3K)" || key_tensor_layout == "MB(H3K)") << "The value of key_layout should be 'BM(H3K)', 'MB(H3K)', 'BM(H2K)' or 'MB(H2K)' when " "the value tensor is None."; value_tensor = key_tensor; value_tensor_layout = key_tensor_layout; v_b = k_b; v_m = k_m; v_h = k_h; v_k = k_k; v_bm_packed = k_bm_packed; } if (attn_bias) { const auto attn_bias_shape = JUST(attn_bias)->shape(); const int64_t num_attn_bias_axes = attn_bias_shape->NumAxes(); CHECK_OR_RETURN(num_attn_bias_axes > 0 && num_attn_bias_axes <= 4) << "The number of dimensions of attn_bias should be greater than 0 and less than or " "equal to 4."; CHECK_GE_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1), k_m) << "The size of the -1 dimension of attn_bias should be greater than or equal to the " "dimension 'M' of the key tensor"; CHECK_EQ_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1) % 8, 0) << "The size of the -1 dimension of attn_bias should be a multiple of 8."; if (num_attn_bias_axes >= 2) { CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 2) == 1 || attn_bias_shape->At(num_attn_bias_axes - 2) >= q_m) << "The size of the -2 dimension of attn_bias should be greater than or equal to the " "dimension 'M' of the query tensor or equal to 1."; } if (num_attn_bias_axes >= 3) { CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 3) == 1 || attn_bias_shape->At(num_attn_bias_axes - 3) == q_h) << "The size of the -3 dimension of attn_bias should be equal to the dimension 'H' of " "the query tensor or equal to 1."; } if (num_attn_bias_axes == 4) { CHECK_OR_RETURN(attn_bias_shape->At(0) == 1 || attn_bias_shape->At(0) == q_b) << "The size of the -4 dimension of attn_bias should be equal to the dimension 'B' of " "the query tensor or equal to 1."; } } const bool o_bm_packed = output_layout == "(BM)(HK)"; CHECK_EQ_OR_RETURN(o_bm_packed, q_bm_packed) << "The query tensor and the output tensor should either both be BM-Packed or both not be " "BM-Packed at the same time."; std::string op_output_layout; if (output_layout == "BM(HK)" || output_layout == "(BM)(HK)") { op_output_layout = output_layout; } else if (output_layout == "MB(HK)") { if (q_b == 1) { op_output_layout = output_layout; } else { op_output_layout = "BM(HK)"; } } else { UNIMPLEMENTED_THEN_RETURN() << "output_layout should be 'BM(HK)', 'MB(HK)' or (BM)(HK)"; } double scale_value = 0.0; if (scale) { scale_value = JUST(scale); } else { scale_value = 1.0 / std::sqrt(static_cast(q_k)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout", "output_layout", "query_head_size", "attn_mask_type", "causal_diagonal_offset", "query_max_seq_len", "key_max_seq_len", "scale"); attrs.SetAllAttrs(query_layout, key_tensor_layout, value_tensor_layout, op_output_layout, q_k, attn_mask_type_val, causal_diagonal_offset, query_max_seq_len.value_or(0), key_max_seq_len.value_or(0), scale_value); OpExprCacheKey cache_key{}; std::vector> inputs; inputs.emplace_back(query); inputs.emplace_back(key_tensor); inputs.emplace_back(value_tensor); if (attn_bias) { inputs.emplace_back(JUST(attn_bias)); cache_key.has_attn_bias = true; } else { cache_key.has_attn_bias = false; } if (query_seq_start && key_seq_start) { inputs.emplace_back(JUST(query_seq_start)); inputs.emplace_back(JUST(key_seq_start)); cache_key.has_seq_start = true; } else { cache_key.has_seq_start = false; } if (key_seq_len) { inputs.emplace_back(JUST(key_seq_len)); cache_key.has_key_seq_len = true; } else { cache_key.has_key_seq_len = false; } auto it = op_cache_.find(cache_key); CHECK_OR_RETURN(it != op_cache_.end()); TensorTuple input_tuple(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { input_tuple[i] = std::move(inputs[i]); } std::shared_ptr op_output = JUST(OpInterpUtil::Dispatch(*it->second, input_tuple, attrs)); if (op_output_layout == output_layout) { return op_output; } else { if (op_output_layout == "BM(HK)" && output_layout == "MB(HK)") { return functional::Transpose(op_output, {1, 0, 2}); } else { UNIMPLEMENTED_THEN_RETURN(); } } } private: OpExprCache op_cache_; }; class FusedAttentionConcatPastKeyValueFunctor { public: FusedAttentionConcatPastKeyValueFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_attention_concat_past_key_value") .Input("key") .Input("value") .Input("past_key") .Input("past_value") .Output("output_key") .Output("output_value") .Build()); op_without_past_ = CHECK_JUST(one::OpBuilder("fused_attention_concat_past_key_value") .Input("key") .Input("value") .Output("output_key") .Output("output_value") .Build()); } Maybe operator()( const Optional& past_key, const std::string& past_key_layout, const Optional& past_value, const std::string& past_value_layout, const std::shared_ptr& key, const std::string& key_layout, const std::shared_ptr& value, const std::string& value_layout, const Optional& key_head_size) const { int64_t k_b = 0; int64_t k_m = 0; int64_t k_h = 0; int64_t k_k = 0; JUST(ParseDims("key", *key->shape(), key_layout, Optional(), key_head_size, &k_b, &k_m, &k_h, &k_k)); int64_t v_b = 0; int64_t v_m = 0; int64_t v_h = 0; int64_t v_k = 0; JUST(ParseDims("value", *value->shape(), value_layout, k_h, k_k, &v_b, &v_m, &v_h, &v_k)); CHECK_EQ_OR_RETURN(v_b, k_b) << "The size of dimension 'B' of the value tensor should be " "the same as that of the key tensor."; CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the " "same as that of the key tensor."; if (past_key) { CHECK_OR_RETURN(past_value) << "Tensor past_key and tensor past_value should both be None or " "both not be None at the same time."; int64_t past_k_b = 0; int64_t past_k_m = 0; int64_t past_k_h = 0; int64_t past_k_k = 0; JUST(ParseDims("past_key", *JUST(past_key)->shape(), past_key_layout, k_h, k_k, &past_k_b, &past_k_m, &past_k_h, &past_k_k)); CHECK_EQ_OR_RETURN(past_k_b, k_b) << "The size of dimension 'B' of the past_key tensor should be " "the same as that of the key tensor."; int64_t past_v_b = 0; int64_t past_v_m = 0; int64_t past_v_h = 0; int64_t past_v_k = 0; JUST(ParseDims("past_value", *JUST(past_value)->shape(), past_value_layout, k_h, k_k, &past_v_b, &past_v_m, &past_v_h, &past_v_k)); CHECK_EQ_OR_RETURN(past_v_b, k_b) << "The size of dimension 'B' of the past_value tensor " "should be the same as that of the key tensor."; CHECK_EQ_OR_RETURN(past_v_m, past_k_m) << "The size of dimension 'M' of the past_value tensor " "should be the same as that of the past_key tensor."; } else { CHECK_OR_RETURN(!past_value) << "Tensor past_key and tensor past_value should both be None or " "both not be None at the same time."; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("past_key_layout", "past_value_layout", "key_layout", "value_layout", "key_head_size"); attrs.SetAllAttrs(past_key_layout, past_value_layout, key_layout, value_layout, k_k); if (past_key) { return JUST(OpInterpUtil::Dispatch( *op_, {key, value, JUST(past_key), JUST(past_value)}, attrs)); } else { return JUST(OpInterpUtil::Dispatch(*op_without_past_, {key, value}, attrs)); } } private: std::shared_ptr op_; std::shared_ptr op_without_past_; }; class FusedApplyRotaryEmbFunctor { public: FusedApplyRotaryEmbFunctor() { op_with_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb") .Input("x") .Input("cos") .Input("sin") .Input("position_ids") .Output("out") .Build()); op_with_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb") .Input("x") .Input("position_ids") .Output("out") .Build()); op_without_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb") .Input("x") .Input("cos") .Input("sin") .Output("out") .Build()); op_without_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional& cos, const Optional& sin, const Optional& position_ids, const std::string& x_layout, const Optional& output_layout, const std::string& mode, const Optional& tensor_index, const Optional& k_size, const float base, const Optional& rotary_size) const { int64_t b = 0, m = 0, h = 0, k = 0; if (tensor_index) { CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) << "tensor_index should be set between [0, 2]"; } CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) << "mode should be \"intervel\" or \"plane\""; ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k); if (k_size) { CHECK_EQ_OR_RETURN(JUST(k_size), k) << "k_size if given should be equal to K of cos, sin and x."; } if (rotary_size) { CHECK_LE_OR_RETURN(JUST(rotary_size), k) << "rotary_size should be no more than k."; } int64_t rotary_emd_dim = 1; if (position_ids) { CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->NumAxes(), 3) << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(0), b) << "1st dim of position_ids should be equal to B."; CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(2), m) << "3rd dim of position_ids should be equal to M."; rotary_emd_dim = JUST(position_ids)->shape()->At(1); CHECK_OR_RETURN(rotary_emd_dim == 1 || rotary_emd_dim == 2) << "2nd dim of position_ids should be 1 or 2."; } const int64_t actual_rotary_size = rotary_size.value_or(k) / rotary_emd_dim; CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) << "k ,or rotary_size if given, should be a multiple of 2 * rotary_encoding_dim."; if (cos && sin) { CHECK_EQ_OR_RETURN(JUST(cos)->shape()->NumAxes(), 2) << "The number of dimensions of cos should be equal to 2."; CHECK_OR_RETURN(JUST(cos)->shape() == JUST(sin)->shape()) << "Each dimension of cos & sin should be the same."; CHECK_EQ_OR_RETURN(JUST(cos)->shape()->At(1), actual_rotary_size) << "The 1st dimension of cos & sin should equal to rotary_size // " "rotary_embedding_dimension."; } else if (!cos && !sin) { // do nothing } else { UNIMPLEMENTED_THEN_RETURN() << "cos & sin should both be given or not given."; } if (!position_ids) { if (cos && sin) { CHECK_GE_OR_RETURN(JUST(cos)->shape()->At(0), m) << "M of cos & sin should be to no less than " "M of x when position_ids is not " "given."; // K of cos & sin is checked // inside ParseDims } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("x_layout", "output_layout", "mode", "tensor_index", "k_size", "base", "rotary_size"); attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0), k_size.value_or(k), base, rotary_size.value_or(k)); if (position_ids) { if (cos && sin) { return OpInterpUtil::Dispatch(*op_with_position_sinuous_, {x, JUST(cos), JUST(sin), JUST(position_ids)}, attrs); } else { return OpInterpUtil::Dispatch(*op_with_position_, {x, JUST(position_ids)}, attrs); } } else { if (cos && sin) { return OpInterpUtil::Dispatch(*op_without_position_, {x, JUST(cos), JUST(sin)}, attrs); } else { return OpInterpUtil::Dispatch(*op_without_position_sinuous_, {x}, attrs); } } } private: std::shared_ptr op_with_position_; std::shared_ptr op_with_position_sinuous_; std::shared_ptr op_without_position_; std::shared_ptr op_without_position_sinuous_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("FusedMultiHeadAttentionInference"); m.add_functor( "FusedMultiHeadAttentionInferenceV2"); m.add_functor("FusedAttentionConcatPastKeyValue"); m.add_functor("FusedApplyRotaryEmb"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/global_cast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/rank_group_scope.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/common/flat_shape.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/common/constant.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/user/kernels/collective_communication/include/broadcast.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace { // NOTE: use env variable 'ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE' indicate whether the // shape and dtype of input tensor on each rank is the same when cast local tensor to global tensor. // If set true, there will be no meta-information synchronization on each rank. Optional ParseEagerLocalToGlobalBalancedOverride() { const char* env_p = std::getenv("ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE"); if (env_p == nullptr) { return Optional(); } else { return ParseBooleanFromEnv("ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE", false); } } bool NeedSyncAndCheckShapeAndDtype(bool check_meta_hint) { thread_local Optional eager_local_to_global_balanced_override = ParseEagerLocalToGlobalBalancedOverride(); if (eager_local_to_global_balanced_override.has_value()) { return IsInDebugMode() || !CHECK_JUST(eager_local_to_global_balanced_override); } else { return IsInDebugMode() || check_meta_hint; } } // clang-format off FLAT_MSG_BEGIN(FlatShapeAndDataType); // Methods static Maybe New() { const auto& flat_shape_dtype = std::make_shared(); flat_shape_dtype->clear(); return flat_shape_dtype; } static Maybe New(const Shape& shape, DataType dtype) { const auto& flat_shape_dtype = JUST(New()); JUST(flat_shape_dtype->mutable_shape()->Init(shape)); flat_shape_dtype->set_dtype(dtype); return flat_shape_dtype; } Maybe Check(const Shape& shape, DataType dtype) const { JUST(this->shape().Check(shape)); CHECK_EQ_OR_RETURN(this->dtype(), dtype) << Error::RuntimeError() << "Expected all tensors on each rank to be the same dtype, but found " "at least two dtypes, " << DType(this->dtype()).name() << " and " << DType(dtype).name() << "!"; return Maybe::Ok(); } Maybe Check(const FlatShapeAndDataType& flat_shape_dtype) const { JUST(this->shape().Check(flat_shape_dtype.shape())); CHECK_EQ_OR_RETURN(this->dtype(), flat_shape_dtype.dtype()) << Error::RuntimeError() << "Expected input of each rank must have the same dtype, but got at least two dtypes, " << DType(this->dtype()).name() << " and " << DType(flat_shape_dtype.dtype()).name(); return Maybe::Ok(); } Maybe ToShape(Shape* shape) const { return this->shape().ToShape(shape); } Maybe ToShape() const { return shape().ToShape(); } int64_t At(int i) const { return shape().At(i); } int64_t NumAxes() const { return shape().NumAxes(); } private: // Fields FLAT_MSG_DEFINE_OPTIONAL(FlatShape, shape); FLAT_MSG_DEFINE_OPTIONAL(DataType, dtype); FLAT_MSG_END(FlatShapeAndDataType); // clang-format on Maybe ShapeAndDataTypeConsistencyCheck(const Symbol& placement, const Shape& shape, DataType dtype) { if (!placement->containing_current_rank() || placement->parallel_num() == 1) { return Maybe::Ok(); } const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype)); const auto& send_buffer = JUST(FlatShapeAndDataType::New(shape, dtype)); const auto& recv_buffer = JUST(FlatShapeAndDataType::New()); recv_buffer->clear(); NaiveAsyncTransportCtx ctx( transport_token, [send_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = send_buffer.get(); *size = sizeof(FlatShapeAndDataType); *Cb = [send_buffer] {}; return Maybe::Ok(); }, [recv_buffer](int64_t rank, void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_buffer.get(); *size = sizeof(FlatShapeAndDataType); *Cb = [recv_buffer] {}; return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroup::New(placement)); JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg); JUST(send_buffer->Check(*recv_buffer)); return Maybe::Ok(); } Maybe>> BroadcastGatherShapeAndDataType( const Shape& shape, DataType dtype, Symbol parallel_desc) { const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype)); const auto& send_buffer = JUST(FlatShapeAndDataType::New(shape, dtype)); const auto& map = std::make_shared>>(); map->emplace(GlobalProcessCtx::Rank(), send_buffer); NaiveAsyncTransportCtx ctx( transport_token, [send_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = send_buffer.get(); *size = sizeof(FlatShapeAndDataType); *Cb = [send_buffer] {}; return Maybe::Ok(); }, [map](int64_t rank, void** buffer, std::size_t* size, std::function* Cb) -> Maybe { const auto& recv_buffer = JUST(FlatShapeAndDataType::New()); recv_buffer->clear(); *buffer = recv_buffer.get(); *size = sizeof(FlatShapeAndDataType); *Cb = [recv_buffer] {}; CHECK_OR_RETURN(map->emplace(rank, recv_buffer).second); // NOLINT(maybe-need-error-msg) return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroup::New(parallel_desc)); JUST(TransportUtil::BroadcastToOtherRanks(rank_group, rank_group, transport_token, &ctx)); JUST(TransportUtil::CollectFromOtherRanks(rank_group, rank_group, transport_token, &ctx)); JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg); return map; } Maybe FindRoot(Symbol broadcast_parallel_desc, Symbol src_parallel_desc) { for (int64_t process_id : broadcast_parallel_desc->sorted_machine_ids()) { if (src_parallel_desc->ContainingMachineId(process_id)) { return process_id; } } UNIMPLEMENTED_THEN_RETURN(); } auto* CachedFindRoot = DECORATE(&FindRoot, ThreadLocal); Maybe BroadcastShapeAndDtype(const Shape& shape, DataType dtype, Symbol parallel_desc) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); const auto& rank_group_parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(parallel_desc->device_type(), rank_group)); const auto& process_id2broadcast_group = JUST(GetBroadcastGroup(parallel_desc, rank_group_parallel_desc)); const auto& broadcast_parallel_desc = JUST(MapAt(*process_id2broadcast_group, GlobalProcessCtx::Rank())); const auto& in_flat_shape_dtype = JUST(FlatShapeAndDataType::New(shape, dtype)); const auto& out_flat_shape_dtype = JUST(FlatShapeAndDataType::New()); int64_t root = JUST(CachedFindRoot(broadcast_parallel_desc, parallel_desc)); const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype)); JUST(ccl::CpuBroadcast(in_flat_shape_dtype.get(), out_flat_shape_dtype.get(), sizeof(FlatShapeAndDataType), root, broadcast_parallel_desc, transport_token)); return out_flat_shape_dtype; } Maybe GetConcatenatedShapeAndCheckDtype( Shape* logical_shape, DataType* dtype, const HashMap>& rank2flat_shape_dtype, Symbol parallel_desc, Symbol nd_sbp) { *dtype = rank2flat_shape_dtype.begin()->second->dtype(); HashMap> rank2logical_shape; for (const auto& pair : rank2flat_shape_dtype) { rank2logical_shape.emplace(pair.first, JUST(pair.second->ToShape())); CHECK_EQ_OR_RETURN(*dtype, pair.second->dtype()) << Error::RuntimeError() << "Expected all tensors on each rank to be the same dtype, but found " "at least two dtypes, " << DType(*dtype).name() << "(rank " << rank2flat_shape_dtype.begin()->first << ") and " << DType(pair.second->dtype()).name() << "(rank " << pair.first << ")!"; } const auto& GetRankPhyShapeByParallelId = [&](Symbol parallel_desc, int64_t parallel_id) -> Maybe { int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(parallel_id)); return JUST(MapAt(rank2logical_shape, machine_id)); }; const auto& parallel_hierarchy = parallel_desc->hierarchy(); Stride parallel_stride(*parallel_hierarchy); for (int32_t i = nd_sbp->sbp_parallel_size() - 1; i >= 0; --i) { if (nd_sbp->sbp_parallel(i).has_split_parallel()) { int64_t concat_axis = nd_sbp->sbp_parallel(i).split_parallel().axis(); int64_t group_size = parallel_hierarchy->Count(0, i); int64_t stride = parallel_stride.at(i); for (int group_id = 0; group_id < group_size; ++group_id) { int64_t parallel_num_in_group = parallel_hierarchy->At(i); for (int64_t stride_id = 0; stride_id < stride; ++stride_id) { ParallelConf parallel_conf; parallel_conf.set_device_tag(parallel_desc->device_tag()); int64_t start_parallel_id = group_id * parallel_num_in_group + stride_id; for (int64_t parallel_id_in_group = 0; parallel_id_in_group < parallel_num_in_group; ++parallel_id_in_group) { int64_t id = start_parallel_id + parallel_id_in_group * stride; int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(id)); int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(id)); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":" + std::to_string(device_id)); } Symbol sub_parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); std::shared_ptr first_shape = JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, 0)); CHECK_GE_OR_RETURN(concat_axis, 0) << Error::RuntimeError() << "Split axis must not be negative, but got " << concat_axis << "!"; CHECK_LT_OR_RETURN(concat_axis, first_shape->NumAxes()) << Error::RuntimeError() << "Split axis out of range (expected to be in range of [" << 0 << ", " << first_shape->NumAxes() << "), but got " << concat_axis << "!)"; int64_t logical_concat_dim = first_shape->At(concat_axis); for (int parallel_id = 1; parallel_id < sub_parallel_desc->parallel_num(); ++parallel_id) { const auto& rank_shape = JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, parallel_id)); CHECK_EQ_OR_RETURN(rank_shape->NumAxes(), first_shape->NumAxes()) << Error::RuntimeError() << "Sizes of tensors must match except in dimension " << concat_axis << ", but found " << first_shape->ToString() << "(rank " << JUST(sub_parallel_desc->MachineId4ParallelId(0)) << ") and " << rank_shape->ToString() << "(rank " << JUST(sub_parallel_desc->MachineId4ParallelId(parallel_id)) << ")!"; logical_concat_dim += rank_shape->At(concat_axis); } BalancedSplitter bs(logical_concat_dim, sub_parallel_desc->parallel_num()); CHECK_EQ_OR_RETURN(first_shape->At(concat_axis), bs.At(0).size()) << Error::RuntimeError() << "Sizes of tensors in dimension " << concat_axis << " must be same or match balanced split distribution. See " "https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/" "balanced_splitter.h " "for details of balanced split"; first_shape->Set(concat_axis, logical_concat_dim); for (int parallel_id = 1; parallel_id < sub_parallel_desc->parallel_num(); ++parallel_id) { std::shared_ptr rank_shape = JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, parallel_id)); for (int i = 0; i < first_shape->NumAxes(); ++i) { if (i == concat_axis) { CHECK_EQ_OR_RETURN(rank_shape->At(i), bs.At(parallel_id).size()) << Error::RuntimeError() << "Sizes of tensors in dimension " << concat_axis << " must be same or match balanced split distribution. See " "https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/" "balanced_splitter.h " "for details of balanced split"; } else { CHECK_EQ_OR_RETURN(rank_shape->At(i), first_shape->At(i)) << Error::RuntimeError() << "Sizes of tensors must match except in dimension " << concat_axis << ". Expected size " << first_shape->At(i) << " but got size " << rank_shape->At(i) << " for tensor on rank " << JUST(sub_parallel_desc->MachineId4ParallelId(parallel_id)) << "!"; } } rank_shape->Set(concat_axis, logical_concat_dim); } } } } } *logical_shape = *JUST(GetRankPhyShapeByParallelId(parallel_desc, 0)); return Maybe::Ok(); } Maybe GetLogicalShapeAndDataType(Shape* logical_shape, DataType* /* in and out */ dtype, std::shared_ptr physical_shape, Symbol parallel_desc, Symbol nd_sbp, bool sync_and_check_meta) { if (!sync_and_check_meta) { *logical_shape = *JUST(GetLogicalShape(*physical_shape, *nd_sbp, *parallel_desc)); } else { if (ContainSplitSbp(nd_sbp)) { *logical_shape = *physical_shape; if (parallel_desc->containing_current_rank()) { const auto& rank2flat_shape_dtype = JUST(BroadcastGatherShapeAndDataType(*logical_shape, *dtype, parallel_desc)); JUST(GetConcatenatedShapeAndCheckDtype(logical_shape, dtype, *rank2flat_shape_dtype, parallel_desc, nd_sbp)); } } else { *logical_shape = *physical_shape; JUST(ShapeAndDataTypeConsistencyCheck(parallel_desc, *logical_shape, *dtype)); } } if (JUST(RankGroup::New(parallel_desc)) != JUST(RankGroupScope::CurrentRankGroup())) { const auto& flat_shape_dtype = JUST(BroadcastShapeAndDtype(*logical_shape, *dtype, parallel_desc)); *logical_shape = *JUST(flat_shape_dtype->ToShape()); *dtype = flat_shape_dtype->dtype(); } return Maybe::Ok(); } Maybe CheckNdSbpValid(Symbol nd_sbp, const Shape& logical_shape) { for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { const auto& sbp_parallel = nd_sbp->sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { CHECK_LT_OR_RETURN(sbp_parallel.split_parallel().axis(), logical_shape.NumAxes()) << Error::RuntimeError() << "Split axis out of range (expected to be in range of [" << 0 << ", " << logical_shape.NumAxes() << "), but got " << sbp_parallel.split_parallel().axis() << "!)"; } } return Maybe::Ok(); } namespace { Maybe RawGetGlobalToGlobalOpExpr( const std::vector>& grad_sbp_parallels) { Optional> grad_nd_sbp; if (!grad_sbp_parallels.empty()) { grad_nd_sbp = JUST(GetNdSbp(grad_sbp_parallels)); } std::shared_ptr op_expr = JUST(one::GlobalToGlobalOpExpr::New(grad_nd_sbp)); return op_expr; } } // namespace static constexpr auto* GetGlobalToGlobalOpExpr = DECORATE(&RawGetGlobalToGlobalOpExpr, ThreadLocalCopiable); Maybe GlobalToGlobal(const std::shared_ptr& x, Symbol parallel_desc, const std::vector>& sbp_parallels, const std::vector>& grad_sbp_parallels, bool copy) { const auto& global_tensor = JUST(x->AsGlobalTensor()); CHECK_NOTNULL_OR_RETURN(global_tensor) << "global tensors supported only"; const auto& nd_sbp = JUST(GetNdSbp(sbp_parallels)); JUST(CheckNdSbpValid(nd_sbp, *x->shape())); std::shared_ptr op; if (unlikely(!LazyMode::is_enabled() && JUST(x->parallel_desc())->hierarchy()->NumAxes() != parallel_desc->hierarchy()->NumAxes() && grad_sbp_parallels.size() == 0)) { op = JUST(GetGlobalToGlobalOpExpr(*JUST(GetSbpList(JUST(x->nd_sbp()))))); } else { op = JUST(GetGlobalToGlobalOpExpr(grad_sbp_parallels)); } if (!LazyMode::is_enabled() && JUST(x->nd_sbp()) == nd_sbp && JUST(x->parallel_desc()) == parallel_desc && (grad_sbp_parallels.size() == 0 || !autograd::GradMode::is_enabled())) { if (copy) { return functional::Identity(x); } return x; } const auto& tensor = JUST(OpInterpUtil::Dispatch( *op, {global_tensor}, OpExprInterpContext(AttrMap{}, parallel_desc, nd_sbp))); if (!LazyMode::is_enabled() && tensor != x && !IsGlobalTensorMetaCheckDisabled()) { const auto& input_global_id = JUST(x->transport_token()); const auto& output_consistend_id = JUST(tensor->transport_token()); CHECK_NE_OR_RETURN(input_global_id, output_consistend_id); // NOLINT(maybe-need-error-msg) } return tensor; } Maybe LocalToGlobal(const std::shared_ptr& x, Symbol parallel_desc, const std::vector>& sbp_parallels, const Optional& opt_shape, const Optional& opt_dtype, const std::shared_ptr& op, bool check_meta_hint, bool sync_data, bool copy) { CHECK_OR_RETURN(!x->is_lazy()) << Error::RuntimeError() << "local_tensor.to_global() is not supported within nn.Graph for now"; CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "local tensors supported only"; std::shared_ptr input = x->contiguous(); // copy to right device first if input's device type is wrong if (JUST(input->device())->type() != parallel_desc->device_tag()) { VLOG(2) << "The device_type of the input tensor is different from placement, now copy it to " << parallel_desc->device_tag(); input = JUST(functional::Copy(x, parallel_desc->device_tag(), GlobalProcessCtx::LocalRank(), /*pin_memory=*/false)); } // copy to default device of the current rank if input's device type is right but not on default // device bool device_mismatch = JUST(input->device())->device_id() != GlobalProcessCtx::LocalRank(); if (copy || device_mismatch) { if (device_mismatch) { VLOG(2) << "The tensor isn't on default device of the current rank, now copy it to " << parallel_desc->device_tag() << ": " << GlobalProcessCtx::LocalRank(); } input = JUST(functional::Copy(x, parallel_desc->device_tag(), GlobalProcessCtx::LocalRank(), /*pin_memory=*/false)); } const auto& device = JUST(input->device()); CHECK_EQ_OR_RETURN(device->type(), parallel_desc->device_tag()) << Error::UnimplementedError() << "tensor' device type must be same with placement."; CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()) << Error::UnimplementedError() << "tensor must be on default device of the current rank."; Symbol nd_sbp = JUST(GetNdSbp(sbp_parallels)); DataType dtype = x->dtype()->data_type(); std::shared_ptr shape = std::make_shared(); if (opt_shape.has_value() && opt_dtype.has_value()) { shape = JUST(opt_shape); dtype = JUST(opt_dtype); } else { bool sync_and_check_meta = NeedSyncAndCheckShapeAndDtype(check_meta_hint); JUST(GetLogicalShapeAndDataType(shape.get(), &dtype, x->shape(), parallel_desc, nd_sbp, sync_and_check_meta)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "sync_data", "inplace_when_sync_data"); attrs.SetAllAttrs(*shape, dtype, sync_data, !copy); const auto& output = JUST(OpInterpUtil::Dispatch( *op, {input}, OpExprInterpContext(attrs, parallel_desc, nd_sbp))); return output; } } // namespace class LocalToGlobalFunctor { public: LocalToGlobalFunctor() { op_ = CHECK_JUST(one::LocalToGlobalOpExpr::New(*CHECK_JUST(UniqueStr("local_to_global")))); } Maybe operator()(const std::shared_ptr& x, Symbol parallel_desc, const std::vector>& sbp_parallels, const Shape& shape, const Symbol& dtype, bool sync_data, bool copy) const { JUST(CheckDeviceIdsIsValid(parallel_desc)); NonRecursiveMetaInfoConsistencyCheckScope no_recursive_meta_info_conisitency_check_scope; JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels, 1, /* force_check */ false)); DisableCheckGlobalTensorMetaScope scope{}; std::shared_ptr tensor; DeviceType device_type = parallel_desc->device_type(); if (ccl::IsBroadcastRegistered(device_type) || !sync_data || device_type == DeviceType::kMeta) { tensor = JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, shape, dtype->data_type(), op_, /* check_meta */ false, sync_data, copy)); } else { // Assuming that the newly adapted hardware device does not support collective // communication, since local to global may need to synchronize data (through the // broadcast API), if device_type is neither cpu nor cuda, generate global tensor // with the corresponding cpu placement first, then convert the cpu global tensor // to the desired placement. Symbol cpu_parallel_desc = JUST(ReplaceDeviceType(parallel_desc, DeviceType::kCPU)); std::shared_ptr cpu_tensor = JUST(LocalToGlobal(x, cpu_parallel_desc, sbp_parallels, shape, dtype->data_type(), op_, /* check_meta */ false, sync_data, copy)); tensor = JUST(GlobalToGlobal(cpu_tensor, parallel_desc, sbp_parallels, GetNoneSbpList(), copy)); } return tensor; } private: std::shared_ptr op_; }; class ToGlobalFunctor { public: ToGlobalFunctor() { local_to_global_op_ = CHECK_JUST(one::LocalToGlobalOpExpr::New(*CHECK_JUST(UniqueStr("local_to_global")))); } Maybe operator()(const std::shared_ptr& x, Symbol parallel_desc, const std::vector>& sbp_parallels, const std::vector>& grad_sbp_parallels, bool check_meta, bool copy) const { JUST(CheckDeviceIdsIsValid(parallel_desc)); NonRecursiveMetaInfoConsistencyCheckScope scope; JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels, grad_sbp_parallels, 1, /* force_check */ check_meta)); std::shared_ptr tensor; if (x->is_global()) { tensor = JUST(GlobalToGlobal(x, parallel_desc, sbp_parallels, grad_sbp_parallels, copy)); } else { DeviceType device_type = parallel_desc->device_type(); if (ccl::IsBroadcastRegistered(device_type)) { tensor = JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, NullOpt, NullOpt, local_to_global_op_, check_meta, /* sync_data */ true, copy)); } else { // Assuming that the newly adapted hardware device does not support collective // communication, since local to global may need to synchronize data (through the // broadcast API), if device_type is neither cpu nor cuda, generate global tensor // with the corresponding cpu placement first, then convert the cpu global tensor // to the desired placement. Symbol cpu_parallel_desc = JUST(ReplaceDeviceType(parallel_desc, DeviceType::kCPU)); std::shared_ptr cpu_tensor = JUST(LocalToGlobal(x, cpu_parallel_desc, sbp_parallels, NullOpt, NullOpt, local_to_global_op_, check_meta, /* sync_data */ true, copy)); tensor = JUST(GlobalToGlobal(cpu_tensor, parallel_desc, sbp_parallels, GetNoneSbpList(), copy)); } } return tensor; } private: std::shared_ptr local_to_global_op_; }; class GlobalToLocalFunctor { public: GlobalToLocalFunctor() { op_ = CHECK_JUST(one::GlobalToLocalOpExpr::New(*CHECK_JUST(UniqueStr("global_to_local")))); } Maybe operator()(const std::shared_ptr& x, bool copy) const { CHECK_OR_RETURN(!x->is_lazy()) << Error::RuntimeError() << "global_tensor.to_local() is not supported within nn.Graph for now"; CHECK_OR_RETURN(x->is_global()) << Error::RuntimeError() << "Expected global tensor for to_local but got local tensor!"; const auto& local_tensor = JUST(OpInterpUtil::Dispatch(*op_, {x})); if (copy) { return local_tensor->clone(); } return local_tensor; } private: std::shared_ptr op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("LocalToGlobal"); m.add_functor("ToGlobal"); m.add_functor("GlobalToLocal"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/gradient_accumulation_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { class GradAccRepeatFunctor { public: GradAccRepeatFunctor() { op_ = CHECK_JUST(one::OpBuilder("repeat").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, int32_t repeat_num) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("repeat_num"); attrs.SetAllAttrs(repeat_num); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; class GradAccCollectFunctor { public: GradAccCollectFunctor() { op_ = CHECK_JUST(one::OpBuilder("acc").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, int32_t collect_num) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("max_acc_num"); attrs.SetAllAttrs(collect_num); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; class GradAccPackFunctor { public: GradAccPackFunctor() { op_ = CHECK_JUST(one::OpBuilder("pack").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, int32_t pack_num) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("pack_num"); attrs.SetAllAttrs(pack_num); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; class GradAccUnpackFunctor { public: GradAccUnpackFunctor() { op_ = CHECK_JUST(one::OpBuilder("unpack").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, int32_t unpack_num) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("unpack_num"); attrs.SetAllAttrs(unpack_num); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("GradAccRepeat"); m.add_functor("GradAccCollect"); m.add_functor("GradAccPack"); m.add_functor("GradAccUnpack"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/higher_derivative_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" namespace oneflow { namespace one { namespace functional { namespace impl { class SinGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto res = sequence_function(functional::Sin) .then(functional::Negative) .then(std::bind(functional::Mul, dydx, std::placeholders::_1)) .call(x); return res; } }; class CosGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto res = sequence_function(functional::Cos) .then(functional::Negative) .then(std::bind(functional::Mul, dydx, std::placeholders::_1)) .call(x); return res; } }; class TanGradGradFunctor { public: // dx = 1/cos^2(x), ddx = 2*sinx/cos^3(x) = tan_grad(x)*tan(x)*2 Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Mul) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(2), input); }) .call(JUST(functional::Tan(x)), JUST(functional::TanGrad(x, dydx))); return r; } }; class SinhGradGradFunctor { public: // dx = cosh(x), ddx = sinh(x) = cosh_grad(x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::CoshGrad(x, dydx); } }; class CoshGradGradFunctor { public: // dx = sinh(x), ddx = cosh(x) = sinh_grad(x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::SinhGrad(x, dydx); } }; class TanhGradGradFunctor { public: // dx = sech^2(x), ddx = -2*sech^2(x)*tanh(x) = dydx*tanh(x)*(-2) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Mul) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(-2), input); }) .call(dydx, x); return r; } }; class AsinGradGradFunctor { public: // dx = 1/sqrt(1-x*x)=rsqrt(1-x*x), ddx = rsqrt_grad(1-x*x)*(-2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0); }) .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(-2), input); }) .call(x); return r; } }; class AcosGradGradFunctor { public: // dx = -1/sqrt(1-x*x)=-rsqrt(1-x*x), ddx = rsqrt_grad(1-x*x)*(2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0); }) .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(2), input); }) .call(x); return r; } }; class AtanGradGradFunctor { public: // dx = 1/(1+x*x), ddx = reci_grad(1+x*x)*(2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(Scalar(1), input, /*alpha=*/1.0); }) .then(std::bind(functional::ReciprocalGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(2), input); }) .call(x); return r; } }; class AsinhGradGradFunctor { public: // dx = 1/sqrt(1+x*x)=rsqrt(1+x*x), ddx = rsqrt_grad(1+x*x)*(2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(Scalar(1), input, /*alpha=*/1.0); }) .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(2), input); }) .call(x); return r; } }; class AcoshGradGradFunctor { public: // dx = 1/sqrt(x*x-1)=rsqrt(x*x-1), ddx = rsqrt_grad(x*x-1)*(2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarSub(input, Scalar(1), /*alpha=*/1.0, /*inplace=*/false); }) .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(2), input); }) .call(x); return r; } }; class AtanhGradGradFunctor { public: // dx = 1/(1-x*x), ddx = reci_grad(1-x*x)*(-2x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto r = sequence_function(functional::Square) .then([](const std::shared_ptr& input) { return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0); }) .then(std::bind(functional::ReciprocalGrad, std::placeholders::_1, dydx)) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(-2), input); }) .call(x); return r; } }; class ErfGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(-2), JUST(functional::Mul(x, JUST(functional::ErfGrad(x, dydx))))); } }; class ErfcGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(-2), JUST(functional::Mul(x, JUST(functional::ErfcGrad(x, dydx))))); } }; class ExpGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ExpGrad(x, dydx); } }; class Exp2GradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(std::log(2)), JUST(functional::Exp2Grad(x, dydx))); } }; class Expm1GradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ExpGrad(x, dydx); } }; class LogGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ReciprocalGrad(x, dydx); } }; class Log2GradGradFunctor { public: // dx = 1/(x*ln2), ddx = 1/ln2 * -1/(x*x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(1.0 / std::log(2.0f)), JUST(functional::ReciprocalGrad(x, dydx))); } }; class Log10GradGradFunctor { public: // dx = 1/(x*ln10), ddx = 1/ln10 * -1/(x*x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(1.0 / std::log(10.0f)), JUST(functional::ReciprocalGrad(x, dydx))); } }; class Log1pGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ReciprocalGrad( JUST(functional::ScalarAdd(Scalar(1), x, /*alpha=*/Scalar(1))), dydx); } }; class LogSigmoidGradGradFunctor { public: // dx = exp(-x)/(1+exp(-x)), ddx = -exp(-x)/(1+exp(-x))^2 = -sigmoid_grad(x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::Negative(JUST(functional::SigmoidGrad(JUST(functional::Sigmoid(x)), dydx))); } }; class ReciprocalGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::Negative(JUST(functional::ScalarPowGrad(x, dydx, Scalar(-2)))); } }; class ReciprocalNoNanGradGradFunctor { public: // dx = -pow(x,-2), ddx = -pow_grad(x,-2) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::Negative(JUST(functional::ScalarPowGrad(x, dydx, Scalar(-2)))); } }; class RsqrtGradGradFunctor { public: // dx = -0.5*pow(x,-1.5), ddx = -0.5*pow_grad(x,-1.5) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(-0.5), JUST(functional::ScalarPowGrad(x, dydx, Scalar(-1.5)))); } }; class SqrtGradGradFunctor { public: // dx = 0.5*pow(x,-0.5), ddx = -0.25*pow(x,-1.5) = 0.5*rsqrt_grad(x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(Scalar(0.5), JUST(functional::RsqrtGrad(x, dydx))); } }; class SquareGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ScalarMul(2, dydx); } }; class SigmoidGradGradFunctor { public: // dy = y * (1 - y), ddy = 1 - 2*y Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dydx) const { return functional::Mul(JUST(functional::ScalarSub(1, y, /*alpha=*/2)), dydx); } }; class SiluGradGradFunctor { public: // y = x ∗ sigmoid(x) // y' = (sig(x) + x * sig_grad(x)) // y'' = (sig(x) + x*sig_grad(x))' = sig_grad(x)*(x+2-2*silu(x)) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto res = functional::sequence_function(functional::Silu) .then([](const std::shared_ptr& input) { return functional::ScalarSub(Scalar(2.0), input, /*alpha=*/Scalar(2.0)); }) .then([&x](const std::shared_ptr& input) { return functional::Add(x, input, /*alpha=*/Scalar(1.0), /*inplace=*/false); }) // Since we use y to compute SigmoidGrad, here we need to use sigmoid with x to // compute x first. // TODO(zzk): Implement SigmoidGradXDy func. .then(std::bind(functional::SigmoidGrad, JUST(functional::Sigmoid(x)), std::placeholders::_1)) .then(std::bind(functional::Mul, dydx, std::placeholders::_1)) .call(x); return res; } }; class SeluGradGradFunctor { public: // y'' = scale * alpha * exp(x) (x < 0) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto condition = JUST(functional::ScalarLogicalLess(x, Scalar(0.0))); auto res = functional::Where(condition, JUST(functional::SeluGrad(dydx, x)), JUST(functional::ZerosLike(x))); return res; } }; class SoftSignGradGradFunctor { public: // y = x/(1+abs(x)), y' = 1/(1+abs(x))^2, y'' = -2/(1+abs(x))^3*abs_grad(x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto res = functional::sequence_function(functional::Abs) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(Scalar(1.0), input, /*alpha=*/Scalar(1)); }) .then([](const std::shared_ptr& input) { return functional::ScalarPow(input, Scalar(-3), /*inplace=*/false); }) .then([](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(-2), input); }) .then(std::bind(functional::AbsGrad, x, std::placeholders::_1)) .then(std::bind(functional::Mul, dydx, std::placeholders::_1)) .call(x); return res; } }; class HardSigmoidGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { return functional::ZerosLike(x); } }; class HardSwishGradGradFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { auto condition = JUST(functional::ScalarLogicalGreater( (JUST(functional::ScalarLogicalLess(x, Scalar(3.0)))), Scalar(-3.0))); return functional::Where(condition, JUST(functional::ScalarDiv(dydx, Scalar(3.0))), JUST(functional::ZerosLike(x))); } }; class SoftplusGradGradFunctor { public: // beta*x <= threshold: // y = 1/beta*ln(1+exp(beta*x)), y' = 1/(1+exp(beta*x))*exp(beta*x) // y'' = beta*exp(beta*x)/(1+exp(beta*x))^2 = beta*sig(beta*x)(1-sig(beta*x)) // = beta*sig_grad(beta*x) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx, const double& beta, const double& threshold) const { auto beta_x = JUST(functional::ScalarMul(x, beta, /*inplace=*/false)); auto condition = JUST(functional::ScalarLogicalLess(beta_x, Scalar(threshold))); auto zero_out = JUST(functional::ZerosLike(x)); auto res = functional::sequence_function(functional::Sigmoid) .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, dydx)) .then([&beta](const std::shared_ptr& input) { return functional::ScalarMul(Scalar(beta), input); }) .then(std::bind(functional::Where, condition, std::placeholders::_1, zero_out)) .call(beta_x); return res; } }; class EluGradGradFunctor { public: // y = max(0,x) + min(0,alpha∗(exp(x)−1)) Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx, const double& alpha) const { auto condition = JUST(functional::ScalarLogicalLess(x, Scalar(0.0))); return functional::Where(condition, JUST(functional::EluGrad(x, dydx, alpha)), JUST(functional::ZerosLike(x))); } }; class CeluGradGradFunctor { public: // y = max(0,x) + min(0,alpha∗(exp(x/alpha)−1)) Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dydx, const double& alpha) const { auto condition = JUST(functional::ScalarLogicalLess(y, Scalar(0))); auto r = functional::Where(condition, JUST(functional::ScalarDiv(dydx, alpha)), JUST(functional::ZerosLike(y))); return r; } }; class MaxPoolNdGradGradFunctor { public: Maybe operator()(const std::shared_ptr& dydx, const std::shared_ptr& indices, const int ndims) const { if (indices->nelement()) { Shape view_shape(indices->shape()->begin(), indices->shape()->end() - ndims); view_shape.push_back(-1); auto indices_view = JUST(functional::Reshape(indices, view_shape)); auto outgrad_view = JUST(functional::Reshape(dydx, view_shape)); return functional::sequence_function(functional::DimGather) .then(std::bind(functional::Reshape, std::placeholders::_1, *indices->shape())) .call(outgrad_view, -1, indices_view, /*sparse_grad=*/false); } else { // empty inputs, return 0size tensor return functional::ZerosLike(indices); } } }; class MishGradGradFunctor { public: // y = x ∗ tanh(softplus(x)) // ddx = grad_tsp * sig * (2 + x * (1 + (-1 - 2 * tsp) * sig)), sig equal grad_sp here Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { const auto sig = JUST(functional::Sigmoid(x)); const auto sp = JUST(functional::Log1p(JUST(functional::Exp(x)))); const auto tanh_sp = JUST(functional::Tanh(sp)); const auto grad_tsp = JUST(functional::TanhGrad(tanh_sp, dydx)); auto r = functional::sequence_function(functional::Tanh) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(-1, input, /*alpha=*/-2); }) .then(std::bind(functional::Mul, std::placeholders::_1, sig)) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/1); }) .then(std::bind(functional::Mul, std::placeholders::_1, x)) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(2, input, /*alpha=*/1); }) .then(std::bind(functional::Mul, std::placeholders::_1, sig)) .then(std::bind(functional::Mul, std::placeholders::_1, grad_tsp)) .call(sp); return r; } }; class GeluGradGradFunctor { public: // y = gussian(x) = 0.5 * x * (1.0 + erf(sqrt(0.5) * x)); // dx = 0.5 * (1.0 + erf(sqrt(0.5)*x) + x * coef * exp(-0.5*x*x)) * dy), coef = sqrt(-2.0/pi) // ddx = coef * grad1 * grad2 * flow.exp(t) * (1+t), t = -0.5*x*x Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dydx) const { const auto& tmp = JUST(functional::ScalarMul(-0.5, JUST(functional::Square(x)))); const auto& tmp_add_one = JUST(functional::ScalarAdd(1, tmp, 1)); const Scalar coef = std::sqrt(2.0 / std::acos(-1.0)); auto r = functional::sequence_function(functional::Exp) .then(std::bind(functional::Mul, std::placeholders::_1, tmp_add_one)) .then(std::bind(functional::Mul, std::placeholders::_1, dydx)) .then([&coef](const std::shared_ptr& input) { return functional::ScalarMul(coef, input); }) .call(tmp); return r; } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("SinGradGrad"); m.add_functor("CosGradGrad"); m.add_functor("TanGradGrad"); m.add_functor("SinhGradGrad"); m.add_functor("CoshGradGrad"); m.add_functor("TanhGradGrad"); m.add_functor("AsinGradGrad"); m.add_functor("AcosGradGrad"); m.add_functor("AtanGradGrad"); m.add_functor("AsinhGradGrad"); m.add_functor("AcoshGradGrad"); m.add_functor("AtanhGradGrad"); m.add_functor("ErfGradGrad"); m.add_functor("ErfcGradGrad"); m.add_functor("ExpGradGrad"); m.add_functor("Exp2GradGrad"); m.add_functor("Expm1GradGrad"); m.add_functor("LogGradGrad"); m.add_functor("Log2GradGrad"); m.add_functor("Log10GradGrad"); m.add_functor("Log1pGradGrad"); m.add_functor("LogSigmoidGradGrad"); m.add_functor("ReciprocalGradGrad"); m.add_functor("ReciprocalNoNanGradGrad"); m.add_functor("RsqrtGradGrad"); m.add_functor("SqrtGradGrad"); m.add_functor("SquareGradGrad"); m.add_functor("SigmoidGradGrad"); m.add_functor("SiluGradGrad"); m.add_functor("SeluGradGrad"); m.add_functor("SoftSignGradGrad"); m.add_functor("HardSigmoidGradGrad"); m.add_functor("HardSwishGradGrad"); m.add_functor("SoftplusGradGrad"); m.add_functor("EluGradGrad"); m.add_functor("CeluGradGrad"); m.add_functor("MaxPoolNdGradGrad"); m.add_functor("MishGradGrad"); m.add_functor("GeluGradGrad"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/linalg_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace linalg { class CrossFunctor { public: CrossFunctor() { op_ = CHECK_JUST(OpBuilder("linalg_cross").Input("input").Input("other").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& other, const Optional& dim) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); const auto do_dispatch_base_on_device = [&attrs, this]( const std::shared_ptr& input, const std::shared_ptr& other, const int64_t dim) -> Maybe { DeviceType device{}; if (input->is_global()) { device = JUST(input->parallel_desc())->device_type(); } else { device = JUST(input->device())->enum_type(); } const int64_t final_dim = input->ndim() - 1; if (device == DeviceType::kCUDA && dim != final_dim) { attrs.SetAllAttrs(final_dim); std::vector perm(input->ndim(), 0); for (size_t i = 0; i < perm.size(); ++i) { perm[i] = static_cast(i); } std::swap(perm[dim], perm[final_dim]); return functional::Transpose( JUST(OpInterpUtil::Dispatch(*op_, {JUST(functional::Transpose(input, perm)), JUST(functional::Transpose(other, perm))}, attrs)), perm); } attrs.SetAllAttrs(dim); return OpInterpUtil::Dispatch(*op_, {input, other}, attrs); }; Shape shape_to_broadcast; std::deque need_to_broadcast; std::tie(shape_to_broadcast, need_to_broadcast) = *JUST(InferUnifiedShapeForBroadcastingWithInfo({*input->shape(), *other->shape()})); CHECK_EQ_OR_RETURN(need_to_broadcast.size(), 2) << fmt::format("The number of boolean values to determine if the tensor is to be broadcast " "should be 2 (which is {})", need_to_broadcast.size()); const auto new_input = need_to_broadcast[0] ? JUST(functional::Expand(input, shape_to_broadcast)) : input; const auto new_other = need_to_broadcast[1] ? JUST(functional::Expand(other, shape_to_broadcast)) : other; if (!dim.has_value()) { return do_dispatch_base_on_device(new_input, new_other, JUST(FindValidDim(shape_to_broadcast))); } int64_t new_dim = JUST(dim); if (new_dim < 0) { new_dim += shape_to_broadcast.NumAxes(); } CHECK_EQ_OR_RETURN(shape_to_broadcast.At(new_dim), 3) << Error::RuntimeError() << fmt::format("the size of the specified dimension(which is {}) is not 3.", JUST(dim)); return do_dispatch_base_on_device(new_input, new_other, new_dim); } private: Maybe FindValidDim(const Shape& shape) const { int64_t valid_dim = -1; const auto& dim_vec = shape.dim_vec(); for (size_t i = 0; i < dim_vec.size(); ++i) { if (dim_vec[i] == 3) { valid_dim = i; break; } } if (valid_dim == -1) { return Error::RuntimeError() << "no dimension of size 3 in input."; } return valid_dim; } std::shared_ptr op_; }; } // namespace linalg } // namespace impl using namespace impl::linalg; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("LinalgCross"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/math_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/functional/tensor_processor.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { namespace functional { namespace impl { class AddNFunctor { public: AddNFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("add_n").Input("in", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& inputs, bool inplace) const { CHECK_GE_OR_RETURN(inputs.size(), 2); TensorTuple outputs; for (int i = 0; i < inputs.size(); i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i; TensorTuple partial_inputs(size); std::copy(inputs.begin() + i, inputs.begin() + i + size, partial_inputs.begin()); if (i == 0 && inplace) { JUST(CheckInplaceValid(partial_inputs.at(0))); std::shared_ptr outs = std::make_shared(1); (*outs)[0] = partial_inputs[0]; JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs, outs.get())); outputs.emplace_back((*outs)[0]); } else { outputs.emplace_back( JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs))); } } if (outputs.size() == 1) { return outputs.at(0); } return this->operator()(outputs, inplace); } private: std::vector> op_; }; class ScalarMathBaseFunctor { public: explicit ScalarMathBaseFunctor(std::string op_name) { op_ = CHECK_JUST(one::OpBuilder(op_name).Input("in").Output("out").Build()); } virtual ~ScalarMathBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const Scalar& scalar, bool inplace) const { if (std::dynamic_pointer_cast(x) && op_->op_type_name() == "scalar_mul") { return x; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); TensorProcessor tensor_processor; Symbol lowest_dtype; if (scalar.IsFloatingPoint() || scalar.IsComplex()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); // Only promote type to Float32 when tensor is Int type but scalar is float type. if (DType::priority_order[x->dtype()->data_type()] < DType::priority_order[DType::Float16()->data_type()]) { lowest_dtype = DType::Float(); } else { lowest_dtype = x->dtype(); } } else if (scalar.IsIntegral()) { attrs.SetAllAttrs(NullOpt, false, scalar.As(), true); // Promote type to Int64 when tensor is Bool type but scalar is int type. // Promote type to Float32 when op is scalar_div. if (DType::priority_order[x->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else if (op_->op_type_name() == "scalar_div") { lowest_dtype = x->dtype() == DType::Float16() ? DType::Float16() : DType::Float(); } else { lowest_dtype = x->dtype(); } } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in " << op_->op_type_name() << " should be float or int."; } JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply()); TensorTuple casted_vec = JUST(tensor_processor.GetInputs()); if (inplace) { JUST(CheckInplaceCastValid(x, casted_vec[0])); JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); (*outputs)[0] = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), OpExprInterpContext(attrs))); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op_, casted_vec, attrs); } } private: std::shared_ptr op_; }; class ScalarAddFunctor : public ScalarMathBaseFunctor { public: ScalarAddFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_add") {} Maybe operator()(const std::shared_ptr& input, const Scalar& other, const Scalar& alpha, const bool& inplace) const { if (IsIntegralDataType(input->dtype()->data_type()) && other.IsIntegral() && alpha.IsFloatingPoint()) { return Error::RuntimeError() << "For integral input tensors, argument alpha must not be a floating point number."; } Scalar scalar; if (other.IsFloatingPoint() || alpha.IsFloatingPoint()) { scalar = Scalar(other.Value() * alpha.Value()); } else { scalar = Scalar(other.Value() * alpha.Value()); } return ScalarMathBaseFunctor::operator()(input, scalar, inplace); } }; class ScalarAdd2Functor { public: Maybe operator()(const Scalar& input, const std::shared_ptr& other, const Scalar& alpha) const { if (IsIntegralDataType(other->dtype()->data_type()) && input.IsIntegral() && alpha.IsFloatingPoint()) { return Error::RuntimeError() << "For integral input tensors, argument alpha must not be a floating point number."; } std::shared_ptr other_; if ((alpha.IsIntegral() && alpha.Value() == 1) || (alpha.IsFloatingPoint() && std::fabs(alpha.Value() - 1.0) < std::numeric_limits::epsilon())) { other_ = other; } else { other_ = JUST(ScalarMul(alpha, other)); } return ScalarAdd(other_, input, /*alpha=*/1, /*inplace=*/false); } }; class ScalarSubFunctor { public: Maybe operator()(const std::shared_ptr& input, const Scalar& scalar, const Scalar& alpha, bool inplace) const { return ScalarAdd(input, Scalar(-1) * scalar, alpha, inplace); } }; class ScalarSub2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& input, const Scalar& alpha) const { return ScalarAdd(scalar, input, Scalar(-1) * alpha); } }; class ScalarMulFunctor : public ScalarMathBaseFunctor { public: ScalarMulFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_mul") {} }; class ScalarMul2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarMul(x, scalar, false); } }; class InplaceScalarMulFunctor : public ScalarMathBaseFunctor { public: InplaceScalarMulFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_mul") {} Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { return ScalarMathBaseFunctor::operator()(x, scalar, true); } }; class ScalarDivFunctor : public ScalarMathBaseFunctor { public: ScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_div") {} Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { return ScalarMathBaseFunctor::operator()(x, scalar, false); } }; class ScalarDivModeFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar& scalar, const Optional& rounding_mode) const { std::string rmode = rounding_mode.value_or(""); CHECK_OR_RETURN(rmode == "" || rmode == "floor" || rmode == "trunc") << "div expected rounding_mode to be one of None," " 'trunc', or 'floor' but found " << rmode; std::shared_ptr ret = JUST(functional::ScalarDiv(x, scalar)); if (rmode == "floor") { return JUST(functional::Floor(ret)); } else if (rmode == "trunc") { return JUST(functional::Trunc(ret)); } return ret; } }; class ScalarDiv2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return functional::ScalarMul(JUST(functional::Reciprocal(x)), scalar, /*inplace=*/false); } }; class ScalarDivMode2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x, const Optional& rounding_mode) const { std::string rmode = rounding_mode.value_or(""); CHECK_OR_RETURN(rmode == "" || rmode == "floor" || rmode == "trunc") << "div expected rounding_mode to be one of None," " 'trunc', or 'floor' but found " << rmode; std::shared_ptr ret = JUST(functional::ScalarDiv(scalar, x)); if (rmode == "floor") { return JUST(functional::Floor(ret)); } else if (rmode == "trunc") { return JUST(functional::Trunc(ret)); } return ret; } }; class InplaceScalarDivFunctor : public ScalarMathBaseFunctor { public: InplaceScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_mul") {} Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { return ScalarMathBaseFunctor::operator()(x, Scalar(1.0) / scalar, true); } }; class ScalarPowFunctor : public ScalarMathBaseFunctor { public: ScalarPowFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_pow") {} }; class ScalarPowGradFunctor { public: ScalarPowGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("scalar_pow_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const Scalar& scalar) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (scalar.IsFloatingPoint()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); } else if (scalar.IsIntegral()) { attrs.SetAllAttrs(NullOpt, false, scalar.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarPowGrad should be float or int."; } return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class ScalarReversePowFunctor : public ScalarMathBaseFunctor { public: ScalarReversePowFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_reverse_pow") {} Maybe operator()(const Scalar& scalar, const std::shared_ptr& input) const { return ScalarMathBaseFunctor::operator()(input, scalar, false); } }; class ScalarReversePowGradFunctor { public: ScalarReversePowGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_reverse_pow_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const Scalar& scalar) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (scalar.IsFloatingPoint()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); } else if (scalar.IsIntegral()) { attrs.SetAllAttrs(NullOpt, false, scalar.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarTensorPowGrad should be float or int."; } return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); } private: std::shared_ptr op_; }; class ScalarFloorDivFunctor : public ScalarMathBaseFunctor { public: ScalarFloorDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_floordiv") {} }; class ScalarTruncDivFunctor : public ScalarMathBaseFunctor { public: ScalarTruncDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_truncdiv") {} }; class ScalarFModFunctor : public ScalarMathBaseFunctor { public: ScalarFModFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_fmod") {} }; class ReduceMaxFunctor { public: ReduceMaxFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_max").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); if (axis.empty()) { std::vector reduce_axis(x->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); attrs.SetAllAttrs(reduce_axis, keepdims); } else { attrs.SetAllAttrs(axis, keepdims); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ReduceMinFunctor { public: ReduceMinFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_min").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); if (axis.empty()) { std::vector reduce_axis(x->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); attrs.SetAllAttrs(reduce_axis, keepdims); } else { attrs.SetAllAttrs(axis, keepdims); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class MaxFunctor { public: Maybe operator()(const std::shared_ptr& x) const { std::vector axis(x->ndim()); std::iota(axis.begin(), axis.end(), 0); return ReduceMax(x, axis, /*keepdims=*/false); } }; class Max2Functor { public: Maybe operator()(const std::shared_ptr& x, const int32_t& dim, const bool& keepdims) const { auto outputs = std::make_shared(2); int32_t axis = dim; axis = JUST(maybe_wrap_dim(axis, x->ndim())); (*outputs)[0] = JUST(ReduceMax(x, {axis}, keepdims)); (*outputs)[1] = JUST(ArgMax(x, dim, keepdims, NullOpt)); return outputs; } }; class MinFunctor { public: Maybe operator()(const std::shared_ptr& x) const { std::vector axis(x->ndim()); std::iota(axis.begin(), axis.end(), 0); return ReduceMin(x, axis, /*keepdims=*/false); } }; class Min2Functor { public: Maybe operator()(const std::shared_ptr& x, const int32_t& dim, const bool& keepdims) const { auto outputs = std::make_shared(2); int32_t axis = dim; axis = JUST(maybe_wrap_dim(axis, x->ndim())); (*outputs)[0] = JUST(ReduceMin(x, {axis}, keepdims)); (*outputs)[1] = JUST(ArgMin(x, dim, keepdims, NullOpt)); return outputs; } }; class AminFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional>& dim, const bool& keepdim) const { if (!dim.has_value()) { return ReduceMin(x, {}, keepdim); } const int32_t ndim = x->ndim(); std::vector& dims = *JUST(dim); for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); } return ReduceMin(x, dims, keepdim); } }; class AmaxFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional>& dim, const bool& keepdim) const { if (!dim.has_value()) { return ReduceMax(x, {}, keepdim); } const int32_t ndim = x->ndim(); std::vector& dims = *JUST(dim); for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); } return ReduceMax(x, dims, keepdim); } }; class ReduceSumWholeFunctor { public: ReduceSumWholeFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false)); } const int32_t naxis = tensor->ndim(); if (naxis == 0) { return x; } // for 0-dim Tensor std::vector axis(naxis); std::iota(axis.begin(), axis.end(), 0); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(axis, false); TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); } private: std::shared_ptr op_; }; class ReduceSumFunctor { public: ReduceSumFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool keepdims, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false)); } std::vector reduce_axis = *JUST(CheckAxis(axis, x->ndim())); if (reduce_axis.size() == 0) { return tensor; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, keepdims); TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); } private: std::shared_ptr op_; }; class ReduceNanSumFunctor { public: ReduceNanSumFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_nansum").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false)); } std::vector reduce_axis = *JUST(CheckAxis(axis, tensor->ndim())); if (reduce_axis.size() == 0) { return tensor; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, keepdims); TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); } private: std::shared_ptr op_; }; class ReduceNanSumWholeFunctor { public: ReduceNanSumWholeFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_nansum").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false)); } const int32_t ndim = tensor->ndim(); if (ndim == 0) { return tensor; } // for 0-dim Tensor std::vector axis(ndim); std::iota(axis.begin(), axis.end(), 0); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(axis, false); TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); } private: std::shared_ptr op_; }; class ReduceAllWholeFunctor { public: ReduceAllWholeFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_all").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x) const { std::vector reduce_axis(x->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, false); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ReduceAllFunctor { public: ReduceAllFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_all").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { std::vector reduce_axis = *JUST(CheckAxis(axis, x->ndim())); if (reduce_axis.size() == 0) { return x; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, keepdims); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ReduceAnyWholeFunctor { public: ReduceAnyWholeFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_any").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x) const { std::vector reduce_axis(x->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, false); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class ReduceAnyFunctor { public: ReduceAnyFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_any").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { std::vector reduce_axis = *JUST(CheckAxis(axis, x->ndim())); if (reduce_axis.size() == 0) { return x; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, keepdims); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; template class ReduceDeviceStageBaseFunctor { public: ReduceDeviceStageBaseFunctor() : op_(CHECK_JUST(one::OpBuilder(T::GetOpName()) .Input("in") .Output("out") .Output("mask") .Output("count") .Build())) {} Maybe operator()(const std::shared_ptr& in, const std::vector& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } virtual ~ReduceDeviceStageBaseFunctor() = default; private: std::shared_ptr op_; }; template class ReduceDeviceStageGradBaseFunctor { public: ReduceDeviceStageGradBaseFunctor() : op_(CHECK_JUST(one::OpBuilder(T::GetOpName()) .Input("out_diff") .Input("mask") .Input("count") .Output("in_diff") .Build())) {} Maybe operator()(const std::shared_ptr& out_diff, const std::shared_ptr& mask, const std::shared_ptr& count, const std::vector& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {out_diff, mask, count}, attrs); } virtual ~ReduceDeviceStageGradBaseFunctor() = default; private: std::shared_ptr op_; }; class ReduceMinDeviceStageFunctor : public ReduceDeviceStageBaseFunctor { public: static std::string GetOpName() { return "reduce_min_device_stage"; } }; class ReduceMaxDeviceStageFunctor : public ReduceDeviceStageBaseFunctor { public: static std::string GetOpName() { return "reduce_max_device_stage"; } }; class ReduceMinDeviceStageGradFunctor : public ReduceDeviceStageGradBaseFunctor { public: static std::string GetOpName() { return "reduce_min_device_stage_grad"; } }; class ReduceMaxDeviceStageGradFunctor : public ReduceDeviceStageGradBaseFunctor { public: static std::string GetOpName() { return "reduce_max_device_stage_grad"; } }; template class ReduceGlobalStageBaseFunctor { public: ReduceGlobalStageBaseFunctor() : op_(CHECK_JUST(one::OpBuilder(T::GetOpName()) .Input("in") .Input("device_count") .Output("out") .Output("mask") .Build())) {} Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& device_count, const std::vector& axis, const bool& keepdims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(axis, keepdims); return OpInterpUtil::Dispatch(*op_, {in, device_count}, attrs); } virtual ~ReduceGlobalStageBaseFunctor() = default; private: std::shared_ptr op_; }; template class ReduceGlobalStageGradBaseFunctor { public: ReduceGlobalStageGradBaseFunctor() : op_(CHECK_JUST(one::OpBuilder(T::GetOpName()) .Input("out_diff") .Input("mask") .Input("device_count") .Output("in_diff") .Build())) {} Maybe operator()(const std::shared_ptr& out_diff, const std::shared_ptr& mask, const std::shared_ptr& device_count, const std::vector& axis, const bool& keepdims) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(axis, keepdims); return OpInterpUtil::Dispatch(*op_, {out_diff, mask, device_count}, attrs); } virtual ~ReduceGlobalStageGradBaseFunctor() = default; private: std::shared_ptr op_; }; class ReduceMinGlobalStageFunctor : public ReduceGlobalStageBaseFunctor { public: static std::string GetOpName() { return "reduce_min_global_stage"; } }; class ReduceMinGlobalStageGradFunctor : public ReduceGlobalStageGradBaseFunctor { public: static std::string GetOpName() { return "reduce_min_global_stage_grad"; } }; class ReduceMaxGlobalStageFunctor : public ReduceGlobalStageBaseFunctor { public: static std::string GetOpName() { return "reduce_max_global_stage"; } }; class ReduceMaxGlobalStageGradFunctor : public ReduceGlobalStageGradBaseFunctor { public: static std::string GetOpName() { return "reduce_max_global_stage_grad"; } }; class ReduceMeanWholeFunctor { public: ReduceMeanWholeFunctor() {} Maybe operator()(const std::shared_ptr& x) const { // ReduceMean only calculate floating values. CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type()) || IsComplexDataType(x->dtype()->data_type())) << "RuntimeError: Can only calculate the mean of floating types or complex types."; size_t reduce_count = 1; reduce_count = x->shape()->Count(0); const auto& sum = JUST(functional::ReduceSumWhole(x, NullOpt)); if (reduce_count == 1 || reduce_count == 0) { return sum; } return functional::ScalarMul(sum, 1.0 / reduce_count, false); } }; class ReduceMeanFunctor { public: ReduceMeanFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { // ReduceMean only calculate floating values. // NOTE: Should use original reduce_mean op/kernel rather than current way(ReduceSum / // reduce_count) because it could encounter precision problem(like overflow) in float16 case. CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type())) << "RuntimeError: Can only calculate the mean of floating types."; const auto& sum = JUST(functional::ReduceSum(x, axis, keepdims, NullOpt)); size_t reduce_count = 1; if (axis.empty()) { reduce_count = x->shape()->Count(0); } else { std::vector reduce_axis = *JUST(CheckAxis(axis, x->ndim())); for (int32_t& i : reduce_axis) { reduce_count *= x->shape()->At(i); } } if (reduce_count == 1 || reduce_count == 0) { return sum; } return functional::ScalarMul(sum, 1.0 / reduce_count, false); } }; class ReduceProdWholeFunctor { public: ReduceProdWholeFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_prod").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(tensor, JUST(dtype), /*pin_memory=*/false)); } TensorProcessor tensor_processor; Symbol lowest_dtype; if (DType::priority_order[tensor->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else { lowest_dtype = tensor->dtype(); } JUST(tensor_processor.AddInputs({tensor}, lowest_dtype).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); std::vector reduce_axis(tensor->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, false); return JUST(OpInterpUtil::Dispatch(*op_, input_tuple, attrs)); } private: std::shared_ptr op_; }; class MedianFunctor { public: MedianFunctor() { op_ = CHECK_JUST(one::OpBuilder("median").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& x) const { if (x->shape()->elem_cnt() == 0) { return functional::To( JUST(functional::Constant(Shape({1}).RemoveOnes({0}), Scalar(std::numeric_limits::quiet_NaN()), JUST(DType::Get(DataType::kFloat)), NullOpt)), x, false); } return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class MedianWithIndicesFunctor { public: MedianWithIndicesFunctor() { op_ = CHECK_JUST(one::OpBuilder("median_with_indices") .Input("input") .Output("values") .Output("indices") .Build()); } Maybe operator()(const std::shared_ptr& x, const int32_t& dim, const bool& keepdim) const { int32_t axis = dim; const int64_t ndim = x->ndim(); axis = JUST(maybe_wrap_dim(axis, ndim)); std::shared_ptr tensor = x; if (x->dim(axis) == 0) { return Error::IndexError() << "IndexError: Expected reduction dim " << axis << " to have non-zero size."; } if (axis != ndim - 1) { tensor = JUST(functional::Squeeze( JUST(functional::Transpose2dim(JUST(functional::Unsqueeze(x, -1)), axis, -1)), std::vector({axis}))); } std::shared_ptr result; result = JUST(OpInterpUtil::Dispatch(*op_, {tensor})); if (keepdim) { JUST(VectorAt(*result, 0)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 0)), axis)); JUST(VectorAt(*result, 1)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 1)), axis)); } return result; } private: std::shared_ptr op_; }; class ModeFunctor { public: ModeFunctor() { op_ = CHECK_JUST( one::OpBuilder("mode").Input("input").Output("values").Output("indices").Build()); } Maybe operator()(const std::shared_ptr& x, const int32_t& dim, const bool keepdim) const { int32_t axis = dim; const int64_t ndim = x->ndim(); axis = JUST(maybe_wrap_dim(axis, ndim)); std::shared_ptr tensor = x; if (x->dim(axis) == 0) { return Error::IndexError() << "IndexError: Expected reduction dim " << axis << " to have non-zero size."; } if (axis != ndim - 1) { tensor = JUST(functional::Squeeze( JUST(functional::Transpose2dim(JUST(functional::Unsqueeze(x, -1)), axis, -1)), std::vector({axis}))); } std::shared_ptr result; result = JUST(OpInterpUtil::Dispatch(*op_, {tensor})); if (keepdim) { JUST(VectorAt(*result, 0)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 0)), axis)); JUST(VectorAt(*result, 1)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 1)), axis)); } return result; } private: std::shared_ptr op_; }; class ReduceProdFunctor { public: ReduceProdFunctor() { op_ = CHECK_JUST( one::OpBuilder("reduce_prod").Input("input_tensor").Output("output_tensor").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims, const Optional>& dtype) const { std::shared_ptr tensor = x; if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(tensor, JUST(dtype), /*pin_memory=*/false)); } TensorProcessor tensor_processor; Symbol lowest_dtype; if (DType::priority_order[tensor->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else { lowest_dtype = tensor->dtype(); } JUST(tensor_processor.AddInputs({tensor}, lowest_dtype).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); std::vector reduce_axis = *JUST(CheckAxis(axis, x->ndim())); if (reduce_axis.size() == 0) { return x; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); attrs.SetAllAttrs(reduce_axis, keepdims); return JUST(OpInterpUtil::Dispatch(*op_, input_tuple, attrs)); } private: std::shared_ptr op_; }; class LogSumExpFunctor { public: LogSumExpFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::vector& axis, const bool& keepdims) const { if (x->ndim() == 0) { // can't take amax of 0-dim tensor return To(x, JUST(DType::Get(DataType::kFloat)), false); } else if (x->nelement() == 0) { // can't take amax of empty tensor std::shared_ptr exp_out = JUST(Exp(x)); return Log(JUST(ReduceSum(exp_out, axis, keepdims, NullOpt))); } else { const std::shared_ptr& maxes = JUST(Amax(x, axis, true)); const std::shared_ptr& maxes_squeezed = (keepdims ? maxes : JUST(SqueezeMultiple(maxes, axis))); JUST(MaskedFillInplace(maxes_squeezed, JUST(ScalarLogicalEqual(JUST(Abs(maxes_squeezed)), INFINITY)), 0)); std::shared_ptr exp_out = JUST(Exp(JUST(Sub(x, maxes, 1, false)))); return Add(JUST(Log(JUST(ReduceSum(exp_out, axis, keepdims, NullOpt)))), maxes_squeezed, 1, false); } } private: Maybe SqueezeMultiple(const std::shared_ptr& x, const std::vector& axis) const { int ndims = x->ndim(); const auto& dims_to_squeeze = JUST(dim_list_to_bitset(axis, ndims)); std::shared_ptr result = x; for (int i = ndims - 1; i >= 0; --i) { if ((*dims_to_squeeze)[i]) { std::vector dims = {i}; result = JUST(Squeeze(result, dims)); } } return result; } }; class LogAddExpFunctor { public: LogAddExpFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { CHECK_OR_RETURN(x->nelement() > 0 && y->nelement() > 0) << "logaddexp do not support 0-size tensor."; const std::shared_ptr& maxes = JUST(Maximum(x, y)); std::shared_ptr exp_out = JUST(Exp(JUST(Negative(JUST(Abs(JUST(Sub(x, y, 1, false)))))))); std::shared_ptr add_out = JUST(ScalarAdd(1.0, exp_out, 1)); return Add(maxes, JUST(Log(add_out)), 1, false); } }; class QuantileFunctor { public: QuantileFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& q, const Optional& dim, const bool keepdim, const std::string& interpolation, const bool ignore_nan) const { CHECK_GT_OR_RETURN(input->nelement(), 0) << "oneflow.quantile input tensor must be non-empty"; CHECK_LE_OR_RETURN(q->ndim(), 1) << "oneflow.quantile only support `q` tensor is a scalar or 1D tensor."; int64_t wrapped_dim = JUST(maybe_wrap_dim(dim.value_or(0), input->ndim())); // NOTE(hujiakui): this check is only performed when running on the CPU to avoid // synchronizing an accelerator with the CPU // For q is a Tensor. DeviceType input_device{}; if (input->is_global()) { input_device = JUST(input->parallel_desc())->device_type(); } else { input_device = JUST(input->device())->enum_type(); } if (input_device == DeviceType::kCPU) { std::shared_ptr condition = JUST(functional::ReduceAllWhole(JUST(functional::BroadcastLogicalAnd( JUST(functional::ScalarLogicalGreaterEqual(q, Scalar(0.0))), JUST(functional::ScalarLogicalLessEqual(q, Scalar(1.0))))))); CHECK_OR_RETURN(JUST(functional::Equal( condition, JUST(functional::Cast(JUST(functional::OnesLike(condition)), DType::Bool(), false))))) << "oneflow.quantile q values must be in the range [0, 1]"; } // calculate the shape of output auto out_shape = quantile_output_shape(dim, input, q, keepdim, wrapped_dim); std::shared_ptr sorted; if (!dim.has_value()) { sorted = JUST(functional::Flatten(input, 0, -1)); sorted = JUST(functional::Sort(sorted, -1, false))->at(0); } else if (wrapped_dim == input->ndim() - 1) { sorted = JUST(functional::Sort(input, -1, false))->at(0); } else { sorted = JUST(functional::Unsqueeze(input, input->ndim() - 1)); std::vector perm(sorted->ndim()); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[wrapped_dim], perm[perm.size() - 1]); sorted = JUST(view::Transpose(sorted, perm)); sorted = JUST(functional::Sort(sorted, -1, false))->at(0); } std::vector in_shape(out_shape.size()); std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin()); in_shape[in_shape.size() - 1] = sorted->dim(sorted->ndim() - 1); DimVector inv(in_shape.size()); for (int i = 0; i < in_shape.size(); ++i) { inv[i] = in_shape[i]; } const Shape step_shape(inv); sorted = JUST(functional::View(sorted->contiguous(), step_shape)); CHECK_LE_OR_RETURN(sorted->dim(sorted->ndim() - 1), std::pow(2, 24)) << "oneflow.quantile input tensor is too large"; std::shared_ptr ranks; if (ignore_nan) { ranks = JUST( functional::Mul(JUST(functional::ScalarSub( JUST(functional::ReduceSum( JUST(functional::LogicalNot(JUST(functional::IsNan(sorted)))), std::vector({static_cast(sorted->ndim() - 1)}), /*keepdim=*/true, NullOpt)), Scalar(1), Scalar(1), /*inplace=*/false)), q)); ranks = JUST(functional::MaskedFill( ranks, JUST(functional::ScalarLogicalLess(ranks, Scalar(0))), Scalar(0))); } else { int64_t last_index = sorted->dim(sorted->ndim() - 1) - 1; std::shared_ptr tl = JUST(functional::BroadcastTensors( {JUST(functional::ScalarMul(q, last_index, /*inplace=*/false)), JUST(functional::ReduceAny( JUST(functional::IsNan(sorted)), std::vector({static_cast(sorted->ndim() - 1)}), /*keepdim=*/true))})); ranks = JUST(functional::MaskedFill(tl->at(0), tl->at(1), Scalar(last_index))); } if (interpolation == "lower") { JUST(functional::Floor_(ranks)); } else if (interpolation == "higher") { JUST(functional::Ceil_(ranks)); } else if (interpolation == "nearest") { JUST(functional::Round_(ranks)); } std::shared_ptr ranks_below = JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false)); std::shared_ptr values_below = JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_below, false)); if (interpolation == "linear" || interpolation == "midpoint") { std::shared_ptr weights = interpolation == "midpoint" ? JUST(functional::FullLike(ranks, Scalar(0.5))) : JUST(functional::Sub(ranks, ranks_below, Scalar(1.0), /*inplace=*/false)); JUST(functional::Ceil_(ranks)); std::shared_ptr ranks_above = JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false)); std::shared_ptr values_above = JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_above, false)); values_below = JUST(functional::Lerp(values_below, values_above, weights)); } values_below = JUST(view::Unsqueeze(values_below, 0)); int32_t ndim = values_below->ndim(); std::vector perm(ndim); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[0], perm[perm.size() - 1]); values_below = JUST(view::Transpose(values_below, perm)); return view::Squeeze(values_below, std::vector({static_cast(values_below->ndim() - 1)})); } private: static inline std::vector quantile_output_shape(const Optional& dim, const std::shared_ptr& input, const std::shared_ptr& q, const bool keepdim, int64_t wrapped_dim) { // Compute output shape: q_size + reduced_size std::vector out_shape; if (dim.has_value() && input->ndim() > 0) { out_shape = std::vector(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end()); if (keepdim) { out_shape[wrapped_dim] = 1; } else { out_shape.erase(out_shape.begin() + wrapped_dim); } } else if (keepdim) { out_shape = std::vector(input->ndim(), 1); } out_shape.insert(out_shape.begin(), q->nelement()); return out_shape; } }; class ScalarQuantileFunctor { public: ScalarQuantileFunctor() {} Maybe operator()(const std::shared_ptr& input, const Scalar& q, const Optional& dim, const bool& keepdim, const std::string& interpolation, const bool& ignore_nan) const { CHECK_GT_OR_RETURN(input->nelement(), 0) << "oneflow.quantile input tensor must be non-empty"; int64_t wrapped_dim = JUST(maybe_wrap_dim(dim.value_or(0), input->ndim())); double qf = 0; if (q.IsIntegral()) { qf = static_cast(q.As()); } else { qf = q.As(); } CHECK_OR_RETURN(qf <= 1.0 && qf >= 0.0) << "oneflow.quantile q values must be in the range [0, 1]"; // calculate the shape of output auto out_shape = quantile_output_shape(dim, input, q, keepdim, wrapped_dim); std::shared_ptr sorted; if (!dim.has_value()) { sorted = JUST(functional::Flatten(input, 0, -1)); sorted = JUST(functional::Sort(sorted, -1, false))->at(0); } else if (wrapped_dim == input->ndim() - 1) { sorted = JUST(functional::Sort(input, -1, false))->at(0); } else { sorted = JUST(functional::Unsqueeze(input, input->ndim() - 1)); std::vector perm(sorted->ndim()); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[wrapped_dim], perm[perm.size() - 1]); sorted = JUST(view::Transpose(sorted, perm)); sorted = JUST(functional::Sort(sorted, -1, false))->at(0); } // q ==> 1-D Tensor out_shape.insert(out_shape.begin(), 1); std::vector in_shape(out_shape.size()); std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin()); in_shape[in_shape.size() - 1] = sorted->dim(sorted->ndim() - 1); DimVector inv(in_shape.size()); for (int i = 0; i < in_shape.size(); ++i) { inv[i] = in_shape[i]; } const Shape step_shape(inv); sorted = JUST(functional::View(sorted->contiguous(), step_shape)); CHECK_LE_OR_RETURN(sorted->dim(sorted->ndim() - 1), std::pow(2, 24)) << "oneflow.quantile input tensor is too large"; std::shared_ptr ranks; if (ignore_nan) { ranks = JUST(functional::ScalarMul( JUST(functional::ScalarSub( JUST(functional::ReduceSum( JUST(functional::LogicalNot(JUST(functional::IsNan(sorted)))), std::vector({static_cast(sorted->ndim() - 1)}), /*keepdim=*/true, NullOpt)), Scalar(1), Scalar(1), /*inplace=*/false)), q, /*inplace=*/false)); ranks = JUST(functional::MaskedFill( ranks, JUST(functional::ScalarLogicalLess(ranks, Scalar(0))), Scalar(0))); } else { int64_t last_index = sorted->dim(sorted->ndim() - 1) - 1; std::shared_ptr tl_index = JUST( functional::ReduceAny(JUST(functional::IsNan(sorted)), std::vector({static_cast(sorted->ndim() - 1)}), /*keepdim=*/true)); std::shared_ptr tl_value; if (input->is_local()) { tl_value = JUST(functional::Empty(*(tl_index->shape()), DType::Float(), JUST(tl_index->device()), /*requires_grad=*/false, /*pin_memory=*/false)); } else { tl_value = JUST(functional::GlobalEmpty( *(tl_index->shape()), DType::Float(), JUST(tl_index->parallel_desc()), *JUST(private_details::RawGetSbpList(JUST(tl_index->nd_sbp()))))); } tl_value = JUST(functional::Fill(tl_value, Scalar(qf * last_index))); ranks = JUST(functional::MaskedFill(tl_value, tl_index, Scalar(last_index))); } // adjust ranks based on the interpolation mode if (interpolation == "lower") { JUST(functional::Floor_(ranks)); } else if (interpolation == "higher") { JUST(functional::Ceil_(ranks)); } else if (interpolation == "nearest") { JUST(functional::Round_(ranks)); } std::shared_ptr ranks_below = JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false)); std::shared_ptr values_below = JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_below, false)); if (interpolation == "linear" || interpolation == "midpoint") { std::shared_ptr weights = interpolation == "midpoint" ? JUST(functional::FullLike(ranks, Scalar(0.5))) : JUST(functional::Sub(ranks, ranks_below, Scalar(1.0), /*inplace=*/false)); JUST(functional::Ceil_(ranks)); std::shared_ptr ranks_above = JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false)); std::shared_ptr values_above = JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_above, false)); values_below = JUST(functional::Lerp(values_below, values_above, weights)); } return view::Squeeze(values_below, std::vector({static_cast(values_below->ndim() - 1)})); } private: static inline std::vector quantile_output_shape(const Optional& dim, const std::shared_ptr& input, const Scalar& q, const bool keepdim, int64_t wrapped_dim) { // Compute output shape: q_size + reduced_size std::vector out_shape; if (dim.has_value() && input->ndim() > 0) { out_shape = std::vector(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end()); if (keepdim) { out_shape[wrapped_dim] = 1; } else { out_shape.erase(out_shape.begin() + wrapped_dim); } } else if (keepdim) { out_shape = std::vector(input->ndim(), 1); } return out_shape; } }; class TransposeFunctor { public: TransposeFunctor() { op_ = CHECK_JUST(one::OpBuilder("transpose").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const std::vector& permute) const { auto ndim = input->ndim(); CHECK_EQ_OR_RETURN(ndim, permute.size()) << "number of dims don't match in permute"; // handle negative permute value here, because of permute is const, // so copy it to local var and do modification. auto positive_perm = permute; for (auto i = 0; i < positive_perm.size(); i++) { positive_perm[i] = JUST(maybe_wrap_dim(positive_perm[i], ndim)); } // currently, view only support eager and local mode if (view::IsViewApplicable(input)) { return JUST(view::Transpose(input, positive_perm)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("perm"); attrs.SetAllAttrs(positive_perm); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class Transpose2dimFunctor { public: Transpose2dimFunctor() { op_ = CHECK_JUST(one::OpBuilder("transpose").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const int32_t dim0, const int32_t dim1) const { const int64_t ndim = input->ndim(); std::vector permute; permute.reserve(ndim); int32_t dim_0 = dim0; int32_t dim_1 = dim1; dim_0 = JUST(maybe_wrap_dim(dim_0, ndim)); dim_1 = JUST(maybe_wrap_dim(dim_1, ndim)); for (int32_t i = 0; i < ndim; ++i) { permute.emplace_back(i); } std::swap(permute[dim_0], permute[dim_1]); Shape shape(DimVector(permute.begin(), permute.end())); if (view::IsViewApplicable(input)) { return JUST(view::Transpose(input, permute)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("perm"); attrs.SetAllAttrs(permute); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class AsStridedFunctor { public: AsStridedFunctor() { op_ = CHECK_JUST(one::OpBuilder("as_strided").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const { CHECK_OR_RETURN(size.size() == stride.size()) << "mismatch in length of strides and shape"; for (size_t i = 0; i < size.size(); i++) { CHECK_OR_RETURN(size[i] >= 0) << "Trying to create tensor with negative dimension" << size[i]; CHECK_OR_RETURN(stride[i] >= 0) << "as_strided: Negative strides are not supported at the moment, got strides:" << stride[i]; } if (view::IsViewApplicable(input)) { return JUST(view::AsStrided(input, size, stride, storage_offset)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size", "stride", "storage_offset"); attrs.SetAllAttrs(size, stride, storage_offset); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class AsStridedGradFunctor { public: AsStridedGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("as_strided_grad").Input("dy").Input("input").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const { if (view::IsViewApplicable(input)) { return JUST(view::AsStridedGrad(dy, input, size, stride, storage_offset)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size", "stride", "storage_offset"); attrs.SetAllAttrs(size, stride, storage_offset); return OpInterpUtil::Dispatch(*op_, {dy, input}, attrs); } private: std::shared_ptr op_; }; class InplaceAsStridedFunctor { public: Maybe operator()(const std::shared_ptr& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const { JUST(CheckInplaceValid(input)); CHECK_OR_RETURN(size.size() == stride.size()) << "mismatch in length of strides and shape"; for (size_t i = 0; i < size.size(); i++) { CHECK_OR_RETURN(size[i] >= 0) << "Trying to create tensor with negative dimension" << size[i]; CHECK_OR_RETURN(stride[i] >= 0) << "as_strided: Negative strides are not supported at the moment, got strides:" << stride[i]; } CHECK_OR_RETURN(view::IsViewApplicable(input)) << "Only support as_strided_ in eager local mode"; JUST(view::InplaceAsStrided(input, size, stride, storage_offset)); return input; } }; class ArangeFunctor { public: ArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, const Optional>& dtype, const Optional>& device) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalArange(start, limit, delta, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("integer_start", "integer_limit", "integer_delta", "float_start", "float_limit", "float_delta", "dtype"); if (dtype.has_value()) { const DataType range_dtype = JUST(dtype)->data_type(); if (IsIntegralDataType(range_dtype)) { attrs.SetAllAttrs(start.As(), limit.As(), delta.As(), NullOpt, NullOpt, NullOpt, range_dtype); } else { attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As(), limit.As(), delta.As(), range_dtype); } } else { if (start.IsIntegral() && limit.IsIntegral() && delta.IsIntegral()) { attrs.SetAllAttrs(start.As(), limit.As(), delta.As(), NullOpt, NullOpt, NullOpt, DType::Int64()->data_type()); } else { attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As(), limit.As(), delta.As(), DType::Float()->data_type()); } } OpExprInterpContext ctx(attrs); ctx.device = device; return OpInterpUtil::Dispatch(*op_, {}, ctx); } private: std::shared_ptr op_; }; class Arange2Functor { public: Maybe operator()(const Scalar& limit, const Optional>& dtype, const Optional>& device) const { return Arange(Scalar(0), limit, Scalar(1), dtype, device); } }; class GlobalArangeFunctor { public: GlobalArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, const Optional>& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { JUST(CheckDeviceIdsIsValid(placement)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("integer_start", "integer_limit", "integer_delta", "float_start", "float_limit", "float_delta", "dtype", "nd_sbp"); if (dtype.has_value()) { const DataType range_dtype = JUST(dtype)->data_type(); if (IsIntegralDataType(range_dtype)) { attrs.SetAllAttrs(start.As(), limit.As(), delta.As(), NullOpt, NullOpt, NullOpt, range_dtype, NullOpt); } else { attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As(), limit.As(), delta.As(), range_dtype, NullOpt); } } else { if (start.IsIntegral() && limit.IsIntegral() && delta.IsIntegral()) { attrs.SetAllAttrs(start.As(), limit.As(), delta.As(), NullOpt, NullOpt, NullOpt, DType::Int64()->data_type(), NullOpt); } else { attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As(), limit.As(), delta.As(), DType::Float()->data_type(), NullOpt); } } if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); } } attrs.SetAttr<7>(nd_sbp); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)); } private: std::shared_ptr op_; }; class GlobalArange2Functor { public: Maybe operator()(const Scalar& limit, const Optional>& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { JUST(CheckDeviceIdsIsValid(placement)); return GlobalArange(Scalar(0), limit, Scalar(1), dtype, placement, sbp_tuple); } }; class HannWindowFunctor { public: Maybe operator()(const int64_t window_length, const bool& periodic, const Optional>& device, const Optional>& dtype, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalHannWindow( window_length, periodic, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, requires_grad)); } autograd::AutoGradMode mode(false); if (dtype.has_value() && !IsFloatingDataType(JUST(dtype)->data_type())) { return Error::RuntimeError() << "hann_window expects floating point dtypes, got: " << JUST(dtype)->name(); } // TODO: speedup auto result = JUST(Arange(1, 2, 1, dtype, device)); if (window_length != 1) { if (periodic) { const auto indice = JUST(Arange(window_length + 1, dtype, device)); const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length)); result = JUST(Slice(JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)), {0}, {window_length}, {1}, /*enable_view_slice=*/false)); } else { const auto indice = JUST(Arange(window_length, dtype, device)); const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length - 1)); result = JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)); } } JUST(result->set_requires_grad(requires_grad)); return result; } }; class GlobalHannWindowFunctor { public: Maybe operator()(const int64_t window_length, const bool& periodic, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const bool& requires_grad) const { autograd::AutoGradMode mode(false); JUST(CheckDeviceIdsIsValid(placement)); if (dtype.has_value() && !IsFloatingDataType(JUST(dtype)->data_type())) { return Error::RuntimeError() << "hann_window expects floating point dtypes, got: " << JUST(dtype)->name(); } auto result = JUST(GlobalArange(1, 1 + window_length, 1, dtype, placement, sbp)); if (window_length != 1) { if (periodic) { const auto indice = JUST(GlobalArange(window_length + 8, dtype, placement, sbp)); const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length)); result = JUST(Slice(JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)), {0}, {window_length}, {1}, /*enable_view_slice=*/false)); } else { const auto indice = JUST(GlobalArange(window_length, dtype, placement, sbp)); const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length - 1)); result = JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)); } } result = JUST(ToGlobal(result, placement, sbp, {}, true, /*copy=*/false)); JUST(result->set_requires_grad(requires_grad)); return result; } }; class CastFunctor { public: CastFunctor() { op_ = CHECK_JUST(one::OpBuilder("cast").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Symbol& dtype, const bool pin_memory) const { if (x->dtype() == dtype) { return x; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype", "pin_memory"); attrs.SetAllAttrs(dtype->data_type(), pin_memory); // refers to pytorch's tensor.to (to_impl function at // aten/src/ATen/native/TensorConversions.cpp) if (JUST(IsNonOverlappingAndDense(x))) { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {x->contiguous()}, attrs); } } private: std::shared_ptr op_; }; class ClampBaseFunctor { public: ClampBaseFunctor() { clip_op_ = CHECK_JUST(one::OpBuilder("clip_by_scalar").Input("x").Output("y").Build()); clip_min_op_ = CHECK_JUST(one::OpBuilder("clip_by_scalar_min").Input("x").Output("y").Build()); clip_max_op_ = CHECK_JUST(one::OpBuilder("clip_by_scalar_max").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional& min, const Optional& max, bool inplace) const { CHECK_OR_RETURN(min.has_value() || max.has_value()) << "Requires one of argument `min` and `max` at least in clip."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("floating_min", "integral_min", "floating_max", "integral_max"); if (IsFloatingDataType(x->dtype()->data_type())) { if (min.has_value()) { const auto& min_val = JUST(min); attrs.SetAttr<0>(min_val->As()); attrs.SetAttr<1>(static_cast(0)); } if (max.has_value()) { const auto& max_val = JUST(max); attrs.SetAttr<2>(max_val->As()); attrs.SetAttr<3>(static_cast(0)); } } else if (IsIntegralDataType(x->dtype()->data_type())) { if (min.has_value()) { const auto& min_val = JUST(min); attrs.SetAttr<0>(static_cast(0)); attrs.SetAttr<1>(min_val->As()); } if (max.has_value()) { const auto& max_val = JUST(max); attrs.SetAttr<2>(static_cast(0)); attrs.SetAttr<3>(max_val->As()); } } else { UNIMPLEMENTED_THEN_RETURN() << "Only support floating or integral data type."; } const OpExpr* op = nullptr; if (!min.has_value()) { op = clip_max_op_.get(); } else if (!max.has_value()) { op = clip_min_op_.get(); } else { op = clip_op_.get(); } if (inplace) { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; if (x->requires_grad()) { JUST(OpInterpUtil::Dispatch(*op, {JUST(functional::Identity(x))}, outputs.get(), attrs)); } else { JUST(OpInterpUtil::Dispatch(*op, {x}, outputs.get(), attrs)); } return outputs->at(0); } else { return OpInterpUtil::Dispatch(*op, {x}, attrs); } } private: std::shared_ptr clip_op_; std::shared_ptr clip_min_op_; std::shared_ptr clip_max_op_; }; class ClampFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional& min, const Optional& max) const { return ClampBaseFunctor::operator()(x, min, max, /* inplace=*/false); } }; class ClampMinFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar& min) const { return ClampBaseFunctor::operator()(x, min, NullOpt, /* inplace=*/false); } }; class ClampMaxFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar& max) const { return ClampBaseFunctor::operator()(x, NullOpt, max, /* inplace=*/false); } }; class ClampInplaceFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional& min, const Optional& max) const { return ClampBaseFunctor::operator()(x, min, max, /* inplace=*/true); } }; class ClampMinInplaceFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar& min) const { return ClampBaseFunctor::operator()(x, min, NullOpt, /* inplace=*/true); } }; class ClampMaxInplaceFunctor : public ClampBaseFunctor { public: Maybe operator()(const std::shared_ptr& x, const Scalar& max) const { return ClampBaseFunctor::operator()(x, NullOpt, max, /* inplace=*/true); } }; class ClipFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional& min, const Optional& max) const { return Clamp(x, min, max); } }; class ClipInplaceFunctor { public: Maybe operator()(const std::shared_ptr& x, const Optional& min, const Optional& max) const { return ClampInplace(x, min, max); } }; class SqrtSquareSumFunctor { public: SqrtSquareSumFunctor() { op_ = CHECK_JUST(one::OpBuilder("sqrt_square_sum").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {x}, {}); } private: std::shared_ptr op_; }; class VectorNormFunctor { public: VectorNormFunctor() {} Maybe operator()(const std::shared_ptr& x, const Scalar& ord, const Optional>& input_dim, const bool& keepdim, const Optional>& dtype) const { std::shared_ptr res; Symbol dtype_val; if (dtype) { dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports floating point and " "complex dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports floating point and " "complex dtypes, but got: Int."; } dtype_val = x->dtype(); } bool full_dim_flag = true; std::vector dim; if (!input_dim.has_value()) { std::vector reduce_axis(x->ndim()); std::iota(reduce_axis.begin(), reduce_axis.end(), 0); dim = reduce_axis; } else { std::vector dim_check; dim_check = *JUST(input_dim); for (int i = 0; i < dim_check.size(); ++i) { if (dim_check[i] >= 0) { dim.emplace_back(dim_check[i]); } else { dim.emplace_back(dim_check[i] + x->ndim()); } if (dim[i] != i) { full_dim_flag = false; } } if ((int)dim.size() < x->ndim()) { full_dim_flag = false; } } if (ord.IsIntegral() || ord.IsFloatingPoint()) { double ord_val = ord.As(); if (ord_val == 0) { res = JUST(ReduceSum(JUST(functional::NotEqualZero(x)), dim, keepdim, NullOpt)); } else if (ord_val == INFINITY) { res = JUST(ReduceMax(JUST(Abs(x)), dim, keepdim)); } else if (ord_val == -INFINITY) { res = JUST(ReduceMin(JUST(Abs(x)), dim, keepdim)); } else if (ord_val == 2.0 && keepdim == false && full_dim_flag && x->requires_grad() == false) { res = JUST(SqrtSquareSum(x)); } else { res = JUST(ScalarPow( JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord, false)), dim, keepdim, NullOpt)), Scalar(1.0) / ord, false)); } res = JUST(Cast(res, dtype_val, /*pin_memory=*/false)); return res; } else { UNIMPLEMENTED_THEN_RETURN() << "linalg_vector_norm(): argument 'ord' must be Number, not str."; } } }; class ScalarVectorNormFunctor { public: ScalarVectorNormFunctor() {} Maybe operator()(const std::shared_ptr& x, const Scalar& ord, const Scalar& input_dim, const bool& keepdim, const Optional>& dtype) const { if (dtype) { Symbol dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } } if (input_dim.IsIntegral()) { std::vector dim(1, input_dim.As()); return functional::VectorNorm(x, ord, dim, keepdim, dtype); } else { UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only support int dim."; } } }; class ScalarMatrixNormFunctor { public: ScalarMatrixNormFunctor() {} Maybe operator()(const std::shared_ptr& x, const Scalar& ord, const std::vector& input_dim, const bool& keepdim, const Optional>& dtype) const { std::shared_ptr res; auto num_dims = x->ndim(); auto axis = input_dim.size(); CHECK_OR_RETURN(num_dims >= 2) << "linalg.matrix_norm(): input tensor must be a matrix or batch of matrices"; CHECK_OR_RETURN(axis == 2 && input_dim[0] != input_dim[1]) << "linalg.matrix_norm(): input_dim must be a 2-tuple of ints with different elements"; Symbol dtype_val; if (dtype) { dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } dtype_val = x->dtype(); } std::vector dim_tmp; dim_tmp.reserve(axis); for (int i = 0; i < axis; ++i) { if (input_dim[i] >= 0) { dim_tmp.emplace_back(input_dim[i]); } else { dim_tmp.emplace_back(input_dim[i] + num_dims); } } std::vector dim(2); double ord_tmp = ord.As(); if (ord_tmp == INFINITY || ord_tmp == -INFINITY) { dim = dim_tmp; dim[0] = dim_tmp[1]; dim[1] = dim_tmp[0]; } else if (ord_tmp == 1 || ord_tmp == -1) { dim = dim_tmp; } else { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): Only support INFINITY,-INFINITY,1 or -1 data type."; } if (dim[1] > dim[0] && keepdim == false) { dim[1] -= 1; } std::vector dim_tmp0_vec(1, dim[0]); std::vector dim_tmp1_vec(1, dim[1]); res = JUST(ReduceSum(JUST(Abs(x)), dim_tmp0_vec, keepdim, NullOpt)); if (ord_tmp == INFINITY || ord_tmp == 1) { res = JUST(ReduceMax(res, dim_tmp1_vec, keepdim)); } else if (ord_tmp == -INFINITY || ord_tmp == -1) { res = JUST(ReduceMin(res, dim_tmp1_vec, keepdim)); } res = JUST(Cast(res, dtype_val, /*pin_memory=*/false)); return res; } }; class MatrixNormFunctor { public: MatrixNormFunctor() {} Maybe operator()(const std::shared_ptr& x, const std::string& ord, const std::vector& input_dim, const bool& keepdim, const Optional>& dtype) const { std::shared_ptr res; Symbol dtype_val; if (dtype) { dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, " "cfloat and cdouble dtypes, but got: Int."; } dtype_val = x->dtype(); } auto num_dims = x->ndim(); auto axis = input_dim.size(); std::vector dim_tmp(axis); for (int i = 0; i < axis; ++i) { if (input_dim[i] >= 0) { dim_tmp[i] = input_dim[i]; } else { dim_tmp[i] = input_dim[i] + num_dims; } } if (ord == "nuc") { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): Not support ord is nuc."; } else if (ord == "fro") { res = JUST(Sqrt(JUST(ReduceSum(JUST(Square(x)), dim_tmp, keepdim, NullOpt)))); } else { UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): could not convert string to float:" << ord; } res = JUST(Cast(res, dtype_val, /*pin_memory=*/false)); return res; } }; class NormFunctor { public: NormFunctor() {} Maybe operator()(const std::shared_ptr& x, const Optional& ord, const Optional>& input_dim, const bool& keepdim, const Optional>& dtype, const bool& for_norm) const { // If for_norm, the functor will be used to oneflow.norm. std::shared_ptr res; if (dtype) { Symbol dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } Scalar ord_sca; bool ord_type = false; if (ord.has_value()) { ord_type = (*JUST(ord)).IsIntegral(); if (ord_type) { ord_sca = Scalar((*JUST(ord)).As()); } else { ord_sca = *JUST(ord); } } if (input_dim.has_value()) { auto axis = (*JUST(input_dim)).size(); if (axis == 1) { Scalar ord_val; if (!ord.has_value()) { ord_val = Scalar(2.0); } else { ord_val = ord_sca; } res = JUST(VectorNorm(x, ord_val, input_dim, keepdim, dtype)); } else if (axis > 2) { res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype)); } else if (axis == 2) { if (!ord.has_value()) { res = JUST(MatrixNorm(x, "fro", *JUST(input_dim), keepdim, dtype)); } else { res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype)); } } } else { if (ord.has_value()) { CHECK_OR_RETURN(x->ndim() <= 2) << "linalg.norm(): input must be 1-D or 2-D when dim is None and ord is not None"; if (ord_type) { const double ord_double = (*JUST(ord)).As(); if (for_norm && (ord_double >= 2 || ord_double <= -2)) { const int32_t num_axes = x->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); return ScalarPow(JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord_sca, false)), axes_vec, /*keepdims=*/false, NullOpt)), 1 / ord_double, false); } } if (x->ndim() == 1) { res = JUST(VectorNorm(x, ord_sca, input_dim, keepdim, dtype)); } else { std::vector dim{0, 1}; res = JUST(MatrixNorm(x, ord_sca, dim, keepdim, dtype)); } } else { res = JUST(VectorNorm(x, Scalar(2.0), input_dim, keepdim, dtype)); } } return res; } }; class Norm2Functor { public: Norm2Functor() {} Maybe operator()(const std::shared_ptr& x, const std::string& ord, const Optional>& input_dim, const bool& keepdim, const Optional>& dtype) const { std::shared_ptr res; std::vector dim(x->ndim()); std::iota(dim.begin(), dim.end(), 0); if (dtype) { Symbol dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } if (input_dim.has_value()) { res = JUST(MatrixNorm(x, ord, *JUST(input_dim), keepdim, dtype)); } else { res = JUST(MatrixNorm(x, ord, dim, keepdim, dtype)); } return res; } }; class ScalarNormFunctor { public: ScalarNormFunctor() {} Maybe operator()(const std::shared_ptr& x, const Optional& ord, const Scalar& input_dim, const bool& keepdim, const Optional>& dtype) const { if (dtype) { Symbol dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } if (input_dim.IsIntegral()) { std::vector dim(1, input_dim.As()); return functional::Norm(x, ord, dim, keepdim, dtype, /*for_norm=*/false); } else { UNIMPLEMENTED_THEN_RETURN() << "linalg_norm(): only supports int dim."; } } }; class ScalarNorm2Functor { public: ScalarNorm2Functor() {} Maybe operator()(const std::shared_ptr& x, const std::string& ord, const Scalar& input_dim, const bool& keepdim, const Optional>& dtype) const { if (dtype) { Symbol dtype_val = JUST(dtype); if (!(dtype_val->data_type() == DataType::kFloat || dtype_val->data_type() == DataType::kDouble || dtype_val->data_type() == DataType::kFloat16 || dtype_val->data_type() == DataType::kBFloat16)) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } else { if (!IsFloatingDataType(x->dtype()->data_type())) { UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and " "cdouble dtypes, but got: Int."; } } if (input_dim.IsIntegral()) { std::vector dim(1, input_dim.As()); return functional::Norm(x, ord, dim, keepdim, dtype); } else { UNIMPLEMENTED_THEN_RETURN() << "linalg_norm(): only supports int dim."; } } }; class InvFunctor { public: InvFunctor() { op_ = CHECK_JUST(one::OpBuilder("inv").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { if (x->ndim() < 2) { return Error::RuntimeError() << "linalg.inv: The input tensor must be at least 2 dimensions."; } if (x->dim(x->ndim() - 1) != x->dim(x->ndim() - 2)) { return Error::RuntimeError() << "linalg.inv: A must be batches of square matrices, " << "but they are " << x->dim(x->ndim() - 2) << " by " << x->dim(x->ndim() - 1) << " matrices"; } return OpInterpUtil::Dispatch(*op_, {x}, {}); } private: std::shared_ptr op_; }; class DetFunctor { public: DetFunctor() { det_op_ = CHECK_JUST(one::OpBuilder("det").Input("x").Output("y").Build()); lu_decomposition_op_ = CHECK_JUST( one::OpBuilder("lu_decomposition").Input("x").Output("LU").Output("pivot").Build()); } Maybe GetPivotDet(const std::shared_ptr& pivot) const { std::shared_ptr arange = nullptr; int64_t end = pivot->shape()->At(pivot->ndim() - 1) + 1; if (pivot->is_local()) { arange = JUST(functional::Arange(1, end, 1, pivot->dtype(), JUST(pivot->device()))); } else { auto pivot_nd_sbp = JUST(pivot->nd_sbp()); std::vector> nd_sbp(pivot_nd_sbp->sbp_parallel_size()); { for (int i = 0; i < nd_sbp.size(); ++i) { nd_sbp[i] = pivot_nd_sbp->sbp_parallel(i); } } arange = JUST(functional::GlobalArange(1, end, 1, pivot->dtype(), JUST(pivot->parallel_desc()), nd_sbp)); } return sequence_function(functional::BroadcastNotEqual) .then([](const auto& x) { return functional::ReduceSum(x, {-1}, false, NullOpt); }) .then([](const auto& x) { return functional::ScalarFMod(x, Scalar(2), true); }) .then([](const auto& x) { return functional::ScalarMul(x, Scalar(-2), true); }) .then([](const auto& x) { return functional::ScalarAdd(x, Scalar(1), Scalar(1), true); }) .call(arange, pivot); } Maybe operator()(const std::shared_ptr& x) const { const int64_t xdims = x->ndim(); if (xdims < 2) { return Error::RuntimeError() << "linalg.det: The input tensor must be at least 2 dimensions."; } if (x->dim(xdims - 1) != x->dim(xdims - 2)) { return Error::RuntimeError() << "linalg.det: A must be batches of square matrices, " << "but they are " << x->dim(xdims - 2) << " by " << x->dim(xdims - 1) << " matrices"; } DeviceType x_device_type = DeviceType::kInvalidDevice; if (x->is_local()) { x_device_type = JUST(x->device())->enum_type(); } else if (x->is_global()) { x_device_type = JUST(x->parallel_desc())->device_type(); } if (x_device_type == DeviceType::kCPU) { return JUST(OpInterpUtil::Dispatch(*det_op_, {x}, {})); } else if (x_device_type == DeviceType::kCUDA) { auto result = JUST(OpInterpUtil::Dispatch(*lu_decomposition_op_, {x}, {})); auto LU = result->at(0); auto pivot = result->at(1); auto LU_det = JUST( functional::ReduceProd(JUST(functional::Diagonal(LU, 0, -2, -1)), {-1}, false, NullOpt)); return functional::Mul(JUST(GetPivotDet(pivot)), LU_det); } else { UNIMPLEMENTED_THEN_RETURN() << "Det: Only support cpu and cuda device."; } } private: std::shared_ptr det_op_; std::shared_ptr lu_decomposition_op_; }; class ClampGradFunctor { public: ClampGradFunctor() { clip_op_ = CHECK_JUST( one::OpBuilder("clip_by_scalar_grad").Input("dy").Input("x").Output("dx").Build()); clip_min_op_ = CHECK_JUST( one::OpBuilder("clip_by_scalar_min_grad").Input("dy").Input("x").Output("dx").Build()); clip_max_op_ = CHECK_JUST( one::OpBuilder("clip_by_scalar_max_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const Optional& min, const Optional& max) const { CHECK_OR_RETURN(min.has_value() || max.has_value()) << "Requires one of argument `min` and `max` at least in clip_grad."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("floating_min", "integral_min", "floating_max", "integral_max"); if (IsFloatingDataType(x->dtype()->data_type())) { if (min.has_value()) { const auto& min_val = JUST(min); attrs.SetAttr<0>(min_val->As()); attrs.SetAttr<1>(static_cast(0)); } if (max.has_value()) { const auto& max_val = JUST(max); attrs.SetAttr<2>(max_val->As()); attrs.SetAttr<3>(static_cast(0)); } } else if (IsIntegralDataType(x->dtype()->data_type())) { if (min.has_value()) { const auto& min_val = JUST(min); attrs.SetAttr<0>(static_cast(0)); attrs.SetAttr<1>(min_val->As()); } if (max.has_value()) { const auto& max_val = JUST(max); attrs.SetAttr<2>(static_cast(0)); attrs.SetAttr<3>(max_val->As()); } } else { UNIMPLEMENTED_THEN_RETURN() << "Only support floating or integral data type."; } const OpExpr* op = nullptr; if (!min.has_value()) { op = clip_max_op_.get(); } else if (!max.has_value()) { op = clip_min_op_.get(); } else { op = clip_op_.get(); } return OpInterpUtil::Dispatch(*op, {dy, x}, attrs); } private: std::shared_ptr clip_op_; std::shared_ptr clip_min_op_; std::shared_ptr clip_max_op_; }; class SelectFunctor { public: SelectFunctor() = default; Maybe operator()(const std::shared_ptr& input, const int32_t& dim, const int32_t& index) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim > 0) << "select() cannot be applied to a 0-dim tensor."; int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); auto size = input->dim(pos_dim); CHECK_OR_RETURN((index >= -size) && (index < size)) << "Index out of range (expected to be in range of [" << -size << "," << size - 1 << "], but got " << index << ")"; int32_t pos_index = index >= 0 ? index : index + size; std::vector sizes(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end()); const auto& stride = *JUST(input->stride()); std::vector strides(stride.begin(), stride.end()); auto storage_offset = JUST(input->storage_offset()) + pos_index * strides[pos_dim]; sizes.erase(sizes.begin() + pos_dim); strides.erase(strides.begin() + pos_dim); return AsStrided(input, sizes, strides, storage_offset); } }; class SelectTopNFunctor { public: SelectTopNFunctor() { op_ = CHECK_JUST(one::SelectTopNOpExpr::New()); } Maybe operator()(const TensorTuple& inputs, int32_t n) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("top_n"); attrs.SetAllAttrs(n); std::vector require_grad(n); for (int i = 0; i < n; ++i) { require_grad[i] = JUST(VectorAt(inputs, i))->requires_grad(); } const auto& output = JUST(OpInterpUtil::Dispatch(*op_, inputs, attrs)); for (int i = 0; i < output->size(); ++i) { (*output)[i]->set_is_leaf(false); JUST((*output)[i]->set_requires_grad(require_grad[i])); } return output; } private: std::shared_ptr op_; }; class MinimumFunctor { public: MinimumFunctor() { elementwise_minimum_op_ = CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build()); broadcast_minimum_op_ = CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false)); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); if (*x->shape() == *y->shape()) { return OpInterpUtil::Dispatch(*elementwise_minimum_op_, {input_tuple[0], input_tuple[1]}); } else { return OpInterpUtil::Dispatch(*broadcast_minimum_op_, {input_tuple[0], input_tuple[1]}); } } private: std::shared_ptr elementwise_minimum_op_; std::shared_ptr broadcast_minimum_op_; }; class MaximumFunctor { public: MaximumFunctor() { elementwise_maximum_op_ = CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build()); broadcast_maximum_op_ = CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { auto tensor_x = x; auto tensor_y = y; JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false)); TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); if (*x->shape() == *y->shape()) { return OpInterpUtil::Dispatch(*elementwise_maximum_op_, {input_tuple[0], input_tuple[1]}); } else { return OpInterpUtil::Dispatch(*broadcast_maximum_op_, {input_tuple[0], input_tuple[1]}); } } private: std::shared_ptr elementwise_maximum_op_; std::shared_ptr broadcast_maximum_op_; }; class ScalarLogicalBaseFunctor { public: explicit ScalarLogicalBaseFunctor(std::string op_name) { op_ = CHECK_JUST(one::OpBuilder(op_name).Input("in").Output("out").Build()); } virtual ~ScalarLogicalBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { TensorProcessor tensor_processor; Symbol lowest_dtype; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (scalar.IsFloatingPoint()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); // Only promote type to Float32 when tensor is Int type but scalar is float type. if (DType::priority_order[x->dtype()->data_type()] < DType::priority_order[DType::Float16()->data_type()]) { lowest_dtype = DType::Float(); } else { lowest_dtype = x->dtype(); } } else if (scalar.IsIntegral() || scalar.IsBool()) { attrs.SetAllAttrs(NullOpt, false, scalar.As(), true); // Only promote type to Int64 when tensor is Bool type but scalar is int type. if (DType::priority_order[x->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else { lowest_dtype = x->dtype(); } } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in " << op_->op_type_name() << " should be float or int."; } JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply()); TensorTuple casted_vec = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, {casted_vec}, attrs); } private: std::shared_ptr op_; }; class ScalarLogicalEqualFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_equal") {} }; // (scalar == x) = (x == scalar) class ScalarLogicalEqual2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalEqual(x, scalar); } }; class ScalarLogicalNotEqualFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalNotEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_not_equal") {} }; // (scalar != x) = (x != scalar) class ScalarLogicalNotEqual2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalNotEqual(x, scalar); } }; class ScalarLogicalGreaterFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalGreaterFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_greater") {} }; // (scalar > x) = (x < scalar) class ScalarLogicalGreater2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalLess(x, scalar); } }; class InplaceScalarLogicalGreaterFunctor { public: InplaceScalarLogicalGreaterFunctor() { op_ = CHECK_JUST( one::OpBuilder("scalar_logical_inplace_greater").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { TensorProcessor tensor_processor; Symbol lowest_dtype; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (scalar.IsFloatingPoint()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); // Only promote type to Float32 when tensor is Int type but scalar is float type. if (DType::priority_order[x->dtype()->data_type()] < DType::priority_order[DType::Float16()->data_type()]) { lowest_dtype = DType::Float(); } else { lowest_dtype = x->dtype(); } } else if (scalar.IsIntegral() || scalar.IsBool()) { attrs.SetAllAttrs(NullOpt, false, scalar.As(), true); // Only promote type to Int64 when tensor is Bool type but scalar is int type. if (DType::priority_order[x->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else { lowest_dtype = x->dtype(); } } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in " << op_->op_type_name() << " should be float or int."; } JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x}, lowest_dtype).Apply()); TensorTuple input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& x_cast = input_vec.at(0); JUST(CheckInplaceValid(x)); JUST(CheckInplaceCastValid(x, x_cast)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class ScalarLogicalGreaterEqualFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalGreaterEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_greater_equal") {} }; // (scalar >= x) = (x <= scalar) class ScalarLogicalGreaterEqual2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalLessEqual(x, scalar); } }; class ScalarLogicalLessFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalLessFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_less") {} }; // (scalar < x) = (x > scalar) class ScalarLogicalLess2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalGreater(x, scalar); } }; class ScalarLogicalLessEqualFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalLessEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_less_equal") {} }; // (scalar <= x) = (x >= scalar) class ScalarLogicalLessEqual2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalGreaterEqual(x, scalar); } }; class ScalarLogicalAndFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalAndFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_and") {} }; // (scalar && x) = (x && scalar) class ScalarLogicalAnd2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalAnd(x, scalar); } }; class ScalarLogicalOrFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalOrFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_or") {} }; // (scalar || x) = (x || scalar) class ScalarLogicalOr2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalOr(x, scalar); } }; class ScalarLogicalXorFunctor : public ScalarLogicalBaseFunctor { public: ScalarLogicalXorFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_xor") {} }; // (scalar ^ x) = (x ^ scalar) class ScalarLogicalXor2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarLogicalXor(x, scalar); } }; class ScalarBitwiseBaseFunctor { public: explicit ScalarBitwiseBaseFunctor(std::string op_name) { op_ = CHECK_JUST(one::OpBuilder(op_name).Input("in").Output("out").Build()); } virtual ~ScalarBitwiseBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { TensorProcessor tensor_processor; Symbol lowest_dtype; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("operand"); CHECK_OR_RETURN(scalar.IsIntegral() || scalar.IsBool()) << "Bitwise ops only support int and bool dtype"; attrs.SetAllAttrs(scalar.As()); // Only promote type to Int64 when tensor is Bool type but scalar is int type. if (DType::priority_order[x->dtype()->data_type()] == DType::priority_order[DType::Bool()->data_type()]) { lowest_dtype = DType::Int64(); } else { lowest_dtype = x->dtype(); } JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply()); TensorTuple casted_vec = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, {casted_vec}, attrs); } private: std::shared_ptr op_; }; class ScalarLerpFunctor { public: ScalarLerpFunctor() { op_ = CHECK_JUST(one::OpBuilder("scalar_lerp").Input("start").Input("end").Output("out").Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const Scalar& weight) const { CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes()) << Error::RuntimeError() << "expected dim" << start->shape()->NumAxes() << "for `end` but got dim" << end->shape()->NumAxes(); auto broadcast_shape = *start->shape(); if (*start->shape() != *end->shape()) { broadcast_shape = *JUST(InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape()})); } std::shared_ptr broadcast_start = start; std::shared_ptr broadcast_end = end; if (*start->shape() != broadcast_shape) { broadcast_start = JUST(functional::Expand(start, broadcast_shape)); } if (*end->shape() != broadcast_shape) { broadcast_end = JUST(functional::Expand(end, broadcast_shape)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (weight.IsFloatingPoint()) { attrs.SetAllAttrs(weight.As(), true, NullOpt, false); } else if (weight.IsIntegral() || weight.IsBool()) { attrs.SetAllAttrs(NullOpt, false, weight.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in " << op_->op_type_name() << " should be float or int."; } return OpInterpUtil::Dispatch(*op_, {broadcast_start, broadcast_end}, attrs); } private: std::shared_ptr op_; }; class ScalarInplaceLerpFunctor { public: ScalarInplaceLerpFunctor() { op_ = CHECK_JUST(one::OpBuilder("scalar_lerp").Input("start").Input("end").Output("out").Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const Scalar& weight) const { CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes()) << Error::RuntimeError() << "expected dim" << start->shape()->NumAxes() << "for `end` but got dim" << end->shape()->NumAxes(); auto broadcast_shape = *start->shape(); if (*start->shape() != *end->shape()) { broadcast_shape = *JUST(InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape()})); } std::shared_ptr broadcast_start = JUST(Identity(start)); std::shared_ptr broadcast_end = JUST(Identity(end)); if (*start->shape() != broadcast_shape) { broadcast_start = JUST(view::Expand(start, broadcast_shape)); } if (*end->shape() != broadcast_shape) { broadcast_end = JUST(view::Expand(end, broadcast_shape)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (weight.IsFloatingPoint()) { attrs.SetAllAttrs(weight.As(), true, NullOpt, false); } else if (weight.IsIntegral() || weight.IsBool()) { attrs.SetAllAttrs(NullOpt, false, weight.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in " << op_->op_type_name() << " should be float or int."; } TensorProcessor tensor_processor; if (broadcast_end->requires_grad()) { JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({JUST(Identity(broadcast_start)), broadcast_end}) .Apply()); } else { JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({broadcast_start, broadcast_end}) .Apply()); } const TensorTuple& input_vec = JUST(tensor_processor.GetInputs()); const std::shared_ptr& start_cast = input_vec.at(0); const std::shared_ptr& end_cast = input_vec.at(1); JUST(CheckInplaceValid(broadcast_start)); JUST(CheckInplaceCastValid(broadcast_start, start_cast)); JUST(CheckInplaceShapeCanExpandTo(*start_cast->shape(), *end_cast->shape())); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = start; JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get(), attrs)); return outputs->at(0); } private: std::shared_ptr op_; }; class ScalarLerpGradFunctor { public: ScalarLerpGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("scalar_lerp_grad") .Input("start") .Input("end") .Input("out_diff") .Output("start_diff") .Output("end_diff") .Build()); } Maybe operator()(const std::shared_ptr& start, const std::shared_ptr& end, const std::shared_ptr& out_diff, const Scalar& weight) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("float_operand", "has_float_operand", "int_operand", "has_int_operand"); if (weight.IsFloatingPoint()) { attrs.SetAllAttrs(weight.As(), true, NullOpt, false); } else if (weight.IsIntegral()) { attrs.SetAllAttrs(NullOpt, false, weight.As(), true); } else { UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarLerpGrad should be float or int."; } return OpInterpUtil::Dispatch(*op_, {start, end, out_diff}, attrs); } private: std::shared_ptr op_; }; class ScalarBitwiseAndFunctor : public ScalarBitwiseBaseFunctor { public: ScalarBitwiseAndFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/"scalar_bitwise_and") {} }; class ScalarBitwiseAnd2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarBitwiseAnd(x, scalar); } }; class ScalarBitwiseOrFunctor : public ScalarBitwiseBaseFunctor { public: ScalarBitwiseOrFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/"scalar_bitwise_or") {} }; class ScalarBitwiseOr2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarBitwiseOr(x, scalar); } }; class ScalarBitwiseXorFunctor : public ScalarBitwiseBaseFunctor { public: ScalarBitwiseXorFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/"scalar_bitwise_xor") {} }; class ScalarBitwiseXor2Functor { public: Maybe operator()(const Scalar& scalar, const std::shared_ptr& x) const { return ScalarBitwiseXor(x, scalar); } }; class StandardDeviationFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& dim, const Optional& unbiased, const Optional& keepdim) const { std::vector axis; if (!dim) { for (int i = 0; i < input->ndim(); i++) { axis.emplace_back(i); } } else { axis = *JUST(CheckAxis(*JUST(dim), input->ndim())); } bool unbias = true; bool keepdims = false; if (unbiased.has_value()) { unbias = JUST(unbiased); } if (keepdim.has_value()) { keepdims = JUST(keepdim); } if (axis.size() == 0) { return functional::Constant(*input->shape(), Scalar(0), *input->dtype(), NullOpt); } int32_t reduce_count = 1; if (axis.size() == 1) { reduce_count *= input->shape()->At(axis[0]); } else { for (int i = 0; i < axis.size(); ++i) { reduce_count *= input->shape()->At(axis[i]); } } bool is_double = input->dtype()->data_type() == DataType::kDouble; if (is_double) { const auto& sum = JUST(functional::ScalarDiv( JUST(functional::ReduceSum(JUST(functional::Square(input)), axis, keepdims, NullOpt)), Scalar((double)reduce_count))); const auto& square = JUST(functional::Square( JUST(functional::ScalarDiv(JUST(functional::ReduceSum(input, axis, keepdims, NullOpt)), Scalar((double)reduce_count))))); const auto& sub = JUST(functional::Sub(sum, square, /*alpha=*/1.0, /*inplace=*/false)); if (unbias) { return functional::Sqrt(JUST(functional::ScalarMul( sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false))); } /* According to the std calculation formula, StandardDeviation = \sqrt {\frac {\sum _ {i=1}^ {N}X_ {i}^ {2}}{N} - \mu ^ {2}} = \sqrt{\frac {1}{N}\sum _ {i=1}^ {n} (x_ {i}-\mu )^ {2} -\frac {1}{N} N \mu ^ {2}} = \sqrt{\frac {\sum _ {i=1}^ {N}X_ {i}^ {2}}{N} - \mu ^ {2}} when we are in the last sqrt, if the value in the radical is <= 0, it may cause the result gradient to appear undefined(nan), which is normal. In this case, the gradient of ours and pytorch are different. Use abs(absolute value) can keep it consistent with pytorch: const auto& abs = JUST(functional::Abs(sub)); return functional::Sqrt(abs); */ // const auto& abs = JUST(functional::Abs(sub)); // return functional::Sqrt(abs); return functional::Sqrt(sub); } else { // If input tensor's dtype is float32, than cast it to double dtype, // because float dtype has accuracy problem in float dtype, see: // https://github.com/Oneflow-Inc/oneflow/issues/6526 const auto& double_input = JUST(functional::Cast(input, DType::Double(), /*pin_memory=*/false)); const auto& sum = JUST( functional::ScalarDiv(JUST(functional::ReduceSum(JUST(functional::Square(double_input)), axis, keepdims, NullOpt)), Scalar((double)reduce_count))); const auto& square = JUST(functional::Square(JUST( functional::ScalarDiv(JUST(functional::ReduceSum(double_input, axis, keepdims, NullOpt)), Scalar((double)reduce_count))))); const auto& sub = JUST(functional::Sub(sum, square, /*alpha=*/1.0, /*inplace=*/false)); if (unbias) { return functional::Cast( JUST(functional::Sqrt(JUST(functional::ScalarMul( sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false)))), input->dtype(), /*pin_memory=*/false); } return functional::Cast(JUST(functional::Sqrt(sub)), input->dtype(), /*pin_memory=*/false); } } }; class VarianceFunctor { public: VarianceFunctor() { op_ = CHECK_JUST(one::OpBuilder("var").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const Optional>& dim, const Optional& unbiased, const Optional& keepdim) const { if (!(IsFloatingDataType(input->dtype()->data_type()) || IsHalfDataType(input->dtype()->data_type()))) { return Error::RuntimeError() << "var only support floating point dtypes"; } std::vector axis; const int ndim = input->ndim(); axis.reserve(ndim); if (!dim) { for (int i = 0; i < ndim; i++) { axis.emplace_back(i); } } else { std::vector& dims = *JUST(dim); JUST(maybe_wrap_dim(dims.size(), ndim)); // only check validation std::sort(dims.begin(), dims.end()); axis.assign(dims.begin(), dims.end()); } for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { axis[i] += ndim; } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("unbiased", "keepdim", "dim", "dtype"); attrs.SetAllAttrs(unbiased, keepdim, axis, input->dtype()->data_type()); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class RMSLayerNormalizationFunctor { public: Maybe operator()(const std::shared_ptr& hidden_states, const std::shared_ptr& weight, const float& variance_epsilon) const { std::shared_ptr cast_hidden_states = hidden_states; if (hidden_states->dtype() != DType::Float()) { cast_hidden_states = JUST(functional::Cast(hidden_states, DType::Float(), /*pin_memory=*/false)); } std::shared_ptr normalized_hidden_states = JUST(functional::Mul( cast_hidden_states, JUST(functional::Rsqrt(JUST(functional::ScalarAdd( JUST(functional::ReduceMean(JUST(Square(hidden_states)), std::vector{-1}, true)), Scalar(variance_epsilon), 1.0, false)))))); if (weight->dtype() == DType::Float16()) { normalized_hidden_states = JUST(functional::Cast(normalized_hidden_states, weight->dtype(), /*pin_memory=*/false)); } return JUST(functional::Mul(normalized_hidden_states, weight)); } }; class DotFunctor { public: DotFunctor() { op_ = CHECK_JUST(one::OpBuilder("dot").Input("x").Input("y").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& other) const { return OpInterpUtil::Dispatch(*op_, {input, other}); } private: std::shared_ptr op_; }; class MovedimVecFunctor { public: MovedimVecFunctor() = default; static Maybe CheckNoRepeat(const std::vector& perm, std::vector& perm_out, int32_t ndim, const std::string& desc) { std::vector is_used(ndim, false); FOR_RANGE(size_t, i, 0, perm.size()) { int32_t item = perm[i]; item = JUST(maybe_wrap_dim(item, ndim)); CHECK_EQ_OR_RETURN(is_used[item], false) << "repeated dim in " << desc; is_used[item] = true; perm_out[i] = item; } return Maybe::Ok(); } Maybe operator()(const std::shared_ptr& input, const std::vector& source, const std::vector& destination) const { int32_t ndim = input->ndim(); int32_t dim = source.size(); CHECK_EQ_OR_RETURN(source.size(), destination.size()) << "movedim: Invalid source or destination dims: source (" << source.size() << " dims ) should contain the same number of dims as destination (" << destination.size() << " dims)"; std::vector source_nopeat(dim); std::vector destination_nopeat(dim); JUST(CheckNoRepeat(source, source_nopeat, ndim, "source")); JUST(CheckNoRepeat(destination, destination_nopeat, ndim, "destination")); std::vector order(ndim); std::vector source_dims(ndim); std::vector destination_dims(ndim); std::iota(source_dims.begin(), source_dims.end(), 0); std::iota(destination_dims.begin(), destination_dims.end(), 0); FOR_RANGE(size_t, i, 0, dim) { order[destination_nopeat[i]] = source_nopeat[i]; source_dims[source_nopeat[i]] = -1; destination_dims[destination_nopeat[i]] = -1; } std::remove(source_dims.begin(), source_dims.end(), -1); std::remove(destination_dims.begin(), destination_dims.end(), -1); int64_t rest_dim = ndim - dim; FOR_RANGE(size_t, i, 0, rest_dim) { order[destination_dims[i]] = source_dims[i]; } return Transpose(input, order); } }; class MovedimIntFunctor { public: MovedimIntFunctor() = default; Maybe operator()(const std::shared_ptr& input, const int32_t& source, const int32_t& destination) const { std::vector src{source}; std::vector dest{destination}; return MovedimVec(input, src, dest); } }; class TensorSplitVecFunctor { public: TensorSplitVecFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::vector& indices_or_sections, const int32_t& dim) const { int32_t ndim = input->ndim(); int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); std::vector start(ndim, 0); std::vector stop(ndim); std::vector step(ndim, 1); for (int32_t i = 0; i < ndim; i++) { stop[i] = input->dim(i); } int32_t num_indices = indices_or_sections.size(); TensorTuple output(num_indices + 1); for (int32_t i = 0; i < num_indices; i++) { int32_t end_idx = indices_or_sections[i]; stop[pos_dim] = end_idx; output[i] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false)); start[pos_dim] = end_idx; } stop[pos_dim] = input->shape()->At(pos_dim); output[num_indices] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false)); return output; } }; class TensorSplitIntFunctor { public: TensorSplitIntFunctor() = default; Maybe operator()(const std::shared_ptr& input, const int32_t& indices_or_sections, const int32_t& dim) const { int32_t ndim = input->ndim(); int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); CHECK_OR_RETURN(indices_or_sections > 0) << "number of sections must be larger than 0, got ," << indices_or_sections << ");"; const auto dim_size = input->dim(pos_dim); int64_t min_split_size = dim_size / indices_or_sections; int64_t num_splits_one_extra = dim_size % indices_or_sections; std::vector start(ndim, 0); std::vector stop(ndim); std::vector step(ndim, 1); for (int32_t i = 0; i < ndim; i++) { stop[i] = input->dim(i); } stop[pos_dim] = 0; TensorTuple output(indices_or_sections); for (int32_t i = 0; i < indices_or_sections; i++) { int64_t split_size = (i < num_splits_one_extra) ? (min_split_size + 1) : min_split_size; stop[pos_dim] += split_size; output[i] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false)); start[pos_dim] += split_size; } return output; } }; class HsplitIntFunctor { public: HsplitIntFunctor() = default; Maybe operator()(const std::shared_ptr& input, const int32_t& indices_or_sections) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim >= 1) << "flow.hsplit requires a tensor with at least 1 dimension, but got a tensor with " << ndim << " dimensions!"; CHECK_OR_RETURN(indices_or_sections > 0) << "indices_or_sections must greater than 0"; int32_t dim = (ndim == 1) ? 0 : 1; CHECK_OR_RETURN(input->dim(dim) % indices_or_sections == 0) << "flow.hsplit attempted to split along dimension " << dim << ", but the size of the dimension " << input->shape()->At(dim) << " is not divisible by the split_size " << indices_or_sections << "!"; return TensorSplitInt(input, indices_or_sections, dim); } }; class HsplitVecFunctor { public: HsplitVecFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::vector& indices_or_sections) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim >= 1) << "flow.hsplit requires a tensor with at least 1 dimension, but got a tensor with " << ndim << " dimensions!"; int32_t dim = (ndim == 1) ? 0 : 1; return TensorSplitVec(input, indices_or_sections, dim); } }; class VsplitIntFunctor { public: VsplitIntFunctor() = default; Maybe operator()(const std::shared_ptr& input, const int32_t& indices_or_sections) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim >= 2) << "flow.vsplit requires a tensor with at least 2 dimension, but got a tensor with " << ndim << " dimensions!"; CHECK_OR_RETURN(indices_or_sections > 0) << "indices_or_sections must greater than 0"; CHECK_OR_RETURN(input->dim(0) % indices_or_sections == 0) << "flow.vsplit attempted to split along dimension " << 0 << ", but the size of the dimension " << input->dim(0) << " is not divisible by the split_size " << indices_or_sections << "!"; return TensorSplitInt(input, indices_or_sections, 0); } }; class VsplitVecFunctor { public: VsplitVecFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::vector& indices_or_sections) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim >= 2) << "flow.vsplit requires a tensor with at least 1 dimension, but got a tensor with " << ndim << " dimensions!"; return TensorSplitVec(input, indices_or_sections, 0); } }; class ErfinvFunctor { public: ErfinvFunctor() { op_ = CHECK_JUST(one::OpBuilder("erfinv").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {x}, {}); } private: std::shared_ptr op_; }; class ErfinvInplaceFunctor { public: ErfinvInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder("erfinv").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x) const { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), {})); return outputs->at(0); } private: std::shared_ptr op_; }; class GeluWithApproximateFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::string& approximate) const { if (approximate != "none" && approximate != "tanh") { return Error::RuntimeError() << "the approximate argument should be 'none' or 'tanh'"; } if (approximate == "tanh") { return FastGelu(x); } return Gelu(x); } }; class CumBaseFunctor { public: explicit CumBaseFunctor(std::string op_name) { op_ = CHECK_JUST(one::OpBuilder(op_name).Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input, int64_t dim, const Optional>& dtype) const { auto ndim = input->ndim(); dim = JUST(maybe_wrap_dim(dim, ndim)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); attrs.SetAllAttrs(dim); TensorProcessor tensor_processor; if (dtype) { JUST(tensor_processor.AddInputs({input}, JUST(dtype)).Apply()); } else { JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply()); } TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); } private: std::shared_ptr op_; }; class CumsumFunctor : public CumBaseFunctor { public: CumsumFunctor() : CumBaseFunctor("cumsum") {} }; class CumProdFunctor : public CumBaseFunctor { public: CumProdFunctor() : CumBaseFunctor("cumprod") {} }; class CumGradBaseFunctor { protected: std::shared_ptr op_; }; class CumProdGradFunctor : public CumGradBaseFunctor { public: CumProdGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("cumprod_grad") .Input("dy") .Input("output") .Input("input") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& y, const std::shared_ptr& x, int64_t dim) const { // No need to check dim validation here, while CumProdFunctor handled already auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim"); attrs.SetAllAttrs(dim); return OpInterpUtil::Dispatch(*op_, {dy, y, x}, attrs); } }; // NOTE(Liang Depeng): The implementation of sumproduct_pair are mostly taken from pytorch. // For more details pls refer to: // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L65 // sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and // batch matrix multiplication // its main purpose is to provide a pairwise reduction for einsum static Maybe sumproduct_pair(const std::shared_ptr& left_, const std::shared_ptr& right_, const std::vector& sum_dims_, bool keepdim) { // assumes that tensors have been pre-unsqueezed (so that all dimensions match - after // broadcasting) but makes no other assumptions on the order of dimensions CHECK_OR_RETURN(left_->ndim() == right_->ndim()) << "number of dimensions must match"; if (sum_dims_.size() == 0) return functional::Mul(left_, right_); int64_t dim = left_->ndim(); constexpr size_t dim_bitset_size = 64; CHECK_OR_RETURN(dim <= (int64_t)dim_bitset_size) << "only tensors with up to " << dim_bitset_size << " dims are supported"; std::bitset sum_dims; for (int i = 0; i < sum_dims_.size(); ++i) { size_t d = sum_dims_[i]; CHECK_OR_RETURN(!sum_dims[d]) << "dim " << d << " appears multiple times in the list of dims"; sum_dims[d] = true; } // dimensions that will be part of the output (i.e. not summed over) in three vectors // dims in lro appear in left, right and output, similarly lo: left and output, ro: right and // output also the sizes are kept track of for reshaping std::vector lro, lo, ro; int32_t lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1; std::shared_ptr left = left_; std::shared_ptr right = right_; for (int i = 0; i < dim; ++i) { auto sl = left->shape()->At(i) > 1; auto sr = right->shape()->At(i) > 1; if (sum_dims[i]) { // first dimensions that will be summed over after multiplication if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i)) << "non-broadcast dimensions must match"; sum_size *= left->shape()->At(i); } else if (sl) { // if it is only in one of left and right, we can sum right away left = JUST(functional::ReduceSum(left, {i}, true, NullOpt)); } else if (sr) { right = JUST(functional::ReduceSum(right, {i}, true, NullOpt)); } } else if (sl && sr) { // now deal with dimensions dimensions that will be in the output // dimensions nontrivially in both left and right must be of the same size CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i)) << "non-broadcast dimensions must match"; lro.push_back(i); lro_size *= left->shape()->At(i); } else if (sl) { // keep track of dimensions appearing only once lo.push_back(i); lo_size *= left->shape()->At(i); } else { ro.push_back(i); ro_size *= right->shape()->At(i); } } // we now work with the following permutations / shapes. // the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output // -> permute output output: "lro, lo, 1-for-summed-dims, ro" with orgiginal shape dimensions // left: "lro, lo, summed" permuted with lpermutation and the three flattened right: "lro, // summed, ro" permuted with rpermutation and the three flattened then the permuted output is a // view of bmm(left, right) finally, opermutation reverts the permutation to the original order // of dimensions std::vector out_size; for (auto& d : lro) out_size.push_back(left->shape()->At(d)); for (auto& d : lo) out_size.push_back(left->shape()->At(d)); for (auto& d : sum_dims_) { out_size.push_back(1); (void)(d); }; // avoid warining about not using d for (auto& d : ro) out_size.push_back(right->shape()->At(d)); std::vector lpermutation(lro); lpermutation.insert(lpermutation.end(), lo.begin(), lo.end()); lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end()); lpermutation.insert(lpermutation.end(), ro.begin(), ro.end()); std::vector rpermutation(lro); rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end()); rpermutation.insert(rpermutation.end(), ro.begin(), ro.end()); rpermutation.insert(rpermutation.end(), lo.begin(), lo.end()); std::vector opermutation(lro.size() + lo.size() + sum_dims_.size() + ro.size(), -1); { int32_t i = 0; for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) { opermutation[*it] = i; } for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) { opermutation[*it] = i; } for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) { opermutation[*it] = i; } for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) { opermutation[*it] = i; } } // now we can execute the operations above left = JUST(functional::Permute(left, lpermutation)); DimVector lsv(3); lsv[0] = lro_size; lsv[1] = lo_size; lsv[2] = sum_size; const Shape ls(lsv); left = JUST(functional::Reshape(left, ls)); right = JUST(functional::Permute(right, rpermutation)); DimVector rsv(3); rsv[0] = lro_size; rsv[1] = sum_size; rsv[2] = ro_size; const Shape rs(rsv); right = JUST(functional::Reshape(right, rs)); std::shared_ptr result = JUST(functional::BatchMatMul(left, right, false, false, 1.0)); DimVector osv(out_size.size()); for (int i = 0; i < out_size.size(); ++i) { osv[i] = out_size[i]; } const Shape os(osv); // TODO(Liang Depeng): change reshape to veiw result = JUST(functional::Reshape(result, os)); result = JUST(functional::Permute(result, opermutation)); // finally squeeze summed dimensions if desired if (!keepdim) { auto sizes = result->shape()->dim_vec(); for (int i = dim - 1; i >= 0; i--) { if (sum_dims[i]) { sizes.erase(sizes.begin() + i); } } // TODO(Liang Depeng): change reshape to veiw const Shape s(sizes); result = JUST(functional::Reshape(result, s)); } return result; } namespace { bool einsum_check_label(unsigned char label) { return std::isalpha(label); } uint8_t einsum_label_to_index(unsigned char label) { constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; return std::isupper(label) ? label - 'A' : NUM_OF_LETTERS + (label - 'a'); } unsigned char einsum_index_to_label(uint8_t index) { constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; return index < NUM_OF_LETTERS ? index + 'A' : index - NUM_OF_LETTERS + 'a'; } } // namespace // NOTE(Liang Depeng): The implementation of EinSumFunctor are mostly taken from pytorch. // For more details pls refer to: // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L190 // There are roughly three parts to compute einsum: // 1. Parse equation to extract the labels for each input operand and output // 2. Unsqueeze missing dimensions from input operands and permute to align them // 3. Compute result by multiplying input operands and summing contraction // dimensions We do the last part by reducing to batch matmul. class EinSumFunctor { public: EinSumFunctor() {} Maybe operator()(const std::string& equation, const one::TensorTuple& operands) const { CHECK_OR_RETURN(operands.size() > 0) << "einsum(): must provide at least one input tensor."; // NOTE(Liang Depeng): In order to better understand what einsum is doing, // the following comments will give a detailed explaination of // how the operands of equation "ik,jkl,il->ij" (bilinear) // are transformed during the computation. // Assume that the size of each operands "ik", "jkl" and "il" are // [2, 3], [4, 3, 5], [2, 5] respectively. // Code used to identify ELLIPSIS ("...") constexpr uint8_t ELLIPSIS = 52; // Find arrow (->) to split equation into lhs (input equations) and rhs (output equation) const auto arrow_pos = equation.find("->"); const auto lhs = equation.substr(0, arrow_pos); const auto num_ops = operands.size(); // Convert each input equations into indexes in range [0, 52] and store // them in op_labels for each operand along with ELLIPSIS if present. std::vector> op_labels(num_ops); // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // After running the following for loop, `op_labels` contains 3 vectors. // The contents of each vectors are: // op_labels[0]: [34('i'-'a'+26), 36('k'-'a'+26)] // op_labels[1]: [35('j'-'a'+26), 36('k'-'a'+26), 37('l'-'a'+26)] // op_labels[2]: [34('i'-'a'+26), 37('l'-'a'+26)] bool found_ell = false; std::size_t curr_op = 0; for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { const unsigned char label = lhs[i]; switch (label) { case ' ': // Ignore spaces break; case '.': // process ellipsis CHECK_OR_RETURN( // Only one ellipsis per operand can be given !found_ell) << "einsum(): found \'.\' for operand " << curr_op << " for which an ellipsis was already found"; CHECK_OR_RETURN( // Ensure it's a valid ellipsis i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.') << "einsum(): found \'.\' for operand " << curr_op << " that is not part of any ellipsis"; op_labels[curr_op].push_back(ELLIPSIS); found_ell = true; break; case ',': // Move onto next operand ++curr_op; CHECK_OR_RETURN(curr_op < num_ops) << "einsum(): fewer operands were provided than specified in the equation"; found_ell = false; break; default: // Parse label CHECK_OR_RETURN(einsum_check_label(label)) << "einsum(): invalid subscript given at index " << i << " in the equation string, subscripts must be in [a-zA-Z]"; op_labels[curr_op].push_back(einsum_label_to_index(label)); } } CHECK_OR_RETURN(curr_op == num_ops - 1) << "einsum(): more operands were provided than specified in the equation"; // Labels must be within [a-zA-Z]. constexpr uint8_t TOTAL_LABELS = 52; std::vector label_count(TOTAL_LABELS, 0); // The maximum number of dimensions covered by any ellipsis, needed when // unsqueezing missing dimensions from operands to permute and broadcast int32_t ell_num_dim = 0; // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // After running the following for loop, // the none zero indexes of `label_count` are: // op_labels[34] = 2 // op_labels[35] = 1 // op_labels[36] = 2 // op_labels[37] = 2 // `ell_num_dim` equals to 0 because no ellipsis in equation // Compute label frequency and number of dimensions covered by ellipsis // We do this after parsing labels to make it more readable and simpler // to compute the number of dimensions covered by ellipsis. for (auto i = 0; i < num_ops; i++) { const auto operand = operands[i]; const auto labels = op_labels[i]; const int ndims = operand->ndim(); int32_t nlabels = static_cast(labels.size()); bool has_ellipsis = false; for (const auto& label : labels) { if (label == ELLIPSIS) { --nlabels; has_ellipsis = true; ell_num_dim = std::max(ell_num_dim, ndims - nlabels); } else { ++label_count[label]; } } if (has_ellipsis) { CHECK_OR_RETURN(nlabels <= ndims) << "einsum() the number of subscripts in the equation (" << nlabels << ") is more than the number of dimensions (" << ndims << ") for operand " << i; } else { CHECK_OR_RETURN(nlabels == ndims) << "einsum(): the number of subscripts in the equation (" << nlabels << ") does not match the number of dimensions (" << ndims << ") for operand " << i << " and no ellipsis was given"; } } // We want to align the dimensions of every input tensor to have // shape out_dims + sum_dims. For this, we create a mapping of label // to index into the permuted shape. std::vector label_perm_index(TOTAL_LABELS, -1); // Current index in the permuted shape int32_t perm_index = 0; // Start index of ellipsis dimensions in the permuted shape int32_t ell_index = 0; found_ell = false; // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // After running the following if-else code block, // the none -1 indexes of `label_perm_index` are: // label_perm_index[34] = 0 // label_perm_index[35] = 1 // `perm_index` equals to 2 // `ell_index` equals to 0 because no ellipsis in equation // `found_ell` equals to false because no ellipsis in equation if (arrow_pos == std::string::npos) { // Implicit output is ellipsis (...) + labels seen only once perm_index = ell_num_dim; found_ell = true; for (auto label = 0; label < TOTAL_LABELS; label++) { if (label_count[label] == 1) { label_perm_index[label] = perm_index++; } } } else { // Parse explicit output const auto rhs = equation.substr(arrow_pos + 2); for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) { const unsigned char label = rhs[i]; switch (label) { case ' ': // Ignore spaces break; case '.': // process ellipsis CHECK_OR_RETURN( // There can only be one ellipsis in the output !found_ell) << "einsum(): found \'.\' for output but an ellipsis (...) was already found"; CHECK_OR_RETURN( // Ensure ellipsis is correct i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.') << "einsum(): found \'.\' for output that is not part of any ellipsis (...)"; ell_index = perm_index; perm_index += ell_num_dim; found_ell = true; break; default: CHECK_OR_RETURN(einsum_check_label(label)) << "einsum(): invalid subscript given at index " << lhs.size() + 2 + i << " in the equation string, subscripts must be in [a-zA-Z]"; const auto index = einsum_label_to_index(label); CHECK_OR_RETURN( // Ensure label appeared at least once for some input operand // and at most once for the output label_count[index] > 0 && label_perm_index[index] == -1) << "einsum(): output subscript " << label << (label_perm_index[index] > -1 ? " appears more than once in the output" : " does not appear in the equation for any input operand"); label_perm_index[index] = perm_index++; } } } // Save output size before adding contraction dims (dims to sum out) const int32_t out_size = perm_index; // If ellipsis is not part of the output, add to contraction dimensions if (!found_ell) { ell_index = perm_index; perm_index += ell_num_dim; } // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // After running the following foor loop, // the none -1 indexes of `label_perm_index` are: // label_perm_index[34] = 0 ('i') // label_perm_index[35] = 1 ('j') // label_perm_index[36] = 2 ('k') // label_perm_index[37] = 3 ('l') // `out_size` equals to 2 // `perm_index` equals to 4 // Add contraction labels (labels not present in output) for (auto label = 0; label < TOTAL_LABELS; label++) { if (label_count[label] > 0 && label_perm_index[label] == -1) { label_perm_index[label] = perm_index++; } } // Here we unsqueeze missing dimensions to make all operands have the same // number of dimensions. We take diagonals for repeated labels within the // same operand. Finally we permute the operands to align dimensions as // per the perm_out_index we computed above. TensorTuple permuted_operands; for (auto i = 0; i < num_ops; i++) { std::vector perm_shape(perm_index, -1); std::vector label_dim(TOTAL_LABELS, -1); std::shared_ptr operand = operands[i]; const auto labels = op_labels[i]; const auto original_sizes = operand->shape()->dim_vec(); int32_t j = 0; for (const auto& label : labels) { if (label == ELLIPSIS) { // Add missing dimensions covered by the ellipsis const auto num_missing_dim = ell_num_dim - (original_sizes.size() - labels.size() + 1); for (auto k = 0; k < num_missing_dim; k++) { operand = JUST(functional::Unsqueeze(operand, j)); } for (auto k = 0; k < ell_num_dim; k++) { perm_shape[ell_index + k] = j++; } } else if (label_dim[label] != -1) { // Repeated label, take diagonal const auto dim = label_dim[label]; CHECK_OR_RETURN(operand->dim(j) == operand->dim(dim)) << "einsum() subscript " << einsum_index_to_label(label) << " is repeated for operand " << i << " but the sizes don't match, " << operand->dim(j) << " != " << operand->dim(dim); operand = JUST(functional::Diagonal(operand, 0, dim, j)); operand = JUST(functional::MovedimInt(operand, -1, dim)); } else { // Lookup output index for label label_dim[label] = j; perm_shape[label_perm_index[label]] = j++; } } // Add dimensions for missing labels for (int32_t& index : perm_shape) { if (index == -1) { operand = JUST(functional::Unsqueeze(operand, -1)); index = j++; } } permuted_operands.emplace_back(JUST(functional::Permute(operand, perm_shape))); // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // What is going on within this foor loop? // For operand "ik" size = [2, 3]: // `perm_shape` equals to [0, 2, 1, 3] // first unsqueeze "ik" to 4 dim, from [2, 3] to [2, 3, 1, 1] // then permute with `perm_shape`, from [2, 3, 1, 1] to [2, 1, 3, 1] // // For operand "jkl" size = [4, 3, 5]: // `perm_shape` equals to [3, 0, 1, 2] // first unsqueeze "jkl" to 4 dim, from [4, 3, 5] to [4, 3, 5, 1] // then permute with `perm_shape`, from [4, 3, 5, 1] to [1, 4, 3, 5] // // For operand "il" size = [2, 5]: // `perm_shape` equals to [0, 2, 3, 1] // first unsqueeze "ik" to 4 dim, from [2, 5] to [2, 5, 1, 1] // then permute with `perm_shape`, from [2, 5, 1, 1] to [2, 1, 1, 5] } // Check if operands broadcast and keep track of last operand with // dimension size != 1 for optimizing reductions std::vector dim_last_op(perm_index, 0); bool has_zero_size_dim = false; // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // After running the following foor loop, // The contents of `dim_last_op` are: // dim_last_op[0] = 2 // dim_last_op[1] = 1 // dim_last_op[2] = 1 // dim_last_op[3] = 2 // `has_zero_size_dim` equals to false for (auto dim = 0; dim < perm_index; dim++) { auto broadcast_size = permuted_operands[0]->dim(dim); for (auto i = 1; i < num_ops; i++) { const auto dim_size = permuted_operands[i]->dim(dim); if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) { std::ostringstream msg; msg << "einsum(): operands do not broadcast with remapped shapes [original->remapped]:"; for (auto j = 0; j < num_ops; j++) { msg << " " << operands[j]->shape()->DebugStr() << "->" << permuted_operands[j]->shape()->DebugStr(); } CHECK_OR_RETURN(false) << msg.str(); } if (dim_size != 1) { broadcast_size = dim_size; dim_last_op[dim] = i; } } has_zero_size_dim |= broadcast_size == 0; } // Compute result std::shared_ptr result = permuted_operands[0]; // Fast path for when an operand has zero sized dim if (has_zero_size_dim) { DimVector out_shape(out_size); for (auto i = 0; i < out_size; i++) { out_shape[i] = permuted_operands[dim_last_op[i]]->dim(i); } const Shape shape(out_shape); return functional::Constant(shape, Scalar(0), *permuted_operands[0]->dtype(), NullOpt); } // Sum out or squeeze dimensions that are size 1 for all later operands int dim = out_size; for (int i = dim; i < perm_index; ++i, ++dim) { if (dim_last_op[i] == 0) { if (result->dim(dim) == 1) { std::vector dims = {dim--}; result = JUST(functional::Squeeze(result, dims)); } else { result = JUST(functional::ReduceSum(result, {dim--}, false, NullOpt)); } } } for (auto i = 1; i < num_ops; i++) { auto operand = permuted_operands[i]; std::vector sum_dims; // Sum out or squeeze dimensions that are size 1 for all later operands dim = out_size; for (int j = dim; j < perm_index; ++j, ++dim) { if (dim_last_op[j] < i) { std::vector dims = {dim--}; operand = JUST(functional::Squeeze(operand, dims)); } else if (dim_last_op[j] == i) { if (result->dim(dim) == 1) { operand = JUST(functional::ReduceSum(operand, {dim}, false, NullOpt)); std::vector dims = {dim--}; result = JUST(functional::Squeeze(result, dims)); } else { sum_dims.push_back(dim); } } } // Multiply tensors and sum out dimensions in sum_dims if (sum_dims.empty()) { result = JUST(functional::Mul(result, operand)); } else if (sum_dims.size() == result->ndim()) { auto flatten_result = JUST(functional::Flatten(result, 0, -1)); auto flatten_operand = JUST(functional::Flatten(operand, 0, -1)); result = JUST(functional::Dot(flatten_result, flatten_operand)); } else { result = JUST(sumproduct_pair(result, operand, sum_dims, false)); } // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". // What is going on within this foor loop? // For iter i = 1: // result = permuted_operands[0], size = [2, 1, 3, 1] // operand = permuted_operands[1], size = [1, 4, 3, 5] // sum_dims = [2, ] // what happened in `sumproduct_pair` ? // result [2, 1, 3, 1] will be permuted to [2, 3, 1, 1] then // reshaped to [1, 2, 3] // operand [1, 4, 3, 5] will be permuted to [3, 4, 5, 1] then // reshape to [1, 3, 4 * 5] // perform batch_matmul(result, operand) => [1, 2, 4 * 5] // then reshape to [2, 1, 4, 5] then permute to // [2, 4, 1, 5], at last reshape to [2, 4, 5] // // For iter i = 2: // result, size = [2, 4, 5] // operand = permuted_operands[2], size = [2, 1, 1, 5] // squeeze operand from [2, 1, 1, 5] to [2, 1, 5] // sum_dims = [2,] // what happened in `sumproduct_pair` ? // result [2, 4, 5] will be permuted to [2, 4, 5] then // reshaped to [2, 4, 5] // operand [2, 1, 5] will be permuted to [2, 5, 1] then // reshape to [2, 5, 1] // perform batch_matmul(result, operand)=>[2, 4, 1] // then reshape to [2, 4, 1] then permute to [2, 4, 1] // at last reshape to [2, 4] } return result; } }; class TruncFunctor { public: TruncFunctor() { op_ = CHECK_JUST(one::OpBuilder("trunc").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class AddCDivFunctor { public: AddCDivFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& tensor1, const std::shared_ptr& tensor2, const Scalar& value) const { return JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, false)); } }; class InplaceAddCDivFunctor { public: InplaceAddCDivFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& tensor1, const std::shared_ptr& tensor2, const Scalar& value) const { JUST(CheckInplaceValid(input)); std::shared_ptr outputs = std::make_shared(1); JUST(VectorAt(*outputs, 0)) = input; JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, true)); return JUST(VectorAt(*outputs, 0)); } }; namespace { constexpr int64_t cufft_max_ndim = 3; // must keep Equal to `oneflow/user/kernels/cufft_plan_cache.h:max_rank` enum class fft_norm_mode { none = 0, // No normalization by_root_n, // Divide by sqrt(signal_size) by_n, // Divide by signal_size }; bool use_optimized_cufft_path(const std::vector& fft_dims) { // For performance reason, when dim starts with (0, 1), do not use the optimized path. if (fft_dims.size() > cufft_max_ndim || (fft_dims.size() >= 2 && fft_dims[0] == 0 && fft_dims[1] == 1)) { return false; } else { return true; } } // Convert NumPy compatible normalization mode string to enum values // In Numpy, "forward" translates to `by_n` for a forward transform and `none` for backward. static fft_norm_mode fft_norm_from_string(const Optional& norm_op, bool forward) { std::string norm_str = norm_op.value_or("backward"); if (norm_str == "backward") { return forward ? fft_norm_mode::none : fft_norm_mode::by_n; } else if (norm_str == "forward") { return forward ? fft_norm_mode::by_n : fft_norm_mode::none; } else if (norm_str == "ortho") { return fft_norm_mode::by_root_n; } return fft_norm_mode::none; } template static T fft_compute_fct(int64_t size, fft_norm_mode normalization) { constexpr auto one = static_cast(1); switch (normalization) { case fft_norm_mode::none: return one; case fft_norm_mode::by_n: return one / static_cast(size); case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast(size)); } return static_cast(0); } template static T fft_compute_fct(const Shape& in_shape, const std::vector& dims, fft_norm_mode normalization) { if (normalization == fft_norm_mode::none) { return static_cast(1); } int64_t n = 1; for (int64_t idx : dims) { n *= in_shape.At(idx); } return fft_compute_fct(n, normalization); } } // namespace class FftBaseFunctor { public: explicit FftBaseFunctor() {} explicit FftBaseFunctor(std::string op_name) { op_ = CHECK_JUST(one::OpBuilder(op_name).Input("input").Output("out").Build()); } virtual ~FftBaseFunctor() = default; Maybe resize_fft_input(const std::shared_ptr& x, const std::vector& dims, const std::vector& sizes) const { CHECK_EQ_OR_THROW(dims.size(), sizes.size()) << "dims.size() != sizes.size()."; bool must_copy = false; auto x_sizes = x->shape()->dim_vec(); std::vector pad_amount(x_sizes.size() * 2); std::vector slice_st(x_sizes.size()); std::vector slice_end(x_sizes.size()); std::vector slice_step(x_sizes.size(), 1); FOR_RANGE(int64_t, i, 0, x_sizes.size()) { slice_st[i] = 0; slice_end[i] = x_sizes[i]; } FOR_RANGE(int64_t, i, 0, sizes.size()) { if (sizes[i] == -1) { continue; } if (x_sizes[dims[i]] < sizes[i]) { must_copy = true; auto pad_idx = pad_amount.size() - 2 * dims[i] - 1; pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]; } if (x_sizes[dims[i]] > sizes[i]) { // slice in dims[i] slice_end[dims[i]] = sizes[i]; } } auto sliced_tenosr = JUST(functional::Slice(x, slice_st, slice_end, slice_step, false)); return must_copy ? functional::ConstantPad(sliced_tenosr, pad_amount, 0) : sliced_tenosr; } Maybe> promote_type_fft(Symbol type, bool require_complex = false) const { if (type->is_complex()) { return type; } if (!type->is_floating_point()) { type = GetDefaultDType(); } CHECK_OR_RETURN(type->data_type() == kFloat || type->data_type() == kDouble) << "Unsupported dtype " << type->name() << ", " << "support kFloat and kDouble"; if (!require_complex) { return type; } switch (type->data_type()) { // TO-DO: add kFloat16 case (kFloat): return CHECK_JUST(DType::Get(DataType::kComplex64)); case (kDouble): return CHECK_JUST(DType::Get(DataType::kComplex128)); default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; } CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; } Maybe promote_tensor_fft(const std::shared_ptr& x, bool require_complex = false) const { auto cur_type = x->dtype(); auto new_type = JUST(promote_type_fft(cur_type, require_complex)); if (cur_type->data_type() == new_type->data_type()) { return x; } else { TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({x}, {new_type}).Apply()); return JUST(oneflow::VectorAt(JUST(tensor_processor.GetInputs()), 0)); } } Maybe maybe_wrap_dims(std::vector& dims, int64_t dim_post_expr, bool wrap_scalar = true) const { if (dim_post_expr <= 0) { if (!wrap_scalar) { CHECK_OR_RETURN(false) << "RuntimeError: dimension specified as " << dims[0] << " but tensor has no dimensions"; } dim_post_expr = 1; // this will make range [-1, 0] } int64_t min = -dim_post_expr; int64_t max = dim_post_expr - 1; for (auto& dim : dims) { if (dim < min || dim > max) { CHECK_OR_RETURN(false) << "RuntimeError: Dimension out of range (expected to be in range of [" << min << ", " << max << "], but got " << dim << ")"; } if (dim < 0) dim += dim_post_expr; } return Maybe::Ok(); } Maybe calculate_fftn_shape_and_dims(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, std::vector& fft_shape, std::vector& fft_dims) const { if (dims.has_value()) { fft_dims = *JUST(dims); JUST(maybe_wrap_dims(fft_dims, x->ndim())); std::vector copy = fft_dims; std::sort(copy.begin(), copy.end()); auto duplicate = std::adjacent_find(copy.begin(), copy.end()); CHECK_OR_RETURN(duplicate == copy.end()) << "RuntimeError: FFT dims must be unique"; } else { fft_dims.resize(x->ndim()); for (int i = 0; i < x->ndim(); i++) { fft_dims[i] = i; } } if (!n.has_value()) { fft_shape.resize(fft_dims.size()); for (int i = 0; i < fft_dims.size(); i++) { fft_shape[i] = x->dim(fft_dims[i]); } } else { fft_shape = *JUST(n); if (dims.has_value()) { // got n, also got dim for (int i = 0; i < fft_dims.size(); i++) { if (fft_shape[i] == -1) { fft_shape[i] = x->dim(fft_dims[i]); } } } else { // got n, but not got dim fft_dims.resize(fft_shape.size()); FOR_RANGE(size_t, i, 0, fft_dims.size()) { fft_dims[i] = x->ndim() - fft_dims.size() + i; } } } return Maybe::Ok(); } Maybe parse_input_n_and_dims(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, std::vector& fft_len, std::vector& wrapped_dims) const { if (n.has_value() && dims.has_value()) { CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) << "RuntimeError: When dim and shape were both given, they must have the same length"; } wrapped_dims.resize(x->ndim()); fft_len.resize(x->ndim()); if (dims.has_value() && (*JUST(dims)).size() == 1) { // 1D-discrete fourier transform wrapped_dims = *JUST(dims); JUST(maybe_wrap_dims(wrapped_dims, x->ndim())); fft_len.resize(wrapped_dims.size()); fft_len[0] = n.has_value() == true ? (*JUST(n))[0] : x->dim(wrapped_dims[0]); if (fft_len[0] == -1) { fft_len[0] = x->dim(wrapped_dims[0]); } CHECK_OR_RETURN(fft_len[0] >= 1) << "RuntimeError: Expected n >= 1, but got " << fft_len[0]; } else if (n.has_value() && JUST(n)->size() == 1) { // 1D-discrete fourier transform fft_len = *(JUST(n)); if (fft_len[0] == -1) { fft_len[0] = x->shape()->back(); } CHECK_OR_RETURN(fft_len[0] >= 1) << "RuntimeError: Expected n >= 1, but got " << fft_len[0]; wrapped_dims.resize(1); wrapped_dims[0] = x->ndim() - 1; } else { // ND-discrete fourier transform JUST(calculate_fftn_shape_and_dims(x, n, dims, fft_len, wrapped_dims)); } return Maybe::Ok(); } Maybe permute_and_reshape(const std::shared_ptr& self, const std::vector& out_sizes, const std::vector& fft_dims, std::vector& out_strides) const { // Permute and reshape `self` Tensor. // This can maximizes data locality const int64_t ndim = self->ndim(); const int64_t fft_ndim = fft_dims.size(); const int64_t batch_dims = ndim - fft_ndim; const auto& in_stride = JUST(self->stride()); // Permute dimensions to make batch dims come first, and this maximizes data locality std::vector dim_permute(ndim); std::iota(dim_permute.begin(), dim_permute.end(), int32_t(0)); std::vector is_transformed_dim(ndim, false); for (const auto& dim : fft_dims) { is_transformed_dim[dim] = true; } auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(), [&](int64_t d) { return !is_transformed_dim[d]; }); std::sort(dim_permute.begin(), batch_end, [&](int64_t a, int64_t b) { return in_stride->at(a) > in_stride->at(b); }); std::copy(fft_dims.begin(), fft_dims.end(), batch_end); // permute auto input = JUST(functional::Permute(self, dim_permute)); std::vector batched_sizes(fft_ndim + 1); batched_sizes[0] = -1; std::copy(input->shape()->begin() + batch_dims, input->shape()->end(), batched_sizes.begin() + 1); // reshape Shape batched_shape(batched_sizes); input = JUST(functional::Reshape(input, batched_shape)); const auto batch_size = input->shape()->At(0); batched_sizes[0] = batch_size; std::vector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); FOR_RANGE(int64_t, i, 0, fft_dims.size()) { batched_out_sizes[i + 1] = out_sizes[fft_dims[i]]; } // Inplace reshaping to original batch shape and inverting the dimension permutation out_strides.resize(ndim, 0); int64_t batch_numel = 1; Stride contiguous_out_strides = Stride(batched_out_sizes); for (int64_t i = batch_dims - 1; i >= 0; --i) { out_strides[dim_permute[i]] = batch_numel * contiguous_out_strides[0]; batch_numel *= out_sizes[dim_permute[i]]; } FOR_RANGE(int64_t, i, batch_dims, ndim) { out_strides[dim_permute[i]] = contiguous_out_strides[1 + (i - batch_dims)]; } // Judge if the input needs to be cloned int64_t signal_ndim = input->shape()->size() - 1; const Stride& batched_input_strides = *(JUST(input->stride())); auto last_stride = JUST(oneflow::VectorAt(batched_input_strides, signal_ndim)); bool must_clone_input = false; if (JUST(oneflow::VectorAt(batched_input_strides, 0)) == 0) { must_clone_input = true; } for (auto i = signal_ndim - 1; !must_clone_input && i > 0; i--) { auto stride = JUST(oneflow::VectorAt(batched_input_strides, i)); if (JUST(oneflow::VectorAt(*(input->shape()), i)) == 1) { continue; } else if (stride > 0 && stride % last_stride == 0) { last_stride = stride; } else { must_clone_input = true; } } if (must_clone_input) { input = JUST(functional::ToContiguous(input)); } return input; } Maybe parse_c2r_input_n_and_dims(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, int64_t& last_dim_size, std::vector& fft_len, std::vector& wrapped_dims) const { JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims)); // infer last_dim_size last_dim_size = 0; if (!n.has_value() || JUST(n)->back() == -1) { int64_t last_dim = wrapped_dims.back(); last_dim_size = 2 * (x->dim(last_dim) - 1); } else { last_dim_size = JUST(n)->back(); } CHECK_OR_RETURN(last_dim_size >= 1) << "RuntimeError: Invalid number of last_dim_size (" << last_dim_size << ") specified"; fft_len.back() = last_dim_size / 2 + 1; return Maybe::Ok(); } protected: std::shared_ptr op_; }; class FftC2CFunctor : public FftBaseFunctor { public: FftC2CFunctor() : FftBaseFunctor("fft_c2c") {} Maybe operator()(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, int32_t norm_mode, bool forward, bool normalized) const { // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is // not valid when using a CPU device, because the cpu's fft operator will be normalized inside // the cpu oprator according to the parameter `forward` and the type of FFT transform CHECK_OR_RETURN(x->dtype()->is_complex()) << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " << x->dtype()->name(); std::vector fft_len(x->ndim(), 0); std::vector wrapped_dims(x->ndim(), 0); JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims)); auto resized_tensor = n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x; DeviceType input_device{}; if (x->is_global()) { input_device = JUST(x->parallel_desc())->device_type(); } else { input_device = JUST(x->device())->enum_type(); } double norm_fct = fft_compute_fct(*(resized_tensor->shape()), wrapped_dims, static_cast(norm_mode)); if (input_device == DeviceType::kCPU) { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "forward", "norm_mode", "norm_fct"); attrs.SetAllAttrs(wrapped_dims, forward, norm_mode, norm_fct); return OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs); } else if (input_device == DeviceType::kCUDA) { if (wrapped_dims.empty()) { return resized_tensor; } std::vector out_sizes(resized_tensor->shape()->dim_vec().begin(), resized_tensor->shape()->dim_vec().end()); std::vector sorted_dims(wrapped_dims.begin(), wrapped_dims.end()); auto working_tensor = resized_tensor; std::vector out_strides; std::shared_ptr output; while (true) { // Sort Dimemsions every iteration auto strides = *JUST(working_tensor->stride()); std::sort(sorted_dims.begin(), sorted_dims.end(), [&](int64_t a, int64_t b) { return strides[a] > strides[b]; }); const auto max_dims = std::min(static_cast(cufft_max_ndim), sorted_dims.size()); auto first_dims_end = sorted_dims.end(); auto first_dims_begin = first_dims_end - max_dims; std::vector first_dims(first_dims_begin, first_dims_end); auto input = JUST(permute_and_reshape(working_tensor, out_sizes, first_dims, out_strides)); std::vector fft_dims(input->ndim() - 1); // must >= 1 std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "forward", "norm_mode", "norm_fct"); attrs.SetAllAttrs(fft_dims, forward, norm_mode, norm_fct); output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); output = JUST( functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); sorted_dims.resize(sorted_dims.size() - max_dims); if (sorted_dims.empty()) { break; } working_tensor = std::move(output); } if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); } return output; } else { CHECK_OR_RETURN(false) << "RuntimeError: FFTC2C Only support cpu and cuda device."; } } }; class FftR2CFunctor : public FftBaseFunctor { public: FftR2CFunctor() : FftBaseFunctor("fft_r2c") {} Maybe operator()(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, int32_t norm_mode, bool onesided, bool forward, bool normalized) const { // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is // not valid when using a CPU device, because the cpu's fft operator will be normalized inside // the cpu oprator according to the parameter `forward` and the type of FFT transform CHECK_OR_RETURN(!(x->dtype()->is_complex())) << "RuntimeError: expects the dtype of input Tensor is Real, but gets " << x->dtype()->name(); auto input_tensor = JUST(promote_tensor_fft(x)); if (n.has_value() && dims.has_value()) { CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) << "RuntimeError: When dim and shape were both given, they must have the same length"; } std::vector fft_len(input_tensor->ndim(), 0); std::vector wrapped_dims(input_tensor->ndim(), 0); JUST(parse_input_n_and_dims(input_tensor, n, dims, fft_len, wrapped_dims)); auto resized_tensor = n.has_value() == true ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len)) : input_tensor; DeviceType input_device{}; if (x->is_global()) { input_device = JUST(x->parallel_desc())->device_type(); } else { input_device = JUST(x->device())->enum_type(); } double norm_fct = fft_compute_fct(*(resized_tensor->shape()), wrapped_dims, static_cast(norm_mode)); std::shared_ptr output; if (input_device == DeviceType::kCPU) { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, onesided); output = JUST(OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs)); } else if (input_device == DeviceType::kCUDA) { std::vector input_sizes(resized_tensor->shape()->begin(), resized_tensor->shape()->end()); std::vector onesided_sizes = input_sizes; int64_t last_dim = wrapped_dims.back(); int64_t last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; onesided_sizes[last_dim] = last_dim_halfsize; std::vector out_sizes = onesided ? onesided_sizes : input_sizes; if (use_optimized_cufft_path(wrapped_dims)) { std::vector out_strides; auto input = JUST(permute_and_reshape(resized_tensor, out_sizes, wrapped_dims, out_strides)); std::vector fft_dims(input->ndim() - 1); // must >= 1 std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, onesided); output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); output = JUST( functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); } else { // First do the **one-sided** R2C transform on the last dimension const std::shared_ptr& working_tensor = resized_tensor; { std::vector out_strides; auto input = JUST( permute_and_reshape(/*self=*/working_tensor, /*out_sizes=*/onesided_sizes, /*fft_dims=*/{wrapped_dims.back()}, /*out_strides=*/out_strides)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); int64_t last_dim = input->shape()->size() - 1; std::vector fft_last_dim_vec = {last_dim}; attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*onesided=*/true); output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); output = JUST(functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); } // Then any remaining C2C transforms std::vector sorted_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); if (!sorted_dims.empty()) { output = JUST(functional::FftC2C(output, NullOpt, sorted_dims, norm_mode, /*forward=*/true, /*normalize=*/false)); } } if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); } } else { CHECK_OR_RETURN(false) << "RuntimeError: FFTR2C Only support cpu and cuda device."; } if (!forward) { return functional::ConjPhysical(output); } else { return output; } } }; class FftC2RFunctor : public FftBaseFunctor { public: FftC2RFunctor() : FftBaseFunctor("fft_c2r") {} Maybe operator()(const std::shared_ptr& x, const Optional>& n, const Optional>& dims, int32_t norm_mode, bool forward, bool normalized) const { // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is // not valid when using a CPU device, because the cpu's fft operator will be normalized inside // the cpu oprator according to the parameter `forward` and the type of FFT transform CHECK_OR_RETURN(x->dtype()->is_complex()) << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " << x->dtype()->name(); if (n.has_value() && dims.has_value()) { CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) << "RuntimeError: When dim and shape were both given, they must have the same length"; } std::vector wrapped_dims(x->ndim(), 0); std::vector fft_len(x->ndim(), 0); int64_t last_dim_size = 0; JUST(parse_c2r_input_n_and_dims(x, n, dims, last_dim_size, fft_len, wrapped_dims)); auto resized_tensor = n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x; Shape out_shape = *(resized_tensor->shape()); out_shape[wrapped_dims.back()] = last_dim_size; double norm_fct = fft_compute_fct(out_shape, wrapped_dims, static_cast(norm_mode)); if (forward) { resized_tensor = JUST(functional::ConjPhysical(resized_tensor)); } DeviceType input_device{}; if (x->is_global()) { input_device = JUST(x->parallel_desc())->device_type(); } else { input_device = JUST(x->device())->enum_type(); } if (input_device == DeviceType::kCPU) { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, last_dim_size); return OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs); } else if (input_device == DeviceType::kCUDA) { std::shared_ptr output; if (use_optimized_cufft_path(wrapped_dims)) { auto input = JUST(functional::ToContiguous(resized_tensor)); std::vector out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end()); std::vector out_strides; input = JUST(permute_and_reshape(input, out_sizes, wrapped_dims, out_strides)); std::vector fft_dims(input->ndim() - 1); // must >= 1 std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, last_dim_size); output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); output = JUST( functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); } else { // First complete any C2C transforms std::shared_ptr temp; if (wrapped_dims.size() > 1) { std::vector any_c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); temp = JUST(functional::FftC2C(resized_tensor, NullOpt, any_c2c_dims, static_cast(fft_norm_mode::none), /*forward=*/false, /*normalized=*/false)); } else { temp = JUST(functional::ToContiguous(resized_tensor)); } // Finally, do the 1D C2R transforms on the last dim std::vector out_strides; std::vector out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end()); auto input = JUST(permute_and_reshape(/*self=*/temp, /*out_sizes=*/out_sizes, /*fft_dims=*/{wrapped_dims.back()}, /*out_strides=*/out_strides)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); int64_t last_dim = input->shape()->size() - 1; std::vector fft_last_dim_vec = {last_dim}; attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*last_dim_size=*/last_dim_size); output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); output = JUST( functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); } if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), /*inplace=*/true)); } return output; } else { CHECK_OR_RETURN(false) << "RuntimeError: FFTC2R Only support cpu and cuda device."; } } }; class FftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { std::string norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return input->dtype()->is_complex() ? functional::FftC2C(input, len, fft_dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true) : functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), /*onesided=*/false, /*forward=*/forward, /*normalized=*/true); } }; class IFftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { auto norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return input->dtype()->is_complex() ? functional::FftC2C(input, len, fft_dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true) : functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), /*onesided=*/false, /*forward=*/forward, /*normalized=*/true); } }; class Fft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::FftN(input, s, dim, norm); } }; class IFft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::IFftN(input, s, dim, norm); } }; class FftNFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { std::string norm_str = norm.value_or("backward"); bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); if (!(input->dtype()->is_complex())) { // cast to complex TensorProcessor tensor_processor; Symbol complex_dtype; if (input->dtype() == DType::Double()) { complex_dtype = DType::Complex128(); } else { complex_dtype = DType::Complex64(); } JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } else { return functional::FftC2C(input, s, dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } } }; class IFftNFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { std::string norm_str = norm.value_or("backward"); bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); if (!(input->dtype()->is_complex())) { // cast to complex TensorProcessor tensor_processor; Symbol complex_dtype; if (input->dtype() == DType::Double()) { complex_dtype = DType::Complex128(); } else { complex_dtype = DType::Complex64(); } JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } else { return functional::FftC2C(input, s, dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } } }; class RFftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { CHECK_OR_RETURN(!(input->dtype()->is_complex())) << "RuntimeError: expects the dtype of input Tensor is Real, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), /*onesided=*/true, /*forward=*/forward, /*normalized=*/true); } }; class IRFftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { std::string norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return functional::FftC2R(input, len, fft_dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } }; class RFft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::RFftN(input, s, dim, norm); } }; class IRFft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::IRFftN(input, s, dim, norm); } }; class RFftNFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { CHECK_OR_RETURN(!(input->dtype()->is_complex())) << "RuntimeError: expects the dtype of input Tensor is Real, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); return functional::FftR2C(input, s, dim, static_cast(norm_mode), /*onesided=*/true, /*forward=*/forward, /*normalized=*/true); } }; class IRFftNFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { CHECK_OR_RETURN(input->dtype()->is_complex()) << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); return functional::FftC2R(input, s, dim, static_cast(norm_mode), /*forward=*/false, /*normalized=*/true); } }; class HFftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { CHECK_OR_RETURN(input->dtype()->is_complex()) << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return functional::FftC2R(input, len, fft_dim, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } }; class IHFftFunctor { public: Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, const Optional& norm) const { CHECK_OR_RETURN(!(input->dtype()->is_complex())) << "RuntimeError: expects the dtype of input Tensor is Real, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); std::vector fft_dim{dim}; bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); std::vector len{n}; return functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), /*onesided=*/true, /*forward=*/forward, /*normalized=*/true); } }; class HFft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::HFftN(input, s, dim, norm); } }; class IHFft2Functor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& s, const std::vector& dim, const Optional& norm) const { return functional::IHFftN(input, s, dim, norm); } }; class HFftNFunctor : FftBaseFunctor { public: HFftNFunctor() : FftBaseFunctor() {} Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { CHECK_OR_RETURN(input->dtype()->is_complex()) << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); bool forward = true; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); if (s.has_value() && dim.has_value()) { CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size()) << "RuntimeError: When dim and shape were both given, they must have the same length"; } std::vector wrapped_dims(input->ndim(), 0); std::vector fft_len(input->ndim(), 0); int64_t last_dim_size = 0; JUST(parse_c2r_input_n_and_dims(input, s, dim, last_dim_size, fft_len, wrapped_dims)); auto resized_tensor = s.has_value() == true ? JUST(resize_fft_input(input, wrapped_dims, fft_len)) : input; std::shared_ptr temp; if (wrapped_dims.size() > 1) { // ND Fast Fourier Transform std::vector c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); temp = JUST(functional::FftC2C(resized_tensor, NullOpt, c2c_dims, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true)); } else { temp = resized_tensor; } // Finally, do 1D fft_c2r int64_t last_dim = wrapped_dims.back(); std::vector last_dim_vec = {last_dim}; std::vector last_dim_size_vec = {last_dim_size}; return functional::FftC2R(temp, last_dim_size_vec, last_dim_vec, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } }; class IHFftNFunctor : FftBaseFunctor { public: IHFftNFunctor() : FftBaseFunctor() {} Maybe operator()(const std::shared_ptr& input, const Optional>& s, const Optional>& dim, const Optional& norm) const { CHECK_OR_RETURN(!(input->dtype()->is_complex())) << "RuntimeError: expects the dtype of input Tensor is Real, but gets " << input->dtype()->name(); std::string norm_str = norm.value_or("backward"); bool forward = false; fft_norm_mode norm_mode = fft_norm_mode::none; norm_mode = fft_norm_from_string(norm_str, forward); auto input_tensor = JUST(promote_tensor_fft(input, false)); if (s.has_value() && dim.has_value()) { CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size()) << "RuntimeError: When dim and shape were both given, they must have the same length"; } std::vector fft_len(input_tensor->ndim(), 0); std::vector wrapped_dims(input_tensor->ndim(), 0); JUST(parse_input_n_and_dims(input_tensor, s, dim, fft_len, wrapped_dims)); auto resized_tensor = s.has_value() == true ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len)) : input_tensor; // First do 1D R2C Transform on the last dim const auto last_dim_len = fft_len.back(); const auto last_dim = wrapped_dims.back(); std::vector r2c_fft_len = {last_dim_len}; std::vector r2c_fft_dim = {last_dim}; auto temp = JUST(functional::FftR2C(resized_tensor, r2c_fft_len, r2c_fft_dim, static_cast(norm_mode), /*onesided=*/true, /*forward=*/forward, /*normalized=*/true)); // NOTE: `temp` is already conjugated in `functional::FftR2C` if (wrapped_dims.size() == 1) { return temp; } // Finally do C2C Transform on the remaining dims std::vector c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); return functional::FftC2C(temp, NullOpt, c2c_dims, static_cast(norm_mode), /*forward=*/forward, /*normalized=*/true); } }; class StftFunctor { public: StftFunctor() { op_ = CHECK_JUST(one::OpBuilder("stft").Input("input").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t n_fft, const Optional& hop_length, const Optional& win_length, const Optional& window, const bool center, const std::string& mode, const bool normalized, const bool onesided, const bool return_complex) const { CHECK_OR_RETURN(n_fft > 0) << Error::RuntimeError() << "Expected 0 < n_fft , but got " << n_fft; int64_t new_hop_length = hop_length.has_value() == true ? JUST(hop_length) : n_fft / 4; int64_t new_win_length = win_length.has_value() == true ? JUST(win_length) : n_fft; auto input_tensor = input; // TODO(yzm):Remove this line when complex numbers are supported CHECK_OR_RETURN(return_complex == false) << Error::RuntimeError() << "return_complex parameter is not supported at this time"; const auto& NumAxes = input_tensor->shape()->NumAxes(); CHECK_OR_RETURN(NumAxes == 2 || NumAxes == 1) << Error::RuntimeError() << "Expected a 1D or 2D tensor,but got " << NumAxes << "D"; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("normalized", "onesided", "return_complex"); attrs.SetAllAttrs(normalized, onesided, return_complex); if (NumAxes == 1) { input_tensor = JUST(functional::Unsqueeze(input_tensor, 0)); } if (center) { const auto& input_shape = input_tensor->shape(); const auto input_dim = input_tensor->shape()->NumAxes(); const auto extra_dims = std::max(size_t{3}, (size_t)input_dim) - input_dim; const auto pad_amount = n_fft / 2; DimVector extended_shape(extra_dims, 1); extended_shape.append(input_shape->begin(), input_shape->end()); input_tensor = JUST(functional::Pad(JUST(functional::View(input_tensor, Shape(extended_shape))), {pad_amount, pad_amount}, mode, Scalar(0))); DimVector view_shape; if (input_dim == 1) { view_shape = {input_tensor->shape()->back()}; } else { view_shape = {input_shape->at(0), input_tensor->shape()->back()}; } input_tensor = JUST(functional::View(input_tensor, Shape(view_shape))); } int32_t batch = input_tensor->shape()->At(0); int32_t len = input_tensor->shape()->At(1); int32_t n_frames = 1 + (len - n_fft) / new_hop_length; int32_t fft_size = static_cast(n_fft); CHECK_OR_RETURN(n_fft > 0 && n_fft <= len) << Error::RuntimeError() << "Expected 0 < n_fft < " << len << " ,but got " << n_fft; CHECK_GT_OR_RETURN(new_hop_length, 0) << Error::RuntimeError() << "Expected hop_length > 0, but got " << new_hop_length; CHECK_OR_RETURN(new_win_length > 0 && new_win_length <= n_fft) << Error::RuntimeError() << "Expected 0 < win_length <=n_fft ,but got " << new_win_length; const auto& stride = *JUST(input_tensor->stride()); std::vector strides(stride.begin(), stride.end()); input_tensor = JUST(view::AsStrided(input_tensor, {batch, n_frames, fft_size}, {JUST(VectorAt(strides, 0)), static_cast(new_hop_length) * JUST(VectorAt(strides, 1)), JUST(VectorAt(strides, 1))}, 0)); std::shared_ptr temp_tensor; if (window.has_value()) { temp_tensor = JUST(window); CHECK_OR_RETURN(temp_tensor->shape()->NumAxes() == 1 && temp_tensor->shape()->at(0) == new_win_length) << Error::RuntimeError() << "Expected a 1D window tensor of size equal to win_length=" << new_win_length << ", but got window with size " << temp_tensor->shape()->ToString(); } if (new_win_length < n_fft) { temp_tensor = JUST(functional::Fill(temp_tensor, 0)); const int64_t left = (n_fft - new_win_length) / 2; if (window.has_value()) { // TODO(yzm):Copy the window matrix to the defined range,such as //''' // functional::AssignLocalTensor(JUST(functional::Narrow(temp_tensor, 0, // left,new_win_length)), window); //''' // Remove the following check after support CHECK_OR_RETURN(false) << Error::RuntimeError() << "The following conditions are not currently supported: " "win_length( *op_, {JUST(functional::ToContiguous(input_tensor))}, attrs)); if (NumAxes == 2 && input->shape()->At(0) == 1) { output = JUST(functional::Unsqueeze(output, 0)); } return output; } private: std::shared_ptr op_; }; class FusedWeightedSumFunctor { public: FusedWeightedSumFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("fused_weighted_sum").Input("in", n).Output("out").Build()); } } Maybe operator()(const TensorTuple& in, const std::vector& weights, const float& alpha) const { CHECK_GE_OR_RETURN(in.size(), 1); CHECK_LT_OR_RETURN(in.size(), kMaxInputCount); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("weights", "alpha"); attrs.SetAllAttrs(weights, alpha); return JUST(OpInterpUtil::Dispatch(*op_[in.size()], in, attrs)); } private: std::vector> op_; }; class FusedCenterFunctor { public: FusedCenterFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_center_dist") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Output("rho2") .Build()); } Maybe operator()( const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2) const { return OpInterpUtil::Dispatch( *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, {}); } private: std::shared_ptr op_; }; class FusedCenterGradFunctor { public: FusedCenterGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_center_dist_grad") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Input("rho2_diff") .Output("b1_x1_diff") .Output("b1_x2_diff") .Output("b2_x1_diff") .Output("b2_x2_diff") .Output("b1_y1_diff") .Output("b1_y2_diff") .Output("b2_y1_diff") .Output("b2_y2_diff") .Build()); } Maybe operator()( const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2, const std::shared_ptr& rho2_diff) const { return OpInterpUtil::Dispatch( *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, rho2_diff}, {}); } private: std::shared_ptr op_; }; class FusedGetIntersectionAreaFunctor { public: FusedGetIntersectionAreaFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_intersection_area") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Output("inter") .Build()); } Maybe operator()( const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2) const { return OpInterpUtil::Dispatch( *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, {}); } private: std::shared_ptr op_; }; class FusedGetIntersectionAreaGradFunctor { public: FusedGetIntersectionAreaGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_intersection_area_grad") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Input("inter_diff") .Output("b1_x1_diff") .Output("b1_x2_diff") .Output("b2_x1_diff") .Output("b2_x2_diff") .Output("b1_y1_diff") .Output("b1_y2_diff") .Output("b2_y1_diff") .Output("b2_y2_diff") .Build()); } Maybe operator()( const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2, const std::shared_ptr& inter_diff) const { return OpInterpUtil::Dispatch( *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, inter_diff}, {}); } private: std::shared_ptr op_; }; class FusedGetBounddingBoxesCoordFunctor { public: FusedGetBounddingBoxesCoordFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_boundding_boxes_coord") .Input("x1") .Input("y1") .Input("w1") .Input("h1") .Input("x2") .Input("y2") .Input("w2") .Input("h2") .Output("b1_x1") .Output("b1_x2") .Output("b1_y1") .Output("b1_y2") .Output("b2_x1") .Output("b2_x2") .Output("b2_y1") .Output("b2_y2") .Build()); } Maybe operator()( const std::shared_ptr& x1, const std::shared_ptr& y1, const std::shared_ptr& w1, const std::shared_ptr& h1, const std::shared_ptr& x2, const std::shared_ptr& y2, const std::shared_ptr& w2, const std::shared_ptr& h2) const { return OpInterpUtil::Dispatch(*op_, {x1, y1, w1, h1, x2, y2, w2, h2}, {}); } private: std::shared_ptr op_; }; class FusedGetBounddingBoxesCoordGradFunctor { public: FusedGetBounddingBoxesCoordGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_boundding_boxes_coord_grad") .Input("b1_x1_diff") .Input("b1_x2_diff") .Input("b1_y1_diff") .Input("b1_y2_diff") .Input("b2_x1_diff") .Input("b2_x2_diff") .Input("b2_y1_diff") .Input("b2_y2_diff") .Output("x1_diff") .Output("y1_diff") .Output("w1_diff") .Output("h1_diff") .Output("x2_diff") .Output("y2_diff") .Output("w2_diff") .Output("h2_diff") .Build()); } Maybe operator()(const std::shared_ptr& b1_x1_diff, const std::shared_ptr& b1_x2_diff, const std::shared_ptr& b1_y1_diff, const std::shared_ptr& b1_y2_diff, const std::shared_ptr& b2_x1_diff, const std::shared_ptr& b2_x2_diff, const std::shared_ptr& b2_y1_diff, const std::shared_ptr& b2_y2_diff) const { return OpInterpUtil::Dispatch(*op_, {b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff, b2_x1_diff, b2_x2_diff, b2_y1_diff, b2_y2_diff}, {}); } private: std::shared_ptr op_; }; class FusedGetCiouDiagonalAngleFunctor { public: FusedGetCiouDiagonalAngleFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_ciou_diagonal_angle") .Input("w1") .Input("h1") .Input("w2") .Input("h2") .Output("v") .Build()); } Maybe operator()(const std::shared_ptr& w1, const std::shared_ptr& h1, const std::shared_ptr& w2, const std::shared_ptr& h2, const float eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch(*op_, {w1, h1, w2, h2}, attrs); } private: std::shared_ptr op_; }; class FusedGetCiouResultFunctor { public: FusedGetCiouResultFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_ciou_result") .Input("v") .Input("iou") .Input("rho2") .Input("c2") .Output("y") .Output("alpha") .Build()); } Maybe operator()(const std::shared_ptr& v, const std::shared_ptr& iou, const std::shared_ptr& rho2, const std::shared_ptr& c2, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch(*op_, {v, iou, rho2, c2}, attrs); } private: std::shared_ptr op_; }; class FusedGetCiouDiagonalAngleGradFunctor { public: FusedGetCiouDiagonalAngleGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_ciou_diagonal_angle_grad") .Input("w1") .Input("h1") .Input("w2") .Input("h2") .Input("v_diff") .Output("w1_diff") .Output("h1_diff") .Output("w2_diff") .Output("h2_diff") .Build()); } Maybe operator()(const std::shared_ptr& w1, const std::shared_ptr& h1, const std::shared_ptr& w2, const std::shared_ptr& h2, const std::shared_ptr& v_diff, const float eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch(*op_, {w1, h1, w2, h2, v_diff}, attrs); } private: std::shared_ptr op_; }; class FusedGetCiouResultGradFunctor { public: FusedGetCiouResultGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_ciou_result_grad") .Input("dy") .Input("alpha") .Input("rho2") .Input("c2") .Output("dv") .Output("diou") .Output("drho2") .Output("dc2") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& alpha, const std::shared_ptr& rho2, const std::shared_ptr& c2) const { return OpInterpUtil::Dispatch(*op_, {dy, alpha, rho2, c2}, {}); } private: std::shared_ptr op_; }; class FusedGetIouFunctor { public: FusedGetIouFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_iou") .Input("w1") .Input("h1") .Input("w2") .Input("h2") .Input("inter") .Output("iou") .Build()); } Maybe operator()(const std::shared_ptr& w1, const std::shared_ptr& h1, const std::shared_ptr& w2, const std::shared_ptr& h2, const std::shared_ptr& inter, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch(*op_, {w1, h1, w2, h2, inter}, attrs); } private: std::shared_ptr op_; }; class FusedGetIouGradFunctor { public: FusedGetIouGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_iou_grad") .Input("diou") .Input("w1") .Input("h1") .Input("w2") .Input("h2") .Input("inter") .Output("dw1") .Output("dh1") .Output("dinter") .Build()); } Maybe operator()(const std::shared_ptr& diou, const std::shared_ptr& w1, const std::shared_ptr& h1, const std::shared_ptr& w2, const std::shared_ptr& h2, const std::shared_ptr& inter, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch(*op_, {diou, w1, h1, w2, h2, inter}, attrs); } private: std::shared_ptr op_; }; class FusedGetConvexDiagonalSquaredFunctor { public: FusedGetConvexDiagonalSquaredFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_convex_diagonal_squared") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Output("c2") .Build()); } Maybe operator()(const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch( *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, attrs); } private: std::shared_ptr op_; }; class FusedGetConvexDiagonalSquaredGradFunctor { public: FusedGetConvexDiagonalSquaredGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_get_convex_diagonal_squared_grad") .Input("c2_diff") .Input("b1_x1") .Input("b1_x2") .Input("b2_x1") .Input("b2_x2") .Input("b1_y1") .Input("b1_y2") .Input("b2_y1") .Input("b2_y2") .Output("b1_x1_diff") .Output("b1_x2_diff") .Output("b2_x1_diff") .Output("b2_x2_diff") .Output("b1_y1_diff") .Output("b1_y2_diff") .Output("b2_y1_diff") .Output("b2_y2_diff") .Build()); } Maybe operator()( const std::shared_ptr& c2_diff, const std::shared_ptr& b1_x1, const std::shared_ptr& b1_x2, const std::shared_ptr& b2_x1, const std::shared_ptr& b2_x2, const std::shared_ptr& b1_y1, const std::shared_ptr& b1_y2, const std::shared_ptr& b2_y1, const std::shared_ptr& b2_y2, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps"); attrs.SetAllAttrs(eps); return OpInterpUtil::Dispatch( *op_, {c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, attrs); } private: std::shared_ptr op_; }; class RealFunctor { public: RealFunctor() { op_ = CHECK_JUST(one::OpBuilder("real").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { if (!x->dtype()->is_complex()) { return x; } return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class RealGradFunctor { public: RealGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("real_grad").Input("dout").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dout) const { return OpInterpUtil::Dispatch(*op_, {dout}); } private: std::shared_ptr op_; }; class ImagFunctor { public: ImagFunctor() { op_ = CHECK_JUST(one::OpBuilder("imag").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { CHECK_OR_RETURN(x->dtype()->is_complex()) << "RuntimeError: imag is implemented for tensors with complex dtypes, but gets" << x->dtype()->name(); return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class ImagGradFunctor { public: ImagGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("imag_grad").Input("dout").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dout) const { return OpInterpUtil::Dispatch(*op_, {dout}); } private: std::shared_ptr op_; }; class ConjFunctor { public: ConjFunctor() { op_ = CHECK_JUST(one::OpBuilder("conj_physical").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { if (!x->dtype()->is_complex()) { return x; } return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; class ConjPhysicalFunctor { public: ConjPhysicalFunctor() { op_ = CHECK_JUST(one::OpBuilder("conj_physical").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { if (!IsComplexDataType(x->dtype()->data_type())) { return x; } return OpInterpUtil::Dispatch(*op_, {x}); } private: std::shared_ptr op_; }; } // namespace impl using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Add"); m.add_functor("ScalarAdd"); m.add_functor("ScalarSub"); m.add_functor("ScalarMul"); m.add_functor("InplaceScalarMul"); m.add_functor("AddCDiv"); m.add_functor("InplaceAddCDiv"); m.add_functor("ScalarDiv"); m.add_functor("ScalarDivMode"); m.add_functor("InplaceScalarDiv"); m.add_functor("ScalarPow"); m.add_functor("ScalarReversePow"); m.add_functor("ScalarPowGrad"); m.add_functor("ScalarReversePowGrad"); m.add_functor("ReduceMax"); m.add_functor("Max"); m.add_functor("ReduceMean"); m.add_functor("ReduceMeanWhole"); m.add_functor("ReduceMin"); m.add_functor("Min"); m.add_functor("Amin"); m.add_functor("Median"); m.add_functor("MedianWithIndices"); m.add_functor("Mode"); m.add_functor("Amax"); m.add_functor("ReduceSum"); m.add_functor("ReduceSumWhole"); m.add_functor("ReduceNanSum"); m.add_functor("ReduceNanSumWhole"); m.add_functor("ReduceAll"); m.add_functor("ReduceAllWhole"); m.add_functor("ReduceAny"); m.add_functor("ReduceAnyWhole"); m.add_functor("ReduceProd"); m.add_functor("ReduceProdWhole"); m.add_functor("ReduceMinDeviceStage"); m.add_functor("ReduceMaxDeviceStage"); m.add_functor("ReduceMinGlobalStage"); m.add_functor("ReduceMaxGlobalStage"); m.add_functor("ReduceMinDeviceStageGrad"); m.add_functor("ReduceMaxDeviceStageGrad"); m.add_functor("ReduceMinGlobalStageGrad"); m.add_functor("ReduceMaxGlobalStageGrad"); m.add_functor("LogSumExp"); m.add_functor("LogAddExp"); m.add_functor("Quantile"); m.add_functor("ScalarQuantile"); m.add_functor("Transpose"); m.add_functor("Transpose2dim"); m.add_functor("Permute"); m.add_functor("AsStrided"); m.add_functor("AsStridedGrad"); m.add_functor("InplaceAsStrided"); m.add_functor("Swapaxes"); m.add_functor("Swapdims"); m.add_functor("Arange"); m.add_functor("GlobalArange"); m.add_functor("HannWindow"); m.add_functor("GlobalHannWindow"); m.add_functor("Cast"); m.add_functor("Clamp"); m.add_functor("ClampMin"); m.add_functor("ClampMax"); m.add_functor("ClampInplace"); m.add_functor("ClampMinInplace"); m.add_functor("ClampMaxInplace"); m.add_functor("Clip"); m.add_functor("ClipInplace"); m.add_functor("SqrtSquareSum"); m.add_functor("VectorNorm"); m.add_functor("MatrixNorm"); m.add_functor("Norm"); m.add_functor("ScalarNorm"); m.add_functor("ClampGrad"); m.add_functor("Select"); m.add_functor("SelectTopN"); m.add_functor("Minimum"); m.add_functor("Min"); m.add_functor("Maximum"); m.add_functor("Max"); m.add_functor("ScalarFMod"); m.add_functor("ScalarFloorDiv"); m.add_functor("ScalarTruncDiv"); m.add_functor("ScalarLogicalEqual"); m.add_functor( "ScalarLogicalNotEqual"); m.add_functor("ScalarLogicalGreater"); m.add_functor("InplaceScalarLogicalGreater"); m.add_functor( "ScalarLogicalGreaterEqual"); m.add_functor("ScalarLogicalLess"); m.add_functor( "ScalarLogicalLessEqual"); m.add_functor("ScalarLogicalAnd"); m.add_functor("ScalarLogicalOr"); m.add_functor("ScalarLogicalXor"); m.add_functor("ScalarLerp"); m.add_functor("ScalarInplaceLerp"); m.add_functor("ScalarLerpGrad"); m.add_functor("StandardDeviation"); m.add_functor("Variance"); m.add_functor("RMSLayerNormalization"); m.add_functor("Dot"); m.add_functor("MovedimVec"); m.add_functor("MovedimInt"); m.add_functor("TensorSplitVec"); m.add_functor("TensorSplitInt"); m.add_functor("HsplitInt"); m.add_functor("HsplitVec"); m.add_functor("VsplitInt"); m.add_functor("VsplitVec"); m.add_functor("Erfinv"); m.add_functor("ErfinvInplace"); m.add_functor("Cumsum"); m.add_functor("Cumprod"); m.add_functor("CumprodGrad"); m.add_functor("EinSum"); m.add_functor("Inv"); m.add_functor("Det"); m.add_functor("GeluWithApproximate"); m.add_functor("Trunc"); m.add_functor("Stft"); m.add_functor("FftC2C"); m.add_functor("FftR2C"); m.add_functor("FftC2R"); m.add_functor("Fft"); m.add_functor("IFft"); m.add_functor("Fft2"); m.add_functor("IFft2"); m.add_functor("FftN"); m.add_functor("IFftN"); m.add_functor("RFft"); m.add_functor("IRFft"); m.add_functor("RFft2"); m.add_functor("IRFft2"); m.add_functor("RFftN"); m.add_functor("IRFftN"); m.add_functor("HFft"); m.add_functor("IHFft"); m.add_functor("HFft2"); m.add_functor("IHFft2"); m.add_functor("HFftN"); m.add_functor("IHFftN"); m.add_functor("FusedWeightedSum"); m.add_functor("FusedCenter"); m.add_functor("FusedCenterGrad"); m.add_functor("FusedGetBounddingBoxesCoord"); m.add_functor("FusedGetBounddingBoxesCoordGrad"); m.add_functor("FusedGetCiouDiagonalAngle"); m.add_functor("FusedGetCiouDiagonalAngleGrad"); m.add_functor("FusedGetCiouResult"); m.add_functor("FusedGetCiouResultGrad"); m.add_functor("FusedGetIntersectionArea"); m.add_functor("FusedGetIntersectionAreaGrad"); m.add_functor("FusedGetIou"); m.add_functor("FusedGetIouGrad"); m.add_functor("FusedGetConvexDiagonalSquared"); m.add_functor( "FusedGetConvexDiagonalSquaredGrad"); m.add_functor("ScalarBitwiseAnd"); m.add_functor("ScalarBitwiseOr"); m.add_functor("ScalarBitwiseXor"); m.add_functor("Real"); m.add_functor("RealGrad"); m.add_functor("Imag"); m.add_functor("ImagGrad"); m.add_functor("Conj"); m.add_functor("ConjPhysical"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/nn_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/random_mask_like_kernel.h" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" #include "oneflow/core/common/container_util.h" #include "fmt/core.h" namespace oneflow { namespace one { namespace functional { namespace impl { class BiasAddFunctor { public: BiasAddFunctor() { op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& bias, const int32_t& axis) const { int32_t axis_val = axis; if (axis_val < 0) { const int64_t num_axes = x->shape()->NumAxes(); axis_val += num_axes; } CHECK_LT_OR_RETURN(axis_val, x->shape()->NumAxes()) << Error::IndexError() << "Dimension out of range (expected to be in range of [-" << x->shape()->NumAxes() << "," << x->shape()->NumAxes() - 1 << "], but got " << axis_val << ")"; CHECK_EQ_OR_RETURN(x->shape()->At(axis_val), bias->shape()->At(0)) << Error::RuntimeError() << "The size of tensor x " << x->shape()->ToString() << " must match the size of tensor b " << bias->shape()->ToString() << " at dimension " << axis_val; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis_val); return OpInterpUtil::Dispatch(*op_, {x, bias}, attrs); } private: std::shared_ptr op_; }; class ConvBaseFunctor { public: explicit ConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) { bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); enable_fused_conv_bias_ = ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", false); } virtual ~ConvBaseFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& weight, const Optional& bias, const std::vector& stride, const std::vector& padding, const std::vector& dilation, const int32_t& groups, const std::string& channel_pos) const { std::shared_ptr unsqueezed_input; bool is_batched = true; std::string func_name; if (num_spatial_dims_ == 1) { func_name = "conv1d"; } else if (num_spatial_dims_ == 2) { func_name = "conv2d"; } else { func_name = "conv3d"; } std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name)); std::vector kernel_size_vec(num_spatial_dims_); int32_t channel_idx = 1; int32_t kernel_idx_offset = 2; if (channel_pos == "channels_last") { kernel_idx_offset = 1; channel_idx = kernel_idx_offset + num_spatial_dims_; } for (int i = 0; i < num_spatial_dims_; i++) { kernel_size_vec.at(i) = ((weight->shape())->At(i + kernel_idx_offset)); } auto& conv_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("filters", "kernel_size", "padding_before", "strides", "dilation_rate", "groups", "data_format"); conv_attrs.SetAllAttrs(static_cast(weight->shape()->At(0)), kernel_size_vec, padding, stride, dilation, groups, channel_pos); if (bias && enable_fused_conv_bias_) { return OpInterpUtil::Dispatch(*conv_bias_op_, {input, weight, JUST(bias)}, conv_attrs); } const std::shared_ptr& conv_out = JUST(OpInterpUtil::Dispatch(*conv_op_, {unsqueezed_input, weight}, conv_attrs)); std::shared_ptr squeezed_conv_output = conv_out; if (!is_batched) { squeezed_conv_output = JUST(functional::Squeeze(conv_out, std::vector{0})); channel_idx -= 1; } if (bias) { return functional::BiasAdd(squeezed_conv_output, JUST(bias), channel_idx); } else { return squeezed_conv_output; } } protected: std::shared_ptr conv_op_; std::shared_ptr bias_op_; std::shared_ptr conv_bias_op_; int32_t num_spatial_dims_; bool enable_fused_conv_bias_; }; class Conv1dFunctor : public ConvBaseFunctor { public: Conv1dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/1) { conv_op_ = CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); conv_bias_op_ = CHECK_JUST( one::OpBuilder("conv1d").Input("in").Input("weight").Input("bias").Output("out").Build()); } }; class Conv2dFunctor : public ConvBaseFunctor { public: Conv2dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/2) { conv_op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); conv_bias_op_ = CHECK_JUST( one::OpBuilder("conv2d").Input("in").Input("weight").Input("bias").Output("out").Build()); } }; class Conv3dFunctor : public ConvBaseFunctor { public: Conv3dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/3) { conv_op_ = CHECK_JUST(one::OpBuilder("conv3d").Input("in").Input("weight").Output("out").Build()); conv_bias_op_ = CHECK_JUST( one::OpBuilder("conv3d").Input("in").Input("weight").Input("bias").Output("out").Build()); } }; class DeConvBaseFunctor { public: explicit DeConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) { bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); } virtual ~DeConvBaseFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& weight, const Optional& bias, const std::vector& stride, const std::vector& padding, const std::vector& output_padding, const int32_t& groups, const std::vector& dilation, const std::string& data_format) const { std::shared_ptr unsqueezed_input; bool is_batched = true; std::string func_name; if (num_spatial_dims_ == 1) { func_name = "deconv1d"; } else if (num_spatial_dims_ == 2) { func_name = "deconv2d"; } else { func_name = "deconv3d"; } std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name)); int32_t channel_idx = 1; std::vector kernel_size_vec(num_spatial_dims_); int32_t kernel_idx_offset = 2; if (data_format == "channels_last") { kernel_idx_offset = 1; } for (int i = 0; i < num_spatial_dims_; i++) { kernel_size_vec[i] = ((weight->shape())->At(i + kernel_idx_offset)); } auto& deconv_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("filters", "kernel_size", "padding_before", "output_padding", "strides", "dilation_rate", "groups", "data_format"); deconv_attrs.SetAllAttrs(static_cast(weight->shape()->At(1) * groups), kernel_size_vec, padding, output_padding, stride, dilation, groups, data_format); std::shared_ptr deconv_out = JUST(OpInterpUtil::Dispatch(*deconv_op_, {unsqueezed_input, weight}, deconv_attrs)); std::shared_ptr squeezed_deconv_output = deconv_out; if (!is_batched) { squeezed_deconv_output = JUST(functional::Squeeze(deconv_out, std::vector{0})); channel_idx -= 1; } if (bias) { auto& bias_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); bias_attrs.SetAllAttrs(static_cast(channel_idx)); return OpInterpUtil::Dispatch(*bias_op_, {squeezed_deconv_output, JUST(bias)}, bias_attrs); } else { return squeezed_deconv_output; } } protected: std::shared_ptr deconv_op_; std::shared_ptr bias_op_; int32_t num_spatial_dims_; }; class DeConv1dFunctor : public DeConvBaseFunctor { public: DeConv1dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/1) { deconv_op_ = CHECK_JUST(one::OpBuilder("deconv1d").Input("in").Input("weight").Output("out").Build()); } }; class DeConv2dFunctor : public DeConvBaseFunctor { public: DeConv2dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/2) { deconv_op_ = CHECK_JUST(one::OpBuilder("deconv2d").Input("in").Input("weight").Output("out").Build()); } }; class DeConv3dFunctor : public DeConvBaseFunctor { public: DeConv3dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/3) { deconv_op_ = CHECK_JUST(one::OpBuilder("deconv3d").Input("in").Input("weight").Output("out").Build()); } }; class EmbeddingReNormFunctor { public: EmbeddingReNormFunctor() { op_ = CHECK_JUST( one::OpBuilder("embedding_renorm").Input("in").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& indices, const double& max_norm, const double& norm_type) const { CHECK_EQ_OR_RETURN(in->ndim(), 2) << Error::RuntimeError() << "The dimension of input should be 2."; std::shared_ptr outputs = std::make_shared(1); JUST(oneflow::VectorAt(*outputs, 0)) = in; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("max_norm", "norm_type"); attrs.SetAllAttrs(max_norm, norm_type); JUST(OpInterpUtil::Dispatch(*op_, {in, indices}, outputs.get(), attrs)); return JUST(oneflow::VectorAt(*outputs, 0)); } private: std::shared_ptr op_; }; class EmbeddingFunctor { public: EmbeddingFunctor() { op_ = CHECK_JUST( one::OpBuilder("embedding").Input("weight").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& weight, const std::shared_ptr& indices, const Optional& padding_idx, const bool& scale_grad_by_freq) const { CHECK_EQ_OR_RETURN(weight->ndim(), 2) << "The dimension of weight should be 2"; int64_t new_padding_idx = -1; if (padding_idx.has_value()) { new_padding_idx = JUST(padding_idx); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding_idx", "scale_grad_by_freq"); attrs.SetAllAttrs(new_padding_idx, scale_grad_by_freq); return OpInterpUtil::Dispatch(*op_, {weight, indices}, attrs); } private: std::shared_ptr op_; }; class MatMulNoBroadCastFunctor { public: Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& mat2) const { const auto& input_shape = input->shape(); const auto& mat2_shape = mat2->shape(); CHECK_EQ_OR_RETURN(input_shape->NumAxes(), 2) << Error::RuntimeError() << "self must be a matrix"; CHECK_EQ_OR_RETURN(mat2_shape->NumAxes(), 2) << Error::RuntimeError() << "mat2 must be a matrix"; CHECK_EQ_OR_RETURN(input_shape->at(1), mat2_shape->at(0)) << Error::RuntimeError() << "mat1 and mat2 shapes cannot be multiplied (" << std::to_string(input_shape->at(0)) << "x" << std::to_string(input_shape->at(1)) << " and " << std::to_string(mat2_shape->at(0)) << "x" << std::to_string(mat2_shape->at(1)) << ")"; return JUST(functional::MatMul(input, mat2, false, false, 1.0)); } }; class MatMulFunctor { public: MatMulFunctor() { matmul_op_ = CHECK_JUST(one::OpBuilder("matmul").Input("a").Input("b").Output("out").Build()); batch_matmul_op_ = CHECK_JUST(one::OpBuilder("batch_matmul").Input("a").Input("b").Output("out").Build()); bcast_matmul_op_ = CHECK_JUST(one::OpBuilder("broadcast_matmul").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const bool& transpose_a, const bool& transpose_b, const double& alpha) const { const auto& a_shape = a->shape(); const auto& b_shape = b->shape(); CHECK_GE_OR_RETURN(a_shape->NumAxes(), 1) << Error::RuntimeError() << "Tensor a's dim should >= 1"; CHECK_GE_OR_RETURN(b_shape->NumAxes(), 1) << Error::RuntimeError() << "Tensor b's dim should >= 1"; DeviceType device_type{}; if (a->is_global()) { device_type = JUST(a->parallel_desc())->device_type(); } else { device_type = JUST(a->device())->enum_type(); } std::shared_ptr cast_a = a; std::shared_ptr cast_b = b; std::shared_ptr result; if ((cast_a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) { cast_a = JUST(functional::Cast(a, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false)); cast_b = JUST(functional::Cast(b, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("transpose_a", "transpose_b", "alpha"); attrs.SetAllAttrs(transpose_a, transpose_b, alpha); const int64_t a_num_axes = a_shape->NumAxes(); const int64_t b_num_axes = b_shape->NumAxes(); if (a_num_axes == 1 && b_num_axes == 2) { result = JUST(VectorMatrixProduct(cast_a, cast_b)); } else if (a_num_axes == 2 && b_num_axes == 1) { result = JUST(MatrixVectorProduct(cast_a, cast_b)); } else if (a_num_axes == 2 && b_num_axes == 2) { result = JUST(OpInterpUtil::Dispatch(*matmul_op_, {cast_a, cast_b}, attrs)); } else if (a_num_axes == b_num_axes) { bool if_batch_matmul = true; for (int i = 0; i < a_num_axes - 2; ++i) { if (a_shape->At(i) != b_shape->At(i)) { if_batch_matmul = false; break; } } if (if_batch_matmul) { result = JUST(OpInterpUtil::Dispatch(*batch_matmul_op_, {cast_a, cast_b}, attrs)); } else { result = JUST(OpInterpUtil::Dispatch(*bcast_matmul_op_, {cast_a, cast_b}, attrs)); } } else { result = JUST(OpInterpUtil::Dispatch(*bcast_matmul_op_, {cast_a, cast_b}, attrs)); } if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) { return JUST(functional::Cast(result, a->dtype(), /*pin_memory=*/false)); } else { return result; } } private: std::shared_ptr matmul_op_; std::shared_ptr batch_matmul_op_; std::shared_ptr bcast_matmul_op_; }; class BatchMatMulFunctor { public: BatchMatMulFunctor() { batch_matmul_op_ = CHECK_JUST(one::OpBuilder("batch_matmul").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const bool& transpose_a, const bool& transpose_b, const double& alpha) const { const auto& a_shape = a->shape(); const auto& b_shape = b->shape(); CHECK_EQ_OR_RETURN(a_shape->NumAxes(), 3) << Error::RuntimeError() << "Expected 3-dimensional tensor, but got " << a_shape->NumAxes() << "-dimensional tensor for argument #1"; CHECK_EQ_OR_RETURN(b_shape->NumAxes(), 3) << Error::RuntimeError() << "Expected 3-dimensional tensor, but got " << b_shape->NumAxes() << "-dimensional tensor for argument #2"; CHECK_EQ_OR_RETURN(a_shape->At(0), b_shape->At(0)) << Error::RuntimeError() << "Batch dim not match, please check input!"; const int64_t matmul_dim_a = transpose_a ? a_shape->At(1) : a_shape->At(2); const int64_t matmul_dim_b = transpose_b ? b_shape->At(2) : b_shape->At(1); CHECK_EQ_OR_RETURN(matmul_dim_a, matmul_dim_b) << Error::RuntimeError() << "Matmul dim not match, got " << matmul_dim_a << " of mat1 and " << matmul_dim_b << " of mat2, please check input!"; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("transpose_a", "transpose_b", "alpha"); attrs.SetAllAttrs(transpose_a, transpose_b, alpha); DeviceType device_type{}; if (a->is_global()) { device_type = JUST(a->parallel_desc())->device_type(); } else { device_type = JUST(a->device())->enum_type(); } std::shared_ptr cast_a = a; std::shared_ptr cast_b = b; if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) { cast_a = JUST(functional::Cast(a, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false)); cast_b = JUST(functional::Cast(b, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false)); } auto result = JUST(OpInterpUtil::Dispatch(*batch_matmul_op_, {cast_a, cast_b}, attrs)); if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) { return JUST(functional::Cast(result, a->dtype(), /*pin_memory=*/false)); } else { return result; } } private: std::shared_ptr batch_matmul_op_; }; class VectorMatrixProductFunctor { public: VectorMatrixProductFunctor() { vector_matrix_product_op_ = CHECK_JUST( one::OpBuilder("vector_matrix_product").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& vec, const std::shared_ptr& input) const { const auto& vec_shape = vec->shape(); const auto& input_shape = input->shape(); CHECK_OR_RETURN(input_shape->NumAxes() == 2 && vec_shape->NumAxes() == 1) << Error::RuntimeError() << "vector @ matrix expected, got " << "1, " << input_shape->NumAxes() << ", " << vec_shape->NumAxes(); CHECK_EQ_OR_RETURN(vec_shape->at(0), input_shape->at(0)) << Error::RuntimeError() << "size mismatch, got " << 1 << ", " << std::to_string(vec_shape->at(0)) << " x " << std::to_string(input_shape->at(0)) << ", " << std::to_string(input_shape->at(1)); return OpInterpUtil::Dispatch(*vector_matrix_product_op_, {vec, input}); } private: std::shared_ptr vector_matrix_product_op_; }; class TensorDotIntDimsFunctor { public: Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const int32_t dims) const { CHECK_GE_OR_RETURN(dims, 0) << Error::RuntimeError() << "tensordot expects dims >= 0, but got dims=" << dims; CHECK_LE_OR_RETURN(dims, a->ndim()) << Error::RuntimeError() << "tensordot expects dims <= a.ndim which is " << a->ndim() << ", but got " << dims; CHECK_LE_OR_RETURN(dims, b->ndim()) << Error::RuntimeError() << "tensordot expects dims <= b.ndim which is " << b->ndim() << ", but got " << dims; std::vector dot_dims_a(dims), dot_dims_b(dims); for (int32_t i = 0; i < dims; i++) { dot_dims_a[i] = a->ndim() - dims + i; dot_dims_b[i] = i; } return JUST(functional::TensorDot(a, b, dot_dims_a, dot_dims_b)); } }; class TensorDotFunctor { public: Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const std::vector& dims_a, const std::vector& dims_b) const { // dims_a and dims_b represent dim indices to calculate dot, and are copied to variables // dot_dims_a and dot_dims_b when they need to be modified CHECK_EQ_OR_RETURN(dims_a.size(), dims_b.size()) << Error::RuntimeError() << "both dimension lists should have same length, got " << dims_a.size() << " and " << dims_b.size(); // dims_a.size() == dims_b.size(), and specially treat if both are empty if (dims_a.empty()) { DimVector shape_sum(a->ndim() + b->ndim()); for (int64_t i = 0; i < a->ndim(); i++) { shape_sum[i] = a->shape()->At(i); } for (int64_t i = 0; i < b->ndim(); i++) { shape_sum[i + a->ndim()] = b->shape()->At(i); } std::shared_ptr reshape_a = JUST(Reshape(a, Shape(DimVector{-1, 1}))); std::shared_ptr reshape_b = JUST(Reshape(b, Shape(DimVector{1, -1}))); return JUST(Reshape(JUST(functional::MatMul(reshape_a, reshape_b, false, false, 1.0)), Shape(DimVector(shape_sum.begin(), shape_sum.end())))); } std::vector dot_dims_a(dims_a.begin(), dims_a.end()); std::vector dot_dims_b(dims_b.begin(), dims_b.end()); for (int64_t i = 0; i < dot_dims_a.size(); i++) { dot_dims_a[i] = JUST(maybe_wrap_dim(dot_dims_a[i], a->ndim())); dot_dims_b[i] = JUST(maybe_wrap_dim(dot_dims_b[i], b->ndim())); } std::vector if_dot_dims_a(a->ndim(), false); std::vector if_dot_dims_b(b->ndim(), false); for (const int32_t dim_idx : dot_dims_a) { CHECK_EQ_OR_RETURN(if_dot_dims_a[dim_idx], false) << Error::RuntimeError() << "dim " << dim_idx << " appears multiple times in the list of dims"; if_dot_dims_a[dim_idx] = true; } for (const int32_t dim_idx : dot_dims_b) { CHECK_EQ_OR_RETURN(if_dot_dims_b[dim_idx], false) << Error::RuntimeError() << "dim " << dim_idx << " appears multiple times in the list of dims"; if_dot_dims_b[dim_idx] = true; } std::vector broadcast_dims_a, broadcast_dims_b; for (int64_t i = 0; i < dot_dims_a.size(); i++) { int64_t size_a = a->shape()->At(dot_dims_a[i]); int64_t size_b = b->shape()->At(dot_dims_b[i]); if (size_a == 1 && size_b > 1) { broadcast_dims_b.emplace_back(dot_dims_b[i]); } else if (size_b == 1 && size_a > 1) { broadcast_dims_a.emplace_back(dot_dims_a[i]); } else { CHECK_EQ_OR_RETURN(size_a, size_b) << Error::RuntimeError() << "contracted dimensions need to match, but first has size " << size_a << " in dim " << dot_dims_a[i] << " and second has size " << size_b << " in dim " << dot_dims_b[i]; } } // calculate ReduceSum for broadcasting of some axis std::shared_ptr reduced_sum_a = a; std::shared_ptr reduced_sum_b = b; if (!broadcast_dims_a.empty()) reduced_sum_a = JUST(functional::ReduceSum(a, broadcast_dims_a, true, NullOpt)); if (!broadcast_dims_b.empty()) reduced_sum_b = JUST(functional::ReduceSum(b, broadcast_dims_b, true, NullOpt)); // int64_t non_dot_size_a = 1, non_dot_size_b = 1; std::vector non_dot_shape_a, non_dot_shape_b; non_dot_shape_a.reserve(a->ndim() - dot_dims_a.size() + b->ndim() - dot_dims_b.size()); non_dot_shape_b.reserve(b->ndim() - dot_dims_b.size()); std::vector permuted_dims_a, permuted_dims_b; permuted_dims_a.reserve(a->ndim()); permuted_dims_b.reserve(b->ndim()); for (int32_t i = 0; i < a->ndim(); i++) { if (!if_dot_dims_a[i]) { permuted_dims_a.emplace_back(i); // non_dot_size_a *= reduced_sum_a->shape()->At(i); non_dot_shape_a.emplace_back(reduced_sum_a->shape()->At(i)); } } for (const int32_t dim_idx : dot_dims_a) permuted_dims_a.emplace_back(dim_idx); for (const int32_t dim_idx : dot_dims_b) permuted_dims_b.emplace_back(dim_idx); for (int32_t i = 0; i < b->ndim(); i++) { if (!if_dot_dims_b[i]) { permuted_dims_b.emplace_back(i); // non_dot_size_b *= reduced_sum_b->shape()->At(i); non_dot_shape_b.emplace_back(reduced_sum_b->shape()->At(i)); } } non_dot_shape_a.insert(non_dot_shape_a.end(), non_dot_shape_b.begin(), non_dot_shape_b.end()); int64_t dot_size = 1; for (const int32_t dim_idx : dot_dims_a) dot_size *= reduced_sum_a->shape()->At(dim_idx); std::shared_ptr permuted_a = JUST( Reshape(JUST(Permute(reduced_sum_a, permuted_dims_a)), Shape(DimVector({-1, dot_size})))); std::shared_ptr permuted_b = JUST( Reshape(JUST(Permute(reduced_sum_b, permuted_dims_b)), Shape(DimVector({dot_size, -1})))); return Reshape(JUST(functional::MatMul(permuted_a, permuted_b, false, false, 1.0)), Shape(DimVector({non_dot_shape_a.begin(), non_dot_shape_a.end()}))); } }; class FusedMLPFunctor { public: FusedMLPFunctor() { #if CUDA_VERSION >= 11060 fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("cublas_fused_mlp") .Input("x") .Input("weights", n) .Input("biases", n) .Output("out") .Output("cublas_aux", n) .Output("hidden", n) .Build()); } #endif } Maybe operator()(const std::shared_ptr& x, const TensorTuple& weights, const TensorTuple& biases, bool skip_final_activation) const { const int64_t weight_size = weights.size(); const int64_t bias_size = biases.size(); CHECK_GE_OR_RETURN(weight_size, 1) << Error::RuntimeError() << "The number of weights should be greater equal than 1. "; CHECK_EQ_OR_RETURN(weight_size, bias_size) << Error::RuntimeError() << "The number of weights should be equal to biases. "; int64_t n = 0, k = 0; /* x: (m, k) weight: (n, k) need transpose bias: (n) */ const auto& x_shape = x->shape(); k = x_shape->At(1); for (int64_t i = 0; i < weight_size; i++) { const auto& weight_shape = weights[i]->shape(); const auto& bias_shape = biases[i]->shape(); // TODO(): Support Fused batch/broadcast matmul. CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << Error::RuntimeError() << "Weight's dim size should == 2"; CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1) << Error::RuntimeError() << "Bias's dim size should == 1"; n = weight_shape->At(0); CHECK_EQ_OR_RETURN(bias_shape->At(0), n) << Error::RuntimeError() << "Bias's dim is not equal to weight's first dim. "; CHECK_EQ_OR_RETURN(weight_shape->At(1), k) << Error::RuntimeError() << "weight's second dim should be equal to input's second dim. "; // Set for next layer. k = n; } #if CUDA_VERSION >= 11060 DeviceType device_type{}; if (x->is_global()) { device_type = JUST(x->parallel_desc())->device_type(); } else { device_type = JUST(x->device())->enum_type(); } if ((device_type == DeviceType::kCUDA) && (weight_size <= kMaxInputCount) && (!ParseBooleanFromEnv("ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP", false))) { TensorTuple input(2 * weight_size + 1); input[0] = x; std::copy(weights.begin(), weights.end(), input.begin() + 1); std::copy(biases.begin(), biases.end(), input.begin() + 1 + weight_size); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("skip_final_activation"); attrs.SetAllAttrs(skip_final_activation); return OpInterpUtil::Dispatch(*fused_op_[weight_size], input, attrs); } #endif // CUDA_VERSION >= 11060 // Fall back to Naive matmul + bias_add + relu std::shared_ptr out = x; for (int32_t layer_idx = 0; layer_idx < weight_size; layer_idx++) { out = JUST( functional::BiasAdd(JUST(functional::MatMul(out, weights[layer_idx], false, true, 1.0)), biases[layer_idx], 1)); if ((layer_idx != weight_size - 1) || (!skip_final_activation)) { /* When it is not last dense layer, or it is last dense layer and skip_final_activate=False, we add relu Layer. */ out = JUST(functional::Relu(out, false)); } } return out; } private: #if CUDA_VERSION >= 11060 std::vector> fused_op_; #endif }; class FusedMatmulBiasFunctor { public: FusedMatmulBiasFunctor() { _with_add_to_output_op = CHECK_JUST(one::OpBuilder("fused_matmul_bias") .Input("x") .Input("weight") .Input("bias") .Input("_add_to_output") .Output("out") .Build()); _without_add_to_output_op = CHECK_JUST(one::OpBuilder("fused_matmul_bias") .Input("x") .Input("weight") .Input("bias") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, const std::shared_ptr& bias, const Optional& _add_to_output, const double& alpha, const double& beta) const { /* x: (m_i, ... m_0, k) weight: (n, k) need transpose bias: (n) */ const auto& x_shape = x->shape(); const int64_t k = x_shape->At(x->shape()->NumAxes() - 1); const auto& weight_shape = weight->shape(); const auto& bias_shape = bias->shape(); CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << Error::RuntimeError() << "Weight's dim size should == 2"; CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1) << Error::RuntimeError() << "Bias's dim size should == 1"; const int64_t n = weight_shape->At(0); CHECK_EQ_OR_RETURN(bias_shape->At(0), n) << Error::RuntimeError() << "Bias's dim is not equal to weight's first dim. "; CHECK_EQ_OR_RETURN(weight_shape->At(1), k) << Error::RuntimeError() << "weight's second dim should be equal to input's second dim. "; #if CUDA_VERSION >= 11020 DeviceType device_type{}; if (x->is_global()) { device_type = JUST(x->parallel_desc())->device_type(); } else { device_type = JUST(x->device())->enum_type(); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha", "beta"); attrs.SetAllAttrs(alpha, beta); if (device_type == DeviceType::kCUDA) { if (_add_to_output) { return OpInterpUtil::Dispatch(*_with_add_to_output_op, {x, weight, bias, JUST(_add_to_output)}, attrs); } else { return OpInterpUtil::Dispatch(*_without_add_to_output_op, {x, weight, bias}, attrs); } } #endif // CUDA_VERSION >= 11020 auto matmul_bias = JUST(functional::BiasAdd( JUST(functional::MatMul(x, weight, false, true, alpha)), bias, x->shape()->NumAxes() - 1)); if (_add_to_output && beta != 0.0) { if (beta == 1.0) { return JUST(functional::Add({matmul_bias, JUST(_add_to_output)}, false)); } else { return JUST(functional::Add( {matmul_bias, JUST(functional::ScalarMul(JUST(_add_to_output), beta, false))}, false)); } } else { return matmul_bias; } } private: std::shared_ptr _with_add_to_output_op; std::shared_ptr _without_add_to_output_op; }; class FusedMatmulBiasAddReluDropoutFunctor { public: FusedMatmulBiasAddReluDropoutFunctor() { #if CUDA_VERSION >= 11060 fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("fused_matmul_bias_add_relu_dropout") .Input("x") .Input("weights", n) .Input("biases", n) .Output("out") .Output("cublas_aux", n) .Output("hidden", n) .Build()); } #endif } Maybe operator()(const std::shared_ptr& x, const TensorTuple& weights, const TensorTuple& biases, bool skip_final_activation, const std::vector& dropout_rate_list, const Optional& generator) const { const int64_t weight_size = weights.size(); const int64_t bias_size = biases.size(); CHECK_GE_OR_RETURN(weight_size, 1) << Error::RuntimeError() << "The number of weights should be greater equal than 1. "; CHECK_EQ_OR_RETURN(weight_size, bias_size) << Error::RuntimeError() << "The number of weights should be equal to biases. "; CHECK_EQ_OR_RETURN(weight_size, dropout_rate_list.size()) << Error::RuntimeError() << "The dropout rate list length should be equal to the number of weights. "; int64_t n = 0, k = 0; /* x: (m, k) weight: (n, k) need transpose bias: (n) */ const auto& x_shape = x->shape(); k = x_shape->At(1); for (int64_t i = 0; i < weight_size; i++) { CHECK_GE_OR_RETURN(dropout_rate_list[i], 0.0f) << Error::RuntimeError() << "Dropout rate should be >= 0.0"; const auto& weight_shape = weights[i]->shape(); const auto& bias_shape = biases[i]->shape(); // TODO(): Support Fused batch/broadcast matmul. CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << "Weight's dim should == 2"; CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1) << "Bias's dim should == 1"; n = weight_shape->At(0); CHECK_EQ_OR_RETURN(bias_shape->At(0), n) << "Bias's dim is not equal to weight's last dim. "; CHECK_EQ_OR_RETURN(weight_shape->At(1), k) << "weight's first dim should be equal to input's last dim. "; // Set for next layer. k = n; } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); #if CUDA_VERSION >= 11060 DeviceType device_type{}; if (x->is_global()) { device_type = JUST(x->parallel_desc())->device_type(); } else { device_type = JUST(x->device())->enum_type(); } if ((device_type == DeviceType::kCUDA) && (weight_size <= kMaxInputCount) && (!ParseBooleanFromEnv("ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP", false))) { TensorTuple input(2 * weight_size + 1); input[0] = x; std::copy(weights.begin(), weights.end(), input.begin() + 1); std::copy(biases.begin(), biases.end(), input.begin() + 1 + weight_size); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("skip_final_activation", "seed", "dropout_rate_list"); attrs.SetAllAttrs(skip_final_activation, static_cast(gen->current_seed()), dropout_rate_list); const auto& dropout_state = std::make_shared(gen); return OpInterpUtil::Dispatch(*fused_op_[weight_size], input, OpExprInterpContext(attrs, dropout_state)); } #endif // CUDA_VERSION >= 11060 // Fall back to Naive matmul + bias_add + relu + dropout std::shared_ptr out = x; for (int32_t layer_idx = 0; layer_idx < weight_size; layer_idx++) { out = JUST( functional::BiasAdd(JUST(functional::MatMul(out, weights[layer_idx], false, true, 1.0)), biases[layer_idx], 1)); if ((layer_idx != weight_size - 1) || !skip_final_activation) { out = JUST(functional::Relu(out, false)); out = JUST(functional::Dropout(out, JUST(VectorAt(dropout_rate_list, layer_idx)), /*training=*/true, /*inplace=*/false, /*generator=*/gen, /*addend=*/NullOpt)); } else { out = JUST(functional::Dropout(out, JUST(VectorAt(dropout_rate_list, layer_idx)), /*training=*/true, /*inplace=*/false, /*generator=*/gen, /*addend=*/NullOpt)); } } return out; } private: #if CUDA_VERSION >= 11060 std::vector> fused_op_; #endif }; class LayerNormFunctor { public: LayerNormFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm") .Input("x") .Output("y") .Output("mean") .Output("inv_variance") .Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon", "center", "scale"); attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon, false, false); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class SkipLayerNormFunctor { public: SkipLayerNormFunctor() { std::vector bool_list = {true, false}; /* number of skip */ for (bool has_skip : bool_list) { /* has_gamma */ for (bool has_gamma : bool_list) { /* has_beta */ for (bool has_beta : bool_list) { /* has_bias */ for (bool has_bias : bool_list) { one::OpBuilder op_builder = one::OpBuilder("skip_layer_norm").Input("x"); if (has_gamma) { op_builder = op_builder.Input("gamma"); } if (has_beta) { op_builder = op_builder.Input("beta"); } if (has_bias) { op_builder = op_builder.Input("bias"); } if (has_skip) { op_builder = op_builder.Input("skip"); } op_builder = op_builder.Output("y").Output("mean").Output("inv_variance"); std::shared_ptr op_expr = CHECK_JUST(op_builder.Build()); ops_.insert(std::pair, std::shared_ptr>( std::tuple(has_skip, has_gamma, has_beta, has_bias), op_expr)); } // has_bias } // has_beta } // has_gamma } // has_skip } Maybe operator()(const std::shared_ptr& x, const Optional& gamma, const Optional& beta, const Optional& bias, const Optional& skip, const double& epsilon, const double& alpha) const { // check shape of x const auto& x_shape = *(x->shape()); CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2) << "number of axes of \'x\' should be greater than or equal to 2, yet get " << x_shape.NumAxes(); if (gamma) { const auto& gamma_shape = *(JUST(gamma)->shape()); CHECK_EQ_OR_RETURN(gamma_shape.NumAxes(), 1) << "number of axes of \'gamma\' should have be equal to 1, yet get " << gamma_shape.NumAxes(); CHECK_EQ_OR_RETURN(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'gamma\'(" << gamma_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (beta) { const auto& beta_shape = *(JUST(beta)->shape()); CHECK_EQ_OR_RETURN(beta_shape.NumAxes(), 1) << "number of axes of \'beta\' should have be equal to 1, yet get " << beta_shape.NumAxes(); CHECK_EQ_OR_RETURN(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "dimension 1 of \'beta\'(" << beta_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (bias) { const auto& bias_shape = *(JUST(bias)->shape()); CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should have be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "dimension 1 of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (skip) { const auto& skip_shape = *(JUST(skip)->shape()); CHECK_EQ_OR_RETURN(skip_shape, x_shape) << "shape of \'skip\' is not the same as \'x\'"; } // set attributes auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("epsilon", "alpha"); attrs.SetAllAttrs(epsilon, alpha); // count number of all input tensors size_t nb_inputs = 1; // count x if (skip) nb_inputs += 1; // count skip if (gamma) nb_inputs += 1; // count gamma if (beta) nb_inputs += 1; // count beta if (bias) nb_inputs += 1; // count bias // construct input tensor tuple size_t tensor_index = 1; TensorTuple input(nb_inputs); bool has_gamma = false, has_beta = false, has_bias = false, has_skip = false; input[0] = x; if (gamma) { input[tensor_index] = JUST(gamma); tensor_index += 1; has_gamma = true; } if (beta) { input[tensor_index] = JUST(beta); tensor_index += 1; has_beta = true; } if (bias) { input[tensor_index] = JUST(bias); tensor_index += 1; has_bias = true; } if (skip) { input[tensor_index] = JUST(skip); tensor_index += 1; has_skip = true; } return OpInterpUtil::Dispatch( *(ops_.find(std::tuple(has_skip, has_gamma, has_beta, has_bias)) ->second), input, attrs); } private: /* (nb_skip, has_gamma, has_beta, has_bias) -> op */ std::map, std::shared_ptr> ops_; }; class LayerNormAffineFunctor { public: LayerNormAffineFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm") .Input("x") .Input("gamma") .Input("beta") .Output("y") .Output("mean") .Output("inv_variance") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& gamma, const std::shared_ptr& beta, const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon", "center", "scale"); attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon, true, true); return OpInterpUtil::Dispatch(*op_, {x, gamma, beta}, attrs); } private: std::shared_ptr op_; }; class GroupNormFunctor { public: GroupNormFunctor() { op_ = CHECK_JUST(one::OpBuilder("group_norm") .Input("x") .Output("y") .Output("mean") .Output("inv_variance") .Attr("affine", false) .Build()); affine_op_ = CHECK_JUST(one::OpBuilder("group_norm") .Input("x") .Input("gamma") .Input("beta") .Output("y") .Output("mean") .Output("inv_variance") .Attr("affine", true) .Build()); } Maybe operator()(const std::shared_ptr& x, const Optional& gamma, const Optional& beta, const bool affine, const int32_t num_groups, const double& epsilon, const std::string& data_format, const std::string& activation) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_groups", "epsilon", "data_format", "activation"); attrs.SetAllAttrs(num_groups, epsilon, data_format, activation); if (affine) { return OpInterpUtil::Dispatch(*affine_op_, {x, JUST(gamma), JUST(beta)}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {x}, attrs); } } private: std::shared_ptr op_; std::shared_ptr affine_op_; }; bool CheckNormShape(const Shape& x_shape, const Shape& normalized_shape) { if (x_shape.size() < normalized_shape.size()) { return false; } size_t b_ndim = x_shape.size() - normalized_shape.size(); for (int i = 0; i < x_shape.size(); ++i) { if (i >= b_ndim) { if (x_shape[i] != normalized_shape[i - b_ndim]) { return false; } } } return true; } class RMSNormFunctor { public: RMSNormFunctor() { op_ = CHECK_JUST(one::OpBuilder("rms_norm").Input("x").Output("y").Output("inv_rms").Build()); op_affine_ = CHECK_JUST(one::OpBuilder("rms_norm") .Input("x") .Input("weight") .Output("y") .Output("inv_rms") .Build()); } Maybe operator()(const std::shared_ptr& x, const Optional& weight, const Shape& normalized_shape, const float epsilon) const { const Shape& x_shape = *x->shape(); if (weight) { const Shape& w_shape = *JUST(weight)->shape(); CHECK_EQ_OR_RETURN(w_shape, normalized_shape) << "Expected weight be the same shape with normalized_shape " << normalized_shape.ToString() << ", but got " << w_shape.ToString(); } if (!CheckNormShape(x_shape, normalized_shape)) { auto shape_str_without_parentheses = x_shape.ToString().substr(1, x_shape.ToString().size() - 2); return Error::RuntimeError() << "Given normalized_shape=" << normalized_shape.ToString() << ", expected input with shape (*, " << shape_str_without_parentheses << "), but got input of " << x_shape.ToString(); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("normalized_shape", "epsilon"); attrs.SetAllAttrs(normalized_shape, epsilon); if (weight) { const DataType dtype = x->dtype()->data_type(); if (JUST(weight)->dtype()->data_type() != dtype) { auto weight_cast = JUST(functional::Cast(JUST(weight), DType{dtype}, /*pin_memory=*/false)); return OpInterpUtil::Dispatch(*op_affine_, {x, weight_cast}, attrs); } return OpInterpUtil::Dispatch(*op_affine_, {x, JUST(weight)}, attrs); } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; std::shared_ptr op_affine_; }; class SkipRMSNormFunctor { public: SkipRMSNormFunctor() { std::vector bool_list = {true, false}; for (bool has_weight : bool_list) { for (bool has_skip : bool_list) { for (bool has_bias : bool_list) { one::OpBuilder op_builder = one::OpBuilder("skip_rms_norm").Input("x"); if (has_weight) { op_builder = op_builder.Input("weight"); } if (has_bias) { op_builder = op_builder.Input("bias"); } if (has_skip) { op_builder = op_builder.Input("skip"); } op_builder = op_builder.Output("y").Output("inv_rms"); std::shared_ptr op_expr = CHECK_JUST(op_builder.Build()); ops_.insert(std::pair, std::shared_ptr>( std::tuple(has_weight, has_skip, has_bias), op_expr)); } // has_bias } // has_skip } // has_weight } Maybe operator()(const std::shared_ptr& x, const Optional& weight, const Optional& bias, const Optional& skip, const double& epsilon, const double& alpha) const { // check shape of x const auto& x_shape = *(x->shape()); CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2) << "number of axes of \'x\' should be greater than or equal to 2, yet get " << x_shape.NumAxes(); if (weight) { const auto& weight_shape = *(JUST(weight)->shape()); CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 1) << "number of axes of \'weight\' should have be equal to 1, yet get " << weight_shape.NumAxes(); CHECK_EQ_OR_RETURN(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "dimension 1 of \'weight\'(" << weight_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (bias) { const auto& bias_shape = *(JUST(bias)->shape()); CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should have be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "dimension 1 of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (skip) { const auto& skip_shape = *(JUST(skip)->shape()); CHECK_EQ_OR_RETURN(skip_shape, x_shape) << "shape of \'skip\' is not the same as \'x\'"; } // set attributes auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("epsilon", "alpha"); attrs.SetAllAttrs(epsilon, alpha); // count number of all input tensors size_t nb_inputs = 1; // count x if (skip) nb_inputs += 1; // count skip if (weight) nb_inputs += 1; // count weight if (bias) nb_inputs += 1; // count bias // construct input tensor tuple size_t tensor_index = 1; TensorTuple input(nb_inputs); bool has_weight = false, has_bias = false, has_skip = false; input[0] = x; if (weight) { input[tensor_index] = JUST(weight); tensor_index += 1; has_weight = true; } if (bias) { input[tensor_index] = JUST(bias); tensor_index += 1; has_bias = true; } if (skip) { input[tensor_index] = JUST(skip); tensor_index += 1; has_skip = true; } return OpInterpUtil::Dispatch( *(ops_.find(std::tuple(has_weight, has_skip, has_bias))->second), input, attrs); } private: /* (has_weight, has_skip, has_bias) -> op */ std::map, std::shared_ptr> ops_; }; class PixelShuffleFunctor { public: PixelShuffleFunctor() {} Maybe operator()(const std::shared_ptr& x, const int64_t& h_upscale_factor, const int64_t& w_upscale_factor) const { CHECK_OR_RETURN(x->ndim() == 4) << Error::RuntimeError() << "Only Accept 4D Tensor"; const int64_t batch = x->shape()->At(0); const int64_t channel = x->shape()->At(1); const int64_t height = x->shape()->At(2); const int64_t width = x->shape()->At(3); std::shared_ptr out; CHECK_OR_RETURN(channel % (h_upscale_factor * w_upscale_factor) == 0) << Error::RuntimeError() << "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or " "(h_upscale_factor * w_upscale_factor)"; const int64_t new_c = static_cast(channel / (h_upscale_factor * w_upscale_factor)); std::vector permute_vec = {0, 1, 4, 2, 5, 3}; std::vector reshape_vec_1 = {batch, new_c, h_upscale_factor * w_upscale_factor, height, width}; Shape reshape_1(DimVector(reshape_vec_1.begin(), reshape_vec_1.end())); std::vector reshape_vec_2 = {batch, new_c, h_upscale_factor, w_upscale_factor, height, width}; Shape reshape_2(DimVector(reshape_vec_2.begin(), reshape_vec_2.end())); std::vector reshape_vec_3 = {batch, new_c, height * h_upscale_factor, width * w_upscale_factor}; Shape reshape_3(DimVector(reshape_vec_3.begin(), reshape_vec_3.end())); out = JUST(Reshape(x, reshape_1)); out = JUST(Reshape(out, reshape_2)); out = JUST(Permute(out, permute_vec)); out = JUST(Reshape(out, reshape_3)); return out; } }; class TFPoolNDFunctor { public: TFPoolNDFunctor() = default; virtual ~TFPoolNDFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::vector& kernel_size, const std::vector& strides, const std::string& padding, const std::vector& padding_before, const std::vector& padding_after, const std::string& data_format, const bool& ceil_mode) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("pool_size", "strides", "padding", "padding_before", "padding_after", "data_format", "ceil_mode"); attrs.SetAllAttrs(kernel_size, strides, padding, padding_before, padding_after, data_format, ceil_mode); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } protected: std::shared_ptr op_; }; class MaxPoolNDFunctor { public: explicit MaxPoolNDFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) {} virtual ~MaxPoolNDFunctor() = default; Maybe operator()(const std::shared_ptr& input, const std::vector& kernel_size, const Optional>& stride, const std::vector& padding, const std::vector& dilation, const bool& return_indices, const bool& ceil_mode, const std::string& data_format) const { // channels_last case if (input->is_cuda() && num_spatial_dims_ == 2 && data_format == "channels_last") { if (!return_indices && dilation.at(0) == 1 && dilation.at(1) == 1) { // legacy tf style maxpool2d , use cudnn implementation // with high performance but do not support dilation/return_indices std::vector padding_before{padding.at(0), padding.at(1)}; std::vector padding_after{padding.at(0), padding.at(1)}; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("pool_size", "strides", "padding", "padding_before", "padding_after", "data_format", "ceil_mode"); attrs.SetAllAttrs(kernel_size, stride ? *JUST(stride) : kernel_size, std::string("customized"), padding_before, padding_after, data_format, ceil_mode); TensorTuple output; output.emplace_back(JUST(OpInterpUtil::Dispatch(*tf_maxpool_op_, {input}, attrs))); return output; } } std::shared_ptr unsqueezed_input; bool is_batched = true; std::string func_name; if (num_spatial_dims_ == 1) { func_name = "max_pool1d"; } else if (num_spatial_dims_ == 2) { func_name = "max_pool2d"; } else { func_name = "max_pool3d"; } std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("kernel_size", "padding", "stride", "dilation", "data_format", "return_indices", "ceil_mode"); // If stride is None, we set it as kernel_size to align Pytorch. attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size, dilation, data_format, return_indices, ceil_mode); const auto& pooling_out = JUST(OpInterpUtil::Dispatch(*op_, {unsqueezed_input}, attrs)); if (!is_batched) { TensorTuple squeezed_pooling_out; // (y,indices) squeezed_pooling_out.emplace_back( JUST(functional::Squeeze(pooling_out->at(0), std::vector{0}))); squeezed_pooling_out.emplace_back( JUST(functional::Squeeze(pooling_out->at(1), std::vector{0}))); return squeezed_pooling_out; } return pooling_out; } protected: int32_t num_spatial_dims_; std::shared_ptr op_; std::shared_ptr tf_maxpool_op_; }; class AvgPoolNDFunctor { public: AvgPoolNDFunctor() = default; virtual ~AvgPoolNDFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::vector& kernel_size, const Optional>& stride, const std::vector& padding, const bool& ceil_mode, const bool& count_include_pad, const int32_t& divisor_override, const std::string& data_format) const { // legacy tf style avgpool2d , use cudnn implementation with high performance but not support // count_include_pad and divisor_override. if (x->is_cuda() && x->ndim() == 4 && data_format == "channels_last") { CHECK_OR_THROW(count_include_pad) << "AvgPool2d with channels_last data format don't support count_include_pad for now."; CHECK_EQ_OR_THROW(divisor_override, 0) << "AvgPool2d with channels_last data format don't support divisor_override for now."; std::vector padding_before{JUST(VectorAt(padding, 0)), JUST(VectorAt(padding, 1))}; std::vector padding_after{JUST(VectorAt(padding, 0)), JUST(VectorAt(padding, 1))}; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("pool_size", "strides", "padding", "padding_before", "padding_after", "data_format", "ceil_mode"); attrs.SetAllAttrs(kernel_size, stride ? *JUST(stride) : kernel_size, std::string("customized"), padding_before, padding_after, data_format, ceil_mode); return JUST(OpInterpUtil::Dispatch(*tf_avgpool_op_, {x}, attrs)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("kernel_size", "padding", "stride", "data_format", "ceil_mode", "count_include_pad", "divisor_override"); // If stride is None, we set it as kernel_size to align Pytorch. attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size, data_format, ceil_mode, count_include_pad, divisor_override); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } protected: std::shared_ptr op_; std::shared_ptr tf_avgpool_op_; }; class TFAvgPool2DFunctor : public TFPoolNDFunctor { public: TFAvgPool2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("tf_avg_pool_2d").Input("x").Output("y").Build()); } }; class MaxPool1DFunctor : public MaxPoolNDFunctor { public: MaxPool1DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/1) { op_ = CHECK_JUST(one::OpBuilder("max_pool_1d").Input("x").Output("y").Output("indice").Build()); } }; class MaxPool2DFunctor : public MaxPoolNDFunctor { public: MaxPool2DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/2) { op_ = CHECK_JUST(one::OpBuilder("max_pool_2d").Input("x").Output("y").Output("indice").Build()); tf_maxpool_op_ = CHECK_JUST(one::OpBuilder("tf_max_pool_2d").Input("x").Output("y").Build()); } }; class MaxPool3DFunctor : public MaxPoolNDFunctor { public: MaxPool3DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/3) { op_ = CHECK_JUST(one::OpBuilder("max_pool_3d").Input("x").Output("y").Output("indice").Build()); } }; class AvgPool1DFunctor : public AvgPoolNDFunctor { public: AvgPool1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("avg_pool_1d").Input("x").Output("y").Build()); } }; class AvgPool2DFunctor : public AvgPoolNDFunctor { public: AvgPool2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("avg_pool_2d").Input("x").Output("y").Build()); tf_avgpool_op_ = CHECK_JUST(one::OpBuilder("tf_avg_pool_2d").Input("x").Output("y").Build()); } }; class AvgPool3DFunctor : public AvgPoolNDFunctor { public: AvgPool3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("avg_pool_3d").Input("x").Output("y").Build()); } }; template class MaxUnpoolNDFunctor { public: MaxUnpoolNDFunctor() : op_(CHECK_JUST(one::OpBuilder(fmt::format("max_unpool_{}d", N)) .Input("x") .Input("indices") .Output("y") .Build())){}; Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& indices, const std::vector& kernel_size, const Optional>& stride, const std::vector& padding, const Optional& output_size) const { const auto fmt_error_msg = [](const std::string& name, int32_t num, bool check_element) { if (check_element) { return fmt::format("each element in `{}` must be greater than 0, got {}", name, num); } return fmt::format("`{}` must be an integer or a list of {} integers", name, N); }; CHECK_EQ_OR_RETURN(kernel_size.size(), N) << fmt_error_msg("kernel_size", N, false); for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0) << fmt_error_msg("kernel_size", pool_dim, true); } if (stride) { CHECK_EQ_OR_RETURN(JUST(stride)->size(), N) << fmt_error_msg("stride", N, false); for (int32_t stride_dim : *JUST(stride)) { CHECK_GT_OR_RETURN(stride_dim, 0) << fmt_error_msg("stride", stride_dim, true); } } for (int32_t i = 0; i < padding.size(); i++) { CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i]) << "pad should be smaller than half of kernel size"; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("kernel_size", "padding", "stride", "has_output_size", "output_size"); attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size, output_size.has_value(), output_size.has_value() ? *JUST(output_size) : Shape()); return OpInterpUtil::Dispatch(*op_, {x, indices}, attrs); } protected: std::shared_ptr op_; }; class AdaptivePoolNDFunctor { public: AdaptivePoolNDFunctor() = default; virtual ~AdaptivePoolNDFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::vector& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("output_size", "data_format"); attrs.SetAllAttrs(output_size, data_format); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } protected: std::shared_ptr op_; }; class AdaptiveAvgPool1DFunctor : public AdaptivePoolNDFunctor { public: AdaptiveAvgPool1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("adaptive_avg_pool1d").Input("x").Output("y").Build()); } }; class AdaptiveAvgPool2DFunctor : public AdaptivePoolNDFunctor { public: AdaptiveAvgPool2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("adaptive_avg_pool2d").Input("x").Output("y").Build()); } }; class AdaptiveAvgPool3DFunctor : public AdaptivePoolNDFunctor { public: AdaptiveAvgPool3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("adaptive_avg_pool3d").Input("x").Output("y").Build()); } }; class AdaptiveMaxPoolBaseFunctor { public: AdaptiveMaxPoolBaseFunctor() = default; virtual ~AdaptiveMaxPoolBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::vector& output_size, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("output_size", "data_format"); attrs.SetAllAttrs(output_size, data_format); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } protected: std::shared_ptr op_; }; class AdaptiveMaxPool1DFunctor : public AdaptiveMaxPoolBaseFunctor { public: AdaptiveMaxPool1DFunctor() { op_ = CHECK_JUST( one::OpBuilder("adaptive_max_pool1d").Input("x").Output("y").Output("index").Build()); } }; class AdaptiveMaxPool2DFunctor : public AdaptiveMaxPoolBaseFunctor { public: AdaptiveMaxPool2DFunctor() { op_ = CHECK_JUST( one::OpBuilder("adaptive_max_pool2d").Input("x").Output("y").Output("index").Build()); } }; class AdaptiveMaxPool3DFunctor : public AdaptiveMaxPoolBaseFunctor { public: AdaptiveMaxPool3DFunctor() { op_ = CHECK_JUST( one::OpBuilder("adaptive_max_pool3d").Input("x").Output("y").Output("index").Build()); } }; class LossFunctorBase { public: Maybe apply_reduction(const Maybe& x, const std::string& reduction) const { CHECK_OR_RETURN(reduction == "none" || reduction == "sum" || reduction == "mean") << Error::RuntimeError() << "Reduction should be none, sum or mean."; if (reduction == "sum") { return functional::ReduceSum(JUST(x), {}, false, NullOpt); } if (reduction == "mean") { return functional::ReduceMean(JUST(x), {}, false); } return x; } protected: LossFunctorBase() = default; virtual ~LossFunctorBase() = default; }; class MseLossFunctor : public LossFunctorBase { public: MseLossFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const std::string& reduction) const { const auto out = sequence_function(functional::Sub) .then(functional::Square) .call(input, target, /*alpha=*/1.0, /*inplace=*/false); return apply_reduction(out, reduction); } }; class L1LossFunctor : public LossFunctorBase { public: L1LossFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const std::string& reduction) const { const auto out = sequence_function(functional::Sub) .then(functional::Abs) .call(input, target, /*alpha=*/1.0, /*inplace=*/false); return apply_reduction(out, reduction); } }; class SmoothL1LossFunctor : LossFunctorBase { public: SmoothL1LossFunctor() { op_ = CHECK_JUST( one::OpBuilder("smooth_l1_loss").Input("input").Input("target").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const float& beta, const std::string& reduction) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("beta"); attrs.SetAllAttrs(beta); return apply_reduction(OpInterpUtil::Dispatch(*op_, {input, target}, attrs), reduction); } private: std::shared_ptr op_; }; class KLDivLossFunctor : public LossFunctorBase { public: KLDivLossFunctor() { op_ = CHECK_JUST( one::OpBuilder("kl_div_loss").Input("input").Input("target").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const bool log_target, const std::string& reduction) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("log_target"); attrs.SetAllAttrs(log_target); if (reduction == "batchmean" && input->ndim() != 0) { const auto& result = JUST( apply_reduction(OpInterpUtil::Dispatch(*op_, {input, target}, attrs), "sum")); return ScalarDiv(result, input->shape()->At(0)); } else { return apply_reduction(OpInterpUtil::Dispatch(*op_, {input, target}, attrs), reduction); } } private: std::shared_ptr op_; }; class MarginRankingLossFunctor : public LossFunctorBase { public: Maybe operator()(const std::shared_ptr& input_1, const std::shared_ptr& input_2, const std::shared_ptr& target, const float margin, const std::string& reduction) const { const auto out = sequence_function(functional::Sub) .then(functional::Negative) .then(std::bind(functional::Mul, target, std::placeholders::_1)) .then([&margin](const std::shared_ptr& x) { return functional::ScalarAdd(x, Scalar(margin), /*alpha=*/1, /*inplace=*/true); }) .then(std::bind(functional::Clamp, std::placeholders::_1, Scalar(0), NullOpt)) .call(input_1, input_2, /*alpha=*/1.0, /*inplace=*/false); return apply_reduction(out, reduction); } }; class BinaryCrossEntropyLossFunctor : public LossFunctorBase { public: BinaryCrossEntropyLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy") .Input("input") .Input("target") .Output("out") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy") .Input("input") .Input("target") .Input("weight") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const std::string& reduction) const { auto out = weight ? OpInterpUtil::Dispatch(*op_weight_, {input, target, JUST(weight)}) : OpInterpUtil::Dispatch(*op_, {input, target}); return apply_reduction(out, reduction); } private: std::shared_ptr op_; std::shared_ptr op_weight_; }; class BinaryCrossEntropyWithLogitsLossFunctor : public LossFunctorBase { public: BinaryCrossEntropyWithLogitsLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits") .Input("input") .Input("target") .Output("out") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits") .Input("input") .Input("target") .Input("weight") .Output("out") .Build()); op_pos_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits") .Input("input") .Input("target") .Input("pos_weight") .Output("out") .Build()); op_weight_pos_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits") .Input("input") .Input("target") .Input("weight") .Input("pos_weight") .Output("out") .Build()); op_reduce_mean_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_reduce_mean") .Input("input") .Input("target") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const Optional& pos_weight, const std::string& reduction) const { if (pos_weight) { const auto pos_weight_shape = JUST(pos_weight)->shape(); // pos weight shape = (), (1,), (1,1)... or (input/target.shape[-1],) const bool is_pos_weight_shape_valid = (pos_weight_shape->elem_cnt() == 1) || (pos_weight_shape->NumAxes() == 1 && pos_weight_shape->At(0) == target->shape()->back()); CHECK_OR_RETURN(is_pos_weight_shape_valid) << Error::RuntimeError() << "pos_weight must be a vector with length equal to the number of classes."; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("has_pos_weight"); attrs.SetAllAttrs(pos_weight.has_value()); std::shared_ptr out; if (weight) { if (pos_weight) { out = JUST(OpInterpUtil::Dispatch( *op_weight_pos_, {input, target, JUST(weight), JUST(pos_weight)}, attrs)); } else { out = JUST(OpInterpUtil::Dispatch(*op_weight_, {input, target, JUST(weight)}, attrs)); } } else { if (pos_weight) { out = JUST( OpInterpUtil::Dispatch(*op_pos_, {input, target, JUST(pos_weight)}, attrs)); } else { if (reduction == "mean") { return OpInterpUtil::Dispatch(*op_reduce_mean_, {input, target}); } out = JUST(OpInterpUtil::Dispatch(*op_, {input, target}, attrs)); } } return apply_reduction(out, reduction); } private: std::shared_ptr op_; std::shared_ptr op_weight_; std::shared_ptr op_pos_; std::shared_ptr op_weight_pos_; std::shared_ptr op_reduce_mean_; }; class NLLLossFunctor { public: NLLLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Output("output") .Output("out_weight") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Input("weight") .Output("output") .Output("out_weight") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const int64_t& ignore_index, const std::string& reduction) const { CHECK_OR_RETURN(reduction == "none" || reduction == "sum" || reduction == "mean") << Error::RuntimeError() << "Reduction should be none, sum or mean."; const auto& input_shape = input->shape(); const int64_t K = input_shape->NumAxes(); CHECK_GE_OR_RETURN(K, 2) << Error::RuntimeError() << "Expected 2 or more dimensions"; const int64_t N = input_shape->At(0); const int64_t C = input_shape->At(1); const auto& target_shape = target->shape(); CHECK_EQ_OR_RETURN(target_shape->NumAxes(), K - 1) << Error::RuntimeError() << "Expected target dimensions (" << K - 1 << ") to match input dimensions (" << K << "), got " << target_shape->NumAxes(); CHECK_EQ_OR_RETURN(target_shape->At(0), N) << Error::RuntimeError() << "Expected input batch_size (" << N << ") to match target batch_size (" << target_shape->At(0) << ")"; std::shared_ptr input_; std::shared_ptr target_; if (K > 2) { DimVector idea_target_dim_vec; idea_target_dim_vec.push_back(N); for (int64_t i = 2; i < K; ++i) { idea_target_dim_vec.push_back(input_shape->At(i)); } Shape idea_target_shape(idea_target_dim_vec); CHECK_EQ_OR_RETURN(*target_shape, idea_target_shape) << Error::RuntimeError() << "Expected target shape " << idea_target_shape.ToString() << ", got " << target_shape->ToString(); std::vector perm(input_shape->dim_vec().size(), 0); perm[perm.size() - 1] = 1; for (size_t i = 1; i < perm.size() - 1; ++i) { perm[i] = i + 1; } input_ = JUST(sequence_function(functional::Transpose) .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, C}))) .call(input, perm)); target_ = JUST(functional::Flatten(target, 0, K - 2)); } else { input_ = input; target_ = target; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); attrs.SetAllAttrs(ignore_index); std::shared_ptr nll_result; if (weight) { nll_result = JUST( OpInterpUtil::Dispatch(*op_weight_, {input_, target_, JUST(weight)}, attrs)); } else { nll_result = JUST(OpInterpUtil::Dispatch(*op_, {input_, target_}, attrs)); } auto output = JUST(VectorAt(*nll_result, 0)); if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); } if (reduction == "none") { return output; } auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt)); if (reduction == "sum") { return sum; } auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt)); return functional::Div(sum, total_weight); } private: std::shared_ptr op_; std::shared_ptr op_weight_; }; class CrossEntropyFunctor { public: CrossEntropyFunctor() { op_log_softmax_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build()); op_nll_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Output("output") .Output("out_weight") .Build()); op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Input("weight") .Output("output") .Output("out_weight") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const int64_t& ignore_index, const std::string& reduction, const double& label_smoothing) const { if (input->shape() == target->shape()) { CHECK_OR_RETURN(target->dtype()->is_floating_point()) << "Expected floating point type for target with class probabilities, got " << target->dtype()->name(); CHECK_LT_OR_RETURN(ignore_index, 0) << "ignore_index is not supported for floating point targe"; return CrossEntropyProb(input, target, weight, reduction, label_smoothing); } if (label_smoothing > 0.0) return CrossEntropyLabelSmoothing(input, target, weight, ignore_index, reduction, label_smoothing); CHECK_OR_RETURN(reduction == "none" || reduction == "sum" || reduction == "mean") << Error::RuntimeError() << "Reduction should be none, sum or mean."; const auto& input_shape = input->shape(); const auto& target_shape = target->shape(); std::vector input_perm(input_shape->dim_vec().size(), 0); input_perm[input_perm.size() - 1] = 1; for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; } const auto input_ = JUST(sequence_function(functional::Transpose) .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, input_shape->At(1)}))) .then([this](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_log_softmax_, {x}); }) .call(input, input_perm)); const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); attrs.SetAllAttrs(ignore_index); std::shared_ptr nll_result; if (weight) { nll_result = JUST(OpInterpUtil::Dispatch( *op_nll_weight_, {input_, target_, JUST(weight)}, attrs)); } else { nll_result = JUST(OpInterpUtil::Dispatch(*op_nll_, {input_, target_}, attrs)); } auto output = JUST(VectorAt(*nll_result, 0)); output = JUST(functional::Reshape(output, *target_shape)); if (reduction == "none") { return output; } auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt)); if (reduction == "sum") { return sum; } auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt)); return functional::Div(sum, total_weight); } private: std::shared_ptr op_log_softmax_; std::shared_ptr op_nll_; std::shared_ptr op_nll_weight_; }; class CrossEntropyLabelSmoothingFunctor { public: CrossEntropyLabelSmoothingFunctor() { op_log_softmax_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build()); op_nll_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Output("output") .Output("out_weight") .Build()); op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Input("weight") .Output("output") .Output("out_weight") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const int64_t& ignore_index, const std::string& reduction, const double& label_smoothing) const { CHECK_OR_RETURN(reduction == "none" || reduction == "sum" || reduction == "mean") << Error::RuntimeError() << "Reduction should be none, sum or mean."; const auto& input_shape = input->shape(); const auto& target_shape = target->shape(); std::vector input_perm(input_shape->dim_vec().size(), 0); input_perm[input_perm.size() - 1] = 1; for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; } CHECK_OR_RETURN(label_smoothing > 0.0 && label_smoothing <= 1.0) << "label_smoothing must be between 0.0 and 1.0. Got: " << label_smoothing; const auto& input_ = JUST(sequence_function(functional::Transpose) .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, input_shape->At(1)}))) .then([this](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_log_softmax_, {x}); }) .call(input, input_perm)); const auto& target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); attrs.SetAllAttrs(ignore_index); std::shared_ptr nll_result; if (weight) { nll_result = JUST(OpInterpUtil::Dispatch( *op_nll_weight_, {input_, target_, JUST(weight)}, attrs)); } else { nll_result = JUST(OpInterpUtil::Dispatch(*op_nll_, {input_, target_}, attrs)); } const auto& ignore_mask = JUST(Reshape(JUST(ScalarLogicalEqual(target_, ignore_index)), {-1})); // smooth_loss = (-(input_ * weight.reshape(1, -1)).sum(1) * ~ignore_mask).reshape_as(target) std::shared_ptr smooth_loss = input_; if (weight) { const auto& weight_2d = JUST(Reshape(JUST(weight), {1, -1})); smooth_loss = JUST(Mul(smooth_loss, weight_2d)); } smooth_loss = JUST(Negative(JUST(ReduceSum(smooth_loss, {1}, false, NullOpt)))); smooth_loss = JUST(MaskedFill(smooth_loss, ignore_mask, 0.0)); smooth_loss = JUST(Reshape(smooth_loss, *target_shape)); int64_t n_classes = input->shape()->At(1); auto nll_loss = JUST(VectorAt(*nll_result, 0)); nll_loss = JUST(functional::Reshape(nll_loss, *target_shape)); // loss = nll_loss * (1 - label_smoothing) + smooth_loss * label_smoothing / num_classes if (reduction == "none") { return JUST(Add(JUST(ScalarMul(nll_loss, 1 - label_smoothing, false)), JUST(ScalarMul(smooth_loss, label_smoothing / n_classes, false)), 1, false)); } const auto& nll_loss_sum = JUST(ReduceSum(nll_loss, {}, false, NullOpt)); const auto& smooth_loss_sum = JUST(ReduceSum(smooth_loss, {}, false, NullOpt)); const auto& cross_entropy_loss_sum = JUST(Add(JUST(ScalarMul(nll_loss_sum, 1 - label_smoothing, false)), JUST(ScalarMul(smooth_loss_sum, label_smoothing / n_classes, false)), 1, false)); if (reduction == "sum") { return cross_entropy_loss_sum; } const auto& total_weight = JUST(ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt)); return Div(cross_entropy_loss_sum, total_weight); } private: std::shared_ptr op_log_softmax_; std::shared_ptr op_nll_; std::shared_ptr op_nll_weight_; }; class CrossEntropyProbFunctor : public LossFunctorBase { public: CrossEntropyProbFunctor() { op_log_softmax_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const std::string& reduction, const double& label_smoothing) const { const auto& input_shape = input->shape(); const auto& target_shape = target->shape(); std::vector input_perm(input_shape->NumAxes(), 0); input_perm[input_perm.size() - 1] = 1; for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; } const auto input_ = JUST(sequence_function(functional::Transpose) .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, input_shape->At(1)}))) .then([this](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_log_softmax_, {x}); }) .call(input, input_perm)); std::shared_ptr target_ = JUST(sequence_function(functional::Transpose) .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, target_shape->At(1)}))) .call(target, input_perm)); if (label_smoothing > 0) { int32_t num_classes = input_->shape()->At(1); target_ = JUST(ScalarAdd(JUST(ScalarMul(target_, static_cast(1) - label_smoothing, false)), label_smoothing / static_cast(num_classes), 1, false)); } auto nll_result = JUST(Negative(JUST(Mul(input_, target_)))); if (weight) { const auto& weight_expand = JUST(Unsqueeze(JUST(weight), 0)); nll_result = JUST(Mul(nll_result, weight_expand)); } DimVector target_reshape_(input->ndim() - 1); for (size_t i = 0; i < target_reshape_.size(); ++i) { target_reshape_[i] = input_shape->At(input_perm[i]); } nll_result = JUST(ReduceSum(nll_result, {-1}, false, NullOpt)); nll_result = JUST(Reshape(nll_result, Shape(target_reshape_))); return apply_reduction(nll_result, reduction); } private: std::shared_ptr op_log_softmax_; }; class SparseCrossEntropyFunctor { public: SparseCrossEntropyFunctor() { op_ = CHECK_JUST(one::OpBuilder("sparse_cross_entropy") .Input("prediction") .Input("label") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& prediction, const std::shared_ptr& label, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prediction, label}, attrs); } private: std::shared_ptr op_; }; class SparseCrossEntropyMsFunctor { public: SparseCrossEntropyMsFunctor() { op_ = CHECK_JUST(one::OpBuilder("sparse_cross_entropy_ms") .Input("prediction") .Input("label") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& prediction, const std::shared_ptr& label, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prediction, label}, attrs); } private: std::shared_ptr op_; }; class SparseSoftmaxCrossEntropyFunctor { public: SparseSoftmaxCrossEntropyFunctor() { // SparseSoftmaxCrossEntropy op_sparse_softmax_cross_entropy_ = CHECK_JUST(one::OpBuilder("sparse_softmax_cross_entropy") .Input("prediction") .Input("label") .Output("prob") .Output("out") .Build()); // lazy model SparseSoftmaxCrossEntropyMs op_sparse_softmax_cross_entropy_ms_ = CHECK_JUST(one::OpBuilder("sparse_softmax_cross_entropy_ms") .Input("prediction") .Input("label") .Output("prob") .Output("out") .Build()); // eager model SparseSoftmaxCrossEntropyMs op_reduce_max_device_stage_ = CHECK_JUST(one::OpBuilder("reduce_max_device_stage") .Input("in") .Output("out") .Output("mask") .Output("count") .Build()); op_reduce_max_global_stage_ = CHECK_JUST(one::OpBuilder("reduce_max_global_stage") .Input("in") .Input("device_count") .Output("out") .Output("mask") .Build()); op_sparse_cross_entropy_ms_ = CHECK_JUST(one::OpBuilder("sparse_cross_entropy_ms") .Input("prediction") .Input("label") .Output("out") .Build()); op_broadcast_sub_ = CHECK_JUST(one::OpBuilder("broadcast_sub").Input("x").Input("y").Output("z").Build()); op_broadcast_div_ = CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build()); op_reduce_sum_ = CHECK_JUST( one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build()); op_exp_ = CHECK_JUST(one::OpBuilder("exp").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& logits, const std::shared_ptr& label) const { if (JUST(RunWithMsVersion(logits, label))) { if (LazyMode::is_enabled()) { return LazySparseSoftmaxCrossEntropyMsOperator(logits, label); } else { return EagerSparseSoftmaxCrossEntropyMsOperator(logits, label); } } else { return SparseSoftmaxCrossEntropyOperator(logits, label); } } Maybe RunWithMsVersion(const std::shared_ptr& logits, const std::shared_ptr& label) const { if (!(logits->is_global() && label->is_global())) { return false; } // npu-implementation not support ms version yet #if defined(WITH_NPU) || defined(WITH_MLU) return false; #endif if (JUST(logits->parallel_desc())->parallel_num() == 1) { return false; } if (logits->shape()->NumAxes() != 2) { return false; } const NdSbp& logits_nd_sbp = *(JUST(logits->nd_sbp())); const int32_t split_axis = logits->shape()->NumAxes() - 1; bool has_split_axis_parallel = false; for (int64_t i = 0; i < logits_nd_sbp.sbp_parallel_size(); ++i) { const auto& sbp = logits_nd_sbp.sbp_parallel(i); if (sbp.has_split_parallel() && sbp.split_parallel().axis() == split_axis) { has_split_axis_parallel = true; } else { if (sbp.has_partial_sum_parallel()) { return false; } } } if (!has_split_axis_parallel) { return false; } return true; } Maybe SparseSoftmaxCrossEntropyOperator(const std::shared_ptr& logits, const std::shared_ptr& label) const { int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); const auto& result = JUST(OpInterpUtil::Dispatch(*op_sparse_softmax_cross_entropy_, {logits, label}, attrs)); return result->at(1); } Maybe LazySparseSoftmaxCrossEntropyMsOperator( const std::shared_ptr& logits, const std::shared_ptr& label) const { int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); const auto& result = JUST(OpInterpUtil::Dispatch( *op_sparse_softmax_cross_entropy_ms_, {logits, label}, attrs)); return result->at(1); } Maybe EagerSparseSoftmaxCrossEntropyMsOperator( const std::shared_ptr& logits, const std::shared_ptr& label) const { // op_reduce_max_device_stage_ int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1); int32_t axis = logits->shape()->NumAxes() - 1; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(std::vector{axis}); const auto& max_device_stage = JUST(OpInterpUtil::Dispatch(*op_reduce_max_device_stage_, {logits}, attrs)); std::shared_ptr max_global_stage_input0 = max_device_stage->at(0); std::shared_ptr max_global_stage_input1 = max_device_stage->at(2); const NdSbp& logits_nd_sbp = *(JUST(logits->nd_sbp())); std::vector> new_sbp_parallels; std::vector> s0s1_sbp_parallels; if (logits_nd_sbp.sbp_parallel_size() == 2) { for (int i = 0; i < logits_nd_sbp.sbp_parallel_size(); ++i) { const auto& sbp_parallel = logits_nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t& split_axis = sbp_parallel.split_parallel().axis(); if (split_axis == axis) { SbpParallel sbp; sbp.mutable_broadcast_parallel(); new_sbp_parallels.emplace_back(sbp); } else { CHECK_EQ_OR_RETURN(split_axis, 0) << Error::RuntimeError() << "Split axis must equal to 0. "; new_sbp_parallels.emplace_back(sbp_parallel); } } else { new_sbp_parallels.emplace_back(sbp_parallel); } } s0s1_sbp_parallels.emplace_back(logits_nd_sbp.sbp_parallel(0)); s0s1_sbp_parallels.emplace_back(logits_nd_sbp.sbp_parallel(1)); max_global_stage_input0 = JUST(functional::ToGlobal( (*max_device_stage)[0], JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels, s0s1_sbp_parallels, /* check_meta */ false, /*copy=*/false)); max_global_stage_input1 = JUST(functional::ToGlobal( (*max_device_stage)[2], JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels, s0s1_sbp_parallels, /* check_meta */ false, /*copy=*/false)); } // op_reduce_max_global_stage_ auto& reduce_max_global_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); reduce_max_global_attrs.SetAllAttrs(std::vector{axis}, true); const auto& max_global_stage = JUST(OpInterpUtil::Dispatch( *op_reduce_max_global_stage_, {max_global_stage_input0, max_global_stage_input1}, reduce_max_global_attrs)); auto& broadcast_sub_input = max_global_stage->at(0); if (logits_nd_sbp.sbp_parallel_size() == 2) { broadcast_sub_input = JUST(functional::ToGlobal( broadcast_sub_input, JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels, new_sbp_parallels, /* check_meta */ false, /*copy=*/false)); } // op_broadcast_sub_ const auto& output_broadcast_sub = JUST( OpInterpUtil::Dispatch(*op_broadcast_sub_, {logits, broadcast_sub_input})); // op_exp_ const auto& output_exp = JUST(OpInterpUtil::Dispatch(*op_exp_, {(*output_broadcast_sub)[0]})); // op_reduce_sum_ auto& reduce_sum_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "keepdims"); reduce_sum_attrs.SetAllAttrs(std::vector{axis}, true); const auto& output_reduce_sum = JUST( OpInterpUtil::Dispatch(*op_reduce_sum_, {(*output_exp)[0]}, reduce_sum_attrs)); std::shared_ptr broadcast_div_input1 = output_reduce_sum->at(0); if (logits_nd_sbp.sbp_parallel_size() == 2) { std::vector> empty_grad_sbp_parallels; broadcast_div_input1 = JUST(functional::ToGlobal( (*output_reduce_sum)[0], JUST((*output_reduce_sum)[0]->parallel_desc()), new_sbp_parallels, new_sbp_parallels, /* check_meta */ false, /*copy=*/false)); } // op_broadcast_div_ const auto& predictions = JUST(OpInterpUtil::Dispatch( *op_broadcast_div_, {(*output_exp)[0], broadcast_div_input1})); // op_sparse_cross_entropy_ms_ auto& sparse_cross_entropy_ms_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); sparse_cross_entropy_ms_attrs.SetAllAttrs(depth); const auto& output = JUST(OpInterpUtil::Dispatch( *op_sparse_cross_entropy_ms_, {(*predictions)[0], label}, sparse_cross_entropy_ms_attrs)); return output; } private: // SparseSoftmaxCrossEntropy std::shared_ptr op_sparse_softmax_cross_entropy_; // lazy model SparseSoftmaxCrossEntropyMs std::shared_ptr op_sparse_softmax_cross_entropy_ms_; // SparseSoftmaxCrossEntropyMs std::shared_ptr op_reduce_max_device_stage_; std::shared_ptr op_reduce_max_global_stage_; std::shared_ptr op_broadcast_sub_; std::shared_ptr op_exp_; std::shared_ptr op_reduce_sum_; std::shared_ptr op_broadcast_div_; std::shared_ptr op_sparse_cross_entropy_ms_; }; class SoftmaxCrossEntropyFunctor { public: SoftmaxCrossEntropyFunctor() { op_ = CHECK_JUST(one::OpBuilder("softmax_cross_entropy") .Input("prediction") .Input("label") .Output("out") .Output("prob") .Build()); } Maybe operator()(const std::shared_ptr& logits, const std::shared_ptr& label) const { return OpInterpUtil::Dispatch(*op_, {logits, label}); } private: std::shared_ptr op_; }; class SoftmaxCrossEntropyGradFunctor { public: SoftmaxCrossEntropyGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("softmax_cross_entropy_grad") .Input("dy") .Input("label") .Input("prob") .Output("prediction_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& label, const std::shared_ptr& prob) const { return OpInterpUtil::Dispatch(*op_, {dy, label, prob}); } private: std::shared_ptr op_; }; class CombinedMarginLossFunctor { public: CombinedMarginLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("combined_margin_loss") .Input("x") .Input("label") .Output("y") .Output("theta") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& label, const float& m1, const float& m2, const float& m3) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("m1", "m2", "m3", "depth"); attrs.SetAllAttrs(m1, m2, m3, x->shape()->At(1)); return OpInterpUtil::Dispatch(*op_, {x, label}, attrs); } private: std::shared_ptr op_; }; class CtcLossFunctor { public: CtcLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("ctc_loss") .Input("log_probs") .Input("targets") .Input("input_lengths") .Input("target_lengths") .Output("loss") .Output("alpha") .Build()); op_xdivy_ = CHECK_JUST(one::OpBuilder("xdivy").Input("x").Input("y").Output("z").Build()); } Maybe operator()(const std::shared_ptr& log_probs, const std::shared_ptr& targets, const std::shared_ptr& input_lengths, const std::shared_ptr& target_lengths, const int64_t& max_target_length, const int64_t& blank, const bool& zero_infinity, const std::string& reduction) const { // FIXME: global ctc loss sometimes segfaults auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("max_target_length", "blank", "zero_infinity"); attrs.SetAllAttrs(max_target_length, blank, zero_infinity); std::shared_ptr out; DeviceType log_probs_device_type; // NOLINT if (log_probs->is_local()) { log_probs_device_type = JUST(log_probs->device())->enum_type(); } else { log_probs_device_type = JUST(log_probs->parallel_desc())->device_type(); } const std::string& log_probs_device_str = *JUST(DeviceTag4DeviceType(log_probs_device_type)); std::shared_ptr target_lengths_on_log_probs_device = JUST(functional::To(target_lengths, log_probs_device_str)); if (targets->dtype()->data_type() == DataType::kInt32) { out = JUST(OpInterpUtil::Dispatch( *op_, { log_probs, JUST(functional::To(targets, log_probs_device_str)), JUST(functional::To(input_lengths, log_probs_device_str)), target_lengths_on_log_probs_device, }, attrs)); } else { out = JUST(OpInterpUtil::Dispatch( *op_, { log_probs, JUST(functional::To(targets, Optional(log_probs_device_str), DType::Int64(), false)), JUST(functional::To(input_lengths, log_probs_device_str)), target_lengths_on_log_probs_device, }, attrs)); } if (zero_infinity) { if (out->is_local()) { const auto create_constant = [&](const Scalar& scalar) -> Maybe { return functional::Constant(*out->shape(), scalar, out->dtype(), JUST(out->device())); }; out = JUST(sequence_function(functional::Constant) .then(std::bind(functional::BroadcastEqual, out, std::placeholders::_1)) .then(std::bind(functional::Where, std::placeholders::_1, JUST(create_constant(Scalar(0))), out)) .call(*out->shape(), Scalar(std::numeric_limits::infinity()), out->dtype(), JUST(out->device()))); } else { const auto& placement = JUST(out->parallel_desc()); const auto& nd_sbp = *JUST(GetSbpList(JUST(out->nd_sbp()))); const auto create_constant = [&](const Scalar& scalar) -> Maybe { return functional::GlobalConstant(*out->shape(), scalar, out->dtype(), placement, nd_sbp); }; out = JUST(sequence_function(functional::GlobalConstant) .then(std::bind(functional::BroadcastEqual, out, std::placeholders::_1)) .then(std::bind(functional::Where, std::placeholders::_1, JUST(create_constant(Scalar(0))), out)) .call(*out->shape(), Scalar(std::numeric_limits::infinity()), out->dtype(), placement, nd_sbp)); } } CHECK_OR_RETURN([&]() -> bool { if ((reduction != "none") && (reduction != "sum") && (reduction != "mean")) return false; return true; }()) << Error::RuntimeError() << "Reduction should be none, sum or mean."; if (reduction == "sum") { return functional::ReduceSum(out, {}, false, NullOpt); } if (reduction == "mean") { return sequence_function(functional::Clamp) .then(std::bind(functional::Cast, std::placeholders::_1, log_probs->dtype(), /*pin_memory=*/false)) .then([&](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*op_xdivy_, {out, x}); }) .then(std::bind(functional::ReduceMean, std::placeholders::_1, std::vector({}), false)) .call(target_lengths_on_log_probs_device, Scalar(1), NullOpt); } return out; } private: std::shared_ptr op_; std::shared_ptr op_xdivy_; }; class TripletMarginLossFunctor { public: TripletMarginLossFunctor() {} Maybe operator()(const std::shared_ptr& anchor, const std::shared_ptr& positive, const std::shared_ptr& negative, const float& margin, const float& p, const float& eps, const bool& swap, const std::string& reduction) const { int32_t dim_norm = anchor->ndim() - 1; std::vector dim(1, dim_norm); CHECK_OR_RETURN([&]() -> bool { if ((reduction != "none") && (reduction != "sum") && (reduction != "mean")) return false; return true; }()) << Error::RuntimeError() << "Reduction should be none, sum or mean."; auto da_p = JUST(VectorNorm( JUST(ScalarAdd(eps, JUST(Sub(anchor, positive, /*alpha=*/1.0, /*inplace=*/false)), /*alpha=*/1)), p, dim, /*keepdim=*/false, anchor->dtype())); auto da_n = JUST(VectorNorm( JUST(ScalarAdd(eps, JUST(Sub(anchor, negative, /*alpha=*/1.0, /*inplace=*/false)), /*alpha=*/1)), p, dim, /*keepdim=*/false, anchor->dtype())); if (swap) { auto distance_swap = JUST(VectorNorm( JUST(ScalarAdd(eps, JUST(Sub(positive, negative, /*alpha=*/1.0, /*inplace=*/false)), /*alpha=*/1)), p, dim, /*keepdim=*/false, positive->dtype())); da_n = JUST(Minimum(distance_swap, da_n)); } auto triplet_loss = JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n, /*alpha=*/1.0, /*inplace=*/false)), margin, /*alpha=*/1, /*inplace=*/false)), /*min=*/0.0, NullOpt)); int32_t ndim = triplet_loss->ndim() - 1; std::vector axis(1, ndim); if (reduction == "mean") { triplet_loss = JUST(ReduceMean(triplet_loss, axis, /*keepdim=*/false)); } else if (reduction == "sum") { triplet_loss = JUST(ReduceSum(triplet_loss, axis, /*keepdim=*/false, NullOpt)); } return triplet_loss; } }; class AffineGridFunctor { public: AffineGridFunctor() { op_ = CHECK_JUST(one::OpBuilder("affine_grid").Input("theta").Output("grid").Build()); } Maybe operator()(const std::shared_ptr& theta, const Shape& size, const bool& align_corners) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size", "align_corners"); attrs.SetAllAttrs(size, align_corners); return OpInterpUtil::Dispatch(*op_, {theta}, attrs); } private: std::shared_ptr op_; }; class GridSampleFunctor { public: GridSampleFunctor() { op_ = CHECK_JUST( one::OpBuilder("grid_sample").Input("input").Input("grid").Output("output").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& grid, const std::string& interpolation_mode, const std::string& padding_mode, const bool& align_corners) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("interpolation_mode", "padding_mode", "align_corners"); attrs.SetAllAttrs(interpolation_mode, padding_mode, align_corners); return OpInterpUtil::Dispatch(*op_, {input, grid}, attrs); } private: std::shared_ptr op_; }; class NormalizationFunctor { public: NormalizationFunctor() { norm_eval_op_ = CHECK_JUST(one::OpBuilder("normalization") .Input("x") .Input("moving_mean") .Input("moving_variance") .Input("gamma") .Input("beta") .Output("y") .Attr("training", false) .Build()); norm_training_stats_op_ = CHECK_JUST(one::OpBuilder("normalization") .Input("x") .Input("moving_mean") .Input("moving_variance") .Input("gamma") .Input("beta") .Output("y") .Output("mean") .Output("inv_variance") .Attr("training", true) .Build()); norm_training_no_stats_op_ = CHECK_JUST(one::OpBuilder("normalization") .Input("x") .Input("gamma") .Input("beta") .Output("y") .Output("mean") .Output("inv_variance") .Attr("training", true) .Build()); cast_op_ = CHECK_JUST(one::OpBuilder("cast").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Optional& moving_mean, const Optional& moving_variance, const Optional& gamma, const Optional& beta, const int32_t& axis, const float& epsilon, const float& momentum, const bool& training) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "epsilon", "momentum"); // convert torch momentum to tensorflow momentum attrs.SetAllAttrs(axis, epsilon, static_cast(1.0 - momentum)); CHECK_OR_RETURN((moving_mean && moving_variance) || (!moving_mean && !moving_variance)) << Error::RuntimeError() << "Both running_mean and running_variance should be None or Tensor."; const DataType dtype = x->dtype()->data_type(); std::shared_ptr gamma_val; std::shared_ptr beta_val; CHECK_GE_OR_RETURN(x->shape()->NumAxes(), 2) << Error::RuntimeError() << "NumAxes of x should be greater or equal than 2. "; if (gamma.has_value() && beta.has_value()) { gamma_val = JUST(gamma); beta_val = JUST(beta); } else { const Shape gamma_beta_shape = Shape({x->shape()->At(1)}); gamma_val = JUST(functional::Constant(gamma_beta_shape, 1.0, x->dtype(), JUST(x->device()))); beta_val = JUST(functional::Constant(gamma_beta_shape, 0.0, x->dtype(), JUST(x->device()))); } const DataType gamma_dtype = gamma_val->dtype()->data_type(); const DataType beta_dtype = beta_val->dtype()->data_type(); CHECK_EQ_OR_RETURN(gamma_dtype, beta_dtype) << Error::RuntimeError() << "gamma and beta have different data types."; if (gamma_dtype != dtype) { gamma_val = JUST(functional::Cast(gamma_val, DType{dtype}, /*pin_memory=*/false)); beta_val = JUST(functional::Cast(beta_val, DType{dtype}, /*pin_memory=*/false)); } std::shared_ptr moving_mean_val; std::shared_ptr moving_variance_val; bool need_cast_moving_stats = false; if (moving_mean) { const DataType moving_mean_dtype = JUST(moving_mean)->dtype()->data_type(); CHECK_EQ_OR_RETURN(JUST(moving_variance)->dtype()->data_type(), moving_mean_dtype) << Error::RuntimeError() << "moving_mean and moving_variance have different data types."; need_cast_moving_stats = (moving_mean_dtype != dtype); if (need_cast_moving_stats) { moving_mean_val = JUST(functional::Cast(JUST(moving_mean), DType{dtype}, /*pin_memory=*/false)); moving_variance_val = JUST(functional::Cast(JUST(moving_variance), DType{dtype}, /*pin_memory=*/false)); } else { moving_mean_val = JUST(moving_mean); moving_variance_val = JUST(moving_variance); } } std::shared_ptr res; if (!training) { CHECK_OR_RETURN(moving_mean && moving_variance) << Error::RuntimeError() << "Must have moving_mean and moving_variance in eval mode."; res = JUST(OpInterpUtil::Dispatch( *norm_eval_op_, {x, moving_mean_val, moving_variance_val, gamma_val, beta_val}, attrs)); } else if (moving_mean) { res = JUST(OpInterpUtil::Dispatch( *norm_training_stats_op_, {x, moving_mean_val, moving_variance_val, gamma_val, beta_val}, attrs)); } else { res = JUST(OpInterpUtil::Dispatch(*norm_training_no_stats_op_, {x, gamma_val, beta_val}, attrs)); } if (need_cast_moving_stats) { // For inplace update moving_mean and moving_variance JUST(CheckInplaceValid(JUST(moving_mean))); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = JUST(moving_mean); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype", "pin_memory"); attrs.SetAllAttrs(JUST(moving_mean)->dtype()->data_type(), false); JUST(OpInterpUtil::Dispatch(*cast_op_, {moving_mean_val}, outputs.get(), attrs)); JUST(CheckInplaceValid(JUST(moving_variance))); outputs->at(0) = JUST(moving_variance); JUST(OpInterpUtil::Dispatch(*cast_op_, {moving_variance_val}, outputs.get(), attrs)); } return res; } private: std::shared_ptr norm_eval_op_; std::shared_ptr norm_training_stats_op_; std::shared_ptr norm_training_no_stats_op_; std::shared_ptr cast_op_; }; class NormalizationAddReluFunctor { public: NormalizationAddReluFunctor() { fused_norm_training_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/false, /*training=*/true)); fused_addend_norm_training_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/true, /*training=*/true)); fused_norm_training_no_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/false, /*addend=*/false, /*training=*/true)); fused_addend_norm_training_no_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/false, /*addend=*/true, /*training=*/true)); fused_norm_eval_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/false, /*training=*/false)); fused_addend_norm_eval_stats_op_ = CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/true, /*training=*/false)); } Maybe BuildFusedNormalizationOp(bool stats, bool addend, bool training) { auto op_builder = one::OpBuilder("normalization_add_relu") .Input("x") .Output("y") .Output("reserve_space") .Attr("training", training); if (addend) { op_builder.Input("addend"); } if (stats) { op_builder.Input("moving_mean").Input("moving_variance"); } op_builder.Input("gamma").Input("beta"); if (training) { op_builder.Output("mean").Output("inv_variance"); } return op_builder.Build(); } Maybe operator()(const std::shared_ptr& x, const Optional& addend, const Optional& moving_mean, const Optional& moving_variance, const std::shared_ptr& gamma, const std::shared_ptr& beta, const int32_t& axis, const float& epsilon, const float& momentum, const bool& is_training) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "epsilon", "momentum"); // convert torch momentum to tensorflow momentum attrs.SetAllAttrs(axis, epsilon, static_cast(1.0 - momentum)); CHECK_OR_RETURN((moving_mean && moving_variance) || (!moving_mean && !moving_variance)) << Error::RuntimeError() << "Both moving_mean and moving_variance should be None or Tensor."; if (!is_training) { CHECK_OR_RETURN(moving_mean && moving_variance) << Error::RuntimeError() << "Must have moving_mean and moving_variance in eval mode."; if (addend) { return OpInterpUtil::Dispatch( *fused_addend_norm_eval_stats_op_, {x, JUST(addend), JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs); } else { return OpInterpUtil::Dispatch( *fused_norm_eval_stats_op_, {x, JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs); } } else if (moving_mean) { if (addend) { return OpInterpUtil::Dispatch( *fused_addend_norm_training_stats_op_, {x, JUST(addend), JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs); } else { return OpInterpUtil::Dispatch( *fused_norm_training_stats_op_, {x, JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs); } } else { if (addend) { return OpInterpUtil::Dispatch(*fused_addend_norm_training_no_stats_op_, {x, JUST(addend), gamma, beta}, attrs); } else { return OpInterpUtil::Dispatch(*fused_norm_training_no_stats_op_, {x, gamma, beta}, attrs); } } } private: std::shared_ptr fused_norm_training_stats_op_; std::shared_ptr fused_addend_norm_training_stats_op_; std::shared_ptr fused_norm_training_no_stats_op_; std::shared_ptr fused_addend_norm_training_no_stats_op_; std::shared_ptr fused_norm_eval_stats_op_; std::shared_ptr fused_addend_norm_eval_stats_op_; }; class ConstantPadFunctor { public: ConstantPadFunctor() { constant_pad_ = CHECK_JUST(one::OpBuilder("pad").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input, const std::vector& pad, const Scalar& value) const { const int64_t ndim = input->shape()->NumAxes(); const int64_t pad_size = pad.size(); CHECK_LE_OR_RETURN(pad_size, 2 * ndim) << Error::RuntimeError() << "Pad size should less than or equal to input axes * 2."; CHECK_EQ_OR_RETURN(pad_size % 2, 0) << Error::RuntimeError() << "Length of pad must be even but instead it equals " << pad_size; std::vector pad_before(ndim, 0); std::vector pad_after(ndim, 0); const int64_t pad_pair = pad_size / 2; for (int64_t i = 0; i < pad_pair; ++i) { pad_before[ndim - i - 1] = pad[2 * i]; pad_after[ndim - i - 1] = pad[2 * i + 1]; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding", "floating_constant_value", "integral_constant_value", "padding_before", "padding_after"); if (IsFloatingDataType(input->dtype()->data_type()) || IsComplexDataType(input->dtype()->data_type())) { attrs.SetAllAttrs(pad, value.As(), static_cast(0), pad_before, pad_after); } else if (IsIntegralDataType(input->dtype()->data_type())) { attrs.SetAllAttrs(pad, static_cast(0), value.As(), pad_before, pad_after); } else if (input->dtype() == DType::Bool()) { int64_t bool_value = value.As(); CHECK_OR_RETURN(bool_value == 1 || bool_value == 0) << "value must be 1/0 or True/False for bool Tensor"; attrs.SetAllAttrs(pad, static_cast(0), value.As(), pad_before, pad_after); } else { UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating, bool or integral type."; } return OpInterpUtil::Dispatch(*constant_pad_, {input}, attrs); } private: std::shared_ptr constant_pad_; }; class ReflectionPadFunctor { public: ReflectionPadFunctor() { reflect_pad1d_ = CHECK_JUST(one::OpBuilder("reflection_pad1d").Input("x").Output("y").Build()); reflect_pad2d_ = CHECK_JUST(one::OpBuilder("reflection_pad2d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input, const std::vector& pad) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding"); attrs.SetAllAttrs(pad); const int64_t pad_size = pad.size(); const size_t ndim = input->ndim(); CHECK_LE_OR_RETURN(pad_size, 2 * ndim) << Error::RuntimeError() << "Pad size should less than or equal to input axes * 2."; if (pad_size == 2) { // 2D/3D reflect padding CHECK_OR_RETURN((ndim == 2 && input->shape()->At(1) != 0) || (ndim == 3 && input->shape()->At(1) != 0 && input->shape()->At(2) != 0)) << "2D or 3D (batch mode) tensor expected for input, but got: " << ndim; const int64_t pad_left = pad[0]; const int64_t pad_right = pad[1]; const int64_t dim_w = (ndim == 3) ? 2 : 1; const int64_t input_width = input->shape()->At(dim_w); const int64_t output_w = input_width + pad_left + pad_right; CHECK_OR_RETURN(pad_left < input_width && pad_right < input_width) << "Padding size should be less than the corresponding input dimension, but got: " "padding (" << pad_left << ", " << pad_right << ") at dimension " << dim_w << " of input " << input->shape()->ToString(); CHECK_OR_RETURN(output_w >= 1) << "input (W: " << input_width << ")is too small. Calculated output W: " << output_w; if (ndim == 2) { // for 2D input auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0)); auto unsqueezed_output = JUST(OpInterpUtil::Dispatch(*reflect_pad1d_, {unsqueezed_input}, attrs)); return JUST(functional::Squeeze(unsqueezed_output, std::vector{0})); } return OpInterpUtil::Dispatch(*reflect_pad1d_, {input}, attrs); } else if (pad_size == 4) { // 3D/4D reflect padding bool valid_dims = input->shape()->At(1) != 0 && input->shape()->At(2) != 0; CHECK_OR_RETURN((ndim == 3 && valid_dims) || (ndim == 4 && valid_dims && input->shape()->At(3) != 0)) << "3D or 4D (batch mode) tensor expected for input, but got: " << ndim; int dim_h = 1; int dim_w = 2; if (ndim == 4) { dim_w++; dim_h++; } const int64_t pad_left = pad[0]; const int64_t pad_right = pad[1]; const int64_t pad_top = pad[2]; const int64_t pad_bottom = pad[3]; const int64_t input_h = input->shape()->At(dim_h); const int64_t input_w = input->shape()->At(dim_w); const int64_t output_h = input_h + pad_top + pad_bottom; const int64_t output_w = input_w + pad_left + pad_right; CHECK_OR_RETURN(pad_left < input_w && pad_right < input_w) << Error::RuntimeError() << "Padding size should be less than the corresponding input " "dimension, but got: padding (" << pad_left << ", " << pad_right << ") at dimension " << dim_w << " of input " << ndim; CHECK_OR_RETURN(pad_top < input_h && pad_bottom < input_h) << Error::RuntimeError() << "Padding size should be less than the corresponding input " "dimension, but got: padding (" << pad_top << ", " << pad_bottom << ") at dimension " << dim_h << " of input " << ndim; CHECK_OR_RETURN(output_w >= 1 || output_h >= 1) << Error::RuntimeError() << "input (H: " << input_h << ", W: " << input_w << ")is too small. Calculated output H: " << output_h << " W: " << output_w; if (ndim == 3) { // for 3D input auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0)); auto unsqueezed_output = JUST(OpInterpUtil::Dispatch(*reflect_pad2d_, {unsqueezed_input}, attrs)); return JUST(functional::Squeeze(unsqueezed_output, std::vector{0})); } return OpInterpUtil::Dispatch(*reflect_pad2d_, {input}, attrs); } else if (pad_size == 6) { UNIMPLEMENTED_THEN_RETURN() << "5D reflect padding are not supported for now"; } else { UNIMPLEMENTED_THEN_RETURN() << "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now"; } } private: std::shared_ptr reflect_pad1d_; std::shared_ptr reflect_pad2d_; }; class ReplicationPadFunctor { public: ReplicationPadFunctor() { replicate_pad1d_ = CHECK_JUST(one::OpBuilder("replication_pad1d").Input("x").Output("y").Build()); replicate_pad2d_ = CHECK_JUST(one::OpBuilder("replication_pad2d").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input, const std::vector& pad) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding"); attrs.SetAllAttrs(pad); const int64_t pad_size = pad.size(); const size_t ndim = input->ndim(); CHECK_LE_OR_RETURN(pad_size, 2 * ndim) << Error::RuntimeError() << "Pad size should less than or equal to input axes * 2."; if (pad_size == 2) { // 2D/3D replicate padding CHECK_OR_RETURN((ndim == 2 && input->shape()->At(0) != 0 && input->shape()->At(1) != 0) || (ndim == 3 && input->shape()->At(1) != 0 && input->shape()->At(2) != 0)) << "Expected 2D or 3D (batch mode) tensor with possibly 0 batch size and other " "non-zero dimensions for input, but got: " << ndim; const int64_t pad_left = pad[0]; const int64_t pad_right = pad[1]; const int64_t dim_w = (ndim == 3) ? 2 : 1; const int64_t input_width = input->shape()->At(dim_w); const int64_t output_w = input_width + pad_left + pad_right; CHECK_OR_RETURN(output_w >= 1) << "input (W: " << input_width << ")is too small. Calculated output W: " << output_w; if (ndim == 2) { // for 2D input auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0)); auto unsqueezed_output = JUST(OpInterpUtil::Dispatch(*replicate_pad1d_, {unsqueezed_input}, attrs)); return JUST(functional::Squeeze(unsqueezed_output, std::vector{0})); } return OpInterpUtil::Dispatch(*replicate_pad1d_, {input}, attrs); } else if (pad_size == 4) { // 3D/4D replicate padding bool valid_dims = input->shape()->At(1) != 0 && input->shape()->At(2) != 0; CHECK_OR_RETURN((ndim == 3 && valid_dims) || (ndim == 4 && valid_dims && input->shape()->At(3) != 0)) << "3D or 4D (batch mode) tensor expected for input, but got: " << ndim; int dim_h = 1; int dim_w = 2; if (ndim == 4) { dim_w++; dim_h++; } const int64_t pad_left = pad[0]; const int64_t pad_right = pad[1]; const int64_t pad_top = pad[2]; const int64_t pad_bottom = pad[3]; const int64_t input_h = input->shape()->At(dim_h); const int64_t input_w = input->shape()->At(dim_w); const int64_t output_h = input_h + pad_top + pad_bottom; const int64_t output_w = input_w + pad_left + pad_right; CHECK_OR_RETURN(output_w >= 1 || output_h >= 1) << Error::RuntimeError() << "input (H: " << input_h << ", W: " << input_w << ")is too small. Calculated output H: " << output_h << " W: " << output_w; if (ndim == 3) { // for 3D input auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0)); auto unsqueezed_output = JUST(OpInterpUtil::Dispatch(*replicate_pad2d_, {unsqueezed_input}, attrs)); return JUST(functional::Squeeze(unsqueezed_output, std::vector{0})); } return OpInterpUtil::Dispatch(*replicate_pad2d_, {input}, attrs); } else if (pad_size == 6) { UNIMPLEMENTED_THEN_RETURN() << "5D replicate padding are not supported for now"; } else { UNIMPLEMENTED_THEN_RETURN() << "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now"; } } private: std::shared_ptr replicate_pad1d_; std::shared_ptr replicate_pad2d_; }; class PadFunctor { public: Maybe operator()(const std::shared_ptr& input, const std::vector& pad, const std::string& mode, const Scalar& value) const { if (mode == "constant") { return functional::ConstantPad(input, pad, value); } else if (mode == "reflect") { return functional::ReflectionPad(input, pad); } else if (mode == "replicate") { return functional::ReplicationPad(input, pad); } else { UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but only constant, reflect and replicate are valid."; } } }; class DropoutFunctor { public: DropoutFunctor() { dropout_op_ = CHECK_JUST(one::OpBuilder("dropout").Input("in").Output("out").Output("mask").Build()); dropout_addend_op_ = CHECK_JUST(one::OpBuilder("dropout") .Input("in") .Input("_add_to_output") .Output("out") .Output("mask") .Build()); add_op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const float& p, const bool& training, const bool& inplace, const Optional& generator, const Optional& addend) const { auto outputs = std::make_shared(1); if (inplace) { JUST(CheckInplaceValid(x)); (*outputs)[0] = x; } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed"); dropout_attrs.SetAllAttrs(p, static_cast(gen->current_seed())); const auto& dropout_state = std::make_shared(gen); OpExprInterpContext ctx(dropout_attrs, dropout_state); if (addend) { if ((!training) || p == 0.0) { JUST(OpInterpUtil::Dispatch(*add_op_, {x, JUST(addend)}, outputs.get())); } else { outputs->resize(2); JUST(OpInterpUtil::Dispatch(*dropout_addend_op_, {x, JUST(addend)}, outputs.get(), ctx)); } } else { if (!training || p == 0.0) { return x; } else { outputs->resize(2); JUST(OpInterpUtil::Dispatch(*dropout_op_, {x}, outputs.get(), ctx)); } } return (*outputs)[0]; } private: std::shared_ptr dropout_op_; std::shared_ptr dropout_addend_op_; std::shared_ptr add_op_; }; namespace { Maybe MakeFeatureNoise(const std::shared_ptr& x) { const int64_t ndim = x->ndim(); CHECK_GE_OR_RETURN(ndim, 2) << Error::RuntimeError() << "Feature dropout requires at least 2 dimensions in the input"; std::vector sizes; sizes.reserve(ndim); sizes.push_back(x->shape()->At(0)); sizes.push_back(x->shape()->At(1)); for (int i = 2; i < ndim; i++) { sizes.push_back(1); } return JUST(Empty(Shape(sizes), x->dtype(), JUST(x->device()), /*requires_grad=*/x->requires_grad(), /*pin_memory=*/false)); } Maybe DropoutImpl(const std::shared_ptr& input, const float& p, const bool& train) { CHECK_EQ_OR_RETURN(p >= 0 && p <= 1, true) << "dropout probability has to be between 0 and 1, but got " << p; if (p == 0 || !train || input->shape()->elem_cnt() == 0) { return input; } if (p == 1) { std::shared_ptr other = JUST(Constant(*input->shape(), Scalar(0.0), input->dtype(), JUST(input->device()))); return Mul(input, other); } std::shared_ptr noise = JUST(MakeFeatureNoise(input)); noise = JUST(BernoulliProb(noise, 1.0 - p, noise->dtype(), JUST(one::DefaultAutoGenerator()), false)); noise = JUST(InplaceScalarDiv(noise, Scalar(1.0 - p))); return JUST(Mul(input, noise)); } } // namespace class Dropout1dFunctor { public: Maybe operator()(const std::shared_ptr& input, const float& p, const bool& training) const { CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false) << "dropout probability has to be between 0 and 1, but got " << p; const int input_dim = input->ndim(); CHECK_EQ_OR_RETURN(input_dim != 2 && input_dim != 3, false) << "dropout1d: Expected 2D or 3D input, but received a " << input_dim << "D input. " "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " "spatial dimension, a channel dimension, and an optional batch dimension " "(i.e. 2D or 3D inputs)."; bool is_batched = (input_dim == 3); std::shared_ptr result = input; if (!is_batched) { result = JUST(Unsqueeze(input, 0)); } result = JUST(DropoutImpl(result, p, training)); if (!is_batched) { result = JUST(Squeeze(result, std::vector{0})); } return result; } }; class Dropout2dFunctor { public: Maybe operator()(const std::shared_ptr& input, const float& p, const bool& training) const { CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false) << "dropout probability has to be between 0 and 1, but got " << p; const int input_dim = input->ndim(); if (input_dim != 3 && input_dim != 4) { LOG(WARNING) << "dropout2d: Received a " << input_dim << "-D input to dropout2d, which is deprecated " "and will result in an error in a future release. To retain the behavior " "and silence this warning, please use dropout instead. Note that dropout2d " "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."; } if (input_dim == 3) { LOG(WARNING) << "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " "is the channel dim. This behavior will change in a future release to interpret the " "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " "channel-wise dropout behavior, please switch to using dropout1d instead."; } return JUST(DropoutImpl(input, p, training)); } }; class Dropout3dFunctor { public: Maybe operator()(const std::shared_ptr& input, const float& p, const bool& training) const { CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false) << "dropout probability has to be between 0 and 1, but got " << p; const int input_dim = input->ndim(); if (input_dim != 4 && input_dim != 5) { LOG(WARNING) << "dropout3d: Received a " << input_dim << "-D input to dropout3d, which is deprecated " "and will result in an error in a future release. To retain the behavior " "and silence this warning, please use dropout instead. Note that dropout3d " "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)."; } bool is_batched = (input_dim == 5); std::shared_ptr result = input; if (!is_batched) { result = JUST(Unsqueeze(input, 0)); } result = JUST(DropoutImpl(result, p, training)); if (!is_batched) { result = JUST(Squeeze(result, std::vector{0})); } return result; } }; class DropoutGradFunctor { public: DropoutGradFunctor() { dropout_grad_op_ = CHECK_JUST(one::OpBuilder("dropout_grad").Input("dy").Input("mask").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& mask, const float& scale) const { auto& dropout_grad_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale"); dropout_grad_attrs.SetAllAttrs(scale); return OpInterpUtil::Dispatch(*dropout_grad_op_, {dy, mask}, dropout_grad_attrs); } private: std::shared_ptr dropout_grad_op_; }; class UnfoldFunctor { public: UnfoldFunctor() { unfold_op_ = CHECK_JUST(one::OpBuilder("unfold").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& kernel_size, const std::vector& dilation_rate, const std::vector& padding, const std::vector& strides, const std::string& data_format) const { const auto& x_shape = x->shape(); // Only Support 4d tensor now. CHECK_EQ_OR_RETURN(x_shape->NumAxes(), 4) << Error::RuntimeError() << "Input Tensor dim should == 4"; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("kernel_size", "dilation_rate", "padding", "strides", "data_format"); attrs.SetAllAttrs(kernel_size, dilation_rate, padding, strides, data_format); return OpInterpUtil::Dispatch(*unfold_op_, {x}, attrs); } private: std::shared_ptr unfold_op_; }; class FoldFunctor { public: FoldFunctor() { fold_op_ = CHECK_JUST(one::OpBuilder("fold").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::vector& output_size, const std::vector& kernel_size, const std::vector& dilation_rate, const std::vector& padding, const std::vector& strides, const std::string& data_format) const { const auto& x_shape = x->shape(); // Only Support 3d tensor fold now. format is (N, C*K*K, L) CHECK_EQ_OR_RETURN(x_shape->NumAxes(), 3) << Error::RuntimeError() << "Input Tensor dim should == 3"; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("output_size", "kernel_size", "dilation_rate", "padding", "strides", "data_format"); attrs.SetAllAttrs(output_size, kernel_size, dilation_rate, padding, strides, data_format); return OpInterpUtil::Dispatch(*fold_op_, {x}, attrs); } private: std::shared_ptr fold_op_; }; class OneHotFunctor { public: OneHotFunctor() { one_hot_op_ = CHECK_JUST(one::OpBuilder("one_hot").Input("indices").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const int64_t& num_classes, const Scalar& on_value, const Scalar& off_value) const { CHECK_OR_RETURN(!IsFloatingDataType(input->dtype()->data_type())) << Error::RuntimeError() << "one_hot is only applicable to index tensor."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth", "dtype", "floating_on_value", "floating_off_value", "integer_on_value", "integer_off_value"); int64_t depth = num_classes; if (num_classes == -1) { std::vector axis(input->ndim()); std::iota(axis.begin(), axis.end(), 0); auto tensor_max = JUST(functional::ReduceMax(input, axis, false)); int64_t max = 0; const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(tensor_max, callback, "const")); depth = max + 1; } // Refer to: https://github.com/Oneflow-Inc/oneflow/pull/5315/files#r755823506 bool is_on_value_double = on_value.IsFloatingPoint(); bool is_off_value_double = off_value.IsFloatingPoint(); if (is_on_value_double || is_off_value_double) { attrs.SetAllAttrs(depth, kFloat, on_value.As(), off_value.As(), static_cast(0), static_cast(0)); } else { attrs.SetAllAttrs(depth, kInt64, static_cast(0), static_cast(0), on_value.As(), off_value.As()); } return OpInterpUtil::Dispatch(*one_hot_op_, {input}, attrs); } private: std::shared_ptr one_hot_op_; }; class PairwiseDistanceFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const float& p, const double& eps, bool keepdim) const { const int64_t xdim = x->ndim(); const int64_t ydim = y->ndim(); const int64_t output_dim = xdim > ydim ? xdim : ydim; const auto& sub = JUST(ScalarAdd(JUST(Sub(x, y, 1, false)), eps, 1, false)); return ScalarNorm(sub, p, output_dim - 1, keepdim, NullOpt); } }; class CosineSimilarityFunctor { public: Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const int32_t& dim, const double& eps) const { const auto& x_shape = *(x->shape()); const auto& y_shape = *(y->shape()); std::shared_ptr x_extend = x; std::shared_ptr y_extend = y; if (x_shape != y_shape) { Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), y_shape.NumAxes())); for (int64_t i = max_shape.NumAxes() - 1; i >= 0; i--) { int64_t offset = max_shape.NumAxes() - 1 - i; int64_t dim_x = x_shape.NumAxes() - 1 - offset; int64_t dim_y = y_shape.NumAxes() - 1 - offset; int64_t size_x = (dim_x >= 0) ? x_shape.At(dim_x) : 1; int64_t size_y = (dim_y >= 0) ? y_shape.At(dim_y) : 1; if (!(size_x == size_y || size_x == 1 || size_y == 1)) { return Error::RuntimeError() << "The size of tensor a (" << size_x << ") must match the size of tensor b (" << size_y << ") at non-singleton dimension " << i; } max_shape.Set(i, std::max(size_x, size_y)); } x_extend = JUST(Expand(x, max_shape)); y_extend = JUST(Expand(y, max_shape)); } TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x_extend, y_extend}).Apply()); TensorTuple input_vec = JUST(tensor_processor.GetInputs()); const auto common_dtype = JUST(oneflow::VectorAt(input_vec, 0))->dtype(); if (!IsFloatingDataType(common_dtype->data_type())) { return Error::RuntimeError() << "expected common dtype to be floating point, yet common dtype is " << common_dtype->name(); } auto& x_ = JUST(oneflow::VectorAt(input_vec, 0)); auto& y_ = JUST(oneflow::VectorAt(input_vec, 1)); std::shared_ptr w12 = JUST(functional::ReduceSum(JUST(functional::Mul(x_, y_)), {dim}, false, NullOpt)); std::shared_ptr w1 = JUST(functional::ReduceSum(JUST(functional::Mul(x_, x_)), {dim}, false, NullOpt)); std::shared_ptr w2 = JUST(functional::ReduceSum(JUST(functional::Mul(y_, y_)), {dim}, false, NullOpt)); std::shared_ptr n12 = JUST(functional::Sqrt( JUST(functional::Clamp(JUST(functional::Mul(w1, w2)), Scalar(eps * eps), NullOpt)))); return functional::Div(w12, n12); } }; class L2NormalizeFunctor { public: L2NormalizeFunctor() { op_ = CHECK_JUST( one::OpBuilder("l2_normalize").Input("x").Output("y").Output("square_x_sum").Build()); } Maybe operator()(const std::shared_ptr& input, const int32_t& axis, const float& epsilon) const { const int32_t ndims = input->shape()->NumAxes(); const int32_t final_dim = ndims - 1; auto axis_ = axis >= 0 ? axis : axis + ndims; CHECK_GE_OR_RETURN(axis_, 0) << Error::RuntimeError() << "Axis should >=0 but axis is " << axis_ << " now."; CHECK_LE_OR_RETURN(axis_, final_dim) << Error::RuntimeError() << "Axis should < " << ndims << " but axis is " << axis_ << " now."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("epsilon", "axis"); attrs.SetAllAttrs(epsilon, final_dim); if (axis_ == final_dim) { return OpInterpUtil::Dispatch(*op_, {input}, attrs); } std::vector input_perm(input->shape()->dim_vec().size(), 0); for (size_t i = 0; i < input_perm.size(); ++i) { input_perm[i] = static_cast(i); } std::swap(input_perm[final_dim], input_perm[static_cast(axis_)]); const auto result = JUST(OpInterpUtil::Dispatch( *op_, {JUST(functional::Transpose(input, input_perm))}, attrs)); return functional::Transpose((*result)[0], input_perm); } private: std::shared_ptr op_; }; class NormalizeFunctor { public: Maybe operator()(const std::shared_ptr& input, const float& p, const int32_t& dim, const float& eps, const bool& use_l2_norm_kernel) const { if (use_l2_norm_kernel && (std::fabs(p - 2.0f) < std::numeric_limits::min())) { return functional::L2Normalize(input, dim, eps); } return SequenceFunction(const std::shared_ptr&, const float&, const int32_t&)>( [](const auto& x, const float& p, const int32_t& dim) -> Maybe { return functional::ScalarNorm(x, p, dim, true, NullOpt); }) .then([&](const auto& x) { return functional::Clamp(x, eps, NullOpt); }) .then([&](const auto& x) { return functional::Div(input, x); }) .call(input, p, dim); } }; class FusedSelfAttentionFunctor { public: FusedSelfAttentionFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value") .Input("hidden_states") .Output("query_mul_key") .Output("value") .Build()); } Maybe operator()(const std::shared_ptr& hidden_states, const int64_t& head_size, const float& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("head_size", "alpha"); attrs.SetAllAttrs(head_size, alpha); return OpInterpUtil::Dispatch(*op_, {hidden_states}, attrs); } private: std::shared_ptr op_; }; class FusedSelfAttentionGradFunctor { public: FusedSelfAttentionGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value_grad") .Input("query_mul_key_grad") .Input("value_grad") .Input("hidden_states") .Output("hidden_states_grad") .Build()); } Maybe operator()(const std::shared_ptr& query_mul_key_grad, const std::shared_ptr& value_grad, const std::shared_ptr& hidden_states, const float& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {query_mul_key_grad, value_grad, hidden_states}, attrs); } private: std::shared_ptr op_; }; class FusedScaleTrilSoftmaxMaskScaleFunctor { public: FusedScaleTrilSoftmaxMaskScaleFunctor() { random_mask_like_op_ = CHECK_JUST(one::OpBuilder("random_mask_like").Input("like").Output("out").Build()); fused_op_ = CHECK_JUST(one::OpBuilder("fused_tril_scale_softmax_mask_scale") .Input("x") .Input("mask") .Output("y") .Output("softmax_y") .Build()); } Maybe operator()(const std::shared_ptr& x, const float p, const int64_t diagonal, const float tril_scale_value, const float tril_fill_value, const Optional& generator) const { auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed"); random_mask_like_attrs.SetAllAttrs(p, static_cast(gen->current_seed())); const auto& random_mask_like_state = std::make_shared(gen); const auto& mask = JUST(OpInterpUtil::Dispatch( *random_mask_like_op_, {x}, OpExprInterpContext(random_mask_like_attrs, random_mask_like_state))); float mask_scale_value = 1.0; if (p != 1.0) { mask_scale_value = 1.0 / (1.0 - p); } auto& fused_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal", "tril_scale_value", "mask_scale_value", "tril_fill_value"); fused_attrs.SetAllAttrs(diagonal, tril_scale_value, mask_scale_value, tril_fill_value); return OpInterpUtil::Dispatch(*fused_op_, {x, mask}, fused_attrs); } private: std::shared_ptr fused_op_; std::shared_ptr random_mask_like_op_; }; class L2NormalizeGradFunctor { public: L2NormalizeGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("l2_normalize_grad") .Input("dy") .Input("y") .Input("square_x_sum") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& y, const std::shared_ptr& square_x_sum, const int32_t& axis, const float& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "epsilon"); attrs.SetAllAttrs(axis, epsilon); return OpInterpUtil::Dispatch(*op_, {dy, y, square_x_sum}, attrs); } private: std::shared_ptr op_; }; class FusedBiasAddGeluFunctor { public: FusedBiasAddGeluFunctor() { op_ = CHECK_JUST( one::OpBuilder("fused_bias_add_gelu").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const int32_t& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {a, b}, attrs); } private: std::shared_ptr op_; }; class FusedBiasAddGeluGradFunctor { public: FusedBiasAddGeluGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_bias_add_gelu_grad") .Input("a") .Input("b") .Input("dy") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const std::shared_ptr& dy, const int32_t& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {a, b, dy}, attrs); } private: std::shared_ptr op_; }; class FusedBiasAddDropoutFunctor { public: FusedBiasAddDropoutFunctor() { random_mask_like_op_ = CHECK_JUST(one::OpBuilder("random_mask_like").Input("like").Output("out").Build()); fused_bias_add_mask_scale_op_ = CHECK_JUST(one::OpBuilder("fused_bias_add_mask_scale") .Input("a") .Input("b") .Input("mask") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, const float& p, const int32_t& axis, const Optional& generator) const { int32_t axis_val = axis; if (axis_val < 0) { const int64_t num_axes = a->shape()->NumAxes(); axis_val += num_axes; } if (p > 0.0) { auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), a)); auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed"); random_mask_like_attrs.SetAllAttrs(p, static_cast(gen->current_seed())); const auto& random_mask_like_state = std::make_shared(gen); float scale = 0.0; if (p != 1.0) { scale = 1.0 / (1.0 - p); } auto& fused_bias_add_mask_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale", "axis"); fused_bias_add_mask_attrs.SetAllAttrs(scale, axis_val); return SequenceFunction()>([&]() -> Maybe { return OpInterpUtil::Dispatch( *random_mask_like_op_, {a}, OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)); }) .then([&](const std::shared_ptr& x) { return OpInterpUtil::Dispatch(*fused_bias_add_mask_scale_op_, {a, b, x}, fused_bias_add_mask_attrs); }) .call(); } else { return functional::BiasAdd(a, b, axis_val); } } private: std::shared_ptr random_mask_like_op_; std::shared_ptr fused_bias_add_mask_scale_op_; }; class FusedGluFunctor { public: FusedGluFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_glu") .Input("x") .Input("w") .Input("b") .Output("y") .Output("matmul_wx") .Build()); op_without_bias_ = CHECK_JUST( one::OpBuilder("fused_glu").Input("x").Input("w").Output("y").Output("matmul_wx").Build()); split_op_ = CHECK_JUST(one::OpBuilder("fused_glu") .Input("x") .Input("w") .Input("b") .Input("v") .Input("c") .Output("y") .Output("matmul_wx") .Output("matmul_vx") .Build()); split_op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_glu") .Input("x") .Input("w") .Input("v") .Output("y") .Output("matmul_wx") .Output("matmul_vx") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& w, const Optional& b, const Optional& v, const Optional& c, const std::string& activation) const { // check whether the user provide weight tensor v bool is_split_mode = false; if (v) { is_split_mode = true; } else { is_split_mode = false; } // check whether the user provide bias tensors bool has_bias = false; if (b) { has_bias = true; if (is_split_mode) { CHECK_OR_RETURN(c) << "expected existance of c, when provide tensors w, v and b"; } } else { CHECK_OR_RETURN(!c) << "expected existance of b while providing c"; has_bias = false; } // obtain input shape const auto& x_shape = *(x->shape()); const auto& w_shape = *(w->shape()); std::shared_ptr b_shape = nullptr; if (has_bias) { b_shape = (JUST(b)->shape()); } // check number of axes of x, w and b CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1) << "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); CHECK_EQ_OR_RETURN(w_shape.NumAxes(), 2) << "number of axes of \'w\' should have be equal to 2, yet get " << w_shape.NumAxes(); if (has_bias) { CHECK_EQ_OR_RETURN(b_shape->NumAxes(), 1) << "number of axes of \'b\' should have be equal to 1, yet get " << b_shape->NumAxes(); } // check input shapes of w and b size_t x_num_axes = x_shape.NumAxes(); CHECK_EQ_OR_RETURN(w_shape.At(1), x_shape.At(x_num_axes - 1)) << "dimension 1 of \'w\'(" << w_shape.At(1) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_num_axes - 1) << ")"; if (has_bias) { CHECK_EQ_OR_RETURN(b_shape->At(0), w_shape.At(0)) << "dimension 0 of \'b\'(" << b_shape->At(0) << ") is not consistant with dimension 0 of \'w\'(" << w_shape.At(0) << ")"; } if (!is_split_mode) { CHECK_EQ_OR_RETURN(w_shape.At(1) % 2, 0) << "dimension 1 of \'w\' is not divisible by 2"; } // check both dimensions and input shapes of v and c (optional) if (is_split_mode) { const auto& v_shape = *(JUST(v)->shape()); std::shared_ptr c_shape = NULL; if (has_bias) { c_shape = (JUST(c)->shape()); } CHECK_EQ_OR_RETURN(v_shape.NumAxes(), 2) << "number of axes of \'v\' should have be equal to 2, yet get " << v_shape.NumAxes(); if (has_bias) { CHECK_EQ_OR_RETURN(c_shape->NumAxes(), 1) << "number of axes of \'c\' should have be equal to 1, yet get " << c_shape->NumAxes(); } CHECK_OR_RETURN(v_shape == w_shape) << "the shape of \'v\' is not consistant with \'w\'"; if (has_bias) { CHECK_OR_RETURN((*c_shape) == (*b_shape)) << "the shape of \'c\' is not consistant with \'b\'"; } } // set activation attribute auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("activation", "has_bias", "is_split"); attrs.SetAllAttrs(activation, has_bias, is_split_mode); // dispatch corresponding operator if (is_split_mode && has_bias) { return OpInterpUtil::Dispatch(*split_op_, {x, w, JUST(b), JUST(v), JUST(c)}, attrs); } else if (!is_split_mode && has_bias) { return OpInterpUtil::Dispatch(*op_, {x, w, JUST(b)}, attrs); } else if (is_split_mode && !has_bias) { return OpInterpUtil::Dispatch(*split_op_without_bias_, {x, w, JUST(v)}, attrs); } else if (!is_split_mode && !has_bias) { return OpInterpUtil::Dispatch(*op_without_bias_, {x, w}, attrs); } else { UNIMPLEMENTED_THEN_RETURN(); } } private: std::shared_ptr op_; std::shared_ptr op_without_bias_; std::shared_ptr split_op_; std::shared_ptr split_op_without_bias_; }; class FusedScaleTrilFunctor { public: FusedScaleTrilFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_scale_tril").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal, const Scalar& fill_value, const Scalar& scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "diagonal", "floating_fill_value", "is_floating_fill_value", "integer_fill_value", "floating_scale_value", "is_floating_scale_value", "integer_scale_value"); bool is_fill_value_double = fill_value.IsFloatingPoint(); bool is_scale_double = scale.IsFloatingPoint(); double floating_fill_value = 0; int64_t integer_fill_value = 0; if (is_fill_value_double) { floating_fill_value = fill_value.As(); } else { integer_fill_value = fill_value.As(); } double floating_scale_value = 0; int64_t integer_scale_value = 0; if (is_scale_double) { floating_scale_value = scale.As(); } else { integer_scale_value = scale.As(); } attrs.SetAllAttrs(diagonal, floating_fill_value, is_fill_value_double, integer_fill_value, floating_scale_value, is_scale_double, integer_scale_value); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class FusedScaleMaskSoftmaxFunctor { public: FusedScaleMaskSoftmaxFunctor() { op_ = CHECK_JUST( one::OpBuilder("fused_scale_mask_softmax").Input("x").Input("mask").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& mask, const float& fill_value, const float& scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_value", "mask_fill_value"); attrs.SetAllAttrs(scale, fill_value); return OpInterpUtil::Dispatch(*op_, {x, mask}, attrs); } private: std::shared_ptr op_; }; class FusedScaleMaskSoftmaxDropoutFunctor { public: FusedScaleMaskSoftmaxDropoutFunctor() { random_mask_like_op_ = CHECK_JUST(one::OpBuilder("random_mask_like").Input("like").Output("out").Build()); fused_scale_mask_softmax_dropout_op_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_softmax_dropout") .Input("x") .Input("mask") .Input("dropout_mask") .Output("y") .Output("softmax_y") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& mask, const float& fill_value, const float& scale, const float& p, const bool& training, const Optional& generator) const { float rate = p; if (!training) rate = 0.0; auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed"); random_mask_like_attrs.SetAllAttrs(rate, static_cast(gen->current_seed())); const auto& random_mask_like_state = std::make_shared(gen); const auto& dropout_mask = JUST(OpInterpUtil::Dispatch( *random_mask_like_op_, {x}, OpExprInterpContext(random_mask_like_attrs, random_mask_like_state))); float dropout_scale = 0.0; if (rate != 1.0) { dropout_scale = 1.0 / (1.0 - rate); } auto& fused_scale_mask_softmax_dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_value", "mask_fill_value", "dropout_scale_value"); fused_scale_mask_softmax_dropout_attrs.SetAllAttrs(scale, fill_value, dropout_scale); return OpInterpUtil::Dispatch(*fused_scale_mask_softmax_dropout_op_, {x, mask, dropout_mask}, fused_scale_mask_softmax_dropout_attrs); } private: std::shared_ptr random_mask_like_op_; std::shared_ptr fused_scale_mask_softmax_dropout_op_; }; // Equivalent to // masked = (x + bias) * mask * scale_value // unmask = (1 - mask).bool() // masked.masked_fill_(unmask, mask_fill_value) // softmax_y = softmax(masked, dim=-1) // y = dropout(softmax_y, p) class FusedBiasAddScaleMaskSoftmaxDropoutFunctor { public: FusedBiasAddScaleMaskSoftmaxDropoutFunctor() { random_mask_op_ = CHECK_JUST(one::OpBuilder("random_mask_like").Input("like").Output("out").Build()); fused_op_ = CHECK_JUST(one::OpBuilder("fused_bias_add_scale_mask_softmax_dropout") .Input("x") .Input("bias") .Input("mask") .Input("dropout_mask") .Output("y") .Output("softmax_y") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& bias, const std::shared_ptr& mask, const float& fill_value, const float& scale, const float& p, const bool& training, const Optional& generator) const { float rate = p; if (!training) rate = 0.0; auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed"); random_mask_like_attrs.SetAllAttrs(rate, static_cast(gen->current_seed())); const auto& random_mask_like_state = std::make_shared(gen); const auto& dropout_mask = JUST(OpInterpUtil::Dispatch( *random_mask_op_, {x}, OpExprInterpContext(random_mask_like_attrs, random_mask_like_state))); float dropout_scale = 0.0; if (rate != 1.0) { dropout_scale = 1.0 / (1.0 - rate); } auto& fused_scale_mask_softmax_dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_value", "mask_fill_value", "dropout_scale_value"); fused_scale_mask_softmax_dropout_attrs.SetAllAttrs(scale, fill_value, dropout_scale); return OpInterpUtil::Dispatch(*fused_op_, {x, bias, mask, dropout_mask}, fused_scale_mask_softmax_dropout_attrs); } private: std::shared_ptr random_mask_op_; std::shared_ptr fused_op_; }; class CtcGreedyDecoderFunctor { public: CtcGreedyDecoderFunctor() { op_ = CHECK_JUST(one::OpBuilder("ctc_greedy_decoder") .Input("log_probs") .Input("input_lengths") .Output("decoded") .Output("neg_sum_logits") .Build()); } Maybe operator()(const std::shared_ptr& log_probs, const std::shared_ptr& input_lengths, const bool& merge_repeated) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("merge_repeated"); attrs.SetAllAttrs(merge_repeated); return OpInterpUtil::Dispatch(*op_, {log_probs, input_lengths}, attrs); } private: std::shared_ptr op_; }; class PariticalFCSampleDisableBoxing { public: PariticalFCSampleDisableBoxing() { op_ = CHECK_JUST(one::OpBuilder("distributed_partial_fc_sample_disable_boxing") .Input("sampled_weight_diff") .Input("sampled_label") .Output("boxing_disabled_sampled_weight_diff") .Output("boxing_disabled_sampled_label") .Build()); } Maybe operator()(const std::shared_ptr& sampled_weight_diff, const std::shared_ptr& sampled_label) const { return OpInterpUtil::Dispatch(*op_, {sampled_weight_diff, sampled_label}); } private: std::shared_ptr op_; }; class NmsFunctor { public: NmsFunctor() { op_ = CHECK_JUST(one::OpBuilder("nms").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const float& iou_threshold, const int32_t& keep_n) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("iou_threshold", "keep_n"); attrs.SetAllAttrs(iou_threshold, keep_n); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_; }; class RoiAlignFunctor { public: RoiAlignFunctor() { op_ = CHECK_JUST(one::OpBuilder("roi_align").Input("x").Input("rois").Output("y").Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& rois, const float& spatial_scale, const int32_t& pooled_h, const int32_t& pooled_w, const int32_t& sampling_ratio, const bool& aligned) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("spatial_scale", "pooled_h", "pooled_w", "sampling_ratio", "aligned"); attrs.SetAllAttrs(spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned); return OpInterpUtil::Dispatch(*op_, {x, rois}, attrs); } private: std::shared_ptr op_; }; class RoiAlignGradFunctor { public: RoiAlignGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("roi_align_grad") .Input("dy") .Input("x_like") .Input("rois") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x_like, const std::shared_ptr& rois, const float& spatial_scale, const int32_t& pooled_h, const int32_t& pooled_w, const int32_t& sampling_ratio, const bool& aligned) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("spatial_scale", "pooled_h", "pooled_w", "sampling_ratio", "aligned"); attrs.SetAllAttrs(spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned); return OpInterpUtil::Dispatch(*op_, {dy, x_like, rois}, attrs); } private: std::shared_ptr op_; }; class FusedDotFeatureInteractionFunctor { public: FusedDotFeatureInteractionFunctor() { ops_has_output_concat_.resize(kMaxInputCount); ops_no_output_concat_.resize(kMaxInputCount); for (int n = 0; n < ops_has_output_concat_.size(); ++n) { ops_has_output_concat_[n] = CHECK_JUST(one::OpBuilder("fused_dot_feature_interaction") .Input("features", n + 1) .Input("output_concat") .Output("out") .Build()); } for (int n = 0; n < ops_no_output_concat_.size(); ++n) { ops_no_output_concat_[n] = CHECK_JUST(one::OpBuilder("fused_dot_feature_interaction") .Input("features", n + 1) .Output("out") .Build()); } } Maybe operator()(const TensorTuple& features, const Optional& output_concat, const bool& self_interaction, const int32_t& output_padding, const std::string& pooling) const { const int64_t n_features = features.size(); TensorTuple inputs; if (n_features > kMaxInputCount) { inputs.push_back(JUST(functional::Concat(features, 1))); } else { inputs = features; } CHECK_OR_RETURN(pooling == "sum" || pooling == "none") << Error::RuntimeError() << "pooling should be sum or none, but get " << pooling; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("self_interaction", "output_padding", "pooling", "has_output_concat"); if (pooling == "sum") { CHECK_EQ_OR_RETURN(output_padding, 0) << Error::RuntimeError() << "output_padding should be equal to 0. "; CHECK_OR_RETURN(!output_concat) << Error::RuntimeError() << "output_concat should not exist"; attrs.SetAllAttrs(self_interaction, output_padding, pooling, false); const std::shared_ptr& bi_interaction = JUST(OpInterpUtil::Dispatch( *JUST(oneflow::VectorAt(ops_no_output_concat_, n_features - 1)), inputs, attrs)); std::vector reduce_axes_vec = {1}; return functional::ReduceSum(bi_interaction, reduce_axes_vec, true, NullOpt); } if (output_concat) { attrs.SetAllAttrs(self_interaction, output_padding, pooling, true); inputs.push_back(JUST(output_concat)); return OpInterpUtil::Dispatch( *JUST(oneflow::VectorAt(ops_has_output_concat_, n_features - 1)), inputs, attrs); } else { attrs.SetAllAttrs(self_interaction, output_padding, pooling, false); return OpInterpUtil::Dispatch( *JUST(oneflow::VectorAt(ops_no_output_concat_, n_features - 1)), inputs, attrs); } } private: std::vector> ops_has_output_concat_; std::vector> ops_no_output_concat_; }; class FusedCrossFeatureInteractionFunctor { public: FusedCrossFeatureInteractionFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_cross_feature_interaction") .Input("x") .Input("weight") .Input("x0") .Input("bias") .Output("out") .Output("matmul_result") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, const std::shared_ptr& x0, const std::shared_ptr& bias, const std::string& interaction_mode) const { if (interaction_mode != "vector" && interaction_mode != "matrix") { UNIMPLEMENTED_THEN_RETURN() << "Fused Cross Interaction mode only support `vector` and `matrix`. "; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("interaction_mode"); attrs.SetAllAttrs(interaction_mode); return OpInterpUtil::Dispatch(*op_, {x, weight, x0, bias}, attrs); } private: std::shared_ptr op_; }; class OneEmbeddingIdShuffleFunctor { public: OneEmbeddingIdShuffleFunctor() { op_table_ids_has_in_out_ = CHECK_JUST(one::OpBuilder("id_shuffle") .Input("ids") .Input("table_ids") .Output("num_unique_matrix") .Output("inverse_unique_partition_indices") .Output("cur_rank_num_unique") .Output("cur_rank_unique_ids") .Output("cur_rank_unique_table_ids") .Output("cur_rank_inverse_indices") .Build()); op_table_ids_no_in_has_out_ = CHECK_JUST(one::OpBuilder("id_shuffle") .Input("ids") .Output("num_unique_matrix") .Output("inverse_unique_partition_indices") .Output("cur_rank_num_unique") .Output("cur_rank_unique_ids") .Output("cur_rank_unique_table_ids") .Output("cur_rank_inverse_indices") .Build()); } Maybe operator()(const std::shared_ptr& ids, const Optional& table_ids, const int32_t& num_tables, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_tables", "embedding_name"); attrs.SetAllAttrs(num_tables, embedding_name); if (table_ids) { return OpInterpUtil::Dispatch(*op_table_ids_has_in_out_, {ids, JUST(table_ids)}, attrs); } else { return OpInterpUtil::Dispatch(*op_table_ids_no_in_has_out_, {ids}, attrs); } } private: std::shared_ptr op_table_ids_has_in_out_; std::shared_ptr op_table_ids_no_in_has_out_; }; class OneEmbeddingEmbeddingShuffleFunctor { public: OneEmbeddingEmbeddingShuffleFunctor() { op_ = CHECK_JUST(one::OpBuilder("embedding_shuffle") .Input("cur_rank_embeddings") .Input("num_unique_matrix") .Input("cur_rank_inverse_indices") .Input("inverse_unique_partition_indices") .Output("embeddings") .Build()); } Maybe operator()(const std::shared_ptr& cur_rank_embeddings, const std::shared_ptr& num_unique_matrix, const std::shared_ptr& cur_rank_inverse_indices, const std::shared_ptr& inverse_unique_partition_indices, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("embedding_size", "embedding_name"); const int64_t num_axes = cur_rank_embeddings->shape()->NumAxes(); attrs.SetAllAttrs(cur_rank_embeddings->shape()->At(num_axes - 1), embedding_name); return OpInterpUtil::Dispatch( *op_, {cur_rank_embeddings, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices}, attrs); } private: std::shared_ptr op_; }; class OneEmbeddingEmbeddingGradientShuffleFunctor { public: OneEmbeddingEmbeddingGradientShuffleFunctor() { op_ = CHECK_JUST(one::OpBuilder("embedding_gradient_shuffle") .Input("embedding_grad") .Input("num_unique_matrix") .Input("cur_rank_inverse_indices") .Input("inverse_unique_partition_indices") .Output("cur_rank_unique_embedding_grad") .Build()); } Maybe operator()(const std::shared_ptr& embedding_grad, const std::shared_ptr& num_unique_matrix, const std::shared_ptr& cur_rank_inverse_indices, const std::shared_ptr& inverse_unique_partition_indices, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("embedding_size", "embedding_name"); const int64_t num_axes = embedding_grad->shape()->NumAxes(); attrs.SetAllAttrs(embedding_grad->shape()->At(num_axes - 1), embedding_name); return OpInterpUtil::Dispatch( *op_, {embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices}, attrs); } private: std::shared_ptr op_; }; class OneEmbeddingLookupFunctor { public: OneEmbeddingLookupFunctor() { op_ = CHECK_JUST(one::OpBuilder("embedding_lookup") .Input("num_unique_ids") .Input("unique_ids") .Input("table_ids") .Output("unique_values") .Build()); } Maybe operator()(const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_ids, const std::shared_ptr& table_ids, const Symbol& dtype, const Symbol& embedding_dtype, const int64_t line_size, const int64_t embedding_size, const std::string& embedding_name, const std::string& embedding_tables, const std::string& state_initializer, const int64_t seed) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype", "embedding_dtype", "line_size", "embedding_size", "embedding_name", "embedding_tables", "state_initializer", "seed"); attrs.SetAllAttrs(dtype->data_type(), embedding_dtype->data_type(), line_size, embedding_size, embedding_name, embedding_tables, state_initializer, seed); return OpInterpUtil::Dispatch(*op_, {num_unique_ids, unique_ids, table_ids}, attrs); } private: std::shared_ptr op_; }; class OneEmbeddingFusedLookupFunctor { public: OneEmbeddingFusedLookupFunctor() { op_has_table_ids_ = CHECK_JUST(one::OpBuilder("one_embedding_fused_lookup") .Input("shadow") .Input("ids") .Input("table_ids") .Output("embeddings") .Build()); op_no_table_ids_ = CHECK_JUST(one::OpBuilder("one_embedding_fused_lookup") .Input("shadow") .Input("ids") .Output("embeddings") .Build()); } Maybe operator()(const std::shared_ptr& shadow, const std::shared_ptr& ids, const Optional& table_ids, const Symbol& dtype, const std::string& embedding_name, const int64_t line_size, const int64_t embedding_size, const bool is_full_cache, const int32_t num_tables, const std::string& embedding_tables, const Optional& padding_idx, const int64_t seed) const { int64_t padding_idx_val = -1; bool has_padding_idx = false; if (padding_idx.has_value()) { padding_idx_val = JUST(padding_idx); has_padding_idx = true; } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "dtype", "embedding_name", "line_size", "embedding_size", "is_full_cache", "num_tables", "embedding_tables", "seed", "padding_idx", "has_padding_idx"); attrs.SetAllAttrs(dtype->data_type(), embedding_name, line_size, embedding_size, is_full_cache, num_tables, embedding_tables, seed, padding_idx_val, has_padding_idx); if (table_ids) { const auto& table_ids_shape = *(JUST(table_ids)->shape()); const auto& ids_shape = *(ids->shape()); auto broadcast_table_ids = JUST(table_ids); if (table_ids_shape != ids_shape) { CHECK_LE_OR_RETURN(table_ids_shape.NumAxes(), ids_shape.NumAxes()) << "table_ids num_axes should be less equal to ids num_axes, but got table_ids " "num_axes " << table_ids_shape.NumAxes() << " and ids num_axes " << ids_shape.NumAxes(); const int64_t left_extend_dims = ids_shape.NumAxes() - table_ids_shape.NumAxes(); for (int64_t i = 0; i < table_ids_shape.NumAxes(); i++) { CHECK_EQ_OR_RETURN(table_ids_shape.at(i), ids_shape.at(left_extend_dims + i)) << "when table_ids's shape not equals ids shape, table_ids must be able to be " "broadcast to ids_shape " "but got table_ids_shape: " << table_ids_shape.DebugStr() << ", ids_shape: " << ids_shape.DebugStr(); } broadcast_table_ids = JUST(functional::BroadcastLike(JUST(table_ids), ids, std::vector{})); } return OpInterpUtil::Dispatch(*op_has_table_ids_, {shadow, ids, broadcast_table_ids}, attrs); } else { return OpInterpUtil::Dispatch(*op_no_table_ids_, {shadow, ids}, attrs); } } private: std::shared_ptr op_has_table_ids_; std::shared_ptr op_no_table_ids_; }; class OneEmbeddingFusedLookupGradFunctor { public: OneEmbeddingFusedLookupGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("one_embedding_fused_lookup_grad") .Input("ids") .Input("embedding_grad") .Build()); } Maybe operator()(const std::shared_ptr& ids, const std::shared_ptr& embedding_grad, const std::string& embedding_name, const int64_t line_size, const int64_t embedding_size) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("embedding_name", "line_size", "embedding_size"); attrs.SetAllAttrs(embedding_name, line_size, embedding_size); JUST(OpInterpUtil::Dispatch(*op_, {ids, embedding_grad}, attrs)); return Maybe::Ok(); } private: std::shared_ptr op_; }; class OneEmbeddingEmbeddingPutFunctor { public: OneEmbeddingEmbeddingPutFunctor() { op_ = CHECK_JUST(one::OpBuilder("embedding_put") .Input("num_unique_ids") .Input("unique_ids") .Input("unique_embeddings") .Build()); } Maybe operator()(const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_ids, const std::shared_ptr& unique_embeddings, const std::string& embedding_name, const int64_t line_size) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("embedding_name", "line_size"); attrs.SetAllAttrs(embedding_name, line_size); JUST(OpInterpUtil::Dispatch(*op_, {num_unique_ids, unique_ids, unique_embeddings}, attrs)); return Maybe::Ok(); } private: std::shared_ptr op_; }; class OneEmbeddingUniqueKeyValuePairFunctor { public: OneEmbeddingUniqueKeyValuePairFunctor() { op_has_input_value_ = CHECK_JUST(one::OpBuilder("unique_key_value_pair") .Input("keys") .Input("values") .Output("num_unique") .Output("unique_keys") .Output("unique_values") .Output("inverse_indices") .Build()); op_no_input_value_ = CHECK_JUST(one::OpBuilder("unique_key_value_pair") .Input("keys") .Output("num_unique") .Output("unique_keys") .Output("unique_values") .Output("inverse_indices") .Build()); } Maybe operator()(const std::shared_ptr& keys, const Optional& values, const int32_t num_tables, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_tables", "embedding_name"); attrs.SetAllAttrs(num_tables, embedding_name); if (values) { return OpInterpUtil::Dispatch(*op_has_input_value_, {keys, JUST(values)}, attrs); } else { return OpInterpUtil::Dispatch(*op_no_input_value_, {keys}, attrs); } } private: std::shared_ptr op_has_input_value_; std::shared_ptr op_no_input_value_; }; class OneEmbeddingSgdUpdateFunctor { public: OneEmbeddingSgdUpdateFunctor() { // This functor is only used in one_embedding eager mode with lr passed by attr and no optional // input, we also define functor with all optional input just for unittest. when the optional // input learning_rate tensor has passed in, we think all optional input are not None and check // them. sgd_no_optional_input_op_ = CHECK_JUST(one::OpBuilder("one_embedding_sgd_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Output("updated_unique_embeddings") .Build()); momentum_no_optional_input_op_ = CHECK_JUST(one::OpBuilder("one_embedding_momentum_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Output("updated_unique_embeddings") .Build()); // This functor is just for unittest sgd_op_ = CHECK_JUST(one::OpBuilder("one_embedding_sgd_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Output("updated_unique_embeddings") .Build()); momentum_op_ = CHECK_JUST(one::OpBuilder("one_embedding_momentum_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Output("updated_unique_embeddings") .Build()); } Maybe operator()(const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_embeddings, const std::shared_ptr& embedding_grad, const Optional& learning_rate, const Optional& down_scale_by_tensor, const Optional& skip_if, const float learning_rate_val, const double scale, const float weight_decay, const float momentum, const int64_t line_size, const int64_t embedding_size, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "weight_decay", "line_size", "embedding_size", "embedding_name", "beta"); if (momentum == 0) { attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, line_size, embedding_size, embedding_name, NullOpt); if (learning_rate) { CHECK(down_scale_by_tensor); CHECK(skip_if); return OpInterpUtil::Dispatch( *sgd_op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if)}, attrs); } else { CHECK(!down_scale_by_tensor); CHECK(!skip_if); return OpInterpUtil::Dispatch( *sgd_no_optional_input_op_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs); } } else { attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, line_size, embedding_size, embedding_name, momentum); if (learning_rate) { CHECK(down_scale_by_tensor); CHECK(skip_if); return OpInterpUtil::Dispatch( *momentum_op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if)}, attrs); } else { CHECK(!down_scale_by_tensor); CHECK(!skip_if); return OpInterpUtil::Dispatch(*momentum_no_optional_input_op_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs); } } } private: std::shared_ptr sgd_no_optional_input_op_; std::shared_ptr sgd_op_; std::shared_ptr momentum_no_optional_input_op_; std::shared_ptr momentum_op_; }; class OneEmbeddingAdamUpdateFunctor { public: OneEmbeddingAdamUpdateFunctor() { // This functor is only used in one_embedding eager mode with lr passed by attr and no optional // input, we also define functor with all optional input just for unittest. when the optional // input learning_rate tensor has passed in, we think all optional input are not None and check // them. no_optional_input_op_ = CHECK_JUST(one::OpBuilder("one_embedding_adam_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Output("updated_unique_embeddings") .Build()); // This functor is just for unittest no_bias_correction_op_ = CHECK_JUST(one::OpBuilder("one_embedding_adam_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Output("updated_unique_embeddings") .Build()); do_bias_correction_op_ = CHECK_JUST(one::OpBuilder("one_embedding_adam_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Input("bias_correction1") .Input("bias_correction2") .Output("updated_unique_embeddings") .Build()); } Maybe operator()( const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_embeddings, const std::shared_ptr& embedding_grad, const Optional& learning_rate, const Optional& down_scale_by_tensor, const Optional& skip_if, const Optional& bias_correction1, const Optional& bias_correction2, const float learning_rate_val, const double scale, const float weight_decay, const float beta1, const float beta2, const float& bias_correction1_val, const float& bias_correction2_val, const float epsilon, const bool do_bias_correction, const int64_t line_size, const int64_t embedding_size, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "learning_rate_val", "scale", "weight_decay", "beta1", "beta2", "epsilon", "bias_correction1_val", "bias_correction2_val", "do_bias_correction", "line_size", "embedding_size", "embedding_name"); attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, beta1, beta2, epsilon, bias_correction1_val, bias_correction2_val, do_bias_correction, line_size, embedding_size, embedding_name); if (learning_rate) { CHECK(down_scale_by_tensor); CHECK(skip_if); if (do_bias_correction) { CHECK(bias_correction1); CHECK(bias_correction2); return OpInterpUtil::Dispatch( *do_bias_correction_op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if), JUST(bias_correction1), JUST(bias_correction2)}, attrs); } else { return OpInterpUtil::Dispatch( *no_bias_correction_op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if)}, attrs); } } else { CHECK(!down_scale_by_tensor); CHECK(!skip_if); CHECK(!bias_correction1); CHECK(!bias_correction2); return OpInterpUtil::Dispatch( *no_optional_input_op_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs); } } private: std::shared_ptr no_bias_correction_op_; std::shared_ptr do_bias_correction_op_; std::shared_ptr no_optional_input_op_; }; class OneEmbeddingAdagradUpdateFunctor { public: OneEmbeddingAdagradUpdateFunctor() { // This functor is only used in one_embedding eager mode with lr passed by attr and no optional // input, we also define functor with all optional input just for unittest. when the optional // input learning_rate tensor has passed in, we think all optional input are not None and check // them. op_no_optional_input_ = CHECK_JUST(one::OpBuilder("one_embedding_adagrad_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Output("updated_unique_embeddings") .Build()); // This functor is just for unittest op_ = CHECK_JUST(one::OpBuilder("one_embedding_adagrad_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Input("train_step") .Output("updated_unique_embeddings") .Build()); } Maybe operator()(const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_embeddings, const std::shared_ptr& embedding_grad, const Optional& learning_rate, const Optional& down_scale_by_tensor, const Optional& skip_if, const Optional& train_step, const int64_t train_step_val, const float learning_rate_val, const double scale, const float weight_decay, const float lr_decay, const float epsilon, const int64_t line_size, const int64_t embedding_size, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("train_step_val", "learning_rate_val", "scale", "weight_decay", "lr_decay", "epsilon", "line_size", "embedding_size", "embedding_name"); attrs.SetAllAttrs(train_step_val, learning_rate_val, scale, weight_decay, lr_decay, epsilon, line_size, embedding_size, embedding_name); if (learning_rate) { CHECK(down_scale_by_tensor); CHECK(skip_if); CHECK(train_step); return OpInterpUtil::Dispatch( *op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if), JUST(train_step)}, attrs); } else { CHECK(!down_scale_by_tensor); CHECK(!skip_if); CHECK(!train_step); return OpInterpUtil::Dispatch( *op_no_optional_input_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs); } } private: std::shared_ptr op_; std::shared_ptr op_no_optional_input_; }; class OneEmbeddingFtrlUpdateFunctor { public: OneEmbeddingFtrlUpdateFunctor() { // This functor is only used in one_embedding eager mode with lr passed by attr and no optional // input, we also define functor with all optional input just for unittest. when the optional // input learning_rate tensor has passed in, we think all optional input are not None and check // them. op_no_optional_input_ = CHECK_JUST(one::OpBuilder("one_embedding_ftrl_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Output("updated_unique_embeddings") .Build()); // This functor is just for unittest op_ = CHECK_JUST(one::OpBuilder("one_embedding_ftrl_update") .Input("num_unique_ids") .Input("unique_embeddings") .Input("embedding_grad") .Input("learning_rate") .Input("down_scale_by_tensor") .Input("skip_if") .Output("updated_unique_embeddings") .Build()); } Maybe operator()(const std::shared_ptr& num_unique_ids, const std::shared_ptr& unique_embeddings, const std::shared_ptr& embedding_grad, const Optional& learning_rate, const Optional& down_scale_by_tensor, const Optional& skip_if, const float learning_rate_val, const double scale, const float weight_decay, const float lr_power, const float lambda1, const float lambda2, const float beta, const int64_t line_size, const int64_t embedding_size, const std::string& embedding_name) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "weight_decay", "lr_power", "lambda1", "lambda2", "beta", "line_size", "embedding_size", "embedding_name"); attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, lr_power, lambda1, lambda2, beta, line_size, embedding_size, embedding_name); if (learning_rate) { CHECK(down_scale_by_tensor); CHECK(skip_if); return OpInterpUtil::Dispatch( *op_, {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate), JUST(down_scale_by_tensor), JUST(skip_if)}, attrs); } else { CHECK(!down_scale_by_tensor); CHECK(!skip_if); return OpInterpUtil::Dispatch( *op_no_optional_input_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs); } } private: std::shared_ptr op_; std::shared_ptr op_no_optional_input_; }; class DeformConv2dFunctor { public: DeformConv2dFunctor() { bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); deformconv2d_op_ = CHECK_JUST(one::OpBuilder("deform_conv2d") .Input("input") .Input("weight") .Input("offset") .Input("mask") .Output("output") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& weight, const std::shared_ptr& offset, const std::shared_ptr& mask, const Optional& bias, const int32_t& stride_h, const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w, const int32_t& dilation_h, const int32_t& dilation_w, const int32_t& groups, const int32_t& offset_groups, const bool& use_mask) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("stride_h", "stride_w", "pad_h", "pad_w", "dilation_h", "dilation_w", "groups", "offset_groups", "use_mask"); attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups, offset_groups, use_mask); const std::shared_ptr& deformconv2d_out = JUST( OpInterpUtil::Dispatch(*deformconv2d_op_, {input, weight, offset, mask}, attrs)); if (bias) { auto bias_shape = JUST(bias)->shape(); auto& bias_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); bias_attrs.SetAllAttrs(static_cast(1)); return OpInterpUtil::Dispatch(*bias_op_, {deformconv2d_out, JUST(bias)}, bias_attrs); } return deformconv2d_out; } private: std::shared_ptr deformconv2d_op_; std::shared_ptr bias_op_; }; class RocAucScoreFunctor { public: RocAucScoreFunctor() { op_ = CHECK_JUST( one::OpBuilder("roc_auc_score").Input("label").Input("pred").Output("out").Build()); } Maybe operator()(const std::shared_ptr& label, const std::shared_ptr& pred) const { return OpInterpUtil::Dispatch(*op_, {label, pred}); } private: std::shared_ptr op_; }; class MultiTensorSgdUpdateFunctor { public: MultiTensorSgdUpdateFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 0; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("multi_tensor_sgd_update") .Input("model", n + 1) .Input("model_diff", n + 1) .Build()); } } Maybe operator()(const TensorTuple& model, const TensorTuple& model_diff, const double& scale, const float& weight_decay, const float& learning_rate_val) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale", "weight_decay", "learning_rate_val"); attrs.SetAllAttrs(scale, weight_decay, learning_rate_val); const int64_t weight_size = model.size(); for (int i = 0; i < weight_size; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i; TensorTuple input(2 * size); std::copy(model.begin() + i, model.begin() + i + size, input.begin()); std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size); JUST(OpInterpUtil::Dispatch(*op_[size - 1], input, attrs)); } return Maybe::Ok(); } private: std::vector> op_; }; class MultiTensorMomentumUpdateFunctor { public: MultiTensorMomentumUpdateFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 0; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("multi_tensor_momentum_update") .Input("model", n + 1) .Input("model_diff", n + 1) .Input("momentum_buf", n + 1) .Build()); } } Maybe operator()(const TensorTuple& model, const TensorTuple& model_diff, const TensorTuple& momentum_buf, const double& scale, const float& weight_decay, const float& learning_rate_val, const float& momentum, const float& dampening, const bool& nesterov, const bool& maximize) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale", "weight_decay", "learning_rate_val", "momentum", "dampening", "nesterov", "maximize"); attrs.SetAllAttrs(scale, weight_decay, learning_rate_val, momentum, dampening, nesterov, maximize); const int64_t weight_size = model.size(); for (int i = 0; i < weight_size; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i; TensorTuple input(3 * size); std::copy(model.begin() + i, model.begin() + i + size, input.begin()); std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size); std::copy(momentum_buf.begin() + i, momentum_buf.begin() + i + size, input.begin() + 2 * size); JUST(OpInterpUtil::Dispatch(*op_[size - 1], input, attrs)); } return Maybe::Ok(); } private: std::vector> op_; }; class MultiTensorAdamUpdateFunctor { public: MultiTensorAdamUpdateFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 0; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("multi_tensor_adam_update") .Input("model", n + 1) .Input("model_diff", n + 1) .Input("m", n + 1) .Input("v", n + 1) .Build()); } } Maybe operator()(const TensorTuple& model, const TensorTuple& model_diff, const TensorTuple& m, const TensorTuple& v, const float& learning_rate_val, const float& l2, const float& beta1, const float& beta2, const float& bias_correction1_val, const float& bias_correction2_val, const bool& do_bias_correction, const double& scale, const float& weight_decay, const float& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP( "scale", "weight_decay", "beta1", "beta2", "bias_correction1_val", "bias_correction2_val", "do_bias_correction", "learning_rate_val", "l2", "epsilon"); attrs.SetAllAttrs(scale, weight_decay, beta1, beta2, bias_correction1_val, bias_correction2_val, do_bias_correction, learning_rate_val, l2, epsilon); const int64_t weight_size = model.size(); for (int i = 0; i < weight_size; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i; TensorTuple input(4 * size); std::copy(model.begin() + i, model.begin() + i + size, input.begin()); std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size); std::copy(m.begin() + i, m.begin() + i + size, input.begin() + 2 * size); std::copy(v.begin() + i, v.begin() + i + size, input.begin() + 3 * size); JUST(OpInterpUtil::Dispatch(*op_[size - 1], input, attrs)); } return Maybe::Ok(); } private: std::vector> op_; }; class MatrixVectorProductFunctor { public: MatrixVectorProductFunctor() { matrix_vector_product_op_ = CHECK_JUST( one::OpBuilder("matrix_vector_product").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& vec) const { const auto& input_shape = input->shape(); const auto& vec_shape = vec->shape(); CHECK_OR_RETURN(input_shape->NumAxes() == 2 && vec_shape->NumAxes() == 1) << Error::RuntimeError() << "vector + matrix @ vector expected, got " << "1, " << input_shape->NumAxes() << ", " << vec_shape->NumAxes(); CHECK_EQ_OR_RETURN(input_shape->at(1), vec_shape->at(0)) << Error::RuntimeError() << "size mismatch, got " << std::to_string(input_shape->at(0)) << ", " << std::to_string(input_shape->at(0)) << "x" << std::to_string(input_shape->at(1)) << ", " << std::to_string(vec_shape->at(0)); return OpInterpUtil::Dispatch(*matrix_vector_product_op_, {input, vec}); } private: std::shared_ptr matrix_vector_product_op_; }; class BatchNormStatsFunctor { public: BatchNormStatsFunctor() { op_ = CHECK_JUST( one::OpBuilder("batch_norm_stats").Input("input").Output("mean").Output("invstd").Build()); } Maybe operator()(const std::shared_ptr& input, const int& axis, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "eps"); attrs.SetAllAttrs(axis, eps); return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; class BatchNormGatherStatsWithCountsFunctor { public: BatchNormGatherStatsWithCountsFunctor() { op_with_running_mean_and_var_ = CHECK_JUST(one::OpBuilder("batch_norm_gather_stats_with_counts") .Input("input") .Input("mean") .Input("invstd") .Input("counts") .Input("running_mean") .Input("running_var") .Output("global_mean") .Output("global_invstd") .Build()); op_without_running_mean_and_var_ = CHECK_JUST(one::OpBuilder("batch_norm_gather_stats_with_counts") .Input("input") .Input("mean") .Input("invstd") .Input("counts") .Output("global_mean") .Output("global_invstd") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& mean, const std::shared_ptr& invstd, const Optional& running_mean, const Optional& running_var, const float& momentum, const float& eps, const std::shared_ptr& counts) const { CHECK_OR_RETURN((running_mean && running_var) || (!running_mean && !running_var)) << Error::RuntimeError() << "Both running_mean and running_var should be None or Tensor at the same time."; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("eps", "momentum"); attrs.SetAllAttrs(eps, momentum); if (running_mean) { return OpInterpUtil::Dispatch( *op_with_running_mean_and_var_, {input, mean, invstd, counts, JUST(running_mean), JUST(running_var)}, attrs); } return OpInterpUtil::Dispatch(*op_without_running_mean_and_var_, {input, mean, invstd, counts}, attrs); } private: std::shared_ptr op_with_running_mean_and_var_; std::shared_ptr op_without_running_mean_and_var_; }; class BatchNormElemtFunctor { public: BatchNormElemtFunctor() { op_ = CHECK_JUST(one::OpBuilder("batch_norm_elemt") .Input("input") .Input("weight") .Input("bias") .Input("mean") .Input("invstd") .Output("output") .Build()); } Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& weight, const std::shared_ptr& bias, const std::shared_ptr& mean, const std::shared_ptr& invstd, const int& axis, const float& eps) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "eps"); attrs.SetAllAttrs(axis, eps); return OpInterpUtil::Dispatch(*op_, {input, weight, bias, mean, invstd}, attrs); } private: std::shared_ptr op_; }; class BatchNormBackwardReduceFunctor { public: BatchNormBackwardReduceFunctor() { op_ = CHECK_JUST(one::OpBuilder("batch_norm_backward_reduce") .Input("grad_out") .Input("input") .Input("mean") .Input("invstd") .Output("sum_dy") .Output("sum_dy_xmu") .Output("grad_weight") .Output("grad_bias") .Build()); } Maybe operator()(const std::shared_ptr& grad_out, const std::shared_ptr& input, const std::shared_ptr& mean, const std::shared_ptr& invstd, const int& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch(*op_, {grad_out, input, mean, invstd}, attrs); } private: std::shared_ptr op_; }; class BatchNormBackwardElemtFunctor { public: BatchNormBackwardElemtFunctor() { op_ = CHECK_JUST(one::OpBuilder("batch_norm_backward_elemt") .Input("grad_out") .Input("input") .Input("mean") .Input("invstd") .Input("weight") .Input("sum_dy") .Input("sum_dy_xmu") .Input("count") .Output("grad_in") .Build()); } Maybe operator()(const std::shared_ptr& grad_out, const std::shared_ptr& input, const std::shared_ptr& mean, const std::shared_ptr& invstd, const std::shared_ptr& weight, const std::shared_ptr& sum_dy, const std::shared_ptr& sum_dy_xmu, const std::shared_ptr& count, const int& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis"); attrs.SetAllAttrs(axis); return OpInterpUtil::Dispatch( *op_, {grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count}, attrs); } private: std::shared_ptr op_; }; class FusedFastGeluMulFunctor { public: FusedFastGeluMulFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_fast_gelu_mul") .Input("in") .Input("multiplier") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& multiplier) const { return OpInterpUtil::Dispatch(*op_, {x, multiplier}); } private: std::shared_ptr op_; }; class FusedFastGeluMulGradFunctor { public: FusedFastGeluMulGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_fast_gelu_mul_grad") .Input("out_diff") .Input("in") .Input("multiplier") .Output("in_diff") .Output("multiplier_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& multiplier) const { return OpInterpUtil::Dispatch(*op_, {dy, x, multiplier}); } private: std::shared_ptr op_; }; class GroupedMatmulBiasFunctor { public: GroupedMatmulBiasFunctor() { fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("grouped_matmul_bias") .Input("xs", n) .Input("weights", n) .Input("biases", n) .Output("ys", n) .Build()); } } Maybe operator()(const TensorTuple& xs, const TensorTuple& weights, const TensorTuple& biases) const { const int64_t input_size = xs.size(); const int64_t weight_size = weights.size(); const int64_t bias_size = biases.size(); CHECK_GE_OR_RETURN(input_size, 1) << Error::RuntimeError() << "The number of xs should be greater equal than 1."; CHECK_EQ_OR_RETURN(weight_size, input_size) << Error::RuntimeError() << "The number of weights should be equal to xs."; CHECK_EQ_OR_RETURN(bias_size, input_size) << Error::RuntimeError() << "The number of bias should be equal to xs."; for (int64_t i = 0; i < input_size; ++i) { const auto& input_shape = xs[i]->shape(); const auto& weight_shape = weights[i]->shape(); const auto& bias_shape = biases[i]->shape(); CHECK_GE_OR_RETURN(input_shape->NumAxes(), 2) << Error::RuntimeError() << "x's dim size should greater equal than 2."; CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << Error::RuntimeError() << "Weight's dim size should == 2"; CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1) << Error::RuntimeError() << "Bias's dim size should == 1"; const int64_t k = input_shape->At(input_shape->NumAxes() - 1); CHECK_EQ_OR_RETURN(weight_shape->At(1), k) << Error::RuntimeError() << "weight's second dim should be equal to input's last dim. "; const int64_t n = weight_shape->At(0); CHECK_EQ_OR_RETURN(bias_shape->At(0), n) << Error::RuntimeError() << "Bias's dim is not equal to weight's first dim. "; } TensorTuple input(3 * input_size); std::copy(xs.begin(), xs.end(), input.begin() + 0 * input_size); std::copy(weights.begin(), weights.end(), input.begin() + 1 * input_size); std::copy(biases.begin(), biases.end(), input.begin() + 2 * input_size); return OpInterpUtil::Dispatch(*fused_op_[input_size], input); } private: std::vector> fused_op_; }; class GroupedMatmulFunctor { public: GroupedMatmulFunctor() { fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("grouped_matmul_bias") .Input("xs", n) .Input("weights", n) .Output("ys", n) .Build()); } } Maybe operator()(const TensorTuple& xs, const TensorTuple& weights) const { const int64_t input_size = xs.size(); const int64_t weight_size = weights.size(); CHECK_LT_OR_RETURN(input_size, kMaxInputCount) << Error::RuntimeError() << "input_size size should not be greater than 128"; CHECK_GE_OR_RETURN(input_size, 1) << Error::RuntimeError() << "The number of xs should be greater equal than 1."; CHECK_EQ_OR_RETURN(weight_size, input_size) << Error::RuntimeError() << "The number of weights should be equal to xs."; for (int64_t i = 0; i < input_size; ++i) { const auto& input_shape = xs[i]->shape(); const auto& weight_shape = weights[i]->shape(); CHECK_GE_OR_RETURN(input_shape->NumAxes(), 2) << Error::RuntimeError() << "x's dim size should greater equal than 2."; CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << Error::RuntimeError() << "Weight's dim size should == 2"; const int64_t k = input_shape->At(input_shape->NumAxes() - 1); CHECK_EQ_OR_RETURN(weight_shape->At(1), k) << Error::RuntimeError() << "weight's second dim should be equal to input's last dim. "; } TensorTuple input(2 * input_size); std::copy(xs.begin(), xs.end(), input.begin() + 0 * input_size); std::copy(weights.begin(), weights.end(), input.begin() + 1 * input_size); return OpInterpUtil::Dispatch(*fused_op_[input_size], input); } private: std::vector> fused_op_; }; class MultiTensorYoloV5WeightUpdateFunctor { public: MultiTensorYoloV5WeightUpdateFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 0; n < op_.size(); ++n) { op_[n] = CHECK_JUST(one::OpBuilder("multi_tensor_yolov5_weight_update") .Input("model", n + 1) .Input("model_update", n + 1) .Build()); } } Maybe operator()(const TensorTuple& model, const TensorTuple& model_update, const float& d) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("d"); attrs.SetAllAttrs(d); const int64_t weight_size = model.size(); for (int i = 0; i < weight_size; i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i; TensorTuple input(size * 2); std::copy(model.begin() + i, model.begin() + i + size, input.begin()); std::copy(model_update.begin() + i, model_update.begin() + i + size, input.begin() + 1 * size); JUST(OpInterpUtil::Dispatch(*op_[size - 1], input, attrs)); } return Maybe::Ok(); } private: std::vector> op_; }; class FusedScaleMaskBiasSoftmaxFunctor { public: FusedScaleMaskBiasSoftmaxFunctor() { op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_bias_softmax") .Input("x") .Input("mask") .Input("bias") .Output("out") .Build()); op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_bias_softmax") .Input("x") .Input("mask") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& mask, const Optional& bias, const float& scale, const bool& inplace = false) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale", "inplace"); attrs.SetAllAttrs(scale, inplace); if (bias) { if (inplace) { std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_with_bias_, {x, mask, JUST(bias)}, outputs.get(), attrs)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_with_bias_, {x, mask, JUST(bias)}, attrs); ; } if (inplace) { std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_without_bias_, {x, mask}, outputs.get(), attrs)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_without_bias_, {x, mask}, attrs); } private: std::shared_ptr op_without_bias_; std::shared_ptr op_with_bias_; }; class FusedScaleMaskBiasSoftmaxGradFunctor { public: FusedScaleMaskBiasSoftmaxGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_bias_softmax_grad") .Input("y") .Input("dy") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const float& scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale"); attrs.SetAllAttrs(scale); return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); } private: std::shared_ptr op_; }; class FusedClipGradFunctor { public: FusedClipGradFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); for (int n = 0; n < op_.size(); ++n) { op_[n] = CHECK_JUST( one::OpBuilder("fused_clip_grad").Input("model_diff", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& model_diff, const float& max_norm, const float& norm_type) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("max_norm", "norm_type"); attrs.SetAllAttrs(max_norm, norm_type); const int64_t input_size = model_diff.size(); CHECK_LE_OR_RETURN(input_size, kMaxInputCount) << Error::RuntimeError() << "model_diff size should not be greater than 128"; return JUST(OpInterpUtil::Dispatch(*op_[input_size - 1], model_diff, attrs)); } private: std::vector> op_; }; class NonContiguousBinaryOpFunctor { public: NonContiguousBinaryOpFunctor() { op_ = CHECK_JUST( one::OpBuilder("noncontiguous_binary_op").Input("lhs").Input("rhs").Output("y").Build()); } Maybe operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs, const std::string& op, const bool& inplace = false) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace"); attrs.SetAllAttrs(op, inplace); if (inplace) { std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = lhs; JUST(OpInterpUtil::Dispatch(*op_, {lhs, rhs}, outputs.get(), attrs)); return outputs->at(0); } return OpInterpUtil::Dispatch(*op_, {lhs, rhs}, attrs); } private: std::shared_ptr op_; }; class NonContiguousBinaryOpGradFunctor { public: NonContiguousBinaryOpGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("noncontiguous_binary_op_grad") .Input("dy") .Input("lhs") .Input("rhs") .Output("dlhs") .Output("drhs") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& lhs, const std::shared_ptr& rhs, const std::string& op, const bool& inplace = false) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace"); attrs.SetAllAttrs(op, inplace); return OpInterpUtil::Dispatch(*op_, {dy, lhs, rhs}, attrs); } private: std::shared_ptr op_; }; namespace { template Maybe pad_last_dim(const std::shared_ptr& input) { auto num_dims = input->shape()->NumAxes(); auto last_dim_size = input->shape()->At(num_dims - 1); if (last_dim_size % alignment_size == 0) { return input; } auto pad_count = alignment_size - (last_dim_size % alignment_size); return JUST(functional::Pad(input, {0, pad_count}, "constant", Scalar(0))); ; } } // namespace class ScaledDotProductFlashAttentionFunctor { public: ScaledDotProductFlashAttentionFunctor() { #if CUDA_VERSION >= 11070 op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention") .Input("query") .Input("key") .Input("value") .Output("out") .Output("softmax_lse") .Output("rng_state") .Build()); #endif // CUDA_VERSION >= 11070 } Maybe operator()(const std::shared_ptr& query, const std::shared_ptr& key, const std::shared_ptr& value, const Optional& attn_mask, const float& dropout_p, const bool& is_causal, const Optional& scale, const int64_t& seed = 0) const { #if CUDA_VERSION >= 11070 const auto og_size = query->shape()->At(3); const auto batch_size = query->shape()->At(0); const auto seqlen_q = query->shape()->At(2); const auto num_heads = query->shape()->At(1); const auto num_heads_k = key->shape()->At(1); const auto max_seqlen_batch_k = key->shape()->At(2); const auto max_seqlen_batch_v = value->shape()->At(2); CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0)) << " key has different batch size from query."; CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0)) << " value has different batch size from query."; CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(1)) << " value has different num_heads from key."; CHECK_EQ_OR_RETURN(max_seqlen_batch_k, max_seqlen_batch_v) << "value has different seqlen from key."; CHECK_EQ_OR_RETURN(og_size, key->shape()->At(3)) << " key has different head dims from query."; CHECK_EQ_OR_RETURN(og_size, value->shape()->At(3)) << " value has different head dims from query."; // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) std::shared_ptr q_padded, k_padded, v_padded; bool padded = og_size % 8; if (padded) { q_padded = JUST(pad_last_dim<8>(query)); k_padded = JUST(pad_last_dim<8>(key)); v_padded = JUST(pad_last_dim<8>(value)); } else { q_padded = query; k_padded = key; v_padded = value; } auto q_ = JUST(functional::Transpose(q_padded, {0, 2, 1, 3})); auto k_ = JUST(functional::Transpose(k_padded, {0, 2, 1, 3})); auto v_ = JUST(functional::Transpose(v_padded, {0, 2, 1, 3})); // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) const auto& scale_ = scale.has_value() ? scale : (1.0f / std::sqrt(static_cast(query->shape()->At(3)))); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal", "window_size_left", "window_size_right", "seed"); attrs.SetAllAttrs(dropout_p, scale_, is_causal, -1, -1, seed); auto gen = JUST(one::DefaultAutoGenerator()); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), query)); const auto& state = std::make_shared(gen); OpExprInterpContext ctx(attrs, state); std::shared_ptr output_ = JUST(OpInterpUtil::Dispatch(*op_, {q_, k_, v_}, ctx)); auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3})); std::shared_ptr output; if (padded) { output = JUST(functional::Slice(output_padded, {0, 0, 0, 0}, {batch_size, num_heads, seqlen_q, og_size}, {1, 1, 1, 1}, false)); } else { output = output_padded; } return output; #endif // CUDA_VERSION >= 11070 UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070."; } private: #if CUDA_VERSION >= 11070 std::shared_ptr op_; #endif // CUDA_VERSION >= 11070 }; class ScaledDotProductFlashAttentionGradFunctor { public: ScaledDotProductFlashAttentionGradFunctor() { #if CUDA_VERSION >= 11070 op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention_grad") .Input("grad_out") .Input("query") .Input("key") .Input("value") .Input("out") .Input("softmax_lse") .Input("rng_state") .Output("grad_q") .Output("grad_k") .Output("grad_v") .Build()); #endif } Maybe operator()( const std::shared_ptr& grad_out, const std::shared_ptr& query, const std::shared_ptr& key, const std::shared_ptr& value, const std::shared_ptr& out, const std::shared_ptr& softmax_lse, const std::shared_ptr& rng_state, const float& dropout_p, const bool& is_causal, const float& scale) const { #if CUDA_VERSION >= 11070 // grad_out(batch x q_sqe_len x num_heads x head_size) // query (batch x q_seq_len x num_heads x head_size_padded) // key (batch x kv_seq_len x num_heads_k x head_size_padded) // value (batch x kv_seq_len x num_heads_k x head_size_padded) // out (batch x kv_seq_len x num_heads x head_size_padded) // softmax_lse (batch x num_heads x q_seq_len) const auto head_size = grad_out->shape()->At(3); const auto head_size_padded = query->shape()->At(3); const auto batch_size = query->shape()->At(0); const auto seqlen_q = query->shape()->At(1); const auto seqlen_k = key->shape()->At(1); const auto num_heads = query->shape()->At(2); const auto num_heads_k = key->shape()->At(2); CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0)) << " key has different batch size from query."; CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0)) << " value has different batch size from query."; CHECK_EQ_OR_RETURN(batch_size, grad_out->shape()->At(0)) << " grad_out has different batch size from query."; CHECK_EQ_OR_RETURN(batch_size, out->shape()->At(0)) << " out has different batch size from query."; CHECK_EQ_OR_RETURN(batch_size, softmax_lse->shape()->At(0)) << " softmax_lse has different batch size from query."; CHECK_EQ_OR_RETURN(num_heads, grad_out->shape()->At(2)) << " grad_out has different num_heads from query."; CHECK_EQ_OR_RETURN(num_heads, softmax_lse->shape()->At(1)) << " softmax_lse has different num_heads from query."; CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(2)) << " value has different num_heads from key."; CHECK_EQ_OR_RETURN(seqlen_q, grad_out->shape()->At(1)) << " grad_out has different seq_len from query."; CHECK_EQ_OR_RETURN(seqlen_q, softmax_lse->shape()->At(2)) << " softmax_lse has different seq_len from query."; CHECK_EQ_OR_RETURN(head_size_padded, key->shape()->At(3)) << " key has different head dims from query."; CHECK_EQ_OR_RETURN(head_size_padded, value->shape()->At(3)) << " key has different head dims from query."; CHECK_EQ_OR_RETURN(head_size_padded, out->shape()->At(3)) << " out has different head dims from query."; bool padded = head_size % 8; auto grad_out_ = padded ? JUST(pad_last_dim<8>(grad_out)) : grad_out; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal", "window_size_left", "window_size_right"); attrs.SetAllAttrs(dropout_p, scale, is_causal, -1, -1); auto output = std::make_shared(3); auto output_ = JUST(OpInterpUtil::Dispatch( *op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs)); CHECK_EQ(output_->size(), 3); auto grad_q_ = (*output_)[0]; auto grad_k_ = (*output_)[1]; auto grad_v_ = (*output_)[2]; std::shared_ptr grad_q_padded, grad_k_padded, grad_v_padded; bool expanded = num_heads != num_heads_k; grad_q_padded = grad_q_; if (expanded) { grad_k_padded = JUST(functional::ReduceSum( JUST(functional::Reshape(grad_k_, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_padded})), {3}, false, grad_k_->dtype())); grad_v_padded = JUST(functional::ReduceSum( JUST(functional::Reshape(grad_v_, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_padded})), {3}, false, grad_v_->dtype())); } else { grad_k_padded = grad_k_; grad_v_padded = grad_v_; } auto grad_q = padded ? JUST(functional::Slice(grad_q_padded, {0, 0, 0, 0}, {batch_size, seqlen_q, num_heads, head_size}, {1, 1, 1, 1}, false)) : grad_q_padded; auto grad_k = padded ? JUST(functional::Slice(grad_k_padded, {0, 0, 0, 0}, {batch_size, seqlen_k, num_heads_k, head_size}, {1, 1, 1, 1}, false)) : grad_k_padded; auto grad_v = padded ? JUST(functional::Slice(grad_v_padded, {0, 0, 0, 0}, {batch_size, seqlen_k, num_heads_k, head_size}, {1, 1, 1, 1}, false)) : grad_v_padded; (*output)[0] = grad_q; (*output)[1] = grad_k; (*output)[2] = grad_v; return output; #endif // CUDA_VERSION >= 11070 UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070."; } private: #if CUDA_VERSION >= 11070 std::shared_ptr op_; #endif // CUDA_VERSION >= 11070 }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); m.add_functor("Conv1d"); m.add_functor("Conv2d"); m.add_functor("Conv3d"); m.add_functor("Deconv1d"); m.add_functor("Deconv2d"); m.add_functor("Deconv3d"); m.add_functor("EmbeddingReNorm"); m.add_functor("Embedding"); m.add_functor("MatMul"); m.add_functor("MatMulNoBroadCast"); m.add_functor("BatchMatMul"); m.add_functor("MatrixVectorProduct"); m.add_functor("VectorMatrixProduct"); m.add_functor("TensorDot"); m.add_functor("TensorDotIntDims"); m.add_functor("FusedMLP"); m.add_functor("FusedMatmulBias"); m.add_functor("FusedMatmulBiasAddReluDropout"); m.add_functor("LayerNorm"); m.add_functor("SkipLayerNorm"); m.add_functor("LayerNormAffine"); m.add_functor("GroupNorm"); m.add_functor("TFAvgPool2D"); m.add_functor("MaxPool1D"); m.add_functor("MaxPool2D"); m.add_functor("MaxPool3D"); m.add_functor>("MaxUnpool1D"); m.add_functor>("MaxUnpool2D"); m.add_functor>("MaxUnpool3D"); m.add_functor("AdaptiveAvgPool1D"); m.add_functor("AdaptiveAvgPool2D"); m.add_functor("AdaptiveAvgPool3D"); m.add_functor("AdaptiveMaxPool1D"); m.add_functor("AdaptiveMaxPool2D"); m.add_functor("AdaptiveMaxPool3D"); m.add_functor("L1Loss"); m.add_functor("MseLoss"); m.add_functor("KLDivLoss"); m.add_functor("NLLLoss"); m.add_functor("BinaryCrossEntropyLoss"); m.add_functor("BinaryCrossEntropyWithLogitsLoss"); m.add_functor("SparseCrossEntropy"); m.add_functor("SparseCrossEntropyMs"); m.add_functor("CrossEntropy"); m.add_functor("CrossEntropyLabelSmoothing"); m.add_functor("CrossEntropyProb"); m.add_functor("SparseSoftmaxCrossEntropy"); m.add_functor("SoftmaxCrossEntropy"); m.add_functor("SoftmaxCrossEntropyGrad"); m.add_functor("SmoothL1Loss"); m.add_functor("CombinedMarginLoss"); m.add_functor("TripletMarginLoss"); m.add_functor("MarginRankingLoss"); m.add_functor("CtcLoss"); m.add_functor("AffineGrid"); m.add_functor("GridSample"); m.add_functor("Normalization"); m.add_functor("NormalizationAddRelu"); m.add_functor("ConstantPad"); m.add_functor("ReflectionPad"); m.add_functor("ReplicationPad"); m.add_functor("Pad"); m.add_functor("Dropout"); m.add_functor("DropoutGrad"); m.add_functor("Dropout1d"); m.add_functor("Dropout2d"); m.add_functor("Dropout3d"); m.add_functor("PixelShuffle"); m.add_functor("AvgPool1D"); m.add_functor("AvgPool2D"); m.add_functor("AvgPool3D"); m.add_functor("Unfold"); m.add_functor("Fold"); m.add_functor("OneHot"); m.add_functor("FusedSelfAttention"); m.add_functor("FusedSelfAttentionGrad"); m.add_functor("PairwiseDistance"); m.add_functor("CosineSimilarity"); m.add_functor("Normalize"); m.add_functor("L2Normalize"); m.add_functor("L2NormalizeGrad"); m.add_functor("FusedBiasAddGelu"); m.add_functor("FusedBiasAddGeluGrad"); m.add_functor("FusedGlu"); m.add_functor("FusedBiasAddDropout"); m.add_functor("FusedScaleMaskSoftmax"); m.add_functor("FusedScaleMaskSoftmaxDropout"); m.add_functor( "FusedBiasAddScaleMaskSoftmaxDropout"); m.add_functor("FusedScaleTrilSoftmaxMaskScale"); m.add_functor("FusedScaleTril"); m.add_functor("CtcGreedyDecoder"); m.add_functor("DistributedPariticalFCSampleDisableBoxing"); m.add_functor("Nms"); m.add_functor("RoiAlign"); m.add_functor("RoiAlignGrad"); m.add_functor("FusedDotFeatureInteraction"); m.add_functor("FusedCrossFeatureInteraction"); m.add_functor("OneEmbeddingIdShuffle"); m.add_functor("OneEmbeddingEmbeddingShuffle"); m.add_functor( "OneEmbeddingEmbeddingGradientShuffle"); m.add_functor("OneEmbeddingLookup"); m.add_functor("OneEmbeddingFusedLookup"); m.add_functor("OneEmbeddingFusedLookupGrad"); m.add_functor("OneEmbeddingEmbeddingPut"); m.add_functor("OneEmbeddingUniqueKeyValuePair"); m.add_functor("OneEmbeddingSgdUpdate"); m.add_functor("OneEmbeddingAdamUpdate"); m.add_functor("OneEmbeddingAdagradUpdate"); m.add_functor("OneEmbeddingFtrlUpdate"); m.add_functor("RocAucScore"); m.add_functor("MultiTensorSgdUpdate"); m.add_functor("MultiTensorMomentumUpdate"); m.add_functor("MultiTensorAdamUpdate"); m.add_functor("DeformConv2d"); m.add_functor("BatchNormStats"); m.add_functor("BatchNormGatherStatsWithCounts"); m.add_functor("BatchNormElemt"); m.add_functor("BatchNormBackwardReduce"); m.add_functor("BatchNormBackwardElemt"); m.add_functor("FusedFastGeluMul"); m.add_functor("FusedFastGeluMulGrad"); m.add_functor("GroupedMatmulBias"); m.add_functor("GroupedMatmul"); m.add_functor("RMSNorm"); m.add_functor("SkipRMSNorm"); m.add_functor("FusedScaleMaskBiasSoftmax"); m.add_functor("FusedScaleMaskBiasSoftmaxGrad"); m.add_functor("NonContiguousBinaryOp"); m.add_functor("NonContiguousBinaryOpGrad"); m.add_functor("MultiTensorYoloV5WeightUpdate"); m.add_functor("FusedClipGrad"); m.add_functor("ScaledDotProductFlashAttention"); m.add_functor( "ScaledDotProductFlashAttentionGrad"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/nn_grad_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { namespace functional { namespace impl { class ConvBiasGradFunctor { public: ConvBiasGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("conv_bias_grad").Input("dy").Output("bias_diff").Build()); } Maybe operator()(const std::shared_ptr& dy, const int32_t& num_spatial_dims, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_spatial_dims", "data_format"); attrs.SetAllAttrs(num_spatial_dims, data_format); return OpInterpUtil::Dispatch(*op_, {dy}, attrs); } private: std::shared_ptr op_; }; class ConvFilterGradFunctor { public: ConvFilterGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("conv_filter_grad").Input("dy").Input("x").Output("filter_diff").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const int32_t& num_spatial_dims, const std::vector& kernel_size, const std::vector& strides, const std::vector& padding_before, const std::vector& dilation_rate, const int32_t& groups, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_spatial_dims", "kernel_size", "strides", "padding_before", "dilation_rate", "groups", "data_format"); attrs.SetAllAttrs(num_spatial_dims, kernel_size, strides, padding_before, dilation_rate, groups, data_format); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } private: std::shared_ptr op_; }; class ConvDataGradFunctor { public: ConvDataGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("conv_data_grad") .Input("dy") .Input("filter") .Input("x_like") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& weight, const std::shared_ptr& x, const int32_t& num_spatial_dims, const std::vector& kernel_size, const std::vector& strides, const std::vector& padding_before, const std::vector& dilation_rate, const int32_t& groups, const std::string& data_format) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_spatial_dims", "kernel_size", "strides", "padding_before", "dilation_rate", "groups", "data_format"); attrs.SetAllAttrs(num_spatial_dims, kernel_size, strides, padding_before, dilation_rate, groups, data_format); return OpInterpUtil::Dispatch(*op_, {dy, weight, JUST(x->detach())}, attrs); } private: std::shared_ptr op_; }; class EmbeddingGradFunctor { public: EmbeddingGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("embedding_grad") .Input("dy") .Input("weight") .Input("indices") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& weight, const std::shared_ptr& indices, const int64_t& padding_idx, const bool& scale_grad_by_freq) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding_idx", "scale_grad_by_freq"); attrs.SetAllAttrs(padding_idx, scale_grad_by_freq); return OpInterpUtil::Dispatch(*op_, {dy, weight, indices}, attrs); } private: std::shared_ptr op_; }; class MaxPoolNdGradFunctor { public: MaxPoolNdGradFunctor() { for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(ndims); op_expr_map_[op_type_name] = CHECK_JUST( one::OpBuilder(op_type_name).Input("dy").Input("x").Input("indice").Output("dx").Build()); } } static std::string GetOpTypeName(const int32_t& ndims) { return "max_pool_" + std::to_string(ndims) + "d_grad"; } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& indice, const std::shared_ptr& dy, const int32_t& ndims, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const std::vector& dilation, const bool& return_indices, const bool& ceil_mode) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format", "padding", "kernel_size", "stride", "dilation", "return_indices", "ceil_mode"); attrs.SetAllAttrs(data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); const auto& op_type_name = GetOpTypeName(ndims); const auto& it = op_expr_map_.find(op_type_name); CHECK_OR_RETURN(it != op_expr_map_.end()) << Error::RuntimeError() << "Encounter unsupported op " << op_type_name << " in MaxPoolNdGradFunctor."; CHECK_NOTNULL_OR_RETURN(it->second); // NOLINT(maybe-need-error-msg) return OpInterpUtil::Dispatch(*it->second, {dy, x, indice}, attrs); } protected: std::unordered_map> op_expr_map_; }; template class MaxUnpoolNdGradFunctor { public: MaxUnpoolNdGradFunctor() : op_(CHECK_JUST(one::OpBuilder(fmt::format("max_unpool_{}d_grad", N)) .Input("dy") .Input("x") .Input("indices") .Output("dx") .Build())) {} Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& indice, const std::shared_ptr& dy) const { return OpInterpUtil::Dispatch(*op_, {dy, x, indice}); } protected: std::shared_ptr op_; }; class AdaptiveMaxPoolNdGradFunctor { public: AdaptiveMaxPoolNdGradFunctor() { for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(ndims); op_expr_map_[op_type_name] = CHECK_JUST( one::OpBuilder(op_type_name).Input("dy").Input("x").Input("index").Output("dx").Build()); } } static std::string GetOpTypeName(const int32_t& ndims) { return "adaptive_max_pool" + std::to_string(ndims) + "d_grad"; } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const std::shared_ptr& index, const int32_t& ndims, const std::string& data_format) const { const auto& op_type_name = GetOpTypeName(ndims); const auto& it = op_expr_map_.find(op_type_name); CHECK_OR_RETURN(it != op_expr_map_.end()) << Error::RuntimeError() << "Encounter unsupported op " << op_type_name << " in AdaptiveMaxPoolNdGradFunctor."; CHECK_NOTNULL_OR_RETURN(it->second); // NOLINT(maybe-need-error-msg) auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); attrs.SetAllAttrs(data_format); return OpInterpUtil::Dispatch(*it->second, {dy, x, index}, attrs); } protected: std::unordered_map> op_expr_map_; }; class TFPoolNdGradFunctor { public: TFPoolNdGradFunctor() { for (const auto& mode : {"tf_max", "tf_avg"}) { for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(mode, ndims); op_expr_map_[op_type_name] = CHECK_JUST( one::OpBuilder(op_type_name).Input("x").Input("y").Input("dy").Output("dx").Build()); } } } static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) { return mode + "_pool_" + std::to_string(ndims) + "d_grad"; } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y, const std::shared_ptr& dy, const std::string& mode, const int32_t& ndims, const std::string& data_format, const std::string& padding, const std::vector& padding_before, const std::vector& padding_after, const std::vector& pool_size, const std::vector& strides, const bool& ceil_mode) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format", "padding", "padding_before", "padding_after", "pool_size", "strides", "ceil_mode"); attrs.SetAllAttrs(data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); const auto& op_type_name = GetOpTypeName(mode, ndims); const auto& it = op_expr_map_.find(op_type_name); CHECK_OR_RETURN(it != op_expr_map_.end()) << Error::RuntimeError() << "Encounter unsupported op " << op_type_name << " in TFPoolNdGradFunctor."; CHECK_NOTNULL_OR_RETURN(it->second); // NOLINT(maybe-need-error-msg) return OpInterpUtil::Dispatch(*it->second, {x, y, dy}, attrs); } protected: std::unordered_map> op_expr_map_; }; class AdaptivePoolNdGradFunctor { public: AdaptivePoolNdGradFunctor() { for (const auto& mode : {"avg"}) { for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(mode, ndims); op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name).Input("dy").Input("x").Output("dx").Build()); } } } static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) { return "adaptive_" + mode + "_pool" + std::to_string(ndims) + "d_grad"; } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const std::string& mode, const int32_t& ndims, const std::string& data_format) const { const auto& op_type_name = GetOpTypeName(mode, ndims); const auto& it = op_expr_map_.find(op_type_name); CHECK_OR_RETURN(it != op_expr_map_.end()) << Error::RuntimeError() << "Encounter unsupported op " << op_type_name << " in AdaptivePoolNdGradFunctor."; CHECK_NOTNULL_OR_RETURN(it->second); // NOLINT(maybe-need-error-msg) auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); attrs.SetAllAttrs(data_format); return OpInterpUtil::Dispatch(*it->second, {dy, x}, attrs); } protected: std::unordered_map> op_expr_map_; }; class SparseCrossEntropyGradFunctor { public: SparseCrossEntropyGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("sparse_cross_entropy_grad") .Input("prediction") .Input("label") .Input("dy") .Output("prediction_diff") .Build()); } Maybe operator()(const std::shared_ptr& prediction, const std::shared_ptr& label, const std::shared_ptr& dy, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prediction, label, dy}, attrs); } private: std::shared_ptr op_; }; class SparseCrossEntropyMsGradFunctor { public: SparseCrossEntropyMsGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("sparse_cross_entropy_ms_grad") .Input("prediction") .Input("label") .Input("dy") .Output("prediction_diff") .Build()); } Maybe operator()(const std::shared_ptr& prediction, const std::shared_ptr& label, const std::shared_ptr& dy, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prediction, label, dy}, attrs); } private: std::shared_ptr op_; }; class SparseSoftmaxCrossEntropyGrad { public: SparseSoftmaxCrossEntropyGrad() { op_ = CHECK_JUST(one::OpBuilder("sparse_softmax_cross_entropy_grad") .Input("prob") .Input("label") .Input("dy") .Output("prediction_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& prob, const std::shared_ptr& label, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prob, label, dy}, attrs); } private: std::shared_ptr op_; }; class SparseSoftmaxCrossEntropyMsGrad { public: SparseSoftmaxCrossEntropyMsGrad() { op_ = CHECK_JUST(one::OpBuilder("sparse_softmax_cross_entropy_ms_grad") .Input("prob") .Input("label") .Input("dy") .Output("prediction_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& prob, const std::shared_ptr& label, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("depth"); attrs.SetAllAttrs(depth); return OpInterpUtil::Dispatch(*op_, {prob, label, dy}, attrs); } private: std::shared_ptr op_; }; class SmoothL1LossGradFunctor { public: SmoothL1LossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("smooth_l1_loss_grad") .Input("dy") .Input("input") .Input("target") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const float& beta) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("beta"); attrs.SetAllAttrs(beta); return OpInterpUtil::Dispatch(*op_, {dy, input, target}, attrs); } private: std::shared_ptr op_; }; class KLDivLossGradFunctor { public: KLDivLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("kl_div_loss_grad") .Input("dy") .Input("input") .Input("target") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const bool log_target) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("log_target"); attrs.SetAllAttrs(log_target); return OpInterpUtil::Dispatch(*op_, {dy, input, target}, attrs); } private: std::shared_ptr op_; }; class KLDivLossTargetGradFunctor { public: Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const bool log_target) const { if (log_target) { return functional::sequence_function(functional::Sub) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1)); }) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Exp(target)))) .then(std::bind(functional::Mul, std::placeholders::_1, dy)) .call(target, input, /*alpha=*/1, /*inplace=*/false); } else { return functional::sequence_function(functional::Log) .then([](const std::shared_ptr& input) { return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1)); }) .then(std::bind(functional::Sub, std::placeholders::_1, input, /*alpha=*/1, /*inplace=*/false)) .then(std::bind(functional::Mul, std::placeholders::_1, dy)) .call(target); } } }; class NLLGradFunctor { public: NLLGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("nll_grad") .Input("out_grad") .Input("input") .Input("target") .Output("in_grad") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("nll_grad") .Input("out_grad") .Input("input") .Input("target") .Input("weight") .Output("in_grad") .Build()); } Maybe operator()(const std::shared_ptr& out_grad, const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const int64_t ignore_index) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); attrs.SetAllAttrs(ignore_index); if (weight) { return OpInterpUtil::Dispatch( *op_weight_, {out_grad, input, target, JUST(JUST(weight)->detach())}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {out_grad, input, target}, attrs); } } private: std::shared_ptr op_; std::shared_ptr op_weight_; }; class BinaryCrossEntropyLossGradFunctor { public: BinaryCrossEntropyLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_grad") .Input("dy") .Input("input") .Input("target") .Output("dx") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_grad") .Input("dy") .Input("input") .Input("target") .Input("weight") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight) const { if (weight) { return OpInterpUtil::Dispatch(*op_weight_, {dy, input, target, JUST(weight)}); } else { return OpInterpUtil::Dispatch(*op_, {dy, input, target}); } } private: std::shared_ptr op_; std::shared_ptr op_weight_; }; class BinaryCrossEntropyLossTargetGradFunctor { public: Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight) const { auto log_one_sub_input = JUST(functional::Log(JUST(ScalarSub(1, input, /*alpha=*/1)))); auto grad = functional::sequence_function(functional::Log) .then(std::bind(functional::Sub, log_one_sub_input, std::placeholders::_1, /*alpha=*/1, /*inplace=*/false)) .then(std::bind(functional::Mul, dy, std::placeholders::_1)) .call(input); return weight ? Mul(JUST(grad), JUST(weight)) : grad; } }; class BinaryCrossEntropyWithLogitsLossGradFunctor { public: BinaryCrossEntropyWithLogitsLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_grad") .Input("dy") .Input("input") .Input("target") .Output("dx") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_grad") .Input("dy") .Input("input") .Input("target") .Input("weight") .Output("dx") .Build()); op_pos_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_grad") .Input("dy") .Input("input") .Input("target") .Input("pos_weight") .Output("dx") .Build()); op_weight_pos_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_grad") .Input("dy") .Input("input") .Input("target") .Input("weight") .Input("pos_weight") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const Optional& pos_weight) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("has_pos_weight"); attrs.SetAllAttrs(pos_weight.has_value()); if (weight) { if (pos_weight) { return OpInterpUtil::Dispatch( *op_weight_pos_, {dy, input, target, JUST(weight), JUST(pos_weight)}, attrs); } else { return OpInterpUtil::Dispatch(*op_weight_, {dy, input, target, JUST(weight)}, attrs); } } else { if (pos_weight) { return OpInterpUtil::Dispatch(*op_pos_, {dy, input, target, JUST(pos_weight)}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {dy, input, target}, attrs); } } } private: std::shared_ptr op_; std::shared_ptr op_weight_; std::shared_ptr op_pos_; std::shared_ptr op_weight_pos_; }; class BinaryCrossEntropyWithLogitsLossTargetGradFunctor { public: Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const Optional& pos_weight) const { if (pos_weight) { auto sig = JUST(functional::Sigmoid(input)); auto log_one_sub_sig = JUST(functional::Log(JUST(functional::ScalarSub(1, sig, /*alpha=*/1)))); auto grad = functional::sequence_function(functional::Log) .then(std::bind(functional::Mul, std::placeholders::_1, JUST(pos_weight))) .then(std::bind(functional::Sub, log_one_sub_sig, std::placeholders::_1, /*alpha=*/1, false)) .call(sig); return weight ? functional::Mul(JUST(grad), JUST(weight)) : grad; } else { auto grad = functional::sequence_function(functional::Negative) .then(std::bind(functional::Mul, std::placeholders::_1, dy)) .call(input); return weight ? functional::Mul(JUST(grad), JUST(weight)) : grad; } } }; class BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor { public: BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("binary_cross_entropy_with_logits_reduce_mean_grad") .Input("dy") .Input("input") .Input("target") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target) const { return OpInterpUtil::Dispatch(*op_, {dy, input, target}); } private: std::shared_ptr op_; std::shared_ptr op_weight_; std::shared_ptr op_pos_; std::shared_ptr op_weight_pos_; }; class BinaryCrossEntropyWithLogitsReduceMeanLossTargetGradFunctor { public: Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& input, const std::shared_ptr& target) const { auto neg_mean_dy = JUST(functional::ScalarMul(-1.0 / input->nelement(), dy)); return functional::Mul(input, neg_mean_dy); } }; class CombinedMarginLossGradFunctor { public: CombinedMarginLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("combined_margin_loss_grad") .Input("dy") .Input("label") .Input("theta") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& label, const std::shared_ptr& theta, const float& m1, const float& m2, const float& m3, const int64_t& depth) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("m1", "m2", "m3", "depth"); attrs.SetAllAttrs(m1, m2, m3, depth); return OpInterpUtil::Dispatch(*op_, {dy, label, theta}, attrs); } private: std::shared_ptr op_; }; class AffineGridGradFunctor { public: AffineGridGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("affine_grid_grad").Input("dgrid").Output("dtheta").Build()); } Maybe operator()(const std::shared_ptr& dgrid, const Shape& size, const bool& align_corners) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size", "align_corners"); attrs.SetAllAttrs(size, align_corners); return OpInterpUtil::Dispatch(*op_, {dgrid}, attrs); } private: std::shared_ptr op_; }; class GridSampleGradFunctor { public: GridSampleGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("grid_sample_grad") .Input("doutput") .Input("input") .Input("grid") .Output("dinput") .Output("dgrid") .Build()); } Maybe operator()(const std::shared_ptr& doutput, const std::shared_ptr& input, const std::shared_ptr& grid, const std::string& interpolation_mode, const std::string& padding_mode, const bool& align_corners) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("interpolation_mode", "padding_mode", "align_corners"); attrs.SetAllAttrs(interpolation_mode, padding_mode, align_corners); return OpInterpUtil::Dispatch(*op_, {doutput, input, grid}, attrs); } private: std::shared_ptr op_; }; class CtcLossGradFunctor { public: CtcLossGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("ctc_loss_grad") .Input("grad_out") .Input("log_probs") .Input("targets") .Input("input_lengths") .Input("target_lengths") .Input("loss") .Input("alpha") .Output("grad") .Build()); } Maybe operator()(const std::shared_ptr& grad_out, const std::shared_ptr& log_probs, const std::shared_ptr& targets, const std::shared_ptr& input_lengths, const std::shared_ptr& target_lengths, const std::shared_ptr& loss, const std::shared_ptr& alpha, const int64_t& blank, const bool& zero_infinity, const int64_t& max_target_length) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("blank", "zero_infinity", "max_target_length"); attrs.SetAllAttrs(blank, zero_infinity, max_target_length); if (targets->dtype()->data_type() == DataType::kInt32) { return OpInterpUtil::Dispatch( *op_, {grad_out, log_probs, targets, input_lengths, target_lengths, loss, alpha}, attrs); } else { return OpInterpUtil::Dispatch( *op_, {grad_out, log_probs, JUST(functional::Cast(targets, DType::Int64(), false)), input_lengths, target_lengths, loss, alpha}, attrs); } return OpInterpUtil::Dispatch( *op_, {grad_out, log_probs, targets, input_lengths, target_lengths, loss, alpha}, attrs); } private: std::shared_ptr op_; }; class PadGradFunctor { public: PadGradFunctor() { reflect_pad1d_grad_ = CHECK_JUST(one::OpBuilder("reflection_pad1d_grad").Input("dy").Output("dx").Build()); reflect_pad2d_grad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d_grad").Input("dy").Output("dx").Build()); replicate_pad1d_grad_ = CHECK_JUST(one::OpBuilder("replication_pad1d_grad").Input("dy").Output("dx").Build()); replicate_pad2d_grad_ = CHECK_JUST(one::OpBuilder("replication_pad2d_grad").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::vector& pad, const std::string& mode, const Scalar& value) const { const int64_t ndim = dy->shape()->NumAxes(); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("padding"); attrs.SetAllAttrs(pad); if (mode == "reflect") { if (ndim == 3) { return OpInterpUtil::Dispatch(*reflect_pad1d_grad_, {dy}, attrs); } else if (ndim == 4) { return OpInterpUtil::Dispatch(*reflect_pad2d_grad_, {dy}, attrs); } else { UNIMPLEMENTED_THEN_RETURN() << "only 3D/4D reflect padding are supported for now"; } } else if (mode == "replicate") { if (ndim == 3) { return OpInterpUtil::Dispatch(*replicate_pad1d_grad_, {dy}, attrs); } else if (ndim == 4) { return OpInterpUtil::Dispatch(*replicate_pad2d_grad_, {dy}, attrs); } else { UNIMPLEMENTED_THEN_RETURN() << "only 3D/4D replicate padding are supported for now"; } } else { UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but only constant, reflect and replicate are valid."; } } private: std::shared_ptr reflect_pad1d_grad_; std::shared_ptr reflect_pad2d_grad_; std::shared_ptr replicate_pad1d_grad_; std::shared_ptr replicate_pad2d_grad_; }; class AvgPoolNdGradFunctor { public: AvgPoolNdGradFunctor() { for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(ndims); op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name).Input("dy").Input("x").Output("dx").Build()); } } static std::string GetOpTypeName(const int32_t& ndims) { return "avg_pool_" + std::to_string(ndims) + "d_grad"; } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& dy, const int32_t& ndims, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const bool& ceil_mode, const bool& count_include_pad, const int32_t& divisor_override) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format", "padding", "kernel_size", "stride", "ceil_mode", "count_include_pad", "divisor_override"); attrs.SetAllAttrs(data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); const auto& op_type_name = GetOpTypeName(ndims); const auto& it = op_expr_map_.find(op_type_name); CHECK_OR_RETURN(it != op_expr_map_.end()) << Error::RuntimeError() << "Encounter unsupported op " << op_type_name << " in AvgPoolNdGradFunctor."; CHECK_NOTNULL_OR_RETURN(it->second); // NOLINT(maybe-need-error-msg) return OpInterpUtil::Dispatch(*it->second, {dy, x}, attrs); } protected: std::unordered_map> op_expr_map_; }; class NormalizationGradFunctor { public: NormalizationGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("normalization_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Input("gamma") .Output("dx") .Output("gamma_diff") .Output("beta_diff") .Build()); } Maybe operator()(const std::shared_ptr& grad, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const std::shared_ptr& gamma, const float& epsilon, const int32_t& axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("epsilon", "axis"); attrs.SetAllAttrs(epsilon, axis); return OpInterpUtil::Dispatch(*op_, {grad, x, mean, inv_variance, gamma}, attrs); } private: std::shared_ptr op_; }; class NormalizationAddReluGradFunctor { public: NormalizationAddReluGradFunctor() { addend_op_ = CHECK_JUST(one::OpBuilder("normalization_add_relu_grad") .Input("x") .Input("dy") .Input("mean") .Input("inv_variance") .Input("gamma") .Input("beta") .Input("reserve_space") .Input("y") .Output("dx") .Output("gamma_diff") .Output("beta_diff") .Output("addend_diff") .Build()); no_addend_op_ = CHECK_JUST(one::OpBuilder("normalization_add_relu_grad") .Input("x") .Input("dy") .Input("mean") .Input("inv_variance") .Input("gamma") .Input("beta") .Input("reserve_space") .Input("y") .Output("dx") .Output("gamma_diff") .Output("beta_diff") .Build()); } Maybe operator()( const std::shared_ptr& x, const std::shared_ptr& grad, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const std::shared_ptr& gamma, const std::shared_ptr& beta, const std::shared_ptr& reserve_space, const std::shared_ptr& y, const int32_t& axis, const float& epsilon, bool has_addend) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis", "epsilon"); attrs.SetAllAttrs(axis, epsilon); if (has_addend) { return OpInterpUtil::Dispatch( *addend_op_, {x, grad, mean, inv_variance, gamma, beta, reserve_space, y}, attrs); } else { return OpInterpUtil::Dispatch( *no_addend_op_, {x, grad, mean, inv_variance, gamma, beta, reserve_space, y}, attrs); } } private: std::shared_ptr addend_op_; std::shared_ptr no_addend_op_; }; class LayerNormGradFunctor { public: LayerNormGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const int64_t& begin_norm_axis, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon"); attrs.SetAllAttrs(begin_norm_axis, epsilon); return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}, attrs); } private: std::shared_ptr op_; }; class LayerNormAffineGradFunctor { public: LayerNormAffineGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Input("gamma") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const std::shared_ptr& gamma, const int64_t& begin_norm_axis, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "epsilon"); attrs.SetAllAttrs(begin_norm_axis, epsilon); return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance, gamma}, attrs); } private: std::shared_ptr op_; }; class FuseLayerNormGradFunctor { public: FuseLayerNormGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Input("gamma") .Output("dx") .Output("gamma_diff") .Output("beta_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const std::shared_ptr& gamma, const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon"); attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon); return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance, gamma}, attrs); } private: std::shared_ptr op_; }; class LayerNormParamGradFunctor { public: LayerNormParamGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm_param_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Output("gamma_diff") .Output("beta_diff") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const int64_t& begin_params_axis) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_params_axis"); attrs.SetAllAttrs(begin_params_axis); return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}, attrs); } private: std::shared_ptr op_; }; class GroupNormGradFunctor { public: GroupNormGradFunctor() { affine_grad_op_ = CHECK_JUST(one::OpBuilder("group_norm_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Input("gamma") .Output("dx") .Build()); grad_op_ = CHECK_JUST(one::OpBuilder("group_norm_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance, const Optional& gamma, const int32_t& num_groups, const double& epsilon) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_groups", "epsilon"); attrs.SetAttr("num_groups", num_groups); attrs.SetAttr("epsilon", epsilon); if (gamma) { return OpInterpUtil::Dispatch(*affine_grad_op_, {dy, x, mean, inv_variance, JUST(gamma)}, attrs); } else { return OpInterpUtil::Dispatch(*grad_op_, {dy, x, mean, inv_variance}, attrs); } } private: std::shared_ptr affine_grad_op_; std::shared_ptr grad_op_; }; class GroupNormParamGradFunctor { public: GroupNormParamGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("group_norm_param_grad") .Input("dy") .Input("x") .Input("mean") .Input("inv_variance") .Output("dgamma") .Output("dbeta") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& mean, const std::shared_ptr& inv_variance) const { return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}); } private: std::shared_ptr op_; }; class RMSNormGradFunctor { public: RMSNormGradFunctor() { grad_op_ = CHECK_JUST(one::OpBuilder("rms_norm_grad") .Input("dy") .Input("x") .Input("inv_rms") .Output("dx") .Build()); affine_grad_op_ = CHECK_JUST(one::OpBuilder("rms_norm_grad") .Input("dy") .Input("x") .Input("inv_rms") .Input("weight") .Output("dx") .Build()); param_grad_op_ = CHECK_JUST(one::OpBuilder("rms_norm_param_grad") .Input("dy") .Input("x") .Input("inv_rms") .Output("weight_grad") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const std::shared_ptr& inv_rms, const Optional& weight, const bool param_grad) const { if (param_grad) { return OpInterpUtil::Dispatch(*param_grad_op_, {dy, x, inv_rms}); } else if (weight) { return OpInterpUtil::Dispatch(*affine_grad_op_, {dy, x, inv_rms, JUST(weight)}); } else { return OpInterpUtil::Dispatch(*grad_op_, {dy, x, inv_rms}); } } private: std::shared_ptr grad_op_; std::shared_ptr affine_grad_op_; std::shared_ptr param_grad_op_; }; class BroadcastMatmulGradBFunctor { public: BroadcastMatmulGradBFunctor() { op_ = CHECK_JUST( one::OpBuilder("broadcast_matmul_grad_b").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& a, const std::shared_ptr& b, double alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {a, b}, attrs); } private: std::shared_ptr op_; }; class FusedScaleTrilSoftmaxMaskScaleGradFunctor { public: FusedScaleTrilSoftmaxMaskScaleGradFunctor() { fused_op_ = CHECK_JUST(one::OpBuilder("fused_tril_scale_softmax_mask_scale_grad") .Input("softmax_y") .Input("dy") .Input("mask") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& softmax_y, const std::shared_ptr& dy, const std::shared_ptr& mask, const int64_t diagonal, const float tril_scale_value, const float mask_scale_value) const { auto& fused_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("diagonal", "tril_scale_value", "mask_scale_value"); fused_attrs.SetAllAttrs(diagonal, tril_scale_value, mask_scale_value); return OpInterpUtil::Dispatch(*fused_op_, {softmax_y, dy, mask}, fused_attrs); } private: std::shared_ptr fused_op_; }; class FusedScaleMaskSoftmaxGradFunctor { public: FusedScaleMaskSoftmaxGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_softmax_grad") .Input("y") .Input("dy") .Input("mask") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& y, const std::shared_ptr& dy, const std::shared_ptr& mask, const float& scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_value"); attrs.SetAllAttrs(scale); return OpInterpUtil::Dispatch(*op_, {y, dy, mask}, attrs); } private: std::shared_ptr op_; }; class FusedScaleMaskSoftmaxDropoutGradFunctor { public: FusedScaleMaskSoftmaxDropoutGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_scale_mask_softmax_dropout_grad") .Input("softmax_y") .Input("dy") .Input("mask") .Input("dropout_mask") .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& softmax_y, const std::shared_ptr& dy, const std::shared_ptr& mask, const std::shared_ptr& dropout_mask, const float& scale, const float& dropout_scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale_value", "dropout_scale_value"); attrs.SetAllAttrs(scale, dropout_scale); return OpInterpUtil::Dispatch(*op_, {softmax_y, dy, mask, dropout_mask}, attrs); } private: std::shared_ptr op_; }; class CublasBiasAddReluMatmulGradFunctor { public: CublasBiasAddReluMatmulGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("cublas_bias_add_relu_matmul_grad") .Input("dy") .Input("weight") .Input("aux") .Output("d_grad") .Output("d_bias") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& weight, const std::shared_ptr& aux, const double& alpha) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha"); attrs.SetAllAttrs(alpha); return OpInterpUtil::Dispatch(*op_, {dy, weight, aux}, attrs); } private: std::shared_ptr op_; }; class CublasMatmulBiasAddGradFunctor { public: CublasMatmulBiasAddGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("cublas_matmul_bias_add_grad") .Input("dy") .Input("x") .Output("w_grad") .Output("b_grad") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {dy, x}); } private: std::shared_ptr op_; }; class FusedReluDropoutGradFunctor { public: FusedReluDropoutGradFunctor() { op_ = CHECK_JUST( one::OpBuilder("fused_relu_dropout_grad").Input("dy").Input("mask").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& mask, const float& scale) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("scale"); attrs.SetAllAttrs(scale); return OpInterpUtil::Dispatch(*op_, {dy, mask}, attrs); } private: std::shared_ptr op_; }; class FusedDotFeatureInteractionGradFunctor { public: FusedDotFeatureInteractionGradFunctor() { ops_has_output_concat_grad_.resize(kMaxInputCount); ops_no_output_concat_grad_.resize(kMaxInputCount); for (int n = 0; n < ops_has_output_concat_grad_.size(); ++n) { ops_has_output_concat_grad_[n] = CHECK_JUST(one::OpBuilder("fused_dot_feature_interaction_grad") .Input("dy") .Input("features", n + 1) .Output("features_grad", n + 1) .Output("output_concat_grad") .Build()); } for (int n = 0; n < ops_no_output_concat_grad_.size(); ++n) { ops_no_output_concat_grad_[n] = CHECK_JUST(one::OpBuilder("fused_dot_feature_interaction_grad") .Input("dy") .Input("features", n + 1) .Output("features_grad", n + 1) .Build()); } } Maybe operator()(const std::shared_ptr& dy, const TensorTuple& features, const bool& has_output_concat, const bool& self_interaction, const int32_t& output_concat_grad_dim, const std::string& pooling) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("self_interaction", "output_concat_grad_dim", "pooling"); attrs.SetAllAttrs(self_interaction, output_concat_grad_dim, pooling); CHECK_OR_RETURN(pooling == "sum" || pooling == "none") << Error::RuntimeError() << "pooling should be sum or none, but get " << pooling << ". "; const int64_t n_features_grad = features.size(); CHECK_LE_OR_RETURN(n_features_grad, kMaxInputCount) << Error::RuntimeError() << "The number of tensors in features should be less than 128."; TensorTuple inputs(n_features_grad + 1); inputs[0] = dy; for (int32_t i = 0; i < n_features_grad; ++i) { inputs[i + 1] = features[i]; } if (has_output_concat) { return OpInterpUtil::Dispatch( *JUST(oneflow::VectorAt(ops_has_output_concat_grad_, n_features_grad - 1)), inputs, attrs); } else { return OpInterpUtil::Dispatch( *JUST(oneflow::VectorAt(ops_no_output_concat_grad_, n_features_grad - 1)), inputs, attrs); } } private: std::vector> ops_has_output_concat_grad_; std::vector> ops_no_output_concat_grad_; }; class FusedCrossFeatureInteractionV1GradFunctor { public: FusedCrossFeatureInteractionV1GradFunctor() { v1_grad_op_ = CHECK_JUST(one::OpBuilder("fused_cross_feature_interaction_v1_grad") .Input("dy") .Input("weight") .Input("x") .Input("x0") .Input("matmul_result") .Output("dx") .Output("dw") .Output("dx0") .Output("dbias") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& weight, const std::shared_ptr& x, const std::shared_ptr& x0, const std::shared_ptr& matmul_result) const { return OpInterpUtil::Dispatch(*v1_grad_op_, {dy, weight, x, x0, matmul_result}); } private: std::shared_ptr v1_grad_op_; }; class FusedCrossFeatureInteractionV2GradFunctor { public: FusedCrossFeatureInteractionV2GradFunctor() { v2_grad_op_ = CHECK_JUST(one::OpBuilder("fused_cross_feature_interaction_v2_grad") .Input("dy") .Input("weight") .Input("bias") .Input("x") .Input("x0") .Input("matmul_result") .Output("dx") .Output("dw") .Output("dx0") .Output("dbias") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& weight, const std::shared_ptr& bias, const std::shared_ptr& x, const std::shared_ptr& x0, const std::shared_ptr& matmul_result) const { return OpInterpUtil::Dispatch(*v2_grad_op_, {dy, weight, bias, x, x0, matmul_result}); } private: std::shared_ptr v2_grad_op_; }; class MatrixVectorProductGradAFunctor { public: MatrixVectorProductGradAFunctor() { matrix_vector_product_grad_a_op_ = CHECK_JUST( one::OpBuilder("matrix_vector_product_grad_a").Input("dy").Input("b").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& b) const { return OpInterpUtil::Dispatch(*matrix_vector_product_grad_a_op_, {dy, b}); } private: std::shared_ptr matrix_vector_product_grad_a_op_; }; class MatrixVectorProductGradBFunctor { public: MatrixVectorProductGradBFunctor() { matrix_vector_product_grad_b_op_ = CHECK_JUST( one::OpBuilder("matrix_vector_product_grad_b").Input("dy").Input("a").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& a) const { return OpInterpUtil::Dispatch(*matrix_vector_product_grad_b_op_, {dy, a}); } private: std::shared_ptr matrix_vector_product_grad_b_op_; }; class VectorMatrixProductGradAFunctor { public: VectorMatrixProductGradAFunctor() { vector_matrix_product_grad_a_op_ = CHECK_JUST( one::OpBuilder("vector_matrix_product_grad_a").Input("dy").Input("b").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& b) const { return OpInterpUtil::Dispatch(*vector_matrix_product_grad_a_op_, {dy, b}); } private: std::shared_ptr vector_matrix_product_grad_a_op_; }; class VectorMatrixProductGradBFunctor { public: VectorMatrixProductGradBFunctor() { vector_matrix_product_grad_b_op_ = CHECK_JUST( one::OpBuilder("vector_matrix_product_grad_b").Input("dy").Input("a").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& a) const { return OpInterpUtil::Dispatch(*vector_matrix_product_grad_b_op_, {dy, a}); } private: std::shared_ptr vector_matrix_product_grad_b_op_; }; class DeformConv2dInputGradFunctor { public: DeformConv2dInputGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("deform_conv2d_input_grad") .Input("output_grad") .Input("input") .Input("weight") .Input("offset") .Output("input_grad") .Output("offset_grad") .Build()); mask_op_ = CHECK_JUST(one::OpBuilder("deform_conv2d_input_grad") .Input("output_grad") .Input("input") .Input("weight") .Input("offset") .Input("mask") .Output("input_grad") .Output("offset_grad") .Output("mask_grad") .Build()); } Maybe operator()(const std::shared_ptr& output_grad, const std::shared_ptr& input, const std::shared_ptr& weight, const std::shared_ptr& offset, const Optional& mask, const int32_t& stride_h, const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w, const int32_t& dilation_h, const int32_t& dilation_w, const int32_t& groups, const int32_t& offset_groups, const bool& use_mask) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("stride_h", "stride_w", "pad_h", "pad_w", "dilation_h", "dilation_w", "groups", "offset_groups", "use_mask"); attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups, offset_groups, use_mask); if (mask) { return OpInterpUtil::Dispatch( *mask_op_, {output_grad, input, weight, offset, JUST(mask)}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {output_grad, input, weight, offset}, attrs); } } private: std::shared_ptr op_; std::shared_ptr mask_op_; }; class DeformConv2dParamGradFunctor { public: DeformConv2dParamGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("deform_conv2d_param_grad") .Input("output_grad") .Input("input") .Input("weight") .Input("offset") .Input("mask") .Output("weight_grad") .Build()); } Maybe operator()(const std::shared_ptr& output_grad, const std::shared_ptr& input, const std::shared_ptr& weight, const std::shared_ptr& offset, const std::shared_ptr& mask, const int32_t& stride_h, const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w, const int32_t& dilation_h, const int32_t& dilation_w, const int32_t& groups, const int32_t& offset_groups, const bool& use_mask) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("stride_h", "stride_w", "pad_h", "pad_w", "dilation_h", "dilation_w", "groups", "offset_groups", "use_mask"); attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups, offset_groups, use_mask); return OpInterpUtil::Dispatch(*op_, {output_grad, input, weight, offset, mask}, attrs); } private: std::shared_ptr op_; }; class FusedGluWithoutLinearGradFunctor { public: FusedGluWithoutLinearGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("fused_glu_without_linear_grad") .Input("dy") .Input("matmul_wx") .Output("d_matmul_wx") .Build()); split_op_ = CHECK_JUST(one::OpBuilder("fused_glu_without_linear_grad") .Input("dy") .Input("matmul_wx") .Input("matmul_vx") .Output("d_matmul_wx") .Output("d_matmul_vx") .Build()); } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& matmul_wx, const Optional& matmul_vx, const std::string& activation) const { // check whether the user provide splited tensors bool is_split_mode = false; if (matmul_vx) { is_split_mode = true; } // obtain input shape const auto& dy_shape = *(dy->shape()); const auto& matmul_wx_shape = *(matmul_wx->shape()); // check number of axes of dy and matmul_wx size_t dy_num_axes = dy_shape.NumAxes(); size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes(); CHECK_GT_OR_RETURN(dy_num_axes, 1) << "number of axes of \'dy\' should have be greater than 1, yet get " << dy_num_axes; CHECK_GE_OR_RETURN(matmul_wx_num_axes, 2) << "number of axes of \'matmul_wx\' should have be greater than 1, yet get " << matmul_wx_num_axes; CHECK_EQ_OR_RETURN(dy_num_axes, matmul_wx_num_axes) << "number of axes of \'matmul_wx\' (" << matmul_wx_num_axes << ") should equal to the one of \'dy\' (" << dy_num_axes << ")"; // check input shapes of dy and matmul_wx for (uint64_t i = 0; i < dy_num_axes - 1; i++) { size_t dy_size = dy_shape.At(i); size_t matmul_wx_size = matmul_wx_shape.At(i); CHECK_EQ_OR_RETURN(dy_size, matmul_wx_size) << "dimension " << i << "of \'dy\'(" << dy_size << ") and \'matmul_wx\'(" << matmul_wx_size << ") is not consistent"; } if (is_split_mode) { CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") and \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ") is not consistent"; } else { CHECK_EQ_OR_RETURN(2 * dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "two times of the last dimension of \'dy\'(" << 2 * (dy_shape.At(dy_num_axes - 1)) << ") and \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ") is not consistent"; } if (is_split_mode) { // obtain input shape const auto& matmul_vx_shape = *(JUST(matmul_vx)->shape()); // check number of axes of dy and matmul_vx size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes(); CHECK_EQ_OR_RETURN(dy_num_axes, matmul_vx_num_axes) << "number of axes of \'matmul_vx\' (" << matmul_vx_num_axes << ") should equal to the one of \'dy\' (" << dy_num_axes << ")"; // check input shapes of dy and matmul_vx for (uint64_t i = 0; i < dy_num_axes - 1; i++) { size_t dy_size = dy_shape.At(i); size_t matmul_vx_size = matmul_vx_shape.At(i); CHECK_EQ_OR_RETURN(dy_size, matmul_vx_size) << "dimension " << i << "of \'dy\'(" << dy_size << ") and \'matmul_vx\'(" << matmul_vx_size << ") is not consistent"; } CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_vx_shape.At(matmul_vx_num_axes - 1)) << "last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") and \'matmul_vx\'(" << matmul_vx_shape.At(matmul_vx_num_axes - 1) << ") is not consistent"; } // set activation attribute auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("activation"); attrs.SetAllAttrs(activation); // dispatch corresponding operator if (is_split_mode) { return OpInterpUtil::Dispatch(*split_op_, {dy, matmul_wx, JUST(matmul_vx)}, attrs); } else { return OpInterpUtil::Dispatch(*op_, {dy, matmul_wx}, attrs); } } private: std::shared_ptr op_; std::shared_ptr split_op_; }; class FusedMLPGradFunctor { public: FusedMLPGradFunctor() { #if CUDA_VERSION >= 11060 fused_op_.resize(kMaxInputCount /*the maximum number of layers*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("cublas_fused_mlp_grad") .Input("dy") .Input("x") .Input("weights", n) .Input("cublas_aux", n) .Input("hidden", n) .Output("d_x") .Output("d_biases", n) .Output("d_weights", n) .Build()); } #endif } Maybe operator()(const std::shared_ptr& dy, const std::shared_ptr& x, const TensorTuple& weights, const TensorTuple& cublas_aux, const TensorTuple& hidden, const std::vector& alpha_list) const { const int64_t weight_size = weights.size(); CHECK_EQ_OR_RETURN(alpha_list.size(), weight_size - 1) << "Alpha list size should be equal to weight_size - 1. "; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("alpha_list"); attrs.SetAllAttrs(alpha_list); TensorTuple input(2 + 3 * weight_size); input[0] = dy; input[1] = x; std::copy(weights.begin(), weights.end(), input.begin() + 2); std::copy(cublas_aux.begin(), cublas_aux.end(), input.begin() + 2 + weight_size); std::copy(hidden.begin(), hidden.end(), input.begin() + 2 + 2 * weight_size); #if CUDA_VERSION >= 11060 return OpInterpUtil::Dispatch(*fused_op_[weight_size], input, attrs); #endif UNIMPLEMENTED_THEN_RETURN() << "Only Support in CUDA_VERSION >= 11060"; } private: #if CUDA_VERSION >= 11060 std::vector> fused_op_; #endif }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ConvBiasGrad"); m.add_functor("ConvFilterGrad"); m.add_functor("ConvDataGrad"); m.add_functor("EmbeddingGrad"); m.add_functor("TFPoolNdGrad"); m.add_functor("AdaptivePoolNdGrad"); m.add_functor("KLDivLossGrad"); m.add_functor("KLDivLossTargetGrad"); m.add_functor("NLLGrad"); m.add_functor("BinaryCrossEntropyLossGrad"); m.add_functor("BinaryCrossEntropyLossTargetGrad"); m.add_functor( "BinaryCrossEntropyWithLogitsLossGrad"); m.add_functor( "BinaryCrossEntropyWithLogitsLossTargetGrad"); m.add_functor("SparseCrossEntropyGrad"); m.add_functor("SparseCrossEntropyMsGrad"); m.add_functor("SparseSoftmaxCrossEntropyGrad"); m.add_functor("SparseSoftmaxCrossEntropyMsGrad"); m.add_functor("SmoothL1LossGrad"); m.add_functor("CombinedMarginLossGrad"); m.add_functor("AffineGridGrad"); m.add_functor("GridSampleGrad"); m.add_functor("MaxPoolNdGrad"); m.add_functor>("MaxUnpool1dGrad"); m.add_functor>("MaxUnpool2dGrad"); m.add_functor>("MaxUnpool3dGrad"); m.add_functor("AdaptiveMaxPoolNdGrad"); m.add_functor("PadGrad"); m.add_functor("AvgPoolNdGrad"); m.add_functor("NormalizationGrad"); m.add_functor("NormalizationAddReluGrad"); m.add_functor("LayerNormGrad"); m.add_functor("LayerNormAffineGrad"); m.add_functor("LayerNormParamGrad"); m.add_functor("FuseLayerNormGrad"); m.add_functor("GroupNormGrad"); m.add_functor("GroupNormParamGrad"); m.add_functor("BroadcastMatmulGradB"); m.add_functor("CtcLossGrad"); m.add_functor( "FusedScaleTrilSoftmaxMaskScaleGrad"); m.add_functor("FusedScaleMaskSoftmaxGrad"); m.add_functor("FusedScaleMaskSoftmaxDropoutGrad"); m.add_functor("CublasBiasAddReluMatmulGrad"); m.add_functor("CublasMatmulBiasAddGrad"); m.add_functor("FusedReluDropoutGrad"); m.add_functor("FusedDotFeatureInteractionGrad"); m.add_functor( "FusedCrossFeatureInteractionV1Grad"); m.add_functor( "FusedCrossFeatureInteractionV2Grad"); m.add_functor("FusedGluWithoutLinearGrad"); m.add_functor("FusedMLPGrad"); m.add_functor( "BinaryCrossEntropyWithLogitsReduceMeanLossGrad"); m.add_functor( "BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad"); m.add_functor("MatrixVectorProductGradA"); m.add_functor("MatrixVectorProductGradB"); m.add_functor("VectorMatrixProductGradA"); m.add_functor("VectorMatrixProductGradB"); m.add_functor("DeformConv2dInputGrad"); m.add_functor("DeformConv2dParamGrad"); m.add_functor("RMSNormGrad"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/quantization.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/function_library.h" namespace oneflow { namespace one { namespace functional { namespace impl { class MinMaxObserverFunctor { public: MinMaxObserverFunctor() { op_ = CHECK_JUST(one::OpBuilder("min_max_observer") .Input("in") .Output("scale") .Output("zero_point") .Build()); } Maybe operator()(const std::shared_ptr& in, const std::string& quantization_formula, const int32_t& quantization_bit, const std::string& quantization_scheme, const bool& per_layer_quantization) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("quantization_formula", "quantization_bit", "quantization_scheme", "per_layer_quantization"); attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme, per_layer_quantization); return OpInterpUtil::Dispatch(*op_, {in}, attrs); } private: std::shared_ptr op_; }; class MovingAverageMinMaxObserverFunctor { public: MovingAverageMinMaxObserverFunctor() { op_ = CHECK_JUST(one::OpBuilder("moving_average_min_max_observer") .Input("in") .Input("current_train_step") .Input("moving_max") .Input("moving_min") .Output("scale") .Output("zero_point") .Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& current_train_step, const std::shared_ptr& moving_max, const std::shared_ptr& moving_min, const bool& training, const int64_t& stop_update_after_iters, const std::string& quantization_formula, const int32_t& quantization_bit, const std::string& quantization_scheme, const float& momentum) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("training", "quantization_formula", "stop_update_after_iters", "quantization_bit", "quantization_scheme", "momentum"); attrs.SetAllAttrs(training, quantization_formula, stop_update_after_iters, quantization_bit, quantization_scheme, momentum); return OpInterpUtil::Dispatch( *op_, {in, current_train_step, moving_max, moving_min}, attrs); } private: std::shared_ptr op_; }; class FakeQuantizationFunctor { public: FakeQuantizationFunctor() { op_ = CHECK_JUST(one::OpBuilder("fake_quantization") .Input("in") .Input("scale") .Input("zero_point") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& scale, const std::shared_ptr& zero_point, const std::string& quantization_formula, const int32_t& quantization_bit, const std::string& quantization_scheme) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("quantization_formula", "quantization_bit", "quantization_scheme"); attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme); return OpInterpUtil::Dispatch(*op_, {in, scale, zero_point}, attrs); } private: std::shared_ptr op_; }; class QuantizationFunctor { public: QuantizationFunctor() { op_ = CHECK_JUST(one::OpBuilder("quantization") .Input("in") .Input("scale") .Input("zero_point") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& scale, const std::shared_ptr& zero_point, const std::string quantization_formula, const int32_t& quantization_bit, const std::string quantization_scheme) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("quantization_formula", "quantization_bit", "quantization_scheme"); attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme); return OpInterpUtil::Dispatch(*op_, {in, scale, zero_point}, attrs); } private: std::shared_ptr op_; }; class GroupwiseDequantizeFunctor { public: GroupwiseDequantizeFunctor() { symmetric_op_ = CHECK_JUST( one::OpBuilder("groupwise_dequantize").Input("in").Input("scale").Output("out").Build()); asymmetric_op_ = CHECK_JUST(one::OpBuilder("groupwise_dequantize") .Input("in") .Input("scale") .Input("zero") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& scale, const Optional& zero, const int32_t& num_bits, const bool& symmetric, const int64_t& group_dim, const int64_t& group_size) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_bits", "symmetric", "group_dim", "group_size"); CHECK_OR_RETURN(num_bits == 4 || num_bits == 8) << "num_bits should be 4 or 8."; CHECK_GE_OR_RETURN(in->shape()->NumAxes(), 1) << "The number of dimensions for tensor in should be greater than or equal to 1."; const int64_t regularized_group_dim = group_dim < 0 ? in->shape()->NumAxes() + group_dim : group_dim; CHECK_OR_RETURN(regularized_group_dim >= 0 && regularized_group_dim < in->shape()->NumAxes()) << "group_dim should be in range [-" << in->shape()->NumAxes() << "," << in->shape()->NumAxes() << ")."; const int64_t group_dim_size = in->shape()->At(regularized_group_dim) * (regularized_group_dim == in->shape()->NumAxes() - 1 ? 8 / num_bits : 1); const int64_t regularized_group_size = group_size < 0 ? group_dim_size : group_size; CHECK_OR_RETURN(regularized_group_size > 0 && regularized_group_size <= group_dim_size) << "group_size should be in range (0," << group_dim_size << "]."; CHECK_EQ_OR_RETURN(group_dim_size % regularized_group_size, 0) << "group_size should be a divisor of " << group_dim_size << "."; const int64_t num_groups = group_dim_size / regularized_group_size; if (symmetric) { CHECK_OR_RETURN(in->dtype()->data_type() == DataType::kUInt8 || in->dtype()->data_type() == DataType::kInt8) << "The dtype of tensor in should be int8 or uint8."; } else { CHECK_OR_RETURN(in->dtype()->data_type() == DataType::kUInt8) << "The dtype of tensor in should be uint8."; } CHECK_EQ_OR_RETURN(scale->shape()->NumAxes(), in->shape()->NumAxes()) << "The number of dimensions of tensor scale should be equal to tensor in."; for (int64_t i = 0; i < in->shape()->NumAxes(); ++i) { if (i == regularized_group_dim) { CHECK_EQ_OR_RETURN(scale->shape()->At(i), num_groups) << "The size of the " << i << "-th dimension of tensor scale should be equal to " << num_groups; } else if (i == in->shape()->NumAxes() - 1) { CHECK_EQ_OR_RETURN(scale->shape()->At(i), in->shape()->At(i) * (8 / num_bits)) << "The size of the " << i << "-th dimension of tensor scale should be equal to " << in->shape()->At(i) * (8 / num_bits) << "."; } else { CHECK_EQ_OR_RETURN(scale->shape()->At(i), in->shape()->At(i)) << "The size of the " << i << "-th dimension of tensor scale should be equal to tensor in."; } } if (!symmetric) { CHECK_OR_RETURN(zero) << "When symmetric is False, tensor zero should be specified."; CHECK_OR_RETURN(JUST(zero)->dtype() == scale->dtype()) << "The dtype of the zero tensor should be the same as the scale " "tensor."; CHECK_OR_RETURN(*JUST(zero)->shape() == *scale->shape()) << "The shape of zero tensor should be equal to tensor scale."; } else { CHECK_OR_RETURN(!zero) << "When symmetric is True, tensor zero should be None."; } attrs.SetAllAttrs(num_bits, symmetric, regularized_group_dim, regularized_group_size); if (symmetric) { return OpInterpUtil::Dispatch(*symmetric_op_, {in, scale}, attrs); } else { return OpInterpUtil::Dispatch(*asymmetric_op_, {in, scale, JUST(zero)}, attrs); } } private: std::shared_ptr symmetric_op_; std::shared_ptr asymmetric_op_; }; class FusedLinearWithGroupwiseQuantizedWeightFunctor { public: FusedLinearWithGroupwiseQuantizedWeightFunctor() { symmetric_with_bias_op_ = CHECK_JUST(one::OpBuilder("fused_linear_with_groupwise_quantized_weight") .Input("x") .Input("w") .Input("w_scale") .Input("b") .Output("out") .Build()); symmetric_without_bias_op_ = CHECK_JUST(one::OpBuilder("fused_linear_with_groupwise_quantized_weight") .Input("x") .Input("w") .Input("w_scale") .Output("out") .Build()); asymmetric_with_bias_op_ = CHECK_JUST(one::OpBuilder("fused_linear_with_groupwise_quantized_weight") .Input("x") .Input("w") .Input("w_scale") .Input("w_zero") .Input("b") .Output("out") .Build()); asymmetric_without_bias_op_ = CHECK_JUST(one::OpBuilder("fused_linear_with_groupwise_quantized_weight") .Input("x") .Input("w") .Input("w_scale") .Input("w_zero") .Output("out") .Build()); } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& w, const std::shared_ptr& w_scale, const Optional& w_zero, const Optional& b, const int32_t& num_bits, const bool& symmetric, const int64_t& group_dim, const int64_t& group_size) const { CHECK_GE_OR_RETURN(x->shape()->NumAxes(), 2) << "The number of dimensions for tensor x should be greater than or equal to 2."; const int64_t m = x->shape()->Count(0, x->shape()->NumAxes() - 1); const int64_t k = x->shape()->At(x->shape()->NumAxes() - 1); CHECK_OR_RETURN(num_bits == 4 || num_bits == 8) << "num_bits should be 4 or 8."; CHECK_EQ_OR_RETURN(w->shape()->NumAxes(), 2) << "The number of dimensions for tensor w should be equal to 2."; CHECK_EQ_OR_RETURN(k % (8 / num_bits), 0) << "The size of the last dimension of x should be a multiple of (8/num_bits)."; CHECK_EQ_OR_RETURN(w->shape()->At(1), k / (8 / num_bits)) << "The size of second dimension of tensor w should be equal to " << k / (8 / num_bits); const int64_t n = w->shape()->At(0); const int64_t regularized_group_dim = group_dim < 0 ? w->shape()->NumAxes() + group_dim : group_dim; CHECK_OR_RETURN(regularized_group_dim == 0 || regularized_group_dim == 1) << "group_dim should be in range [-2,2)."; const int64_t group_dim_size = regularized_group_dim == 0 ? n : k; const int64_t regularized_group_size = group_size < 0 ? group_dim_size : group_size; CHECK_OR_RETURN(regularized_group_size > 0 && regularized_group_size <= group_dim_size) << "group_size should be in range (0," << group_dim_size << "]."; CHECK_EQ_OR_RETURN(group_dim_size % regularized_group_size, 0) << "group_size should be a divisor of " << group_dim_size << "."; const int64_t num_groups = group_dim_size / regularized_group_size; if (symmetric) { CHECK_OR_RETURN(w->dtype()->data_type() == DataType::kUInt8 || w->dtype()->data_type() == DataType::kInt8) << "The dtype of tensor w should be int8 or uint8."; } else { CHECK_OR_RETURN(w->dtype()->data_type() == DataType::kUInt8) << "The dtype of tensor w should be uint8."; } CHECK_EQ_OR_RETURN(w_scale->shape()->NumAxes(), 2) << "The number of dimensions of tensor w_scale should be equal to 2."; for (int64_t i = 0; i < 2; ++i) { if (i == regularized_group_dim) { CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), num_groups) << "The size of the " << i << "-th dimension of tensor w_scale should be equal to " << num_groups; } else if (i == 1) { CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), k) << "The size of the " << i << "-th dimension of tensor w_scale should be equal to " << k << "."; } else { CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), w->shape()->At(i)) << "The size of the " << i << "-th dimension of tensor w_scale should be equal to tensor w."; } } CHECK_OR_RETURN(w_scale->dtype() == x->dtype()) << "The dtype of the w_scale tensor should be the same as the x tensor."; if (!symmetric) { CHECK_OR_RETURN(w_zero) << "When symmetric is False, tensor w_zero should be specified."; CHECK_OR_RETURN(JUST(w_zero)->dtype() == w_scale->dtype()) << "The dtype of the w_zero tensor should be the same as the w_scale " "tensor."; CHECK_OR_RETURN(*JUST(w_zero)->shape() == *w_scale->shape()) << "The shape of w_zero tensor should be equal to tensor w_scale."; } else { CHECK_OR_RETURN(!w_zero) << "When symmetric is True, tensor w_zero should be None."; } if (b) { CHECK_OR_RETURN(JUST(b)->dtype() == x->dtype()) << "The dtype of the b tensor should be the same as the x tensor."; CHECK_EQ_OR_RETURN(JUST(b)->shape()->NumAxes(), 1) << "The number of dimensions for tensor b should be equal to 1."; CHECK_EQ_OR_RETURN(JUST(b)->shape()->At(0), n) << "The size of first dimension of tensor b should be equal to the size of first " "dimension of tensor w"; } if (m > 8) { const auto w_dequantized = JUST(functional::GroupwiseDequantize( w, w_scale, w_zero, num_bits, symmetric, group_dim, group_size)); if (b) { return JUST(functional::FusedMatmulBias(x, w_dequantized, JUST(b), Optional(), 1.0, 1.0)); } else { return JUST(functional::MatMul(x, w_dequantized, false, true, 1.0)); } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_bits", "symmetric", "group_dim", "group_size"); attrs.SetAllAttrs(num_bits, symmetric, regularized_group_dim, regularized_group_size); if (symmetric) { if (b) { return OpInterpUtil::Dispatch(*symmetric_with_bias_op_, {x, w, w_scale, JUST(b)}, attrs); } else { return OpInterpUtil::Dispatch(*symmetric_without_bias_op_, {x, w, w_scale}, attrs); } } else { if (b) { return OpInterpUtil::Dispatch(*asymmetric_with_bias_op_, {x, w, w_scale, JUST(w_zero), JUST(b)}, attrs); } else { return OpInterpUtil::Dispatch(*asymmetric_without_bias_op_, {x, w, w_scale, JUST(w_zero)}, attrs); } } } private: std::shared_ptr symmetric_with_bias_op_; std::shared_ptr symmetric_without_bias_op_; std::shared_ptr asymmetric_with_bias_op_; std::shared_ptr asymmetric_without_bias_op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("FakeQuantization"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Quantization"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("MinMaxObserver"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("MovingAverageMinMaxObserver"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("GroupwiseDequantize"); m.add_functor( "FusedLinearWithGroupwiseQuantizedWeight"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/random_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/framework/layout.h" #include "oneflow/core/framework/mutable_attr_map.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/job/global_mode.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/functional/functional.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace one { namespace functional { namespace impl { class BernoulliFunctor { public: BernoulliFunctor() { bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Symbol& dtype, const Optional& generator, const bool& inplace) const { if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& bernoulli_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype", "seed", "p"); // p == -1 means bernoulli op doesn't use p to generate random number bernoulli_attrs.SetAllAttrs(dtype->data_type(), static_cast(gen->current_seed()), static_cast(-1)); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(bernoulli_attrs, distribution_state); if (inplace) { auto outputs = std::make_shared(1); JUST(CheckInplaceValid(x)); (*outputs)[0] = x; JUST(OpInterpUtil::Dispatch(*bernoulli_op_, {x}, outputs.get(), ctx)); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*bernoulli_op_, {x}, ctx); } } private: std::shared_ptr bernoulli_op_; }; class BernoulliInplaceFunctor { public: Maybe operator()(const std::shared_ptr& x, const Symbol& dtype, const Optional& generator) const { return Bernoulli(x, dtype, generator, true); } }; class BernoulliProbFunctor { public: BernoulliProbFunctor() { bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const double& p, const Symbol& dtype, const Optional& generator, const bool& inplace) const { CHECK_OR_THROW(p >= 0.0 && p <= 1.0) << "bernoulli expects p to be in [0, 1], but got p=" << p; if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& bernoulli_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype", "seed", "p"); bernoulli_attrs.SetAllAttrs(dtype->data_type(), static_cast(gen->current_seed()), p); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(bernoulli_attrs, distribution_state); if (inplace) { auto outputs = std::make_shared(1); JUST(CheckInplaceValid(x)); (*outputs)[0] = x; JUST(OpInterpUtil::Dispatch(*bernoulli_op_, {x}, outputs.get(), ctx)); return outputs->at(0); } else { return OpInterpUtil::Dispatch(*bernoulli_op_, {x}, ctx); } } private: std::shared_ptr bernoulli_op_; }; class BernoulliProbInplaceFunctor { public: Maybe operator()(const std::shared_ptr& x, const double& p, const Symbol& dtype, const Optional& generator) const { return BernoulliProb(x, p, dtype, generator, true); } }; class InplaceUniformFunctor { public: InplaceUniformFunctor() { uniform_op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); uniform_int_op_ = CHECK_JUST(one::OpBuilder("uniform_int").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Scalar& from, const Scalar& to) const { JUST(CheckInplaceValid(x)); const Shape& shape = *(x->shape()); std::shared_ptr exec_op; const auto& dtype = x->dtype(); bool IsInteger = false; if (dtype->is_floating_point()) { exec_op = uniform_op_; } else if (dtype->is_integer()) { exec_op = uniform_int_op_; IsInteger = true; } else { OF_UNIMPLEMENTED() << "Only support floating and int dtype."; } DataType dtype_val = dtype->data_type(); Optional> device; Optional> placement; Optional> nd_sbp; auto gen = JUST(one::DefaultAutoGenerator()); if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); placement = JUST(x->parallel_desc()); nd_sbp = JUST(x->nd_sbp()); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); } else { device = JUST(x->device()); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("from", "to", "shape", "dtype", "seed", "nd_sbp"); Optional> attr_nd_sbp{NullOpt}; if (nd_sbp) { attr_nd_sbp = *JUST(GetNdSbpStrList(JUST(nd_sbp))); } if (IsInteger) { attrs.SetAllAttrs(from.Value(), to.Value(), shape, dtype_val, static_cast(gen->current_seed()), attr_nd_sbp); } else { attrs.SetAllAttrs(from.Value(), to.Value(), shape, dtype_val, static_cast(gen->current_seed()), attr_nd_sbp); } const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); ctx.parallel_desc = placement; ctx.nd_sbp = nd_sbp; ctx.device = device; auto outputs = std::make_shared(1); (*outputs)[0] = x; JUST(OpInterpUtil::Dispatch(*exec_op, {}, outputs.get(), ctx)); return outputs->at(0); } private: std::shared_ptr uniform_op_; std::shared_ptr uniform_int_op_; }; class RandFunctor { public: RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRand(shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); } DataType dtype_val = GetDefaultDType()->data_type(); if (dtype.has_value()) { dtype_val = JUST(dtype)->data_type(); if (!JUST(dtype)->is_floating_point()) { OF_UNIMPLEMENTED() << "Only support floating dtype in rand()."; } } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("from", "to", "shape", "dtype", "seed"); attrs.SetAllAttrs(static_cast(0), static_cast(1), shape, dtype_val, static_cast(gen->current_seed())); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); ctx.device = device; auto result = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class GlobalRandFunctor { public: GlobalRandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { DataType dtype_val = GetDefaultDType()->data_type(); if (dtype.has_value()) { dtype_val = JUST(dtype)->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << "Only support floating dtype in rand()."; } } JUST(CheckDeviceIdsIsValid(placement)); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp)); auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("from", "to", "shape", "dtype", "seed", "nd_sbp"); attrs.SetAllAttrs(static_cast(0), static_cast(1), shape, dtype_val, static_cast(gen->current_seed()), attr_nd_sbp); const auto& distribution_state = std::make_shared(gen); auto result = JUST(OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class RandNFunctor { public: Maybe operator()(const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad, const Symbol& layout) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandN(shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); } if (dtype.has_value() && !JUST(dtype)->is_floating_point()) { OF_UNIMPLEMENTED() << "Only support floating dtype in randn()."; } const auto& out = Optional(); return Normal(static_cast(0), static_cast(1), shape, out, dtype, device, generator, requires_grad); } }; class GlobalRandNFunctor { public: Maybe operator()(const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { if (dtype.has_value() && !JUST(dtype)->is_floating_point()) { OF_UNIMPLEMENTED() << "Only support floating dtype in randn()."; } const auto& out = Optional(); return GlobalNormal(static_cast(0), static_cast(1), shape, out, placement, sbp_tuple, dtype, generator, requires_grad); } }; class NormalFunctor { public: NormalFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); } Maybe operator()(const float mean, const float std, const Shape& shape, const Optional& out, const Optional>& optional_dtype, const Optional>& optional_device, const Optional& optional_generator, const bool requires_grad) const { Symbol dtype = GetDefaultDType(); if (optional_dtype.has_value()) { if (!JUST(optional_dtype)->is_floating_point()) { OF_UNIMPLEMENTED() << "Only support float and double in normal()."; } dtype = JUST(optional_dtype); } Symbol device = JUST(Device::New("cpu")); if (optional_device.has_value()) { device = JUST(optional_device); } if (out.has_value()) { auto out_tensor = JUST(out); CHECK_OR_RETURN(shape == (*out_tensor->shape())) << "Shape of out_tensor does not match shape. " << "Expected shape: " << shape << ", actual shape: " << *out_tensor->shape(); Symbol output_tensor_dtype = out_tensor->dtype(); if (optional_dtype.has_value()) { CHECK_OR_RETURN(output_tensor_dtype == dtype) << Error::RuntimeError() << "data type " << dtype->name() << " does not match data type of out parameter " << output_tensor_dtype->name(); } dtype = output_tensor_dtype; Symbol out_tensor_device = JUST(out_tensor->device()); if (optional_device.has_value()) { CHECK_OR_RETURN(out_tensor_device == JUST(optional_device)) << Error::RuntimeError() << "device type " << device->ToString() << " does not match device type of out parameter " << out_tensor_device->ToString(); } device = out_tensor_device; } auto gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("mean", "std", "shape", "dtype", "seed"); attrs.SetAllAttrs(static_cast(mean), static_cast(std), shape, dtype->data_type(), static_cast(gen->current_seed())); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, device, distribution_state); if (out.has_value()) { std::shared_ptr outputs = std::make_shared(1); (*outputs)[0] = JUST(out); JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx)); return (*outputs)[0]; } auto result = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class Normal2Functor { public: Maybe operator()(const float mean, const float std, const int32_t shape, const Optional& out, const Optional>& optional_dtype, const Optional>& optional_device, const Optional& optional_generator, const bool requires_grad) const { const Shape size = Shape({shape}); return Normal(mean, std, size, out, optional_dtype, optional_device, optional_generator, requires_grad); } }; class InplaceNormalFuctor { public: Maybe operator()(const std::shared_ptr& x, const float mean, const float std, const Optional& optional_generator) const { return Normal(mean, std, *x->shape(), x, x->dtype(), JUST(x->device()), optional_generator, x->requires_grad()); } }; class TensorTensorNormalFunctor { public: Maybe operator()(const std::shared_ptr& mean, const std::shared_ptr& std, const Optional& out, const Optional& optional_generator, const bool requires_grad) const { JUST(CheckNormalTensorStd(std)); auto out_shape = *JUST(InferUnifiedShapeForBroadcasting({*mean->shape(), *std->shape()})); auto output = JUST(Normal(0, 1, out_shape, out, Symbol(mean->dtype()), JUST(mean->device()), optional_generator, requires_grad)); // mean + output * std JUST(InplaceMul(output, std)); JUST(Add(output, mean, 1, true)); JUST(output->set_requires_grad(requires_grad)); return output; } }; class TensorScalarNormalFunctor { public: // TODO : performance optimizing Write as a kenerl Maybe operator()(const std::shared_ptr& mean, const float std, const Optional& out, const Optional& optional_generator, const bool requires_grad) const { JUST(CheckNormalTensorStd(std)); auto output = JUST(Normal(0, std, *(mean->shape()), out, mean->dtype(), JUST(mean->device()), optional_generator, requires_grad)); JUST(Add(output, mean, 1, true)); JUST(output->set_requires_grad(requires_grad)); return output; } }; class ScalarTensorNormalFunctor { public: // TODO : performance optimizing one multiplication and one addition Write as a kenerl Maybe operator()(const float mean, const std::shared_ptr& std, const Optional& out, const Optional& optional_generator, const bool requires_grad) const { JUST(CheckNormalTensorStd(std)); auto output = JUST(Normal(0.0, 1.0, *(std->shape()), out, std->dtype(), JUST(std->device()), optional_generator, requires_grad)); JUST(InplaceMul(output, std)); JUST(ScalarAdd(output, mean, 1, true)); JUST(output->set_requires_grad(requires_grad)); return output; } }; class GlobalNormalFunctor { public: GlobalNormalFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); } Maybe operator()(const float& mean, const float& std, const Shape& shape, const Optional& out, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& optional_dtype, const Optional& optional_generator, const bool& requires_grad) const { Symbol dtype = DType::Float(); if (optional_dtype.has_value()) { if (!JUST(optional_dtype)->is_floating_point()) { OF_UNIMPLEMENTED() << "Only support float and double in normal()."; } dtype = JUST(optional_dtype); } if (out.has_value()) { auto out_tensor = JUST(out); Symbol output_tensor_dtype = out_tensor->dtype(); if (optional_dtype.has_value()) { CHECK_OR_RETURN(output_tensor_dtype == dtype) << Error::RuntimeError() << "data type " << dtype->name() << " does not match data type of out parameter (" << output_tensor_dtype->name(); } dtype = output_tensor_dtype; } JUST(CheckDeviceIdsIsValid(placement)); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp)); std::shared_ptr gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("mean", "std", "shape", "dtype", "seed", "nd_sbp"); attrs.SetAllAttrs(static_cast(mean), static_cast(std), shape, dtype->data_type(), static_cast(gen->current_seed()), attr_nd_sbp); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, placement, nd_sbp, distribution_state); if (out.has_value()) { std::shared_ptr outputs = std::make_shared(1); (*outputs)[0] = JUST(out); JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx)); return (*outputs)[0]; } auto result = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class GlobalNormal2Functor { public: Maybe operator()(const float& mean, const float& std, const int32_t& shape, const Optional& out, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& optional_dtype, const Optional& optional_generator, const bool& requires_grad) const { const Shape size = Shape({shape}); return GlobalNormal(mean, std, size, out, placement, sbp_tuple, optional_dtype, optional_generator, requires_grad); } }; class RandnLikeFunctor { public: Maybe operator()(const std::shared_ptr& input, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { return RandN(*input->shape(), dtype.value_or(input->dtype()), device.value_or(JUST(input->device())), generator, requires_grad, Layout::Strided()); } }; class GlobalRandnLikeFunctor { public: Maybe operator()(const std::shared_ptr& input, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { return GlobalRandN(*input->shape(), placement, sbp, dtype.value_or(input->dtype()), generator, requires_grad); } }; class RandIntFunctor { public: RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform_int").Output("out").Build()); } Maybe operator()(const int64_t low, const int64_t high, const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandInt( low, high, shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); } DataType dtype_val = DataType::kInt64; if (dtype) { dtype_val = JUST(dtype)->data_type(); } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "from", "to", "dtype", "seed"); attrs.SetAllAttrs(shape, low, high, dtype_val, static_cast(gen->current_seed())); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); ctx.device = device; auto result = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class RandInt2Functor { public: Maybe operator()(const int64_t high, const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { return RandInt(/*low*/ 0, high, shape, dtype, device, generator, requires_grad); } }; class RandIntLikeFunctor { public: Maybe operator()(const std::shared_ptr& input, const int64_t low, const int64_t high, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { const Shape shape = *input->shape(); return RandInt(low, high, shape, dtype, device, generator, requires_grad); } }; class RandIntLike2Functor { public: Maybe operator()(const std::shared_ptr& input, const int64_t high, const Optional>& dtype, const Optional>& device, const Optional& generator, const bool& requires_grad) const { const Shape shape = *input->shape(); return RandInt(/*low*/ 0, high, shape, dtype, device, generator, requires_grad); } }; class GlobalRandIntFunctor { public: GlobalRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform_int").Output("out").Build()); } Maybe operator()(const int64_t low, const int64_t high, const Shape& shape, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { JUST(CheckDeviceIdsIsValid(placement)); DataType dtype_val = DataType::kInt64; if (dtype) { dtype_val = JUST(dtype)->data_type(); } const auto& nd_sbp = JUST(GetNdSbp(sbp)); auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp)); auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "from", "to", "dtype", "seed", "nd_sbp"); attrs.SetAllAttrs(shape, low, high, dtype_val, static_cast(gen->current_seed()), attr_nd_sbp); const auto& distribution_state = std::make_shared(gen); auto result = JUST(OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); JUST(result->set_requires_grad(requires_grad)); return result; } private: std::shared_ptr op_; }; class GlobalRandInt2Functor { public: Maybe operator()(const int64_t high, const Shape& shape, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { JUST(CheckDeviceIdsIsValid(placement)); return GlobalRandInt(/*low*/ 0, high, shape, placement, sbp, dtype, generator, requires_grad); } }; class GlobalRandIntLikeFunctor { public: Maybe operator()(const std::shared_ptr& input, const int64_t low, const int64_t high, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { const Shape shape = *input->shape(); return GlobalRandInt(low, high, shape, placement, sbp, dtype, generator, requires_grad); } }; class GlobalRandIntLike2Functor { public: Maybe operator()(const std::shared_ptr& input, const int64_t high, const Symbol& placement, const std::vector>& sbp, const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { const Shape shape = *input->shape(); return GlobalRandInt(/*low*/ 0, high, shape, placement, sbp, dtype, generator, requires_grad); } }; class RandPermFunctor { public: RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); } Maybe operator()(const int32_t n, const Optional& generator, const Symbol& dtype, const Optional>& device, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandPerm(n, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), generator, dtype, requires_grad)); } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("n", "seed"); attrs.SetAllAttrs(n, static_cast(gen->current_seed())); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); ctx.device = device; auto result = JUST(OpInterpUtil::Dispatch(*randperm_op_, {}, ctx)); JUST(result->set_requires_grad(requires_grad)); return functional::Cast(result, dtype, /*pin_memory=*/false); } private: std::shared_ptr randperm_op_; }; class GlobalRandPermFunctor { public: GlobalRandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); } Maybe operator()(const int32_t n, const Symbol& placement, const std::vector>& sbp_tuple, const Optional& generator, const Symbol& dtype, const bool& requires_grad) const { JUST(CheckDeviceIdsIsValid(placement)); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp)); auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("n", "seed", "nd_sbp"); attrs.SetAllAttrs(n, static_cast(gen->current_seed()), attr_nd_sbp); const auto& distribution_state = std::make_shared(gen); auto result = JUST(OpInterpUtil::Dispatch( *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); JUST(result->set_requires_grad(requires_grad)); return functional::Cast(result, dtype, /*pin_memory=*/false); } private: std::shared_ptr randperm_op_; }; class ExponentialFunctor { public: ExponentialFunctor() { op_ = CHECK_JUST(one::OpBuilder("exponential").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const float& lambd, const Optional& generator) const { DataType dtype_val = x->dtype()->data_type(); Optional> device; Optional> placement; Optional> nd_sbp; auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); placement = JUST(x->parallel_desc()); nd_sbp = JUST(x->nd_sbp()); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp)); } else { device = JUST(x->device()); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt)); } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "lambd", "dtype", "out_shape", "nd_sbp"); const Shape& out_shape = *(x->shape()); Optional> attr_nd_sbp{NullOpt}; if (nd_sbp) { attr_nd_sbp = *JUST(GetNdSbpStrList(JUST(nd_sbp))); } attrs.SetAllAttrs(static_cast(gen->current_seed()), lambd, dtype_val, out_shape, attr_nd_sbp); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); ctx.device = device; ctx.parallel_desc = placement; ctx.nd_sbp = nd_sbp; std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx)); return outputs->at(0); } private: std::shared_ptr op_; }; // NOTE(Liang Depeng): The implementation of MultinomialFunctor is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Distributions.cpp#L548 class MultinomialFunctor { public: MultinomialFunctor() { op_cpu_ = CHECK_JUST(one::OpBuilder("multinomial_with_replacement").Input("x").Output("out").Build()); op_gpu_ = CHECK_JUST(one::OpBuilder("multinomial_with_replacement") .Input("x") .Input("prefix_sum") .Output("out") .Build()); op_npu_ = CHECK_JUST(one::OpBuilder("multinomial_with_replacement").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const int& num_samples, const bool& replacement, const Optional& generator) const { CHECK_OR_RETURN(x->ndim() > 0 && x->ndim() <= 2) << "The input probability tensor must be 1 or 2 dim, " << "but got: " << x->ndim(); CHECK_OR_RETURN(x->dtype()->is_floating_point()) << "multinomial only supports floating-point dtypes for input, but got: " << x->dtype()->name(); CHECK_OR_RETURN(num_samples > 0) << "cannot sample num_samples <= 0 samples"; int64_t num_categories = x->dim(x->ndim() - 1); CHECK_OR_RETURN(replacement || num_samples <= num_categories) << "cannot sample num_samples > prob_dist.size(-1) samples without replacement"; /* The largest consecutive integer representable in float32 (2^24) */ constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG); // Since the index tensor is float, numCategories cannot exceed max float integer precision CHECK_OR_RETURN(num_categories <= FLOAT32_MAX_CONSECUTIVE_INT) << "number of categories cannot exceed 2^24"; DeviceType input_device = DeviceType::kCPU; if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); input_device = JUST(x->parallel_desc())->device_type(); } else { input_device = JUST(x->device())->enum_type(); } // Fast-path for no replacement. // Reference: // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 if (!replacement && input_device != DeviceType::kNPU) { // The algorithm is from gumbel softmax. // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) // Here we can apply exp to the formula which will not affect result of // argmax or topk. Then we have // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1). // We can also simplify the formula above by // s = argmax( p / q ) where q ~ Exp(1) std::shared_ptr q = JUST(functional::Empty(*(x->shape()), x->dtype(), JUST(x->device()), /*requires_grad=*/x->requires_grad(), /*pin_memory=*/false)); q = JUST(functional::Exponential(q, 1, generator)); // In theory the probability to generate 0 from exponential distribution is // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU // side, there is a very low probability to generate 0 from // exponential. The probability is about 2^(-DBL_MANT_DIG). We just // ignore it here, but there may be some risk to get invalid output on CPU. q = JUST(functional::Div(x, q)); std::shared_ptr result; if (num_samples == 1) { result = JUST(functional::ArgMax(q, -1, true, JUST(DType::Get(DataType::kInt64)))); } else if (input_device == DeviceType::kNPU) { } else { std::shared_ptr temp = JUST(functional::TopK(q, num_samples, -1, /*largest=*/true, /*sorted=*/true)); result = (*temp)[1]; } return result; } auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "num_samples", "replacement"); attrs.SetAllAttrs(static_cast(gen->current_seed()), num_samples, replacement); const auto& distribution_state = std::make_shared(gen); OpExprInterpContext ctx(attrs, distribution_state); if (input_device == DeviceType::kCPU) { return OpInterpUtil::Dispatch(*op_cpu_, {x}, ctx); } else if (input_device == DeviceType::kNPU) { return OpInterpUtil::Dispatch(*op_npu_, {x}, ctx); } else { std::shared_ptr sum_last_dim = JUST(functional::ReduceSum(x, {-1}, true, NullOpt)); std::shared_ptr norm_dist = JUST(functional::Div(x, sum_last_dim)); std::shared_ptr prefix_sum = JUST(functional::Cumsum(norm_dist, -1, x->dtype())); return OpInterpUtil::Dispatch(*op_gpu_, {norm_dist, prefix_sum}, ctx); } } private: std::shared_ptr op_cpu_; std::shared_ptr op_gpu_; std::shared_ptr op_npu_; }; } // namespace impl using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Bernoulli"); m.add_functor("BernoulliInplace"); m.add_functor("BernoulliProb"); m.add_functor("BernoulliProbInplace"); m.add_functor("RandPerm"); m.add_functor("GlobalRandPerm"); m.add_functor("Rand"); m.add_functor("GlobalRand"); m.add_functor("RandN"); m.add_functor("GlobalRandN"); m.add_functor("Normal"); m.add_functor("Normal2"); m.add_functor("TensorTensorNormal"); m.add_functor("TensorScalarNormal"); m.add_functor("ScalarTensorNormal"); m.add_functor("Normal_"); m.add_functor("GlobalNormal"); m.add_functor("GlobalNormal2"); m.add_functor("RandnLike"); m.add_functor("GlobalRandnLike"); m.add_functor("RandInt"); m.add_functor("GlobalRandInt"); m.add_functor("RandIntLike"); m.add_functor("GlobalRandIntLike"); m.add_functor("Exponential"); m.add_functor("Multinomial"); m.add_functor("InplaceUniform"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/rnn_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/sequence_function.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace one { namespace functional { namespace impl { // NOTE(Liang Depeng): The implementation of rnn related functors are modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp struct tanh_f { Maybe operator()(const std::shared_ptr& t) const { return JUST(functional::Tanh(t)); } }; struct relu_f { Maybe operator()(const std::shared_ptr& t) const { return JUST(functional::Relu(t, false)); } }; Maybe check_rnn_cell_forward_input(const std::shared_ptr& input, int64_t input_size) { CHECK_OR_RETURN(input->shape()->At(1) == input_size) << "input has inconsistent input_size: got " << input->shape()->At(1) << " expected " << input_size; return Maybe::Ok(); } Maybe check_rnn_cell_forward_hidden(const std::shared_ptr& input, const std::shared_ptr& hx, int64_t hidden_size, int64_t hidden_label) { CHECK_OR_RETURN(input->shape()->At(0) == hx->shape()->At(0)) << "Input batch size " << input->shape()->At(0) << " doesn't match hidden" << hidden_label << " batch size " << hx->shape()->At(0); CHECK_OR_RETURN(hx->shape()->At(1) == hidden_size) << "hidden" << hidden_label << " has inconsistent hidden_size: got " << hx->shape()->At(1) << ", expected " << hidden_size; return Maybe::Ok(); } Maybe check_attributes(const std::shared_ptr& input, const TensorTuple& params, const TensorTuple& hiddens, bool check_dtype = false) { DeviceType input_device{}; if (input->is_global()) { input_device = JUST(input->parallel_desc())->device_type(); } else { input_device = JUST(input->device())->enum_type(); } DataType input_dtype = input->dtype()->data_type(); auto check_tensors = [&](const std::string& name, const std::shared_ptr& t) -> Maybe { DeviceType t_device{}; if (t->is_global()) { t_device = JUST(t->parallel_desc())->device_type(); } else { t_device = JUST(t->device())->enum_type(); } CHECK_OR_RETURN(input_device == t_device) << "Input and " << name << " tensors are not at the same device, found input tensor at " << input_device << " and " << name << " tensor at " << t_device; if (check_dtype) { DataType t_dtype = t->dtype()->data_type(); CHECK_OR_RETURN(input_dtype == t_dtype) << "Input and " << name << " tensors are not the same dtype, found input tensor with " << input_dtype << " and " << name << " tensor with " << t_dtype; } return Maybe::Ok(); }; for (const auto& h : hiddens) JUST(check_tensors("hidden", h)); for (const auto& p : params) JUST(check_tensors("parameter", p)); return Maybe::Ok(); } Maybe linear(const std::shared_ptr& input, const std::shared_ptr& weight, const std::shared_ptr& bias) { if (bias != nullptr) { TensorTuple weights; weights.emplace_back(weight); TensorTuple biases; biases.emplace_back(bias); return functional::FusedMLP(input, weights, biases, true); } else { return functional::MatMul(input, weight, false, true, 1.0); } } struct CellParams { CellParams(const std::shared_ptr _w_ih, // NOLINT const std::shared_ptr _w_hh, // NOLINT const std::shared_ptr _b_ih, // NOLINT const std::shared_ptr _b_hh, // NOLINT const std::shared_ptr _w_hr) // NOLINT : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr){}; const std::shared_ptr w_ih; const std::shared_ptr w_hh; const std::shared_ptr b_ih_; const std::shared_ptr b_hh_; const std::shared_ptr w_hr; // only defined for LSTMs with projections Maybe matmul_ih(const std::shared_ptr& input) const { return functional::MatMul(input, w_ih, false, true, 1.0); } Maybe matmul_hh(const std::shared_ptr& h) const { return functional::MatMul(h, w_hh, false, true, 1.0); } Maybe matmul_hr(const std::shared_ptr& h) const { if (w_hr != nullptr) { return functional::MatMul(h, w_hr, false, true, 1.0); } return h; } Maybe linear_ih(const std::shared_ptr& input) const { return linear(input, w_ih, b_ih_); } Maybe linear_hh(const std::shared_ptr& h) const { return linear(h, w_hh, b_hh_); } const std::shared_ptr& b_ih() const { return b_ih_; } const std::shared_ptr& b_hh() const { return b_hh_; } }; // Parses a flat list of parameter tensors into a list of CellParams static Maybe> gather_params(const TensorTuple& params, bool has_biases, bool has_projections = false) { std::vector result; if (has_biases) { if (has_projections) { CHECK_OR_RETURN(params.size() % 5 == 0) << "got an incorrect number of RNN parameters"; for (size_t i = 0; i < params.size(); i += 5) { result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]); } } else { CHECK_OR_RETURN(params.size() % 4 == 0) << "got an incorrect number of RNN parameters"; for (size_t i = 0; i < params.size(); i += 4) { result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], nullptr); } } } else { if (has_projections) { CHECK_OR_RETURN(params.size() % 3 == 0) << "got an incorrect number of RNN parameters"; for (size_t i = 0; i < params.size(); i += 3) { result.emplace_back(params[i], params[i + 1], nullptr, nullptr, params[i + 2]); } } else { CHECK_OR_RETURN(params.size() % 2 == 0) << "got an incorrect number of RNN parameters"; for (size_t i = 0; i < params.size(); i += 2) { result.emplace_back(params[i], params[i + 1], nullptr, nullptr, nullptr); } } } return result; } template struct SimpleCell { Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hidden, const cell_params& params, bool pre_compute_input = false) const { std::shared_ptr hh = JUST(params.linear_hh(hidden)); std::shared_ptr output; if (pre_compute_input) { output = JUST(functional::Add(hh, input, 1.0, true)); } else { std::shared_ptr ih = JUST(params.linear_ih(input)); output = JUST(functional::Add(hh, ih, 1.0, true)); } return nonlinearity{}(output); } }; template struct GRUCell { Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hidden, const cell_params& params, bool pre_compute_input = false) const { DeviceType input_device{}; if (input->is_global()) { input_device = JUST(input->parallel_desc())->device_type(); } else { input_device = JUST(input->device())->enum_type(); } if (input_device == DeviceType::kCUDA) { CHECK_OR_RETURN(!pre_compute_input); std::shared_ptr igates = JUST(params.matmul_ih(input)); std::shared_ptr hgates = JUST(params.matmul_hh(hidden)); std::shared_ptr result = JUST(functional::FusedGruCell(igates, hgates, hidden, params.b_ih(), params.b_hh())); return (*result)[0]; } std::shared_ptr chunked_igates; if (pre_compute_input) { chunked_igates = JUST(functional::Chunk(input, 3, 1)); } else { std::shared_ptr gates_ih = JUST(params.linear_ih(input)); chunked_igates = JUST(functional::Chunk(gates_ih, 3, 1)); } std::shared_ptr tmp = JUST(params.linear_hh(hidden)); std::shared_ptr chunked_hgates = JUST(functional::Chunk(tmp, 3, 1)); std::shared_ptr reset_gate = JUST(functional::Add((*chunked_hgates)[0], (*chunked_igates)[0], 1.0, false)); reset_gate = JUST(functional::Sigmoid(reset_gate)); std::shared_ptr input_gate = JUST(functional::Add((*chunked_hgates)[1], (*chunked_igates)[1], 1.0, false)); input_gate = JUST(functional::Sigmoid(input_gate)); std::shared_ptr new_gate = JUST(functional::Mul((*chunked_hgates)[2], reset_gate)); new_gate = JUST(functional::Add((*chunked_igates)[2], new_gate, 1.0, false)); new_gate = JUST(functional::Tanh(new_gate)); std::shared_ptr output = JUST(functional::Sub(hidden, new_gate, 1.0, false)); output = JUST(functional::Mul(output, input_gate)); output = JUST(functional::Add(output, new_gate, 1.0, false)); return output; } }; template struct LSTMCell { Maybe operator()(const std::shared_ptr& input, const one::TensorTuple& hidden, const cell_params& params, bool pre_compute_input = false) const { const std::shared_ptr& hx = hidden[0]; const std::shared_ptr& cx = hidden[1]; DeviceType input_device{}; if (input->is_global()) { input_device = JUST(input->parallel_desc())->device_type(); } else { input_device = JUST(input->device())->enum_type(); } if (input_device == DeviceType::kCUDA) { CHECK_OR_RETURN(!pre_compute_input); std::shared_ptr igates = JUST(params.matmul_ih(input)); std::shared_ptr hgates = JUST(params.matmul_hh(hx)); std::shared_ptr result = JUST(functional::FusedLstmCell(igates, hgates, cx, params.b_ih(), params.b_hh())); auto outputs = std::make_shared(2); (*outputs)[0] = JUST(params.matmul_hr((*result)[0])); (*outputs)[1] = (*result)[1]; return outputs; } std::shared_ptr gates = JUST(params.linear_hh(hx)); if (pre_compute_input) { gates = JUST(functional::Add(gates, input, 1.0, true)); } else { std::shared_ptr gates_ih = JUST(params.linear_ih(input)); gates = JUST(functional::Add(gates, gates_ih, 1.0, true)); } std::shared_ptr chunked_gates = JUST(functional::Chunk(gates, 4, 1)); std::shared_ptr ingate = JUST(functional::Sigmoid((*chunked_gates)[0])); std::shared_ptr forgetgate = JUST(functional::Sigmoid((*chunked_gates)[1])); std::shared_ptr cellgate = JUST(functional::Tanh((*chunked_gates)[2])); std::shared_ptr outgate = JUST(functional::Sigmoid((*chunked_gates)[3])); std::shared_ptr cy = JUST(functional::Mul(forgetgate, cx)); cellgate = JUST(functional::Mul(ingate, cellgate)); cy = JUST(functional::Add(cy, cellgate, 1.0, true)); std::shared_ptr tanh_cy = JUST(functional::Tanh(cy)); std::shared_ptr hy = JUST(functional::Mul(outgate, tanh_cy)); auto outputs = std::make_shared(2); (*outputs)[0] = JUST(params.matmul_hr(hy)); (*outputs)[1] = cy; return outputs; } }; class RnnTanhCellFunctor { public: RnnTanhCellFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const std::shared_ptr& w_ih, const std::shared_ptr& w_hh, const Optional& b_ih, const Optional& b_hh) const { JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1))); JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0)); std::shared_ptr bias_ih = nullptr; std::shared_ptr bias_hh = nullptr; if (b_ih.has_value() && b_hh.has_value()) { bias_ih = JUST(b_ih); bias_hh = JUST(b_hh); } return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr}); } }; class RnnReluCellFunctor { public: RnnReluCellFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const std::shared_ptr& w_ih, const std::shared_ptr& w_hh, const Optional& b_ih, const Optional& b_hh) const { JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1))); JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0)); std::shared_ptr bias_ih = nullptr; std::shared_ptr bias_hh = nullptr; if (b_ih.has_value() && b_hh.has_value()) { bias_ih = JUST(b_ih); bias_hh = JUST(b_hh); } return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr}); } }; class GruCellFunctor { public: GruCellFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const std::shared_ptr& w_ih, const std::shared_ptr& w_hh, const Optional& b_ih, const Optional& b_hh) const { JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1))); JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0)); std::shared_ptr bias_ih = nullptr; std::shared_ptr bias_hh = nullptr; if (b_ih.has_value() && b_hh.has_value()) { bias_ih = JUST(b_ih); bias_hh = JUST(b_hh); } return GRUCell{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr}); } }; class LstmCellFunctor { public: LstmCellFunctor() {} Maybe operator()(const std::shared_ptr& input, const one::TensorTuple& hx, const std::shared_ptr& w_ih, const std::shared_ptr& w_hh, const Optional& b_ih, const Optional& b_hh) const { CHECK_OR_RETURN(hx.size() == 2) << "lstm_cell expects two hidden states"; JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1))); auto hidden_size = w_hh->shape()->At(1); JUST(check_rnn_cell_forward_hidden(input, hx[0], hidden_size, 0)); JUST(check_rnn_cell_forward_hidden(input, hx[1], hidden_size, 0)); std::shared_ptr bias_ih = nullptr; std::shared_ptr bias_hh = nullptr; if (b_ih.has_value() && b_hh.has_value()) { bias_ih = JUST(b_ih); bias_hh = JUST(b_hh); } return LSTMCell{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr}); } }; class FusedGruCellFunctor { public: FusedGruCellFunctor() { op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_gru_cell") .Input("input_gates") .Input("hidden_gates") .Input("hx") .Input("input_bias") .Input("hidden_bias") .Output("hy") .Output("workspace") .Build()); op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_gru_cell") .Input("input_gates") .Input("hidden_gates") .Input("hx") .Output("hy") .Output("workspace") .Build()); } Maybe operator()(const std::shared_ptr& igates, const std::shared_ptr& hgates, const std::shared_ptr& hx, const Optional& b_ih, const Optional& b_hh) const { std::shared_ptr kernel_result; if (b_ih.has_value() && b_hh.has_value()) { kernel_result = JUST(OpInterpUtil::Dispatch( *op_with_bias_, {igates, hgates, hx, JUST(b_ih), JUST(b_hh)})); } else { kernel_result = JUST(OpInterpUtil::Dispatch(*op_without_bias_, {igates, hgates, hx})); } return kernel_result; } private: std::shared_ptr op_with_bias_; std::shared_ptr op_without_bias_; }; class FusedGruCellGradFunctor { public: FusedGruCellGradFunctor() { op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_gru_cell_grad") .Input("grad_hy") .Input("workspace") .Output("grad_input_gates") .Output("grad_hidden_gates") .Output("grad_hx") .Output("grad_input_bias") .Output("grad_hidden_bias") .Build()); op_with_bias_without_hx_ = CHECK_JUST(one::OpBuilder("fused_gru_cell_grad") .Input("grad_hy") .Input("workspace") .Output("grad_input_gates") .Output("grad_hidden_gates") .Output("grad_input_bias") .Output("grad_hidden_bias") .Build()); op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_gru_cell_grad") .Input("grad_hy") .Input("workspace") .Output("grad_input_gates") .Output("grad_hidden_gates") .Output("grad_hx") .Build()); op_without_bias_without_hx_ = CHECK_JUST(one::OpBuilder("fused_gru_cell_grad") .Input("grad_hy") .Input("workspace") .Output("grad_input_gates") .Output("grad_hidden_gates") .Build()); } Maybe operator()(const std::shared_ptr& grad_hy, const std::shared_ptr& workspace, bool has_bias, bool hx_needs_grad) const { std::shared_ptr kernel_result; if (has_bias) { if (hx_needs_grad) { kernel_result = JUST(OpInterpUtil::Dispatch(*op_with_bias_, {grad_hy, workspace})); } else { kernel_result = JUST( OpInterpUtil::Dispatch(*op_with_bias_without_hx_, {grad_hy, workspace})); } } else { if (hx_needs_grad) { kernel_result = JUST(OpInterpUtil::Dispatch(*op_without_bias_, {grad_hy, workspace})); } else { kernel_result = JUST(OpInterpUtil::Dispatch(*op_without_bias_without_hx_, {grad_hy, workspace})); } } return kernel_result; } private: std::shared_ptr op_with_bias_; std::shared_ptr op_with_bias_without_hx_; std::shared_ptr op_without_bias_; std::shared_ptr op_without_bias_without_hx_; }; class FusedLstmCellFunctor { public: FusedLstmCellFunctor() { op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell") .Input("input_gates") .Input("hidden_gates") .Input("cx") .Input("input_bias") .Input("hidden_bias") .Output("hy") .Output("cy") .Output("workspace") .Build()); op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell") .Input("input_gates") .Input("hidden_gates") .Input("cx") .Output("hy") .Output("cy") .Output("workspace") .Build()); } Maybe operator()(const std::shared_ptr& igates, const std::shared_ptr& hgates, const std::shared_ptr& cx, const Optional& b_ih, const Optional& b_hh) const { std::shared_ptr kernel_result; if (b_ih.has_value() && b_hh.has_value()) { kernel_result = JUST(OpInterpUtil::Dispatch( *op_with_bias_, {igates, hgates, cx, JUST(b_ih), JUST(b_hh)})); } else { kernel_result = JUST(OpInterpUtil::Dispatch(*op_without_bias_, {igates, hgates, cx})); } return kernel_result; } private: std::shared_ptr op_with_bias_; std::shared_ptr op_without_bias_; }; class FusedLstmCellGradFunctor { public: FusedLstmCellGradFunctor() { op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell_grad") .Input("grad_hy") .Input("grad_cy") .Input("cx") .Input("cy") .Input("workspace") .Output("grad_gates") .Output("grad_cx") .Output("grad_bias") .Build()); op_without_bias_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell_grad") .Input("grad_hy") .Input("grad_cy") .Input("cx") .Input("cy") .Input("workspace") .Output("grad_gates") .Output("grad_cx") .Build()); op_with_bias_no_grad_cx_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell_grad") .Input("grad_hy") .Input("grad_cy") .Input("cx") .Input("cy") .Input("workspace") .Output("grad_gates") .Output("grad_bias") .Build()); op_without_bias_no_grad_cx_ = CHECK_JUST(one::OpBuilder("fused_lstm_cell_grad") .Input("grad_hy") .Input("grad_cy") .Input("cx") .Input("cy") .Input("workspace") .Output("grad_gates") .Build()); } Maybe operator()(const std::shared_ptr& grad_hy, const std::shared_ptr& grad_cy, const std::shared_ptr& cx, const std::shared_ptr& cy, const std::shared_ptr& workspace, bool need_cx_grad, bool has_bias) const { std::shared_ptr kernel_result; if (has_bias) { if (need_cx_grad) { kernel_result = JUST(OpInterpUtil::Dispatch( *op_with_bias_, {grad_hy, grad_cy, cx, cy, workspace})); } else { kernel_result = JUST(OpInterpUtil::Dispatch( *op_with_bias_no_grad_cx_, {grad_hy, grad_cy, cx, cy, workspace})); } } else { if (need_cx_grad) { kernel_result = JUST(OpInterpUtil::Dispatch( *op_without_bias_, {grad_hy, grad_cy, cx, cy, workspace})); } else { kernel_result = JUST(OpInterpUtil::Dispatch( *op_without_bias_no_grad_cx_, {grad_hy, grad_cy, cx, cy, workspace})); } } return kernel_result; } private: std::shared_ptr op_with_bias_; std::shared_ptr op_with_bias_no_grad_cx_; std::shared_ptr op_without_bias_; std::shared_ptr op_without_bias_no_grad_cx_; }; template Maybe _rnn_impl(const std::shared_ptr& input, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) { TensorTuple hiddens; hiddens.emplace_back(hx); JUST(check_attributes(input, params, hiddens)); std::shared_ptr rnn_input = input; if (batch_first) { std::vector dims = {1, 0, 2}; rnn_input = JUST(functional::Permute(input, dims)); } auto rnn_params = JUST(gather_params(params, has_biases)); std::shared_ptr rnn_hiddens = JUST(functional::Unbind(hx, 0)); std::shared_ptr rnn_inputs = JUST(functional::Unbind(rnn_input, 0)); auto generator = JUST(one::DefaultAutoGenerator()); TensorTuple final_hiddens; if (bidirectional) { std::shared_ptr fw_outputs = std::make_shared(rnn_inputs->size()); std::shared_ptr bw_outputs = std::make_shared(rnn_inputs->size()); for (int32_t l = 0; l < num_layers; ++l) { // forward direction std::shared_ptr fw_hidden = (*rnn_hiddens)[l * 2]; auto& fw_cell_param = (*rnn_params)[l * 2]; for (int32_t i = 0; i < rnn_inputs->size(); ++i) { fw_hidden = JUST(cell_type{}((*rnn_inputs)[i], fw_hidden, fw_cell_param)); (*fw_outputs)[i] = fw_hidden; } final_hiddens.emplace_back(fw_hidden); // reverse direction std::shared_ptr bw_hidden = (*rnn_hiddens)[l * 2 + 1]; auto& bw_cell_param = (*rnn_params)[l * 2 + 1]; for (int32_t i = rnn_inputs->size() - 1; i >= 0; i--) { bw_hidden = JUST(cell_type{}((*rnn_inputs)[i], bw_hidden, bw_cell_param)); (*bw_outputs)[i] = bw_hidden; } final_hiddens.emplace_back(bw_hidden); // concat fw_outputs and bw_outputs for (int32_t i = 0; i < rnn_inputs->size(); ++i) { (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]}, bw_hidden->shape()->NumAxes() - 1)); } if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Stack(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); rnn_inputs = JUST(functional::Unbind(dropout_res, 0)); } } } else { for (int32_t l = 0; l < num_layers; ++l) { std::shared_ptr hidden = (*rnn_hiddens)[l]; auto& cell_param = (*rnn_params)[l]; for (int32_t i = 0; i < rnn_inputs->size(); ++i) { hidden = JUST(cell_type{}((*rnn_inputs)[i], hidden, cell_param)); (*rnn_inputs)[i] = hidden; } final_hiddens.emplace_back(hidden); if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Stack(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); rnn_inputs = JUST(functional::Unbind(dropout_res, 0)); } } } TensorTuple output; std::shared_ptr output_0 = JUST(functional::Stack(*rnn_inputs, 0)); if (batch_first) { std::vector dims = {1, 0, 2}; output.emplace_back(JUST(functional::Permute(output_0, dims))); } else { output.emplace_back(output_0); } output.emplace_back(JUST(functional::Stack(final_hiddens, 0))); return output; } class RnnTanhInputFunctor { public: RnnTanhInputFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) const { return _rnn_impl>(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); } }; class RnnReluInputFunctor { public: RnnReluInputFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) const { return _rnn_impl>(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); } }; template Maybe _rnn_pack_sequence_impl(const std::shared_ptr& input, const std::shared_ptr& batch_sizes, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) { auto rnn_params = JUST(gather_params(params, has_biases)); std::shared_ptr rnn_hiddens = JUST(functional::Unbind(hx, 0)); auto generator = JUST(one::DefaultAutoGenerator()); TensorTuple final_hiddens; std::vector batch_sizes_vec; batch_sizes_vec.resize(batch_sizes->nelement()); const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, batch_sizes_vec.data(), eager_blob_object->dptr(), batch_sizes_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(batch_sizes, callback, "const")); int64_t num_steps = batch_sizes->shape()->At(0); std::shared_ptr rnn_inputs = std::make_shared(num_steps); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(input, 0, input_offset, batch_size)); input_offset += batch_size; } if (bidirectional) { std::shared_ptr fw_outputs = std::make_shared(rnn_inputs->size()); std::shared_ptr bw_outputs = std::make_shared(rnn_inputs->size()); for (int32_t l = 0; l < num_layers; ++l) { // forward direction int64_t last_batch_size = batch_sizes_vec[0]; std::shared_ptr fw_hidden = (*rnn_hiddens)[l * 2]; auto& fw_cell_param = (*rnn_params)[l * 2]; TensorTuple fw_final_hiddens_for_single_layer; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t dec = last_batch_size - batch_size; if (dec > 0) { fw_final_hiddens_for_single_layer.emplace_back( JUST(functional::Narrow(fw_hidden, 0, last_batch_size - dec, dec))); fw_hidden = JUST(functional::Narrow(fw_hidden, 0, 0, last_batch_size - dec)); } last_batch_size = batch_size; fw_hidden = JUST(cell_type{}((*rnn_inputs)[i], fw_hidden, fw_cell_param)); (*fw_outputs)[i] = fw_hidden; } fw_final_hiddens_for_single_layer.emplace_back(fw_hidden); std::reverse(fw_final_hiddens_for_single_layer.begin(), fw_final_hiddens_for_single_layer.end()); final_hiddens.emplace_back(JUST(functional::Concat(fw_final_hiddens_for_single_layer, 0))); // reverse direction last_batch_size = batch_sizes_vec[num_steps - 1]; std::shared_ptr bw_hidden = JUST(functional::Narrow((*rnn_hiddens)[l * 2 + 1], 0, 0, last_batch_size)); auto& bw_cell_param = (*rnn_params)[l * 2 + 1]; // Here the situation is similar to that above, except we start out with // the smallest batch size (and a small set of hidden states we actually use), // and progressively expand the hidden states, as we move backwards over the // 1D list of inputs. for (int64_t i = num_steps - 1; i >= 0; --i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t inc = batch_size - last_batch_size; if (inc > 0) { std::shared_ptr hidden_slice = JUST(functional::Narrow( (*rnn_hiddens)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size)); std::shared_ptr tmp = std::make_shared(2); (*tmp)[0] = bw_hidden; (*tmp)[1] = hidden_slice; bw_hidden = JUST(functional::Concat(*tmp, 0)); } last_batch_size = batch_size; bw_hidden = JUST(cell_type{}((*rnn_inputs)[i], bw_hidden, bw_cell_param)); (*bw_outputs)[i] = bw_hidden; } final_hiddens.emplace_back(bw_hidden); // concat fw_outputs and bw_outputs for (int32_t i = 0; i < num_steps; ++i) { (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]}, bw_hidden->shape()->NumAxes() - 1)); } if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Concat(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size)); input_offset += batch_size; } } } } else { // Batch sizes is a sequence of decreasing lengths, which are offsets // into a 1D list of inputs. At every step we slice out batch_size elements, // and possibly account for the decrease in the batch size since the last step, // which requires us to slice the hidden state (since some sequences // are completed now). The sliced parts are also saved, because we will need // to return a tensor of final hidden state. for (int32_t l = 0; l < num_layers; ++l) { int64_t last_batch_size = batch_sizes_vec[0]; std::shared_ptr hidden = (*rnn_hiddens)[l]; auto& cell_param = (*rnn_params)[l]; TensorTuple final_hiddens_for_single_layer; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t dec = last_batch_size - batch_size; if (dec > 0) { final_hiddens_for_single_layer.emplace_back( JUST(functional::Narrow(hidden, 0, last_batch_size - dec, dec))); hidden = JUST(functional::Narrow(hidden, 0, 0, last_batch_size - dec)); } last_batch_size = batch_size; hidden = JUST(cell_type{}((*rnn_inputs)[i], hidden, cell_param)); (*rnn_inputs)[i] = hidden; } final_hiddens_for_single_layer.emplace_back(hidden); std::reverse(final_hiddens_for_single_layer.begin(), final_hiddens_for_single_layer.end()); final_hiddens.emplace_back(JUST(functional::Concat(final_hiddens_for_single_layer, 0))); if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Concat(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size)); input_offset += batch_size; } } } } TensorTuple output; output.emplace_back(JUST(functional::Concat(*rnn_inputs, 0))); output.emplace_back(JUST(functional::Stack(final_hiddens, 0))); return output; } class RnnTanhDataFunctor { public: RnnTanhDataFunctor() {} Maybe operator()(const std::shared_ptr& data, const std::shared_ptr& batch_sizes, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) const { return _rnn_pack_sequence_impl>( data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); } }; class RnnReluDataFunctor { public: RnnReluDataFunctor() {} Maybe operator()(const std::shared_ptr& data, const std::shared_ptr& batch_sizes, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) const { return _rnn_pack_sequence_impl>( data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); } }; Maybe _lstm_impl(const std::shared_ptr& input, const one::TensorTuple& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) { CHECK_OR_RETURN(hx.size() == 2) << "lstm expects two hidden states"; // if cells are of different size, that means projections are used bool has_projections = (hx[0]->shape()->At(2) != hx[1]->shape()->At(2)); JUST(check_attributes(input, params, hx)); std::shared_ptr rnn_input = input; if (batch_first) { std::vector dims = {1, 0, 2}; rnn_input = JUST(functional::Permute(input, dims)); } auto rnn_params = JUST(gather_params(params, has_biases, has_projections)); std::shared_ptr layer_hxs = JUST(functional::Unbind(hx[0], 0)); std::shared_ptr layer_cxs = JUST(functional::Unbind(hx[1], 0)); std::shared_ptr rnn_inputs = JUST(functional::Unbind(rnn_input, 0)); auto generator = JUST(one::DefaultAutoGenerator()); TensorTuple final_hy; TensorTuple final_cy; if (bidirectional) { std::shared_ptr fw_outputs = std::make_shared(rnn_inputs->size()); std::shared_ptr lstm_cell_out = std::make_shared(2); std::shared_ptr bw_outputs = std::make_shared(rnn_inputs->size()); for (int32_t l = 0; l < num_layers; ++l) { // forward direction (*lstm_cell_out)[0] = (*layer_hxs)[l * 2]; (*lstm_cell_out)[1] = (*layer_cxs)[l * 2]; auto& fw_cell_param = (*rnn_params)[l * 2]; for (int32_t i = 0; i < rnn_inputs->size(); ++i) { lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, fw_cell_param)); (*fw_outputs)[i] = (*lstm_cell_out)[0]; } final_hy.emplace_back((*lstm_cell_out)[0]); final_cy.emplace_back((*lstm_cell_out)[1]); // reverse direction (*lstm_cell_out)[0] = (*layer_hxs)[l * 2 + 1]; (*lstm_cell_out)[1] = (*layer_cxs)[l * 2 + 1]; auto& bw_cell_param = (*rnn_params)[l * 2 + 1]; for (int32_t i = rnn_inputs->size() - 1; i >= 0; i--) { lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, bw_cell_param)); (*bw_outputs)[i] = (*lstm_cell_out)[0]; } final_hy.emplace_back((*lstm_cell_out)[0]); final_cy.emplace_back((*lstm_cell_out)[1]); // concat fw_outputs and bw_outputs for (int32_t i = 0; i < rnn_inputs->size(); ++i) { (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]}, (*bw_outputs)[0]->shape()->NumAxes() - 1)); } if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Stack(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); rnn_inputs = JUST(functional::Unbind(dropout_res, 0)); } } } else { std::shared_ptr lstm_cell_out = std::make_shared(2); for (int32_t l = 0; l < num_layers; ++l) { auto& cell_param = (*rnn_params)[l]; (*lstm_cell_out)[0] = (*layer_hxs)[l]; (*lstm_cell_out)[1] = (*layer_cxs)[l]; for (int32_t i = 0; i < rnn_inputs->size(); ++i) { lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, cell_param)); (*rnn_inputs)[i] = (*lstm_cell_out)[0]; } final_hy.emplace_back((*lstm_cell_out)[0]); final_cy.emplace_back((*lstm_cell_out)[1]); if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Stack(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); rnn_inputs = JUST(functional::Unbind(dropout_res, 0)); } } } TensorTuple output; std::shared_ptr output_0 = JUST(functional::Stack(*rnn_inputs, 0)); if (batch_first) { std::vector dims = {1, 0, 2}; output.emplace_back(JUST(functional::Permute(output_0, dims))); } else { output.emplace_back(output_0); } output.emplace_back(JUST(functional::Stack(final_hy, 0))); output.emplace_back(JUST(functional::Stack(final_cy, 0))); return output; } class LstmInputFunctor { public: LstmInputFunctor() {} Maybe operator()(const std::shared_ptr& input, const one::TensorTuple& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) const { return _lstm_impl(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); } }; Maybe _lstm_pack_sequence_impl(const std::shared_ptr& input, const std::shared_ptr& batch_sizes, const one::TensorTuple& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) { CHECK_OR_RETURN(hx.size() == 2) << "lstm expects two hidden states"; // if cells are of different size, that means projections are used bool has_projections = (hx[0]->shape()->At(2) != hx[1]->shape()->At(2)); auto rnn_params = JUST(gather_params(params, has_biases, has_projections)); std::shared_ptr layer_hxs = JUST(functional::Unbind(hx[0], 0)); std::shared_ptr layer_cxs = JUST(functional::Unbind(hx[1], 0)); std::vector batch_sizes_vec; batch_sizes_vec.resize(batch_sizes->nelement()); const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, batch_sizes_vec.data(), eager_blob_object->dptr(), batch_sizes_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(batch_sizes, callback, "const")); int64_t num_steps = batch_sizes->shape()->At(0); std::shared_ptr rnn_inputs = std::make_shared(num_steps); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(input, 0, input_offset, batch_size)); input_offset += batch_size; } auto generator = JUST(one::DefaultAutoGenerator()); TensorTuple final_hy; TensorTuple final_cy; if (bidirectional) { std::shared_ptr fw_outputs = std::make_shared(rnn_inputs->size()); std::shared_ptr lstm_cell_out = std::make_shared(2); std::shared_ptr bw_outputs = std::make_shared(rnn_inputs->size()); for (int32_t l = 0; l < num_layers; ++l) { int64_t last_batch_size = batch_sizes_vec[0]; // forward direction (*lstm_cell_out)[0] = (*layer_hxs)[l * 2]; (*lstm_cell_out)[1] = (*layer_cxs)[l * 2]; auto& fw_cell_param = (*rnn_params)[l * 2]; TensorTuple final_hy_for_single_layer; TensorTuple final_cy_for_single_layer; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t dec = last_batch_size - batch_size; if (dec > 0) { final_hy_for_single_layer.emplace_back( JUST(functional::Narrow((*lstm_cell_out)[0], 0, last_batch_size - dec, dec))); (*lstm_cell_out)[0] = JUST(functional::Narrow((*lstm_cell_out)[0], 0, 0, last_batch_size - dec)); final_cy_for_single_layer.emplace_back( JUST(functional::Narrow((*lstm_cell_out)[1], 0, last_batch_size - dec, dec))); (*lstm_cell_out)[1] = JUST(functional::Narrow((*lstm_cell_out)[1], 0, 0, last_batch_size - dec)); } last_batch_size = batch_size; lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, fw_cell_param)); (*fw_outputs)[i] = (*lstm_cell_out)[0]; } final_hy_for_single_layer.emplace_back((*lstm_cell_out)[0]); final_cy_for_single_layer.emplace_back((*lstm_cell_out)[1]); std::reverse(final_hy_for_single_layer.begin(), final_hy_for_single_layer.end()); std::reverse(final_cy_for_single_layer.begin(), final_cy_for_single_layer.end()); final_hy.emplace_back(JUST(functional::Concat(final_hy_for_single_layer, 0))); final_cy.emplace_back(JUST(functional::Concat(final_cy_for_single_layer, 0))); // reverse direction last_batch_size = batch_sizes_vec[num_steps - 1]; (*lstm_cell_out)[0] = JUST(functional::Narrow((*layer_hxs)[l * 2 + 1], 0, 0, last_batch_size)); (*lstm_cell_out)[1] = JUST(functional::Narrow((*layer_cxs)[l * 2 + 1], 0, 0, last_batch_size)); auto& bw_cell_param = (*rnn_params)[l * 2 + 1]; for (int64_t i = num_steps - 1; i >= 0; --i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t inc = batch_size - last_batch_size; if (inc > 0) { std::shared_ptr hxs_slice = JUST(functional::Narrow( (*layer_hxs)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size)); std::shared_ptr tmp = std::make_shared(2); (*tmp)[0] = (*lstm_cell_out)[0]; (*tmp)[1] = hxs_slice; (*lstm_cell_out)[0] = JUST(functional::Concat(*tmp, 0)); std::shared_ptr cxs_slice = JUST(functional::Narrow( (*layer_cxs)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size)); (*tmp)[0] = (*lstm_cell_out)[1]; (*tmp)[1] = cxs_slice; (*lstm_cell_out)[1] = JUST(functional::Concat(*tmp, 0)); } last_batch_size = batch_size; lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, bw_cell_param)); (*bw_outputs)[i] = (*lstm_cell_out)[0]; } final_hy.emplace_back((*lstm_cell_out)[0]); final_cy.emplace_back((*lstm_cell_out)[1]); // concat fw_outputs and bw_outputs for (int32_t i = 0; i < rnn_inputs->size(); ++i) { (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]}, (*bw_outputs)[0]->shape()->NumAxes() - 1)); } if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Concat(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size)); input_offset += batch_size; } } } } else { std::shared_ptr lstm_cell_out = std::make_shared(2); for (int32_t l = 0; l < num_layers; ++l) { int64_t last_batch_size = batch_sizes_vec[0]; (*lstm_cell_out)[0] = (*layer_hxs)[l]; (*lstm_cell_out)[1] = (*layer_cxs)[l]; auto& cell_param = (*rnn_params)[l]; TensorTuple final_hy_for_single_layer; TensorTuple final_cy_for_single_layer; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; const int64_t dec = last_batch_size - batch_size; if (dec > 0) { final_hy_for_single_layer.emplace_back( JUST(functional::Narrow((*lstm_cell_out)[0], 0, last_batch_size - dec, dec))); (*lstm_cell_out)[0] = JUST(functional::Narrow((*lstm_cell_out)[0], 0, 0, last_batch_size - dec)); final_cy_for_single_layer.emplace_back( JUST(functional::Narrow((*lstm_cell_out)[1], 0, last_batch_size - dec, dec))); (*lstm_cell_out)[1] = JUST(functional::Narrow((*lstm_cell_out)[1], 0, 0, last_batch_size - dec)); } last_batch_size = batch_size; lstm_cell_out = JUST(LSTMCell{}((*rnn_inputs)[i], *lstm_cell_out, cell_param)); (*rnn_inputs)[i] = (*lstm_cell_out)[0]; } final_hy_for_single_layer.emplace_back((*lstm_cell_out)[0]); final_cy_for_single_layer.emplace_back((*lstm_cell_out)[1]); std::reverse(final_hy_for_single_layer.begin(), final_hy_for_single_layer.end()); std::reverse(final_cy_for_single_layer.begin(), final_cy_for_single_layer.end()); final_hy.emplace_back(JUST(functional::Concat(final_hy_for_single_layer, 0))); final_cy.emplace_back(JUST(functional::Concat(final_cy_for_single_layer, 0))); if (dropout != 0 && train && l < num_layers - 1) { std::shared_ptr stack_res = JUST(functional::Concat(*rnn_inputs, 0)); std::shared_ptr dropout_res = JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr)); int64_t input_offset = 0; for (int32_t i = 0; i < num_steps; ++i) { const int64_t batch_size = batch_sizes_vec[i]; (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size)); input_offset += batch_size; } } } } TensorTuple output; std::shared_ptr output_0 = JUST(functional::Concat(*rnn_inputs, 0)); output.emplace_back(output_0); output.emplace_back(JUST(functional::Stack(final_hy, 0))); output.emplace_back(JUST(functional::Stack(final_cy, 0))); return output; } class LstmDataFunctor { public: LstmDataFunctor() {} Maybe operator()(const std::shared_ptr& data, const std::shared_ptr& batch_sizes, const one::TensorTuple& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) const { return _lstm_pack_sequence_impl(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); } }; class GruInputFunctor { public: GruInputFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional, const bool& batch_first) const { return _rnn_impl>(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); } }; class GruDataFunctor { public: GruDataFunctor() {} Maybe operator()(const std::shared_ptr& data, const std::shared_ptr& batch_sizes, const std::shared_ptr& hx, const one::TensorTuple& params, const bool& has_biases, const int32_t& num_layers, const float& dropout, const bool& train, const bool& bidirectional) const { return _rnn_pack_sequence_impl>(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); } }; Maybe checkLongTensor(const std::shared_ptr& tensor) { auto& device = JUST(tensor->device())->type(); CHECK_OR_RETURN(tensor->ndim() == 1 && device == "cpu" && tensor->dtype() == DType::Int64()) << "'lengths' argument should be a 1D CPU int64 tensor, but got " << tensor->ndim() << "D " << device << " " << tensor->dtype()->name() << " tensor"; return Maybe::Ok(); } class PackPaddedSequenceFunctor { public: PackPaddedSequenceFunctor() {} Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& lengths, const bool& batch_first) const { CHECK_OR_RETURN(input->is_local() && lengths->is_local()) << "pack_padded_sequence only accept local tensors as input."; std::shared_ptr new_input = input; if (batch_first) { std::vector dims; dims.resize(input->shape()->NumAxes()); dims[0] = 1; dims[1] = 0; for (int i = 2; i < input->shape()->NumAxes(); ++i) { dims[i] = i; } new_input = JUST(functional::Permute(input, dims)); } JUST(checkLongTensor(lengths)); int64_t batch_size = new_input->shape()->At(1); std::vector lengths_vec; lengths_vec.resize(lengths->nelement()); const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, lengths_vec.data(), eager_blob_object->dptr(), lengths_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(lengths, callback, "const")); CHECK_OR_RETURN(new_input->nelement() > 0) << "Cannot pack empty tensors."; CHECK_OR_RETURN(lengths->shape()->At(0) == batch_size) << "Expected `len(lengths)` to be equal to batch_size, but got " << lengths->shape()->At(0) << " (batch_size=" << batch_size << ")"; CHECK_OR_RETURN(lengths_vec[batch_size - 1] > 0) << "Length of all samples has to be greater than 0, but found an element in 'lengths' that " "is <= 0"; for (int i = 0; i < batch_size - 1; ++i) { if (lengths_vec[batch_size - 1 - i] > lengths_vec[batch_size - 2 - i]) { CHECK_OR_RETURN(false) << "`lengths` array must be sorted in decreasing order when " "`enforce_sorted` is True. You can pass `enforce_sorted=False` " "to pack_padded_sequence and/or pack_sequence to sidestep this " "requirement if you do not need ONNX exportability."; } } std::vector step_shape_vec; // == [-1, *input.shape[2:]] { const auto& input_sizes = new_input->shape(); step_shape_vec.push_back(-1); for (int i = 2; i < input_sizes->NumAxes(); ++i) { step_shape_vec.push_back(input_sizes->At(i)); } } DimVector rsv(step_shape_vec.size()); for (int i = 0; i < step_shape_vec.size(); ++i) { rsv[i] = step_shape_vec[i]; } const Shape step_shape(rsv); // To understand what's going on in this loop imagine that the input is a padded 2D // array that looks like this (x = valid entry, . = padding) // // 1 1 1 1 1 // 2 2 2 . . // 2 2 2 . . // 4 . . . . // 4 . . . . // // Where the vertical dimension corresponds to time, and horizontal dim to batch. // In this example, the lengths array will be equal to [5, 3, 3, 1, 1], and we will // iterate over them in reverse order (from the rightmost column to the left). // We want to avoid eager slicing of the input at every time step, and wait for // the moments where the length increases. In this example, that will happen at the // first, second and fourth steps. Then, we slice out the whole block of the input // that corresponds to this length, and hasn't been sliced yet (the steps at which each // element is sliced are annotated in the array above). You can think of this as if we // were scanning the sequences from the shortest one, and every time we realize there's // more elements below in our column, we lower the counter (prev_l), and append the new // block to the output. std::vector batch_sizes; batch_sizes.resize(lengths_vec[0]); int64_t* batch_sizes_ptr = batch_sizes.data(); TensorTuple steps; int64_t prev_l = 0; for (int i = 0; i < batch_size; ++i) { int64_t l = lengths_vec[batch_size - 1 - i]; if (l > prev_l) { auto current_batch_size = batch_size - i; std::shared_ptr slice_res = JUST(functional::Narrow(new_input, 0, prev_l, l - prev_l)); slice_res = JUST(functional::Narrow(slice_res, 1, 0, current_batch_size)); slice_res = JUST(functional::View(slice_res->contiguous(), step_shape)); steps.emplace_back(slice_res); for (int64_t j = 0; j < (l - prev_l); ++j) { (*batch_sizes_ptr++) = current_batch_size; } prev_l = l; } CHECK_OR_RETURN(l >= prev_l) << "PackPaddedSequenceFunctor: `lengths` array must be sorted in decreasing order."; } DimVector lsv(1); lsv[0] = lengths_vec[0]; const Shape ls(lsv); std::shared_ptr batch_sizes_t = JUST(functional::Empty(ls, lengths->dtype(), JUST(lengths->device()), /*requires_grad=*/lengths->requires_grad(), /*pin_memory=*/false)); const auto& callback2 = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), batch_sizes.data(), batch_sizes.size() * sizeof(int64_t), eager_blob_object->mem_case(), memory::MakeHostMemCase()); // copy 1 scalar(int64_t) tensor's value to max }; JUST(SyncAccessTensorWithTimeOut(batch_sizes_t, callback2, "const")); std::shared_ptr output = std::make_shared(2); (*output)[0] = JUST(functional::Concat(steps, 0)); (*output)[1] = batch_sizes_t; return output; } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("RnnTanhCell"); m.add_functor("RnnReluCell"); m.add_functor("LstmCell"); m.add_functor("GruCell"); m.add_functor("FusedLstmCell"); m.add_functor("FusedLstmCellGrad"); m.add_functor("FusedGruCell"); m.add_functor("FusedGruCellGrad"); m.add_functor("RnnTanhInput"); m.add_functor("RnnTanhData"); m.add_functor("RnnReluInput"); m.add_functor("RnnReluData"); m.add_functor("LstmInput"); m.add_functor("LstmData"); m.add_functor("GruInput"); m.add_functor("GruData"); m.add_functor("PackPaddedSequence"); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/slice_boxing_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { namespace { bool IsSplitSbp(Symbol sbp_parallel) { return sbp_parallel->has_split_parallel(); } Maybe EagerSToB(Symbol in_parallel_desc, Symbol out_parallel_desc, Symbol src_sbp, const Shape& shape) { return one::OpBuilder("eager_s_to_b", *JUST(UniqueStr("eager_s_to_b"))) .Input("in") .Output("out") .Attr("in_split_axis", src_sbp->split_parallel().axis()) .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerSToBOpExpr = DECORATE(&EagerSToB, ThreadLocalCopiable); Maybe EagerPToB(Symbol in_parallel_desc, Symbol out_parallel_desc, const Shape& shape) { return one::OpBuilder("eager_p_to_b", *JUST(UniqueStr("eager_p_to_b"))) .Input("in") .Output("out") .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerPToBOpExpr = DECORATE(&EagerPToB, ThreadLocalCopiable); Maybe EagerNaiveSToS(Symbol in_parallel_desc, Symbol out_parallel_desc, Symbol src_sbp, Symbol dst_sbp, const Shape& shape) { return one::OpBuilder("eager_naive_s_to_s", *JUST(UniqueStr("eager_naive_s_to_s"))) .Input("in") .Output("out") .Attr("in_split_axis", src_sbp->split_parallel().axis()) .Attr("out_split_axis", dst_sbp->split_parallel().axis()) .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerNaiveSToSOpExpr = DECORATE(&EagerNaiveSToS, ThreadLocalCopiable); Maybe EagerBToS(Symbol in_parallel_desc, Symbol out_parallel_desc, Symbol dst_sbp, const Shape& shape) { return one::OpBuilder("eager_b_to_s", *JUST(UniqueStr("eager_b_to_s"))) .Input("in") .Output("out") .Attr("out_split_axis", dst_sbp->split_parallel().axis()) .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerBToSOpExpr = DECORATE(&EagerBToS, ThreadLocalCopiable); Maybe EagerPToS(Symbol in_parallel_desc, Symbol out_parallel_desc, Symbol dst_sbp, const Shape& shape) { return one::OpBuilder("eager_p_to_s", *JUST(UniqueStr("eager_p_to_s"))) .Input("in") .Output("out") .Attr("out_split_axis", dst_sbp->split_parallel().axis()) .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerPToSOpExpr = DECORATE(&EagerPToS, ThreadLocalCopiable); Maybe EagerSToP(Symbol in_parallel_desc, Symbol out_parallel_desc, Symbol src_sbp, const Shape& shape) { return one::OpBuilder("eager_s_to_p", *JUST(UniqueStr("eager_s_to_p"))) .Input("in") .Output("out") .Attr("in_split_axis", src_sbp->split_parallel().axis()) .Attr("in_parallel_conf", PbMessage2TxtString(in_parallel_desc->parallel_conf())) .Attr("out_parallel_conf", PbMessage2TxtString(out_parallel_desc->parallel_conf())) .Attr("shape", shape) .Build(); } static constexpr auto* CachedEagerSToPOpExpr = DECORATE(&EagerSToP, ThreadLocalCopiable); } // namespace class EagerSToBFunctor { public: EagerSToBFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const std::vector>& in_sbp_parallels, const Shape& shape) const { Symbol in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels)); { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(in_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The input tensor's sbp should be (split, )"; } std::shared_ptr op_expr = JUST(CachedEagerSToBOpExpr( in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)), shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class EagerPToBFunctor { public: EagerPToBFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const Shape& shape) const { { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; } std::shared_ptr op_expr = JUST(CachedEagerPToBOpExpr(in_parallel_desc, out_parallel_desc, shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class EagerNaiveSToSFunctor { public: EagerNaiveSToSFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const std::vector>& in_sbp_parallels, const std::vector>& out_sbp_parallels, const Shape& shape) const { Symbol in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels)); Symbol out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels)); { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(in_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The input tensor's sbp should be (split, )"; CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(out_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The output tensor's sbp should be (split, )"; } std::shared_ptr op_expr = JUST(CachedEagerNaiveSToSOpExpr( in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)), SymbolOf(out_nd_sbp->sbp_parallel(0)), shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class EagerBToSFunctor { public: EagerBToSFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const std::vector>& out_sbp_parallels, const Shape& shape) const { Symbol out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels)); { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(out_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The output tensor's sbp should be (split, )"; } std::shared_ptr op_expr = JUST(CachedEagerBToSOpExpr( in_parallel_desc, out_parallel_desc, SymbolOf(out_nd_sbp->sbp_parallel(0)), shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class EagerPToSFunctor { public: EagerPToSFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const std::vector>& out_sbp_parallels, const Shape& shape) const { Symbol out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels)); { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(out_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The output tensor's sbp should be (split, )"; } std::shared_ptr op_expr = JUST(CachedEagerPToSOpExpr( in_parallel_desc, out_parallel_desc, SymbolOf(out_nd_sbp->sbp_parallel(0)), shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; class EagerSToPFunctor { public: EagerSToPFunctor() = default; Maybe operator()(const std::shared_ptr& x, Symbol in_parallel_desc, Symbol out_parallel_desc, const std::vector>& in_sbp_parallels, const Shape& shape) const { Symbol in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels)); { CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << "input tensors `.is_local` should be true"; CHECK_OR_RETURN(x->is_eager()) << Error::RuntimeError() << "input tensors `.is_eager` should be true"; CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1) && IsSplitSbp(in_nd_sbp->sbp_parallel(0))) << Error::RuntimeError() << "The input tensor's sbp should be (split, )"; } std::shared_ptr op_expr = JUST(CachedEagerSToPOpExpr( in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)), shape)); return JUST(OpInterpUtil::Dispatch(*op_expr, {x})); } }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("EagerSToB"); m.add_functor("EagerPToB"); m.add_functor("EagerNaiveSToS"); m.add_functor("EagerBToS"); m.add_functor("EagerPToS"); m.add_functor("EagerSToP"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/test_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { class ThrowErrorFunctor final { public: ThrowErrorFunctor() { op_ = CHECK_JUST(one::OpBuilder("throw_error").Input("x").Output("y").Build()); } Maybe operator()(const std::shared_ptr& input) const { return JUST(OpInterpUtil::Dispatch(*op_, {input})); } protected: std::shared_ptr op_; }; } // namespace impl using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ThrowError"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/unary_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/functional/impl/binary_functor.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" namespace oneflow { namespace one { namespace functional { namespace impl { #define INPLACE_UNARY_FLOAT_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sin", InplaceSin) \ OF_PP_MAKE_TUPLE_SEQ("floor", InplaceFloor) \ OF_PP_MAKE_TUPLE_SEQ("ceil", InplaceCeil) \ OF_PP_MAKE_TUPLE_SEQ("round", InplaceRound) #define UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ \ OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \ OF_PP_MAKE_TUPLE_SEQ("acos", Acos) \ OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \ OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ReciprocalNoNan) #define FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ \ OF_PP_MAKE_TUPLE_SEQ("acosh", Acosh) \ OF_PP_MAKE_TUPLE_SEQ("asin", Asin) \ OF_PP_MAKE_TUPLE_SEQ("asinh", Asinh) \ OF_PP_MAKE_TUPLE_SEQ("atan", Atan) \ OF_PP_MAKE_TUPLE_SEQ("atanh", Atanh) \ OF_PP_MAKE_TUPLE_SEQ("sin", Sin) \ OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \ OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \ OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \ OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ OF_PP_MAKE_TUPLE_SEQ("exp2", Exp2) \ OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ OF_PP_MAKE_TUPLE_SEQ("log", Log) \ OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log10", Log10) \ OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt", Rsqrt) \ OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \ OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \ OF_PP_MAKE_TUPLE_SEQ("square", Square) \ OF_PP_MAKE_TUPLE_SEQ("tan", Tan) \ OF_PP_MAKE_TUPLE_SEQ("digamma", Digamma) #define FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_Y_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sigmoid", Sigmoid) \ OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh) #define UNARY_FUNC_BWD_WITH_FILL_SEQ \ OF_PP_MAKE_TUPLE_SEQ("rint", Rint) \ OF_PP_MAKE_TUPLE_SEQ("round", Round) \ OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ OF_PP_MAKE_TUPLE_SEQ("ceil", Ceil) #define FLOAT_UNARY_FUNC_BWD_WITH_FILL_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sign", Sign) \ OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero) #define LOGICAL_FLOAT_UNARY_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("logical_not", LogicalNot) #define UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, base) \ class class_name##Functor : public base { \ public: \ class_name##Functor() { \ op_ = CHECK_JUST(one::OpBuilder(op_type_name).Input("x").Output("y").Build()); \ } \ }; #define UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, base) \ class class_name##WithDyXGradFunctor : public base { \ public: \ class_name##WithDyXGradFunctor() { \ op_ = CHECK_JUST(one::OpBuilder(std::string("") + op_type_name + "_grad") \ .Input("x") \ .Input("dy") \ .Output("dx") \ .Build()); \ } \ }; #define UNARY_ELEMENTWISE_BWD_WITH_DY_Y_FUNCTOR(op_type_name, class_name, base) \ class class_name##WithDyYGradFunctor : public base { \ public: \ class_name##WithDyYGradFunctor() { \ op_ = CHECK_JUST(one::OpBuilder(std::string("") + op_type_name + "_grad") \ .Input("y") \ .Input("dy") \ .Output("dx") \ .Build()); \ } \ }; #define INPLACE_UNARY_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, InplaceUnaryFunctor) #define INPLACE_FLOAT_UNARY_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, InplaceFloatUnaryFunctor) #define LOGICAL_FLOAT_UNARY_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) #define UNARY_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor) \ UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor) #define FLOAT_UNARY_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \ UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor) #define UNARY_BWD_WITH_DY_X_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor) \ UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor) #define FLOAT_UNARY_BWD_WITH_DY_X_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \ UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor) #define FLOAT_UNARY_WITH_DY_Y_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \ UNARY_ELEMENTWISE_BWD_WITH_DY_Y_FUNCTOR(op_type_name, class_name, BinaryFunctor) #define FLOAT_UNARY_BWD_WITH_FILL_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) #define UNARY_BWD_WITH_FILL_FUNCTORS(op_type_name, class_name) \ UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor) OF_PP_FOR_EACH_TUPLE(INPLACE_FLOAT_UNARY_FUNCTORS, INPLACE_UNARY_FLOAT_FUNC_SEQ); OF_PP_FOR_EACH_TUPLE(LOGICAL_FLOAT_UNARY_FUNCTORS, LOGICAL_FLOAT_UNARY_FUNC_SEQ); OF_PP_FOR_EACH_TUPLE(UNARY_BWD_WITH_DY_X_FUNCTORS, UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ); OF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_BWD_WITH_DY_X_FUNCTORS, FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ); OF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_WITH_DY_Y_FUNCTORS, FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_Y_SEQ); OF_PP_FOR_EACH_TUPLE(UNARY_BWD_WITH_FILL_FUNCTORS, UNARY_FUNC_BWD_WITH_FILL_SEQ); OF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_BWD_WITH_FILL_FUNCTORS, FLOAT_UNARY_FUNC_BWD_WITH_FILL_SEQ); UNARY_ELEMENTWISE_FUNCTOR("negative", Negative, FloatUnaryFunctor) UNARY_ELEMENTWISE_FUNCTOR("bitwise_not", BitwiseNot, UnaryFunctor) UNARY_ELEMENTWISE_FUNCTOR("trigamma", Trigamma, FloatUnaryFunctor) } // namespace impl using namespace impl; #define ADD_UNARY_FUNCTOR_WITH_DY_X(class_name, functor_name) \ m.add_functor(functor_name); \ m.add_functor(std::string("") + functor_name + "Grad"); #define ADD_UNARY_FUNCTOR_WITH_DY_Y(class_name, functor_name) \ m.add_functor(functor_name); \ m.add_functor(std::string("") + functor_name + "Grad"); ONEFLOW_FUNCTION_LIBRARY(m) { ADD_UNARY_FUNCTOR_WITH_DY_X(Abs, "Abs"); ADD_UNARY_FUNCTOR_WITH_DY_X(Acos, "Acos"); ADD_UNARY_FUNCTOR_WITH_DY_X(Acosh, "Acosh"); ADD_UNARY_FUNCTOR_WITH_DY_X(Asin, "Asin"); ADD_UNARY_FUNCTOR_WITH_DY_X(Asinh, "Asinh"); ADD_UNARY_FUNCTOR_WITH_DY_X(Atan, "Atan"); ADD_UNARY_FUNCTOR_WITH_DY_X(Atanh, "Atanh"); m.add_functor("Ceil"); ADD_UNARY_FUNCTOR_WITH_DY_X(Cos, "Cos"); ADD_UNARY_FUNCTOR_WITH_DY_X(Cosh, "Cosh"); ADD_UNARY_FUNCTOR_WITH_DY_X(Digamma, "Digamma"); ADD_UNARY_FUNCTOR_WITH_DY_X(Erf, "Erf"); ADD_UNARY_FUNCTOR_WITH_DY_X(Erfc, "Erfc"); ADD_UNARY_FUNCTOR_WITH_DY_X(Exp, "Exp"); ADD_UNARY_FUNCTOR_WITH_DY_X(Exp2, "Exp2"); ADD_UNARY_FUNCTOR_WITH_DY_X(Expm1, "Expm1"); m.add_functor("Floor"); ADD_UNARY_FUNCTOR_WITH_DY_X(Lgamma, "Lgamma"); ADD_UNARY_FUNCTOR_WITH_DY_X(Log, "Log"); ADD_UNARY_FUNCTOR_WITH_DY_X(Log2, "Log2"); ADD_UNARY_FUNCTOR_WITH_DY_X(Log10, "Log10"); ADD_UNARY_FUNCTOR_WITH_DY_X(Log1p, "Log1p"); ADD_UNARY_FUNCTOR_WITH_DY_X(LogSigmoid, "LogSigmoid"); m.add_functor("Negative"); m.add_functor("BitwiseNot"); ADD_UNARY_FUNCTOR_WITH_DY_X(Reciprocal, "Reciprocal"); ADD_UNARY_FUNCTOR_WITH_DY_X(ReciprocalNoNan, "ReciprocalNoNan"); m.add_functor("Rint"); m.add_functor("Round"); ADD_UNARY_FUNCTOR_WITH_DY_X(Rsqrt, "Rsqrt"); ADD_UNARY_FUNCTOR_WITH_DY_Y(Sigmoid, "Sigmoid"); m.add_functor("Sign"); ADD_UNARY_FUNCTOR_WITH_DY_X(Sin, "Sin"); ADD_UNARY_FUNCTOR_WITH_DY_X(Sinh, "Sinh"); ADD_UNARY_FUNCTOR_WITH_DY_X(Sqrt, "Sqrt"); ADD_UNARY_FUNCTOR_WITH_DY_X(Square, "Square"); ADD_UNARY_FUNCTOR_WITH_DY_X(Tan, "Tan"); ADD_UNARY_FUNCTOR_WITH_DY_Y(Tanh, "Tanh"); m.add_functor("NotEqualZero"); m.add_functor("LogicalNot"); m.add_functor("Sin_"); m.add_functor("Floor_"); m.add_functor("Ceil_"); m.add_functor("Round_"); m.add_functor("Trigamma"); }; #undef ADD_UNARY_FUNCTOR_WITH_DY_X #undef ADD_UNARY_FUNCTOR_WITH_DY_Y } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/impl/unary_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ #define ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/tensor_processor.h" namespace oneflow { namespace one { namespace functional { namespace impl { class UnaryFunctor { public: Maybe operator()(const std::shared_ptr& x) const { return OpInterpUtil::Dispatch(*op_, {x}); } protected: UnaryFunctor() = default; virtual ~UnaryFunctor() = default; std::shared_ptr op_; }; class InplaceUnaryFunctor { public: Maybe operator()(const std::shared_ptr& x) const { JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; if (x->requires_grad()) { JUST(OpInterpUtil::Dispatch(*op_, {JUST(functional::Identity(x))}, outputs.get())); } else { JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get())); } return outputs->at(0); } protected: InplaceUnaryFunctor() = default; virtual ~InplaceUnaryFunctor() = default; std::shared_ptr op_; }; class FloatUnaryFunctor { public: Maybe operator()(const std::shared_ptr& x) const { // The functor lowest Dtype is Float32. (For sigmoid, tanh and etc. ) TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({x}, DType::Float()) .PromoteInputsToCommonDtype(true) .PromoteIntegerInputsToFloatDtype(true) .Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); return OpInterpUtil::Dispatch(*op_, input_tuple); } protected: FloatUnaryFunctor() = default; virtual ~FloatUnaryFunctor() = default; std::shared_ptr op_; }; class InplaceFloatUnaryFunctor { public: Maybe operator()(const std::shared_ptr& x) const { TensorProcessor tensor_processor; JUST(tensor_processor.AddInputs({x}, DType::Float()).Apply()); TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); JUST(CheckInplaceCastValid(x, input_tuple.at(0))); JUST(CheckInplaceValid(x)); std::shared_ptr outputs = std::make_shared(1); outputs->at(0) = x; if (x->requires_grad()) { // It should copy input tensor in autograd_mode because these operators can't calculate // in_grad with output. JUST(OpInterpUtil::Dispatch(*op_, {JUST(functional::Identity(x))}, outputs.get())); } else { JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get())); } return outputs->at(0); } protected: InplaceFloatUnaryFunctor() = default; virtual ~InplaceFloatUnaryFunctor() = default; std::shared_ptr op_; }; } // namespace impl } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ ================================================ FILE: oneflow/core/functional/impl/util_ops_functor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/impl/common.h" namespace oneflow { namespace one { namespace functional { namespace impl { class UtilOpsFunctor { public: Maybe operator()(const std::shared_ptr& input) const { return JUST(OpInterpUtil::Dispatch(*op_, {input})); } protected: std::shared_ptr op_; }; class IsNanFunctor final : public UtilOpsFunctor { public: IsNanFunctor() { op_ = CHECK_JUST(one::OpBuilder("isnan").Input("in").Output("out").Build()); } }; class IsInfFunctor final : public UtilOpsFunctor { public: IsInfFunctor() { op_ = CHECK_JUST(one::OpBuilder("isinf").Input("in").Output("out").Build()); } }; class IsFiniteFunctor final : public UtilOpsFunctor { public: IsFiniteFunctor() { op_ = CHECK_JUST(one::OpBuilder("isfinite").Input("in").Output("out").Build()); } }; class DependFunctor { public: DependFunctor() { op_ = CHECK_JUST( one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build()); } Maybe operator()(const std::shared_ptr& in, const std::shared_ptr& depend_tensor) const { return OpInterpUtil::Dispatch(*op_, {in, depend_tensor}); } private: std::shared_ptr op_; }; class DependTupleFunctor { public: DependTupleFunctor() { ops_.resize(kMaxInputCount); for (int n = 0; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST( one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build()); } } Maybe operator()(const std::shared_ptr& in, const one::TensorTuple& depends) const { return _dispatch(in, depends, 0); } private: Maybe _dispatch(const std::shared_ptr& in, const one::TensorTuple& depends, const int pos) const { const size_t ndepend = depends.size(); Maybe output = OpInterpUtil::Dispatch(*ops_[pos], {in, depends[pos]}); if (pos == ndepend - 1) { return output; } return _dispatch(JUST(output), depends, pos + 1); } std::vector> ops_; }; } // namespace impl using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsNan"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsInf"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsFinite"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Depend"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("DependTuple"); }; } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/packed_functor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_ #define ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_ #include #include "oneflow/core/common/function_traits.h" #include "oneflow/core/common/type_traits.h" namespace oneflow { namespace one { namespace functional { template using remove_cvref_t = oneflow::detail::remove_cvref_t; template class PackedFunctor; template class PackedFunctor { public: PackedFunctor(const std::string& func_name, const std::function& impl) : func_name_(func_name), impl_(impl) {} virtual ~PackedFunctor() = default; template R call(TArgs&&... args) const { return impl_(std::forward(args)...); } private: std::string func_name_; std::function impl_; }; template class PackedFunctorMaker; template class PackedFunctorMaker { public: using FType = R(const remove_cvref_t&...); template::func_type, R(Args...)>::value, int>::type = 0> static PackedFunctor make(const std::string& func_name, const Func& func) { return PackedFunctor(func_name, [func](const remove_cvref_t&... args) -> R { return func(std::forward&>(args)...); }); } }; } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_ ================================================ FILE: oneflow/core/functional/sequence_function.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_ #define ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_ #include #include #include "oneflow/core/common/maybe.h" namespace oneflow { namespace one { namespace functional { template class SequenceFunction; template class SequenceFunction { public: using first_f_type = std::function; using f_type = std::function().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile())&)>; explicit SequenceFunction(first_f_type&& f) : fn_(std::forward(f)) {} explicit SequenceFunction(const first_f_type& f) : fn_(f) {} SequenceFunction& then(f_type&& f) { auto fn_ = std::move(this->fn_); this->fn_ = [fn_, f](Args&&... args) -> R { return f(JUST(fn_(std::forward(args)...))); }; return *this; } SequenceFunction& then_if(bool condition, f_type&& f) { return condition ? then(std::forward(f)) : *this; } SequenceFunction& operator<<(f_type&& f) { return then(std::forward(f)); } R call(Args&&... args) const { return fn_(std::forward(args)...); } private: std::function fn_; }; #define sequence_function(f) SequenceFunction(f) } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_ ================================================ FILE: oneflow/core/functional/tensor_index.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace one { namespace functional { namespace { int64_t CountSpecifiedDims(const TensorIndex& index) { int64_t specified_ndims = 0; for (int i = 0; i < index.size(); ++i) { const auto& index_item = index.at(i); if (index_item.IsSlice() || index_item.IsInteger()) { specified_ndims++; } else if (index_item.IsTensor()) { const auto& tensor = index_item.tensor(); if (IsMaskTensor(tensor)) { specified_ndims += tensor->ndim(); } else { specified_ndims++; } } } return specified_ndims; } Maybe ExpandMaskIndex(const std::shared_ptr& index) { auto indices = std::make_shared(); const auto& res = JUST(functional::ArgWhere(index, DType::Int64())); if (res->size() != 2) { return Error::RuntimeError() << "Argwhere should returns 2 tensors, but got " << res->size(); } auto size_tensor = res->at(1); if (!size_tensor->is_eager()) { return Error::RuntimeError() << "Advanced indexing by boolean(mask) tensor only valid in eager mode."; } if (size_tensor->is_global()) { // TODO(): check size_tensor sbp is broadcast. size_tensor = JUST(functional::GlobalToLocal(size_tensor, /*copy=*/false)); } int64_t size = 0; const auto& callback = [&](ep::Stream* stream, const std::shared_ptr& eager_blob_object) { AutoMemcpy(stream, &size, eager_blob_object->dptr(), sizeof(size), memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; JUST(SyncAccessTensorWithTimeOut(size_tensor, callback, "const")); for (int i = 0; i < index->ndim(); ++i) { auto item = JUST(functional::Slice((*res)[0], {0, i}, {size, i + 1}, {1, 1}, /*enable_view_slice=*/false)); item = JUST(functional::Reshape(item, {size})); indices->emplace_back(item); } return indices; } // NOTE: expand each non-empty indice to same shape. Maybe ExpandIndices(const TensorTuple& indices) { std::shared_ptr expanded_shape; { bool first = true; for (int i = 0; i < indices.size(); ++i) { if (!indices.at(i)) { continue; } if (first) { expanded_shape = indices.at(i)->shape(); first = false; } else { const auto& shape = indices.at(i)->shape(); int ndims = std::max(shape->NumAxes(), expanded_shape->NumAxes()); DimVector sizes(ndims); for (int j = ndims - 1; j >= 0; --j) { int dim = j - (ndims - shape->NumAxes()); int expanded_dim = j - (ndims - expanded_shape->NumAxes()); if (dim < 0) { sizes[j] = expanded_shape->At(expanded_dim); } else if (expanded_dim < 0) { sizes[j] = shape->At(dim); } else { int size = shape->At(dim); int expanded_size = expanded_shape->At(expanded_dim); CHECK_OR_RETURN(size == expanded_size || size == 1 || expanded_size == 1) << Error::RuntimeError() << "The size of tensor a (" << size << ") must match the size of tensor b (" << expanded_size << ") at non-singleton dimension " << i; sizes[j] = size == 1 ? expanded_size : size; } } expanded_shape.reset(new Shape(sizes)); } } } auto expanded_indices = std::make_shared(indices.size()); for (int i = 0; i < indices.size(); ++i) { if (!indices.at(i)) { continue; } if (*(indices.at(i)->shape()) != *expanded_shape) { expanded_indices->at(i) = JUST(Expand(indices.at(i), *expanded_shape)); } else { expanded_indices->at(i) = indices.at(i); } } return expanded_indices; } // NOTE(wyg): // Judge whether all index dims are contiguous. // e.g. [:, index0, index1, :] -> True // [index0, :, index1] -> False // [index0, index1, :] -> True Maybe IsContinuousSubspace(const TensorTuple& indices) { int token = 0; for (int i = 0; i < indices.size(); ++i) { if (indices.at(i) && !token) { token = 1; } else if (indices.at(i) && token) { if (token != 1) { return false; } } else if (token) { token += 1; } } return true; } // NOTE(wyg): // Move indices subspace to be contiguous and ahead. // e.g. [:, index0, index1] -> [index0, index1, :] Maybe> TransposeFront(const std::shared_ptr& input, const TensorTuple& indices, std::shared_ptr* output, TensorTuple* valid_indices) { std::vector permute; permute.reserve(input->ndim()); for (int i = 0; i < input->ndim(); ++i) { if (i < indices.size() && indices.at(i)) { permute.emplace_back(i); valid_indices->emplace_back(indices.at(i)); } } for (int i = 0; i < input->ndim(); ++i) { if (i >= indices.size() || !indices.at(i)) { permute.emplace_back(i); } } bool need_transpose = [&]() { for (int i = 0; i < permute.size(); ++i) { if (permute.at(i) != i) { return true; } } return false; }(); if (need_transpose) { *output = JUST(Transpose(input, permute)); } else { *output = input; } return permute; } Maybe AdjustSubspace(const std::shared_ptr& input, const TensorTuple& indices, const int& index_ndim, bool reverse = false) { int index_subspace_pos = -1; for (int i = 0; i < indices.size(); ++i) { if (indices.at(i)) { index_subspace_pos = i; break; } } if (index_subspace_pos <= 0) { return input; } int ndim = input->ndim(); CHECK_LE_OR_RETURN(index_subspace_pos + index_ndim, ndim) << Error::IndexError() << "Failed to adjust subspace since the index is out of bounds for tensor dimension " << ndim; std::vector permute; { permute.reserve(ndim); if (reverse) { for (int i = 0; i < index_ndim; ++i) { permute.emplace_back(index_subspace_pos + i); } for (int i = 0; i < index_subspace_pos; ++i) { permute.emplace_back(i); } } else { for (int i = 0; i < index_subspace_pos; ++i) { permute.emplace_back(i + index_ndim); } for (int i = 0; i < index_ndim; ++i) { permute.emplace_back(i); } } for (int i = permute.size(); i < ndim; ++i) { permute.emplace_back(i); } } return Transpose(input, permute); } Maybe HasFalseIndex(const TensorIndex& index) { return std::any_of(index.begin(), index.end(), [](const detail::IndexItem& item) { return item.IsBoolean() && !item.boolean(); }); } bool IsValidScalarTensorIndex(const std::shared_ptr& tensor) { if (!(tensor->dtype()->is_integer() || tensor->dtype() == DType::Bool())) { return false; } return tensor->shape()->NumAxes() == 0 && tensor->shape()->elem_cnt() == 1; } // Permute back for global tensor which transpose dims to front Maybe PermuteBackForGlobalTensor(const std::shared_ptr& result, const std::vector& permute) { CHECK_OR_RETURN(result->is_global()); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(result->ndim(), permute.size()); // NOLINT(maybe-need-error-msg) std::vector inv_permute(permute.size()); for (int32_t i = 0; i < permute.size(); ++i) { inv_permute[permute[i]] = i; } bool not_permute = true; { for (int32_t i = 0; i < permute.size(); ++i) { if (inv_permute[i] != i) { not_permute = false; break; } } } if (!not_permute) { return Transpose(result, inv_permute); } else { return result; } } } // namespace bool IsMaskTensor(const std::shared_ptr& tensor) { return tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8() || tensor->dtype() == DType::Bool(); } Maybe PrepareSliceIndices(const TensorIndex& index, const Shape& shape, std::vector* slice_indices, TensorTuple* tensor_indices, std::vector* expand_dims, std::vector* target_dims) { int64_t ndims = shape.NumAxes(); int64_t specified_ndims = CountSpecifiedDims(index); CHECK_LE_OR_RETURN(specified_ndims, ndims) << Error::IndexError() << "Too many indices for tensor of dimension " << ndims; bool has_false_index = JUST(HasFalseIndex(index)); bool has_expand_boolean_dim = false; int dim = 0; for (int i = 0; i < index.size(); ++i) { const auto& index_item = index.at(i); if (index_item.IsNone()) { expand_dims->emplace_back(dim); slice_indices->emplace_back(0, 1, 1); target_dims->emplace_back(1); continue; } if (index_item.IsBoolean()) { if (!has_expand_boolean_dim) { int boolean_index = !has_false_index; expand_dims->emplace_back(dim); slice_indices->emplace_back(0, boolean_index, 1); target_dims->emplace_back(boolean_index); has_expand_boolean_dim = true; } continue; } if (index_item.IsEllipsis()) { int64_t unspecified_ndims = ndims - specified_ndims; unspecified_ndims = std::min(ndims - dim, unspecified_ndims); for (int j = 0; j < unspecified_ndims; ++j) { slice_indices->emplace_back(0, shape.At(dim + j), 1); target_dims->emplace_back(shape.At(dim + j)); } dim += unspecified_ndims; continue; } CHECK_LT_OR_RETURN(dim, ndims) << Error::IndexError() << "Invalid index for tensor of dimension " << ndims; if (index_item.IsSlice()) { const auto& slice = index_item.slice(); CHECK_GT_OR_RETURN(slice.step(), 0) << Error::RuntimeError() << "Step must be greater than zero."; int64_t step = std::min(slice.step(), shape.At(dim)); int64_t end = std::min(slice.end(), shape.At(dim)); int64_t start = std::min(slice.start(), shape.At(dim)); if (start < 0) { start += shape.At(dim); } if (start < 0) { start = 0; } if (end < 0) { end += shape.At(dim); } if (end < start) { end = start; } if (start == end) { step = 1; } slice_indices->emplace_back(start, end, step); int64_t length = start == end ? 0 : (end - start + step - 1) / step; target_dims->emplace_back(length); dim++; } else if (index_item.IsInteger()) { int64_t integer = index_item.integer(); if (integer < 0) { integer += shape.At(dim); } if (integer < 0 || integer >= shape.At(dim)) { return Error::IndexError() << "Index " << index_item.integer() << " is out of bounds for dimension " << dim << " with size " << shape.At(dim); } slice_indices->emplace_back(integer, integer + 1, 1); dim++; } else if (index_item.IsTensor()) { const auto& tensor = index_item.tensor(); if (IsValidScalarTensorIndex(tensor) && !LazyMode::is_enabled()) { if (tensor->dtype()->is_integer() && tensor->dtype()->data_type() != DataType::kBool) { int64_t integer = JUST(GetItemInScalarTensor(tensor)); if (integer < 0) { integer += shape.At(dim); } if (integer < 0 || integer >= shape.At(dim)) { return Error::IndexError() << "Index " << index_item.integer() << " is out of bounds for dimension " << dim << " with size " << shape.At(dim); } slice_indices->emplace_back(integer, integer + 1, 1); dim++; } else { bool boolean_index = JUST(GetItemInScalarTensor(tensor)); if (!has_expand_boolean_dim) { expand_dims->emplace_back(dim); slice_indices->emplace_back(0, boolean_index, 1); target_dims->emplace_back(boolean_index); has_expand_boolean_dim = true; } } } else { auto indices = std::make_shared(); if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8() || tensor->dtype() == DType::Bool()) { for (int j = 0; j < tensor->ndim(); ++j) { if (tensor->shape()->At(j) != shape.At(dim + j)) { return Error::IndexError() << "The shape of the mask " << tensor->shape()->ToString() << " at index " << j << " does not match the shape of the indexed tensor " << shape.ToString() << " at index " << dim + j; } } indices = JUST(ExpandMaskIndex(tensor)); } else { indices->emplace_back(tensor); } for (int j = 0; j < indices->size(); ++j) { slice_indices->emplace_back(0, shape.At(dim), 1); tensor_indices->resize(target_dims->size()); tensor_indices->emplace_back(indices->at(j)); target_dims->emplace_back(shape.At(dim)); dim++; } } } } for (int i = dim; i < ndims; ++i) { slice_indices->emplace_back(0, shape.At(i), 1); target_dims->emplace_back(shape.At(i)); } return Maybe::Ok(); } Maybe> RemoveExpandDimSlice( const std::vector& expand_slices, const std::vector& expand_dims) { auto slices = std::make_shared>(); std::vector mask(expand_slices.size(), 0); for (const auto& dim : expand_dims) { if (dim >= expand_slices.size()) { return Error::IndexError() << "Dimension " << dim << " is out of bounds for size " << expand_slices.size(); } mask[dim] = 1; } for (int i = 0; i < expand_slices.size(); ++i) { if (!mask[i]) { slices->emplace_back(expand_slices.at(i)); } } return slices; } Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, const TensorTuple& indices) { CHECK_GE_OR_RETURN(input->ndim(), indices.size()) << Error::IndexError() << "Too many indices for tensor of dimension " << input->ndim(); const auto& expanded_indices = JUST(ExpandIndices(indices)); bool is_continuous_subspace = JUST(IsContinuousSubspace(indices)); // Since the start dimension cannot be specified for `gather_nd`, so we should // transpose the input as long as the first index is null. std::shared_ptr transposed_input; TensorTuple valid_indices; JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices)); if (valid_indices.empty()) { return input; } int index_ndim = valid_indices.at(0)->ndim(); std::shared_ptr packed_indices; if (valid_indices.size() == 1) { packed_indices = JUST(functional::Unsqueeze(valid_indices.at(0), -1)); } else { packed_indices = JUST(Stack(valid_indices, 0)); int packed_ndim = packed_indices->ndim(); CHECK_GT_OR_RETURN(packed_ndim, 0) << Error::RuntimeError() << "Index array dimension should be greater than 0."; std::vector permute(packed_ndim); permute[packed_ndim - 1] = 0; std::iota(permute.begin(), permute.end() - 1, 1); packed_indices = JUST(Transpose(packed_indices, permute))->contiguous(); } CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local()) << Error::RuntimeError() << "The input and indices must be both local or global."; auto result = JUST(GatherNd(transposed_input, packed_indices)); int required_ndim = input->ndim() - valid_indices.size() + index_ndim; CHECK_EQ_OR_RETURN(result->ndim(), required_ndim) << Error::RuntimeError() << "The indexing result dimension is " << result->ndim() << ", but shoule be " << required_ndim; if (is_continuous_subspace) { result = JUST(AdjustSubspace(result, indices, index_ndim, /*reverse*/ false)); } return result; } Maybe ApplyAdvancedIndexingUpdate(const std::shared_ptr& input, const TensorTuple& indices, const std::shared_ptr& value) { CHECK_GE_OR_RETURN(input->ndim(), indices.size()) << Error::IndexError() << "Too many indices for tensor of dimension " << input->ndim(); const auto& expanded_indices = JUST(ExpandIndices(indices)); bool is_continuous_subspace = JUST(IsContinuousSubspace(indices)); // Since the start dimension cannot be specified for `scatter_nd`, so we should // transpose the input as long as the first index is null. std::shared_ptr transposed_input; TensorTuple valid_indices; const auto& transposed_input_permute = JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices)); // NOTE: For local tensor, we make sure that transposed_input is a view of input. // Therefore we need not transpose it back because we update the value in a same memory // by tensor_scatter_nd_update operator. if (input->is_local()) { CHECK_EQ_OR_RETURN(JUST(transposed_input->tensor_storage()), JUST(input->tensor_storage())) << Error::RuntimeError() << "This setitem operator must enable view mechanism, please try to set " "ONEFLOW_DISABLE_VIEW=0"; } if (valid_indices.empty()) { CHECK_EQ_OR_RETURN(value->nelement(), 0) << Error::IndexError() << "invalid indices"; return input; } int index_ndim = valid_indices[0]->ndim(); auto packed_indices = JUST(Stack(valid_indices, 0)); { int packed_ndim = packed_indices->ndim(); CHECK_GT_OR_RETURN(packed_ndim, 0) << Error::RuntimeError() << "Index array dimension should be greater than 0."; std::vector permute(packed_ndim); permute[packed_ndim - 1] = 0; std::iota(permute.begin(), permute.end() - 1, 1); packed_indices = JUST(Transpose(packed_indices, permute))->contiguous(); } CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local()) << Error::RuntimeError() << "The input and indices must be both local or global."; Shape expand_shape; { if (is_continuous_subspace) { bool index_subspace_begin = true; for (int i = 0; i < indices.size(); ++i) { // if the index is the first not-null index if (indices[i]) { if (!index_subspace_begin) { continue; } for (int j = 0; j < index_ndim; ++j) { expand_shape.emplace_back(valid_indices[0]->shape()->At(j)); } index_subspace_begin = false; } else { expand_shape.emplace_back(input->shape()->At(i)); } } } else { expand_shape = *(valid_indices[0]->shape()); for (int i = 0; i < indices.size(); ++i) { if (!indices[i]) { expand_shape.emplace_back(input->shape()->At(i)); } } } for (int i = indices.size(); i < input->ndim(); ++i) { expand_shape.emplace_back(input->shape()->At(i)); } } std::shared_ptr expand_value = JUST(Expand(value, expand_shape)); // reverse adjust value if index subspace is continuous but transposed since the start // dimension cannot be specified for `scatter_nd` if (is_continuous_subspace) { expand_value = JUST(AdjustSubspace(expand_value, indices, index_ndim, /*reverse*/ true)); } JUST(TensorScatterNdUpdate(transposed_input, packed_indices, expand_value, /*inplace=*/true)); // Global tensor is not support view, so we should permute back and copy to origin input if need if (transposed_input->is_global()) { return PermuteBackForGlobalTensor(transposed_input, *transposed_input_permute); } return transposed_input; } Maybe ApplySelectIndexing(const std::shared_ptr& input, const TensorIndex& tensor_index) { const int32_t index = tensor_index[0].integer(); const int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim > 0) << Error::RuntimeError() << "select() cannot be applied to a 0-dim tensor."; const int32_t pos_dim = 0; auto size = input->dim(pos_dim); CHECK_OR_RETURN(index >= -size && index < size) << Error::IndexError() << "Index out of range (expected to be in range of [" << -size << "," << size - 1 << "], but got " << index << ")"; int32_t pos_index = index >= 0 ? index : index + size; std::vector sizes(input->shape()->dim_vec().begin() + 1, input->shape()->dim_vec().end()); const auto& stride = *JUST(input->stride()); const int64_t storage_offset = JUST(input->storage_offset()) + pos_index * stride[pos_dim]; std::vector strides(stride.begin() + 1, stride.end()); return functional::AsStrided(input, sizes, strides, storage_offset); } Maybe UnifyInputAndIndicesOnDevice(const std::shared_ptr& x, TensorTuple& tensor_indices) { if (x->is_local()) { const auto x_device = JUST(x->device()); for (int64_t i = 0; i < tensor_indices.size(); ++i) { const auto tensor_index = tensor_indices[i]; if (tensor_index == nullptr) { continue; } if (tensor_index->is_global()) { return Maybe::Ok(); } const auto tensor_index_device = JUST(tensor_index->device()); if ((tensor_index_device->type() != x_device->type()) || (tensor_index_device->device_id() != x_device->device_id())) { tensor_indices[i] = JUST(Copy(tensor_index, x_device->type(), x_device->device_id(), /*pin_memory=*/false)); } } } else { // global tensor const auto& placement = JUST(x->parallel_desc()); const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel()); int n = JUST(x->nd_sbp())->sbp_parallel_size(); std::vector> grad_sbp_tuple; for (int64_t i = 0; i < tensor_indices.size(); ++i) { const auto tensor_index = tensor_indices[i]; if (tensor_index == nullptr) { continue; } if (tensor_index->is_local()) { // NOTE: LocalToGlobal should be called in eager mode LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); tensor_indices[i] = JUST(ToGlobal(tensor_index, placement, std::vector>(n, broadcast_sbp), grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false)); } } } return Maybe::Ok(); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/tensor_index.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_ #define ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_ #include #include #include #include "oneflow/core/common/shape.h" namespace oneflow { namespace one { class Tensor; class TensorTuple; namespace functional { namespace detail { struct NoneIndex {}; struct EllipsisIndex {}; class Slice { public: Slice() : Slice(0, std::numeric_limits::max(), 1) {} explicit Slice(int64_t start) : Slice(start, std::numeric_limits::max(), 1) {} explicit Slice(int64_t start, int64_t end) : Slice(start, end, 1) {} explicit Slice(int64_t start, int64_t end, int64_t step) : start_(start), end_(end), step_(step) {} int64_t start() const { return start_; } int64_t end() const { return end_; } int64_t step() const { return step_; } std::string ToString() const { std::stringstream ss; ss << "[" << start_ << ":" << end_ << ":" << step_ << "]\n"; return ss.str(); } private: int64_t start_; int64_t end_; int64_t step_; }; class IndexItem { public: IndexItem() : IndexItem(NoneIndex()) {} explicit IndexItem(NoneIndex none) : item_{.dummy = 0}, tag_(HAS_NONE) {} explicit IndexItem(int64_t start, int64_t end, int64_t step) : item_{.slice = Slice{start, end, step}}, tag_(HAS_SLICE) {} explicit IndexItem(const Slice& slice) : item_{.slice = slice}, tag_(HAS_SLICE) {} explicit IndexItem(int64_t index) : item_{.i = index}, tag_(HAS_INT) {} explicit IndexItem(bool boolean) : item_{.b = boolean}, tag_(HAS_BOOLEAN) {} explicit IndexItem(EllipsisIndex ellipsis) : item_{.dummy = 0}, tag_(HAS_ELLIPSIS) {} explicit IndexItem(const std::shared_ptr& tensor) : item_{.dummy = 0}, tensor_(tensor), tag_(HAS_TENSOR) {} bool IsSlice() const { return tag_ == HAS_SLICE; } const Slice& slice() const { return item_.slice; } bool IsInteger() const { return tag_ == HAS_INT; } int64_t integer() const { return item_.i; } bool IsBoolean() const { return tag_ == HAS_BOOLEAN; } bool boolean() const { return item_.b; } bool IsEllipsis() const { return tag_ == HAS_ELLIPSIS; } bool IsNone() const { return tag_ == HAS_NONE; } bool IsTensor() const { return tag_ == HAS_TENSOR; } const std::shared_ptr& tensor() const { return tensor_; } private: union { Slice slice; bool b; int64_t i; char dummy; } item_; std::shared_ptr tensor_; enum { HAS_SLICE, HAS_BOOLEAN, HAS_INT, HAS_ELLIPSIS, HAS_NONE, HAS_TENSOR } tag_; }; } // namespace detail class TensorIndex : public std::vector { public: using std::vector::vector; }; bool IsMaskTensor(const std::shared_ptr& tensor); Maybe PrepareSliceIndices(const TensorIndex& index, const Shape& shape, std::vector* slice_indices, TensorTuple* tensor_indices, std::vector* expand_dims, std::vector* target_dims); Maybe> RemoveExpandDimSlice( const std::vector& expand_slices, const std::vector& expand_dims); Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, const TensorTuple& indices); Maybe ApplySelectIndexing(const std::shared_ptr& input, const TensorIndex& index); Maybe UnifyInputAndIndicesOnDevice(const std::shared_ptr& x, TensorTuple& tensor_indices); Maybe ApplyAdvancedIndexingUpdate(const std::shared_ptr& input, const TensorTuple& indices, const std::shared_ptr& value); } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_ ================================================ FILE: oneflow/core/functional/tensor_processor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/functional/tensor_processor.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { namespace functional { namespace { Symbol ComputeCommonDType(const TensorTuple& tensor_tuple) { Symbol common_dtype = DType::InvalidDataType(); bool all_scalar_tensors = std::all_of( tensor_tuple.begin(), tensor_tuple.end(), [](const std::shared_ptr& tensor) { return tensor->shape()->NumAxes() == 0; }); for (auto& tensor_ptr : tensor_tuple) { // skip scalar tensor if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0 && !(tensor_ptr->dtype()->is_complex())) { continue; } common_dtype = promoteTypes(tensor_ptr->dtype(), common_dtype); } return common_dtype; } bool CheckHasDifferentInputDType(const TensorTuple& tensor_tuple) { if (tensor_tuple.size() <= 1) { return false; } Symbol common_dtype = tensor_tuple[0]->dtype(); for (auto& tensor_ptr : tensor_tuple) { if (common_dtype != tensor_ptr->dtype()) { return true; } } return false; } Maybe CastToSameType(TensorTuple& tensor_tuple, const Symbol& common_dtype) { for (auto& tensor_ptr : tensor_tuple) { if (tensor_ptr->dtype() != common_dtype) { tensor_ptr = JUST(functional::Cast(tensor_ptr, common_dtype, /*pin_memory=*/false)); } } return Maybe::Ok(); } } // namespace TensorProcessor& TensorProcessor::AddInputs(const TensorTuple& init_tensor_or_tuple) { for (const auto& tensor : init_tensor_or_tuple) { tensor_tuple_.emplace_back(tensor); inputs_lowest_dtype_vec_.emplace_back(DType::InvalidDataType()); } return *this; } TensorProcessor& TensorProcessor::AddInputs(const TensorTuple& init_tensor_or_tuple, Symbol tensor_lowest_dtype) { for (const auto& tensor : init_tensor_or_tuple) { tensor_tuple_.emplace_back(tensor); inputs_lowest_dtype_vec_.emplace_back(tensor_lowest_dtype); } return *this; } TensorProcessor& TensorProcessor::PromoteInputsToCommonDtype(bool is_promote) { promote_inputs_to_common_dtype_ = is_promote; return *this; } TensorProcessor& TensorProcessor::PromoteInputsToCommonDtype( bool is_promote, const Optional>& promote_dtype) { promote_inputs_to_common_dtype_ = is_promote; promote_dtype_ = promote_dtype; return *this; } TensorProcessor& TensorProcessor::PromoteIntegerInputsToFloatDtype(bool is_promote) { promote_integer_inputs_to_float_ = is_promote; CHECK_OR_THROW(!promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_) << "when set promote_integer_inputs_to_float to 'True', then promote_inputs_to_common_dtype " "should be set to 'True' first!"; return *this; } Maybe TensorProcessor::Apply() { if (promote_inputs_to_common_dtype_) { bool has_different_input_dtype = CheckHasDifferentInputDType(tensor_tuple_); if (has_different_input_dtype) { if (promote_dtype_.has_value()) { common_dtype_ = CHECK_JUST(promote_dtype_); } else { common_dtype_ = ComputeCommonDType(tensor_tuple_); } if (promote_integer_inputs_to_float_ && common_dtype_->is_integer()) { // Promotes common dtype to the default float scalar type, if needed. // same to pytorch's computeTypes() in torch/csrc/jit/codegen/cuda/type_promotion.cpp common_dtype_ = DType::Float(); } JUST(CastToSameType(tensor_tuple_, common_dtype_)); } else { if (tensor_tuple_.size() == 1 && !((tensor_tuple_[0]->dtype()->is_floating_point()) || tensor_tuple_[0]->dtype()->is_complex())) { Symbol cast_dtype = (inputs_lowest_dtype_vec_[0] == DType::InvalidDataType()) ? DType::Float() : inputs_lowest_dtype_vec_[0]; JUST(CastToSameType(tensor_tuple_, cast_dtype)); } } } else { for (int i = 0; i < tensor_tuple_.size(); ++i) { // Cast all the inputs to it's attribute `lowest_dtype` if the input tensor dtype is lower // than attribute `lowest_dtype`. Symbol base_dtype = inputs_lowest_dtype_vec_.at(i); if (base_dtype->data_type() && DType::priority_order[base_dtype->data_type()] > DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()]) { tensor_tuple_[i] = JUST(one::functional::Cast(tensor_tuple_[i], base_dtype, /*pin_memory=*/false)); } } } return Maybe::Ok(); } static bool IsAllContiguous(const TensorTuple& tensors) { for (const auto& t : tensors) { if (t && !t->is_contiguous()) { return false; } } return true; } Maybe TensorLayoutProcessor::Apply() { if (LazyMode::is_enabled()) { return Maybe::Ok(); } if (!non_contiguous_enabled_ && !IsAllContiguous(inputs_)) { contiguous_inputs_.resize(inputs_.size()); for (int i = 0; i < inputs_.size(); ++i) { contiguous_inputs_[i] = inputs_[i]->contiguous(); } } // inplace operation is not allowed if input is non-contiguous and non-contiguous is // not supported for this operation if (!non_contiguous_enabled_ && outputs_ && !IsAllContiguous(*outputs_)) { post_process_outputs_.reserve(outputs_->size()); post_process_output_indices_.reserve(outputs_->size()); for (int i = 0; i < outputs_->size(); ++i) { if ((*outputs_)[i] && !(*outputs_)[i]->is_contiguous()) { post_process_outputs_.emplace_back((*outputs_)[i]); post_process_output_indices_.emplace_back(i); (*outputs_)[i] = nullptr; } } } return Maybe::Ok(); } TensorLayoutProcessor::~TensorLayoutProcessor() { for (int i = 0; i < post_process_output_indices_.size(); ++i) { int output_index = post_process_output_indices_[i]; CHECK_OR_THROW((*outputs_)[output_index]) << "the output which index is " << i << " should not be nullptr"; functional::TensorIndex ellipsis_index; ellipsis_index.emplace_back(functional::detail::EllipsisIndex()); CHECK_JUST(functional::TensorSetItem(post_process_outputs_[i], ellipsis_index, (*outputs_)[output_index])); (*outputs_)[output_index] = post_process_outputs_[i]; } } Maybe TensorAutoCastProcessor::Apply() { if (!autocast::is_enabled()) { return Maybe::Ok(); } if (autocast_meta_.autocast_color() == autocast::kNoColor) { return Maybe::Ok(); } auto autocast_device_type = autocast::get_autocast_device_type(); auto autocast_dtype = autocast::get_autocast_dtype(); auto IsDeviceType = [](const std::shared_ptr& tensor, DeviceType device_type) -> Maybe { return tensor->is_local() ? JUST(tensor->device())->enum_type() == device_type : JUST(tensor->parallel_desc())->device_type() == device_type; }; bool is_autocast_eligible = [&]() { if (!autocast_meta_.is_autocast_eligible(autocast_device_type, autocast_dtype)) { return false; } // Skip autocast if output data type is float32 if (outputs_) { for (const auto& output : *outputs_) { if (output && output->dtype() != autocast_dtype) { return false; } } } // Skip autocast if any input is float32 for gray or clear list if (autocast_meta_.autocast_color() != autocast::kWhite) { for (int i = 0; i < inputs_.size(); ++i) { if (autocast_meta_.is_args_autocast_eligible(i) && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != autocast_dtype) { return false; } } } return true; }(); // Disable autocast temporarily to avoid going into a dead loop autocast::set_enabled(false); if (is_autocast_eligible) { const auto& args_eligible = autocast_meta_.is_args_autocast_eligible(); CHECK_EQ_OR_RETURN(args_eligible.size(), inputs_.size()) << Error::RuntimeError() << "argument autocast eligible size should equal to input size"; autocast_inputs_.resize(inputs_.size()); for (int i = 0; i < inputs_.size(); ++i) { if (args_eligible[i] && JUST(IsDeviceType(inputs_[i], autocast_device_type)) && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != autocast_dtype) { autocast_inputs_[i] = JUST(autocast::cached_cast(inputs_[i], autocast_dtype, JUST(inputs_[i]->device())->enum_type())); } else { autocast_inputs_[i] = inputs_[i]; } } } else { // Fallback to float32 auto common_dtype = ComputeCommonDType(inputs_); auto promote_dtype = promoteTypes(common_dtype, DType::Float()); autocast_inputs_.resize(inputs_.size()); for (int i = 0; i < inputs_.size(); ++i) { if (JUST(IsDeviceType(inputs_[i], autocast_device_type)) && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != promote_dtype) { autocast_inputs_[i] = JUST(functional::To(inputs_[i], promote_dtype, /*copy*/ false)); } else { autocast_inputs_[i] = inputs_[i]; } } } // Enable autocast to restore autocast state autocast::set_enabled(true); return Maybe::Ok(); } } // namespace functional } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/functional/tensor_processor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_ #define ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_ #include #include #include #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/framework/autocast.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/common/optional.h" namespace oneflow { namespace one { namespace functional { class TensorProcessor final { public: TensorProcessor() : common_dtype_(DType::InvalidDataType()), promote_dtype_(NullOpt), promote_inputs_to_common_dtype_(false), promote_integer_inputs_to_float_(false){}; TensorProcessor& AddInputs(const TensorTuple& init_list); TensorProcessor& AddInputs(const TensorTuple& init_list, Symbol tensor_lowest_dtype); Maybe Apply(); TensorProcessor& PromoteInputsToCommonDtype(bool is_promote); TensorProcessor& PromoteInputsToCommonDtype(bool is_promote, const Optional>& promote_dtype); TensorProcessor& PromoteIntegerInputsToFloatDtype(bool is_promote); Maybe GetInputs() { return tensor_tuple_; }; private: TensorTuple tensor_tuple_; Symbol common_dtype_; Optional> promote_dtype_; std::vector> inputs_lowest_dtype_vec_; bool promote_inputs_to_common_dtype_; bool promote_integer_inputs_to_float_; }; class TensorLayoutProcessor final { public: TensorLayoutProcessor(const TensorTuple& inputs, bool non_contiguous_enabled) : TensorLayoutProcessor(inputs, nullptr, non_contiguous_enabled) {} TensorLayoutProcessor(const TensorTuple& inputs, TensorTuple* outputs, bool non_contiguous_enabled) : inputs_(inputs), outputs_(outputs), non_contiguous_enabled_(non_contiguous_enabled) {} ~TensorLayoutProcessor(); Maybe Apply(); const TensorTuple& inputs() const { if (!contiguous_inputs_.empty()) { return contiguous_inputs_; } return inputs_; } TensorTuple* outputs() const { return outputs_; } private: const TensorTuple& inputs_; TensorTuple* outputs_; bool non_contiguous_enabled_; TensorTuple contiguous_inputs_; std::vector post_process_output_indices_; TensorTuple post_process_outputs_; }; class TensorAutoCastProcessor final { public: TensorAutoCastProcessor(const TensorTuple& inputs, const autocast::AutoCastMeta& autocast_meta) : TensorAutoCastProcessor(inputs, nullptr, autocast_meta) {} TensorAutoCastProcessor(const TensorTuple& inputs, TensorTuple* outputs, const autocast::AutoCastMeta& autocast_meta) : inputs_(inputs), outputs_(outputs), autocast_meta_(autocast_meta) {} ~TensorAutoCastProcessor() = default; Maybe Apply(); const TensorTuple& inputs() const { if (!autocast_inputs_.empty()) { return autocast_inputs_; } return inputs_; } TensorTuple* outputs() const { return outputs_; } private: const TensorTuple& inputs_; TensorTuple* outputs_; const autocast::AutoCastMeta& autocast_meta_; TensorTuple autocast_inputs_; }; template struct TupleTrait { constexpr static size_t size = sizeof...(TPArgs); constexpr static size_t max_storage_size = std::max({sizeof(TPArgs)...}); constexpr static size_t alignment = std::max({alignof(TPArgs)...}); using type = std::tuple; }; struct TensorProcessorTuple { using trait = TupleTrait; constexpr static size_t size = trait::size; constexpr static size_t max_storage_size = trait::max_storage_size; constexpr static size_t alignment = trait::alignment; using type = typename trait::type; }; class TensorProcessorStorage { public: constexpr static size_t TPMaxStorageSize = TensorProcessorTuple::max_storage_size; TensorProcessorStorage() = default; TensorProcessorStorage(TensorProcessorStorage&& other) = default; ~TensorProcessorStorage() { if (deleter_) { deleter_(buffer_); } } template void New(Args&&... args) { static_assert(sizeof(TP) <= TPMaxStorageSize, "Insufficient buffer size"); new (buffer_) TP(std::forward(args)...); deleter_ = [](char* buffer) { reinterpret_cast(buffer)->~TP(); }; } template TP* As() { return reinterpret_cast(buffer_); } private: alignas(TensorProcessorTuple::alignment) char buffer_[TPMaxStorageSize]; std::function deleter_; }; class TensorProcessorPipe final { public: constexpr static size_t TPSize = TensorProcessorTuple::size; TensorProcessorPipe(const TensorTuple& inputs) : TensorProcessorPipe(inputs, nullptr) {} TensorProcessorPipe(const TensorTuple& inputs, TensorTuple* outputs) : inputs_(&inputs), outputs_(outputs), index_(0) {} template Maybe Apply(Args&&... args) { CHECK_LT_OR_RETURN(index_, static_cast(TPSize)) << Error::RuntimeError() << "The tensor processor pipe can only be applied up to " << static_cast(TPSize) << " times"; processors_[index_].New(*inputs_, outputs_, std::forward(args)...); auto* processor = processors_[index_].As(); JUST(processor->Apply()); inputs_ = &(processor->inputs()); outputs_ = processor->outputs(); ++index_; return Maybe::Ok(); } const TensorTuple& inputs() const { return *inputs_; } TensorTuple* outputs() const { return outputs_; } private: const TensorTuple* inputs_; TensorTuple* outputs_; int index_; TensorProcessorStorage processors_[TPSize]; }; } // namespace functional } // namespace one } // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_ ================================================ FILE: oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" namespace oneflow { Maybe B21SubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) && out_parallel_desc.parallel_num() == 1) { const int64_t out_parallel_id = 0; const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_parallel_id); sorted_ctrl_tasks->resize(1); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); if (i == nearest_in_parallel_id) { TaskNode* proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, out_parallel_id); sorted_out_tasks->emplace_back(proxy); } else { sorted_ctrl_tasks->at(0).emplace_back(in_node); } } return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/b21_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class B21SubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(B21SubTskGphBuilder); B21SubTskGphBuilder() = default; ~B21SubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/boxing_logger.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/boxing_logger.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace { #define OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD \ "src_op_name,dst_op_name,src_parallel_desc,dst_parallel_desc," \ "src_nd_sbp," \ "dst_nd_sbp,lbi,dtype,shape,builder,comment\n" std::string ShapeToString(const Shape& shape) { std::stringstream shape_ss; auto dim_vec = shape.dim_vec(); shape_ss << "("; for (int32_t i = 0; i < dim_vec.size(); ++i) { shape_ss << dim_vec.at(i); if (i != dim_vec.size() - 1) { shape_ss << " "; } } shape_ss << ")"; return shape_ss.str(); } std::string ParallelDescToString(const ParallelDesc& parallel_desc) { std::string serialized_parallel_desc; std::string device_type; device_type = *CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type())); auto sorted_machine_ids = parallel_desc.sorted_machine_ids(); serialized_parallel_desc += "{"; for (int64_t i = 0; i < sorted_machine_ids.size(); ++i) { const int64_t machine_id = sorted_machine_ids.at(i); serialized_parallel_desc += std::to_string(machine_id) + ":" + device_type + ":"; int64_t min_id = parallel_desc.sorted_dev_phy_ids(machine_id).front(); int64_t max_id = parallel_desc.sorted_dev_phy_ids(machine_id).back(); serialized_parallel_desc += std::to_string(min_id) + "-" + std::to_string(max_id); serialized_parallel_desc += " "; } serialized_parallel_desc += ShapeToString(*parallel_desc.hierarchy()); serialized_parallel_desc += "}"; return serialized_parallel_desc; } std::string NdSbpToCsvString(const NdSbp& nd_sbp) { std::ostringstream ss; ss << "("; for (size_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (i > 0) { ss << " "; } ss << SbpToString(nd_sbp.sbp_parallel(i)); } ss << ")"; return ss.str(); } std::string MakeBoxingLoggerCsvRow(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) { std::string serialized_status; serialized_status += src_op_name + ","; serialized_status += dst_op_name + ","; serialized_status += ParallelDescToString(src_parallel_desc) + ","; serialized_status += ParallelDescToString(dst_parallel_desc) + ","; serialized_status += NdSbpToCsvString(src_nd_sbp) + ","; serialized_status += NdSbpToCsvString(dst_nd_sbp) + ","; serialized_status += GenLogicalBlobName(lbi) + ","; serialized_status += DataType_Name(logical_blob_desc.data_type()) + ","; serialized_status += ShapeToString(logical_blob_desc.shape()) + ","; serialized_status += status.builder_name() + ","; if (status.comment().empty()) { serialized_status += "-"; } else { serialized_status += status.comment(); } serialized_status += "\n"; return serialized_status; } } // namespace CsvBoxingLogger::CsvBoxingLogger(std::string path) { log_stream_ = TeePersistentLogStream::Create(path); log_stream_ << OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD; } CsvBoxingLogger::~CsvBoxingLogger() { log_stream_->Flush(); } void CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) { log_stream_ << MakeBoxingLoggerCsvRow(status, src_op_name, dst_op_name, src_parallel_desc, dst_parallel_desc, src_nd_sbp, dst_nd_sbp, lbi, logical_blob_desc); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/boxing_logger.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_ #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h" namespace oneflow { class BoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(BoxingLogger); BoxingLogger() = default; virtual ~BoxingLogger() = default; virtual void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) = 0; }; class NullBoxingLogger final : public BoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(NullBoxingLogger); NullBoxingLogger() = default; ~NullBoxingLogger() override = default; void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) override{}; }; class CsvBoxingLogger final : public BoxingLogger { public: OF_DISALLOW_COPY_AND_MOVE(CsvBoxingLogger); CsvBoxingLogger() = delete; CsvBoxingLogger(std::string path); ~CsvBoxingLogger() override; void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) override; private: std::unique_ptr log_stream_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_ ================================================ FILE: oneflow/core/graph/boxing/ccl_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/collective_boxing_task_node.h" #include "oneflow/core/graph/collective_boxing_pack_task_node.h" #include "oneflow/core/graph/collective_boxing_unpack_task_node.h" #include "oneflow/core/graph/slice_boxing_task_node.h" #include "oneflow/core/graph/task_stream_id.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { using namespace boxing::collective; namespace { void CclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node, const ParallelDesc& parallel_desc, int64_t parallel_id, const std::string& name, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, OpType op_type, DeviceType device_type, int64_t root) { OperatorConf op_conf; op_conf.set_name(name); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type))); CollectiveBoxingGenericOpConf* conf = op_conf.mutable_collective_boxing_generic_conf(); *conf->mutable_lbi() = lbi; RankDesc* rank_desc = conf->mutable_rank_desc(); OpDesc* op_desc = rank_desc->mutable_op_desc(); op_desc->set_name(name); op_desc->set_op_type(op_type); if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeReduce) { op_desc->set_reduce_method(ReduceMethod::kReduceMethodSum); } op_desc->set_data_type(logical_blob_desc.data_type()); logical_blob_desc.shape().ToProto(op_desc->mutable_shape()); op_desc->set_num_ranks(parallel_desc.parallel_num()); if (op_type == OpType::kOpTypeBroadcast || op_type == OpType::kOpTypeReduce) { CHECK_GE(root, 0); CHECK_LT(root, parallel_desc.parallel_num()); op_desc->set_root(root); } else { CHECK_EQ(root, -1); } op_desc->set_device_type(device_type); rank_desc->set_rank(parallel_id); const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); const int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId( machine_id, device_type, device_index, *CHECK_JUST(DeviceTag4DeviceType(device_type)))); node->Init(machine_id, thrd_id, lbi, op_conf); } int64_t FindRootParallelId(const ParallelDesc& multi_device, const ParallelDesc& sole_device) { CHECK_EQ(sole_device.parallel_num(), 1); const int64_t root_machine_id = CHECK_JUST(sole_device.MachineId4ParallelId(0)); const int64_t root_device_id = CHECK_JUST(sole_device.DeviceId4ParallelId(0)); int64_t root_parallel_id = -1; FOR_RANGE(int64_t, i, 0, multi_device.parallel_num()) { if (CHECK_JUST(multi_device.MachineId4ParallelId(i)) == root_machine_id && CHECK_JUST(multi_device.DeviceId4ParallelId(i)) == root_device_id) { root_parallel_id = i; break; } } return root_parallel_id; } } // namespace bool IsSourceTimeShape(const Shape& shape) { return shape.elem_cnt() == 1; } Maybe CclAllReduceSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (out_parallel_desc.Equals(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) { const std::string op_name = "System-Boxing-CclBoxingAllReduce-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllReduce, device_type_, -1); ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); sorted_out_tasks->emplace_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingAllReduceSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclReduceScatterSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (out_parallel_desc.Equals(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && logical_blob_desc.shape().NumAxes() > 0 && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0 && SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel) && out_sbp_parallel.split_parallel().axis() == 0) { const std::string op_name = "System-Boxing-CclBoxingReduceScatter-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduceScatter, device_type_, -1); ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); sorted_out_tasks->emplace_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingReduceScatterSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclP2SNoncontinuousSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { const Shape& shape = logical_blob_desc.shape(); const int64_t out_split_axis = out_sbp_parallel.split_parallel().axis(); if (out_parallel_desc.Equals(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel) && shape.NumAxes() > out_split_axis && shape.At(out_split_axis) % out_parallel_desc.parallel_num() == 0 && out_sbp_parallel.split_parallel().axis() != 0) { const std::string op_name = "System-Boxing-CclBoxingP2SNoncontinuous-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i)); const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i)); const int64_t thrd_id = EncodeStreamIdToInt64( GenerateComputeTaskStreamId(machine_id, device_type_, device_index)); TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT CollectiveBoxingPackTaskNode* pack_node = ctx->task_graph()->NewNode(); pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi); auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode( collective_node, in_parallel_desc, i, op_name, lbi, BlobDesc({logical_blob_desc.shape().elem_cnt()}, logical_blob_desc.data_type(), logical_blob_desc.memory_format()), OpType::kOpTypeReduceScatter, device_type_, -1); ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi); CollectiveBoxingUnpackTaskNode* unpack_node = ctx->task_graph()->NewNode(); unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi); sorted_out_tasks->emplace_back(unpack_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingP2SNoncontinuousSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclAllGatherSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && SubTskGphBuilderUtil::IsDeviceTypeCPUOr(in_parallel_desc, device_type_) && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && logical_blob_desc.shape().NumAxes() > 0 && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0 && SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel) && in_sbp_parallel.split_parallel().axis() == 0) { const std::string op_name = "System-Boxing-CclBoxingAllGather-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT TaskNode* in_node_proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i); auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, device_type_, -1); ctx->task_graph()->ConnectWithLbi(in_node_proxy, collective_node, lbi); sorted_out_tasks->emplace_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingAllGatherSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclS2BNoncontinuousSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { const Shape& shape = logical_blob_desc.shape(); const int64_t in_split_axis = in_sbp_parallel.split_parallel().axis(); if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && SubTskGphBuilderUtil::IsDeviceTypeCPUOr(in_parallel_desc, device_type_) && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel) && shape.NumAxes() > in_split_axis && in_split_axis > 0 && shape.At(in_split_axis) % out_parallel_desc.parallel_num() == 0) { const std::string op_name = "System-Boxing-CclBoxingS2BNoncontinuous-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { const int64_t machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(i)); const int64_t device_index = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(i)); const int64_t thrd_id = EncodeStreamIdToInt64( GenerateComputeTaskStreamId(machine_id, device_type_, device_index)); TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT TaskNode* in_node_proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i); CollectiveBoxingPackTaskNode* pack_node = ctx->task_graph()->NewNode(); pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(in_node_proxy, pack_node, lbi); auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode( collective_node, out_parallel_desc, i, op_name, lbi, BlobDesc({logical_blob_desc.shape().elem_cnt()}, logical_blob_desc.data_type(), logical_blob_desc.memory_format()), OpType::kOpTypeAllGather, device_type_, -1); ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi); CollectiveBoxingUnpackTaskNode* unpack_node = ctx->task_graph()->NewNode(); unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi); sorted_out_tasks->emplace_back(unpack_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingS2BNoncontinuousSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclReduceSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (in_parallel_desc.parallel_num() > 1 && out_parallel_desc.parallel_num() == 1 && in_parallel_desc.device_type() == device_type_ && out_parallel_desc.device_type() == device_type_ && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && in_sbp_parallel.has_partial_sum_parallel()) { const int64_t root_parallel_id = FindRootParallelId(in_parallel_desc, out_parallel_desc); if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); } const std::string op_name = "System-Boxing-CclBoxingReduce-" + NewUniqueId(); sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduce, device_type_, root_parallel_id); ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); if (i == root_parallel_id) { sorted_out_tasks->emplace_back(collective_node); } else { sorted_ctrl_tasks->at(0).emplace_back(collective_node); // NOLINT } } return TRY(BuildSubTskGphBuilderStatus("CclBoxingReduceSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclScatterThenAllGatherSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1 && in_parallel_desc.device_type() == DeviceType::kCPU && out_parallel_desc.device_type() == device_type_ && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && logical_blob_desc.shape().elem_cnt() >= 1024 && out_sbp_parallel.has_broadcast_parallel() // a potential optimization: flat the blob and then relax this requirement && logical_blob_desc.shape().NumAxes() > 0 && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0) { const TensorSliceView in_slice = GetBroadcastTensorSliceView(logical_blob_desc); SbpParallel split_sbp_parallel; split_sbp_parallel.mutable_split_parallel()->set_axis(0); std::vector out_slices = GetTensorSliceView(out_parallel_desc.parallel_num(), split_sbp_parallel, logical_blob_desc); const std::string op_name = "System-Boxing-CclBoxingAllGather-" + NewUniqueId(); FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { const TensorSliceView& out_slice = out_slices.at(out_id); // NOLINT const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_id); TaskNode* in_node = sorted_in_tasks.at(nearest_in_parallel_id); // NOLINT SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode(); // slice on cpu const auto in_machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(0)); int64_t thrd_id = EncodeStreamIdToInt64(GenerateComputeTaskStreamId(in_machine_id, DeviceType::kCPU, 0)); slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, in_machine_id, thrd_id); slice_node->ConnectToSrcNodeWithSlice(in_node, ctx->task_graph()->NewEdge(), in_slice); // copy to dst gpu TaskNode* slice_node_proxy = ctx->task_graph()->GetProxyNode(slice_node, lbi, out_parallel_desc, out_id); // allgather auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, out_parallel_desc, out_id, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, device_type_, -1); ctx->task_graph()->ConnectWithLbi(slice_node_proxy, collective_node, lbi); sorted_out_tasks->emplace_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("BoxingCclScatterThenAllGatherSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclBroadcastSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1 && (in_parallel_desc.device_type() == device_type_ || (in_parallel_desc.device_type() == DeviceType::kCPU && logical_blob_desc.shape().elem_cnt() >= 1024)) && out_parallel_desc.device_type() == device_type_ && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && out_sbp_parallel.has_broadcast_parallel()) { TaskNode* gpu_in_node = nullptr; int64_t root_parallel_id = -1; if (in_parallel_desc.device_type() == DeviceType::kCPU) { auto* cpu_in_node = sorted_in_tasks.front(); root_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(out_parallel_desc, in_parallel_desc, 0); gpu_in_node = ctx->task_graph()->GetProxyNode(cpu_in_node, lbi, out_parallel_desc, root_parallel_id); } else if (in_parallel_desc.device_type() == device_type_) { root_parallel_id = FindRootParallelId(out_parallel_desc, in_parallel_desc); gpu_in_node = sorted_in_tasks.front(); } else { return Error::BoxingNotSupportedError(); } if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); } const std::string op_name = "System-Boxing-CclBoxingBroadcast-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, out_parallel_desc.parallel_num()) { auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeBroadcast, device_type_, root_parallel_id); if (i == root_parallel_id) { ctx->task_graph()->ConnectWithLbi(gpu_in_node, collective_node, lbi); } else { std::string regst_desc_name; gpu_in_node->BuildCtrlRegstDesc(collective_node, ®st_desc_name); TaskEdge* edge = ctx->task_graph()->NewEdge(); Connect(gpu_in_node, edge, collective_node); gpu_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } sorted_out_tasks->emplace_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingBroadcastSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } Maybe CclAll2AllSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { const Shape& shape = logical_blob_desc.shape(); const int64_t in_split_axis = in_sbp_parallel.split_parallel().axis(); const int64_t out_split_axis = out_sbp_parallel.split_parallel().axis(); if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && in_parallel_desc.device_type() == device_type_ && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1 && shape.NumAxes() > std::max(in_split_axis, out_split_axis) && shape.At(in_split_axis) % in_parallel_desc.parallel_num() == 0 && shape.At(out_split_axis) % out_parallel_desc.parallel_num() == 0 && in_sbp_parallel.split_parallel().axis() != out_sbp_parallel.split_parallel().axis() && SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) { const std::string op_name = "System-Boxing-CclBoxingAll2All-" + NewUniqueId(); FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i)); const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i)); const int64_t thrd_id = EncodeStreamIdToInt64( GenerateComputeTaskStreamId(machine_id, device_type_, device_index)); TaskNode* in_node = sorted_in_tasks.at(i); // NOLINT CollectiveBoxingPackTaskNode* pack_node = ctx->task_graph()->NewNode(); pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi); auto* collective_node = ctx->task_graph()->NewNode(); CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAll2All, device_type_, -1); ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi); CollectiveBoxingUnpackTaskNode* unpack_node = ctx->task_graph()->NewNode(); unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi); sorted_out_tasks->emplace_back(unpack_node); } return TRY(BuildSubTskGphBuilderStatus("CclBoxingAll2AllSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { bool IsSourceTimeShape(const Shape& shape); class CclAllReduceSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclAllReduceSubTskGphBuilder); CclAllReduceSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclAllReduceSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclReduceScatterSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclReduceScatterSubTskGphBuilder); CclReduceScatterSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclReduceScatterSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclP2SNoncontinuousSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclP2SNoncontinuousSubTskGphBuilder); CclP2SNoncontinuousSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclP2SNoncontinuousSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclAllGatherSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclAllGatherSubTskGphBuilder); CclAllGatherSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclAllGatherSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclS2BNoncontinuousSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclS2BNoncontinuousSubTskGphBuilder); CclS2BNoncontinuousSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclS2BNoncontinuousSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclReduceSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclReduceSubTskGphBuilder); CclReduceSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclReduceSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclScatterThenAllGatherSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclScatterThenAllGatherSubTskGphBuilder); CclScatterThenAllGatherSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclScatterThenAllGatherSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclBroadcastSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclBroadcastSubTskGphBuilder); CclBroadcastSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclBroadcastSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; class CclAll2AllSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CclAll2AllSubTskGphBuilder); CclAll2AllSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {} ~CclAll2AllSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: DeviceType device_type_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" namespace oneflow { Maybe ChainSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { for (const auto& builder : builders_) { Maybe boxing_builder_status = TRY(builder->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape)); if (!boxing_builder_status.IsOk() && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) { continue; } else { return boxing_builder_status; } } return Error::BoxingNotSupportedError(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/chain_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class ChainSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(ChainSubTskGphBuilder); explicit ChainSubTskGphBuilder(std::vector> builders) : builders_(std::move(builders)) {} ~ChainSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: std::vector> builders_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/collective_boxing.proto ================================================ syntax = "proto2"; package oneflow.boxing.collective; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/common/device_type.proto"; enum OpType { kOpTypeInvalid = 0; kOpTypeAllReduce = 1; kOpTypeReduceScatter = 2; kOpTypeAllGather = 3; kOpTypeReduce = 4; kOpTypeBroadcast = 5; kOpTypeAll2All = 6; } enum ReduceMethod { kReduceMethodInvalid = 0; kReduceMethodSum = 1; } message DeviceDesc { required int64 machine_id = 1; required DeviceType device_type = 2; required int64 device_id = 3; } message DeviceSet { repeated DeviceDesc device = 1; } message OpDesc { required string name = 1; required OpType op_type = 2; optional ReduceMethod reduce_method = 3; optional int64 root = 4; required DataType data_type = 5; required ShapeProto shape = 6; required int64 num_ranks = 7; required DeviceType device_type = 8; } message RequestDesc { required OpDesc op_desc = 1; required DeviceSet device_set = 2; required int64 order = 3; required int64 dependency_depth = 4; } message RequestSet { repeated RequestDesc request = 1; } message RankDesc { required OpDesc op_desc = 1; required int64 rank = 2; } ================================================ FILE: oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h" namespace oneflow { CollectiveBoxingSubTskGphBuilder::CollectiveBoxingSubTskGphBuilder() { const CollectiveBoxingConf collective_boxing_conf = Singleton::Get()->collective_boxing_conf(); std::vector> builders; builders.emplace_back(new CclAllReduceSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclReduceScatterSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclP2SNoncontinuousSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclAllGatherSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclS2BNoncontinuousSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclReduceSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclScatterThenAllGatherSubTskGphBuilder(DeviceType::kCUDA)); builders.emplace_back(new CclBroadcastSubTskGphBuilder(DeviceType::kCUDA)); if (collective_boxing_conf.nccl_enable_all_to_all()) { #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kCUDA)); #elif defined(WITH_NPU) builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kNPU)); #elif defined(WITH_MLU) builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kMLU)); #else LOG(WARNING) << "nccl_enable_all_to_all is unavailable unless NCCL_VERSION > 2.7.0"; #endif } chain_builder_.reset(new ChainSubTskGphBuilder(builders)); } Maybe CollectiveBoxingSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (!GlobalJobDesc().Bool("__is_user_function__")) { return Error::BoxingNotSupportedError(); } if (!IsSourceTimeShape(time_shape)) { return Error::BoxingNotSupportedError(); } return chain_builder_->Build(ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class CollectiveBoxingSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingSubTskGphBuilder); CollectiveBoxingSubTskGphBuilder(); ~CollectiveBoxingSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: std::unique_ptr chain_builder_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/collective_boxing_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/collective_boxing_util.h" namespace oneflow { namespace boxing { namespace collective { namespace { Shape GetSplitShape(const RankDesc& rank_desc) { Shape shape(rank_desc.op_desc().shape()); CHECK_GT(shape.NumAxes(), 0); CHECK(shape.At(0) % rank_desc.op_desc().num_ranks() == 0); shape.Set(0, shape.At(0) / rank_desc.op_desc().num_ranks()); return shape; } Shape GetFlattenSplitShape(const RankDesc& rank_desc) { Shape shape(rank_desc.op_desc().shape()); CHECK_GT(shape.NumAxes(), 0); CHECK(shape.elem_cnt() % rank_desc.op_desc().num_ranks() == 0); Shape return_shape({shape.elem_cnt() / rank_desc.op_desc().num_ranks()}); return return_shape; } } // namespace bool GenericOpHasInput(const RankDesc& rank_desc) { const OpType op_type = rank_desc.op_desc().op_type(); if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeReduce || op_type == OpType::kOpTypeAll2All) { return true; } else if (op_type == OpType::kOpTypeBroadcast) { CHECK(rank_desc.op_desc().has_root()); return rank_desc.rank() == rank_desc.op_desc().root(); } else { UNIMPLEMENTED(); return false; } } bool GenericOpHasOutput(const RankDesc& rank_desc) { const OpType op_type = rank_desc.op_desc().op_type(); if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeBroadcast || op_type == OpType::kOpTypeAll2All) { return true; } else if (op_type == OpType::kOpTypeReduce) { CHECK(rank_desc.op_desc().has_root()); return rank_desc.rank() == rank_desc.op_desc().root(); } else { UNIMPLEMENTED(); return false; } } Shape GenericOpGetInputShape(const RankDesc& rank_desc) { CHECK(GenericOpHasInput(rank_desc)); const OpType op_type = rank_desc.op_desc().op_type(); if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeReduce || op_type == OpType::kOpTypeBroadcast) { return Shape(rank_desc.op_desc().shape()); } else if (op_type == OpType::kOpTypeAllGather) { return GetSplitShape(rank_desc); } else if (op_type == OpType::kOpTypeAll2All) { return GetFlattenSplitShape(rank_desc); } else { UNIMPLEMENTED(); return Shape(); } } Shape GenericOpGetOutputShape(const RankDesc& rank_desc) { CHECK(GenericOpHasOutput(rank_desc)); const OpType op_type = rank_desc.op_desc().op_type(); if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather || op_type == OpType::kOpTypeReduce || op_type == OpType::kOpTypeBroadcast) { return Shape(rank_desc.op_desc().shape()); } else if (op_type == OpType::kOpTypeReduceScatter) { return GetSplitShape(rank_desc); } else if (op_type == OpType::kOpTypeAll2All) { return GetFlattenSplitShape(rank_desc); } else { UNIMPLEMENTED(); return Shape(); } } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/collective_boxing_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/graph/boxing/collective_boxing.pb.h" #include "oneflow/core/common/shape.h" namespace oneflow { namespace boxing { namespace collective { inline bool operator==(const OpDesc& lhs, const OpDesc& rhs) { return PbMd::Equals(lhs, rhs); } inline bool operator==(const DeviceDesc& lhs, const DeviceDesc& rhs) { return PbMd::Equals(lhs, rhs); } inline bool operator==(const DeviceSet& lhs, const DeviceSet& rhs) { return PbMd::Equals(lhs, rhs); } inline bool operator!=(const DeviceSet& lhs, const DeviceSet& rhs) { return !(lhs == rhs); } bool GenericOpHasInput(const RankDesc& rank_desc); bool GenericOpHasOutput(const RankDesc& rank_desc); Shape GenericOpGetInputShape(const RankDesc& rank_desc); Shape GenericOpGetOutputShape(const RankDesc& rank_desc); } // namespace collective } // namespace boxing } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::boxing::collective::DeviceDesc& device_desc) const { size_t hash = std::hash()(device_desc.machine_id()); oneflow::HashCombine(&hash, std::hash()(device_desc.device_type())); oneflow::HashCombine(&hash, std::hash()(device_desc.device_id())); return hash; } }; template<> struct hash { size_t operator()(const oneflow::boxing::collective::DeviceSet& device_set) const { size_t hash = 0; for (const auto& device : device_set.device()) { oneflow::AddHash(&hash, device); } return hash; } }; } // namespace std #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_ ================================================ FILE: oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" namespace oneflow { Maybe FallbackToCpuSliceBoxingSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { std::vector status; std::vector cpu_in_tasks; std::vector cpu_out_tasks; std::vector> cpu_ctrl_tasks; cpu_out_tasks.reserve(out_parallel_desc.parallel_num()); FOR_RANGE(int64_t, in_id, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(in_id); TaskNode* proxy_on_src_host = ctx->task_graph()->GetProxyNode( in_node, lbi, GetNodeCPUMemZoneId(in_node->MemZoneId121().rank())); cpu_in_tasks.push_back(proxy_on_src_host); } status.emplace_back("MoveToCpu", "-"); ParallelConf cpu_in_parallel_conf = in_parallel_desc.parallel_conf(); cpu_in_parallel_conf.set_device_tag("cpu"); ParallelConf cpu_out_parallel_conf = out_parallel_desc.parallel_conf(); cpu_out_parallel_conf.set_device_tag("cpu"); Maybe boxing_builder_status = TRY(builder_->Build(ctx, cpu_in_tasks, &cpu_out_tasks, &cpu_ctrl_tasks, ParallelDesc(cpu_in_parallel_conf), ParallelDesc(cpu_out_parallel_conf), lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape)); if (!boxing_builder_status.IsOk() && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) { return Error::BoxingNotSupportedError(); } status.push_back(*JUST(boxing_builder_status)); FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { TaskNode* out_node = ctx->task_graph()->GetProxyNode(cpu_out_tasks.at(out_id), lbi, out_parallel_desc, out_id); sorted_out_tasks->push_back(out_node); } status.emplace_back("MoveBackToDevice", "-"); if (!cpu_ctrl_tasks.empty()) { CHECK_EQ(cpu_ctrl_tasks.size(), sorted_out_tasks->size()); FOR_RANGE(size_t, i, 0, sorted_out_tasks->size()) { for (TaskNode* ctrl_node : cpu_ctrl_tasks.at(i)) { std::string regst_desc_name; ctrl_node->BuildCtrlRegstDesc(sorted_out_tasks->at(i), ®st_desc_name); TaskEdge* edge = ctx->task_graph()->NewEdge(); Connect(ctrl_node, edge, sorted_out_tasks->at(i)); ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } } } return TRY(MakeComposedSubTskGphBuilderStatus(status)); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h" namespace oneflow { class FallbackToCpuSliceBoxingSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(FallbackToCpuSliceBoxingSubTskGphBuilder); FallbackToCpuSliceBoxingSubTskGphBuilder() { builder_.reset(new SliceBoxingSubTskGphBuilder()); } ~FallbackToCpuSliceBoxingSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: std::unique_ptr builder_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h" namespace oneflow { class HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(HierarchicalSubTskGphBuilder); HierarchicalSubTskGphBuilder() = default; virtual ~HierarchicalSubTskGphBuilder() = default; virtual Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/graph/task_stream_id.h" #include "oneflow/core/job/job_desc.h" namespace oneflow { namespace { std::shared_ptr Make1DSubTskGphBuilder() { std::vector> builders; builders.emplace_back(new OneToOneSubTskGphBuilder()); builders.emplace_back(new B21SubTskGphBuilder()); if (!Singleton::Get()->nccl_use_compute_stream()) { builders.emplace_back(new CollectiveBoxingSubTskGphBuilder()); } builders.emplace_back(new SliceBoxingSubTskGphBuilder()); builders.emplace_back(new FallbackToCpuSliceBoxingSubTskGphBuilder()); builders.emplace_back(new NaiveB2BSubTskGphBuilder()); builders.emplace_back(new NaiveB2PSubTskGphBuilder()); return std::make_shared(builders); } void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& parallel_desc_1, ParallelConf* parallel_conf) { CHECK_EQ(parallel_desc_0.device_tag(), parallel_desc_1.device_tag()); std::set> machine_device_ids; for (int64_t machine_id : parallel_desc_0.sorted_machine_ids()) { for (int64_t device_id : parallel_desc_0.sorted_dev_phy_ids(machine_id)) { machine_device_ids.insert(std::make_pair(machine_id, device_id)); } } for (int64_t machine_id : parallel_desc_1.sorted_machine_ids()) { for (int64_t device_id : parallel_desc_1.sorted_dev_phy_ids(machine_id)) { machine_device_ids.insert(std::make_pair(machine_id, device_id)); } } parallel_conf->set_device_tag(parallel_desc_0.device_tag()); for (const auto& pair : machine_device_ids) { parallel_conf->add_device_name("@" + std::to_string(pair.first) + ":" + std::to_string(pair.second)); } } inline std::string NewUniqueIdGbc() { // The boxing task graph is built on rank 0 and broadcasted to all the ranks, // so the ids here are unique among all the ranks. static std::atomic counter(0); static std::atomic curr_job_id(0); if (curr_job_id != GlobalJobDesc().job_id()) { curr_job_id = GlobalJobDesc().job_id(); counter = 0; } return std::to_string(counter.fetch_add(1, std::memory_order_relaxed)); } class NDNcclSendRecvBoxingSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(NDNcclSendRecvBoxingSubTskGphBuilder); NDNcclSendRecvBoxingSubTskGphBuilder() {} ~NDNcclSendRecvBoxingSubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override { if (in_parallel_desc.device_type() == out_parallel_desc.device_type() && in_parallel_desc.device_type() != DeviceType::kCPU && !NdSbpHasPartialParallel(out_nd_sbp)) { #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) ParallelConf merged_parallel_conf; MergeParallelConf(in_parallel_desc.parallel_conf(), out_parallel_desc.parallel_conf(), &merged_parallel_conf); ParallelDesc merged_parallel_desc(merged_parallel_conf); TaskNode* first_in_node = sorted_in_tasks.front(); sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); std::string stream_name = "NCCL_SEND_RECV_BOXING" + NewUniqueIdGbc(); FOR_RANGE(int64_t, id, 0, merged_parallel_desc.parallel_num()) { NcclSendRecvBoxingTaskNode* node = ctx->task_graph()->NewNode(); const int64_t machine_id = JUST(merged_parallel_desc.MachineId4ParallelId(id)); int64_t device_index = JUST(merged_parallel_desc.DeviceId4ParallelId(id)); int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId( machine_id, merged_parallel_desc.device_type(), device_index, stream_name)); bool has_input = in_parallel_desc.Containing(machine_id, device_index); bool has_output = out_parallel_desc.Containing(machine_id, device_index); node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), logical_blob_desc.data_type(), in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, id, merged_parallel_desc, has_input, has_output, stream_name); if (has_input) { int64_t in_id = JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); ctx->task_graph()->ConnectWithLbi(sorted_in_tasks.at(in_id), node, lbi); } else { // TODO: find nearest std::string regst_desc_name; first_in_node->BuildCtrlRegstDesc(node, ®st_desc_name); TaskEdge* edge = ctx->task_graph()->NewEdge(); Connect(first_in_node, edge, node); first_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } if (has_output) { sorted_out_tasks->push_back(node); } } return BuildSubTskGphBuilderStatus("NDNcclSendRecvBoxingSubTskGphBuilder", ""); #else return Error::BoxingNotSupportedError() << "No Device or low NCCL version"; #endif } else { return Error::BoxingNotSupportedError() << "Partial SBP in the consumer or not running on CUDA"; } } }; class Dim0NdSbpMismatchedSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(Dim0NdSbpMismatchedSubTskGphBuilder); Dim0NdSbpMismatchedSubTskGphBuilder() { inter_group_sub_tsk_gph_builder_.reset( new InterGroupSubTskGphBuilder(Make1DSubTskGphBuilder())); } ~Dim0NdSbpMismatchedSubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override { if (in_parallel_desc.hierarchy()->NumAxes() == 2 && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy()) && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0) && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1) && !(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) { return inter_group_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } else { return nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } } private: std::unique_ptr inter_group_sub_tsk_gph_builder_; std::unique_ptr nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(Same2DHierarchySubTskGphBuilder); Same2DHierarchySubTskGphBuilder() { intra_group_sub_tsk_gph_builder_.reset( new IntraGroupSubTskGphBuilder(Make1DSubTskGphBuilder())); dim0_nd_sbp_mismatched_sub_tsk_gph_builder_.reset(new Dim0NdSbpMismatchedSubTskGphBuilder()); } ~Same2DHierarchySubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override { if (in_parallel_desc.hierarchy()->NumAxes() == 2 && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy())) { if (in_nd_sbp.sbp_parallel(0) == out_nd_sbp.sbp_parallel(0)) { return intra_group_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } else { return dim0_nd_sbp_mismatched_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } } else { return Error::BoxingNotSupportedError(); } } private: std::unique_ptr intra_group_sub_tsk_gph_builder_; std::unique_ptr dim0_nd_sbp_mismatched_sub_tsk_gph_builder_; }; } // namespace struct DispatchHierarchicalSubTskGphBuilder::Impl { Impl(); std::unique_ptr flat_sub_tsk_gph_builder_; std::unique_ptr same_2d_hierarchy_sub_tsk_gph_builder_; std::unique_ptr nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder(Make1DSubTskGphBuilder())); same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder()); nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder()); } DispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() { impl_.reset(new Impl()); } DispatchHierarchicalSubTskGphBuilder::~DispatchHierarchicalSubTskGphBuilder() = default; Maybe DispatchHierarchicalSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const { ParallelDesc reduced_in_parallel_desc = in_parallel_desc; ParallelDesc reduced_out_parallel_desc = out_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbp reduced_out_nd_sbp; // The 1d to 2d and 2d to 1d cases are consider in this function // If it gives out 1d sbp and 2d sbp simultaneously, then that the 2d sbp can not be converted // to 1d sbp and 1d sbp can not be expanded to 2d sbp. InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp, &reduced_in_parallel_desc, &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2) && reduced_in_parallel_desc.device_type() == reduced_out_parallel_desc.device_type() && reduced_in_parallel_desc.device_type() != DeviceType::kCPU) { return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); } if (in_hierarchy->NumAxes() <= 2 && out_hierarchy->NumAxes() <= 2) { if (in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 1) { return impl_->flat_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); } else if ((in_hierarchy->NumAxes() == 2) && (*in_hierarchy == *out_hierarchy)) { return impl_->same_2d_hierarchy_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); } else if (reduced_in_parallel_desc.device_type() != DeviceType::kCPU && reduced_out_parallel_desc.device_type() != DeviceType::kCPU) { return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); } else { return Error::BoxingNotSupportedError(); } } return Error::BoxingNotSupportedError(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_ #define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_ #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h" namespace oneflow { class DispatchHierarchicalSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(DispatchHierarchicalSubTskGphBuilder); DispatchHierarchicalSubTskGphBuilder(); ~DispatchHierarchicalSubTskGphBuilder() override; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override; private: struct Impl; std::unique_ptr impl_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_ ================================================ FILE: oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h" namespace oneflow { Maybe FlatSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const { if (in_parallel_desc.hierarchy()->NumAxes() == 1 && out_parallel_desc.hierarchy()->NumAxes() == 1) { return sub_tsk_gph_builder_->Build(ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp.sbp_parallel(0), out_nd_sbp.sbp_parallel(0), time_shape); } else { return Error::BoxingNotSupportedError(); } } Maybe IntraGroupSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const { if (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy() && in_parallel_desc.hierarchy()->NumAxes() == 2 && in_nd_sbp.sbp_parallel(0) == out_nd_sbp.sbp_parallel(0) && in_nd_sbp.sbp_parallel(1) != out_nd_sbp.sbp_parallel(1)) { const auto& hierarchy = in_parallel_desc.hierarchy(); std::vector status; const int64_t num_groups = hierarchy->At(0); const int64_t group_size = hierarchy->At(1); status.reserve(num_groups); sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); sorted_out_tasks->resize(out_parallel_desc.parallel_num()); FOR_RANGE(int64_t, i, 0, num_groups) { std::vector in_tasks; std::vector out_tasks; std::vector> ctrl_tasks; ParallelConf in_parallel_conf; in_parallel_conf.set_device_tag(in_parallel_desc.device_tag()); in_parallel_conf.mutable_hierarchy()->add_dim(group_size); ParallelConf out_parallel_conf; out_parallel_conf.set_device_tag(out_parallel_desc.device_tag()); out_parallel_conf.mutable_hierarchy()->add_dim(group_size); FOR_RANGE(int64_t, j, 0, group_size) { const int64_t parallel_id = i * group_size + j; in_tasks.emplace_back(sorted_in_tasks.at(parallel_id)); // NOLINT in_parallel_conf.add_device_name( "@" + std::to_string(JUST(in_parallel_desc.MachineId4ParallelId(parallel_id))) + ":" + std::to_string(JUST(in_parallel_desc.DeviceId4ParallelId(parallel_id)))); out_parallel_conf.add_device_name( "@" + std::to_string(JUST(out_parallel_desc.MachineId4ParallelId(parallel_id))) + ":" + std::to_string(JUST(out_parallel_desc.DeviceId4ParallelId(parallel_id)))); } DimVector dim_vec = logical_blob_desc.shape().dim_vec(); if (in_nd_sbp.sbp_parallel(0).has_split_parallel()) { const int64_t axis = in_nd_sbp.sbp_parallel(0).split_parallel().axis(); dim_vec.at(axis) /= hierarchy->At(0); } BlobDesc new_blob_desc(Shape(dim_vec), logical_blob_desc.data_type(), logical_blob_desc.memory_format()); std::shared_ptr boxing_builder_status = JUST(sub_tsk_gph_builder_->Build( ctx, in_tasks, &out_tasks, &ctrl_tasks, ParallelDesc(in_parallel_conf), ParallelDesc(out_parallel_conf), lbi, new_blob_desc, in_nd_sbp.sbp_parallel(1), out_nd_sbp.sbp_parallel(1), time_shape)); status.emplace_back(*boxing_builder_status); CHECK_EQ_OR_RETURN(out_tasks.size(), group_size); // NOLINT FOR_RANGE(int64_t, j, 0, group_size) { const int64_t parallel_id = i * group_size + j; sorted_out_tasks->at(parallel_id) = out_tasks.at(j); // NOLINT if (!ctrl_tasks.empty()) { for (TaskNode* ctrl_node : ctrl_tasks.at(j)) { // NOLINT sorted_ctrl_tasks->at(parallel_id).emplace_back(ctrl_node); // NOLINT } } } } return MakeComposedSubTskGphBuilderStatus(status); } else { return Error::BoxingNotSupportedError(); } } Maybe InterGroupSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const { if (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy() && in_parallel_desc.hierarchy()->NumAxes() == 2 && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1) && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0) && !NdSbpAllSameSplitParallel(in_nd_sbp) && !NdSbpAllSameSplitParallel(out_nd_sbp)) { const auto& hierarchy = in_parallel_desc.hierarchy(); std::vector status; const int64_t num_groups = hierarchy->At(0); const int64_t group_size = hierarchy->At(1); status.reserve(group_size); sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); sorted_out_tasks->resize(out_parallel_desc.parallel_num()); FOR_RANGE(int64_t, i, 0, group_size) { std::vector in_tasks; std::vector out_tasks; std::vector> ctrl_tasks; ParallelConf in_parallel_conf; in_parallel_conf.set_device_tag(in_parallel_desc.device_tag()); in_parallel_conf.mutable_hierarchy()->add_dim(num_groups); ParallelConf out_parallel_conf; out_parallel_conf.set_device_tag(out_parallel_desc.device_tag()); out_parallel_conf.mutable_hierarchy()->add_dim(num_groups); FOR_RANGE(int64_t, j, 0, num_groups) { const int64_t parallel_id = j * group_size + i; in_tasks.emplace_back(sorted_in_tasks.at(parallel_id)); // NOLINT in_parallel_conf.add_device_name( "@" + std::to_string(JUST(in_parallel_desc.MachineId4ParallelId(parallel_id))) + ":" + std::to_string(JUST(in_parallel_desc.DeviceId4ParallelId(parallel_id)))); out_parallel_conf.add_device_name( "@" + std::to_string(JUST(out_parallel_desc.MachineId4ParallelId(parallel_id))) + ":" + std::to_string(JUST(out_parallel_desc.DeviceId4ParallelId(parallel_id)))); } DimVector dim_vec = logical_blob_desc.shape().dim_vec(); if (in_nd_sbp.sbp_parallel(1).has_split_parallel()) { const int64_t axis = in_nd_sbp.sbp_parallel(1).split_parallel().axis(); dim_vec.at(axis) /= hierarchy->At(1); } BlobDesc new_blob_desc(Shape(dim_vec), logical_blob_desc.data_type(), logical_blob_desc.memory_format()); std::shared_ptr boxing_builder_status = JUST(sub_tsk_gph_builder_->Build( ctx, in_tasks, &out_tasks, &ctrl_tasks, ParallelDesc(in_parallel_conf), ParallelDesc(out_parallel_conf), lbi, new_blob_desc, in_nd_sbp.sbp_parallel(0), out_nd_sbp.sbp_parallel(0), time_shape)); status.emplace_back(*boxing_builder_status); CHECK_EQ_OR_RETURN(out_tasks.size(), num_groups); // NOLINT FOR_RANGE(int64_t, j, 0, num_groups) { const int64_t parallel_id = j * group_size + i; sorted_out_tasks->at(parallel_id) = out_tasks.at(j); // NOLINT if (!ctrl_tasks.empty()) { for (TaskNode* ctrl_node : ctrl_tasks.at(j)) { // NOLINT sorted_ctrl_tasks->at(parallel_id).emplace_back(ctrl_node); // NOLINT } } } } return MakeComposedSubTskGphBuilderStatus(status); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_ #define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_ #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(FlatSubTskGphBuilder); FlatSubTskGphBuilder(const std::shared_ptr& sub_tsk_gph_builder) : sub_tsk_gph_builder_(sub_tsk_gph_builder) {} ~FlatSubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override; private: std::shared_ptr sub_tsk_gph_builder_; }; class IntraGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(IntraGroupSubTskGphBuilder); IntraGroupSubTskGphBuilder(const std::shared_ptr& sub_tsk_gph_builder) : sub_tsk_gph_builder_(sub_tsk_gph_builder) {} ~IntraGroupSubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override; private: std::shared_ptr sub_tsk_gph_builder_; }; class InterGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(InterGroupSubTskGphBuilder); InterGroupSubTskGphBuilder(const std::shared_ptr& sub_tsk_gph_builder) : sub_tsk_gph_builder_(sub_tsk_gph_builder) {} ~InterGroupSubTskGphBuilder() override = default; Maybe Build(SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& time_shape) const override; private: std::shared_ptr sub_tsk_gph_builder_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_ ================================================ FILE: oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" namespace oneflow { Maybe NaiveB2BSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) && (out_parallel_desc.parallel_num() == 1 || out_sbp_parallel.has_broadcast_parallel())) { FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_id); TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); TaskNode* proxy = ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id); sorted_out_tasks->emplace_back(proxy); } return TRY(BuildSubTskGphBuilderStatus("NaiveB2BSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class NaiveB2BSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(NaiveB2BSubTskGphBuilder); NaiveB2BSubTskGphBuilder() = default; ~NaiveB2BSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/boxing_zeros_task_node.h" #include "oneflow/core/graph/task_stream_id.h" namespace oneflow { Maybe NaiveB2PSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) && out_parallel_desc.parallel_num() != 1 && out_sbp_parallel.has_partial_sum_parallel()) { HashMap out_id2nearest_in_id; int64_t nearest_out_node_idx = -1; int64_t nearest_out_node_distance = -1; FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_id); out_id2nearest_in_id.emplace(out_id, nearest_in_parallel_id); const int64_t distance = SubTskGphBuilderUtil::GetDistance( in_parallel_desc, nearest_in_parallel_id, out_parallel_desc, out_id); if (nearest_out_node_idx == -1 || distance < nearest_out_node_distance) { nearest_out_node_idx = out_id; nearest_out_node_distance = distance; } } FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { const int64_t nearest_in_id = out_id2nearest_in_id.at(out_id); TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_id); if (out_id == nearest_out_node_idx) { TaskNode* proxy = ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id); sorted_out_tasks->emplace_back(proxy); } else { int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(out_id)); int64_t out_dev_phy_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(out_id)); if (out_parallel_desc.device_type() == DeviceType::kCPU) { out_dev_phy_id = 0; } int64_t thrd_id = EncodeStreamIdToInt64(GenerateComputeTaskStreamId( out_machine_id, out_parallel_desc.device_type(), out_dev_phy_id)); auto* zeros_node = ctx->task_graph()->NewNode(); zeros_node->Init(out_machine_id, thrd_id, lbi, logical_blob_desc.shape(), logical_blob_desc.data_type(), time_shape); nearest_in_node->BuildCtrlRegstDesc(zeros_node); ctx->task_graph()->ConnectWithLbi(nearest_in_node, zeros_node, lbi); sorted_out_tasks->emplace_back(zeros_node); } } return TRY(BuildSubTskGphBuilderStatus("NaiveB2PSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class NaiveB2PSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(NaiveB2PSubTskGphBuilder); NaiveB2PSubTskGphBuilder() = default; ~NaiveB2PSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" namespace oneflow { Maybe OneToOneSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if ((in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() == 1) || (in_parallel_desc.parallel_num() == out_parallel_desc.parallel_num() && in_sbp_parallel == out_sbp_parallel)) { for (int64_t i = 0; i < in_parallel_desc.parallel_num(); ++i) { TaskNode* in_node = sorted_in_tasks.at(i); TaskNode* proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i); sorted_out_tasks->emplace_back(proxy); } return TRY(BuildSubTskGphBuilderStatus("OneToOneSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class OneToOneSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(OneToOneSubTskGphBuilder); OneToOneSubTskGphBuilder() = default; ~OneToOneSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/graph/slice_boxing_task_node.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/graph/task_stream_id.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" namespace oneflow { namespace { bool IsCopyNdPrimitiveSupported(DeviceType device_type, int64_t ndims) { auto primitive = ep::primitive::NewPrimitive(device_type, ndims); return primitive.operator bool(); } } // namespace Maybe SliceBoxingSubTskGphBuilder::Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (!IsCopyNdPrimitiveSupported(in_parallel_desc.device_type(), logical_blob_desc.shape().NumAxes())) { return Error::BoxingNotSupportedError(); } if (!IsCopyNdPrimitiveSupported(out_parallel_desc.device_type(), logical_blob_desc.shape().NumAxes())) { return Error::BoxingNotSupportedError(); } if (SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)) { return Error::BoxingNotSupportedError(); } if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(in_parallel_desc.parallel_num(), in_sbp_parallel, logical_blob_desc)) { return Error::BoxingNotSupportedError(); } if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(out_parallel_desc.parallel_num(), out_sbp_parallel, logical_blob_desc)) { return Error::BoxingNotSupportedError(); } if (!(SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel) || SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel) || SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel) || SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel) || SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel))) { return Error::BoxingNotSupportedError(); } const auto NewEdge = [&ctx]() -> TaskEdge* { return ctx->task_graph()->NewEdge(); }; const auto CreateSliceBoxingNode = [&ctx, &lbi](const ParallelDesc& pd, const int64_t parallel_id, const TensorSliceView& slice, SliceBoxingTaskMode mode) -> SliceBoxingTaskNode* { SliceBoxingTaskNode* node = ctx->task_graph()->NewNode(); const int64_t machine_id = CHECK_JUST(pd.MachineId4ParallelId(parallel_id)); int64_t device_index = (pd.device_type() == DeviceType::kCPU) ? 0 : CHECK_JUST(pd.DeviceId4ParallelId(parallel_id)); int64_t thrd_id = EncodeStreamIdToInt64( GenerateComputeTaskStreamId(machine_id, pd.device_type(), device_index)); node->Init(lbi, slice, mode, machine_id, thrd_id); return node; }; const auto GetSliceCopyNode = [&CreateSliceBoxingNode, &NewEdge]( TaskNode* in_node, const TensorSliceView& in_slice, const ParallelDesc& in_pd, const int64_t in_id, const TensorSliceView& intersection) -> TaskNode* { if (in_slice == intersection) { return in_node; } else { SliceBoxingTaskNode* slice_copy_node = CreateSliceBoxingNode(in_pd, in_id, intersection, kSliceBoxingTaskModeCopy); slice_copy_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); return slice_copy_node; } }; const auto BuildSubTaskGphS2B = [&ctx, &CreateSliceBoxingNode, &NewEdge, &lbi]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector& in_nodes, std::vector* out_nodes) { CHECK(SubTskGphBuilderUtil::IsBoxingS2B(in_sbp, out_sbp)); const std::vector in_slices = GetTensorSliceView(in_pd.parallel_num(), in_sbp, blob_desc); const TensorSliceView& out_slice = GetBroadcastTensorSliceView(blob_desc); FOR_RANGE(int64_t, out_id, 0, out_pd.parallel_num()) { SliceBoxingTaskNode* out_node = CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeCopy); FOR_RANGE(int64_t, in_id, 0, in_pd.parallel_num()) { const TensorSliceView& in_slice = in_slices.at(in_id); TaskNode* in_node = in_nodes.at(in_id); TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( in_node, lbi, dynamic_cast(out_node)->MemZoneId121()); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice); } out_nodes->emplace_back(out_node); } }; const auto BuildSubTaskGphS2S = [&ctx, &lbi, &CreateSliceBoxingNode, &GetSliceCopyNode, &NewEdge]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector& in_nodes, std::vector* out_nodes) { CHECK(SubTskGphBuilderUtil::IsBoxingS2S(in_sbp, out_sbp)); const std::vector in_slices = GetTensorSliceView(in_pd.parallel_num(), in_sbp, blob_desc); const std::vector out_slices = GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc); for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) { const TensorSliceView& out_slice = out_slices.at(out_id); SliceBoxingTaskNode* out_node = CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeCopy); for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) { const TensorSliceView& in_slice = in_slices.at(in_id); const TensorSliceView& intersection = out_slice.Intersect(in_slice); if (intersection.IsEmpty()) { continue; } TaskNode* in_node = in_nodes.at(in_id); TaskNode* slice_copy_node = GetSliceCopyNode(in_node, in_slice, in_pd, in_id, intersection); TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( slice_copy_node, lbi, dynamic_cast(out_node)->MemZoneId121()); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), intersection); } out_nodes->emplace_back(out_node); } }; const auto BuildSubTaskGphP2S = [&ctx, &lbi, &CreateSliceBoxingNode, &GetSliceCopyNode, &NewEdge]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector& in_nodes, std::vector* out_nodes) { CHECK(SubTskGphBuilderUtil::IsBoxingP2S(in_sbp, out_sbp)); const TensorSliceView& in_slice = GetBroadcastTensorSliceView(blob_desc); const std::vector out_slices = GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc); for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) { const TensorSliceView& out_slice = out_slices.at(out_id); SliceBoxingTaskNode* out_node = CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeAdd); for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) { const TensorSliceView& intersection = out_slice.Intersect(in_slice); if (intersection.IsEmpty()) { continue; } TaskNode* in_node = in_nodes.at(in_id); TaskNode* slice_copy_node = GetSliceCopyNode(in_node, in_slice, in_pd, in_id, intersection); TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( slice_copy_node, lbi, dynamic_cast(out_node)->MemZoneId121()); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), intersection); } out_nodes->emplace_back(out_node); } }; const auto BuildSubTaskGphP2B = [&ctx, &lbi, &CreateSliceBoxingNode, &NewEdge]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector& in_nodes, std::vector* out_nodes) { CHECK(SubTskGphBuilderUtil::IsBoxingP2B(in_sbp, out_sbp)); const TensorSliceView& slice = GetBroadcastTensorSliceView(blob_desc); for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) { SliceBoxingTaskNode* out_node = CreateSliceBoxingNode(out_pd, out_id, slice, kSliceBoxingTaskModeAdd); for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) { TaskNode* in_node = in_nodes.at(in_id); TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( in_node, lbi, dynamic_cast(out_node)->MemZoneId121()); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), slice); } out_nodes->emplace_back(out_node); } }; const auto BuildSubTaskGphB2S = [&ctx, &lbi, &CreateSliceBoxingNode, &NewEdge]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector& in_nodes, std::vector* out_nodes) { CHECK(SubTskGphBuilderUtil::IsBoxingB2S(in_sbp, out_sbp)); const TensorSliceView& in_slice = GetBroadcastTensorSliceView(blob_desc); const std::vector out_slices = GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc); FOR_RANGE(int64_t, out_id, 0, out_pd.parallel_num()) { const TensorSliceView& out_slice = out_slices.at(out_id); const int64_t nearest_idx = SubTskGphBuilderUtil::FindNearestSrcParallelId(in_pd, out_pd, out_id); TaskNode* in_node = in_nodes.at(nearest_idx); SliceBoxingTaskNode* slice_node = CreateSliceBoxingNode(in_pd, nearest_idx, out_slice, kSliceBoxingTaskModeCopy); slice_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); TaskNode* out_node = ctx->task_graph()->GetProxyNode(slice_node, lbi, out_pd, out_id); out_nodes->emplace_back(out_node); } }; std::string comment; if (SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)) { BuildSubTaskGphS2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphS2B"; } else if (SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) { BuildSubTaskGphS2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphS2S"; } else if (SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)) { BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphP2S"; } else if (SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) { if (logical_blob_desc.shape().elem_cnt() < out_parallel_desc.parallel_num()) { BuildSubTaskGphP2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphP2B"; } else { BlobDesc flat_blob_desc(logical_blob_desc.data_type(), logical_blob_desc.memory_format()); flat_blob_desc.set_shape(Shape({logical_blob_desc.shape().elem_cnt()})); std::vector middle_nodes; SbpParallel middle_sbp; middle_sbp.mutable_split_parallel()->set_axis(0); BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, middle_sbp, flat_blob_desc, sorted_in_tasks, &middle_nodes); BuildSubTaskGphS2B(out_parallel_desc, out_parallel_desc, middle_sbp, out_sbp_parallel, flat_blob_desc, middle_nodes, sorted_out_tasks); comment = "BuildSubTaskGphP2S->BuildSubTaskGphS2B"; for (TaskNode* out_node : *sorted_out_tasks) { auto* slice_boxing_node = dynamic_cast(out_node); CHECK_NOTNULL(slice_boxing_node); slice_boxing_node->SetOutShape(logical_blob_desc.shape()); } } } else if (SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel)) { BuildSubTaskGphB2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphB2S"; } else { UNIMPLEMENTED(); } return TRY(BuildSubTskGphBuilderStatus("SliceBoxingSubTskGphBuilder", comment)); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/graph/boxing/sub_task_graph_builder.h" namespace oneflow { class SliceBoxingSubTskGphBuilder final : public SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingSubTskGphBuilder); SliceBoxingSubTskGphBuilder() = default; ~SliceBoxingSubTskGphBuilder() override = default; Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h" namespace oneflow { class SubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilder); SubTskGphBuilder() = default; virtual ~SubTskGphBuilder() = default; virtual Maybe Build( SubTskGphBuilderCtx* ctx, const std::vector& sorted_in_tasks, std::vector* sorted_out_tasks, std::vector>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, const SbpParallel& out_sbp_parallel, const Shape& time_shape) const = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_ ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h" namespace oneflow { SubTskGphBuilderCtx::SubTskGphBuilderCtx(TaskGraph* task_graph) : task_graph_(task_graph) {} TaskGraph* SubTskGphBuilderCtx::task_graph() { return task_graph_; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_ #define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/graph/task_graph.h" namespace oneflow { class SubTskGphBuilderCtx final { public: OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilderCtx); explicit SubTskGphBuilderCtx(TaskGraph* task_graph); virtual ~SubTskGphBuilderCtx() = default; virtual TaskGraph* task_graph(); private: TaskGraph* task_graph_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_ ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h" namespace oneflow { Maybe BuildSubTskGphBuilderStatus(const std::string& builder_name, const std::string& comment) { SubTskGphBuilderStatus status(builder_name, comment); return status; } Maybe MakeComposedSubTskGphBuilderStatus( const std::vector& status_vec) { std::string builder_name = "ComposedBuilder:"; std::string comment = "ComposedComment:"; for (auto status : status_vec) { builder_name += " "; builder_name += status.builder_name(); comment += " "; if (status.comment().empty()) { comment += "None"; } else { comment += status.comment(); } } SubTskGphBuilderStatus composed_status(builder_name, comment); return composed_status; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_ #define ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_ #include "oneflow/core/graph/compute_task_node.h" namespace oneflow { class SubTskGphBuilderStatus; Maybe BuildSubTskGphBuilderStatus(const std::string& builder_name, const std::string& comment); Maybe MakeComposedSubTskGphBuilderStatus( const std::vector& status); class SubTskGphBuilderStatus final { public: SubTskGphBuilderStatus(const std::string& builder_name, const std::string& comment) : builder_name_(builder_name), comment_(comment){}; ~SubTskGphBuilderStatus() = default; // Getters const std::string& builder_name() const { return builder_name_; } const std::string& comment() const { return comment_; } private: std::string builder_name_; std::string comment_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_ ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { bool SubTskGphBuilderUtil::IsDeviceTypeCPUOr(const ParallelDesc& parallel_desc, DeviceType device_type) { return parallel_desc.device_type() == DeviceType::kCPU || parallel_desc.device_type() == device_type; } bool SubTskGphBuilderUtil::HasEmptySliceIfSplit(int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc) { if (sbp_parallel.has_split_parallel()) { return blob_desc.shape().At(sbp_parallel.split_parallel().axis()) < parallel_num; } else { return false; } } bool SubTskGphBuilderUtil::IsOnSameDevice(const TaskNode* lhs, const TaskNode* rhs) { return lhs->stream_id().device_id() == rhs->stream_id().device_id() && lhs->stream_id().device_id().device_type() != DeviceType::kCPU; } bool SubTskGphBuilderUtil::IsBoxingS2S(const SbpParallel& src, const SbpParallel& dst) { return src.has_split_parallel() && dst.has_split_parallel(); } bool SubTskGphBuilderUtil::IsBoxingS2B(const SbpParallel& src, const SbpParallel& dst) { return src.has_split_parallel() && dst.has_broadcast_parallel(); } bool SubTskGphBuilderUtil::IsBoxingP2S(const SbpParallel& src, const SbpParallel& dst) { return src.has_partial_sum_parallel() && dst.has_split_parallel(); } bool SubTskGphBuilderUtil::IsBoxingP2B(const SbpParallel& src, const SbpParallel& dst) { return src.has_partial_sum_parallel() && dst.has_broadcast_parallel(); } bool SubTskGphBuilderUtil::IsBoxingB2B(const SbpParallel& src, const SbpParallel& dst) { return src.has_broadcast_parallel() && dst.has_broadcast_parallel(); } bool SubTskGphBuilderUtil::IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst) { return src.has_broadcast_parallel() && dst.has_split_parallel(); } bool SubTskGphBuilderUtil::BlobHasDynamicShape(const BlobDesc& blob_desc) { return blob_desc.is_dynamic(); } bool SubTskGphBuilderUtil::IsErrorBoxingNotSupported(const ErrorProto& error) { return error.has_boxing_not_supported_error(); } int64_t SubTskGphBuilderUtil::GetDistance( const int64_t src_machine_id, const int64_t src_dev_phy_id, const DeviceType src_device_type, const int64_t dst_machine_id, const int64_t dst_dev_phy_id, const DeviceType dst_device_type) { if (src_machine_id != dst_machine_id) { return kDistanceDiffMachine; } else if (src_device_type != dst_device_type) { return kDistanceSameMachine; } else if (src_device_type == DeviceType::kCPU) { return kDistanceSameDevice; } else { if (src_dev_phy_id == dst_dev_phy_id) { return kDistanceSameDevice; } else { return kDistanceSameMachine; } } } int64_t SubTskGphBuilderUtil::GetDistance(const ParallelDesc& src_parallel_desc, const int64_t src_parallel_id, const ParallelDesc& dst_parallel_desc, const int64_t dst_parallel_id) { const int64_t src_machine_id = CHECK_JUST(src_parallel_desc.MachineId4ParallelId(src_parallel_id)); const int64_t src_dev_phy_id = CHECK_JUST(src_parallel_desc.DeviceId4ParallelId(src_parallel_id)); const int64_t dst_machine_id = CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); const int64_t dst_dev_phy_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); return GetDistance(src_machine_id, src_dev_phy_id, src_parallel_desc.device_type(), dst_machine_id, dst_dev_phy_id, dst_parallel_desc.device_type()); } int64_t SubTskGphBuilderUtil::GetDistance(const TaskNode* src, const TaskNode* dst) { const auto GetDevPhyId = [](const TaskNode* node) -> int64_t { const DeviceId& device_id = node->stream_id().device_id(); if (device_id.device_type() == DeviceType::kCPU) { return 0; } else { return device_id.device_index(); } }; const DeviceType src_device_type = src->device_type(); const int64_t src_dev_phy_id = GetDevPhyId(src); const DeviceType dst_device_type = dst->device_type(); const int64_t dst_dev_phy_id = GetDevPhyId(dst); return GetDistance(src->machine_id(), src_dev_phy_id, src_device_type, dst->machine_id(), dst_dev_phy_id, dst_device_type); } int64_t SubTskGphBuilderUtil::FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc, const ParallelDesc& to_parallel_desc, const int64_t to_parallel_id) { int64_t nearest_from_parallel_idx = -1; int64_t nearest_distance = SubTskGphBuilderUtil::kDistanceMax; for (int64_t i = 0; i < from_parallel_desc.parallel_num(); ++i) { const int64_t distance = SubTskGphBuilderUtil::GetDistance(from_parallel_desc, i, to_parallel_desc, to_parallel_id); if (distance < nearest_distance) { nearest_from_parallel_idx = i; nearest_distance = distance; } } CHECK_NE(nearest_from_parallel_idx, -1); return nearest_from_parallel_idx; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing/sub_task_graph_builder_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_ #define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_ #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/graph/task_node.h" namespace oneflow { struct SubTskGphBuilderUtil { static constexpr int64_t kDistanceSameDevice = 0; static constexpr int64_t kDistanceSameMachine = 1; static constexpr int64_t kDistanceDiffMachine = 2; static constexpr int64_t kDistanceMax = 3; static bool IsDeviceTypeCPUOr(const ParallelDesc& parallel_desc, DeviceType device_type); static bool HasEmptySliceIfSplit(int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc); static bool IsOnSameDevice(const TaskNode* lhs, const TaskNode* rhs); static bool IsBoxingS2S(const SbpParallel& src, const SbpParallel& dst); static bool IsBoxingS2B(const SbpParallel& src, const SbpParallel& dst); static bool IsBoxingP2S(const SbpParallel& src, const SbpParallel& dst); static bool IsBoxingP2B(const SbpParallel& src, const SbpParallel& dst); static bool IsBoxingB2B(const SbpParallel& src, const SbpParallel& dst); static bool IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst); static bool BlobHasDynamicShape(const BlobDesc& blob_desc); static bool IsErrorBoxingNotSupported(const ErrorProto& error); static int64_t GetDistance(int64_t src_machine_id, int64_t src_dev_phy_id, DeviceType src_device_type, int64_t dst_machine_id, int64_t dst_dev_phy_id, DeviceType dst_device_type); static int64_t GetDistance(const ParallelDesc& src_parallel_desc, int64_t src_parallel_id, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id); static int64_t GetDistance(const TaskNode* src, const TaskNode* dst); template static int64_t FindNearestNodeIndex(const std::vector from_nodes, const NodeType* to_node) { CHECK(!from_nodes.empty()); int64_t nearest_from_node_idx = -1; int64_t nearest_distance = SubTskGphBuilderUtil::kDistanceMax; for (int64_t i = 0; i < from_nodes.size(); ++i) { NodeType* from_node = from_nodes.at(i); int64_t distance = SubTskGphBuilderUtil::GetDistance(from_node, to_node); if (distance < nearest_distance) { nearest_from_node_idx = i; nearest_distance = distance; } } return nearest_from_node_idx; } template static NodeType* FindNearestNode(const std::vector from_nodes, const NodeType* to_node) { const int64_t idx = FindNearestNodeIndex(from_nodes, to_node); return from_nodes.at(idx); } static int64_t FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc, const ParallelDesc& to_parallel_desc, int64_t to_parallel_id); }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_ ================================================ FILE: oneflow/core/graph/boxing_identity_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/boxing_identity_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); } void BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } void BoxingIdentityTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } void BoxingIdentityTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Boxing-Identity-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); *op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi(); std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); (node->*GetInferBlobDescsMethod())(nullptr); } void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } Maybe BoxingIdentityTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_boxing_identity_task()) << "not a serialized BoxingIdentityTaskNode. debug string: " << transport_task_proto.DebugString(); return Maybe::Ok(); } void BoxingIdentityTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); transport_task_proto->mutable_boxing_identity_task(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing_identity_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" namespace oneflow { class BoxingIdentityTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityTaskNode); BoxingIdentityTaskNode() = default; ~BoxingIdentityTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi); TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/boxing_task_graph.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/job/task.proto"; import "oneflow/core/job/placement.proto"; import "oneflow/core/graph/task_edge.proto"; import "oneflow/core/operator/op_conf.proto"; import "oneflow/core/register/tensor_slice_view.proto"; message ComputeTasksProto { map parallel_id2task = 2; } message CollectiveBoxingGenericTaskProto { required OperatorConf op_conf = 1; } message NcclSendRecvBoxingTaskProto { required ShapeProto logical_shape = 1; required DataType data_type = 2; required NdSbp src_nd_sbp = 3; required NdSbp dst_nd_sbp = 4; required ParallelConf src_parallel_conf = 5; required ParallelConf dst_parallel_conf = 6; required ParallelConf parallel_conf = 7; required ParallelContext parallel_ctx = 8; required bool has_input = 9; required bool has_output = 10; required string stream_name = 11; } enum CopyHdType { H2D = 0; D2H = 1; } message CopyHdTaskProto { required CopyHdType copy_type = 1; } message CopyCommNetTaskProto { } message BoxingZerosTaskProto { required ShapeProto shape = 1; required DataType data_type = 2; required ShapeProto time_shape = 3; } enum SliceBoxingTaskMode { kSliceBoxingTaskModeInvalid = 0; kSliceBoxingTaskModeCopy = 1; kSliceBoxingTaskModeAdd = 2; } message SliceBoxingTaskProto { map in_data_edge_uid2slice = 1; repeated int64 ordered_in_data_edge_uid = 2; required TensorSliceViewProto out_slice = 3; required ShapeProto out_shape = 4; required SliceBoxingTaskMode mode = 5; } message CollectiveBoxingPackTaskProto { required ShapeProto logical_shape = 1; required SbpParallel src_sbp_parallel = 2; required SbpParallel dst_sbp_parallel = 3; required int64 parallel_num = 4; } message CollectiveBoxingUnpackTaskProto { required ShapeProto logical_shape = 1; required SbpParallel src_sbp_parallel = 2; required SbpParallel dst_sbp_parallel = 3; required int64 parallel_num = 4; } message BoxingIdentityTaskProto { } message TransportTaskProto { required TaskProto task_proto = 1; required LogicalBlobId lbi = 11; oneof transport_task_type { CollectiveBoxingGenericTaskProto collective_boxing_generic_task = 2; NcclSendRecvBoxingTaskProto nccl_send_recv_boxing_task = 3; CopyHdTaskProto copy_hd_task = 4; CopyCommNetTaskProto copy_comm_net_task = 5; BoxingZerosTaskProto boxing_zeros_task = 6; SliceBoxingTaskProto slice_boxing_task = 7; CollectiveBoxingPackTaskProto collective_boxing_pack_task = 8; CollectiveBoxingUnpackTaskProto collective_boxing_unpack_task = 9; BoxingIdentityTaskProto boxing_identity_task = 10; } } message TaskIdsProto { repeated int64 task_id = 1; } message BoxingTaskGraphProto { map boxing_related_op_name2compute_tasks = 1; repeated TransportTaskProto transport_task = 2; repeated TaskEdgeProto task_edge = 3; map boxing_unrelated_op_name2task_ids = 4; } ================================================ FILE: oneflow/core/graph/boxing_zeros_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/boxing_zeros_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& shape, DataType data_type, const Shape& time_shape) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); shape_ = shape; data_type_ = data_type; time_shape_ = time_shape; } void BoxingZerosTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } void BoxingZerosTaskNode::ConsumeAllRegsts() { // do nothing } void BoxingZerosTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Boxing-Zeros-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); *op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi(); shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape()); op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_); std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); node->mut_op() = sole_op; std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); (node->*GetInferBlobDescsMethod())(nullptr); } void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() { GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_)); } Maybe BoxingZerosTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_boxing_zeros_task()) << "not a serialized BoxingZerosTaskNode. debug string: " << transport_task_proto.DebugString(); const auto& proto = transport_task_proto.boxing_zeros_task(); shape_ = Shape(proto.shape()); data_type_ = proto.data_type(); time_shape_ = Shape(proto.time_shape()); return Maybe::Ok(); } void BoxingZerosTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); auto* proto = transport_task_proto->mutable_boxing_zeros_task(); shape_.ToProto(proto->mutable_shape()); proto->set_data_type(data_type_); time_shape_.ToProto(proto->mutable_time_shape()); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/boxing_zeros_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" namespace oneflow { class BoxingZerosTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(BoxingZerosTaskNode); BoxingZerosTaskNode() = default; ~BoxingZerosTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& shape, DataType data_type, const Shape& time_shape); TaskType GetTaskType() const override { return TaskType::kBoxingZeros; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; Shape shape_; DataType data_type_; Shape time_shape_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/collective_boxing_pack_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/collective_boxing_pack_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); logical_shape_ = logical_shape; parallel_num_ = parallel_num; src_sbp_parallel_ = src_sbp_parallel; dst_sbp_parallel_ = dst_sbp_parallel; } void CollectiveBoxingPackTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } void CollectiveBoxingPackTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Collective-Boxing-Pack-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); auto* collective_boxing_pack_conf = op_conf.mutable_collective_boxing_pack_conf(); *collective_boxing_pack_conf->mutable_lbi() = lbi(); logical_shape_.ToProto(collective_boxing_pack_conf->mutable_logical_shape()); *collective_boxing_pack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; *collective_boxing_pack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; collective_boxing_pack_conf->set_num_ranks(parallel_num_); std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } Maybe CollectiveBoxingPackTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_pack_task()) << "not a serialized CollectiveBoxingPackTaskNode. debug string: " << transport_task_proto.DebugString(); const auto& proto = transport_task_proto.collective_boxing_pack_task(); logical_shape_ = Shape(proto.logical_shape()); src_sbp_parallel_ = proto.src_sbp_parallel(); dst_sbp_parallel_ = proto.dst_sbp_parallel(); parallel_num_ = proto.parallel_num(); return Maybe::Ok(); } void CollectiveBoxingPackTaskNode::ToTransportTaskProto( TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); auto* proto = transport_task_proto->mutable_collective_boxing_pack_task(); logical_shape_.ToProto(proto->mutable_logical_shape()); *proto->mutable_src_sbp_parallel() = src_sbp_parallel_; *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_; proto->set_parallel_num(parallel_num_); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/collective_boxing_pack_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" namespace oneflow { class CollectiveBoxingPackTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackTaskNode); CollectiveBoxingPackTaskNode() = default; ~CollectiveBoxingPackTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num); TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; Shape logical_shape_; SbpParallel src_sbp_parallel_; SbpParallel dst_sbp_parallel_; int64_t parallel_num_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/collective_boxing_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/collective_boxing_task_node.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" namespace oneflow { void CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const OperatorConf& op_conf) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); op_conf_ = op_conf; } void CollectiveBoxingGenericTaskNode::ProduceAllRegstsAndBindEdges() { if (boxing::collective::GenericOpHasOutput( op_conf_.collective_boxing_generic_conf().rank_desc())) { const bool enable_mem_reuse = ParseBooleanFromEnv("ONEFLOW_GRAPH_BOXING_ENABLE_MEM_REUSE", false); std::shared_ptr out_regst = ProduceRegst("out", enable_mem_reuse, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } } void CollectiveBoxingGenericTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } void CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr boxing_op = CHECK_JUST(ConstructOp(op_conf_)); node->mut_op() = boxing_op; for (const std::string& ibn : boxing_op->input_bns()) { node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in")); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : boxing_op->output_bns()) { CHECK(out_regst != nullptr); node->BindBnWithRegst(obn, out_regst); out_regst->AddLbi(boxing_op->BnInOp2Lbi(obn)); } (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() { auto out_regst = GetProducedRegst("out"); if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } } Maybe CollectiveBoxingGenericTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_generic_task()) << "not a serialized CollectiveBoxingGenericTaskNode. debug string: " << transport_task_proto.DebugString(); op_conf_ = transport_task_proto.collective_boxing_generic_task().op_conf(); return Maybe::Ok(); } void CollectiveBoxingGenericTaskNode::ToTransportTaskProto( TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); *transport_task_proto->mutable_collective_boxing_generic_task()->mutable_op_conf() = op_conf_; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/collective_boxing_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" namespace oneflow { class CollectiveBoxingGenericTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericTaskNode); CollectiveBoxingGenericTaskNode() = default; ~CollectiveBoxingGenericTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const OperatorConf& op_conf); Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingGeneric; } OperatorConf op_conf_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/collective_boxing_unpack_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/collective_boxing_unpack_task_node.h" namespace oneflow { void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); logical_shape_ = logical_shape; parallel_num_ = parallel_num; src_sbp_parallel_ = src_sbp_parallel; dst_sbp_parallel_ = dst_sbp_parallel; } void CollectiveBoxingUnpackTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } void CollectiveBoxingUnpackTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Collective-Boxing-Unpack-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); auto* collective_boxing_unpack_conf = op_conf.mutable_collective_boxing_unpack_conf(); *collective_boxing_unpack_conf->mutable_lbi() = lbi(); logical_shape_.ToProto(collective_boxing_unpack_conf->mutable_logical_shape()); *collective_boxing_unpack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; *collective_boxing_unpack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; collective_boxing_unpack_conf->set_num_ranks(parallel_num_); std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } Maybe CollectiveBoxingUnpackTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_unpack_task()) << "not a serialized CollectiveBoxingUnpackTaskNode. debug string: " << transport_task_proto.DebugString(); const auto& proto = transport_task_proto.collective_boxing_unpack_task(); logical_shape_ = Shape(proto.logical_shape()); src_sbp_parallel_ = proto.src_sbp_parallel(); dst_sbp_parallel_ = proto.dst_sbp_parallel(); parallel_num_ = proto.parallel_num(); return Maybe::Ok(); } void CollectiveBoxingUnpackTaskNode::ToTransportTaskProto( TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); auto* proto = transport_task_proto->mutable_collective_boxing_unpack_task(); logical_shape_.ToProto(proto->mutable_logical_shape()); *proto->mutable_src_sbp_parallel() = src_sbp_parallel_; *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_; proto->set_parallel_num(parallel_num_); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/collective_boxing_unpack_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" namespace oneflow { class CollectiveBoxingUnpackTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackTaskNode); CollectiveBoxingUnpackTaskNode() = default; ~CollectiveBoxingUnpackTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num); TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingUnpack; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; Shape logical_shape_; SbpParallel src_sbp_parallel_; SbpParallel dst_sbp_parallel_; int64_t parallel_num_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h" namespace oneflow { namespace { const OpNode* OpNodeOnEdge(TaskEdge* edge, TaskNode* (TaskEdge::*GetNode)() const, void (TaskNode::*ForEachDataEdge)(const std::function&) const) { CompTaskNode* target_node = nullptr; do { TaskNode* tmp_node = (edge->*GetNode)(); target_node = dynamic_cast(tmp_node); edge = nullptr; (tmp_node->*ForEachDataEdge)([&](TaskEdge* e) { if (edge == nullptr) { edge = e; } }); } while (!target_node && edge); if (target_node) { return target_node->op_node(); } return nullptr; } std::vector GetCompTaskNodesOnEdge( TaskEdge* edge, TaskNode* (TaskEdge::*GetNode)() const, void (TaskNode::*ForEachDataEdge)(const std::function&) const) { std::queue nodes; HashSet visited_nodes; nodes.push((edge->*GetNode)()); CHECK(visited_nodes.emplace((edge->*GetNode)()).second); std::vector comp_task_nodes; while (!nodes.empty()) { TaskNode* node = nodes.front(); nodes.pop(); CompTaskNode* comp_task_node = dynamic_cast(node); if (comp_task_node) { comp_task_nodes.emplace_back(comp_task_node); } else { (node->*ForEachDataEdge)([&](TaskEdge* task_edge) { if (visited_nodes.find((task_edge->*GetNode)()) == visited_nodes.end()) { nodes.push((task_edge->*GetNode)()); CHECK(visited_nodes.emplace((task_edge->*GetNode)()).second); } }); } } return comp_task_nodes; } std::shared_ptr NewFakeDataRegstDesc() { auto regst_desc = std::make_shared(); regst_desc->mut_regst_desc_type()->mutable_data_regst_desc(); return regst_desc; } } // namespace void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) { ConsumeRegst(regst_name, NewFakeDataRegstDesc()); fake_consumed_regst_names_.insert(regst_name); } void CompTaskNode::ConsumeFakeRegstsIf() { ConsumeFakeRegsts(); RegstDesc* data_regst_desc = nullptr; for (const auto& pair : consumed_regsts()) { for (const auto& regst_desc : pair.second) { if (regst_desc->regst_desc_type().has_data_regst_desc()) { // Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts(). CHECK(data_regst_desc == nullptr); data_regst_desc = CHECK_NOTNULL(regst_desc.get()); } else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) { // do nothing. } else { UNIMPLEMENTED(); } } } if (data_regst_desc != nullptr) { for (const auto& ibn : op_node()->op().input_bns()) { // Only one fake data regst is creatd and just use it for all input_bns as a placeholder. data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); } } } void CompTaskNode::EraseFakeRegstsIf() { for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) { EraseConsumedRegstsByName(fake_consumed_regst_name); } fake_consumed_regst_names_.clear(); } std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) { TaskNode::InitFromProtoExceptConsumedRegsts(proto); parallel_ctx_ = proto.parallel_ctx(); } void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const { TaskNode::ToProto(task_proto, check); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; } const OpNode* CompTaskNode::GetOneSuccOpNodeOnEdge(TaskEdge* edge) { return OpNodeOnEdge(edge, &TaskEdge::dst_node, &TaskNode::ForEachOutDataEdge); } const OpNode* CompTaskNode::GetOnePredOpNodeOnEdge(TaskEdge* edge) { return OpNodeOnEdge(edge, &TaskEdge::src_node, &TaskNode::ForEachInDataEdge); } std::vector CompTaskNode::GetSuccCompTaskNodesOnEdge(TaskEdge* edge) const { return GetCompTaskNodesOnEdge(edge, &TaskEdge::dst_node, &TaskNode::ForEachOutDataEdge); } std::vector CompTaskNode::GetPredCompTaskNodesOnEdge(TaskEdge* edge) const { return GetCompTaskNodesOnEdge(edge, &TaskEdge::src_node, &TaskNode::ForEachInDataEdge); } void CompTaskNode::InferProducedDataRegstTimeShape() { std::shared_ptr op_time_shape(new Shape(*CHECK_JUST(op()->GetOpTimeShape()))); ForEachProducedDataRegst([op_time_shape](const std::string& name, RegstDesc* regst) { *regst->mut_data_regst_time_shape() = op_time_shape; }); } CompTaskNode* NewCompTaskNode4OpNode(const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (op_conf.has_user_conf()) { const std::string& op_type_name = op_conf.user_conf().op_type_name(); if (IsClassRegistered(op_type_name)) { return std::unique_ptr( NewObj(op_type_name)) ->NewCompTaskNode(op_conf); } else { return new NormalForwardCompTaskNode; } } else { OperatorConf::OpTypeCase op_type_case = op_conf.op_type_case(); if (IsClassRegistered(op_type_case)) { return std::unique_ptr( NewObj(op_type_case)) ->NewCompTaskNode(op_conf); } else { return new NormalForwardCompTaskNode; } } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/compute_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_ #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/fake_consumed_regst_provider.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/job/compile_mode.h" namespace oneflow { class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { public: OF_DISALLOW_COPY_AND_MOVE(CompTaskNode); CompTaskNode() = default; virtual ~CompTaskNode() = default; virtual void ToProto(TaskProto*, bool check) const override; virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override; void ConsumeFakeRegstsIf() override; void EraseFakeRegstsIf() override; // ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. virtual void ConsumeFakeRegsts() = 0; void ConsumeFakeRegst(const std::string& regst_name); // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } const ParallelContext* parallel_ctx() const override { return ¶llel_ctx_; } ParallelContext* mut_parallel_ctx() { return ¶llel_ctx_; } // op_node_ const OpNode* op_node() const { return op_node_; } void set_op_node(const OpNode* val) { op_node_ = val; } std::string VisualStr() const override; // op std::shared_ptr op() const { return op_node_->shared_op(); } ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override { // For default compilation mode, compute task node use input blob desc to infer output blob // desc; For separate compilation mode, compute task node use NdSBP to infer output blob desc. return InferBlobDescsMethodGetter::Visit(CHECK_JUST(CurrentCompileMode())); } protected: const OpNode* GetOneSuccOpNodeOnEdge(TaskEdge* edge); const OpNode* GetOnePredOpNodeOnEdge(TaskEdge* edge); std::vector GetSuccCompTaskNodesOnEdge(TaskEdge* edge) const; std::vector GetPredCompTaskNodesOnEdge(TaskEdge* edge) const; void InferProducedDataRegstTimeShape() override; private: struct InferBlobDescsMethodGetter final : public CompileModeVisitor { static ExecNode::InferBlobDescsMethod VisitNaive() { return &ExecNode::InferBlobDescsByInputs; } static ExecNode::InferBlobDescsMethod VisitRankPerProcess() { return &ExecNode::InferBlobDescsByNdSbp; } static ExecNode::InferBlobDescsMethod VisitInValid() { return nullptr; } }; ParallelContext parallel_ctx_; const OpNode* op_node_; HashSet fake_consumed_regst_names_; }; class OpCompTaskNodeCreator { public: virtual ~OpCompTaskNodeCreator() = default; virtual CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) = 0; }; template class StaticOpCompTaskNodeCreator : public OpCompTaskNodeCreator { public: StaticOpCompTaskNodeCreator() = default; ~StaticOpCompTaskNodeCreator() override = default; private: CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) override { return new CompTaskNodeType(); } }; class FnOpCompTaskNodeCreator : public OpCompTaskNodeCreator { public: using CreateFn = std::function; explicit FnOpCompTaskNodeCreator(CreateFn fn) : fn_(std::move(fn)) {} ~FnOpCompTaskNodeCreator() override = default; private: CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) override { return fn_(op_conf); } CreateFn fn_; }; #define REGISTER_USER_OP_COMP_TASK_NODE_TYPE(op_type_name, comp_task_node_type) \ REGISTER_CLASS_CREATOR(std::string, op_type_name, OpCompTaskNodeCreator, \ ([] { return new StaticOpCompTaskNodeCreator(); })); #define REGISTER_USER_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(op_type_name, func) \ REGISTER_CLASS_CREATOR(std::string, op_type_name, OpCompTaskNodeCreator, \ ([] { return new FnOpCompTaskNodeCreator(func); })); #define REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(op_type_case, comp_task_node_type) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, OpCompTaskNodeCreator, \ ([] { return new StaticOpCompTaskNodeCreator(); })); #define REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(op_type_case, func) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, OpCompTaskNodeCreator, \ ([] { return new FnOpCompTaskNodeCreator(func); })); CompTaskNode* NewCompTaskNode4OpNode(const OpNode* op_node); } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/copy_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/graph/task_stream_id.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { void CopyTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("copy_out", false); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("copy_out", out_regst); }); } void CopyTaskNode::ConsumeAllRegsts() { ConsumeRegst("copy_in", SoleInDataEdge()->GetSoleRegst()); } void CopyTaskNode::BuildExecGphAndRegst() { auto out_regst = GetProducedRegst("copy_out"); auto in_regst = GetSoleConsumedRegst("copy_in"); out_regst->CopyBlobDescFrom(in_regst.get()); ExecNode* node = mut_exec_gph().NewNode(); auto constructed = CHECK_JUST(ConstructOp(NewCopyOpConf())); // prevent filling parallel desc for copy commnet if (constructed->op_conf().has_user_conf()) { std::shared_ptr hierarchy = std::make_shared(Shape({1})); auto parallel_desc = ParallelDesc::New(constructed->op_conf().device_tag(), {"0:0-0"}, hierarchy).GetOrThrow(); CHECK_JUST(constructed->FillOpParallelDesc(parallel_desc)); } node->mut_op() = constructed; node->BindBnWithRegst(node->op()->SoleIbn(), in_regst); node->BindBnWithRegst(node->op()->SoleObn(), out_regst); } void CopyTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } void CopyHdTaskNode::Init(CopyHdType copy_type, const DeviceId& device_id, const LogicalBlobId& lbi) { copy_type_ = copy_type; set_machine_id(device_id.rank()); int64_t thrd_id = -1; if (copy_type == CopyHdType::H2D) { thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(device_id, "H2D")); } else if (copy_type == CopyHdType::D2H) { thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(device_id, "D2H")); } else { UNIMPLEMENTED(); } set_thrd_id(thrd_id); set_lbi(lbi); } void CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) { if (copy_type_ == CopyHdType::H2D) { TaskNode::InitProducedRegstMemCase(mem_case); } else if (copy_type_ == CopyHdType::D2H) { mem_case->set_device_type(DeviceType::kCPU); mem_case->set_device_id(0); mem_case->set_pinned_device_type(device_type()); mem_case->set_pinned_device_id(stream_id().device_id().device_index()); } else { UNIMPLEMENTED(); } } void CopyHdTaskNode::ProduceAllRegstsAndBindEdges() { const bool enable_mem_reuse = ParseBooleanFromEnv("ONEFLOW_GRAPH_BOXING_ENABLE_MEM_REUSE", false) && (copy_type_ == CopyHdType::H2D); std::shared_ptr out_regst = ProduceRegst("copy_out", enable_mem_reuse); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("copy_out", out_regst); }); } OperatorConf CopyHdTaskNode::NewCopyOpConf() { OperatorConf conf; conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type()))); auto copy_type_name = "undefined"; if (copy_type_ == CopyHdType::D2H) { copy_type_name = "copy_d2h"; } else if (copy_type_ == CopyHdType::H2D) { copy_type_name = "copy_h2d"; } else { LOG(FATAL) << "unknow copy type: " << copy_type_; } conf.set_name(std::string(copy_type_name) + "_" + lbi().op_name() + "-" + lbi().blob_name() + "_" + std::to_string(task_id())); *conf.mutable_user_conf()->mutable_op_type_name() = copy_type_name; auto in_regst = GetSoleConsumedRegst("copy_in"); CHECK_EQ(in_regst->NumOfLbi(), 1); in_regst->ForEachLbi([&](const LogicalBlobId& lbi) { (*conf.mutable_user_conf()->mutable_input())["in"].add_s(GenLogicalBlobName(lbi)); (*conf.mutable_user_conf()->mutable_output())["out"].add_s( GenLogicalBlobName(conf.name(), GenRepeatedBn("out", 0))); }); return conf; } void CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) { set_machine_id(machine_id); set_thrd_id(EncodeStreamIdToInt64( GenerateNamedTaskStreamId(machine_id, DeviceType::kCPU, 0, "COMM_NET"))); set_lbi(lbi); } OperatorConf CopyCommNetTaskNode::NewCopyOpConf() { OperatorConf conf; conf.set_name("copy_comm_net_" + NewUniqueId()); conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); *(conf.mutable_copy_comm_net_conf()->mutable_lbi()) = lbi(); return conf; } Maybe CopyHdTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_copy_hd_task()) << "not a serialized CopyHdTaskNode. debug string: " << transport_task_proto.DebugString(); copy_type_ = transport_task_proto.copy_hd_task().copy_type(); return Maybe::Ok(); } void CopyHdTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); transport_task_proto->mutable_copy_hd_task()->set_copy_type(copy_type_); } Maybe CopyCommNetTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_copy_comm_net_task()) << "not a serialized CopyCommNetTaskNode. debug string: " << transport_task_proto.DebugString(); return Maybe::Ok(); } void CopyCommNetTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); transport_task_proto->mutable_copy_comm_net_task(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/copy_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { class CopyTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyTaskNode); CopyTaskNode() = default; virtual ~CopyTaskNode() = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void BuildExecGphAndRegst() override; protected: virtual OperatorConf NewCopyOpConf() = 0; private: void InferProducedDataRegstTimeShape() final; }; class CopyHdTaskNode final : public CopyTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyHdTaskNode); CopyHdTaskNode() = default; ~CopyHdTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kCopyHd; } void Init(CopyHdType, const DeviceId& device_id, const LogicalBlobId& lbi); void ProduceAllRegstsAndBindEdges() override; CopyHdType copy_type() const { return copy_type_; } MemZoneId MemZoneId121() const override { if (copy_type_ == CopyHdType::H2D) { return TaskNode::MemZoneId121(); } else if (copy_type_ == CopyHdType::D2H) { return GetNodeCPUMemZoneId(this->machine_id()); } else { UNIMPLEMENTED(); } return kInvalidMemZoneId; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void InitProducedRegstMemCase(MemoryCase*) override; OperatorConf NewCopyOpConf() override; CopyHdType copy_type_; }; class CopyCommNetTaskNode final : public CopyTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyCommNetTaskNode); CopyCommNetTaskNode() = default; ~CopyCommNetTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kCopyCommNet; } void Init(int64_t machine_id, const LogicalBlobId& lbi); Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: OperatorConf NewCopyOpConf() override; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/exec_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/exec_graph.h" #include #include "oneflow/core/common/just.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { void ExecNode::BindBnWithRegst(const std::string& bn, std::shared_ptr regst) { CHECK(bn_in_op2regst_.emplace(bn, regst).second); } void ExecNode::BindBnsWithRegst(const PbRpf& (Operator::*bns_getter)() const, std::shared_ptr regst) { for (const std::string& bn : (op_.get()->*bns_getter)()) { BindBnWithRegst(bn, regst); } } void ExecNode::AddBnToRegstAndBindIt(const PbRpf& (Operator::*bns_getter)() const, std::shared_ptr regst) { for (const std::string& bn : (op_.get()->*bns_getter)()) { regst->AddLbi(op_->BnInOp2Lbi(bn)); } BindBnsWithRegst(bns_getter, regst); } bool ExecNode::TryBindBnWithOneOfTheRegsts(const std::string& bn, const std::list>& regsts) { const LogicalBlobId& lbi = op()->BnInOp2Lbi(bn); bool has_binded = false; for (std::shared_ptr regst : regsts) { if (regst->GetBlobDesc(lbi) == nullptr) { continue; } BindBnWithRegst(bn, regst); has_binded = true; break; } return has_binded; } void ExecNode::BindBnWithOneOfTheRegsts(const std::string& bn, const std::list>& regsts) { CHECK(TryBindBnWithOneOfTheRegsts(bn, regsts)); } void ExecNode::UnbindBnWithEmptyRegst() { EraseIf>( &bn_in_op2regst_, [](HashMap>::iterator it) { return it->second->regst_desc_type().has_data_regst_desc() && it->second->NumOfLbi() == 0; }); } void ExecNode::ToProto(const ParallelContext* parallel_ctx, ExecNodeProto* ret) const { op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf()); for (const auto& bn_regst : bn_in_op2regst_) { const std::string& bn_in_op = bn_regst.first; auto regst = bn_regst.second; CHECK(regst); PbMapPair pair{bn_in_op, regst->regst_desc_id()}; CHECK(ret->mutable_bn_in_op2regst_desc_id()->insert(pair).second); } } namespace { Maybe CheckPhysicalBlobDesc(const BlobDesc& logical, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext* parallel_ctx, const BlobDesc& physical) { CHECK_EQ_OR_RETURN(physical.shape(), *JUST(GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx))); return Maybe::Ok(); } Maybe CheckPhysicalBlobDesc( const Operator& op, const PbRpf& bns, const std::function(const std::string&)>& GetLogicalBlobDesc, const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx, const std::function& GetPhysicalBlobDesc) { const std::shared_ptr op_parallel_desc = JUST(op.GetOpParallelDesc()); for (const auto& bn : bns) { const BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn); if (physical_blob_desc == nullptr) { // TODO(liujuncheng): remove this hotfix continue; } if (*JUST(op.GetParallelDesc4BnInOp(bn)) == *op_parallel_desc) { JUST_MSG(CheckPhysicalBlobDesc(*JUST(GetLogicalBlobDesc(bn)), nd_sbp_signature->bn_in_op2nd_sbp().at(bn), *op_parallel_desc, parallel_ctx, *physical_blob_desc), std::stringstream() << " check physical shape failed, op name " << op.op_loc()); } } return Maybe::Ok(); } // A helper function to infer blob's physical shape with ND SBP. Maybe InferPhysicalBlobDesc( const Operator& op, const PbRpf& bns, const std::function(const std::string&)>& GetLogicalBlobDesc, const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx, const std::function& GetPhysicalBlobDesc) { const std::shared_ptr op_parallel_desc = JUST(op.GetOpParallelDesc()); for (const auto& bn : bns) { BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn); const auto& logical_blob_desc = *JUST(GetLogicalBlobDesc(bn)); CHECK_NOTNULL_OR_RETURN(physical_blob_desc) << "physical_blob_desc should not be nullptr. op location: " << op.op_loc(); *physical_blob_desc = logical_blob_desc; const auto& physical_shape = JUST_MSG( GetPhysicalShape(logical_blob_desc.shape(), nd_sbp_signature->bn_in_op2nd_sbp().at(bn), *op_parallel_desc, *parallel_ctx), std::stringstream() << " check physical shape failed, op name " << op.op_loc()); physical_blob_desc->set_shape(*physical_shape); } return Maybe::Ok(); } } // namespace void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) { auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc(); const OpNode* op_node = Singleton::Get()->OpNode4OpName(op()->op_name()); const NdSbpSignature* nd_sbp_signature = nullptr; if (op_node != nullptr) { nd_sbp_signature = &op_node->nd_sbp_signature(); } if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) { CHECK_JUST(CheckPhysicalBlobDesc( *op(), op()->input_bns(), std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1), nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); } CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()), std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc()); if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) { CHECK_JUST(CheckPhysicalBlobDesc( *op(), op()->output_bns(), std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1), nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); } CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_, GetBlobDesc4BnInOp, parallel_ctx), std::stringstream() << " infer inplace obn to ibn is failed, op name " << op_->op_loc()); } void ExecNode::InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx) { const HashSet ibns{op()->input_bns().begin(), op()->input_bns().end()}; HashMap ibn2blob_desc{}; const auto& GetBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* { // Generate temp regst to store input blob desc, and will be released after infer output blob // desc. if (ibns.count(bn_in_op) > 0) { auto iter = ibn2blob_desc.find(bn_in_op); if (iter == ibn2blob_desc.end()) { iter = ibn2blob_desc.emplace(bn_in_op, BlobDesc(kInvalidDataType, kContiguous)).first; } return &iter->second; } auto it = bn_in_op2regst_.find(bn_in_op); if (it == bn_in_op2regst_.end()) { return nullptr; } std::shared_ptr regst = it->second; CHECK(regst); return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op)); }; const OpNode* op_node = Singleton::Get()->OpNode4OpName(op()->op_name()); const NdSbpSignature* nd_sbp_signature = &CHECK_NOTNULL(op_node)->nd_sbp_signature(); // TODO(strint): user op can infer output with SBP, so there is no need to infer the input. // Reference: https://github.com/Oneflow-Inc/oneflow/pull/8971 // Infer input blob desc with SBP, the infer results are set into the temp input blob desc. CHECK_JUST(InferPhysicalBlobDesc( *op(), op()->input_bns(), std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1), nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); // Infer output blob desc with input. CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()), std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc()); CHECK_JUST(CheckPhysicalBlobDesc( *op(), op()->output_bns(), std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1), nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_, GetBlobDesc4BnInOp, parallel_ctx), std::stringstream() << " infer inplace obn to ibn is failed, op name " << op_->op_loc()); } std::function ExecNode::GetBlobDesc4BnInOpFunc() const { return [this](const std::string& bn_in_op) -> BlobDesc* { auto it = bn_in_op2regst_.find(bn_in_op); if (it == bn_in_op2regst_.end()) { return nullptr; } std::shared_ptr regst = it->second; CHECK(regst); return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op)); }; } void ExecGraph::ToExecSequence(const ParallelContext* parallel_ctx, ExecSequence* ret) const { TopoForEachNode([&](ExecNode* node) { node->ToProto(parallel_ctx, ret->add_exec_node()); }); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/exec_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_ #include "oneflow/core/common/protobuf.h" #include "oneflow/core/graph/exec_sequence.pb.h" #include "oneflow/core/graph/graph.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/register/register_desc.h" namespace oneflow { class ExecNode; class ExecEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(ExecEdge); ExecEdge() = default; ~ExecEdge() = default; // Getters const LogicalBlobId& lbi() const { return lbi_; } const std::string& src_bn() const { return src_bn_; } const std::string& dst_bn() const { return dst_bn_; } // Setters void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; } std::string& mut_src_bn() { return src_bn_; } std::string& mut_dst_bn() { return dst_bn_; } private: // various names for one blob LogicalBlobId lbi_; std::string src_bn_; std::string dst_bn_; }; class ExecNode final : public Node { public: OF_DISALLOW_COPY_AND_MOVE(ExecNode); ExecNode() {} ~ExecNode() = default; std::shared_ptr op() const { return op_; } std::shared_ptr& mut_op() { return op_; } RegstDesc* RegstDesc4BnInOp(const std::string& bn) const { return bn_in_op2regst_.at(bn).get(); } void BindBnWithRegst(const std::string& bn, std::shared_ptr); void BindBnsWithRegst(const PbRpf& (Operator::*bns_getter)() const, std::shared_ptr); void AddBnToRegstAndBindIt(const PbRpf& (Operator::*bns_getter)() const, std::shared_ptr); bool TryBindBnWithOneOfTheRegsts(const std::string&, const std::list>&); void BindBnWithOneOfTheRegsts(const std::string&, const std::list>&); void UnbindBnWithEmptyRegst(); std::string VisualStr() const override { return op_->op_name(); } void ToProto(const ParallelContext*, ExecNodeProto*) const; typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*); void InferBlobDescsByInputs(const ParallelContext* parallel_ctx); void InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx); const HashMap& mut_inplace_obn2ibn() const { return mut_inplace_obn2ibn_; } const HashMap& con_inplace_obn2ibn() const { return con_inplace_obn2ibn_; } private: std::function GetBlobDesc4BnInOpFunc() const; std::shared_ptr op_; HashMap> bn_in_op2regst_; HashMap mut_inplace_obn2ibn_; HashMap con_inplace_obn2ibn_; }; class ExecGraph final : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(ExecGraph); ExecGraph() = default; ~ExecGraph() = default; void ToExecSequence(const ParallelContext*, ExecSequence*) const; const char* TypeName() const override { return "ExecGraph"; } private: }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_ ================================================ FILE: oneflow/core/graph/exec_sequence.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/kernel/kernel.proto"; message ExecNodeProto { required KernelConf kernel_conf = 1; map bn_in_op2regst_desc_id = 2; } message ExecSequence { repeated ExecNodeProto exec_node = 1; } ================================================ FILE: oneflow/core/graph/fake_consumed_regst_provider.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ #define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ namespace oneflow { // Provide a compute task node with a fake input regst, and its output regst can be inferred using // SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob // desc, mainly to ensure that the transport task node has the correct input blob desc. class FakeConsumedRegstProvider { public: FakeConsumedRegstProvider() = default; virtual ~FakeConsumedRegstProvider() = default; virtual void ConsumeFakeRegstsIf() = 0; virtual void EraseFakeRegstsIf() = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ ================================================ FILE: oneflow/core/graph/graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_GRAPH_H_ #include #include #include "oneflow/core/common/str_util.h" #include "oneflow/core/graph/node.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" namespace oneflow { template class Graph { public: OF_DISALLOW_COPY_AND_MOVE(Graph); Graph() = default; virtual ~Graph() = default; // For Each void ForEachNode(std::function NodeHandler) const; Maybe MaybeForEachNode(std::function(NodeType*)> NodeHandler) const; // In case you want to change the topological structure during the node handler. // For example, adding/deleting a node or an edge. // Still, it might have bugs even if you use TopoForEachNodeDynamic. void TopoForEachNodeDynamic(std::function NodeHandler) const; void TopoForEachNode(std::function NodeHandler) const; Maybe TopoForEachNodeDynamicWithErrorCaptured( std::function(NodeType*)> NodeHandler) const; Maybe TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; void ForEachEdge(std::function EdgeHandler) const; Maybe MaybeForEachEdge(std::function(EdgeType*)> EdgeHandler) const; void SortedTopoForEachNode(std::function LessThan, std::function NodeHandler) const; void BfsForEachNode( const std::list& starts, const std::function&)>& ForEachNext, const std::function& Handler) const; void DfsForEachNode( const std::list& starts, const std::function&)>& ForEachNext, const std::function& Handler) const; void TopoForEachNodeDynamic( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; void TopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; void TopoForEachNode( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; Maybe TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; Maybe TopoForEachNodeWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; Maybe TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; void DfsTopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; void DfsTopoForEachNodeSortByDistanceToSink( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; std::function MakePredicatorIsReachable() const; std::function MakePredicatorIsReachable( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const; void ForEachConnectedComponent( const std::function&)>& Handler) const; void ForEachConnectedComponent( const std::function&)>& ForEachConnected, const std::function&)>& Handler) const; void ForEachConnectedComponent( const std::function&)>& ForEachNodeAsStart, const std::function&)>& ForEachConnected, const std::function&)>& Handler) const; // find first nontrivial strongly connected component std::unique_ptr> FindFirstNontrivialSCC( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const; std::unique_ptr> FindFirstNontrivialSCC( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const; std::unique_ptr> FindFirstNontrivialSCC() const; // Getters std::list source_nodes() const; std::list sink_nodes() const; NodeType* SoleSourceNode() const; NodeType* SoleSinkNode() const; NodeType* SoleNode() const; size_t node_num() const { return nodes_.size(); } size_t edge_num() const { return edges_.size(); } virtual const char* TypeName() const { return ""; } // Setters template DerivedNodeType* NewNode(); template EdgeType* NewEdge(Args&&... args); void AddAllocatedNode(NodeType*); void AddAllocatedEdge(EdgeType*); void DeleteNode(NodeType*); // ToDot template void ToDotWithStream(StreamT& out_stream) const; template void ToDotWithStream(const std::function& IsNodeAllowed, const std::function& IsEdgeAllowed, const std::function& AddNodeAttribute, const std::function& AddEdgeAttribute, StreamT& out_stream) const; void ToDotWithFilePath(const std::string& file_path) const; void ToDotWithFilePath(const std::function& AddNodeAttribute, const std::function& AddEdgeAttribute, const std::string& file_path) const; void ToDotWithFilePath(const std::function& IsNodeAllowed, const std::function& IsEdgeAllowed, const std::string& file_path) const; void ToDotWithAutoFilePath() const; private: std::unique_ptr> FindFirstNontrivialSCC( const std::function&)>& ForEachStart, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const; // finish time first search void FfsForEachNode( const std::function&)>& ForEachStart, const std::function&)>& ForEachNext, const std::function& Handler) const; void FfsForEachNode(const std::function& Handler) const; std::vector> nodes_; std::vector> edges_; }; template void Graph::ForEachNode(std::function NodeHandler) const { for (auto& x : nodes_) { NodeHandler(x.get()); } } template Maybe Graph::MaybeForEachNode( std::function(NodeType*)> NodeHandler) const { for (auto& x : nodes_) { JUST(NodeHandler(x.get())); } return Maybe::Ok(); } template std::list Graph::source_nodes() const { std::list ret; ForEachNode([&](NodeType* node) { if (node->in_edges().empty()) { ret.emplace_back(node); } }); return ret; } template std::list Graph::sink_nodes() const { std::list ret; ForEachNode([&](NodeType* node) { if (node->out_edges().empty()) { ret.emplace_back(node); } }); return ret; } template NodeType* Graph::SoleSourceNode() const { std::list source_nodes_list = source_nodes(); CHECK_EQ(source_nodes_list.size(), 1); return source_nodes_list.front(); } template NodeType* Graph::SoleSinkNode() const { std::list sink_nodes_list = sink_nodes(); CHECK_EQ(sink_nodes_list.size(), 1); return sink_nodes_list.front(); } template void Graph::TopoForEachNodeDynamic( std::function NodeHandler) const { TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template void Graph::TopoForEachNode(std::function NodeHandler) const { CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { NodeHandler(node); return Maybe::Ok(); })); } template Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template Maybe Graph::TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template void Graph::SortedTopoForEachNode( std::function LessThan, std::function NodeHandler) const { ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); }); TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge, NodeHandler); } template void Graph::ReverseTopoForEachNode( std::function NodeHandler) const { TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler); } template void Graph::ForEachEdge(std::function EdgeHandler) const { for (auto& x : edges_) { if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; } EdgeHandler(x.get()); } } template Maybe Graph::MaybeForEachEdge( std::function(EdgeType*)> EdgeHandler) const { for (auto& x : edges_) { if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; } JUST(EdgeHandler(x.get())); } return Maybe::Ok(); } template NodeType* Graph::SoleNode() const { CHECK_EQ(nodes_.size(), 1); return nodes_.front().get(); } template template DerivedNodeType* Graph::NewNode() { DerivedNodeType* ret = new DerivedNodeType; AddAllocatedNode(ret); return ret; } template template EdgeType* Graph::NewEdge(Args&&... args) { EdgeType* ret = new EdgeType(std::forward(args)...); AddAllocatedEdge(ret); return ret; } template void Graph::AddAllocatedNode(NodeType* node) { nodes_.emplace_back(node); } template void Graph::AddAllocatedEdge(EdgeType* edge) { edges_.emplace_back(edge); } template void Graph::DeleteNode(NodeType* node) { Erase>>( nodes_, [node](const std::unique_ptr& node_ptr) { return node_ptr.get() == node; }); } template template void Graph::ToDotWithStream(StreamT& out_stream) const { ToDotWithStream([](NodeType*) { return true; }, [](EdgeType*) { return true; }, [](NodeType*) { return ""; }, [](EdgeType*) { return ""; }, out_stream); } template template void Graph::ToDotWithStream( const std::function& IsNodeAllowed, const std::function& IsEdgeAllowed, const std::function& AddNodeAttribute, const std::function& AddEdgeAttribute, StreamT& out_stream) const { out_stream << "digraph {\n"; this->ForEachNode([&](NodeType* node) { if (IsNodeAllowed(node) == false) { return; } out_stream << "\"" << node->node_id_str() << "\" [label=\"" << node->VisualStr() << "\"" << AddNodeAttribute(node) << "]\n"; }); this->ForEachEdge([&](EdgeType* edge) { if (IsEdgeAllowed(edge) == false) { return; } if (IsNodeAllowed(edge->src_node()) == false) { return; } if (IsNodeAllowed(edge->dst_node()) == false) { return; } out_stream << "\"" << edge->src_node()->node_id_str() << "\" -> " << "\"" << edge->dst_node()->node_id_str() << "\"" << "[label=\"" << edge->VisualStr() << "\"" << AddEdgeAttribute(edge) << "];\n"; }); out_stream << "}\n"; } template void Graph::ToDotWithFilePath(const std::string& file_path) const { auto log_stream = TeePersistentLogStream::Create(file_path); ToDotWithStream(log_stream); log_stream->Flush(); } template void Graph::ToDotWithFilePath( const std::function& AddNodeAttribute, const std::function& AddEdgeAttribute, const std::string& file_path) const { auto log_stream = TeePersistentLogStream::Create(file_path); ToDotWithStream([](NodeType*) { return true; }, [](EdgeType*) { return true; }, AddNodeAttribute, AddEdgeAttribute, log_stream); log_stream->Flush(); } template void Graph::ToDotWithFilePath( const std::function& IsNodeAllowed, const std::function& IsEdgeAllowed, const std::string& file_path) const { auto log_stream = TeePersistentLogStream::Create(file_path); ToDotWithStream( IsNodeAllowed, IsEdgeAllowed, [](NodeType*) { return ""; }, [](EdgeType*) { return ""; }, log_stream); log_stream->Flush(); } template void Graph::ToDotWithAutoFilePath() const { std::string file_path = JoinPath("dot", TypeName(), NewUniqueId() + ".dot"); ToDotWithFilePath(file_path); } template void Graph::BfsForEachNode( const std::list& starts, const std::function&)>& ForEachNext, const std::function& Handler) const { HashSet queued_nodes; std::queue queue; for (NodeType* start : starts) { if (queued_nodes.find(start) == queued_nodes.end()) { queue.push(start); queued_nodes.insert(start); } } while (!queue.empty()) { NodeType* cur_node = queue.front(); queue.pop(); Handler(cur_node); ForEachNext(cur_node, [&](NodeType* next) { if (queued_nodes.find(next) == queued_nodes.end()) { queue.push(next); queued_nodes.insert(next); } }); } } template void Graph::DfsForEachNode( const std::list& starts, const std::function&)>& ForEachNext, const std::function& Handler) const { HashSet visited_nodes; std::stack stack; for (NodeType* start : starts) { stack.push(start); } while (!stack.empty()) { NodeType* cur_node = stack.top(); stack.pop(); if (visited_nodes.find(cur_node) == visited_nodes.end()) { Handler(cur_node); visited_nodes.insert(cur_node); ForEachNext(cur_node, [&](NodeType* next) { if (visited_nodes.find(next) == visited_nodes.end()) { stack.push(next); } }); } } } template void Graph::FfsForEachNode( const std::function&)>& ForEachStart, const std::function&)>& ForEachNext, const std::function& Handler) const { HashSet visited_nodes; HashSet handled_nodes; ForEachStart([&](NodeType* start) { if (visited_nodes.find(start) != visited_nodes.end()) { return; } std::stack> stack; stack.emplace(std::queue{}); stack.top().push(start); while (!stack.empty()) { if (stack.top().empty()) { stack.pop(); continue; } if (handled_nodes.find(stack.top().front()) != handled_nodes.end()) { stack.top().pop(); continue; } NodeType* cur_node = stack.top().front(); if (visited_nodes.find(cur_node) == visited_nodes.end()) { visited_nodes.insert(cur_node); } int64_t next_unvisited_cnt = 0; ForEachNext(cur_node, [&](NodeType* next) { if (visited_nodes.find(next) == visited_nodes.end()) { if (next_unvisited_cnt == 0) { stack.emplace(std::queue()); } stack.top().push(next); ++next_unvisited_cnt; } }); if (next_unvisited_cnt == 0) { Handler(cur_node); handled_nodes.insert(cur_node); } } }); } template std::unique_ptr> Graph::FindFirstNontrivialSCC( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const { auto ForEachStart = [&](const std::function& Handler) { for (NodeType* start : starts) { Handler(start); } }; return FindFirstNontrivialSCC(ForEachStart, ForEachInNode, ForEachOutNode); } template std::unique_ptr> Graph::FindFirstNontrivialSCC( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const { return FindFirstNontrivialSCC( [&](const std::function& Handler) { ForEachNode(Handler); }, ForEachInNode, ForEachOutNode); } template std::unique_ptr> Graph::FindFirstNontrivialSCC() const { return FindFirstNontrivialSCC( [&](const std::function& Handler) { ForEachNode(Handler); }, &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge); } template std::unique_ptr> Graph::FindFirstNontrivialSCC( const std::function&)>& ForEachStart, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const { std::stack stack; FfsForEachNode(ForEachStart, ForEachOutNode, [&](NodeType* node) { stack.push(node); }); HashSet visited; auto ForEachUnvisitedInNode = [&](NodeType* node, const std::function& Handler) { ForEachInNode(node, [&](NodeType* in_node) { if (visited.find(in_node) == visited.end()) { Handler(in_node); } }); }; while (stack.empty() == false) { NodeType* cur_node = stack.top(); stack.pop(); auto ret = std::make_unique>(); DfsForEachNode({cur_node}, ForEachUnvisitedInNode, [&](NodeType* node) { CHECK(ret->insert(node).second); }); for (const auto& node : *ret) { visited.insert(node); } if (ret->size() > 1) { return ret; } } return std::unique_ptr>(); } template void Graph::TopoForEachNodeDynamic( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { Handler(node); return Maybe::Ok(); })); } template void Graph::TopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { CHECK_JUST( TopoForEachNodeWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { Handler(node); return Maybe::Ok(); })); } template void Graph::TopoForEachNode( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) { Handler(node); return Maybe::Ok(); })); } template Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const { HashMap has_queued; std::queue queue; for (NodeType* start : starts) { queue.push(start); has_queued[start] = true; ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); } while (!queue.empty()) { NodeType* cur_node = queue.front(); queue.pop(); JUST(Handler(cur_node)); ForEachOutNode(cur_node, [&](NodeType* out) { bool is_ready = true; ForEachInNode(out, [&](NodeType* in) { if (is_ready && !has_queued[in]) { is_ready = false; } }); if (is_ready && !has_queued[out]) { queue.push(out); has_queued[out] = true; } }); } return Maybe::Ok(); } template Maybe Graph::TopoForEachNodeWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const { HashMap counter_in; std::queue queue; for (NodeType* start : starts) { queue.push(start); counter_in[start] = 0; ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); } while (!queue.empty()) { NodeType* cur_node = queue.front(); queue.pop(); JUST(Handler(cur_node)); ForEachOutNode(cur_node, [&](NodeType* out) { auto it = counter_in.find(out); // Move the initialization here if (it == counter_in.end()) { int32_t count = 0; ForEachInNode(out, [&](NodeType* out_in) { count++; }); counter_in[out] = count; it = counter_in.find(out); } it->second--; if (it->second == 0) { queue.push(out); } }); } return Maybe::Ok(); } template Maybe Graph::TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const { HashMap counter_in; std::queue queue; ForEachNode([&](NodeType* node) { int32_t count = 0; ForEachInNode(node, [&](NodeType*) { count++; }); counter_in[node] = count; if (count == 0) { queue.push(node); } }); while (!queue.empty()) { NodeType* cur_node = queue.front(); queue.pop(); JUST(Handler(cur_node)); ForEachOutNode(cur_node, [&](NodeType* out) { --counter_in[out]; if (counter_in[out] == 0) { queue.push(out); } }); } return Maybe::Ok(); } template void Graph::DfsTopoForEachNodeSortByDistanceToSink( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { HashMap node2distance_to_sink; { std::list nodes; TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { nodes.emplace_back(node); }); std::list sinks; for (NodeType* node : nodes) { bool is_sink = true; ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; }); if (is_sink) { sinks.emplace_back(node); } } TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) { int64_t distance_to_sink = -1; ForEachOutNode(node, [&](NodeType* out_node) { distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]); }); node2distance_to_sink[node] = distance_to_sink + 1; }); } auto ForEachOutNodeSortedByDistanceToSink = [&](NodeType* node, const std::function& Handler) { std::vector out_nodes; ForEachOutNode(node, [&](NodeType* out_node) { out_nodes.emplace_back(out_node); }); std::sort(out_nodes.begin(), out_nodes.end(), [&](NodeType* lhs, NodeType* rhs) { // DfsTopoForEachNode use stack, so sort desc return node2distance_to_sink.at(lhs) > node2distance_to_sink.at(rhs); }); for (NodeType* out_node : out_nodes) { Handler(out_node); } }; DfsTopoForEachNode(starts, ForEachInNode, ForEachOutNodeSortedByDistanceToSink, Handler); } template void Graph::DfsTopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { HashMap be_visited; std::stack stack; for (NodeType* start : starts) { stack.push(start); ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); } while (!stack.empty()) { NodeType* cur_node = stack.top(); stack.pop(); Handler(cur_node); be_visited[cur_node] = true; ForEachOutNode(cur_node, [&](NodeType* out) { bool is_ready = true; ForEachInNode(out, [&](NodeType* in) { if (is_ready && !be_visited[in]) { is_ready = false; } }); if (is_ready && !be_visited[out]) { stack.push(out); } }); } } template std::function Graph::MakePredicatorIsReachable() const { return MakePredicatorIsReachable(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge); } template std::function Graph::MakePredicatorIsReachable( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode) const { static constexpr int64_t BITSET_SIZE = 512; // size of cache line class BitSet { public: BitSet() = default; ~BitSet() = default; void Insert(int64_t k) { bitset_vec_.at(k / BITSET_SIZE).set(k % BITSET_SIZE, true); } bool Contains(int64_t k) { return bitset_vec_.at(k / BITSET_SIZE).test(k % BITSET_SIZE); } void Merge(const BitSet& other) { CHECK_EQ(bitset_vec_.size(), other.bitset_vec_.size()); for (int64_t i = 0; i < bitset_vec_.size(); ++i) { bitset_vec_.at(i) |= other.bitset_vec_.at(i); } } void Resize(size_t size) { const int64_t bitset_vec_size = RoundUp(size, BITSET_SIZE) / BITSET_SIZE; bitset_vec_.resize(bitset_vec_size); } private: using bitset_vec = std::vector>; bitset_vec bitset_vec_; }; using NodePtr2Id = HashMap; using Id2Ancestor = std::vector; std::shared_ptr node2id(new NodePtr2Id); std::shared_ptr id2ancestor(new Id2Ancestor(node_num())); int64_t id = 0; node2id->reserve(node_num()); TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { node2id->emplace(node, id); id2ancestor->at(id).Resize(node_num()); id += 1; }); TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { const int64_t node_id = node2id->at(node); auto& ancestor_bitset_vec = id2ancestor->at(node_id); ForEachInNode(node, [&](NodeType* in_node) { const int64_t in_node_id = node2id->at(in_node); ancestor_bitset_vec.Insert(in_node_id); ancestor_bitset_vec.Merge(id2ancestor->at(in_node_id)); }); }); return [id2ancestor, node2id](const NodeType* src, const NodeType* dst) -> bool { const int64_t dst_id = node2id->at(dst); return id2ancestor->at(dst_id).Contains(node2id->at(src)); }; } template void Graph::ForEachConnectedComponent( const std::function&)>& Handler) const { ForEachConnectedComponent( [&](const std::function& Handler) { ForEachNode(Handler); }, &NodeType::ForEachNodeOnInOutEdge, Handler); } template void Graph::ForEachConnectedComponent( const std::function&)>& ForEachConnected, const std::function&)>& Handler) const { ForEachConnectedComponent( [&](const std::function& Handler) { ForEachNode(Handler); }, ForEachConnected, Handler); } template void Graph::ForEachConnectedComponent( const std::function&)>& ForEachNodeAsStart, const std::function&)>& ForEachConnected, const std::function&)>& Handler) const { HashMap node2component_id; int32_t cur_component_id = 0; ForEachNodeAsStart([&](NodeType* start) { if (node2component_id.find(start) != node2component_id.end()) { return; } ++cur_component_id; BfsForEachNode({start}, ForEachConnected, [&](NodeType* node) { CHECK(node2component_id.emplace(node, cur_component_id).second); }); }); HashMap> component_id2nodes; for (const auto& pair : node2component_id) { component_id2nodes[pair.second].insert(pair.first); } for (const auto& pair : component_id2nodes) { Handler(pair.second); } } } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_GRAPH_H_ ================================================ FILE: oneflow/core/graph/inplace_lbi_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace { bool IsSourceNode(const Operator& op) { const auto& op_conf = op.op_conf(); if (op_conf.has_user_conf() && op_conf.user_conf().input().size() == 0 && op_conf.user_conf().output().size() == 1) { return true; } if (op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == "mutable_cast_once") { return true; } if (op_conf.has_variable_conf()) { return true; } if (op_conf.has_distribute_clone_conf() && op_conf.distribute_clone_conf().is_variable_ref()) { return true; } if (op_conf.has_distribute_split_conf() && op_conf.distribute_split_conf().is_variable_ref()) { return true; } return false; } void CheckSubGraph(const HashSet& nodes) { size_t source_op_node_cnt = 0; size_t updt_node_cnt = 0; size_t source_cnt = 0; for (const auto* node : nodes) { if (node->in_edges().empty()) { CHECK_EQ(++source_cnt, 1); } if (dynamic_cast(node) != nullptr) { CHECK_EQ(++source_op_node_cnt, 1); CHECK(node->in_edges().empty()); } if (dynamic_cast(node) != nullptr) { CHECK_EQ(++updt_node_cnt, 1); CHECK(dynamic_cast(node->SoleInEdge()->src_node()) != nullptr) << "UpdateInplaceLbiNode-lbi: " << PbMessage2TxtString(node->lbi()) << ", src_node.in_edges_size: " << node->SoleInEdge()->src_node()->in_edges().size() << ", SoleInNode: " << typeid(node->SoleInEdge()->src_node()).name() << ", " << PbMessage2TxtString(node->SoleInEdge()->src_node()->lbi()); } } } const InplaceLbiNode* GetRoot(const HashSet& nodes, const std::function& IsValidEdge) { const InplaceLbiNode* root = nullptr; for (const InplaceLbiNode* node : nodes) { if (node->GetValidInEdge(IsValidEdge) == nullptr) { CHECK_ISNULL(root); root = node; } } return root; } const InplaceLbiNode* FindSoleIsMutableIbnConsumer(const SourceOpInplaceLbiNode* node) { const InplaceLbiNode* ret = nullptr; for (const InplaceLbiEdge* edge : node->out_edges()) { if (dynamic_cast(edge->dst_node()) != nullptr) { CHECK_ISNULL(ret); ret = edge->dst_node(); } } return ret; } InplaceLbiNode* CreateNode(const LogicalBlobId& lbi, const std::function& Op4OpName) { const Operator& op = *Op4OpName(lbi.op_name()); if (IsSourceNode(op)) { return new SourceOpInplaceLbiNode(lbi); } else if (std::find_if(op.output_bns().begin(), op.output_bns().end(), [&](const std::string& obn) { return op.BnInOp2Lbi(obn) == lbi; }) != op.output_bns().end()) { return new NormalInplaceLbiNode(lbi); } else { return new UpdateInplaceLbiNode(lbi); } } void GetUnconnectedNodes(const HashSet& nodes, const std::function& IsValidEdge, HashSet* cur_disabled_nodes) { for (const InplaceLbiNode* node : nodes) { size_t cnt = 0; for (const InplaceLbiEdge* edge : node->in_edges()) { cnt += IsValidEdge(edge); } for (const InplaceLbiEdge* edge : node->out_edges()) { cnt += IsValidEdge(edge); } if (cnt == 0) { CHECK(cur_disabled_nodes->emplace(node).second); } } } const InplaceLbiNode* GetFirstDiffNode(const std::vector& lhs, const std::vector& rhs) { FOR_RANGE(int32_t, i, 0, std::min(lhs.size(), rhs.size())) { if (lhs.at(i) != rhs.at(i)) { return lhs.at(i); } } return nullptr; }; std::function&)> GetForEachValidInNode(const HashSet* nodes, std::function IsValidEdge) { return [nodes, IsValidEdge](const InplaceLbiNode* node, const std::function& Handler) { const InplaceLbiEdge* in_edge = node->GetValidInEdge(IsValidEdge); if (in_edge == nullptr) { return; } if (nodes->find(in_edge->src_node()) != nodes->end()) { Handler(in_edge->src_node()); } }; } std::function&)> GetForEachValidOutNode(const HashSet* nodes, std::function IsValidEdge) { return [nodes, IsValidEdge](const InplaceLbiNode* node, const std::function& Handler) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (nodes->find(out_node) != nodes->end()) { Handler(out_node); } }); }; } bool IsOtherIbnBoundToOneOfLbis(const HashSet& lbis, const InplaceLbiEdge* edge) { const Operator& op = edge->op(); for (const std::string& ibn : op.input_bns()) { if (ibn != edge->ibn() && lbis.find(op.BnInOp2Lbi(ibn)) != lbis.end()) { return true; } } return false; } void RemoveUnconnectedNodes(HashSet* nodes, const std::function& IsValidEdge) { HashSet cur_disabled_nodes; GetUnconnectedNodes(*nodes, IsValidEdge, &cur_disabled_nodes); for (const auto* node : cur_disabled_nodes) { nodes->erase(node); } } } // namespace const InplaceLbiEdge* InplaceLbiNode::GetValidInEdge( const std::function& IsValidEdge) const { if (!in_edges().empty() && IsValidEdge(SoleInEdge())) { return SoleInEdge(); } return nullptr; } const InplaceLbiEdge* InplaceLbiNode::GetSoleValidInEdge( const std::function& IsValidEdge) const { const auto* edge = GetValidInEdge(IsValidEdge); CHECK_NOTNULL(edge); return edge; } void InplaceLbiNode::ForEachNodeOnValidOutEdge( const std::function& IsValidEdge, const std::function& Handler) const { for (const auto* edge : out_edges()) { if (IsValidEdge(edge)) { Handler(edge->dst_node()); } } } bool InplaceLbiNode::IsMutRef(const std::function& IsValidEdge) const { UNIMPLEMENTED(); } bool InplaceLbiNode::IsConstRef( const std::function& IsValidEdge) const { return !IsMutRef(IsValidEdge); } bool NormalInplaceLbiNode::IsMutRef( const std::function& IsValidEdge) const { const InplaceLbiEdge* in_edge = GetValidInEdge(IsValidEdge); return in_edge != nullptr && in_edge->IsMutRef(); } bool InplaceLbiEdge::IsMutRef() const { CHECK_NOTNULL(dynamic_cast(dst_node())); return is_mut_ref_; } std::function InplaceLbiGraph::MakeMutFindOrCreateNode( std::function Op4OpName) { auto lbi2node = std::make_shared>(); return [this, lbi2node, Op4OpName](const LogicalBlobId& lbi) -> InplaceLbiNode* { auto node_it = lbi2node->find(lbi); if (node_it == lbi2node->end()) { auto* node = CreateNode(lbi, Op4OpName); AddAllocatedNode(node); node_it = lbi2node->emplace(lbi, node).first; } return node_it->second; }; } void InplaceLbiGraph::Init(const InplaceObasInfo& obas_info, const std::function& Op4OpName) { auto FindOrCreateNode = MakeMutFindOrCreateNode(Op4OpName); auto AddEdge = [&](const Operator& op, const LogicalBlobId& lbi, const std::string& ibn, const std::string& obn, bool is_mut) { auto* edge = new InplaceLbiEdge(&op, ibn, obn, is_mut); AddAllocatedEdge(edge); Connect(FindOrCreateNode(op.BnInOp2Lbi(ibn)), edge, FindOrCreateNode(lbi)); }; auto BuildNodeAndEdge4InplacePairs = [&](const OpBlobArgPairs& pairs, bool is_mut) { for (const auto& pair : pairs.pair()) { CHECK_EQ(pair.first().op_name(), pair.second().op_name()); const Operator& op = *Op4OpName(pair.first().op_name()); std::string ibn = pair.first().bn_in_op(); std::string obn = pair.second().bn_in_op(); LogicalBlobId lbi = op.BnInOp2Lbi(obn); CHECK(std::find(op.input_bns().begin(), op.input_bns().end(), ibn) != op.input_bns().end()); CHECK(std::find(op.output_bns().begin(), op.output_bns().end(), obn) != op.output_bns().end()); AddEdge(op, lbi, ibn, obn, is_mut); } }; for (const auto& oba : obas_info.mut_in_obas.oba()) { const Operator& op = *Op4OpName(oba.op_name()); std::string ibn = oba.bn_in_op(); std::string obn = ibn + "_updated"; LogicalBlobId lbi; lbi.set_op_name(op.op_name()); lbi.set_blob_name(obn); CHECK(std::find(op.input_bns().begin(), op.input_bns().end(), ibn) != op.input_bns().end()); CHECK(std::find_if(op.output_bns().begin(), op.output_bns().end(), [&](const std::string& obn) { return op.BnInOp2Lbi(obn) == lbi; }) == op.output_bns().end()); AddEdge(op, lbi, ibn, obn, true); } BuildNodeAndEdge4InplacePairs(obas_info.mut_inplace_oba_pairs, true); BuildNodeAndEdge4InplacePairs(obas_info.con_inplace_oba_pairs, false); ForEachNode([](const InplaceLbiNode* node) { CHECK_LE(node->in_edges().size(), 1); }); CHECK(!FindFirstNontrivialSCC()); } void InplaceLbiGraph::ComputeSafeInplaceObns( InplaceObasInfo* obas_info, const std::function& IsReachableFromLbiToOpName) const { ComputeSafeInplaceEdges(IsReachableFromLbiToOpName, [&](const InplaceLbiEdge* edge) { CHECK_NOTNULL(dynamic_cast(edge->dst_node())); if (edge->IsMutRef()) { auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(edge->op().op_name(), edge->ibn()); *pair->mutable_second() = GenOpBlobArg(edge->op().op_name(), edge->obn()); } else { auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(edge->op().op_name(), edge->ibn()); *pair->mutable_second() = GenOpBlobArg(edge->op().op_name(), edge->obn()); } }); } void InplaceLbiGraph::ComputeSafeInplaceEdges( const std::function& IsReachableFromLbiToOpName, const std::function& Handler) const { ForEachConnectedComponent([&](const HashSet& nodes) { ComputeSafeInplaceEdges(nodes, IsReachableFromLbiToOpName, Handler); }); } void InplaceLbiGraph::ForEachSafeInplaceEdgeInSourceOpSubTree( const HashSet& nodes, const std::function& IsReachableFromLbiToOpName, const std::function& Handler, HashSet* disabled_edges) const { disabled_edges->clear(); auto IsValidEdge = [&](const InplaceLbiEdge* edge) { return disabled_edges->find(edge) == disabled_edges->end(); }; const InplaceLbiNode* root = GetRoot(nodes, [](const InplaceLbiEdge*) { return true; }); const auto* source_op_root = dynamic_cast(root); if (source_op_root != nullptr) { const InplaceLbiNode* updt_node = FindSoleIsMutableIbnConsumer(source_op_root); if (updt_node != nullptr) { HashSet cur_disabled_edges; FixConstRefOrMutRefConflictsToUpdtNode(nodes, IsReachableFromLbiToOpName, &cur_disabled_edges); disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end()); } { HashSet cur_disabled_edges; FixMutRefConflictsFromSourceOpNode(source_op_root, IsValidEdge, &cur_disabled_edges); disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end()); } { // disconnect edges in the subtree containning `root` HashSet cur_disabled_edges; auto ForEachNext = GetForEachValidOutNode(&nodes, IsValidEdge); BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) { const InplaceLbiEdge* in_edge = node->GetValidInEdge(IsValidEdge); if (in_edge != nullptr) { CHECK(cur_disabled_edges.emplace(in_edge).second); } if (dynamic_cast(node) != nullptr) { CHECK_NOTNULL(in_edge); if (node->IsConstRef(IsValidEdge)) { Handler(in_edge); } } }); disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end()); } } } void InplaceLbiGraph::ComputeSafeInplaceEdges( const HashSet& nodes, const std::function& IsReachableFromLbiToOpName, const std::function& Handler) const { CheckSubGraph(nodes); HashSet remainder_nodes(nodes); HashSet disabled_edges; { // compute safe inplace edges in the subtree containning SourceOpInplaceLbiNode as root HashSet cur_disabled_edges; ForEachSafeInplaceEdgeInSourceOpSubTree(remainder_nodes, IsReachableFromLbiToOpName, Handler, &cur_disabled_edges); disabled_edges.insert(cur_disabled_edges.begin(), cur_disabled_edges.end()); } auto IsValidEdge = [&](const InplaceLbiEdge* edge) { return remainder_nodes.find(edge->src_node()) != remainder_nodes.end() && remainder_nodes.find(edge->dst_node()) != remainder_nodes.end() && disabled_edges.find(edge) == disabled_edges.end(); }; RemoveUnconnectedNodes(&remainder_nodes, IsValidEdge); size_t dead_loop_check = remainder_nodes.size(); while (!remainder_nodes.empty()) { ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet& nodes) { const InplaceLbiEdge* cur_disabled_edge = FindFirstInterOpRefConflictMutRefEdge(nodes, IsValidEdge, IsReachableFromLbiToOpName); if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); } }); ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet& nodes) { const InplaceLbiEdge* cur_disabled_edge = FindFirstConstRefConflictMutRefEdge(nodes, IsValidEdge, IsReachableFromLbiToOpName); if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); } }); ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet& nodes) { const InplaceLbiEdge* cur_disabled_edge = FindFirstIntraOpRefConflictMutRefEdge(nodes, IsValidEdge); if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); } }); { HashSet cur_safe_inplace_obn_edges; GetSafeInplaceObnEdges(remainder_nodes, IsValidEdge, IsReachableFromLbiToOpName, &cur_safe_inplace_obn_edges); for (const auto* edge : cur_safe_inplace_obn_edges) { Handler(edge); } disabled_edges.insert(cur_safe_inplace_obn_edges.begin(), cur_safe_inplace_obn_edges.end()); } RemoveUnconnectedNodes(&remainder_nodes, IsValidEdge); CHECK_GE(--dead_loop_check, 0); } } void InplaceLbiGraph::FindAllEdges(const HashSet& nodes, const std::function& IsValidEdge, HashSet* cur_disabled_edges) const { for (const auto* node : nodes) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second); }); } } const InplaceLbiEdge* InplaceLbiGraph::FindFirstIntraOpRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge) const { const InplaceLbiEdge* ret = nullptr; HashSet lbis; for (const auto* node : nodes) { CHECK(lbis.insert(node->lbi()).second); } const auto* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsOtherIbnBoundToOneOfLbis(lbis, node->SoleInEdge())) { ret = node->SoleInEdge(); } }); return ret; } bool InplaceLbiGraph::IsConstRefConflictMutRefNode( const InplaceLbiNode* mut_ref_node, const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const { CHECK(mut_ref_node->IsMutRef(IsValidEdge)); auto ForEachNext = [&](const InplaceLbiNode* node, const std::function& Handler) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (out_node != mut_ref_node) { Handler(out_node); } }); }; bool conflict = false; const auto& op_name = mut_ref_node->lbi().op_name(); BfsForEachNode({GetRoot(nodes, IsValidEdge)}, ForEachNext, [&](const InplaceLbiNode* node) { conflict = conflict || !IsLbiAllConsumerReachableToOpName(node->lbi(), op_name); }); return conflict; } const InplaceLbiEdge* InplaceLbiGraph::FindFirstConstRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const { const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); const InplaceLbiEdge* ret = nullptr; TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsConstRefConflictMutRefNode(node, nodes, IsValidEdge, IsLbiAllConsumerReachableToOpName)) { ret = node->GetValidInEdge(IsValidEdge); } }); return ret; } const InplaceLbiEdge* InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const { HashSet mut_ref_nodes; HashMap> node2mut_ref_ancestors; { const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (node->IsMutRef(IsValidEdge)) { mut_ref_nodes.insert(node); } size_t in_edges_size_check = 0; ForEachInNode(node, [&](const InplaceLbiNode* in_node) { node2mut_ref_ancestors[node] = node2mut_ref_ancestors[in_node]; if (in_node->IsMutRef(IsValidEdge)) { node2mut_ref_ancestors[node].emplace_back(in_node); } CHECK_EQ(++in_edges_size_check, 1); }); }); } std::vector last_mut_ref_nodes; { HashMap mut_ref_node2descendents_size; for (const InplaceLbiNode* descendent : mut_ref_nodes) { for (const InplaceLbiNode* ancestor : node2mut_ref_ancestors.at(descendent)) { ++mut_ref_node2descendents_size[ancestor]; } } for (const InplaceLbiNode* node : mut_ref_nodes) { if (mut_ref_node2descendents_size[node] == 0) { last_mut_ref_nodes.emplace_back(node); } } } if (last_mut_ref_nodes.size() <= 1) { return nullptr; } const InplaceLbiNode* first_diff_node = nullptr; { const auto& first = node2mut_ref_ancestors.at(last_mut_ref_nodes.at(0)); const auto& second = node2mut_ref_ancestors.at(last_mut_ref_nodes.at(1)); first_diff_node = GetFirstDiffNode(first, second); if (first_diff_node == nullptr) { first_diff_node = last_mut_ref_nodes.at(first.size() < second.size() ? 0 : 1); } } return first_diff_node->GetSoleValidInEdge(IsValidEdge); } void InplaceLbiGraph::GetSafeInplaceObnEdges( const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName, HashSet* cur_disabled_edges) const { ForEachTree(nodes, IsValidEdge, [&](const HashSet& nodes) { // no inter-op reference conflicts const InplaceLbiEdge* inter_op_conflict_ref_edge = FindFirstInterOpRefConflictMutRefEdge( nodes, IsValidEdge, IsLbiAllConsumerReachableToOpName); // mutable reference always goes after const reference const InplaceLbiEdge* const_ref_conflict_ref_edge = FindFirstConstRefConflictMutRefEdge(nodes, IsValidEdge, IsLbiAllConsumerReachableToOpName); // no intra-op reference conflicts const InplaceLbiEdge* intra_op_conflict_ref_edge = FindFirstIntraOpRefConflictMutRefEdge(nodes, IsValidEdge); if (const_ref_conflict_ref_edge == nullptr && intra_op_conflict_ref_edge == nullptr && inter_op_conflict_ref_edge == nullptr) { FindAllEdges(nodes, IsValidEdge, cur_disabled_edges); } }); } void InplaceLbiGraph::ForEachTree( const HashSet& nodes, const std::function& IsValidEdge, const std::function&)>& Handler) const { auto ForEachNode = [&](const std::function& Handler) { for (const auto* node : nodes) { Handler(node); } }; auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); auto ForEachConnected = [&](const InplaceLbiNode* node, const std::function& Handler) { ForEachInNode(node, Handler); ForEachOutNode(node, Handler); }; ForEachConnectedComponent(ForEachNode, ForEachConnected, Handler); } void InplaceLbiGraph::FixConstRefOrMutRefConflictsToUpdtNode( const HashSet& nodes, const std::function& IsLbiAllConsumerReachableToOpName, HashSet* cur_disabled_edges) const { auto IsValidEdge = [](const InplaceLbiEdge*) { return true; }; const InplaceLbiNode* updt_node = nullptr; HashSet safe_const_ref_nodes; const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge); CHECK_NOTNULL(root); { const auto* source_op_root = dynamic_cast(root); CHECK_NOTNULL(source_op_root); updt_node = FindSoleIsMutableIbnConsumer(source_op_root); CHECK_NOTNULL(updt_node); auto ForEachNext = [&](const InplaceLbiNode* node, const std::function& Handler) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (dynamic_cast(out_node) == nullptr) { return; } if (out_node->IsMutRef(IsValidEdge)) { return; } if (!IsLbiAllConsumerReachableToOpName(out_node->lbi(), updt_node->lbi().op_name())) { return; } Handler(out_node); }); }; BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) { if (node == root) { return; } CHECK(safe_const_ref_nodes.emplace(node).second); }); } for (const auto* node : safe_const_ref_nodes) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (safe_const_ref_nodes.find(out_node) == safe_const_ref_nodes.end() && out_node != updt_node) { CHECK(nodes.find(out_node) != nodes.end()); CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second); } }); } // remove mutable inplace edges from root which are not end with model update node root->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { const auto* node = dynamic_cast(out_node); if (node != nullptr && node->IsMutRef(IsValidEdge)) { CHECK(nodes.find(out_node) != nodes.end()); CHECK(cur_disabled_edges->emplace(node->GetSoleValidInEdge(IsValidEdge)).second); } }); } void InplaceLbiGraph::FixMutRefConflictsFromSourceOpNode( const SourceOpInplaceLbiNode* root, const std::function& IsValidEdge, HashSet* cur_disabled_edges) const { HashSet safe_const_ref_nodes; { auto ForEachNext = [&](const InplaceLbiNode* node, const std::function& Handler) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (dynamic_cast(out_node) == nullptr) { Handler(out_node); } else if (out_node->IsConstRef(IsValidEdge)) { Handler(out_node); } else { // do nothing } }); }; BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) { if (dynamic_cast(node) != nullptr) { CHECK(safe_const_ref_nodes.emplace(node).second); } }); } for (const auto* node : safe_const_ref_nodes) { node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) { if (safe_const_ref_nodes.find(out_node) == safe_const_ref_nodes.end() && dynamic_cast(out_node) != nullptr && out_node->IsMutRef(IsValidEdge)) { CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second); } }); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/inplace_lbi_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/graph.h" #include "oneflow/core/register/op_blob_arg_info.h" namespace oneflow { class InplaceLbiEdge; class InplaceLbiNode : public Node { public: virtual ~InplaceLbiNode() = default; const LogicalBlobId& lbi() const { return lbi_; } const InplaceLbiEdge* GetValidInEdge( const std::function& IsValidEdge) const; const InplaceLbiEdge* GetSoleValidInEdge( const std::function& IsValidEdge) const; void ForEachNodeOnValidOutEdge(const std::function& IsValidEdge, const std::function& Handler) const; virtual bool IsMutRef(const std::function& IsValidEdge) const; bool IsConstRef(const std::function& IsValidEdge) const; std::string VisualStr() const override { return GenLogicalBlobName(lbi_); } protected: OF_DISALLOW_COPY_AND_MOVE(InplaceLbiNode); explicit InplaceLbiNode(const LogicalBlobId& lbi) : lbi_(lbi) {} private: LogicalBlobId lbi_; }; class NormalInplaceLbiNode final : public InplaceLbiNode { public: OF_DISALLOW_COPY_AND_MOVE(NormalInplaceLbiNode); explicit NormalInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {} ~NormalInplaceLbiNode() override = default; bool IsMutRef(const std::function& IsValidEdge) const override; }; class SourceOpInplaceLbiNode final : public InplaceLbiNode { public: OF_DISALLOW_COPY_AND_MOVE(SourceOpInplaceLbiNode); explicit SourceOpInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {} ~SourceOpInplaceLbiNode() = default; }; class UpdateInplaceLbiNode final : public InplaceLbiNode { public: OF_DISALLOW_COPY_AND_MOVE(UpdateInplaceLbiNode); explicit UpdateInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {} ~UpdateInplaceLbiNode() = default; }; class InplaceLbiEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(InplaceLbiEdge); InplaceLbiEdge(const Operator* op, const std::string& ibn, const std::string& obn, bool is_mut_ref) : op_(op), ibn_(ibn), obn_(obn), is_mut_ref_(is_mut_ref) {} ~InplaceLbiEdge() = default; const Operator& op() const { return *op_; } const std::string& ibn() const { return ibn_; } const std::string& obn() const { return obn_; } bool IsMutRef() const; bool IsConstRef() const { return !IsMutRef(); } std::string VisualStr() const override { return std::string(op_->op_name() + "/" + ibn_ + ":" + obn_); } private: const Operator* op_; const std::string ibn_; const std::string obn_; const bool is_mut_ref_; }; class InplaceLbiGraph final : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(InplaceLbiGraph); InplaceLbiGraph(const InplaceObasInfo& obas_info, const std::function& Op4OpName) { Init(obas_info, Op4OpName); } ~InplaceLbiGraph() = default; const char* TypeName() const override { return "InplaceLbiGraph"; } void ComputeSafeInplaceObns(InplaceObasInfo* obas_info, const std::function& IsLbiAllConsumerReachableToOpName) const; private: void Init(const InplaceObasInfo& obas_info, const std::function& Op4OpName); std::function MakeMutFindOrCreateNode( std::function Op4OpName); void ComputeSafeInplaceEdges(const std::function& IsLbiAllConsumerReachableToOpName, const std::function& Handler) const; void ComputeSafeInplaceEdges(const HashSet& nodes, const std::function& IsLbiAllConsumerReachableToOpName, const std::function& Handler) const; void ForEachSafeInplaceEdgeInSourceOpSubTree( const HashSet& nodes, const std::function& IsLbiAllConsumerReachableToOpName, const std::function& Handler, HashSet* cur_disabled_edges) const; void GetSafeInplaceObnEdges(const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName, HashSet* cur_disabled_edges) const; const InplaceLbiEdge* FindFirstConstRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const; const InplaceLbiEdge* FindFirstIntraOpRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge) const; const InplaceLbiEdge* FindFirstInterOpRefConflictMutRefEdge( const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const; bool IsConstRefConflictMutRefNode( const InplaceLbiNode* mut_ref_node, const HashSet& nodes, const std::function& IsValidEdge, const std::function& IsLbiAllConsumerReachableToOpName) const; void FixConstRefOrMutRefConflictsToUpdtNode( const HashSet& nodes, const std::function& IsLbiAllConsumerReachableToOpName, HashSet* cur_disabled_edges) const; void FixMutRefConflictsFromSourceOpNode( const SourceOpInplaceLbiNode* root, const std::function& IsValidEdge, HashSet* cur_disabled_edges) const; void ForEachTree(const HashSet& nodes, const std::function& IsValidEdge, const std::function&)>& Handler) const; void FindAllEdges(const HashSet& nodes, const std::function& IsValidEdge, HashSet* cur_disabled_edges) const; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_ ================================================ FILE: oneflow/core/graph/inplace_regst_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/inplace_regst_graph.h" namespace oneflow { namespace { std::function MakeGetterRegstDesc4RegstDescId( const HashSet& regst_descs) { auto regst_desc_id2regst_desc = std::make_shared>(); for (const auto* regst_desc : regst_descs) { CHECK(regst_desc_id2regst_desc->emplace(regst_desc->regst_desc_id(), regst_desc).second); } return [regst_desc_id2regst_desc](int64_t regst_desc_id) -> const RegstDescProto* { auto it = regst_desc_id2regst_desc->find(regst_desc_id); return it == regst_desc_id2regst_desc->end() ? nullptr : it->second; }; } } // namespace InplaceRegstGraph::InplaceRegstGraph(const HashSet& regst_descs) { auto RegstDesc4RegstDescId = MakeGetterRegstDesc4RegstDescId(regst_descs); auto FindOrCreate = MakeMutFindOrCreateNode(); for (const RegstDescProto* regst_desc : regst_descs) { if (regst_desc->has_hint_inplace_consumed_regst_desc_id()) { const RegstDescProto* in_regst_desc = RegstDesc4RegstDescId(regst_desc->hint_inplace_consumed_regst_desc_id()); if (in_regst_desc != nullptr) { auto* edge = new InplaceRegstEdge(); AddAllocatedEdge(edge); Connect(FindOrCreate(in_regst_desc), edge, FindOrCreate(regst_desc)); } } } } std::function InplaceRegstGraph::MakeMutFindOrCreateNode() { auto regst_desc2node = std::make_shared>(); return [regst_desc2node, this](const RegstDescProto* regst_desc) -> InplaceRegstNode* { auto it = regst_desc2node->find(regst_desc); if (it == regst_desc2node->end()) { InplaceRegstNode* node = new InplaceRegstNode(regst_desc); AddAllocatedNode(node); it = regst_desc2node->emplace(regst_desc, node).first; } return it->second; }; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/inplace_regst_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/register/register_desc.pb.h" #include "oneflow/core/graph/graph.h" namespace oneflow { class InplaceRegstEdge; class InplaceRegstNode final : public Node { public: OF_DISALLOW_COPY_AND_MOVE(InplaceRegstNode); explicit InplaceRegstNode(const RegstDescProto* regst_desc) : regst_desc_(regst_desc) {} ~InplaceRegstNode() = default; const RegstDescProto* regst_desc() const { return regst_desc_; } private: const RegstDescProto* regst_desc_; }; class InplaceRegstEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(InplaceRegstEdge); InplaceRegstEdge() = default; ~InplaceRegstEdge() = default; }; class InplaceRegstGraph final : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(InplaceRegstGraph); explicit InplaceRegstGraph(const HashSet& regst_descs); private: std::function MakeMutFindOrCreateNode(); }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_ ================================================ FILE: oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/job/placement.pb.h" namespace oneflow { void NcclSendRecvBoxingTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const int64_t parallel_id, const ParallelDesc& parallel_desc, const bool has_input, const bool has_output, const std::string& stream_name) { set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); logical_shape_ = logical_shape; src_nd_sbp_ = src_nd_sbp; dst_nd_sbp_ = dst_nd_sbp; src_parallel_conf_ = src_parallel_desc.parallel_conf(); dst_parallel_conf_ = dst_parallel_desc.parallel_conf(); parallel_conf_ = parallel_desc.parallel_conf(); parallel_ctx_.set_parallel_id(parallel_id); parallel_ctx_.set_parallel_num(parallel_desc.parallel_num()); has_input_ = has_input; has_output_ = has_output; data_type_ = data_type; stream_name_ = stream_name; } void NcclSendRecvBoxingTaskNode::ProduceAllRegstsAndBindEdges() { if (has_output_) { std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } ProduceRegst("tmp", true); } void NcclSendRecvBoxingTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Nccl-Send-Recv-Boxing-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); op_conf.set_stream_name_hint(stream_name_); auto* nccl_send_recv_boxing_conf = op_conf.mutable_nccl_send_recv_boxing_conf(); *nccl_send_recv_boxing_conf->mutable_lbi() = lbi(); logical_shape_.ToProto(nccl_send_recv_boxing_conf->mutable_logical_shape()); nccl_send_recv_boxing_conf->set_data_type(data_type_); *nccl_send_recv_boxing_conf->mutable_src_nd_sbp() = src_nd_sbp_; *nccl_send_recv_boxing_conf->mutable_dst_nd_sbp() = dst_nd_sbp_; *nccl_send_recv_boxing_conf->mutable_parallel_conf() = parallel_conf_; *nccl_send_recv_boxing_conf->mutable_src_parallel_conf() = src_parallel_conf_; *nccl_send_recv_boxing_conf->mutable_dst_parallel_conf() = dst_parallel_conf_; nccl_send_recv_boxing_conf->set_has_input(has_input_); nccl_send_recv_boxing_conf->set_has_output(has_output_); std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); node->mut_op() = sole_op; CHECK_JUST(sole_op->FillOpParallelDesc(parallel_conf_)); if (has_input_) { node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); } if (has_output_) { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); } node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() { auto out_regst = GetProducedRegst("out"); if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } auto tmp_regst = GetProducedRegst("tmp"); tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } Maybe NcclSendRecvBoxingTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_nccl_send_recv_boxing_task()) << "not a serialized NcclSendRecvBoxingTaskNode. debug string: " << transport_task_proto.DebugString(); const auto& proto = transport_task_proto.nccl_send_recv_boxing_task(); logical_shape_ = Shape(proto.logical_shape()); data_type_ = proto.data_type(); src_nd_sbp_ = proto.src_nd_sbp(); dst_nd_sbp_ = proto.dst_nd_sbp(); src_parallel_conf_ = proto.src_parallel_conf(); dst_parallel_conf_ = proto.dst_parallel_conf(); parallel_conf_ = proto.parallel_conf(); parallel_ctx_ = proto.parallel_ctx(); has_input_ = proto.has_input(); has_output_ = proto.has_output(); stream_name_ = proto.stream_name(); return Maybe::Ok(); } void NcclSendRecvBoxingTaskNode::ToTransportTaskProto( TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); auto* proto = transport_task_proto->mutable_nccl_send_recv_boxing_task(); logical_shape_.ToProto(proto->mutable_logical_shape()); proto->set_data_type(data_type_); *proto->mutable_src_nd_sbp() = src_nd_sbp_; *proto->mutable_dst_nd_sbp() = dst_nd_sbp_; *proto->mutable_src_parallel_conf() = src_parallel_conf_; *proto->mutable_dst_parallel_conf() = dst_parallel_conf_; *proto->mutable_parallel_conf() = parallel_conf_; *proto->mutable_parallel_ctx() = parallel_ctx_; proto->set_has_input(has_input_); proto->set_has_output(has_output_); proto->set_stream_name(stream_name_); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/nccl_send_recv_boxing_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/job/placement.pb.h" namespace oneflow { class NcclSendRecvBoxingTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingTaskNode); NcclSendRecvBoxingTaskNode() = default; ~NcclSendRecvBoxingTaskNode() override = default; void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, const int64_t parallel_id, const ParallelDesc& parallel_desc, const bool has_input, const bool has_output, const std::string& stream_name); TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; } const ParallelContext* parallel_ctx() const override { return ¶llel_ctx_; } Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; Shape logical_shape_; DataType data_type_; NdSbp src_nd_sbp_; NdSbp dst_nd_sbp_; ParallelConf src_parallel_conf_; ParallelConf dst_parallel_conf_; ParallelConf parallel_conf_; ParallelContext parallel_ctx_; bool has_input_; bool has_output_; std::string stream_name_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/node.h" namespace oneflow { int64_t NewNodeId() { static int64_t node_id = 0; return node_id++; } int64_t NewEdgeId() { static int64_t edge_id = 0; return edge_id++; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_NODE_H_ #define ONEFLOW_CORE_GRAPH_NODE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/id_manager.h" namespace oneflow { template void Connect(NodeType* src_node, EdgeType* edge, NodeType* dst_node) { CHECK(src_node->out_edges_.insert(edge).second); CHECK(dst_node->in_edges_.insert(edge).second); CHECK(edge->src_node_ == nullptr); CHECK(edge->dst_node_ == nullptr); edge->src_node_ = src_node; edge->dst_node_ = dst_node; } template void DisConnect(EdgeType* edge) { CHECK_EQ(edge->src_node_->out_edges_.erase(edge), 1); CHECK_EQ(edge->dst_node_->in_edges_.erase(edge), 1); edge->src_node_ = nullptr; edge->dst_node_ = nullptr; } int64_t NewNodeId(); int64_t NewEdgeId(); template class Edge { public: OF_DISALLOW_COPY_AND_MOVE(Edge); Edge() { edge_id_ = NewEdgeId(); src_node_ = nullptr; dst_node_ = nullptr; } virtual ~Edge() = default; int64_t edge_id() const { return edge_id_; } NodeType* src_node() const { return src_node_; } NodeType* dst_node() const { return dst_node_; } virtual std::string VisualStr() const { return ""; } private: friend void Connect(NodeType* src_node, EdgeType* edge, NodeType* dst_node); friend void DisConnect(EdgeType* edge); int64_t edge_id_; NodeType* src_node_; NodeType* dst_node_; }; template class Node { public: OF_DISALLOW_COPY_AND_MOVE(Node); Node() { node_id_ = NewNodeId(); } virtual ~Node() = default; int64_t node_id() const { return node_id_; } std::string node_id_str() const { return std::to_string(node_id_); } EdgeType* SoleInEdge() const { CHECK_EQ(in_edges_.size(), 1); return *(in_edges_.begin()); } EdgeType* SoleOutEdge() const { CHECK_EQ(out_edges_.size(), 1); return *(out_edges_.begin()); } const std::unordered_set& in_edges() const { return in_edges_; } const std::unordered_set& out_edges() const { return out_edges_; } void ForEachNodeOnInEdge(std::function Handler) const { for (EdgeType* edge : in_edges_) { Handler(edge->src_node()); } } void ForEachNodeOnOutEdge(std::function Handler) const { for (EdgeType* edge : out_edges_) { Handler(edge->dst_node()); } } void ForEachNodeOnInOutEdge(std::function Handler) const { ForEachNodeOnInEdge(Handler); ForEachNodeOnOutEdge(Handler); } Maybe ForEachInNode(std::function(NodeType*)> Handler) const { for (EdgeType* edge : in_edges_) { JUST(Handler(edge->src_node())); } return Maybe::Ok(); } Maybe ForEachOutNode(std::function(NodeType*)> Handler) const { for (EdgeType* edge : out_edges_) { JUST(Handler(edge->dst_node())); } return Maybe::Ok(); } Maybe ForEachInOutNode(std::function(NodeType*)> Handler) const { JUST(ForEachNodeOnInEdge(Handler)); JUST(ForEachNodeOnOutEdge(Handler)); return Maybe::Ok(); } void ForEachNodeOnSortedInEdge(std::function Handler) const { for (EdgeType* edge : sorted_in_edges_) { Handler(edge->src_node()); } } void ForEachNodeOnSortedOutEdge(std::function Handler) const { for (EdgeType* edge : sorted_out_edges_) { Handler(edge->dst_node()); } } void ForEachNodeOnSortedInOutEdge(std::function Handler) const { ForEachNodeOnSortedInEdge(Handler); ForEachNodeOnSortedOutEdge(Handler); } void DisconnectAllEdges() { for (EdgeType* edge : in_edges_) { DisConnect(edge); } for (EdgeType* edge : out_edges_) { DisConnect(edge); } } virtual std::string VisualStr() const { return ""; } void SortInOutEdges(std::function LessThan) { sorted_in_edges_.assign(in_edges_.begin(), in_edges_.end()); sorted_out_edges_.assign(out_edges_.begin(), out_edges_.end()); std::sort(sorted_in_edges_.begin(), sorted_in_edges_.end(), LessThan); std::sort(sorted_out_edges_.begin(), sorted_out_edges_.end(), LessThan); } private: friend void Connect(NodeType* src_node, EdgeType* edge, NodeType* dst_node); friend void DisConnect(EdgeType* edge); int64_t node_id_; HashSet in_edges_; HashSet out_edges_; std::vector sorted_in_edges_; std::vector sorted_out_edges_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_NODE_H_ ================================================ FILE: oneflow/core/graph/normal_forward_compute_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ #include "oneflow/core/graph/compute_task_node.h" namespace oneflow { size_t RegstNum4Op(const Operator& sole_op); class NormalForwardCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode); NormalForwardCompTaskNode() = default; ~NormalForwardCompTaskNode() = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kNormalForward; } private: void ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num); void BuildExecGphAndRegst() override; void BuildExecGphStructAndBindInRegst(); void BuildOutRegst(); void BuildTmp7BufRegsts(); }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/op_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/sbp_infer_util.h" namespace oneflow { bool OpEdge::NeedBoxing() const { if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; } if (src_node()->parallel_desc().parallel_num() == 1) { return false; } for (const auto& lbi : *lbis_) { Shape src_reduced_hierarchy; Shape dst_reduced_hierarchy; NdSbp src_reduced_nd_sbp; NdSbp dst_reduced_nd_sbp; InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(), *dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi), dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy, &dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp, src_node()->LogicalBlobDesc4Lbi(lbi).shape()); if (src_reduced_hierarchy != dst_reduced_hierarchy || src_reduced_nd_sbp != dst_reduced_nd_sbp) { // Not one to one return true; } } return false; } std::string OpEdge::VisualStr() const { std::string str; int32_t idx = 0; for (const LogicalBlobId& lbi : *lbis_) { if (idx++ > 0) { str += "\\n"; } str += lbi.blob_name() + ":"; str += src_node()->LogicalBlobDesc4Lbi(lbi).shape().ToString(); } return str; } const SbpParallel& OpNode::SbpParallel4BnInOp(const std::string& bn_in_op) const { return *CHECK_JUST(op().SbpParallel4BnInOp(bn_in_op)); } const SbpParallel& OpNode::SbpParallel4Lbi(const LogicalBlobId& lbi) const { auto it = lbi2nd_sbp_.find(lbi); CHECK(it != lbi2nd_sbp_.end()); CHECK_EQ(it->second.sbp_parallel_size(), 1); return it->second.sbp_parallel(0); } const NdSbp& OpNode::NdSbp4BnInOp(const std::string& bn_in_op) const { return *CHECK_JUST(op().NdSbp4BnInOp(bn_in_op)); } const NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const { auto it = lbi2nd_sbp_.find(lbi); CHECK(it != lbi2nd_sbp_.end()); return it->second; } OpNode::OpNode(Symbol parallel_desc, const OperatorConf& op_conf) : parallel_desc_(parallel_desc), op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))), ibns_(op_->input_bns().begin(), op_->input_bns().end()) { CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol())); } std::string OpNode::VisualStr() const { std::string str = op().op_name(); { for (int64_t machine_id : parallel_desc().sorted_machine_ids()) { const std::string dev_type = *CHECK_JUST(DeviceTag4DeviceType(parallel_desc().device_type())); std::string parallel_desc_str = std::to_string(machine_id) + ":" + dev_type + ":"; const auto& dev_phy_ids = parallel_desc().sorted_dev_phy_ids(machine_id); parallel_desc_str += std::to_string(dev_phy_ids.front()); if (dev_phy_ids.back() > dev_phy_ids.front()) { parallel_desc_str += "-" + std::to_string(dev_phy_ids.back()); } str += "\\n" + parallel_desc_str; } } auto GetTimeShapeStr = [&](const Shape& shape, const std::string& prefix) { std::string time_shape_str = prefix + ":"; time_shape_str += shape.ToString(); return time_shape_str; }; if (in_edges().empty() == false) { str += "\\n" + GetTimeShapeStr(*CHECK_JUST(op().GetInputBlobFastestTimeShape()), "in_blob_time_shape"); } str += "\\n" + GetTimeShapeStr(*CHECK_JUST(op().GetOpTimeShape()), "op_time_shape"); return str; } const BlobDesc& OpNode::LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const { const OpNode& producer = ProducerOpNode4Lbi(lbi); const int32_t index = CHECK_JUST(producer.op().GetOutputIndex(lbi)); const BlobDesc* blob_desc = CHECK_JUST(producer.op().GetLogicalBlobDescPtr4OutputIndex(index)); return *blob_desc; } const OpNode& OpNode::SrcNode4Ibn(const std::string& bn_in_op) const { return *MutSrcNode4Ibn(bn_in_op); } OpNode* OpNode::MutSrcNode4Ibn(const std::string& bn_in_op) const { const LogicalBlobId& lbi = op().BnInOp2Lbi(bn_in_op); CHECK(ibns_.find(bn_in_op) != ibns_.end()); return MutSrcNode4InputLbi(lbi); } const OpNode& OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) const { const OpNode* producer = MutSrcNode4InputLbi(lbi); if (producer == nullptr) { producer = this; } return *producer; } OpNode* OpNode::MutSrcNode4InputLbi(const LogicalBlobId& lbi) const { auto it = lbi2source_node_.find(lbi); if (it == lbi2source_node_.end()) { return nullptr; } else { return it->second; } } bool OpNode::IsTimeShapeIdentity() const { std::shared_ptr in_shape = CHECK_JUST(op().GetInputBlobFastestTimeShape()); if (!in_shape) { return true; } std::shared_ptr op_shape = CHECK_JUST(op().GetOpTimeShape()); return *in_shape == *op_shape; } void OpNode::InitLbi2SourceNode() { for (OpEdge* edge : in_edges()) { for (const LogicalBlobId& lbi : edge->lbis()) { CHECK(lbi2source_node_.emplace(lbi, edge->src_node()).second); } } } void OpNode::InitLbi2NdSbp() { const auto Update = [&](const PbRpf& bns) { for (const auto& bn : bns) { const LogicalBlobId& lbi = op().BnInOp2Lbi(bn); const NdSbp& nd_sbp = NdSbp4BnInOp(bn); auto it = lbi2nd_sbp_.find(lbi); if (it == lbi2nd_sbp_.end()) { lbi2nd_sbp_[lbi] = nd_sbp; } else { CHECK(it->second == nd_sbp); } } }; Update(op().input_bns()); Update(op().output_bns()); } Maybe OpGraph::New(const Job& job) { const auto& op_graph = std::make_shared(); JUST(op_graph->Init(job)); return op_graph; } Maybe OpGraph::Init(const Job& job) { InitNodes(job); op_name2op_node_.reserve(job.net().op_size()); ForEachNode([&](OpNode* node) { CHECK(op_name2op_node_.emplace(node->op().op_name(), node).second) << "op_name: " << node->op().op_name(); }); InitEdges(); InitProducerOpName2CtrlConsumerOpNames(job); CheckIsDAG(); ForEachNode([](OpNode* node) { node->InitLbi2SourceNode(); }); InferBlobLastUsed(); InferTimeShape(); { LazyMode::Guard enable_lazy_mode_guard(true); JUST(InferLogicalBlobDesc(job)); } return Maybe::Ok(); } void OpGraph::CheckIsDAG() const { CHECK(!FindFirstNontrivialSCC()); auto ForEachIn = [&](OpNode* node, const std::function& Handler) { ForEachDataAndCtrlInNode(node, Handler); }; auto ForEachOut = [&](OpNode* node, const std::function& Handler) { ForEachDataAndCtrlOutNode(node, Handler); }; CHECK(!FindFirstNontrivialSCC(ForEachIn, ForEachOut)); } namespace { std::function(const std::string&)> MakeGetterParallelDesc4OpName( const Job& job) { const Placement& placement = job.placement(); auto op_name2parallel_desc = std::make_shared>>(); op_name2parallel_desc->reserve(job.net().op_size()); for (const auto& placement_group : placement.placement_group()) { const ParallelConf& parallel_conf = placement_group.parallel_conf(); Symbol parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); for (const std::string& op_name : placement_group.op_set().op_name()) { CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second) << "op_name: " << op_name; } } return [op_name2parallel_desc](const std::string& op_name) { return op_name2parallel_desc->at(op_name); }; } } // namespace void OpGraph::InitNodes(const Job& job) { auto ParallelDesc4OpName = MakeGetterParallelDesc4OpName(job); for (const auto& op_conf : job.net().op()) { op_names_.emplace_back(op_conf.name()); OpNode* node = new OpNode(ParallelDesc4OpName(op_conf.name()), op_conf); AddAllocatedNode(node); } } void OpGraph::InitEdges() { HashMap lbi2producer; HashMap>> producer_op_name2lbi2obn; ForEachNode([&](OpNode* op_node) { for (const auto& obn : op_node->op().output_bns()) { const auto& lbi = op_node->op().BnInOp2Lbi(obn); CHECK(lbi2producer.emplace(lbi, op_node).second); auto& lbi2obn = producer_op_name2lbi2obn[op_node->op().op_name()]; if (!lbi2obn) { lbi2obn.reset(new HashMap()); } CHECK(lbi2obn->emplace(lbi, obn).second); } }); ForEachNode([&](OpNode* op_node) { HashMap> producer_op_name2lbis; std::shared_ptr>> consumer_lbi2ibns( new HashMap>); op_node->input_index2producer_and_output_index_.reserve(op_node->op().input_bns().size()); for (const auto& ibn : op_node->op().input_bns()) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); producer_op_name2lbis[lbi.op_name()].insert(lbi); (*consumer_lbi2ibns)[lbi].emplace_back(ibn); auto producer_it = lbi2producer.find(lbi); CHECK(producer_it != lbi2producer.end()) << "producer not found: " << GenLogicalBlobName(lbi); const int32_t output_index = CHECK_JUST(producer_it->second->op().GetOutputIndex(lbi)); op_node->input_index2producer_and_output_index_.emplace_back(producer_it->second, output_index); } for (const auto& pair : producer_op_name2lbis) { std::shared_ptr> lbis( new std::vector({pair.second.begin(), pair.second.end()})); const auto it = producer_op_name2lbi2obn.find(pair.first); CHECK(it != producer_op_name2lbi2obn.end()) << "producer_op_name: " << pair.first; const auto& lbi2obn = it->second; auto producer_it = lbi2producer.find(lbis->front()); CHECK(producer_it != lbi2producer.end()) << "producer not found: " << GenLogicalBlobName(lbis->front()); Connect(producer_it->second, NewEdge(lbis, lbi2obn, consumer_lbi2ibns), op_node); } }); } void OpGraph::InitProducerOpName2CtrlConsumerOpNames(const Job& job) { for (const auto& op_conf : job.net().op()) { for (const auto& ctrl_in_op_name : op_conf.ctrl_in_op_name()) { auto* consumer_op_names = &producer_op_name2ctrl_consumer_op_names_[ctrl_in_op_name]; CHECK(consumer_op_names->emplace(op_conf.name()).second); } } } void OpGraph::InferBlobLastUsed() const { HashSet visisted_lbi; for (auto iter = op_names_.rbegin(); iter != op_names_.rend(); iter++) { Operator* op = op_name2op_node_.at(*iter)->mut_op(); auto* map = op->mut_blob_last_used_signature()->mutable_bn_in_op2blob_last_used(); const auto InferLastUsed = [&](const std::string& bn_in_op) { (*map)[bn_in_op] = visisted_lbi.insert(op->BnInOp2Lbi(bn_in_op)).second; }; for (const auto& obn : op->output_bns()) { InferLastUsed(obn); } for (const auto& ibn : op->input_bns()) { InferLastUsed(ibn); } } } void OpGraph::InferTimeShape() const { TopoForEachNode([&](OpNode* op_node) { auto GetInputBlobTimeShape = [&](int32_t index) -> Maybe { CHECK_LT_OR_RETURN(index, op_node->input_index2producer_and_output_index_.size()); return op_node->input_index2producer_and_output_index_.at(index).first->op().GetOpTimeShape(); }; CHECK_JUST(op_node->mut_op()->FillInputBlobTimeShape(GetInputBlobTimeShape)); CHECK_JUST(op_node->mut_op()->InferOpTimeShapeIf()); }); } void OpGraph::InferOpNodeNdSbpSignature(OpNode* op_node, const NdSbpSignature& nd_sbp_sig_conf) const { HashMap ibn2nd_sbp_infer_hint; for (const std::string& ibn : op_node->op().input_bns()) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); OpNode* producer = op_node->MutSrcNode4Ibn(ibn); const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi)); const ParallelDesc* parallel_desc = CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)).get(); const BlobDesc* logical_blob_desc = &producer->LogicalBlobDesc4Lbi(lbi); const NdSbp* nd_sbp = &producer->NdSbp4Lbi(lbi); ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(parallel_desc, logical_blob_desc, nd_sbp)); } const auto NdSbpInferHint4Ibn = [&](const std::string& bn) -> Maybe { auto it = ibn2nd_sbp_infer_hint.find(bn); CHECK_OR_RETURN(it != ibn2nd_sbp_infer_hint.end()); return Maybe(&it->second); }; CHECK_JUST(op_node->mut_op()->InferNdSbpSignatureIf(nd_sbp_sig_conf, op_node->parallel_desc(), NdSbpInferHint4Ibn)); op_node->InitLbi2NdSbp(); } Maybe OpGraph::InferOpNodeLocalSignature(OpNode* op_node, bool is_local_conf) const { HashMap ibn2local_sig_infer_hint; for (const std::string& ibn : op_node->op().input_bns()) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); const auto* producer = op_node->MutSrcNode4Ibn(ibn); const ParallelDesc* parallel_desc = &producer->parallel_desc(); const auto& producer_obn = *JUST(producer->op().obn4lbi(lbi)); const auto& opt_local_parallel = *JUST(producer->op().OptLocalParallel4BnInOp(producer_obn)); LocalSigInferHint infer_ctx(parallel_desc, opt_local_parallel.has_local_parallel()); ibn2local_sig_infer_hint.emplace(ibn, infer_ctx); } const auto& LocalSigInferHint4Ibn = [&](const std::string& ibn) -> Maybe { const auto& iter = ibn2local_sig_infer_hint.find(ibn); CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << "input blob not found. ibn: " << ibn; return &iter->second; }; JUST(op_node->mut_op()->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local_conf, op_node->parallel_desc())); return Maybe::Ok(); } const OpNode* OpGraph::OpNode4OpName(const std::string& op_name) const { const auto& op_node_it = op_name2op_node_.find(op_name); if (op_node_it == op_name2op_node_.end()) { return nullptr; } return op_node_it->second; } Maybe OpGraph::InferLogicalBlobDesc(const Job& job) const { JobParallelViewConf job_parallel_view_conf(job.job_parallel_view_conf()); JUST(TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe { auto LogicalBlobDesc4InputIndex = [&](int32_t index) -> Maybe { CHECK_LT_OR_RETURN(index, op_node->input_index2producer_and_output_index_.size()); const auto& producer_info = op_node->input_index2producer_and_output_index_.at(index); return producer_info.first->op().GetLogicalBlobDesc4OutputIndex(producer_info.second); }; JUST(op_node->mut_op()->FillLogicalInBlobDesc(LogicalBlobDesc4InputIndex)); // Infer ParallelSignature JUST(op_node->mut_op()->InferParallelSignatureIf()); // Infer local_signature bool is_local_conf = false; { const auto& op_name2is_local = job_parallel_view_conf.op_name2is_local_parallel_view(); const auto& iter = op_name2is_local.find(op_node->op().op_name()); if (iter != op_name2is_local.end()) { is_local_conf = iter->second; } } JUST(InferOpNodeLocalSignature(op_node, is_local_conf)); NdSbpSignature nd_sbp_sig_conf; { const auto& op_name2nd_sbp_sig_conf = job_parallel_view_conf.op_name2nd_sbp_signature_conf(); const auto& iter = op_name2nd_sbp_sig_conf.find(op_node->op().op_name()); if (iter != op_name2nd_sbp_sig_conf.end()) { nd_sbp_sig_conf = NdSbpSignature(iter->second); if (op_node->parallel_desc().hierarchy()->NumAxes() == 1) { const auto& op_name2sbp_sig_conf = job_parallel_view_conf.op_name2sbp_signature_conf(); const auto& op_name2sbp_sig_conf_it = op_name2sbp_sig_conf.find(op_node->op().op_name()); CHECK_OR_RETURN(op_name2sbp_sig_conf_it != op_name2sbp_sig_conf.end()) << op_node->op().op_name(); CheckSbpSignatureAndNdSbpEquals(SbpSignature(op_name2sbp_sig_conf_it->second), NdSbpSignature(iter->second)); } else { // do nothing } } } InferOpNodeNdSbpSignature(op_node, nd_sbp_sig_conf); JUST(op_node->mut_op()->InferLogicalOutBlobDescsIf()); return Maybe::Ok(); })); return Maybe::Ok(); } int64_t OpGraph::GetParallelNum(const std::string& op_name) const { return op_name2op_node_.at(op_name)->parallel_desc().parallel_num(); } const SbpParallel& OpGraph::GetSbpParallel(const std::string& op_name, const LogicalBlobId& lbi) const { return op_name2op_node_.at(GetOpNameKey(op_name, lbi)) ->SbpParallel4Lbi(GetLogicalBlobIdKey(op_name, lbi)); } const NdSbp& OpGraph::GetNdSbp(const std::string& op_name, const LogicalBlobId& lbi) const { return op_name2op_node_.at(GetOpNameKey(op_name, lbi)) ->NdSbp4Lbi(GetLogicalBlobIdKey(op_name, lbi)); } DataType OpGraph::GetBlobDataType(const LogicalBlobId& lbi) const { return op_name2op_node_.at(lbi.op_name()) ->LogicalBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi)) .data_type(); } const BlobDesc& OpGraph::GetLogicalBlobDesc(const LogicalBlobId& lbi) const { return op_name2op_node_.at(lbi.op_name()) ->LogicalBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi)); } std::string OpGraph::GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const { if (op_name2op_node_.find(op_name) != op_name2op_node_.end()) { return op_name; } else { UNIMPLEMENTED(); } } LogicalBlobId OpGraph::GetLogicalBlobIdKey(const std::string& op_name, const LogicalBlobId& lbi) const { if (op_name2op_node_.find(op_name) != op_name2op_node_.end()) { return lbi; } else { UNIMPLEMENTED(); } } void OpGraph::ForEachDataAndCtrlInNode(OpNode* node, const std::function& Handler) const { node->ForEachNodeOnInEdge(Handler); for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) { CHECK(op_name2op_node_.find(ctrl_in_op_name) != op_name2op_node_.end()) << " cannot find ctrl_in_op_name: [" << ctrl_in_op_name << "] of op: [" << node->op().op_name() << "] in OpGraph. "; Handler(op_name2op_node_.at(ctrl_in_op_name)); } } void OpGraph::ForEachDataAndCtrlOutNode(OpNode* node, const std::function& Handler) const { node->ForEachNodeOnOutEdge(Handler); const auto& op_name_it = producer_op_name2ctrl_consumer_op_names_.find(node->op().op_name()); if (op_name_it == producer_op_name2ctrl_consumer_op_names_.end()) { return; } for (const std::string& ctrl_consumer_op_name : op_name_it->second) { CHECK(op_name2op_node_.find(ctrl_consumer_op_name) != op_name2op_node_.end()) << " cannot find ctrl_consumer_op_name: [" << ctrl_consumer_op_name << "] of op: [" << node->op().op_name() << "] in OpGraph."; Handler(op_name2op_node_.at(ctrl_consumer_op_name)); } } void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function& NodeHandler) const { auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node, const std::function& Handler) { ForEachDataAndCtrlInNode(node, Handler); }; auto OpGraphForEachOutDataAndCtrlNode = [&](OpNode* node, const std::function& Handler) { ForEachDataAndCtrlOutNode(node, Handler); }; TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler); } std::function OpGraph::MakePredicatorIsOpNameDataOrCtrlReachable() const { auto IsDataOrCtrlReachable = MakePredicatorIsDataOrCtrlReachable(); return [IsDataOrCtrlReachable, this](const std::string& lhs, const std::string& rhs) { const auto& src_node_it = op_name2op_node_.find(lhs); if (src_node_it == op_name2op_node_.end()) { return false; } const auto& dst_node_it = op_name2op_node_.find(rhs); if (dst_node_it == op_name2op_node_.end()) { return false; } return (src_node_it->second == dst_node_it->second) || IsDataOrCtrlReachable(src_node_it->second, dst_node_it->second); }; } std::function OpGraph::MakePredicatorIsDataOrCtrlReachable() const { auto _1 = std::placeholders::_1; auto _2 = std::placeholders::_2; return MakePredicatorIsReachable(DataOrCtrlSourceNodes(), std::bind(&OpGraph::ForEachDataAndCtrlInNode, this, _1, _2), std::bind(&OpGraph::ForEachDataAndCtrlOutNode, this, _1, _2)); } std::list OpGraph::DataOrCtrlSourceNodes() const { std::list ret; ForEachNode([&](OpNode* op_node) { size_t in_edges_cnt = 0; ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_edges_cnt; }); if (in_edges_cnt == 0) { ret.emplace_back(op_node); } }); return ret; } void OpGraph::DumpLogicalBlobDesc(Job* job) const { auto* helper = job->mutable_helper(); ForEachNode([&](const OpNode* node) { for (const auto& obn : node->op().output_bns()) { const auto& lbi = node->op().BnInOp2Lbi(obn); node->LogicalBlobDesc4Lbi(lbi).ToProto( &(*helper->mutable_lbn2logical_blob_desc())[GenLogicalBlobName(lbi)]); } }); } void OpGraph::DumpNdSbpSignature(Job* job) const { ForEachNode([&](const OpNode* node) -> void { (*job->mutable_job_parallel_view_conf() ->mutable_op_name2nd_sbp_signature_conf())[node->op().op_name()] = *CHECK_JUST(node->op().nd_sbp_signature()); if (node->parallel_desc().hierarchy()->NumAxes() == 1) { (*job->mutable_job_parallel_view_conf() ->mutable_op_name2sbp_signature_conf())[node->op().op_name()] = node->sbp_signature(); } }); } void OpGraph::DumpArgSignature(Job* job) const { ForEachNode([&](const OpNode* node) { auto* op_arg_signature = &(*job->mutable_helper()->mutable_op_name2arg_signature())[node->op().op_name()]; for (const auto& ibn : node->op().input_bns()) { const auto& lbi = node->op().BnInOp2Lbi(ibn); (*op_arg_signature->mutable_bn_in_op2lbi())[ibn] = lbi; } for (const auto& obn : node->op().output_bns()) { const auto& lbi = node->op().BnInOp2Lbi(obn); (*op_arg_signature->mutable_bn_in_op2lbi())[obn] = lbi; } }); } Maybe OpGraph::ForEachOpNode(const std::function(const OpNode&)>& DoEach) const { HashMap visited; for (const auto& op_name : op_names_) { const OpNode& op_node = *op_name2op_node_.at(op_name); for (const auto& ibn : op_node.op().input_bns()) { const auto& lbi = op_node.op().BnInOp2Lbi(ibn); CHECK_OR_RETURN(visited[lbi]) << "input blob '" << ibn << "' is not defined\n" << lbi.DebugString() << "\n==== op_conf ====\n" << op_node.op().op_conf().DebugString(); } for (const auto& obn : op_node.op().output_bns()) { const auto& lbi = op_node.op().BnInOp2Lbi(obn); CHECK_OR_RETURN(!visited[lbi]) << "output blob '" << obn << "' is defined\n" << lbi.DebugString() << "\n==== op_conf ====\n" << op_node.op().op_conf().DebugString(); visited[lbi] = true; } JUST(DoEach(op_node)); } return Maybe::Ok(); } std::function OpGraph::CreatePredicatorIsReachable() const { return MakePredicatorIsReachable(); } // Print the graph with SBP in order void OpGraph::PrintSBPGraphDebugInfo() const { // test debug std::cout << "Get Into Print Op Graph" << std::endl; // Collect op_node std::vector NodeList; ForEachNode([&](OpNode* op_node) { NodeList.push_back(op_node); }); // test debug std::cout << "Deciding order" << std::endl; // Decide the order to vist the op std::vector order; auto_parallel::DecideOrder(NodeList, order, [&](OpNode* a, OpNode* b) { return a->op().op_name().compare(b->op().op_name()) > 0; }); std::vector str_order; // test debug std::cout << "Finish deciding order" << std::endl; for (int32_t i = 0; i < NodeList.size(); i++) { OpNode* op_node = NodeList[order[i]]; std::cout << op_node->op().op_name() << " (^_^):" << std::endl; // Sort before printing const auto& op_input_bns = op_node->op().input_bns(); auto comp = [](const std::string& a, const std::string& b) { return a.compare(b) > 0; }; auto_parallel::DecideOrder(op_input_bns, str_order, comp); // Print out SBP information for input operator for (int32_t j : str_order) { const auto& ibn = op_input_bns[j]; auto producer_node = op_node->MutSrcNode4Ibn(ibn); std::cout << "Pre Op:" << producer_node->op().op_name() << ": " << ibn; const auto& this_sbp_parallel = op_node->NdSbp4BnInOp(ibn); std::cout << ", " << NdSbpToString(this_sbp_parallel); const auto input_blob_modifier_ = op_node->op().InputBlobModifier4Ibn(ibn); bool is_same_sbp = input_blob_modifier_.has_is_mutable() && input_blob_modifier_.is_mutable(); if (is_same_sbp) std::cout << ", same SBP"; std::cout << ", " << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(ibn)).shape(); std::cout << std::endl; } // Sort before printing const auto& op_output_bns = op_node->op().output_bns(); auto_parallel::DecideOrder(op_output_bns, str_order, comp); // Print out SBP information for output blobs for (int32_t j : str_order) { const auto& obn = op_output_bns[j]; std::cout << "Out Op:" << obn; const auto& this_sbp_parallel = op_node->NdSbp4BnInOp(obn); std::cout << ", " << NdSbpToString(this_sbp_parallel); std::cout << ", " << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(obn)).shape(); std::cout << std::endl; } std::cout << std::endl; } } OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) { // new Singleton and set log configs. Singleton::New(job); const JobDesc& job_desc = GlobalJobDesc(); if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job); Singleton::Get()->ToDotWithFilePath( "optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot"); } } OpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton::Delete(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/op_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ #include "oneflow/core/graph/graph.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/local_parallel.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { namespace auto_parallel { class SbpConstructor; } class OpEdge; class OpGraph; class OpNode final : public Node { public: OF_DISALLOW_COPY_AND_MOVE(OpNode); explicit OpNode(Symbol parallel_desc, const OperatorConf& op_conf); ~OpNode() = default; // Getters bool IsTimeShapeIdentity() const; const Operator& op() const { return *op_; } std::shared_ptr shared_op() const { return op_; } const ParallelDesc& parallel_desc() const { return *parallel_desc_; } Symbol parallel_desc_sym() const { return parallel_desc_; } const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); } const NdSbpSignature& nd_sbp_signature() const { return *CHECK_JUST(op().nd_sbp_signature()); } const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const; const SbpParallel& SbpParallel4BnInOp(const std::string& bn_in_op) const; const NdSbp& NdSbp4Lbi(const LogicalBlobId& lbi) const; const NdSbp& NdSbp4BnInOp(const std::string& bn_in_op) const; const BlobDesc& LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const; const OpNode& ProducerOpNode4Lbi(const LogicalBlobId& lbi) const; const OpNode& SrcNode4Ibn(const std::string& bn_in_op) const; std::string VisualStr() const override; private: friend class OpGraph; friend class OpEdge; friend class auto_parallel::SbpConstructor; // Setters Operator* mut_op() { return op_.get(); } OpNode* MutSrcNode4Ibn(const std::string& bn_in_op) const; OpNode* MutSrcNode4InputLbi(const LogicalBlobId& lbi) const; void InitLbi2SourceNode(); void InitLbi2NdSbp(); Symbol parallel_desc_; std::shared_ptr op_; HashSet ibns_; HashMap lbi2source_node_; HashMap lbi2nd_sbp_; std::vector> input_index2producer_and_output_index_; }; class OpEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(OpEdge); explicit OpEdge(std::shared_ptr> lbis, std::shared_ptr> lbi2obn, std::shared_ptr>> lbi2ibns) : lbis_(std::move(lbis)), lbi2obn_(std::move(lbi2obn)), lbi2ibns_(std::move(lbi2ibns)) {} ~OpEdge() override = default; // Getters const std::vector& lbis() const { return *lbis_; } const HashMap& lbi2obn() const { return *lbi2obn_; } const HashMap>& lbi2ibns() const { return *lbi2ibns_; } bool NeedBoxing() const; std::string VisualStr() const override; private: std::shared_ptr> lbis_; std::shared_ptr> lbi2obn_; std::shared_ptr>> lbi2ibns_; }; class OpGraph final : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(OpGraph); explicit OpGraph(const Job& job) { CHECK_JUST(Init(job)); } explicit OpGraph() = default; ~OpGraph() override = default; static Maybe New(const Job& job); Maybe ForEachOpNode(const std::function(const OpNode&)>& DoEach) const; const OpNode* OpNode4OpName(const std::string& name) const; int64_t GetParallelNum(const std::string& op_name) const; const SbpParallel& GetSbpParallel(const std::string& op_name, const LogicalBlobId& lbi) const; const NdSbp& GetNdSbp(const std::string& op_name, const LogicalBlobId& lbi) const; DataType GetBlobDataType(const LogicalBlobId& lbi) const; const BlobDesc& GetLogicalBlobDesc(const LogicalBlobId& lbi) const; std::function MakePredicatorIsOpNameDataOrCtrlReachable() const; void ForEachDataAndCtrlInNode(OpNode* node, const std::function& Handler) const; void ForEachDataAndCtrlOutNode(OpNode* node, const std::function& Handler) const; void TopoForEachNodeWithCtrlEdge(const std::function& NodeHandler) const; // NOTE(chengcheng): For topo for each with ctrl edges. OpEdge is ONLY data edge. std::list DataOrCtrlSourceNodes() const; void DumpLogicalBlobDesc(Job* job) const; void DumpArgSignature(Job* job) const; void DumpNdSbpSignature(Job* job) const; Maybe Init(const Job& job); std::function CreatePredicatorIsReachable() const; // Print the graph with SBP in order void PrintSBPGraphDebugInfo() const; private: friend class auto_parallel::SbpConstructor; void InitNodes(const Job& job); void InitEdges(); void InitProducerOpName2CtrlConsumerOpNames(const Job& job); void CheckIsDAG() const; void InferBlobLastUsed() const; void InferTimeShape() const; void InferOpNodeNdSbpSignature(OpNode* op_node, const NdSbpSignature& nd_sbp_sig_conf) const; Maybe InferOpNodeLocalSignature(OpNode* op_node, bool is_local_conf) const; Maybe InferLogicalBlobDesc(const Job& job) const; std::string GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const; LogicalBlobId GetLogicalBlobIdKey(const std::string& op_name, const LogicalBlobId& lbi) const; std::function MakePredicatorIsDataOrCtrlReachable() const; HashMap op_name2op_node_; std::list op_names_; HashMap> producer_op_name2ctrl_consumer_op_names_; }; class OpGraphSingletonGuard { public: OF_DISALLOW_COPY_AND_MOVE(OpGraphSingletonGuard); explicit OpGraphSingletonGuard(const Job& job); ~OpGraphSingletonGuard(); }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ ================================================ FILE: oneflow/core/graph/plan_task_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/graph/plan_task_graph.h" namespace oneflow { PlanTaskGraph::PlanTaskGraph(const Plan& plan) : plan_(&plan) { InitNodes(); InitEdges(); } void PlanTaskGraph::InitNodes() { for (const auto& task : plan_->task()) { PlanTaskNode* plan_task_node = new PlanTaskNode(task); task_id2plan_task_node_.insert({task.task_id(), plan_task_node}); AddAllocatedNode(plan_task_node); } } void PlanTaskGraph::InitEdges() { for (const auto& task_id_and_plan_task_node : task_id2plan_task_node_) { PlanTaskNode* producer_node = task_id_and_plan_task_node.second; for (const auto& pair : producer_node->task_proto()->produced_regst_desc()) { for (int64_t consumer_task_id : pair.second.consumer_task_id()) { PlanTaskNode* consumer_node = CHECK_JUST(MapAt(task_id2plan_task_node_, consumer_task_id)); TryConnect(producer_node, consumer_node); } } } } void PlanTaskGraph::TryConnect(PlanTaskNode* src, PlanTaskNode* dst) { if (edges_.insert({src, dst}).second) { Connect(src, NewEdge(), dst); } } const TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const { return CHECK_JUST(MapAt(task_id2plan_task_node_, task_id))->task_proto(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/plan_task_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_ #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/graph/graph.h" namespace oneflow { class PlanTaskNode; class PlanTaskEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(PlanTaskEdge); PlanTaskEdge() = default; ~PlanTaskEdge() = default; }; class PlanTaskNode final : public Node { public: OF_DISALLOW_COPY_AND_MOVE(PlanTaskNode); explicit PlanTaskNode(const TaskProto& task_proto) : task_proto_(&task_proto) {} ~PlanTaskNode() = default; const TaskProto* task_proto() const { return task_proto_; } int64_t task_id() const { return task_proto_->task_id(); } int64_t chain_id() const { return task_proto_->chain_id(); } int64_t order_in_chain() const { return task_proto_->order_in_chain(); } private: const TaskProto* task_proto_; }; class PlanTaskGraph : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(PlanTaskGraph); explicit PlanTaskGraph(const Plan& plan); virtual ~PlanTaskGraph() = default; const TaskProto* TaskProto4TaskId(int64_t task_id) const; const Plan& plan() const { return *plan_; } protected: void InitNodes(); void InitEdges(); void TryConnect(PlanTaskNode* src, PlanTaskNode* dst); const Plan* plan_; HashMap task_id2plan_task_node_; HashSet> edges_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_ ================================================ FILE: oneflow/core/graph/slice_boxing_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/slice_boxing_task_node.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id) { out_slice_ = out_slice; out_shape_ = out_slice.shape(); mode_ = mode; set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); } void SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst_desc = ProduceRegst("out", true); this->ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst_desc); }); ProduceRegst("tmp", true); } void SliceBoxingTaskNode::ConsumeAllRegsts() { HashMap edge2order_; FOR_RANGE(int64_t, i, 0, ordered_in_data_edges_.size()) { edge2order_.emplace(ordered_in_data_edges_.at(i), i); } int64_t in_data_edge_cnt = 0; ForEachInDataEdge([&](TaskEdge* edge) { const auto order_it = edge2order_.find(edge); CHECK(order_it != edge2order_.end()); ConsumeRegst("in_" + std::to_string(order_it->second), edge->GetSoleRegst()); in_data_edge_cnt += 1; }); CHECK_EQ(in_data_edge_cnt, ordered_in_data_edges_.size()); } void SliceBoxingTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr op = CHECK_JUST(ConstructOp(GetBoxingOpConf())); node->mut_op() = op; FOR_RANGE(size_t, i, 0, op->input_bns().size()) { const std::string& ibn = op->input_bns().Get(i); CHECK_EQ(GenUnRepeatedBn(ibn).second, i); node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in_" + std::to_string(i))); } std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(lbi()); node->BindBnWithRegst(op->SoleObn(), out_regst); node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void SliceBoxingTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } void SliceBoxingTaskNode::SetInDataEdgeSlice(const TaskEdge* edge, const TensorSliceView& slice) { CHECK(in_data_edge2slice_.emplace(edge, slice).second); ordered_in_data_edges_.emplace_back(edge); } void SliceBoxingTaskNode::ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge, const TensorSliceView& slice) { edge->AddLbi(lbi()); Connect(src, edge, this); SetInDataEdgeSlice(edge, slice); } void SliceBoxingTaskNode::SetOutShape(const Shape& shape) { out_shape_ = shape; } OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() { OperatorConf op_conf{}; op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type()))); SliceBoxingConf boxing_conf{}; *boxing_conf.mutable_lbi() = lbi(); out_slice_.ToProto(boxing_conf.mutable_out_slice()); out_shape_.ToProto(boxing_conf.mutable_out_shape()); for (const TaskEdge* edge : ordered_in_data_edges_) { in_data_edge2slice_.at(edge).ToProto(boxing_conf.mutable_in_slice()->Add()); } if (mode_ == kSliceBoxingTaskModeCopy) { op_conf.set_name("System-Boxing-BoxingCopy-" + NewUniqueId()); SliceBoxingCopyOpConf* conf = op_conf.mutable_slice_boxing_copy_conf(); *conf->mutable_slice_boxing_conf() = boxing_conf; } else if (mode_ == kSliceBoxingTaskModeAdd) { op_conf.set_name("System-Boxing-BoxingAdd-" + NewUniqueId()); SliceBoxingAddOpConf* conf = op_conf.mutable_slice_boxing_add_conf(); *conf->mutable_slice_boxing_conf() = boxing_conf; } else { UNIMPLEMENTED(); } return op_conf; } Maybe SliceBoxingTaskNode::InitTransportTaskFromProto( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK_OR_RETURN(transport_task_proto.has_slice_boxing_task()) << "not a serialized SliceBoxingTaskNode. debug string: " << transport_task_proto.DebugString(); const auto& proto = transport_task_proto.slice_boxing_task(); for (const auto& pair : proto.in_data_edge_uid2slice()) { const auto* edge = JUST(ctx.TaskEdge4Uid(pair.first)); CHECK_OR_RETURN(in_data_edge2slice_.emplace(edge, pair.second).second) << "redundant edge found. edge_uid: " << pair.first; } for (int64_t edge_uid : proto.ordered_in_data_edge_uid()) { ordered_in_data_edges_.push_back(JUST(ctx.TaskEdge4Uid(edge_uid))); } out_slice_ = TensorSliceView(proto.out_slice()); out_shape_ = Shape(proto.out_shape()); mode_ = proto.mode(); return Maybe::Ok(); } void SliceBoxingTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); auto* proto = transport_task_proto->mutable_slice_boxing_task(); for (const auto& pair : in_data_edge2slice_) { int64_t edge_uid = reinterpret_cast(pair.first); pair.second.ToProto(&(*proto->mutable_in_data_edge_uid2slice())[edge_uid]); } for (const auto* edge : ordered_in_data_edges_) { proto->add_ordered_in_data_edge_uid(reinterpret_cast(edge)); } out_slice_.ToProto(proto->mutable_out_slice()); out_shape_.ToProto(proto->mutable_out_shape()); proto->set_mode(mode_); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/slice_boxing_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/memory/memory_zone.h" namespace oneflow { class SliceBoxingTaskNode final : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingTaskNode); SliceBoxingTaskNode() = default; ~SliceBoxingTaskNode() override = default; void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id); void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; TaskType GetTaskType() const override { return TaskType::kSliceBoxing; } void SetInDataEdgeSlice(const TaskEdge* edge, const TensorSliceView& slice); void ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge, const TensorSliceView& slice); void SetOutShape(const Shape& shape); Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) override; void ToTransportTaskProto(TransportTaskProto*) const override; private: void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; OperatorConf GetBoxingOpConf(); HashMap in_data_edge2slice_; std::vector ordered_in_data_edges_; TensorSliceView out_slice_; Shape out_shape_; SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/straighten_nodes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/util.h" #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/straighten_nodes.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { namespace { enum TaskClassifier : int { kWaitingOverlapNode = 0, kWaitingMainComputation = 1, kRunASAP = 2, kRunALAP = 3 }; // The difference between a descending order and its corresponding ascending order static const int kDiff4AscendDescend = 100; class TopoStruct { public: TaskNode* node = nullptr; int32_t min_layer = -1; int32_t tributary_layer = -1; bool on_trunk = false; int32_t counter = 0; int32_t min_distance2overlap = -1; int64_t memory_increment = -1; int32_t exceed_time = -1; int32_t min_lifetime = -1; int64_t memory_volume = -1; int32_t max_layer = -1; TaskClassifier task_classifier; std::string key; // We can have some other nodes in it for example // SbpNode* node; // SbpEdge* node; // Or we can omit all the pointers and leave all the useful parameters. int32_t ComputeMinLayer(HashMap* task_node2topo_struct, std::map>* key2topo_structs); // Drop down the tributary layer void DropTributaryLayer(int32_t upper_bound); void SpreadTributaryLayer(HashMap* task_node2topo_struct); void ComputeMaxLayer(HashMap* task_node2topo_struct); void SpreadTrunk(HashMap* task_node2topo_struct); // The minimum computation distance from the beginning of this op to the next overlap node int32_t GetMinDistance2Overlap(HashMap* task_node2topo_struct); // Memory increment = (memory of out registers) - (memory of in registers) void ComputeMemoryIncrement(); // Exceed time = time of cpu - time of gpu // For most operators, the execution time on gpu exceed the execution time on cpu. // However, overlap is needed if time of cpu > time of gpu. void ComputeExceedTime(); // Memory volume is memory * lifetime, but we might change the formula void ComputeMemoryVolume(); // TODO: We might design more deciding parameter and choose a right combination of them in the // future. // deciding parameter // kTributaryLayerAscend = 0, // small tributary layers go first // kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first // kLayerAscend = 2, // first in first out // kMemoryIncrementAscend = 3, // small memory increment go first // kExceedTimeAscend = 4, // small exceed time go first // kTributaryLayerDescend = 100, // large tributary layers go first // kDistanceToOverlapDescend = 101, // long distance to overlap go first // kLayerDescend = 102, // last in first out // kMemoryIncrementDescend = 103, // large memory increment go first // kExceedTimeDescend = 104, // large exceed time go first int64_t GetDecidingParameter(StraightenOrder so) const; }; static StraightenAlgorithmTag sat; // NOTE: Leave these code for debugging in the future // static std::vector decide_parameters({ParseIntegerFromEnv("Parameter0", 3), // ParseIntegerFromEnv("Parameter1", 0), // ParseIntegerFromEnv("Parameter2", 3)}); // The best parameter set for saving time is {102, 100} // The best parameter set for saving memory is {3, 0} static std::vector decide_parameters; // move the head from source to target void MoveFrontBetweenMaps(std::map& source, std::map& target) { if (!source.empty()) { const auto& front = source.begin(); target[front->first] = front->second; source.erase(front); } }; bool ShouldRunASAP(TaskType task_type) { // They are sorted according to frequency of occurrences switch (task_type) { // We mark the number of occurrences in bert case TaskType::kDeviceTick: // 38 case TaskType::kTick: // 8 case TaskType::kSrcSubsetTick: // 6 case TaskType::kDstSubsetTick: // 6 case TaskType::kCriticalSectionWaitTick: // 4 case TaskType::kWaitAndSendIds: // 2 case TaskType::kPack: // 0 case TaskType::kUnpack: // 0 case TaskType::kRepeat: // 0 case TaskType::kAcc: // 0 case TaskType::kSourceTick: // 0 case TaskType::kAccTick: // 0 case TaskType::kAccCtrlTick: // ? case TaskType::kCase: // 0 case TaskType::kEsac: // 0 case TaskType::kReentrantLock: return true; // 0 default: return false; } } bool IsTransferNode(TaskType task_type) { // return task_type == 12 || task_type == 13 || (48 <= task_type && task_type <= 64); // They are sorted according to frequency of occurrences switch (task_type) { // We mark the number of occurrences in bert case TaskType::kCollectiveBoxingGeneric: // 76 case TaskType::kNcclSendRecvBoxing: // ? case TaskType::kCopyHd: // 27 case TaskType::kSliceBoxing: // 16 case TaskType::kCopyCommNet: // 12 case TaskType::kCollectiveBoxingPack: // 8 case TaskType::kCollectiveBoxingUnpack: // 8 case TaskType::kBoxingZeros: // 3 case TaskType::kDistributeConcat: // 0 case TaskType::kDistributeSplit: // 0 case TaskType::kBoxingIdentity: // 0 case TaskType::kDecodeH2D: // 0 case TaskType::kSspVariableProxy: return true; // 0 default: return false; } } // Classifier for the set according to the task type TaskClassifier GetTaskClassifier(const TaskNode* node, bool nccl_use_compute_stream) { // Check task.pb.h for detail // They are sorted according to frequency of judgement // frequency of judgement = the number of occurrences / the times of judgement TaskType task_type = node->GetTaskType(); if (task_type == TaskType::kNormalForward) { const auto& op_conf = dynamic_cast(node)->op()->op_conf(); if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu && ShortGpuTime(op_conf)) { return TaskClassifier::kWaitingOverlapNode; } else { return TaskClassifier::kWaitingMainComputation; } } if (IsTransferNode(task_type)) { if (sat == StraightenAlgorithmTag::kCompressMemory && nccl_use_compute_stream) { // Overlap is not the first consideration, memory is return TaskClassifier::kWaitingMainComputation; } else { return TaskClassifier::kWaitingOverlapNode; } } if (task_type == TaskType::kCallbackNotify) { return TaskClassifier::kRunALAP; } if (ShouldRunASAP(task_type)) { return TaskClassifier::kRunASAP; } CHECK(false) << "Unclassified or invalid task type (" << task_type << ") showing up"; // Throw a kRunASAP which means ignoring this node in the algorithm return TaskClassifier::kRunASAP; } int32_t MaxProducerMinLayer(HashMap* task_node2topo_struct, std::map>* key2topo_structs, TaskNode* node) { int32_t max_min_layer = 0; node->ForEachNodeOnInEdge([&](TaskNode* in) { max_min_layer = std::max(max_min_layer, task_node2topo_struct->at(in).ComputeMinLayer( task_node2topo_struct, key2topo_structs)); }); return max_min_layer + 1; } int32_t TopoStruct::ComputeMinLayer( HashMap* task_node2topo_struct, std::map>* key2topo_structs) { // Directly return the value if computed if (min_layer > -1) { return min_layer; } auto transport_task_node = dynamic_cast(node); if (transport_task_node) { // Only compute the minimum layer for this transport node min_layer = MaxProducerMinLayer(task_node2topo_struct, key2topo_structs, node); // Generate the key to determine the same task nodes // Since the key is connected with the min_layer for transport nodes key = transport_task_node->lbi().ShortDebugString() + "MinLayer:" + std::to_string(min_layer); // Gather all the task nodes with the same key (*key2topo_structs)[key].push_back(this); } else { // Compute the minimum layer for all the nodes with the same key simultaneously int32_t max_min_layer = -1; for (auto& curr_topo_struct : key2topo_structs->at(key)) { max_min_layer = std::max( max_min_layer, MaxProducerMinLayer(task_node2topo_struct, key2topo_structs, curr_topo_struct->node)); } for (auto& curr_topo_struct : key2topo_structs->at(key)) { curr_topo_struct->min_layer = max_min_layer; } } return min_layer; } // Drop down the maximum layer with the minimum layer from consumer void TopoStruct::DropTributaryLayer(int32_t upper_bound) { if (upper_bound < tributary_layer || tributary_layer < 0) { tributary_layer = upper_bound; } } // Should initialize the counter to be the number of out edges // Compute maximum layer for tributaries void TopoStruct::SpreadTributaryLayer(HashMap* task_node2topo_struct) { if (counter || min_layer <= 0) { return; } int32_t producer_max_lay = 0; if (on_trunk) { producer_max_lay = min_layer - 1; } else { // On a tributary, the operator could be run later. producer_max_lay = tributary_layer; } node->ForEachNodeOnInEdge([&](TaskNode* in) { auto& topo_struct_in = task_node2topo_struct->at(in); topo_struct_in.DropTributaryLayer(producer_max_lay); --topo_struct_in.counter; if (topo_struct_in.counter == 0) { topo_struct_in.SpreadTributaryLayer(task_node2topo_struct); } }); // Reduce counter to -1 to avoid visiting again counter--; } void TopoStruct::ComputeMaxLayer(HashMap* task_node2topo_struct) { node->ForEachNodeOnOutEdge([&](TaskNode* out) { max_layer = std::max(max_layer, task_node2topo_struct->at(out).min_layer); }); } // Judge if this node is on the trunk // If so, judge it for its producer/upstream nodes void TopoStruct::SpreadTrunk(HashMap* task_node2topo_struct) { // Skip it if this node is already judged. if (on_trunk) { return; } CHECK_GE(min_layer, 0) << "TopoStruct not initialized!"; on_trunk = true; // If I am in the trunk, then all the children with (min_layer >= my layer id - 1) would be // considered as in the trunk node->ForEachNodeOnInEdge([&](TaskNode* in) { auto& topo_struct_in = task_node2topo_struct->at(in); if (topo_struct_in.min_layer == min_layer - 1) { topo_struct_in.SpreadTrunk(task_node2topo_struct); } }); } // The minimum computation distance from the beginning of this op to the next overlap int32_t TopoStruct::GetMinDistance2Overlap(HashMap* task_node2topo_struct) { if (min_distance2overlap >= 0) { return min_distance2overlap; } // if this node should be overlapped by main computation nodes if (task_classifier == TaskClassifier::kWaitingOverlapNode) { min_distance2overlap = 0; return min_distance2overlap; } // Otherwise, initialize it with a large number // Well, the total number in the task graph is large enough min_distance2overlap = task_node2topo_struct->size(); node->ForEachNodeOnOutEdge([&](TaskNode* out) { min_distance2overlap = std::min(min_distance2overlap, task_node2topo_struct->at(out).GetMinDistance2Overlap(task_node2topo_struct)); }); ++min_distance2overlap; return min_distance2overlap; } // Memory increment = (memory of out registers) - (memory of in registers) void TopoStruct::ComputeMemoryIncrement() { if (memory_increment < 0) { memory_increment = 0; for (const auto& produced_register : node->produced_regsts()) { if (produced_register.second->enable_reuse_mem()) { RegstDescProto temp_proto; produced_register.second->ToProto(&temp_proto); memory_increment += RtRegstDesc(temp_proto).TotalMainByteSize4AllRegst(); } } for (const auto& consumed_register_list : node->consumed_regsts()) { for (const auto& consumed_register : consumed_register_list.second) { if (consumed_register->enable_reuse_mem()) { RegstDescProto temp_proto; consumed_register->ToProto(&temp_proto); memory_increment -= RtRegstDesc(temp_proto).TotalMainByteSize4AllRegst() / consumed_register->consumers().size(); } } } } } // Exceed time = time of cpu - time of gpu void TopoStruct::ComputeExceedTime() { if (node->GetTaskType() == TaskType::kNormalForward && ShortGpuTime(dynamic_cast(node)->op()->op_conf())) { exceed_time = 1; } else { exceed_time = 0; } } // Memory volume is memory * lifetime, but we might change the formula void TopoStruct::ComputeMemoryVolume() { static float lifetime_order = ParseFloatFromEnv("LifetimeOrder", 1.0); // We might get a large tensor multiply by a long life time, we need some rescaling memory_volume = static_cast( (memory_increment * pow(static_cast(min_lifetime), lifetime_order)) / 1000.0); // We need to distinguish zero or negative memory increment from slight positive memory increment. // Make sure that we execute -0.1, 0, -0.003 before 0.1, 0.2 if (memory_increment > 0) { memory_volume += 1; } } // deciding parameter // kTributaryLayerAscend = 0, // small tributary layers go first // kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first // kLayerAscend = 2, // first in first out // kMemoryIncrementAscend = 3, // small memory increment go first // kExceedTimeAscend = 4, // small exceed time go first // kMemoryVolumeAscend = 5, // small memory volume go first // kTributaryLayerDescend = 100, // large tributary layers go first // kDistanceToOverlapDescend = 101, // long distance to overlap go first // kLayerDescend = 102, // last in first out // kMemoryIncrementDescend = 103, // large memory increment go first // kExceedTimeDescend = 104, // large exceed time go first // kMemoryVolumeAscend = 105, // large memory volume go first int64_t TopoStruct::GetDecidingParameter(StraightenOrder so) const { int64_t sign = 1; if (so >= kDiff4AscendDescend) { so = StraightenOrder(int(so) - kDiff4AscendDescend); sign = -1; } switch (so) { case StraightenOrder::kTributaryLayerAscend: return sign * tributary_layer; case StraightenOrder::kDistanceToOverlapAscend: return sign * min_distance2overlap; case StraightenOrder::kLayerAscend: return sign * min_layer; case StraightenOrder::kMemoryIncrementAscend: return sign * memory_increment; case StraightenOrder::kExceedTimeAscend: return sign * exceed_time; case StraightenOrder::kMemoryVolumeAscend: return sign * memory_volume; case StraightenOrder::kMaxLayerAscend: return sign * max_layer; default: return 0; } } // Find the trunk of the task graph, then reduce the wait time for tributaries void FindTrunk(HashMap* task_node2topo_struct) { // Find the maximum layer number int32_t max_min_layer = -1; for (const auto& pair : *task_node2topo_struct) { if (max_min_layer < pair.second.min_layer) { max_min_layer = pair.second.min_layer; } } // All the nodes with min_layer>=trunk_end_id would be considered as trunk nodes // The last 5 layers would be considered as in trunk anyway. int32_t trunk_end_id = max_min_layer - 4; for (auto& pair : *task_node2topo_struct) { auto& topo_struct = pair.second; // Initialize the counter and Tributary Layer topo_struct.counter = pair.first->out_edges().size(); topo_struct.tributary_layer = max_min_layer; // Find out all the nodes on the trunk. if (topo_struct.min_layer >= trunk_end_id) { topo_struct.SpreadTrunk(task_node2topo_struct); } } for (auto& pair : *task_node2topo_struct) { // Compute maximum layer for tributaries pair.second.SpreadTributaryLayer(task_node2topo_struct); // Set the min_distance2overlap for each topological structure pair.second.GetMinDistance2Overlap(task_node2topo_struct); } // The computation of maximum layer must behind those of minimum layer for the whole graph. for (auto& pair : *task_node2topo_struct) { pair.second.ComputeMaxLayer(task_node2topo_struct); } } // Find the minimum life time of the task graph, // which is the maximum of the minimum layer among all the consumers. // The function must be executed after generating min layer void FindMinLifetime(HashMap* task_node2topo_struct) { // Find the maximum consumer layer for (auto& pair : *task_node2topo_struct) { int32_t curr_min_layer = pair.second.min_layer; pair.first->ForEachNodeOnInDataEdge([&](TaskNode* in) { auto& max_consumer_layer = task_node2topo_struct->at(in).min_lifetime; if (max_consumer_layer < curr_min_layer) { max_consumer_layer = curr_min_layer; } }); } // Compute the life time for (auto& pair : *task_node2topo_struct) { if (pair.second.min_layer >= pair.second.min_lifetime) { // No consumer, the register will be killed after the execution of the current operator // The life time is 1 (including the current operator) pair.second.min_lifetime = 1; } else { // The life time is the distance between two operators + 1 // For example, a ---(x)---> b // Register x is created while executing a, and x is killed after the execution of b. // The life time is 2 (including a and b) == b.lifetime - a.lifetime pair.second.min_lifetime -= pair.second.min_layer - 1; } pair.second.ComputeMemoryVolume(); } } } // anonymous namespace // Some operators have longer time in cpu and less time in gpu. // Running those operators without overlap would cause large gap during each iteration. // For example, expand dims would not execute any kernel on gpu but still need 10us to execute some // functions on cpu. bool ShortGpuTime(const OperatorConf& op_conf) { if (op_conf.has_variable_conf()) { // Variable operators would not be run. They just create tensors. // We do not visualize any execution in NVTX. (Even a tick operator has something in NVTX.) return true; } if (op_conf.has_user_conf()) { const auto& op_type_name = op_conf.user_conf().op_type_name(); // They are sorted according to frequency of occurrences in stable diffusion if (op_type_name == "expand_dims" // 90 || op_type_name == "cast" // 16 || op_type_name == "expand" // 2 ) { return true; } } return false; } // SAT, a.k.a. Scholastic Aptitude Test, // is the college admission test in the United States of America. void InitDecideParameters(StraightenAlgorithmTag sat, std::vector* decide_parameters) { decide_parameters->clear(); if (sat == StraightenAlgorithmTag::kCompressMemory) { decide_parameters->push_back(StraightenOrder::kMemoryVolumeAscend); decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend); decide_parameters->push_back(StraightenOrder::kTributaryLayerAscend); } else if (sat == StraightenAlgorithmTag::kOverlap4Transfer) { decide_parameters->push_back(StraightenOrder::kLayerDescend); decide_parameters->push_back(StraightenOrder::kTributaryLayerDescend); } else if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu) { decide_parameters->push_back(StraightenOrder::kExceedTimeDescend); decide_parameters->push_back(StraightenOrder::kLayerDescend); decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend); } else if (sat == StraightenAlgorithmTag::kDelayShortGpu) { decide_parameters->push_back(StraightenOrder::kExceedTimeAscend); decide_parameters->push_back(StraightenOrder::kMaxLayerAscend); decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend); } else { // sat == StraightenAlgorithmTag::kDisable decide_parameters->push_back(StraightenOrder::kLayerAscend); } } // Maximum overlap number // While running an overlap operator, we would run some other operators simultaneously. int32_t MaximumOverlapNum(StraightenAlgorithmTag sat, bool nccl_use_compute_stream) { if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu) { // 10 operators on GPU is enough to cover the time for a CPU operator return 10; } // This condition should be following the sat == StraightenAlgorithmTag::kOverlap4CpuGpu // Since the kOverlap4CpuGpu would not be affected by transfer. if (nccl_use_compute_stream) { // Using nccl compute stream would disable the overlap for transfer // We need to reduce it to 1 return 1; } if (sat == StraightenAlgorithmTag::kCompressMemory) { // Actually we do not need the overlap. // Time is not the main consideration, memory is. return 2; } // The default number is 10. Mainly for sat == StraightenAlgorithmTag::kOverlap4Transfer // sat == StraightenAlgorithmTag::kDisable does not need a maximum overlap number. return 10; } void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task_nodes, bool nccl_use_compute_stream) { // Generate topological data structure for each task node HashMap task_node2topo_struct; // Determine the same nodes which should run simultaneously by the keys std::map> key2topo_structs; task_graph->TopoForEachNode([&](TaskNode* node) { auto& topo_struct = task_node2topo_struct[node]; topo_struct.node = node; topo_struct.ComputeMemoryIncrement(); topo_struct.ComputeExceedTime(); // Generate the key to determine the same task nodes if (dynamic_cast(node)) { // Deal with the key and the same task nodes later return; // topo_struct.key = dynamic_cast(node)->lbi().ShortDebugString(); } else if (node->GetTaskType() == TaskType::kNormalForward) { topo_struct.key = dynamic_cast(node)->op()->op_name(); } else { topo_struct.key = node->VisualStr(); } // Gather all the task nodes with the same key key2topo_structs[topo_struct.key].push_back(&topo_struct); }); // Compute all the min layer and generate the rest of the keys for (auto& pair : task_node2topo_struct) { pair.second.ComputeMinLayer(&task_node2topo_struct, &key2topo_structs); } // Generate other parameters in the topological data structure FindTrunk(&task_node2topo_struct); FindMinLifetime(&task_node2topo_struct); // Update sat, since sat might be changed in previous jobs UpdateSat(task_node2topo_struct, &sat); // Decide the task classifier after updating sat for (auto& pair : task_node2topo_struct) { pair.second.task_classifier = GetTaskClassifier(pair.first, nccl_use_compute_stream); } // Check the task classifier for all the nodes with the same key for (auto& pair : key2topo_structs) { TaskClassifier first_task_classifier = pair.second.at(0)->task_classifier; for (auto& topo_struct : pair.second) { CHECK_EQ(first_task_classifier, topo_struct->task_classifier) << " We have different task classifier " << first_task_classifier << " and " << topo_struct->task_classifier << " for the nodes with the same key: " << pair.first; } } // Decide which node should run first InitDecideParameters(sat, &decide_parameters); VLOG(3) << "Straightening order: "; for (int32_t decide_parameter : decide_parameters) { VLOG(3) << decide_parameter; } // Order in the waiting sets struct comp { bool operator()(const TopoStruct* a, const TopoStruct* b) const { for (auto decide_parameter : decide_parameters) { auto decide_parameter_a = a->GetDecidingParameter(decide_parameter); auto decide_parameter_b = b->GetDecidingParameter(decide_parameter); if (decide_parameter_a != decide_parameter_b) { return decide_parameter_a < decide_parameter_b; } } return a->node < b->node; } }; // Classify sets for the task nodes // 0, TaskClassifier::kWaitingOverlapNode // It contains transfer nodes, and those with less time in gpu if request. // std::set waiting_overlap_node; // 1, TaskClassifier::kWaitingMainComputation // std::set waiting_main_computation; // 2, TaskClassifier::kRunASAP , run as soon as possible // std::set run_asap; // 3, TaskClassifier::kRunALAP , run as late as possible // std::set run_alap; const int32_t num_classifier = 4; std::vector> waiting_lists(num_classifier); std::vector remain_task_nums(num_classifier, 0); auto AddOrderedNodes = [&](TaskNode* task_node) { ordered_task_nodes->emplace_back(task_node); }; // wait in the list auto wait = [&](TaskNode* node) { TopoStruct* first_topo_struct = &task_node2topo_struct[node]; // Check if all the same nodes are ready simultaneously for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) { if (curr_topo_struct->counter) { return; } } // Add all the same nodes at the same time auto& waiting_list = waiting_lists[first_topo_struct->task_classifier]; for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) { waiting_list.insert(curr_topo_struct); // Reduce counter then this node will never be added again // Though inserting into a map twice does not matter because of the same keys curr_topo_struct->counter--; } }; // initialization task_graph->ForEachNode([&](TaskNode* node) { int32_t count = node->in_edges().size(); auto& topo_struct = task_node2topo_struct[node]; topo_struct.counter = count; if (count == 0) { wait(node); } remain_task_nums[topo_struct.task_classifier]++; }); // Finish execution auto finish_execution = [&](TaskNode* node) { node->ForEachNodeOnOutEdge([&](TaskNode* out) { --(task_node2topo_struct[out].counter); if (task_node2topo_struct[out].counter == 0) { wait(out); } }); }; // Find the iterator of an element in set // Make sure that the element exist in the set before using this function auto FindElementInSet = [&](TopoStruct* element, std::set& set) { auto it = set.find(element); // NOTE: In some cases, the set can not find this element // Tested in machine-16: // Deleting: 0x7f75041d64c0, size: 4: // 0x7f75041d64c0, 0x7f75040d7390, 0x7f7504384540, 0x7f75042bc410, // Find: 0x4 // Or it may have the chance to delete multiple elements while deleting one element. CHECK(it != set.end() && *it == element) << " Something happens. If you make sure that the element exist in the set but you still " "can not find that element, please report this issue to Oneflow Inc."; // TODO: One simple resolution is to traverse all the elements in the set and find the // corresponding iterator. But it is not recommended. If std::set do have problem, we may need // to implement our own set. Or we find out the problematic version of std and make it clear to // the users that we do not support that version. // We may be able to reproduce the bug in the commit 0c06021c7e48d2e84d20e555e4f4dfbaf04a5e7b // by running // ONEFLOW_LAZY_COMPILE_MODE="rank_per_thread" ONEFLOW_TEST_DEVICE_NUM=4 python3 -m // oneflow.distributed.launch --nproc_per_node 4 -m unittest discover . --failfast --verbose // under the path oneflow/python/oneflow/test/graph // We still need to delete the file test_alexnet_auto_parallel.py before running the command. return it; }; // Since the erase function call the find function // we also need to reset the erase function auto EraseElementInSet = [&](TopoStruct* element, std::set& set) { set.erase(FindElementInSet(element, set)); }; // Move the first node of the waiting list to the execution list auto move2execution_list = [&](std::set& waiting_list, std::vector& execution_list) { TaskNode* first_node = (*waiting_list.begin())->node; int32_t execution_num = 0; TopoStruct* first_topo_struct = &task_node2topo_struct[first_node]; // Find all the same nodes in different machines which should be run simultaneously for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) { execution_num++; execution_list.push_back(curr_topo_struct->node); EraseElementInSet(curr_topo_struct, waiting_list); } CHECK_GT(execution_num, 0) << "Error, no task nodes are moved to the execution list"; }; // Execute the first n nodes in the waiting list auto execute = [&](int32_t list_classifier, int32_t n, bool if_reverse = false) { // n > 0 if (n <= 0) { return; } auto& waiting_list = waiting_lists[list_classifier]; std::vector execution_list; int32_t count = 0; // Move to the execution list while (!waiting_list.empty()) { move2execution_list(waiting_list, execution_list); count++; if (count >= n) { break; } } remain_task_nums[list_classifier] -= execution_list.size(); // Set the order and then remove from the execution list for (auto* node : execution_list) { AddOrderedNodes(node); finish_execution(node); } }; // straightening int32_t maximum_overlap_num = MaximumOverlapNum(sat, nccl_use_compute_stream); while (true) { if (waiting_lists[TaskClassifier::kRunASAP].empty()) { if (waiting_lists[TaskClassifier::kWaitingOverlapNode].empty()) { if (waiting_lists[TaskClassifier::kWaitingMainComputation].empty()) { if (waiting_lists[TaskClassifier::kRunALAP].empty()) { // All the waiting lists are empty break; } else { // Execute all the nodes left execute(TaskClassifier::kRunALAP, waiting_lists[TaskClassifier::kRunALAP].size()); } } else { // Execute one computation node execute(TaskClassifier::kWaitingMainComputation, 1); } } else { int32_t computation_num = std::min( std::min(int32_t(waiting_lists[TaskClassifier::kWaitingMainComputation].size() / (waiting_lists[TaskClassifier::kWaitingOverlapNode].size())), remain_task_nums[TaskClassifier::kWaitingMainComputation] / remain_task_nums[TaskClassifier::kWaitingOverlapNode]), maximum_overlap_num); // Holding the node to be overlapped std::vector overlap_execution_list; move2execution_list(waiting_lists[TaskClassifier::kWaitingOverlapNode], overlap_execution_list); remain_task_nums[TaskClassifier::kWaitingOverlapNode] -= overlap_execution_list.size(); for (auto* overlap_node : overlap_execution_list) { AddOrderedNodes(overlap_node); } // Overlap the node with computation from the trunk execute(TaskClassifier::kWaitingMainComputation, computation_num); // Release the overlap node for (auto* overlap_node : overlap_execution_list) { finish_execution(overlap_node); } } } else { execute(TaskClassifier::kRunASAP, waiting_lists[TaskClassifier::kRunASAP].size()); } } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/straighten_nodes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ #define ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/job/job_conf.pb.h" namespace oneflow { // The difference between a descending order and its corresponding ascending order const int kDiff4AscendDescend = 100; // deciding parameter // The sorting order of nodes for the straighten algorithm enum StraightenOrder : int { kTributaryLayerAscend = 0, // small tributary layers go first kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first kLayerAscend = 2, // first in first out kMemoryIncrementAscend = 3, // small memory increment go first kExceedTimeAscend = 4, // small exceed time go first kMemoryVolumeAscend = 5, // small memory volume go first kMaxLayerAscend = 6, // the urgent one go first kTributaryLayerDescend = kDiff4AscendDescend + kTributaryLayerAscend, // large tributary layers go first kDistanceToOverlapDescend = kDiff4AscendDescend + kDistanceToOverlapAscend, // long distance to overlap go first kLayerDescend = kDiff4AscendDescend + kLayerAscend, // last in first out kMemoryIncrementDescend = kDiff4AscendDescend + kMemoryIncrementAscend, // large memory increment go first kExceedTimeDescend = kDiff4AscendDescend + kExceedTimeAscend, // large exceed time go first kMemoryVolumeDescend = kDiff4AscendDescend + kMemoryVolumeAscend, // large memory volume go first kMaxLayerDescent = kDiff4AscendDescend + kMaxLayerAscend, // the non-urgent one go first }; // Some operators have longer time in cpu and less time in gpu. // Running those operators without overlap would cause large gap during each iteration. // For example, expand dims would not execute any kernel on gpu but still need 10us to execute some // functions on cpu. bool ShortGpuTime(const OperatorConf& op_conf); // SAT, a.k.a. Scholastic Aptitude Test, // is the college admission test in the United States of America. void InitDecideParameters(StraightenAlgorithmTag sat, std::vector* decide_parameters); // Maximum overlap number // While running an overlap operator, we would run some other operators simultaneously. int32_t MaximumOverlapNum(StraightenAlgorithmTag sat, bool nccl_use_compute_stream); template void UpdateSat(const HashMapType& node2topo_struct, StraightenAlgorithmTag* sat) { *sat = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); if (*sat == StraightenAlgorithmTag::kOverlap4CpuGpu) { // If not cpu nodes, then the overlap strategy between cpu and gpu might consume large memory bool exist_cpu_nodes = false; for (const auto& pair : node2topo_struct) { // Found a cpu node if (pair.second.exceed_time == 1) { exist_cpu_nodes = true; break; } } if (!exist_cpu_nodes) { // Switch to the compress memory strategy, the default one // Since the overlap strategy for transfer might not be working on 1n1d. *sat = StraightenAlgorithmTag::kCompressMemory; } } } // Make sure that we use the same boolean value nccl_use_compute_stream through the straighten // algorithm void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task_nodes, bool nccl_use_compute_stream); } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ ================================================ FILE: oneflow/core/graph/stream_id.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/stream_id.h" #include namespace oneflow { // StreamId encoding (bits) // | reserved | node_index | device_type | device_index | stream_index | // | -- 18 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- | | // | | DeviceId | | // | | ------------------- 31 --------------------- | ---- 15 ---- | // | StreamId | // | -------------------------------- 64 ---------------------------------- | namespace { constexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT; constexpr size_t kDeviceIndexShift = StreamId::kStreamIndexBits; constexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits; constexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits; static_assert(kRankShift + DeviceId::kRankBits < kInt64Bits, ""); constexpr int64_t kStreamIndexInt64Mask = (int64_t{1} << StreamId::kStreamIndexBits) - 1; constexpr int64_t kDeviceIndexInt64Mask = ((int64_t{1} << DeviceId::kDeviceIndexBits) - 1) << kDeviceIndexShift; constexpr int64_t kDeviceTypeInt64Mask = ((int64_t{1} << DeviceId::kDeviceTypeBits) - 1) << kDeviceTypeShift; constexpr int64_t kRankInt64Mask = ((int64_t{1} << DeviceId::kRankBits) - 1) << kRankShift; } // namespace int64_t EncodeStreamIdToInt64(const StreamId& stream_id) { int64_t id = static_cast(stream_id.stream_index()); id |= static_cast(stream_id.device_index()) << kDeviceIndexShift; id |= static_cast(stream_id.device_type()) << kDeviceTypeShift; id |= static_cast(stream_id.rank()) << kRankShift; return id; } StreamId DecodeStreamIdFromInt64(int64_t stream_id_val) { int64_t rank = (stream_id_val & kRankInt64Mask) >> kRankShift; int64_t device_type = (stream_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift; int64_t device_index = (stream_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift; int64_t stream_index = (stream_id_val & kStreamIndexInt64Mask); return StreamId{static_cast(rank), static_cast(device_type), static_cast(device_index), static_cast(stream_index)}; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/stream_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_STREAM_ID_H_ #define ONEFLOW_CORE_GRAPH_STREAM_ID_H_ #include "oneflow/core/device/device_id.h" namespace oneflow { class StreamId { public: using stream_index_t = uint32_t; constexpr static size_t kStreamIndexBits = 15; constexpr static stream_index_t kMaxStreamIndex = (stream_index_t{1} << kStreamIndexBits) - stream_index_t{1}; StreamId(const DeviceId& device_id, stream_index_t stream_index) : device_id_(device_id), stream_index_(stream_index) { CHECK_LE(stream_index, kMaxStreamIndex); } StreamId(DeviceId::rank_t node_index, DeviceType device_type, DeviceId::device_index_t device_index, stream_index_t stream_index) : device_id_(node_index, device_type, device_index), stream_index_(stream_index) { CHECK_LE(stream_index, kMaxStreamIndex); } const DeviceId& device_id() const { return device_id_; } DeviceId::rank_t rank() const { return device_id_.rank(); } DeviceType device_type() const { return device_id_.device_type(); } DeviceId::device_index_t device_index() const { return device_id_.device_index(); } stream_index_t stream_index() const { return stream_index_; } bool operator==(const StreamId& rhs) const { return device_id_ == rhs.device_id_ && stream_index_ == rhs.stream_index_; } bool operator!=(const StreamId& rhs) const { return !(*this == rhs); } size_t hash() const { size_t hash = device_id_.hash(); HashCombine(&hash, std::hash{}(stream_index_)); return hash; } private: DeviceId device_id_; stream_index_t stream_index_; }; int64_t EncodeStreamIdToInt64(const StreamId&); StreamId DecodeStreamIdFromInt64(int64_t); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::StreamId& stream_id) const { return stream_id.hash(); } }; } // namespace std #endif // ONEFLOW_CORE_GRAPH_STREAM_ID_H_ ================================================ FILE: oneflow/core/graph/stream_index_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/stream_index_generator.h" #include #include "oneflow/core/job/id_state.h" namespace oneflow { StreamIndexGenerator::StreamIndexGenerator(stream_index_t stream_index) : next_stream_index_(stream_index) {} StreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateAnonymous() { std::unique_lock lck(mtx_); return next_stream_index_++; } StreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateNamed(const std::string& name) { return GenerateNamedRoundRobin(name, 1); } StreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateNamedRoundRobin( const std::string& name, size_t size) { CHECK_GT(size, 0); std::unique_lock lck(mtx_); auto it = name2rr_range_.find(name); if (it == name2rr_range_.end()) { it = name2rr_range_.emplace(name, RoundRobinRange{next_stream_index_, size}).first; next_stream_index_ += size; } else { CHECK_EQ(it->second.size, size) << name; } stream_index_t cur_stream_index = it->second.begin; if (size > 1) { size_t& offset = it->second.offset; cur_stream_index += offset++; if (offset >= size) { offset = 0; } } return cur_stream_index; } StreamIndexGenerator::stream_index_t StreamIndexGenerator::GetCurrStreamIndex() { std::unique_lock lck(mtx_); return next_stream_index_; } void StreamIndexGenerator::TryUpdateNextStreamIndex(stream_index_t next_stream_index) { std::unique_lock lck(mtx_); next_stream_index_ = std::max(next_stream_index_, next_stream_index); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/stream_index_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_ #define ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_ #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/job/id_state.h" namespace oneflow { class StreamIndexGenerator final { public: using stream_index_t = StreamId::stream_index_t; explicit StreamIndexGenerator(stream_index_t stream_index); OF_DISALLOW_COPY_AND_MOVE(StreamIndexGenerator); ~StreamIndexGenerator() = default; stream_index_t GenerateAnonymous(); stream_index_t GenerateNamed(const std::string& name); stream_index_t GenerateNamedRoundRobin(const std::string& name, size_t size); stream_index_t GetCurrStreamIndex(); void TryUpdateNextStreamIndex(stream_index_t next_stream_index); private: struct RoundRobinRange { RoundRobinRange(stream_index_t begin, size_t size) : begin(begin), size(size), offset(0) {} stream_index_t begin; size_t size; size_t offset; }; stream_index_t next_stream_index_; HashMap name2rr_range_; std::mutex mtx_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_ ================================================ FILE: oneflow/core/graph/task_edge.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/register/logical_blob_id.proto"; message TaskEdgeProto { required int64 task_edge_uid = 1; required int64 src_task_id = 2; required int64 dst_task_id = 3; repeated LogicalBlobId lbi = 4; map name_in_producer2regst_desc_id = 5; }; ================================================ FILE: oneflow/core/graph/task_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/boxing_identity_task_node.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/graph/straighten_nodes.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/graph/task_type_visitor.h" namespace oneflow { // TODO(Chengcheng): default false. DEFINE_ENV_BOOL(ONEFLOW_ENABLE_OUTDATED_OPT_FW_CHAIN_MERGE, true); namespace { bool IsMemcpyPrimitiveSupported(DeviceType device_type, ep::primitive::MemcpyKind kind) { auto primitive = ep::primitive::NewPrimitive(device_type, kind); return primitive.operator bool(); } bool IsMemcpyHtoDSupported(DeviceType device_type) { return IsMemcpyPrimitiveSupported(device_type, ep::primitive::MemcpyKind::kHtoD); } bool IsMemcpyDtoHSupported(DeviceType device_type) { return IsMemcpyPrimitiveSupported(device_type, ep::primitive::MemcpyKind::kDtoH); } bool IsConnectToTickOp(const TaskNode* node) { const auto* comp_task_node = dynamic_cast(node); if (comp_task_node == nullptr) { return false; } const Operator* op = comp_task_node->op().get(); if (dynamic_cast(op) != nullptr) { return true; } return false; } bool IsSubsetTickOpConf(const OperatorConf& op_conf) { return op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf(); } bool IsTickOpConf(const OperatorConf& conf) { return IsClassRegistered(conf.op_type_case()); } const std::string& GetOpConfCalculationPassName(const OperatorConf& op_conf) { CHECK(op_conf.has_scope_symbol_id()); if (op_conf.has_calculation_pass_name()) { return op_conf.calculation_pass_name(); } int64_t scope_symbol_id = op_conf.scope_symbol_id(); CHECK(Singleton>::Get()->Has(scope_symbol_id)) << " Error! op : \n " << op_conf.DebugString() << " has error scope_symbol_id = " << scope_symbol_id << " which cannot find in Singleton>::Get()\n"; const Scope& scope = Singleton>::Get()->Get(scope_symbol_id); return scope.scope_proto().calculation_pass_name(); } bool IsOptimizerPassOp(const Operator* op) { // NOTE(chengcheng): use scope::calculation_pass_name instead of area_id to not merge optimizer // ops with fw/bw ops if (!op->op_conf().has_scope_symbol_id()) { // NOTE(chengcheng): Some system op insert to OpGraph may not set scope_symbol_id, it MUST NOT // optimizer subgraph ops. return false; } return GetOpConfCalculationPassName(op->op_conf()) == kOptimizerPass; } bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) { const OperatorConf& op_conf = op->op_conf(); if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() || op_conf.has_acc_tick_conf()) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" || user_type_name == "unpack" || user_type_name == "identity_buffer") { return true; } } // NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops if (!Singleton::Get()->nccl_use_compute_stream() && IsOptimizerPassOp(op) && EnvBool()) { return true; } return false; } bool IsTaskNodeProducedRegstHasMultiRegstNum(const TaskNode* node) { for (const auto& pair : node->produced_regsts()) { if (pair.second->min_register_num() > 1) { return true; } } return false; } bool CanBeMergedInChain(const TaskNode* node) { // ONLY the node which is NormalForward and in GPU and NOT variable can be merged. if (IsTaskNodeProducedRegstHasMultiRegstNum(node)) { return false; } const auto* fw_comp_node = dynamic_cast(node); if (fw_comp_node == nullptr) { return false; } if (fw_comp_node->device_type() == DeviceType::kCPU) { return false; } const Operator* op = fw_comp_node->op().get(); if (IsSpecialOpNotConsiderMergeInChain(op)) { return false; } return true; } std::shared_ptr GetTaskNodeTimeShape(const TaskNode* node) { const auto* fw_comp_node = dynamic_cast(node); CHECK(fw_comp_node != nullptr); return CHECK_JUST(fw_comp_node->op()->GetOpTimeShape()); } void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) { CHECK(IsValidChainId(this_chain_id)); CHECK(!IsValidChainId(this_node->chain_id())); // bfs search all node can be merged in this chain std::shared_ptr seed_time_shape = GetTaskNodeTimeShape(this_node); HashSet visited_nodes; std::queue queued_nodes; queued_nodes.push(this_node); visited_nodes.insert(this_node); while (!queued_nodes.empty()) { TaskNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK(!IsValidChainId(cur_node->chain_id())); cur_node->set_chain_id(this_chain_id); cur_node->ForEachNodeOnInOutDataEdge([&](TaskNode* next_node) { if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node) && this_node->thrd_id() == next_node->thrd_id() && (*GetTaskNodeTimeShape(next_node)) == (*seed_time_shape)) { if (!IsValidChainId(next_node->chain_id())) { queued_nodes.push(next_node); visited_nodes.insert(next_node); } else { CHECK_EQ(next_node->chain_id(), this_chain_id); } } }); } } std::function MakeGetterTaskNode4SoleOpName( const HashSet& task_nodes) { auto op_name2task_nodes = std::make_shared>>(); for (TaskNode* task_node : task_nodes) { if (task_node->exec_gph().node_num() == 1) { ExecNode* exec_node = task_node->exec_gph().SoleNode(); CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second); } } return [op_name2task_nodes](const std::string& op_name) -> TaskNode* { const auto& iter = op_name2task_nodes->find(op_name); if (iter == op_name2task_nodes->end()) { return nullptr; } if (iter->second.size() > 1) { return nullptr; } return *iter->second.begin(); }; } bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) { for (const auto& regst_desc : edge->GetRegsts()) { if (regst_desc->HasLbi(lbi)) { return true; } } return false; } std::function MakePredicatorIsLbiAllConsumersReachable( const std::function& TaskNode4SoleOpName, const std::function& IsOpNameDataOrCtrlReachable) { auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node, const TaskNode* dst_node) -> bool { if (IsValidChainId(src_node->chain_id()) && IsValidChainId(dst_node->chain_id()) && src_node->chain_id() == dst_node->chain_id() && src_node->order_in_chain() <= dst_node->order_in_chain()) { return true; } const CompTaskNode* comp_src_node = dynamic_cast(src_node); if (comp_src_node == nullptr) { return false; } const CompTaskNode* comp_dst_node = dynamic_cast(dst_node); if (comp_dst_node == nullptr) { return false; } return IsOpNameDataOrCtrlReachable(comp_src_node->op()->op_name(), comp_dst_node->op()->op_name()); }; return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi, const std::string& op_name) -> bool { const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name()); const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name); size_t out_edges_size = 0; size_t reachable_out_edges_size = 0; for (TaskEdge* out_edge : src_task_node->out_edges()) { if (IsLbiOnTaskEdge(out_edge, lbi)) { out_edges_size += 1; reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node); } } return out_edges_size > 0 && out_edges_size == reachable_out_edges_size; }; } bool IsInplaceAllowed( TaskNode* task_node, const std::vector& bns, const std::function& TaskNode4SoleOpName) { if (task_node->exec_gph().node_num() != 1) { return false; } const auto& exec_node = *task_node->exec_gph().SoleNode(); for (const auto& bn : bns) { // TaskNode for bn is not nullptr if it's on the same device with `task_node` if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; } const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn); if (regst_desc.NumOfLbi() != 1) { return false; } } const BlobDesc* first_blob = nullptr; for (const auto& bn : bns) { const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc(); if (first_blob == nullptr) { first_blob = blob_desc; } else { if (!(first_blob->shape().elem_cnt() == blob_desc->shape().elem_cnt() && first_blob->data_type() == blob_desc->data_type())) { return false; } } } return true; } std::unique_ptr CreateBoxingLogger() { if (Singleton::Get()->enable_debug_mode()) { return std::unique_ptr( new CsvBoxingLogger(StrCat("boxing/log/", GlobalJobDesc().job_id()) + ".csv")); } else { return std::unique_ptr(new NullBoxingLogger()); } } Maybe MakeGetterTaskNode4MachineId7ThrdId( const std::vector& task_nodes, std::function(int64_t mchn_id, int64_t thrd_id)>* Getter) { // ticks are shared within a machine/process auto machine_id2task_node = std::make_shared>(); for (auto* task_node : task_nodes) { machine_id2task_node->emplace(task_node->machine_id(), task_node); } *Getter = [machine_id2task_node](int64_t mchn_id, int64_t thrd_id) -> Maybe { const auto& iter = machine_id2task_node->find(mchn_id); CHECK_OR_RETURN(iter != machine_id2task_node->end()); return iter->second; }; return Maybe::Ok(); } namespace { StreamId GetStreamId(const OpNode* op_node, int64_t parallel_id, TaskType task_type) { const ParallelDesc& parallel_desc = op_node->parallel_desc(); int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t dev_phy_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); DeviceId::device_index_t device_index = parallel_desc.device_type() == DeviceType::kCPU ? 0 : static_cast(dev_phy_id); DeviceId device_id{static_cast(machine_id), parallel_desc.device_type(), device_index}; StreamId::stream_index_t stream_index = 0; if (op_node->op().op_conf().has_stream_name_hint()) { const std::string& stream_name_hint = op_node->op().op_conf().stream_name_hint(); VLOG(3) << "set op: " << op_node->op().op_name() << " to stream: " << stream_name_hint; stream_index = Singleton::Get()->GetNamedTaskStreamIndex( device_id, stream_name_hint); } else { stream_index = Singleton::Get()->GetTaskStreamIndex(task_type, device_id); } return StreamId{device_id, stream_index}; } TaskType TaskType4OpNode(const OpNode* op_node) { std::unique_ptr comp_task_node(NewCompTaskNode4OpNode(op_node)); return comp_task_node->GetTaskType(); } } // namespace CompTaskNode* GenCompTaskNode( const OpNode* op_node, int64_t parallel_id, const std::function& GetOrCreateStreamId) { const ParallelDesc& parallel_desc = op_node->parallel_desc(); int64_t parallel_num = parallel_desc.parallel_num(); CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); comp_task_node->set_machine_id(machine_id); comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_id); comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num); StreamId stream_id = GetOrCreateStreamId(op_node, parallel_id, comp_task_node->GetTaskType()); comp_task_node->set_thrd_id(EncodeStreamIdToInt64(stream_id)); comp_task_node->set_op_node(op_node); return comp_task_node; } void GenSortedCompTaskNodes(const OpNode* op_node, std::vector* sorted_comp_tasks) { int64_t parallel_idx = 0; const ParallelDesc& parallel_desc = op_node->parallel_desc(); for (int64_t machine_id : parallel_desc.sorted_machine_ids()) { for (int64_t dev_phy_id : parallel_desc.sorted_dev_phy_ids(machine_id)) { sorted_comp_tasks->emplace_back(GenCompTaskNode(op_node, parallel_idx++, &GetStreamId)); (void)dev_phy_id; } (void)machine_id; } } bool IsConnectedLbisAllSameNdSbp(const OpEdge* op_edge) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); CHECK_GT(op_edge->lbis().size(), 0); HashSet predicators; for (const LogicalBlobId& lbi : op_edge->lbis()) { const NdSbp& src_nd_sbp = src_node->NdSbp4Lbi(lbi); const NdSbp& dst_nd_sbp = dst_node->NdSbp4Lbi(lbi); predicators.insert(src_nd_sbp == dst_nd_sbp); } CHECK_EQ(predicators.size(), 1); return *predicators.begin(); } BldSubTskGphMthd GetMthdForBldSubTskGph(const OpEdge* op_edge) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); const ParallelDesc& src_pd = src_node->parallel_desc(); const ParallelDesc& dst_pd = dst_node->parallel_desc(); const OperatorConf& src_op_conf = src_node->op().op_conf(); const OperatorConf& dst_op_conf = dst_node->op().op_conf(); // WaitAndSendIds -> Reentrantlock if (src_op_conf.has_wait_and_send_ids_conf() && dst_op_conf.has_reentrant_lock_conf()) { CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } // *Tick -> *Tick if (IsTickOpConf(src_op_conf) || IsTickOpConf(dst_op_conf)) { if (src_op_conf.has_source_tick_conf()) { CHECK(dst_op_conf.has_tick_conf()); CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } else if (dst_op_conf.has_sink_tick_conf()) { CHECK(src_op_conf.has_tick_conf() || src_op_conf.has_sink_tick_conf()); CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } else if (IsSubsetTickOpConf(src_op_conf)) { return &TaskGraph::BldSubTskGphBySrcSubsetConnect; } else if (IsSubsetTickOpConf(dst_op_conf)) { return &TaskGraph::BldSubTskGphByDstSubsetConnect; } else if (IsTickOpConf(src_op_conf) && IsTickOpConf(dst_op_conf)) { if (src_pd.parallel_num() == dst_pd.parallel_num()) { return &TaskGraph::BldSubTskGphByOneToOne; } else { CHECK_EQ(src_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBroadcastToBroadcast; } } } std::shared_ptr src_comp_task(NewCompTaskNode4OpNode(src_node)); std::shared_ptr dst_comp_task(NewCompTaskNode4OpNode(dst_node)); // NOTE(chengcheng): MUST use TaskType instead of OpTypeCase because may // Multi-op corresponding to SAME TaskType such as: // DistributeConcatOpConf and DistributeAddOpConf -> TaskType::kDistributeConcat // DistributeSplitOpConf and DistributeCloneOpConf -> TaskType::kDistributeSplit // * -> DistributeConcat if (dst_comp_task->GetTaskType() == TaskType::kDistributeConcat) { return &TaskGraph::BldSubTskGphByPartialInLbiConnect; } // DistributeSplit -> * if (src_comp_task->GetTaskType() == TaskType::kDistributeSplit) { return &TaskGraph::BldSubTskGphByPartialOutLbiConnect; } // NormalForward -> DecodeH2D if (src_comp_task->GetTaskType() == TaskType::kNormalForward && dst_comp_task->GetTaskType() == TaskType::kDecodeH2D) { return &TaskGraph::BldSubTskGphNormalForwardToDecodeH2D; } if (src_pd.parallel_num() == 1 && dst_pd.parallel_num() == 1) { return &TaskGraph::BldSubTskGphByOneToOne; } // one to one if (src_pd.parallel_num() == dst_pd.parallel_num() && *src_pd.hierarchy() == *dst_pd.hierarchy() && IsConnectedLbisAllSameNdSbp(op_edge)) { return &TaskGraph::BldSubTskGphByOneToOne; } return &TaskGraph::BldSubTskGphByBoxing; } void ForEachOpGraphNecessaryCtrlEdge( const OpGraph* op_graph, const std::function& Handler) { auto IsOpGraphDataReachable = op_graph->CreatePredicatorIsReachable(); op_graph->ForEachNode([&](OpNode* dst) { for (const auto& ctrl_in_op_name : dst->op().op_conf().ctrl_in_op_name()) { const OpNode* src = op_graph->OpNode4OpName(ctrl_in_op_name); CHECK(!IsOpGraphDataReachable(dst, src)); // src has ctrl to dst, but src has no data path to dst. if (!IsOpGraphDataReachable(src, dst)) { CHECK_EQ(dst->parallel_desc().parallel_num(), src->parallel_desc().parallel_num()); const Shape* src_time_shape = CHECK_JUST(src->op().GetOpTimeShape()).get(); const Shape* dst_time_shape = CHECK_JUST(dst->op().GetInputBlobFastestTimeShape()).get(); if (dst_time_shape == nullptr) { dst_time_shape = CHECK_JUST(dst->op().GetOpTimeShape()).get(); } if (src_time_shape->elem_cnt() != dst_time_shape->elem_cnt()) { // NOTE(chengcheng): acc / pack op node can be merged and add ctrl edge. CHECK(src->op().op_conf().has_user_conf()); const std::string& op_type_name = src->op().op_conf().user_conf().op_type_name(); CHECK(op_type_name == "acc" || op_type_name == "pack"); const Shape* src_input_time_shape = CHECK_JUST(src->op().GetInputBlobFastestTimeShape()).get(); CHECK_EQ(src_input_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); } else { CHECK_EQ(src_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); } if (!src->parallel_desc().EqualsIgnoringHierarchy(dst->parallel_desc())) { LOG(WARNING) << " Warning, there is a ctrl edge connected across placement from: " << src->op().op_name() << " [" << src->parallel_desc().parallel_conf().DebugString() << "] to: " << dst->op().op_name() << " [" << dst->parallel_desc().parallel_conf().DebugString() << "]"; } Handler(src, dst); } } }); } void GetHostInputLbis4OpNode(const OpNode* op_node, std::vector* host_mem_input_lbis) { host_mem_input_lbis->clear(); if (op_node->op().op_conf().has_user_conf()) { const auto& user_conf = op_node->op().op_conf().user_conf(); const auto& op_type_name = user_conf.op_type_name(); if (user_op::UserOpHostMemoryInputRegistry::Get().HasHostMemoryInput(op_type_name)) { const auto& inputs = [&]() -> std::vector> { const auto& arg_map = op_node->op().op_conf().user_conf().input(); std::vector> arg_vec; for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { for (int32_t i = 0; i < it->second.s_size(); ++i) { arg_vec.emplace_back(std::make_pair(it->first, i)); } } return arg_vec; }(); for (const auto& pair : inputs) { if (user_op::UserOpHostMemoryInputRegistry::Get().IsHostMemoryInput4Op( op_type_name, pair.first, pair.second)) { const LogicalBlobId& host_input_lbi = GenLogicalBlobId(user_conf.input().at(pair.first).s(pair.second)); host_mem_input_lbis->emplace_back(host_input_lbi); } } } } } HashMap* GlobalDeviceType2CreateSubTskGphBuilderFn() { static HashMap global_device_type_create_sub_tsk_gph_builder_fn; return &global_device_type_create_sub_tsk_gph_builder_fn; } } // namespace TaskGraph::TaskGraph() = default; TaskGraph::~TaskGraph() = default; Maybe RegisterCreateSubTskGphBuilderFn(DeviceType device_type, const CreateSubTskGphBuilderFn& fn) { auto* global_device_type_create_sub_tsk_gph_builder_fn = GlobalDeviceType2CreateSubTskGphBuilderFn(); global_device_type_create_sub_tsk_gph_builder_fn->emplace(device_type, fn); return Maybe::Ok(); } TaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) { TaskEdge* edge = NewEdge(); edge->AddLbi(lbi); return edge; } TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector& lbis) { TaskEdge* edge = NewEdge(); edge->AddLbis(lbis); return edge; } TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const MemZoneId& dst_mem_zone_id) { const auto& src_mem_zone_id = src_node->MemZoneId121(); const ProxyKey key(src_node, lbi, dst_mem_zone_id); auto it = proxy2node.find(key); if (it != proxy2node.cend()) { // hit cache return it->second; } else { if (src_mem_zone_id == dst_mem_zone_id) { // in the same memory zone proxy2node[key] = src_node; return src_node; } else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) { if (src_mem_zone_id.rank() == dst_mem_zone_id.rank()) { // on the same node, not on the same device // src must be not on the cpu mem zone, copy d2h first CHECK(IsMemcpyDtoHSupported(src_mem_zone_id.device_type())); CopyHdTaskNode* copy_task = NewNode(); copy_task->Init(CopyHdType::D2H, src_mem_zone_id, lbi); Connect(src_node, NewTaskEdgeWithLbi(lbi), copy_task); proxy2node[key] = copy_task; return copy_task; } else { // not on the same node, need CopyCommNet from src to dst // build src cpu proxy first TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.rank())); CopyCommNetTaskNode* copy_comm_net_task = NewNode(); copy_comm_net_task->Init(dst_mem_zone_id.rank(), lbi); Connect(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task); proxy2node[key] = copy_comm_net_task; return copy_comm_net_task; } } else { TaskNode* proxy_on_dst_host = GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.rank())); CHECK(IsMemcpyHtoDSupported(dst_mem_zone_id.device_type())); CopyHdTaskNode* copy_task = NewNode(); copy_task->Init(CopyHdType::H2D, dst_mem_zone_id, lbi); Connect(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task); proxy2node[key] = copy_task; return copy_task; } } return nullptr; } TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) { const int64_t dst_machine_id = CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); DeviceType device_type = dst_parallel_desc.device_type(); auto device_index = (device_type == DeviceType::kCPU ? 0 : static_cast(dev_id)); MemZoneId mem_zone_id{static_cast(dst_machine_id), device_type, device_index}; return GetProxyNode(src_node, lbi, mem_zone_id); } void TaskGraph::ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node) { std::string regst_desc_name; src_task_node->BuildCtrlRegstDesc(dst_task_node, ®st_desc_name); TaskEdge* edge = NewEdge(); Connect(src_task_node, edge, dst_task_node); src_task_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes) { CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { ConnectCtrlEdge(src_task_nodes.at(i), dst_task_nodes.at(i)); } } void TaskGraph::RemoveEmptyRegsts() { ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); }); ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); }); ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); }); ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); }); } void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() { if (EnableLogicalChain()) { // Ctrl edges in chain has already been added in logical chain pass, so // there is no need to call BuildCtrlRegstDescInSameChain here. MergeChainByLogicalChainId(); } else { // TODO(chengcheng): erase old chain version in the future. MergeChainByPhysicalTaskGraph(); BuildCtrlRegstDescInSameChain(); } } void TaskGraph::InitOrderedTaskNodes() { // NOTE(chengcheng): Warning, ordered_task_nodes_ by topo is NOT valid in process // parallel compile, because the current rank task graph is Incomplete. TopoForEachNode([&](TaskNode* task_node) { ordered_task_nodes_.emplace_back(task_node); }); } void TaskGraph::MergeChainByPhysicalTaskGraph() { int64_t chain_id = 0; for (auto* this_node : ordered_task_nodes_) { // skip if this node has been set in a chain. if (IsValidChainId(this_node->chain_id())) { continue; } if (CanBeMergedInChain(this_node)) { TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); } else { this_node->set_chain_id(chain_id); } ++chain_id; } // set order_in_chain by ordered_task_nodes_ HashMap chain_id2order; for (auto* node : ordered_task_nodes_) { CHECK(IsValidChainId(node->chain_id())); int64_t this_chain_id = node->chain_id(); if (chain_id2order.find(this_chain_id) == chain_id2order.end()) { chain_id2order.emplace(this_chain_id, 0); } node->set_order_in_chain(chain_id2order.at(this_chain_id)++); } } void TaskGraph::MergeChainByLogicalChainId() { for (TaskNode* this_node : ordered_task_nodes_) { CompTaskNode* comp_node = dynamic_cast(this_node); if (!comp_node) { continue; } const OperatorConf& conf = comp_node->op()->op_conf(); if (conf.has_logical_chain_id()) { const int64_t logical_chain_id = conf.logical_chain_id(); CHECK(IsValidChainId(logical_chain_id)); this_node->set_chain_id(logical_chain_id); CHECK(conf.has_order_in_logical_chain()); this_node->set_order_in_chain(conf.order_in_logical_chain()); } } } void TaskGraph::BuildCtrlRegstDescInSameChain() { auto GenPhysicalChainId = [](TaskNode* node) { // NOTE(chengcheng): different rank cannot use same chain id for bad ctrl link. return (node->chain_id() << 31) | (node->machine_id()); }; HashMap physical_chain_id2node; // Note that ordered_task_nodes_'s topology order in seperation plan compile is not gerenteed, // So add ctrl edge with ordered_task_nodes_ in seperation plan compile may case dead lock. for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } // NOTE(chengcheng): skip invalid chain id if (!IsValidChainId(node->chain_id())) { continue; } int64_t physical_chain_id = GenPhysicalChainId(node); auto iter = physical_chain_id2node.find(physical_chain_id); if (iter == physical_chain_id2node.end()) { CHECK(physical_chain_id2node.emplace(physical_chain_id, node).second); } else { TaskNode* src_node = iter->second; TaskNode* dst_node = node; std::string ctrl_regst_name; bool build_ctrl_edge = src_node->BuildCtrlRegstDescIfNeed(dst_node, &ctrl_regst_name); if (build_ctrl_edge) { CHECK(!ctrl_regst_name.empty()); TaskEdge* edge = NewEdge(); Connect(src_node, edge, dst_node); src_node->BindEdgeWithProducedRegst(edge, ctrl_regst_name); } iter->second = dst_node; } } } void TaskGraph::GetInplaceOpBlobArgList( InplaceObasInfo* obas_info, const HashSet& dev_nodes, const std::function& TaskNode4OpName) const { auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn, const std::string& op_name) { if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) { auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(op_name, ibn); *pair->mutable_second() = GenOpBlobArg(op_name, obn); } }; auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn, const std::string& op_name) { if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) { auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(op_name, ibn); *pair->mutable_second() = GenOpBlobArg(op_name, obn); } }; for (TaskNode* task_node : dev_nodes) { if (task_node->exec_gph().node_num() != 1) { continue; } const auto& op = *task_node->exec_gph().SoleNode()->op(); for (const std::string& ibn : op.input_bns()) { if (op.InputBlobModifier4Ibn(ibn).is_mutable()) { CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName)); *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn); } } for (const auto& pair : task_node->exec_gph().SoleNode()->mut_inplace_obn2ibn()) { AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name()); } for (const auto& pair : task_node->exec_gph().SoleNode()->con_inplace_obn2ibn()) { AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name()); } } } void TaskGraph::GetSafeInplaceOpBlobArgList( InplaceObasInfo* safe_obas_info, const HashSet& dev_nodes, const std::function& IsOpNameDataOrCtrlReachable) const { auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes); InplaceObasInfo obas_info; GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName); auto Op4OpName = [&](const std::string& op_name) -> const Operator* { return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get(); }; auto IsLbiAllConsumersReachable = MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable); InplaceLbiGraph origin_graph(obas_info, Op4OpName); InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName); origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable); if (Singleton::Get()->enable_debug_mode()) { origin_graph.ToDotWithFilePath( JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot")); safe_graph.ToDotWithFilePath( JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot")); } } void TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info, const HashSet& dev_nodes) const { auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes); auto Op4OpName = [&](const std::string& op_name) -> const Operator* { return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get(); }; InplaceLbiGraph inplace_gph(obas_info, Op4OpName); inplace_gph.ForEachConnectedComponent([&](const HashSet& inplace_nodes) { for (const auto* inplace_node : inplace_nodes) { if (inplace_node->in_edges().empty()) { continue; } const auto* inplace_edge = inplace_node->SoleInEdge(); auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode(); RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn()); RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn()); out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } }); } void TaskGraph::ForEachGpuDeviceNodes( const std::function& dev_nodes)>& Handler) const { HashMap, HashSet> global_dev_phy_id2nodes; ForEachNode([&](TaskNode* task_node) { if (task_node->device_type() == DeviceType::kCPU) { return; } int64_t dev_phy_id = task_node->stream_id().device_id().device_index(); global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node); }); for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); } } void TaskGraph::EnableInplaceMemSharing( const std::function& IsOpNameDataOrCtrlReachable) { ForEachGpuDeviceNodes([&](const HashSet& dev_nodes) { EnableInplaceMemSharing(dev_nodes, IsOpNameDataOrCtrlReachable); }); } void TaskGraph::EnableInplaceMemSharing( const HashSet& dev_nodes, const std::function& IsOpNameDataOrCtrlReachable) { InplaceObasInfo safe_inplace_obas_info; GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable); SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes); } void TaskGraph::DecideExecutionOrder() { // For one machine with no transfer available, the straighten algorithm for overlaps consume a lot // of memory StraightenAlgorithmTag straighten_algorithm_tag = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); if (straighten_algorithm_tag == StraightenAlgorithmTag::kDisableStraighten || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer && GlobalProcessCtx::WorldSize() == 1)) { InitOrderedTaskNodes(); } else { StraightenNodes(this, &ordered_task_nodes_, Singleton::Get()->nccl_use_compute_stream()); } } #define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \ void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS() DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const OpNode* src_op_node = op_edge->src_node(); const OpNode* dst_op_node = op_edge->dst_node(); std::vector host_mem_input_lbis; GetHostInputLbis4OpNode(dst_op_node, &host_mem_input_lbis); for (const LogicalBlobId& lbi : op_edge->lbis()) { std::vector in_nodes(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); std::vector out_nodes; out_nodes.reserve(sorted_dst_comp_tasks.size()); std::vector> sorted_ctrl_tasks; const NdSbp& src_nd_sbp = src_op_node->NdSbp4Lbi(lbi); const NdSbp& dst_nd_sbp = dst_op_node->NdSbp4Lbi(lbi); const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc(); const ParallelDesc& dst_parallel_desc = [&]() { if (std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi) != host_mem_input_lbis.end()) { return *CHECK_JUST( ReplaceDeviceType(SymbolOf(dst_op_node->parallel_desc()), DeviceType::kCPU)); } else { return dst_op_node->parallel_desc(); } }(); const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi); VLOG(3) << "src op: " << src_op_node->op().op_name() << " dst op: " << dst_op_node->op().op_name() << " src_parallel_conf: " << src_parallel_desc.parallel_conf().DebugString() << " dst parallel conf: " << dst_parallel_desc.parallel_conf().DebugString() << " src_nd_sbp " << src_nd_sbp.DebugString() << " dst nd_sbp " << dst_nd_sbp.DebugString(); std::shared_ptr status; const DeviceType device_type = [&src_parallel_desc, &dst_parallel_desc]() { return src_parallel_desc.device_type() != DeviceType::kCPU ? src_parallel_desc.device_type() : dst_parallel_desc.device_type(); }(); if (device_type != DeviceType::kCPU && device_type2sub_tsk_gph_builder_.find(device_type) != device_type2sub_tsk_gph_builder_.end()) { auto maybe_status = // NOLINT device_type2sub_tsk_gph_builder_ // NOLINT .at(device_type) // NOLINT ->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); } } if (!status) { status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); } boxing_logger_->Log(*status, src_op_node->op().op_name(), dst_op_node->op().op_name(), src_parallel_desc, dst_parallel_desc, src_nd_sbp, dst_nd_sbp, lbi, blob_desc); CHECK_EQ(out_nodes.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, out_nodes.size()) { ConnectWithLbi(out_nodes.at(i), sorted_dst_comp_tasks.at(i), lbi); } if (!sorted_ctrl_tasks.empty()) { CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) { for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) { std::string regst_desc_name; ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i), ®st_desc_name); TaskEdge* edge = NewEdge(); Connect(ctrl_node, edge, sorted_dst_comp_tasks.at(i)); ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } } } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) { std::vector host_mem_input_lbis; GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis); CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { for (const LogicalBlobId& lbi : op_edge->lbis()) { bool is_host_mem_input = std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi) != host_mem_input_lbis.end(); BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(i), lbi, is_host_mem_input); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) { std::vector host_mem_input_lbis; GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis); for (CompTaskNode* dst_node : sorted_dst_comp_tasks) { CompTaskNode* nearest_src_node = SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); CHECK_NOTNULL(nearest_src_node); for (const LogicalBlobId& lbi : op_edge->lbis()) { bool is_host_mem_input = std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi) != host_mem_input_lbis.end(); BuildTaskPath(nearest_src_node, dst_node, lbi, is_host_mem_input); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) { const Operator& src_op = op_edge->src_node()->op(); const Operator& dst_op = op_edge->dst_node()->op(); HashSet lbis; std::vector host_mem_input_lbis; GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis); for (const auto& obn : src_op.output_bns()) { lbis.insert(src_op.BnInOp2Lbi(obn)); } CHECK_EQ(sorted_src_comp_tasks.size(), 1); CHECK_EQ(dst_op.input_bns().size(), sorted_dst_comp_tasks.size()); FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) { const auto& lbi = dst_op.BnInOp2Lbi(dst_op.input_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { bool is_host_mem_input = std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi) != host_mem_input_lbis.end(); BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), lbi, is_host_mem_input); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) { const Operator& src_op = op_edge->src_node()->op(); const Operator& dst_op = op_edge->dst_node()->op(); HashSet lbis; std::vector host_mem_input_lbis; GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis); for (const auto& ibn : dst_op.input_bns()) { lbis.insert(dst_op.BnInOp2Lbi(ibn)); } CHECK_EQ(sorted_dst_comp_tasks.size(), 1); CHECK_EQ(src_op.output_bns().size(), sorted_src_comp_tasks.size()); FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) { const auto& lbi = src_op.BnInOp2Lbi(src_op.output_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { bool is_host_mem_input = std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi) != host_mem_input_lbis.end(); BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), lbi, is_host_mem_input); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) { std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; CHECK_JUST( MakeGetterTaskNode4MachineId7ThrdId(sorted_src_comp_tasks, &TaskNode4MachineId7ThrdId)); for (CompTaskNode* dst_task_node : sorted_dst_comp_tasks) { CompTaskNode* src_task_node = CHECK_JUST( TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id())); Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect) { std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; CHECK_JUST( MakeGetterTaskNode4MachineId7ThrdId(sorted_dst_comp_tasks, &TaskNode4MachineId7ThrdId)); for (CompTaskNode* src_task_node : sorted_src_comp_tasks) { CompTaskNode* dst_task_node = CHECK_JUST( TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id())); Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) { CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { CompTaskNode* src = sorted_src_comp_tasks.at(i); CompTaskNode* dst = sorted_dst_comp_tasks.at(i); for (const LogicalBlobId& lbi : op_edge->lbis()) { ConnectWithLbi(src, dst, lbi); } } } void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { if (src_node == dst_node) { return; } for (TaskEdge* out_edge : src_node->out_edges()) { TaskNode* out_node = out_edge->dst_node(); if (out_node == dst_node) { out_edge->AddLbi(lbi); return; } } TaskEdge* connected_edge = NewEdge(); connected_edge->AddLbi(lbi); Connect(src_node, connected_edge, dst_node); } void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi, bool is_host_mem_input) { const MemZoneId dst_mem_zone_id = [&]() { if (is_host_mem_input) { MemZoneId mem_zone_id = dst_node->MemZoneId121(); return MemZoneId(mem_zone_id.rank(), DeviceType::kCPU, 0); } else { return dst_node->MemZoneId121(); } }(); TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_mem_zone_id); ConnectWithLbi(proxy_node, dst_node, lbi); } Maybe GlobalTaskGraph::Init() { OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); // Register the corresponding task graph builder based on the device type and store them to map const auto* global_device_type_create_sub_tsk_gph_builder_fn = GlobalDeviceType2CreateSubTskGphBuilderFn(); for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); } hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); HashMap> op_node2sorted_comp_tasks; op_graph->ForEachNode([&](const OpNode* op_node) { std::vector* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]); GenSortedCompTaskNodes(op_node, sorted_comp_tasks); for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } }); op_graph->ForEachEdge([&](const OpEdge* op_edge) { BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()), op_node2sorted_comp_tasks.at(op_edge->dst_node())); }); ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src); const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst); if (src->op().op_conf().has_src_subset_tick_conf()) { UNIMPLEMENTED(); } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { UNIMPLEMENTED(); } else { ConnectCtrlEdges(src_task_nodes, dst_task_nodes); } }); if (Singleton::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } return Maybe::Ok(); } Maybe BoxingTaskGraph::Init( const std::function&)>& ParallelRunLoop) { OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); // Register the corresponding task graph builder based on the device type and store them to map const auto* global_device_type_create_sub_tsk_gph_builder_fn = GlobalDeviceType2CreateSubTskGphBuilderFn(); for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); } hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) { if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) > 0) { return; } std::vector* sorted_comp_tasks = &(boxing_related_op_node2sorted_comp_tasks_[op_node]); GenSortedCompTaskNodes(op_node, sorted_comp_tasks); for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } }; op_graph->ForEachEdge([&](const OpEdge* op_edge) { if (!op_edge->NeedBoxing()) { return; } TryCreateSortedCompTaskNodes(op_edge->src_node()); TryCreateSortedCompTaskNodes(op_edge->dst_node()); BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); (this->*method)(op_edge, boxing_related_op_node2sorted_comp_tasks_.at(op_edge->src_node()), boxing_related_op_node2sorted_comp_tasks_.at(op_edge->dst_node())); }); ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, std::placeholders::_1)); CreateOpNode2TaskIds(ParallelRunLoop); return Maybe::Ok(); } void BoxingTaskGraph::CreateOpNode2TaskIds( const std::function&)>& ParallelRunLoop) { const OpGraph* op_graph = Singleton::Get(); std::vector op_nodes; op_nodes.reserve(op_graph->node_num()); op_graph->ForEachNode([&](OpNode* op_node) { if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) == 0) { op_nodes.push_back(op_node); boxing_unrelated_op_node2sorted_task_ids_[op_node].reserve( op_node->parallel_desc().parallel_num()); } }); ParallelRunLoop(op_nodes.size(), [&](size_t i) { const OpNode* op_node = op_nodes.at(i); TaskType task_type = TaskType4OpNode(op_node); const auto& parallel_desc = op_node->parallel_desc(); auto* task_ids = &boxing_unrelated_op_node2sorted_task_ids_[op_node]; for (int parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { const auto& stream_id = GetStreamId(op_node, parallel_id, task_type); task_ids->push_back(Singleton::Get()->GetTaskIdGenerator()->Generate(stream_id)); } }); } namespace { bool IsComputTaskNodeDutyRank(int64_t current_rank, const ParallelDesc& parallel_desc, int64_t task_node_rank) { if (current_rank == 0) { // make sure master knows at least one op_node. return CHECK_JUST(parallel_desc.MachineId4ParallelId(0)) == task_node_rank; } else if (parallel_desc.HasMachineId(current_rank)) { // workers only care their own rank. return current_rank == task_node_rank; } else { return false; } } // A template function to process task node for different task node type. // RetT, function return type // HandleTansportTaskNode, if the task node is a transport task node, call this processing function // HandleComputeTaskNode, if the task node is a compute task node, call this processing // task_node, the input task node template RetT TaskNodeVisitor(TaskNode* task_node, const HandleTansportTaskNodeT& HandleTansportTaskNode, const HandleComputeTaskNodeT& HandleComputeTaskNode) { auto* transport_task_node = dynamic_cast(task_node); if (transport_task_node != nullptr) { return HandleTansportTaskNode(transport_task_node); } else { auto* comp_task_node = dynamic_cast(task_node); if (comp_task_node != nullptr) { return HandleComputeTaskNode(comp_task_node); } else { UNIMPLEMENTED(); } } } } // namespace /*static*/ bool BoxingTaskGraph::SelectTaskNodeByRank(TaskNode* task_node, int64_t rank) { return TaskNodeVisitor( task_node, [&](TransportTaskNode* task_node) { return task_node->machine_id() == rank; }, [&](CompTaskNode* task_node) { const auto& machine_id = task_node->machine_id(); return IsComputTaskNodeDutyRank(rank, task_node->op_node()->parallel_desc(), machine_id); }); } void BoxingTaskGraph::ToProto(const std::function& Pick, BoxingTaskGraphProto* proto) const { const auto sources = [&]() -> std::list { HashSet sources; ForEachNode([&](TaskNode* task_node) { if (Pick(task_node)) { sources.insert(task_node); } }); HashSet sources_out; for (auto* source : sources) { // The consumed task_ids must be generated from out_nodes. source->ForEachNodeOnOutEdge([&](TaskNode* out_node) { if (!sources.count(out_node)) { sources_out.insert(out_node); } }); } sources.insert(sources_out.begin(), sources_out.end()); return std::list{sources.begin(), sources.end()}; }(); const auto& TransportTaskNodeToProto = [&](TransportTaskNode* task_node) { task_node->ToTransportTaskProtoIf(proto->mutable_transport_task()->Add()); }; const auto& ComputeTaskNodeToProto = [&](CompTaskNode* task_node) { auto* map = proto->mutable_boxing_related_op_name2compute_tasks(); const auto& op_name = task_node->op_node()->op().op_name(); auto* parallel_id2task_proto = (*map)[op_name].mutable_parallel_id2task(); int64_t parallel_id = task_node->parallel_id(); task_node->ToProto(&(*parallel_id2task_proto)[parallel_id], /*check=*/false); }; HashSet rank_task_nodes; BfsForEachNode(sources, &TaskNode::ForEachNodeOnInEdge, [&](TaskNode* task_node) { rank_task_nodes.insert(task_node); TaskNodeVisitor(task_node, TransportTaskNodeToProto, ComputeTaskNodeToProto); }); const auto rank_task_edges = [&] { HashSet rank_task_edges; const auto& TryInsertEdge = [&](TaskEdge* edge) { if (rank_task_nodes.count(edge->src_node()) > 0 && rank_task_nodes.count(edge->dst_node()) > 0) { rank_task_edges.insert(edge); } }; for (const auto* task_node : rank_task_nodes) { for (auto* in_edge : task_node->in_edges()) { TryInsertEdge(in_edge); } for (auto* out_edge : task_node->out_edges()) { TryInsertEdge(out_edge); } } return rank_task_edges; }(); for (auto* edge : rank_task_edges) { edge->ToProto(proto->mutable_task_edge()->Add()); } for (const auto& pair : boxing_unrelated_op_node2sorted_task_ids_) { const auto& op_name = pair.first->op().op_name(); auto* vec = &(*proto->mutable_boxing_unrelated_op_name2task_ids())[op_name]; for (const auto& task_id : pair.second) { vec->add_task_id(EncodeTaskIdToInt64(task_id)); } } } RankTaskGraph::RankTaskGraph(const std::shared_ptr& boxing_task_graph_proto, int64_t current_rank) : boxing_task_graph_proto_(boxing_task_graph_proto), current_rank_(current_rank), task_graph_rebuild_ctx_(std::make_unique()) {} Maybe RankTaskGraph::TryGetBoxingRelatedComTaskNode(const OpNode* op_node, int64_t parallel_id) { const auto& op_name = op_node->op().op_name(); auto iter = boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().find(op_name); if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) { return nullptr; } if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) { return nullptr; } auto task_iter = iter->second.parallel_id2task().find(parallel_id); if (task_iter == iter->second.parallel_id2task().end()) { return nullptr; } int64_t task_id = task_iter->second.task_id(); auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id)); auto* comp_task_node = dynamic_cast(task_node); CHECK_NOTNULL_OR_RETURN(comp_task_node) << "invalid task_type. task_id: " << task_id; return comp_task_node; } Maybe RankTaskGraph::CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node, int64_t parallel_id) { auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); if (comp_task_node != nullptr) { return comp_task_node; } auto iter = op_node2comp_task_node_.find(op_node); if (iter != op_node2comp_task_node_.end()) { return iter->second; } const TaskId task_id = *JUST([&]() -> Maybe { const auto& map = boxing_task_graph_proto_->boxing_unrelated_op_name2task_ids(); const auto& iter = map.find(op_node->op().op_name()); CHECK_OR_RETURN(iter != map.end()); CHECK_LT_OR_RETURN(parallel_id, iter->second.task_id_size()); return DecodeTaskIdFromInt64(iter->second.task_id().Get(parallel_id)); }()); const auto& GetStreamIdFromMaster = [&](const OpNode* op_node, int64_t parallel_id, TaskType) { return task_id.stream_id(); }; auto comp_task_node_ptr = GenCompTaskNode(op_node, parallel_id, GetStreamIdFromMaster); comp_task_node_ptr->update_new_task_id(task_id); AddAllocatedNode(comp_task_node_ptr); CHECK_OR_RETURN(op_node2comp_task_node_.emplace(op_node, comp_task_node_ptr).second) << "Got dupliacted op_node " << op_node->op().op_name(); return comp_task_node_ptr; } Maybe RankTaskGraph::CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node, int64_t rank) { CHECK_OR_RETURN(op_node->parallel_desc().HasMachineId(rank)) << "rank is not contained in the placment"; int64_t parallel_id = -1; CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, ¶llel_id))) << "parallel_id not found."; return CreateOrFindRankCompTaskNodeByParallelId(op_node, parallel_id); } Maybe RankTaskGraph::TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank) { if (!op_node->parallel_desc().HasMachineId(rank)) { return nullptr; } int64_t parallel_id = -1; CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, ¶llel_id))) << "parallel_id not found."; auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); if (comp_task_node != nullptr) { return comp_task_node; } auto iter = op_node2comp_task_node_.find(op_node); CHECK_OR_RETURN(iter != op_node2comp_task_node_.end()) << "op_node " << op_node->op().op_name() << " not found."; return iter->second; } Maybe RankTaskGraph::AddBoxingReletedCompTaskNodesFromProto() { OpGraph* op_graph = Singleton::Get(); for (const auto& pair : boxing_task_graph_proto_->boxing_related_op_name2compute_tasks()) { const OpNode* op_node = op_graph->OpNode4OpName(pair.first); for (const auto& pair : pair.second.parallel_id2task()) { const auto& task_proto = pair.second; CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second) << "redundant task_id."; CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); comp_task_node->set_op_node(op_node); AddAllocatedNode(comp_task_node); // Note here has no consume regst // Init task node and produce regst comp_task_node->InitFromProtoExceptConsumedRegsts(task_proto); JUST(task_graph_rebuild_ctx_->AddTaskNode(comp_task_node)); } } return Maybe::Ok(); } Maybe RankTaskGraph::CreateAndPartiallyInitTransportTaskNodesFromProto() { for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) { const auto& task_proto = transport_task_proto.task_proto(); CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second) << "redundant task_id."; auto* task_node = JUST(CreateTransportTask::Visit(task_proto.task_type())); AddAllocatedNode(task_node); // Init task node and produce regst task_node->InitFromProtoExceptConsumedRegsts(transport_task_proto.task_proto()); JUST(task_graph_rebuild_ctx_->AddTaskNode(task_node)); } return Maybe::Ok(); } Maybe RankTaskGraph::AddTransportTaskEdgesFromProto() { for (const auto& task_edge_proto : boxing_task_graph_proto_->task_edge()) { TaskEdge* edge = NewEdge(); auto* src_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.src_task_id())); auto* dst_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.dst_task_id())); Connect(src_task_node, edge, dst_task_node); JUST(edge->InitFromProto(task_edge_proto, *task_graph_rebuild_ctx_)); JUST(task_graph_rebuild_ctx_->AddTaskEdge(edge, task_edge_proto.task_edge_uid())); } return Maybe::Ok(); } Maybe RankTaskGraph::InitTransportTaskNodesFromProto() { for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) { int64_t task_id = transport_task_proto.task_proto().task_id(); auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id)); auto* transport_task_node = dynamic_cast(task_node); CHECK_NOTNULL_OR_RETURN(transport_task_node) << "task node is not a TransportTaskNode. task_id" << task_id; JUST(transport_task_node->InitTransportTaskFromProtoIf(transport_task_proto, *task_graph_rebuild_ctx_)); } return Maybe::Ok(); } bool RankTaskGraph::ContainRank(const OpNode* op_node, int64_t rank) const { return op_node->parallel_desc().HasMachineId(rank); } Maybe RankTaskGraph::ConnectDataEdges(const OpEdge* op_edge, int64_t rank) { if (!op_edge->NeedBoxing()) { auto* src_task_node = JUST(TryGetRankCompTaskNode(op_edge->src_node(), rank)); auto* dst_task_node = JUST(TryGetRankCompTaskNode(op_edge->dst_node(), rank)); if (ContainRank(op_edge->src_node(), rank)) { CHECK_NOTNULL_OR_RETURN(src_task_node) << "src_task_node should not be nullptr. op_name: " << op_edge->src_node()->op().op_name(); } if (ContainRank(op_edge->dst_node(), rank)) { CHECK_NOTNULL_OR_RETURN(dst_task_node) << "dst_task_node should not be nullptr. op_name: " << op_edge->dst_node()->op().op_name(); } if (src_task_node != nullptr && dst_task_node != nullptr) { for (const auto& lbi : op_edge->lbis()) { ConnectWithLbi(src_task_node, dst_task_node, lbi); } } } return Maybe::Ok(); } Maybe RankTaskGraph::ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank) { if ((ContainRank(src, rank) && ContainRank(dst, rank))) { auto* src_task_node = CHECK_JUST(TryGetRankCompTaskNode(src, rank)); auto* dst_task_node = CHECK_JUST(TryGetRankCompTaskNode(dst, rank)); if (src->op().op_conf().has_src_subset_tick_conf()) { UNIMPLEMENTED_THEN_RETURN() << "ctrl edge from src_subset_tick is not supported."; } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { UNIMPLEMENTED_THEN_RETURN() << "ctrl edge to dst_subset_tick is not supported."; } else { ConnectCtrlEdge(CHECK_NOTNULL(src_task_node), CHECK_NOTNULL(dst_task_node)); } } return Maybe::Ok(); } bool RankTaskGraph::IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const { return IsComputTaskNodeDutyRank(current_rank_, parallel_desc, rank); } template Maybe RankTaskGraph::DoRankDuty(const ParallelDesc& parallel_desc, const DoEachRankT& DoWithRank) { if (current_rank_ == 0) { // make sure master knows at least one op_node. JUST(DoWithRank(JUST(parallel_desc.MachineId4ParallelId(0)))); } else if (parallel_desc.HasMachineId(current_rank_)) { // workers only care their own rank. JUST(DoWithRank(current_rank_)); } else { // Do nothing. } return Maybe::Ok(); } Maybe RankTaskGraph::InitRegstDescsConsumers() { const auto& RegstDesc4Id = [&](int64_t regst_desc_id) -> Maybe { return JUST(task_graph_rebuild_ctx_->RegstDesc4Id(regst_desc_id)); }; JUST(MaybeForEachNode([&](TaskNode* task_node) -> Maybe { const auto& task_proto = *JUST(MapAt(task_id2task_proto_, task_node->task_id())); JUST(task_node->InitConsumedRegstsFromProto(task_proto, RegstDesc4Id)); return Maybe::Ok(); })); return Maybe::Ok(); } Maybe RankTaskGraph::Init(const HashSet& var_op_names) { JUST(AddBoxingReletedCompTaskNodesFromProto()); JUST(CreateAndPartiallyInitTransportTaskNodesFromProto()); JUST(AddTransportTaskEdgesFromProto()); JUST(InitTransportTaskNodesFromProto()); JUST(InitRegstDescsConsumers()); // Note that tasks currently added in above code are from BoxingTaskGraph, so they are all // boxing related. OpGraph* op_graph = Singleton::Get(); JUST(op_graph->MaybeForEachNode([&](OpNode* op_node) -> Maybe { JUST(DoRankDuty(op_node->parallel_desc(), [&](int64_t rank) -> Maybe { JUST(CreateOrFindRankCompTaskNodeByRank(op_node, rank)); return Maybe::Ok(); })); if (var_op_names.count(op_node->op().op_name()) > 0 && !IsDutyRank(op_node->parallel_desc(), current_rank_)) { // To makes sure all ranks know all var_op_names, at least one task for variable op is // needed in the plan. JUST(CreateOrFindRankCompTaskNodeByParallelId(op_node, /*parallel_id=*/0)); } return Maybe::Ok(); })); JUST(op_graph->MaybeForEachEdge([&](const OpEdge* op_edge) -> Maybe { return DoRankDuty(op_edge->src_node()->parallel_desc(), [&](int64_t rank) { return ConnectDataEdges(op_edge, rank); }); })); ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { if (!src->parallel_desc_sym()->EqualsIgnoringHierarchy(*dst->parallel_desc_sym())) { LOG(INFO) << " src " << src->parallel_desc_sym()->data().DebugString() << " dst " << dst->parallel_desc_sym()->data().DebugString(); return; } CHECK_JUST(DoRankDuty(src->parallel_desc(), [&](int64_t rank) { return ConnectCtrlEdges(src, dst, rank); })); }); if (Singleton::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } ForEachNode([&](TaskNode* task_node) { task_node->ProduceAllRegstsAndBindEdges(); }); ForEachEdge([&](TaskEdge* edge) { CHECK(edge->HasRegst()) << "Found edge which has not bound a regst, src task " << edge->src_node()->VisualStr(); }); return Maybe::Ok(); } RankTaskGraph::~RankTaskGraph() {} } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_ #include "oneflow/core/graph/task_node.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/register/op_blob_arg_info.h" #include "oneflow/core/graph/boxing/boxing_logger.h" #include "oneflow/core/memory/memory_zone.h" namespace oneflow { class SubTskGphBuilderCtx; class HierarchicalSubTskGphBuilder; #define BLD_SUB_TSK_GPH_MTHD_ARGS() \ (const OpEdge* op_edge, const std::vector& sorted_src_comp_tasks, \ const std::vector& sorted_dst_comp_tasks) class TaskGraph; using BldSubTskGphMthd = void(TaskGraph::*) BLD_SUB_TSK_GPH_MTHD_ARGS(); class TaskGraph : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(TaskGraph); virtual ~TaskGraph() override; const char* TypeName() const override { return "TaskGraph"; } void RemoveEmptyRegsts(); void MergeChainAndAddOrderingCtrlEdgeInSameChain(); void DecideExecutionOrder(); void EnableInplaceMemSharing(const std::function& IsOpNameDataOrCtrlReachable); void EnableInplaceMemSharing(const HashSet& dev_nodes, const std::function& IsOpNameDataOrCtrlReachable); TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const MemZoneId& dst_mem_zone_id); TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id); TaskEdge* NewTaskEdgeWithLbi(const LogicalBlobId& lbi); TaskEdge* NewTaskEdgeWithLbis(const std::vector& lbis); void ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi); #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D); void ForEachGpuDeviceNodes( const std::function& dev_nodes)>& Handler) const; protected: explicit TaskGraph(); void BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi, bool is_host_mem_input); void ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes); void ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node); void InitOrderedTaskNodes(); void MergeChainByPhysicalTaskGraph(); void MergeChainByLogicalChainId(); void BuildCtrlRegstDescInSameChain(); // inplace void GetInplaceOpBlobArgList( InplaceObasInfo* obas_info, const HashSet& dev_nodes, const std::function& TaskNode4OpName) const; void GetSafeInplaceOpBlobArgList( InplaceObasInfo* safe_obas_info, const HashSet& dev_nodes, const std::function& IsOpNameDataOrCtrlReachable) const; void SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info, const HashSet& dev_nodes) const; std::vector ordered_task_nodes_; HashMap> device_type2sub_tsk_gph_builder_; std::unique_ptr hierarchical_sub_tsk_gph_builder_; std::unique_ptr sub_tsk_gph_builder_ctx_; std::unique_ptr boxing_logger_; struct ProxyKey { TaskNode* src_node; LogicalBlobId lbi; MemZoneId dst_mem_zone_id; ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, const MemZoneId& arg_mem_zone_id) : src_node(src), lbi(arg_lbi), dst_mem_zone_id(arg_mem_zone_id) {} bool operator==(const ProxyKey& other) const { return src_node == other.src_node && lbi == other.lbi && dst_mem_zone_id == other.dst_mem_zone_id; } struct Hasher { inline size_t operator()(const ProxyKey& key) const { return Hash(key.src_node, key.lbi, key.dst_mem_zone_id.hash()); } }; }; HashMap proxy2node; }; class GlobalTaskGraph final : public TaskGraph { public: OF_DISALLOW_COPY_AND_MOVE(GlobalTaskGraph); ~GlobalTaskGraph() = default; static Maybe New() { std::shared_ptr graph(new GlobalTaskGraph()); JUST(graph->Init()); return graph; } private: GlobalTaskGraph() = default; Maybe Init(); }; class BoxingTaskGraphProto; class BoxingTaskGraph final : public TaskGraph { public: OF_DISALLOW_COPY_AND_MOVE(BoxingTaskGraph); ~BoxingTaskGraph() = default; static Maybe New( const std::function&)>& ParallelRunLoop) { std::shared_ptr graph(new BoxingTaskGraph()); JUST(graph->Init(ParallelRunLoop)); return graph; } void ToProto(const std::function& Pick, BoxingTaskGraphProto* proto) const; static bool SelectTaskNodeByRank(TaskNode*, int64_t rank); private: BoxingTaskGraph() = default; Maybe Init( const std::function&)>& ParallelRunLoop); void CreateOpNode2TaskIds( const std::function&)>& ParallelRunLoop); HashMap> boxing_related_op_node2sorted_comp_tasks_; HashMap> boxing_unrelated_op_node2sorted_task_ids_; }; class TaskGraphRebuildCtx; class RankTaskGraph final : public TaskGraph { public: OF_DISALLOW_COPY_AND_MOVE(RankTaskGraph); ~RankTaskGraph(); static Maybe New( const std::shared_ptr& boxing_task_graph_proto, const HashSet& var_op_names, int64_t current_rank) { std::shared_ptr graph(new RankTaskGraph(boxing_task_graph_proto, current_rank)); JUST(graph->Init(var_op_names)); return graph; } // Is `rank` my duty. bool IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const; private: RankTaskGraph(const std::shared_ptr& boxing_task_graph_proto, int64_t rank); Maybe Init(const HashSet& var_op_names); bool ContainRank(const OpNode* op_node, int64_t rank) const; Maybe AddBoxingReletedCompTaskNodesFromProto(); Maybe CreateAndPartiallyInitTransportTaskNodesFromProto(); Maybe AddTransportTaskEdgesFromProto(); Maybe InitTransportTaskNodesFromProto(); Maybe InitRegstDescsConsumers(); template Maybe DoRankDuty(const ParallelDesc& parallel_desc, const DoEachRankT& DoWithRank); Maybe TryGetBoxingRelatedComTaskNode(const OpNode* op_node, int64_t parallel_id); Maybe CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node, int64_t parallel_id); Maybe CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node, int64_t rank); Maybe TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank); Maybe ConnectDataEdges(const OpEdge* op_edge, int64_t rank); Maybe ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank); std::shared_ptr boxing_task_graph_proto_; HashMap task_id2task_proto_; const int64_t current_rank_; std::unique_ptr task_graph_rebuild_ctx_; HashMap op_node2comp_task_node_; }; using CreateSubTskGphBuilderFn = std::function()>; Maybe RegisterCreateSubTskGphBuilderFn(DeviceType device_type, const CreateSubTskGphBuilderFn& fn); #define REGISTER_CREATE_SUB_TASK_GRAPH_BUILDER_FN(device_type, fn) \ COMMAND(CHECK_JUST(RegisterCreateSubTskGphBuilderFn(device_type, fn))) } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_ ================================================ FILE: oneflow/core/graph/task_graph_rebuild_ctx.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { Maybe TaskGraphRebuildCtx::TaskNode4Id(int64_t task_id) const { auto* task_node = JUST(MapAt(id2task_node_, task_id)); CHECK_EQ_OR_RETURN(task_node->task_id(), task_id); // NOLINT return task_node; } Maybe TaskGraphRebuildCtx::TaskEdge4Uid(int64_t task_edge_uid) const { return JUST(MapAt(uid2task_edge_, task_edge_uid)); } Maybe TaskGraphRebuildCtx::RegstDesc4Id(int64_t regst_desc_id) const { return JUST(MapAt(id2regst_desc_, regst_desc_id)); } Maybe TaskGraphRebuildCtx::AddTaskNode(TaskNode* task_node) { CHECK_OR_RETURN(id2task_node_.emplace(task_node->task_id(), task_node).second) << "redundant task id found. value: " << task_node->task_id(); for (const auto& pair : task_node->produced_regsts()) { JUST(AddRegstDesc(pair.second)); } return Maybe::Ok(); } Maybe TaskGraphRebuildCtx::AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid) { CHECK_OR_RETURN(uid2task_edge_.emplace(task_edge_uid, task_edge).second) << "redundant task edge uid found. value: " << task_edge_uid; return Maybe::Ok(); } Maybe TaskGraphRebuildCtx::AddRegstDesc(const std::shared_ptr& regst_desc) { CHECK_OR_RETURN(id2regst_desc_.emplace(regst_desc->regst_desc_id(), regst_desc).second) << "redundant register descriptor id found. value: " << regst_desc->regst_desc_id(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_graph_rebuild_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ #define ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/register/register_desc.h" namespace oneflow { class TaskNode; class TaskEdge; class TaskGraphRebuildCtx { public: TaskGraphRebuildCtx() = default; ~TaskGraphRebuildCtx() = default; Maybe TaskNode4Id(int64_t task_id) const; Maybe TaskEdge4Uid(int64_t task_edge_uid) const; Maybe RegstDesc4Id(int64_t regst_desc_id) const; Maybe AddTaskNode(TaskNode* task_node); Maybe AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid); Maybe AddRegstDesc(const std::shared_ptr& regst_desc); private: HashMap id2task_node_; HashMap uid2task_edge_; HashMap> id2regst_desc_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ ================================================ FILE: oneflow/core/graph/task_id.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_id.h" #include namespace oneflow { // TaskId encoding (maybe extended to 128 bits in future) // | rank | device_type | device_index | | // | ----------- 16 ----------- | ---- 5 ---- | ----- 7 ----- | | // | DeviceId | stream_index | | // | ------------------------- 31 --------------------------- | ---- 15 ---- | | // | StreamId | task_index | // | -------------------------------- 43 ----------------------------------- | --- 21 --- | // | TaskId | // | ----------------------------------- 64 bit ----------------------------------------- | namespace { constexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT; constexpr size_t kStreamIndexShift = TaskId::kTaskIndexBits; constexpr size_t kDeviceIndexShift = kStreamIndexShift + StreamId::kStreamIndexBits; constexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits; constexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits; static_assert(kInt64Bits == kRankShift + DeviceId::kRankBits, ""); constexpr int64_t kTaskIndexInt64Mask = (int64_t{1} << TaskId::kTaskIndexBits) - 1; constexpr int64_t kStreamIndexInt64Mask = ((int64_t{1} << StreamId::kStreamIndexBits) - 1) << kStreamIndexShift; constexpr int64_t kDeviceIndexInt64Mask = ((int64_t{1} << DeviceId::kDeviceIndexBits) - 1) << kDeviceIndexShift; constexpr int64_t kDeviceTypeInt64Mask = ((int64_t{1} << DeviceId::kDeviceTypeBits) - 1) << kDeviceTypeShift; constexpr int64_t kRankInt64Mask = ((int64_t{1} << DeviceId::kRankBits) - 1) << kRankShift; } // namespace int64_t EncodeTaskIdToInt64(const TaskId& task_id) { int64_t id = static_cast(task_id.task_index()); id |= static_cast(task_id.stream_id().stream_index()) << kStreamIndexShift; id |= static_cast(task_id.stream_id().device_index()) << kDeviceIndexShift; id |= static_cast(task_id.stream_id().device_type()) << kDeviceTypeShift; id |= static_cast(task_id.stream_id().rank()) << kRankShift; return id; } TaskId DecodeTaskIdFromInt64(int64_t task_id_val) { int64_t rank = (task_id_val & kRankInt64Mask) >> kRankShift; int64_t device_type = (task_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift; int64_t device_index = (task_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift; int64_t stream_index = (task_id_val & kStreamIndexInt64Mask) >> kStreamIndexShift; int64_t task_index = task_id_val & kTaskIndexInt64Mask; StreamId stream_id{static_cast(rank), static_cast(device_type), static_cast(device_index), static_cast(stream_index)}; return TaskId{stream_id, static_cast(task_index)}; } int64_t MachineId4ActorId(int64_t actor_id) { return DecodeTaskIdFromInt64(actor_id).stream_id().rank(); } int64_t ThrdId4ActorId(int64_t actor_id) { return EncodeStreamIdToInt64(DecodeTaskIdFromInt64(actor_id).stream_id()); } } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_ID_H_ #define ONEFLOW_CORE_GRAPH_TASK_ID_H_ #include "oneflow/core/graph/stream_id.h" namespace oneflow { class TaskId { public: using task_index_t = uint32_t; const static size_t kTaskIndexBits = 21; constexpr static task_index_t kMaxTaskIndex = (task_index_t{1} << kTaskIndexBits) - task_index_t{1}; TaskId(const StreamId& stream_id, task_index_t task_index) : stream_id_(stream_id), task_index_(task_index) { CHECK_LE(task_index_, kMaxTaskIndex); } const StreamId& stream_id() const { return stream_id_; } task_index_t task_index() const { return task_index_; } bool operator==(const TaskId& rhs) const { return stream_id_ == rhs.stream_id_ && task_index_ == rhs.task_index_; } bool operator!=(const TaskId& rhs) const { return !(*this == rhs); } size_t hash() const { size_t hash = stream_id_.hash(); HashCombine(&hash, std::hash{}(task_index_)); return hash; } private: StreamId stream_id_; task_index_t task_index_; }; int64_t EncodeTaskIdToInt64(const TaskId&); TaskId DecodeTaskIdFromInt64(int64_t); int64_t MachineId4ActorId(int64_t actor_id); int64_t ThrdId4ActorId(int64_t actor_id); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::TaskId& task_id) const { return task_id.hash(); } }; } // namespace std #endif // ONEFLOW_CORE_GRAPH_TASK_ID_H_ ================================================ FILE: oneflow/core/graph/task_id_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/graph/task_id.h" #include "oneflow/core/graph/task_id_generator.h" namespace oneflow { void TaskIdGenerator::GetTaskIndex(HashMap* task_index_state) { for (const auto& pair : stream_id2task_index_counter_) { const int64_t i64_stream_id = EncodeStreamIdToInt64(pair.first); (*task_index_state)[i64_stream_id] = pair.second; } } void TaskIdGenerator::TryUpdateTaskIndex(const HashMap& task_index_state) { for (auto& pair : stream_id2task_index_counter_) { const int64_t i64_stream_id = EncodeStreamIdToInt64(pair.first); uint32_t initial_task_index = 0; if (task_index_state.count(i64_stream_id) != 0) { initial_task_index = task_index_state.at(i64_stream_id); } pair.second = std::max(pair.second, initial_task_index); } // try update the task_index_init_state for (const auto& pair : task_index_state) { const auto& key = pair.first; const auto& val = pair.second; if (task_index_init_state_.count(key) != 0) { task_index_init_state_[key] = std::max(task_index_init_state_.at(key), val); } else { task_index_init_state_[key] = val; } } } TaskId TaskIdGenerator::Generate(const StreamId& stream_id) { std::unique_lock lock(mutex_); if (stream_id2task_index_counter_.count(stream_id) == 0) { uint32_t init_task_index = 0; const int64_t i64_stream_id = EncodeStreamIdToInt64(stream_id); if (task_index_init_state_.count(i64_stream_id) != 0) { init_task_index = task_index_init_state_.at(i64_stream_id); } stream_id2task_index_counter_[stream_id] = init_task_index; } task_index_t task_index = stream_id2task_index_counter_[stream_id]++; return TaskId{stream_id, task_index}; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_id_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_ #define ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_ #include "oneflow/core/graph/task_id.h" #include "oneflow/core/job/id_state.h" namespace oneflow { class TaskIdGenerator final { public: using task_index_t = TaskId::task_index_t; TaskIdGenerator() = default; OF_DISALLOW_COPY_AND_MOVE(TaskIdGenerator); ~TaskIdGenerator() = default; TaskId Generate(const StreamId& stream_id); void GetTaskIndex(HashMap* task_index_state); void TryUpdateTaskIndex(const HashMap& task_index_state); private: std::mutex mutex_; HashMap stream_id2task_index_counter_; // The task_index_init_state is used to initialize the `stream_id2task_index_counter_` hashmap. HashMap task_index_init_state_{}; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_ ================================================ FILE: oneflow/core/graph/task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_node.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { namespace { void ForEachDataEdge(const std::unordered_set& edges, const std::function& Handler) { for (TaskEdge* edge : edges) { const auto& regsts = edge->GetRegsts(); int32_t data_regst_size = std::count_if(regsts.begin(), regsts.end(), [](const std::shared_ptr& regst) { return regst->regst_desc_type().has_data_regst_desc(); }); if (data_regst_size == regsts.size()) { Handler(edge); } else { CHECK_EQ(data_regst_size, 0); } } } } // namespace TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1), chain_id_(-1), order_in_chain_(-1) {} std::shared_ptr TaskNode::GetProducedRegst(const std::string& name) { auto produced_regsts_it = produced_regsts_.find(name); if (produced_regsts_it == produced_regsts_.end()) { return nullptr; } else { return produced_regsts_it->second; } } const std::list>& TaskNode::GetConsumedRegst(const std::string& name) { return consumed_regsts_.at(name); } std::shared_ptr TaskNode::GetSoleConsumedRegst(const std::string& name) { auto it = consumed_regsts_.find(name); if (it == consumed_regsts_.end()) { return nullptr; } const std::list>& vec = it->second; CHECK_EQ(vec.size(), 1); return vec.front(); } const StreamId& TaskNode::stream_id() const { CHECK(new_task_id_); return new_task_id_->stream_id(); } DeviceType TaskNode::device_type() const { return stream_id().device_id().device_type(); } void TaskNode::set_machine_id(int64_t val) { CHECK_EQ(machine_id_, -1); machine_id_ = val; if (thrd_id_ != -1) { UpdateTaskId(); } } void TaskNode::set_thrd_id(int64_t val) { CHECK_EQ(thrd_id_, -1); thrd_id_ = val; CHECK_GE(thrd_id_, 0); if (machine_id_ != -1) { UpdateTaskId(); } } void TaskNode::set_chain_id(int64_t val) { CHECK(!IsValidChainId(chain_id_)); chain_id_ = val; } void TaskNode::set_order_in_chain(int64_t val) { CHECK_EQ(order_in_chain_, -1); order_in_chain_ = val; } void TaskNode::PinConsumedRegst() { for (auto& pair : consumed_regsts_) { for (const std::shared_ptr& regst : pair.second) { PinConsumedRegstMemCase(regst->mut_mem_case()); } } } void TaskNode::NaiveInferProducedDataRegstTimeShape() { if (IsMeaningLess()) { return; } std::shared_ptr time_shape; ForEachConsumedDataRegst([&time_shape](const std::string& name, const RegstDesc* regst) { if (time_shape) { CHECK_EQ(*time_shape.get(), *regst->data_regst_time_shape().get()); } else { time_shape = regst->data_regst_time_shape(); } }); CHECK(time_shape); ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) { *regst->mut_data_regst_time_shape() = time_shape; }); } void TaskNode::InferTimeShapeIfMeaningful() { if (!IsMeaningLess()) { InferProducedDataRegstTimeShape(); } } std::shared_ptr TaskNode::GetFastestInputOutputTimeShape() const { std::shared_ptr shape; auto UpdateRetShape = [&](TaskEdge* edge) { for (const auto& regst : edge->GetRegsts()) { if (!shape || shape->elem_cnt() < regst->data_regst_time_shape()->elem_cnt()) { shape = regst->data_regst_time_shape(); } } }; ForEachOutDataEdge(UpdateRetShape); if (shape) { return shape; } ForEachInDataEdge(UpdateRetShape); return shape; } void TaskNode::ForEachConsumedDataRegst( const std::function& Handler) const { for (const auto& pair : consumed_regsts_) { for (const auto& regst : pair.second) { if (!regst->regst_desc_type().has_data_regst_desc()) { continue; } Handler(pair.first, regst.get()); } } } void TaskNode::ForEachProducedDataRegst( const std::function& Handler) { for (auto& pair : produced_regsts_) { if (!pair.second->regst_desc_type().has_data_regst_desc()) { continue; } Handler(pair.first, pair.second.get()); } } void TaskNode::Build() { BuildExecGphAndRegst(); } void TaskNode::EraseUninitializedShapeProducedBlob() { for (auto& pair : produced_regsts_) { pair.second->EraseUninitializedShapeBlob(); } } void TaskNode::EraseZeroSizeConsumedRegst() { for (auto& pair : consumed_regsts_) { for (auto it = pair.second.begin(); it != pair.second.end();) { auto regst_ptr = *it; CHECK(regst_ptr); if (regst_ptr->regst_desc_type().has_data_regst_desc() && regst_ptr->NumOfLbi() == 0) { it = pair.second.erase(it); } else { ++it; } } } EraseIf>>( &consumed_regsts_, [](HashMap>>::iterator it) { return it->second.empty(); }); } void TaskNode::EraseZeroSizeProducedRegst() { EraseIf>( &produced_regsts_, [](HashMap>::iterator it) { return it->second->regst_desc_type().has_data_regst_desc() && it->second->NumOfLbi() == 0; }); } void TaskNode::UnbindBnWithEmptyRegst() { exec_gph_.ForEachNode([&](ExecNode* exec_node) { exec_node->UnbindBnWithEmptyRegst(); }); } std::string TaskNode::VisualStr() const { std::stringstream ss; ss << TaskType_Name(GetTaskType()) << "\\n" << machine_id_ << ":" << thrd_id_ << "\\n" << task_id_; return ss.str(); } bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) { // Step1: init some scalar items. CHECK(task_proto.task_type() == GetTaskType()); machine_id_ = task_proto.machine_id(); thrd_id_ = task_proto.thrd_id(); task_id_ = task_proto.task_id(); new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_))); CHECK(task_proto.job_id() == GlobalJobDesc().job_id()); chain_id_ = task_proto.chain_id(); order_in_chain_ = task_proto.order_in_chain(); // Step2: check exec_gph empty. CHECK(task_proto.exec_sequence().exec_node().empty()); // Step3: init produced_regst. for (const auto& pair : task_proto.produced_regst_desc()) { const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem()); // regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto. regst_desc->InitFromProtoExceptConsumers(pair.second); } } Maybe TaskNode::InitConsumedRegstsFromProto( const TaskProto& task_proto, const std::function(int64_t regst_desc_id)>& RegstDesc4Id) { // init consumed_regst. for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id))); } } return Maybe::Ok(); } void TaskNode::ToProto(TaskProto* task_proto, bool check) const { // Step1: process some scalar items. task_proto->set_task_type(GetTaskType()); task_proto->set_machine_id(machine_id_); task_proto->set_thrd_id(thrd_id_); task_proto->set_task_id(task_id_); task_proto->set_job_id(GlobalJobDesc().job_id()); task_proto->set_chain_id(chain_id_); task_proto->set_order_in_chain(order_in_chain_); // Step2: process exec_gph. exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence()); // Step3: process produced_regst. auto* produced_regst_proto = task_proto->mutable_produced_regst_desc(); for (auto& pair : produced_regsts_) { RegstDescProto regst_desc_proto; pair.second->ToProto(®st_desc_proto, check); CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second); } // Step4: process consumed_regst. auto* consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id(); for (const auto& pair : consumed_regsts_) { RegstDescIdSet regst_desc_ids; for (const std::shared_ptr& regst : pair.second) { regst_desc_ids.add_regst_desc_id(regst->regst_desc_id()); } CHECK(consumed_regst_proto->insert({pair.first, regst_desc_ids}).second); } } MemZoneId TaskNode::MemZoneId121() const { StreamId stream_id = DecodeStreamIdFromInt64(thrd_id_); return stream_id.device_id(); } bool TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name) { if (IsMeaningLess() || dst_node->IsMeaningLess()) { return false; } for (const TaskEdge* in_edge : dst_node->in_edges()) { if (in_edge->src_node() == this) { return false; } } BuildCtrlRegstDesc(dst_node, name); return true; } RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node) { std::string name; return BuildCtrlRegstDesc(dst_node, &name); } RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) { RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_ctrl_regst_desc(); auto regst = NewProducedRegst(false, 1, kMaxRegisterNum, regst_desc_type); *name = "out_ctrl_" + std::to_string(regst->regst_desc_id()); CHECK(produced_regsts_.emplace(*name, regst).second); dst_node->ConsumeRegst("in_ctrl", regst); return regst.get(); } void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { if (edge->HasRegst(name)) { return; } edge->AddRegst(name, GetProducedRegst(name)); } std::shared_ptr TaskNode::GetAndCheckRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num) const { auto iter = produced_regsts_.find(name); if (iter == produced_regsts_.end()) { return nullptr; } const auto& regst = (iter->second); CHECK_EQ(regst->min_register_num(), min_register_num); CHECK_EQ(regst->max_register_num(), max_register_num); CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem); return regst; } std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) { return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum); } std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num) { // Because the Regst of separate compilation is not created in order, some Regst may have been // built. This implementation can avoid ProduceRegst being called multiple times. const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); if (regst) { return regst; } RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_data_regst_desc(); return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type); } std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto& regst_desc_type) { auto regst = NewProducedRegst(enable_reuse_mem, min_register_num, max_register_num, regst_desc_type); CHECK(produced_regsts_.emplace(name, regst).second); return regst; } std::shared_ptr TaskNode::NewProducedRegst(bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto& regst_desc_type) { auto regst = std::make_shared(); regst->set_producer(this); *(regst->mut_regst_desc_type()) = regst_desc_type; regst->UpdtMinRegstNumIfNeed(min_register_num); regst->UpdtMaxRegstNumIfNeed(max_register_num); regst->set_enable_reuse_mem(GlobalJobDesc().enable_reuse_mem() && enable_reuse_mem); InitProducedRegstMemCase(regst.get()); return regst; } void TaskNode::InitProducedRegstMemCase(RegstDesc* regst) { InitProducedRegstMemCase(regst->mut_mem_case()); } void TaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) { mem_case->set_device_type(device_type()); mem_case->set_device_id(stream_id().device_id().device_index()); } void TaskNode::PinConsumedRegstMemCase(MemoryCase* mem_case) { // When a node located on non-cpu device consumes a cpu regst, // the regst memory should be pinned on host memory (locked page memory). // When the regst is not on host, skip pinning if (!memory::IsHostMem(*mem_case)) { return; } // When the node is located on host, skip pinning if (device_type() == DeviceType::kCPU) { return; } mem_case->set_pinned_device_type(device_type()); mem_case->set_pinned_device_id(stream_id().device_id().device_index()); } void TaskNode::ConsumeRegst(const std::string& name) { consumed_regsts_.emplace(name, std::list>{}); } void TaskNode::ConsumeRegst(const std::string& name, const std::shared_ptr& regst) { regst->AddConsumer(this); consumed_regsts_[name].emplace_back(regst); } void TaskNode::UpdateTaskId() { CHECK_NE(machine_id_, -1); CHECK_NE(thrd_id_, -1); StreamId stream_id = DecodeStreamIdFromInt64(thrd_id_); new_task_id_.reset( new TaskId(Singleton::Get()->GetTaskIdGenerator()->Generate(stream_id))); task_id_ = EncodeTaskIdToInt64(*new_task_id_); } void TaskNode::update_new_task_id(const TaskId& task_id) { CHECK(static_cast(new_task_id_)); CHECK(new_task_id_->stream_id() == task_id.stream_id()); *new_task_id_ = task_id; task_id_ = EncodeTaskIdToInt64(*new_task_id_); } void TaskNode::EraseConsumedRegstsByName(const std::string& name) { if (consumed_regsts_.find(name) != consumed_regsts_.end()) { for (auto& regst : consumed_regsts_[name]) { regst->DeleteConsumer(this); } CHECK_EQ(consumed_regsts_.erase(name), 1); } } std::shared_ptr TaskEdge::GetRegst(const std::string& name_in_producer) const { return name_in_producer2regst_.at(name_in_producer); } bool TaskEdge::HasRegst(const std::string& name_in_producer) const { return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end()); } std::shared_ptr TaskEdge::GetSoleRegst() const { CHECK_EQ(name_in_producer2regst_.size(), 1) << "edge: " << this << ", src: " << src_node()->task_id() << ", dst: " << dst_node()->task_id(); return name_in_producer2regst_.begin()->second; } std::vector> TaskEdge::GetRegsts() const { std::vector> regst_descs; regst_descs.reserve(name_in_producer2regst_.size()); for (auto& pair : name_in_producer2regst_) { regst_descs.emplace_back(pair.second); } return regst_descs; } void TaskEdge::AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst) { if (HasRegst(name_in_producer)) { CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id() == regst->regst_desc_id()); return; } CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } void TaskEdge::CheckRegstLbiValid() const { HashMap> lbi2data_regst; for (auto& pair : name_in_producer2regst_) { std::shared_ptr regst = pair.second; if (regst->regst_desc_type().has_data_regst_desc()) { // NOTE(chengcheng): regst_desc_type is Set, BUT regst_desc_type.data_regst_desc is UNSET! // So you can ONLY use NumOfLbi and ForEachLbi interface. CHECK_EQ(regst->NumOfLbi(), 1); regst->ForEachLbi( [&](const LogicalBlobId& lbi) { CHECK(lbi2data_regst.emplace(lbi, regst).second); }); } } CHECK_EQ(lbi2data_regst.size(), lbis_.size()) << " \n\n TaskEdge lbi and regst NOT match." << " TaskEdge: edge_id = " << edge_id() << " From: [" << src_node()->VisualStr() << "] To: [" << dst_node()->VisualStr() << "]\n"; for (auto& lbi : lbis_) { CHECK(lbi2data_regst.find(lbi) != lbi2data_regst.end()) << " \n\n Cannot find lbi: " << lbi.DebugString() << " in TaskEdge From: [" << src_node()->VisualStr() << "] To: [" << dst_node()->VisualStr() << "]\n\n"; } } RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto, const std::string& regst_desc_name) { auto* produced_regst_desc = task_proto->mutable_produced_regst_desc(); if (produced_regst_desc->find(regst_desc_name) == produced_regst_desc->end()) { RegstDescProto ctrl_regst_desc; InitCtrlRegstDesc(task_proto->task_id(), &ctrl_regst_desc); CHECK(produced_regst_desc->insert({regst_desc_name, ctrl_regst_desc}).second); } return &produced_regst_desc->at(regst_desc_name); } RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto, const std::string& regst_desc_name) { auto* consumed_regst_desc_id_sets = task_proto->mutable_consumed_regst_desc_id(); if (consumed_regst_desc_id_sets->find(regst_desc_name) == consumed_regst_desc_id_sets->end()) { CHECK(consumed_regst_desc_id_sets->insert({regst_desc_name, RegstDescIdSet()}).second); } return &consumed_regst_desc_id_sets->at(regst_desc_name); } void TaskNode::ForEachInDataEdge(const std::function& Handler) const { ForEachDataEdge(in_edges(), Handler); } void TaskNode::ForEachOutDataEdge(const std::function& Handler) const { ForEachDataEdge(out_edges(), Handler); } void TaskNode::ForEachNodeOnInDataEdge(const std::function& Handler) const { ForEachInDataEdge([&](TaskEdge* in_edge) { Handler(in_edge->src_node()); }); } void TaskNode::ForEachNodeOnOutDataEdge(const std::function& Handler) const { ForEachOutDataEdge([&](TaskEdge* out_edge) { Handler(out_edge->dst_node()); }); } void TaskNode::ForEachNodeOnInOutDataEdge(const std::function& Handler) const { ForEachNodeOnInDataEdge(Handler); ForEachNodeOnOutDataEdge(Handler); } TaskEdge* TaskNode::GetSoleEdge(void (TaskNode::*ForEachEdge)(const std::function&) const) const { TaskEdge* ret = nullptr; (this->*ForEachEdge)([&](TaskEdge* edge) { CHECK(ret == nullptr); ret = edge; }); CHECK_NOTNULL(ret); return ret; } size_t TaskNode::GetEdgesSize(void (TaskNode::*ForEachEdge)(const std::function&) const) const { size_t size = 0; (this->*ForEachEdge)([&](TaskEdge* edge) { ++size; }); return size; } TaskEdge* TaskNode::SoleInDataEdge() const { return GetSoleEdge(&TaskNode::ForEachInDataEdge); } TaskEdge* TaskNode::SoleOutDataEdge() const { return GetSoleEdge(&TaskNode::ForEachOutDataEdge); } size_t TaskNode::in_data_edges_size() const { return GetEdgesSize(&TaskNode::ForEachInDataEdge); } size_t TaskNode::out_data_edges_size() const { return GetEdgesSize(&TaskNode::ForEachOutDataEdge); } Maybe TaskEdge::InitFromProto(const TaskEdgeProto& proto, const TaskGraphRebuildCtx& task_graph_rebuild_ctx) { CHECK_NE_OR_RETURN(proto.src_task_id(), proto.dst_task_id()) << "self-loop are not supported"; JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.src_task_id())); JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.dst_task_id())); // Note that edge id from proto is ignored. lbis_.insert(proto.lbi().begin(), proto.lbi().end()); for (const auto& pair : proto.name_in_producer2regst_desc_id()) { AddRegst(pair.first, JUST(task_graph_rebuild_ctx.RegstDesc4Id(pair.second))); } return Maybe::Ok(); } void TaskEdge::ToProto(TaskEdgeProto* proto) const { // proto->set_task_edge_uid(edge_id()); proto->set_task_edge_uid(reinterpret_cast(this)); proto->set_src_task_id(src_node()->task_id()); proto->set_dst_task_id(dst_node()->task_id()); *proto->mutable_lbi() = {lbis_.begin(), lbis_.end()}; auto* map = proto->mutable_name_in_producer2regst_desc_id(); for (const auto& pair : name_in_producer2regst_) { CHECK(map->insert({pair.first, pair.second->regst_desc_id()}).second); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_TASK_NODE_H_ #include "oneflow/core/graph/exec_graph.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/graph/task_edge.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/memory/memory_zone.h" namespace std { template<> struct hash { std::size_t operator()(const oneflow::TaskType& task_type) const { return std::hash{}(static_cast(task_type)); } }; } // namespace std namespace oneflow { RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto, const std::string& regst_desc_name); RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto, const std::string& regst_desc_name); bool inline IsValidChainId(int64_t val) { return val >= 0; } class TaskEdge; class TaskNode : public Node { public: OF_DISALLOW_COPY_AND_MOVE(TaskNode); TaskNode(); ~TaskNode() override = default; // Getters int64_t machine_id() const { return machine_id_; } int64_t thrd_id() const { return thrd_id_; } int64_t task_id() const { return task_id_; } const StreamId& stream_id() const; int64_t chain_id() const { return chain_id_; } int64_t order_in_chain() const { return order_in_chain_; } const ExecGraph& exec_gph() const { return exec_gph_; } std::shared_ptr GetProducedRegst(const std::string& name); const std::list>& GetConsumedRegst(const std::string& name); std::shared_ptr GetSoleConsumedRegst(const std::string& name); const HashMap>& produced_regsts() const { return produced_regsts_; } const HashMap>>& consumed_regsts() const { return consumed_regsts_; } DeviceType device_type() const; virtual const ParallelContext* parallel_ctx() const { return nullptr; } // Different types of TaskNode/Compile Mode choose different output BlobDesc inference methods virtual ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const = 0; // Setters void set_machine_id(int64_t val); void set_thrd_id(int64_t val); void set_chain_id(int64_t val); void set_order_in_chain(int64_t val); // Build virtual void ProduceAllRegstsAndBindEdges() = 0; virtual void ConsumeAllRegsts() = 0; void PinConsumedRegst(); void InferTimeShapeIfMeaningful(); void ForEachProducedDataRegst(const std::function& Handler); void ForEachConsumedDataRegst( const std::function& Handler) const; void Build(); void EraseUninitializedShapeProducedBlob(); void EraseZeroSizeConsumedRegst(); void EraseZeroSizeProducedRegst(); void UnbindBnWithEmptyRegst(); // Others virtual TaskType GetTaskType() const { return TaskType::kInvalid; } std::string VisualStr() const override; virtual bool IsMeaningLess(); void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } // Used to create task node from proto in plan separation compilation. virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); Maybe InitConsumedRegstsFromProto( const TaskProto& task_proto, const std::function(int64_t regst_desc_id)>& RegstDesc4Id); virtual void ToProto(TaskProto* task_proto, bool check) const; void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual MemZoneId MemZoneId121() const; bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name); RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node); RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name); std::shared_ptr GetFastestInputOutputTimeShape() const; void ForEachInDataEdge(const std::function& Handler) const; void ForEachOutDataEdge(const std::function& Handler) const; void ForEachNodeOnInDataEdge(const std::function& Handler) const; void ForEachNodeOnOutDataEdge(const std::function& Handler) const; void ForEachNodeOnInOutDataEdge(const std::function& Handler) const; TaskEdge* SoleInDataEdge() const; TaskEdge* SoleOutDataEdge() const; size_t in_data_edges_size() const; size_t out_data_edges_size() const; const TaskId& new_task_id() const { CHECK(has_new_task_id()); return *new_task_id_; } void update_new_task_id(const TaskId& task_id); bool has_new_task_id() const { return static_cast(new_task_id_); } protected: std::shared_ptr ProduceRegst(const std::string& name, bool enable_reuse_mem); std::shared_ptr ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num); std::shared_ptr ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto&); std::shared_ptr NewProducedRegst(bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto&); virtual void InitProducedRegstMemCase(RegstDesc* regst); virtual void InitProducedRegstMemCase(MemoryCase*); virtual void PinConsumedRegstMemCase(MemoryCase*); void ConsumeRegst(const std::string& name); void ConsumeRegst(const std::string& name, const std::shared_ptr&); ExecGraph& mut_exec_gph() { return exec_gph_; } void EraseConsumedRegstsByName(const std::string& name); virtual void BuildExecGphAndRegst() = 0; virtual void InferProducedDataRegstTimeShape() = 0; void NaiveInferProducedDataRegstTimeShape(); TaskEdge* GetSoleEdge(void (TaskNode::*ForEachEdge)(const std::function&) const) const; size_t GetEdgesSize(void (TaskNode::*ForEachEdge)(const std::function&) const) const; private: void UpdateTaskId(); std::shared_ptr GetAndCheckRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num) const; int64_t machine_id_; int64_t thrd_id_; int64_t task_id_; int64_t chain_id_; int64_t order_in_chain_; std::unique_ptr new_task_id_; ExecGraph exec_gph_; HashMap> produced_regsts_; HashMap>> consumed_regsts_; }; class TaskGraphRebuildCtx; class TaskEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(TaskEdge); TaskEdge() = default; ~TaskEdge() override = default; std::shared_ptr GetRegst(const std::string& name_in_producer) const; bool HasRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; std::vector> GetRegsts() const; const HashSet& GetLbis() const { return lbis_; } void AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst); void AddLbi(const LogicalBlobId& lbi) { lbis_.insert(lbi); } void AddLbis(const std::vector& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } void CheckRegstLbiValid() const; bool HasRegst() const { return !name_in_producer2regst_.empty(); } Maybe InitFromProto(const TaskEdgeProto& proto, const TaskGraphRebuildCtx& task_graph_rebuild_ctx); void ToProto(TaskEdgeProto* proto) const; private: HashSet lbis_; HashMap> name_in_producer2regst_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph/task_stream_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_ #define ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_ #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { inline StreamId GenerateComputeTaskStreamId(const DeviceId& device_id) { auto stream_index = Singleton::Get()->GetComputeTaskStreamIndex(device_id); return StreamId{device_id, stream_index}; } inline StreamId GenerateComputeTaskStreamId(int64_t rank, DeviceType device_type, int64_t device_index) { DeviceId device_id{static_cast(rank), device_type, static_cast(device_index)}; return GenerateComputeTaskStreamId(device_id); } inline StreamId GenerateNamedTaskStreamId(const DeviceId& device_id, const std::string& name) { auto stream_index = Singleton::Get()->GetNamedTaskStreamIndex(device_id, name); return StreamId{device_id, stream_index}; } inline StreamId GenerateNamedTaskStreamId(int64_t rank, DeviceType device_type, int64_t device_index, const std::string& name) { DeviceId device_id{static_cast(rank), device_type, static_cast(device_index)}; return GenerateNamedTaskStreamId(device_id, name); } } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_ ================================================ FILE: oneflow/core/graph/task_stream_index_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_state.h" #include "oneflow/core/job/resource_desc.h" namespace oneflow { StreamIndexGenerator* TaskStreamIndexManager::GetGenerator(const DeviceId& device_id) { std::unique_lock lck(mtx_); auto iter = generators_.find(device_id); if (iter == generators_.end()) { uint32_t init_stream_index = 0; const int64_t i64_device_id = EncodeDeviceIdToInt64(device_id); if (stream_index_init_state_.count(i64_device_id) != 0) { init_stream_index = stream_index_init_state_.at(i64_device_id); } iter = generators_.emplace(device_id, std::make_unique(init_stream_index)) .first; } return iter->second.get(); } TaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetTaskStreamIndex( TaskType task_type, const DeviceId& device_id) { auto* generator = GetGenerator(device_id); auto stream_index = CHECK_JUST(TaskStreamIndexGetterRegistry::Instance().Dispatch( device_id.device_type(), task_type, generator)); return stream_index; } TaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetComputeTaskStreamIndex( const DeviceId& device_id) { auto* generator = GetGenerator(device_id); return GenerateComputeTaskStreamIndex(device_id.device_type(), generator); } TaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetNamedTaskStreamIndex( const DeviceId& device_id, const std::string& name) { auto* generator = GetGenerator(device_id); return generator->GenerateNamed(name); } void TaskStreamIndexManager::GetTaskStreamIndex(HashMap* stream_index_state) { for (auto& pair : generators_) { const int64_t i64_device_id = EncodeDeviceIdToInt64(pair.first); (*stream_index_state)[i64_device_id] = pair.second->GetCurrStreamIndex(); } } void TaskStreamIndexManager::TryUpdateTaskStreamIndex( const HashMap& stream_index_state) { // Try Update generator's new_stream_index for (auto& pair : generators_) { const int64_t i64_device_id = EncodeDeviceIdToInt64(pair.first); uint32_t initial_stream_index = 0; if (stream_index_state.count(i64_device_id) != 0) { initial_stream_index = stream_index_state.at(i64_device_id); } pair.second->TryUpdateNextStreamIndex(initial_stream_index); } // try update stream_index_init_state for (const auto& pair : stream_index_state) { const auto& key = pair.first; const auto& val = pair.second; if (stream_index_init_state_.count(key) != 0) { stream_index_init_state_[key] = std::max(stream_index_init_state_.at(key), val); } else { stream_index_init_state_[key] = val; } } } void TaskStreamIndexGetterRegistry::Register(const key_t& key, const stream_index_getter& getter) { bool insert_success = stream_index_getter_map_.emplace(key, getter).second; if (!insert_success) { std::cerr << "DeviceType " << key.first << ", TaskType " << key.second << " was already registered"; abort(); } } Maybe TaskStreamIndexGetterRegistry::Dispatch( DeviceType device_type, TaskType task_type, StreamIndexGenerator* generator) { auto key = std::make_pair(device_type, task_type); auto it = stream_index_getter_map_.find(key); CHECK_OR_RETURN(it != stream_index_getter_map_.end()) << "TaskType: " << key.second << ", DeviceType: " << key.first << " has not been registered"; return it->second(generator); } StreamId::stream_index_t GenerateComputeTaskStreamIndex(DeviceType device_type, StreamIndexGenerator* generator) { if (device_type == DeviceType::kCPU) { size_t cpu_device_num = Singleton::Get()->CpuDeviceNum(); return generator->GenerateNamedRoundRobin("CPU_COMPUTE", cpu_device_num); } else { return generator->GenerateNamed("COMPUTE"); } } } // namespace oneflow ================================================ FILE: oneflow/core/graph/task_stream_index_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_ #define ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_ #include "oneflow/core/job/task.pb.h" #include "oneflow/core/graph/stream_index_generator.h" namespace oneflow { class TaskStreamIndexManager { public: using stream_index_t = StreamId::stream_index_t; OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexManager); TaskStreamIndexManager() = default; virtual ~TaskStreamIndexManager() = default; StreamIndexGenerator* GetGenerator(const DeviceId& device_id); stream_index_t GetTaskStreamIndex(TaskType task_type, const DeviceId& device_id); stream_index_t GetComputeTaskStreamIndex(const DeviceId& device_id); stream_index_t GetNamedTaskStreamIndex(const DeviceId& device_id, const std::string& name); void GetTaskStreamIndex(HashMap* stream_index_state); void TryUpdateTaskStreamIndex(const HashMap& stream_index_state); private: HashMap> generators_; // The stream_index_init_state is used to initialize the generator. HashMap stream_index_init_state_{}; std::mutex mtx_; }; class TaskStreamIndexGetterRegistry final { public: using key_t = std::pair; using stream_index_getter = std::function; using map_t = HashMap; struct GetterRegister { GetterRegister(DeviceType device_type, TaskType task_type, const stream_index_getter& getter) { TaskStreamIndexGetterRegistry::Instance().Register(std::make_pair(device_type, task_type), getter); } }; static TaskStreamIndexGetterRegistry& Instance() { static TaskStreamIndexGetterRegistry registry; return registry; } OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexGetterRegistry); ~TaskStreamIndexGetterRegistry() = default; void Register(const key_t& key, const stream_index_getter& getter); Maybe Dispatch(DeviceType device_type, TaskType task_type, StreamIndexGenerator* generator); private: TaskStreamIndexGetterRegistry() = default; map_t stream_index_getter_map_; }; StreamId::stream_index_t GenerateComputeTaskStreamIndex(DeviceType device_type, StreamIndexGenerator* generator); } // namespace oneflow #define REGISTER_TASK_STREAM_INDEX_GETTER(device_type, task_type, getter) \ static auto OF_PP_CAT(g_stream_index_getter_register_, __COUNTER__) = \ ::oneflow::TaskStreamIndexGetterRegistry::GetterRegister(device_type, task_type, getter) #define REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(device_type, task_type, name) \ REGISTER_TASK_STREAM_INDEX_GETTER( \ device_type, task_type, ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \ return generator->GenerateNamed(name); \ })); #define REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(task_type) \ REGISTER_TASK_STREAM_INDEX_GETTER( \ DeviceType::kCPU, task_type, \ ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \ return generator->GenerateAnonymous(); \ })); #define REGISTER_TICK_TASK_STREAM_INDEX_GETTER(task_type) \ REGISTER_TASK_STREAM_INDEX_GETTER( \ DeviceType::kCPU, task_type, \ ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \ return generator->GenerateNamed("TICK"); \ })); #define REGISTER_DEVICE_COMP_TASK_STREAM_INDEX_GETTER(device_type, task_type) \ REGISTER_TASK_STREAM_INDEX_GETTER( \ device_type, task_type, ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \ return GenerateComputeTaskStreamIndex(device_type, generator); \ })); #define REGISTER_COMP_TASK_STREAM_INDEX_GETTER(task_type) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DEVICE_COMP_TASK_STREAM_INDEX_GETTER, DEVICE_TYPE_SEQ, \ (task_type)) #endif // ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_ ================================================ FILE: oneflow/core/graph/task_type_visitor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/throw.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/graph/collective_boxing_task_node.h" #include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/graph/boxing_zeros_task_node.h" #include "oneflow/core/graph/slice_boxing_task_node.h" #include "oneflow/core/graph/collective_boxing_pack_task_node.h" #include "oneflow/core/graph/collective_boxing_unpack_task_node.h" #include "oneflow/core/graph/boxing_identity_task_node.h" namespace oneflow { template struct TransportTaskTypeVisitor { template static auto Visit(TaskType task_type, Args&&... args) { switch (task_type) { case TaskType::kInvalid: LOG(FATAL) << "invalid task type"; case TaskType::kCopyHd: return DerivedT::VisitCopyHd(std::forward(args)...); case TaskType::kCopyCommNet: return DerivedT::VisitCopyCommNet(std::forward(args)...); case TaskType::kSliceBoxing: return DerivedT::VisitSliceBoxing(std::forward(args)...); case TaskType::kCollectiveBoxingGeneric: return DerivedT::VisitCollectiveBoxingGeneric(std::forward(args)...); case TaskType::kBoxingIdentity: return DerivedT::VisitBoxingIdentity(std::forward(args)...); case TaskType::kNcclSendRecvBoxing: return DerivedT::VisitNcclSendRecvBoxing(std::forward(args)...); case TaskType::kBoxingZeros: return DerivedT::VisitBoxingZeros(std::forward(args)...); case TaskType::kCollectiveBoxingPack: return DerivedT::VisitCollectiveBoxingPack(std::forward(args)...); case TaskType::kCollectiveBoxingUnpack: return DerivedT::VisitCollectiveBoxingUnpack(std::forward(args)...); default: LOG(FATAL) << "invalid task type"; } } }; struct CreateTransportTask final : public TransportTaskTypeVisitor { static Maybe VisitCopyHd() { return new CopyHdTaskNode(); } static Maybe VisitCopyCommNet() { return new CopyCommNetTaskNode(); } static Maybe VisitSliceBoxing() { return new SliceBoxingTaskNode(); } static Maybe VisitCollectiveBoxingGeneric() { return new CollectiveBoxingGenericTaskNode(); } static Maybe VisitBoxingIdentity() { return new BoxingIdentityTaskNode(); } static Maybe VisitCollectiveBoxingPack() { return new CollectiveBoxingPackTaskNode(); } static Maybe VisitCollectiveBoxingUnpack() { return new CollectiveBoxingUnpackTaskNode(); } static Maybe VisitBoxingZeros() { return new BoxingZerosTaskNode(); } static Maybe VisitNcclSendRecvBoxing() { return new NcclSendRecvBoxingTaskNode(); } }; } // namespace oneflow ================================================ FILE: oneflow/core/graph/transport_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { Maybe TransportTaskNode::InitTransportTaskFromProtoIf( const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { CHECK(has_new_task_id()); JUST(InitTransportTaskFromProto(transport_task_proto, ctx)); lbi_ = transport_task_proto.lbi(); return Maybe::Ok(); } void TransportTaskNode::ToTransportTaskProtoIf(TransportTaskProto* transport_task_proto) const { ToTransportTaskProto(transport_task_proto); *transport_task_proto->mutable_lbi() = lbi_; } } // namespace oneflow ================================================ FILE: oneflow/core/graph/transport_task_node.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ #include "oneflow/core/graph/task_node.h" #include "oneflow/core/register/logical_blob_id.pb.h" namespace oneflow { class TransportTaskProto; class TaskGraphRebuildCtx; class TransportTaskNode : public TaskNode { public: OF_DISALLOW_COPY_AND_MOVE(TransportTaskNode); TransportTaskNode() = default; virtual ~TransportTaskNode() = default; void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; } LogicalBlobId lbi() const { return lbi_; } Maybe InitTransportTaskFromProtoIf(const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx); void ToTransportTaskProtoIf(TransportTaskProto*) const; ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override { // TransportTaskNode infers output BlobDesc based on input BlobDesc, because it can't infers // output BlobDesc with SBP. return &ExecNode::InferBlobDescsByInputs; } private: virtual Maybe InitTransportTaskFromProto(const TransportTaskProto&, const TaskGraphRebuildCtx& ctx) = 0; virtual void ToTransportTaskProto(TransportTaskProto*) const = 0; LogicalBlobId lbi_; }; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ ================================================ FILE: oneflow/core/graph_impl/acc_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class AccCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(AccCompTaskNode); AccCompTaskNode() = default; ~AccCompTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kAcc; } void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; }; void AccCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr regst = ProduceRegst("out", false); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); } void AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst); out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); out_regst->ForEachLbi([out_regst](const LogicalBlobId& lbi) { const BlobDesc* blob_desc = out_regst->GetBlobDesc(lbi); CHECK_EQ(blob_desc->is_dynamic(), false); }); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAcc); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("acc", AccCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class AccCtrlTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickCompTaskNode); AccCtrlTickCompTaskNode() = default; ~AccCtrlTickCompTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kAccCtrlTick; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void BuildExecGphAndRegst() override; void ConsumeFakeRegsts() override; }; void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr regst = ProduceRegst("out", false); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); } void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); std::shared_ptr op = this->op(); ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op; exec_node->BindBnWithRegst(op->SoleIbn(), in_regst); out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); exec_node->BindBnWithRegst(op->SoleObn(), out_regst); (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("acc_ctrl_tick", AccCtrlTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/acc_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class AccTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(AccTickCompTaskNode); AccTickCompTaskNode() = default; ~AccTickCompTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kAccTick; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr regst = ProduceRegst("out", false); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); } void AccTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); std::shared_ptr op = this->op(); ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op; exec_node->BindBnWithRegst(op->SoleIbn(), in_regst); out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); exec_node->BindBnWithRegst(op->SoleObn(), out_regst); (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kAccTickConf, AccTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/callback_notify_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class CallbackNotifyCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyCompTaskNode); CallbackNotifyCompTaskNode() = default; ~CallbackNotifyCompTaskNode() = default; TaskType GetTaskType() const override { return TaskType::kCallbackNotify; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void CallbackNotifyCompTaskNode::ProduceAllRegstsAndBindEdges() {} void CallbackNotifyCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void CallbackNotifyCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = this->op(); for (const std::string& ibn : node->op()->input_bns()) { node->BindBnWithOneOfTheRegsts(ibn, GetConsumedRegst("in")); } CHECK(node->op()->tmp_bns().empty()); CHECK(node->op()->output_bns().empty()); } REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCPU, TaskType::kCallbackNotify, "CALLBACK_NOTIFY"); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCallbackNotifyConf, CallbackNotifyCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/case_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class CaseCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CaseCompTaskNode); CaseCompTaskNode() = default; ~CaseCompTaskNode() override = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kCase; } private: void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; }; void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() { HashMap lbi2obn_id; FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) { CHECK(lbi2obn_id.emplace(op()->BnInOp2Lbi(GenRepeatedBn("out", obn_id)), obn_id).second); } ForEachOutDataEdge([&](TaskEdge* edge) { const OpNode* succ = GetOneSuccOpNodeOnEdge(edge); int64_t obn_id = -1; for (const std::string& ibn : succ->shared_op()->input_bns()) { const LogicalBlobId& lbi = succ->shared_op()->BnInOp2Lbi(ibn); if (lbi2obn_id.find(lbi) != lbi2obn_id.cend()) { CHECK_EQ(obn_id, -1); obn_id = lbi2obn_id.at(lbi); } } CHECK_NE(obn_id, -1); std::string name = "out_" + std::to_string(obn_id); CHECK(GetProducedRegst(name) == nullptr); edge->AddRegst("out", ProduceRegst(name, false)); }); } void CaseCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr sole_op = op(); node->mut_op() = sole_op; node->BindBnWithRegst("in", GetSoleConsumedRegst("in")); FOR_RANGE(int64_t, obn_id, 0, sole_op->output_bns().size()) { std::string name = "out_" + std::to_string(obn_id); std::shared_ptr out_regst = GetProducedRegst(name); out_regst->AddLbi(sole_op->BnInOp2Lbi(name)); node->BindBnWithRegst(name, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void CaseCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kCase); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCaseConf, CaseCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickCompTaskNode); CriticalSectionWaitTickCompTaskNode() = default; ~CriticalSectionWaitTickCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kCriticalSectionWaitTick; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void CriticalSectionWaitTickCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 128, 128); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : node->op()->input_bns()) { node->BindBnWithOneOfTheRegsts(ibn, in_regsts); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kCriticalSectionWaitTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class DecodeH2DCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(DecodeH2DCompTaskNode); DecodeH2DCompTaskNode() = default; ~DecodeH2DCompTaskNode() override = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDecodeH2D; } private: void BuildExecGphAndRegst() override; }; void DecodeH2DCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void DecodeH2DCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() { auto regst_num = ParseIntegerFromEnv("ONEFLOW_DECODE_H2D_REGST_NUM", 2); std::shared_ptr out_regst = ProduceRegst("out", false, regst_num, regst_num); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); ProduceRegst("tmp", false); } void DecodeH2DCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr sole_op = op(); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCUDA, TaskType::kDecodeH2D, "DECODE_H2D") namespace { CompTaskNode* CreateCompTaskNodeByOpDeviceType(const OperatorConf& op_conf) { if (CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag())) == DeviceType::kCUDA) { return new DecodeH2DCompTaskNode; } else { return new NormalForwardCompTaskNode; } } } // namespace REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(OperatorConf::kImageDecoderRandomCropResizeConf, CreateCompTaskNodeByOpDeviceType); REGISTER_USER_OP_COMP_TASK_NODE_TYPE_WITH_FUNC("raw_reader", CreateCompTaskNodeByOpDeviceType); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/device_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class DeviceTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(DeviceTickCompTaskNode); DeviceTickCompTaskNode() = default; ~DeviceTickCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kDeviceTick; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void DeviceTickCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void DeviceTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void DeviceTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void DeviceTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : node->op()->input_bns()) { node->BindBnWithOneOfTheRegsts(ibn, in_regsts); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDeviceTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDeviceTickConf, DeviceTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class DistributeConcatCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(DistributeConcatCompTaskNode); DistributeConcatCompTaskNode() = default; ~DistributeConcatCompTaskNode() = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeConcat; } private: void BuildExecGphAndRegst() override; void BuildExecGphStructAndBindInRegst(); void BuildOutRegst(); }; void DistributeConcatCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", true); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void DistributeConcatCompTaskNode::ConsumeAllRegsts() { size_t cnt = 0; ForEachInDataEdge([&](TaskEdge* edge) { cnt += 1; ConsumeRegst("in", edge->GetSoleRegst()); }); CHECK_EQ(cnt, 1); } void DistributeConcatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void DistributeConcatCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); mut_exec_gph().TopoForEachNode( [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void DistributeConcatCompTaskNode::BuildExecGphStructAndBindInRegst() { ExecNode* cur_node = mut_exec_gph().NewNode(); cur_node->mut_op() = this->op(); auto in_regst = GetSoleConsumedRegst("in"); mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { const auto& ibn = cur_node->op()->input_bns().Get(parallel_ctx()->parallel_id()); cur_node->BindBnWithRegst(ibn, in_regst); CHECK(in_regst->HasLbi(cur_node->op()->BnInOp2Lbi(ibn))); }); } // namespace oneflow void DistributeConcatCompTaskNode::BuildOutRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { HashSet found_lbis; for (ExecEdge* out_edge : cur_node->out_edges()) { found_lbis.insert(out_edge->lbi()); } for (const std::string& obn : cur_node->op()->output_bns()) { out_regst->AddLbi(cur_node->op()->BnInOp2Lbi(obn)); cur_node->BindBnWithRegst(obn, out_regst); } }); // NOTE: we can ONLY set inplace when regst has ONLY ONE blob auto in_regst = GetSoleConsumedRegst("in"); if (in_regst->NumOfLbi() == 1) { out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDistributeConcat); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeConcatConf, DistributeConcatCompTaskNode); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeAddConf, DistributeConcatCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/distribute_split_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class DistributeSplitCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(DistributeSplitCompTaskNode); DistributeSplitCompTaskNode() = default; ~DistributeSplitCompTaskNode() = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeSplit; } private: void BuildExecGphAndRegst() override; void BuildExecGphStructAndBindInRegst(); void BuildOutRegst(); }; void DistributeSplitCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", true); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void DistributeSplitCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void DistributeSplitCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void DistributeSplitCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); mut_exec_gph().TopoForEachNode( [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void DistributeSplitCompTaskNode::BuildExecGphStructAndBindInRegst() { ExecNode* cur_node = mut_exec_gph().NewNode(); cur_node->mut_op() = this->op(); for (const std::string& ibn : cur_node->op()->input_bns()) { cur_node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in")); } } void DistributeSplitCompTaskNode::BuildOutRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { const auto& obn = cur_node->op()->output_bns().Get(parallel_ctx()->parallel_id()); out_regst->AddLbi(cur_node->op()->BnInOp2Lbi(obn)); cur_node->BindBnWithRegst(obn, out_regst); }); // NOTE: we can ONLY set inplace when regst has ONLY ONE blob auto in_regst = GetSoleConsumedRegst("in"); if (in_regst->NumOfLbi() == 1) { out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDistributeSplit); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeSplitConf, DistributeSplitCompTaskNode); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeCloneConf, DistributeSplitCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class DstSubsetTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickCompTaskNode); DstSubsetTickCompTaskNode() = default; ~DstSubsetTickCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kDstSubsetTick; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void DstSubsetTickCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 2, 2); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void DstSubsetTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void DstSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void DstSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : node->op()->input_bns()) { node->TryBindBnWithOneOfTheRegsts(ibn, in_regsts); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kDstSubsetTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDstSubsetTickConf, DstSubsetTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/esac_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class EsacCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(EsacCompTaskNode); EsacCompTaskNode() = default; ~EsacCompTaskNode() override = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override { UNIMPLEMENTED() << "EsacCompTaskNode is deprecated"; } TaskType GetTaskType() const override { return TaskType::kEsac; } private: void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; }; void EsacCompTaskNode::ConsumeAllRegsts() { HashMap lbi2ibn_id; FOR_RANGE(int64_t, ibn_id, 0, op()->input_bns().size()) { CHECK(lbi2ibn_id.emplace(op()->BnInOp2Lbi(GenRepeatedBn("in", ibn_id)), ibn_id).second); } ForEachInDataEdge([&](TaskEdge* edge) { const OpNode* pred = GetOnePredOpNodeOnEdge(edge); int64_t ibn_id = -1; for (const std::string& obn : pred->shared_op()->output_bns()) { const LogicalBlobId& lbi = pred->shared_op()->BnInOp2Lbi(obn); if (lbi2ibn_id.find(lbi) != lbi2ibn_id.cend()) { CHECK_EQ(ibn_id, -1); ibn_id = lbi2ibn_id.at(lbi); } } CHECK_NE(ibn_id, -1); ConsumeRegst("in_" + std::to_string(ibn_id), edge->GetSoleRegst()); }); } void EsacCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); } void EsacCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr sole_op = this->op(); node->mut_op() = sole_op; FOR_RANGE(int64_t, ibn_id, 0, sole_op->input_bns().size()) { node->BindBnWithRegst(GenRepeatedBn("in", ibn_id), GetSoleConsumedRegst("in_" + std::to_string(ibn_id))); } std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi("out")); node->BindBnWithRegst("out", out_regst); (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void EsacCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kEsac); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kEsacConf, EsacCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/normal_forward_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { size_t RegstNum4OpSameOutputBlob(OperatorConf::OpTypeCase op_type_case) { if (IsClassRegistered(op_type_case)) { std::unique_ptr ptr; ptr.reset(NewObj(op_type_case)); return *ptr; } else { return -1; } } std::string GetOutRegstNameByObn(const std::string& obn) { return "__" + obn; } } // namespace void NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num) { if (mem_block_num != -1) { CHECK_GT(mem_block_num, 0); ProduceRegst(name, false, mem_block_num, mem_block_num); } else { ProduceRegst(name, true); } } size_t RegstNum4Op(const Operator& sole_op) { size_t mem_block_num = RegstNum4OpSameOutputBlob(sole_op.op_conf().op_type_case()); if (sole_op.op_conf().has_user_conf()) { const std::string& op_type_name = sole_op.op_conf().user_conf().op_type_name(); const auto* op_reg_result = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); CHECK(op_reg_result != nullptr) << "op_type_name " << op_type_name << " not register"; if (op_reg_result->same_output_regst_num > 0) { mem_block_num = op_reg_result->same_output_regst_num; } if (IsClassRegistered(op_type_name)) { std::unique_ptr ptr; ptr.reset(NewObj(op_type_name)); mem_block_num = *ptr; } if (op_type_name == "identity_buffer") { mem_block_num = user_op::UserOpConfWrapper(sole_op.op_conf()).attr("buffer_size"); } } return mem_block_num; } void NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr sole_op = op(); size_t mem_block_num = RegstNum4Op(*sole_op); // when output blob num > 1 and task node on out edge is all NormalForwardCompTaskNode , // create multi out regst by output blob name in op HashMap lbi2out_regst_name; for (const std::string& obn : sole_op->output_bns()) { const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn); std::string out_regst_name = GetOutRegstNameByObn(obn); lbi2out_regst_name.insert({lbi, out_regst_name}); ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num); } ForEachOutDataEdge([&](TaskEdge* edge) { for (const LogicalBlobId& lbi : edge->GetLbis()) { auto it = lbi2out_regst_name.find(lbi); CHECK(it != lbi2out_regst_name.end()); BindEdgeWithProducedRegst(edge, it->second); } }); ProduceRegst("tmp", true); } void NormalForwardCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); } }); } void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void NormalForwardCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); BuildTmp7BufRegsts(); mut_exec_gph().TopoForEachNode( [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { ExecNode* cur_node = mut_exec_gph().NewNode(); cur_node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : cur_node->op()->input_bns()) { cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts); } } void NormalForwardCompTaskNode::BuildOutRegst() { ExecNode* exec_node = mut_exec_gph().SoleNode(); for (const std::string& obn : exec_node->op()->output_bns()) { std::string out_regst_name = GetOutRegstNameByObn(obn); std::shared_ptr out_regst = GetProducedRegst(out_regst_name); out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn)); exec_node->BindBnWithRegst(obn, out_regst); } } void NormalForwardCompTaskNode::BuildTmp7BufRegsts() { mut_exec_gph().ForEachNode([&](ExecNode* node) { node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); }); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kNormalForward); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/pack_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class PackCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(PackCompTaskNode); PackCompTaskNode() = default; ~PackCompTaskNode() override = default; TaskType GetTaskType() const override { return TaskType::kPack; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; }; void PackCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void PackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void PackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void PackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); std::shared_ptr in_regst = GetSoleConsumedRegst("in"); exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kPack); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("pack", PackCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class ReentrantLockCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(ReentrantLockCompTaskNode); ReentrantLockCompTaskNode() = default; ~ReentrantLockCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kReentrantLock; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; }; void ReentrantLockCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void ReentrantLockCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void ReentrantLockCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void ReentrantLockCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); // no regst_desc for ibn "end" provided because TaskGraph hates cycle node->BindBnWithOneOfTheRegsts("start", in_regsts); std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void ReentrantLockCompTaskNode::InferProducedDataRegstTimeShape() { std::shared_ptr time_shape(new Shape()); for (TaskEdge* edge : in_edges()) { if (edge->src_node()->GetFastestInputOutputTimeShape()) { *time_shape = *edge->src_node()->GetFastestInputOutputTimeShape(); } } CHECK_GT(time_shape->elem_cnt(), 0); ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) { *regst->mut_data_regst_time_shape() = time_shape; }); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kReentrantLock); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kReentrantLockConf, ReentrantLockCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/repeat_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class RepeatCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(RepeatCompTaskNode); RepeatCompTaskNode() = default; ~RepeatCompTaskNode() override = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kRepeat; } private: void BuildExecGphAndRegst() override; }; void RepeatCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void RepeatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); } void RepeatCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr sole_op = op(); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), in_regst); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); (node->*GetInferBlobDescsMethod())(parallel_ctx()); // NOTE(chengcheng): force inplace CHECK_EQ(in_regst->NumOfLbi(), 1); CHECK_EQ(out_regst->NumOfLbi(), 1); CHECK_EQ(in_regst->min_register_num(), 1); // NOTE(chengcheng): input need unreused mem in_regst->set_enable_reuse_mem(false); out_regst->set_force_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kRepeat); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("repeat", RepeatCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/source_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class SourceTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SourceTickCompTaskNode); SourceTickCompTaskNode() = default; ~SourceTickCompTaskNode() = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kSourceTick; } }; void SourceTickCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 2, 2); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); } void SourceTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSourceTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSourceTickConf, SourceTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class SrcSubsetTickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SrcSubsetTickCompTaskNode); SrcSubsetTickCompTaskNode() = default; ~SrcSubsetTickCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kSrcSubsetTick; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void SrcSubsetTickCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 2, 2); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void SrcSubsetTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void SrcSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : node->op()->input_bns()) { node->TryBindBnWithOneOfTheRegsts(ibn, in_regsts); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSrcSubsetTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSrcSubsetTickConf, SrcSubsetTickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class SspVariableProxyCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SspVariableProxyCompTaskNode); SspVariableProxyCompTaskNode() = default; ~SspVariableProxyCompTaskNode() = default; void ProduceAllRegstsAndBindEdges() override { int64_t buffer_size = user_op::UserOpConfWrapper(op()->op_conf()).attr("buffer_size"); CHECK_GT(buffer_size, 0); ProduceRegst("value", false, buffer_size, buffer_size); ProduceRegst("ref", false, 1, 1); HashMap> out_regst_name2edges; ForEachOutDataEdge( [&](TaskEdge* edge) { { auto* copy_hd_node = dynamic_cast(edge->dst_node()); if (copy_hd_node != nullptr) { // The only possible regst_name is "value" because "ref" is always strictly one-to-one // connected. CHECK_EQ(*out_regst_name2edges["value"].insert(edge).first, edge); return; } } auto* dst_node = dynamic_cast(edge->dst_node()); CHECK(dst_node != nullptr) << "SspVariableProxyTaskNode must be consumed by CompTaskNode. got " << TaskType_Name(edge->dst_node()->GetTaskType()); for (const std::string& ibn : dst_node->op()->input_bns()) { const LogicalBlobId& dst_in_lbi = dst_node->op()->BnInOp2Lbi(ibn); if (dst_in_lbi == op()->BnInOp2Lbi("ref_0")) { CHECK_EQ(*out_regst_name2edges["ref"].insert(edge).first, edge); } else if (dst_in_lbi == op()->BnInOp2Lbi("value_0")) { CHECK_EQ(*out_regst_name2edges["value"].insert(edge).first, edge); } else { // do nothing } } }); for (const auto& pair : out_regst_name2edges) { for (TaskEdge* edge : pair.second) { BindEdgeWithProducedRegst(edge, pair.first); } } } void ConsumeAllRegsts() override { ConsumeRegst("var"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("var", edge->GetSoleRegst()); }); } void ConsumeFakeRegsts() override { ConsumeFakeRegst("var"); } TaskType GetTaskType() const override { return TaskType::kSspVariableProxy; } private: void BuildExecGphAndRegst() override { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); mut_exec_gph().TopoForEachNode( [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void BuildExecGphStructAndBindInRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); exec_node->BindBnWithOneOfTheRegsts("var_0", GetConsumedRegst("var")); BindInplacebetweenVarAndRef(); } void BindInplacebetweenVarAndRef() { const auto& var_regst = GetSoleConsumedRegst("var"); CHECK_EQ(var_regst->NumOfLbi(), 1); CHECK_EQ(var_regst->min_register_num(), 1); CHECK_EQ(var_regst->max_register_num(), 1); const auto& ref_regst = GetProducedRegst("ref"); ref_regst->set_force_inplace_consumed_regst_desc_id(var_regst->regst_desc_id()); } void BuildOutRegst() { ExecNode* exec_node = mut_exec_gph().SoleNode(); const auto& AddLbiAndBindBn = [&](const std::string& regst_name) { // "ref_0" obn <-> "ref" regst_name // "value_0" obn <-> "value" regst_name const std::string& obn = regst_name + "_0"; const std::shared_ptr& regst = GetProducedRegst(regst_name); regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn)); exec_node->BindBnWithRegst(obn, regst); }; AddLbiAndBindBn("ref"); AddLbiAndBindBn("value"); } void InferProducedDataRegstTimeShape() override { NaiveInferProducedDataRegstTimeShape(); } }; REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kSspVariableProxy); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("ssp_variable_proxy", SspVariableProxyCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/tick_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class TickCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(TickCompTaskNode); TickCompTaskNode() = default; ~TickCompTaskNode() = default; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kTick; } private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; void TickCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void TickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in"); ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } void TickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void TickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : node->op()->input_bns()) { node->BindBnWithOneOfTheRegsts(ibn, in_regsts); } std::shared_ptr out_regst = GetProducedRegst("out"); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kTick); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kTickConf, TickCompTaskNode); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSinkTickConf, TickCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/unpack_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class UnpackCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(UnpackCompTaskNode); UnpackCompTaskNode() = default; ~UnpackCompTaskNode() override = default; TaskType GetTaskType() const override { return TaskType::kUnpack; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; }; void UnpackCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("out", false); ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); } void UnpackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } void UnpackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void UnpackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); exec_node->BindBnWithRegst(op()->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kUnpack); REGISTER_USER_OP_COMP_TASK_NODE_TYPE("unpack", UnpackCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/task_stream_index_manager.h" namespace oneflow { class WaitAndSendIdsCompTaskNode final : public CompTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsCompTaskNode); WaitAndSendIdsCompTaskNode() = default; ~WaitAndSendIdsCompTaskNode() override = default; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } TaskType GetTaskType() const override { return TaskType::kWaitAndSendIds; } private: void InferProducedDataRegstTimeShape() override; }; void WaitAndSendIdsCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 100, 100); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); } void WaitAndSendIdsCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); for (const std::string& obn : node->op()->output_bns()) { const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void WaitAndSendIdsCompTaskNode::InferProducedDataRegstTimeShape() { std::shared_ptr time_shape(new Shape({1, 1})); ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) { *regst->mut_data_regst_time_shape() = time_shape; }); } REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kWaitAndSendIds); REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsCompTaskNode); } // namespace oneflow ================================================ FILE: oneflow/core/hardware/basic_device_descriptor_list.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/basic_device_descriptor_list.h" namespace oneflow { namespace hardware { BasicDeviceDescriptorList::BasicDeviceDescriptorList( std::vector> device_descriptor_list) : device_descriptor_list_(std::move(device_descriptor_list)) {} BasicDeviceDescriptorList::BasicDeviceDescriptorList() : BasicDeviceDescriptorList(std::vector>()) {} BasicDeviceDescriptorList::~BasicDeviceDescriptorList() = default; size_t BasicDeviceDescriptorList::DeviceCount() const { return device_descriptor_list_.size(); } std::shared_ptr BasicDeviceDescriptorList::GetDevice(size_t ordinal) const { if (ordinal < device_descriptor_list_.size()) { return device_descriptor_list_.at(ordinal); } else { return nullptr; } } } // namespace hardware } // namespace oneflow ================================================ FILE: oneflow/core/hardware/basic_device_descriptor_list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_ #define ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_ #include "oneflow/core/hardware/device_descriptor_list.h" #include "oneflow/core/common/util.h" #include #include #include namespace oneflow { namespace hardware { class BasicDeviceDescriptorList : public DeviceDescriptorList { public: OF_DISALLOW_COPY_AND_MOVE(BasicDeviceDescriptorList); explicit BasicDeviceDescriptorList( std::vector> device_descriptor_list); BasicDeviceDescriptorList(); ~BasicDeviceDescriptorList() override; size_t DeviceCount() const override; std::shared_ptr GetDevice(size_t ordinal) const override; private: std::vector> device_descriptor_list_; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_ ================================================ FILE: oneflow/core/hardware/cuda_device_descriptor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/cuda_device_descriptor.h" #ifdef WITH_CUDA #include #include #include "nlohmann/json.hpp" namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyOrdinal[] = "ordinal"; constexpr char kJsonKeyName[] = "name"; constexpr char kJsonKeyTotalGlobalMemory[] = "total_global_memory_bytes"; constexpr char kJsonKeyClockRate[] = "clock_rate_khz"; constexpr char kJsonKeyComputeCapabilityMajor[] = "compute_capability_major"; constexpr char kJsonKeyComputeCapabilityMinor[] = "compute_capability_minor"; constexpr char kJsonKeyMemoryClockRate[] = "memory_clock_rate_khz"; constexpr char kJsonKeyMemoryBusWidth[] = "memory_bus_width_bit"; constexpr char kJsonKeyPCIBusID[] = "pci_bus_id"; } // namespace struct CudaDeviceDescriptor::Impl { int32_t ordinal{}; std::string name; size_t total_global_memory_bytes{}; int32_t clock_rate_khz{}; int32_t compute_capability_major{}; int32_t compute_capability_minor{}; int32_t memory_clock_rate_khz{}; int32_t memory_bus_width_bit{}; std::string pci_bus_id; }; CudaDeviceDescriptor::CudaDeviceDescriptor() { impl_.reset(new Impl()); } CudaDeviceDescriptor::~CudaDeviceDescriptor() = default; int32_t CudaDeviceDescriptor::Ordinal() const { return impl_->ordinal; } const std::string& CudaDeviceDescriptor::Name() const { return impl_->name; } size_t CudaDeviceDescriptor::GlobalMemorySizeBytes() const { return impl_->total_global_memory_bytes; } int32_t CudaDeviceDescriptor::ClockRateKHz() const { return impl_->clock_rate_khz; } int32_t CudaDeviceDescriptor::ComputeCapabilityMajor() const { return impl_->compute_capability_major; } int32_t CudaDeviceDescriptor::ComputeCapabilityMinor() const { return impl_->compute_capability_minor; } int32_t CudaDeviceDescriptor::MemoryClockRateKHz() const { return impl_->memory_clock_rate_khz; } int32_t CudaDeviceDescriptor::MemoryBusWidthBit() const { return impl_->memory_bus_width_bit; } const std::string& CudaDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; } std::shared_ptr CudaDeviceDescriptor::Query(int32_t ordinal) { cudaDeviceProp prop{}; const cudaError_t err = cudaGetDeviceProperties(&prop, ordinal); CHECK(err == cudaSuccess); static const std::set compiled_archs{CUDA_REAL_ARCHS}; if (compiled_archs.find(prop.major * 10 + prop.minor) == compiled_archs.cend() && compiled_archs.find(prop.major * 10) == compiled_archs.cend()) { static std::atomic once_flag(false); if (!once_flag.exchange(true)) { LOG(WARNING) << "The CUDA device '" << prop.name << "' with capability " << prop.major * 10 + prop.minor << " is not compatible with the current OneFlow installation. The current program " "may throw a 'no kernel image is available for execution " "on the device' error or hang for a long time. Please reinstall OneFlow " "compiled with a newer version of CUDA."; } } auto* desc = new CudaDeviceDescriptor(); desc->impl_->ordinal = ordinal; desc->impl_->name = prop.name; desc->impl_->total_global_memory_bytes = prop.totalGlobalMem; desc->impl_->clock_rate_khz = prop.clockRate; desc->impl_->compute_capability_major = prop.major; desc->impl_->compute_capability_minor = prop.minor; desc->impl_->memory_clock_rate_khz = prop.memoryClockRate; desc->impl_->memory_bus_width_bit = prop.memoryBusWidth; char pci_bus_id_buf[sizeof("00000000:00:00.0")]; if (CUDA_VERSION >= 11000 && cudaDeviceGetPCIBusId(pci_bus_id_buf, sizeof(pci_bus_id_buf), ordinal) == cudaSuccess) { for (int i = 0; i < sizeof(pci_bus_id_buf) - 1; ++i) { pci_bus_id_buf[i] = static_cast(std::tolower(pci_bus_id_buf[i])); } desc->impl_->pci_bus_id = pci_bus_id_buf; } else { desc->impl_->pci_bus_id = ""; } return std::shared_ptr(desc); } void CudaDeviceDescriptor::Serialize(std::string* serialized) const { nlohmann::json json_object; json_object[kJsonKeyOrdinal] = impl_->ordinal; json_object[kJsonKeyName] = impl_->name; json_object[kJsonKeyTotalGlobalMemory] = impl_->total_global_memory_bytes; json_object[kJsonKeyClockRate] = impl_->clock_rate_khz; json_object[kJsonKeyComputeCapabilityMajor] = impl_->compute_capability_major; json_object[kJsonKeyComputeCapabilityMinor] = impl_->compute_capability_minor; json_object[kJsonKeyMemoryClockRate] = impl_->memory_clock_rate_khz; json_object[kJsonKeyMemoryBusWidth] = impl_->memory_bus_width_bit; json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id; *serialized = json_object.dump(2); } std::shared_ptr CudaDeviceDescriptor::Deserialize( const std::string& serialized) { auto json_object = nlohmann::json::parse(serialized); auto* desc = new CudaDeviceDescriptor(); desc->impl_->ordinal = json_object[kJsonKeyOrdinal]; desc->impl_->name = json_object[kJsonKeyName]; desc->impl_->total_global_memory_bytes = json_object[kJsonKeyTotalGlobalMemory]; desc->impl_->clock_rate_khz = json_object[kJsonKeyClockRate]; desc->impl_->compute_capability_major = json_object[kJsonKeyComputeCapabilityMajor]; desc->impl_->compute_capability_minor = json_object[kJsonKeyComputeCapabilityMinor]; desc->impl_->memory_clock_rate_khz = json_object[kJsonKeyMemoryClockRate]; desc->impl_->memory_bus_width_bit = json_object[kJsonKeyMemoryBusWidth]; desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID]; return std::shared_ptr(desc); } } // namespace hardware } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/hardware/cuda_device_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_ #include "oneflow/core/hardware/device_descriptor.h" #include "oneflow/core/common/util.h" #include #include #ifdef WITH_CUDA namespace oneflow { namespace hardware { constexpr char kCudaDeviceDescriptorClassName[] = "cuda"; class CudaDeviceDescriptor : public DeviceDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(CudaDeviceDescriptor); ~CudaDeviceDescriptor() override; int32_t Ordinal() const; const std::string& Name() const; size_t GlobalMemorySizeBytes() const; int32_t ClockRateKHz() const; int32_t ComputeCapabilityMajor() const; int32_t ComputeCapabilityMinor() const; int32_t MemoryClockRateKHz() const; int32_t MemoryBusWidthBit() const; const std::string& PCIBusID() const; void Serialize(std::string* serialized) const; static std::shared_ptr Query(int32_t ordinal); static std::shared_ptr Deserialize(const std::string& serialized); private: CudaDeviceDescriptor(); struct Impl; std::unique_ptr impl_; }; } // namespace hardware } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/hardware/cuda_device_descriptor_class.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/device_descriptor_class.h" #include "oneflow/core/hardware/cuda_device_descriptor.h" #include "oneflow/core/hardware/basic_device_descriptor_list.h" #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" #include "nlohmann/json.hpp" #ifdef WITH_CUDA #include namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyDevices[] = "devices"; } // namespace class CudaDeviceDescriptorClass : public DeviceDescriptorClass { public: CudaDeviceDescriptorClass() = default; ~CudaDeviceDescriptorClass() override = default; std::shared_ptr QueryDeviceDescriptorList() const override { int n_dev = 0; cudaError_t err = cudaGetDeviceCount(&n_dev); if (err != cudaSuccess) { LOG(WARNING) << cudaGetErrorString(err); return std::make_shared( std::vector>()); } std::vector> devices(n_dev); for (int dev = 0; dev < n_dev; ++dev) { devices.at(dev) = CudaDeviceDescriptor::Query(dev); } return std::make_shared(devices); } std::string Name() const override { return kCudaDeviceDescriptorClassName; } void SerializeDeviceDescriptorList(const std::shared_ptr& list, std::string* serialized) const override { std::vector serialized_devices; serialized_devices.reserve(list->DeviceCount()); for (size_t i = 0; i < list->DeviceCount(); ++i) { auto cuda_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(cuda_device); std::string serialized_device; cuda_device->Serialize(&serialized_device); serialized_devices.emplace_back(std::move(serialized_device)); } nlohmann::json json_object; json_object[kJsonKeyDevices] = serialized_devices; *serialized = json_object.dump(); } std::shared_ptr DeserializeDeviceDescriptorList( const std::string& serialized) const override { auto json_object = nlohmann::json::parse(serialized); std::vector serialized_devices = json_object[kJsonKeyDevices]; std::vector> devices(serialized_devices.size()); for (int i = 0; i < serialized_devices.size(); ++i) { devices.at(i) = CudaDeviceDescriptor::Deserialize(serialized_devices.at(i)); } return std::make_shared(devices); } void DumpDeviceDescriptorListSummary(const std::shared_ptr& list, const std::string& path) const override { for (size_t i = 0; i < list->DeviceCount(); ++i) { auto cuda_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(cuda_device); auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + ".json")); std::string serialized; cuda_device->Serialize(&serialized); stream << serialized; } } }; COMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared())); } // namespace hardware } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/hardware/device_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace hardware { class DeviceDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptor); DeviceDescriptor() = default; virtual ~DeviceDescriptor() = default; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/hardware/device_descriptor_class.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/device_descriptor_class.h" #include #include #include #include namespace oneflow { namespace hardware { namespace { class DeviceClassRegistryStorage { public: DeviceClassRegistryStorage() = default; ~DeviceClassRegistryStorage() = default; void Register(std::shared_ptr descriptor_class) { std::lock_guard lock(mutex_); const std::string name = descriptor_class->Name(); if (!name2index_.emplace(name, classes_.size()).second) { abort(); } classes_.emplace_back(std::make_shared(name), std::move(descriptor_class)); } size_t RegisteredCount() { std::lock_guard lock(mutex_); return classes_.size(); } const std::string& GetRegisteredClass(size_t index) { std::lock_guard lock(mutex_); return *classes_.at(index).first; } std::shared_ptr GetRegistered(size_t index) { std::lock_guard lock(mutex_); return classes_.at(index).second; } std::shared_ptr GetRegistered(const std::string& name) { std::lock_guard lock(mutex_); auto it = name2index_.find(name); if (it == name2index_.end()) { return std::shared_ptr(); } return classes_.at(it->second).second; } static DeviceClassRegistryStorage& Instance() { static DeviceClassRegistryStorage instance; return instance; } private: std::unordered_map name2index_; std::vector, std::shared_ptr>> classes_; std::mutex mutex_; }; } // namespace void DeviceDescriptorClass::RegisterClass( std::shared_ptr descriptor_class) { DeviceClassRegistryStorage::Instance().Register(std::move(descriptor_class)); } size_t DeviceDescriptorClass::GetRegisteredClassesCount() { return DeviceClassRegistryStorage::Instance().RegisteredCount(); } std::shared_ptr DeviceDescriptorClass::GetRegisteredClass( size_t index) { return DeviceClassRegistryStorage::Instance().GetRegistered(index); } std::shared_ptr DeviceDescriptorClass::GetRegisteredClass( const std::string& class_name) { return DeviceClassRegistryStorage::Instance().GetRegistered(class_name); } } // namespace hardware } // namespace oneflow ================================================ FILE: oneflow/core/hardware/device_descriptor_class.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_ #define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_ #include "oneflow/core/hardware/device_descriptor_list.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace hardware { class DeviceDescriptorClass { public: OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptorClass); DeviceDescriptorClass() = default; virtual ~DeviceDescriptorClass() = default; virtual std::shared_ptr QueryDeviceDescriptorList() const = 0; virtual std::string Name() const = 0; virtual void SerializeDeviceDescriptorList( const std::shared_ptr& list, std::string* serialized) const = 0; virtual std::shared_ptr DeserializeDeviceDescriptorList( const std::string& serialized) const = 0; virtual void DumpDeviceDescriptorListSummary( const std::shared_ptr& list, const std::string& path) const = 0; static void RegisterClass(std::shared_ptr descriptor_class); static size_t GetRegisteredClassesCount(); static std::shared_ptr GetRegisteredClass(size_t index); static std::shared_ptr GetRegisteredClass( const std::string& class_name); }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_ ================================================ FILE: oneflow/core/hardware/device_descriptor_list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_ #define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_ #include "oneflow/core/hardware/device_descriptor.h" #include "oneflow/core/common/util.h" #include #include namespace oneflow { namespace hardware { class DeviceDescriptorList { public: OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptorList); DeviceDescriptorList() = default; virtual ~DeviceDescriptorList() = default; virtual size_t DeviceCount() const = 0; virtual std::shared_ptr GetDevice(size_t ordinal) const = 0; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_ ================================================ FILE: oneflow/core/hardware/net_ib_device_descriptor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/net_ib_device_descriptor.h" #ifdef WITH_RDMA #include "nlohmann/json.hpp" namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyOrdinal[] = "ordinal"; constexpr char kJsonKeyName[] = "name"; constexpr char kJsonKeyGUID[] = "guid"; constexpr char kJsonKeyPort[] = "port"; constexpr char kJsonKeyLankLayer[] = "link_layer"; constexpr char kJsonValueLinkLayerInfiniBand[] = "InfiniBand"; constexpr char kJsonValueLinkLayerEthernet[] = "Ethernet"; constexpr char kJsonKeyPCIBusID[] = "pci_bus_id"; void GetPCIBusID(const std::string& name, std::string* pci_bus_id) { #ifdef __linux__ const std::string device_path = "/sys/class/infiniband/" + name + "/device"; const char* device_real_path = realpath(device_path.data(), nullptr); if (device_real_path == nullptr) { return; } const std::string device_real_path_str = device_real_path; const size_t pos = device_real_path_str.rfind('/'); if (pos == std::string::npos) { return; } *pci_bus_id = device_real_path_str.substr(pos + 1); #endif } } // namespace struct NetIBDeviceDescriptor::Impl { int32_t ordinal{}; std::string name; uint64_t guid{}; uint8_t port{}; NetIBDeviceDescriptorLinkLayer link_layer{}; std::string pci_bus_id; }; NetIBDeviceDescriptor::NetIBDeviceDescriptor() { impl_.reset(new Impl()); } NetIBDeviceDescriptor::~NetIBDeviceDescriptor() = default; int32_t NetIBDeviceDescriptor::Ordinal() const { return impl_->ordinal; } const std::string& NetIBDeviceDescriptor::Name() const { return impl_->name; } uint64_t NetIBDeviceDescriptor::GUID() const { return impl_->guid; } uint8_t NetIBDeviceDescriptor::Port() const { return impl_->port; } NetIBDeviceDescriptorLinkLayer NetIBDeviceDescriptor::LinkLayer() const { return impl_->link_layer; } const std::string& NetIBDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; } void NetIBDeviceDescriptor::Serialize(std::string* serialized) const { nlohmann::json json_object; json_object[kJsonKeyOrdinal] = impl_->ordinal; json_object[kJsonKeyName] = impl_->name; json_object[kJsonKeyGUID] = impl_->guid; json_object[kJsonKeyPort] = impl_->port; if (impl_->link_layer == kNetIBDeviceDescriptorLinkLayerInfiniBand) { json_object[kJsonKeyLankLayer] = kJsonValueLinkLayerInfiniBand; } else if (impl_->link_layer == kNetIBDeviceDescriptorLinkLayerEthernet) { json_object[kJsonKeyLankLayer] = kJsonValueLinkLayerEthernet; } else { UNIMPLEMENTED(); } json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id; *serialized = json_object.dump(2); } std::shared_ptr NetIBDeviceDescriptor::Query(int32_t ordinal, ibv_context* context, uint8_t port) { CHECK(ibv::IsAvailable()); ibv_device_attr device_attr{}; if (ibv::wrapper.ibv_query_device(context, &device_attr) != 0) { VLOG(3) << "Unable to query device: " << context->device->name; return std::shared_ptr(); } ibv_port_attr port_attr{}; if (ibv::wrapper.ibv_query_port_wrap(context, port, &port_attr) != 0) { VLOG(3) << "Unable to query port: device " << context->device->name << " port " << port; return std::shared_ptr(); } if (port_attr.state != IBV_PORT_ACTIVE) { VLOG(3) << "Inactivate port: device " << context->device->name << " port " << port; return std::shared_ptr(); } if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND && port_attr.link_layer != IBV_LINK_LAYER_ETHERNET) { VLOG(3) << "Link layer is not supported: device " << context->device->name << " port " << port; return std::shared_ptr(); } auto* desc = new NetIBDeviceDescriptor(); desc->impl_->ordinal = ordinal; desc->impl_->name = context->device->name; desc->impl_->guid = device_attr.sys_image_guid; desc->impl_->port = port; if (port_attr.link_layer == IBV_LINK_LAYER_INFINIBAND) { desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerInfiniBand; } else if (port_attr.link_layer == IBV_LINK_LAYER_ETHERNET) { desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerEthernet; } else { UNIMPLEMENTED(); } GetPCIBusID(desc->impl_->name, &desc->impl_->pci_bus_id); return std::shared_ptr(desc); } std::shared_ptr NetIBDeviceDescriptor::Deserialize( const std::string& serialized) { auto json_object = nlohmann::json::parse(serialized); auto* desc = new NetIBDeviceDescriptor(); desc->impl_->ordinal = json_object[kJsonKeyOrdinal]; desc->impl_->name = json_object[kJsonKeyName]; desc->impl_->guid = json_object[kJsonKeyGUID]; desc->impl_->port = json_object[kJsonKeyPort]; const std::string link_layer_value = json_object[kJsonKeyLankLayer]; if (link_layer_value == kJsonValueLinkLayerInfiniBand) { desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerInfiniBand; } else if (link_layer_value == kJsonValueLinkLayerEthernet) { desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerEthernet; } else { UNIMPLEMENTED(); } desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID]; return std::shared_ptr(desc); } } // namespace hardware } // namespace oneflow #endif // WITH_RDMA ================================================ FILE: oneflow/core/hardware/net_ib_device_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_ #include "oneflow/core/hardware/device_descriptor.h" #include "oneflow/core/common/util.h" #include #include #ifdef WITH_RDMA #include "oneflow/core/platform/include/ibv.h" namespace oneflow { namespace hardware { constexpr char kNetIBDeviceDescriptorClassName[] = "net_ib"; enum NetIBDeviceDescriptorLinkLayer { kNetIBDeviceDescriptorLinkLayerInvalid = 0, kNetIBDeviceDescriptorLinkLayerInfiniBand = 1, kNetIBDeviceDescriptorLinkLayerEthernet = 2, }; class NetIBDeviceDescriptor : public DeviceDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(NetIBDeviceDescriptor); ~NetIBDeviceDescriptor() override; int32_t Ordinal() const; const std::string& Name() const; uint64_t GUID() const; uint8_t Port() const; NetIBDeviceDescriptorLinkLayer LinkLayer() const; const std::string& PCIBusID() const; void Serialize(std::string* serialized) const; static std::shared_ptr Query(int32_t ordinal, ibv_context* context, uint8_t port); static std::shared_ptr Deserialize(const std::string& serialized); private: NetIBDeviceDescriptor(); struct Impl; std::unique_ptr impl_; }; } // namespace hardware } // namespace oneflow #endif // WITH_RDMA #endif // ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/hardware/net_ib_device_descriptor_class.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/device_descriptor_class.h" #include "oneflow/core/hardware/net_ib_device_descriptor.h" #include "oneflow/core/hardware/basic_device_descriptor_list.h" #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" #include "nlohmann/json.hpp" #ifdef WITH_RDMA namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyDevices[] = "devices"; } // namespace class NetIBDeviceDescriptorClass : public DeviceDescriptorClass { public: NetIBDeviceDescriptorClass() = default; ~NetIBDeviceDescriptorClass() override = default; std::shared_ptr QueryDeviceDescriptorList() const override { std::vector> devices; int num_devices; if (!ibv::IsAvailable()) { return std::make_shared(devices); } ibv_device** device_list = ibv::wrapper.ibv_get_device_list(&num_devices); if (device_list == nullptr) { return std::make_shared(devices); } for (int i = 0; i < num_devices; ++i) { ibv_device* device = device_list[i]; ibv_context* context = ibv::wrapper.ibv_open_device(device); if (context == nullptr) { continue; } ibv_device_attr device_attr{}; if (ibv::wrapper.ibv_query_device(context, &device_attr) != 0) { CHECK_EQ(ibv::wrapper.ibv_close_device(context), 0); } for (int port = 1; port <= device_attr.phys_port_cnt; ++port) { auto device_desc = NetIBDeviceDescriptor::Query(static_cast(devices.size()), context, port); if (device_desc) { devices.emplace_back(device_desc); } } } ibv::wrapper.ibv_free_device_list(device_list); return std::make_shared(devices); } std::string Name() const override { return kNetIBDeviceDescriptorClassName; } void SerializeDeviceDescriptorList(const std::shared_ptr& list, std::string* serialized) const override { std::vector serialized_devices; serialized_devices.reserve(list->DeviceCount()); for (size_t i = 0; i < list->DeviceCount(); ++i) { auto ib_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(ib_device); std::string serialized_device; ib_device->Serialize(&serialized_device); serialized_devices.emplace_back(std::move(serialized_device)); } nlohmann::json json_object; json_object[kJsonKeyDevices] = serialized_devices; *serialized = json_object.dump(); } std::shared_ptr DeserializeDeviceDescriptorList( const std::string& serialized) const override { auto json_object = nlohmann::json::parse(serialized); std::vector serialized_devices = json_object[kJsonKeyDevices]; std::vector> devices(serialized_devices.size()); for (int i = 0; i < serialized_devices.size(); ++i) { devices.at(i) = NetIBDeviceDescriptor::Deserialize(serialized_devices.at(i)); } return std::make_shared(devices); } void DumpDeviceDescriptorListSummary(const std::shared_ptr& list, const std::string& path) const override { for (size_t i = 0; i < list->DeviceCount(); ++i) { auto ib_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(ib_device); auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + ".json")); std::string serialized; ib_device->Serialize(&serialized); stream << serialized; } } }; COMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared())); } // namespace hardware } // namespace oneflow #endif // WITH_RDMA ================================================ FILE: oneflow/core/hardware/net_socket_device_descriptor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/hardware/net_socket_device_descriptor.h" #include "nlohmann/json.hpp" namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyOrdinal[] = "ordinal"; constexpr char kJsonKeyName[] = "name"; constexpr char kJsonKeyAddress[] = "address"; constexpr char kJsonKeyPCIBusID[] = "pci_bus_id"; void GetPCIBusID(const std::string& name, std::string* pci_bus_id) { #ifdef __linux__ const std::string device_path = "/sys/class/net/" + name + "/device"; char* device_real_path = realpath(device_path.data(), nullptr); if (device_real_path == nullptr) { return; } const std::string device_real_path_str = device_real_path; free(device_real_path); // NOLINT const size_t pos = device_real_path_str.rfind('/'); if (pos == std::string::npos) { return; } *pci_bus_id = device_real_path_str.substr(pos + 1); #endif } } // namespace struct NetSocketDeviceDescriptor::Impl { int32_t ordinal{}; std::string name; std::string address; std::string pci_bus_id; }; NetSocketDeviceDescriptor::NetSocketDeviceDescriptor() { impl_.reset(new Impl()); } NetSocketDeviceDescriptor::~NetSocketDeviceDescriptor() = default; int32_t NetSocketDeviceDescriptor::Ordinal() const { return impl_->ordinal; } const std::string& NetSocketDeviceDescriptor::Name() const { return impl_->name; } const std::string& NetSocketDeviceDescriptor::Address() const { return impl_->address; } const std::string& NetSocketDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; } void NetSocketDeviceDescriptor::Serialize(std::string* serialized) const { nlohmann::json json_object; json_object[kJsonKeyOrdinal] = impl_->ordinal; json_object[kJsonKeyName] = impl_->name; json_object[kJsonKeyAddress] = impl_->address; json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id; *serialized = json_object.dump(2); } std::shared_ptr NetSocketDeviceDescriptor::Query( int32_t ordinal, const std::string& name, const std::string& address) { auto* desc = new NetSocketDeviceDescriptor(); desc->impl_->ordinal = ordinal; desc->impl_->name = name; desc->impl_->address = address; GetPCIBusID(name, &desc->impl_->pci_bus_id); return std::shared_ptr(desc); } std::shared_ptr NetSocketDeviceDescriptor::Deserialize( const std::string& serialized) { auto json_object = nlohmann::json::parse(serialized); auto* desc = new NetSocketDeviceDescriptor(); desc->impl_->ordinal = json_object[kJsonKeyOrdinal]; desc->impl_->name = json_object[kJsonKeyName]; desc->impl_->address = json_object[kJsonKeyAddress]; desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID]; return std::shared_ptr(desc); } } // namespace hardware } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/hardware/net_socket_device_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_ #ifdef __linux__ #include "oneflow/core/hardware/device_descriptor.h" #include "oneflow/core/common/util.h" #include #include #include namespace oneflow { namespace hardware { constexpr char kNetSocketDeviceDescriptorClassName[] = "net_socket"; class NetSocketDeviceDescriptor : public DeviceDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(NetSocketDeviceDescriptor); ~NetSocketDeviceDescriptor() override; int32_t Ordinal() const; const std::string& Name() const; const std::string& Address() const; const std::string& PCIBusID() const; void Serialize(std::string* serialized) const; static std::shared_ptr Query(int32_t ordinal, const std::string& name, const std::string& address); static std::shared_ptr Deserialize( const std::string& serialized); private: NetSocketDeviceDescriptor(); struct Impl; std::unique_ptr impl_; }; } // namespace hardware } // namespace oneflow #endif // __linux__ #endif // ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/hardware/net_socket_device_descriptor_class.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/hardware/device_descriptor_class.h" #include "oneflow/core/hardware/net_socket_device_descriptor.h" #include "oneflow/core/hardware/basic_device_descriptor_list.h" #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" #include "nlohmann/json.hpp" #include #include #include #include namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyDevices[] = "devices"; } // namespace class NetSocketDeviceDescriptorClass : public DeviceDescriptorClass { public: NetSocketDeviceDescriptorClass() = default; ~NetSocketDeviceDescriptorClass() override = default; std::shared_ptr QueryDeviceDescriptorList() const override { std::vector> devices; ifaddrs* interfaces = nullptr; if (getifaddrs(&interfaces) != 0) { return std::make_shared(); } ifaddrs* ifa = nullptr; for (ifa = interfaces; ifa != nullptr; ifa = ifa->ifa_next) { if (ifa->ifa_addr == nullptr) { continue; } const std::string name(ifa->ifa_name); if (name == "lo") { continue; } // TODO(liujuncheng): support ipv6 if (ifa->ifa_addr->sa_family != AF_INET) { continue; } if (std::count_if(devices.cbegin(), devices.cend(), [&](const std::shared_ptr& device) { return device->Name() == name; }) != 0) { continue; } char host[NI_MAXHOST]; const socklen_t sa_len = (ifa->ifa_addr->sa_family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); if (getnameinfo(ifa->ifa_addr, sa_len, host, NI_MAXHOST, nullptr, 0, NI_NUMERICHOST) != 0) { continue; } auto socket_device = NetSocketDeviceDescriptor::Query(static_cast(devices.size()), name, host); if (socket_device) { devices.emplace_back(socket_device); } } freeifaddrs(interfaces); return std::make_shared( std::vector>{devices.begin(), devices.end()}); } std::string Name() const override { return kNetSocketDeviceDescriptorClassName; } void SerializeDeviceDescriptorList(const std::shared_ptr& list, std::string* serialized) const override { std::vector serialized_devices; serialized_devices.reserve(list->DeviceCount()); for (size_t i = 0; i < list->DeviceCount(); ++i) { auto socket_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(socket_device); std::string serialized_device; socket_device->Serialize(&serialized_device); serialized_devices.emplace_back(std::move(serialized_device)); } nlohmann::json json_object; json_object[kJsonKeyDevices] = serialized_devices; *serialized = json_object.dump(); } std::shared_ptr DeserializeDeviceDescriptorList( const std::string& serialized) const override { auto json_object = nlohmann::json::parse(serialized); std::vector serialized_devices = json_object[kJsonKeyDevices]; std::vector> devices(serialized_devices.size()); for (int i = 0; i < serialized_devices.size(); ++i) { devices.at(i) = NetSocketDeviceDescriptor::Deserialize(serialized_devices.at(i)); } return std::make_shared(devices); } void DumpDeviceDescriptorListSummary(const std::shared_ptr& list, const std::string& path) const override { for (size_t i = 0; i < list->DeviceCount(); ++i) { auto socket_device = std::dynamic_pointer_cast(list->GetDevice(i)); CHECK(socket_device); auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + ".json")); std::string serialized; socket_device->Serialize(&serialized); stream << serialized; } } }; COMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared())); } // namespace hardware } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/hardware/node_device_descriptor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/node_device_descriptor.h" #include "oneflow/core/hardware/device_descriptor_class.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "nlohmann/json.hpp" #ifdef WITH_HWLOC #include #endif // WITH_HWLOC namespace oneflow { namespace hardware { namespace { constexpr char kJsonKeyClasses[] = "classes"; constexpr char kJsonKeyClassName[] = "class_name"; constexpr char kJsonKeySerializedDescriptorList[] = "serialized_descriptor_list"; constexpr char kJsonKeyHostMemorySize[] = "host_memory_size_bytes"; constexpr char kJsonKeyTopology[] = "topology"; class DummyCPUAffinityDescriptor : public TopologyCPUAffinityDescriptor { public: DummyCPUAffinityDescriptor() = default; ~DummyCPUAffinityDescriptor() override = default; }; class DummyMemoryAffinityDescriptor : public TopologyMemoryAffinityDescriptor { public: DummyMemoryAffinityDescriptor() = default; ~DummyMemoryAffinityDescriptor() override = default; }; class DummyTopologyDescriptor : public TopologyDescriptor { public: DummyTopologyDescriptor() = default; ~DummyTopologyDescriptor() override = default; std::shared_ptr GetCPUAffinity() const override { return std::make_shared(); } std::shared_ptr GetMemoryAffinity() const override { return std::make_shared(); } std::shared_ptr GetCPUAffinityByPCIBusID( const std::string& bus_id) const override { return std::make_shared(); } std::shared_ptr GetMemoryAffinityByPCIBusID( const std::string& bus_id) const override { return std::make_shared(); } void SetCPUAffinity( const std::shared_ptr& affinity) const override {} void SetMemoryAffinity( const std::shared_ptr& affinity) const override {} }; #ifdef WITH_HWLOC class HWLocCPUAffinityDescriptor : public TopologyCPUAffinityDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(HWLocCPUAffinityDescriptor); explicit HWLocCPUAffinityDescriptor(hwloc_cpuset_t hwloc_cpu_set) : hwloc_cpu_set_(hwloc_cpu_set) {} ~HWLocCPUAffinityDescriptor() override { hwloc_bitmap_free(hwloc_cpu_set_); } hwloc_cpuset_t HWLocCPUSet() const { return hwloc_cpu_set_; } private: hwloc_cpuset_t hwloc_cpu_set_; }; class HWLocMemoryAffinityDescriptor : public TopologyMemoryAffinityDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(HWLocMemoryAffinityDescriptor); explicit HWLocMemoryAffinityDescriptor(hwloc_bitmap_t hwloc_bitmap, hwloc_membind_policy_t policy) : hwloc_bitmap_(hwloc_bitmap), policy_(policy) {} ~HWLocMemoryAffinityDescriptor() override { hwloc_bitmap_free(hwloc_bitmap_); } hwloc_bitmap_t HWLocBitmap() const { return hwloc_bitmap_; } hwloc_membind_policy_t HWLocPolicy() const { return policy_; } private: hwloc_bitmap_t hwloc_bitmap_; hwloc_membind_policy_t policy_; }; class HWLocTopologyDescriptor : public TopologyDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(HWLocTopologyDescriptor); ~HWLocTopologyDescriptor() override { hwloc_topology_destroy(topology_); } std::shared_ptr GetCPUAffinity() const override { hwloc_bitmap_t set = hwloc_bitmap_alloc(); if (hwloc_get_cpubind(topology_, set, HWLOC_CPUBIND_THREAD) != 0) { return nullptr; } return std::make_shared(set); } std::shared_ptr GetMemoryAffinity() const override { hwloc_bitmap_t set = hwloc_bitmap_alloc(); hwloc_membind_policy_t policy{}; if (hwloc_get_membind(topology_, set, &policy, HWLOC_MEMBIND_THREAD) != 0) { return nullptr; } return std::make_shared(set, policy); } std::shared_ptr GetCPUAffinityByPCIBusID( const std::string& bus_id) const override { if (bus_id.empty()) { return nullptr; } hwloc_obj_t non_io_ancestor = GetNonIOAncestorByPCIBusID(bus_id); if (non_io_ancestor == nullptr) { return nullptr; } if (non_io_ancestor->cpuset == nullptr) { return nullptr; } return std::make_shared( hwloc_bitmap_dup(non_io_ancestor->cpuset)); } std::shared_ptr GetMemoryAffinityByPCIBusID( const std::string& bus_id) const override { if (bus_id.empty()) { return nullptr; } hwloc_obj_t non_io_ancestor = GetNonIOAncestorByPCIBusID(bus_id); if (non_io_ancestor == nullptr) { return nullptr; } if (non_io_ancestor->cpuset == nullptr) { return nullptr; } return std::make_shared( hwloc_bitmap_dup(non_io_ancestor->cpuset), HWLOC_MEMBIND_BIND); } void SetCPUAffinity( const std::shared_ptr& affinity) const override { auto hwloc_affinity = std::dynamic_pointer_cast(affinity); if (!hwloc_affinity) { return; } hwloc_set_cpubind(topology_, hwloc_affinity->HWLocCPUSet(), HWLOC_CPUBIND_THREAD); } void SetMemoryAffinity( const std::shared_ptr& affinity) const override { auto hwloc_affinity = std::dynamic_pointer_cast(affinity); if (!hwloc_affinity) { return; } hwloc_set_membind(topology_, hwloc_affinity->HWLocBitmap(), hwloc_affinity->HWLocPolicy(), HWLOC_MEMBIND_THREAD); } static std::shared_ptr Query() { hwloc_topology_t topology = nullptr; do { if (hwloc_topology_init(&topology) != 0) { break; } if (hwloc_topology_set_io_types_filter(topology, HWLOC_TYPE_FILTER_KEEP_ALL) != 0) { break; } if (hwloc_topology_load(topology) != 0) { break; } auto* desc = new HWLocTopologyDescriptor(topology); return std::shared_ptr(desc); } while (false); if (topology != nullptr) { hwloc_topology_destroy(topology); } return nullptr; } static std::shared_ptr Deserialize(const std::string& serialized) { hwloc_topology_t topology = nullptr; do { if (hwloc_topology_init(&topology) != 0) { break; } if (hwloc_topology_set_xmlbuffer(topology, serialized.data(), static_cast(serialized.size())) != 0) { break; } if (hwloc_topology_load(topology) != 0) { break; } auto* desc = new HWLocTopologyDescriptor(topology); return std::shared_ptr(desc); } while (false); if (topology != nullptr) { hwloc_topology_destroy(topology); } return nullptr; } void Serialize(std::string* serialized) const { char* buffer = nullptr; int len = 0; if (hwloc_topology_export_xmlbuffer(topology_, &buffer, &len, 0) == 0) { *serialized = buffer; hwloc_free_xmlbuffer(topology_, buffer); } } private: hwloc_obj_t GetNonIOAncestorByPCIBusID(const std::string& pci_bus_id) const { hwloc_obj_t device = hwloc_get_pcidev_by_busidstring(topology_, pci_bus_id.data()); if (device == nullptr) { return nullptr; } hwloc_obj_t non_io_ancestor = hwloc_get_non_io_ancestor_obj(topology_, device); return non_io_ancestor; } explicit HWLocTopologyDescriptor(hwloc_topology_t topology) : topology_(topology) {} hwloc_topology_t topology_{}; }; #endif // WITH_HWLOC std::shared_ptr QueryTopologyDescriptor() { std::shared_ptr topology; #ifdef WITH_HWLOC topology = HWLocTopologyDescriptor::Query(); #endif // WITH_HWLOC if (!topology) { topology.reset(new DummyTopologyDescriptor()); } return topology; } std::shared_ptr DeserializeTopologyDescriptor( const std::string& serialized) { std::shared_ptr topology; if (serialized.empty()) { topology.reset(new DummyTopologyDescriptor()); } else { #ifdef WITH_HWLOC topology = HWLocTopologyDescriptor::Deserialize(serialized); #else UNIMPLEMENTED(); #endif // WITH_HWLOC } if (!topology) { topology.reset(new DummyTopologyDescriptor()); } return topology; } void SerializeTopologyDescriptor(const std::shared_ptr& topology, std::string* serialized) { #ifdef WITH_HWLOC auto hwloc_topology = std::dynamic_pointer_cast(topology); if (hwloc_topology) { hwloc_topology->Serialize(serialized); } #endif // WITH_HWLOC } } // namespace struct NodeDeviceDescriptor::Impl { std::unordered_map> class_name2descriptor_list; size_t host_memory_size_bytes{}; std::shared_ptr topology; }; NodeDeviceDescriptor::NodeDeviceDescriptor() { impl_.reset(new Impl()); } NodeDeviceDescriptor::~NodeDeviceDescriptor() = default; bool NodeDeviceDescriptor::HasDeviceClass(const std::string& class_name) const { return impl_->class_name2descriptor_list.find(class_name) != impl_->class_name2descriptor_list.end(); } std::shared_ptr NodeDeviceDescriptor::GetDeviceDescriptorList( const std::string& class_name) const { auto it = impl_->class_name2descriptor_list.find(class_name); if (it != impl_->class_name2descriptor_list.end()) { return it->second; } else { return nullptr; } } std::shared_ptr NodeDeviceDescriptor::GetDevice( const std::string& class_name, size_t ordinal) const { const auto device_list = GetDeviceDescriptorList(class_name); if (device_list) { return device_list->GetDevice(ordinal); } else { return nullptr; } } size_t NodeDeviceDescriptor::HostMemorySizeBytes() const { return impl_->host_memory_size_bytes; } std::shared_ptr NodeDeviceDescriptor::Topology() const { return impl_->topology; } void NodeDeviceDescriptor::Serialize(std::string* serialized) const { nlohmann::json json_object; json_object[kJsonKeyHostMemorySize] = impl_->host_memory_size_bytes; for (const auto& pair : impl_->class_name2descriptor_list) { std::string serialized_descriptor_list; auto clz = DeviceDescriptorClass::GetRegisteredClass(pair.first); CHECK(clz); clz->SerializeDeviceDescriptorList(pair.second, &serialized_descriptor_list); json_object[kJsonKeyClasses].push_back( {{kJsonKeyClassName, clz->Name()}, {kJsonKeySerializedDescriptorList, serialized_descriptor_list}}); } std::string serialized_topology; SerializeTopologyDescriptor(impl_->topology, &serialized_topology); json_object[kJsonKeyTopology] = serialized_topology; *serialized = json_object.dump(); } void NodeDeviceDescriptor::DumpSummary(const std::string& path) const { std::string classes_base = JoinPath(path, "classes"); for (const auto& pair : impl_->class_name2descriptor_list) { auto clz = DeviceDescriptorClass::GetRegisteredClass(pair.first); CHECK(clz); clz->DumpDeviceDescriptorListSummary(pair.second, JoinPath(classes_base, pair.first)); } std::string serialized_topology; SerializeTopologyDescriptor(impl_->topology, &serialized_topology); if (!serialized_topology.empty()) { TeePersistentLogStream::Create(JoinPath(path, "topology"))->Write(serialized_topology); } } std::shared_ptr NodeDeviceDescriptor::Query() { auto* desc = new NodeDeviceDescriptor(); desc->impl_->host_memory_size_bytes = GetAvailableCpuMemSize(); const size_t num_classes = DeviceDescriptorClass::GetRegisteredClassesCount(); for (size_t i = 0; i < num_classes; ++i) { std::shared_ptr descriptor_class = DeviceDescriptorClass::GetRegisteredClass(i); desc->impl_->class_name2descriptor_list.emplace(descriptor_class->Name(), descriptor_class->QueryDeviceDescriptorList()); } desc->impl_->topology = QueryTopologyDescriptor(); return std::shared_ptr(desc); } std::shared_ptr NodeDeviceDescriptor::Deserialize( const std::string& serialized) { auto json_object = nlohmann::json::parse(serialized); auto* desc = new NodeDeviceDescriptor(); desc->impl_->host_memory_size_bytes = json_object[kJsonKeyHostMemorySize]; auto num_classes = json_object[kJsonKeyClasses].size(); for (int i = 0; i < num_classes; ++i) { const std::string class_name = json_object[kJsonKeyClasses].at(i)[kJsonKeyClassName]; const std::string serialized_descriptor_list = json_object[kJsonKeyClasses].at(i)[kJsonKeySerializedDescriptorList]; auto clz = DeviceDescriptorClass::GetRegisteredClass(class_name); CHECK(clz); const auto descriptor_list = clz->DeserializeDeviceDescriptorList(serialized_descriptor_list); desc->impl_->class_name2descriptor_list.emplace(class_name, descriptor_list); } desc->impl_->topology = DeserializeTopologyDescriptor(json_object[kJsonKeyTopology]); return std::shared_ptr(desc); } } // namespace hardware } // namespace oneflow ================================================ FILE: oneflow/core/hardware/node_device_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_ #include "oneflow/core/hardware/device_descriptor_list.h" #include "oneflow/core/hardware/topology_descriptor.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace hardware { class NodeDeviceDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(NodeDeviceDescriptor); ~NodeDeviceDescriptor(); bool HasDeviceClass(const std::string& class_name) const; std::shared_ptr GetDeviceDescriptorList( const std::string& class_name) const; std::shared_ptr GetDevice(const std::string& class_name, size_t ordinal) const; size_t HostMemorySizeBytes() const; std::shared_ptr Topology() const; void Serialize(std::string* serialized) const; void DumpSummary(const std::string& path) const; static std::shared_ptr Query(); static std::shared_ptr Deserialize(const std::string& serialized); private: NodeDeviceDescriptor(); struct Impl; std::unique_ptr impl_; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/hardware/node_device_descriptor_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/node_device_descriptor_manager.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace hardware { namespace { std::string MakeNodeDeviceDescriptorRpcKey(const int64_t rank) { return "NodeDeviceDescriptorRpcKey/" + std::to_string(rank); } } // namespace struct NodeDeviceDescriptorManager::Impl { Impl(int64_t rank, int64_t num_ranks) : rank(rank) { nodes.resize(num_ranks); } std::vector> nodes; int64_t rank; }; NodeDeviceDescriptorManager::NodeDeviceDescriptorManager() { impl_.reset(new Impl(GlobalProcessCtx::Rank(), GlobalProcessCtx::WorldSize())); std::shared_ptr local = NodeDeviceDescriptor::Query(); impl_->nodes.at(impl_->rank) = local; if (impl_->nodes.size() > 1) { std::string serialized_local_node; local->Serialize(&serialized_local_node); Singleton::Get()->PushKV(MakeNodeDeviceDescriptorRpcKey(impl_->rank), serialized_local_node); for (int64_t i = 0; i < impl_->nodes.size(); ++i) { if (i == impl_->rank) { continue; } Singleton::Get()->PullKV( MakeNodeDeviceDescriptorRpcKey(i), [&](const std::string& serialized) { impl_->nodes.at(i) = NodeDeviceDescriptor::Deserialize(serialized); }); } } } NodeDeviceDescriptorManager::~NodeDeviceDescriptorManager() = default; std::shared_ptr NodeDeviceDescriptorManager::GetNodeDeviceDescriptor( int64_t rank) const { CHECK_LT(rank, impl_->nodes.size()); return impl_->nodes.at(rank); } std::shared_ptr NodeDeviceDescriptorManager::GetLocalNodeDeviceDescriptor() const { return impl_->nodes.at(impl_->rank); } void NodeDeviceDescriptorManager::DumpSummary(const std::string& base) const { for (int64_t i = 0; i < impl_->nodes.size(); ++i) { impl_->nodes.at(i)->DumpSummary(JoinPath(base, "nodes", std::to_string(i))); } } } // namespace hardware } // namespace oneflow ================================================ FILE: oneflow/core/hardware/node_device_descriptor_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_ #define ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_ #include "oneflow/core/hardware/node_device_descriptor.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace hardware { class NodeDeviceDescriptorManager { public: OF_DISALLOW_COPY_AND_MOVE(NodeDeviceDescriptorManager); NodeDeviceDescriptorManager(); ~NodeDeviceDescriptorManager(); std::shared_ptr GetNodeDeviceDescriptor(int64_t rank) const; std::shared_ptr GetLocalNodeDeviceDescriptor() const; void DumpSummary(const std::string& path) const; private: struct Impl; std::unique_ptr impl_; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_ ================================================ FILE: oneflow/core/hardware/topology_descriptor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/hardware/topology_descriptor.h" namespace oneflow { namespace hardware { void TopologyDescriptor::SetCPUAffinityByPCIBusID(const std::string& bus_id) const { SetCPUAffinity(GetCPUAffinityByPCIBusID(bus_id)); } void TopologyDescriptor::SetMemoryAffinityByPCIBusID(const std::string& bus_id) const { SetMemoryAffinity(GetMemoryAffinityByPCIBusID(bus_id)); } } // namespace hardware } // namespace oneflow ================================================ FILE: oneflow/core/hardware/topology_descriptor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_ #define ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_ #include #include #include "oneflow/core/common/util.h" namespace oneflow { namespace hardware { class TopologyCPUAffinityDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(TopologyCPUAffinityDescriptor); TopologyCPUAffinityDescriptor() = default; virtual ~TopologyCPUAffinityDescriptor() = default; }; class TopologyMemoryAffinityDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(TopologyMemoryAffinityDescriptor); TopologyMemoryAffinityDescriptor() = default; virtual ~TopologyMemoryAffinityDescriptor() = default; }; class TopologyDescriptor { public: OF_DISALLOW_COPY_AND_MOVE(TopologyDescriptor); TopologyDescriptor() = default; virtual ~TopologyDescriptor() = default; virtual std::shared_ptr GetCPUAffinity() const = 0; virtual std::shared_ptr GetMemoryAffinity() const = 0; virtual std::shared_ptr GetCPUAffinityByPCIBusID( const std::string& bus_id) const = 0; virtual std::shared_ptr GetMemoryAffinityByPCIBusID( const std::string& bus_id) const = 0; virtual void SetCPUAffinity( const std::shared_ptr& affinity) const = 0; virtual void SetMemoryAffinity( const std::shared_ptr& affinity) const = 0; virtual void SetCPUAffinityByPCIBusID(const std::string& bus_id) const; virtual void SetMemoryAffinityByPCIBusID(const std::string& bus_id) const; }; } // namespace hardware } // namespace oneflow #endif // ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_ ================================================ FILE: oneflow/core/intrusive/README.md ================================================ ### 概念与数据结构 本子系统可以方便用户定义可侵入式类型。内建支持侵入式智能指针`intrusive::shared_ptr`和侵入式容器。 目前有主要有两类侵入式容器: 1. `intrusive::List`,双链表。基于此,还提供了`intrusive::MutexedList`和`intrusive::Channel`。 2. `intrusive::SkipList`,跳表,等同于map。 为了管理元素CURD所带来的生命周期,侵入式容器需要`intrusive::shared_ptr`来实现内存生命周期的管理,它与`std::shared_ptr`的不同在于其引用计数嵌入在目标结构体里。 ### 接口 需要使用`intrusive::shared_ptr`来管理生命周期的类必须拥有`intrusive::Ref* mut_intrusive_ref();`方法 由于侵入式容器支持比标准容器更为强大的迭代方式,同时为了性能起见,我们提供三类迭代宏: 1. `INTRUSIVE_FOR_EACH`,支持迭代过程中删除当前元素,同时使用`intrusive::shared_ptr`管理好当前元素生命周期 2. `INTRUSIVE_FOR_EACH_PTR`,支持迭代过程中删除当前元素,类型直接为裸指针,即不负责当前元素生命周期的管理 3. `INTRUSIVE_UNSAFE_FOR_EACH_PTR`,不支持迭代中删除元素,不负责当前元素生命周期的管理。 ### 特点 本组件与boost::intrusive最大不同在于实现了完整的生命周期管理,另外提供了其他更能减少内存分配的容器定义方式(详见intrusive::HeadFreeList)。 ================================================ FILE: oneflow/core/intrusive/base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_BASE_H_ #define ONEFLOW_CORE_INTRUSIVE_BASE_H_ namespace oneflow { namespace intrusive { class Base { public: void __Init__() {} void __Delete__() {} }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_BASE_H_ ================================================ FILE: oneflow/core/intrusive/cpp_attribute.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_ #define ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_ #define INTRUSIVE_PREDICT_TRUE GOOGLE_PREDICT_TRUE #define INTRUSIVE_PREDICT_FALSE GOOGLE_PREDICT_FALSE #endif // ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_ ================================================ FILE: oneflow/core/intrusive/dss.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_DSS_H_ #define ONEFLOW_CORE_INTRUSIVE_DSS_H_ #include #include #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/intrusive/struct_traits.h" namespace oneflow { // DSS is short for domain specific struct #define DSS_BEGIN(field_counter, type) _DSS_BEGIN(field_counter, type) #define DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field_name) \ _DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field_name) #define DSS_END(field_counter, dss_type, type) _DSS_END(field_counter, dss_type, type) #define DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq) \ _DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq) #define DSS_GET_FIELD_COUNTER() __COUNTER__ // details #define _DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq) \ private: \ template class F, typename WalkCtxType, typename DssFieldType, \ typename Enabled> \ struct __DssVisitField__ { \ template \ using PartialF = F; \ static void Call(WalkCtxType* ctx, DssFieldType* field_ptr) { \ switch (field_ptr->field_case) { \ OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK, type7field7case_tuple_seq) \ default:; \ } \ } \ }; \ template class F, typename WalkCtxType, typename DssFieldType, \ typename Enabled> \ struct __DssVisitVerboseField__ { \ template \ using PartialF = F; \ static void Call(WalkCtxType* ctx, DssFieldType* field_ptr, const char* __field_name__) { \ switch (field_ptr->field_case) { \ OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK_VERBOSE, \ type7field7case_tuple_seq) \ default:; \ } \ } \ }; \ template class F, typename WalkCtxType, \ typename DssFieldType, typename Enabled> \ struct __DssVisitStaticVerboseField__ { \ template \ using PartialF = F<__DssSelfType__, field_counter, WalkCtxType, __DssFieldType, true>; \ static void Call(WalkCtxType* ctx, const char* __oneof_name__) { \ OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK_STATIC_VERBOSE, \ type7field7case_tuple_seq) \ } \ }; \ template class F, typename WalkCtxType, typename DssFieldType, \ typename Enabled> \ struct __DssVisitFieldUntil__ { \ template \ using PartialF = F; \ static bool Call(WalkCtxType* ctx, DssFieldType* field_ptr) { \ switch (field_ptr->field_case) { \ OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK, type7field7case_tuple_seq) \ default:; \ } \ } \ }; #define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK(field_type, field_name, field_case_value) \ case field_case_value: { \ return PartialF::Call(ctx, &field_ptr->field_name); \ } #define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK_VERBOSE(field_type, field_name, field_case_value) \ case field_case_value: { \ const char* case_field_name = OF_PP_STRINGIZE(field_name); \ return PartialF::Call(ctx, &field_ptr->field_name, case_field_name); \ } #define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK_STATIC_VERBOSE(field_type, field_name, \ field_case_value) \ { \ const char* case_field_name = OF_PP_STRINGIZE(field_name); \ PartialF::Call(ctx, case_field_name, __oneof_name__); \ } #define _DSS_BEGIN(field_counter, type) \ private: \ using __DssSelfType__ = type; \ \ public: \ template \ struct __DssFieldType__; \ template class F, typename WalkCtxType> \ void __WalkField__(WalkCtxType* ctx) { \ __DssFieldIter__::Call(ctx, this); \ } \ template class F, typename WalkCtxType> \ void __WalkVerboseField__(WalkCtxType* ctx) { \ __DssVerboseFieldIter__::Call(ctx, this); \ } \ template class F, typename WalkCtxType> \ static void __WalkStaticVerboseField__(WalkCtxType* ctx) { \ __DssStaticVerboseFieldIter__::Call(ctx); \ } \ template class F, typename WalkCtxType> \ bool __WalkFieldUntil__(WalkCtxType* ctx) { \ return __DssFieldIterUntil__::Call(ctx, this); \ } \ \ private: \ template class F, typename WalkCtxType, \ typename DssFieldType, typename Enabled = void> \ struct __DssVisitField__ { \ static void Call(WalkCtxType* ctx, DssFieldType* field_ptr) { \ F::Call(ctx, field_ptr); \ } \ }; \ template class F, typename WalkCtxType, \ typename DssFieldType, typename Enabled = void> \ struct __DssVisitVerboseField__ { \ static void Call(WalkCtxType* ctx, DssFieldType* field_ptr, const char* __field_name__) { \ F::Call(ctx, field_ptr, __field_name__); \ } \ }; \ template class F, \ typename WalkCtxType, typename DssFieldType, typename Enabled = void> \ struct __DssVisitStaticVerboseField__ { \ static void Call(WalkCtxType* ctx, const char* __field_name__) { \ const char* __oneof_name__ = nullptr; \ F<__DssSelfType__, tpl_fld_counter, WalkCtxType, DssFieldType, false>::Call( \ ctx, __field_name__, __oneof_name__); \ } \ }; \ template class F, typename WalkCtxType, \ typename DssFieldType, typename Enabled = void> \ struct __DssVisitFieldUntil__ { \ static bool Call(WalkCtxType* ctx, DssFieldType* field_ptr) { \ return F::Call(ctx, field_ptr); \ } \ }; \ template class F, typename WalkCtxType, \ typename Enabled = void> \ struct __DssFieldIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ __DssFieldIter__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, \ typename Enabled = void> \ struct __DssVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ __DssVerboseFieldIter__::Call(ctx, self); \ } \ }; \ template class F, \ typename WalkCtxType, typename Enabled = void> \ struct __DssStaticVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx) { \ __DssStaticVerboseFieldIter__::Call(ctx); \ } \ }; \ template class F, typename WalkCtxType, \ typename Enabled = void> \ struct __DssFieldIterUntil__ { \ static bool Call(WalkCtxType* ctx, __DssSelfType__* self) { \ return __DssFieldIterUntil__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, \ typename Enabled = void> \ struct __DssFieldReverseIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ __DssFieldReverseIter__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldReverseIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) {} \ }; \ template \ struct __DssFieldAlign4Counter__ { \ static const int value = 1; \ }; \ template \ struct __DssFieldSize4Counter__ { \ static const int value = 0; \ }; \ template \ struct __DssFieldOffsetOfFieldNumber__ { \ constexpr static int Get() { \ return __DssFieldOffsetOfFieldNumber__::Get(); \ } \ }; \ template \ struct __DssFieldOffsetOfFieldNumber__ { \ constexpr static int Get() { return 0; } \ }; \ template \ struct __DssStaticAssertFieldCounter__ {}; \ \ template \ struct __DssAccumulatedAlignedSize4Counter__ { \ static const int value = \ ConstExprRoundUp<__DssAccumulatedAlignedSize4Counter__::value \ + __DssFieldSize4Counter__::value, \ __DssFieldAlign4Counter__::value>(); \ }; \ template \ struct __DssAccumulatedAlignedSize4Counter__ { \ static const int value = 0; \ }; \ \ public: \ template \ struct __DssFieldOffset4FieldIndex__ { \ static const int value = __DssAccumulatedAlignedSize4Counter__::value; \ }; #define DSS_ASSERT_VERBOSE(dss_type) \ "\n\n\n please check file " __FILE__ " (before line " OF_PP_STRINGIZE( \ __LINE__) ") carefully\n" \ " non " dss_type " member found before line " OF_PP_STRINGIZE(__LINE__) "\n\n" #define _DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field) \ private: \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ __DssVisitField__field)>::Call(ctx, \ &self->field); \ __DssFieldIter__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ const char* __field_name__ = OF_PP_STRINGIZE(field); \ __DssVisitVerboseField__field)>::Call( \ ctx, &self->field, __field_name__); \ __DssVerboseFieldIter__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, \ typename Enabled> \ struct __DssStaticVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx) { \ const char* __field_name__ = OF_PP_STRINGIZE(field); \ __DssVisitStaticVerboseField__< \ field_counter, F, WalkCtxType, \ decltype(((__DssSelfType__*)nullptr)->field)>::Call(ctx, __field_name__); \ __DssStaticVerboseFieldIter__::Call(ctx); \ } \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldIterUntil__ { \ static bool Call(WalkCtxType* ctx, __DssSelfType__* self) { \ bool end = \ __DssVisitFieldUntil__field)>::Call( \ ctx, &self->field); \ if (end) { return true; } \ return __DssFieldIterUntil__::Call(ctx, self); \ } \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldReverseIter__ { \ static void Call(WalkCtxType* ctx, __DssSelfType__* self) { \ __DssVisitField__field)>::Call(ctx, \ &self->field); \ __DssFieldReverseIter__::Call(ctx, self); \ } \ }; \ template \ struct __DssFieldAlign4Counter__ { \ static const int value = alignof(field_type); \ }; \ template \ struct __DssFieldSize4Counter__ { \ static const int value = sizeof(field_type); \ }; \ template \ struct __DssFieldOffsetOfFieldNumber__ { \ constexpr static int Get() { \ static_assert(std::is_standard_layout<__DssSelfType__>::value, ""); \ return offsetof(__DssSelfType__, field); \ } \ }; \ template \ struct __DssStaticAssertFieldCounter__ { \ static void StaticAssert() { \ static const int kAccSize = __DssAccumulatedAlignedSize4Counter__::value; \ static_assert(kAccSize == __DssFieldOffsetOfFieldNumber__::Get(), \ DSS_ASSERT_VERBOSE(dss_type)); \ } \ }; \ \ public: \ template \ struct __DssFieldType__ { \ using type = field_type; \ }; \ [[maybe_unused]] static const int OF_PP_CAT(field, kDssFieldNumber) = field_counter; \ using OF_PP_CAT(field, DssFieldType) = field_type; \ [[maybe_unused]] static const int OF_PP_CAT(field, kDssFieldOffset) = \ __DssAccumulatedAlignedSize4Counter__::value; #define _DSS_END(field_counter, dss_type, type) \ public: \ template class F, typename WalkCtxType> \ void __ReverseWalkField__(WalkCtxType* ctx) { \ __DssFieldReverseIter__::Call(ctx, this); \ } \ \ private: \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldIter__ { \ static void Call(WalkCtxType* ctx, type* self) {} \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx, type* self) {} \ }; \ template class F, typename WalkCtxType, \ typename Enabled> \ struct __DssStaticVerboseFieldIter__ { \ static void Call(WalkCtxType* ctx) {} \ }; \ template class F, typename WalkCtxType, typename Enabled> \ struct __DssFieldIterUntil__ { \ static bool Call(WalkCtxType* ctx, type* self) { return false; } \ }; \ static void __DssStaticAssertStructSize__() { \ static const int kSize = \ ConstExprRoundUp<__DssAccumulatedAlignedSize4Counter__::value, \ alignof(type)>(); \ static_assert((kSize == 0 && sizeof(type) == 1) || (kSize == sizeof(type)), \ DSS_ASSERT_VERBOSE(dss_type)); \ } template constexpr int ConstExprRoundUp() { return (x + y - 1) / y * y; } template struct GetterTrait {}; template struct GetterTrait { template static const T& Call(const T& data) { return data; } }; template struct GetterTrait { template static const T& Call(const T* data) { return *data; } }; template struct MutableTrait {}; template struct MutableTrait { template static T* Call(T* data) { return data; } }; template struct MutableTrait { template static T* Call(T** data) { return *data; } }; } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_DSS_H_ ================================================ FILE: oneflow/core/intrusive/dss_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/dss.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace { struct Foo { DSS_BEGIN(DSS_GET_FIELD_COUNTER(), Foo); int x; int y; int* z; DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), "demo dss", int, x); DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), "demo dss", int, y); DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), "demo dss", int*, z); DSS_END(DSS_GET_FIELD_COUNTER(), "demo dss", Foo); }; struct Bar { DSS_BEGIN(DSS_GET_FIELD_COUNTER(), Foo); DSS_END(DSS_GET_FIELD_COUNTER(), "demo dss", Bar); }; template struct IsPointer { static const bool value = std::is_pointer::value; }; template struct RemovePointer { using type = typename std::remove_pointer::type; }; template struct IsScalar { static const bool value = std::is_arithmetic::value || std::is_enum::value || std::is_same::value; }; template struct DumpFieldName { static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) { ctx->emplace_back(field_name); } }; TEST(DSS, walk_field) { Foo foo; std::vector field_names; foo.__WalkVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 3); ASSERT_TRUE(field_names[0] == "x"); ASSERT_TRUE(field_names[1] == "y"); ASSERT_TRUE(field_names[2] == "z"); } template struct PushBackPtrFieldName { template static void Call(WalkCtxType* ctx, const char* field_name) {} }; template<> struct PushBackPtrFieldName { template static void Call(WalkCtxType* ctx, const char* field_name) { ctx->emplace_back(field_name); } }; template struct FilterPointerFieldName { static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) { PushBackPtrFieldName::value>::Call(ctx, field_name); } }; template struct FilterPointerFieldNameUntil { static bool Call(WalkCtxType* ctx, FieldType* field) { return true; PushBackPtrFieldName::value>::Call(ctx, ""); } }; TEST(DSS, filter_field) { Foo foo; std::vector field_names; foo.__WalkVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 1); ASSERT_TRUE(field_names[0] == "z"); } TEST(DSS, filter_field_until) { Foo foo; std::vector field_names; ASSERT_TRUE(foo.__WalkFieldUntil__(&field_names)); ASSERT_TRUE(field_names.empty()); } #define DSS_DEFINE_TEST_UNION_FIELD(field_counter) \ DSS_DEFINE_FIELD(field_counter, "demo dss", UnionField, union_field); \ DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, union_case, \ OF_PP_MAKE_TUPLE_SEQ(int32_t, x, 1) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, y, 2)); struct TestDssUnion { DSS_BEGIN(DSS_GET_FIELD_COUNTER(), TestDssUnion); public: struct UnionField { int32_t union_case; union { int32_t x; int64_t y; }; } union_field; DSS_DEFINE_TEST_UNION_FIELD(DSS_GET_FIELD_COUNTER()); DSS_END(DSS_GET_FIELD_COUNTER(), "demo dss", TestDssUnion); }; template struct StaticDumpFieldName { static void Call(WalkCtxType* ctx, const char* field_name, const char* oneof_name) { ctx->emplace_back(field_name); ctx->emplace_back(oneof_name); } }; TEST(DSS, union_field) { TestDssUnion foo; foo.union_field.union_case = 0; { std::vector field_names; foo.__WalkVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 0); } foo.union_field.union_case = 1; { std::vector field_names; foo.__WalkVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 1); ASSERT_EQ(field_names.at(0), "x"); } foo.union_field.union_case = 2; { std::vector field_names; foo.__WalkVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 1); ASSERT_EQ(field_names.at(0), "y"); } } TEST(DSS, static_verbose_field) { std::vector field_names; TestDssUnion::__WalkStaticVerboseField__(&field_names); ASSERT_EQ(field_names.size(), 4); ASSERT_EQ(field_names.at(0), "x"); ASSERT_EQ(field_names.at(1), "union_field"); ASSERT_EQ(field_names.at(2), "y"); ASSERT_EQ(field_names.at(3), "union_field"); } } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/flat_msg.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_FLAT__H_ #define ONEFLOW_CORE_INTRUSIVE_FLAT__H_ #include #include #include "oneflow/core/common/throw.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/intrusive/dss.h" #include "oneflow/core/intrusive/static_counter.h" namespace oneflow { #define FLAT_MSG_BEGIN(struct_name) \ struct struct_name final { \ using self_type = struct_name; \ using self_value_type = struct_name; \ static const bool __is_flat_message_type__ = true; \ \ public: \ DEFINE_STATIC_COUNTER(field_counter); \ DSS_BEGIN(STATIC_COUNTER(field_counter), struct_name); \ FLAT_MSG_DEFINE_BASIC_METHODS(struct_name); \ FLAT_MSG_DEFINE_DEFAULT(struct_name); #define FLAT_MSG_END(struct_name) \ static_assert(__is_flat_message_type__, "this struct is not a flat message"); \ \ public: \ [[maybe_unused]] static const int __NumberOfFields__ = STATIC_COUNTER(field_counter); \ \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ DSS_END(STATIC_COUNTER(field_counter), "flat message", struct_name); \ } \ ; #define FLAT_MSG_DEFINE_OPTIONAL(field_type, field_name) \ static_assert(__is_flat_message_type__, "this struct is not a flat message"); \ FLAT_MSG_DEFINE_ONEOF(OF_PP_CAT(__flat_msg_optional__, field_name), \ FLAT_MSG_ONEOF_FIELD(field_type, field_name)) #define FLAT_MSG_DEFINE_ONEOF(oneof_name, type_and_field_name_seq) \ _FLAT_MSG_DEFINE_ONEOF(_FLAT_MSG_DEFINE_NOTHING, oneof_name, type_and_field_name_seq); #define FLAT_MSG_DEFINE_STRICT_ONEOF(oneof_name, type_and_field_name_seq) \ _FLAT_MSG_DEFINE_ONEOF(_FLAT_MSG_DEFINE_ONEOF_VALUE4TYPE, oneof_name, type_and_field_name_seq); #define _FLAT_MSG_DEFINE_ONEOF(define_field_value4field_type, oneof_name, type_and_field_name_seq) \ static_assert(__is_flat_message_type__, "this struct is not a flat message"); \ FLAT_MSG_DEFINE_ONEOF_ENUM_TYPE(oneof_name, type_and_field_name_seq); \ FLAT_MSG_DEFINE_ONEOF_UNION(define_field_value4field_type, oneof_name, type_and_field_name_seq); \ FLAT_MSG_DEFINE_ONEOF_ACCESSOR(oneof_name, type_and_field_name_seq) \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ FLAT_MSG_DSS_DEFINE_UION_FIELD(STATIC_COUNTER(field_counter), oneof_name, \ type_and_field_name_seq); #define FLAT_MSG_DEFINE_REPEATED(field_type, field_name, max_size) \ static_assert(__is_flat_message_type__, "this struct is not a flat message"); \ _FLAT_MSG_DEFINE_REPEATED_FIELD(FLAT_MSG_TYPE_CHECK(field_type), field_name, max_size); \ \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), "flat message", \ OF_PP_CAT(field_name, _RepeatedField), OF_PP_CAT(field_name, _)); #define FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP() _FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP() #define FLAT_MSG_ONEOF_FIELD(field_type, field_name) \ OF_PP_MAKE_TUPLE_SEQ(FLAT_MSG_TYPE_CHECK(field_type), field_name) #define FLAT_MSG_ONEOF_CASE(oneof_name) _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name) #define FLAT_MSG_ONEOF_CASE_VALUE(field) _FLAT_MSG_ONEOF_ENUM_VALUE(field) #define FLAT_MSG_ONEOF_NOT_SET_VALUE(field_type, oneof_name) \ field_type::_FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name) #define FLAT_MSG_TYPE_CHECK(type_name) FlatMsgSelfType::type // details #define FLAT_MSG_DSS_DEFINE_UION_FIELD(field_counter, oneof_name, type_and_field_name_seq) \ DSS_DEFINE_FIELD(field_counter, "flat message", OF_PP_CAT(oneof_name, _OneofType), \ OF_PP_CAT(oneof_name, _)); \ DSS_DEFINE_UNION_FIELD_VISITOR( \ field_counter, case_, \ OF_PP_FOR_EACH_TUPLE(FLAT_MSG_MAKE_UNION_TYPE7FIELD4CASE, type_and_field_name_seq)); #define FLAT_MSG_MAKE_UNION_TYPE7FIELD4CASE(field_type, field_name) \ OF_PP_MAKE_TUPLE_SEQ(field_type, OF_PP_CAT(field_name, _), _FLAT_MSG_ONEOF_ENUM_VALUE(field_name)) template struct FlatMsgSelfType { static_assert(T::__is_flat_message_type__, "T is not a flat message type"); using type = T; }; template struct FlatMsgSelfType< T, typename std::enable_if::value || std::is_enum::value>::type> { using type = T; }; template struct FlatMsg final { using value_type = T; using self_value_type = value_type; FlatMsg() { msg_.clear(); } FlatMsg(const FlatMsg& rhs) { msg_.CopyFrom(rhs.msg_); } FlatMsg(const T& msg) { msg_.CopyFrom(msg); } const value_type& operator*() const { return msg_; } value_type& operator*() { return msg_; } const value_type* operator->() const { return &msg_; } value_type* operator->() { return &msg_; } const value_type& Get() const { return msg_; } value_type* Mutable() { return &msg_; } template bool operator==(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ == rhs.msg_; } template bool operator!=(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ != rhs.msg_; } template bool operator>=(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ >= rhs.msg_; } template bool operator<=(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ <= rhs.msg_; } template bool operator>(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ > rhs.msg_; } template bool operator<(const RhsT& rhs) const { static_assert(std::is_same::value, ""); return msg_ < rhs.msg_; } private: union { value_type msg_; }; }; #define FLAT_MSG_DEFINE_DEFAULT(flat_msg_type_name) \ const flat_msg_type_name& __Default__() const { \ static const FlatMsg default_flat_msg; \ return default_flat_msg.Get(); \ } template struct FlatMsgIsScalar final { static const bool value = std::is_arithmetic::value || std::is_enum::value; }; template struct FlatMsgGetDefault final { template static const T& Call(const T* val) { return val->__Default__(); } }; template<> struct FlatMsgGetDefault final { template static const T& Call(const T* val) { return *val; } }; #define _FLAT_MSG_ONEOF_CASE_NAME(oneof_name) OF_PP_CAT(oneof_name, _case) #define _FLAT_MSG_ONEOF_ENUM_VALUE(field) SNAKE_TO_CAMEL(field) #define _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name) SNAKE_TO_CAMEL(oneof_name) #define _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name) OF_PP_CAT(k_, OF_PP_CAT(oneof_name, _not_set)) #define FLAT_MSG_DEFINE_BASIC_METHODS(T) _FLAT_MSG_DEFINE_BASIC_METHODS(T) #define _FLAT_MSG_DEFINE_BASIC_METHODS(T) \ public: \ void clear() { std::memset(reinterpret_cast(this), 0, sizeof(T)); } \ void CopyFrom(const self_type& rhs) { \ std::memcpy(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)); \ } #define FLAT_MSG_DEFINE_ONEOF_ENUM_TYPE(oneof_name, type_and_field_name_seq) \ public: \ enum _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name) { \ _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name) = 0, \ OF_PP_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_ENUM_CASE, type_and_field_name_seq) \ } #define MAKE_FLAT_MSG_ONEOF_ENUM_CASE(field_type, field_name) \ _FLAT_MSG_ONEOF_ENUM_VALUE(field_name), #define FLAT_MSG_DEFINE_ONEOF_ACCESSOR(oneof_name, type_and_field_name_seq) \ _FLAT_MSG_DEFINE_ONEOF_CASE_ACCESSOR(oneof_name, _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name)); \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_ACCESSOR, (_FLAT_MSG_ONEOF_ENUM_VALUE), \ (oneof_name), type_and_field_name_seq) #define MAKE_FLAT_MSG_ONEOF_ACCESSOR(get_enum_value, oneof_name, pair) \ public: \ const OF_PP_PAIR_FIRST(pair) & OF_PP_PAIR_SECOND(pair)() const { \ if (OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))()) { \ return OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _); \ } \ return FlatMsgGetDefault::value>::Call( \ &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _)); \ } \ bool OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))() const { \ return _FLAT_MSG_ONEOF_CASE_NAME(oneof_name)() == get_enum_value(OF_PP_PAIR_SECOND(pair)); \ } \ void OF_PP_CAT(clear_, OF_PP_PAIR_SECOND(pair))() { \ if (!OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))()) { return; } \ OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name)) \ (_FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name)); \ } \ OF_PP_PAIR_FIRST(pair) * OF_PP_CAT(mut_, OF_PP_PAIR_SECOND(pair))() { \ OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name)) \ (get_enum_value(OF_PP_PAIR_SECOND(pair))); \ return &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _); \ } \ OF_PP_PAIR_FIRST(pair) * OF_PP_CAT(mutable_, OF_PP_PAIR_SECOND(pair))() { \ OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name)) \ (get_enum_value(OF_PP_PAIR_SECOND(pair))); \ return &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _); \ } \ void OF_PP_CAT(set_, OF_PP_PAIR_SECOND(pair))(const OF_PP_PAIR_FIRST(pair) & val) { \ *OF_PP_CAT(mutable_, OF_PP_PAIR_SECOND(pair))() = val; \ } #define FLAT_MSG_DEFINE_ONEOF_UNION(define_field_value4field_type, oneof_name, \ type_and_field_name_seq) \ public: \ struct OF_PP_CAT(oneof_name, _OneofType) { \ public: \ using self_oneof_type = OF_PP_CAT(oneof_name, _OneofType); \ using self_oneof_case_type = _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name); \ template \ struct FieldType4FieldValueStruct {}; \ template \ struct HasStruct {}; \ template \ struct GetStruct {}; \ template \ struct MutableStruct {}; \ OF_PP_FOR_EACH_TUPLE(_MAKE_FLAT_MSG_ONEOF_TEMPLATE_ACCESSOR, type_and_field_name_seq); \ define_field_value4field_type(type_and_field_name_seq); \ template \ bool Has() const { \ return HasStruct::Call(*this); \ } \ template \ const typename FieldType4FieldValueStruct::type& Get() const { \ return GetStruct::Call(*this); \ } \ template \ typename FieldType4FieldValueStruct::type* Mutable() { \ return MutableStruct::Call(this); \ } \ \ union { \ OF_PP_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_UNION_FIELD, type_and_field_name_seq) \ }; \ self_oneof_case_type case_; \ }; \ \ private: \ OF_PP_CAT(oneof_name, _OneofType) OF_PP_CAT(oneof_name, _); \ \ public: \ const OF_PP_CAT(oneof_name, _OneofType) & oneof_name() const { \ return OF_PP_CAT(oneof_name, _); \ } \ OF_PP_CAT(oneof_name, _OneofType) * OF_PP_CAT(mutable_, oneof_name)() { \ return &OF_PP_CAT(oneof_name, _); \ } #define _MAKE_FLAT_MSG_ONEOF_TEMPLATE_ACCESSOR(field_type, field_name) \ public: \ template \ struct FieldType4FieldValueStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> { \ using type = field_type; \ }; \ template \ struct HasStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> { \ static bool Call(const self_oneof_type& self) { \ return self.case_ == _FLAT_MSG_ONEOF_ENUM_VALUE(field_name); \ } \ }; \ template \ struct GetStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> { \ static const field_type& Call(const self_oneof_type& self) { \ return self.OF_PP_CAT(field_name, _); \ } \ }; \ template \ struct MutableStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> { \ static field_type* Call(self_oneof_type* self) { \ self->case_ = _FLAT_MSG_ONEOF_ENUM_VALUE(field_name); \ return &self->OF_PP_CAT(field_name, _); \ } \ }; #define _FLAT_MSG_DEFINE_NOTHING(type_and_field_name_seq) #define _FLAT_MSG_DEFINE_ONEOF_VALUE4TYPE(type_and_field_name_seq) \ public: \ template \ struct FieldValue4FieldType {}; \ OF_PP_FOR_EACH_TUPLE(_MAKE_FLAT_MSG_ONEOF_VALUE4TYPE, type_and_field_name_seq); \ template \ bool HasField() const { \ return Has::value>(); \ } \ template \ const T& GetField() const { \ return Get::value>(); \ } \ template \ T* MutableField() { \ return Mutable::value>(); \ } #define _MAKE_FLAT_MSG_ONEOF_VALUE4TYPE(field_type, field_name) \ template \ struct FieldValue4FieldType { \ static const self_oneof_case_type value = _FLAT_MSG_ONEOF_ENUM_VALUE(field_name); \ }; #define MAKE_FLAT_MSG_ONEOF_UNION_FIELD(field_type, field_name) field_type OF_PP_CAT(field_name, _); #define SNAKE_TO_CAMEL(name) OF_PP_CAT(__FlatMsgSnakeToCamel__, name) #define _FLAT_MSG_DEFINE_ONEOF_CASE_ACCESSOR(oneof_name, T) \ public: \ T OF_PP_CAT(oneof_name, _case)() const { return OF_PP_CAT(oneof_name, _).case_; } \ bool OF_PP_CAT(has_, oneof_name)() const { \ return OF_PP_CAT(oneof_name, _).case_ != _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name); \ } \ \ private: \ void OF_PP_CAT(set_, OF_PP_CAT(oneof_name, _case))(T val) { \ OF_PP_CAT(oneof_name, _).case_ = val; \ } #define _FLAT_MSG_DEFINE_REPEATED_FIELD(T, field_name, N) \ public: \ using OF_PP_CAT(field_name, _RepeatedField) = FlatMsgRepeatedField; \ std::size_t OF_PP_CAT(field_name, _size)() const { return OF_PP_CAT(field_name, _).size(); } \ const OF_PP_CAT(field_name, _RepeatedField) & field_name() const { \ return OF_PP_CAT(field_name, _); \ } \ const T& field_name(int32_t i) const { return OF_PP_CAT(field_name, _).Get(i); } \ OF_PP_CAT(field_name, _RepeatedField) * OF_PP_CAT(mut_, field_name)() { \ return &OF_PP_CAT(field_name, _); \ } \ OF_PP_CAT(field_name, _RepeatedField) * OF_PP_CAT(mutable_, field_name)() { \ return &OF_PP_CAT(field_name, _); \ } \ T* OF_PP_CAT(mut_, field_name)(int32_t i) { return OF_PP_CAT(field_name, _).Mutable(i); } \ T* OF_PP_CAT(mutable_, field_name)(int32_t i) { return OF_PP_CAT(field_name, _).Mutable(i); } \ T* OF_PP_CAT(add_, field_name)() { return OF_PP_CAT(field_name, _).Add(); } \ void OF_PP_CAT(clear_, field_name)() { OF_PP_CAT(field_name, _).clear(); } \ \ private: \ OF_PP_CAT(field_name, _RepeatedField) \ OF_PP_CAT(field_name, _); #define _FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP() \ public: \ bool operator<(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ < 0; \ } \ bool operator<=(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ <= 0; \ } \ bool operator==(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ == 0; \ } \ bool operator!=(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ != 0; \ } \ bool operator>(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ > 0; \ } \ bool operator>=(const self_type& rhs) const { \ return std::memcmp(reinterpret_cast(this), reinterpret_cast(&rhs), \ sizeof(self_type)) \ >= 0; \ } template class FlatMsgRepeatedField final { public: using value_type = T; static const int capacity = N; bool empty() const { return size_ == 0; } std::size_t size() const { return size_; } void clear() { size_ = 0; } T* begin() { return &data_[0]; } T* end() { CHECK_GE(size_, 0); CHECK_LE(size_, N); return &data_[size_]; } const T* begin() const { return &data_[0]; } const T* end() const { CHECK_GE(size_, 0); CHECK_LE(size_, N); return &data_[size_]; } const T& Get(int32_t index) const { CHECK_GE(index, 0); CHECK_LT(index, N); return data_[index]; } T* Mutable(int32_t index) { CHECK_GE(index, 0); CHECK_LT(index, N); return &data_[index]; } const T* data() const { return &Get(0); } T* data() { return Mutable(0); } T* mut_data() { return Mutable(0); } T* Add() { CHECK_GE(size_, 0); CHECK_LT(size_, N); return &data_[size_++]; } private: std::size_t size_; std::array data_; }; } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_FLAT__H_ ================================================ FILE: oneflow/core/intrusive/flat_msg_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/flat_msg.h" namespace oneflow { namespace { template struct DumpFieldName { static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) { ctx->emplace_back(field_name); } }; template std::vector GetFieldNames(T* flat_msg) { std::vector field_names; flat_msg->template __WalkVerboseField__(&field_names); return field_names; } template void CheckSoleFieldName(T* flat_msg, const std::string& expected) { const auto& field_names = GetFieldNames(flat_msg); ASSERT_EQ(field_names.size(), 1); ASSERT_EQ(field_names.at(0), expected); } // clang-format off FLAT_MSG_BEGIN(TestOptional) FLAT_MSG_DEFINE_OPTIONAL(int32_t, bar); FLAT_MSG_END(TestOptional) // clang-format on TEST(FlatMsg, optional) { static_assert(std::is_trivial::value, "TestOptional is not trivial"); FlatMsg foo_box; auto& foo = *foo_box.Mutable(); ASSERT_TRUE(!foo.has_bar()); ASSERT_EQ(foo.bar(), 0); ASSERT_TRUE(GetFieldNames(&foo).empty()); *foo.mutable_bar() = 9527; ASSERT_TRUE(foo.has_bar()); ASSERT_EQ(foo.bar(), 9527); auto field_names = GetFieldNames(&foo); ASSERT_EQ(field_names.size(), 1); ASSERT_EQ(field_names.at(0), "bar_"); } // clang-format off FLAT_MSG_BEGIN(FooOneof) FLAT_MSG_DEFINE_ONEOF(type, FLAT_MSG_ONEOF_FIELD(int32_t, case_0) FLAT_MSG_ONEOF_FIELD(int64_t, case_1) FLAT_MSG_ONEOF_FIELD(TestOptional, bar)); FLAT_MSG_END(FooOneof) // clang-format on TEST(FlatMsg, oneof) { FlatMsg foo_box; auto& foo = *foo_box.Mutable(); ASSERT_TRUE(GetFieldNames(&foo).empty()); ASSERT_TRUE(!foo.has_bar()); ASSERT_EQ(foo.bar().bar(), 0); foo.mutable_case_0(); CheckSoleFieldName(&foo, "case_0_"); ASSERT_TRUE(foo.has_case_0()); FooOneof::FLAT_MSG_ONEOF_CASE(type) x = foo.type_case(); ASSERT_TRUE(x == FooOneof::FLAT_MSG_ONEOF_CASE_VALUE(case_0)); *foo.mutable_case_1() = 9527; CheckSoleFieldName(&foo, "case_1_"); ASSERT_TRUE(foo.has_case_1()); ASSERT_EQ(foo.case_1(), 9527); } // clang-format off FLAT_MSG_BEGIN(FooRepeated) FLAT_MSG_DEFINE_REPEATED(char, char_field, 1); FLAT_MSG_DEFINE_REPEATED(TestOptional, bar, 10); FLAT_MSG_END(FooRepeated) // clang-format on TEST(FlatMsg, repeated) { FlatMsg foo_box; auto& foo = *foo_box.Mutable(); ASSERT_EQ(foo.bar_size(), 0); ASSERT_EQ(foo.bar().size(), 0); auto* bar = foo.mutable_bar()->Add(); ASSERT_TRUE(!bar->has_bar()); ASSERT_EQ(foo.bar_size(), 1); ASSERT_EQ(foo.bar().size(), 1); bar->set_bar(9527); ASSERT_TRUE(bar->has_bar()); ASSERT_EQ(bar->bar(), 9527); bar = foo.mutable_bar()->Add(); ASSERT_TRUE(!bar->has_bar()); ASSERT_EQ(foo.bar_size(), 2); ASSERT_EQ(foo.bar().size(), 2); bar->set_bar(9528); for (const auto& x : foo.bar()) { ASSERT_TRUE(x.has_bar()); } foo.clear_bar(); ASSERT_EQ(foo.bar_size(), 0); } // clang-format off template FLAT_MSG_BEGIN(TestTemplateFlatMsg); FLAT_MSG_DEFINE_REPEATED(char, char_field, N); FLAT_MSG_END(TestTemplateFlatMsg); // clang-format on TEST(FlatMsg, flat_msg_template) { FlatMsg> foo; ASSERT_TRUE(foo.Get().char_field().empty()); } } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/flat_msg_view.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_ #define ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_ #include #include "oneflow/core/common/throw.h" #include "oneflow/core/intrusive/dss.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/intrusive/static_counter.h" namespace oneflow { #define FLAT_MSG_VIEW_BEGIN(struct_name) \ struct struct_name final { \ using self_type = struct_name; \ static const bool __is_flat_message_view_type__ = true; \ FLAT_MSG_VIEW_DEFINE_BASIC_METHODS(struct_name); \ \ public: \ DEFINE_STATIC_COUNTER(field_counter); \ DSS_BEGIN(STATIC_COUNTER(field_counter), struct_name); #define FLAT_MSG_VIEW_END(struct_name) \ static_assert(__is_flat_message_view_type__, "this struct is not a flat message view"); \ \ public: \ static const int __LastFieldIndex__ = STATIC_COUNTER(field_counter); \ \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ DSS_END(STATIC_COUNTER(field_counter), "flat message view", struct_name); \ } \ ; #define FLAT_MSG_VIEW_DEFINE_PATTERN(flat_msg_field_type, field_name) \ static_assert(__is_flat_message_view_type__, "this struct is not a flat message view"); \ _FLAT_MSG_VIEW_DEFINE_PATTERN(FLAT_MSG_TYPE_CHECK(flat_msg_field_type), field_name); \ \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type); \ FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type); \ DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), "flat message view", flat_msg_field_type*, \ OF_PP_CAT(field_name, _)); #define FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(flat_msg_field_type, field_name) \ static_assert(__is_flat_message_view_type__, "this struct is not a flat message view"); \ _FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(FLAT_MSG_TYPE_CHECK(flat_msg_field_type), field_name); \ \ public: \ INCREASE_STATIC_COUNTER(field_counter); \ _SPECIALIZE_IS_REPEATED_PATTERN(STATIC_COUNTER(field_counter)); \ FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type); \ FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type); \ DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), "flat message view", \ FlatMsgViewPatternVec, OF_PP_CAT(field_name, _)); // details #define _FLAT_MSG_VIEW_DEFINE_PATTERN(field_type, field_name) \ public: \ const field_type& field_name() const { return *OF_PP_CAT(field_name, _); } \ \ private: \ const field_type* OF_PP_CAT(field_name, _); #define _FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(field_type, field_name) \ public: \ const field_type& field_name(int i) const { return *OF_PP_CAT(field_name, _).at(i); } \ std::size_t OF_PP_CAT(field_name, _size)() const { return OF_PP_CAT(field_name, _).size(); } \ \ private: \ FlatMsgViewPatternVec OF_PP_CAT(field_name, _); #define FLAT_MSG_VIEW_DEFINE_BASIC_METHODS(T) \ public: \ template \ struct IsRepeatedPattern { \ static const bool value = false; \ }; \ \ private: \ template \ struct __FlatMsgViewFieldType__ { \ struct type {}; \ }; #define FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(field_index, field_type) \ private: \ template \ struct __FlatMsgViewFieldType__ { \ using type = field_type; \ }; #define FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(field_index, field_type) \ private: \ static void OF_PP_CAT(__CheckLastFieldType__, __LINE__)() { \ static_assert( \ !(IsRepeatedPattern::value \ && std::is_same<__FlatMsgViewFieldType__::type, field_type>::value), \ "repeated pattern shouldn't be followed by the pattern with same type"); \ } #define _SPECIALIZE_IS_REPEATED_PATTERN(field_index) \ template \ struct IsRepeatedPattern { \ static const bool value = true; \ } template struct FlatMsgViewPatternVec { using value_type = T; void __Init__() { new (&vec_buffer_) Vec(); } void __Delete__() { mut_vec()->~Vec(); } const T* at(int index) const { return vec().at(index); } size_t size() const { return vec().size(); } void clear() { mut_vec()->clear(); } void emplace_back(const T* ptr) { mut_vec()->emplace_back(ptr); } private: using Vec = std::vector; Vec* mut_vec() { Vec* __attribute__((__may_alias__)) ptr = reinterpret_cast(&vec_buffer_); return ptr; } const Vec& vec() const { const Vec* __attribute__((__may_alias__)) ptr = reinterpret_cast(&vec_buffer_); return *ptr; } union { char vec_buffer_[sizeof(Vec)]; int64_t align64_; }; }; template class FlatMsgViewFieldCtx { public: using flat_msg_view_type = FlatMsgViewT; static_assert(std::is_same::value, "invalid view match"); FlatMsgViewFieldCtx(const FlatMsgViewFieldCtx&) = delete; FlatMsgViewFieldCtx(FlatMsgViewFieldCtx&&) = delete; FlatMsgViewFieldCtx(const OneofValueType* repeated_flag_msg, std::size_t size) : repeated_flag_msg_(repeated_flag_msg), token_index_(0), size_(size) {} ~FlatMsgViewFieldCtx() = default; const OneofValueType* GetFlatMsg() const { return repeated_flag_msg_ + token_index_; } typename FlatMsgOneofField::field_type* GetOneof() const { return FlatMsgOneofField::FieldPtr4StructPtr(GetFlatMsg()); } bool is_token_index_valid() const { return token_index_ < size_; } void increase_token_index() { ++token_index_; } int32_t token_count() const { return token_index_; } private: const OneofValueType* repeated_flag_msg_; int32_t token_index_; const std::size_t size_; }; template struct _FlatMsgViewFieldMatcher {}; template struct FlatMsgViewFieldMatcher { static const bool is_repeated_pattern = WalkCtxType::flat_msg_view_type::template IsRepeatedPattern::value; // return true if error occured static bool Call(WalkCtxType* ctx, FieldPtrT* field) { return _FlatMsgViewFieldMatcher::Call(ctx, field); } }; template struct _FlatMsgViewFieldMatcher { // return true if error occured static bool Call(WalkCtxType* ctx, FieldPtrT* field) { if (!ctx->is_token_index_valid()) { return true; } using ConstFieldType = typename std::remove_pointer::type; using FieldType = typename std::remove_const::type; const auto* oneof = ctx->GetOneof(); if (!oneof->template HasField()) { return true; } *field = &oneof->template GetField(); ctx->increase_token_index(); return false; } }; template struct _FlatMsgViewFieldMatcher { // return true if error occured static bool Call(WalkCtxType* ctx, FieldPtrT* field) { field->clear(); using FieldType = typename FieldPtrT::value_type; while (ctx->is_token_index_valid()) { const auto* oneof = ctx->GetOneof(); if (!oneof->template HasField()) { break; } field->emplace_back(&oneof->template GetField()); ctx->increase_token_index(); } return false; } }; template struct FlatMsgViewUtil { static_assert(std::is_same::value, "invalid view match"); static bool Match(FlatMsgViewT* flat_msg_view, const ValueType* data_ptr, std::size_t size) { FlatMsgViewFieldCtx ctx(data_ptr, size); bool ret = !flat_msg_view->template __WalkFieldUntil__(&ctx); if (ret) { if (FlatMsgViewT::template IsRepeatedPattern::value) { ret = (ctx.token_count() == size) || /* last repeated field empty */ (ctx.token_count() - 1 == size); } else { ret = (ctx.token_count() == size); } } return ret; } }; template struct FlatMsgViewContainerUtil { using FlatMsgOneofField = intrusive::OffsetStructField; static bool Match(FlatMsgViewT* self, const ContainerT& container) { return FlatMsgViewUtil::Match( self, container.data(), container.size()); } }; template struct FlatMsgViewContainerUtil>, Enabled> { using FlatMsgOneofField = intrusive::OffsetStructField; static_assert(sizeof(ValueType) == sizeof(FlatMsg), ""); static_assert(alignof(ValueType) == alignof(FlatMsg), ""); static bool Match(FlatMsgViewT* self, const std::vector>& container) { return FlatMsgViewUtil::Match( self, &container.data()->Get(), container.size()); } }; template struct _FlatMsgViewFieldInit {}; template struct FlatMsgViewFieldInit { static const bool is_repeated_pattern = WalkCtxType::template IsRepeatedPattern::value; static void Call(WalkCtxType* ctx, FieldPtrT* field) { _FlatMsgViewFieldInit::Call(field); } }; template struct _FlatMsgViewFieldInit { static void Call(FieldPtrT* field) {} }; template struct _FlatMsgViewFieldInit { static void Call(FieldPtrT* field) { field->__Init__(); } }; template struct _FlatMsgViewFieldDelete {}; template struct FlatMsgViewFieldDelete { static const bool is_repeated_pattern = WalkCtxType::template IsRepeatedPattern::value; static void Call(WalkCtxType* ctx, FieldPtrT* field) { _FlatMsgViewFieldDelete::Call(field); } }; template struct _FlatMsgViewFieldDelete { static void Call(FieldPtrT* field) {} }; template struct _FlatMsgViewFieldDelete { static void Call(FieldPtrT* field) { field->__Delete__(); } }; template struct FlatMsgView final { FlatMsgView(const FlatMsgView&) = delete; FlatMsgView(FlatMsgView&&) = delete; static_assert(T::__is_flat_message_view_type__, "T is not a flat message view type"); FlatMsgView() { view_.template __WalkField__(&view_); } template explicit FlatMsgView(const RepeatedFlatMsgT& repeated_flat_msg) { view_.template __WalkField__(&view_); CHECK(this->template Match(repeated_flat_msg)); } ~FlatMsgView() { view_.template __ReverseWalkField__(&view_); } const T& operator*() const { return view_; } T& operator*() { return view_; } const T* operator->() const { return &view_; } T* operator->() { return &view_; } const T& Get() const { return view_; } T* Mutable() { return &view_; } template bool Match(const RepeatedFlatMsgT& repeated_flat_msg) { using OneofType = typename RepeatedFlatMsgT::value_type::self_value_type; return FlatMsgViewContainerUtil::Match(&view_, repeated_flat_msg); } private: union { T view_; }; }; } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_ ================================================ FILE: oneflow/core/intrusive/flat_msg_view_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/flat_msg_view.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { namespace { // clang-format off FLAT_MSG_BEGIN(VariantFoo); FLAT_MSG_DEFINE_STRICT_ONEOF(_, FLAT_MSG_ONEOF_FIELD(int8_t, int8_value) FLAT_MSG_ONEOF_FIELD(int16_t, int16_value) FLAT_MSG_ONEOF_FIELD(int32_t, int32_value) FLAT_MSG_ONEOF_FIELD(float, float_value)); FLAT_MSG_END(VariantFoo); // clang-format on // clang-format off FLAT_MSG_BEGIN(VariantList); FLAT_MSG_DEFINE_REPEATED(VariantFoo, foo, 16); FLAT_MSG_END(VariantList); // clang-format on // clang-format off FLAT_MSG_VIEW_BEGIN(ViewFoo); FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value); FLAT_MSG_VIEW_DEFINE_PATTERN(int16_t, int16_value); FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value); FLAT_MSG_VIEW_END(ViewFoo); // clang-format on TEST(FlatMsgView, match_success) { FlatMsg variant_list; variant_list.Mutable()->mutable_foo()->Add()->set_int32_value(30); variant_list.Mutable()->mutable_foo()->Add()->set_int16_value(40); variant_list.Mutable()->mutable_foo()->Add()->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.template Match(variant_list.Get().foo())); ASSERT_EQ(view->int32_value(), 30); ASSERT_EQ(view->int16_value(), 40); ASSERT_EQ(view->float_value(), 50.0); } TEST(FlatMsgView, match_failed) { FlatMsg variant_list; variant_list.Mutable()->mutable_foo()->Add()->set_int16_value(40); variant_list.Mutable()->mutable_foo()->Add()->set_int32_value(30); variant_list.Mutable()->mutable_foo()->Add()->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(!view.template Match(variant_list.Get().foo())); } TEST(FlatMsgView, match_success_vector) { std::vector> variant_list(3); variant_list.at(0)->set_int32_value(30); variant_list.at(1)->set_int16_value(40); variant_list.at(2)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.template Match(variant_list)); ASSERT_EQ(view->int32_value(), 30); ASSERT_EQ(view->int16_value(), 40); ASSERT_EQ(view->float_value(), 50.0); } TEST(FlatMsgView, match_failed_vector) { std::vector> variant_list(3); variant_list.at(0)->set_int16_value(40); variant_list.at(1)->set_int32_value(30); variant_list.at(2)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(!view.template Match(variant_list)); } // clang-format off FLAT_MSG_VIEW_BEGIN(RepeatedFoo); FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value); FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(int16_t, int16_value); FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value); FLAT_MSG_VIEW_END(RepeatedFoo); // clang-format on TEST(FlatMsgView, repeated_empty) { std::vector> variant_list(2); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 0); } TEST(FlatMsgView, repeated_empty_failed) { std::vector> variant_list(2); variant_list.at(0)->set_float_value(50.0); variant_list.at(1)->set_int32_value(40); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } TEST(FlatMsgView, repeated_one) { std::vector> variant_list(3); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_int16_value(45); variant_list.at(2)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 1); ASSERT_EQ(view->int16_value(0), 45); } TEST(FlatMsgView, repeated_one_failed) { std::vector> variant_list(3); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_float_value(50.0); variant_list.at(2)->set_int16_value(45); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } TEST(FlatMsgView, repeated_many) { std::vector> variant_list(4); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_int16_value(45); variant_list.at(2)->set_int16_value(45); variant_list.at(3)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 2); ASSERT_EQ(view->int16_value(0), 45); ASSERT_EQ(view->int16_value(1), 45); } TEST(FlatMsgView, repeated_many_failed) { std::vector> variant_list(4); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_int16_value(45); variant_list.at(2)->set_float_value(45.0); variant_list.at(3)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } // clang-format off FLAT_MSG_VIEW_BEGIN(LastFieldRepeatedFoo); FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value); FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value); FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(int16_t, int16_value); FLAT_MSG_VIEW_END(LastFieldRepeatedFoo); // clang-format on TEST(FlatMsgView, last_field_repeated_empty) { std::vector> variant_list(2); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 0); } TEST(FlatMsgView, last_field_repeated_empty_failed) { std::vector> variant_list(2); variant_list.at(0)->set_float_value(50.0); variant_list.at(1)->set_int32_value(40); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } TEST(FlatMsgView, last_field_repeated_one) { std::vector> variant_list(3); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_float_value(50.0); variant_list.at(2)->set_int16_value(45); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 1); ASSERT_EQ(view->int16_value(0), 45); } TEST(FlatMsgView, last_field_repeated_one_failed) { std::vector> variant_list(3); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_int16_value(45); variant_list.at(2)->set_float_value(50.0); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } TEST(FlatMsgView, last_field_repeated_many) { std::vector> variant_list(4); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_float_value(50.0); variant_list.at(2)->set_int16_value(45); variant_list.at(3)->set_int16_value(45); FlatMsgView view; ASSERT_TRUE(view.Match(variant_list)); ASSERT_EQ(view->int16_value_size(), 2); ASSERT_EQ(view->int16_value(0), 45); ASSERT_EQ(view->int16_value(1), 45); } TEST(FlatMsgView, last_field_repeated_many_failed) { std::vector> variant_list(4); variant_list.at(0)->set_int32_value(40); variant_list.at(1)->set_int16_value(45); variant_list.at(2)->set_float_value(50.0); variant_list.at(3)->set_int16_value(45); FlatMsgView view; ASSERT_TRUE(!view.Match(variant_list)); } } // namespace } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/for_each.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_ #define ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_ #include "oneflow/core/intrusive/list_hook.h" #include "oneflow/core/intrusive/struct_traits.h" namespace oneflow { namespace intrusive { #define INTRUSIVE_FOR_EACH(elem, container) \ _INTRUSIVE_FOR_EACH(std::remove_pointer::type, elem, container) #define INTRUSIVE_FOR_EACH_PTR(elem, container) \ _INTRUSIVE_FOR_EACH_PTR(std::remove_pointer::type, elem, container) #define INTRUSIVE_UNSAFE_FOR_EACH_PTR(elem, container) \ _INTRUSIVE_UNSAFE_FOR_EACH_PTR(std::remove_pointer::type, elem, container) // details #define _INTRUSIVE_FOR_EACH(container_type, elem, container) \ for (intrusive::shared_ptr elem, \ *end_if_not_null = nullptr; \ end_if_not_null == nullptr; end_if_not_null = nullptr, ++end_if_not_null) \ LIST_HOOK_FOR_EACH_WITH_EXPR( \ (intrusive::OffsetStructField< \ typename container_type, intrusive::ListHook, \ container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \ container_type::iterator_struct_field, elem_ptr, (elem.Reset(elem_ptr), true)) #define _INTRUSIVE_FOR_EACH_PTR(container_type, elem, container) \ LIST_HOOK_FOR_EACH((intrusive::OffsetStructField< \ typename container_type, intrusive::ListHook, \ container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \ container_type::iterator_struct_field, elem) #define _INTRUSIVE_UNSAFE_FOR_EACH_PTR(container_type, elem, container) \ LIST_HOOK_UNSAFE_FOR_EACH( \ (intrusive::OffsetStructField< \ typename container_type, intrusive::ListHook, \ container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \ container_type::iterator_struct_field, elem) } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_ ================================================ FILE: oneflow/core/intrusive/force_standard_layout.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_ #define ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_ namespace oneflow { namespace intrusive { template class ForceStandardLayout final { public: ForceStandardLayout() { new (&object_) T(); } template::type>::value>::type> explicit ForceStandardLayout(Arg&& arg) { new (&object_) T(std::forward(arg)); } template ForceStandardLayout(Arg0&& arg0, Arg1&& arg1, Args&&... args) { new (&object_) T(std::forward(arg0), std::forward(arg1), std::forward(args)...); } ~ForceStandardLayout() { Mutable()->~T(); } ForceStandardLayout(const ForceStandardLayout& other) { new (&object_) T(other.Get()); } ForceStandardLayout(ForceStandardLayout&& other) { new (&object_) T(std::move(*other.Mutable())); } ForceStandardLayout& operator=(const ForceStandardLayout& other) { *Mutable() = other.Get(); return *this; } ForceStandardLayout& operator=(ForceStandardLayout&& other) { *Mutable() = std::move(*other.Mutable()); return *this; } const T& Get() const { const auto* __attribute__((__may_alias__)) ptr = reinterpret_cast(&object_[0]); return *ptr; } T* Mutable() { auto* __attribute__((__may_alias__)) ptr = reinterpret_cast(&object_[0]); return ptr; } private: alignas(T) char object_[sizeof(T)]; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_ ================================================ FILE: oneflow/core/intrusive/force_standard_layout_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/force_standard_layout.h" namespace oneflow { namespace intrusive { namespace test { constexpr const int unstandard_value = 999; constexpr const int standard_value = 666; struct Unstandard { public: explicit Unstandard(int* ptr) : x(unstandard_value), ptr_(ptr) {} ~Unstandard() { *ptr_ = unstandard_value; } Unstandard(const Unstandard&) = default; Unstandard(Unstandard&&) = default; Unstandard& operator=(const Unstandard&) = default; Unstandard& operator=(Unstandard&&) = default; int* ptr() const { return ptr_; } void set_ptr(int* val) { ptr_ = val; } int x; private: int* ptr_; }; TEST(ForceStandardLayout, default_constructor) { int value = standard_value; ForceStandardLayout sl(&value); ASSERT_EQ(sl.Get().x, unstandard_value); ASSERT_EQ(sl.Get().ptr(), &value); } TEST(ForceStandardLayout, copy_constructor) { int value = standard_value; const ForceStandardLayout const_sl(&value); ForceStandardLayout sl(const_sl); // NOLINT ASSERT_EQ(sl.Get().x, unstandard_value); ASSERT_EQ(sl.Get().ptr(), &value); } TEST(ForceStandardLayout, move_constructor) { int value = standard_value; ForceStandardLayout old_sl(&value); ForceStandardLayout sl(std::move(old_sl)); ASSERT_EQ(sl.Get().x, unstandard_value); ASSERT_EQ(sl.Get().ptr(), &value); } TEST(ForceStandardLayout, copy_assign) { int value = standard_value; const ForceStandardLayout const_sl(&value); ForceStandardLayout sl = const_sl; // NOLINT ASSERT_EQ(sl.Get().x, unstandard_value); ASSERT_EQ(sl.Get().ptr(), &value); } TEST(ForceStandardLayout, move_assign) { int value = standard_value; ForceStandardLayout sl = ForceStandardLayout(&value); ASSERT_EQ(sl.Get().x, unstandard_value); ASSERT_EQ(sl.Get().ptr(), &value); } TEST(ForceStandardLayout, destructor) { int value = standard_value; { ForceStandardLayout sl(&value); } ASSERT_EQ(value, unstandard_value); } } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/head_free_list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_ #define ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_ #include "oneflow/core/intrusive/ref.h" #include "oneflow/core/intrusive/list_hook.h" #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/intrusive/reflective.h" namespace oneflow { namespace intrusive { template class HeadFreeList { public: static_assert(std::is_same::value, ""); HeadFreeList(const HeadFreeList&) = delete; HeadFreeList(HeadFreeList&&) = delete; HeadFreeList() { this->__Init__(); } ~HeadFreeList() { this->Clear(); } using value_type = typename ValueHookField::struct_type; using iterator_struct_field = ValueHookField; // field_counter is last field_number static const int field_number_in_countainter = field_counter + 1; template static constexpr int IteratorHookOffset() { return offsetof(HeadFreeList, list_head_) + intrusive::ListHead::IteratorHookOffset(); } std::size_t size() const { return list_head_.size(); } bool empty() const { return list_head_.empty(); } void __Init__() { list_head_.__Init__(); static_assert( std::is_same::value, "It's invalid to define fields between definition of head-free list type and definition of " "head-free list field."); using ThisInContainer = OffsetStructField; container_ = ThisInContainer::StructPtr4FieldPtr(this); } value_type* Begin() { if (list_head_.empty()) { return nullptr; } return list_head_.Begin(); } value_type* Next(value_type* ptr) { if (ptr == nullptr) { return nullptr; } value_type* next = list_head_.Next(ptr); if (next == list_head_.End()) { return nullptr; } return next; } value_type* Last() { if (list_head_.empty()) { return nullptr; } return list_head_.Last(); } constexpr value_type* End() const { return nullptr; } void MoveToDstBack(value_type* ptr, HeadFreeList* dst) { list_head_.MoveToDstBack(ptr, &dst->list_head_); MoveReference(ptr, dst); } void MoveToDstFront(value_type* ptr, HeadFreeList* dst) { list_head_.MoveToDstFront(ptr, &dst->list_head_); MoveReference(ptr, dst); } value_type* MoveFrontToDstBack(HeadFreeList* dst) { value_type* begin = list_head_.Begin(); MoveToDstBack(begin, dst); return begin; } value_type* MoveBackToDstBack(HeadFreeList* dst) { value_type* begin = list_head_.Last(); MoveToDstBack(begin, dst); return begin; } void PushBack(value_type* ptr) { list_head_.PushBack(ptr); if (container_ != ptr) { Ref::IncreaseRef(ptr); } } void PushFront(value_type* ptr) { list_head_.PushFront(ptr); if (container_ != ptr) { Ref::IncreaseRef(ptr); } } void EmplaceBack(intrusive::shared_ptr&& ptr) { value_type* raw_ptr = nullptr; if (container_ != ptr.Mutable()) { ptr.__UnsafeMoveTo__(&raw_ptr); } else { raw_ptr = ptr.Mutable(); } list_head_.PushBack(raw_ptr); } void EmplaceFront(intrusive::shared_ptr&& ptr) { value_type* raw_ptr = nullptr; if (container_ != ptr.Mutable()) { ptr.__UnsafeMoveTo__(&raw_ptr); } else { raw_ptr = ptr.Mutable(); } list_head_.PushFront(raw_ptr); } intrusive::shared_ptr Erase(value_type* ptr) { list_head_.Erase(ptr); if (container_ != ptr) { return intrusive::shared_ptr::__UnsafeMove__(ptr); } else { return intrusive::shared_ptr(ptr); } } intrusive::shared_ptr PopBack() { value_type* raw_ptr = nullptr; if (!list_head_.empty()) { raw_ptr = list_head_.PopBack(); } if (container_ != raw_ptr) { return intrusive::shared_ptr::__UnsafeMove__(raw_ptr); } else { return intrusive::shared_ptr(raw_ptr); } } intrusive::shared_ptr PopFront() { value_type* raw_ptr = nullptr; if (!list_head_.empty()) { raw_ptr = list_head_.PopFront(); } if (container_ != raw_ptr) { return intrusive::shared_ptr::__UnsafeMove__(raw_ptr); } else { return intrusive::shared_ptr(raw_ptr); } } void MoveTo(HeadFreeList* list) { MoveToDstBack(list); } void MoveToDstBack(HeadFreeList* list) { while (!empty()) { MoveToDstBack(list_head_.Begin(), list); } } void Clear() { while (!empty()) { auto* ptr = list_head_.PopFront(); if (container_ != ptr) { Ref::DecreaseRef(ptr); } } } private: void MoveReference(value_type* ptr, HeadFreeList* dst) { if (ptr == container_ && ptr != dst->container_) { Ref::IncreaseRef(ptr); } else if (ptr != container_ && ptr == dst->container_) { Ref::DecreaseRef(ptr); } else { // do nothing } } intrusive::ListHead list_head_; const value_type* container_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_ ================================================ FILE: oneflow/core/intrusive/head_free_list_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #define private public #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/intrusive.h" namespace oneflow { namespace test { namespace { // clang-format off REFLECTIVE_CLASS_BEGIN(SelfLoopContainer); public: void __Init__() { clear_deleted(); } // Getters bool has_deleted() const { return deleted_ != nullptr; } bool deleted() const { return *deleted_; } bool is_hook_empty() const { return hook_.empty(); } // Setters bool* mut_deleted() { return deleted_; } void set_deleted(bool* val) { deleted_ = val; } void clear_deleted() { deleted_ = nullptr; } // methods void __Init__(bool* deleted) { __Init__(); set_deleted(deleted); } void __Delete__() { *mut_deleted() = true; } size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } SelfLoopContainer() : intrusive_ref_(), deleted_(), hook_(), head_() {} REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::Ref, intrusive_ref_); // fields REFLECTIVE_CLASS_DEFINE_FIELD(bool*, deleted_); // list hooks REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::ListHook, hook_); public: // Do not insert other REFLECTIVE_CLASS_DEFINE_FIELD between `using SelfLoopContainerList = ...;` and `REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, ...);` using SelfLoopContainerList = intrusive::HeadFreeList; const SelfLoopContainerList& head() const { return head_; } SelfLoopContainerList* mut_head() { return &head_; } private: REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, head_); REFLECTIVE_CLASS_END(SelfLoopContainer); // clang-format on TEST(HeadFreeList, __Init__) { bool deleted = false; auto self_loop_head = intrusive::make_shared(&deleted); ASSERT_EQ(self_loop_head->mut_head()->container_, self_loop_head.Mutable()); } TEST(HeadFreeList, PushBack) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); ASSERT_EQ(self_loop_head0->head().size(), 1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); ASSERT_EQ(self_loop_head1->ref_cnt(), 2); ASSERT_EQ(self_loop_head0->head().size(), 2); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, PushFront) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); self_loop_head0->mut_head()->PushFront(self_loop_head0.Mutable()); ASSERT_EQ(self_loop_head0->head().size(), 1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); self_loop_head0->mut_head()->PushFront(self_loop_head1.Mutable()); ASSERT_EQ(self_loop_head1->ref_cnt(), 2); ASSERT_EQ(self_loop_head0->head().size(), 2); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, EmplaceBack) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); self_loop_head0->mut_head()->EmplaceBack( intrusive::shared_ptr(self_loop_head0)); ASSERT_EQ(self_loop_head0->head().size(), 1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); self_loop_head0->mut_head()->EmplaceBack( intrusive::shared_ptr(self_loop_head1)); ASSERT_EQ(self_loop_head1->ref_cnt(), 2); ASSERT_EQ(self_loop_head0->head().size(), 2); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, EmplaceFront) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); self_loop_head0->mut_head()->EmplaceFront( intrusive::shared_ptr(self_loop_head0)); ASSERT_EQ(self_loop_head0->head().size(), 1); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); self_loop_head0->mut_head()->EmplaceFront( intrusive::shared_ptr(self_loop_head1)); ASSERT_EQ(self_loop_head1->ref_cnt(), 2); ASSERT_EQ(self_loop_head0->head().size(), 2); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, Erase) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); self_loop_head0->mut_head()->Erase(self_loop_head0.Mutable()); self_loop_head0->mut_head()->Erase(self_loop_head1.Mutable()); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, PopBack) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); self_loop_head0->mut_head()->PopBack(); self_loop_head0->mut_head()->PopBack(); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, PopFront) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); self_loop_head0->mut_head()->PopFront(); self_loop_head0->mut_head()->PopFront(); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, MoveTo) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); self_loop_head0->mut_head()->MoveTo(self_loop_head1->mut_head()); ASSERT_EQ(self_loop_head0->ref_cnt(), 2); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } TEST(HeadFreeList, Clear) { bool deleted0 = false; bool deleted1 = false; { auto self_loop_head0 = intrusive::make_shared(&deleted0); auto self_loop_head1 = intrusive::make_shared(&deleted1); self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable()); self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable()); self_loop_head0->mut_head()->Clear(); ASSERT_EQ(self_loop_head0->ref_cnt(), 1); ASSERT_EQ(self_loop_head1->ref_cnt(), 1); } ASSERT_TRUE(deleted0); ASSERT_TRUE(deleted1); } } // namespace } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/intrusive.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_ #define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_ #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/intrusive/base.h" #include "oneflow/core/intrusive/ref.h" #include "oneflow/core/intrusive/shared_ptr.h" #include "oneflow/core/intrusive/list.h" #include "oneflow/core/intrusive/head_free_list.h" #include "oneflow/core/intrusive/skiplist.h" #include "oneflow/core/intrusive/for_each.h" #include "oneflow/core/intrusive/reflective.h" #include "oneflow/core/intrusive/force_standard_layout.h" #endif // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_ ================================================ FILE: oneflow/core/intrusive/intrusive_core_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #define private public #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { namespace intrusive { namespace test { namespace { TEST(Ref, ref_cnt) { class Foo final : public Ref { public: Foo() = default; }; Foo foo; foo.InitRefCount(); foo.IncreaseRefCount(); foo.IncreaseRefCount(); ASSERT_EQ(foo.DecreaseRefCount(), 1); ASSERT_EQ(foo.DecreaseRefCount(), 0); } class IntrusiveFoo final : public intrusive::Base { public: void __Init__() { clear_is_deleted(); } void __Delete__(); // Getters int8_t x() const { return x_; } int32_t foo() const { return foo_; } int16_t bar() const { return bar_; } int64_t foobar() const { return foobar_; } bool has_is_deleted() const { return is_deleted_ != nullptr; } const std::string& is_deleted() const { return *is_deleted_; } // Setters void set_x(int8_t val) { x_ = val; } void set_foo(int32_t val) { foo_ = val; } void set_bar(int16_t val) { bar_ = val; } void set_foobar(int64_t val) { foobar_ = val; } void set_is_deleted(std::string* val) { is_deleted_ = val; } std::string* mut_is_deleted() { return is_deleted_; } void clear_is_deleted() { is_deleted_ = nullptr; } size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } IntrusiveFoo() : intrusive_ref_(), x_(), foo_(), bar_(), foobar_(), is_deleted_() {} intrusive::Ref intrusive_ref_; int8_t x_; int32_t foo_; int16_t bar_; int64_t foobar_; std::string* is_deleted_; }; void IntrusiveFoo::__Delete__() { if (mut_is_deleted()) { *mut_is_deleted() = "deleted"; } } TEST(intrusive, naive) { auto foo = intrusive::make_shared(); foo->set_bar(9527); ASSERT_TRUE(foo->bar() == 9527); } TEST(intrusive, __delete__) { std::string is_deleted; { auto foo = intrusive::make_shared(); foo->set_bar(9527); foo->set_is_deleted(&is_deleted); ASSERT_EQ(foo->bar(), 9527); } ASSERT_TRUE(is_deleted == "deleted"); } class IntrusiveBar final : public intrusive::Base { public: void __Init__() { clear_is_deleted(); } void __Delete__() { if (mut_is_deleted()) { *mut_is_deleted() = "bar_deleted"; } } // Getters const IntrusiveFoo& foo() const { if (foo_) { return foo_.Get(); } static const auto default_val = intrusive::make_shared(); return default_val.Get(); } const std::string& is_deleted() const { return *is_deleted_; } bool has_is_deleted() const { return is_deleted_ != nullptr; } // Setters IntrusiveFoo* mut_foo() { if (!foo_) { foo_ = intrusive::make_shared(); } return foo_.Mutable(); } std::string* mut_is_deleted() { return is_deleted_; } void set_is_deleted(std::string* val) { is_deleted_ = val; } void clear_is_deleted() { is_deleted_ = nullptr; } size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } IntrusiveBar() : intrusive_ref_(), foo_(), is_deleted_() {} intrusive::Ref intrusive_ref_; intrusive::shared_ptr foo_; std::string* is_deleted_; }; TEST(intrusive, nested_objects) { auto bar = intrusive::make_shared(); bar->mut_foo()->set_bar(9527); ASSERT_TRUE(bar->foo().bar() == 9527); } TEST(intrusive, nested_delete) { std::string bar_is_deleted; std::string is_deleted; { auto bar = intrusive::make_shared(); bar->set_is_deleted(&bar_is_deleted); auto* foo = bar->mut_foo(); foo->set_bar(9527); foo->set_is_deleted(&is_deleted); ASSERT_EQ(foo->bar(), 9527); ASSERT_EQ(bar->ref_cnt(), 1); ASSERT_EQ(foo->ref_cnt(), 1); } ASSERT_EQ(is_deleted, std::string("deleted")); ASSERT_EQ(bar_is_deleted, std::string("bar_deleted")); } // clang-format off FLAT_MSG_BEGIN(FlatMsgDemo) FLAT_MSG_DEFINE_ONEOF(type, FLAT_MSG_ONEOF_FIELD(int32_t, int32_field) FLAT_MSG_ONEOF_FIELD(float, float_field)); FLAT_MSG_END(FlatMsgDemo) // clang-format on class IntrusiveContainerDemo final : public intrusive::Base { public: // Getters const FlatMsgDemo& flat_field() const { return flat_field_.Get(); } // Setters FlatMsgDemo* mut_flat_field() { return flat_field_.Mutable(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } IntrusiveContainerDemo() : intrusive_ref_(), flat_field_() {} intrusive::Ref intrusive_ref_; FlatMsg flat_field_; }; TEST(intrusive, flat_msg_field) { auto obj = intrusive::make_shared(); ASSERT_TRUE(!obj->flat_field().has_int32_field()); obj->mut_flat_field()->set_int32_field(33); ASSERT_TRUE(obj->flat_field().has_int32_field()); ASSERT_EQ(obj->flat_field().int32_field(), 33); } // clang-format off REFLECTIVE_CLASS_BEGIN(TestIntrusiveField); TestIntrusiveField() = default; static_assert(REFLECTIVE_FIELD_COUNTER == 0, ""); static_assert(REFLECTIVE_FIELD_COUNTER == 0, ""); REFLECTIVE_CLASS_DEFINE_FIELD(int32_t, a); static_assert(REFLECTIVE_FIELD_COUNTER == 1, ""); static_assert(REFLECTIVE_FIELD_COUNTER == 1, ""); REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b); static_assert(REFLECTIVE_FIELD_COUNTER == 2, ""); static_assert(REFLECTIVE_FIELD_COUNTER == 2, ""); REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c); static_assert(REFLECTIVE_FIELD_COUNTER == 3, ""); static_assert(REFLECTIVE_FIELD_COUNTER == 3, ""); REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d); static_assert(REFLECTIVE_FIELD_COUNTER == 4, ""); static_assert(REFLECTIVE_FIELD_COUNTER == 4, ""); REFLECTIVE_CLASS_END(TestIntrusiveField); // clang-format on TEST(intrusive, intrusive_field_number) { static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, a) == 1, ""); static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, b) == 2, ""); static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, c) == 3, ""); static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, d) == 4, ""); } TEST(intrusive, intrusive_field_type) { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); } TEST(intrusive, intrusive_field_offset) { static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 1) == 0, ""); static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 2) == 8, ""); static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 3) == 16, ""); static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 4) == 24, ""); } } // namespace } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_LIST_H_ #define ONEFLOW_CORE_INTRUSIVE_LIST_H_ #include "oneflow/core/intrusive/ref.h" #include "oneflow/core/intrusive/list_hook.h" namespace oneflow { namespace intrusive { template class List { public: List(const List&) = delete; List(List&&) = delete; List() { this->__Init__(); } ~List() { this->Clear(); } using value_type = typename HookField::struct_type; using iterator_struct_field = HookField; template static constexpr int IteratorHookOffset() { return offsetof(List, list_head_) + intrusive::ListHead::IteratorHookOffset(); } std::size_t size() const { return list_head_.size(); } bool empty() const { return list_head_.empty(); } void CheckSize() const { list_head_.CheckSize(); } void __Init__() { list_head_.__Init__(); } value_type* Begin() { if (list_head_.empty()) { return nullptr; } return list_head_.Begin(); } value_type* Prev(value_type* ptr) { if (ptr == nullptr) { return nullptr; } value_type* prev = list_head_.Prev(ptr); if (prev == list_head_.End()) { return nullptr; } return prev; } value_type* Next(value_type* ptr) { if (ptr == nullptr) { return nullptr; } value_type* next = list_head_.Next(ptr); if (next == list_head_.End()) { return nullptr; } return next; } value_type* Last() { if (list_head_.empty()) { return nullptr; } return list_head_.Last(); } constexpr value_type* End() const { return nullptr; } void MoveToDstBack(value_type* ptr, List* dst) { list_head_.MoveToDstBack(ptr, &dst->list_head_); } void MoveToDstFront(value_type* ptr, List* dst) { list_head_.MoveToDstFront(ptr, &dst->list_head_); } value_type* MoveFrontToDstBack(List* dst) { value_type* begin = list_head_.Begin(); MoveToDstBack(begin, dst); return begin; } value_type* MoveBackToDstBack(List* dst) { value_type* begin = list_head_.Last(); MoveToDstBack(begin, dst); return begin; } void PushBack(value_type* ptr) { list_head_.PushBack(ptr); Ref::IncreaseRef(ptr); } void PushFront(value_type* ptr) { list_head_.PushFront(ptr); Ref::IncreaseRef(ptr); } void EmplaceBack(intrusive::shared_ptr&& ptr) { value_type* raw_ptr = nullptr; ptr.__UnsafeMoveTo__(&raw_ptr); list_head_.PushBack(raw_ptr); } void EmplaceFront(intrusive::shared_ptr&& ptr) { value_type* raw_ptr = nullptr; ptr.__UnsafeMoveTo__(&raw_ptr); list_head_.PushFront(raw_ptr); } intrusive::shared_ptr Erase(value_type* ptr) { list_head_.Erase(ptr); return intrusive::shared_ptr::__UnsafeMove__(ptr); } intrusive::shared_ptr PopBack() { value_type* raw_ptr = nullptr; if (!list_head_.empty()) { raw_ptr = list_head_.PopBack(); } return intrusive::shared_ptr::__UnsafeMove__(raw_ptr); } intrusive::shared_ptr PopFront() { value_type* raw_ptr = nullptr; if (!list_head_.empty()) { raw_ptr = list_head_.PopFront(); } return intrusive::shared_ptr::__UnsafeMove__(raw_ptr); } void MoveTo(List* list) { MoveToDstBack(list); } void MoveToDstBack(List* list) { list_head_.MoveToDstBack(&list->list_head_); } void Clear() { while (!empty()) { Ref::DecreaseRef(list_head_.PopFront()); } } private: intrusive::ListHead list_head_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_LIST_H_ ================================================ FILE: oneflow/core/intrusive/list_hook.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_ #define ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_ #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/common/throw.h" namespace oneflow { namespace intrusive { struct ListHook { public: ListHook() { Clear(); } ListHook* prev() const { return prev_; } ListHook* next() const { return next_; } // NOLINT void __Init__() { Clear(); } void Clear() { prev_ = this; next_ = this; } bool empty() const { return prev_ == this || next_ == this; } void AppendTo(ListHook* prev) { prev->set_next(this); this->set_prev(prev); } void InsertAfter(ListHook* prev) { auto* next = prev->next(); this->AppendTo(prev); next->AppendTo(this); } void Erase() { next_->AppendTo(prev_); Clear(); } bool nullptr_empty() const { return prev_ == nullptr && next_ == nullptr; } void NullptrClear() { prev_ = nullptr; next_ = nullptr; } private: void set_prev(ListHook* prev) { prev_ = prev; } void set_next(ListHook* next) { next_ = next; } ListHook* prev_; ListHook* next_; }; #define LIST_HOOK_FOR_EACH(head_hook, elem_hook_struct_field, elem) \ LIST_HOOK_FOR_EACH_WITH_EXPR(head_hook, elem_hook_struct_field, elem, 0) #define LIST_HOOK_FOR_EACH_WITH_EXPR(head_hook, elem_hook_struct_field, elem, expr) \ for (typename elem_hook_struct_field::struct_type* elem = nullptr; elem == nullptr; \ elem = nullptr, elem++) \ LIST_HOOK_FOR_EACH_I(head_hook, __elem_hook__, \ ((elem = elem_hook_struct_field::StructPtr4FieldPtr(__elem_hook__)), expr)) #define LIST_HOOK_FOR_EACH_I(head_hook, elem_hook, expr) \ for (intrusive::ListHook* __head_hook__ = (head_hook), *elem_hook = __head_hook__->next(), \ *__next_hook__ = elem_hook->next(); \ (elem_hook != __head_hook__) && ((expr) || true); \ elem_hook = __next_hook__, __next_hook__ = __next_hook__->next()) #define LIST_HOOK_UNSAFE_FOR_EACH(head_hook, elem_hook_struct_field, elem) \ for (typename elem_hook_struct_field::struct_type* elem = nullptr; elem == nullptr; \ elem = nullptr, elem++) \ LIST_HOOK_UNSAFE_FOR_EACH_I(head_hook, __elem_hook__, \ (elem = elem_hook_struct_field::StructPtr4FieldPtr(__elem_hook__))) #define LIST_HOOK_UNSAFE_FOR_EACH_I(head_hook, elem_hook, expr) \ for (intrusive::ListHook* __head_hook__ = (head_hook), *elem_hook = __head_hook__->next(); \ (elem_hook != __head_hook__) && ((expr), true); elem_hook = elem_hook->next()) template class ListHead { public: ListHead() { Clear(); } using value_type = typename HookField::struct_type; static_assert(std::is_same::value, "no ListHook found"); template static constexpr int IteratorHookOffset() { return offsetof(ListHead, container_); } std::size_t size() const { return size_; } bool empty() const { bool list_empty = (&Begin() == &End()); bool size_empty = (size_ == 0); CHECK_EQ(list_empty, size_empty); return size_empty; } void CheckSize() const { size_t hook_size = 0; for (ListHook* iter = container_.next(); iter != &container_; iter = iter->next()) { ++hook_size; } CHECK_EQ(size_, hook_size); } const value_type& Begin() const { return Next(End()); } const value_type& ReverseBegin() const { return Prev(End()); } const value_type& End() const { return *HookField::StructPtr4FieldPtr(&container()); } const value_type& Next(const value_type& current) const { return *HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(¤t)->next()); } const value_type& Prev(const value_type& current) const { return *HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(¤t)->prev()); } value_type* Begin() { return Next(End()); } value_type* Last() { return Prev(End()); } value_type* End() { return HookField::StructPtr4FieldPtr(mut_container()); } value_type* Next(value_type* current) { return HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(current)->next()); } value_type* Prev(value_type* current) { return HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(current)->prev()); } void __Init__() { Clear(); } void Clear() { container_.__Init__(); size_ = 0; } void Erase(value_type* elem) { CHECK_GT(size_, 0); CHECK_NE(elem, End()); ListHook* list_hook = HookField::FieldPtr4StructPtr(elem); CHECK(!list_hook->empty()); list_hook->Erase(); --size_; } void MoveToDstBack(value_type* elem, ListHead* dst) { CHECK(!container_.empty()); auto* dst_rbegin = dst->container_.prev(); auto* dst_end = &dst->container_; ListHook* elem_hook = HookField::FieldPtr4StructPtr(elem); elem_hook->next()->AppendTo(elem_hook->prev()); elem_hook->AppendTo(dst_rbegin); dst_end->AppendTo(elem_hook); --size_; ++dst->size_; } void MoveToDstFront(value_type* elem, ListHead* dst) { CHECK(!container_.empty()); auto* dst_end = &dst->container_; auto* dst_begin = dst->container_.next(); ListHook* elem_hook = HookField::FieldPtr4StructPtr(elem); elem_hook->next()->AppendTo(elem_hook->prev()); elem_hook->AppendTo(dst_end); dst_begin->AppendTo(elem_hook); --size_; ++dst->size_; } void PushBack(value_type* elem) { InsertAfter(Last(), elem); } void PushFront(value_type* elem) { InsertAfter(End(), elem); } value_type* PopBack() { CHECK(!empty()); value_type* last = Last(); Erase(last); return last; } value_type* PopFront() { CHECK(!empty()); value_type* first = Begin(); Erase(first); return first; } void MoveToDstBack(ListHead* dst) { if (container_.empty()) { return; } auto* dst_last = dst->container_.prev(); auto* dst_end = &dst->container_; auto* this_first = container_.next(); auto* this_last = container_.prev(); this_first->AppendTo(dst_last); dst_end->AppendTo(this_last); dst->size_ += size(); this->Clear(); } private: void InsertAfter(value_type* prev_elem, value_type* new_elem) { ListHook* prev_list_hook = HookField::FieldPtr4StructPtr(prev_elem); ListHook* next_list_hook = prev_list_hook->next(); ListHook* new_list_hook = HookField::FieldPtr4StructPtr(new_elem); CHECK(new_list_hook->empty()); new_list_hook->AppendTo(prev_list_hook); next_list_hook->AppendTo(new_list_hook); ++size_; } const ListHook& container() const { return container_; } ListHook* mut_container() { return &container_; } private: ListHook container_; volatile std::size_t size_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_ ================================================ FILE: oneflow/core/intrusive/list_hook_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #define private public #include "oneflow/core/intrusive/list_hook.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace intrusive { namespace test { struct ListItemBar final { ListItemBar() : value() { bar_list.__Init__(); } int value; ListHook bar_list; }; class TestListHook final : public ListHook { public: TestListHook() { this->__Init__(); } }; template class TestListHead : public intrusive::ListHead { public: TestListHead() { this->__Init__(); } }; using BarListHead = TestListHead; TEST(TestListHook, init) { TestListHook list_iterator; ASSERT_EQ(&list_iterator, list_iterator.prev()); ASSERT_EQ(&list_iterator, list_iterator.next()); } TEST(TestListHook, append_to) { TestListHook list_iter0; TestListHook list_iter1; list_iter1.AppendTo(&list_iter0); ASSERT_EQ(&list_iter0, list_iter1.prev()); ASSERT_EQ(&list_iter1, list_iter0.next()); } TEST(TestListHook, clear) { TestListHook list_head0; TestListHook list_head1; list_head1.AppendTo(&list_head0); list_head1.__Init__(); ASSERT_EQ(&list_head1, list_head1.prev()); ASSERT_EQ(&list_head1, list_head1.next()); } TEST(ListHead, empty) { BarListHead list_head; ASSERT_TRUE(list_head.empty()); } TEST(ListHead, push_front) { BarListHead list_head; ListHook& head = list_head.container_; ListItemBar item0; list_head.PushFront(&item0); ASSERT_EQ(head.next(), &item0.bar_list); ASSERT_EQ(head.prev(), &item0.bar_list); ASSERT_EQ(item0.bar_list.next(), &head); ASSERT_EQ(item0.bar_list.prev(), &head); ListItemBar item1; list_head.PushFront(&item1); ASSERT_EQ(head.next(), &item1.bar_list); ASSERT_EQ(item1.bar_list.prev(), &head); ASSERT_EQ(item1.bar_list.next(), &item0.bar_list); ASSERT_EQ(item0.bar_list.prev(), &item1.bar_list); ASSERT_EQ(item0.bar_list.next(), &head); ASSERT_EQ(head.prev(), &item0.bar_list); } TEST(ListHead, end) { BarListHead list_head; ListItemBar* end_item = list_head.End(); ListItemBar item0; list_head.PushFront(&item0); ASSERT_EQ(end_item, list_head.End()); } TEST(ListHead, begin) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushFront(&item0); ASSERT_EQ(list_head.Begin(), &item0); ListItemBar item1; list_head.PushFront(&item1); ASSERT_EQ(list_head.Begin(), &item1); } TEST(ListHead, last) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushFront(&item0); ASSERT_EQ(list_head.Last(), &item0); ListItemBar item1; list_head.PushFront(&item1); ASSERT_EQ(list_head.Last(), &item0); } TEST(ListHead, push_back) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushBack(&item0); ASSERT_EQ(list_head.Last(), &item0); ListItemBar item1; list_head.PushBack(&item1); ASSERT_EQ(list_head.Last(), &item1); } TEST(ListHead, erase) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushBack(&item0); ASSERT_EQ(list_head.Last(), &item0); ListItemBar item1; list_head.PushBack(&item1); ASSERT_EQ(list_head.Last(), &item1); list_head.Erase(&item0); ASSERT_EQ(list_head.Last(), &item1); ASSERT_EQ(list_head.Begin(), &item1); ASSERT_EQ(item0.bar_list.prev(), &item0.bar_list); ASSERT_EQ(item0.bar_list.next(), &item0.bar_list); } TEST(ListHead, pop_front) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushBack(&item0); ASSERT_EQ(list_head.Last(), &item0); ListItemBar item1; list_head.PushBack(&item1); ASSERT_EQ(list_head.Last(), &item1); list_head.PopFront(); ASSERT_EQ(list_head.Last(), &item1); ASSERT_EQ(list_head.Begin(), &item1); ASSERT_EQ(item0.bar_list.prev(), &item0.bar_list); ASSERT_EQ(item0.bar_list.next(), &item0.bar_list); } TEST(ListHead, pop_back) { BarListHead list_head; ASSERT_EQ(list_head.Begin(), list_head.End()); ListItemBar item0; list_head.PushBack(&item0); ASSERT_EQ(list_head.Last(), &item0); ListItemBar item1; list_head.PushBack(&item1); ASSERT_EQ(list_head.Last(), &item1); list_head.PopBack(); ASSERT_EQ(list_head.Last(), &item0); ASSERT_EQ(list_head.Begin(), &item0); ASSERT_EQ(item1.bar_list.prev(), &item1.bar_list); ASSERT_EQ(item1.bar_list.next(), &item1.bar_list); } TEST(ListHead, Next) { BarListHead list_head; ListItemBar item0; list_head.PushBack(&item0); ListItemBar item1; list_head.PushBack(&item1); ListItemBar* item = list_head.Begin(); ASSERT_EQ(item, &item0); item = list_head.Next(item); ASSERT_EQ(item, &item1); item = list_head.Next(item); ASSERT_EQ(item, list_head.End()); item = list_head.Next(item); ASSERT_EQ(item, &item0); } TEST(ListHead, prev_item) { BarListHead list_head; ListItemBar item0; list_head.PushBack(&item0); ListItemBar item1; list_head.PushBack(&item1); ListItemBar* item = list_head.Begin(); ASSERT_EQ(item, &item0); item = list_head.Prev(item); ASSERT_EQ(item, list_head.End()); item = list_head.Prev(item); ASSERT_EQ(item, &item1); item = list_head.Prev(item); ASSERT_EQ(item, &item0); } } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/list_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // include sstream first to avoid some compiling error // caused by the following trick // reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899 #include #include "gtest/gtest.h" #define private public #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/intrusive.h" namespace oneflow { namespace test { namespace { class TestListItem : public intrusive::Base { public: void __Init__() { clear_cnt(); } void __Delete__() { if (has_cnt()) { --*mut_cnt(); } } // Getters bool has_cnt() const { return cnt_ != nullptr; } int cnt() const { return *cnt_; } bool is_foo_list_empty() const { return foo_list_.empty(); } // Setters void set_cnt(int* val) { cnt_ = val; } void clear_cnt() { cnt_ = nullptr; } int* mut_cnt() { return cnt_; } size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); } intrusive::ListHook foo_list_; private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } TestListItem() : foo_list_(), intrusive_ref_(), cnt_() {} intrusive::Ref intrusive_ref_; int* cnt_; }; using TestList = intrusive::List; TEST(List, empty) { TestList foo_list; ASSERT_TRUE(foo_list.empty()); ASSERT_EQ(foo_list.size(), 0); } TEST(List, empty_Begin) { TestList foo_list; intrusive::shared_ptr obj_ptr; obj_ptr = foo_list.Begin(); ASSERT_TRUE(!obj_ptr); intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(!obj_ptr); } TEST(List, empty_Next) { TestList foo_list; intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(!obj_ptr); ASSERT_TRUE(!next); obj_ptr = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(!obj_ptr); obj_ptr = next; next = foo_list.Next(next.Mutable()); ASSERT_TRUE(!obj_ptr); ASSERT_TRUE(!next); } TEST(List, PushFront) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushFront(item0.Mutable()); foo_list.PushFront(item1.Mutable()); intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(obj_ptr == item1); ASSERT_TRUE(next == item0); } TEST(List, destructor) { int elem_cnt = 2; { TestList foo_list; auto item0 = intrusive::make_shared(); item0->set_cnt(&elem_cnt); auto item1 = intrusive::make_shared(); item1->set_cnt(&elem_cnt); foo_list.PushFront(item0.Mutable()); foo_list.PushFront(item1.Mutable()); } ASSERT_EQ(elem_cnt, 0); elem_cnt = 2; auto item0 = intrusive::make_shared(); { TestList foo_list; item0->set_cnt(&elem_cnt); auto item1 = intrusive::make_shared(); item1->set_cnt(&elem_cnt); foo_list.PushFront(item0.Mutable()); foo_list.PushFront(item1.Mutable()); } ASSERT_EQ(elem_cnt, 1); } TEST(List, PushBack) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(obj_ptr == item0); ASSERT_TRUE(next == item1); } TEST(List, Erase) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item1->ref_cnt(), 2); foo_list.Erase(item1.Mutable()); ASSERT_EQ(item1->ref_cnt(), 1); intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(obj_ptr == item0); ASSERT_TRUE(!next); } TEST(List, PopBack) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item1->ref_cnt(), 2); foo_list.PopBack(); ASSERT_EQ(item1->ref_cnt(), 1); intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(obj_ptr == item0); ASSERT_TRUE(!next); } TEST(List, PopFront) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item0->ref_cnt(), 2); foo_list.PopFront(); ASSERT_EQ(item0->ref_cnt(), 1); intrusive::shared_ptr obj_ptr; intrusive::shared_ptr next; obj_ptr = foo_list.Begin(); next = foo_list.Next(obj_ptr.Mutable()); ASSERT_TRUE(!next); } TEST(List, Clear) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item0->ref_cnt(), 2); ASSERT_EQ(item1->ref_cnt(), 2); foo_list.Clear(); ASSERT_TRUE(foo_list.empty()); ASSERT_EQ(item0->ref_cnt(), 1); ASSERT_EQ(item1->ref_cnt(), 1); } TEST(List, UNSAFE_FOR_EACH_PTR) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); int i = 0; INTRUSIVE_UNSAFE_FOR_EACH_PTR(item, &foo_list) { if (i == 0) { ASSERT_TRUE(item == item0.Mutable()); } else if (i == 1) { ASSERT_TRUE(item == item1.Mutable()); } ++i; } ASSERT_EQ(i, 2); } TEST(List, FOR_EACH) { TestList foo_list; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); int i = 0; INTRUSIVE_FOR_EACH(item, &foo_list) { if (i == 0) { ASSERT_TRUE(item == item0); foo_list.Erase(item.Mutable()); } else if (i == 1) { ASSERT_TRUE(item == item1); foo_list.Erase(item.Mutable()); } ++i; } ASSERT_EQ(i, 2); ASSERT_TRUE(foo_list.empty()); ASSERT_EQ(item0->ref_cnt(), 1); ASSERT_EQ(item1->ref_cnt(), 1); } class TestIntrusiveListHead final : public intrusive::Base { public: // types using FooList = intrusive::List; // Getters const FooList& foo_list() const { return foo_list_; } // Setters FooList* mut_foo_list() { return &foo_list_; } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } TestIntrusiveListHead() : intrusive_ref_(), foo_list_() {} intrusive::Ref intrusive_ref_; FooList foo_list_; }; TEST(List, intrusive_list_for_each) { auto foo_list_head = intrusive::make_shared(); auto& foo_list = *foo_list_head->mut_foo_list(); auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item0->ref_cnt(), 2); ASSERT_EQ(item1->ref_cnt(), 2); int i = 0; INTRUSIVE_FOR_EACH(item, &foo_list) { if (i == 0) { ASSERT_TRUE(item == item0); foo_list.Erase(item.Mutable()); } else if (i == 1) { ASSERT_TRUE(item == item1); foo_list.Erase(item.Mutable()); } ++i; } ASSERT_EQ(i, 2); ASSERT_TRUE(foo_list.empty()); ASSERT_EQ(item0->ref_cnt(), 1); ASSERT_EQ(item1->ref_cnt(), 1); } class TestIntrusiveListHeadWrapper final : public intrusive::Base { public: // Getters const TestIntrusiveListHead& head() const { if (head_) { return head_.Get(); } static const auto default_val = intrusive::make_shared(); return default_val.Get(); } // Setters TestIntrusiveListHead* mut_head() { if (!head_) { head_ = intrusive::make_shared(); } return head_.Mutable(); } void clear_head() { if (head_) { head_.Reset(); } } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } TestIntrusiveListHeadWrapper() : intrusive_ref_(), head_() {} intrusive::Ref intrusive_ref_; intrusive::shared_ptr head_; }; TEST(List, nested_list_delete) { auto foo_list_head = intrusive::make_shared(); auto& foo_list = *foo_list_head->mut_head()->mut_foo_list(); auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item0->ref_cnt(), 2); ASSERT_EQ(item1->ref_cnt(), 2); int i = 0; INTRUSIVE_UNSAFE_FOR_EACH_PTR(item, &foo_list) { if (i == 0) { ASSERT_TRUE(item == item0.Mutable()); } else if (i == 1) { ASSERT_TRUE(item == item1.Mutable()); } ++i; } ASSERT_EQ(i, 2); foo_list_head->clear_head(); ASSERT_EQ(item0->ref_cnt(), 1); ASSERT_EQ(item1->ref_cnt(), 1); } TEST(List, MoveTo) { TestList foo_list; TestList foo_list0; auto item0 = intrusive::make_shared(); auto item1 = intrusive::make_shared(); ASSERT_EQ(item0->is_foo_list_empty(), true); ASSERT_EQ(item1->is_foo_list_empty(), true); foo_list.PushBack(item0.Mutable()); foo_list.PushBack(item1.Mutable()); ASSERT_EQ(item0->is_foo_list_empty(), false); ASSERT_EQ(item1->is_foo_list_empty(), false); ASSERT_EQ(foo_list.size(), 2); ASSERT_EQ(foo_list0.empty(), true); ASSERT_EQ(item0->ref_cnt(), 2); ASSERT_EQ(item1->ref_cnt(), 2); foo_list.MoveTo(&foo_list0); ASSERT_EQ(foo_list0.size(), 2); ASSERT_EQ(foo_list.empty(), true); ASSERT_TRUE(foo_list0.Begin() == item0.Mutable()); ASSERT_TRUE(foo_list0.Last() == item1.Mutable()); ASSERT_EQ(item0->ref_cnt(), 2); ASSERT_EQ(item1->ref_cnt(), 2); } } // namespace } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/mutexed_list.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_ #define ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_ #include #include "oneflow/core/intrusive/list.h" namespace oneflow { namespace intrusive { template class MutexedList { public: using value_type = typename HookField::struct_type; using list_type = List; MutexedList(const MutexedList&) = delete; MutexedList(MutexedList&&) = delete; explicit MutexedList(std::mutex* mutex) { this->__Init__(mutex); } ~MutexedList() { this->Clear(); } std::size_t thread_unsafe_size() const { return list_head_.size(); } std::size_t size() const { std::unique_lock lock(*mutex_); return list_head_.size(); } bool empty() const { std::unique_lock lock(*mutex_); return list_head_.empty(); } void __Init__(std::mutex* mutex) { list_head_.__Init__(); mutex_ = mutex; } void EmplaceBack(intrusive::shared_ptr&& ptr) { std::unique_lock lock(*mutex_); return list_head_.EmplaceBack(std::move(ptr)); } void EmplaceFront(intrusive::shared_ptr&& ptr) { std::unique_lock lock(*mutex_); return list_head_.EmplaceFront(std::move(ptr)); } void PushBack(value_type* ptr) { EmplaceBack(intrusive::shared_ptr(ptr)); } void PushFront(value_type* ptr) { EmplaceFront(intrusive::shared_ptr(ptr)); } intrusive::shared_ptr PopBack() { std::unique_lock lock(*mutex_); return list_head_.PopBack(); } intrusive::shared_ptr PopFront() { std::unique_lock lock(*mutex_); return list_head_.PopFront(); } // Returns true if old list is empty. bool MoveFrom(list_type* src) { std::unique_lock lock(*mutex_); return ThreadUnsafeMoveFrom(src); } // Returns true if old list is empty. bool ThreadUnsafeMoveFrom(list_type* src) { bool old_list_empty = list_head_.empty(); src->MoveToDstBack(&list_head_); return old_list_empty; } void MoveTo(list_type* dst) { std::unique_lock lock(*mutex_); list_head_.MoveToDstBack(dst); } void ThreadUnsafeMoveTo(list_type* dst) { list_head_.MoveToDstBack(dst); } void Clear() { std::unique_lock lock(*mutex_); list_head_.Clear(); } private: list_type list_head_; std::mutex* mutex_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_ ================================================ FILE: oneflow/core/intrusive/object_pool.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_ #define ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_ #include #include "oneflow/core/intrusive/cpp_attribute.h" namespace oneflow { namespace intrusive { enum ObjectPoolStrategey { kThreadUnsafeAndDisableDestruct, }; template class ObjectPool; template class EnableObjectPool { public: EnableObjectPool() = default; EnableObjectPool(const EnableObjectPool&) = default; EnableObjectPool(EnableObjectPool&&) = default; ~EnableObjectPool() = default; using object_pool_type = ObjectPool; object_pool_type* mut_object_pool() { return object_pool_; } void set_object_pool(object_pool_type* val) { object_pool_ = val; } private: object_pool_type* object_pool_; }; template class ObjectPool { public: ObjectPool() { container_.reserve(kObjectPoolInitCap); } ObjectPool(const ObjectPool&) = delete; ObjectPool(ObjectPool&&) = delete; ~ObjectPool() { for (auto* elem : container_) { delete elem; } } template intrusive::shared_ptr make_shared(Args&&... args) { if (INTRUSIVE_PREDICT_FALSE(container_.empty())) { auto ptr = intrusive::make_shared(std::forward(args)...); InitObjectPoolFields4Element(ptr.get()); return ptr; } else { auto* ptr = container_.back(); container_.pop_back(); ptr->__Init__(std::forward(args)...); InitObjectPoolFields4Element(ptr); return intrusive::shared_ptr(ptr); } } static void Put(void* raw_ptr) { T* ptr = reinterpret_cast(raw_ptr); ptr->mut_object_pool()->container_.push_back(ptr); } private: inline void InitObjectPoolFields4Element(T* ptr) { ptr->set_object_pool(this); ptr->mut_intrusive_ref()->set_deleter(&ObjectPool::Put); } static constexpr int kObjectPoolInitCap = 65536; std::vector container_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_ ================================================ FILE: oneflow/core/intrusive/object_pool_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "gtest/gtest.h" #define private public #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/object_pool.h" namespace oneflow { namespace intrusive { namespace test { namespace { class IntrusiveFoo final // NOLINT : public intrusive::Base, public intrusive::EnableObjectPool { // NOLINT public: IntrusiveFoo() = default; // NOLINT intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } private: intrusive::Ref intrusive_ref_; }; TEST(ObjectPool_kThreadUnsafeAndDisableDestruct, append_to_pool) { ObjectPool object_pool; IntrusiveFoo* ptr = nullptr; { ptr = object_pool.make_shared().get(); } ASSERT_EQ(ptr, object_pool.make_shared().get()); } TEST(ObjectPool_kThreadUnsafeAndDisableDestruct, recycle) { ObjectPool object_pool; auto* ptr = object_pool.make_shared().get(); ASSERT_EQ(ptr, object_pool.make_shared().get()); } } // namespace } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/ref.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_REF_H_ #define ONEFLOW_CORE_INTRUSIVE_REF_H_ #include #include "oneflow/core/common/throw.h" #include "oneflow/core/intrusive/cpp_attribute.h" namespace oneflow { namespace intrusive { class Ref { public: Ref() : ref_cnt_(), deleter_(nullptr) {} using RefCntType = int32_t; RefCntType ref_cnt() const { return ref_cnt_; } template static void NewAndInitRef(T** ptr) { *ptr = new T(); (*ptr)->mut_intrusive_ref()->InitRefCount(); IncreaseRef(*ptr); } template static void IncreaseRef(T* ptr) { ptr->mut_intrusive_ref()->IncreaseRefCount(); } template static void DecreaseRef(T* ptr) { CHECK_NOTNULL(ptr); auto* ref = ptr->mut_intrusive_ref(); if (INTRUSIVE_PREDICT_TRUE(ref->DecreaseRefCount() > 0)) { return; } if (INTRUSIVE_PREDICT_TRUE(ref->deleter_ == nullptr)) { ptr->__Delete__(); delete ptr; } else { ref->deleter_(ptr); } } void set_deleter(void (*deleter)(void*)) { deleter_ = deleter; } private: void InitRefCount() { ref_cnt_ = 0; } void IncreaseRefCount() { ref_cnt_++; } RefCntType DecreaseRefCount() { return --ref_cnt_; } std::atomic ref_cnt_; void (*deleter_)(void*); }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_REF_H_ ================================================ FILE: oneflow/core/intrusive/reflective.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_ #define ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_ #include "oneflow/core/intrusive/dss.h" #include "oneflow/core/intrusive/static_counter.h" #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/intrusive/base.h" namespace oneflow { #define REFLECTIVE_CLASS_BEGIN(class_name) \ struct class_name final : public intrusive::Base { \ public: \ using self_type = class_name; \ static const bool __has_intrusive_ref__ = true; \ \ private: \ DEFINE_STATIC_COUNTER(field_counter); \ DSS_BEGIN(STATIC_COUNTER(field_counter), class_name); #define REFLECTIVE_CLASS_END(class_name) \ static_assert(__has_intrusive_ref__, "this class is not intrusive-referenced"); \ \ public: \ static const int __NumberOfFields__ = STATIC_COUNTER(field_counter); \ \ private: \ INCREASE_STATIC_COUNTER(field_counter); \ DSS_END(STATIC_COUNTER(field_counter), "intrusive-referenced class", class_name); \ } \ ; #define REFLECTIVE_CLASS_DEFINE_FIELD(field_type, field_name) \ static_assert(__has_intrusive_ref__, "this class is not intrusive-referenced"); \ field_type field_name; \ INCREASE_STATIC_COUNTER(field_counter); \ DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), "intrusive-referenced class", field_type, \ field_name); #define REFLECTIVE_FIELD(struct_type, field_name) \ intrusive::OffsetStructField // Get field number by field name // note: field numbers start from 1 instead of 0. #define REFLECTIVE_FIELD_NUMBER(cls, field_name) cls::OF_PP_CAT(field_name, kDssFieldNumber) // Get field type by field number #define REFLECTIVE_FIELD_TYPE(cls, field_number) cls::template __DssFieldType__::type // Get field offset by field number #define REFLECTIVE_FIELD_OFFSET(cls, field_number) \ cls::template __DssFieldOffset4FieldIndex__::value // Get current defined field counter inside a intrusive-referenced class. // note: not used outside REFLECTIVE_CLASS_BEGIN ... REFLECTIVE_CLASS_END // e.g.: // REFLECTIVE_CLASS_BEGIN(Foo); // static_assert(REFLECTIVE_FIELD_COUNTER == 0, ""); // REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, a); // static_assert(REFLECTIVE_FIELD_COUNTER == 1, ""); // REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b); // static_assert(REFLECTIVE_FIELD_COUNTER == 2, ""); // REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c); // static_assert(REFLECTIVE_FIELD_COUNTER == 3, ""); // REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d); // REFLECTIVE_CLASS_END(Foo); #define REFLECTIVE_FIELD_COUNTER STATIC_COUNTER(field_counter) } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_ ================================================ FILE: oneflow/core/intrusive/shared_ptr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_ #define ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_ #include "oneflow/core/intrusive/ref.h" namespace oneflow { namespace intrusive { template class shared_ptr final { public: using value_type = T; shared_ptr() : ptr_(nullptr) {} shared_ptr(value_type* ptr) : ptr_(nullptr) { Reset(ptr); } shared_ptr(const shared_ptr& obj_ptr) { ptr_ = nullptr; Reset(obj_ptr.ptr_); } shared_ptr(shared_ptr&& obj_ptr) noexcept { ptr_ = obj_ptr.ptr_; obj_ptr.ptr_ = nullptr; } // NOLINTNEXTLINE(google-explicit-constructor) operator shared_ptr() const { return shared_ptr(ptr_); } ~shared_ptr() { Clear(); } template static shared_ptr make_shared(Args&&... args) { shared_ptr ret; Ref::NewAndInitRef(&ret.ptr_); ret.Mutable()->__Init__(std::forward(args)...); return ret; } explicit operator bool() const { return ptr_ != nullptr; } value_type* get() const { return ptr_; } const value_type& Get() const { return *ptr_; } value_type* operator->() const { return ptr_; } value_type& operator*() const { return *ptr_; } bool operator==(const shared_ptr& rhs) const { return this->ptr_ == rhs.ptr_; } value_type* Mutable() { return ptr_; } void Reset() { Reset(nullptr); } void Reset(value_type* ptr) { Clear(); if (ptr == nullptr) { return; } ptr_ = ptr; Ref::IncreaseRef(ptr_); } shared_ptr& operator=(const shared_ptr& rhs) { Reset(rhs.ptr_); return *this; } shared_ptr& operator=(shared_ptr&& rhs) noexcept { ptr_ = rhs.ptr_; rhs.ptr_ = nullptr; return *this; } static shared_ptr __UnsafeMove__(value_type* ptr) { shared_ptr ret; ret.ptr_ = ptr; return ret; } void __UnsafeMoveTo__(value_type** ptr) { *ptr = ptr_; ptr_ = nullptr; } private: void Clear() { if (ptr_ == nullptr) { return; } Ref::DecreaseRef(ptr_); ptr_ = nullptr; } mutable value_type* ptr_; }; template shared_ptr make_shared(Args&&... args) { return shared_ptr::make_shared(std::forward(args)...); } } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_ ================================================ FILE: oneflow/core/intrusive/skiplist.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_ #define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_ #include "oneflow/core/intrusive/ref.h" #include "oneflow/core/intrusive/skiplist_hook.h" namespace oneflow { namespace intrusive { template class SkipList { public: SkipList(const SkipList&) = delete; SkipList(SkipList&&) = delete; SkipList() { this->__Init__(); } ~SkipList() { this->Clear(); } using value_type = typename ElemKeyField::struct_type; using key_type = typename ElemKeyField::field_type::key_type; using elem_key_level0_hook_struct_field = OffsetStructField; using iterator_struct_field = ComposeStructField; template static constexpr int IteratorHookOffset() { return offsetof(SkipList, skiplist_head_) + intrusive::SkipListHead::IteratorHookOffset(); } void __Init__() { skiplist_head_.__Init__(); } std::size_t size() const { return skiplist_head_.size(); } bool empty() const { return skiplist_head_.empty(); } value_type* Begin() { return skiplist_head_.Begin(); } intrusive::shared_ptr Find(const key_type& key) { intrusive::shared_ptr ret; ret.Reset(skiplist_head_.Find(key)); return ret; } value_type* FindPtr(const key_type& key) { return skiplist_head_.Find(key); } const value_type* FindPtr(const key_type& key) const { return skiplist_head_.Find(key); } bool EqualsEnd(const intrusive::shared_ptr& ptr) { return !ptr; } void Erase(const key_type& key) { Ref::DecreaseRef(skiplist_head_.Erase(key)); } void Erase(value_type* elem_ptr) { skiplist_head_.Erase(elem_ptr); Ref::DecreaseRef(elem_ptr); } std::pair, bool> Insert(value_type* elem_ptr) { value_type* ret_elem = nullptr; bool success = false; std::tie(ret_elem, success) = skiplist_head_.Insert(elem_ptr); std::pair, bool> ret; ret.first.Reset(ret_elem); ret.second = success; if (success) { Ref::IncreaseRef(elem_ptr); } return ret; } void Clear() { skiplist_head_.Clear([](value_type* elem) { Ref::DecreaseRef(elem); }); } private: intrusive::SkipListHead skiplist_head_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_ ================================================ FILE: oneflow/core/intrusive/skiplist_hook.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_ #define ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_ #include #include #include #include "oneflow/core/common/throw.h" #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/intrusive/list_hook.h" namespace oneflow { namespace intrusive { template struct ListHookArray final { public: ListHookArray() { Clear(); } using self_type = ListHookArray; template static constexpr int LevelZeroHookOffset() { return 0; } bool empty() const { return hooks_[0].nullptr_empty(); } void __Init__() { Clear(); } void Clear() { for (auto& hook : hooks_) { hook.Clear(); } } void NullptrClear() { for (auto& hook : hooks_) { hook.NullptrClear(); } } void InsertAfter(ListHookArray* prev_skiplist_hook, int levels) { CHECK(empty()); ListHook* prev_hook = &prev_skiplist_hook->hooks_[0]; int i = 0; for (; i < levels; ++i, ++prev_hook) { while (prev_hook->nullptr_empty()) { prev_hook = (prev_hook - 1)->prev() + 1; } hooks_[i].InsertAfter(prev_hook); } } void Erase() { for (int i = 0; i < max_level; ++i) { if (hooks_[i].nullptr_empty()) { return; } hooks_[i].next()->AppendTo(hooks_[i].prev()); hooks_[i].NullptrClear(); } } static ListHookArray* ThisPtr4HookPtr(ListHook* slist_ptr, int level) { auto* hooks_ptr = (std::array*)(slist_ptr - level); return OffsetStructField::StructPtr4FieldPtr( hooks_ptr); } void CheckEmpty() const { for (const auto& hook : hooks_) { CHECK(hook.empty()); } } void CheckNullptrEmpty() const { for (const auto& hook : hooks_) { CHECK(hook.nullptr_empty()); } } ListHook* mutable_hook(int i) { return &hooks_[i]; } private: template static constexpr int HooksOffset() { return offsetof(self_type, hooks_); } std::array hooks_; }; template struct SkipListHook { public: SkipListHook() : key_() { __Init__(); } using self_type = SkipListHook; using hook_type = ListHookArray; using key_type = T; static const int max_level = N; static_assert(N > 0, "invalid number of levels"); template static constexpr int LevelZeroHookOffset() { return offsetof(SkipListHook, hook_) + hook_type::LevelZeroHookOffset(); } bool empty() const { return hook_.empty(); } void __Init__() { hook_.NullptrClear(); } const T& key() const { return key_; } T* mut_key() { return &key_; } void CheckEmpty() const { return hook_.CheckNullptrEmpty(); } void Clear() { hook_.NullptrClear(); mut_key()->__Delete__(); } static self_type* Find(const key_type& key, hook_type* head, int size_shift) { ListHook* last_hook_less_than_key = SearchLastBottomHookLessThan(key, head, size_shift); if (last_hook_less_than_key->next() == head->mutable_hook(0)) { return nullptr; } self_type* searched = ThisPtr4HookPtr(last_hook_less_than_key->next(), 0); if (searched->key() == key) { return searched; } return nullptr; } static self_type* Erase(const key_type& key, hook_type* head, int size_shift) { self_type* searched = Find(key, head, size_shift); CHECK_NOTNULL(searched); Erase(searched); return searched; } static void Erase(self_type* elem) { elem->hook_.Erase(); } // return true if success static std::pair Insert(self_type* elem, hook_type* head, int size_shift) { ListHook* prev_list_hook = SearchLastBottomHookLessThan(elem->key(), head, size_shift); self_type* maybe_searched = nullptr; if (prev_list_hook->next() == head->mutable_hook(0)) { maybe_searched = nullptr; } else { maybe_searched = ThisPtr4HookPtr(prev_list_hook->next(), 0); } self_type* ret_elem = nullptr; bool success = false; if (maybe_searched != nullptr && (maybe_searched->key() == elem->key())) { ret_elem = maybe_searched; success = false; } else { self_type* prev = ThisPtr4HookPtr(prev_list_hook, 0); ret_elem = elem; elem->hook_.InsertAfter(&prev->hook_, RandomNumLevels(size_shift)); success = true; } // CHECK_EQ(Find(ret_elem->key(), head), ret_elem, GetMaxVal() / 2); return std::make_pair(ret_elem, success); } static SkipListHook* ThisPtr4HookPtr(ListHook* list_hook_ptr, int level) { auto* skip_list_ptr = hook_type::ThisPtr4HookPtr(list_hook_ptr, level); using FieldUtil = OffsetStructField; return FieldUtil::StructPtr4FieldPtr(skip_list_ptr); } private: template static constexpr int SkipListIteratorOffset() { return offsetof(self_type, hook_); } static int32_t RandomNumLevels(int size_shift) { std::minstd_rand rand{std::random_device{}()}; int32_t max_num_levels = std::min(size_shift, N); int32_t num_levels = 1; for (int i = 1; (rand() % 2 == 0) && i < max_num_levels; ++i) { ++num_levels; } return num_levels; } static ListHook* SearchLastBottomHookLessThan(const key_type& key, hook_type* head, int size_shift) { int max_num_level = std::min(size_shift, N); ListHook* list_hook = head->mutable_hook(max_num_level); for (int level = max_num_level - 1; level >= 0; --level) { --list_hook; while (list_hook->next() != head->mutable_hook(level) && ThisPtr4HookPtr(list_hook->next(), level)->key() < key) { list_hook = list_hook->next(); } } return list_hook; } hook_type hook_; T key_; }; template class SkipListHead { public: SkipListHead() { __Init__(); } using value_type = typename ValueHookField::struct_type; using key_hook_type = typename ValueHookField::field_type; using key_type = typename key_hook_type::key_type; using value_key_level0_hook_struct_field = OffsetStructField; using value_level0_hook_struct_field = ComposeStructField; static const int max_level = key_hook_type::max_level; template static constexpr int IteratorHookOffset() { return offsetof(SkipListHead, skiplist_head_) + ListHookArray::LevelZeroHookOffset(); } void __Init__() { skiplist_head_.__Init__(); size_ = 0; } std::size_t size() const { return size_; } bool empty() const { return size_ == 0; } value_type* Begin() { ListHook* head_level0 = skiplist_head_.mutable_hook(0); ListHook* begin_list_hook = head_level0->next(); if (begin_list_hook == head_level0) { return nullptr; } return value_level0_hook_struct_field::StructPtr4FieldPtr(begin_list_hook); } value_type* Find(const key_type& key) { auto* key_hook_ptr = key_hook_type::Find(key, &skiplist_head_, size_shift()); if (key_hook_ptr == nullptr) { return nullptr; } return ValueHookField::StructPtr4FieldPtr(key_hook_ptr); } const value_type* Find(const key_type& key) const { auto* key_hook_ptr = key_hook_type::Find( key, const_cast*>(&skiplist_head_), size_shift()); if (key_hook_ptr == nullptr) { return nullptr; } return ValueHookField::StructPtr4FieldPtr(key_hook_ptr); } value_type* Erase(const key_type& key) { key_hook_type* erased = key_hook_type::Erase(key, &skiplist_head_, size_shift()); --size_; return ValueHookField::StructPtr4FieldPtr(erased); } void Erase(value_type* elem) { key_hook_type::Erase(ValueHookField::FieldPtr4StructPtr(elem)); --size_; } // return true if success std::pair Insert(value_type* elem) { key_hook_type* elem_key_hook = ValueHookField::FieldPtr4StructPtr(elem); key_hook_type* ret_key_hook = nullptr; bool success = false; std::tie(ret_key_hook, success) = key_hook_type::Insert(elem_key_hook, &skiplist_head_, size_shift()); if (success) { ++size_; } return std::make_pair(ValueHookField::StructPtr4FieldPtr(ret_key_hook), success); } template void Clear(const Callback& cb) { using hook_type = ListHookArray; for (; size_ > 0; --size_) { ListHook* begin_list_hook = skiplist_head_.mutable_hook(0)->next(); auto* begin = hook_type::ThisPtr4HookPtr(begin_list_hook, 0); if (begin == &skiplist_head_) { break; } begin->Erase(); cb(value_level0_hook_struct_field::StructPtr4FieldPtr(begin_list_hook)); } CHECK(empty_debug()); } void Clear() { Clear([](value_type*) {}); } bool empty_debug() const { bool ret = (size_ == 0); if (ret) { skiplist_head_.CheckEmpty(); } return ret; } private: int size_shift() const { return std::log2(size_ + 1); } ListHookArray skiplist_head_; volatile std::size_t size_; }; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_ ================================================ FILE: oneflow/core/intrusive/skiplist_hook_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/skiplist_hook.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace intrusive { namespace test { template class TestSkipListHead final : public SkipListHead { // NOLINT public: TestSkipListHead() { this->__Init__(); } TestSkipListHead(const TestSkipListHead&) = delete; TestSkipListHead(TestSkipListHead&&) = delete; TestSkipListHead& operator==(const TestSkipListHead&) = delete; TestSkipListHead& operator==(TestSkipListHead&&) = delete; ~TestSkipListHead() { this->Clear(); } }; struct FooSkipListElem { FooSkipListElem() : value() { key.__Init__(); } int value; SkipListHook key; }; using FooSkipList = TestSkipListHead; TEST(SkipListHook, empty) { FooSkipList skiplist; ASSERT_TRUE(skiplist.empty_debug()); ASSERT_EQ(skiplist.size(), 0); } TEST(SkipListHook, insert_naive) { FooSkipList skiplist; FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 1); { auto* searched = skiplist.Find(int(0)); ASSERT_EQ(searched, &elem0); } { auto* searched = skiplist.Find(int(-1)); ASSERT_TRUE(searched == nullptr); } } TEST(SkipListHook, erase_by_key) { FooSkipList skiplist; FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 1); ASSERT_TRUE(skiplist.Find(int(0)) != nullptr); skiplist.Erase(int(0)); ASSERT_EQ(skiplist.size(), 0); ASSERT_TRUE(skiplist.Find(int(0)) == nullptr); } TEST(SkipListHook, erase_by_elem) { FooSkipList skiplist; FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 1); ASSERT_TRUE(skiplist.Find(int(0)) != nullptr); skiplist.Erase(&elem0); ASSERT_EQ(skiplist.size(), 0); ASSERT_TRUE(skiplist.Find(int(0)) == nullptr); } TEST(SkipListHook, insert_many) { FooSkipList skiplist; FooSkipListElem exists[100]; for (int i = 0; i < 100; ++i) { int key = i - 50; if (key >= 0) { ++key; } *exists[i].key.mut_key() = key; skiplist.Insert(&exists[i]); ASSERT_EQ(skiplist.Find(key), &exists[i]); } FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 101); { auto* searched = skiplist.Find(int(0)); ASSERT_EQ(searched, &elem0); } { auto* searched = skiplist.Find(int(-1001)); ASSERT_TRUE(searched == nullptr); } skiplist.Clear(); ASSERT_TRUE(skiplist.empty_debug()); } TEST(SkipListHook, erase_many_by_key) { FooSkipList skiplist; FooSkipListElem exists[100]; for (int i = 0; i < 100; ++i) { int key = i - 50; if (key >= 0) { ++key; } *exists[i].key.mut_key() = key; skiplist.Insert(&exists[i]); ASSERT_EQ(skiplist.Find(key), &exists[i]); } FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 101); ASSERT_TRUE(skiplist.Find(int(0)) != nullptr); skiplist.Erase(int(0)); ASSERT_EQ(skiplist.size(), 100); ASSERT_TRUE(skiplist.Find(int(0)) == nullptr); skiplist.Clear(); ASSERT_TRUE(skiplist.empty_debug()); } TEST(SkipListHook, erase_many_by_elem) { FooSkipList skiplist; FooSkipListElem exists[100]; for (int i = 0; i < 100; ++i) { int key = i - 50; if (key >= 0) { ++key; } *exists[i].key.mut_key() = key; skiplist.Insert(&exists[i]); ASSERT_EQ(skiplist.Find(key), &exists[i]); } FooSkipListElem elem0; *elem0.key.mut_key() = 0; elem0.value = 1; skiplist.Insert(&elem0); ASSERT_EQ(skiplist.size(), 101); ASSERT_TRUE(skiplist.Find(int(0)) != nullptr); skiplist.Erase(&elem0); ASSERT_EQ(skiplist.size(), 100); ASSERT_TRUE(skiplist.Find(int(0)) == nullptr); skiplist.Clear(); ASSERT_TRUE(skiplist.empty_debug()); } } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/skiplist_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace intrusive { namespace test { namespace { class SkipListFoo final : public intrusive::Base { public: void __Init__() { clear_is_deleted(); } void __Delete__() { if (has_is_deleted()) { ++*mut_is_deleted(); } } // Getters bool has_is_deleted() const { return is_deleted_ != nullptr; } int is_deleted() const { return *is_deleted_; } int32_t foo_map_key() const { return foo_map_key_.key(); } // Setters void set_is_deleted(int* val) { is_deleted_ = val; } void clear_is_deleted() { is_deleted_ = nullptr; } int* mut_is_deleted() { return is_deleted_; } void set_foo_map_key(int32_t val) { *foo_map_key_.mut_key() = val; } size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } SkipListFoo() : intrusive_ref_(), is_deleted_(), foo_map_key_() {} intrusive::Ref intrusive_ref_; int* is_deleted_; public: intrusive::SkipListHook foo_map_key_; }; class SkipListFooContainer final : public intrusive::Base { public: // types using Key2SkipListFoo = intrusive::SkipList; // Getters const Key2SkipListFoo& foo_map() const { return foo_map_; } // Setters Key2SkipListFoo* mut_foo_map() { return &foo_map_; } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } SkipListFooContainer() : intrusive_ref_(), foo_map_() {} intrusive::Ref intrusive_ref_; // maps Key2SkipListFoo foo_map_; }; using Key2SkipListFoo = intrusive::SkipList; TEST(SkipList, empty) { Key2SkipListFoo foo_map; ASSERT_TRUE(foo_map.empty()); ASSERT_EQ(foo_map.size(), 0); } TEST(SkipList, insert_naive) { Key2SkipListFoo foo_map; auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 1); { auto searched = foo_map.Find(int(0)); ASSERT_TRUE(searched == elem0); } { auto searched = foo_map.Find(int(-1)); ASSERT_TRUE(foo_map.EqualsEnd(searched)); } } TEST(SkipList, insert_twice) { Key2SkipListFoo foo_map; auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); auto elem1 = intrusive::make_shared(); elem1->set_foo_map_key(0); ASSERT_TRUE(foo_map.Insert(elem0.Mutable()).second); ASSERT_TRUE(!foo_map.Insert(elem1.Mutable()).second); } TEST(SkipList, erase_by_key) { Key2SkipListFoo foo_map; auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 1); ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Erase(int(0)); ASSERT_EQ(foo_map.size(), 0); ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0)))); } TEST(SkipList, erase_by_elem) { Key2SkipListFoo foo_map; auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 1); ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Erase(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 0); ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0)))); } TEST(SkipList, insert_many) { Key2SkipListFoo foo_map; intrusive::shared_ptr exists[100]; for (int i = 0; i < 100; ++i) { exists[i] = intrusive::make_shared(); int key = i - 50; if (key >= 0) { ++key; } exists[i]->set_foo_map_key(key); foo_map.Insert(exists[i].Mutable()); ASSERT_TRUE(foo_map.Find(key) == exists[i]); } auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 101); { auto searched = foo_map.Find(int(0)); ASSERT_TRUE(searched == elem0); } { auto searched = foo_map.Find(int(-1001)); ASSERT_TRUE(foo_map.EqualsEnd(searched)); } foo_map.Clear(); ASSERT_TRUE(foo_map.empty()); } TEST(SkipList, erase_many_by_key) { Key2SkipListFoo foo_map; intrusive::shared_ptr exists[100]; for (int i = 0; i < 100; ++i) { exists[i] = intrusive::make_shared(); int key = i - 50; if (key >= 0) { ++key; } exists[i]->set_foo_map_key(key); foo_map.Insert(exists[i].Mutable()); ASSERT_TRUE(foo_map.Find(key) == exists[i]); } auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 101); ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Erase(int(0)); ASSERT_EQ(foo_map.size(), 100); ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Clear(); ASSERT_TRUE(foo_map.empty()); } TEST(SkipList, erase_many_by_elem) { Key2SkipListFoo foo_map; intrusive::shared_ptr exists[100]; for (int i = 0; i < 100; ++i) { exists[i] = intrusive::make_shared(); int key = i - 50; if (key >= 0) { ++key; } exists[i]->set_foo_map_key(key); foo_map.Insert(exists[i].Mutable()); ASSERT_TRUE(foo_map.Find(key) == exists[i]); } auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 101); ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Erase(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 100); ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Clear(); ASSERT_TRUE(foo_map.empty()); } TEST(SkipList, MAP_HEAD) { int elem_cnt = 0; { auto foo_map_container = intrusive::make_shared(); auto& foo_map = *foo_map_container->mut_foo_map(); intrusive::shared_ptr exists[100]; for (int i = 0; i < 100; ++i) { exists[i] = intrusive::make_shared(); int key = i - 50; if (key >= 0) { ++key; } exists[i]->set_foo_map_key(key); exists[i]->set_is_deleted(&elem_cnt); foo_map.Insert(exists[i].Mutable()); ASSERT_TRUE(foo_map.Find(key) == exists[i]); ASSERT_EQ(exists[i]->ref_cnt(), 2); } auto elem0 = intrusive::make_shared(); elem0->set_foo_map_key(0); elem0->set_is_deleted(&elem_cnt); foo_map.Insert(elem0.Mutable()); ASSERT_EQ(foo_map.size(), 101); ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0)))); ASSERT_EQ(elem0->ref_cnt(), 2); foo_map.Erase(elem0->foo_map_key()); ASSERT_EQ(elem0->ref_cnt(), 1); ASSERT_EQ(foo_map.size(), 100); ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0)))); foo_map.Clear(); ASSERT_TRUE(foo_map.empty()); } ASSERT_EQ(elem_cnt, 101); } TEST(SkipList, FOR_EACH) { int elem_cnt = 0; { auto foo_map_container = intrusive::make_shared(); auto& foo_map = *foo_map_container->mut_foo_map(); intrusive::shared_ptr exists[100]; for (int i = 0; i < 100; ++i) { exists[i] = intrusive::make_shared(); int key = i - 50; exists[i]->set_foo_map_key(key); exists[i]->set_is_deleted(&elem_cnt); foo_map.Insert(exists[i].Mutable()); ASSERT_TRUE(foo_map.Find(key) == exists[i]); ASSERT_EQ(exists[i]->ref_cnt(), 2); } int value = -50; INTRUSIVE_UNSAFE_FOR_EACH_PTR(foo, &foo_map) { ASSERT_EQ(foo->foo_map_key(), value); ++value; } } ASSERT_EQ(elem_cnt, 100); } } // namespace } // namespace test } // namespace intrusive } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/static_counter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_ #define ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_ namespace oneflow { #define STATIC_COUNTER(counter_name) _STATIC_COUNTER_NAME(counter_name)<_AUTO_INCREMENT()>::value #define DEFINE_STATIC_COUNTER(counter_name) _DEFINE_STATIC_COUNTER(_AUTO_INCREMENT(), counter_name) #define INCREASE_STATIC_COUNTER(counter_name) \ _INCREASE_STATIC_COUNTER(_AUTO_INCREMENT(), counter_name) // details #define _STATIC_COUNTER_NAME(counter_name) StaticCounter_##counter_name #define _AUTO_INCREMENT() __COUNTER__ #define _DEFINE_STATIC_COUNTER(auto_counter, counter_name) \ template \ struct _STATIC_COUNTER_NAME(counter_name) { \ static const int value = _STATIC_COUNTER_NAME(counter_name)::value; \ }; \ template \ struct _STATIC_COUNTER_NAME(counter_name) { \ static const int value = 0; \ }; #define _INCREASE_STATIC_COUNTER(auto_counter, counter_name) \ template \ struct _STATIC_COUNTER_NAME(counter_name) { \ static const int value = _STATIC_COUNTER_NAME(counter_name)::value + 1; \ }; } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_ ================================================ FILE: oneflow/core/intrusive/static_counter_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/static_counter.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { namespace { DEFINE_STATIC_COUNTER(static_counter); static_assert(STATIC_COUNTER(static_counter) == 0, ""); TEST(StaticCounter, eq0) { static_assert(STATIC_COUNTER(static_counter) == 0, ""); } INCREASE_STATIC_COUNTER(static_counter); static_assert(STATIC_COUNTER(static_counter) == 1, ""); TEST(StaticCounter, eq1) { static_assert(STATIC_COUNTER(static_counter) == 1, ""); } static_assert(STATIC_COUNTER(static_counter) == 1, ""); TEST(StaticCounter, eq1_again) { static_assert(STATIC_COUNTER(static_counter) == 1, ""); } INCREASE_STATIC_COUNTER(static_counter); static_assert(STATIC_COUNTER(static_counter) == 2, ""); TEST(StaticCounter, eq2) { static_assert(STATIC_COUNTER(static_counter) == 2, ""); } // clang-format off REFLECTIVE_CLASS_BEGIN(FooBar); FooBar() = default; static_assert(STATIC_COUNTER(field_counter) == 0, ""); REFLECTIVE_CLASS_END(FooBar); // clang-format on } // namespace } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/intrusive/struct_traits.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_ #define ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_ #include #include #include "oneflow/core/common/preprocessor.h" namespace oneflow { namespace intrusive { template struct PtrStructField { using struct_type = T; using field_type = F; static T* StructPtr4FieldPtr(const F* field_ptr) { int offset_value = reinterpret_cast(&(((T*)nullptr)->*ptr2member)); return (T*)((const_cast(reinterpret_cast(field_ptr))) - offset_value); } static F* FieldPtr4StructPtr(const T* struct_ptr) { return &(const_cast(struct_ptr)->*ptr2member); } }; template struct OffsetStructField { using struct_type = T; using field_type = F; static const int offset_value = offset; static T* StructPtr4FieldPtr(const F* field_ptr) { return (T*)((const_cast(reinterpret_cast(field_ptr))) - offset_value); } static F* FieldPtr4StructPtr(const T* struct_ptr) { return (F*)((const_cast(reinterpret_cast(struct_ptr))) + offset_value); } }; #define INTRUSIVE_FIELD(struct_type, field_name) \ intrusive::PtrStructFieldfield_name), \ &struct_type::field_name> template struct ComposeStructField { static_assert(std::is_same::value, "invalid type"); using struct_type = typename X::struct_type; using field_type = typename Y::field_type; static struct_type* StructPtr4FieldPtr(const field_type* field_ptr) { return X::StructPtr4FieldPtr(Y::StructPtr4FieldPtr(field_ptr)); } static field_type* FieldPtr4StructPtr(const struct_type* struct_ptr) { return Y::FieldPtr4StructPtr(X::FieldPtr4StructPtr(struct_ptr)); } }; template struct ConstStruct { using type = const T; }; template struct ConstStruct { using type = const T; }; template using ConstType = typename ConstStruct::type; template struct ConstRefOrPtrStruct { using type = ConstType&; }; template struct ConstRefOrPtrStruct { using type = ConstType*; }; template using ConstRefOrPtr = typename ConstRefOrPtrStruct::type; } // namespace intrusive } // namespace oneflow #endif // ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_ ================================================ FILE: oneflow/core/intrusive/struct_traits_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/intrusive/struct_traits.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace test { namespace { struct OneflowTestNamespaceFoo { OneflowTestNamespaceFoo() : x(0), bar(0), const_bar(0) {} int x; int bar; const int const_bar; }; TEST(StructField, mutable_struct_mutable_field) { OneflowTestNamespaceFoo foo; auto* bar = &foo.bar; auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar); auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo); ASSERT_EQ(struct_ptr, &foo); ASSERT_EQ(field_ptr, bar); } TEST(StructField, mutable_struct_const_field) { OneflowTestNamespaceFoo foo; auto* bar = &foo.const_bar; auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar); auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo); ASSERT_EQ(struct_ptr, &foo); ASSERT_EQ(field_ptr, bar); } TEST(StructField, const_struct_mutable_field) { const OneflowTestNamespaceFoo foo; auto* bar = &foo.bar; auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar); auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo); ASSERT_EQ(struct_ptr, &foo); ASSERT_EQ(field_ptr, bar); } TEST(StructField, const_struct_const_field) { const OneflowTestNamespaceFoo foo; auto* bar = &foo.const_bar; auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar); auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo); ASSERT_EQ(struct_ptr, &foo); ASSERT_EQ(field_ptr, bar); } struct X { int a; int b; }; struct Y { int c; X d; }; TEST(StructField, compose) { using BFieldInY = intrusive::ComposeStructField; Y y{}; int* field_b = &y.d.b; ASSERT_EQ(BFieldInY::FieldPtr4StructPtr(&y), field_b); ASSERT_EQ(BFieldInY::StructPtr4FieldPtr(field_b), &y); } } // namespace } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/ipc/shared_memory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ipc/shared_memory.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/pcheck.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/env_var/env_var.h" #ifdef __linux__ #include #include #include #include #include #include #include #endif namespace oneflow { namespace ipc { namespace { #ifdef __linux__ // return errno int ShmOpen(const std::string& shm_name, int* fd, bool create) { SharedMemoryManager::get().AddShmName(shm_name); *fd = shm_open(("/" + shm_name).c_str(), (create ? O_CREAT : 0) | O_RDWR | O_EXCL, S_IRUSR | S_IWUSR); return *fd == -1 ? errno : 0; } // return errno int ShmOpen(std::string* shm_name, int* fd, bool create) { int err = EEXIST; while (true) { static constexpr int kNameLength = 8; *shm_name = std::string("ofshm_") + GenAlphaNumericString(kNameLength); err = ShmOpen(*shm_name, fd, create); if (err != EEXIST) { return err; } } return err; } int ShmMap(int fd, const size_t shm_size, void** ptr) { *ptr = mmap(NULL, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); return (*ptr == MAP_FAILED) ? errno : 0; } #endif Maybe ShmSetUp(std::string* shm_name, size_t shm_size, bool create) { #ifdef __linux__ int fd = 0; PCHECK_OR_RETURN(ShmOpen(shm_name, &fd, create)); PCHECK_OR_RETURN(posix_fallocate(fd, 0, shm_size)) << ReturnEmptyStr([&] { close(fd); }); void* ptr = nullptr; PCHECK_OR_RETURN(ShmMap(fd, shm_size, &ptr)) << ReturnEmptyStr([&] { close(fd); }); close(fd); std::memset(ptr, 0, shm_size); return ptr; #else TODO_THEN_RETURN(); #endif } Maybe ShmSetUp(const std::string& shm_name, size_t* shm_size, bool create) { #ifdef __linux__ int fd = 0; PCHECK_OR_RETURN(ShmOpen(shm_name, &fd, create)); struct stat st; // NOLINT PCHECK_OR_RETURN(fstat(fd, &st)) << ReturnEmptyStr([&] { close(fd); }); *shm_size = st.st_size; void* ptr = nullptr; PCHECK_OR_RETURN(ShmMap(fd, *shm_size, &ptr)) << ReturnEmptyStr([&] { close(fd); }); close(fd); return ptr; #else TODO_THEN_RETURN(); #endif } Maybe> GetContentsOfShmDirectory() { #ifdef __linux__ std::set contents; DIR* dir = opendir("/dev/shm/"); CHECK_NOTNULL_OR_RETURN(dir) << "/dev/shm directory does not exist, there may be a problem with your machine!"; while (dirent* f = readdir(dir)) { if (f->d_name[0] == '.') continue; contents.insert(f->d_name); } closedir(dir); return contents; #else TODO_THEN_RETURN(); #endif } } // namespace SharedMemoryManager& SharedMemoryManager::get() { // Must be a static singleton variable instead of Singleton. // Subprocesses don't have chance to call `Singleton::Delete()` static SharedMemoryManager shared_memory_manager; return shared_memory_manager; } void SharedMemoryManager::FindAndDeleteOutdatedShmNames() { std::unique_lock lock(mutex_); static size_t counter = 0; const int delete_invalid_names_interval = EnvInteger(); if (counter % delete_invalid_names_interval == 0) { const auto& existing_shm_names = CHECK_JUST(GetContentsOfShmDirectory()); // std::remove_if doesn't support std::map for (auto it = shm_names_.begin(); it != shm_names_.end(); /* do nothing */) { if (existing_shm_names->find(*it) == existing_shm_names->end()) { it = shm_names_.erase(it); } else { it++; } } } counter++; } void SharedMemoryManager::AddShmName(const std::string& shm_name) { FindAndDeleteOutdatedShmNames(); std::unique_lock lock(mutex_); shm_names_.insert(shm_name); } Maybe SharedMemoryManager::DeleteShmName(const std::string& shm_name) { std::unique_lock lock(mutex_); auto it = std::find(shm_names_.begin(), shm_names_.end(), shm_name); if (it != shm_names_.end()) { shm_names_.erase(it); } else { return Error::RuntimeError() << "shared memory was not created but attempted to be freed."; } return Maybe::Ok(); } void SharedMemoryManager::UnlinkAllShms() { #ifdef __linux__ // Here we deliberately do not handle unlink errors. std::unique_lock lock(mutex_); for (const auto& shm : shm_names_) { shm_unlink(shm.c_str()); } shm_names_.clear(); #else UNIMPLEMENTED(); #endif } SharedMemoryManager::~SharedMemoryManager() { UnlinkAllShms(); } SharedMemory::~SharedMemory() { CHECK_JUST(Close()); } Maybe SharedMemory::Open(size_t shm_size, bool create) { std::string shm_name; char* ptr = static_cast(JUST(ShmSetUp(&shm_name, shm_size, create))); return std::shared_ptr(new SharedMemory(ptr, shm_name, shm_size)); } Maybe SharedMemory::Open(const std::string& shm_name, bool create) { size_t shm_size = 0; char* ptr = static_cast(JUST(ShmSetUp(shm_name, &shm_size, create))); return std::shared_ptr(new SharedMemory(ptr, shm_name, shm_size)); } Maybe SharedMemory::Close() { #ifdef __linux__ if (buf_ != nullptr) { PCHECK_OR_RETURN(munmap(buf_, size_)); buf_ = nullptr; } return Maybe::Ok(); #else TODO_THEN_RETURN(); #endif } Maybe SharedMemory::Unlink() { #ifdef __linux__ PCHECK_OR_RETURN(shm_unlink(name_.c_str())); JUST(SharedMemoryManager::get().DeleteShmName(name_)); return Maybe::Ok(); #else TODO_THEN_RETURN(); #endif } } // namespace ipc } // namespace oneflow ================================================ FILE: oneflow/core/ipc/shared_memory.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_IPC_SHARED_MEMORY_H_ #define ONEFLOW_CORE_IPC_SHARED_MEMORY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/singleton.h" namespace oneflow { namespace ipc { class SharedMemoryManager final { public: OF_DISALLOW_COPY_AND_MOVE(SharedMemoryManager); ~SharedMemoryManager(); void AddShmName(const std::string& shm_name); Maybe DeleteShmName(const std::string& shm_name); void UnlinkAllShms(); static SharedMemoryManager& get(); private: SharedMemoryManager() = default; void FindAndDeleteOutdatedShmNames(); std::set shm_names_; std::recursive_mutex mutex_; }; class SharedMemory final { public: SharedMemory(const SharedMemory&) = delete; SharedMemory(SharedMemory&&) = delete; ~SharedMemory(); static Maybe Open(size_t size, bool create); static Maybe Open(const std::string& name, bool create); const char* buf() const { return buf_; } char* mut_buf() { return buf_; } const std::string& name() const { return name_; } size_t size() const { return size_; } Maybe Close(); Maybe Unlink(); private: SharedMemory(char* buf, const std::string& name, size_t size) : buf_(buf), name_(name), size_(size) {} char* buf_; std::string name_; size_t size_; }; } // namespace ipc } // namespace oneflow #endif // ONEFLOW_CORE_IPC_SHARED_MEMORY_H_ ================================================ FILE: oneflow/core/job/blob_lifetime_signature.proto ================================================ syntax = "proto2"; package oneflow; message BlobLastUsedSignature { map bn_in_op2blob_last_used = 1; } message BlobBackwardUsedSignature { map bn_in_op2blob_backward_used = 1; } ================================================ FILE: oneflow/core/job/checkpointing_config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" namespace oneflow { namespace { REGISTER_SCOPE_CONFIG_DEF().Bool( "checkpointing", false, "enable checkpointing op/tensor for backward recomputation to sublinear memory cost"); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job/cluster_instruction.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/job/cluster_instruction.pb.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { namespace { std::string GetHaltAckCtrlKey(int64_t machine_id) { return "HaltAckCtrlKey/" + std::to_string(machine_id); } // return unique sequential key // because ctrl key is not allowed to push/pull twice std::string GetClusterInstructionKey() { static int64_t seq = 0; return "ClusterInstructionKey/" + std::to_string(seq++); } class ObsoleteCtrlKeys { public: ObsoleteCtrlKeys() = default; ~ObsoleteCtrlKeys() = default; template void ForEach(const CallbackT& Callback) const { std::unique_lock lck(mutex_); for (const std::string& k : keys_) { Callback(k); } } void Clear() { std::unique_lock lck(mutex_); keys_.clear(); } void Add(const std::string& key) { std::unique_lock lck(mutex_); keys_.emplace_back(key); } private: mutable std::mutex mutex_; std::vector keys_; }; COMMAND(Singleton::SetAllocated(new ObsoleteCtrlKeys())); void OccasionallyClearCtrlKV(const std::string& key) { static std::atomic seq(0LL); const static int64_t interval = 65536; Singleton::Get()->Add(key); // 1 instead of 0 is better for avoid clearing no ctrl kv if ((seq++) % interval == 1) { OF_ENV_BARRIER(); if (GlobalProcessCtx::IsThisProcessMaster()) { Singleton::Get()->ForEach( [](const std::string& k) { Singleton::Get()->ClearMasterKV(k); }); } Singleton::Get()->Clear(); OF_ENV_BARRIER(); } } void PushClusterInstruction(const ClusterInstructionProto& cluster_instruction) { const std::string& key = GetClusterInstructionKey(); Singleton::Get()->PushMasterKV(key, cluster_instruction); OccasionallyClearCtrlKV(key); } void PullClusterInstruction(ClusterInstructionProto* cluster_instruction) { const std::string& key = GetClusterInstructionKey(); Singleton::Get()->PullMasterKV(key, cluster_instruction); OccasionallyClearCtrlKV(key); } } // namespace void ClusterInstruction::NewSessionBarrier() { OF_ENV_BARRIER(); Singleton::Get()->Clear(); Singleton::Get()->Clear(); OF_ENV_BARRIER(); } void ClusterInstruction::MasterSendSessionStart() { ClusterInstructionProto cluster_instruction; cluster_instruction.mutable_cluster_ctrl_session_start(); PushClusterInstruction(cluster_instruction); NewSessionBarrier(); } void ClusterInstruction::MasterSendHalt() { ClusterInstructionProto cluster_instruction; cluster_instruction.mutable_cluster_ctrl_halt(); PushClusterInstruction(cluster_instruction); HaltBarrier(); } void ClusterInstruction::MasterSendAbort() { LOG(INFO) << "Sending abort instruction."; ClusterInstructionProto cluster_instruction; cluster_instruction.mutable_cluster_ctrl_abort(); PushClusterInstruction(cluster_instruction); } void ClusterInstruction::WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction) { PullClusterInstruction(cluster_instruction); } void ClusterInstruction::HaltBarrier() { OF_ENV_BARRIER(); } void ClusterInstruction::EagerSyncBarrier() { // TODO(jianhao): update here after eager instructions are run asynchronously OF_ENV_BARRIER(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/cluster_instruction.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_ #define ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_ #include "oneflow/core/job/cluster_instruction.pb.h" namespace oneflow { struct ClusterInstruction final { static void MasterSendSessionStart(); static void MasterSendHalt(); static void MasterSendAbort(); static void MasterSendEagerSync(); static void WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction); static void NewSessionBarrier(); static void HaltBarrier(); static void EagerSyncBarrier(); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_ ================================================ FILE: oneflow/core/job/cluster_instruction.proto ================================================ syntax = "proto2"; package oneflow; message ClusterCtrlSessionStart {} message ClusterCtrlHalt {} message ClusterCtrlAbort {} message ClusterInstructionProto { oneof instruction_type { ClusterCtrlSessionStart cluster_ctrl_session_start = 1; ClusterCtrlHalt cluster_ctrl_halt = 2; // normal exit ClusterCtrlAbort cluster_ctrl_abort = 5; // error exit } } ================================================ FILE: oneflow/core/job/collective_boxing/coordinator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace boxing { namespace collective { class RequestStore; class Executor; struct RequestId; class Coordinator { public: Coordinator() = default; virtual ~Coordinator() = default; virtual void Init(std::shared_ptr request_store, std::shared_ptr executor) = 0; virtual void InitJob(int64_t job_id) = 0; virtual void DeinitJob(int64_t job_id) = 0; virtual void AddRequest(void* coordinator_token) = 0; virtual void* CreateCoordinatorToken(const RequestId& request_id) = 0; virtual void DestroyCoordinatorToken(void* coordinator_token) = 0; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_ ================================================ FILE: oneflow/core/job/collective_boxing/executor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/executor.h" namespace oneflow { namespace boxing { namespace collective { void Executor::ExecuteRequests(const std::vector& request_ids) { GroupRequests(request_ids, [&](std::vector&& group, GroupToken* group_token) { ExecuteGroup(group_token); }); } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/executor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace boxing { namespace collective { class RequestStore; struct RequestId; class GroupToken; class Executor { public: Executor() = default; virtual ~Executor() = default; virtual void Init(std::shared_ptr request_store) = 0; virtual void InitJob(int64_t job_id) = 0; virtual void DeinitJob(int64_t job_id) = 0; virtual void GroupRequests( const std::vector& request_ids, const std::function&&, GroupToken*)>& Handler) = 0; virtual void ExecuteGroup(GroupToken* group_token) = 0; virtual void DestroyGroupToken(GroupToken* group_token) = 0; virtual void ExecuteRequests(const std::vector& request_ids); }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_ ================================================ FILE: oneflow/core/job/collective_boxing/executor_backend.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace boxing { namespace collective { class RequestStore; struct RequestId; class ExecutorBackend { public: OF_DISALLOW_COPY_AND_MOVE(ExecutorBackend); ExecutorBackend() = default; virtual ~ExecutorBackend() = default; virtual void Init(std::shared_ptr request_store) = 0; virtual void InitJob(int64_t job_id) = 0; virtual void DeinitJob(int64_t job_id) = 0; virtual void GroupRequests( const std::vector& request_ids, const std::function&&, void*)>& Handler) = 0; virtual void ExecuteGroup(void* group_token) = 0; virtual void* CreateGroupToken(const std::vector& group) = 0; virtual void DestroyGroupToken(void* group_token) = 0; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_ ================================================ FILE: oneflow/core/job/collective_boxing/executor_backend_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/executor_backend_manager.h" namespace oneflow { namespace boxing { namespace collective { ExecutorBackendMgr& ExecutorBackendMgr::Get() { static ExecutorBackendMgr mgr; return mgr; } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/executor_backend_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_ #include "oneflow/core/job/collective_boxing/executor_backend.h" #include "oneflow/core/common/device_type.h" namespace oneflow { namespace boxing { namespace collective { class ExecutorBackendMgr { public: using Creator = std::function()>; ExecutorBackendMgr(ExecutorBackendMgr const&) = delete; ExecutorBackendMgr& operator=(ExecutorBackendMgr const&) = delete; static ExecutorBackendMgr& Get(); template void RegisterExecutorBackendType(DeviceType device_type) { executor_backend_reg_result_.emplace(device_type, []() -> std::unique_ptr { return std::make_unique(); }); vaild_executor_device_types_.emplace_back(device_type); } std::unique_ptr NewExecutorBackend(DeviceType device_type) const { const auto& it = executor_backend_reg_result_.find(device_type); CHECK(it != executor_backend_reg_result_.end()); return it->second(); } const std::vector& vaild_executor_device_types() const { return vaild_executor_device_types_; } private: ExecutorBackendMgr() = default; HashMap executor_backend_reg_result_; std::vector vaild_executor_device_types_; }; #define REGISTER_EXECUTOR_BACKEND(device, Derived) \ COMMAND(ExecutorBackendMgr::Get().RegisterExecutorBackendType(device)) } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_ ================================================ FILE: oneflow/core/job/collective_boxing/nccl_executor_backend.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/executor_backend_manager.h" #include "oneflow/core/job/collective_boxing/request_store.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/device/cuda_util.h" #include #include #include namespace oneflow { namespace boxing { namespace collective { namespace { ncclRedOp_t GetNcclReduceOp(ReduceMethod reduce_method) { if (reduce_method == kReduceMethodSum) { return ncclRedOp_t::ncclSum; } else { UNIMPLEMENTED(); return ncclRedOp_t{}; } } std::string GetNcclUniqueIdRpcKey(const std::string& name, int64_t stream_id) { return "CollectiveBoxingExecutorNcclUniqueIdRpcKey-" + name + "-" + std::to_string(stream_id); } struct CopyParams { void* dst; const void* src; int64_t count; }; constexpr int64_t kMultiCopyParamsMaxSize = 128; constexpr int64_t kMultiCopyAlignSize = 32; int64_t GetMultiCopyAlignedSize(int64_t size) { return ((size + kMultiCopyAlignSize - 1) / kMultiCopyAlignSize) * kMultiCopyAlignSize; } struct MultiCopyParams { CopyParams params[kMultiCopyParamsMaxSize]; int64_t count; MultiCopyParams() : count(0), params{} {} void Add(void* dst, const void* src, int64_t count) { CHECK_LT(this->count, kMultiCopyParamsMaxSize); params[this->count].dst = dst; params[this->count].src = src; params[this->count].count = count; this->count += 1; } }; using BulkType = ulonglong2; __global__ void MultiCopyGpu(MultiCopyParams multi_params) { for (int64_t p = 0; p < multi_params.count; ++p) { const CopyParams params = multi_params.params[p]; auto* bulk_dst = reinterpret_cast(params.dst); const auto* bulk_src = reinterpret_cast(params.src); const int64_t bulk_count = params.count / sizeof(BulkType); CUDA_1D_KERNEL_LOOP_T(int64_t, i, bulk_count) { bulk_dst[i] = bulk_src[i]; } const int64_t tail_offset = bulk_count * sizeof(BulkType); auto* tail_dst = reinterpret_cast(params.dst) + tail_offset; const auto* tail_src = reinterpret_cast(params.src) + tail_offset; const int64_t tail_count = params.count - tail_offset; CUDA_1D_KERNEL_LOOP_T(int64_t, i, tail_count) { tail_dst[i] = tail_src[i]; } } } void MultiCopy(cudaStream_t stream, const MultiCopyParams& multi_params) { if (multi_params.count <= 0) { return; } CHECK_LE(multi_params.count, kMultiCopyParamsMaxSize); int64_t max_count = multi_params.params[0].count; for (int64_t i = 0; i < multi_params.count; ++i) { max_count = std::max(max_count, multi_params.params[i].count); } MultiCopyGpu<<>>( multi_params); } class CommRank final { public: OF_DISALLOW_COPY(CommRank); CommRank(int32_t device_id, int32_t global_rank, int32_t global_rank_count, int32_t local_rank, int32_t local_rank_count) : device_id_(device_id), global_rank_(global_rank), local_rank_(local_rank), nccl_comm_(nullptr) {} CommRank(CommRank&& rhs) noexcept { this->device_id_ = rhs.device_id_; this->global_rank_ = rhs.global_rank_; this->local_rank_ = rhs.local_rank_; this->nccl_comm_ = rhs.nccl_comm_; rhs.nccl_comm_ = nullptr; } ~CommRank() { if (nccl_comm_ != nullptr) { CudaCurrentDeviceGuard guard(device_id_); OF_NCCL_CHECK(ncclCommDestroy(nccl_comm_)); } } int32_t device_id() const { return device_id_; } ncclComm_t nccl_comm() const { return nccl_comm_; } void InitRank(ncclUniqueId unique_id, int32_t global_rank_count) { CudaCurrentDeviceGuard guard(device_id_); OF_NCCL_CHECK(ncclCommInitRank(&nccl_comm_, global_rank_count, unique_id, global_rank_)); } private: int32_t device_id_; int32_t global_rank_; int32_t local_rank_; ncclComm_t nccl_comm_; }; class CommGroup final { public: OF_DISALLOW_COPY(CommGroup); CommGroup() = default; ~CommGroup() = default; CommGroup(CommGroup&& rhs) noexcept { rank_vec_.swap(rhs.rank_vec_); global_rank_count_ = rhs.global_rank_count_; } void InitGroup(const DeviceSet& device_set, const std::string& unique_name) { CudaCurrentDeviceGuard guard; const int64_t this_machine_id = GlobalProcessCtx::Rank(); global_rank_count_ = device_set.device_size(); std::vector local_ranks; for (int32_t i = 0; i < global_rank_count_; ++i) { if (device_set.device(i).machine_id() == this_machine_id) { local_ranks.emplace_back(i); } } const int32_t local_rank_count = local_ranks.size(); CHECK_GT(local_rank_count, 0); ncclUniqueId nccl_unique_id{}; if (local_ranks.front() == 0) { OF_NCCL_CHECK(ncclGetUniqueId(&nccl_unique_id)); if (local_rank_count != global_rank_count_) { Singleton::Get()->PushKV(unique_name, NcclUniqueIdToString(nccl_unique_id)); } } else { Singleton::Get()->PullKV(unique_name, [&nccl_unique_id](const std::string& val) { NcclUniqueIdFromString(val, &nccl_unique_id); }); } rank_vec_.reserve(local_rank_count); OF_NCCL_CHECK(ncclGroupStart()); for (int32_t local_rank = 0; local_rank < local_ranks.size(); ++local_rank) { const int32_t global_rank = local_ranks.at(local_rank); const int32_t device_id = device_set.device(global_rank).device_id(); OF_CUDA_CHECK(cudaSetDevice(device_id)); rank_vec_.emplace_back(device_id, global_rank, global_rank_count_, local_rank, local_rank_count); rank_vec_.at(local_rank).InitRank(nccl_unique_id, global_rank_count_); } OF_NCCL_CHECK(ncclGroupEnd()); } int32_t global_rank_count() const { return global_rank_count_; } int32_t local_rank_count() const { return rank_vec_.size(); } const CommRank& GetCommRank(int32_t local_rank) const { return rank_vec_.at(local_rank); } private: std::vector rank_vec_; int32_t global_rank_count_ = 0; }; class StreamCtx { public: OF_DISALLOW_COPY(StreamCtx); StreamCtx(int32_t device_id, size_t fusion_buffer_size) : device_id_(device_id), fusion_buffer_size_(fusion_buffer_size) { CudaCurrentDeviceGuard guard(device_id_); int priority; OF_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(nullptr, &priority)); OF_CUDA_CHECK(cudaStreamCreateWithPriority(&stream_, cudaStreamNonBlocking, priority)); OF_CUDA_CHECK(cudaMalloc(&fusion_buffer_, fusion_buffer_size_)); cb_event_poller_ = std::thread(&StreamCtx::PollEvent, this); } ~StreamCtx() { cb_event_chan_.Close(); cb_event_poller_.join(); CudaCurrentDeviceGuard guard(device_id_); OF_CUDA_CHECK(cudaStreamSynchronize(stream_)); OF_CUDA_CHECK(cudaStreamDestroy(stream_)); OF_CUDA_CHECK(cudaFree(fusion_buffer_)); } void PollEvent() { CudaCurrentDeviceGuard guard(device_id_); while (true) { std::pair> cb_event; ChannelStatus status = cb_event_chan_.Receive(&cb_event); if (status == kChannelStatusErrorClosed) { break; } CHECK_EQ(status, kChannelStatusSuccess); OF_CUDA_CHECK(cudaEventSynchronize(cb_event.first)); cb_event.second(); OF_CUDA_CHECK(cudaEventDestroy(cb_event.first)); } } void AddCallback(const std::function& callback) { cudaEvent_t event; OF_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); OF_CUDA_CHECK(cudaEventRecord(event, stream_)); CHECK_EQ(cb_event_chan_.Send(std::make_pair(event, callback)), kChannelStatusSuccess); } int32_t device_id() const { return device_id_; } cudaStream_t stream() const { return stream_; } size_t fusion_buffer_size() const { return fusion_buffer_size_; } char* fusion_buffer() const { return fusion_buffer_; } private: int32_t device_id_; cudaStream_t stream_ = nullptr; size_t fusion_buffer_size_; char* fusion_buffer_ = nullptr; Channel>> cb_event_chan_; std::thread cb_event_poller_; }; void LaunchFusedAllReduce(const CommGroup& comm_group, const std::vector>& device_id2stream_ctx, const std::shared_ptr& request_store, const std::vector& request_ids) { CHECK_LE(request_ids.size(), kMultiCopyParamsMaxSize); RequestEntry* first_request_entry = request_store->MutRequestEntry(request_ids.front()); const ncclDataType_t nccl_data_type = GetNcclDataType(first_request_entry->desc().op_desc().data_type()); const ncclRedOp_t nccl_reduce_op = GetNcclReduceOp(first_request_entry->desc().op_desc().reduce_method()); const int64_t size_of_data_type = GetSizeOfDataType(first_request_entry->desc().op_desc().data_type()); std::vector offset_vec; offset_vec.reserve(request_ids.size()); int64_t offset = 0; request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { offset_vec.emplace_back(offset); offset += GetMultiCopyAlignedSize(request_entry->size_in_bytes()); }); const int64_t elem_cnt = offset / size_of_data_type; for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) { MultiCopyParams copy_in_params; const CommRank& comm_rank = comm_group.GetCommRank(local_rank); const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get(); CHECK_LE(offset, stream_ctx->fusion_buffer_size()); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { copy_in_params.Add(stream_ctx->fusion_buffer() + offset_vec.at(i), request_entry->GetRuntimeRequest(local_rank)->send_buff, request_entry->size_in_bytes()); }); OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id())); MultiCopy(stream_ctx->stream(), copy_in_params); } OF_NCCL_CHECK(ncclGroupStart()); for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) { const CommRank& comm_rank = comm_group.GetCommRank(local_rank); const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get(); OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id())); OF_NCCL_CHECK(ncclAllReduce(stream_ctx->fusion_buffer(), stream_ctx->fusion_buffer(), elem_cnt, nccl_data_type, nccl_reduce_op, comm_rank.nccl_comm(), stream_ctx->stream())); } OF_NCCL_CHECK(ncclGroupEnd()); for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) { MultiCopyParams copy_out_params; const CommRank& comm_rank = comm_group.GetCommRank(local_rank); const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get(); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { copy_out_params.Add(request_entry->GetRuntimeRequest(local_rank)->recv_buff, stream_ctx->fusion_buffer() + offset_vec.at(i), request_entry->size_in_bytes()); }); OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id())); MultiCopy(stream_ctx->stream(), copy_out_params); } } void LaunchAggregatedOps(const CommGroup& comm_group, const std::vector>& device_id2stream_ctx, const std::shared_ptr& request_store, const std::vector& request_ids) { OF_NCCL_CHECK(ncclGroupStart()); for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) { const CommRank& comm_rank = comm_group.GetCommRank(local_rank); const auto comm = comm_rank.nccl_comm(); const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get(); OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id())); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { const auto& op_desc = request_entry->desc().op_desc(); const std::shared_ptr& runtime_request_info = request_entry->GetRuntimeRequest(local_rank); const OpType op_type = op_desc.op_type(); const void* send_buff = runtime_request_info->send_buff; void* recv_buff = runtime_request_info->recv_buff; const int64_t elem_cnt = request_entry->elem_cnt(); const ncclDataType_t nccl_data_type = GetNcclDataType(op_desc.data_type()); const int32_t num_ranks = comm_group.global_rank_count(); if (op_type == OpType::kOpTypeAllReduce) { OF_NCCL_CHECK(ncclAllReduce(send_buff, recv_buff, elem_cnt, nccl_data_type, GetNcclReduceOp(op_desc.reduce_method()), comm, stream_ctx->stream())); } else if (op_type == OpType::kOpTypeAllGather) { CHECK_EQ(elem_cnt % num_ranks, 0); OF_NCCL_CHECK(ncclAllGather(send_buff, recv_buff, elem_cnt / num_ranks, nccl_data_type, comm, stream_ctx->stream())); } else if (op_type == OpType::kOpTypeReduceScatter) { CHECK_EQ(elem_cnt % num_ranks, 0); OF_NCCL_CHECK(ncclReduceScatter( send_buff, recv_buff, elem_cnt / num_ranks, nccl_data_type, GetNcclReduceOp(op_desc.reduce_method()), comm, stream_ctx->stream())); } else if (op_type == OpType::kOpTypeReduce) { OF_NCCL_CHECK(ncclReduce(send_buff, recv_buff, elem_cnt, nccl_data_type, GetNcclReduceOp(op_desc.reduce_method()), op_desc.root(), comm, stream_ctx->stream())); } else if (op_type == OpType::kOpTypeBroadcast) { OF_NCCL_CHECK(ncclBroadcast(send_buff, recv_buff, elem_cnt, nccl_data_type, op_desc.root(), comm, stream_ctx->stream())); } else if (op_type == OpType::kOpTypeAll2All) { #if NCCL_VERSION_CODE > 2700 const int64_t elem_per_rank = elem_cnt / num_ranks; const int64_t elem_per_chunk = elem_per_rank / num_ranks; const int64_t dtype_size = GetSizeOfDataType(op_desc.data_type()); const int64_t chunk_size = elem_per_chunk * dtype_size; for (int64_t j = 0; j < num_ranks; ++j) { OF_NCCL_CHECK(ncclSend(reinterpret_cast( reinterpret_cast(send_buff) + j * chunk_size), elem_per_chunk, nccl_data_type, j, comm, stream_ctx->stream())); OF_NCCL_CHECK(ncclRecv( reinterpret_cast(reinterpret_cast(recv_buff) + j * chunk_size), elem_per_chunk, nccl_data_type, j, comm, stream_ctx->stream())); } #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } }); } OF_NCCL_CHECK(ncclGroupEnd()); } void AddCallbackAndResetRuntimeRequest( const CommGroup& comm_group, const std::vector>& device_id2stream_ctx, const std::shared_ptr& request_store, const std::vector& request_ids) { std::vector>> saved_runtime_request_info( request_ids.size()); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { saved_runtime_request_info.at(i) = std::move(request_entry->ResetRuntimeRequest()); }); for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) { const CommRank& comm_rank = comm_group.GetCommRank(local_rank); StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get(); auto runtime_request_info_vec = std::make_shared>>(); runtime_request_info_vec->reserve(request_ids.size()); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { runtime_request_info_vec->emplace_back( std::move(saved_runtime_request_info.at(i).at(local_rank))); }); OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id())); stream_ctx->AddCallback([runtime_request_info_vec]() { for (auto& runtime_request_info : *runtime_request_info_vec) { runtime_request_info->callback(Maybe::Ok()); } }); } } } // namespace class NcclExecutorBackend : public ExecutorBackend { public: OF_DISALLOW_COPY_AND_MOVE(NcclExecutorBackend); NcclExecutorBackend(); ~NcclExecutorBackend() override; private: void Init(std::shared_ptr request_store) override; void InitJob(int64_t job_id) override; void DeinitJob(int64_t job_id) override; void GroupRequests(const std::vector& request_ids, const std::function&&, void*)>& Handler) override; void ExecuteGroup(void* group_token) override; void* CreateGroupToken(const std::vector& group) override; void DestroyGroupToken(void* group_token) override; struct Impl; std::unique_ptr impl_; }; struct NcclExecutorBackend::Impl { Impl(const CollectiveBoxingConf& conf, std::shared_ptr request_store) : conf(conf), request_store(std::move(request_store)) { CHECK_GT(conf.nccl_num_streams(), 0); CHECK_GE(conf.nccl_fusion_threshold_mb(), 0); fusion_threshold = conf.nccl_fusion_threshold_mb() * 1024 * 1024; num_streams = conf.nccl_num_streams(); current_stream_id = 0; enable_mixed_fusion = (!conf.nccl_fusion_all_reduce_use_buffer()) && conf.nccl_enable_mixed_fusion(); int nccl_version; OF_NCCL_CHECK(ncclGetVersion(&nccl_version)); if (nccl_version == 21003) { LOG(WARNING) << "Current nccl version is 2.10.3, in this version, ncclGroup() with mixed " "datatype/element/collective could induce crash or corruption, so we will not " "fuse any request."; } InitStreamCtx(); InitIsOpTypeFusionEnabled(); } ~Impl() { stream_id2device_id2stream_ctx.clear(); device_set2stream_id2comm_group.clear(); } void InitCommGroup(int64_t job_id) { std::set local_device_ids; request_store->ForEachMutRequestEntryInJob( job_id, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { const auto& request = request_entry->desc(); if (request.op_desc().device_type() != DeviceType::kCUDA) { return; } if (!request_entry->HasRankOnThisNode()) { return; } const DeviceSet& device_set = request.device_set(); if (device_set2stream_id2comm_group.count(device_set) > 0) { return; } auto& stream_id2comm_group = device_set2stream_id2comm_group[device_set]; stream_id2comm_group.resize(num_streams); for (int32_t stream_id = 0; stream_id < num_streams; ++stream_id) { stream_id2comm_group.at(stream_id).InitGroup( device_set, GetNcclUniqueIdRpcKey(request.op_desc().name(), stream_id)); } for (int32_t j = 0; j < stream_id2comm_group.at(0).local_rank_count(); ++j) { local_device_ids.emplace(stream_id2comm_group.at(0).GetCommRank(j).device_id()); } }); for (int32_t stream_id = 0; stream_id < num_streams; ++stream_id) { for (const int64_t device_id : local_device_ids) { if (stream_id2device_id2stream_ctx.at(stream_id).at(device_id) == nullptr) { stream_id2device_id2stream_ctx.at(stream_id).at(device_id) = std::make_unique(device_id, fusion_threshold); } } } } void InitStreamCtx() { int32_t num_devices; OF_CUDA_CHECK(cudaGetDeviceCount(&num_devices)); stream_id2device_id2stream_ctx.resize(num_streams); for (int64_t stream_id = 0; stream_id < num_streams; ++stream_id) { stream_id2device_id2stream_ctx.at(stream_id).resize(num_devices); } } void InitIsOpTypeFusionEnabled() { op_type2fusion_enabled.resize(OpType_ARRAYSIZE, false); op_type2fusion_enabled.at(OpType::kOpTypeAllReduce) = conf.nccl_fusion_all_reduce(); op_type2fusion_enabled.at(OpType::kOpTypeAllGather) = conf.nccl_fusion_all_gather(); op_type2fusion_enabled.at(OpType::kOpTypeReduceScatter) = conf.nccl_fusion_reduce_scatter(); op_type2fusion_enabled.at(OpType::kOpTypeReduce) = conf.nccl_fusion_reduce(); op_type2fusion_enabled.at(OpType::kOpTypeBroadcast) = conf.nccl_fusion_broadcast(); op_type2fusion_enabled.at(OpType::kOpTypeAll2All) = false; } int32_t NextStreamId() { const int32_t stream_id = current_stream_id; current_stream_id = (current_stream_id + 1) % num_streams; return stream_id; } bool IsOpTypeFusionEnabled(OpType op_type) const { return op_type2fusion_enabled.at(op_type); } bool IsRequestEntryFusionEnabled(const RequestEntry* entry) const { return IsOpTypeFusionEnabled(entry->desc().op_desc().op_type()); } bool CanRequestEntryFuse(const RequestEntry* lhs, const RequestEntry* rhs) const { { int nccl_version; OF_NCCL_CHECK(ncclGetVersion(&nccl_version)); // Workaround for https://github.com/NVIDIA/nccl/issues/560 if (nccl_version == 21003) { return false; } } if (lhs->device_set_symbol() != rhs->device_set_symbol()) { return false; } if ((!IsRequestEntryFusionEnabled(lhs)) || (!IsRequestEntryFusionEnabled(rhs))) { return false; } if ((!enable_mixed_fusion) && lhs->desc().op_desc().op_type() != rhs->desc().op_desc().op_type()) { return false; } if (conf.nccl_fusion_all_reduce_use_buffer()) { if (lhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce && rhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce) { CHECK(lhs->desc().op_desc().has_reduce_method()); CHECK(rhs->desc().op_desc().has_reduce_method()); return lhs->desc().op_desc().reduce_method() == rhs->desc().op_desc().reduce_method() && lhs->desc().op_desc().data_type() == rhs->desc().op_desc().data_type(); } else if (lhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce || rhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce) { return false; } else { return true; } } else { return true; } } void GroupRequests(const std::vector& request_ids, const std::function&&, void*)>& Handler) { std::vector group; int64_t group_size = 0; const int64_t fusion_max_ops = std::min(conf.nccl_fusion_max_ops(), kMultiCopyParamsMaxSize); request_store->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { const auto& request = request_entry->desc(); const int64_t size = GetMultiCopyAlignedSize(request_entry->size_in_bytes()); if (group.empty() || !CanRequestEntryFuse(request_store->MutRequestEntry(group.back()), request_entry) || group_size + size > fusion_threshold || group.size() >= fusion_max_ops) { if (!group.empty()) { void* token = CreateGroupToken(group); Handler(std::move(group), token); group.clear(); group_size = 0; } } group.emplace_back(request_id); group_size += size; }); if (!group.empty()) { void* token = CreateGroupToken(group); Handler(std::move(group), token); } } struct GroupToken { GroupToken(const std::vector& group, std::vector* stream_id2comm_group) : request_ids(group), stream_id2comm_group(stream_id2comm_group) {} std::vector request_ids; std::vector* stream_id2comm_group; }; void* CreateGroupToken(const std::vector& group) { CHECK_GT(group.size(), 0); void* group_token; const DeviceSet& first_device_set = request_store->MutRequestEntry(group.front())->desc().device_set(); auto it = device_set2stream_id2comm_group.find(first_device_set); CHECK(it != device_set2stream_id2comm_group.end()); group_token = new GroupToken(group, &it->second); request_store->ForEachMutRequestEntryForIdsInJob( group, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { const DeviceSet& device_set = request_entry->desc().device_set(); CHECK(first_device_set == device_set); }); return group_token; } void DestroyGroupToken(void* group_token) { GroupToken* token = static_cast(group_token); delete token; } void ExecuteGroup(void* group_token) { GroupToken* token = static_cast(group_token); const std::vector& request_ids = token->request_ids; if (request_ids.empty()) { return; } const int32_t stream_id = NextStreamId(); CudaCurrentDeviceGuard device_guard; const auto& comm_group = token->stream_id2comm_group->at(stream_id); auto& device_id2stream_ctx = stream_id2device_id2stream_ctx.at(stream_id); if (request_store->MutRequestEntry(request_ids.front())->desc().op_desc().op_type() == OpType::kOpTypeAllReduce && conf.nccl_fusion_all_reduce_use_buffer() && request_ids.size() > 1) { LaunchFusedAllReduce(comm_group, device_id2stream_ctx, request_store, request_ids); } else { LaunchAggregatedOps(comm_group, device_id2stream_ctx, request_store, request_ids); } AddCallbackAndResetRuntimeRequest(comm_group, device_id2stream_ctx, request_store, request_ids); } CollectiveBoxingConf conf; int64_t fusion_threshold; int32_t num_streams; int32_t current_stream_id; bool enable_mixed_fusion; std::vector op_type2fusion_enabled; std::shared_ptr request_store; HashMap> device_set2stream_id2comm_group; std::vector>> stream_id2device_id2stream_ctx; }; NcclExecutorBackend::NcclExecutorBackend() = default; NcclExecutorBackend::~NcclExecutorBackend() = default; void NcclExecutorBackend::Init(std::shared_ptr request_store) { impl_ = std::make_unique( Singleton::Get()->collective_boxing_conf(), request_store); } void NcclExecutorBackend::InitJob(int64_t job_id) { CudaCurrentDeviceGuard guard; impl_->InitCommGroup(job_id); } void NcclExecutorBackend::DeinitJob(int64_t job_id) {} void NcclExecutorBackend::GroupRequests( const std::vector& request_ids, const std::function&&, void*)>& Handler) { impl_->GroupRequests(request_ids, Handler); } void* NcclExecutorBackend::CreateGroupToken(const std::vector& group) { return impl_->CreateGroupToken(group); } void NcclExecutorBackend::DestroyGroupToken(void* group_token) { return impl_->DestroyGroupToken(group_token); } void NcclExecutorBackend::ExecuteGroup(void* group_token) { impl_->ExecuteGroup(group_token); } REGISTER_EXECUTOR_BACKEND(DeviceType::kCUDA, NcclExecutorBackend); } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/request_store.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/request_store.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace boxing { namespace collective { RequestEntry::RequestEntry(const RequestDesc& desc) : desc_(desc) { std::set node_ids; for (int64_t global_rank = 0; global_rank < desc.device_set().device().size(); ++global_rank) { const DeviceDesc& device = desc.device_set().device(global_rank); if (device.machine_id() == GlobalProcessCtx::Rank()) { local_device_vec_.emplace_back(device); global_rank2local_rank_.emplace(global_rank, local_rank2global_rank_.size()); local_rank2global_rank_.emplace_back(global_rank); } node_ids.emplace(device.machine_id()); } const size_t local_rank_count = local_device_vec_.size(); node_count_ = node_ids.size(); state_.runtime_request_info_vec.resize(local_rank_count); state_.runtime_request_count = 0; elem_cnt_ = Shape(desc.op_desc().shape()).elem_cnt(); size_in_bytes_ = elem_cnt_ * GetSizeOfDataType(desc.op_desc().data_type()); device_set_symbol_.reset(desc.device_set()); } bool RequestEntry::AddRuntimeRequest( int32_t local_rank, std::shared_ptr runtime_request_info) { CHECK_LT(local_rank, state_.runtime_request_info_vec.size()); std::lock_guard lock(state_.mutex); CHECK(!state_.runtime_request_info_vec.at(local_rank)); state_.runtime_request_info_vec.at(local_rank) = std::move(runtime_request_info); state_.runtime_request_count += 1; return state_.runtime_request_count == state_.runtime_request_info_vec.size(); } const std::shared_ptr& RequestEntry::GetRuntimeRequest( int32_t local_rank) { std::lock_guard lock(state_.mutex); return state_.runtime_request_info_vec.at(local_rank); } std::vector> RequestEntry::ResetRuntimeRequest() { std::lock_guard lock(state_.mutex); std::vector> ret( state_.runtime_request_info_vec.size()); ret.swap(state_.runtime_request_info_vec); state_.runtime_request_count = 0; return ret; } void RequestStore::InitJob(int64_t job_id, const RequestSet& request_set) { std::vector>& request_entry_vec = job_id2request_entry_vec_[job_id]; CHECK_EQ(request_entry_vec.size(), 0); for (const RequestDesc& desc : request_set.request()) { request_entry_vec.emplace_back(std::make_unique(desc)); } for (int32_t i = 0; i < request_entry_vec.size(); ++i) { const std::unique_ptr& entry = request_entry_vec.at(i); CHECK(name2request_id_.emplace(entry->desc().op_desc().name(), RequestId(job_id, i)).second); } } void RequestStore::DeinitJob(int64_t job_id) { const auto& it = job_id2request_entry_vec_.find(job_id); CHECK(it != job_id2request_entry_vec_.end()); const auto& request_entry_vec = it->second; for (const auto& request_entry : request_entry_vec) { name2request_id_.erase(request_entry->desc().op_desc().name()); } job_id2request_entry_vec_.erase(job_id); } struct RequestEntryToken { RequestEntry* request_entry; }; void* RequestStore::CreateRequestEntryToken(const RequestId& request_id) { auto it = job_id2request_entry_vec_.find(request_id.job_id); CHECK(it != job_id2request_entry_vec_.end()); return new RequestEntryToken{it->second.at(request_id.request_index).get()}; } void RequestStore::DestroyRequestEntryToken(void* request_entry_token) { auto token = static_cast(request_entry_token); delete token; } RequestEntry* RequestStore::GetRequestEntry(void* request_entry_token) { return static_cast(request_entry_token)->request_entry; } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/request_store.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/collective_boxing/runtime_request_info.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" namespace oneflow { namespace boxing { namespace collective { class RequestEntry final { public: OF_DISALLOW_COPY_AND_MOVE(RequestEntry); RequestEntry(const RequestDesc& desc); ~RequestEntry() = default; const RequestDesc& desc() const { return desc_; } int32_t LocalRankCount() const { return local_rank2global_rank_.size(); } int32_t LocalRankToGlobalRank(int32_t local_rank) const { return local_rank2global_rank_.at(local_rank); } int32_t GlobalRankToLocalRank(int32_t global_rank) const { return global_rank2local_rank_.at(global_rank); } bool HasRankOnThisNode() const { return !local_rank2global_rank_.empty(); } int32_t NodeCount() const { return node_count_; } const DeviceDesc& LocalDeviceDesc(int32_t local_rank) const { return local_device_vec_.at(local_rank); } bool IsRootOnThisNode() const { return (!local_rank2global_rank_.empty()) && local_rank2global_rank_.front() == 0; } bool AddRuntimeRequest(int32_t local_rank, std::shared_ptr runtime_request_info); const std::shared_ptr& GetRuntimeRequest(int32_t local_rank); std::vector> ResetRuntimeRequest(); int64_t elem_cnt() const { return elem_cnt_; } int64_t size_in_bytes() const { return size_in_bytes_; } const Symbol& device_set_symbol() const { return device_set_symbol_; } private: RequestDesc desc_; int32_t node_count_; std::vector local_device_vec_; std::vector local_rank2global_rank_; std::map global_rank2local_rank_; int64_t elem_cnt_; int64_t size_in_bytes_; Symbol device_set_symbol_; struct State { std::vector> runtime_request_info_vec; int32_t runtime_request_count; std::mutex mutex; }; State state_; }; struct RequestId { RequestId(int64_t job_id, int32_t request_index) : job_id(job_id), request_index(request_index) {} int64_t job_id; int32_t request_index; }; class RequestStore { public: OF_DISALLOW_COPY_AND_MOVE(RequestStore); RequestStore() = default; ~RequestStore() = default; void InitJob(int64_t job_id, const RequestSet& request_set); void DeinitJob(int64_t job_id); RequestEntry* MutRequestEntry(const RequestId& request_id) { auto it = job_id2request_entry_vec_.find(request_id.job_id); CHECK(it != job_id2request_entry_vec_.end()); return it->second.at(request_id.request_index).get(); } void ForEachMutRequestEntryForIdsInJob( const std::vector& request_ids, const std::function& Handler) { if (request_ids.size() == 0) { return; } int64_t job_id = request_ids.front().job_id; auto it = job_id2request_entry_vec_.find(job_id); CHECK(it != job_id2request_entry_vec_.end()); for (int32_t i = 0; i < request_ids.size(); ++i) { CHECK_EQ(request_ids.at(i).job_id, job_id); Handler(it->second.at(request_ids.at(i).request_index).get(), i, request_ids.at(i)); } } void ForEachMutRequestEntryInJob( int64_t job_id, const std::function& Handler) { auto it = job_id2request_entry_vec_.find(job_id); CHECK(it != job_id2request_entry_vec_.end()); for (int32_t i = 0; i < it->second.size(); ++i) { RequestId request_id(job_id, i); Handler(it->second.at(i).get(), i, request_id); } } int32_t RequestCountForJob(int64_t job_id) const { const auto& it = job_id2request_entry_vec_.find(job_id); CHECK(it != job_id2request_entry_vec_.end()); return it->second.size(); } RequestId GetRequestIdByName(const std::string& name) const { return name2request_id_.at(name); } void* CreateRequestEntryToken(const RequestId& request_id); void DestroyRequestEntryToken(void* token); RequestEntry* GetRequestEntry(void* token); private: HashMap>> job_id2request_entry_vec_; HashMap name2request_id_; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_ ================================================ FILE: oneflow/core/job/collective_boxing/runtime_request_info.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace boxing { namespace collective { struct RuntimeRequestInfo { const void* send_buff; void* recv_buff; std::function&)> callback; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_ ================================================ FILE: oneflow/core/job/collective_boxing/scheduler.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/scheduler.h" #include "oneflow/core/job/collective_boxing/executor.h" #include "oneflow/core/job/collective_boxing/request_store.h" #include "oneflow/core/job/collective_boxing/coordinator.h" #include "oneflow/core/job/collective_boxing/static_group_coordinator.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/collective_boxing/executor_backend_manager.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace boxing { namespace collective { namespace { bool CanMergeIntoCurGroup(RequestStore* request_store, const RequestEntry* request_entry, const RequestId& request_id, const std::vector& group_buffer) { if (group_buffer.empty()) { return true; } const RequestId& group_entry_id = group_buffer.front(); const auto* group_entry = request_store->MutRequestEntry(group_entry_id); return (request_id.job_id == group_entry_id.job_id && request_entry->desc().dependency_depth() == group_entry->desc().dependency_depth() && request_entry->desc().op_desc().device_type() == group_entry->desc().op_desc().device_type() && request_entry->device_set_symbol() == group_entry->device_set_symbol()); } bool HasRankInteraction(const DeviceSet& a, const DeviceSet& b) { for (int64_t i = 0; i < a.device_size(); ++i) { const DeviceDesc& a_device_desc = a.device(i); for (int64_t j = 0; j < b.device_size(); ++j) { if (a_device_desc.machine_id() == b.device(j).machine_id()) { return true; } } } return false; } } // namespace class RequestHandle final { public: OF_DISALLOW_COPY_AND_MOVE(RequestHandle); RequestHandle(int32_t local_rank, void* request_entry_token, void* coordinator_token) : local_rank_(local_rank), request_entry_token_(request_entry_token), coordinator_token_(coordinator_token) {} ~RequestHandle() = default; int32_t local_rank() const { return local_rank_; } void* request_entry_token() { return request_entry_token_; } void* coordinator_token() { return coordinator_token_; } private: int32_t local_rank_; void* request_entry_token_; void* coordinator_token_; }; class GroupToken final { public: OF_DISALLOW_COPY_AND_MOVE(GroupToken); GroupToken(DeviceType device_type, void* backend_group_token) : device_type_(device_type), backend_group_token_(backend_group_token) {} ~GroupToken() = default; DeviceType device_type() { return device_type_; } void* backend_group_token() { return backend_group_token_; } private: DeviceType device_type_; void* backend_group_token_; }; class ExecutorImpl : public Executor { public: ExecutorImpl() = default; ~ExecutorImpl() override = default; void Init(std::shared_ptr request_store) override; void InitJob(int64_t job_id) override; void DeinitJob(int64_t job_id) override; void GroupRequests( const std::vector& request_ids, const std::function&&, GroupToken*)>& Handler) override; void ExecuteGroup(GroupToken* group_token) override; void DestroyGroupToken(GroupToken* group_token) override; private: DeviceType GetUniqueDeviceType(const std::vector& group); GroupToken* CreateGroupToken(const std::vector& group, void* backend_group_token); std::vector> backends_; std::shared_ptr request_store_; std::vector group_buffer_; }; void ExecutorImpl::Init(std::shared_ptr request_store) { request_store_ = request_store; backends_.resize(DeviceType_ARRAYSIZE); const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types(); CHECK_LE(vaild_executor_device_types.size(), 1) << "Currently only one backend is supported at the same time"; for (DeviceType device_type : vaild_executor_device_types) { size_t dev_count = Singleton::Get()->GetDeviceCount(device_type); if (dev_count > 0) { std::unique_ptr backend = ExecutorBackendMgr::Get().NewExecutorBackend(device_type); CHECK(backend); backend->Init(request_store_); backends_.at(device_type) = std::move(backend); } } } void ExecutorImpl::InitJob(int64_t job_id) { const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types(); for (DeviceType device_type : vaild_executor_device_types) { CHECK(backends_.at(device_type)); backends_.at(device_type)->InitJob(job_id); } } void ExecutorImpl::DeinitJob(int64_t job_id) { const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types(); for (DeviceType device_type : vaild_executor_device_types) { CHECK(backends_.at(device_type)); backends_.at(device_type)->DeinitJob(job_id); } } GroupToken* ExecutorImpl::CreateGroupToken(const std::vector& group, void* backend_group_token) { return new GroupToken(GetUniqueDeviceType(group), backend_group_token); } void ExecutorImpl::DestroyGroupToken(GroupToken* group_token) { const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types(); for (DeviceType device_type : vaild_executor_device_types) { CHECK(backends_.at(device_type)); backends_.at(device_type)->DestroyGroupToken(group_token->backend_group_token()); } delete group_token; } void ExecutorImpl::GroupRequests( const std::vector& request_ids, const std::function&&, GroupToken*)>& Handler) { if (request_ids.empty()) { return; } const CollectiveBoxingConf& conf = Singleton::Get()->collective_boxing_conf(); auto BackendHandler = [&](std::vector&& group, void* backend_group_token) { GroupToken* group_token = CreateGroupToken(group, backend_group_token); Handler(std::move(group), group_token); }; auto HandleGroup = [&]() { if (group_buffer_.empty()) { return; } const auto device_type = request_store_->MutRequestEntry(group_buffer_.front())->desc().op_desc().device_type(); backends_.at(device_type)->GroupRequests(group_buffer_, BackendHandler); group_buffer_.clear(); }; request_store_->ForEachMutRequestEntryForIdsInJob( request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { if (request_entry->HasRankOnThisNode()) { if (!(conf.enable_fusion() && CanMergeIntoCurGroup(request_store_.get(), request_entry, request_id, group_buffer_))) { HandleGroup(); } group_buffer_.emplace_back(request_id); } else { if (!group_buffer_.empty() && HasRankInteraction( request_store_->MutRequestEntry(group_buffer_.back())->desc().device_set(), request_entry->desc().device_set())) { HandleGroup(); } } }); HandleGroup(); } void ExecutorImpl::ExecuteGroup(GroupToken* group_token) { const DeviceType device_type = group_token->device_type(); backends_.at(device_type)->ExecuteGroup(group_token->backend_group_token()); } DeviceType ExecutorImpl::GetUniqueDeviceType(const std::vector& group) { const DeviceType device_type = request_store_->MutRequestEntry(group.front())->desc().op_desc().device_type(); request_store_->ForEachMutRequestEntryForIdsInJob( group, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { CHECK_EQ(request_entry->desc().op_desc().device_type(), device_type); }); return device_type; } struct Scheduler::Impl { Impl(); std::shared_ptr request_store; std::shared_ptr executor; std::shared_ptr coordinator; }; Scheduler::Impl::Impl() { request_store.reset(new RequestStore()); executor.reset(new ExecutorImpl()); executor->Init(request_store); coordinator.reset(new StaticGroupCoordinator()); coordinator->Init(request_store, executor); } class SchedulerPlanToken { public: OF_DISALLOW_COPY_AND_MOVE(SchedulerPlanToken); explicit SchedulerPlanToken(const std::vector& job_ids) : job_ids_(job_ids) {} ~SchedulerPlanToken() = default; const std::vector& job_ids() const { return job_ids_; } private: std::vector job_ids_; }; SchedulerPlanToken* Scheduler::AddPlan(const Plan& plan) { std::vector job_ids; for (const auto& job_id7request_set : plan.collective_boxing_plan().job_id2request_set()) { const int64_t job_id = job_id7request_set.first; job_ids.emplace_back(job_id); impl_->request_store->InitJob(job_id, job_id7request_set.second); impl_->executor->InitJob(job_id); impl_->coordinator->InitJob(job_id); } return new SchedulerPlanToken(job_ids); } void Scheduler::DeletePlan(SchedulerPlanToken* plan_token) { const std::vector& job_ids = plan_token->job_ids(); for (const auto& job_id : job_ids) { impl_->coordinator->DeinitJob(job_id); impl_->executor->DeinitJob(job_id); impl_->request_store->DeinitJob(job_id); } delete plan_token; } Scheduler::Scheduler() { impl_.reset(new Impl()); } Scheduler::~Scheduler() = default; RequestHandle* Scheduler::CreateRequestHandle(const RankDesc& rank_desc) { const RequestId& request_id = impl_->request_store->GetRequestIdByName(rank_desc.op_desc().name()); auto* request_entry = impl_->request_store->MutRequestEntry(request_id); CHECK(rank_desc.op_desc() == request_entry->desc().op_desc()); const int32_t local_rank = request_entry->GlobalRankToLocalRank(rank_desc.rank()); void* request_entry_token = impl_->request_store->CreateRequestEntryToken(request_id); void* coordinator_token = impl_->coordinator->CreateCoordinatorToken(request_id); return new RequestHandle(local_rank, request_entry_token, coordinator_token); } void Scheduler::DestroyRequestHandle(RequestHandle* handle) { impl_->coordinator->DestroyCoordinatorToken(handle->coordinator_token()); impl_->request_store->DestroyRequestEntryToken(handle->request_entry_token()); } void Scheduler::Schedule(RequestHandle* handle, std::shared_ptr request_info) { const int32_t local_rank = handle->local_rank(); const bool ready = impl_->request_store->GetRequestEntry(handle->request_entry_token()) ->AddRuntimeRequest(local_rank, std::move(request_info)); if (ready) { impl_->coordinator->AddRequest(handle->coordinator_token()); } } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/scheduler.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/collective_boxing/runtime_request_info.h" #include "oneflow/core/job/collective_boxing/request_store.h" #include "oneflow/core/job/plan.pb.h" namespace oneflow { namespace boxing { namespace collective { class RequestHandle; class SchedulerPlanToken; class Scheduler final { public: OF_DISALLOW_COPY_AND_MOVE(Scheduler); ~Scheduler(); RequestHandle* CreateRequestHandle(const RankDesc& rank_desc); void DestroyRequestHandle(RequestHandle*); void Schedule(RequestHandle* handle, std::shared_ptr request_info); SchedulerPlanToken* AddPlan(const Plan& plan); void DeletePlan(SchedulerPlanToken* plan_token); private: friend class Singleton; Scheduler(); struct Impl; std::unique_ptr impl_; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_ ================================================ FILE: oneflow/core/job/collective_boxing/static_group_coordinator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/collective_boxing/static_group_coordinator.h" #include "oneflow/core/job/collective_boxing/executor.h" #include "oneflow/core/job/collective_boxing/request_store.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/str_util.h" namespace oneflow { namespace boxing { namespace collective { namespace { void SortRequestIdsByOrder(RequestStore* request_store, std::vector* requests) { std::sort(requests->begin(), requests->end(), [request_store](const RequestId& a, const RequestId& b) { return request_store->MutRequestEntry(a)->desc().order() < request_store->MutRequestEntry(b)->desc().order(); }); } bool HasRankInteractionOnDeviceSet(const DeviceSet& a, const DeviceSet& b) { for (int64_t i = 0; i < a.device_size(); ++i) { const DeviceDesc& a_device_desc = a.device(i); for (int64_t j = 0; j < b.device_size(); ++j) { if (a_device_desc.machine_id() == b.device(j).machine_id()) { return true; } } } return false; } } // namespace struct GroupState { explicit GroupState(int32_t group_size) : index2is_ready(group_size), ready_request_count(0) {} void AddReadyRequest(int32_t index); bool IsReady() const; void Reset(); std::vector index2is_ready; int32_t ready_request_count; }; std::mutex mutex_; int64_t current_job_id_ = -1; int64_t current_group_idx_in_job_ = -1; struct RequestGroupIndex { int32_t group_id; int32_t index_in_group; }; class GroupToken; struct StaticGroupRequestsInfo { std::vector request_index2request_group_index; std::vector group_states; std::vector> group_id2request_ids; std::vector group_id2group_token; }; struct StaticGroupRequestsInfoToken { RequestId request_id; StaticGroupRequestsInfo* info; }; struct StaticGroupCoordinator::Impl { Impl(const std::shared_ptr& request_store, const std::shared_ptr& executor); std::shared_ptr request_store_; std::shared_ptr executor_; HashMap job_id2static_group_requests_info_; }; StaticGroupCoordinator::Impl::Impl(const std::shared_ptr& request_store, const std::shared_ptr& executor) : request_store_(request_store), executor_(executor) {} StaticGroupCoordinator::StaticGroupCoordinator() = default; StaticGroupCoordinator::~StaticGroupCoordinator() = default; void StaticGroupCoordinator::Init(std::shared_ptr request_store, std::shared_ptr executor) { impl_ = std::make_unique(request_store, executor); } void* StaticGroupCoordinator::CreateCoordinatorToken(const RequestId& request_id) { std::unique_lock lock(mutex_); auto it = impl_->job_id2static_group_requests_info_.find(request_id.job_id); CHECK(it != impl_->job_id2static_group_requests_info_.end()); return new StaticGroupRequestsInfoToken{request_id, &it->second}; } void StaticGroupCoordinator::DestroyCoordinatorToken(void* coordinator_token) { std::unique_lock lock(mutex_); auto token = static_cast(coordinator_token); delete token; } void StaticGroupCoordinator::InitJob(int64_t job_id) { std::unique_lock lock(mutex_); std::vector request_ids; impl_->request_store_->ForEachMutRequestEntryInJob( job_id, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { request_ids.emplace_back(request_id); }); SortRequestIdsByOrder(impl_->request_store_.get(), &request_ids); StaticGroupRequestsInfo info; std::vector& group_states = info.group_states; std::vector& request_index2request_group_index = info.request_index2request_group_index; std::vector>& group_id2request_ids = info.group_id2request_ids; std::vector& group_id2group_token = info.group_id2group_token; const int32_t request_count = impl_->request_store_->RequestCountForJob(job_id); request_index2request_group_index.resize(request_count); impl_->executor_->GroupRequests( request_ids, [&](std::vector&& group, GroupToken* group_token) { const int32_t group_id = group_states.size(); group_states.emplace_back(group.size()); for (int32_t idx_in_group = 0; idx_in_group < group.size(); ++idx_in_group) { const RequestId& request_id = group.at(idx_in_group); RequestGroupIndex request_group_index{group_id, idx_in_group}; request_index2request_group_index.at(request_id.request_index) = request_group_index; } group_id2request_ids.emplace_back(group); group_id2group_token.emplace_back(group_token); }); CHECK(impl_->job_id2static_group_requests_info_.emplace(job_id, info).second); if (group_states.size() != 0) { DumpSummary(job_id); } } void StaticGroupCoordinator::DeinitJob(int64_t job_id) { std::unique_lock lock(mutex_); const auto& it = impl_->job_id2static_group_requests_info_.find(job_id); CHECK(it != impl_->job_id2static_group_requests_info_.end()); const auto& group_id2group_token = it->second.group_id2group_token; for (int32_t group_id = 0; group_id < group_id2group_token.size(); ++group_id) { impl_->executor_->DestroyGroupToken(group_id2group_token.at(group_id)); } impl_->job_id2static_group_requests_info_.erase(job_id); } void StaticGroupCoordinator::AddRequest(void* coordinator_token) { std::unique_lock lock(mutex_); StaticGroupRequestsInfoToken* token = static_cast(coordinator_token); const RequestId& request_id = token->request_id; if (current_job_id_ == -1) { current_job_id_ = request_id.job_id; current_group_idx_in_job_ = 0; } else { CHECK_EQ(current_job_id_, request_id.job_id); } StaticGroupRequestsInfo* info = token->info; const RequestGroupIndex& request_group_index = info->request_index2request_group_index.at(request_id.request_index); info->group_states.at(request_group_index.group_id) .AddReadyRequest(request_group_index.index_in_group); int64_t num_launched_groups = 0; while (true) { auto& group_state = info->group_states.at(current_group_idx_in_job_); if (group_state.IsReady()) { impl_->executor_->ExecuteGroup(info->group_id2group_token.at(current_group_idx_in_job_)); group_state.Reset(); current_group_idx_in_job_ = (current_group_idx_in_job_ + 1) % info->group_states.size(); num_launched_groups += 1; } else { break; } } if (current_group_idx_in_job_ == 0 && num_launched_groups > 0) { current_job_id_ = -1; current_group_idx_in_job_ = -1; } } void StaticGroupCoordinator::DumpSummary(const int64_t job_id) const { if (!Singleton::Get()->enable_debug_mode()) { return; } auto group_ls = TeePersistentLogStream::Create(StrCat("boxing/collective/job_", job_id)); const auto& it = impl_->job_id2static_group_requests_info_.find(job_id); CHECK(it != impl_->job_id2static_group_requests_info_.end()); const auto& group_id2request_ids = it->second.group_id2request_ids; for (int32_t group_id = 0; group_id < group_id2request_ids.size(); ++group_id) { group_ls << "group id: " << std::to_string(group_id) << "\n"; impl_->request_store_->ForEachMutRequestEntryForIdsInJob( group_id2request_ids.at(group_id), [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) { group_ls->Write(request_entry->desc()); }); } } void GroupState::AddReadyRequest(int32_t index) { CHECK(!index2is_ready.at(index)); CHECK(index2is_ready.at(index) = true); ready_request_count += 1; } bool GroupState::IsReady() const { return ready_request_count == index2is_ready.size(); } void GroupState::Reset() { ready_request_count = 0; std::fill(index2is_ready.begin(), index2is_ready.end(), false); } } // namespace collective } // namespace boxing } // namespace oneflow ================================================ FILE: oneflow/core/job/collective_boxing/static_group_coordinator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_ #define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_ #include "oneflow/core/job/collective_boxing/coordinator.h" namespace oneflow { namespace boxing { namespace collective { class RequestStore; class Executor; class StaticGroupCoordinator : public Coordinator { public: OF_DISALLOW_COPY_AND_MOVE(StaticGroupCoordinator); StaticGroupCoordinator(); ~StaticGroupCoordinator() override; void Init(std::shared_ptr request_store, std::shared_ptr executor) override; void InitJob(int64_t job_id) override; void DeinitJob(int64_t job_id) override; void AddRequest(void* coordinator_token) override; void* CreateCoordinatorToken(const RequestId& request_id) override; void DestroyCoordinatorToken(void* token) override; private: void DumpSummary(const int64_t job_id) const; struct Impl; std::unique_ptr impl_; }; } // namespace collective } // namespace boxing } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_ ================================================ FILE: oneflow/core/job/compile_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/compile_mode.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace { struct CompileModeName final : public CompileModeVisitor { static std::string VisitNaive() { return "naive"; } static std::string VisitRankPerProcess() { return "rank_per_process"; } static std::string VisitInValid() { return "invalid"; } }; std::unordered_map Name2CompileMode() { std::unordered_map name2compile_mode; for (int i = static_cast(CompileMode::kInvalid) + 1; i != static_cast(CompileMode::kEnd); ++i) { CompileMode compile_mode = static_cast(i); CHECK(name2compile_mode.emplace(CompileModeName::Visit(compile_mode), compile_mode).second); } return name2compile_mode; } std::string GetValidCompileModeNames() { std::stringstream ss; for (int i = static_cast(CompileMode::kInvalid) + 1; i != static_cast(CompileMode::kEnd); ++i) { if (i > static_cast(CompileMode::kInvalid) + 1) { ss << ", "; } CompileMode compile_mode = static_cast(i); ss << CompileModeName::Visit(compile_mode); } return ss.str(); } } // namespace Maybe CurrentCompileMode() { static thread_local CompileMode mode = JUST_MSG(MapAt(Name2CompileMode(), ThreadLocalEnvString()), std::stringstream() << "ONEFLOW_LAZY_COMPILER(value: " << ThreadLocalEnvString() << ") is invalid. valid options: \"" << GetValidCompileModeNames() << "\""); return mode; } } // namespace oneflow ================================================ FILE: oneflow/core/job/compile_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COMPILE_MODE_H_ #define ONEFLOW_CORE_JOB_COMPILE_MODE_H_ #include "oneflow/core/common/maybe.h" namespace oneflow { enum class CompileMode { kInvalid = 0, // make sure kInvalid is the first CompileMode kNaive, kRankPerProcess, kEnd, // make sure kEnd is the last CompileMode }; template struct CompileModeVisitor { template static auto Visit(CompileMode compile_mode, Args&&... args) { switch (compile_mode) { case CompileMode::kNaive: return DerivedT::VisitNaive(std::forward(args)...); case CompileMode::kRankPerProcess: return DerivedT::VisitRankPerProcess(std::forward(args)...); default: { LOG(FATAL) << "invalid compile mode"; return DerivedT::VisitInValid(std::forward(args)...); } } } }; Maybe CurrentCompileMode(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COMPILE_MODE_H_ ================================================ FILE: oneflow/core/job/compiler.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/compiler.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/intra_job_mem_sharing_util.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/cost_util.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { void Compiler::Compile(Job* job, Plan* plan) const { const auto& job_name = job->job_conf().job_name(); auto compile_tc = std::make_unique>(true, true); // Step1: new Singleton and set log configs. Singleton::New(*job); const JobDesc& job_desc = GlobalJobDesc(); compile_tc->Count("[GraphCompile]" + job_name + " NewOpGraph", 1); // Step2: build task_gph. // TODO(levi): we can rewrite this part of code in visitor pattern. auto task_gph = CHECK_JUST(GlobalTaskGraph::New()); using std::placeholders::_1; LazyMode::Guard guard(true); task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); task_gph->TopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); task_gph->DecideExecutionOrder(); task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain(); auto IsReachable = Singleton::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); } task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); compile_tc->Count("[GraphCompile]" + job_name + " BuildTaskGraph", 1, true); // Step3: put infomation from task_gph into plan. const int64_t node_num = task_gph->node_num(); const int64_t cpu_num = std::thread::hardware_concurrency(); const int64_t thread_pool_size = std::min(node_num, cpu_num); BlockingCounter counter(node_num); std::mutex mtx; ThreadPool thread_pool(thread_pool_size); task_gph->ForEachNode([&](TaskNode* task_node) { thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() { if (!task_node->IsMeaningLess()) { TaskProto task_proto; task_node->ToProto(&task_proto); { std::unique_lock guard(mtx); if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat || task_node->GetTaskType() == kAcc) { PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); } plan->mutable_task()->Add(std::move(task_proto)); } // guard(mtx) } counter.Decrease(); } /* thread_pool.AddWork */); } /* task_gph->ForEachNode */); counter.WaitForeverUntilCntEqualZero(); // NOTE(levi): release task_gph here to decrise memory peak. task_gph.reset(); compile_tc->Count("[GraphCompile]" + job_name + " AddTaskToPlan", 1, true); // Step4: post-process for plan and delete Singleton. auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl // TODO(chengcheng): set inplace hint for cpu regst IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan); PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); PlanUtil::SetForceInplaceMemBlock(plan); compile_tc->Count("[GraphCompile]" + job_name + " InferMemShare", 1, true); Singleton::Delete(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/compiler.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_COMPILER_H_ #define ONEFLOW_CORE_JOB_COMPILER_H_ #include "oneflow/core/common/protobuf.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { class Compiler final { public: OF_DISALLOW_COPY_AND_MOVE(Compiler); Compiler() = default; ~Compiler() = default; void Compile(Job*, Plan*) const; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_COMPILER_H_ ================================================ FILE: oneflow/core/job/critical_section.proto ================================================ syntax = "proto2"; package oneflow; message TotalJobCriticalSection {} message InputOutputCriticalSection { repeated string lbi_producer_op_name = 1; } message CriticalSection { required int64 job_id = 1; map machine_id2source_tick_op_name = 2; map machine_id2sink_tick_op_name = 3; repeated int64 mem_block_id = 4; repeated int64 chunk_id = 5; oneof type { TotalJobCriticalSection total_job_critical_section = 6; InputOutputCriticalSection input_output_critical_section = 7; } } ================================================ FILE: oneflow/core/job/critical_section_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/critical_section_desc.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include #include #include namespace oneflow { CriticalSection* CriticalSectionDesc::AddCriticalSection(int64_t job_id) { CHECK_EQ(inited_, false); auto critical_section = std::make_unique(); CriticalSection* ret = critical_section.get(); critical_section->set_job_id(job_id); critical_sections_.emplace_back(std::move(critical_section)); return ret; } void CriticalSectionDesc::Done() { CHECK_EQ(inited_, false); UpdateJobId2CriticalSectionIds(); UpdateJobId2TotalJobCriticalSectionId(); UpdateCriticalSectionIds2IntersectingIds(); CHECK_EQ(job_id2critical_section_ids_.size(), job_id2total_job_critical_section_id_.size()); CHECK_EQ(critical_sections_.size(), critical_section_id2intersecting_ids_.size()); inited_ = true; std::string all_output; int32_t i = 0; for (const auto& cs : critical_sections_) { all_output += "CriticalSection " + std::to_string(i) + "\n"; std::string output; google::protobuf::TextFormat::PrintToString(*cs, &output); all_output += output; all_output += "\n"; i++; } TeePersistentLogStream::Create("critical_section_desc")->Write(all_output); } const CriticalSection& CriticalSectionDesc::GetCriticalSection(int64_t critical_section_id) const { CHECK(inited_); return *critical_sections_.at(critical_section_id); } CriticalSection* CriticalSectionDesc::MutCriticalSection(int64_t critical_section_id) const { CHECK_EQ(inited_, false); return critical_sections_.at(critical_section_id).get(); } const std::vector& CriticalSectionDesc::CriticalSectionIds4JobId(int64_t job_id) const { CHECK(inited_); return job_id2critical_section_ids_.at(job_id); } void CriticalSectionDesc::DumpCriticalSectionId2IntersectinIds(PbRpf* id2id_list) const { CHECK(inited_); FOR_RANGE(int64_t, i, 0, critical_sections_.size()) { *id2id_list->Add()->mutable_value() = {critical_section_id2intersecting_ids_.at(i).begin(), critical_section_id2intersecting_ids_.at(i).end()}; } } void CriticalSectionDesc::UpdateJobId2CriticalSectionIds() { CHECK_EQ(inited_, false); job_id2critical_section_ids_.resize(critical_sections_.size()); int64_t max_job_id = -1; FOR_RANGE(int64_t, i, 0, critical_sections_.size()) { const auto& critical_section = *critical_sections_.at(i); int64_t job_id = critical_section.job_id(); job_id2critical_section_ids_[job_id].emplace_back(i); max_job_id = std::max(max_job_id, job_id); } job_id2critical_section_ids_.resize(max_job_id + 1); } void CriticalSectionDesc::UpdateJobId2TotalJobCriticalSectionId() { CHECK_EQ(inited_, false); HashSet unique_check; job_id2total_job_critical_section_id_.resize(critical_sections_.size()); FOR_RANGE(int64_t, i, 0, critical_sections_.size()) { const auto& critical_section = *critical_sections_.at(i); if (critical_section.has_total_job_critical_section()) { CHECK(unique_check.emplace(critical_section.job_id()).second); job_id2total_job_critical_section_id_.at(critical_section.job_id()) = i; } } job_id2total_job_critical_section_id_.resize(unique_check.size()); } void CriticalSectionDesc::UpdateCriticalSectionIds2IntersectingIds() { CHECK_EQ(inited_, false); critical_section_id2intersecting_ids_.resize(critical_sections_.size()); HashMap> mem_block_id2critical_section_ids; HashMap> chunk_id2critical_section_ids; FOR_RANGE(int64_t, i, 0, critical_sections_.size()) { for (int64_t mem_block_id : critical_sections_.at(i)->mem_block_id()) { mem_block_id2critical_section_ids[mem_block_id].insert(i); } for (int64_t chunk_id : critical_sections_.at(i)->chunk_id()) { chunk_id2critical_section_ids[chunk_id].insert(i); } } for (const auto& pair : mem_block_id2critical_section_ids) { for (int64_t first_id : pair.second) { for (int64_t second_id : pair.second) { if (first_id != second_id) { critical_section_id2intersecting_ids_[first_id].insert(second_id); } } } } for (const auto& pair : chunk_id2critical_section_ids) { for (int64_t first_id : pair.second) { for (int64_t second_id : pair.second) { if (first_id != second_id) { critical_section_id2intersecting_ids_[first_id].insert(second_id); } } } } } } // namespace oneflow ================================================ FILE: oneflow/core/job/critical_section_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_ #define ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/critical_section.pb.h" namespace oneflow { class CriticalSectionDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionDesc); ~CriticalSectionDesc() = default; CriticalSection* AddCriticalSection(int64_t job_id); void Done(); size_t CriticalSectionNum() const { return critical_sections_.size(); } const CriticalSection& GetCriticalSection(int64_t) const; CriticalSection* MutCriticalSection(int64_t) const; const std::vector& CriticalSectionIds4JobId(int64_t) const; void DumpCriticalSectionId2IntersectinIds(PbRpf* id2id_list) const; const std::vector>& job_id2critical_section_ids() const { return job_id2critical_section_ids_; } const std::vector& job_id2total_job_critical_section_id() const { return job_id2total_job_critical_section_id_; } private: friend class Singleton; CriticalSectionDesc() : inited_(false) {} void UpdateJobId2CriticalSectionIds(); void UpdateJobId2TotalJobCriticalSectionId(); void UpdateCriticalSectionIds2IntersectingIds(); bool inited_; std::vector> critical_sections_; std::vector> job_id2critical_section_ids_; std::vector job_id2total_job_critical_section_id_; std::vector> critical_section_id2intersecting_ids_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_ ================================================ FILE: oneflow/core/job/critical_section_instance.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ #define ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ #include #include "oneflow/core/common/util.h" namespace oneflow { class Blob; namespace ep { class Stream; } class CriticalSectionInstance { public: CriticalSectionInstance() = default; virtual const std::string& job_name() const = 0; virtual ~CriticalSectionInstance() = default; virtual void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) const { UNIMPLEMENTED(); } virtual void Finish() const { UNIMPLEMENTED(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ ================================================ FILE: oneflow/core/job/distribute_hirarchy.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/sbp_parallel.proto"; enum DistributeType { kInvalidDistributeType = 0; kSpaceDistribute = 2; kTimeDistribute = 3; } message DistributeDim { required DistributeType distribute_type = 1; required SbpParallel sbp_parallel = 2; required int64 distribute_num = 3; } message DistributeHirarchy { repeated DistributeDim dim = 1; } ================================================ FILE: oneflow/core/job/dlnet_conf.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/operator/op_conf.proto"; message DLNetConf { repeated OperatorConf op = 1; } ================================================ FILE: oneflow/core/job/eager_ccl_comm_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/eager_ccl_comm_manager.h" namespace oneflow { const std::string EagerCclCommMgr::kDefaultCclStreamName = "DEFAULT"; EagerCclCommMgrBuilder& EagerCclCommMgrBuilder::Get() { static EagerCclCommMgrBuilder mgr; return mgr; } } // namespace oneflow ================================================ FILE: oneflow/core/job/eager_ccl_comm_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_ #define ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { class EagerCclCommMgr { public: static const std::string kDefaultCclStreamName; OF_DISALLOW_COPY_AND_MOVE(EagerCclCommMgr); virtual ~EagerCclCommMgr() = default; virtual void CreateCommFromPlan(const Plan& plan) = 0; virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0; virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, const std::string& stream_name) = 0; virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, const std::string& stream_name, const int64_t this_parallel_id, const std::string& comm_key) = 0; template T* As() { return dynamic_cast(this); } protected: EagerCclCommMgr() = default; }; class EagerCclCommMgrBuilder { public: using Creator = std::function; EagerCclCommMgrBuilder(EagerCclCommMgrBuilder const&) = delete; EagerCclCommMgrBuilder& operator=(EagerCclCommMgrBuilder const&) = delete; static EagerCclCommMgrBuilder& Get(); template void RegisterEagerCclCommMgrType(DeviceType device_type) { ccl_comm_mgr_reg_result_->emplace(device_type, []() -> EagerCclCommMgr* { return new Derived; }); vaild_ccl_comm_mgr_device_types_.emplace_back(device_type); } EagerCclCommMgr* NewCclCommMgr(DeviceType device_type) const { const auto& it = ccl_comm_mgr_reg_result_->find(device_type); CHECK(it != ccl_comm_mgr_reg_result_->end()); return it->second(); } const std::vector& vaild_ccl_comm_mgr_device_types() const { return vaild_ccl_comm_mgr_device_types_; } private: EagerCclCommMgrBuilder() { ccl_comm_mgr_reg_result_.reset(new std::map); } std::unique_ptr> ccl_comm_mgr_reg_result_; std::vector vaild_ccl_comm_mgr_device_types_; }; #define REGISTER_CCL_COMM_MGR(device, Derived) \ COMMAND(EagerCclCommMgrBuilder::Get().RegisterEagerCclCommMgrType(device)) class UserKernelUnifiedCclCommInitRegistry final { public: struct Trigger { explicit Trigger(const std::string& key) { UserKernelUnifiedCclCommInitRegistry::Instance().Register(key); } }; static UserKernelUnifiedCclCommInitRegistry& Instance() { static UserKernelUnifiedCclCommInitRegistry reg; return reg; } OF_DISALLOW_COPY_AND_MOVE(UserKernelUnifiedCclCommInitRegistry); ~UserKernelUnifiedCclCommInitRegistry() = default; void Register(const std::string& key) { bool insert_success = reg_set_.insert(key).second; if (!insert_success) { std::cerr << key << " was already registered in CclCommRegistry" << std::endl; abort(); } } bool IsRegistered(const std::string& key) const { return reg_set_.find(key) != reg_set_.end(); } private: UserKernelUnifiedCclCommInitRegistry() = default; std::set reg_set_; }; static const std::string kSystemCclOpPrefix = "sys_op_"; #define REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(op_type_name) \ static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ ::oneflow::UserKernelUnifiedCclCommInitRegistry::Trigger(op_type_name) #define REGISTER_SYSTEM_OP_KERNEL_UNIFIED_CCL_COMM_INIT(op_type_case) \ static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ ::oneflow::UserKernelUnifiedCclCommInitRegistry::Trigger(::oneflow::kSystemCclOpPrefix \ + std::to_string(op_type_case)) } // namespace oneflow #endif // ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_ ================================================ FILE: oneflow/core/job/eager_nccl_comm_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/vm/vm_util.h" #ifdef WITH_CUDA namespace oneflow { namespace { std::string GetNcclUniqueIdRpcKey(const std::vector>& sorted_devices) { std::ostringstream oss; oss << "eager_nccl_unique_id_rpc_key"; for (const auto& pair : sorted_devices) { oss << "," << pair.first << ":" << pair.second; } return oss.str(); } std::string NcclUniqueId2String(const ncclUniqueId& id) { std::stringstream ss; for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { ss << std::hex << std::setfill('0') << std::setw(2) << static_cast(id.internal[i]); } return ss.str(); } bool CompareDeviceSetPair(const std::pair& a, const std::pair& b) { if (a.first == b.first) { return a.second < b.second; } else { return a.first < b.first; } } void CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key, const std::vector>& device_vec) { ncclUniqueId nccl_unique_id{}; int64_t machine = GlobalProcessCtx::Rank(); std::pair this_device(machine, dev); auto it = std::find(device_vec.cbegin(), device_vec.cend(), this_device); CHECK(it != device_vec.end()); int rank = std::distance(device_vec.cbegin(), it); if (rank == 0) { OF_NCCL_CHECK(ncclGetUniqueId(&nccl_unique_id)); Singleton::Get()->PushKV( key, std::string(nccl_unique_id.internal, NCCL_UNIQUE_ID_BYTES)); } else { Singleton::Get()->PullKV(key, [&nccl_unique_id](const std::string& val) { memcpy(nccl_unique_id.internal, val.data(), NCCL_UNIQUE_ID_BYTES); }); } VLOG(2) << " EagerNcclCommMgr::ncclCommInitRank device_vec.size() = " << device_vec.size() << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank << ", key = {" << key << "}\n"; OF_NCCL_CHECK(ncclCommInitRank(comm, device_vec.size(), nccl_unique_id, rank)); VLOG(2) << " EagerNcclCommMgr::ncclCommInitRank succeed device_vec.size() = " << device_vec.size() << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank << ", key = {" << key << "}\n"; } bool NeedUnifiedNcclCommInit(const OperatorConf& op_conf) { if (op_conf.has_user_conf()) { return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered( op_conf.user_conf().op_type_name()); } else { // Please check the .h file for hard-coding of the name return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered( kSystemOpPrefix + std::to_string(op_conf.op_type_case())); } } } // namespace const std::string EagerNcclCommMgr::kDefaultStreamName = "DEFAULT"; EagerNcclCommMgr::~EagerNcclCommMgr() { for (auto& device_set7device_id2comm : device_set2device_id2comm_) { for (auto& device_id7comm : device_set7device_id2comm.second) { OF_NCCL_CHECK(ncclCommDestroy(device_id7comm.second)); } } for (auto& pair : device7stream2device_id2comm_) { for (auto& device_id7comm : pair.second) { OF_NCCL_CHECK(ncclCommDestroy(device_id7comm.second)); } } } ncclComm_t EagerNcclCommMgr::GetCommForDevice( const std::set>& device_set) { int dev; OF_CUDA_CHECK(cudaGetDevice(&dev)); { std::lock_guard lock(mutex_); auto it = device_set2device_id2comm_.find(device_set); if (it != device_set2device_id2comm_.end()) { return it->second.at(dev); } } std::vector> device_vec(device_set.cbegin(), device_set.cend()); std::sort(device_vec.begin(), device_vec.end(), CompareDeviceSetPair); ncclComm_t comm; std::string nccl_unique_id_rpc_key = GetNcclUniqueIdRpcKey(device_vec); CreateNcclComm(&comm, dev, nccl_unique_id_rpc_key, device_vec); { std::lock_guard lock(mutex_); device_set2device_id2comm_[device_set][dev] = comm; } return comm; } ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName( const std::set>& device_set, const std::string& stream_name) { int dev; OF_CUDA_CHECK(cudaGetDevice(&dev)); std::vector> device_vec(device_set.cbegin(), device_set.cend()); std::sort(device_vec.begin(), device_vec.end(), CompareDeviceSetPair); std::string key = GetNcclUniqueIdRpcKey(device_vec) + "-stream_name_hint:" + stream_name; { std::lock_guard lock(mutex_); auto it = device7stream2device_id2comm_.find(key); if (it != device7stream2device_id2comm_.end()) { return it->second.at(dev); } } ncclComm_t comm; CreateNcclComm(&comm, dev, key, device_vec); { std::lock_guard lock(mutex_); device7stream2device_id2comm_[key][dev] = comm; } return comm; } ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) { std::set> device_set; FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } ncclComm_t comm = GetCommForDevice(device_set); std::shared_ptr ncclCommAdapter = std::make_shared(comm); ccl::CclComm ccl_comm(ncclCommAdapter); return ccl_comm; } ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName( const ParallelDesc& parallel_desc, const std::string& stream_name) { std::set> device_set; FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); std::shared_ptr ncclCommAdapter = std::make_shared(comm); ccl::CclComm ccl_comm(ncclCommAdapter); return ccl_comm; } ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy( const ParallelDesc& parallel_desc, const std::string& stream_name, const int64_t this_parallel_id, const std::string& comm_key) { std::set> device_set; const Shape& hierarchy = *parallel_desc.hierarchy(); CHECK_LE(hierarchy.NumAxes(), 2); // 1D if (hierarchy.NumAxes() == 1) { // 1D hierarchy for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } } else if (hierarchy.NumAxes() == 2) { // 2D hierarchy CHECK(comm_key == "SameDim0" || comm_key == "SameDim1"); if (comm_key == "SameDim0") { const int64_t num_groups = hierarchy.At(0); const int64_t group_size = hierarchy.At(1); CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size; CHECK_EQ(this_group_begin_parallel_id % group_size, 0); CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num()); for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } } else if (comm_key == "SameDim1") { const int64_t group_size = hierarchy.At(0); const int64_t num_groups = hierarchy.At(1); CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups; CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, parallel_desc.parallel_num()); for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } } else { UNIMPLEMENTED(); } } ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); std::shared_ptr ncclCommAdapter = std::make_shared(comm); ccl::CclComm ccl_comm(ncclCommAdapter); return ccl_comm; } void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) { const int64_t rank = GlobalProcessCtx::Rank(); const int64_t dev = GlobalProcessCtx::LocalRank(); std::map>> nccl_comm_key2devices; for (const auto& task_proto : plan.task()) { if (task_proto.machine_id() != rank) { continue; } if (task_proto.exec_sequence().exec_node_size() != 1) { continue; } const auto& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf(); const OpAttribute* op_attr = nullptr; if (kernel_conf.has_op_attribute()) { op_attr = &kernel_conf.op_attribute(); } else if (kernel_conf.has_op_attribute_ref()) { const auto& ref_name = kernel_conf.op_attribute_ref(); op_attr = &plan.job_id2op_attribute_ref_table() .at(task_proto.job_id()) .op_name2op_attribute() .at(ref_name); } else { continue; } const auto& op_conf = op_attr->op_conf(); if (!NeedUnifiedNcclCommInit(op_conf)) { continue; } if (!op_attr->has_parallel_conf_signature()) { continue; } if (!op_attr->parallel_conf_signature().has_op_parallel_conf()) { continue; } std::vector> device_vec; ParallelDesc parallel_desc(op_attr->parallel_conf_signature().op_parallel_conf()); for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_vec.emplace_back(machine_id, device_id); } std::string stream_name = kDefaultStreamName; if (op_conf.has_stream_name_hint()) { stream_name = op_conf.stream_name_hint(); } std::string key = GetNcclUniqueIdRpcKey(device_vec) + "-stream_name_hint:" + stream_name; VLOG(3) << " EagerNcclCommMgr create nccl comm for " << op_conf.name() << ", rank = " << rank << ", dev = " << dev << ", key = {" << key << "}\n"; nccl_comm_key2devices.emplace(std::move(key), std::move(device_vec)); } if (nccl_comm_key2devices.size() == 0) { return; } CHECK_JUST(vm::CurrentRankSync()); CudaCurrentDeviceGuard guard(dev); for (const auto& pair : nccl_comm_key2devices) { const auto& key = pair.first; auto device_id2comm_it = device7stream2device_id2comm_.find(key); if (device_id2comm_it != device7stream2device_id2comm_.end()) { auto comm_it = device_id2comm_it->second.find(dev); if (comm_it != device_id2comm_it->second.end()) { continue; } } ncclComm_t comm; CreateNcclComm(&comm, dev, key, pair.second); device7stream2device_id2comm_[key][dev] = comm; } } REGISTER_CCL_COMM_MGR(DeviceType::kCUDA, EagerNcclCommMgr); } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/job/eager_nccl_comm_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_ #define ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/eager_ccl_comm_manager.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace ccl { class NcclCommAdapter : public CommBase { public: explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} void* getComm() const override { return const_cast(static_cast(&comm_)); } private: ncclComm_t comm_; }; } // namespace ccl class EagerNcclCommMgr final : public EagerCclCommMgr { public: static const std::string kDefaultStreamName; OF_DISALLOW_COPY_AND_MOVE(EagerNcclCommMgr); ~EagerNcclCommMgr() override; ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override; ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, const std::string& stream_name) override; ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, const std::string& stream_name, const int64_t this_parallel_id, const std::string& comm_key) override; void CreateCommFromPlan(const Plan& plan) override; bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; } void SetAsyncLaunchCclLogicalKernel(bool val) override { async_launch_nccl_logical_kernel_ = val; } private: friend class EagerCclCommMgrBuilder; // NOTE(chengcheng): default async launch nccl logical kernel is true for better performence. EagerNcclCommMgr() : EagerCclCommMgr(), async_launch_nccl_logical_kernel_(true) {} std::map>, HashMap> device_set2device_id2comm_; std::map> device7stream2device_id2comm_; std::mutex mutex_; bool async_launch_nccl_logical_kernel_; }; class UserKernelUnifiedNcclCommInitRegistry final { public: struct Trigger { explicit Trigger(const std::string& key) { UserKernelUnifiedNcclCommInitRegistry::Instance().Register(key); } }; static UserKernelUnifiedNcclCommInitRegistry& Instance() { static UserKernelUnifiedNcclCommInitRegistry reg; return reg; } OF_DISALLOW_COPY_AND_MOVE(UserKernelUnifiedNcclCommInitRegistry); ~UserKernelUnifiedNcclCommInitRegistry() = default; void Register(const std::string& key) { bool insert_success = reg_set_.insert(key).second; if (!insert_success) { std::cerr << key << " was already registered in NcclCommRegistry" << std::endl; abort(); } } bool IsRegistered(const std::string& key) const { return reg_set_.find(key) != reg_set_.end(); } private: UserKernelUnifiedNcclCommInitRegistry() = default; std::set reg_set_; }; static const std::string kSystemOpPrefix = "sys_op_"; } // namespace oneflow #define REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_name) \ static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(op_type_name) #define REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_case) \ static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(::oneflow::kSystemOpPrefix \ + std::to_string(op_type_case)) #endif // WITH_CUDA #endif // ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_ ================================================ FILE: oneflow/core/job/env.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/control/ctrl_bootstrap.proto"; message Machine { required int64 id = 1; required string addr = 2; // domain name or ip optional int32 ctrl_port_agent = 3 [default = -1]; optional int32 data_port_agent = 4 [default = -1]; } message CppLoggingConf { optional string log_dir = 1 [default = "./log"]; optional int32 logtostderr = 2 [default = 1]; optional int32 logbuflevel = 3 [default = -1]; optional int32 minloglevel = 4 [default = 1]; } message EnvProto { repeated Machine machine = 1; required int32 ctrl_port = 2; optional int32 data_port = 3 [default = -1]; optional CppLoggingConf cpp_logging_conf = 4; optional BootstrapConf ctrl_bootstrap_conf = 5; } ================================================ FILE: oneflow/core/job/env_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { const BootstrapConf& EnvDesc::bootstrap_conf() const { CHECK(has_ctrl_bootstrap_conf()); return env_proto_.ctrl_bootstrap_conf(); } int32_t EnvDesc::bootstrap_conf_ctrl_port() const { CHECK(has_bootstrap_conf_ctrl_port()); return env_proto_.ctrl_bootstrap_conf().ctrl_port(); } size_t EnvDesc::TotalMachineNum() const { if (env_proto_.has_ctrl_bootstrap_conf()) { return env_proto_.ctrl_bootstrap_conf().world_size(); } else { return env_proto_.machine().size(); } } int64_t EnvDesc::GetMachineId(const std::string& addr) const { int64_t machine_id = -1; int64_t machine_num = env_proto_.machine_size(); FOR_RANGE(int64_t, i, 0, machine_num) { if (addr == env_proto_.machine(i).addr()) { machine_id = i; break; } } CHECK_GE(machine_id, 0); CHECK_LT(machine_id, machine_num); return machine_id; } } // namespace oneflow ================================================ FILE: oneflow/core/job/env_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CLUSTER_DESC_H_ #define ONEFLOW_CORE_JOB_CLUSTER_DESC_H_ #include "oneflow/core/job/env.pb.h" #include "oneflow/core/common/util.h" namespace oneflow { class EnvDesc final { public: OF_DISALLOW_COPY_AND_MOVE(EnvDesc); explicit EnvDesc(const EnvProto& env_proto) : env_proto_(env_proto) {} ~EnvDesc() = default; const EnvProto& env_proto() const { return env_proto_; } const Machine& machine(int32_t idx) const { return env_proto_.machine(idx); } int32_t ctrl_port() const { return env_proto_.ctrl_port(); } int32_t data_port() const { return env_proto_.data_port(); } bool has_ctrl_bootstrap_conf() const { return env_proto_.has_ctrl_bootstrap_conf(); } bool has_bootstrap_conf_ctrl_port() const { return has_ctrl_bootstrap_conf() && env_proto_.ctrl_bootstrap_conf().has_ctrl_port(); } const BootstrapConf& bootstrap_conf() const; int32_t bootstrap_conf_ctrl_port() const; size_t TotalMachineNum() const; int64_t GetMachineId(const std::string& addr) const; private: EnvProto env_proto_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CLUSTER_DESC_H_ ================================================ FILE: oneflow/core/job/env_global_objects_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/remat/allocator.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA #include #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/vm/virtual_machine_scope.h" #include "oneflow/core/vm/remat/util.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/eager_ccl_comm_manager.h" #include "oneflow/core/device/cudnn_conv_util.h" #include "oneflow/core/rpc/include/manager.h" #include "oneflow/core/transport/transport.h" #include "oneflow/core/hardware/node_device_descriptor_manager.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/operator/op_node_signature.pb.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/comm_network/epoll/epoll_comm_network.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #include "oneflow/core/kernel/chain_kernel_observer.h" #include "oneflow/core/kernel/sync_check_kernel_observer.h" #include "oneflow/core/kernel/blob_access_checker_kernel_observer.h" #include "oneflow/core/kernel/profiler_kernel_observer.h" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/vm/remat/env.h" #ifdef WITH_RDMA #include "oneflow/core/platform/include/ibv.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #endif // WITH_RDMA #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/cpu/cpu_device_manager.h" #include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { namespace { std::string LogDir(const std::string& log_dir) { char hostname[255]; CHECK_EQ(gethostname(hostname, sizeof(hostname)), 0); std::string v = JoinPath(log_dir, std::string(hostname)); return v; } void InitLogging(const CppLoggingConf& logging_conf) { FLAGS_log_dir = LogDir(logging_conf.log_dir()); FLAGS_logtostderr = logging_conf.logtostderr(); FLAGS_logbuflevel = logging_conf.logbuflevel(); FLAGS_minloglevel = logging_conf.minloglevel(); FLAGS_stderrthreshold = 1; // 1=WARNING google::InitGoogleLogging("oneflow"); if (IsInDebugMode()) { // record all level logs to file in debug mode FLAGS_logtostderr = 0; FLAGS_minloglevel = 0; // 0=INFO } if (!FLAGS_logtostderr) { LocalFS()->RecursivelyCreateDirIfNotExist(FLAGS_log_dir); } } int32_t GetDefaultCpuDeviceNum() { return std::thread::hardware_concurrency(); } Resource GetDefaultResource(const EnvProto& env_proto) { Resource resource; if (env_proto.has_ctrl_bootstrap_conf()) { resource.set_machine_num(GlobalProcessCtx::NodeSize()); } else { resource.set_machine_num(env_proto.machine_size()); } resource.set_cpu_device_num(GetDefaultCpuDeviceNum()); return resource; } void SetCpuDeviceManagerNumThreads() { ep::CpuDeviceManager* cpu_device_manager = dynamic_cast( Singleton::Get()->GetDeviceManager(DeviceType::kCPU)); constexpr size_t kDefaultUsedNumThreads = 2; int64_t cpu_logic_core = std::thread::hardware_concurrency(); int64_t default_num_threads = (cpu_logic_core / GlobalProcessCtx::NumOfProcessPerNode()) - kDefaultUsedNumThreads; int64_t num_threads = ParseIntegerFromEnv("OMP_NUM_THREADS", default_num_threads); cpu_device_manager->SetDeviceNumThreads(num_threads); } void ClearAllSymbol() { Singleton>::Get()->ClearAll(); Singleton>::Get()->ClearAll(); Singleton>::Get()->ClearAll(); Singleton>::Get()->ClearAll(); } #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) bool CommNetIBEnabled() { if (!ibv::IsAvailable()) { return false; } const auto* node_manager = Singleton::Get(); if (node_manager == nullptr) { return false; } for (int64_t rank = 0; rank < GlobalProcessCtx::WorldSize(); ++rank) { const auto& node = node_manager->GetNodeDeviceDescriptor(rank); if (!node) { return false; } const auto& list = node->GetDeviceDescriptorList("net_ib"); if (!list) { return false; } if (list->DeviceCount() == 0) { return false; } } return true; } #endif // WITH_RDMA && OF_PLATFORM_POSIX } // namespace EnvGlobalObjectsScope::EnvGlobalObjectsScope(const std::string& env_proto_str) { EnvProto env_proto; CHECK(TxtString2PbMessage(env_proto_str, &env_proto)) << "failed to parse env_proto" << env_proto_str; CHECK_JUST(Init(env_proto)); } EnvGlobalObjectsScope::EnvGlobalObjectsScope(const EnvProto& env_proto) { CHECK_JUST(Init(env_proto)); } Maybe EnvGlobalObjectsScope::Init(const EnvProto& env_proto) { CHECK(Singleton::Get() == nullptr); Singleton::SetAllocated(this); InitLogging(env_proto.cpp_logging_conf()); Singleton::New(); Singleton::New(env_proto); Singleton::New(); // Avoid dead lock by using CHECK_JUST instead of JUST. because it maybe be blocked in // ~CtrlBootstrap. if ((env_proto.machine_size() == 1 && env_proto.has_ctrl_bootstrap_conf() == false) || (env_proto.has_ctrl_bootstrap_conf() && env_proto.ctrl_bootstrap_conf().world_size() == 1)) /*single process*/ { #ifdef RPC_BACKEND_LOCAL LOG(INFO) << "Using rpc backend: local"; Singleton::SetAllocated(new LocalRpcManager()); #else static_assert(false, "Requires rpc backend local to run oneflow in single processs"); #endif // RPC_BACKEND_LOCAL } else /*multi process, multi machine*/ { #ifdef RPC_BACKEND_GRPC LOG(INFO) << "Using rpc backend: gRPC"; Singleton::SetAllocated(new GrpcRpcManager()); #else UNIMPLEMENTED() << "To run distributed oneflow, you must enable at least one multi-node rpc " "backend by adding cmake argument, for instance: -DRPC_BACKEND=GRPC"; #endif // RPC_BACKEND_GRPC } CHECK_JUST(Singleton::Get()->CreateServer()); CHECK_JUST(Singleton::Get()->Bootstrap()); CHECK_JUST(Singleton::Get()->CreateClient()); Singleton::New(GetDefaultResource(env_proto), GlobalProcessCtx::NumOfProcessPerNode()); Singleton::New(GetDefaultResource(env_proto), GlobalProcessCtx::NumOfProcessPerNode()); Singleton::SetAllocated( new hardware::NodeDeviceDescriptorManager()); if (Singleton::Get()->enable_debug_mode()) { Singleton::Get()->DumpSummary("devices"); } Singleton::New(); Singleton::New(); Singleton::New(Singleton::Get()->ComputeThreadPoolSize()); SetCpuDeviceManagerNumThreads(); #ifdef WITH_CUDA Singleton::New(); Singleton::New(); Singleton::New(); #endif const auto& vaild_ccl_comm_mgr_device_types = EagerCclCommMgrBuilder::Get().vaild_ccl_comm_mgr_device_types(); CHECK_LE_OR_RETURN(vaild_ccl_comm_mgr_device_types.size(), 1) << "Only one kind collective communication manager is supported at most at the same time for " "now!"; if (!vaild_ccl_comm_mgr_device_types.empty() && !Singleton::Get()) { Singleton::SetAllocated( EagerCclCommMgrBuilder::Get().NewCclCommMgr(vaild_ccl_comm_mgr_device_types.front())); } Singleton::New(Singleton::Get()->resource()); #ifdef __linux__ Singleton::New(); Singleton::New(); if (Singleton::Get()->process_ranks().size() > 1) { Singleton::SetAllocated(Singleton::Get()); } #endif // __linux__ { std::vector> kernel_observers; if (ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK", false)) { LOG(WARNING) << "Environment variable ONEFLOW_DEBUG_KERNEL_SYNC_CHECK has been set to a truthy " "value, it will impact performance"; kernel_observers.emplace_back(new SyncCheckKernelObserver()); } if (!ParseBooleanFromEnv("ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER", true)) { kernel_observers.emplace_back(new BlobAccessCheckerKernelObserver()); } kernel_observers.emplace_back(new ProfilerKernelObserver()); Singleton::SetAllocated(new ChainKernelObserver(kernel_observers)); } TensorBufferPool::New(); return Maybe::Ok(); } EnvGlobalObjectsScope::~EnvGlobalObjectsScope() { VLOG(2) << "Try to close env global objects scope." << std::endl; OF_ENV_BARRIER(); if (is_normal_exit_.has_value() && !CHECK_JUST(is_normal_exit_)) { return; } TensorBufferPool::Delete(); Singleton::Delete(); #ifdef __linux__ if (Singleton::Get()->process_ranks().size() > 1) { if (Singleton::Get() != dynamic_cast(Singleton::Get())) { Singleton::Delete(); } } Singleton::Delete(); Singleton::Delete(); #endif // __linux__ Singleton::Delete(); #ifdef WITH_CUDA Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); #endif if (Singleton::Get() != nullptr) { Singleton::Delete(); } Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); if (Singleton::Get() != nullptr) { Singleton::Delete(); } Singleton::Delete(); Singleton::Delete(); CHECK_NOTNULL(Singleton::Get()); CHECK_NOTNULL(Singleton::Get()); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); ClearAllSymbol(); ClearAllBackwardPassScope(); if (Singleton::Get() != nullptr) { Singleton::SetAllocated(nullptr); } VLOG(2) << "Finish closing env global objects scope." << std::endl; google::ShutdownGoogleLogging(); } Maybe InitRDMA() { #ifdef __linux__ if (Singleton::Get()->process_ranks().size() > 1) { #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) if (CommNetIBEnabled()) { if (Singleton::Get() == nullptr) { Singleton::New(); Singleton::SetAllocated(Singleton::Get()); } else { LOG(INFO) << "Skip init RDMA because RDMA is already initialized!"; } } else { LOG(WARNING) << "Skip init RDMA because RDMA is unavailable!"; } #else LOG(WARNING) << "Skip init RDMA because RDMA is not compiled!"; #endif // WITH_RDMA && OF_PLATFORM_POSIX } else { LOG(INFO) << "Skip init RDMA because only one process in this group!"; } #endif // __linux__ return Maybe::Ok(); } Maybe RDMAIsInitialized() { #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) return Singleton::Get() != nullptr; #else return false; #endif // WITH_RDMA && OF_PLATFORM_POSIX } Maybe DestoryRDMA() { #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) if (JUST(RDMAIsInitialized())) { CHECK_NOTNULL(Singleton::Get()); CHECK_NOTNULL(Singleton::Get()); Singleton::Delete(); if (Singleton::Get()) { Singleton::SetAllocated(Singleton::Get()); } } #endif // WITH_RDMA && OF_PLATFORM_POSIX return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/env_global_objects_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_ #define ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/framework/device.h" namespace oneflow { class ParallelDesc; class EnvGlobalObjectsScope final { public: OF_DISALLOW_COPY_AND_MOVE(EnvGlobalObjectsScope); explicit EnvGlobalObjectsScope(const std::string& env_proto_str); explicit EnvGlobalObjectsScope(const EnvProto& env_proto); ~EnvGlobalObjectsScope(); Maybe init_is_normal_exit(bool is_normal_exit) { CHECK_OR_RETURN(!is_normal_exit_.has_value()); is_normal_exit_ = is_normal_exit; return Maybe::Ok(); } private: Maybe Init(const EnvProto& env_proto); private: Optional is_normal_exit_; }; Maybe InitRDMA(); Maybe RDMAIsInitialized(); Maybe DestoryRDMA(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_ ================================================ FILE: oneflow/core/job/function_config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" namespace oneflow {} ================================================ FILE: oneflow/core/job/global_for.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/util.h" namespace oneflow { COMMAND(Singleton, MultiClient>::SetAllocated(new Optional())); } // namespace oneflow ================================================ FILE: oneflow/core/job/global_for.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_GLOBAL_FOR_H_ #define ONEFLOW_CORE_JOB_GLOBAL_FOR_H_ #include "oneflow/core/common/singleton.h" namespace oneflow { class ForSession {}; class ForEnv {}; class MultiClient {}; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_GLOBAL_FOR_H_ ================================================ FILE: oneflow/core/job/global_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/global_mode.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/framework/device.h" namespace oneflow { Symbol GetGlobalParallelDescFromDevice(const Optional>& device) { auto parallel_desc = GlobalMode::parallel_desc(); if (device.has_value()) { const auto& device_type = device.value_or(Symbol())->type(); if (parallel_desc->parallel_conf().device_tag() != device_type) { ParallelConf parallel_conf = parallel_desc->parallel_conf(); parallel_conf.set_device_tag(device_type); parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); } } return parallel_desc; } /* static */ bool* GlobalMode::get_mode_ptr() { thread_local bool mode = false; return &mode; } /* static */ bool GlobalMode::is_enabled() { return *get_mode_ptr(); } /* static */ void GlobalMode::set_enabled(bool enabled) { *get_mode_ptr() = enabled; } /* static */ Symbol* GlobalMode::get_nd_sbp_ptr() { thread_local Symbol nd_sbp; return &nd_sbp; } /* static */ Symbol GlobalMode::nd_sbp() { return *get_nd_sbp_ptr(); } /* static */ void GlobalMode::set_nd_sbp(Symbol nd_sbp) { *get_nd_sbp_ptr() = nd_sbp; } /* static */ Symbol* GlobalMode::get_parallel_desc_ptr() { thread_local Symbol parallel_desc; return ¶llel_desc; } /* static */ Symbol GlobalMode::parallel_desc() { return *get_parallel_desc_ptr(); } /* static */ void GlobalMode::set_parallel_desc(Symbol parallel_desc) { *get_parallel_desc_ptr() = parallel_desc; } } // namespace oneflow ================================================ FILE: oneflow/core/job/global_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_GLOBAL_MODE_H_ #define ONEFLOW_CORE_JOB_GLOBAL_MODE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.pb.h" namespace oneflow { Symbol GetGlobalParallelDescFromDevice(const Optional>& device); class GlobalMode { public: OF_DISALLOW_COPY_AND_MOVE(GlobalMode); GlobalMode() = default; ~GlobalMode() = default; static bool is_enabled(); static Symbol nd_sbp(); static Symbol parallel_desc(); class Guard { public: explicit Guard(bool enabled) : prev_mode_(GlobalMode::is_enabled()), prev_nd_sbp_(GlobalMode::nd_sbp()), prev_parallel_desc_(GlobalMode::parallel_desc()) { CHECK(!enabled); GlobalMode::set_enabled(enabled); } explicit Guard(bool enabled, Symbol nd_sbp, Symbol parallel_desc) : prev_mode_(GlobalMode::is_enabled()), prev_nd_sbp_(GlobalMode::nd_sbp()), prev_parallel_desc_(GlobalMode::parallel_desc()) { GlobalMode::set_enabled(enabled); if (enabled) { GlobalMode::set_nd_sbp(nd_sbp); GlobalMode::set_parallel_desc(parallel_desc); } } ~Guard() { GlobalMode::set_enabled(prev_mode_); GlobalMode::set_nd_sbp(prev_nd_sbp_); GlobalMode::set_parallel_desc(prev_parallel_desc_); } private: bool prev_mode_; Symbol prev_nd_sbp_; Symbol prev_parallel_desc_; }; private: static bool* get_mode_ptr(); static Symbol* get_nd_sbp_ptr(); static Symbol* get_parallel_desc_ptr(); static void set_enabled(bool enabled); static void set_nd_sbp(Symbol nd_sbp); static void set_parallel_desc(Symbol parallel_desc); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_GLOBAL_MODE_H_ ================================================ FILE: oneflow/core/job/graph_scope_vars.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/graph_scope_vars.h" #include namespace oneflow { namespace { std::vector* GetPythonPathsToBeFilteredForDebuggingVar() { static thread_local std::vector filtered_paths; return &filtered_paths; } std::vector* GetPythonPathsToBeKeptForDebuggingVar() { static thread_local std::vector kept_paths; return &kept_paths; } bool* GetGraphVerboseStepLr() { static thread_local bool graph_verbose_step_lr = false; return &graph_verbose_step_lr; } int32_t* GetGraphDebugMaxPyStackDepthVar() { static thread_local int32_t graph_debug_max_py_stack_depth = 2; return &graph_debug_max_py_stack_depth; } bool* GetGraphDebugModeFlag() { static thread_local bool graph_debug_mode_flag = false; return &graph_debug_mode_flag; } bool* GetGraphDebugOnlyUserPyStackFlag() { static thread_local bool graph_debug_only_user_py_stack = true; return &graph_debug_only_user_py_stack; } } // namespace bool IsOpenGraphVerboseStepLr() { auto* graph_verbose_step_lr = GetGraphVerboseStepLr(); bool is_graph_verbose_step_lr = *graph_verbose_step_lr; return is_graph_verbose_step_lr; } void SetGraphVerboseStepLr(bool verbose) { auto* graph_verbose_step_lr = GetGraphVerboseStepLr(); *graph_verbose_step_lr = verbose; } void InitPythonPathsToBeKeptAndFilteredForDebugging(const std::string& python_base_dir) { std::vector* kept_paths = GetPythonPathsToBeKeptForDebuggingVar(); kept_paths->clear(); kept_paths->push_back(python_base_dir + "/test"); kept_paths->push_back(python_base_dir + "/nn/modules"); std::vector* filtered_paths = GetPythonPathsToBeFilteredForDebuggingVar(); filtered_paths->clear(); filtered_paths->push_back(python_base_dir); } const std::vector& GetPythonPathsToBeFilteredForDebugging() { return *GetPythonPathsToBeFilteredForDebuggingVar(); } const std::vector& GetPythonPathsToBeKeptForDebugging() { return *GetPythonPathsToBeKeptForDebuggingVar(); } void SetGraphDebugMaxPyStackDepth(int32_t depth) { *GetGraphDebugMaxPyStackDepthVar() = depth; } int32_t GetGraphDebugMaxPyStackDepth() { return *GetGraphDebugMaxPyStackDepthVar(); } void SetGraphDebugMode(bool mode) { *GetGraphDebugModeFlag() = mode; } bool GetGraphDebugMode() { return *GetGraphDebugModeFlag(); } void SetGraphDebugOnlyUserPyStack(bool flag) { *GetGraphDebugOnlyUserPyStackFlag() = flag; } bool GetGraphDebugOnlyUserPyStack() { return *GetGraphDebugOnlyUserPyStackFlag(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/graph_scope_vars.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_ #define ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_ #include #include #include namespace oneflow { bool IsOpenGraphVerboseStepLr(); void SetGraphVerboseStepLr(bool verbose); void SetGraphDebugMaxPyStackDepth(int32_t depth); int32_t GetGraphDebugMaxPyStackDepth(); void SetGraphDebugMode(bool mode); bool GetGraphDebugMode(); void SetGraphDebugOnlyUserPyStack(bool flag); bool GetGraphDebugOnlyUserPyStack(); void InitPythonPathsToBeKeptAndFilteredForDebugging(const std::string& python_base_dir); const std::vector& GetPythonPathsToBeFilteredForDebugging(); const std::vector& GetPythonPathsToBeKeptForDebugging(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_ ================================================ FILE: oneflow/core/job/id_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/id_manager.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/id_state.h" namespace oneflow { namespace { constexpr static int64_t kRankLimitShift = 16; constexpr static int64_t kIdLimitShift = (sizeof(int64_t) * 8 - kRankLimitShift); static_assert(kIdLimitShift > 0, ""); int64_t AddCurrentRankOffset(int64_t x) { CHECK_GE(x, 0); CHECK_LT(x, (static_cast(1) << kIdLimitShift)); return (static_cast(GlobalProcessCtx::Rank()) << kIdLimitShift) + x; } } // namespace IDMgr::IDMgr() { regst_desc_id_count_ = 0; mem_block_id_count_ = 0; chunk_id_count_ = 0; CHECK_LE(GlobalProcessCtx::WorldSize(), (static_cast(1) << kRankLimitShift)); } int64_t IDMgr::NewRegstDescId() { return AddCurrentRankOffset(regst_desc_id_count_++); } int64_t IDMgr::NewMemBlockId() { return AddCurrentRankOffset(mem_block_id_count_++); } int64_t IDMgr::NewChunkId() { return AddCurrentRankOffset(chunk_id_count_++); } void IDMgr::SaveIdAndTaskIndex(IdState* id_state) { id_state->regst_desc_id_state_ = regst_desc_id_count_; id_state->mem_block_id_state_ = mem_block_id_count_; id_state->chunk_id_state_ = chunk_id_count_; task_id_gen_.GetTaskIndex(&id_state->task_index_state_); } void IDMgr::TryUpdateIdAndTaskIndex(const IdState* id_state) { regst_desc_id_count_ = std::max(regst_desc_id_count_.load(std::memory_order_relaxed), id_state->regst_desc_id_state_); mem_block_id_count_ = std::max(mem_block_id_count_.load(std::memory_order_relaxed), id_state->mem_block_id_state_); chunk_id_count_ = std::max(chunk_id_count_.load(std::memory_order_relaxed), id_state->chunk_id_state_); task_id_gen_.TryUpdateTaskIndex(id_state->task_index_state_); } } // namespace oneflow ================================================ FILE: oneflow/core/job/id_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_ID_MANAGER_H_ #define ONEFLOW_CORE_JOB_ID_MANAGER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/id_state.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/graph/task_id_generator.h" namespace oneflow { class IDMgr final { public: OF_DISALLOW_COPY_AND_MOVE(IDMgr); ~IDMgr() = default; int64_t NewRegstDescId(); int64_t NewMemBlockId(); int64_t NewChunkId(); TaskIdGenerator* GetTaskIdGenerator() { return &task_id_gen_; } void SaveIdAndTaskIndex(IdState* id_state); void TryUpdateIdAndTaskIndex(const IdState* id_state); private: friend class Singleton; IDMgr(); std::atomic regst_desc_id_count_; std::atomic mem_block_id_count_; std::atomic chunk_id_count_; TaskIdGenerator task_id_gen_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_ID_MANAGER_H_ ================================================ FILE: oneflow/core/job/id_manager_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/global_for.h" namespace oneflow { namespace { static const int64_t machine_id_shl = 11 + 21 + 21; static const int64_t thread_id_shl = 21 + 21; static const int64_t local_work_stream_shl = 21; EnvProto GetEnvProto() { EnvProto ret; for (size_t i = 0; i < 10; ++i) { auto* machine = ret.add_machine(); machine->set_id(i); machine->set_addr("192.168.1." + std::to_string(i)); } ret.set_ctrl_port(9527); return ret; } Resource GetResource() { Resource ret; ret.set_machine_num(10); ret.set_cpu_device_num(5); ret.set_comm_net_worker_num(4); return ret; } void New() { Singleton::New(GetEnvProto()); Singleton::New(); Singleton::Get()->mutable_ctrl_addr()->Add(); Singleton::Get()->set_rank(0); Singleton::Get()->set_node_size(1); Singleton::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode()); Singleton::New(); } void Delete() { Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); Singleton::Delete(); } } // namespace TEST(IDMgr, compile_regst_desc_id) { New(); ASSERT_EQ(Singleton::Get()->NewRegstDescId(), 0); ASSERT_EQ(Singleton::Get()->NewRegstDescId(), 1); ASSERT_EQ(Singleton::Get()->NewRegstDescId(), 2); Delete(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/id_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_ID_STATE_H_ #define ONEFLOW_CORE_JOB_ID_STATE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/device/device_id.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/graph/task_id.h" namespace oneflow { class IdState { public: int64_t regst_desc_id_state_{}; int64_t mem_block_id_state_{}; int64_t chunk_id_state_{}; int64_t job_id_state_{}; HashMap task_index_state_{}; HashMap stream_index_state_{}; }; } // namespace oneflow #endif ================================================ FILE: oneflow/core/job/initializer_conf.proto ================================================ syntax = "proto2"; package oneflow; message ConstantInitializerConf { optional float value = 1 [default = 0]; } message ConstantIntInitializerConf { optional int64 value = 1 [default = 0]; } message RandomNormalInitializerConf { optional float mean = 1 [default = 0]; optional float std = 2 [default = 1]; } //output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride message RangeInitializerConf { optional double start = 1 [default = 0]; optional double stride = 2 [default = 1]; optional int64 axis = 3 [default = -1]; } message IntRangeInitializerConf { optional int64 start = 1 [default = 0]; optional int64 stride = 2 [default = 1]; optional int64 axis = 3 [default = -1]; } message EmptyInitializerConf { } message InitializerConf { oneof type { ConstantInitializerConf constant_conf = 1; ConstantIntInitializerConf constant_int_conf = 2; RandomNormalInitializerConf random_normal_conf = 3; RangeInitializerConf range_conf = 4; IntRangeInitializerConf int_range_conf = 5; EmptyInitializerConf empty_conf = 6; } } message InitializeWithSnapshotConf { required string path = 1; optional string key = 2; } ================================================ FILE: oneflow/core/job/inter_job_mem_sharing_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/inter_job_mem_sharing_util.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" namespace oneflow { namespace { void GetOpName2JobId2TaskProtos( Plan* plan, const HashSet& op_names, HashMap>>* op_name2job_id2task_protos) { for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); if (task->exec_sequence().exec_node_size() == 1) { const KernelConf& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); std::string op_name = PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name(); if (op_names.find(op_name) != op_names.end()) { CHECK(task->has_parallel_ctx()); (*op_name2job_id2task_protos)[op_name][task->job_id()].emplace_back(task); } } } for (auto& op2job_task_pair : *op_name2job_id2task_protos) { for (auto& job2task_pair : op2job_task_pair.second) { std::vector& task_protos = job2task_pair.second; std::sort(task_protos.begin(), task_protos.end(), [](const TaskProto* lhs, const TaskProto* rhs) { return lhs->parallel_ctx().parallel_id() < rhs->parallel_ctx().parallel_id(); }); } } } HashMap> GetInterfaceOpName2JobIds( const std::vector>& jobs) { HashMap> interface_op_name2job_ids; HashSet unique_op_name_check; FOR_RANGE(int64_t, i, 0, jobs.size()) { const auto& job = *jobs.at(i); for (const auto& op : job.net().op()) { if (IsInterfaceOpConf(op)) { CHECK(interface_op_name2job_ids[op.name()].emplace(i).second); unique_op_name_check.emplace(op.name()); } else { // interface ops shouldn't share op_name with other ops CHECK(unique_op_name_check.find(op.name()) == unique_op_name_check.end()); } } } return interface_op_name2job_ids; } std::vector> InitJobId2MutualExclusionJobIds( const std::vector>& jobs) { int64_t job_size = jobs.size(); std::vector> job_id2mutual_exclusion_ids(job_size); for (const auto& pair : GetInterfaceOpName2JobIds(jobs)) { for (int64_t first_id : pair.second) { for (int64_t second_id : pair.second) { if (first_id != second_id) { job_id2mutual_exclusion_ids[first_id].emplace(second_id); } } } } const InterJobReuseMemStrategy* strategy = Singleton::Get(); if (strategy->has_custom_parallelism()) { auto* job_name2job_id = Singleton::Get(); for (const auto& group : strategy->custom_parallelism().nonparallel_group()) { for (const std::string& first_name : group.job_name()) { for (const std::string& second_name : group.job_name()) { if (first_name != second_name) { CHECK(job_name2job_id->find(first_name) != job_name2job_id->end()); CHECK(job_name2job_id->find(second_name) != job_name2job_id->end()); int64_t first_id = (*job_name2job_id)[first_name]; int64_t second_id = (*job_name2job_id)[second_name]; job_id2mutual_exclusion_ids[first_id].emplace(second_id); } } } } } return job_id2mutual_exclusion_ids; } std::vector> GetMutualExclusionJobGroups( const std::vector>& jobs) { int64_t job_size = jobs.size(); std::vector> job_groups; job_groups.reserve(job_size); if (Singleton::Get()->has_reuse_mem_priority()) { job_groups.emplace_back(HashSet()); FOR_RANGE(int64_t, i, 0, job_size) { job_groups.front().emplace(i); } return job_groups; } // default using parallelism_priority strategy std::vector> job_id2mutual_exclusion_ids = InitJobId2MutualExclusionJobIds(jobs); std::vector> job_id2enable_parallel_ids(job_size); FOR_RANGE(int64_t, i, 0, job_size) { FOR_RANGE(int64_t, j, 0, job_size) { if (job_id2mutual_exclusion_ids[i].find(j) == job_id2mutual_exclusion_ids[i].end()) { job_id2enable_parallel_ids[i].emplace(j); } } } int64_t mem_share_group_num = 0; std::vector job_id2mem_share_group_id(job_size, -1); FOR_RANGE(int64_t, this_job_id, 0, job_size) { HashSet mem_share_group_id_used; for (int64_t enable_parallel_job_id : job_id2enable_parallel_ids[this_job_id]) { int64_t group_id = job_id2mem_share_group_id[enable_parallel_job_id]; if (group_id != -1) { mem_share_group_id_used.emplace(group_id); } } FOR_RANGE(int64_t, this_group_id, 0, mem_share_group_num) { if (mem_share_group_id_used.find(this_group_id) == mem_share_group_id_used.end()) { job_id2mem_share_group_id[this_job_id] = this_group_id; break; } } if (job_id2mem_share_group_id[this_job_id] == -1) { job_id2mem_share_group_id[this_job_id] = mem_share_group_num; ++mem_share_group_num; CHECK_LE(mem_share_group_num, job_size); } } job_groups.resize(mem_share_group_num); FOR_RANGE(int64_t, this_job_id, 0, job_size) { job_groups[job_id2mem_share_group_id[this_job_id]].emplace(this_job_id); } { HashSet job_id_unique_check; for (auto& job_group : job_groups) { for (int64_t job_id : job_group) { CHECK(job_id_unique_check.emplace(job_id).second); } } } return job_groups; } void MergeReusedChunk(HashMap* chunk_id2chunk, HashMap* mem_block_id2mem_block, const std::vector>& reuse_mem_job_groups) { // mzuid = memory zone unique id HashMap> job_id2mzuid2chunk_id; HashMap> chunk_id2mem_blocks; for (auto& pair : *mem_block_id2mem_block) { MemBlockProto* mem_block = pair.second; if (mem_block->enable_reuse_mem() == false) { CHECK(mem_block->has_chunk_id() == false); CHECK(mem_block->has_chunk_offset() == false); continue; } CHECK(mem_block->has_chunk_id() && mem_block->chunk_id() >= 0); CHECK(mem_block->has_chunk_offset() && mem_block->chunk_offset() >= 0); CHECK(chunk_id2mem_blocks[mem_block->chunk_id()].insert(mem_block).second); } // merge chunk and delete useless chunk for (const auto& pair : *chunk_id2chunk) { const ChunkProto& chunk = pair.second; const MemoryCase& mem_case = chunk.mem_case(); // NOTE(zwx): do not reuse mem on cpu if (memory::IsHostMem(mem_case)) { continue; } int64_t mzuid = memory::GetUniqueMemCaseId(chunk.machine_id(), mem_case); CHECK_EQ(chunk.job_id_size(), 1); CHECK(job_id2mzuid2chunk_id[chunk.job_id(0)].emplace(mzuid, chunk.chunk_id()).second); } auto MergeMemChunkIdR2L = [&](int64_t left_chunk_id, int64_t right_chunk_id) { CHECK_NE(left_chunk_id, right_chunk_id); ChunkProto* chunk_l = &(chunk_id2chunk->at(left_chunk_id)); ChunkProto* chunk_r = &(chunk_id2chunk->at(right_chunk_id)); CHECK_GE(chunk_l->job_id_size(), 1); CHECK_EQ(chunk_r->job_id_size(), 1); CHECK_EQ(chunk_l->machine_id(), chunk_r->machine_id()); CHECK(chunk_l->mem_case() == chunk_r->mem_case()); CHECK_GT(chunk_l->mem_size(), 0); CHECK_GT(chunk_r->mem_size(), 0); for (MemBlockProto* mem_block : chunk_id2mem_blocks[right_chunk_id]) { CHECK_EQ(mem_block->machine_id(), chunk_l->machine_id()); CHECK(mem_block->mem_case() == chunk_l->mem_case()); mem_block->set_chunk_id(left_chunk_id); } chunk_l->add_job_id(chunk_r->job_id(0)); chunk_l->set_mem_size(std::max(chunk_l->mem_size(), chunk_r->mem_size())); chunk_id2chunk->erase(chunk_id2chunk->find(right_chunk_id)); }; auto InitMzuid2JobIdsInJobGroup = [&](const HashSet& job_group) -> HashMap> { HashMap> mzuid2job_ids; for (int64_t job_id : job_group) { for (const auto& pair : job_id2mzuid2chunk_id[job_id]) { CHECK(mzuid2job_ids[pair.first].emplace(job_id).second); } } return mzuid2job_ids; }; for (const HashSet& job_group : reuse_mem_job_groups) { if (job_group.size() <= 1) { continue; } HashMap> mzuid2job_ids = InitMzuid2JobIdsInJobGroup(job_group); for (const auto& pair : mzuid2job_ids) { const HashSet& job_ids = pair.second; if (job_ids.size() <= 1) { continue; } int64_t mzuid = pair.first; int64_t merged_job_id = *(job_ids.begin()); for (int64_t job_id : job_ids) { if (job_id == merged_job_id) { continue; } MergeMemChunkIdR2L(job_id2mzuid2chunk_id[merged_job_id].at(mzuid), job_id2mzuid2chunk_id[job_id].at(mzuid)); } } } } void MergeSharedMemBlockR2L(RegstDescProto* lhs, RegstDescProto* rhs, HashMap* mem_block_id2mem_block) { if (lhs == rhs) { return; } auto CheckValidAndGetMemBlock = [&](int64_t mem_block_id, int64_t mem_size, const MemoryCase& mem_case) { CHECK_NE(mem_block_id, -1); CHECK(mem_block_id2mem_block->find(mem_block_id) != mem_block_id2mem_block->end()); MemBlockProto* mem_block = &(mem_block_id2mem_block->at(mem_block_id)); CHECK(mem_block->enable_reuse_mem() == false); CHECK(mem_block->has_chunk_id() == false); CHECK(mem_block->has_chunk_offset() == false); CHECK_EQ(mem_block->mem_size(), mem_size); CHECK(mem_block->mem_case() == mem_case); return mem_block; }; auto MergeAndEraseMemBlock = [&](MemBlockProto* merged_block, MemBlockProto* erased_block) { CHECK_NE(merged_block->mem_block_id(), erased_block->mem_block_id()); CHECK_EQ(erased_block->job_id_size(), 1); CHECK_EQ(merged_block->mem_size(), erased_block->mem_size()); merged_block->add_job_id(erased_block->job_id(0)); CHECK_EQ(mem_block_id2mem_block->erase(erased_block->mem_block_id()), 1); }; int64_t merged_mem_block_id = lhs->mem_block_id(); int64_t erased_mem_block_id = rhs->mem_block_id(); CHECK(lhs->enable_reuse_mem() == false && rhs->enable_reuse_mem() == false); CHECK_EQ(lhs->mem_block_offset(), 0); CHECK_EQ(rhs->mem_block_offset(), 0); RtRegstDesc left_rt_regst(*lhs); RtRegstDesc right_rt_regst(*rhs); MemBlockProto* merged_mem_block = CheckValidAndGetMemBlock( merged_mem_block_id, left_rt_regst.TotalMainByteSize4AllRegst(), lhs->mem_case()); MemBlockProto* erased_mem_block = CheckValidAndGetMemBlock( erased_mem_block_id, right_rt_regst.TotalMainByteSize4AllRegst(), rhs->mem_case()); MergeAndEraseMemBlock(merged_mem_block, erased_mem_block); rhs->set_mem_block_id(merged_mem_block_id); int64_t separated_header_mem_size = left_rt_regst.TotalSeparatedHeaderByteSize4AllRegst(); if (separated_header_mem_size > 0) { CHECK_EQ(separated_header_mem_size, right_rt_regst.TotalSeparatedHeaderByteSize4AllRegst()); int64_t merged_header_id = lhs->separated_header_mem_block_id(); int64_t erased_header_id = rhs->separated_header_mem_block_id(); MemoryCase header_mem_case = memory::GetPinnedHostMemoryCase(lhs->mem_case()); MemBlockProto* merged_header_block = CheckValidAndGetMemBlock(merged_header_id, separated_header_mem_size, header_mem_case); MemBlockProto* erased_header_block = CheckValidAndGetMemBlock(erased_header_id, separated_header_mem_size, header_mem_case); MergeAndEraseMemBlock(merged_header_block, erased_header_block); rhs->set_separated_header_mem_block_id(merged_header_id); } } void MergeSharedInterfaceMemBlock(const std::vector>& jobs, Plan* plan, HashMap* mem_block_id2mem_block) { HashMap> interface_op_name2job_ids = GetInterfaceOpName2JobIds(jobs); HashSet interfaces_op_names; for (const auto& pair : interface_op_name2job_ids) { interfaces_op_names.insert(pair.first); } HashMap>> op_name2job_id2task_protos; GetOpName2JobId2TaskProtos(plan, interfaces_op_names, &op_name2job_id2task_protos); for (const auto& op_job_pair : interface_op_name2job_ids) { if (op_job_pair.second.size() <= 1) { continue; } const HashMap>& job_id2same_op_name_sorted_task_protos = op_name2job_id2task_protos.at(op_job_pair.first); const auto& first_vec = job_id2same_op_name_sorted_task_protos.begin()->second; std::vector common_mem_case_vec(first_vec.size()); std::transform( first_vec.cbegin(), first_vec.cend(), common_mem_case_vec.begin(), [](TaskProto* tp) { return PlanUtil::GetSoleProducedDataRegst(tp)->mem_case(); }); for (const auto& pair : job_id2same_op_name_sorted_task_protos) { const auto& task_protos = pair.second; CHECK_EQ(task_protos.size(), first_vec.size()); FOR_RANGE(int64_t, i, 0, first_vec.size()) { CHECK_EQ(task_protos.at(i)->machine_id(), first_vec.at(i)->machine_id()); RegstDescProto* first_regst_desc = PlanUtil::GetSoleProducedDataRegst(first_vec.at(i)); RegstDescProto* regst_desc = PlanUtil::GetSoleProducedDataRegst(task_protos.at(i)); MergeSharedMemBlockR2L(first_regst_desc, regst_desc, mem_block_id2mem_block); CHECK(memory::EqualsIgnorePinnedDevice(common_mem_case_vec.at(i), regst_desc->mem_case())); common_mem_case_vec[i] = regst_desc->mem_case(); } } for (const auto& pair : job_id2same_op_name_sorted_task_protos) { const auto& task_protos = pair.second; FOR_RANGE(int64_t, i, 0, task_protos.size()) { RegstDescProto* regst_desc = PlanUtil::GetSoleProducedDataRegst(task_protos.at(i)); *(regst_desc->mutable_mem_case()) = common_mem_case_vec.at(i); CHECK(mem_block_id2mem_block->find(regst_desc->mem_block_id()) != mem_block_id2mem_block->end()); *(mem_block_id2mem_block->at(regst_desc->mem_block_id()).mutable_mem_case()) = common_mem_case_vec.at(i); } } } } } // namespace void InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs( const std::vector>& jobs, Plan* plan) { if (jobs.size() == 1) { return; } HashMap mem_block_id2mem_block; for (const auto& mem_block : plan->block_chunk_list().mem_block()) { CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } plan->mutable_block_chunk_list()->clear_mem_block(); MergeSharedInterfaceMemBlock(jobs, plan, &mem_block_id2mem_block); for (const auto& pair : mem_block_id2mem_block) { *(plan->mutable_block_chunk_list()->add_mem_block()) = pair.second; } } void InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs( const std::vector>& user_jobs, Plan* plan) { if (user_jobs.size() == 1) { return; } std::vector> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs); HashMap chunk_id2chunk; HashMap mem_block_id2mem_block; for (const auto& chunk : plan->block_chunk_list().chunk()) { CHECK(chunk_id2chunk.emplace(chunk.chunk_id(), chunk).second); } plan->mutable_block_chunk_list()->clear_chunk(); for (MemBlockProto& mem_block : *plan->mutable_block_chunk_list()->mutable_mem_block()) { CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), &mem_block).second); } MergeReusedChunk(&chunk_id2chunk, &mem_block_id2mem_block, reuse_mem_job_groups); for (const auto& pair : chunk_id2chunk) { *(plan->mutable_block_chunk_list()->add_chunk()) = pair.second; } } } // namespace oneflow ================================================ FILE: oneflow/core/job/inter_job_mem_sharing_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_ #define ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_ #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/plan.pb.h" namespace oneflow { struct InterJobMemSharingUtil { static void MergeMemSharedInterfaceMemBlockBetweenJobs( const std::vector>& jobs, Plan* plan); static void MergeMemReusedChunkBetweenUserJobs(const std::vector>& user_jobs, Plan* plan); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_ ================================================ FILE: oneflow/core/job/inter_user_job_info.proto ================================================ syntax = "proto2"; package oneflow; message InterUserJobInfo { map input_or_var_op_name2push_job_name = 1; map output_or_var_op_name2pull_job_name = 2; optional string global_model_init_job_name = 4; optional string global_model_load_job_name = 5; optional string global_model_save_job_name = 6; } ================================================ FILE: oneflow/core/job/intra_job_mem_sharing_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/intra_job_mem_sharing_util.h" #include #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/memory_share_strategy.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/graph/task_node.h" #include "oneflow/core/job/plan_util.h" namespace oneflow { enum MemAllocAlgoType { kMemSizeFirstAlgo = 0, kLifetimeFirstAlgo = 1, kTimeLineAlgo = 2, kMemVolumeFirstAlgo = 3, }; } // namespace oneflow namespace std { template<> struct hash<::oneflow::MemAllocAlgoType> { std::size_t operator()(const ::oneflow::MemAllocAlgoType& type) const { return std::hash()(static_cast(type)); } }; } // namespace std namespace oneflow { namespace { int64_t GenDeviceUniqueId(int64_t machine_id, int64_t device_id) { return (machine_id << 32) | device_id; } void TryConnectWithMemSafeGuardCtrlRegstDesc(TaskProto* src_task_proto, TaskProto* dst_task_proto) { RegstDescProto* ctrl_regst_desc = FindOrCreateProducedCtrlRegstDesc(src_task_proto, "out_ctrl_shared_mem_safe_guard"); int64_t dst_task_id = dst_task_proto->task_id(); if (!IsInRepeatedField(ctrl_regst_desc->consumer_task_id(), dst_task_id)) { ctrl_regst_desc->add_consumer_task_id(dst_task_id); int64_t ctrl_regst_desc_id = ctrl_regst_desc->regst_desc_id(); RegstDescIdSet* consumed_ctrl_regst_desc_ids = FindOrCreateConsumedCtrlRegstDescIdSet(dst_task_proto, "in_ctrl"); CHECK(!IsInRepeatedField(consumed_ctrl_regst_desc_ids->regst_desc_id(), ctrl_regst_desc_id)); consumed_ctrl_regst_desc_ids->add_regst_desc_id(ctrl_regst_desc_id); } } struct MemoryChain { std::vector sorted_tasks; HashSet mem_reused_regsts; int64_t total_mem_reused_size = 0; Shape time_shape; }; void InitMemoryChains(Plan* plan, HashMap>* device2chain2mem_chain, HashMap* mem_reused_regst2size) { for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); const StreamId stream_id = PlanUtil::GetStreamId(*task); int64_t machine_id = task->machine_id(); DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } if (!IsValidChainId(task->chain_id())) { continue; } int64_t device_id = stream_id.device_id().device_index(); int64_t device_unique_id = GenDeviceUniqueId(machine_id, device_id); MemoryChain* mem_chain = &((*device2chain2mem_chain)[device_unique_id][task->chain_id()]); mem_chain->sorted_tasks.emplace_back(task); for (auto& pair : *(task->mutable_produced_regst_desc())) { RegstDescProto* regst_desc = &pair.second; int64_t regst_total_main_size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst(); if (regst_desc->mem_case().device_type() == device_type && regst_desc->mem_case().device_id() == device_id && regst_desc->enable_reuse_mem() && regst_desc->register_num() == 1 && regst_desc->mem_block_id() == -1 && regst_desc->mem_block_offset() == -1 && regst_desc->regst_desc_type().has_data_regst_desc() && regst_total_main_size > 0) { CHECK(mem_chain->mem_reused_regsts.insert(regst_desc).second); (*mem_reused_regst2size)[regst_desc] = regst_total_main_size; mem_chain->total_mem_reused_size += regst_total_main_size; // for time shape in mem chain Shape regst_time_shape = Shape(regst_desc->regst_desc_type().data_regst_desc().time_shape()); if (!mem_chain->time_shape.is_initialized()) { mem_chain->time_shape = regst_time_shape; } else { CHECK(mem_chain->time_shape == regst_time_shape); } } } } for (auto& device_pair : *device2chain2mem_chain) { HashMap* chain2mem_chain = &device_pair.second; HashSet useless_chain_ids; for (auto& pair : *chain2mem_chain) { if (pair.second.mem_reused_regsts.empty()) { useless_chain_ids.insert(pair.first); } } for (int64_t chain_id : useless_chain_ids) { chain2mem_chain->erase(chain_id); } for (auto& pair : *chain2mem_chain) { MemoryChain* mem_chain = &pair.second; std::sort(mem_chain->sorted_tasks.begin(), mem_chain->sorted_tasks.end(), [&](const TaskProto* lhs, const TaskProto* rhs) { int64_t lhs_order_in_chain = lhs->order_in_chain(); int64_t rhs_order_in_chain = rhs->order_in_chain(); CHECK_NE(lhs_order_in_chain, rhs_order_in_chain); return lhs_order_in_chain < rhs_order_in_chain; }); } } } bool IsReachableToAnyOtherTask(const TaskProto* src_task, const HashSet& task_ids) { for (const auto& pair : src_task->produced_regst_desc()) { for (int64_t consumer : pair.second.consumer_task_id()) { if (task_ids.find(consumer) != task_ids.end()) { return true; } } } return false; } bool IsTaskConnectedL2R(const TaskProto* src, const TaskProto* dst) { for (const auto& pair : src->produced_regst_desc()) { for (int64_t consumer : pair.second.consumer_task_id()) { if (consumer == dst->task_id()) { return true; } } } return false; } void GenMemChainTasksAndRegsts( Plan* plan, HashMap>* mem_chain2sorted_tasks, HashMap>* mem_chain2mem_reused_regsts, HashMap>* mem_chain2regst_desc_id2reuse_regst_desc, HashMap* mem_reused_regst2size) { mem_chain2sorted_tasks->clear(); mem_chain2mem_reused_regsts->clear(); HashMap> device2chain2mem_chain; InitMemoryChains(plan, &device2chain2mem_chain, mem_reused_regst2size); int64_t mem_chain_id = 0; for (auto& device_chain_pair : device2chain2mem_chain) { if (device_chain_pair.second.empty()) { continue; } std::vector mem_chains; mem_chains.reserve(device_chain_pair.second.size()); for (auto& pair : device_chain_pair.second) { mem_chains.emplace_back(&pair.second); } for (MemoryChain* mem_chain : mem_chains) { std::vector* sorted_tasks = &((*mem_chain2sorted_tasks)[mem_chain_id]); CHECK(sorted_tasks->empty()); sorted_tasks->insert(sorted_tasks->end(), mem_chain->sorted_tasks.begin(), mem_chain->sorted_tasks.end()); std::vector* mem_reused_regsts = &((*mem_chain2mem_reused_regsts)[mem_chain_id]); CHECK(mem_reused_regsts->empty()); mem_reused_regsts->insert(mem_reused_regsts->end(), mem_chain->mem_reused_regsts.begin(), mem_chain->mem_reused_regsts.end()); // Merge HashSet mem_chain2mem_reused_regsts and HashMap regst_desc_id2reuse_regst_desc auto& regst_desc_id2reuse_regst_desc = (*mem_chain2regst_desc_id2reuse_regst_desc)[mem_chain_id]; CHECK(regst_desc_id2reuse_regst_desc.empty()); for (auto& mem_reused_regst : mem_chain->mem_reused_regsts) { regst_desc_id2reuse_regst_desc[mem_reused_regst->regst_desc_id()] = mem_reused_regst; } ++mem_chain_id; } } CHECK_EQ(mem_chain2sorted_tasks->size(), mem_chain2mem_reused_regsts->size()); // NOTE(chengcheng): add ctrl safe guard for each mem chain HashMap task_id2proto; for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); CHECK(task_id2proto.emplace(task->task_id(), task).second); } for (auto& pair : *mem_chain2sorted_tasks) { std::vector* sorted_tasks = &(pair.second); // NOTE(chengcheng): We CANNOT only add ctrl safe guard between first and last task, // because of the sorted_tasks may connected as a graph, has multi-tail tasks(sink task). const std::vector& mem_reused_regsts = mem_chain2mem_reused_regsts->at(pair.first); if (mem_reused_regsts.size() <= 1) { continue; } HashSet consumer_task_ids; for (const RegstDescProto* regst : mem_reused_regsts) { for (int64_t consumer : regst->consumer_task_id()) { consumer_task_ids.insert(consumer); } } std::vector sink_tasks; sink_tasks.reserve(consumer_task_ids.size()); for (int64_t src_task_id : consumer_task_ids) { auto it = task_id2proto.find(src_task_id); CHECK(it != task_id2proto.end()); if (!IsReachableToAnyOtherTask(it->second, consumer_task_ids)) { sink_tasks.emplace_back(it->second); } } TaskProto* first_task = sorted_tasks->front(); for (TaskProto* sink_task : sink_tasks) { CHECK(first_task != sink_task); if (!IsTaskConnectedL2R(first_task, sink_task)) { TryConnectWithMemSafeGuardCtrlRegstDesc(first_task, sink_task); } } } } void GenRegstAllocFreeTimeLineAndRegstLifetimes( const std::vector& sorted_tasks, const std::vector& mem_reused_regsts, const HashMap& regst_desc_id2reuse_regst_desc, const HashMap& mem_reused_regst2size, HashMap>* regst2lifetime, HashMap* consumer2inplaced_regst, size_t* peak_memory) { CHECK(consumer2inplaced_regst->empty()); std::vector> alloc_regsts_timeline(sorted_tasks.size()); std::vector> free_regsts_timeline(sorted_tasks.size()); HashMap task_id2sorted_id; for (int64_t i = 0; i < sorted_tasks.size(); ++i) { TaskProto* task = sorted_tasks.at(i); CHECK(task_id2sorted_id.emplace(task->task_id(), i).second); } auto FindLastFreeIndexInSortedTasks = [&](RegstDescProto* regst_desc) -> int64_t { // temp regst will set free index as same as alloc index int64_t free_index = task_id2sorted_id.at(regst_desc->producer_task_id()); for (int64_t consumer_task_id : regst_desc->consumer_task_id()) { // if consumer is not in this mem chain, set free index = last index int64_t this_sorted_index = sorted_tasks.size() - 1; if (task_id2sorted_id.find(consumer_task_id) != task_id2sorted_id.end()) { this_sorted_index = task_id2sorted_id.at(consumer_task_id); } free_index = std::max(free_index, this_sorted_index); } return free_index; }; auto TryFindFirstInplacedRegstDesc = [&](RegstDescProto* consumer_regst) -> RegstDescProto* { RegstDescProto* inplaced_regst = nullptr; while (consumer_regst->has_hint_inplace_consumed_regst_desc_id() && consumer_regst->hint_inplace_consumed_regst_desc_id() != -1) { const auto& iterator_hint_inplaced_regst = regst_desc_id2reuse_regst_desc.find( consumer_regst->hint_inplace_consumed_regst_desc_id()); if (iterator_hint_inplaced_regst != regst_desc_id2reuse_regst_desc.end()) { inplaced_regst = iterator_hint_inplaced_regst->second; consumer_regst = iterator_hint_inplaced_regst->second; } else { break; } } return inplaced_regst; }; HashMap regst_desc_id2free_index; for (RegstDescProto* regst_desc : mem_reused_regsts) { RegstDescProto* inplaced_regst_desc = TryFindFirstInplacedRegstDesc(regst_desc); if (inplaced_regst_desc != nullptr) { CHECK(consumer2inplaced_regst->emplace(regst_desc, inplaced_regst_desc).second); continue; } alloc_regsts_timeline[task_id2sorted_id.at(regst_desc->producer_task_id())].push_back( regst_desc); CHECK(regst_desc_id2free_index .emplace(regst_desc->regst_desc_id(), FindLastFreeIndexInSortedTasks(regst_desc)) .second); } // inplace extend regst free index for (auto pair : *consumer2inplaced_regst) { RegstDescProto* consumer_regst_desc = pair.first; int64_t inplaced_regst_desc_id = pair.second->regst_desc_id(); CHECK(regst_desc_id2free_index.find(inplaced_regst_desc_id) != regst_desc_id2free_index.end()); regst_desc_id2free_index.at(inplaced_regst_desc_id) = std::max(regst_desc_id2free_index.at(inplaced_regst_desc_id), FindLastFreeIndexInSortedTasks(consumer_regst_desc)); } for (const auto& pair : regst_desc_id2free_index) { free_regsts_timeline[pair.second].push_back(regst_desc_id2reuse_regst_desc.at(pair.first)); } HashSet remain_regsts; size_t remain_memory = 0; *peak_memory = 0; for (int64_t i = 0; i < sorted_tasks.size(); ++i) { for (RegstDescProto* alloc_regst : alloc_regsts_timeline.at(i)) { // Record the born time (*regst2lifetime)[alloc_regst].first = i; CHECK(remain_regsts.insert(alloc_regst).second); remain_memory += mem_reused_regst2size.at(alloc_regst); // NOTE(chengcheng): insert time line to regst proto alloc_regst->set_mem_block_total_actor_count(sorted_tasks.size()); alloc_regst->set_alloc_before_actor(i); } // Update the peak of memory during execution if (*peak_memory < remain_memory) { *peak_memory = remain_memory; } for (RegstDescProto* free_regst : free_regsts_timeline.at(i)) { CHECK_EQ(remain_regsts.erase(free_regst), 1); free_regst->set_free_after_actor(i); remain_memory -= mem_reused_regst2size.at(free_regst); // Record the die time (*regst2lifetime)[free_regst].second = i + 1; } } // Make sure that every register has a die time CHECK(remain_regsts.empty()); } void MemReusedLifetimeFirstAlgo( const bool compact_insert, const HashMap>& regst2lifetime, const HashMap& mem_reused_regst2size, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2lifetime.size()); for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); } std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) { int64_t l_value = regst2lifetime.at(lhs).second - regst2lifetime.at(lhs).first; int64_t r_value = regst2lifetime.at(rhs).second - regst2lifetime.at(rhs).first; if (l_value == r_value) { return regst2lifetime.at(lhs).first < regst2lifetime.at(rhs).first; } return l_value > r_value; }); MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime, result); } void MemReusedTimeLineAlgo( const bool compact_insert, const HashMap>& regst2lifetime, const HashMap& mem_reused_regst2size, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2lifetime.size()); for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); } std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) { int64_t l_value = regst2lifetime.at(lhs).first; int64_t r_value = regst2lifetime.at(rhs).first; if (l_value == r_value) { return regst2lifetime.at(lhs).second > regst2lifetime.at(rhs).second; } return l_value > r_value; }); MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime, result); } void MemReusedMemVolumeFirstAlgo( const bool compact_insert, const HashMap>& regst2lifetime, const HashMap& mem_reused_regst2size, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2lifetime.size()); auto ComputeMemoryVolume = [&](RegstDescProto* key) { return mem_reused_regst2size.at(key) * (regst2lifetime.at(key).second - regst2lifetime.at(key).first) / 1000; }; for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); } std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) { size_t l_value = ComputeMemoryVolume(lhs); size_t r_value = ComputeMemoryVolume(rhs); if (l_value == r_value) { return mem_reused_regst2size.at(lhs) > mem_reused_regst2size.at(rhs); } return l_value > r_value; }); MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime, result); } void SelectAlgorithmGenMemBlockOffset4Regsts( MemAllocAlgoType algo_id, const bool compact_insert, const HashMap>& regst2lifetime, const HashMap& mem_reused_regst2size, MemBlockResultInfo* result) { CHECK_EQ(result->mem_block_size, 0); CHECK(result->regst_desc2offset.empty()); switch (algo_id) { case kMemSizeFirstAlgo: MemReusedMemSizeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result); break; case kLifetimeFirstAlgo: MemReusedLifetimeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result); break; case kTimeLineAlgo: MemReusedTimeLineAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result); break; case kMemVolumeFirstAlgo: MemReusedMemVolumeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result); break; default: UNIMPLEMENTED(); } CHECK_GT(result->mem_block_size, 0); CHECK(!result->regst_desc2offset.empty()); } int64_t CountMemAllocAlgoNum() { const MemoryAllocationAlgorithmConf& mem_alloc_algo_conf = GlobalJobDesc().job_conf().memory_allocation_algorithm_conf(); int64_t alloc_algo_num = 0; if (mem_alloc_algo_conf.use_mem_size_first_algo()) { ++alloc_algo_num; } if (mem_alloc_algo_conf.use_lifetime_first_algo()) { ++alloc_algo_num; } if (mem_alloc_algo_conf.use_time_line_algo()) { ++alloc_algo_num; } if (mem_alloc_algo_conf.use_mem_volume_first_algo()) { ++alloc_algo_num; } CHECK_GE(alloc_algo_num, 0) << "At least choose one type of memory allocation algorithm. We " "recommend use_mem_size_first_algo()"; const MemoryCompactInsertConf& mem_compact_insert_conf = GlobalJobDesc().job_conf().memory_compact_insert_conf(); int64_t compact_insert_num = 0; if (mem_compact_insert_conf.use_compact_insert()) { ++compact_insert_num; } if (mem_compact_insert_conf.use_non_compact_insert()) { ++compact_insert_num; } CHECK_GE(compact_insert_num, 0) << "At least choose one type of memory arrangement algorithm " "during memory allocation. We recommend use_compact_insert()"; return alloc_algo_num * compact_insert_num; } void InitAlgo2Result( HashMap, MemBlockResultInfo>* algo2result) { CHECK(algo2result->empty()); std::vector compact_insert_algorithms; const MemoryCompactInsertConf& mem_compact_insert_conf = GlobalJobDesc().job_conf().memory_compact_insert_conf(); if (mem_compact_insert_conf.use_compact_insert()) { compact_insert_algorithms.push_back(true); } if (mem_compact_insert_conf.use_non_compact_insert()) { compact_insert_algorithms.push_back(false); } const MemoryAllocationAlgorithmConf& mem_alloc_algo_conf = GlobalJobDesc().job_conf().memory_allocation_algorithm_conf(); // NOTE: Experiments show that memory first might be good enough for some cases. for (auto compact_insert : compact_insert_algorithms) { if (mem_alloc_algo_conf.use_mem_size_first_algo()) { (*algo2result)[{kMemSizeFirstAlgo, compact_insert}] = MemBlockResultInfo(); } if (mem_alloc_algo_conf.use_lifetime_first_algo()) { (*algo2result)[{kLifetimeFirstAlgo, compact_insert}] = MemBlockResultInfo(); } if (mem_alloc_algo_conf.use_time_line_algo()) { (*algo2result)[{kTimeLineAlgo, compact_insert}] = MemBlockResultInfo(); } if (mem_alloc_algo_conf.use_mem_volume_first_algo()) { (*algo2result)[{kMemVolumeFirstAlgo, compact_insert}] = MemBlockResultInfo(); } } } } // namespace void IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(Plan* plan) { // 1 device 1 mem chain HashMap> mem_chain2sorted_tasks; HashMap> mem_chain2mem_reused_regsts; // NOTE: We only store those reusable registers in mem_chain2regst_desc_id2reuse_regst_desc. // There are no duplicated registers in different memory chains. HashMap> mem_chain2regst_desc_id2reuse_regst_desc; HashMap mem_reused_regst2size; GenMemChainTasksAndRegsts(plan, &mem_chain2sorted_tasks, &mem_chain2mem_reused_regsts, &mem_chain2regst_desc_id2reuse_regst_desc, &mem_reused_regst2size); if (mem_chain2mem_reused_regsts.empty()) { return; } HashSet mem_chains; for (const auto& pair : mem_chain2mem_reused_regsts) { mem_chains.insert(pair.first); } // register lifetime HashMap>> mem_chain2regst2lifetime; // info for inplace HashMap> mem_chain2consumer2inplaced_regst; // info for straighten HashMap mem_chain2peak_memory; // step 1: generate regst alloc/free queue AND regst lifetimes for (const auto& pair : mem_chain2mem_reused_regsts) { GenRegstAllocFreeTimeLineAndRegstLifetimes( mem_chain2sorted_tasks.at(pair.first), pair.second, mem_chain2regst_desc_id2reuse_regst_desc.at(pair.first), mem_reused_regst2size, &mem_chain2regst2lifetime[pair.first], &mem_chain2consumer2inplaced_regst[pair.first], &mem_chain2peak_memory[pair.first]); } // step 2: multi-thread run several algorithm for each mem chain HashMap, MemBlockResultInfo>> mem_chain2algo2result; { int64_t work_size = mem_chain2mem_reused_regsts.size() * CountMemAllocAlgoNum(); int64_t thread_pool_size = std::min(work_size, std::thread::hardware_concurrency()); BlockingCounter counter(work_size); ThreadPool thread_pool(thread_pool_size); for (int64_t mem_chain_id : mem_chains) { InitAlgo2Result(&mem_chain2algo2result[mem_chain_id]); for (auto& pair : mem_chain2algo2result.at(mem_chain_id)) { MemAllocAlgoType algo_id = pair.first.first; bool compact_insert = pair.first.second; MemBlockResultInfo* result = &pair.second; thread_pool.AddWork([algo_id, compact_insert, mem_chain_id, &mem_chain2regst2lifetime, &mem_reused_regst2size, result, &counter]() { SelectAlgorithmGenMemBlockOffset4Regsts(algo_id, compact_insert, mem_chain2regst2lifetime.at(mem_chain_id), mem_reused_regst2size, result); counter.Decrease(); }); } } counter.WaitForeverUntilCntEqualZero(); } // step 3: choose best one for each mem chain and set offset for inplace consumer regst for (auto& pair : mem_chain2algo2result) { MemBlockResultInfo* best_result = nullptr; for (auto& algo_result_pair : pair.second) { if (!best_result || algo_result_pair.second.mem_block_size < best_result->mem_block_size) { best_result = &algo_result_pair.second; } } CHECK(best_result != nullptr); // Update the offset with a smaller total memory size if the current size is greater than the // lower bound if (GlobalJobDesc().job_conf().enable_compress_memory()) { MemoryShareStrategy mss; mss.AdaptivelyUpdateOffset(mem_reused_regst2size, mem_chain2regst2lifetime.at(pair.first), mem_chain2peak_memory[pair.first], &best_result->mem_block_size, &best_result->regst_desc2offset); } int64_t mem_block_id = Singleton::Get()->NewMemBlockId(); CHECK_EQ(mem_chain2mem_reused_regsts.at(pair.first).size(), (best_result->regst_desc2offset.size() + mem_chain2consumer2inplaced_regst.at(pair.first).size())); for (const auto& regst_offset_pair : best_result->regst_desc2offset) { RegstDescProto* regst_desc = regst_offset_pair.first; CHECK_EQ(regst_desc->mem_block_id(), -1); regst_desc->set_mem_block_id(mem_block_id); regst_desc->set_mem_block_offset(regst_offset_pair.second); } // set inplace for (auto& consumer_inplace_pair : mem_chain2consumer2inplaced_regst.at(pair.first)) { RegstDescProto* consumer_regst_desc = consumer_inplace_pair.first; CHECK_EQ(consumer_regst_desc->mem_block_id(), -1); RegstDescProto* inplaced_regst_desc = consumer_inplace_pair.second; CHECK_EQ(inplaced_regst_desc->mem_block_id(), mem_block_id); CHECK_NE(inplaced_regst_desc->mem_block_offset(), -1); consumer_regst_desc->set_mem_block_id(inplaced_regst_desc->mem_block_id()); consumer_regst_desc->set_mem_block_offset(inplaced_regst_desc->mem_block_offset()); } // set inplace hint and check const auto& regst_desc_id2reuse_regst_desc = mem_chain2regst_desc_id2reuse_regst_desc.at(pair.first); for (auto& consumer_inplace_pair : mem_chain2consumer2inplaced_regst.at(pair.first)) { RegstDescProto* consumer_regst_desc = consumer_inplace_pair.first; RegstDescProto* inplaced_regst_desc = consumer_inplace_pair.second; CHECK(consumer_regst_desc->has_inplace_consumed_regst_desc_id() == false); CHECK(consumer_regst_desc->has_hint_inplace_consumed_regst_desc_id()); int64_t hint = consumer_regst_desc->hint_inplace_consumed_regst_desc_id(); // NOTE(chengcheng): hint regst desc id may NOT be the inplaced_regst_desc_id // because of nest inplace. // NOTE: All the registers in mem_chain2consumer2inplaced_regst are reusable auto hint_it = regst_desc_id2reuse_regst_desc.find(hint); CHECK(hint_it != regst_desc_id2reuse_regst_desc.end()); RegstDescProto* in_regst_desc = hint_it->second; CHECK_EQ(consumer_regst_desc->mem_block_id(), in_regst_desc->mem_block_id()); CHECK_EQ(consumer_regst_desc->mem_block_offset(), in_regst_desc->mem_block_offset()); CHECK_EQ(in_regst_desc->mem_block_offset(), inplaced_regst_desc->mem_block_offset()); CHECK_EQ(consumer_regst_desc->register_num(), in_regst_desc->register_num()); consumer_regst_desc->set_inplace_consumed_regst_desc_id(hint); } } } } // namespace oneflow ================================================ FILE: oneflow/core/job/intra_job_mem_sharing_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_ #define ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/job/memory_share_strategy.h" #include "oneflow/core/job/plan.pb.h" #include #include namespace oneflow { struct IntraJobMemSharingUtil { static void InferMemBlockId4MemReusedRegst(Plan* plan); }; template struct MemBlockResultInfo { size_t mem_block_size; HashMap regst_desc2offset; }; // Judge whether a is suitable than b for a gap inline bool SuitableThan(int64_t a, int64_t b) { // The number have orders // A non-negative number is always more suitable than a negative number // If a number is non-negative, then the smaller the better // If a number is negative, then the larger the better // 0 > 1 > 2 > ... > 999999999 > -1 > -2 > ... > -99999999 // Now we flip the positive part to make it "the larger the better". if (a >= 0) { a = GetMaxVal() - a; } if (b >= 0) { b = GetMaxVal() - b; } return a > b; } template void MemReusedAlgorithmAllocateByOrder( const bool compact_insert, const std::vector& order, const HashMap& regst_desc2size, const HashMap>& regst2lifetime, MemBlockResultInfo* result) { HashMap* regst_desc2offset = &(result->regst_desc2offset); // NOTE: It is important to make the variables local. // It took me several days to find out that using passed-in vector for size, order, and lifetime // would double the running time. Switch HashMap to vector int32_t total_register_num = order.size(); std::vector order2size(total_register_num); std::vector> order2lifetime(total_register_num); std::vector order2offset(total_register_num); for (int32_t i = 0; i < total_register_num; i++) { order2size[i] = regst_desc2size.at(order[i]); order2lifetime[i] = regst2lifetime.at(order[i]); } size_t buffer_size = 1; // Sort by offset auto comp = [&order2offset](const auto& a, const auto& b) { if (order2offset[a] != order2offset[b]) { return order2offset[a] < order2offset[b]; } // Make sure we have a stable order even if we have the same offset for different registers return a < b; }; std::set sorted_registers(comp); // Decide offset following the given order for (int32_t inserting_id = 0; inserting_id < total_register_num; inserting_id++) { const auto& inserting_lifetime = order2lifetime[inserting_id]; // At the beginning, try to insert the offset in the front of the whole memory pool. int64_t inserting_offset = 0; int64_t inserting_end = inserting_offset + order2size[inserting_id]; if (compact_insert) { // Find the most suitable gap for the register int64_t gap_head = 0; int64_t inserting_size = order2size[inserting_id]; // difference = length of gap - length of the inserting register int64_t diff_gap = 0, suitable_diff_gap = -1 - inserting_size; for (const auto& curr_register : sorted_registers) { // Ignore those non-excluded registers if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) { if (gap_head < order2offset[curr_register]) { // Find one gap diff_gap = (order2offset[curr_register] - gap_head) - inserting_size; // Compared with the previous suitable gap if (SuitableThan(diff_gap, suitable_diff_gap)) { suitable_diff_gap = diff_gap; // We may insert the register into the gap inserting_offset = gap_head; } // Update gap head gap_head = order2offset[curr_register] + order2size[curr_register]; } else { // No gap, update gap head gap_head = std::max(gap_head, order2offset[curr_register] + order2size[curr_register]); } } } // Deal with the buffer_size, which may be the final gap diff_gap = (buffer_size - gap_head) - inserting_size; // Compared with the previous suitable gap if (SuitableThan(diff_gap, suitable_diff_gap)) { suitable_diff_gap = diff_gap; // We may insert the register into the gap inserting_offset = gap_head; } // If no gap large enough to contain the current register if (suitable_diff_gap < 0) { // Prolong the maximum memory pool size by (-suitable_diff_gap) buffer_size -= suitable_diff_gap; int64_t gap_end = suitable_diff_gap + inserting_size + inserting_offset; for (auto reverse_it = sorted_registers.rbegin(); reverse_it != sorted_registers.rend(); reverse_it++) { // All the registers with offset < gap_end maintain their position if (order2offset[*reverse_it] < gap_end) { break; } // All the registers with offset >= gap_end move backward order2offset[*reverse_it] -= suitable_diff_gap; } } } else { for (const auto& curr_register : sorted_registers) { // i: inserting register, j: current register // x: register offset, l: register size // If x_i + l_i <= x_j, then the inserting register would be placed at x_i if (order2offset[curr_register] >= inserting_end) { break; } // If i and j are excluded, and x_i + l_i > x_j, // then we try to place i at x_j + l_j and check the following registers if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) { int64_t curr_end = order2offset[curr_register] + order2size[curr_register]; // Can not set inserting offset = current end directly. // We might have two excluded registers like this: // register a: [100, 10000] // register b: [500, 600] if (inserting_offset < curr_end) { inserting_offset = curr_end; inserting_end = inserting_offset + order2size[inserting_id]; } } } // Update total size if (inserting_end > buffer_size) { buffer_size = inserting_end; } } // Either we break the loop or the loop terminated naturally, we can place i at inserting_offset order2offset[inserting_id] = inserting_offset; sorted_registers.insert(inserting_id); } result->mem_block_size = buffer_size; // Switch vector to HashMap for (int32_t i = 0; i < total_register_num; i++) { (*regst_desc2offset)[order[i]] = order2offset[i]; } } template void MemReusedMemSizeFirstAlgo(const bool compact_insert, const HashMap>& regst2lifetime, const HashMap& mem_reused_regst2size, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2lifetime.size()); for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); } std::sort(order.begin(), order.end(), [&](T lhs, T rhs) { size_t l_value = mem_reused_regst2size.at(lhs); size_t r_value = mem_reused_regst2size.at(rhs); if (l_value == r_value) { return regst2lifetime.at(lhs).first < regst2lifetime.at(rhs).first; } return l_value > r_value; }); MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime, result); } } // namespace oneflow #endif // ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_ ================================================ FILE: oneflow/core/job/job.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/dlnet_conf.proto"; import "oneflow/core/job/placement.proto"; import "oneflow/core/job/job_conf.proto"; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/register/op_blob_arg.proto"; import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/operator/op_conf.proto"; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/job/module_conf.proto"; message JobParallelViewConf { map op_name2sbp_signature_conf = 1; map op_name2is_local_parallel_view = 2; map op_name2nd_sbp_signature_conf = 3; } message JobHelperConf { map tag2lbi_relations = 1; map tag2op_name_relations = 2; map lbn2logical_blob_desc = 4; map lbn2logical_object_id = 5; map op_name2arg_signature = 9; } message MergedLogicalChainIdGroup { repeated int64 logical_chain_id_list = 1; } message Job { optional DLNetConf net = 1; optional Placement placement = 2; required JobConfigProto job_conf = 3; optional JobParallelViewConf job_parallel_view_conf = 4; optional JobHelperConf helper = 5; map module_name2module_conf = 6; repeated MergedLogicalChainIdGroup logical_chain_groups = 7; } ================================================ FILE: oneflow/core/job/job_build_and_infer_ctx.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/cost_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/user/summary/summary_converter.h" #include #include "nlohmann/json.hpp" namespace oneflow { static const std::string kAutoLocalBlobNamePrefix = "System-Local-Blob-Auto-Converted-From-Global-Blob"; namespace { void ResetOpConfName(OperatorConf* op_conf, const std::string& new_op_name) { op_conf->set_name(new_op_name); PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case()); UserOpConf* user_conf = dynamic_cast(op_type_conf); if (user_conf) { for (const auto& pair : user_conf->output()) { for (const std::string& old_lbn : pair.second.s()) { LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); auto blob_name_id_pair = GenUnRepeatedBn(old_lbi.blob_name()); std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name()); (*(user_conf->mutable_output()))[pair.first].set_s(blob_name_id_pair.second, new_lbn); } } } } Maybe GetOpNames(const Job& job, HashSet* op_names) { for (const auto& op_conf : job.net().op()) { CHECK_OR_RETURN(op_names->insert(op_conf.name()).second); } return Maybe::Ok(); } void UpdateOpName2AncestorsNeedNoGrad( const Operator& op, const std::function& Op4OpName, const bool is_train, HashMap* op_name2ancestors_need_no_grad) { bool no_grad = !is_train; auto IsTrainableVariableLbi = [&](const LogicalBlobId& lbi) { const auto& op_conf = Op4OpName(lbi.op_name())->op_conf(); return op_conf.has_variable_conf() && op_conf.variable_conf().trainable(); }; for (const auto& ibn : op.input_bns()) { const auto& lbi = op.BnInOp2Lbi(ibn); no_grad = no_grad && !IsTrainableVariableLbi(lbi); no_grad = no_grad && !op.InputBlobModifier4Ibn(ibn).requires_grad(); no_grad = no_grad && (*op_name2ancestors_need_no_grad)[lbi.op_name()]; } (*op_name2ancestors_need_no_grad)[op.op_name()] = no_grad; } } // namespace JobBuildAndInferCtx::JobBuildAndInferCtx(Job* job, int64_t job_id) : job_(job), job_id_(job_id), unique_op_name_index_(0) { is_job_conf_frozen_ = false; has_job_conf_ = false; } Maybe JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) { CHECK_OR_RETURN(!is_job_conf_frozen_) << Error::JobConfFrozenError(); CHECK_OR_RETURN(!has_job_conf_) << Error::JobConfRepeatedSetError(); has_job_conf_ = true; CHECK_EQ_OR_RETURN(job_->job_conf().job_name(), job_conf.job_name()) << Error::JobNameNotEqualError() << "job name you set: " << job_conf.job_name() << " not equal to origin job name: " << job_->job_conf().job_name(); job_->mutable_job_conf()->CopyFrom(job_conf); CHECK_ISNULL_OR_RETURN(Singleton::Get()); Singleton::New(job_conf, job_id_); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::AddOpNameParallelConf2Placement( const std::string& op_name, const ParallelConf& parallel_conf) { ParallelDesc parallel_desc(parallel_conf); PlacementGroup* pg = nullptr; if (parallel_desc2placement_group_.find(parallel_desc) == parallel_desc2placement_group_.end()) { pg = job_->mutable_placement()->add_placement_group(); parallel_desc2placement_group_.emplace(parallel_desc, pg); *(pg->mutable_parallel_conf()) = parallel_conf; } else { pg = parallel_desc2placement_group_.at(parallel_desc); } pg->mutable_op_set()->add_op_name(op_name); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::AddLbiParallelConf2BlobPlacement( const Operator* op, std::function ParallelDesc4Obn) { for (const auto& obn : op->output_bns()) { const auto& parallel_desc = *ParallelDesc4Obn(obn); auto iter = parallel_desc2blob_placement_group_.find(parallel_desc); if (iter == parallel_desc2blob_placement_group_.end()) { auto* blob_pg = job_->mutable_placement()->add_blob_placement_group(); *blob_pg->mutable_parallel_conf() = parallel_desc.parallel_conf(); iter = parallel_desc2blob_placement_group_.emplace(parallel_desc, blob_pg).first; } const auto& lbi = op->BnInOp2Lbi(obn); *iter->second->add_lbi() = lbi; } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::DecodeLbiHintAndReturnNewOpConf( const Operator& op, SbpSignature* sbp_sig_conf) const { auto op_conf_without_split_hint = std::make_shared(op.op_conf()); for (const std::string& ibn : op.input_bns()) { std::string lbn_may_with_hint = GetInputLbnInOpCustomizedConf(op.op_conf(), ibn); SbpParallel sbp_parallel; bool has_sbp_hint = JUST(GetSbpParallelInLbnOrNothing(lbn_may_with_hint, &sbp_parallel)); if (has_sbp_hint) { (*(sbp_sig_conf->mutable_bn_in_op2sbp_parallel()))[ibn] = sbp_parallel; const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn); std::string lbn = GenLogicalBlobName(lbi); CHECK_EQ_OR_RETURN(lbn_may_with_hint, ReplaceInputLbnInOpCustomizedConf( op_conf_without_split_hint.get(), ibn, lbn)); } } return op_conf_without_split_hint; } void JobBuildAndInferCtx::AddOpAndUpdateJobParallelViewConf(const OperatorConf& operator_conf, const ParallelDesc& parallel_desc, const NdSbpSignature& nd_sbp_signature, bool is_local_parallel_view) const { auto* op_name2sbp_sig = job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf(); auto* op_name2nd_sbp_sig = job_->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf(); if (nd_sbp_signature.bn_in_op2nd_sbp().size() > 0) { (*op_name2nd_sbp_sig)[operator_conf.name()] = nd_sbp_signature; if (parallel_desc.hierarchy()->NumAxes() == 1) { SbpSignature sbp_signature; NdSbpSignatureToSbpSignature(nd_sbp_signature, &sbp_signature); (*op_name2sbp_sig)[operator_conf.name()] = sbp_signature; } } auto* op_name2is_local_parallel_view = job_->mutable_job_parallel_view_conf()->mutable_op_name2is_local_parallel_view(); if (is_local_parallel_view) { (*op_name2is_local_parallel_view)[operator_conf.name()] = true; } job_->mutable_net()->add_op()->CopyFrom(operator_conf); // set up the module config const auto& scope = Singleton>::Get()->Get(operator_conf.scope_symbol_id()); if (scope.scope_proto().has_module_name()) { const auto& module_name = scope.scope_proto().module_name(); auto* module_name2module_conf = job_->mutable_module_name2module_conf(); if (!(*module_name2module_conf)[module_name].has_name()) { (*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name()); } *((*module_name2module_conf)[module_name].add_ops()) = operator_conf.name(); } } Maybe JobBuildAndInferCtx::InferLocalSignature(Operator* op, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) { HashMap ibn2local_sig_infer_hint; for (const std::string& ibn : op->input_bns()) { const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn); CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) << Error::LogicalBlobNameNotExistError() << "infer blob desc not found, when infer op_name: \"" << op->op_name() << "\", consumed op_name: \"" << lbi.op_name() << "\", blob_name: \"" << lbi.blob_name(); const ParallelDesc* pd = &lbi2parallel_desc_from_producer_view_.at(lbi); const auto* producer_op = op_name2op_.at(lbi.op_name()).get(); const auto& producer_obn = *JUST(producer_op->obn4lbi(lbi)); const auto& opt_local_parallel = *CHECK_JUST(producer_op->OptLocalParallel4BnInOp(producer_obn)); ibn2local_sig_infer_hint.emplace( ibn, LocalSigInferHint(pd, opt_local_parallel.has_local_parallel())); } const auto& LocalSigInferHint4Ibn = [&](const std::string& ibn) -> Maybe { const auto& iter = ibn2local_sig_infer_hint.find(ibn); CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << "input blob not found. ibn: " << ibn; return &iter->second; }; JUST( op->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local_parallel_view_conf, parallel_desc)); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::InferOpOutNdSbp(Operator* op, const NdSbpSignature& nd_sbp_sig_conf, const ParallelDesc& parallel_desc) { HashMap ibn2nd_sbp_infer_hint; for (const std::string& ibn : op->input_bns()) { const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn); auto logical_blob_desc_it = lbi2logical_blob_desc_.find(lbi); CHECK_OR_RETURN(logical_blob_desc_it != lbi2logical_blob_desc_.end()) << Error::LogicalBlobNameNotExistError() << "infer blob desc not found, when infer op_name: \"" << op->op_name() << "\", consumed op_name: \"" << lbi.op_name() << "\", blob_name: \"" << lbi.blob_name(); const BlobDesc* logical_blob_desc = logical_blob_desc_it->second.get(); const ParallelDesc* pd = &lbi2parallel_desc_from_producer_view_.at(lbi); auto nd_sbp_it = lbi2nd_sbp_from_producer_view_.find(lbi); CHECK_OR_RETURN(nd_sbp_it != lbi2nd_sbp_from_producer_view_.end()) << Error::LogicalBlobNameNotExistError() << "when infer op_name: " << op->op_name() << " consumed op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " not infer parallel distribution"; const NdSbp* nd_sbp = &nd_sbp_it->second; ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(pd, logical_blob_desc, nd_sbp)); } const auto NdSbpInferHint4Ibn = [&](const std::string& bn) -> Maybe { return &ibn2nd_sbp_infer_hint.at(bn); }; JUST(op->InferNdSbpSignatureIf(nd_sbp_sig_conf, parallel_desc, NdSbpInferHint4Ibn)); const auto& bn2nd_sbp = JUST(op->nd_sbp_signature())->bn_in_op2nd_sbp(); for (const auto& obn : op->output_bns()) { const LogicalBlobId& lbi = op->BnInOp2Lbi(obn); CHECK_OR_RETURN(bn2nd_sbp.find(obn) != bn2nd_sbp.end()) << Error::BlobSplitAxisInferError() << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " not infer split axis"; CHECK_OR_RETURN(lbi2nd_sbp_from_producer_view_.emplace(lbi, bn2nd_sbp.at(obn)).second) << Error::BlobSplitAxisInferError() << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " infer split axis repeated"; CHECK_OR_RETURN(lbi2parallel_desc_from_producer_view_.emplace(lbi, parallel_desc).second) << Error::BlobSplitAxisInferError() << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " add parallel desc repeated"; } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::GenOpProducedEmptyLogicalBlobDesc(Operator* op) { // check consumed blob for (const std::string& consumed_bn : op->input_bns()) { const LogicalBlobId& lbi = op->BnInOp2Lbi(consumed_bn); CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) << Error::LogicalBlobNameNotExistError() << "op_name: " << op->op_name() << " consumed_op_name:" << lbi.op_name() << " blob_name: " << lbi.blob_name() << " not exist"; } // create produced blob std::vector produced_bns; produced_bns.reserve(op->output_bns().size() + op->tmp_bns().size()); produced_bns.insert(produced_bns.end(), op->output_bns().begin(), op->output_bns().end()); produced_bns.insert(produced_bns.end(), op->tmp_bns().begin(), op->tmp_bns().end()); for (const std::string& produced_bn : produced_bns) { const LogicalBlobId& lbi = op->BnInOp2Lbi(produced_bn); CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) == lbi2logical_blob_desc_.end()) << Error::LogicalBlobNameExistError() << "duplicate logical blob name found. op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name(); lbi2logical_blob_desc_.emplace( lbi, std::make_unique(DataType::kInvalidDataType, MemoryFormat::kContiguous)); } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t parallel_num) { const auto& parallel_hierarchy = JUST(op->GetOpParallelDesc())->hierarchy(); if (parallel_hierarchy->NumAxes() == 1) { HashSet obns(op->output_bns().begin(), op->output_bns().end()); auto GetParallelNum = [&](const std::string& bn_in_op) { if (obns.find(bn_in_op) == obns.end()) { return parallel_num; } return lbi2parallel_desc_from_producer_view_.at(op->BnInOp2Lbi(bn_in_op)).parallel_num(); }; for (const auto& pair : JUST(op->sbp_signature())->bn_in_op2sbp_parallel()) { if (!pair.second.has_split_parallel()) { continue; } if (JUST(op->OptLocalParallel4BnInOp(pair.first))->has_local_parallel()) { continue; } int64_t axis = pair.second.split_parallel().axis(); const LogicalBlobId& lbi = op->BnInOp2Lbi(pair.first); int64_t blob_parallel_num = GetParallelNum(pair.first); const BlobDesc& logical_blob_desc = *(lbi2logical_blob_desc_.at(lbi).get()); int64_t num_axes = logical_blob_desc.shape().NumAxes(); if (axis < 0) { axis += num_axes; } CHECK_GE_OR_RETURN(axis, 0); CHECK_LE_OR_RETURN(axis, num_axes) << "op: " << op->op_name() << ", blob: " << pair.first << ", axis: " << axis << ", shape: " << logical_blob_desc.shape(); if (logical_blob_desc.shape().NumAxes() > 0) { CHECK_GE_OR_RETURN(logical_blob_desc.shape().At(axis), blob_parallel_num) << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " shape: " << logical_blob_desc.shape() << " cannot be splitted by parallel_num: " << blob_parallel_num << " at axis " << axis; } } } else { for (const auto& pair : JUST(op->nd_sbp_signature())->bn_in_op2nd_sbp()) { if (JUST(op->OptLocalParallel4BnInOp(pair.first))->has_local_parallel()) { continue; } const LogicalBlobId& lbi = op->BnInOp2Lbi(pair.first); const BlobDesc& logical_blob_desc = *(lbi2logical_blob_desc_.at(lbi).get()); Shape current_shape = logical_blob_desc.shape(); for (int64_t i = 0; i < pair.second.sbp_parallel_size(); ++i) { const SbpParallel& sbp_parallel = pair.second.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); CHECK_GT_OR_RETURN(current_shape.At(axis), 0); // Support unbalanced splitting CHECK_GE_OR_RETURN(current_shape.At(axis), parallel_hierarchy->At(i)) << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " shape: " << logical_blob_desc.shape() << " cannot be splitted by nd sbp: " << NdSbpToString(pair.second) << " at axis " << axis << " with parallel_hierarchy: " << *parallel_hierarchy; // Split and take the minimum one current_shape.Set(axis, current_shape.At(axis) / parallel_hierarchy->At(i)); } } } } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::InferOpParallelConf( const Operator& op, const ParallelConf& origin_parallel_conf, const HashMap& ibn2disable_boxing) const { const ParallelDesc* parallel_desc = nullptr; for (const auto& ibn : op.input_bns()) { if (ibn2disable_boxing.at(ibn) == false) { continue; } const auto& lbi = op.BnInOp2Lbi(ibn); const auto& ibn_parallel_desc = lbi2parallel_desc_from_producer_view_.at(lbi); if (parallel_desc == nullptr) { parallel_desc = &ibn_parallel_desc; } else { CHECK_EQ_OR_RETURN(parallel_desc->parallel_num(), ibn_parallel_desc.parallel_num()); } } if (parallel_desc == nullptr) { return std::make_shared(origin_parallel_conf); } return std::make_shared(parallel_desc->parallel_conf()); } void JobBuildAndInferCtx::InitIbn2DisableBoxing(const Operator& op, HashMap* ibn2disable_boxing) { for (const auto& ibn : op.input_bns()) { (*ibn2disable_boxing)[ibn] = lbi2disable_boxing_[op.BnInOp2Lbi(ibn)]; } } Maybe JobBuildAndInferCtx::InitConstraitNdSbpSignature( const Operator& op, const HashMap& ibn2disable_boxing) const { auto nd_sbp_sig = std::make_shared(); for (const auto& it : ibn2disable_boxing) { if (it.second) { const auto& ibn = it.first; const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn); const auto& nd_sbp_iter = lbi2nd_sbp_from_producer_view_.find(lbi); if (nd_sbp_iter == lbi2nd_sbp_from_producer_view_.end()) { return Error::RuntimeError() << "The nd_sbp of input " << ibn << " (tensor name is " << GenLogicalBlobName(lbi) << ") is not found for operation " << op.op_name() << ". It maybe caused by an invalid inplace operation."; } (*(nd_sbp_sig->mutable_bn_in_op2nd_sbp()))[ibn] = lbi2nd_sbp_from_producer_view_.at(lbi); } } return nd_sbp_sig; } bool JobBuildAndInferCtx::HasAnyLocalBlobInput(const Operator& op) const { for (const auto& ibn : op.input_bns()) { const auto& lbi = op.BnInOp2Lbi(ibn); if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { return true; } } return false; } Maybe JobBuildAndInferCtx::SbpParallel4Lbi(const LogicalBlobId& lbi) const { const auto& iter = lbi2nd_sbp_from_producer_view_.find(lbi); CHECK_OR_RETURN(iter != lbi2nd_sbp_from_producer_view_.end()) << "lbn: " << GenLogicalBlobName(lbi) << " undefined"; CHECK_EQ_OR_RETURN(iter->second.sbp_parallel_size(), 1); return &(iter->second.sbp_parallel(0)); } Maybe JobBuildAndInferCtx::ParallelDesc4Lbi(const LogicalBlobId& lbi) const { const auto& iter = lbi2parallel_desc_from_producer_view_.find(lbi); CHECK_OR_RETURN(iter != lbi2parallel_desc_from_producer_view_.end()) << "lbn: " << GenLogicalBlobName(lbi) << " undefined"; return &iter->second; } Maybe JobBuildAndInferCtx::AllInputsBroadcastParallel(const Operator& op) const { for (const auto& ibn : op.input_bns()) { const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn); const auto& iter = local_lbi2sbp_parallel_.find(lbi); if (iter != local_lbi2sbp_parallel_.end()) { if (!iter->second.has_broadcast_parallel()) { return false; } } else { if (!JUST(SbpParallel4Lbi(lbi))->has_broadcast_parallel()) { return false; } } } return true; } bool JobBuildAndInferCtx::IsVariableLbi(const LogicalBlobId& lbi) const { return op_name2op_.at(lbi.op_name())->op_conf().has_variable_conf(); } Maybe JobBuildAndInferCtx::CheckAllInputsConvertableToLocalBlob(const Operator& op) const { for (const auto& ibn : op.input_bns()) { const auto& lbi = op.BnInOp2Lbi(ibn); if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { continue; } const auto& sbp = *JUST(SbpParallel4Lbi(lbi)); if (sbp.has_broadcast_parallel()) { continue; } if (sbp.has_split_parallel() && sbp.split_parallel().axis() == 0) { continue; } const std::string& lbn = GenLogicalBlobName(lbi); return Error::CheckFailedError() << "input lbn: " << lbn << " is not convertable to local blob"; } return Maybe::Ok(); } Maybe LazyJobBuildAndInferCtx::CheckAllInputsWithSameParallelNum(const Operator& op, int32_t parallel_num) const { for (const auto& ibn : op.input_bns()) { const auto& lbi = op.BnInOp2Lbi(ibn); const auto& iter = local_lbi2sub_lbis().find(lbi); int32_t ibn_parallel_num = 0; if (iter != local_lbi2sub_lbis().end()) { ibn_parallel_num = iter->second.size(); } else { ibn_parallel_num = JUST(ParallelDesc4Lbi(lbi))->parallel_num(); } CHECK_EQ_OR_RETURN(ibn_parallel_num, parallel_num) << "the parallel_num of input lbn: " << GenLogicalBlobName(lbi) << " is not equals to op' parallel_num"; } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::AddAndInferLocalOp(const OperatorConf& op_conf) { CHECK_OR_RETURN(op_conf.has_scope_symbol_id()); const auto& scope = Singleton>::Get()->Get(op_conf.scope_symbol_id()); const auto* job_desc = JUST(scope.job_desc()); const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf)); auto op = JUST(ConstructOp(op_conf, parallel_desc.device_type())); JUST(CheckAllInputsConvertableToLocalBlob(*op)); int32_t parallel_num = parallel_desc.parallel_num(); JUST(CheckAllInputsWithSameParallelNum(*op, parallel_num)); auto GetSubOpName = [&](int index) { return GetLocalOpName(op_conf.name(), index); }; OperatorConf sub_op_conf(op_conf); int64_t sub_op_list_size = SizeOfSubGlobalOpList(parallel_num); auto last_op_attribute = std::make_shared(); FOR_RANGE(int32_t, i, 0, sub_op_list_size) { ResetOpConfName(&sub_op_conf, GetSubOpName(i)); for (const auto& ibn : op->input_bns()) { const auto& lbi = *JUST(GetSubLbi(op_conf.scope_symbol_id(), op->BnInOp2Lbi(ibn), i)); ReplaceInputLbnInOpCustomizedConf(&sub_op_conf, ibn, GenLogicalBlobName(lbi)); } const ParallelConf& parallel_conf = GetLocalOpParallelConf(parallel_desc, i); bool is_local_parallel_view = GetIsLocalParallelView(); last_op_attribute = JUST(AddAndInferOp(sub_op_conf, parallel_conf, job_desc, is_local_parallel_view)); } bool is_broadcast = JUST(AllInputsBroadcastParallel(*op)); for (const auto& obn : op->output_bns()) { const auto& lbi = op->BnInOp2Lbi(obn); auto* sub_lbis = &local_lbi2sub_lbis_[lbi]; sub_lbis->resize(sub_op_list_size, op->BnInOp2Lbi(obn)); FOR_RANGE(int32_t, i, 0, sub_op_list_size) { sub_lbis->at(i).set_op_name(GetSubOpName(i)); } CHECK(local_lbi2parallel_desc_.emplace(lbi, parallel_desc).second); auto* sbp_parallel = &local_lbi2sbp_parallel_[lbi]; if (is_broadcast) { sbp_parallel->mutable_broadcast_parallel(); } else { sbp_parallel->mutable_split_parallel()->set_axis(0); } } return last_op_attribute; } Maybe JobBuildAndInferCtx::GetSubLbi(int64_t scope_symbol_id, const LogicalBlobId& lbi, int32_t index) { auto lbi_vec_iter = local_lbi2sub_lbis_.find(lbi); if (lbi_vec_iter == local_lbi2sub_lbis_.end()) { const auto& new_lbi = JUST(FindOrCreateLocalLbiFromCompatibleGlobalBlob(scope_symbol_id, lbi)); lbi_vec_iter = local_lbi2sub_lbis_.find(*new_lbi); CHECK(lbi_vec_iter != local_lbi2sub_lbis_.end()); } return &lbi_vec_iter->second.at(index); } Maybe JobBuildAndInferCtx::AddAndInferGlobalOp(const OperatorConf& op_conf) { CHECK_OR_RETURN(op_conf.has_scope_symbol_id()); const auto& scope = Singleton>::Get()->Get(op_conf.scope_symbol_id()); const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf)); const auto* job_desc = JUST(scope.job_desc()); return AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false); } // TODO(): add handle error of same interface op blob between jobs Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_conf, const ParallelConf& origin_parallel_conf, const JobDesc* job_desc, bool is_local_parallel_view) { CHECK_OR_RETURN(has_job_conf_) << Error::JobConfNotSetError(); if (!is_job_conf_frozen_) { is_job_conf_frozen_ = true; } const std::string& op_name = op_conf.name(); CHECK_OR_RETURN(op_name2op_.find(op_name) == op_name2op_.end()) << Error::OpNameExistError() << "op_name: " << op_name << " already exist in job: " << job_->job_conf().job_name(); CHECK_NE_OR_RETURN(op_conf.device_tag(), "invalid_device") << Error::OpConfDeviceTagNoSetError() << "op_name: " << op_name << " not set device tag"; op_name2op_.emplace(op_name, JUST(ConstructOp(op_conf))); Operator* op = op_name2op_.at(op_name).get(); SbpSignature sbp_sig_conf; HashMap ibn2disable_boxing; InitIbn2DisableBoxing(*op, &ibn2disable_boxing); auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf)); auto parallel_conf = JUST(InferOpParallelConf(*op, origin_parallel_conf, ibn2disable_boxing)); ParallelDesc parallel_desc(*parallel_conf); JUST(op->FillOpParallelDesc(parallel_desc)); JUST(AddOpNameParallelConf2Placement(op_name, *parallel_conf)); auto GetBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* { const LogicalBlobId& lbi = op->BnInOp2Lbi(bn); if (lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) { return lbi2logical_blob_desc_.at(lbi).get(); } return nullptr; }; JUST(op->FillLogicalInBlobDesc(GetBlobDesc4BnInOp)); JUST(op->InferParallelSignatureIf()); // infer local signature JUST(InferLocalSignature(op, is_local_parallel_view, parallel_desc)); // infer nd_sbp signature NdSbpSignature nd_sbp_sig_conf; // Only infer nd_sbp signature if auto parallel is not enable, // since the semi-auto parallellism rule might have inconsistency with the auto-parallel strategy. if (!job_desc->enable_auto_parallel()) { nd_sbp_sig_conf = *JUST(InitConstraitNdSbpSignature(*op, ibn2disable_boxing)); } // Override constrait nd_sbp if sbp hint is given if (!sbp_sig_conf.bn_in_op2sbp_parallel().empty()) { SbpSignatureToNdSbpSignature(sbp_sig_conf, &nd_sbp_sig_conf); } AddOpAndUpdateJobParallelViewConf(*new_op_conf, parallel_desc, nd_sbp_sig_conf, is_local_parallel_view); JUST(InferOpOutNdSbp(op, nd_sbp_sig_conf, parallel_desc)); // infer logical blob desc JUST(GenOpProducedEmptyLogicalBlobDesc(op)); JUST(op->InferLogicalOutBlobDescsIf()); for (const auto& bn : op->output_bns()) { *lbi2logical_blob_desc_.at(op->BnInOp2Lbi(bn)) = *JUST(op->GetLogicalBlobDesc4Obn(bn)); } // Infer ParallelDesc for output blobs. auto ParallelDesc4Obn = [&](const std::string& obn) -> ParallelDesc* { const auto& lbi = op->BnInOp2Lbi(obn); auto iter = lbi2parallel_desc_from_producer_view_.find(lbi); if (iter == lbi2parallel_desc_from_producer_view_.end()) { iter = lbi2parallel_desc_from_producer_view_.emplace(lbi, parallel_desc).first; } return &iter->second; }; for (const auto& bn : op->output_bns()) { lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn), *JUST(op->GetParallelDesc4BnInOp(bn))); } JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn)); // Check splitability JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num())); return op->GetOpAttributeWithoutOpNameAndLbn(); } bool JobBuildAndInferCtx::HasJobConf() const { return has_job_conf_; } Maybe JobBuildAndInferCtx::SetTrainConf(const TrainConf& train_conf) { *job_->mutable_job_conf()->mutable_train_conf() = train_conf; return Maybe::Ok(); } Maybe JobBuildAndInferCtx::AddLossLogicalBlobName(const std::string& lbn) { if (IsLocalBlob(lbn)) { return AddLossLocalBlobName(lbn); } return AddLossGlobalBlobName(lbn); } Maybe JobBuildAndInferCtx::AddLossGlobalBlobName(const std::string& lbn) { JUST(CheckLbnValidAndExist(lbn)); CHECK_OR_RETURN(job_->job_conf().has_train_conf()) << Error::UnknownJobBuildAndInferError() << "job has no TrainConf when adding loss logical blob name"; job_->mutable_job_conf()->mutable_train_conf()->add_loss_lbn(lbn); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::MarkVariableGradientBlobNames( const HashMap& variable_grad_lbns) { CHECK_OR_RETURN(job_->job_conf().has_train_conf()) << Error::UnknownJobBuildAndInferError() << "job has no TrainConf when add variable gradient logical blob name"; auto* train_conf = job_->mutable_job_conf()->mutable_train_conf(); for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) { auto* optimizer_conf = train_conf->mutable_optimizer_conf(i); for (const auto& variable_op_name : optimizer_conf->variable_op_names()) { const auto& it = variable_grad_lbns.find(variable_op_name + "/out"); if (it != variable_grad_lbns.end()) { optimizer_conf->add_variable_grad_lbns(it->second); } else { // add an empty gradient lbn for variable that has no gradient optimizer_conf->add_variable_grad_lbns(""); } } } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::MarkOutputGradientBlobNames( const HashMap& output_gradient_lbns) { CHECK_OR_RETURN(job_->job_conf().has_train_conf()) << Error::UnknownJobBuildAndInferError() << "job has no TrainConf when add variable gradient logical blob name"; auto* train_conf = job_->mutable_job_conf()->mutable_train_conf(); for (const auto& loss_lbn : train_conf->loss_lbn()) { const auto& it = output_gradient_lbns.find(loss_lbn); CHECK_OR_RETURN(it != output_gradient_lbns.end()) << Error::UnknownJobBuildAndInferError() << "gradient is missing for loss " << loss_lbn; train_conf->add_loss_grad_lbn(it->second); } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::GetStaticShape(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->shape(); } Maybe JobBuildAndInferCtx::GetDataType(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->data_type(); } Maybe JobBuildAndInferCtx::IsDynamic(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->is_dynamic(); } Maybe JobBuildAndInferCtx::IsDisableBoxing(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); LogicalBlobId lbi(GenLogicalBlobId(lbn)); const auto& iter = lbi2disable_boxing_.find(lbi); CHECK_OR_RETURN(iter != lbi2disable_boxing_.end()); return iter->second; } Maybe JobBuildAndInferCtx::DisableBoxing(const std::string& lbn) { JUST(CheckLbnValidAndExist(lbn)); LogicalBlobId lbi(GenLogicalBlobId(lbn)); lbi2disable_boxing_[lbi] = true; return Maybe::Ok(); } Maybe JobBuildAndInferCtx::Op4OpName(const std::string& op_name) const { const auto& op_iter = op_name2op_.find(op_name); CHECK_OR_RETURN(op_iter != op_name2op_.end()); auto* op = op_iter->second.get(); CHECK_NOTNULL_OR_RETURN(op); return op; } Maybe JobBuildAndInferCtx::GetSplitAxisFromProducerView(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); OptInt64 ret; const auto& nd_sbp = lbi2nd_sbp_from_producer_view_.at(GenLogicalBlobId(lbn)); CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1); const auto& sbp = nd_sbp.sbp_parallel(0); if (sbp.has_split_parallel()) { ret.set_value(sbp.split_parallel().axis()); } return ret; } Maybe JobBuildAndInferCtx::GetParallelDescFromProducerView( const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); return &(lbi2parallel_desc_from_producer_view_.at(GenLogicalBlobId(lbn))); } Maybe JobBuildAndInferCtx::AddLossLocalBlobName(const std::string& lbn) { const auto& local_lbi = JUST(GetLocalLbi(lbn)); CHECK_OR_RETURN(job_->job_conf().has_train_conf()) << Error::UnknownJobBuildAndInferError() << "job has no TrainConf when adding loss logical blob name"; for (const auto& lbi : local_lbi2sub_lbis_[*local_lbi]) { job_->mutable_job_conf()->mutable_train_conf()->add_loss_lbn(GenLogicalBlobName(lbi)); } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::GetLocalLbi(const std::string& lbn_with_hint) const { const LogicalBlobId& lbi = GenLogicalBlobId(lbn_with_hint); if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { return lbi; } return Error::CheckFailedError() << lbn_with_hint << " is not a local blob name"; } Maybe JobBuildAndInferCtx::LocalBlobGetNumSubLbi(const std::string& lbn_with_hint) const { const auto& local_lbi = JUST(GetLocalLbi(lbn_with_hint)); return local_lbi2sub_lbis_.at(*local_lbi).size(); // NOLINT } Maybe JobBuildAndInferCtx::LocalBlobGetSubLbi( const std::string& lbn_with_hint, int index) const { const auto& local_lbi = JUST(GetLocalLbi(lbn_with_hint)); const auto& vec = local_lbi2sub_lbis_.at(*local_lbi); // NOLINT CHECK_GE_OR_RETURN(index, 0); CHECK_LT_OR_RETURN(index, vec.size()); return &vec.at(index); } bool JobBuildAndInferCtx::IsLocalBlob(const std::string& lbn) const { bool is_local_blob = TRY(GetLocalLbi(lbn)).IsOk(); if (is_local_blob) { return is_local_blob; } const LogicalBlobId& lbi = GenLogicalBlobId(lbn); CHECK(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) << "lbn: " << lbn; return false; } Maybe JobBuildAndInferCtx::LocalBlobGetStaticShape(const std::string& lbn_with_hint) const { const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0)); return lbi2logical_blob_desc_.at(lbi)->shape(); } Maybe JobBuildAndInferCtx::LocalBlobGetDataType(const std::string& lbn_with_hint) const { const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0)); return lbi2logical_blob_desc_.at(lbi)->data_type(); } Maybe JobBuildAndInferCtx::LocalBlobIsDynamic(const std::string& lbn_with_hint) const { const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0)); return lbi2logical_blob_desc_.at(lbi)->is_dynamic(); } Maybe JobBuildAndInferCtx::LocalBlobGetSplitAxisFromProducerView( const std::string& lbn_with_hint) const { const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0)); OptInt64 ret; const auto& nd_sbp = lbi2nd_sbp_from_producer_view_.at(lbi); CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1); const auto& sbp = nd_sbp.sbp_parallel(0); if (sbp.has_split_parallel()) { ret.set_value(sbp.split_parallel().axis()); } return ret; } Maybe JobBuildAndInferCtx::LocalBlobGetParallelDescFromProducerView( const std::string& lbn_with_hint) const { const auto& lbi = JUST(GetLocalLbi(lbn_with_hint)); return &(local_lbi2parallel_desc_.at(*lbi)); // NOLINT } Maybe JobBuildAndInferCtx::CheckJob() const { JUST(CheckPlacement()); JUST(CheckJobConf()); JUST(CheckOpScope()); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::CheckPlacement() const { HashSet op_names_in_net; HashSet op_names_in_placement; for (const OperatorConf& op_conf : job_->net().op()) { CHECK_OR_RETURN(op_names_in_net.insert(op_conf.name()).second) << Error::OpNameExistError() << "op_name: " << op_conf.name() << " already exist in job: " << job_->job_conf().job_name() << " net"; } for (const PlacementGroup& placement_group : job_->placement().placement_group()) { for (const std::string& op_name : placement_group.op_set().op_name()) { CHECK_OR_RETURN(op_names_in_placement.insert(op_name).second) << Error::OpNameExistError() << "op_name: " << op_name << " already exist in job: " << job_->job_conf().job_name() << " placement"; } } CHECK_EQ_OR_RETURN(op_names_in_net.size(), op_names_in_placement.size()) << Error::PlacementError() << "job: " << job_->job_conf().job_name() << " op number not equal between net and placement"; for (const std::string& op_name : op_names_in_net) { CHECK_OR_RETURN(op_names_in_placement.find(op_name) != op_names_in_placement.end()) << Error::PlacementError() << "job: " << job_->job_conf().job_name() << " op_name: " << op_name << " defined in net cannot find its placement"; } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::CheckJobConf() const { if (job_->job_conf().job_type_case() == JobConfigProto::JOB_TYPE_NOT_SET) { return Error::JobTypeNotSetError() << "job_type not set, please set predict_conf or train_conf"; } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::CheckOpScope() const { for (const OperatorConf& op_conf : job_->net().op()) { if (!op_conf.has_scope_symbol_id()) { // NOTE(chengcheng): LOG(WARNING) instead of CHECK_OR_RETURN() for transition LOG(WARNING) << " ERROR! op_name: " << op_conf.name() << " has NOT set scope(scope_symbol_id) in job: " << job_->job_conf().job_name() << " net. \n op_conf = " << op_conf.DebugString(); } } return Maybe::Ok(); } Maybe JobBuildAndInferCtx::CheckLbnValidAndExist(const std::string& lbn) const { CHECK_OR_RETURN(lbn.find('/') != std::string::npos) << Error::LogicalBlobNameInvalidError() << "lbn:" << lbn; LogicalBlobId lbi = GenLogicalBlobId(lbn); #define CHECK_HAS_LBI_KEY(info_src) \ CHECK_OR_RETURN(info_src.find(lbi) != info_src.end()) \ << Error::LogicalBlobNameNotExistError() << "lbn:" << lbn; CHECK_HAS_LBI_KEY(lbi2logical_blob_desc_); CHECK_HAS_LBI_KEY(lbi2nd_sbp_from_producer_view_); CHECK_HAS_LBI_KEY(lbi2parallel_desc_from_producer_view_); #undef CHECK_HAS_LBI_KEY return Maybe::Ok(); } const Job& JobBuildAndInferCtx::job() const { return *job_; } std::string LazyJobBuildAndInferCtx::GetLocalOpName(const std::string& op_name, int64_t parallel_id) const { return op_name + "_" + std::to_string(parallel_id); } ParallelConf LazyJobBuildAndInferCtx::GetLocalOpParallelConf(const ParallelDesc& parallel_desc, int64_t parallel_id) const { return parallel_desc.GetParallelIdOnlyParallelConf(parallel_id); } Maybe LazyJobBuildAndInferCtx::FindOrCreateLocalLbiFromCompatibleGlobalBlob( int64_t scope_symbol_id, const LogicalBlobId& lbi) { const std::string& lbn = GenLogicalBlobName(lbi); const auto& sbn_it = mut_global_lbi2local_lbi()->find(lbi); if (sbn_it != mut_global_lbi2local_lbi()->end()) { return sbn_it->second; } const SbpParallel& sbp = *JUST(SbpParallel4Lbi(lbi)); const ParallelDesc& parallel_desc = *JUST(ParallelDesc4Lbi(lbi)); LogicalBlobId local_lbi; local_lbi.set_op_name(kAutoLocalBlobNamePrefix + NewUniqueId()); local_lbi.set_blob_name("out"); (*mut_global_lbi2local_lbi())[lbi] = local_lbi; auto* lbi_vec = &(*mut_local_lbi2sub_lbis())[local_lbi]; lbi_vec->reserve(parallel_desc.parallel_num()); auto PushBackSubLbi = [&](const std::string& op_name, const std::string& blob_name) { LogicalBlobId sub_lbi; sub_lbi.set_op_name(op_name); sub_lbi.set_blob_name(blob_name); lbi_vec->emplace_back(sub_lbi); }; OperatorConf op_conf; op_conf.set_scope_symbol_id(scope_symbol_id); op_conf.set_device_tag(*JUST(DeviceTag4DeviceType(parallel_desc.device_type()))); if (sbp.has_broadcast_parallel()) { op_conf.set_name(kAutoLocalBlobNamePrefix + "-DistributeClone-" + NewUniqueId()); auto* distribute_clone = op_conf.mutable_distribute_clone_conf(); distribute_clone->set_in(lbn); FOR_RANGE(int32_t, i, 0, parallel_desc.parallel_num()) { const std::string& blob_name = "out_" + std::to_string(i); distribute_clone->add_out(blob_name); distribute_clone->set_is_variable_ref(IsVariableLbi(lbi)); PushBackSubLbi(op_conf.name(), blob_name); } } else if (sbp.has_split_parallel()) { CHECK_EQ_OR_RETURN(sbp.split_parallel().axis(), 0) << "only `S(0)' global blob is compatible to local blob"; op_conf.set_name(kAutoLocalBlobNamePrefix + "-DistributeSplit-" + NewUniqueId()); auto* distribute_split = op_conf.mutable_distribute_split_conf(); distribute_split->set_in(lbn); distribute_split->set_axis(0); distribute_split->set_is_variable_ref(IsVariableLbi(lbi)); FOR_RANGE(int32_t, i, 0, parallel_desc.parallel_num()) { const std::string& blob_name = "out_" + std::to_string(i); distribute_split->add_out(blob_name); PushBackSubLbi(op_conf.name(), blob_name); } } else { OF_UNIMPLEMENTED() << "`P' global blob is not compatible to local blob"; } { const auto& producer_op_conf = JUST(Op4OpName(lbi.op_name()))->op_conf(); CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id()); const auto& scope = Singleton>::Get()->Get(scope_symbol_id); const auto* job_desc = JUST(scope.job_desc()); JUST(AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false)); } return local_lbi; } Maybe LazyJobBuildAndInferCtx::Complete() { CHECK_GT_OR_RETURN(job().net().op_size(), 0) << " Sorry, nn.Graph need at least 1 op in net, but get 0 now."; auto compile_tc = std::make_unique>(true, true); CHECK_NOTNULL(Singleton::Get()); // A global variable to get graph configurations. auto current_graph_config = std::make_unique(mut_job()->job_conf(), job_id()); JobPassCtx job_pass_ctx(GlobalJobDesc()); const auto job_name = job().job_conf().job_name(); auto LogJob = [&](const std::string& name_suffix) -> void { std::string full_log_name = job_name + "-job_id_" + std::to_string(job_id()) + "-" + name_suffix; TeePersistentLogStream::Create(full_log_name)->Write(job()); Singleton::New(job()); Singleton::Get()->ToDotWithFilePath(full_log_name + ".dot"); Singleton::Delete(); }; std::string debug_pass_name = GetStringFromEnv("ONEFLOW_DEBUG_PASS", ""); auto NeedLogJob = [&](const std::string& pass_name) -> bool { if ("ALL" == debug_pass_name) { return true; } else if (pass_name == debug_pass_name) { return true; } else { return false; } }; int32_t pass_cnt = 0; const int64_t prev_v = FLAGS_v; auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe { auto pass_tc = std::make_unique>(true, true); VLOG(1) << job_name << " start compiling with pass" << " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name << (cnt > 0 ? std::to_string(cnt) : ""); if (unlikely(NeedLogJob(pass_name))) { std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-before"); FLAGS_v = 3; } JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx)); if (unlikely(NeedLogJob(pass_name))) { FLAGS_v = prev_v; std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-after"); } VLOG(1) << job_name << " finish compiling with pass" << " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name << (cnt > 0 ? std::to_string(cnt) : ""); pass_tc->Count("[GraphCompile]" + job_name + " " + pass_name, 1, true); ++pass_cnt; return Maybe::Ok(); }; if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create(StrCat("forward_graph", job_id()))->Write(job()); Singleton::New(job()); Singleton::Get()->ToDotWithFilePath("forward_dlnet_" + std::to_string(job_id()) + "_op_graph.dot"); Singleton::Delete(); } if (GlobalJobDesc().Bool("__is_user_function__")) { // insert pinned identity to prevent the loss, loss initial gradient and // variable gradient from being eliminated by IRRoundTripBeforeAD pass JUST(DoPass("InsertPinnedIdentityOpPass")); // prune the dangling constant which are the 0 gradients initialized by // the autograd engine for those tensors that have no gradients JUST(DoPass("EliminateDeadNodesPass")); JUST(DoPass("NormalizationExponentialAverageAutoTickPass")); JUST(DoPass("AutoMixedPrecision")); // prune depend OP and and add ctrl_in_op to op_conf accordingly // to express the same semantics and avoid performance loss JUST(DoPass("PruneDependOpPass")); JUST(DoPass("PruneAmpWhiteIdentityOpPass")); JUST(DoPass("OptimizerPlacementOptimizationPass")); // run FuseAddToOutputPass before IRRoundTripBeforeAD since add_2 maybe // fused as add_n in IRRoundTripBeforeAD pass JUST(DoPass("FuseAddToOutputPass")); #ifdef WITH_MLIR JUST(DoPass("IRRoundTripBeforeAD")); #endif // WITH_MLIR // run DynamicLossScaleSchedulePass, AutoTrainStep and AutoLearningRate // after IRRoundTripBeforeAD since IRRoundTripBeforeAD will do DCE // optimization which could eliminate the nodes inserted by them JUST(DoPass("DynamicLossScaleSchedulePass")); JUST(DoPass("AutoTrainStep")); JUST(DoPass("AutoLearningRate")); JUST(DoPass("QuantAwareTraining")); JUST(DoPass("GenerateOptimizerOpConfs")); // pinned identity can be pruned since GenerateOptimizerOpConfs pass has // already construct a complete computational graph JUST(DoPass("PrunePinnedIdentityOpPass")); JUST(DoPass("ReplaceEmbeddingOps")); JUST(DoPass("SequentialOneEmbeddingOpsPass")); JUST(DoPass("FuseEmbeddingShuffleInteractionPass")); JUST(DoPass("FuseBCEReduceMeanFwBwPass")); JUST(DoPass("AddSspVariableProxy")); JUST(DoPass("CheckpointingPass")); JUST(DoPass("CudnnFusedNormalizationAddReluPass")); JUST(DoPass("PruneCastToStaticShapeOpsPass")); #ifdef WITH_MLIR JUST(DoPass("IRRoundTrip")); #endif // WITH_MLIR // run this pass again to fuse ops created in the first run. // TODO(guoran): loop multiple times inside the pass JUST(DoPass("FuseAddToOutputPass", 1)); JUST(DoPass("FuseConsecutiveAddPass")); JUST(DoPass("IndexedSlicesOptimizerRewritePass")); JUST(DoPass("SplitSparseSoftmaxCrossEntropyOpPass")); JUST(DoPass("DoParallelCastBeforeWideningTypeCast")); JUST(DoPass("FuseCastScalePass")); JUST(DoPass("PruneParallelCastOpsPass")); JUST(DoPass("FuseUpdateOpsPass")); JUST(DoPass("FuseModelUpdateCastOpsPass")); JUST(DoPass("MultiTensorModelUpdatePass")); JUST(DoPass("FixPipelineStageIdPass")); JUST(DoPass("PipelineBufferPass")); JUST(DoPass("AutoParallelPass")); JUST(DoPass("DelayVariableOpExecutionPass")); #ifdef WITH_CUTLASS JUST(DoPass("CutlassConvTuningWarmupPass")); #endif // WITH_CUTLASS JUST(DoPass("DumpVariableInfoPass")); } JUST(DoPass("DumpBlobParallelConfPass")); JUST(CheckJob()); compile_tc->Count("[GraphCompile]" + job_name + " OptimizationLogicalGraph", 0); return Maybe::Ok(); } namespace { std::string OpConf2ClassName(const OperatorConf& op_conf) { if (op_conf.has_user_conf()) { return op_conf.user_conf().op_type_name(); } else if (op_conf.has_variable_conf()) { return "variable"; } else if (op_conf.has_input_conf() && op_conf.has_return_conf()) { return "input"; } else if (op_conf.has_output_conf() && op_conf.has_return_conf()) { return "output"; } else { return "system_op"; } } void FormateUserConf(nlohmann::json& json_conf) { nlohmann::json user_conf = json_conf["user_conf"]; if (user_conf.is_null()) { json_conf.erase(json_conf.find("user_conf")); return; } std::string nomarl_array[] = {"at_int32", "at_int64", "at_bool", "at_float", "at_double", "at_string", "at_shape", "at_stride", "at_data_type"}; std::string list_array[] = {"at_list_int32", "at_list_int64", "at_list_float", "at_list_data_type", "at_list_shape", "at_list_stride", "at_list_string"}; nlohmann::json attr_json = user_conf["attr"]; for (int32_t i = 0; i < attr_json.size(); i++) { std::string key = attr_json[i]["key"]; nlohmann::json value_json = attr_json[i]["value"]; bool is_found_normal = false; for (int32_t j = 0; j < nomarl_array->length(); j++) { std::string value_key = nomarl_array[j]; if (value_json.contains(value_key)) { is_found_normal = true; if ("at_shape" == value_key || "at_stride" == value_key) { json_conf[key] = value_json[value_key]["dim"]; } else { json_conf[key] = value_json[value_key]; } break; } } if (is_found_normal) { continue; } for (int32_t j = 0; j < list_array->length(); j++) { std::string value_key = list_array[j]; if (value_json.contains(value_key)) { if (value_json[value_key].contains("val")) { json_conf[key] = value_json[value_key]["val"]; break; } else if (value_json[value_key].contains("dim")) { json_conf[key] = value_json[value_key]["dim"]; break; } } } } json_conf.erase(json_conf.find("user_conf")); } void FormateVariableConf(nlohmann::json& json_conf) { nlohmann::json variable_conf = json_conf["variable_conf"]; if (variable_conf == nullptr) { json_conf.erase(json_conf.find("variable_conf")); return; } for (nlohmann::json::iterator it = variable_conf.begin(); it != variable_conf.end(); ++it) { std::string key = it.key(); if ("shape" == key) { json_conf[key] = it.value()["dim"]; } else { json_conf[key] = it.value(); } } json_conf.erase(json_conf.find("variable_conf")); } } // namespace std::string oneflow::JobBuildAndInferCtx::GetJobStructureGraphJson( const std::string& job_name) const { HashSet inputs_op_names; HashSet outputs_op_names; std::vector layers_vec; layers_vec.reserve(op_name2op_.size()); for (const auto& pair : op_name2op_) { nlohmann::json json_layers_pair; const Operator* op = pair.second.get(); const std::string& op_name = pair.first; HashSet inbound_nodes; for (const auto& ibn : op->input_bns()) { const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn); if (op_name2op_.find(lbi.op_name()) != op_name2op_.end()) { inbound_nodes.insert(lbi.op_name()); } } if (op->op_conf().has_input_conf() && op->op_conf().has_return_conf()) { inputs_op_names.insert(op_name); } if (op->op_conf().has_output_conf() && op->op_conf().has_return_conf()) { outputs_op_names.insert(op_name); } json_layers_pair["name"] = op_name; std::string class_name = OpConf2ClassName(op->op_conf()); json_layers_pair["class_name"] = class_name; nlohmann::json json_conf; summary::ConvertProtobufMsg2Json(json_conf, op->op_conf()); FormateUserConf(json_conf); FormateVariableConf(json_conf); json_layers_pair["config"] = json_conf; std::vector inbound_nodes_vec; inbound_nodes_vec.reserve(inbound_nodes.size()); for (const auto& in_node_name : inbound_nodes) { inbound_nodes_vec.emplace_back(in_node_name); } json_layers_pair["inbound_nodes"] = inbound_nodes_vec; layers_vec.emplace_back(json_layers_pair); } nlohmann::json json_pair; json_pair["name"] = job_name; json_pair["layers"] = layers_vec; json_pair["input_layers"] = inputs_op_names; json_pair["output_layers"] = outputs_op_names; return json_pair.dump(); } Maybe JobBuildAndInferCtx::Rebuild() { // clear old state lbi2logical_blob_desc_.clear(); lbi2nd_sbp_from_producer_view_.clear(); lbi2parallel_desc_from_producer_view_.clear(); lbi2disable_boxing_.clear(); op_name2op_.clear(); parallel_desc2placement_group_.clear(); parallel_desc2blob_placement_group_.clear(); global_lbi2local_lbi_.clear(); local_lbi2sub_lbis_.clear(); local_lbi2parallel_desc_.clear(); local_lbi2sbp_parallel_.clear(); op_name2ancestors_need_no_grad_.clear(); // record op mirror view HashMap op_name2is_local; CHECK_OR_RETURN(job_->has_job_parallel_view_conf()); for (const auto& op_conf : job_->net().op()) { const auto& op_name = op_conf.name(); CHECK_OR_RETURN(op_name2is_local.find(op_name) == op_name2is_local.end()); // NOLINT op_name2is_local[op_name] = false; const auto& op_name2is_local_parallel_view = job_->job_parallel_view_conf().op_name2is_local_parallel_view(); if (op_name2is_local_parallel_view.find(op_name) != op_name2is_local_parallel_view.end()) { if (op_name2is_local_parallel_view.at(op_name)) { op_name2is_local[op_name] = true; } } } // build op graph OpGraph op_graph; if (Singleton::Get()) { JUST(op_graph.Init(*job_)); } else { auto scope = std::make_unique(job_->job_conf(), job_id()); JUST(op_graph.Init(*job_)); } // clear old job except job_conf job_->mutable_net()->Clear(); job_->mutable_placement()->Clear(); job_->mutable_job_parallel_view_conf()->Clear(); job_->mutable_helper()->Clear(); // topo traverse op_graph to AddAndInferOp op_graph.TopoForEachNode([&](OpNode* node) -> void { const auto& op_conf = node->op().op_conf(); CHECK(op_name2is_local.find(op_conf.name()) != op_name2is_local.end()); bool is_local = op_name2is_local.at(op_conf.name()); if (is_local) { CHECK_JUST(AddAndInferLocalOp(op_conf)); } else { CHECK_JUST(AddAndInferGlobalOp(op_conf)); } }); // updata job_helper op_graph.DumpLogicalBlobDesc(job_); op_graph.DumpNdSbpSignature(job_); return Maybe::Ok(); } Maybe JobBuildAndInferCtx::GetOpBlobLbn(const std::string& op_name, const std::string& bn_in_op) const { const auto& lbi = JUST(Op4OpName(op_name))->BnInOp2Lbi(bn_in_op); return GenLogicalBlobName(lbi); } Maybe JobBuildAndInferCtx::NewUniqueOpNameByFunctionalOpConf( const OperatorConf& op_conf) { // NOTE(chengcheng): arg op_conf has a default global op_name because it is created by // static functional op expr, so we need reset a unique op name for each functional op. // This op_conf can NOT be a input/output/variable op which has set correct name in nn.Graph. // But free eager tensor is treated as a special variable which needs to create name here. CHECK_OR_RETURN(!(op_conf.has_input_conf() || op_conf.has_output_conf())); const auto& scope = JUST(GetCurrentScope()); std::string op_name_prefix; for (const std::string& prefix : scope->scope_proto().scope_op_name_prefixes()) { op_name_prefix += (prefix + "-"); } std::string op_type_name; if (op_conf.has_user_conf()) { op_type_name = op_conf.user_conf().op_type_name(); } else if (op_conf.has_variable_conf()) { // NOTE(chengcheng): To support Free Eager Tensor caught by nn.Graph op_type_name = "FreeEagerTensor"; } else { op_type_name = "SystemOp"; } std::string op_name = op_name_prefix + op_type_name + "-" + std::to_string(unique_op_name_index_); ++unique_op_name_index_; return op_name; } } // namespace oneflow ================================================ FILE: oneflow/core/job/job_build_and_infer_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_ #define ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/register/blob_desc.h" namespace oneflow { class JobBuildAndInferCtx { public: OF_DISALLOW_COPY_AND_MOVE(JobBuildAndInferCtx); JobBuildAndInferCtx(Job* job, int64_t job_id); virtual ~JobBuildAndInferCtx() = default; Maybe SetJobConf(const JobConfigProto& job_conf); Maybe AddAndInferGlobalOp(const OperatorConf& op_conf); Maybe AddAndInferLocalOp(const OperatorConf& op_conf); Maybe AddLossLogicalBlobName(const std::string& lbn); Maybe SetTrainConf(const TrainConf& train_conf); Maybe MarkVariableGradientBlobNames( const HashMap& variable_grad_lbns); Maybe MarkOutputGradientBlobNames( const HashMap& output_gradient_lbns); bool HasJobConf() const; Maybe GetStaticShape(const std::string& lbn) const; Maybe GetDataType(const std::string& lbn) const; Maybe IsDynamic(const std::string& lbn) const; Maybe IsDisableBoxing(const std::string& lbn) const; Maybe DisableBoxing(const std::string& lbn); Maybe GetSplitAxisFromProducerView(const std::string& lbn) const; Maybe GetParallelDescFromProducerView(const std::string& lbn) const; bool IsLocalBlob(const std::string& lbn) const; Maybe LocalBlobGetNumSubLbi(const std::string& lbn) const; Maybe LocalBlobGetSubLbi(const std::string& lbn, int index) const; Maybe LocalBlobGetStaticShape(const std::string& lbn_with_hint) const; Maybe LocalBlobGetDataType(const std::string& lbn_with_hint) const; Maybe LocalBlobIsDynamic(const std::string& lbn_with_hint) const; Maybe LocalBlobGetSplitAxisFromProducerView(const std::string& lbn_with_hint) const; Maybe LocalBlobGetParallelDescFromProducerView( const std::string& lbn_with_hint) const; const Job& job() const; int64_t job_id() const { return job_id_; } Maybe CheckJob() const; std::string GetJobStructureGraphJson(const std::string& job_name) const; Maybe CheckLbnValidAndExist(const std::string& lbn) const; Maybe Rebuild(); Maybe GetOpBlobLbn(const std::string& op_name, const std::string& bn_in_op) const; // NOTE(chengcheng): Only used in multi-client. Maybe NewUniqueOpNameByFunctionalOpConf(const OperatorConf& op_conf); Maybe Op4OpName(const std::string& op_name) const; virtual Maybe Complete() = 0; protected: virtual Maybe CheckAllInputsWithSameParallelNum(const Operator& op, int32_t parallel_num) const = 0; virtual std::string GetLocalOpName(const std::string& op_name, int64_t parallel_id) const = 0; virtual int64_t SizeOfSubGlobalOpList(int64_t parallel_num) const = 0; virtual ParallelConf GetLocalOpParallelConf(const ParallelDesc&, int64_t parallel_id) const = 0; virtual bool GetIsLocalParallelView() const = 0; virtual Maybe FindOrCreateLocalLbiFromCompatibleGlobalBlob( int64_t scope_symbol_id, const LogicalBlobId& lbn) = 0; Job* mut_job() const { return job_; } const HashMap>& local_lbi2sub_lbis() const { return local_lbi2sub_lbis_; } HashMap>* mut_local_lbi2sub_lbis() { return &local_lbi2sub_lbis_; } Maybe ParallelDesc4Lbi(const LogicalBlobId& lbi) const; HashMap* mut_global_lbi2local_lbi() { return &global_lbi2local_lbi_; } Maybe SbpParallel4Lbi(const LogicalBlobId& lbi) const; bool IsVariableLbi(const LogicalBlobId& lbi) const; Maybe AddAndInferOp(const OperatorConf& op_conf, const ParallelConf& parallel_conf, const JobDesc* job_desc, bool is_local_parallel_view); private: Maybe InferOpParallelConf( const Operator& op, const ParallelConf& origin_parallel_conf, const HashMap& ibn2disable_boxing) const; Maybe AddOpNameParallelConf2Placement(const std::string& op_name, const ParallelConf& parallel_conf); void InitIbn2DisableBoxing(const Operator& op, HashMap* ibn2disable_boxing); Maybe InitConstraitNdSbpSignature( const Operator& op, const HashMap& ibn2disable_boxing) const; Maybe DecodeLbiHintAndReturnNewOpConf(const Operator& op, SbpSignature* sbp_sig_conf) const; Maybe AddLbiParallelConf2BlobPlacement( const Operator* op, std::function ParallelDesc4Obn); void AddOpAndUpdateJobParallelViewConf(const OperatorConf& operator_conf, const ParallelDesc& parallel_desc, const NdSbpSignature& nd_sbp_signature, bool is_local_parallel_view) const; Maybe InferLocalSignature(Operator*, bool is_local_parallel_view_conf, const ParallelDesc&); Maybe InferOpOutNdSbp(Operator*, const NdSbpSignature&, const ParallelDesc&); Maybe GenOpProducedEmptyLogicalBlobDesc(Operator* op); Maybe CheckOpBlobSplitability(Operator*, int64_t parallel_num); Maybe CheckPlacement() const; Maybe CheckJobConf() const; Maybe CheckOpScope() const; Maybe GetLocalLbi(const std::string& lbn_with_hint) const; bool HasAnyLocalBlobInput(const Operator& op) const; Maybe CheckAllInputsConvertableToLocalBlob(const Operator& op) const; Maybe AddLossGlobalBlobName(const std::string& lbn); Maybe AddLossLocalBlobName(const std::string& lbn); Maybe GetSubLbi(int64_t scope_symbol_id, const LogicalBlobId& lbi, int32_t index); Maybe AllInputsBroadcastParallel(const Operator& op) const; Job* job_; int64_t job_id_; HashMap> lbi2logical_blob_desc_; HashMap lbi2nd_sbp_from_producer_view_; HashMap lbi2parallel_desc_from_producer_view_; HashMap lbi2disable_boxing_; HashMap> op_name2op_; HashMap parallel_desc2placement_group_; HashMap parallel_desc2blob_placement_group_; HashMap global_lbi2local_lbi_; HashMap> local_lbi2sub_lbis_; HashMap local_lbi2parallel_desc_; HashMap local_lbi2sbp_parallel_; bool is_job_conf_frozen_; bool has_job_conf_; HashMap op_name2ancestors_need_no_grad_; int64_t unique_op_name_index_; }; class LazyJobBuildAndInferCtx : public JobBuildAndInferCtx { public: OF_DISALLOW_COPY_AND_MOVE(LazyJobBuildAndInferCtx); LazyJobBuildAndInferCtx(Job* job, int64_t job_id) : JobBuildAndInferCtx(job, job_id) {} virtual ~LazyJobBuildAndInferCtx() = default; private: Maybe Complete() override; Maybe CheckAllInputsWithSameParallelNum(const Operator& op, int32_t parallel_num) const override; std::string GetLocalOpName(const std::string& op_name, int64_t parallel_id) const override; int64_t SizeOfSubGlobalOpList(int64_t parallel_num) const override { return parallel_num; } ParallelConf GetLocalOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override; bool GetIsLocalParallelView() const override { return false; } Maybe FindOrCreateLocalLbiFromCompatibleGlobalBlob( int64_t scope_symbol_id, const LogicalBlobId& lbn) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_ ================================================ FILE: oneflow/core/job/job_build_and_infer_ctx_mgr.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_state.h" #include "oneflow/core/job/lazy_mode.h" #include "nlohmann/json.hpp" namespace oneflow { Maybe JobBuildAndInferCtxMgr::OpenJobBuildAndInferCtx(const std::string& job_name) { CHECK_OR_RETURN(!has_cur_job_) << Error::UnknownJobBuildAndInferError() << "cur job not leave before you enter this job_name:" << job_name; CHECK_OR_RETURN(!job_name.empty()) << Error::JobNameEmptyError(); CHECK_OR_RETURN(job_name2infer_ctx_.find(job_name) == job_name2infer_ctx_.end()) << Error::JobNameExistError() << "job name: " << job_name << " already exist"; int64_t job_id = job_id_count_++; Job* job = job_set_.add_job(); job->mutable_job_conf()->set_job_name(job_name); std::unique_ptr ctx(NewJobBuildAndInferCtx(job, job_id)); job_name2infer_ctx_.emplace(job_name, std::move(ctx)); cur_job_name_ = job_name; has_cur_job_ = true; return Maybe::Ok(); } JobBuildAndInferCtx* LazyJobBuildAndInferCtxMgr::NewJobBuildAndInferCtx(Job* job, int64_t job_id) const { return new LazyJobBuildAndInferCtx(job, job_id); } Maybe JobBuildAndInferCtxMgr::FindJobBuildAndInferCtx( const std::string& job_name) { CHECK_OR_RETURN(job_name2infer_ctx_.find(job_name) != job_name2infer_ctx_.end()) << Error::NoJobBuildAndInferCtxError() << "cannot find job name:" << job_name; return job_name2infer_ctx_.at(job_name).get(); } Maybe JobBuildAndInferCtxMgr::GetCurrentJobName() const { CHECK_OR_RETURN(has_cur_job_) << Error::NoJobBuildAndInferCtxError() << "current JobBuildAndInferCtx was closed, job name: " << cur_job_name_; return cur_job_name_; } Maybe JobBuildAndInferCtxMgr::CloseCurrentJobBuildAndInferCtx() { OF_RETURN_IF_ERROR(VirtualCloseJob()); has_cur_job_ = false; return Maybe::Ok(); } std::string JobBuildAndInferCtxMgr::structure_graph() const { nlohmann::json json_array; for (const auto& pair : job_name2infer_ctx_) { nlohmann::json json_pair; json_pair["class_name"] = "Model"; std::string tmp_json = pair.second->GetJobStructureGraphJson(pair.first); json_pair["config"] = nlohmann::json::parse(tmp_json); json_pair["backend"] = "oneflow"; json_array.emplace_back(json_pair); } return json_array.dump(); } void JobBuildAndInferCtxMgr::TryUpdateJobIdCount(int64_t id_count) { job_id_count_ = std::max(id_count, job_id_count_); } int64_t JobBuildAndInferCtxMgr::GetJobIdCount() const { return job_id_count_; } Maybe LazyJobBuildAndInferCtxMgr::VirtualCloseJob() { const JobDesc* job_desc = Singleton::Get(); if (job_desc == nullptr) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(job_desc->job_name(), *JUST(GetCurrentJobName())); Singleton::Delete(); return Maybe::Ok(); } Maybe GlobalJobBuildAndInferCtxMgr() { return JUST(SingletonMaybe()); } Maybe GetJobBuildAndInferCtx(const std::string& job_name) { auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr()); return mgr->FindJobBuildAndInferCtx(job_name); } Maybe GetCurInferCtx() { auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr()); return mgr->FindJobBuildAndInferCtx(*JUST(mgr->GetCurrentJobName())); } } // namespace oneflow ================================================ FILE: oneflow/core/job/job_build_and_infer_ctx_mgr.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_ #define ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" namespace oneflow { class JobBuildAndInferCtxMgr { public: OF_DISALLOW_COPY_AND_MOVE(JobBuildAndInferCtxMgr); virtual ~JobBuildAndInferCtxMgr() = default; Maybe OpenJobBuildAndInferCtx(const std::string& job_name); Maybe FindJobBuildAndInferCtx(const std::string& job_name); Maybe GetCurrentJobName() const; Maybe CloseCurrentJobBuildAndInferCtx(); const JobSet& job_set() const { return job_set_; } std::string structure_graph() const; void TryUpdateJobIdCount(int64_t id_count); int64_t GetJobIdCount() const; protected: virtual JobBuildAndInferCtx* NewJobBuildAndInferCtx(Job* job, int64_t job_id) const = 0; JobBuildAndInferCtxMgr() : has_cur_job_(false) {} virtual Maybe VirtualCloseJob() = 0; JobSet* mut_job_set() { return &job_set_; } void clear_job_name2infer_ctx() { job_name2infer_ctx_.clear(); } private: JobSet job_set_; int64_t job_id_count_{0}; bool has_cur_job_; std::string cur_job_name_; HashMap> job_name2infer_ctx_; }; class LazyJobBuildAndInferCtxMgr : public JobBuildAndInferCtxMgr { public: OF_DISALLOW_COPY_AND_MOVE(LazyJobBuildAndInferCtxMgr); LazyJobBuildAndInferCtxMgr() : JobBuildAndInferCtxMgr() {} ~LazyJobBuildAndInferCtxMgr() override = default; private: friend class Singleton; Maybe VirtualCloseJob() override; JobBuildAndInferCtx* NewJobBuildAndInferCtx(Job* job, int64_t job_id) const override; }; Maybe GlobalJobBuildAndInferCtxMgr(); Maybe GetJobBuildAndInferCtx(const std::string& job_name); Maybe GetCurInferCtx(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_ ================================================ FILE: oneflow/core/job/job_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/job_builder.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/scope_util.h" namespace oneflow { namespace { int64_t GetParallelHierarchyNumAxes( const HashMap& op_name2parallel_conf, const std::string& op_name) { const auto& it = op_name2parallel_conf.find(op_name); CHECK(it != op_name2parallel_conf.end()); if (!it->second->has_hierarchy()) { return 1; } else if (it->second->hierarchy().dim_size() == 0) { return 1; } else { return it->second->hierarchy().dim_size(); } } void SetNdSbpSignature4Oba(Job* job, HashMap* op_name2nd_sbp_signature_map, const OpBlobArg& oba, const NdSbp& nd_sbp) { auto* nd_sbp_sig = &(*job->mutable_job_parallel_view_conf() ->mutable_op_name2nd_sbp_signature_conf())[oba.op_name()]; (*nd_sbp_sig->mutable_bn_in_op2nd_sbp())[oba.bn_in_op()] = nd_sbp; auto* op_name2nd_sbp_signature_conf = job->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf(); (*op_name2nd_sbp_signature_map)[oba.op_name()] = &(*op_name2nd_sbp_signature_conf)[oba.op_name()]; } void SetSbpSignature4Oba(Job* job, const OpBlobArg& oba, const SbpParallel& sbp_parallel) { auto* sbp_sig = &( *job->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf())[oba.op_name()]; (*sbp_sig->mutable_bn_in_op2sbp_parallel())[oba.bn_in_op()] = sbp_parallel; } void AddOrSetNdSbpSignature4OpName( Job* job, HashMap* op_name2nd_sbp_signature_map, const std::string& op_name, const NdSbpSignature& nd_sbp_signature) { const auto& it = op_name2nd_sbp_signature_map->find(op_name); if (it != op_name2nd_sbp_signature_map->end()) { *it->second = nd_sbp_signature; } else { auto* op_name2nd_sbp_signature_conf = job->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf(); (*op_name2nd_sbp_signature_conf)[op_name] = nd_sbp_signature; op_name2nd_sbp_signature_map->emplace(op_name, &(*op_name2nd_sbp_signature_conf)[op_name]); } } void AddOrSetSbpSignature4OpName(Job* job, const std::string& op_name, const SbpSignature& sbp_signature) { auto* op_name2sbp_signature_conf = job->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf(); (*op_name2sbp_signature_conf)[op_name] = sbp_signature; } } // namespace std::function MakeGetterParallelConf4OpName( const Placement& placement) { auto op_name2parallel_conf = std::make_shared>(); for (const auto& placement_group : placement.placement_group()) { for (const std::string& op_name : placement_group.op_set().op_name()) { const ParallelConf* parallel_conf = &placement_group.parallel_conf(); CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second) << "op_name: " << op_name; } } return [op_name2parallel_conf](const std::string& op_name) { return op_name2parallel_conf->at(op_name); }; } JobBuilder::JobBuilder(Job* job) : job_(job) { FOR_RANGE(int32_t, i, 0, job->net().op_size()) { CHECK(op_name2op_conf_.emplace(job->net().op(i).name(), job->mutable_net()->mutable_op(i)) .second); } bool all_ops_1d_hierarchy = true; FOR_RANGE(int32_t, i, 0, job->placement().placement_group_size()) { auto* placemnt_group = job->mutable_placement()->mutable_placement_group(i); if (placemnt_group->parallel_conf().has_hierarchy() && placemnt_group->parallel_conf().hierarchy().dim_size() > 1) { all_ops_1d_hierarchy = false; } } auto* job_parallel_view_conf = job->mutable_job_parallel_view_conf(); for (auto& pair : *(job_parallel_view_conf->mutable_op_name2nd_sbp_signature_conf())) { op_name2nd_sbp_signature_conf_.emplace(pair.first, &pair.second); } if (all_ops_1d_hierarchy) { CHECK_EQ(job_parallel_view_conf->op_name2sbp_signature_conf_size(), job_parallel_view_conf->op_name2nd_sbp_signature_conf_size()); for (const auto& pair : job_parallel_view_conf->op_name2nd_sbp_signature_conf()) { const auto& op_name2sbp_sig = job_parallel_view_conf->op_name2sbp_signature_conf(); const auto it = op_name2sbp_sig.find(pair.first); CHECK(it != op_name2sbp_sig.end()); CheckSbpSignatureAndNdSbpEquals(SbpSignature(it->second), NdSbpSignature(pair.second)); } } FOR_RANGE(int32_t, i, 0, job->placement().blob_placement_group_size()) { auto* blob_pg = job->mutable_placement()->mutable_blob_placement_group(i); for (const auto& lbi : blob_pg->lbi()) { CHECK(lbi2blob_parallel_conf_.emplace(lbi, blob_pg->mutable_parallel_conf()).second); } } for (auto& placement_group : *job->mutable_placement()->mutable_placement_group()) { if (placement_group.op_set().op_name().empty()) { continue; } const ParallelConf& parallel_conf = placement_group.parallel_conf(); auto it = parallel_conf2placement_group_.find(parallel_conf); if (it == parallel_conf2placement_group_.end()) { parallel_conf2placement_group_.emplace(parallel_conf, &placement_group); for (const auto& op_name : placement_group.op_set().op_name()) { CHECK(op_name2parallel_conf_.emplace(op_name, placement_group.mutable_parallel_conf()) .second); } } else { PlacementGroup* existing_placement_group = it->second; for (const auto& op_name : placement_group.op_set().op_name()) { *existing_placement_group->mutable_op_set()->mutable_op_name()->Add() = op_name; CHECK(op_name2parallel_conf_ .emplace(op_name, existing_placement_group->mutable_parallel_conf()) .second); } placement_group.mutable_op_set()->mutable_op_name()->Clear(); } } } Maybe JobBuilder::MutableOpConf4OpName(const std::string& op_name) { const auto& it = op_name2op_conf_.find(op_name); CHECK_OR_RETURN(it != op_name2op_conf_.end()); return it->second; } Maybe JobBuilder::OpConf4OpName(const std::string& op_name) const { return *JUST(MapAt(op_name2op_conf_, op_name)); } Maybe JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const { const auto& iter = lbi2blob_parallel_conf_.find(lbi); if (iter != lbi2blob_parallel_conf_.end()) { return *iter->second; } return ParallelConf4OpName(lbi.op_name()); } Maybe JobBuilder::AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf) { CHECK_OR_RETURN(op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end()); OperatorConf* mut_op_conf = job_->mutable_net()->add_op(); *mut_op_conf = op_conf; CHECK_OR_RETURN(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second); AddOpToModuleConf(op_conf); AddOpNamesToPlacementGroup({op_conf.name()}, parallel_conf); return Maybe::Ok(); } void JobBuilder::AddOps(const ParallelConf& parallel_conf, const std::vector& op_confs) { if (op_confs.empty()) { return; } std::vector op_names; op_names.reserve(op_confs.size()); for (const auto& op_conf : op_confs) { CHECK(op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end()); OperatorConf* mut_op_conf = job_->mutable_net()->add_op(); *mut_op_conf = op_conf; CHECK(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second); op_names.emplace_back(op_conf.name()); AddOpToModuleConf(op_conf); } AddOpNamesToPlacementGroup(op_names, parallel_conf); } void JobBuilder::AddOpToModuleConf(const OperatorConf& op_conf) { // set up the module config if (Singleton>::Get()->Has(op_conf.scope_symbol_id())) { const auto& scope = Singleton>::Get()->Get(op_conf.scope_symbol_id()); if (scope.scope_proto().has_module_name()) { const auto& module_name = scope.scope_proto().module_name(); auto* module_name2module_conf = job_->mutable_module_name2module_conf(); if (!(*module_name2module_conf)[module_name].has_name()) { (*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name()); } *((*module_name2module_conf)[module_name].add_ops()) = op_conf.name(); return; } } const auto& module_name = job_->job_conf().job_name(); auto* module_name2module_conf = job_->mutable_module_name2module_conf(); if (!(*module_name2module_conf)[module_name].has_name()) { (*module_name2module_conf)[module_name].set_name(module_name); } *((*module_name2module_conf)[module_name].add_ops()) = op_conf.name(); } void JobBuilder::AddOpNamesToPlacementGroup(const std::vector& op_names, const ParallelConf& parallel_conf) { PlacementGroup* placement_group = nullptr; auto it = parallel_conf2placement_group_.find(parallel_conf); if (it != parallel_conf2placement_group_.end()) { placement_group = it->second; } else { placement_group = job_->mutable_placement()->add_placement_group(); *placement_group->mutable_parallel_conf() = parallel_conf; parallel_conf2placement_group_.emplace(parallel_conf, placement_group); } for (const auto& op_name : op_names) { placement_group->mutable_op_set()->add_op_name(op_name); CHECK(op_name2parallel_conf_.emplace(op_name, placement_group->mutable_parallel_conf()).second); } } void JobBuilder::MutParallelConfOnlyOnce(const std::string& op_name, const ParallelConf& parallel_conf) { CHECK(modified_parallel_conf_op_names_.emplace(op_name).second); const auto& parallel_conf_it = op_name2parallel_conf_.find(op_name); CHECK(parallel_conf_it != op_name2parallel_conf_.end()); auto old_placement_group_it = parallel_conf2placement_group_.find(*parallel_conf_it->second); CHECK(old_placement_group_it != parallel_conf2placement_group_.end()); op_name2parallel_conf_.erase(parallel_conf_it); Erase>(*old_placement_group_it->second->mutable_op_set()->mutable_op_name(), [&](const std::string& x) { return x == op_name; }); AddOpNamesToPlacementGroup({op_name}, parallel_conf); } void JobBuilder::RemoveOpByName(const std::string& op_name) { RemoveOpByName(std::unordered_set{op_name}); } void JobBuilder::RemoveOpByName(const std::unordered_set& removing_names) { // Update net DLNetConf net = job_->net(); job_->mutable_net()->clear_op(); for (const OperatorConf& op_conf : net.op()) { if (removing_names.count(op_conf.name()) == 0) { *(job_->mutable_net()->add_op()) = op_conf; } } // Update module conf auto module_confs_map = job_->module_name2module_conf(); job_->clear_module_name2module_conf(); for (const auto& module_conf_pair : module_confs_map) { const auto& module_name = module_conf_pair.first; auto* module_name2module_conf = job_->mutable_module_name2module_conf(); if (!(*module_name2module_conf)[module_name].has_name()) { (*module_name2module_conf)[module_name].set_name(module_name); } for (const auto& op_name : module_conf_pair.second.ops()) { if (removing_names.count(op_name) == 0) { *((*module_name2module_conf)[module_name].add_ops()) = op_name; } } } // Update placement auto placement_group = job_->placement().placement_group(); job_->mutable_placement()->clear_placement_group(); for (const PlacementGroup& place : placement_group) { PlacementGroup p; OpNameSet* op_set = p.mutable_op_set(); for (const std::string& name : place.op_set().op_name()) { if (removing_names.count(name) == 0) { op_set->add_op_name(name); } } *(p.mutable_parallel_conf()) = place.parallel_conf(); if (op_set->op_name().size() > 0) { *(job_->mutable_placement()->add_placement_group()) = p; } } auto* op_name2sbp_signature_conf = job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf(); auto* op_name2nd_sbp_signature_conf = job_->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf(); for (const std::string& op_name : removing_names) { // Update NdSbp, Sbp if (op_name2nd_sbp_signature_conf->count(op_name) > 0) { op_name2nd_sbp_signature_conf->erase(op_name); if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name) == 1) { CHECK(op_name2sbp_signature_conf->count(op_name) > 0); op_name2sbp_signature_conf->erase(op_name); } } } // Update builder JobBuilder builder(job_); op_name2op_conf_.swap(builder.op_name2op_conf_); op_name2parallel_conf_.swap(builder.op_name2parallel_conf_); op_name2nd_sbp_signature_conf_.swap(builder.op_name2nd_sbp_signature_conf_); parallel_conf2placement_group_.swap(builder.parallel_conf2placement_group_); } void JobBuilder::DelOps(const std::vector& op_names) { std::unordered_set removing_names; for (const auto& op_name : op_names) { removing_names.insert(op_name); } RemoveOpByName(removing_names); } void JobBuilder::DelOps(const std::vector& op_confs) { std::unordered_set removing_names; for (const auto& op_conf : op_confs) { removing_names.insert(op_conf.name()); } RemoveOpByName(removing_names); } Maybe JobBuilder::MutOpOnlyOnce(const OperatorConf& op_conf) { CHECK_OR_RETURN(modified_op_conf_op_names_.emplace(op_conf.name()).second) << op_conf.name() << " is mut twice."; auto find_iter = op_name2op_conf_.find(op_conf.name()); CHECK_OR_RETURN(find_iter != op_name2op_conf_.end()) << op_conf.name() << " not found."; find_iter->second->CopyFrom(op_conf); return Maybe::Ok(); } void JobBuilder::MutOpsOnlyOnce(const std::vector& op_confs) { for (const auto& op_conf : op_confs) { CHECK(modified_op_conf_op_names_.emplace(op_conf.name()).second) << op_conf.name() << " is mut twice."; op_name2op_conf_.at(op_conf.name())->CopyFrom(op_conf); } } Maybe JobBuilder::IsInMutOpTransaction(const std::string& op_name) const { auto find_iter = mut_op_transaction_name2op_conf_.find(op_name); return find_iter != mut_op_transaction_name2op_conf_.end(); } Maybe JobBuilder::MutOpTransactionGet(const std::string& op_name) { return JUST(MapAt(mut_op_transaction_name2op_conf_, op_name)); } Maybe JobBuilder::MutOpTransactionMut(const OperatorConf& op_conf) { auto find_iter = mut_op_transaction_name2op_conf_.find(op_conf.name()); if (find_iter == mut_op_transaction_name2op_conf_.end()) { CHECK_OR_RETURN(mut_op_transaction_name2op_conf_.emplace(op_conf.name(), op_conf).second) << op_conf.name() << " has been added."; } else { find_iter->second.CopyFrom(op_conf); } return Maybe::Ok(); } Maybe JobBuilder::MutOpTransactionCommit() { for (const auto& pair : mut_op_transaction_name2op_conf_) { JUST(MutOpOnlyOnce(pair.second)); } return Maybe::Ok(); } void JobBuilder::AddOrMutOpsOnlyOnce(const ParallelConf& parallel_conf, const std::vector& op_confs) { std::vector add_ops; std::vector mut_ops; for (const auto& op_conf : op_confs) { if (op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end()) { add_ops.emplace_back(op_conf); } else { mut_ops.emplace_back(op_conf); } } AddOps(parallel_conf, add_ops); MutOpsOnlyOnce(mut_ops); } Maybe JobBuilder::ForEachOperator( const std::function(const Operator&)>& Handler) const { for (const auto& pair : op_name2op_conf_) { auto it = op_name2parallel_conf_.find(pair.first); CHECK_OR_RETURN(it != op_name2parallel_conf_.end()) << "op_name: " << pair.first; DeviceType device_type = ParallelDesc(*it->second).device_type(); std::shared_ptr op = JUST(ConstructOp(*pair.second, device_type)); JUST(Handler(*op)); } return Maybe::Ok(); } Maybe JobBuilder::ParallelConf4OpName(const std::string& op_name) const { const auto& iter = op_name2parallel_conf_.find(op_name); CHECK_OR_RETURN(iter != op_name2parallel_conf_.end()); return *iter->second; } SbpParallel* JobBuilder::MutSbpParallel4Oba(const OpBlobArg& oba) const { // TODO(guoran): rm this func auto* sbp_sig = &( *job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf())[oba.op_name()]; return &(*sbp_sig->mutable_bn_in_op2sbp_parallel())[oba.bn_in_op()]; } void JobBuilder::SetSbpParallel4Oba(const OpBlobArg& oba, const SbpParallel& sbp_parallel) { CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, oba.op_name()), 1); SetSbpSignature4Oba(job_, oba, sbp_parallel); NdSbp nd_sbp; *nd_sbp.add_sbp_parallel() = sbp_parallel; SetNdSbpSignature4Oba(job_, &op_name2nd_sbp_signature_conf_, oba, nd_sbp); } void JobBuilder::SetNdSbp4Oba(const OpBlobArg& oba, const NdSbp& nd_sbp) { SetNdSbpSignature4Oba(job_, &op_name2nd_sbp_signature_conf_, oba, nd_sbp); if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, oba.op_name()) == 1) { SetSbpSignature4Oba(job_, oba, nd_sbp.sbp_parallel(0)); } } const SbpSignature JobBuilder::SbpSignature4OpName(const std::string& op_name) const { CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name), 1); const auto& it = op_name2nd_sbp_signature_conf_.find(op_name); CHECK(it != op_name2nd_sbp_signature_conf_.end()); SbpSignature sbp_sig_conf; NdSbpSignatureToSbpSignature(*it->second, &sbp_sig_conf); return sbp_sig_conf; } void JobBuilder::AddSbpSignature4OpName(const std::string& op_name, const SbpSignature& sbp_signature) { NdSbpSignature nd_sbp_signature; SbpSignatureToNdSbpSignature(sbp_signature, &nd_sbp_signature); AddOrSetNdSbpSignature4OpName(job_, &op_name2nd_sbp_signature_conf_, op_name, nd_sbp_signature); CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name), 1); AddOrSetSbpSignature4OpName(job_, op_name, sbp_signature); } const NdSbpSignature& JobBuilder::NdSbpSignature4OpName(const std::string& op_name) const { const auto& it = op_name2nd_sbp_signature_conf_.find(op_name); CHECK(it != op_name2nd_sbp_signature_conf_.end()); return *(it->second); } void JobBuilder::AddNdSbpSignature4OpName(const std::string& op_name, const NdSbpSignature& nd_sbp_signature) { AddOrSetNdSbpSignature4OpName(job_, &op_name2nd_sbp_signature_conf_, op_name, nd_sbp_signature); if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name) == 1) { SbpSignature sbp_signature; NdSbpSignatureToSbpSignature(nd_sbp_signature, &sbp_signature); AddOrSetSbpSignature4OpName(job_, op_name, sbp_signature); } } } // namespace oneflow ================================================ FILE: oneflow/core/job/job_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_ #define ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/register/op_blob_arg.pb.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { const static std::string kProducedLbi2ConsumedDiffLbi = "produced_lbi2consumed_diff_lbi"; std::function MakeGetterParallelConf4OpName( const Placement& placement); class SbpParallel; class LogicalBlobId; class Operator; class JobBuilder final { public: OF_DISALLOW_COPY_AND_MOVE(JobBuilder); explicit JobBuilder(Job* job); ~JobBuilder() = default; const Job& job() const { return *job_; } JobHelperConf* mutable_helper() { return job_->mutable_helper(); } JobParallelViewConf* mutable_job_parallel_view_conf() { return job_->mutable_job_parallel_view_conf(); } MergedLogicalChainIdGroup* add_logical_chain_groups() { return job_->add_logical_chain_groups(); } Maybe OpConf4OpName(const std::string& op_name) const; Maybe MutableOpConf4OpName(const std::string& op_name); Maybe AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf); void AddOps(const ParallelConf& parallel_conf, const std::vector& op_confs); Maybe MutOpOnlyOnce(const OperatorConf& op_conf); void MutOpsOnlyOnce(const std::vector& op_confs); // Mut op with transaction Maybe IsInMutOpTransaction(const std::string& op_name) const; Maybe MutOpTransactionGet(const std::string& op_name); Maybe MutOpTransactionMut(const OperatorConf& op_conf); Maybe MutOpTransactionCommit(); void MutParallelConfOnlyOnce(const std::string& op_name, const ParallelConf& parallel_conf); void AddOrMutOpsOnlyOnce(const ParallelConf& parallel_conf, const std::vector& op_confs); void RemoveOpByName(const std::string& op_name); void RemoveOpByName(const std::unordered_set& removing_names); void DelOps(const std::vector& op_names); void DelOps(const std::vector& op_confs); SbpParallel* MutSbpParallel4Oba(const OpBlobArg& oba) const; void SetSbpParallel4Oba(const OpBlobArg& oba, const SbpParallel& sbp_parallel); void SetNdSbp4Oba(const OpBlobArg& oba, const NdSbp& nd_sbp); Maybe ForEachOperator(const std::function(const Operator&)>& Handler) const; Maybe ParallelConf4Lbi(const LogicalBlobId& lbi) const; Maybe ParallelConf4OpName(const std::string& op_name) const; const SbpSignature SbpSignature4OpName(const std::string& op_name) const; void AddSbpSignature4OpName(const std::string& op_name, const SbpSignature& sbp_signature); const NdSbpSignature& NdSbpSignature4OpName(const std::string& op_name) const; void AddNdSbpSignature4OpName(const std::string& op_name, const NdSbpSignature& nd_sbp_signature); private: void AddOpNamesToPlacementGroup(const std::vector& op_names, const ParallelConf& parallel_conf); void AddOpToModuleConf(const OperatorConf& op_conf); Job* job_; HashMap op_name2op_conf_; HashMap op_name2parallel_conf_; HashMap lbi2blob_parallel_conf_; HashSet modified_op_conf_op_names_; HashSet modified_parallel_conf_op_names_; HashMap op_name2nd_sbp_signature_conf_; HashMap parallel_conf2placement_group_; HashMap mut_op_transaction_name2op_conf_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_ ================================================ FILE: oneflow/core/job/job_conf.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/data_type.proto"; import "oneflow/core/job/placement.proto"; import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/framework/user_op_attr.proto"; import "oneflow/core/job/initializer_conf.proto"; import "oneflow/core/job/learning_rate_schedule_conf.proto"; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/operator/interface_blob_conf.proto"; message NaiveModelUpdateConf { } message MomentumModelUpdateConf { optional float beta = 1 [default = 0.9]; optional float dampening = 2 [default = 0.0]; optional bool nesterov = 3 [default = false]; optional bool maximize = 4 [default = false]; } message RMSPropModelUpdateConf { optional float decay_rate = 1 [default = 0.99]; optional float epsilon = 2 [default = 1e-8]; optional bool centered = 3 [default = false]; } message LARSModelUpdateConf { optional float momentum_beta = 1 [default = 0.9]; optional float epsilon = 2 [default = 1e-9]; optional float lars_coefficient = 3 [default = 0.0001]; } message AdamModelUpdateConf { optional float beta1 = 1 [default = 0.9]; optional float beta2 = 2 [default = 0.999]; optional float epsilon = 3 [default = 1e-8]; optional bool do_bias_correction = 4 [default = true]; optional bool amsgrad = 5 [default = false]; optional bool smart_decay = 6 [default = false]; } message LazyAdamModelUpdateConf { optional float beta1 = 1 [default = 0.9]; optional float beta2 = 2 [default = 0.999]; optional float epsilon = 3 [default = 1e-8]; optional bool do_bias_correction = 4 [default = true]; optional bool amsgrad = 5 [default = false]; } message LambModelUpdateConf { optional float beta1 = 1 [default = 0.9]; optional float beta2 = 2 [default = 0.999]; optional float epsilon = 3 [default = 1e-8]; optional bool do_bias_correction = 4 [default = true]; } message AdagradModelUpdateConf { required float lr_decay = 1 [default = 0.0]; required float initial_accumulator_value = 2 [default = 0.0]; required float epsilon = 3 [default = 1e-10]; } message FtrlModelUpdateConf { required float initial_accumulator_value = 1 [default = 0.1]; required float lr_power = 2 [default = 0.5]; optional float lambda1 = 3 [default = 0.0]; optional float lambda2 = 4 [default = 0.0]; optional float beta = 5 [default = 0.0]; } message AdadeltaModelUpdateConf { required float rho = 1 [default = 0.9]; required float epsilon = 2 [default = 1e-6]; required bool maximize = 3 [default = false]; } message ClipByGlobalNormConf { optional float max_norm = 1 [default = 1.0]; optional double norm_type = 2 [default = 2.0]; } message ClipConf { oneof type { ClipByGlobalNormConf clip_by_global_norm = 1; } } message WeightDecayFilterPatternSet { repeated string pattern = 1; } message WeightDecayConf { required float weight_decay_rate = 1; oneof weight_decay_filter_type { WeightDecayFilterPatternSet includes = 2; WeightDecayFilterPatternSet excludes = 3; } } message OptimizerConf { repeated string variable_op_names = 1; optional float base_learning_rate = 2; repeated string variable_grad_lbns = 3; optional LearningRateDecayConf learning_rate_decay = 4; optional string learning_rate_lbn = 5; optional ClipConf clip_conf = 6; optional WeightDecayConf weight_decay_conf = 7; optional float lr_scale = 8 [default = 1.0]; oneof normal_mdupdt { NaiveModelUpdateConf naive_conf = 1000; MomentumModelUpdateConf momentum_conf = 1001; RMSPropModelUpdateConf rmsprop_conf = 1002; LARSModelUpdateConf lars_conf = 1003; AdamModelUpdateConf adam_conf = 1004; LazyAdamModelUpdateConf lazy_adam_conf = 1005; LambModelUpdateConf lamb_conf = 1006; AdagradModelUpdateConf adagrad_conf = 1007; FtrlModelUpdateConf ftrl_conf = 1008; AdadeltaModelUpdateConf adadelta_conf = 1009; } } message NormalModelUpdateOpUserConf { optional LearningRateDecayConf learning_rate_decay = 1; optional ClipConf clip_conf = 3; optional WeightDecayConf weight_decay_conf = 4; oneof normal_mdupdt { NaiveModelUpdateConf naive_conf = 1000; MomentumModelUpdateConf momentum_conf = 1001; RMSPropModelUpdateConf rmsprop_conf = 1002; LARSModelUpdateConf lars_conf = 1003; AdamModelUpdateConf adam_conf = 1004; LazyAdamModelUpdateConf lazy_adam_conf = 1005; LambModelUpdateConf lamb_conf = 1006; AdagradModelUpdateConf adagrad_conf = 1007; FtrlModelUpdateConf ftrl_conf = 1008; } } message DynamicLossScalePolicy { optional float initial_loss_scale = 1 [default = 1073741824.0]; optional float increment_period = 2 [default = 2000]; optional float multiplier = 3 [default=2.0]; } message TrainConf { repeated OptimizerConf optimizer_conf = 1; repeated string loss_lbn = 2; repeated string loss_grad_lbn = 6; optional string train_step_lbn = 3; oneof loss_scale_policy { float loss_scale_factor = 4 [default = 1]; DynamicLossScalePolicy dynamic_loss_scale_policy = 5; } // Deprecated model update conf, will be removed later. optional NormalModelUpdateOpUserConf model_update_conf = 101; optional float primary_lr = 102; optional float secondary_lr = 103; optional string primary_lr_lbn = 104; optional string secondary_lr_lbn = 105; } message PredictConf { } message MemoryAllocationAlgorithmConf { optional bool use_mem_size_first_algo = 1 [default = true]; optional bool use_lifetime_first_algo = 2 [default = false]; optional bool use_time_line_algo = 3 [default = false]; optional bool use_mem_volume_first_algo = 4 [default = false]; } message MemoryCompactInsertConf { optional bool use_compact_insert = 1 [default = false]; optional bool use_non_compact_insert = 2 [default = true]; } message QatConfig { optional bool per_channel_weight_quantization = 1 [default = false]; optional bool symmetric = 2 [default = true]; optional float moving_min_max_momentum = 3 [default = 0.95]; optional int64 moving_min_max_stop_update_after_iters = 4; optional string target_backend = 5 [default = ""]; } message IndexedSlicesOptimizerConf { optional bool enable = 1 [default = true]; required OpNameSet include_op_names = 2; } message ParallelBlobConf { required BlobDescProto logical_blob_desc_conf = 1; required ParallelConf parallel_conf = 2; required NdSbp nd_sbp = 3; } message JobInputDef { required LogicalBlobId lbi = 1; required InterfaceBlobConf blob_conf = 2; } message JobOutputDef { required LogicalBlobId lbi = 1; } message JobSignatureDef { map inputs = 1; map outputs = 2; } enum StraightenAlgorithmTag { kDisableStraighten = 1; kOverlap4Transfer = 2; kCompressMemory = 3; kOverlap4CpuGpu = 4; kDelayShortGpu = 5; } enum AutoMemoryStrategy { kDisableAutoMemory = 1; kSlightAutoMemory = 2; kModerateAutoMemory = 3; kHeavyAutoMemory = 4; kAdaptiveAutoMemory = 5; } message JobConfigProto { required string job_name = 1; oneof job_type { TrainConf train_conf = 3; PredictConf predict_conf = 4; } optional DataType default_data_type = 8 [default = kFloat]; // kFloat or kDouble oneof default_initialize_conf { InitializerConf default_initializer_conf = 10; string default_initialize_with_snapshot_path = 11; } optional MemoryAllocationAlgorithmConf memory_allocation_algorithm_conf = 102; optional MemoryCompactInsertConf memory_compact_insert_conf = 103; optional IndexedSlicesOptimizerConf indexed_slices_optimizer_conf = 104; optional bool enable_fuse_model_update_ops = 105 [default = false]; optional bool enable_gradients_stats_aggregation = 106 [default = true]; optional string optimizer_placement_optimization_mode = 107; optional int64 optimizer_placement_optimization_threshold = 108 [default = 1024]; optional int64 optimizer_placement_optimization_shard_restore_level = 110 [default = 2]; optional QatConfig qat_config = 109; optional bool enable_cudnn = 200 [default = true]; optional int64 cudnn_buf_limit_mbyte = 201 [default = 1024]; // 1GByte optional int32 cudnn_conv_force_fwd_algo = 202; optional int32 cudnn_conv_force_bwd_data_algo = 203; optional int32 cudnn_conv_force_bwd_filter_algo = 204; optional bool cudnn_conv_heuristic_search_algo = 205 [default = true]; optional bool cudnn_conv_use_deterministic_algo_only = 206 [default = false]; optional bool enable_cudnn_fused_normalization_add_relu = 207; optional bool enable_fuse_add_to_output = 208 [default = false]; optional bool enable_fuse_cast_scale = 209 [default = false]; optional int64 num_gradient_accumulation_steps = 210; optional bool enable_reuse_mem = 300 [default = true]; optional bool enable_inplace = 301 [default = true]; optional bool enable_inplace_in_reduce_struct = 302 [default = true]; optional bool do_parallel_cast_before_widening_type_cast = 403 [default = true]; optional bool prune_parallel_cast_ops = 509 [default = true]; optional bool prune_cast_to_static_shape_ops = 510 [default = true]; optional bool prune_amp_white_identity_ops = 511 [default = true]; optional bool prune_depend_ops = 512 [default = true]; optional bool cudnn_conv_enable_pseudo_half = 600 [default = true]; optional bool enable_auto_mixed_precision = 602 [default = false]; optional bool enable_quantization_aware_training = 603 [default = false]; optional DataType mixed_precision_data_type = 604 [default = kFloat16]; // kFloat16 or kBFloat16 optional bool enable_multi_tensor_update = 605 [default = false]; optional bool enable_fused_model_update_cast = 606 [default = false]; optional bool enable_auto_parallel = 700 [default = false]; optional double auto_parallel_computation_cost_ratio = 701 [default = 0.05]; optional double auto_parallel_wait_time = 702 [default = 1.65e4]; optional bool enable_auto_parallel_trunk_algo = 703 [default = true]; optional bool enable_auto_parallel_sbp_collector = 704 [default = false]; optional bool enable_auto_parallel_ignore_user_sbp_config = 705 [default = false]; optional AutoMemoryStrategy enable_auto_memory = 706 [default = kAdaptiveAutoMemory]; optional StraightenAlgorithmTag straighten_algorithm_tag_in_task_graph = 800 [default = kCompressMemory]; optional bool enable_compress_memory = 801 [default = false]; optional int64 concurrency_width = 1000 [default = 128]; map flag_name2flag_value = 2000; optional int64 logical_object_id = 3000; optional JobSignatureDef signature = 4000; } ================================================ FILE: oneflow/core/job/job_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/persistence/hadoop/hadoop_file_system.h" #include "oneflow/core/graph/graph.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { namespace { void CheckFunctionConfig(const JobConfigProto& job_conf) { const auto& attr_name2attr_def = GlobalFunctionConfigDef().attr_name2attr_def(); for (const auto& pair : job_conf.flag_name2flag_value()) { const auto& iter = attr_name2attr_def.find(pair.first); CHECK(iter != attr_name2attr_def.end()); CHECK_EQ(iter->second.default_val().value_case(), pair.second.value_case()); } } } // namespace JobDesc::JobDesc(const JobConfigProto& job_conf, int64_t job_id) : job_conf_(job_conf), job_id_(job_id), symbol_id_(NullOpt) { CHECK_JUST(Init()); Singleton::Get()->DumpCudnnConf(job_conf); } Maybe JobDesc::New(int64_t symbol_id, const JobConfigProto& job_conf) { auto job_desc = std::make_shared(job_conf); job_desc->symbol_id_ = symbol_id; return job_desc; } Maybe JobDesc::Init() { CheckFunctionConfig(job_conf_); return Maybe::Ok(); } const AttrValue& JobDesc::GetFunctionFlagVal(const std::string& field_name) const { const auto& iter = job_conf_.flag_name2flag_value().find(field_name); if (iter != job_conf_.flag_name2flag_value().end()) { return iter->second; } const auto& attr_name2attr_def = GlobalFunctionConfigDef().attr_name2attr_def(); const auto& def_iter = attr_name2attr_def.find(field_name); CHECK(def_iter != attr_name2attr_def.end()); return def_iter->second.default_val(); } bool IsInterfaceOpConf(const OperatorConf& op_conf) { return IsClassRegistered(op_conf.op_type_case()); } GlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) { if (Singleton::Get() != nullptr) { Singleton::Delete(); } Singleton::New(job_conf, job_id); } GlobalJobDescScope::~GlobalJobDescScope() { Singleton::Delete(); } const JobDesc& GlobalJobDesc() { return *Singleton::Get(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/job_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_DESC_H_ #define ONEFLOW_CORE_JOB_JOB_DESC_H_ #include "oneflow/core/common/optional.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/dlnet_conf.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/job/inter_user_job_info.pb.h" #include "oneflow/core/register/logical_blob_id.pb.h" #include "oneflow/core/framework/config_def.h" namespace oneflow { bool IsInterfaceOpConf(const OperatorConf& op_conf); class JobDesc final { public: OF_DISALLOW_COPY_AND_MOVE(JobDesc); JobDesc(const JobConfigProto& job_conf, int64_t job_id); explicit JobDesc(const JobConfigProto& job_conf) : JobDesc(job_conf, -1) {} ~JobDesc() = default; static Maybe New(int64_t symbol_id, const JobConfigProto& job_conf); const Optional& symbol_id() const { return symbol_id_; } // Common int64_t job_id() const { return job_id_; } const std::string& job_name() const { return job_conf_.job_name(); } int64_t concurrency_width() const { return job_conf_.concurrency_width(); } const JobConfigProto& job_conf() const { return job_conf_; } const JobConfigProto& data() const { return job_conf_; } DataType DefaultDataType() const { return job_conf_.default_data_type(); } bool EnableCudnn() const { return job_conf_.enable_cudnn(); } bool IsTrain() const { return job_conf_.has_train_conf(); } bool IsPredict() const { return job_conf_.has_predict_conf(); } bool enable_reuse_mem() const { return job_conf_.enable_reuse_mem(); } bool enable_inplace() const { return job_conf_.enable_inplace(); } bool enable_auto_mixed_precision() const { return job_conf_.enable_auto_mixed_precision(); } bool enable_multi_tensor_update() const { return job_conf_.enable_multi_tensor_update(); } bool enable_fused_model_update_cast() const { return job_conf_.enable_fused_model_update_cast(); } DataType mixed_precision_data_type() const { return job_conf_.mixed_precision_data_type(); } bool do_parallel_cast_before_widening_type_cast() const { return job_conf_.do_parallel_cast_before_widening_type_cast(); }; bool prune_parallel_cast_ops() const { return job_conf_.prune_parallel_cast_ops(); } bool prune_cast_to_static_shape_ops() const { return job_conf_.prune_cast_to_static_shape_ops(); } bool prune_amp_white_identity_ops() const { return job_conf_.prune_amp_white_identity_ops(); } bool prune_depend_ops() const { return job_conf_.prune_depend_ops(); } bool enable_auto_parallel() const { return job_conf_.enable_auto_parallel(); } int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); } #define DEFINE_FUNCTION_CONFIG_GETTER(T, func_name, field_name) \ T func_name(const std::string& field_name) const { \ const AttrValue& attr_val = GetFunctionFlagVal(field_name); \ CHECK(attr_val.has_##field_name()); \ return attr_val.field_name(); \ } DEFINE_FUNCTION_CONFIG_GETTER(bool, Bool, at_bool); DEFINE_FUNCTION_CONFIG_GETTER(int64_t, Int64, at_int64); DEFINE_FUNCTION_CONFIG_GETTER(double, Double, at_double); DEFINE_FUNCTION_CONFIG_GETTER(const std::string&, String, at_string); private: Maybe Init(); const AttrValue& GetFunctionFlagVal(const std::string& field_name) const; JobConfigProto job_conf_; int64_t job_id_; Optional symbol_id_; }; typedef HashMap JobName2JobId; class GlobalJobDescScope final { public: GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id); ~GlobalJobDescScope(); }; const JobDesc& GlobalJobDesc(); bool IsPullJob(const std::string& job_name, const InterUserJobInfo& inter_user_job_info); bool IsPushJob(const std::string& job_name, const InterUserJobInfo& inter_user_job_info); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_DESC_H_ ================================================ FILE: oneflow/core/job/job_instance.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_INSTANCE_H_ #define ONEFLOW_CORE_JOB_JOB_INSTANCE_H_ #include #include "oneflow/core/common/util.h" namespace oneflow { class JobInstance { public: JobInstance() = default; virtual ~JobInstance() = default; virtual std::string job_name() const { UNIMPLEMENTED(); } virtual void Finish() const { UNIMPLEMENTED(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_INSTANCE_H_ ================================================ FILE: oneflow/core/job/job_interpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/framework/tensor_global_id.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/boxing/eager_boxing_logger.h" namespace oneflow { namespace one { using Env = std::map>; using NameToParallelDescMap = std::map>; Maybe InitEnv(const one::TensorTuple& graph_inputs, const std::shared_ptr& graph) { Env env; for (const auto& [name, tensor] : graph->variable_op_name2tensor()) { env.emplace(name + "/out", tensor); } for (size_t i = 0; i < graph->inputs_op_names().size(); ++i) { const auto& name = graph->inputs_op_names()[i]; env.emplace(name + "/out", JUST(VectorAt(graph_inputs, i))); } return env; } Maybe OpConfToUserOpExpr(const OperatorConf& op_conf) { CHECK_OR_RETURN(op_conf.has_user_conf()); const auto& user_conf = op_conf.user_conf(); auto builder = OpBuilder(user_conf.op_type_name()); for (const auto& pair : user_conf.attr()) { builder.Attr(pair.first, pair.second); } for (const auto& pair : user_conf.input()) { // ignore "UserSourceOpTickInput" if (pair.first == "UserSourceOpTickInput") { continue; } builder.Input(pair.first, pair.second.s_size()); } for (const auto& pair : user_conf.output()) { builder.Output(pair.first, pair.second.s_size()); } return JUST(builder.Build()); } template Maybe>> GetInputTensors( const UserOpConf& user_conf, const Env& env, const Func& preprocess) { TensorTuple inputs; OpArgsVector ibns; for (const auto& [ibn, ibs] : user_conf.input()) { if (ibn == "UserSourceOpTickInput") { continue; } const auto& tensor_names = ibs.s(); for (int i = 0; i < tensor_names.size(); ++i) { inputs.emplace_back(preprocess(JUST(MapAt(env, tensor_names[i])))); ibns.emplace_back(ibn + '_' + std::to_string(i)); } } return std::make_pair(inputs, ibns); } OpArgsVector GetOutputNamesOfOp(const UserOpConf& user_conf) { OpArgsVector output_names; for (const auto& pair : user_conf.output()) { for (const auto& name : pair.second.s()) { output_names.emplace_back(name); } } return output_names; } // Only support a limited subset of view ops for now bool IsViewOp(const std::shared_ptr& op) { return op->op_type_name() == "reshape" || op->op_type_name() == "expand_dims"; } Maybe RunViewOp(const std::shared_ptr& op, Env& env, const TensorTuple& inputs, const OpArgsVector& output_names) { // eliminate the memcpy of view ops CHECK_OR_RETURN(IsViewOp(op)); const std::shared_ptr result = JUST([&]() -> Maybe { LocalTensorMetaInferArgs infer_args; JUST(infer_args.Init(op->base_attrs(), JUST(inputs[0]->device()), inputs)); return JUST(op->mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); }()); const auto& output_shape = result->output_tensor_metas()[0]->shape(); const auto output = JUST(view::BasicView(inputs[0], output_shape, JUST(inputs[0]->storage_offset()))); env.emplace(output_names[0], output); return Maybe::Ok(); } namespace { Maybe RawRunGlobalNormalOp(const std::shared_ptr& op, TensorTuple& inputs, TensorTuple* outputs, Env& env, const OpArgsVector& ibns, const OpArgsVector& output_names, const NdSbpSignature& ndsbp_signature, const Symbol& op_parallel_desc) { Optional parallel_id; const auto& tensor_device = JUST(GetTensorDevice4CurrentProcessCtx(op_parallel_desc, ¶llel_id)); const auto* mgr = Singleton::Get(); CHECK_OR_RETURN(inputs.size() == ibns.size()) << "inputs size != ibns size"; for (int i = 0; i < inputs.size(); ++i) { std::shared_ptr input_tensor = inputs[i]; std::string lbn = JUST(VectorAt(ibns, i)); const auto& logical_shape = input_tensor->shape(); CHECK_OR_RETURN(logical_shape->elem_cnt() > 0) << "tensor logical element empty"; const auto& in_nd_sbp = JUST(input_tensor->nd_sbp()); const auto& out_nd_sbp = SymbolOf(JUST(MapAt(ndsbp_signature.bn_in_op2nd_sbp(), lbn))); const auto& in_parallel_desc = JUST(input_tensor->parallel_desc()); const auto& out_parallel_desc = op_parallel_desc; CHECK_OR_RETURN(in_parallel_desc == out_parallel_desc) << "input placement != output placement"; if (in_parallel_desc->parallel_num() != 1 && in_nd_sbp != out_nd_sbp) { const auto& boxing_interpreter = JUST(mgr->GetEagerBoxingInterpreter( in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *logical_shape)); Singleton::Get()->Log( *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ ""); if (parallel_id.has_value()) { inputs.at(i) = JUST(boxing_interpreter->Interpret(input_tensor, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); } } } static EagerGlobalInterpreter it; static OpExprInterpContext ctx = OpExprInterpContext(AttrMap{}, op_parallel_desc, SymbolOf(JUST(MapAt(ndsbp_signature.bn_in_op2nd_sbp(), "out_0")))); JUST(it.Apply(*op, inputs, outputs, ctx)); for (size_t i = 0; i < output_names.size(); ++i) { env.emplace(output_names[i], JUST(VectorAt(*outputs, i))); } return Maybe::Ok(); } auto* RunGlobalNormalOpThenInitGlobalId = DECORATE(&RawRunGlobalNormalOp, NonRecursiveInitGlobalId); } // namespace Maybe RunGlobalNormalOp(const std::shared_ptr& op, TensorTuple& inputs, Env& env, const OpArgsVector& ibns, const OpArgsVector& output_names, const NdSbpSignature& ndsbp_signature, const Symbol& op_parallel_desc) { TensorTuple outputs(output_names.size()); return RunGlobalNormalOpThenInitGlobalId(op, inputs, &outputs, env, ibns, output_names, ndsbp_signature, op_parallel_desc); } Maybe RunNormalOp(const std::shared_ptr& op, Env& env, const TensorTuple& inputs, const OpArgsVector& output_names) { TensorTuple outputs(output_names.size()); static EagerLocalInterpreter it; static AttrMap empty_attr_map; JUST(it.Apply(*op, inputs, &outputs, empty_attr_map)); for (size_t i = 0; i < output_names.size(); ++i) { env.emplace(output_names[i], JUST(VectorAt(outputs, i))); } return Maybe::Ok(); } // tensors in outdated_tensors_after_op[i] will not be accessed any more after i-th op // so they can be released once i-th op's execution finishes. std::vector> GetOutdatedTensorsAfterOp(const Job& job) { std::vector> outdated_tensors_after_op(job.net().op_size()); std::set visited; for (int i = job.net().op_size() - 1; i >= 0; --i) { const auto& op_conf = job.net().op(i); // do not release the graph output tensors if (op_conf.has_output_conf()) { const auto& output_conf = op_conf.output_conf(); visited.insert(output_conf.in()); } else if (op_conf.has_user_conf()) { const auto& user_conf = op_conf.user_conf(); for (const auto& pair : user_conf.input()) { if (pair.first == "UserSourceOpTickInput") { continue; } for (const auto& name : pair.second.s()) { if (visited.find(name) == visited.end()) { outdated_tensors_after_op[i].push_back(name); visited.insert(name); } } } } } return outdated_tensors_after_op; } Maybe InitOpExprs(const std::shared_ptr& graph) { CHECK_OR_RETURN(graph->cached_op_exprs.empty()); const auto& job = graph->job(); for (int i = 0; i < job.net().op_size(); i++) { const auto& op_conf = job.net().op(i); if (op_conf.has_user_conf()) { const auto op_expr = JUST(OpConfToUserOpExpr(op_conf)); graph->cached_op_exprs.push_back(op_expr); } else { graph->cached_op_exprs.push_back(nullptr); } } return Maybe::Ok(); } Maybe InterpretJob(const one::TensorTuple& graph_inputs, const std::shared_ptr& graph) { if (graph->cached_op_exprs.empty()) { JUST(InitOpExprs(graph)); } const auto& job = graph->job(); auto env = *JUST(InitEnv(graph_inputs, graph)); // See comments above GetOutdatedTensorsAfterOp's definition for more details const auto outdated_tensors_after_op = GetOutdatedTensorsAfterOp(job); CHECK_OR_RETURN(job.has_placement()) << "no job placement"; const auto& job_placement = job.placement(); NameToParallelDescMap op2paralleldesc; for (const auto& blob_placement_group : job_placement.blob_placement_group()) { const auto parallel_desc = SymbolOf(ParallelDesc(blob_placement_group.parallel_conf())); for (const auto& logical_blob_id : blob_placement_group.lbi()) { op2paralleldesc.emplace(logical_blob_id.op_name(), parallel_desc); } } CHECK_OR_RETURN(job.has_job_parallel_view_conf()) << "no job parallel conf"; const auto& op_name2nd_sbp_signature_conf = job.job_parallel_view_conf().op_name2nd_sbp_signature_conf(); one::TensorTuple graph_outputs; for (int i = 0; i < job.net().op_size(); i++) { const auto& op_conf = job.net().op(i); if (op_conf.has_user_conf()) { auto op = CHECK_NOTNULL(graph->cached_op_exprs[i]); const auto& user_conf = op_conf.user_conf(); OF_PROFILER_RANGE_GUARD(user_conf.op_type_name()); auto [inputs, ibns] = *JUST(GetInputTensors(user_conf, env, [&op_conf](const std::shared_ptr& tensor) { return CHECK_JUST(functional::To(tensor, op_conf.device_tag())); })); OpArgsVector output_names = GetOutputNamesOfOp(user_conf); if (!inputs.empty() && inputs[0]->is_local()) { // All tensors maintain the same properties of is_local if (IsViewOp(op)) { JUST(RunViewOp(op, env, inputs, output_names)); } else { JUST(RunNormalOp(op, env, inputs, output_names)); } } else { const auto& op_parallel_desc = JUST(MapAt(op2paralleldesc, op_conf.name())); const auto& nd_sbp_signature_conf = JUST(MapAt(op_name2nd_sbp_signature_conf, op_conf.name())); JUST(RunGlobalNormalOp(op, inputs, env, ibns, output_names, nd_sbp_signature_conf, op_parallel_desc)); } for (const auto& name : outdated_tensors_after_op[i]) { CHECK_EQ_OR_RETURN(env.erase(name), 1); } } else if (op_conf.has_learning_rate_schedule_conf()) { // FIXME(daquexian): // It is a temporary hack to support learning_rate_schedule op. // Only the naive sgd without any lr decay is supported. const auto& lr_conf = op_conf.learning_rate_schedule_conf(); env.emplace( op_conf.name() + "/" + lr_conf.out(), JUST(functional::Constant({1}, lr_conf.learning_rate(), DType::Float(), NullOpt))); } else if (op_conf.has_identity_conf()) { const auto& identity_conf = op_conf.identity_conf(); const auto& in = identity_conf.in(); const auto& out = op_conf.name() + "/" + identity_conf.out(); env.emplace(out, JUST(functional::Identity(JUST(MapAt(env, in))))); } else if (op_conf.has_output_conf()) { const auto& output_conf = op_conf.output_conf(); graph_outputs.emplace_back(JUST(MapAt(env, output_conf.in()))); } } return graph_outputs; } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/job/job_interpreter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { class NNGraph; namespace one { class TensorTuple; Maybe InterpretJob(const one::TensorTuple& inputs, const std::shared_ptr& graph); } // namespace one } // namespace oneflow ================================================ FILE: oneflow/core/job/job_ir.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/job_ir.h" namespace oneflow { #ifndef WITH_MLIR Maybe ConvertJobToTosaIR(Job* job) { UNIMPLEMENTED_THEN_RETURN() << "ConvertJobToTosaIR is only supported WITH_MLIR"; } Maybe SaveJobToIR(Job* job, const std::string& path) { UNIMPLEMENTED_THEN_RETURN() << "SaveJobToIR is only supported WITH_MLIR"; } Maybe ConvertJobToIR(Job* job) { UNIMPLEMENTED_THEN_RETURN() << "ConvertJobToIR is only supported WITH_MLIR"; } Maybe LoadJobFromIR(Job* job, const std::string& path) { UNIMPLEMENTED_THEN_RETURN() << "LoadJobFromIR is only supported WITH_MLIR"; } #endif } // namespace oneflow ================================================ FILE: oneflow/core/job/job_ir.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_JOB_IR_H_ #define ONEFLOW_CORE_JOB_JOB_IR_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { Maybe ConvertJobToTosaIR(Job* job); Maybe ConvertJobToIR(Job* job); Maybe SaveJobToIR(Job* job, const std::string& path); Maybe LoadJobFromIR(Job* job, const std::string& path); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_IR_H_ ================================================ FILE: oneflow/core/job/job_set.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/job.proto"; import "oneflow/core/job/resource.proto"; message ReuseMemPriorityStrategy { } message ParallelismPriorityStrategy { } message JobNameGroup { repeated string job_name = 1; } message CustomParallelismStrategy { repeated JobNameGroup nonparallel_group = 1; } message InterJobReuseMemStrategy { oneof strategy_case { ReuseMemPriorityStrategy reuse_mem_priority = 1; ParallelismPriorityStrategy parallelism_priority = 2; CustomParallelismStrategy custom_parallelism = 3; } } message ConfigProto { required Resource resource = 1; required int64 session_id = 5; } message JobSet { repeated Job job = 1; optional InterJobReuseMemStrategy inter_job_reuse_mem_strategy = 5; } ================================================ FILE: oneflow/core/job/job_set_compile_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SET_COMPILE_CTX_ #define ONEFLOW_CORE_JOB_SET_COMPILE_CTX_ #include "oneflow/core/job/compiler.h" #include "oneflow/core/job/job_set_compile_ctx.pb.h" namespace oneflow { class JobSetCompileCtx final { public: JobSetCompileCtx() = default; ~JobSetCompileCtx() = default; PbMap* GetVarOpName2randomSeed() { return job_set_compile_ctx_proto_.mutable_var_op_name2random_seed(); } private: JobSetCompileCtxProto job_set_compile_ctx_proto_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SET_COMPILE_CTX_ ================================================ FILE: oneflow/core/job/job_set_compile_ctx.proto ================================================ syntax = "proto2"; package oneflow; message JobSetCompileCtxProto { map var_op_name2random_seed = 1; } ================================================ FILE: oneflow/core/job/lazy_mode.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/lazy_mode.h" namespace oneflow { /* static */ bool* LazyMode::get_mode_ptr() { static thread_local bool mode = false; return &mode; } /* static */ bool LazyMode::is_enabled() { return *get_mode_ptr(); } /* static */ void LazyMode::set_enabled(bool enabled) { *get_mode_ptr() = enabled; } } // namespace oneflow ================================================ FILE: oneflow/core/job/lazy_mode.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_LAZY_MODE_H_ #define ONEFLOW_CORE_JOB_LAZY_MODE_H_ #include "oneflow/core/common/util.h" namespace oneflow { class LazyMode { public: OF_DISALLOW_COPY_AND_MOVE(LazyMode); LazyMode() = delete; ~LazyMode() = delete; static bool is_enabled(); class Guard { public: explicit Guard(bool enabled) : prev_mode_(LazyMode::is_enabled()) { LazyMode::set_enabled(enabled); } ~Guard() { LazyMode::set_enabled(prev_mode_); } private: bool prev_mode_; }; private: static bool* get_mode_ptr(); static void set_enabled(bool enabled); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_LAZY_MODE_H_ ================================================ FILE: oneflow/core/job/learning_rate_schedule_conf.proto ================================================ syntax = "proto2"; package oneflow; message ExponentialDecayConf { required int64 decay_batches = 1; required double decay_rate = 2; optional bool staircase = 3 [default = false]; } message InverseTimeDecayConf { required int64 decay_batches = 1; required double decay_rate = 2; optional bool staircase = 3 [default = false]; } message NaturalExpDecayConf { required int64 decay_batches = 1; required double decay_rate = 2; optional bool staircase = 3 [default = false]; } message PiecewiseConstantConf { repeated int64 boundaries = 1; repeated double values = 2; } message PolynomialDecayConf { required int64 decay_batches = 1; optional double end_learning_rate = 2 [default = 0.0001]; optional double power = 3 [default = 1.0]; optional bool cycle = 4 [default = false]; } message CosineDecayConf { required int64 decay_batches = 1; optional double alpha = 2 [default = 0.0]; } message CosineAnnealingDecayConf { required int64 t_max = 1; optional double eta_min = 2 [default = 0.0]; } message LinearCosineDecayConf { required int64 decay_batches = 1; optional double num_periods = 2 [default = 0.5]; optional double alpha = 3 [default = 0.0]; optional double beta = 4 [default = 0.001]; } message PiecewiseScalingConf { repeated int64 boundaries = 1; repeated double scales = 2; } message StepConf { required int64 step_size = 1; optional double gamma = 2 [default = 0.1]; } message MultiStepConf { repeated int64 milestones = 1; optional double gamma = 2 [default = 0.1]; } message LinearLRConf { required double start_factor = 1; required double end_factor = 2; required int64 total_iters = 3; } message ConstantLRConf { required double factor = 1; required int64 total_iters = 2; } message CosineAnnealingWarmRestartsConf { required int64 t_initial = 1; required int64 t_mult = 2; required double eta_min = 3; required double decay_rate = 4; required int64 restart_limit = 5; } message SequentialSchedulerConf { repeated LearningRateDecayConf schedulers = 1; repeated int64 milestones = 2; // NOTE(zwx): should be repeated bool, however it has bug in cfg repeated int32 interval_rescaling = 3; } // TODO(zwx): ChainedSchedulerConf message LearningRateDecayConf { oneof type { ExponentialDecayConf exponential_conf = 2000; InverseTimeDecayConf inverse_time_conf = 2001; NaturalExpDecayConf natural_exp_conf = 2002; PiecewiseConstantConf piecewise_constant_conf = 2003; PolynomialDecayConf polynomial_conf = 2004; CosineDecayConf cosine_conf = 2005; LinearCosineDecayConf linear_cosine_conf = 2006; PiecewiseScalingConf piecewise_scaling_conf = 2007; MultiStepConf multi_step_conf = 2008; StepConf step_conf = 2009; CosineAnnealingDecayConf cosine_annealing_conf = 2010; LinearLRConf linear_lr_conf = 2011; ConstantLRConf constant_lr_conf = 2012; CosineAnnealingWarmRestartsConf cosine_annealing_warm_restarts_conf = 2013; SequentialSchedulerConf sequential_scheduler_conf = 2014; } } ================================================ FILE: oneflow/core/job/local_parallel.proto ================================================ syntax = "proto2"; package oneflow; message LocalParallel { } message OptLocalParallel { optional LocalParallel local_parallel = 1; } message LocalSignature { map bn_in_op2opt_local_parallel = 1; } ================================================ FILE: oneflow/core/job/local_sig_infer_hint.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_ #define ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_ #include "oneflow/core/job/parallel_desc.h" namespace oneflow { class LocalSigInferHint final { public: LocalSigInferHint(const ParallelDesc* parallel_desc, bool is_local_parallel_view) : parallel_desc_(parallel_desc), is_local_parallel_view_(is_local_parallel_view) {} const ParallelDesc& parallel_desc() const { return *parallel_desc_; } bool is_local_parallel_view() const { return is_local_parallel_view_; } private: const ParallelDesc* parallel_desc_; bool is_local_parallel_view_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_ ================================================ FILE: oneflow/core/job/memory_share_strategy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/memory_share_strategy.h" #include #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { namespace { constexpr int32_t kMaxIterStep = 100; } // anonymous namespace bool IsLifetimeExcluded(const std::pair& a, const std::pair& b) { return a.first < b.second && b.first < a.second; } // Initialization void MemoryShareStrategy::InitRegister( const HashMap>& register2lifetime) { total_register_num_ = register2lifetime.size(); index2register_.resize(total_register_num_); int32_t register_id = 0; for (const auto& pair : register2lifetime) { index2register_[register_id] = pair.first; register_id++; } } void MemoryShareStrategy::InitRegisterInformation( const HashMap& mem_reused_regst2size) { total_register_num_ = index2register_.size(); register_size_.resize(total_register_num_); for (int32_t register_id = 0; register_id < total_register_num_; register_id++) { const auto& register_ = index2register_[register_id]; int64_t register_size = mem_reused_regst2size.at(register_); register_size_[register_id] = register_size; register2index_[register_] = register_id; } order_.resize(total_register_num_); for (int32_t i = 0; i < total_register_num_; i++) { order_[i] = i; } } // Steal a compact position as the initial strategy void MemoryShareStrategy::StealCompactPosition( const HashMap& regst_desc2offset, const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime) { // Initialization InitRegister(register2lifetime); // Sort index2register_ std::sort(index2register_.begin(), index2register_.end(), [&](RegstDescProto* i, RegstDescProto* j) { return regst_desc2offset.at(i) < regst_desc2offset.at(j); }); // Update other information InitRegisterInformation(mem_reused_regst2size); left_registers_.clear(); left_registers_.resize(total_register_num_); excluded_registers_.clear(); excluded_registers_.resize(total_register_num_); // should_visit_[i] indicates whether we should visit register[i]. // should_visit_[i] = 0: should not visit i, or have already visited i.. // should_visit_[i] = 1: should visit i, i is excluded with j // should_visit_[i] = 2: should visit i, i is not excluded with j should_visit_.clear(); should_visit_.resize(total_register_num_, 0); register_offset_.resize(total_register_num_); // Generate a compact relationship of position // For example we have 3 relationship: x1 < x2, x2 < x3, x1 < x3 // We would delete the redundant relationship (x1 < x3) for (int32_t j = 0; j < total_register_num_; j++) { const auto& register_j = index2register_[j]; register_offset_[j] = regst_desc2offset.at(register_j); auto& excluded_register_j = excluded_registers_[j]; const auto& lifetime_j = register2lifetime.at(register_j); // Init should visit with all orders of the excluded register for (int32_t i = j + 1; i < total_register_num_; i++) { if (IsLifetimeExcluded(lifetime_j, register2lifetime.at(index2register_[i]))) { // Copy the data to excluded registers excluded_register_j.insert(i); excluded_registers_[i].insert(j); } } } for (int32_t j = 0; j < total_register_num_; j++) { ResetCompactPosition(j); } } // Generate a compact position with the order of occurrence // Not recommended void MemoryShareStrategy::GenerateCompactPosition( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime) { HashMap regst_desc2offset; int64_t offset = 0; for (const auto& pair : register2lifetime) { regst_desc2offset[pair.first] = offset; offset++; } StealCompactPosition(regst_desc2offset, mem_reused_regst2size, register2lifetime); } // Compute optimal cost with compact relationship size_t MemoryShareStrategy::ComputeOptimalCost4CompactRelationship() { int64_t mem_block_size = 0; for (int32_t i = 0; i < total_register_num_; i++) { mem_block_size = std::max(mem_block_size, ComputeOffset4CompactRelationship(i) + register_size_[i]); } mem_block_size_ = size_t(mem_block_size); return mem_block_size_; } // Compute offset with compact relationship int64_t MemoryShareStrategy::ComputeOffset4CompactRelationship(int32_t i) { if (register_offset_[i] < 0) { // An initial value x would be store as -x - 1. register_offset_[i] = -register_offset_[i] - 1; for (int32_t j : left_registers_[i]) { register_offset_[i] = std::max(register_offset_[i], ComputeOffset4CompactRelationship(j) + register_size_[j]); } } return register_offset_[i]; } size_t MemoryShareStrategy::ComputeOptimalAdjustedCost() { // Initial optimal cost size_t optimal_cost = ComputeOptimalCostFrom0(); // All the registers excluded with register i are sorted from left to right // std::vector order_; // auto CompareRegisterPosition = [&](int32_t i, int32_t j) { // return register_offset_[i] < register_offset_[j]; // }; backup_registers_.clear(); backup_registers_.resize(total_register_num_); // The number of steps that the optimal cost does not decrease int32_t step_no_decrease = 0; for (int32_t m = 0; m < max_iteration_step_; m++) { for (int32_t i = 0; i < total_register_num_; i++) { EliminateRegister(i); size_t cost_without_i = ComputeOptimalCostFrom0(); // Find the offset of i which has the minimum cost int64_t min_x_i = -1; if (cost_without_i < optimal_cost) { // Find the minimum cost int64_t min_cost = optimal_cost; // Back up the current register offset with elimination of i auto register_offset_backup = register_offset_; // Try to insert the register i into the sorted excluded registers HashSet all_x_i; for (int32_t j : excluded_registers_[i]) { // Insert i before j all_x_i.insert(register_offset_backup[j]); // Insert i after j all_x_i.insert(register_offset_backup[j] + register_size_[j]); } for (int64_t x_i : all_x_i) { int64_t cost_insert_i = ComputeOptimalCostWithOccupation(i, x_i, register_offset_backup); // Check if we found a smaller cost if (cost_insert_i < min_cost) { min_cost = cost_insert_i; min_x_i = x_i; if (min_cost <= cost_without_i) { break; } } } // Found a smaller cost if (min_x_i >= 0) { InsertRegister(i, min_x_i, register_offset_backup); optimal_cost = ComputeOptimalCostFrom0(); } } // Found a smaller cost if (min_x_i >= 0) { // Move to a new status with smaller cost, dump the backup of the offset. ClearBackup(); step_no_decrease = 0; } else { // Recover to the original status RecoverFromBackup(i); // Adjust the offset after recovery ComputeOptimalCostFrom0(); // Terminate it if no cost reduce for any of the adjustment. step_no_decrease++; if (step_no_decrease >= total_register_num_) { break; } } } if (step_no_decrease >= total_register_num_) { break; } } CHECK_JUST(CheckConflict()); return optimal_cost; } // Let x_i occupy some space [x_i, x_i + l_i), then we recompute the optimal cost size_t MemoryShareStrategy::ComputeOptimalCostWithOccupation( int32_t i, int64_t x_i, const std::vector& register_offset_backup) { // The end of register i. int64_t e_i = x_i + register_size_[i]; register_offset_.clear(); register_offset_.resize(total_register_num_, -1); for (int32_t k : excluded_registers_[i]) { // x_k + l_k > x_i // k is behind i if (register_offset_backup[k] + register_size_[k] > x_i) { register_offset_[k] = -e_i - 1; } else { register_offset_[k] = register_offset_backup[k]; } } register_offset_[i] = x_i; return ComputeOptimalCost4CompactRelationship(); } // Eliminate one register void MemoryShareStrategy::EliminateRegister(int32_t i) { // Init back up registers backup_registers_[i] = left_registers_[i]; for (auto j : excluded_registers_[i]) { if (register_offset_[i] < register_offset_[j]) { should_visit_.clear(); should_visit_.resize(total_register_num_, 0); // should_visit_[i] = 0: should not visit i, or have already visited i.. // should_visit_[i] = 1: should visit i, i is excluded with j // should_visit_[i] = 2: should visit i, i is not excluded with j // should_visit_[i] = -1: i is visited, i is excluded with j // should_visit_[i] = -2: i is visited, i is not excluded with j for (int32_t k = 0; k < total_register_num_; k++) { if (register_offset_[k] < register_offset_[j]) { if (Exclude(k, j)) { should_visit_[k] = 1; } else { should_visit_[k] = 2; } } } // Eliminate all the grandsons of the excluded registers for (int32_t k : excluded_registers_[j]) { if (should_visit_[k] == 1) { EliminateRedundantRelationshipIgnore(i, k); } } for (int32_t k : excluded_registers_[j]) { if (should_visit_[k] == -1) { if (left_registers_[j].insert(k).second) { backup_registers_[j].insert(k); } } } if (left_registers_[j].erase(i)) { backup_register_behind_i_.insert(j); } } } left_registers_[i].clear(); } // Whether i and j occurs simultaneously bool MemoryShareStrategy::Exclude(int32_t i, int32_t j) { return excluded_registers_[i].find(j) != excluded_registers_[i].end(); } // If the previous strategy has fewer cost, recover to the previous one from the backup. void MemoryShareStrategy::RecoverFromBackup(int32_t i) { for (int32_t j = 0; j < total_register_num_; j++) { if (i == j) { left_registers_[i] = backup_registers_[i]; } else { for (int32_t k : backup_registers_[j]) { left_registers_[j].erase(k); } } } for (int32_t j : backup_register_behind_i_) { left_registers_[j].insert(i); } ClearBackup(); } // Clear backup void MemoryShareStrategy::ClearBackup() { for (auto& backup_register : backup_registers_) { backup_register.clear(); } backup_register_behind_i_.clear(); } size_t MemoryShareStrategy::ComputeOptimalCostFrom0() { register_offset_.clear(); register_offset_.resize(total_register_num_, -1); return ComputeOptimalCost4CompactRelationship(); } // Insert register i at position [x_i, x_i + l_i) void MemoryShareStrategy::InsertRegister(int32_t i, int64_t x_i, const std::vector& original_register_offset) { ComputeOptimalCostWithOccupation(i, x_i, original_register_offset); std::sort(order_.begin(), order_.end(), [&](int32_t k, int32_t j) { return register_offset_[k] < register_offset_[j]; }); for (int32_t j : order_) { if (register_offset_[i] <= register_offset_[j]) { ResetCompactPosition(j); } } } // Eliminate children of j but ignore i. void MemoryShareStrategy::EliminateRedundantRelationshipIgnore(int32_t i, int32_t j) { // Ignore i if (i == j) { return; } if (should_visit_[j] > 0) { // Do not look into it again should_visit_[j] = -should_visit_[j]; for (int32_t k : left_registers_[j]) { EliminateRedundantRelationshipIgnore(i, k); should_visit_[k] = 0; } } } // Check whether the current offset does not introduce any conflict Maybe MemoryShareStrategy::CheckConflict() { CHECK_EQ_OR_RETURN(index2register_.size(), register_offset_.size()) << "Not equal size, we might be calling CheckConflict() at a wrong time."; for (int32_t i = 0; i < total_register_num_; i++) { CHECK_GE_OR_RETURN(register_offset_[i], 0) << "Register offset is not computed."; for (int32_t j : excluded_registers_[i]) { CHECK_OR_RETURN(register_offset_[i] + register_size_[i] <= register_offset_[j] || register_offset_[j] + register_size_[j] <= register_offset_[i]) << "Two registers overlap"; } } return Maybe::Ok(); } // Update the offset with the adjusted strategy void MemoryShareStrategy::UpdateOffset(size_t* mem_block_size, HashMap* regst_desc2offset) { size_t optimal_cost = ComputeOptimalAdjustedCost(); if (optimal_cost < *mem_block_size) { VLOG(3) << "Original cost: " << *mem_block_size << ", updated cost: " << optimal_cost; *mem_block_size = optimal_cost; for (auto& pair : *regst_desc2offset) { pair.second = register_offset_[register2index_[pair.first]]; } } } // Find all the k < i, eliminates k < j, // since k < i and i < j have already implied that. void MemoryShareStrategy::EliminateRedundantRelationship(int32_t i) { // If i is already eliminate, skip it. if (should_visit_[i]) { for (int32_t k : left_registers_[i]) { // Eliminate all the k < i EliminateRedundantRelationship(k); // Eliminate left[i] should_visit_[k] = 0; } } } // Reset the compact position for the registers void MemoryShareStrategy::ResetCompactPosition(int32_t j) { left_registers_[j].clear(); // Mark all the registers on the left for (int32_t i = 0; i < total_register_num_; i++) { if (register_offset_[i] < register_offset_[j]) { if (Exclude(i, j)) { should_visit_[i] = 1; } else { should_visit_[i] = 2; } } else { // Might be unnecessary since we clear up should_visit_ before. should_visit_[i] = 0; } } for (int32_t i = 0; i < total_register_num_; i++) { if (should_visit_[i] == 1) { // Find all the k < i, eliminates k < j, // since k < i and i < j have already implied that. // Also reset should_visit_[i] to false, // since we have already visited i. EliminateRedundantRelationship(i); } } for (int32_t i = 0; i < total_register_num_; i++) { if (should_visit_[i] == 1) { // i < j left_registers_[j].insert(i); } // Might be unnecessary since we clear up should_visit_ before. should_visit_[i] = 0; } } // Update the maximum iteration step with the current size and lower bound void MemoryShareStrategy::UpdateMaxIteration(size_t mem_block_size, size_t lower_bound) { if (lower_bound > 0) { max_iteration_step_ = ((mem_block_size - lower_bound) * 100) / lower_bound; } else { // A graph only containing several 0 size tensors might have lower bound = 0. // Check test_div.py::TestDiv::test_0_size_div for example. max_iteration_step_ = 0; } // if mem_block_size is closed to the maximum number of type size_t, then we might have a negative // value for (mem_block_size - lower_bound) * 100 // In this case, we just set a large max_iteration_step_ if (max_iteration_step_ < 0) { max_iteration_step_ = kMaxIterStep; } } // Adaptively update the offset of registers to minimize the total memory void MemoryShareStrategy::AdaptivelyUpdateOffset( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime, size_t lower_bound, size_t* mem_block_size, HashMap* regst_desc2offset) { VLOG(3) << "Current memory size: " << *mem_block_size << ", lower bound : " << lower_bound; if (*mem_block_size > lower_bound) { UpdateMaxIteration(*mem_block_size, lower_bound); VLOG(3) << "max iteration step: " << max_iteration_step_; if (max_iteration_step_ > 0) { StealCompactPosition(*regst_desc2offset, mem_reused_regst2size, register2lifetime); UpdateOffset(mem_block_size, regst_desc2offset); } VLOG(3) << "After compression, memory size: " << *mem_block_size; } } // Set the offset of registers to minimize the total memory // Iterating from a random order might take a lot of steps to reach the optimal cost. // Therefore, this function is not recommended with an initial offset provided. void MemoryShareStrategy::GenerateOffset( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime, size_t* mem_block_size, HashMap* regst_desc2offset) { max_iteration_step_ = kMaxIterStep; VLOG(3) << "max iteration step: " << max_iteration_step_; GenerateCompactPosition(mem_reused_regst2size, register2lifetime); UpdateOffset(mem_block_size, regst_desc2offset); } } // namespace oneflow ================================================ FILE: oneflow/core/job/memory_share_strategy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_ #define ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_ #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/register/register_desc.pb.h" #include "oneflow/core/common/maybe.h" namespace oneflow { // NOTE: Another trick to save times. // Comparing two numbers is faster than asking the existence in a HashSet. bool IsLifetimeExcluded(const std::pair& a, const std::pair& b); class MemoryShareStrategy { public: // Adaptively update the offset of registers to minimize the total memory void AdaptivelyUpdateOffset( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime, size_t lower_bound, size_t* mem_block_size, HashMap* regst_desc2offset); // Set the offset of registers to minimize the total memory // Iterating from a random order might take a lot of steps to reach the optimal cost. // Therefore, this function is not recommended with an initial offset provided. void GenerateOffset( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime, size_t* mem_block_size, HashMap* regst_desc2offset); private: size_t mem_block_size_; int32_t max_iteration_step_; std::vector register_offset_; std::vector register_size_; HashMap register2index_; std::vector index2register_; // left registers store the first registers on the left, which have smaller offsets. // For example, 1 < 2 < 3 < 5 // 2 < 4 < 5 // Then // left_registers_[1] = {} // left_registers_[2] = {1} // left_registers_[3] = {2} // left_registers_[4] = {2} // left_registers_[5] = {3, 4} // We know that 1 < 3, but 1 is not in left_registers_[3], // since we only store the first registers. std::vector> left_registers_; // Store all the registers which exist simultaneously. std::vector> excluded_registers_; // Back up the changes std::vector> backup_registers_; HashSet backup_register_behind_i_; // A buffer which implies whether we should visit a register std::vector should_visit_; int32_t total_register_num_; std::vector order_; // Mid-level interfaces // Steal a compact position as the initial strategy void StealCompactPosition( const HashMap& regst_desc2offset, const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime); // Generate a compact position with the order of occurrence void GenerateCompactPosition( const HashMap& mem_reused_regst2size, const HashMap>& register2lifetime); // Update the offset with the adjusted strategy void UpdateOffset(size_t* mem_block_size, HashMap* regst_desc2offset); // Update the maximum iteration step with the current size and lower bound void UpdateMaxIteration(size_t mem_block_size, size_t lower_bound); // Initialization void InitRegister(const HashMap>& register2lifetime); void InitRegisterInformation(const HashMap& mem_reused_regst2size); // Adjust the original strategy, return the updated optimal cost size_t ComputeOptimalAdjustedCost(); // Eliminate one register void EliminateRegister(int32_t i); // Eliminate children of j but ignore i. void EliminateRedundantRelationshipIgnore(int32_t i, int32_t j); // Whether i and j occurs simultaneously bool Exclude(int32_t i, int32_t j); // If the previous strategy without the elimination of i has fewer cost, recover to the previous // one from the backup. void RecoverFromBackup(int32_t i); // Clear backup void ClearBackup(); // Let x_i occupy some space [x_i, x_i + l_i), then we recompute the optimal cost size_t ComputeOptimalCostWithOccupation(int32_t i, int64_t x_i, const std::vector& register_offset_backup); // Insert register i at position [x_i, x_i + l_i) void InsertRegister(int32_t i, int64_t x_i, const std::vector& original_register_offset); // Compute optimal cost with compact relationship size_t ComputeOptimalCost4CompactRelationship(); size_t ComputeOptimalCostFrom0(); // Compute offset with compact relationship int64_t ComputeOffset4CompactRelationship(int32_t i); // Check whether the current offset does not introduce any conflict Maybe CheckConflict(); // Reset the compact position for the registers with should_visit_ = 0 void ResetCompactPosition(int32_t j); // Find all the k < i, eliminates k < j, // since k < i and i < j have already implied that. void EliminateRedundantRelationship(int32_t i); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_ ================================================ FILE: oneflow/core/job/module_conf.proto ================================================ syntax = "proto2"; package oneflow; message ModuleConf { required string name = 1; repeated string ops = 2; } ================================================ FILE: oneflow/core/job/nd_sbp_infer_hint.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_ #define ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_ #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/register/blob_desc.h" namespace oneflow { class NdSbpInferHint final { public: NdSbpInferHint(const ParallelDesc* parallel_desc, const BlobDesc* logical_blob_desc, const NdSbp* nd_sbp) : parallel_desc_(parallel_desc), logical_blob_desc_(logical_blob_desc), nd_sbp_(nd_sbp) {} NdSbpInferHint(const NdSbpInferHint&) = default; ~NdSbpInferHint() = default; // Getters const ParallelDesc& parallel_desc() const { return *parallel_desc_; } const BlobDesc& logical_blob_desc() const { return *logical_blob_desc_; } const NdSbp& nd_sbp() const { return *nd_sbp_; } private: const ParallelDesc* parallel_desc_; const BlobDesc* logical_blob_desc_; const NdSbp* nd_sbp_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_ ================================================ FILE: oneflow/core/job/nd_sbp_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { std::vector GetTensorSliceView(const int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc) { const Shape& shape = blob_desc.shape(); std::vector ranges(shape.NumAxes()); FOR_RANGE(int64_t, i, 0, shape.NumAxes()) { ranges[i].mut_begin() = 0; ranges[i].mut_end() = shape.At(i); } if (shape.NumAxes() == 0) { // NOTE(chengcheng): For Scalar Tensor. ranges.emplace_back(0, 1); } std::vector views; views.reserve(parallel_num); if (sbp_parallel.has_partial_sum_parallel() || sbp_parallel.has_broadcast_parallel()) { FOR_RANGE(int64_t, i, 0, parallel_num) { views.emplace_back(ranges); } } else if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); CHECK_LT(axis, shape.NumAxes()); const BalancedSplitter bs(shape.At(axis), parallel_num); FOR_RANGE(int64_t, i, 0, parallel_num) { if (bs.At(i).size() == 0) { views.emplace_back(); } else { ranges[axis] = bs.At(i); views.emplace_back(ranges); } } } else { UNIMPLEMENTED(); } return views; } TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, const std::vector& parallel_rank) { std::vector ranges(logical_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, logical_shape.NumAxes()) { ranges[i].mut_begin() = 0; ranges[i].mut_end() = logical_shape.At(i); } if (parallel_hierarchy.elem_cnt() == 1) { return TensorSliceView(ranges); } if (parallel_hierarchy.NumAxes() == 1) { const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(0); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_GE(split_axis, 0); CHECK_LT(split_axis, ranges.size()); const int64_t id = parallel_rank.front(); CHECK_GE(id, 0); CHECK_LT(id, parallel_hierarchy.elem_cnt()); const BalancedSplitter bs(logical_shape.At(split_axis), parallel_hierarchy.elem_cnt()); CHECK_GT(bs.At(id).size(), 0); ranges[split_axis] = bs.At(id); } } else { Shape physical_shape(logical_shape); FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_GE(split_axis, 0); CHECK_LT(split_axis, ranges.size()); CHECK_GE(ranges[split_axis].size(), parallel_hierarchy.At(i)); const BalancedSplitter bs(physical_shape.At(split_axis), parallel_hierarchy.At(i)); const auto& range = bs.At(parallel_rank.at(i)); const int64_t range_size = range.size(); const int64_t dim_start = ranges[split_axis].begin() + range.begin(); physical_shape.Set(split_axis, range_size); ranges[split_axis].mut_begin() = dim_start; ranges[split_axis].mut_end() = dim_start + range_size; } } } return TensorSliceView(ranges); } TensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, int64_t parallel_id) { NdIndexOffsetHelper hierarchy_index_helper( parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes()); std::vector parallel_rank(SHAPE_MAX_AXIS_SIZE); hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data()); return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, parallel_rank); } std::vector GetTensorSliceView(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape) { std::vector views; views.reserve(parallel_hierarchy.elem_cnt()); FOR_RANGE(int64_t, i, 0, parallel_hierarchy.elem_cnt()) { views.emplace_back(GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, i)); } return views; } TensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc) { return TensorSliceView(blob_desc.shape()); } bool NdSbpHasPartialParallel(const NdSbp& nd_sbp) { CHECK_GT(nd_sbp.sbp_parallel_size(), 0); FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return true; } } return false; } bool NdSbpHasBroadcastParallel(const NdSbp& nd_sbp) { CHECK_GT(nd_sbp.sbp_parallel_size(), 0); FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { return true; } } return false; } bool NdSbpIsAllBroadcast(const NdSbp& nd_sbp) { for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) { if (!sbp_parallel.has_broadcast_parallel()) { return false; } } return true; } bool NdSbpIsAllPartialSum(const NdSbp& nd_sbp) { for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) { if (!sbp_parallel.has_partial_sum_parallel()) { return false; } } return true; } bool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis) { for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) { if (!(sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == axis)) { return false; } } return true; } } // namespace oneflow ================================================ FILE: oneflow/core/job/nd_sbp_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_ND_SBP_UTIL_H_ #define ONEFLOW_CORE_JOB_ND_SBP_UTIL_H_ #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { std::vector GetTensorSliceView(int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc); std::vector GetTensorSliceView(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape); TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, const std::vector& parallel_rank); TensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, int64_t parallel_id); TensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc); bool NdSbpIsAllBroadcast(const NdSbp& nd_sbp); bool NdSbpIsAllPartialSum(const NdSbp& nd_sbp); bool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis); bool NdSbpHasPartialParallel(const NdSbp& nd_sbp); bool NdSbpHasBroadcastParallel(const NdSbp& nd_sbp); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ ================================================ FILE: oneflow/core/job/oneflow.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/constant.h" #include "oneflow/core/common/range.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/compiler.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/sub_plan.pb.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/job/oneflow.h" #include "oneflow/core/job/inter_job_mem_sharing_util.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/operator/interface_op_util.h" #include "oneflow/core/job/critical_section_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/graph/plan_task_graph.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job_rewriter/job_completer.h" namespace std { template<> struct hash { size_t operator()(const oneflow::ParallelBlobConf& parallel_blob_conf) const { std::string serialized; parallel_blob_conf.SerializeToString(&serialized); return std::hash()(serialized); } }; } // namespace std namespace oneflow { bool operator==(const ParallelBlobConf& lhs, const ParallelBlobConf& rhs) { return BlobDesc(lhs.logical_blob_desc_conf()) == BlobDesc(rhs.logical_blob_desc_conf()) && lhs.parallel_conf() == rhs.parallel_conf() && lhs.nd_sbp() == rhs.nd_sbp(); } namespace { // There are circles in MainJob. // A MainJob is a Job like: // // wait_and_send_ids_op -> reentrant_lock_op -> case_op -> identity_op -> esac_op -> // \________________________________________________/ // // back edges esac_op -> reentrant_lock_op are linked by rewriting the plan instead of // compiling OpGraph to TaskGraph. // ReentrantLockBackEdge holds the key information of a back edge struct ReentrantLockBackEdge { std::string reentrant_lock_op_name; // back edge destination. LogicalBlobId critical_section_sink_lbi; // back edge source. }; std::string cluster_thrd_ids_key(const std::string& plan_name) { return plan_name + "_cluster_thrd_ids"; } std::string ctrl_regst_desc_info_key(const std::string& plan_name) { return plan_name + "_ctrl_regst_desc_info_key"; } std::string job_id2job_conf(const std::string& plan_name) { return plan_name + "_job_id2job_conf"; } std::string GetCollectiveBoxingPlanKey(const std::string& plan_name) { return plan_name + "_collective_boxing_plan"; } std::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) { return plan_name + "_" + std::to_string(machine_id) + "_" + std::to_string(thrd_id); } std::string block7chunk_key(const std::string& plan_name, int64_t machine_id) { return plan_name + "_" + std::to_string(machine_id) + "_block7chunk"; } void PushPlan(const std::string& plan_name, Plan&& plan) { HashMap> machine_id2thrd_id_set; HashMap, std::list> mchn_thrd_id2task_protos; HashMap machine_id2block7chunk; for (TaskProto& task : *plan.mutable_task()) { machine_id2thrd_id_set[task.machine_id()].insert(task.thrd_id()); mchn_thrd_id2task_protos[std::make_pair(task.machine_id(), task.thrd_id())].emplace_back( std::move(task)); } HashMap machine_id2thrd_ids; for (const auto& pair : machine_id2thrd_id_set) { CHECK(machine_id2thrd_ids.emplace(pair.first, ThrdIds()).second); std::vector thrd_id_vec(pair.second.begin(), pair.second.end()); *(machine_id2thrd_ids.at(pair.first).mutable_thrd_id()) = StdVec2PbRf(thrd_id_vec); } ClusterThrdIds cluster_thrd_ids; *(cluster_thrd_ids.mutable_machine_id2thrd_ids()) = HashMap2PbMap(machine_id2thrd_ids); Singleton::Get()->PushKV(cluster_thrd_ids_key(plan_name), cluster_thrd_ids); for (std::pair, std::list>& pair : mchn_thrd_id2task_protos) { SubPlan sub_plan; sub_plan.mutable_task()->Reserve(pair.second.size()); while (!pair.second.empty()) { sub_plan.mutable_task()->Add(std::move(pair.second.front())); pair.second.pop_front(); } Singleton::Get()->PushKV( sub_plan_key(plan_name, pair.first.first, pair.first.second), sub_plan); } for (const auto& mem_block : plan.block_chunk_list().mem_block()) { *machine_id2block7chunk[mem_block.machine_id()].add_mem_block() = mem_block; } for (const auto& chunk : plan.block_chunk_list().chunk()) { *machine_id2block7chunk[chunk.machine_id()].add_chunk() = chunk; } for (const auto& pair : machine_id2block7chunk) { Singleton::Get()->PushKV(block7chunk_key(plan_name, pair.first), pair.second); } Singleton::Get()->PushKV(ctrl_regst_desc_info_key(plan_name), plan.ctrl_regst_desc_info()); Singleton::Get()->PushKV(job_id2job_conf(plan_name), plan.job_confs()); Singleton::Get()->PushKV(GetCollectiveBoxingPlanKey(plan_name), plan.collective_boxing_plan()); } void PullPlan(const std::string& plan_name, Plan* plan) { ClusterThrdIds cluster_thrd_ids; Singleton::Get()->PullKV(cluster_thrd_ids_key(plan_name), &cluster_thrd_ids); PrintProtoToTextFile(cluster_thrd_ids, JoinPath(FLAGS_log_dir, cluster_thrd_ids_key(plan_name))); HashMap machine_id2thrd_ids; machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids()); int64_t machine_id = GlobalProcessCtx::Rank(); auto thrd_ids_it = machine_id2thrd_ids.find(machine_id); CHECK(thrd_ids_it != machine_id2thrd_ids.end()); std::vector thrd_id_vec = PbRf2StdVec(thrd_ids_it->second.thrd_id()); for (auto thrd_id : thrd_id_vec) { SubPlan sub_plan; Singleton::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan); plan->mutable_task()->MergeFrom(sub_plan.task()); } CtrlRegstDescInfo ctrl_regst_desc_info; Singleton::Get()->PullKV(ctrl_regst_desc_info_key(plan_name), &ctrl_regst_desc_info); *(plan->mutable_ctrl_regst_desc_info()) = ctrl_regst_desc_info; JobConfs job_confs; Singleton::Get()->PullKV(job_id2job_conf(plan_name), &job_confs); *(plan->mutable_job_confs()) = job_confs; Singleton::Get()->PullKV(GetCollectiveBoxingPlanKey(plan_name), plan->mutable_collective_boxing_plan()); MemBlockAndChunkList block7chunk; Singleton::Get()->PullKV(block7chunk_key(plan_name, machine_id), &block7chunk); plan->mutable_block_chunk_list()->CopyFrom(block7chunk); // pull op_attribute_info OpAttributeInfo op_attribute_info; Singleton::Get()->PullKV("op_attribute_info", &op_attribute_info); // populate op_attribute_info PlanUtil::PopulateOpAttribute(plan, op_attribute_info.job_id2op_attribute_ref_table()); } Maybe CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) { const JobDesc& job_desc = GlobalJobDesc(); if (GlobalProcessCtx::IsThisProcessMaster()) { double start = GetCurTime(); if (need_job_complete) { JUST(JobCompleter::Complete(job)); } Compiler().Compile(job, plan); PlanUtil::GenMemBlockAndChunk4Plan(plan); LOG(INFO) << "\njob_id: " << job_desc.job_id() << " , job_name: " << job_desc.job_name() << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n"; if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create(StrCat("subplan_job_", job_desc.job_id()))->Write(*plan); } } PlanUtil::GenCollectiveBoxingPlan(job, plan); PlanUtil::GenRegisterHint(plan); return Maybe::Ok(); } void MergePlan(Plan* plan, Plan&& other) { PbRpf* dst_tasks = plan->mutable_task(); PbRpf* src_tasks = other.mutable_task(); dst_tasks->Reserve(dst_tasks->size() + src_tasks->size()); for (TaskProto& task : *src_tasks) { *(dst_tasks->Add()) = std::move(task); } plan->mutable_block_chunk_list()->MergeFrom(other.block_chunk_list()); for (const auto& pair : other.job_confs().job_id2job_conf()) { CHECK(plan->mutable_job_confs()->mutable_job_id2job_conf()->insert(pair).second); } for (const auto& pair : other.collective_boxing_plan().job_id2request_set()) { CHECK( plan->mutable_collective_boxing_plan()->mutable_job_id2request_set()->insert(pair).second); } for (auto& pair : *(other.mutable_job_id2op_attribute_ref_table())) { CHECK(plan->job_id2op_attribute_ref_table().find(pair.first) == plan->job_id2op_attribute_ref_table().end()) << "fail to merge op attribute info for job: " << pair.first; (*plan->mutable_job_id2op_attribute_ref_table())[pair.first] = std::move(pair.second); } } void MergeSubPlan(Plan* plan, std::vector&& sub_plans) { CHECK(!sub_plans.empty()); *plan = std::move(sub_plans.at(0)); FOR_RANGE(int32_t, i, 1, sub_plans.size()) { MergePlan(plan, std::move(sub_plans.at(i))); } } RegstDescProto* GetSoleDataRegstDescProto(TaskProto* task) { RegstDescProto* ret = nullptr; for (auto& pair : *task->mutable_produced_regst_desc()) { CHECK(pair.second.regst_desc_type().has_data_regst_desc()); CHECK_ISNULL(ret); ret = &pair.second; } CHECK_NOTNULL(ret); return ret; } const OperatorConf& GetSoleOpConf(Plan* plan, const TaskProto& task) { CHECK_EQ(task.exec_sequence().exec_node_size(), 1); return PlanUtil::GetOpAttribute(plan, task.job_id(), task.exec_sequence().exec_node(0).kernel_conf()) .op_conf(); } void UpdateSoleObnRegstDescId(Plan* plan, TaskProto* task) { CHECK_EQ(task->exec_sequence().exec_node_size(), 1); auto* exec_node = task->mutable_exec_sequence()->mutable_exec_node(0); const auto& obns = PlanUtil::GetOpAttribute(plan, task->job_id(), exec_node->kernel_conf()).output_bns(); CHECK_EQ(obns.size(), 1); int64_t regst_desc_id = GetSoleDataRegstDescProto(task)->regst_desc_id(); (*exec_node->mutable_bn_in_op2regst_desc_id())[obns.Get(0)] = regst_desc_id; } // example // given caller plan: op_A --> op_identity_tick --> op_B // given callee plan: op_src_tick --> op_C --> op_D --> op_E --> op_sink_tick // return: // op_A --> op_identity_tick --> op_C --> op_D --> op_E --> op_sink_tick --> op_B // / // op_src_tick -->/ // // note: after this function called, op_src_tick is illegal and need to be deleted from plan void LinkTickTaskProto(Plan* plan, TaskProto* identity_tick, TaskProto* src_tick, TaskProto* sink_tick) { CHECK(GetSoleOpConf(plan, *identity_tick).has_tick_conf()); CHECK(GetSoleOpConf(plan, *src_tick).has_source_tick_conf()); CHECK(GetSoleOpConf(plan, *sink_tick).has_sink_tick_conf()); RegstDescProto* id_tick_sole_regst = GetSoleDataRegstDescProto(identity_tick); RegstDescProto* src_tick_sole_regst = GetSoleDataRegstDescProto(src_tick); RegstDescProto* sink_tick_sole_regst = GetSoleDataRegstDescProto(sink_tick); sink_tick_sole_regst->set_regst_desc_id(id_tick_sole_regst->regst_desc_id()); *sink_tick_sole_regst->mutable_consumer_task_id() = id_tick_sole_regst->consumer_task_id(); UpdateSoleObnRegstDescId(plan, sink_tick); CHECK_EQ(identity_tick->machine_id(), sink_tick->machine_id()); id_tick_sole_regst->set_regst_desc_id(src_tick_sole_regst->regst_desc_id()); *id_tick_sole_regst->mutable_consumer_task_id() = src_tick_sole_regst->consumer_task_id(); UpdateSoleObnRegstDescId(plan, identity_tick); } void LinkMainPlan(Plan* plan, Plan&& main_plan, const std::vector>& identity_tick_op_names) { std::function IsInterfaceTickTockTask; { auto task_ids = std::make_shared>(); for (const auto& task : main_plan.task()) { if (task.task_type() == TaskType::kTick) { CHECK(task_ids->emplace(task.task_id()).second); } } IsInterfaceTickTockTask = [task_ids, plan](const TaskProto* task) { if (task_ids->find(task->task_id()) != task_ids->end()) { return true; } if (task->exec_sequence().exec_node_size() != 1) { return false; } const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); OperatorConf::OpTypeCase op_type_case = PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case(); return op_type_case == OperatorConf::kSourceTickConf || op_type_case == OperatorConf::kSinkTickConf; }; } MergePlan(plan, std::move(main_plan)); HashMap sole_tick_op_name2sole_task; FOR_RANGE(int64_t, i, 0, plan->task_size()) { TaskProto* task = plan->mutable_task(i); if (IsInterfaceTickTockTask(task) == false) { continue; } const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); const auto& op_name = PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name(); CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second); } auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan); const auto& process_ranks = Singleton::Get()->process_ranks(); FOR_RANGE(int32_t, i, 0, Singleton::Get()->CriticalSectionNum()) { const CriticalSection& cs = Singleton::Get()->GetCriticalSection(i); for (int64_t machine_id : process_ranks) { TaskProto* identity_tick = sole_tick_op_name2sole_task.at(identity_tick_op_names.at(i).at(machine_id)); LinkTickTaskProto( plan, identity_tick, sole_tick_op_name2sole_task.at(cs.machine_id2source_tick_op_name().at(machine_id)), sole_tick_op_name2sole_task.at(cs.machine_id2sink_tick_op_name().at(machine_id))); } } { // erase source_tick task_proto HashSet source_tick_op_names; FOR_RANGE(int32_t, i, 0, Singleton::Get()->CriticalSectionNum()) { const CriticalSection& cs = Singleton::Get()->GetCriticalSection(i); for (int64_t machine_id : process_ranks) { const auto& src_tick_op_name = cs.machine_id2source_tick_op_name().at(machine_id); CHECK(source_tick_op_names.emplace(src_tick_op_name).second); } } Erase>(*plan->mutable_task(), [&](const TaskProto& task) { if (task.task_type() == TaskType::kSourceTick) { CHECK(task.exec_sequence().exec_node_size() == 1); const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf(); const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf(); CHECK(op_conf.has_source_tick_conf()); CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end()); return true; } else { return false; } }); } } void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& op_name, ParallelBlobConf* blob_conf) { std::string obn = "out"; std::string lbn; { const auto& op_conf = CHECK_JUST(job_builder.OpConf4OpName(op_name)); if (op_conf.has_variable_conf()) { lbn = op_name + "/" + op_conf.variable_conf().out(); } else if (op_conf.has_input_conf()) { lbn = op_name + "/" + op_conf.input_conf().out(); } else if (op_conf.has_output_conf()) { lbn = op_name + "/" + op_conf.output_conf().out(); } else if (op_conf.has_return_conf()) { lbn = op_name + "/" + op_conf.return_conf().out(); } else { UNIMPLEMENTED(); } } const auto& job = job_builder.job(); ParallelBlobConf ret; *blob_conf->mutable_parallel_conf() = CHECK_JUST(job_builder.ParallelConf4OpName(op_name)); *blob_conf->mutable_logical_blob_desc_conf() = job.helper().lbn2logical_blob_desc().at(lbn); *blob_conf->mutable_nd_sbp() = job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name).bn_in_op2nd_sbp().at( obn); } void FilterOpName2ParallelBlobConf( const HashSet& match, const std::vector>& jobs, HashMap* op_name2parallel_blob_conf) { FOR_RANGE(int64_t, job_id, 0, jobs.size()) { JobBuilder job_builder(jobs.at(job_id).get()); for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) { if (match.find(op_conf.op_type_case()) == match.end()) { continue; } ParallelBlobConf parallel_blob_conf; GetMemSharingOpBlobInfo(job_builder, op_conf.name(), ¶llel_blob_conf); auto iter = op_name2parallel_blob_conf->find(op_conf.name()); if (iter == op_name2parallel_blob_conf->end()) { CHECK(op_name2parallel_blob_conf->emplace(op_conf.name(), parallel_blob_conf).second); } else { CHECK(parallel_blob_conf == iter->second); } } } } void CheckNonDistributeOptimizerAvailable(const std::vector>& jobs) { bool has_job_enable_optimizer_placement_optimization = false; const auto IsEnabled = [](const Job& job) { return job.job_conf().has_train_conf() && job.job_conf().has_optimizer_placement_optimization_mode(); }; FOR_RANGE(int64_t, job_id, 0, jobs.size()) { if (IsEnabled(*jobs.at(job_id))) { has_job_enable_optimizer_placement_optimization = true; break; } } if (!has_job_enable_optimizer_placement_optimization) { return; } HashSet var_names; FOR_RANGE(int64_t, job_id, 0, jobs.size()) { if (!IsEnabled(*jobs.at(job_id))) { continue; } for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) { if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; } if (var_names.find(op_conf.name()) == var_names.end()) { var_names.emplace(op_conf.name()); } else { // optimizer_placement_optimization jobs has a same variable in between them. LOG(FATAL) << "Only support optimizer_placement_optimization when jobs not sharing same variable"; } } } FOR_RANGE(int64_t, job_id, 0, jobs.size()) { if (IsEnabled(*jobs.at(job_id))) { continue; } for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) { if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; } if (var_names.find(op_conf.name()) != var_names.end()) { // Other jobs has a same variable in optimizer_placement_optimization jobs. LOG(FATAL) << "Only support optimizer_placement_optimization when jobs not sharing same variable"; } } } } Maybe MakeMainJobComponent( const std::string& wait_and_send_ids_lbn, const Range& machine_id_range, JobBuilder* job_builder, std::vector>* identity_tick_op_names, std::vector>* cb_sink_tick_op_names) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id_range.begin()) + ":0"); auto lock_back_edge = std::make_shared(); OperatorConf reentrant_lock_op_conf; { lock_back_edge->reentrant_lock_op_name = std::string("System-Main-ReentrantLock_") + NewUniqueId(); reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name); auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf(); reentrant_lock_conf->set_start(wait_and_send_ids_lbn); // ibn "end" is set after plan generated because we don't like cycle in job reentrant_lock_conf->set_out("out"); Singleton::Get()->DumpCriticalSectionId2IntersectinIds( reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids()); JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf)); } // critical section case op conf OperatorConf cs_case_op_conf; { cs_case_op_conf.set_name(std::string("System-Main-Case_") + NewUniqueId()); auto* cs_case_conf = cs_case_op_conf.mutable_case_conf(); cs_case_conf->set_in(reentrant_lock_op_conf.name() + "/out"); FOR_RANGE(int64_t, i, 0, Singleton::Get()->CriticalSectionNum()) { cs_case_conf->add_out(GenRepeatedBn("out", i)); } JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf)); } const int64_t num_critial_sections = Singleton::Get()->CriticalSectionNum(); std::vector snk_tick_op_names; snk_tick_op_names.reserve(num_critial_sections * machine_id_range.size()); FOR_RANGE(int64_t, i, 0, num_critial_sections) { // source tick OperatorConf src_tick_op_conf; { std::string name_prefix = "System-Main-SourceTick_CriticalSection_"; src_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId()); auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf(); src_tick_conf->add_tick(cs_case_op_conf.name() + "/" + GenRepeatedBn("out", i)); src_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf)); } auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i); for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end(); ++machine_id) { // identity tick OperatorConf identity_tick_op_conf; { std::string name_prefix = "System-Main-Tick_CriticalSection_"; identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId()); auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf(); identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out"); identity_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf)); auto* cur_id_tick_op_names = &identity_tick_op_names->at(i); CHECK_OR_RETURN( cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second); } // callback { OperatorConf cb_sink_tick_op_conf; std::string name_prefix = "System-Main-CallbackSinkTick_"; cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId()); auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf(); cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + "/out"); cb_sink_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf)); CHECK_OR_RETURN( cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second); } // sink tick { OperatorConf snk_tick_op_conf; std::string name_prefix = "System-Main-SinkTick_CriticalSection_"; snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId()); auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); snk_tick_conf->add_tick(identity_tick_op_conf.name() + "/out"); snk_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf)); snk_tick_op_names.emplace_back(snk_tick_op_conf.name()); } } } // critical section esac op conf OperatorConf cs_esac_op_conf; { cs_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId()); // cs_esac_op_conf.set_pass_tag("main"); auto* cs_esac_conf = cs_esac_op_conf.mutable_esac_conf(); for (const auto& snk_tick_op_name : snk_tick_op_names) { cs_esac_conf->add_in(snk_tick_op_name + "/out"); } cs_esac_conf->set_out("out"); cs_esac_conf->set_data_type(DataType::kInt32); JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf)); } lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name()); lock_back_edge->critical_section_sink_lbi.set_blob_name("out"); return lock_back_edge; } Maybe MakeCallbackNotifierSinkTick( const std::set& process_ranks, const std::vector>& cb_sink_tick_op_names, JobBuilder* job_builder, const std::function& DoEachSinkTickLbn) { const auto& MakeSinkTick = [&](const std::vector& job_cs_ids, int64_t machine_id) -> Maybe { if (job_cs_ids.size() == 1) { return cb_sink_tick_op_names.at(job_cs_ids.at(0)).at(machine_id) + "/out"; } ParallelConf machine_parallel_conf; { machine_parallel_conf.set_device_tag("cpu"); machine_parallel_conf.add_device_name("@" + std::to_string(machine_id) + ":0"); } OperatorConf snk_tick_op_conf; { std::string name_prefix = "System-Main-CallbackNotifier_CriticalSection_"; snk_tick_op_conf.set_name(name_prefix + NewUniqueId()); auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); for (int64_t job_cs_id : job_cs_ids) { const auto& cb_sink_tick_op_name = cb_sink_tick_op_names.at(job_cs_id).at(machine_id); snk_tick_conf->add_tick(cb_sink_tick_op_name + "/out"); } snk_tick_conf->set_out("out"); JUST(job_builder->AddOp(machine_parallel_conf, snk_tick_op_conf)); } return snk_tick_op_conf.name() + "/out"; }; ParallelConf parallel_conf; { parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0"); } for (const auto& cs_ids : Singleton::Get()->job_id2critical_section_ids()) { OperatorConf snk_tick_op_conf; { std::string name_prefix = "System-Main-CallbackNotifier_CriticalSection_"; snk_tick_op_conf.set_name(name_prefix + NewUniqueId()); snk_tick_op_conf.set_pass_tag(kMainOp); auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); for (int64_t machine_id : process_ranks) { snk_tick_conf->add_tick(*JUST(MakeSinkTick(cs_ids, machine_id))); } snk_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf)); } DoEachSinkTickLbn(snk_tick_op_conf.name() + "/out"); } return Maybe::Ok(); } Maybe MakeMainJob(Job* main_job, std::vector>* identity_tick_op_names, std::vector* lock_back_edges) { JobBuilder job_builder(main_job); CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster()); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0"); OperatorConf wait_and_send_ids_op_conf; { wait_and_send_ids_op_conf.set_name(std::string("System-Main-WaitAndSendIds_") + NewUniqueId()); wait_and_send_ids_op_conf.set_pass_tag(kMainOp); auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf(); wait_and_send_ids_conf->set_out("out"); wait_and_send_ids_conf->set_wait_buffer_name(kBufferNameGlobalWaitJobId); wait_and_send_ids_conf->set_data_type(DataType::kInt32); auto* id_list = wait_and_send_ids_conf->mutable_id_list(); FOR_RANGE(int32_t, i, 0, Singleton::Get()->size()) { id_list->Add(); } HashSet unique_check; for (const auto& pair : *Singleton::Get()) { int64_t job_id = pair.second; CHECK_OR_RETURN(unique_check.insert(job_id).second); const auto& cs_idx = Singleton::Get()->CriticalSectionIds4JobId(job_id); *id_list->Mutable(job_id)->mutable_value() = {cs_idx.begin(), cs_idx.end()}; } JUST(job_builder.AddOp(parallel_conf, wait_and_send_ids_op_conf)); } const int64_t num_critial_sections = Singleton::Get()->CriticalSectionNum(); std::vector> cb_sink_tick_op_names; identity_tick_op_names->resize(num_critial_sections); cb_sink_tick_op_names.resize(num_critial_sections); const auto& process_ranks = Singleton::Get()->process_ranks(); for (int64_t machine_id : process_ranks) { Range sub_range(machine_id, machine_id + 1); const auto& in_lbn = wait_and_send_ids_op_conf.name() + "/out"; lock_back_edges->emplace_back(*JUST(MakeMainJobComponent( in_lbn, sub_range, &job_builder, identity_tick_op_names, &cb_sink_tick_op_names))); } OperatorConf callback_notify_esac_op_conf; { callback_notify_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId()); callback_notify_esac_op_conf.set_pass_tag(kMainOp); auto* callback_notify_esac_conf = callback_notify_esac_op_conf.mutable_esac_conf(); JUST(MakeCallbackNotifierSinkTick( process_ranks, cb_sink_tick_op_names, &job_builder, [&](const std::string& lbn) { callback_notify_esac_conf->add_in(lbn); })); callback_notify_esac_conf->set_out("out"); callback_notify_esac_conf->set_data_type(DataType::kInt32); JUST(job_builder.AddOp(parallel_conf, callback_notify_esac_op_conf)); } OperatorConf callback_notify_op_conf; { callback_notify_op_conf.set_name(std::string("System-Main-CallbackNotify_") + NewUniqueId()); callback_notify_op_conf.set_pass_tag(kMainOp); auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf(); callback_notify_conf->set_in(callback_notify_esac_op_conf.name() + "/out"); auto* buffer_names = callback_notify_conf->mutable_callback_buffer_name(); FOR_RANGE(int64_t, i, 0, Singleton::Get()->size()) { buffer_names->Add(); } for (const auto& pair : *Singleton::Get()) { int64_t job_id = pair.second; const auto& buffer_name = GetCallbackNotifierBufferName(pair.first); *buffer_names->Mutable(job_id) = buffer_name; } JUST(job_builder.AddOp(parallel_conf, callback_notify_op_conf)); } auto* job_conf = main_job->mutable_job_conf(); job_conf->set_job_name("MainJob-unamed"); job_conf->mutable_predict_conf(); job_conf->set_default_data_type(DataType::kInt32); return Maybe::Ok(); } Maybe ConnectCriticalSectionEndToReentrantLockEnd( Plan* main_plan, const ReentrantLockBackEdge& lock_back_edge) { TaskProto* reentrant_lock_task = nullptr; TaskProto* cs_sink_task = nullptr; FOR_RANGE(int64_t, i, 0, main_plan->task_size()) { auto* task = main_plan->mutable_task(i); CHECK_EQ_OR_RETURN(task->exec_sequence().exec_node_size(), 1); const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); const auto& op_name = PlanUtil::GetOpAttribute(main_plan, task->job_id(), kernel_conf).op_conf().name(); if (op_name == lock_back_edge.reentrant_lock_op_name) { CHECK_ISNULL_OR_RETURN(reentrant_lock_task); reentrant_lock_task = task; } else if (op_name == lock_back_edge.critical_section_sink_lbi.op_name()) { CHECK_ISNULL_OR_RETURN(cs_sink_task); cs_sink_task = task; } else { // do nothing } } CHECK_NOTNULL_OR_RETURN(reentrant_lock_task); CHECK_NOTNULL_OR_RETURN(cs_sink_task); RegstDescProto* cs_end_regst = PlanUtil::GetSoleProducedDataRegst(cs_sink_task); cs_end_regst->add_consumer_task_id(reentrant_lock_task->task_id()); reentrant_lock_task->mutable_consumed_regst_desc_id()->at("in").add_regst_desc_id( cs_end_regst->regst_desc_id()); auto* reentrant_exec_node = reentrant_lock_task->mutable_exec_sequence()->mutable_exec_node(0); (*reentrant_exec_node->mutable_bn_in_op2regst_desc_id())["end"] = cs_end_regst->regst_desc_id(); auto* op_attribute = reentrant_exec_node->mutable_kernel_conf()->mutable_op_attribute(); op_attribute->add_input_bns("end"); (*op_attribute->mutable_arg_signature()->mutable_bn_in_op2lbi())["end"] = lock_back_edge.critical_section_sink_lbi; const auto& blob_desc_signature_map = op_attribute->logical_blob_desc_signature().bn_in_op2blob_desc(); const auto it = blob_desc_signature_map.find("start"); CHECK_OR_RETURN(it != blob_desc_signature_map.end()); CHECK_OR_RETURN(blob_desc_signature_map.find("end") == blob_desc_signature_map.end()); (*op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc())["end"] = it->second; auto* reentrant_lock_conf = op_attribute->mutable_op_conf()->mutable_reentrant_lock_conf(); reentrant_lock_conf->set_end(GenLogicalBlobName(lock_back_edge.critical_section_sink_lbi)); return Maybe::Ok(); } Maybe CompileMainJob(Job* main_job, const std::vector& lock_back_edges, int64_t job_id, Plan* main_plan) { CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster()); { auto scope = std::make_unique(main_job->job_conf(), job_id); JUST(CompileCurJobOnMaster(main_job, main_plan, false)); } for (const auto& lock_back_edge : lock_back_edges) { JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge)); } return Maybe::Ok(); } void AddJobName2JobId(const std::string& job_name, int64_t job_id) { if (!GlobalProcessCtx::IsThisProcessMaster()) { return; } CHECK(Singleton::Get()->emplace(job_name, job_id).second); } bool NeedAllocateMemory(const RegstDescTypeProto& regst_desc_type) { return regst_desc_type.has_data_regst_desc(); } void FinishGlobalCriticalSectionDesc(const Plan& plan, int64_t job_size) { std::vector>> job_id2sole_op_name2mem_block_ids(job_size); std::vector> job_id2mem_block_ids(job_size); std::vector> job_id2chunk_ids(job_size); for (const auto& task : plan.task()) { if (task.exec_sequence().exec_node_size() == 1) { const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf(); const std::string& op_name = PlanUtil::GetOpAttribute(&plan, task.job_id(), kernel_conf).op_conf().name(); HashSet* mem_block_ids = &(job_id2sole_op_name2mem_block_ids.at(task.job_id())[op_name]); for (const auto& pair : task.produced_regst_desc()) { if (NeedAllocateMemory(pair.second.regst_desc_type())) { mem_block_ids->emplace(pair.second.mem_block_id()); } if (pair.second.has_separated_header_mem_block_id() && pair.second.separated_header_mem_block_id() != -1) { mem_block_ids->emplace(pair.second.separated_header_mem_block_id()); } } } } for (const auto& mem_block : plan.block_chunk_list().mem_block()) { if (mem_block.mem_size() == 0) { continue; } for (int64_t job_id : mem_block.job_id()) { job_id2mem_block_ids.at(job_id).insert(mem_block.mem_block_id()); } } for (const auto& chunk : plan.block_chunk_list().chunk()) { if (chunk.mem_size() == 0) { continue; } for (int64_t job_id : chunk.job_id()) { job_id2chunk_ids.at(job_id).insert(chunk.chunk_id()); } } HashMap> job_id2input_output_mem_block_ids; auto* critical_section_desc = Singleton::Get(); // set mem_block_id for InputOutputCriticalSection FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) { auto* critical_section = critical_section_desc->MutCriticalSection(i); int64_t job_id = critical_section->job_id(); auto* input_output_mem_block_ids = &job_id2input_output_mem_block_ids[job_id]; if (critical_section->has_input_output_critical_section()) { HashSet mem_block_ids; for (const auto& op_name : critical_section->input_output_critical_section().lbi_producer_op_name()) { const auto& cur_mem_block_ids = job_id2sole_op_name2mem_block_ids.at(job_id).at(op_name); mem_block_ids.insert(cur_mem_block_ids.begin(), cur_mem_block_ids.end()); } *critical_section->mutable_mem_block_id() = {mem_block_ids.begin(), mem_block_ids.end()}; input_output_mem_block_ids->insert(mem_block_ids.begin(), mem_block_ids.end()); } else { CHECK(critical_section->has_total_job_critical_section()); } } HashSet unique_job_id_check; // set mem_block_id for TotalJobCriticalSection FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) { auto* critical_section = critical_section_desc->MutCriticalSection(i); int64_t job_id = critical_section->job_id(); const auto& input_output_mem_block_ids = job_id2input_output_mem_block_ids.at(job_id); if (critical_section->has_total_job_critical_section()) { CHECK(unique_job_id_check.emplace(job_id).second); auto* mem_block_ids = &job_id2mem_block_ids.at(job_id); { // exclude input/output criticalsection mem_blob_ids from total_job auto it = mem_block_ids->begin(); while (it != mem_block_ids->end()) { if (input_output_mem_block_ids.find(*it) == input_output_mem_block_ids.end()) { ++it; } else { it = mem_block_ids->erase(it); } } } *critical_section->mutable_mem_block_id() = {mem_block_ids->begin(), mem_block_ids->end()}; *critical_section->mutable_chunk_id() = {job_id2chunk_ids.at(job_id).begin(), job_id2chunk_ids.at(job_id).end()}; } } critical_section_desc->Done(); } REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defined function"); Maybe CompileJobsAndMergePlans(const PbRpf& job_confs, Plan& plan) { std::vector> jobs(job_confs.size()); FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(job_confs.Get(i))); } // These checks donot work in nn.Graph API because there is only on job compile each time. // And nn.Graph Support training and evaluation share the same variable. if (jobs.size() > 1) { CheckNonDistributeOptimizerAvailable(jobs); } HashMap var_op_name2parallel_blob_conf; FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs, &var_op_name2parallel_blob_conf); std::vector> function_jobs; function_jobs.reserve(jobs.size()); FOR_RANGE(int, i, 0, jobs.size()) { JobDesc job_desc(jobs.at(i)->job_conf(), i); if (job_desc.Bool("__is_user_function__")) { function_jobs.emplace_back(jobs.at(i)); } } std::vector sub_plans(jobs.size()); FOR_RANGE(int64_t, i, 0, jobs.size()) { AddJobName2JobId(jobs.at(i)->job_conf().job_name(), i); auto scope = std::make_unique(jobs.at(i)->job_conf(), i); JUST(CompileCurJobOnMaster(jobs.at(i).get(), &sub_plans.at(i), true)); } MergeSubPlan(&plan, std::move(sub_plans)); InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(function_jobs, &plan); InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, &plan); PlanUtil::SetForceInplaceMemBlock(&plan); FinishGlobalCriticalSectionDesc(plan, jobs.size()); Plan main_plan; std::vector> identity_tick_op_names; { Job main_job; std::vector lock_back_edges; JUST(MakeMainJob(&main_job, &identity_tick_op_names, &lock_back_edges)); AddJobName2JobId(main_job.job_conf().job_name(), jobs.size()); JUST(CompileMainJob(&main_job, lock_back_edges, jobs.size(), &main_plan)); } LinkMainPlan(&plan, std::move(main_plan), identity_tick_op_names); PlanUtil::CleanUselessMemBlockAndCheckValid(&plan); PlanUtil::DumpCtrlRegstInfoToPlan(&plan); PlanUtil::PlanMemoryLog(&plan, "merged_plan"); if (Singleton::Get()->enable_debug_mode()) { TeePersistentLogStream::Create("merged_plan")->Write(plan); PlanUtil::ToDotFile(plan, "/dot/merged_plan.dot"); } return Maybe::Ok(); } Maybe CompileJobsAndPushMergedPlan(const PbRpf& job_confs) { if (GlobalProcessCtx::IsThisProcessMaster()) { Plan plan; JUST(CompileJobsAndMergePlans(job_confs, plan)); double start = GetCurTime(); // push op_attribute_info OpAttributeInfo op_attribute_info; *op_attribute_info.mutable_job_id2op_attribute_ref_table() = plan.job_id2op_attribute_ref_table(); Singleton::Get()->PushKV("op_attribute_info", op_attribute_info); // push plan PushPlan("merged_plan", std::move(plan)); LOG(INFO) << " PushPlan merged_plan time: " << (GetCurTime() - start) / 1e9 << " seconds.\n"; } OF_SESSION_BARRIER(); return Maybe::Ok(); } } // namespace Maybe Oneflow::Init(const oneflow::JobSet& job_set) { OF_PROFILER_RANGE_GUARD("Oneflow::Init"); // Runtime OF_PROFILER_RANGE_PUSH("CompileJobsAndPushMergedPlan"); JUST(CompileJobsAndPushMergedPlan(job_set.job())); OF_PROFILER_RANGE_POP(); // CompileJobsAndPushMergedPlan double start = GetCurTime(); PullPlan("merged_plan", &plan_); LOG(INFO) << " PullPlan merged_plan time: " << (GetCurTime() - start) / 1e9 << " seconds.\n"; if (GlobalProcessCtx::IsThisProcessMaster()) { runtime_buffers_scope_.reset(new RuntimeBuffersScope(plan_.job_confs())); } OF_PROFILER_RANGE_PUSH("new Runtime"); HashMap variable_op_name2eager_blob_object; runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object)); OF_PROFILER_RANGE_POP(); // new Runtime return Maybe::Ok(); } Oneflow::~Oneflow() { if (GlobalProcessCtx::IsThisProcessMaster()) { runtime_buffers_scope_.reset(); } runtime_.reset(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/oneflow.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_ONEFLOW_H_ #define ONEFLOW_CORE_JOB_ONEFLOW_H_ #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/job/runtime_buffers_scope.h" #include "oneflow/core/job/inter_user_job_info.pb.h" namespace oneflow { class Oneflow final { public: OF_DISALLOW_COPY_AND_MOVE(Oneflow); Oneflow() {} ~Oneflow(); Maybe Init(const oneflow::JobSet& job_set); private: Plan plan_; std::unique_ptr runtime_buffers_scope_; std::unique_ptr runtime_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_ONEFLOW_H_ ================================================ FILE: oneflow/core/job/parallel_conf_signature.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/placement.proto"; message ParallelConfSignature { optional ParallelConf op_parallel_conf = 1; map bn_in_op2parallel_conf = 2; } ================================================ FILE: oneflow/core/job/parallel_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/parallel_conf_util.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace { int64_t GetDeviceCount(DeviceType device_type) { return Singleton::Get()->GetDeviceCount(device_type); } using MachineId2DeviceIdList = std::shared_ptr>>>; bool GlobalDeviceIdsContaining(const MachineId2DeviceIdList& bigger, const MachineId2DeviceIdList& smaller) { for (const auto& pair : *smaller) { if (bigger->find(pair.first) == bigger->end()) { return false; } const auto& bigger_device_ids = bigger->find(pair.first)->second; std::vector::iterator ret; for (int64_t device_id : *pair.second) { ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id); if (ret == bigger_device_ids->end()) { return false; } } } return true; } } // namespace Maybe> ParseDeviceNameConf(const std::string& device_name) { size_t delimiter_pos = device_name.rfind(":"); CHECK_NE_OR_RETURN(delimiter_pos, std::string::npos); int64_t mchn_id = oneflow_cast(device_name.substr(0, delimiter_pos)); std::string device_id_str = device_name.substr(delimiter_pos + 1); return std::make_pair(mchn_id, device_id_str); } Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf) { ParallelDesc parallel_desc; JUST(parallel_desc.MaybeInit(parallel_conf)); auto machine2device_list = std::make_shared(); auto* features = machine2device_list->mutable_feature(); for (int64_t machine_id : parallel_desc.sorted_machine_ids()) { Int32List* device_id_list = (*features)[std::to_string(machine_id)].mutable_int32_list(); for (int64_t device_id : parallel_desc.sorted_dev_phy_ids(machine_id)) { device_id_list->add_value(device_id); } } return machine2device_list; } ParallelDesc::ParallelDesc(const ParallelConf& user_conf) : symbol_id_(NullOpt) { // NOLINT CHECK_JUST(MaybeInit(user_conf)); } Maybe ParallelDesc::New(int64_t symbol_id, const ParallelConf& parallel_conf) { std::shared_ptr parallel_desc(new ParallelDesc(symbol_id)); JUST(parallel_desc->MaybeInit(parallel_conf)); return parallel_desc; } Maybe ParallelDesc::New(const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy) { const auto parallel_conf = JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy)); std::shared_ptr parallel_desc; JUST(PhysicalRun([¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { parallel_desc = JUST(builder->GetParallelDescSymbol(*parallel_conf)); return Maybe::Ok(); })); return parallel_desc; } Maybe ParallelDesc::MaybeInit(const ParallelConf& user_conf) { parallel_conf_ = user_conf; device_type_ = DeviceType::kInvalidDevice; const std::string& device_tag = parallel_conf_.device_tag(); DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag)); CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type); device_type_ = device_type; machine_id2sorted_dev_phy_ids_ = std::make_shared>>>(); for (const std::string& device_name : parallel_conf_.device_name()) { if (device_name[0] == '@') { JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name.substr(1), 1)); } else { JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name, GlobalProcessCtx::NumOfProcessPerNode())); } } containing_current_rank_ = machine_id2sorted_dev_phy_ids_->count(GlobalProcessCtx::Rank()) > 0; ClearUp(); JUST(SanityCheck()); return Maybe::Ok(); } Maybe ParallelDesc::SetMachineIdAndDeviceIdsByParsingDeviceName( const std::string& device_name, size_t cols) { auto [node_id, device_id_str] = *JUST(ParseDeviceNameConf(device_name)); int64_t minus_pos = device_id_str.find("-"); if (minus_pos == std::string::npos) { device_id_str = device_id_str + "-" + device_id_str; minus_pos = device_id_str.find("-"); } int64_t min_id = oneflow_cast(device_id_str.substr(0, minus_pos)); int64_t max_id = oneflow_cast(device_id_str.substr(minus_pos + 1)); CHECK_LE_OR_RETURN(min_id, max_id); for (int64_t dev_phy_id = min_id; dev_phy_id <= max_id; ++dev_phy_id) { int64_t mchn_id = dev_phy_id % cols + node_id * cols; if (!(*machine_id2sorted_dev_phy_ids_)[mchn_id]) { (*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared>(); } (*machine_id2sorted_dev_phy_ids_)[mchn_id]->emplace_back(dev_phy_id); } return Maybe::Ok(); } Maybe ParallelDesc::ParallelId4MachineDeviceId(int64_t machine_id, int64_t device_id) const { const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id); CHECK_OR_RETURN(machine_iter != machine_id2device_id2parallel_id_.end()); const auto& device_iter = machine_iter->second.find(device_id); CHECK_OR_RETURN(device_iter != machine_iter->second.end()); return device_iter->second; } Maybe> ParallelDesc::GetTensorDevice4CurrentProcessCtx( Optional* parallel_id) const { int64_t machine_id = 0; int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); const auto& device = JUST(Device::New(device_tag(), device_id)); int64_t parallel_id_val = -1; if (TryGetParallelId(machine_id, device_id, ¶llel_id_val)) { *parallel_id = parallel_id_val; } else { *parallel_id = Optional(); } return device; } Maybe> GetTensorDevice4CurrentProcessCtx(Symbol parallel_desc, Optional* parallel_id) { static thread_local HashMap, Optional> parallel_desc2parallel_id; static thread_local HashMap, Symbol> parallel_desc2device; auto parallel_id_iter = parallel_desc2parallel_id.find(parallel_desc); auto device_iter = parallel_desc2device.find(parallel_desc); if (device_iter == parallel_desc2device.end()) { CHECK_OR_RETURN(parallel_id_iter == parallel_desc2parallel_id.end()); Optional id_val; const auto& device_symbol = JUST(parallel_desc->GetTensorDevice4CurrentProcessCtx(&id_val)); parallel_id_iter = parallel_desc2parallel_id.emplace(parallel_desc, id_val).first; device_iter = parallel_desc2device.emplace(parallel_desc, device_symbol).first; } else { CHECK_OR_RETURN(parallel_id_iter != parallel_desc2parallel_id.end()); } *parallel_id = parallel_id_iter->second; return device_iter->second; } bool ParallelDesc::TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const { const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id); if (machine_iter == machine_id2device_id2parallel_id_.end()) { return false; } const auto& device_iter = machine_iter->second.find(device_id); if (device_iter == machine_iter->second.end()) { return false; } *parallel_id = device_iter->second; return true; } Maybe ParallelDesc::TryGetParallelId(int64_t rank, int64_t* parallel_id) const { if (!HasMachineId(rank)) { return false; } const auto& device_ids = sorted_dev_phy_ids(rank); CHECK_EQ_OR_RETURN(device_ids.size(), 1) << "only sole device_id supported. parallel_conf: \n" << parallel_conf().DebugString(); return TryGetParallelId(rank, JUST(VectorAt(device_ids, 0)), parallel_id); } Maybe ParallelDesc::GetParallelContext(ParallelContext* parallel_ctx, int64_t machine_id, int64_t device_id) const { parallel_ctx->set_parallel_num(parallel_num()); parallel_ctx->set_parallel_id(JUST(ParallelId4MachineDeviceId(machine_id, device_id))); return Maybe::Ok(); } bool ParallelDesc::Equals(const ParallelDesc& rhs) const { return (this == &rhs) || (device_type_ == rhs.device_type_ && sorted_machine_ids_ == rhs.sorted_machine_ids_ && EqualsMachineId2SortedDevPhyIds(rhs) && *hierarchy_ == *rhs.hierarchy_); } bool ParallelDesc::EqualsIgnoringDeviceType(const ParallelDesc& rhs) const { return sorted_machine_ids_ == rhs.sorted_machine_ids_ && EqualsMachineId2SortedDevPhyIds(rhs) && *hierarchy_ == *rhs.hierarchy_; } bool ParallelDesc::EqualsIgnoringHierarchy(const ParallelDesc& rhs) const { return (this == &rhs) || (device_type_ == rhs.device_type_ && sorted_machine_ids_ == rhs.sorted_machine_ids_ && EqualsMachineId2SortedDevPhyIds(rhs)); } bool ParallelDesc::EqualsOnlyForMachineAndDeviceIds(const ParallelDesc& rhs) const { return (this == &rhs) || (sorted_machine_ids_ == rhs.sorted_machine_ids_ && EqualsMachineId2SortedDevPhyIds(rhs)); } bool ParallelDesc::EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const { for (int64_t machine_id : sorted_machine_ids_) { if (*machine_id2sorted_dev_phy_ids_->at(machine_id) != *rhs.machine_id2sorted_dev_phy_ids_->at(machine_id)) { return false; } } return true; } void ParallelDesc::ClearUp() { EraseIf>>( machine_id2sorted_dev_phy_ids_.get(), [](HashMap>>::iterator it) { return it->second->empty(); }); sorted_machine_ids_.clear(); parallel_num_ = 0; for (auto& pair : *machine_id2sorted_dev_phy_ids_) { sorted_machine_ids_.emplace_back(pair.first); SortAndRemoveDuplication((pair.second).get()); parallel_num_ += pair.second->size(); } if (parallel_conf_.has_hierarchy() && parallel_conf_.hierarchy().dim_size() != 0) { hierarchy_.reset(new Shape(parallel_conf_.hierarchy())); CHECK_EQ(hierarchy_->elem_cnt(), parallel_num_); } else { hierarchy_.reset(new Shape({parallel_num_})); hierarchy_->ToProto(parallel_conf_.mutable_hierarchy()); } SortAndRemoveDuplication(&sorted_machine_ids_); parallel_conf_.clear_device_name(); int64_t parallel_id = 0; for (int64_t machine_id : sorted_machine_ids_) { for (int64_t device_id : *machine_id2sorted_dev_phy_ids_->at(machine_id)) { parallel_conf_.add_device_name(std::string("@") + std::to_string(machine_id) + ":" + std::to_string(device_id)); CHECK_EQ(parallel_id, parallel_id2machine_id_.size()); parallel_id2machine_id_.push_back(machine_id); CHECK_EQ(parallel_id, parallel_id2device_id_.size()); parallel_id2device_id_.push_back(device_id); machine_id2device_id2parallel_id_[machine_id][device_id] = parallel_id; parallel_id += 1; } } } void ParallelDesc::set_device_type(DeviceType device_type) { if (device_type == device_type_) { return; } device_type_ = device_type; const std::string tag = *CHECK_JUST(DeviceTag4DeviceType(device_type)); parallel_conf_.set_device_tag(tag); } Maybe ParallelDesc::SanityCheck() { device_num_of_each_machine_ = -1; for (auto& pair : *machine_id2sorted_dev_phy_ids_) { if (device_num_of_each_machine_ == -1) { device_num_of_each_machine_ = pair.second->size(); } else { CHECK_EQ_OR_RETURN(device_num_of_each_machine_, pair.second->size()); } } return Maybe::Ok(); } Maybe ParallelDesc::CheckDeviceIdsIsValid() const { const auto& sorted_dev_phy_ids_iter = machine_id2sorted_dev_phy_ids_->find(GlobalProcessCtx::Rank()); for (int64_t machine_id : sorted_machine_ids_) { CHECK_LT_OR_RETURN(machine_id, GlobalProcessCtx::WorldSize()) << Error::RuntimeError() << "Placement is invalid because rank must be less than world size!"; } if (sorted_dev_phy_ids_iter != machine_id2sorted_dev_phy_ids_->end()) { for (int64_t dev_phy_id : *sorted_dev_phy_ids_iter->second) { if (device_type_ == DeviceType::kCPU) { CHECK_LT_OR_RETURN(dev_phy_id, GlobalProcessCtx::NumOfProcessPerNode()) << Error::RuntimeError() << "Placement is invalid because device id must be less than num of process per node"; } else { const int64_t device_count = GetDeviceCount(device_type_); CHECK_NE_OR_RETURN(device_count, 0) << Error::RuntimeError() << "Placement is invalid because there is no device!"; int64_t device_num = std::min(GlobalProcessCtx::NumOfProcessPerNode(), device_count); CHECK_LT_OR_RETURN(dev_phy_id, device_num) << Error::RuntimeError() << "Placement is invalid because device id must be less than " << (device_count < GlobalProcessCtx::NumOfProcessPerNode() ? "num devices on node" : "num of process per node"); } } } return Maybe::Ok(); } ParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const { ParallelConf parallel_conf; std::string rank = std::to_string(CHECK_JUST(MachineId4ParallelId(parallel_id))); std::string device_id = std::to_string(CHECK_JUST(DeviceId4ParallelId(parallel_id))); parallel_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type()))); parallel_conf.add_device_name(std::string("@") + rank + ":" + device_id); return parallel_conf; } Maybe ParallelDesc::MachineId4ParallelId(int64_t parallel_id) const { CHECK_LT_OR_RETURN(parallel_id, parallel_id2machine_id_.size()) << "parallel_id: " << parallel_id << "\n----[ parallel_conf ]----" << parallel_conf().DebugString(); return parallel_id2machine_id_.at(parallel_id); } Maybe ParallelDesc::DeviceId4ParallelId(int64_t parallel_id) const { CHECK_LT_OR_RETURN(parallel_id, parallel_id2device_id_.size()) << "parallel_id: " << parallel_id << "\n----[ parallel_conf ]----" << parallel_conf().DebugString(); return parallel_id2device_id_.at(parallel_id); } bool ParallelDesc::ContainingMachineId(int64_t machine_id) const { return machine_id2sorted_dev_phy_ids_->find(machine_id) != machine_id2sorted_dev_phy_ids_->end(); } bool ParallelDesc::Containing(int64_t machine_id, int64_t device_id) const { const auto& machine_iter = machine_id2sorted_dev_phy_ids_->find(machine_id); if (machine_iter == machine_id2sorted_dev_phy_ids_->end()) { return false; } const auto& vec = machine_iter->second; return std::find(vec->begin(), vec->end(), device_id) != vec->end(); } bool ParallelDesc::Bigger(const ParallelDesc& rhs) const { if (device_tag() != rhs.device_tag()) { return false; } return GlobalDeviceIdsContaining(machine_id2sorted_dev_phy_ids_, rhs.machine_id2sorted_dev_phy_ids()); } std::tuple GetPartIdAndPartNumFromParallelCtx( const ParallelContext* parallel_ctx) { return std::make_tuple(parallel_ctx->parallel_id(), parallel_ctx->parallel_num()); } ParallelConf GenParallelConfOfCpuZeroOnMaster() { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0"); return parallel_conf; } ParallelConf GenParallelConfOfCpuZeroOnAllMachines() { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); for (int64_t i : Singleton::Get()->process_ranks()) { parallel_conf.add_device_name(std::string("@") + std::to_string(i) + ":0"); } return parallel_conf; } ParallelConf GenParallelConfOfCpuOnAllRanks() { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); int64_t node_size = GlobalProcessCtx::NodeSize(); int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode(); for (int64_t node_id = 0; node_id < node_size; ++node_id) { parallel_conf.add_device_name(std::to_string(node_id) + ":0-" + std::to_string(device_num - 1)); } return parallel_conf; } namespace { Maybe> CalcParallelId4CurrentProcessCtx(Symbol parallel_desc) { int64_t machine_id = 0; int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); int64_t parallel_id = -1; if (parallel_desc->TryGetParallelId(machine_id, device_id, ¶llel_id)) { return Optional(parallel_id); } else { return Optional(); } } Maybe CalcParallelContext4CurrentProcessCtx( Symbol parallel_desc) { int64_t machine_id = 0; int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); int64_t parallel_id_val = -1; CHECK_OR_RETURN(parallel_desc->TryGetParallelId(machine_id, device_id, ¶llel_id_val)); std::shared_ptr parallel_ctx = std::make_shared(); parallel_ctx->set_parallel_id(parallel_id_val); parallel_ctx->set_parallel_num(parallel_desc->parallel_num()); return std::shared_ptr(parallel_ctx); } Maybe> RawReplaceDeviceType(Symbol parallel_desc, DeviceType device_type) { ParallelConf parallel_conf(parallel_desc->parallel_conf()); parallel_conf.set_device_tag(*JUST(DeviceTag4DeviceType(device_type))); return SymbolOf(ParallelDesc(parallel_conf)); } Maybe RanksToString(int64_t axis, const int64_t* ranks, const Shape& shape) { if (axis == shape.NumAxes()) { return std::to_string(*ranks); } int64_t stride = shape.Count(axis) / shape.At(axis); std::string str = "["; for (int i = 0; i < shape.At(axis); ++i) { str += *JUST(RanksToString(axis + 1, ranks, shape)); ranks += stride; if (i != shape.At(axis) - 1) { str += ", "; } } str += "]"; return str; } Maybe RawPlacementToString(Symbol placement) { const std::string& device_type = placement->device_tag(); std::vector sorted_node_ids; sorted_node_ids.reserve(placement->sorted_machine_ids().size()); HashMap> node_id2sorted_dev_phy_ids; for (int64_t machine_id : placement->sorted_machine_ids()) { int64_t node_id = GlobalProcessCtx::NodeId(machine_id); if (!std::count(sorted_node_ids.begin(), sorted_node_ids.end(), node_id)) { sorted_node_ids.emplace_back(node_id); } for (int64_t device_id : placement->sorted_dev_phy_ids(machine_id)) { node_id2sorted_dev_phy_ids[node_id].emplace_back(device_id); } } std::vector ranks; for (int64_t node_id : sorted_node_ids) { for (int64_t device_id : node_id2sorted_dev_phy_ids.at(node_id)) { ranks.emplace_back(node_id * GlobalProcessCtx::NumOfProcessPerNode() + device_id); } } CHECK_EQ_OR_RETURN(ranks.size(), placement->hierarchy()->elem_cnt()) << "rank size is " << ranks.size() << ", but shape is " << placement->hierarchy()->ToString(); const auto& ranks_str = JUST(RanksToString(0, ranks.data(), *placement->hierarchy())); return "oneflow.placement(type=\"" + device_type + "\", ranks=" + *ranks_str + ")"; } Maybe> RawGetTensorDevice(Symbol parallel_desc) { int64_t machine_id = 0; int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); const auto& type = parallel_desc->device_tag(); return JUST(Device::New(type, device_id)); } Maybe> RawTxtStringToPlacement(const std::string& parallel_conf_str) { ParallelConf parallel_conf; CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, ¶llel_conf)); return SymbolOf(ParallelDesc(parallel_conf)); } Maybe RawCheckDeviceIdsIsValid(Symbol placement) { JUST(placement->CheckDeviceIdsIsValid()); return Maybe::Ok(); } Maybe> RawGetParallelDescOfThisRank(const std::string& device_tag) { ParallelConf parallel_conf; parallel_conf.set_device_tag(device_tag); parallel_conf.add_device_name(std::to_string(GlobalProcessCtx::Rank()) + ":" + std::to_string(GlobalProcessCtx::LocalRank())); return SymbolOf(ParallelDesc(parallel_conf)); } } // namespace decltype(GetParallelId4CurrentProcessCtx) GetParallelId4CurrentProcessCtx = DECORATE(&CalcParallelId4CurrentProcessCtx, ThreadLocal); decltype(GetParallelContext4CurrentProcessCtx) GetParallelContext4CurrentProcessCtx = DECORATE(&CalcParallelContext4CurrentProcessCtx, ThreadLocal); decltype(ReplaceDeviceType) ReplaceDeviceType = DECORATE(&RawReplaceDeviceType, ThreadLocal); decltype(PlacementToString) PlacementToString = DECORATE(&RawPlacementToString, ThreadLocal); decltype(GetTensorDevice) GetTensorDevice = DECORATE(&RawGetTensorDevice, ThreadLocal); decltype(TxtStringToPlacement) TxtStringToPlacement = DECORATE(&RawTxtStringToPlacement, ThreadLocalCopiable); decltype(GetParallelDescOfThisRank) GetParallelDescOfThisRank = DECORATE(&RawGetParallelDescOfThisRank, ThreadLocalCopiable); decltype(CheckDeviceIdsIsValid) CheckDeviceIdsIsValid = DECORATE(&RawCheckDeviceIdsIsValid, ThreadLocal); } // namespace oneflow ================================================ FILE: oneflow/core/job/parallel_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_PARALLEL_DESC_H_ #define ONEFLOW_CORE_JOB_PARALLEL_DESC_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/cached_caller.h" namespace oneflow { class ResourceDesc; Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf); Maybe> ParseDeviceNameConf(const std::string& device_name); class ParallelContext; class Device; class ParallelDesc final { public: ~ParallelDesc() = default; ParallelDesc(const ParallelDesc&) = default; ParallelDesc(const ParallelConf& user_conf); static Maybe New(int64_t symbol_id, const ParallelConf& parallel_conf); static Maybe New(const std::string& device_tag, const std::vector& machine_device_ids, const std::shared_ptr& hierarchy); Maybe MaybeInit(const ParallelConf& user_conf); // Getters const Optional& symbol_id() const { return symbol_id_; } bool containing_current_rank() const { return containing_current_rank_; } DeviceType device_type() const { return device_type_; } const std::string& device_tag() const { return parallel_conf_.device_tag(); } std::shared_ptr>>> machine_id2sorted_dev_phy_ids() const { return machine_id2sorted_dev_phy_ids_; } bool HasMachineId(int64_t machine_id) const { return machine_id2sorted_dev_phy_ids_->find(machine_id) != machine_id2sorted_dev_phy_ids_->end(); } const std::vector& sorted_machine_ids() const { return sorted_machine_ids_; } const std::vector& sorted_dev_phy_ids(int64_t machine_id) const { return *machine_id2sorted_dev_phy_ids_->at(machine_id); } int64_t parallel_num() const { return parallel_num_; } int64_t device_num_of_each_machine() const { return device_num_of_each_machine_; } const ParallelConf& parallel_conf() const { return parallel_conf_; } const ParallelConf& data() const { return parallel_conf_; } Maybe GetParallelContext(ParallelContext* parallel_ctx, int64_t machine_id, int64_t device_id) const; std::shared_ptr hierarchy() const { return hierarchy_; } // Setters void set_device_type(DeviceType device_type); ParallelConf GetParallelIdOnlyParallelConf(int64_t parallel_id) const; bool EqualsIgnoringDeviceType(const ParallelDesc& rhs) const; bool EqualsIgnoringHierarchy(const ParallelDesc& rhs) const; bool EqualsOnlyForMachineAndDeviceIds(const ParallelDesc& rhs) const; bool Equals(const ParallelDesc& rhs) const; bool operator==(const ParallelDesc& rhs) const { return Equals(rhs); } bool operator!=(const ParallelDesc& rhs) const { return !(*this == rhs); } bool Equals(const ParallelDesc* rhs) const { return Equals(*rhs); } const std::vector& parallel_id2machine_id() const { return parallel_id2machine_id_; } const std::vector& parallel_id2device_id() const { return parallel_id2device_id_; } Maybe MachineId4ParallelId(int64_t parallel_id) const; Maybe DeviceId4ParallelId(int64_t parallel_id) const; Maybe ParallelId4MachineDeviceId(int64_t machine_id, int64_t device_id) const; Maybe> GetTensorDevice4CurrentProcessCtx(Optional* parallel_id) const; bool Containing(int64_t machine_id, int64_t device_id) const; // this api is exported to python as Containing bool Bigger(const ParallelDesc& rhs) const; bool ContainingMachineId(int64_t machine_id) const; bool TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const; Maybe TryGetParallelId(int64_t rank, int64_t* parallel_id) const; Maybe CheckDeviceIdsIsValid() const; private: friend Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf); ParallelDesc() : symbol_id_(NullOpt) {} ParallelDesc(int64_t symbol_id) : symbol_id_(symbol_id) {} void ClearUp(); Maybe SetMachineIdAndDeviceIdsByParsingDeviceName(const std::string& device_name, size_t cols); Maybe SanityCheck(); Maybe CheckWithResourceDesc(const ResourceDesc& resource_desc); bool EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const; Optional symbol_id_; DeviceType device_type_; ParallelConf parallel_conf_; std::shared_ptr hierarchy_; std::vector sorted_machine_ids_; std::shared_ptr>>> machine_id2sorted_dev_phy_ids_; int64_t parallel_num_; int64_t device_num_of_each_machine_; std::vector parallel_id2machine_id_; std::vector parallel_id2device_id_; HashMap> machine_id2device_id2parallel_id_; // cached result of ContainingMachineId(GlobalProcessCtx::Rank()) for performace optimization. bool containing_current_rank_; }; Maybe> GetTensorDevice4CurrentProcessCtx(Symbol parallel_desc, Optional* parallel_id); extern Maybe> (*GetParallelId4CurrentProcessCtx)( Symbol parallel_desc); extern Maybe (*GetParallelContext4CurrentProcessCtx)( Symbol parallel_desc); extern Maybe> (*ReplaceDeviceType)(Symbol, DeviceType); extern Maybe (*PlacementToString)(Symbol placement); extern Maybe> (*GetTensorDevice)(Symbol parallel_desc); extern Maybe> (*TxtStringToPlacement)(const std::string& parallel_conf_str); extern Maybe (*CheckDeviceIdsIsValid)(Symbol placement); extern Maybe> (*GetParallelDescOfThisRank)(const std::string& device_tag); inline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) { return ParallelDesc(lhs) == ParallelDesc(rhs); } inline bool operator!=(const ParallelConf& lhs, const ParallelConf& rhs) { return ParallelDesc(lhs) != ParallelDesc(rhs); } std::tuple GetPartIdAndPartNumFromParallelCtx( const ParallelContext* parallel_ctx); ParallelConf GenParallelConfOfCpuZeroOnMaster(); ParallelConf GenParallelConfOfCpuZeroOnAllMachines(); ParallelConf GenParallelConfOfCpuOnAllRanks(); namespace private_details { Maybe> RawReplaceDeviceType(Symbol, DeviceType); Maybe RawPlacementToString(Symbol placement); Maybe> RawTxtStringToPlacement(const std::string& parallel_conf_str); } // namespace private_details } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::ParallelDesc& pr) const { using namespace oneflow; size_t ret = 0; int i = 0; int shift_roundtrip = (sizeof(size_t) / 2); for (int machine_id : pr.sorted_machine_ids()) { int shift = i++ % shift_roundtrip; AddHash(&ret, machine_id << shift_roundtrip << shift); AddHash(&ret, pr.sorted_dev_phy_ids(machine_id).size() << shift); } AddHash(&ret, *pr.hierarchy()); return hash()(ret); } }; template<> struct hash { size_t operator()(const oneflow::ParallelConf& parallel_conf) const { return std::hash()(oneflow::ParallelDesc(parallel_conf)); } }; } // namespace std #endif // ONEFLOW_CORE_JOB_PARALLEL_DESC_H_ ================================================ FILE: oneflow/core/job/parallel_desc_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "gtest/gtest.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" namespace oneflow { namespace test { namespace { struct GlobaProcessCtxScope final { GlobaProcessCtxScope(int64_t node_size, int64_t world_size) { Singleton::New(); auto* ctx = Singleton::Get(); for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); } ctx->set_rank(0); ctx->set_node_size(node_size); } ~GlobaProcessCtxScope() { Singleton::Delete(); } }; } // namespace TEST(ParallelDesc, continuous_1n4d) { GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); } TEST(ParallelDesc, continuous_1n4d_multi_process) { GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); ParallelDesc parallel_desc(parallel_conf); const std::vector& machine_ids = parallel_desc.sorted_machine_ids(); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 0), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 1), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 2), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 3), 1); } TEST(ParallelDesc, continuous_1n4d_multi_process_with_rank) { GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("@0:0-3"); ParallelDesc parallel_desc(parallel_conf); const std::vector& machine_ids = parallel_desc.sorted_machine_ids(); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); ASSERT_EQ(machine_ids.size(), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 0), 1); } TEST(ParallelDesc, discrete_1n4d) { GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); parallel_conf.add_device_name("0:2-3"); ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); } TEST(ParallelDesc, continuous_2n8d) { GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); parallel_conf.add_device_name("1:0-3"); ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 8); } TEST(ParallelDesc, discrete_2n8d) { GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); parallel_conf.add_device_name("0:2-3"); parallel_conf.add_device_name("1:0-1"); parallel_conf.add_device_name("1:2-3"); ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 8); } TEST(GetBroadcastGroup, naive_1n1d) { GlobaProcessCtxScope scope(1, 1); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0"); const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); const auto& map = CHECK_JUST(GetBroadcastGroup(parallel_desc, parallel_desc)); ASSERT_EQ(map->size(), 1); ASSERT_EQ(map->begin()->first, 0); ASSERT_TRUE(map->begin()->second == parallel_desc); } TEST(GetBroadcastGroup, naive_1n4d) { GlobaProcessCtxScope scope(1, 4); ParallelConf src_parallel_conf; src_parallel_conf.set_device_tag("cpu"); src_parallel_conf.add_device_name("0:0"); const auto& src_parallel_desc = SymbolOf(ParallelDesc(src_parallel_conf)); ParallelConf dst_parallel_conf; dst_parallel_conf.set_device_tag("cpu"); dst_parallel_conf.add_device_name("0:0-3"); const auto& dst_parallel_desc = SymbolOf(ParallelDesc(dst_parallel_conf)); const auto& map = CHECK_JUST(GetBroadcastGroup(src_parallel_desc, dst_parallel_desc)); ASSERT_EQ(map->size(), 4); for (int i = 0; i < 4; ++i) { const auto& iter = map->find(i); ASSERT_TRUE(iter != map->end()); ASSERT_TRUE(iter->second == dst_parallel_desc); } } TEST(GetBroadcastGroup, naive_2n8d) { GlobaProcessCtxScope scope(2, 8); ParallelConf src_parallel_conf; src_parallel_conf.set_device_tag("cpu"); src_parallel_conf.add_device_name("0:0"); src_parallel_conf.add_device_name("1:0"); const auto& src_parallel_desc = SymbolOf(ParallelDesc(src_parallel_conf)); ParallelConf dst_parallel_conf; dst_parallel_conf.set_device_tag("cpu"); dst_parallel_conf.add_device_name("0:0-3"); dst_parallel_conf.add_device_name("1:0-3"); const auto& dst_parallel_desc = SymbolOf(ParallelDesc(dst_parallel_conf)); const auto& map = CHECK_JUST(GetBroadcastGroup(src_parallel_desc, dst_parallel_desc)); ASSERT_EQ(map->size(), 8); ParallelConf first_node_parallel_conf; first_node_parallel_conf.set_device_tag("cpu"); first_node_parallel_conf.add_device_name("0:0-3"); const auto& first_node_parallel_desc = SymbolOf(ParallelDesc(first_node_parallel_conf)); for (int i = 0; i < 4; ++i) { const auto& iter = map->find(i); ASSERT_TRUE(iter != map->end()); ASSERT_TRUE(iter->second == first_node_parallel_desc); } ParallelConf second_node_parallel_conf; second_node_parallel_conf.set_device_tag("cpu"); second_node_parallel_conf.add_device_name("1:0-3"); const auto& second_node_parallel_desc = SymbolOf(ParallelDesc(second_node_parallel_conf)); for (int i = 4; i < 8; ++i) { const auto& iter = map->find(i); ASSERT_TRUE(iter != map->end()); ASSERT_TRUE(iter->second == second_node_parallel_desc); } } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/job/parallel_signature.proto ================================================ syntax = "proto2"; package oneflow; message ParallelSignature { optional int64 op_parallel_desc_symbol_id = 1; map bn_in_op2parallel_desc_symbol_id = 2; } ================================================ FILE: oneflow/core/job/pipeline_config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" namespace oneflow { namespace { REGISTER_SCOPE_CONFIG_DEF().Int64( "pipeline_stage_id_hint", 0, "Manually marking different stages of pipelining parallelism. \n Generally speaking, different " "stages are on different devices, and these stages are connected sequentially, so that the " "whole network can be pipeline parallel."); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job/placement.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/common/shape.proto"; message ParallelContext { required int64 parallel_id = 1; required int64 parallel_num = 2; } message ParallelConf { repeated string device_name = 1; required string device_tag = 2; optional ShapeProto hierarchy = 3; } message OpNameSet { repeated string op_name = 1; } message PlacementGroup { required OpNameSet op_set = 1; required ParallelConf parallel_conf = 2; } message BlobPlacementGroup { repeated LogicalBlobId lbi = 1; required ParallelConf parallel_conf = 2; } message Placement { repeated PlacementGroup placement_group = 1; repeated BlobPlacementGroup blob_placement_group = 2; } ================================================ FILE: oneflow/core/job/placement_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/placement_scope.h" #include "oneflow/core/operator/operator.h" namespace oneflow { Maybe> PlacementScope::GetParallelDesc(const std::string& device_tag, const OperatorConf& op_conf) const { if (device_tag == "cpu" || IsCpuOnly(op_conf)) { return host_parallel_desc_; } else { return device_parallel_desc_; } } Maybe> PlacementScope::GetParallelDesc(const std::string& device_tag, const std::string& op_type_name) const { if (device_tag == "cpu" || IsCpuOnly(op_type_name)) { return host_parallel_desc_; } else { return device_parallel_desc_; } } Maybe> PlacementScope::GetParallelDesc(const std::string& op_type_name) const { return GetParallelDesc(device_parallel_desc_->device_tag(), op_type_name); } } // namespace oneflow ================================================ FILE: oneflow/core/job/placement_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_ #define ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_ #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { class OperatorConf; class PlacementScope final { public: PlacementScope(Symbol device_parallel_desc, Symbol host_parallel_desc) : device_parallel_desc_(device_parallel_desc), host_parallel_desc_(host_parallel_desc) {} size_t hash_value() const { return Hash(device_parallel_desc_, host_parallel_desc_); } bool operator==(const PlacementScope& other) const { return this->device_parallel_desc_ == other.device_parallel_desc_ && this->host_parallel_desc_ == other.host_parallel_desc_; } Symbol device_parallel_desc() const { return device_parallel_desc_; } Symbol host_parallel_desc() const { return host_parallel_desc_; } Maybe> GetParallelDesc(const std::string& device_tag, const OperatorConf& op_conf) const; Maybe> GetParallelDesc(const std::string& device_tag, const std::string& op_type_name) const; Maybe> GetParallelDesc(const std::string& op_type_name) const; private: Symbol device_parallel_desc_; Symbol host_parallel_desc_; }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::PlacementScope& val) const { return val.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_ ================================================ FILE: oneflow/core/job/plan.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/task.proto"; import "oneflow/core/job/job_conf.proto"; import "oneflow/core/memory/memory_block.proto"; import "oneflow/core/graph/boxing/collective_boxing.proto"; import "oneflow/core/operator/op_attribute.proto"; message MachineIds { repeated int64 machine_id = 1; } message JobConfs { map job_id2job_conf = 1; } message CollectiveBoxingPlan { map job_id2request_set = 1; } message CtrlRegstDescInfo { map ctrl_regst_desc_id2producer_task_id = 6; } message OpAttributeRefTable { map op_name2op_attribute = 1; } message OpAttributeInfo { map job_id2op_attribute_ref_table = 1; } message Plan { repeated TaskProto task = 1; required MemBlockAndChunkList block_chunk_list = 2; required JobConfs job_confs = 4; required CollectiveBoxingPlan collective_boxing_plan= 5; required CtrlRegstDescInfo ctrl_regst_desc_info = 6; map job_id2op_attribute_ref_table = 7; } ================================================ FILE: oneflow/core/job/plan_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/constant.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/graph/plan_task_graph.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/memory/chunk_manager.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/task_node.h" namespace oneflow { RegstDescProto* PlanUtil::GetSoleProducedDataRegst(TaskProto* task_proto) { RegstDescProto* ret = nullptr; for (auto& pair : *task_proto->mutable_produced_regst_desc()) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->regst_desc_type().has_data_regst_desc()) { CHECK_ISNULL(ret); CHECK_EQ(regst_desc->regst_desc_type().data_regst_desc().lbi2blob_desc_size(), 1); ret = regst_desc; } } CHECK_NOTNULL(ret); return ret; } std::function PlanUtil::MakeGetterTaskProto4TaskId(const Plan& plan) { auto task_id2task_proto = std::make_shared>(); for (const TaskProto& task_proto : plan.task()) { task_id2task_proto->emplace(task_proto.task_id(), &task_proto); } return [task_id2task_proto](int64_t task_id) { return task_id2task_proto->at(task_id); }; } namespace { void SetVariableOpNamesForVariableAndRepeatRegst(Plan* plan) { // NOTE(chengcheng): set variable_op_name before set separated header because var regst alway // separated. HashMap regst_id2var_name; for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); if (task->exec_sequence().exec_node_size() == 1) { const auto& op_conf = PlanUtil::GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf()) .op_conf(); if (op_conf.has_variable_conf()) { RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task); regst_id2var_name.emplace(regst->regst_desc_id(), op_conf.name()); regst->set_variable_op_name(op_conf.name()); } } } for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); if (task->task_type() == TaskType::kRepeat) { RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task); CHECK(regst->has_force_inplace_consumed_regst_desc_id()); int64_t force_inplace_regst_id = regst->force_inplace_consumed_regst_desc_id(); auto var_name_it = regst_id2var_name.find(force_inplace_regst_id); if (var_name_it != regst_id2var_name.end()) { regst->set_variable_op_name(var_name_it->second); VLOG(3) << " set var op name to repeat regst : " << regst->DebugString(); } } } } } // namespace void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { SetVariableOpNamesForVariableAndRepeatRegst(plan); for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); for (auto& pair : *task->mutable_produced_regst_desc()) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->mem_block_id() == -1) { CHECK_EQ(regst_desc->mem_block_offset(), -1); regst_desc->set_mem_block_id(Singleton::Get()->NewMemBlockId()); regst_desc->set_mem_block_offset(0); } RtRegstDesc rt_regst_desc(*regst_desc); int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst(); if (regst_separated_size > 0) { int64_t separated_mem_block_id = Singleton::Get()->NewMemBlockId(); regst_desc->set_separated_header_mem_block_id(separated_mem_block_id); } } } } void PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) { HashSet variable_op_names; PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names); } namespace { void GenChunkForMultiNNGraphMemoryReuseInMultiClient( Plan* plan, HashMap>* mem_block_id2mem_block) { HashMap> mzuid2mem_blocks; for (auto& pair : *mem_block_id2mem_block) { MemBlockProto* mem_block = pair.second.get(); CHECK(mem_block->has_chunk_id() == false); CHECK(mem_block->has_chunk_offset() == false); if (mem_block->has_variable_op_name()) { continue; } if (!mem_block->enable_reuse_mem()) { continue; } // NOTE(chengcheng): // only reused mem in cuda device. // special cpu memory like OFRecord pb and TensorBuffer CANNOT reused by another plan. if (memory::IsHostMem(mem_block->mem_case())) { continue; } int64_t mem_zone_uid = memory::GetUniqueMemCaseId(mem_block->machine_id(), mem_block->mem_case()); auto it = mzuid2mem_blocks.find(mem_zone_uid); if (it == mzuid2mem_blocks.end()) { it = mzuid2mem_blocks.emplace(mem_zone_uid, HashSet()).first; } CHECK(it->second.insert(mem_block).second); } std::vector all_chunks; HashSet unique_chunk_ids; for (auto& pair : mzuid2mem_blocks) { int64_t mem_zone_uid = pair.first; std::vector exist_chunks; Singleton::Get()->GetChunkProtosByMemZoneUniqueId(mem_zone_uid, &exist_chunks); auto chunk_it = exist_chunks.begin(); auto& mem_blocks = pair.second; int64_t current_chunk_offset = 0; HashSet remain_blocks; for (auto mem_block_it = mem_blocks.begin(); mem_block_it != mem_blocks.end(); ++mem_block_it) { if (chunk_it == exist_chunks.end()) { // NOTE(chengcheng): it means that exist chunk has run out. CHECK(remain_blocks.insert(*mem_block_it).second); } else { // NOTE(chengcheng): find chunk which has enough space left. while (chunk_it != exist_chunks.end() && (current_chunk_offset + (*mem_block_it)->mem_size() > (*chunk_it)->mem_size())) { // NOTE(chengcheng): current chunk has no space left, so we move to next chunk. ++chunk_it; current_chunk_offset = 0; } if (chunk_it != exist_chunks.end()) { // NOTE(chengcheng): lucky, we find a appropriate chunk. MemBlockProto* mem_block = *mem_block_it; const ChunkProto* chunk = *chunk_it; CHECK_EQ(mem_block->machine_id(), chunk->machine_id()); CHECK(mem_block->mem_case() == chunk->mem_case()); CHECK_LE(current_chunk_offset + mem_block->mem_size(), chunk->mem_size()); CHECK_GE(current_chunk_offset, 0); // CHECK_GT(mem_block->mem_size(), 0); NOTE(chengcheng): has mem block mem size = 0 CHECK_GE(chunk->mem_size(), 0); mem_block->set_chunk_id(chunk->chunk_id()); mem_block->set_chunk_offset(current_chunk_offset); current_chunk_offset += mem_block->mem_size(); VLOG(3) << "Lazy nn.Graph Reused MemBlock :[" << mem_block->DebugString() << "] to old Chunk :[" << chunk->DebugString() << "]\n"; } else { // NOTE(chengcheng): sad, no chunk can used, so this mem block need to insert in remain. CHECK(remain_blocks.insert(*mem_block_it).second); } } } for (const ChunkProto* exist_chunk : exist_chunks) { all_chunks.emplace_back(*exist_chunk); CHECK(unique_chunk_ids.insert(exist_chunk->chunk_id()).second); } if (!remain_blocks.empty()) { auto remain_block_it = remain_blocks.begin(); MemBlockProto* first_block = *remain_block_it; ChunkProto new_chunk; new_chunk.set_chunk_id(Singleton::Get()->NewChunkId()); new_chunk.set_machine_id(first_block->machine_id()); *new_chunk.mutable_mem_case() = first_block->mem_case(); new_chunk.set_mem_size(first_block->mem_size()); first_block->set_chunk_id(new_chunk.chunk_id()); first_block->set_chunk_offset(0); ++remain_block_it; VLOG(3) << "Lazy nn.Graph Add MemBlock :[" << first_block->DebugString() << "] to NewChunk :[" << new_chunk.DebugString() << "]\n"; while (remain_block_it != remain_blocks.end()) { MemBlockProto* this_block = *remain_block_it; CHECK_EQ(this_block->machine_id(), new_chunk.machine_id()); CHECK(this_block->mem_case() == new_chunk.mem_case()); this_block->set_chunk_id(new_chunk.chunk_id()); this_block->set_chunk_offset(new_chunk.mem_size()); new_chunk.set_mem_size(new_chunk.mem_size() + this_block->mem_size()); VLOG(3) << "Lazy nn.Graph Add MemBlock :[" << this_block->DebugString() << "] to NewChunk :[" << new_chunk.DebugString() << "]\n"; ++remain_block_it; } all_chunks.emplace_back(new_chunk); CHECK(unique_chunk_ids.insert(new_chunk.chunk_id()).second); Singleton::Get()->AddChunkProto(new_chunk); } } CHECK_EQ(all_chunks.size(), unique_chunk_ids.size()); for (const ChunkProto& chunk : all_chunks) { *(plan->mutable_block_chunk_list()->add_chunk()) = chunk; } } } // namespace void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, int64_t limited_rank) { if (job.logical_chain_groups_size() == 0) { return; } HashMap> logical_chain_id2machine_id2mem_block_id; for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); const StreamId stream_id = PlanUtil::GetStreamId(*task); int64_t machine_id = task->machine_id(); DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } if (!IsValidChainId(task->chain_id())) { continue; } int64_t logical_chain_id = task->chain_id(); for (auto& pair : *(task->mutable_produced_regst_desc())) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem() && regst_desc->mem_case().device_type() == device_type && regst_desc->regst_desc_type().has_data_regst_desc()) { int64_t mem_block_id = regst_desc->mem_block_id(); auto* rank2blocks = &(logical_chain_id2machine_id2mem_block_id[logical_chain_id]); if (rank2blocks->find(machine_id) == rank2blocks->end()) { rank2blocks->emplace(machine_id, mem_block_id); } else { CHECK_EQ(rank2blocks->at(machine_id), mem_block_id); } } } } HashMap mem_block_id2merged_mem_block_id; for (const auto& logical_chain_group : job.logical_chain_groups()) { CHECK_GE(logical_chain_group.logical_chain_id_list_size(), 2); int64_t merged_logical_chain_id = logical_chain_group.logical_chain_id_list(0); if (limited_rank == -1) { CHECK(logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) != logical_chain_id2machine_id2mem_block_id.end()); } else { if (logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) == logical_chain_id2machine_id2mem_block_id.end()) { // Skip when doing rank compile and this logical chain group is not related to this rank. continue; } } const auto& merged_rank2block = logical_chain_id2machine_id2mem_block_id.at(merged_logical_chain_id); for (int64_t i = 1; i < logical_chain_group.logical_chain_id_list_size(); ++i) { int64_t this_logical_chain_id = logical_chain_group.logical_chain_id_list(i); // NOTE(chengcheng): merge mem block id by each rank CHECK(logical_chain_id2machine_id2mem_block_id.find(this_logical_chain_id) != logical_chain_id2machine_id2mem_block_id.end()); const auto& this_rank2block = logical_chain_id2machine_id2mem_block_id.at(this_logical_chain_id); for (const auto& pair : this_rank2block) { int64_t this_machine_id = pair.first; int64_t this_mem_block_id = pair.second; if (limited_rank == -1) { CHECK(merged_rank2block.find(this_machine_id) != merged_rank2block.end()); } else { if (merged_rank2block.find(this_machine_id) == merged_rank2block.end()) { continue; } } int64_t merged_mem_block_id = merged_rank2block.at(this_machine_id); CHECK(mem_block_id2merged_mem_block_id.emplace(this_mem_block_id, merged_mem_block_id) .second); VLOG(2) << " merge mem_block_id: " << this_mem_block_id << " to " << merged_mem_block_id; } } } for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); const StreamId stream_id = PlanUtil::GetStreamId(*task); DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } if (!IsValidChainId(task->chain_id())) { continue; } for (auto& pair : *(task->mutable_produced_regst_desc())) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem() && regst_desc->mem_case().device_type() == device_type && regst_desc->regst_desc_type().has_data_regst_desc()) { int64_t mem_block_id = regst_desc->mem_block_id(); if (mem_block_id2merged_mem_block_id.find(mem_block_id) != mem_block_id2merged_mem_block_id.end()) { // merge mem_block_id int64_t merged_mem_block_id = mem_block_id2merged_mem_block_id.at(mem_block_id); regst_desc->set_mem_block_id(merged_mem_block_id); if (VLOG_IS_ON(3)) { const auto& data_regst = regst_desc->regst_desc_type().data_regst_desc(); CHECK_GE(data_regst.lbi2blob_desc_size(), 1); const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); VLOG(3) << " regst: " << tensor_name << " merge mem block id " << mem_block_id << " to " << merged_mem_block_id; } } } } } } void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( Plan* plan, const HashSet& variable_op_names) { HashMap> mem_block_id2mem_block; auto IsVariableRegst = [&](const TaskProto* task, std::string* name) -> bool { if (variable_op_names.empty()) { return false; } if (task->exec_sequence().exec_node_size() != 1) { return false; } const auto& op_conf = GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf()) .op_conf(); if (!op_conf.has_variable_conf()) { return false; } const std::string& var_name = op_conf.name(); if (variable_op_names.find(var_name) == variable_op_names.end()) { LOG(WARNING) << " Oh no! Cannot find variable_op_name: " << var_name << " in nn.Graph Compiler bind EagerTensor with VariableOp. " << " \n But each variable need bind with eager tensor for init."; return false; } *name = var_name; return true; }; auto GenMemBlock4RegstIfNeed = [&](RegstDescProto* regst_desc, const TaskProto* task) { const int64_t job_id = task->job_id(); const int64_t machine_id = task->machine_id(); const int64_t thrd_id = task->thrd_id(); int64_t mem_block_id = regst_desc->mem_block_id(); int64_t mem_block_offset = regst_desc->mem_block_offset(); CHECK_NE(mem_block_id, -1); CHECK_NE(mem_block_offset, -1); std::string var_name; bool is_variable_regst = IsVariableRegst(task, &var_name); if (is_variable_regst) { CHECK(!var_name.empty()); CHECK_EQ(regst_desc->register_num(), 1); CHECK_EQ(regst_desc->min_register_num(), 1); // NOTE(xuxiaoyu): this check cannot pass when open ZeRO // CHECK_EQ(regst_desc->max_register_num(), 1) << var_name; regst_desc->set_variable_op_name(var_name); } RtRegstDesc rt_regst_desc(*regst_desc); int64_t regst_main_size = rt_regst_desc.TotalMainByteSize4AllRegst(); int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst(); auto mem_block_it = mem_block_id2mem_block.find(mem_block_id); if (mem_block_it == mem_block_id2mem_block.end()) { MemBlockProto mem_block; mem_block.set_mem_block_id(mem_block_id); mem_block.add_job_id(job_id); mem_block.set_machine_id(machine_id); *(mem_block.mutable_mem_case()) = regst_desc->mem_case(); mem_block.set_enable_reuse_mem(regst_desc->enable_reuse_mem()); mem_block.set_mem_size(regst_main_size + mem_block_offset); mem_block.set_thrd_id_hint(thrd_id); if (is_variable_regst) { mem_block.set_variable_op_name(var_name); mem_block.set_is_separated_header(false); } CHECK(mem_block_id2mem_block .emplace(mem_block.mem_block_id(), std::make_unique(mem_block)) .second); } else { MemBlockProto* mem_block = mem_block_it->second.get(); CHECK_EQ(mem_block->job_id(0), job_id); CHECK_EQ(mem_block->machine_id(), machine_id); CHECK(mem_block->mem_case() == regst_desc->mem_case()); CHECK_EQ(mem_block->enable_reuse_mem(), regst_desc->enable_reuse_mem()); if (mem_block->enable_reuse_mem()) { mem_block->set_mem_size( std::max(mem_block->mem_size(), regst_main_size + mem_block_offset)); } else { CHECK_EQ(mem_block->mem_size(), regst_main_size); CHECK_EQ(mem_block_offset, 0); } if (is_variable_regst) { mem_block->set_variable_op_name(var_name); mem_block->set_is_separated_header(false); } } if (regst_separated_size > 0) { CHECK(regst_desc->has_separated_header_mem_block_id()) << regst_desc->DebugString(); int64_t separated_mem_block_id = regst_desc->separated_header_mem_block_id(); CHECK_NE(separated_mem_block_id, -1); if (mem_block_id2mem_block.find(separated_mem_block_id) == mem_block_id2mem_block.end()) { MemBlockProto mem_block; mem_block.set_mem_block_id(separated_mem_block_id); mem_block.add_job_id(job_id); mem_block.set_machine_id(machine_id); *(mem_block.mutable_mem_case()) = memory::GetPinnedHostMemoryCase(regst_desc->mem_case()); mem_block.set_enable_reuse_mem(false); mem_block.set_mem_size(regst_separated_size); mem_block.set_thrd_id_hint(thrd_id); if (is_variable_regst) { mem_block.set_variable_op_name(var_name); mem_block.set_is_separated_header(true); } CHECK(mem_block_id2mem_block .emplace(mem_block.mem_block_id(), std::make_unique(mem_block)) .second); } else { MemBlockProto* mem_block = mem_block_id2mem_block.at(separated_mem_block_id).get(); CHECK_EQ(mem_block->job_id(0), job_id); CHECK_EQ(mem_block->machine_id(), machine_id); CHECK(mem_block->mem_case() == memory::GetPinnedHostMemoryCase(regst_desc->mem_case())); CHECK_EQ(mem_block->enable_reuse_mem(), false); CHECK_EQ(mem_block->mem_size(), regst_separated_size); if (is_variable_regst) { mem_block->set_variable_op_name(var_name); mem_block->set_is_separated_header(true); } } } }; for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); for (auto& pair : *task->mutable_produced_regst_desc()) { GenMemBlock4RegstIfNeed(&pair.second, task); } } GenChunkForMultiNNGraphMemoryReuseInMultiClient(plan, &mem_block_id2mem_block); for (const auto& pair : mem_block_id2mem_block) { *(plan->mutable_block_chunk_list()->add_mem_block()) = *(pair.second); } } void PlanUtil::CleanUselessMemBlockAndCheckValid(Plan* plan) { HashMap chunk_id2chunk; HashMap mem_block_id2mem_block; for (const auto& chunk : plan->block_chunk_list().chunk()) { CHECK(chunk_id2chunk.emplace(chunk.chunk_id(), chunk).second); } for (const auto& mem_block : plan->block_chunk_list().mem_block()) { CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } plan->mutable_block_chunk_list()->clear_mem_block(); HashMap> chunk_id2job_ids; HashMap> mem_block_id2job_ids; for (const auto& pair : chunk_id2chunk) { for (int64_t job_id : pair.second.job_id()) { CHECK(chunk_id2job_ids[pair.first].insert(job_id).second); } } for (const auto& pair : mem_block_id2mem_block) { for (int64_t job_id : pair.second.job_id()) { CHECK(mem_block_id2job_ids[pair.first].insert(job_id).second); } } HashSet valid_mem_block_ids; for (const TaskProto& task : plan->task()) { for (const auto& pair : task.produced_regst_desc()) { const RegstDescProto& regst = pair.second; RtRegstDesc rt_regst(regst); int64_t regst_size = rt_regst.TotalMainByteSize4AllRegst(); CHECK(mem_block_id2mem_block.find(regst.mem_block_id()) != mem_block_id2mem_block.end()); const MemBlockProto& mem_block = mem_block_id2mem_block.at(regst.mem_block_id()); CHECK_GE(mem_block.mem_size(), regst.mem_block_offset() + regst_size); CHECK_EQ(task.machine_id(), mem_block.machine_id()); CHECK_EQ(mem_block.enable_reuse_mem(), regst.enable_reuse_mem()); CHECK(mem_block.mem_case() == regst.mem_case()); const auto& job_ids = mem_block_id2job_ids[regst.mem_block_id()]; CHECK(job_ids.find(task.job_id()) != job_ids.end()); valid_mem_block_ids.insert(regst.mem_block_id()); // separated_header int64_t separated_header_mem_size = rt_regst.TotalSeparatedHeaderByteSize4AllRegst(); if (separated_header_mem_size > 0) { int64_t header_block_id = regst.separated_header_mem_block_id(); CHECK_NE(header_block_id, -1); CHECK(mem_block_id2mem_block.find(header_block_id) != mem_block_id2mem_block.end()); const MemBlockProto& header_mem_block = mem_block_id2mem_block.at(header_block_id); CHECK_EQ(header_mem_block.mem_size(), separated_header_mem_size); CHECK_EQ(task.machine_id(), header_mem_block.machine_id()); CHECK(header_mem_block.mem_case() == memory::GetPinnedHostMemoryCase(regst.mem_case())); CHECK(header_mem_block.enable_reuse_mem() == false); const auto& header_block_job_ids = mem_block_id2job_ids[header_block_id]; CHECK(header_block_job_ids.find(task.job_id()) != header_block_job_ids.end()); valid_mem_block_ids.insert(regst.separated_header_mem_block_id()); } } } HashSet useless_mem_block_ids; HashSet valid_chunk_ids; for (const auto& pair : mem_block_id2mem_block) { if (valid_mem_block_ids.find(pair.first) == valid_mem_block_ids.end()) { CHECK(useless_mem_block_ids.insert(pair.first).second); continue; } const MemBlockProto& mem_block = pair.second; if (mem_block.has_chunk_id()) { CHECK(mem_block.has_chunk_offset()); CHECK(mem_block.enable_reuse_mem()); CHECK(chunk_id2chunk.find(mem_block.chunk_id()) != chunk_id2chunk.end()); const ChunkProto& chunk = chunk_id2chunk.at(mem_block.chunk_id()); CHECK_GE(chunk.mem_size(), mem_block.chunk_offset() + mem_block.mem_size()); CHECK_EQ(mem_block.job_id_size(), 1); CHECK_GE(chunk.job_id_size(), 1); const HashSet& chunk_job_ids = chunk_id2job_ids.at(chunk.chunk_id()); CHECK(chunk_job_ids.find(mem_block.job_id(0)) != chunk_job_ids.end()); valid_chunk_ids.insert(mem_block.chunk_id()); } } CHECK_EQ(valid_chunk_ids.size(), chunk_id2chunk.size()); for (int64_t useless_block_id : useless_mem_block_ids) { mem_block_id2mem_block.erase(useless_block_id); } for (const auto& pair : mem_block_id2mem_block) { *(plan->mutable_block_chunk_list()->add_mem_block()) = pair.second; } } void PlanUtil::ToDotFile(const Plan& plan, const std::string& filepath) { const auto& process_ranks = Singleton::Get()->process_ranks(); size_t gpu_device_num = Singleton::Get()->GetDeviceCount(DeviceType::kCUDA); std::map>>> machine_id2job_id_device_id2node_list; for (size_t i : process_ranks) { for (const auto& pair : plan.job_confs().job_id2job_conf()) { machine_id2job_id_device_id2node_list[i][pair.first].resize(gpu_device_num); } } std::map>> machine_id2job_id2host_node_list; std::vector main_node_list; std::vector copy_comm_net_node_list; HashSet ctrl_regst_desc_ids; HashMap> task_id2consumer_regst_id2name; HashMap task_id2op_name; HashMap> task_id2producer_task_ids; std::vector> machine_id2device_id2node_list_job_ids(process_ranks.size()); std::vector> machine_id2host_node_list_job_ids(process_ranks.size()); auto InsertNodeDefByTaskProto = [&](const TaskProto& task_proto, const std::string& node_def, const std::string& pass_tag) { if (task_proto.task_type() == TaskType::kCopyCommNet) { copy_comm_net_node_list.emplace_back(node_def); return; } if (pass_tag == kNoPassTag) { const StreamId stream_id = PlanUtil::GetStreamId(task_proto); if (stream_id.device_id().device_type() == DeviceType::kCUDA) { machine_id2job_id_device_id2node_list[task_proto.machine_id()][task_proto.job_id()] [stream_id.device_id().device_index()] .emplace_back(node_def); machine_id2device_id2node_list_job_ids[task_proto.machine_id()].insert(task_proto.job_id()); } else { machine_id2job_id2host_node_list[task_proto.machine_id()][task_proto.job_id()].emplace_back( node_def); machine_id2host_node_list_job_ids[task_proto.machine_id()].insert(task_proto.job_id()); } } else if (pass_tag == kMainOp) { main_node_list.emplace_back(node_def); } else { UNIMPLEMENTED(); } }; auto GenEdgeColorStr = [](const RegstDescTypeProto& type) { if (type.has_ctrl_regst_desc()) { return "fontcolor=\"gray65\",color=\"gray65\""; } return "fontcolor=\"gray15\",color=\"gray15\""; }; auto IsEsac2ReentrantLockEdge = [](const std::string& src_name, const std::string& dst_name) { if (src_name.find("Esac") != std::string::npos && dst_name.find("ReentrantLock") != std::string::npos) { return true; } return false; }; auto IsEsacNode = [](const std::string& name) { if (name.find("Esac") != std::string::npos) { return true; } return false; }; auto log_stream = TeePersistentLogStream::Create(filepath); // task node for (const TaskProto& task_proto : plan.task()) { for (const auto& pair : task_proto.produced_regst_desc()) { const RegstDescProto& regst = pair.second; for (int64_t consumer_task_id : regst.consumer_task_id()) { task_id2producer_task_ids[consumer_task_id].emplace_back(task_proto.task_id()); } } } for (const TaskProto& task_proto : plan.task()) { std::string task_id_str = "task" + std::to_string(task_proto.task_id()); std::string task_class = task_id_str; for (const auto& in_task_id : task_id2producer_task_ids[task_proto.task_id()]) { task_class += " in" + std::to_string(in_task_id); } for (const auto& pair : task_proto.produced_regst_desc()) { const RegstDescProto& regst = pair.second; for (int64_t consumer_task_id : regst.consumer_task_id()) { task_class += " out" + std::to_string(consumer_task_id); } } task_class += " job_id" + std::to_string(task_proto.job_id()); task_class += " machine_id" + std::to_string(task_proto.machine_id()); std::string node_def = task_id_str + "[class=\"" + task_class + "\",label=\"{{"; node_def += std::to_string(task_proto.task_id()) + ":" + std::to_string(task_proto.machine_id()) + "\\n"; std::string op_name = ""; std::string pass_tag = kNoPassTag; for (const ExecNodeProto& exec_node : task_proto.exec_sequence().exec_node()) { const auto& op_conf = GetOpAttribute(&plan, task_proto.job_id(), exec_node.kernel_conf()).op_conf(); op_name += op_conf.name(); if (op_conf.has_pass_tag()) { pass_tag = op_conf.pass_tag(); } } task_id2op_name[task_proto.task_id()] = op_name; node_def += op_name; size_t index = 0; for (const auto& pair : task_proto.produced_regst_desc()) { std::string regst_id = std::to_string(pair.second.regst_desc_id()); if (index % 2 == 0) { node_def += "}|{"; } else { node_def += "|"; } // node_def += ""; node_def += (pair.first + ":" + regst_id + ":" + std::to_string(pair.second.register_num())); ++index; } node_def += "}}"; node_def += ("\",tooltip=\"" + TaskType_Name(task_proto.task_type()) + " " + std::to_string(task_proto.task_id()) + "-" + std::to_string(task_proto.machine_id()) + ":" + std::to_string(task_proto.thrd_id()) + ":" + std::to_string(task_proto.parallel_ctx().parallel_id()) + "\", shape=record, style=\"rounded,filled\"" + ",colorscheme=set312, fillcolor=" + std::to_string((task_proto.job_id() % 12) + 1)); if (IsEsacNode(op_name)) { node_def += ",width=5,height=1.5"; } node_def += "];\n"; InsertNodeDefByTaskProto(task_proto, node_def, pass_tag); for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { task_id2consumer_regst_id2name[task_proto.task_id()][regst_desc_id] = pair.first; } } } log_stream << "digraph merged_plan_graph {\n"; log_stream << "#splines=\"ortho\";\n"; log_stream << "#rankdir=TB;\n"; log_stream << "#nodesep=1.3;\n"; log_stream << "#ranksep=1.3;\n"; log_stream << "node[color=\"gray\"];\n"; // main_node and copy_comm_net_node graph for (const std::string& main_node : main_node_list) { log_stream << main_node; } for (const std::string& copy_comm_net_node : copy_comm_net_node_list) { log_stream << copy_comm_net_node; } // sub graph for (size_t machine_id : process_ranks) { std::string machine_name = "machine_" + std::to_string(machine_id); log_stream << "subgraph cluster_" << machine_name << " { label = \"" << machine_name << "\";\n"; log_stream << "style=\"rounded\";\n"; { for (const auto& job_id : machine_id2host_node_list_job_ids[machine_id]) { std::string job_name = plan.job_confs().job_id2job_conf().at(job_id).job_name(); job_name += (std::string(":") + std::to_string(job_id)); if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << "subgraph cluster_job_" << std::to_string(job_id) << " { label = \"" << job_name << "\";\n"; log_stream << "style=\"rounded\";\n"; } for (const std::string& host_node_def : machine_id2job_id2host_node_list[machine_id][job_id]) { log_stream << host_node_def; } if (machine_id2device_id2node_list_job_ids[machine_id].find(job_id) != machine_id2device_id2node_list_job_ids[machine_id].end()) { for (size_t device_id = 0; device_id < gpu_device_num; ++device_id) { std::string device_name = machine_name + "_device_" + std::to_string(device_id); log_stream << "#subgraph cluster_" << device_name << " { label = \"" << device_name << "\";\n"; log_stream << "#color=\"skyblue\";\n"; log_stream << "#fillcolor=\"azure\";\n"; log_stream << "#style=\"rounded,filled\";\n"; for (const auto& device_node_def : machine_id2job_id_device_id2node_list[machine_id][job_id][device_id]) { log_stream << device_node_def; } log_stream << "#}\n"; } machine_id2device_id2node_list_job_ids[machine_id].erase(job_id); } if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << "}\n"; } } for (const auto& job_id : machine_id2device_id2node_list_job_ids[machine_id]) { std::string job_name = plan.job_confs().job_id2job_conf().at(job_id).job_name(); job_name += (std::string(":") + std::to_string(job_id)); if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << "subgraph cluster_job_" << std::to_string(job_id) << " { label = \"" << job_name << "\";\n"; log_stream << "style=\"rounded\";\n"; } for (size_t device_id = 0; device_id < gpu_device_num; ++device_id) { std::string device_name = machine_name + "_device_" + std::to_string(device_id); log_stream << "#subgraph cluster_" << device_name << " { label = \"" << device_name << "\";\n"; log_stream << "#color=\"skyblue\";\n"; log_stream << "#fillcolor=\"azure\";\n"; log_stream << "#style=\"rounded,filled\";\n"; for (const auto& device_node_def : machine_id2job_id_device_id2node_list[machine_id][job_id][device_id]) { log_stream << device_node_def; } log_stream << "#}\n"; } if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << "}\n"; } } } log_stream << "}\n"; } // produce/consume edge for (const TaskProto& task_proto : plan.task()) { for (const auto& pair : task_proto.produced_regst_desc()) { const RegstDescProto& regst = pair.second; std::string src_node = "task" + std::to_string(task_proto.task_id()); // src_node += ":regst_desc_" + std::to_string(regst.regst_desc_id()); for (int64_t consumer_task_id : regst.consumer_task_id()) { std::string dst_node = "task" + std::to_string(consumer_task_id); // dst_node += ":task_node_" + std::to_string(consumer_task_id); std::string consumer_regst_name = task_id2consumer_regst_id2name[consumer_task_id][regst.regst_desc_id()]; std::string consumer_op_name = task_id2op_name[consumer_task_id]; std::string producer_regst_name = pair.first; std::string producer_op_name = task_id2op_name[task_proto.task_id()]; std::string tooltip = producer_op_name + " : " + producer_regst_name + " -> " + consumer_op_name + " : " + consumer_regst_name; if (IsEsac2ReentrantLockEdge(producer_op_name, consumer_op_name)) { log_stream << dst_node << "->" << src_node << "[arrowhead=\"invempty\",fontcolor=\"red\",color=\"red\",taillabel=\"" << consumer_regst_name << "\",tailtooltip=\"" << tooltip; } else { log_stream << src_node << "->" << dst_node << "[" << GenEdgeColorStr(regst.regst_desc_type()) << ",headlabel=\"" << consumer_regst_name << "\",headtooltip=\"" << tooltip; } log_stream << "\",tooltip=\"" << tooltip << "\",arrowsize=0.5,labeldistance=1.5,penwidth=2" << "];\n"; } } } log_stream << "}\n"; } std::function PlanUtil::MakeMutRegstDesc4Id(Plan* plan) { auto regst_desc_id2regst_desc = std::make_shared>(); for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); for (auto& pair : *task->mutable_produced_regst_desc()) { int64_t regst_desc_id = pair.second.regst_desc_id(); CHECK(regst_desc_id2regst_desc->insert({regst_desc_id, &pair.second}).second) << "regst_desc_id2regst_desc has got duplicated regst_desc_id " << regst_desc_id; } } return [regst_desc_id2regst_desc](int64_t regst_desc_id) -> RegstDescProto* { auto iter = regst_desc_id2regst_desc->find(regst_desc_id); CHECK(iter != regst_desc_id2regst_desc->end()) << "regst_desc_id " << regst_desc_id << " can't be found in plan."; return iter->second; }; } void PlanUtil::SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank) { auto RegstDesc4Id = MakeMutRegstDesc4Id(plan); for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); // When do seperation compilation, some rank's plan (such as rank 0) has other ranks task node // for compilation. There is no need to set mem block for other ranks task node. if (limited_rank >= 0 && task->machine_id() != limited_rank) { continue; } for (auto& pair : *task->mutable_produced_regst_desc()) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->has_force_inplace_consumed_regst_desc_id()) { int64_t force_id = regst_desc->force_inplace_consumed_regst_desc_id(); const RegstDescProto* in_regst_desc = RegstDesc4Id(force_id); CHECK(!in_regst_desc->enable_reuse_mem()); CHECK(!regst_desc->enable_reuse_mem()); CHECK_NE(in_regst_desc->mem_block_id(), -1); CHECK_EQ(in_regst_desc->mem_block_offset(), 0); CHECK_EQ(regst_desc->mem_block_offset(), 0); CHECK_EQ(in_regst_desc->register_num(), regst_desc->register_num()); CHECK(in_regst_desc->mem_case() == regst_desc->mem_case()); RtRegstDesc in_regst_rt(*in_regst_desc); RtRegstDesc regst_rt(*regst_desc); CHECK_EQ(in_regst_rt.TotalByteSize4AllRegst(), regst_rt.TotalByteSize4AllRegst()); CHECK_EQ(in_regst_rt.TotalMainByteSize4AllRegst(), regst_rt.TotalMainByteSize4AllRegst()); CHECK_EQ(in_regst_rt.TotalSeparatedHeaderByteSize4AllRegst(), regst_rt.TotalSeparatedHeaderByteSize4AllRegst()); regst_desc->set_mem_block_id(in_regst_desc->mem_block_id()); regst_desc->set_inplace_consumed_regst_desc_id(force_id); if (in_regst_desc->has_separated_header_mem_block_id()) { CHECK(regst_desc->has_separated_header_mem_block_id()); regst_desc->set_separated_header_mem_block_id( in_regst_desc->separated_header_mem_block_id()); } VLOG(3) << " set force inplace from " << regst_desc->DebugString() << " to " << in_regst_desc->DebugString(); } } } } void PlanUtil::DumpCtrlRegstInfoToPlan(Plan* plan) { auto* ctrl_regst_desc_id2producer_task_id = plan->mutable_ctrl_regst_desc_info()->mutable_ctrl_regst_desc_id2producer_task_id(); for (const TaskProto& task : plan->task()) { for (const auto& pair : task.produced_regst_desc()) { if (pair.second.regst_desc_type().has_ctrl_regst_desc()) { ctrl_regst_desc_id2producer_task_id->insert( {pair.second.regst_desc_id(), pair.second.producer_task_id()}); } } } } namespace { bool IsCollectiveBoxingTaskType(TaskType task_type) { return task_type == TaskType::kCollectiveBoxingGeneric; } bool IsCollectiveBoxingNode(const PlanTaskNode* node) { const TaskType task_type = node->task_proto()->task_type(); return IsCollectiveBoxingTaskType(task_type); } const boxing::collective::RankDesc& GetRankDesc(const OperatorConf& conf) { if (conf.has_collective_boxing_generic_conf()) { return conf.collective_boxing_generic_conf().rank_desc(); } else { UNIMPLEMENTED(); } } const boxing::collective::RankDesc& GetRankDesc(Plan* plan, const TaskProto& task_proto) { CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1); return GetRankDesc(PlanUtil::GetOpAttribute(plan, task_proto.job_id(), task_proto.exec_sequence().exec_node(0).kernel_conf()) .op_conf()); } struct CollectiveBoxingRequestInfo { boxing::collective::OpDesc op_desc; std::map rank2node; int64_t order; int64_t dependency_depth; }; void GetDeviceDesc(const TaskProto* task_proto, boxing::collective::DeviceDesc* device_desc) { device_desc->set_machine_id(task_proto->machine_id()); const StreamId stream_id = PlanUtil::GetStreamId(*task_proto); const DeviceId& device_id = stream_id.device_id(); device_desc->set_device_type(device_id.device_type()); device_desc->set_device_id(device_id.device_index()); } } // namespace void PlanUtil::GenCollectiveBoxingPlan(Job* job, Plan* plan) { using namespace boxing::collective; RequestSet* request_set = &(*plan->mutable_collective_boxing_plan() ->mutable_job_id2request_set())[GlobalJobDesc().job_id()]; const int64_t cb_task_count = std::count_if( plan->task().cbegin(), plan->task().cend(), [](const TaskProto& task) { return IsCollectiveBoxingTaskType(task.task_type()); }); if (cb_task_count == 0) { return; } PlanTaskGraph plan_task_graph(*plan); int64_t dependency_depth = 0; int64_t order = 0; HashSet all_visited; while (true) { std::list src_nodes; plan_task_graph.ForEachNode([&](const PlanTaskNode* node) { if (all_visited.count(node) != 0) { return; } int64_t in_cnt = 0; node->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) { if (all_visited.count(node_on_in_edge) != 0) { return; } in_cnt += 1; }); if (in_cnt == 0) { src_nodes.emplace_back(node); } }); if (src_nodes.empty()) { break; } auto ForEachNodeOnInEdge = [&](const PlanTaskNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) { if (all_visited.count(node_on_in_edge) == 0) { Handler(node_on_in_edge); } }); }; auto ForEachNodeOnOutEdge = [&](const PlanTaskNode* node, const std::function& Handler) { if (!IsCollectiveBoxingNode(node)) { node->ForEachNodeOnOutEdge([&](const PlanTaskNode* node_on_out_edge) { bool has_unvisited_collective_boxing_node_on_in_edges = false; node_on_out_edge->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) { if (!has_unvisited_collective_boxing_node_on_in_edges && IsCollectiveBoxingNode(node_on_in_edge) && all_visited.count(node_on_in_edge) == 0) { has_unvisited_collective_boxing_node_on_in_edges = true; } }); if (!has_unvisited_collective_boxing_node_on_in_edges) { Handler(node_on_out_edge); } }); } }; HashSet visited; std::vector collective_boxing_nodes; plan_task_graph.TopoForEachNode(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge, [&](const PlanTaskNode* node) { visited.insert(node); if (IsCollectiveBoxingNode(node)) { collective_boxing_nodes.emplace_back(node); } }); if (collective_boxing_nodes.empty()) { break; } HashMap name2request_info; for (const PlanTaskNode* node : collective_boxing_nodes) { const TaskProto* task_proto = node->task_proto(); const RankDesc& rank_desc = GetRankDesc(plan, *task_proto); CHECK_GE(rank_desc.rank(), 0); CHECK_LT(rank_desc.rank(), rank_desc.op_desc().num_ranks()); const std::string& name = rank_desc.op_desc().name(); boxing::collective::DeviceDesc device_desc; GetDeviceDesc(task_proto, &device_desc); auto it = name2request_info.find(name); if (it == name2request_info.end()) { CollectiveBoxingRequestInfo request_info{ .op_desc = rank_desc.op_desc(), .rank2node = {std::make_pair(rank_desc.rank(), node)}, .order = order, .dependency_depth = dependency_depth, }; name2request_info.emplace(std::make_pair(name, std::move(request_info))); order += 1; } else { CHECK(it->second.op_desc == rank_desc.op_desc()); CHECK(it->second.rank2node.emplace(std::make_pair(rank_desc.rank(), node)).second); } } int64_t collected = 0; for (const auto& name7request_info : name2request_info) { const CollectiveBoxingRequestInfo& info = name7request_info.second; if (info.rank2node.size() == info.op_desc.num_ranks()) { collected += 1; boxing::collective::RequestDesc* request_desc = request_set->mutable_request()->Add(); *request_desc->mutable_op_desc() = info.op_desc; for (int64_t i = 0; i < info.op_desc.num_ranks(); ++i) { GetDeviceDesc(info.rank2node.at(i)->task_proto(), request_desc->mutable_device_set()->mutable_device()->Add()); } request_desc->set_order(info.order); request_desc->set_dependency_depth(info.dependency_depth); } else { CHECK_LT(info.rank2node.size(), info.op_desc.num_ranks()); for (const auto& pair : info.rank2node) { visited.erase(pair.second); } } } CHECK_GT(collected, 0); all_visited.insert(visited.begin(), visited.end()); ++dependency_depth; } } void PlanUtil::GenRegisterHint(Plan* plan) { HashSet multi_regst_regst_desc_ids; for (const TaskProto& task : plan->task()) { for (const auto& pair : task.produced_regst_desc()) { if (pair.second.register_num() != 1 || task.task_type() == TaskType::kRepeat) { multi_regst_regst_desc_ids.emplace(pair.second.regst_desc_id()); } } } for (TaskProto& task : *(plan->mutable_task())) { bool all_register_num_eq_one = true; for (const auto& pair : task.produced_regst_desc()) { if (pair.second.register_num() != 1) { all_register_num_eq_one = false; break; } } for (const auto& pair : task.consumed_regst_desc_id()) { if (!all_register_num_eq_one) { break; } for (auto regst_desc_id : pair.second.regst_desc_id()) { if (multi_regst_regst_desc_ids.count(regst_desc_id) > 0) { all_register_num_eq_one = false; break; } } } task.set_all_register_num_eq_one_hint(all_register_num_eq_one); } } namespace { struct MemBlockMemoryInfo { int64_t mem_block_id; int64_t mem_block_mem_size; int64_t regst_num; std::vector ordered_regst_desc_id; MemBlockMemoryInfo() : mem_block_id(-1), mem_block_mem_size(-1), regst_num(-1) {} }; struct ChunkMemoryInfo { int64_t chunk_id; int64_t chunk_mem_size; std::vector mem_block_ids; ChunkMemoryInfo() : chunk_id(-1), chunk_mem_size(-1) {} }; struct RankDeviceMemoryInfo { int64_t rank_id; int64_t device_id; ChunkMemoryInfo chunk_info; int64_t total_mem_size; int64_t not_reused_mem_size; std::vector not_reused_mem_block_ids; int64_t eager_variable_total_mem_size; std::vector eager_variable_mem_block_ids; RankDeviceMemoryInfo() : rank_id(-1), device_id(-1), total_mem_size(0), not_reused_mem_size(0), eager_variable_total_mem_size(0) {} }; } // namespace void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) { std::vector rank_device_memory_infos(GlobalProcessCtx::WorldSize(), RankDeviceMemoryInfo()); HashMap mem_block_id2info; HashMap regst_desc_id2regst; for (const ChunkProto& chunk : plan->block_chunk_list().chunk()) { int64_t rank_id = chunk.machine_id(); auto& info = rank_device_memory_infos[rank_id]; info.rank_id = rank_id; if (!memory::IsHostMem(chunk.mem_case())) { info.device_id = chunk.mem_case().device_id(); } info.total_mem_size += chunk.mem_size(); info.chunk_info.chunk_id = chunk.chunk_id(); info.chunk_info.chunk_mem_size = chunk.mem_size(); } for (const MemBlockProto& mem_block : plan->block_chunk_list().mem_block()) { int64_t mem_block_id = mem_block.mem_block_id(); mem_block_id2info.emplace(mem_block_id, MemBlockMemoryInfo()); auto& info = mem_block_id2info.at(mem_block_id); info.mem_block_id = mem_block_id; info.mem_block_mem_size = mem_block.mem_size(); auto& rank_memory_info = rank_device_memory_infos.at(mem_block.machine_id()); if (!memory::IsHostMem(mem_block.mem_case())) { if (mem_block.has_chunk_id()) { rank_memory_info.chunk_info.mem_block_ids.push_back(mem_block_id); } else { if (mem_block.has_variable_op_name()) { rank_memory_info.eager_variable_mem_block_ids.push_back(mem_block_id); rank_memory_info.eager_variable_total_mem_size += mem_block.mem_size(); } else { rank_memory_info.not_reused_mem_block_ids.push_back(mem_block_id); rank_memory_info.not_reused_mem_size += mem_block.mem_size(); } rank_memory_info.total_mem_size += mem_block.mem_size(); } } } for (const auto& task : plan->task()) { for (const auto& pair : task.produced_regst_desc()) { const auto& regst = pair.second; if (regst.regst_desc_type().has_data_regst_desc() && mem_block_id2info.find(regst.mem_block_id()) != mem_block_id2info.end()) { mem_block_id2info.at(regst.mem_block_id()) .ordered_regst_desc_id.push_back(regst.regst_desc_id()); regst_desc_id2regst.emplace(regst.regst_desc_id(), ®st); } } } auto CompMemBlock = [&](int64_t a, int64_t b) { return mem_block_id2info[a].mem_block_mem_size > mem_block_id2info[b].mem_block_mem_size; }; auto B2MiB = [](int64_t val) { return val * 1.0 / 1000000.0; }; for (auto& rank_memory_info : rank_device_memory_infos) { std::sort(rank_memory_info.chunk_info.mem_block_ids.begin(), rank_memory_info.chunk_info.mem_block_ids.end(), CompMemBlock); std::sort(rank_memory_info.not_reused_mem_block_ids.begin(), rank_memory_info.not_reused_mem_block_ids.end(), CompMemBlock); std::sort(rank_memory_info.eager_variable_mem_block_ids.begin(), rank_memory_info.eager_variable_mem_block_ids.end(), CompMemBlock); LOG(INFO) << "\n Graph name " << plan_name << " in Rank: " << rank_memory_info.rank_id << ", Device: " << rank_memory_info.device_id << " needs to allocate [ " << B2MiB(rank_memory_info.total_mem_size) << " MiB ] device memory. \n In general, Chunk id: " << rank_memory_info.chunk_info.chunk_id << " memory is [ " << B2MiB(rank_memory_info.chunk_info.chunk_mem_size) << " MiB ] with mem_block_num = " << rank_memory_info.chunk_info.mem_block_ids.size() << "\n Unreused memory not eager var is [ " << B2MiB(rank_memory_info.not_reused_mem_size) << " MiB ] with mem_block_num = " << rank_memory_info.not_reused_mem_block_ids.size() << "\n Eager Variable Tensor total memory is [ " << B2MiB(rank_memory_info.eager_variable_total_mem_size) << " MiB ] with mem_block_num = " << rank_memory_info.eager_variable_mem_block_ids.size() << "\n"; } auto Vlog3ForMemBlockDetails = [&](int64_t device_id, const std::vector& mem_block_ids, const std::string& prefix) { for (int64_t mem_block_id : mem_block_ids) { CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); const auto& mem_block_info = mem_block_id2info.at(mem_block_id); if (mem_block_info.ordered_regst_desc_id.size() != 1) { continue; } const auto* regst = regst_desc_id2regst.at(mem_block_info.ordered_regst_desc_id.at(0)); const auto& data_regst = regst->regst_desc_type().data_regst_desc(); const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); const auto& blob_desc = lbi2blob_desc_pair.blob_desc(); VLOG(3) << "In Device: " << device_id << " Memblock id: " << mem_block_id << prefix << " size: " << B2MiB(mem_block_info.mem_block_mem_size) << " MiB, name: " << tensor_name << "\nshape: " << Shape(blob_desc.shape()).ToString() << " ,dtype: " << DataType_Name(blob_desc.data_type()); } }; for (const auto& rank_memory_info : rank_device_memory_infos) { int64_t chunk_id = rank_memory_info.chunk_info.chunk_id; int64_t device_id = rank_memory_info.device_id; VLOG(2) << "========================= " << "In Device : " << device_id << " Chunk Memory info details:"; for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) { CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); const auto& mem_block_info = mem_block_id2info.at(mem_block_id); VLOG(2) << " In Device: " << device_id << " Chunk id: " << chunk_id << " MemBlock id: " << mem_block_id << " has num = " << mem_block_info.ordered_regst_desc_id.size() << " tensor with mem size = " << B2MiB(mem_block_info.mem_block_mem_size); for (int64_t i = 0; i < mem_block_info.ordered_regst_desc_id.size(); ++i) { const auto* regst = regst_desc_id2regst.at(mem_block_info.ordered_regst_desc_id.at(i)); const auto& data_regst = regst->regst_desc_type().data_regst_desc(); const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); const auto& blob_desc = lbi2blob_desc_pair.blob_desc(); std::string alloc_order = "inplaced"; if (regst->has_alloc_before_actor()) { alloc_order = std::to_string(regst->alloc_before_actor()); } std::string free_order = "inplaced"; if (regst->has_free_after_actor()) { free_order = std::to_string(regst->free_after_actor()); } VLOG(3) << "In Chunk id: " << chunk_id << ", MemBlock id: " << mem_block_id << " Order: " << i << " ,duration: " << (regst->free_after_actor() - regst->alloc_before_actor() + 1) << " ,size: " << B2MiB(BlobDesc(blob_desc).AlignedTotalByteSize()) << " MiB, name: " << tensor_name << "\nshape: " << Shape(blob_desc.shape()).ToString() << " ,dtype: " << DataType_Name(blob_desc.data_type()) << " ,alloc_order: " << alloc_order << " ,free_order: " << free_order; } } Vlog3ForMemBlockDetails(device_id, rank_memory_info.not_reused_mem_block_ids, " Unreused "); Vlog3ForMemBlockDetails(device_id, rank_memory_info.eager_variable_mem_block_ids, " EagerVariable "); } } void PlanUtil::GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank) { // NOTE(chengcheng): ordered_tasks is NOT exec order, just task id order. std::vector ordered_tasks; for (const TaskProto& task : plan->task()) { ordered_tasks.push_back(&task); } auto CompTask = [](const TaskProto* a, const TaskProto* b) { return a->task_id() < b->task_id(); }; std::sort(ordered_tasks.begin(), ordered_tasks.end(), CompTask); HashMap task_id2name; HashMap task_id2proto; HashMap regst_id2name; HashMap regst_id2proto; for (const auto* task : ordered_tasks) { const auto& exec_seq = task->exec_sequence(); std::string name; if (exec_seq.exec_node_size() >= 1) { const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); if (kernel_conf.has_op_attribute_ref()) { name = kernel_conf.op_attribute_ref(); } else { name = kernel_conf.op_attribute().op_conf().name(); } } else { name = TaskType_Name(task->task_type()); } task_id2name.emplace(task->task_id(), name); task_id2proto.emplace(task->task_id(), task); CHECK(!name.empty()); for (const auto& pair : task->produced_regst_desc()) { std::string regst_name = name + "/" + pair.first; regst_id2name.emplace(pair.second.regst_desc_id(), regst_name); regst_id2proto.emplace(pair.second.regst_desc_id(), pair.second); } } auto RegstId2TensorStr = [&](int64_t regst_id) -> std::string { CHECK(regst_id2proto.find(regst_id) != regst_id2proto.end()) << " regst_id2proto cannot find: " << regst_id; std::ostringstream ss; ss << "{"; const RegstDescProto& regst = regst_id2proto.at(regst_id); ss << "regust_num: " << std::to_string(regst.register_num()); ss << ", device: " << *CHECK_JUST(DeviceTag4DeviceType(regst.mem_case().device_type())); if (regst.regst_desc_type().has_data_regst_desc()) { const DataRegstDesc& data = regst.regst_desc_type().data_regst_desc(); ss << ", time_shape: " << Shape(data.time_shape()).ToString(); const BlobDescProto& blob = data.lbi2blob_desc(0).blob_desc(); ss << ", shape: " << Shape(blob.shape()).ToString(); ss << ", dtype: " << DataType_Name(blob.data_type()); } else { ss << ", ctrl"; } ss << "}"; return ss.str(); }; std::vector> rank2ordered_task(GlobalProcessCtx::WorldSize(), std::vector()); for (const auto* task : ordered_tasks) { CHECK_LT(task->machine_id(), rank2ordered_task.size()); rank2ordered_task.at(task->machine_id()).push_back(task); } for (int64_t rank = 0; rank < GlobalProcessCtx::WorldSize(); ++rank) { // Filter rank to generate log. if (limited_rank >= 0 && rank != limited_rank) { continue; } auto file_stream = TeePersistentLogStream::Create(plan_name + "_rank_" + std::to_string(rank) + "_light_plan"); file_stream << "rank : " << std::to_string(rank) << "\n"; CHECK_LT(rank, rank2ordered_task.size()); const auto& ordered_task_in_rank = rank2ordered_task.at(rank); for (int64_t i = 0; i < ordered_task_in_rank.size(); ++i) { CHECK_LT(i, ordered_task_in_rank.size()); const auto* task = ordered_task_in_rank.at(i); int64_t task_id = task->task_id(); CHECK(task_id2name.find(task_id) != task_id2name.end()) << " task_id2name cannot find" << task_id; int64_t thrd_id = task->thrd_id(); StreamId stream_id = DecodeStreamIdFromInt64(thrd_id); file_stream << "i : " << std::to_string(i) << " , actor id : " << std::to_string(task_id) << " thrd : " << std::to_string(thrd_id) << " name : " << task_id2name.at(task_id) << "\n chain_id : " << std::to_string(task->chain_id()) << " order_in_chain : " << std::to_string(task->order_in_chain()) << " device_type : " << DeviceType_Name(stream_id.device_type()) << " stream_index : " << std::to_string(stream_id.stream_index()) << " {\n"; for (const auto& key2consume_regst : task->consumed_regst_desc_id()) { std::string key = key2consume_regst.first; for (int64_t consume_regst_id : key2consume_regst.second.regst_desc_id()) { std::string other_rank_str = ""; CHECK(regst_id2proto.find(consume_regst_id) != regst_id2proto.end()) << " regst_id2proto cannot find: " << consume_regst_id; int64_t consume_task_id = regst_id2proto.at(consume_regst_id).producer_task_id(); CHECK(task_id2proto.find(consume_task_id) != task_id2proto.end()) << " task_id2proto cannot find: " << consume_task_id; int64_t other_rank = task_id2proto.at(consume_task_id)->machine_id(); if (other_rank != rank) { other_rank_str = " , rank: " + std::to_string(other_rank); } CHECK(regst_id2name.find(consume_regst_id) != regst_id2name.end()) << " regst_id2name cannot find: " << consume_regst_id; file_stream << " consume : " << key << " : <- [ " << regst_id2name.at(consume_regst_id) << " ] ( actor_id: " << std::to_string(consume_task_id) << other_rank_str << ", regst: " << RegstId2TensorStr(consume_regst_id) << " )\n"; } } for (const auto& key2produce_regst : task->produced_regst_desc()) { const RegstDescProto& regst = key2produce_regst.second; file_stream << " produce : " << key2produce_regst.first << " regst: " << RegstId2TensorStr(regst.regst_desc_id()) << " {\n"; for (int64_t consumer_task_id : regst.consumer_task_id()) { std::string other_rank_str = ""; CHECK(task_id2proto.find(consumer_task_id) != task_id2proto.end()) << " task_id2proto cannot find " << consumer_task_id; CHECK(task_id2name.find(consumer_task_id) != task_id2name.end()) << " task_id2name cannot find " << consumer_task_id; int64_t other_rank = task_id2proto.at(consumer_task_id)->machine_id(); if (other_rank != rank) { other_rank_str = " , rank: " + std::to_string(other_rank); } file_stream << " -> [ " << task_id2name.at(consumer_task_id) << " ] ( actor_id: " << std::to_string(consumer_task_id) << other_rank_str << " )\n"; } file_stream << " }\n"; } file_stream << "}\n"; } } } const oneflow::OpAttribute& PlanUtil::GetOpAttribute(const Plan* plan, int64_t job_id, const oneflow::KernelConf& kernel_conf) { if (kernel_conf.has_op_attribute()) { return kernel_conf.op_attribute(); } else if (kernel_conf.has_op_attribute_ref()) { auto table_it = plan->job_id2op_attribute_ref_table().find(job_id); CHECK(table_it != plan->job_id2op_attribute_ref_table().end()) << "op attribute ref table not found for job id: " << job_id; ; auto it = table_it->second.op_name2op_attribute().find(kernel_conf.op_attribute_ref()); CHECK(it != table_it->second.op_name2op_attribute().end()) << "op attribute ref: " << kernel_conf.op_attribute_ref() << " not found"; return it->second; } else { UNIMPLEMENTED() << "kernel_conf must has either op_attribute or op_attribute_ref. kernel_conf: " << kernel_conf.DebugString(); } } void PlanUtil::PopulateOpAttribute( Plan* plan, const PbMap& job_id2op_attribute_ref_table) { for (auto& task : *plan->mutable_task()) { if (task.exec_sequence().exec_node_size() == 1 && task.exec_sequence().exec_node(0).kernel_conf().has_op_attribute_ref()) { auto* kernel_conf = task.mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf(); auto table_it = job_id2op_attribute_ref_table.find(task.job_id()); CHECK(table_it != job_id2op_attribute_ref_table.end()) << "op attribute ref table not found for job id: " << task.job_id(); auto it = table_it->second.op_name2op_attribute().find(kernel_conf->op_attribute_ref()); CHECK(it != table_it->second.op_name2op_attribute().end()) << "ref: " << kernel_conf->op_attribute_ref() << " not found"; *kernel_conf->mutable_op_attribute() = it->second; kernel_conf->clear_op_attribute_ref(); } else { for (auto& exec_node : task.exec_sequence().exec_node()) { CHECK(exec_node.kernel_conf().has_op_attribute()) << "op_attribute absent, exec_node: " << exec_node.DebugString(); } } } } /*static*/ StreamId PlanUtil::GetStreamId(const TaskProto& task) { return DecodeStreamIdFromInt64(task.thrd_id()); } /*static*/ int64_t PlanUtil::GetDeviceIndex(const TaskProto& task) { return GetStreamId(task).device_id().device_index(); } /*static*/ void PlanUtil::CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { auto* job_id2op_attribute_ref_table = plan->mutable_job_id2op_attribute_ref_table(); CHECK(task_proto->exec_sequence().exec_node_size() == 1); auto* exec_node = task_proto->mutable_exec_sequence()->mutable_exec_node(0); CHECK(exec_node->kernel_conf().has_op_attribute()); const std::string op_name = exec_node->kernel_conf().op_attribute().op_conf().name(); auto* op_name2op_attribute = (*job_id2op_attribute_ref_table)[job_id].mutable_op_name2op_attribute(); auto find_it = op_name2op_attribute->find(op_name); if (find_it == op_name2op_attribute->end()) { op_name2op_attribute->insert( {op_name, task_proto->exec_sequence().exec_node(0).kernel_conf().op_attribute()}); } auto* kernel_conf = task_proto->mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf(); kernel_conf->set_op_attribute_ref(op_name); // NOTE(levi): memory of op_attribute_ is released here. kernel_conf->set_allocated_op_attribute(nullptr); } } // namespace oneflow ================================================ FILE: oneflow/core/job/plan_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_PLAN_UTIL_H_ #define ONEFLOW_CORE_JOB_PLAN_UTIL_H_ #include #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/graph/plan_task_graph.h" namespace oneflow { struct PlanUtil { static RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto); static std::function MakeGetterTaskProto4TaskId(const Plan& plan); // limited_rank equals -1 means taking care of all ranks. // Otherwise, only take care of rank limited_rank. static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, int64_t limited_rank = -1); static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan); static void GenMemBlockAndChunk4Plan(Plan* plan); static void GenMemBlockAndChunkWithVariableOpNames4Plan( Plan* plan, const HashSet& variable_op_names); static void CleanUselessMemBlockAndCheckValid(Plan* plan); static void ToDotFile(const Plan& plan, const std::string& filepath); static std::function MakeMutRegstDesc4Id(Plan* plan); // limited_rank equals -1 means taking care of all ranks. // Otherwise, only take care of rank limited_rank. static void SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank = -1); static void DumpCtrlRegstInfoToPlan(Plan* plan); static void GenCollectiveBoxingPlan(Job* job, Plan* plan); static void GenRegisterHint(Plan* plan); // Generate readable plan log from plan proto. // Use filter_rank to choose which rank to generate. When filter_rank is -1, all rank will be // generated. The default value of filter_rank is -1. static void GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank = -1); static void PlanMemoryLog(Plan* plan, const std::string& plan_name); static const oneflow::OpAttribute& GetOpAttribute(const Plan* plan, int64_t job_id, const oneflow::KernelConf& kernel_conf); // NOTE(chengcheng): recovery op_attr static void PopulateOpAttribute( Plan* plan, const PbMap& job_id2op_attribute_ref_table); static StreamId GetStreamId(const TaskProto& task); static int64_t GetDeviceIndex(const TaskProto& task); static void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_PLAN_UTIL_H_ ================================================ FILE: oneflow/core/job/qat_config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" namespace oneflow { namespace { REGISTER_SCOPE_CONFIG_DEF().Bool("quantization_aware_training", true, "enable quantization aware training"); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job/rank_compiler.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/rank_compiler.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/intra_job_mem_sharing_util.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { Maybe RankCompiler::Compile(const HashSet& var_op_names, Job* job, Plan* plan) const { #ifdef WITH_CUDA // Use the right device when some plan compilation needs cuda to avoid creating unnecessary cuda // context on cuda:0. CudaCurrentDeviceGuard guard(GetCudaDeviceIndex()); #endif // WITH_CUDA auto task_gph = JUST(RankTaskGraph::New(boxing_task_graph_proto_, var_op_names, rank_)); using std::placeholders::_1; const auto& IsNotMyDuty = [&](const CompTaskNode* comp_task_node) { if (comp_task_node == nullptr) { return false; } const auto& parallel_desc = comp_task_node->op_node()->parallel_desc(); return !task_gph->IsDutyRank(parallel_desc, comp_task_node->machine_id()); }; task_gph->ForEachNode([&](TaskNode* task_node) { auto* comp_task_node = dynamic_cast(task_node); if (IsNotMyDuty(comp_task_node)) { auto* fake_consumed_regsts_provider = dynamic_cast(comp_task_node); CHECK_NOTNULL(fake_consumed_regsts_provider)->ConsumeFakeRegstsIf(); } else { task_node->ConsumeAllRegsts(); } }); task_gph->ForEachNode([&](TaskNode* task_node) { auto* comp_task_node = dynamic_cast(task_node); if (IsNotMyDuty(comp_task_node)) { // Do nothing. because all consumed registers are fake. } else { task_node->PinConsumedRegst(); } }); task_gph->TopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); task_gph->DecideExecutionOrder(); task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain(); auto IsReachable = Singleton::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); const JobDesc& job_desc = GlobalJobDesc(); if (job_desc.enable_inplace()) { task_gph->ForEachGpuDeviceNodes([&](const HashSet& dev_nodes) { if (dev_nodes.empty()) { return; } if ((*dev_nodes.begin())->machine_id() != rank_) { return; } // other ranks are ignored. task_gph->EnableInplaceMemSharing(dev_nodes, IsReachable); }); } task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); // put infomation from task_gph into plan. task_gph->ForEachNode([&](TaskNode* task_node) { if (task_node->IsMeaningLess()) { return; } auto* comp_task_node = dynamic_cast(task_node); if (comp_task_node != nullptr) { const auto& parallel_desc = comp_task_node->op_node()->parallel_desc(); if (!task_gph->IsDutyRank(parallel_desc, task_node->machine_id())) { auto* fake_consumed_regsts_provider = dynamic_cast(comp_task_node); CHECK_NOTNULL(fake_consumed_regsts_provider)->EraseFakeRegstsIf(); } } TaskProto task_proto; task_node->ToProto(&task_proto); if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat || task_node->GetTaskType() == kAcc) { PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); } plan->mutable_task()->Add(std::move(task_proto)); }); // post-process for plan and delete Singleton. auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan); PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job, rank_); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); PlanUtil::SetForceInplaceMemBlock(plan, rank_); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/rank_compiler.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RANK_COMPILER_H_ #define ONEFLOW_CORE_JOB_RANK_COMPILER_H_ #include "oneflow/core/common/protobuf.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { class RankCompiler final { public: OF_DISALLOW_COPY_AND_MOVE(RankCompiler); RankCompiler(const std::shared_ptr& boxing_task_graph_proto, int64_t rank) : boxing_task_graph_proto_(boxing_task_graph_proto), rank_(rank) {} ~RankCompiler() = default; Maybe Compile(const HashSet& var_op_names, Job* job, Plan* plan) const; private: std::shared_ptr boxing_task_graph_proto_; int64_t rank_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RANK_COMPILER_H_ ================================================ FILE: oneflow/core/job/rank_group.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/job/rank_group.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { /*static*/ Maybe> RankGroup::New(Symbol parallel_desc) { return DECORATE(&RankGroup::RawNew, ThreadLocal)(parallel_desc); } /*static*/ Maybe> RankGroup::RawNew(Symbol parallel_desc) { CHECK_EQ_OR_RETURN(parallel_desc->sorted_machine_ids().size(), parallel_desc->parallel_num()); const auto& sorted_machine_ids = parallel_desc->sorted_machine_ids(); return New(std::set{sorted_machine_ids.begin(), sorted_machine_ids.end()}); } /*static*/ Maybe> RankGroup::New(const std::set& ranks) { static thread_local std::map, Symbol> map; auto iter = map.find(ranks); if (iter == map.end()) { RankGroup rank_group; JUST(rank_group.Init(ranks)); iter = map.emplace(ranks, SymbolOf(rank_group)).first; } return iter->second; } namespace { Maybe> CalcDefaultParallelDesc(DeviceType device_type, Symbol rank_group) { ParallelConf parallel_conf; parallel_conf.set_device_tag(*JUST(DeviceTag4DeviceType(device_type))); JUST(rank_group->ForEachRank([&](int64_t rank) -> Maybe { int64_t local_rank = GlobalProcessCtx::LocalRank(rank); parallel_conf.add_device_name(std::string("@") + std::to_string(rank) + ":" + std::to_string(local_rank)); return Maybe::Ok(); })); return SymbolOf(ParallelDesc(parallel_conf)); } auto* CachedDefaultParallelDesc = DECORATE(&CalcDefaultParallelDesc, ThreadLocal); } // namespace /*static*/ Maybe> RankGroup::GetDefaultParallelDesc( DeviceType device_type, Symbol rank_group) { return CachedDefaultParallelDesc(device_type, rank_group); } namespace { Maybe> AllWorldRanks() { const auto& ranks = std::make_shared>(); for (int i = 0; i < GlobalProcessCtx::WorldSize(); ++i) { ranks->insert(i); } return ranks; } } // namespace /*static*/ Maybe> RankGroup::DefaultRankGroup() { const auto& all_wold_ranks = JUST(AllWorldRanks()); const auto& rank_group = JUST(RankGroup::New(*all_wold_ranks)); return rank_group; } Maybe RankGroup::Init(const std::set& ranks) { ranks_ = ranks; // Initialize rank2next_rank_in_ring_ and rank2prev_rank_in_ring_ { CHECK_GT_OR_RETURN(ranks.size(), 0); int64_t last = *(--ranks.end()); for (int64_t i : ranks) { CHECK_OR_RETURN(rank2next_rank_in_ring_.emplace(last, i).second); CHECK_OR_RETURN(rank2prev_rank_in_ring_.emplace(i, last).second); last = i; } } // Initialize hash_value_ hash_value_ = 0; for (int64_t i : ranks) { HashCombine(&hash_value_, i); } return Maybe::Ok(); } Maybe RankGroup::GetNextRankInRing(int64_t rank) const { return MapAt(rank2next_rank_in_ring_, rank); } Maybe RankGroup::GetNextRankInRing() const { return GetNextRankInRing(GlobalProcessCtx::Rank()); } Maybe RankGroup::GetPrevRankInRing(int64_t rank) const { return MapAt(rank2prev_rank_in_ring_, rank); } Maybe RankGroup::GetPrevRankInRing() const { return GetPrevRankInRing(GlobalProcessCtx::Rank()); } bool RankGroup::ContainingCurrentRank() const { return rank2next_rank_in_ring_.count(GlobalProcessCtx::Rank()) > 0; } Maybe RankGroup::ForEachRank(const std::function(int64_t)>& DoEach) const { for (int64_t i : ranks_) { JUST(DoEach(i)); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/rank_group.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RANK_GROUP_H_ #define ONEFLOW_CORE_JOB_RANK_GROUP_H_ #include #include #include #include #include #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/device_type.h" namespace oneflow { class ParallelDesc; class RankGroup final { public: ~RankGroup() = default; static Maybe> New(const std::set& ranks); static Maybe> New(Symbol parallel_desc); static Maybe> DefaultRankGroup(); static Maybe> GetDefaultParallelDesc(DeviceType device_type, Symbol rank_group); bool operator==(const RankGroup& that) const { return this->ranks_ == that.ranks_; } bool operator!=(const RankGroup& that) const { return !(*this == that); } size_t size() const { return ranks_.size(); } size_t hash_value() const { return hash_value_; } Maybe GetNextRankInRing(int64_t rank) const; Maybe GetNextRankInRing() const; Maybe GetPrevRankInRing(int64_t rank) const; Maybe GetPrevRankInRing() const; bool ContainingCurrentRank() const; Maybe ForEachRank(const std::function(int64_t)>&) const; private: RankGroup() = default; Maybe Init(const std::set& ranks); static Maybe> RawNew(Symbol parallel_desc); std::set ranks_; std::unordered_map rank2next_rank_in_ring_; std::unordered_map rank2prev_rank_in_ring_; size_t hash_value_; }; } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::RankGroup& rank_group) const { return rank_group.hash_value(); } }; } // namespace std #endif // ONEFLOW_CORE_JOB_RANK_GROUP_H_ ================================================ FILE: oneflow/core/job/rank_group_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/rank_group_scope.h" namespace oneflow { /*static*/ Maybe> RankGroupScope::CurrentRankGroup() { return RankGroup::DefaultRankGroup(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/rank_group_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ #define ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ #include "oneflow/core/job/rank_group.h" #include "oneflow/core/common/symbol.h" namespace oneflow { // NOTE(daquexian): this scope class is not actually used. We only keep // it in case we need it in the future. class RankGroupScope final { public: static Maybe> CurrentRankGroup(); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ ================================================ FILE: oneflow/core/job/rank_group_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" namespace oneflow { namespace test { TEST(RankGroup, two_rank) { const auto& rank_group = CHECK_JUST(RankGroup::New(std::set{0, 1})); int64_t rank = 0; rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); ASSERT_EQ(rank, 1); rank = CHECK_JUST(rank_group->GetNextRankInRing(1)); ASSERT_EQ(rank, 0); rank = CHECK_JUST(rank_group->GetPrevRankInRing(0)); ASSERT_EQ(rank, 1); rank = CHECK_JUST(rank_group->GetPrevRankInRing(1)); ASSERT_EQ(rank, 0); } TEST(RankGroup, nonconsecutive_rank) { const auto& rank_group = CHECK_JUST(RankGroup::New(std::set{0, 1, 3, 4})); int64_t rank = 0; rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); ASSERT_EQ(rank, 1); rank = CHECK_JUST(rank_group->GetNextRankInRing(1)); ASSERT_EQ(rank, 3); rank = CHECK_JUST(rank_group->GetNextRankInRing(3)); ASSERT_EQ(rank, 4); rank = CHECK_JUST(rank_group->GetNextRankInRing(4)); ASSERT_EQ(rank, 0); bool is_ok = TRY(rank_group->GetNextRankInRing(2)).IsOk(); ASSERT_FALSE(is_ok); rank = CHECK_JUST(rank_group->GetPrevRankInRing(1)); ASSERT_EQ(rank, 0); rank = CHECK_JUST(rank_group->GetPrevRankInRing(3)); ASSERT_EQ(rank, 1); rank = CHECK_JUST(rank_group->GetPrevRankInRing(4)); ASSERT_EQ(rank, 3); rank = CHECK_JUST(rank_group->GetPrevRankInRing(0)); ASSERT_EQ(rank, 4); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/job/regularizer_conf.proto ================================================ syntax = "proto2"; package oneflow; message L1L2RegularizerConf { optional float l1 = 1 [default = 0.0]; optional float l2 = 2 [default = 0.0]; } message RegularizerConf { oneof type { L1L2RegularizerConf l1_l2_conf = 1; } } ================================================ FILE: oneflow/core/job/resource.proto ================================================ syntax = "proto2"; package oneflow; import public "oneflow/core/common/device_type.proto"; message CollectiveBoxingConf { // global optional bool enable_fusion = 1 [default = true]; optional int64 num_callback_threads = 2 [default = 4]; // nccl optional int64 nccl_num_streams = 101 [default = 1]; optional int64 nccl_fusion_threshold_mb = 102 [default = 16]; optional bool nccl_fusion_all_reduce = 103 [default = true]; optional bool nccl_fusion_reduce_scatter = 104 [default = false]; optional bool nccl_fusion_all_gather = 105 [default = false]; optional bool nccl_fusion_reduce = 106 [default = true]; optional bool nccl_fusion_broadcast = 107 [default = true]; optional bool nccl_fusion_all_reduce_use_buffer = 108 [default = false]; optional int64 nccl_fusion_max_ops = 109 [default = 64]; optional bool nccl_enable_all_to_all = 110 [default = false]; optional bool nccl_enable_mixed_fusion = 111 [default = false]; } message CudnnConfig { optional bool enable_cudnn = 1 [default = true]; optional int64 cudnn_buf_limit_mbyte = 2 [default = 1024]; // 1GByte optional int32 cudnn_conv_force_fwd_algo = 3; optional int32 cudnn_conv_force_bwd_data_algo = 4; optional int32 cudnn_conv_force_bwd_filter_algo = 5; optional bool cudnn_conv_heuristic_search_algo = 6 [default = true]; optional bool cudnn_conv_use_deterministic_algo_only = 7 [default = false]; optional bool enable_cudnn_fused_normalization_add_relu = 8; optional bool cudnn_conv_enable_pseudo_half = 9 [default = true]; } message Resource { optional int32 machine_num = 1 [default = 0]; optional int32 cpu_device_num = 5 [default = 0]; optional int32 comm_net_worker_num = 6 [default = 4]; optional int32 max_mdsave_worker_num = 7 [default = 64]; optional uint64 reserved_host_mem_mbyte = 12 [default = 500]; optional uint64 reserved_device_mem_mbyte = 13 [default = 500]; optional int32 compute_thread_pool_size = 15; optional bool enable_thread_local_cache = 16 [default = true]; optional int64 thread_local_cache_max_size = 17 [default = 67108864]; // 64M optional bool enable_debug_mode = 18 [default = false]; optional bool enable_tensor_float_32_compute = 20 [default = true]; optional CollectiveBoxingConf collective_boxing_conf = 19; // NOTE(chengcheng) to reuse nccl memory and speed up optional bool nccl_use_compute_stream = 30 [default = false]; optional bool disable_group_boxing_by_dst_parallel = 31 [default = false]; optional CudnnConfig cudnn_conf = 32; optional bool enable_legacy_model_io = 33 [default = true]; optional bool enable_legacy_model_io_v2 = 34 [default = false]; } ================================================ FILE: oneflow/core/job/resource_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/control/global_process_ctx.h" #ifdef WITH_CUDA #include #endif namespace oneflow { ResourceDesc::ResourceDesc(const Resource& resource, int64_t num_process_per_node) : resource_(resource) { CHECK_GT(resource_.machine_num(), 0); CHECK_LE(resource_.machine_num(), Singleton::Get()->TotalMachineNum()); for (int i = 0; i < GlobalProcessCtx::WorldSize(); ++i) { CHECK(process_ranks_.emplace(i).second); } } Machine ResourceDesc::machine(int32_t idx) const { CHECK_GE(idx, 0); CHECK(process_ranks().find(idx) != process_ranks().end()); if (Singleton::Get()->has_ctrl_bootstrap_conf()) { CHECK_NOTNULL(Singleton::Get()); CHECK_GE(Singleton::Get()->ctrl_addr().size(), process_ranks().size()); Machine machine; const Address& addr = Singleton::Get()->ctrl_addr(idx); machine.set_addr(addr.host()); return machine; } else { return Singleton::Get()->machine(idx); } } int32_t ResourceDesc::ComputeThreadPoolSize() const { if (resource_.has_compute_thread_pool_size()) { CHECK_GT(resource_.compute_thread_pool_size(), 0); return resource_.compute_thread_pool_size(); } else { return CpuDeviceNum(); } } bool ResourceDesc::enable_debug_mode() const { return IsInDebugMode() || resource_.enable_debug_mode(); } CollectiveBoxingConf ResourceDesc::collective_boxing_conf() const { if (resource_.has_collective_boxing_conf()) { return resource_.collective_boxing_conf(); } else { return CollectiveBoxingConf(); } } bool ResourceDesc::nccl_use_compute_stream() const { #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 return resource_.nccl_use_compute_stream(); #elif defined(WITH_NPU) return resource_.nccl_use_compute_stream(); #elif defined(WITH_MLU) return resource_.nccl_use_compute_stream(); #else return false; #endif } void ResourceDesc::DumpCudnnConf(const JobConfigProto& job_conf) { auto* cudnn_conf = resource_.mutable_cudnn_conf(); if (job_conf.has_enable_cudnn()) { cudnn_conf->set_enable_cudnn(job_conf.enable_cudnn()); } if (job_conf.has_cudnn_buf_limit_mbyte()) { cudnn_conf->set_cudnn_buf_limit_mbyte(job_conf.cudnn_buf_limit_mbyte()); } if (job_conf.has_cudnn_conv_force_fwd_algo()) { cudnn_conf->set_cudnn_conv_force_fwd_algo(job_conf.cudnn_conv_force_fwd_algo()); } if (job_conf.has_cudnn_conv_force_bwd_data_algo()) { cudnn_conf->set_cudnn_conv_force_bwd_data_algo(job_conf.cudnn_conv_force_bwd_data_algo()); } if (job_conf.has_cudnn_conv_force_bwd_filter_algo()) { cudnn_conf->set_cudnn_conv_force_bwd_filter_algo(job_conf.cudnn_conv_force_bwd_filter_algo()); } if (job_conf.has_cudnn_conv_heuristic_search_algo()) { cudnn_conf->set_cudnn_conv_heuristic_search_algo(job_conf.cudnn_conv_heuristic_search_algo()); } if (job_conf.has_cudnn_conv_use_deterministic_algo_only()) { cudnn_conf->set_cudnn_conv_use_deterministic_algo_only( job_conf.cudnn_conv_use_deterministic_algo_only()); } if (job_conf.has_enable_cudnn_fused_normalization_add_relu()) { cudnn_conf->set_enable_cudnn_fused_normalization_add_relu( job_conf.enable_cudnn_fused_normalization_add_relu()); } if (job_conf.has_cudnn_conv_enable_pseudo_half()) { cudnn_conf->set_cudnn_conv_enable_pseudo_half(job_conf.cudnn_conv_enable_pseudo_half()); } } void ResourceDesc::Update(const Resource& reso_conf) { resource_.CopyFrom(reso_conf); } } // namespace oneflow ================================================ FILE: oneflow/core/job/resource_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RESOURCE_DESC_H_ #define ONEFLOW_CORE_JOB_RESOURCE_DESC_H_ #include #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { static const size_t kMB = 1024 * 1024; class ResourceDesc final { public: OF_DISALLOW_COPY_AND_MOVE(ResourceDesc); ResourceDesc(const Resource& resource, int64_t num_process_per_node); ResourceDesc(const Resource& resource) : resource_(resource) {} // TODO(yaochi): Only for eager, remove it later ~ResourceDesc() = default; const std::set& process_ranks() const { return process_ranks_; } __attribute__((deprecated)) Machine machine(int32_t idx) const; size_t CommNetWorkerNum() const { return resource_.comm_net_worker_num(); } int32_t CpuDeviceNum() const { return resource_.cpu_device_num(); } int32_t MaxMdSaveWorkerNum() const { return resource_.max_mdsave_worker_num(); } size_t reserved_host_mem_byte() const { return resource_.reserved_host_mem_mbyte() * kMB; } size_t reserved_device_mem_byte() const { return resource_.reserved_device_mem_mbyte() * kMB; } bool enable_thread_local_cache() const { return resource_.enable_thread_local_cache(); } size_t thread_local_cache_max_size() const { return resource_.thread_local_cache_max_size(); } int32_t ComputeThreadPoolSize() const; bool enable_debug_mode() const; CollectiveBoxingConf collective_boxing_conf() const; bool nccl_use_compute_stream() const; void SetMachineNum(int32_t val) { resource_.set_machine_num(val); } void SetCpuDeviceNum(int32_t val) { resource_.set_cpu_device_num(val); } bool enable_tensor_float_32_compute() const { return resource_.enable_tensor_float_32_compute(); } const Resource& resource() const { return resource_; } void DumpCudnnConf(const JobConfigProto& job_conf); void Update(const Resource& reso_conf); private: Resource resource_; std::set process_ranks_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RESOURCE_DESC_H_ ================================================ FILE: oneflow/core/job/runtime.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/runtime.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/job/runtime_job_descs.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/graph/task_node.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/user/summary/events_writer.h" namespace oneflow { namespace { void SendCmdMsg(const std::vector& tasks, ActorCmd cmd) { for (const TaskProto* task : tasks) { ActorMsg msg = ActorMsg::BuildCommandMsg(task->task_id(), cmd); Singleton::Get()->SendMsg(msg); } } void HandoutTasks(const std::vector& tasks) { for (const TaskProto* task : tasks) { Singleton::Get()->GetThrd(task->thrd_id())->AddTask(*task); } SendCmdMsg(tasks, ActorCmd::kConstructActor); } bool HasNonCtrlConsumedRegstDescId(const TaskProto& task) { for (const auto& pair : task.consumed_regst_desc_id()) { if (pair.first == "in_ctrl") { continue; } return true; } return false; } } // namespace Runtime::Runtime( const Plan& plan, const HashMap& variable_op_name2eager_blob_object) { DumpThreadIdsFromPlan(plan); { // NOTE(chengcheng): All runtime global(singleton) objects AddPlan Singleton::Get()->AddPlan(plan, variable_op_name2eager_blob_object); Singleton::Get()->AddThreads(thread_ids_); Singleton::Get()->AddPlan(plan); collective_boxing_scheduler_plan_token_ = Singleton::Get()->AddPlan(plan); #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) const auto& vaild_ccl_comm_mgr_device_types = EagerCclCommMgrBuilder::Get().vaild_ccl_comm_mgr_device_types(); if (!vaild_ccl_comm_mgr_device_types.empty() && !Singleton::Get()) { Singleton::SetAllocated( EagerCclCommMgrBuilder::Get().NewCclCommMgr(vaild_ccl_comm_mgr_device_types.front())); } Singleton::Get()->CreateCommFromPlan(plan); #endif // defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) } std::vector source_tasks; source_tasks.reserve(plan.task().size()); std::vector other_tasks; other_tasks.reserve(plan.task().size()); int64_t this_machine_task_num = 0; for (const TaskProto& task : plan.task()) { if (task.machine_id() != GlobalProcessCtx::Rank()) { continue; } if (!HasNonCtrlConsumedRegstDescId(task)) { source_tasks.emplace_back(&task); } else { other_tasks.emplace_back(&task); } auto it = job_id2actor_size_.find(task.job_id()); if (it == job_id2actor_size_.end()) { auto emplace_ret_pair = job_id2actor_size_.emplace(task.job_id(), 0); CHECK(emplace_ret_pair.second); it = emplace_ret_pair.first; } it->second++; this_machine_task_num++; } RuntimeCtx* runtime_ctx = Singleton::Get(); runtime_ctx->NewCounter("constructing_actor_cnt", this_machine_task_num); HandoutTasks(source_tasks); HandoutTasks(other_tasks); runtime_ctx->WaitUntilCntEqualZero("constructing_actor_cnt"); VLOG(3) << "Actors on this machine constructed"; OF_SESSION_BARRIER(); VLOG(3) << "Actors on every machine constructed"; for (auto pair : job_id2actor_size_) { runtime_ctx->NewCounter(GetRunningActorCountKeyByJobId(pair.first), pair.second); } SendCmdMsg(source_tasks, ActorCmd::kStart); } Runtime::~Runtime() { for (auto pair : job_id2actor_size_) { Singleton::Get()->WaitUntilCntEqualZero(GetRunningActorCountKeyByJobId(pair.first)); } OF_SESSION_BARRIER(); Singleton::Get()->DeleteThreads(independent_thread_ids_); Singleton::Get()->DeletePlan( collective_boxing_scheduler_plan_token_); } void Runtime::DumpThreadIdsFromPlan(const Plan& plan) { const int64_t this_rank = GlobalProcessCtx::Rank(); for (const TaskProto& task : plan.task()) { TaskId task_id = DecodeTaskIdFromInt64(task.task_id()); StreamId stream_id = task_id.stream_id(); if (stream_id.rank() != this_rank) { continue; } int64_t thrd_id = EncodeStreamIdToInt64(stream_id); thread_ids_.insert(thrd_id); // NOTE(chengcheng): there is not a interface to query whether a task type is indenpendent, // so use hard code. if (task.task_type() == TaskType::kWaitAndSendIds || task.task_type() == TaskType::kCriticalSectionWaitTick) { CHECK(independent_thread_ids_.insert(thrd_id).second) << " RuntimeError! Thread : " << thrd_id << " not independent with task proto: " << task.DebugString(); } } } } // namespace oneflow ================================================ FILE: oneflow/core/job/runtime.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RUNTIME_H_ #define ONEFLOW_CORE_JOB_RUNTIME_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/job/collective_boxing/scheduler.h" namespace oneflow { namespace vm { class EagerBlobObject; } class Runtime final { public: OF_DISALLOW_COPY_AND_MOVE(Runtime); Runtime() = delete; ~Runtime(); // TODO(chengcheng): refactor Runtime interface about variable_op_name2eager_blob_object Runtime(const Plan& plan, const HashMap& variable_op_name2eager_blob_object); private: void DumpThreadIdsFromPlan(const Plan& plan); HashMap job_id2actor_size_; HashSet thread_ids_; HashSet independent_thread_ids_; boxing::collective::SchedulerPlanToken* collective_boxing_scheduler_plan_token_; }; } // namespace oneflow #endif ================================================ FILE: oneflow/core/job/runtime_buffer_managers_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/runtime_buffer_managers_scope.h" #include "oneflow/core/job/job_instance.h" namespace oneflow { RuntimeBufferManagersScope::RuntimeBufferManagersScope() { Singleton>::New(); Singleton>>::New(); } RuntimeBufferManagersScope::~RuntimeBufferManagersScope() { Singleton>>::Delete(); Singleton>::Delete(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/runtime_buffer_managers_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_ #define ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_ #include "oneflow/core/common/util.h" namespace oneflow { class RuntimeBufferManagersScope final { public: OF_DISALLOW_COPY_AND_MOVE(RuntimeBufferManagersScope); RuntimeBufferManagersScope(); ~RuntimeBufferManagersScope(); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_ ================================================ FILE: oneflow/core/job/runtime_buffers_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/runtime_buffers_scope.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_instance.h" namespace oneflow { RuntimeBuffersScope::RuntimeBuffersScope(const JobConfs& job_confs) { size_t job_size = Singleton::Get()->size(); Singleton>::Get()->NewBuffer(kBufferNameGlobalWaitJobId, job_size); auto* buffer_mgr = Singleton>>::Get(); for (const auto& pair : job_confs.job_id2job_conf()) { const auto& job_name = pair.second.job_name(); CHECK_EQ(pair.first, Singleton::Get()->at(job_name)); size_t concurrency_width = pair.second.concurrency_width(); buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(job_name), concurrency_width); } } RuntimeBuffersScope::~RuntimeBuffersScope() { auto* buffer_mgr = Singleton>>::Get(); for (const auto& pair : *Singleton::Get()) { const auto& job_name = pair.first; buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Close(); } Singleton>::Get()->Get(kBufferNameGlobalWaitJobId)->Close(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/runtime_buffers_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_ #define ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" namespace oneflow { class RuntimeBuffersScope final { public: OF_DISALLOW_COPY_AND_MOVE(RuntimeBuffersScope); RuntimeBuffersScope(const JobConfs& job_confs); ~RuntimeBuffersScope(); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_ ================================================ FILE: oneflow/core/job/runtime_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/runtime_context.h" namespace oneflow { void RuntimeCtx::NewCounter(const std::string& name, int64_t val) { VLOG(3) << "NewCounter " << name << " " << val; CHECK(counters_.emplace(name, std::make_unique(val)).second); } void RuntimeCtx::DecreaseCounter(const std::string& name) { auto it = counters_.find(name); CHECK(it != counters_.end()); int64_t cur_val = it->second->Decrease(); VLOG(3) << "DecreaseCounter " << name << ", current val is " << cur_val; } void RuntimeCtx::WaitUntilCntEqualZero(const std::string& name) { auto it = counters_.find(name); CHECK(it != counters_.end()); it->second->WaitForeverUntilCntEqualZero(); counters_.erase(it); } std::string GetRunningActorCountKeyByJobId(int64_t job_id) { return "job_" + std::to_string(job_id) + "_running_actor_count"; } } // namespace oneflow ================================================ FILE: oneflow/core/job/runtime_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_ #define ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_ #include "oneflow/core/common/blocking_counter.h" namespace oneflow { class RuntimeCtx final { public: OF_DISALLOW_COPY_AND_MOVE(RuntimeCtx); RuntimeCtx() = default; ~RuntimeCtx() = default; void NewCounter(const std::string& name, int64_t val); void DecreaseCounter(const std::string& name); void WaitUntilCntEqualZero(const std::string& name); private: HashMap> counters_; }; std::string GetRunningActorCountKeyByJobId(int64_t job_id); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_ ================================================ FILE: oneflow/core/job/runtime_job_descs.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/runtime_job_descs.h" namespace oneflow { void RuntimeJobDescs::AddPlan(const Plan& plan) { for (const auto& pair : plan.job_confs().job_id2job_conf()) { auto job_desc = std::make_unique(pair.second, pair.first); CHECK(job_id2job_desc_.emplace(pair.first, std::move(job_desc)).second); } } const JobDesc& RuntimeJobDescs::job_desc(int64_t job_id) const { auto it = job_id2job_desc_.find(job_id); CHECK(it != job_id2job_desc_.end()); return *(it->second); } } // namespace oneflow ================================================ FILE: oneflow/core/job/runtime_job_descs.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_ #define ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_ #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/job_desc.h" namespace oneflow { class RuntimeJobDescs final { public: RuntimeJobDescs() = default; ~RuntimeJobDescs() = default; void AddPlan(const Plan& plan); const JobDesc& job_desc(int64_t job_id) const; private: HashMap> job_id2job_desc_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_ ================================================ FILE: oneflow/core/job/sbp_infer_hint.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ #define ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/register/blob_desc.h" namespace oneflow { class SbpInferHint final { public: SbpInferHint(const ParallelDesc* parallel_desc, const BlobDesc* logical_blob_desc, const SbpParallel* sbp_parallel) : parallel_desc_(parallel_desc), logical_blob_desc_(logical_blob_desc), sbp_parallel_(sbp_parallel) {} SbpInferHint(const SbpInferHint&) = default; ~SbpInferHint() = default; // Getters const ParallelDesc& parallel_desc() const { return *parallel_desc_; } const BlobDesc& logical_blob_desc() const { return *logical_blob_desc_; } const SbpParallel& sbp_parallel() const { return *sbp_parallel_; } private: const ParallelDesc* parallel_desc_; const BlobDesc* logical_blob_desc_; const SbpParallel* sbp_parallel_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ ================================================ FILE: oneflow/core/job/sbp_parallel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { bool operator==(const SbpParallel& lhs, const SbpParallel& rhs) { if (lhs.parallel_type_case() != rhs.parallel_type_case()) { return false; } if (lhs.has_split_parallel()) { return lhs.split_parallel().axis() == rhs.split_parallel().axis(); } else if (lhs.has_broadcast_parallel()) { return true; } else if (lhs.has_partial_sum_parallel()) { return true; } else { UNIMPLEMENTED(); } } bool operator==(const NdSbp& lhs, const NdSbp& rhs) { if (lhs.sbp_parallel_size() != rhs.sbp_parallel_size()) { return false; } for (int i = 0; i < lhs.sbp_parallel_size(); ++i) { if (lhs.sbp_parallel(i) != rhs.sbp_parallel(i)) { return false; } } return true; ; } bool operator==(const SbpSignature& lhs, const SbpSignature& rhs) { if (lhs.bn_in_op2sbp_parallel_size() != rhs.bn_in_op2sbp_parallel_size()) { return false; } const auto& lhs_map = lhs.bn_in_op2sbp_parallel(); const auto& rhs_map = rhs.bn_in_op2sbp_parallel(); for (const auto& lhs_pair : lhs_map) { const auto& rhs_iter = rhs_map.find(lhs_pair.first); if (rhs_iter == rhs_map.end()) { return false; } if (lhs_pair.second != rhs_iter->second) { return false; } } return true; } bool operator==(const NdSbpSignature& lhs, const NdSbpSignature& rhs) { if (lhs.bn_in_op2nd_sbp_size() != rhs.bn_in_op2nd_sbp_size()) { return false; } const auto& lhs_map = lhs.bn_in_op2nd_sbp(); const auto& rhs_map = rhs.bn_in_op2nd_sbp(); for (const auto& lhs_pair : lhs_map) { const auto& rhs_iter = rhs_map.find(lhs_pair.first); if (rhs_iter == rhs_map.end()) { return false; } if (lhs_pair.second != rhs_iter->second) { return false; } } return true; } Maybe> MakeSplitSbpParallel(int axis) { CHECK_LT_OR_RETURN(axis, kMaxSplitAxis); SbpParallel split_sbp_parallel; split_sbp_parallel.mutable_split_parallel()->set_axis(axis); return SymbolOf(split_sbp_parallel); } Maybe> MakeBroadcastSbpParallel() { SbpParallel broadcast_sbp; broadcast_sbp.mutable_broadcast_parallel(); return SymbolOf(broadcast_sbp); } Maybe> MakePartialSumSbpParallel() { SbpParallel partial_sum_sbp; partial_sum_sbp.mutable_partial_sum_parallel(); return SymbolOf(partial_sum_sbp); } // S -> S // P -> B // B -> P SbpParallel GetDualSbpParallel(const SbpParallel& sbp_parallel) { SbpParallel ret(sbp_parallel); if (sbp_parallel.has_split_parallel()) { // do nothing } else if (sbp_parallel.has_broadcast_parallel()) { ret.mutable_partial_sum_parallel(); } else if (sbp_parallel.has_partial_sum_parallel()) { ret.mutable_broadcast_parallel(); } else { UNIMPLEMENTED(); } return ret; } bool IsSbpSignatureContaining(const SbpSignature& bigger, const SbpSignature& smaller) { auto& bn2sbp = bigger.bn_in_op2sbp_parallel(); for (const auto& pair : smaller.bn_in_op2sbp_parallel()) { if (pair.second.parallel_type_case() == SbpParallel::PARALLEL_TYPE_NOT_SET) { continue; } CHECK(bn2sbp.find(pair.first) != bn2sbp.end()) << pair.first; if (bn2sbp.at(pair.first) != pair.second) { return false; } } return true; } void FilterSbpSignatureList(const SbpSignatureList& sbp_sig_list, const SbpSignature& sbp_sig_conf, SbpSignatureList* filtered_sbp_sig_list) { for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { if (IsSbpSignatureContaining(sbp_signature, sbp_sig_conf)) { *filtered_sbp_sig_list->mutable_sbp_signature()->Add() = sbp_signature; } } } double ComputCopyCostBetweenTwoSbpParallel(const SbpInferHint& producer_sbp_infer_hint, const SbpParallel& consumer_sbp_parallel) { if (producer_sbp_infer_hint.sbp_parallel() == consumer_sbp_parallel) { return 0.0; } if (consumer_sbp_parallel.has_partial_sum_parallel()) { return GetMaxVal(); } if (producer_sbp_infer_hint.sbp_parallel().has_broadcast_parallel()) { return GetMaxVal(); } const auto& logical_blob_desc = producer_sbp_infer_hint.logical_blob_desc(); return logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); } double ComputeIbnCopyCost4SbpSig( const PbRpf& ibns, const std::function(const std::string&)>& SbpInferHint4Ibn, const SbpSignature& sbp_signature) { double cost = 0; for (const auto& ibn : ibns) { const auto& consumer_sbp_parallel = sbp_signature.bn_in_op2sbp_parallel().find(ibn)->second; cost += ComputCopyCostBetweenTwoSbpParallel(*CHECK_JUST(SbpInferHint4Ibn(ibn)), consumer_sbp_parallel); } return cost; } std::function MakeGetterIbnCopyCost4SbpSig( const PbRpf& ibns, const std::function(const std::string&)>& SbpInferHint4Ibn, const SbpSignatureList& sbp_sig_list) { auto sbp_sig2ibn_copy_cast = std::make_shared>(); for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { double cost = ComputeIbnCopyCost4SbpSig(ibns, SbpInferHint4Ibn, sbp_signature); CHECK(sbp_sig2ibn_copy_cast->emplace(&sbp_signature, cost).second); } return [sbp_sig2ibn_copy_cast](const SbpSignature* sbp_sig) -> double { return sbp_sig2ibn_copy_cast->at(sbp_sig); }; } std::function MakeGetterOrderValue4SbpSig( const SbpSignatureList& sbp_sig_list, const std::function& CalcOrderValue4SbpSig) { auto sbp_sig2order_value = std::make_shared>(); for (const SbpSignature& sbp_signature : sbp_sig_list.sbp_signature()) { sbp_sig2order_value->emplace(&sbp_signature, CalcOrderValue4SbpSig(sbp_signature)); } return [sbp_sig2order_value](const SbpSignature* sbp_sig) { return sbp_sig2order_value->at(sbp_sig); }; } void SortSbpSignatureListByCopyCost( const SbpSignatureList& sbp_sig_list, const PbRpf& ibns, const std::function(const std::string&)>& SbpInferHint4Ibn, const std::function& CalcOrderValue4SbpSig, std::vector* sorted_sbp_signatures) { auto OrderValue4SbpSig = MakeGetterOrderValue4SbpSig(sbp_sig_list, CalcOrderValue4SbpSig); auto IbnCopyCost4SbpSig = MakeGetterIbnCopyCost4SbpSig(ibns, SbpInferHint4Ibn, sbp_sig_list); for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { sorted_sbp_signatures->emplace_back(&sbp_signature); } std::sort(sorted_sbp_signatures->begin(), sorted_sbp_signatures->end(), [&](const SbpSignature* lhs, const SbpSignature* rhs) { if (OrderValue4SbpSig(lhs) < OrderValue4SbpSig(rhs)) { return true; } if (OrderValue4SbpSig(lhs) > OrderValue4SbpSig(rhs)) { return false; } return IbnCopyCost4SbpSig(lhs) < IbnCopyCost4SbpSig(rhs); }); } bool IsValidSbpParallelString(const std::string& sbp_str) { SbpParallel sbp_parallel; return ParseSbpParallelFromString(sbp_str, &sbp_parallel); } bool ParseNdSbpFromLongString(const std::string& nd_sbp_str, NdSbp* nd_sbp) { bool success = true; Split(nd_sbp_str, ",", [&](std::string&& sbp_str) { SbpParallel* sbp_parallel = nd_sbp->add_sbp_parallel(); bool ret = ParseSbpParallelFromString(sbp_str, sbp_parallel); if (!ret) { success = false; } }); if (nd_sbp->sbp_parallel_size() == 0) { return false; } return success; } std::string NdSbpToLongString(const NdSbp& nd_sbp) { std::string ret = ""; for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (i > 0) { ret += ","; } // NOTE(chengcheng): Separator ',' ret += SbpToString(nd_sbp.sbp_parallel(i)); } return ret; } bool ParseSbpParallelFromString(const std::string& sbp_str, SbpParallel* sbp_parallel) { bool success = false; if (sbp_str.length() >= 1) { if (sbp_str == "B") { sbp_parallel->mutable_broadcast_parallel(); success = true; } else if (sbp_str == "P") { sbp_parallel->mutable_partial_sum_parallel(); success = true; } else if (sbp_str[0] == 'S') { if (sbp_str.length() >= 4 && sbp_str[1] == '(' && sbp_str[sbp_str.length() - 1] == ')') { int split_axis = 0; if (sbp_str.length() == 4) { split_axis = sbp_str[2] - '0'; if (split_axis >= 0 && split_axis <= 9) { success = true; } } else { std::string split_axis_str = sbp_str.substr(2, sbp_str.length() - 3); if (std::all_of(split_axis_str.cbegin(), split_axis_str.cend(), [](char ch) { return std::isdigit(ch); })) { size_t pos = 0; split_axis = std::stoi(split_axis_str, &pos); if (pos == split_axis_str.length()) { success = true; } } } if (success) { sbp_parallel->mutable_split_parallel()->set_axis(split_axis); } } } } return success; } std::string SbpParallelToString(const SbpParallel& sbp_parallel) { return SbpToString(sbp_parallel); } bool ParseNdSbpFromStringList(const std::vector& sbp_str_list, NdSbp* nd_sbp) { for (const auto& sbp_str : sbp_str_list) { if (!ParseSbpParallelFromString(sbp_str, nd_sbp->add_sbp_parallel())) { return false; } } return true; } std::vector NdSbpToStringList(const NdSbp& nd_sbp) { std::vector sbp_str_list(nd_sbp.sbp_parallel_size()); for (size_t i = 0; i < sbp_str_list.size(); ++i) { sbp_str_list[i] = SbpToString(nd_sbp.sbp_parallel(i)); } return sbp_str_list; } void SbpSignatureToNdSbpSignature(const SbpSignature& sbp_signature, NdSbpSignature* nd_sbp_signature) { for (const auto& pair : sbp_signature.bn_in_op2sbp_parallel()) { *((*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[pair.first].add_sbp_parallel()) = pair.second; } } void NdSbpSignatureToSbpSignature(const NdSbpSignature& nd_sbp_signature, SbpSignature* sbp_signature) { for (const auto& pair : nd_sbp_signature.bn_in_op2nd_sbp()) { CHECK_EQ(pair.second.sbp_parallel_size(), 1); (*sbp_signature->mutable_bn_in_op2sbp_parallel())[pair.first] = pair.second.sbp_parallel(0); } } void CheckSbpSignatureAndNdSbpEquals(const SbpSignature& sbp_sig, const NdSbpSignature& nd_sbp_sig) { CHECK_EQ(sbp_sig.bn_in_op2sbp_parallel_size(), nd_sbp_sig.bn_in_op2nd_sbp_size()); for (const auto& pair : nd_sbp_sig.bn_in_op2nd_sbp()) { const auto& bn_in_op2sbp_parallel = sbp_sig.bn_in_op2sbp_parallel(); const auto it = bn_in_op2sbp_parallel.find(pair.first); CHECK(it != bn_in_op2sbp_parallel.end()); CHECK_EQ(pair.second.sbp_parallel_size(), 1); CHECK(pair.second.sbp_parallel(0) == it->second); } } bool NdSbpAllSameSplitParallel(const NdSbp& nd_sbp) { CHECK_GT(nd_sbp.sbp_parallel_size(), 0); const SbpParallel& first_sbp = nd_sbp.sbp_parallel(0); if (!first_sbp.has_split_parallel()) { return false; } FOR_RANGE(int64_t, i, 1, nd_sbp.sbp_parallel_size()) { if (nd_sbp.sbp_parallel(i) != first_sbp) { return false; } } return true; } Maybe NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature, const std::vector& inputs, const std::vector& outputs) { std::ostringstream ss; auto AppendBnNdSbpString = [&](const std::string& bn) -> Maybe { auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn); if (iter == nd_sbp_signature.bn_in_op2nd_sbp().end()) { return Error::RuntimeError() << "can't find " << bn << " in NdSbpSignature: " << nd_sbp_signature.DebugString(); } ss << " " << NdSbpToString(iter->second); return Maybe::Ok(); }; int bn_index = 0; for (const auto& ibn : inputs) { if (bn_index > 0) { ss << ", "; } ss << ibn; JUST(AppendBnNdSbpString(ibn)); bn_index++; } ss << " -> "; bn_index = 0; for (const auto& obn : outputs) { if (bn_index > 0) { ss << ", "; } ss << obn; JUST(AppendBnNdSbpString(obn)); bn_index++; } return ss.str(); } Maybe NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature, const PbRpf& inputs, const PbRpf& outputs) { return NdSbpSignatureToString(nd_sbp_signature, std::vector{inputs.begin(), inputs.end()}, std::vector{outputs.begin(), outputs.end()}); } Maybe NdSbpSignatureListToString(const std::vector& nd_sbp_sig_list, const std::vector& inputs, const std::vector& outputs) { std::ostringstream ss; if (nd_sbp_sig_list.empty()) { return ss.str(); } auto WalkIO = [&](const std::function(const std::string&)>& bn_handler) -> Maybe { ss << "("; for (size_t i = 0; i < inputs.size(); ++i) { ss << *JUST(bn_handler(inputs[i])); if (i != inputs.size() - 1) { ss << ", "; } } ss << ") -> ("; for (size_t i = 0; i < outputs.size(); ++i) { ss << *JUST(bn_handler(outputs[i])); if (i != outputs.size() - 1) { ss << ", "; } } ss << ")"; return Maybe::Ok(); }; ss << "\n"; JUST(WalkIO([](const std::string& bn) -> Maybe { return bn; })); ss << ": "; ss << "[\n"; for (const auto& nd_sbp_sig : nd_sbp_sig_list) { ss << "\t"; JUST(WalkIO([&](const std::string& bn) -> Maybe { auto it = nd_sbp_sig.bn_in_op2nd_sbp().find(bn); if (it == nd_sbp_sig.bn_in_op2nd_sbp().end()) { return Error::RuntimeError() << "can't find " << bn << " in NdSbpSignature: " << nd_sbp_sig.DebugString(); } return NdSbpToString(it->second); })); ss << ",\n"; } ss << "]"; return ss.str(); } Maybe NdSbpSignatureListToString(const std::vector& nd_sbp_sig_list, const PbRpf& inputs, const PbRpf& outputs) { return NdSbpSignatureListToString(nd_sbp_sig_list, std::vector{inputs.begin(), inputs.end()}, std::vector{outputs.begin(), outputs.end()}); } } // namespace oneflow ================================================ FILE: oneflow/core/job/sbp_parallel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ #define ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job/sbp_infer_hint.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/symbol.h" namespace oneflow { bool operator==(const SbpParallel& lhs, const SbpParallel& rhs); inline bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs) { return !(lhs == rhs); } bool operator==(const NdSbp& lhs, const NdSbp& rhs); inline bool operator!=(const NdSbp& lhs, const NdSbp& rhs) { return !(lhs == rhs); } bool operator==(const SbpSignature& lhs, const SbpSignature& rhs); inline bool operator!=(const SbpSignature& lhs, const SbpSignature& rhs) { return !(lhs == rhs); } bool operator==(const NdSbpSignature& lhs, const NdSbpSignature& rhs); inline bool operator!=(const NdSbpSignature& lhs, const NdSbpSignature& rhs) { return !(lhs == rhs); } } // namespace oneflow namespace std { template<> struct hash : public oneflow::SerializedHashPb {}; template<> struct hash : public oneflow::SerializedHashPb {}; } // namespace std namespace oneflow { Maybe> MakeSplitSbpParallel(int axis); Maybe> MakeBroadcastSbpParallel(); Maybe> MakePartialSumSbpParallel(); SbpParallel GetDualSbpParallel(const SbpParallel&); bool IsSbpSignatureContaining(const SbpSignature& bigger, const SbpSignature& smaller); void FilterSbpSignatureList(const SbpSignatureList& sbp_sig_list, const SbpSignature& sbp_sig_conf, SbpSignatureList* filtered_sbp_sig_list); void SortSbpSignatureListByCopyCost( const SbpSignatureList& sbp_sig_list, const PbRpf& ibns, const std::function(const std::string&)>& SbpInferHint4Ibn, const std::function& OrderValue4SbpSig, std::vector* sorted_sbp_signatures); bool IsValidSbpParallelString(const std::string& sbp_str); bool ParseSbpParallelFromString(const std::string& sbp_str, SbpParallel* sbp_parallel); std::string SbpParallelToString(const SbpParallel& sbp_parallel); bool ParseNdSbpFromStringList(const std::vector& sbp_str_list, NdSbp* nd_sbp); std::vector NdSbpToStringList(const NdSbp& nd_sbp); bool ParseNdSbpFromLongString(const std::string& nd_sbp_str, NdSbp* nd_sbp); std::string NdSbpToLongString(const NdSbp& nd_sbp); void SbpSignatureToNdSbpSignature(const SbpSignature& sbp_signature, NdSbpSignature* nd_sbp_signature); void NdSbpSignatureToSbpSignature(const NdSbpSignature& nd_sbp_signature, SbpSignature* sbp_signature); void CheckSbpSignatureAndNdSbpEquals(const SbpSignature& sbp_sig, const NdSbpSignature& nd_sbp_sig); bool NdSbpAllSameSplitParallel(const NdSbp& nd_sbp); // Print functions Maybe NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature, const std::vector& inputs, const std::vector& outputs); Maybe NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature, const PbRpf& inputs, const PbRpf& outputs); Maybe NdSbpSignatureListToString(const std::vector& nd_sbp_sig_list, const std::vector& inputs, const std::vector& outputs); Maybe NdSbpSignatureListToString(const std::vector& nd_sbp_sig_list, const PbRpf& inputs, const PbRpf& outputs); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ ================================================ FILE: oneflow/core/job/sbp_parallel.proto ================================================ syntax = "proto2"; package oneflow; // Take matmal_op as an example. // Y = A * B // (m, n) (m, k) , (k, n) // candidate signature 0: // Y:Split(0), A:Split(0), B:Broadcast // ----------------------------------- // device0: Y0 = A0 * B // (m0, n) (m0, k) , (k, n) // ----------------------------------- // device1: Y1 = A1 * B // (m1, n) (m1, k) , (k, n) // ----------------------------------- // where (m0 + m1 == m) // and (A0 == A[0:m0, :]) and (A1 == A[m0:, :]) // and (Y0 == Y[0:m0, :]) and (Y1 == Y[m0:, :]) // candidate signature 1: // Y:Split(1), A:Broadcast, B:Split(1) // ----------------------------------- // device0: Y0 = A * B0 // (m, n0) (m, k) , (k, n0) // ----------------------------------- // device1: Y1 = A * B1 // (m, n1) (m, k) , (k, n1) // ----------------------------------- // where (n0 + n1 == n) // and (B0 == B[:, 0:n0]) and (B1 == B[:, n0:]) // and (Y0 == Y[:, 0:n0]) and (Y1 == Y[:, n0:]) // candidate signature 2: // Y:PartialSum, A:Split(1), B:Split(0) // ------------------------------------ // device0: Y0 = A0 * B0 // (m, n) (m, k0) , (k0, n) // ------------------------------------ // device1: Y1 = A1 * B1 // (m, n) (m, k1) , (k1, n) // ------------------------------------ // where (k0 + k1 == k) and (Y0 + Y1 == Y) message SplitParallel { required int64 axis = 1; } message BroadcastParallel { } message PartialSumParallel { } message SbpParallel { oneof parallel_type { SplitParallel split_parallel = 1; BroadcastParallel broadcast_parallel = 2; PartialSumParallel partial_sum_parallel = 3; } } message SbpSignature { map bn_in_op2sbp_parallel = 1; } message NdSbp { repeated SbpParallel sbp_parallel = 1; } message NdSbpSignature { map bn_in_op2nd_sbp = 1; } message SbpSignatureList { repeated SbpSignature sbp_signature = 1; } ================================================ FILE: oneflow/core/job/sbp_signature_builder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { void SplitSbpSignatureListBuilder::CheckTemplate() { CHECK_GT(sbp_signature_template_.bn_in_op2sbp_parallel().size(), 0); const auto& first = sbp_signature_template_.bn_in_op2sbp_parallel().begin()->second; CHECK(first.has_split_parallel()); for (const auto& pair : sbp_signature_template_.bn_in_op2sbp_parallel()) { CHECK(first == pair.second); } } SplitSbpSignatureListBuilder&& SplitSbpSignatureListBuilder::SetNumAxes(int64_t num_axes) { num_axes_ = num_axes; return std::move(*this); } void SplitSbpSignatureListBuilder::Build(SbpSignatureList* list) const { CHECK_GE(num_axes_, 0); SbpSignature sbp_sig_template(sbp_signature_template_); FOR_RANGE(int32_t, axis, 0, num_axes_) { for (auto& pair : *sbp_sig_template.mutable_bn_in_op2sbp_parallel()) { pair.second.mutable_split_parallel()->set_axis(axis); } *list->mutable_sbp_signature()->Add() = sbp_sig_template; } } SbpSignatureBuilder&& SbpSignatureBuilder::Split(const std::string& bn_in_op, int64_t axis) { (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_split_parallel()->set_axis( axis); return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::Broadcast(const std::string& bn_in_op) { (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_broadcast_parallel(); return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::PartialSum(const std::string& bn_in_op) { (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_partial_sum_parallel(); return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::Split(const PbRpf& bns, int64_t axis) { for (const auto& bn_in_op : bns) { Split(bn_in_op, axis); } return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::Broadcast(const PbRpf& bns) { for (const auto& bn_in_op : bns) { Broadcast(bn_in_op); } return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::PartialSum(const PbRpf& bns) { for (const auto& bn_in_op : bns) { PartialSum(bn_in_op); } return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::Split(const std::initializer_list& bns, int64_t axis) { for (const auto& bn_in_op : bns) { Split(bn_in_op, axis); } return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::Broadcast( const std::initializer_list& bns) { for (const auto& bn_in_op : bns) { Broadcast(bn_in_op); } return std::move(*this); } SbpSignatureBuilder&& SbpSignatureBuilder::PartialSum( const std::initializer_list& bns) { for (const auto& bn_in_op : bns) { PartialSum(bn_in_op); } return std::move(*this); } SplitSbpSignatureListBuilder SbpSignatureBuilder::MakeSplitSignatureListBuilder( int64_t num_axes) const { SbpSignature sbp_signature; Build(&sbp_signature); return SplitSbpSignatureListBuilder(sbp_signature).SetNumAxes(num_axes); } } // namespace oneflow ================================================ FILE: oneflow/core/job/sbp_signature_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_ #define ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_ #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { class SplitSbpSignatureListBuilder final { public: SplitSbpSignatureListBuilder(const SplitSbpSignatureListBuilder&) = default; explicit SplitSbpSignatureListBuilder(const SbpSignature& sbp_signature_template) : sbp_signature_template_(sbp_signature_template), num_axes_(0) { CheckTemplate(); } ~SplitSbpSignatureListBuilder() = default; SplitSbpSignatureListBuilder&& SetNumAxes(int64_t num_axes); void Build(SbpSignatureList* list) const; private: void CheckTemplate(); SbpSignature sbp_signature_template_; int64_t num_axes_; }; class SbpSignatureBuilder final { public: OF_DISALLOW_COPY_AND_MOVE(SbpSignatureBuilder); SbpSignatureBuilder() = default; ~SbpSignatureBuilder() = default; // split SbpSignatureBuilder&& Split(const std::string& bn_in_op, int64_t axis); SbpSignatureBuilder&& Split(const PbRpf& bns, int64_t axis); SbpSignatureBuilder&& Split(const std::initializer_list& bns, int64_t axis); // broadcast SbpSignatureBuilder&& Broadcast(const std::string& bn_in_op); SbpSignatureBuilder&& Broadcast(const PbRpf& bns); SbpSignatureBuilder&& Broadcast(const std::initializer_list& bns); // partial_sum SbpSignatureBuilder&& PartialSum(const std::string& bn_in_op); SbpSignatureBuilder&& PartialSum(const PbRpf& bns); SbpSignatureBuilder&& PartialSum(const std::initializer_list& bns); SplitSbpSignatureListBuilder MakeSplitSignatureListBuilder(int64_t num_axes) const; void Build(SbpSignature* ret) const { *ret = sbp_signature_; } private: SbpSignature sbp_signature_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_ ================================================ FILE: oneflow/core/job/scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/scope.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" namespace oneflow { Scope::Scope(const ScopeProto& scope_proto) : auto_increment_id_(0), symbol_id_(NullOpt), scope_proto_(scope_proto) { CHECK_OK(Init()) << scope_proto_.DebugString(); } Scope::Scope(int64_t symbol_id, const ScopeProto& scope_proto) : auto_increment_id_(0), symbol_id_(symbol_id), scope_proto_(scope_proto) {} Maybe Scope::New(int64_t symbol_id, const ScopeProto& scope_proto) { auto* ptr = new Scope(symbol_id, scope_proto); std::shared_ptr scope(ptr); JUST(scope->Init()); return scope; } Maybe Scope::Init() { { const auto& storage = *Singleton>::Get(); job_desc_ = JUST(storage.MaybeGetPtr(scope_proto_.job_desc_symbol_id())); } { const auto& storage = *Singleton>::Get(); const auto& device_parallel_desc = SymbolOf(*JUST(storage.MaybeGetPtr(scope_proto_.device_parallel_desc_symbol_id()))); const auto& host_parallel_desc = SymbolOf(*JUST(storage.MaybeGetPtr(scope_proto_.host_parallel_desc_symbol_id()))); placement_scope_ = SymbolOf(PlacementScope(device_parallel_desc, host_parallel_desc)); } { const auto& storage = *Singleton>::Get(); if (scope_proto_.has_parent_scope_symbol_id()) { parent_scope_symbol_ = JUST(storage.MaybeGetPtr(scope_proto_.parent_scope_symbol_id())); } } return Maybe::Ok(); } Maybe Scope::job_desc() const { CHECK_NOTNULL_OR_RETURN(job_desc_.get()); return job_desc_.get(); } Maybe Scope::GetParallelDescSymbolId(const OperatorConf& op_conf) const { if (op_conf.device_tag() == "cpu" || IsCpuOnly(op_conf)) { return scope_proto_.host_parallel_desc_symbol_id(); } else { return scope_proto_.device_parallel_desc_symbol_id(); } } Maybe> Scope::GetParallelDesc(const OperatorConf& op_conf) const { return placement_scope_->GetParallelDesc(op_conf.device_tag(), op_conf); } const AttrValue& Scope::GetAttrValue(const std::string& attr_name) const { const auto& iter = scope_proto_.attr_name2attr_value().find(attr_name); if (iter != scope_proto_.attr_name2attr_value().end()) { return iter->second; } const auto& attr_name2attr_def = GlobalScopeConfigDef().attr_name2attr_def(); const auto& def_iter = attr_name2attr_def.find(attr_name); CHECK(def_iter != attr_name2attr_def.end()); return def_iter->second.default_val(); } Maybe Scope::MakeChildScopeProto() const { auto child = std::make_shared(scope_proto_); child->set_parent_scope_symbol_id(JUST(symbol_id())); return child; } Maybe NewScopeSymbolId( int64_t old_scope_symbol_id, const std::function new_scope)>& InitNewScopeProto) { CHECK_OR_RETURN(Singleton>::Get()->Has(old_scope_symbol_id)); // NOLINT const Scope& old_scope = Singleton>::Get()->Get(old_scope_symbol_id); std::shared_ptr new_scope = JUST(old_scope.MakeChildScopeProto()); InitNewScopeProto(new_scope); std::shared_ptr new_scope_symbol; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { new_scope_symbol = JUST(builder->GetScopeSymbol(*new_scope)); return Maybe::Ok(); })); return JUST(new_scope_symbol->symbol_id()); } } // namespace oneflow ================================================ FILE: oneflow/core/job/scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SCOPE_H_ #define ONEFLOW_CORE_JOB_SCOPE_H_ #include "oneflow/core/job/scope.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/placement_scope.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class OperatorConf; class Scope final { public: Scope(const Scope&) = delete; Scope(Scope&&) = delete; explicit Scope(const ScopeProto& scope_proto); ~Scope() = default; static Maybe New(int64_t symbol_id, const ScopeProto& scope_proto); const Optional& symbol_id() const { return symbol_id_; } int64_t auto_increment_id() { return ++auto_increment_id_; } int64_t session_id() const { return scope_proto().session_id(); } const std::shared_ptr& job_desc_symbol() const { return job_desc_; } Symbol placement_scope() const { return placement_scope_; } Symbol device_parallel_desc_symbol() const { return placement_scope_->device_parallel_desc(); } const std::shared_ptr& parent_scope_symbol() const { return parent_scope_symbol_; } Maybe MakeChildScopeProto() const; Maybe job_desc() const; Maybe GetParallelDescSymbolId(const OperatorConf& op_conf) const; Maybe> GetParallelDesc(const OperatorConf& op_conf) const; const OptLocalParallel& opt_local_parallel_conf() const { return scope_proto_.opt_local_parallel_conf(); } const ScopeProto& scope_proto() const { return scope_proto_; } const ScopeProto& data() const { return scope_proto_; } #define DEFINE_SCOPE_CONFIG_GETTER(T, func_name, field_name) \ T func_name(const std::string& field_name) const { \ const AttrValue& attr_val = GetAttrValue(field_name); \ CHECK(attr_val.has_##field_name()); \ return attr_val.field_name(); \ } DEFINE_SCOPE_CONFIG_GETTER(bool, Bool, at_bool); DEFINE_SCOPE_CONFIG_GETTER(int64_t, Int64, at_int64); DEFINE_SCOPE_CONFIG_GETTER(double, Double, at_double); DEFINE_SCOPE_CONFIG_GETTER(const std::string&, String, at_string); private: Scope(int64_t symbol_id, const ScopeProto& scope_proto); Maybe Init(); const AttrValue& GetAttrValue(const std::string& attr_name) const; int64_t auto_increment_id_; Optional symbol_id_; const ScopeProto scope_proto_; std::shared_ptr job_desc_; Symbol placement_scope_; std::shared_ptr parent_scope_symbol_; }; Maybe NewScopeSymbolId( int64_t old_scope_symbol_id, const std::function new_scope)>& InitNewScopeProto); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SCOPE_H_ ================================================ FILE: oneflow/core/job/scope.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/local_parallel.proto"; import "oneflow/core/framework/user_op_attr.proto"; import "oneflow/core/job/module_conf.proto"; message ScopeProto { required int64 job_desc_symbol_id = 20; required int64 device_parallel_desc_symbol_id = 30; required int64 host_parallel_desc_symbol_id = 40; optional bool enable_cpu_alternative_op = 41 [default = true]; required OptLocalParallel opt_local_parallel_conf = 50; repeated string scope_op_name_prefixes = 60; optional int64 parent_scope_symbol_id = 70; required int64 session_id = 80; map attr_name2attr_value = 90; optional string calculation_pass_name = 100 [default = "forward_pass"]; optional string module_name = 110; } ================================================ FILE: oneflow/core/job/session.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/job/session.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" namespace oneflow { int64_t NewSessionId() { static std::atomic counter(0); return counter++; } ConfigProtoContext::ConfigProtoContext(const ConfigProto& config_proto) : session_id_(config_proto.session_id()) {} ConfigProtoContext::~ConfigProtoContext() {} LogicalConfigProtoContext::LogicalConfigProtoContext(const std::string& config_proto_str) { ConfigProto config_proto; CHECK(TxtString2PbMessage(config_proto_str, &config_proto)); // TODO(hanbinbin): init for worker machines config_proto_ctx_.reset(new ConfigProtoContext(config_proto)); } LogicalConfigProtoContext::~LogicalConfigProtoContext() { config_proto_ctx_.reset(); // TODO(hanbinbin): destroy ConfigProtoContext of worker machines } } // namespace oneflow ================================================ FILE: oneflow/core/job/session.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_SESSION_H_ #define ONEFLOW_CORE_JOB_SESSION_H_ #include #include namespace oneflow { int64_t NewSessionId(); class ConfigProto; class ConfigProtoContext { public: ConfigProtoContext(const ConfigProto& config_proto); ~ConfigProtoContext(); int64_t session_id() const { return session_id_; } private: int64_t session_id_; }; class LogicalConfigProtoContext { public: LogicalConfigProtoContext(const std::string& config_proto_str); ~LogicalConfigProtoContext(); std::unique_ptr config_proto_ctx_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SESSION_H_ ================================================ FILE: oneflow/core/job/ssp_config_def.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/config_def.h" namespace oneflow { namespace { REGISTER_FUNCTION_CONFIG_DEF() .Bool("enable_ssp", false, "enable ssp") .String("ssp_partition_strategy", "naive_sequential", "ssp partition strategy, Avaiable strategies: naive_sequential | disable") .ListInt64("ssp_partition_scope_ids", {}, "type: list[int64]. ssp partition scope symbol ids"); REGISTER_SCOPE_CONFIG_DEF() .Int64("ssp_num_stages", -1, "total number of ssp stages") .Int64("ssp_stage_id", -1, "current ssp stage id "); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job/sub_plan.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/task.proto"; message ThrdIds { repeated int64 thrd_id = 1; } message ClusterThrdIds { map machine_id2thrd_ids = 1; } message SubPlan { repeated TaskProto task = 1; } ================================================ FILE: oneflow/core/job/task.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/graph/exec_sequence.proto"; import "oneflow/core/register/register_desc.proto"; import "oneflow/core/job/placement.proto"; enum TaskType { kInvalid = 0; kNormalForward = 1; kCopyHd = 12; kCopyCommNet = 13; kDeviceTick = 27; kPack = 30; kUnpack = 31; kRepeat = 32; kAcc = 33; kAccCtrlTick = 34; kSrcSubsetTick = 38; kDstSubsetTick = 39; kSourceTick = 40; kTick = 41; kAccTick = 42; kCase = 43; kEsac = 44; kWaitAndSendIds = 45; kReentrantLock = 46; kCallbackNotify = 47; kDistributeConcat = 55; kDistributeSplit = 56; kSliceBoxing = 57; kCollectiveBoxingGeneric = 58; kBoxingIdentity = 59; kDecodeH2D = 60; kCollectiveBoxingPack = 61; kCollectiveBoxingUnpack = 62; kSspVariableProxy = 63; kBoxingZeros = 64; kCriticalSectionWaitTick = 65; kNcclSendRecvBoxing = 66; }; message RegstDescIdSet { repeated int64 regst_desc_id = 1; } message TaskProto { // common required TaskType task_type = 1; required int64 machine_id = 2; required int64 thrd_id = 3; required int64 task_id = 4; required int64 job_id = 5; required ExecSequence exec_sequence = 7; map produced_regst_desc = 8; map consumed_regst_desc_id = 9; optional bool all_register_num_eq_one_hint = 10 [default = false]; required int64 chain_id = 20; required int64 order_in_chain = 21; // compute task optional ParallelContext parallel_ctx = 1000; // CompTask }; ================================================ FILE: oneflow/core/job/utils/progress_bar.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job/utils/progress_bar.h" #include "oneflow/core/job/graph_scope_vars.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { Maybe LogProgress(const std::string& task_name, bool is_end) { const bool log_progress = GetGraphDebugMode() || ThreadLocalEnvBool(); if (!log_progress || OF_PREDICT_FALSE(GlobalProcessCtx::Rank() != 0)) { return Maybe::Ok(); } const static thread_local uint64_t progress_total_num = 60; static thread_local uint64_t progress_cnt = 1; static constexpr char clear_line[] = " \r"; auto const& limited_str = task_name.size() > 60 ? task_name.substr(0, 60) : task_name; std::cout << clear_line << "[" << progress_cnt << "/" << progress_total_num << "]" << limited_str << "\r" << std::flush; if (is_end) { progress_cnt = 0; std::cout << clear_line << std::endl << std::flush; } ++progress_cnt; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/utils/progress_bar.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_ #define ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/env_var/env_var.h" namespace oneflow { DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_NNGRAPH_ENABLE_PROGRESS_BAR, false); Maybe LogProgress(const std::string& task_name = "", bool is_end = false); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_ ================================================ FILE: oneflow/core/job/version.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/version.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { void DumpVersionInfo() { LOG(INFO) << "OneFlow git version: " << GetOneFlowGitVersion(); ep::DeviceManagerRegistry::DumpVersionInfo(); } } // namespace oneflow ================================================ FILE: oneflow/core/job/version.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_VERSION_H_ #define ONEFLOW_CORE_JOB_VERSION_H_ #include "oneflow/core/common/util.h" namespace oneflow { const char* GetOneFlowGitVersion(); void DumpVersionInfo(); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_VERSION_H_ ================================================ FILE: oneflow/core/job_rewriter/adadelta_optim.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/initializer_conf.pb.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/variable_op.h" namespace oneflow { namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateAdadeltaHelperVariableConf(const VariableOp& op, const std::string& name) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(0.0f); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void GenerateAdadeltaOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); user_op::UserOpConfWrapperBuilder adadelta_update_op_builder(var_op->op_name() + "_optimizer"); float rho = 0.0; float epsilon = 0.0; bool maximize = false; const AdadeltaModelUpdateConf& adadelta_conf = optimizer_conf.adadelta_conf(); rho = adadelta_conf.rho(); epsilon = adadelta_conf.epsilon(); maximize = adadelta_conf.maximize(); const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); OperatorConf square_avgs_var(GenerateAdadeltaHelperVariableConf(*var_op, "square_avgs")); OperatorConf acc_deltas_var(GenerateAdadeltaHelperVariableConf(*var_op, "acc_deltas")); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {square_avgs_var, acc_deltas_var}); adadelta_update_op_builder.OpTypeName("adadelta_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", learning_rate_lbn) .Input("square_avgs", GenVariableOutputLbn(square_avgs_var)) .Input("acc_deltas", GenVariableOutputLbn(acc_deltas_var)) .Attr("rho", rho) .Attr("epsilon", epsilon) .Attr("maximize", maximize) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { adadelta_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &adadelta_update_op_builder); const auto adadelta_update_op = adadelta_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adadelta_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kAdadeltaConf, &GenerateAdadeltaOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/adagrad_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/initializer_conf.pb.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/variable_op.h" namespace oneflow { namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateAdagradHelperVariableConf(const VariableOp& op, const std::string& name, const float initial_value) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(initial_value); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void GenerateAdagradOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); user_op::UserOpConfWrapperBuilder adagrad_update_op_builder(var_op->op_name() + "_optimizer"); float lr_decay = 0.0; float initial_accumulator_value = 0.0; float epsilon = 0.0; const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf(); lr_decay = adagrad_conf.lr_decay(); initial_accumulator_value = adagrad_conf.initial_accumulator_value(); epsilon = adagrad_conf.epsilon(); const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn(); const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); OperatorConf sum_var( GenerateAdagradHelperVariableConf(*var_op, "sum", initial_accumulator_value)); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {sum_var}); adagrad_update_op_builder.OpTypeName("adagrad_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", learning_rate_lbn) .Input("train_step", train_step_lbn) .Input("sum", GenVariableOutputLbn(sum_var)) .Attr("epsilon", epsilon) .Attr("lr_decay", lr_decay) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { adagrad_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &adagrad_update_op_builder); const auto adagrad_update_op = adagrad_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adagrad_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kAdagradConf, &GenerateAdagradOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/adam_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { struct BiasCorrectionFactorCacheKey { float beta = 1.0; ParallelConf parallel_conf; }; bool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs) { return (lhs.beta == rhs.beta) && (lhs.parallel_conf == rhs.parallel_conf); } } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const { using namespace oneflow; return Hash(key.beta, key.parallel_conf); } }; } // namespace std namespace oneflow { class BiasCorrectionFactorState final : public JobPassState { public: BiasCorrectionFactorState() {} ~BiasCorrectionFactorState() override = default; std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf, const std::function& BiasCorrectionFactorStateOp) { BiasCorrectionFactorCacheKey cache_key; cache_key.beta = beta; cache_key.parallel_conf = parallel_conf; const auto& iter = key2lbn_.find(cache_key); if (iter != key2lbn_.end()) { return iter->second; } else { std::string lbn = BiasCorrectionFactorStateOp(beta, std::move(bias_correction_name)); key2lbn_.emplace(cache_key, lbn); return lbn; } } private: HashMap key2lbn_; }; namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name, const float initial_value) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(initial_value); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); user_op::UserOpConfWrapperBuilder adam_update_op_builder(var_op->op_name() + "_optimizer"); float beta1 = 0.9; float beta2 = 0.999; float epsilon = 1e-8; bool do_bias_correction = true; bool amsgrad = false; if (optimizer_conf.has_adam_conf()) { const AdamModelUpdateConf& adam_conf = optimizer_conf.adam_conf(); beta1 = adam_conf.beta1(); beta2 = adam_conf.beta2(); epsilon = adam_conf.epsilon(); do_bias_correction = adam_conf.do_bias_correction(); amsgrad = adam_conf.amsgrad(); } else if (optimizer_conf.has_lazy_adam_conf()) { const LazyAdamModelUpdateConf& lazy_adam_conf = optimizer_conf.lazy_adam_conf(); beta1 = lazy_adam_conf.beta1(); beta2 = lazy_adam_conf.beta2(); epsilon = lazy_adam_conf.epsilon(); do_bias_correction = lazy_adam_conf.do_bias_correction(); amsgrad = lazy_adam_conf.amsgrad(); } else { UNIMPLEMENTED(); } OperatorConf m_var(GenerateAdamHelperVariableOpConf(*var_op, "m", 0.f)); OperatorConf v_var(GenerateAdamHelperVariableOpConf(*var_op, "v", 0.f)); OperatorConf max_v_var{}; if (amsgrad) { max_v_var = GenerateAdamHelperVariableOpConf(*var_op, "max_v", 0.f); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var, max_v_var}); } else { job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var}); } const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn(); const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); adam_update_op_builder.OpTypeName("adam_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", learning_rate_lbn) .Input("m", GenVariableOutputLbn(m_var)) .Input("v", GenVariableOutputLbn(v_var)) .Attr("beta1", beta1) .Attr("beta2", beta2) .Attr("epsilon", epsilon) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .Attr("amsgrad", amsgrad) .Attr("do_bias_correction", do_bias_correction) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (do_bias_correction) { const std::string& job_pass_state_key = "adam_bias_correction_factor"; const bool has_state = CHECK_JUST(ctx->HasState(job_pass_state_key)); if (!has_state) { CHECK_JUST( ctx->ResetState(job_pass_state_key, std::make_unique())); } auto* state = CHECK_JUST(ctx->MutableState(job_pass_state_key)); ParallelConf bias_correction_parallel_conf; const auto& lr_parallel_conf = CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn))); const auto& train_step_parallel_conf = CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn))); if (lr_parallel_conf == train_step_parallel_conf) { bias_correction_parallel_conf = lr_parallel_conf; } else { bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf(); } auto AddAdamBiasCorrectionFactorOp = [&](float beta_val, const std::string& op_name) -> std::string { user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name); const auto adam_bias_correction_factor_op = op_builder.OpTypeName("adam_bias_correction_factor") .Input("train_step", train_step_lbn) .Attr("beta", beta_val) .Output("out") .ScopeSymbolId(var_op->op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(bias_correction_parallel_conf, {adam_bias_correction_factor_op.op_conf()}); return adam_bias_correction_factor_op.output("out", 0); }; const std::string bias_correction1_lbn = state->GetLbn(beta1, "adam_bias_correction_factor1", bias_correction_parallel_conf, AddAdamBiasCorrectionFactorOp); const std::string bias_correction2_lbn = state->GetLbn(beta2, "adam_bias_correction_factor2", bias_correction_parallel_conf, AddAdamBiasCorrectionFactorOp); adam_update_op_builder.Input("bias_correction1", bias_correction1_lbn) .Input("bias_correction2", bias_correction2_lbn); } if (amsgrad) { adam_update_op_builder.Input("max_v", GenVariableOutputLbn(max_v_var)); } if (optimizer_conf.has_lr_scale()) { adam_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &adam_update_op_builder); const auto adam_update_op = adam_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adam_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kAdamConf, &GenerateOptimizerOpConf); REGISTER_OPTIMIZER(OptimizerConf::kLazyAdamConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/add_ssp_variable_proxy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class AddSspVariableProxyPass final : public JobPass { public: AddSspVariableProxyPass(const AddSspVariableProxyPass&) = delete; AddSspVariableProxyPass(AddSspVariableProxyPass&&) = delete; AddSspVariableProxyPass() = default; ~AddSspVariableProxyPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain() && ctx.job_desc().Bool("enable_ssp"); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap> var2ref_value_pair; HashSet var_consumers; HashSet trainable_variable_op_names; const Job& job = job_builder->job(); for (const auto& optimizer_conf : job.job_conf().train_conf().optimizer_conf()) { for (const auto& variable_op_name : optimizer_conf.variable_op_names()) { trainable_variable_op_names.insert(variable_op_name); } } auto IsTrainableVarOp = [&](const OperatorConf& op_conf) { if (!op_conf.has_variable_conf()) { return false; } return trainable_variable_op_names.count(op_conf.name()) > 0; }; JUST(ForEachTrainableVarOpNode(op_graph, IsTrainableVarOp, [&](OpNode* op_node) -> Maybe { op_node->ForEachNodeOnOutEdge([&](OpNode* consumer) { var_consumers.insert(consumer); }); const auto& old_var_out_lbi = op_node->op().BnInOp2Lbi("out"); return AddSspVarProxyOp(op_node, job_builder, &var2ref_value_pair[old_var_out_lbi].first, &var2ref_value_pair[old_var_out_lbi].second); })); { const auto& NeedReplace = [&](const LogicalBlobId& var_lbi) -> bool { return var2ref_value_pair.count(var_lbi) > 0; }; const auto& Ref4Var = [&](const LogicalBlobId& var_lbi) -> const std::string& { return var2ref_value_pair.at(var_lbi).first; }; const auto& Val4Var = [&](const LogicalBlobId& var_lbi) -> const std::string& { return var2ref_value_pair.at(var_lbi).second; }; for (OpNode* op_node : var_consumers) { JUST(ReplaceVarWithSspVarProxyOp(op_node, job_builder, NeedReplace, Ref4Var, Val4Var)); } } return Maybe::Ok(); } Maybe ForEachTrainableVarOpNode( const OpGraph& op_graph, const std::function& IsTrainableVarOp, const std::function(OpNode*)>& DoEach) const { const auto& IsSspVarProxy = [](const OperatorConf& op_conf) { return op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == "ssp_variable_proxy"; }; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { const auto& op_conf = op_node->op().op_conf(); CHECK_OR_RETURN(!IsSspVarProxy(op_conf)) << "AddSspVariableProxy can not be applied twice"; if (IsTrainableVarOp(op_conf)) { return DoEach(op_node); } return Maybe::Ok(); })); return Maybe::Ok(); } Maybe AddSspVarProxyOp(OpNode* op_node, JobBuilder* job_builder, std::string* ref_lbn, std::string* value_lbn) const { const LogicalBlobId& old_var_out_lbi = op_node->op().BnInOp2Lbi("out"); int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id(); JUST(AddSspVarProxyOp(old_var_out_lbi, scope_symbol_id, job_builder, ref_lbn, value_lbn)); return Maybe::Ok(); } Maybe ReplaceVarWithSspVarProxyOp( OpNode* op_node, JobBuilder* job_builder, const std::function& NeedReplace, const std::function& Ref4Var, const std::function& Val4Var) const { const auto& op = op_node->op(); std::unique_ptr> new_op_confs; for (const auto& ibn : op.input_bns()) { const auto& lbi = op.BnInOp2Lbi(ibn); if (!NeedReplace(lbi)) { continue; } if (!new_op_confs) { new_op_confs.reset(new std::vector({op.op_conf()})); } auto* new_op_conf = &new_op_confs->at(0); int64_t scope_symbol_id = op.op_conf().scope_symbol_id(); bool in_optimizer_pass = JUST(IsInOptimizerPass(scope_symbol_id)); const auto* lbn = (in_optimizer_pass ? &Ref4Var(lbi) : &Val4Var(lbi)); ReplaceInputLbnInOpCustomizedConf(new_op_conf, ibn, *lbn); } if (new_op_confs) { job_builder->MutOpsOnlyOnce(*new_op_confs); } return Maybe::Ok(); } Maybe IsInOptimizerPass(int64_t scope_symbol_id) const { const auto& scope = JUST(Singleton>::Get()->MaybeGet(scope_symbol_id)); return scope.scope_proto().calculation_pass_name() == kOptimizerPass; } Maybe AddSspVarProxyOp(const LogicalBlobId& old_var_out_lbi, int64_t scope_symbol_id, JobBuilder* job_builder, std::string* ref_lbn, std::string* value_lbn) const { const Scope& scope = JUST(Singleton>::Get()->MaybeGet(scope_symbol_id)); int64_t buffer_size = 0; { int64_t num_stages = scope.Int64("ssp_num_stages"); int64_t stage_id = scope.Int64("ssp_stage_id"); CHECK_GT(num_stages, 0); CHECK_GE(stage_id, 0); CHECK_LT(stage_id, num_stages); buffer_size = num_stages - stage_id; } std::string op_name = old_var_out_lbi.op_name() + "_ssp_variable_proxy"; const auto proxy_op = user_op::UserOpConfWrapperBuilder(op_name) .Op("ssp_variable_proxy") .ScopeSymbolId(scope_symbol_id) .Input("var", GenLogicalBlobName(old_var_out_lbi)) .Output("ref") .Output("value") .Attr("buffer_size", buffer_size) .Build(); const auto& parallel_desc = *JUST(scope.GetParallelDesc(proxy_op.op_conf())); job_builder->AddOps(parallel_desc.parallel_conf(), {proxy_op.op_conf()}); *ref_lbn = op_name + "/ref_0"; *value_lbn = op_name + "/value_0"; return Maybe::Ok(); } }; REGISTER_JOB_PASS("AddSspVariableProxy", AddSspVariableProxyPass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/auto_learning_rate.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class AutoLearningRate final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(AutoLearningRate); AutoLearningRate() = default; ~AutoLearningRate() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); } Maybe Apply(const OpGraph& op_graph, Job* job) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); return Apply(op_graph, job); } }; Maybe AutoLearningRate::Apply(const OpGraph& op_graph, Job* job) const { JobBuilder job_builder(job); const TrainConf& train_conf = job->job_conf().train_conf(); auto AddScheduleOp = [&](const OptimizerConf& optimizer_conf, const std::string& op_name) -> std::string { const class oneflow::OpNode* op_node = op_graph.OpNode4OpName(GenLogicalBlobId(train_conf.train_step_lbn()).op_name()); CHECK_OR_RETURN(op_node != nullptr) << "op node not found in op graph, op name: " << op_name; const ParallelConf& parallel_conf = op_node->parallel_desc().parallel_conf(); OperatorConf schedule_op_conf{}; schedule_op_conf.set_name(op_name); auto* schedule_conf = schedule_op_conf.mutable_learning_rate_schedule_conf(); schedule_conf->set_train_step(train_conf.train_step_lbn()); schedule_conf->set_learning_rate(optimizer_conf.base_learning_rate()); schedule_conf->set_out("out"); if (optimizer_conf.has_learning_rate_decay()) { *schedule_conf->mutable_learning_rate_decay() = optimizer_conf.learning_rate_decay(); } schedule_op_conf.set_scope_symbol_id(op_node->op().op_conf().scope_symbol_id()); job_builder.AddOps(parallel_conf, {schedule_op_conf}); return GenLogicalBlobName(op_name, schedule_conf->out()); }; FOR_RANGE(int64_t, i, 0, train_conf.optimizer_conf_size()) { const auto& optimizer_conf = train_conf.optimizer_conf(i); const std::string& lbn = AddScheduleOp(optimizer_conf, "System-Train-LearningRate-Scheduler_" + NewUniqueId()); job->mutable_job_conf()->mutable_train_conf()->mutable_optimizer_conf(i)->set_learning_rate_lbn( lbn); } return Maybe::Ok(); } REGISTER_JOB_PASS("AutoLearningRate", AutoLearningRate); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/auto_mixed_precision.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/auto_mixed_precision.h" #include "oneflow/core/job_rewriter/auto_mixed_precision_lists.h" #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/pass_util.h" #include "oneflow/core/job/job_desc.h" namespace oneflow { namespace { void VerifyAMPList(const AMPList& amp_list) { for (const auto& op_type : amp_list) { CHECK(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) << "Cannot find " << op_type << " of AutoMixedPrecision list in OpRegistry."; } } using NoCastRegistry = std::multimap; NoCastRegistry* GetNoCastRegistry() { static NoCastRegistry s_registry; return &s_registry; } bool FindInNoCastRegisry(const std::string& op_type, const OpArg& op_arg) { auto range = GetNoCastRegistry()->equal_range(op_type); for (auto it = range.first; it != range.second; ++it) { if (it->second == op_arg) { return true; } } return false; } std::function MakePredicatorIsAllowedToRunWithHalf(const OpGraph& op_graph) { auto allowed_set = std::make_shared>(); op_graph.ForEachNode([&](OpNode* node) { // half computation is not supported on cpu if (node->parallel_desc().device_type() == DeviceType::kCPU) { return; } if (node->op().output_bns().size() > 0 || IsUserOpWithTypeName(node->op().op_conf(), "one_embedding_fused_lookup_grad")) { INSERT_CHECK(allowed_set->insert(node)); } }); return [allowed_set](OpNode* node) -> bool { return IsKeyFound(*allowed_set, node); }; } void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet& white_set, const DataType mixed_precision_data_type, JobBuilder* job_builder) { HashSet white_set_edges; { std::function&(OpNode*)> Node2Edges = f2h ? &OpNode::in_edges : &OpNode::out_edges; std::function OppositeNode = f2h ? &OpEdge::src_node : &OpEdge::dst_node; op_graph.ForEachNode([&](OpNode* node) { if (IsKeyFound(white_set, node)) { for (OpEdge* edge : Node2Edges(node)) { if (!IsKeyFound(white_set, OppositeNode(edge))) { INSERT_CHECK(white_set_edges.insert(edge)); } } } }); auto EdgeName4Edge = [](OpEdge* const& edge) { return std::string("edge of\t") + edge->src_node()->op().op_name() + "\tto\t" + edge->dst_node()->op().op_name(); }; VLOG(3) << "white_set_edges for f2h value: " << f2h << " is " << Container2Str, OpEdge*>(white_set_edges, EdgeName4Edge); } HashMap> edges_group_by_lbn; { for (OpEdge* edge : white_set_edges) { for (const auto& lbi : edge->lbis()) { std::string lbn = GenLogicalBlobName(lbi); edges_group_by_lbn[lbn].emplace_back(edge); } } } HashMap dst_op_name2dst_op_confs; for (auto& pair : edges_group_by_lbn) { const std::string& lbn = pair.first; LogicalBlobId cur_lbi = GenLogicalBlobId(lbn); OpNode* src_node = pair.second.front()->src_node(); const BlobDesc& blob_desc = src_node->LogicalBlobDesc4Lbi(cur_lbi); if (blob_desc.data_type() != DataType::kFloat) { continue; } std::string cast_suffix = f2h ? "-cast_f2h" : "-cast_h2f"; DataType cast_data_type = f2h ? mixed_precision_data_type : DataType::kFloat; auto cast_op = user_op::UserOpConfWrapperBuilder(ReplaceSlashToDash4Lbn(lbn) + cast_suffix) .Op("cast") .Input("in", lbn) .Output("out") .Attr("dtype", cast_data_type) .ScopeSymbolId(src_node->op().op_conf().scope_symbol_id()) .Build(); bool cast_is_consumed = false; for (OpEdge* edge : pair.second) { CHECK(src_node == edge->src_node()); OpNode* dst_node = edge->dst_node(); const auto& dst_ibns = edge->lbi2ibns().at(cur_lbi); for (const auto& dst_ibn : dst_ibns) { if (dst_node->op().op_conf().has_user_conf()) { const std::string& op_type = dst_node->op().op_conf().user_conf().op_type_name(); const auto& op_arg = GenUnRepeatedBn(dst_ibn); if (FindInNoCastRegisry(op_type, op_arg)) { continue; } } cast_is_consumed = true; const std::string& dst_op_name = dst_node->op().op_name(); if (!IsKeyFound(dst_op_name2dst_op_confs, dst_op_name)) { INSERT_CHECK(dst_op_name2dst_op_confs.insert( std::make_pair(dst_op_name, dst_node->op().op_conf()))); } OperatorConf& dst_op_conf = dst_op_name2dst_op_confs.at(dst_op_name); std::string new_lbn = cast_op.op_name() + "/out_0"; CHECK_EQ(lbn, ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, new_lbn)); } } if (cast_is_consumed) { job_builder->AddOps(src_node->parallel_desc().parallel_conf(), std::vector{cast_op.op_conf()}); VLOG(3) << "Insert CastOp: " << cast_op.op_name() << " between " << lbn; } } std::vector dst_op_confs; dst_op_confs.reserve(dst_op_name2dst_op_confs.size()); for (const auto& pair : dst_op_name2dst_op_confs) { dst_op_confs.emplace_back(pair.second); } // make sure an op_conf can only be udpated once, cuz later update will override before job_builder->MutOpsOnlyOnce(dst_op_confs); } class AutoMixedPrecision final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision); AutoMixedPrecision() : white_list_(AutoMixedPrecisionLists::WhiteList()), black_list_(AutoMixedPrecisionLists::BlackList()), gray_list_(AutoMixedPrecisionLists::GrayList()), clear_list_(AutoMixedPrecisionLists::ClearList()) {} ~AutoMixedPrecision() = default; bool IsEnabled(const JobPassCtx& ctx) const { #if defined(WITH_CUDA) && defined(CUDA_VERSION) && CUDA_VERSION < 10000 return false; #else return ctx.job_desc().enable_auto_mixed_precision(); #endif } Maybe Apply(Job* job, JobPassCtx* ctx) const override; private: void FillBlackSet(const OpGraph& op_graph, HashSet* black_set) const; void FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, const HashSet& black_set, HashSet* white_set) const; void PropagateWhiteThroughClearNodes(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, const HashSet& black_set, HashSet* white_set) const; void InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, const DataType mixed_precision_data_type, JobBuilder* job_builder) const; const AMPList& white_list_; const AMPList& black_list_; const AMPList& gray_list_; const AMPList& clear_list_; }; Maybe AutoMixedPrecision::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().enable_auto_mixed_precision()) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat); VerifyAMPList(white_list_); VerifyAMPList(black_list_); VerifyAMPList(gray_list_); VerifyAMPList(clear_list_); std::function OpName4Node = [](OpNode* const& node) { return node->op().op_name(); }; HashSet black_set; HashSet white_set; FillBlackSet(op_graph, &black_set); VLOG(3) << "BlackSet include: " << Container2Str, OpNode*>(black_set, OpName4Node); auto IsAllowedToRunWithHalf = MakePredicatorIsAllowedToRunWithHalf(op_graph); FillWhiteSet(op_graph, IsAllowedToRunWithHalf, black_set, &white_set); VLOG(3) << "WhiteSet Before Propagate include: " << Container2Str, OpNode*>(white_set, OpName4Node); PropagateWhiteThroughClearNodes(op_graph, IsAllowedToRunWithHalf, black_set, &white_set); VLOG(2) << "WhiteSet include: " << Container2Str, OpNode*>(white_set, OpName4Node); const DataType mixed_precision_data_type = ctx->job_desc().mixed_precision_data_type(); CHECK(mixed_precision_data_type == DataType::kFloat16 || mixed_precision_data_type == DataType::kBFloat16); InsertCastOp(op_graph, white_set, mixed_precision_data_type, &job_builder); return Maybe::Ok(); } void AutoMixedPrecision::FillBlackSet(const OpGraph& op_graph, HashSet* black_set) const { HashSet upstream_or_part_of_black_and_gray; DfsTopoGraphTraversal( op_graph, true, [&](OpNode* node) { return IsNodeInList(black_list_, node) || IsNodeInList(gray_list_, node); }, [&](OpNode* node) { return IsNodeInList(clear_list_, node); }, [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_black_and_gray, node); }, [&](OpNode* node) { INSERT_CHECK(upstream_or_part_of_black_and_gray.insert(node)); VLOG(3) << "FillBlackSet(): Insert " << node->op().op_name() << " to upstream_or_part_of_black_and_gray"; }); // propagate black through upstream_or_part_of_black_and_gray DfsTopoGraphTraversal( op_graph, false, [&](OpNode* node) { return IsNodeInList(black_list_, node); }, [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_black_and_gray, node); }, [&](OpNode* node) { return IsKeyFound(*black_set, node); }, [&](OpNode* node) { INSERT_CHECK(black_set->insert(node)); VLOG(3) << "FillBlackSet(): Insert " << node->op().op_name() << " to black_set"; }); } void AutoMixedPrecision::FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, const HashSet& black_set, HashSet* white_set) const { auto IsWhiteOrSinkAndAllowedToRunHalf = [&](OpNode* node) { return IsAllowedToRunWithHalf(node) && (IsNodeInList(white_list_, node) || (node->out_edges().empty() && (IsNodeInList(gray_list_, node) || IsNodeInList(clear_list_, node)))); }; HashSet upstream_or_part_of_white; DfsTopoGraphTraversal( op_graph, true, IsWhiteOrSinkAndAllowedToRunHalf, [&](OpNode* node) { return !IsKeyFound(black_set, node) && IsAllowedToRunWithHalf(node) && (IsNodeInList(gray_list_, node) || IsNodeInList(clear_list_, node)); }, [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_white, node); }, [&](OpNode* node) { INSERT_CHECK(upstream_or_part_of_white.insert(node)); VLOG(3) << "FillWhiteSet(): Insert " << node->op().op_name() << " to upstream_or_part_of_white"; }); auto IsWhiteAndAllowedToRunHalf = [&](OpNode* node) { return IsAllowedToRunWithHalf(node) && IsNodeInList(white_list_, node); }; DfsTopoGraphTraversal( op_graph, false, IsWhiteAndAllowedToRunHalf, [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_white, node); }, [&](OpNode* node) { return IsKeyFound(*white_set, node); }, [&](OpNode* node) { INSERT_CHECK(white_set->insert(node)); VLOG(3) << "FillWhiteSet(): Insert " << node->op().op_name() << " to white_set"; }); } void AutoMixedPrecision::PropagateWhiteThroughClearNodes( const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, const HashSet& black_set, HashSet* white_set) const { auto PropagateIntoOneDirection = [&](bool is_downward) { DfsTopoGraphTraversal( op_graph, !is_downward, [&](OpNode* node) { return false; }, [&](OpNode* node) { return !IsKeyFound(*white_set, node) && !IsKeyFound(black_set, node) && IsNodeInList(clear_list_, node) && IsAllowedToRunWithHalf(node); }, [&](OpNode* node) { return IsKeyFound(*white_set, node); }, [&](OpNode* node) { INSERT_CHECK(white_set->insert(node)); VLOG(3) << "PropagateWhiteThroughNonListNodes(): Insert " << node->op().op_name() << " to white_set"; }); }; PropagateIntoOneDirection(true); PropagateIntoOneDirection(false); } void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, const DataType mixed_precision_data_type, JobBuilder* job_builder) const { InsertCastOpImpl(true, op_graph, white_set, mixed_precision_data_type, job_builder); InsertCastOpImpl(false, op_graph, white_set, mixed_precision_data_type, job_builder); } REGISTER_JOB_PASS("AutoMixedPrecision", AutoMixedPrecision); } // namespace namespace { struct NoCastRegistrar final { NoCastRegistrar(const std::string& op_type, OpArg&& op_arg) { auto* registry = GetNoCastRegistry(); registry->emplace(std::make_pair(op_type, std::move(op_arg))); } ~NoCastRegistrar() = default; }; #define REGISTER_NO_CAST_REGISTRY(op_type, input_arg_name, idx) \ static NoCastRegistrar OF_PP_CAT(g_registrar, __COUNTER__)(op_type, \ std::make_pair(input_arg_name, idx)); // For Example: // REGISTER_NO_CAST_REGISTRY("matmul", "b", 0); REGISTER_NO_CAST_REGISTRY("normalization", "moving_mean", 0) REGISTER_NO_CAST_REGISTRY("normalization", "moving_variance", 0) REGISTER_NO_CAST_REGISTRY("normalization", "gamma", 0) REGISTER_NO_CAST_REGISTRY("normalization", "beta", 0) REGISTER_NO_CAST_REGISTRY("normalization_grad", "gamma", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_mean", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_variance", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "gamma", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "beta", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "gamma", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "beta", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "mean", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "inv_variance", 0) REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "reserve_space", 0) REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0) REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0) REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0) REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0) REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "mean", 0) REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "inv_variance", 0) } // namespace namespace amp { bool IsNoCast(const std::string& op_type, const OpArg& op_arg) { return FindInNoCastRegisry(op_type, op_arg); } } // namespace amp } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/auto_mixed_precision.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_ #define ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_ #include namespace oneflow { using OpArg = std::pair; namespace amp { bool IsNoCast(const std::string& op_type, const OpArg& op_arg); } // namespace amp } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_ ================================================ FILE: oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/auto_mixed_precision_lists.h" namespace oneflow { const AMPList& AutoMixedPrecisionLists::WhiteList() { static AMPList white_list = {"matmul", "batch_matmul", "conv2d", "conv_data_grad", "conv_filter_grad", "conv_bias_grad", "amp_white_identity", "broadcast_matmul", "broadcast_matmul_grad_b", "fused_self_attention_query_mul_key_and_value", "fused_self_attention_query_mul_key_and_value_grad", "prelu", "prelu_grad", "tf_prelu", "tf_prelu_grad", "cublas_fused_mlp", "cublas_fused_mlp_grad", "fused_matmul_bias", "cublas_bias_add_relu_matmul_grad", "fused_glu", "fused_glu_without_linear_grad", "fused_matmul_bias_add_relu_dropout", "fused_relu_dropout_grad", "fused_dot_feature_interaction", "fused_dot_feature_interaction_grad", "one_embedding_fused_lookup", "one_embedding_fused_lookup_grad", "binary_cross_entropy_with_logits_reduce_mean", "binary_cross_entropy_with_logits_reduce_mean_grad", "fused_cross_feature_interaction", "fused_cross_feature_interaction_v1_grad", "fused_cross_feature_interaction_v2_grad", "fused_multi_head_attention_inference", "grouped_matmul_bias"}; return white_list; } const AMPList& AutoMixedPrecisionLists::BlackList() { // TODO(niuchong): reduce_mean? static AMPList black_list = {"amp_black_identity"}; return black_list; } const AMPList& AutoMixedPrecisionLists::GrayList() { static AMPList gray_list = {"add_n", "tf_avg_pool_1d", "tf_avg_pool_1d_grad", "tf_avg_pool_2d", "tf_avg_pool_2d_grad", "tf_avg_pool_3d", "tf_avg_pool_3d_grad", "avg_pool_1d", "avg_pool_1d_grad", "avg_pool_2d", "avg_pool_2d_grad", "avg_pool_3d", "avg_pool_3d_grad", "bias_add", "reduce_sum", "reduce_sum_like", "sigmoid_grad", "tanh", "tanh_grad", "sqrt", "sqrt_grad", "scalar_mul", "scalar_mul_by_tensor", "scalar_add", "scalar_div", "scalar_pow", "broadcast_add", "broadcast_sub", "broadcast_mul", "broadcast_div", "layer_norm", "layer_norm_param_grad", "layer_norm_grad", "fuse_layer_norm_grad", "skip_layer_norm", "rms_norm", "rms_norm_grad", "rms_norm_param_grad", "dropout", "dropout_grad", "softmax", "softmax_grad", "log_softmax", "log_softmax_grad", "gelu", "gelu_grad", "fast_gelu", "fast_gelu_grad", "normalization", "normalization_grad", "normalization_add_relu", "normalization_add_relu_grad", "sparse_softmax_cross_entropy", "sparse_softmax_cross_entropy_grad", "nll", "nll_grad", "fused_tril_scale_softmax_mask_scale", "fused_tril_scale_softmax_mask_scale_grad", "fused_scale_mask_softmax_dropout", "fused_scale_mask_softmax_dropout_grad", "fused_scale_mask_softmax", "fused_scale_mask_softmax_grad", "fused_bias_add_scale_mask_softmax_dropout", "fused_bias_add_gelu", "fused_bias_add_gelu_grad", "fused_bias_add_mask_scale", "fused_fast_gelu_mul", "fused_fast_gelu_mul_grad", "acc", "reciprocal", "reciprocal_no_nan", "group_norm", "group_norm_param_grad", "group_norm_grad", "silu", "silu_grad", "fused_weighted_sum"}; return gray_list; } const AMPList& AutoMixedPrecisionLists::ClearList() { // TODO(niuchong): tuple_identity static AMPList clear_list = {"broadcast_like", "gather", "gather_nd", "scatter_nd", "scatter_nd_like", "unsorted_segment_sum_like", "tf_max_pool_1d", "tf_max_pool_1d_grad", "tf_max_pool_2d", "tf_max_pool_2d_grad", "tf_max_pool_3d", "tf_max_pool_3d_grad", "max_pool_1d", "max_pool_1d_grad", "max_pool_2d", "max_pool_2d_grad", "max_pool_3d", "max_pool_3d_grad", "reshape", "reshape_like", "relu", "relu_grad", "transpose", "random_mask_like", "cat", "split_like", "pad", "same_padding", "same_padding_grad", "tril", "slice", "slice_grad", "fused_scale_tril", "identity", "squeeze", "embedding", "embedding_grad", "expand", "expand_dims", "cast_to_static_shape", "parallel_cast", "hierarchical_parallel_cast", "hierarchical_parallel_cast_like", "repeat", "unpack", "pack", "nvtx_start", "nvtx_end", "narrow", "narrow_grad", "ones_like", "pinned_identity", "to_contiguous", "copy", "where", "upsample_nearest_2d", "fill_"}; return clear_list; } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/auto_mixed_precision_lists.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_ #define ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/operator/op_conf_util.h" namespace oneflow { typedef HashSet AMPList; class AutoMixedPrecisionLists final { public: // TODO(niuchong): list include grad static const AMPList& WhiteList(); static const AMPList& BlackList(); static const AMPList& GrayList(); static const AMPList& ClearList(); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_ ================================================ FILE: oneflow/core/job_rewriter/auto_parallel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/auto_parallel/sbp_constructor.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace { class AutoParallelPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(AutoParallelPass); AutoParallelPass() = default; ~AutoParallelPass() override = default; Maybe Apply(const OpGraph& op_graph, Job* job) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!job->job_conf().enable_auto_parallel()) { return Maybe::Ok(); } VLOG(3) << "=== Enable AutoParallel ==="; if (job->job_conf().enable_auto_parallel_ignore_user_sbp_config()) { JUST(RemoveParallelCastOps(job)); } const OpGraph op_graph(*job); return Apply(op_graph, job); } private: Maybe RemoveParallelCastOps(Job* job) const; }; Maybe AutoParallelPass::Apply(const OpGraph& op_graph, Job* job) const { // auto-parallel LOG(INFO) << "Start Auto Parallel"; auto time_begin = std::chrono::high_resolution_clock::now(); auto_parallel::SbpConstructor sbp_constructor(op_graph, job); JUST(sbp_constructor.FindBestSbpSignature()); JUST(sbp_constructor.DumpNdSbpSignatureForJob(op_graph, job)); auto time_end = std::chrono::high_resolution_clock::now(); VLOG(2) << "Auto parallel took " << std::chrono::duration_cast(time_end - time_begin).count() << " ms\n"; if (GlobalProcessCtx::Rank() == 0) { // sbp_constructor.PrintSBPGraphDebugInfo(); JUST(sbp_constructor.CheckSbpAgreement(*job)); } return Maybe::Ok(); } REGISTER_JOB_PASS("AutoParallelPass", AutoParallelPass); Maybe AutoParallelPass::RemoveParallelCastOps(Job* job) const { VLOG(3) << "Remove parallel cast ops for auto_parallel:"; const OpGraph op_graph(*job); JobBuilder job_builder(job); HashMap op_name2op_conf; HashMap op_name2nd_sbp_signature; HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); const auto IsParallelCastOp = [](const OperatorConf& op_conf) -> bool { return op_conf.has_user_conf() && (op_conf.user_conf().op_type_name() == "parallel_cast" || op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast" || op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast_like"); }; std::vector del_op_names; HashSet del_op_name_set; std::function Try2Delete = [&](const OpNode* op_node) { if (del_op_name_set.find(op_node->op().op_name()) != del_op_name_set.end()) { return; } const OperatorConf& op_conf = op_node->op().op_conf(); if (!IsParallelCastOp(op_conf)) { return; } if (!op_conf.ctrl_in_op_name().empty()) { VLOG(3) << "Skip " << op_conf.name() << ", because it has ctrl edge."; return; } if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { VLOG(3) << "Skip " << op_conf.name() << ", because it is a ctrl edge."; return; } if (op_node->in_edges().size() != 1) { return; } // Find the first op which won't be deleted const OpNode* source_op = op_node; const OpNode* producer = op_node->SoleInEdge()->src_node(); while (IsParallelCastOp(producer->op().op_conf())) { Try2Delete(producer); if (del_op_name_set.find(producer->op().op_name()) == del_op_name_set.end()) { break; } source_op = producer; producer = source_op->SoleInEdge()->src_node(); } user_op::UserOpConfWrapper conf_wrapper_in(source_op->op().op_conf()); const LogicalBlobId& parallel_cast_in_lbi = GenLogicalBlobId(conf_wrapper_in.input("in", 0)); user_op::UserOpConfWrapper conf_wrapper_out(op_conf); const LogicalBlobId& parallel_cast_out_lbi = GenLogicalBlobId(conf_wrapper_out.output("out", 0)); if (op_node->parallel_desc() != producer->parallel_desc()) { VLOG(3) << "Skip " << op_node->op().op_name() << "(with placement: " << *CHECK_JUST(PlacementToString(SymbolOf(op_node->parallel_desc()))) << "), because producer " << producer->op().op_name() << "'s placement is " << *CHECK_JUST(PlacementToString(SymbolOf(producer->parallel_desc()))); return; } for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (consumer->parallel_desc() != op_node->parallel_desc()) { VLOG(3) << "Skip " << op_node->op().op_name() << "(with placement: " << *CHECK_JUST(PlacementToString(SymbolOf(op_node->parallel_desc()))) << "), because consumer " << consumer->op().op_name() << "'s placement is " << *CHECK_JUST(PlacementToString(SymbolOf(consumer->parallel_desc()))); return; } } op_name2nd_sbp_signature[producer->op().op_name()] = producer->nd_sbp_signature(); for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); const std::string& consumer_op_name = consumer->op().op_name(); op_name2nd_sbp_signature[consumer_op_name] = consumer->nd_sbp_signature(); if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) { op_name2op_conf[consumer_op_name] = consumer->op().op_conf(); } OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name); for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == parallel_cast_out_lbi) { const auto& new_val = GenLogicalBlobName(parallel_cast_in_lbi); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(parallel_cast_out_lbi), old_val); } } } del_op_names.emplace_back(op_conf.name()); del_op_name_set.insert(op_conf.name()); VLOG(3) << "\tremove " << op_conf.name(); }; op_graph.ForEachNode(Try2Delete); for (const auto& pair : op_name2op_conf) { job_builder.MutOpsOnlyOnce({pair.second}); } for (const auto& pair : op_name2nd_sbp_signature) { job_builder.AddNdSbpSignature4OpName(pair.first, pair.second); } job_builder.DelOps(del_op_names); return Maybe::Ok(); } } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/auto_train_step.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h" #include "oneflow/core/framework/scope_util.h" namespace oneflow { namespace { class AutoTrainStep final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(AutoTrainStep); AutoTrainStep() = default; ~AutoTrainStep() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe AutoTrainStep::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().IsTrain()) { return Maybe::Ok(); } const OpGraph op_graph(*job); const TrainConf& train_conf = job->job_conf().train_conf(); if (train_conf.has_train_step_lbn()) { CHECK_OR_RETURN(!train_conf.has_dynamic_loss_scale_policy()); return Maybe::Ok(); } OperatorConf variable_op_conf{}; const std::string train_step_name = "System-Train-TrainStep"; variable_op_conf.set_name(train_step_name); VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); variable_conf->set_out("out"); *variable_conf->mutable_shape()->mutable_dim()->Add() = 1; variable_conf->set_data_type(DataType::kInt64); variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0); OperatorConf identity_op_conf{}; identity_op_conf.set_name(train_step_name + "-Identity"); IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf(); identity_conf->set_in(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out())); identity_conf->set_out("out"); const std::string& train_step_lbn = GenLogicalBlobName(identity_op_conf.name(), identity_conf->out()); JobBuilder job_builder(job); ParallelConf parallel_conf; if (ParseBooleanFromEnv("ONEFLOW_GRAPH_PLACE_TRAINING_STATE_ON_ALL_RANKS", false)) { parallel_conf = GenParallelConfOfCpuOnAllRanks(); } else { parallel_conf = GenParallelConfOfCpuZeroOnMaster(); } int64_t scope_symbol_id = 0; { const auto& opt_scope_symbol_id = JUST(MakeInitialScope(job->job_conf(), SymbolOf(ParallelDesc(parallel_conf)), /* is_local */ false)) ->symbol_id(); CHECK_OR_RETURN(opt_scope_symbol_id.has_value()) << Error::RuntimeError() << "symbol_id not initialized"; scope_symbol_id = JUST(opt_scope_symbol_id); } auto scalar_add_op = user_op::UserOpConfWrapperBuilder(train_step_name + "-ScalarAdd") .Op("scalar_add") .Input("in", train_step_lbn) .Output("out") .Attr("has_float_operand", false) .Attr("float_operand", 0) .Attr("has_int_operand", true) .Attr("int_operand", 1) .ScopeSymbolId(scope_symbol_id) .Build(); variable_op_conf.set_scope_symbol_id(scope_symbol_id); identity_op_conf.set_scope_symbol_id(scope_symbol_id); job_builder.AddOps(parallel_conf, {variable_op_conf, identity_op_conf, scalar_add_op.op_conf()}); if (train_conf.has_dynamic_loss_scale_policy()) { const auto& dynamic_loss_scale_state = JUST(ctx->GetState("dynamic_loss_scale_state")); auto assign_op = user_op::UserOpConfWrapperBuilder(train_step_name + "-AssignIfNot") .Op("assign_if_not") .Input("ref", GenLogicalBlobName(variable_op_conf.name(), variable_conf->out())) .Input("value", scalar_add_op.output("out", 0)) .Input("condition", dynamic_loss_scale_state.count_not_finite_lbn()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder.AddOps(parallel_conf, {assign_op.op_conf()}); } else { auto assign_op = user_op::UserOpConfWrapperBuilder(train_step_name + "-Assign") .Op("assign") .Input("ref", GenLogicalBlobName(variable_op_conf.name(), variable_conf->out())) .Input("value", scalar_add_op.output("out", 0)) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder.AddOps(parallel_conf, {assign_op.op_conf()}); } job->mutable_job_conf()->mutable_train_conf()->set_train_step_lbn(train_step_lbn); return Maybe::Ok(); } REGISTER_JOB_PASS("AutoTrainStep", AutoTrainStep); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/autograd.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job_rewriter/clone_grad.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/register/op_blob_arg.pb.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h" #include "oneflow/core/framework/scope_util.h" #include "oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h" #include "oneflow/core/job_rewriter/pass_util.h" namespace oneflow { namespace { const TrainConf& GetTrainConf() { return GlobalJobDesc().job_conf().train_conf(); } int64_t ScopeSymbolId4Lbi(const OpGraph& op_graph, const LogicalBlobId& lbi) { return op_graph.OpNode4OpName(lbi.op_name())->op().op_conf().scope_symbol_id(); } bool AnyLbiWithDiffLbi(const OpEdge* op_edge) { const Operator& src_op = op_edge->src_node()->op(); const Operator& dst_op = op_edge->dst_node()->op(); auto IsOutputBlobModifierRequiresGrad = [&](const LogicalBlobId& lbi) { return src_op.OutputBlobModifier4Obn(op_edge->lbi2obn().at(lbi)).requires_grad(); }; auto IsInputBlobModifierRequiresGrad = [&](const LogicalBlobId& lbi) { const auto& ibns = op_edge->lbi2ibns().at(lbi); for (const std::string& ibn : ibns) { if (dst_op.InputBlobModifier4Ibn(ibn).requires_grad()) { return true; } } CHECK_GT(ibns.size(), 0); return false; }; for (const LogicalBlobId& lbi : op_edge->lbis()) { if (IsOutputBlobModifierRequiresGrad(lbi) && IsInputBlobModifierRequiresGrad(lbi)) { return true; } } CHECK_GT(op_edge->lbis().size(), 0); return false; } void CheckNotReachableAmongOpNodes(const OpGraph& op_graph, const std::list& op_nodes) { auto IsReachable = op_graph.MakePredicatorIsReachable(); for (OpNode* src_node : op_nodes) { for (OpNode* dst_node : op_nodes) { if (src_node == dst_node) { continue; } CHECK(!IsReachable(src_node, dst_node)); } } } Maybe GetLossOpNodes(const OpGraph& op_graph, std::list* loss_op_nodes) { const auto& train_conf = GetTrainConf(); HashSet loss_op_names; for (const std::string& loss_lbn : train_conf.loss_lbn()) { loss_op_names.emplace(GenLogicalBlobId(loss_lbn).op_name()); } op_graph.ForEachNode([&](OpNode* op_node) { if (loss_op_names.find(op_node->op().op_name()) != loss_op_names.end()) { loss_op_nodes->emplace_back(op_node); } }); if (loss_op_nodes->empty()) { return Error::LossBlobNotFoundError() << "Loss blob not found."; } return Maybe::Ok(); } Maybe GetLossOpNodesAndAscendants(const OpGraph& op_graph, HashSet* op_nodes) { std::list starts; JUST(GetLossOpNodes(op_graph, &starts)); auto ForEachNextNode = [&](OpNode* op_node, const std::function& Handler) { for (OpEdge* edge : op_node->in_edges()) { if (AnyLbiWithDiffLbi(edge)) { Handler(edge->src_node()); } } }; op_graph.BfsForEachNode(starts, ForEachNextNode, [&](OpNode* op_node) { op_nodes->emplace(op_node); }); return Maybe::Ok(); } const ParallelConf& ProducerParallelConf4Lbi(const OpGraph& op_graph, const LogicalBlobId& lbi) { return op_graph.OpNode4OpName(lbi.op_name())->parallel_desc().parallel_conf(); } void ScaleModelDiffByConstantLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi, const int64_t loss_instance_num) { if (loss_instance_num == 1) { return; } const float scale_factor = 1.0f / static_cast(loss_instance_num); for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul") .Input("in", GenLogicalBlobName(diff_lbi)) .Output("out") .Attr("has_float_operand", true) .Attr("float_operand", scale_factor) .Attr("has_int_operand", false) .Attr("int_operand", 0) .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi)) .Build(); job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()}); diff_lbi = GenLogicalBlobId(scalar_mul_op.output("out", 0)); } } Maybe TryLocalCastTotalLossInstanceNum( JobBuilder* job_builder, const HashMap& loss_lbi2loss_node, LogicalBlobId* total_loss_instance_num_lbi) { auto IsLocal4Lbi = [](const LogicalBlobId& lbi, OpNode* op_node) -> Maybe { const auto& obn = *JUST(op_node->op().obn4lbi(lbi)); const auto& opt_local_parallel = *JUST(op_node->op().OptLocalParallel4BnInOp(obn)); return opt_local_parallel.has_local_parallel(); }; const auto& begin = *loss_lbi2loss_node.begin(); bool is_local = JUST(IsLocal4Lbi(begin.first, begin.second)); for (const auto& pair : loss_lbi2loss_node) { bool is_other_local = JUST(IsLocal4Lbi(pair.first, pair.second)); CHECK_EQ_OR_RETURN(is_local, is_other_local); // NOLINT } if (is_local) { OperatorConf op_conf; op_conf.set_name("System-Cast-Local-TotalLossInstanceNum" + NewUniqueId()); CastFromLocalOpConf* cast_from_local = op_conf.mutable_cast_from_local_conf(); cast_from_local->set_in(GenLogicalBlobName(*total_loss_instance_num_lbi)); cast_from_local->set_out("out"); cast_from_local->mutable_sbp_parallel()->mutable_partial_sum_parallel(); const auto& parallel_conf = JUST(job_builder->ParallelConf4Lbi(*total_loss_instance_num_lbi)); int64_t scope_symbol_id = 0; { const auto& opt_scope_symbol_id = JUST(MakeInitialScope(job_builder->job().job_conf(), SymbolOf(ParallelDesc(parallel_conf)), /* is_local */ false)) ->symbol_id(); CHECK_OR_RETURN(opt_scope_symbol_id.has_value()) << Error::RuntimeError() << "symbol_id not initialized"; scope_symbol_id = JUST(opt_scope_symbol_id); } op_conf.set_scope_symbol_id(scope_symbol_id); job_builder->AddOps(parallel_conf, {op_conf}); total_loss_instance_num_lbi->set_op_name(op_conf.name()); total_loss_instance_num_lbi->set_blob_name("out"); } return Maybe::Ok(); } void ScaleModelDiffByDynamicLossInstanceNum( const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi, const HashMap& loss_lbi2loss_node) { auto BuildInstanceNumOpConf4LossOpNode = [&](const LogicalBlobId& loss_lbi, const OpNode* op_node, LogicalBlobId* lbi) { OperatorConf instance_num_op; instance_num_op.set_name("System-Autograd-" + loss_lbi.op_name() + "-" + loss_lbi.blob_name() + "-LossInstanceNum"); auto* instance_num_op_conf = instance_num_op.mutable_shape_elem_cnt_conf(); instance_num_op_conf->set_x(GenLogicalBlobName(loss_lbi)); instance_num_op_conf->set_y("y"); instance_num_op_conf->set_data_type(op_node->LogicalBlobDesc4Lbi(loss_lbi).data_type()); instance_num_op_conf->mutable_include_axis_conf(); instance_num_op.set_scope_symbol_id(op_node->op().op_conf().scope_symbol_id()); job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {instance_num_op}); lbi->set_op_name(instance_num_op.name()); lbi->set_blob_name("y"); }; LogicalBlobId total_loss_instance_num_lbi; if (loss_lbi2loss_node.size() == 1) { const auto& pair_it = loss_lbi2loss_node.begin(); BuildInstanceNumOpConf4LossOpNode(pair_it->first, pair_it->second, &total_loss_instance_num_lbi); } else if (loss_lbi2loss_node.size() > 1) { OperatorConf op_conf; op_conf.set_name("System-Autograd-total_loss_instance_num"); TotalLossInstanceNumOpConf* total_loss_instance_num_conf = op_conf.mutable_total_loss_instance_num_conf(); for (const auto& pair : loss_lbi2loss_node) { LogicalBlobId loss_instance_num_lbi; BuildInstanceNumOpConf4LossOpNode(pair.first, pair.second, &loss_instance_num_lbi); total_loss_instance_num_conf->add_in(GenLogicalBlobName(loss_instance_num_lbi)); } total_loss_instance_num_conf->set_out("out"); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0"); int64_t scope_symbol_id = 0; { const auto& opt_scope_symbol_id = CHECK_JUST(MakeInitialScope(job_builder->job().job_conf(), SymbolOf(ParallelDesc(parallel_conf)), /* is_local */ false)) ->symbol_id(); if (!opt_scope_symbol_id.has_value()) { THROW(RuntimeError) << "symbol_id not initialized"; } scope_symbol_id = CHECK_JUST(opt_scope_symbol_id); } op_conf.set_scope_symbol_id(scope_symbol_id); job_builder->AddOps(parallel_conf, {op_conf}); total_loss_instance_num_lbi.set_op_name(op_conf.name()); total_loss_instance_num_lbi.set_blob_name("out"); } else { UNIMPLEMENTED(); } CHECK_JUST(TryLocalCastTotalLossInstanceNum(job_builder, loss_lbi2loss_node, &total_loss_instance_num_lbi)); for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_div_op = user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarDiv-" + lbi.op_name() + "_" + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_div_by_tensor") .Input("x", GenLogicalBlobName(diff_lbi)) .Input("scalar", GenLogicalBlobName(total_loss_instance_num_lbi)) .Output("y") .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi)) .Build(); job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_div_op.op_conf()}); diff_lbi = GenLogicalBlobId(scalar_div_op.output("y", 0)); } } bool AllSplitDistribution(const NdSbp& nd_sbp) { for (int64_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (!nd_sbp.sbp_parallel(i).has_split_parallel()) { return false; } } return true; } void ForEachAggregatedParamGroup( const OpGraph& op_graph, const HashMap& lbi2diff_lbi, const std::function& libs)>& Handler) { HashMap lbi2parallel_desc; HashMap, std::vector> group; for (auto& pair : lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name()); const ParallelDesc& parallel_desc = model_op_node->parallel_desc(); const NdSbp& nd_sbp = model_op_node->NdSbp4Lbi(lbi); group[std::make_pair(parallel_desc, nd_sbp)].emplace_back(lbi); } for (const auto& pair : group) { Handler(pair.first.first, pair.first.second, pair.second); } } int64_t MakeScopeSymbolId(const JobConfigProto& job_conf, const ParallelConf& parallel_conf) { const auto& opt_scope_symbol_id = CHECK_JUST(MakeInitialScope(job_conf, SymbolOf(ParallelDesc(parallel_conf)), /* is_local */ false)) ->symbol_id(); if (!opt_scope_symbol_id.has_value()) { THROW(RuntimeError) << "symbol_id not initialized"; } return CHECK_JUST(opt_scope_symbol_id); } std::string AddLbns(JobBuilder* job_builder, const std::vector& lbns, const ParallelConf& parallel_conf, int64_t scope_symbol_id, const std::string& op_name_prefix) { if (lbns.size() == 1) { return lbns.front(); } else { user_op::UserOpConfWrapperBuilder add_op_builder(op_name_prefix + NewUniqueId()); add_op_builder.Op("add_n"); for (const std::string& lbn : lbns) { add_op_builder.Input("in", lbn); } const auto add_op = add_op_builder.Output("out").ScopeSymbolId(scope_symbol_id).Build(); job_builder->AddOps(parallel_conf, {add_op.op_conf()}); return add_op.output("out", 0); } } std::string AddParallelCast(JobBuilder* job_builder, const std::string& in_lbn, const std::string& sbp_str, const ParallelConf& parallel_conf, const std::string& op_name_prefix) { ParallelConf flat_parallel_conf = parallel_conf; flat_parallel_conf.mutable_hierarchy()->clear_dim(); const int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), flat_parallel_conf); std::vector sbp = {sbp_str}; auto parallel_cast_op = user_op::UserOpConfWrapperBuilder(op_name_prefix + NewUniqueId()) .Op("hierarchical_parallel_cast") .Input("in", in_lbn) .Output("out") .Attr>("nd_sbp", sbp) .Attr("grad_mode", "auto") .Attr>("grad_nd_sbp", std::vector{}) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(flat_parallel_conf, {parallel_cast_op.op_conf()}); return parallel_cast_op.output("out", 0); } bool IsBroadcast(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) { if (parallel_desc.parallel_num() == 1) { return true; } for (int64_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (!nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { return false; } } return true; } bool HasSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) { if (parallel_desc.parallel_num() == 1) { return false; } for (const auto& sbp : nd_sbp.sbp_parallel()) { if (sbp.has_split_parallel()) { return true; } } return false; } OperatorConf GenConstantLikeOp(const std::string& op_name, int64_t scope_symbol_id, const std::string& like_lbn, double value, DataType dtype) { OperatorConf op_conf; op_conf.set_name(op_name); op_conf.set_scope_symbol_id(scope_symbol_id); ConstantLikeOpConf* constant_like_conf = op_conf.mutable_constant_like_conf(); constant_like_conf->set_like(like_lbn); if (dtype == DataType::kInt32) { constant_like_conf->set_int_operand(static_cast(value)); } else if (dtype == DataType::kInt64) { constant_like_conf->set_int_operand(static_cast(value)); } else if (dtype == DataType::kFloat) { constant_like_conf->set_float_operand(static_cast(value)); } else if (dtype == DataType::kDouble) { constant_like_conf->set_float_operand(value); } else { UNIMPLEMENTED(); } constant_like_conf->set_data_type(dtype); constant_like_conf->set_out("out"); return op_conf; } std::string GlobalAbsMaxMin(const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& lbi2diff_lbi, bool max_or_min, ParallelConf* out_parallel_conf) { // max(abs(x)) bool all_same_parallel_desc = true; const ParallelDesc& any_parallel_desc = op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc(); std::vector group_reduce_lbns; auto GroupReduce = [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, const std::vector& lbis) { if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) { all_same_parallel_desc = false; } int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf()); bool has_split = HasSplit(nd_sbp, parallel_desc); if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) { std::string multi_reduce_op_type_name = has_split ? (max_or_min ? "local_multi_reduce_max_abs" : "local_multi_reduce_min_abs") : (max_or_min ? "multi_reduce_max_abs" : "multi_reduce_min_abs"); std::string multi_reduce_op_name = "System-ClipGradient-GlobalNorm-MultiReduceXimumAbs-" + NewUniqueId(); auto multi_reduce_op_builder = user_op::UserOpConfWrapperBuilder(multi_reduce_op_name) .Op(multi_reduce_op_type_name) .Output("y") .ScopeSymbolId(scope_symbol_id); for (const auto& lbi : lbis) { multi_reduce_op_builder.Input("x", GenLogicalBlobName(lbi2diff_lbi.at(lbi))); } auto multi_reduce_op = multi_reduce_op_builder.Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {multi_reduce_op.op_conf()}); if (has_split) { std::string group_reduce_op_type_name = max_or_min ? "reduce_max" : "reduce_min"; std::string group_reduce_op_name = "System-ClipGradient-GlobalNorm-GroupReduceXimum-" + NewUniqueId(); auto group_reduce_op = user_op::UserOpConfWrapperBuilder(group_reduce_op_name) .Op(group_reduce_op_type_name) .Input("input_tensor", multi_reduce_op.output("y", 0)) .Output("output_tensor") .Attr("axis", std::vector{0}) .Attr("keepdims", false) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {group_reduce_op.op_conf()}); group_reduce_lbns.push_back(group_reduce_op.output("output_tensor", 0)); } else { group_reduce_lbns.push_back(multi_reduce_op.output("y", 0)); } } else { UNIMPLEMENTED(); } }; ForEachAggregatedParamGroup(op_graph, lbi2diff_lbi, GroupReduce); CHECK_GT(group_reduce_lbns.size(), 0); *out_parallel_conf = all_same_parallel_desc ? any_parallel_desc.parallel_conf() : GenParallelConfOfCpuZeroOnMaster(); out_parallel_conf->mutable_hierarchy()->clear_dim(); if (group_reduce_lbns.size() == 1) { return group_reduce_lbns[0]; } else { // stack all group max and go on max const int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), *out_parallel_conf); auto stack_op_builder = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-GlobalStack-" + NewUniqueId()) .Op("stack") .Output("out") .Attr("axis", int64_t(0)) .Attr("max_dim_size", static_cast(group_reduce_lbns.size())) .ScopeSymbolId(scope_symbol_id); for (const auto& lbn : group_reduce_lbns) { stack_op_builder.Input("in", lbn); } auto stack_op = stack_op_builder.Build(); job_builder->AddOps(*out_parallel_conf, {stack_op.op_conf()}); std::string reduce_op_type_name = max_or_min ? "reduce_max" : "reduce_min"; std::string reduce_op_name = "System-ClipGradient-GlobalNorm-GlobalReduceXimum-" + NewUniqueId(); auto reduce_op = user_op::UserOpConfWrapperBuilder(reduce_op_name) .Op(reduce_op_type_name) .Input("input_tensor", stack_op.output("out", 0)) .Output("output_tensor") .Attr("axis", std::vector{0}) .Attr("keepdims", false) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(*out_parallel_conf, {reduce_op.op_conf()}); return reduce_op.output("output_tensor", 0); } } std::string GlobalNorm(const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& lbi2diff_lbi, float p, ParallelConf* out_parallel_conf) { bool all_same_parallel_desc = true; const ParallelDesc& any_parallel_desc = op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc(); bool all_broadcast = true; std::vector group_lbns; std::vector group_parallel_confs; group_lbns.reserve(lbi2diff_lbi.size()); group_parallel_confs.reserve(lbi2diff_lbi.size()); auto GroupNorm = [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, const std::vector& lbis) { if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) { all_same_parallel_desc = false; } int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf()); if (!IsBroadcast(nd_sbp, parallel_desc)) { all_broadcast = false; } group_parallel_confs.emplace_back(parallel_desc.parallel_conf()); if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) { auto multi_reduce_sum_op_builder = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-MultiReduceSumPowAbs-" + NewUniqueId()) .Op("multi_reduce_sum_pow_abs") .Attr("p", p) .Output("y") .ScopeSymbolId(scope_symbol_id); for (const auto& lbi : lbis) { multi_reduce_sum_op_builder.Input("x", GenLogicalBlobName(lbi2diff_lbi.at(lbi))); } const auto multi_reduce_sum_op = multi_reduce_sum_op_builder.Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {multi_reduce_sum_op.op_conf()}); group_lbns.emplace_back(multi_reduce_sum_op.output("y", 0)); } else { std::vector lbns_to_add; lbns_to_add.reserve(lbis.size()); for (const auto& lbi : lbis) { const LogicalBlobId& diff_lbi = lbi2diff_lbi.at(lbi); const auto square_sum_op = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-ReduceSumPowAbs-" + NewUniqueId()) .Op("multi_reduce_sum_pow_abs") .Input("x", GenLogicalBlobName(diff_lbi)) .Attr("p", p) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {square_sum_op.op_conf()}); lbns_to_add.emplace_back(square_sum_op.output("y", 0)); } group_lbns.emplace_back(AddLbns(job_builder, lbns_to_add, parallel_desc.parallel_conf(), scope_symbol_id, "System-ClipGradient-GlobalNorm-Add-")); } }; ForEachAggregatedParamGroup(op_graph, lbi2diff_lbi, GroupNorm); // sum in group *out_parallel_conf = all_same_parallel_desc ? any_parallel_desc.parallel_conf() : GenParallelConfOfCpuZeroOnMaster(); const int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), *out_parallel_conf); std::vector sum_group_lbns; if (all_broadcast) { sum_group_lbns = std::move(group_lbns); } else { sum_group_lbns.reserve(group_lbns.size()); for (size_t i = 0; i < group_lbns.size(); ++i) { std::string lbn; if (all_same_parallel_desc) { // reduce many times P->B (allreduce) to 1 times lbn = AddParallelCast(job_builder, group_lbns.at(i), "P", group_parallel_confs.at(i), "System-ClipGradient-ParallelCast-"); } else { // sum will run on cpu 0, we need do P->B first, // because when execution is on single device, only B is accepted lbn = AddParallelCast(job_builder, group_lbns.at(i), "B", group_parallel_confs.at(i), "System-ClipGradient-ParallelCast-"); } sum_group_lbns.push_back(std::move(lbn)); } out_parallel_conf->mutable_hierarchy()->clear_dim(); } auto global_reduce_sum_lbn = AddLbns(job_builder, sum_group_lbns, *out_parallel_conf, scope_symbol_id, "System-ClipGradient-GlobalNorm-Add-"); auto global_pow_op = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-GlobalPow-" + NewUniqueId()) .Op("scalar_pow") .Input("in", global_reduce_sum_lbn) .Attr("float_operand", 1.0 / p) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(*out_parallel_conf, {global_pow_op.op_conf()}); return global_pow_op.output("out", 0); } void ClipGradientByGlobalNorm(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi, const ClipByGlobalNormConf& conf) { if (lbi2diff_lbi->empty()) { return; } ParallelConf parallel_conf; std::string total_norm_lbn; CHECK(conf.has_norm_type()); double norm_type = conf.norm_type(); if (std::isinf(norm_type) && norm_type > 0) { total_norm_lbn = GlobalAbsMaxMin(op_graph, job_builder, *lbi2diff_lbi, true, ¶llel_conf); } else if (std::isinf(norm_type) && norm_type < 0) { total_norm_lbn = GlobalAbsMaxMin(op_graph, job_builder, *lbi2diff_lbi, false, ¶llel_conf); } else { total_norm_lbn = GlobalNorm(op_graph, job_builder, *lbi2diff_lbi, norm_type, ¶llel_conf); } int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), parallel_conf); auto add_eps_ops = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-AddEps-" + NewUniqueId()) .Op("scalar_add") .Input("in", total_norm_lbn) .Attr("float_operand", 1e-6) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {add_eps_ops.op_conf()}); auto inv_op = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-Inv-" + NewUniqueId()) .Op("reciprocal_no_nan") .Input("x", add_eps_ops.output("out", 0)) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {inv_op.op_conf()}); auto coeff_op = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-Coeff-" + NewUniqueId()) .Op("scalar_mul") .Input("in", inv_op.output("y", 0)) .Attr("float_operand", static_cast(conf.max_norm())) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {coeff_op.op_conf()}); auto clamp_coeff_op = user_op::UserOpConfWrapperBuilder("System-ClipGradient-GlobalNorm-Clamp-" + NewUniqueId()) .Op("clip_by_scalar_max") .Input("x", coeff_op.output("out", 0)) .Attr("floating_max", 1.0) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {clamp_coeff_op.op_conf()}); const std::string& coeff_lbn = clamp_coeff_op.output("y", 0); for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto mul_op_name = "System-ClipGradient-GlobalNorm-ScalarMul-" + NewUniqueId(); auto scalar_mul_op = user_op::UserOpConfWrapperBuilder(mul_op_name) .Op("scalar_mul_by_tensor") .Input("x", GenLogicalBlobName(diff_lbi)) .Input("scalar", coeff_lbn) .Output("y") .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi)) .Build(); job_builder->AddOps(op_graph.OpNode4OpName(lbi.op_name())->parallel_desc().parallel_conf(), {scalar_mul_op.op_conf()}); diff_lbi = GenLogicalBlobId(scalar_mul_op.output("y", 0)); } if (!CHECK_JUST(ctx->HasState("clip_by_global_norm_state"))) { CHECK_JUST(ctx->ResetState("clip_by_global_norm_state", std::make_unique())); } auto state = CHECK_JUST(ctx->MutableState("clip_by_global_norm_state")); const std::shared_ptr& total_norm_state = std::make_shared( total_norm_lbn, coeff_lbn, parallel_conf, scope_symbol_id); for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; const std::string& variable_op_name = lbi.op_name(); state->AddTotalNormState(variable_op_name, total_norm_state); } } } // namespace Maybe MakeGetterLossOpNode4OpName( const OpGraph& op_graph, std::function* LossOpNode4OpName) { std::list loss_nodes; JUST(GetLossOpNodes(op_graph, &loss_nodes)); auto loss_op_name2op_node = std::make_shared>(); for (OpNode* op_node : loss_nodes) { CHECK(loss_op_name2op_node->emplace(op_node->op().op_name(), op_node).second); } *LossOpNode4OpName = [loss_op_name2op_node](const std::string& op_name) -> OpNode* { return loss_op_name2op_node->at(op_name); }; return Maybe::Ok(); } Maybe ScaleModelDiffByLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { std::function LossOpNode4OpName; JUST(MakeGetterLossOpNode4OpName(op_graph, &LossOpNode4OpName)); const auto& train_conf = GetTrainConf(); HashMap loss_lbi2op_node; for (const auto& loss_lbn : train_conf.loss_lbn()) { const auto& lbi = GenLogicalBlobId(loss_lbn); CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second); } const Shape src_time_shape({1, 1}); const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt(); bool all_loss_time_shape_eq_src = true; for (const auto& pair : loss_lbi2op_node) { const int64_t time_shape_elem_cnt = JUST(pair.second->op().GetOpTimeShape())->elem_cnt(); if (time_shape_elem_cnt != source_time_shape_elem_cnt) { CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0); all_loss_time_shape_eq_src = false; } } if (all_loss_time_shape_eq_src) { const BlobDesc* blob_desc = nullptr; for (const auto& pair : loss_lbi2op_node) { const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); } blob_desc = cur_blob_desc; } if (blob_desc->is_dynamic()) { ScaleModelDiffByDynamicLossInstanceNum(op_graph, job_builder, lbi2diff_lbi, loss_lbi2op_node); } else { ScaleModelDiffByConstantLossInstanceNum(op_graph, job_builder, lbi2diff_lbi, blob_desc->shape().elem_cnt()); } } else { std::unique_ptr blob_desc; for (const auto& pair : loss_lbi2op_node) { const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); // TODO: support dynamic CHECK(!cur_blob_desc->is_dynamic()); const DataType loss_data_type = cur_blob_desc->data_type(); const int64_t time_shape_elem_cnt = JUST(pair.second->op().GetOpTimeShape())->elem_cnt(); // TODO: consider sbp const int64_t loss_elem_cnt = cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt; if (blob_desc) { CHECK_EQ(blob_desc->data_type(), loss_data_type); CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt); } else { blob_desc.reset( new BlobDesc(Shape({loss_elem_cnt}), loss_data_type, cur_blob_desc->memory_format())); } } ScaleModelDiffByConstantLossInstanceNum(op_graph, job_builder, lbi2diff_lbi, blob_desc->shape().elem_cnt()); } return Maybe::Ok(); } Maybe ScaleInitialDiffByLossScale( JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* loss_lbi2initial_diff_lbi) { const TrainConf& train_conf = ctx->job_desc().job_conf().train_conf(); if (!train_conf.has_dynamic_loss_scale_policy() && !train_conf.has_loss_scale_factor()) { return Maybe::Ok(); } for (auto& it : *loss_lbi2initial_diff_lbi) { const auto& loss_lbi = it.first; const auto& initial_diff_lbi = it.second; const OpNode* initial_diff_node = op_graph.OpNode4OpName(initial_diff_lbi.op_name()); int64_t scope_symbol_id = initial_diff_node->op().op_conf().scope_symbol_id(); const auto& parallel_conf = initial_diff_node->parallel_desc().parallel_conf(); std::string loss_diff_lbn = GenLogicalBlobName(initial_diff_lbi); const DataType init_diff_data_type = op_graph.GetLogicalBlobDesc(initial_diff_lbi).data_type(); // cast loss init diff from float16 to float32 since we need do loss scale (float32 multiply) // later if (init_diff_data_type != DataType::kFloat) { std::string cast_op_name = initial_diff_lbi.op_name() + "_" + initial_diff_lbi.blob_name() + "_loss_scale-cast_h2f"; auto cast_op = user_op::UserOpConfWrapperBuilder(cast_op_name) .Op("cast") .Input("in", loss_diff_lbn) .Output("out") .Attr("dtype", DataType::kFloat) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {cast_op.op_conf()}); loss_diff_lbn = cast_op.output("out", 0); } std::string loss_scale_val_lbn; if (train_conf.has_dynamic_loss_scale_policy()) { const auto& dynamic_loss_scale_state = JUST(ctx->GetState("dynamic_loss_scale_state")); loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn(); } else if (train_conf.has_loss_scale_factor()) { OperatorConf constant_like_op{}; constant_like_op.set_name(loss_lbi.op_name() + "_" + loss_lbi.blob_name() + "_constant_like_loss_scale"); constant_like_op.set_scope_symbol_id(scope_symbol_id); ConstantLikeOpConf* constant_like_conf = constant_like_op.mutable_constant_like_conf(); constant_like_conf->set_like(loss_diff_lbn); constant_like_conf->set_out("out"); constant_like_conf->set_float_operand(train_conf.loss_scale_factor()); job_builder->AddOps(parallel_conf, {constant_like_op}); loss_scale_val_lbn = GenLogicalBlobName(constant_like_op.name(), constant_like_conf->out()); } else { UNIMPLEMENTED_THEN_RETURN() << "dynamic or static loss scale must be config"; } const int64_t time_shape_elem_cnt = JUST(initial_diff_node->op().GetInputBlobFastestTimeShape())->elem_cnt(); if (time_shape_elem_cnt != 1) { const auto repeat_op = user_op::UserOpConfWrapperBuilder(loss_lbi.op_name() + "_" + loss_lbi.blob_name() + "_loss_scale-repeat") .OpTypeName("repeat") .Input("in", loss_scale_val_lbn) .Output("out") .Attr("repeat_num", time_shape_elem_cnt) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {repeat_op.op_conf()}); loss_scale_val_lbn = repeat_op.output("out", 0); } auto scalar_mul_op = user_op::UserOpConfWrapperBuilder(initial_diff_lbi.op_name() + "_" + initial_diff_lbi.blob_name() + "_scale_initial_diff") .Op("scalar_mul_by_tensor") .Input("x", loss_diff_lbn) .Input("scalar", loss_scale_val_lbn) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {scalar_mul_op.op_conf()}); std::string scaled_initial_diff_lbn = scalar_mul_op.output("y", 0); // cast loss initial diff back to float16 if (init_diff_data_type != DataType::kFloat) { std::string cast_op_name = initial_diff_lbi.op_name() + "_" + initial_diff_lbi.blob_name() + "_loss_scale-cast_f2h"; auto cast_op = user_op::UserOpConfWrapperBuilder(cast_op_name) .Op("cast") .Input("in", scaled_initial_diff_lbn) .Output("out") .Attr("dtype", init_diff_data_type) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {cast_op.op_conf()}); scaled_initial_diff_lbn = cast_op.output("out", 0); } // update consumer input by scalar_mul_op output initial_diff_node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == initial_diff_lbi) { if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) { CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf())); } OperatorConf& mut_consumer_op = CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name())); const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, scaled_initial_diff_lbn); CHECK_EQ(old_lbn, GenLogicalBlobName(initial_diff_lbi)); } } }); // update initial diff lbi it.second = GenLogicalBlobId(scaled_initial_diff_lbn); } JUST(job_builder->MutOpTransactionCommit()); return Maybe::Ok(); } void ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { auto ProducerOpNode4Lbi = [&](const LogicalBlobId& lbi) { return op_graph.OpNode4OpName(lbi.op_name()); }; auto ProducerOpNode4Lbn = [&](const std::string& lbn) { return ProducerOpNode4Lbi(GenLogicalBlobId(lbn)); }; const TrainConf& train_conf = ctx->job_desc().job_conf().train_conf(); if (train_conf.has_dynamic_loss_scale_policy()) { const auto& dynamic_loss_scale_state = CHECK_JUST(ctx->GetState("dynamic_loss_scale_state")); HashMap data_type2loss_scale_lbn; const auto LossScale4DataType = [&](DataType data_type) -> std::string { auto it = data_type2loss_scale_lbn.find(data_type); if (it == data_type2loss_scale_lbn.end()) { const std::string& loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn(); const int64_t scope_symbol_id = ScopeSymbolId4Lbi(op_graph, GenLogicalBlobId(loss_scale_val_lbn)); const ParallelConf& parallel_conf = ProducerOpNode4Lbn(loss_scale_val_lbn)->parallel_desc().parallel_conf(); std::string loss_scale_lbn_with_data_type; if (data_type == DataType::kFloat) { loss_scale_lbn_with_data_type = loss_scale_val_lbn; } else { auto cast_op = user_op::UserOpConfWrapperBuilder("System-DynamicLossScale-Cast-" + NewUniqueId()) .Op("cast") .Input("in", loss_scale_val_lbn) .Output("out") .Attr("dtype", data_type) .ScopeSymbolId(scope_symbol_id) .Build(); loss_scale_lbn_with_data_type = cast_op.output("out", 0); job_builder->AddOps(parallel_conf, {cast_op.op_conf()}); } auto inv_scale_op = user_op::UserOpConfWrapperBuilder("System-DynamicLossScale-Reciprocal-" + NewUniqueId()) .Op("reciprocal") .Input("x", loss_scale_lbn_with_data_type) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {inv_scale_op.op_conf()}); std::string lbn = inv_scale_op.output("y", 0); data_type2loss_scale_lbn[data_type] = lbn; return lbn; } else { return it->second; } }; for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul_by_tensor") .Input("x", GenLogicalBlobName(diff_lbi)) .Input("scalar", LossScale4DataType(op_graph.GetLogicalBlobDesc(lbi).data_type())) .Output("y") .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi)) .Build(); job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()}); diff_lbi = GenLogicalBlobId(scalar_mul_op.output("y", 0)); } } else if (train_conf.has_loss_scale_factor()) { const float loss_scale_factor = train_conf.loss_scale_factor(); if (loss_scale_factor == 1) { return; } const float down_scale_factor = 1.0f / loss_scale_factor; for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul") .Input("in", GenLogicalBlobName(diff_lbi)) .Output("out") .Attr("has_float_operand", true) .Attr("float_operand", down_scale_factor) .Attr("has_int_operand", false) .Attr("int_operand", 0) .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi)) .Build(); job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()}); diff_lbi = GenLogicalBlobId(scalar_mul_op.output("out", 0)); } } else { return; } } void RegularizeGradient(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name()); int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id(); CHECK(model_op_node->op().op_conf().has_variable_conf()); const VariableOpConf& variable_conf = model_op_node->op().op_conf().variable_conf(); if (!variable_conf.has_regularizer()) { continue; } const RegularizerConf& regularizer_conf = variable_conf.regularizer(); if (regularizer_conf.has_l1_l2_conf()) { user_op::UserOpConfWrapper regularize_gradient_op = user_op::UserOpConfWrapperBuilder("System-RegularizeGradient-L1L2-" + NewUniqueId()) .Op("l1_l2_regularize_gradient") .Input("model", GenLogicalBlobName(lbi)) .Input("model_diff", GenLogicalBlobName(diff_lbi)) .Output("out") .Attr("l1", regularizer_conf.l1_l2_conf().l1()) .Attr("l2", regularizer_conf.l1_l2_conf().l2()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(), {regularize_gradient_op.op_conf()}); diff_lbi = GenLogicalBlobId(regularize_gradient_op.output("out", 0)); } else { UNIMPLEMENTED(); } } } void ClipGradient(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi, const ClipConf& clip_conf) { if (clip_conf.has_clip_by_global_norm()) { ClipGradientByGlobalNorm(ctx, op_graph, job_builder, lbi2diff_lbi, clip_conf.clip_by_global_norm()); } else { UNIMPLEMENTED(); } } void AddDiffParallelCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name()); if (model_op_node->parallel_desc().parallel_num() <= 1) { continue; } const int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id(); std::vector nd_sbp; const std::string& variable_sole_obn = model_op_node->op().SoleObn(); nd_sbp.reserve(model_op_node->NdSbp4BnInOp(variable_sole_obn).sbp_parallel().size()); for (const auto& sbp_parallel : model_op_node->NdSbp4BnInOp(variable_sole_obn).sbp_parallel()) { nd_sbp.emplace_back(SbpParallelToString(sbp_parallel)); } auto parallel_cast_op = user_op::UserOpConfWrapperBuilder("System-AutoGrad-ParallelCast-" + NewUniqueId()) .Op("hierarchical_parallel_cast") .Input("in", GenLogicalBlobName(diff_lbi)) .Output("out") .Attr>("nd_sbp", nd_sbp) .Attr("grad_mode", "auto") .Attr>("grad_nd_sbp", std::vector()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); diff_lbi = GenLogicalBlobId(parallel_cast_op.output("out", 0)); } } void AddDiffHalf2FloatCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { for (auto& pair : *lbi2diff_lbi) { LogicalBlobId& diff_lbi = pair.second; auto data_type = op_graph.GetLogicalBlobDesc(diff_lbi).data_type(); if (data_type != DataType::kFloat) { std::string lbn = GenLogicalBlobName(diff_lbi); const OpNode* op_node = op_graph.OpNode4OpName(diff_lbi.op_name()); int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id(); auto cast_op = user_op::UserOpConfWrapperBuilder(ReplaceSlashToDash4Lbn(lbn) + "-cast_h2f") .Op("cast") .Input("in", lbn) .Output("out") .Attr("dtype", DataType::kFloat) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {cast_op.op_conf()}); diff_lbi = GenLogicalBlobId(cast_op.output("out", 0)); } } } void AddDiffStaticShapeCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi) { for (auto& pair : *lbi2diff_lbi) { const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name()); int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id(); const auto cast_to_static_shape_op = user_op::UserOpConfWrapperBuilder("System-AutoGrad-StaticShapeCast-" + NewUniqueId()) .Op("cast_to_static_shape") .Input("input", GenLogicalBlobName(diff_lbi)) .Output("output") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(), {cast_to_static_shape_op.op_conf()}); diff_lbi = GenLogicalBlobId(cast_to_static_shape_op.output("output", 0)); } } Maybe CountNotFiniteIfNeeded(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& lbi2diff_lbi) { if (lbi2diff_lbi.empty()) { return Maybe::Ok(); } if (!ctx->job_desc().job_conf().train_conf().has_dynamic_loss_scale_policy()) { return Maybe::Ok(); } bool all_same_parallel_desc = true; const ParallelDesc& any_parallel_desc = op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc(); std::vector partial_count_not_finite_lbns; std::vector is_broadcast_nd_sbp; std::vector param_group_parallel_confs; ForEachAggregatedParamGroup( op_graph, lbi2diff_lbi, [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, const std::vector& lbis) { if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) { all_same_parallel_desc = false; } const int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf()); is_broadcast_nd_sbp.emplace_back(IsBroadcast(nd_sbp, parallel_desc)); param_group_parallel_confs.emplace_back(parallel_desc.parallel_conf()); if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) { auto multi_count_not_finite_op_builder = user_op::UserOpConfWrapperBuilder("System-DynamicLossScale-MultiCountNotFinite-" + NewUniqueId()) .Op("multi_count_not_finite") .Output("y") .ScopeSymbolId(scope_symbol_id); for (const auto& lbi : lbis) { multi_count_not_finite_op_builder.Input("x", GenLogicalBlobName(lbi2diff_lbi.at(lbi))); } const auto multi_count_not_finite_op = multi_count_not_finite_op_builder.Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {multi_count_not_finite_op.op_conf()}); partial_count_not_finite_lbns.emplace_back(multi_count_not_finite_op.output("y", 0)); } else { std::vector lbns_to_add; for (const auto& lbi : lbis) { const auto count_not_finite_op = user_op::UserOpConfWrapperBuilder("System-DynamicLossScale-CountNotFinite-" + NewUniqueId()) .Op("count_not_finite") .Input("x", GenLogicalBlobName(lbi2diff_lbi.at(lbi))) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_desc.parallel_conf(), {count_not_finite_op.op_conf()}); lbns_to_add.emplace_back(count_not_finite_op.output("y", 0)); } partial_count_not_finite_lbns.emplace_back( AddLbns(job_builder, lbns_to_add, parallel_desc.parallel_conf(), scope_symbol_id, "System-DynamicLossScale-CountNotFinite-Add-")); } }); const bool all_group_broadcast = std::all_of(is_broadcast_nd_sbp.begin(), is_broadcast_nd_sbp.end(), [](bool i) { return i; }); std::vector count_not_finite_lbns_for_add; ParallelConf count_all_parallel_conf = all_same_parallel_desc ? any_parallel_desc.parallel_conf() : GenParallelConfOfCpuZeroOnMaster(); if (!all_group_broadcast) { for (int64_t i = 0; i < partial_count_not_finite_lbns.size(); ++i) { count_not_finite_lbns_for_add.emplace_back(AddParallelCast( job_builder, JUST(VectorAt(partial_count_not_finite_lbns, i)), "P", JUST(VectorAt(param_group_parallel_confs, i)), "System-DynamicLossScale-ParallelCast-")); } count_all_parallel_conf.mutable_hierarchy()->clear_dim(); } else { count_not_finite_lbns_for_add = std::move(partial_count_not_finite_lbns); } const int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), count_all_parallel_conf); std::string count_all_lbn = AddLbns(job_builder, count_not_finite_lbns_for_add, count_all_parallel_conf, scope_symbol_id, "System-DynamicLossScale-CountNotFinite-Add-"); if (!all_group_broadcast) { std::vector cast_nd_sbp; cast_nd_sbp.emplace_back("B"); auto parallel_cast_op = user_op::UserOpConfWrapperBuilder( "System-DynamicLossScale-CountNotFinite-After-Add-ParallelCast-" + NewUniqueId()) .Op("hierarchical_parallel_cast") .Input("in", count_all_lbn) .Output("out") .Attr>("nd_sbp", cast_nd_sbp) .Attr("grad_mode", "auto") .Attr>("grad_nd_sbp", std::vector()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(count_all_parallel_conf, {parallel_cast_op.op_conf()}); count_all_lbn = parallel_cast_op.output("out", 0); } const LogicalBlobId count_not_finite_lbi = GenLogicalBlobId(JUST(ctx->GetState("dynamic_loss_scale_state")) .count_not_finite_lbn()); auto count_not_finite_op = user_op::UserOpConfWrapperBuilder(count_not_finite_lbi.op_name()) .Op("identity") .Input("in", count_all_lbn) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->MutOpsOnlyOnce({count_not_finite_op.op_conf()}); job_builder->MutParallelConfOnlyOnce(count_not_finite_op.op_name(), count_all_parallel_conf); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/autograd.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_ #define ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { class JobPassCtx; void AddDiffHalf2FloatCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); void AddDiffParallelCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); void AddDiffStaticShapeCast(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); Maybe CountNotFiniteIfNeeded(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& lbi2diff_lbi); Maybe MakeGetterLossOpNode4OpName( const OpGraph& op_graph, std::function* LossOpNode4OpName); Maybe ScaleModelDiffByLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); Maybe ScaleInitialDiffByLossScale( JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* loss_lbi2initial_diff_lbi); void ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); void RegularizeGradient(const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi); void ClipGradient(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, HashMap* lbi2diff_lbi, const ClipConf& clip_conf); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_ ================================================ FILE: oneflow/core/job_rewriter/autotick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/job_rewriter/autotick.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/critical_section_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { std::unique_ptr NewMutOpConTickInputHelper(const OperatorConf& op_conf) { std::unique_ptr ret; if (IsClassRegistered(op_conf.op_type_case())) { ret.reset(NewObj(op_conf.op_type_case())); ret->InitFromOpConf(op_conf); } return ret; } void PrependTickByParallelDesc(const OpGraph& op_graph, JobBuilder* job_builder) { HashMap> parallel_desc2op_node; op_graph.ForEachNode([&](OpNode* op_node) { auto mut_tick_input_helper = NewMutOpConTickInputHelper(op_node->op().op_conf()); if (!mut_tick_input_helper) { return; } if (mut_tick_input_helper->IsTickInputBound() == true) { return; } parallel_desc2op_node[op_node->parallel_desc()].emplace_back(op_node); }); for (const auto& pair : parallel_desc2op_node) { OperatorConf device_tick_op; device_tick_op.set_name("System-AutoTick-Prepend-DeviceTick_" + NewUniqueId()); auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf(); device_tick_op_conf->set_out("out"); job_builder->AddOps(pair.first.parallel_conf(), {device_tick_op}); for (const auto* op_node : pair.second) { auto mut_tick_input_helper = NewMutOpConTickInputHelper(op_node->op().op_conf()); job_builder->MutOpsOnlyOnce( {mut_tick_input_helper->NewTickInputBoundOpConf(device_tick_op.name() + "/out")}); } } } Maybe FindJobSoleSrcSubsetTickOpConf(const Job& job) { const OperatorConf* src_subset_tick_op_conf = nullptr; for (const auto& op_conf : job.net().op()) { if (!op_conf.has_src_subset_tick_conf()) { continue; } CHECK_ISNULL_OR_RETURN(src_subset_tick_op_conf); src_subset_tick_op_conf = &op_conf; } CHECK_NOTNULL_OR_RETURN(src_subset_tick_op_conf); return *src_subset_tick_op_conf; } Maybe BuildDstSubsetTickOpAndParallelConf(const HashSet& tick_lbis, OperatorConf* dst_subset_tick_op, JobBuilder* job_builder) { dst_subset_tick_op->set_name("System-AutoTick-DstSubsetTick_" + NewUniqueId()); auto* dst_subset_tick_op_conf = dst_subset_tick_op->mutable_dst_subset_tick_conf(); dst_subset_tick_op_conf->set_out("out"); for (const LogicalBlobId& tick_lbi : tick_lbis) { dst_subset_tick_op_conf->add_in(GenLogicalBlobName(tick_lbi)); } ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); for (int64_t machine_id : Singleton::Get()->process_ranks()) { parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); } JUST(job_builder->AddOp(parallel_conf, *dst_subset_tick_op)); return Maybe::Ok(); } Maybe CreateDstSubsetTickAndSinkTicks( const OperatorConf& src_subset_tick, const HashSet& tick_lbis, JobBuilder* job_builder, const std::function(int64_t machine_id, const std::string& op_name)>& DoEachSink) { OperatorConf dst_subset_tick; dst_subset_tick.mutable_dst_subset_tick_conf()->add_in( src_subset_tick.name() + "/" + src_subset_tick.src_subset_tick_conf().out()); JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder)); const auto& process_ranks = Singleton::Get()->process_ranks(); HashMap machine_id2gather_tick_in_lbns; for (int64_t machine_id : process_ranks) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); OperatorConf tick_op; { tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId()); auto* tick_conf = tick_op.mutable_tick_conf(); tick_conf->add_tick(dst_subset_tick.name() + "/" + dst_subset_tick.dst_subset_tick_conf().out()); tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, tick_op)); } CHECK_OR_RETURN( machine_id2gather_tick_in_lbns.emplace(machine_id, tick_op.name() + "/out").second); } for (int64_t machine_id : process_ranks) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); OperatorConf tick_op; { tick_op.set_name("System-SyncAllRanksSinkTick_" + NewUniqueId()); auto* tick_conf = tick_op.mutable_tick_conf(); // gather ticks from all processes. for (int64_t tick_machine_id : process_ranks) { tick_conf->add_tick(JUST(MapAt(machine_id2gather_tick_in_lbns, tick_machine_id))); } tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, tick_op)); } OperatorConf sink_tick_op; { sink_tick_op.set_name("System-AutoTick-SinkTick_" + NewUniqueId()); auto* sink_tick_conf = sink_tick_op.mutable_sink_tick_conf(); sink_tick_conf->add_tick(tick_op.name() + "/out"); sink_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, sink_tick_op)); } JUST(DoEachSink(machine_id, sink_tick_op.name())); } return Maybe::Ok(); } Maybe CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, const OperatorConf& src_subset_tick, const HashSet& tick_lbis, JobBuilder* job_builder) { auto* map = critical_section->mutable_machine_id2sink_tick_op_name(); const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe { (*map)[machine_id] = op_name; return Maybe::Ok(); }; JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink)); return Maybe::Ok(); } Maybe BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op, JobBuilder* job_builder) { src_subset_tick_op->set_name("System-AutoTick-SrcSubsetTick_" + NewUniqueId()); src_subset_tick_op->mutable_src_subset_tick_conf()->set_out("out"); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); for (int64_t machine_id : Singleton::Get()->process_ranks()) { parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); } JUST(job_builder->AddOp(parallel_conf, *src_subset_tick_op)); return Maybe::Ok(); } Maybe CreateSourceTicksAndSrcSubsetTick( OperatorConf* src_subset_tick_op, JobBuilder* job_builder, const std::function(int64_t machine_id, const std::string& op_name)>& DoEachSrc) { for (int64_t machine_id : Singleton::Get()->process_ranks()) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); OperatorConf src_tick_op; { src_tick_op.set_name("System-AutoTick-SourceTick_" + NewUniqueId()); src_tick_op.mutable_source_tick_conf()->set_out("out"); JUST(job_builder->AddOp(parallel_conf, src_tick_op)); } JUST(DoEachSrc(machine_id, src_tick_op.name())); OperatorConf tick_op; { tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId()); tick_op.mutable_tick_conf()->add_tick(src_tick_op.name() + "/out"); tick_op.mutable_tick_conf()->set_out("out"); JUST(job_builder->AddOp(parallel_conf, tick_op)); } src_subset_tick_op->mutable_src_subset_tick_conf()->add_in(tick_op.name() + "/out"); } JUST(job_builder->MutOpOnlyOnce(*src_subset_tick_op)); return Maybe::Ok(); } Maybe CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, OperatorConf* src_subset_tick_op, JobBuilder* job_builder) { auto* map = critical_section->mutable_machine_id2source_tick_op_name(); const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe { (*map)[machine_id] = op_name; return Maybe::Ok(); }; JUST(CreateSourceTicksAndSrcSubsetTick(src_subset_tick_op, job_builder, DoEachSrc)); return Maybe::Ok(); } Maybe ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick_op, JobBuilder* job_builder) { CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf()); const std::string& src_lbn = src_subset_tick_op.name() + "/" + src_subset_tick_op.src_subset_tick_conf().out(); JUST(job_builder->ForEachOperator([&](const Operator& op) -> Maybe { if (op.op_name() != src_subset_tick_op.name()) { CHECK_OR_RETURN(!op.op_conf().has_src_subset_tick_conf()); } auto mut_helper = NewMutOpConTickInputHelper(op.op_conf()); if (!mut_helper) { return Maybe::Ok(); } if (mut_helper->IsTickInputBound() == true) { return Maybe::Ok(); } JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn))); return Maybe::Ok(); })); return Maybe::Ok(); } Maybe GetSrcSubsetTickOpNode(const OpGraph& op_graph) { const OpNode* src_subset_tick = nullptr; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (op_node->op().op_conf().has_src_subset_tick_conf()) { CHECK_ISNULL_OR_RETURN(src_subset_tick); src_subset_tick = op_node; } return Maybe::Ok(); })); CHECK_NOTNULL_OR_RETURN(src_subset_tick); return src_subset_tick; } OperatorConf MakeTickOpConf(const std::string& tick_name) { OperatorConf tick_op_conf; tick_op_conf.set_name(std::string("System-AutoTick-" + tick_name + "Tick_") + NewUniqueId()); auto* tick_conf = tick_op_conf.mutable_tick_conf(); tick_conf->set_out("out"); return tick_op_conf; } OperatorConf MakeDeviceTickOpConf(const std::string& tick_name) { OperatorConf device_tick_op_conf; device_tick_op_conf.set_name(std::string("System-AutoTick-" + tick_name + "DeviceTick_") + NewUniqueId()); auto* tick_conf = device_tick_op_conf.mutable_device_tick_conf(); tick_conf->set_out("out"); return device_tick_op_conf; } OperatorConf AppendTick(const std::string tick_name, const std::vector& op_names, const std::shared_ptr& time_shape, ParallelConf parallel_conf, JobBuilder* job_builder) { OperatorConf device_tick_op_conf = MakeDeviceTickOpConf(tick_name); if (time_shape) { time_shape->ToProto(device_tick_op_conf.mutable_device_tick_conf()->mutable_time_shape()); } for (const auto& op_name : op_names) { device_tick_op_conf.add_ctrl_in_op_name(op_name); } job_builder->AddOps(parallel_conf, {device_tick_op_conf}); return device_tick_op_conf; } OperatorConf AppendTick(const std::string tick_name, const std::list& op_nodes, const std::shared_ptr& time_shape, JobBuilder* job_builder) { std::vector op_names; op_names.reserve(op_nodes.size()); for (const auto* op_node : op_nodes) { CHECK(op_nodes.front()->parallel_desc() == op_node->parallel_desc()); op_names.emplace_back(op_node->op().op_name()); } return AppendTick(tick_name, op_names, time_shape, op_nodes.front()->parallel_desc().parallel_conf(), job_builder); } OperatorConf PrependTick(const HashSet& op_nodes, JobBuilder* job_builder) { CHECK_GE(op_nodes.size(), 1); OperatorConf tick_op_conf = MakeTickOpConf("Prepend"); std::vector op_confs; op_confs.reserve(op_nodes.size()); for (const OpNode* op_node : op_nodes) { OperatorConf op_conf(op_node->op().op_conf()); op_conf.add_ctrl_in_op_name(tick_op_conf.name()); op_confs.emplace_back(op_conf); } job_builder->MutOpsOnlyOnce({op_confs}); ParallelDesc pd((*op_nodes.begin())->parallel_desc()); pd.set_device_type(DeviceType::kCPU); job_builder->AddOps(pd.parallel_conf(), {tick_op_conf}); return tick_op_conf; } OperatorConf AppendAccTick(const Shape& src_shape, const std::list& op_nodes, JobBuilder* job_builder) { std::shared_ptr tick_shape = CHECK_JUST(op_nodes.front()->op().GetOpTimeShape()); CHECK_EQ(tick_shape->elem_cnt() % src_shape.elem_cnt(), 0); const OperatorConf& tick_op_conf = AppendTick("AppendAcc", op_nodes, tick_shape, job_builder); OperatorConf acc_op_conf; { acc_op_conf.set_name(std::string("System-AutoTick-AccTick_") + NewUniqueId()); auto* acc_conf = acc_op_conf.mutable_acc_tick_conf(); CHECK(tick_op_conf.has_device_tick_conf()); acc_conf->set_one(tick_op_conf.name() + "/" + tick_op_conf.device_tick_conf().out()); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(tick_shape->elem_cnt() / src_shape.elem_cnt()); } OperatorConf last_device_tick_op_conf; { last_device_tick_op_conf.set_name(std::string("System-AutoTick-Tick_") + NewUniqueId()); auto* device_tick_conf = last_device_tick_op_conf.mutable_device_tick_conf(); device_tick_conf->add_tick(acc_op_conf.name() + "/acc"); device_tick_conf->set_out("out"); } job_builder->AddOps(op_nodes.front()->parallel_desc().parallel_conf(), {acc_op_conf, last_device_tick_op_conf}); return last_device_tick_op_conf; } std::vector GetOpNames(const HashSet& op_nodes) { std::vector ret; ret.reserve(op_nodes.size()); for (const OpNode* op_node : op_nodes) { ret.emplace_back(op_node->op().op_name()); } return ret; }; Maybe InitOpTypeCase2OpNodes( const OpGraph& op_graph, HashMap>* op_type_case2op_nodes) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { const auto& op_conf = op_node->op().op_conf(); if (IsInterfaceOpConf(op_conf)) { CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second); } return Maybe::Ok(); })); return Maybe::Ok(); } Maybe ForEachInputCriticalSectionOpNodes( const OpGraph& op_graph, const std::function(const HashSet&, const std::vector&)>& Handler) { HashMap> op_type_case2op_nodes; JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes)); OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf; if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe::Ok(); } HashSet op_nodes = op_type_case2op_nodes[op_type_case]; for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) { op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); }); } JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]))); return Maybe::Ok(); } Maybe ForEachOutputCriticalSectionOpNodes( const OpGraph& op_graph, const std::function(const HashSet&, const std::vector&)>& Handler) { HashMap> op_type_case2op_nodes; JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes)); if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) { JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf], GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf]))); } if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) { JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf], GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]))); } return Maybe::Ok(); } Maybe> AddTickForTimeShape(const Shape& src_time_shape, const HashSet& op_nodes, JobBuilder* job_builder) { HashMap>, std::list> pd7ts2op_nodes; for (const OpNode* op_node : op_nodes) { auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()), *JUST(op_node->op().GetOpTimeShape())); pd7ts2op_nodes[{op_node->parallel_desc(), ts}].emplace_back(op_node); } std::vector op_confs; op_confs.reserve(pd7ts2op_nodes.size()); for (const auto& pair : pd7ts2op_nodes) { const std::pair& ts = pair.first.second; if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) { CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt()); op_confs.emplace_back( AppendTick("Append", pair.second, std::make_shared(ts.second), job_builder)); } else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) { op_confs.emplace_back(AppendAccTick(src_time_shape, pair.second, job_builder)); } else { UNIMPLEMENTED_THEN_RETURN(); } } return op_confs; } Maybe AddGlobalInputOutputCriticalSection( const HashSet& op_nodes, const std::vector& lbi_producer_op_names, JobBuilder* job_builder) { auto* critical_section = Singleton::Get()->AddCriticalSection(GlobalJobDesc().job_id()); { auto* io_cs = critical_section->mutable_input_output_critical_section(); *io_cs->mutable_lbi_producer_op_name() = {lbi_producer_op_names.begin(), lbi_producer_op_names.end()}; } auto time_shape = std::make_unique(DimVector{1, 1}); HashMap> parallel_desc2op_nodes; for (const OpNode* op_node : op_nodes) { CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second); } std::vector source_ticks; std::vector sink_ticks; source_ticks.reserve(parallel_desc2op_nodes.size()); for (const auto& pair : parallel_desc2op_nodes) { source_ticks.emplace_back(PrependTick(pair.second, job_builder)); const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder)); for (const auto& sink_tick : *ops) { sink_ticks.emplace_back(sink_tick); } } OperatorConf src_subset_tick_op; { CHECK_EQ_OR_RETURN(source_ticks.empty(), false); JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder)); for (auto& op_conf : source_ticks) { op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/" + src_subset_tick_op.src_subset_tick_conf().out()); } job_builder->MutOpsOnlyOnce(source_ticks); } HashSet tick_lbis; for (const auto& op_conf : sink_ticks) { LogicalBlobId lbi; lbi.set_op_name(op_conf.name()); CHECK_OR_RETURN(op_conf.has_device_tick_conf()); lbi.set_blob_name(op_conf.device_tick_conf().out()); CHECK_OR_RETURN(tick_lbis.insert(lbi).second); } JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis, job_builder)); return Maybe::Ok(); } Maybe MultiClientAddOneWaitAndSendIdsOp(JobBuilder* job_builder, int64_t machine_id, const OperatorConf& src_op_consumer) { ParallelConf parallel_conf; { parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); } // add wait_and_send_ids op conf OperatorConf wait_and_send_ids_op_conf; { wait_and_send_ids_op_conf.set_name(std::string("System-Src-WaitAndSendIds_") + NewUniqueId()); wait_and_send_ids_op_conf.set_pass_tag(kMainOp); auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf(); wait_and_send_ids_conf->set_out("out"); wait_and_send_ids_conf->set_wait_buffer_name("UnimplementedBufferName"); wait_and_send_ids_conf->set_data_type(DataType::kInt32); // wait_and_send_ids_conf->id_list() is unused in multi-client mode. } JUST(job_builder->AddOp(parallel_conf, wait_and_send_ids_op_conf)); // connect wait_and_send_ids to tick op which was connected to the src tick op OperatorConf tick_op_conf; tick_op_conf.CopyFrom(src_op_consumer); CHECK_OR_RETURN(tick_op_conf.has_tick_conf()); CHECK_EQ_OR_RETURN(tick_op_conf.tick_conf().tick_size(), 1); tick_op_conf.mutable_tick_conf()->clear_tick(); tick_op_conf.mutable_tick_conf()->add_tick( GenLogicalBlobName(wait_and_send_ids_op_conf.name(), "out")); JUST(job_builder->MutOpOnlyOnce(tick_op_conf)); return Maybe::Ok(); } Maybe MultiClientAddWaitAndSendIds( JobBuilder* job_builder, const HashMap& machine_id2src_op_name) { // Prepare the consumer tick op for each Source op HashMap src_op_name2solo_consumer_tick_op; HashSet src_op_names; for (const auto& pair : machine_id2src_op_name) { CHECK_OR_RETURN(src_op_names.insert(pair.second).second) << " duplicated src op name " << pair.second; } JUST(job_builder->ForEachOperator([&](const Operator& op) -> Maybe { // skip if the op is not a tick op if (!op.op_conf().has_tick_conf()) { return Maybe::Ok(); } for (const auto& ibn : op.input_bns()) { const auto& input_lbi = op.BnInOp2Lbi(ibn); if (src_op_names.count(input_lbi.op_name()) == 0) { continue; } auto insert_pair = src_op_name2solo_consumer_tick_op.emplace(input_lbi.op_name(), op.op_conf()); CHECK_OR_RETURN(insert_pair.second) << " Duplicated src op name " << input_lbi.op_name() << " old op " << insert_pair.first->second.DebugString() << " new op " << op.op_conf().DebugString(); } return Maybe::Ok(); })); // Replace Source op with WaitAndSendIds op for (const auto& pair : machine_id2src_op_name) { auto tick_op_iter = src_op_name2solo_consumer_tick_op.find(pair.second); CHECK_OR_RETURN(tick_op_iter != src_op_name2solo_consumer_tick_op.end()) << "Can't find consumer tick op of source op name " << pair.second << " machine id " << pair.first; JUST(MultiClientAddOneWaitAndSendIdsOp(job_builder, pair.first, tick_op_iter->second)); } // Delete Source op std::vector src_op_name_vec{src_op_names.begin(), src_op_names.end()}; job_builder->DelOps(src_op_name_vec); return Maybe::Ok(); } Maybe MultiClientAddCallbackNotifier(JobBuilder* job_builder, int64_t machine_id, const std::string& sink_op_name) { ParallelConf parallel_conf; { parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0"); } OperatorConf callback_notify_op_conf; { callback_notify_op_conf.set_name(std::string("System-Sink-CallbackNotify_") + NewUniqueId()); callback_notify_op_conf.set_pass_tag(kMainOp); auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf(); callback_notify_conf->set_in(GenLogicalBlobName(sink_op_name, "out")); // callback_notify_conf->callback_buffer_name() is unused in multi-client mode. } JUST(job_builder->AddOp(parallel_conf, callback_notify_op_conf)); return Maybe::Ok(); } } // namespace Maybe AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) { PrependTickByParallelDesc(op_graph, job_builder); OperatorConf src_subset_tick_op; JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder)); return Maybe::Ok(); } Maybe AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) { const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph)); const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape()); HashSet sink_op_nodes; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf()); size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); if (out_cnt == 0) { sink_op_nodes.insert(op_node); } return Maybe::Ok(); })); JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder)); return Maybe::Ok(); } Maybe AutoSourceAndSinkTick( const OpGraph& op_graph, JobBuilder* job_builder, const std::function(int64_t machine_id, const std::string& op_name)>& DoEachSrc, const std::function(int64_t machine_id, const std::string& op_name)>& DoEachSink) { JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe { CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf()); return Maybe::Ok(); })); const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph)); const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape()); HashSet tick_lbis; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); if (out_cnt > 0) { return Maybe::Ok(); } CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf()); CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt()); CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second); return Maybe::Ok(); })); OperatorConf src_subset_tick = JUST(FindJobSoleSrcSubsetTickOpConf(job_builder->job())); JUST(CreateSourceTicksAndSrcSubsetTick(&src_subset_tick, job_builder, DoEachSrc)); JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink)); return Maybe::Ok(); } Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) { HashMap machine_id2src_op_name; HashMap machine_id2sink_op_name; { JobBuilder job_builder(job); const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe { CHECK_OR_RETURN(machine_id2src_op_name.emplace(machine_id, op_name).second); return Maybe::Ok(); }; const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe { CHECK_OR_RETURN(machine_id2sink_op_name.emplace(machine_id, op_name).second); return Maybe::Ok(); }; JUST(AutoSourceAndSinkTick(op_graph, &job_builder, DoEachSrc, DoEachSink)); } { JobBuilder job_builder(job); JUST(MultiClientAddWaitAndSendIds(&job_builder, machine_id2src_op_name)); for (const auto& pair : machine_id2sink_op_name) { JUST(MultiClientAddCallbackNotifier(&job_builder, pair.first, pair.second)); } } return Maybe::Ok(); } namespace { Maybe InsertCriticalSectionSrcAndDstTicks( const std::vector& interface_op_nodes, JobBuilder* job_builder, std::vector* interface_src_tick_op_names, std::vector* interface_dst_tick_lbns) { HashMap> parallel_desc2interface_op_nodes; for (const auto* op_node : interface_op_nodes) { parallel_desc2interface_op_nodes[op_node->parallel_desc()].push_back(op_node); } for (const auto& pair : parallel_desc2interface_op_nodes) { const auto& parallel_conf = pair.first.parallel_conf(); for (const auto* op_node : pair.second) { OperatorConf interface_op(op_node->op().op_conf()); { OperatorConf device_tick_op; device_tick_op.set_name("System-EagerCriticalSection-Interface-Begin-Tick-" + NewUniqueId()); auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf(); device_tick_op_conf->set_out("out"); interface_src_tick_op_names->push_back(device_tick_op.name()); JUST(job_builder->AddOp(parallel_conf, device_tick_op)); interface_op.add_ctrl_in_op_name(device_tick_op.name()); JUST(job_builder->MutOpOnlyOnce(interface_op)); } { OperatorConf device_tick_op; device_tick_op.set_name("System-EagerCriticalSection-Interface-End-Tick-" + NewUniqueId()); device_tick_op.add_ctrl_in_op_name(interface_op.name()); auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf(); device_tick_op_conf->set_out("out"); interface_dst_tick_lbns->push_back(device_tick_op.name() + "/out"); JUST(job_builder->AddOp(parallel_conf, device_tick_op)); } } } return Maybe::Ok(); } Maybe InsertSrcSubsetTickAndDstSubsetTick( const std::vector& interface_src_tick_op_names, const std::vector& interface_dst_tick_lbns, JobBuilder* job_builder, std::string* src_subset_tick_op_name, LogicalBlobId* dst_subset_tick_lbi) { { OperatorConf src_subset_tick; JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick, job_builder)); *src_subset_tick_op_name = src_subset_tick.name(); } for (const auto& op_name : interface_src_tick_op_names) { OperatorConf op_conf(JUST(job_builder->OpConf4OpName(op_name))); CHECK_OR_RETURN(op_conf.has_device_tick_conf()); op_conf.mutable_device_tick_conf()->add_tick(*src_subset_tick_op_name + "/out"); JUST(job_builder->MutOpOnlyOnce(op_conf)); } HashSet dst_subset_tick_input_lbis; dst_subset_tick_input_lbis.insert(GenLogicalBlobId(*src_subset_tick_op_name + "/out")); for (const auto& lbn : interface_dst_tick_lbns) { const auto& lbi = GenLogicalBlobId(lbn); CHECK_OR_RETURN(dst_subset_tick_input_lbis.insert(lbi).second); } { OperatorConf dst_subset_tick_op; JUST(BuildDstSubsetTickOpAndParallelConf(dst_subset_tick_input_lbis, &dst_subset_tick_op, job_builder)); dst_subset_tick_lbi->set_op_name(dst_subset_tick_op.name()); CHECK_OR_RETURN(dst_subset_tick_op.has_dst_subset_tick_conf()); dst_subset_tick_lbi->set_blob_name(dst_subset_tick_op.dst_subset_tick_conf().out()); } return Maybe::Ok(); } Maybe InsertCriticalSectionWaitTicks(const OpGraph& op_graph, JobBuilder* job_builder, const std::string& src_subset_tick_op_name, const std::string& wait_buffer_name) { std::vector wait_and_send_id_op_nodes; op_graph.ForEachNode([&](OpNode* op_node) { if (!op_node->op().op_conf().has_wait_and_send_ids_conf()) { return; } wait_and_send_id_op_nodes.push_back(op_node); }); CHECK_GT_OR_RETURN(wait_and_send_id_op_nodes.size(), 0); OperatorConf src_subset_tick_op(JUST(job_builder->OpConf4OpName(src_subset_tick_op_name))); CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf()); for (const OpNode* wait_and_send_id_op_node : wait_and_send_id_op_nodes) { LogicalBlobId lbi; lbi.set_op_name(wait_and_send_id_op_node->op().op_name()); lbi.set_blob_name(wait_and_send_id_op_node->op().op_conf().wait_and_send_ids_conf().out()); OperatorConf critical_section_wait_op; { critical_section_wait_op.set_name("System-EagerCriticalSection-Wait-" + NewUniqueId()); auto* conf = critical_section_wait_op.mutable_critical_section_wait_tick_conf(); conf->add_tick(GenLogicalBlobName(lbi)); conf->set_out("out"); conf->set_buffer_name(wait_buffer_name); } const auto& parallel_conf = wait_and_send_id_op_node->parallel_desc().parallel_conf(); JUST(job_builder->AddOp(parallel_conf, critical_section_wait_op)); src_subset_tick_op.mutable_src_subset_tick_conf()->add_in(critical_section_wait_op.name() + "/out"); } JUST(job_builder->MutOpOnlyOnce(src_subset_tick_op)); return Maybe::Ok(); } Maybe InsertCriticalSectionCallbackTicks(const OpGraph& op_graph, JobBuilder* job_builder, const LogicalBlobId& dst_subset_tick_lbi, const std::string& callback_buffer_name) { OperatorConf critical_section_callback_op; critical_section_callback_op.set_name("System-EagerCriticalSection-Callback-" + NewUniqueId()); auto* conf = critical_section_callback_op.mutable_critical_section_callback_tick_conf(); conf->add_tick(GenLogicalBlobName(dst_subset_tick_lbi)); conf->set_out("out"); conf->set_buffer_name(callback_buffer_name); const auto& op_name = dst_subset_tick_lbi.op_name(); const auto& parallel_conf = JUST(job_builder->ParallelConf4OpName(op_name)); JUST(job_builder->AddOp(parallel_conf, critical_section_callback_op)); LogicalBlobId critical_section_callback_lbi; critical_section_callback_lbi.set_op_name(critical_section_callback_op.name()); critical_section_callback_lbi.set_blob_name("out"); return critical_section_callback_lbi; } Maybe MultiClientAutoCriticalSectionTick( const OpGraph& op_graph, JobBuilder* job_builder, const std::vector& interface_op_nodes, const std::string& wait_buffer_name, const std::string& callback_buffer_name) { std::vector interface_src_tick_op_names; std::vector interface_dst_tick_lbns; JUST(InsertCriticalSectionSrcAndDstTicks(interface_op_nodes, job_builder, &interface_src_tick_op_names, &interface_dst_tick_lbns)); std::string src_subset_tick_op_name; LogicalBlobId dst_subset_tick_lbi; JUST(InsertSrcSubsetTickAndDstSubsetTick(interface_src_tick_op_names, interface_dst_tick_lbns, job_builder, &src_subset_tick_op_name, &dst_subset_tick_lbi)); JUST(InsertCriticalSectionWaitTicks(op_graph, job_builder, src_subset_tick_op_name, wait_buffer_name)); const auto& lbi = JUST(InsertCriticalSectionCallbackTicks( op_graph, job_builder, dst_subset_tick_lbi, callback_buffer_name)); return lbi; } Maybe ConnectCriticalSectionCallbackToJobSoleDstSubsetTick( const OpGraph& op_graph, JobBuilder* job_builder, const std::vector>& critical_section_callback_lbis) { const OpNode* dst_subset_tick_op_node = nullptr; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (!op_node->op().op_conf().has_dst_subset_tick_conf()) { return Maybe::Ok(); } CHECK_OR_RETURN(dst_subset_tick_op_node == nullptr); dst_subset_tick_op_node = op_node; return Maybe::Ok(); })); CHECK_NOTNULL_OR_RETURN(dst_subset_tick_op_node); OperatorConf dst_subset_tick_op(dst_subset_tick_op_node->op().op_conf()); auto* conf = dst_subset_tick_op.mutable_dst_subset_tick_conf(); for (const auto& lbi : critical_section_callback_lbis) { conf->add_in(GenLogicalBlobName(*lbi)); } JUST(job_builder->MutOpOnlyOnce(dst_subset_tick_op)); return Maybe::Ok(); } } // namespace Maybe MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job) { JobBuilder job_builder(job); std::vector> critical_section_callback_lbis; { std::vector interface_op_nodes; op_graph.ForEachNode([&](OpNode* node) { if (node->op().op_conf().has_input_conf()) { interface_op_nodes.push_back(node); } }); const auto& lbi = JUST(MultiClientAutoCriticalSectionTick( op_graph, &job_builder, interface_op_nodes, GetInputCriticalSectionWaitBufferName(job->job_conf().job_name()), GetInputCriticalSectionCallbackBufferName(job->job_conf().job_name()))); critical_section_callback_lbis.push_back(lbi); } { std::vector interface_op_nodes; op_graph.ForEachNode([&](OpNode* node) { if (node->op().op_conf().has_output_conf()) { interface_op_nodes.push_back(node); } }); const auto& lbi = JUST(MultiClientAutoCriticalSectionTick( op_graph, &job_builder, interface_op_nodes, GetOutputCriticalSectionWaitBufferName(job->job_conf().job_name()), GetOutputCriticalSectionCallbackBufferName(job->job_conf().job_name()))); critical_section_callback_lbis.push_back(lbi); } JUST(ConnectCriticalSectionCallbackToJobSoleDstSubsetTick(op_graph, &job_builder, critical_section_callback_lbis)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/autotick.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_ #define ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { Maybe AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder); Maybe AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder); Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job); Maybe MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job); class MutOpConTickInputHelper { public: bool IsTickInputBound() const { return VirtualIsTickInputBound(); } virtual bool VirtualIsTickInputBound() const = 0; virtual OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const = 0; void InitFromOpConf(const OperatorConf& op_conf) { op_conf_ = &op_conf; } virtual ~MutOpConTickInputHelper() = default; protected: MutOpConTickInputHelper() : op_conf_(nullptr) {} const OperatorConf& op_conf() const { return *op_conf_; } private: const OperatorConf* op_conf_; }; #define REGISTER_AUTO_TICK(op_type_case, HelperType) \ REGISTER_CLASS(int32_t, op_type_case, MutOpConTickInputHelper, HelperType) } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_ ================================================ FILE: oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/boxing_with_middle_nodes.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/common/container_util.h" namespace oneflow { Maybe BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_builder) { // Not allowed two-step boxing and disable checking for debugging if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) { return Maybe::Ok(); } // Initialize boxing collector BoxingCollector boxing_collector; std::vector middle_sbps; HashMap op_node2op_conf; // Fill other unsupported combinations op_graph.ForEachNode([&](const OpNode* node) -> Maybe { OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case(); if (IsClassRegistered(op_type_case)) { return Maybe::Ok(); } for (const std::string& ibn : node->op().input_bns()) { const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn); const OpNode& producer = node->ProducerOpNode4Lbi(lbi); const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi); const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn); // If dealing with different placement if (producer.parallel_desc().parallel_num() != 1 || node->parallel_desc().parallel_num() != 1) { const auto& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi); // Ask for middle nodes int32_t diag_node = 0; JUST(boxing_collector.AskSbpCombination(producer_nd_sbp, consumer_nd_sbp, logical_blob_desc, producer.parallel_desc(), node->parallel_desc(), /*is_customized=*/false, middle_sbps, &diag_node, /*compute_cost=*/false)); // move to the next ibn if no middle nodes needed if (middle_sbps.size() <= 0) { continue; } LogicalBlobId middle_node_lbi = lbi; VLOG(3) << " Lbi " << lbi.op_name() << "/" << lbi.blob_name() << " src sbp " << NdSbpToString(producer_nd_sbp); VLOG(3) << " Lbi " << lbi.op_name() << "/" << lbi.blob_name() << " dst sbp " << NdSbpToString(consumer_nd_sbp); for (int32_t middle_node_id = 0; middle_node_id < middle_sbps.size(); middle_node_id++) { VLOG(3) << " Lbi " << lbi.op_name() << "/" << lbi.blob_name() << " add middle node " << NdSbpToString(JUST(VectorAt(middle_sbps, middle_node_id))); // Create the middle operators OperatorConf identity_op_conf{}; identity_op_conf.set_name("System-Boxing-Middle-Identity-" + NewUniqueId()); IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf(); identity_conf->set_in(GenLogicalBlobName(middle_node_lbi)); identity_conf->set_out("out"); if (middle_node_id < diag_node) { job_builder->AddOps(producer.parallel_desc().parallel_conf(), {identity_op_conf}); } else { job_builder->AddOps(node->parallel_desc().parallel_conf(), {identity_op_conf}); } NdSbpSignature identity_nd_sbp_signature; (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())["in"] = middle_sbps[middle_node_id]; (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())["out"] = middle_sbps[middle_node_id]; job_builder->AddNdSbpSignature4OpName(identity_op_conf.name(), identity_nd_sbp_signature); // Connection for the next middle node middle_node_lbi.set_op_name(identity_op_conf.name()); middle_node_lbi.set_blob_name(identity_conf->out()); } // Replace input blob with configuration from middle nodes if (op_node2op_conf.find(node) == op_node2op_conf.end()) { op_node2op_conf[node] = node->op().op_conf(); } OperatorConf& consumer_op_conf = op_node2op_conf[node]; const auto& old_val = ReplaceInputLbnInOpCustomizedConf( &consumer_op_conf, ibn, GenLogicalBlobName(middle_node_lbi)); CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val); } } return Maybe::Ok(); }); for (const auto& op_node7op_conf : op_node2op_conf) { JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second)); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/boxing_with_middle_nodes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_ #define ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_ #include "oneflow/core/graph/op_graph.h" namespace oneflow { class OpGraph; class Job; Maybe BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_builder); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_ ================================================ FILE: oneflow/core/job_rewriter/calculation_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/calculation_pass.h" namespace oneflow { const std::string kForwardPass = "forward_pass"; const std::string kBackwardPass = "backward_pass"; const std::string kOptimizerPass = "optimizer_pass"; } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/calculation_pass.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ #define ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ #include namespace oneflow { extern const std::string kForwardPass; extern const std::string kBackwardPass; extern const std::string kOptimizerPass; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ ================================================ FILE: oneflow/core/job_rewriter/checkpointing_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { namespace { // Do CheckpointingPass will use backward recomputation for sublinear memory cost. class CheckpointingPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(CheckpointingPass); CheckpointingPass() = default; ~CheckpointingPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; const std::string kCheckpointingFakeOpNamePrefix = "Sys-Checkpointing-Fake-Fw-Op_"; const std::string kCheckpointingIdentityOpName = "Sys-Checkpointing-Identity"; const std::string kCheckpointingBadOpName = "Sys-CheckpointPassBadEndOpName"; const Scope& Scope4OpNode(const OpNode* op_node) { int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id(); CHECK(Singleton>::Get()->Has(scope_symbol_id)) << "rank[" << GlobalProcessCtx::Rank() << "] " << "scope_symbol_id: " << scope_symbol_id; return Singleton>::Get()->Get(scope_symbol_id); } bool IsForwardPassScope(const Scope& scope) { return scope.scope_proto().calculation_pass_name() == kForwardPass; } bool IsForwardPass7CheckpointingScope(const Scope& scope) { return IsForwardPassScope(scope) && scope.Bool("checkpointing"); } void CollectAllCheckpointingOpsInForwardPass( const OpGraph& op_graph, HashMap* checkpointing_op_name2op_node) { // NOTE(chengcheng): // ignore batch_norm ops because of recompute bn will repeat the calculation of 'm' and 'v'. // in the future, we need to support the recomputation version of batch_norm which do NOT // update forward variables. HashSet ignore_op_type_names = {"normalization", "normalization_add_relu", "cudnn_fused_normalization_add_relu", "repeat", "unpack"}; op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } if (ignore_op_type_names.find(op_conf.user_conf().op_type_name()) != ignore_op_type_names.end()) { return; } if (IsForwardPass7CheckpointingScope(Scope4OpNode(op_node))) { CHECK(checkpointing_op_name2op_node->emplace(op_conf.name(), op_node).second); } }); } void GenConnectedCheckpointingSubgraphs( const HashMap& checkpointing_op_name2op_node, std::vector>* checkpointing_subgraphs) { HashSet visited_nodes; checkpointing_subgraphs->reserve(checkpointing_op_name2op_node.size()); for (const auto& pair : checkpointing_op_name2op_node) { const OpNode* node = pair.second; if (visited_nodes.find(node) != visited_nodes.end()) { continue; } // new subgraph checkpointing_subgraphs->emplace_back(HashSet()); CHECK(!checkpointing_subgraphs->empty()); auto& subgraph = checkpointing_subgraphs->back(); CHECK(subgraph.empty()); // bfs search all node in checkpointing ops CHECK(visited_nodes.insert(node).second); std::queue queued_nodes; queued_nodes.push(node); while (!queued_nodes.empty()) { const OpNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK(subgraph.insert(cur_node).second); cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { const std::string& next_op_name = next_node->op().op_name(); if (checkpointing_op_name2op_node.find(next_op_name) != checkpointing_op_name2op_node.end() && cur_node->parallel_desc() == next_node->parallel_desc() && visited_nodes.find(next_node) == visited_nodes.end()) { queued_nodes.push(next_node); CHECK(visited_nodes.insert(next_node).second); } }); } } } Maybe CheckpointingPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { // step 1. collect all checkpointing ops in forwardpass. HashMap checkpointing_op_name2op_node; CollectAllCheckpointingOpsInForwardPass(op_graph, &checkpointing_op_name2op_node); if (checkpointing_op_name2op_node.empty()) { return Maybe::Ok(); } // step 2. get all connected subgraphs in checkpointing ops. std::vector> checkpointing_subgraphs; GenConnectedCheckpointingSubgraphs(checkpointing_op_name2op_node, &checkpointing_subgraphs); HashMap op_node2order; int32_t order = 0; op_graph.TopoForEachNode([&](const OpNode* op_node) { CHECK(op_node2order.emplace(op_node, order).second); ++order; }); // step 3. for each subgraphs: // NOTE(chengcheng): // maybe a bw consumer will consume multi subgraph for recompute. // so we need collect bw consumer between subgraphs, and update them in job builder only once. HashMap total_bw_consumers_op_name2conf; int32_t subgraph_id = 0; for (auto& subgraph : checkpointing_subgraphs) { // step 3.1 ignore this subgraph if there is no direct edge to backward pass op. HashSet bw_consumers; for (const OpNode* node : subgraph) { node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { if (!IsForwardPassScope(Scope4OpNode(out_node))) { bw_consumers.insert(out_node); CHECK(subgraph.find(out_node) == subgraph.end()); } }); } if (bw_consumers.empty()) { continue; } HashSet checkpointing_tensor; HashMap subgraph_op_name2op_node; ParallelConf parallel_conf; for (const OpNode* node : subgraph) { subgraph_op_name2op_node.emplace(node->op().op_name(), node); parallel_conf = node->parallel_desc().parallel_conf(); } // step 3.2 generate fake subgraph for recomputation HashMap fake_op_name2conf; HashSet source_node_in_fake_subgraph; for (const OpNode* node : subgraph) { OperatorConf fake_op_conf = node->op().op_conf(); std::string fake_op_name = kCheckpointingFakeOpNamePrefix + fake_op_conf.name(); fake_op_conf.set_name(fake_op_name); const int64_t old_scope_symbol_id = fake_op_conf.scope_symbol_id(); // update fake op conf scope from fw to bw const int64_t new_scope_symbol_id = JUST( NewScopeSymbolId(old_scope_symbol_id, [](const std::shared_ptr& new_scope) { CHECK_EQ(new_scope->calculation_pass_name(), kForwardPass); new_scope->set_calculation_pass_name(kBackwardPass); })); fake_op_conf.set_scope_symbol_id(new_scope_symbol_id); auto* user_conf = fake_op_conf.mutable_user_conf(); // change output lbns for (auto& pair : *(user_conf->mutable_output())) { auto& list_s = pair.second; for (int i = 0; i < list_s.s_size(); ++i) { std::string old_lbn = list_s.s(i); list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn); // check valid LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); CHECK_EQ(node->op().op_conf().name(), old_lbi.op_name()); CHECK_EQ(kCheckpointingFakeOpNamePrefix + old_lbi.op_name(), fake_op_name); std::string new_lbn = list_s.s(i); LogicalBlobId new_lbi = GenLogicalBlobId(new_lbn); CHECK_EQ(new_lbi.op_name(), fake_op_name); CHECK_EQ(old_lbi.blob_name(), new_lbi.blob_name()); } } int32_t input_num = 0; // change input lbns if in subgraph for (auto& pair : *(user_conf->mutable_input())) { auto& list_s = pair.second; for (int i = 0; i < list_s.s_size(); ++i) { ++input_num; std::string old_lbn = list_s.s(i); LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); std::string old_input_op_name = old_lbi.op_name(); if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) { list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn); } else { source_node_in_fake_subgraph.insert(fake_op_name); checkpointing_tensor.insert(old_lbi); } } } if (input_num == 0) { source_node_in_fake_subgraph.insert(fake_op_name); } fake_op_name2conf.emplace(fake_op_name, fake_op_conf); } const OpNode* first_bw_consumer = nullptr; int32_t first_bw_order = std::numeric_limits::max(); // step 3.3 change bw consumers input from subgraph to fake subgraph for (const OpNode* node : bw_consumers) { std::string bw_consumer_name = node->op().op_name(); OperatorConf bw_consumer_op_conf; // NOTE(chengcheng): // reuse bw conumer op conf if it has been existed in map. if (total_bw_consumers_op_name2conf.find(bw_consumer_name) != total_bw_consumers_op_name2conf.end()) { bw_consumer_op_conf = total_bw_consumers_op_name2conf.at(bw_consumer_name); } else { bw_consumer_op_conf = node->op().op_conf(); } CHECK_EQ(bw_consumer_name, bw_consumer_op_conf.name()); auto* user_conf = bw_consumer_op_conf.mutable_user_conf(); // change input lbns if in subgraph for (auto& pair : *(user_conf->mutable_input())) { auto& list_s = pair.second; for (int i = 0; i < list_s.s_size(); ++i) { std::string old_lbn = list_s.s(i); LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn); std::string old_input_op_name = old_lbi.op_name(); if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) { list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn); } } } // NOTE(chengcheng): // emplace maybe repeated, so do not check the return value total_bw_consumers_op_name2conf.emplace(bw_consumer_name, bw_consumer_op_conf); CHECK(op_node2order.find(node) != op_node2order.end()); int32_t this_order = op_node2order.at(node); if (this_order < first_bw_order) { first_bw_consumer = node; first_bw_order = this_order; } } // step 3.4 add control edge from End Op to all source node in fake subgraph CHECK(first_bw_consumer != nullptr); std::string end_op_name = kCheckpointingBadOpName; int32_t end_order = -1; const OpNode* end_op_node = nullptr; first_bw_consumer->ForEachNodeOnInEdge([&](const OpNode* end_node) { CHECK(op_node2order.find(end_node) != op_node2order.end()); int32_t this_order = op_node2order.at(end_node); if (this_order > end_order) { end_order = this_order; end_op_name = end_node->op().op_name(); end_op_node = end_node; } }); CHECK_NE(end_order, -1); CHECK_NE(end_op_name, kCheckpointingBadOpName); CHECK_LT(end_order, first_bw_order); CHECK(end_op_node != nullptr); // NOTE(chengcheng): if end_op placement is different with first_bw_consumer, the ctrl edge // cannot be directly connected. if (!first_bw_consumer->parallel_desc().EqualsIgnoringHierarchy(end_op_node->parallel_desc())) { std::string lbn = ""; LogicalBlobId lbi; const OpEdge* end_op_edge = nullptr; for (const OpEdge* in_edge : first_bw_consumer->in_edges()) { if (in_edge->src_node() == end_op_node) { lbi = in_edge->lbis().front(); lbn = GenLogicalBlobName(lbi); end_op_edge = in_edge; break; } } CHECK(!lbn.empty()); auto id_op = user_op::UserOpConfWrapperBuilder(kCheckpointingIdentityOpName + NewUniqueId()) .Op("identity") .Input("in", lbn) .Output("out") .ScopeSymbolId(first_bw_consumer->op().op_conf().scope_symbol_id()) .Build(); std::string id_out = id_op.output("out", 0); for (const std::string& ibn : end_op_edge->lbi2ibns().at(lbi)) { std::string old_lbn = ReplaceInputLbnInOpCustomizedConf( &(total_bw_consumers_op_name2conf.at(first_bw_consumer->op().op_name())), ibn, id_out); CHECK_EQ(old_lbn, lbn); } JUST(job_builder->AddOp(first_bw_consumer->parallel_desc().parallel_conf(), id_op.op_conf())); end_op_name = id_op.op_name(); } for (const auto& source_op_name : source_node_in_fake_subgraph) { fake_op_name2conf.at(source_op_name).add_ctrl_in_op_name(end_op_name); } // step 3.5 add fake subgraph ops to job builder std::vector fake_op_confs; for (auto& pair : fake_op_name2conf) { fake_op_confs.emplace_back(pair.second); } job_builder->AddOps(parallel_conf, fake_op_confs); // step 3.6 log checkpointing tensor flow debug. if (IsInDebugMode()) { VLOG(2) << " In subgraph: " << subgraph_id << " has checkpointing tensor num = " << checkpointing_tensor.size(); for (const auto& lbi : checkpointing_tensor) { const OpNode* node = op_graph.OpNode4OpName(lbi.op_name()); const BlobDesc& blob = node->LogicalBlobDesc4Lbi(lbi); VLOG(2) << "Checkpointing tensor: " << GenLogicalBlobName(lbi) << " ,shape: " << blob.shape().ToString() << " ,dtype: " << DataType_Name(blob.data_type()) << " ,placement: " << *JUST(PlacementToString(SymbolOf(node->parallel_desc()))) << " ,sbp: " << NdSbpToString(node->NdSbp4Lbi(lbi)); } subgraph_id++; } } // step 4. update bw consumers in job builder only once std::vector total_bw_consumer_op_confs; for (auto& pair : total_bw_consumers_op_name2conf) { total_bw_consumer_op_confs.emplace_back(pair.second); } job_builder->MutOpsOnlyOnce(total_bw_consumer_op_confs); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("CheckpointingPass", CheckpointingPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_ #define ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_ #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { class ClipByGlobalNormJobPassState : public JobPassState { public: OF_DISALLOW_COPY_AND_MOVE(ClipByGlobalNormJobPassState); ClipByGlobalNormJobPassState() = default; ~ClipByGlobalNormJobPassState() override = default; class TotalNormState { public: TotalNormState(const std::string& total_norm_lbn, const std::string& coeff_lbn, const ParallelConf& parallel_conf, int64_t scope_symbol_id) : total_norm_lbn_(total_norm_lbn), coeff_lbn_(coeff_lbn), parallel_conf_(parallel_conf), scope_symbol_id_(scope_symbol_id) {} void set_total_norm_lbn(const std::string& total_norm_lbn) { total_norm_lbn_ = total_norm_lbn; } const std::string& total_norm_lbn() const { return total_norm_lbn_; } const std::string& coeff_lbn() const { return coeff_lbn_; } const ParallelConf& parallel_conf() const { return parallel_conf_; } int64_t scope_symbol_id() const { return scope_symbol_id_; } private: std::string total_norm_lbn_; std::string coeff_lbn_; ParallelConf parallel_conf_; int64_t scope_symbol_id_; }; void AddTotalNormState(const std::string& variable_op_name, const std::shared_ptr& total_norm_state) { CHECK(variable_op_name2total_norm_state_.emplace(variable_op_name, total_norm_state).second) << variable_op_name; } const std::shared_ptr& GetTotalNormState(const std::string& variable_op_name) { const auto& it = variable_op_name2total_norm_state_.find(variable_op_name); CHECK(it != variable_op_name2total_norm_state_.end()); return it->second; } const bool HasTotalNormState(const std::string& variable_op_name) { const auto& it = variable_op_name2total_norm_state_.find(variable_op_name); return (it != variable_op_name2total_norm_state_.end()); } private: HashMap> variable_op_name2total_norm_state_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_ ================================================ FILE: oneflow/core/job_rewriter/clone_grad.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/clone_grad.h" #include "oneflow/core/framework/framework.h" namespace oneflow { Maybe GenerateCloneGradOpIfNeed( const OpNode& op_node, JobBuilder* job_builder, const HashMap& in_oba2in_diff_lbi, HashMap* out_oba2out_diff_lbi, HashMap* out_oba2clone_bw_add_out_lbi) { HashMap out_lbi2out_oba; for (const auto& obn : op_node.op().output_bns()) { out_lbi2out_oba[op_node.op().BnInOp2Lbi(obn)] = GenOpBlobArg(op_node.op().op_name(), obn); } HashMap> out_oba2in_diff_lbis; op_node.ForEachNodeOnOutEdge([&](OpNode* out_node) { for (const auto& ibn : out_node->op().input_bns()) { const auto& oba_it = out_lbi2out_oba.find(out_node->op().BnInOp2Lbi(ibn)); if (oba_it == out_lbi2out_oba.end()) { continue; } const auto& in_diff_lbi_it = in_oba2in_diff_lbi.find(GenOpBlobArg(out_node->op().op_name(), ibn)); if (in_diff_lbi_it == in_oba2in_diff_lbi.end()) { continue; } out_oba2in_diff_lbis[oba_it->second].emplace_back(in_diff_lbi_it->second); } }); for (const auto& obn : op_node.op().output_bns()) { const OpBlobArg& oba = GenOpBlobArg(op_node.op().op_name(), obn); const LogicalBlobId& lbi = op_node.op().BnInOp2Lbi(obn); const std::vector& lbis_to_add = out_oba2in_diff_lbis[oba]; if (lbis_to_add.empty()) { continue; } else if (lbis_to_add.size() == 1) { out_oba2out_diff_lbi->emplace(oba, lbis_to_add.front()); } else { user_op::UserOpConfWrapperBuilder add_op_builder(op_node.op().op_name() + "_clone_grad_" + NewUniqueId()); add_op_builder.Op("add_n"); for (const LogicalBlobId& lbi_to_add : lbis_to_add) { add_op_builder.Input("in", GenLogicalBlobName(lbi_to_add)); } const auto& op_conf = JUST(job_builder->OpConf4OpName(lbi.op_name())); const auto add_op = add_op_builder.Output("out").ScopeSymbolId(op_conf.scope_symbol_id()).Build(); job_builder->AddOps(JUST(job_builder->ParallelConf4Lbi(lbi)), {add_op.op_conf()}); CHECK(out_oba2clone_bw_add_out_lbi->emplace(oba, lbis_to_add.front()).second); out_oba2out_diff_lbi->emplace(oba, GenLogicalBlobId(add_op.output("out", 0))); } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/clone_grad.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_ #define ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_ #include "oneflow/core/job_rewriter/autograd.h" namespace oneflow { Maybe GenerateCloneGradOpIfNeed( const OpNode& op_node, JobBuilder* job_builder, const HashMap& in_oba2in_diff_lbi, HashMap* out_oba2out_diff_lbi, HashMap* out_oba2clone_bw_add_out_lbi); } #endif // ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_ ================================================ FILE: oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace oneflow { namespace { bool IsFusedBnAddReluSupported() { #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) return true; #else return false; #endif } bool IsNormalizationAddReluOp(const OperatorConf& op) { return op.has_user_conf() && (op.user_conf().op_type_name() == "normalization_add_relu" || op.user_conf().op_type_name() == "normalization_add_relu_grad"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsNormalizationAddReluOp); } } // namespace class CudnnFusedNormalizationAddReluPass final : public JobPass { public: CudnnFusedNormalizationAddReluPass() = default; ~CudnnFusedNormalizationAddReluPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { if (ctx.job_desc().job_conf().has_enable_cudnn_fused_normalization_add_relu()) { bool enabled = ctx.job_desc().job_conf().enable_cudnn_fused_normalization_add_relu(); CHECK(!enabled || IsFusedBnAddReluSupported()) << "Option 'enable_cudnn_fused_normalization_add_relu' is only supported when cuDNN " "version >= 7.4.1"; return enabled; } else { return IsFusedBnAddReluSupported(); } } Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe CudnnFusedNormalizationAddReluPass::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); const DataType mixed_precision_data_type = ctx->job_desc().mixed_precision_data_type(); op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!IsNormalizationAddReluOp(op_conf)) { return; } const std::string& op_type_name = op_conf.user_conf().op_type_name(); const user_op::UserOpConfWrapper user_op_conf(op_conf); const BlobDesc& x_desc = op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(user_op_conf.input("x", 0))); const int32_t axis = user_op_conf.attr("axis"); if (x_desc.data_type() != mixed_precision_data_type) { return; } const Shape& x_shape = x_desc.shape(); if (x_shape.Count(axis + 1) != 1) { return; } if (x_shape.At(axis) % 4 != 0) { return; } OperatorConf new_op_conf = op_conf; auto mute_attrs = new_op_conf.mutable_user_conf()->mutable_attr(); auto training_it = mute_attrs->find("training"); if (training_it != mute_attrs->end()) { const bool training = user_op_conf.attr("training"); if (!training) { return; } mute_attrs->erase(training_it); } new_op_conf.mutable_user_conf()->set_op_type_name("cudnn_fused_" + op_type_name); job_builder.MutOpsOnlyOnce({new_op_conf}); }); return Maybe::Ok(); } REGISTER_JOB_PASS("CudnnFusedNormalizationAddReluPass", CudnnFusedNormalizationAddReluPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/cutlass_conv_tuning_warmup_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUTLASS #include "oneflow/core/framework/to_string.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/user/kernels/cutlass_conv_tuner.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/framework/user_op_conf.h" #include namespace oneflow { namespace { constexpr size_t kMaxWorkspaceSize = 128 * 1024 * 1024; // 128MB constexpr size_t kBufferMallocAlign = 128 * 1024 * 1024; // 128MB class CutlassConvTuningWarmupPass final : public JobPass { public: CutlassConvTuningWarmupPass() = default; ~CutlassConvTuningWarmupPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe CutlassConvTuningWarmupPass::Apply(Job* job, JobPassCtx* ctx) const { // Compatible with typo `KERENL` if (!ParseBooleanFromEnv("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", ParseBooleanFromEnv("ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL", false))) { return Maybe::Ok(); } if (!ParseBooleanFromEnv( "ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", ParseBooleanFromEnv("ONEFLOW_KERENL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", false))) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); auto device = Singleton::Get()->GetDevice(DeviceType::kCUDA, 0); ep::Stream* stream = device->CreateStream(); void* workspace = nullptr; char* buffer = nullptr; size_t buffer_size = 0; OF_CUDA_CHECK(cudaMalloc(&workspace, kMaxWorkspaceSize)); std::vector op_confs; op_graph.ForEachNode([&](const OpNode* node) { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } if (op_conf.user_conf().op_type_name() != "conv2d") { return; } if (node->parallel_desc().device_type() != DeviceType::kCUDA) { return; } if (node->parallel_desc().parallel_num() != 1) { return; } if (!node->parallel_desc().containing_current_rank()) { return; } user_op::UserOpConfWrapper conv2d_op(op_conf); if (conv2d_op.attr("data_format") != "channels_last") { return; } if (conv2d_op.attr("groups") != 1) { return; } VLOG(3) << "Tuning " << op_conf.name(); const auto& in_desc = node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input("in", 0))); if (in_desc.data_type() != DataType::kFloat16) { return; } const auto& weight_desc = node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input("weight", 0))); const auto& out_desc = node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.output("out", 0))); const auto& padding_before = conv2d_op.attr>("padding_before"); const auto& dilation_rate = conv2d_op.attr>("dilation_rate"); const auto& strides = conv2d_op.attr>("strides"); const int n = in_desc.shape().At(0); const int h = in_desc.shape().At(1); const int w = in_desc.shape().At(2); const int c = in_desc.shape().At(3); const int k = weight_desc.shape().At(0); const int r = weight_desc.shape().At(1); const int s = weight_desc.shape().At(2); CHECK_EQ(weight_desc.shape().At(3), c); const int p = out_desc.shape().At(1); const int q = out_desc.shape().At(2); cutlass::library::ConvFunctionalKey key( cutlass::library::Provider::kCUTLASS, cutlass::library::ConvKind::kFprop, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF32, cutlass::library::NumericTypeID::kF32); const bool allow_half_accumulation = ParseBooleanFromEnv("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", false); if (allow_half_accumulation) { key.element_accumulator = cutlass::library::NumericTypeID::kF16; key.element_compute = cutlass::library::NumericTypeID::kF16; } const size_t x_size = GetCudaAlignedSize(in_desc.ByteSizeOfBlobBody()); const size_t w_size = GetCudaAlignedSize(weight_desc.ByteSizeOfBlobBody()); const size_t y_size = GetCudaAlignedSize(out_desc.ByteSizeOfBlobBody()); size_t bias_size = 0; if (conv2d_op.has_input("bias", 0)) { bias_size = GetCudaAlignedSize(node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input("bias", 0))) .ByteSizeOfBlobBody()); } const size_t total_buf_size = x_size + w_size + y_size + bias_size; if (total_buf_size > buffer_size) { size_t malloc_size = RoundUp(total_buf_size, kBufferMallocAlign); OF_CUDA_CHECK(cudaFree(buffer)); OF_CUDA_CHECK(cudaMalloc(&buffer, malloc_size)); buffer_size = malloc_size; } void* x_ptr = buffer; void* w_ptr = buffer + x_size; void* y_ptr = buffer + x_size + w_size; void* bias_ptr = nullptr; if (bias_size != 0) { bias_ptr = buffer + x_size + w_size + y_size; } cutlass::conv::Conv2dProblemSize problem_size( n, h, w, c, k, r, s, p, q, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1), dilation_rate.at(0), dilation_rate.at(1), cutlass::conv::Mode::kCrossCorrelation); cutlass::library::Conv2dConfiguration configuraion; configuraion.split_k_mode = cutlass::conv::SplitKMode::kSerial; configuraion.problem_size = problem_size; configuraion.stride_a = {c, w * c, h * w * c}; configuraion.stride_b = {c, s * c, r * s * c}; configuraion.stride_c = {0, 0, 0}; cutlass::library::ConvArguments arguments; arguments.A = x_ptr; arguments.B = w_ptr; arguments.reordered_B = nullptr; arguments.C = bias_ptr; arguments.D = y_ptr; union SP { float f{}; half h; }; SP alpha; SP beta; if (allow_half_accumulation) { alpha.h = static_cast(1.0F); if (bias_ptr == nullptr) { beta.h = static_cast(0.0F); } else { beta.h = static_cast(1.0F); } } else { alpha.f = 1.0F; if (bias_ptr == nullptr) { beta.f = 0.0F; } else { beta.f = 1.0F; } } arguments.alpha = α arguments.beta = β arguments.pointer_mode = cutlass::library::ScalarPointerMode::kHost; const cutlass::library::Operation* operation = CutlassConvTuner::Get().FindConv2dOperation( stream->As(), key, configuraion, arguments, workspace, kMaxWorkspaceSize); if (operation != nullptr) { VLOG(3) << "Fastest operation: " << operation->description().name; nlohmann::json tuning_cache; tuning_cache["cutlass"] = operation->description().name; OperatorConf new_op_conf = op_conf; (*(*new_op_conf.mutable_user_conf()->mutable_attr())["tuning_cache"].mutable_at_string()) = tuning_cache.dump(); op_confs.push_back(new_op_conf); } }); job_builder.MutOpsOnlyOnce(op_confs); OF_CUDA_CHECK(cudaFree(workspace)); OF_CUDA_CHECK(cudaFree(buffer)); device->DestroyStream(stream); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("CutlassConvTuningWarmupPass", CutlassConvTuningWarmupPass); } // namespace oneflow #endif // WITH_CUTLASS ================================================ FILE: oneflow/core/job_rewriter/delay_variable_op_execution_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { class DelayVariableOpExecutionPass final : public JobPass { public: DelayVariableOpExecutionPass() = default; ~DelayVariableOpExecutionPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe DelayVariableOpExecutionPass::Apply(Job* job, JobPassCtx* ctx) const { if (!ParseBooleanFromEnv("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", false)) { return Maybe::Ok(); } const JobConfigProto& job_conf = ctx->job_desc().job_conf(); if (job_conf.has_train_conf()) { return Maybe::Ok(); } if (job_conf.has_num_gradient_accumulation_steps() && job_conf.num_gradient_accumulation_steps() > 1) { return Maybe::Ok(); } if (GlobalProcessCtx::WorldSize() > 1) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_variable_conf()) { return Maybe::Ok(); } if (!op_conf.ctrl_in_op_name().empty()) { return Maybe::Ok(); } if (op_conf.variable_conf().has_tick()) { return Maybe::Ok(); } if (node->out_edges().size() != 1) { return Maybe::Ok(); } if (node->parallel_desc().parallel_num() != 1) { return Maybe::Ok(); } const OpNode* dst_node = (*node->out_edges().begin())->dst_node(); if (dst_node->parallel_desc() != node->parallel_desc()) { return Maybe::Ok(); } const OpEdge* none_variable_edge = nullptr; for (const OpEdge* edge : dst_node->in_edges()) { if (edge->src_node()->op().op_conf().has_variable_conf()) { continue; } if (edge->lbis().size() == 0) { continue; } if (edge->src_node()->parallel_desc() != node->parallel_desc()) { continue; } none_variable_edge = edge; break; } if (none_variable_edge == nullptr) { return Maybe::Ok(); } OperatorConf new_varibale_conf = op_conf; new_varibale_conf.mutable_variable_conf()->set_tick( GenLogicalBlobName(none_variable_edge->lbis().front())); job_builder.MutOpsOnlyOnce({new_varibale_conf}); return Maybe::Ok(); })); return Maybe::Ok(); } REGISTER_JOB_PASS("DelayVariableOpExecutionPass", DelayVariableOpExecutionPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/device_tick_autotick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autotick.h" namespace oneflow { namespace { class MutDeviceTickOpConTickInputHelper final : public MutOpConTickInputHelper { public: MutDeviceTickOpConTickInputHelper() : MutOpConTickInputHelper() {} bool VirtualIsTickInputBound() const override { return op_conf().device_tick_conf().tick_size() > 0; } OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override { OperatorConf ret(op_conf()); ret.mutable_device_tick_conf()->add_tick(lbn); return ret; } }; } // namespace REGISTER_AUTO_TICK(OperatorConf::kDeviceTickConf, MutDeviceTickOpConTickInputHelper); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/do_parallel_cast_before_widening_type_cast_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/pass_util.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace { class DoParallelCastBeforeWideningTypeCast final : public JobPass { public: DoParallelCastBeforeWideningTypeCast() = default; ~DoParallelCastBeforeWideningTypeCast() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().do_parallel_cast_before_widening_type_cast(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } if (GlobalProcessCtx::WorldSize() == 1) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe DoParallelCastBeforeWideningTypeCast::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { OpConfCache op_conf_cache; op_graph.ForEachNode([&op_conf_cache](OpNode* parallel_cast_node) { // find cast_fp16_to_fp32_or_double -> parallel_cast pattern const OperatorConf& parallel_cast_op_conf = op_conf_cache.GetLatest(parallel_cast_node->op().op_conf()); if (!(parallel_cast_op_conf.has_user_conf() && (parallel_cast_op_conf.user_conf().op_type_name() == "parallel_cast" || parallel_cast_op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast"))) { return; } auto* cast_node = parallel_cast_node->SoleInEdge()->src_node(); if (cast_node->out_edges().size() != 1) { return; } auto cast_op_conf = op_conf_cache.GetLatest(cast_node->op().op_conf()); if (!(cast_op_conf.has_user_conf() && cast_op_conf.user_conf().op_type_name() == "cast")) { return; } user_op::UserOpConfWrapper cast_conf_wrapper(cast_op_conf); const auto cast_in_lbi = cast_node->SoleInEdge()->lbis().front(); const auto cast_in_dtype = cast_node->LogicalBlobDesc4Lbi(cast_in_lbi).data_type(); const auto cast_out_dtype = cast_conf_wrapper.attr("dtype"); if (!((cast_in_dtype == DataType::kFloat16 || cast_in_dtype == DataType::kBFloat16) && (cast_out_dtype == DataType::kFloat || cast_out_dtype == DataType::kDouble))) { return; } user_op::UserOpConfWrapper parallel_cast_conf_wrapper(parallel_cast_op_conf); // replace parallel_cast op input with cast op input { OperatorConf new_parallel_cast_op_conf(parallel_cast_op_conf); const auto& cast_input = cast_conf_wrapper.input("in", 0); const auto& parallel_cast_input = parallel_cast_conf_wrapper.input("in", 0); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&new_parallel_cast_op_conf, "in_0", cast_input); CHECK_EQ(parallel_cast_input, old_val); op_conf_cache.Put(new_parallel_cast_op_conf); } // replace cast op input with parallel_cast op output { OperatorConf new_cast_op_conf(cast_op_conf); const auto& parallel_cast_output = parallel_cast_conf_wrapper.output("out", 0); const auto& cast_input = cast_conf_wrapper.input("in", 0); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&new_cast_op_conf, "in_0", parallel_cast_output); CHECK_EQ(cast_input, old_val); op_conf_cache.Put(new_cast_op_conf); } // update all parallel_cast op consumers const std::string& cast_output = cast_conf_wrapper.output("out", 0); for (OpEdge* edge : parallel_cast_node->out_edges()) { CHECK_EQ(1, edge->lbis().size()); LogicalBlobId cur_lbi = edge->lbis().front(); const auto lbn = GenLogicalBlobName(cur_lbi); CHECK_EQ(1, edge->lbi2ibns().at(cur_lbi).size()); const std::string& dst_ibn = edge->lbi2ibns().at(cur_lbi).front(); OpNode* dst_node = edge->dst_node(); OperatorConf dst_op_conf = op_conf_cache.GetLatest(dst_node->op().op_conf()); CHECK_EQ(lbn, ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, cast_output)); op_conf_cache.Put(dst_op_conf); } }); job_builder->MutOpsOnlyOnce(op_conf_cache.op_confs()); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("DoParallelCastBeforeWideningTypeCast", DoParallelCastBeforeWideningTypeCast); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/dump_blob_parallel_conf_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { namespace { class DumpBlobParallelConfPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(DumpBlobParallelConfPass); DumpBlobParallelConfPass() = default; ~DumpBlobParallelConfPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return true; } Maybe Apply(const OpGraph& op_graph, Job* job) const { op_graph.DumpLogicalBlobDesc(job); op_graph.DumpArgSignature(job); op_graph.DumpNdSbpSignature(job); return Maybe::Ok(); } Maybe Apply(Job* job, JobPassCtx* ctx) const override { const OpGraph op_graph(*job); return Apply(op_graph, job); } }; REGISTER_JOB_PASS("DumpBlobParallelConfPass", DumpBlobParallelConfPass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/dump_variable_info_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { std::string GetNdSbpString(const VariableOpConf& conf, const ParallelDesc& parallel_desc) { const bool has_nd_sbp_conf = (conf.nd_sbp_size() != 0); const int64_t num_axes = parallel_desc.hierarchy()->NumAxes(); if (has_nd_sbp_conf) { CHECK_EQ(conf.nd_sbp_size(), num_axes); } std::string nd_sbp_str; FOR_RANGE(int64_t, i, 0, num_axes) { if (has_nd_sbp_conf) { nd_sbp_str += conf.nd_sbp(i); } else { nd_sbp_str += "B"; } if (i != num_axes - 1) { nd_sbp_str += ", "; } } return nd_sbp_str; } class DumpVariableInfoPass final : public JobPass { public: DumpVariableInfoPass() = default; ~DumpVariableInfoPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return Singleton::Get()->enable_debug_mode(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe DumpVariableInfoPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { int64_t cnt = 0; const std::string sep = "\t"; auto log_stream = TeePersistentLogStream::Create("variable_table_" + std::to_string(GlobalJobDesc().job_id())); (*log_stream) << "id" << sep << "name" << sep << "device_tag" << sep << "parallel_hierarchy" << sep << "distribute" << sep << "data_type" << sep << "shape" << sep << "elem_cnt" << sep << "size" << "\n"; JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_variable_conf()) { return Maybe::Ok(); } const VariableOpConf& conf = op_conf.variable_conf(); (*log_stream) << std::to_string(cnt); (*log_stream) << sep; (*log_stream) << op_conf.name(); (*log_stream) << sep; (*log_stream) << op_conf.device_tag(); (*log_stream) << sep; (*log_stream) << node->parallel_desc().hierarchy()->DebugStr(); (*log_stream) << sep; (*log_stream) << GetNdSbpString(conf, node->parallel_desc()); (*log_stream) << sep; (*log_stream) << DataType_Name(conf.data_type()); (*log_stream) << sep; const Shape shape(conf.shape()); (*log_stream) << shape.ToString(); (*log_stream) << sep; (*log_stream) << std::to_string(shape.elem_cnt()); (*log_stream) << sep; (*log_stream) << std::to_string(shape.elem_cnt() * GetSizeOfDataType(conf.data_type())); (*log_stream) << "\n"; cnt += 1; return Maybe::Ok(); })); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("DumpVariableInfoPass", DumpVariableInfoPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_ #define ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_ #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { class DynamicLossScaleJobPassState : public JobPassState { public: OF_DISALLOW_COPY_AND_MOVE(DynamicLossScaleJobPassState); DynamicLossScaleJobPassState() = default; ~DynamicLossScaleJobPassState() override = default; const std::string& count_not_finite_lbn() const { return count_not_finite_lbn_; } void set_count_not_finite_lbn(const std::string& lbn) { count_not_finite_lbn_ = lbn; } const std::string& loss_scale_val_lbn() const { return loss_scale_val_lbn_; } void set_loss_scale_val_lbn(const std::string& lbn) { loss_scale_val_lbn_ = lbn; } private: std::string count_not_finite_lbn_; std::string loss_scale_val_lbn_; }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_ ================================================ FILE: oneflow/core/job_rewriter/dynamic_loss_scale_schedule_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h" #include "oneflow/core/framework/scope_util.h" namespace oneflow { namespace { class DynamicLossScaleSchedulePass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(DynamicLossScaleSchedulePass); DynamicLossScaleSchedulePass() = default; ~DynamicLossScaleSchedulePass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe DynamicLossScaleSchedulePass::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().IsTrain()) { return Maybe::Ok(); } const TrainConf& train_conf = job->job_conf().train_conf(); if (!train_conf.has_dynamic_loss_scale_policy()) { return Maybe::Ok(); } const auto& policy = train_conf.dynamic_loss_scale_policy(); const OpGraph op_graph(*job); JobBuilder job_builder(job); const ParallelConf& parallel_conf = GenParallelConfOfCpuZeroOnMaster(); int64_t scope_symbol_id; { const auto& opt_scope_symbol_id = JUST(MakeInitialScope(job->job_conf(), SymbolOf(ParallelDesc(parallel_conf)), /* is_local */ false)) ->symbol_id(); CHECK_OR_RETURN(opt_scope_symbol_id.has_value()) << Error::RuntimeError() << "symbol_id not initialized"; scope_symbol_id = JUST(opt_scope_symbol_id); } OperatorConf loss_scale_var_op_conf{}; const std::string op_name_prefix = "System-Train-DynamicLossScale-"; { loss_scale_var_op_conf.set_name(op_name_prefix + job->job_conf().job_name() + "-LossScale"); VariableOpConf* variable_conf = loss_scale_var_op_conf.mutable_variable_conf(); variable_conf->set_out("out"); *variable_conf->mutable_shape()->mutable_dim()->Add() = 1; variable_conf->set_data_type(DataType::kFloat); variable_conf->mutable_initializer()->mutable_constant_conf()->set_value( policy.initial_loss_scale()); loss_scale_var_op_conf.set_scope_symbol_id(scope_symbol_id); } OperatorConf good_step_counter_var_conf{}; { good_step_counter_var_conf.set_name(op_name_prefix + job->job_conf().job_name() + "-GoodStepCounter"); VariableOpConf* variable_conf = good_step_counter_var_conf.mutable_variable_conf(); variable_conf->set_out("out"); *variable_conf->mutable_shape()->mutable_dim()->Add() = 1; variable_conf->set_data_type(DataType::kInt64); variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0); good_step_counter_var_conf.set_scope_symbol_id(scope_symbol_id); } OperatorConf loss_scale_val_op_conf{}; const std::string loss_scale_var_lbn = GenLogicalBlobName( loss_scale_var_op_conf.name(), loss_scale_var_op_conf.variable_conf().out()); { loss_scale_val_op_conf.set_name(loss_scale_var_op_conf.name() + "-Identity"); loss_scale_val_op_conf.set_scope_symbol_id(scope_symbol_id); IdentityOpConf* identity_conf = loss_scale_val_op_conf.mutable_identity_conf(); identity_conf->set_in(loss_scale_var_lbn); identity_conf->set_out("out"); } // will be replaced by real count of not finite auto count_not_finite_stub_op = user_op::UserOpConfWrapperBuilder(op_name_prefix + job->job_conf().job_name() + "-CountNotFinite") .Op("constant") .Output("out") .Attr("floating_value", 0.0) .Attr("integer_value", 0) .Attr("is_floating_value", false) .Attr("dtype", DataType::kInt64) .Attr("shape", Shape({1})) .ScopeSymbolId(scope_symbol_id) .Build(); const std::string loss_scale_val_lbn = GenLogicalBlobName( loss_scale_val_op_conf.name(), loss_scale_val_op_conf.identity_conf().out()); const std::string good_step_counter_var_lbn = GenLogicalBlobName( good_step_counter_var_conf.name(), good_step_counter_var_conf.variable_conf().out()); auto schedule = user_op::UserOpConfWrapperBuilder(op_name_prefix + job->job_conf().job_name() + "-Schedule") .Op("dynamic_loss_scale_schedule") .Input("count_not_finite", count_not_finite_stub_op.output("out", 0)) .Input("loss_scale", loss_scale_var_lbn) .Input("good_step_counter", good_step_counter_var_lbn) .Attr("increment_period", policy.increment_period()) .Attr("multiplier", policy.multiplier()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder.AddOps(parallel_conf, {loss_scale_var_op_conf, loss_scale_val_op_conf, good_step_counter_var_conf, count_not_finite_stub_op.op_conf(), schedule.op_conf()}); if (!JUST(ctx->HasState("dynamic_loss_scale_state"))) { JUST(ctx->ResetState("dynamic_loss_scale_state", std::make_unique())); } auto state = JUST(ctx->MutableState("dynamic_loss_scale_state")); state->set_loss_scale_val_lbn(loss_scale_val_lbn); state->set_count_not_finite_lbn(count_not_finite_stub_op.output("out", 0)); return Maybe::Ok(); } REGISTER_JOB_PASS("DynamicLossScaleSchedulePass", DynamicLossScaleSchedulePass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/eliminate_dead_nodes_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class EliminateDeadNodesPass final : public JobPass { public: EliminateDeadNodesPass() = default; ~EliminateDeadNodesPass() override = default; Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; static bool IsNoSideEffect(const OpNode* op_node) { static HashSet no_side_effect_ops = { "constant", "zeros_like", "ones_like", "repeat", "acc", "pack", "unpack", }; static HashSet no_side_effect_system_ops = { OperatorConf::kDeviceTickConf, }; const auto& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return no_side_effect_system_ops.count(op_conf.op_type_case()); } return no_side_effect_ops.count(op_conf.user_conf().op_type_name()); } Maybe EliminateDeadNodesPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashSet delete_ops; std::vector delete_op_confs; op_graph.ReverseTopoForEachNode([&](const OpNode* op_node) { if (!IsNoSideEffect(op_node)) { return; } for (const auto* out_edge : op_node->out_edges()) { if (!delete_ops.count(out_edge->dst_node())) { return; } } VLOG(3) << "Eliminate dead node: " << op_node->op().op_name(); delete_ops.insert(op_node); delete_op_confs.emplace_back(op_node->op().op_conf()); }); job_builder->DelOps(delete_op_confs); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("EliminateDeadNodesPass", EliminateDeadNodesPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fix_pipeline_stage_id_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/symbol_storage.h" namespace oneflow { namespace { class FixPipelineStageIdPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(FixPipelineStageIdPass); FixPipelineStageIdPass() = default; ~FixPipelineStageIdPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain() && ctx.job_desc().job_conf().num_gradient_accumulation_steps() > 1; } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; const Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) { CHECK(Singleton>::Get()->Has(scope_symbol_id)); return Singleton>::Get()->Get(scope_symbol_id); } const Scope& Scope4OpNode(const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); CHECK(op_conf.has_scope_symbol_id()); return Scope4ScopeSymbolId(op_conf.scope_symbol_id()); } bool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); } int64_t GetStageIdHint(const OpNode* node) { return Scope4OpNode(node).Int64("pipeline_stage_id_hint"); } std::string ParallelDesc2HashString(const ParallelDesc& parallel_desc) { std::string ret = parallel_desc.device_tag() + ",{"; for (int64_t m : parallel_desc.sorted_machine_ids()) { ret += (std::to_string(m) + ":["); for (int64_t d : parallel_desc.sorted_dev_phy_ids(m)) { ret += (std::to_string(d) + ","); } ret += "],"; } ret += "}"; return ret; } Maybe NewScopeWithStageId(int64_t old_scope_symbol_id, int64_t stage_id) { return NewScopeSymbolId( old_scope_symbol_id, [stage_id]( std::shared_ptr new_scope) { // NOLINT(performance-unnecessary-value-param) auto* attr_map = new_scope->mutable_attr_name2attr_value(); (*attr_map)["pipeline_stage_id_hint"].set_at_int64(stage_id); }); } Maybe FixPipelineStageIdPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { int64_t max_stage_id = 0; op_graph.ForEachNode([&](const OpNode* this_node) { if (!OpNodeHasScope(this_node)) { LOG(WARNING) << " op : " << this_node->op().op_conf().DebugString() << " has NOT scope!"; return; } max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node)); }); if (max_stage_id == 0) { return Maybe::Ok(); } const int64_t total_stage_num = max_stage_id + 1; VLOG(3) << "total stage num = " << total_stage_num; HashMap op_name2node; HashMap> placement2op_nodes; std::vector fix_stage_op_confs; // NOTE(chengcheng): group op by placement. op_graph.ForEachNode([&](const OpNode* this_node) { if (!OpNodeHasScope(this_node)) { return; } const std::string& op_name = this_node->op().op_name(); op_name2node.emplace(op_name, this_node); std::string placement = ParallelDesc2HashString(this_node->parallel_desc()); placement2op_nodes[placement].emplace_back(this_node); }); for (auto& pair : placement2op_nodes) { int64_t max_stage_id = -1; for (const OpNode* this_node : pair.second) { max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node)); } CHECK_GE_OR_RETURN(max_stage_id, 0); for (const OpNode* this_node : pair.second) { int64_t this_stage_id = GetStageIdHint(this_node); if (this_stage_id != max_stage_id) { VLOG(3) << " In FixPipelineStageIdPass, op_name: " << this_node->op().op_name() << " origin_stage_id = " << this_stage_id << " is different with same placement : " << pair.first << " max_stage_id: " << max_stage_id << " , so change this op to the max stage id.\n"; OperatorConf new_op_conf = this_node->op().op_conf(); int64_t new_scope_symbol_id = JUST(NewScopeWithStageId(new_op_conf.scope_symbol_id(), max_stage_id)); new_op_conf.set_scope_symbol_id(new_scope_symbol_id); fix_stage_op_confs.emplace_back(std::move(new_op_conf)); } } } for (const auto& op : fix_stage_op_confs) { JUST(job_builder->MutOpOnlyOnce(op)); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FixPipelineStageIdPass", FixPipelineStageIdPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/ftrl_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/initializer_conf.pb.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/variable_op.h" namespace oneflow { namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateFtrlHelperVariableConf(const VariableOp& op, const std::string& name, const float initial_value) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(initial_value); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void GenerateFtrlOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); user_op::UserOpConfWrapperBuilder ftrl_update_op_builder(var_op->op_name() + "_optimizer"); float lr_power = 0.0; float initial_accumulator_value = 0.0; float lambda1 = 0.0; float lambda2 = 0.0; float beta = 0.0; const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf(); lr_power = ftrl_conf.lr_power(); initial_accumulator_value = ftrl_conf.initial_accumulator_value(); lambda1 = ftrl_conf.lambda1(); lambda2 = ftrl_conf.lambda2(); beta = ftrl_conf.beta(); const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); OperatorConf accumulator_var( GenerateFtrlHelperVariableConf(*var_op, "accumulate", initial_accumulator_value)); OperatorConf z_var(GenerateFtrlHelperVariableConf(*var_op, "z", 0.0)); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {accumulator_var, z_var}); ftrl_update_op_builder.OpTypeName("ftrl_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", learning_rate_lbn) .Input("accumulate", GenVariableOutputLbn(accumulator_var)) .Input("z", GenVariableOutputLbn(z_var)) .Attr("lr_power", lr_power) .Attr("lambda1", lambda1) .Attr("lambda2", lambda2) .Attr("beta", beta) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { ftrl_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &ftrl_update_op_builder); const auto ftrl_update_op = ftrl_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {ftrl_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kFtrlConf, &GenerateFtrlOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/hash_container.h" #include "oneflow/core/common/just.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class FuseAddToOutputPass final : public JobPass { public: FuseAddToOutputPass() = default; ~FuseAddToOutputPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().job_conf().enable_fuse_add_to_output(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const HashMap supported_op_type_name2output_arg( {{"normalization", user_op::OpArg("y", 0)}, {"dropout", user_op::OpArg("out", 0)}, {"matmul", user_op::OpArg("out", 0)}, {"layer_norm_grad", user_op::OpArg("dx", 0)}, {"batch_matmul", user_op::OpArg("out", 0)}, {"fused_bias_add_mask_scale", user_op::OpArg("out", 0)}, {"fused_matmul_bias", user_op::OpArg("out", 0)}, {"broadcast_matmul", user_op::OpArg("out", 0)}, {"broadcast_matmul_grad_b", user_op::OpArg("out", 0)}}); HashSet consumer_op_names; auto IsAddToOutputSupported = [&](const OpNode* node, const LogicalBlobId& lbi) -> bool { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_user_conf()) { return false; } if (consumer_op_names.count(op_conf.name()) > 0) { return false; } auto it = supported_op_type_name2output_arg.find(op_conf.user_conf().op_type_name()); if (it == supported_op_type_name2output_arg.end()) { return false; } const user_op::UserOpConfWrapper user_op_conf(op_conf); if (GenLogicalBlobId(user_op_conf.output(it->second.name(), it->second.index())) != lbi) { return false; } // add op should be the only consumer int64_t output_consumer_cnt = 0; for (const OpEdge* out_edge : node->out_edges()) { if (std::find(out_edge->lbis().cbegin(), out_edge->lbis().cend(), lbi) != out_edge->lbis().cend()) { output_consumer_cnt += 1; } } if (output_consumer_cnt != 1) { return false; } // already fused if (user_op_conf.has_input("_add_to_output", 0)) { return false; } return true; }; // Save all op's ctrl in op name in a set. HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); std::vector delete_ops; HashSet be_fused_op_names; JUST(op_graph.MaybeForEachNode([&](const OpNode* op_node) -> Maybe { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return Maybe::Ok(); } if (!op_conf.ctrl_in_op_name().empty()) { return Maybe::Ok(); } if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return Maybe::Ok(); } if (op_conf.user_conf().op_type_name() != "add_n") { return Maybe::Ok(); } if (be_fused_op_names.count(op_conf.name()) > 0) { return Maybe::Ok(); } if (consumer_op_names.count(op_conf.name()) > 0) { return Maybe::Ok(); } const user_op::UserOpConfWrapper user_op_conf(op_conf); if (user_op_conf.input_size("in") != 2) { return Maybe::Ok(); } const LogicalBlobId in_0 = GenLogicalBlobId(user_op_conf.input("in", 0)); const LogicalBlobId in_1 = GenLogicalBlobId(user_op_conf.input("in", 1)); const LogicalBlobId out = GenLogicalBlobId(user_op_conf.output("out", 0)); const OpNode* in_0_node = op_graph.OpNode4OpName(in_0.op_name()); const OpNode* in_1_node = op_graph.OpNode4OpName(in_1.op_name()); const OpNode* add_to_node; const LogicalBlobId* add_to_lbi; const LogicalBlobId* sum_lbi; if ((!IsReachable(in_0.op_name(), in_1.op_name())) && IsAddToOutputSupported(in_0_node, in_0)) { add_to_node = in_0_node; add_to_lbi = &in_1; sum_lbi = &in_0; be_fused_op_names.insert(in_1.op_name()); } else if ((!IsReachable(in_1.op_name(), in_0.op_name())) && IsAddToOutputSupported(in_1_node, in_1)) { add_to_node = in_1_node; add_to_lbi = &in_0; sum_lbi = &in_1; be_fused_op_names.insert(in_0.op_name()); } else { return Maybe::Ok(); } // Make a new_add_to_op to fuse add_n into this op. if (JUST(job_builder->IsInMutOpTransaction(add_to_node->op().op_name()))) { OperatorConf& new_add_to_op_conf = JUST(job_builder->MutOpTransactionGet(add_to_node->op().op_name())); *(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))["_add_to_output"] .mutable_s() ->Add() = GenLogicalBlobName(*add_to_lbi); } else { OperatorConf new_add_to_op_conf = add_to_node->op().op_conf(); *(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))["_add_to_output"] .mutable_s() ->Add() = GenLogicalBlobName(*add_to_lbi); JUST(job_builder->MutOpTransactionMut(new_add_to_op_conf)); } for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); const std::string& consumer_op_name = consumer->op().op_name(); if (consumer_op_names.count(consumer_op_name) == 0) { if (!JUST(job_builder->IsInMutOpTransaction(consumer->op().op_name()))) { consumer_op_names.insert(consumer_op_name); JUST(job_builder->MutOpTransactionMut(consumer->op().op_conf())); } } // Make add_n op's consumer to consume the new_add_to_op for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == out) { OperatorConf& consumer_op_conf = JUST(job_builder->MutOpTransactionGet(consumer_op_name)); const auto& new_val = GenLogicalBlobName(*sum_lbi); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(out), old_val); } } } // Add the add_n op to removing list delete_ops.emplace_back(op_conf); return Maybe::Ok(); })); JUST(job_builder->MutOpTransactionCommit()); job_builder->DelOps(delete_ops); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseAddToOutputPass", FuseAddToOutputPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_bce_reduce_mean_fw_bw_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { void UpdateConsumerOpConf(const OpNode* consumer, const LogicalBlobId& out, const std::string& new_out_lbn, HashMap* op_name2op_conf) { const std::string& consumer_op_name = consumer->op().op_name(); if (op_name2op_conf->find(consumer_op_name) == op_name2op_conf->end()) { (*op_name2op_conf)[consumer_op_name] = consumer->op().op_conf(); } for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == out) { OperatorConf& consumer_op_conf = op_name2op_conf->at(consumer_op_name); const auto& new_val = new_out_lbn; const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(out), old_val); } } } class FuseBCEReduceMeanFwBwPass final : public JobPass { public: FuseBCEReduceMeanFwBwPass() = default; ~FuseBCEReduceMeanFwBwPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ParseBooleanFromEnv("ONEFLOW_FUSE_BCE_REDUCE_MEAN_FW_BW", false); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseBCEReduceMeanFwBwPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { // This pass fuse binary_cross_entropy_with_logits_reduce_mean and // binary_cross_entropy_with_logits_reduce_mean_grad. delete the h2f cast to loss, and the // constant_like of dy. const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph); HashMap op_name2op_conf; std::vector delete_ops; op_graph.ForEachNode([&](const OpNode* op_node) { if (!IsUserOpWithTypeName(op_node->op().op_conf(), "binary_cross_entropy_with_logits_reduce_mean")) { return; } if (op_node->out_edges().size() > 2) { return; } bool find_grad_op = false; for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (!IsSafeToDelete(consumer)) { return; } if (!(IsUserOpWithTypeName(consumer->op().op_conf(), "cast") || consumer->op().op_conf().has_constant_like_conf() || consumer->op().op_conf().has_output_conf())) { return; } if (consumer->op().op_conf().has_constant_like_conf()) { const OpNode* grad_node = consumer->SoleOutEdge()->dst_node(); if (!IsUserOpWithTypeName(grad_node->op().op_conf(), "binary_cross_entropy_with_logits_reduce_mean_grad")) { return; } find_grad_op = true; if (!IsSafeToDelete(grad_node)) { return; } } } if (!find_grad_op) { return; } const user_op::UserOpConfWrapper bce_op_conf(op_node->op().op_conf()); user_op::UserOpConfWrapperBuilder fused_op_builder(bce_op_conf.op_name()); fused_op_builder.OpTypeName("fused_bce_reduce_mean_fw_bw") .Input("input", bce_op_conf.input("input", 0)) .Input("target", bce_op_conf.input("target", 0)) .Output("out") .Output("dx"); for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (IsUserOpWithTypeName(consumer->op().op_conf(), "cast")) { const user_op::UserOpConfWrapper cast_conf(consumer->op().op_conf()); fused_op_builder.Attr("out_dtype", cast_conf.attr("dtype")); // delete cast and update cast consumer's in. delete_ops.push_back(consumer->op().op_conf()); for (const OpEdge* cast_out_edge : consumer->out_edges()) { const OpNode* cast_consumer = cast_out_edge->dst_node(); UpdateConsumerOpConf(cast_consumer, GenLogicalBlobId(cast_conf.output("out", 0)), GenLogicalBlobName(bce_op_conf.op_name(), "out_0"), &op_name2op_conf); } } else if (consumer->op().op_conf().has_constant_like_conf()) { fused_op_builder.Attr( "constant_value", consumer->op().op_conf().constant_like_conf().float_operand()); const OpNode* grad_node = consumer->SoleOutEdge()->dst_node(); // delete constant_like and grad op, update consumer delete_ops.push_back(grad_node->op().op_conf()); delete_ops.push_back(consumer->op().op_conf()); const user_op::UserOpConfWrapper grad_conf(grad_node->op().op_conf()); for (const OpEdge* grad_out_edge : grad_node->out_edges()) { const OpNode* grad_consumer = grad_out_edge->dst_node(); UpdateConsumerOpConf(grad_consumer, GenLogicalBlobId(grad_conf.output("dx", 0)), GenLogicalBlobName(bce_op_conf.op_name(), "dx_0"), &op_name2op_conf); } } else { continue; } } user_op::UserOpConfWrapper fused_op = fused_op_builder.ScopeSymbolId(bce_op_conf.op_conf().scope_symbol_id()).Build(); job_builder->MutOpsOnlyOnce({fused_op.op_conf()}); }); job_builder->DelOps(delete_ops); for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseBCEReduceMeanFwBwPass", FuseBCEReduceMeanFwBwPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class FuseCastScalePass final : public JobPass { public: FuseCastScalePass() = default; ~FuseCastScalePass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().job_conf().enable_fuse_cast_scale(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph); std::vector delete_ops; op_graph.ForEachNode([&](const OpNode* op_node) { if (!IsUserOpWithTypeName(op_node->op().op_conf(), "cast")) { return; } if (!IsSafeToDelete(op_node)) { return; } if (op_node->out_edges().size() != 1) { return; } OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node(); if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) { if (!IsSafeToDelete(sole_dst_node)) { return; } if (!IsUserOpWithTypeName(sole_dst_node->SoleOutEdge()->dst_node()->op().op_conf(), "scalar_mul_by_tensor")) { return; } } else { if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul_by_tensor")) { return; } } const user_op::UserOpConfWrapper cast_user_conf(op_node->op().op_conf()); if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input("in", 0))).data_type() != DataType::kFloat16 && op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input("in", 0))).data_type() != DataType::kBFloat16) { return; } if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.output("out", 0))).data_type() != DataType::kFloat) { return; } if (op_node->parallel_desc().device_type() != DeviceType::kCUDA) { return; } double scale = 1.0; if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) { const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf()); if (scalar_mul_op_conf.attr("has_int_operand")) { scale = static_cast(scalar_mul_op_conf.attr("int_operand")); } else if (scalar_mul_op_conf.attr("has_float_operand")) { scale = scalar_mul_op_conf.attr("float_operand"); } else { UNIMPLEMENTED(); } delete_ops.emplace_back(sole_dst_node->op().op_conf()); sole_dst_node = sole_dst_node->SoleOutEdge()->dst_node(); } delete_ops.emplace_back(op_node->op().op_conf()); const user_op::UserOpConfWrapper scale_user_conf(sole_dst_node->op().op_conf()); user_op::UserOpConfWrapperBuilder fused_op_builder(sole_dst_node->op().op_name()); fused_op_builder.OpTypeName("fused_cast_scale") .Input("x", cast_user_conf.input("in", 0)) .Input("scale_by_tensor", scale_user_conf.input("scalar", 0)) .Attr("scale", scale) .Output("y"); OperatorConf new_op_conf = sole_dst_node->op().op_conf(); *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf(); job_builder->MutOpsOnlyOnce({new_op_conf}); }); job_builder->DelOps(delete_ops); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseCastScalePass", FuseCastScalePass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_consecutive_add_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/cost_util.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { class FuseConsecutiveAddPass final : public JobPass { public: FuseConsecutiveAddPass() = default; ~FuseConsecutiveAddPass() override = default; Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { const OpGraph op_graph(*job); JobBuilder job_builder(job); JUST(Apply(op_graph, &job_builder)); return Maybe::Ok(); } }; Maybe FuseConsecutiveAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph); std::vector delete_ops; op_graph.TopoForEachNode([&](const OpNode* op_node) { if (!IsUserOpWithTypeName(op_node->op().op_conf(), "add_n") || !IsSafeToDelete(op_node) || op_node->out_edges().size() != 1) { return; } OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node(); if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "add_n") || !IsSafeToDelete(sole_dst_node)) { return; } const std::string this_op_name = op_node->op().op_name(); const auto& GetCurOpConf = [&](const OpNode& cur_op) -> OperatorConf { const std::string& cur_op_name = cur_op.op().op_name(); if (!CHECK_JUST(job_builder->IsInMutOpTransaction(cur_op_name))) { return cur_op.op().op_conf(); } else { return CHECK_JUST(job_builder->MutOpTransactionGet(cur_op_name)); } }; int64_t fused_cnt = 0; auto fused_op_conf = GetCurOpConf(*sole_dst_node); auto in_it = fused_op_conf.mutable_user_conf()->mutable_input()->find("in"); CHECK(in_it != fused_op_conf.mutable_user_conf()->mutable_input()->end()); auto* in_lbns = in_it->second.mutable_s(); auto in_lbn_it = in_lbns->begin(); while (in_lbn_it != in_lbns->end()) { const auto lbi = GenLogicalBlobId(*in_lbn_it); if (lbi.op_name() == this_op_name) { in_lbn_it = in_lbns->erase(in_lbn_it); ++fused_cnt; } else { ++in_lbn_it; } } const auto& this_op_conf = GetCurOpConf(*op_node); auto this_in_it = this_op_conf.user_conf().input().find("in"); CHECK(this_in_it != this_op_conf.user_conf().input().end()); for (int64_t fuse_i = 0; fuse_i < fused_cnt; ++fuse_i) { for (const auto& this_in_lbn : this_in_it->second.s()) { *(in_lbns->Add()) = this_in_lbn; } } CHECK_JUST(job_builder->MutOpTransactionMut(fused_op_conf)); delete_ops.emplace_back(this_op_name); }); if (delete_ops.empty()) { return Maybe::Ok(); } JUST(job_builder->MutOpTransactionCommit()); job_builder->DelOps(delete_ops); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseConsecutiveAddPass", FuseConsecutiveAddPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_embedding_interaction_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class FuseEmbeddingShuffleInteractionPass final : public JobPass { public: FuseEmbeddingShuffleInteractionPass() = default; ~FuseEmbeddingShuffleInteractionPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { // if enable quantize, not support fuse kernel. bool enable_quantized_comm = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); bool enable_fuse_embedding_interaction = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION", false); return (!enable_quantized_comm && enable_fuse_embedding_interaction); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseEmbeddingShuffleInteractionPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { op_graph.ForEachNode([&](const OpNode* op_node) { if (!IsUserOpWithTypeName(op_node->op().op_conf(), "embedding_shuffle")) { return; } if (op_node->out_edges().size() > 2) { return; } const user_op::UserOpConfWrapper embedding_shuffle_conf(op_node->op().op_conf()); const std::string& embeddings_lbn = embedding_shuffle_conf.output("embeddings", 0); const std::string& indices_lbn = embedding_shuffle_conf.input("inverse_unique_partition_indices", 0); const std::string& num_unique_matrix_lbn = embedding_shuffle_conf.input("num_unique_matrix", 0); if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(embeddings_lbn)).data_type() != DataType::kFloat16 || embedding_shuffle_conf.attr("embedding_size") % 2 != 0) { // only support half and embedding_size % 2 == 0 fuse, because atomicAdd half is slow. return; } if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(indices_lbn)).data_type() != DataType::kUInt32) { // only support indices with uint32_t dtype return; } if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(num_unique_matrix_lbn)).data_type() != DataType::kUInt32) { // only support num_unique with uint32_t dtype return; } for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (!consumer->op().op_conf().has_user_conf()) { return; } const user_op::UserOpConfWrapper consumer_op_conf(consumer->op().op_conf()); if (!(consumer_op_conf.op_type_name() == "fused_dot_feature_interaction" || consumer_op_conf.op_type_name() == "fused_dot_feature_interaction_grad")) { return; } if (consumer_op_conf.attr("pooling") != "none") { return; } int input_size = consumer_op_conf.input_size("features"); CHECK_GT(input_size, 0) << input_size; if (consumer_op_conf.input("features", input_size - 1) != embeddings_lbn) { // only support embeddings as last feature return; } user_op::UserOpConfWrapperBuilder fused_op_builder(consumer_op_conf.op_name()); const std::string& op_type_name = consumer_op_conf.op_type_name(); fused_op_builder.OpTypeName(op_type_name) .Input("sparse_feature", embeddings_lbn) .Input("sparse_indices", indices_lbn) .Input("num_valid_sparse_feature", num_unique_matrix_lbn) .Attr("self_interaction", consumer_op_conf.attr("self_interaction")) .Attr("pooling", consumer_op_conf.attr("pooling")); for (int i = 0; i < input_size - 1; ++i) { fused_op_builder.Input("features", consumer_op_conf.input("features", i)); } OperatorConf new_op_conf = consumer->op().op_conf(); if (op_type_name == "fused_dot_feature_interaction") { if (consumer_op_conf.has_input("output_concat", 0)) { fused_op_builder.Input("output_concat", consumer_op_conf.input("output_concat", 0)); } fused_op_builder.Output("out") .Attr("has_output_concat", consumer_op_conf.attr("has_output_concat")) .Attr("output_padding", consumer_op_conf.attr("output_padding")); *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf(); } else { // fused_dot_feature_interaction_grad fused_op_builder.Input("dy", consumer_op_conf.input("dy", 0)) .Output("features_grad", input_size - 1) .Output("sparse_feature_grad") .Attr("output_concat_grad_dim", consumer_op_conf.attr("output_concat_grad_dim")); if (consumer_op_conf.has_output("output_concat_grad", 0)) { fused_op_builder.Output("output_concat_grad"); } user_op::UserOpConfWrapper fused_dot_feature_interaction_grad_op = fused_op_builder.Build(); *new_op_conf.mutable_user_conf() = fused_dot_feature_interaction_grad_op.op_conf().user_conf(); const LogicalBlobId last_feature_grad_lbi = GenLogicalBlobId(consumer_op_conf.output("features_grad", input_size - 1)); std::string sparse_feature_grad_lbn = fused_dot_feature_interaction_grad_op.output("sparse_feature_grad", 0); for (const OpEdge* out_edge : consumer->out_edges()) { const OpNode* grad_out_node = out_edge->dst_node(); if (out_edge->lbis().size() == 1 && out_edge->lbis().front() == last_feature_grad_lbi) { if (!IsUserOpWithTypeName(grad_out_node->op().op_conf(), "embedding_gradient_shuffle")) { return; } OperatorConf new_embedding_gradient_shuffle_conf = grad_out_node->op().op_conf(); for (const std::string& ibn : grad_out_node->op().input_bns()) { if (grad_out_node->op().BnInOp2Lbi(ibn) == last_feature_grad_lbi) { const auto& new_val = sparse_feature_grad_lbn; const auto& old_val = ReplaceInputLbnInOpCustomizedConf( &new_embedding_gradient_shuffle_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(last_feature_grad_lbi), old_val); } } auto bool_attr = ::oneflow::AttrValue(); bool_attr.set_at_bool(true); (*(new_embedding_gradient_shuffle_conf.mutable_user_conf() ->mutable_attr()))["skip_first_scatter"] = bool_attr; job_builder->MutOpsOnlyOnce({new_embedding_gradient_shuffle_conf}); } } } job_builder->MutOpsOnlyOnce({new_op_conf}); } auto bool_attr = ::oneflow::AttrValue(); bool_attr.set_at_bool(true); OperatorConf new_embedding_shuffle_conf = op_node->op().op_conf(); (*(new_embedding_shuffle_conf.mutable_user_conf()->mutable_attr()))["skip_last_gather"] = bool_attr; job_builder->MutOpsOnlyOnce({new_embedding_shuffle_conf}); }); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseEmbeddingShuffleInteractionPass", FuseEmbeddingShuffleInteractionPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_model_update_cast_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class FuseModelUpdateCastOpsPass final : public JobPass { public: FuseModelUpdateCastOpsPass() = default; ~FuseModelUpdateCastOpsPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return (ctx.job_desc().enable_fused_model_update_cast() || ParseBooleanFromEnv("ONEFLOW_FUSE_MODEL_UPDATE_CAST", false)) && ctx.job_desc().enable_auto_mixed_precision(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } LOG(INFO) << "Enable fuse model update cast pass. "; const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseModelUpdateCastOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { op_graph.ForEachNode([&](OpNode* op_node) { const auto& op_conf = op_node->op().op_conf(); if (!op_conf.has_variable_conf()) { return; } LogicalBlobId model_copy_lbi; for (OpEdge* find_cast_edge : op_node->out_edges()) { OpNode* find_cast_node = find_cast_edge->dst_node(); if (!IsUserOpWithTypeName(find_cast_node->op().op_conf(), "cast")) { continue; } const user_op::UserOpConfWrapper cast_user_conf(find_cast_node->op().op_conf()); if (find_cast_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input("in", 0))) .data_type() != DataType::kFloat) { continue; } if (find_cast_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.output("out", 0))) .data_type() != DataType::kFloat16) { continue; } // Currently only support for cuda, maybe remove this limit. if (find_cast_node->parallel_desc().device_type() != DeviceType::kCUDA) { continue; } for (OpEdge* find_model_update_edge : op_node->out_edges()) { OpNode* find_model_update_update_node = find_model_update_edge->dst_node(); if (!IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "sgd_update") && !IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "adam_update")) { continue; } // Currently only support for cuda, maybe remove this limit. if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) { continue; } const user_op::UserOpConfWrapper model_update_user_conf( find_model_update_update_node->op().op_conf()); // Here we find cast and model_update node, Replace cast as mutable_cast_once, and add // model_copy to model_update node. user_op::UserOpConfWrapperBuilder fused_cast_op_builder(cast_user_conf.op_name()); fused_cast_op_builder.OpTypeName("mutable_cast_once") .Input("in", cast_user_conf.input("in", 0)) .Attr("dtype", cast_user_conf.attr("dtype")) .Output("out"); CHECK(cast_user_conf.op_conf().has_scope_symbol_id()); fused_cast_op_builder.ScopeSymbolId(cast_user_conf.op_conf().scope_symbol_id()); OperatorConf new_cast_op_conf = cast_user_conf.op_conf(); *new_cast_op_conf.mutable_user_conf() = fused_cast_op_builder.Build().op_conf().user_conf(); job_builder->MutOpsOnlyOnce({new_cast_op_conf}); const user_op::UserOpConfWrapper new_cast_user_conf(new_cast_op_conf); model_copy_lbi = GenLogicalBlobId(new_cast_user_conf.output("out", 0)); user_op::UserOpConfWrapperBuilder fused_model_update_op_builder( model_update_user_conf.op_name()); if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "sgd_update")) { fused_model_update_op_builder.OpTypeName("sgd_update") .Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)) .Input("learning_rate", model_update_user_conf.input("learning_rate", 0)) .Attr("scale", model_update_user_conf.attr("scale")) .Attr("l1", model_update_user_conf.attr("l1")) .Attr("l2", model_update_user_conf.attr("l2")) .Attr("weight_decay", model_update_user_conf.attr("weight_decay")) .Attr("learning_rate_scale", model_update_user_conf.attr("learning_rate_scale")); } else if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "adam_update")) { fused_model_update_op_builder.OpTypeName("adam_update") .Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)) .Input("m", model_update_user_conf.input("m", 0)) .Input("v", model_update_user_conf.input("v", 0)) .Input("learning_rate", model_update_user_conf.input("learning_rate", 0)) .Attr("scale", model_update_user_conf.attr("scale")) .Attr("l1", model_update_user_conf.attr("l1")) .Attr("l2", model_update_user_conf.attr("l2")) .Attr("weight_decay", model_update_user_conf.attr("weight_decay")) .Attr("beta1", model_update_user_conf.attr("beta1")) .Attr("beta2", model_update_user_conf.attr("beta2")) .Attr("epsilon", model_update_user_conf.attr("epsilon")) .Attr("amsgrad", model_update_user_conf.attr("amsgrad")) .Attr("do_bias_correction", model_update_user_conf.attr("do_bias_correction")) .Attr("learning_rate_scale", model_update_user_conf.attr("learning_rate_scale")); ; if (model_update_user_conf.attr("do_bias_correction")) { fused_model_update_op_builder.Input( "bias_correction1", model_update_user_conf.input("bias_correction1", 0)); fused_model_update_op_builder.Input( "bias_correction2", model_update_user_conf.input("bias_correction2", 0)); } if (model_update_user_conf.attr("amsgrad")) { fused_model_update_op_builder.Input("max_v", model_update_user_conf.input("max_v", 0)); } } else { UNIMPLEMENTED() << "Need support more optimizers. "; } fused_model_update_op_builder.Input("model_copy", GenLogicalBlobName(model_copy_lbi)); CHECK(model_update_user_conf.op_conf().has_scope_symbol_id()); fused_model_update_op_builder.ScopeSymbolId( model_update_user_conf.op_conf().scope_symbol_id()); OperatorConf new_model_update_op_conf = model_update_user_conf.op_conf(); *new_model_update_op_conf.mutable_user_conf() = fused_model_update_op_builder.Build().op_conf().user_conf(); job_builder->MutOpsOnlyOnce({new_model_update_op_conf}); break; } break; } }); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseModelUpdateCastOpsPass", FuseModelUpdateCastOpsPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/fuse_update_ops_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class FuseUpdateOpsPass final : public JobPass { public: FuseUpdateOpsPass() = default; ~FuseUpdateOpsPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().job_conf().enable_fuse_model_update_ops(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph); std::vector del_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { if (!op_node->op().op_conf().has_user_conf()) { return; } const user_op::UserOpConfWrapper user_op_conf(op_node->op().op_conf()); if (user_op_conf.op_type_name() != "sgd_update" && user_op_conf.op_type_name() != "momentum_update" && user_op_conf.op_type_name() != "adam_update" && user_op_conf.op_type_name() != "rmsprop_update" && user_op_conf.op_type_name() != "lars_update" && user_op_conf.op_type_name() != "adagrad_update" && user_op_conf.op_type_name() != "lamb_update" && user_op_conf.op_type_name() != "ftrl_update" && user_op_conf.op_type_name() != "adadelta_update") { return; } if (user_op_conf.attr("scale") != 1.0 || user_op_conf.attr("l1") != 0.0f || user_op_conf.attr("l2") != 0.0f) { return; } float l1 = 0; float l2 = 0; double scale = 1; bool fused = false; LogicalBlobId model_diff_lbi = GenLogicalBlobId(user_op_conf.input("model_diff", 0)); std::string scale_by_tensor_lbn; [&]() { do { const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name()); if (!IsUserOpWithTypeName(producer->op().op_conf(), "l1_l2_regularize_gradient")) { break; } if (!IsSafeToDelete(producer)) { return; } const user_op::UserOpConfWrapper l1_l2_regularize_gradient_op_conf( producer->op().op_conf()); if (l1_l2_regularize_gradient_op_conf.input("model", 0) != user_op_conf.input("model", 0)) { return; } l1 = l1_l2_regularize_gradient_op_conf.attr("l1"); l2 = l1_l2_regularize_gradient_op_conf.attr("l2"); model_diff_lbi = GenLogicalBlobId(l1_l2_regularize_gradient_op_conf.input("model_diff", 0)); del_op_names.emplace_back(producer->op().op_name()); fused = true; } while (false); do { const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name()); if (!IsUserOpWithTypeName(producer->op().op_conf(), "scalar_mul_by_tensor")) { break; } if (!IsSafeToDelete(producer)) { return; } const user_op::UserOpConfWrapper scalar_mul_by_tensor_op_conf(producer->op().op_conf()); model_diff_lbi = GenLogicalBlobId(scalar_mul_by_tensor_op_conf.input("x", 0)); scale_by_tensor_lbn = scalar_mul_by_tensor_op_conf.input("scalar", 0); del_op_names.emplace_back(producer->op().op_name()); fused = true; } while (false); do { const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name()); if (!IsUserOpWithTypeName(producer->op().op_conf(), "scalar_mul")) { break; } if (!IsSafeToDelete(producer)) { return; } const user_op::UserOpConfWrapper scalar_mul_op_conf(producer->op().op_conf()); if (scalar_mul_op_conf.attr("has_int_operand")) { scale = static_cast(scalar_mul_op_conf.attr("int_operand")); } else if (scalar_mul_op_conf.attr("has_float_operand")) { scale = scalar_mul_op_conf.attr("float_operand"); } else { UNIMPLEMENTED(); } model_diff_lbi = GenLogicalBlobId(scalar_mul_op_conf.input("in", 0)); del_op_names.emplace_back(producer->op().op_name()); fused = true; } while (false); do { const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name()); if (!IsUserOpWithTypeName(producer->op().op_conf(), "cast")) { break; } if (!IsSafeToDelete(producer)) { return; } const user_op::UserOpConfWrapper cast_op_conf(producer->op().op_conf()); if (producer->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_op_conf.input("in", 0))).data_type() != DataType::kFloat16 || cast_op_conf.attr("dtype") != DataType::kFloat) { return; } model_diff_lbi = GenLogicalBlobId(cast_op_conf.input("in", 0)); del_op_names.emplace_back(producer->op().op_name()); fused = true; } while (false); }(); if (!fused) { return; } const TrainConf& train_conf = job_builder->job().job_conf().train_conf(); user_op::UserOpConfWrapperBuilder fused_op_builder(user_op_conf.op_name()); fused_op_builder.OpTypeName(user_op_conf.op_type_name()) .Input("model", user_op_conf.input("model", 0)) .Input("model_diff", GenLogicalBlobName(model_diff_lbi)) .Input("learning_rate", user_op_conf.input("learning_rate", 0)) .Attr("scale", scale) .Attr("l1", l1) .Attr("l2", l2) .Attr("weight_decay", user_op_conf.attr("weight_decay")) .Attr("learning_rate_scale", user_op_conf.attr("learning_rate_scale")); if (scale_by_tensor_lbn != "") { fused_op_builder.Input("scale_by_tensor", scale_by_tensor_lbn); } if (user_op_conf.has_input("skip_if", 0)) { fused_op_builder.Input("skip_if", user_op_conf.input("skip_if", 0)); } if (user_op_conf.op_type_name() == "sgd_update") { // do nothing } else if (user_op_conf.op_type_name() == "momentum_update") { fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0)) .Attr("beta", user_op_conf.attr("beta")) .Attr("dampening", user_op_conf.attr("dampening")) .Attr("nesterov", user_op_conf.attr("nesterov")) .Attr("maximize", user_op_conf.attr("maximize")); } else if (user_op_conf.op_type_name() == "adam_update") { fused_op_builder.Input("m", user_op_conf.input("m", 0)) .Input("v", user_op_conf.input("v", 0)) .Attr("beta1", user_op_conf.attr("beta1")) .Attr("beta2", user_op_conf.attr("beta2")) .Attr("epsilon", user_op_conf.attr("epsilon")) .Attr("amsgrad", user_op_conf.attr("amsgrad")) .Attr("do_bias_correction", user_op_conf.attr("do_bias_correction")); if (user_op_conf.has_input("max_v", 0)) { fused_op_builder.Input("max_v", user_op_conf.input("max_v", 0)); } if (user_op_conf.has_input("bias_correction1", 0)) { fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0)); } if (user_op_conf.has_input("bias_correction2", 0)) { fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0)); } } else if (user_op_conf.op_type_name() == "rmsprop_update") { const bool centered = user_op_conf.attr("centered"); fused_op_builder.Input("mean_square", user_op_conf.input("mean_square", 0.f)) .Attr("centered", user_op_conf.attr("centered")) .Attr("epsilon", user_op_conf.attr("epsilon")) .Attr("decay_rate", user_op_conf.attr("decay_rate")); if (centered) { fused_op_builder.Input("mean_gradient", user_op_conf.input("mean_gradient", 0.f)); } } else if (user_op_conf.op_type_name() == "lars_update") { fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0)) .Attr("momentum_beta", user_op_conf.attr("momentum_beta")) .Attr("epsilon", user_op_conf.attr("epsilon")) .Attr("lars_coefficient", user_op_conf.attr("lars_coefficient")); } else if (user_op_conf.op_type_name() == "adagrad_update") { fused_op_builder.Input("sum", user_op_conf.input("sum", 0)) .Input("train_step", train_conf.train_step_lbn()) .Attr("lr_decay", user_op_conf.attr("lr_decay")) .Attr("epsilon", user_op_conf.attr("epsilon")); } else if (user_op_conf.op_type_name() == "lamb_update") { fused_op_builder.Input("m", user_op_conf.input("m", 0)) .Input("v", user_op_conf.input("v", 0)) .Attr("beta1", user_op_conf.attr("beta1")) .Attr("beta2", user_op_conf.attr("beta2")) .Attr("epsilon", user_op_conf.attr("epsilon")) .Attr("do_bias_correction", user_op_conf.attr("do_bias_correction")); if (user_op_conf.has_input("bias_correction1", 0)) { fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0)); } if (user_op_conf.has_input("bias_correction2", 0)) { fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0)); } } else if (user_op_conf.op_type_name() == "ftrl_update") { fused_op_builder.Input("accumulate", user_op_conf.input("accumulate", 0)) .Input("z", user_op_conf.input("z", 0)) .Attr("lr_power", user_op_conf.attr("lr_power")) .Attr("lambda1", user_op_conf.attr("lambda1")) .Attr("lambda2", user_op_conf.attr("lambda2")) .Attr("beta", user_op_conf.attr("beta")); } else if (user_op_conf.op_type_name() == "adadelta_update") { fused_op_builder.Input("square_avgs", user_op_conf.input("square_avgs", 0)) .Input("acc_deltas", user_op_conf.input("acc_deltas", 0)) .Attr("rho", user_op_conf.attr("rho")) .Attr("epsilon", user_op_conf.attr("epsilon")) .Attr("maximize", user_op_conf.attr("maximize")); } else { UNIMPLEMENTED(); } CHECK(user_op_conf.op_conf().has_scope_symbol_id()); fused_op_builder.ScopeSymbolId(user_op_conf.op_conf().scope_symbol_id()); OperatorConf new_op_conf = user_op_conf.op_conf(); *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf(); job_builder->MutOpsOnlyOnce({new_op_conf}); }); job_builder->DelOps(del_op_names); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("FuseUpdateOpsPass", FuseUpdateOpsPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/generate_optimizer_op_confs.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/scope.pb.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" namespace oneflow { namespace { class GenerateOptimizerOpConfs final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(GenerateOptimizerOpConfs); GenerateOptimizerOpConfs() = default; ~GenerateOptimizerOpConfs() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); } Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; void FilterCurModelLbi2ModelDiffLbiByName( const ::google::protobuf::RepeatedPtrField& variables, const HashMap& model_lbi2model_diff_lbi, HashMap* cur_model_lbi2model_diff_lbi) { for (const std::string& variable : variables) { const LogicalBlobId& lbi = GenLogicalBlobId(variable + "/out"); if (model_lbi2model_diff_lbi.find(lbi) != model_lbi2model_diff_lbi.end()) { (*cur_model_lbi2model_diff_lbi)[lbi] = model_lbi2model_diff_lbi.at(lbi); } } } Maybe WithCalculationPassScope(const std::string& pass_name, Job* job, const std::function()>& Handler) { HashSet exists_op_names; for (const auto& op_conf : job->net().op()) { CHECK_OR_RETURN(exists_op_names.emplace(op_conf.name()).second); } JUST(Handler()); // using a new JobBuilder to avoid bugs caused by MutOnlyOnce auto new_job_builder = std::make_shared(job); HashMap> scope_id2op_names; const auto& scope_storage = *Singleton>::Get(); for (const auto& op_conf : job->net().op()) { if (exists_op_names.count(op_conf.name()) > 0) { continue; } CHECK_OR_RETURN(op_conf.has_scope_symbol_id()); OF_RETURN_IF_ERROR(scope_storage.MaybeGet(op_conf.scope_symbol_id())) << op_conf.DebugString(); scope_id2op_names[op_conf.scope_symbol_id()].emplace_back(&op_conf); } const auto& GetNewScopeSymbolId = [&](int64_t old_scope_symbol_id) -> Maybe { const auto& old_scope = JUST(scope_storage.MaybeGet(old_scope_symbol_id)); std::shared_ptr new_scope = std::make_shared(old_scope.scope_proto()); new_scope->set_parent_scope_symbol_id(old_scope_symbol_id); new_scope->set_calculation_pass_name(pass_name); std::shared_ptr new_scope_symbol; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { new_scope_symbol = JUST(builder->GetScopeSymbol(*new_scope)); return Maybe::Ok(); })); return JUST(new_scope_symbol->symbol_id()); }; for (const auto& pair : scope_id2op_names) { int64_t new_scope_symbol_id = JUST(GetNewScopeSymbolId(pair.first)); std::vector op_confs(pair.second.size()); for (int i = 0; i < pair.second.size(); ++i) { op_confs.at(i).CopyFrom(*pair.second.at(i)); op_confs.at(i).set_scope_symbol_id(new_scope_symbol_id); } new_job_builder->MutOpsOnlyOnce(op_confs); } return new_job_builder; } Maybe GenerateOptimizerOpConfs::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const auto& train_conf = job->job_conf().train_conf(); // loss initial gradients HashMap loss_lbi2initial_diff_lbi; CHECK_OR_RETURN(train_conf.loss_lbn_size() == train_conf.loss_grad_lbn_size()) << "loss_lbn and loss_grad_lbn size mismatch"; for (int i = 0; i < train_conf.loss_lbn_size(); ++i) { auto loss_lbi = GenLogicalBlobId(train_conf.loss_lbn(i)); auto loss_grad_lbi = GenLogicalBlobId(train_conf.loss_grad_lbn(i)); loss_lbi2initial_diff_lbi.emplace(loss_lbi, loss_grad_lbi); } // variable gradients HashMap model_lbi2model_diff_lbi; for (const auto& optimizer_conf : train_conf.optimizer_conf()) { CHECK_OR_RETURN(optimizer_conf.variable_op_names_size() == optimizer_conf.variable_grad_lbns_size()) << "variable_op_names and variable_grad_lbns size mismatch"; for (int i = 0; i < optimizer_conf.variable_op_names_size(); ++i) { auto model_lbi = GenLogicalBlobId(optimizer_conf.variable_op_names(i) + "/out"); const auto& model_diff_lbn = optimizer_conf.variable_grad_lbns(i); // variable maybe has no gradient, so skip it if model_diff_lbn is empty if (!model_diff_lbn.empty()) { model_lbi2model_diff_lbi.emplace(model_lbi, GenLogicalBlobId(model_diff_lbn)); } } } const OpGraph op_graph(*job); auto job_builder = std::make_shared(job); const JobBuilder* old_job_builder = job_builder.get(); job_builder = JUST(WithCalculationPassScope(kOptimizerPass, job, [&]() -> Maybe { CHECK(old_job_builder == job_builder.get()); // Check this lambda never been async called AddDiffHalf2FloatCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); AddDiffStaticShapeCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); AddDiffParallelCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); JUST(ScaleModelDiffByLossInstanceNum(op_graph, job_builder.get(), &model_lbi2model_diff_lbi)); JUST(ScaleInitialDiffByLossScale(ctx, op_graph, job_builder.get(), &loss_lbi2initial_diff_lbi)); ScaleModelDiffByLossScale(ctx, op_graph, job_builder.get(), &model_lbi2model_diff_lbi); JUST(CountNotFiniteIfNeeded(ctx, op_graph, job_builder.get(), model_lbi2model_diff_lbi)); for (const auto& optimizer_conf : job->job_conf().train_conf().optimizer_conf()) { HashMap cur_model_lbi2model_diff_lbi; FilterCurModelLbi2ModelDiffLbiByName(optimizer_conf.variable_op_names(), model_lbi2model_diff_lbi, &cur_model_lbi2model_diff_lbi); if (optimizer_conf.has_clip_conf()) { ClipGradient(ctx, op_graph, job_builder.get(), &cur_model_lbi2model_diff_lbi, optimizer_conf.clip_conf()); } RegularizeGradient(op_graph, job_builder.get(), &cur_model_lbi2model_diff_lbi); op_graph.ForEachNode([&](OpNode* op_node) { const VariableOp* var_op = dynamic_cast(&op_node->op()); if (var_op == nullptr || cur_model_lbi2model_diff_lbi.find(var_op->BnInOp2Lbi(var_op->SoleObn())) == cur_model_lbi2model_diff_lbi.end()) { return; } const std::string& model_diff_lbn = GenLogicalBlobName( cur_model_lbi2model_diff_lbi.at(var_op->BnInOp2Lbi(var_op->SoleObn()))); AddOptimizerOp(ctx, *op_node, model_diff_lbn, optimizer_conf, job_builder.get()); }); } return Maybe::Ok(); })); return Maybe::Ok(); } REGISTER_JOB_PASS("GenerateOptimizerOpConfs", GenerateOptimizerOpConfs); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { const Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) { CHECK(Singleton>::Get()->Has(scope_symbol_id)); return Singleton>::Get()->Get(scope_symbol_id); } const Scope& Scope4OpNode(const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); CHECK(op_conf.has_scope_symbol_id()); return Scope4ScopeSymbolId(op_conf.scope_symbol_id()); } bool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); } int64_t GetStageIdHint(const OpNode* node) { return Scope4OpNode(node).Int64("pipeline_stage_id_hint"); } Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) { { // NOTE(chengcheng): Disable group boxing for pipeline parallel, because there will be bad case // make forward backward exec sequential in ZeRO + 3-D Parallel by insert additional boxing // identity. int64_t max_stage_id = 0; op_graph.ForEachNode([&](const OpNode* this_node) { if (!OpNodeHasScope(this_node)) { LOG(WARNING) << " op : " << this_node->op().op_conf().DebugString() << " has NOT scope!"; return; } max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node)); }); if (max_stage_id > 0) { return Maybe::Ok(); } } HashMap, std::vector>>> lbi2consumer_grouped_by_parallel; HashMap op_node2op_conf; op_graph.ForEachNode([&](const OpNode* node) { OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case(); if (IsClassRegistered(op_type_case)) { return; } for (const std::string& ibn : node->op().input_bns()) { const auto& blob_modifier_ = node->op().InputBlobModifier4Ibn(ibn); if (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) { continue; } const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn); const OpNode& producer = node->ProducerOpNode4Lbi(lbi); const auto& logical_shape = node->LogicalBlobDesc4Lbi(lbi).shape(); const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi); const std::string& producer_lbn = *CHECK_JUST(producer.op().obn4lbi(lbi)); const ParallelDesc& producer_parallel_desc = *CHECK_JUST(producer.op().GetParallelDesc4BnInOp(producer_lbn)).get(); ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc, &reduced_in_nd_sbp, logical_shape); const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn); const ParallelDesc& consumer_parallel_desc = *CHECK_JUST(node->op().GetParallelDesc4BnInOp(ibn)); ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; NdSbp reduced_out_nd_sbp; NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc, &reduced_out_nd_sbp, logical_shape); if (reduced_in_parallel_desc == reduced_out_parallel_desc && reduced_in_nd_sbp == reduced_out_nd_sbp) { continue; } lbi2consumer_grouped_by_parallel[lbi][{reduced_out_parallel_desc, reduced_out_nd_sbp}] .push_back({node, ibn}); if (op_node2op_conf.find(node) == op_node2op_conf.end()) { op_node2op_conf[node] = node->op().op_conf(); } } }); for (const auto& lbi7groups : lbi2consumer_grouped_by_parallel) { const LogicalBlobId& lbi = lbi7groups.first; for (const auto& parallel7group : lbi7groups.second) { if (parallel7group.second.size() < 2) { continue; } const ParallelDesc& dst_parallel_desc = parallel7group.first.first; const NdSbp& dst_nd_sbp = parallel7group.first.second; OperatorConf identity_op_conf{}; identity_op_conf.set_name("Sys-Boxing-GroupIdentity-" + lbi.op_name() + "_" + lbi.blob_name() + "-" + NewUniqueId()); IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf(); identity_conf->set_in(GenLogicalBlobName(lbi)); identity_conf->set_out("out"); job_builder->AddOps(dst_parallel_desc.parallel_conf(), {identity_op_conf}); NdSbpSignature identity_nd_sbp_signature; (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())["in"] = dst_nd_sbp; (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())["out"] = dst_nd_sbp; job_builder->AddNdSbpSignature4OpName(identity_op_conf.name(), identity_nd_sbp_signature); LogicalBlobId grouped_lbi; grouped_lbi.set_op_name(identity_op_conf.name()); grouped_lbi.set_blob_name(identity_conf->out()); for (const auto& consumer7ibn : parallel7group.second) { const OpNode* consumer = consumer7ibn.first; const std::string& ibn = consumer7ibn.second; OperatorConf& consumer_op_conf = op_node2op_conf[consumer]; const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, GenLogicalBlobName(grouped_lbi)); CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val); } } } for (const auto& op_node7op_conf : op_node2op_conf) { JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second)); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_ #define ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_ #include "oneflow/core/graph/op_graph.h" namespace oneflow { class OpGraph; class Job; Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_ ================================================ FILE: oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class IndexedSlicesOptimizerRewritePass final : public JobPass { public: IndexedSlicesOptimizerRewritePass() = default; ~IndexedSlicesOptimizerRewritePass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().job_conf().has_indexed_slices_optimizer_conf() && ctx.job_desc().job_conf().indexed_slices_optimizer_conf().enable(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const PbRpf& include_op_names = GlobalJobDesc().job_conf().indexed_slices_optimizer_conf().include_op_names().op_name(); const std::set include_op_name_set( {include_op_names.cbegin(), include_op_names.cend()}); op_graph.ForEachNode([&](const OpNode* src_node) { const OperatorConf& src_op_conf = src_node->op().op_conf(); if (src_node->out_edges().size() != 1) { return; } std::string indices_lbn; std::string values_lbn; std::string model_op_name; if (!src_op_conf.has_user_conf()) { return; } const user_op::UserOpConfWrapper src_op(src_op_conf); if (src_op.op_type_name() == "unsorted_segment_sum" && src_op.attr("axis") == 0) { indices_lbn = src_op.input("segment_ids", 0); values_lbn = src_op.input("data", 0); } else if (src_op.op_type_name() == "unsorted_segment_sum_like" && src_op.attr("axis") == 0) { indices_lbn = src_op.input("segment_ids", 0); values_lbn = src_op.input("data", 0); } else { return; } std::vector op_nodes_to_remove; std::vector op_nodes_apply_to_diff; const OpNode* dst_node = src_node->SoleOutEdge()->dst_node(); do { if (dst_node->op().output_bns().empty()) { break; } const OperatorConf& dst_op_conf = dst_node->op().op_conf(); if (dst_op_conf.has_user_conf() && dst_op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast") { if (dst_node->out_edges().size() != 1) { return; } op_nodes_to_remove.emplace_back(dst_node); dst_node = dst_node->SoleOutEdge()->dst_node(); continue; } else if (dst_op_conf.has_user_conf() && dst_op_conf.user_conf().op_type_name() == "scalar_mul") { if (dst_node->out_edges().size() != 1) { return; } op_nodes_apply_to_diff.emplace_back(dst_node); dst_node = dst_node->SoleOutEdge()->dst_node(); continue; } else { return; } } while (true); if (!dst_node->op().op_conf().has_user_conf()) { return; } const user_op::UserOpConfWrapper user_op_conf(dst_node->op().op_conf()); if (user_op_conf.op_type_name() != "sgd_update" && user_op_conf.op_type_name() != "momentum_update" && user_op_conf.op_type_name() != "adam_update") { return; } if (user_op_conf.attr("scale") != 1.0 || user_op_conf.attr("l1") != 0.0f || user_op_conf.attr("l2") != 0.0f || user_op_conf.has_input("scale_by_tensor", 0)) { return; } const LogicalBlobId& model_lbi = GenLogicalBlobId(user_op_conf.input("model", 0)); if (dst_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(user_op_conf.input("model_diff", 0))) .data_type() != dst_node->LogicalBlobDesc4Lbi(model_lbi).data_type()) { return; } model_op_name = model_lbi.op_name(); user_op::UserOpConfWrapperBuilder indexed_slices_op_builder("System-Optimizer-IndexedSlices-" + model_op_name); indexed_slices_op_builder.OpTypeName("indexed_slices_" + user_op_conf.op_type_name()) .Input("model", user_op_conf.input("model", 0)) .Input("learning_rate", user_op_conf.input("learning_rate", 0)) .Attr("weight_decay", user_op_conf.attr("weight_decay")) .Attr("learning_rate_scale", user_op_conf.attr("learning_rate_scale")); if (user_op_conf.op_type_name() == "sgd_update") { // do nothing } else if (user_op_conf.op_type_name() == "momentum_update") { indexed_slices_op_builder.Input("momentum", user_op_conf.input("momentum", 0)) .Attr("beta", user_op_conf.attr("beta")) .Attr("dampening", user_op_conf.attr("dampening")) .Attr("nesterov", user_op_conf.attr("nesterov")) .Attr("maximize", user_op_conf.attr("maximize")); } else if (user_op_conf.op_type_name() == "adam_update") { indexed_slices_op_builder.Input("m", user_op_conf.input("m", 0)) .Input("v", user_op_conf.input("v", 0)) .Attr("beta1", user_op_conf.attr("beta1")) .Attr("beta2", user_op_conf.attr("beta2")) .Attr("epsilon", user_op_conf.attr("epsilon")); if (user_op_conf.has_input("max_v", 0)) { indexed_slices_op_builder.Input("max_v", user_op_conf.input("max_v", 0)); } } else { return; } CHECK(!model_op_name.empty()); CHECK(!indices_lbn.empty()); CHECK(!values_lbn.empty()); if (include_op_name_set.find(model_op_name) == include_op_name_set.end()) { return; } for (const OpNode* node : op_nodes_to_remove) { job_builder->DelOps({node->op().op_conf()}); } for (const OpNode* node : op_nodes_apply_to_diff) { OperatorConf new_conf = node->op().op_conf(); if (new_conf.has_user_conf() && new_conf.user_conf().op_type_name() == "scalar_mul") { const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&new_conf, "in_0", values_lbn); CHECK_EQ(GenLogicalBlobName(node->op().BnInOp2Lbi("in_0")), old_val); values_lbn = GenLogicalBlobName(new_conf.name(), "out_0"); job_builder->MutOpsOnlyOnce({new_conf}); } else { UNIMPLEMENTED(); } } indexed_slices_op_builder.Input("model_diff_indices", indices_lbn) .Input("model_diff_values", values_lbn) .ScopeSymbolId(src_op_conf.scope_symbol_id()); job_builder->DelOps({src_op_conf, user_op_conf.op_conf()}); job_builder->AddOps(dst_node->parallel_desc().parallel_conf(), {indexed_slices_op_builder.Build().op_conf()}); }); return Maybe::Ok(); } REGISTER_JOB_PASS("IndexedSlicesOptimizerRewritePass", IndexedSlicesOptimizerRewritePass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/input_autotick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autotick.h" namespace oneflow { namespace { class MutInputOpConTickInputHelper final : public MutOpConTickInputHelper { public: MutInputOpConTickInputHelper() : MutOpConTickInputHelper() {} bool VirtualIsTickInputBound() const override { return op_conf().input_conf().has_tick(); } OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override { OperatorConf ret(op_conf()); ret.mutable_input_conf()->set_tick(lbn); return ret; } }; } // namespace REGISTER_AUTO_TICK(OperatorConf::kInputConf, MutInputOpConTickInputHelper); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/nd_sbp_util.h" #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { DEFINE_ENV_INTEGER(ONEFLOW_GRAPH_MAX_NCCL_COMPUTE_STREAM, 8); namespace { class InsertNcclLogicalOpPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(InsertNcclLogicalOpPass); InsertNcclLogicalOpPass() = default; ~InsertNcclLogicalOpPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return Singleton::Get()->nccl_use_compute_stream(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; const std::string kNcclLogicalOpNamePrefix = "System-NCCL-Logical"; bool IsTickOpConf(const OperatorConf& op_conf) { if (IsClassRegistered(op_conf.op_type_case())) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "cast_to_tick" || user_type_name == "acc_ctrl_tick") { return true; } } return false; } bool IsBreakpointOpNode(const OpNode* node) { // NOTE(chengcheng): breakpoint op is special which CANNOT through subgraph such as: // variable, tick, repeat/acc/pack/unpack change timeshape const Operator& op = node->op(); const OperatorConf& op_conf = op.op_conf(); // TODO(chengcheng): filter ops which has special type // TODO(chengcheng): get stream by op type if (op_conf.has_variable_conf() /* varialbe */ || IsTickOpConf(op_conf) /* tick */ || op_conf.has_input_conf() || op_conf.has_output_conf() /* io */ || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */ || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "pack" || user_type_name == "unpack" || user_type_name == "identity_buffer") { return true; } if (!EnableLogicalChain()) { // NOTE(chengcheng): in old task graph chain version, consider acc as breakpoint node if (user_type_name == "acc") { return true; } } } return false; } bool IsAccOpNode(const OpNode* node) { return node->op().op_conf().has_user_conf() && node->op().op_conf().user_conf().op_type_name() == "acc"; } bool IsRepeatOpNode(const OpNode* node) { return node->op().op_conf().has_user_conf() && node->op().op_conf().user_conf().op_type_name() == "repeat"; } std::shared_ptr GetOpNodeTimeShape(const OpNode* op_node) { return CHECK_JUST(op_node->op().GetOpTimeShape()); } std::shared_ptr GetOpNodeInputTimeShape(const OpNode* op_node) { return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape()); } std::shared_ptr GetOpNodeFastestTimeShape(const OpNode* op_node) { return CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()); } bool SharedPtrShapeEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { return (*lhs) == (*rhs); } void FindAllConnectedSubgraphForGpuExecOrder(std::vector>* ret, const OpGraph& op_graph, const std::vector& order) { // NOTE(chengcheng): acc subgraph may greater than fw/bw subgraph. we need use max time shape. std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); op_graph.ForEachNode([&](const OpNode* node) { std::shared_ptr this_time_shape = GetOpNodeFastestTimeShape(node); if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) { seed_time_shape = this_time_shape; } }); VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); HashSet visited; for (const OpNode* seed_node : order) { if (visited.find(seed_node) != visited.end()) { continue; } CHECK(visited.insert(seed_node).second); const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1. if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; } if (seed_parallel_desc.parallel_num() <= 1) { continue; } // NOTE(chengcheng): using fastest time shape for merge acc into bw subgraph. if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; } if (IsBreakpointOpNode(seed_node)) { continue; } // NOTE(chengcheng): // stream name hint maybe set by other job pass like replace embedding. // we cannot replace stream name in subgraph if (seed_node->op().op_conf().has_stream_name_hint()) { continue; } HashSet this_subgraph; std::queue queued_nodes; queued_nodes.push(seed_node); while (!queued_nodes.empty()) { const OpNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); CHECK(this_subgraph.insert(cur_node).second); cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) { CHECK(visited.insert(next_node).second); queued_nodes.push(next_node); } }); } if (this_subgraph.size() > 1) { ret->emplace_back(HashSet()); ret->back().swap(this_subgraph); } } std::sort(ret->begin(), ret->end(), [](const HashSet& lhs, const HashSet& rhs) { return lhs.size() > rhs.size(); }); } bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp, const SbpParallel& dst_sbp, const std::string& lbn, const int64_t scope_symbol_id, const BlobDesc& logical_blob_desc, const int64_t parallel_num) { auto CanSplitAtDim = [&](int64_t dim) -> bool { if (logical_blob_desc.shape().NumAxes() <= dim) { return false; } return logical_blob_desc.shape().At(dim) % parallel_num == 0; }; if (src_sbp.has_partial_sum_parallel() && dst_sbp.has_broadcast_parallel()) { // P->B : AllReduce *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-P2B-" + NewUniqueId()) .Op("_nccl_logical_all_reduce") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (CanSplitAtDim(0) && (src_sbp.has_partial_sum_parallel() && dst_sbp.has_split_parallel()) && (dst_sbp.split_parallel().axis() == 0)) { // P->S(0) : ReduceScatter *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-P2S-" + NewUniqueId()) .Op("_nccl_logical_reduce_scatter") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (CanSplitAtDim(0) && (src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel()) && (src_sbp.split_parallel().axis() == 0)) { // S(0)->B : AllGather *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2B-" + NewUniqueId()) .Op("_nccl_logical_all_gather") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel() && src_sbp.split_parallel().axis() > 0 && CanSplitAtDim(src_sbp.split_parallel().axis())) { // S(1)->B : AllGather Noncontinuous *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2B-" + NewUniqueId()) .Op("_nccl_logical_all_gather_noncontinuous") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (src_sbp.has_split_parallel() && dst_sbp.has_split_parallel() && src_sbp.split_parallel().axis() != dst_sbp.split_parallel().axis() && CanSplitAtDim(src_sbp.split_parallel().axis()) && CanSplitAtDim(dst_sbp.split_parallel().axis())) { // S(in)->S(out) : All2All *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2S-" + NewUniqueId()) .Op("_nccl_logical_s2s") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (CanSplitAtDim(dst_sbp.split_parallel().axis()) && (src_sbp.has_partial_sum_parallel() && dst_sbp.has_split_parallel()) && (dst_sbp.split_parallel().axis() > 0)) { // P->S(1) : ReduceScatter Noncontinuous *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-P2S-" + NewUniqueId()) .Op("_nccl_logical_reduce_scatter_noncontinuous") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (!dst_sbp.has_partial_sum_parallel()) { *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(Send)2(Recv)-" + NewUniqueId()) .Op("_nccl_logical_send_recv") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", {SbpToString(src_sbp)}) .Attr>("dst_reduced_nd_sbp", {SbpToString(dst_sbp)}) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } return false; } bool TryBuildNcclBy2DHierarchySameDim0(OperatorConf* ret, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const std::shared_ptr& hierarchy, const std::string& lbn, const int64_t scope_symbol_id, const BlobDesc& logical_blob_desc) { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(0) == dst_nd_sbp.sbp_parallel(0)); const SbpParallel& src_dim1_sbp = src_nd_sbp.sbp_parallel(1); const SbpParallel& dst_dim1_sbp = dst_nd_sbp.sbp_parallel(1); // split when dim0 sbp is split parallel DimVector dim_vec = logical_blob_desc.shape().dim_vec(); if (src_nd_sbp.sbp_parallel(0).has_split_parallel()) { const int64_t axis = src_nd_sbp.sbp_parallel(0).split_parallel().axis(); dim_vec.at(axis) /= hierarchy->At(0); } const int64_t num_ranks = hierarchy->At(1); if (src_dim1_sbp.has_partial_sum_parallel() && dst_dim1_sbp.has_broadcast_parallel()) { // (*, P)->(*, B) : AllReduce *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(*P)2(*B)-" + NewUniqueId()) .Op("_nccl_logical_2D_same_dim0_all_reduce") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if ((src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_broadcast_parallel()) && (src_dim1_sbp.split_parallel().axis() == 0) && (dim_vec.at(0) % num_ranks == 0)) { // (*, S(0)) -> (*, B) : AllGather *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(*S0)2(*B)-" + NewUniqueId()) .Op("_nccl_logical_2D_same_dim0_all_gather") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if (src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_broadcast_parallel() && (src_dim1_sbp.split_parallel().axis() > 0) && (dim_vec.at(src_dim1_sbp.split_parallel().axis()) % num_ranks == 0)) { // (*, S(1)) -> (*, B) : AllGather Noncontinuous *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(*S1)2(*B)-" + NewUniqueId()) .Op("_nccl_logical_2D_same_dim0_all_gather_noncontinuous") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } else if ((src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_split_parallel()) && (src_dim1_sbp.split_parallel().axis() != dst_dim1_sbp.split_parallel().axis()) && (dim_vec.at(src_dim1_sbp.split_parallel().axis()) % num_ranks == 0) && (dim_vec.at(dst_dim1_sbp.split_parallel().axis()) % num_ranks == 0)) { // (*, S(src_split_axis)) -> (*, S(dst_split_axis)) : All2All *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(*S)2(*S)-" + NewUniqueId()) .Op("_nccl_logical_2D_same_dim0_all2all") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } return false; } bool TryBuildNcclBy2DHierarchySameDim1(OperatorConf* ret, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const std::shared_ptr& hierarchy, const std::string& lbn, const int64_t scope_symbol_id, const BlobDesc& logical_blob_desc) { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1) == dst_nd_sbp.sbp_parallel(1)); const SbpParallel& src_dim1_sbp = src_nd_sbp.sbp_parallel(0); const SbpParallel& dst_dim1_sbp = dst_nd_sbp.sbp_parallel(0); if (src_dim1_sbp.has_partial_sum_parallel() && dst_dim1_sbp.has_broadcast_parallel()) { // (P, *) -> (B, *) : AllReduce *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(P*)2(B*)-" + NewUniqueId()) .Op("_nccl_logical_2D_same_dim1_all_reduce") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } return false; } bool TryBuildNcclBy2DHierarchyOthers(OperatorConf* ret, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const std::shared_ptr& hierarchy, const std::string& lbn, const int64_t scope_symbol_id, const BlobDesc& logical_blob_desc) { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); // send recv is dealing with same 0-Dim VLOG_IF(3, src_nd_sbp.sbp_parallel(0) == dst_nd_sbp.sbp_parallel(0)) << "send recv is dealing with same 0-Dim, src sbp " << NdSbpToString(src_nd_sbp) << ", dst sbp " << NdSbpToString(dst_nd_sbp); // send recv is dealing with same 1-Dim, such as (B, S0) -> (S0, S0) VLOG_IF(3, ((src_nd_sbp.sbp_parallel(1) == dst_nd_sbp.sbp_parallel(1)) && !(NdSbpAllSameSplitParallel(src_nd_sbp) || NdSbpAllSameSplitParallel(dst_nd_sbp)))) << "send recv is dealing with same 1-Dim, src sbp " << NdSbpToString(src_nd_sbp) << ", dst sbp " << NdSbpToString(dst_nd_sbp); // send recv can not dealing with P in dst_nd_sbp if (NdSbpHasPartialParallel(dst_nd_sbp)) return false; *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(Send)2(Recv)-" + NewUniqueId()) .Op("_nccl_logical_send_recv") .Input("in", lbn) .Output("out") .Attr>("src_reduced_nd_sbp", NdSbpToStringList(src_nd_sbp)) .Attr>("dst_reduced_nd_sbp", NdSbpToStringList(dst_nd_sbp)) .ScopeSymbolId(scope_symbol_id) .Build() .op_conf(); return true; } Maybe BuildScopeWithReducedParallelDesc(int64_t old_scope_symbol_id, const ParallelDesc& parallel_desc) { auto* scope_storage = Singleton>::Get(); CHECK_OR_RETURN(scope_storage->Has(old_scope_symbol_id)); auto old_scope = scope_storage->GetPtr(old_scope_symbol_id); std::shared_ptr new_scope; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { new_scope = JUST(builder->BuildScopeWithNewParallelConf(old_scope, parallel_desc.parallel_conf())); return Maybe::Ok(); })); // NOTE(chengcheng): need sync vm for get scope right now JUST(vm::CurrentRankSync()); CHECK_OR_RETURN(new_scope); return JUST(new_scope->symbol_id()); } bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const OpNode* dst_node, const LogicalBlobId& lbi, ParallelDesc* src_reduced_parallel_desc, ParallelDesc* dst_reduced_parallel_desc, NdSbp* src_reduced_nd_sbp, NdSbp* dst_reduced_nd_sbp) { if (!src_node->op().op_conf().has_scope_symbol_id()) { return false; /* device_tick */ } const std::string lbn = GenLogicalBlobName(lbi); const BlobDesc& logical_blob_desc = src_node->LogicalBlobDesc4Lbi(lbi); // reduce hierarchy InOutParallelDimReduce(src_node->parallel_desc(), dst_node->parallel_desc(), src_node->NdSbp4Lbi(lbi), dst_node->NdSbp4Lbi(lbi), src_reduced_parallel_desc, dst_reduced_parallel_desc, src_reduced_nd_sbp, dst_reduced_nd_sbp, logical_blob_desc.shape()); CHECK_EQ(src_reduced_parallel_desc->parallel_num(), dst_reduced_parallel_desc->parallel_num()); std::shared_ptr src_reduced_hierarchy = src_reduced_parallel_desc->hierarchy(); std::shared_ptr dst_reduced_hierarchy = dst_reduced_parallel_desc->hierarchy(); if ((*src_reduced_hierarchy) == (*dst_reduced_hierarchy) && (*src_reduced_nd_sbp) == (*dst_reduced_nd_sbp)) { // one to one return false; } // NOTE(chengcheng): nccl donot support dynamic shape. if (logical_blob_desc.is_dynamic()) { return false; } CHECK_GT(logical_blob_desc.shape().elem_cnt(), 0) << dst_node->op().op_name() << " consume " << GenLogicalBlobName(lbi) << ", " << *CHECK_JUST(PlacementToString(*src_reduced_parallel_desc)) << " " << NdSbpToString(*src_reduced_nd_sbp) << " -> " << *CHECK_JUST(PlacementToString(*dst_reduced_parallel_desc)) << " " << NdSbpToString(*dst_reduced_nd_sbp); int64_t scope_symbol_id = CHECK_JUST(BuildScopeWithReducedParallelDesc( src_node->op().op_conf().scope_symbol_id(), *src_reduced_parallel_desc)); if (src_reduced_hierarchy->NumAxes() == 1 && dst_reduced_hierarchy->NumAxes() == 1) { return TryBuildNcclBy1DHierarchy(ret, src_reduced_nd_sbp->sbp_parallel(0), dst_reduced_nd_sbp->sbp_parallel(0), lbn, scope_symbol_id, logical_blob_desc, src_reduced_parallel_desc->parallel_num()); } else if (src_reduced_hierarchy->NumAxes() == 2 && (*src_reduced_hierarchy == *dst_reduced_hierarchy)) { bool got_nccl = false; if (src_reduced_nd_sbp->sbp_parallel(0) == dst_reduced_nd_sbp->sbp_parallel(0)) { // TODO(): same dim 0 need to deal with (*, P) -> (*, S) got_nccl = TryBuildNcclBy2DHierarchySameDim0(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } else if (src_reduced_nd_sbp->sbp_parallel(1) == dst_reduced_nd_sbp->sbp_parallel(1)) { if (!(NdSbpAllSameSplitParallel(*src_reduced_nd_sbp) || NdSbpAllSameSplitParallel(*dst_reduced_nd_sbp))) { got_nccl = TryBuildNcclBy2DHierarchySameDim1(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } } if (!got_nccl) { got_nccl = TryBuildNcclBy2DHierarchyOthers(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } VLOG_IF(3, !got_nccl) << "Cannot get nccl logical op for 2D sbp, src nd sbp " << NdSbpToString(*src_reduced_nd_sbp) << ", dst nd sbp " << NdSbpToString(*dst_reduced_nd_sbp) << "."; return got_nccl; } return false; } void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( HashMap* subgraph_op_name2conf, HashSet* mut_op_names, std::vector* nccl_op_confs, std::vector* nccl_op_parallel_confs, const std::vector& subgraph_ordered_nodes, const HashMap& node2subgraph_order) { for (const OpNode* dst_node : subgraph_ordered_nodes) { const std::string& dst_op_name = dst_node->op().op_name(); for (const OpEdge* op_edge : dst_node->in_edges()) { const OpNode* src_node = op_edge->src_node(); const std::string& src_op_name = src_node->op().op_name(); CHECK(src_node != dst_node); if (src_node->parallel_desc().EqualsIgnoringHierarchy(dst_node->parallel_desc())) { // NOTE(chengcheng): We don't care src node whether in this subgraph, or whether is repeat // op, or whether is breaking op. We ONLY care src node is same placement with dst. // So, we can handle both ZeRO from variable and in GradAcc from repeat and in Pipeline. for (const LogicalBlobId& lbi : op_edge->lbis()) { OperatorConf nccl_op; ParallelDesc src_reduced_parallel_desc = op_edge->src_node()->parallel_desc(); ParallelDesc dst_reduced_parallel_desc = op_edge->dst_node()->parallel_desc(); NdSbp src_reduced_nd_sbp; NdSbp dst_reduced_nd_sbp; if (!TryBuildNcclLogicalOpConf(&nccl_op, src_node, dst_node, lbi, &src_reduced_parallel_desc, &dst_reduced_parallel_desc, &src_reduced_nd_sbp, &dst_reduced_nd_sbp)) { continue; } mut_op_names->insert(dst_op_name); // insert nccl op user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op); for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) { std::string old_lbn = ReplaceInputLbnInOpCustomizedConf( &subgraph_op_name2conf->at(dst_op_name), ibn, nccl_op_wrapper.output("out", 0)); CHECK(old_lbn == GenLogicalBlobName(lbi)); } // NOTE(chengcheng): Do NOT add ctrl edge for nccl fusion. nccl_op_confs->emplace_back(nccl_op); // NOTE(chengcheng, guoran): set nccl op as dst_node parallel_conf (hierarchy) may check // failed in complier, so need use dst_node reduced_parallel_conf. nccl_op_parallel_confs->emplace_back(dst_reduced_parallel_desc.parallel_conf()); VLOG(2) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << "] to [" << dst_op_name << "]\n"; } } } } } void GenAfterAccSubgraph(std::vector* ordered_after_acc_subgraph, const HashMap& op_node2global_order, const std::vector& ordered_acc_op_nodes) { std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); const ParallelDesc& seed_parallel_desc = ordered_acc_op_nodes.front()->parallel_desc(); HashSet visited; std::queue queued_nodes; auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) { if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) { CHECK(visited.insert(next_node).second); queued_nodes.push(next_node); } }; auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); }; for (const OpNode* acc_node : ordered_acc_op_nodes) { for (const OpEdge* out_edge : acc_node->out_edges()) { const OpNode* seed_node = out_edge->dst_node(); SearchToNextNode(acc_node, seed_node, out_edge); } } while (!queued_nodes.empty()) { const OpNode* cur_node = queued_nodes.front(); queued_nodes.pop(); ordered_after_acc_subgraph->push_back(cur_node); for (const OpEdge* in_edge : cur_node->in_edges()) { SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { SearchToNextNode(cur_node, out_edge->dst_node(), out_edge); } } std::sort(ordered_after_acc_subgraph->begin(), ordered_after_acc_subgraph->end(), CmpOpNodeOrder); } struct InsertNcclSubGraph { std::vector ordered_op_nodes; int64_t begin_op_global_order; int64_t end_op_global_order; const OpNode* begin_op; const OpNode* end_op; }; struct PlacementNcclSubGraghsInfo { std::vector> ordered_subgraph; std::vector ordered_acc_op_nodes; const ParallelDesc* seed_parallel_desc; }; void InitInsertNcclSubGraphInfoFromSet( std::shared_ptr nccl_subgraph_info, const HashSet& subgraph, const HashMap& op_node2global_order, const std::function& CmpOpNodeOrder) { auto* subgraph_ordered_nodes = &nccl_subgraph_info->ordered_op_nodes; subgraph_ordered_nodes->assign(subgraph.begin(), subgraph.end()); std::sort(subgraph_ordered_nodes->begin(), subgraph_ordered_nodes->end(), CmpOpNodeOrder); nccl_subgraph_info->begin_op = subgraph_ordered_nodes->front(); nccl_subgraph_info->end_op = subgraph_ordered_nodes->back(); nccl_subgraph_info->begin_op_global_order = op_node2global_order.at(nccl_subgraph_info->begin_op); nccl_subgraph_info->end_op_global_order = op_node2global_order.at(nccl_subgraph_info->end_op); CHECK(nccl_subgraph_info->begin_op != nccl_subgraph_info->end_op); CHECK_LT(nccl_subgraph_info->begin_op_global_order, nccl_subgraph_info->end_op_global_order); } std::string GetStreamIndexName(uint32_t id) { return "NCCL_COMPUTE_" + std::to_string(id); } int64_t InsertNcclLogicalOpsInSubGraph(const OpGraph& op_graph, JobBuilder* job_builder, const std::vector& subgraph_ordered_nodes, int64_t* nccl_compute_stream_id, const int64_t logical_chain_id) { HashMap node2subgraph_order; node2subgraph_order.reserve(subgraph_ordered_nodes.size()); for (int64_t i = 0; i < subgraph_ordered_nodes.size(); ++i) { CHECK(node2subgraph_order.emplace(subgraph_ordered_nodes.at(i), i).second); } VLOG(3) << " ======================================================================== \n" << " Try insert nccl logical ops into Graph: " << job_builder->job().job_conf().job_name() << " , logical_chain: " << logical_chain_id << ". Begin...\n"; HashSet mut_op_names; HashMap subgraph_op_name2conf; for (const OpNode* this_node : subgraph_ordered_nodes) { VLOG(3) << "logical_chain: " << logical_chain_id << " , op: " << this_node->op().op_name(); CHECK( subgraph_op_name2conf.emplace(this_node->op().op_name(), this_node->op().op_conf()).second); } std::vector nccl_op_confs; std::vector nccl_op_parallel_confs; // NOTE(chengcheng): ONLY support insert nccl to dst for memory. InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names, &nccl_op_confs, &nccl_op_parallel_confs, subgraph_ordered_nodes, node2subgraph_order); VLOG(3) << " ======================================================================== \n" << " Try insert nccl logical ops into Graph: " << job_builder->job().job_conf().job_name() << " , logical_chain: " << logical_chain_id << ". End.\n"; // NOTE(chengcheng): For NCCL logical correct exec order in pipeline multi-subgraph. if (nccl_op_confs.empty()) { return 0; } const int64_t max_nccl_stream_count = EnvInteger(); if ((*nccl_compute_stream_id) >= max_nccl_stream_count) { return 0; // NOTE(chengcheng): ONLY support kMaxNcclComputeStreamCount insert nccl subgraphs. } std::string stream_index_name = GetStreamIndexName(*nccl_compute_stream_id); // NOTE(chengcheng): ONLY valid subgraph will increase nccl stream id. (*nccl_compute_stream_id)++; // NOTE(chengcheng): set ALL subgraph op and ALL nccl op stream index and logical chain id. for (auto& pair : subgraph_op_name2conf) { mut_op_names.insert(pair.first); pair.second.set_stream_name_hint(stream_index_name); pair.second.set_logical_chain_id(logical_chain_id); } for (auto& nccl_op : nccl_op_confs) { nccl_op.set_stream_name_hint(stream_index_name); nccl_op.set_logical_chain_id(logical_chain_id); } std::vector mut_op_confs; mut_op_confs.reserve(mut_op_names.size()); for (const std::string& mut_op_name : mut_op_names) { mut_op_confs.emplace_back(subgraph_op_name2conf.at(mut_op_name)); } job_builder->MutOpsOnlyOnce(mut_op_confs); CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size()); for (int64_t i = 0; i < nccl_op_confs.size(); ++i) { CHECK_JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i))); } VLOG(3) << " In logical chain id: " << logical_chain_id << " insert nccl op num = " << nccl_op_confs.size() << " and origin chain op num = " << subgraph_ordered_nodes.size(); return nccl_op_confs.size() + subgraph_ordered_nodes.size(); } void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, JobBuilder* job_builder, const std::vector& ordered_acc_op_nodes, const HashMap& op_node2global_order, const int64_t nccl_compute_stream_id, const int64_t logical_chain_id) { // insert nccl ops after acc std::vector ordered_after_acc_subgraph; GenAfterAccSubgraph(&ordered_after_acc_subgraph, op_node2global_order, ordered_acc_op_nodes); if (ordered_after_acc_subgraph.size() <= 1) { return; } HashMap node2subgraph_order; node2subgraph_order.reserve(ordered_after_acc_subgraph.size()); for (int64_t i = 0; i < ordered_after_acc_subgraph.size(); ++i) { CHECK(node2subgraph_order.emplace(ordered_after_acc_subgraph.at(i), i).second); } std::vector after_acc_nccl_op_confs; std::vector after_acc_nccl_parallel_confs; HashSet mut_op_names; HashMap acc_subgraph_op_name2conf; for (const OpNode* this_node : ordered_after_acc_subgraph) { CHECK(acc_subgraph_op_name2conf.emplace(this_node->op().op_name(), this_node->op().op_conf()) .second); VLOG(3) << "After Acc logical_chain: " << logical_chain_id << " , op: " << this_node->op().op_name(); } InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( &acc_subgraph_op_name2conf, &mut_op_names, &after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs, ordered_after_acc_subgraph, node2subgraph_order); if (after_acc_nccl_op_confs.empty()) { CHECK(after_acc_nccl_parallel_confs.empty()); CHECK(mut_op_names.empty()); } else { std::string stream_index_name = GetStreamIndexName(nccl_compute_stream_id); // set logical chain id and stream name for ops after acc for (auto& pair : acc_subgraph_op_name2conf) { mut_op_names.insert(pair.first); pair.second.set_stream_name_hint(stream_index_name); pair.second.set_logical_chain_id(logical_chain_id); } for (auto& nccl_op : after_acc_nccl_op_confs) { nccl_op.set_stream_name_hint(stream_index_name); nccl_op.set_logical_chain_id(logical_chain_id); } // insert nccl ops after acc std::vector mut_op_confs; mut_op_confs.reserve(mut_op_names.size()); for (const std::string& mut_op_name : mut_op_names) { mut_op_confs.emplace_back(acc_subgraph_op_name2conf.at(mut_op_name)); } job_builder->MutOpsOnlyOnce(mut_op_confs); CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size()); for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) { CHECK_JUST( job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i))); } } } Maybe InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { std::vector ordered_op_nodes; if (ParseBooleanFromEnv("DISABLE_LOGICAL_STRAIGHTEN", false)) { op_graph.TopoForEachNodeWithCtrlEdge( [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); }); } else { auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes); } HashMap op_node2global_order; for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) { op_node2global_order.emplace(ordered_op_nodes[global_order], global_order); } std::vector> subgraph_list; FindAllConnectedSubgraphForGpuExecOrder(&subgraph_list, op_graph, ordered_op_nodes); if (subgraph_list.size() == 0) { return Maybe::Ok(); } // sign subgraph ops logical chain id for merge. int64_t global_logical_chain_id = 0; auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); }; auto CmpSubGraphOrder = [&](const std::shared_ptr& lhs, const std::shared_ptr& rhs) { int64_t lhs_begin_op_global_order = op_node2global_order.at(lhs->ordered_op_nodes.front()); int64_t rhs_begin_op_global_order = op_node2global_order.at(rhs->ordered_op_nodes.front()); return lhs_begin_op_global_order < rhs_begin_op_global_order; }; HashMap placement2subgraphs; for (const auto& subgraph : subgraph_list) { const OpNode* rand_node = *subgraph.begin(); const ParallelDesc& this_parallel_desc = rand_node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); auto it = placement2subgraphs.find(key); if (it == placement2subgraphs.end()) { it = placement2subgraphs.emplace(key, PlacementNcclSubGraghsInfo()).first; it->second.seed_parallel_desc = &this_parallel_desc; } else { CHECK(this_parallel_desc.EqualsIgnoringHierarchy(*it->second.seed_parallel_desc)); } auto& info = it->second; info.ordered_subgraph.emplace_back(std::make_shared()); InitInsertNcclSubGraphInfoFromSet(info.ordered_subgraph.back(), subgraph, op_node2global_order, CmpOpNodeOrder); } for (auto& pair : placement2subgraphs) { std::sort(pair.second.ordered_subgraph.begin(), pair.second.ordered_subgraph.end(), CmpSubGraphOrder); } for (const OpNode* this_node : ordered_op_nodes) { if (IsAccOpNode(this_node)) { const ParallelDesc& this_parallel_desc = this_node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); auto it = placement2subgraphs.find(key); if (it != placement2subgraphs.end()) { it->second.ordered_acc_op_nodes.emplace_back(this_node); } } } for (auto& pair : placement2subgraphs) { PlacementNcclSubGraghsInfo& info = pair.second; // NOTE(chengcheng): insert nccl ops for each subgraph int64_t stream_offset = 0; int64_t total_op_num = 0; for (int i = 0; i < info.ordered_subgraph.size(); i++) { auto& ordered_op_nodes = info.ordered_subgraph.at(i)->ordered_op_nodes; int64_t this_op_num = InsertNcclLogicalOpsInSubGraph( op_graph, job_builder, ordered_op_nodes, &stream_offset, global_logical_chain_id++); total_op_num += this_op_num; } if (stream_offset >= 2 && total_op_num >= 1000) { LOG(WARNING) << " In Graph: " << job_builder->job().job_conf().job_name() << " Placement: " << pair.first << " the total_op_num = " << total_op_num << " and has " << stream_offset << " different nccl stream which is possible to trigger cuda stream kernel " "launch upper limit." << " So the nccl logical kernel will from async to sync exec, which may affect " "performance."; EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); comm_mgr->SetAsyncLaunchCclLogicalKernel(false); } // NOTE(chengcheng): insert acc for all subgraph with same placement group if (!info.ordered_acc_op_nodes.empty()) { InsertNcclLogicalOpsAfterAcc(op_graph, job_builder, info.ordered_acc_op_nodes, op_node2global_order, stream_offset++, global_logical_chain_id++); } } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("InsertNcclLogicalOpPass", InsertNcclLogicalOpPass); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || defined(WITH_MLU) ================================================ FILE: oneflow/core/job_rewriter/insert_pinned_identity_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { class InsertPinnedIdentityOpPass final : public JobPass { public: InsertPinnedIdentityOpPass() = default; ~InsertPinnedIdentityOpPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe InsertPinnedIdentityOp(JobBuilder* job_builder, const OpGraph& op_graph, const std::string& lbn) { auto lbi = GenLogicalBlobId(lbn); const OpNode* node = op_graph.OpNode4OpName(lbi.op_name()); auto pinned_identity_op = user_op::UserOpConfWrapperBuilder(lbi.op_name() + "_" + lbi.blob_name() + "_pinned_identity") .Op("pinned_identity") .Input("in", lbn) .Output("out") .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) .Build(); const auto& parallel_conf = node->parallel_desc().parallel_conf(); job_builder->AddOps(parallel_conf, {pinned_identity_op.op_conf()}); node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) { CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf())); } OperatorConf& mut_consumer_op = CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name())); const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf( &mut_consumer_op, ibn, pinned_identity_op.output("out", 0)); CHECK_EQ(old_lbn, GenLogicalBlobName(lbi)); } } }); return pinned_identity_op.output("out", 0); } Maybe InsertPinnedIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().IsTrain()) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); HashMap pinned_lbns; TrainConf* train_conf = job->mutable_job_conf()->mutable_train_conf(); // insert after loss for (int i = 0; i < train_conf->loss_lbn_size(); ++i) { const auto& loss_lbn = train_conf->loss_lbn(i); auto it = pinned_lbns.find(loss_lbn); if (it == pinned_lbns.end()) { const auto& pinned_loss_lbn = JUST(InsertPinnedIdentityOp(&job_builder, op_graph, loss_lbn)); it = pinned_lbns.emplace(loss_lbn, *pinned_loss_lbn).first; } train_conf->set_loss_lbn(i, it->second); } // insert after loss initial gradient for (int i = 0; i < train_conf->loss_grad_lbn_size(); ++i) { const auto& loss_grad_lbn = train_conf->loss_grad_lbn(i); auto it = pinned_lbns.find(loss_grad_lbn); if (it == pinned_lbns.end()) { const auto& pinned_loss_grad_lbn = JUST(InsertPinnedIdentityOp(&job_builder, op_graph, loss_grad_lbn)); it = pinned_lbns.emplace(loss_grad_lbn, *pinned_loss_grad_lbn).first; } train_conf->set_loss_grad_lbn(i, it->second); } // insert after variable gradient for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) { auto* optimizer_conf = train_conf->mutable_optimizer_conf(i); for (int j = 0; j < optimizer_conf->variable_grad_lbns_size(); ++j) { const auto& variable_grad_lbn = optimizer_conf->variable_grad_lbns(j); if (variable_grad_lbn.empty()) { continue; } auto it = pinned_lbns.find(variable_grad_lbn); if (it == pinned_lbns.end()) { const auto& pinned_variable_grad_lbn = JUST(InsertPinnedIdentityOp(&job_builder, op_graph, variable_grad_lbn)); it = pinned_lbns.emplace(variable_grad_lbn, *pinned_variable_grad_lbn).first; } optimizer_conf->set_variable_grad_lbns(j, it->second); } } JUST(job_builder.MutOpTransactionCommit()); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("InsertPinnedIdentityOpPass", InsertPinnedIdentityOpPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/job_completer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/framework/placed_nd_sbp.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/job_rewriter/autotick.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/job_rewriter/boxing_with_middle_nodes.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/common/cost_util.h" #include "oneflow/core/common/buffer_manager.h" namespace oneflow { namespace { Maybe CheckOpGraph(const OpGraph& op_graph) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { size_t in_cnt = 0; op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; }); if (in_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_wait_and_send_ids_conf()); } size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); if (out_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_callback_notify_conf()); } return Maybe::Ok(); })); return Maybe::Ok(); } Maybe CheckAndLogOpGraph(const Job& job) { auto op_graph = std::make_unique(job); // Check op graph. JUST(CheckOpGraph(*op_graph)); // Log op graph. if (Singleton::Get()->enable_debug_mode()) { const JobDesc& job_desc = GlobalJobDesc(); TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job); op_graph->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot"); } return Maybe::Ok(); } Maybe WithOpGraphAndMutJob(Job* job, const std::function(const OpGraph&, Job*)>& Handler) { OpGraph op_graph(*job); JUST(Handler(op_graph, job)); return Maybe::Ok(); } Maybe WithOpGraphAndMutJobBuilder( Job* job, const std::function(const OpGraph&, JobBuilder*)>& Handler) { OpGraph op_graph(*job); JobBuilder job_builder(job); JUST(Handler(op_graph, &job_builder)); return Maybe::Ok(); } Maybe SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { for (const std::string& bn : op.input_bns()) { if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; } } return false; }; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); HashMap> op_conf2ctrl_in_op_names; JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe::Ok(); } if (op_node->out_edges().size() <= 1) { return Maybe::Ok(); } const Operator& variable_op = op_node->op(); const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn()); const OperatorConf* mutable_consumer = nullptr; std::vector naive_consumers; naive_consumers.reserve(op_node->out_edges().size()); for (OpEdge* edge : op_node->out_edges()) { const auto& op_conf = edge->dst_node()->op().op_conf(); if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) { CHECK_OR_RETURN(mutable_consumer == nullptr); mutable_consumer = &op_conf; } else { naive_consumers.emplace_back(&op_conf); } } if (mutable_consumer == nullptr) { return Maybe::Ok(); } for (const auto* fw_bw_op : naive_consumers) { op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name()); } return Maybe::Ok(); })); for (const auto& pair : op_conf2ctrl_in_op_names) { OperatorConf mut_mutable_consumer_op_conf(*pair.first); for (const auto& fw_bw_op_name : pair.second) { if (!IsReachable(fw_bw_op_name, mut_mutable_consumer_op_conf.name())) { mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name); } } JUST(job_builder->MutOpOnlyOnce(mut_mutable_consumer_op_conf)); } return Maybe::Ok(); } } // namespace Maybe JobCompleter::Complete(Job* job) { const auto& job_name = job->job_conf().job_name(); JobPassCtx job_pass_ctx(GlobalJobDesc()); // NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost. auto compile_tc = std::make_unique>(true, true); if (!Singleton::Get() ->resource() .disable_group_boxing_by_dst_parallel()) { JUST(WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel)); } compile_tc->Count("[GraphCompile]" + job_name + " GroupBoxingByDstParallel", 1, true); if (GlobalProcessCtx::WorldSize() > 1) { JUST(WithOpGraphAndMutJobBuilder(job, &BoxingWithMiddleNodes)); } compile_tc->Count("[GraphCompile]" + job_name + " BoxingWithMiddleNodes", 1, true); JUST(WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp)); compile_tc->Count("[GraphCompile]" + job_name + " SetCtrl", 1, true); // complete tick ops JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick)); compile_tc->Count("[GraphCompile]" + job_name + " AutoPrependTick", 1, true); JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape)); compile_tc->Count("[GraphCompile]" + job_name + " AddTickForTimeShape", 1, true); JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick)); compile_tc->Count("[GraphCompile]" + job_name + " AutoSourceAndSinkTick", 1, true); JUST(WithOpGraphAndMutJob(job, &MultiClientAutoInterfaceCriticalSectionTick)); compile_tc->Count("[GraphCompile]" + job_name + " CriticalSectionTick", 1, true); JUST(JobPass4Name("SystemOpFillJobNamePass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " SystemOpFillJobNamePass", 1, true); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " DumpBlobParallelConfPass", 1, true); #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) if (Singleton::Get()->nccl_use_compute_stream()) { // NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing. JUST(JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " InsertNcclLogicalOpPass", 1, true); // NOTE(chengcheng): must do this pass after InsertNcclLogicalOpPass for nccl op fusion and // add ctrl stirct order. JUST(JobPass4Name("NcclLogicalOpFusionPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " NcclLogicalOpFusionPass", 1, true); JUST(JobPass4Name("NcclLogicalChainStrictOrderPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " NcclLogicalChainStrictOrderPass", 1, true); // NOTE(chengcheng): Because insert new logical nccl op, MUST dump time shape, sbp again. JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " DumpBlobParallelConfPass", 1, true); } #endif // WITH_CUDA || WITH_NPU || WITH_MLU JUST(JobPass4Name("LogicalChainPass")(job, &job_pass_ctx)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); JUST(CheckAndLogOpGraph(*job)); compile_tc->Count("[GraphCompile]" + job_name + " CheckAndLogOpGraph", 1, true); return Maybe::Ok(); } Maybe JobCompleter::UpdateSharedGraphForNewInput( Job* job, const std::function>(const std::string&)>& InputTensor4Name, const std::function(const std::string& shared_op_name)>& NewOp4SharedOpName) { // job is a copy from a shared graph. // The job name has already update in py nn.Graph. const auto& new_job_name = job->job_conf().job_name(); const auto& UpdateInputShape = [&InputTensor4Name](OperatorConf& op_conf) -> Maybe { // Input op needs to be updated with new input tensor. if (op_conf.has_input_conf()) { InputOpConf* input_conf = op_conf.mutable_input_conf(); InterfaceBlobConf* blob_conf = input_conf->mutable_blob_conf(); auto input_tensor = *JUST(InputTensor4Name(op_conf.name())); input_tensor->shape()->ToProto(blob_conf->mutable_shape()); blob_conf->set_data_type(input_tensor->dtype()->data_type()); } return Maybe::Ok(); }; const auto& UpdateAttr = [&NewOp4SharedOpName](OperatorConf& op_conf) -> Maybe { // Some op attributes need to be updated with the new traced graph. if (op_conf.has_user_conf()) { for (auto& pair : *op_conf.mutable_user_conf()->mutable_attr()) { const auto* new_op_conf = JUST(NewOp4SharedOpName(op_conf.name())); if (new_op_conf == nullptr) { continue; } CHECK_EQ_OR_RETURN(new_op_conf->user_conf().op_type_name(), op_conf.user_conf().op_type_name()) << " new op " << new_op_conf->DebugString() << " is not corresponding with " << op_conf.DebugString(); auto attr_iter = new_op_conf->user_conf().attr().find(pair.first); CHECK_OR_RETURN(attr_iter != new_op_conf->user_conf().attr().end()) << " There is not attr " << pair.first << " in new op " << new_op_conf->DebugString(); if (pair.second.has_at_shape()) { *pair.second.mutable_at_shape() = attr_iter->second.at_shape(); } else if (pair.second.has_at_double()) { pair.second.set_at_double(attr_iter->second.at_double()); } else if (pair.second.has_at_list_int64()) { pair.second.mutable_at_list_int64()->CopyFrom(attr_iter->second.at_list_int64()); } } } return Maybe::Ok(); }; const auto& UpdateBufferName = [&new_job_name](OperatorConf& op_conf) -> Maybe { // These operators' execution depends on new job name. #define UPDATE_JOB_NAME(op_conf_name) \ if (op_conf.has_##op_conf_name()) { \ op_conf.mutable_##op_conf_name()->set_job_name(new_job_name); \ } UPDATE_JOB_NAME(input_conf); UPDATE_JOB_NAME(output_conf); UPDATE_JOB_NAME(callback_notify_conf); UPDATE_JOB_NAME(wait_and_send_ids_conf); UPDATE_JOB_NAME(return_conf); #undef UPDATE_JOB_NAME // Critical section operators depend job_name related buffer_name. if (op_conf.has_critical_section_wait_tick_conf()) { const auto& buffer_name = op_conf.critical_section_wait_tick_conf().buffer_name(); if (buffer_name.rfind(kInputCriticalSectionWaitBufferNamePrefix, 0) == 0) { op_conf.mutable_critical_section_wait_tick_conf()->set_buffer_name( GetInputCriticalSectionWaitBufferName(new_job_name)); } else if (buffer_name.rfind(kOutputCriticalSectionWaitBufferNamePrefix, 0) == 0) { op_conf.mutable_critical_section_wait_tick_conf()->set_buffer_name( GetOutputCriticalSectionWaitBufferName(new_job_name)); } } if (op_conf.has_critical_section_callback_tick_conf()) { const auto& buffer_name = op_conf.critical_section_callback_tick_conf().buffer_name(); if (buffer_name.rfind(kInputCriticalSectionCallbackBufferNamePrefix, 0) == 0) { op_conf.mutable_critical_section_callback_tick_conf()->set_buffer_name( GetInputCriticalSectionCallbackBufferName(new_job_name)); } else if (buffer_name.rfind(kOutputCriticalSectionCallbackBufferNamePrefix, 0) == 0) { op_conf.mutable_critical_section_callback_tick_conf()->set_buffer_name( GetOutputCriticalSectionCallbackBufferName(new_job_name)); } } return Maybe::Ok(); }; // Update the job for new input. for (auto& op_conf : *job->mutable_net()->mutable_op()) { JUST(UpdateInputShape(op_conf)); JUST(UpdateAttr(op_conf)); JUST(UpdateBufferName(op_conf)); } // Use OpGraph init to infer all LogicalBlobDesc with the new input shape. auto op_graph = std::make_unique(*job); op_graph->DumpLogicalBlobDesc(job); #ifdef WITH_CUTLASS // Warmup cutlass conv with new input shape. JobPassCtx job_pass_ctx(GlobalJobDesc()); JUST(JobPass4Name("CutlassConvTuningWarmupPass")(job, &job_pass_ctx)); #endif // WITH_CUTLASS JUST(CheckAndLogOpGraph(*job)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/job_completer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_ #define ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/framework/tensor.h" namespace oneflow { class JobCompleter final { public: OF_DISALLOW_COPY_AND_MOVE(JobCompleter); JobCompleter() = default; ~JobCompleter() = default; static Maybe Complete(Job* job); // The job is copied from a shared graph, it needs to be modified // for a new graph with different input. static Maybe UpdateSharedGraphForNewInput( Job* job, const std::function>(const std::string&)>& InputTensor4Name, const std::function(const std::string& shared_op_name)>& NewOp4SharedOpName); }; } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_ ================================================ FILE: oneflow/core/job_rewriter/job_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { HashMap* PassName2JobPass() { static HashMap pass_name2job_pass; return &pass_name2job_pass; } } // namespace void RegisterJobPass(const std::string& pass_name, const JobPass* pass) { CHECK(PassName2JobPass()->emplace(pass_name, pass).second); } bool HasJobPass(const std::string& pass_name) { return PassName2JobPass()->find(pass_name) != PassName2JobPass()->end(); } const JobPass& JobPass4Name(const std::string& pass_name) { const auto& iter = PassName2JobPass()->find(pass_name); CHECK(iter != PassName2JobPass()->end()) << "Cannot find job pass: " << pass_name; return *iter->second; } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/job_pass.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_ #define ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job_rewriter/pass_util.h" namespace oneflow { class JobPassCtx; class JobPass { public: JobPass() = default; virtual ~JobPass() = default; Maybe operator()(Job* job, JobPassCtx* ctx) const { return Apply(job, ctx); } virtual Maybe Apply(Job* job, JobPassCtx* ctx) const = 0; }; class JobPassState { public: virtual ~JobPassState() = default; protected: JobPassState() = default; }; class JobPassCtx { public: JobPassCtx(const JobPassCtx&) = delete; JobPassCtx(JobPassCtx&&) = delete; JobPassCtx(const JobDesc& job_desc) : job_desc_(&job_desc) {} ~JobPassCtx() = default; const JobDesc& job_desc() const { return *job_desc_; } template Maybe GetState(const std::string& key) const { const auto& iter = key2state_.find(key); CHECK_OR_RETURN(iter != key2state_.end()); const T* ptr = dynamic_cast(iter->second.get()); const auto& origin_obj = *iter->second; CHECK_NOTNULL_OR_RETURN(ptr) << typeid(origin_obj).name(); return *ptr; } template Maybe MutableState(const std::string& key) { const auto& iter = key2state_.find(key); CHECK_OR_RETURN(iter != key2state_.end()); T* ptr = dynamic_cast(iter->second.get()); const auto& origin_obj = *iter->second; CHECK_NOTNULL_OR_RETURN(ptr) << typeid(origin_obj).name(); return ptr; } template Maybe HasState(const std::string& key) const { const auto& iter = key2state_.find(key); return (iter != key2state_.end()); } Maybe ResetState(const std::string& key, std::unique_ptr&& state) { if (!state) { key2state_.erase(key); } else { key2state_.emplace(key, std::move(state)); } return Maybe::Ok(); } Maybe ResetState(const std::string& key) { key2state_.erase(key); return Maybe::Ok(); } private: const JobDesc* job_desc_; HashMap> key2state_; }; #define REGISTER_JOB_PASS(pass_name, pass_type) COMMAND(RegisterJobPass(pass_name, new pass_type)) void RegisterJobPass(const std::string& pass_name, const JobPass* pass); bool HasJobPass(const std::string& pass_name); const JobPass& JobPass4Name(const std::string& pass_name); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_ ================================================ FILE: oneflow/core/job_rewriter/lamb_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { struct BiasCorrectionFactorCacheKey { float beta = 1.0; ParallelConf parallel_conf; }; bool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs); } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const { using namespace oneflow; return Hash(key.beta, key.parallel_conf); } }; } // namespace std namespace oneflow { // Forward declaration for bias correction factor class BiasCorrectionFactorState final : public JobPassState { public: BiasCorrectionFactorState() {} ~BiasCorrectionFactorState() override = default; std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf, const std::function& BiasCorrectionFactorStateOp); private: HashMap key2lbn_; }; namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateLAMBHelperVariableOpConf(const VariableOp& op, const std::string& name, const float initial_value) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(initial_value); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void SetScalarShapeAndNdSbpConf(const ParallelDesc& parallel_desc, OperatorConf* op_conf) { op_conf->mutable_variable_conf()->mutable_shape()->clear_dim(); op_conf->mutable_variable_conf()->mutable_shape()->add_dim(1); op_conf->mutable_variable_conf()->clear_nd_sbp(); FOR_RANGE(int, i, 0, parallel_desc.hierarchy()->NumAxes()) { *op_conf->mutable_variable_conf()->add_nd_sbp() = "B"; } CHECK_NE(op_conf->name(), std::string("")); } void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); OperatorConf m_var = GenerateLAMBHelperVariableOpConf(*var_op, "m", 0.f); OperatorConf v_var = GenerateLAMBHelperVariableOpConf(*var_op, "v", 0.f); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var}); user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + "_optimizer"); const LambModelUpdateConf& lamb_conf = optimizer_conf.lamb_conf(); float beta1 = lamb_conf.beta1(); float beta2 = lamb_conf.beta2(); float epsilon = lamb_conf.epsilon(); bool do_bias_correction = lamb_conf.do_bias_correction(); const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn(); const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); if (do_bias_correction) { // Reuse adam bias_correction job pass const std::string& job_pass_state_key = "adam_bias_correction_factor"; const bool has_state = CHECK_JUST(ctx->HasState(job_pass_state_key)); if (!has_state) { CHECK_JUST( ctx->ResetState(job_pass_state_key, std::make_unique())); } auto* state = CHECK_JUST(ctx->MutableState(job_pass_state_key)); ParallelConf bias_correction_parallel_conf; const auto& lr_parallel_conf = CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn))); const auto& train_step_parallel_conf = CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn))); if (lr_parallel_conf == train_step_parallel_conf) { bias_correction_parallel_conf = lr_parallel_conf; } else { bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf(); } auto AddLambBiasCorrectionFactorOp = [&](float beta_val, const std::string& op_name) -> std::string { user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name); const auto lamb_bias_correction_factor_op = op_builder.OpTypeName("adam_bias_correction_factor") .Input("train_step", train_step_lbn) .Attr("beta", beta_val) .Output("out") .ScopeSymbolId(var_op->op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(bias_correction_parallel_conf, {lamb_bias_correction_factor_op.op_conf()}); return lamb_bias_correction_factor_op.output("out", 0); }; const std::string bias_correction1_lbn = state->GetLbn(beta1, "lamb_bias_correction_factor1", bias_correction_parallel_conf, AddLambBiasCorrectionFactorOp); const std::string bias_correction2_lbn = state->GetLbn(beta2, "lamb_bias_correction_factor2", bias_correction_parallel_conf, AddLambBiasCorrectionFactorOp); lamb_update_op_builder.OpTypeName("lamb_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("m", GenVariableOutputLbn(m_var)) .Input("v", GenVariableOutputLbn(v_var)) .Input("learning_rate", learning_rate_lbn) .Input("bias_correction1", bias_correction1_lbn) .Input("bias_correction2", bias_correction2_lbn) .Attr("beta1", beta1) .Attr("beta2", beta2) .Attr("epsilon", epsilon) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .Attr("do_bias_correction", true) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); } else { lamb_update_op_builder.OpTypeName("lamb_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("m", GenVariableOutputLbn(m_var)) .Input("v", GenVariableOutputLbn(v_var)) .Input("learning_rate", learning_rate_lbn) .Attr("beta1", beta1) .Attr("beta2", beta2) .Attr("epsilon", epsilon) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .Attr("do_bias_correction", false) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); } if (optimizer_conf.has_lr_scale()) { lamb_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &lamb_update_op_builder); const auto lamb_update_op = lamb_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lamb_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kLambConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/lars_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); const std::string momentum_var_op_name = var_op->op_name() + "-momentum"; OperatorConf momentum_var(var_op->op_conf()); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(0.f); *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer; momentum_var.set_name(momentum_var_op_name); momentum_var.mutable_variable_conf()->set_out("out"); momentum_var.set_scope_symbol_id(var_op->op_conf().scope_symbol_id()); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_var}); user_op::UserOpConfWrapperBuilder lars_update_op_builder(var_op->op_name() + "_optimizer"); lars_update_op_builder.OpTypeName("lars_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", optimizer_conf.learning_rate_lbn()) .Input("momentum", GenLogicalBlobName(momentum_var_op_name, momentum_var.variable_conf().out())) .Attr("momentum_beta", optimizer_conf.lars_conf().momentum_beta()) .Attr("epsilon", optimizer_conf.lars_conf().epsilon()) .Attr("lars_coefficient", optimizer_conf.lars_conf().lars_coefficient()) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { lars_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &lars_update_op_builder); user_op::UserOpConfWrapper lars_update_op = lars_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lars_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kLarsConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/logical_chain_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/container_util.h" namespace oneflow { DEFINE_ENV_BOOL(ENABLE_ACC_CHAIN_MERGE, true); namespace { class LogicalChainPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(LogicalChainPass); LogicalChainPass() = default; ~LogicalChainPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return EnableLogicalChain(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; bool IsTickOpConf(const OperatorConf& op_conf) { if (IsClassRegistered(op_conf.op_type_case())) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "cast_to_tick" || user_type_name == "acc_ctrl_tick") { return true; } } return false; } bool IsBreakpointOpNode(const OpNode* node) { // NOTE(chengcheng): breakpoint op is special which CANNOT merge in chain such as: // variable, tick, repeat/acc/pack/unpack change timeshape const Operator& op = node->op(); const OperatorConf& op_conf = op.op_conf(); // TODO(chengcheng): filter ops which has special type // TODO(chengcheng): get stream by op type if (op_conf.has_variable_conf() /* variable */ || IsTickOpConf(op_conf) /* tick */ || op_conf.has_input_conf() || op_conf.has_output_conf() /* io */ || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */ || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "unpack" || user_type_name == "identity_buffer" || user_type_name == "copy_h2d" || user_type_name == "copy_d2h") { return true; } } return false; } bool IsAccOrPackOpNode(const OpNode* node) { const auto& op_conf = node->op().op_conf(); return op_conf.has_user_conf() && (op_conf.user_conf().op_type_name() == "acc" || op_conf.user_conf().op_type_name() == "pack"); } bool IsAccOpNode(const OpNode* node) { return node->op().op_conf().has_user_conf() && node->op().op_conf().user_conf().op_type_name() == "acc"; } bool IsRepeatOpNode(const OpNode* node) { return node->op().op_conf().has_user_conf() && node->op().op_conf().user_conf().op_type_name() == "repeat"; } std::shared_ptr GetOpNodeFastestTimeShape(const OpNode* op_node) { return CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()); } std::shared_ptr GetOpNodeInputTimeShape(const OpNode* op_node) { return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape()); } bool SharedPtrShapeEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { return (*lhs) == (*rhs); } bool IsOpEdge121Connected(const OpNode* src_node, const OpNode* dst_node, const OpEdge* edge) { CHECK(src_node != dst_node && (edge->src_node() == src_node || edge->src_node() == dst_node) && (edge->dst_node() == src_node || edge->dst_node() == dst_node)); if (src_node->parallel_desc().parallel_num() != dst_node->parallel_desc().parallel_num()) { return false; } if (src_node->parallel_desc().parallel_num() == 1) { return true; } for (const auto& lbi : edge->lbis()) { // NOTE(chengcheng): nd_sbp need to be reduction like from [P, P] to [P] Shape src_reduced_hierarchy; Shape dst_reduced_hierarchy; NdSbp src_reduced_nd_sbp; NdSbp dst_reduced_nd_sbp; InOutParallelDimReduce(*src_node->parallel_desc().hierarchy(), *dst_node->parallel_desc().hierarchy(), src_node->NdSbp4Lbi(lbi), dst_node->NdSbp4Lbi(lbi), &src_reduced_hierarchy, &dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp, src_node->LogicalBlobDesc4Lbi(lbi).shape()); if (src_reduced_hierarchy != dst_reduced_hierarchy || src_reduced_nd_sbp != dst_reduced_nd_sbp) { // Not one to one return false; } } return true; } void GetLogicalChainsWithTimeShape(std::vector>* ret, const std::vector& order, const std::shared_ptr& seed_time_shape) { HashSet visited; for (const OpNode* seed_node : order) { if (visited.find(seed_node) != visited.end()) { continue; } CHECK(visited.insert(seed_node).second); const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); if (seed_node->op().op_conf().has_logical_chain_id()) { continue; } // TODO(chengcheng): support cpu chain. if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; } if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; } if (IsBreakpointOpNode(seed_node)) { continue; } HashSet this_subgraph; std::queue queued_nodes; queued_nodes.push(seed_node); while (!queued_nodes.empty()) { const OpNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); CHECK(this_subgraph.insert(cur_node).second); auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) { if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && (!next_node->op().op_conf().has_logical_chain_id()) /* skip logical chain id */ && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape) && next_node->op().op_conf().stream_name_hint() == seed_node->op().op_conf().stream_name_hint() && IsOpEdge121Connected(cur_node, next_node, edge)) { CHECK(visited.insert(next_node).second); queued_nodes.push(next_node); } }; for (const OpEdge* in_edge : cur_node->in_edges()) { SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { SearchToNextNode(cur_node, out_edge->dst_node(), out_edge); } } if (this_subgraph.size() > 1) { ret->emplace_back(HashSet()); ret->back().swap(this_subgraph); } } } struct LogicalChain { int64_t logical_chain_id; std::vector ordered_op_nodes; explicit LogicalChain(int64_t val) : logical_chain_id(val) { CHECK_GE(val, 0); } }; struct PlacementLogicalChainsInfo { std::vector> ordered_logical_chains; std::vector ordered_acc_op_nodes; std::shared_ptr after_acc_logical_chain; const ParallelDesc* seed_parallel_desc; PlacementLogicalChainsInfo() : seed_parallel_desc(nullptr) {} }; void InitPlacementLogicalChainsInfoFromSet( const std::shared_ptr& logical_chain, const HashSet& origin_logical_chain, const HashMap& op_node2global_order, const std::function& CmpOpNodeOrder) { auto* logical_chain_ordered_nodes = &logical_chain->ordered_op_nodes; CHECK(logical_chain_ordered_nodes->empty()); logical_chain_ordered_nodes->assign(origin_logical_chain.begin(), origin_logical_chain.end()); std::sort(logical_chain_ordered_nodes->begin(), logical_chain_ordered_nodes->end(), CmpOpNodeOrder); const OpNode* begin_op = logical_chain_ordered_nodes->front(); const OpNode* end_op = logical_chain_ordered_nodes->back(); int64_t begin_op_global_order = op_node2global_order.at(begin_op); int64_t end_op_global_order = op_node2global_order.at(end_op); CHECK(begin_op != end_op); CHECK_LT(begin_op_global_order, end_op_global_order); } void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_logical_chain, const std::vector& ordered_acc_op_nodes, const ParallelDesc& seed_parallel_desc) { // Meta time shape (1, 1) std::shared_ptr meta_time_shape = std::make_shared(Shape({1, 1})); HashSet visited; HashSet after_acc_chain_ops; std::queue queued_nodes; auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) { if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), meta_time_shape) && IsOpEdge121Connected(cur_node, next_node, edge)) { CHECK(visited.insert(next_node).second); queued_nodes.push(next_node); } }; for (const OpNode* acc_node : ordered_acc_op_nodes) { for (const OpEdge* out_edge : acc_node->out_edges()) { const OpNode* seed_node = out_edge->dst_node(); SearchToNextNode(acc_node, seed_node, out_edge); } } while (!queued_nodes.empty()) { const OpNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK(after_acc_chain_ops.insert(cur_node).second); for (const OpEdge* in_edge : cur_node->in_edges()) { SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { SearchToNextNode(cur_node, out_edge->dst_node(), out_edge); } } if (after_acc_chain_ops.size() > 1) { for (const OpNode* node : after_acc_chain_ops) { after_acc_logical_chain->ordered_op_nodes.push_back(node); } CHECK_EQ(after_acc_logical_chain->ordered_op_nodes.size(), after_acc_chain_ops.size()); } } void TryMergeAfterAccLogicalChainToMaxLogicalChain( PlacementLogicalChainsInfo* info, HashMap* mut_op_name2conf, JobBuilder* job_builder, const std::function& IsReachable, const std::shared_ptr& seed_time_shape) { if (!EnvBool()) { return; } int64_t max_chain_index = 0; for (int64_t i = 1; i < info->ordered_logical_chains.size(); ++i) { if (info->ordered_logical_chains.at(i)->ordered_op_nodes.size() > info->ordered_logical_chains.at(max_chain_index)->ordered_op_nodes.size()) { max_chain_index = i; } } const int64_t acc_chain_id = info->after_acc_logical_chain->logical_chain_id; auto& acc_chain_order_ops = info->after_acc_logical_chain->ordered_op_nodes; const auto& max_chain = info->ordered_logical_chains.at(max_chain_index); const OpNode* max_chain_src_op = max_chain->ordered_op_nodes.front(); const OpNode* max_chain_sink_op = max_chain->ordered_op_nodes.back(); HashSet max_chain_ops(max_chain->ordered_op_nodes.begin(), max_chain->ordered_op_nodes.end()); const OpNode* acc_chain_src_op = acc_chain_order_ops.front(); const OpNode* acc_chain_sink_op = acc_chain_order_ops.back(); // NOTE(chengcheng): find all nontrivial sink consumer ops HashSet nontrivial_sink_consumers; for (const OpNode* chain_op : max_chain->ordered_op_nodes) { chain_op->ForEachNodeOnOutEdge([&](const OpNode* out_node) { if (max_chain_ops.find(out_node) == max_chain_ops.end() && !IsTickOpConf(out_node->op().op_conf()) && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(out_node), seed_time_shape)) { nontrivial_sink_consumers.insert(out_node); } }); } // NOTE(chengcheng): find last op can insert acc ctrl tick. while ((!acc_chain_sink_op->op().op_conf().has_user_conf()) || IsReachable(acc_chain_sink_op->op().op_name(), max_chain_src_op->op().op_name())) { VLOG(3) << " cannot insert acc ctrl edge between: [" << max_chain_src_op->op().op_name() << "] -> [" << acc_chain_sink_op->op().op_name() << "] , debug info :\n" << max_chain_src_op->op().op_conf().DebugString() << "\n" << acc_chain_sink_op->op().op_conf().DebugString() << "\n"; VLOG(3) << "remove op : " << acc_chain_sink_op->op().op_name() << " from after acc logical chain: " << acc_chain_id; acc_chain_order_ops.pop_back(); if (acc_chain_order_ops.size() > 1) { acc_chain_sink_op = acc_chain_order_ops.back(); } else { acc_chain_sink_op = nullptr; break; } } if (acc_chain_sink_op == nullptr) { return; } // NOTE(chengcheng): find last op can insert acc tick. while (IsReachable(acc_chain_src_op->op().op_name(), max_chain_sink_op->op().op_name())) { VLOG(3) << " cannot insert acc tick edge between: [" << max_chain_sink_op->op().op_name() << "] -> [" << acc_chain_src_op->op().op_name() << "] , debug info :\n" << max_chain_sink_op->op().op_conf().DebugString() << "\n" << acc_chain_src_op->op().op_conf().DebugString() << "\n"; VLOG(3) << "remove op : " << acc_chain_src_op->op().op_name() << " from after acc logical chain: " << acc_chain_id; acc_chain_order_ops.erase(acc_chain_order_ops.begin()); if (acc_chain_order_ops.size() > 1) { acc_chain_src_op = acc_chain_order_ops.front(); } else { acc_chain_src_op = nullptr; break; } } if (acc_chain_src_op == nullptr) { return; } // NOTE(chengcheng): // 1.add acc ctrl tick between max chain src to acc chain sink for memory lock. const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); CHECK_GT(acc_num, 1); const auto& fc_src_obns = max_chain_src_op->op().output_bns(); CHECK(!fc_src_obns.empty()); const std::string& max_chain_src_out_lbn = GenLogicalBlobName(max_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); VLOG(3) << " max_chain_src_out_lbn : " << max_chain_src_out_lbn; user_op::UserOpConfWrapper acc_ctrl_tick_op = user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeMaxAccChain-" + NewUniqueId()) .OpTypeName("acc_ctrl_tick") .Input("in", max_chain_src_out_lbn) .Output("out") .ScopeSymbolId(max_chain_src_op->op().op_conf().scope_symbol_id()) .Attr("max_acc_num", acc_num) .Build(); OperatorConf& acc_chain_sink_op_conf = CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_sink_op->op().op_name())); CHECK(acc_chain_sink_op_conf.has_user_conf()); (*acc_chain_sink_op_conf.mutable_user_conf() ->mutable_input())[user_op::kUserSourceOpTickInputArgName] .add_s(acc_ctrl_tick_op.output("out", 0)); CHECK_JUST(job_builder->AddOp(max_chain_src_op->parallel_desc().parallel_conf(), acc_ctrl_tick_op.op_conf())); VLOG(3) << " Insert acc ctrl tick between: [" << max_chain_src_op->op().op_name() << "] -> [" << acc_chain_sink_op->op().op_name() << "]"; // NOTE(chengcheng): // 2.add acc tick between max chain sink to acc chain src for strict exec order. const auto& fc_sink_obns = max_chain_sink_op->op().output_bns(); CHECK(!fc_sink_obns.empty()); const std::string max_chain_sink_lbn = GenLogicalBlobName(max_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); VLOG(3) << " max_chain_sink_lbn : " << max_chain_sink_lbn; user_op::UserOpConfWrapper cast_to_tick_op = user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) .OpTypeName("cast_to_tick") .Input("in", max_chain_sink_lbn) .Output("out") .ScopeSymbolId(max_chain_sink_op->op().op_conf().scope_symbol_id()) .Build(); CHECK_JUST(job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); if (!IsAccOrPackOpNode(max_chain_sink_op)) { // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, // there is no need and CANNOT insert acc tick op. OperatorConf sink_acc_tick_conf; sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); sink_acc_tick_conf.set_scope_symbol_id(max_chain_sink_op->op().op_conf().scope_symbol_id()); auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); acc_conf->set_one(cast_to_tick_op.output("out", 0)); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(acc_num); acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() << " of last op in fw/bw chain."; CHECK_JUST( job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(), sink_acc_tick_conf)); } OperatorConf sink_final_tick_conf; sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") + NewUniqueId()); sink_final_tick_conf.set_scope_symbol_id(max_chain_sink_op->op().op_conf().scope_symbol_id()); auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); tick_conf->add_tick(acc_tick_output_lbn); tick_conf->set_out("out"); // NOTE(chengcheng): // 3. Important Tips: If there have nontrivial_sink_consumers, there must insert ctrl // between sink consumer with acc chain for exec order. for (const OpNode* sink_consumer : nontrivial_sink_consumers) { VLOG(2) << " insert acc tick between nontrivial_sink_consumer: [" << sink_consumer->op().op_name() << "] -> [" << sink_final_tick_conf.name() << "] for mem safe guard."; CHECK(!IsReachable(acc_chain_src_op->op().op_name(), sink_consumer->op().op_name())); const auto& sink_consumer_obns = sink_consumer->op().output_bns(); CHECK(!sink_consumer_obns.empty()); std::string sink_consumer_output_lbn = GenLogicalBlobName(sink_consumer->op().BnInOp2Lbi(sink_consumer_obns.Get(0))); user_op::UserOpConfWrapper sink_consumer_cast_to_tick_op = user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSinkConsumer-CastToTick-" + NewUniqueId()) .OpTypeName("cast_to_tick") .Input("in", sink_consumer_output_lbn) .Output("out") .ScopeSymbolId(sink_consumer->op().op_conf().scope_symbol_id()) .Build(); CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(), sink_consumer_cast_to_tick_op.op_conf())); std::string sink_consumer_acc_tick_lbn = sink_consumer_cast_to_tick_op.output("out", 0); if (!IsAccOrPackOpNode(sink_consumer)) { OperatorConf sink_consumer_acc_tick_conf; sink_consumer_acc_tick_conf.set_name(std::string("Sys-LogicalChainSinkConsumer-AccTick_") + NewUniqueId()); sink_consumer_acc_tick_conf.set_scope_symbol_id( acc_chain_src_op->op().op_conf().scope_symbol_id()); auto* acc_conf = sink_consumer_acc_tick_conf.mutable_acc_tick_conf(); acc_conf->set_one(sink_consumer_acc_tick_lbn); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(acc_num); sink_consumer_acc_tick_lbn = GenLogicalBlobName(sink_consumer_acc_tick_conf.name(), "acc"); VLOG(3) << " insert acc tick op : " << sink_consumer_acc_tick_conf.name() << " of nontrivial_sink_consumer in fw/bw chain."; CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(), sink_consumer_acc_tick_conf)); } tick_conf->add_tick(sink_consumer_acc_tick_lbn); } CHECK_JUST( job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf)); CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_src_op->op().op_name())) .add_ctrl_in_op_name(sink_final_tick_conf.name()); VLOG(3) << " Insert acc tick between: [" << max_chain_sink_op->op().op_name() << "] -> [" << acc_chain_src_op->op().op_name() << "]"; // NOTE(chengcheng): // 4. merge max chain and acc chain MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); group->add_logical_chain_id_list(max_chain->logical_chain_id); group->add_logical_chain_id_list(acc_chain_id); VLOG(3) << " Merge acc chain : " << acc_chain_id << " to max logical chain : " << max_chain->logical_chain_id; } Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); bool has_acc = acc_num > 1; int64_t max_logical_chain_id = -1; HashMap placement2logical_chains; auto FindOrCreatePlacementLogicalChainsInfo = [&](const OpNode* node) { const ParallelDesc& this_parallel_desc = node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); auto it = placement2logical_chains.find(key); if (it == placement2logical_chains.end()) { it = placement2logical_chains.emplace(key, PlacementLogicalChainsInfo()).first; it->second.seed_parallel_desc = &this_parallel_desc; } return &(it->second); }; std::vector ordered_op_nodes; HashMap op_node2global_order; HashMap mut_op_name2conf; std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); if (ParseBooleanFromEnv("DISABLE_LOGICAL_STRAIGHTEN", false)) { op_graph.TopoForEachNodeWithCtrlEdge( [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); }); } else { auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes); } for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) { const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order)); op_node2global_order.emplace(node, global_order); std::shared_ptr this_time_shape = GetOpNodeFastestTimeShape(node); if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) { seed_time_shape = this_time_shape; } mut_op_name2conf.emplace(node->op().op_name(), node->op().op_conf()); // NOTE(chengcheng): handle logical chain id set by nccl logical pass if (node->op().op_conf().has_logical_chain_id()) { const int64_t logical_chain_id = node->op().op_conf().logical_chain_id(); max_logical_chain_id = std::max(max_logical_chain_id, logical_chain_id); PlacementLogicalChainsInfo* info = FindOrCreatePlacementLogicalChainsInfo(node); if (has_acc && this_time_shape->elem_cnt() == 1) { // acc logical chain if (info->after_acc_logical_chain.get() == nullptr) { info->after_acc_logical_chain = std::make_shared(logical_chain_id); } info->after_acc_logical_chain->ordered_op_nodes.push_back(node); CHECK_EQ(info->after_acc_logical_chain->logical_chain_id, logical_chain_id); } else { // fw/bw logical chain bool find_chain = false; for (const auto& logical_chain : info->ordered_logical_chains) { if (logical_chain->logical_chain_id == logical_chain_id) { logical_chain->ordered_op_nodes.push_back(node); find_chain = true; break; } } if (!find_chain) { info->ordered_logical_chains.push_back(std::make_shared(logical_chain_id)); info->ordered_logical_chains.back()->ordered_op_nodes.push_back(node); CHECK_EQ(info->ordered_logical_chains.back()->logical_chain_id, logical_chain_id); } } } } VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); std::vector> logical_chains; GetLogicalChainsWithTimeShape(&logical_chains, ordered_op_nodes, seed_time_shape); if (logical_chains.empty() && placement2logical_chains.empty()) { return Maybe::Ok(); } auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); }; auto CmpLogicalChainOrder = [&](const std::shared_ptr& lhs, const std::shared_ptr& rhs) { int64_t lhs_begin_op_global_order = op_node2global_order.at(lhs->ordered_op_nodes.front()); int64_t rhs_begin_op_global_order = op_node2global_order.at(rhs->ordered_op_nodes.front()); return lhs_begin_op_global_order < rhs_begin_op_global_order; }; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); for (const auto& origin_logical_chain : logical_chains) { const OpNode* rand_node = *origin_logical_chain.begin(); PlacementLogicalChainsInfo* info = FindOrCreatePlacementLogicalChainsInfo(rand_node); info->ordered_logical_chains.emplace_back( std::make_shared(++max_logical_chain_id)); InitPlacementLogicalChainsInfoFromSet(info->ordered_logical_chains.back(), origin_logical_chain, op_node2global_order, CmpOpNodeOrder); } for (auto& pair : placement2logical_chains) { std::sort(pair.second.ordered_logical_chains.begin(), pair.second.ordered_logical_chains.end(), CmpLogicalChainOrder); } for (const OpNode* this_node : ordered_op_nodes) { if (IsAccOpNode(this_node)) { const ParallelDesc& this_parallel_desc = this_node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); auto it = placement2logical_chains.find(key); if (it != placement2logical_chains.end()) { it->second.ordered_acc_op_nodes.emplace_back(this_node); } } } auto InsertCtrlEdgeInChain = [&](const std::vector& ordered_op_nodes) { for (int64_t i = 1; i < ordered_op_nodes.size(); ++i) { const OpNode* this_node = CHECK_JUST(VectorAt(ordered_op_nodes, i)); const OpNode* prev_node = CHECK_JUST(VectorAt(ordered_op_nodes, i - 1)); const std::string& this_op_name = this_node->op().op_name(); const std::string& prev_op_name = prev_node->op().op_name(); if (!IsReachable(prev_op_name, this_op_name)) { CHECK_JUST(MapAt(mut_op_name2conf, this_op_name)).add_ctrl_in_op_name(prev_op_name); } } }; auto InsertLogicalChainId = [&](const std::vector& ordered_op_nodes, const int64_t logical_chain_id) { int64_t order = 0; for (const OpNode* op_node : ordered_op_nodes) { auto& conf = CHECK_JUST(MapAt(mut_op_name2conf, op_node->op().op_name())); conf.set_logical_chain_id(logical_chain_id); conf.set_order_in_logical_chain(order++); } }; HashSet exist_chain_ids; for (auto& pair : placement2logical_chains) { const auto& placement = pair.first; auto& info = pair.second; CHECK_GE(info.ordered_logical_chains.size(), 1); // NOTE(chengcheng): set logical chain id for each op in each logical chain, and insert ctrl // edge for order. for (auto& logical_chain : info.ordered_logical_chains) { CHECK_GE(logical_chain->logical_chain_id, 0); CHECK(exist_chain_ids.insert(logical_chain->logical_chain_id).second); InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id); InsertCtrlEdgeInChain(logical_chain->ordered_op_nodes); } for (const auto& logical_chain : info.ordered_logical_chains) { VLOG(3) << " In placement: " << placement << " logical_chain_id: " << logical_chain->logical_chain_id << " has op num = " << logical_chain->ordered_op_nodes.size(); for (int i = 0; i < logical_chain->ordered_op_nodes.size(); ++i) { const OpNode* ordered_op = JUST(VectorAt(logical_chain->ordered_op_nodes, i)); VLOG(3) << " ChainId: " << logical_chain->logical_chain_id << " order: " << i << " op_name: " << ordered_op->op().op_name() << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } } // NOTE(chengcheng): create logical chain after acc, and merge with max logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; if (!ordered_acc_op_nodes.empty()) { if (info.after_acc_logical_chain.get() == nullptr) { info.after_acc_logical_chain = std::make_shared(++max_logical_chain_id); CreateAfterAccLogicalChain(info.after_acc_logical_chain, ordered_acc_op_nodes, *info.seed_parallel_desc); } CHECK_GE(info.after_acc_logical_chain->logical_chain_id, 0); CHECK(exist_chain_ids.insert(info.after_acc_logical_chain->logical_chain_id).second); auto& acc_chain_order_ops = info.after_acc_logical_chain->ordered_op_nodes; if (acc_chain_order_ops.size() > 1) { std::sort(acc_chain_order_ops.begin(), acc_chain_order_ops.end(), CmpOpNodeOrder); TryMergeAfterAccLogicalChainToMaxLogicalChain(&info, &mut_op_name2conf, job_builder, IsReachable, seed_time_shape); if (acc_chain_order_ops.size() <= 1) { continue; } VLOG(3) << " In placement: " << placement << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id << " has op num = " << acc_chain_order_ops.size(); for (int i = 0; i < acc_chain_order_ops.size(); ++i) { const OpNode* ordered_op = JUST(VectorAt(acc_chain_order_ops, i)); VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id << " order: " << i << " op_name: " << ordered_op->op().op_name() << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } InsertLogicalChainId(acc_chain_order_ops, info.after_acc_logical_chain->logical_chain_id); InsertCtrlEdgeInChain(acc_chain_order_ops); } } } // NOTE(chengcheng): update global order and chain id for ops. for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("LogicalChainPass", LogicalChainPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/momentum_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/str_util.h" #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); const std::string op_name = var_op->op_name() + "-momentum"; OperatorConf momentum_var(var_op->op_conf()); const bool has_snapshot_path = job_builder->job().job_conf().has_default_initialize_with_snapshot_path(); std::string file_path; if (has_snapshot_path) { file_path = JoinPath(job_builder->job().job_conf().default_initialize_with_snapshot_path(), op_name, "out"); } if (has_snapshot_path && SnapshotFS()->FileExists(file_path)) { VLOG(3) << "file_path: " << file_path; momentum_var.mutable_variable_conf()->mutable_initialize_with_snapshot()->set_path( JoinPath(job_builder->job().job_conf().default_initialize_with_snapshot_path(), op_name)); momentum_var.mutable_variable_conf()->mutable_initialize_with_snapshot()->set_key("out"); } else { if (has_snapshot_path) { VLOG(3) << file_path << " not found, will be initialized"; } InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(0.f); *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer; } momentum_var.set_name(op_name); momentum_var.mutable_variable_conf()->set_out("out"); momentum_var.set_scope_symbol_id(var_op->op_conf().scope_symbol_id()); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_var}); user_op::UserOpConfWrapperBuilder momentum_update_op_builder(var_op->op_name() + "_optimizer"); momentum_update_op_builder.OpTypeName("momentum_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", optimizer_conf.learning_rate_lbn()) .Input("momentum", GenLogicalBlobName(op_name, momentum_var.variable_conf().out())) .Attr("beta", optimizer_conf.momentum_conf().beta()) .Attr("dampening", optimizer_conf.momentum_conf().dampening()) .Attr("nesterov", optimizer_conf.momentum_conf().nesterov()) .Attr("maximize", optimizer_conf.momentum_conf().maximize()) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { momentum_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &momentum_update_op_builder); user_op::UserOpConfWrapper momentum_update_op = momentum_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kMomentumConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/multi_tensor_model_update.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { struct SGDOptimizerKey { std::string learning_rate; std::string scale_by_tensor_lbn; std::string skip_if_lbn; double scale; float l1; float l2; float weight_decay; ParallelConf parallel_conf; bool has_model_copy; /* In fuse_model_update_cast pass, not all the cast fp16 model_diff kernel can be fused, it may cause some model diff type is float16, some is float32. So here we need to use model_diff datatype as key to group. */ DataType model_diff_dtype; }; bool operator==(const SGDOptimizerKey& lhs, const SGDOptimizerKey& rhs) { return (lhs.learning_rate == rhs.learning_rate) && (lhs.scale_by_tensor_lbn == rhs.scale_by_tensor_lbn) && (lhs.skip_if_lbn == rhs.skip_if_lbn) && (lhs.scale == rhs.scale) && (lhs.l1 == rhs.l1) && (lhs.l2 == rhs.l2) && (lhs.weight_decay == rhs.weight_decay) && (lhs.parallel_conf == rhs.parallel_conf) && (lhs.has_model_copy == rhs.has_model_copy) && (lhs.model_diff_dtype == rhs.model_diff_dtype); } struct AdamOptimizerKey { std::string learning_rate; std::string scale_by_tensor_lbn; std::string skip_if_lbn; double scale; float l1; float l2; float beta1; float beta2; float epsilon; float weight_decay; bool amsgrad; bool do_bias_correction; ParallelConf parallel_conf; bool has_model_copy; DataType model_diff_dtype; }; bool operator==(const AdamOptimizerKey& lhs, const AdamOptimizerKey& rhs) { return (lhs.learning_rate == rhs.learning_rate) && (lhs.scale_by_tensor_lbn == rhs.scale_by_tensor_lbn) && (lhs.skip_if_lbn == rhs.skip_if_lbn) && (lhs.scale == rhs.scale) && (lhs.l1 == rhs.l1) && (lhs.l2 == rhs.l2) && (lhs.beta1 == rhs.beta1) && (lhs.beta2 == rhs.beta2) && (lhs.epsilon == rhs.epsilon) && (lhs.weight_decay == rhs.weight_decay) && (lhs.amsgrad == rhs.amsgrad) && (lhs.do_bias_correction == rhs.do_bias_correction) && (lhs.parallel_conf == rhs.parallel_conf) && (lhs.has_model_copy == rhs.has_model_copy) && (lhs.model_diff_dtype == rhs.model_diff_dtype); } } // namespace oneflow namespace std { template<> struct hash { size_t operator()(const oneflow::SGDOptimizerKey& key) const { const auto float_hash = std::hash(); const auto double_hash = std::hash(); const auto& string_hash = std::hash(); const auto& parallel_conf_hash = std::hash(); const auto& bool_hash = std::hash(); const auto& dtype_hash = std::hash(); size_t hash = string_hash(key.learning_rate); oneflow::HashCombine(&hash, string_hash(key.scale_by_tensor_lbn)); oneflow::HashCombine(&hash, string_hash(key.skip_if_lbn)); oneflow::HashCombine(&hash, double_hash(key.scale)); oneflow::HashCombine(&hash, float_hash(key.l1)); oneflow::HashCombine(&hash, float_hash(key.l2)); oneflow::HashCombine(&hash, float_hash(key.weight_decay)); oneflow::HashCombine(&hash, parallel_conf_hash(key.parallel_conf)); oneflow::HashCombine(&hash, bool_hash(key.has_model_copy)); oneflow::HashCombine(&hash, dtype_hash(key.model_diff_dtype)); return hash; } }; template<> struct hash { size_t operator()(const oneflow::AdamOptimizerKey& key) const { const auto& float_hash = std::hash(); const auto& double_hash = std::hash(); const auto& string_hash = std::hash(); const auto& bool_hash = std::hash(); const auto& parallel_conf_hash = std::hash(); const auto& dtype_hash = std::hash(); size_t hash = string_hash(key.learning_rate); oneflow::HashCombine(&hash, string_hash(key.scale_by_tensor_lbn)); oneflow::HashCombine(&hash, string_hash(key.skip_if_lbn)); oneflow::HashCombine(&hash, double_hash(key.scale)); oneflow::HashCombine(&hash, float_hash(key.l1)); oneflow::HashCombine(&hash, float_hash(key.l2)); oneflow::HashCombine(&hash, float_hash(key.beta1)); oneflow::HashCombine(&hash, float_hash(key.beta2)); oneflow::HashCombine(&hash, float_hash(key.epsilon)); oneflow::HashCombine(&hash, float_hash(key.weight_decay)); oneflow::HashCombine(&hash, bool_hash(key.amsgrad)); oneflow::HashCombine(&hash, bool_hash(key.do_bias_correction)); oneflow::HashCombine(&hash, parallel_conf_hash(key.parallel_conf)); oneflow::HashCombine(&hash, bool_hash(key.has_model_copy)); oneflow::HashCombine(&hash, dtype_hash(key.model_diff_dtype)); return hash; } }; } // namespace std namespace oneflow { namespace { void AddScaleAndSkipLbn(user_op::UserOpConfWrapperBuilder& multi_tensor_model_update_op_builder, const user_op::UserOpConfWrapper& model_update_user_conf) { if (model_update_user_conf.has_input("scale_by_tensor", 0)) { multi_tensor_model_update_op_builder.Input("scale_by_tensor", model_update_user_conf.input("scale_by_tensor", 0)); } if (model_update_user_conf.has_input("skip_if", 0)) { multi_tensor_model_update_op_builder.Input("skip_if", model_update_user_conf.input("skip_if", 0)); } } void AddProcessedVariable(HashSet& processed_variable_list, const user_op::UserOpConfWrapper& model_update_user_conf) { /* Since each variable op will be processed in pass, for example, Adam optimizer has 3 variables: model, m, v. We replace to multi tensor optimizer and processed 3 variables at once, if we don't filter these variables, these variables will be repeated 3 times in multi_tensor_update kernel. Here we use a HashSet to sign if the variable has been processed. */ processed_variable_list.emplace(model_update_user_conf.input("model", 0)); if (model_update_user_conf.op_type_name() == "adam_update") { processed_variable_list.emplace(model_update_user_conf.input("m", 0)); processed_variable_list.emplace(model_update_user_conf.input("v", 0)); } } bool IfVariableProcessed(const HashSet& processed_variable_list, const user_op::UserOpConfWrapper& model_update_user_conf) { if (model_update_user_conf.op_type_name() == "sgd_update") { const auto& processed_model_iter = processed_variable_list.find(model_update_user_conf.input("model", 0)); if (processed_model_iter != processed_variable_list.end()) { return true; } } else if (model_update_user_conf.op_type_name() == "adam_update") { const auto& processed_model_iter = processed_variable_list.find(model_update_user_conf.input("model", 0)); const auto& processed_m_iter = processed_variable_list.find(model_update_user_conf.input("m", 0)); const auto& processed_v_iter = processed_variable_list.find(model_update_user_conf.input("v", 0)); if (processed_model_iter != processed_variable_list.end() && processed_m_iter != processed_variable_list.end() && processed_v_iter != processed_variable_list.end()) { return true; } } else { UNIMPLEMENTED() << "Current Optimizer do not support multi tensor update. "; } return false; } class MultiTensorModelUpdatePass final : public JobPass { public: MultiTensorModelUpdatePass() = default; ~MultiTensorModelUpdatePass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().enable_multi_tensor_update() || ParseBooleanFromEnv("ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE", false); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe MultiTensorModelUpdatePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { if (!job_builder->job().job_conf().has_train_conf()) { return Maybe::Ok(); } std::vector delete_ops; ParallelConf parallel_conf{}; HashMap multi_tensor_sgd_update_hashmap; HashMap multi_tensor_adam_update_hashmap; HashSet processed_variable_list{}; op_graph.ForEachNode([&](OpNode* op_node) { const auto& op_conf = op_node->op().op_conf(); if (!op_conf.has_variable_conf()) { return; } LogicalBlobId model_copy_lbi; for (OpEdge* find_model_update_edge : op_node->out_edges()) { OpNode* find_model_update_update_node = find_model_update_edge->dst_node(); if (!IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "sgd_update") && !IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "adam_update")) { continue; } const user_op::UserOpConfWrapper model_update_user_conf( find_model_update_update_node->op().op_conf()); // Multi tensor update pass only support for CUDA currently. if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) { continue; } // Multi tensor update pass only support Data Parallel. bool if_data_parallel = true; for (const auto& pair : find_model_update_update_node->sbp_signature().bn_in_op2sbp_parallel()) { if (!pair.second.has_broadcast_parallel()) { if_data_parallel = false; break; } } if (!if_data_parallel) { continue; } // Check the variable has been processed before. if (IfVariableProcessed(processed_variable_list, model_update_user_conf)) { continue; } delete_ops.emplace_back(find_model_update_update_node->op().op_conf()); parallel_conf = find_model_update_update_node->parallel_desc().parallel_conf(); std::string scale_by_tensor_lbn = ""; std::string skip_if_lbn = ""; bool has_model_copy = false; if (model_update_user_conf.has_input("scale_by_tensor", 0)) { scale_by_tensor_lbn = model_update_user_conf.input("scale_by_tensor", 0); } if (model_update_user_conf.has_input("skip_if", 0)) { skip_if_lbn = model_update_user_conf.input("skip_if", 0); } if (model_update_user_conf.has_input("model_copy", 0)) { has_model_copy = true; } const BlobDesc& model_diff_blob_desc = op_graph.GetLogicalBlobDesc( GenLogicalBlobId(model_update_user_conf.input("model_diff", 0))); const DataType model_diff_dtype = model_diff_blob_desc.data_type(); if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "sgd_update")) { SGDOptimizerKey key{model_update_user_conf.input("learning_rate", 0), scale_by_tensor_lbn, skip_if_lbn, model_update_user_conf.attr("scale"), model_update_user_conf.attr("l1"), model_update_user_conf.attr("l2"), model_update_user_conf.attr("weight_decay"), parallel_conf, has_model_copy, model_diff_dtype}; const auto& iter = multi_tensor_sgd_update_hashmap.find(key); if (iter != multi_tensor_sgd_update_hashmap.end()) { iter->second.Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)); if (has_model_copy) { iter->second.Input("model_copy", model_update_user_conf.input("model_copy", 0)); } } else { user_op::UserOpConfWrapperBuilder multi_tensor_sgd_update_op_builder( "multi_tensor_model_update" + NewUniqueId()); std::string op_type_name = "multi_tensor_sgd_update"; if (has_model_copy) { op_type_name = "multi_tensor_sgd_update_with_cast"; } multi_tensor_sgd_update_op_builder.OpTypeName(op_type_name) .Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)) .Input("learning_rate", model_update_user_conf.input("learning_rate", 0)) .Attr("scale", model_update_user_conf.attr("scale")) .Attr("l1", model_update_user_conf.attr("l1")) .Attr("l2", model_update_user_conf.attr("l2")) .Attr("weight_decay", model_update_user_conf.attr("weight_decay")); if (has_model_copy) { multi_tensor_sgd_update_op_builder.Input("model_copy", model_update_user_conf.input("model_copy", 0)); } AddScaleAndSkipLbn(multi_tensor_sgd_update_op_builder, model_update_user_conf); CHECK(model_update_user_conf.op_conf().has_scope_symbol_id()); multi_tensor_sgd_update_op_builder.ScopeSymbolId( model_update_user_conf.op_conf().scope_symbol_id()); multi_tensor_sgd_update_hashmap.emplace(key, multi_tensor_sgd_update_op_builder); } } else if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), "adam_update")) { AdamOptimizerKey key{model_update_user_conf.input("learning_rate", 0), scale_by_tensor_lbn, skip_if_lbn, model_update_user_conf.attr("scale"), model_update_user_conf.attr("l1"), model_update_user_conf.attr("l2"), model_update_user_conf.attr("beta1"), model_update_user_conf.attr("beta2"), model_update_user_conf.attr("epsilon"), model_update_user_conf.attr("weight_decay"), model_update_user_conf.attr("amsgrad"), model_update_user_conf.attr("do_bias_correction"), parallel_conf, has_model_copy, model_diff_dtype}; if (key.amsgrad) { UNIMPLEMENTED() << "Multi Tensor Adam update do not support amsgrad = True. "; } const auto& iter = multi_tensor_adam_update_hashmap.find(key); if (iter != multi_tensor_adam_update_hashmap.end()) { iter->second.Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)) .Input("m", model_update_user_conf.input("m", 0)) .Input("v", model_update_user_conf.input("v", 0)); if (has_model_copy) { iter->second.Input("model_copy", model_update_user_conf.input("model_copy", 0)); } if (model_update_user_conf.attr("do_bias_correction")) { iter->second .Input("bias_correction1", model_update_user_conf.input("bias_correction1", 0)) .Input("bias_correction2", model_update_user_conf.input("bias_correction2", 0)); } } else { user_op::UserOpConfWrapperBuilder multi_tensor_adam_update_op_builder( "multi_tensor_model_update" + NewUniqueId()); std::string op_type_name = "multi_tensor_adam_update"; if (has_model_copy) { op_type_name = "multi_tensor_adam_update_with_cast"; } multi_tensor_adam_update_op_builder.OpTypeName(op_type_name) .Input("model", model_update_user_conf.input("model", 0)) .Input("model_diff", model_update_user_conf.input("model_diff", 0)) .Input("m", model_update_user_conf.input("m", 0)) .Input("v", model_update_user_conf.input("v", 0)) .Input("learning_rate", model_update_user_conf.input("learning_rate", 0)) .Attr("scale", model_update_user_conf.attr("scale")) .Attr("l1", model_update_user_conf.attr("l1")) .Attr("l2", model_update_user_conf.attr("l2")) .Attr("beta1", model_update_user_conf.attr("beta1")) .Attr("beta2", model_update_user_conf.attr("beta2")) .Attr("epsilon", model_update_user_conf.attr("epsilon")) .Attr("weight_decay", model_update_user_conf.attr("weight_decay")) .Attr("amsgrad", model_update_user_conf.attr("amsgrad")) .Attr("do_bias_correction", model_update_user_conf.attr("do_bias_correction")); if (model_update_user_conf.attr("do_bias_correction")) { multi_tensor_adam_update_op_builder .Input("bias_correction1", model_update_user_conf.input("bias_correction1", 0)) .Input("bias_correction2", model_update_user_conf.input("bias_correction2", 0)); } if (has_model_copy) { multi_tensor_adam_update_op_builder.Input( "model_copy", model_update_user_conf.input("model_copy", 0)); } AddScaleAndSkipLbn(multi_tensor_adam_update_op_builder, model_update_user_conf); CHECK(model_update_user_conf.op_conf().has_scope_symbol_id()); multi_tensor_adam_update_op_builder.ScopeSymbolId( model_update_user_conf.op_conf().scope_symbol_id()); multi_tensor_adam_update_hashmap.emplace(key, multi_tensor_adam_update_op_builder); } } else { UNIMPLEMENTED() << "Current Optimizer do not support multi tensor update. "; } AddProcessedVariable(processed_variable_list, model_update_user_conf); break; } }); for (auto& op : multi_tensor_sgd_update_hashmap) { auto multi_tensor_model_update_sgd_op = op.second.Build(); job_builder->AddOps(parallel_conf, {multi_tensor_model_update_sgd_op.op_conf()}); } for (auto& op : multi_tensor_adam_update_hashmap) { auto multi_tensor_model_update_adam_op = op.second.Build(); job_builder->AddOps(parallel_conf, {multi_tensor_model_update_adam_op.op_conf()}); } job_builder->DelOps(delete_ops); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("MultiTensorModelUpdatePass", MultiTensorModelUpdatePass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace { class NcclLogicalChainStrictOrderPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(NcclLogicalChainStrictOrderPass); NcclLogicalChainStrictOrderPass() = default; ~NcclLogicalChainStrictOrderPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return Singleton::Get()->nccl_use_compute_stream(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; bool IsAccOrPackOpNode(const OpNode* node) { const auto& op_conf = node->op().op_conf(); return op_conf.has_user_conf() && (op_conf.user_conf().op_type_name() == "acc" || op_conf.user_conf().op_type_name() == "pack"); } Maybe InsertCtrlOpBetweenBwChainAndAccChain( HashMap* mut_op_name2conf, JobBuilder* job_builder, const std::vector& ordered_op_nodes, const std::function& IsReachable) { HashMap placement2last_normal_node; HashMap placement2first_after_acc_node; int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) { const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order)); if (!node->op().op_conf().has_logical_chain_id()) { continue; } const int64_t time_shape_cnt = CHECK_JUST(node->op().GetInputOutputFastestTimeShape())->elem_cnt(); CHECK(time_shape_cnt == acc_num || time_shape_cnt == 1) << " invalid time shape count = " << time_shape_cnt << " which should be : [ " << acc_num << " , 1 ]"; std::string placement_key = GenParallelConfKey(node->parallel_desc().parallel_conf()); if (time_shape_cnt == acc_num) { // for all fw/bw chains in this placement placement2last_normal_node[placement_key] = node; // create or update } else { // acc chain if (placement2first_after_acc_node.find(placement_key) == placement2first_after_acc_node.end()) { CHECK(placement2first_after_acc_node.emplace(placement_key, node).second); } } } for (const auto& pair : placement2last_normal_node) { if (placement2first_after_acc_node.find(pair.first) == placement2first_after_acc_node.end()) { continue; } const OpNode* last_bw_node = pair.second; const OpNode* first_after_acc_node = JUST(MapAt(placement2first_after_acc_node, pair.first)); const std::string& last_bw_op_name = last_bw_node->op().op_name(); const std::string& first_after_acc_op_name = first_after_acc_node->op().op_name(); CHECK_OR_RETURN(!IsReachable(first_after_acc_op_name, last_bw_op_name)) << Error::RuntimeError() << " Error! Cycle control edge from first acc chain op: " << first_after_acc_op_name << " to last bw chain sink op: " << last_bw_op_name; const auto& bw_sink_obns = last_bw_node->op().output_bns(); CHECK_OR_RETURN(!bw_sink_obns.empty()); const std::string bw_sink_lbn = GenLogicalBlobName(last_bw_node->op().BnInOp2Lbi(bw_sink_obns.Get(0))); VLOG(3) << " bw_sink_lbn : " << bw_sink_lbn; user_op::UserOpConfWrapper cast_to_tick_op = user_op::UserOpConfWrapperBuilder("Sys-LastNcclChainSink-CastToTick-" + NewUniqueId()) .OpTypeName("cast_to_tick") .Input("in", bw_sink_lbn) .Output("out") .ScopeSymbolId(last_bw_node->op().op_conf().scope_symbol_id()) .Build(); JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); if (!IsAccOrPackOpNode(last_bw_node)) { // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, // there is no need and CANNOT insert acc tick op. OperatorConf sink_acc_tick_conf; sink_acc_tick_conf.set_name(std::string("Sys-LastNcclChainSink-AccTick_") + NewUniqueId()); sink_acc_tick_conf.set_scope_symbol_id(last_bw_node->op().op_conf().scope_symbol_id()); auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); acc_conf->set_one(acc_tick_output_lbn); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(acc_num); acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() << " of last op in fw/bw chain."; JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(), sink_acc_tick_conf)); } OperatorConf sink_final_tick_conf; sink_final_tick_conf.set_name(std::string("Sys-LastNcclChainSink-FinalTick-DeviceTick_") + NewUniqueId()); sink_final_tick_conf.set_scope_symbol_id(last_bw_node->op().op_conf().scope_symbol_id()); auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); tick_conf->add_tick(acc_tick_output_lbn); tick_conf->set_out("out"); JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(), sink_final_tick_conf)); if (mut_op_name2conf->find(first_after_acc_op_name) == mut_op_name2conf->end()) { mut_op_name2conf->emplace(first_after_acc_op_name, first_after_acc_node->op().op_conf()); } JUST(MapAt(*mut_op_name2conf, first_after_acc_op_name)) .add_ctrl_in_op_name(sink_final_tick_conf.name()); VLOG(2) << " In: " << pair.first << " , insert ctrl edge from: [ " << last_bw_op_name << " ] to: [ " << first_after_acc_op_name << " ]"; } return Maybe::Ok(); } Maybe NcclLogicalChainStrictOrderPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap nccl_chain_id2cur_last_node; HashMap mut_op_name2conf; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); std::vector ordered_op_nodes; if (ParseBooleanFromEnv("DISABLE_LOGICAL_STRAIGHTEN", false)) { op_graph.TopoForEachNodeWithCtrlEdge( [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); }); } else { auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes); } for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) { const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order)); if (!node->op().op_conf().has_logical_chain_id()) { continue; } const int64_t logical_chain_id = node->op().op_conf().logical_chain_id(); // add ctrl edge for strict order auto it = nccl_chain_id2cur_last_node.find(logical_chain_id); if (it == nccl_chain_id2cur_last_node.end()) { nccl_chain_id2cur_last_node.emplace(logical_chain_id, node); } else { const std::string& this_op_name = node->op().op_name(); const std::string& prev_op_name = it->second->op().op_name(); if (!IsReachable(prev_op_name, this_op_name)) { CHECK(mut_op_name2conf.emplace(this_op_name, node->op().op_conf()).second); JUST(MapAt(mut_op_name2conf, this_op_name)).add_ctrl_in_op_name(prev_op_name); } it->second = node; } } if (job_builder->job().job_conf().num_gradient_accumulation_steps() > 1) { JUST(InsertCtrlOpBetweenBwChainAndAccChain(&mut_op_name2conf, job_builder, ordered_op_nodes, IsReachable)); } for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("NcclLogicalChainStrictOrderPass", NcclLogicalChainStrictOrderPass); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU) #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/container_util.h" #include "oneflow/user/ops/nccl_logical_util.h" namespace oneflow { // nccl fusion bucket size 500MiB. DEFINE_ENV_INTEGER(ONEFLOW_GRAPH_NCCL_LOGICAL_FUSION_BUCKET_SIZE, 5e8); namespace { class NcclLogicalOpFusionPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(NcclLogicalOpFusionPass); NcclLogicalOpFusionPass() = default; ~NcclLogicalOpFusionPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { return Singleton::Get()->nccl_use_compute_stream() && EnableNcclLogicalFusion(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; const std::string kNcclLogicalFusionOpNamePrefix = "Sys-NCCL-Logical-Fusion"; bool IsNcclLogicalOpNode(const OpNode* node) { if (node->op().op_conf().has_user_conf()) { const std::string& user_type_name = node->op().op_conf().user_conf().op_type_name(); if (user_type_name == "_nccl_logical_all_reduce" || user_type_name == "_nccl_logical_reduce_scatter" || user_type_name == "_nccl_logical_reduce_scatter_noncontinuous" || user_type_name == "_nccl_logical_all_gather" || user_type_name == "_nccl_logical_all_gather_noncontinuous" || user_type_name == "_nccl_logical_s2s" || user_type_name == "_nccl_logical_2D_same_dim0_all_reduce" || user_type_name == "_nccl_logical_2D_same_dim0_all_gather" || user_type_name == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous" || user_type_name == "_nccl_logical_2D_same_dim0_all2all" || user_type_name == "_nccl_logical_2D_same_dim1_all_reduce" /* || user_type_name == "_nccl_logical_send_recv" */) { // TODO(chengcheng) : support nccl send/recv kernel return true; } } return false; } Maybe ReplaceNcclOpsWithFusionOp(std::vector* nccl_fusion_ops, std::vector* nccl_fusion_op_parallel_confs, std::unordered_set* del_ops, HashMap* mut_op_name2conf, const std::vector& nccl_ops) { if (nccl_ops.size() <= 1) { return Maybe::Ok(); } const int32_t nccl_size = nccl_ops.size(); const OpNode* first_nccl = nccl_ops.front(); const OperatorConf& first_nccl_conf = first_nccl->op().op_conf(); const ParallelDesc& seed_placement = first_nccl->parallel_desc(); const int64_t scope_symbol_id = first_nccl_conf.scope_symbol_id(); std::vector src_nd_sbp_str_list; std::vector dst_nd_sbp_str_list; std::vector nccl_type_list; int64_t logical_chain_id = first_nccl_conf.logical_chain_id(); bool has_stream_name_hint = first_nccl_conf.has_stream_name_hint(); std::string stream_name_hint = first_nccl_conf.stream_name_hint(); user_op::UserOpConfWrapperBuilder fusion_builder = user_op::UserOpConfWrapperBuilder("Sys-NCCL-fusion-" + NewUniqueId()); fusion_builder.OpTypeName("_nccl_logical_fusion"); for (const OpNode* nccl_op : nccl_ops) { fusion_builder.Input("in", GenLogicalBlobName(nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleIbn()))); src_nd_sbp_str_list.push_back( NdSbpToLongString(nccl_op->NdSbp4BnInOp(nccl_op->op().SoleIbn()))); dst_nd_sbp_str_list.push_back( NdSbpToLongString(nccl_op->NdSbp4BnInOp(nccl_op->op().SoleObn()))); nccl_type_list.push_back(nccl_op->op().op_conf().user_conf().op_type_name()); CHECK(seed_placement == nccl_op->parallel_desc()); CHECK_EQ(has_stream_name_hint, nccl_op->op().op_conf().has_stream_name_hint()); CHECK_EQ(stream_name_hint, nccl_op->op().op_conf().stream_name_hint()); // 1. update del op VLOG(3) << " Del op: " << nccl_op->op().op_conf().DebugString(); del_ops->insert(nccl_op->op().op_name()); } auto fusion_nccl_op = fusion_builder.Output("out", nccl_size) .Attr>("src_nd_sbp_str_list", src_nd_sbp_str_list) .Attr>("dst_nd_sbp_str_list", dst_nd_sbp_str_list) .Attr>("nccl_type_list", nccl_type_list) .ScopeSymbolId(scope_symbol_id) .Build(); OperatorConf fusion_nccl_op_conf = fusion_nccl_op.op_conf(); fusion_nccl_op_conf.set_logical_chain_id(logical_chain_id); if (has_stream_name_hint) { fusion_nccl_op_conf.set_stream_name_hint(stream_name_hint); } // 2. update fusion op VLOG(3) << " Add fusion op : " << fusion_nccl_op_conf.DebugString() << " \n with placement: " << seed_placement.parallel_conf().DebugString(); nccl_fusion_ops->push_back(fusion_nccl_op_conf); nccl_fusion_op_parallel_confs->push_back(seed_placement.parallel_conf()); for (int32_t i = 0; i < nccl_size; ++i) { std::string output_lbn = fusion_nccl_op.output("out", i); std::string input_lbn = fusion_nccl_op.input("in", i); const OpNode* origin_nccl = JUST(VectorAt(nccl_ops, i)); const OpEdge* origin_edge = origin_nccl->SoleOutEdge(); std::string origin_nccl_input_lbn = GenLogicalBlobName(origin_nccl->op().BnInOp2Lbi(origin_nccl->op().SoleIbn())); std::string origin_nccl_output_lbn = GenLogicalBlobName(origin_nccl->op().BnInOp2Lbi(origin_nccl->op().SoleObn())); CHECK_EQ(input_lbn, origin_nccl_input_lbn); const OpNode* origin_consumer = origin_edge->dst_node(); const std::string& consumer_op_name = origin_consumer->op().op_name(); if (mut_op_name2conf->find(consumer_op_name) == mut_op_name2conf->end()) { mut_op_name2conf->emplace(consumer_op_name, origin_consumer->op().op_conf()); } CHECK_EQ(origin_edge->lbis().size(), 1); const LogicalBlobId& lbi = origin_edge->lbis().front(); VLOG(3) << " input_lbn: " << input_lbn; VLOG(3) << " lbi: " << GenLogicalBlobName(lbi); CHECK_EQ(origin_nccl_output_lbn, GenLogicalBlobName(lbi)); // 3. update consumer op for (const std::string& ibn : JUST(MapAt(origin_edge->lbi2ibns(), lbi))) { std::string old_lbn = ReplaceInputLbnInOpCustomizedConf( &JUST(MapAt(*mut_op_name2conf, consumer_op_name)), ibn, output_lbn); CHECK_EQ(old_lbn, origin_nccl_output_lbn); } VLOG(3) << " Update origin consumer op from: \n [ " << origin_consumer->op().op_conf().DebugString() << " ] \n to \n [ " << JUST(MapAt(*mut_op_name2conf, consumer_op_name)).DebugString() << " ] \n"; } return Maybe::Ok(); } struct NcclFusionBucket { std::vector nccl_ops; int64_t fusion_bucket_size; NcclFusionBucket() : fusion_bucket_size(0) {} }; std::string GenNcclFusionKey(const OpNode* nccl_op) { // NOTE(chengcheng): Chain need same placement but ignore hierarchy, // logical_chain_id + hierarchy_shape can guarantee the same device_mesh. int64_t logical_chain_id = nccl_op->op().op_conf().logical_chain_id(); const auto& hierarchy = nccl_op->parallel_desc().hierarchy(); std::string fusion_key = "logical_chain_id: " + std::to_string(logical_chain_id) + ", device_mesh: " + hierarchy->ToString() + ", comm: " + GetCommKeyFromNcclType(nccl_op->op().op_conf().user_conf().op_type_name()); return fusion_key; } int64_t GetNcclOpMemSize(const OpNode* nccl_op) { const LogicalBlobId& in_lbi = nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleIbn()); const LogicalBlobId& out_lbi = nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleObn()); const BlobDesc& in_logical_blob_desc = nccl_op->LogicalBlobDesc4Lbi(in_lbi); const BlobDesc& out_logical_blob_desc = nccl_op->LogicalBlobDesc4Lbi(out_lbi); const std::shared_ptr in_local_shape = CHECK_JUST(GetPhysicalShape( in_logical_blob_desc.shape(), nccl_op->NdSbp4Lbi(in_lbi), nccl_op->parallel_desc(), 0)); const std::shared_ptr out_local_shape = CHECK_JUST(GetPhysicalShape( out_logical_blob_desc.shape(), nccl_op->NdSbp4Lbi(out_lbi), nccl_op->parallel_desc(), 0)); int64_t elem_cnt = std::max(in_local_shape->elem_cnt(), out_local_shape->elem_cnt()); return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(in_logical_blob_desc.data_type())); } void AppendOrCreatFusionBucket(std::vector* buckets, const OpNode* nccl_op, const int64_t bucket_limit) { const int64_t nccl_mem_size = GetNcclOpMemSize(nccl_op); for (auto& fusion_bucket : *buckets) { if (fusion_bucket.fusion_bucket_size + nccl_mem_size < bucket_limit) { fusion_bucket.nccl_ops.push_back(nccl_op); fusion_bucket.fusion_bucket_size += nccl_mem_size; return; } } buckets->push_back(NcclFusionBucket()); buckets->back().nccl_ops.push_back(nccl_op); buckets->back().fusion_bucket_size += nccl_mem_size; } Maybe NcclLogicalOpFusionPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap op_node2nccl_depth; HashMap> nccl_depth2nccl_ops; auto ConstForEachDataAndCtrlInNode = [&](const OpNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge(Handler); for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) { const OpNode* in_node = op_graph.OpNode4OpName(ctrl_in_op_name); CHECK(in_node) << " cannot find ctrl_in_op_name: [" << ctrl_in_op_name << "] of op: [" << node->op().op_name() << "] in OpGraph. "; Handler(in_node); } }; std::vector ordered_op_nodes; if (ParseBooleanFromEnv("DISABLE_LOGICAL_STRAIGHTEN", false)) { op_graph.TopoForEachNodeWithCtrlEdge( [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); }); } else { auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes); } for (const OpNode* node : ordered_op_nodes) { int64_t nccl_depth = 0; ConstForEachDataAndCtrlInNode(node, [&](const OpNode* in_node) { auto it = op_node2nccl_depth.find(in_node); CHECK(it != op_node2nccl_depth.end()); // topo search nccl_depth = std::max(nccl_depth, it->second); }); if (IsNcclLogicalOpNode(node)) { nccl_depth++; // ONLY nccl node update depth nccl_depth2nccl_ops[nccl_depth].push_back(node); } CHECK(op_node2nccl_depth.emplace(node, nccl_depth).second); } if (nccl_depth2nccl_ops.empty()) { return Maybe::Ok(); } std::vector nccl_fusion_ops; std::vector nccl_fusion_op_parallel_confs; std::unordered_set del_ops; HashMap mut_op_name2conf; const int64_t bucket_limit = EnvInteger(); VLOG(2) << "bucket_limit = " << bucket_limit; for (const auto& pair : nccl_depth2nccl_ops) { HashMap> fusion_key2nccl_buckets; for (const OpNode* nccl_op : pair.second) { CHECK(nccl_op->op().op_conf().has_logical_chain_id()); std::string fusion_key = GenNcclFusionKey(nccl_op); AppendOrCreatFusionBucket(&fusion_key2nccl_buckets[fusion_key], nccl_op, bucket_limit); } for (const auto& pair : fusion_key2nccl_buckets) { for (const auto& fusion_bucket : pair.second) { JUST(ReplaceNcclOpsWithFusionOp(&nccl_fusion_ops, &nccl_fusion_op_parallel_confs, &del_ops, &mut_op_name2conf, fusion_bucket.nccl_ops)); } } } job_builder->RemoveOpByName(del_ops); for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } CHECK_EQ(nccl_fusion_ops.size(), nccl_fusion_op_parallel_confs.size()); for (int32_t i = 0; i < nccl_fusion_ops.size(); ++i) { JUST(job_builder->AddOp(JUST(VectorAt(nccl_fusion_op_parallel_confs, i)), JUST(VectorAt(nccl_fusion_ops, i)))); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("NcclLogicalOpFusionPass", NcclLogicalOpFusionPass); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class NormalizationExponentialAverageAutoTickPass final : public JobPass { public: NormalizationExponentialAverageAutoTickPass() = default; ~NormalizationExponentialAverageAutoTickPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe NormalizationExponentialAverageAutoTickPass::Apply(Job* job, JobPassCtx* ctx) const { const JobConfigProto& job_conf = ctx->job_desc().job_conf(); if (!job_conf.has_train_conf()) { return Maybe::Ok(); } if ((!job_conf.has_num_gradient_accumulation_steps()) || job_conf.num_gradient_accumulation_steps() <= 1) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_user_conf()) { return Maybe::Ok(); } const user_op::UserOpConfWrapper user_op_conf(op_conf); if (user_op_conf.op_type_name() != "normalization" && user_op_conf.op_type_name() != "normalization_add_relu") { return Maybe::Ok(); } const std::string& x_lbn = user_op_conf.input("x", 0); const std::string& moving_mean_lbn = user_op_conf.input("moving_mean", 0); const std::string& moving_variance_lbn = user_op_conf.input("moving_variance", 0); std::string x_tick_lbn; auto GetXTick = [&]() { if (x_tick_lbn.empty()) { user_op::UserOpConfWrapperBuilder cast_to_tick_builder("System-CastToTick-" + NewUniqueId()); const auto cast_to_tick_op = cast_to_tick_builder.OpTypeName("cast_to_tick") .Input("in", x_lbn) .Output("out") .Build(); job_builder.AddOps(node->parallel_desc().parallel_conf(), {cast_to_tick_op.op_conf()}); x_tick_lbn = cast_to_tick_op.output("out", 0); } return x_tick_lbn; }; auto TrySetTickForNode = [&](const OpNode* var_node) { if (!var_node->in_edges().empty()) { return; } if (!var_node->op().op_conf().has_variable_conf()) { return; } if (var_node->op().op_conf().variable_conf().has_tick()) { return; } OperatorConf new_var_op_conf = var_node->op().op_conf(); new_var_op_conf.mutable_variable_conf()->set_tick(GetXTick()); job_builder.MutOpsOnlyOnce({new_var_op_conf}); }; TrySetTickForNode(op_graph.OpNode4OpName(GenLogicalBlobId(moving_mean_lbn).op_name())); TrySetTickForNode(op_graph.OpNode4OpName(GenLogicalBlobId(moving_variance_lbn).op_name())); return Maybe::Ok(); })); return Maybe::Ok(); } REGISTER_JOB_PASS("NormalizationExponentialAverageAutoTickPass", NormalizationExponentialAverageAutoTickPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/optimizer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h" #include namespace oneflow { void GenerateOptimizerOpConfWrapperStruct::Call(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) const { (*func_)(ctx, var_op_node, model_diff_lbn, optimizer_conf, job_builder); } void AddOptimizerOp(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const auto optimizer_case = optimizer_conf.normal_mdupdt_case(); auto* obj = NewObj(optimizer_case); obj->Call(ctx, var_op_node, model_diff_lbn, optimizer_conf, job_builder); } float GetOptimizerWeightDecayRate(const OptimizerConf& optimizer_conf, const VariableOp& op) { if (optimizer_conf.has_weight_decay_conf()) { const WeightDecayConf& weight_decay_conf = optimizer_conf.weight_decay_conf(); std::function WeightDecayFilter; if (weight_decay_conf.has_includes()) { WeightDecayFilter = [&](const std::string& op_name) { return std::any_of( weight_decay_conf.includes().pattern().cbegin(), weight_decay_conf.includes().pattern().cend(), [&](const std::string& pattern) { return RE2::PartialMatch(op_name, pattern); }); }; } else if (weight_decay_conf.has_excludes()) { WeightDecayFilter = [&](const std::string& op_name) { return !std::any_of( weight_decay_conf.excludes().pattern().cbegin(), weight_decay_conf.excludes().pattern().cend(), [&](const std::string& pattern) { return RE2::PartialMatch(op_name, pattern); }); }; } else { WeightDecayFilter = [&](const std::string& op_name) { return true; }; } if (WeightDecayFilter(op.op_name())) { return weight_decay_conf.weight_decay_rate(); } else { return 0; } } else { return 0; } } void SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder) { if (!ctx->job_desc().job_conf().train_conf().has_dynamic_loss_scale_policy()) { return; } builder->Input("skip_if", CHECK_JUST(ctx->GetState("dynamic_loss_scale_state")) .count_not_finite_lbn()); } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/optimizer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_ #define ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_ #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { void AddOptimizerOp(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder); float GetOptimizerWeightDecayRate(const OptimizerConf& optimizer_conf, const VariableOp& op); void SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder); class GenerateOptimizerOpConfWrapperStruct final { public: using Func = std::function; GenerateOptimizerOpConfWrapperStruct(const Func& f) : func_(std::make_unique(f)) {} void Call(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) const; private: const std::unique_ptr func_; }; #define REGISTER_OPTIMIZER(model_update_case, gen_optimizer_conf_func) \ REGISTER_CLASS_CREATOR( \ int32_t, model_update_case, GenerateOptimizerOpConfWrapperStruct, \ ([] { return new GenerateOptimizerOpConfWrapperStruct(gen_optimizer_conf_func); })) } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_ ================================================ FILE: oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { int64_t GetSoleOutBlobSize(const OpNode* node) { const BlobDesc& blob_desc = node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(node->op().SoleObn())); return blob_desc.shape().elem_cnt() * GetSizeOfDataType(blob_desc.data_type()); } class DataParallelNodeSequence final { public: DataParallelNodeSequence(std::vector nodes, int64_t order) : nodes_(std::move(nodes)), order_(order), len_(nodes_.size()) { const OpNode* var_node = nodes_.front(); CHECK(var_node->op().op_conf().has_variable_conf()); model_size_ = GetSoleOutBlobSize(var_node); } ~DataParallelNodeSequence() = default; const OpNode* GetVariableNode() const { return nodes_.front(); } const OpNode* GetLastNode() const { return nodes_.back(); } int64_t order() const { return order_; } const std::vector& nodes() const { return nodes_; } const ParallelDesc& parallel_desc() const { return nodes_.front()->parallel_desc(); } int64_t model_size() const { return model_size_; } int64_t len() const { return len_; } void resize(const int64_t size) { CHECK_LE(size, len_); CHECK_GE(size, 1); nodes_.resize(size); len_ = nodes().size(); } private: std::vector nodes_; int64_t order_; int64_t model_size_; int64_t len_; }; using SequencePtr = std::shared_ptr; ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, const int64_t parallel_id) { std::string device_name; device_name += std::to_string(CHECK_JUST(pd.MachineId4ParallelId(parallel_id))); device_name += ":"; device_name += std::to_string(CHECK_JUST(pd.DeviceId4ParallelId(parallel_id))); ParallelConf parallel_conf; *parallel_conf.mutable_device_name()->Add() = device_name; parallel_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(pd.device_type()))); return parallel_conf; } Maybe GetDataParallelVariableAndNaiveSuccNode( const OpNode* start, const std::function& IsAllowed, std::vector* out) { // Find sequence like: vairable -> cast_fp32_to_fp16 if (!start->op().op_conf().has_variable_conf()) { return Maybe::Ok(); } const ParallelDesc& pd = start->parallel_desc(); if (pd.parallel_num() == 1) { return Maybe::Ok(); } const OpNode* cur_node = start; while (cur_node != nullptr) { if (cur_node != start) { if (cur_node->parallel_desc() != pd) { break; } if (cur_node->in_edges().size() > 1) { break; } if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); bool has_broadcast = false; FOR_RANGE(int, i, 0, ibn_nd_sbp.sbp_parallel_size()) { if (ibn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; } if (!has_broadcast) { break; } } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); bool has_broadcast = false; FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) { if (obn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; } if (!has_broadcast) { break; } out->emplace_back(cur_node); if (cur_node->out_edges().size() == 1) { cur_node = cur_node->SoleOutEdge()->dst_node(); } else { cur_node = nullptr; } } return Maybe::Ok(); } void SetBroadcastParallel4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::string& ibn) { OpBlobArg op_blob_arg; op_blob_arg.set_op_name(node->op().op_name()); op_blob_arg.set_bn_in_op(ibn); SbpParallel sbp_parallel; sbp_parallel.mutable_broadcast_parallel(); builder->SetSbpParallel4Oba(op_blob_arg, sbp_parallel); } void SetBroadcastParallel4Consumers(JobBuilder* builder, const SequencePtr& sequence) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { SetBroadcastParallel4OpNodeIbn(builder, out_node, ibn); } } }); } void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::string& ibn, const NdSbp& nd_sbp) { OpBlobArg op_blob_arg; op_blob_arg.set_op_name(node->op().op_name()); op_blob_arg.set_bn_in_op(ibn); builder->SetNdSbp4Oba(op_blob_arg, nd_sbp); } void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); const int64_t shard_restore_level = builder->job().job_conf().optimizer_placement_optimization_shard_restore_level(); // If shard_restore_level == 0, no limit on consumer if (shard_restore_level == 1) { // Input lbn for parallel cast op std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi); // Add parallel cast op to make soft limt on consumer to consume weight with Broadcast SBP. const auto parallel_cast_op = user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + NewUniqueId()) .Op("hierarchical_parallel_cast") .Input("in", parallel_cast_input_lbn) .Output("out") .Attr>("nd_sbp", NdSbpToStringList(nd_sbp)) .Attr("grad_mode", "identity") // don't do ndsbp cast at backward .Attr>("grad_nd_sbp", std::vector()) .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) .Build(); builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); // Make consumers to consume parallel cast op auto out_lbn = parallel_cast_op.output("out", 0); node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { if (!CHECK_JUST(builder->IsInMutOpTransaction(out_node->op().op_name()))) { CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf())); } OperatorConf& mut_consumer_op = CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name())); const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, out_lbn); CHECK_EQ(old_lbn, GenLogicalBlobName(lbi)); } } }); } else if (shard_restore_level == 2) { // Hard limt consumer to consume weight as Broadcast. node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); } } }); } } std::function MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) { HashMap op_node2topo_order; int64_t node_cnt = 0; op_graph.TopoForEachNode([&](const OpNode* node) { op_node2topo_order[node] = node_cnt; node_cnt += 1; }); return [op_node2topo_order](const OpNode* node) { return op_node2topo_order.at(node); }; } int64_t GetMinConsumerOrder(const OpGraph& op_graph, const OpNode* node, const std::function& OpNode2Order) { int64_t min_consumer_topo_order = op_graph.node_num(); node->ForEachNodeOnOutEdge([&](const OpNode* dst) { min_consumer_topo_order = std::min(min_consumer_topo_order, OpNode2Order(dst)); }); return min_consumer_topo_order; } void ForEachDataParallelNodeSequence(const OpGraph& op_graph, const std::function& IsAllowed, std::function Handler) { auto OpNode2Order = MakeGetterOpNode2TopoOrder(op_graph); op_graph.ForEachNode([&](const OpNode* node) { std::vector nodes; // Find sequence like: vairable -> cast_fp32_to_fp16 CHECK_JUST(GetDataParallelVariableAndNaiveSuccNode(node, IsAllowed, &nodes)); if (nodes.empty()) { return; } const int64_t order = GetMinConsumerOrder(op_graph, nodes.back(), OpNode2Order); Handler(std::make_shared(std::move(nodes), order)); }); } bool SequenceCompSortedByOrderAsc(const SequencePtr& lhs, const SequencePtr& rhs) { return lhs->order() < rhs->order(); } bool SequenceCompSortedByModelSizeDesc(const SequencePtr& lhs, const SequencePtr& rhs) { return lhs->model_size() > rhs->model_size(); } void ForEachParallelSortedNodeSequence( const OpGraph& op_graph, const std::function& IsAllowed, const std::function& Comp, const std::function&&)>& Handler) { HashMap> parallel_desc2sequences; // Find sequence like: vairable -> cast_fp32_to_fp16 ForEachDataParallelNodeSequence(op_graph, IsAllowed, [&](SequencePtr&& sequence) { parallel_desc2sequences[sequence->parallel_desc()].emplace_back(std::move(sequence)); }); for (auto& pair : parallel_desc2sequences) { auto& sequences = pair.second; std::sort(sequences.begin(), sequences.end(), Comp); Handler(pair.first, std::move(sequences)); } } bool IsS0Parallel(const SbpParallel& sbp_parallel) { return sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0; } bool IsS0Parallel(const SbpSignature& signature, const std::string& bn) { return IsS0Parallel(signature.bn_in_op2sbp_parallel().at(bn)); } bool IsNdSbpMatch(const NdSbpSignature& signature, const std::string& bn, const NdSbp& nd_sbp) { return signature.bn_in_op2nd_sbp().at(bn) == nd_sbp; } bool IsNdSbpSupported4Op(const OpNode* node, const NdSbp& nd_sbp) { if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } std::vector list; auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; CHECK_JUST(node->op().GetNdSbpSignatureList(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); const auto IsInAndOutMatch = [&](const NdSbpSignature& signature) { return IsNdSbpMatch(signature, node->op().SoleIbn(), nd_sbp) && IsNdSbpMatch(signature, node->op().SoleObn(), nd_sbp); }; return std::any_of(list.cbegin(), list.cend(), IsInAndOutMatch); } bool IsS0SignatureSupported(const OpNode* node) { if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } SbpSignatureList list; auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc().parallel_num(), &list)); const auto IsInOutS0Parallel = [&](const SbpSignature& signature) { return IsS0Parallel(signature, node->op().SoleIbn()) && IsS0Parallel(signature, node->op().SoleObn()); }; return std::any_of(list.sbp_signature().cbegin(), list.sbp_signature().cend(), IsInOutS0Parallel); } void ForEachModelSizeBalancedPartition( const ParallelDesc& parallel_desc, std::vector&& sorted_sequences, const std::function&&)>& Handler) { std::vector sequences = std::move(sorted_sequences); std::vector parallel_id2model_size(parallel_desc.parallel_num(), 0); std::vector> partitions(parallel_desc.parallel_num()); for (auto& sequence : sequences) { const auto it = std::min_element(parallel_id2model_size.cbegin(), parallel_id2model_size.cend()); const int64_t min_parallel_id = std::distance(parallel_id2model_size.cbegin(), it); parallel_id2model_size.at(min_parallel_id) += sequence->model_size(); partitions.at(min_parallel_id).emplace_back(std::move(sequence)); } for (int64_t i = 0; i < parallel_desc.parallel_num(); ++i) { ParallelConf parallel_conf = NonDistributedParallelConf4ParallelId(parallel_desc, i); Handler(parallel_conf, std::move(partitions.at(i))); } } namespace { bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, int64_t min_size) { if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; } CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); Shape cur_shape = shape; if (cur_shape.elem_cnt() < min_size) { return false; } FOR_RANGE(int64_t, i, 0, hierachy.NumAxes()) { const auto& sbp = nd_sbp.sbp_parallel(i); if (sbp.has_split_parallel()) { const int64_t dim = sbp.split_parallel().axis(); if (dim >= cur_shape.NumAxes()) { return false; } // Unbalanced split and take the minimum cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i)); // Larger then min size. if (cur_shape.elem_cnt() < min_size) { return false; } } } return true; } void GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf, std::string& new_split_signature, int64_t& split_dim) { if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { // split last dim split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; // All B, B -> S0 new_split_signature = "S(0)"; } else { // ND sbp, (*, B, S, *) -> (*, S, S, *) // ND sbp, (*, S, B, *) -> (*, S, S, *) FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { std::vector adjacent_dim{j - 1, j + 1}; for (auto const& dim_to_try : adjacent_dim) { if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { SbpParallel sbp; if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), &sbp) && sbp.has_split_parallel()) { new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); split_dim = j; } } if (new_split_signature != "") break; } } // Only split one more dim. if (new_split_signature != "") break; } } } void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd, std::vector&& sorted_sequences) { // For all sorted sequence, set the variable op in the sequence to S // and add ctrl edge to control the execution order between variable ops. // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass // consume the fp16 variable and the optimizer consume the fp32 variable. std::string prev_allowed_op_name = ""; for (int64_t i = 0; i < sorted_sequences.size(); ++i) { const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); const std::string& sole_obn = var_node->op().SoleObn(); const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape()); std::string new_split_signature = ""; int64_t split_dim = 0; GenerateSplitSignature(var_nd_sbp, new_var_op_conf, new_split_signature, split_dim); if (new_split_signature != "") { *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; } else { continue; } bool split_is_allowed = true; { NdSbp new_nd_sbp; std::vector nd_sbp_str_vec; for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) { nd_sbp_str_vec.emplace_back(sbp_str); } ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); // check allowed by min shard size and evenly split if (split_is_allowed) { split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); } if (split_is_allowed) { // resize sequence by new nd sbp limit auto& cur_seq = sorted_sequences.at(i); int64_t max_len = 1; if (cur_seq->len() > 1) { FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) { if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) { ++max_len; } else { break; } } } if (max_len < cur_seq->len()) { cur_seq->resize(max_len); } } } if (!split_is_allowed) { VLOG(3) << var_node->op().op_name() << " failed to change from B to S " << " with op conf " << new_var_op_conf.variable_conf().DebugString(); continue; } if (!prev_allowed_op_name.empty()) { new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); } builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. if (new_split_signature != "") { SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } prev_allowed_op_name = var_node->op().op_name(); VLOG(3) << var_node->op().op_name() << " succeed to change from B to " << new_split_signature << " on ranks dim " << split_dim << " with op conf " << new_var_op_conf.variable_conf().DebugString(); } } } // namespace Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); const auto IsAllowed = [](const OpNode* n) -> bool { // No need to limit here. return true; }; const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, std::vector&& sorted_sequences) { ShardSequence(builder, threshold, pd, std::forward>(sorted_sequences)); }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, PlacementSequencesAsSplitParallel); JUST(builder->MutOpTransactionCommit()); return Maybe::Ok(); } Maybe RewriteNonDistributed(const OpGraph& op_graph, JobBuilder* builder) { HashMap> new_parallel_desc2sequences; const auto RewritePartition = [&](const ParallelDesc& new_parallel_desc, std::vector&& partition) { for (auto& sequence : partition) { for (const OpNode* op_node : sequence->nodes()) { builder->MutParallelConfOnlyOnce(op_node->op().op_name(), new_parallel_desc.parallel_conf()); } SetBroadcastParallel4Consumers(builder, sequence); new_parallel_desc2sequences[new_parallel_desc].emplace_back(std::move(sequence)); } }; const auto RewriteSequences = [&](const ParallelDesc& pd, std::vector&& sorted_sequences) { ForEachModelSizeBalancedPartition(pd, std::move(sorted_sequences), RewritePartition); }; const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); const auto IsAllowed = [threshold](const OpNode* n) -> bool { if (n->op().op_conf().has_variable_conf()) { const Shape shape(n->op().op_conf().variable_conf().shape()); const int64_t parallel_num = n->parallel_desc().parallel_num(); return shape.elem_cnt() >= threshold * parallel_num; } else { return true; } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByModelSizeDesc, RewriteSequences); for (auto& parallel_desc7sequences : new_parallel_desc2sequences) { auto& sequences = parallel_desc7sequences.second; std::sort(sequences.begin(), sequences.end(), SequenceCompSortedByOrderAsc); for (int64_t i = 1; i < sequences.size(); ++i) { const OpNode* cur_var_node = sequences.at(i)->GetVariableNode(); OperatorConf cur_var_conf(cur_var_node->op().op_conf()); const OpNode* prev_var_node = sequences.at(i - i)->GetVariableNode(); cur_var_conf.add_ctrl_in_op_name(prev_var_node->op().op_name()); builder->MutOpsOnlyOnce({cur_var_conf}); } } return Maybe::Ok(); } class OptimizerPlacementOptimizationPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(OptimizerPlacementOptimizationPass); OptimizerPlacementOptimizationPass() = default; ~OptimizerPlacementOptimizationPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!(ctx->job_desc().IsTrain() && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode() && ctx->job_desc().job_conf().optimizer_placement_optimization_mode() != "none")) { return Maybe::Ok(); } if (job->job_conf().enable_auto_parallel() && job->job_conf().enable_auto_parallel_ignore_user_sbp_config()) { LOG(WARNING) << "ZeRO optimization will be ignored when enabling AutoParallel to ignore user " "sbp configuration"; if (job->job_conf().enable_auto_memory() != oneflow::AutoMemoryStrategy::kHeavyAutoMemory) { job->mutable_job_conf()->set_enable_auto_memory( ::oneflow::AutoMemoryStrategy::kModerateAutoMemory); LOG(WARNING) << "But we turn on moderate auto memory to reduce the memory, which has " "similar effect as the ZeRO optimization"; } return Maybe::Ok(); } const std::string& mode = ctx->job_desc().job_conf().optimizer_placement_optimization_mode(); const OpGraph op_graph(*job); JobBuilder job_builder(job); if (mode == "non_distributed") { return RewriteNonDistributed(op_graph, &job_builder); } else if (mode == "distributed_split") { return RewriteDistributedSplit(op_graph, &job_builder); } else { return Error::UnimplementedError(); } } }; REGISTER_JOB_PASS("OptimizerPlacementOptimizationPass", OptimizerPlacementOptimizationPass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/pass_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/pass_util.h" namespace oneflow { bool IsNodeInList(const HashSet& op_list, OpNode* node) { if (node->op().op_conf().has_user_conf() == false) { return false; } const std::string op_type = node->op().op_conf().user_conf().op_type_name(); return IsKeyFound(op_list, op_type); } std::string ReplaceSlashToDash4Lbn(std::string lbn) { std::replace(lbn.begin(), lbn.end(), '/', '-'); return lbn; } void DfsTopoGraphTraversal(const OpGraph& graph, bool reversed, std::function IsCurNodeStartNode, std::function IsCurNodeSatisfied, std::function IsFatherNodeSatisfied, std::function NodeHandler) { auto start_nodes = reversed ? graph.sink_nodes() : graph.source_nodes(); std::function)> NodeOnInEdge = reversed ? &OpNode::ForEachNodeOnOutEdge : &OpNode::ForEachNodeOnInEdge; std::function)> NodeOnOutEdge = reversed ? &OpNode::ForEachNodeOnInEdge : &OpNode::ForEachNodeOnOutEdge; graph.DfsTopoForEachNode(start_nodes, NodeOnInEdge, NodeOnOutEdge, [&](OpNode* node) { if (IsCurNodeStartNode(node)) { NodeHandler(node); return; } if (IsCurNodeSatisfied(node)) { bool is_one_father_of_node_satisfied = false; NodeOnInEdge(node, [&](OpNode* father_node) { if (is_one_father_of_node_satisfied) { return; } if (IsFatherNodeSatisfied(father_node)) { is_one_father_of_node_satisfied = true; } }); if (is_one_father_of_node_satisfied) { NodeHandler(node); } } }); } std::function MakePredicatorIsSafeToDelete(const OpGraph& op_graph) { HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); return [=](const OpNode* op_node) { if (op_node->out_edges().size() > 1) { return false; } if (!op_node->op().op_conf().ctrl_in_op_name().empty()) { return false; } if (ctrl_in_op_names.find(op_node->op().op_conf().name()) != ctrl_in_op_names.end()) { return false; } return true; }; } bool IsUserOpWithTypeName(const OperatorConf& op_conf, const std::string& op_type_name) { return op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == op_type_name; } std::string GenParallelConfKey(const ParallelConf& conf) { std::string ret = conf.device_tag(); for (const auto& name : conf.device_name()) { ret += ("-" + name); } return ret; } } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/pass_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_ #define ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_ #include #include #include "oneflow/core/graph/op_graph.h" namespace oneflow { #define INSERT_CHECK(expr) CHECK(expr.second) #define INSERT_CHECK_OR_RETURN(expr) CHECK_OR_RETURN(expr.second) template bool IsKeyFound(const MapT& m, const KeyT& k) { return m.find(k) != m.end(); } bool IsNodeInList(const HashSet& op_list, OpNode* node); template std::string Container2Str(const ContainerT& container, std::function elem2str) { std::string ret; bool is_first = true; for (const ElemT& elem : container) { if (is_first) { is_first = false; } else { ret += ",\n"; } ret += elem2str(elem); } return ret; } std::string ReplaceSlashToDash4Lbn(std::string lbn); void DfsTopoGraphTraversal(const OpGraph& graph, bool reversed, std::function IsCurNodeStartNode, std::function IsCurNodeSatisfied, std::function IsFatherNodeSatisfied, std::function NodeHandler); // make sure an op_conf can only be udpated once, cuz later update will override before class OpConfCache { std::map _op_confs_to_update; public: OperatorConf GetLatest(const OperatorConf& op_conf) { if (_op_confs_to_update.find(op_conf.name()) != _op_confs_to_update.end()) { return _op_confs_to_update[op_conf.name()]; } return op_conf; } void Put(const OperatorConf& op_conf) { _op_confs_to_update[op_conf.name()] = op_conf; } std::vector op_confs() { std::vector ret; for (const auto& x : _op_confs_to_update) { ret.emplace_back(x.second); } return ret; } }; std::function MakePredicatorIsSafeToDelete(const OpGraph& op_graph); bool IsUserOpWithTypeName(const OperatorConf& op_conf, const std::string& op_type_name); std::string GenParallelConfKey(const ParallelConf& conf); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_ ================================================ FILE: oneflow/core/job_rewriter/pipeline_buffer_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { class PipelineBufferPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(PipelineBufferPass); PipelineBufferPass() = default; ~PipelineBufferPass() = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } bool IsEnabled(const JobPassCtx& ctx) const { // Pipeline optimization depends on gradient accumulatioin. return ctx.job_desc().IsTrain() && ctx.job_desc().job_conf().num_gradient_accumulation_steps() > 1; } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; const std::string kBufferOpNamePrefix = "System-Pipeline-Buffer-Op_"; const Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) { CHECK(Singleton>::Get()->Has(scope_symbol_id)); return Singleton>::Get()->Get(scope_symbol_id); } const Scope& Scope4OpNode(const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); CHECK(op_conf.has_scope_symbol_id()); return Scope4ScopeSymbolId(op_conf.scope_symbol_id()); } bool IsForwardPass(const OpNode* node) { return Scope4OpNode(node).scope_proto().calculation_pass_name() == kForwardPass; } bool IsBackwardPass(const OpNode* node) { return Scope4OpNode(node).scope_proto().calculation_pass_name() == kBackwardPass; } bool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); } bool IsIdentityBufferOrRepeatOpNode(const OpNode* node) { const OperatorConf& op_conf = node->op().op_conf(); if (op_conf.has_user_conf()) { const std::string& op_type_name = op_conf.user_conf().op_type_name(); if (op_type_name == "identity_buffer" || op_type_name == "repeat") { return true; } } return false; } int64_t GetStageIdHint(const OpNode* node) { return Scope4OpNode(node).Int64("pipeline_stage_id_hint"); } void TryInsertOrUseBufferOpToDstNode( const OpEdge* op_edge, const int64_t buffer_size, HashMap* buffer_op_name2op_conf, HashMap* buffer_op_name2parallel_conf, HashMap* mut_op_name2conf) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); const int64_t src_stage_id = GetStageIdHint(src_node); const int64_t dst_stage_id = GetStageIdHint(dst_node); const std::string& dst_op_name = dst_node->op().op_name(); const int64_t stage_id = GetStageIdHint(dst_node); for (const LogicalBlobId& lbi : op_edge->lbis()) { std::string lbn = GenLogicalBlobName(lbi); std::string buffer_op_name = kBufferOpNamePrefix + "-" + lbi.op_name() + "-" + lbi.blob_name() + "-stage_id_" + std::to_string(stage_id); auto it = buffer_op_name2op_conf->find(buffer_op_name); if (it == buffer_op_name2op_conf->end()) { it = buffer_op_name2op_conf ->emplace(buffer_op_name, user_op::UserOpConfWrapperBuilder(buffer_op_name) .Op("identity_buffer") .Input("in", lbn) .Output("out") .Attr("buffer_size", buffer_size) .ScopeSymbolId(dst_node->op().op_conf().scope_symbol_id()) .Build() .op_conf()) .first; CHECK(buffer_op_name2parallel_conf ->emplace(buffer_op_name, dst_node->parallel_desc().parallel_conf()) .second); VLOG(3) << "\n Insert buffer op : [" << buffer_op_name << "](buffer_size:" << buffer_size << ") \n from [" << src_node->op().op_name() << "] (stage_id:" << std::to_string(src_stage_id) << ") -> [" << dst_node->op().op_name() << "] (stage_id:" << std::to_string(dst_stage_id) << ") \n"; } auto mut_op_it = mut_op_name2conf->find(dst_op_name); if (mut_op_it == mut_op_name2conf->end()) { mut_op_it = mut_op_name2conf->emplace(dst_op_name, dst_node->op().op_conf()).first; } const std::string buffer_out = user_op::UserOpConfWrapper(it->second).output("out", 0); for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) { std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(&(mut_op_it->second), ibn, buffer_out); CHECK_EQ(old_lbn, lbn); } } } void TryInsertOrUseBufferOpBothSrcDst( const OpEdge* op_edge, const int64_t src_buffer_size, const int64_t dst_buffer_size, HashMap* buffer_op_name2op_conf, HashMap* buffer_op_name2parallel_conf, HashMap* mut_op_name2conf) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); const ParallelDesc& src_parallel_desc = src_node->parallel_desc(); const ParallelDesc& dst_parallel_desc = dst_node->parallel_desc(); const std::string& src_op_name = src_node->op().op_name(); const std::string& dst_op_name = dst_node->op().op_name(); const int64_t src_stage_id = GetStageIdHint(src_node); const int64_t dst_stage_id = GetStageIdHint(dst_node); CHECK_NE(src_stage_id, dst_stage_id); CHECK_GE(src_buffer_size, 1); CHECK_GE(dst_buffer_size, 1); CHECK(!src_parallel_desc.EqualsIgnoringHierarchy(dst_parallel_desc)) << " Pipeline buffer pass meet ERROR! the src_op: " << src_op_name << " -> dst_op: " << dst_op_name << " with same placement: " << src_parallel_desc.parallel_conf().DebugString() << " , but with different stage id: src_stage_id (" << src_stage_id << ") -> dst_stage_id (" << dst_stage_id << "). Please check your stage id config for modules."; for (const LogicalBlobId& lbi : op_edge->lbis()) { std::string lbn = GenLogicalBlobName(lbi); std::string src_buffer_op_name = kBufferOpNamePrefix + "-" + lbi.op_name() + "-" + lbi.blob_name(); std::string dst_buffer_op_name = kBufferOpNamePrefix + "-" + lbi.op_name() + "-" + lbi.blob_name() + "-stage_id_" + std::to_string(dst_stage_id); auto src_buffer_it = buffer_op_name2op_conf->find(src_buffer_op_name); if (src_buffer_it == buffer_op_name2op_conf->end()) { src_buffer_it = buffer_op_name2op_conf ->emplace(src_buffer_op_name, user_op::UserOpConfWrapperBuilder(src_buffer_op_name) .Op("identity_buffer") .Input("in", lbn) .Output("out") .Attr("buffer_size", src_buffer_size) .ScopeSymbolId(src_node->op().op_conf().scope_symbol_id()) .Build() .op_conf()) .first; CHECK(buffer_op_name2parallel_conf ->emplace(src_buffer_op_name, src_parallel_desc.parallel_conf()) .second); } const OperatorConf& src_conf = src_buffer_it->second; const std::string src_buffer_out = user_op::UserOpConfWrapper(src_conf).output("out", 0); auto dst_buffer_it = buffer_op_name2op_conf->find(dst_buffer_op_name); if (dst_buffer_it == buffer_op_name2op_conf->end()) { dst_buffer_it = buffer_op_name2op_conf ->emplace(dst_buffer_op_name, user_op::UserOpConfWrapperBuilder(dst_buffer_op_name) .Op("identity_buffer") .Input("in", src_buffer_out) .Output("out") .Attr("buffer_size", dst_buffer_size) .ScopeSymbolId(dst_node->op().op_conf().scope_symbol_id()) .Build() .op_conf()) .first; CHECK(buffer_op_name2parallel_conf ->emplace(dst_buffer_op_name, dst_parallel_desc.parallel_conf()) .second); } const OperatorConf& dst_conf = dst_buffer_it->second; auto mut_op_it = mut_op_name2conf->find(dst_op_name); if (mut_op_it == mut_op_name2conf->end()) { mut_op_it = mut_op_name2conf->emplace(dst_op_name, dst_node->op().op_conf()).first; } VLOG(3) << "\n Insert buffer op pair : src_buffer = <" << src_buffer_op_name << ">(buffer_size:" << src_buffer_size << ") , dst_buffer = <" << dst_buffer_op_name << ">(buffer_size:" << dst_buffer_size << ") \n from [" << src_node->op().op_name() << "] (stage_id:" << std::to_string(src_stage_id) << ") -> [" << dst_node->op().op_name() << "] (stage_id:" << std::to_string(dst_stage_id) << ") \n"; const std::string dst_buffer_out = user_op::UserOpConfWrapper(dst_conf).output("out", 0); for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) { std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(&(mut_op_it->second), ibn, dst_buffer_out); CHECK_EQ(old_lbn, lbn); } } } Maybe PipelineBufferPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { int64_t max_stage_id = 0; op_graph.ForEachNode([&](const OpNode* this_node) { if (!OpNodeHasScope(this_node)) { LOG(WARNING) << " op : " << this_node->op().op_conf().DebugString() << " has NOT scope!"; return; } max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node)); }); if (max_stage_id == 0) { return Maybe::Ok(); } const int64_t total_stage_num = max_stage_id + 1; VLOG(3) << "total stage num = " << total_stage_num; HashMap buffer_op_name2op_conf; HashMap buffer_op_name2parallel_conf; HashMap mut_op_name2conf; op_graph.ForEachNode([&](const OpNode* this_node) { if (!OpNodeHasScope(this_node)) { return; /* ignore op without scope */ } if (!IsBackwardPass(this_node)) { return; /* ignore fw dst op */ } for (const OpEdge* in_edge : this_node->in_edges()) { const OpNode* src_node = in_edge->src_node(); if (!OpNodeHasScope(src_node)) { continue; /* ignore op without scope */ } const int64_t src_stage_id = GetStageIdHint(src_node); const int64_t dst_stage_id = GetStageIdHint(this_node); if (IsForwardPass(src_node) && (!IsIdentityBufferOrRepeatOpNode(src_node))) { if (dst_stage_id == max_stage_id) { continue; /* last stage(loss) does NOT need to insert buffer */ } if (src_stage_id != dst_stage_id) { LOG(WARNING) << " Cross diff stage link From: [" << src_node->op().op_conf().DebugString() << "](stage_id:" << std::to_string(src_stage_id) << ") -> [" << this_node->op().op_conf().DebugString() << "](stage_id:" << std::to_string(dst_stage_id) << "). Make sure to change the tensor's placement before it enter the module " "of a next pipeline stage.\n"; } const int64_t buffer_size = total_stage_num * 2; /* NOTE(chengcheng): max buffer size */ TryInsertOrUseBufferOpToDstNode(in_edge, buffer_size, &buffer_op_name2op_conf, &buffer_op_name2parallel_conf, &mut_op_name2conf); } } for (const std::string& ctrl_in_op_name : this_node->op().op_conf().ctrl_in_op_name()) { const OpNode* src_node = op_graph.OpNode4OpName(ctrl_in_op_name); if (!OpNodeHasScope(src_node)) { continue; /* ignore op without scope */ } if (IsForwardPass(src_node)) { LOG(WARNING) << "CtrlEdge: src_op[FwPass]: " << src_node->op().op_conf().DebugString() << " dst_op[BwPass]: " << this_node->op().op_conf().DebugString() << " connected."; } } }); op_graph.ForEachEdge([&](const OpEdge* edge) { const OpNode* src_node = edge->src_node(); const OpNode* dst_node = edge->dst_node(); if (OpNodeHasScope(src_node) && OpNodeHasScope(dst_node) && IsForwardPass(src_node) && IsForwardPass(dst_node)) { const int64_t src_stage_id = GetStageIdHint(src_node); const int64_t dst_stage_id = GetStageIdHint(dst_node); if (src_node->parallel_desc().device_type() == DeviceType::kCPU && dst_node->parallel_desc().device_type() == DeviceType::kCUDA) { if (src_stage_id == 0 && dst_stage_id == max_stage_id) { TryInsertOrUseBufferOpToDstNode(edge, total_stage_num * 2, &buffer_op_name2op_conf, &buffer_op_name2parallel_conf, &mut_op_name2conf); return; } } if (src_stage_id < dst_stage_id) { /* NOTE(chengcheng): We insert double buffer between src / dst node. * src_buffer_size = 1 because we need free memory as early as possible so we can overlap * CopyD2H with Compute. * dst_buffer_size = dst_stage_id - src_stage_id for pipeline. */ const int64_t dst_buffer_size = dst_stage_id - src_stage_id; TryInsertOrUseBufferOpBothSrcDst(edge, 1, dst_buffer_size, &buffer_op_name2op_conf, &buffer_op_name2parallel_conf, &mut_op_name2conf); } } if (OpNodeHasScope(src_node) && OpNodeHasScope(dst_node) && IsBackwardPass(src_node) && IsBackwardPass(dst_node)) { const int64_t src_stage_id = GetStageIdHint(src_node); const int64_t dst_stage_id = GetStageIdHint(dst_node); // NOTE(chengcheng): Backward ONLY need buffer size 1. if (src_stage_id > dst_stage_id) { TryInsertOrUseBufferOpBothSrcDst(edge, 1, 1, &buffer_op_name2op_conf, &buffer_op_name2parallel_conf, &mut_op_name2conf); } } }); for (auto& pair : buffer_op_name2op_conf) { CHECK(buffer_op_name2parallel_conf.find(pair.first) != buffer_op_name2parallel_conf.end()); JUST(job_builder->AddOp(buffer_op_name2parallel_conf.at(pair.first), pair.second)); } for (auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("PipelineBufferPass", PipelineBufferPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/prune_amp_white_identity_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { bool IsAmpIdentityOp(const OperatorConf& op) { return op.has_user_conf() && (op.user_conf().op_type_name() == "amp_white_identity" || op.user_conf().op_type_name() == "amp_black_identity"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsAmpIdentityOp); } class PruneAmpWhiteIdentityOpPass final : public JobPass { public: PruneAmpWhiteIdentityOpPass() = default; ~PruneAmpWhiteIdentityOpPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe PruneAmpWhiteIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().prune_amp_white_identity_ops()) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); HashSet del_nodes; op_graph.ForEachNode([&](const OpNode* op_node) { const std::string& op_name = op_node->op().op_name(); const OperatorConf& op_conf = op_node->op().op_conf(); // not amp identity op if (!IsAmpIdentityOp(op_conf)) { return; } // has ctrl in if (!op_conf.ctrl_in_op_name().empty()) { return; } // is ctrl in of another op if (ctrl_in_op_names.find(op_name) != ctrl_in_op_names.end()) { return; } // not sole in if (op_node->in_edges().size() != 1) { return; } del_nodes.insert(op_node); }); HashMap to_update_op_confs; std::vector del_op_names; del_op_names.reserve(del_nodes.size()); for (const OpNode* op_node : del_nodes) { del_op_names.emplace_back(op_node->op().op_name()); // find first node not deleted const OpNode* first = op_node; const OpNode* producer = op_node->SoleInEdge()->src_node(); while (del_nodes.find(producer) != del_nodes.end()) { first = producer; producer = producer->SoleInEdge()->src_node(); } const auto& old_lbi = op_node->op().BnInOp2Lbi(op_node->op().SoleObn()); const auto& new_lbi = first->op().BnInOp2Lbi(first->op().SoleIbn()); for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (del_nodes.find(consumer) == del_nodes.end()) { const Operator& op = consumer->op(); for (const std::string& ibn : op.input_bns()) { if (op.BnInOp2Lbi(ibn) == old_lbi) { auto iter = to_update_op_confs.find(op.op_name()); if (iter == to_update_op_confs.end()) { iter = to_update_op_confs.emplace(op.op_name(), op.op_conf()).first; } OperatorConf& op_conf = iter->second; const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, GenLogicalBlobName(new_lbi)); CHECK_EQ_OR_RETURN(GenLogicalBlobName(old_lbi), old_val); } } } } } JobBuilder job_builder(job); for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); } job_builder.DelOps(del_op_names); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("PruneAmpWhiteIdentityOpPass", PruneAmpWhiteIdentityOpPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/prune_cast_to_static_shape_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { bool IsRelatedOp(const OperatorConf& op) { return op.has_user_conf() && (op.user_conf().op_type_name() == "cast_to_static_shape"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsRelatedOp); } class PruneCastToStaticShapeOpsPass final : public JobPass { public: PruneCastToStaticShapeOpsPass() = default; ~PruneCastToStaticShapeOpsPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain() && ctx.job_desc().prune_cast_to_static_shape_ops(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe PruneCastToStaticShapeOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap op_name2op_conf; HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); std::vector del_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } const std::string& op_type_name = op_conf.user_conf().op_type_name(); if (op_type_name != "cast_to_static_shape") { return; } if (!op_conf.ctrl_in_op_name().empty()) { return; } if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; } if (op_node->in_edges().size() != 1) { return; } const user_op::UserOpConfWrapper user_op_conf(op_conf); const LogicalBlobId& cast_in_lbi = GenLogicalBlobId(user_op_conf.input("input", 0)); const LogicalBlobId& cast_out_lbi = GenLogicalBlobId(user_op_conf.output("output", 0)); const OpNode* producer = op_graph.OpNode4OpName(cast_in_lbi.op_name()); const BlobDesc& cast_in_logical_blob_desc = producer->LogicalBlobDesc4Lbi(cast_in_lbi); if (cast_in_logical_blob_desc.is_dynamic()) { return; } for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); const std::string& consumer_op_name = consumer->op().op_name(); if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) { op_name2op_conf[consumer_op_name] = consumer->op().op_conf(); } OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name); for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == cast_out_lbi) { const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, GenLogicalBlobName(cast_in_lbi)); CHECK_EQ(GenLogicalBlobName(cast_out_lbi), old_val); } } } del_op_names.emplace_back(op_conf.name()); }); for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); } job_builder->DelOps(del_op_names); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("PruneCastToStaticShapeOpsPass", PruneCastToStaticShapeOpsPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/prune_depend_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/core/common/hash_container.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/graph/node.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/register/logical_blob_id.pb.h" namespace oneflow { namespace { struct UpdatedNodeInfo { const OpNode* node = nullptr; const OpNode* new_src_node = nullptr; const OpNode* depend_node_nearest_src = nullptr; const OpNode* depend_node_nearest_dst = nullptr; std::vector new_in_ctrl_nodes; bool updated = false; }; bool IsDependyOp(const OperatorConf& op) { return op.has_user_conf() && (op.user_conf().op_type_name() == "depend"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp); } const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node, const std::string& target_tensor_name) { CHECK(IsDependyOp(op_node->op().op_conf())); for (const OpEdge* in_edge : op_node->in_edges()) { const OpNode* in_op_node = in_edge->src_node(); const std::string& in_op_node_name = in_op_node->op().op_name(); const HashMap>& lbi2ibns = in_edge->lbi2ibns(); for (const auto& item : lbi2ibns) { const std::string& lbi_op_name = item.first.op_name(); for (const std::string& tensor_name : item.second) { if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) { return in_op_node; } } } } return nullptr; } const OpNode* GetNodeFromInputEdge(const OpNode* op_node) { return GetNodeFromEdgeByTensorName(op_node, "in_0"); } const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) { return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0"); } LogicalBlobId GetNewLbi(const OpNode* src_node, const OpNode* depend_node_nearest_src) { CHECK(IsDependyOp(depend_node_nearest_src->op().op_conf())); for (const OpEdge* out_edge : src_node->out_edges()) { const OpNode* dst_node = out_edge->dst_node(); if (dst_node != depend_node_nearest_src) { continue; } CHECK(out_edge->lbis().size() == 1); return out_edge->lbis()[0]; } // should not reach here CHECK(false); return {}; } class PruneDependOpPass final : public JobPass { public: PruneDependOpPass() = default; ~PruneDependOpPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const { if (!ctx->job_desc().prune_depend_ops()) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); HashMap node_info_with_update; std::vector ordered_nodes; // Step 0: topological sort, setup a map for recording modification op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { UpdatedNodeInfo node_info; node_info.node = node; node_info_with_update.emplace(node->op().op_name(), node_info); ordered_nodes.emplace_back(node); }); // Step 1: process node by topological order // record modification info when meet Depend OP nodes for (const OpNode* cur_node : ordered_nodes) { const std::string& cur_op_name = cur_node->op().op_name(); const OperatorConf& cur_op_conf = cur_node->op().op_conf(); if (!IsDependyOp(cur_op_conf)) { continue; } // record modification info to each dst_node for (const OpEdge* out_edge : cur_node->out_edges()) { const OpNode* dst_node = out_edge->dst_node(); const Operator& dst_op = dst_node->op(); UpdatedNodeInfo& updated_dst_node_info = node_info_with_update.find(dst_op.op_name())->second; UpdatedNodeInfo& updated_cur_node_info = node_info_with_update.find(cur_op_name)->second; updated_dst_node_info.updated = true; updated_dst_node_info.depend_node_nearest_dst = cur_node; // Step 1.1: record a new in-ctrl node const OpNode* cur_in_ctrl_node = GetNodeFromInCtrlEdge(cur_node); updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_in_ctrl_node); // Step 1.2: inherit in-ctrl nodes from Depend OP nodes const auto& ori_in_ctrl_op_names = cur_op_conf.ctrl_in_op_name(); for (const std::string& ori_ctrl_in_op_name : ori_in_ctrl_op_names) { updated_dst_node_info.new_in_ctrl_nodes.emplace_back( node_info_with_update[ori_ctrl_in_op_name].node); } if (updated_cur_node_info.updated) { std::vector& inherit_in_ctrl_nodes = updated_cur_node_info.new_in_ctrl_nodes; for (const OpNode* inherit_in_ctrl_node : inherit_in_ctrl_nodes) { updated_dst_node_info.new_in_ctrl_nodes.emplace_back(inherit_in_ctrl_node); } } // Step 1.3 process src nodes const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node); if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) { // "cur_node" and "dst_node" are all Depend OP nodes, and their connection is like this // other_node cur_node // \ / // dst_node // in this case, all src nodes of "cur_node" should be seen as in-ctrl nodes if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { updated_dst_node_info.new_in_ctrl_nodes.emplace_back(updated_cur_node_info.new_src_node); } updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_src_node); } else { if (!IsDependyOp(cur_src_node->op().op_conf())) { updated_dst_node_info.new_src_node = cur_src_node; updated_dst_node_info.depend_node_nearest_src = cur_node; } else if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { updated_dst_node_info.new_src_node = updated_cur_node_info.new_src_node; updated_dst_node_info.depend_node_nearest_src = updated_cur_node_info.depend_node_nearest_src; } } } } // Step 2: extract modification info // including new connection and to delete nodes std::vector del_node_names; HashMap to_update_op_confs; for (const auto& node_info : node_info_with_update) { // filter nodes not updated if (!node_info.second.updated) { continue; } const OpNode* cur_node = node_info.second.node; const std::string& cur_op_name = cur_node->op().op_name(); // filter Depnd nodes if (IsDependyOp(cur_node->op().op_conf())) { del_node_names.emplace_back(cur_op_name); continue; } const Operator& cur_op = cur_node->op(); auto iter = to_update_op_confs.find(node_info.first); if (iter == to_update_op_confs.end()) { iter = to_update_op_confs.emplace(node_info.first, cur_op.op_conf()).first; } OperatorConf& cur_op_conf = iter->second; // Step 2.1: connect updated src_node with cur_node (dst_node of Depned OP) const OpNode* src_node = node_info.second.new_src_node; const OpNode* depend_node_nearest_dst = node_info.second.depend_node_nearest_dst; const OpNode* depend_node_nearest_src = node_info.second.depend_node_nearest_src; CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src); const auto& old_lbi = depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn()); const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src); for (const std::string& ibn : cur_node->op().input_bns()) { if (cur_op.BnInOp2Lbi(ibn) == old_lbi) { const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&cur_op_conf, ibn, GenLogicalBlobName(new_lbi)); CHECK_EQ(GenLogicalBlobName(old_lbi), old_val); VLOG(3) << "Update input edge, Src Node: " << src_node->op().op_name() << "\t->\tDst Node: " << cur_op_name; } } // Step 2.2: add in-ctrl OPs const auto& existed_ctrl_in_op_names = cur_op_conf.ctrl_in_op_name(); for (const OpNode* in_ctrl_node : node_info.second.new_in_ctrl_nodes) { // filter Depnd nodes if (IsDependyOp(in_ctrl_node->op().op_conf())) { continue; } CHECK(cur_node != in_ctrl_node); // self-loop found const std::string& new_ctrl_in_op_name = in_ctrl_node->op().op_name(); auto existed_it = std::find(existed_ctrl_in_op_names.begin(), existed_ctrl_in_op_names.end(), new_ctrl_in_op_name); // filter src node or duplicate in-ctrl nodes if (in_ctrl_node != src_node && existed_it == existed_ctrl_in_op_names.end()) { cur_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name); VLOG(3) << "Add in-ctrl edge, Src Node: " << new_ctrl_in_op_name << "\t->\tDst Node: " << cur_op_name; } } } // Step 3: apply modification to job JobBuilder job_builder(job); for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); } job_builder.DelOps(del_node_names); return Maybe::Ok(); }; } // namespace REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/prune_parallel_cast_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { bool IsParallelCastOp(const OperatorConf& op_conf) { return op_conf.has_user_conf() && (op_conf.user_conf().op_type_name() == "parallel_cast" || op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast" || op_conf.user_conf().op_type_name() == "hierarchical_parallel_cast_like"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsParallelCastOp); } class PruneParallelCastOpsPass final : public JobPass { public: PruneParallelCastOpsPass() = default; ~PruneParallelCastOpsPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().prune_parallel_cast_ops(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe PruneParallelCastOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap op_name2op_conf; HashMap op_name2nd_sbp_signature; HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { ctrl_in_op_names.insert(ctrl_in_op_name); } }); std::vector del_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.ctrl_in_op_name().empty()) { return; } if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; } if (!IsParallelCastOp(op_conf)) { return; } if (op_node->in_edges().size() != 1) { return; } user_op::UserOpConfWrapper conf_wrapper(op_conf); const LogicalBlobId& parallel_cast_in_lbi = GenLogicalBlobId(conf_wrapper.input("in", 0)); const LogicalBlobId& parallel_cast_out_lbi = GenLogicalBlobId(conf_wrapper.output("out", 0)); const OpNode* producer = op_graph.OpNode4OpName(parallel_cast_in_lbi.op_name()); const NdSbp& parallel_cast_nd_sbp = op_node->NdSbp4Lbi(parallel_cast_in_lbi); const NdSbp& producer_nd_sbp = producer->NdSbp4Lbi(parallel_cast_in_lbi); if (op_node->parallel_desc() != producer->parallel_desc()) { return; } if (parallel_cast_nd_sbp != producer_nd_sbp && op_node->out_edges().size() > 1) { return; } for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); if (IsParallelCastOp(consumer->op().op_conf())) { return; } if (consumer->parallel_desc() != op_node->parallel_desc()) { return; } if (consumer->NdSbp4Lbi(parallel_cast_out_lbi) != parallel_cast_nd_sbp) { return; } } op_name2nd_sbp_signature[producer->op().op_name()] = producer->nd_sbp_signature(); for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); const std::string& consumer_op_name = consumer->op().op_name(); op_name2nd_sbp_signature[consumer_op_name] = consumer->nd_sbp_signature(); if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) { op_name2op_conf[consumer_op_name] = consumer->op().op_conf(); } OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name); for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == parallel_cast_out_lbi) { const auto& new_val = GenLogicalBlobName(parallel_cast_in_lbi); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(parallel_cast_out_lbi), old_val); } } } del_op_names.emplace_back(op_conf.name()); }); for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); } for (const auto& pair : op_name2nd_sbp_signature) { job_builder->AddNdSbpSignature4OpName(pair.first, pair.second); } job_builder->DelOps(del_op_names); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("PruneParallelCastOpsPass", PruneParallelCastOpsPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/prune_pinned_identity_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { namespace { class PrunePinnedIdentityOpPass final : public JobPass { public: PrunePinnedIdentityOpPass() = default; ~PrunePinnedIdentityOpPass() override = default; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; Maybe PrunePinnedIdentityOp(JobBuilder* job_builder, std::vector* outdated_ops, const OpGraph& op_graph, const std::string& lbn) { auto lbi = GenLogicalBlobId(lbn); const OpNode* op_node = op_graph.OpNode4OpName(lbi.op_name()); CHECK_EQ_OR_RETURN(op_node->in_edges().size(), 1); // NOLINT const OperatorConf& op_conf = op_node->op().op_conf(); CHECK_OR_RETURN(op_conf.has_user_conf()); // NOLINT const std::string& op_type_name = op_conf.user_conf().op_type_name(); CHECK_OR_RETURN(op_type_name == "pinned_identity"); // NOLINT // skip prune if the pinned identity has `ctrl_in_op` if (!op_conf.ctrl_in_op_name().empty()) { return lbn; } const user_op::UserOpConfWrapper user_op_conf(op_conf); const LogicalBlobId& in_lbi = GenLogicalBlobId(user_op_conf.input("in", 0)); const LogicalBlobId& out_lbi = GenLogicalBlobId(user_op_conf.output("out", 0)); op_node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == out_lbi) { if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) { CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf())); } OperatorConf& mut_consumer_op = CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name())); const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, GenLogicalBlobName(in_lbi)); CHECK_EQ(old_lbn, GenLogicalBlobName(out_lbi)); } } }); outdated_ops->push_back(op_conf.name()); return GenLogicalBlobName(in_lbi); } Maybe PrunePinnedIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const { if (!job->job_conf().has_train_conf()) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); HashMap pruned_lbns; std::vector outdated_ops; TrainConf* train_conf = job->mutable_job_conf()->mutable_train_conf(); // prune loss pinned identity for (int i = 0; i < train_conf->loss_lbn_size(); ++i) { const auto& pinned_loss_lbn = train_conf->loss_lbn(i); auto it = pruned_lbns.find(pinned_loss_lbn); if (it == pruned_lbns.end()) { const auto& loss_lbn = JUST(PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_loss_lbn)); it = pruned_lbns.emplace(pinned_loss_lbn, *loss_lbn).first; } train_conf->set_loss_lbn(i, it->second); } // prune loss initial gradient pinned identity for (int i = 0; i < train_conf->loss_grad_lbn_size(); ++i) { const auto& pinned_loss_grad_lbn = train_conf->loss_grad_lbn(i); auto it = pruned_lbns.find(pinned_loss_grad_lbn); if (it == pruned_lbns.end()) { const auto& loss_grad_lbn = JUST(PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_loss_grad_lbn)); it = pruned_lbns.emplace(pinned_loss_grad_lbn, *loss_grad_lbn).first; } train_conf->set_loss_grad_lbn(i, it->second); } // prune variable gradient pinned identity for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) { auto* optimizer_conf = train_conf->mutable_optimizer_conf(i); for (int j = 0; j < optimizer_conf->variable_grad_lbns_size(); ++j) { const auto& pinned_variable_grad_lbn = optimizer_conf->variable_grad_lbns(j); if (pinned_variable_grad_lbn.empty()) { continue; } auto it = pruned_lbns.find(pinned_variable_grad_lbn); if (it == pruned_lbns.end()) { const auto& variable_grad_lbn = JUST( PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_variable_grad_lbn)); it = pruned_lbns.emplace(pinned_variable_grad_lbn, *variable_grad_lbn).first; } optimizer_conf->set_variable_grad_lbns(j, it->second); } } job_builder.DelOps(outdated_ops); JUST(job_builder.MutOpTransactionCommit()); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("PrunePinnedIdentityOpPass", PrunePinnedIdentityOpPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/quantization_aware_training.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job/job_conf.pb.h" #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/pass_util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/vm/symbol_storage.h" namespace oneflow { namespace { using OpTypeSet = HashSet; const std::string FAKE_QUANT_SUFFIX = "-fake-quant"; const std::string ZP_SUFFIX = "-fake-quant-zp"; const std::string MOVING_MAX_SUFFIX = "-fake-quant-moving-max"; const std::string MOVING_MIN_SUFFIX = "-fake-quant-moving-min"; const std::string MUL_BIAS_SUFFIX = "-fake-quant-mul-bias"; const std::string OBSERVER_SUFFIX = "-fake-quant-observer"; const std::string TRAIN_STEP_SUFFIX = "-fake-train-step"; Maybe VerifyQATList(const OpTypeSet& op_list) { for (const auto& op_type : op_list) { CHECK_OR_RETURN(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) << "Cannot find " << op_type << " of QuantAwareTraining list in OpRegistry."; } return Maybe::Ok(); } HashMap scale_map; Maybe GetScaleLbn(const std::string& lbn) { CHECK_OR_RETURN(scale_map.find(lbn) != scale_map.end()); return scale_map[lbn]; } Maybe IsConvBiasEdge(const QatConfig& qat_config, const OpEdge* edge, std::string* conv_input_scale_lbn, std::string* conv_weight_scale_lbn, int64_t* weight_scale_length) { const auto* dst_node = edge->dst_node(); const auto dst_op_type = dst_node->op().op_conf().user_conf().op_type_name(); auto GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode = [](const QatConfig& qat_config, const OpNode* conv_node, std::string* conv_input_scale_lbn, std::string* conv_weight_scale_lbn, int64_t* weight_scale_length) -> Maybe { *weight_scale_length = 1; for (const OpEdge* in_edge : conv_node->in_edges()) { CHECK_EQ_OR_RETURN(in_edge->lbis().size(), 1); const auto lbi = in_edge->lbis().front(); const auto ibn = in_edge->lbi2ibns().at(lbi); CHECK_EQ_OR_RETURN(ibn.size(), 1); CHECK_OR_RETURN(ibn[0] == "in_0" || ibn[0] == "weight_0"); if (ibn[0] == "in_0") { *conv_input_scale_lbn = *JUST(GetScaleLbn(GenLogicalBlobName(in_edge->lbis()[0]))); } else if (ibn[0] == "weight_0") { if (qat_config.per_channel_weight_quantization()) { *weight_scale_length = conv_node->LogicalBlobDesc4Lbi(lbi).shape().At(0); } *conv_weight_scale_lbn = *JUST(GetScaleLbn(GenLogicalBlobName(in_edge->lbis()[0]))); } } return Maybe::Ok(); }; if (dst_op_type == "conv2d") { CHECK_EQ_OR_RETURN(edge->lbis().size(), 1); const auto lbi = edge->lbis().front(); const auto ibn = edge->lbi2ibns().at(lbi); CHECK_EQ_OR_RETURN(ibn.size(), 1); if (ibn[0] == "bias_0") { JUST(GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode( qat_config, dst_node, conv_input_scale_lbn, conv_weight_scale_lbn, weight_scale_length)); return true; } } else if (dst_op_type == "bias_add") { // check whether the bias_add corresponds to a conv for (const OpEdge* edge : dst_node->in_edges()) { const auto* src_node = edge->src_node(); if (src_node->op().op_conf().user_conf().op_type_name() == "conv2d") { JUST(GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode( qat_config, src_node, conv_input_scale_lbn, conv_weight_scale_lbn, weight_scale_length)); return true; } } } return false; } bool IsWeightEdge(const OpEdge* edge) { return edge->src_node()->op().op_conf().has_variable_conf(); } bool IsBnInputEdge(const OpEdge* edge) { // Skip the inputs of bn for now. // In the complete qat pass, bn will be merged into conv. return edge->dst_node()->op().op_conf().user_conf().op_type_name() == "normalization"; } std::string OpTypeName4OpNode(const OpNode* node) { return node->op().op_conf().user_conf().op_type_name(); } using OpConfMap = HashMap; template OperatorConf Get1DZeroVariableOpConf(std::string name, const int64_t scope_symbol_id, const int64_t length, OpConfMap* inserted_ops) { OperatorConf variable_op_conf{}; variable_op_conf.set_name(name); variable_op_conf.set_scope_symbol_id(scope_symbol_id); VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); variable_conf->set_out("out"); *variable_conf->mutable_shape()->mutable_dim()->Add() = length; variable_conf->set_data_type(data_type); variable_conf->mutable_initializer()->mutable_constant_conf()->set_value(0); (*inserted_ops)[name] = variable_op_conf; return variable_op_conf; } Maybe GetInferenceOutputNode(const OpGraph& op_graph, OpNode* node) { OpNode* cur_node = node; if (node->op().op_conf().user_conf().op_type_name() == "conv2d" && node->out_edges().size() == 1) { OpNode* next_node = node->SoleOutEdge()->dst_node(); if (OpTypeName4OpNode(next_node) == "bias_add") { cur_node = next_node; if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); } } if (OpTypeName4OpNode(next_node) == "normalization") { cur_node = next_node; if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); } } if (OpTypeName4OpNode(next_node) == "relu") { cur_node = next_node; } } VLOG(3) << "For node: " << node->op().op_name(); VLOG(3) << "output node is: " << cur_node->op().op_name(); return cur_node; } bool PerLayerQuantizationAttr4Config(const QatConfig& qat_config) { return !qat_config.per_channel_weight_quantization(); } std::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) { return qat_config.symmetric() ? "symmetric" : "affine"; } // TODO: refactor the following 4 methods by registration Maybe QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { return std::string("google"); } else if (target_backend == "cambricon") { return std::string("cambricon"); } else { UNIMPLEMENTED_THEN_RETURN(); } } Maybe Int8List4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "") { return OpTypeSet{"add_n", "matmul", "batch_matmul", "conv2d", "tf_avg_pool_2d", "tf_max_pool_2d"}; } else if (target_backend == "cambricon" || target_backend == "tensorrt") { return OpTypeSet{"conv2d", "matmul"}; } else { UNIMPLEMENTED_THEN_RETURN(); } } Maybe TransparentList4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { return OpTypeSet{"reshape"}; } else if (target_backend == "cambricon") { return OpTypeSet{}; } else { UNIMPLEMENTED_THEN_RETURN(); } } Maybe InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { return true; } else if (target_backend == "cambricon") { return false; } else { UNIMPLEMENTED_THEN_RETURN(); } } user_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string& x, const std::string& y, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("broadcast_mul") .Input("x", x) .Input("y", y) .Output("z") .ScopeSymbolId(scope_symbol_id) .Build(); (*inserted_ops)[name] = op_wrapper.op_conf(); return op_wrapper; } Maybe MinMaxObserver(const std::string& name, const std::string& input, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("min_max_observer") .Input("in", input) .Output("scale") .Output("zero_point") .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("per_layer_quantization", PerLayerQuantizationAttr4Config(qat_config)) .ScopeSymbolId(scope_symbol_id) .Build(); (*inserted_ops)[name] = op_wrapper.op_conf(); return op_wrapper; } Maybe MovingMinMaxObserver( const std::string& name, const std::string& input, const std::string& train_step_lbn, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const std::string moving_max_name = name + MOVING_MAX_SUFFIX; const std::string moving_min_name = name + MOVING_MIN_SUFFIX; const auto moving_max_var = Get1DZeroVariableOpConf(moving_max_name, scope_symbol_id, 1, inserted_ops); const auto moving_min_var = Get1DZeroVariableOpConf(moving_min_name, scope_symbol_id, 1, inserted_ops); std::string observer_current_train_step = train_step_lbn; if (!GlobalJobDesc().IsTrain()) { const std::string train_step_name = name + TRAIN_STEP_SUFFIX; const auto train_step_var = Get1DZeroVariableOpConf( train_step_name, scope_symbol_id, 1, inserted_ops); observer_current_train_step = GenLogicalBlobName(train_step_var.name(), train_step_var.variable_conf().out()); } const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("moving_average_min_max_observer") .Input("in", input) .Input("current_train_step", observer_current_train_step) .Input("moving_max", GenLogicalBlobName(moving_max_var.name(), moving_max_var.variable_conf().out())) .Input("moving_min", GenLogicalBlobName(moving_min_var.name(), moving_min_var.variable_conf().out())) .Output("scale") .Output("zero_point") .Attr("training", GlobalJobDesc().IsTrain()) .Attr("stop_update_after_iters", qat_config.moving_min_max_stop_update_after_iters()) .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("momentum", qat_config.moving_min_max_momentum()) .ScopeSymbolId(scope_symbol_id) .Build(); (*inserted_ops)[name] = op_wrapper.op_conf(); return op_wrapper; } Maybe FakeQuantOp(const std::string& name, const std::string& input, const std::string& scale, const std::string& zero_point, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("fake_quantization") .Input("in", input) .Input("scale", scale) .Input("zero_point", zero_point) .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); (*inserted_ops)[name] = op_wrapper.op_conf(); return op_wrapper; } Maybe GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_step_lbn, std::string* scale, std::string* zero_point, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { std::string lbn = GenLogicalBlobName(edge->lbis().front()); std::string conv_input_scale_lbn; std::string conv_weight_scale_lbn; int64_t weight_scale_length; if (JUST(IsConvBiasEdge(qat_config, edge, &conv_input_scale_lbn, &conv_weight_scale_lbn, &weight_scale_length))) { // mul scale const std::string mul_scale_op_name = ReplaceSlashToDash4Lbn(lbn) + MUL_BIAS_SUFFIX; CHECK_OR_RETURN(inserted_ops->find(mul_scale_op_name) == inserted_ops->end()); const auto mul_scale_op = MultiplyOp(mul_scale_op_name, conv_input_scale_lbn, conv_weight_scale_lbn, scope_symbol_id, inserted_ops); *scale = mul_scale_op.output("z", 0); const std::string zp_var_name = ReplaceSlashToDash4Lbn(lbn) + ZP_SUFFIX; const auto zp_var = Get1DZeroVariableOpConf(zp_var_name, scope_symbol_id, weight_scale_length, inserted_ops); *zero_point = GenLogicalBlobName(zp_var.name(), zp_var.variable_conf().out()); } else { const std::string observer_op_name = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX; if (IsWeightEdge(edge)) { const auto observer_op = JUST(MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops)); *scale = observer_op->output("scale", 0); *zero_point = observer_op->output("zero_point", 0); } else { CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters()); const auto observer_op = JUST(MovingMinMaxObserver( observer_op_name, lbn, train_step_lbn, qat_config, scope_symbol_id, inserted_ops)); *scale = observer_op->output("scale", 0); *zero_point = observer_op->output("zero_point", 0); } } return Maybe::Ok(); } Maybe ReplaceInputLbn4DstNodeOfEdge(OpEdge* edge, const std::string& new_lbn, OpConfCache* op_conf_cache) { OpNode* dst_node = edge->dst_node(); LogicalBlobId cur_lbi = edge->lbis().front(); CHECK_EQ_OR_RETURN(1, edge->lbi2ibns().at(cur_lbi).size()); const std::string& dst_ibn = edge->lbi2ibns().at(cur_lbi).front(); OperatorConf dst_op_conf = op_conf_cache->GetLatest(dst_node->op().op_conf()); ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, new_lbn); op_conf_cache->Put(dst_op_conf); return Maybe::Ok(); } class QuantAwareTraining final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(QuantAwareTraining); QuantAwareTraining() = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().job_conf().enable_quantization_aware_training(); } Maybe Apply(Job* job, JobPassCtx* ctx) const override; private: Maybe InsertFakeQuantOp(const QatConfig& qat_config, const OpGraph& op_graph, const OpTypeSet& int8_list, const OpTypeSet& transparent_list, bool insert_quant_op_after_int8_ops, HashSet downstream_white, Job* job) const; }; Maybe IsNodeQuantizationEnabled(const OpNode& node) { int64_t scope_symbol_id = node.op().op_conf().scope_symbol_id(); CHECK_OR_RETURN(Singleton>::Get()->Has(scope_symbol_id)); // NOLINT const Scope& scope = Singleton>::Get()->Get(scope_symbol_id); return scope.Bool("quantization_aware_training"); } Maybe QuantAwareTraining::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); CHECK_OR_RETURN(GlobalJobDesc().DefaultDataType() == DataType::kFloat); const auto qat_config = ctx->job_desc().job_conf().qat_config(); OpTypeSet int8_list = *JUST(Int8List4QatConfig(qat_config)); OpTypeSet transparent_list = *JUST(TransparentList4QatConfig(qat_config)); // if `insert_quant_op_after_int8_ops` is false, // always insert quant op before int8 ops. // if `insert_quant_op_after_int8_ops` is true, // always insert quant op after int8 ops bool insert_quant_op_after_int8_ops = JUST(InsertQuantOpAfterInt8Ops4QatConfig(qat_config)); JUST(VerifyQATList(int8_list)); JUST(VerifyQATList(transparent_list)); std::function OpName4Node = [](OpNode* const& node) { return node->op().op_name(); }; HashSet white_set; DfsTopoGraphTraversal( op_graph, false, [&int8_list](OpNode* node) { return IsNodeInList(int8_list, node); }, [&](OpNode* node) { return IsNodeInList(transparent_list, node); }, [&](OpNode* node) { return IsKeyFound(white_set, node); }, [&](OpNode* node) { INSERT_CHECK(white_set.insert(node)); if (node->op().op_conf().user_conf().op_type_name() == "conv2d" && node->out_edges().size() == 1) { OpNode* next_node = node->SoleOutEdge()->dst_node(); if (OpTypeName4OpNode(next_node) == "bias_add") { INSERT_CHECK(white_set.insert(next_node)); // TODO(daquexian): mark these special nodes if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); } } if (OpTypeName4OpNode(next_node) == "normalization") { INSERT_CHECK(white_set.insert(next_node)); if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); } } if (OpTypeName4OpNode(next_node) == "relu") { INSERT_CHECK(white_set.insert(next_node)); } } }); VLOG(3) << "white_set include: " << Container2Str, OpNode*>(white_set, OpName4Node); JUST(InsertFakeQuantOp(ctx->job_desc().job_conf().qat_config(), op_graph, int8_list, transparent_list, insert_quant_op_after_int8_ops, white_set, job)); return Maybe::Ok(); } // TODO: remove int8_list, transparent_list and insert_quant_op_after_int8_ops arguments Maybe QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, const OpGraph& op_graph, const OpTypeSet& int8_list, const OpTypeSet& transparent_list, const bool insert_quant_op_after_int8_ops, HashSet white_set, Job* job) const { JobBuilder job_builder(job); HashSet white_set_edges; auto EdgeName4Edge = [](OpEdge* const& edge) { return std::string("edge of\t") + edge->src_node()->op().op_name() + "\tto\t" + edge->dst_node()->op().op_name(); }; auto AddWhiteSetEdge = [&white_set_edges, &EdgeName4Edge](OpEdge* edge) -> Maybe { VLOG(3) << "insert " << EdgeName4Edge(edge); CHECK_EQ_OR_RETURN(edge->lbis().size(), 1); const std::string lbn = GenLogicalBlobName(edge->lbis().front()); scale_map[lbn] = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX + "/scale_0"; VLOG(3) << "set " << lbn << " to " << scale_map[lbn]; INSERT_CHECK_OR_RETURN(white_set_edges.insert(edge)); return Maybe::Ok(); }; auto PropagateScale = [](OpNode* node) -> Maybe { CHECK_EQ_OR_RETURN(node->in_edges().size(), 1); CHECK_EQ_OR_RETURN(node->SoleInEdge()->lbis().size(), 1); for (OpEdge* edge : node->out_edges()) { CHECK_EQ_OR_RETURN(edge->lbis().size(), 1); const std::string node_input_lbn = GenLogicalBlobName(node->SoleInEdge()->lbis().front()); const std::string lbn = GenLogicalBlobName(edge->lbis().front()); if (scale_map.find(node_input_lbn) != scale_map.end()) { scale_map[lbn] = scale_map[node_input_lbn]; } } return Maybe::Ok(); }; { JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe { if (IsKeyFound(white_set, node)) { for (OpEdge* edge : node->in_edges()) { if (IsKeyFound(white_set, edge->src_node())) { continue; } if (JUST(IsNodeQuantizationEnabled(*edge->dst_node()))) { JUST(AddWhiteSetEdge(edge)); } } if (IsNodeInList(int8_list, node)) { if (insert_quant_op_after_int8_ops) { OpNode* inference_node = JUST(GetInferenceOutputNode(op_graph, node)); if (JUST(IsNodeQuantizationEnabled(*inference_node))) { for (OpEdge* edge : inference_node->out_edges()) { JUST(AddWhiteSetEdge(edge)); } } } else { if (JUST(IsNodeQuantizationEnabled(*node))) { for (OpEdge* edge : node->in_edges()) { if (white_set_edges.find(edge) == white_set_edges.end()) { JUST(AddWhiteSetEdge(edge)); } } } } } else if (IsNodeInList(transparent_list, node)) { JUST(PropagateScale(node)); } else { // this is bias_add, relu or bn op in "conv -> bias_add -> bn -> relu" pattern, // do nothing } } return Maybe::Ok(); })); VLOG(3) << "white_set_edges: " << Container2Str, OpEdge*>(white_set_edges, EdgeName4Edge); } // group edges by lbn so that we can use `src_node` when calling `AddOps` HashMap> edges_group_by_lbn; { for (OpEdge* edge : white_set_edges) { CHECK_EQ_OR_RETURN(1, edge->lbis().size()); std::string lbn = GenLogicalBlobName(edge->lbis().front()); edges_group_by_lbn[lbn].emplace_back(edge); } } OpConfCache op_conf_cache; for (auto& pair : edges_group_by_lbn) { const std::string& lbn = pair.first; const OpNode* src_node = pair.second.front()->src_node(); const BlobDesc& blob_desc = src_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(lbn)); if (blob_desc.data_type() != DataType::kFloat) { continue; } OpConfMap inserted_ops; for (OpEdge* edge : pair.second) { if (IsBnInputEdge(edge)) { continue; } std::string scale; std::string zero_point; const int64_t scope_symbol_id = edge->src_node()->op().op_conf().scope_symbol_id(); JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale, &zero_point, qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX; const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_output_name = fake_quant_op->output("out", 0); JUST(ReplaceInputLbn4DstNodeOfEdge(edge, fake_quant_op_output_name, &op_conf_cache)); } for (const auto& pair : inserted_ops) { VLOG(3) << "Insert op: " << pair.second.DebugString() << " between " << lbn; job_builder.AddOps(src_node->parallel_desc().parallel_conf(), {pair.second}); } } job_builder.MutOpsOnlyOnce(op_conf_cache.op_confs()); return Maybe::Ok(); } REGISTER_JOB_PASS("QuantAwareTraining", QuantAwareTraining); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h" #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h" #include "oneflow/core/embedding/embedding_manager.h" namespace oneflow { namespace { std::string BuildIdentityOp(JobBuilder* job_builder, const std::string& in_lbn, const ParallelConf& parallel_conf, const user_op::UserOpConfWrapper& embedding_op) { user_op::UserOpConfWrapperBuilder identity_op_builder(embedding_op.op_name() + "_identity_" + NewUniqueId()); user_op::UserOpConfWrapper identity_op = identity_op_builder.OpTypeName("identity") .Input("in", in_lbn) .Output("out") .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(parallel_conf, {identity_op.op_conf()}); return identity_op.output("out", 0); } Maybe DynamicLossScaleAddGradient( JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& shadow_op_name2grad_lbn, int64_t scope_symbol_id, const ParallelConf& parallel_conf) { if (job_builder->job().job_conf().train_conf().has_dynamic_loss_scale_policy()) { const auto& dynamic_loss_scale_state = JUST(ctx->GetState("dynamic_loss_scale_state")); const LogicalBlobId count_not_finite_lbi = GenLogicalBlobId(dynamic_loss_scale_state.count_not_finite_lbn()); const OpNode* op_node = op_graph.OpNode4OpName(count_not_finite_lbi.op_name()); if (op_node->op().op_conf().has_user_conf() && op_node->op().op_conf().user_conf().op_type_name() == "identity") { const user_op::UserOpConfWrapper identity_op_conf(op_node->op().op_conf()); std::string new_count_not_finite_lbn; if (shadow_op_name2grad_lbn.size() == 1) { const std::string& grad_lbn = shadow_op_name2grad_lbn.begin()->second; const auto count_not_finite_op = user_op::UserOpConfWrapperBuilder("OneEmbedding-DynamicLossScale-CountNotFinite-" + NewUniqueId()) .Op("count_not_finite") .Input("x", grad_lbn) .Output("y") .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(parallel_conf, {count_not_finite_op.op_conf()}); new_count_not_finite_lbn = count_not_finite_op.output("y", 0); } else { auto multi_count_not_finite_op_builder = user_op::UserOpConfWrapperBuilder("OneEmbedding-DynamicLossScale-MultiCountNotFinite-" + NewUniqueId()) .Op("multi_count_not_finite") .Output("y") .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id()); for (const auto& pair : shadow_op_name2grad_lbn) { multi_count_not_finite_op_builder.Input("x", pair.second); } const auto multi_count_not_finite_op = multi_count_not_finite_op_builder.Build(); job_builder->AddOps(parallel_conf, {multi_count_not_finite_op.op_conf()}); new_count_not_finite_lbn = multi_count_not_finite_op.output("y", 0); } user_op::UserOpConfWrapperBuilder add_op_builder( "OneEmbedding-DynamicLossScale-CountNotFinite-Add_" + NewUniqueId()); const auto add_op = add_op_builder.Op("add_n") .Input("in", identity_op_conf.input("in", 0)) .Input("in", new_count_not_finite_lbn) .Output("out") .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {add_op.op_conf()}); OperatorConf new_identity_conf = identity_op_conf.op_conf(); const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&new_identity_conf, "in_0", add_op.output("out", 0)); CHECK_EQ_OR_RETURN(identity_op_conf.input("in", 0), old_val); job_builder->MutOpsOnlyOnce({new_identity_conf}); } else { UNIMPLEMENTED_THEN_RETURN(); } } return Maybe::Ok(); } void BuildEmbeddingLookup( JobPassCtx* ctx, JobBuilder* job_builder, const int64_t embedding_size, const int64_t line_size, const std::string& embedding_name, const int64_t seed, bool has_embedding_prefetch, const ParallelConf& parallel_conf, const user_op::UserOpConfWrapper& embedding_op, const std::string& prefetch_num_unique_ids_lbn, const std::string& prefetch_unique_ids_lbn, const std::string& prefetch_unique_table_ids_lbn, const std::string& num_unique_ids_lbn, const std::string& unique_ids_lbn, const std::string& unique_table_ids_lbn, std::string* embedding_lbn, std::string* unique_values_lbn, OperatorConf* embedding_prefetch_op_conf, OperatorConf* embedding_lookup_op_conf) { std::string context_lbn; if (has_embedding_prefetch) { // embedding prefetch op user_op::UserOpConfWrapperBuilder embedding_prefetch_op_builder( embedding_op.op_name() + "_embedding_prefetch" + NewUniqueId()); user_op::UserOpConfWrapper embedding_prefetch_op = embedding_prefetch_op_builder.OpTypeName("embedding_prefetch") .Input("num_unique_ids", prefetch_num_unique_ids_lbn) .Input("unique_ids", prefetch_unique_ids_lbn) .Input("table_ids", prefetch_unique_table_ids_lbn) .Output("context") .Attr("embedding_size", embedding_size) .Attr("line_size", line_size) .Attr("embedding_tables", embedding_op.attr("embedding_tables")) .Attr("embedding_name", embedding_name) .Attr("seed", seed) .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()) .Build(); *embedding_prefetch_op_conf = embedding_prefetch_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { embedding_prefetch_op_conf->set_stream_name_hint(embedding_name + "_EMBEDDING"); } context_lbn = embedding_prefetch_op.output("context", 0); } // embedding lookup op user_op::UserOpConfWrapperBuilder embedding_lookup_op_builder( embedding_op.op_name() + "_embedding_lookup" + NewUniqueId()); embedding_lookup_op_builder.OpTypeName("embedding_lookup") .Input("num_unique_ids", num_unique_ids_lbn) .Input("unique_ids", unique_ids_lbn) .Input("table_ids", unique_table_ids_lbn) .Output("unique_values") .Attr("dtype", embedding_op.attr("dtype")) .Attr("embedding_size", embedding_size) .Attr("line_size", line_size) .Attr("embedding_tables", embedding_op.attr("embedding_tables")) .Attr("embedding_name", embedding_name) .Attr("seed", seed) .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()); if (has_embedding_prefetch) { embedding_lookup_op_builder.Input("context", context_lbn); } bool has_embeddings_output = (line_size != embedding_size) || ctx->job_desc().enable_auto_mixed_precision(); if (has_embeddings_output) { DataType embeddings_dtype = ctx->job_desc().enable_auto_mixed_precision() ? DataType::kFloat16 : embedding_op.attr("dtype"); embedding_lookup_op_builder.Output("embeddings") .Attr("embeddings_dtype", embeddings_dtype); } user_op::UserOpConfWrapper embedding_lookup_op = embedding_lookup_op_builder.Build(); *embedding_lookup_op_conf = embedding_lookup_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { embedding_lookup_op_conf->set_stream_name_hint(embedding_name + "_EMBEDDING"); } if (has_embeddings_output) { *embedding_lbn = embedding_lookup_op.output("embeddings", 0); } else { *embedding_lbn = embedding_lookup_op.output("unique_values", 0); } *unique_values_lbn = embedding_lookup_op.output("unique_values", 0); } void BuildEmbeddingShuffle(JobBuilder* job_builder, const std::string& embedding_name, int64_t embedding_size, const ParallelConf& parallel_conf, const user_op::UserOpConfWrapper& embedding_op, const std::string& inverse_indices_lbn, const std::string& inner_inverse_unique_partition_indices_lbn, const std::string& num_unique_matrix_lbn, const std::string& embedding_lbn, std::vector* add_ops, std::string* new_embeddings_lbn) { const bool is_train_job = job_builder->job().job_conf().has_train_conf(); user_op::UserOpConfWrapperBuilder embedding_shuffle_op_builder( embedding_op.op_name() + "_embedding_shuffle" + NewUniqueId()); user_op::UserOpConfWrapper embedding_shuffle_op = embedding_shuffle_op_builder.OpTypeName("embedding_shuffle") .Input("cur_rank_embeddings", embedding_lbn) .Input("cur_rank_inverse_indices", inverse_indices_lbn) .Input("inverse_unique_partition_indices", inner_inverse_unique_partition_indices_lbn) .Input("num_unique_matrix", num_unique_matrix_lbn) .Attr("embedding_name", embedding_name) .Attr("embedding_size", embedding_size) .Attr("is_train", is_train_job) .Output("embeddings") .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()) .Build(); OperatorConf embedding_shuffle_new_op_conf = embedding_shuffle_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false) && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM", true)) { embedding_shuffle_new_op_conf.set_stream_name_hint(embedding_name + "_EMBEDDING"); } add_ops->push_back(embedding_shuffle_new_op_conf); *new_embeddings_lbn = embedding_shuffle_op.output("embeddings", 0); } void BuildEmbeddingGradientShuffle( JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const OpNode* op_node, const std::string& embedding_name, int64_t embedding_size, const bool use_system_gather, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, const user_op::UserOpConfWrapper& embedding_op, const std::string& inverse_indices_lbn, const std::string& inner_inverse_unique_partition_indices_lbn, const std::string& num_unique_matrix_lbn, const std::string& update_embedding_grad, const bool has_clip_grad, std::string* cur_rank_unique_embedding_grad_lbn) { std::string update_embedding_grad_lbn = update_embedding_grad; if (ctx->job_desc().enable_auto_mixed_precision() && !ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16", true)) { auto cast_op = user_op::UserOpConfWrapperBuilder(embedding_op.op_name() + "_before_grad_shuffle_cast_h2f") .Op("cast") .Input("in", update_embedding_grad_lbn) .Output("out") .Attr("dtype", DataType::kFloat) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {cast_op.op_conf()}); update_embedding_grad_lbn = cast_op.output("out", 0); } if (use_system_gather) { const int64_t num_segments = op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi("ids_0")).shape().elem_cnt(); user_op::UserOpConfWrapperBuilder unsorted_segment_sum_op_builder(embedding_op.op_name() + "_unsorted_segment_sum"); user_op::UserOpConfWrapper unsorted_segment_sum_op = unsorted_segment_sum_op_builder.OpTypeName("unsorted_segment_sum") .Input("data", update_embedding_grad_lbn) .Input("segment_ids", inverse_indices_lbn) .Output("out") .Attr("num_segments", num_segments) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {unsorted_segment_sum_op.op_conf()}); *cur_rank_unique_embedding_grad_lbn = unsorted_segment_sum_op.output("out", 0); } else { // embedding_gradient_shuffle op // if no dynamic loss scale or no clip_grad, we think gradient shuffle grad's invalid buffer // need not to be memset. const bool has_dynamic_loss_scale = job_builder->job().job_conf().train_conf().has_dynamic_loss_scale_policy(); const bool only_zero_valid_grad = (!has_clip_grad) && (!has_dynamic_loss_scale); user_op::UserOpConfWrapperBuilder embedding_gradient_shuffle_op_builder( embedding_op.op_name() + "_embedding_gradient_shuffle" + NewUniqueId()); user_op::UserOpConfWrapper embedding_gradient_shuffle_op = embedding_gradient_shuffle_op_builder.OpTypeName("embedding_gradient_shuffle") .Input("cur_rank_inverse_indices", inverse_indices_lbn) .Input("inverse_unique_partition_indices", inner_inverse_unique_partition_indices_lbn) .Input("embedding_grad", update_embedding_grad_lbn) .Input("num_unique_matrix", num_unique_matrix_lbn) .Output("cur_rank_unique_embedding_grad") .Attr("embedding_name", embedding_name) .Attr("embedding_size", embedding_size) .Attr("only_zero_valid_grad", only_zero_valid_grad) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); OperatorConf embedding_gradient_shuffle_new_op_conf = embedding_gradient_shuffle_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false) && ParseBooleanFromEnv( "ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_INDEPENTENT_STREAM", true)) { embedding_gradient_shuffle_new_op_conf.set_stream_name_hint(embedding_name + "_EMBEDDING"); } job_builder->AddOps(embedding_parallel_conf, {embedding_gradient_shuffle_new_op_conf}); *cur_rank_unique_embedding_grad_lbn = embedding_gradient_shuffle_op.output("cur_rank_unique_embedding_grad", 0); } if (ctx->job_desc().enable_auto_mixed_precision() && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16", true) && (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_NOT_FUSE_CAST_TO_UPDATE", false) || has_clip_grad)) { auto cast_op = user_op::UserOpConfWrapperBuilder(embedding_op.op_name() + "_cast_h2f") .Op("cast") .Input("in", *cur_rank_unique_embedding_grad_lbn) .Output("out") .Attr("dtype", DataType::kFloat) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); *cur_rank_unique_embedding_grad_lbn = cast_op.output("out", 0); job_builder->AddOps(embedding_parallel_conf, {cast_op.op_conf()}); } } double GetLossInstanceNumScaleFactor(const OpGraph& op_graph, JobBuilder* job_builder) { double scale_factor = 1; std::function LossOpNode4OpName; CHECK_JUST(MakeGetterLossOpNode4OpName(op_graph, &LossOpNode4OpName)); const TrainConf& train_conf = job_builder->job().job_conf().train_conf(); HashMap loss_lbi2op_node; CHECK_GT(train_conf.loss_lbn().size(), 0); for (const auto& loss_lbn : train_conf.loss_lbn()) { const auto& lbi = GenLogicalBlobId(loss_lbn); CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second); } const Shape src_time_shape({1, 1}); const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt(); bool all_loss_time_shape_eq_src = true; for (const auto& pair : loss_lbi2op_node) { const int64_t time_shape_elem_cnt = CHECK_JUST(pair.second->op().GetOpTimeShape())->elem_cnt(); if (time_shape_elem_cnt != source_time_shape_elem_cnt) { CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0); all_loss_time_shape_eq_src = false; } } if (all_loss_time_shape_eq_src) { const BlobDesc* blob_desc = nullptr; for (const auto& pair : loss_lbi2op_node) { const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); } blob_desc = cur_blob_desc; } CHECK(blob_desc != nullptr); scale_factor = 1.0f / static_cast(blob_desc->shape().elem_cnt()); } else { std::unique_ptr blob_desc; for (const auto& pair : loss_lbi2op_node) { const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); // TODO: support dynamic CHECK(!cur_blob_desc->is_dynamic()); const DataType loss_data_type = cur_blob_desc->data_type(); const int64_t time_shape_elem_cnt = CHECK_JUST(pair.second->op().GetOpTimeShape())->elem_cnt(); // TODO: consider sbp const int64_t loss_elem_cnt = cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt; if (blob_desc) { CHECK_EQ(blob_desc->data_type(), loss_data_type); CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt); } else { blob_desc.reset( new BlobDesc(Shape({loss_elem_cnt}), loss_data_type, cur_blob_desc->memory_format())); } } scale_factor = 1.0f / static_cast(blob_desc->shape().elem_cnt()); } return scale_factor; } void BuildIdShuffle(bool use_system_gather, const std::string& embedding_name, const user_op::UserOpConfWrapper& embedding_op, std::vector* add_ops, std::string* prefetch_num_unique_lbn, std::string* prefetch_unique_ids_lbn, std::string* prefetch_unique_table_ids_lbn, std::string* inner_inverse_unique_partition_indices_lbn, std::string* num_unique_ids_lbn, std::string* unique_ids_lbn, std::string* unique_table_ids_lbn, std::string* inverse_indices_lbn, std::string* num_unique_matrix_lbn) { const int32_t num_tables = embedding_op.attr("num_tables"); const int64_t padding_idx = embedding_op.attr("padding_idx"); const int64_t has_padding_idx = embedding_op.attr("has_padding_idx"); bool enable_pipelined_execution = !ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false); if (use_system_gather) { user_op::UserOpConfWrapperBuilder unique_op_builder(embedding_op.op_name() + "_unique_ids_and_tables"); unique_op_builder.OpTypeName("unique_key_value_pair") .Input("keys", embedding_op.input("ids", 0)) .Output("num_unique") .Output("unique_keys") .Output("unique_values") .Output("inverse_indices") .Attr("num_tables", num_tables) .Attr("padding_idx", padding_idx) .Attr("has_padding_idx", has_padding_idx) .Attr("embedding_name", embedding_name) .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()); if (embedding_op.has_input("table_ids", 0)) { unique_op_builder.Input("values", embedding_op.input("table_ids", 0)); } user_op::UserOpConfWrapper unique_op = unique_op_builder.Build(); OperatorConf unique_new_op_conf = unique_op.op_conf(); if (enable_pipelined_execution) { unique_new_op_conf.set_stream_name_hint(embedding_name + "_ID_SHUFFLE"); } add_ops->push_back(unique_new_op_conf); *num_unique_ids_lbn = unique_op.output("num_unique", 0); *unique_ids_lbn = unique_op.output("unique_keys", 0); *unique_table_ids_lbn = unique_op.output("unique_values", 0); *inverse_indices_lbn = unique_op.output("inverse_indices", 0); *prefetch_num_unique_lbn = *num_unique_ids_lbn; *prefetch_unique_ids_lbn = *unique_ids_lbn; *prefetch_unique_table_ids_lbn = *unique_table_ids_lbn; } else { user_op::UserOpConfWrapperBuilder id_shuffle_op_builder(embedding_op.op_name() + "_id_shuffle" + NewUniqueId()); id_shuffle_op_builder.OpTypeName("id_shuffle") .Input("ids", embedding_op.input("ids", 0)) .Output("inverse_unique_partition_indices") .Output("cur_rank_num_unique") .Output("cur_rank_unique_ids") .Output("cur_rank_unique_table_ids") .Output("cur_rank_inverse_indices") .Output("num_unique_matrix") .Attr("num_tables", num_tables) .Attr("padding_idx", padding_idx) .Attr("has_padding_idx", has_padding_idx) .Attr("embedding_name", embedding_name) .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()); if (embedding_op.has_input("table_ids", 0)) { id_shuffle_op_builder.Input("table_ids", embedding_op.input("table_ids", 0)); } user_op::UserOpConfWrapper id_shuffle_op = id_shuffle_op_builder.Build(); OperatorConf id_shuffle_new_op_conf = id_shuffle_op.op_conf(); if (enable_pipelined_execution) { id_shuffle_new_op_conf.set_stream_name_hint(embedding_name + "_ID_SHUFFLE"); } add_ops->push_back(id_shuffle_new_op_conf); if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true)) { // add id_shuffle_copy_out, so the consumer can use light_actor and cuda_graph. user_op::UserOpConfWrapperBuilder identity_op_builder( embedding_op.op_name() + "_id_shuffle_copy_out_" + NewUniqueId()); user_op::UserOpConfWrapper identity_op = identity_op_builder.OpTypeName("id_shuffle_copy_out") .Attr("embedding_name", embedding_name) .Input("inverse_unique_partition_indices", id_shuffle_op.output("inverse_unique_partition_indices", 0)) .Input("cur_rank_num_unique", id_shuffle_op.output("cur_rank_num_unique", 0)) .Input("cur_rank_unique_ids", id_shuffle_op.output("cur_rank_unique_ids", 0)) .Input("cur_rank_unique_table_ids", id_shuffle_op.output("cur_rank_unique_table_ids", 0)) .Input("cur_rank_inverse_indices", id_shuffle_op.output("cur_rank_inverse_indices", 0)) .Input("num_unique_matrix", id_shuffle_op.output("num_unique_matrix", 0)) .Output("out_inverse_unique_partition_indices") .Output("out_cur_rank_num_unique") .Output("out_cur_rank_unique_ids") .Output("out_cur_rank_unique_table_ids") .Output("out_cur_rank_inverse_indices") .Output("out_num_unique_matrix") .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id()) .Build(); OperatorConf identity_op_conf = identity_op.op_conf(); if (enable_pipelined_execution) { identity_op_conf.set_stream_name_hint(embedding_name + "_EMBEDDING"); } add_ops->push_back(identity_op_conf); *inner_inverse_unique_partition_indices_lbn = identity_op.output("out_inverse_unique_partition_indices", 0); *num_unique_ids_lbn = identity_op.output("out_cur_rank_num_unique", 0); *unique_ids_lbn = identity_op.output("out_cur_rank_unique_ids", 0); *unique_table_ids_lbn = identity_op.output("out_cur_rank_unique_table_ids", 0); *inverse_indices_lbn = identity_op.output("out_cur_rank_inverse_indices", 0); *num_unique_matrix_lbn = identity_op.output("out_num_unique_matrix", 0); } else { *inner_inverse_unique_partition_indices_lbn = id_shuffle_op.output("inverse_unique_partition_indices", 0); *num_unique_ids_lbn = id_shuffle_op.output("cur_rank_num_unique", 0); *unique_ids_lbn = id_shuffle_op.output("cur_rank_unique_ids", 0); *unique_table_ids_lbn = id_shuffle_op.output("cur_rank_unique_table_ids", 0); *inverse_indices_lbn = id_shuffle_op.output("cur_rank_inverse_indices", 0); *num_unique_matrix_lbn = id_shuffle_op.output("num_unique_matrix", 0); } *prefetch_num_unique_lbn = id_shuffle_op.output("cur_rank_num_unique", 0); *prefetch_unique_ids_lbn = id_shuffle_op.output("cur_rank_unique_ids", 0); *prefetch_unique_table_ids_lbn = id_shuffle_op.output("cur_rank_unique_table_ids", 0); } } void MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t line_size, const std::vector& values, std::string* initializer_attr) { if (embedding_size == line_size) { return; } const int32_t num_states = line_size / embedding_size - 1; CHECK_GT(num_states, 0) << "num_states " << num_states; CHECK(values.size() == 0 || num_states == values.size()) << "must set " << num_states << " optimizer states init value, but get " << values.size(); nlohmann::json initializers; for (int32_t i = 0; i < num_states; ++i) { nlohmann::json initializer; initializer["type"] = "constant"; const float initial_value = values.size() > 0 ? values.at(i) : 0.0; initializer["value"] = initial_value; initializers.push_back(initializer); } *initializer_attr = initializers.dump(); } void ScaleGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, const bool has_clip_grad, const std::string& embedding_grad_lbn, std::string* new_embedding_grad_lbn, std::string* update_skip_if_lbn, std::string* fuse_to_update_down_scale_by_lbn, double* fuse_to_update_scale) { *new_embedding_grad_lbn = embedding_grad_lbn; const TrainConf& train_conf = job_builder->job().job_conf().train_conf(); double scale = GetLossInstanceNumScaleFactor(op_graph, job_builder); if (train_conf.has_dynamic_loss_scale_policy()) { const auto& dynamic_loss_scale_state = CHECK_JUST(ctx->GetState("dynamic_loss_scale_state")); const std::string& loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn(); *update_skip_if_lbn = dynamic_loss_scale_state.count_not_finite_lbn(); if (has_clip_grad) { const LogicalBlobId loss_scale_val_lbi = GenLogicalBlobId(loss_scale_val_lbn); const OpNode* loss_scale_node = op_graph.OpNode4OpName(loss_scale_val_lbi.op_name()); auto inv_scale_op = user_op::UserOpConfWrapperBuilder( "OneEmbedding-DynamicLossScale-Reciprocal-" + NewUniqueId()) .Op("reciprocal") .Input("x", loss_scale_val_lbn) .Output("y") .ScopeSymbolId(loss_scale_node->op().op_conf().scope_symbol_id()) .Build(); job_builder->AddOps(loss_scale_node->parallel_desc().parallel_conf(), {inv_scale_op.op_conf()}); auto scalar_mul_op = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ModelDiffScale-ScalarMul-" + NewUniqueId()) .Op("scalar_mul_by_tensor") .Input("x", *new_embedding_grad_lbn) .Input("scalar", inv_scale_op.output("y", 0)) .Output("y") .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {scalar_mul_op.op_conf()}); *new_embedding_grad_lbn = scalar_mul_op.output("y", 0); } else { *fuse_to_update_down_scale_by_lbn = loss_scale_val_lbn; } } else if (train_conf.has_loss_scale_factor()) { double down_scale_factor = 1.0f / train_conf.loss_scale_factor(); scale *= down_scale_factor; } if (has_clip_grad) { auto scalar_mul_op = user_op::UserOpConfWrapperBuilder("OneEmbedding-ModelDiffScale-ScalarMul-" + NewUniqueId()) .Op("scalar_mul") .Input("in", *new_embedding_grad_lbn) .Output("out") .Attr("has_float_operand", true) .Attr("float_operand", scale) .Attr("has_int_operand", false) .Attr("int_operand", 0) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {scalar_mul_op.op_conf()}); *new_embedding_grad_lbn = scalar_mul_op.output("out", 0); *fuse_to_update_scale = 1.0; } else { *fuse_to_update_scale = scale; } } bool IsSupportFusedUpdatePut(const bool is_full_cache, const bool enable_auto_mixed_precision, const bool is_sgd, const std::string& down_scale_by_lbn, const std::string& skip_if_lbn, const float l1, const float l2, const float weight_decay) { if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSE_UPDATE_PUT", true)) { return false; } if (!is_full_cache) { return false; } if (!enable_auto_mixed_precision) { return false; } if (!is_sgd) { return false; } if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16", true)) { return false; } if (!down_scale_by_lbn.empty()) { return false; } if (!skip_if_lbn.empty()) { return false; } if (l1 != 0) { return false; } if (l2 != 0) { return false; } if (weight_decay != 0) { return false; } return true; } void BuildEmbeddingUpdate( JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, const bool is_full_cache, const int64_t embedding_size, const int64_t line_size, const float l1, const float l2, const std::string& embedding_name, const OptimizerConf& optimizer_conf, const user_op::UserOpConfWrapper& embedding_op, const std::string& num_unique_ids_lbn, const std::string& unique_ids_lbn, const std::string& unique_values_lbn, const std::string& embedding_grad_lbn, const std::string& learning_rate_lbn, std::string* new_embedding_grad_lbn, std::string* state_initializer, OperatorConf* embedding_update_new_op_conf) { const TrainConf& train_conf = job_builder->job().job_conf().train_conf(); const bool has_clip_grad = optimizer_conf.has_clip_conf(); *new_embedding_grad_lbn = embedding_grad_lbn; std::string update_skip_if_lbn; std::string fuse_to_update_down_scale_by_lbn; double fuse_to_update_scale = 1.0; ScaleGrad(ctx, op_graph, job_builder, embedding_parallel_conf, embedding_scope_symbol_id, has_clip_grad, embedding_grad_lbn, new_embedding_grad_lbn, &update_skip_if_lbn, &fuse_to_update_down_scale_by_lbn, &fuse_to_update_scale); if (IsSupportFusedUpdatePut(is_full_cache, ctx->job_desc().enable_auto_mixed_precision(), optimizer_conf.has_naive_conf(), fuse_to_update_down_scale_by_lbn, update_skip_if_lbn, l1, l2, optimizer_conf.weight_decay_conf().weight_decay_rate())) { user_op::UserOpConfWrapperBuilder fused_embedding_update_put_op_builder( embedding_op.op_name() + "_fused_embedding_update_put" + NewUniqueId()); user_op::UserOpConfWrapper fused_embedding_update_put_op = fused_embedding_update_put_op_builder.OpTypeName("one_embedding_fused_sgd_update_put") .Input("num_unique_ids", num_unique_ids_lbn) .Input("unique_ids", unique_ids_lbn) .Input("unique_embeddings", unique_values_lbn) .Input("embedding_grad", *new_embedding_grad_lbn) .Input("learning_rate", learning_rate_lbn) .Attr("scale", fuse_to_update_scale) .Attr("embedding_name", embedding_name) .Attr("embedding_size", embedding_size) .Attr("line_size", line_size) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); *embedding_update_new_op_conf = fused_embedding_update_put_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { embedding_update_new_op_conf->set_stream_name_hint(embedding_name + "_EMBEDDING"); } return; } auto AddAdamBiasCorrectionFactorOp = [&](float beta_val, const std::string& op_name) -> std::string { user_op::UserOpConfWrapperBuilder op_builder(embedding_op.op_name() + op_name); const auto adam_bias_correction_factor_op = op_builder.OpTypeName("adam_bias_correction_factor") .Input("train_step", train_conf.train_step_lbn()) .Attr("beta", beta_val) .Output("out") .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {adam_bias_correction_factor_op.op_conf()}); return adam_bias_correction_factor_op.output("out", 0); }; user_op::UserOpConfWrapperBuilder embedding_update_op_builder( embedding_op.op_name() + "_embedding_update" + NewUniqueId()); std::vector state_constant_init_values; if (optimizer_conf.has_naive_conf()) { embedding_update_op_builder.OpTypeName("one_embedding_sgd_update"); } else if (optimizer_conf.has_momentum_conf()) { embedding_update_op_builder.OpTypeName("one_embedding_momentum_update") .Attr("beta", optimizer_conf.momentum_conf().beta()); } else if (optimizer_conf.has_adam_conf()) { const AdamModelUpdateConf& adam_conf = optimizer_conf.adam_conf(); if (adam_conf.smart_decay()) { CHECK(adam_conf.do_bias_correction()) << "when use smart decay adam, do_bias_correction should be true. but got " << adam_conf.do_bias_correction(); embedding_update_op_builder.OpTypeName("one_embedding_smart_decay_sparse_adam_update") .Input("train_step", train_conf.train_step_lbn()) .Attr("beta1", adam_conf.beta1()) .Attr("beta2", adam_conf.beta2()) .Attr("epsilon", adam_conf.epsilon()) .Attr("do_bias_correction", adam_conf.do_bias_correction()); } else { embedding_update_op_builder.OpTypeName("one_embedding_adam_update") .Attr("beta1", adam_conf.beta1()) .Attr("beta2", adam_conf.beta2()) .Attr("epsilon", adam_conf.epsilon()) .Attr("do_bias_correction", adam_conf.do_bias_correction()); if (adam_conf.do_bias_correction()) { const std::string bias_correction1_lbn = AddAdamBiasCorrectionFactorOp(adam_conf.beta1(), "adam_bias_correction_factor1"); const std::string bias_correction2_lbn = AddAdamBiasCorrectionFactorOp(adam_conf.beta2(), "adam_bias_correction_factor2"); embedding_update_op_builder.Input("bias_correction1", bias_correction1_lbn) .Input("bias_correction2", bias_correction2_lbn); } } } else if (optimizer_conf.has_adagrad_conf()) { const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf(); state_constant_init_values.push_back(adagrad_conf.initial_accumulator_value()); embedding_update_op_builder.OpTypeName("one_embedding_adagrad_update") .Input("train_step", train_conf.train_step_lbn()) .Attr("lr_decay", adagrad_conf.lr_decay()) .Attr("epsilon", adagrad_conf.epsilon()); } else if (optimizer_conf.has_ftrl_conf()) { const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf(); state_constant_init_values.push_back(ftrl_conf.initial_accumulator_value()); // For `z`, its init value is 0.0. state_constant_init_values.push_back(0.0); embedding_update_op_builder.OpTypeName("one_embedding_ftrl_update") .Attr("lr_power", ftrl_conf.lr_power()) .Attr("lambda1", ftrl_conf.lambda1()) .Attr("lambda2", ftrl_conf.lambda2()) .Attr("beta", ftrl_conf.beta()); } else { UNIMPLEMENTED(); } MakeConstantInitializerAttr(embedding_size, line_size, state_constant_init_values, state_initializer); embedding_update_op_builder.Input("num_unique_ids", num_unique_ids_lbn) .Input("unique_embeddings", unique_values_lbn) .Input("learning_rate", learning_rate_lbn) .Attr("weight_decay", optimizer_conf.weight_decay_conf().weight_decay_rate()) .Attr("l1", l1) .Attr("l2", l2) .Output("updated_unique_embeddings"); if (!update_skip_if_lbn.empty()) { embedding_update_op_builder.Input("skip_if", update_skip_if_lbn); } if (!fuse_to_update_down_scale_by_lbn.empty()) { CHECK(!has_clip_grad); embedding_update_op_builder.Input("down_scale_by_tensor", fuse_to_update_down_scale_by_lbn); } user_op::UserOpConfWrapper embedding_update_op = embedding_update_op_builder.Input("embedding_grad", *new_embedding_grad_lbn) .Attr("scale", fuse_to_update_scale) .Attr("embedding_name", embedding_name) .Attr("embedding_size", embedding_size) .Attr("line_size", line_size) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); *embedding_update_new_op_conf = embedding_update_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { embedding_update_new_op_conf->set_stream_name_hint(embedding_name + "_EMBEDDING"); } user_op::UserOpConfWrapperBuilder embedding_put_op_builder(embedding_op.op_name() + "_embedding_put" + NewUniqueId()); user_op::UserOpConfWrapper embedding_put_op = embedding_put_op_builder.OpTypeName("embedding_put") .Input("num_unique_ids", num_unique_ids_lbn) .Input("unique_ids", unique_ids_lbn) .Input("unique_embeddings", embedding_update_op.output("updated_unique_embeddings", 0)) .Attr("embedding_name", embedding_name) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); OperatorConf embedding_put_new_op_conf = embedding_put_op.op_conf(); if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { embedding_put_new_op_conf.set_stream_name_hint(embedding_name + "_EMBEDDING"); } job_builder->AddOps(embedding_parallel_conf, {embedding_put_new_op_conf}); } void UpdateConsumerOpConf(const OpNode* consumer, const LogicalBlobId& out, const std::string& new_out_lbn, HashMap* op_name2op_conf) { const std::string& consumer_op_name = consumer->op().op_name(); if (op_name2op_conf->find(consumer_op_name) == op_name2op_conf->end()) { (*op_name2op_conf)[consumer_op_name] = consumer->op().op_conf(); } for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == out) { OperatorConf& consumer_op_conf = op_name2op_conf->at(consumer_op_name); const auto& new_val = new_out_lbn; const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val); CHECK_EQ(GenLogicalBlobName(out), old_val); } } } std::string GlobalAbsMaxMin(JobBuilder* job_builder, const HashMap& shadow_op_name2grad_lbn, float p, const std::string& total_norm_lbn, bool max_or_min, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, const ParallelConf& parallel_conf, const int64_t scope_symbol_id) { bool has_split = true; std::string multi_reduce_op_type_name = has_split ? (max_or_min ? "local_multi_reduce_max_abs" : "local_multi_reduce_min_abs") : (max_or_min ? "multi_reduce_max_abs" : "multi_reduce_min_abs"); std::string multi_reduce_op_name = "OneEmbedding-ClipGradient-GlobalNorm-MultiReduceXimumAbs-" + NewUniqueId(); auto multi_reduce_op_builder = user_op::UserOpConfWrapperBuilder(multi_reduce_op_name) .Op(multi_reduce_op_type_name) .Output("y") .ScopeSymbolId(embedding_scope_symbol_id); for (const auto& pair : shadow_op_name2grad_lbn) { const std::string& grad_lbn = pair.second; multi_reduce_op_builder.Input("x", grad_lbn); } auto multi_reduce_op = multi_reduce_op_builder.Build(); job_builder->AddOps(embedding_parallel_conf, {multi_reduce_op.op_conf()}); std::string embedding_reduce_lbn = multi_reduce_op.output("y", 0); if (has_split) { std::string group_reduce_op_type_name = max_or_min ? "reduce_max" : "reduce_min"; std::string group_reduce_op_name = "OneEmbedding-ClipGradient-GlobalNorm-GroupReduceXimum-" + NewUniqueId(); auto group_reduce_op = user_op::UserOpConfWrapperBuilder(group_reduce_op_name) .Op(group_reduce_op_type_name) .Input("input_tensor", multi_reduce_op.output("y", 0)) .Output("output_tensor") .Attr("axis", std::vector{0}) .Attr("keepdims", false) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); job_builder->AddOps(embedding_parallel_conf, {group_reduce_op.op_conf()}); embedding_reduce_lbn = group_reduce_op.output("output_tensor", 0); } if (!total_norm_lbn.empty()) { auto stack_op_builder = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ClipGradient-GlobalNorm-GlobalStack-" + NewUniqueId()) .Op("stack") .Input("in", embedding_reduce_lbn) .Input("in", total_norm_lbn) .Output("out") .Attr("axis", int64_t(0)) .Attr("max_dim_size", static_cast(2)) .ScopeSymbolId(scope_symbol_id); auto stack_op = stack_op_builder.Build(); job_builder->AddOps(parallel_conf, {stack_op.op_conf()}); std::string reduce_op_type_name = max_or_min ? "reduce_max" : "reduce_min"; std::string reduce_op_name = "OneEmbedding-ClipGradient-GlobalNorm-GlobalReduceXimum-" + NewUniqueId(); auto reduce_op = user_op::UserOpConfWrapperBuilder(reduce_op_name) .Op(reduce_op_type_name) .Input("input_tensor", stack_op.output("out", 0)) .Output("output_tensor") .Attr("axis", std::vector{0}) .Attr("keepdims", false) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {reduce_op.op_conf()}); return reduce_op.output("output_tensor", 0); } else { return embedding_reduce_lbn; } } std::string GlobalNorm(JobBuilder* job_builder, const HashMap& shadow_op_name2grad_lbn, float p, const std::string& total_norm_lbn, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, const ParallelConf& parallel_conf, const int64_t scope_symbol_id) { auto multi_reduce_sum_op_builder = user_op::UserOpConfWrapperBuilder("OneEmbedding-ClipGradient-GlobalNorm-MultiReduceSumPowAbs-" + NewUniqueId()) .Op("multi_reduce_sum_pow_abs") .Attr("p", static_cast(p)) .Output("y") .ScopeSymbolId(embedding_scope_symbol_id); for (const auto& pair : shadow_op_name2grad_lbn) { const std::string grad_lbn = pair.second; multi_reduce_sum_op_builder.Input("x", grad_lbn); } const auto multi_reduce_sum_op = multi_reduce_sum_op_builder.Build(); job_builder->AddOps(embedding_parallel_conf, {multi_reduce_sum_op.op_conf()}); const std::string& embedding_sum_pow_abs_lbn = multi_reduce_sum_op.output("y", 0); std::string global_pow_in_lbn; if (!total_norm_lbn.empty()) { auto pow_op = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ClipGradient-GlobalNorm-GlobalPow-" + NewUniqueId()) .Op("scalar_pow") .Input("in", total_norm_lbn) .Attr("float_operand", static_cast(p)) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {pow_op.op_conf()}); user_op::UserOpConfWrapperBuilder add_op_builder("OneEmbedding-ClipGradient-GlobalNorm-Add-" + NewUniqueId()); const auto add_op = add_op_builder.Op("add_n") .Input("in", embedding_sum_pow_abs_lbn) .Input("in", pow_op.output("out", 0)) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {add_op.op_conf()}); global_pow_in_lbn = add_op.output("out", 0); } else { global_pow_in_lbn = embedding_sum_pow_abs_lbn; } auto global_pow_op = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ClipGradient-GlobalNorm-GlobalPow-" + NewUniqueId()) .Op("scalar_pow") .Input("in", global_pow_in_lbn) .Attr("float_operand", static_cast(1.0 / p)) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {global_pow_op.op_conf()}); return global_pow_op.output("out", 0); } std::string GetClampCoeff(JobBuilder* job_builder, const std::string& total_norm_lbn, float max_norm, const ParallelConf& parallel_conf, const int64_t scope_symbol_id) { auto add_eps_ops = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ClipGradient-GlobalNorm-AddEps-" + NewUniqueId()) .Op("scalar_add") .Input("in", total_norm_lbn) .Attr("float_operand", 1e-6) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {add_eps_ops.op_conf()}); auto inv_op = user_op::UserOpConfWrapperBuilder("OneEmbedding-ClipGradient-GlobalNorm-Inv-" + NewUniqueId()) .Op("reciprocal_no_nan") .Input("x", add_eps_ops.output("out", 0)) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {inv_op.op_conf()}); auto coeff_op = user_op::UserOpConfWrapperBuilder("OneEmbedding-ClipGradient-GlobalNorm-Coeff-" + NewUniqueId()) .Op("scalar_mul") .Input("in", inv_op.output("y", 0)) .Attr("float_operand", static_cast(max_norm)) .Attr("has_float_operand", true) .Output("out") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {coeff_op.op_conf()}); auto clamp_coeff_op = user_op::UserOpConfWrapperBuilder( "OneEmbedding-ClipGradient-GlobalNorm-Clamp-" + NewUniqueId()) .Op("clip_by_scalar_max") .Input("x", coeff_op.output("out", 0)) .Attr("floating_max", 1.0) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(parallel_conf, {clamp_coeff_op.op_conf()}); return clamp_coeff_op.output("y", 0); } void ClipGradByGlobalNorm(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const OptimizerConf& optimizer_conf, const HashMap& shadow_op_name2grad_lbn, const HashMap& grad_lbn2update_op_conf, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, HashMap* op_name2op_conf) { const ClipByGlobalNormConf& conf = optimizer_conf.clip_conf().clip_by_global_norm(); double norm_type = conf.norm_type(); auto clip_by_global_norm_pass_state = CHECK_JUST(ctx->MutableState("clip_by_global_norm_state")); const auto NewGlobalNorm = [&](const std::string& total_norm_lbn, const ParallelConf& parallel_conf, const int64_t scope_symbol_id) -> std::string { if (std::isinf(norm_type) && norm_type > 0) { return GlobalAbsMaxMin(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn, true, embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf, scope_symbol_id); } else if (std::isinf(norm_type) && norm_type < 0) { UNIMPLEMENTED() << "one_embedding gradient's invalid values set to 0, so not support abs_reduce_min."; return GlobalAbsMaxMin(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn, false, embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf, scope_symbol_id); } else { return GlobalNorm(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn, embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf, scope_symbol_id); } }; bool has_total_norm_state = false; std::string variable_op_name; for (const auto& var_op_name : optimizer_conf.variable_op_names()) { if (clip_by_global_norm_pass_state->HasTotalNormState(var_op_name)) { has_total_norm_state = true; variable_op_name = var_op_name; break; } } std::string coeff_lbn; if (has_total_norm_state) { // has_total_norm_state means there are some gradients in same optimizer group with // embedding_grads, the total_norm_lbn is the global norm of other gradients, embedding_grads // need to compute global norm with total_norm_lbn and update the consumer of the // total_norm_lbn, no need to compute clamp coff because it has been built in autograd pass. const std::shared_ptr& total_norm_state = clip_by_global_norm_pass_state->GetTotalNormState(variable_op_name); const LogicalBlobId total_norm_lbi = GenLogicalBlobId(total_norm_state->total_norm_lbn()); std::string new_total_norm_lbn = NewGlobalNorm(total_norm_state->total_norm_lbn(), total_norm_state->parallel_conf(), total_norm_state->scope_symbol_id()); const OpNode* total_norm_lbn_producer = op_graph.OpNode4OpName(total_norm_lbi.op_name()); for (const OpEdge* out_edge : total_norm_lbn_producer->out_edges()) { const OpNode* consumer = out_edge->dst_node(); UpdateConsumerOpConf(consumer, total_norm_lbi, new_total_norm_lbn, op_name2op_conf); } total_norm_state->set_total_norm_lbn(new_total_norm_lbn); coeff_lbn = total_norm_state->coeff_lbn(); } else { // no norm_state means there are no gradients in same optimizer group with embedding_grad, // embedding_grad compute the global norm and clip independently. const std::string& new_total_norm_lbn = NewGlobalNorm("", embedding_parallel_conf, embedding_scope_symbol_id); coeff_lbn = GetClampCoeff(job_builder, new_total_norm_lbn, conf.max_norm(), embedding_parallel_conf, embedding_scope_symbol_id); } for (const auto& pair : shadow_op_name2grad_lbn) { const std::string& grad_lbn = pair.second; const auto& it = grad_lbn2update_op_conf.find(grad_lbn); CHECK(it != grad_lbn2update_op_conf.end()); OperatorConf update_op_conf = it->second; *(*update_op_conf.mutable_user_conf()->mutable_input())["scale_by_tensor"].mutable_s() = StdVec2PbRpf({coeff_lbn}); job_builder->AddOps(embedding_parallel_conf, {update_op_conf}); } } void FilterCurGradLbnAndUpdateOpConfPairs( const ::google::protobuf::RepeatedPtrField& variables, const HashMap& shadow_op_name2grad_lbn, HashMap* cur_shadow_op_name2grad_lbn) { for (const std::string& variable : variables) { const auto& it = shadow_op_name2grad_lbn.find(variable); if (it != shadow_op_name2grad_lbn.end()) { (*cur_shadow_op_name2grad_lbn)[variable] = it->second; } } } void FilterEmbeddingGradients(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const HashMap& shadow_op_name2grad_lbn, const HashMap& grad_lbn2update_op_conf, const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id, HashMap* op_name2op_conf) { for (const auto& optimizer_conf : job_builder->job().job_conf().train_conf().optimizer_conf()) { HashMap cur_shadow_op_name2grad_lbn; FilterCurGradLbnAndUpdateOpConfPairs(optimizer_conf.variable_op_names(), shadow_op_name2grad_lbn, &cur_shadow_op_name2grad_lbn); if (!optimizer_conf.has_clip_conf()) { for (const auto& pair : cur_shadow_op_name2grad_lbn) { const auto& it = grad_lbn2update_op_conf.find(pair.second); CHECK(it != grad_lbn2update_op_conf.end()); job_builder->AddOps(embedding_parallel_conf, {it->second}); } } else { ClipGradByGlobalNorm(ctx, op_graph, job_builder, optimizer_conf, cur_shadow_op_name2grad_lbn, grad_lbn2update_op_conf, embedding_parallel_conf, embedding_scope_symbol_id, op_name2op_conf); } } } bool IsRelatedOp(const OperatorConf& op) { return op.has_user_conf() && (op.user_conf().op_type_name() == "one_embedding_fused_lookup"); } bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsRelatedOp); } } // namespace class ReplaceEmbeddingOps final : public JobPass { public: ReplaceEmbeddingOps() = default; ~ReplaceEmbeddingOps() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder, JobPassCtx* ctx) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder, ctx); } }; Maybe ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_builder, JobPassCtx* ctx) const { ParallelConf embedding_parallel_conf; int64_t embedding_scope_symbol_id = 0; HashMap op_name2op_conf; HashMap shadow_op_name2grad_lbn; HashMap grad_lbn2update_op_conf; op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } if (!(op_conf.user_conf().op_type_name() == "one_embedding_fused_lookup")) { return; } std::vector add_ops; std::vector delete_op_names; const user_op::UserOpConfWrapper embedding_op(op_node->op().op_conf()); const OpNode* shadow_producer = op_graph.OpNode4OpName(GenLogicalBlobId(embedding_op.input("shadow", 0)).op_name()); std::string shadow_op_name; if (shadow_producer->op().op_conf().has_variable_conf()) { shadow_op_name = shadow_producer->op().op_name(); } else if (shadow_producer->op().op_conf().has_user_conf() && shadow_producer->op().op_conf().user_conf().op_type_name() == "cast") { const user_op::UserOpConfWrapper shadow_cast_op(shadow_producer->op().op_conf()); const OpNode* cast_producer = op_graph.OpNode4OpName(GenLogicalBlobId(shadow_cast_op.input("in", 0)).op_name()); CHECK(cast_producer->op().op_conf().has_variable_conf()) << cast_producer->op().op_name(); shadow_op_name = cast_producer->op().op_name(); delete_op_names.push_back(shadow_cast_op.op_name()); } else { UNIMPLEMENTED() << "shadow must be variable or variable and cast"; } // assume all embeddings have same placement embedding_scope_symbol_id = embedding_op.op_conf().scope_symbol_id(); embedding_parallel_conf = op_node->parallel_desc().parallel_conf(); const std::string& embedding_name = embedding_op.attr("embedding_name"); const int64_t line_size = embedding_op.attr("line_size"); const int64_t embedding_size = embedding_op.attr("embedding_size"); const bool is_full_cache = embedding_op.attr("is_full_cache"); const int64_t seed = embedding_op.attr("seed"); const int64_t parallel_num = op_node->parallel_desc().parallel_num(); const bool use_system_gather = (parallel_num == 1 && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER", true)); std::string new_embeddings_lbn; // prefetch can not exec in advance when it consume id_shuffle_copy_out, because // id_shuffle_copy_out's regster_num is 1. so we set id_shuffle out to // prefetch_num_unique_ids_lbn and prefetch consume them for pipeline. std::string prefetch_num_unique_ids_lbn; std::string prefetch_unique_ids_lbn; std::string prefetch_unique_table_ids_lbn; std::string inner_inverse_unique_partition_indices_lbn; std::string num_unique_ids_lbn; std::string unique_ids_lbn; std::string unique_table_ids_lbn; std::string inverse_indices_lbn; std::string num_unique_matrix_lbn; BuildIdShuffle(use_system_gather, embedding_name, embedding_op, &add_ops, &prefetch_num_unique_ids_lbn, &prefetch_unique_ids_lbn, &prefetch_unique_table_ids_lbn, &inner_inverse_unique_partition_indices_lbn, &num_unique_ids_lbn, &unique_ids_lbn, &unique_table_ids_lbn, &inverse_indices_lbn, &num_unique_matrix_lbn); const bool is_train_job = job_builder->job().job_conf().has_train_conf(); const bool no_optimizer_states = (embedding_size == line_size); const bool has_embedding_prefetch = (!is_full_cache) && (is_train_job || no_optimizer_states); OperatorConf embedding_prefetch_op_conf; OperatorConf embedding_lookup_op_conf; // embedding lookup op std::string embedding_lbn, unique_values_lbn; BuildEmbeddingLookup( ctx, job_builder, embedding_size, line_size, embedding_name, seed, has_embedding_prefetch, embedding_parallel_conf, embedding_op, prefetch_num_unique_ids_lbn, prefetch_unique_ids_lbn, prefetch_unique_table_ids_lbn, num_unique_ids_lbn, unique_ids_lbn, unique_table_ids_lbn, &embedding_lbn, &unique_values_lbn, &embedding_prefetch_op_conf, &embedding_lookup_op_conf); if (use_system_gather) { user_op::UserOpConfWrapperBuilder gather_op_builder(embedding_op.op_name() + "_one_embedding_gather"); user_op::UserOpConfWrapper gather_op = gather_op_builder.OpTypeName("one_embedding_gather") .Input("in", embedding_lbn) .Input("indices", inverse_indices_lbn) .Output("out") .Attr("embedding_size", embedding_size) .Attr("embedding_name", embedding_name) .ScopeSymbolId(embedding_scope_symbol_id) .Build(); add_ops.push_back(gather_op.op_conf()); new_embeddings_lbn = gather_op.output("out", 0); } else { // embedding shuffle op BuildEmbeddingShuffle(job_builder, embedding_name, embedding_size, embedding_parallel_conf, embedding_op, inverse_indices_lbn, inner_inverse_unique_partition_indices_lbn, num_unique_matrix_lbn, embedding_lbn, &add_ops, &new_embeddings_lbn); } delete_op_names.push_back(embedding_op.op_name()); const LogicalBlobId out = GenLogicalBlobId(embedding_op.output("embeddings", 0)); for (const OpEdge* out_edge : op_node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); UpdateConsumerOpConf(consumer, out, new_embeddings_lbn, &op_name2op_conf); } std::string state_initializer; // find update op const OpNode* producer = op_graph.OpNode4OpName(GenLogicalBlobId(embedding_op.input("ids", 0)).op_name()); for (OpEdge* edge : producer->out_edges()) { const OpNode* consumer = edge->dst_node(); if (consumer->op().op_conf().has_user_conf()) { const user_op::UserOpConfWrapper update_op_conf(consumer->op().op_conf()); if (update_op_conf.op_type_name() != "one_embedding_fused_lookup_grad") { continue; } if (update_op_conf.attr("embedding_name") != embedding_op.attr("embedding_name")) { continue; } delete_op_names.push_back(update_op_conf.op_name()); OptimizerConf embedding_optimizer_conf; bool found_embedding_optimizer = false; for (const auto& optimizer_conf : job_builder->job().job_conf().train_conf().optimizer_conf()) { for (const auto& name : optimizer_conf.variable_op_names()) { if (name == shadow_op_name) { embedding_optimizer_conf = optimizer_conf; found_embedding_optimizer = true; break; } } if (found_embedding_optimizer == true) { break; } } CHECK_EQ(found_embedding_optimizer, true) << shadow_op_name << " has not found optimizer"; std::string embedding_grad_lbn; BuildEmbeddingGradientShuffle( ctx, op_graph, job_builder, op_node, embedding_name, embedding_size, use_system_gather, embedding_parallel_conf, embedding_scope_symbol_id, embedding_op, inverse_indices_lbn, inner_inverse_unique_partition_indices_lbn, num_unique_matrix_lbn, update_op_conf.input("embedding_grad", 0), embedding_optimizer_conf.has_clip_conf(), &embedding_grad_lbn); const OpNode* shadow_node = op_graph.OpNode4OpName(shadow_op_name); const VariableOpConf& shadow_variable_conf = shadow_node->op().op_conf().variable_conf(); float l1 = 0.0; float l2 = 0.0; if (shadow_variable_conf.has_regularizer()) { const RegularizerConf& regularizer_conf = shadow_variable_conf.regularizer(); if (regularizer_conf.has_l1_l2_conf()) { l1 = regularizer_conf.l1_l2_conf().l1(); l2 = regularizer_conf.l1_l2_conf().l2(); } } const std::string& learning_rate_lbn = embedding_optimizer_conf.learning_rate_lbn(); std::string new_embedding_grad_lbn; OperatorConf embedding_update_op_conf; BuildEmbeddingUpdate(ctx, op_graph, job_builder, embedding_parallel_conf, embedding_scope_symbol_id, is_full_cache, embedding_size, line_size, l1, l2, embedding_name, embedding_optimizer_conf, embedding_op, num_unique_ids_lbn, unique_ids_lbn, unique_values_lbn, embedding_grad_lbn, learning_rate_lbn, &new_embedding_grad_lbn, &state_initializer, &embedding_update_op_conf); shadow_op_name2grad_lbn[shadow_op_name] = new_embedding_grad_lbn; grad_lbn2update_op_conf[new_embedding_grad_lbn] = std::move(embedding_update_op_conf); } } if ((state_initializer.empty()) && !no_optimizer_states) { CHECK(!is_train_job) << "train job must have set state initializer"; MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer); } auto state_initializer_attr = ::oneflow::AttrValue(); state_initializer_attr.set_at_string(state_initializer); if (has_embedding_prefetch) { (*(embedding_prefetch_op_conf.mutable_user_conf()->mutable_attr()))["state_initializer"] = state_initializer_attr; add_ops.push_back(embedding_prefetch_op_conf); } (*(embedding_lookup_op_conf.mutable_user_conf()->mutable_attr()))["state_initializer"] = state_initializer_attr; add_ops.push_back(embedding_lookup_op_conf); job_builder->DelOps(delete_op_names); job_builder->AddOps(embedding_parallel_conf, add_ops); }); if (shadow_op_name2grad_lbn.size() > 0) { FilterEmbeddingGradients(ctx, op_graph, job_builder, shadow_op_name2grad_lbn, grad_lbn2update_op_conf, embedding_parallel_conf, embedding_scope_symbol_id, &op_name2op_conf); JUST(DynamicLossScaleAddGradient(ctx, op_graph, job_builder, shadow_op_name2grad_lbn, embedding_scope_symbol_id, embedding_parallel_conf)); } for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); } return Maybe::Ok(); } REGISTER_JOB_PASS("ReplaceEmbeddingOps", ReplaceEmbeddingOps); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/rmsprop_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { CHECK(op_conf.has_variable_conf()); return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); } OperatorConf GenerateRmspropHelperVariableOpConf(const VariableOp& op, const std::string& name, const float initial_value) { OperatorConf helper_variable_op(op.op_conf()); helper_variable_op.set_name(op.op_name() + "-" + name); helper_variable_op.mutable_variable_conf()->set_out("out"); InitializerConf constant_initializer; constant_initializer.mutable_constant_conf()->set_value(initial_value); *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); return helper_variable_op; } void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); OperatorConf mean_square_var(GenerateRmspropHelperVariableOpConf(*var_op, "mean_square", 0.f)); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {mean_square_var}); user_op::UserOpConfWrapperBuilder rmsprop_update_op_builder(var_op->op_name() + "_optimizer"); const RMSPropModelUpdateConf& rmsprop_conf = optimizer_conf.rmsprop_conf(); bool centered = rmsprop_conf.centered(); rmsprop_update_op_builder.OpTypeName("rmsprop_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", optimizer_conf.learning_rate_lbn()) .Input("mean_square", GenVariableOutputLbn(mean_square_var)) .Attr("centered", centered) .Attr("epsilon", rmsprop_conf.epsilon()) .Attr("decay_rate", rmsprop_conf.decay_rate()) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { rmsprop_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &rmsprop_update_op_builder); if (centered) { OperatorConf mean_gradient_var( GenerateRmspropHelperVariableOpConf(*var_op, "mean_gradient", 0.f)); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {mean_gradient_var}); rmsprop_update_op_builder.Input("mean_gradient", GenVariableOutputLbn(mean_gradient_var)); } user_op::UserOpConfWrapper rmsprop_update_op = rmsprop_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {rmsprop_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kRmspropConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/sequential_one_embedding_shuffle_ops_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class SequentialOneEmbeddingOpsPass final : public JobPass { public: SequentialOneEmbeddingOpsPass() = default; ~SequentialOneEmbeddingOpsPass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe SequentialOneEmbeddingOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { HashMap> stream_name_hint2shuffle_op_names; op_graph.TopoForEachNode([&](const OpNode* op_node) { if (!(IsUserOpWithTypeName(op_node->op().op_conf(), "id_shuffle") || IsUserOpWithTypeName(op_node->op().op_conf(), "embedding_shuffle") || IsUserOpWithTypeName(op_node->op().op_conf(), "embedding_gradient_shuffle"))) { return; } OperatorConf op_conf = op_node->op().op_conf(); std::string stream_name; if (op_conf.has_stream_name_hint()) { stream_name = op_conf.stream_name_hint(); } else { stream_name = "DEFAULT"; } const auto& it = stream_name_hint2shuffle_op_names.find(stream_name); if (it != stream_name_hint2shuffle_op_names.end()) { if (it->second.size() > 0) { std::string pre_shuffle_op_name = it->second.back(); op_conf.add_ctrl_in_op_name(pre_shuffle_op_name); job_builder->MutOpsOnlyOnce({op_conf}); } it->second.push_back(op_conf.name()); } else { std::vector shuffle_ops{op_conf.name()}; CHECK(stream_name_hint2shuffle_op_names.emplace(stream_name, shuffle_ops).second); } }); return Maybe::Ok(); } } // namespace REGISTER_JOB_PASS("SequentialOneEmbeddingOpsPass", SequentialOneEmbeddingOpsPass); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/sgd_optm.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); user_op::UserOpConfWrapperBuilder sgd_update_op_builder(var_op->op_name() + "_optimizer"); sgd_update_op_builder.OpTypeName("sgd_update") .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) .Input("model_diff", model_diff_lbn) .Input("learning_rate", optimizer_conf.learning_rate_lbn()) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); if (optimizer_conf.has_lr_scale()) { sgd_update_op_builder.Attr("learning_rate_scale", optimizer_conf.lr_scale()); } SetDynamicLossScaleSkipIf(ctx, &sgd_update_op_builder); user_op::UserOpConfWrapper sgd_update_op = sgd_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {sgd_update_op.op_conf()}); } } // namespace REGISTER_OPTIMIZER(OptimizerConf::kNaiveConf, &GenerateOptimizerOpConf); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/source_user_op_auto_tick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autotick.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace { class MutUserOpConTickInputHelper final : public MutOpConTickInputHelper { public: MutUserOpConTickInputHelper() : MutOpConTickInputHelper() {} bool VirtualIsTickInputBound() const override { return !op_conf().user_conf().input().empty(); } OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override { OperatorConf ret(op_conf()); (*ret.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName].add_s(lbn); return ret; } }; } // namespace REGISTER_AUTO_TICK(OperatorConf::kUserConf, MutUserOpConTickInputHelper); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { bool NeedDoPass(const Job& job) { return std::any_of(job.net().op().cbegin(), job.net().op().cend(), [&](const OperatorConf& op) { return op.has_user_conf() && op.user_conf().op_type_name() == "sparse_softmax_cross_entropy_ms"; }); } void UpdateProbConsumerOpConf(const std::string& new_prob_lbn, const OpNode* op_node, JobBuilder* job_builder) { for (const OpEdge* edge : op_node->out_edges()) { OpNode* out_node = edge->dst_node(); OperatorConf new_conf = out_node->op().op_conf(); if (new_conf.has_user_conf() && new_conf.user_conf().op_type_name() == "sparse_softmax_cross_entropy_ms_grad") { CHECK_EQ(GenLogicalBlobName(out_node->op().BnInOp2Lbi("prob_0")), ReplaceInputLbnInOpCustomizedConf(&new_conf, "prob_0", new_prob_lbn)); job_builder->MutOpsOnlyOnce({new_conf}); } } } class SplitSparseSoftmaxCrossEntropyOpPass final : public JobPass { public: SplitSparseSoftmaxCrossEntropyOpPass() = default; ~SplitSparseSoftmaxCrossEntropyOpPass() override = default; Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!NeedDoPass(*job)) { return Maybe::Ok(); } const OpGraph op_graph(*job); JobBuilder job_builder(job); return Apply(op_graph, &job_builder); } }; Maybe SplitSparseSoftmaxCrossEntropyOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { std::vector to_del_op_names; HashMap consumer_op_name2op_confs; op_graph.ForEachNode([&](const OpNode* node) { const OperatorConf& op_conf = node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } if (op_conf.user_conf().op_type_name() != "sparse_softmax_cross_entropy_ms") { return; } const int64_t scope_symbol_id = node->op().op_conf().scope_symbol_id(); user_op::UserOpConfWrapper cur_op(op_conf); const std::string& op_prediction_blob_name = cur_op.input("prediction", 0); const std::string& op_label_blob_name = cur_op.input("label", 0); const int32_t split_axis = node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi("prediction_0")).shape().NumAxes() - 1; const std::vector axis_vec(1, split_axis); const std::string& op_name = node->op().op_name(); const auto& prediction_nd_sbp = node->NdSbp4BnInOp("prediction_0"); NdSbp stat_distribution_for_consumer; bool has_split_axis_parallel = false; CHECK_EQ(prediction_nd_sbp.sbp_parallel_size(), node->parallel_desc().hierarchy()->NumAxes()); for (int64_t i = 0; i < node->parallel_desc().hierarchy()->NumAxes(); ++i) { const auto& sbp = prediction_nd_sbp.sbp_parallel(i); if (sbp.has_split_parallel() && sbp.split_parallel().axis() == split_axis) { has_split_axis_parallel = true; stat_distribution_for_consumer.add_sbp_parallel()->mutable_broadcast_parallel(); } else { CHECK(!sbp.has_partial_sum_parallel()); *stat_distribution_for_consumer.add_sbp_parallel() = SbpParallel(sbp); } } if (!has_split_axis_parallel) { return; } to_del_op_names.push_back(op_name); auto reduce_max_device_stage_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_reduce_max_device_stage") .Op("reduce_max_device_stage") .Input("in", op_prediction_blob_name) .Output("out") .Output("mask") .Output("count") .Attr("axis", axis_vec) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {reduce_max_device_stage_op.op_conf()}); NdSbpSignature reduce_max_device_stage_signature; (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())["in_0"] = NdSbp(prediction_nd_sbp); (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())["out_0"] = NdSbp(prediction_nd_sbp); (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())["mask_0"] = NdSbp(prediction_nd_sbp); (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())["count_0"] = NdSbp(prediction_nd_sbp); job_builder->AddNdSbpSignature4OpName(reduce_max_device_stage_op.op_name(), reduce_max_device_stage_signature); auto reduce_max_global_stage_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_reduce_max_global_stage") .Op("reduce_max_global_stage") .Input("in", reduce_max_device_stage_op.output("out", 0)) .Input("device_count", reduce_max_device_stage_op.output("count", 0)) .Output("out") .Output("mask") .Attr("axis", axis_vec) .Attr("keepdims", true) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {reduce_max_global_stage_op.op_conf()}); NdSbpSignature reduce_max_global_stage_signature; (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())["in_0"] = stat_distribution_for_consumer; (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())["device_count_0"] = stat_distribution_for_consumer; (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())["out_0"] = stat_distribution_for_consumer; job_builder->AddNdSbpSignature4OpName(reduce_max_global_stage_op.op_name(), reduce_max_global_stage_signature); auto broadcast_sub_max_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_sub_max") .Op("broadcast_sub") .Input("x", op_prediction_blob_name) .Input("y", reduce_max_global_stage_op.output("out", 0)) .Output("z") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_sub_max_op.op_conf()}); auto exp_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_exp") .Op("exp") .Input("x", broadcast_sub_max_op.output("z", 0)) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {exp_op.op_conf()}); auto reduce_sum_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_reduce_sum") .Op("reduce_sum") .Input("input_tensor", exp_op.output("y", 0)) .Output("output_tensor") .Attr("axis", axis_vec) .Attr("keepdims", true) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {reduce_sum_op.op_conf()}); std::string reduce_sum_op_out; if (node->parallel_desc().hierarchy()->NumAxes() > 1) { std::vector nd_sbp_conf; for (const auto& sbp_parallel : stat_distribution_for_consumer.sbp_parallel()) { nd_sbp_conf.emplace_back(SbpParallelToString(sbp_parallel)); } auto parallel_cast_sum_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_reduce_sum_cast_P2B") .Op("hierarchical_parallel_cast") .Input("in", reduce_sum_op.output("output_tensor", 0)) .Output("out") .Attr>("nd_sbp", nd_sbp_conf) .Attr("grad_mode", "auto") .Attr>("grad_nd_sbp", std::vector()) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_sum_op.op_conf()}); reduce_sum_op_out = parallel_cast_sum_op.output("out", 0); } else { reduce_sum_op_out = reduce_sum_op.output("output_tensor", 0); } auto broadcast_div_op = user_op::UserOpConfWrapperBuilder(op_name + "-split_softmax_div") .Op("broadcast_div") .Input("x", exp_op.output("y", 0)) .Input("y", reduce_sum_op_out) .Output("z") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_div_op.op_conf()}); auto log_op = user_op::UserOpConfWrapperBuilder(op_name + "-log") .Op("log") .Input("x", reduce_sum_op_out) .Output("y") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {log_op.op_conf()}); auto broadcast_sub_op = user_op::UserOpConfWrapperBuilder(op_name + "-broadcast_add") .Op("broadcast_sub") .Input("x", broadcast_sub_max_op.output("z", 0)) .Input("y", log_op.output("y", 0)) .Output("z") .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_sub_op.op_conf()}); auto nll_op = user_op::UserOpConfWrapperBuilder(op_name + "-nll") .Op("nll") .Input("input", broadcast_sub_op.output("z", 0)) .Input("target", op_label_blob_name) .Output("output") .Output("out_weight") .Attr("ignore_index", -100) .ScopeSymbolId(scope_symbol_id) .Build(); job_builder->AddOps(node->parallel_desc().parallel_conf(), {nll_op.op_conf()}); const std::string& prob_lbn = cur_op.output("prob", 0); const std::string& out_lbn = cur_op.output("out", 0); const std::string& new_prob_lbn = broadcast_div_op.output("z", 0); const std::string& new_out_lbn = nll_op.output("output", 0); for (const OpEdge* out_edge : node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); const std::string& consumer_op_name = consumer->op().op_name(); if (consumer_op_name2op_confs.find(consumer_op_name) == consumer_op_name2op_confs.end()) { consumer_op_name2op_confs[consumer_op_name] = consumer->op().op_conf(); } OperatorConf& consumer_op_conf = consumer_op_name2op_confs[consumer_op_name]; for (const std::string& ibn : consumer->op().input_bns()) { const std::string& input_lbn = GenLogicalBlobName(consumer->op().BnInOp2Lbi(ibn)); if (input_lbn == prob_lbn) { const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_prob_lbn); CHECK_EQ(old_lbn, prob_lbn); } else if (input_lbn == out_lbn) { const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_out_lbn); CHECK_EQ(old_lbn, out_lbn); } else { // does not care } } } }); for (const auto& pair : consumer_op_name2op_confs) { job_builder->MutOpsOnlyOnce({pair.second}); } job_builder->DelOps(to_del_op_names); return Maybe::Ok(); } REGISTER_JOB_PASS("SplitSparseSoftmaxCrossEntropyOpPass", SplitSparseSoftmaxCrossEntropyOpPass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/system_op_fill_job_name_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { namespace { class SystemOpFillJobNamePass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(SystemOpFillJobNamePass); SystemOpFillJobNamePass() = default; ~SystemOpFillJobNamePass() override = default; bool IsEnabled(const JobPassCtx& ctx) const { return true; } Maybe Apply(Job* job, JobPassCtx* ctx) const override { const std::string& job_name = job->job_conf().job_name(); for (OperatorConf& op_conf : *job->mutable_net()->mutable_op()) { if (op_conf.has_input_conf()) { op_conf.mutable_input_conf()->set_job_name(job_name); } else if (op_conf.has_wait_and_send_ids_conf()) { op_conf.mutable_wait_and_send_ids_conf()->set_job_name(job_name); } else if (op_conf.has_output_conf()) { op_conf.mutable_output_conf()->set_job_name(job_name); } else if (op_conf.has_return_conf()) { op_conf.mutable_return_conf()->set_job_name(job_name); } else if (op_conf.has_callback_notify_conf()) { op_conf.mutable_callback_notify_conf()->set_job_name(job_name); } else { // do nothing } } return Maybe::Ok(); } }; REGISTER_JOB_PASS("SystemOpFillJobNamePass", SystemOpFillJobNamePass); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/tick_autotick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autotick.h" namespace oneflow { namespace { class MutTickOpConTickInputHelper final : public MutOpConTickInputHelper { public: MutTickOpConTickInputHelper() : MutOpConTickInputHelper() {} bool VirtualIsTickInputBound() const override { return op_conf().tick_conf().tick_size() > 0; } OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override { OperatorConf ret(op_conf()); ret.mutable_tick_conf()->add_tick(lbn); return ret; } }; } // namespace REGISTER_AUTO_TICK(OperatorConf::kTickConf, MutTickOpConTickInputHelper); } // namespace oneflow ================================================ FILE: oneflow/core/job_rewriter/variable_autotick.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/autotick.h" namespace oneflow { namespace { class MutVariableOpConTickInputHelper final : public MutOpConTickInputHelper { public: MutVariableOpConTickInputHelper() : MutOpConTickInputHelper() {} bool VirtualIsTickInputBound() const override { return op_conf().variable_conf().has_tick(); } OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override { OperatorConf ret(op_conf()); ret.mutable_variable_conf()->set_tick(lbn); return ret; } }; } // namespace REGISTER_AUTO_TICK(OperatorConf::kVariableConf, MutVariableOpConTickInputHelper); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/assign_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" namespace oneflow { class AssignKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(AssignKernel); AssignKernel() = default; ~AssignKernel() override = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void AssignKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* value = ctx->BnInOp2Blob("value"); Blob* ref = ctx->BnInOp2Blob("ref"); AutoMemcpy(ctx->stream(), ref, value); } REGISTER_KERNEL(OperatorConf::kAssignConf, AssignKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/blob_access_checker_kernel_observer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/blob_access_checker_kernel_observer.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { namespace { template void ForEachObnAndIsHeaderInferedBeforeCompute(KernelContext* kernel_ctx, const Kernel* kernel, const HandlerT& Handler) { const auto& modifier_map = kernel->op_attribute().arg_modifier_signature().obn2output_blob_modifier(); for (const std::string& obn : kernel->op_attribute().output_bns()) { Blob* blob = kernel_ctx->BnInOp2Blob(obn); if (blob) { bool is_header_infered_before_compute = modifier_map.at(obn).header_infered_before_compute(); Handler(obn, is_header_infered_before_compute); } } } template void ForEachObnAndIsMutableByConsumer(KernelContext* kernel_ctx, const Kernel* kernel, const HandlerT& Handler) { const auto& modifier_map = kernel->op_attribute().arg_modifier_signature().obn2output_blob_modifier(); for (const std::string& obn : kernel->op_attribute().output_bns()) { Blob* blob = kernel_ctx->BnInOp2Blob(obn); if (blob) { bool is_mutable_by_consumer = modifier_map.at(obn).is_mutable(); Handler(obn, is_mutable_by_consumer); } } } void SetOutputBlobProducerInferAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) { ForEachObnAndIsHeaderInferedBeforeCompute( kernel_ctx, kernel, [&](const std::string& obn, bool _) { kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker( Singleton>::Get()); }); } void SetOutputBlobProducerComputeAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) { ForEachObnAndIsHeaderInferedBeforeCompute( kernel_ctx, kernel, [&](const std::string& obn, bool is_header_infered_before_compute) { const BlobAccessChecker* checker = nullptr; if (is_header_infered_before_compute) { checker = Singleton>::Get(); } else { checker = Singleton>::Get(); } kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker(checker); }); } void SetOutputBlobConsumerAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) { ForEachObnAndIsMutableByConsumer( kernel_ctx, kernel, [&](const std::string& obn, bool is_mutable) { const BlobAccessChecker* checker = nullptr; if (is_mutable) { checker = Singleton>::Get(); } else { checker = Singleton>::Get(); } kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker(checker); }); } } // namespace void BlobAccessCheckerKernelObserver::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) { SetOutputBlobProducerInferAccessChecker(kernel_ctx, kernel); } void BlobAccessCheckerKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { SetOutputBlobProducerComputeAccessChecker(kernel_ctx, kernel); } void BlobAccessCheckerKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { SetOutputBlobConsumerAccessChecker(kernel_ctx, kernel); } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/blob_access_checker_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_ #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class BlobAccessCheckerKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(BlobAccessCheckerKernelObserver); BlobAccessCheckerKernelObserver() = default; ~BlobAccessCheckerKernelObserver() override = default; void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override; void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/blob_tensor_view.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/blob_tensor_view.h" #include "oneflow/core/register/blob.h" namespace oneflow { namespace user_op { BlobTensorView::BlobTensorView(Blob* blob) : blob_(blob) {} ShapeView BlobTensorView::shape_view() const { return blob_->shape(); } MutShapeView BlobTensorView::mut_shape_view() { return *blob_->mut_shape_view(); } const Stride& BlobTensorView::stride() const { return blob_->stride(); } DataType BlobTensorView::data_type() const { return blob_->data_type(); } MemoryFormat BlobTensorView::memory_format() const { return blob_->memory_format(); } const MemoryCase& BlobTensorView::mem_case() const { return blob_->mem_case(); } const void* BlobTensorView::raw_dptr() const { return blob_->dptr(); } void* BlobTensorView::mut_raw_dptr() { return blob_->mut_dptr(); } void BlobTensorView::Reset(Blob* blob) { blob_ = blob; } Blob* BlobTensorView::blob() const { return blob_; } } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/core/kernel/blob_tensor_view.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ #define ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ #include "oneflow/core/framework/user_op_tensor.h" namespace oneflow { class Blob; namespace user_op { class BlobTensorView final : public Tensor { public: explicit BlobTensorView(Blob* blob); ~BlobTensorView() = default; ShapeView shape_view() const override; MutShapeView mut_shape_view() override; const Stride& stride() const override; DataType data_type() const override; MemoryFormat memory_format() const override; const MemoryCase& mem_case() const override; const void* raw_dptr() const override; void* mut_raw_dptr() override; void Reset(Blob* blob); Blob* blob() const; private: Blob* blob_; }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ ================================================ FILE: oneflow/core/kernel/boxing_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/operator/op_conf_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/ep/include/primitive/add.h" namespace oneflow { template class BoxingKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(BoxingKernel); BoxingKernel() = default; ~BoxingKernel() = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; PbRpf ibn_0_; PbRpf obn_0_; }; namespace { PbRpf ConstructPbRpf(const std::string& s) { PbRpf ret; ret.Reserve(1); ret.Add()->assign(s); return ret; } template void CalcSumOfBlobs(KernelContext* ctx, const std::function& BnInOp2Blob, const PbRpf& src_bns, const std::string& dst_bn) { Blob* dst_blob = BnInOp2Blob(dst_bn); std::unique_ptr primitive = ep::primitive::NewPrimitive(DeviceType::kCPU, dst_blob->data_type()); CHECK(primitive); std::vector srcs(src_bns.size()); FOR_RANGE(size_t, i, 0, src_bns.size()) { Blob* src_blob_i = BnInOp2Blob(src_bns.Get(i)); srcs[i] = src_blob_i->dptr(); } primitive->Launch(ctx->stream(), srcs.data(), srcs.size(), dst_blob->mut_dptr(), dst_blob->static_shape().elem_cnt()); } void CopyFromFirstToOtherBlobs(KernelContext* ctx, const std::function& BnInOp2Blob, const PbRpf& bns) { const Blob* blob_0 = BnInOp2Blob(bns.Get(0)); FOR_RANGE(size_t, i, 1, bns.size()) { AutoMemcpy(ctx->stream(), BnInOp2Blob(bns.Get(i)), blob_0); } } class DataContentDesc final { public: OF_DISALLOW_COPY_AND_MOVE(DataContentDesc); DataContentDesc() = delete; ~DataContentDesc() = default; DataContentDesc(std::function BnInOp2Blob, const PbRpf* bns, int32_t axis) { BnInOp2Blob_ = BnInOp2Blob; seg_num_ = BnInOp2Blob(bns->Get(0))->static_shape().Count(0, axis); elem_sum_.assign(bns->size(), 0); FOR_RANGE(size_t, i, 0, elem_sum_.size()) { elem_sum_[i] = BnInOp2Blob(bns->Get(i))->static_shape().Count(axis); if (i > 0) { elem_sum_[i] += elem_sum_[i - 1]; } } bns_ = bns; axis_ = axis; } size_t OneElemSize() const { return GetSizeOfDataType(BnInOp2Blob_(bns_->Get(0))->data_type()); } int64_t TotalElemNum() const { return seg_num_ * elem_sum_.back(); } template std::tuple CalcContinuousElemNumStartFrom(int64_t idx) const { std::tuple ret(0, nullptr); int64_t seg_idx = idx / elem_sum_.back(); int64_t idx_in_seg = idx % elem_sum_.back(); auto elem_sum_it = std::upper_bound(elem_sum_.begin(), elem_sum_.end(), idx_in_seg); CHECK(elem_sum_it != elem_sum_.end()); std::get<0>(ret) = *elem_sum_it - idx_in_seg; int64_t bn_idx = elem_sum_it - elem_sum_.begin(); int64_t idx_in_blob = idx_in_seg; if (bn_idx > 0) { idx_in_blob -= elem_sum_[bn_idx - 1]; } Blob* blob = BnInOp2Blob_(bns_->Get(bn_idx)); std::get<1>(ret) = GetDptrT(blob) + (seg_idx * blob->static_shape().Count(axis_) + idx_in_blob) * GetSizeOfDataType(blob->data_type()); return ret; } private: std::function BnInOp2Blob_; int64_t seg_num_; std::vector elem_sum_; const PbRpf* bns_; int32_t axis_; }; static const char* GetConstDptr(Blob* blob) { return blob->dptr(); } static char* GetMutDptr(Blob* blob) { return blob->mut_dptr(); } void ConcatSplitPartDataContent(ep::Stream* stream, const DataContentDesc& in_desc, const DataContentDesc& out_desc, int32_t part_id, int32_t part_num) { size_t one_elem_size = in_desc.OneElemSize(); BalancedSplitter bs(in_desc.TotalElemNum(), part_num); Range range = bs.At(part_id); int64_t in_idx = range.begin(); int64_t in_elem_num = 0; const char* in_ptr = nullptr; int64_t out_idx = range.begin(); int64_t out_elem_num = 0; char* out_ptr = nullptr; while (in_elem_num > 0 || out_elem_num > 0 || in_idx < range.end() || out_idx < range.end()) { if (in_elem_num == 0) { std::tie(in_elem_num, in_ptr) = in_desc.CalcContinuousElemNumStartFrom(in_idx); in_elem_num = std::min(in_elem_num, range.end() - in_idx); if (in_elem_num == 0) { break; } in_idx += in_elem_num; } if (out_elem_num == 0) { std::tie(out_elem_num, out_ptr) = out_desc.CalcContinuousElemNumStartFrom(out_idx); out_elem_num = std::min(out_elem_num, range.end() - out_idx); if (out_elem_num == 0) { break; } out_idx += out_elem_num; } int64_t copy_elem_num = std::min(in_elem_num, out_elem_num); size_t copy_size = copy_elem_num * one_elem_size; Memcpy(stream, out_ptr, in_ptr, copy_size); in_elem_num -= copy_elem_num; out_elem_num -= copy_elem_num; in_ptr += copy_size; out_ptr += copy_size; } CHECK_EQ(in_elem_num, 0); CHECK_EQ(out_elem_num, 0); CHECK_EQ(in_idx, range.end()); CHECK_EQ(out_idx, range.end()); } void ConcatSplitDataContent(ep::Stream* stream, const std::function& BnInOp2Blob, const PbRpf& concat_bns, int32_t concat_axis, const PbRpf& split_bns, int32_t split_axis) { DataContentDesc in_desc(BnInOp2Blob, &concat_bns, concat_axis); DataContentDesc out_desc(BnInOp2Blob, &split_bns, split_axis); CHECK_EQ(in_desc.TotalElemNum(), out_desc.TotalElemNum()); CHECK_EQ(in_desc.OneElemSize(), out_desc.OneElemSize()); static const size_t min_byte_one_part = 128; int32_t part_num = in_desc.TotalElemNum() * in_desc.OneElemSize() / min_byte_one_part; part_num = std::min(part_num, Singleton::Get()->thread_num()); if (part_num >= 2) { BlockingCounter bc(part_num); FOR_RANGE(int32_t, part_id, 0, part_num) { Singleton::Get()->AddWork( [stream, &in_desc, &out_desc, part_id, &part_num, &bc]() { ConcatSplitPartDataContent(stream, in_desc, out_desc, part_id, part_num); bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } else { ConcatSplitPartDataContent(stream, in_desc, out_desc, 0, 1); } } } // namespace template void BoxingKernel::VirtualKernelInit(KernelContext* ctx) { const std::string& ibn_0 = op_attribute().input_bns(0); const std::string& obn_0 = op_attribute().output_bns(0); ibn_0_ = ConstructPbRpf(ibn_0); obn_0_ = ConstructPbRpf(obn_0); } template void BoxingKernel::ForwardDataContent(KernelContext* ctx) const { const BoxingOpConf& boxing_conf = op_conf().boxing_conf(); ep::Stream* stream = ctx->stream(); const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); }; if (boxing_conf.in_box_case() == BoxingOpConf::kConcatBox) { if (boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) { ConcatSplitDataContent(stream, BnInOp2Blob, op_attribute().input_bns(), boxing_conf.concat_box().axis(), op_attribute().output_bns(), boxing_conf.split_box().axis()); } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) { ConcatSplitDataContent(stream, BnInOp2Blob, op_attribute().input_bns(), boxing_conf.concat_box().axis(), obn_0_, 0); CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns()); } else { UNIMPLEMENTED(); } } else if (boxing_conf.in_box_case() == BoxingOpConf::kAddBox) { if (boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) { CalcSumOfBlobs(ctx, BnInOp2Blob, op_attribute().input_bns(), "middle"); ConcatSplitDataContent(stream, BnInOp2Blob, ConstructPbRpf("middle"), 0, op_attribute().output_bns(), boxing_conf.split_box().axis()); } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) { CalcSumOfBlobs(ctx, BnInOp2Blob, op_attribute().input_bns(), obn_0_.Get(0)); CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns()); } else { UNIMPLEMENTED(); } } else { UNIMPLEMENTED(); } } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBoxingConf, BoxingKernel, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/boxing_zeros_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { class BoxingZerosKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(BoxingZerosKernel); BoxingZerosKernel() = default; ~BoxingZerosKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; std::unique_ptr primitive_; }; void BoxingZerosKernel::VirtualKernelInit(KernelContext* ctx) { primitive_ = ep::primitive::NewPrimitive(ctx->stream()->device_type()); CHECK(primitive_); } void BoxingZerosKernel::ForwardDataContent(KernelContext* ctx) const { Blob* out = ctx->BnInOp2Blob("out"); primitive_->Launch(ctx->stream(), out->mut_dptr(), 0, out->ByteSizeOfBlobBody()); } REGISTER_KERNEL(OperatorConf::kBoxingZerosConf, BoxingZerosKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/broadcast_to_compatible_with_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template class BroadcastToCompatibleWithKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastToCompatibleWithKernel); BroadcastToCompatibleWithKernel() = default; ~BroadcastToCompatibleWithKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; }; template void BroadcastToCompatibleWithKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* x = ctx->BnInOp2Blob("x"); Blob* y = ctx->BnInOp2Blob("y"); const auto& broadcast_axes = this->kernel_conf().broadcast_to_compatible_with_conf().broadcast_axes(); int64_t num_axes = y->shape().NumAxes(); Shape x_extend_shape = CreateLeftExtendedShape(x->shape(), num_axes); FOR_RANGE(int64_t, i, 0, num_axes) { if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) == broadcast_axes.end()) { CHECK_EQ(x_extend_shape.At(i), y->shape().At(i)); } else { CHECK_EQ(x_extend_shape.At(i), 1); } } NdarrayUtil::BroadcastTo(ctx->stream(), XpuVarNdarray(y, num_axes), XpuVarNdarray(x, num_axes)); } #define REGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL(device_type_v, dtype_pair) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE( \ OperatorConf::kBroadcastToCompatibleWithConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ BroadcastToCompatibleWithKernel) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) #if defined(WITH_CUDA) REGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL(DeviceType::kCUDA, (float16, DataType::kFloat16)) #endif } // namespace oneflow ================================================ FILE: oneflow/core/kernel/callback_notify_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/buffer_manager.h" namespace oneflow { template class CallbackNotifyKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyKernel); CallbackNotifyKernel() = default; ~CallbackNotifyKernel() = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; template void CallbackNotifyKernel::ForwardDataContent(KernelContext* ctx) const { auto* buffer_mgr = Singleton>>::Get(); std::string buffer_name; CHECK(this->op_conf().callback_notify_conf().has_job_name()); buffer_name = GetCallbackNotifierBufferName(this->op_conf().callback_notify_conf().job_name()); std::shared_ptr job_instance; BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&job_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { job_instance->Finish(); } } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kCallbackNotifyConf, CallbackNotifyKernel, INT_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/case_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/case_kernel.h" #include "oneflow/core/operator/operator.h" namespace oneflow { template void CaseKernel::VirtualKernelInit(KernelContext* ctx) { ctx->set_state(std::make_shared()); } template void CaseKernel::ForwardDataContent(KernelContext* ctx) const { CaseStatus* const case_status = CHECK_NOTNULL(dynamic_cast(ctx->state().get())); if (case_status->cmd == kCaseCmdHandleInput) { int64_t cur_selected_id = static_cast(ctx->BnInOp2Blob("in")->dptr()[0]); case_status->select_id2request_cnt[cur_selected_id] += 1; } else if (case_status->cmd == kCaseCmdHandleOutput) { int64_t cur_selected_id = case_status->cur_selected_id; CHECK_GT(case_status->select_id2request_cnt[cur_selected_id], 0); case_status->select_id2request_cnt[cur_selected_id] -= 1; if (case_status->select_id2request_cnt[cur_selected_id] == 0) { case_status->select_id2request_cnt.erase(cur_selected_id); } *(ctx->BnInOp2Blob(GenRepeatedBn("out", cur_selected_id))->mut_dptr()) = cur_selected_id; } else { UNIMPLEMENTED(); } } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kCaseConf, CaseKernel, INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/kernel/case_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_ #include "oneflow/core/kernel/kernel.h" namespace oneflow { enum CaseCmd { kCaseCmdInvalid = 0, kCaseCmdHandleInput = 1, kCaseCmdHandleOutput = 2, }; struct CaseStatus final : public KernelState { CaseStatus() : cmd(kCaseCmdInvalid), cur_selected_id(-1) {} ~CaseStatus() = default; CaseCmd cmd; int64_t cur_selected_id; HashMap select_id2request_cnt; }; template class CaseKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CaseKernel); CaseKernel() = default; ~CaseKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_ ================================================ FILE: oneflow/core/kernel/chain_kernel_observer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/chain_kernel_observer.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { void ChainKernelObserver::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->WillForward(kernel_ctx, kernel); } } void ChainKernelObserver::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->DidForward(kernel_ctx, kernel); } } void ChainKernelObserver::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->WillForwardHeader(kernel_ctx, kernel); } } void ChainKernelObserver::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->DidForwardHeader(kernel_ctx, kernel); } } void ChainKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->WillForwardDataContent(kernel_ctx, kernel); } } void ChainKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { for (const auto& observer : kernel_observers_) { observer->DidForwardDataContent(kernel_ctx, kernel); } } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/chain_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_ #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class ChainKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(ChainKernelObserver); explicit ChainKernelObserver(std::vector> kernel_observers) : kernel_observers_(std::move(kernel_observers)) {} ~ChainKernelObserver() override = default; void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override; void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override; void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; private: std::vector> kernel_observers_; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/collective_boxing_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/job/collective_boxing/scheduler.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/lazy/actor/collective_boxing_actor_context.h" namespace oneflow { using namespace boxing::collective; namespace { CollectiveBoxingActorContext* GetCollectiveBoxingActorContext(KernelContext* kernel_ctx) { auto* actor_context_provider = CHECK_NOTNULL(dynamic_cast(kernel_ctx)); return CHECK_NOTNULL( dynamic_cast(actor_context_provider->GetActorContext())); } class CollectiveBoxingKernelState final : public KernelState { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingKernelState); explicit CollectiveBoxingKernelState(const RankDesc& rank_desc) : request_handle_(Singleton::Get()->CreateRequestHandle(rank_desc)) {} ~CollectiveBoxingKernelState() override { Singleton::Get()->DestroyRequestHandle(request_handle_); } RequestHandle* request_handle() { return request_handle_; } private: RequestHandle* request_handle_ = nullptr; }; class CollectiveBoxingGenericKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericKernel); CollectiveBoxingGenericKernel() = default; ~CollectiveBoxingGenericKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; bool IsKernelLaunchSynchronized() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void CollectiveBoxingGenericKernel::VirtualKernelInit(KernelContext* ctx) { const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc(); ctx->set_state(std::make_shared(rank_desc)); } void CollectiveBoxingGenericKernel::ForwardDataContent(KernelContext* ctx) const { RequestHandle* request_handle = CHECK_NOTNULL(dynamic_cast(ctx->state().get())) ->request_handle(); const void* send_buff = nullptr; void* recv_buff = nullptr; const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc(); const DataType data_type = rank_desc.op_desc().data_type(); if (GenericOpHasInput(rank_desc)) { const Blob* in = ctx->BnInOp2Blob("in"); CHECK_EQ(in->data_type(), data_type); CHECK(in->shape() == ShapeView(GenericOpGetInputShape(rank_desc))); send_buff = in->dptr(); } if (GenericOpHasOutput(rank_desc)) { Blob* out = ctx->BnInOp2Blob("out"); CHECK_EQ(out->data_type(), data_type); CHECK(out->shape() == ShapeView(GenericOpGetOutputShape(rank_desc))); recv_buff = out->mut_dptr(); } auto* actor_ctx = GetCollectiveBoxingActorContext(ctx); actor_ctx->Schedule(request_handle, send_buff, recv_buff); } REGISTER_KERNEL(OperatorConf::kCollectiveBoxingGenericConf, CollectiveBoxingGenericKernel); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/kernel/collective_boxing_pack_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/permute.h" namespace oneflow { class CollectiveBoxingPackKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackKernel); CollectiveBoxingPackKernel() = default; ~CollectiveBoxingPackKernel() override = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void CollectiveBoxingPackKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* in = ctx->BnInOp2Blob("in"); Blob* out = ctx->BnInOp2Blob("out"); const CollectiveBoxingPackOpConf& pack_conf = this->op_conf().collective_boxing_pack_conf(); const int64_t num_ranks = pack_conf.num_ranks(); const Shape logical_shape(pack_conf.logical_shape()); const bool need_transpose = !((pack_conf.dst_sbp_parallel().has_split_parallel() && pack_conf.dst_sbp_parallel().split_parallel().axis() == 0) || pack_conf.dst_sbp_parallel().has_broadcast_parallel() || pack_conf.dst_sbp_parallel().has_partial_sum_parallel()); if (need_transpose) { const int64_t dst_split_axis = pack_conf.dst_sbp_parallel().split_parallel().axis(); DimVector transpose_in_dim_vec = logical_shape.dim_vec(); if (pack_conf.src_sbp_parallel().has_split_parallel()) { const int64_t src_split_axis = pack_conf.src_sbp_parallel().split_parallel().axis(); transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks; } CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0); transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + dst_split_axis, num_ranks); std::vector perm; perm.emplace_back(dst_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != dst_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), out->mut_dptr()); } else { AutoMemcpy(ctx->stream(), out, in); } } REGISTER_KERNEL(OperatorConf::kCollectiveBoxingPackConf, CollectiveBoxingPackKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/collective_boxing_unpack_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/permute.h" namespace oneflow { class CollectiveBoxingUnpackKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackKernel); CollectiveBoxingUnpackKernel() = default; ~CollectiveBoxingUnpackKernel() override = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void CollectiveBoxingUnpackKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* in = ctx->BnInOp2Blob("in"); Blob* out = ctx->BnInOp2Blob("out"); const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf(); const int64_t num_ranks = unpack_conf.num_ranks(); const Shape logical_shape(unpack_conf.logical_shape()); // skip 0size tensor boxing if (logical_shape.elem_cnt() == 0) { return; } const bool need_transpose = !((unpack_conf.src_sbp_parallel().has_split_parallel() && unpack_conf.src_sbp_parallel().split_parallel().axis() == 0) || unpack_conf.src_sbp_parallel().has_broadcast_parallel() || unpack_conf.src_sbp_parallel().has_partial_sum_parallel()); if (need_transpose) { const int64_t src_split_axis = unpack_conf.src_sbp_parallel().split_parallel().axis(); DimVector transpose_in_dim_vec = logical_shape.dim_vec(); CHECK_EQ(transpose_in_dim_vec.at(src_split_axis) % num_ranks, 0); transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks; if (unpack_conf.dst_sbp_parallel().has_split_parallel()) { const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis(); CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0); transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks; } transpose_in_dim_vec.insert(transpose_in_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, transpose_in_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + src_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), out->mut_dptr()); } else { AutoMemcpy(ctx->stream(), out, in); } } REGISTER_KERNEL(OperatorConf::kCollectiveBoxingUnpackConf, CollectiveBoxingUnpackKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/constant_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { class ConstantLikeKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(ConstantLikeKernel); ConstantLikeKernel() : is_init_(false) {} ~ConstantLikeKernel() = default; private: mutable bool is_init_; void ForwardDataContent(KernelContext* ctx) const override { if (is_init_) { return; } Blob* out_blob = ctx->BnInOp2Blob("out"); Scalar value; const auto& conf = this->op_conf().constant_like_conf(); if (conf.has_int_operand()) { value = Scalar(conf.int_operand()); } else if (conf.has_float_operand()) { value = Scalar(conf.float_operand()); } else { UNIMPLEMENTED(); } std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out_blob->data_type()); CHECK(primitive); primitive->Launch(ctx->stream(), out_blob->mut_dptr(), value, out_blob->static_shape().elem_cnt()); is_init_ = true; } }; REGISTER_KERNEL(OperatorConf::kConstantLikeConf, ConstantLikeKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/cpu_check_numerics_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_ #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class CpuCheckNumericsKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(CpuCheckNumericsKernelObserver); CpuCheckNumericsKernelObserver() = default; ~CpuCheckNumericsKernelObserver() override = default; void DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/cpu_numerics_kernel_observer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cpu_check_numerics_kernel_observer.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { namespace { template bool HasNotFinite(const int64_t elem_cnt, const T* data_ptr) { FOR_RANGE(int64_t, i, 0, elem_cnt) { if (!std::isfinite(data_ptr[i])) { return true; } } return false; } bool HasNotFiniteCpu(ep::Stream* stream, const Blob* blob) { const DataType dtype = blob->data_type(); const int64_t elem_cnt = blob->shape().elem_cnt(); if (dtype == kFloat) { return HasNotFinite(elem_cnt, blob->dptr()); } else if (dtype == kDouble) { return HasNotFinite(elem_cnt, blob->dptr()); } else { return false; } } void DumpBlob(KernelContext* ctx, const std::string& bn) { Blob* blob = ctx->BnInOp2Blob(bn); if (blob != nullptr) { std::ofstream ofs(bn); ofs.write(blob->dptr(), blob->ByteSizeOfBlobBody()); } } void DumpBlobs(KernelContext* ctx, const Kernel* kernel) { for (const auto& obn : kernel->op_attribute().output_bns()) { DumpBlob(ctx, obn); } for (const auto& ibn : kernel->op_attribute().input_bns()) { DumpBlob(ctx, ibn); } } } // namespace void CpuCheckNumericsKernelObserver::DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) { for (const auto& obn : kernel->op_attribute().output_bns()) { Blob* blob = ctx->BnInOp2Blob(obn); if (blob != nullptr) { bool has_not_finite = HasNotFiniteCpu(ctx->stream(), blob); if (has_not_finite && ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP", false)) { DumpBlobs(ctx, kernel); } CHECK(!has_not_finite) << kernel->op_conf().name() << " : " << obn << " has nan or inf"; } } } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/critical_section_callback_tick_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/buffer_manager.h" namespace oneflow { class CriticalSectionCallbackTickKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickKernel); CriticalSectionCallbackTickKernel() = default; ~CriticalSectionCallbackTickKernel() = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void CriticalSectionCallbackTickKernel::ForwardDataContent(KernelContext* ctx) const { auto* buffer_mgr = Singleton>>::Get(); CHECK(op_conf().has_critical_section_callback_tick_conf()); const std::string& buffer_name = op_conf().critical_section_callback_tick_conf().buffer_name(); std::shared_ptr critical_section_instance; BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&critical_section_instance); CHECK_EQ(buffer_status, kBufferStatusSuccess); critical_section_instance->Finish(); } REGISTER_KERNEL(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/critical_section_wait_tick_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/buffer_manager.h" namespace oneflow { class CriticalSectionWaitTickKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickKernel); CriticalSectionWaitTickKernel() = default; ~CriticalSectionWaitTickKernel() = default; private: bool IsStateless() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override; }; void CriticalSectionWaitTickKernel::ForwardDataContent(KernelContext* ctx) const { auto* buffer_mgr = Singleton>>::Get(); CHECK(this->op_conf().has_critical_section_wait_tick_conf()); const std::string& buffer_name = this->op_conf().critical_section_wait_tick_conf().buffer_name(); std::shared_ptr critical_section_instance; BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->Pull(&critical_section_instance); CHECK_EQ(buffer_status, kBufferStatusSuccess); } REGISTER_KERNEL(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/cuda_check_numerics_kernel_observer.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_check_numerics_kernel_observer.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __device__ bool IsNotFinite(T x) { return !isfinite(x); } template<> __device__ bool IsNotFinite(half x) { #if __CUDA_ARCH__ >= 530 return (__hisinf(x) || __hisnan(x)); #else __trap(); return true; #endif } template __global__ void HasNotFiniteGpuKernel(const int64_t n, const T* x, volatile bool* has_not_finite) { if (*has_not_finite) { return; } CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { if (IsNotFinite(x[i])) { *has_not_finite = true; return; } } } template bool HasNotFinite(ep::Stream* stream, const int64_t elem_cnt, const T* data_ptr, bool* has_not_finite_host, bool* has_not_finite_device) { OF_CUDA_CHECK(cudaMemsetAsync(has_not_finite_device, 0, sizeof(bool), stream->As()->cuda_stream())); HasNotFiniteGpuKernel <<As()->cuda_stream()>>>(elem_cnt, data_ptr, has_not_finite_device); OF_CUDA_CHECK(cudaMemcpyAsync(has_not_finite_host, has_not_finite_device, sizeof(bool), cudaMemcpyDefault, stream->As()->cuda_stream())); OF_CUDA_CHECK(cudaStreamSynchronize(stream->As()->cuda_stream())); return *has_not_finite_host; } bool HasNotFiniteGpu(ep::Stream* stream, const Blob* blob, bool* has_not_finite_host, bool* has_not_finite_device) { auto* cuda_stream = stream->As(); const DataType dtype = blob->data_type(); const int64_t elem_cnt = blob->shape().elem_cnt(); if (elem_cnt == 0) { return false; } if (dtype == kFloat) { return HasNotFinite(stream, elem_cnt, blob->dptr(), has_not_finite_host, has_not_finite_device); } else if (dtype == kDouble) { return HasNotFinite(stream, elem_cnt, blob->dptr(), has_not_finite_host, has_not_finite_device); } else if (dtype == kFloat16) { if (cuda_stream->cuda_arch() >= 530) { return HasNotFinite(stream, elem_cnt, blob->dptr(), has_not_finite_host, has_not_finite_device); } else { LOG(FATAL) << "use half need nvcc arch >= 530"; return true; } } else { return false; } } void DumpBlob(KernelContext* ctx, const std::string& bn) { Blob* blob = ctx->BnInOp2Blob(bn); if (blob != nullptr) { std::vector buffer(blob->ByteSizeOfBlobBody()); OF_CUDA_CHECK( cudaMemcpy(buffer.data(), blob->dptr(), blob->ByteSizeOfBlobBody(), cudaMemcpyDefault)); OF_CUDA_CHECK(cudaDeviceSynchronize()); std::ofstream ofs(bn); ofs.write(buffer.data(), blob->ByteSizeOfBlobBody()); } } void DumpBlobs(KernelContext* ctx, const Kernel* kernel) { for (const auto& obn : kernel->op_attribute().output_bns()) { DumpBlob(ctx, obn); } for (const auto& ibn : kernel->op_attribute().input_bns()) { DumpBlob(ctx, ibn); } } } // namespace CudaCheckNumericsKernelObserver::CudaCheckNumericsKernelObserver() : has_not_finite_host_(nullptr), has_not_finite_device_(nullptr) { OF_CUDA_CHECK(cudaGetDevice(&device_id_)); OF_CUDA_CHECK(cudaMallocHost(&has_not_finite_host_, sizeof(bool))); OF_CUDA_CHECK(cudaMalloc(&has_not_finite_device_, sizeof(bool))); } CudaCheckNumericsKernelObserver::~CudaCheckNumericsKernelObserver() { CudaCurrentDeviceGuard guard(device_id_); OF_CUDA_CHECK(cudaFreeHost(has_not_finite_host_)); OF_CUDA_CHECK(cudaFree(has_not_finite_device_)); } void CudaCheckNumericsKernelObserver::DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) { for (const auto& obn : kernel->op_attribute().output_bns()) { Blob* blob = ctx->BnInOp2Blob(obn); if (blob != nullptr) { bool has_not_finite = HasNotFiniteGpu(ctx->stream(), blob, has_not_finite_host_, has_not_finite_device_); if (has_not_finite && ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP", false)) { DumpBlobs(ctx, kernel); } CHECK(!has_not_finite) << kernel->op_conf().name() << " : " << obn << " has nan or inf"; } } } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/cuda_check_numerics_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_ #ifdef WITH_CUDA #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class CudaCheckNumericsKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(CudaCheckNumericsKernelObserver); CudaCheckNumericsKernelObserver(); ~CudaCheckNumericsKernelObserver() override; void DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) override; private: bool* has_not_finite_host_; bool* has_not_finite_device_; int device_id_; }; } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/cuda_graph_support.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_ #define ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_ namespace oneflow { namespace user_op { class KernelInitContext; class KernelComputeContext; class OpKernelState; class OpKernelCache; class CudaGraphSupport { public: CudaGraphSupport() = default; virtual ~CudaGraphSupport() = default; virtual bool IsCudaGraphSupported(KernelInitContext* ctx, OpKernelState* state) const { return true; } virtual bool IsReadyForCapture(KernelComputeContext* ctx, OpKernelState* state, const OpKernelCache* cache) const { return true; } }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_ ================================================ FILE: oneflow/core/kernel/distribute_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_context.h" namespace oneflow { class DistributeAddKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DistributeAddKernel); DistributeAddKernel() = default; ~DistributeAddKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; const Blob* GetInBlob(KernelContext* ctx) const; }; void DistributeAddKernel::ForwardDataContent(KernelContext* ctx) const { AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob("out"), GetInBlob(ctx)); } const Blob* DistributeAddKernel::GetInBlob(KernelContext* ctx) const { const Blob* in_blob = nullptr; FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) { const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i)); if (cur_blob != nullptr && cur_blob != in_blob) { CHECK_ISNULL(in_blob); in_blob = cur_blob; } } return in_blob; } REGISTER_KERNEL(OperatorConf::kDistributeAddConf, DistributeAddKernel); class DistributeCloneKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DistributeCloneKernel); DistributeCloneKernel() = default; ~DistributeCloneKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; Blob* GetOutBlob(KernelContext* ctx) const; }; void DistributeCloneKernel::ForwardDataContent(KernelContext* ctx) const { AutoMemcpy(ctx->stream(), GetOutBlob(ctx), ctx->BnInOp2Blob("in")); } Blob* DistributeCloneKernel::GetOutBlob(KernelContext* ctx) const { Blob* out_blob = nullptr; FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) { Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i)); if (cur_blob != nullptr && cur_blob != out_blob) { CHECK_ISNULL(out_blob); out_blob = cur_blob; } } return out_blob; } REGISTER_KERNEL(OperatorConf::kDistributeCloneConf, DistributeCloneKernel); class DistributeConcatKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DistributeConcatKernel); DistributeConcatKernel() = default; ~DistributeConcatKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; const Blob* GetInBlob(KernelContext* ctx) const; }; void DistributeConcatKernel::ForwardDataContent(KernelContext* ctx) const { AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob("out"), GetInBlob(ctx)); } const Blob* DistributeConcatKernel::GetInBlob(KernelContext* ctx) const { const Blob* in_blob = nullptr; FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) { const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i)); if (cur_blob != nullptr && cur_blob != in_blob) { CHECK_ISNULL(in_blob); in_blob = cur_blob; } } return in_blob; } REGISTER_KERNEL(OperatorConf::kDistributeConcatConf, DistributeConcatKernel); class DistributeSplitKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DistributeSplitKernel); DistributeSplitKernel() = default; ~DistributeSplitKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; void ForwardShape(KernelContext* ctx) const override; Blob* GetOutBlob(KernelContext* ctx) const; }; void DistributeSplitKernel::ForwardDataContent(KernelContext* ctx) const { AutoMemcpy(ctx->stream(), GetOutBlob(ctx), ctx->BnInOp2Blob("in")); } void DistributeSplitKernel::ForwardShape(KernelContext* ctx) const { Blob* out_blob = GetOutBlob(ctx); out_blob->mut_shape_view()->set_shape(ctx->BnInOp2Blob("in")->shape()); } Blob* DistributeSplitKernel::GetOutBlob(KernelContext* ctx) const { Blob* out_blob = nullptr; FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) { Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i)); if (cur_blob != nullptr && cur_blob != out_blob) { CHECK_ISNULL(out_blob); out_blob = cur_blob; } } return out_blob; } REGISTER_KERNEL(OperatorConf::kDistributeSplitConf, DistributeSplitKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/dynamic_reshape_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" namespace oneflow { class DynamicReshapeKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeKernel); DynamicReshapeKernel() = default; ~DynamicReshapeKernel() override = default; private: void ForwardDataContent(KernelContext* ctx) const override; }; void DynamicReshapeKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* in_blob = ctx->BnInOp2Blob("in"); Blob* out_blob = ctx->BnInOp2Blob("out"); AutoMemcpy(ctx->stream(), out_blob, in_blob); } REGISTER_KERNEL(OperatorConf::kDynamicReshapeConf, DynamicReshapeKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/dynamic_reshape_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" namespace oneflow { class DynamicReshapeLikeKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeLikeKernel); DynamicReshapeLikeKernel() = default; ~DynamicReshapeLikeKernel() override = default; private: void ForwardDataContent(KernelContext* ctx) const override; }; void DynamicReshapeLikeKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* in_blob = ctx->BnInOp2Blob("x"); Blob* out_blob = ctx->BnInOp2Blob("y"); AutoMemcpy(ctx->stream(), out_blob, in_blob); } REGISTER_KERNEL(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/esac_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/esac_kernel.h" namespace oneflow { template void EsacKernel::VirtualKernelInit(KernelContext* ctx) { ctx->set_state(std::make_shared()); } template void EsacKernel::ForwardDataContent(KernelContext* ctx) const { T value = static_cast(CHECK_NOTNULL(dynamic_cast(ctx->state().get()))->value); *(ctx->BnInOp2Blob("out")->mut_dptr()) = value; } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kEsacConf, EsacKernel, INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/kernel/esac_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_ #include "oneflow/core/kernel/kernel.h" namespace oneflow { struct EsacKernelState : public KernelState { int64_t value{}; }; template class EsacKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(EsacKernel); EsacKernel() = default; ~EsacKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_ ================================================ FILE: oneflow/core/kernel/identity_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/ep/include/primitive/memcpy.h" namespace oneflow { class IdentityKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(IdentityKernel); IdentityKernel() = default; ~IdentityKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override; }; void IdentityKernel::ForwardDataContent(KernelContext* ctx) const { const Blob* in_blob = ctx->BnInOp2Blob("in"); Blob* out_blob = ctx->BnInOp2Blob("out"); AutoMemcpy(ctx->stream(), out_blob, in_blob); } void IdentityKernel::ForwardHeader(KernelContext* ctx) const { ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->BnInOp2Blob("in")); } REGISTER_KERNEL(OperatorConf::kIdentityConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCopyConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCastToLocalConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCastFromLocalConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kBoxingIdentityConf, IdentityKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/image_decoder_random_crop_resize_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/error.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/common/channel.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/user/image/random_crop_generator.h" #include "oneflow/user/image/jpeg_decoder.h" #include #ifdef WITH_CUDA #include #if CUDA_VERSION >= 10020 #define WITH_NVJPEG #include #include #endif // CUDA_VERSION >= 10020 #endif // WITH_CUDA namespace oneflow { namespace { constexpr int kNumChannels = 3; struct Task { const unsigned char* data; size_t length; unsigned char* dst; RandomCropGenerator* crop_generator; }; struct Work { std::shared_ptr> tasks; unsigned char* workspace = nullptr; size_t workspace_size = 0; std::shared_ptr done_counter; std::shared_ptr> task_counter; }; struct ROI { int x; int y; int w; int h; }; class ROIGenerator { public: virtual ~ROIGenerator() = default; virtual void Generate(int width, int height, ROI* roi) const = 0; }; class RandomCropROIGenerator : public ROIGenerator { public: explicit RandomCropROIGenerator(RandomCropGenerator* crop_generator) : crop_generator_(crop_generator) {} ~RandomCropROIGenerator() override = default; void Generate(int width, int height, ROI* roi) const override { CropWindow window; crop_generator_->GenerateCropWindow({height, width}, &window); roi->x = window.anchor.At(1); roi->y = window.anchor.At(0); roi->w = window.shape.At(1); roi->h = window.shape.At(0); } private: RandomCropGenerator* crop_generator_; }; class NoChangeROIGenerator : public ROIGenerator { public: ~NoChangeROIGenerator() override = default; void Generate(int width, int height, ROI* roi) const override { roi->x = 0; roi->y = 0; roi->w = width; roi->h = height; } }; void GenerateRandomCropRoi(RandomCropGenerator* crop_generator, int width, int height, int* roi_x, int* roi_y, int* roi_width, int* roi_height) { CropWindow window; crop_generator->GenerateCropWindow({height, width}, &window); *roi_x = window.anchor.At(1); *roi_y = window.anchor.At(0); *roi_width = window.shape.At(1); *roi_height = window.shape.At(0); } class DecodeHandle { public: DecodeHandle() = default; virtual ~DecodeHandle() = default; virtual void DecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) = 0; virtual void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) = 0; virtual void Synchronize() = 0; }; using DecodeHandleFactory = std::function()>; template DecodeHandleFactory CreateDecodeHandleFactory(int target_width, int target_height); class CpuDecodeHandle final : public DecodeHandle { public: OF_DISALLOW_COPY_AND_MOVE(CpuDecodeHandle); CpuDecodeHandle() = default; ~CpuDecodeHandle() override = default; void DecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) override; void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) override { // do nothing } void Synchronize() override { // do nothing } }; bool CpuJpegDecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) { cv::Mat image_mat; if (JpegPartialDecodeRandomCropImage(data, length, crop_generator, workspace, workspace_size, &image_mat)) { return false; } cv::Mat dst_mat(target_height, target_width, CV_8UC3, dst, cv::Mat::AUTO_STEP); cv::resize(image_mat, dst_mat, cv::Size(target_width, target_height), 0, 0, cv::INTER_LINEAR); return true; } void OpencvDecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* dst, int target_width, int target_height) { cv::Mat image = cv::imdecode(cv::Mat(1, length, CV_8UC1, const_cast(data)), cv::IMREAD_COLOR); cv::Mat cropped; if (crop_generator) { cv::Rect roi; GenerateRandomCropRoi(crop_generator, image.cols, image.rows, &roi.x, &roi.y, &roi.width, &roi.height); image(roi).copyTo(cropped); } else { cropped = image; } cv::Mat resized; cv::resize(cropped, resized, cv::Size(target_width, target_height), 0, 0, cv::INTER_LINEAR); cv::Mat dst_mat(target_height, target_width, CV_8UC3, dst, cv::Mat::AUTO_STEP); cv::cvtColor(resized, dst_mat, cv::COLOR_BGR2RGB); } void CpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) { if (CpuJpegDecodeRandomCropResize(data, length, crop_generator, workspace, workspace_size, dst, target_width, target_height)) { return; } OpencvDecodeRandomCropResize(data, length, crop_generator, dst, target_width, target_height); } template<> DecodeHandleFactory CreateDecodeHandleFactory(int target_width, int target_height) { return []() -> std::shared_ptr { return std::make_shared(); }; } #if defined(WITH_NVJPEG) int GpuDeviceMalloc(void** p, size_t s) { return (int)cudaMalloc(p, s); } int GpuDeviceFree(void* p) { return (int)cudaFree(p); } int GpuPinnedMalloc(void** p, size_t s, unsigned int flags) { return (int)cudaHostAlloc(p, s, flags); } int GpuPinnedFree(void* p) { return (int)cudaFreeHost(p); } void InitNppStreamContext(NppStreamContext* ctx, int dev, cudaStream_t stream) { ctx->hStream = stream; ctx->nCudaDeviceId = dev; OF_CUDA_CHECK( cudaDeviceGetAttribute(&ctx->nMultiProcessorCount, cudaDevAttrMultiProcessorCount, dev)); OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nMaxThreadsPerMultiProcessor, cudaDevAttrMaxThreadsPerMultiProcessor, dev)); OF_CUDA_CHECK( cudaDeviceGetAttribute(&ctx->nMaxThreadsPerBlock, cudaDevAttrMaxThreadsPerBlock, dev)); int smem_per_block = 0; OF_CUDA_CHECK(cudaDeviceGetAttribute(&smem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, dev)); ctx->nSharedMemPerBlock = smem_per_block; OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nCudaDevAttrComputeCapabilityMajor, cudaDevAttrComputeCapabilityMajor, dev)); OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nCudaDevAttrComputeCapabilityMinor, cudaDevAttrComputeCapabilityMinor, dev)); OF_CUDA_CHECK(cudaStreamGetFlags(stream, &ctx->nStreamFlags)); } class GpuDecodeHandle final : public DecodeHandle { public: OF_DISALLOW_COPY_AND_MOVE(GpuDecodeHandle); explicit GpuDecodeHandle(int dev, int target_width, int target_height); ~GpuDecodeHandle() override; void DecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) override; void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) override; void Synchronize() override; private: void DecodeRandomCrop(const unsigned char* data, size_t length, ROIGenerator* roi_generator, unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height); void Decode(const unsigned char* data, size_t length, unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height); void CropResize(const unsigned char* src, int src_width, int src_height, ROIGenerator* roi_generator, unsigned char* dst, int dst_width, int dst_height); cudaStream_t cuda_stream_ = nullptr; nvjpegHandle_t jpeg_handle_ = nullptr; nvjpegJpegState_t jpeg_state_ = nullptr; nvjpegJpegState_t hw_jpeg_state_ = nullptr; nvjpegBufferPinned_t jpeg_pinned_buffer_ = nullptr; nvjpegBufferDevice_t jpeg_device_buffer_ = nullptr; nvjpegDecodeParams_t jpeg_decode_params_ = nullptr; nvjpegJpegDecoder_t jpeg_decoder_ = nullptr; nvjpegJpegDecoder_t hw_jpeg_decoder_ = nullptr; nvjpegJpegStream_t jpeg_stream_ = nullptr; NppStreamContext npp_stream_ctx_{}; nvjpegDevAllocator_t dev_allocator_{}; nvjpegPinnedAllocator_t pinned_allocator_{}; CpuDecodeHandle fallback_handle_; unsigned char* fallback_buffer_{}; size_t fallback_buffer_size_; bool warmup_done_; bool use_hardware_acceleration_; }; GpuDecodeHandle::GpuDecodeHandle(int dev, int target_width, int target_height) : warmup_done_(false), use_hardware_acceleration_(false) { OF_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking)); dev_allocator_.dev_malloc = &GpuDeviceMalloc; dev_allocator_.dev_free = &GpuDeviceFree; pinned_allocator_.pinned_malloc = &GpuPinnedMalloc; pinned_allocator_.pinned_free = &GpuPinnedFree; OF_NVJPEG_CHECK(nvjpegCreateEx(NVJPEG_BACKEND_DEFAULT, &dev_allocator_, &pinned_allocator_, 0, &jpeg_handle_)); OF_NVJPEG_CHECK(nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_DEFAULT, &jpeg_decoder_)); OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, jpeg_decoder_, &jpeg_state_)); #if NVJPEG_VER_MAJOR >= 11 if (ParseBooleanFromEnv("ONEFLOW_DECODER_ENABLE_NVJPEG_HARDWARE_ACCELERATION", true) && nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_HARDWARE, &hw_jpeg_decoder_) == NVJPEG_STATUS_SUCCESS) { OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, hw_jpeg_decoder_, &hw_jpeg_state_)); use_hardware_acceleration_ = true; } else { hw_jpeg_decoder_ = nullptr; hw_jpeg_state_ = nullptr; } #endif OF_NVJPEG_CHECK(nvjpegBufferPinnedCreate(jpeg_handle_, &pinned_allocator_, &jpeg_pinned_buffer_)); OF_NVJPEG_CHECK(nvjpegBufferDeviceCreate(jpeg_handle_, &dev_allocator_, &jpeg_device_buffer_)); OF_NVJPEG_CHECK(nvjpegDecodeParamsCreate(jpeg_handle_, &jpeg_decode_params_)); OF_NVJPEG_CHECK(nvjpegJpegStreamCreate(jpeg_handle_, &jpeg_stream_)); InitNppStreamContext(&npp_stream_ctx_, dev, cuda_stream_); fallback_buffer_size_ = target_width * target_height * kNumChannels; OF_CUDA_CHECK(cudaMallocHost(&fallback_buffer_, fallback_buffer_size_)); } GpuDecodeHandle::~GpuDecodeHandle() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); OF_NVJPEG_CHECK(nvjpegJpegStreamDestroy(jpeg_stream_)); OF_NVJPEG_CHECK(nvjpegDecodeParamsDestroy(jpeg_decode_params_)); OF_NVJPEG_CHECK(nvjpegBufferDeviceDestroy(jpeg_device_buffer_)); OF_NVJPEG_CHECK(nvjpegBufferPinnedDestroy(jpeg_pinned_buffer_)); OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(jpeg_state_)); OF_NVJPEG_CHECK(nvjpegDecoderDestroy(jpeg_decoder_)); if (use_hardware_acceleration_) { OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(hw_jpeg_state_)); OF_NVJPEG_CHECK(nvjpegDecoderDestroy(hw_jpeg_decoder_)); } OF_NVJPEG_CHECK(nvjpegDestroy(jpeg_handle_)); OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); OF_CUDA_CHECK(cudaFreeHost(fallback_buffer_)); } void GpuDecodeHandle::DecodeRandomCrop(const unsigned char* data, size_t length, ROIGenerator* roi_generator, unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height) { // https://docs.nvidia.com/cuda/archive/10.2/nvjpeg/index.html#nvjpeg-decoupled-decode-api OF_NVJPEG_CHECK(nvjpegJpegStreamParse(jpeg_handle_, data, length, 0, 0, jpeg_stream_)); unsigned int orig_width = 0; unsigned int orig_height = 0; OF_NVJPEG_CHECK(nvjpegJpegStreamGetFrameDimensions(jpeg_stream_, &orig_width, &orig_height)); ROI roi{}; roi_generator->Generate(static_cast(orig_width), static_cast(orig_height), &roi); CHECK_LE(roi.w * roi.h * kNumChannels, dst_max_length); nvjpegImage_t image; image.channel[0] = dst; image.pitch[0] = roi.w * kNumChannels; OF_NVJPEG_CHECK(nvjpegDecodeParamsSetOutputFormat(jpeg_decode_params_, NVJPEG_OUTPUT_RGBI)); nvjpegJpegDecoder_t jpeg_decoder = nullptr; nvjpegJpegState_t jpeg_state = nullptr; int is_hardware_acceleration_supported = -1; if (use_hardware_acceleration_) { nvjpegDecoderJpegSupported(hw_jpeg_decoder_, jpeg_stream_, jpeg_decode_params_, &is_hardware_acceleration_supported); } if (is_hardware_acceleration_supported == 0) { jpeg_decoder = hw_jpeg_decoder_; jpeg_state = hw_jpeg_state_; } else { jpeg_decoder = jpeg_decoder_; jpeg_state = jpeg_state_; } if (roi.x != 0 || roi.y != 0 || roi.w != orig_width || roi.h != orig_height) { // hardware_acceleration not support nvjpegDecodeParamsSetROI OF_NVJPEG_CHECK(nvjpegDecodeParamsSetROI(jpeg_decode_params_, roi.x, roi.y, roi.w, roi.h)); } else { OF_NVJPEG_CHECK(nvjpegDecodeParamsSetROI(jpeg_decode_params_, 0, 0, -1, -1)); } OF_NVJPEG_CHECK(nvjpegStateAttachPinnedBuffer(jpeg_state, jpeg_pinned_buffer_)); OF_NVJPEG_CHECK(nvjpegStateAttachDeviceBuffer(jpeg_state, jpeg_device_buffer_)); OF_NVJPEG_CHECK(nvjpegDecodeJpegHost(jpeg_handle_, jpeg_decoder, jpeg_state, jpeg_decode_params_, jpeg_stream_)); OF_NVJPEG_CHECK(nvjpegDecodeJpegTransferToDevice(jpeg_handle_, jpeg_decoder, jpeg_state, jpeg_stream_, cuda_stream_)); OF_NVJPEG_CHECK( nvjpegDecodeJpegDevice(jpeg_handle_, jpeg_decoder, jpeg_state, &image, cuda_stream_)); *dst_width = roi.w; *dst_height = roi.h; } void GpuDecodeHandle::Decode(const unsigned char* data, size_t length, unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height) { NoChangeROIGenerator no_change_roi_generator; DecodeRandomCrop(data, length, &no_change_roi_generator, dst, dst_max_length, dst_width, dst_height); } void GpuDecodeHandle::CropResize(const unsigned char* src, int src_width, int src_height, ROIGenerator* roi_generator, unsigned char* dst, int dst_width, int dst_height) { ROI roi{}; roi_generator->Generate(static_cast(src_width), static_cast(src_height), &roi); const NppiSize src_size{ .width = src_width, .height = src_height, }; const NppiRect src_rect{ .x = roi.x, .y = roi.y, .width = roi.w, .height = roi.h, }; const NppiSize dst_size{ .width = dst_width, .height = dst_height, }; const NppiRect dst_rect{ .x = 0, .y = 0, .width = dst_width, .height = dst_height, }; NppStatus status = nppiResize_8u_C3R_Ctx(src, src_width * kNumChannels, src_size, src_rect, dst, dst_width * 3, dst_size, dst_rect, NPPI_INTER_LINEAR, npp_stream_ctx_); CHECK_GE(status, NPP_SUCCESS); } void GpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t length, RandomCropGenerator* crop_generator, unsigned char* workspace, size_t workspace_size, unsigned char* dst, int target_width, int target_height) { int width[NVJPEG_MAX_COMPONENT]; int height[NVJPEG_MAX_COMPONENT]; nvjpegChromaSubsampling_t subsampling{}; int num_components = 0; nvjpegStatus_t status = nvjpegGetImageInfo(jpeg_handle_, data, length, &num_components, &subsampling, width, height); if (status != NVJPEG_STATUS_SUCCESS) { CHECK_LE(target_width * target_height * kNumChannels, fallback_buffer_size_); fallback_handle_.DecodeRandomCropResize(data, length, crop_generator, nullptr, 0, fallback_buffer_, target_width, target_height); OF_CUDA_CHECK(cudaMemcpyAsync(dst, fallback_buffer_, target_width * target_height * kNumChannels, cudaMemcpyDefault, cuda_stream_)); return; } NoChangeROIGenerator no_change_roi_generator; RandomCropROIGenerator random_crop_roi_generator(crop_generator); if (use_hardware_acceleration_) { int w = 0; int h = 0; DecodeRandomCrop(data, length, &no_change_roi_generator, workspace, workspace_size, &w, &h); CropResize(workspace, w, h, &random_crop_roi_generator, dst, target_width, target_height); } else { int w = 0; int h = 0; DecodeRandomCrop(data, length, &random_crop_roi_generator, workspace, workspace_size, &w, &h); CropResize(workspace, w, h, &no_change_roi_generator, dst, target_width, target_height); } } void GpuDecodeHandle::WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) { if (warmup_done_) { return; } warmup_size = std::min(static_cast(std::sqrt(workspace_size / kNumChannels)), warmup_size); cv::Mat image = cv::Mat::zeros(cv::Size(warmup_size, warmup_size), CV_8UC3); cv::randu(image, cv::Scalar(0, 0, 0), cv::Scalar(255, 255, 255)); std::vector data; cv::imencode(".jpg", image, data, {}); int decoded_width = 0; int decoded_height = 0; Decode(data.data(), data.size(), workspace, workspace_size, &decoded_width, &decoded_height); Synchronize(); if (use_hardware_acceleration_) { // Note(guoran): hardware acceleration jpeg decoder support baseline decoding only, use // progressive to warmup jpeg decoder. cv::imencode(".jpg", image, data, {cv::IMWRITE_JPEG_PROGRESSIVE, 1}); Decode(data.data(), data.size(), workspace, workspace_size, &decoded_width, &decoded_height); Synchronize(); } warmup_done_ = true; } void GpuDecodeHandle::Synchronize() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); } template<> DecodeHandleFactory CreateDecodeHandleFactory(int target_width, int target_height) { int dev = 0; OF_CUDA_CHECK(cudaGetDevice(&dev)); return [dev, target_width, target_height]() -> std::shared_ptr { OF_CUDA_CHECK(cudaSetDevice(dev)); return std::make_shared(dev, target_width, target_height); }; } #endif // defined(WITH_NVJPEG) class Worker final { public: OF_DISALLOW_COPY_AND_MOVE(Worker); Worker(const std::function()>& handle_factory, int target_width, int target_height, int warmup_size) { worker_thread_ = std::thread(&Worker::PollWork, this, handle_factory, target_width, target_height, warmup_size); } ~Worker() { work_queue_.Close(); worker_thread_.join(); } void Enqueue(std::shared_ptr& work) { work_queue_.Send(work); } private: Channel> work_queue_; std::thread worker_thread_; void PollWork(const std::function()>& handle_factory, int target_width, int target_height, int warmup_size) { OF_PROFILER_NAME_THIS_HOST_THREAD("_cuda_img_decode"); std::shared_ptr handle = handle_factory(); std::shared_ptr work; while (true) { ChannelStatus status = work_queue_.Receive(&work); if (status == ChannelStatus::kChannelStatusErrorClosed) { break; } CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess); handle->WarmupOnce(warmup_size, work->workspace, work->workspace_size); while (true) { const int task_id = work->task_counter->fetch_add(1, std::memory_order_relaxed); if (task_id >= work->tasks->size()) { break; } const Task& task = work->tasks->at(task_id); handle->DecodeRandomCropResize(task.data, task.length, task.crop_generator, work->workspace, work->workspace_size, task.dst, target_width, target_height); handle->Synchronize(); } work->done_counter->Decrease(); } } }; } // namespace template class ImageDecoderRandomCropResizeKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(ImageDecoderRandomCropResizeKernel); ImageDecoderRandomCropResizeKernel() = default; ~ImageDecoderRandomCropResizeKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; std::vector> random_crop_generators_; std::vector> workers_; }; template void ImageDecoderRandomCropResizeKernel::VirtualKernelInit(KernelContext* ctx) { const ImageDecoderRandomCropResizeOpConf& conf = this->op_conf().image_decoder_random_crop_resize_conf(); const int64_t batch_size = this->kernel_conf().image_decoder_random_crop_resize_conf().batch_size(); random_crop_generators_.resize(batch_size); std::seed_seq seq{this->kernel_conf().image_decoder_random_crop_resize_conf().seed()}; std::vector seeds(batch_size); seq.generate(seeds.begin(), seeds.end()); AspectRatioRange aspect_ratio_range{ conf.random_aspect_ratio_min(), conf.random_aspect_ratio_max(), }; AreaRange area_range{ conf.random_area_min(), conf.random_area_max(), }; for (int64_t i = 0; i < batch_size; ++i) { random_crop_generators_.at(i).reset( new RandomCropGenerator(aspect_ratio_range, area_range, seeds.at(i), conf.num_attempts())); } workers_.resize(conf.num_workers()); for (int64_t i = 0; i < conf.num_workers(); ++i) { workers_.at(i).reset(new Worker( CreateDecodeHandleFactory(conf.target_width(), conf.target_height()), conf.target_width(), conf.target_height(), conf.warmup_size())); } } template void ImageDecoderRandomCropResizeKernel::ForwardDataContent(KernelContext* ctx) const { const ImageDecoderRandomCropResizeOpConf& conf = this->op_conf().image_decoder_random_crop_resize_conf(); const Blob* in = ctx->BnInOp2Blob("in"); Blob* out = ctx->BnInOp2Blob("out"); Blob* tmp = ctx->BnInOp2Blob("tmp"); CHECK_EQ(in->data_type(), DataType::kTensorBuffer); CHECK_EQ(out->data_type(), DataType::kUInt8); const ShapeView& in_shape = in->shape(); const int64_t num_in_axes = in_shape.NumAxes(); const ShapeView& out_shape = out->shape(); const int64_t num_out_axes = out_shape.NumAxes(); CHECK_EQ(num_out_axes, num_in_axes + 3); for (int i = 0; i < num_in_axes; ++i) { CHECK_EQ(out_shape.At(i), in_shape.At(i)); } CHECK_EQ(out_shape.At(num_in_axes), conf.target_height()); CHECK_EQ(out_shape.At(num_in_axes + 1), conf.target_width()); CHECK_EQ(out_shape.At(num_in_axes + 2), kNumChannels); CHECK_EQ(tmp->data_type(), DataType::kUInt8); const int64_t batch_size = in_shape.elem_cnt(); const auto* buffers = in->dptr(); auto* out_ptr = out->mut_dptr(); const int64_t out_instance_size = conf.target_height() * conf.target_width() * kNumChannels; auto* workspace_ptr = tmp->mut_dptr(); size_t workspace_size_per_worker = tmp->shape().elem_cnt() / workers_.size(); std::shared_ptr done_counter(new BlockingCounter(workers_.size())); std::shared_ptr> task_counter(new std::atomic(0)); std::shared_ptr> tasks(new std::vector(batch_size)); for (int64_t task_id = 0; task_id < batch_size; ++task_id) { const TensorBuffer* buffer = buffers + task_id; CHECK_EQ(buffer->data_type(), DataType::kUInt8); tasks->at(task_id).data = buffer->data(); tasks->at(task_id).length = buffer->elem_cnt(); tasks->at(task_id).dst = out_ptr + task_id * out_instance_size; tasks->at(task_id).crop_generator = random_crop_generators_.at(task_id).get(); } // Larger images will be processed first, balancing the work time of the workers. std::sort(tasks->begin(), tasks->end(), [](const Task& a, const Task& b) { return b.length < a.length; }); for (int64_t worker_id = 0; worker_id < workers_.size(); ++worker_id) { std::shared_ptr work(new Work()); work->tasks = tasks; work->workspace = workspace_ptr + worker_id * workspace_size_per_worker; work->workspace_size = workspace_size_per_worker; work->done_counter = done_counter; work->task_counter = task_counter; workers_.at(worker_id)->Enqueue(work); } done_counter->WaitForeverUntilCntEqualZero(); } NEW_REGISTER_KERNEL(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeKernel) .SetIsMatchedPred([](const KernelConf& conf) -> bool { return conf.op_attribute().op_conf().device_tag() == "cpu"; }); #if defined(WITH_NVJPEG) NEW_REGISTER_KERNEL(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeKernel) .SetIsMatchedPred([](const KernelConf& conf) -> bool { return conf.op_attribute().op_conf().device_tag() == "cuda"; }); #endif // defined(WITH_NVJPEG) } // namespace oneflow ================================================ FILE: oneflow/core/kernel/input_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { namespace { class InputKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(InputKernel); InputKernel() = default; ~InputKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override { CHECK(this->op_conf().input_conf().has_job_name()); const auto& job_name = this->op_conf().input_conf().job_name(); const auto& op_name = this->op_conf().name(); auto* buffer_mgr = Singleton>>::Get(); auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name)); std::shared_ptr critical_section_instance; BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob("out"), op_name); } } void ForwardHeader(KernelContext* ctx) const override {} }; } // namespace REGISTER_KERNEL(OperatorConf::kInputConf, InputKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/runtime_blob_shape_infer_helper.h" #include "oneflow/core/kernel/kernel_observer.h" #include "oneflow/core/vm/sync_vm_mode_guard.h" namespace oneflow { namespace { bool IsAllBlobEmpty(const PbRpf& bns, KernelContext* ctx) { for (const auto& bn : bns) { Blob* blob = ctx->BnInOp2Blob(bn); if (blob && !blob->IsBodyEmpty()) { return false; } } return true; } } // namespace Kernel::Kernel() = default; Kernel::~Kernel() = default; void Kernel::InitBase(const KernelConf& kernel_conf) { if (shape_infer_helper_) { return; } kernel_conf_ = kernel_conf; shape_infer_helper_.reset( new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), this)); } void Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) { SyncVmModeGuard guard(SyncVmMode::kEnable); InitBase(kernel_conf); VirtualKernelInit(ctx); } void Kernel::Launch(KernelContext* ctx) const { SyncVmModeGuard guard(SyncVmMode::kEnable); ctx->WillForward(ctx, this); Forward(ctx); ctx->DidForward(ctx, this); } void Kernel::Forward(KernelContext* ctx) const { ctx->WillForwardHeader(ctx, this); ForwardHeader(ctx); ctx->DidForwardHeader(ctx, this); if ((!kernel_conf_.all_blobs_are_static()) && IsAllBlobEmpty(op_attribute().output_bns(), ctx) && IsStateless()) { return; } ctx->WillForwardDataContent(ctx, this); ForwardDataContent(ctx); ctx->DidForwardDataContent(ctx, this); } void Kernel::ForwardHeader(KernelContext* ctx) const { if (!kernel_conf_.all_blobs_are_static()) { ForwardShape(ctx); } } void Kernel::ForwardShape(KernelContext* ctx) const { return shape_infer_helper_->InferShape( [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); }); } std::unique_ptr ConstructKernel(const KernelConf& conf, KernelContext* kernel_ctx) { auto op_type = conf.op_attribute().op_conf().op_type_case(); CHECK_NE(op_type, OperatorConf::OpTypeCase::OP_TYPE_NOT_SET) << " ERROR! KernelConf: " << conf.DebugString() << " has NOT set op_type_case"; Kernel* rptr = kernel_registration::CreateKernel(conf); if (rptr == nullptr) { rptr = NewObj(op_type, conf); } CHECK_NOTNULL(rptr); rptr->Init(conf, kernel_ctx); return std::unique_ptr(rptr); } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_KERNEL_H_ #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/kernel/kernel_registration.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { class JobDesc; class RuntimeBlobShapeInferHelper; class Kernel { public: OF_DISALLOW_COPY_AND_MOVE(Kernel); virtual ~Kernel(); void Init(const KernelConf& kernel_conf, KernelContext* ctx); void Launch(KernelContext* ctx) const; const OperatorConf& op_conf() const { return op_attribute().op_conf(); } const OpAttribute& op_attribute() const { return kernel_conf().op_attribute(); } const KernelConf& kernel_conf() const { return kernel_conf_; } /* * return true means all below must be guaranteed when `Launch` function return: * 1) all out blob header has been set (e.g. SyncSetHeadKernel) * 2) all asynchronous task has been queued (e.g. NCCL related kernel) */ virtual bool IsKernelLaunchSynchronized() const { return true; } void SystemForwardHeader(KernelContext* ctx) const { ForwardHeader(ctx); } void SystemForwardDataContent(KernelContext* ctx) const { ForwardDataContent(ctx); } virtual void Forward(KernelContext* ctx) const; protected: Kernel(); void InitBase(const KernelConf&); virtual void VirtualKernelInit(KernelContext* ctx) {} virtual void ForwardHeader(KernelContext* ctx) const; virtual void ForwardShape(KernelContext* ctx) const; // TODO(niuchong) : rename ForwardDataContent to ForwardBody virtual void ForwardDataContent(KernelContext* ctx) const = 0; virtual bool IsStateless() const { return false; } private: std::unique_ptr shape_infer_helper_; KernelConf kernel_conf_; }; #define REGISTER_KERNEL(k, KernelType) \ REGISTER_CLASS_WITH_ARGS(int32_t, k, Kernel, KernelType, const KernelConf&) #define REGISTER_KERNEL_CREATOR(k, f) \ REGISTER_CLASS_CREATOR(int32_t, k, Kernel, f, const KernelConf&) std::unique_ptr ConstructKernel(const KernelConf& kernel_conf, KernelContext* ctx); } // namespace oneflow #define MAKE_KERNEL_CREATOR_ENTRY(kernel_class, device_type, data_type_pair) \ {GetHashKey(device_type, OF_PP_PAIR_SECOND(data_type_pair)), \ []() { return new kernel_class(); }}, #define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \ namespace { \ \ Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \ static const HashMap> creators = { \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class), \ DEVICE_TYPE_SEQ, data_type_seq)}; \ DeviceType device_type = \ CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \ auto key = GetHashKey(device_type, kernel_conf.data_type()); \ auto it = creators.find(key); \ if (it == creators.end()) { \ LOG(FATAL) << "Error! Cannot find kernel creator: " << kernel_conf.DebugString() \ << " with device_type = " << device_type \ << ", dtype = " << kernel_conf.data_type(); \ } \ return (it->second)(); \ } \ \ REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \ } #define MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY(kernel_class, device_type) \ {device_type, []() { return new kernel_class(); }}, #define ADD_DEVICE_TYPE_KERNEL_CREATOR(op_type_case, kernel_class) \ namespace { \ \ Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \ static const HashMap> creators = { \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY, (kernel_class), \ DEVICE_TYPE_SEQ)}; \ DeviceType device_type = \ CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \ auto it = creators.find(device_type); \ if (it == creators.end()) { \ LOG(FATAL) << "Error! Cannot find kernel creator: " << kernel_conf.DebugString() \ << " with device_type = " << device_type; \ } \ return (it->second)(); \ } \ \ REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \ } #define MAKE_CPU_KERNEL_CREATOR_ENTRY(kernel_class, data_type_pair) \ {OF_PP_PAIR_SECOND(data_type_pair), \ []() { return new kernel_class(); }}, #define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \ namespace { \ \ Kernel* CreateKernel(const KernelConf& kernel_conf) { \ static const HashMap> creators = { \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY, (kernel_class), \ data_type_seq)}; \ auto it = creators.find(kernel_conf.data_type()); \ if (it == creators.end()) { \ LOG(FATAL) << "Error! Cannot find kernel creator: " << kernel_conf.DebugString() \ << " with dtype = " << kernel_conf.data_type(); \ } \ return (it->second)(); \ } \ \ REGISTER_KERNEL_CREATOR(op_type_case, CreateKernel); \ } #endif // ONEFLOW_CORE_KERNEL_KERNEL_H_ ================================================ FILE: oneflow/core/kernel/kernel.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/data_type.proto"; import "oneflow/core/common/dtype_signature.proto"; import "oneflow/core/operator/op_attribute.proto"; import "oneflow/core/job/placement.proto"; import "oneflow/core/register/blob_desc.proto"; message DecodeRandomKernelConf { required uint32 random_seed = 1; } message ShapeElemCntKernelConf { repeated int32 axis = 1; } message UserKernelConf { map bn_in_op2blob_desc = 1; } message SyncDynamicResizeKernelConf { required DataType size_data_type = 1; } message BroadcastToCompatibleWithKernelConf { repeated int64 broadcast_axes = 1; } message ImageDecoderRandomCropResizeKernelConf { required int64 seed = 1; required int64 batch_size = 2; } message KernelConf { required DataType data_type = 2; required bool all_blobs_are_static = 6; required DTypeSignature dtype_signature = 7; optional ParallelContext parallel_ctx = 8; optional OpAttribute op_attribute = 9; optional string op_attribute_ref = 10; oneof kernel_type { UserKernelConf user_conf = 100; DecodeRandomKernelConf decode_random_conf = 103; SyncDynamicResizeKernelConf sync_dynamic_resize_conf = 360; ShapeElemCntKernelConf shape_elem_cnt_conf = 412; BroadcastToCompatibleWithKernelConf broadcast_to_compatible_with_conf = 428; ImageDecoderRandomCropResizeKernelConf image_decoder_random_crop_resize_conf = 429; } } ================================================ FILE: oneflow/core/kernel/kernel_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ #define ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ #include "oneflow/core/kernel/kernel_observer.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class Blob; class JobDesc; class KernelState { public: OF_DISALLOW_COPY_AND_MOVE(KernelState); KernelState() = default; virtual ~KernelState() = default; }; class KernelContext : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(KernelContext); KernelContext() = default; virtual ~KernelContext() = default; virtual ep::Stream* stream() const = 0; virtual Blob* BnInOp2Blob(const std::string& bn) const = 0; virtual const std::shared_ptr& state() const = 0; virtual void set_state(std::shared_ptr state) = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ ================================================ FILE: oneflow/core/kernel/kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_ #include "oneflow/core/common/util.h" namespace oneflow { class Kernel; class KernelContext; class Blob; class KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(KernelObserver); KernelObserver() = default; virtual ~KernelObserver() = default; virtual void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {} virtual void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {} virtual void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {} virtual void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {} virtual void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {} virtual void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {} }; class KernelObserverProvider { public: OF_DISALLOW_COPY_AND_MOVE(KernelObserverProvider); KernelObserverProvider() = default; virtual ~KernelObserverProvider() = default; virtual KernelObserver* GetKernelObserver() = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/kernel_registration.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel_registration.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { namespace kernel_registration { namespace { HashMap>* MutKernelRegistry() { static HashMap> creators; return &creators; } } // namespace KernelRegistrarBuilder& KernelRegistrarBuilder::SetCreateFn(CreateFn fn) { registry_val_.func = fn; return *this; } KernelRegistrarBuilder& KernelRegistrarBuilder::SetIsMatchedPred(IsMatchedPredicator fn) { registry_val_.cons.SetIsMatchedPred(fn); return *this; } void KernelRegistrarBuilder::Finalize(OperatorConf::OpTypeCase* op_type, KernelRegistryVal* val) const { *op_type = op_type_; val->func = registry_val_.func; val->cons = registry_val_.cons; } KernelRegistrar::KernelRegistrar(const KernelRegistrarBuilder& builder) { auto* creators = MutKernelRegistry(); OperatorConf::OpTypeCase op_type; KernelRegistryVal val; builder.Finalize(&op_type, &val); (*creators)[op_type].emplace_back(std::move(val)); } Kernel* CreateKernel(const KernelConf& kernel_conf) { auto op_type = kernel_conf.op_attribute().op_conf().op_type_case(); auto kernel_registry = MutKernelRegistry(); if (kernel_registry->find(op_type) == kernel_registry->end()) { return nullptr; } const auto& registry_vals = kernel_registry->at(op_type); Kernel* ret = nullptr; bool is_matched = false; for (const KernelRegistryVal& val : registry_vals) { if (val.cons.IsMatched(kernel_conf)) { CHECK(!is_matched) << "There are more than one kernel constraints satisfied by kernel conf of " << static_cast(op_type); is_matched = true; ret = val.func(); } } // TODO: print more info when failed return ret; } } // namespace kernel_registration } // namespace oneflow ================================================ FILE: oneflow/core/kernel/kernel_registration.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_ #define ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/operator/op_conf_util.h" namespace oneflow { class Kernel; namespace kernel_registration { using CreateFn = std::function; using IsMatchedPredicator = std::function; class KernelConstraint final { public: KernelConstraint() = default; ~KernelConstraint() = default; bool IsMatched(const KernelConf& conf) const { return predicator_(conf); } void SetIsMatchedPred(IsMatchedPredicator pred) { predicator_ = pred; } private: IsMatchedPredicator predicator_; }; struct KernelRegistryVal final { KernelRegistryVal() : func(), cons() {} CreateFn func; KernelConstraint cons; }; class KernelRegistrarBuilder final { public: explicit KernelRegistrarBuilder(OperatorConf::OpTypeCase op_type) : op_type_(op_type), registry_val_() {} KernelRegistrarBuilder& SetCreateFn(CreateFn fn); KernelRegistrarBuilder& SetIsMatchedPred(IsMatchedPredicator fn); void Finalize(OperatorConf::OpTypeCase* op_type, KernelRegistryVal* val) const; private: OperatorConf::OpTypeCase op_type_; KernelRegistryVal registry_val_; }; struct KernelRegistrar final { KernelRegistrar(const KernelRegistrarBuilder&); }; Kernel* CreateKernel(const KernelConf& kernel_conf); } // namespace kernel_registration #define NEW_REGISTER_KERNEL(op_type, ...) \ static kernel_registration::KernelRegistrar OF_PP_CAT(g_registrar, __COUNTER__) = \ kernel_registration::KernelRegistrarBuilder(op_type).SetCreateFn( \ []() { return new __VA_ARGS__(); }) #define REGISTER_KERNEL_WITH_NOTHING(op_type, ...) \ NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf&) -> bool { \ return true; \ }); #define REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, device, dtype, ...) \ NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \ return (*CHECK_JUST(DeviceTag4DeviceType(device)) \ == conf.op_attribute().op_conf().device_tag()) \ && (GetDataType::value == conf.data_type()); \ }); #define REGISTER_KERNEL_WITH_DEVICE(op_type, device, ...) \ NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \ return (*CHECK_JUST(DeviceTag4DeviceType(device)) \ == conf.op_attribute().op_conf().device_tag()); \ }); #define REGISTER_KERNEL_HELPER_CPU_FLOATING(op_type, kernel) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCPU, float, \ kernel) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCPU, double, \ kernel) #define REGISTER_KERNEL_HELPER_CUDA_FLOATING(op_type, kernel) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, float, \ kernel) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, double, \ kernel) #define REGISTER_KERNEL_HELPER_CUDA_HALF(op_type, kernel) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, float16, \ kernel) } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_ ================================================ FILE: oneflow/core/kernel/kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { void AutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz, const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) { ep::primitive::MemcpyKind kind{}; if (stream->device_type() == DeviceType::kCPU) { CHECK(memory::IsHostMem(src_mem_case)); if (dst_mem_case.device_type() != DeviceType::kMeta) { CHECK(memory::IsHostMem(dst_mem_case)); } kind = ep::primitive::MemcpyKind::kDtoD; } else { if (memory::IsHostMem(src_mem_case)) { CHECK(!memory::IsHostMem(dst_mem_case)); kind = ep::primitive::MemcpyKind::kHtoD; } else if (memory::IsHostMem(dst_mem_case)) { CHECK(!memory::IsHostMem(src_mem_case)); kind = ep::primitive::MemcpyKind::kDtoH; } else { kind = ep::primitive::MemcpyKind::kDtoD; } } std::unique_ptr primitive = ep::primitive::NewPrimitive(stream->device_type(), kind); CHECK(primitive); primitive->Launch(stream, dst, src, sz); } void AutoMemcpy(ep::Stream* stream, Blob* dst, const Blob* src) { const size_t body_bytes = src->ByteSizeOfBlobBody(); CHECK_EQ(dst->ByteSizeOfBlobBody(), body_bytes); AutoMemcpy(stream, dst->mut_dptr(), src->dptr(), body_bytes, dst->mem_case(), src->mem_case()); } void SyncAutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz, const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) { AutoMemcpy(stream, dst, src, sz, dst_mem_case, src_mem_case); CHECK_JUST(stream->Sync()); } void AutoMemset(ep::Stream* stream, void* dst, const char value, size_t sz, const MemoryCase& /*dst_mem_case*/) { std::unique_ptr primitive = ep::primitive::NewPrimitive(stream->device_type()); primitive->Launch(stream, dst, value, sz); } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/kernel_util.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_ #define ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_ #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/device/cuda_pseudo_half.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { template::value>::type* = nullptr> OF_DEVICE_FUNC T MaxWithLogThreshold(T x) { const T threshold = 1e-20; return x > threshold ? x : threshold; } template::value>::type* = nullptr> OF_DEVICE_FUNC T MaxWithLogThreshold(T x) { return x; } #if defined(__CUDACC__) __device__ __forceinline__ half MaxWithLogThreshold(half x) { half threshold = hexp2(__float2half(-14.0)); if (__hgt(x, threshold)) { return x; } return threshold; } #endif template OF_DEVICE_FUNC T SafeLog(T x) { return logf(MaxWithLogThreshold(x)); } #if defined(__CUDACC__) __device__ __forceinline__ half SafeLog(half x) { return hlog(MaxWithLogThreshold(x)); } #endif } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_ ================================================ FILE: oneflow/core/kernel/kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_ #define ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_ #include "oneflow/core/common/blas.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class Blob; class MemoryCase; class StreamContext; void AutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz, const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case); void AutoMemcpy(ep::Stream* stream, Blob* dst, const Blob* src); void SyncAutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz, const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case); void AutoMemset(ep::Stream* stream, void* dst, const char value, size_t sz, const MemoryCase& dst_mem_case); } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/core/kernel/learning_rate_schedule_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/graph_scope_vars.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" namespace oneflow { class LearningRateScheduleKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(LearningRateScheduleKernel); LearningRateScheduleKernel() = default; ~LearningRateScheduleKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override { if (Singleton::Get()->enable_debug_mode()) { pid_t pid = getpid(); log_stream_ = TeePersistentLogStream::Create(std::to_string(pid) + "-train_step2lr.csv"); (*log_stream_) << "train_step, lr\n"; } if (IsOpenGraphVerboseStepLr()) { print_step_lr_ = true; } } void ForwardDataContent(KernelContext* ctx) const override; bool print_step_lr_ = false; std::unique_ptr log_stream_; }; namespace { double GetDecayedLearningRate(const LearningRateDecayConf& conf, double base_lr, int64_t step); double ConstantLearningRate(double base_lr, double factor, int64_t total_step, int64_t cur_step) { CHECK_GE(total_step, 0); CHECK_GE(factor, 0.0); CHECK_LE(factor, 1.0); if (cur_step < total_step) { return base_lr * factor; } return base_lr; } double LinearLearningRate(double base_lr, double start_factor, double end_factor, int64_t total_step, int64_t cur_step) { CHECK_GE(total_step, 0); CHECK_GE(start_factor, 0.0); CHECK_LE(start_factor, 1.0); CHECK_GE(end_factor, 0.0); CHECK_LE(end_factor, 1.0); double multiplier = end_factor; double c_step_f = float(cur_step); double t_step_f = float(total_step); if (cur_step < total_step) { multiplier = start_factor + (end_factor - start_factor) * (c_step_f / t_step_f); } return base_lr * multiplier; } double ExponentialDecayedLearningRate(const ExponentialDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); double p = static_cast(cur_batch_num) / static_cast(conf.decay_batches()); if (conf.staircase()) { p = std::floor(p); } return lr * std::pow(conf.decay_rate(), p); } double InverseTimeDecayedLearningRate(const InverseTimeDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); double p = static_cast(cur_batch_num) / static_cast(conf.decay_batches()); if (conf.staircase()) { p = std::floor(p); } return lr / (1.0 + conf.decay_rate() * p); } double NaturalExpDecayedLearningRate(const NaturalExpDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); double p = static_cast(cur_batch_num) / static_cast(conf.decay_batches()); if (conf.staircase()) { p = std::floor(p); } return lr * std::exp(-conf.decay_rate() * p); } double PiecewiseConstantLearningRate(const PiecewiseConstantConf& conf, double lr, int64_t cur_batch_num) { const PbRf& boundaries = conf.boundaries(); const PbRf& values = conf.values(); CHECK_EQ(boundaries.size() + 1, values.size()); size_t i = 0; for (; i < boundaries.size(); ++i) { if (cur_batch_num <= boundaries[i]) { break; } } return values[i]; } double PolynomialDecayedLearningRate(const PolynomialDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); double cur_batch = static_cast(cur_batch_num); double decay_batches = static_cast(conf.decay_batches()); if (conf.cycle()) { if (cur_batch_num == 0) { cur_batch = 1.0; } decay_batches = decay_batches * std::ceil(cur_batch / decay_batches); } else { cur_batch = std::min(cur_batch, decay_batches); } return (lr - conf.end_learning_rate()) * std::pow(1.0 - (cur_batch / decay_batches), conf.power()) + conf.end_learning_rate(); } double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); const double PI = std::atan(1.0) * 4.0; double cur_batch = static_cast(cur_batch_num); double decay_batches = static_cast(conf.decay_batches()); cur_batch = std::min(cur_batch, decay_batches); double cosine_decay = 0.5 * (1.0 + std::cos(PI * cur_batch / decay_batches)); double decayed = (1.0 - conf.alpha()) * cosine_decay + conf.alpha(); return lr * decayed; } double CosineAnnealingDecayedLearningRate(const CosineAnnealingDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.t_max(), 0); if (0 == cur_batch_num) { return lr; } const double PI = std::atan(1.0) * 4.0; const double eta_min = conf.eta_min(); CHECK_LT(eta_min, lr); const double t_max_d = static_cast(conf.t_max()); const double cur_batch_num_d = static_cast(cur_batch_num); return eta_min + (((lr - eta_min) * (1 + std::cos(PI * (cur_batch_num_d / t_max_d)))) / 2); } double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); const double PI = std::atan(1.0) * 4.0; double cur_batch = static_cast(cur_batch_num); double decay_batches = static_cast(conf.decay_batches()); cur_batch = std::min(cur_batch, decay_batches); double linear_decay = (decay_batches - cur_batch) / decay_batches; double cosine_decay = 0.5 * (1.0 + std::cos(PI * 2.0 * conf.num_periods() * cur_batch / decay_batches)); double decayed = (conf.alpha() + linear_decay) * cosine_decay + conf.beta(); return lr * decayed; } double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr, int64_t cur_batch_num) { const PbRf& boundaries = conf.boundaries(); const PbRf& scales = conf.scales(); CHECK_EQ(boundaries.size() + 1, scales.size()); size_t i = 0; for (; i < boundaries.size(); ++i) { if (cur_batch_num <= boundaries[i]) { break; } } return scales[i] * lr; } double StepLearningRate(const StepConf& conf, double lr, int64_t cur_batch_num) { const int64_t step_size = conf.step_size(); CHECK_GE(step_size, 1); const double gamma = conf.gamma(); double cur_batch = static_cast(cur_batch_num); double step = static_cast(step_size); size_t i = std::floor(cur_batch / step); return lr * std::pow(gamma, i); } double MultiStepLearningRate(const MultiStepConf& conf, double lr, int64_t cur_batch_num) { const PbRf& milestones = conf.milestones(); CHECK_GE(milestones.size(), 1); const double gamma = conf.gamma(); size_t i = 0; if (cur_batch_num < milestones[milestones.size() - 1]) { for (; i < milestones.size(); ++i) { if (cur_batch_num < milestones[i]) { break; } } } else { i = milestones.size(); } return lr * std::pow(gamma, i); } double CosineAnnealingWarmRestartsLearningRate(const CosineAnnealingWarmRestartsConf& conf, const double base_lr, const int64_t step) { int64_t epoch_steps = conf.t_initial(); int64_t epoch = step / epoch_steps; int64_t step_in_epoch = step - (epoch_steps * epoch); if (conf.t_mult() > 1) { epoch = static_cast(std::floor( std::log(1 - step / conf.t_initial() * (1 - conf.t_mult())) / std::log(conf.t_mult()))); int64_t interval = std::pow(conf.t_mult(), epoch); epoch_steps = interval * conf.t_initial(); step_in_epoch = step - static_cast(std::floor(static_cast(1 - interval) / (1 - conf.t_mult()) * conf.t_initial())); } double lr = conf.eta_min(); if (conf.restart_limit() == 0 || (conf.restart_limit() > 0 && epoch < conf.restart_limit())) { double gamma = std::pow(conf.decay_rate(), epoch); lr = lr + 0.5 * (base_lr * gamma - lr) * (1 + std::cos(M_PI * step_in_epoch / epoch_steps)); } return lr; } double SequentialScheduler(const SequentialSchedulerConf& conf, const double base_lr, const int64_t step) { CHECK_GE(conf.schedulers_size(), 1); CHECK_EQ(conf.milestones_size(), conf.schedulers_size() - 1); CHECK_EQ(conf.interval_rescaling_size(), conf.milestones_size()); int64_t cur_step = step; size_t scheduler_idx = 0; for (size_t i = 0; i < conf.milestones_size(); ++i) { if (step < conf.milestones(i)) { break; } else { if (conf.interval_rescaling(i)) { cur_step = step - conf.milestones(i); } scheduler_idx++; } } return GetDecayedLearningRate(conf.schedulers(scheduler_idx), base_lr, cur_step); } double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) { if (conf.has_exponential_conf()) { return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num); } else if (conf.has_inverse_time_conf()) { return InverseTimeDecayedLearningRate(conf.inverse_time_conf(), lr, cur_batch_num); } else if (conf.has_natural_exp_conf()) { return NaturalExpDecayedLearningRate(conf.natural_exp_conf(), lr, cur_batch_num); } else if (conf.has_piecewise_constant_conf()) { return PiecewiseConstantLearningRate(conf.piecewise_constant_conf(), lr, cur_batch_num); } else if (conf.has_polynomial_conf()) { return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num); } else if (conf.has_cosine_conf()) { return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num); } else if (conf.has_cosine_annealing_conf()) { return CosineAnnealingDecayedLearningRate(conf.cosine_annealing_conf(), lr, cur_batch_num); } else if (conf.has_linear_cosine_conf()) { return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num); } else if (conf.has_piecewise_scaling_conf()) { return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num); } else if (conf.has_step_conf()) { return StepLearningRate(conf.step_conf(), lr, cur_batch_num); } else if (conf.has_multi_step_conf()) { return MultiStepLearningRate(conf.multi_step_conf(), lr, cur_batch_num); } else if (conf.has_constant_lr_conf()) { return ConstantLearningRate(lr, conf.constant_lr_conf().factor(), conf.constant_lr_conf().total_iters(), cur_batch_num); } else if (conf.has_linear_lr_conf()) { return LinearLearningRate(lr, conf.linear_lr_conf().start_factor(), conf.linear_lr_conf().end_factor(), conf.linear_lr_conf().total_iters(), cur_batch_num); } else if (conf.has_cosine_annealing_warm_restarts_conf()) { return CosineAnnealingWarmRestartsLearningRate(conf.cosine_annealing_warm_restarts_conf(), lr, cur_batch_num); } else if (conf.has_sequential_scheduler_conf()) { return SequentialScheduler(conf.sequential_scheduler_conf(), lr, cur_batch_num); } else { UNIMPLEMENTED(); } } } // namespace void LearningRateScheduleKernel::ForwardDataContent(KernelContext* ctx) const { const LearningRateScheduleOpConf& conf = this->op_conf().learning_rate_schedule_conf(); const int64_t train_step = *ctx->BnInOp2Blob("train_step")->dptr(); float learning_rate = conf.learning_rate(); if (conf.has_learning_rate_decay()) { learning_rate = GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, train_step); } // NOTE(lixiang): Set verbose=True will print step and lr. if (unlikely(print_step_lr_)) { std::cout << "Last step " << train_step << " adjusting learning rate to " << learning_rate << std::endl; } *ctx->BnInOp2Blob("out")->mut_dptr() = learning_rate; if (Singleton::Get()->enable_debug_mode()) { (*log_stream_) << std::to_string(train_step) << ", " << std::to_string(learning_rate) << "\n"; log_stream_->Flush(); } } REGISTER_KERNEL(OperatorConf::kLearningRateScheduleConf, LearningRateScheduleKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { class CclSendRecvBoxingKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(CclSendRecvBoxingKernel); CclSendRecvBoxingKernel() = default; ~CclSendRecvBoxingKernel() override = default; const std::vector>& in_tensor_slice_copier_vec() const { return in_tensor_slice_copier_vec_; } const std::vector>& out_tensor_slice_copier_vec() const { return out_tensor_slice_copier_vec_; } const std::vector& send_elem_cnts() const { return send_elem_cnts_; } const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } const bool has_input() const { return has_input_; } const bool has_output() const { return has_output_; } ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; } private: struct Comm { explicit Comm(ccl::CclComm comm) : ccl_comm(comm) {} ccl::CclComm ccl_comm; }; void Init() const { ParallelDesc parallel_desc(parallel_conf_); EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl::CclComm ccl_comm = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_); ccl_comm_.reset(new Comm(ccl_comm)); } const Comm& GetOrCreate() const { if (!ccl_comm_) { Init(); } return *ccl_comm_; } void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; std::string stream_name_; ParallelConf parallel_conf_; mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; std::vector send_elem_cnts_; std::vector recv_elem_cnts_; bool has_input_; bool has_output_; }; void CclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { Blob* buf = ctx->BnInOp2Blob("buf"); ccl::CclComm ccl_comm = this->ccl_comm(); const std::vector& send_elem_cnts = this->send_elem_cnts(); const std::vector& recv_elem_cnts = this->recv_elem_cnts(); const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num(); const DataType data_type = buf->data_type(); const size_t dtype_size = GetSizeOfDataType(data_type); std::vector send_in_ptr; std::vector recv_out_ptr; std::vector send_offsets; std::vector recv_offsets; char* buf_ptr = buf->mut_dptr(); uint64_t offset = 0; if (this->has_input()) { for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * dtype_size; } } const uint64_t recv_offset = offset; if (this->has_output()) { for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * dtype_size; } } if (this->has_input()) { const Blob* in = ctx->BnInOp2Blob("in"); const std::vector>& in_tensor_slice_copier_vec = this->in_tensor_slice_copier_vec(); for (int64_t i = 0; i < parallel_num; ++i) { if (in_tensor_slice_copier_vec.at(i)) { in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); } } } if (this->has_input() || this->has_output()) { std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, this->has_input(), this->has_output()); } if (!this->has_output()) { return; } Blob* out = ctx->BnInOp2Blob("out"); const std::vector>& out_tensor_slice_copier_vec = this->out_tensor_slice_copier_vec(); if (src_nd_sbp_no_partial_parallel_) { for (int64_t i = 0; i < parallel_num; ++i) { if (out_tensor_slice_copier_vec.at(i)) { out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); } } } else { std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out->data_type()); CHECK(add_primitive); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); CHECK(memset_primitive); bool is_first_slice = true; for (int64_t i = 0; i < parallel_num; ++i) { if (out_tensor_slice_copier_vec.at(i)) { if (is_first_slice) { is_first_slice = false; if (recv_elem_cnts.at(i) != out->shape().elem_cnt()) { // if not same shape, memset out memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, out->shape().elem_cnt() * dtype_size); } out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); } else { if (recv_elem_cnts.at(i) == out->shape().elem_cnt()) { add_primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(), out->shape().elem_cnt()); } else { void* out_buf = reinterpret_cast(buf_ptr + offset); memset_primitive->Launch(ctx->stream(), out_buf, 0, out->shape().elem_cnt() * dtype_size); out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); add_primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), out->shape().elem_cnt()); } } } } } } void CclSendRecvBoxingKernel::VirtualKernelInit(KernelContext* ctx) { const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); if (this->op_conf().has_stream_name_hint()) { stream_name_ = this->op_conf().stream_name_hint(); } else { stream_name_ = EagerCclCommMgr::kDefaultCclStreamName; } parallel_conf_ = conf.parallel_conf(); const int64_t parallel_id = this->kernel_conf().parallel_ctx().parallel_id(); ParallelDesc parallel_desc(parallel_conf_); ParallelDesc src_parallel_desc(conf.src_parallel_conf()); ParallelDesc dst_parallel_desc(conf.dst_parallel_conf()); const NdSbp& src_nd_sbp = conf.src_nd_sbp(); const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); has_input_ = conf.has_input(); has_output_ = conf.has_output(); src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); const DataType data_type = this->kernel_conf().data_type(); const DeviceType device_type = parallel_desc.device_type(); const Shape& logical_shape = Shape(conf.logical_shape()); const int64_t parallel_num = parallel_desc.parallel_num(); std::vector src_send_intersections; std::vector dst_recv_intersections; GetRankSendRecvIntersection(parallel_id, parallel_desc, src_parallel_desc, dst_parallel_desc, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); // if parallel_id exists in src parallel desc, has send int64_t src_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, src_parallel_desc); if (src_parallel_id != -1) { CHECK_EQ(src_send_intersections.size(), parallel_num); send_elem_cnts_.resize(parallel_num); in_tensor_slice_copier_vec_.resize(parallel_num); const TensorSliceView& cur_rank_in_slice = GetTensorSliceView4ParallelId( *src_parallel_desc.hierarchy(), src_nd_sbp, logical_shape, src_parallel_id); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = src_send_intersections.at(i); if (!intersection.IsEmpty()) { send_elem_cnts_.at(i) = intersection.shape().elem_cnt(); in_tensor_slice_copier_vec_.at(i).reset( new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type)); } } } else { CHECK_EQ(src_send_intersections.size(), 0); } // if parallel_id exists in src parallel desc, has send int64_t dst_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, dst_parallel_desc); if (dst_parallel_id != -1) { CHECK_EQ(dst_recv_intersections.size(), parallel_num); recv_elem_cnts_.resize(parallel_num); out_tensor_slice_copier_vec_.resize(parallel_num); const TensorSliceView& cur_rank_out_slice = GetTensorSliceView4ParallelId( *dst_parallel_desc.hierarchy(), dst_nd_sbp, logical_shape, dst_parallel_id); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = dst_recv_intersections.at(i); if (!intersection.IsEmpty()) { recv_elem_cnts_.at(i) = intersection.shape().elem_cnt(); out_tensor_slice_copier_vec_.at(i).reset( new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type)); } } } else { CHECK_EQ(dst_recv_intersections.size(), 0); } } // TODO: replace all kNcclxxxConf with kCclxxxConf(for multi devices) REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, CclSendRecvBoxingKernel); REGISTER_SYSTEM_OP_KERNEL_UNIFIED_CCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/core/kernel/new_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_ #define ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memcpy.h" namespace oneflow { namespace ep { class Stream; } template void Memcpy(ep::Stream* stream, void* dst, const void* src, size_t sz) { CHECK_EQ(device_type, stream->device_type()) << "Device type mismatch"; std::unique_ptr primitive = ep::primitive::NewPrimitive(stream->device_type(), ep::primitive::MemcpyKind::kDtoD); CHECK(primitive) << "Can not create Memcpy primitive for device type " << device_type; primitive->Launch(stream, dst, src, sz); } template void Memset(ep::Stream* stream, void* dst, const char value, size_t sz) { CHECK_EQ(device_type, stream->device_type()) << "Device type mismatch"; std::unique_ptr primitive = ep::primitive::NewPrimitive(stream->device_type()); CHECK(primitive) << "Can not create Memset primitive for device type " << device_type; primitive->Launch(stream, dst, value, sz); } } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_ ================================================ FILE: oneflow/core/kernel/nop_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" namespace oneflow { class NopKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(NopKernel); NopKernel() = default; ~NopKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override {} }; REGISTER_KERNEL(OperatorConf::kVariableConf, NopKernel); REGISTER_KERNEL(OperatorConf::kTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kSinkTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kAccTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kCopyCommNetConf, NopKernel); REGISTER_KERNEL(OperatorConf::kDeviceTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kDstSubsetTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kSourceTickConf, NopKernel); REGISTER_KERNEL(OperatorConf::kSrcSubsetTickConf, NopKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/output_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { class OutputKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(OutputKernel); OutputKernel() = default; ~OutputKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override; }; void OutputKernel::ForwardDataContent(KernelContext* ctx) const { CHECK(this->op_conf().output_conf().has_job_name()); const auto& job_name = this->op_conf().output_conf().job_name(); const auto& op_name = this->op_conf().name(); auto* buffer_mgr = Singleton>>::Get(); auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); std::shared_ptr critical_section_instance; BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob("in"), op_name); } } void OutputKernel::ForwardHeader(KernelContext* ctx) const { // Do nothing. } REGISTER_KERNEL(OperatorConf::kOutputConf, OutputKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/profiler_kernel_observer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/profiler_kernel_observer.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/kernel.h" namespace oneflow { void ProfilerKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentStart(kernel_ctx, kernel)); } void ProfilerKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentEnd(kernel_ctx, kernel)); } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/profiler_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_ #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class ProfilerKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(ProfilerKernelObserver); ProfilerKernelObserver() = default; ~ProfilerKernelObserver() override = default; void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/random_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/random_generator.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { template void RandomGenerator::Uniform(const int64_t elem_cnt, T* dptr) { Uniform(elem_cnt, GetZeroVal(), GetOneVal(), dptr); } template void RandomGenerator::Uniform(const int64_t elem_cnt, const T min, const T max, T* dptr) { CHECK_GE(elem_cnt, 0); CHECK(dptr); CHECK_LE(min, max); std::uniform_real_distribution random_distribution(min, std::nextafter(max, GetMaxVal())); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(mt19937_generator_); } } #define INITIATE_CPU_RANDOM_GENERATOR_UNIFORM(T, typeproto) \ template void RandomGenerator::Uniform(const int64_t elem_cnt, T* dptr); \ template void RandomGenerator::Uniform(const int64_t elem_cnt, const T min, \ const T max, T* dptr); OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_RANDOM_GENERATOR_UNIFORM, FLOATING_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/random_generator.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/random_generator.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template void RngUniformGpu(const curandGenerator_t& gen, int64_t n, T* ret); template<> void RngUniformGpu(const curandGenerator_t& gen, int64_t n, float* ret) { OF_CURAND_CHECK(curandGenerateUniform(gen, ret, n)); } template<> void RngUniformGpu(const curandGenerator_t& gen, int64_t n, double* ret) { OF_CURAND_CHECK(curandGenerateUniformDouble(gen, ret, n)); } } // namespace RandomGenerator::RandomGenerator(int64_t seed, ep::Stream* stream) { OF_CURAND_CHECK(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); OF_CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator_, seed)); OF_CURAND_CHECK(curandSetStream(curand_generator_, stream->As()->cuda_stream())); } RandomGenerator::~RandomGenerator() { OF_CURAND_CHECK(curandDestroyGenerator(curand_generator_)); } template void RandomGenerator::Uniform(const int64_t elem_cnt, T* dptr) { RngUniformGpu(curand_generator_, elem_cnt, dptr); } #define INITIATE_CUDA_RANDOM_GENERATOR_UNIFORM(T, typeproto) \ template void RandomGenerator::Uniform(const int64_t elem_cnt, T* dptr); OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_RANDOM_GENERATOR_UNIFORM, FLOATING_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/random_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_ #define ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template class RandomGenerator; template<> class RandomGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomGenerator); RandomGenerator(int64_t seed, ep::Stream* stream) : mt19937_generator_(seed) {} ~RandomGenerator() {} template void Uniform(const int64_t elem_cnt, T* dptr); template void Uniform(const int64_t elem_cnt, const T min, const T max, T* dptr); private: std::mt19937 mt19937_generator_; }; template<> class RandomGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomGenerator); RandomGenerator(int64_t seed, ep::Stream* stream); ~RandomGenerator(); template void Uniform(const int64_t elem_cnt, T* dptr); private: #ifdef WITH_CUDA curandGenerator_t curand_generator_; #endif }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_ ================================================ FILE: oneflow/core/kernel/reentrant_lock_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/reentrant_lock_kernel.h" namespace oneflow { std::string ReentrantLockStatus::kEmptyIbn = "reentrant_lock_status_empty_ibn"; void ReentrantLockStatus::Init(const KernelConf& kernel_conf) { const auto& conf = kernel_conf.op_attribute().op_conf().reentrant_lock_conf(); cur_ibn_ = ""; cur_act_id_ = -1; acquired_lock_to_be_sent_ = false; total_queued_request_lock_num_ = 0; total_acquired_lock_num_ = 0; lock_id2queued_request_act_id_.resize(conf.lock_id2intersecting_lock_ids_size()); lock_id2acquired_num_.resize(conf.lock_id2intersecting_lock_ids_size()); for (const Int64List& ids : conf.lock_id2intersecting_lock_ids()) { lock_id2intersecting_lock_ids_.emplace_back( std::vector(ids.value().begin(), ids.value().end())); } } bool ReentrantLockStatus::TryAcquireLock(int64_t lock_id) { CHECK_EQ(lock_id2queued_request_act_id_.at(lock_id).empty(), false); int64_t act_id = lock_id2queued_request_act_id_.at(lock_id).front(); bool blocked = false; for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.at(lock_id)) { if (lock_id2acquired_num_.at(intersect_lock_id) > 0 || (lock_id2queued_request_act_id_.at(intersect_lock_id).empty() == false && lock_id2queued_request_act_id_.at(intersect_lock_id).front() < act_id)) { blocked = true; break; } } if (blocked) { return false; } lock_id2queued_request_act_id_.at(lock_id).pop(); --total_queued_request_lock_num_; ++lock_id2acquired_num_.at(lock_id); ++total_acquired_lock_num_; return true; } void ReentrantLockStatus::RequestLock(int64_t lock_id, std::queue* unlocked_ids) { lock_id2queued_request_act_id_.at(lock_id).push(cur_act_id()); ++total_queued_request_lock_num_; if (TryAcquireLock(lock_id)) { unlocked_ids->push(lock_id); } } void ReentrantLockStatus::ReleaseLock(int64_t lock_id, std::queue* unlocked_ids) { CHECK_GT(lock_id2acquired_num_.at(lock_id), 0); CHECK_GT(total_acquired_lock_num_, 0); --lock_id2acquired_num_.at(lock_id); --total_acquired_lock_num_; size_t unlocked_cnt = 0; do { unlocked_cnt = 0; auto ReleaseRelatedLockId = [&](int64_t related_lock_id) { if (lock_id2queued_request_act_id_.at(related_lock_id).empty()) { return; } if (TryAcquireLock(related_lock_id)) { unlocked_ids->push(related_lock_id); ++unlocked_cnt; } }; ReleaseRelatedLockId(lock_id); for (int64_t id : lock_id2intersecting_lock_ids_.at(lock_id)) { ReleaseRelatedLockId(id); } } while (unlocked_cnt > 0); } template void ReentrantLockKernel::VirtualKernelInit(KernelContext* ctx) { ctx->set_state(std::make_shared()); } template void ReentrantLockKernel::ForwardDataContent(KernelContext* ctx) const { auto* const status = CHECK_NOTNULL(dynamic_cast(ctx->state().get())); if (status->cur_ibn() == "start") { T lock_id = *ctx->BnInOp2Blob("start")->dptr(); status->RequestLock(lock_id, status->mut_cur_unlocked_ids()); } else if (status->cur_ibn() == "end") { status->ReleaseLock(*ctx->BnInOp2Blob("end")->dptr(), status->mut_cur_unlocked_ids()); } else { CHECK_EQ(status->cur_ibn(), ReentrantLockStatus::kEmptyIbn); } if (status->cur_unlocked_ids().size() > 0) { T lock_id = status->cur_unlocked_ids().front(); status->mut_cur_unlocked_ids()->pop(); *ctx->BnInOp2Blob("out")->mut_dptr() = lock_id; status->set_acquired_lock_to_be_sent(true); } else { status->set_acquired_lock_to_be_sent(false); } } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kReentrantLockConf, ReentrantLockKernel, INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/kernel/reentrant_lock_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/graph/graph.h" namespace oneflow { class ReentrantLockStatus final : public KernelState { public: OF_DISALLOW_COPY_AND_MOVE(ReentrantLockStatus); ReentrantLockStatus() = default; ~ReentrantLockStatus() = default; void Init(const KernelConf& kernel_conf); static std::string kEmptyIbn; // true: success // false: failed void RequestLock(int64_t lock_id, std::queue* unlocked_ids); // return lock_id if any other lock acquired // -1: no other lock acquired void ReleaseLock(int64_t lock_id, std::queue* unlocked_ids); const std::queue& cur_unlocked_ids() const { return cur_unlocked_ids_; } std::queue* mut_cur_unlocked_ids() { return &cur_unlocked_ids_; } // Getters const std::string& cur_ibn() const { return cur_ibn_; } int64_t cur_act_id() const { return cur_act_id_; } bool acquired_lock_to_be_sent() const { return acquired_lock_to_be_sent_; } size_t total_queued_request_lock_num() const { return total_queued_request_lock_num_; } size_t total_acquired_lock_num() const { return total_acquired_lock_num_; } // Setters void set_cur_ibn(const std::string& ibn) { cur_ibn_ = ibn; } void set_cur_act_id(int64_t act_id) { cur_act_id_ = act_id; } void set_acquired_lock_to_be_sent(bool val) { acquired_lock_to_be_sent_ = val; } private: // true: success // false: failed bool TryAcquireLock(int64_t lock_id); std::string cur_ibn_; int64_t cur_act_id_{}; bool acquired_lock_to_be_sent_{}; size_t total_queued_request_lock_num_{}; size_t total_acquired_lock_num_{}; std::vector> lock_id2queued_request_act_id_; std::vector lock_id2acquired_num_; std::vector> lock_id2intersecting_lock_ids_; std::queue cur_unlocked_ids_; }; template class ReentrantLockKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(ReentrantLockKernel); ReentrantLockKernel() = default; ~ReentrantLockKernel() override = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_ ================================================ FILE: oneflow/core/kernel/return_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { class ReturnKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(ReturnKernel); ReturnKernel() = default; ~ReturnKernel() = default; private: void ForwardDataContent(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override; }; void ReturnKernel::ForwardDataContent(KernelContext* ctx) const { CHECK(this->op_conf().return_conf().has_job_name()); const auto& job_name = this->op_conf().return_conf().job_name(); const auto& op_name = this->op_conf().name(); auto* buffer_mgr = Singleton>>::Get(); auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); std::shared_ptr critical_section_instance; BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob("in"), op_name); } } void ReturnKernel::ForwardHeader(KernelContext* ctx) const { // Do nothing. } REGISTER_KERNEL(OperatorConf::kReturnConf, ReturnKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/runtime_blob_shape_infer_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/runtime_blob_shape_infer_helper.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/common/cached_caller.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { RuntimeBlobShapeInferHelper::RuntimeBlobShapeInferHelper(const OperatorConf& op_conf, const KernelConf& kernel_conf, const void* scope) { op_ = CHECK_JUST(ConstructOp(op_conf)); const OpAttribute& op_attribute = kernel_conf.op_attribute(); if (op_attribute.has_parallel_conf_signature() && op_attribute.parallel_conf_signature().has_op_parallel_conf()) { CHECK_JUST(op_->FillOpParallelDesc( ParallelDesc(op_attribute.parallel_conf_signature().op_parallel_conf()))); } if (op_attribute.has_sbp_signature()) { sbp_signature_.reset(new SbpSignature(op_attribute.sbp_signature())); CHECK_JUST(op_->FillSbpSignature(*sbp_signature_)); } op_->ForEachBnInOp([&](const std::string& bn_in_op) { bn_in_op2blob_desc_[bn_in_op].reset(); }); if (op_attribute.has_logical_blob_desc_signature()) { HashMap> bn_in_op2logical_blob_desc; const auto& blob_desc_signature_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc(); for (const auto& pair : blob_desc_signature_map) { bn_in_op2logical_blob_desc[pair.first].reset(new BlobDesc(pair.second)); } auto GetLogicalBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* { if (bn_in_op2logical_blob_desc.find(bn) != bn_in_op2logical_blob_desc.end()) { return bn_in_op2logical_blob_desc.at(bn).get(); } return nullptr; }; CHECK_JUST(op_->FillLogicalInBlobDesc(GetLogicalBlobDesc4BnInOp)); CHECK_JUST(op_->FillLogicalOutBlobDesc(GetLogicalBlobDesc4BnInOp)); } if (kernel_conf.has_parallel_ctx()) { parallel_ctx_.reset(new ParallelContext(kernel_conf.parallel_ctx())); } op_infer_cache_key_.scope = scope; op_infer_cache_key_.op_conf_sym = op_->GetOpConfWithoutOpNameAndLbn(); op_infer_cache_key_.ibn_idx2shape_sym.resize(op_->input_bns().size()); op_infer_cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature()); } void RuntimeBlobShapeInferHelper::UpdateInputBlobDescs7OpInferCacheKey( std::function BnInOp2Blob) { auto ResetBlobDescAndGetShapeSym = [&](const std::string& ibn) -> Symbol { const Blob* blob = BnInOp2Blob(ibn); if (blob == nullptr) { return Symbol(); } BlobDesc* blob_desc = BlobDesc4BnInOp(ibn, blob->blob_desc()); Shape blob_shape = blob_desc->shape(); blob_shape.LeftOnesExtendedAssign(blob->shape()); blob_desc->set_shape(blob_shape); Stride blob_stride = blob_desc->stride(); blob_stride.CheckNumAxesIdenticalAndAssign(blob->stride()); blob_desc->set_stride(blob_stride); return SymbolOf(blob_desc->shape()); }; const auto& input_bns = op_->input_bns(); FOR_RANGE(int, i, 0, input_bns.size()) { op_infer_cache_key_.ibn_idx2shape_sym.at(i) = ResetBlobDescAndGetShapeSym(input_bns.Get(i)); } } BlobDesc* RuntimeBlobShapeInferHelper::BlobDesc4BnInOp(const std::string& bn_in_op, const BlobDesc& blob_desc) { auto it = bn_in_op2blob_desc_.find(bn_in_op); if (it == bn_in_op2blob_desc_.end()) { return nullptr; } if (!it->second) { it->second.reset(new BlobDesc(blob_desc)); } return it->second.get(); } void RuntimeBlobShapeInferHelper::InferShape( const std::function& BnInOp2Blob) { UpdateInputBlobDescs7OpInferCacheKey(BnInOp2Blob); auto Infer = [&](const OpInferCacheKey& key) -> std::shared_ptr { auto CachedBlobDesc4BnInOp = WithResultCached([&](const std::string& bn_in_op) -> BlobDesc* { const Blob* blob = BnInOp2Blob(bn_in_op); if (blob == nullptr) { return nullptr; } return BlobDesc4BnInOp(bn_in_op, blob->blob_desc()); }); CHECK_JUST(op_->InferOutBlobDescsIf(CachedBlobDesc4BnInOp, parallel_ctx_.get())); auto* ret = new OpInferCacheValue(); ret->obn_idx2shape_sym.resize(op_->output_bns().size()); FOR_RANGE(int, i, 0, op_->output_bns().size()) { const auto& obn = op_->output_bns().Get(i); const auto& blob_desc = bn_in_op2blob_desc_.at(obn); ret->obn_idx2shape_sym.at(i).reset(blob_desc->shape()); auto* blob = BnInOp2Blob(obn); if (blob == nullptr) { continue; } CHECK_EQ(blob->data_type(), blob_desc->data_type()); CHECK_EQ(blob->blob_desc().is_dynamic(), blob_desc->is_dynamic()); } return std::shared_ptr(ret); }; size_t cache_size = Singleton::Get()->thread_local_cache_max_size(); const auto& shape_infer_ret = ThreadLocalCachedCall(cache_size, Infer, op_infer_cache_key_); const auto& obn_idx2shape_sym = shape_infer_ret->obn_idx2shape_sym; FOR_RANGE(int, i, 0, op_->output_bns().size()) { const auto& obn = op_->output_bns().Get(i); auto* blob = BnInOp2Blob(obn); if (blob == nullptr) { continue; } if (blob->blob_desc().is_dynamic()) { blob->mut_shape_view()->set_shape(*obn_idx2shape_sym.at(i)); } else { CHECK(*obn_idx2shape_sym.at(i) == blob->static_shape()); } } } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/runtime_blob_shape_infer_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_ #define ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_ #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_infer_cache.h" namespace oneflow { class Blob; class BlobDesc; class RuntimeBlobShapeInferHelper final { public: RuntimeBlobShapeInferHelper(const OperatorConf& op_conf, const KernelConf& kernel_conf, const void* scope); ~RuntimeBlobShapeInferHelper() = default; void InferShape(const std::function& BnInOp2Blob); private: void UpdateInputBlobDescs7OpInferCacheKey(std::function BnInOp2Blob); BlobDesc* BlobDesc4BnInOp(const std::string& bn_in_op, const BlobDesc& rt_blob_desc); std::shared_ptr op_; HashSet ibns_; HashMap> bn_in_op2blob_desc_; std::unique_ptr parallel_ctx_; std::unique_ptr sbp_signature_; OpInferCacheKey op_infer_cache_key_; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_ ================================================ FILE: oneflow/core/kernel/shape_elem_cnt_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { template class ShapeElemCntKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(ShapeElemCntKernel); ShapeElemCntKernel() = default; ~ShapeElemCntKernel() override = default; private: void ForwardDataContent(KernelContext* ctx) const override; int32_t GetShapePartialElemCnt(const ShapeView& shape) const; }; template void ShapeElemCntKernel::ForwardDataContent(KernelContext* ctx) const { const T elem_cnt = GetShapePartialElemCnt(ctx->BnInOp2Blob("x")->shape()); std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), ctx->BnInOp2Blob("y")->data_type()); CHECK(fill); fill->Launch(ctx->stream(), ctx->BnInOp2Blob("y")->mut_dptr(), elem_cnt, 1); } template int32_t ShapeElemCntKernel::GetShapePartialElemCnt(const ShapeView& shape) const { int32_t ret = 1; for (int32_t axis : this->kernel_conf().shape_elem_cnt_conf().axis()) { ret *= shape.At(axis); } return ret; } ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kShapeElemCntConf, ShapeElemCntKernel, ARITHMETIC_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/slice_boxing_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { class SliceBoxingKernel : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingKernel); SliceBoxingKernel() = default; ~SliceBoxingKernel() override = default; protected: virtual const SliceBoxingConf& GetCustomizedBoxingConf() const = 0; const std::vector>& tensor_slice_copier_vec() const; private: void VirtualKernelInit(KernelContext* ctx) override; std::vector> tensor_slice_copier_vec_; }; class SliceBoxingCopyKernel final : public SliceBoxingKernel { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingCopyKernel); SliceBoxingCopyKernel() = default; ~SliceBoxingCopyKernel() override = default; private: virtual const SliceBoxingConf& GetCustomizedBoxingConf() const override; void ForwardDataContent(KernelContext* ctx) const override; }; class SliceBoxingAddKernel final : public SliceBoxingKernel { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingAddKernel); SliceBoxingAddKernel() = default; ~SliceBoxingAddKernel() override = default; private: virtual const SliceBoxingConf& GetCustomizedBoxingConf() const override; void ForwardDataContent(KernelContext* ctx) const override; }; void SliceBoxingKernel::VirtualKernelInit(KernelContext* ctx) { const SliceBoxingConf& conf = GetCustomizedBoxingConf(); if (/*is_0size_tensor=*/std::any_of(conf.out_shape().dim().begin(), conf.out_shape().dim().end(), [](int64_t dim) { return dim == 0; })) { return; } const TensorSliceView out_slice(conf.out_slice()); for (const TensorSliceViewProto& in_slice_proto : conf.in_slice()) { const TensorSliceView in_slice(in_slice_proto); tensor_slice_copier_vec_.emplace_back(new TensorSliceCopier( out_slice, in_slice, this->kernel_conf().data_type(), ctx->stream()->device_type())); } } const std::vector>& SliceBoxingKernel::tensor_slice_copier_vec() const { return tensor_slice_copier_vec_; } const SliceBoxingConf& SliceBoxingCopyKernel::GetCustomizedBoxingConf() const { return this->op_conf().slice_boxing_copy_conf().slice_boxing_conf(); } void SliceBoxingCopyKernel::ForwardDataContent(KernelContext* ctx) const { Blob* out = ctx->BnInOp2Blob("out"); if (out->shape_view().elem_cnt() == 0) { return; } FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) { const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn("in", i)); this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i); } } const SliceBoxingConf& SliceBoxingAddKernel::GetCustomizedBoxingConf() const { return this->op_conf().slice_boxing_add_conf().slice_boxing_conf(); } void SliceBoxingAddKernel::ForwardDataContent(KernelContext* ctx) const { Blob* out = ctx->BnInOp2Blob("out"); if (out->shape_view().elem_cnt() == 0) { return; } std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out->data_type()); CHECK(primitive); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); CHECK(memset_primitive); FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) { const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn("in", i)); if (i == 0) { if (in_i->shape().NumAxes() == 0 && out->shape().NumAxes() == 0) { AutoMemcpy(ctx->stream(), out, in_i); } else { this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i); } } else { if (in_i->shape() == out->shape()) { primitive->Launch(ctx->stream(), out->dptr(), in_i->dptr(), out->mut_dptr(), out->shape().elem_cnt()); } else { Blob* buf = ctx->BnInOp2Blob("buf"); memset_primitive->Launch(ctx->stream(), buf->mut_dptr(), 0, buf->shape().elem_cnt() * GetSizeOfDataType(buf->data_type())); this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), buf, in_i); primitive->Launch(ctx->stream(), out->dptr(), buf->dptr(), out->mut_dptr(), out->shape().elem_cnt()); } } } } REGISTER_KERNEL(OperatorConf::kSliceBoxingCopyConf, SliceBoxingCopyKernel); REGISTER_KERNEL(OperatorConf::kSliceBoxingAddConf, SliceBoxingAddKernel); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/sync_check_kernel_observer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/sync_check_kernel_observer.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { void SyncCheckKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { CHECK_JUST_MSG(kernel_ctx->stream()->Sync(), kernel->op_conf().name()); } } // namespace oneflow ================================================ FILE: oneflow/core/kernel/sync_check_kernel_observer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_ #define ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_ #include "oneflow/core/kernel/kernel_observer.h" namespace oneflow { class SyncCheckKernelObserver final : public KernelObserver { public: OF_DISALLOW_COPY_AND_MOVE(SyncCheckKernelObserver); SyncCheckKernelObserver() = default; ~SyncCheckKernelObserver() override = default; void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_ ================================================ FILE: oneflow/core/kernel/sync_dynamic_resize_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/register/register_desc.h" #include "oneflow/core/lazy/actor/actor_context.h" #include "oneflow/core/memory/memory_case_util.h" #include #include #include #include #include namespace oneflow { #ifdef WITH_CUDA namespace { class CudaHostMem { public: OF_DISALLOW_COPY_AND_MOVE(CudaHostMem); CudaHostMem(const size_t size) { OF_CUDA_CHECK(cudaMallocHost(&ptr_, size)); } ~CudaHostMem() { OF_CUDA_CHECK(cudaFreeHost(ptr_)); } void* Ptr() const { return ptr_; } private: void* ptr_; }; } // namespace template class SyncDynamicResizeGPUKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeGPUKernel); SyncDynamicResizeGPUKernel() = default; ~SyncDynamicResizeGPUKernel() override = default; private: bool IsKernelLaunchSynchronized() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override { const SyncDynamicResizeOpConf& conf = this->op_conf().sync_dynamic_resize_conf(); CHECK_EQ(conf.axis(), 0); std::shared_ptr cuda_host_mem_ptr; { std::lock_guard lock(mutex_); if (queue_.empty()) { cuda_host_mem_ptr.reset(new CudaHostMem(sizeof(SizeType))); } else { cuda_host_mem_ptr = queue_.front(); queue_.pop(); } } const Blob* in = ctx->BnInOp2Blob("in"); const Blob* size = ctx->BnInOp2Blob("size"); Blob* out = ctx->BnInOp2Blob("out"); AutoMemcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->ByteSizeOfBlobBody(), out->mem_case(), in->mem_case()); AutoMemcpy(ctx->stream(), cuda_host_mem_ptr->Ptr(), size->dptr(), sizeof(SizeType), memory::MakeHostMemCase(), size->mem_case()); const auto& UpdateShape = [out, cuda_host_mem_ptr, conf, this]() { const int64_t new_size = *reinterpret_cast(cuda_host_mem_ptr->Ptr()); CHECK_GE(new_size, 0); CHECK_LE(new_size, out->shape_view().At(conf.axis())); // NOTE(Liang Depeng): `mut_shape_view` should be used here to get the blob's `MutShapeView` // pointer. But this callback is called after `Kernel::Forward` function's // execution and the header check is already been set to false at that // moment. So we have to choose the `ForceMutShapeView` function with // header checker disabled. out->ForceMutShapeView()->Set(conf.axis(), new_size); std::lock_guard lock(mutex_); queue_.push(cuda_host_mem_ptr); }; if (conf.eager()) { CHECK_JUST(ctx->stream()->Sync()); UpdateShape(); } else { auto* actor_context_provider = CHECK_NOTNULL(dynamic_cast(ctx)); actor_context_provider->GetActorContext()->AddCallback(UpdateShape); } } mutable std::queue> queue_; mutable std::mutex mutex_; }; #define REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(stype) \ NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeGPUKernel) \ .SetIsMatchedPred([](const KernelConf& kernel_conf) { \ return (kernel_conf.op_attribute().op_conf().device_tag() == "cuda" \ && GetDataType::value \ == kernel_conf.sync_dynamic_resize_conf().size_data_type()); \ }) REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int8_t); REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int32_t); REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int64_t); #endif // WITH_CUDA template class SyncDynamicResizeCPUKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeCPUKernel); SyncDynamicResizeCPUKernel() = default; ~SyncDynamicResizeCPUKernel() override = default; private: bool IsKernelLaunchSynchronized() const override { return false; } void ForwardDataContent(KernelContext* ctx) const override { const SyncDynamicResizeOpConf& conf = this->op_conf().sync_dynamic_resize_conf(); CHECK_EQ(conf.axis(), 0); const Blob* in = ctx->BnInOp2Blob("in"); const Blob* size = ctx->BnInOp2Blob("size"); Blob* out = ctx->BnInOp2Blob("out"); AutoMemcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->ByteSizeOfBlobBody(), out->mem_case(), in->mem_case()); const SizeType new_size = *size->dptr(); CHECK_GE(new_size, 0); CHECK_LE(new_size, out->shape_view().At(conf.axis())); out->mut_shape_view()->Set(conf.axis(), new_size); } }; #define REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(stype) \ NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeCPUKernel) \ .SetIsMatchedPred([](const KernelConf& kernel_conf) { \ return (kernel_conf.op_attribute().op_conf().device_tag() == "cpu" \ && GetDataType::value \ == kernel_conf.sync_dynamic_resize_conf().size_data_type()); \ }) REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int8_t); REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int32_t); REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/total_loss_instance_num_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/kernel.h" namespace oneflow { template class TotalLossInstanceNumKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(TotalLossInstanceNumKernel); TotalLossInstanceNumKernel() = default; ~TotalLossInstanceNumKernel() override = default; private: void ForwardDataContent(KernelContext* ctx) const override; }; template void TotalLossInstanceNumKernel::ForwardDataContent(KernelContext* ctx) const { const auto& input_bns = this->op_attribute().input_bns(); T first_val = ctx->BnInOp2Blob(input_bns.Get(0))->template dptr()[0]; for (const std::string& ibn : input_bns) { CHECK_EQ(ctx->BnInOp2Blob(ibn)->template dptr()[0], first_val); } ctx->BnInOp2Blob("out")->template mut_dptr()[0] = first_val; } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kTotalLossInstanceNumConf, TotalLossInstanceNumKernel, ARITHMETIC_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/user_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/user_kernel.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/op_kernel_infer_cache.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/kernel/blob_tensor_view.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { bool IsAllBlobEmpty(const PbRpf& bns, const std::function& BnInOp2Blob) { for (const auto& bn : bns) { Blob* blob = BnInOp2Blob(bn); if (blob && !blob->IsBodyEmpty()) { return false; } } return true; } } // namespace using Arg2Tensor = HashMap, std::unique_ptr>; using ArgVec = std::vector>; namespace { void FillTensorDescWithBlob(const Blob* blob, user_op::NaiveTensorDesc* tensor_desc) { BlobDescProto proto; blob->blob_desc().shape().ToProto(proto.mutable_shape()); blob->blob_desc().stride().ToProto(proto.mutable_stride()); proto.set_data_type(blob->blob_desc().data_type()); proto.set_is_dynamic(blob->blob_desc().is_dynamic()); *tensor_desc = proto; Shape tensor_desc_shape = tensor_desc->shape(); tensor_desc_shape.CheckNumAxesIdenticalAndAssign(blob->shape()); tensor_desc->set_shape(tensor_desc_shape); Stride tensor_desc_stride = tensor_desc->stride(); tensor_desc_stride.CheckNumAxesIdenticalAndAssign(blob->stride()); tensor_desc->set_stride(tensor_desc_stride); } } // namespace class UserKernelBaseContext { public: explicit UserKernelBaseContext(const KernelConf& kernel_conf) { CHECK(kernel_conf.has_user_conf()); CHECK(kernel_conf.op_attribute().op_conf().has_user_conf()); auto InitInOrOut = [&](const PbMap& arg_map, ArgVec* arg_vec) { for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { for (int32_t i = 0; i < it->second.s_size(); ++i) { arg_vec->emplace_back(std::make_pair(it->first, i)); } } }; InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_); InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_); device_type_ = CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); parallel_ctx_ = kernel_conf.parallel_ctx(); for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) { arg2bn_and_tensor_desc_.emplace( GenUnRepeatedBn(pair.first), std::make_pair(pair.first, user_op::NaiveTensorDesc(pair.second))); } } ~UserKernelBaseContext() = default; DeviceType device_type() const { return device_type_; } const ParallelContext& parallel_ctx() const { return parallel_ctx_; } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2bn_and_tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2bn_and_tensor_desc_.end()) { return nullptr; } return &(it->second.second); } const ArgVec& inputs() const { return inputs_; } const ArgVec& outputs() const { return outputs_; } private: friend class UserKernelInitAndCacheContext; HashMap, std::pair> arg2bn_and_tensor_desc_; ArgVec inputs_; ArgVec outputs_; DeviceType device_type_; ParallelContext parallel_ctx_; }; class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, public user_op::KernelCacheContext { public: explicit UserKernelInitAndCacheContext(ep::Stream* stream, const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), stream_(stream), base_ctx_(UserKernelBaseContext(kernel_conf)), parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) { nd_sbp_signature_ = NdSbpSignature(kernel_conf.op_attribute().nd_sbp_signature()); if (kernel_conf.op_attribute().has_sbp_signature()) { sbp_signature_ = SbpSignature(kernel_conf.op_attribute().sbp_signature()); } bool is_dynamic = false; for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) { if (pair.second.is_dynamic()) { is_dynamic = true; break; } } if (!is_dynamic || parallel_ctx().parallel_num() == 1) { for (const auto& pair : kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) { arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first), user_op::NaiveTensorDesc(pair.second)); } } } ~UserKernelInitAndCacheContext() override = default; ep::Stream* stream() override { return stream_; } void UpdateTensorWithCorrBlob(const std::function& BnInOp2Blob) { for (auto& pair : base_ctx_.arg2bn_and_tensor_desc_) { const std::string& bn = pair.second.first; auto& tensor_desc = pair.second.second; Blob* blob = BnInOp2Blob(bn); CHECK(blob != nullptr) << "Blob " << bn << " is not found in cache context."; if (blob->blob_desc().is_dynamic()) { Shape shape; blob->shape().ToShape(&shape); tensor_desc.set_shape(shape); } } } DeviceType device_type() const override { return base_ctx_.device_type(); } const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2logical_tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2logical_tensor_desc_.end()) { return nullptr; } else { return &(it->second); } } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { CHECK_EQ(parallel_desc_.hierarchy()->NumAxes(), 1); const auto& bn2sbp = sbp_signature_.bn_in_op2sbp_parallel(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2sbp.find(bn); CHECK(it != bn2sbp.end()); return it->second; } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const auto& bn2nd_sbp = nd_sbp_signature_.bn_in_op2nd_sbp(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2nd_sbp.find(bn); CHECK(it != bn2nd_sbp.end()); return it->second; } const ArgVec& inputs() const override { return base_ctx_.inputs(); } const ArgVec& outputs() const override { return base_ctx_.outputs(); } const ParallelDesc& parallel_desc() const override { return parallel_desc_; } private: const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } user_op::UserOpConfWrapper user_op_conf_; ep::Stream* stream_; UserKernelBaseContext base_ctx_; SbpSignature sbp_signature_; HashMap, user_op::NaiveTensorDesc> arg2logical_tensor_desc_; ParallelDesc parallel_desc_; NdSbpSignature nd_sbp_signature_; }; using UserKernelInitContext = UserKernelInitAndCacheContext; using UserKernelCacheContext = UserKernelInitAndCacheContext; class UserKernelOpInferContext : public user_op::InferContext { public: explicit UserKernelOpInferContext(const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), parallel_ctx_(kernel_conf.parallel_ctx()), nd_sbp_signature_(kernel_conf.op_attribute().nd_sbp_signature()), parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) { if (kernel_conf.op_attribute().has_sbp_signature()) { sbp_signature_ = SbpSignature(kernel_conf.op_attribute().sbp_signature()); } auto InitTensorDesc = [&](const PbMap& arg_map, ArgVec* arg_vec) { for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { const std::string& arg_name = it->first; for (int32_t i = 0; i < it->second.s_size(); ++i) { std::pair arg_pair = std::make_pair(arg_name, i); arg_vec->emplace_back(arg_pair); arg2tensor_desc_.emplace(arg_pair, nullptr); } } }; InitTensorDesc(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_); InitTensorDesc(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_); for (const auto& pair : kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) { arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first), user_op::NaiveTensorDesc(pair.second)); } } ~UserKernelOpInferContext() override = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2logical_tensor_desc_.find(std::make_pair(arg_name, index)); CHECK(it != arg2logical_tensor_desc_.end()) << "Arg (" << arg_name << "," << index << ") is not found"; return &(it->second); } const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return it->second.get(); } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return it->second.get(); } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override { SetShape4ArgNameAndIndex(arg_name, index, shape); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->shape(); } void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Shape& shape) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_shape(shape); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override { return SetStride4ArgNameAndIndex(arg_name, index, stride); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->stride(); } void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Stride& stride) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_stride(stride); } DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override { return SetDtype4ArgNameAndIndex(arg_name, index, data_type); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type(); } void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index, DataType data_type) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_data_type(data_type); } MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } void SetOutputMemoryFormat(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format); } MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format(); } void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override { return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic(); } void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index, bool is_dynamic) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_is_dynamic(is_dynamic); } const ArgVec& inputs() const override { return inputs_; } const ArgVec& outputs() const override { return outputs_; } const ParallelContext& parallel_ctx() const override { return parallel_ctx_; }; const ParallelDesc& parallel_desc() const override { return parallel_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { CHECK_EQ(parallel_desc_.hierarchy()->NumAxes(), 1); const auto& bn2sbp = sbp_signature_.bn_in_op2sbp_parallel(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2sbp.find(bn); CHECK(it != bn2sbp.end()); return it->second; } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const auto& bn2nd_sbp = nd_sbp_signature_.bn_in_op2nd_sbp(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2nd_sbp.find(bn); CHECK(it != bn2nd_sbp.end()); return it->second; } void UpdateArg2TensorDesc(const std::function& BnInOp2Blob) { for (auto& pair : arg2tensor_desc_) { const auto& arg_pair = pair.first; std::unique_ptr* arg_tensor_desc_ptr = &pair.second; Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second)); CHECK_NOTNULL(blob); if (*arg_tensor_desc_ptr) { Shape tensor_desc_shape = (*arg_tensor_desc_ptr)->shape(); tensor_desc_shape.CheckNumAxesIdenticalAndAssign(blob->shape()); (*arg_tensor_desc_ptr)->set_shape(tensor_desc_shape); Stride tensor_desc_stride = (*arg_tensor_desc_ptr)->stride(); tensor_desc_stride.CheckNumAxesIdenticalAndAssign(blob->stride()); (*arg_tensor_desc_ptr)->set_stride(tensor_desc_stride); } else { arg_tensor_desc_ptr->reset(new user_op::NaiveTensorDesc()); FillTensorDescWithBlob(blob, arg_tensor_desc_ptr->get()); } } } int64_t parallel_num() const override { return parallel_ctx_.parallel_num(); } const std::string& input(const std::string& arg_name, int32_t index) const override { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const override { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const override { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const override { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const override { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const override { return user_op_conf().output_size(arg_name); } const std::string& op_name() const override { return user_op_conf().op_name(); } const std::string& op_type_name() const override { return user_op_conf().op_type_name(); } const std::string& op_loc() const override { return user_op_conf_.op_conf().loc(); } private: const user_op::UserOpConfWrapper& user_op_conf() const { return user_op_conf_; } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } user_op::UserOpConfWrapper user_op_conf_; ArgVec inputs_; ArgVec outputs_; ParallelContext parallel_ctx_; SbpSignature sbp_signature_; NdSbpSignature nd_sbp_signature_; ParallelDesc parallel_desc_; HashMap, std::unique_ptr> arg2tensor_desc_; HashMap, user_op::NaiveTensorDesc> arg2logical_tensor_desc_; }; class UserKernelInferContext final : public user_op::KernelInferContext { public: explicit UserKernelInferContext(ep::Stream* stream, const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), stream_(stream), base_ctx_(UserKernelBaseContext(kernel_conf)), op_infer_ctx_(kernel_conf) { auto InitArg2Blob = [this](const PbMap& arg_map) { for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { const std::string& arg_name = it->first; for (int32_t i = 0; i < it->second.s_size(); ++i) { arg2tensor_.emplace(std::make_pair(arg_name, i), nullptr); } } }; InitArg2Blob(kernel_conf.op_attribute().op_conf().user_conf().input()); InitArg2Blob(kernel_conf.op_attribute().op_conf().user_conf().output()); const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult( kernel_conf.op_attribute().op_conf().user_conf().op_type_name()); CHECK_NOTNULL(op_reg_val); if (op_reg_val->physical_tensor_desc_infer_fn) { tensor_desc_infer_fn_ = op_reg_val->physical_tensor_desc_infer_fn; } else { UNIMPLEMENTED(); } } ~UserKernelInferContext() = default; DeviceType device_type() const override { return base_ctx_.device_type(); } const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index); } const ArgVec& inputs() const override { return base_ctx_.inputs(); } const ArgVec& outputs() const override { return base_ctx_.outputs(); } ep::Stream* stream() override { return stream_; } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override { auto it = arg2tensor_.find(std::make_pair(arg_name, arg_index)); CHECK(it != arg2tensor_.end()) << "Arg (" << arg_name << "," << arg_index << ") is not found"; return it->second.get(); } ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override { user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index); CHECK(arg_tensor != nullptr) << "Tensor of arg (" << arg_name << "," << arg_index << ") is not found"; return arg_tensor->shape_view(); } MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override { user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index); CHECK(arg_tensor != nullptr) << "Tensor of arg (" << arg_name << "," << arg_index << ") is not found"; return arg_tensor->mut_shape_view(); } user_op::InferContext* MutOpInferContext() override { return &op_infer_ctx_; } const user_op::TensorDescInferFn& GetOpInferFn() const override { return tensor_desc_infer_fn_; } void UpdateArg2Tensor(const std::function& BnInOp2Blob) { for (auto& pair : arg2tensor_) { const auto& arg_pair = pair.first; std::unique_ptr* arg_tensor_ptr = &pair.second; Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second)); if (blob == nullptr) { continue; } if (*arg_tensor_ptr) { arg_tensor_ptr->get()->Reset(blob); } else { arg_tensor_ptr->reset(new user_op::BlobTensorView(blob)); } } } private: const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } user_op::UserOpConfWrapper user_op_conf_; ep::Stream* stream_; UserKernelBaseContext base_ctx_; UserKernelOpInferContext op_infer_ctx_; user_op::TensorDescInferFn tensor_desc_infer_fn_; HashMap, std::unique_ptr> arg2tensor_; }; namespace { struct BnTensorPair { std::string bn; std::unique_ptr tensor; }; BnTensorPair MakeBnTensorPair(const std::string& bn) { BnTensorPair pair; pair.bn = bn; return pair; } BnTensorPair MakeBnTensorPair(const std::string& bn, std::unique_ptr&& tensor) { BnTensorPair pair; pair.bn = bn; pair.tensor = std::move(tensor); return pair; } } // namespace class UserKernelComputeContext final : public user_op::KernelComputeContext { public: explicit UserKernelComputeContext(ep::Stream* stream, const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), stream_(stream), base_ctx_(kernel_conf) { auto InitInOrOut = [&](const PbMap& arg_map) { for (const auto& it : arg_map) { const std::string& arg_name = it.first; for (int32_t i = 0; i < it.second.s_size(); ++i) { arg2bn_tensor_pair_.emplace(std::make_pair(arg_name, i), MakeBnTensorPair(GenRepeatedBn(arg_name, i))); } } }; InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input()); InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output()); arg2bn_tensor_pair_.emplace(std::make_pair("tmp_buffer", 0), MakeBnTensorPair(GenRepeatedBn("tmp_buffer", 0))); } ~UserKernelComputeContext() = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2bn_tensor_pair_.find(std::make_pair(arg_name, index)); if (it == arg2bn_tensor_pair_.end()) { return nullptr; } return it->second.tensor.get(); } ep::Stream* stream() override { return stream_; } bool UpdateTensorWithCorrBlob(const std::function& BnInOp2Blob) { bool updated = false; for (auto& pair : arg2bn_tensor_pair_) { std::unique_ptr* arg_tensor_ptr = &pair.second.tensor; Blob* blob = BnInOp2Blob(pair.second.bn); if (blob == nullptr) { if (*arg_tensor_ptr) { arg_tensor_ptr->reset(nullptr); updated = true; } } else { if (*arg_tensor_ptr) { if (arg_tensor_ptr->get()->blob() != blob) { arg_tensor_ptr->get()->Reset(blob); updated = true; } else { if (blob->blob_desc().is_dynamic()) { updated = true; } } } else { arg_tensor_ptr->reset(new user_op::BlobTensorView(blob)); updated = true; } } } return updated; } DeviceType device_type() const override { return base_ctx_.device_type(); } const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); } const ArgVec& inputs() const override { return base_ctx_.inputs(); } const ArgVec& outputs() const override { return base_ctx_.outputs(); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } user_op::UserOpConfWrapper user_op_conf_; ep::Stream* stream_; HashMap, BnTensorPair> arg2bn_tensor_pair_; UserKernelBaseContext base_ctx_; }; // kernel registry context used in kernel creation class UserKernelRegContext final : public user_op::KernelRegContext { public: explicit UserKernelRegContext(const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), base_ctx_(UserKernelBaseContext(kernel_conf)) {} ~UserKernelRegContext() = default; DeviceType device_type() const override { return base_ctx_.device_type(); } const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index); } const ArgVec& inputs() const override { return base_ctx_.inputs(); } const ArgVec& outputs() const override { return base_ctx_.outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } private: user_op::UserOpConfWrapper user_op_conf_; UserKernelBaseContext base_ctx_; }; UserKernel::~UserKernel() = default; void UserKernel::InitUserKernel(ep::Stream* stream) { ctx_.reset(new UserKernelComputeContext(stream, kernel_conf())); infer_ctx_.reset(new UserKernelInferContext(stream, kernel_conf())); cache_ctx_.reset(new UserKernelCacheContext(stream, kernel_conf())); infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), this)); { const std::string& op_type_name = kernel_conf().op_attribute().op_conf().user_conf().op_type_name(); const user_op::OpKernelRegistryResult* kernel_reg_val = CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult( op_type_name, UserKernelRegContext(kernel_conf()))); CHECK_NOTNULL(kernel_reg_val); kernel_.reset(kernel_reg_val->create_fn()); } } std::shared_ptr UserKernel::CreateOpKernelState(KernelContext* ctx) { UserKernelInitContext init_ctx(ctx->stream(), kernel_conf()); return kernel_->CreateOpKernelState(&init_ctx); } const std::shared_ptr& UserKernel::GetOpKernelState() const { return opkernel_state_; } void UserKernel::ForwardUserKernel(const std::function& BnInOp2Blob, user_op::OpKernelState* opkernel_state) const { const bool updated = ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob); if (updated) { cache_ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob); kernel_->InitOpKernelCacheWithFlags(cache_ctx_.get(), user_op::OpKernelCache::kAttrNotChanged, &opkernel_cache_); } else { // do nothing } #ifdef WITH_CUDA_GRAPHS bool current_scope_capturing = false; if (cuda_graph_exec_) { auto* cuda_stream = dynamic_cast(ctx_->stream()); if (!cuda_stream->IsGraphCapturing()) { if (cuda_graph_exec_->IsInstantiated() && (!updated)) { cuda_stream->LaunchGraph(cuda_graph_exec_.get()); return; } const auto* cuda_graph_support = CHECK_NOTNULL(dynamic_cast(kernel_.get())); if (cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state, opkernel_cache_.get())) { current_scope_capturing = true; cuda_stream->BeginGraphCapture(); } } } #endif // WITH_CUDA_GRAPHS kernel_->Compute(ctx_.get(), opkernel_state, opkernel_cache_.get()); #ifdef WITH_CUDA_GRAPHS if (cuda_graph_exec_ && current_scope_capturing) { auto* cuda_stream = dynamic_cast(ctx_->stream()); cuda_stream->EndGraphCapture(cuda_graph_exec_.get()); cuda_stream->LaunchGraph(cuda_graph_exec_.get()); } #endif // WITH_CUDA_GRAPHS } bool UserKernel::IsCudaGraphSupported() const { #ifdef WITH_CUDA_GRAPHS return cuda_graph_exec_.get() != nullptr; #else return false; #endif // WITH_CUDA_GRAPHS } bool UserKernel::IsReadyForCudaGraphCapture(KernelContext* ctx) const { const auto* cuda_graph_support = dynamic_cast(kernel_.get()); if (cuda_graph_support == nullptr) { return false; } return cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state_.get(), opkernel_cache_.get()); } void UserKernel::VirtualKernelInit(KernelContext* ctx) { InitUserKernel(ctx->stream()); CHECK(opkernel_state_.get() == nullptr); opkernel_state_ = CreateOpKernelState(ctx); kernel_->InitOpKernelCacheWithFlags(cache_ctx_.get(), user_op::OpKernelCache::kAllMayChanged, &opkernel_cache_); #ifdef WITH_CUDA_GRAPHS if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false) && (!ParseBooleanFromEnv("ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION", false))) { UserKernelInitContext init_ctx(ctx->stream(), kernel_conf()); auto* cuda_stream = dynamic_cast(ctx->stream()); const auto* cuda_graph_support = dynamic_cast(kernel_.get()); if (cuda_stream != nullptr) { if (cuda_graph_support != nullptr && cuda_graph_support->IsCudaGraphSupported(&init_ctx, opkernel_state_.get())) { cuda_graph_exec_.reset(new ep::CudaGraphExecutable()); VLOG(3) << "CUDA Graphs Kernel: " << op_conf().name() << " (" << op_conf().user_conf().op_type_name() << ")"; } else { VLOG(3) << "CUDA Graphs not supported: " << op_conf().name() << " (" << op_conf().user_conf().op_type_name() << ")"; } } } #endif // WITH_CUDA_GRAPHS } void UserKernel::ForwardDataContent(KernelContext* ctx) const { const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); }; ForwardUserKernel(BnInOp2Blob, opkernel_state_.get()); } void UserKernel::ForwardShape(KernelContext* ctx) const { const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); }; infer_ctx_->UpdateArg2Tensor(BnInOp2Blob); infer_cache_->UpdateCacheKey(infer_ctx_.get()); if (!infer_cache_->IsCacheHit()) { auto* op_infer_ctx = dynamic_cast(infer_ctx_->MutOpInferContext()); CHECK_NOTNULL(op_infer_ctx); op_infer_ctx->UpdateArg2TensorDesc(BnInOp2Blob); kernel_->InferShape(infer_ctx_.get()); for (const auto& out_arg_pair : infer_ctx_->outputs()) { const Shape& static_shape = infer_ctx_->TensorDesc4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second)->shape(); const ShapeView& shape_view = infer_ctx_->ShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second); CHECK_LE(shape_view.elem_cnt(), static_shape.elem_cnt()) << "InferShape of OpKernel (op_type_name: " << op_conf().user_conf().op_type_name() << ", op_name: " << op_conf().name() << ") raise error, output arg's (name: " << out_arg_pair.first << ", index: " << out_arg_pair.second << ") runtime shape " << shape_view.ToString() << " surpass the limit of static shape " << static_shape.ToString(); } infer_cache_->UpdateCacheValue(infer_ctx_.get()); } else { std::shared_ptr cache_value_ptr = infer_cache_->GetCacheValue(); FOR_RANGE(int, i, 0, infer_ctx_->outputs().size()) { const auto& out_arg_pair = infer_ctx_->outputs().at(i); MutShapeView mut_shape_view = infer_ctx_->MutShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second); mut_shape_view.set_shape(*cache_value_ptr->obn_idx2shape_sym.at(i)); } } } bool UserKernel::IsStateless() const { return !kernel_->AlwaysComputeWhenAllOutputsEmpty(); } NEW_REGISTER_KERNEL(OperatorConf::kUserConf, UserKernel).SetIsMatchedPred([](const KernelConf&) { return true; }); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/user_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/op_kernel_infer_cache.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/kernel/kernel.h" #ifdef WITH_CUDA #include "oneflow/core/ep/cuda/cuda_stream.h" #endif // WITH_CUDA namespace oneflow { class UserKernelComputeContext; class UserKernelInferContext; class UserKernelInitAndCacheContext; namespace user_op { class OpKernelCache; } class UserKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(UserKernel); UserKernel() = default; ~UserKernel() override; void InitUserKernel(ep::Stream* stream); std::shared_ptr CreateOpKernelState(KernelContext* ctx); const std::shared_ptr& GetOpKernelState() const; void ForwardUserKernel(const std::function& BnInOp2Blob, user_op::OpKernelState* opkernel_state) const; bool IsCudaGraphSupported() const; bool IsReadyForCudaGraphCapture(KernelContext* ctx) const; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; void ForwardShape(KernelContext* ctx) const override; bool IsStateless() const override; bool IsKernelLaunchSynchronized() const override { return kernel_->IsKernelLaunchSynchronized(); } mutable std::shared_ptr opkernel_cache_; std::shared_ptr opkernel_state_; std::unique_ptr kernel_; std::unique_ptr ctx_; std::unique_ptr cache_ctx_; std::unique_ptr infer_ctx_; std::unique_ptr infer_cache_; #ifdef WITH_CUDA_GRAPHS std::unique_ptr cuda_graph_exec_; #endif // WITH_CUDA_GRAPHS }; } // namespace oneflow ================================================ FILE: oneflow/core/kernel/util/cuda_half_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_ #define ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_ #include "oneflow/core/device/cuda_util.h" namespace oneflow { #define HALF_CHECK_FAILED \ printf("half operations are only supported when CUDA_ARCH >= 530"); \ assert(false) __inline__ __device__ half hone() { return __float2half(1.0); } __inline__ __device__ half hzero() { return __float2half(0.0); } __inline__ half float16_2half(float16 x) { // TODO: Potential loss of accuracy half* ret = reinterpret_cast(&x); return *ret; } __inline__ float16 half2float16(half x) { // TODO: Potential loss of accuracy float16* ret = reinterpret_cast(&x); return *ret; } } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_ ================================================ FILE: oneflow/core/kernel/util/numeric_limits.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // reference: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/NumericLimits.cuh #pragma once #include #include #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" // numeric_limits.cuh is a holder for numeric limits definitions of commonly used // types. This header is very specific to ROCm HIP and may be removed in the future. // The lower_bound and upper_bound constants are same as lowest and max for // integral types, but are -inf and +inf for floating point types. They are // useful in implementing min, max, etc. namespace oneflow { namespace detail { #if defined(__CUDACC__) #define OF_NUMERICS_FUNC static inline __host__ __device__ #else #define OF_NUMERICS_FUNC static inline #endif template struct numeric_limits {}; // WARNING: the following oneflow::numeric_limits definitions are there only to support // HIP compilation for the moment. Use std::numeric_limits if you are not // compiling for ROCm. // from @colesbury: "The functions on numeric_limits aren't marked with // __device__ which is why they don't work with ROCm. CUDA allows them // because they're constexpr." namespace { // ROCm doesn't like INFINITY too. constexpr double inf = INFINITY; } // namespace template<> struct numeric_limits { OF_NUMERICS_FUNC bool lowest() { return false; } OF_NUMERICS_FUNC bool max() { return true; } OF_NUMERICS_FUNC bool lower_bound() { return false; } OF_NUMERICS_FUNC bool upper_bound() { return true; } }; template<> struct numeric_limits { OF_NUMERICS_FUNC uint8_t lowest() { return 0; } OF_NUMERICS_FUNC uint8_t max() { return UINT8_MAX; } OF_NUMERICS_FUNC uint8_t lower_bound() { return 0; } OF_NUMERICS_FUNC uint8_t upper_bound() { return UINT8_MAX; } }; template<> struct numeric_limits { OF_NUMERICS_FUNC int8_t lowest() { return INT8_MIN; } OF_NUMERICS_FUNC int8_t max() { return INT8_MAX; } OF_NUMERICS_FUNC int8_t lower_bound() { return INT8_MIN; } OF_NUMERICS_FUNC int8_t upper_bound() { return INT8_MAX; } }; template<> struct numeric_limits { OF_NUMERICS_FUNC int16_t lowest() { return INT16_MIN; } OF_NUMERICS_FUNC int16_t max() { return INT16_MAX; } OF_NUMERICS_FUNC int16_t lower_bound() { return INT16_MIN; } OF_NUMERICS_FUNC int16_t upper_bound() { return INT16_MAX; } }; template<> struct numeric_limits { OF_NUMERICS_FUNC int32_t lowest() { return INT32_MIN; } OF_NUMERICS_FUNC int32_t max() { return INT32_MAX; } OF_NUMERICS_FUNC int32_t lower_bound() { return INT32_MIN; } OF_NUMERICS_FUNC int32_t upper_bound() { return INT32_MAX; } }; template<> struct numeric_limits { #ifdef _MSC_VER OF_NUMERICS_FUNC int64_t lowest() { return _I64_MIN; } OF_NUMERICS_FUNC int64_t max() { return _I64_MAX; } OF_NUMERICS_FUNC int64_t lower_bound() { return _I64_MIN; } OF_NUMERICS_FUNC int64_t upper_bound() { return _I64_MAX; } #else OF_NUMERICS_FUNC int64_t lowest() { return INT64_MIN; } OF_NUMERICS_FUNC int64_t max() { return INT64_MAX; } OF_NUMERICS_FUNC int64_t lower_bound() { return INT64_MIN; } OF_NUMERICS_FUNC int64_t upper_bound() { return INT64_MAX; } #endif }; template<> struct numeric_limits { OF_NUMERICS_FUNC float lowest() { return -FLT_MAX; } OF_NUMERICS_FUNC float max() { return FLT_MAX; } OF_NUMERICS_FUNC float lower_bound() { return -static_cast(inf); } OF_NUMERICS_FUNC float upper_bound() { return static_cast(inf); } }; #if defined(__CUDACC__) static __device__ unsigned short int HALF_LOWEST = 0xfbff; static __device__ unsigned short int HALF_MAX = 0x7bff; static __device__ unsigned short int HALF_LOWER_BOUND = 0xfc00; static __device__ unsigned short int HALF_UPPER_BOUND = 0x7c00; template<> struct numeric_limits { static inline __device__ half lowest() { return *reinterpret_cast(&HALF_LOWEST); } static inline __device__ half max() { return *reinterpret_cast(&HALF_MAX); } static inline __device__ half lower_bound() { return *reinterpret_cast(&HALF_LOWER_BOUND); } static inline __device__ half upper_bound() { return *reinterpret_cast(&HALF_UPPER_BOUND); } }; #if CUDA_VERSION >= 11000 static __device__ unsigned short int NV_BFLOAT16_LOWEST = 0xff7f; static __device__ unsigned short int NV_BFLOAT16_MAX = 0x7f7f; static __device__ unsigned short int NV_BFLOAT16_LOWER_BOUND = 0xff80; static __device__ unsigned short int NV_BFLOAT16_UPPER_BOUND = 0x7f80; template<> struct numeric_limits { static inline __device__ nv_bfloat16 lowest() { return *reinterpret_cast(&NV_BFLOAT16_LOWEST); } static inline __device__ nv_bfloat16 max() { return *reinterpret_cast(&NV_BFLOAT16_MAX); } static inline __device__ nv_bfloat16 lower_bound() { return *reinterpret_cast(&NV_BFLOAT16_LOWER_BOUND); } static inline __device__ nv_bfloat16 upper_bound() { return *reinterpret_cast(&NV_BFLOAT16_UPPER_BOUND); } }; #endif // CUDA_VERSION >= 11000 #endif // defined(__CUDACC__) template<> struct numeric_limits { OF_NUMERICS_FUNC double lowest() { return -DBL_MAX; } OF_NUMERICS_FUNC double max() { return DBL_MAX; } OF_NUMERICS_FUNC double lower_bound() { return -inf; } OF_NUMERICS_FUNC double upper_bound() { return inf; } }; } // namespace detail } // namespace oneflow ================================================ FILE: oneflow/core/kernel/util/numerics.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // reference: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCNumerics.cuh #ifndef ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H #define ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H #pragma once #include #include #include #include #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/util/numeric_limits.cuh" namespace oneflow { namespace detail { template struct numerics {}; template OF_NUMERICS_FUNC T powi(T a, T b) { assert(numerics::ge(b, 0)); T result = 1; while (b) { if (b & 1) { result *= a; } b /= 2; a *= a; } return result; } template<> struct numerics { OF_NUMERICS_FUNC uint8_t min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC uint8_t max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC uint8_t lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC uint8_t upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(uint8_t a, uint8_t b) { return a < b; } OF_NUMERICS_FUNC bool le(uint8_t a, uint8_t b) { return a <= b; } OF_NUMERICS_FUNC bool gt(uint8_t a, uint8_t b) { return a > b; } OF_NUMERICS_FUNC bool ge(uint8_t a, uint8_t b) { return a >= b; } OF_NUMERICS_FUNC bool eq(uint8_t a, uint8_t b) { return a == b; } OF_NUMERICS_FUNC bool ne(uint8_t a, uint8_t b) { return a != b; } OF_NUMERICS_FUNC uint8_t add(uint8_t a, uint8_t b) { return a + b; } OF_NUMERICS_FUNC uint8_t mul(uint8_t a, uint8_t b) { return a * b; } OF_NUMERICS_FUNC uint8_t sub(uint8_t a, uint8_t b) { return a - b; } OF_NUMERICS_FUNC uint8_t div(uint8_t a, uint8_t b) { return a / b; } OF_NUMERICS_FUNC uint8_t pow(uint8_t a, uint8_t b) { return powi(a, b); } OF_NUMERICS_FUNC bool isnan(uint8_t a) { return false; } OF_NUMERICS_FUNC bool isinf(uint8_t a) { return false; } }; #ifdef _MSC_VER // Suppress warning C4804: '/': unsafe use of type 'bool' in operation #pragma warning(push) #pragma warning(disable : 4804) #endif template<> struct numerics { OF_NUMERICS_FUNC bool min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC bool max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC bool lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC bool upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(bool a, bool b) { return a < b; } OF_NUMERICS_FUNC bool le(bool a, bool b) { return a <= b; } OF_NUMERICS_FUNC bool gt(bool a, bool b) { return a > b; } OF_NUMERICS_FUNC bool ge(bool a, bool b) { return a >= b; } OF_NUMERICS_FUNC bool eq(bool a, bool b) { return a == b; } OF_NUMERICS_FUNC bool ne(bool a, bool b) { return a != b; } OF_NUMERICS_FUNC bool add(bool a, bool b) { return a + b; } OF_NUMERICS_FUNC bool mul(bool a, bool b) { return a && b; } OF_NUMERICS_FUNC bool sub(bool a, bool b) { return a - b; } OF_NUMERICS_FUNC bool div(bool a, bool b) { return a / b; } OF_NUMERICS_FUNC bool isnan(bool a) { return false; } OF_NUMERICS_FUNC bool isinf(bool a) { return false; } }; #ifdef _MSC_VER #pragma warning(pop) #endif template<> struct numerics { OF_NUMERICS_FUNC int8_t min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC int8_t max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC int8_t lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC int8_t upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(int8_t a, int8_t b) { return a < b; } OF_NUMERICS_FUNC bool le(int8_t a, int8_t b) { return a <= b; } OF_NUMERICS_FUNC bool gt(int8_t a, int8_t b) { return a > b; } OF_NUMERICS_FUNC bool ge(int8_t a, int8_t b) { return a >= b; } OF_NUMERICS_FUNC bool eq(int8_t a, int8_t b) { return a == b; } OF_NUMERICS_FUNC bool ne(int8_t a, int8_t b) { return a != b; } OF_NUMERICS_FUNC int8_t add(int8_t a, int8_t b) { return a + b; } OF_NUMERICS_FUNC int8_t mul(int8_t a, int8_t b) { return a * b; } OF_NUMERICS_FUNC int8_t sub(int8_t a, int8_t b) { return a - b; } OF_NUMERICS_FUNC int8_t div(int8_t a, int8_t b) { return a / b; } OF_NUMERICS_FUNC int8_t pow(int8_t a, int8_t b) { return powi(a, b); } OF_NUMERICS_FUNC bool isnan(int8_t a) { return false; } OF_NUMERICS_FUNC bool isinf(int8_t a) { return false; } }; template<> struct numerics { OF_NUMERICS_FUNC int16_t min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC int16_t max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC int16_t lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC int16_t upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(int16_t a, int16_t b) { return a < b; } OF_NUMERICS_FUNC bool le(int16_t a, int16_t b) { return a <= b; } OF_NUMERICS_FUNC bool gt(int16_t a, int16_t b) { return a > b; } OF_NUMERICS_FUNC bool ge(int16_t a, int16_t b) { return a >= b; } OF_NUMERICS_FUNC bool eq(int16_t a, int16_t b) { return a == b; } OF_NUMERICS_FUNC bool ne(int16_t a, int16_t b) { return a != b; } OF_NUMERICS_FUNC int16_t add(int16_t a, int16_t b) { return a + b; } OF_NUMERICS_FUNC int16_t mul(int16_t a, int16_t b) { return a * b; } OF_NUMERICS_FUNC int16_t sub(int16_t a, int16_t b) { return a - b; } OF_NUMERICS_FUNC int16_t div(int16_t a, int16_t b) { return a / b; } OF_NUMERICS_FUNC int16_t pow(int16_t a, int16_t b) { return powi(a, b); } OF_NUMERICS_FUNC bool isnan(int16_t a) { return false; } OF_NUMERICS_FUNC bool isinf(int16_t a) { return false; } }; template<> struct numerics { OF_NUMERICS_FUNC int32_t min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC int32_t max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC int32_t lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC int32_t upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(int32_t a, int32_t b) { return a < b; } OF_NUMERICS_FUNC bool le(int32_t a, int32_t b) { return a <= b; } OF_NUMERICS_FUNC bool gt(int32_t a, int32_t b) { return a > b; } OF_NUMERICS_FUNC bool ge(int32_t a, int32_t b) { return a >= b; } OF_NUMERICS_FUNC bool eq(int32_t a, int32_t b) { return a == b; } OF_NUMERICS_FUNC bool ne(int32_t a, int32_t b) { return a != b; } OF_NUMERICS_FUNC int32_t add(int32_t a, int32_t b) { return a + b; } OF_NUMERICS_FUNC int32_t mul(int32_t a, int32_t b) { return a * b; } OF_NUMERICS_FUNC int32_t sub(int32_t a, int32_t b) { return a - b; } OF_NUMERICS_FUNC int32_t div(int32_t a, int32_t b) { return a / b; } OF_NUMERICS_FUNC int32_t pow(int32_t a, int32_t b) { return powi(a, b); } OF_NUMERICS_FUNC bool isnan(int32_t a) { return false; } OF_NUMERICS_FUNC bool isinf(int32_t a) { return false; } }; template<> struct numerics { OF_NUMERICS_FUNC int64_t min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC int64_t max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC int64_t lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC int64_t upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(int64_t a, int64_t b) { return a < b; } OF_NUMERICS_FUNC bool le(int64_t a, int64_t b) { return a <= b; } OF_NUMERICS_FUNC bool gt(int64_t a, int64_t b) { return a > b; } OF_NUMERICS_FUNC bool ge(int64_t a, int64_t b) { return a >= b; } OF_NUMERICS_FUNC bool eq(int64_t a, int64_t b) { return a == b; } OF_NUMERICS_FUNC bool ne(int64_t a, int64_t b) { return a != b; } OF_NUMERICS_FUNC int64_t add(int64_t a, int64_t b) { return a + b; } OF_NUMERICS_FUNC int64_t mul(int64_t a, int64_t b) { return a * b; } OF_NUMERICS_FUNC int64_t sub(int64_t a, int64_t b) { return a - b; } OF_NUMERICS_FUNC int64_t div(int64_t a, int64_t b) { return a / b; }; OF_NUMERICS_FUNC int64_t pow(int64_t a, int64_t b) { return powi(a, b); } OF_NUMERICS_FUNC bool isnan(int64_t a) { return false; } OF_NUMERICS_FUNC bool isinf(int64_t a) { return false; } }; // DEPRECATED: use math functions from std and cuda math API (if needed) template<> struct numerics { OF_NUMERICS_FUNC float min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC float max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC float lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC float upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(float a, float b) { return a < b; } OF_NUMERICS_FUNC bool le(float a, float b) { return a <= b; } OF_NUMERICS_FUNC bool gt(float a, float b) { return a > b; } OF_NUMERICS_FUNC bool ge(float a, float b) { return a >= b; } OF_NUMERICS_FUNC bool eq(float a, float b) { return a == b; } OF_NUMERICS_FUNC bool ne(float a, float b) { return a != b; } OF_NUMERICS_FUNC float sqrt(float a) { return sqrtf(a); } OF_NUMERICS_FUNC float atan(float a) { return atanf(a); } OF_NUMERICS_FUNC float add(float a, float b) { return a + b; } OF_NUMERICS_FUNC float div(float a, float b) { return a / b; } OF_NUMERICS_FUNC float mul(float a, float b) { return a * b; } OF_NUMERICS_FUNC float sub(float a, float b) { return a - b; } OF_NUMERICS_FUNC float pow(float a, float b) { return powf(a, b); } OF_NUMERICS_FUNC bool isnan(float a) { return ::isnan(a); } OF_NUMERICS_FUNC bool isinf(float a) { return ::isinf(a); } }; #if defined(__CUDACC__) template<> struct numerics { OF_NUMERICS_FUNC bool isnan(half a) { return ::isnan((float)a); } }; #endif template<> struct numerics { OF_NUMERICS_FUNC double min() { return detail::numeric_limits::lowest(); } OF_NUMERICS_FUNC double max() { return detail::numeric_limits::max(); } OF_NUMERICS_FUNC double lower_bound() { return detail::numeric_limits::lower_bound(); } OF_NUMERICS_FUNC double upper_bound() { return detail::numeric_limits::upper_bound(); } OF_NUMERICS_FUNC bool lt(double a, double b) { return a < b; } OF_NUMERICS_FUNC bool le(double a, double b) { return a <= b; } OF_NUMERICS_FUNC bool gt(double a, double b) { return a > b; } OF_NUMERICS_FUNC bool ge(double a, double b) { return a >= b; } OF_NUMERICS_FUNC bool eq(double a, double b) { return a == b; } OF_NUMERICS_FUNC bool ne(double a, double b) { return a != b; } OF_NUMERICS_FUNC double sqrt(double a) { return ::sqrt(a); } OF_NUMERICS_FUNC double atan(double a) { return ::atan(a); } OF_NUMERICS_FUNC double add(double a, double b) { return a + b; } OF_NUMERICS_FUNC double div(double a, double b) { return a / b; } OF_NUMERICS_FUNC double mul(double a, double b) { return a * b; } OF_NUMERICS_FUNC double sub(double a, double b) { return a - b; } OF_NUMERICS_FUNC double pow(double a, double b) { return ::pow(a, b); } OF_NUMERICS_FUNC bool isnan(double a) { return ::isnan(a); } OF_NUMERICS_FUNC bool isinf(double a) { return ::isinf(a); } }; } // namespace detail } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H ================================================ FILE: oneflow/core/kernel/wait_and_send_ids_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/wait_and_send_ids_kernel.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { template void WaitAndSendIdsKernel::VirtualKernelInit(KernelContext* ctx) { ctx->set_state(std::make_shared()); } template void WaitAndSendIdsKernel::ForwardDataContent(KernelContext* ctx) const { auto* status = CHECK_NOTNULL(dynamic_cast(ctx->state().get())); if (status->out_idx_ >= status->out_num_) { CHECK(this->op_conf().wait_and_send_ids_conf().has_job_name()); const auto& job_name = this->op_conf().wait_and_send_ids_conf().job_name(); auto* buffer_mgr = Singleton>>::Get(); auto* buffer = buffer_mgr->Get(GetSourceTickBufferName(job_name)); status->in_id_ = 0; { std::shared_ptr job_instance; status->buffer_status_ = buffer->Pull(&job_instance); } if (status->buffer_status_ == kBufferStatusErrorClosed) { return; } status->out_idx_ = 0; status->out_num_ = 1; } *ctx->BnInOp2Blob("out")->mut_dptr() = 0; ++status->out_idx_; } ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsKernel, INT_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/kernel/wait_and_send_ids_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" namespace oneflow { struct WaitAndSendIdsStatus final : public KernelState { BufferStatus buffer_status_; int64_t in_id_; int64_t out_idx_; size_t out_num_; }; template class WaitAndSendIdsKernel final : public Kernel { public: OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsKernel); WaitAndSendIdsKernel() = default; ~WaitAndSendIdsKernel() = default; private: void VirtualKernelInit(KernelContext* ctx) override; void ForwardDataContent(KernelContext* ctx) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_ ================================================ FILE: oneflow/core/lazy/actor/acc_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" namespace oneflow { class AccActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(AccActor); AccActor() = default; ~AccActor() override = default; private: void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void VirtualActorInit(const TaskProto& proto) override; int32_t acc_cnt_{}; int32_t max_acc_cnt_{}; }; void AccActor::VirtualActorInit(const TaskProto& proto) { const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_GE(in_time_shape.elem_cnt(), out_time_shape.elem_cnt()); max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt(); acc_cnt_ = 0; OF_SET_MSG_HANDLER(&AccActor::HandlerNormal); } void AccActor::Act() { if (acc_cnt_ == 0) { Regst* out_regst = GetNaiveCurWriteable("out"); Regst* in_regst = GetNaiveCurReadable("in"); const Blob* in_blob = in_regst->GetMutSoleBlob(); Blob* out_blob = out_regst->GetMutSoleBlob(); const size_t size = in_blob->ByteSizeOfBlobBody(); CHECK_EQ(out_blob->ByteSizeOfBlobBody(), size); AutoMemcpy(actor_ctx()->stream_ctx()->stream(), out_blob->ForceMutDptr(), in_blob->dptr(), size, out_blob->mem_case(), in_blob->mem_case()); } else { AsyncLaunchKernel(); } acc_cnt_ += 1; } void AccActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (acc_cnt_ == max_acc_cnt_) { HandleProducedNaiveDataRegstToConsumer(); acc_cnt_ = 0; } } REGISTER_ACTOR(TaskType::kAcc, AccActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class AccCtrlTickActor : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickActor); AccCtrlTickActor() : acc_cnt_(0), max_acc_num_(0), last_micro_batch_input_output_mutex_(false), consumed_tick_regst_desc_id_(-1), produced_tick_regst_desc_id_(-1){}; virtual ~AccCtrlTickActor() = default; private: // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized. std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } std::pair> GetNaiveOrCustomizedProducedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } bool IsCustomizedReadReady() const override { return (!last_micro_batch_input_output_mutex_) && consumed_tick_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { return produced_tick_rs_.IsCurSlotReady(); } void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { // all Messages are flushed return ReceiveEordMsg(consumed_tick_regst_desc_id_); } void VirtualActorInit(const TaskProto& proto) override; void Act() override; void AsyncSendCustomizedProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override; int32_t acc_cnt_; int32_t max_acc_num_; bool last_micro_batch_input_output_mutex_; int64_t consumed_tick_regst_desc_id_; int64_t produced_tick_regst_desc_id_; RegstSlot consumed_tick_rs_; RegstSlot produced_tick_rs_; }; void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { acc_cnt_ = 0; const OperatorConf& op_conf = proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); max_acc_num_ = user_op::UserOpConfWrapper(op_conf).attr("max_acc_num"); // NOTE(chengcheng): check time shape equal max_acc_num const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0); CHECK_EQ(in_time_shape.elem_cnt() / out_time_shape.elem_cnt(), max_acc_num_); CHECK_GT(max_acc_num_, 1); // input const auto& consumed_ids = proto.consumed_regst_desc_id(); CHECK_EQ(consumed_ids.size(), 1); auto in_it = consumed_ids.find("in"); CHECK(in_it != consumed_ids.end()); CHECK_EQ(in_it->second.regst_desc_id_size(), 1); consumed_tick_regst_desc_id_ = in_it->second.regst_desc_id(0); consumed_tick_rs_.InsertRegstDescId(consumed_tick_regst_desc_id_); consumed_tick_rs_.InitedDone(); // output CHECK_EQ(proto.produced_regst_desc().size(), 1); const auto& produced_ids = proto.produced_regst_desc(); CHECK_EQ(produced_ids.size(), 1); auto out_it = produced_ids.find("out"); CHECK(out_it != produced_ids.end()); const RegstDescProto& out_regst_desc = out_it->second; produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); produced_tick_rs_.InitedDone(); ForEachProducedRegst([&](Regst* regst) { CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); CHECK_EQ(0, produced_tick_rs_.TryPushBackRegst(regst)); }); OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); } void AccCtrlTickActor::Act() { acc_cnt_ += 1; if (acc_cnt_ == max_acc_num_) { CHECK(!last_micro_batch_input_output_mutex_); last_micro_batch_input_output_mutex_ = true; acc_cnt_ = 0; } } void AccCtrlTickActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { if (last_micro_batch_input_output_mutex_) { CHECK(consumed_tick_rs_.IsCurSlotReady()); // inplace consume CHECK(produced_tick_rs_.IsCurSlotReady()); Regst* const tick_regst = produced_tick_rs_.Front(produced_tick_regst_desc_id_); CHECK_GT(HandleRegstToConsumer(tick_regst), 0); produced_tick_rs_.PopFrontRegsts({produced_tick_regst_desc_id_}); } } void AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { if (!last_micro_batch_input_output_mutex_) { Regst* const tick_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_); CHECK_NOTNULL(tick_regst); AsyncSendRegstMsgToProducer(tick_regst); CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); } } void AccCtrlTickActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { CHECK(last_micro_batch_input_output_mutex_); CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); CHECK_EQ(produced_tick_rs_.TryPushBackRegst(regst), 0); Regst* in_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_); CHECK(in_regst); AsyncSendRegstMsgToProducer(in_regst); CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); last_micro_batch_input_output_mutex_ = false; } void AccCtrlTickActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_tick_rs_.TryPushBackRegst(msg.regst())); } REGISTER_ACTOR(TaskType::kAccCtrlTick, AccCtrlTickActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/acc_tick_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" namespace oneflow { class AccTickActor : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(AccTickActor); AccTickActor() = default; virtual ~AccTickActor() = default; protected: void VirtualActorInit(const TaskProto& proto) override; private: void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; int32_t acc_cnt_; int32_t max_acc_cnt_; }; void AccTickActor::VirtualActorInit(const TaskProto& proto) { const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0); acc_cnt_ = 0; max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt(); OF_SET_MSG_HANDLER(&AccTickActor::HandlerNormal); } void AccTickActor::Act() { acc_cnt_ += 1; } void AccTickActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (acc_cnt_ == max_acc_cnt_) { HandleProducedNaiveDataRegstToConsumer(); acc_cnt_ = 0; } } REGISTER_ACTOR(TaskType::kAccTick, AccTickActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/runtime_job_descs.h" #include "oneflow/core/lazy/stream_context/include/stream_context.h" namespace oneflow { namespace { class KernelContextImpl : public KernelContext, public ActorContextProvider { public: OF_DISALLOW_COPY_AND_MOVE(KernelContextImpl); explicit KernelContextImpl(ActorContext* actor_ctx) : actor_ctx_(actor_ctx), stream_ctx_(actor_ctx->stream_ctx()), state_(nullptr), stream_kernel_observer_(nullptr) { auto* kernel_observer_provider = dynamic_cast(stream_ctx_); if (kernel_observer_provider != nullptr) { stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver(); } } ~KernelContextImpl() = default; ep::Stream* stream() const override { return stream_ctx_->stream(); } ActorContext* GetActorContext() const override { return actor_ctx_; } Blob* BnInOp2Blob(const std::string& bn) const override { return bn_in_op2blob_fn_(bn); } const std::shared_ptr& state() const override { return state_; } void set_state(std::shared_ptr state) override { state_ = std::move(state); } void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override; void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override; void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override; void UpdateBnInOp2BlobFn(std::function fn) { bn_in_op2blob_fn_ = std::move(fn); } private: ActorContext* actor_ctx_; StreamContext* stream_ctx_; std::function bn_in_op2blob_fn_; std::shared_ptr state_; KernelObserver* stream_kernel_observer_; }; void KernelContextImpl::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) { Singleton::Get()->WillForward(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForward(kernel_ctx, kernel); } } void KernelContextImpl::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) { CHECK_JUST_MSG(kernel_ctx->stream()->GetAsyncError(), kernel->op_conf().name()); Singleton::Get()->DidForward(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForward(kernel_ctx, kernel); } } void KernelContextImpl::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) { Singleton::Get()->WillForwardHeader(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel); } } void KernelContextImpl::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) { Singleton::Get()->DidForwardHeader(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel); } } void KernelContextImpl::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { Singleton::Get()->WillForwardDataContent(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel); } } void KernelContextImpl::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) { Singleton::Get()->DidForwardDataContent(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel); } } void CheckInplaceRegstDescId(const TaskProto& task_proto) { HashSet consumed_regst_desc_ids; for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t id : pair.second.regst_desc_id()) { consumed_regst_desc_ids.insert(id); } } for (const auto& pair : task_proto.produced_regst_desc()) { if (pair.second.has_inplace_consumed_regst_desc_id() == false) { continue; } int64_t in_regst_desc_id = pair.second.inplace_consumed_regst_desc_id(); CHECK(consumed_regst_desc_ids.find(in_regst_desc_id) != consumed_regst_desc_ids.end()); } } } // namespace Actor::~Actor() = default; void Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) { actor_ctx_ = actor_ctx; const TaskProto& task_proto = actor_ctx->task_proto(); actor_id_ = task_proto.task_id(); thrd_id_ = ThrdId4ActorId(actor_id_); job_id_ = task_proto.job_id(); act_cnt_ = 0; op_name_ = "NULL_OP"; debug_ = EnableActorDebugLog(); for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) { ExecKernel ek; ek.kernel_ctx.reset(new KernelContextImpl(actor_ctx)); ek.kernel = ConstructKernel(node.kernel_conf(), ek.kernel_ctx.get()); exec_kernel_vec_.emplace_back(std::move(ek)); op_name_ = node.kernel_conf().op_attribute().op_conf().name(); } is_kernel_launch_synchronized_ = std::all_of(exec_kernel_vec_.cbegin(), exec_kernel_vec_.cend(), [](const ExecKernel& ek) { return ek.kernel->IsKernelLaunchSynchronized(); }); if (!is_kernel_launch_synchronized_) { CHECK_EQ(exec_kernel_vec_.size(), 1); } remaining_eord_cnt_ = 0; msg_handler_ = nullptr; eord_regst_desc_ids_.clear(); for (const auto& pair : task_proto.produced_regst_desc()) { Singleton::Get()->NewRegsts(pair.second, [this](Regst* regst) { produced_regsts_[regst->regst_desc_id()].emplace_back(regst); }); int64_t regst_desc_id = pair.second.regst_desc_id(); CHECK(name2regst_desc_id_.insert({pair.first, {regst_desc_id}}).second); if (pair.second.regst_desc_type().has_ctrl_regst_desc()) { produced_ctrl_regst_desc_ids_.insert(regst_desc_id); } } for (const auto& pair : produced_regsts_) { for (const auto& regst : pair.second) { produced_regst2reading_cnt_[regst.get()] = 0; } } for (const auto& pair : task_proto.consumed_regst_desc_id()) { CHECK(name2regst_desc_id_.find(pair.first) == name2regst_desc_id_.end()); std::vector& regst_desc_id_vec = name2regst_desc_id_[pair.first]; for (int64_t regst_desc_id : pair.second.regst_desc_id()) { regst_desc_id_vec.emplace_back(regst_desc_id); } remaining_eord_cnt_ += pair.second.regst_desc_id_size(); if (pair.first == "in_ctrl") { consumed_ctrl_regst_desc_ids_.insert(regst_desc_id_vec.begin(), regst_desc_id_vec.end()); } } total_reading_cnt_ = 0; is_inplace_consumed_eord_ = false; CheckInplaceRegstDescId(task_proto); TakeOverInplaceConsumedAndProduced(task_proto.produced_regst_desc()); is_naive_consumed_eord_ = false; TakeOverNaiveConsumed(task_proto.consumed_regst_desc_id()); TakeOverNaiveProduced(task_proto.produced_regst_desc()); InitBnInOp2BlobInfo(task_proto); VirtualActorInit(task_proto); } void Actor::TakeOverInplaceConsumedAndProduced( const PbMap& produced_ids) { for (const auto& pair : produced_ids) { int64_t out_regst_desc_id = pair.second.regst_desc_id(); if (pair.second.has_inplace_consumed_regst_desc_id() == false) { continue; } int64_t in_regst_desc_id = pair.second.inplace_consumed_regst_desc_id(); inplace_regst_desc_id_in2out_.insert(std::make_pair(in_regst_desc_id, out_regst_desc_id)); inplace_regst_desc_id_out2in_.insert(std::make_pair(out_regst_desc_id, in_regst_desc_id)); inplace_consumed_rs_.InsertRegstDescId(in_regst_desc_id); inplace_produced_rs_.InsertRegstDescId(out_regst_desc_id); } inplace_consumed_rs_.InitedDone(); inplace_produced_rs_.InitedDone(); for (const auto& pair : produced_regsts_) { if (inplace_produced_rs_.HasRegstDescId(pair.first)) { for (const auto& regst : pair.second) { CHECK_EQ(0, inplace_produced_rs_.TryPushBackRegst(regst.get())); if (regst->consumers_actor_id().size() == 0) { CHECK(inplace_in_ids_with_no_out_consumed_ .emplace(inplace_regst_desc_id_out2in_.at(pair.first)) .second); } } } } } void Actor::TakeOverNaiveConsumed(const PbMap& consumed_ids) { auto res = GetNaiveOrCustomizedConsumedRegstDescName(); bool is_naive_names = res.first == RegstNameType::kNaive; const HashSet& names = res.second; for (const auto& pair : consumed_ids) { bool find_the_name = names.find(pair.first) != names.end(); if (is_naive_names == find_the_name || pair.first == "in_ctrl") { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { if (inplace_consumed_rs_.HasRegstDescId(regst_desc_id)) { continue; } naive_consumed_rs_.InsertRegstDescId(regst_desc_id); } } } naive_consumed_rs_.InitedDone(); } void Actor::TakeOverNaiveProduced(const PbMap& produced_ids) { auto res = GetNaiveOrCustomizedProducedRegstDescName(); bool is_naive_names = res.first == RegstNameType::kNaive; const HashSet& names = res.second; for (const auto& pair : produced_ids) { bool find_the_name = names.find(pair.first) != names.end(); if (inplace_produced_rs_.HasRegstDescId(pair.second.regst_desc_id())) { continue; } if (is_naive_names == find_the_name || pair.first.substr(0, 9) == "out_ctrl_") { naive_produced_rs_.InsertRegstDescId(pair.second.regst_desc_id()); } } naive_produced_rs_.InitedDone(); for (const auto& pair : produced_regsts_) { if (naive_produced_rs_.HasRegstDescId(pair.first) == false) { continue; } for (const auto& regst : pair.second) { CHECK_EQ(0, naive_produced_rs_.TryPushBackRegst(regst.get())); } } } void Actor::InitBnInOp2BlobInfo(const TaskProto& task_proto) { for (int64_t i = 0; i < exec_kernel_vec_.size(); ++i) { ExecKernel& ek = exec_kernel_vec_.at(i); const ExecNodeProto& node = task_proto.exec_sequence().exec_node(i); for (auto& pair : node.kernel_conf().op_attribute().arg_signature().bn_in_op2lbi()) { BlobInfo blob_info; blob_info.lbi = pair.second; const std::string& bn = pair.first; auto regst_desc_id_it = node.bn_in_op2regst_desc_id().find(bn); if (regst_desc_id_it != node.bn_in_op2regst_desc_id().end() && Singleton::Get()->HasRegstDescId(regst_desc_id_it->second)) { const int64_t regst_desc_id = regst_desc_id_it->second; blob_info.regst_desc_id = regst_desc_id; const RtRegstDesc& regst_desc = Singleton::Get()->RegstDesc4RegstDescId(regst_desc_id); blob_info.ordinal = regst_desc.GetOrdinalForLbi(blob_info.lbi); if (naive_produced_rs_.HasRegstDescId(regst_desc_id)) { blob_info.rs = &naive_produced_rs_; } else if (inplace_produced_rs_.HasRegstDescId(regst_desc_id)) { blob_info.rs = &inplace_produced_rs_; } else if (naive_consumed_rs_.HasRegstDescId(regst_desc_id)) { blob_info.rs = &naive_consumed_rs_; } else if (inplace_consumed_rs_.HasRegstDescId(regst_desc_id)) { blob_info.rs = &inplace_consumed_rs_; } else { blob_info.rs = nullptr; } } else { blob_info.regst_desc_id = -1; blob_info.ordinal = -1; blob_info.rs = nullptr; } ek.bn_in_op2blob_info.emplace(bn, std::move(blob_info)); } } } void Actor::ForEachProducedRegst(const std::function& Handler) const { for (const auto& pair : produced_regsts_) { for (const auto& regst : pair.second) { Handler(regst.get()); } } } int64_t Actor::Name2SoleRegstDescId(const std::string& name) const { auto find_it = name2regst_desc_id_.find(name); if (find_it != name2regst_desc_id_.end()) { CHECK_EQ(find_it->second.size(), 1); return find_it->second.front(); } return -1; } const std::vector& Actor::Name2RegstDescIds(const std::string& name) const { return name2regst_desc_id_.at(name); } int64_t Actor::ReadingCnt4ProducedRegst(Regst* regst) const { return produced_regst2reading_cnt_.at(regst); } void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) { produced_regst2reading_cnt_.at(regst) += val; } void Actor::ForEachCurNaiveReadableDataRegst(const std::function& func) const { naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) { if (Singleton::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; } if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { func(regst); } }); } bool Actor::ReceiveEordMsg(int64_t regst_desc_id) const { return eord_regst_desc_ids_.find(regst_desc_id) != eord_regst_desc_ids_.end(); } int Actor::HandlerNormal(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kEordMsg) { remaining_eord_cnt_ -= 1; CHECK(eord_regst_desc_ids_.insert(msg.eord_regst_desc_id()).second); if (naive_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) { is_naive_consumed_eord_ = true; } else if (inplace_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) { is_inplace_consumed_eord_ = true; } else { NormalProcessCustomizedEordMsg(msg); } } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { if (msg.SrcMachineId() == GlobalProcessCtx::Rank()) { Regst* regst = msg.regst(); if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) { CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst)); const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id()); CHECK(rdeq.empty() == false); if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) { NormalProcessNaiveReadableDataRegstMsg(rdeq); } } else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) { CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst)); } else if (TryUpdtStateAsProducedRegst(regst) == 0) { // do nothing } else { NormalProcessCustomizedReadableRegstMsg(msg); } } else { if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) { // process ctrl msg from other rank if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) { Regst* regst = msg.regst(); CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id())); CHECK(Singleton::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id())); CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id())); const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id()); CHECK(rdeq.empty() == false); } else { CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0); } } } if (debug_) { LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " in act_cnt: [ " << act_cnt_ << " ] , Recv ActorMsg from: " << msg.src_actor_id() << " to: " << msg.dst_actor_id() << " with regst: " << msg.regst_desc_id(); } ActUntilFail(); } else if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart); ActUntilFail(); } else { UNIMPLEMENTED(); } // handler halts bool has_naive_or_inplace = naive_consumed_rs_.total_regst_desc_cnt() != 0 || inplace_consumed_rs_.total_regst_desc_cnt() != 0; bool naive_or_inplace_eord_and_empty = (is_naive_consumed_eord_ || is_inplace_consumed_eord_) && (naive_consumed_rs_.available_regst_desc_cnt() == 0 && inplace_consumed_rs_.available_regst_desc_cnt() == 0); bool customized_eord = IsCustomizedReadAlwaysUnReadyFromNow(); if ((has_naive_or_inplace && naive_or_inplace_eord_and_empty) || (!has_naive_or_inplace && customized_eord)) { CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0); AsyncReturnAllCustomizedReadableRegst(); AsyncSendEORDMsgForAllProducedRegstDesc(); if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) { OF_SET_MSG_HANDLER(nullptr); return 1; } else { OF_SET_MSG_HANDLER(&Actor::HandlerZombie); return 0; } } return 0; } int Actor::HandlerZombie(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kEordMsg) { CHECK_GE(remaining_eord_cnt_, 1); remaining_eord_cnt_ -= 1; } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) { AsyncSendRegstMsgToProducer(msg.regst()); } } else { UNIMPLEMENTED(); } if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) { msg_handler_ = nullptr; return 1; } return 0; } void Actor::ActUntilFail() { if (debug_) { // NOTE(chengcheng): using if(debug_) code hack to minimize debug code cost when debug off. LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " Try to act before act_cnt: [ " << act_cnt_ << " ] . And IsReadReady: " << IsReadReady() << " IsWriteReady: " << IsWriteReady(); } while (IsReadReady() && IsWriteReady()) { PrepareProducedNaiveInplaceDataRegst(); if (debug_) { LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " Try to act act_cnt: [ " << act_cnt_ << " ] before launch kernel."; } Act(); AsyncSendCustomizedProducedRegstMsgToConsumer(); AsyncSendNaiveProducedRegstMsgToConsumer(); AsyncSendInplaceProducedRegstMsgToConsumer(); AsyncSendCustomizedConsumedRegstMsgToProducer(); AsyncSendNaiveConsumedRegstMsgToProducer(); AsyncRetInplaceConsumedRegstIfNoConsumer(); AsyncSendQueuedMsg(); if (debug_) { LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " Finish act act_cnt: [ " << act_cnt_++ << " ]."; } } // NOTE(liujuncheng): return inplace consumed AsyncSendQueuedMsg(); } void Actor::AsyncSendNaiveProducedRegstMsgToConsumer() { VirtualAsyncSendNaiveProducedRegstMsgToConsumer(); AsyncSendProducedCtrlRegstMsgToConsumer(); } void Actor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { HandleProducedNaiveDataRegstToConsumer(); } void Actor::AsyncSendInplaceProducedRegstMsgToConsumer() { VirtualAsyncSendInplaceProducedRegstMsgToConsumer(); } void Actor::AsyncRetInplaceConsumedRegstIfNoConsumer() { tmp_regst_desc_id_vec_.clear(); inplace_consumed_rs_.ForChosenRegstDeq( [&](int64_t regst_desc_id) { return inplace_in_ids_with_no_out_consumed_.find(regst_desc_id) != inplace_in_ids_with_no_out_consumed_.end(); }, [&](const std::deque& deq) { if (!deq.empty()) { Regst* in_regst = deq.front(); CHECK(in_regst); AsyncSendRegstMsgToProducer(in_regst); tmp_regst_desc_id_vec_.emplace_back(in_regst->regst_desc_id()); } }); inplace_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } void Actor::VirtualAsyncSendInplaceProducedRegstMsgToConsumer() { HandleProducedInplaceDataRegstToConsumer(); } void Actor::AsyncSendNaiveConsumedRegstMsgToProducer() { VirtualAsyncSendNaiveConsumedRegstMsgToProducer(); AsyncSendConsumedCtrlRegstMsgToProducer(); } void Actor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() { HandleConsumedNaiveDataRegstToProducer(); } void Actor::AsyncSendConsumedCtrlRegstMsgToProducer() { auto IsChosenRegstDescId = [this](int64_t regst_desc_id) { return IsConsumedCtrlRegstDescId(regst_desc_id) && ConsumedCtrlRegstValid(regst_desc_id); }; tmp_regst_desc_id_vec_.clear(); naive_consumed_rs_.ForChosenRegstDeq(IsChosenRegstDescId, [&](int64_t regst_desc_id, const std::deque& reg_deq) { CHECK(reg_deq.empty() == false); auto producer_task_id = Singleton::Get()->ProducerTaskId4RegstDescId(regst_desc_id); Regst* regst = reg_deq.front(); CHECK_GE(reg_deq.size(), 1); // must access regst before sending it to producer tmp_regst_desc_id_vec_.emplace_back(regst_desc_id); EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer_task_id, regst)); }); naive_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } void Actor::AsyncSendProducedCtrlRegstMsgToConsumer() { auto IsChosenRegstDescId = [this](int64_t regst_desc_id) { return IsProducedCtrlRegstDescId(regst_desc_id) && ProducedCtrlRegstValid(regst_desc_id); }; tmp_regst_desc_id_vec_.clear(); naive_produced_rs_.ForChosenFrontRegst(IsChosenRegstDescId, [&](Regst* regst) { CHECK(regst->regst_desc()->regst_desc_type().has_ctrl_regst_desc()); int64_t real_consumer_cnt = HandleRegstToConsumer(regst); if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id()); } }); naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } int64_t Actor::HandleRegstToConsumer(Regst* regst) { auto regst_reading_cnt_it = produced_regst2reading_cnt_.find(regst); CHECK_EQ(regst_reading_cnt_it->second, 0); int64_t real_consumer_cnt = 0; ActorMsg tpl_msg = ActorMsg::BuildRegstMsgToConsumer(actor_id_, 0, regst); for (int64_t consumer : regst->consumers_actor_id()) { tpl_msg.set_dst_actor_id(consumer); EnqueueAsyncMsg(tpl_msg); real_consumer_cnt += 1; } total_reading_cnt_ += real_consumer_cnt; regst_reading_cnt_it->second += real_consumer_cnt; return real_consumer_cnt; } bool Actor::IsReadReady() const { return naive_consumed_rs_.IsCurSlotReady() && inplace_consumed_rs_.IsCurSlotReady() && IsCustomizedReadReady(); } bool Actor::IsWriteReady() const { return naive_produced_rs_.IsCurSlotReady() && inplace_produced_rs_.IsCurSlotReady() && IsCustomizedWriteReady(); } void Actor::AsyncLaunchKernel(std::function Regst4RegstDescId) { for (const ExecKernel& ek : exec_kernel_vec_) { CHECK_NOTNULL(dynamic_cast(ek.kernel_ctx.get())) ->UpdateBnInOp2BlobFn([&](const std::string& bn_in_op) -> Blob* { const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op); if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; } const BlobInfo& info = blob_info_it->second; if (info.regst_desc_id == -1) { return nullptr; } Regst* regst = nullptr; if (info.rs != nullptr) { regst = info.rs->Front(info.regst_desc_id); } else { regst = Regst4RegstDescId(info.regst_desc_id); } if (regst == nullptr) { return nullptr; } if (info.ordinal >= 0) { return regst->GetBlobByOrdinal(info.ordinal); } else { return regst->GetBlobByLbi(info.lbi); } }); ek.kernel->Launch(ek.kernel_ctx.get()); } } void Actor::AsyncLaunchKernel() { AsyncLaunchKernel([](int64_t) -> Regst* { UNIMPLEMENTED(); return nullptr; }); } void Actor::PrepareProducedNaiveInplaceDataRegst() { naive_produced_rs_.ForEachFrontRegst([&](Regst* regst) { if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK(regst->body_mem_ptr() == nullptr); void* body_ptr = nullptr; CHECK_JUST(actor_ctx_->stream_ctx()->stream()->AllocAsync( &body_ptr, regst->regst_desc()->BodyByteSize4OneRegst())); regst->ResetBodyMemPtr(body_ptr); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } }); inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) { CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc()); const int64_t in_regst_desc_id = inplace_regst_desc_id_out2in_.at(regst->regst_desc_id()); Regst* in_regst = inplace_consumed_rs_.Front(in_regst_desc_id); CHECK(in_regst != nullptr); if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK(regst->body_mem_ptr() == nullptr); regst->ResetBodyMemPtr(in_regst->body_mem_ptr()); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } }); } void Actor::HandleProducedNaiveDataRegstToConsumer() { tmp_regst_desc_id_vec_.clear(); naive_produced_rs_.ForEachFrontRegst([&](Regst* regst) { if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { int64_t real_consumer_cnt = HandleRegstToConsumer(regst); if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id()); } else { if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(regst->body_mem_ptr())); regst->ResetBodyMemPtr(nullptr); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } } }); naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } void Actor::HandleProducedInplaceDataRegstToConsumer() { tmp_regst_desc_id_vec_.clear(); inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) { CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc()); int64_t real_consumer_cnt = HandleRegstToConsumer(regst); if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id()); } else { if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { regst->ResetBodyMemPtr(nullptr); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } }); inplace_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } void Actor::HandleConsumedNaiveDataRegstToProducer() { tmp_regst_desc_id_vec_.clear(); naive_consumed_rs_.ForEachFrontRegst([&](int64_t regst_desc_id, Regst* regst) { if (IsConsumedCtrlRegstDescId(regst_desc_id)) { return; } if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { // must access regst before sending it to producer tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id()); EnqueueAsyncMsg( ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst)); } }); naive_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_); } void Actor::AsyncSendEORDMsgForAllProducedRegstDesc() { for (auto& pair : produced_regsts_) { CHECK(!pair.second.empty()); const RtRegstDesc* regst_desc = pair.second.front()->regst_desc(); AddCallback([regst_desc]() { for (int64_t consumer : regst_desc->consumers_actor_id()) { Singleton::Get()->SendMsg( ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id())); } }); } } void Actor::AsyncSendRegstMsgToProducer(Regst* regst) { AsyncSendRegstMsgToProducer(regst, regst->producer_actor_id()); } void Actor::AsyncSendRegstMsgToProducer(Regst* regst, int64_t producer) { // must access regst before sending it to producer int64_t regst_desc_id = regst->regst_desc_id(); EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer, regst)); naive_consumed_rs_.TryPopFrontRegst(regst_desc_id); } Regst* Actor::GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) const { auto it = produced_regsts_.find(regst_desc_id); CHECK(it != produced_regsts_.end()); CHECK_EQ(it->second.size(), 1); return it->second.front().get(); } int Actor::TryUpdtStateAsProducedRegst(Regst* regst) { auto reading_cnt_it = produced_regst2reading_cnt_.find(regst); if (reading_cnt_it == produced_regst2reading_cnt_.end()) { return -1; } CHECK(produced_regsts_.find(regst->regst_desc_id()) != produced_regsts_.end()); CHECK_GE(reading_cnt_it->second, 1); reading_cnt_it->second -= 1; total_reading_cnt_ -= 1; if (debug_) { LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " in act_cnt: [ " << act_cnt_ << " ] recv produce_regst: " << regst->regst_desc_id() << " and the total_reading_cnt_ is : " << total_reading_cnt_ << " now."; } if (reading_cnt_it->second != 0) { return 0; } if (inplace_produced_rs_.TryPushBackRegst(regst) == 0) { int64_t in_regst_desc_id = inplace_regst_desc_id_out2in_.at(regst->regst_desc_id()); Regst* in_regst = inplace_consumed_rs_.Front(in_regst_desc_id); CHECK(in_regst); if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { regst->ResetBodyMemPtr(nullptr); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } AsyncSendRegstMsgToProducer(in_regst); CHECK_EQ(0, inplace_consumed_rs_.TryPopFrontRegst(in_regst_desc_id)); } else if (naive_produced_rs_.TryPushBackRegst(regst) == 0) { if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(regst->body_mem_ptr())); regst->ResetBodyMemPtr(nullptr); } else if (regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } else { UpdtStateAsCustomizedProducedRegst(regst); } return 0; } void Actor::EnqueueAsyncMsg(const ActorMsg& msg) { if (is_kernel_launch_synchronized_ && thrd_id_ == ThrdId4ActorId(msg.dst_actor_id())) { sync_msg_queue_.emplace_back(msg); } else { async_msg_queue_.emplace_back(msg); } if (debug_ && msg.msg_type() == ActorMsgType::kRegstMsg) { LOG(INFO) << " Actor: " << actor_id_ << " op: " << op_name_ << " in act_cnt: [ " << act_cnt_ << " ] post ActorMsg from: " << msg.src_actor_id() << " to: " << msg.dst_actor_id() << " with regst: " << msg.regst_desc_id(); } } Regst* Actor::GetNaiveOrInplaceCurReadable(int64_t regst_desc_id) const { Regst* regst = naive_consumed_rs_.Front(regst_desc_id); if (regst == nullptr) { regst = inplace_consumed_rs_.Front(regst_desc_id); } return regst; } Regst* Actor::GetNaiveOrInplaceCurWriteable(int64_t regst_desc_id) const { Regst* regst = naive_produced_rs_.Front(regst_desc_id); if (regst == nullptr) { regst = inplace_produced_rs_.Front(regst_desc_id); } return regst; } Regst* Actor::GetNaiveCurReadable(int64_t regst_desc_id) const { return naive_consumed_rs_.Front(regst_desc_id); } Regst* Actor::GetNaiveCurWriteable(int64_t regst_desc_id) const { return naive_produced_rs_.Front(regst_desc_id); } void Actor::AsyncSendQueuedMsg() { if (!sync_msg_queue_.empty()) { Singleton::Get()->SendMsgsWithoutCommNet(sync_msg_queue_.data(), sync_msg_queue_.size(), thrd_id_); sync_msg_queue_.clear(); } if (!async_msg_queue_.empty()) { std::deque msgs; msgs.swap(async_msg_queue_); AddCallback([msgs]() { for (const ActorMsg& msg : msgs) { Singleton::Get()->SendMsg(msg); } }); } } void Actor::AddCallback(std::function callback) { actor_ctx_->AddCallback(std::move(callback)); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_ #define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_ #include "oneflow/core/lazy/actor/actor_base.h" #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/core/lazy/actor/register_slot.h" namespace oneflow { class Actor : public ActorBase { public: OF_DISALLOW_COPY_AND_MOVE(Actor); virtual ~Actor(); void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override; // 1: success, and actor finish // 0: success, and actor not finish int ProcessMsg(const ActorMsg& msg) override { return (this->*msg_handler_)(msg); } int64_t machine_id() const { return MachineId4ActorId(actor_id_); } int64_t actor_id() const { return actor_id_; } int64_t job_id() const { return job_id_; } protected: struct BlobInfo { LogicalBlobId lbi; int64_t regst_desc_id; int64_t ordinal; RegstSlot* rs; }; struct ExecKernel { std::unique_ptr kernel; HashMap bn_in_op2blob_info; std::unique_ptr kernel_ctx; }; using MsgHandler = int (Actor::*)(const ActorMsg&); enum class RegstNameType { kNaive = 0, kCustomized }; // Util Actor() = default; bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; } bool ReceiveEordMsg(int64_t regst_desc_id) const; virtual void VirtualActorInit(const TaskProto&) {} int64_t Name2SoleRegstDescId(const std::string& name) const; const std::vector& Name2RegstDescIds(const std::string& name) const; ActorContext* actor_ctx() const { return actor_ctx_; } const std::vector& exec_kernel_vec() { return exec_kernel_vec_; } void ForEachCurNaiveReadableDataRegst(const std::function&) const; int64_t ReadingCnt4ProducedRegst(Regst* regst) const; void IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val); void IncreaseTotalReadingCnt(int64_t val) { total_reading_cnt_ += val; } // Msg Handler void set_msg_handler(MsgHandler val) { msg_handler_ = val; } #define OF_SET_MSG_HANDLER(val) \ do { \ VLOG(3) << "actor " << actor_id() << " switch to " << #val; \ set_msg_handler(static_cast(val)); \ } while (0) // Common Handlers and related virtual method int HandlerNormal(const ActorMsg& msg); int HandlerZombie(const ActorMsg& msg); virtual bool ConsumedCtrlRegstValid(int64_t regst_desc_id) const { return true; } virtual bool ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; } void AsyncLaunchKernel(std::function Regst4RegstDescId); void AsyncLaunchKernel(); // Util For Derived Actor to Send Msg void EnqueueAsyncMsg(const ActorMsg&); void HandleProducedNaiveDataRegstToConsumer(); void PrepareProducedNaiveInplaceDataRegst(); void HandleProducedInplaceDataRegstToConsumer(); void HandleConsumedNaiveDataRegstToProducer(); void AsyncSendRegstMsgToProducer(Regst*); void AsyncSendRegstMsgToProducer(Regst*, int64_t producer); void AsyncSendEORDMsgForAllProducedRegstDesc(); void AsyncSendQueuedMsg(); // Get Regst Regst* GetNaiveCurReadable(int64_t regst_desc_id) const; Regst* GetNaiveCurReadable(const std::string& name) const { return GetNaiveCurReadable(Name2SoleRegstDescId(name)); } Regst* GetNaiveOrInplaceCurReadable(int64_t regst_desc_id) const; Regst* GetNaiveOrInplaceCurReadable(const std::string& name) const { return GetNaiveOrInplaceCurReadable(Name2SoleRegstDescId(name)); } Regst* GetNaiveCurWriteable(int64_t regst_desc_id) const; Regst* GetNaiveCurWriteable(const std::string& name) const { return GetNaiveCurWriteable(Name2SoleRegstDescId(name)); } Regst* GetNaiveOrInplaceCurWriteable(int64_t regst_desc_id) const; Regst* GetNaiveOrInplaceCurWriteable(const std::string& name) const { return GetNaiveOrInplaceCurWriteable(Name2SoleRegstDescId(name)); } Regst* GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) const; void ForEachProducedRegst(const std::function&) const; int64_t HandleRegstToConsumer(Regst* regst); protected: bool IsConsumedCtrlRegstDescId(int64_t regst_desc_id) { return consumed_ctrl_regst_desc_ids_.find(regst_desc_id) != consumed_ctrl_regst_desc_ids_.end(); } bool IsProducedCtrlRegstDescId(int64_t regst_desc_id) { return produced_ctrl_regst_desc_ids_.find(regst_desc_id) != produced_ctrl_regst_desc_ids_.end(); } // Process Msg virtual void NormalProcessNaiveReadableDataRegstMsg(const std::deque&) {} virtual bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) { return false; } int TryUpdtStateAsProducedRegst(Regst* regst); // Act void ActUntilFail(); virtual void Act() { UNIMPLEMENTED(); } // Ready bool IsReadReady() const; bool IsWriteReady() const; // Naive, Inplace Or Customized virtual void TakeOverInplaceConsumedAndProduced( const PbMap& produced_ids); void TakeOverNaiveConsumed(const PbMap& consumed_ids); void TakeOverNaiveProduced(const PbMap& produced_ids); void InitBnInOp2BlobInfo(const TaskProto& task_proto); // Send Msgs void AsyncSendNaiveProducedRegstMsgToConsumer(); virtual void VirtualAsyncSendNaiveProducedRegstMsgToConsumer(); virtual void VirtualAsyncSendInplaceProducedRegstMsgToConsumer(); void AsyncSendInplaceProducedRegstMsgToConsumer(); void AsyncSendNaiveConsumedRegstMsgToProducer(); virtual void VirtualAsyncSendNaiveConsumedRegstMsgToProducer(); void AsyncSendConsumedCtrlRegstMsgToProducer(); void AsyncSendProducedCtrlRegstMsgToConsumer(); // Customized Consumed virtual func virtual void ForEachCurCustomizedReadableRegst(std::function) const {} virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {} virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); } virtual bool IsCustomizedReadReady() const { return true; } virtual bool IsCustomizedReadAlwaysUnReadyFromNow() const { return false; } virtual std::pair> GetNaiveOrCustomizedConsumedRegstDescName() { return std::make_pair(RegstNameType::kCustomized, HashSet{}); } virtual void AsyncSendCustomizedProducedRegstMsgToConsumer() {} virtual void AsyncReturnAllCustomizedReadableRegst() {} // Customized Produced virtual func virtual void UpdtStateAsCustomizedProducedRegst(Regst* regst) { UNIMPLEMENTED(); } virtual bool IsCustomizedWriteReady() const { return true; } virtual std::pair> GetNaiveOrCustomizedProducedRegstDescName() { return std::make_pair(RegstNameType::kCustomized, HashSet{}); } virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {} void AsyncRetInplaceConsumedRegstIfNoConsumer(); virtual void AddCallback(std::function callback); int64_t actor_id_; int64_t thrd_id_; int64_t job_id_; std::vector exec_kernel_vec_; HashMap> name2regst_desc_id_; MsgHandler msg_handler_; ActorContext* actor_ctx_; HashSet eord_regst_desc_ids_; int64_t remaining_eord_cnt_; HashMap>> produced_regsts_; HashMap produced_regst2reading_cnt_; int64_t total_reading_cnt_; RegstSlot naive_produced_rs_; RegstSlot naive_consumed_rs_; bool is_naive_consumed_eord_; HashSet produced_ctrl_regst_desc_ids_; HashSet consumed_ctrl_regst_desc_ids_; RegstSlot inplace_consumed_rs_; RegstSlot inplace_produced_rs_; bool is_inplace_consumed_eord_; HashSet inplace_in_ids_with_no_out_consumed_; HashMap inplace_regst_desc_id_in2out_; HashMap inplace_regst_desc_id_out2in_; std::deque async_msg_queue_; std::vector sync_msg_queue_; bool is_kernel_launch_synchronized_; std::vector tmp_regst_desc_id_vec_; // for debug std::string op_name_; bool debug_; int64_t act_cnt_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_ ================================================ FILE: oneflow/core/lazy/actor/actor_base.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/runtime_job_descs.h" namespace oneflow { std::unique_ptr NewActor(ActorContext* actor_ctx) { ActorBase* rptr = NewObj(actor_ctx->task_proto().task_type()); const auto& job_descs = *Singleton::Get(); rptr->Init(&job_descs.job_desc(actor_ctx->task_proto().job_id()), actor_ctx); return std::unique_ptr(rptr); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor_base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_ #define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/lazy/actor/actor_context.h" namespace oneflow { class JobDesc; class TaskProto; class StreamContext; class ActorMsg; class ActorBase { public: OF_DISALLOW_COPY_AND_MOVE(ActorBase); ActorBase() = default; virtual ~ActorBase() = default; virtual void Init(const JobDesc* job_desc, ActorContext* actor_ctx) = 0; // 1: success, and actor finish // 0: success, and actor not finish virtual int ProcessMsg(const ActorMsg& msg) = 0; }; std::unique_ptr NewActor(ActorContext* actor_ctx); #define REGISTER_ACTOR(task_type, ActorType) \ REGISTER_CLASS(int32_t, task_type, ActorBase, ActorType) } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_ ================================================ FILE: oneflow/core/lazy/actor/actor_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor_context.h" #include "oneflow/core/lazy/actor/generic_actor_context.h" namespace oneflow { std::unique_ptr NewActorContext(const TaskProto& task_proto, StreamContext* stream_ctx) { ActorContext* ctx = nullptr; if (IsClassRegistered(task_proto.task_type())) { ctx = NewObj(task_proto.task_type()); } else { ctx = new GenericActorContext(); } ctx->Init(task_proto, stream_ctx); return std::unique_ptr(ctx); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_ #define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_ #include "oneflow/core/lazy/stream_context/include/stream_context.h" #include "oneflow/core/job/task.pb.h" namespace oneflow { class ActorContext { public: OF_DISALLOW_COPY_AND_MOVE(ActorContext); ActorContext() = default; virtual ~ActorContext() = default; virtual void Init(const TaskProto& task_proto, StreamContext* stream_ctx) = 0; virtual void AddCallback(std::function callback) = 0; virtual StreamContext* stream_ctx() const = 0; virtual const TaskProto& task_proto() const = 0; }; class ActorContextProvider { public: OF_DISALLOW_COPY_AND_MOVE(ActorContextProvider); ActorContextProvider() = default; virtual ~ActorContextProvider() = default; virtual ActorContext* GetActorContext() const = 0; }; std::unique_ptr NewActorContext(const TaskProto& task_proto, StreamContext* stream_ctx); #define REGISTER_ACTOR_CONTEXT(task_type, ActorContextType) \ REGISTER_CLASS(int32_t, task_type, ActorContext, ActorContextType) } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_ ================================================ FILE: oneflow/core/lazy/actor/actor_message.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/id_manager.h" namespace oneflow { namespace { bool IsSoleBlobAndDynamicEmpty(Regst* regst) { if (regst == nullptr) { return false; } if (regst->GetBlobSize() != 1) { return false; } Blob* sole_blob = regst->GetMutSoleBlob(); if (!regst->GetSoleBlob()->IsBodyEmpty()) { return false; } const auto& shape = sole_blob->shape(); for (int i = 0; i < shape.NumAxes(); ++i) { if (shape.At(i) != 0) { return false; } } return true; } } // namespace ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, Regst* regst_raw_ptr) { ActorMsg msg{}; msg.src_actor_id_ = producer; msg.dst_actor_id_ = consumer; msg.msg_type_ = ActorMsgType::kRegstMsg; msg.regst_wrapper_.regst = regst_raw_ptr; msg.regst_wrapper_.comm_net_token = nullptr; msg.regst_wrapper_.regst_desc_id = regst_raw_ptr->regst_desc_id(); msg.regst_wrapper_.has_sole_empty_blob = IsSoleBlobAndDynamicEmpty(regst_raw_ptr); msg.regst_wrapper_.is_data_regst_to_consumer = regst_raw_ptr->regst_desc()->regst_desc_type().has_data_regst_desc(); return msg; } ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer, Regst* regst_raw_ptr) { ActorMsg msg{}; msg.src_actor_id_ = consumer; msg.dst_actor_id_ = producer; msg.msg_type_ = ActorMsgType::kRegstMsg; msg.regst_wrapper_.regst = regst_raw_ptr; msg.regst_wrapper_.regst_desc_id = -1; msg.regst_wrapper_.comm_net_token = nullptr; // you can NOT access the regst ptr when multi nodes, because the address is in another machine msg.regst_wrapper_.has_sole_empty_blob = false; msg.regst_wrapper_.is_data_regst_to_consumer = false; return msg; } ActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) { ActorMsg msg{}; msg.src_actor_id_ = -1; msg.dst_actor_id_ = consumer; msg.msg_type_ = ActorMsgType::kEordMsg; msg.eord_regst_desc_id_ = regst_desc_id; return msg; } ActorMsg ActorMsg::BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd) { ActorMsg msg{}; msg.src_actor_id_ = -1; msg.dst_actor_id_ = dst_actor_id; msg.msg_type_ = ActorMsgType::kCmdMsg; msg.actor_cmd_ = cmd; return msg; } int64_t ActorMsg::SrcMachineId() const { return MachineId4ActorId(src_actor_id_); } ActorCmd ActorMsg::actor_cmd() const { CHECK_EQ(msg_type_, ActorMsgType::kCmdMsg); return actor_cmd_; } Regst* ActorMsg::regst() const { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); return regst_wrapper_.regst; } int64_t ActorMsg::regst_desc_id() const { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); // FIXME(liujunchneg): regst_desc_id for remote returned regst if (MachineId4ActorId(src_actor_id_) == GlobalProcessCtx::Rank()) { return regst_wrapper_.regst->regst_desc_id(); } else { return regst_wrapper_.regst_desc_id; } } int64_t ActorMsg::comm_net_sequence_number() const { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); return regst_wrapper_.comm_net_sequence_number; } void ActorMsg::set_comm_net_sequence_number(int64_t sequence_number) { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); regst_wrapper_.comm_net_sequence_number = sequence_number; } void* ActorMsg::comm_net_token() const { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); return regst_wrapper_.comm_net_token; } void ActorMsg::set_comm_net_token(void* token) { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); regst_wrapper_.comm_net_token = token; } bool ActorMsg::has_sole_empty_blob() const { CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); return regst_wrapper_.has_sole_empty_blob; } int64_t ActorMsg::eord_regst_desc_id() const { CHECK_EQ(msg_type_, ActorMsgType::kEordMsg); return eord_regst_desc_id_; } bool ActorMsg::IsDataRegstMsgToConsumer() const { return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer; } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor_message.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_ #define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/register/register.h" namespace oneflow { enum class ActorCmd { kStart = 0, // Source Actor kStopThread, kConstructActor }; enum class ActorMsgType : int8_t { kRegstMsg = 0, kEordMsg, kCmdMsg }; class ActorMsg final { public: ActorMsg() = default; ~ActorMsg() = default; // Build Msg static ActorMsg BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, Regst*); static ActorMsg BuildRegstMsgToProducer(int64_t consumer, int64_t producer, Regst*); static ActorMsg BuildEordMsg(int64_t consumer, int64_t regst_desc_id); static ActorMsg BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd); // Getters int64_t SrcMachineId() const; int64_t src_actor_id() const { return src_actor_id_; } int64_t dst_actor_id() const { return dst_actor_id_; } ActorMsgType msg_type() const { return msg_type_; } ActorCmd actor_cmd() const; Regst* regst() const; int64_t regst_desc_id() const; void* comm_net_token() const; void set_comm_net_token(void* token); bool has_sole_empty_blob() const; int64_t eord_regst_desc_id() const; bool IsDataRegstMsgToConsumer() const; int64_t comm_net_sequence_number() const; void set_comm_net_sequence_number(int64_t sequence_number); // Serialize template void Serialize(StreamT& out_stream) const { out_stream.Write(this, sizeof(ActorMsg)); } template void Deserialize(StreamT& in_stream) { in_stream.Read(this, sizeof(ActorMsg)); } void set_dst_actor_id(int64_t actor_id) { dst_actor_id_ = actor_id; } private: struct RegstWrapper { Regst* regst; void* comm_net_token; int64_t comm_net_sequence_number; int64_t regst_desc_id; bool has_sole_empty_blob; bool is_data_regst_to_consumer; }; int64_t src_actor_id_; int64_t dst_actor_id_; union { ActorCmd actor_cmd_; RegstWrapper regst_wrapper_; int64_t eord_regst_desc_id_; }; ActorMsgType msg_type_; }; template StreamT& operator<<(StreamT& out_stream, const ActorMsg& msg) { msg.Serialize(out_stream); } template StreamT& operator>>(StreamT& in_stream, const ActorMsg& msg) { msg.Deserialize(in_stream); } } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_ ================================================ FILE: oneflow/core/lazy/actor/actor_message_bus.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/comm_network/comm_network.h" namespace oneflow { void ActorMsgBus::SendMsg(const ActorMsg& msg) { int64_t dst_machine_id = MachineId4ActorId(msg.dst_actor_id()); if (dst_machine_id == GlobalProcessCtx::Rank()) { SendMsgWithoutCommNet(msg); } else { if (msg.IsDataRegstMsgToConsumer()) { int64_t comm_net_sequence; { std::unique_lock lock( regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_); int64_t& comm_net_sequence_ref = regst_desc_id_dst_actor_id2comm_net_sequence_number_[std::make_pair( msg.regst_desc_id(), msg.dst_actor_id())]; comm_net_sequence = comm_net_sequence_ref; comm_net_sequence_ref += 1; } ActorMsg new_msg = msg; new_msg.set_comm_net_sequence_number(comm_net_sequence); Singleton::Get()->SendActorMsg(dst_machine_id, new_msg); } else { Singleton::Get()->SendActorMsg(dst_machine_id, msg); } } } void ActorMsgBus::SendMsgWithoutCommNet(const ActorMsg& msg) { CHECK_EQ(MachineId4ActorId(msg.dst_actor_id()), GlobalProcessCtx::Rank()); int64_t thrd_id = ThrdId4ActorId(msg.dst_actor_id()); Singleton::Get()->GetThrd(thrd_id)->EnqueueActorMsg(msg); } void ActorMsgBus::SendMsgsWithoutCommNet(const ActorMsg* msgs, size_t n, int64_t thrd_id) { Singleton::Get()->GetThrd(thrd_id)->EnqueueActorMsg(msgs, msgs + n); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/actor_message_bus.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_ #define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_ #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/common/util.h" namespace oneflow { class ActorMsgBus final { public: OF_DISALLOW_COPY_AND_MOVE(ActorMsgBus); ~ActorMsgBus() = default; void SendMsg(const ActorMsg& msg); void SendMsgWithoutCommNet(const ActorMsg& msg); void SendMsgsWithoutCommNet(const ActorMsg* msgs, size_t n, int64_t thrd_id); private: friend class Singleton; ActorMsgBus() = default; HashMap, int64_t> regst_desc_id_dst_actor_id2comm_net_sequence_number_; std::mutex regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_ ================================================ FILE: oneflow/core/lazy/actor/boxing_zeros_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/naive_actor.h" namespace oneflow { class BoxingZerosActor : public NaiveActor { public: OF_DISALLOW_COPY_AND_MOVE(BoxingZerosActor); BoxingZerosActor() = default; ~BoxingZerosActor() override = default; void VirtualActorInit(const TaskProto& task_proto) override { NaiveActor::VirtualActorInit(task_proto); out_inited_ = false; } private: void Act() override { if (!out_inited_) { NaiveActor::Act(); out_inited_ = true; } } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override { HandleProducedNaiveDataRegstToConsumer(); } bool out_inited_; }; REGISTER_ACTOR(TaskType::kBoxingZeros, BoxingZerosActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/callback_notify_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/sink_actor.h" namespace oneflow { class CallbackNotifyActor final : public SinkActor { public: OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyActor); CallbackNotifyActor() = default; ~CallbackNotifyActor() = default; private: }; REGISTER_ACTOR(TaskType::kCallbackNotify, CallbackNotifyActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/case_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/case_kernel.h" #include "oneflow/core/operator/operator.h" namespace oneflow { class CaseActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(CaseActor); CaseActor() : case_status_(nullptr) {} ~CaseActor() override = default; protected: bool IsCustomizedReadReady() const override; bool IsCustomizedWriteReady() const override; bool IsCustomizedReadAlwaysUnReadyFromNow() const override; void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; void AsyncSendCustomizedProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; void ForEachCurCustomizedReadableRegst(std::function) const override; void VirtualActorInit(const TaskProto&) override; bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } std::pair> GetNaiveOrCustomizedProducedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } private: void Act() override; void TakeOverConsumedRegst(const PbMap& consumed_ids); void TakeOverProducedRegst(const PbMap& produced_ids); bool IsInputOrOutputReady() const; int64_t GetCurSelectId() const; HashMap out_bn_id2regst_desc_id_; int64_t consumed_regst_desc_id_{}; RegstSlot consumed_rs_; HashMap regst_desc_id2produced_rs_; CaseStatus* case_status_; }; void CaseActor::VirtualActorInit(const TaskProto& task_proto) { CHECK_EQ(1, exec_kernel_vec().size()); case_status_ = CHECK_NOTNULL(dynamic_cast(exec_kernel_vec().at(0).kernel_ctx->state().get())); const int32_t output_bns_size = task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().output_bns_size(); FOR_RANGE(int64_t, i, 0, output_bns_size) { const int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn("out", i)).regst_desc_id; CHECK(out_bn_id2regst_desc_id_.emplace(i, regst_desc_id).second); } TakeOverConsumedRegst(task_proto.consumed_regst_desc_id()); TakeOverProducedRegst(task_proto.produced_regst_desc()); OF_SET_MSG_HANDLER(&CaseActor::HandlerNormal); } void CaseActor::TakeOverConsumedRegst(const PbMap& consumed_ids) { CHECK_EQ(consumed_ids.size(), 1); const auto& pair = *consumed_ids.begin(); CHECK_EQ(pair.second.regst_desc_id_size(), 1); consumed_regst_desc_id_ = pair.second.regst_desc_id(0); consumed_rs_.InsertRegstDescId(consumed_regst_desc_id_); consumed_rs_.InitedDone(); } void CaseActor::TakeOverProducedRegst(const PbMap& produced_ids) { for (const auto& pair : produced_ids) { CHECK(pair.second.regst_desc_type().has_data_regst_desc()); CHECK_EQ(pair.second.has_inplace_consumed_regst_desc_id(), false); const int64_t regst_desc_id = pair.second.regst_desc_id(); regst_desc_id2produced_rs_[regst_desc_id].InsertRegstDescId(regst_desc_id); regst_desc_id2produced_rs_.at(regst_desc_id).InitedDone(); } ForEachProducedRegst([&](Regst* regst) { const int64_t regst_desc_id = regst->regst_desc_id(); CHECK_EQ(0, regst_desc_id2produced_rs_.at(regst_desc_id).TryPushBackRegst(regst)); }); } // twice called for each output // first called: set cur_selected_id // second called: output cur_selected_id void CaseActor::Act() { Regst* const consumed_regst = consumed_rs_.Front(consumed_regst_desc_id_); case_status_->cur_selected_id = GetCurSelectId(); case_status_->cmd = (case_status_->cur_selected_id == -1 ? kCaseCmdHandleInput : kCaseCmdHandleOutput); AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { if (consumed_regst_desc_id_ == regst_desc_id) { return consumed_regst; } return regst_desc_id2produced_rs_.at(regst_desc_id).Front(regst_desc_id); }); } void CaseActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { const int64_t regst_desc_id = regst->regst_desc_id(); CHECK_EQ(0, regst_desc_id2produced_rs_.at(regst_desc_id).TryPushBackRegst(regst)); } bool CaseActor::IsCustomizedReadReady() const { return IsInputOrOutputReady(); } bool CaseActor::IsCustomizedWriteReady() const { return IsInputOrOutputReady(); } bool CaseActor::IsCustomizedReadAlwaysUnReadyFromNow() const { return ReceiveEordMsg(consumed_regst_desc_id_) && case_status_->select_id2request_cnt.size() == 0; } bool CaseActor::IsInputOrOutputReady() const { if (GetCurSelectId() != -1) { return true; } return consumed_rs_.IsCurSlotReady(); } int64_t CaseActor::GetCurSelectId() const { for (const auto& pair : case_status_->select_id2request_cnt) { CHECK_GT(pair.second, 0); const int64_t regst_desc_id = out_bn_id2regst_desc_id_.at(pair.first); if (regst_desc_id2produced_rs_.at(regst_desc_id).IsCurSlotReady()) { return pair.first; } } return -1; } void CaseActor::ForEachCurCustomizedReadableRegst(std::function Handler) const { Handler(consumed_rs_.Front(consumed_regst_desc_id_)); } void CaseActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { if (case_status_->cmd != kCaseCmdHandleInput) { return; } Regst* const cur_regst = consumed_rs_.Front(consumed_regst_desc_id_); CHECK_NOTNULL(cur_regst); AsyncSendRegstMsgToProducer(cur_regst); CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(consumed_regst_desc_id_)); } void CaseActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst())); } void CaseActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { if (case_status_->cmd != kCaseCmdHandleOutput) { return; } const int64_t regst_desc_id = out_bn_id2regst_desc_id_.at(case_status_->cur_selected_id); Regst* const regst = regst_desc_id2produced_rs_.at(regst_desc_id).Front(regst_desc_id); CHECK_GT(HandleRegstToConsumer(regst), 0); regst_desc_id2produced_rs_.at(regst_desc_id).PopFrontRegsts({regst_desc_id}); } bool CaseActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; } REGISTER_ACTOR(kCase, CaseActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/collective_boxing_actor_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/collective_boxing_actor_context.h" #include "oneflow/core/job/collective_boxing/scheduler.h" namespace oneflow { using namespace boxing::collective; void CollectiveBoxingActorContext::Init(const TaskProto& task_proto, StreamContext* stream_ctx) { stream_ctx_ = stream_ctx; task_proto_ = task_proto; scheduled_count_ = 0; completed_count_ = 0; } void CollectiveBoxingActorContext::AddCallback(std::function callback) { std::lock_guard lock(mutex_); if (scheduled_count_ == completed_count_) { callback(); } else { callbacks_.emplace_back(std::make_pair(scheduled_count_ - 1, std::move(callback))); } } void CollectiveBoxingActorContext::Schedule(RequestHandle* handle, const void* send_buff, void* recv_buff) { std::lock_guard lock(mutex_); auto request = std::make_shared(); request->send_buff = send_buff; request->recv_buff = recv_buff; const size_t schedule_id = scheduled_count_; request->callback = [schedule_id, this](const Maybe& status) { CHECK(status.IsOk()); this->SetCompleted(schedule_id); }; Singleton::Get()->Schedule(handle, request); scheduled_count_ += 1; } void CollectiveBoxingActorContext::SetCompleted(size_t schedule_id) { std::lock_guard lock(mutex_); CHECK_EQ(schedule_id, completed_count_); while (!callbacks_.empty() && callbacks_.front().first == schedule_id) { callbacks_.front().second(); callbacks_.pop_front(); } completed_count_ += 1; } StreamContext* CollectiveBoxingActorContext::stream_ctx() const { return stream_ctx_; } const TaskProto& CollectiveBoxingActorContext::task_proto() const { return task_proto_; } REGISTER_ACTOR_CONTEXT(TaskType::kCollectiveBoxingGeneric, CollectiveBoxingActorContext); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/collective_boxing_actor_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_ #define ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_ #include "oneflow/core/lazy/actor/actor_context.h" #include "oneflow/core/job/collective_boxing/scheduler.h" namespace oneflow { class CollectiveBoxingActorContext : public ActorContext { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingActorContext); CollectiveBoxingActorContext() = default; ~CollectiveBoxingActorContext() override = default; void Init(const TaskProto& task_proto, StreamContext* stream_ctx) override; void AddCallback(std::function callback) override; void Schedule(boxing::collective::RequestHandle* handle, const void* send_buff, void* recv_buff); void SetCompleted(size_t schedule_id); StreamContext* stream_ctx() const override; const TaskProto& task_proto() const override; private: StreamContext* stream_ctx_{}; TaskProto task_proto_{}; size_t scheduled_count_{}; size_t completed_count_{}; std::mutex mutex_; std::deque>> callbacks_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_ ================================================ FILE: oneflow/core/lazy/actor/copy_comm_net_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/register/register.h" namespace oneflow { class CopyCommNetActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(CopyCommNetActor); CopyCommNetActor() = default; ~CopyCommNetActor(); private: struct RegstCtx { void* comm_net_token; Regst* regst_raw_ptr; int64_t producer; bool has_sole_empty_blob; }; void VirtualActorInit(const TaskProto&) override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void ForEachCurCustomizedReadableRegst(std::function) const override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override { is_in_eord_ = true; } bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) override; void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; bool IsCustomizedReadReady() const override; bool IsCustomizedReadAlwaysUnReadyFromNow() const override; void AsyncReturnAllCustomizedReadableRegst() override; void AddCallback(std::function callback) override; bool is_in_eord_; HashMap sequence_number2regst_ctx_; void* actor_read_id_; int64_t next_sequence_number_; int64_t in_regst_desc_id_; }; CopyCommNetActor::~CopyCommNetActor() { Singleton::Get()->DeleteActorReadId(actor_read_id_); } void CopyCommNetActor::VirtualActorInit(const TaskProto& task_proto) { is_in_eord_ = false; next_sequence_number_ = 0; in_regst_desc_id_ = Name2SoleRegstDescId("copy_in"); actor_read_id_ = Singleton::Get()->NewActorReadId(); OF_SET_MSG_HANDLER(&CopyCommNetActor::HandlerNormal); } void CopyCommNetActor::ForEachCurCustomizedReadableRegst( std::function handler) const { handler(sequence_number2regst_ctx_.at(next_sequence_number_).regst_raw_ptr); } bool CopyCommNetActor::NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg& msg) { RegstCtx regst_ctx; regst_ctx.comm_net_token = msg.comm_net_token(); regst_ctx.regst_raw_ptr = msg.regst(); regst_ctx.producer = msg.src_actor_id(); regst_ctx.has_sole_empty_blob = msg.has_sole_empty_blob(); CHECK(sequence_number2regst_ctx_.emplace(msg.comm_net_sequence_number(), regst_ctx).second); return true; } void CopyCommNetActor::Act() { // readable auto readable_it = sequence_number2regst_ctx_.find(next_sequence_number_); void* readable_token = readable_it->second.comm_net_token; int64_t src_actor_id = readable_it->second.producer; int64_t src_machine_id = MachineId4ActorId(src_actor_id); // writeable Regst* writeable_regst = GetNaiveCurWriteable("copy_out"); if (readable_it->second.has_sole_empty_blob) { // pass if regst dynamic body is emtpy Blob* data_blob = writeable_regst->GetMutSoleBlob(); Shape empty_shape = data_blob->static_shape(); for (int i = 0; i < empty_shape.NumAxes(); ++i) { empty_shape.Set(i, 0); } data_blob->mut_shape_view()->set_shape(empty_shape); } else { void* writeable_token = writeable_regst->comm_net_token(); // Async Singleton::Get()->Read(actor_read_id_, src_machine_id, readable_token, writeable_token); } } void CopyCommNetActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { HandleProducedNaiveDataRegstToConsumer(); } void CopyCommNetActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { auto readable_it = sequence_number2regst_ctx_.find(next_sequence_number_); EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id(), readable_it->second.producer, readable_it->second.regst_raw_ptr)); sequence_number2regst_ctx_.erase(readable_it); next_sequence_number_ += 1; } bool CopyCommNetActor::IsCustomizedReadReady() const { return sequence_number2regst_ctx_.find(next_sequence_number_) != sequence_number2regst_ctx_.end(); } bool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() const { return is_in_eord_ && sequence_number2regst_ctx_.empty(); } void CopyCommNetActor::AsyncReturnAllCustomizedReadableRegst() { CHECK(sequence_number2regst_ctx_.empty()); } void CopyCommNetActor::AddCallback(std::function callback) { Singleton::Get()->AddReadCallBack(actor_read_id_, callback); } REGISTER_ACTOR(TaskType::kCopyCommNet, CopyCommNetActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/esac_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/kernel/esac_kernel.h" namespace oneflow { class EsacActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(EsacActor); EsacActor() = default; ~EsacActor() override = default; protected: void VirtualActorInit(const TaskProto&) override; int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); } bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override; private: void Act() override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void ForEachCurCustomizedReadableRegst(std::function) const override; bool IsCustomizedReadReady() const override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0; } void AsyncReturnAllCustomizedReadableRegst() override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; int64_t GetCurProcessedRegstDescId() const; RegstSlot consumed_rs_; int64_t cur_processed_regst_desc_id_{}; HashMap regst_desc_id2in_bn_id_; }; void EsacActor::VirtualActorInit(const TaskProto& task_proto) { CHECK_EQ(1, exec_kernel_vec().size()); const int32_t input_bns_size = task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().input_bns_size(); FOR_RANGE(int64_t, i, 0, input_bns_size) { const int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn("in", i)).regst_desc_id; CHECK(regst_desc_id2in_bn_id_.emplace(regst_desc_id, i).second); } for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (const int64_t regst_desc_id : pair.second.regst_desc_id()) { consumed_rs_.InsertRegstDescId(regst_desc_id); } } consumed_rs_.InitedDone(); cur_processed_regst_desc_id_ = -1; OF_SET_MSG_HANDLER(&EsacActor::HandlerNormal); } void EsacActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst())); } bool EsacActor::IsCustomizedReadReady() const { return -1 != GetCurProcessedRegstDescId(); } void EsacActor::ForEachCurCustomizedReadableRegst(std::function handler) const { handler(consumed_rs_.Front(cur_processed_regst_desc_id_)); } void EsacActor::Act() { cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId(); Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); CHECK(cur_regst); int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id_); CHECK_EQ(exec_kernel_vec().size(), 1); CHECK_NOTNULL(dynamic_cast(exec_kernel_vec().at(0).kernel_ctx->state().get())) ->value = in_bn_id; AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; } return cur_regst; }); } void EsacActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { HandleProducedNaiveDataRegstToConsumer(); } void EsacActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); CHECK(cur_regst); AsyncSendRegstMsgToProducer(cur_regst); CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_)); cur_processed_regst_desc_id_ = -1; } void EsacActor::AsyncReturnAllCustomizedReadableRegst() { CHECK_EQ(-1, cur_processed_regst_desc_id_); CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt()); } bool EsacActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; } int64_t EsacActor::GetCurProcessedRegstDescId() const { int64_t cur_processed_regst_desc_id = -1; consumed_rs_.ForChosenRegstDeq( [&cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; }, [&cur_processed_regst_desc_id](const std::deque& reg_deq) { if (reg_deq.empty()) { return; } cur_processed_regst_desc_id = reg_deq.front()->regst_desc_id(); }); return cur_processed_regst_desc_id; } REGISTER_ACTOR(kEsac, EsacActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/generic_actor_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/generic_actor_context.h" namespace oneflow { void GenericActorContext::Init(const TaskProto& task_proto, StreamContext* stream_ctx) { stream_ctx_ = stream_ctx; task_proto_ = task_proto; } void GenericActorContext::AddCallback(std::function callback) { CHECK_JUST(stream_ctx_->AddCallback(std::move(callback))); } StreamContext* GenericActorContext::stream_ctx() const { return stream_ctx_; } const TaskProto& GenericActorContext::task_proto() const { return task_proto_; } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/generic_actor_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_ #define ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_ #include "oneflow/core/lazy/actor/actor_context.h" namespace oneflow { class GenericActorContext : public ActorContext { public: OF_DISALLOW_COPY_AND_MOVE(GenericActorContext); GenericActorContext() = default; ~GenericActorContext() override = default; void Init(const TaskProto& task_proto, StreamContext* stream_ctx) override; void AddCallback(std::function callback) override; StreamContext* stream_ctx() const override; const TaskProto& task_proto() const override; private: StreamContext* stream_ctx_{}; TaskProto task_proto_{}; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_ ================================================ FILE: oneflow/core/lazy/actor/input_wise_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/input_wise_actor.h" namespace oneflow { void InputWiseActor::Init(const TaskProto& task_proto) { CHECK_EQ(1, exec_kernel_vec().size()); const auto& input_bns = task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().input_bns(); HashMap ibn2in_bn_id; for (int64_t i = 0; i < input_bns.size(); ++i) { CHECK(ibn2in_bn_id.emplace(input_bns.Get(i), i).second); } for (const auto& pair : exec_kernel_vec().at(0).bn_in_op2blob_info) { auto it = ibn2in_bn_id.find(pair.first); if (it != ibn2in_bn_id.end()) { CHECK(regst_desc_id2in_bn_id_.emplace(pair.second.regst_desc_id, it->second).second); } } for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { consumed_rs_.InsertRegstDescId(regst_desc_id); CHECK(regst_desc_id2is_processed_.emplace(regst_desc_id, false).second); } } consumed_rs_.InitedDone(); cur_processed_regst_desc_id_ = -1; processed_regst_desc_id_cnt_ = 0; OF_SET_MSG_HANDLER(&InputWiseActor::HandlerNormal); } void InputWiseActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst())); } bool InputWiseActor::IsCustomizedReadReady() const { return -1 != GetCurProcessedRegstDescId(); } void InputWiseActor::ForEachCurCustomizedReadableRegst( std::function handler) const { handler(consumed_rs_.Front(cur_processed_regst_desc_id_)); } void InputWiseActor::Act() { cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId(); Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); CHECK(cur_regst); AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; } return cur_regst; }); processed_regst_desc_id_cnt_ += 1; regst_desc_id2is_processed_.at(cur_processed_regst_desc_id_) = true; } void InputWiseActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (processed_regst_desc_id_cnt_ == regst_desc_id2is_processed_.size()) { HandleProducedNaiveDataRegstToConsumer(); for (auto& pair : regst_desc_id2is_processed_) { CHECK(pair.second); pair.second = false; } processed_regst_desc_id_cnt_ = 0; } } void InputWiseActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); CHECK(cur_regst); AsyncSendRegstMsgToProducer(cur_regst); CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_)); cur_processed_regst_desc_id_ = -1; } void InputWiseActor::AsyncReturnAllCustomizedReadableRegst() { CHECK_EQ(-1, cur_processed_regst_desc_id_); CHECK_EQ(0, processed_regst_desc_id_cnt_); CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt()); } bool InputWiseActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; } int64_t InputWiseActor::GetCurProcessedRegstDescId() const { int64_t cur_processed_regst_desc_id = -1; consumed_rs_.ForChosenRegstDeq( [cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; }, [this, &cur_processed_regst_desc_id](const std::deque& reg_deq) { if (reg_deq.empty()) { return; } int64_t regst_desc_id = reg_deq.front()->regst_desc_id(); if (regst_desc_id2is_processed_.at(regst_desc_id) == false) { cur_processed_regst_desc_id = regst_desc_id; } }); return cur_processed_regst_desc_id; } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/input_wise_actor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_ #define ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_ #include "oneflow/core/lazy/actor/actor.h" namespace oneflow { class InputWiseActor : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(InputWiseActor); InputWiseActor() = default; ~InputWiseActor() = default; using Actor::Init; protected: void Init(const TaskProto&); int64_t cur_processed_regst_desc_id() const { return cur_processed_regst_desc_id_; } int64_t processed_regst_desc_id_cnt() const { return processed_regst_desc_id_cnt_; } int64_t RegstDescNum() const { return consumed_rs_.total_regst_desc_cnt(); } int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); } bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override; private: void Act() override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void ForEachCurCustomizedReadableRegst(std::function) const override; bool IsCustomizedReadReady() const override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0; } void AsyncReturnAllCustomizedReadableRegst() override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; int64_t GetCurProcessedRegstDescId() const; RegstSlot consumed_rs_; HashMap regst_desc_id2is_processed_; int64_t processed_regst_desc_id_cnt_; int64_t cur_processed_regst_desc_id_; HashMap regst_desc_id2in_bn_id_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_ ================================================ FILE: oneflow/core/lazy/actor/light_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor_base.h" #include "oneflow/core/register/register.h" #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/core/lazy/actor/actor_message.h" #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/thread/thread.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/job/runtime_job_descs.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/kernel/user_kernel.h" #include "oneflow/core/lazy/stream_context/include/stream_context.h" #ifdef WITH_CUDA #include "oneflow/core/ep/cuda/cuda_stream.h" #endif // WITH_CUDA namespace oneflow { namespace { enum RegstType : int8_t { kInvalid = 0, kProduced, kConsumed, }; template struct ProducedRegstState { IndexType reading_cnt; IndexType max_reading_cnt; }; struct ConsumedRegstState { bool ready; bool eord; }; template struct RegstState { Regst* regst; RegstType regst_type; union { ProducedRegstState produced; ConsumedRegstState consumed; }; }; struct KernelInfo { std::unique_ptr kernel; HashMap bn_in_op2blob; std::shared_ptr state; }; struct DebugInfo { int64_t actor_id; std::string op_name; int64_t act_cnt; DebugInfo() : actor_id(-1), op_name(""), act_cnt(-1) {} }; template struct ArrayBaseIndex { ArrayBaseIndex() { std::memset(this, 0, sizeof(*this)); } inline IndexType Size() const { return size; } void Reserve(IndexType new_size) { CHECK_LE(new_size, max_size); } inline IndexType Lookup(int64_t v) const { for (IndexType i = 0; i < size; ++i) { if (arr[i] == v) { return i; } } CHECK(false); return -1; } bool Contains(int64_t v) const { for (IndexType i = 0; i < size; ++i) { if (arr[i] == v) { return true; } } return false; } IndexType Add(int64_t v) { CHECK_LT(size, max_size); const IndexType index = size; size += 1; arr[index] = v; return index; } void GetValues(std::vector* values) const { values->resize(size); for (IndexType i = 0; i < size; ++i) { values->at(i) = arr[i]; } } std::array arr; IndexType size; }; template struct MapBaseIndex { inline IndexType Size() const { return index_map.size(); } void Reserve(IndexType size) { index_map.reserve(size); } inline IndexType Lookup(int64_t v) { auto it = index_map.find(v); CHECK(it != index_map.end()); return it->second; } bool Contains(int64_t v) { return index_map.count(v) > 0; } IndexType Add(int64_t v) { const IndexType index = index_map.size(); CHECK(index_map.emplace(v, index).second); return index; } void GetValues(std::vector* values) const { values->resize(index_map.size()); for (const auto& pair : index_map) { values->at(pair.second) = pair.first; } } HashMap index_map; }; template struct ArrayBaseStateContainer { ArrayBaseStateContainer() { std::memset(this, 0, sizeof(*this)); } void Resize(IndexType new_size) { CHECK_LE(new_size, max_size); size = new_size; } inline IndexType Size() const { return size; } inline RegstState& Get(IndexType index) { CHECK_LT(index, size); return arr[index]; } std::array, max_size> arr; IndexType size; }; template struct VectorBaseStateContainer { void Resize(IndexType new_size) { vec.resize(new_size); } inline IndexType Size() const { return static_cast(vec.size()); } inline RegstState& Get(IndexType index) { return vec.at(index); } std::vector> vec; }; bool IsInplaceRegstDesc(const RegstDescProto& regst_desc) { return regst_desc.has_inplace_consumed_regst_desc_id() && regst_desc.consumer_task_id_size() > 0; } size_t GetRegstDescCount(const TaskProto& task) { size_t regst_cnt = task.produced_regst_desc().size(); for (const auto& pair : task.consumed_regst_desc_id()) { regst_cnt += pair.second.regst_desc_id_size(); } return regst_cnt; } size_t GetConsumerCount(const TaskProto& task) { size_t consumer_cnt = 0; for (const auto& pair : task.produced_regst_desc()) { consumer_cnt += pair.second.consumer_task_id_size(); } return consumer_cnt; } bool NeedExecKernelWhenInplace(const TaskProto& task) { int64_t data_regst_cnt = 0; for (const auto& pair : task.produced_regst_desc()) { if (pair.second.regst_desc_type().has_data_regst_desc()) { if (data_regst_cnt != 0) { return true; } data_regst_cnt += 1; const DataRegstDesc& regst_desc = pair.second.regst_desc_type().data_regst_desc(); if (regst_desc.lbi2blob_desc().size() != 1) { return true; } if (regst_desc.lbi2blob_desc().begin()->blob_desc().is_dynamic()) { return true; } } } if (data_regst_cnt != 1) { return true; } if (task.exec_sequence().exec_node().size() != 1) { return true; } const OperatorConf& op_conf = task.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); if (!op_conf.has_user_conf()) { return true; } const std::string& op_type = op_conf.user_conf().op_type_name(); const bool is_const_inplace_op_type = (op_type == "expand_dims") || (op_type == "squeeze") || (op_type == "reshape") || (op_type == "reshape_like") || (op_type == "transpose"); if (!is_const_inplace_op_type) { return true; } return false; } #ifdef WITH_CUDA_GRAPHS bool IsCUDAGraphSupported(const Kernel* kernel) { auto* user_kernel = dynamic_cast(kernel); return (user_kernel != nullptr && user_kernel->IsCudaGraphSupported()); } #endif // WITH_CUDA_GRAPHS template class LightActor : public ActorBase, public KernelContext, public ActorContextProvider { public: OF_DISALLOW_COPY_AND_MOVE(LightActor); explicit LightActor(ActorContext* actor_ctx) : thread_(nullptr), actor_ctx_(actor_ctx), stream_ctx_(actor_ctx->stream_ctx()), stream_kernel_observer_(nullptr) { auto* kernel_observer_provider = dynamic_cast(stream_ctx_); if (kernel_observer_provider != nullptr) { stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver(); } } ~LightActor() override { for (IndexType i = 0; i < index2state_.Size(); ++i) { auto& state = index2state_.Get(i); if (state.regst_type == RegstType::kProduced) { delete state.regst; } } } void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override { const TaskProto& task_proto = actor_ctx->task_proto(); CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1); if (debug) { debug_info_[0].reset(new DebugInfo()); debug_info_[0]->op_name = task_proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf().name(); debug_info_[0]->actor_id = task_proto.task_id(); debug_info_[0]->act_cnt = 0; } if (exec_kernel) { kernel_info_[0].reset(new KernelInfo()); const KernelConf& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf(); kernel_info_[0]->kernel = ConstructKernel(kernel_conf, this); #ifdef WITH_CUDA_GRAPHS auto* cuda_stream = dynamic_cast(actor_ctx->stream_ctx()->stream()); if (cuda_stream != nullptr && kernel_conf.all_blobs_are_static() && IsCUDAGraphSupported(kernel_info_[0]->kernel.get())) { cuda_graph_exec_[0].reset(new ep::CudaGraphExecutable()); } #endif } const int64_t thrd_id = ThrdId4ActorId(task_proto.task_id()); thread_ = Singleton::Get()->GetThrd(thrd_id); total_reading_cnt_ = 0; max_total_reading_cnt_ = 0; remaining_eord_cnt_ = 0; ready_consumed_ = 0; max_ready_consumed_ = 0; const IndexType regst_cnt = GetRegstDescCount(task_proto); regst_desc_id_index_.Reserve(regst_cnt); index2state_.Resize(regst_cnt); IndexType inplace_produced_index = -1; IndexType inplace_consumed_index = -1; int64_t inplace_consumed_regst_desc_id = -1; for (const auto& pair : task_proto.produced_regst_desc()) { const RegstDescProto& regst_desc = pair.second; if (IsInplaceRegstDesc(regst_desc)) { CHECK_EQ(inplace_consumed_regst_desc_id, -1); inplace_consumed_regst_desc_id = regst_desc.inplace_consumed_regst_desc_id(); } } for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { const IndexType index = regst_desc_id_index_.Add(regst_desc_id); auto& state = index2state_.Get(index); state.regst_type = RegstType::kConsumed; state.consumed.ready = false; state.consumed.eord = false; remaining_eord_cnt_ += 1; max_ready_consumed_ += 1; if (regst_desc_id == inplace_consumed_regst_desc_id) { inplace_consumed_index = index; } } } for (const auto& pair : task_proto.produced_regst_desc()) { const RegstDescProto& regst_desc = pair.second; const IndexType index = regst_desc_id_index_.Add(regst_desc.regst_desc_id()); auto& state = index2state_.Get(index); Singleton::Get()->NewRegsts(regst_desc, [&state](Regst* regst) { CHECK(state.regst == nullptr); state.regst = regst; }); state.produced.max_reading_cnt = regst_desc.consumer_task_id_size(); state.regst_type = RegstType::kProduced; state.produced.reading_cnt = 0; max_total_reading_cnt_ += state.produced.max_reading_cnt; if (IsInplaceRegstDesc(regst_desc)) { CHECK_EQ(inplace_produced_index, -1); inplace_produced_index = index; } } if (inplace) { CHECK_NE(inplace_produced_index, -1); CHECK_NE(inplace_consumed_index, -1); inplace_produced_index_[0] = inplace_produced_index; inplace_consumed_index_[0] = inplace_consumed_index; } else { CHECK_EQ(inplace_produced_index, -1); CHECK_EQ(inplace_consumed_index, -1); } } int ProcessMsg(const ActorMsg& msg) override { HandleActorMsg(msg); if (debug) { LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " in act_cnt: [ " << debug_info_[0]->act_cnt << " ] IsWriteReady: " << (total_reading_cnt_ == 0) << " IsReadReady: " << (ready_consumed_ == max_ready_consumed_) << " \n details: { total_reading_cnt = " << static_cast(total_reading_cnt_) << " (expect: 0) , ready_consumed_ = " << static_cast(ready_consumed_) << " (except: " << static_cast(max_ready_consumed_) << ") }"; } if (total_reading_cnt_ != 0) { return 0; } if (ready_consumed_ == max_ready_consumed_) { ActOnce(); return 0; } if (OF_PREDICT_FALSE(ready_consumed_ == 0 && remaining_eord_cnt_ == 0)) { SendEORDMsg(); return 1; } return 0; } private: void InitBnInOp2Blob() { if (exec_kernel) { const ExecNodeProto& node = actor_ctx_->task_proto().exec_sequence().exec_node(0); for (auto& pair : node.kernel_conf().op_attribute().arg_signature().bn_in_op2lbi()) { const std::string& bn = pair.first; auto regst_desc_id_it = node.bn_in_op2regst_desc_id().find(bn); if (regst_desc_id_it == node.bn_in_op2regst_desc_id().end()) { CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second); continue; } if (!regst_desc_id_index_.Contains(regst_desc_id_it->second)) { CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second); continue; } Regst* regst = index2state_.Get(regst_desc_id_index_.Lookup(regst_desc_id_it->second)).regst; if (regst == nullptr) { LOG(WARNING) << "null regst found, op:" << node.kernel_conf().op_attribute().op_conf().name(); CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second); continue; } Blob* blob = regst->GetBlobByLbi(pair.second); if (!blob) { LOG(WARNING) << "null blob found, op: " << node.kernel_conf().op_attribute().op_conf().name(); } CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, blob).second); } } } void InitActMsg() { const bool is_kernel_launch_synchronized = (!exec_kernel) || kernel_info_[0]->kernel->IsKernelLaunchSynchronized(); const int64_t actor_id = actor_ctx_->task_proto().task_id(); const int64_t thrd_id = ThrdId4ActorId(actor_id); auto IsSyncMsg = [&](const ActorMsg& msg) { return is_kernel_launch_synchronized && thrd_id == ThrdId4ActorId(msg.dst_actor_id()); }; auto EnqueueActorMsg = [&](const ActorMsg& msg) { if (IsSyncMsg(msg)) { sync_post_act_msgs_.emplace_back(msg); } else { async_post_act_msgs_.emplace_back(msg); } }; std::vector index2regst_desc_id; regst_desc_id_index_.GetValues(&index2regst_desc_id); for (IndexType i = 0; i < index2state_.Size(); ++i) { const auto& state = index2state_.Get(i); if (state.regst_type == RegstType::kProduced) { for (int64_t consumer : state.regst->consumers_actor_id()) { EnqueueActorMsg(ActorMsg::BuildRegstMsgToConsumer(actor_id, consumer, state.regst)); } } else if (state.regst_type == RegstType::kConsumed) { const int64_t regst_desc_id = index2regst_desc_id.at(i); int64_t producer = -1; if (Singleton::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { producer = Singleton::Get()->ProducerTaskId4RegstDescId(regst_desc_id); } else { producer = state.regst->producer_actor_id(); } ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(actor_id, producer, state.regst); if (inplace && i == inplace_consumed_index_[0]) { if (IsSyncMsg(msg)) { return_inplace_consumed_fn_[0] = [this, msg]() { thread_->EnqueueActorMsg(msg); }; } else { return_inplace_consumed_fn_[0] = [this, msg]() { actor_ctx_->AddCallback([msg] { Singleton::Get()->SendMsg(msg); }); }; } } else { EnqueueActorMsg(msg); } } else { UNIMPLEMENTED(); } } } inline void ResetState() { total_reading_cnt_ = max_total_reading_cnt_; ready_consumed_ = 0; for (IndexType i = 0; i < index2state_.Size(); ++i) { auto& state = index2state_.Get(i); if (state.regst_type == RegstType::kProduced) { state.produced.reading_cnt = state.produced.max_reading_cnt; if (dynamic_allocation && state.produced.max_reading_cnt == 0 && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) { if (inplace && i == inplace_produced_index_[0]) { // do nothing } else { CHECK_JUST( actor_ctx_->stream_ctx()->stream()->FreeAsync(state.regst->body_mem_ptr())); } state.regst->ResetBodyMemPtr(nullptr); } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } } else if (state.regst_type == RegstType::kConsumed) { state.consumed.ready = false; } else { UNIMPLEMENTED(); } } } inline void HandleActorMsg(const ActorMsg& msg) { if (OF_PREDICT_TRUE(msg.msg_type() == ActorMsgType::kRegstMsg)) { HandleRegstMsg(msg); } else if (msg.msg_type() == ActorMsgType::kEordMsg) { HandleEordMsg(msg); } else if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart); } else { UNIMPLEMENTED() << msg.msg_type() << " " << actor_ctx_->task_proto().task_id(); } } void HandleEordMsg(const ActorMsg& msg) { const IndexType index = regst_desc_id_index_.Lookup(msg.eord_regst_desc_id()); auto& state = index2state_.Get(index); CHECK_EQ(state.regst_type, RegstType::kConsumed); CHECK_EQ(state.consumed.eord, false); state.consumed.eord = true; CHECK_GT(remaining_eord_cnt_, 0); remaining_eord_cnt_ -= 1; } inline void HandleRegstMsg(const ActorMsg& msg) { int64_t regst_desc_id = msg.regst_desc_id(); if (regst_desc_id == -1) { regst_desc_id = msg.regst()->regst_desc_id(); } if (debug) { LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " in act_cnt: [ " << debug_info_[0]->act_cnt << " ] , Recv ActorMsg from: " << msg.src_actor_id() << " to: " << msg.dst_actor_id() << " with regst: " << regst_desc_id; } const IndexType index = regst_desc_id_index_.Lookup(regst_desc_id); auto& state = index2state_.Get(index); if (state.regst_type == RegstType::kProduced) { CHECK_GT(state.produced.reading_cnt, 0); state.produced.reading_cnt -= 1; CHECK_GT(total_reading_cnt_, 0); total_reading_cnt_ -= 1; if (dynamic_allocation && state.produced.reading_cnt == 0 && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) { if (inplace && index == inplace_produced_index_[0]) { // do nothing } else { CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(state.regst->body_mem_ptr())); } state.regst->ResetBodyMemPtr(nullptr); } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } if (inplace && index == inplace_produced_index_[0] && state.produced.reading_cnt == 0) { return_inplace_consumed_fn_[0](); } } else if (state.regst_type == RegstType::kConsumed) { CHECK_EQ(state.consumed.ready, false); CHECK_EQ(state.consumed.eord, false); if (state.regst == nullptr) { state.regst = msg.regst(); } else { CHECK(state.regst == msg.regst()); } ready_consumed_ += 1; } else { UNIMPLEMENTED(); } } inline void ActOnce() { if (OF_PREDICT_FALSE(sync_post_act_msgs_.empty() && async_post_act_msgs_.empty())) { InitBnInOp2Blob(); InitActMsg(); } for (IndexType i = 0; i < index2state_.Size(); ++i) { auto& state = index2state_.Get(i); if (dynamic_allocation && state.regst_type == RegstType::kProduced && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK(state.regst->body_mem_ptr() == nullptr); void* body_ptr = nullptr; if (inplace && i == inplace_produced_index_[0]) { body_ptr = index2state_.Get(inplace_consumed_index_[0]).regst->body_mem_ptr(); } else { CHECK_JUST(actor_ctx_->stream_ctx()->stream()->AllocAsync( &body_ptr, state.regst->regst_desc()->BodyByteSize4OneRegst())); } state.regst->ResetBodyMemPtr(body_ptr); } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) { // do nothing } else { UNIMPLEMENTED(); } } } if (debug) { LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " Try to act act_cnt: [ " << debug_info_[0]->act_cnt << " ] before launch kernel."; } if (exec_kernel) { LaunchKernel(); } ResetState(); thread_->EnqueueActorMsg(sync_post_act_msgs_.cbegin(), sync_post_act_msgs_.cend()); if (!async_post_act_msgs_.empty()) { actor_ctx_->AddCallback([this]() { for (const auto& msg : async_post_act_msgs_) { Singleton::Get()->SendMsg(msg); } }); } if (debug) { for (const auto& msg : sync_post_act_msgs_) { LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " in act_cnt: [ " << debug_info_[0]->act_cnt << " ] Sync post ActorMsg from: " << msg.src_actor_id() << " to: " << msg.dst_actor_id() << " with regst: " << msg.regst_desc_id(); } for (const auto& msg : async_post_act_msgs_) { LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " in act_cnt: [ " << debug_info_[0]->act_cnt << " ] Async post ActorMsg from: " << msg.src_actor_id() << " to: " << msg.dst_actor_id() << " with regst: " << msg.regst_desc_id(); } LOG(INFO) << " Actor: " << debug_info_[0]->actor_id << " op: " << debug_info_[0]->op_name << " Finish act act_cnt: [ " << debug_info_[0]->act_cnt++ << " ]."; } } inline void LaunchKernel() { #ifdef WITH_CUDA_GRAPHS bool is_capturing = false; if (cuda_graph_exec_[0]) { auto* cuda_stream = stream_ctx_->stream()->As(); if (cuda_graph_exec_[0]->IsInstantiated()) { cuda_stream->LaunchGraph(cuda_graph_exec_[0].get()); return; } auto* user_kernel = CHECK_NOTNULL(dynamic_cast(kernel_info_[0]->kernel.get())); if (user_kernel->IsReadyForCudaGraphCapture(this)) { is_capturing = true; cuda_stream->BeginGraphCapture(); } } #endif kernel_info_[0]->kernel->Launch(this); #ifdef WITH_CUDA_GRAPHS if (cuda_graph_exec_[0] && is_capturing) { auto* cuda_stream = stream_ctx_->stream()->As(); cuda_stream->EndGraphCapture(cuda_graph_exec_[0].get()); cuda_stream->LaunchGraph(cuda_graph_exec_[0].get()); } #endif } void SendEORDMsg() { for (IndexType i = 0; i < index2state_.Size(); ++i) { auto& state = index2state_.Get(i); if (state.regst_type != RegstType::kProduced) { continue; } const RtRegstDesc* regst_desc = state.regst->regst_desc(); actor_ctx_->AddCallback([regst_desc]() { for (int64_t consumer : regst_desc->consumers_actor_id()) { Singleton::Get()->SendMsg( ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id())); } }); } } ep::Stream* stream() const override { return stream_ctx_->stream(); } ActorContext* GetActorContext() const override { return actor_ctx_; } Blob* BnInOp2Blob(const std::string& bn) const override { if (exec_kernel) { auto it = kernel_info_[0]->bn_in_op2blob.find(bn); if (it == kernel_info_[0]->bn_in_op2blob.end()) { return nullptr; } else { return it->second; } } else { return nullptr; } } const std::shared_ptr& state() const override { if (exec_kernel) { return kernel_info_[0]->state; } else { static const std::shared_ptr null_state; return null_state; } } void set_state(std::shared_ptr state) override { CHECK(exec_kernel); kernel_info_[0]->state = std::move(state); } void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override { Singleton::Get()->WillForward(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForward(kernel_ctx, kernel); } } void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override { CHECK_JUST_MSG(kernel_ctx->stream()->GetAsyncError(), kernel->op_conf().name()); Singleton::Get()->DidForward(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForward(kernel_ctx, kernel); } } void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override { Singleton::Get()->WillForwardHeader(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel); } } void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override { Singleton::Get()->DidForwardHeader(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel); } } void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override { Singleton::Get()->WillForwardDataContent(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel); } } void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override { Singleton::Get()->DidForwardDataContent(kernel_ctx, kernel); if (stream_kernel_observer_ != nullptr) { stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel); } } RegstIndex regst_desc_id_index_; StateContainer index2state_; IndexType total_reading_cnt_; IndexType ready_consumed_; IndexType max_total_reading_cnt_; IndexType max_ready_consumed_; IndexType remaining_eord_cnt_; IndexType inplace_produced_index_[inplace]; IndexType inplace_consumed_index_[inplace]; std::function return_inplace_consumed_fn_[inplace]; Thread* thread_; std::unique_ptr kernel_info_[exec_kernel]; #ifdef WITH_CUDA_GRAPHS std::unique_ptr cuda_graph_exec_[exec_kernel]; #endif ActorContext* actor_ctx_; StreamContext* stream_ctx_; std::vector sync_post_act_msgs_; std::vector async_post_act_msgs_; KernelObserver* stream_kernel_observer_; // for debug std::unique_ptr debug_info_[debug]; }; template ActorBase* DispatchNewLightActorDebug(ActorContext* actor_ctx) { const bool debug = EnableActorDebugLog(); if (debug) { return new LightActor(actor_ctx); } else { return new LightActor(actor_ctx); } } template ActorBase* DispatchNewLightActorDynamicAlloc(ActorContext* actor_ctx) { const bool dynamic_allocation = ParseBooleanFromEnv("ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION", false); if (dynamic_allocation) { return DispatchNewLightActorDebug(actor_ctx); } else { return DispatchNewLightActorDebug(actor_ctx); } } template ActorBase* DispatchNewLightActorMaxSize(ActorContext* actor_ctx) { const size_t regst_desc_count = GetRegstDescCount(actor_ctx->task_proto()); if (regst_desc_count <= 2) { return DispatchNewLightActorDynamicAlloc, ArrayBaseStateContainer>(actor_ctx); } else if (regst_desc_count <= 4) { return DispatchNewLightActorDynamicAlloc, ArrayBaseStateContainer>(actor_ctx); } else if (regst_desc_count <= 8) { return DispatchNewLightActorDynamicAlloc, ArrayBaseStateContainer>(actor_ctx); } else { return DispatchNewLightActorDynamicAlloc, VectorBaseStateContainer>(actor_ctx); } } template ActorBase* DispatchNewLightActorIndexType(ActorContext* actor_ctx) { size_t size = std::max(GetRegstDescCount(actor_ctx->task_proto()), GetConsumerCount(actor_ctx->task_proto())); if (size <= static_cast(std::numeric_limits::max())) { return DispatchNewLightActorMaxSize(actor_ctx); } else if (size <= static_cast(std::numeric_limits::max())) { return DispatchNewLightActorMaxSize(actor_ctx); } else { return nullptr; } } template ActorBase* DispatchNewLightActorInplace(ActorContext* actor_ctx) { const auto& produced_regst_desc = actor_ctx->task_proto().produced_regst_desc(); const size_t inplace_produced_regst_cnt = std::count_if(produced_regst_desc.cbegin(), produced_regst_desc.cend(), [](const PbMapPair& pair) { return pair.second.has_inplace_consumed_regst_desc_id(); }); if (inplace_produced_regst_cnt > 1) { return nullptr; } bool inplace = false; for (const auto& pair : produced_regst_desc) { const RegstDescProto& regst_desc = pair.second; if (IsInplaceRegstDesc(regst_desc)) { CHECK_EQ(inplace, false); inplace = true; } } if (inplace) { if (kernel_exec && NeedExecKernelWhenInplace(actor_ctx->task_proto())) { return DispatchNewLightActorIndexType<1, 1>(actor_ctx); } else { return DispatchNewLightActorIndexType<0, 1>(actor_ctx); } } else { return DispatchNewLightActorIndexType(actor_ctx); } } ActorBase* NewLightActorWithKernel(ActorContext* actor_ctx) { return DispatchNewLightActorInplace<1>(actor_ctx); } ActorBase* NewLightActorWithoutKernel(ActorContext* actor_ctx) { return DispatchNewLightActorInplace<0>(actor_ctx); } ActorBase* TryNewLightActorWithoutInit(ActorContext* actor_ctx) { const TaskProto& task_proto = actor_ctx->task_proto(); if (!task_proto.all_register_num_eq_one_hint()) { return nullptr; } if (task_proto.exec_sequence().exec_node_size() != 1) { return nullptr; } if (task_proto.task_type() == TaskType::kNormalForward) { const OperatorConf& op_conf = task_proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); if (op_conf.has_variable_conf()) { return NewLightActorWithoutKernel(actor_ctx); } else { return NewLightActorWithKernel(actor_ctx); } } else if (task_proto.task_type() == TaskType::kCopyHd) { return NewLightActorWithKernel(actor_ctx); } else if (task_proto.task_type() == TaskType::kTick) { return NewLightActorWithoutKernel(actor_ctx); } else if (task_proto.task_type() == TaskType::kCollectiveBoxingGeneric) { return NewLightActorWithKernel(actor_ctx); } else { return nullptr; } } } // namespace std::unique_ptr TryNewLightActor(ActorContext* actor_ctx) { ActorBase* actor = TryNewLightActorWithoutInit(actor_ctx); if (actor != nullptr) { const auto& job_descs = *Singleton::Get(); actor->Init(&job_descs.job_desc(actor_ctx->task_proto().job_id()), actor_ctx); } return std::unique_ptr(actor); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/light_actor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_ #define ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_ #include "oneflow/core/lazy/actor/actor_base.h" namespace oneflow { std::unique_ptr TryNewLightActor(ActorContext* ctx); } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_ ================================================ FILE: oneflow/core/lazy/actor/naive_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/naive_actor.h" namespace oneflow { void NaiveActor::Act() { AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { return nullptr; }); } void NaiveActor::VirtualActorInit(const TaskProto&) { OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal); } REGISTER_ACTOR(TaskType::kNormalForward, NaiveActor); REGISTER_ACTOR(TaskType::kDistributeConcat, NaiveActor); REGISTER_ACTOR(TaskType::kDistributeSplit, NaiveActor); REGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor); REGISTER_ACTOR(TaskType::kNcclSendRecvBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor); REGISTER_ACTOR(TaskType::kCriticalSectionWaitTick, NaiveActor); REGISTER_ACTOR(TaskType::kCopyHd, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingGeneric, NaiveActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/naive_actor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_ #define ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_ #include "oneflow/core/lazy/actor/actor.h" namespace oneflow { class NaiveActor : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(NaiveActor); NaiveActor() = default; ~NaiveActor() override = default; void VirtualActorInit(const TaskProto&) override; protected: void Act() override; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_ ================================================ FILE: oneflow/core/lazy/actor/pack_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/user_kernel.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { class PackActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(PackActor); PackActor() = default; ~PackActor() = default; private: void VirtualActorInit(const TaskProto& proto) override; void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void VirtualAsyncSendNaiveConsumedRegstMsgToProducer() override; size_t total_pack_num_; size_t act_num_cnt_; }; void PackActor::VirtualActorInit(const TaskProto& proto) { const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); total_pack_num_ = in_time_shape.At(in_time_shape.NumAxes() - 1); act_num_cnt_ = 0; OF_SET_MSG_HANDLER(&PackActor::HandlerNormal); } void PackActor::Act() { CHECK_GE(exec_kernel_vec().size(), 1); auto user_kernel = dynamic_cast(exec_kernel_vec().at(0).kernel.get()); CHECK_NOTNULL(user_kernel); auto state = dynamic_cast>*>( user_kernel->GetOpKernelState().get()); CHECK_NOTNULL(state); state->Mutable()->first = act_num_cnt_; state->Mutable()->second = total_pack_num_; AsyncLaunchKernel(); act_num_cnt_ += 1; } void PackActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (act_num_cnt_ == total_pack_num_) { HandleProducedNaiveDataRegstToConsumer(); } } void PackActor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() { HandleConsumedNaiveDataRegstToProducer(); if (act_num_cnt_ == total_pack_num_) { act_num_cnt_ = 0; } } REGISTER_ACTOR(TaskType::kPack, PackActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/reentrant_lock_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/reentrant_lock_kernel.h" namespace oneflow { class ReentrantLockActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(ReentrantLockActor); ReentrantLockActor() : reentrant_lock_status_(nullptr) {} ~ReentrantLockActor() override = default; protected: void VirtualActorInit(const TaskProto&) override; private: void Act() override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void ForEachCurCustomizedReadableRegst(std::function) const override; bool IsCustomizedReadReady() const override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override; void AsyncReturnAllCustomizedReadableRegst() override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; int64_t GetCurProcessedRegstDescId() const; const std::string& Ibn4RegstDescId(int64_t id) const; RegstSlot consumed_rs_; int64_t cur_processed_regst_desc_id_{}; HashMap regst_desc_id2ibn_; ReentrantLockStatus* reentrant_lock_status_; int64_t eord_regst_desc_id_{}; int64_t act_id_{}; }; void ReentrantLockActor::VirtualActorInit(const TaskProto& task_proto) { CHECK_EQ(1, exec_kernel_vec().size()); reentrant_lock_status_ = CHECK_NOTNULL( dynamic_cast(exec_kernel_vec().at(0).kernel_ctx->state().get())); act_id_ = 0; const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf(); const auto& ibns = kernel_conf.op_attribute().input_bns(); for (const auto& ibn : ibns) { int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2blob_info.at(ibn).regst_desc_id; if (ibn == "start") { eord_regst_desc_id_ = regst_desc_id; } CHECK(regst_desc_id2ibn_.emplace(regst_desc_id, ibn).second); } for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (const int64_t regst_desc_id : pair.second.regst_desc_id()) { consumed_rs_.InsertRegstDescId(regst_desc_id); } } consumed_rs_.InitedDone(); cur_processed_regst_desc_id_ = -1; reentrant_lock_status_->Init(kernel_conf); OF_SET_MSG_HANDLER(&ReentrantLockActor::HandlerNormal); } void ReentrantLockActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst())); } bool ReentrantLockActor::IsCustomizedReadReady() const { return reentrant_lock_status_->cur_unlocked_ids().size() > 0 || -1 != GetCurProcessedRegstDescId(); } void ReentrantLockActor::ForEachCurCustomizedReadableRegst( std::function handler) const { handler(consumed_rs_.Front(cur_processed_regst_desc_id_)); } const std::string& ReentrantLockActor::Ibn4RegstDescId(int64_t id) const { const auto& iter = regst_desc_id2ibn_.find(id); if (iter == regst_desc_id2ibn_.end()) { return ReentrantLockStatus::kEmptyIbn; } return regst_desc_id2ibn_.at(id); } void ReentrantLockActor::Act() { cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId(); Regst* const cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); reentrant_lock_status_->set_cur_ibn(Ibn4RegstDescId(cur_processed_regst_desc_id_)); reentrant_lock_status_->set_cur_act_id(act_id_); act_id_ += 1; AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; } return cur_regst; }); } bool ReentrantLockActor::IsCustomizedReadAlwaysUnReadyFromNow() const { return ReceiveEordMsg(eord_regst_desc_id_) && reentrant_lock_status_->total_queued_request_lock_num() == 0 && reentrant_lock_status_->total_acquired_lock_num() == 0; } void ReentrantLockActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (reentrant_lock_status_->acquired_lock_to_be_sent() == false) { return; } HandleProducedNaiveDataRegstToConsumer(); } void ReentrantLockActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { Regst* const cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); if (cur_regst == nullptr) { return; } AsyncSendRegstMsgToProducer(cur_regst); CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_)); cur_processed_regst_desc_id_ = -1; } void ReentrantLockActor::AsyncReturnAllCustomizedReadableRegst() { CHECK_EQ(-1, cur_processed_regst_desc_id_); CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt()); } int64_t ReentrantLockActor::GetCurProcessedRegstDescId() const { int64_t cur_processed_regst_desc_id = -1; consumed_rs_.ForChosenRegstDeq( [&cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; }, [&cur_processed_regst_desc_id](const std::deque& reg_deq) { if (reg_deq.empty()) { return; } cur_processed_regst_desc_id = reg_deq.front()->regst_desc_id(); }); return cur_processed_regst_desc_id; } REGISTER_ACTOR(kReentrantLock, ReentrantLockActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/register_slot.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/register_slot.h" namespace oneflow { int64_t RegstSlot::GetReadyRegstSize(int64_t regst_desc_id) const { CHECK(is_inited_); auto it = regst_desc_id2regsts_.find(regst_desc_id); if (it == regst_desc_id2regsts_.end()) { return -1; } return it->second.size(); } bool RegstSlot::HasRegstDescId(int64_t regst_desc_id) const { CHECK(is_inited_); return regst_desc_id2regsts_.find(regst_desc_id) != regst_desc_id2regsts_.end(); } const std::deque& RegstSlot::RegstDeq4RegstDescId(int64_t regst_desc_id) const { CHECK(is_inited_); return regst_desc_id2regsts_.at(regst_desc_id); } int RegstSlot::TryPushBackRegst(Regst* regst) { return TryPushBackRegst(regst, regst->regst_desc_id()); } int RegstSlot::TryPushBackRegst(Regst* regst, int64_t regst_desc_id) { CHECK(is_inited_); auto it = regst_desc_id2regsts_.find(regst_desc_id); if (it == regst_desc_id2regsts_.end()) { return -1; } if (it->second.empty()) { available_regst_desc_cnt_ += 1; } it->second.emplace_back(regst); return 0; } int RegstSlot::TryPopFrontRegst(int64_t regst_desc_id) { CHECK(is_inited_); auto it = regst_desc_id2regsts_.find(regst_desc_id); if (it == regst_desc_id2regsts_.end()) { return -1; } CHECK(it->second.empty() == false); it->second.pop_front(); if (it->second.empty()) { available_regst_desc_cnt_ -= 1; } return 0; } void RegstSlot::PopFrontRegsts(const std::vector& regst_desc_ids) { CHECK(is_inited_); for (int64_t regst_desc_id : regst_desc_ids) { CHECK_EQ(0, TryPopFrontRegst(regst_desc_id)); } } void RegstSlot::InsertRegstDescId(int64_t regst_desc_id) { CHECK(is_inited_ == false); CHECK(regst_desc_id2regsts_.emplace(regst_desc_id, std::deque()).second); } Regst* RegstSlot::Front(int64_t regst_desc_id) const { CHECK(is_inited_); auto it = regst_desc_id2regsts_.find(regst_desc_id); if (it == regst_desc_id2regsts_.end()) { return nullptr; } if (it->second.empty()) { return nullptr; } return it->second.front(); } Regst* RegstSlot::SoleFront() const { CHECK(is_inited_); CHECK_EQ(1, total_regst_desc_cnt()); auto it = regst_desc_id2regsts_.begin(); if (it->second.empty()) { return nullptr; } return it->second.front(); } Regst* RegstSlot::FirstFront() const { CHECK(is_inited_); CHECK_GE(total_regst_desc_cnt(), 1); auto it = regst_desc_id2regsts_.begin(); if (it->second.empty()) { return nullptr; } return it->second.front(); } void RegstSlot::InitedDone() { CHECK(is_inited_ == false); is_inited_ = true; } void RegstSlot::ForChosenFrontRegst(const std::function& IsChosenRegstDescId, const std::function& Handler) const { for (const auto& kv : regst_desc_id2regsts_) { if (IsChosenRegstDescId(kv.first)) { CHECK(kv.second.empty() == false); Handler(kv.second.front()); } } } void RegstSlot::ForChosenFrontRegst( const std::function& IsChosenRegstDescId, const std::function& Handler) const { for (const auto& kv : regst_desc_id2regsts_) { if (IsChosenRegstDescId(kv.first)) { CHECK(kv.second.empty() == false); Handler(kv.first, kv.second.front()); } } } void RegstSlot::ForChosenRegstDeq( const std::function& IsChosenRegstDescId, const std::function&)>& Handler) const { for (const auto& kv : regst_desc_id2regsts_) { if (IsChosenRegstDescId(kv.first)) { Handler(kv.second); } } } void RegstSlot::ForChosenRegstDeq( const std::function& IsChosenRegstDescId, const std::function&)>& Handler) const { for (const auto& kv : regst_desc_id2regsts_) { if (IsChosenRegstDescId(kv.first)) { Handler(kv.first, kv.second); } } } void RegstSlot::ForEachFrontRegst(const std::function& Handler) const { ForChosenFrontRegst([](int64_t) { return true; }, Handler); } void RegstSlot::ForEachFrontRegst( const std::function& Handler) const { ForChosenFrontRegst([](int64_t) { return true; }, Handler); } void RegstSlot::ForEachRegstDeq( const std::function&)>& Handler) const { ForChosenRegstDeq([](int64_t) { return true; }, Handler); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/register_slot.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_ #define ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_ #include "oneflow/core/register/register_manager.h" namespace oneflow { class RegstSlot final { public: OF_DISALLOW_COPY_AND_MOVE(RegstSlot); RegstSlot() : regst_desc_id2regsts_(), available_regst_desc_cnt_(0), is_inited_(false) {} ~RegstSlot() = default; bool is_inited() const { return is_inited_; } size_t total_regst_desc_cnt() const { return regst_desc_id2regsts_.size(); } size_t available_regst_desc_cnt() const { return available_regst_desc_cnt_; } int64_t GetReadyRegstSize(int64_t regst_desc_id) const; bool IsCurSlotReady() const { return available_regst_desc_cnt() == total_regst_desc_cnt(); } bool HasRegstDescId(int64_t regst_desc_id) const; const std::deque& RegstDeq4RegstDescId(int64_t regst_desc_id) const; void ForEachFrontRegst(const std::function&) const; void ForEachFrontRegst(const std::function&) const; void ForEachRegstDeq(const std::function&)>&) const; void ForChosenFrontRegst(const std::function&, const std::function&) const; void ForChosenFrontRegst(const std::function&, const std::function&) const; void ForChosenRegstDeq(const std::function&, const std::function&)>&) const; void ForChosenRegstDeq( const std::function&, const std::function&)>&) const; Regst* Front(int64_t regst_desc_id) const; Regst* SoleFront() const; Regst* FirstFront() const; // 0: success, -1: cannot find regst_desc_id int TryPushBackRegst(Regst* regst); int TryPushBackRegst(Regst* regst, int64_t regst_desc_id); int TryPopFrontRegst(int64_t regst_desc_id); void PopFrontRegsts(const std::vector& regst_desc_ids); void InitedDone(); void InsertRegstDescId(int64_t regst_desc_id); private: HashMap> regst_desc_id2regsts_; size_t available_regst_desc_cnt_; bool is_inited_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_ ================================================ FILE: oneflow/core/lazy/actor/repeat_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class RepeatActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(RepeatActor); RepeatActor() : repeat_count_(0), repeat_num_(0), wait_all_regst_return_(false), consumed_var_regst_desc_id_(-1), produced_repeat_var_regst_desc_id_(-1){}; ~RepeatActor() override = default; private: // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized. std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } std::pair> GetNaiveOrCustomizedProducedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void TakeOverInplaceConsumedAndProduced( const PbMap& produced_ids) override { // NOTE(chengcheng): all regst is customized. inplace_consumed_rs_.InitedDone(); inplace_produced_rs_.InitedDone(); } bool IsCustomizedReadReady() const override { return (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { return (!wait_all_regst_return_) && produced_repeat_var_rs_.IsCurSlotReady(); } void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { // all Messages are flushed return ReceiveEordMsg(consumed_var_regst_desc_id_); } void VirtualActorInit(const TaskProto& proto) override; void Act() override; void AsyncSendCustomizedProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override; int32_t repeat_count_; int32_t repeat_num_; bool wait_all_regst_return_; int64_t consumed_var_regst_desc_id_; int64_t produced_repeat_var_regst_desc_id_; RegstSlot consumed_var_rs_; RegstSlot produced_repeat_var_rs_; }; void RepeatActor::VirtualActorInit(const TaskProto& proto) { repeat_count_ = 0; const OperatorConf& op_conf = proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); repeat_num_ = user_op::UserOpConfWrapper(op_conf).attr("repeat_num"); const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_GE(out_time_shape.NumAxes(), 1); CHECK_EQ(in_time_shape.NumAxes() + 1, out_time_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, in_time_shape.NumAxes()) { CHECK_EQ(in_time_shape.At(i), out_time_shape.At(i)); } CHECK_EQ(repeat_num_, out_time_shape.At(out_time_shape.NumAxes() - 1)); // input const auto& consumed_ids = proto.consumed_regst_desc_id(); auto in_it = consumed_ids.find("in"); CHECK(in_it != consumed_ids.end()); CHECK_EQ(in_it->second.regst_desc_id_size(), 1); consumed_var_regst_desc_id_ = in_it->second.regst_desc_id(0); consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_); consumed_var_rs_.InitedDone(); // output const auto& produced_ids = proto.produced_regst_desc(); auto out_it = produced_ids.find("out"); CHECK(out_it != produced_ids.end()); const RegstDescProto& out_regst_desc = out_it->second; CHECK(!out_regst_desc.enable_reuse_mem()); CHECK_EQ(out_regst_desc.register_num(), 1); // check inplace CHECK_EQ(out_regst_desc.inplace_consumed_regst_desc_id(), consumed_var_regst_desc_id_); produced_repeat_var_regst_desc_id_ = out_regst_desc.regst_desc_id(); produced_repeat_var_rs_.InsertRegstDescId(produced_repeat_var_regst_desc_id_); produced_repeat_var_rs_.InitedDone(); // NOTE(chengcheng): repeat actor may has output ctrl regst. ctrl regst also need hack regst num. for (const auto& pair : proto.produced_regst_desc()) { const RegstDescProto& regst_desc = pair.second; int64_t regst_desc_id = regst_desc.regst_desc_id(); // This iter begins from 1 because first ctrl regst was already inserted in // TakeOverNaiveProduced for (int64_t i = 1; i < repeat_num_; ++i) { Singleton::Get()->NewRegsts(regst_desc, [this, regst_desc_id](Regst* regst) { produced_regsts_[regst_desc_id].emplace_back(regst); produced_regst2reading_cnt_[regst] = 0; if (regst_desc_id != produced_repeat_var_regst_desc_id_) { CHECK_EQ(0, naive_produced_rs_.TryPushBackRegst(regst)); } }); } } ForEachProducedRegst([&](Regst* regst) { if (regst->regst_desc_id() == produced_repeat_var_regst_desc_id_) { CHECK_EQ(0, produced_repeat_var_rs_.TryPushBackRegst(regst)); } }); for (const auto& pair : proto.produced_regst_desc()) { const RegstDescProto& regst_desc = pair.second; int64_t regst_desc_id = regst_desc.regst_desc_id(); if (regst_desc_id == produced_repeat_var_regst_desc_id_) { CHECK_EQ(produced_repeat_var_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_); } else { CHECK_EQ(naive_produced_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_); } } OF_SET_MSG_HANDLER(&RepeatActor::HandlerNormal); } void RepeatActor::Act() { repeat_count_ += 1; if (repeat_count_ == repeat_num_) { wait_all_regst_return_ = true; repeat_count_ = 0; } Regst* out_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); CHECK(out_regst && in_regst); CHECK(out_regst->body_mem_ptr() == in_regst->body_mem_ptr()); CHECK(out_regst->header_mem_ptr() == in_regst->header_mem_ptr()); CHECK_EQ(out_regst->regst_desc()->MainByteSize4OneRegst(), in_regst->regst_desc()->MainByteSize4OneRegst()); CHECK_EQ(out_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst(), in_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst()); } void RepeatActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { CHECK(produced_repeat_var_rs_.IsCurSlotReady()); Regst* const repeat_var_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); CHECK_GT(HandleRegstToConsumer(repeat_var_regst), 0); produced_repeat_var_rs_.PopFrontRegsts({produced_repeat_var_regst_desc_id_}); } void RepeatActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { // NOTE(chengcheng): do nothing. consumed var regst will return in inplace done. } void RepeatActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { CHECK_EQ(regst->regst_desc_id(), produced_repeat_var_regst_desc_id_); CHECK_EQ(produced_repeat_var_rs_.TryPushBackRegst(regst), 0); if (wait_all_regst_return_ && produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_) == repeat_num_) { Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); CHECK(in_regst); AsyncSendRegstMsgToProducer(in_regst); CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_)); wait_all_regst_return_ = false; } } void RepeatActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(msg.regst())); } REGISTER_ACTOR(TaskType::kRepeat, RepeatActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/sink_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/sink_actor.h" namespace oneflow { void SinkActor::VirtualActorInit(const TaskProto& proto) { OF_SET_MSG_HANDLER(&SinkActor::HandlerNormal); VirtualSinkActorInit(proto); } void SinkActor::Act() { AsyncLaunchKernel(); } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/sink_actor.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_ #define ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_ #include "oneflow/core/lazy/actor/actor.h" namespace oneflow { class SinkActor : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(SinkActor); SinkActor() = default; virtual ~SinkActor() = default; protected: virtual void VirtualSinkActorInit(const TaskProto&) {} private: void VirtualActorInit(const TaskProto&) override; void Act() override; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_ ================================================ FILE: oneflow/core/lazy/actor/source_tick_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/record/record.pb.h" namespace oneflow { class SourceTickActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(SourceTickActor); SourceTickActor() = default; ~SourceTickActor() = default; private: void VirtualActorInit(const TaskProto&) override; void Act() override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } bool IsCustomizedReadReady() const override; bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); } int HandlerWaitToStart(const ActorMsg&); }; void SourceTickActor::VirtualActorInit(const TaskProto& task_proto) { OF_SET_MSG_HANDLER(&SourceTickActor::HandlerWaitToStart); } void SourceTickActor::Act() {} bool SourceTickActor::IsCustomizedReadReady() const { // NOTE(chengcheng): SourceTickActor CANNOT be used and need delete in the future return true; } int SourceTickActor::HandlerWaitToStart(const ActorMsg& msg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart); OF_SET_MSG_HANDLER(&SourceTickActor::HandlerNormal); return ProcessMsg(msg); } REGISTER_ACTOR(kSourceTick, SourceTickActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/ssp_variable_proxy_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { class SspVariableProxyActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(SspVariableProxyActor); SspVariableProxyActor() = default; ~SspVariableProxyActor() override = default; protected: std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } std::pair> GetNaiveOrCustomizedProducedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } bool IsCustomizedReadReady() const override { return consumed_var_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { int64_t cur_staleness = (received_var_piece_id_ - ack_msg_returned_ref_piece_id_); return ((cur_staleness <= staleness() /* bounded staleness */) && (produced_value_rs_.IsCurSlotReady() /* able to send messages to consumers of output `value` */)) || (produced_ref_rs_.IsCurSlotReady() /* able to send or to flush messages to consumers of output `ref` */); } void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { // all Messages are flushed return ReceiveEordMsg(consumed_var_regst_desc_id_) && (received_var_piece_id_ <= ack_msg_returned_value_piece_id_ + 1 /* there is no need to wait the last piece */) && (received_var_piece_id_ == ack_msg_returned_ref_piece_id_); } void UpdtStateAsCustomizedProducedRegst(Regst* regst) override { if (regst->regst_desc_id() == produced_value_regst_desc_id_) { ++ack_msg_returned_value_piece_id_; CHECK_EQ(regst, GetRingBufferValueRegst(ack_msg_returned_value_piece_id_)); CHECK_EQ(0, produced_value_rs_.TryPushBackRegst(regst)); if (ack_msg_returned_ref_piece_id_ == ack_msg_returned_value_piece_id_ /* All mutable consumers to ref regst has done their job */) { // The updated ref regst are not synced into value regst yet. SyncRefRegstIntoValueRegst(ack_msg_returned_value_piece_id_); } else if (ack_msg_returned_ref_piece_id_ > ack_msg_returned_value_piece_id_) { // The ACK of ref resgt can just be slightly earlier than the one of value regst. // `slightly` means `ack_msg_returned_ref_piece_id_ == ack_msg_returned_value_piece_id_` UNIMPLEMENTED(); } else { // Do nothing. The ref data is not updated yet. } } else if (regst->regst_desc_id() == produced_ref_regst_desc_id_) { ++ack_msg_returned_ref_piece_id_; CHECK_EQ(regst, ref_regst_); if (ack_msg_returned_value_piece_id_ >= ack_msg_returned_ref_piece_id_ /* All const consumers to value regst has done their job */) { SyncRefRegstIntoValueRegst(ack_msg_returned_ref_piece_id_); } else { // Do nothing. The ACK of value regst will do the sync work } } else { UNIMPLEMENTED(); } } void AsyncSendCustomizedProducedRegstMsgToConsumer() override { if (consumed_var_rs_.IsCurSlotReady() && produced_value_rs_.IsCurSlotReady()) { Regst* const value_regst = produced_value_rs_.Front(produced_value_regst_desc_id_); if (value_regst->consumers_actor_id().empty()) { ++ack_msg_returned_value_piece_id_; } else { CHECK_EQ(value_regst, GetRingBufferValueRegst(received_var_piece_id_)); CHECK_GT(HandleRegstToConsumer(value_regst), 0); produced_value_rs_.PopFrontRegsts({produced_value_regst_desc_id_}); } } if ((ack_msg_returned_ref_piece_id_ < received_var_piece_id_) && produced_ref_rs_.IsCurSlotReady()) { Regst* const ref_regst = produced_ref_rs_.Front(produced_ref_regst_desc_id_); if (ref_regst->consumers_actor_id().empty()) { ++ack_msg_returned_ref_piece_id_; } else { CHECK_GT(HandleRegstToConsumer(ref_regst), 0); produced_ref_rs_.PopFrontRegsts({produced_ref_regst_desc_id_}); } } } void AsyncSendCustomizedConsumedRegstMsgToProducer() override { Regst* const var_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); CHECK_NOTNULL(var_regst); AsyncSendRegstMsgToProducer(var_regst); CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_)); } void ForEachCurCustomizedReadableRegst(std::function Handler) const override { Handler(consumed_var_rs_.Front(consumed_var_regst_desc_id_)); } void TakeOverInplaceConsumedAndProduced( const PbMap& produced_ids) override { inplace_consumed_rs_.InitedDone(); inplace_produced_rs_.InitedDone(); } void VirtualActorInit(const TaskProto& task_proto) override { CheckInplaceBetweenVarAndRef(task_proto); TakeOverVarRegst(task_proto.consumed_regst_desc_id()); TakeOverRefRegst(task_proto.produced_regst_desc()); TakeOverValueRegst(task_proto.produced_regst_desc()); OF_SET_MSG_HANDLER(&SspVariableProxyActor::HandlerNormal); } bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override { return true; } void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override { if (var_regst_ == nullptr) { var_regst_ = msg.regst(); } else { CHECK_EQ(var_regst_, msg.regst()); } CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(var_regst_)); ++received_var_piece_id_; } private: void Act() override { if (received_var_piece_id_ == 0) { // Initialize all value regsts for (int64_t piece_id = 0; piece_id < staleness(); ++piece_id) { CopyRefToValue(GetRingBufferValueRegst(piece_id)); } } else { // Do nothing, value regsts are updated in UpdtStateAsCustomizedProducedRegst } } void CheckInplaceBetweenVarAndRef(const TaskProto& task_proto) const { int64_t var_id = task_proto.consumed_regst_desc_id().at("var").regst_desc_id(0); const auto& ref_regst_desc_proto = task_proto.produced_regst_desc().at("ref"); CHECK_EQ(ref_regst_desc_proto.inplace_consumed_regst_desc_id(), var_id); } void TakeOverVarRegst(const PbMap& consumed_ids) { received_var_piece_id_ = -1; consumed_var_regst_desc_id_ = consumed_ids.at("var").regst_desc_id(0); consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_); consumed_var_rs_.InitedDone(); var_regst_ = nullptr; } void TakeOverRefRegst(const PbMap& produced_ids) { ack_msg_returned_ref_piece_id_ = -1; produced_ref_regst_desc_id_ = produced_ids.at("ref").regst_desc_id(); produced_ref_rs_.InsertRegstDescId(produced_ref_regst_desc_id_); produced_ref_rs_.InitedDone(); ref_regst_ = nullptr; ForEachProducedRegst([&](Regst* regst) { if (regst->regst_desc_id() != produced_ref_regst_desc_id_) { return; } CHECK(ref_regst_ == nullptr) << "regst_num of ref_regst must equal 1"; CHECK_EQ(0, produced_ref_rs_.TryPushBackRegst(regst)); ref_regst_ = regst; }); } void TakeOverValueRegst(const PbMap& produced_ids) { ack_msg_returned_value_piece_id_ = -1; produced_value_regst_desc_id_ = produced_ids.at("value").regst_desc_id(); produced_value_rs_.InsertRegstDescId(produced_value_regst_desc_id_); produced_value_rs_.InitedDone(); ForEachProducedRegst([&](Regst* regst) { if (regst->regst_desc_id() != produced_value_regst_desc_id_) { return; } CHECK_EQ(0, produced_value_rs_.TryPushBackRegst(regst)); value_regst_ring_buffer_.push_back(regst); }); } void SyncRefRegstIntoValueRegst(int64_t released_piece_id) { CopyRefToValue(GetRingBufferValueRegst(released_piece_id)); CHECK_EQ(0, produced_ref_rs_.TryPushBackRegst(ref_regst_)); } void CopyRefToValue(Regst* value_regst) { AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { if (regst_desc_id == consumed_var_regst_desc_id_) { return var_regst_; } else if (regst_desc_id == produced_ref_regst_desc_id_) { return ref_regst_; } else if (regst_desc_id == produced_value_regst_desc_id_) { return value_regst; } else { UNIMPLEMENTED(); } }); } Regst* GetRingBufferValueRegst(int64_t value_piece_id) const { return value_regst_ring_buffer_.at(value_piece_id % staleness()); } size_t staleness() const { return value_regst_ring_buffer_.size(); } // input var int64_t received_var_piece_id_; int64_t consumed_var_regst_desc_id_; RegstSlot consumed_var_rs_; Regst* var_regst_; // output ref // consumers has used the ref regst int64_t ack_msg_returned_ref_piece_id_; int64_t produced_ref_regst_desc_id_; RegstSlot produced_ref_rs_; Regst* ref_regst_; // output value // consumers has used the value regst int64_t ack_msg_returned_value_piece_id_; int64_t produced_value_regst_desc_id_; RegstSlot produced_value_rs_; std::vector value_regst_ring_buffer_; }; REGISTER_ACTOR(TaskType::kSspVariableProxy, SspVariableProxyActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/tick_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/naive_actor.h" namespace oneflow { class TickActor final : public NaiveActor { public: OF_DISALLOW_COPY_AND_MOVE(TickActor); TickActor() = default; ~TickActor() = default; private: void Act() override {} }; REGISTER_ACTOR(kTick, TickActor); REGISTER_ACTOR(kDeviceTick, TickActor); REGISTER_ACTOR(kSrcSubsetTick, TickActor); REGISTER_ACTOR(kDstSubsetTick, TickActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/unpack_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/user_kernel.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { class UnpackActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(UnpackActor); UnpackActor() = default; ~UnpackActor() override = default; private: void VirtualActorInit(const TaskProto& proto) override; void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void VirtualAsyncSendNaiveConsumedRegstMsgToProducer() override; bool ConsumedCtrlRegstValid(int64_t regst_desc_id) const override; size_t total_unpack_num_; size_t act_num_cnt_; }; void UnpackActor::VirtualActorInit(const TaskProto& proto) { const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); total_unpack_num_ = out_time_shape.At(out_time_shape.NumAxes() - 1); act_num_cnt_ = 0; OF_SET_MSG_HANDLER(&UnpackActor::HandlerNormal); } void UnpackActor::Act() { CHECK_GE(exec_kernel_vec().size(), 1); auto user_kernel = dynamic_cast(exec_kernel_vec().at(0).kernel.get()); CHECK_NOTNULL(user_kernel); auto state = dynamic_cast>*>( user_kernel->GetOpKernelState().get()); CHECK_NOTNULL(state); state->Mutable()->first = act_num_cnt_; state->Mutable()->second = total_unpack_num_; AsyncLaunchKernel(); act_num_cnt_ += 1; } void UnpackActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { HandleProducedNaiveDataRegstToConsumer(); } void UnpackActor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() { if (act_num_cnt_ == total_unpack_num_) { HandleConsumedNaiveDataRegstToProducer(); act_num_cnt_ = 0; } } bool UnpackActor::ConsumedCtrlRegstValid(int64_t regst_desc_id) const { return act_num_cnt_ == 0; } REGISTER_ACTOR(TaskType::kUnpack, UnpackActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/actor/wait_and_send_ids_actor.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/wait_and_send_ids_kernel.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/record/record.pb.h" namespace oneflow { class WaitAndSendIdsActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsActor); WaitAndSendIdsActor() : wait_and_send_ids_status_(nullptr) {} ~WaitAndSendIdsActor() = default; private: void VirtualActorInit(const TaskProto&) override; void Act() override; std::pair> GetNaiveOrCustomizedConsumedRegstDescName() override { return std::make_pair(RegstNameType::kNaive, HashSet{}); } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; bool IsCustomizedReadReady() const override; bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); } int HandlerWaitToStart(const ActorMsg&); WaitAndSendIdsStatus* wait_and_send_ids_status_; }; void WaitAndSendIdsActor::VirtualActorInit(const TaskProto& task_proto) { CHECK_EQ(exec_kernel_vec().size(), 1); wait_and_send_ids_status_ = CHECK_NOTNULL( dynamic_cast(exec_kernel_vec().at(0).kernel_ctx->state().get())); wait_and_send_ids_status_->buffer_status_ = kBufferStatusSuccess; wait_and_send_ids_status_->in_id_ = 0; wait_and_send_ids_status_->out_idx_ = 0; wait_and_send_ids_status_->out_num_ = 0; OF_SET_MSG_HANDLER(&WaitAndSendIdsActor::HandlerWaitToStart); } void WaitAndSendIdsActor::Act() { CHECK_LE(wait_and_send_ids_status_->out_idx_, wait_and_send_ids_status_->out_num_); AsyncLaunchKernel(); } void WaitAndSendIdsActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (wait_and_send_ids_status_->buffer_status_ == kBufferStatusSuccess) { HandleProducedNaiveDataRegstToConsumer(); } } bool WaitAndSendIdsActor::IsCustomizedReadReady() const { return wait_and_send_ids_status_->buffer_status_ == kBufferStatusSuccess; } int WaitAndSendIdsActor::HandlerWaitToStart(const ActorMsg& msg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart); OF_SET_MSG_HANDLER(&WaitAndSendIdsActor::HandlerNormal); return ProcessMsg(msg); } REGISTER_ACTOR(kWaitAndSendIds, WaitAndSendIdsActor); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/stream_context/common/generic_stream_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/stream_context/include/generic_stream_context.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/include/active_device_guard.h" namespace oneflow { GenericStreamContext::GenericStreamContext(const StreamId& stream_id) : stream_(nullptr) { device_ = std::dynamic_pointer_cast(Singleton::Get()->GetDevice( stream_id.device_type(), stream_id.device_index())); CHECK(device_); ep::ActiveDeviceGuard guard(device_.get()); stream_ = dynamic_cast(device_->CreateStream()); CHECK(stream_ != nullptr); poller_thread_ = std::thread([this]() { CHECK_JUST(stream_->OnExecutionContextSetup()); std::pair> cb_event; while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) { CHECK_JUST(cb_event.first->Sync()); cb_event.second(); device_->DestroyEvent(cb_event.first); } CHECK_JUST(stream_->OnExecutionContextTeardown()); }); } GenericStreamContext::~GenericStreamContext() { ep::ActiveDeviceGuard guard(device_.get()); cb_event_chan_.Close(); poller_thread_.join(); device_->DestroyStream(stream_); } Maybe GenericStreamContext::AddCallback(std::function callback) { ep::Event* event = device_->CreateEvent(); stream_->RecordEvent(event); cb_event_chan_.Send(std::make_pair(event, std::move(callback))); return Maybe::Ok(); } ep::Stream* GenericStreamContext::stream() { return stream_; } } // namespace oneflow ================================================ FILE: oneflow/core/lazy/stream_context/cpu/cpu_stream_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/stream_context/include/stream_context.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/device/event_record.h" #include "oneflow/core/kernel/chain_kernel_observer.h" #include "oneflow/core/kernel/cpu_check_numerics_kernel_observer.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { class CpuStreamContext : public StreamContext, public KernelObserverProvider { public: OF_DISALLOW_COPY_AND_MOVE(CpuStreamContext); CpuStreamContext(); ~CpuStreamContext() override; ep::Stream* stream() override; Maybe AddCallback(std::function callback) override; KernelObserver* GetKernelObserver() override; DeviceType device_type() const override { return DeviceType::kCPU; } private: std::shared_ptr device_; ep::Stream* stream_; std::unique_ptr kernel_observer_; }; CpuStreamContext::CpuStreamContext() : stream_(nullptr) { device_ = Singleton::Get()->GetDevice(DeviceType::kCPU, 0); stream_ = device_->CreateStream(); // NOLINT std::vector> kernel_observers; if (ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS", false)) { kernel_observers.emplace_back(new CpuCheckNumericsKernelObserver()); } kernel_observer_.reset(new ChainKernelObserver(kernel_observers)); } CpuStreamContext::~CpuStreamContext() { device_->DestroyStream(stream_); } ep::Stream* CpuStreamContext::stream() { return stream_; } Maybe CpuStreamContext::AddCallback(std::function callback) { callback(); return Maybe::Ok(); } KernelObserver* CpuStreamContext::GetKernelObserver() { return kernel_observer_.get(); } REGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID(DeviceType::kCPU, ([](const StreamId& stream_id) -> StreamContext* { return new CpuStreamContext(); })); } // namespace oneflow ================================================ FILE: oneflow/core/lazy/stream_context/cuda/cuda_stream_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/stream_context/include/stream_context.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/kernel/chain_kernel_observer.h" #include "oneflow/core/kernel/cuda_check_numerics_kernel_observer.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/common/channel.h" #ifdef WITH_CUDA #include namespace oneflow { namespace { class CudaStreamContext : public StreamContext, public KernelObserverProvider { public: OF_DISALLOW_COPY_AND_MOVE(CudaStreamContext); explicit CudaStreamContext(int device_index); ~CudaStreamContext() override; Maybe AddCallback(std::function callback) override; DeviceType device_type() const override { return DeviceType::kCUDA; } KernelObserver* GetKernelObserver() override; ep::Stream* stream() override; private: ep::CudaStream* stream_; Channel>> cb_event_chan_; std::thread poller_thread_; int device_index_; std::unique_ptr kernel_observer_; std::shared_ptr device_; }; CudaStreamContext::CudaStreamContext(int device_index) : stream_(nullptr), device_index_(device_index) { CudaCurrentDeviceGuard guard(device_index_); device_ = std::dynamic_pointer_cast( Singleton::Get()->GetDevice(DeviceType::kCUDA, device_index)); CHECK(device_); stream_ = dynamic_cast(device_->CreateStream()); CHECK(stream_ != nullptr); std::vector> kernel_observers; if (ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS", false)) { LOG(WARNING) << "Environment variable ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS has been set " "to a truthy " "value, it will impact performance"; kernel_observers.emplace_back(new CudaCheckNumericsKernelObserver()); } kernel_observer_.reset(new ChainKernelObserver(kernel_observers)); poller_thread_ = std::thread([this]() { CHECK_JUST(stream_->OnExecutionContextSetup()); OF_PROFILER_NAME_THIS_HOST_THREAD("_cuda" + std::to_string(device_index_) + " Poller : (" + std::to_string(device_index_) + ")"); std::pair> cb_event; while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) { CHECK_JUST(cb_event.first->Sync()); cb_event.second(); device_->DestroyEvent(cb_event.first); } CHECK_JUST(stream_->OnExecutionContextTeardown()); }); } CudaStreamContext::~CudaStreamContext() { CudaCurrentDeviceGuard guard(device_index_); cb_event_chan_.Close(); poller_thread_.join(); device_->DestroyStream(stream_); } Maybe CudaStreamContext::AddCallback(std::function callback) { ep::Event* event = device_->CreateEvent(); stream_->RecordEvent(event); cb_event_chan_.Send(std::make_pair(event, std::move(callback))); return Maybe::Ok(); } KernelObserver* CudaStreamContext::GetKernelObserver() { return kernel_observer_.get(); } ep::Stream* CudaStreamContext::stream() { return stream_; } REGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID( DeviceType::kCUDA, ([](const StreamId& stream_id) -> StreamContext* { CHECK_EQ(stream_id.device_type(), DeviceType::kCUDA); return new CudaStreamContext(stream_id.device_index()); })); } // namespace } // namespace oneflow #endif ================================================ FILE: oneflow/core/lazy/stream_context/include/generic_stream_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_ #define ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_ #include "oneflow/core/lazy/stream_context/include/stream_context.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/graph/stream_id.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/common/channel.h" namespace oneflow { class GenericStreamContext : public StreamContext { public: OF_DISALLOW_COPY_AND_MOVE(GenericStreamContext); explicit GenericStreamContext(const StreamId& stream_id); ~GenericStreamContext() override; Maybe AddCallback(std::function callback) override; DeviceType device_type() const override { return stream_->device_type(); } ep::Stream* stream() override; private: ep::Stream* stream_; Channel>> cb_event_chan_; std::thread poller_thread_; std::shared_ptr device_; }; } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_ ================================================ FILE: oneflow/core/lazy/stream_context/include/stream_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_ #define ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class StreamContext { public: OF_DISALLOW_COPY_AND_MOVE(StreamContext); StreamContext() = default; virtual ~StreamContext() = default; virtual ep::Stream* stream() = 0; virtual Maybe AddCallback(std::function callback) = 0; virtual DeviceType device_type() const = 0; }; #define REGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID(device, creator) \ REGISTER_CLASS_CREATOR(int, device, StreamContext, creator, const StreamId&) } // namespace oneflow #endif // ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_ ================================================ FILE: oneflow/core/memory/chunk_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/memory/chunk_manager.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { void ChunkMgr::GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid, std::vector* chunks) const { std::unique_lock guard(mutex_); chunks->clear(); auto chunk_ids_it = mzuid2chunk_ids_.find(mem_zone_uid); if (chunk_ids_it != mzuid2chunk_ids_.end()) { const auto& chunk_ids = chunk_ids_it->second; chunks->reserve(chunk_ids.size()); for (int64_t chunk_id : chunk_ids) { auto chunk_it = chunk_id2chunk_proto_.find(chunk_id); CHECK(chunk_it != chunk_id2chunk_proto_.end()); chunks->emplace_back(chunk_it->second.get()); } } } void ChunkMgr::AddChunkProto(const ChunkProto& chunk) { std::unique_lock guard(mutex_); const int64_t mem_zone_uid = memory::GetUniqueMemCaseId(chunk.machine_id(), chunk.mem_case()); CHECK( chunk_id2chunk_proto_.emplace(chunk.chunk_id(), std::make_unique(chunk)).second); auto chunk_ids_it = mzuid2chunk_ids_.find(mem_zone_uid); if (chunk_ids_it == mzuid2chunk_ids_.end()) { chunk_ids_it = mzuid2chunk_ids_.emplace(mem_zone_uid, HashSet()).first; } CHECK(chunk_ids_it->second.insert(chunk.chunk_id()).second); } char* ChunkMgr::FindOrCreateChunk(const ChunkProto& chunk) { std::unique_lock guard(mutex_); CHECK_EQ(GlobalProcessCtx::Rank(), chunk.machine_id()); auto it = chunk_id2chunk_.find(chunk.chunk_id()); if (it == chunk_id2chunk_.end()) { char* chunk_ptr = Singleton::Get()->Allocate(chunk.mem_case(), chunk.mem_size()); it = chunk_id2chunk_.emplace(chunk.chunk_id(), ChunkWithPtr(chunk_ptr, chunk)).first; } else { const ChunkProto& store_proto = it->second.chunk_proto; CHECK_EQ(chunk.chunk_id(), store_proto.chunk_id()); CHECK_EQ(chunk.machine_id(), store_proto.machine_id()); CHECK(chunk.mem_case() == store_proto.mem_case()); CHECK_EQ(chunk.mem_size(), store_proto.mem_size()); } return it->second.ptr; } } // namespace oneflow ================================================ FILE: oneflow/core/memory/chunk_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_ #define ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_ #include #include "oneflow/core/job/id_manager.h" #include "oneflow/core/memory/memory_block.pb.h" #include "oneflow/core/memory/memory_allocator.h" namespace oneflow { class ChunkMgr final { public: OF_DISALLOW_COPY_AND_MOVE(ChunkMgr); ChunkMgr() = default; ~ChunkMgr() = default; // Compiler void GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid, std::vector* chunks) const; void AddChunkProto(const ChunkProto& chunk); // Runtime char* FindOrCreateChunk(const ChunkProto& chunk); private: // for master compiler in PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan HashMap> mzuid2chunk_ids_; HashMap> chunk_id2chunk_proto_; struct ChunkWithPtr { char* ptr; ChunkProto chunk_proto; ChunkWithPtr(char* p, const ChunkProto& c_p) : ptr(p), chunk_proto(c_p) {} }; // for runtime HashMap chunk_id2chunk_; mutable std::mutex mutex_; }; } // namespace oneflow #endif // ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_ ================================================ FILE: oneflow/core/memory/memory_allocator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace { std::shared_ptr GetAllocationDevice(const MemoryCase& mem_case) { auto device = Singleton::Get()->GetDevice(mem_case.device_type(), mem_case.device_id()); CHECK(device); return device; } ep::AllocationOptions GetAllocationOptions(const MemoryCase& mem_case) { ep::AllocationOptions options{}; if (mem_case.has_pinned_device_type() && mem_case.has_pinned_device_id()) { options.SetPinnedDevice(mem_case.pinned_device_type(), mem_case.pinned_device_id()); } return options; } } // namespace void* MemoryAllocatorImpl::Allocate(const MemoryCase& mem_case, size_t size) { void* ptr = nullptr; std::shared_ptr device = GetAllocationDevice(mem_case); ep::AllocationOptions options = GetAllocationOptions(mem_case); CHECK_JUST(device->Alloc(options, &ptr, size)); return ptr; } void MemoryAllocatorImpl::Deallocate(void* ptr, const MemoryCase& mem_case) { std::shared_ptr device = GetAllocationDevice(mem_case); ep::AllocationOptions options = GetAllocationOptions(mem_case); device->Free(options, ptr); } void* MemoryAllocatorImpl::AllocateUnPinnedHostMem(size_t size) { void* ptr = aligned_alloc(kHostAlignSize, size); CHECK_NOTNULL(ptr); return ptr; } void MemoryAllocatorImpl::DeallocateUnPinnedHostMem(void* ptr) { free(ptr); // NOLINT } MemoryAllocator::~MemoryAllocator() { for (const std::function& deleter : deleters_) { deleter(); } } char* MemoryAllocator::Allocate(const MemoryCase& mem_case, std::size_t size) { char* dptr = static_cast(MemoryAllocatorImpl::Allocate(mem_case, size)); deleters_.push_front(std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case)); return dptr; } void MemoryAllocator::Deallocate(char* dptr, const MemoryCase& mem_case) { MemoryAllocatorImpl::Deallocate(static_cast(dptr), mem_case); } void InitNonPODTypeBlobIfNeed(MemoryAllocator* allocator, Blob* blob_ptr) { const BlobDesc& blob_desc = blob_ptr->blob_desc(); if (blob_desc.data_type() == kOFRecord) { int64_t elem_cnt = blob_desc.shape().elem_cnt(); FOR_RANGE(int64_t, idx, 0, elem_cnt) { allocator->PlacementNew(&blob_ptr->mut_dptr()[idx]); } } if (blob_desc.data_type() == kTensorBuffer) { int64_t elem_cnt = blob_desc.shape().elem_cnt(); FOR_RANGE(int64_t, idx, 0, elem_cnt) { allocator->PlacementNew(&blob_ptr->mut_dptr()[idx]); } } } void InitNonPODTypeEagerBlobObjectIfNeed(MemoryAllocator* allocator, vm::EagerBlobObject* eager_blob_object_ptr) { if (eager_blob_object_ptr->data_type() == kOFRecord) { int64_t elem_cnt = eager_blob_object_ptr->shape().elem_cnt(); FOR_RANGE(int64_t, idx, 0, elem_cnt) { allocator->PlacementNew(&eager_blob_object_ptr->mut_dptr()[idx]); } } if (eager_blob_object_ptr->data_type() == kTensorBuffer) { int64_t elem_cnt = eager_blob_object_ptr->shape().elem_cnt(); FOR_RANGE(int64_t, idx, 0, elem_cnt) { allocator->PlacementNew(&eager_blob_object_ptr->mut_dptr()[idx]); } } } } // namespace oneflow ================================================ FILE: oneflow/core/memory/memory_allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_ #define ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/memory/memory_case_util.h" namespace oneflow { namespace vm { class EagerBlobObject; } class MemoryAllocator final { public: OF_DISALLOW_COPY_AND_MOVE(MemoryAllocator); MemoryAllocator() = default; ~MemoryAllocator(); char* Allocate(const MemoryCase& mem_case, std::size_t size); template T* PlacementNew(T* mem_ptr); private: void Deallocate(char* dptr, const MemoryCase& mem_case); std::mutex deleters_mutex_; std::list> deleters_; }; class Blob; void InitNonPODTypeBlobIfNeed(MemoryAllocator* allocator, Blob* blob_ptr); void InitNonPODTypeEagerBlobObjectIfNeed(MemoryAllocator* allocator, vm::EagerBlobObject* eager_blob_object_ptr); template T* MemoryAllocator::PlacementNew(T* mem_ptr) { T* obj = new (mem_ptr) T(); { std::unique_lock lock(deleters_mutex_); deleters_.push_front([obj] { obj->~T(); }); } CHECK_EQ(mem_ptr, obj); return obj; } struct MemoryAllocatorImpl final { static void* Allocate(const MemoryCase& mem_case, size_t size); static void Deallocate(void* ptr, const MemoryCase& mem_case); static void* AllocateUnPinnedHostMem(size_t size); static void DeallocateUnPinnedHostMem(void* ptr); }; } // namespace oneflow #endif // ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_ ================================================ FILE: oneflow/core/memory/memory_block.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/memory/memory_case.proto"; message MemBlockProto { required int64 mem_block_id = 1; repeated int64 job_id = 2; required int64 machine_id = 3; required MemoryCase mem_case = 4; required bool enable_reuse_mem = 5; optional int64 chunk_id = 6 [default = -1]; optional int64 chunk_offset = 7 [default = -1]; required int64 mem_size = 8; // NOTE(chengcheng): thrd id hint is used by packed separated block group order. optional int64 thrd_id_hint = 9 [default = -1]; // NOTE(chengcheng): mark this block memory is shared with EagerParameter. optional string variable_op_name = 10 [default = ""]; optional bool is_separated_header = 11 [default = false]; } message ChunkProto { required int64 chunk_id = 1; repeated int64 job_id = 2; required int64 machine_id = 3; required MemoryCase mem_case = 4; required int64 mem_size = 5; } message MemBlockAndChunkList { repeated MemBlockProto mem_block = 1; repeated ChunkProto chunk = 2; } ================================================ FILE: oneflow/core/memory/memory_case.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/device_type.proto"; message MemoryCase { required DeviceType device_type = 1; required int64 device_id = 2; optional DeviceType pinned_device_type = 3; optional int64 pinned_device_id = 4; } ================================================ FILE: oneflow/core/memory/memory_case_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/memory/memory_case_util.h" #include namespace oneflow { namespace memory { bool EqualsIgnorePinnedDevice(const MemoryCase& a, const MemoryCase& b) { if (a.device_type() != b.device_type()) { return false; } if (a.device_id() != b.device_id()) { return false; } return true; } void GetPinnedHostMemoryCase(const MemoryCase& mem_case, MemoryCase* ret) { ret->set_device_type(DeviceType::kCPU); ret->set_device_id(0); if (!IsHostMem(mem_case)) { ret->set_pinned_device_type(mem_case.device_type()); ret->set_pinned_device_id(mem_case.device_id()); } } MemoryCase GetPinnedHostMemoryCase(const MemoryCase& mem_case) { MemoryCase ret; GetPinnedHostMemoryCase(mem_case, &ret); return ret; } // clang-format off // MemCaseId encoding (bits) // | reserved | node_index | device_type | device_index | reserved | pinned_device_type | pinned_device_index | // | --- 1 -- | --- 19 --- | ---- 5 ---- | ----- 7 ---- | -- 20 -- | ------- 5 -------- | ------- 7 --------- | // | ---------------------- 32 ------------------------ | ---------------------- 32 ------------------------- | // clang-format on namespace { constexpr size_t kDeviceIndexBits = 7; constexpr size_t kDeviceTypeBits = 5; constexpr size_t kDeviceTypeShift = kDeviceIndexBits; constexpr size_t kNodeIndexShift = kDeviceTypeShift + kDeviceTypeBits; constexpr size_t kPinnedDeviceShift = 32; } // namespace int64_t GetMemCaseId(const MemoryCase& mem_case) { uint32_t high = 0; high |= static_cast(mem_case.device_id()); high |= static_cast(mem_case.device_type()) << kDeviceTypeShift; uint32_t low = 0; if (mem_case.has_pinned_device_id()) { low |= static_cast(mem_case.pinned_device_id()); } if (mem_case.has_pinned_device_type()) { low |= static_cast(mem_case.pinned_device_type()) << kDeviceTypeShift; } int64_t id = 0; id |= static_cast(high) << kPinnedDeviceShift; id |= static_cast(low); return id; } int64_t GetUniqueMemCaseId(int64_t machine_id, const MemoryCase& mem_case) { int64_t id = 0; id |= (machine_id << kNodeIndexShift << kPinnedDeviceShift); id |= GetMemCaseId(mem_case); return id; } std::shared_ptr MakeMemCaseShared(const DeviceType device_type, const int64_t device_id) { auto mem_case_ptr = std::make_shared(); mem_case_ptr->set_device_type(device_type); // We consider that there is only one cpu physical device. // As non-cpu devices, a logical device map to a physical device, // however as cpu devices, all logical devices map to a single physical device. if (device_type == DeviceType::kCPU) { mem_case_ptr->set_device_id(0); } else { mem_case_ptr->set_device_id(device_id); } return mem_case_ptr; } MemoryCase MakeHostMemCase() { MemoryCase mem_case; mem_case.set_device_type(DeviceType::kCPU); mem_case.set_device_id(0); return mem_case; } bool IsHostMem(const MemoryCase& mem_case) { return mem_case.device_type() == DeviceType::kCPU; } } // namespace memory bool operator==(const MemoryCase& lhs, const MemoryCase& rhs) { return google::protobuf::util::MessageDifferencer::Equals(lhs, rhs); } } // namespace oneflow ================================================ FILE: oneflow/core/memory/memory_case_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_ #define ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/memory/memory_case.pb.h" namespace oneflow { namespace memory { bool EqualsIgnorePinnedDevice(const MemoryCase& a, const MemoryCase& b); void GetPinnedHostMemoryCase(const MemoryCase& mem_case, MemoryCase* ret); MemoryCase GetPinnedHostMemoryCase(const MemoryCase& mem_case); int64_t GetMemCaseId(const MemoryCase& mem_case); int64_t GetUniqueMemCaseId(int64_t machine_id, const MemoryCase& mem_case); std::shared_ptr MakeMemCaseShared(const DeviceType device_type, const int64_t device_id); MemoryCase MakeHostMemCase(); bool IsHostMem(const MemoryCase& mem_case); } // namespace memory bool operator==(const MemoryCase& lhs, const MemoryCase& rhs); } // namespace oneflow #endif // ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_ ================================================ FILE: oneflow/core/memory/memory_zone.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/memory/memory_zone.h" namespace oneflow { namespace { constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits; constexpr size_t kMemZoneIdRankShift = kMemZoneIdDeviceTypeShift + MemZoneId::kDeviceTypeBits; constexpr int64_t kMemZoneIdRankInt64Mask = ((int64_t{1} << MemZoneId::kRankBits) - 1) << kMemZoneIdRankShift; constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1) << kMemZoneIdDeviceTypeShift; constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1; } // namespace const MemZoneId kInvalidMemZoneId = MemZoneId{0, DeviceType::kInvalidDevice, 0}; MemZoneId GetNodeCPUMemZoneId(MemZoneId::rank_t node_index) { return MemZoneId{node_index, DeviceType::kCPU, 0}; } int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) { int64_t id = static_cast(mem_zone_id.device_index()); id |= static_cast(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift; id |= static_cast(mem_zone_id.rank()) << kMemZoneIdRankShift; return id; } MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) { int64_t rank = (mem_zone_id & kMemZoneIdRankInt64Mask) >> kMemZoneIdRankShift; int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift; int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask; return MemZoneId(static_cast(rank), static_cast(device_type), static_cast(device_index)); } } // namespace oneflow ================================================ FILE: oneflow/core/memory/memory_zone.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ #define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ #include "oneflow/core/device/device_id.h" namespace oneflow { using MemZoneId = DeviceId; int64_t EncodeMemZoneIdToInt64(const MemZoneId&); MemZoneId DecodeMemZoneIdFromInt64(int64_t); MemZoneId GetNodeCPUMemZoneId(MemZoneId::rank_t node_index); extern const MemZoneId kInvalidMemZoneId; } // namespace oneflow #endif // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ ================================================ FILE: oneflow/core/ndarray/binary_func.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ #define ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ #include #include #include #include #if defined(__CUDACC__) #include #endif #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/util.h" namespace oneflow { #define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max)(FloorMod)(FMod)(Pow) #define LOGICAL_BINARY_FUNC_NAME_SEQ (EQ)(NE)(GT)(GE)(LT)(LE)(AND)(OR)(XOR) #define PREPEND_PREFIX_BINARY_FUNC(name) OF_PP_CAT(BinaryFunc, name) #define ARITHMETIC_BINARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ) #define LOGICAL_BINARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ) #define REDUCE_BINARY_FUNC_NAME_SEQ (Sum)(Max)(Min)(Prod)(Any)(All) #define ARITHMETIC_REDUCE_BINARY_FUNC_NAME_SEQ (Sum)(Max)(Min)(Prod) #define LOGICAL_REDUCE_BINARY_FUNC_NAME_SEQ (Any)(All) #define REDUCE_BINARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ) #define REDUCE_COMPLEX_BINARY_FUNC_SEQ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, (Sum)) #define ARITHMETIC_REDUCE_BINARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, ARITHMETIC_REDUCE_BINARY_FUNC_NAME_SEQ) #define LOGICAL_REDUCE_BINARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, LOGICAL_REDUCE_BINARY_FUNC_NAME_SEQ) #define NANSUM_REDUCE_BINARY_FUNC_SEQ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, (NanSum)) #define NO_HALF_UTIL_FOUND \ printf("cuda arch must >= 530"); \ assert(false); \ return __float2half(0.0) template class BinaryFunc, typename T> struct BinaryFuncTrait final { typedef typename std::remove_const::Invoke(std::declval(), std::declval()))>::type return_type; }; #define SPECIALIZE_CONST_TYPE_BINARY_FUNC(func_struct) \ template \ struct func_struct final { \ static OF_DEVICE_FUNC const typename BinaryFuncTrait::return_type Invoke( \ const T x, const T y) { \ return func_struct::Invoke(x, y); \ } \ } template struct BinaryFuncNanSum final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { #if defined(__CUDACC__) if (isnan(x)) return isnan(y) ? T{0} : y; return isnan(y) ? x : x + y; #else if (std::isnan(x)) return std::isnan(y) ? T{0} : y; return std::isnan(y) ? x : x + y; #endif } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncNanSum); template struct BinaryFuncAdd final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x + y; } }; template struct BinaryFuncSum final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return BinaryFuncAdd::Invoke(x, y); } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncAdd); template struct BinaryFuncSub final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x - y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncSub); template struct BinaryFuncMul final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x * y; } }; template<> struct BinaryFuncMul final { static OF_DEVICE_FUNC bool Invoke(const bool x, const bool y) { return x && y; } }; template struct BinaryFuncProd final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return BinaryFuncMul::Invoke(x, y); } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMul); template struct BinaryFuncDiv final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x / y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncDiv); template struct BinaryFuncFloorMod final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { #if defined(__CUDACC__) T trunc_mod = x % y; return (trunc_mod != T(0)) && ((y < T(0)) != (trunc_mod < T(0))) ? trunc_mod + y : trunc_mod; #else T trunc_mod = x % y; return (trunc_mod != T(0)) && ((y < T(0)) != (trunc_mod < T(0))) ? trunc_mod + y : trunc_mod; #endif } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFloorMod); template<> struct BinaryFuncFloorMod final { static OF_DEVICE_FUNC uint8_t Invoke(const uint8_t x, const uint8_t y) { #if defined(__CUDACC__) uint8_t trunc_mod = x % y; return trunc_mod; #else uint8_t trunc_mod = x % y; return trunc_mod; #endif } }; template struct BinaryFuncFMod final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { #if defined(__CUDACC__) T trunc_mod = x % y; return trunc_mod; #else T trunc_mod = x % y; return trunc_mod; #endif } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFMod); template struct BinaryFuncPow final { static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { #if defined(__CUDACC__) return powf(x, y); #else return std::pow(x, y); #endif } }; template<> struct BinaryFuncPow final { static OF_DEVICE_FUNC bool Invoke(const bool x, const bool y) { #if defined(__CUDACC__) return static_cast(powf(static_cast(x), static_cast(y))); #else return static_cast(std::pow(static_cast(x), static_cast(y))); #endif } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncPow); template<> struct BinaryFuncPow final { static inline const float16 Invoke(const float16 x, const float16 y) { return static_cast(std::pow(static_cast(x), static_cast(y))); } }; #if defined(__CUDACC__) template<> struct BinaryFuncPow final { static OF_DEVICE_FUNC double Invoke(const double x, const double y) { return pow(x, y); } }; template<> struct BinaryFuncPow final { static __device__ __forceinline__ float Invoke(const float x, const float y) { return powf(x, y); } }; template<> struct BinaryFuncPow final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __float2half(powf(__half2float(x), __half2float(y))); #else NO_HALF_UTIL_FOUND; #endif } }; #endif // defined(__CUDACC__) template struct BinaryFuncFloorDiv final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { #if defined(__CUDACC__) return floor(fdividef(x, y)); #else return std::floor(x / y); #endif } }; template struct BinaryFuncMax final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x > y ? x : y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMax); template struct BinaryFuncMin final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x < y ? x : y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMin); template struct BinaryFuncEQ final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x == y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncEQ); template struct BinaryFuncNE final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x != y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncNE); template struct BinaryFuncGT final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x > y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncGT); template struct BinaryFuncGE final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x >= y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncGE); template struct BinaryFuncLT final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x < y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncLT); template struct BinaryFuncLE final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x <= y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncLE); template struct BinaryFuncAND final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x && y; } }; template struct BinaryFuncAll final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return BinaryFuncAND::Invoke(x, y); } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncAND); template struct BinaryFuncOR final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x || y; } }; template struct BinaryFuncAny final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return BinaryFuncOR::Invoke(x, y); } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncOR); template struct BinaryFuncXOR final { static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return (!x) != (!y); } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncXOR); template struct BinaryFuncBitwiseAnd final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x & y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseAnd); template struct BinaryFuncBitwiseOr final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x | y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseOr); template struct BinaryFuncBitwiseXor final { static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x ^ y; } }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseXor); #if defined(__CUDACC__) template<> struct BinaryFuncAdd final { static __device__ __forceinline__ half Invoke(const half x, const half y) { return __hadd(x, y); } }; template<> struct BinaryFuncNanSum final { static __device__ __forceinline__ half Invoke(const half x, const half y) { if (isnan(__half2float(x))) return isnan(__half2float(y)) ? half(0.0) : y; return isnan(__half2float(y)) ? __hadd(x, y) : x; } }; template<> struct BinaryFuncSub final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hsub(x, y); #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct BinaryFuncMul final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hmul(x, y); #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct BinaryFuncDiv final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if __CUDA_ARCH__ >= 530 return __hdiv(x, y); #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct BinaryFuncMax final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hgt(x, y) ? x : y; #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct BinaryFuncMin final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hlt(x, y) ? x : y; #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct BinaryFuncAdd final { static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) { return cuComplex{x.x + y.x, x.y + y.y}; } }; template<> struct BinaryFuncSub final { static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) { return cuComplex{x.x - y.x, x.y - y.y}; } }; template<> struct BinaryFuncMul final { static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) { return cuCmulf(x, y); } }; template<> struct BinaryFuncAdd final { static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x, const cuDoubleComplex y) { return cuDoubleComplex{x.x + y.x, x.y + y.y}; } }; template<> struct BinaryFuncSub final { static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x, const cuDoubleComplex y) { return cuDoubleComplex{x.x - y.x, x.y - y.y}; } }; template<> struct BinaryFuncMul final { static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x, const cuDoubleComplex y) { return cuCmul(x, y); } }; #endif // defined(__CUDACC__) #if defined(__CUDACC__) template<> struct BinaryFuncFloorMod final { static __device__ __forceinline__ float Invoke(const float x, const float y) { const float trunc_mod = fmodf(x, y); return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod; } }; template<> struct BinaryFuncFloorMod final { static __device__ __forceinline__ double Invoke(const double x, const double y) { const double trunc_mod = fmod(x, y); return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod; } }; template<> struct BinaryFuncFloorMod final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if __CUDA_ARCH__ >= 530 const half trunc_mod = __float2half(fmodf(__half2float(x), __half2float(y))); return __hne(trunc_mod, GetZeroVal()) && __hlt(y, GetZeroVal()) != __hlt(trunc_mod, half(0)) ? trunc_mod + y : trunc_mod; #else NO_HALF_UTIL_FOUND; #endif } }; #else template<> struct BinaryFuncFloorMod final { static inline float Invoke(const float x, const float y) { const float trunc_mod = std::fmod(x, y); return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod; } }; template<> struct BinaryFuncFloorMod final { static inline double Invoke(const double x, const double y) { const double trunc_mod = std::fmod(x, y); return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod; } }; template<> struct BinaryFuncFloorMod final { static inline float16 Invoke(const float16 x, const float16 y) { const float trunc_mod = std::fmod(static_cast(x), static_cast(y)); return (trunc_mod != float(0)) && ((y < float(0)) != (trunc_mod < float(0))) ? static_cast(trunc_mod + y) : static_cast(trunc_mod); } }; #endif // defined(__CUDACC__) #if defined(__CUDACC__) template<> struct BinaryFuncFMod final { static __device__ __forceinline__ float Invoke(const float x, const float y) { const float trunc_mod = fmodf(x, y); return trunc_mod; } }; template<> struct BinaryFuncFMod final { static __device__ __forceinline__ double Invoke(const double x, const double y) { const double trunc_mod = fmod(x, y); return trunc_mod; } }; template<> struct BinaryFuncFMod final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if __CUDA_ARCH__ >= 530 const half trunc_mod = __float2half(fmodf(__half2float(x), __half2float(y))); return trunc_mod; #else NO_HALF_UTIL_FOUND; #endif } }; #else template<> struct BinaryFuncFMod final { static inline float Invoke(const float x, const float y) { const float trunc_mod = std::fmod(x, y); return trunc_mod; } }; template<> struct BinaryFuncFMod final { static inline double Invoke(const double x, const double y) { const double trunc_mod = std::fmod(x, y); return trunc_mod; } }; template<> struct BinaryFuncFMod final { static inline float16 Invoke(const float16 x, const float16 y) { const float trunc_mod = std::fmod(static_cast(x), static_cast(y)); return static_cast(trunc_mod); } }; #endif // defined(__CUDACC__) #if defined(__CUDACC__) template<> struct BinaryFuncFloorDiv final { static __device__ __forceinline__ uint8_t Invoke(uint8_t x, uint8_t y) { return x / y; } }; template<> struct BinaryFuncFloorDiv final { static __device__ __forceinline__ int8_t Invoke(int8_t x, int8_t y) { return x / y; } }; template<> struct BinaryFuncFloorDiv final { static __device__ __forceinline__ int32_t Invoke(int32_t x, int32_t y) { return x / y; } }; template<> struct BinaryFuncFloorDiv final { static __device__ __forceinline__ int64_t Invoke(int64_t x, int64_t y) { return x / y; } }; template<> struct BinaryFuncFloorDiv final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if __CUDA_ARCH__ >= 530 return __float2half(floor(fdividef(__half2float(x), __half2float(y)))); #else NO_HALF_UTIL_FOUND; #endif } }; #else template<> struct BinaryFuncFloorDiv final { static inline float16 Invoke(float16 x, float16 y) { return static_cast(std::floor(static_cast(x) / static_cast(y))); } }; #endif // defined(__CUDACC__) template class binary_func> struct UnitOfBinaryFunc; #define SPECIALIZE_UNIT_OF_BINARY_FUNC(binary_func, get_val) \ template \ struct UnitOfBinaryFunc final { \ static OF_DEVICE_FUNC T Val() { return get_val(); } \ }; SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAdd, GetZeroVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncNanSum, GetZeroVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncSum, GetZeroVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMul, GetOneVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncProd, GetOneVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMax, GetMinVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMin, GetMaxVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAny, GetZeroVal); SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAll, GetOneVal); #undef SPECIALIZE_UNIT_OF_BINARY_FUNC /* These placeholder specializations are used for `GetBinaryBroadcastSbpSignature` in oneflow/user/ops/math_binary_broadcast_ops.cpp */ #define SPECIALIZE_FOR_SBP(binary_func) \ template \ struct binary_func final {}; SPECIALIZE_FOR_SBP(BinaryFuncIEN); SPECIALIZE_FOR_SBP(BinaryFuncINN); SPECIALIZE_FOR_SBP(BinaryFuncZeta); #undef SPECIALIZE_FOR_SBP } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ ================================================ FILE: oneflow/core/ndarray/cpu_concat_var_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_ #include "oneflow/core/ndarray/cpu_ndarray.h" #include "oneflow/core/ndarray/cpu_var_ndarray.h" #include "oneflow/core/common/range.h" namespace oneflow { template class CpuConcatVarNdarray : public CpuNdarray { public: static const bool immutable = false; static_assert(CONCAT_AXES >= 0 && CONCAT_AXES < NDIMS, "CONCAT_AXES should be a valid dim"); CpuConcatVarNdarray(const std::vector>& var_ndarrays) : CpuNdarray(CalcConcatenatedShape(var_ndarrays)), var_ndarrays_(var_ndarrays), dim_ranges_(CalcDimRanges(var_ndarrays)), contiguous_lens_(CalcContiguousLens(var_ndarrays)) {} ~CpuConcatVarNdarray() = default; template void CopyFrom(const XT& ndarray) { CpuNdarrayCopy(this, ndarray); } void GetMutPtrAndContiguousSize(int64_t offset, T** ptr, size_t* size) const { int64_t dim[NDIMS] = {0}; this->xpu_shape().template Offset2Coordinate(offset, dim); int32_t var_index = 0; this->GetVarNdarrayIndexAndInputDim(dim[CONCAT_AXES], &var_index, &dim[CONCAT_AXES]); int64_t input_offset = this->var_ndarray(var_index).xpu_shape().template Coordinate2Offset(dim); this->GetMutPtrAndMinContiguousSize(var_index, input_offset, ptr, size); } protected: ALWAYS_INLINE void GetVarNdarrayIndexAndInputDim(int64_t output_dim, int32_t* var_index, int64_t* input_dim) const { *var_index = CpuVarNdarrayIndex4OutputDim(output_dim); *input_dim = output_dim - dim_ranges_[*var_index].begin(); } ALWAYS_INLINE const CpuVarNdarray var_ndarray(int32_t var_index) const { return var_ndarrays_[var_index]; } ALWAYS_INLINE void GetMutPtrAndMinContiguousSize(int32_t var_index, int64_t var_offset, T** ptr, size_t* size) const { size_t var_contiguous_size = 0; var_ndarray(var_index).GetMutPtrAndContiguousSize(var_offset, ptr, &var_contiguous_size); *size = std::min(var_contiguous_size, static_cast(contiguous_lens_[var_index] - var_offset % contiguous_lens_[var_index])); } private: ALWAYS_INLINE int32_t CpuVarNdarrayIndex4OutputDim(int64_t output_dim) const { // TODO change to bianry search FOR_RANGE(int32_t, i, 0, dim_ranges_.size()) { if (output_dim >= dim_ranges_[i].begin() && output_dim < dim_ranges_[i].end()) { return i; } } UNIMPLEMENTED(); } XpuShape CalcConcatenatedShape(const std::vector>& var_ndarrays) const { CheckInputShape(var_ndarrays); XpuShape xpu_shape(var_ndarrays[0].xpu_shape()); int64_t axes_dim_num = 0; FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) { axes_dim_num += var_ndarrays[i].xpu_shape().At(CONCAT_AXES); } xpu_shape.Set(CONCAT_AXES, axes_dim_num); return xpu_shape; } void CheckInputShape(const std::vector>& var_ndarrays) const { FOR_RANGE(int32_t, i, 1, var_ndarrays.size()) { FOR_RANGE(int32_t, j, 0, NDIMS) { if (j == CONCAT_AXES) { continue; } CHECK_EQ(var_ndarrays[0].xpu_shape().At(j), var_ndarrays[i].xpu_shape().At(j)); } } } std::vector CalcDimRanges(const std::vector>& var_ndarrays) const { int64_t axes_dim_num = 0; std::vector ret; FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) { ret.emplace_back( Range(axes_dim_num, axes_dim_num + var_ndarrays[i].xpu_shape().At(CONCAT_AXES))); axes_dim_num += var_ndarrays[i].xpu_shape().At(CONCAT_AXES); } return ret; } std::vector CalcContiguousLens( const std::vector>& var_ndarrays) const { std::vector ret(var_ndarrays.size(), 0); FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) { ret[i] = var_ndarrays[i].xpu_shape().Count(CONCAT_AXES); } return ret; } const std::vector> var_ndarrays_; const std::vector dim_ranges_; const std::vector contiguous_lens_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/cpu_concat_var_ndarray_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/cpu_ndarray_builder.h" #include namespace oneflow { namespace test { TEST(CpuConcatVarNdarray, two_elem_concat) { std::vector x0_data{0}; std::vector x1_data{1}; std::vector buffer{-1, -1}; std::vector expected{0, 1}; CpuNdarrayBuilder ndarray; auto x0 = ndarray.Var(Shape{1LL}, x0_data.data()); auto x1 = ndarray.Var(Shape{1LL}, x1_data.data()); ndarray.Var(Shape{2LL}, buffer.data()).CopyFrom(ndarray.Concatenate({x0, x1})); ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 2), 0); } TEST(CpuConcatVarNdarray, two_elem_concat_assign) { std::vector x0_data{-1}; std::vector x1_data{-1}; std::vector buffer{0, 1}; CpuNdarrayBuilder ndarray; auto x0 = ndarray.Var(Shape{1LL}, x0_data.data()); auto x1 = ndarray.Var(Shape{1LL}, x1_data.data()); ndarray.Concatenate({x0, x1}).CopyFrom(ndarray.Var(Shape{2LL}, buffer.data())); ASSERT_EQ(x0_data[0], 0); ASSERT_EQ(x1_data[0], 1); } TEST(CpuConcatVarNdarray, 2d_concat) { // clang-format off std::vector x0_data{ 0, 1, 2, 5, 6, 7, }; std::vector x1_data{ 3, 4, 8, 9, }; std::vector expected{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, }; std::vector buffer(10, -1); // clang-format on CpuNdarrayBuilder ndarray; auto x0 = ndarray.Var(Shape{2LL, 3LL}, x0_data.data()); auto x1 = ndarray.Var(Shape{2LL, 2LL}, x1_data.data()); ndarray.Var(Shape{2LL, 5LL}, buffer.data()).CopyFrom(ndarray.template Concatenate<1>({x0, x1})); ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 10), 0); } TEST(CpuConcatVarNdarray, 2d_concat_assign) { // clang-format off std::vector x_data{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, }; std::vector y0_buffer(6, -1); std::vector y1_buffer(4, -1); std::vector y0_expected{ 0, 1, 2, 5, 6, 7, }; std::vector y1_expected{ 3, 4, 8, 9, }; // clang-format on CpuNdarrayBuilder ndarray; auto x = ndarray.Var(Shape{2LL, 5LL}, x_data.data()); auto y0 = ndarray.Var(Shape{2LL, 3LL}, y0_buffer.data()); auto y1 = ndarray.Var(Shape{2LL, 2LL}, y1_buffer.data()); ndarray.template Concatenate<1>({y0, y1}).CopyFrom(x); ASSERT_EQ(memcmp(y0_buffer.data(), y0_expected.data(), sizeof(int32_t) * 6), 0); ASSERT_EQ(memcmp(y1_buffer.data(), y1_expected.data(), sizeof(int32_t) * 4), 0); } TEST(CpuConcatVarNdarray, 3d_concat) { // clang-format off std::vector x0_data{ 0, 1, 2, 5, 6, 7, 10,11,12, 15,16,17 }; std::vector x1_data{ 3, 4, 8, 9, 13,14, 18,19, }; std::vector expected{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14, 15,16,17,18,19, }; std::vector buffer(20, -1); // clang-format on CpuNdarrayBuilder ndarray; auto x0 = ndarray.Var(Shape{2LL, 2LL, 3LL}, x0_data.data()); auto x1 = ndarray.Var(Shape{2LL, 2LL, 2LL}, x1_data.data()); ndarray.Var(Shape{2LL, 2LL, 5LL}, buffer.data()) .CopyFrom(ndarray.template Concatenate<2>({x0, x1})); ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 20), 0); } TEST(CpuConcatVarNdarray, 3d_concat_assign) { // clang-format off std::vector x_data{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14, 15,16,17,18,19, }; std::vector y0_expected{ 0, 1, 2, 5, 6, 7, 10,11,12, 15,16,17 }; std::vector y1_expected{ 3, 4, 8, 9, 13,14, 18,19, }; std::vector y0_buffer(2*2*3, -1); std::vector y1_buffer(2*2*2, -1); // clang-format on CpuNdarrayBuilder ndarray; auto x = ndarray.Var(Shape{2LL, 2LL, 5LL}, x_data.data()); auto y0 = ndarray.Var(Shape{2LL, 2LL, 3LL}, y0_buffer.data()); auto y1 = ndarray.Var(Shape{2LL, 2LL, 2LL}, y1_buffer.data()); ndarray.template Concatenate<2>({y0, y1}).CopyFrom(x); ASSERT_EQ(memcmp(y0_buffer.data(), y0_expected.data(), sizeof(int32_t) * y0_expected.size()), 0); ASSERT_EQ(memcmp(y1_buffer.data(), y1_expected.data(), sizeof(int32_t) * y1_expected.size()), 0); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/cpu_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_ #include #include "oneflow/core/common/shape.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ndarray/xpu_shape.h" namespace oneflow { template class CpuNdarray { public: using dtype = T; static const int ndims = NDIMS; ALWAYS_INLINE const XpuShape& xpu_shape() const { return xpu_shape_; } protected: explicit CpuNdarray(const Shape& shape) : xpu_shape_(shape) {} explicit CpuNdarray(const XpuShape& xpu_shape) : xpu_shape_(xpu_shape) {} virtual ~CpuNdarray() = default; private: XpuShape xpu_shape_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/cpu_ndarray_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_ #define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_ #include "oneflow/core/ndarray/cpu_var_ndarray.h" #include "oneflow/core/ndarray/cpu_slice_var_ndarray.h" #include "oneflow/core/ndarray/cpu_concat_var_ndarray.h" namespace oneflow { template class CpuNdarrayBuilder final { public: OF_DISALLOW_COPY_AND_MOVE(CpuNdarrayBuilder); CpuNdarrayBuilder() = default; ~CpuNdarrayBuilder() = default; template CpuVarNdarray Var(const Shape& shape, T* ptr) const { return CpuVarNdarray(shape, ptr); } template CpuVarNdarray Var(const ShapeView& shape_view, T* ptr) const { return CpuVarNdarray(shape_view, ptr); } template CpuConcatVarNdarray Concatenate( const std::vector>& var_ndarrays) const { return CpuConcatVarNdarray(var_ndarrays); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_ ================================================ FILE: oneflow/core/ndarray/cpu_ndarray_copy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_ #define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_ #include "oneflow/core/ndarray/cpu_ndarray.h" namespace oneflow { template void CpuNdarrayCopy(YT* y_ndarray, const XT& x_ndarray) { CHECK_EQ(y_ndarray->xpu_shape().ElemNum(), x_ndarray.xpu_shape().ElemNum()); T* dst_ptr = nullptr; size_t dst_size = 0; T* src_ptr = nullptr; size_t src_size = 0; int64_t cur_index = 0; size_t total_elem_cnt = y_ndarray->xpu_shape().ElemNum(); while (cur_index < total_elem_cnt) { if (dst_size == 0) { y_ndarray->GetMutPtrAndContiguousSize(cur_index, &dst_ptr, &dst_size); } if (src_size == 0) { x_ndarray.GetMutPtrAndContiguousSize(cur_index, &src_ptr, &src_size); } if (src_size == 0) { break; } size_t cp_size = std::min(dst_size, src_size); if (cp_size == 1) { *dst_ptr = *src_ptr; } else { memcpy(dst_ptr, src_ptr, sizeof(T) * cp_size); } dst_ptr += cp_size; src_ptr += cp_size; dst_size -= cp_size; src_size -= cp_size; cur_index += cp_size; } CHECK_EQ(dst_size, 0); CHECK_EQ(src_size, 0); CHECK_EQ(cur_index, total_elem_cnt); } } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_ ================================================ FILE: oneflow/core/ndarray/cpu_slice_var_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_ #include "oneflow/core/ndarray/slice.h" #include "oneflow/core/ndarray/cpu_ndarray.h" #include "oneflow/core/ndarray/cpu_ndarray_copy.h" namespace oneflow { template class CpuSliceVarNdarray : public CpuNdarray { public: CpuSliceVarNdarray(XT&& x, std::array&& slices) : CpuNdarray( BoundedSlices2Shape(BoundSlices(x, std::move(slices)))), x_(x), slices_(std::move(slices)) { SetContiguousLength(slices); } virtual ~CpuSliceVarNdarray() = default; CpuSliceVarNdarray> operator()(Slice&& slice0) { static_assert(XT::ndims == 1, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1) { static_assert(XT::ndims == 2, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2) { static_assert(XT::ndims == 3, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2, Slice&& slice3) { static_assert(XT::ndims == 4, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2, slice3}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2, Slice&& slice3, Slice&& slice4) { static_assert(XT::ndims == 5, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2, slice3, slice4}); } template void CopyFrom(const AT& ndarray) { CpuNdarrayCopy(this, ndarray); } ALWAYS_INLINE void GetMutPtrAndContiguousSize(int64_t offset, typename XT::dtype** ptr, size_t* size) const { int64_t dim[XT::ndims] = {0}; this->xpu_shape().template Offset2Coordinate(offset, dim); for (int i = 0; i < XT::ndims; ++i) { dim[i] = this->slice(i).Get(dim[i]); } size_t x_offset = this->x().xpu_shape().template Coordinate2Offset(dim); this->GetMutPtrAndMinContiguousSize(offset, x_offset, ptr, size); } protected: ALWAYS_INLINE const XT& x() const { return x_; } ALWAYS_INLINE const Slice& slice(int32_t dim) const { return slices_[dim]; } ALWAYS_INLINE void GetMutPtrAndMinContiguousSize(int64_t offset, int64_t x_offset, typename XT::dtype** ptr, size_t* size) const { size_t x_contiguous_size; this->x().GetMutPtrAndContiguousSize(x_offset, ptr, &x_contiguous_size); size_t slice_contiguous_size = (contiguous_len_ - offset % contiguous_len_); *size = std::min(x_contiguous_size, slice_contiguous_size); } private: static std::array&& BoundSlices(const XT& x, std::array&& slices) { FOR_RANGE(int32_t, i, 0, XT::ndims) { slices[i].Bound(x.xpu_shape().At(i)); } return std::move(slices); } static Shape BoundedSlices2Shape(const std::array& bounded_slices) { DimVector dim_vec; for (const Slice& slice : bounded_slices) { CHECK_GT(slice.Size(), 0); dim_vec.emplace_back(slice.Size()); } return Shape(dim_vec); } void SetContiguousLength(const std::array& bounded_slices) { contiguous_len_ = 1; for (int i = XT::ndims - 1; i >= 0; --i) { if (bounded_slices[i].IsContiguous()) { contiguous_len_ *= bounded_slices[i].Size(); } if (!(bounded_slices[i].IsContiguous() && bounded_slices[i].IsCoveringAll())) { break; } } } const XT& x_; std::array slices_; size_t contiguous_len_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/cpu_slice_var_ndarray_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/cpu_ndarray_builder.h" #include namespace oneflow { namespace test { TEST(CpuSliceVarNdarray, one_elem_assign) { std::vector data({1}); std::vector buffer({0}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data()); buffer_ndarray(0).CopyFrom(data_ndarray(0)); ASSERT_EQ(data[0], buffer[0]); } TEST(CpuSliceVarNdarray, one_elem_assign_slice_on_slice) { std::vector data({1}); std::vector buffer({0}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data()); buffer_ndarray(0)(0).CopyFrom(data_ndarray(0)(0)); ASSERT_EQ(data[0], buffer[0]); } TEST(CpuSliceVarNdarray, 1d_assign) { std::vector data({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); std::vector buffer(10, 0); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{10LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data()); buffer_ndarray({}).CopyFrom(data_ndarray({})); ASSERT_EQ(memcmp(data.data(), buffer.data(), sizeof(int32_t) * 10), 0); } TEST(CpuSliceVarNdarray, 1d_slice_assign) { std::vector data({1, 2, 3, 4, 5, 6, 7, 8}); std::vector buffer(10, 100); std::vector expected({100, 1, 2, 3, 4, 5, 6, 7, 8, 100}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{static_cast(data.size())}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data()); ASSERT_EQ(buffer_ndarray({1, -1}).xpu_shape(), XpuShape(Shape({8}))); buffer_ndarray({1, -1}).CopyFrom(data_ndarray({})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * 10), 0); } TEST(CpuSliceVarNdarray, 1d_slice) { std::vector data({100, 1, 2, 3, 4, 5, 6, 7, 8, 100}); std::vector buffer(8, 100); std::vector expected({1, 2, 3, 4, 5, 6, 7, 8}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{static_cast(data.size())}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{static_cast(buffer.size())}, buffer.data()); buffer_ndarray({}).CopyFrom(data_ndarray({1, -1})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } TEST(CpuSliceVarNdarray, 2d_slice) { // clang-format off std::vector data({ 100, 100, 100, 100, 100, 0, 1, 100, 100, 2, 3, 100, 100, 100, 100, 100, }); // clang-format on std::vector buffer(4, 100); std::vector expected({0, 1, 2, 3}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{4LL, 4LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 2LL}, buffer.data()); buffer_ndarray({}, {}).CopyFrom(data_ndarray({1, -1}, {1, -1})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } TEST(CpuSliceVarNdarray, 2d_slice_assign) { std::vector data({0, 1, 2, 3}); std::vector buffer(16, 100); // clang-format off std::vector expected({ 100, 100, 100, 100, 100, 0, 1, 100, 100, 2, 3, 100, 100, 100, 100, 100, }); // clang-format on CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{2LL, 2LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{4LL, 4LL}, buffer.data()); buffer_ndarray({1, -1}, {1, -1}).CopyFrom(data_ndarray({}, {})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } TEST(CpuSliceVarNdarray, 2d_slice_reverse) { // clang-format off std::vector data({ 100, 100, 100, 100, 100, 0, 1, 100, 100, 2, 3, 100, 100, 100, 100, 100, }); std::vector buffer(16, 100); std::vector expected({ 100, 100, 100, 100, 100, 2, 3, 100, 100, 0, 1, 100, 100, 100, 100, 100, }); // clang-format on CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{4LL, 4LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{4LL, 4LL}, buffer.data()); buffer_ndarray({1, -1}, {1, -1}).CopyFrom(data_ndarray({-2, 0, -1}, {1, -1})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } TEST(CpuSliceVarNdarray, 3d_slice) { // clang-format off std::vector data({ 100, 100, 100, 100, 100, 0, 1, 100, 100, 2, 3, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 4, 5, 100, 100, 6, 7, 100, 100, 100, 100, 100, }); std::vector buffer(8, -1); std::vector expected({ 0, 1, 2, 3, 4, 5, 6, 7 }); // clang-format on CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{2LL, 4LL, 4LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 2LL, 2LL}, buffer.data()); buffer_ndarray.CopyFrom(data_ndarray({}, {1, -1}, {1, -1})); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } TEST(CpuSliceVarNdarray, 3d_slice_assign) { // clang-format off std::vector data({ 0, 1, 2, 3, 4, 5, 6, 7 }); std::vector buffer(32, 100); std::vector expected({ 100, 100, 100, 100, 100, 0, 1, 100, 100, 2, 3, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 4, 5, 100, 100, 6, 7, 100, 100, 100, 100, 100, }); // clang-format on CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{2LL, 2LL, 2LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 4LL, 4LL}, buffer.data()); buffer_ndarray({}, {1, -1}, {1, -1}).CopyFrom(data_ndarray); ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/cpu_var_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_ #include "oneflow/core/ndarray/cpu_ndarray.h" #include "oneflow/core/ndarray/cpu_ndarray_copy.h" namespace oneflow { class Slice; template class CpuSliceVarNdarray; template class CpuVarNdarray : public CpuNdarray { public: CpuVarNdarray(const CpuVarNdarray&) = default; CpuVarNdarray(const Shape& shape, T* ptr) : CpuNdarray(shape), ptr_(ptr), len_(shape.elem_cnt()) { CHECK_GT(len_, 0); } CpuVarNdarray(const ShapeView& shape_view, T* ptr) : CpuNdarray(XpuShape(shape_view)), ptr_(ptr), len_(shape_view.elem_cnt()) { CHECK_GT(len_, 0); } virtual ~CpuVarNdarray() = default; CpuSliceVarNdarray> operator()(Slice&& slice0) { static_assert(NDIMS == 1, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1) { static_assert(NDIMS == 2, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2) { static_assert(NDIMS == 3, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2, Slice&& slice3) { static_assert(NDIMS == 4, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2, slice3}); } CpuSliceVarNdarray> operator()(Slice&& slice0, Slice&& slice1, Slice&& slice2, Slice&& slice3, Slice&& slice4) { static_assert(NDIMS == 5, "NDIMS error"); return CpuSliceVarNdarray>(std::move(*this), {slice0, slice1, slice2, slice3, slice4}); } template void CopyFrom(const XT& ndarray) { CpuNdarrayCopy(this, ndarray); } ALWAYS_INLINE void GetMutPtrAndContiguousSize(int64_t offset, T** ptr, size_t* size) const { *ptr = ptr_ + offset; *size = len_ - offset; } protected: ALWAYS_INLINE T* ptr() const { return ptr_; } ALWAYS_INLINE size_t len() const { return len_; } private: T* const ptr_; size_t len_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/cpu_var_ndarray_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/cpu_ndarray_builder.h" #include namespace oneflow { namespace test { TEST(CpuVarNdarray, one_elem_assign) { std::vector data({1}); std::vector buffer({0}); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data()); buffer_ndarray.CopyFrom(data_ndarray); ASSERT_EQ(data[0], buffer[0]); } TEST(CpuVarNdarray, 1d_assign) { std::vector data({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); std::vector buffer(10, 0); CpuNdarrayBuilder ndarray; auto&& data_ndarray = ndarray.Var(Shape{10LL}, data.data()); auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data()); buffer_ndarray.CopyFrom(data_ndarray); ASSERT_EQ(memcmp(data.data(), buffer.data(), sizeof(int32_t) * 10), 0); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_binary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ndarray/ndarray_apply_binary_core.h" namespace oneflow { template class binary_func, typename Enable = void> struct NdarrayApplyBinary; template class binary_func> struct NdarrayApplyBinary< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { NdarrayApplyBinaryCoreWrapper::Apply(stream, y, a, b); } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayApplyBinaryCoreWrapper::InplaceApply(stream, y, x); } }; template class binary_func> struct NdarrayApplyBinary< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { using NewT = typename DevDType::type; static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { return NdarrayApplyBinary::Apply( stream, reinterpret_cast&>(y), reinterpret_cast&>(a), reinterpret_cast&>(b)); } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { return NdarrayApplyBinary::InplaceApply( stream, reinterpret_cast&>(y), reinterpret_cast&>(x)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_binary_core.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_binary_core.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { template class binary_func> struct NdarrayApplyBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { NdarrayApplyBinaryCore::Apply(y.shape().ElemNum(), y.ptr(), a.ptr(), b.ptr()); } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayApplyBinaryCore::InplaceApply(y.shape().ElemNum(), y.ptr(), x.ptr()); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_binary_core.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_binary_core.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { namespace { template class binary_func> __global__ void NdarrayApplyBinaryApplyGpu(size_t n, typename BinaryFuncTrait::return_type* y, const T* a, const T* b) { NdarrayApplyBinaryCore::Apply(n, y, a, b); } template class binary_func> __global__ void NdarrayApplyBinaryInplaceApplyGpu(size_t n, T* y, const T* x) { NdarrayApplyBinaryCore::InplaceApply(n, y, x); } } // namespace template class binary_func> struct NdarrayApplyBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu), stream, n, n, y.host_ptr(), a.host_ptr(), b.host_ptr()); } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu), stream, n, n, y.host_ptr(), x.host_ptr()); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_binary_core.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ndarray/xpu_binary_func_ndarray.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { template class binary_func> struct NdarrayApplyBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b); static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x); }; template class binary_func> struct NdarrayApplyBinaryCore final { OF_DEVICE_FUNC static void Apply(size_t n, typename BinaryFuncTrait::return_type* y, const T* a, const T* b) { XPU_1D_KERNEL_LOOP_BEGIN(i, n); y[i] = binary_func::Invoke(a[i], b[i]); XPU_1D_KERNEL_LOOP_END(); } OF_DEVICE_FUNC static void InplaceApply(size_t n, T* y, const T* x) { XPU_1D_KERNEL_LOOP_BEGIN(i, n); y[i] = binary_func::Invoke(y[i], x[i]); XPU_1D_KERNEL_LOOP_END(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_binary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ #include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" #include "oneflow/core/ndarray/ndarray_apply_binary.h" #include "oneflow/core/common/util.h" namespace oneflow { template class binary_func, typename Enable = void> struct NdarrayApplyBroadcastBinary; template class binary_func> struct NdarrayApplyBroadcastBinary< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { using RetT = typename BinaryFuncTrait::return_type; static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { if (a.shape() == b.shape()) { return NdarrayApplyBinary::Apply(stream, y, a, b); } if (TryInplaceApply::value>(stream, y, a, b)) { return; } CheckBroadcastable(y, a, b); DimVector simplified_y_dim; DimVector simplified_a_dim; DimVector simplified_b_dim; SimplifyBroadcastShapes(y.shape(), a.shape(), b.shape(), &simplified_y_dim, &simplified_a_dim, &simplified_b_dim); return SwitchApply(SwitchCase(simplified_y_dim.size()), stream, XpuVarNdarray(Shape(simplified_y_dim), y.ptr()), XpuVarNdarray(Shape(simplified_a_dim), a.ptr()), XpuVarNdarray(Shape(simplified_b_dim), b.ptr())); } template static typename std::enable_if::type TryInplaceApply( ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { bool is_inplace = (y.shape() == a.shape() && y.ptr() == a.ptr()); if (is_inplace) { InplaceApply(stream, y, b); } return is_inplace; } template static typename std::enable_if::type TryInplaceApply( ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { return false; } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { if (y.shape() == x.shape()) { return NdarrayApplyBinary::InplaceApply(stream, y, x); } CheckBroadcastable(y, reinterpret_cast&>(y), x); DimVector simplified_y_dim; DimVector simplified_x_dim; SimplifyBroadcastShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim); return SwitchInplaceApply(SwitchCase(simplified_y_dim.size()), stream, XpuVarNdarray(Shape(simplified_y_dim), y.ptr()), XpuVarNdarray(Shape(simplified_x_dim), x.ptr())); } private: #define MAKE_NDARRAY_BROADCAST_BINARY(func_name, NDIMS) \ NdarrayApplyBroadcastBinaryCoreWrapper::func_name DEFINE_STATIC_SWITCH_FUNC(void, Apply, MAKE_NDARRAY_BROADCAST_BINARY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) #undef MAKE_NDARRAY_BROADCAST_BINARY #define MAKE_NDARRAY_INPLACE_BROADCAST_BINARY(func_name, NDIMS) \ NdarrayApplyBroadcastInplaceBinaryCoreWrapper::func_name DEFINE_STATIC_SWITCH_FUNC(void, InplaceApply, MAKE_NDARRAY_INPLACE_BROADCAST_BINARY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) #undef MAKE_NDARRAY_INPLACE_BROADCAST_BINARY static void CheckBroadcastable( const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes()); CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes()); for (int i = 0; i < y.shape().NumAxes(); ++i) { CHECK_EQ(y.shape().At(i), (a.shape().At(i) == 0 || b.shape().At(i) == 0) ? 0 : std::max(a.shape().At(i), b.shape().At(i))); if (a.shape().At(i) != b.shape().At(i)) { CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1); } } } }; template class binary_func> struct NdarrayApplyBroadcastBinary< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { using NewT = typename DevDType::type; static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { return NdarrayApplyBroadcastBinary::Apply( stream, reinterpret_cast&>(y), reinterpret_cast&>(a), reinterpret_cast&>(b)); } static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { return NdarrayApplyBroadcastBinary::InplaceApply( stream, reinterpret_cast&>(y), reinterpret_cast&>(x)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" namespace oneflow { template class binary_func> struct NdarrayApplyBroadcastBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { NdarrayApplyBroadcastBinaryCore::Apply(y, a, b); } }; template class binary_func> struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayApplyBroadcastBinaryCore::InplaceApply(y, x); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" namespace oneflow { namespace { template struct XY2XFunctor final { __host__ __device__ XY2XFunctor(Index dim_y) : dim_y_(dim_y) {} __host__ __device__ Index operator()(Index idx) const { return idx / dim_y_; } Index dim_y_; }; template struct XY2YFunctor final { __host__ __device__ XY2YFunctor(Index dim_y) : dim_y_(dim_y) {} __host__ __device__ Index operator()(Index idx) const { return idx % dim_y_; } Index dim_y_; }; template struct XYZ2XZFunctor final { __host__ __device__ XYZ2XZFunctor(Index dim_y, Index dim_z) : dim_yz_(dim_y * dim_z), dim_z_(dim_z) {} __host__ __device__ Index operator()(Index idx) const { const Index x = idx / dim_yz_; const Index z = (idx % dim_yz_) % dim_z_; return x * dim_z_ + z; } Index dim_yz_; Index dim_z_; }; template struct XYZ2YFunctor final { __host__ __device__ XYZ2YFunctor(Index dim_y, Index dim_z) : dim_yz_(dim_y * dim_z), dim_z_(dim_z) {} __host__ __device__ Index operator()(Index idx) const { return (idx % dim_yz_) / dim_z_; } Index dim_yz_; Index dim_z_; }; template class binary_func, typename OffsetFunctor> __global__ void PartialBroadcastGpu(K n, typename BinaryFuncTrait::return_type* y, const T* a, const T* b, OffsetFunctor offset_functor) { CUDA_1D_KERNEL_LOOP_T(K, i, n) { y[i] = binary_func::Invoke(a[i], b[offset_functor(i)]); } } template class binary_func> __global__ void GpuBroadcastBinaryFunc( const XpuVarNdarray::return_type> y, const XpuVarNdarray a, const XpuVarNdarray b) { NdarrayApplyBroadcastBinaryCore::Apply(y, a, b); } template class binary_func> __global__ void GpuInplaceBroadcastBinaryFunc(const XpuVarNdarray y, const XpuVarNdarray x) { NdarrayApplyBroadcastBinaryCore::InplaceApply(y, x); } } // namespace template class binary_func> struct NdarrayApplyBroadcastBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } if (IsKernelSafeInt32(n) && PartialBroadcast(stream, y, a, b)) { return; } if (!IsKernelSafeInt32(n) && PartialBroadcast(stream, y, a, b)) { return; } RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc), stream, n, y, a, b); } template static bool PartialBroadcast( ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); if (y.host_shape() == a.host_shape()) { if (y.host_shape().NumAxes() == 2) { const K y_dim0 = y.host_shape().At(0); const K y_dim1 = y.host_shape().At(1); const K b_dim0 = b.host_shape().At(0); const K b_dim1 = b.host_shape().At(1); if (b_dim0 == y_dim0 && b_dim1 == 1) { XY2XFunctor xy2x(y_dim1); RUN_CUDA_KERNEL((PartialBroadcastGpu>), stream, n, n, y.host_ptr(), a.host_ptr(), b.host_ptr(), xy2x); return true; } if (b_dim0 == 1 && b_dim1 == y_dim1) { XY2YFunctor xy2y(y_dim1); RUN_CUDA_KERNEL((PartialBroadcastGpu>), stream, n, n, y.host_ptr(), a.host_ptr(), b.host_ptr(), xy2y); return true; } } if (y.host_shape().NumAxes() == 3) { const K y_dim0 = y.host_shape().At(0); const K y_dim1 = y.host_shape().At(1); const K y_dim2 = y.host_shape().At(2); const K b_dim0 = b.host_shape().At(0); const K b_dim1 = b.host_shape().At(1); const K b_dim2 = b.host_shape().At(2); if (b_dim0 == y_dim0 && b_dim1 == 1 && b_dim2 == y_dim2) { XYZ2XZFunctor xyz2xz(y_dim1, y_dim2); RUN_CUDA_KERNEL((PartialBroadcastGpu>), stream, n, n, y.host_ptr(), a.host_ptr(), b.host_ptr(), xyz2xz); return true; } if (b_dim0 == 1 && b_dim1 == y_dim1 && b_dim2 == 1) { XYZ2YFunctor xyz2y(y_dim1, y_dim2); RUN_CUDA_KERNEL((PartialBroadcastGpu>), stream, n, n, y.host_ptr(), a.host_ptr(), b.host_ptr(), xyz2y); return true; } } } return false; } }; template class binary_func> struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); XpuVarNdarray a(y.host_shape(), y.host_ptr()); using NBB = NdarrayApplyBroadcastBinaryCoreWrapper; if (n == 0) { return; } if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast(stream, y, a, x)) { return; } if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast(stream, y, a, x)) { return; } RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc), stream, n, y, x); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_broadcast_ndarray.h" #include "oneflow/core/ndarray/xpu_binary_func_ndarray.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template class binary_func> struct NdarrayApplyBroadcastBinaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b); }; template class binary_func> struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x); }; template class binary_func> struct NdarrayApplyBroadcastBinaryCore final { OF_DEVICE_FUNC static void Apply( const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { const auto& ret = a.Broadcast(y.shape()).template BinaryFunc(b.Broadcast(y.shape())); y.template Assign(ret); } OF_DEVICE_FUNC static void InplaceApply(const XpuVarNdarray& y, const XpuVarNdarray& x) { y.template BinaryAssign(x.Broadcast(y.shape())); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ #include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" #include "oneflow/core/common/util.h" namespace oneflow { template class unary_func, typename Enable = void> struct NdarrayApplyBroadcastUnary; template class unary_func> struct NdarrayApplyBroadcastUnary< device_type, T, unary_func, typename std::enable_if::type>::value>::type> final { static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { CheckBroadcastable(y, x); DimVector simplified_y_dim; DimVector simplified_x_dim; SimplifyBroadcastShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim); SwitchApply(SwitchCase(simplified_y_dim.size()), stream, XpuVarNdarray(Shape(simplified_y_dim), y.ptr()), XpuVarNdarray(Shape(simplified_x_dim), x.ptr())); } private: #define DEFINE_NDARRAY_BROADCAST_UNARY(func_name, NDIMS) \ NdarrayApplyBroadcastUnaryCoreWrapper::func_name DEFINE_STATIC_SWITCH_FUNC(void, Apply, DEFINE_NDARRAY_BROADCAST_UNARY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) #undef DEFINE_NDARRAY_BROADCAST_UNARY static void CheckBroadcastable(const XpuVarNdarray& y, const XpuVarNdarray& x) { CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); for (int i = 0; i < y.shape().NumAxes(); ++i) { CHECK(x.shape().At(i) == 1 || x.shape().At(i) == y.shape().At(i)); } } }; template class unary_func> struct NdarrayApplyBroadcastUnary< device_type, T, unary_func, typename std::enable_if::type>::value>::type> final { static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { using NewT = typename DevDType::type; return NdarrayApplyBroadcastUnary::Apply( stream, reinterpret_cast&>(y), reinterpret_cast&>(x)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" namespace oneflow { template class unary_func> struct NdarrayApplyBroadcastUnaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayApplyBroadcastUnaryCore::Apply(y, x); } }; #define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \ template struct NdarrayApplyBroadcastUnaryCoreWrapper< \ DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ, DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" namespace oneflow { namespace { template class unary_func> __global__ void GpuBroadcastUnaryFunc(const XpuVarNdarray y, const XpuVarNdarray x) { NdarrayApplyBroadcastUnaryCore::Apply(y, x); } } // namespace template class unary_func> struct NdarrayApplyBroadcastUnaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc), stream, n, y, x); } }; #define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \ template struct NdarrayApplyBroadcastUnaryCoreWrapper< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ) } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_broadcast_ndarray.h" #include "oneflow/core/ndarray/xpu_unary_func_ndarray.h" #include "oneflow/core/ndarray/unary_func.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template class unary_func> struct NdarrayApplyBroadcastUnaryCoreWrapper final { static void Apply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x); }; template class unary_func> struct NdarrayApplyBroadcastUnaryCore final { OF_DEVICE_FUNC static void Apply(const XpuVarNdarray& y, const XpuVarNdarray& x) { y.template Assign(x.Broadcast(y.shape()).template UnaryFunc()); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_unary.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ndarray/ndarray_apply_unary_core.h" namespace oneflow { template class unary_func, typename Enable = void> struct NdarrayApplyUnary; template class unary_func> struct NdarrayApplyUnary< device_type, T, unary_func, typename std::enable_if::type>::value>::type> final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y) { NdarrayApplyUnaryCoreWrapper::InplaceApply(stream, y); } }; template class unary_func> struct NdarrayApplyUnary< device_type, T, unary_func, typename std::enable_if::type>::value>::type> final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y) { using NewT = typename DevDType::type; return NdarrayApplyUnary::InplaceApply( stream, reinterpret_cast&>(y)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_apply_unary_core.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_unary_core.h" #include "oneflow/core/ndarray/unary_func.h" namespace oneflow { template class unary_func> struct NdarrayApplyUnaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y) { NdarrayApplyUnaryCore::InplaceApply(y.ptr(), y.shape().ElemNum()); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_unary_core.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_apply_unary_core.h" #include "oneflow/core/ndarray/unary_func.h" namespace oneflow { namespace { template class unary_func> __global__ void NdarrayApplyUnaryInplaceApplyGpu(T* ptr, size_t n) { NdarrayApplyUnaryCore::InplaceApply(ptr, n); } } // namespace template class unary_func> struct NdarrayApplyUnaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu), stream, n, y.host_ptr(), n); } }; } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_apply_unary_core.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ndarray/xpu_unary_func_ndarray.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { template class unary_func> struct NdarrayApplyUnaryCoreWrapper final { static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y); }; template class unary_func> struct NdarrayApplyUnaryCore final { OF_DEVICE_FUNC static void InplaceApply(T* y, size_t n) { XPU_1D_KERNEL_LOOP(i, n) { y[i] = unary_func::Invoke(y[i]); } } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_assign_core.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_assign_core.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct NdarrayAssignCoreWrapper final { static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { NdarrayAssignCore::Assign(y, reduced); } static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayAssignCore::Assign(y, x); } }; #define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS) \ template struct NdarrayAssignCoreWrapper; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, COMPLEX_DATA_TYPE_SEQ, COMPLEX_DATA_TYPE_SEQ, DIM_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_assign_core.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_assign_core.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" namespace oneflow { namespace { template __global__ void NdarrayAssignReducedGpu(XpuVarNdarray y, const XpuReducedNdarray reduced) { NdarrayAssignCore::Assign(y, reduced); } template __global__ void NdarrayAssignGpu(XpuVarNdarray y, const XpuVarNdarray x) { NdarrayAssignCore::Assign(y, x); } } // namespace template struct NdarrayAssignCoreWrapper final { static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayAssignReducedGpu), ctx, n, y, reduced); } static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, x); } }; #define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS) \ template struct NdarrayAssignCoreWrapper; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, HALF_DATA_TYPE_SEQ, HALF_DATA_TYPE_SEQ, DIM_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ, DIM_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ, DIM_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_assign_core.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_reduced_ndarray.h" namespace oneflow { template struct NdarrayAssignCoreWrapper final { static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuReducedNdarray& reduced); static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x); }; template struct NdarrayAssignCore final { OF_DEVICE_FUNC static void Assign(const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { y.template Assign(reduced); } OF_DEVICE_FUNC static void Assign(const XpuVarNdarray& y, const XpuVarNdarray& x) { y.template Assign(x); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_reduce.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ndarray/ndarray_reduce_impl.h" namespace oneflow { template class binary_func, typename Enable = void> struct NdarrayReduce; template class binary_func> struct NdarrayReduce< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { using RetT = typename BinaryFuncTrait::return_type; static void Reduce(ep::Stream* stream, const XpuVarNdarray& origin_y, const XpuVarNdarray& origin_x, const XpuVarNdarray& tmp_storage) { DimVector simplified_x_dim; DimVector simplified_y_dim; TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim); XpuVarNdarray y(Shape(simplified_y_dim), origin_y.ptr()); XpuVarNdarray x(Shape(simplified_x_dim), origin_x.ptr()); CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); if (NdarrayNoReduce::Matched(y, x)) { NdarrayNoReduce::Reduce(stream, y, x, tmp_storage); } else if (NdarrayScalarReduce::Matched(y, x)) { NdarrayScalarReduce::Reduce(stream, y, x, tmp_storage); } else if (NdarrayMatrixRowReduce::Matched(y, x)) { NdarrayMatrixRowReduce::Reduce(stream, y, x, tmp_storage); } else if (NdarrayMatrixColReduce::Matched(y, x)) { NdarrayMatrixColReduce::Reduce(stream, y, x, tmp_storage); } else if (NdarrayXYZCubeXZReduce::Matched(y, x)) { NdarrayXYZCubeXZReduce::Reduce(stream, y, x, tmp_storage); } else { NdarrayDefaultReduce::Reduce(stream, y, x, tmp_storage); } } static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, DimVector* simplified_x, DimVector* simplified_y) { CHECK_EQ(y.NumAxes(), x.NumAxes()); CHECK(y.At(0) == 1 || y.At(0) == x.At(0)); CHECK(simplified_x->empty()); CHECK(simplified_y->empty()); simplified_x->emplace_back(x.At(0)); simplified_y->emplace_back(y.At(0)); bool prev_axis_is_reduced = (y.At(0) == 1); FOR_RANGE(int, i, 1, x.NumAxes()) { const int64_t x_dim = x.At(i); const int64_t y_dim = y.At(i); const bool cur_axis_is_reduced = (y_dim == 1); CHECK(cur_axis_is_reduced || y_dim == x_dim); if (cur_axis_is_reduced == prev_axis_is_reduced) { simplified_x->back() *= x_dim; simplified_y->back() *= y_dim; } else { simplified_x->emplace_back(x_dim); simplified_y->emplace_back(y_dim); } prev_axis_is_reduced = cur_axis_is_reduced; } } }; template class binary_func> struct NdarrayReduce< device_type, T, binary_func, typename std::enable_if::type>::value>::type> final { static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { using NewT = typename DevDType::type; return NdarrayReduce::Reduce( stream, reinterpret_cast&>(y), reinterpret_cast&>(x), reinterpret_cast&>(tmp_storage)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_reduce_impl.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/ndarray/ndarray_reduce_impl.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { #define SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(struct_name) \ template class binary_func> \ struct struct_name final { \ using RetT = typename BinaryFuncTrait::return_type; \ static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { \ return false; \ } \ static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, \ const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { \ UNIMPLEMENTED(); \ } \ } SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce); SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce); SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce); SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayXYZCubeXZReduce); #undef SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL #define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func) \ template struct NdarrayScalarReduce; \ template struct NdarrayMatrixRowReduce; \ template struct NdarrayMatrixColReduce; \ template struct NdarrayXYZCubeXZReduce; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, FLOATING_DATA_TYPE_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, COMPLEX_DATA_TYPE_SEQ, REDUCE_BINARY_FUNC_SEQ); template class binary_func> struct NdarrayReduceCoreWrapper final { static void ReduceAxis(ep::Stream* stream, const XpuReducedNdarray& dst_reduced, const XpuReducedNdarray& x, int axis) { NdarrayReduceCore::ReduceAxis(dst_reduced, x, axis); } }; #define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func) \ template struct NdarrayReduceCoreWrapper; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ, REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, COMPLEX_DATA_TYPE_SEQ, DIM_SEQ, REDUCE_COMPLEX_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, FLOATING_DATA_TYPE_SEQ, DIM_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_reduce_impl.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/kernel/util/numerics.cuh" #include "oneflow/core/ndarray/ndarray_reduce_impl.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/permutation_iterator.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h" namespace cub { struct Prod { template __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a * b; } }; struct Any { template __host__ __device__ __forceinline__ T operator()(const T& a, const U& b) const { return a || b; } }; struct All { template __host__ __device__ __forceinline__ T operator()(const T& a, const U& b) const { return a && b; } }; struct NanSum { template __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const { if (oneflow::detail::numerics::isnan(a)) return oneflow::detail::numerics::isnan(b) ? T{0} : b; return oneflow::detail::numerics::isnan(b) ? a : a + b; } }; } // namespace cub __host__ __device__ __forceinline__ cuComplex operator+(const cuComplex& a, const cuComplex& b) { return cuComplex{a.x + b.x, a.y + b.y}; } __host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) { return cuDoubleComplex{a.x + b.x, a.y + b.y}; } namespace oneflow { namespace { template class R, typename T, typename K, typename RetT> __global__ void MatrixColReduceBy1ThreadPerColumn(K num_elems, K num_cols, const T* in, RetT* out) { CUDA_1D_KERNEL_LOOP_T(K, j, num_cols) { K index = j; T sum = in[index]; for (index += num_cols; index < num_elems; index += num_cols) { sum = R::Invoke(sum, in[index]); } out[j] = sum; } } template struct WithAlign2 { union { T value; int32_t padding; }; }; template class R, typename T, typename K, typename RetT> __global__ void MatrixColReduceByWarpBlock(K num_elems, K num_cols, const T* in, RetT* out) { const K thread_col = threadIdx.x % kCudaWarpSize; const K thread_row = threadIdx.x / kCudaWarpSize; const K thread_dim_row = blockDim.x / kCudaWarpSize; const K num_valid_threads = thread_dim_row * num_cols; // ASSERT: always <= num_elems const K col = blockIdx.x * kCudaWarpSize + thread_col; __shared__ WithAlign2 partial_values[kCudaWarpSize * kCudaWarpSize]; if (col < num_cols) { K index = thread_row * num_cols + col; T val = in[index]; for (index += num_valid_threads; index < num_elems; index += num_valid_threads) { val = R::Invoke(val, in[index]); } partial_values[threadIdx.x].value = val; } __syncthreads(); if (col < num_cols && thread_row == 0) { int index = thread_col; T val = partial_values[index].value; for (index += kCudaWarpSize; index < blockDim.x; index += kCudaWarpSize) { val = R::Invoke(val, partial_values[index].value); } out[col] = val; } } template class R, typename T, typename K, typename RetT> void MatrixColReduceBy1BlockLayer(ep::Stream* stream, K num_elems, K num_cols, const T* in, RetT* out) { CHECK_LE(num_cols, kCudaMaxBlocksNum * kCudaWarpSize); const K num_rows = num_elems / num_cols; CHECK_GT(num_rows, 0); if (num_rows < kCudaWarpSize) { RUN_CUDA_KERNEL((MatrixColReduceBy1ThreadPerColumn), stream, num_cols, num_elems, num_cols, in, out); } else { const int num_blocks = (num_cols + kCudaWarpSize - 1) / kCudaWarpSize; const int num_threads = kCudaWarpSize * kCudaWarpSize; auto Reduce = &MatrixColReduceByWarpBlock; Reduce<<As()->cuda_stream()>>>( num_elems, num_cols, in, out); } } const static int32_t kNumRows4OneBlockLayer = kCudaWarpSize * kCudaWarpSize; const static int32_t kNumCols4OneBlockLayer = kCudaMaxBlocksNum * kCudaWarpSize / 2; template class R, typename T, typename K> void MatrixColReduceK(ep::Stream* stream, K num_rows, K num_cols, const T* in, typename BinaryFuncTrait::return_type* out, T* tmp) { K num_elems = num_rows * num_cols; if (num_rows < kNumRows4OneBlockLayer || num_cols > kNumCols4OneBlockLayer) { MatrixColReduceBy1BlockLayer::return_type>( stream, num_elems, num_cols, in, out); } else { int scale_shift = 1; for (; true; ++scale_shift) { if ((num_rows >> scale_shift) < kNumRows4OneBlockLayer) { break; } if ((num_cols << scale_shift) > kNumCols4OneBlockLayer) { break; } } MatrixColReduceBy1BlockLayer(stream, num_elems, (num_cols << scale_shift), in, tmp); // recursively calls MatrixColReduceK(...) log32(num_rows) times at most MatrixColReduceK(stream, (1 << scale_shift), num_cols, tmp, out, tmp); } } template class R, typename T> void MatrixColReduce(ep::Stream* stream, int64_t num_rows, int64_t num_cols, const T* in, typename BinaryFuncTrait::return_type* out, T* tmp) { if (IsKernelSafeInt32(num_rows * num_cols)) { return MatrixColReduceK(stream, num_rows, num_cols, in, out, tmp); } else { return MatrixColReduceK(stream, num_rows, num_cols, in, out, tmp); } } } // namespace template class binary_func> struct CubFunctor4BianryFunc; #define SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(func_name) \ template \ struct CubFunctor4BianryFunc final { \ using type = cub::func_name; \ }; OF_PP_FOR_EACH_ATOMIC(SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ(NanSum)); #undef SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC struct RowOffsetFunctor final { OF_DEVICE_FUNC explicit RowOffsetFunctor(int32_t num_cols) : num_cols_(num_cols) {} OF_DEVICE_FUNC int32_t operator()(const int32_t& x) const { return x * num_cols_; } int32_t num_cols_; }; template class binary_func> struct NdarrayScalarReduce final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { return y.shape().ElemNum() == 1; } static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { CHECK(Matched(y, x)); size_t x_size = x.shape().ElemNum(); size_t tmp_storage_bytes = 0; auto DoReduce = [&](T* tmp_storage_ptr) { int retcode = cub::DeviceReduce::Reduce( tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), x_size, typename CubFunctor4BianryFunc::type(), UnitOfBinaryFunc::Val(), stream->As()->cuda_stream()); CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; }; DoReduce(nullptr); CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); DoReduce(tmp_storage.ptr()); } }; template class binary_func> struct NdarrayMatrixRowReduce final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { if (y.shape().ElemNum() > GetMaxVal()) { return false; } if (x.shape().NumAxes() != 2) { return false; } if (y.shape().NumAxes() != 2) { return false; } return x.shape().At(0) == y.shape().At(0) && y.shape().At(1) == 1; } static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { CHECK(Matched(y, x)); int32_t num_rows = y.shape().ElemNum(); int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum(); RowOffsetFunctor get_row_offset(num_cols); cub::CountingInputIterator counting_intput_it(0); cub::TransformInputIterator> transform_input_iter(counting_intput_it, get_row_offset); size_t tmp_storage_bytes = 0; auto DoReduce = [&](T* tmp_storage_ptr) { int retcode = cub::DeviceSegmentedReduce::Reduce( tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), num_rows, transform_input_iter, transform_input_iter + 1, typename CubFunctor4BianryFunc::type(), UnitOfBinaryFunc::Val(), stream->As()->cuda_stream()); CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; }; DoReduce(nullptr); CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); DoReduce(tmp_storage.ptr()); } }; template class binary_func> struct NdarrayMatrixColReduce final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { if (y.shape().ElemNum() > GetMaxVal()) { return false; } if (x.shape().NumAxes() != 2) { return false; } if (y.shape().NumAxes() != 2) { return false; } return y.shape().At(0) == 1 && x.shape().At(1) == y.shape().At(1); } struct XY2YXFunctor final { __host__ __device__ XY2YXFunctor(int32_t dim_x, int32_t dim_y) : dim_x_(dim_x), dim_y_(dim_y) {} __host__ __device__ int32_t operator()(const int32_t& idx) const { const int32_t y = idx / dim_x_; const int32_t x = idx % dim_x_; return x * dim_y_ + y; } int32_t dim_x_; int32_t dim_y_; }; static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { CHECK(Matched(y, x)); int64_t num_rows = x.shape().At(0); int64_t num_cols = x.shape().At(1); if (num_cols < kNumCols4OneBlockLayer) { return MatrixColReduce(stream, num_rows, num_cols, x.host_ptr(), y.host_ptr(), tmp_storage.host_ptr()); } RowOffsetFunctor get_row_offset(num_rows); cub::CountingInputIterator counting_intput_it(0); cub::TransformInputIterator> transform_input_iter(counting_intput_it, get_row_offset); XY2YXFunctor xy2yx(x.shape().At(0), x.shape().At(1)); using XY2YxIndexIter = cub::TransformInputIterator>; XY2YxIndexIter xy2yx_iter(counting_intput_it, xy2yx); PermutationIterator x_iter(x.ptr(), xy2yx_iter); size_t tmp_storage_bytes = 0; auto DoReduce = [&](T* tmp_storage_ptr) { int retcode = cub::DeviceSegmentedReduce::Reduce( tmp_storage_ptr, tmp_storage_bytes, x_iter, y.ptr(), num_cols, transform_input_iter, transform_input_iter + 1, typename CubFunctor4BianryFunc::type(), UnitOfBinaryFunc::Val(), stream->As()->cuda_stream()); CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; }; DoReduce(nullptr); CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); DoReduce(tmp_storage.ptr()); } }; template class binary_func> struct NdarrayXYZCubeXZReduce final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { if (y.shape().ElemNum() > GetMaxVal()) { return false; } if (x.shape().NumAxes() != 3) { return false; } if (y.shape().NumAxes() != 3) { return false; } return y.shape().At(0) == 1 && x.shape().At(1) == y.shape().At(1) && y.shape().At(2) == 1; } struct XYZ2YxzFunctor final { __host__ __device__ XYZ2YxzFunctor(int32_t dim_x, int32_t dim_y, int32_t dim_z) : dim_z_(dim_z), dim_xz_(dim_x * dim_z), dim_yz_(dim_y * dim_z) {} __host__ __device__ int32_t operator()(const int32_t& idx) const { const int32_t y = idx / dim_xz_; const int32_t xz_idx = idx % dim_xz_; const int32_t x = xz_idx / dim_z_; const int32_t z = xz_idx % dim_z_; return x * dim_yz_ + y * dim_z_ + z; } int32_t dim_z_; int32_t dim_xz_; int32_t dim_yz_; }; static void Reduce(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { CHECK(Matched(y, x)); int32_t num_rows = y.shape().ElemNum(); int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum(); RowOffsetFunctor get_row_offset(num_cols); cub::CountingInputIterator counting_intput_it(0); cub::TransformInputIterator> transform_input_iter(counting_intput_it, get_row_offset); XYZ2YxzFunctor xyz2yxz(x.shape().At(0), x.shape().At(1), x.shape().At(2)); using XYZ2YxzIndexIter = cub::TransformInputIterator>; XYZ2YxzIndexIter xyz2yxz_iter(counting_intput_it, xyz2yxz); PermutationIterator x_iter(x.ptr(), xyz2yxz_iter); size_t tmp_storage_bytes = 0; auto DoReduce = [&](T* tmp_storage_ptr) { int retcode = cub::DeviceSegmentedReduce::Reduce( tmp_storage_ptr, tmp_storage_bytes, x_iter, y.ptr(), num_rows, transform_input_iter, transform_input_iter + 1, typename CubFunctor4BianryFunc::type(), UnitOfBinaryFunc::Val(), stream->As()->cuda_stream()); CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; }; DoReduce(nullptr); CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); DoReduce(tmp_storage.ptr()); } }; namespace { template class binary_func> __global__ void NdarrayReduceGpuInplaceReduceAxis(const XpuReducedNdarray dst_reduced, const XpuReducedNdarray x, int axis) { NdarrayReduceCore::ReduceAxis(dst_reduced, x, axis); } } // namespace template class binary_func> struct NdarrayReduceCoreWrapper final { static void ReduceAxis(ep::Stream* stream, const XpuReducedNdarray& dst_reduced, const XpuReducedNdarray& x, int axis) { size_t n = x.host_shape().HostElemNum(); RUN_CUDA_KERNEL((NdarrayReduceGpuInplaceReduceAxis), stream, n, dst_reduced, x, axis); } }; #define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func) \ template struct NdarrayScalarReduce; \ template struct NdarrayMatrixRowReduce; \ template struct NdarrayMatrixColReduce; \ template struct NdarrayXYZCubeXZReduce; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, ARITHMETIC_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, FLOATING_DATA_TYPE_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, LOGICAL_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, REDUCE_COMPLEX_BINARY_FUNC_SEQ); #define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func) \ template struct NdarrayReduceCoreWrapper; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ, ARITHMETIC_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, FLOATING_DATA_TYPE_SEQ, DIM_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ, LOGICAL_REDUCE_BINARY_FUNC_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, DIM_SEQ, REDUCE_COMPLEX_BINARY_FUNC_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/ndarray_reduce_impl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/ndarray/xpu_ndarray_assign.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { #define DECLARE_NDARRAY_REDUCE_IMPL(struct_name) \ template class binary_func> \ struct struct_name final { \ static bool Matched( \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& x); \ static void Reduce( \ ep::Stream* ctx, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage); \ } DECLARE_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce); DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce); DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce); DECLARE_NDARRAY_REDUCE_IMPL(NdarrayXYZCubeXZReduce); #undef DECLARE_NDARRAY_REDUCE_IMPL template class binary_func, typename Enable = void> struct NdarrayNoReduce; template class binary_func> struct NdarrayNoReduce::return_type>::value>::type> final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { return x.shape() == y.shape(); } static void Reduce(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { if (std::is_same, BinaryFuncNanSum>()) { XpuNdarrayAssign::AssignNanSum(ctx, y, x); } else { XpuNdarrayAssign::Assign(ctx, y, x); } } }; template class binary_func> struct NdarrayNoReduce::return_type>::value>::type> final { using RetT = typename BinaryFuncTrait::return_type; static bool Matched(const XpuVarNdarray& y, const XpuVarNdarray& x) { return x.shape() == y.shape(); } static void Reduce(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage); } private: #define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) #undef DEFINE_NDARRAY_REDUCE template static void Reduce(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { XpuNdarrayAssign::template Assign(ctx, y, x); } }; template class binary_func> struct NdarrayReduceCoreWrapper final { static void ReduceAxis(ep::Stream* ctx, const XpuReducedNdarray& dst_reduced, const XpuReducedNdarray& x, int axis); }; template class binary_func> struct NdarrayDefaultReduce final { using RetT = typename BinaryFuncTrait::return_type; static void Reduce(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage); } private: #define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) #undef DEFINE_NDARRAY_REDUCE template static void Reduce(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x, const XpuVarNdarray& tmp_storage) { XpuVarNdarray storage(x.shape(), tmp_storage.ptr()); XpuShape cur_shape(x.shape()); CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); CHECK(x.shape() != y.shape()); XpuNdarrayAssign::Assign(ctx, storage, x); for (int i = 0; i < x.shape().NumAxes(); ++i) { if (y.shape().At(i) == x.shape().At(i)) { continue; } CHECK_EQ(y.shape().At(i), 1); CHECK_GT(x.shape().At(i), y.shape().At(i)); InplaceReduceAxis(ctx, i, storage, &cur_shape); } XpuReducedNdarray reduced(y.shape(), storage); XpuNdarrayAssign::template Assign(ctx, y, reduced); } template static void InplaceReduceAxis(ep::Stream* ctx, int axis, const XpuVarNdarray& implace, XpuShape* cur_shape) { int64_t target_elem_num = cur_shape->ElemNum() / cur_shape->At(axis); while (cur_shape->At(axis) > 1) { int64_t shrink = 8 + std::sqrt(target_elem_num); XpuReducedNdarray from(*cur_shape, implace); int64_t new_dim_value = (cur_shape->At(axis) + (shrink - 1)) / shrink; cur_shape->Set(axis, new_dim_value); XpuReducedNdarray to(*cur_shape, implace); NdarrayReduceCoreWrapper::ReduceAxis(ctx, to, from, axis); } } }; template class binary_func> struct NdarrayReduceCore final { template OF_DEVICE_FUNC static void ReduceAxis(const XpuReducedNdarray& dst_reduced, const X& x, int axis) { size_t n = dst_reduced.shape().ElemNum(); int64_t dst_dim_val = dst_reduced.shape().At(axis); XPU_1D_KERNEL_LOOP_BEGIN(i, n); T* dst_reduced_ptr = dst_reduced.template Mut(i); int64_t coord[NDIMS]; dst_reduced.shape().template Offset2Coordinate(i, coord); T reduced = UnitOfBinaryFunc::Val(); while (coord[axis] < x.shape().At(axis)) { reduced = binary_func::Invoke(reduced, x.template Get(coord)); coord[axis] += dst_dim_val; } *dst_reduced_ptr = reduced; XPU_1D_KERNEL_LOOP_END(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ ================================================ FILE: oneflow/core/ndarray/ndarray_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ #define ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_var_ndarray_builder.h" #include "oneflow/core/ndarray/ndarray_reduce.h" #include "oneflow/core/ndarray/ndarray_apply_unary.h" #include "oneflow/core/ndarray/ndarray_apply_binary.h" #include "oneflow/core/ndarray/ndarray_apply_broadcast_unary.h" #include "oneflow/core/ndarray/ndarray_apply_broadcast_binary.h" #include "oneflow/core/ndarray/xpu_reduced_ndarray.h" #include "oneflow/core/ndarray/xpu_ndarray_assign.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/common/util.h" namespace oneflow { template struct NdarrayUtil final { static XpuVarNdarrayBuilder GetValNdarrayBuilder() { return XpuVarNdarrayBuilder(); } static XpuVarNdarrayBuilder GetVarNdarrayBuilder() { return XpuVarNdarrayBuilder(); } static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { return XpuNdarrayAssign::Assign(stream, y, x); } static void BroadcastTo(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { return BroadcastIdentity(stream, y, x); } #define DEFINE_UNARY_FUNC(func_name) \ static void func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& x) { \ return ApplyUnary(stream, y, x); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ) #undef DEFINE_UNARY_FUNC #define DEFINE_ARITHMETIC_BINARY_FUNC(func_name) \ static void func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& a, const XpuVarNdarray& b) { \ return ApplyBinary(stream, y, a, b); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_ARITHMETIC_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ) #undef DEFINE_ARITHMETIC_BINARY_FUNC #define DEFINE_LOGICAL_BINARY_FUNC(func_name) \ static void func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& a, const XpuVarNdarray& b) { \ return ApplyBinary(stream, y, a, b); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_LOGICAL_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ) #undef DEFINE_LOGICAL_BINARY_FUNC #define DEFINE_BROADCAST_UNARY_FUNC(func_name) \ static void Broadcast##func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& x) { \ return BroadcastApplyUnary(stream, y, x); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ) #undef DEFINE_BROADCAST_UNARY_FUNC #define DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC(func_name) \ static void Broadcast##func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& a, const XpuVarNdarray& b) { \ return BroadcastApplyBinary(stream, y, a, b); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ) #undef DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC #define DEFINE_BROADCAST_LOGICAL_BINARY_FUNC(func_name) \ static void Broadcast##func_name( \ ep::Stream* stream, \ const XpuVarNdarray::return_type>& y, \ const XpuVarNdarray& a, const XpuVarNdarray& b) { \ return BroadcastApplyBinary(stream, y, a, b); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_LOGICAL_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ) #undef DEFINE_BROADCAST_LOGICAL_BINARY_FUNC #define DEFINE_INPLACE_UNARY_FUNC(func_name) \ static void Inplace##func_name(ep::Stream* stream, const XpuVarNdarray& y) { \ InplaceApply(stream, y); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ) #undef DEFINE_INPLACE_UNARY_FUNC #define DEFINE_INPLACE_BINARY_FUNC(func_name) \ static void Inplace##func_name(ep::Stream* stream, const XpuVarNdarray& y, \ const XpuVarNdarray& x) { \ InplaceApply(stream, y, x); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ) #undef DEFINE_INPLACE_BINARY_FUNC #define DEFINE_INPLACE_BROADCAST_BINARY_FUNC(func_name) \ static void InplaceBroadcast##func_name(ep::Stream* stream, const XpuVarNdarray& y, \ const XpuVarNdarray& x) { \ return InplaceBroadcastApply(stream, y, x); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_BROADCAST_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ) #undef DEFINE_INPLACE_BROADCAST_BINARY_FUNC #define DEFINE_REDUCE_FUNC(func_name) \ static void Reduce##func_name(ep::Stream* stream, const XpuVarNdarray& y, \ const XpuVarNdarray& x, \ const XpuVarNdarray& tmp_storage) { \ return NdarrayReduce::Reduce(stream, y, x, \ tmp_storage); \ } OF_PP_FOR_EACH_ATOMIC(DEFINE_REDUCE_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ) #undef DEFINE_REDUCE_FUNC private: template class unary_func> static void BroadcastApplyUnary( ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& x) { CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes()); return NdarrayApplyBroadcastUnary::Apply(stream, y, x); } template class binary_func> static void BroadcastApplyBinary( ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { CHECK_EQ(a.shape().NumAxes(), y.shape().NumAxes()); CHECK_EQ(b.shape().NumAxes(), y.shape().NumAxes()); return NdarrayApplyBroadcastBinary::Apply(stream, y, a, b); } template class binary_func> static void InplaceBroadcastApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { static_assert(std::is_same::return_type>::value, "T must be same with BinaryFuncTrait::return_type"); CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes()); return NdarrayApplyBroadcastBinary::InplaceApply(stream, y, x); } template class unary_func> static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y) { static_assert(std::is_same::return_type>::value, "T must be same with UnaryFuncTrait::return_type"); return NdarrayApplyUnary::InplaceApply(stream, y); } template class binary_func> static void InplaceApply(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { static_assert(std::is_same::return_type>::value, "T must be same with BinaryFuncTrait::return_type"); return NdarrayApplyBinary::InplaceApply(stream, y, x); } template class unary_func> static void ApplyUnary( ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& x) { return NdarrayApplyUnary::Apply(stream, y, x); } template class binary_func> static void ApplyBinary( ep::Stream* stream, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { if (a.host_ptr() == y.host_ptr()) { CHECK(a.host_shape() == y.host_shape()); return NdarrayApplyBinary::InplaceApply(stream, y, b); } else { return NdarrayApplyBinary::Apply(stream, y, a, b); } } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ ================================================ FILE: oneflow/core/ndarray/slice.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/slice.h" namespace oneflow { Slice::Slice(const std::initializer_list& l) { DimVector vec(l); value_capacity_ = 0; if (vec.size() == 0) { start_ = kStart; end_ = kEnd; stride_ = 1; } else if (vec.size() == 1) { start_ = vec[0]; end_ = kEnd; stride_ = 1; } else if (vec.size() == 2) { start_ = vec[0]; end_ = vec[1]; stride_ = 1; } else if (vec.size() == 3) { start_ = vec[0]; end_ = vec[1]; stride_ = vec[2]; } else { UNIMPLEMENTED(); } } bool Slice::IsBounded() const { CHECK_NE(stride_, 0); if (value_capacity_ == 0) { return false; } return (start_ >= 0) && (start_ <= value_capacity_ - (stride_ < 0)) && (end_ >= 0 - (stride_ < 0)) && (end_ <= value_capacity_); } const Slice& Slice::Bound(size_t value_capacity) { CHECK_GT(value_capacity, 0); if (value_capacity_ == value_capacity) { return *this; } CHECK_EQ(value_capacity_, 0); value_capacity_ = value_capacity; if (start_ != kStart && start_ < 0) { start_ += value_capacity_; } if (end_ != kStart && end_ < 0) { end_ += value_capacity_; } if (start_ == kStart) { start_ = 0; } if (end_ == kEnd) { end_ = value_capacity_; } if (start_ == kEnd) { start_ = value_capacity_ - (stride_ < 0); } if (end_ == kStart) { end_ = 0 - (stride_ < 0); } CHECK_NE(stride_, 0); CHECK_GE(start_, 0); CHECK_LE(start_, value_capacity_); CHECK_GE(end_, 0); CHECK_LE(end_, value_capacity_); return *this; } size_t Slice::Size() const { CHECK(IsBounded()); if (stride_ > 0 && start_ >= end_) { return 0; } if (stride_ < 0 && start_ <= end_) { return 0; } return ((end_ - start_) + (stride_ - ((stride_ > 0) - (stride_ < 0)))) / stride_; } bool Slice::IsContiguous() const { CHECK(IsBounded()); return stride_ == 1; } bool Slice::IsCoveringAll() const { CHECK(IsBounded()); return start_ == 0 && end_ == value_capacity_; } } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/slice.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_SLICE_H_ #define ONEFLOW_CORE_NDARRAY_SLICE_H_ #include "oneflow/core/ndarray/cpu_ndarray.h" namespace oneflow { class Slice final { public: static const int64_t kStart = LLONG_MIN; static const int64_t kEnd = LLONG_MAX; Slice(const Slice&) = default; Slice(int64_t index) : start_(index), end_(index + 1), stride_(1), value_capacity_(0) {} Slice(const std::initializer_list& l); ~Slice() = default; const Slice& Bound(size_t value_capacity); ALWAYS_INLINE int64_t Get(int64_t index) const { return start_ + index * stride_; } bool IsBounded() const; size_t Size() const; bool IsContiguous() const; bool IsCoveringAll() const; private: int64_t start_; int64_t end_; int64_t stride_; size_t value_capacity_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_SLICE_H_ ================================================ FILE: oneflow/core/ndarray/slice_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/slice.h" #include namespace oneflow { namespace test { TEST(Slice, size) { Slice slice({-2, 0, -1}); slice.Bound(4); ASSERT_EQ(slice.Size(), 2); } TEST(Slice, contiguous) { Slice slice({0, -1, 1}); slice.Bound(4); ASSERT_TRUE(slice.IsContiguous()); ASSERT_FALSE(slice.IsCoveringAll()); } TEST(Slice, is_covering_all) { Slice slice({}); slice.Bound(4); ASSERT_TRUE(slice.IsCoveringAll()); ASSERT_TRUE(slice.IsContiguous()); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/unary_func.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ #define ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ #if defined(__CUDACC__) #include #endif #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/util.h" namespace oneflow { #define ARITHMETIC_UNARY_FUNC_NAME_SEQ (Identity)(Negative)(Exp) #define PREPEND_PREFIX_UNARY_FUNC(name) OF_PP_CAT(UnaryFunc, name) #define ARITHMETIC_UNARY_FUNC_SEQ \ OF_PP_SEQ_MAP(PREPEND_PREFIX_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ) template class UnaryFunc, typename T> struct UnaryFuncTrait final { typedef typename std::remove_const::Invoke(*(const T*)nullptr))>::type return_type; }; #define SPECIALIZE_CONST_TYPE_UNARY_FUNC(func_struct) \ template \ struct func_struct final { \ static OF_DEVICE_FUNC const T Invoke(const T x) { return func_struct::Invoke(x); } \ } template struct UnaryFuncIdentity final { static OF_DEVICE_FUNC const T Invoke(const T x) { return x; } }; template struct UnaryFuncNegative final { static OF_DEVICE_FUNC const T Invoke(const T x) { return -x; } }; SPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncNegative); template struct UnaryFuncExp final { static OF_DEVICE_FUNC const T Invoke(const T x) { #if defined(__CUDA_ARCH__) if (std::is_same::value) { return static_cast(exp(static_cast(x))); } else { return static_cast(exp(static_cast(x))); } #else return std::exp(x); #endif // defined(__CUDA_ARCH__) } }; template<> struct UnaryFuncExp final { static OF_DEVICE_FUNC bool Invoke(const bool x) { #if defined(__CUDA_ARCH__) return static_cast(exp(static_cast(x))); #else return static_cast(std::exp(static_cast(x))); #endif // defined(__CUDA_ARCH__) } }; SPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncExp); template<> struct UnaryFuncExp final { static OF_DEVICE_FUNC const float16 Invoke(const float16 x) { #if defined(__CUDA_ARCH__) half res = static_cast(exp(static_cast(*reinterpret_cast(&x)))); return *reinterpret_cast(&res); #else return float16(std::exp(static_cast(x))); #endif // defined(__CUDA_ARCH__) } }; #define NO_HALF_UTIL_FOUND \ printf("cuda arch must >= 530"); \ assert(false); \ return __float2half(0.0) #if defined(__CUDACC__) template<> struct UnaryFuncNegative final { static __device__ __forceinline__ const half Invoke(const half x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hneg(x); #else NO_HALF_UTIL_FOUND; #endif } }; template<> struct UnaryFuncExp final { static __device__ __forceinline__ const half Invoke(const half x) { return __float2half(std::exp(__half2float(x))); } }; template<> struct UnaryFuncNegative final { static __device__ __forceinline__ const cuComplex Invoke(const cuComplex x) { return cuComplex{-x.x, -x.y}; } }; template<> struct UnaryFuncExp final { static __device__ __forceinline__ const cuComplex Invoke(const cuComplex x) { return cuComplex{exp(x.x) * cos(x.y), exp(x.x) * sin(x.y)}; } }; template<> struct UnaryFuncNegative final { static __device__ __forceinline__ const cuDoubleComplex Invoke(const cuDoubleComplex x) { return cuDoubleComplex{-x.x, -x.y}; } }; template<> struct UnaryFuncExp final { static __device__ __forceinline__ const cuDoubleComplex Invoke(const cuDoubleComplex x) { return cuDoubleComplex{exp(x.x) * cos(x.y), exp(x.x) * sin(x.y)}; } }; #endif template struct UnaryFuncLogicalNot final { static OF_DEVICE_FUNC bool Invoke(const T x) { return !x; } }; SPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncLogicalNot); } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ ================================================ FILE: oneflow/core/ndarray/xpu_binary_func_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { template class binary_func, typename A, typename B> class XpuBinaryFuncNdarray final { public: OF_DEVICE_FUNC XpuBinaryFuncNdarray(const A& a, const B& b) : a_(a), b_(b) {} template OF_DEVICE_FUNC typename BinaryFuncTrait::return_type Get(int64_t offset) const { return binary_func::Invoke(a_.template Get(offset), b_.template Get(offset)); } private: const A a_; const B b_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_broadcast_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_ndarray_base.h" namespace oneflow { template struct XpuBroadcastNdarrayUtil; template class XpuBroadcastNdarray final : public XpuNdarrayBase, T> { public: OF_DEVICE_FUNC XpuBroadcastNdarray(const XpuShape& shape, const XpuVarNdarray& var) : shape_(shape), var_(var) {} ~XpuBroadcastNdarray() = default; template OF_DEVICE_FUNC T Get(int64_t offset) const { int64_t coord[NDIMS]; shape_.template Offset2Coordinate(offset, coord); XpuBroadcastNdarrayUtil::SrcCoordinate(var_.shape(), coord); return var_.template Get(coord); } OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } OF_DEVICE_FUNC const XpuVarNdarray& var() const { return var_; } private: const XpuShape shape_; const XpuVarNdarray var_; }; #define IMPLACE_SET_SRC_COORD(i) coord[i] %= src_shape.At(i); #define SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(n) \ template \ struct XpuBroadcastNdarrayUtil final { \ OF_DEVICE_FUNC static void SrcCoordinate(const XpuShape& src_shape, int64_t coord[n + 1]) { \ OF_PP_FOR_EACH_TUPLE(IMPLACE_SET_SRC_COORD, GET_SEQ(n)); \ } \ } SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(0); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(1); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(5); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(6); #undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL #undef IMPLACE_SET_SRC_COORD } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_ndarray_assign.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/ndarray_assign_core.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace { template __global__ void NdarrayAssignReducedGpu(XpuVarNdarray y, const XpuReducedNdarray reduced) { NdarrayAssignCore::Assign(y, reduced); } template __global__ void NdarrayAssignGpu(XpuVarNdarray y, const XpuVarNdarray x) { NdarrayAssignCore::Assign(y, x); } } // namespace template struct NdarrayAssignCoreWrapper final { static void Assign(ep::Stream* stream, XpuVarNdarray* y, const XpuReducedNdarray& reduced) { size_t n = y->host_shape().HostElemNum(); RUN_CUDA_KERNEL((NdarrayAssignReducedGpu), stream, n, *y, reduced); } static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, x); } }; #define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS) \ template struct NdarrayAssignCoreWrapper; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, HALF_DATA_TYPE_SEQ, HALF_DATA_TYPE_SEQ, DIM_SEQ); } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/xpu_ndarray_assign.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ #define ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ #include "oneflow/core/ndarray/ndarray_assign_core.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include "oneflow/core/ep/include/primitive/unary_op.h" namespace oneflow { template struct XpuNdarrayAssign; template struct XpuNdarrayAssign< device_type, T, typename std::enable_if::type>::value>::type> final { template static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { NdarrayAssignCoreWrapper::Assign(stream, y, reduced); } template static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { NdarrayAssignCoreWrapper::Assign(stream, y, x); } static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { CHECK(y.shape() == x.shape()); if (x.ptr() == y.ptr()) { return; } Memcpy(stream, y.ptr(), x.ptr(), y.shape().ElemNum() * sizeof(T)); } static void AssignNanSum(ep::Stream* stream, const XpuVarNdarray& y, const XpuVarNdarray& x) { CHECK(y.shape() == x.shape()); // NOLINT CHECK_EQ(device_type, stream->device_type()) << "Device type mismatch"; std::unique_ptr primitive = ep::primitive::NewPrimitive( device_type, ep::primitive::UnaryOp::kNanAssign, GetDataType(), GetDataType()); CHECK(primitive) << "Can not create NanSum primitive for device type " << device_type; primitive->Launch(stream, x.ptr(), y.ptr(), y.shape().ElemNum()); } }; template struct XpuNdarrayAssign< device_type, T, typename std::enable_if::type>::value>::type> final { using NewT = typename DevDType::type; template static void Assign(ep::Stream* stream, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { XpuNdarrayAssign::Assign( stream, reinterpret_cast&>(y), reinterpret_cast&>(reduced)); } static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { XpuNdarrayAssign::Assign( ctx, reinterpret_cast&>(y), reinterpret_cast&>(x)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ ================================================ FILE: oneflow/core/ndarray/xpu_ndarray_base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ #define ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ #include "oneflow/core/ndarray/xpu_shape.h" namespace oneflow { template class unary_func, typename X> class XpuUnaryFuncNdarray; template class binary_func, typename A, typename B> class XpuBinaryFuncNdarray; template class XpuBroadcastNdarray; template class XpuTransposeNdarray; template class XpuReshapeNdarray; template class XpuNdarrayBase { public: XpuNdarrayBase() = default; ~XpuNdarrayBase() = default; template class unary_func> OF_DEVICE_FUNC XpuUnaryFuncNdarray UnaryFunc() const { return XpuUnaryFuncNdarray(*static_cast(this)); } template class binary_func, typename X> OF_DEVICE_FUNC XpuBinaryFuncNdarray BinaryFunc(const X& x) const { return XpuBinaryFuncNdarray(*static_cast(this), x); } OF_DEVICE_FUNC XpuBroadcastNdarray Broadcast(const XpuShape& shape) const { return XpuBroadcastNdarray(shape, *static_cast(this)); } template OF_DEVICE_FUNC XpuTransposeNdarray Transpose( const int64_t perm[NDIMS]) const { return XpuTransposeNdarray(*static_cast(this), perm); } template OF_DEVICE_FUNC XpuReshapeNdarray Reshape(const int64_t shape[NDIMS]) { return XpuReshapeNdarray(*static_cast(this), shape); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ ================================================ FILE: oneflow/core/ndarray/xpu_reduced_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/ndarray/unary_func.h" namespace oneflow { template> class XpuReducedNdarray final { public: OF_DEVICE_FUNC XpuReducedNdarray(const XpuShape& shape, const X& data) : shape_(shape), data_(data) {} OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } const XpuShape& host_shape() const { return shape_; } OF_DEVICE_FUNC const X& data() const { return data_; } template OF_DEVICE_FUNC T Get(int64_t offset) const { int64_t coord[NDIMS]; shape_.template Offset2Coordinate(offset, coord); return Get(coord); } template OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { return data_.template Get(coord); } template OF_DEVICE_FUNC T* Mut(int64_t offset) const { int64_t coord[NDIMS]; shape_.template Offset2Coordinate(offset, coord); return Mut(coord); } template OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { return data_.template Mut(coord); } private: XpuShape shape_; X data_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_reshape_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ namespace oneflow { template> class XpuReshapeNdarray final { public: OF_DEVICE_FUNC XpuReshapeNdarray(const X& x, const int64_t dim[NDIMS]) : x_(x), shape_(dim, NDIMS) {} template OF_DEVICE_FUNC T Get(int64_t offset) const { return x_.template Get(offset); } template OF_DEVICE_FUNC T* Mut(int64_t offset) const { return x_.template Mut(offset); } template OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { return Get(Coord2Offset(coord)); } template OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { return Get(Coord2Offset(coord)); } private: OF_DEVICE_FUNC int64_t Coord2Offset(const int64_t coord[NDIMS]) const { return XpuShapeUtil::Coord2Offset(shape_, coord); } const X& x_; XpuShape shape_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_shape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ndarray/xpu_shape.h" namespace oneflow { XpuShape::XpuShape(const int64_t dim[], int num_axes) { num_axes_ = num_axes; int i = 0; for (; i < num_axes_; ++i) { dim_[i] = dim[i]; } UpdateDimElemNumAndElemNum(); for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; dim_elem_num_[i] = 1; } } XpuShape::XpuShape(const Shape& shape) { num_axes_ = shape.NumAxes(); int i = 0; for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); } UpdateDimElemNumAndElemNum(); for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; dim_elem_num_[i] = 1; } } XpuShape::XpuShape(const ShapeView& shape) { num_axes_ = shape.NumAxes(); int i = 0; for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); } UpdateDimElemNumAndElemNum(); for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; dim_elem_num_[i] = 1; } } XpuShape::XpuShape(const ShapeView& shape, int ndims_left_extend_to) { if (shape.NumAxes() == 1 && ndims_left_extend_to == 0) { num_axes_ = 0; int i = 0; dim_[i] = shape.At(i); UpdateDimElemNumAndElemNum(); for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; } } else { CHECK_LE(shape.NumAxes(), ndims_left_extend_to); num_axes_ = ndims_left_extend_to; size_t left_ones_num = num_axes_ - shape.NumAxes(); int i = 0; for (; i < left_ones_num; ++i) { dim_[i] = 1; } for (; i < num_axes_; ++i) { dim_[i] = shape.At(i - left_ones_num); } UpdateDimElemNumAndElemNum(); for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; dim_elem_num_[i] = 1; } } } bool XpuShape::operator==(const XpuShape& rhs) const { if (num_axes_ != rhs.num_axes_) { return false; } if (elem_num_ != rhs.elem_num_) { return false; } for (int i = 0; i < num_axes_; ++i) { if (dim_[i] != rhs.dim_[i]) { return false; } if (dim_elem_num_[i] != rhs.dim_elem_num_[i]) { return false; } } return true; } void SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y, DimVector* simplified_b) { DimVector simplified_a; SimplifyBroadcastShapes(y, y, b, simplified_y, &simplified_a, simplified_b); } void SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b, DimVector* simplified_y, DimVector* simplified_a, DimVector* simplified_b) { CHECK_EQ(y.NumAxes(), a.NumAxes()); CHECK_EQ(b.NumAxes(), a.NumAxes()); CHECK(simplified_y->empty()); CHECK(simplified_a->empty()); CHECK(simplified_b->empty()); simplified_y->emplace_back(y.At(0)); simplified_a->emplace_back(a.At(0)); simplified_b->emplace_back(b.At(0)); bool a_prev_axis_is_broadcast = (a.At(0) == 1); bool b_prev_axis_is_broadcast = (b.At(0) == 1); FOR_RANGE(int, i, 1, y.NumAxes()) { const int64_t y_dim = y.At(i); const int64_t a_dim = a.At(i); const int64_t b_dim = b.At(i); if ((a_dim == 1) && (b_dim == 1)) { continue; } const bool a_cur_axis_is_broadcast = (a_dim == 1); const bool b_cur_axis_is_broadcast = (b_dim == 1); if (a_prev_axis_is_broadcast == a_cur_axis_is_broadcast && b_prev_axis_is_broadcast == b_cur_axis_is_broadcast) { simplified_y->back() *= y_dim; simplified_a->back() *= a_dim; simplified_b->back() *= b_dim; } else { simplified_y->emplace_back(y_dim); simplified_a->emplace_back(a_dim); simplified_b->emplace_back(b_dim); } a_prev_axis_is_broadcast = a_cur_axis_is_broadcast; b_prev_axis_is_broadcast = b_cur_axis_is_broadcast; } } } // namespace oneflow ================================================ FILE: oneflow/core/ndarray/xpu_shape.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ #define ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { template struct XpuShapeUtil; class XpuShape final { public: explicit XpuShape(const Shape& shape); explicit XpuShape(const ShapeView& shape); explicit XpuShape(const ShapeView& shape, int ndims_left_extend_to); OF_DEVICE_FUNC XpuShape(const int64_t dim[], int num_axes); XpuShape(const XpuShape&) = default; OF_DEVICE_FUNC int64_t At(int64_t dim) const { return dim_[dim]; } OF_DEVICE_FUNC int64_t DimElemNum(int64_t dim) const { return dim_elem_num_[dim]; } OF_DEVICE_FUNC int64_t Count(int64_t dim) const { return At(dim) * DimElemNum(dim); } OF_DEVICE_FUNC size_t ElemNum() const { return elem_num_; } OF_DEVICE_FUNC size_t NumAxes() const { return num_axes_; } size_t HostElemNum() const { return elem_num_; } bool operator==(const XpuShape&) const; bool operator!=(const XpuShape& rhs) const { return !(*this == rhs); } OF_DEVICE_FUNC void Set(int64_t axis, int64_t value) { dim_[axis] = value; UpdateDimElemNumAndElemNum(); } template OF_DEVICE_FUNC int64_t Coordinate2Offset(const int64_t coord[NDIMS]) const { return XpuShapeUtil::Coordinate2Offset(*this, coord); } template OF_DEVICE_FUNC void Offset2Coordinate(int64_t offset, int64_t coord[NDIMS]) const { XpuShapeUtil::Offset2Coordinate(*this, offset, coord); } OF_DEVICE_FUNC void UpdateDimElemNumAndElemNum() { elem_num_ = 1; for (int i = num_axes_ - 1; i >= 0; --i) { dim_elem_num_[i] = elem_num_; elem_num_ *= dim_[i]; } } std::string ToString() const { return ShapeView(dim_, num_axes_).ToString(); } private: size_t num_axes_; size_t elem_num_; int64_t dim_[OF_PP_SEQ_SIZE(DIM_SEQ)]; int64_t dim_elem_num_[OF_PP_SEQ_SIZE(DIM_SEQ)]; }; template<> struct XpuShapeUtil<1> final { OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, const int64_t coord[1]) { return coord[0]; } OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, int64_t coord[1]) { coord[0] = offset; } }; #define COORD_MUL_STRIDE(i) coord[i] * shape.DimElemNum(i) + #define EXTRACT_COORD(i) \ coord[i] = offset / shape.DimElemNum(i); \ offset %= shape.DimElemNum(i); #define SPECIALIZE_XPU_SHAPE_UTIL(n) \ template<> \ struct XpuShapeUtil final { \ OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, \ const int64_t coord[n + 2]) { \ return OF_PP_FOR_EACH_TUPLE(COORD_MUL_STRIDE, GET_SEQ(n)) coord[n + 1]; \ } \ OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, \ int64_t coord[n + 2]) { \ OF_PP_FOR_EACH_TUPLE(EXTRACT_COORD, GET_SEQ(n)); \ coord[n + 1] = offset; \ } \ }; SPECIALIZE_XPU_SHAPE_UTIL(0); SPECIALIZE_XPU_SHAPE_UTIL(1); SPECIALIZE_XPU_SHAPE_UTIL(2); SPECIALIZE_XPU_SHAPE_UTIL(3); SPECIALIZE_XPU_SHAPE_UTIL(4); SPECIALIZE_XPU_SHAPE_UTIL(5); #undef SPECIALIZE_XPU_SHAPE_UTIL #undef EXTRACT_COORD #undef COORD_MUL_STRIDE void SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y, DimVector* simplified_b); void SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b, DimVector* simplified_y, DimVector* simplified_a, DimVector* simplified_b); } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ ================================================ FILE: oneflow/core/ndarray/xpu_transpose_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template> class XpuTransposeNdarray final { public: OF_DEVICE_FUNC XpuTransposeNdarray(const X& x, const int64_t perm[NDIMS]) : x_(x), shape_(x.shape()) { for (int i = 0; i < NDIMS; ++i) { perm_[i] = perm[i]; shape_.Set(i, x.shape().At(perm[i])); } } template::type> OF_DEVICE_FUNC T Get(int64_t offset) const { int64_t coord[NDIMS]; Offset2Coord(offset, coord); return Get(coord); } template::type> OF_DEVICE_FUNC T* Mut(int64_t offset) const { int64_t coord[NDIMS]; Offset2Coord(offset, coord); return Mut(coord); } template::type> OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { int64_t permuted_coord[NDIMS]; PermuteCoord(coord, permuted_coord); return x_.template Get(permuted_coord); } template::type> OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { int64_t permuted_coord[NDIMS]; PermuteCoord(coord, permuted_coord); return x_.template Mut(permuted_coord); } private: OF_DEVICE_FUNC void Offset2Coord(int64_t offset, int64_t coord[NDIMS]) const { shape_.template Offset2Coordinate(offset, coord); } OF_DEVICE_FUNC void PermuteCoord(const int64_t coord[NDIMS], int64_t permuted_coord[NDIMS]) const { for (int i = 0; i < NDIMS; ++i) { permuted_coord[perm_[i]] = coord[i]; } } const X& x_; XpuShape shape_; int64_t perm_[NDIMS]; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_unary_func_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ #define ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ namespace oneflow { template class unary_func, typename X> class XpuUnaryFuncNdarray final { public: OF_DEVICE_FUNC XpuUnaryFuncNdarray(const X& x) : x_(x) {} template OF_DEVICE_FUNC T Get(int64_t offset) const { return unary_func::Invoke(x_.template Get(offset)); } private: const X& x_; }; } // namespace oneflow #endif // ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ #define ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { #if defined(__CUDACC__) #define XPU_1D_KERNEL_LOOP_BEGIN(i, n) CUDA_1D_KERNEL_LOOP(i, n) { #define XPU_1D_KERNEL_LOOP_END() } #else #define XPU_1D_KERNEL_LOOP_BEGIN(i, n) MultiThreadLoop(n, [&](size_t i) { #define XPU_1D_KERNEL_LOOP_END() \ }); #endif #if defined(__CUDACC__) #define XPU_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP(i, n) #else #define XPU_1D_KERNEL_LOOP(i, n) FOR_RANGE(int64_t, i, 0, n) #endif #if defined(__CUDACC__) #define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \ for (int64_t i = blockIdx.x; i < (m); i += gridDim.x) \ for (int64_t j = threadIdx.x; j < (n); j += blockDim.x) #else #define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \ for (int64_t i = 0; i < (m); ++i) \ for (int64_t j = 0; j < (n); ++j) #endif #if defined(__CUDACC__) #define OF_GLOBAL_FUNC __global__ #else #define OF_GLOBAL_FUNC #endif #define GET_SEQ(n) OF_PP_CAT(OF_PP_CAT(GET_SEQ_, n), ) #define GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(0) #define GET_SEQ_1 GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(1) #define GET_SEQ_2 GET_SEQ_1 OF_PP_MAKE_TUPLE_SEQ(2) #define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3) #define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4) #define GET_SEQ_5 GET_SEQ_4 OF_PP_MAKE_TUPLE_SEQ(5) #define GET_SEQ_6 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(6) } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ ================================================ FILE: oneflow/core/ndarray/xpu_var_ndarray.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ #define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ #include "oneflow/core/ndarray/xpu_shape.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/ndarray/xpu_ndarray_base.h" namespace oneflow { template class XpuVarNdarray final : public XpuNdarrayBase, T> { public: XpuVarNdarray(const Blob* blob, int ndims_left_extend_to) : shape_(blob->shape(), ndims_left_extend_to), ptr_(blob->dptr::type>()) {} XpuVarNdarray(Blob* blob, int ndims_left_extend_to) : shape_(blob->shape(), ndims_left_extend_to), ptr_(blob->mut_dptr()) {} XpuVarNdarray(const Shape& shape, T* ptr) : shape_(shape), ptr_(ptr) {} XpuVarNdarray(const ShapeView& shape, T* ptr) : shape_(shape), ptr_(ptr) {} XpuVarNdarray(const ShapeView& shape, T* ptr, int ndims_left_extend_to) : shape_(shape, ndims_left_extend_to), ptr_(ptr) {} ~XpuVarNdarray() = default; ALWAYS_INLINE XpuVarNdarray(const XpuVarNdarray&) = default; OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuShape& shape, T* ptr) : shape_(shape), ptr_(ptr) {} const XpuShape& host_shape() const { return shape_; } T* host_ptr() const { return ptr_; } OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } OF_DEVICE_FUNC T* ptr() const { return ptr_; } template OF_DEVICE_FUNC T Get(int64_t offset) const { return ptr_[offset]; } template OF_DEVICE_FUNC T Get(int64_t coord[NDIMS]) const { return ptr_[shape().template Coordinate2Offset(coord)]; } template OF_DEVICE_FUNC T* Mut(int64_t offset) const { return ptr_ + offset; } template OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { return ptr_ + shape().template Coordinate2Offset(coord); } template OF_DEVICE_FUNC void Assign(const X& x) const { size_t n = shape_.ElemNum(); XPU_1D_KERNEL_LOOP_BEGIN(i, n); ptr_[i] = x.template Get(i); XPU_1D_KERNEL_LOOP_END(); } template class binary_func, int NDIMS, typename X> OF_DEVICE_FUNC void BinaryAssign(const X& x) const { size_t n = shape_.ElemNum(); XPU_1D_KERNEL_LOOP_BEGIN(i, n); T* ptr_i = ptr_ + i; *ptr_i = binary_func::Invoke(*ptr_i, x.template Get(i)); XPU_1D_KERNEL_LOOP_END(); } private: XpuShape shape_; T* ptr_; }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ ================================================ FILE: oneflow/core/ndarray/xpu_var_ndarray_builder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_ #define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_ #include "oneflow/core/ndarray/xpu_var_ndarray.h" namespace oneflow { template class XpuVarNdarrayBuilder final { public: XpuVarNdarrayBuilder() = default; XpuVarNdarrayBuilder(const XpuVarNdarrayBuilder&) = default; ~XpuVarNdarrayBuilder() = default; XpuVarNdarray operator()(const Shape& shape, T* ptr) const { return XpuVarNdarray(shape, ptr); } template typename std::enable_if::value, XpuVarNdarray
>::type operator()( Blob* blob, int ndims_extend_to) const { return XpuVarNdarray
(blob, ndims_extend_to); } template typename std::enable_if::value, XpuVarNdarray>::type operator()(const Blob* blob, int ndims_extend_to) const { return XpuVarNdarray(blob, ndims_extend_to); } }; } // namespace oneflow #endif // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_ ================================================ FILE: oneflow/core/operator/acc_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/acc_tick_op.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& GetBlobDesc4BnInOp) { *GetBlobDesc4BnInOp("acc") = *GetBlobDesc4BnInOp("one"); GetBlobDesc4BnInOp("acc")->set_shape(Shape({1LL})); return Maybe::Ok(); } } // namespace Maybe AccTickOp::InitFromOpConf() { CHECK(op_conf().has_acc_tick_conf()); EnrollInputBn("one", false); EnrollOutputBn("acc", false); return Maybe::Ok(); } Maybe AccTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe AccTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe AccTickOp::InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const { const int32_t max_acc_num = op_conf().acc_tick_conf().max_acc_num(); std::shared_ptr in_shape = JUST(GetTimeShape4BnInOp("one")); CHECK_EQ_OR_RETURN(in_shape->elem_cnt() % max_acc_num, 0); DimVector in_dim_vec = in_shape->dim_vec(); std::shared_ptr op_time_shape; if (in_dim_vec.back() == max_acc_num) { in_dim_vec.pop_back(); op_time_shape.reset(new Shape(in_dim_vec)); } else if (in_dim_vec.back() % max_acc_num == 0) { in_dim_vec.back() /= max_acc_num; op_time_shape.reset(new Shape(in_dim_vec)); } else { op_time_shape.reset(new Shape({in_shape->elem_cnt() / max_acc_num})); } *time_shape = op_time_shape; return Maybe::Ok(); } Maybe AccTickOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_OP(OperatorConf::kAccTickConf, AccTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kAccTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/acc_tick_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_ #define ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class AccTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(AccTickOp); AccTickOp() = default; ~AccTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; Maybe InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; private: }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_ ================================================ FILE: oneflow/core/operator/arg_modifier_signature.proto ================================================ syntax = "proto2"; package oneflow; message InputBlobModifier { optional bool is_mutable = 1 [default = false]; optional bool requires_grad = 3 [default = false]; } message OutputBlobModifier { optional bool is_mutable = 1 [default = false]; optional bool requires_grad = 2 [default = false]; optional bool header_infered_before_compute = 3 [default = true]; oneof inplace_type { string mutable_inplace_ibn = 20; string const_inplace_ibn = 21; } } message ArgModifierSignature { map ibn2input_blob_modifier = 1; map obn2output_blob_modifier = 2; } ================================================ FILE: oneflow/core/operator/assign_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" namespace oneflow { class AssignOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(AssignOp); AssignOp() = default; ~AssignOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; Maybe AssignOp::InitFromOpConf() { CHECK(op_conf().has_assign_conf()); EnrollInputBn("ref")->set_is_mutable(true); EnrollInputBn("value"); return Maybe::Ok(); } std::string DebugString(const BlobDesc& blob_desc) { BlobDescProto blob_desc_proto; blob_desc.ToProto(&blob_desc_proto); return blob_desc_proto.DebugString(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { CHECK_OR_RETURN(*BlobDesc4BnInOp("ref") == *BlobDesc4BnInOp("value")) << "\nref_blob_desc: " << DebugString(*BlobDesc4BnInOp("ref")) << "\nvalue_blob_desc: " << DebugString(*BlobDesc4BnInOp("value")); return Maybe::Ok(); } } // namespace Maybe AssignOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe AssignOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe AssignOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder() .Split(input_bns(), 0) .MakeSplitSignatureListBuilder( JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes()) .Build(sbp_sig_list); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kAssignConf, AssignOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/boxing_identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { class BoxingIdentityOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityOp); BoxingIdentityOp() = default; ~BoxingIdentityOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; Maybe BoxingIdentityOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } LogicalBlobId BoxingIdentityOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().boxing_identity_conf().lbi(); } LogicalBlobId BoxingIdentityOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().boxing_identity_conf().lbi(); } Maybe BoxingIdentityOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kBoxingIdentityConf, BoxingIdentityOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/boxing_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/boxing_op.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace { void EraseEmptyBnInVec(const std::function& GetBlobDesc4BnInOp, PbRpf* bns) { size_t idx_available = 0; for (size_t i = 0; i < bns->size(); ++i) { if (GetBlobDesc4BnInOp((*bns)[i])) { if (i != idx_available) { (*bns)[idx_available] = (*bns)[i]; } ++idx_available; } } bns->erase(bns->begin() + idx_available, bns->end()); } } // namespace void BoxingOp::VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { OpAttribute* op_attribute = kernel_conf->mutable_op_attribute(); EraseEmptyBnInVec(GetBlobDesc4BnInOp, op_attribute->mutable_input_bns()); EraseEmptyBnInVec(GetBlobDesc4BnInOp, op_attribute->mutable_output_bns()); } Maybe BoxingOp::InitFromOpConf() { CHECK(op_conf().has_boxing_conf()); const BoxingOpConf& boxing_conf = op_conf().boxing_conf(); for (int32_t i = 0; i < boxing_conf.in_num(); ++i) { EnrollInputBn("in_" + std::to_string(i), false); } if (boxing_conf.in_box_case() == BoxingOpConf::kAddBox && boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) { EnrollTmpBn("middle"); } for (int32_t i = 0; i < boxing_conf.out_num(); ++i) { EnrollOutputBn("out_" + std::to_string(i), false); } return Maybe::Ok(); } LogicalBlobId BoxingOp::lbi4ibn(const std::string& input_bn) const { return op_conf().boxing_conf().lbi(); } LogicalBlobId BoxingOp::lbi4obn(const std::string& output_bn) const { return op_conf().boxing_conf().lbi(); } Symbol BoxingOp::GetOpConfWithoutOpNameAndLbn() const { OperatorConf op_conf(this->op_conf()); op_conf.set_name("undefined-op-name"); CHECK(op_conf.has_boxing_conf()); auto* boxing_conf = op_conf.mutable_boxing_conf(); LogicalBlobId empty_logical_blob_id; *boxing_conf->mutable_lbi() = empty_logical_blob_id; return SymbolOf(op_conf); } Maybe BoxingOp::InferBlobDescs( const std::function& BlobDesc4BnInOp, bool is_logical) const { const BoxingOpConf& conf = op_conf().boxing_conf(); BlobDesc* first_in_blob = BlobDesc4BnInOp(input_bns().Get(0)); if (conf.in_box_case() == BoxingOpConf::kAddBox) { const Shape& first_in_blob_shape = first_in_blob->shape(); for (const std::string& ibn : input_bns()) { CHECK_EQ_OR_RETURN(first_in_blob_shape, BlobDesc4BnInOp(ibn)->shape()); } } DimVector data_tmp_blob_shape_vec = BlobDesc4BnInOp(input_bns().Get(0))->shape().dim_vec(); JUST(InferTmpBlobDesc(BlobDesc4BnInOp, &data_tmp_blob_shape_vec, is_logical)); if (conf.out_box_case() == BoxingOpConf::kSplitBox) { const BoxSplitConf& split_conf = conf.split_box(); CHECK_GE_OR_RETURN(split_conf.axis(), 0); CHECK_LT_OR_RETURN(split_conf.axis(), data_tmp_blob_shape_vec.size()); FOR_RANGE(size_t, i, 0, output_bns().size()) { BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i)); *out_blob_desc = *first_in_blob; CHECK_GT_OR_RETURN(split_conf.part_num(i), 0); data_tmp_blob_shape_vec[split_conf.axis()] = split_conf.part_num(i); out_blob_desc->set_shape(Shape(data_tmp_blob_shape_vec)); } } else if (conf.out_box_case() == BoxingOpConf::kCloneBox) { for (const std::string& obn : output_bns()) { BlobDesc* out_blob_desc = BlobDesc4BnInOp(obn); *out_blob_desc = *first_in_blob; out_blob_desc->set_shape(Shape(data_tmp_blob_shape_vec)); } } else { UNIMPLEMENTED_THEN_RETURN(); } return Maybe::Ok(); } Maybe BoxingOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp, true); } Maybe BoxingOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp, false); } Maybe BoxingOp::InferTmpBlobDesc( std::function GetBlobDesc4BnInOp, DimVector* data_tmp_vec_ptr, bool is_logical) const { const BoxingOpConf& conf = op_conf().boxing_conf(); if (conf.in_box_case() == BoxingOpConf::kConcatBox) { int32_t concat_axis = conf.concat_box().axis(); CHECK_GE_OR_RETURN(concat_axis, 0); FOR_RANGE(size_t, ib_idx, 1, input_bns().size()) { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(ib_idx)); const DimVector& in_blob_shape_vec = in_blob_desc->shape().dim_vec(); CHECK_LT_OR_RETURN(concat_axis, in_blob_shape_vec.size()); FOR_RANGE(size_t, i, 0, in_blob_shape_vec.size()) { if (i == concat_axis) { (*data_tmp_vec_ptr)[i] += in_blob_shape_vec[i]; } else { CHECK_EQ_OR_RETURN((*data_tmp_vec_ptr)[i], in_blob_shape_vec[i]); } } } } CHECK_NE_OR_RETURN(conf.out_box_case(), BoxingOpConf::OUT_BOX_NOT_SET); if (conf.in_box_case() == BoxingOpConf::kAddBox && conf.out_box_case() == BoxingOpConf::kSplitBox) { if (!is_logical) { BlobDesc* data_tmp_blob_desc = GetBlobDesc4BnInOp(SoleTbn()); data_tmp_blob_desc->set_shape(Shape(*data_tmp_vec_ptr)); data_tmp_blob_desc->set_data_type(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type()); } } return Maybe::Ok(); } Maybe BoxingOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel(); const SbpParallel& sbp_parallel = JUST(SbpInferHint4Ibn(input_bns().Get(0)))->sbp_parallel(); FOR_RANGE(int32_t, i, 0, input_bns().size()) { CHECK_OR_RETURN(sbp_parallel == JUST(SbpInferHint4Ibn(input_bns().Get(i)))->sbp_parallel()); } (*bn2sbp)[input_bns().Get(0)] = sbp_parallel; (*bn2sbp)[output_bns().Get(0)] = sbp_parallel; return Maybe::Ok(); } REGISTER_OP(OperatorConf::kBoxingConf, BoxingOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/boxing_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_BOXING_OP_H_ #define ONEFLOW_CORE_OPERATOR_BOXING_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class BoxingOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(BoxingOp); BoxingOp() = default; ~BoxingOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; protected: void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; void AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index) override {} private: Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp, bool is_logical) const; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; Maybe InferTmpBlobDesc(std::function GetBlobDesc4BnInOp, DimVector* data_tmp_vec_ptr, bool is_logical) const; Symbol GetOpConfWithoutOpNameAndLbn() const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_BOXING_OP_H_ ================================================ FILE: oneflow/core/operator/boxing_zeros_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { class BoxingZerosOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(BoxingZerosOp); BoxingZerosOp() = default; ~BoxingZerosOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; Maybe BoxingZerosOp::InitFromOpConf() { EnrollOutputBn("out", false); return Maybe::Ok(); } LogicalBlobId BoxingZerosOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().boxing_zeros_conf().lbi(); } LogicalBlobId BoxingZerosOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().boxing_zeros_conf().lbi(); } Maybe BoxingZerosOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BoxingZerosOpConf& conf = this->op_conf().boxing_zeros_conf(); BlobDesc* out = GetBlobDesc4BnInOp("out"); out->set_data_type(conf.data_type()); out->set_shape(Shape(conf.shape())); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kBoxingZerosConf, BoxingZerosOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/broadcast_to_compatible_with_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { namespace { Maybe GetBroadcastShape(const Shape& a_shape, const Shape& b_shape, Shape* broadcast_shape) { Shape max_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes())); Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), max_shape.NumAxes()); Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), max_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) { CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1 || a_extend_shape.At(i) == b_extend_shape.At(i)) << "shape " << a_shape.ToString() << " and shape " << b_shape.ToString() << " are not broadcastable"; max_shape.Set(i, std::max(a_extend_shape.At(i), b_extend_shape.At(i))); } *broadcast_shape = max_shape; return Maybe::Ok(); } Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { int64_t num_compatibles = op_conf.broadcast_to_compatible_with_conf().compatible_size(); const BlobDesc* x_desc = BlobDesc4BnInOp("x"); Shape broadcasted_shape(x_desc->shape()); FOR_RANGE(int64_t, i, 0, num_compatibles) { const BlobDesc* compatible_i = BlobDesc4BnInOp(GenRepeatedBn("compatible", i)); JUST(GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape)); } BlobDesc* y_desc = BlobDesc4BnInOp("y"); y_desc->CopyFrom(*x_desc); y_desc->set_shape(broadcasted_shape); return Maybe::Ok(); } } // namespace class BroadcastToCompatibleWithOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastToCompatibleWithOp); BroadcastToCompatibleWithOp() = default; ~BroadcastToCompatibleWithOp() override = default; Maybe InitFromOpConf() override { CHECK(op_conf().has_broadcast_to_compatible_with_conf()); EnrollInputBn("x"); EnrollRepeatedInputBn("compatible", false); EnrollOutputBn("y"); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } private: void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override { auto* conf = kernel_conf->mutable_broadcast_to_compatible_with_conf(); const BlobDesc* x_desc = GetBlobDesc4BnInOp("x"); const BlobDesc* y_desc = GetBlobDesc4BnInOp("y"); Shape x_extend_shape = CreateLeftExtendedShape(ShapeView(x_desc->shape()), y_desc->shape().NumAxes()); FOR_RANGE(int64_t, i, 0, y_desc->shape().NumAxes()) { if (x_extend_shape.At(i) == 1 && y_desc->shape().At(i) != 1) conf->mutable_broadcast_axes()->Add(i); } } Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { Shape broadcasted_shape{1}; for (const std::string& ibn : input_bns()) { const Shape& input_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape(); JUST(GetBroadcastShape(broadcasted_shape, input_shape, &broadcasted_shape)); } const int64_t broadcast_num_axes = broadcasted_shape.NumAxes(); HashMap ibn2extend_shape; for (const std::string& ibn : input_bns()) { const Shape& input_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape(); CHECK_OR_RETURN( ibn2extend_shape .emplace(ibn, CreateLeftExtendedShape(ShapeView(input_shape), broadcast_num_axes)) .second); } FOR_RANGE(int64_t, i, 0, broadcast_num_axes) { if (broadcasted_shape.At(i) == 1) { continue; } SbpSignature sbp_sig; for (const auto& pair : ibn2extend_shape) { if (pair.second.At(i) == 1) { (*sbp_sig.mutable_bn_in_op2sbp_parallel())[pair.first].mutable_broadcast_parallel(); } else { (*sbp_sig.mutable_bn_in_op2sbp_parallel())[pair.first].mutable_split_parallel()->set_axis( i - (broadcast_num_axes - pair.second.NumAxes())); } } (*sbp_sig.mutable_bn_in_op2sbp_parallel())["y"].mutable_split_parallel()->set_axis(i); *sbp_sig_list->mutable_sbp_signature()->Add() = sbp_sig; } PbRpf compatible_bns; int64_t num_compatibles = op_conf().broadcast_to_compatible_with_conf().compatible_size(); FOR_RANGE(int64_t, i, 0, num_compatibles) { *compatible_bns.Add() = GenRepeatedBn("compatible", i); } SbpSignatureBuilder() .PartialSum("x") .Broadcast(compatible_bns) .PartialSum("y") .Build(sbp_sig_list->mutable_sbp_signature()->Add()); SbpSignatureBuilder() .Broadcast("x") .PartialSum(compatible_bns) .Broadcast("y") .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kBroadcastToCompatibleWithConf, BroadcastToCompatibleWithOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/callback_notify_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/callback_notify_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe CallbackNotifyOp::InitFromOpConf() { CHECK(op_conf().has_callback_notify_conf()); EnrollInputBn("in", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { CHECK_OR_RETURN(BlobDesc4BnInOp("in")->shape() == Shape({1})); CHECK_OR_RETURN(IsIntegralDataType(BlobDesc4BnInOp("in")->data_type())); return Maybe::Ok(); } } // namespace Maybe CallbackNotifyOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); return InferBlobDescs(BlobDesc4BnInOp); } Maybe CallbackNotifyOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe CallbackNotifyOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kCallbackNotifyConf, CallbackNotifyOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/callback_notify_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_ #define ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class CallbackNotifyOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyOp); CallbackNotifyOp() = default; ~CallbackNotifyOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_ ================================================ FILE: oneflow/core/operator/case_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/case_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe CaseOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollRepeatedOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const Operator& op, const std::function& BlobDesc4BnInOp) { const BlobDesc* in = BlobDesc4BnInOp("in"); CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), 1); const DataType data_type = in->data_type(); CHECK_OR_RETURN(IsIntegralDataType(data_type)); for (const std::string& obn : op.output_bns()) { BlobDesc* out = BlobDesc4BnInOp(obn); out->set_shape(Shape({1})); out->set_data_type(data_type); } return Maybe::Ok(); } } // namespace Maybe CaseOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(*this, BlobDesc4BnInOp); } Maybe CaseOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(*this, GetBlobDesc4BnInOp); } Maybe CaseOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kCaseConf, CaseOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/case_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_CASE_OP_H_ #define ONEFLOW_CORE_OPERATOR_CASE_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class CaseOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CaseOp); CaseOp() = default; ~CaseOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_CASE_OP_H_ ================================================ FILE: oneflow/core/operator/collective_boxing_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" namespace oneflow { using namespace boxing::collective; class CollectiveBoxingGenericOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericOp); CollectiveBoxingGenericOp() = default; ~CollectiveBoxingGenericOp() override = default; private: Maybe InitFromOpConf() override { CHECK(op_conf().has_collective_boxing_generic_conf()); const RankDesc& rank_desc = op_conf().collective_boxing_generic_conf().rank_desc(); if (GenericOpHasInput(rank_desc)) { EnrollInputBn("in", false); } if (GenericOpHasOutput(rank_desc)) { EnrollOutputBn("out", false); } return Maybe::Ok(); } LogicalBlobId lbi4ibn(const std::string& input_bn) const override { return this->op_conf().collective_boxing_generic_conf().lbi(); } LogicalBlobId lbi4obn(const std::string& output_bn) const override { return this->op_conf().collective_boxing_generic_conf().lbi(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { const RankDesc& rank_desc = op_conf().collective_boxing_generic_conf().rank_desc(); const DataType data_type = rank_desc.op_desc().data_type(); if (GenericOpHasInput(rank_desc)) { const BlobDesc* in = GetBlobDesc4BnInOp("in"); CHECK_OR_RETURN(!in->is_dynamic()); CHECK_EQ_OR_RETURN(in->data_type(), data_type); CHECK_EQ_OR_RETURN(in->shape(), GenericOpGetInputShape(rank_desc)); } if (GenericOpHasOutput(rank_desc)) { BlobDesc* out = GetBlobDesc4BnInOp("out"); out->set_data_type(data_type); out->set_shape(GenericOpGetOutputShape(rank_desc)); } return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kCollectiveBoxingGenericConf, CollectiveBoxingGenericOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/collective_boxing_pack_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { class CollectiveBoxingPackOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackOp); CollectiveBoxingPackOp() = default; ~CollectiveBoxingPackOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; Maybe CollectiveBoxingPackOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } LogicalBlobId CollectiveBoxingPackOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().collective_boxing_pack_conf().lbi(); } LogicalBlobId CollectiveBoxingPackOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().collective_boxing_pack_conf().lbi(); } Maybe CollectiveBoxingPackOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *CHECK_NOTNULL(out_blob_desc) = *CHECK_NOTNULL(in_blob_desc); // NOLINT out_blob_desc->set_shape(Shape({in_blob_desc->shape().elem_cnt()})); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kCollectiveBoxingPackConf, CollectiveBoxingPackOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/collective_boxing_unpack_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { class CollectiveBoxingUnpackOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackOp); CollectiveBoxingUnpackOp() = default; ~CollectiveBoxingUnpackOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; Maybe CollectiveBoxingUnpackOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } LogicalBlobId CollectiveBoxingUnpackOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().collective_boxing_unpack_conf().lbi(); } LogicalBlobId CollectiveBoxingUnpackOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().collective_boxing_unpack_conf().lbi(); } Maybe CollectiveBoxingUnpackOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf(); const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *out_blob_desc = *in_blob_desc; Shape out_shape(unpack_conf.logical_shape()); if (unpack_conf.dst_sbp_parallel().has_split_parallel()) { const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis(); out_shape.Set(dst_split_axis, out_shape.At(dst_split_axis) / unpack_conf.num_ranks()); } CHECK_EQ_OR_RETURN(out_shape.elem_cnt(), in_blob_desc->shape().elem_cnt()); out_blob_desc->set_shape(out_shape); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kCollectiveBoxingUnpackConf, CollectiveBoxingUnpackOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/constant_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { const ConstantLikeOpConf& conf = op_conf.constant_like_conf(); BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); *out_blob_desc = *BlobDesc4BnInOp("like"); if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); } return Maybe::Ok(); } } // namespace class ConstantLikeOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ConstantLikeOp); ConstantLikeOp() = default; ~ConstantLikeOp() = default; Maybe InitFromOpConf() override { CHECK(op_conf().has_constant_like_conf()); EnrollInputBn("like", false); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split("like", 0) .Split("out", 0) .MakeSplitSignatureListBuilder(JUST(LogicalBlobDesc4Ibn("like")).shape().NumAxes()) .Build(sbp_sig_list); SbpSignatureBuilder().PartialSum("like").Broadcast("out").Build( sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kConstantLikeConf, ConstantLikeOp); REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kConstantLikeConf, 1); } // namespace oneflow ================================================ FILE: oneflow/core/operator/copy_comm_net_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/copy_comm_net_op.h" namespace oneflow { Maybe CopyCommNetOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } LogicalBlobId CopyCommNetOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().copy_comm_net_conf().lbi(); } LogicalBlobId CopyCommNetOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().copy_comm_net_conf().lbi(); } REGISTER_OP(OperatorConf::kCopyCommNetConf, CopyCommNetOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/copy_comm_net_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_ #define ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class CopyCommNetOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CopyCommNetOp); CopyCommNetOp() = default; ~CopyCommNetOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_ ================================================ FILE: oneflow/core/operator/critical_section_callback_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace class CriticalSectionCallbackTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickOp); CriticalSectionCallbackTickOp() = default; ~CriticalSectionCallbackTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; Maybe CriticalSectionCallbackTickOp::InitFromOpConf() { CHECK(op_conf().has_critical_section_callback_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe CriticalSectionCallbackTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe CriticalSectionCallbackTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe CriticalSectionCallbackTickOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionCallbackTickConf, 128); REGISTER_OP(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionCallbackTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/critical_section_wait_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace class CriticalSectionWaitTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickOp); CriticalSectionWaitTickOp() = default; ~CriticalSectionWaitTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; Maybe CriticalSectionWaitTickOp::InitFromOpConf() { CHECK_OR_RETURN(op_conf().has_critical_section_wait_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe CriticalSectionWaitTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe CriticalSectionWaitTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe CriticalSectionWaitTickOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionWaitTickConf, 2); REGISTER_OP(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionWaitTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/cwise_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/cwise_op.h" namespace oneflow { Maybe CWiseOp::InitFromOpConf() { EnrollRepeatedInputBn("in"); EnrollOutputBn("out")->set_mutable_inplace_ibn("in_0"); VirtualInitFromOpConf(); return Maybe::Ok(); } Maybe CWiseOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const BlobDesc* in_0_blob_desc = BlobDesc4BnInOp(input_bns().Get(0)); for (size_t i = 1; i < input_bns().size(); ++i) { const auto* blob_desc = BlobDesc4BnInOp(input_bns().Get(i)); CHECK_OR_RETURN(*in_0_blob_desc == *blob_desc); } *BlobDesc4BnInOp("out") = *in_0_blob_desc; return VirtualInferBlobDescs(BlobDesc4BnInOp, nullptr); } Maybe CWiseOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(0)); for (size_t i = 1; i < input_bns().size(); ++i) { const auto* blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i)); CHECK_OR_RETURN(*in_0_blob_desc == *blob_desc); } *GetBlobDesc4BnInOp("out") = *in_0_blob_desc; return VirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); } } // namespace oneflow ================================================ FILE: oneflow/core/operator/cwise_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_CWISE_OP_H_ #define ONEFLOW_CORE_OPERATOR_CWISE_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class CWiseOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(CWiseOp); CWiseOp() = default; virtual ~CWiseOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; protected: virtual void VirtualInitFromOpConf() { UNIMPLEMENTED(); } virtual Maybe VirtualInferBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return Maybe::Ok(); } }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_CWISE_OP_H_ ================================================ FILE: oneflow/core/operator/decode_random_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_ #define ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class DecodeRandomOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DecodeRandomOp); DecodeRandomOp() = default; ~DecodeRandomOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_ ================================================ FILE: oneflow/core/operator/device_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/device_tick_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe DeviceTickOp::InitFromOpConf() { CHECK(op_conf().has_device_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace Maybe DeviceTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe DeviceTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe DeviceTickOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } Maybe DeviceTickOp::InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const { std::shared_ptr in_time_shape; for (const auto& bn : input_bns()) { std::shared_ptr ts = JUST(GetTimeShape4BnInOp(bn)); if (!in_time_shape) { in_time_shape = ts; } else { CHECK_OR_RETURN(*in_time_shape == *ts); } } if (this->op_conf().device_tick_conf().has_time_shape()) { if (!in_time_shape) { in_time_shape.reset(new Shape(this->op_conf().device_tick_conf().time_shape())); } else { CHECK_OR_RETURN(in_time_shape->elem_cnt() == Shape(this->op_conf().device_tick_conf().time_shape()).elem_cnt()); } } if (in_time_shape) { *time_shape = in_time_shape; } else { *time_shape = std::make_shared(Shape({1, 1})); } return Maybe::Ok(); } REGISTER_OP(OperatorConf::kDeviceTickConf, DeviceTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kDeviceTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/device_tick_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_ #define ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class DeviceTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DeviceTickOp); DeviceTickOp() = default; ~DeviceTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; Maybe InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_ ================================================ FILE: oneflow/core/operator/distribute_add_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/scope.h" namespace oneflow { class DistributeAddOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DistributeAddOp); DistributeAddOp() = default; ~DistributeAddOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferBlobParallelDesc() override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; }; Maybe DistributeAddOp::InitFromOpConf() { CHECK(op_conf().has_distribute_add_conf()); EnrollRepeatedInputBn("in"); EnrollOutputBn("out"); return Maybe::Ok(); } Maybe DistributeAddOp::InferBlobParallelDesc() { HashMap> bn2parallel_desc; const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); FOR_RANGE(int, i, 0, input_bns().size()) { bn2parallel_desc[input_bns().Get(i)] = std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } bn2parallel_desc["out"] = op_parallel_desc; JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe { auto it = bn2parallel_desc.find(bn); CHECK_OR_RETURN(it != bn2parallel_desc.end()); return it->second; })); return Maybe::Ok(); } Maybe DistributeAddOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const BlobDesc* in_0 = BlobDesc4BnInOp(input_bns().Get(0)); FOR_RANGE(int, i, 1, output_bns().size()) { const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i)); CHECK_OR_RETURN(*in_i == *in_0); } *BlobDesc4BnInOp("out") = *in_0; return Maybe::Ok(); } Maybe DistributeAddOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* first_blob_desc = nullptr; FOR_RANGE(int, i, 0, input_bns().size()) { first_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i)); if (first_blob_desc != nullptr) { break; } } CHECK_NOTNULL(first_blob_desc); *GetBlobDesc4BnInOp("out") = *first_blob_desc; return Maybe::Ok(); } Maybe DistributeAddOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), input_bns().size()); const auto& first_in_hint = *JUST(SbpInferHint4Ibn(input_bns().Get(0))); FOR_RANGE(int, i, 0, input_bns().size()) { const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn(input_bns().Get(i))); CHECK_EQ_OR_RETURN(1, in_sbp_infer_hint.parallel_desc().parallel_num()); CHECK_EQ_OR_RETURN(first_in_hint.logical_blob_desc().shape(), in_sbp_infer_hint.logical_blob_desc().shape()); } auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel(); for (const auto& ibn : input_bns()) { (*bn2sbp)[ibn].mutable_partial_sum_parallel(); } (*bn2sbp)["out"].mutable_partial_sum_parallel(); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kDistributeAddConf, DistributeAddOp); REGISTER_DISABLE_INPUT_BOXING_GROUP(OperatorConf::kDistributeAddConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/distribute_clone_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/scope.h" namespace oneflow { class DistributeCloneOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DistributeCloneOp); DistributeCloneOp() = default; ~DistributeCloneOp() = default; Maybe InitFromOpConf() override; private: Maybe InferBlobParallelDesc() override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; }; Maybe DistributeCloneOp::InitFromOpConf() { CHECK(op_conf().has_distribute_clone_conf()); EnrollInputBn("in"); EnrollRepeatedOutputBnWithSetter("out", [&](OutputBlobModifier* ob_modifier) { ob_modifier->set_is_mutable(op_conf().distribute_clone_conf().is_variable_ref()); }); return Maybe::Ok(); } Maybe DistributeCloneOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const auto& in_blob_desc = *BlobDesc4BnInOp("in"); FOR_RANGE(int, i, 0, output_bns().size()) { BlobDesc* blob_desc = BlobDesc4BnInOp(output_bns().Get(i)); *blob_desc = in_blob_desc; } return Maybe::Ok(); } Maybe DistributeCloneOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const auto& in_blob_desc = *GetBlobDesc4BnInOp("in"); if (parallel_ctx->parallel_num() > 1) { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), output_bns().size()); auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id())); *out_blob_desc = in_blob_desc; return Maybe::Ok(); } FOR_RANGE(int, i, 0, output_bns().size()) { BlobDesc* blob_desc = GetBlobDesc4BnInOp(output_bns().Get(i)); if (blob_desc != nullptr) { *blob_desc = in_blob_desc; } } return Maybe::Ok(); } Maybe DistributeCloneOp::InferBlobParallelDesc() { HashMap> bn2parallel_desc; const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); bn2parallel_desc["in"] = op_parallel_desc; FOR_RANGE(int, i, 0, output_bns().size()) { bn2parallel_desc[output_bns().Get(i)] = std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe { auto it = bn2parallel_desc.find(bn); CHECK_OR_RETURN(it != bn2parallel_desc.end()); return it->second; })); return Maybe::Ok(); } Maybe DistributeCloneOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), output_bns().size()); const SbpInferHint& in_hint = *JUST(SbpInferHint4Ibn("in")); CHECK_OR_RETURN(in_hint.parallel_desc() == parallel_desc); SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_signature); auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel(); (*bn2sbp)["in"].mutable_broadcast_parallel(); return Maybe::Ok(); } REGISTER_OP(OperatorConf::kDistributeCloneConf, DistributeCloneOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/distribute_concat_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/scope.h" namespace oneflow { class DistributeConcatOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DistributeConcatOp); DistributeConcatOp() = default; ~DistributeConcatOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferBlobParallelDesc() override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; int32_t FixAxis(const int32_t axis, const int64_t num_axes) const; }; Maybe DistributeConcatOp::InitFromOpConf() { CHECK(op_conf().has_distribute_concat_conf()); EnrollRepeatedInputBn("in"); EnrollOutputBn("out"); return Maybe::Ok(); } Maybe DistributeConcatOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const auto& conf = op_conf().distribute_concat_conf(); BlobDesc* out = BlobDesc4BnInOp("out"); *out = *BlobDesc4BnInOp(input_bns().Get(0)); const int32_t concat_axis = FixAxis(conf.axis(), out->shape().NumAxes()); int64_t concat_dim_size = out->shape().At(concat_axis); for (size_t i = 1; i < input_bns().size(); ++i) { const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i)); for (int64_t j = 0; j < in_i->shape().NumAxes(); ++j) { if (j == concat_axis) { concat_dim_size += in_i->shape().At(j); } else { CHECK_EQ_OR_RETURN(out->shape().At(j), in_i->shape().At(j)); } } CHECK_EQ_OR_RETURN(in_i->data_type(), out->data_type()); } Shape output = out->shape(); output.Set(concat_axis, concat_dim_size); out->set_shape(output); out->set_is_dynamic(false); return Maybe::Ok(); } Maybe DistributeConcatOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { if (parallel_ctx->parallel_num() > 1) { const auto* in_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(parallel_ctx->parallel_id())); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *out_blob_desc = *in_blob_desc; out_blob_desc->set_is_dynamic(false); return Maybe::Ok(); } const auto& conf = op_conf().distribute_concat_conf(); const BlobDesc* first_blob_desc = nullptr; int first_blob_desc_idx = -1; FOR_RANGE(int, i, 0, input_bns().size()) { first_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i)); if (first_blob_desc != nullptr) { first_blob_desc_idx = i; break; } } CHECK_NOTNULL(first_blob_desc); DimVector out_dim_vec = first_blob_desc->shape().dim_vec(); int32_t concat_axis = FixAxis(conf.axis(), out_dim_vec.size()); for (size_t i = 0; i < input_bns().size(); ++i) { const BlobDesc* in_i_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i)); if (in_i_blob_desc == nullptr) { continue; } if (first_blob_desc_idx == i) { continue; } for (int64_t j = 0; j < in_i_blob_desc->shape().NumAxes(); ++j) { if (j == concat_axis) { out_dim_vec[j] += in_i_blob_desc->shape().At(j); } else { CHECK_EQ_OR_RETURN(out_dim_vec[j], in_i_blob_desc->shape().At(j)); } } CHECK_EQ_OR_RETURN(in_i_blob_desc->data_type(), first_blob_desc->data_type()); } BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *out_blob_desc = *first_blob_desc; out_blob_desc->set_shape(Shape(out_dim_vec)); out_blob_desc->set_is_dynamic(false); return Maybe::Ok(); } Maybe DistributeConcatOp::InferBlobParallelDesc() { HashMap> bn2parallel_desc; const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); FOR_RANGE(int, i, 0, input_bns().size()) { bn2parallel_desc[input_bns().Get(i)] = std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } bn2parallel_desc["out"] = op_parallel_desc; JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe { auto it = bn2parallel_desc.find(bn); CHECK_OR_RETURN(it != bn2parallel_desc.end()); return it->second; })); return Maybe::Ok(); } Maybe DistributeConcatOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), input_bns().size()); auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe { const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn)); return Maybe(sbp_infer_hint->logical_blob_desc()); }; { // check parallel_num and dimention const auto& conf = op_conf().distribute_concat_conf(); const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes(); const int32_t axis = FixAxis(conf.axis(), num_axes); int64_t dim = 0; FOR_RANGE(int, i, 0, input_bns().size()) { const auto& in_parallel_desc = JUST(SbpInferHint4Ibn(input_bns().Get(i)))->parallel_desc(); CHECK_EQ_OR_RETURN(1, in_parallel_desc.parallel_num()); dim += JUST(LogicalBlobDesc4Ibn(input_bns().Get(i))).shape().At(axis); } BalancedSplitter bs(dim, parallel_desc.parallel_num()); FOR_RANGE(int, i, 0, input_bns().size()) { CHECK_EQ_OR_RETURN(JUST(LogicalBlobDesc4Ibn(input_bns().Get(i))).shape().At(axis), bs.At(i).size()); } } SbpSignatureList sbp_sig_list; JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, &sbp_sig_list)); *sbp_signature = sbp_sig_list.sbp_signature().Get(0); return Maybe::Ok(); } Maybe DistributeConcatOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { const auto& conf = op_conf().distribute_concat_conf(); const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes(); const int32_t axis = FixAxis(conf.axis(), num_axes); SbpSignatureBuilder() .Broadcast(input_bns()) .Split(output_bns(), axis) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } int32_t DistributeConcatOp::FixAxis(const int32_t axis, const int64_t num_axes) const { int32_t ret = axis; if (axis < 0) { ret += num_axes; } CHECK_GE(axis, 0); CHECK_LT(axis, num_axes); return ret; } REGISTER_OP(OperatorConf::kDistributeConcatConf, DistributeConcatOp); REGISTER_DISABLE_INPUT_BOXING_GROUP(OperatorConf::kDistributeConcatConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/distribute_split_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/scope.h" namespace oneflow { class DistributeSplitOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DistributeSplitOp); DistributeSplitOp() = default; ~DistributeSplitOp() = default; Maybe InitFromOpConf() override; private: Maybe InferBlobParallelDesc() override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; int32_t FixAxis(const int32_t axis, const int64_t num_axes) const; }; Maybe DistributeSplitOp::InitFromOpConf() { CHECK(op_conf().has_distribute_split_conf()); EnrollInputBn("in"); EnrollRepeatedOutputBnWithSetter("out", [&](OutputBlobModifier* ob_modifier) { ob_modifier->set_header_infered_before_compute(false); ob_modifier->set_is_mutable(op_conf().distribute_split_conf().is_variable_ref()); }); return Maybe::Ok(); } Maybe DistributeSplitOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const auto& in_blob_desc = *BlobDesc4BnInOp("in"); CHECK_EQ(parallel_desc.parallel_num(), output_bns().size()); const auto& conf = op_conf().distribute_split_conf(); const int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes()); BalancedSplitter bs(in_blob_desc.shape().At(split_axis), parallel_desc.parallel_num()); FOR_RANGE(int, i, 0, parallel_desc.parallel_num()) { BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i)); *out_blob_desc = in_blob_desc; Shape output = out_blob_desc->shape(); output.Set(split_axis, bs.At(i).size()); out_blob_desc->set_shape(output); } return Maybe::Ok(); } Maybe DistributeSplitOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const auto& in_blob_desc = *GetBlobDesc4BnInOp("in"); if (parallel_ctx->parallel_num() > 1) { CHECK_EQ(parallel_ctx->parallel_num(), output_bns().size()); auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id())); *out_blob_desc = in_blob_desc; return Maybe::Ok(); } const auto& conf = op_conf().distribute_split_conf(); int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes()); std::vector out_blob_descs; out_blob_descs.reserve(output_bns().size()); FOR_RANGE(int, i, 0, output_bns().size()) { BlobDesc* blob_desc = GetBlobDesc4BnInOp(output_bns().Get(i)); if (blob_desc != nullptr) { out_blob_descs.emplace_back(blob_desc); } } BalancedSplitter bs(in_blob_desc.shape().At(split_axis), out_blob_descs.size()); FOR_RANGE(int, i, 0, out_blob_descs.size()) { *out_blob_descs.at(i) = in_blob_desc; Shape output = out_blob_descs.at(i)->shape(); // NOLINT output.Set(split_axis, bs.At(i).size()); out_blob_descs.at(i)->set_shape(output); // NOLINT } return Maybe::Ok(); } Maybe DistributeSplitOp::InferBlobParallelDesc() { HashMap> bn2parallel_desc; const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); bn2parallel_desc["in"] = op_parallel_desc; FOR_RANGE(int, i, 0, output_bns().size()) { bn2parallel_desc[output_bns().Get(i)] = std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe { auto it = bn2parallel_desc.find(bn); CHECK_OR_RETURN(it != bn2parallel_desc.end()); return it->second; })); return Maybe::Ok(); } Maybe DistributeSplitOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), output_bns().size()); auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe { const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn)); return Maybe(sbp_infer_hint->logical_blob_desc()); }; SbpSignatureList sbp_sig_list; JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, &sbp_sig_list)); *sbp_signature = sbp_sig_list.sbp_signature().Get(0); return Maybe::Ok(); } Maybe DistributeSplitOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { const auto& conf = op_conf().distribute_split_conf(); const int64_t num_axes = JUST(LogicalBlobDesc4Ibn("in")).shape().NumAxes(); const int32_t axis = FixAxis(conf.axis(), num_axes); SbpSignatureBuilder() .Split(input_bns(), axis) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } int32_t DistributeSplitOp::FixAxis(const int32_t axis, const int64_t num_axes) const { int32_t ret = axis; if (axis < 0) { ret += num_axes; } CHECK_GE(axis, 0); CHECK_LT(axis, num_axes); return ret; } REGISTER_OP(OperatorConf::kDistributeSplitConf, DistributeSplitOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/dst_subset_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace class DstSubsetTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickOp); DstSubsetTickOp() = default; ~DstSubsetTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; Maybe DstSubsetTickOp::InitFromOpConf() { CHECK(op_conf().has_dst_subset_tick_conf()); EnrollRepeatedInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe DstSubsetTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe DstSubsetTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe DstSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kDstSubsetTickConf, DstSubsetTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kDstSubsetTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/dynamic_reshape_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" namespace oneflow { class DynamicReshapeOp final : public Operator { public: Maybe InitFromOpConf() override { CHECK(op_conf().has_dynamic_reshape_conf()); EnrollInputBn("in"); EnrollOutputBn("out")->set_const_inplace_ibn("in"); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { const DynamicReshapeOpConf& conf = op_conf().dynamic_reshape_conf(); const BlobDesc* in = BlobDesc4BnInOp("in"); BlobDesc* out = BlobDesc4BnInOp("out"); *out = *in; DimVector out_dim_vec(conf.shape().dim().begin(), conf.shape().dim().end()); int32_t inferred_axis = -1; int32_t product = 1; for (int32_t i = 0; i < out_dim_vec.size(); ++i) { if (out_dim_vec.at(i) == -1) { CHECK_EQ_OR_RETURN(-1, inferred_axis); inferred_axis = i; } else { CHECK_GT_OR_RETURN(out_dim_vec.at(i), 0); product *= out_dim_vec.at(i); } } if (inferred_axis >= 0) { CHECK_GE_OR_RETURN(product, 1); CHECK_EQ_OR_RETURN(in->shape().elem_cnt() % product, 0); out_dim_vec.at(inferred_axis) = in->shape().elem_cnt() / product; } out->set_shape(Shape(out_dim_vec)); CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), out->shape().elem_cnt()); return Maybe::Ok(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { const auto* sbp_signature = JUST(this->sbp_signature()); const DynamicReshapeOpConf& conf = op_conf().dynamic_reshape_conf(); const BlobDesc* in = GetBlobDesc4BnInOp("in"); BlobDesc* out = GetBlobDesc4BnInOp("out"); *out = *in; DimVector out_dim_vec(conf.shape().dim().begin(), conf.shape().dim().end()); if (parallel_ctx->parallel_num() > 1) { // global strategy // ONLY support sbp: S(0); and -1 must at axis 0 const auto& out_sbp_it = sbp_signature->bn_in_op2sbp_parallel().find("out"); CHECK_OR_RETURN(out_sbp_it != sbp_signature->bn_in_op2sbp_parallel().end()); const SbpParallel& out_sbp = out_sbp_it->second; const auto& in_sbp_it = sbp_signature->bn_in_op2sbp_parallel().find("in"); CHECK_OR_RETURN(in_sbp_it != sbp_signature->bn_in_op2sbp_parallel().end()); const SbpParallel& in_sbp = in_sbp_it->second; if (out_sbp.has_split_parallel()) { CHECK_EQ_OR_RETURN(out_sbp.split_parallel().axis(), 0); CHECK_EQ_OR_RETURN(out_dim_vec.at(0), -1); CHECK_OR_RETURN(in_sbp.has_split_parallel()); CHECK_EQ_OR_RETURN(in_sbp.split_parallel().axis(), 0); } } int32_t inferred_axis = -1; int32_t product = 1; for (int32_t i = 0; i < out_dim_vec.size(); ++i) { if (out_dim_vec.at(i) == -1) { CHECK_EQ_OR_RETURN(-1, inferred_axis); inferred_axis = i; } else { CHECK_GT_OR_RETURN(out_dim_vec.at(i), 0); product *= out_dim_vec.at(i); } } if (inferred_axis >= 0) { CHECK_GE_OR_RETURN(product, 1); CHECK_EQ_OR_RETURN(in->shape().elem_cnt() % product, 0); out_dim_vec.at(inferred_axis) = in->shape().elem_cnt() / product; } out->set_shape(Shape(out_dim_vec)); CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), out->shape().elem_cnt()); return Maybe::Ok(); } private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kDynamicReshapeConf, DynamicReshapeOp); class DynamicReshapeLikeOp final : public Operator { public: Maybe InitFromOpConf() override { CHECK(op_conf().has_dynamic_reshape_like_conf()); EnrollInputBn("x"); EnrollOutputBn("y"); EnrollInputBn("like", false); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { CHECK_EQ_OR_RETURN(BlobDesc4BnInOp("x")->shape().elem_cnt(), BlobDesc4BnInOp("like")->shape().elem_cnt()); BlobDesc4BnInOp("y")->CopyFrom(*BlobDesc4BnInOp("like")); return Maybe::Ok(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { CHECK_EQ_OR_RETURN(GetBlobDesc4BnInOp("x")->shape().elem_cnt(), GetBlobDesc4BnInOp("like")->shape().elem_cnt()); GetBlobDesc4BnInOp("y")->CopyFrom(*GetBlobDesc4BnInOp("like")); return Maybe::Ok(); } private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/esac_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/esac_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe EsacOp::InitFromOpConf() { EnrollRepeatedInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { BlobDesc* out = BlobDesc4BnInOp("out"); out->set_shape(Shape({1})); const DataType data_type = op_conf.esac_conf().data_type(); CHECK_OR_RETURN(IsIntegralDataType(data_type)); out->set_data_type(data_type); return Maybe::Ok(); } } // namespace Maybe EsacOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe EsacOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } Maybe EsacOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kEsacConf, EsacOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/esac_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_ESAC_OP_H_ #define ONEFLOW_CORE_OPERATOR_ESAC_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class EsacOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(EsacOp); EsacOp() = default; ~EsacOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_ESAC_OP_H_ ================================================ FILE: oneflow/core/operator/identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { *BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in"); return Maybe::Ok(); } } // namespace template class IdentityOpTpl final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(IdentityOpTpl); IdentityOpTpl() = default; ~IdentityOpTpl() override = default; Maybe InitFromOpConf() override { EnrollInputBn("in"); EnrollOutputBn("out")->set_const_inplace_ibn("in"); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(GetBlobDesc4BnInOp); } private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { const auto bns = StdVec2PbRpf({"in", "out"}); SbpSignatureBuilder().PartialSum(bns).Build(sbp_sig_list->mutable_sbp_signature()->Add()); const int64_t num_axes = JUST(LogicalBlobDesc4Ibn("in")).shape().NumAxes(); SbpSignatureBuilder().Split(bns, 0).MakeSplitSignatureListBuilder(num_axes).Build(sbp_sig_list); return Maybe::Ok(); } }; struct IdentityOp {}; REGISTER_OP(OperatorConf::kIdentityConf, IdentityOpTpl); struct CopyOp {}; REGISTER_OP(OperatorConf::kCopyConf, IdentityOpTpl); class LocalCastOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(LocalCastOp); LocalCastOp() = default; virtual ~LocalCastOp() override = default; Maybe InitFromOpConf() override { EnrollInputBn("in"); EnrollOutputBn("out")->set_const_inplace_ibn("in"); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(GetBlobDesc4BnInOp); } private: }; namespace { class CastToLocalOp : public LocalCastOp { public: OF_DISALLOW_COPY_AND_MOVE(CastToLocalOp); CastToLocalOp() = default; virtual ~CastToLocalOp() override = default; private: Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { BlobDesc* out = BlobDesc4BnInOp("out"); *out = *BlobDesc4BnInOp("in"); const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_to_local_conf().sbp_parallel()); if (conf_sbp.has_split_parallel()) { const int64_t axis = conf_sbp.split_parallel().axis(); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, out->shape().NumAxes()); const int64_t dim_value = out->shape().At(axis); const int64_t parallel_num = parallel_desc.parallel_num(); CHECK_EQ_OR_RETURN(dim_value % parallel_num, 0); Shape output = out->shape(); output.Set(axis, dim_value / parallel_num); out->set_shape(output); } return Maybe::Ok(); } Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override { CHECK_NE_OR_RETURN(op_conf().cast_to_local_conf().sbp_parallel().parallel_type_case(), SbpParallel::PARALLEL_TYPE_NOT_SET) << "attribute sbp_parallel not set."; const auto& ibn_hint = *JUST(SbpInferHint4Ibn("in")); CHECK_EQ_OR_RETURN(ibn_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); auto* map = sbp_signature->mutable_bn_in_op2sbp_parallel(); const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_to_local_conf().sbp_parallel()); CHECK_OR_RETURN(ibn_hint.sbp_parallel() == conf_sbp); (*map)["in"] = ibn_hint.sbp_parallel(); (*map)["out"] = conf_sbp; return Maybe::Ok(); } Maybe InferLocalSignature( std::function(const std::string&)> LocalSigInferHint4Ibn, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) override { const auto& in_infer_hint = *JUST(LocalSigInferHint4Ibn("in")); CHECK_OR_RETURN(!in_infer_hint.is_local_parallel_view()) << "error use of CastToLocalOp. `in' shouldn't be a local blob"; CHECK_EQ_OR_RETURN(in_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); MutOptLocalParallel("in")->clear_local_parallel(); MutOptLocalParallel("out")->mutable_local_parallel(); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kCastToLocalConf, CastToLocalOp); } // namespace namespace { class CastFromLocalOp : public LocalCastOp { public: OF_DISALLOW_COPY_AND_MOVE(CastFromLocalOp); CastFromLocalOp() = default; virtual ~CastFromLocalOp() override = default; private: Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { BlobDesc* out = BlobDesc4BnInOp("out"); *out = *BlobDesc4BnInOp("in"); const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_from_local_conf().sbp_parallel()); if (conf_sbp.has_split_parallel()) { const int64_t axis = conf_sbp.split_parallel().axis(); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, out->shape().NumAxes()); Shape output = out->shape(); output.Set(axis, out->shape().At(axis) * parallel_desc.parallel_num()); out->set_shape(output); } return Maybe::Ok(); } Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override { CHECK_NE_OR_RETURN(op_conf().cast_from_local_conf().sbp_parallel().parallel_type_case(), SbpParallel::PARALLEL_TYPE_NOT_SET) << "attribute sbp_parallel not set."; const auto& ibn_hint = *JUST(SbpInferHint4Ibn("in")); CHECK_EQ_OR_RETURN(ibn_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); auto* map = sbp_signature->mutable_bn_in_op2sbp_parallel(); (*map)["in"] = ibn_hint.sbp_parallel(); (*map)["out"] = SbpParallel(op_conf().cast_from_local_conf().sbp_parallel()); return Maybe::Ok(); } Maybe InferLocalSignature( std::function(const std::string&)> LocalSigInferHint4Ibn, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) override { const auto& in_infer_hint = *JUST(LocalSigInferHint4Ibn("in")); CHECK_OR_RETURN(in_infer_hint.is_local_parallel_view()) << "error use of CastFromLocalOp. `in' should be a local blob"; CHECK_EQ_OR_RETURN(in_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); MutOptLocalParallel("in")->mutable_local_parallel(); MutOptLocalParallel("out")->clear_local_parallel(); return Maybe::Ok(); } }; REGISTER_OP(OperatorConf::kCastFromLocalConf, CastFromLocalOp); } // namespace } // namespace oneflow ================================================ FILE: oneflow/core/operator/image_decoder_random_crop_resize_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/scope.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace oneflow { namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { const ImageDecoderRandomCropResizeOpConf& conf = op_conf.image_decoder_random_crop_resize_conf(); const BlobDesc* in = BlobDesc4BnInOp("in"); BlobDesc* out = BlobDesc4BnInOp("out"); CHECK_EQ_OR_RETURN(in->data_type(), DataType::kTensorBuffer); *out = *in; out->set_data_type(DataType::kUInt8); DimVector out_dim_vec = in->shape().dim_vec(); out_dim_vec.emplace_back(conf.target_height()); out_dim_vec.emplace_back(conf.target_width()); out_dim_vec.emplace_back(3); out->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } } // namespace class ImageDecoderRandomCropResizeOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ImageDecoderRandomCropResizeOp); ImageDecoderRandomCropResizeOp() = default; ~ImageDecoderRandomCropResizeOp() override = default; private: Maybe InitFromOpConf() override { EnrollInputBn("in", false); EnrollOutputBn("out", false); EnrollTmpBn("tmp"); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(this->op_conf(), BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(this->op_conf(), GetBlobDesc4BnInOp); } Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override { const ImageDecoderRandomCropResizeOpConf& conf = this->op_conf().image_decoder_random_crop_resize_conf(); BlobDesc* tmp = GetBlobDesc4BnInOp("tmp"); tmp->set_data_type(DataType::kUInt8); tmp->set_shape(Shape({conf.max_num_pixels() * 3 * conf.num_workers()})); return Maybe::Ok(); } Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split("in", 0) .Split("out", 0) .MakeSplitSignatureListBuilder(JUST(LogicalBlobDesc4Ibn("in")).shape().NumAxes()) .Build(sbp_sig_list); return Maybe::Ok(); } void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override { const ImageDecoderRandomCropResizeOpConf& conf = this->op_conf().image_decoder_random_crop_resize_conf(); int64_t seed; if (conf.has_seed()) { seed = conf.seed(); } else { std::random_device rd; seed = rd(); } std::seed_seq seq{seed}; std::vector seeds(parallel_ctx->parallel_num()); seq.generate(seeds.begin(), seeds.end()); kernel_conf->mutable_image_decoder_random_crop_resize_conf()->set_seed( seeds.at(parallel_ctx->parallel_id())); kernel_conf->mutable_image_decoder_random_crop_resize_conf()->set_batch_size( GetBlobDesc4BnInOp("in")->shape().elem_cnt()); } Maybe InferBlobParallelDesc() override { HashMap> bn2parallel_desc; const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); bn2parallel_desc["out"] = op_parallel_desc; if (device_type() == DeviceType::kCPU) { bn2parallel_desc["in"] = op_parallel_desc; } else if (device_type() == DeviceType::kCUDA) { std::shared_ptr in_parallel_desc = std::make_shared(*op_parallel_desc); in_parallel_desc->set_device_type(DeviceType::kCPU); bn2parallel_desc["in"] = in_parallel_desc; } else { UNIMPLEMENTED_THEN_RETURN(); } JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe { auto it = bn2parallel_desc.find(bn); CHECK_OR_RETURN(it != bn2parallel_desc.end()); return it->second; })); return Maybe::Ok(); } }; #if defined(WITH_CUDA) && CUDA_VERSION >= 10020 REGISTER_OP(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeOp); #else REGISTER_CPU_OP(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeOp); #endif } // namespace oneflow ================================================ FILE: oneflow/core/operator/input_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/operator/input_op.h" #include "oneflow/core/operator/interface_op_util.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe InferInputOpNdSbpSignature(NdSbpSignature* nd_sbp_signature, const ParallelDesc& parallel_desc, const OperatorConf& op_conf) { const auto& parallel_hierarchy = parallel_desc.hierarchy(); const InterfaceBlobConf& blob_conf = op_conf.input_conf().blob_conf(); if (op_conf.input_conf().has_tick()) { NdSbp& tick_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["tick"]; tick_nd_sbp.clear_sbp_parallel(); FOR_RANGE(int64_t, i, 0, parallel_hierarchy->NumAxes()) { tick_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } } NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["out"]; JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &out_nd_sbp)); return Maybe::Ok(); } } // namespace Maybe InputOp::InitFromOpConf() { CHECK(op_conf().has_input_conf()); if (op_conf().input_conf().has_tick()) { EnrollInputBn("tick", false); } OutputBlobModifier* modifier = EnrollOutputBn("out", false); modifier->set_is_mutable(true); modifier->set_header_infered_before_compute(false); return Maybe::Ok(); } Maybe InputOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc, parallel_desc)); return Maybe::Ok(); } Maybe InputOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc, parallel_ctx, *JUST(GetOpParallelDesc()))); return Maybe::Ok(); } Maybe InputOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), output_bns(), sbp_signature)); return Maybe::Ok(); } Maybe InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), output_bns(), sbp_sig_list->mutable_sbp_signature()->Add())); return Maybe::Ok(); } Maybe InputOp::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { NdSbpSignature nd_sbp_signature; JUST(InferInputOpNdSbpSignature(&nd_sbp_signature, parallel_desc, op_conf())); nd_sbp_sig_list->emplace_back(nd_sbp_signature); return Maybe::Ok(); } Maybe InputOp::InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const { JUST(InferInputOpNdSbpSignature(nd_sbp_signature, parallel_desc, op_conf())); return Maybe::Ok(); } Symbol InputOp::GetOpConfWithoutOpNameAndLbn() const { return SymbolOf(this->op_conf()); } REGISTER_OP(OperatorConf::kInputConf, InputOp); REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kInputConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kInputConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/input_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_INPUT_OP_H_ #define ONEFLOW_CORE_OPERATOR_INPUT_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class InputOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(InputOp); InputOp() : Operator() {} ~InputOp() = default; Maybe InitFromOpConf() override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; Maybe InferNdSbpSignature(NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_INPUT_OP_H_ ================================================ FILE: oneflow/core/operator/interface_blob_conf.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/job/sbp_parallel.proto"; message InterfaceBlobConf { optional ShapeProto shape = 1; optional DataType data_type = 2; optional bool is_dynamic = 3; optional NdSbp nd_sbp = 4; } ================================================ FILE: oneflow/core/operator/interface_op_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/interface_op_util.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { namespace { void CheckShape(const Shape& shape) { FOR_RANGE(int, i, 1, shape.NumAxes()) { CHECK_GE(shape.At(i), 0); } } Maybe GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature, bool is_for_input_op) { if (!blob_conf.has_nd_sbp()) { SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature); return Maybe::Ok(); } CHECK_EQ_OR_RETURN(blob_conf.nd_sbp().sbp_parallel_size(), 1); const auto& sbp_parallel = blob_conf.nd_sbp().sbp_parallel(0); if (sbp_parallel.has_split_parallel()) { int64_t num_axes = blob_conf.shape().dim_size(); int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_GE_OR_RETURN(split_axis, 0); CHECK_LT_OR_RETURN(split_axis, num_axes); SbpSignatureBuilder sbp_signature_builder; if (is_for_input_op) { // broadcast tick args for InputOp sbp_signature_builder.Broadcast(input_bns); } else { sbp_signature_builder.Split(input_bns, split_axis); } sbp_signature_builder.Split(output_bns, split_axis).Build(sbp_signature); } else { SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature); } return Maybe::Ok(); } } // namespace Maybe InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelContext* parallel_ctx, const ParallelDesc& parallel_desc) { NdSbp nd_sbp; JUST(ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &nd_sbp)); out_blob_desc->set_shape( *JUST(GetPhysicalShape(Shape(blob_conf.shape()), nd_sbp, parallel_desc, *parallel_ctx))); out_blob_desc->set_data_type(blob_conf.data_type()); out_blob_desc->set_is_dynamic(blob_conf.is_dynamic()); return Maybe::Ok(); } Maybe InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelDesc& parallel_desc) { CHECK_OR_RETURN(blob_conf.has_shape()); out_blob_desc->set_shape(Shape(blob_conf.shape())); CheckShape(out_blob_desc->shape()); if (out_blob_desc->shape().NumAxes() > 0) { CHECK_GT(out_blob_desc->shape().At(0), 0); } CHECK_OR_RETURN(blob_conf.has_data_type()); out_blob_desc->set_data_type(blob_conf.data_type()); CHECK_OR_RETURN(blob_conf.has_is_dynamic()); out_blob_desc->set_is_dynamic(blob_conf.is_dynamic()); return Maybe::Ok(); } Maybe InterfaceOpUtil::GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature) { JUST(GetSbpSignature(blob_conf, input_bns, output_bns, sbp_signature, true)); return Maybe::Ok(); } Maybe InterfaceOpUtil::GetOutputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature) { JUST(GetSbpSignature(blob_conf, input_bns, output_bns, sbp_signature, false)); return Maybe::Ok(); } Maybe InterfaceOpUtil::InitBlobConf(InterfaceBlobConf* blob_conf, const ParallelBlobConf& parallel_blob_conf) { BlobDesc blob_desc(parallel_blob_conf.logical_blob_desc_conf()); blob_desc.shape().ToProto(blob_conf->mutable_shape()); blob_conf->set_data_type(blob_desc.data_type()); blob_conf->set_is_dynamic(blob_desc.is_dynamic()); *blob_conf->mutable_nd_sbp() = parallel_blob_conf.nd_sbp(); return Maybe::Ok(); } Maybe InterfaceOpUtil::ParseNdSbpFromBlobConf(const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc, NdSbp* nd_sbp) { const int64_t num_axes = parallel_desc.hierarchy()->NumAxes(); if (blob_conf.has_nd_sbp()) { *nd_sbp = NdSbp(blob_conf.nd_sbp()); } else { nd_sbp->clear_sbp_parallel(); FOR_RANGE(int64_t, i, 0, num_axes) { nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/operator/interface_op_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_ #define ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { struct InterfaceOpUtil final { static Maybe InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelContext* parallel_ctx, const ParallelDesc& parallel_desc); static Maybe InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelDesc& parallel_desc); static Maybe GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature); static Maybe GetOutputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature); static Maybe InitBlobConf(InterfaceBlobConf* blob_conf, const ParallelBlobConf& parallel_blob_conf); static Maybe ParseNdSbpFromBlobConf(const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc, NdSbp* nd_sbp); }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_ ================================================ FILE: oneflow/core/operator/learning_rate_schedule_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" namespace oneflow { class LearningRateScheduleOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(LearningRateScheduleOp); LearningRateScheduleOp() = default; ~LearningRateScheduleOp() override = default; Maybe InitFromOpConf() override; virtual Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; Maybe LearningRateScheduleOp::InitFromOpConf() { CHECK(op_conf().has_learning_rate_schedule_conf()); EnrollInputBn("train_step"); EnrollOutputBn("out"); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { const BlobDesc* train_step = BlobDesc4BnInOp("train_step"); CHECK_EQ(train_step->shape().elem_cnt(), 1); CHECK_EQ(train_step->data_type(), DataType::kInt64); BlobDesc* out = BlobDesc4BnInOp("out"); out->set_shape(Shape({1})); out->set_data_type(DataType::kFloat); return Maybe::Ok(); } } // namespace Maybe LearningRateScheduleOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe LearningRateScheduleOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe LearningRateScheduleOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kLearningRateScheduleConf, LearningRateScheduleOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/nccl_send_recv_boxing_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" namespace oneflow { class NcclSendRecvBoxingOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingOp); NcclSendRecvBoxingOp() = default; ~NcclSendRecvBoxingOp() override = default; Maybe InitFromOpConf() override; Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; Maybe NcclSendRecvBoxingOp::InitFromOpConf() { const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); if (conf.has_input()) { EnrollInputBn("in", false); } if (conf.has_output()) { EnrollOutputBn("out", false); } EnrollTmpBn("buf"); return Maybe::Ok(); } Maybe NcclSendRecvBoxingOp::InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { BlobDesc* buf = GetBlobDesc4BnInOp("buf"); const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); const NdSbp& src_nd_sbp = conf.src_nd_sbp(); const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); ParallelDesc parallel_desc(conf.parallel_conf()); ParallelDesc in_parallel_desc(conf.src_parallel_conf()); ParallelDesc out_parallel_desc(conf.dst_parallel_conf()); const int64_t parallel_num = parallel_desc.parallel_num(); const int64_t parallel_id = parallel_ctx->parallel_id(); const Shape& logical_shape = Shape(conf.logical_shape()); std::vector src_send_intersections; std::vector dst_recv_intersections; GetRankSendRecvIntersection(parallel_id, parallel_desc, in_parallel_desc, out_parallel_desc, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); int64_t buf_count = 0; if (conf.has_input()) { const BlobDesc* in = GetBlobDesc4BnInOp("in"); buf->set_data_type(in->data_type()); CHECK_EQ(src_send_intersections.size(), parallel_num); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = JUST(VectorAt(src_send_intersections, i)); if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } } } if (conf.has_output()) { const BlobDesc* out = GetBlobDesc4BnInOp("out"); buf->set_data_type(out->data_type()); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = JUST(VectorAt(dst_recv_intersections, i)); if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } } if (NdSbpHasPartialParallel(src_nd_sbp)) { // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. buf_count += out->shape().elem_cnt(); } } buf->set_shape(Shape({buf_count})); return Maybe::Ok(); } LogicalBlobId NcclSendRecvBoxingOp::lbi4ibn(const std::string& input_bn) const { return this->op_conf().nccl_send_recv_boxing_conf().lbi(); } LogicalBlobId NcclSendRecvBoxingOp::lbi4obn(const std::string& output_bn) const { return this->op_conf().nccl_send_recv_boxing_conf().lbi(); } Maybe NcclSendRecvBoxingOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); const Shape& logical_shape = Shape(conf.logical_shape()); const ParallelDesc& parallel_desc = ParallelDesc(conf.parallel_conf()); const int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_ctx->parallel_id())); const int64_t device_index = JUST(parallel_desc.DeviceId4ParallelId(parallel_ctx->parallel_id())); if (conf.has_input()) { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); const NdSbp& src_nd_sbp = conf.src_nd_sbp(); const ParallelDesc& src_parallel_desc = ParallelDesc(conf.src_parallel_conf()); int64_t src_parallel_id = JUST(src_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); std::shared_ptr in_shape = JUST(GetPhysicalShape(logical_shape, src_nd_sbp, src_parallel_desc, src_parallel_id)); CHECK_EQ_OR_RETURN(*in_shape, in_blob_desc->shape()) << "Non-matching shape of blobs for pieces of nccl send recv"; } if (conf.has_output()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); const ParallelDesc& dst_parallel_desc = ParallelDesc(conf.dst_parallel_conf()); int64_t dst_parallel_id = JUST(dst_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); std::shared_ptr out_shape = JUST(GetPhysicalShape(logical_shape, dst_nd_sbp, dst_parallel_desc, dst_parallel_id)); out_blob_desc->set_shape(*out_shape); out_blob_desc->set_data_type(conf.data_type()); } return Maybe::Ok(); } REGISTER_OP(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" namespace oneflow { namespace { // Go through all the ranks while transfer between two nd sbps with no PartialSum under the same // placement. // NOTE: We need to make sure no partial sums in the sbps of the producer and consumer. void DfsTraverseRanks4NdSbp( int32_t depth, std::vector& in_parallel_ids, const std::vector& out_parallel_ids, const Shape& in_parallel_hierarchy, const NdIndexOffsetHelper& in_hierarchy_index_helper, const NdSbp& in_nd_sbp, const std::function& visit) { if (depth >= in_parallel_hierarchy.NumAxes()) { visit(in_hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), in_parallel_hierarchy.NumAxes())); return; } if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) { // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the // current rank along the depth-dimension. in_parallel_ids[depth] = out_parallel_ids[depth]; DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, in_hierarchy_index_helper, in_nd_sbp, visit); } else { // If Split or PartialSum, go through all the ranks along the depth-dimension. for (int64_t i = 0; i < in_parallel_hierarchy.dim_vec().at(depth); i++) { in_parallel_ids[depth] = i; DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, in_hierarchy_index_helper, in_nd_sbp, visit); } } } bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { CHECK_GT(nd_sbp.sbp_parallel_size(), 0); FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } } return true; } } // namespace int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, const ParallelDesc& to_parallel_desc) { const int64_t machine_id = CHECK_JUST(from_parallel_desc.MachineId4ParallelId(from_parallel_id)); const int64_t device_index = CHECK_JUST(from_parallel_desc.DeviceId4ParallelId(from_parallel_id)); if (to_parallel_desc.Containing(machine_id, device_index)) { return CHECK_JUST(to_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); } else { return -1; } } void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& logical_shape, std::vector* send_intersections, std::vector* recv_intersections) { const int64_t parallel_num = parallel_desc.parallel_num(); CHECK_LT(parallel_id, parallel_num); const std::vector& in_slices = GetTensorSliceView(*in_parallel_desc.hierarchy(), in_nd_sbp, logical_shape); const std::vector& out_slices = GetTensorSliceView(*out_parallel_desc.hierarchy(), out_nd_sbp, logical_shape); const auto& in_parallel_hierarchy = in_parallel_desc.hierarchy(); int32_t in_hierarchy_dimension = in_parallel_hierarchy->NumAxes(); const NdIndexOffsetHelper in_hierarchy_index_helper( in_parallel_hierarchy->dim_vec().data(), in_hierarchy_dimension); const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); const int64_t in_parallel_num = in_parallel_desc.parallel_num(); const int64_t out_parallel_num = out_parallel_desc.parallel_num(); // cur rank recv from // cur rank has output if (out_parallel_desc.Containing(machine_id, device_index)) { recv_intersections->resize(parallel_num); int64_t out_id = CHECK_JUST(out_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); const TensorSliceView& cur_rank_out_slice = out_slices.at(out_id); const auto& add_to_recv_intersections = [&](int32_t send_id) { const TensorSliceView& in_slice = in_slices.at(send_id); const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); if (intersection.IsEmpty()) { return; } const int64_t merged_id = GetMappedParallelId(send_id, in_parallel_desc, parallel_desc); recv_intersections->at(merged_id) = intersection; }; int64_t corresponding_in_id = 0; // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] if (in_parallel_desc.Containing(machine_id, device_index)) { // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description // The id of 1 is (0, 1), the id of 3 is (1, 1) corresponding_in_id = CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); } else { // 5 and 7 are not in [[0, 1], [2, 3]] // Then the id does not matter corresponding_in_id = out_id % in_parallel_num; } std::vector in_parallel_ids(in_hierarchy_dimension); // The corresponding parallel id of a consumer rank in the producer parallel description std::vector out_parallel_ids(in_hierarchy_dimension); in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), in_hierarchy_dimension); DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, in_hierarchy_index_helper, in_nd_sbp, add_to_recv_intersections); } // cur rank send to if (in_parallel_desc.Containing(machine_id, device_index)) { send_intersections->resize(parallel_num); int64_t in_id = CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); const TensorSliceView& cur_rank_in_slice = in_slices.at(in_id); for (int64_t recv_i = 0; recv_i < out_parallel_num; ++recv_i) { const auto& add_to_send_intersections = [&](int32_t send_id) { if (send_id != in_id) { return; } const TensorSliceView& out_slice = out_slices.at(recv_i); const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); if (intersection.IsEmpty()) { return; } const int64_t merged_id = GetMappedParallelId(recv_i, out_parallel_desc, parallel_desc); send_intersections->at(merged_id) = intersection; }; int64_t out_device_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(recv_i)); int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(recv_i)); int64_t corresponding_in_id = 0; // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] if (in_parallel_desc.Containing(out_machine_id, out_device_id)) { // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description // The id of 1 is (0, 1), the id of 3 is (1, 1) corresponding_in_id = CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(out_machine_id, out_device_id)); } else { // 5 and 7 are not in [[0, 1], [2, 3]] // Then the id does not matter corresponding_in_id = recv_i % in_parallel_num; } std::vector in_parallel_ids(in_hierarchy_dimension); // The corresponding parallel id of a consumer rank in the producer parallel description std::vector out_parallel_ids(in_hierarchy_dimension); in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), in_hierarchy_dimension); DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, in_hierarchy_index_helper, in_nd_sbp, add_to_send_intersections); } } } } // namespace oneflow ================================================ FILE: oneflow/core/operator/nccl_send_recv_boxing_op_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, const ParallelDesc& to_parallel_desc); void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, const Shape& logical_shape, std::vector* send_intersections, std::vector* recv_intersections); } // namespace oneflow ================================================ FILE: oneflow/core/operator/op_attribute.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/operator/op_conf.proto"; import "oneflow/core/operator/arg_modifier_signature.proto"; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/job/local_parallel.proto"; import "oneflow/core/job/blob_lifetime_signature.proto"; import "oneflow/core/job/parallel_signature.proto"; import "oneflow/core/job/parallel_conf_signature.proto"; message OpAttribute { repeated string input_bns = 1; repeated string output_bns = 2; repeated string tmp_bns = 3; required OperatorConf op_conf = 50; // inter-node signature required ArgSignature arg_signature = 100; required ArgModifierSignature arg_modifier_signature = 101; optional BlobLastUsedSignature blob_last_used_signature = 102; optional BlobBackwardUsedSignature blob_backward_used_signature = 103; // op node signature optional SbpSignature sbp_signature = 104; optional LocalSignature local_signature = 105; optional BlobDescSignature logical_blob_desc_signature = 106; optional ParallelConfSignature parallel_conf_signature = 109; optional NdSbpSignature nd_sbp_signature = 110; } message OpAttributeList { repeated OpAttribute op_attribute = 1; } ================================================ FILE: oneflow/core/operator/op_conf.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/common/device_type.proto"; import "oneflow/core/record/record.proto"; import "oneflow/core/job/resource.proto"; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/register/tensor_slice_view.proto"; import "oneflow/core/framework/user_op_conf.proto"; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/graph/boxing/collective_boxing.proto"; import "oneflow/core/job/initializer_conf.proto"; import "oneflow/core/job/regularizer_conf.proto"; import "oneflow/core/job/placement.proto"; import "oneflow/core/job/learning_rate_schedule_conf.proto"; import "oneflow/core/operator/interface_blob_conf.proto"; import "oneflow/core/register/blob_desc.proto"; enum ActivationType { kNone = 0; kTanH = 1; kSigmoid = 2; kRelu = 3; } message DistributeConcatOpConf { repeated string in = 1; required string out = 2; required int32 axis = 3; } message DistributeSplitOpConf { required string in = 1; repeated string out = 2; required int32 axis = 3; optional bool is_variable_ref = 4 [default = false]; } message DistributeCloneOpConf { required string in = 1; repeated string out = 2; optional bool is_variable_ref = 3 [default = false]; } message DistributeAddOpConf { repeated string in = 1; required string out = 2; } message CopyCommNetOpConf { required LogicalBlobId lbi = 2; } message BoxConcatConf { required int32 axis = 1; } message BoxAddConf { } message BoxSplitConf { required int32 axis = 1; repeated int32 part_num = 2; } message BoxCloneConf { } message BoxingOpConf { required LogicalBlobId lbi = 1; required int32 in_num = 2; required int32 out_num = 3; oneof in_box { BoxConcatConf concat_box = 4; BoxAddConf add_box = 5; } oneof out_box { BoxSplitConf split_box = 6; BoxCloneConf clone_box = 7; } } message DynamicReshapeOpConf { required string in = 1; required string out = 2; required ShapeProto shape = 3; } message DynamicReshapeLikeOpConf { required string x = 1; required string y = 2; required string like = 3; } message FeedInputOpConf { // NOTE(chengcheng): define in/out key as UserOp ibn/obn. required string in_0 = 1; required string out_0 = 2; } message FeedVariableOpConf { required string in_0 = 1; required string out_0 = 2; } message FetchOutputOpConf { required string in_0 = 1; required string out_0 = 2; } message InputOpConf { optional string tick = 1; required string out = 2; required InterfaceBlobConf blob_conf = 3; optional string job_name = 4; } message ReturnOpConf { required string in = 1; required string out = 2; optional string job_name = 3; } message OutputOpConf { required string in = 1; required string out = 2; required InterfaceBlobConf blob_conf = 3; optional string job_name = 4; } message VariableOpConf { optional string tick = 1; required string out = 2; required ShapeProto shape = 3; optional DataType data_type = 4; oneof initialize { InitializerConf initializer = 5; InitializeWithSnapshotConf initialize_with_snapshot = 6; } optional string model_name = 7 [default = "weight"]; optional int64 random_seed = 9; optional RegularizerConf regularizer = 10; optional bool trainable = 11 [default = true]; repeated string nd_sbp = 12; } message TickOpConf { repeated string tick = 1; required string out = 2; } message CriticalSectionWaitTickOpConf { repeated string tick = 1; required string out = 2; required string buffer_name = 3; } message CriticalSectionCallbackTickOpConf { repeated string tick = 1; required string out = 2; required string buffer_name = 3; } message DeviceTickOpConf { repeated string tick = 1; required string out = 2; optional ShapeProto time_shape = 3; } message WaitAndSendIdsOpConf { required string out = 1; required string wait_buffer_name = 2; repeated Int64List id_list = 3; required DataType data_type = 4 [default = kInt32]; optional string job_name = 5; } message CallbackNotifyOpConf { required string in = 1; repeated string callback_buffer_name = 2; optional string job_name = 3; } message ReentrantLockOpConf { required string start = 1; optional string end = 2; required string out = 3; repeated Int64List lock_id2intersecting_lock_ids = 4; } message SrcSubsetTickOpConf { repeated string in = 1; required string out = 2; } message DstSubsetTickOpConf { repeated string in = 1; required string out = 2; } message SourceTickOpConf { required string out = 1; } message SinkTickOpConf { repeated string tick = 1; required string out = 2; } message TotalLossInstanceNumOpConf { repeated string in = 1; required string out = 2; } message ShapeElemCntAxisConf { repeated int32 axis = 1; } message ShapeElemCntRangeAxisConf { // closed interval: [begin_axis, end_axis] optional int32 begin_axis = 1 [default = 0]; optional int32 end_axis = 2 [default = -1]; } message ShapeElemCntOpConf { required string x = 1; required string y = 2; optional DataType data_type = 3 [default = kInt32]; oneof axis_conf { ShapeElemCntAxisConf exclude_axis_conf = 4; ShapeElemCntAxisConf include_axis_conf = 5; ShapeElemCntRangeAxisConf range_axis_conf = 6; } } message AccTickOpConf { // in required string one = 1; // out required string acc = 2; optional int32 max_acc_num = 3 [default = 1]; } message IdentityOpConf { required string in = 1; required string out = 2; } message CopyOpConf { required string in = 1; required string out = 2; } message CastToLocalOpConf { required string in = 1; required string out = 2; required SbpParallel sbp_parallel = 3; } message CastFromLocalOpConf { required string in = 1; required string out = 2; required SbpParallel sbp_parallel = 3; } message CaseOpConf { required string in = 1; repeated string out = 2; } message EsacOpConf { repeated string in = 1; required string out = 2; optional DataType data_type = 3 [default=kInt32]; } message AssignOpConf { required string ref = 1; required string value = 2; } message LearningRateScheduleOpConf { required string train_step = 1; required string out = 2; required float learning_rate = 3; optional LearningRateDecayConf learning_rate_decay = 4; } message SliceBoxingConf { required LogicalBlobId lbi = 1; repeated TensorSliceViewProto in_slice = 2; required TensorSliceViewProto out_slice = 3; optional ShapeProto out_shape = 4; } message SliceBoxingCopyOpConf { required SliceBoxingConf slice_boxing_conf = 1; } message SliceBoxingAddOpConf { required SliceBoxingConf slice_boxing_conf = 1; } message ConstantLikeOpConf { required string like = 1; required string out = 2; optional DataType data_type = 3; oneof scalar_operand { int64 int_operand = 4; double float_operand = 5; } } message SyncDynamicResizeOpConf { required string in = 1; required string size = 2; required string out = 3; required int64 axis = 4; optional bool eager = 5 [default = false]; } message BroadcastToCompatibleWithOpConf { required string x = 1; repeated string compatible = 2; required string y = 3; } message CollectiveBoxingGenericOpConf { required LogicalBlobId lbi = 1; required boxing.collective.RankDesc rank_desc = 2; } message BoxingIdentityOpConf { required LogicalBlobId lbi = 1; } message CollectiveBoxingPackOpConf { required LogicalBlobId lbi = 1; required SbpParallel src_sbp_parallel = 2; required SbpParallel dst_sbp_parallel = 3; required int64 num_ranks = 4; required ShapeProto logical_shape = 5; } message CollectiveBoxingUnpackOpConf { required LogicalBlobId lbi = 1; required SbpParallel src_sbp_parallel = 2; required SbpParallel dst_sbp_parallel = 3; required int64 num_ranks = 4; required ShapeProto logical_shape = 5; } message ImageDecoderRandomCropResizeOpConf { required string in = 1; required string out = 2; required int64 target_width = 3; required int64 target_height = 4; optional int64 num_workers = 5 [default = 3]; optional int64 max_num_pixels = 6 [default = 67108864]; optional int64 warmup_size = 7 [default = 6400]; optional int64 seed = 8; optional int64 num_attempts = 9 [default = 10]; optional float random_area_min = 10 [default = 0.08]; optional float random_area_max = 11 [default = 1.0]; optional float random_aspect_ratio_min = 12 [default = 0.75]; optional float random_aspect_ratio_max = 13 [default = 1.333333]; } message BoxingZerosOpConf { required LogicalBlobId lbi = 1; required ShapeProto shape = 2; required DataType data_type = 3; } message NcclSendRecvBoxingOpConf { required LogicalBlobId lbi = 1; required NdSbp src_nd_sbp = 2; required NdSbp dst_nd_sbp = 3; required ParallelConf parallel_conf = 4; required ParallelConf src_parallel_conf = 5; required ParallelConf dst_parallel_conf = 6; required ShapeProto logical_shape = 7; required DataType data_type = 8; required bool has_input = 9; required bool has_output = 10; } message OperatorConf { required string name = 1; optional string device_tag = 4 [default = "invalid_device"]; repeated string ctrl_in_op_name = 7; optional int64 scope_symbol_id = 8; optional string stream_name_hint = 9; optional string pass_tag = 10; optional string loc = 11 [default = ""]; optional int64 logical_chain_id = 12 [default = -1]; optional int64 order_in_logical_chain = 13 [default = -1]; optional string calculation_pass_name = 14 [default = "forward_pass"]; oneof op_type { // system op CopyCommNetOpConf copy_comm_net_conf = 106; BoxingOpConf boxing_conf = 108; VariableOpConf variable_conf = 122; TickOpConf tick_conf = 124; CriticalSectionWaitTickOpConf critical_section_wait_tick_conf = 125; CriticalSectionCallbackTickOpConf critical_section_callback_tick_conf = 126; TotalLossInstanceNumOpConf total_loss_instance_num_conf = 131; ShapeElemCntOpConf shape_elem_cnt_conf = 132; SrcSubsetTickOpConf src_subset_tick_conf = 133; DstSubsetTickOpConf dst_subset_tick_conf = 134; SourceTickOpConf source_tick_conf = 135; SinkTickOpConf sink_tick_conf = 136; InputOpConf input_conf = 137; OutputOpConf output_conf = 138; WaitAndSendIdsOpConf wait_and_send_ids_conf = 139; ReentrantLockOpConf reentrant_lock_conf = 140; CallbackNotifyOpConf callback_notify_conf = 141; AccTickOpConf acc_tick_conf = 144; ReturnOpConf return_conf = 146; DistributeConcatOpConf distribute_concat_conf = 155; DistributeSplitOpConf distribute_split_conf = 156; DistributeCloneOpConf distribute_clone_conf = 157; DistributeAddOpConf distribute_add_conf = 158; DeviceTickOpConf device_tick_conf = 159; SliceBoxingCopyOpConf slice_boxing_copy_conf = 166; SliceBoxingAddOpConf slice_boxing_add_conf = 167; CollectiveBoxingGenericOpConf collective_boxing_generic_conf = 170; BoxingIdentityOpConf boxing_identity_conf = 171; CollectiveBoxingPackOpConf collective_boxing_pack_conf = 174; CollectiveBoxingUnpackOpConf collective_boxing_unpack_conf = 175; BoxingZerosOpConf boxing_zeros_conf = 176; NcclSendRecvBoxingOpConf nccl_send_recv_boxing_conf = 177; UserOpConf user_conf = 199; // domain op DynamicReshapeOpConf dynamic_reshape_conf = 203; DynamicReshapeLikeOpConf dynamic_reshape_like_conf = 287; IdentityOpConf identity_conf = 290; CaseOpConf case_conf = 291; EsacOpConf esac_conf = 292; AssignOpConf assign_conf = 296; LearningRateScheduleOpConf learning_rate_schedule_conf = 298; ConstantLikeOpConf constant_like_conf = 339; SyncDynamicResizeOpConf sync_dynamic_resize_conf = 340; CopyOpConf copy_conf = 343; CastToLocalOpConf cast_to_local_conf = 344; CastFromLocalOpConf cast_from_local_conf = 345; ImageDecoderRandomCropResizeOpConf image_decoder_random_crop_resize_conf = 349; // math op BroadcastToCompatibleWithOpConf broadcast_to_compatible_with_conf = 525; // NOTE(chengcheng): Lazy 1.0 system ops. // Feed EagerTensor to interface op. // Note that FeedxxOp just for build CustomOpExpr, and has NO operator impl. FeedInputOpConf feed_input_conf = 600; FeedVariableOpConf feed_variable_conf = 601; // Fetch EagerTensor from output op FetchOutputOpConf fetch_output_conf = 602; } } message OpNameRelations { map src_op_name2dst_op_name = 1; } message OpNameGroups { message OpNameGroup { repeated string op_name = 1; } repeated OpNameGroup op_name_group = 2; } ================================================ FILE: oneflow/core/operator/op_conf_symbol.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/op_conf_symbol.h" namespace oneflow { OperatorConfSymbol::OperatorConfSymbol(int64_t symbol_id, const OperatorConf& op_conf) : symbol_id_(symbol_id), op_conf_(op_conf) {} } // namespace oneflow ================================================ FILE: oneflow/core/operator/op_conf_symbol.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_ #define ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_ #include #include "oneflow/core/common/optional.h" #include "oneflow/core/operator/op_conf.pb.h" namespace oneflow { class OperatorConfSymbol final { public: OperatorConfSymbol(const OperatorConfSymbol&) = delete; OperatorConfSymbol(OperatorConfSymbol&&) = delete; OperatorConfSymbol(int64_t symbol_id, const OperatorConf& op_conf); ~OperatorConfSymbol() = default; const OperatorConf& op_conf() const { return op_conf_; } const OperatorConf& data() const { return op_conf_; } const Optional& symbol_id() const { return symbol_id_; } private: Optional symbol_id_; OperatorConf op_conf_; std::shared_ptr data_; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_ ================================================ FILE: oneflow/core/operator/op_conf_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_ #define ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_ #include "oneflow/core/operator/op_conf.pb.h" namespace std { template<> struct hash<::oneflow::OperatorConf::OpTypeCase> { std::size_t operator()(const ::oneflow::OperatorConf::OpTypeCase& op_type) const { return std::hash()(static_cast(op_type)); } }; } // namespace std #endif // ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_ ================================================ FILE: oneflow/core/operator/op_infer_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_ #define ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/dtype_signature.h" #include "oneflow/core/common/symbol.h" namespace oneflow { struct OpInferCacheKey final { const void* scope; Symbol op_conf_sym; Symbol dtype_signature_sym; std::vector> ibn_idx2shape_sym; }; struct OpInferCacheValue final { std::vector> obn_idx2shape_sym; }; inline bool operator==(const OpInferCacheKey& lhs, const OpInferCacheKey& rhs) { return lhs.scope == rhs.scope && lhs.op_conf_sym == rhs.op_conf_sym && lhs.dtype_signature_sym == rhs.dtype_signature_sym && lhs.ibn_idx2shape_sym == rhs.ibn_idx2shape_sym; } inline bool operator!=(const OpInferCacheKey& lhs, const OpInferCacheKey& rhs) { return !(lhs == rhs); } } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::OpInferCacheKey& op_infer_cache_key) const { using namespace oneflow; size_t ibn_idx2shape_sym_hash_value = 0; for (const auto& shape_sym : op_infer_cache_key.ibn_idx2shape_sym) { AddHash(&ibn_idx2shape_sym_hash_value, shape_sym); } return Hash(op_infer_cache_key.scope, op_infer_cache_key.op_conf_sym, ibn_idx2shape_sym_hash_value, op_infer_cache_key.dtype_signature_sym); } }; } // namespace std #endif // ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_ ================================================ FILE: oneflow/core/operator/op_node_signature.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/job/local_parallel.proto"; import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/job/parallel_signature.proto"; message OpNodeSignature { optional SbpSignature sbp_signature = 1; optional LocalSignature local_signature = 2; optional BlobDescSignature logical_blob_desc_signature = 3; optional ParallelSignature parallel_signature = 5; } ================================================ FILE: oneflow/core/operator/operator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_node_signature.pb.h" #include "oneflow/core/job/nd_sbp_infer_hint.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/framework/placement_sbp_util.h" namespace oneflow { namespace { DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { for (const std::string& bn_in_op : bn_in_ops) { const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op); if (blob_desc) { return blob_desc->data_type(); } } return DataType::kInvalidDataType; } Maybe CheckAndConstructOp(std::shared_ptr op_conf) { Operator* rptr = NewObj(op_conf->op_type_case(), *op_conf); DeviceType device_type = JUST(DeviceType4DeviceTag(op_conf->device_tag())); if (IsCpuOnly(*op_conf)) { CHECK_EQ_OR_RETURN(device_type, DeviceType::kCPU); } JUST(rptr->Init(op_conf)); return std::shared_ptr(rptr); } } // namespace Operator::Operator() : device_type_(DeviceType::kInvalidDevice) {} Maybe Operator::Init(const OperatorConf& op_conf) { return Init(std::make_shared(op_conf)); } Maybe Operator::Init(std::shared_ptr op_conf) { op_conf_ = std::move(op_conf); device_type_ = JUST(DeviceType4DeviceTag(op_conf_->device_tag())); JUST(InitFromOpConf()); input_output_bns_.Reserve(input_bns().size() + output_bns().size()); for (const auto& bn : input_bns()) { *input_output_bns_.Add() = bn; } for (const auto& bn : output_bns()) { *input_output_bns_.Add() = bn; } return Maybe::Ok(); } const LogicalBlobId& Operator::BnInOp2Lbi(const std::string& bn_in_op) const { return arg_signature_.bn_in_op2lbi().at(bn_in_op); } const OperatorConf& Operator::op_conf() const { CHECK(op_conf_); return *op_conf_; } std::shared_ptr Operator::shared_op_conf() const { return op_conf_; } DeviceType Operator::device_type() const { return device_type_; } const std::string& Operator::SoleIbn() const { CHECK_EQ(input_bns().size(), 1) << ", op_name " << op_name(); return input_bns().Get(0); } const std::string& Operator::SoleObn() const { CHECK_EQ(output_bns().size(), 1) << ", op_name " << op_name(); return output_bns().Get(0); } const std::string& Operator::SoleTbn() const { CHECK_EQ(tmp_bns().size(), 1); return tmp_bns().Get(0); } Maybe Operator::obn4lbi(const LogicalBlobId& lbi) const { const auto& it = lbi2output_index_.find(lbi); CHECK_OR_RETURN(it != lbi2output_index_.end()) << "no logical blob id found. lbn: " << lbi.op_name() << "/" << lbi.blob_name(); return &output_bns().Get(it->second); } const PbRpf& Operator::input_bns() const { return input_bns_; } const PbRpf& Operator::output_bns() const { return output_bns_; } const PbRpf& Operator::tmp_bns() const { return tmp_bns_; } const PbRpf& Operator::input_output_bns() const { return input_output_bns_; } Maybe Operator::InferParallelSignatureIf() { JUST(InferBlobParallelDesc()); return Maybe::Ok(); } Maybe Operator::GetParallelDesc4BnInOp(const std::string& bn) const { CHECK_OR_RETURN(bn2parallel_desc_); auto it = bn2parallel_desc_->find(bn); CHECK_OR_RETURN(it != bn2parallel_desc_->end()); return it->second; } Maybe Operator::FillBlobParallelDesc( const std::function(const std::string&)>& ParallelDesc4Bn) { CHECK_OR_RETURN(!bn2parallel_desc_); bn2parallel_desc_.reset(new HashMap>); for (const auto& bn : input_output_bns()) { auto blob_parallel_desc = JUST(ParallelDesc4Bn(bn)); CHECK(bn2parallel_desc_->emplace(bn, blob_parallel_desc).second); } return Maybe::Ok(); } Maybe Operator::InferBlobParallelDesc() { JUST(FillBlobParallelDesc( [&](const std::string& bn) -> Maybe { return GetOpParallelDesc(); })); return Maybe::Ok(); } Maybe Operator::FillOpParallelDesc(const ParallelDesc& parallel_desc) { return FillOpParallelDesc(std::make_shared(parallel_desc)); } Maybe Operator::FillOpParallelDesc(std::shared_ptr parallel_desc) { CHECK_OR_RETURN(!op_parallel_desc_); op_parallel_desc_ = std::move(parallel_desc); return Maybe::Ok(); } Maybe Operator::GetOpParallelDesc() const { CHECK_OR_RETURN(op_parallel_desc_); return op_parallel_desc_; } namespace { Maybe FillLogicalBlobDesc( const std::function(int32_t)>& BlobDesc4Index, const PbRpf& bns, std::unique_ptr>>* index2logical_blob_desc_ptr) { CHECK_OR_RETURN(!(*index2logical_blob_desc_ptr)); index2logical_blob_desc_ptr->reset(new std::vector>()); (*index2logical_blob_desc_ptr)->reserve(bns.size()); for (int32_t i = 0; i < bns.size(); ++i) { (*index2logical_blob_desc_ptr)->emplace_back(JUST(BlobDesc4Index(i))); } return Maybe::Ok(); } Maybe FillLogicalBlobDesc( const std::function& BlobDesc4BnInOp, const PbRpf& bns, std::unique_ptr>>* index2logical_blob_desc_ptr) { CHECK_OR_RETURN(!(*index2logical_blob_desc_ptr)); index2logical_blob_desc_ptr->reset(new std::vector>()); (*index2logical_blob_desc_ptr)->reserve(bns.size()); for (const auto& bn : bns) { const BlobDesc& blob_desc = BlobDesc4BnInOp(bn); (*index2logical_blob_desc_ptr)->emplace_back(std::make_shared(blob_desc)); } return Maybe::Ok(); } Maybe FillLogicalBlobDesc( const std::function& BlobDesc4BnInOp, const PbRpf& bns, std::unique_ptr>>* index2logical_blob_desc_ptr) { JUST(FillLogicalBlobDesc( [&](const std::string& bn) -> const BlobDesc& { const BlobDesc* blob_desc = BlobDesc4BnInOp(bn); CHECK_NOTNULL(blob_desc); return *blob_desc; }, bns, index2logical_blob_desc_ptr)); return Maybe::Ok(); } Maybe GetLogicalBlobDesc( const std::unique_ptr>>& index2logical_blob_desc, int32_t index) { CHECK_OR_RETURN(index2logical_blob_desc); CHECK_LT_OR_RETURN(index, index2logical_blob_desc->size()); return index2logical_blob_desc->at(index); } Maybe FillLogicalBlobDescSignature( const PbRpf& bns, const std::unique_ptr>>& index2logical_blob_desc, PbMap* bn_in_op2blob_desc) { CHECK_OR_RETURN(index2logical_blob_desc); CHECK_EQ_OR_RETURN(bns.size(), index2logical_blob_desc->size()); for (int32_t i = 0; i < bns.size(); ++i) { index2logical_blob_desc->at(i)->ToProto(&(*bn_in_op2blob_desc)[bns.Get(i)]); } return Maybe::Ok(); } Maybe SupportNonContiguous(const Operator* op) { const auto& op_conf = op->op_conf(); if (op_conf.has_user_conf()) { const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf.user_conf().op_type_name()); CHECK_NOTNULL_OR_RETURN(registry) << "The op(operation) " << op_conf.user_conf().op_type_name() << " is not found. Please check whether it has been registered correctly."; return registry->non_contiguous_supported; } return false; } } // namespace Maybe Operator::FillLogicalInBlobDesc( const std::function& BlobDesc4BnInOp) { JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, input_bns(), &input_index2logical_blob_desc_)); return Maybe::Ok(); } Maybe Operator::FillLogicalInBlobDesc( const std::function& BlobDesc4BnInOp) { JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, input_bns(), &input_index2logical_blob_desc_)); return Maybe::Ok(); } Maybe Operator::FillLogicalInBlobDesc( const std::function(int32_t)>& BlobDesc4InputIndex) { JUST(FillLogicalBlobDesc(BlobDesc4InputIndex, input_bns(), &input_index2logical_blob_desc_)); return Maybe::Ok(); } Maybe Operator::GetLogicalBlobDesc4Ibn(const std::string& ibn) const { return GetLogicalBlobDesc4InputIndex(JUST(GetInputIndex(ibn))); } Maybe Operator::GetLogicalBlobDesc4InputIndex(int32_t index) const { return GetLogicalBlobDesc(input_index2logical_blob_desc_, index); } Maybe Operator::FillLogicalOutBlobDesc( const std::function& BlobDesc4BnInOp) { JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, output_bns(), &output_index2logical_blob_desc_)); return Maybe::Ok(); } Maybe Operator::FillLogicalOutBlobDesc( const std::function& BlobDesc4BnInOp) { JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, output_bns(), &output_index2logical_blob_desc_)); return Maybe::Ok(); } Maybe Operator::GetLogicalBlobDesc4Obn(const std::string& obn) const { return GetLogicalBlobDesc4OutputIndex(JUST(GetOutputIndex(obn))); } Maybe Operator::GetLogicalBlobDesc4OutputIndex(int32_t index) const { return GetLogicalBlobDesc(output_index2logical_blob_desc_, index); } Maybe Operator::GetLogicalBlobDescPtr4OutputIndex(int32_t index) const { CHECK_OR_RETURN(output_index2logical_blob_desc_); CHECK_LT_OR_RETURN(index, output_index2logical_blob_desc_->size()); CHECK_OR_RETURN(output_index2logical_blob_desc_->at(index)); return output_index2logical_blob_desc_->at(index).get(); } Maybe Operator::GetLogicalBlobDesc4BnInOp(const std::string& bn) const { const auto& it = bn2index_pair_.find(bn); CHECK_OR_RETURN(it != bn2index_pair_.end()); if (it->second.first == BlobNameTag::kInputBlobName) { return GetLogicalBlobDesc4InputIndex(it->second.second); } else if (it->second.first == BlobNameTag::kOutputBlobName) { return GetLogicalBlobDesc4OutputIndex(it->second.second); } else { UNIMPLEMENTED_THEN_RETURN(); } } Maybe Operator::InferLogicalOutBlobDescsIf() { CHECK_OR_RETURN(input_index2logical_blob_desc_); CHECK_OR_RETURN(!output_index2logical_blob_desc_); std::vector> output_logical_blob_desc_vec; output_logical_blob_desc_vec.resize(output_bns().size()); for (auto& blob_desc : output_logical_blob_desc_vec) { blob_desc.reset(new BlobDesc(DataType::kInvalidDataType, MemoryFormat::kContiguous)); } std::vector> in_logical_blob_desc_vec; in_logical_blob_desc_vec.resize(input_bns().size()); auto BlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* { const auto& it = bn2index_pair_.find(bn); CHECK(it != bn2index_pair_.end()); if (it->second.first == BlobNameTag::kInputBlobName) { auto& ptr = in_logical_blob_desc_vec.at(it->second.second); if (!ptr) { ptr.reset(new BlobDesc(*input_index2logical_blob_desc_->at(it->second.second))); } return ptr.get(); } else if (it->second.first == BlobNameTag::kOutputBlobName) { return output_logical_blob_desc_vec.at(it->second.second).get(); } else { UNIMPLEMENTED(); return nullptr; } }; JUST(InferLogicalOutBlobDescs(BlobDesc4BnInOp, *JUST(GetOpParallelDesc()))); output_index2logical_blob_desc_.reset(new std::vector>()); output_index2logical_blob_desc_->resize(output_bns().size()); for (int32_t i = 0; i < output_bns().size(); ++i) { auto& out_blob_desc = output_logical_blob_desc_vec[i]; // initialize stride by shape if the op does not support non-contiguous if (!JUST(SupportNonContiguous(this))) { out_blob_desc->set_stride(Stride(out_blob_desc->shape())); } CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size()) << Error::RuntimeError() << "stride and shape size mismatch since stride is " << out_blob_desc->stride().ToString() << " but shape is " << out_blob_desc->shape().ToString(); (*output_index2logical_blob_desc_)[i] = out_blob_desc; } return Maybe::Ok(); } Maybe Operator::InferBlobDescsIf( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { JUST(InferOutBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx)); JUST(InferInternalBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, job_desc)); return Maybe::Ok(); } Maybe Operator::InferOutBlobDescsIf( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { JUST(InferOutBlobDescs(GetBlobDesc4BnInOp, parallel_ctx)); for (const auto& bn : output_bns()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(bn); // initialize stride by shape if the op does not support non-contiguous if (!JUST(SupportNonContiguous(this))) { out_blob_desc->set_stride(Stride(out_blob_desc->shape())); } CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size()) << Error::RuntimeError() << "stride and shape size mismatch since stride is " << out_blob_desc->stride().ToString() << " but shape is " << out_blob_desc->shape().ToString(); } return Maybe::Ok(); } Maybe Operator::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { if (parallel_ctx->parallel_num() == 1) { JUST(InferLogicalOutBlobDescs(GetBlobDesc4BnInOp, *JUST(GetOpParallelDesc()))); } else { const auto& nd_sbp_signature = JUST(this->nd_sbp_signature()); const auto& parallel_desc = JUST(this->GetOpParallelDesc()); for (const auto& bn : input_bns()) { const auto& nd_sbp = nd_sbp_signature->bn_in_op2nd_sbp().at(bn); std::shared_ptr in_logical = JUST(GetLogicalBlobDesc4Ibn(bn)); CHECK_OR_RETURN( *JUST(GetPhysicalShape(in_logical->shape(), nd_sbp, *parallel_desc, *parallel_ctx)) == GetBlobDesc4BnInOp(bn)->shape()); } for (const auto& bn : output_bns()) { BlobDesc* desc = GetBlobDesc4BnInOp(bn); *desc = *JUST(GetLogicalBlobDesc4Obn(bn)); const auto& nd_sbp = nd_sbp_signature->bn_in_op2nd_sbp().at(bn); desc->set_shape( *JUST(GetPhysicalShape(desc->shape(), nd_sbp, *parallel_desc, *parallel_ctx))); } } return Maybe::Ok(); } Maybe Operator::InferInternalBlobDescsIf( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { return InferInternalBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, job_desc); } Maybe Operator::InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { return Maybe::Ok(); } Maybe Operator::InferInplaceObn2IbnIf( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferInplaceObn2Ibn(mut_inplace_obn2ibn, con_inplace_obn2ibn, GetBlobDesc4BnInOp, parallel_ctx); } Maybe Operator::InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { for (const std::string& obn : output_bns()) { const auto& obn_modifier = OutputBlobModifier4Obn(obn); if (obn_modifier.has_mutable_inplace_ibn()) { mut_inplace_obn2ibn->emplace(obn, obn_modifier.mutable_inplace_ibn()); } else if (obn_modifier.has_const_inplace_ibn()) { con_inplace_obn2ibn->emplace(obn, obn_modifier.const_inplace_ibn()); } } return Maybe::Ok(); } Maybe Operator::FillInputBlobTimeShape( const std::function(int32_t)>& GetTimeShape4InputIndex) { CHECK_OR_RETURN(!input_index2time_shape_); input_index2time_shape_.reset(new std::vector>()); input_index2time_shape_->reserve(input_bns().size()); for (int32_t i = 0; i < input_bns().size(); ++i) { std::shared_ptr time_shape = JUST(GetTimeShape4InputIndex(i)); if ((!input_blob_fastest_time_shape_) || input_blob_fastest_time_shape_->elem_cnt() < time_shape->elem_cnt()) { input_blob_fastest_time_shape_ = time_shape; } input_index2time_shape_->emplace_back(time_shape); } return Maybe::Ok(); } Maybe Operator::InferOpTimeShapeIf() { CHECK_OR_RETURN(!op_time_shape_); CHECK_OR_RETURN(input_index2time_shape_); auto GetTimeShape4BnInOp = [&](const std::string& ibn) -> Maybe { const auto& it = bn2index_pair_.find(ibn); CHECK_OR_RETURN(it != bn2index_pair_.end()); CHECK_EQ_OR_RETURN(it->second.first, kInputBlobName); return input_index2time_shape_->at(it->second.second); }; JUST(InferOpTimeShape(GetTimeShape4BnInOp, &op_time_shape_)); if (input_blob_fastest_time_shape_ && input_blob_fastest_time_shape_->elem_cnt() > op_time_shape_->elem_cnt()) { input_output_fastest_time_shape_ = input_blob_fastest_time_shape_; } else { input_output_fastest_time_shape_ = op_time_shape_; } return Maybe::Ok(); } Maybe Operator::InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const { if (!input_bns().empty()) { std::shared_ptr first_time_shape = input_index2time_shape_->at(0); for (int64_t i = 1; i < input_bns().size(); ++i) { CHECK_EQ_OR_RETURN(*input_index2time_shape_->at(i), *first_time_shape); } *time_shape = first_time_shape; } else { *time_shape = std::make_shared(Shape({1, 1})); } return Maybe::Ok(); } Maybe Operator::GetOpTimeShape() const { CHECK_OR_RETURN(op_time_shape_); return op_time_shape_; } Maybe Operator::GetInputBlobFastestTimeShape() const { return input_blob_fastest_time_shape_; } Maybe Operator::GetInputOutputFastestTimeShape() const { return input_output_fastest_time_shape_; } Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } Maybe Operator::EnumerateNdSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { return Maybe::Ok(); } Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { // Get 1D sbp signature list HashMap hierarchy_value2sbp_sig_list; // hierarchy value is the value at the dimension corresponding to the current SBP // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] // Suppose we have nd_sbp = (S0, B) // The hierarchy value corresponding to S0 is 2 // The hierarchy value corresponding to B is 4. for (int32_t hierarchy_value : *parallel_desc.hierarchy()) { if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) { auto* sbp_sig_list = &hierarchy_value2sbp_sig_list[hierarchy_value]; JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) << op_name() << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy value: " << hierarchy_value; } } int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0), &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(), hierarchy_value2sbp_sig_list, nd_sbp_sig_list); JUST(EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list)); return Maybe::Ok(); } Maybe Operator::GetValidNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list, bool check_output) const { JUST(GetNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list)); // Leave those valid Nd SBPs JUST(FilterNdSbpSignatureListByLogicalShape(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list, check_output)); CHECK_OR_RETURN(nd_sbp_sig_list->size() > 0) << "Empty sbp signature after filtering for " << op_name(); return Maybe::Ok(); } Operator::DumpNdSbpSignatureForOpConfFn Operator::GetDumpNdSbpSignatureForOpConfFn() const { return [](const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf) -> Maybe { return Maybe::Ok(); }; } void Operator::ForEachBnInOp(const std::function& Handler) const { for (const std::string& bn_in_op : input_bns()) { Handler(bn_in_op); } for (const std::string& bn_in_op : output_bns()) { Handler(bn_in_op); } for (const std::string& bn_in_op : tmp_bns()) { Handler(bn_in_op); } } Maybe Operator::FillSbpSignature(const SbpSignature& sbp_signature) { NdSbpSignature nd_sbp_signature; SbpSignatureToNdSbpSignature(sbp_signature, &nd_sbp_signature); JUST(FillNdSbpSignature(nd_sbp_signature)); return Maybe::Ok(); } Maybe Operator::FillNdSbpSignature(const NdSbpSignature& signature) { CHECK_OR_RETURN(!nd_sbp_signature_); CHECK_OR_RETURN(!sbp_signature_); nd_sbp_signature_.reset(new NdSbpSignature(signature)); CHECK_OR_RETURN(op_parallel_desc_); if (op_parallel_desc_->hierarchy()->NumAxes() == 1) { SbpSignature sbp_signature; NdSbpSignatureToSbpSignature(signature, &sbp_signature); sbp_signature_.reset(new SbpSignature(sbp_signature)); } return Maybe::Ok(); } Maybe Operator::InferSbpSignatureIf( const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, const std::function(const std::string&)>& SbpInferHint4Ibn, const ParallelDesc& parallel_desc) { SbpSignature signature; JUST(InferSbpSignature(&signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn, parallel_desc)); JUST(FillSbpSignature(signature)); return Maybe::Ok(); } Maybe Operator::InferSbpSignature( SbpSignature* infered_sbp_signature, const SbpSignature& sbp_sig_conf, const HashMap& ibn2sbp_infer_hint) const { auto SbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe { auto it = ibn2sbp_infer_hint.find(ibn); if (it == ibn2sbp_infer_hint.end()) { return Error::CheckFailedError() << "cannot find corresponding SbpInferHint for input_blob_name : " << ibn; } return &(it->second); }; std::function CalcOrderValue4SbpSig; auto OrderValue4SourceDefaultSplit0 = [&](const std::string& bn, const SbpParallel& sbp_parallel) -> int32_t { return -1 * (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0); }; auto OrderValue4SbpHint = [&](const std::string& ibn, const SbpParallel& sbp_parallel) -> int32_t { const auto* hint = CHECK_JUST(SbpInferHint4Ibn(ibn)); // NOTE(chengcheng): one to one connect. return -10 * (hint->sbp_parallel() == sbp_parallel && hint->parallel_desc().parallel_num() == op_parallel_desc_->parallel_num()); }; if (sbp_sig_conf.bn_in_op2sbp_parallel().empty()) { CalcOrderValue4SbpSig = [&](const SbpSignature& sbp_signature) -> int32_t { int32_t order_value = 0; if (input_bns().size() > 0) { // NOTE(chengcheng): non-source op only ordered by input sbp match. for (const auto& ibn : input_bns()) { const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(ibn); CHECK(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end()); order_value += OrderValue4SbpHint(ibn, sbp_parallel_it->second); } } else { // NOTE(chengcheng): source op default split(0) // ONLY data source op will consider order here. variable op sbp is set by user. for (const auto& obn : output_bns()) { const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(obn); CHECK(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end()); order_value += OrderValue4SourceDefaultSplit0(obn, sbp_parallel_it->second); } } return order_value; }; } else { CalcOrderValue4SbpSig = [](const SbpSignature&) -> int32_t { return 0; }; } JUST(InferSbpSignature(infered_sbp_signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn, *op_parallel_desc_)); return Maybe::Ok(); } Maybe Operator::FilterAndCheckValidSbpSignatureListByLogicalShape( const SbpSignatureList& total_sbp_sig_list, const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* valid_sbp_sig_list) const { auto GetOpDebugShapeStr = [&]() -> std::string { std::string ret = ""; if (op_conf().has_user_conf()) { ret += ("op_type_name = " + op_conf().user_conf().op_type_name() + ", "); } for (const auto& ibn : input_bns()) { ret += (" ibn:(" + ibn + ") lbn:(" + GenLogicalBlobName(BnInOp2Lbi(ibn)) + ") logical_shape = " + CHECK_JUST(LogicalBlobDesc4Ibn(ibn)).shape().DebugStr() + ", "); } return ret; }; for (const auto& sbp_signature : total_sbp_sig_list.sbp_signature()) { bool is_valid = true; for (const auto& ibn : input_bns()) { const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(ibn); CHECK_OR_RETURN(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end()); const SbpParallel& sbp_parallel = sbp_parallel_it->second; const Shape& logical_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape(); // NOTE(chengcheng): disable split when logical shape cannot split at this axis if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); CHECK_OR_RETURN(axis >= 0 && axis < logical_shape.NumAxes()) << "The sbp sign is ERROR because of the split axis >= shape num axes. In op: [" << op_name() << "] ibn: (" << ibn << ") the split axis is = " << axis << " . And the logical_shape = " << logical_shape.DebugStr() << ". This Op debug str = {" << GetOpDebugShapeStr() << "}"; if (logical_shape.At(axis) < parallel_desc.parallel_num()) { // NOTE(chengcheng): cannot split at this axis! is_valid = false; break; } } } if (is_valid) { *valid_sbp_sig_list->mutable_sbp_signature()->Add() = sbp_signature; } } return Maybe::Ok(); } Maybe Operator::FilterNdSbpSignatureListByLogicalShape( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list, bool check_output) const { auto FilterSbp4Blobs = [&](const PbRpf& bns, const NdSbpSignature& nd_sbp_sig) -> Maybe { // {in_0 : (S(6), B), in_1 : (S(0), S(1)), out : (B, S(1))} // look through input blob name in_0 and in_1 for (const auto& ibn : bns) { const auto& nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(ibn); // Find an unexpected blob name CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end()); const auto& nd_sbp = nd_sbp_it->second; Shape logical_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape(); const auto& parallel_hierarchy = parallel_desc.hierarchy(); // Treat 1D sbp and nD sbp differently. Please refer to // JobBuildAndInferCtx::CheckOpBlobSplitability // for more details. if (JUST(FilterNdSbpByLogicalShape(nd_sbp, logical_shape, *parallel_hierarchy))) { return true; } } return false; }; // Go down from the tail to the head, since we might drop the tail. for (int32_t sbp_id = nd_sbp_sig_list->size() - 1; sbp_id >= 0; sbp_id--) { if (JUST(FilterSbp4Blobs(input_bns(), JUST(VectorAt(*nd_sbp_sig_list, sbp_id)))) || (check_output && JUST(FilterSbp4Blobs(output_bns(), JUST(VectorAt(*nd_sbp_sig_list, sbp_id)))))) { // Remove the Nd SBP candidate (*nd_sbp_sig_list)[sbp_id] = JUST(VectorAt(*nd_sbp_sig_list, nd_sbp_sig_list->size() - 1)); nd_sbp_sig_list->pop_back(); } } return Maybe::Ok(); } Maybe Operator::GreedilyFindMinCopyCostNdSbp( NdSbpSignature* nd_sbp_signature, const std::function(const std::string&)>& NdSbpInferHint4Ibn, const std::vector& nd_sbp_sig_list) const { int32_t select_sbp_idx = -1; double min_copy_cost = GetValidMaxCopyCost(); // We notice that we have a lot of inquiries asking for the cost. // If the candidate list only have one entry, select it to reduce the inquiries. // Normally, we support all the sbp combination for boxing. Therefore, we do not need to worry // about the case that we can not transfer to this sbp signature. Even if we do not support such // transfer, a report would be sent in boxing_with_middle_nodes.cpp. if (nd_sbp_sig_list.size() == 1) { select_sbp_idx = 0; } else { std::vector requires_same_sbp(input_bns().size()); for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { const auto& ibn = input_bns().at(ibn_id); const auto& blob_modifier_ = InputBlobModifier4Ibn(ibn); requires_same_sbp[ibn_id] = (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) || NotSupportBoxingDataType( JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc().data_type()); } // SBP_INFER_RULE_TAG = 1, pick the sbp signature which matches all the producers // or has the lowest cost // SBP_INFER_RULE_TAG = 2, pick the sbp signature which matches as much as possible // SBP_INFER_RULE_TAG = 3, pick the sbp signature which has the lowest cost static int32_t infer_rule = ParseIntegerFromEnv("SBP_INFER_RULE_TAG", 1); for (int32_t i = 0; i < nd_sbp_sig_list.size(); ++i) { double total_copy_cost = 0.0; double sum_priority_ratio = 0.0; // The initial ratio do not need to be a large one. // Since any copy cost less than infinity would reset the min_sum_priority_ratio. double min_sum_priority_ratio = 0.0; bool same_sbp_before_reduce = true; for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { const auto& ibn = input_bns().at(ibn_id); const auto& producer_infer_hint4ibn = JUST(NdSbpInferHint4Ibn(ibn)); same_sbp_before_reduce &= producer_infer_hint4ibn->nd_sbp() == JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn); // Skip the computation of priority ratio if SBP_INFER_RULE_TAG = 3 if (infer_rule != SbpInferRuleTag::kMinCost) { double priority_ratio = ComputeSbpInferPriority( producer_infer_hint4ibn->nd_sbp(), JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn), producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id], producer_infer_hint4ibn->logical_blob_desc().shape()); sum_priority_ratio += priority_ratio; // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); break; } // If SBP_INFER_RULE_TAG = 2 and the input blob has a matched sbp, // skip the computation of the transfer cost if (infer_rule == SbpInferRuleTag::kMatchAMAP && priority_ratio == 0.0) { continue; } } // Compute the cost and add them up total_copy_cost += JUST(ComputeCopyCostBetweenNdSbp( producer_infer_hint4ibn->nd_sbp(), JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn), producer_infer_hint4ibn->logical_blob_desc(), producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id])); // Reduce inquiries when the current cost is larger than the minimum cost // For SBP_INFER_RULE_TAG = 1, do not prune it since the all-matched case // might have larger cost. if (infer_rule != SbpInferRuleTag::kAllMatch && total_copy_cost > min_copy_cost) { break; } } // For SBP_INFER_RULE_TAG = 1, select the all-matched case if found if (infer_rule == SbpInferRuleTag::kAllMatch && same_sbp_before_reduce && sum_priority_ratio == 0.0) { select_sbp_idx = i; break; } // Otherwise, select the case with the lowest cost if (total_copy_cost < min_copy_cost * kFloatDeviationMinus // Strict less than || (total_copy_cost <= min_copy_cost * kFloatDeviationPlus // Loose equal && sum_priority_ratio < min_sum_priority_ratio)) { select_sbp_idx = i; min_copy_cost = total_copy_cost; min_sum_priority_ratio = sum_priority_ratio; // NOLINT(clang-analyzer-deadcode.DeadStores) } } // Can't find any available sbp if (select_sbp_idx == -1) { std::ostringstream err; err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; err << "candidate nd sbp signature are: " << *JUST(NdSbpSignatureListToString(nd_sbp_sig_list, input_bns(), output_bns())); err << ", but inputs sbp are:"; for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { const auto& ibn = input_bns().at(ibn_id); const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); err << " " << ibn << ": " << NdSbpToString(nd_sbp); if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } err << ";"; } return Error::RuntimeError() << err.str(); } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); return Maybe::Ok(); } Maybe Operator::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { // get op sbp signatures auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe { const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn)); return Maybe(sbp_infer_hint->logical_blob_desc()); }; SbpSignatureList valid_sbp_sig_list; { SbpSignatureList sbp_sig_candidates; // For 1d sbp, hierarchy value = parallel num JUST( GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates)); // filter sbp signatures by logical shape JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn, parallel_desc, &valid_sbp_sig_list)); } // filter sbp signatures by sbp signature conf SbpSignatureList filtered_sbp_sigs_by_conf; FilterSbpSignatureList(valid_sbp_sig_list, sbp_sig_conf, &filtered_sbp_sigs_by_conf); CHECK_GT_OR_RETURN(filtered_sbp_sigs_by_conf.sbp_signature_size(), 0) << op_name() << " has no maching sbp after flitering valid sbp list " << valid_sbp_sig_list.DebugString() << " with sbp hint " << sbp_sig_conf.DebugString(); if (filtered_sbp_sigs_by_conf.sbp_signature_size() == 1) { *sbp_signature = *filtered_sbp_sigs_by_conf.sbp_signature().begin(); return Maybe::Ok(); } // sort sbp signatures by copy cost, then return the one with least cost HashMap ibn2producer_sbp_parallel; for (const auto& ibn : input_bns()) { ibn2producer_sbp_parallel[ibn] = &(JUST(SbpInferHint4Ibn(ibn))->sbp_parallel()); } std::vector sorted_sbp_signatures; SortSbpSignatureListByCopyCost(filtered_sbp_sigs_by_conf, input_bns(), SbpInferHint4Ibn, CalcOrderValue4SbpSig, &sorted_sbp_signatures); *sbp_signature = *sorted_sbp_signatures.at(0); return Maybe::Ok(); } Maybe Operator::InferNdSbpSignatureIf( const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) { NdSbpSignature nd_sbp_signature; JUST(InferNdSbpSignature(&nd_sbp_signature, nd_sbp_constraints, parallel_desc, NdSbpInferHint4Ibn)); JUST(FillNdSbpSignature(nd_sbp_signature)); return Maybe::Ok(); } Maybe Operator::InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const { const auto& parallel_hierarchy = parallel_desc.hierarchy(); CHECK_GT(parallel_hierarchy->NumAxes(), 0); if (parallel_hierarchy->NumAxes() == 1) { // Infer 1d sbp HashMap ibn2sbp_infer_hint; for (const auto& ibn : input_bns()) { const NdSbpInferHint* hint = JUST(NdSbpInferHint4Ibn(ibn)); ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(&hint->parallel_desc(), &hint->logical_blob_desc(), &hint->nd_sbp().sbp_parallel(0))); } SbpSignature sbp_constraints; NdSbpSignatureToSbpSignature(nd_sbp_constraints, &sbp_constraints); SbpSignature sbp_signature; JUST(InferSbpSignature(&sbp_signature, sbp_constraints, ibn2sbp_infer_hint)); SbpSignatureToNdSbpSignature(sbp_signature, nd_sbp_signature); return Maybe::Ok(); } else { // Infer nd sbp const auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe { return JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc(); }; std::vector nd_sbp_sig_list; JUST(GetValidNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, &nd_sbp_sig_list, /*check_output=*/false)); // Filter nd_sbp according to `nd_sbp_constraints` for (int32_t i = nd_sbp_sig_list.size() - 1; i >= 0; --i) { // If any blob do not match nd_sbp_constraints, the candidate nd_sbp will be deleted. if (/*not_match=*/std::any_of(input_bns().begin(), input_bns().end(), [&](const auto& ibn) { const auto nd_sbp_constraints_it = nd_sbp_constraints.bn_in_op2nd_sbp().find(ibn); if (nd_sbp_constraints_it != nd_sbp_constraints.bn_in_op2nd_sbp().end()) { return nd_sbp_sig_list.at(i).bn_in_op2nd_sbp().at(ibn) != nd_sbp_constraints_it->second; } return false; })) { nd_sbp_sig_list.at(i) = nd_sbp_sig_list.back(); nd_sbp_sig_list.pop_back(); } } CHECK_OR_RETURN(!nd_sbp_sig_list.empty()) << "Empty sbp signature after filtering for " << op_name(); JUST(GreedilyFindMinCopyCostNdSbp(nd_sbp_signature, NdSbpInferHint4Ibn, nd_sbp_sig_list)); return Maybe::Ok(); } } Maybe Operator::InferLocalSignatureIf( std::function(const std::string&)> LocalSigInferHint4Ibn, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) { return InferLocalSignature(std::move(LocalSigInferHint4Ibn), is_local_parallel_view_conf, parallel_desc); } // Compute time complexity for given blob description and sbp signature. // Use function to replace the HashMap from logical blob id to blob description pointer. Maybe Operator::GetComputeComplexity( NdSbpSignature* sbp_signature, std::function logical_blob_desc4bn, const ParallelDesc& parallel_desc) const { const auto& sbp_bn_in_op2nd_sbp = sbp_signature->bn_in_op2nd_sbp(); double complexity = 0; const auto& parallel_hierarchy = *parallel_desc.hierarchy(); auto ComputeComplexity4Blobs = [&](const PbRpf& bns) -> Maybe { for (const auto& bn : bns) { const BlobDesc& logical_blob_desc = logical_blob_desc4bn(bn); const NdSbp& nd_sbp = sbp_bn_in_op2nd_sbp.at(bn); CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), parallel_hierarchy.NumAxes()) << "At this moment, the dimension of nd SBP should be equal to the depth of hierarchy in " << "parallel description."; double total_cost = logical_blob_desc.shape().elem_cnt(); for (int32_t sbp_dim = 0; sbp_dim < nd_sbp.sbp_parallel_size(); sbp_dim++) { const auto& sbp = nd_sbp.sbp_parallel(sbp_dim); if (sbp.has_split_parallel()) { const int64_t axis = sbp.split_parallel().axis(); if (axis >= logical_blob_desc.shape().NumAxes() || logical_blob_desc.shape().At(axis) < parallel_hierarchy.At(sbp_dim)) { complexity = GetMaxVal(); return Maybe::Ok(); } else { total_cost /= parallel_hierarchy.At(sbp_dim); } } } complexity += total_cost; } return Maybe::Ok(); }; JUST(ComputeComplexity4Blobs(input_bns())); JUST(ComputeComplexity4Blobs(output_bns())); return complexity; } std::string DebugString4LocalHint( std::function(const std::string&)> LocalSigInferHint4Ibn, const Operator& op) { std::string ret; for (const auto& ibn : op.input_bns()) { const auto& infer_hint = *CHECK_JUST(LocalSigInferHint4Ibn(ibn)); bool is_local = infer_hint.is_local_parallel_view(); ret += "arg: " + ibn + ", is_local: " + (is_local ? "true" : "false") + "\n"; } return ret; } Maybe Operator::InferLocalSignature( std::function(const std::string&)> LocalSigInferHint4Ibn, // NOLINT bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) { HashSet is_local_parallel_view_values; for (const auto& ibn : input_bns()) { const auto& infer_hint = *JUST(LocalSigInferHint4Ibn(ibn)); is_local_parallel_view_values.insert(infer_hint.is_local_parallel_view()); } CHECK_LE_OR_RETURN(is_local_parallel_view_values.size(), 1) << "mixed parallel_views are disallowed." << "\n=========== is_mirrrored_conf ===========\n" << DebugString4LocalHint(LocalSigInferHint4Ibn, *this) << "\n=========== op_cnf ===========\n" << op_conf().DebugString(); if (is_local_parallel_view_values.size() == 1) { is_local_parallel_view_conf = *is_local_parallel_view_values.begin(); } if (is_local_parallel_view_conf) { for (const auto& ibn : input_bns()) { const auto& infer_hint = *JUST(LocalSigInferHint4Ibn(ibn)); CHECK_EQ_OR_RETURN(infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); } } const auto SetIsLocalParallel = [&](const std::string& bn_in_op) { if (is_local_parallel_view_conf) { MutOptLocalParallel(bn_in_op)->mutable_local_parallel(); } else { MutOptLocalParallel(bn_in_op)->clear_local_parallel(); } }; for (const auto& ibn : input_bns()) { SetIsLocalParallel(ibn); } for (const auto& obn : output_bns()) { SetIsLocalParallel(obn); } return Maybe::Ok(); } Maybe Operator::sbp_signature() const { CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered"; return sbp_signature_.get(); } Maybe Operator::nd_sbp_signature() const { CHECK_OR_RETURN(nd_sbp_signature_) << "parallel distribution signature not infered"; return nd_sbp_signature_.get(); } BlobLastUsedSignature* Operator::mut_blob_last_used_signature() { if (!blob_last_used_signature_) { blob_last_used_signature_.reset(new BlobLastUsedSignature()); } return blob_last_used_signature_.get(); } BlobBackwardUsedSignature* Operator::mut_blob_backward_used_signature() { if (!blob_backward_used_signature_) { blob_backward_used_signature_.reset(new BlobBackwardUsedSignature()); } return blob_backward_used_signature_.get(); } Maybe Operator::SbpParallel4BnInOp(const std::string& bn_in_op) const { CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered"; const auto& map = sbp_signature_->bn_in_op2sbp_parallel(); const auto& iter = map.find(bn_in_op); CHECK_OR_RETURN(iter != map.end()) << "blob_name " << bn_in_op << " not found in sbp signature"; return &iter->second; } Maybe Operator::NdSbp4BnInOp(const std::string& bn_in_op) const { CHECK_OR_RETURN(nd_sbp_signature_) << "parallel distribution signature not infered"; const auto& map = nd_sbp_signature_->bn_in_op2nd_sbp(); const auto& iter = map.find(bn_in_op); CHECK_OR_RETURN(iter != map.end()) << "op_name " << op_name() << " blob_name " << bn_in_op << " not found in parallel distribution"; return &iter->second; } Maybe Operator::OptLocalParallel4BnInOp( const std::string& bn_in_op) const { CHECK_OR_RETURN(local_signature_) << "local signature not infered"; const auto& map = local_signature_->bn_in_op2opt_local_parallel(); const auto& iter = map.find(bn_in_op); CHECK_OR_RETURN(iter != map.end()) << "blob_name " << bn_in_op << " not found in local signature"; return &iter->second; } OptLocalParallel* Operator::MutOptLocalParallel(const std::string& bn_in_op) { if (!local_signature_) { local_signature_.reset(new LocalSignature()); } auto* map = local_signature_->mutable_bn_in_op2opt_local_parallel(); return &(*map)[bn_in_op]; } namespace { bool HasBlobDescWithField(std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops, std::function Predicator4BlobDesc) { for (const std::string& bn_in_op : bn_in_ops) { const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op); if (blob_desc && Predicator4BlobDesc(blob_desc)) { return true; } } return false; } } // namespace void Operator::GenKernelConf( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { auto* dtype_signature = kernel_conf->mutable_dtype_signature(); for (const std::string& ibn : input_bns()) { const BlobDesc* blob_desc = GetBlobDesc4BnInOp(ibn); if (blob_desc == nullptr) { continue; } (*dtype_signature->mutable_name2dtype())[ibn] = blob_desc->data_type(); } CHECK_JUST(ToOpAttribute(kernel_conf->mutable_op_attribute())); kernel_conf->set_all_blobs_are_static( !HasBlobDescWithField(GetBlobDesc4BnInOp, output_bns(), [](const BlobDesc* blob_desc) { return blob_desc->is_dynamic(); })); { DataType data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, output_bns()); if (data_type == DataType::kInvalidDataType) { data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, input_bns()); } kernel_conf->set_data_type(data_type); } if (parallel_ctx != nullptr) { *(kernel_conf->mutable_parallel_ctx()) = *parallel_ctx; } VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf); } void Operator::VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {} void Operator::AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index) { CHECK(lbi2output_index_.emplace(lbi, output_index).second); } std::string Operator::Bn2ConfName(const std::string& bn) const { return GetStrValInPbFdOrPbRpf(GetCustomizedConf(), bn); } LogicalBlobId Operator::lbi4ibn(const std::string& input_bn) const { return GenLogicalBlobId(Bn2ConfName(input_bn)); } LogicalBlobId Operator::lbi4obn(const std::string& output_bn) const { LogicalBlobId ret; ret.set_op_name(op_name()); ret.set_blob_name(Bn2ConfName(output_bn)); return ret; } LogicalBlobId Operator::tbn2lbi(const std::string& tmp_bn) const { LogicalBlobId ret; ret.set_op_name(op_name()); ret.set_blob_name(tmp_bn); return ret; } void Operator::EnrollTmpBn(const std::string& tbn) { *tmp_bns_.Add() = tbn; CHECK(mut_bn_in_op2lbi()->insert({tbn, tbn2lbi(tbn)}).second); } InputBlobModifier* Operator::EnrollInputBn(const std::string& ibn, bool has_diff) { LogicalBlobId lbi = lbi4ibn(ibn); auto* map = arg_modifier_signature_.mutable_ibn2input_blob_modifier(); const auto& pair = map->insert({ibn, InputBlobModifier()}); CHECK(pair.second); const int32_t input_index = input_bns_.size(); CHECK( bn2index_pair_.emplace(ibn, std::make_pair(BlobNameTag::kInputBlobName, input_index)).second); *input_bns_.Add() = ibn; CHECK(mut_bn_in_op2lbi()->insert({ibn, lbi}).second); auto* ret = &pair.first->second; ret->set_requires_grad(has_diff); return ret; } const InputBlobModifier& Operator::InputBlobModifier4Ibn(const std::string& ibn) const { return arg_modifier_signature_.ibn2input_blob_modifier().at(ibn); } const OutputBlobModifier& Operator::OutputBlobModifier4Obn(const std::string& obn) const { return arg_modifier_signature_.obn2output_blob_modifier().at(obn); } InputBlobModifier* Operator::MutInputBlobModifier4Ibn(const std::string& ibn) { auto* map = arg_modifier_signature_.mutable_ibn2input_blob_modifier(); return &map->at(ibn); } OutputBlobModifier* Operator::MutOutputBlobModifier4Obn(const std::string& obn) { auto* map = arg_modifier_signature_.mutable_obn2output_blob_modifier(); return &map->at(obn); } void Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff) { FOR_RANGE(int32_t, i, 0, num) { EnrollInputBn(GenRepeatedBn(ibn_prefix, i), has_diff); } } void Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, bool has_diff) { EnrollRepeatedInputBn(ibn_prefix, GetPbRpfFromCustomizedConf(ibn_prefix).size(), has_diff); } void Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num) { EnrollRepeatedInputBn(ibn_prefix, num, true); } void Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix) { EnrollRepeatedInputBn(ibn_prefix, true); } OutputBlobModifier* Operator::EnrollOutputBn(const std::string& obn, bool has_diff) { LogicalBlobId lbi = lbi4obn(obn); auto* map = arg_modifier_signature_.mutable_obn2output_blob_modifier(); const auto& pair = map->insert({obn, OutputBlobModifier()}); CHECK(pair.second); auto* ret = &pair.first->second; const int32_t output_index = output_bns_.size(); CHECK(bn2index_pair_.emplace(obn, std::make_pair(BlobNameTag::kOutputBlobName, output_index)) .second); AddLbi2OutputIndex(lbi, output_index); *output_bns_.Add() = obn; CHECK(mut_bn_in_op2lbi()->insert({obn, lbi}).second); ret->set_requires_grad(has_diff); return ret; } void Operator::EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, int32_t num, bool has_diff, const std::function& ModifierSetter) { FOR_RANGE(int32_t, i, 0, num) { ModifierSetter(EnrollOutputBn(GenRepeatedBn(obn_prefix, i), has_diff)); } } void Operator::EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, bool has_diff, const std::function& ModifierSetter) { EnrollRepeatedOutputBnWithSetter(obn_prefix, GetPbRpfFromCustomizedConf(obn_prefix).size(), has_diff, ModifierSetter); } void Operator::EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, int32_t num, const std::function& ModifierSetter) { EnrollRepeatedOutputBnWithSetter(obn_prefix, num, true, ModifierSetter); } void Operator::EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, const std::function& ModifierSetter) { EnrollRepeatedOutputBnWithSetter(obn_prefix, true, ModifierSetter); } void Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num, bool has_diff) { FOR_RANGE(int32_t, i, 0, num) { EnrollOutputBn(GenRepeatedBn(obn_prefix, i), has_diff); } } void Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, bool has_diff) { EnrollRepeatedOutputBn(obn_prefix, GetPbRpfFromCustomizedConf(obn_prefix).size(), has_diff); } void Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num) { EnrollRepeatedOutputBn(obn_prefix, num, true); } void Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix) { EnrollRepeatedOutputBn(obn_prefix, true); } std::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx) { CHECK_GE(idx, 0); return bn_prefix + "_" + std::to_string(idx); } std::pair GenUnRepeatedBn(const std::string& bn) { return GetFieldNameAndIndex4StrVal(bn); } bool IsCpuOnly(const std::string& user_op_type_name) { auto* registration_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_type_name); CHECK(registration_val != nullptr) << "user_op_type_name: " << user_op_type_name; return registration_val->cpu_only_supported; } bool IsCpuOnly(const OperatorConf& op_conf) { OperatorConf::OpTypeCase op_type_case = op_conf.op_type_case(); using CpuOnly = OnlyCpuSupportPredicator; auto* ptr = NewObj(op_type_case); CHECK(ptr != nullptr) << "op_conf\n" << op_conf.DebugString(); if (*std::unique_ptr(ptr)) { return true; } if (!op_conf.has_user_conf()) { return false; } return IsCpuOnly(op_conf.user_conf().op_type_name()); } Maybe ConstructOp(const OperatorConf& op_conf, DeviceType device_type) { std::shared_ptr dev_op_conf = std::make_shared(op_conf); dev_op_conf->set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type))); auto op = JUST(CheckAndConstructOp(dev_op_conf)); return op; } Maybe ConstructOp(const OperatorConf& op_conf) { if (IsCpuOnly(op_conf)) { return JUST(ConstructOp(op_conf, DeviceType::kCPU)); } return CheckAndConstructOp(std::make_shared(op_conf)); } Symbol Operator::GetOpConfWithoutOpNameAndLbn() const { OperatorConf op_conf(this->op_conf()); op_conf.set_name("undefined-op-name"); PbMessage* op_type_conf = MutableMessageInPbMessage(&op_conf, op_conf.op_type_case()); for (const auto& ibn : input_bns()) { if (!HasStrFieldInPbFdOrPbRpf(*op_type_conf, ibn)) { continue; } ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, "undefined-op-name/undefined-ibn"); } return SymbolOf(op_conf); } std::shared_ptr Operator::GetOpAttributeWithoutOpNameAndLbn() const { auto op_attribute = std::make_shared(); CHECK_JUST(ToOpAttribute(op_attribute.get())); op_attribute->mutable_sbp_signature(); *op_attribute->mutable_op_conf() = *GetOpConfWithoutOpNameAndLbn(); return op_attribute; } Maybe Operator::GetInputIndex(const std::string& ibn) const { auto it = bn2index_pair_.find(ibn); CHECK_OR_RETURN(it != bn2index_pair_.end()); CHECK_EQ_OR_RETURN(it->second.first, BlobNameTag::kInputBlobName); return it->second.second; } Maybe Operator::GetOutputIndex(const std::string& obn) const { auto it = bn2index_pair_.find(obn); CHECK_OR_RETURN(it != bn2index_pair_.end()); CHECK_EQ_OR_RETURN(it->second.first, BlobNameTag::kOutputBlobName); return it->second.second; } Maybe Operator::GetOutputIndex(const LogicalBlobId& lbi) const { auto it = lbi2output_index_.find(lbi); CHECK_OR_RETURN(it != lbi2output_index_.end()); return it->second; } Maybe Operator::ToOpAttribute(OpAttribute* op_attribute) const { *op_attribute->mutable_input_bns() = input_bns_; *op_attribute->mutable_output_bns() = output_bns_; *op_attribute->mutable_tmp_bns() = tmp_bns_; *op_attribute->mutable_op_conf() = op_conf(); *op_attribute->mutable_arg_signature() = arg_signature_; *op_attribute->mutable_arg_modifier_signature() = arg_modifier_signature_; if (blob_last_used_signature_) { *op_attribute->mutable_blob_last_used_signature() = *blob_last_used_signature_; } else { op_attribute->clear_blob_last_used_signature(); } if (blob_backward_used_signature_) { *op_attribute->mutable_blob_backward_used_signature() = *blob_backward_used_signature_; } else { op_attribute->clear_blob_backward_used_signature(); } if (sbp_signature_) { *op_attribute->mutable_sbp_signature() = *sbp_signature_; } else { op_attribute->clear_sbp_signature(); } if (nd_sbp_signature_) { *op_attribute->mutable_nd_sbp_signature() = *nd_sbp_signature_; } else { op_attribute->clear_nd_sbp_signature(); } if (local_signature_) { *op_attribute->mutable_local_signature() = *local_signature_; } else { op_attribute->clear_local_signature(); } if (input_index2logical_blob_desc_) { JUST(FillLogicalBlobDescSignature( input_bns(), input_index2logical_blob_desc_, op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc())); } if (output_index2logical_blob_desc_) { JUST(FillLogicalBlobDescSignature( output_bns(), output_index2logical_blob_desc_, op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc())); } if (op_parallel_desc_) { *op_attribute->mutable_parallel_conf_signature()->mutable_op_parallel_conf() = op_parallel_desc_->parallel_conf(); } if (bn2parallel_desc_) { auto* map = op_attribute->mutable_parallel_conf_signature()->mutable_bn_in_op2parallel_conf(); for (const auto& pair : *bn2parallel_desc_) { const bool has_same_parallel_conf_as_op = *op_parallel_desc_ == *pair.second; if (!has_same_parallel_conf_as_op) { (*map)[pair.first] = pair.second->parallel_conf(); } } } return Maybe::Ok(); } LogicalBlobId GenLogicalBlobId(const std::string& lbn) { LogicalBlobId lbi; size_t pos = lbn.find('/'); CHECK_NE(pos, std::string::npos) << "lbn: " << lbn; lbi.set_op_name(lbn.substr(0, pos)); std::string blob_name_with_hit = lbn.substr(pos + 1); size_t vbar_pos = blob_name_with_hit.rfind('|'); std::string blob_name_with_split_hit = blob_name_with_hit.substr(0, vbar_pos); size_t split_pos = blob_name_with_split_hit.rfind(':'); lbi.set_blob_name(blob_name_with_split_hit.substr(0, split_pos)); return lbi; } Maybe GetSbpParallelInLbnOrNothing(const std::string& lbn, SbpParallel* sbp) { size_t vbar_pos = lbn.rfind('|'); std::string lbn_with_split_hint = lbn.substr(0, vbar_pos); size_t pos = lbn_with_split_hint.rfind(':'); CHECK_NE(pos, lbn_with_split_hint.length() - 1); if (pos == std::string::npos) { return false; } std::string split_hint = lbn_with_split_hint.substr(pos + 1); if (split_hint[0] == 'S') { std::string axis_str = split_hint.substr(1); CHECK_OR_RETURN(IsStrInt(axis_str)); sbp->mutable_split_parallel()->set_axis(oneflow_cast(axis_str)); } else if (split_hint[0] == 'B') { sbp->mutable_broadcast_parallel(); } else { return Error::CheckFailedError() << "split hint only support 'S' or 'B', but get:" << split_hint[0]; } return true; } Maybe ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disable_boxing) { size_t pos = lbn_with_hint.rfind('|'); if (pos == std::string::npos) { return false; } CHECK_NE(pos, lbn_with_hint.length() - 1); std::string disable_boxing_str = lbn_with_hint.substr(pos + 1); CHECK_OR_RETURN(IsStrInt(disable_boxing_str)); *disable_boxing = oneflow_cast(disable_boxing_str); return true; } std::string GetInputLbnInOpCustomizedConf(const OperatorConf& op_conf, const std::string& fd_name_may_have_idx) { const PbMessage& msg = GetMessageInPbMessage(op_conf, op_conf.op_type_case()); const PbMessage* msg_ptr = &msg; const UserOpConf* user_conf = dynamic_cast(msg_ptr); if (user_conf) { std::pair pair = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx); if (user_conf->input().find(pair.first) != user_conf->input().end()) { return user_conf->input().at(pair.first).s(pair.second); } else { LOG(WARNING) << "cannot find input arg val in user op conf. (arg_name = " << pair.first << ", id = " << std::to_string(pair.second) << ")"; return ""; } } else { return GetStrValInPbFdOrPbRpf(msg, fd_name_may_have_idx); } } // return old value std::string ReplaceInputLbnInOpTypeConf(PbMessage* msg, const std::string& fd_name_may_have_idx, const std::string& new_val) { UserOpConf* user_conf = dynamic_cast(msg); std::string old_val; if (user_conf) { std::pair pair = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx); CHECK(user_conf->input().find(pair.first) != user_conf->input().end()) << "cannot find input arg val in user op conf. (arg_name = " << pair.first << ", id = " << std::to_string(pair.second) << ")\n" << " new lbn = " << new_val; old_val = user_conf->input().at(pair.first).s(pair.second); (*(user_conf->mutable_input()))[pair.first].set_s(pair.second, new_val); } else { old_val = ReplaceStrValInPbFdOrPbRpf(msg, fd_name_may_have_idx, new_val); } return old_val; } std::string ReplaceInputLbnInOpCustomizedConf(OperatorConf* op_conf, const std::string& fd_name_may_have_idx, const std::string& new_val) { PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case()); return ReplaceInputLbnInOpTypeConf(op_type_conf, fd_name_may_have_idx, new_val); } bool operator==(const OperatorConf& lhs, const OperatorConf& rhs) { return PbMd().Equals(lhs, rhs); } namespace { Maybe InferOpOutSbpParallel( Operator* op, const OpNodeSignature& upstream_signature, const std::function& ConstBlobDesc4Ibn, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc) { const auto& SbpParallel4Ibn = [&](const std::string& ibn) -> const SbpParallel* { const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel(); return &map.at(ibn); }; HashMap ibn2sbp_infer_hint; for (const std::string& ibn : op->input_bns()) { const ParallelDesc* pd = ¶llel_desc; const BlobDesc* logical_blob_desc = &ConstBlobDesc4Ibn(ibn); const SbpParallel* sbp_parallel = SbpParallel4Ibn(ibn); ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel)); } SbpSignature sbp_signature; JUST(op->InferSbpSignature(&sbp_signature, sbp_sig_conf, ibn2sbp_infer_hint)); JUST(op->FillSbpSignature(sbp_signature)); return Maybe::Ok(); } Maybe InferLocalSignature(Operator* op, const OpNodeSignature& upstream_signature, bool is_local, const ParallelDesc& parallel_desc) { HashMap ibn2local_sig_infer_hint; for (const std::string& ibn : op->input_bns()) { const auto& map = upstream_signature.local_signature().bn_in_op2opt_local_parallel(); const auto& opt_local_parallel = map.at(ibn); ibn2local_sig_infer_hint.emplace( ibn, LocalSigInferHint(¶llel_desc, opt_local_parallel.has_local_parallel())); } const auto& LocalSigInferHint4Ibn = [&](const std::string& ibn) -> Maybe { const auto& iter = ibn2local_sig_infer_hint.find(ibn); CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << "input blob not found. ibn: " << ibn; return &iter->second; }; JUST(op->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local, parallel_desc)); return Maybe::Ok(); } Maybe CheckOpInputSignature(const Operator& op, const OpNodeSignature& upstream_signature) { for (const auto& ibn : op.input_bns()) { { CHECK_OR_RETURN(upstream_signature.has_logical_blob_desc_signature()); const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc(); CHECK_OR_RETURN(map.find(ibn) != map.end()); } { CHECK_OR_RETURN(upstream_signature.has_sbp_signature()); const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel(); CHECK_OR_RETURN(map.find(ibn) != map.end()); // NOLINT } { CHECK_OR_RETURN(upstream_signature.has_local_signature()); // NOLINT const auto& map = upstream_signature.local_signature().bn_in_op2opt_local_parallel(); CHECK_OR_RETURN(map.find(ibn) != map.end()); } } return Maybe::Ok(); } } // namespace Maybe ConstructAndInferOp(const OperatorConf& op_conf, const OpNodeSignature& upstream_signature, const Scope& scope) { const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf)); bool is_local = scope.opt_local_parallel_conf().has_local_parallel(); const auto& op = JUST(ConstructOp(op_conf)); JUST(CheckOpInputSignature(*op, upstream_signature)); JUST(op->FillOpParallelDesc(parallel_desc)); HashMap> bn_in_op2blob_desc; for (const auto& ibn : op->input_bns()) { const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc(); bn_in_op2blob_desc[ibn].reset(new BlobDesc(map.at(ibn))); } const auto& ConstBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc& { return *bn_in_op2blob_desc.at(ibn); }; JUST(op->FillLogicalInBlobDesc(ConstBlobDesc4Ibn)); // infer is_local JUST(InferLocalSignature(op.get(), upstream_signature, is_local, parallel_desc)); SbpSignature sbp_sig_conf; // iner sbp JUST(InferOpOutSbpParallel(op.get(), upstream_signature, ConstBlobDesc4Ibn, sbp_sig_conf, parallel_desc)); // infer logical blob_desc JUST(op->InferLogicalOutBlobDescsIf()); return op; } namespace { template Maybe Get1dHierarchyPhysicalShape(const Shape& logical_shape, const SbpParallelT& sbp_parallel, const int64_t parallel_num, const int64_t parallel_id) { std::shared_ptr physical = std::make_shared(logical_shape); if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); if (logical_shape.At(axis) > 0) { CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num); const BalancedSplitter bs(logical_shape.At(axis), parallel_num); physical->Set(axis, bs.At(parallel_id).size()); } } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) { // do nothing } else { UNIMPLEMENTED(); } return physical; } Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const int64_t parallel_id) { const auto& parallel_hierarchy = *parallel_desc.hierarchy(); std::shared_ptr physical = std::make_shared(logical_shape); Stride hierarch_stride(parallel_hierarchy); FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); // Both the lazy and eager mode support unbalanced splitting now if (physical->At(split_axis) > 0) { CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) << Error::RuntimeError() << "Expected size at split axis (" << split_axis << ") of logical shape must be be greater than or equal to parallel num, but got " "logical_shape: " << logical_shape.ToString() << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); } } } return physical; } } // namespace Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, int64_t parallel_id) { CHECK_GE_OR_RETURN(parallel_id, 0); CHECK_LT_OR_RETURN(parallel_id, parallel_desc.hierarchy()->elem_cnt()); CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), nd_sbp.sbp_parallel_size()); if (parallel_desc.hierarchy()->NumAxes() == 1) { return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0), parallel_desc.hierarchy()->elem_cnt(), parallel_id); } else { return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_id); } } Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext& parallel_ctx) { return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_ctx.parallel_id()); } Maybe GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) { const auto& parallel_hierarchy = *parallel_desc.hierarchy(); CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), nd_sbp.sbp_parallel_size()); std::shared_ptr logical_shape = std::make_shared(physical_shape); for (int i = parallel_hierarchy.NumAxes() - 1; i >= 0; --i) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); logical_shape->Set(split_axis, logical_shape->At(split_axis) * parallel_hierarchy.At(i)); } } return logical_shape; } } // namespace oneflow ================================================ FILE: oneflow/core/operator/operator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OPERATOR_H_ #define ONEFLOW_CORE_OPERATOR_OPERATOR_H_ #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/local_parallel.pb.h" #include "oneflow/core/operator/op_conf_util.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/job/nd_sbp_infer_hint.h" namespace oneflow { class LocalSigInferHint; class OpNodeSignature; class Scope; class Operator { public: OF_DISALLOW_COPY_AND_MOVE(Operator); Operator(); virtual ~Operator() = default; // Maybe Init(const OperatorConf& op_conf); Maybe Init(std::shared_ptr op_conf); virtual Maybe InitFromOpConf() = 0; // bn_in_op <-> lbi const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const; // Getters const std::string& op_name() const { return op_conf().name(); } const std::string& op_loc() const { return op_conf().loc(); } DeviceType device_type() const; const OperatorConf& op_conf() const; std::shared_ptr shared_op_conf() const; const PbMessage& GetCustomizedConf() const { return GetMessageInPbMessage(op_conf(), op_conf().op_type_case()); } template T GetValFromCustomizedConf(const std::string& field_name) const { return GetValFromPbMessage(GetCustomizedConf(), field_name); } template const PbRpf& GetPbRpfFromCustomizedConf(const std::string& field_name) const { return GetPbRpfFromPbMessage(GetCustomizedConf(), field_name); } const std::string& SoleIbn() const; const std::string& SoleObn() const; const std::string& SoleTbn() const; Maybe obn4lbi(const LogicalBlobId& lbi) const; const PbRpf& input_bns() const; const PbRpf& output_bns() const; const PbRpf& tmp_bns() const; const PbRpf& input_output_bns() const; Maybe FillOpParallelDesc(const ParallelDesc& parallel_desc); Maybe FillOpParallelDesc(std::shared_ptr parallel_desc); Maybe GetOpParallelDesc() const; Maybe InferParallelSignatureIf(); Maybe GetParallelDesc4BnInOp(const std::string& bn) const; Maybe FillLogicalInBlobDesc( const std::function& BlobDesc4BnInOp); Maybe FillLogicalInBlobDesc( const std::function& BlobDesc4BnInOp); Maybe FillLogicalInBlobDesc( const std::function(int32_t)>& BlobDesc4InputIndex); Maybe GetLogicalBlobDesc4Ibn(const std::string& ibn) const; Maybe GetLogicalBlobDesc4InputIndex(int32_t index) const; Maybe FillLogicalOutBlobDesc( const std::function& BlobDesc4BnInOp); Maybe FillLogicalOutBlobDesc( const std::function& BlobDesc4BnInOp); Maybe GetLogicalBlobDesc4Obn(const std::string& obn) const; Maybe GetLogicalBlobDesc4OutputIndex(int32_t index) const; Maybe GetLogicalBlobDescPtr4OutputIndex(int32_t index) const; Maybe GetLogicalBlobDesc4BnInOp(const std::string& bn) const; Maybe InferLogicalOutBlobDescsIf(); virtual Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const = 0; // Read: shape of input_blobs // Write: shape of output_blobs Maybe InferBlobDescsIf( const std::function& GetBlobDesc4BnInOp, const ParallelContext*, const JobDesc* job_desc) const; Maybe InferOutBlobDescsIf(std::function GetBlobDesc4BnInOp, const ParallelContext*) const; Maybe InferInternalBlobDescsIf( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; Maybe InferInplaceObn2IbnIf( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const; Maybe FillInputBlobTimeShape( const std::function(int32_t)>& GetTimeShape4InputIndex); Maybe InferOpTimeShapeIf(); virtual Maybe InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const; Maybe GetOpTimeShape() const; Maybe GetInputBlobFastestTimeShape() const; Maybe GetInputOutputFastestTimeShape() const; Maybe InferSbpSignature(SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const HashMap& ibn2sbp_infer_hint) const; Maybe FillSbpSignature(const SbpSignature& sbp_signature); Maybe FillNdSbpSignature(const NdSbpSignature& signature); Maybe InferSbpSignatureIf( const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, const std::function(const std::string&)>& SbpInferHint4Ibn, const ParallelDesc& parallel_desc); Maybe InferNdSbpSignatureIf( const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn); // The function that how to dump nd_sbp for op_conf using DumpNdSbpSignatureForOpConfFn = std::function(const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf)>; virtual DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const; // Infer blob's LocalSignature Maybe InferLocalSignatureIf( std::function(const std::string&)> LocalSigInferHint4Ibn, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc); void GenKernelConf(const std::function& GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const; const InputBlobModifier& InputBlobModifier4Ibn(const std::string& ibn) const; const OutputBlobModifier& OutputBlobModifier4Obn(const std::string& obn) const; Maybe SbpParallel4BnInOp(const std::string& bn_in_op) const; Maybe NdSbp4BnInOp(const std::string& bn_in_op) const; Maybe OptLocalParallel4BnInOp(const std::string& bn_in_op) const; Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const; virtual Maybe EnumerateNdSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; virtual Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; virtual Maybe GetComputeComplexity( NdSbpSignature* sbp_signature, std::function logical_blob_desc4bn, const ParallelDesc& parallel_desc) const; // TODO: Will infer blob shape before inferring sbp and delete the check_output later Maybe GetValidNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list, bool check_output) const; void ForEachBnInOp(const std::function&) const; virtual Symbol GetOpConfWithoutOpNameAndLbn() const; std::shared_ptr GetOpAttributeWithoutOpNameAndLbn() const; Maybe sbp_signature() const; Maybe nd_sbp_signature() const; BlobLastUsedSignature* mut_blob_last_used_signature(); BlobBackwardUsedSignature* mut_blob_backward_used_signature(); Maybe GetInputIndex(const std::string& ibn) const; Maybe GetOutputIndex(const std::string& obn) const; Maybe GetOutputIndex(const LogicalBlobId& lbi) const; Maybe ToOpAttribute(OpAttribute* op_attribute) const; protected: Maybe FillBlobParallelDesc( const std::function(const std::string&)>& ParallelDesc4Bn); virtual Maybe InferBlobParallelDesc(); virtual Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const; virtual Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); } virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(sbp_sig_list); } virtual Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const; virtual Maybe InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const; virtual Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { OF_UNIMPLEMENTED() << " GetSbpSignatures unimplemented, op name: " << op_name(); } virtual Maybe InferLocalSignature( std::function(const std::string&)> LocalSigInferHint4Ibn, bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc); virtual Maybe InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const; virtual void VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const; virtual void AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index); virtual LogicalBlobId lbi4ibn(const std::string& input_bn) const; virtual LogicalBlobId lbi4obn(const std::string& output_bn) const; // enroll data blobs void EnrollTmpBn(const std::string& dtbn); void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff); void EnrollRepeatedInputBn(const std::string& ibn_prefix, bool has_diff); void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num); void EnrollRepeatedInputBn(const std::string& ibn_prefix); void EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num, bool has_diff); void EnrollRepeatedOutputBn(const std::string& obn_prefix, bool has_diff); void EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num); void EnrollRepeatedOutputBn(const std::string& obn_prefix); void EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, int32_t num, bool has_diff, const std::function& ModifierSetter); void EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, bool has_diff, const std::function& ModifierSetter); void EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, int32_t num, const std::function& ModifierSetter); void EnrollRepeatedOutputBnWithSetter( const std::string& obn_prefix, const std::function& ModifierSetter); InputBlobModifier* EnrollInputBn(const std::string& ibn, bool has_diff); InputBlobModifier* EnrollInputBn(const std::string& ibn) { return EnrollInputBn(ibn, true); } OutputBlobModifier* EnrollOutputBn(const std::string& obn, bool has_diff); OutputBlobModifier* EnrollOutputBn(const std::string& obn) { return EnrollOutputBn(obn, true); } InputBlobModifier* MutInputBlobModifier4Ibn(const std::string& ibn); OutputBlobModifier* MutOutputBlobModifier4Obn(const std::string& obn); OptLocalParallel* MutOptLocalParallel(const std::string& bn_in_op); private: enum BlobNameTag { kInputBlobName, kOutputBlobName, }; Maybe FilterAndCheckValidSbpSignatureListByLogicalShape( const SbpSignatureList& total_sbp_sig_list, const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* valid_sbp_sig_list) const; // TODO(wyg): 1d and nd sbp use this function to filter and check Maybe FilterNdSbpSignatureListByLogicalShape( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list, bool check_output) const; Maybe GreedilyFindMinCopyCostNdSbp( NdSbpSignature* nd_sbp_signature, const std::function(const std::string&)>& NdSbpInferHint4Ibn, const std::vector& nd_sbp_sig_list) const; LogicalBlobId tbn2lbi(const std::string& data_tmp_bn) const; std::string Bn2ConfName(const std::string& bn) const; PbMap* mut_bn_in_op2lbi() { return arg_signature_.mutable_bn_in_op2lbi(); } std::shared_ptr op_conf_; std::shared_ptr op_parallel_desc_; std::unique_ptr>> bn2parallel_desc_; std::unique_ptr>> input_index2logical_blob_desc_; std::unique_ptr>> output_index2logical_blob_desc_; std::unique_ptr>> input_index2time_shape_; std::shared_ptr input_blob_fastest_time_shape_; std::shared_ptr input_output_fastest_time_shape_; std::shared_ptr op_time_shape_; std::shared_ptr sbp_signature_; std::shared_ptr nd_sbp_signature_; PbRpf input_bns_; PbRpf output_bns_; PbRpf tmp_bns_; PbRpf input_output_bns_; DeviceType device_type_; ArgSignature arg_signature_; ArgModifierSignature arg_modifier_signature_; std::unique_ptr blob_last_used_signature_; std::unique_ptr blob_backward_used_signature_; std::unique_ptr local_signature_; HashMap> bn2index_pair_; HashMap lbi2output_index_; }; std::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx); std::pair GenUnRepeatedBn(const std::string& bn); bool IsCpuOnly(const std::string& user_op_type_name); bool IsCpuOnly(const OperatorConf& op_conf); struct OnlyCpuSupportPredicator { OnlyCpuSupportPredicator(bool only_cpu) : only_cpu_(only_cpu) {} operator bool() { return only_cpu_; } private: bool only_cpu_; }; struct RuntimeRegstNum4OpSameOutputBlob final { RuntimeRegstNum4OpSameOutputBlob(size_t num) : num_(num) {} RuntimeRegstNum4OpSameOutputBlob(std::function get_num) : get_num_(new std::function(get_num)) {} operator size_t() { if (!get_num_) { return num_; } return (*this->get_num_)(); } private: size_t num_; std::unique_ptr> get_num_; }; #define REGISTER_OP(op_type_case, OpType) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator, \ ([] { return new OnlyCpuSupportPredicator(false); })); \ REGISTER_CLASS_WITH_ARGS(int32_t, op_type_case, Operator, OpType, const OperatorConf&) #define REGISTER_CPU_OP(op_type_case, OpType) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator, \ ([] { return new OnlyCpuSupportPredicator(true); })); \ REGISTER_CLASS_WITH_ARGS(int32_t, op_type_case, Operator, OpType, const OperatorConf&) #define REGISTER_OP_CREATOR(op_type_case, creator) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator, \ ([] { return new OnlyCpuSupportPredicator(false); })); \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, Operator, creator, const OperatorConf&) #define REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(op_type_case, num) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, RuntimeRegstNum4OpSameOutputBlob, \ ([] { return new RuntimeRegstNum4OpSameOutputBlob(num); })) #define REGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM(op_type_name, num) \ REGISTER_CLASS_CREATOR(std::string, op_type_name, RuntimeRegstNum4OpSameOutputBlob, \ ([] { return new RuntimeRegstNum4OpSameOutputBlob(num); })) #define REGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM_WITH_FUNC(op_type_name, func) \ REGISTER_CLASS_CREATOR(std::string, op_type_name, RuntimeRegstNum4OpSameOutputBlob, \ ([] { return new RuntimeRegstNum4OpSameOutputBlob(func); })); struct IsInterfaceOpConf4OpTypeCase final {}; #define REGISTER_INTERFACE_OP(op_type_case) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, IsInterfaceOpConf4OpTypeCase, \ ([] { return new IsInterfaceOpConf4OpTypeCase(); })) struct DisableInputBoxingGroup final {}; #define REGISTER_DISABLE_INPUT_BOXING_GROUP(op_type_case) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, DisableInputBoxingGroup, \ ([] { return new DisableInputBoxingGroup(); })) struct IsTickTockOpTypeCase final {}; #define REGISTER_TICK_TOCK_OP(op_type_case) \ REGISTER_CLASS_CREATOR(int32_t, op_type_case, IsTickTockOpTypeCase, \ ([] { return new IsTickTockOpTypeCase; })) Maybe ConstructOp(const OperatorConf& op_conf); Maybe ConstructOp(const OperatorConf& op_conf, DeviceType device_type); inline OpBlobArg GenOpBlobArg(const std::string& op_name, const std::string& bn_in_op) { OpBlobArg oba; oba.set_op_name(op_name); oba.set_bn_in_op(bn_in_op); return oba; } LogicalBlobId GenLogicalBlobId(const std::string& lbn); inline std::string GenLogicalBlobName(const std::string& op_name, const std::string& blob_name) { return op_name + "/" + blob_name; } inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) { CHECK_EQ(lbi.has_op_name(), true); CHECK_EQ(lbi.has_blob_name(), true); return GenLogicalBlobName(lbi.op_name(), lbi.blob_name()); } Maybe GetSbpParallelInLbnOrNothing(const std::string& lbn, SbpParallel* sbp); Maybe ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disable_boxing); std::string GetInputLbnInOpCustomizedConf(const OperatorConf& op_conf, const std::string& fd_name_may_have_idx); // return old value std::string ReplaceInputLbnInOpCustomizedConf(OperatorConf* op_conf, const std::string& fd_name_may_have_idx, const std::string& new_val); bool operator==(const OperatorConf& lhs, const OperatorConf& rhs); Maybe ConstructAndInferOp(const OperatorConf& op_conf, const OpNodeSignature& upstream_signature, const Scope& scope); Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext& parallel_ctx); Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, int64_t parallel_id); Maybe GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc); } // namespace oneflow namespace std { template<> struct hash final { size_t operator()(const oneflow::OperatorConf& op_conf) const { std::string serialized; op_conf.SerializeToString(&serialized); return std::hash()(serialized); } }; } // namespace std #endif // ONEFLOW_CORE_OPERATOR_OPERATOR_H_ ================================================ FILE: oneflow/core/operator/operator_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator_util.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { size_t DhwOffset(const std::string& data_format) { if (data_format == "channels_first") { return 2; } else if (data_format == "channels_last") { return 1; } else { UNIMPLEMENTED(); } } std::vector Get3DVecInOpConf(const PbRf& field_vals, int32_t NDims) { std::vector vec; vec.reserve(3); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - NDims); if (index < 0) { vec.emplace_back(1); } else { vec.emplace_back(field_vals.Get(index)); } } return vec; } int64_t GetInDim(const ShapeView& shape, const std::string& data_format, int32_t dim, int32_t NDims) { int64_t offset = 0; if (data_format == "channels_last") { offset = 1; } else if (data_format == "channels_first") { offset = 2; } else { UNIMPLEMENTED(); } int64_t index = offset + static_cast(dim) - static_cast(3 - NDims); if (index < offset) { return 1; } else { return shape.At(index); } } void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, bool ceil_mode, int64_t* output_size, int32_t* padding_before, int32_t* padding_after) { CHECK_GT(stride, 0); CHECK_GE(dilation_rate, 1); int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; if (padding_type == "customized") { if (output_size) { *output_size = (input_size + *padding_before + *padding_after - effective_filter_size + stride + (ceil_mode ? stride - 1 : 0)) / stride; CHECK_GE((*output_size), 0); } } else if (padding_type == "valid") { if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; } if (padding_before) { *padding_before = 0; } if (padding_after) { *padding_after = 0; } } else { int64_t tmp_output_size = (input_size + stride - 1) / stride; if (output_size) { *output_size = tmp_output_size; } const int32_t padding_needed = std::max( 0, static_cast((tmp_output_size - 1) * stride + effective_filter_size - input_size)); const int32_t padding_small = padding_needed / 2; const int32_t padding_large = padding_needed - padding_needed / 2; if (padding_type == "same_upper") { if (padding_before) { *padding_before = padding_small; } if (padding_after) { *padding_after = padding_large; } } else if (padding_type == "same_lower") { if (padding_before) { *padding_before = padding_large; } if (padding_after) { *padding_after = padding_small; } } else { UNIMPLEMENTED(); } } if (output_size) { CHECK_GE((*output_size), 0); } } void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after) { CHECK_GT(stride, 0); CHECK_GE(dilation_rate, 1); int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; if (padding_type == "valid") { if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; } if (padding_before) { *padding_before = 0; } if (padding_after) { *padding_after = 0; } } else if (padding_type == "same") { int64_t tmp_output_size = (input_size + stride - 1) / stride; if (output_size) { *output_size = tmp_output_size; } const int32_t padding_needed = std::max( 0, static_cast((tmp_output_size - 1) * stride + effective_filter_size - input_size)); // For odd values of total padding, add more padding at the 'right' // side of the given dimension. if (padding_before) { *padding_before = padding_needed / 2; } if (padding_after) { *padding_after = padding_needed - padding_needed / 2; } } else { UNIMPLEMENTED(); } if (output_size) { CHECK_GE((*output_size), 0); } } void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after) { GetWindowedOutputSize(input_size, filter_size, 1, stride, padding_type, output_size, padding_before, padding_after); } void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_size) { GetWindowedOutputSize(input_size, filter_size, stride, padding_type, output_size, padding_size, nullptr); } void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding) { Get3DOutputSize(in, pool_size, strides, padding_type, out, padding, nullptr, nullptr); } void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding_before, std::vector* padding_after) { Get3DOutputSize(in, pool_size, strides, padding_type, out, padding_before, padding_after, nullptr); } void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding_before, std::vector* padding_after, std::vector* dilation_rate) { CHECK(out); out->clear(); out->resize(3); if (padding_before) { padding_before->clear(); padding_before->resize(3); } if (padding_after) { padding_after->clear(); padding_after->resize(3); } FOR_RANGE(size_t, i, 0, 3) { int64_t* out_ptr = &(*out).at(i); int32_t* padding_before_ptr = padding_before ? (&(*padding_before).at(i)) : nullptr; int32_t* padding_after_ptr = padding_after ? (&(*padding_after).at(i)) : nullptr; if (dilation_rate) { GetWindowedOutputSize(in.at(i), pool_size.at(i), dilation_rate->at(i), strides.at(i), padding_type, out_ptr, padding_before_ptr, padding_after_ptr); } else { GetWindowedOutputSize(in.at(i), pool_size.at(i), strides.at(i), padding_type, out_ptr, padding_before_ptr, padding_after_ptr); } } } void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, const bool ceil_mode, std::vector* dilation_rate, DimVector* out, std::vector* padding_before, std::vector* padding_after) { CHECK(out); out->clear(); out->resize(3); FOR_RANGE(size_t, i, 0, 3) { int64_t* out_ptr = &(*out).at(i); if (dilation_rate) { GetWindowedOutputSize(in.at(i), pool_size.at(i), dilation_rate->at(i), strides.at(i), padding_type, ceil_mode, out_ptr, &(padding_before->at(i)), &(padding_after->at(i))); } else { GetWindowedOutputSize(in.at(i), pool_size.at(i), 1, strides.at(i), padding_type, ceil_mode, out_ptr, &(padding_before->at(i)), &(padding_after->at(i))); } } } void GetConvOutAndPad(const ShapeView& in_blob_shape, const PbMessage& conv_conf, DimVector* out, std::vector* pad_small_side, std::vector* pad_large_side) { int32_t opkernel_dim = in_blob_shape.NumAxes() - 2; if (out) { out->assign(opkernel_dim, 0); } if (pad_small_side) { pad_small_side->assign(opkernel_dim, 0); } if (pad_large_side) { pad_large_side->assign(opkernel_dim, 0); } const auto& data_format = GetValFromPbMessage(conv_conf, "data_format"); const std::string& padding = GetValFromPbMessage(conv_conf, "padding"); const PbRf& dilation_rate = GetPbRfFromPbMessage(conv_conf, "dilation_rate"); const auto& strides = GetPbRfFromPbMessage(conv_conf, "strides"); const PbRf& kernel_size = GetPbRfFromPbMessage(conv_conf, "kernel_size"); FOR_RANGE(int32_t, i, 0, opkernel_dim) { GetWindowedOutputSize(in_blob_shape.At(DhwOffset(data_format) + i), kernel_size.Get(i), dilation_rate.Get(i), strides.Get(i), padding, out ? &(out->at(i)) : nullptr, pad_small_side ? &(pad_small_side->at(i)) : nullptr, pad_large_side ? &(pad_large_side->at(i)) : nullptr); } } void GetConvOutAndPad(const ShapeView& in_blob_shape, const user_op::UserOpConfWrapper& conv_conf, DimVector* out, std::vector* pad_small_side, std::vector* pad_large_side) { int32_t opkernel_dim = in_blob_shape.NumAxes() - 2; if (out) { out->assign(opkernel_dim, 0); } if (pad_small_side) { pad_small_side->assign(opkernel_dim, 0); } if (pad_large_side) { pad_large_side->assign(opkernel_dim, 0); } const auto& data_format = conv_conf.attr("data_format"); const auto& padding = conv_conf.attr("padding"); const auto& strides = conv_conf.attr>("strides"); const auto& dilation_rate = conv_conf.attr>("dilation_rate"); const auto& kernel_size = conv_conf.attr>("kernel_size"); FOR_RANGE(int32_t, i, 0, opkernel_dim) { GetWindowedOutputSize(in_blob_shape.At(DhwOffset(data_format) + i), kernel_size.at(i), dilation_rate.at(i), strides.at(i), padding, out ? &(out->at(i)) : nullptr, pad_small_side ? &(pad_small_side->at(i)) : nullptr, pad_large_side ? &(pad_large_side->at(i)) : nullptr); } } } // namespace oneflow ================================================ FILE: oneflow/core/operator/operator_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_ #define ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace user_op { class UserOpConfWrapper; } size_t DhwOffset(const std::string& data_format); std::vector Get3DVecInOpConf(const PbRf& field_vals, int32_t NDims); int64_t GetInDim(const ShapeView& shape, const std::string& data_format, int32_t dim, int32_t NDim); void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after); void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_size); void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after); void GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, bool ceil_mode, int64_t* output_size, int32_t* padding_before, int32_t* padding_after); void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding); void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding_before, std::vector* padding_after); void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, DimVector* out, std::vector* padding_before, std::vector* padding_after, std::vector* dilation_rate); void Get3DOutputSize(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::string& padding_type, const bool ceil_mode, std::vector* dilation_rate, DimVector* out, std::vector* padding_before, std::vector* padding_after); void GetConvOutAndPad(const ShapeView& in_blob_shape, const PbMessage& conv_conf, DimVector* out, std::vector* pad_small_side, std::vector* pad_large_side); void GetConvOutAndPad(const ShapeView& in_blob_shape, const user_op::UserOpConfWrapper& conv_conf, DimVector* out, std::vector* pad_small_side, std::vector* pad_large_side); } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_ ================================================ FILE: oneflow/core/operator/output_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/interface_op_util.h" #include "oneflow/core/operator/output_op.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/job/env_desc.h" namespace oneflow { namespace { Maybe InferOutputOpNdSbpSignature(NdSbpSignature* nd_sbp_signature, const ParallelDesc& parallel_desc, const OperatorConf& op_conf) { const InterfaceBlobConf& blob_conf = op_conf.output_conf().blob_conf(); NdSbp& in_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["in"]; NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["out"]; JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &in_nd_sbp)); JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &out_nd_sbp)); return Maybe::Ok(); } } // anonymous namespace Maybe OutputOp::InitFromOpConf() { CHECK(op_conf().has_output_conf()); EnrollInputBn("in"); EnrollOutputBn("out")->set_is_mutable(true); return Maybe::Ok(); } Maybe OutputOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); *out_blob_desc = *BlobDesc4BnInOp("in"); return Maybe::Ok(); } Maybe OutputOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); // NOTE(chengcheng): // In multi-client, in blob shape maybe changed and NOT equal with output_conf.blob_conf, // and the output op actually is return op (used in single-client) with NO blob conf. *out_blob_desc = *in_blob_desc; return Maybe::Ok(); } Maybe OutputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignature* sbp = sbp_sig_list->mutable_sbp_signature()->Add(); CHECK_EQ_OR_RETURN(JUST(GetOpParallelDesc())->hierarchy()->NumAxes(), 1) << "Only support 1d sbp now."; // Get sbp from BlobConf const InterfaceBlobConf& blob_conf = op_conf().output_conf().blob_conf(); // TODO: make sure blob_conf must set nd_sbp CHECK_OR_RETURN(blob_conf.has_nd_sbp()); const SbpParallel& sbp_parallel = SbpParallel(blob_conf.nd_sbp().sbp_parallel(0)); if (sbp_parallel.has_broadcast_parallel()) { SbpSignatureBuilder().Broadcast("in").Broadcast("out").Build(sbp); } else if (sbp_parallel.has_partial_sum_parallel()) { SbpSignatureBuilder().PartialSum("in").PartialSum("out").Build(sbp); } else if (sbp_parallel.has_split_parallel()) { int64_t split_axis = sbp_parallel.split_parallel().axis(); SbpSignatureBuilder().Split("in", split_axis).Split("out", split_axis).Build(sbp); } else { UNIMPLEMENTED_THEN_RETURN(); } return Maybe::Ok(); } Maybe OutputOp::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { NdSbpSignature nd_sbp_signature; JUST(InferOutputOpNdSbpSignature(&nd_sbp_signature, parallel_desc, op_conf())); nd_sbp_sig_list->emplace_back(nd_sbp_signature); return Maybe::Ok(); } Maybe OutputOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { JUST(InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(), input_bns(), output_bns(), sbp_signature)); return Maybe::Ok(); } Maybe OutputOp::InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const { JUST(InferOutputOpNdSbpSignature(nd_sbp_signature, parallel_desc, op_conf())); return Maybe::Ok(); } Symbol OutputOp::GetOpConfWithoutOpNameAndLbn() const { return SymbolOf(this->op_conf()); } REGISTER_OP(OperatorConf::kOutputConf, OutputOp); REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kOutputConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kOutputConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/output_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_ #define ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class OutputOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(OutputOp); OutputOp() = default; ~OutputOp() override = default; Maybe InitFromOpConf() override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const override; Maybe InferNdSbpSignature(NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_ ================================================ FILE: oneflow/core/operator/reduce_sbp_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/reduce_sbp_util.h" namespace oneflow { bool ReduceSbpUtil::IsReduceAxisSplitted(const SbpInferHint& ibn_hint, const HashSet& reduced_axes) { if (ibn_hint.sbp_parallel().has_split_parallel() == false) { return false; } if (reduced_axes.empty()) { return true; } return reduced_axes.find(ibn_hint.sbp_parallel().split_parallel().axis()) != reduced_axes.end(); } std::function ReduceSbpUtil::MakePredicatorIsReducedAxis(const PbRf& axes, int32_t num_axes) { HashSet axes_set = {axes.begin(), axes.end()}; return MakePredicatorIsReducedAxis(axes_set, num_axes); } std::function ReduceSbpUtil::MakePredicatorIsReducedAxis( const HashSet& axes, int32_t num_axes) { auto axis_set = std::make_shared>(axes); return [axis_set](int32_t axis) -> bool { return axis_set->find(axis) != axis_set->end(); }; } void ReduceSbpUtil::GetRegularAxes(int64_t num_axes, const std::vector& reduce_axes, HashSet* axes) { for (auto axis : reduce_axes) { axes->insert(ShiftNegativeAxis(axis, num_axes)); } } } // namespace oneflow ================================================ FILE: oneflow/core/operator/reduce_sbp_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_ #define ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/job/sbp_infer_hint.h" #include "oneflow/core/operator/operator.h" namespace oneflow { struct ReduceSbpUtil final { static bool IsReduceAxisSplitted(const SbpInferHint& ibn_hint, const HashSet& reduced_axes); static std::function MakePredicatorIsReducedAxis(const HashSet& axes, int32_t num_axes); static std::function MakePredicatorIsReducedAxis(const PbRf& axes, int32_t num_axes); static void GetRegularAxes(int64_t num_axes, const std::vector& reduce_axes, HashSet* axes); }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_ ================================================ FILE: oneflow/core/operator/reentrant_lock_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/reentrant_lock_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe ReentrantLockOp::InitFromOpConf() { EnrollInputBn("start", false); if (op_conf().reentrant_lock_conf().has_end()) { EnrollInputBn("end", false); } EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { const BlobDesc* start = BlobDesc4BnInOp("start"); const DataType data_type = start->data_type(); CHECK_OR_RETURN(IsIntegralDataType(data_type)); BlobDesc* out = BlobDesc4BnInOp("out"); out->set_shape(Shape({1})); out->set_data_type(data_type); return Maybe::Ok(); } } // namespace Maybe ReentrantLockOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); return InferBlobDescs(BlobDesc4BnInOp); } Maybe ReentrantLockOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe ReentrantLockOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kReentrantLockConf, ReentrantLockOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/reentrant_lock_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_ #define ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class ReentrantLockOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ReentrantLockOp); ReentrantLockOp() = default; ~ReentrantLockOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_ ================================================ FILE: oneflow/core/operator/return_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/return_op.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/operator/interface_op_util.h" namespace oneflow { Maybe ReturnOp::InitFromOpConf() { CHECK(op_conf().has_return_conf()); EnrollInputBn("in"); EnrollOutputBn("out")->set_is_mutable(true); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { *BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in"); return Maybe::Ok(); } } // namespace Maybe ReturnOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe ReturnOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe ReturnOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn("in")); CHECK_EQ_OR_RETURN(in_sbp_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num()); if (in_sbp_infer_hint.sbp_parallel().has_partial_sum_parallel()) { SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature); } else { auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel(); (*bn2sbp)["in"] = in_sbp_infer_hint.sbp_parallel(); (*bn2sbp)["out"] = in_sbp_infer_hint.sbp_parallel(); } return Maybe::Ok(); } Symbol ReturnOp::GetOpConfWithoutOpNameAndLbn() const { return SymbolOf(this->op_conf()); } REGISTER_OP(OperatorConf::kReturnConf, ReturnOp); REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kReturnConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kReturnConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/return_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_RETURN_OP_H_ #define ONEFLOW_CORE_OPERATOR_RETURN_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class ReturnOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ReturnOp); ReturnOp() = default; ~ReturnOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_RETURN_OP_H_ ================================================ FILE: oneflow/core/operator/scalar_op_base.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/scalar_op_base.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe ScalarOpBase::InitFromOpConf() { EnrollInputBn("in"); EnrollInputBn("scalar"); EnrollOutputBn("out")->set_mutable_inplace_ibn("in"); return Maybe::Ok(); } Maybe ScalarOpBase::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); const BlobDesc* scalar_blob_desc = GetBlobDesc4BnInOp("scalar"); CHECK_EQ_OR_RETURN(in_blob_desc->data_type(), scalar_blob_desc->data_type()); CHECK_EQ_OR_RETURN(scalar_blob_desc->shape().elem_cnt(), 1); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *out_blob_desc = *in_blob_desc; return Maybe::Ok(); } Maybe ScalarOpBase::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { const Shape& in_shape = JUST(LogicalBlobDesc4Ibn("in")).shape(); FOR_RANGE(int64_t, i, 0, in_shape.NumAxes()) { SbpSignatureBuilder().Split("in", i).Broadcast("scalar").Split("out", i).Build( sbp_sig_list->mutable_sbp_signature()->Add()); } JUST(VirtualGetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/core/operator/scalar_op_base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_ #define ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class ScalarOpBase : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ScalarOpBase); ScalarOpBase() = default; ~ScalarOpBase() override = default; Maybe InitFromOpConf() override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; protected: virtual Maybe VirtualGetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_ ================================================ FILE: oneflow/core/operator/shape_elem_cnt_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/shape_elem_cnt_op.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { HashSet GetInclusiveAxes(const ShapeElemCntOpConf& conf, int32_t num_axes) { HashSet ret; if (conf.has_exclude_axis_conf()) { HashSet exclude_axes(conf.exclude_axis_conf().axis().begin(), conf.exclude_axis_conf().axis().end()); FOR_RANGE(int32_t, i, 0, num_axes) { if (exclude_axes.find(i) == exclude_axes.end() && exclude_axes.find(i - num_axes) == exclude_axes.end()) { ret.insert(i); } } } else if (conf.has_include_axis_conf()) { for (int32_t axis : conf.include_axis_conf().axis()) { if (axis < 0) { axis += num_axes; } CHECK_GE(axis, 0); CHECK_LT(axis, num_axes); ret.insert(axis); } } else if (conf.has_range_axis_conf()) { TODO(); } else { UNIMPLEMENTED(); } return ret; } } // namespace Maybe ShapeElemCntOp::InitFromOpConf() { EnrollInputBn("x", false); EnrollOutputBn("y", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { BlobDesc4BnInOp("y")->set_data_type(op_conf.shape_elem_cnt_conf().data_type()); BlobDesc4BnInOp("y")->set_shape(Shape({})); return Maybe::Ok(); } } // namespace Maybe ShapeElemCntOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe ShapeElemCntOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } void ShapeElemCntOp::VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { int32_t num_axes = GetBlobDesc4BnInOp("x")->shape().NumAxes(); const HashSet& inclusive_axis = GetInclusiveAxes(op_conf().shape_elem_cnt_conf(), num_axes); *kernel_conf->mutable_shape_elem_cnt_conf()->mutable_axis() = {inclusive_axis.begin(), inclusive_axis.end()}; } Maybe ShapeElemCntOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { int32_t num_axes = JUST(LogicalBlobDesc4Ibn("x")).shape().NumAxes(); const auto& inclusive_axes = GetInclusiveAxes(op_conf().shape_elem_cnt_conf(), num_axes); auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(inclusive_axes, num_axes); FOR_RANGE(int64_t, i, 0, num_axes) { if (IsReducedAxis(i)) { SbpSignatureBuilder() .Split(input_bns(), i) .PartialSum(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); } else { SbpSignatureBuilder() .Split(input_bns(), i) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); } } if (num_axes == 0) { SbpSignatureBuilder() .PartialSum(input_bns()) .PartialSum(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); } return Maybe::Ok(); } REGISTER_OP(OperatorConf::kShapeElemCntConf, ShapeElemCntOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/shape_elem_cnt_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_ #define ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class ShapeElemCntOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ShapeElemCntOp); ShapeElemCntOp() = default; ~ShapeElemCntOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_ ================================================ FILE: oneflow/core/operator/sink_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/sink_tick_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe SinkTickOp::InitFromOpConf() { CHECK(op_conf().has_sink_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace Maybe SinkTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe SinkTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe SinkTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kSinkTickConf, SinkTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kSinkTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/sink_tick_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_ #define ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class SinkTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SinkTickOp); SinkTickOp() = default; ~SinkTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_ ================================================ FILE: oneflow/core/operator/slice_boxing_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { class SliceBoxingOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingOp); SliceBoxingOp() = default; ~SliceBoxingOp() override = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { UNIMPLEMENTED_THEN_RETURN(); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; protected: virtual const SliceBoxingConf& GetCustomizedBoxingConf() const = 0; virtual void VirtualInitFromOpConf(){}; private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; class SliceBoxingCopyOp final : public SliceBoxingOp { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingCopyOp); SliceBoxingCopyOp() = default; ~SliceBoxingCopyOp() override = default; private: const SliceBoxingConf& GetCustomizedBoxingConf() const override { return op_conf().slice_boxing_copy_conf().slice_boxing_conf(); } Symbol GetOpConfWithoutOpNameAndLbn() const override; }; class SliceBoxingAddOp final : public SliceBoxingOp { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingAddOp); SliceBoxingAddOp() = default; ~SliceBoxingAddOp() override = default; private: const SliceBoxingConf& GetCustomizedBoxingConf() const override { return op_conf().slice_boxing_add_conf().slice_boxing_conf(); } void VirtualInitFromOpConf() override; Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; }; Maybe SliceBoxingOp::InitFromOpConf() { EnrollRepeatedInputBn("in", GetCustomizedBoxingConf().in_slice_size(), false); EnrollOutputBn("out"); VirtualInitFromOpConf(); return Maybe::Ok(); } LogicalBlobId SliceBoxingOp::lbi4ibn(const std::string& input_bn) const { return GetCustomizedBoxingConf().lbi(); } LogicalBlobId SliceBoxingOp::lbi4obn(const std::string& output_bn) const { return GetCustomizedBoxingConf().lbi(); } Maybe SliceBoxingOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const SliceBoxingConf& slice_boxing_conf = GetCustomizedBoxingConf(); const PbRpf& in_slice_proto = slice_boxing_conf.in_slice(); const TensorSliceViewProto& out_slice_proto = slice_boxing_conf.out_slice(); const BlobDesc* in_0 = GetBlobDesc4BnInOp(GenRepeatedBn("in", 0)); const DataType data_type = in_0->data_type(); FOR_RANGE(int64_t, i, 1, input_bns().size()) { const BlobDesc* in_i = GetBlobDesc4BnInOp(GenRepeatedBn("in", i)); CHECK_EQ(in_i->data_type(), data_type); } FOR_RANGE(int64_t, i, 0, input_bns().size()) { const BlobDesc* in_i = GetBlobDesc4BnInOp(GenRepeatedBn("in", i)); const TensorSliceView in_i_slice(in_slice_proto.Get(i)); CHECK_EQ(in_i->shape().elem_cnt(), in_i_slice.shape().elem_cnt()); } const TensorSliceView out_slice(out_slice_proto); BlobDesc* out = GetBlobDesc4BnInOp("out"); out->set_data_type(data_type); if (slice_boxing_conf.has_out_shape()) { const Shape out_shape(slice_boxing_conf.out_shape()); CHECK_EQ(out_shape.elem_cnt(), out_slice.shape().elem_cnt()); out->set_shape(out_shape); } else { out->set_shape(out_slice.shape()); } return Maybe::Ok(); } Symbol SliceBoxingCopyOp::GetOpConfWithoutOpNameAndLbn() const { OperatorConf op_conf(this->op_conf()); op_conf.set_name("undefined-op-name"); CHECK(op_conf.has_slice_boxing_copy_conf()); auto* boxing_conf = op_conf.mutable_slice_boxing_copy_conf(); LogicalBlobId empty_logical_blob_id{}; *boxing_conf->mutable_slice_boxing_conf()->mutable_lbi() = empty_logical_blob_id; return SymbolOf(op_conf); } void SliceBoxingAddOp::VirtualInitFromOpConf() { EnrollTmpBn("buf"); } Maybe SliceBoxingAddOp::InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { *GetBlobDesc4BnInOp("buf") = *GetBlobDesc4BnInOp("out"); return Maybe::Ok(); } Symbol SliceBoxingAddOp::GetOpConfWithoutOpNameAndLbn() const { OperatorConf op_conf(this->op_conf()); op_conf.set_name("undefined-op-name"); CHECK(op_conf.has_slice_boxing_add_conf()); auto* boxing_conf = op_conf.mutable_slice_boxing_add_conf(); LogicalBlobId empty_logical_blob_id{}; *boxing_conf->mutable_slice_boxing_conf()->mutable_lbi() = empty_logical_blob_id; return SymbolOf(op_conf); } REGISTER_OP(OperatorConf::kSliceBoxingCopyConf, SliceBoxingCopyOp); REGISTER_OP(OperatorConf::kSliceBoxingAddConf, SliceBoxingAddOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/source_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/source_tick_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe SourceTickOp::InitFromOpConf() { CHECK(op_conf().has_source_tick_conf()); CHECK(op_conf().ctrl_in_op_name().empty()); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe SourceTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } Maybe SourceTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); BlobDesc* blob_desc = GetBlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } Maybe SourceTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kSourceTickConf, SourceTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kSourceTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/source_tick_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_ #define ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class SourceTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SourceTickOp); SourceTickOp() = default; ~SourceTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_ ================================================ FILE: oneflow/core/operator/src_subset_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { class SrcSubsetTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SrcSubsetTickOp); SrcSubsetTickOp() = default; ~SrcSubsetTickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; Maybe SrcSubsetTickOp::InitFromOpConf() { CHECK(op_conf().has_src_subset_tick_conf()); EnrollRepeatedInputBn("in", false); EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace Maybe SrcSubsetTickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe SrcSubsetTickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe SrcSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kSrcSubsetTickConf, SrcSubsetTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kSrcSubsetTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/sync_dynamic_resize_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { const SyncDynamicResizeOpConf& conf = op_conf.sync_dynamic_resize_conf(); CHECK_EQ_OR_RETURN(conf.axis(), 0); const BlobDesc* in = BlobDesc4BnInOp("in"); const BlobDesc* size = BlobDesc4BnInOp("size"); CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1); CHECK_OR_RETURN(IsIntegralDataType(size->data_type())); BlobDesc* out = BlobDesc4BnInOp("out"); *out = *in; out->set_is_dynamic(true); return Maybe::Ok(); } } // namespace class SyncDynamicResizeOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeOp); SyncDynamicResizeOp() = default; ~SyncDynamicResizeOp() override = default; Maybe InitFromOpConf() override { EnrollInputBn("in"); EnrollInputBn("size", false); EnrollOutputBn("out")->set_header_infered_before_compute(false); return Maybe::Ok(); } Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override { return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override { return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override { return Maybe::Ok(); } void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override { kernel_conf->mutable_sync_dynamic_resize_conf()->set_size_data_type( GetBlobDesc4BnInOp("size")->data_type()); } }; REGISTER_OP(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/tick_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { BlobDesc* blob_desc = BlobDesc4BnInOp("out"); blob_desc->set_shape(Shape({1})); blob_desc->set_data_type(DataType::kInt8); return Maybe::Ok(); } } // namespace Maybe TickOp::InitFromOpConf() { CHECK(op_conf().has_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); return Maybe::Ok(); } Maybe TickOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { return InferBlobDescs(BlobDesc4BnInOp); } Maybe TickOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe TickOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { return Maybe::Ok(); } REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kTickConf, 2); REGISTER_OP(OperatorConf::kTickConf, TickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kTickConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/tick_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_TICK_OP_H_ #define ONEFLOW_CORE_OPERATOR_TICK_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class TickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(TickOp); TickOp() = default; ~TickOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_TICK_OP_H_ ================================================ FILE: oneflow/core/operator/total_loss_instance_num_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/total_loss_instance_num_op.h" namespace oneflow { void TotalLossInstanceNumOp::VirtualInitFromOpConf() { CHECK(op_conf().has_total_loss_instance_num_conf()); } Maybe TotalLossInstanceNumOp::VirtualInferBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { for (const std::string& ibn : input_bns()) { CHECK_OR_RETURN(*GetBlobDesc4BnInOp(ibn) == *GetBlobDesc4BnInOp(input_bns().Get(0))); } return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kTotalLossInstanceNumConf, TotalLossInstanceNumOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/total_loss_instance_num_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_ #define ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_ #include "oneflow/core/operator/cwise_op.h" namespace oneflow { class TotalLossInstanceNumOp final : public CWiseOp { public: OF_DISALLOW_COPY_AND_MOVE(TotalLossInstanceNumOp); TotalLossInstanceNumOp() = default; ~TotalLossInstanceNumOp() = default; void VirtualInitFromOpConf() override; Maybe VirtualInferBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_ ================================================ FILE: oneflow/core/operator/user_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/operator/user_op.h" #include "oneflow/core/framework/infer_output_blob_time_shape_fn_context.h" #include "oneflow/core/framework/infer_nd_sbp_fn_context.h" #include "oneflow/core/framework/compute_complexity_fn_context.h" #include "oneflow/core/framework/get_nd_sbp_signature_list_context.h" namespace oneflow { namespace { BlobDesc* FindValidBlobDescOfBnsInOp( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { BlobDesc* valid = nullptr; for (const std::string& bn_in_op : bn_in_ops) { BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op); if (blob_desc) { const bool is_dynamic = blob_desc->is_dynamic(); if (valid == nullptr || is_dynamic) { valid = blob_desc; if (is_dynamic) { break; } } } } return valid; } user_op::NaiveTensorDesc GenTensorDescFromBlobDesc(const BlobDesc* blob_desc) { user_op::NaiveTensorDesc tensor_desc; tensor_desc.set_shape(blob_desc->shape()); tensor_desc.set_stride(blob_desc->stride()); tensor_desc.set_data_type(blob_desc->data_type()); tensor_desc.set_memory_format(blob_desc->memory_format()); tensor_desc.set_is_dynamic(blob_desc->is_dynamic()); return tensor_desc; } } // namespace // kernel registry context used in infer functions of user op class UserOpKernelRegContext final : public user_op::KernelRegContext { public: using ArgVec = std::vector>; explicit UserOpKernelRegContext(const UserOp* user_op, std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) : user_op_conf_(user_op->op_conf()) { const auto& op_conf = user_op->op_conf(); CHECK(op_conf.has_user_conf()); device_type_ = CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag())); parallel_ctx_ = parallel_ctx; auto InitInOrOut = [&](const PbMap& arg_map, ArgVec* arg_vec) { for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { for (int32_t i = 0; i < it->second.s_size(); ++i) { arg_vec->emplace_back(std::make_pair(it->first, i)); } } }; InitInOrOut(op_conf.user_conf().input(), &inputs_); InitInOrOut(op_conf.user_conf().output(), &outputs_); { #define INSERT_TO_ARG2TENSOR_DESC(prefix) \ for (const auto& bn : user_op->prefix##_bns()) { \ const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn); \ if (!blob_desc) { continue; } \ arg2tensor_desc_.emplace(GenUnRepeatedBn(bn), GenTensorDescFromBlobDesc(blob_desc)); \ } INSERT_TO_ARG2TENSOR_DESC(input) INSERT_TO_ARG2TENSOR_DESC(output) INSERT_TO_ARG2TENSOR_DESC(tmp) #undef INSERT_TO_ARG2TENSOR_DESC } } ~UserOpKernelRegContext() = default; DeviceType device_type() const override { return device_type_; } const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return &(it->second); } const ArgVec& inputs() const override { return inputs_; } const ArgVec& outputs() const override { return outputs_; } const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } private: const user_op::UserOpConfWrapper user_op_conf_; ArgVec inputs_; ArgVec outputs_; DeviceType device_type_; const ParallelContext* parallel_ctx_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; }; class UserOpInferContext final : public user_op::InferContext { public: using ArgVec = std::vector>; UserOpInferContext(const UserOp* op, const ParallelContext* parallel_ctx, const JobDesc* job_desc, const std::function& GetBlobDesc4BnInOp) : op_(op), parallel_ctx_(parallel_ctx), job_desc_(job_desc) { bn2logical_tensor_desc_.reset(new HashMap()); auto InitTensorDesc = [&](const ArgVec& arg_vec, const PbRpf& bns) { CHECK_EQ(arg_vec.size(), bns.size()); for (int32_t i = 0; i < arg_vec.size(); ++i) { const auto& bn_i = bns.Get(i); BlobDesc* blob = GetBlobDesc4BnInOp(bns.Get(i)); CHECK(blob != nullptr) << bn_i; arg2tensor_desc_.emplace(arg_vec.at(i), GenTensorDescFromBlobDesc(blob)); } }; InitTensorDesc(op->inputs(), op->input_bns()); InitTensorDesc(op->outputs(), op->output_bns()); } ~UserOpInferContext() override = default; const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, int32_t index) const override { return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return &it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return &(it->second); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const std::string bn = GenRepeatedBn(arg_name, index); const auto it = bn2logical_tensor_desc_->find(bn); if (it != bn2logical_tensor_desc_->end()) { return &it->second; } else { std::shared_ptr blob_desc = CHECK_JUST(op_->GetLogicalBlobDesc4BnInOp(bn)); bn2logical_tensor_desc_->emplace(bn, GenTensorDescFromBlobDesc(blob_desc.get())); return &(bn2logical_tensor_desc_->emplace(bn, GenTensorDescFromBlobDesc(blob_desc.get())) .first->second); } } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override { SetShape4ArgNameAndIndex(arg_name, index, shape); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { thread_local static Shape non_shape; return non_shape; }; return it->second.shape(); } void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Shape& shape) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return; }; return it->second.set_shape(shape); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override { return SetStride4ArgNameAndIndex(arg_name, index, stride); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { thread_local static Stride non_stride; return non_stride; }; return it->second.stride(); } void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Stride& stride) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return; }; return it->second.set_stride(stride); } DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override { return SetDtype4ArgNameAndIndex(arg_name, index, data_type); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; }; return it->second.data_type(); } void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index, DataType data_type) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return; }; return it->second.set_data_type(data_type); } MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } void SetOutputMemoryFormat(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format); } MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return MemoryFormat::kContiguous; }; return it->second.memory_format(); } void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return; }; return it->second.set_memory_format(memory_format); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override { return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return false; }; return it->second.is_dynamic(); } void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index, bool is_dynamic) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return; }; return it->second.set_is_dynamic(is_dynamic); } const ArgVec& inputs() const override { return op_->inputs(); } const ArgVec& outputs() const override { return op_->outputs(); } const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; }; const ParallelDesc& parallel_desc() const override { return *CHECK_JUST(op_->GetOpParallelDesc()); }; const JobDesc* job_desc() const override { CHECK_NOTNULL(job_desc_); return job_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { CHECK_EQ(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy()->NumAxes(), 1); const auto& bn2sbp = CHECK_JUST(op_->sbp_signature())->bn_in_op2sbp_parallel(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2sbp.find(bn); CHECK(it != bn2sbp.end()); return it->second; } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const auto& bn2nd_sbp = CHECK_JUST(op_->nd_sbp_signature())->bn_in_op2nd_sbp(); std::string bn = GenRepeatedBn(arg_name, index); auto it = bn2nd_sbp.find(bn); CHECK(it != bn2nd_sbp.end()); return it->second; } int64_t parallel_num() const override { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } const std::string& input(const std::string& arg_name, int32_t index) const override { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const override { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const override { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const override { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const override { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const override { return user_op_conf().output_size(arg_name); } const std::string& op_name() const override { return user_op_conf().op_name(); } const std::string& op_type_name() const override { return user_op_conf().op_type_name(); } const std::string& op_loc() const override { return op_->op_loc(); } private: const user_op::UserOpConfWrapper& user_op_conf() const { return op_->user_op_conf(); } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } const UserOp* op_; const ParallelContext* parallel_ctx_; const JobDesc* job_desc_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; std::unique_ptr> bn2logical_tensor_desc_; }; class UserOpSbpContext : public user_op::SbpContext { public: using ArgVec = std::vector>; UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list, std::function(const std::string&)> LogicalBlobDesc4Ibn, int32_t hierarchy_value) : op_(op), sbp_sig_list_(sbp_sig_list), hierarchy_value_(hierarchy_value) { const auto& user_op_conf = op->op_conf().user_conf(); for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) { const std::string& arg_name = it->first; for (int32_t i = 0; i < it->second.s_size(); ++i) { const BlobDesc* blob = &CHECK_JUST(LogicalBlobDesc4Ibn(GenRepeatedBn(arg_name, i))); arg2tensor_desc_.emplace(std::make_pair(arg_name, i), GenTensorDescFromBlobDesc(blob)); } } } ~UserOpSbpContext() override = default; const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex( const std::string& input_arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index)); CHECK(it != arg2tensor_desc_.end()) << "Cannot find input_arg_name : " << input_arg_name << " input_arg_index : " << index; return it->second; } const ArgVec& inputs() const override { return op_->inputs(); } const ArgVec& outputs() const override { return op_->outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); } user_op::UserOpSbpSignatureBuilder NewBuilder() override { return user_op::UserOpSbpSignatureBuilder(sbp_sig_list_); } DeviceType device_type() const override { return op_->device_type(); } int64_t parallel_num() const override { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } int64_t hierarchy_value() const override { return hierarchy_value_; } private: const UserOp* op_; SbpSignatureList* sbp_sig_list_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; int32_t hierarchy_value_; }; class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext { public: using ArgVec = std::vector>; UserOpInferSbpSignatureFnContext( const UserOp* op, SbpSignature* signature, const SbpSignature& sbp_signature_conf, std::function(const std::string&)> SbpInferHint4Ibn) : op_(op), signature_(signature), sbp_signature_conf_(sbp_signature_conf), sbp_infer_hint4ibn_fn_(std::move(SbpInferHint4Ibn)) { const auto& user_op_conf = op->op_conf().user_conf(); for (const auto& it : user_op_conf.input()) { const std::string& arg_name = it.first; for (int32_t i = 0; i < it.second.s_size(); ++i) { auto hint = CHECK_JUST(sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, i))); arg2tensor_desc_.emplace(std::make_pair(arg_name, i), GenTensorDescFromBlobDesc(&hint->logical_blob_desc())); arg2sbp_parallel_hint_.emplace(std::make_pair(arg_name, i), hint->sbp_parallel()); } } } ~UserOpInferSbpSignatureFnContext() override = default; const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex( const std::string& input_arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index)); CHECK(it != arg2tensor_desc_.end()) << "Cannot find input_arg_name : " << input_arg_name << " input_arg_index : " << index; return it->second; } const ArgVec& inputs() const override { return op_->inputs(); } const ArgVec& outputs() const override { return op_->outputs(); } SbpSignature* mutable_sbp_signature() override { return signature_; } const SbpSignature& sbp_signature_conf() const override { return sbp_signature_conf_; } const SbpParallel& SbpParallelHint4InputArgNameAndIndex(const std::string& input_arg_name, int32_t index) const override { auto it = arg2sbp_parallel_hint_.find(std::make_pair(input_arg_name, index)); CHECK(it != arg2sbp_parallel_hint_.end()) << "Cannot find input_arg_name : " << input_arg_name << " input_arg_index : " << index; return it->second; } const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); } DeviceType device_type() const override { return op_->device_type(); } int64_t parallel_num() const override { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } private: const UserOp* op_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; HashMap, SbpParallel> arg2sbp_parallel_hint_; SbpSignature* signature_; SbpSignature sbp_signature_conf_; std::function(const std::string&)> sbp_infer_hint4ibn_fn_; }; class UserOpInferOutputBlobTimeShapeFnContext : public user_op::InferOutputBlobTimeShapeFnContext { public: UserOpInferOutputBlobTimeShapeFnContext( const UserOp* op, const std::function(const std::string&)>& GetTimeShape4BnInOp, Shape* output_blob_time_shape) : op_(op), output_blob_time_shape_(output_blob_time_shape) { for (const auto& it : op->op_conf().user_conf().input()) { const std::string& arg_name = it.first; for (int32_t i = 0; i < it.second.s_size(); ++i) { std::string ibn = GenRepeatedBn(arg_name, i); arg2time_shape_.emplace(std::make_pair(arg_name, i), *CHECK_JUST(GetTimeShape4BnInOp(ibn))); } } } ~UserOpInferOutputBlobTimeShapeFnContext() override = default; const Shape& TimeShape4InputArgNameAndIndex(const std::string& arg_name, int32_t index) override { return arg2time_shape_.at(std::make_pair(arg_name, index)); } const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); } Shape* mut_output_blob_time_shape() override { return output_blob_time_shape_; }; private: const UserOp* op_; HashMap, Shape> arg2time_shape_; Shape* output_blob_time_shape_; }; class UserOpInferNdSbpFnContext : public user_op::InferNdSbpFnContext { public: using ArgVec = std::vector>; UserOpInferNdSbpFnContext( const UserOp* op, NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, std::function(const std::string&)> NdSbpInferHint4Ibn) : op_(op), nd_sbp_signature_(nd_sbp_signature), nd_sbp_constraints_(nd_sbp_constraints), nd_sbp_infer_hint4ibn_fn_(std::move(NdSbpInferHint4Ibn)) { const auto& user_op_conf = op->op_conf().user_conf(); for (const auto& it : user_op_conf.input()) { const std::string& arg_name = it.first; for (int32_t i = 0; i < it.second.s_size(); ++i) { auto hint = CHECK_JUST(nd_sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, i))); CHECK(arg2tensor_desc_ .emplace(std::make_pair(arg_name, i), GenTensorDescFromBlobDesc(&hint->logical_blob_desc())) .second); } } } ~UserOpInferNdSbpFnContext() override = default; const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex( const std::string& input_arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index)); CHECK(it != arg2tensor_desc_.end()) << "Cannot find input_arg_name : " << input_arg_name << " input_arg_index : " << index; return it->second; } const NdSbpSignature& nd_sbp_constraints() const override { return nd_sbp_constraints_; } NdSbp* NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return &(*nd_sbp_signature_->mutable_bn_in_op2nd_sbp())[GenRepeatedBn(arg_name, index)]; } const NdSbp& NdSbpHint4InputArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto hint = CHECK_JUST(nd_sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, index))); return hint->nd_sbp(); } const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); } int64_t parallel_num() const override { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } const Shape& parallel_hierarchy() override { return *(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy()); } const ArgVec& inputs() const override { return op_->inputs(); } const ArgVec& outputs() const override { return op_->outputs(); } private: const UserOp* op_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; NdSbpSignature* nd_sbp_signature_; NdSbpSignature nd_sbp_constraints_; std::function(const std::string&)> nd_sbp_infer_hint4ibn_fn_; }; // Store information for computing computation cost // TODO: Maybe this class could simplify class UserOpComputeComplexityFnContext : public user_op::ComputeComplexityFnContext { public: using ArgVec = std::vector>; UserOpComputeComplexityFnContext( const OperatorConf& op_conf, const ParallelDesc& parallel_desc, const NdSbpSignature* sbp_signature, std::function logical_blob_desc4bn) : user_op::ComputeComplexityFnContext(user_op::UserOpConfWrapper(op_conf)), parallel_desc_(parallel_desc), sbp_signature_(sbp_signature) { auto InitInOrOut = [&](const PbMap& arg_map, ArgVec* arg_vec) { for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { const std::string& arg_name = it->first; for (int32_t i = 0; i < it->second.s_size(); ++i) { const BlobDesc& blob = logical_blob_desc4bn(GenRepeatedBn(arg_name, i)); auto key = std::make_pair(arg_name, i); arg2tensor_desc_.emplace(key, GenTensorDescFromBlobDesc(&blob)); arg_vec->emplace_back(std::make_pair(arg_name, i)); } } }; InitInOrOut(op_conf.user_conf().input(), &inputs_); InitInOrOut(op_conf.user_conf().output(), &outputs_); } ~UserOpComputeComplexityFnContext() override = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return &(it->second); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { thread_local static Shape non_shape; return non_shape; }; return it->second.shape(); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; }; return it->second.data_type(); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return false; }; return it->second.is_dynamic(); } const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const auto& bn2sbp = sbp_signature_->bn_in_op2nd_sbp(); std::string bn = GenRepeatedBn(arg_name, index); CHECK(bn2sbp.find(bn) != bn2sbp.end()); return sbp_signature_->bn_in_op2nd_sbp().at(bn); } const ArgVec& inputs() const override { return inputs_; } const ArgVec& outputs() const override { return outputs_; } const ParallelDesc& parallel_desc() const override { return parallel_desc_; }; const NdSbpSignature* GetNdSbpSignature() const override { return sbp_signature_; } private: ArgVec inputs_; ArgVec outputs_; const ParallelDesc parallel_desc_; const NdSbpSignature* sbp_signature_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; }; class UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureListContext { public: UserOpGetNdSbpSignatureListContext( const UserOp* op, std::function(const std::string&)> LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) : user_op::GetNdSbpSignatureListContext(user_op::UserOpConfWrapper(op->user_op_conf())), op_(op), logical_blob_desc4ibn_(std::move(LogicalBlobDesc4Ibn)), parallel_desc_(parallel_desc), nd_sbp_sig_list_(nd_sbp_sig_list) {} ~UserOpGetNdSbpSignatureListContext() override = default; std::vector* MutNdSbpSignatureList() override { return nd_sbp_sig_list_; } void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override { nd_sbp_sig_list_->emplace_back(nd_sbp_sig); } const Shape& parallel_hierarchy() override { return *(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy()); } const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return CHECK_JUST(logical_blob_desc4ibn_(GenRepeatedBn(arg_name, index))).shape(); } private: const UserOp* op_; std::function(const std::string&)> logical_blob_desc4ibn_; const ParallelDesc parallel_desc_; std::vector* nd_sbp_sig_list_; }; Maybe UserOp::InitFromOpConf() { CHECK_OR_RETURN(op_conf().has_user_conf()); for (const auto& pair : op_conf().user_conf().input()) { EnrollRepeatedInputBn(pair.first, pair.second.s_size()); for (int32_t i = 0; i < pair.second.s_size(); ++i) { inputs_.emplace_back(std::make_pair(pair.first, i)); } } for (const auto& pair : op_conf().user_conf().output()) { EnrollRepeatedOutputBn(pair.first, pair.second.s_size()); for (int32_t i = 0; i < pair.second.s_size(); ++i) { outputs_.emplace_back(std::make_pair(pair.first, i)); } } EnrollTmpBn(GenRepeatedBn("tmp_buffer", 0)); user_op_conf_.reset(new user_op::UserOpConfWrapper(shared_op_conf())); val_ = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf().user_conf().op_type_name()); if (val_ != nullptr) { if (val_->input_arg_modify_fn) { user_op::GetInputArgModifier GetInputArgModifierFn = [&](const std::string& in_arg_name, int32_t in_arg_index) -> user_op::InputArgModifier* { std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index); if (std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()) { return MutInputBlobModifier4Ibn(ibn); } return nullptr; }; JUST(val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_)); } if (val_->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = [&](const std::string& out_arg_name, int32_t out_arg_index) -> user_op::OutputArgModifier* { std::string obn = GenRepeatedBn(out_arg_name, out_arg_index); if (std::find(output_bns().begin(), output_bns().end(), obn) != output_bns().end()) { return MutOutputBlobModifier4Obn(obn); } return nullptr; }; JUST(val_->output_arg_modify_fn(GetOutputArgModifierFn, *user_op_conf_)); } } return Maybe::Ok(); } Maybe UserOp::InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { // tmp buffer size must be inferred after out shape/dtype UserOpInferContext infer_ctx(this, parallel_ctx, job_desc, GetBlobDesc4BnInOp); const user_op::OpKernelRegistryResult* kernel_reg_val = JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult( op_conf().user_conf().op_type_name(), UserOpKernelRegContext(this, GetBlobDesc4BnInOp, parallel_ctx))); CHECK_OR_RETURN(kernel_reg_val != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in kernel registry !"; size_t tmp_size = kernel_reg_val->infer_tmp_size_fn(&infer_ctx); if (tmp_size > 0) { BlobDesc* tmp_buffer_blob = GetBlobDesc4BnInOp(GenRepeatedBn("tmp_buffer", 0)); CHECK_NOTNULL_OR_RETURN(tmp_buffer_blob); tmp_buffer_blob->set_data_type(DataType::kChar); tmp_buffer_blob->set_memory_format(MemoryFormat::kContiguous); tmp_buffer_blob->set_shape(Shape({static_cast(tmp_size)})); tmp_buffer_blob->set_stride(Stride({static_cast(tmp_size)})); } return Maybe::Ok(); } Maybe UserOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; // default method set output blob desc (such as Dtype, is_dynamic) // set out blob desc attr as first input blob desc (if has) BlobDesc* first_in_blob_desc = FindValidBlobDescOfBnsInOp(BlobDesc4BnInOp, input_bns()); if (first_in_blob_desc) { for (const std::string& obn : output_bns()) { BlobDesc4BnInOp(obn)->CopyFrom(*first_in_blob_desc); } } UserOpInferContext infer_ctx(this, nullptr, nullptr, BlobDesc4BnInOp); CHECK_OR_RETURN(val_->data_type_infer_fn) << "No InferDataType function for " << val_->op_type_name; JUST(val_->data_type_infer_fn(&infer_ctx)); JUST(val_->logical_tensor_desc_infer_fn(&infer_ctx)); for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = BlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); const user_op::TensorDesc& tensor_desc = infer_ctx.OutputTensorDesc(pair.first, pair.second); out_blob_desc->set_data_type(tensor_desc.data_type()); out_blob_desc->set_memory_format(tensor_desc.memory_format()); out_blob_desc->set_shape(tensor_desc.shape()); if (val_->non_contiguous_supported) { out_blob_desc->set_stride(tensor_desc.stride()); } else { out_blob_desc->set_stride(Stride(out_blob_desc->shape())); } CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size()) << Error::RuntimeError() << "stride and shape size mismatch since stride is " << out_blob_desc->stride().ToString() << " but shape is " << out_blob_desc->shape().ToString(); out_blob_desc->set_is_dynamic(tensor_desc.is_dynamic()); } return Maybe::Ok(); } Maybe UserOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; if (!val_->physical_tensor_desc_infer_fn) { return Operator::InferOutBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); } else { // default method set output blob desc (such as Dtype, is_dynamic, is_tensor_list) // set out blob desc attr as first input blob desc (if has) BlobDesc* first_in_blob_desc = FindValidBlobDescOfBnsInOp(GetBlobDesc4BnInOp, input_bns()); if (first_in_blob_desc) { for (const std::string& obn : output_bns()) { GetBlobDesc4BnInOp(obn)->CopyFrom(*first_in_blob_desc); } } UserOpInferContext infer_ctx(this, parallel_ctx, nullptr, GetBlobDesc4BnInOp); CHECK_OR_RETURN(val_->data_type_infer_fn) << "No InferDataType function for " << val_->op_type_name; JUST(val_->data_type_infer_fn(&infer_ctx)); JUST(val_->physical_tensor_desc_infer_fn(&infer_ctx)); for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); out_blob_desc->set_data_type(infer_ctx.OutputDType(pair.first, pair.second)); out_blob_desc->set_memory_format(infer_ctx.OutputMemoryFormat(pair.first, pair.second)); out_blob_desc->set_shape(infer_ctx.OutputShape(pair.first, pair.second)); if (val_->non_contiguous_supported) { out_blob_desc->set_stride(infer_ctx.OutputStride(pair.first, pair.second)); } else { out_blob_desc->set_stride(Stride(out_blob_desc->shape())); } CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size()) << Error::RuntimeError() << "stride and shape size mismatch since stride is " << out_blob_desc->stride().ToString() << " but shape is " << out_blob_desc->shape().ToString(); out_blob_desc->set_is_dynamic(infer_ctx.OutputIsDynamic(pair.first, pair.second)); } return Maybe::Ok(); } } Maybe UserOp::InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { UserOpInferContext infer_ctx(this, parallel_ctx, nullptr, GetBlobDesc4BnInOp); const user_op::OpKernelRegistryResult* kernel_reg_val = JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult( op_conf().user_conf().op_type_name(), UserOpKernelRegContext(this, GetBlobDesc4BnInOp, parallel_ctx))); CHECK_OR_RETURN(kernel_reg_val != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in kernel registry !"; HashSet bn_in_op_unique_check; user_op::AddInplaceArgPair AddInplaceArgPairFn = [&](const std::string& out_arg_name, int32_t out_arg_index, const std::string& in_arg_name, int32_t in_arg_index, bool is_mutable) -> Maybe { std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index); std::string obn = GenRepeatedBn(out_arg_name, out_arg_index); if (is_mutable) { mut_inplace_obn2ibn->emplace(obn, ibn); } else { con_inplace_obn2ibn->emplace(obn, ibn); } CHECK_OR_RETURN(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()) << "Cannot find input_arg_name : " << in_arg_name << " input_arg_index : " << in_arg_index << " in op_name: " << op_conf().name(); CHECK_OR_RETURN(std::find(output_bns().begin(), output_bns().end(), obn) != output_bns().end()) << "Cannot find output_arg_name : " << out_arg_name << " output_arg_index : " << out_arg_index << " in op_name: " << op_conf().name(); std::string repeated_ibn_err_msg = "Cannot repeated set inplace proposal for same intput arg : " + in_arg_name + " index : " + std::to_string(in_arg_index) + " in op_name: " + op_conf().name(); std::string repeated_obn_err_msg = "Cannot repeated set inplace proposal for same output arg : " + out_arg_name + " index : " + std::to_string(out_arg_index) + " in op_name: " + op_conf().name(); CHECK_OR_RETURN(bn_in_op_unique_check.insert(ibn).second) << repeated_ibn_err_msg; CHECK_OR_RETURN(bn_in_op_unique_check.insert(obn).second) << repeated_obn_err_msg; return Maybe::Ok(); }; JUST(kernel_reg_val->inplace_proposal_fn(infer_ctx, AddInplaceArgPairFn)); return Maybe::Ok(); } LogicalBlobId UserOp::lbi4ibn(const std::string& input_bn) const { auto pair = GenUnRepeatedBn(input_bn); return GenLogicalBlobId(op_conf().user_conf().input().at(pair.first).s(pair.second)); } LogicalBlobId UserOp::lbi4obn(const std::string& output_bn) const { // TODO: remove this workaround and use different lbi for input and output const bool is_copy_hd = op_conf().user_conf().op_type_name() == "copy_d2h" || op_conf().user_conf().op_type_name() == "copy_h2d"; if (is_copy_hd) { return GenLogicalBlobId(op_conf().user_conf().input().at("in").s(0)); } auto pair = GenUnRepeatedBn(output_bn); auto ret = GenLogicalBlobId(op_conf().user_conf().output().at(pair.first).s(pair.second)); CHECK_EQ(ret.op_name(), op_conf().name()); CHECK_EQ(ret.blob_name(), output_bn); return ret; } Maybe UserOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { if (val_->sbp_signature_infer_fn) { UserOpInferSbpSignatureFnContext ctx(this, sbp_signature, sbp_sig_conf, SbpInferHint4Ibn); return val_->sbp_signature_infer_fn(&ctx); } else { return Operator::InferSbpSignature(sbp_signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn, parallel_desc); } } Maybe UserOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, hierarchy_value); JUST(val_->get_sbp_fn(&sbp_ctx)); // Add Broadcast for source user op tick input if (val_->op_def.input_size() == 1 && input_bns().size() == 1 && val_->op_def.input(0).name() == user_op::kUserSourceOpTickInputArgName) { std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0); CHECK_OR_RETURN(input_bns().Get(0) == tick_bn) << "user op_name: " << op_conf().name() << " op_type_name: " << op_conf().user_conf().op_type_name() << " set ERROR input arg name : " << input_bns().Get(0) << " because NO input in op def"; for (auto& sbp_sig : *sbp_sig_list->mutable_sbp_signature()) { auto* bn2sbp = sbp_sig.mutable_bn_in_op2sbp_parallel(); if (bn2sbp->find(tick_bn) == bn2sbp->end()) { (*bn2sbp)[tick_bn].mutable_broadcast_parallel(); } } } // Check valid for (const auto& sbp_sig : sbp_sig_list->sbp_signature()) { const auto& bn2sbp = sbp_sig.bn_in_op2sbp_parallel(); for (const auto& ibn : input_bns()) { auto pair = GenUnRepeatedBn(ibn); CHECK_OR_RETURN(bn2sbp.find(ibn) != bn2sbp.end()) << "In op_name: " << op_conf().name() << " op_type_name: " << op_conf().user_conf().op_type_name() << ", input_arg_name : " << pair.first << " input_arg_index : " << pair.second << " have NOT set sbp signature"; } for (const auto& obn : output_bns()) { auto pair = GenUnRepeatedBn(obn); CHECK_OR_RETURN(bn2sbp.find(obn) != bn2sbp.end()) << "In op_name: " << op_conf().name() << " op_type_name: " << op_conf().user_conf().op_type_name() << ", output_arg_name : " << pair.first << " output_arg_index : " << pair.second << " have NOT set sbp signature"; } } return Maybe::Ok(); } Maybe UserOp::GetComputeComplexity( NdSbpSignature* sbp_signature, std::function logical_blob_desc4bn, const ParallelDesc& parallel_desc) const { if (val_->compute_complexity_fn) { UserOpComputeComplexityFnContext user_op_compute_complexity_fn_context( op_conf(), parallel_desc, sbp_signature, logical_blob_desc4bn); return val_->compute_complexity_fn(&user_op_compute_complexity_fn_context); } else { return Operator::GetComputeComplexity(sbp_signature, logical_blob_desc4bn, parallel_desc); } } Operator::DumpNdSbpSignatureForOpConfFn UserOp::GetDumpNdSbpSignatureForOpConfFn() const { if (val_->dump_nd_sbp_signature_for_op_conf_fn) { return val_->dump_nd_sbp_signature_for_op_conf_fn; } else { return Operator::GetDumpNdSbpSignatureForOpConfFn(); } } Maybe UserOp::InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const { if (val_->output_blob_time_shape_infer_fn) { std::shared_ptr op_time_shape(new Shape()); UserOpInferOutputBlobTimeShapeFnContext infer_output_blob_time_shape_fn_ctx( this, GetTimeShape4BnInOp, op_time_shape.get()); *time_shape = op_time_shape; return val_->output_blob_time_shape_infer_fn(&infer_output_blob_time_shape_fn_ctx); } else { return Operator::InferOpTimeShape(GetTimeShape4BnInOp, time_shape); } } namespace { bool IgnoreInferNdSbpFnWhenFlatHierarchy(const std::string& op_type_name) { return (op_type_name == "reshape" || op_type_name == "reshape_like"); } } // namespace Maybe UserOp::InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const { if (val_->nd_sbp_infer_fn && (parallel_desc.hierarchy()->NumAxes() > 1 || !IgnoreInferNdSbpFnWhenFlatHierarchy(this->user_op_conf().op_type_name()))) { UserOpInferNdSbpFnContext ctx(this, nd_sbp_signature, nd_sbp_constraints, NdSbpInferHint4Ibn); JUST(val_->nd_sbp_infer_fn(&ctx)); } else { JUST(Operator::InferNdSbpSignature(nd_sbp_signature, nd_sbp_constraints, parallel_desc, NdSbpInferHint4Ibn)); } std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0); if (std::find(input_bns().begin(), input_bns().end(), tick_bn) != input_bns().end()) { auto* map = nd_sbp_signature->mutable_bn_in_op2nd_sbp(); if (map->count(tick_bn) == 0) { auto* sbp_list = (*map)[tick_bn].mutable_sbp_parallel(); for (int i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) { sbp_list->Add()->mutable_broadcast_parallel(); } } } return Maybe::Ok(); } Maybe UserOp::EnumerateNdSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { if (val_->enumerate_nd_sbp_signatures_fn) { NdSbpSignature empty_sbp_signature; UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context( this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list); return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context); } else { return Operator::EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list); } } Maybe UserOp::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { if (val_->get_nd_sbp_list_fn) { NdSbpSignature empty_sbp_signature; UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context( this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list); return val_->get_nd_sbp_list_fn(&user_op_get_nd_sbp_list_context); } else { JUST(Operator::GetNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list)); } return Maybe::Ok(); } Symbol UserOp::GetOpConfWithoutOpNameAndLbn() const { OperatorConf op_conf(this->op_conf()); op_conf.set_name("undefined-op-name"); UserOpConf* user_op_conf = op_conf.mutable_user_conf(); for (auto& pair : *user_op_conf->mutable_input()) { for (auto& str : *pair.second.mutable_s()) { str = "undefined-op-name/undefined-ibn"; } } for (auto& pair : *user_op_conf->mutable_output()) { std::string prefix = "undefined-op-name/"; prefix += pair.first; prefix += "_"; int i = 0; for (auto& str : *pair.second.mutable_s()) { str = prefix + std::to_string(i++); } } return SymbolOf(op_conf); } void UserOp::VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { auto user_conf = kernel_conf->mutable_user_conf(); ForEachBnInOp([&](const std::string& bn) { const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn); if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); } }); } const user_op::UserOpConfWrapper& UserOp::user_op_conf() const { CHECK(user_op_conf_); return *user_op_conf_; } REGISTER_OP(OperatorConf::kUserConf, UserOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/user_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_USER_OP_H_ #define ONEFLOW_CORE_OPERATOR_USER_OP_H_ #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/operator/operator.h" namespace oneflow { class UserOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(UserOp); UserOp() = default; ~UserOp() = default; using ArgVec = std::vector>; Maybe InitFromOpConf() override; Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; Maybe GetComputeComplexity( NdSbpSignature* sbp_signature, std::function logical_blob_desc4bn, const ParallelDesc& parallel_desc) const override; Operator::DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; const user_op::UserOpConfWrapper& user_op_conf() const; const ArgVec& inputs() const { return inputs_; } const ArgVec& outputs() const { return outputs_; } private: LogicalBlobId lbi4ibn(const std::string& input_bn) const override; LogicalBlobId lbi4obn(const std::string& output_bn) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override; Maybe EnumerateNdSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const override; Maybe InferOpTimeShape( const std::function(const std::string&)>& GetTimeShape4BnInOp, std::shared_ptr* time_shape) const override; Maybe InferNdSbpSignature(NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const override; void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; const user_op::OpRegistryResult* val_; std::unique_ptr user_op_conf_; ArgVec inputs_; ArgVec outputs_; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_USER_OP_H_ ================================================ FILE: oneflow/core/operator/variable_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { namespace { Maybe ParseNdSbpFromConf(const VariableOpConf& conf, const ParallelDesc& parallel_desc, NdSbp* nd_sbp) { const bool has_nd_sbp_conf = (conf.nd_sbp_size() != 0); const int64_t num_axes = parallel_desc.hierarchy()->NumAxes(); if (has_nd_sbp_conf) { CHECK_EQ(conf.nd_sbp_size(), num_axes); } nd_sbp->clear_sbp_parallel(); FOR_RANGE(int64_t, i, 0, num_axes) { if (has_nd_sbp_conf) { CHECK_OR_RETURN(ParseSbpParallelFromString(conf.nd_sbp(i), nd_sbp->add_sbp_parallel())); } else { nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } } return Maybe::Ok(); } } // namespace Maybe VariableOp::InitFromOpConf() { CHECK(op_conf().has_variable_conf()); if (op_conf().variable_conf().has_tick()) { EnrollInputBn("tick", false); } bool is_trainable = op_conf().variable_conf().trainable(); EnrollOutputBn("out", is_trainable)->set_is_mutable(true); return Maybe::Ok(); } Maybe VariableOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { const VariableOpConf& variable_conf = op_conf().variable_conf(); BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); out_blob_desc->set_shape(Shape(variable_conf.shape())); CHECK_OR_RETURN(variable_conf.has_data_type()); out_blob_desc->set_data_type(variable_conf.data_type()); return Maybe::Ok(); } Maybe VariableOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const VariableOpConf& variable_conf = op_conf().variable_conf(); const ParallelDesc& parallel_desc = *JUST(GetOpParallelDesc()); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); CHECK_OR_RETURN(variable_conf.has_data_type()); out_blob_desc->set_data_type(variable_conf.data_type()); NdSbp nd_sbp; JUST(ParseNdSbpFromConf(variable_conf, parallel_desc, &nd_sbp)); out_blob_desc->set_shape( *JUST(GetPhysicalShape(Shape(variable_conf.shape()), nd_sbp, parallel_desc, *parallel_ctx))); return Maybe::Ok(); } Maybe VariableOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { int64_t num_axes = op_conf().variable_conf().shape().dim_size(); for (int i = 0; i < num_axes; ++i) { SbpSignatureBuilder() .Broadcast(input_bns()) .Split(output_bns(), i) .Build(sbp_sig_list->mutable_sbp_signature()->Add()); } return Maybe::Ok(); } Maybe VariableOp::InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), 1); SbpSignatureBuilder sbp_sig_builder; if (op_conf().variable_conf().nd_sbp_size() != 0) { CHECK_EQ_OR_RETURN(op_conf().variable_conf().nd_sbp_size(), 1); SbpParallel sbp_parallel; CHECK_OR_RETURN(ParseSbpParallelFromString(op_conf().variable_conf().nd_sbp(0), &sbp_parallel)); if (sbp_parallel.has_split_parallel()) { sbp_sig_builder.Split(output_bns(), sbp_parallel.split_parallel().axis()); } else { sbp_sig_builder.Broadcast(output_bns()); } } else { sbp_sig_builder.Broadcast(output_bns()); } sbp_sig_builder.Broadcast(input_bns()).Build(sbp_signature); return Maybe::Ok(); } Symbol VariableOp::GetOpConfWithoutOpNameAndLbn() const { return SymbolOf(this->op_conf()); } Maybe VariableOp::InferNdSbpSignature( NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const { const auto& parallel_hierarchy = parallel_desc.hierarchy(); const VariableOpConf& conf = this->op_conf().variable_conf(); NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["out"]; JUST(ParseNdSbpFromConf(conf, parallel_desc, &out_nd_sbp)); if (conf.has_tick()) { NdSbp& tick_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())["tick"]; for (int64_t i = 0; i < parallel_hierarchy->NumAxes(); ++i) { tick_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } } return Maybe::Ok(); } Operator::DumpNdSbpSignatureForOpConfFn VariableOp::GetDumpNdSbpSignatureForOpConfFn() const { return [](const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf) -> Maybe { CHECK_OR_RETURN(op_conf->has_variable_conf()) << "VariableOp don't set variable op_conf"; op_conf->mutable_variable_conf()->clear_nd_sbp(); const auto& nd_sbp = nd_sbp_sig.bn_in_op2nd_sbp().at("out"); for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) { op_conf->mutable_variable_conf()->mutable_nd_sbp()->Add(SbpParallelToString(sbp_parallel)); } return Maybe::Ok(); }; } REGISTER_OP(OperatorConf::kVariableConf, VariableOp); REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kVariableConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kVariableConf); } // namespace oneflow ================================================ FILE: oneflow/core/operator/variable_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ #define ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class VariableOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(VariableOp); VariableOp() : Operator() {} ~VariableOp() = default; Maybe InitFromOpConf() override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; Maybe InferNdSbpSignature(NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc, std::function(const std::string&)> NdSbpInferHint4Ibn) const override; DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ ================================================ FILE: oneflow/core/operator/wait_and_send_ids_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/wait_and_send_ids_op.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace oneflow { Maybe WaitAndSendIdsOp::InitFromOpConf() { CHECK(op_conf().has_wait_and_send_ids_conf()); EnrollOutputBn("out", false); return Maybe::Ok(); } namespace { Maybe InferBlobDescs(const OperatorConf& op_conf, const std::function& BlobDesc4BnInOp) { BlobDesc4BnInOp("out")->set_shape(Shape({1})); BlobDesc4BnInOp("out")->set_data_type(op_conf.wait_and_send_ids_conf().data_type()); return Maybe::Ok(); } } // namespace Maybe WaitAndSendIdsOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); return InferBlobDescs(op_conf(), BlobDesc4BnInOp); } Maybe WaitAndSendIdsOp::InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } Maybe WaitAndSendIdsOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } REGISTER_CPU_OP(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsOp); } // namespace oneflow ================================================ FILE: oneflow/core/operator/wait_and_send_ids_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_ #define ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_ #include "oneflow/core/operator/operator.h" namespace oneflow { class WaitAndSendIdsOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsOp); WaitAndSendIdsOp() = default; ~WaitAndSendIdsOp() = default; Maybe InitFromOpConf() override; Maybe InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; }; } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_ ================================================ FILE: oneflow/core/persistence/binary_in_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_ #define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_ #include "oneflow/core/common/util.h" namespace oneflow { class BinaryInStream { public: OF_DISALLOW_COPY_AND_MOVE(BinaryInStream); virtual ~BinaryInStream() = default; // 0: success // -1: eof virtual int32_t Read(char* s, size_t n) = 0; virtual uint64_t file_size() const = 0; virtual uint64_t cur_file_pos() const = 0; virtual void set_cur_file_pos(uint64_t val) = 0; virtual bool IsEof() const = 0; protected: BinaryInStream() = default; }; } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_ ================================================ FILE: oneflow/core/persistence/binary_in_stream_with_local_copy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/binary_in_stream_with_local_copy.h" #include "oneflow/core/persistence/binary_in_stream_without_local_copy.h" #include "oneflow/core/common/str_util.h" namespace oneflow { BinaryInStreamWithLocalCopy::BinaryInStreamWithLocalCopy(fs::FileSystem* fs, const std::string& file_path) : once_read_(false) { VLOG(3) << "New BinaryInStreamWithLocalCopy " << file_path; in_stream_.reset(new BinaryInStreamWithoutLocalCopy(fs, file_path)); local_copy_path_ = JoinPath(FLAGS_log_dir, "global_fs_buffer", file_path); out_stream_.reset(new PersistentOutStream(LocalFS(), local_copy_path_)); read_mthd_ = &BinaryInStreamWithLocalCopy::ReadAndWriteToLocal; } int32_t BinaryInStreamWithLocalCopy::ReadAndWriteToLocal(char* s, size_t n) { if (Restart()) { CopyToLocalFinish(); return Read(s, n); } else { int32_t ret = in_stream_->Read(s, n); CHECK_EQ(ret, 0); out_stream_->Write(s, n); once_read_ = true; return 0; } } bool BinaryInStreamWithLocalCopy::Restart() { return in_stream_->cur_file_pos() == 0 && once_read_; } void BinaryInStreamWithLocalCopy::CopyToLocalFinish() { out_stream_.reset(); in_stream_.reset(new BinaryInStreamWithoutLocalCopy(LocalFS(), local_copy_path_)); read_mthd_ = &BinaryInStreamWithLocalCopy::ReadFromLocal; } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/binary_in_stream_with_local_copy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_ #define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_ #include "oneflow/core/persistence/binary_in_stream.h" #include "oneflow/core/persistence/persistent_out_stream.h" namespace oneflow { class BinaryInStreamWithLocalCopy final : public BinaryInStream { public: OF_DISALLOW_COPY_AND_MOVE(BinaryInStreamWithLocalCopy); BinaryInStreamWithLocalCopy() = delete; ~BinaryInStreamWithLocalCopy() = default; BinaryInStreamWithLocalCopy(fs::FileSystem* fs, const std::string& file_path); int32_t Read(char* s, size_t n) override { return (this->*read_mthd_)(s, n); } uint64_t file_size() const override { return in_stream_->file_size(); } uint64_t cur_file_pos() const override { return in_stream_->cur_file_pos(); } void set_cur_file_pos(uint64_t val) override { in_stream_->set_cur_file_pos(val); } bool IsEof() const override { return in_stream_->IsEof(); } private: int32_t ReadAndWriteToLocal(char* s, size_t n); int32_t ReadFromLocal(char* s, size_t n) { return in_stream_->Read(s, n); } bool Restart(); void CopyToLocalFinish(); bool once_read_; std::unique_ptr in_stream_; std::string local_copy_path_; std::unique_ptr out_stream_; int32_t (BinaryInStreamWithLocalCopy::*read_mthd_)(char*, size_t); }; } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_ ================================================ FILE: oneflow/core/persistence/binary_in_stream_without_local_copy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/binary_in_stream_without_local_copy.h" #include "oneflow/core/job/job_desc.h" #include namespace oneflow { int32_t BinaryInStreamWithoutLocalCopy::Read(char* s, size_t n) { if (IsEof()) return -1; CHECK_LE(cur_file_pos_ + n, file_size_); file_->Read(cur_file_pos_, n, s); cur_file_pos_ += n; return 0; } BinaryInStreamWithoutLocalCopy::BinaryInStreamWithoutLocalCopy(fs::FileSystem* fs, const std::string& file_path) : cur_file_pos_(0) { fs->NewRandomAccessFile(file_path, &file_); file_size_ = fs->GetFileSize(file_path); } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/binary_in_stream_without_local_copy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_ #define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_ #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/binary_in_stream.h" namespace oneflow { class BinaryInStreamWithoutLocalCopy final : public BinaryInStream { public: OF_DISALLOW_COPY_AND_MOVE(BinaryInStreamWithoutLocalCopy); BinaryInStreamWithoutLocalCopy() = delete; virtual ~BinaryInStreamWithoutLocalCopy() = default; BinaryInStreamWithoutLocalCopy(fs::FileSystem*, const std::string& file_path); int32_t Read(char* s, size_t n) override; uint64_t file_size() const override { return file_size_; } uint64_t cur_file_pos() const override { return cur_file_pos_; } void set_cur_file_pos(uint64_t val) override { cur_file_pos_ = val; } bool IsEof() const override { return cur_file_pos_ == file_size_; } private: std::unique_ptr file_; uint64_t file_size_; uint64_t cur_file_pos_; }; } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_ ================================================ FILE: oneflow/core/persistence/file_system.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/file_system.h" #include #include "oneflow/core/common/str_util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/persistence/hadoop/hadoop_file_system.h" #include "oneflow/core/persistence/posix/posix_file_system.h" #include "oneflow/core/job/job_set.pb.h" namespace oneflow { namespace fs { std::string FileSystem::SplitRecursiveDir(const std::string& dirname, std::vector& sub_dirs) { std::string remaining_dir = dirname; while (!remaining_dir.empty()) { bool status = FileExists(remaining_dir); if (status) { break; } // Basename returns "" for / ending dirs. if (remaining_dir[remaining_dir.length() - 1] != '/') { sub_dirs.emplace_back(Basename(remaining_dir)); } remaining_dir = Dirname(remaining_dir); } // sub_dirs contains all the dirs to be created but in reverse order. std::reverse(sub_dirs.begin(), sub_dirs.end()); return remaining_dir; } void FileSystem::CreateDirIfNotExist(const std::string& dirname) { if (IsDirectory(dirname)) { return; } CreateDir(dirname); } void FileSystem::RecursivelyCreateDirIfNotExist(const std::string& dirname) { if (IsDirectory(dirname)) { return; } // sub_dirs contains all the dirs to be created but in reverse order. std::vector sub_dirs; std::string remaining_dir = SplitRecursiveDir(dirname, sub_dirs); // Now create the directories. std::string built_path = remaining_dir; for (const std::string& sub_dir : sub_dirs) { built_path = JoinPath(built_path, sub_dir); CreateDirIfNotExist(built_path); } } bool FileSystem::IsDirEmpty(const std::string& dirname) { return ListDir(dirname).empty(); } std::string FileSystem::TranslateName(const std::string& name) const { return CleanPath(name); } void FileSystem::MakeEmptyDir(const std::string& dirname) { if (IsDirectory(dirname)) { RecursivelyDeleteDir(dirname); } RecursivelyCreateDir(dirname); } void FileSystem::RecursivelyDeleteDir(const std::string& dirname) { CHECK(FileExists(dirname)); std::deque dir_q; // Queue for the BFS std::vector dir_list; // List of all dirs discovered dir_q.emplace_back(dirname); // ret : Status to be returned. // Do a BFS on the directory to discover all the sub-directories. Remove all // children that are files along the way. Then cleanup and remove the // directories in reverse order.; while (!dir_q.empty()) { std::string dir = dir_q.front(); dir_q.pop_front(); dir_list.emplace_back(dir); // GetChildren might fail if we don't have appropriate permissions. std::vector children = ListDir(dir); for (const std::string& child : children) { const std::string child_path = JoinPath(dir, child); // If the child is a directory add it to the queue, otherwise delete it. if (IsDirectory(child_path)) { dir_q.emplace_back(child_path); } else { // Delete file might fail because of permissions issues or might be // unimplemented. DelFile(child_path); } } } // Now reverse the list of directories and delete them. The BFS ensures that // we can delete the directories in this order. std::reverse(dir_list.begin(), dir_list.end()); for (const std::string& dir : dir_list) { // Delete dir might fail because of permissions issues or might be // unimplemented. DeleteDir(dir); } } void FileSystem::RecursivelyCreateDir(const std::string& dirname) { // sub_dirs contains all the dirs to be created but in reverse order. std::vector sub_dirs; std::string remaining_dir = SplitRecursiveDir(dirname, sub_dirs); // Now create the directories. std::string built_path = remaining_dir; for (const std::string& sub_dir : sub_dirs) { built_path = JoinPath(built_path, sub_dir); CreateDir(built_path); } } } // namespace fs void CreateLocalFS(std::unique_ptr& fs) { #ifdef OF_PLATFORM_POSIX fs.reset(new fs::PosixFileSystem); #else OF_UNIMPLEMENTED(); #endif } void CreateHadoopFS(std::unique_ptr& fs, const std::string& namenode) { fs.reset(new fs::HadoopFileSystem(namenode)); } void CreateFileSystemFromEnv(std::unique_ptr& fs, const std::string& env_prefix) { CHECK(!fs); auto fs_type_env = env_prefix + "_TYPE"; const char* fs_type = std::getenv(fs_type_env.c_str()); std::string fs_type_str; if (fs_type) { fs_type_str = ToLower(fs_type); } else { // local file system by default fs_type_str = "local"; } if (fs_type_str == "local") { CreateLocalFS(fs); } else if (fs_type_str == "hdfs") { auto hdfs_nn_env = env_prefix + "_HDFS_NAMENODE"; const char* hdfs_namenode = std::getenv(hdfs_nn_env.c_str()); if (hdfs_namenode == nullptr) { LOG(FATAL) << "env " << hdfs_nn_env << " must be set when " << fs_type_env << " be set to hdfs"; } CreateHadoopFS(fs, hdfs_namenode); } else { LOG(FATAL) << "invalid value " << fs_type << " of env " << fs_type_env; } } fs::FileSystem* DataFS() { static std::unique_ptr data_fs; static std::mutex data_fs_mutex; { std::lock_guard lock(data_fs_mutex); if (!data_fs) { CreateFileSystemFromEnv(data_fs, "ONEFLOW_DATA_FILE_SYSTEM"); } } return data_fs.get(); } fs::FileSystem* SnapshotFS() { static std::unique_ptr snapshot_fs; static std::mutex snapshot_fs_mutex; { std::lock_guard lock(snapshot_fs_mutex); if (!snapshot_fs) { CreateFileSystemFromEnv(snapshot_fs, "ONEFLOW_SNAPSHOT_FILE_SYSTEM"); } } return snapshot_fs.get(); } fs::FileSystem* LocalFS() { static std::unique_ptr local_fs; static std::mutex local_fs_mutex; { std::lock_guard lock(local_fs_mutex); if (!local_fs) { CreateLocalFS(local_fs); } } return local_fs.get(); } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/file_system.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_ #define ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_ #include "oneflow/core/common/platform.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace fs { // A file abstraction for randomly reading the contents of a file. class RandomAccessFile { public: OF_DISALLOW_COPY_AND_MOVE(RandomAccessFile); RandomAccessFile() = default; virtual ~RandomAccessFile() = default; // Reads `n` bytes from the file starting at `offset`. // Sets `*result` to the data that was read. // // Safe for concurrent use by multiple threads. virtual void Read(uint64_t offset, size_t n, char* result) const = 0; private: }; // A file abstraction for sequential writing. // // The implementation must provide buffering since callers may append // small fragments at a time to the file. class WritableFile { public: OF_DISALLOW_COPY_AND_MOVE(WritableFile); WritableFile() = default; virtual ~WritableFile() = default; // Append 'data' to the file. virtual void Append(const char* data, size_t n) = 0; // Close the file. // // Flush() and de-allocate resources associated with this file virtual void Close() = 0; // Flushes the file and optionally syncs contents to filesystem. // // This should flush any local buffers whose contents have not been // delivered to the filesystem. // // If the process terminates after a successful flush, the contents // may still be persisted, since the underlying filesystem may // eventually flush the contents. If the OS or machine crashes // after a successful flush, the contents may or may not be // persisted, depending on the implementation. virtual void Flush() = 0; private: }; class FileSystem { public: virtual ~FileSystem() = default; // Creates a brand new random access read-only file with the // specified name. // // On success, stores a pointer to the new file in // *result. On failure stores NULL in *result. // // The returned file may be concurrently accessed by multiple threads. // // The ownership of the returned RandomAccessFile is passed to the caller // and the object should be deleted when is not used. virtual void NewRandomAccessFile(const std::string& fname, std::unique_ptr* result) = 0; // Creates an object that writes to a new file with the specified // name. // // Deletes any existing file with the same name and creates a // new file. On success, stores a pointer to the new file in // *result. On failure stores NULL in *result. // // The returned file will only be accessed by one thread at a time. // // The ownership of the returned WritableFile is passed to the caller // and the object should be deleted when is not used. virtual void NewWritableFile(const std::string& fname, std::unique_ptr* result) = 0; // Creates an object that either appends to an existing file, or // writes to a new file (if the file does not exist to begin with). // // On success, stores a pointer to the new file in *result. // On failure stores NULL in *result. // // The returned file will only be accessed by one thread at a time. // // The ownership of the returned WritableFile is passed to the caller // and the object should be deleted when is not used. virtual void NewAppendableFile(const std::string& fname, std::unique_ptr* result) = 0; // Returns true if the named path exists and false otherwise. virtual bool FileExists(const std::string& fname) = 0; // Returns the immediate children in the `dir` // // The returned paths are relative to 'dir'. virtual std::vector ListDir(const std::string& dir) = 0; // Deletes the named file. // Using DelFile to avoid Windows macro virtual void DelFile(const std::string& fname) = 0; // Creates the specified directory. virtual void CreateDir(const std::string& dirname) = 0; virtual void CreateDirIfNotExist(const std::string& dirname); virtual void RecursivelyCreateDir(const std::string& dirname); void RecursivelyCreateDirIfNotExist(const std::string& dirname); // Empty bool IsDirEmpty(const std::string& dirname); void MakeEmptyDir(const std::string& dirname); // Deletes the specified directory. virtual void DeleteDir(const std::string& dirname) = 0; // Deletes the specified directory and all subdirectories and files // underneath it. undeleted_files and undeleted_dirs stores the number of // files and directories that weren't deleted. virtual void RecursivelyDeleteDir(const std::string& dirname); // Returns the size of `fname`. virtual uint64_t GetFileSize(const std::string& fname) = 0; // Overwrites the target if it exists. virtual void RenameFile(const std::string& old_name, const std::string& new_name) = 0; // Translate an URI to a filename for the FileSystem implementation. // // The implementation in this class cleans up the path, removing // duplicate /'s, resolving .. and . (more details in // str_util.h CleanPath). virtual std::string TranslateName(const std::string& name) const; // Returns whether the given path is a directory or not. virtual bool IsDirectory(const std::string& fname) = 0; protected: FileSystem() = default; private: std::string SplitRecursiveDir(const std::string& dirname, std::vector& sub_dirs); }; } // namespace fs fs::FileSystem* LocalFS(); fs::FileSystem* DataFS(); fs::FileSystem* SnapshotFS(); } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_ ================================================ FILE: oneflow/core/persistence/file_system_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/process_state.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/persistence/posix/posix_file_system.h" namespace oneflow { namespace fs { void TestFileOperation(FileSystem* file_system) { std::string current_dir = GetCwd(); StringReplace(¤t_dir, '\\', '/'); std::string file_name = JoinPath(current_dir, "/tmp_test_file_asdfasdf"); // write std::unique_ptr writable_file; file_system->NewWritableFile(file_name, &writable_file); std::string write_content = "oneflow-file-system-test"; writable_file->Append(write_content.substr(0, 10).c_str(), 10); writable_file->Flush(); writable_file->Append(write_content.substr(10, 14).c_str(), 14); writable_file->Close(); // write append std::string append_content = "append-text"; std::unique_ptr appendable_file; file_system->NewAppendableFile(file_name, &appendable_file); appendable_file->Append(append_content.c_str(), 11); appendable_file->Flush(); appendable_file->Close(); // rename std::string new_file_name = file_name + "_new"; file_system->RenameFile(file_name, new_file_name); file_system->RenameFile(new_file_name, file_name); // read std::unique_ptr random_access_file; file_system->NewRandomAccessFile(file_name, &random_access_file); uint64_t file_size = file_system->GetFileSize(file_name); ASSERT_EQ(file_size, 35); char* read_array = new char[file_size]; random_access_file->Read(0, file_size, read_array); std::string read_content(read_array, file_size); ASSERT_EQ(write_content + append_content, read_content); file_system->DelFile(file_name); delete[] read_array; } void TestDirOperation(FileSystem* file_system) { std::string current_dir = GetCwd(); StringReplace(¤t_dir, '\\', '/'); std::string test_root_path = JoinPath(current_dir, "/tmp_test_dir_asdfasdf"); if (file_system->IsDirectory(test_root_path)) { ASSERT_TRUE(file_system->ListDir(test_root_path).empty()); } else { file_system->CreateDir(test_root_path); } std::string file_name = JoinPath(test_root_path, "/direct_file_"); std::string content = "test_file"; std::unique_ptr file_a; std::unique_ptr file_b; file_system->NewWritableFile(file_name + "_a", &file_a); file_a->Append(content.c_str(), 9); file_a->Close(); file_system->NewWritableFile(file_name + "_b", &file_b); file_b->Append(content.c_str(), 9); file_b->Close(); std::string child_dir = JoinPath(test_root_path, "/direct_dir"); file_system->CreateDir(child_dir); ASSERT_EQ(file_system->ListDir(test_root_path).size(), 3); file_system->DeleteDir(child_dir); ASSERT_TRUE(!file_system->IsDirectory(child_dir)); file_system->RecursivelyDeleteDir(test_root_path); ASSERT_TRUE(!file_system->IsDirectory(test_root_path)); } void TestMultiThreadsDirOperation(FileSystem* file_system) { std::string current_dir = GetCwd(); StringReplace(¤t_dir, '\\', '/'); std::string test_root_path = JoinPath(current_dir, "tmp_multithread_test_dir"); std::vector thread_vector; for (int i = 0; i < 10; i++) { thread_vector.emplace_back( std::thread([&]() { file_system->RecursivelyCreateDirIfNotExist(test_root_path); })); } for (int i = 0; i < 10; i++) { thread_vector[i].join(); } ASSERT_TRUE(file_system->IsDirectory(test_root_path)); } void TestFileSystem(FileSystem* file_system) { TestFileOperation(file_system); TestDirOperation(file_system); TestMultiThreadsDirOperation(file_system); } } // namespace fs TEST(file_system, write_and_read) { #ifdef OF_PLATFORM_POSIX fs::FileSystem* file_system = new fs::PosixFileSystem(); fs::TestFileSystem(file_system); #endif } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/hadoop/hadoop_file_system.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/hadoop/hadoop_file_system.h" #include #include "oneflow/core/common/str_util.h" #ifdef OF_PLATFORM_POSIX #include #endif // OF_PLATFORM_POSIX #define FS_RETURN_FALSE_IF_FALSE(val) \ if (!val) { \ PLOG(WARNING); \ return false; \ } namespace oneflow { namespace fs { namespace internal { #ifdef OF_PLATFORM_POSIX bool GetSymbolFromLibrary(void* handle, const char* symbol_name, void** symbol) { *symbol = dlsym(handle, symbol_name); if (!*symbol) { PLOG(WARNING) << dlerror(); return false; } return true; } bool LoadLibrary(const char* library_filename, void** handle) { *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL); if (!*handle) { PLOG(WARNING) << dlerror(); return false; } return true; } #endif // OF_PLATFORM_POSIX } // namespace internal template bool BindFunc(void* handle, const char* name, std::function* func) { void* symbol_ptr = nullptr; FS_RETURN_FALSE_IF_FALSE(internal::GetSymbolFromLibrary(handle, name, &symbol_ptr)); *func = reinterpret_cast(symbol_ptr); return true; } void LibHDFS::LoadAndBind() { auto TryLoadAndBind = [this](const char* name, void** handle) -> bool { FS_RETURN_FALSE_IF_FALSE(internal::LoadLibrary(name, handle)); #define BIND_HDFS_FUNC(function) FS_RETURN_FALSE_IF_FALSE(BindFunc(*handle, #function, &function)); BIND_HDFS_FUNC(hdfsBuilderConnect); BIND_HDFS_FUNC(hdfsNewBuilder); BIND_HDFS_FUNC(hdfsBuilderSetNameNode); BIND_HDFS_FUNC(hdfsConfGetStr); BIND_HDFS_FUNC(hdfsBuilderSetKerbTicketCachePath); BIND_HDFS_FUNC(hdfsCloseFile); BIND_HDFS_FUNC(hdfsPread); BIND_HDFS_FUNC(hdfsWrite); BIND_HDFS_FUNC(hdfsHFlush); BIND_HDFS_FUNC(hdfsHSync); BIND_HDFS_FUNC(hdfsOpenFile); BIND_HDFS_FUNC(hdfsExists); BIND_HDFS_FUNC(hdfsListDirectory); BIND_HDFS_FUNC(hdfsFreeFileInfo); BIND_HDFS_FUNC(hdfsDelete); BIND_HDFS_FUNC(hdfsCreateDirectory); BIND_HDFS_FUNC(hdfsGetPathInfo); BIND_HDFS_FUNC(hdfsRename); #undef BIND_HDFS_FUNC return true; }; // libhdfs.so won't be in the standard locations. Use the path as specified // in the libhdfs documentation. const char* kLibHdfsDso = "libhdfs.so"; char* hdfs_home = getenv("HADOOP_HOME"); if (hdfs_home == nullptr) { PLOG(WARNING) << "Environment variable HADOOP_HOME not set"; status_ = false; return; } std::string path = JoinPath(hdfs_home, "lib", "native", kLibHdfsDso); status_ = TryLoadAndBind(path.c_str(), &handle_); if (!status_) { // try load libhdfs.so using dynamic loader's search path in case // libhdfs.so is installed in non-standard location status_ = TryLoadAndBind(kLibHdfsDso, &handle_); } } HadoopFileSystem::HadoopFileSystem(const std::string& namenode) : namenode_(namenode), hdfs_(LibHDFS::Load()) {} bool HadoopFileSystem::Connect(hdfsFS* fs) { FS_RETURN_FALSE_IF_FALSE(hdfs_->status()); hdfsBuilder* builder = hdfs_->hdfsNewBuilder(); hdfs_->hdfsBuilderSetNameNode(builder, namenode_.c_str()); // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME // is the build in environment variable of Kerberos, so // KERB_TICKET_CACHE_PATH and related code are unnecessary. char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH"); if (ticket_cache_path != nullptr) { hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path); } *fs = hdfs_->hdfsBuilderConnect(builder); if (*fs == nullptr) { PLOG(WARNING) << " HDFS connect failed. NOT FOUND"; return false; } return true; } class HDFSRandomAccessFile : public RandomAccessFile { public: HDFSRandomAccessFile(const std::string& filename, const std::string& hdfs_filename, LibHDFS* hdfs, hdfsFS fs, hdfsFile file) : filename_(filename), hdfs_filename_(hdfs_filename), hdfs_(hdfs), fs_(fs), file_(file) {} ~HDFSRandomAccessFile() override { if (file_ != nullptr) { std::unique_lock lock(mu_); hdfs_->hdfsCloseFile(fs_, file_); } } void Read(uint64_t offset, size_t n, char* result) const override { char* dst = result; bool eof_retried = false; while (n > 0) { // We lock inside the loop rather than outside so we don't block other // concurrent readers. std::unique_lock lock(mu_); tSize r = hdfs_->hdfsPread(fs_, file_, static_cast(offset), dst, static_cast(n)); if (r > 0) { dst += r; n -= r; offset += r; } else if (!eof_retried && r == 0) { // Always reopen the file upon reaching EOF to see if there's more data. // If writers are streaming contents while others are concurrently // reading, HDFS requires that we reopen the file to see updated // contents. PCHECK(file_ == nullptr || hdfs_->hdfsCloseFile(fs_, file_) == 0) << filename_; file_ = hdfs_->hdfsOpenFile(fs_, hdfs_filename_.c_str(), O_RDONLY, 0, 0, 0); PCHECK(file_ != nullptr) << filename_; eof_retried = true; } else if (eof_retried && r == 0) { PLOG(FATAL) << "Read less bytes than requested"; return; } else if (errno == EINTR || errno == EAGAIN) { // hdfsPread may return EINTR too. Just retry. } else { PLOG(FATAL) << filename_; return; } } } private: std::string filename_; std::string hdfs_filename_; LibHDFS* hdfs_; hdfsFS fs_; mutable std::mutex mu_; mutable hdfsFile file_; }; void HadoopFileSystem::NewRandomAccessFile(const std::string& fname, std::unique_ptr* result) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); hdfsFile file = hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_RDONLY, 0, 0, 0); PCHECK(file != nullptr) << fname; result->reset(new HDFSRandomAccessFile(fname, TranslateName(fname), hdfs_, fs, file)); CHECK_NOTNULL(result->get()); } class HDFSWritableFile : public WritableFile { public: HDFSWritableFile(const std::string& fname, LibHDFS* hdfs, hdfsFS fs, hdfsFile file) : filename_(fname), hdfs_(hdfs), fs_(fs), file_(file) {} ~HDFSWritableFile() override { if (file_ != nullptr) { Close(); } } void Append(const char* data, size_t n) override { PCHECK(hdfs_->hdfsWrite(fs_, file_, data, static_cast(n)) != -1) << filename_; } void Close() override { int32_t result = hdfs_->hdfsCloseFile(fs_, file_); hdfs_ = nullptr; fs_ = nullptr; file_ = nullptr; PCHECK(result == 0) << filename_; } void Flush() override { PCHECK(hdfs_->hdfsHFlush(fs_, file_) == 0) << filename_; } private: std::string filename_; LibHDFS* hdfs_; hdfsFS fs_; hdfsFile file_; }; void HadoopFileSystem::NewWritableFile(const std::string& fname, std::unique_ptr* result) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); hdfsFile file = hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_WRONLY, 0, 0, 0); PCHECK(file != nullptr) << fname; result->reset(new HDFSWritableFile(fname, hdfs_, fs, file)); CHECK_NOTNULL(result->get()); } void HadoopFileSystem::NewAppendableFile(const std::string& fname, std::unique_ptr* result) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); hdfsFile file = hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_WRONLY | O_APPEND, 0, 0, 0); PCHECK(file != nullptr) << fname; result->reset(new HDFSWritableFile(fname, hdfs_, fs, file)); CHECK_NOTNULL(result->get()); } bool HadoopFileSystem::FileExists(const std::string& fname) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); if (hdfs_->hdfsExists(fs, TranslateName(fname).c_str()) == 0) { return true; } return false; } std::vector HadoopFileSystem::ListDir(const std::string& dir) { std::vector result; hdfsFS fs = nullptr; CHECK(Connect(&fs)); // hdfsListDirectory returns nullptr if the directory is empty. Do a separate // check to verify the directory exists first. CHECK(IsDirectory(dir)) << "directory not found, path: " << dir; int entries = 0; hdfsFileInfo* info = hdfs_->hdfsListDirectory(fs, TranslateName(dir).c_str(), &entries); if (info == nullptr) { // Assume it's an empty directory. return result; } for (int i = 0; i < entries; i++) { result.emplace_back(Basename(info[i].mName)); } hdfs_->hdfsFreeFileInfo(info, entries); return result; } void HadoopFileSystem::DelFile(const std::string& fname) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); PCHECK(hdfs_->hdfsDelete(fs, TranslateName(fname).c_str(), /*recursive=*/0) == 0) << fname; } void HadoopFileSystem::CreateDir(const std::string& dir) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); PCHECK(hdfs_->hdfsCreateDirectory(fs, TranslateName(dir).c_str()) == 0) << dir; } void HadoopFileSystem::DeleteDir(const std::string& dir) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); // Count the number of entries in the directory, and only delete if it's // non-empty. This is consistent with the interface, but note that there's // a race condition where a file may be added after this check, in which // case the directory will still be deleted. int entries = 0; hdfsFileInfo* info = hdfs_->hdfsListDirectory(fs, TranslateName(dir).c_str(), &entries); if (info != nullptr) { hdfs_->hdfsFreeFileInfo(info, entries); } // Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty // folder, expscially for Kerberos enable setup, EAGAIN is quite common // when the call is actually successful. Check again by Stat. if (info == nullptr && errno != 0) { CHECK(IsDirectory(dir)) << "directory not found, path: " << dir; } PCHECK(entries == 0) << dir << "Cannot delete a non-empty directory."; PCHECK(hdfs_->hdfsDelete(fs, TranslateName(dir).c_str(), /*recursive=*/1) == 0) << dir; } void HadoopFileSystem::RecursivelyDeleteDir(const std::string& dirname) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); PCHECK(hdfs_->hdfsDelete(fs, TranslateName(dirname).c_str(), /*recursive=*/1) == 0) << dirname; } uint64_t HadoopFileSystem::GetFileSize(const std::string& fname) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); hdfsFileInfo* info = hdfs_->hdfsGetPathInfo(fs, TranslateName(fname).c_str()); PCHECK(info != nullptr) << fname; uint64_t ret = info->mSize; hdfs_->hdfsFreeFileInfo(info, 1); return ret; } void HadoopFileSystem::RenameFile(const std::string& old_name, const std::string& new_name) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); PCHECK(hdfs_->hdfsExists(fs, TranslateName(new_name).c_str()) != 0 || hdfs_->hdfsDelete(fs, TranslateName(new_name).c_str(), /*recursive=*/0) == 0) << new_name; PCHECK(hdfs_->hdfsRename(fs, TranslateName(old_name).c_str(), TranslateName(new_name).c_str()) == 0) << old_name; } bool HadoopFileSystem::IsDirectory(const std::string& fname) { hdfsFS fs = nullptr; CHECK(Connect(&fs)); hdfsFileInfo* info = hdfs_->hdfsGetPathInfo(fs, TranslateName(fname).c_str()); if (info == nullptr || info->mKind != kObjectKindDirectory) { return false; } hdfs_->hdfsFreeFileInfo(info, 1); return true; } } // namespace fs } // namespace oneflow ================================================ FILE: oneflow/core/persistence/hadoop/hadoop_file_system.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_ #define ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/hadoop/hdfs.h" extern "C" { struct hdfs_internal; typedef hdfs_internal* hdfsFS; } namespace oneflow { namespace fs { class LibHDFS { public: static LibHDFS* Load() { static LibHDFS* lib = []() -> LibHDFS* { LibHDFS* lib = new LibHDFS; lib->LoadAndBind(); return lib; }(); return lib; } // The status, if any, from failure to load. // true is OK // false is non-OK bool status() { return status_; } std::function hdfsBuilderConnect; std::function hdfsNewBuilder; std::function hdfsBuilderSetNameNode; std::function hdfsConfGetStr; std::function hdfsBuilderSetKerbTicketCachePath; std::function hdfsCloseFile; std::function hdfsPread; std::function hdfsWrite; std::function hdfsHFlush; std::function hdfsHSync; std::function hdfsOpenFile; std::function hdfsExists; std::function hdfsListDirectory; std::function hdfsFreeFileInfo; std::function hdfsDelete; std::function hdfsCreateDirectory; std::function hdfsGetPathInfo; std::function hdfsRename; private: void LoadAndBind(); bool status_; void* handle_ = nullptr; }; class HadoopFileSystem final : public FileSystem { public: OF_DISALLOW_COPY_AND_MOVE(HadoopFileSystem); HadoopFileSystem() = delete; ~HadoopFileSystem() = default; HadoopFileSystem(const std::string&); void NewRandomAccessFile(const std::string& fname, std::unique_ptr* result) override; void NewWritableFile(const std::string& fname, std::unique_ptr* result) override; void NewAppendableFile(const std::string& fname, std::unique_ptr* result) override; bool FileExists(const std::string& fname) override; std::vector ListDir(const std::string& dir) override; void DelFile(const std::string& fname) override; void CreateDir(const std::string& dirname) override; void DeleteDir(const std::string& dirname) override; void RecursivelyDeleteDir(const std::string& dirname) override; uint64_t GetFileSize(const std::string& fname) override; void RenameFile(const std::string& old_name, const std::string& new_name) override; bool IsDirectory(const std::string& fname) override; private: bool Connect(hdfsFS* fs); std::string namenode_; LibHDFS* hdfs_; }; } // namespace fs } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_ ================================================ FILE: oneflow/core/persistence/hadoop/hdfs.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_ #define ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_ #include /* for EINTERNAL, etc. */ #include /* for O_RDONLY, O_WRONLY */ #include /* for uint64_t, etc. */ #include /* for time_t */ /* * Support export of DLL symbols during libhdfs build, and import of DLL symbols * during client application build. A client application may optionally define * symbol LIBHDFS_DLL_IMPORT in its build. This is not strictly required, but * the compiler can produce more efficient code with it. */ #ifdef WIN32 #ifdef LIBHDFS_DLL_EXPORT #define LIBHDFS_EXTERNAL __declspec(dllexport) #elif LIBHDFS_DLL_IMPORT #define LIBHDFS_EXTERNAL __declspec(dllimport) #else #define LIBHDFS_EXTERNAL #endif #else #ifdef LIBHDFS_DLL_EXPORT #define LIBHDFS_EXTERNAL __attribute__((visibility("default"))) #elif LIBHDFS_DLL_IMPORT #define LIBHDFS_EXTERNAL __attribute__((visibility("default"))) #else #define LIBHDFS_EXTERNAL #endif #endif #ifndef O_RDONLY #define O_RDONLY 1 #endif #ifndef O_WRONLY #define O_WRONLY 2 #endif #ifndef EINTERNAL #define EINTERNAL 255 #endif #define ELASTIC_BYTE_BUFFER_POOL_CLASS "org/apache/hadoop/io/ElasticByteBufferPool" /** All APIs set errno to meaningful values */ #ifdef __cplusplus extern "C" { #endif /** * Some utility decls used in libhdfs. */ struct hdfsBuilder; typedef int32_t tSize; /// size of data for read/write io ops typedef time_t tTime; /// time type in seconds typedef int64_t tOffset; /// offset within the file typedef uint16_t tPort; /// port typedef enum tObjectKind { kObjectKindFile = 'F', kObjectKindDirectory = 'D', } tObjectKind; /** * The C reflection of org.apache.org.hadoop.FileSystem . */ struct hdfs_internal; typedef struct hdfs_internal* hdfsFS; struct hdfsFile_internal; typedef struct hdfsFile_internal* hdfsFile; struct hadoopRzOptions; struct hadoopRzBuffer; /** * Determine if a file is open for read. * * @param file The HDFS file * @return 1 if the file is open for read; 0 otherwise */ LIBHDFS_EXTERNAL int hdfsFileIsOpenForRead(hdfsFile file); /** * Determine if a file is open for write. * * @param file The HDFS file * @return 1 if the file is open for write; 0 otherwise */ LIBHDFS_EXTERNAL int hdfsFileIsOpenForWrite(hdfsFile file); struct hdfsReadStatistics { uint64_t totalBytesRead; uint64_t totalLocalBytesRead; uint64_t totalShortCircuitBytesRead; uint64_t totalZeroCopyBytesRead; }; /** * Get read statistics about a file. This is only applicable to files * opened for reading. * * @param file The HDFS file * @param stats (out parameter) on a successful return, the read * statistics. Unchanged otherwise. You must free the * returned statistics with hdfsFileFreeReadStatistics. * @return 0 if the statistics were successfully returned, * -1 otherwise. On a failure, please check errno against * ENOTSUP. webhdfs, LocalFilesystem, and so forth may * not support read statistics. */ LIBHDFS_EXTERNAL int hdfsFileGetReadStatistics(hdfsFile file, struct hdfsReadStatistics** stats); /** * @param stats HDFS read statistics for a file. * * @return the number of remote bytes read. */ LIBHDFS_EXTERNAL int64_t hdfsReadStatisticsGetRemoteBytesRead(const struct hdfsReadStatistics* stats); /** * Clear the read statistics for a file. * * @param file The file to clear the read statistics of. * * @return 0 on success; the error code otherwise. * EINVAL: the file is not open for reading. * ENOTSUP: the file does not support clearing the read * statistics. * Errno will also be set to this code on failure. */ LIBHDFS_EXTERNAL int hdfsFileClearReadStatistics(hdfsFile file); /** * Free some HDFS read statistics. * * @param stats The HDFS read statistics to free. */ LIBHDFS_EXTERNAL void hdfsFileFreeReadStatistics(struct hdfsReadStatistics* stats); /** * hdfsConnectAsUser - Connect to a hdfs file system as a specific user * Connect to the hdfs. * @param nn The NameNode. See hdfsBuilderSetNameNode for details. * @param port The port on which the server is listening. * @param user the user name (this is hadoop domain user). Or NULL is equivalent * to hhdfsConnect(host, port) * @return Returns a handle to the filesystem or NULL on error. * @deprecated Use hdfsBuilderConnect instead. */ LIBHDFS_EXTERNAL hdfsFS hdfsConnectAsUser(const char* nn, tPort port, const char* user); /** * hdfsConnect - Connect to a hdfs file system. * Connect to the hdfs. * @param nn The NameNode. See hdfsBuilderSetNameNode for details. * @param port The port on which the server is listening. * @return Returns a handle to the filesystem or NULL on error. * @deprecated Use hdfsBuilderConnect instead. */ LIBHDFS_EXTERNAL hdfsFS hdfsConnect(const char* nn, tPort port); /** * hdfsConnect - Connect to an hdfs file system. * * Forces a new instance to be created * * @param nn The NameNode. See hdfsBuilderSetNameNode for details. * @param port The port on which the server is listening. * @param user The user name to use when connecting * @return Returns a handle to the filesystem or NULL on error. * @deprecated Use hdfsBuilderConnect instead. */ LIBHDFS_EXTERNAL hdfsFS hdfsConnectAsUserNewInstance(const char* nn, tPort port, const char* user); /** * hdfsConnect - Connect to an hdfs file system. * * Forces a new instance to be created * * @param nn The NameNode. See hdfsBuilderSetNameNode for details. * @param port The port on which the server is listening. * @return Returns a handle to the filesystem or NULL on error. * @deprecated Use hdfsBuilderConnect instead. */ LIBHDFS_EXTERNAL hdfsFS hdfsConnectNewInstance(const char* nn, tPort port); /** * Connect to HDFS using the parameters defined by the builder. * * The HDFS builder will be freed, whether or not the connection was * successful. * * Every successful call to hdfsBuilderConnect should be matched with a call * to hdfsDisconnect, when the hdfsFS is no longer needed. * * @param bld The HDFS builder * @return Returns a handle to the filesystem, or NULL on error. */ LIBHDFS_EXTERNAL hdfsFS hdfsBuilderConnect(struct hdfsBuilder* bld); /** * Create an HDFS builder. * * @return The HDFS builder, or NULL on error. */ LIBHDFS_EXTERNAL struct hdfsBuilder* hdfsNewBuilder(void); /** * Force the builder to always create a new instance of the FileSystem, * rather than possibly finding one in the cache. * * @param bld The HDFS builder */ LIBHDFS_EXTERNAL void hdfsBuilderSetForceNewInstance(struct hdfsBuilder* bld); /** * Set the HDFS NameNode to connect to. * * @param bld The HDFS builder * @param nn The NameNode to use. * * If the string given is 'default', the default NameNode * configuration will be used (from the XML configuration files) * * If NULL is given, a LocalFileSystem will be created. * * If the string starts with a protocol type such as file:// or * hdfs://, this protocol type will be used. If not, the * hdfs:// protocol type will be used. * * You may specify a NameNode port in the usual way by * passing a string of the format hdfs://:. * Alternately, you may set the port with * hdfsBuilderSetNameNodePort. However, you must not pass the * port in two different ways. */ LIBHDFS_EXTERNAL void hdfsBuilderSetNameNode(struct hdfsBuilder* bld, const char* nn); /** * Set the port of the HDFS NameNode to connect to. * * @param bld The HDFS builder * @param port The port. */ LIBHDFS_EXTERNAL void hdfsBuilderSetNameNodePort(struct hdfsBuilder* bld, tPort port); /** * Set the username to use when connecting to the HDFS cluster. * * @param bld The HDFS builder * @param userName The user name. The string will be shallow-copied. */ LIBHDFS_EXTERNAL void hdfsBuilderSetUserName(struct hdfsBuilder* bld, const char* userName); /** * Set the path to the Kerberos ticket cache to use when connecting to * the HDFS cluster. * * @param bld The HDFS builder * @param kerbTicketCachePath The Kerberos ticket cache path. The string * will be shallow-copied. */ LIBHDFS_EXTERNAL void hdfsBuilderSetKerbTicketCachePath(struct hdfsBuilder* bld, const char* kerbTicketCachePath); /** * Free an HDFS builder. * * It is normally not necessary to call this function since * hdfsBuilderConnect frees the builder. * * @param bld The HDFS builder */ LIBHDFS_EXTERNAL void hdfsFreeBuilder(struct hdfsBuilder* bld); /** * Set a configuration string for an HdfsBuilder. * * @param key The key to set. * @param val The value, or NULL to set no value. * This will be shallow-copied. You are responsible for * ensuring that it remains valid until the builder is * freed. * * @return 0 on success; nonzero error code otherwise. */ LIBHDFS_EXTERNAL int hdfsBuilderConfSetStr(struct hdfsBuilder* bld, const char* key, const char* val); /** * Get a configuration string. * * @param key The key to find * @param val (out param) The value. This will be set to NULL if the * key isn't found. You must free this string with * hdfsConfStrFree. * * @return 0 on success; nonzero error code otherwise. * Failure to find the key is not an error. */ LIBHDFS_EXTERNAL int hdfsConfGetStr(const char* key, char** val); /** * Get a configuration integer. * * @param key The key to find * @param val (out param) The value. This will NOT be changed if the * key isn't found. * * @return 0 on success; nonzero error code otherwise. * Failure to find the key is not an error. */ LIBHDFS_EXTERNAL int hdfsConfGetInt(const char* key, int32_t* val); /** * Free a configuration string found with hdfsConfGetStr. * * @param val A configuration string obtained from hdfsConfGetStr */ LIBHDFS_EXTERNAL void hdfsConfStrFree(char* val); /** * hdfsDisconnect - Disconnect from the hdfs file system. * Disconnect from hdfs. * @param fs The configured filesystem handle. * @return Returns 0 on success, -1 on error. * Even if there is an error, the resources associated with the * hdfsFS will be freed. */ LIBHDFS_EXTERNAL int hdfsDisconnect(hdfsFS fs); /** * hdfsOpenFile - Open a hdfs file in given mode. * @param fs The configured filesystem handle. * @param path The full path to the file. * @param flags - an | of bits/fcntl.h file flags - supported flags are * O_RDONLY, O_WRONLY (meaning create or overwrite i.e., implies O_TRUNCAT), * O_WRONLY|O_APPEND. Other flags are generally ignored other than (O_RDWR || * (O_EXCL & O_CREAT)) which return NULL and set errno equal ENOTSUP. * @param bufferSize Size of buffer for read/write - pass 0 if you want * to use the default configured values. * @param replication Block replication - pass 0 if you want to use * the default configured values. * @param blocksize Size of block - pass 0 if you want to use the * default configured values. * @return Returns the handle to the open file or NULL on error. */ LIBHDFS_EXTERNAL hdfsFile hdfsOpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication, tSize blocksize); /** * hdfsTruncateFile - Truncate a hdfs file to given length. * @param fs The configured filesystem handle. * @param path The full path to the file. * @param newlength The size the file is to be truncated to * @return 1 if the file has been truncated to the desired newlength * and is immediately available to be reused for write operations * such as append. * 0 if a background process of adjusting the length of the last * block has been started, and clients should wait for it to * complete before proceeding with further file updates. * -1 on error. */ int hdfsTruncateFile(hdfsFS fs, const char* path, tOffset newlength); /** * hdfsUnbufferFile - Reduce the buffering done on a file. * * @param file The file to unbuffer. * @return 0 on success * ENOTSUP if the file does not support unbuffering * Errno will also be set to this value. */ LIBHDFS_EXTERNAL int hdfsUnbufferFile(hdfsFile file); /** * hdfsCloseFile - Close an open file. * @param fs The configured filesystem handle. * @param file The file handle. * @return Returns 0 on success, -1 on error. * On error, errno will be set appropriately. * If the hdfs file was valid, the memory associated with it will * be freed at the end of this call, even if there was an I/O * error. */ LIBHDFS_EXTERNAL int hdfsCloseFile(hdfsFS fs, hdfsFile file); /** * hdfsExists - Checks if a given path exsits on the filesystem * @param fs The configured filesystem handle. * @param path The path to look for * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsExists(hdfsFS fs, const char* path); /** * hdfsSeek - Seek to given offset in file. * This works only for files opened in read-only mode. * @param fs The configured filesystem handle. * @param file The file handle. * @param desiredPos Offset into the file to seek into. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsSeek(hdfsFS fs, hdfsFile file, tOffset desiredPos); /** * hdfsTell - Get the current offset in the file, in bytes. * @param fs The configured filesystem handle. * @param file The file handle. * @return Current offset, -1 on error. */ LIBHDFS_EXTERNAL tOffset hdfsTell(hdfsFS fs, hdfsFile file); /** * hdfsRead - Read data from an open file. * @param fs The configured filesystem handle. * @param file The file handle. * @param buffer The buffer to copy read bytes into. * @param length The length of the buffer. * @return On success, a positive number indicating how many bytes * were read. * On end-of-file, 0. * On error, -1. Errno will be set to the error code. * Just like the POSIX read function, hdfsRead will return -1 * and set errno to EINTR if data is temporarily unavailable, * but we are not yet at the end of the file. */ LIBHDFS_EXTERNAL tSize hdfsRead(hdfsFS fs, hdfsFile file, void* buffer, tSize length); /** * hdfsPread - Positional read of data from an open file. * @param fs The configured filesystem handle. * @param file The file handle. * @param position Position from which to read * @param buffer The buffer to copy read bytes into. * @param length The length of the buffer. * @return See hdfsRead */ LIBHDFS_EXTERNAL tSize hdfsPread(hdfsFS fs, hdfsFile file, tOffset position, void* buffer, tSize length); /** * hdfsWrite - Write data into an open file. * @param fs The configured filesystem handle. * @param file The file handle. * @param buffer The data. * @param length The no. of bytes to write. * @return Returns the number of bytes written, -1 on error. */ LIBHDFS_EXTERNAL tSize hdfsWrite(hdfsFS fs, hdfsFile file, const void* buffer, tSize length); /** * hdfsWrite - Flush the data. * @param fs The configured filesystem handle. * @param file The file handle. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsFlush(hdfsFS fs, hdfsFile file); /** * hdfsHFlush - Flush out the data in client's user buffer. After the * return of this call, new readers will see the data. * @param fs configured filesystem handle * @param file file handle * @return 0 on success, -1 on error and sets errno */ LIBHDFS_EXTERNAL int hdfsHFlush(hdfsFS fs, hdfsFile file); /** * hdfsHSync - Similar to posix fsync, Flush out the data in client's * user buffer. all the way to the disk device (but the disk may have * it in its cache). * @param fs configured filesystem handle * @param file file handle * @return 0 on success, -1 on error and sets errno */ LIBHDFS_EXTERNAL int hdfsHSync(hdfsFS fs, hdfsFile file); /** * hdfsAvailable - Number of bytes that can be read from this * input stream without blocking. * @param fs The configured filesystem handle. * @param file The file handle. * @return Returns available bytes; -1 on error. */ LIBHDFS_EXTERNAL int hdfsAvailable(hdfsFS fs, hdfsFile file); /** * hdfsCopy - Copy file from one filesystem to another. * @param srcFS The handle to source filesystem. * @param src The path of source file. * @param dstFS The handle to destination filesystem. * @param dst The path of destination file. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsCopy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst); /** * hdfsMove - Move file from one filesystem to another. * @param srcFS The handle to source filesystem. * @param src The path of source file. * @param dstFS The handle to destination filesystem. * @param dst The path of destination file. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsMove(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst); /** * hdfsDelete - Delete file. * @param fs The configured filesystem handle. * @param path The path of the file. * @param recursive if path is a directory and set to * non-zero, the directory is deleted else throws an exception. In * case of a file the recursive argument is irrelevant. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsDelete(hdfsFS fs, const char* path, int recursive); /** * hdfsRename - Rename file. * @param fs The configured filesystem handle. * @param oldPath The path of the source file. * @param newPath The path of the destination file. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsRename(hdfsFS fs, const char* oldPath, const char* newPath); /** * hdfsGetWorkingDirectory - Get the current working directory for * the given filesystem. * @param fs The configured filesystem handle. * @param buffer The user-buffer to copy path of cwd into. * @param bufferSize The length of user-buffer. * @return Returns buffer, NULL on error. */ LIBHDFS_EXTERNAL char* hdfsGetWorkingDirectory(hdfsFS fs, char* buffer, size_t bufferSize); /** * hdfsSetWorkingDirectory - Set the working directory. All relative * paths will be resolved relative to it. * @param fs The configured filesystem handle. * @param path The path of the new 'cwd'. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsSetWorkingDirectory(hdfsFS fs, const char* path); /** * hdfsCreateDirectory - Make the given file and all non-existent * parents into directories. * @param fs The configured filesystem handle. * @param path The path of the directory. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsCreateDirectory(hdfsFS fs, const char* path); /** * hdfsSetReplication - Set the replication of the specified * file to the supplied value * @param fs The configured filesystem handle. * @param path The path of the file. * @return Returns 0 on success, -1 on error. */ LIBHDFS_EXTERNAL int hdfsSetReplication(hdfsFS fs, const char* path, int16_t replication); /** * hdfsFileInfo - Information about a file/directory. */ typedef struct { tObjectKind mKind; /* file or directory */ char* mName; /* the name of the file */ tTime mLastMod; /* the last modification time for the file in seconds */ tOffset mSize; /* the size of the file in bytes */ short mReplication; /* the count of replicas */ tOffset mBlockSize; /* the block size for the file */ char* mOwner; /* the owner of the file */ char* mGroup; /* the group associated with the file */ short mPermissions; /* the permissions associated with the file */ tTime mLastAccess; /* the last access time for the file in seconds */ } hdfsFileInfo; /** * hdfsListDirectory - Get list of files/directories for a given * directory-path. hdfsFreeFileInfo should be called to deallocate memory. * @param fs The configured filesystem handle. * @param path The path of the directory. * @param numEntries Set to the number of files/directories in path. * @return Returns a dynamically-allocated array of hdfsFileInfo * objects; NULL on error. */ LIBHDFS_EXTERNAL hdfsFileInfo* hdfsListDirectory(hdfsFS fs, const char* path, int* numEntries); /** * hdfsGetPathInfo - Get information about a path as a (dynamically * allocated) single hdfsFileInfo struct. hdfsFreeFileInfo should be * called when the pointer is no longer needed. * @param fs The configured filesystem handle. * @param path The path of the file. * @return Returns a dynamically-allocated hdfsFileInfo object; * NULL on error. */ LIBHDFS_EXTERNAL hdfsFileInfo* hdfsGetPathInfo(hdfsFS fs, const char* path); /** * hdfsFreeFileInfo - Free up the hdfsFileInfo array (including fields) * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo * objects. * @param numEntries The size of the array. */ LIBHDFS_EXTERNAL void hdfsFreeFileInfo(hdfsFileInfo* hdfsFileInfo, int numEntries); /** * hdfsFileIsEncrypted: determine if a file is encrypted based on its * hdfsFileInfo. * @return -1 if there was an error (errno will be set), 0 if the file is * not encrypted, 1 if the file is encrypted. */ LIBHDFS_EXTERNAL int hdfsFileIsEncrypted(hdfsFileInfo* hdfsFileInfo); /** * hdfsGetHosts - Get hostnames where a particular block (determined by * pos & blocksize) of a file is stored. The last element in the array * is NULL. Due to replication, a single block could be present on * multiple hosts. * @param fs The configured filesystem handle. * @param path The path of the file. * @param start The start of the block. * @param length The length of the block. * @return Returns a dynamically-allocated 2-d array of blocks-hosts; * NULL on error. */ LIBHDFS_EXTERNAL char*** hdfsGetHosts(hdfsFS fs, const char* path, tOffset start, tOffset length); /** * hdfsFreeHosts - Free up the structure returned by hdfsGetHosts * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo * objects. * @param numEntries The size of the array. */ LIBHDFS_EXTERNAL void hdfsFreeHosts(char*** blockHosts); /** * hdfsGetDefaultBlockSize - Get the default blocksize. * * @param fs The configured filesystem handle. * @deprecated Use hdfsGetDefaultBlockSizeAtPath instead. * * @return Returns the default blocksize, or -1 on error. */ LIBHDFS_EXTERNAL tOffset hdfsGetDefaultBlockSize(hdfsFS fs); /** * hdfsGetDefaultBlockSizeAtPath - Get the default blocksize at the * filesystem indicated by a given path. * * @param fs The configured filesystem handle. * @param path The given path will be used to locate the actual * filesystem. The full path does not have to exist. * * @return Returns the default blocksize, or -1 on error. */ LIBHDFS_EXTERNAL tOffset hdfsGetDefaultBlockSizeAtPath(hdfsFS fs, const char* path); /** * hdfsGetCapacity - Return the raw capacity of the filesystem. * @param fs The configured filesystem handle. * @return Returns the raw-capacity; -1 on error. */ LIBHDFS_EXTERNAL tOffset hdfsGetCapacity(hdfsFS fs); /** * hdfsGetUsed - Return the total raw size of all files in the filesystem. * @param fs The configured filesystem handle. * @return Returns the total-size; -1 on error. */ LIBHDFS_EXTERNAL tOffset hdfsGetUsed(hdfsFS fs); /** * Change the user and/or group of a file or directory. * * @param fs The configured filesystem handle. * @param path the path to the file or directory * @param owner User string. Set to NULL for 'no change' * @param group Group string. Set to NULL for 'no change' * @return 0 on success else -1 */ LIBHDFS_EXTERNAL int hdfsChown(hdfsFS fs, const char* path, const char* owner, const char* group); /** * hdfsChmod * @param fs The configured filesystem handle. * @param path the path to the file or directory * @param mode the bitmask to set it to * @return 0 on success else -1 */ LIBHDFS_EXTERNAL int hdfsChmod(hdfsFS fs, const char* path, short mode); /** * hdfsUtime * @param fs The configured filesystem handle. * @param path the path to the file or directory * @param mtime new modification time or -1 for no change * @param atime new access time or -1 for no change * @return 0 on success else -1 */ LIBHDFS_EXTERNAL int hdfsUtime(hdfsFS fs, const char* path, tTime mtime, tTime atime); /** * Allocate a zero-copy options structure. * * You must free all options structures allocated with this function using * hadoopRzOptionsFree. * * @return A zero-copy options structure, or NULL if one could * not be allocated. If NULL is returned, errno will * contain the error number. */ LIBHDFS_EXTERNAL struct hadoopRzOptions* hadoopRzOptionsAlloc(void); /** * Determine whether we should skip checksums in read0. * * @param opts The options structure. * @param skip Nonzero to skip checksums sometimes; zero to always * check them. * * @return 0 on success; -1 plus errno on failure. */ LIBHDFS_EXTERNAL int hadoopRzOptionsSetSkipChecksum(struct hadoopRzOptions* opts, int skip); /** * Set the ByteBufferPool to use with read0. * * @param opts The options structure. * @param className If this is NULL, we will not use any * ByteBufferPool. If this is non-NULL, it will be * treated as the name of the pool class to use. * For example, you can use * ELASTIC_BYTE_BUFFER_POOL_CLASS. * * @return 0 if the ByteBufferPool class was found and * instantiated; * -1 plus errno otherwise. */ LIBHDFS_EXTERNAL int hadoopRzOptionsSetByteBufferPool(struct hadoopRzOptions* opts, const char* className); /** * Free a hadoopRzOptionsFree structure. * * @param opts The options structure to free. * Any associated ByteBufferPool will also be freed. */ LIBHDFS_EXTERNAL void hadoopRzOptionsFree(struct hadoopRzOptions* opts); /** * Perform a byte buffer read. * If possible, this will be a zero-copy (mmap) read. * * @param file The file to read from. * @param opts An options structure created by hadoopRzOptionsAlloc. * @param maxLength The maximum length to read. We may read fewer bytes * than this length. * * @return On success, we will return a new hadoopRzBuffer. * This buffer will continue to be valid and readable * until it is released by readZeroBufferFree. Failure to * release a buffer will lead to a memory leak. * You can access the data within the hadoopRzBuffer with * hadoopRzBufferGet. If you have reached EOF, the data * within the hadoopRzBuffer will be NULL. You must still * free hadoopRzBuffer instances containing NULL. * * On failure, we will return NULL plus an errno code. * errno = EOPNOTSUPP indicates that we could not do a * zero-copy read, and there was no ByteBufferPool * supplied. */ LIBHDFS_EXTERNAL struct hadoopRzBuffer* hadoopReadZero(hdfsFile file, struct hadoopRzOptions* opts, int32_t maxLength); /** * Determine the length of the buffer returned from readZero. * * @param buffer a buffer returned from readZero. * @return the length of the buffer. */ LIBHDFS_EXTERNAL int32_t hadoopRzBufferLength(const struct hadoopRzBuffer* buffer); /** * Get a pointer to the raw buffer returned from readZero. * * To find out how many bytes this buffer contains, call * hadoopRzBufferLength. * * @param buffer a buffer returned from readZero. * @return a pointer to the start of the buffer. This will be * NULL when end-of-file has been reached. */ LIBHDFS_EXTERNAL const void* hadoopRzBufferGet(const struct hadoopRzBuffer* buffer); /** * Release a buffer obtained through readZero. * * @param file The hdfs stream that created this buffer. This must be * the same stream you called hadoopReadZero on. * @param buffer The buffer to release. */ LIBHDFS_EXTERNAL void hadoopRzBufferFree(hdfsFile file, struct hadoopRzBuffer* buffer); #ifdef __cplusplus } #endif #undef LIBHDFS_EXTERNAL #endif /*ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_*/ /** * vim: ts=4: sw=4: et */ ================================================ FILE: oneflow/core/persistence/persistent_in_stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/persistence/binary_in_stream_with_local_copy.h" #include "oneflow/core/persistence/binary_in_stream_without_local_copy.h" #include "oneflow/core/job/job_set.pb.h" #include #include "oneflow/core/common/constant.h" namespace oneflow { namespace { constexpr size_t kDefaultBufferSize = 32 * 1024; // 32KB size_t GetBufferSize() { const char* buf_size_str = std::getenv("ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES"); if (buf_size_str) { int buf_size = atoi(buf_size_str); if (buf_size > 0) { return buf_size; } else { LOG(WARNING) << "invalid env ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES " << buf_size_str << ", default size " << kDefaultBufferSize << " is set"; return kDefaultBufferSize; } } return kDefaultBufferSize; } } // namespace PersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::vector& file_paths, uint64_t offset, bool cyclic, bool with_local_copy) : PersistentInStream(kInvalidSessionId, fs, file_paths, offset, cyclic, with_local_copy) {} PersistentInStream::PersistentInStream(int64_t session_id, fs::FileSystem* fs, const std::vector& file_paths, uint64_t offset, bool cyclic, bool with_local_copy) { if (with_local_copy) { CHECK_EQ(offset, 0); } std::vector> streams; for (auto& file_path : file_paths) { if (with_local_copy) { streams.emplace_back(new BinaryInStreamWithLocalCopy(fs, file_path)); } else { streams.emplace_back(new BinaryInStreamWithoutLocalCopy(fs, file_path)); } } if (cyclic) { stream_scanner_.reset(new CyclicStreamScanner(fs, streams, offset)); } else { stream_scanner_.reset(new AcyclicStreamScanner(fs, streams, offset)); } buffer_.resize(GetBufferSize() + 1); cur_buf_begin_ = buffer_.data(); cur_buf_end_ = buffer_.data(); *cur_buf_end_ = '\0'; } PersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::vector& file_paths, bool cyclic, bool with_local_copy) : PersistentInStream(fs, file_paths, 0, cyclic, with_local_copy) {} PersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset, bool cyclic, bool with_local_copy) : PersistentInStream(fs, std::vector({file_path}), offset, cyclic, with_local_copy) {} PersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset) : PersistentInStream(fs, file_path, offset, false, false) {} PersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path) : PersistentInStream(fs, file_path, 0, false, false) {} PersistentInStream::PersistentInStream(int64_t session_id, fs::FileSystem* fs, const std::string& file_path) : PersistentInStream(session_id, fs, std::vector({file_path}), 0, false, false) {} int32_t PersistentInStream::ReadLine(std::string* l) { if (IsEof()) { return -1; } l->clear(); while (*cur_buf_begin_ != '\n') { if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); if (cur_buf_begin_ == cur_buf_end_) { return 0; } else { continue; } } l->push_back(*cur_buf_begin_++); } ++cur_buf_begin_; return 0; } int32_t PersistentInStream::ReadFully(char* s, size_t n) { if (IsEof()) { return -1; } while (n) { if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); } CHECK_LT(cur_buf_begin_, cur_buf_end_); int64_t copy_size = std::min(cur_buf_end_ - cur_buf_begin_, n); std::memcpy(s, cur_buf_begin_, static_cast(copy_size)); s += copy_size; cur_buf_begin_ += copy_size; n -= copy_size; } return 0; } void PersistentInStream::UpdateBuffer() { CHECK_EQ(cur_buf_begin_, cur_buf_end_); uint64_t n = stream_scanner_->UpdateBuffer(&buffer_); cur_buf_begin_ = buffer_.data(); cur_buf_end_ = buffer_.data() + n; *cur_buf_end_ = '\0'; } bool PersistentInStream::IsEof() const { return cur_buf_begin_ == cur_buf_end_ && stream_scanner_->IsEof(); } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/persistent_in_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_ #define ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_ #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/stream_scanner.h" namespace oneflow { class PersistentInStream { public: OF_DISALLOW_COPY_AND_MOVE(PersistentInStream); virtual ~PersistentInStream() {} PersistentInStream(fs::FileSystem* fs, const std::vector& file_paths, uint64_t offset, bool cyclic, bool with_local_copy); PersistentInStream(fs::FileSystem* fs, const std::vector& file_paths, bool cyclic, bool with_local_copy); PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset, bool cyclic, bool with_local_copy); PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset); PersistentInStream(fs::FileSystem* fs, const std::string& file_path); PersistentInStream(int64_t session_id, fs::FileSystem* fs, const std::string& file_path); PersistentInStream(int64_t session_id, fs::FileSystem* fs, const std::vector& file_paths, uint64_t offset, bool cyclic, bool with_local_copy); // 0: success // -1: eof int32_t ReadLine(std::string* l); int32_t ReadFully(char* s, size_t n); private: bool IsEof() const; void UpdateBuffer(); std::unique_ptr stream_scanner_; std::vector buffer_; char* cur_buf_begin_; char* cur_buf_end_; }; } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_ ================================================ FILE: oneflow/core/persistence/persistent_out_stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/persistent_out_stream.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { PersistentOutStream::PersistentOutStream(fs::FileSystem* fs, const std::string& file_path) { std::string file_dir = Dirname(file_path); OfCallOnce(GlobalProcessCtx::LogDirEntry() + "/" + file_dir, fs, &fs::FileSystem::RecursivelyCreateDirIfNotExist, file_dir); fs->NewWritableFile(file_path, &file_); } PersistentOutStream::~PersistentOutStream() { file_->Close(); } PersistentOutStream& PersistentOutStream::Write(const char* s, size_t n) { file_->Append(s, n); return *this; } void PersistentOutStream::Flush() { file_->Flush(); } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/persistent_out_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_PERSISTENT_OUT_STREAM_H_ #define ONEFLOW_CORE_PERSISTENCE_PERSISTENT_OUT_STREAM_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/file_system.h" namespace oneflow { class PersistentOutStream final { public: OF_DISALLOW_COPY_AND_MOVE(PersistentOutStream); PersistentOutStream() = delete; ~PersistentOutStream(); PersistentOutStream(fs::FileSystem*, const std::string& file_path); // Write block of data // Inserts the first n characters of the array pointed by s into the stream. PersistentOutStream& Write(const char* s, size_t n); void Flush(); private: std::unique_ptr file_; }; template typename std::enable_if::value, PersistentOutStream&>::type operator<<( PersistentOutStream& out_stream, const T& x) { const char* x_ptr = reinterpret_cast(&x); size_t n = sizeof(x); out_stream.Write(x_ptr, n); return out_stream; } inline PersistentOutStream& operator<<(PersistentOutStream& out_stream, const std::string& s) { out_stream.Write(s.c_str(), s.size()); return out_stream; } template PersistentOutStream& operator<<(PersistentOutStream& out_stream, const char (&s)[n]) { out_stream.Write(s, strlen(s)); return out_stream; } } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCY_PERSISTENT_OUT_STREAM_H ================================================ FILE: oneflow/core/persistence/posix/posix_file_system.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/posix/posix_file_system.h" #ifdef OF_PLATFORM_POSIX #include #include #include #include #include #include #include #include #include #include namespace oneflow { namespace fs { class PosixRandomAccessFile : public RandomAccessFile { private: std::string fname_; int fd_; public: PosixRandomAccessFile(const std::string& fname, int fd) : fname_(fname), fd_(fd) {} ~PosixRandomAccessFile() override { close(fd_); } void Read(uint64_t offset, size_t n, char* result) const override { char* dst = result; while (n > 0) { ssize_t r = pread(fd_, dst, n, static_cast(offset)); if (r > 0) { dst += r; n -= r; offset += r; } else if (r == 0) { PLOG(FATAL) << "Read EOF"; return; } else if (errno == EINTR || errno == EAGAIN) { // Retry } else { PLOG(FATAL) << "Fail to read file " << fname_; return; } } } }; class PosixWritableFile : public WritableFile { private: std::string fname_; FILE* file_; public: PosixWritableFile(const std::string& fname, FILE* file) : fname_(fname), file_(file) {} ~PosixWritableFile() override { if (file_ != nullptr) { fclose(file_); } } void Append(const char* data, size_t n) override { PCHECK(fwrite(data, sizeof(char), n, file_) == n) << "Fail to append to file " << fname_ << ", errno is " << errno; } void Close() override { Flush(); PCHECK(fclose(file_) == 0) << "Fail to close file " << fname_ << ", errno is " << errno; file_ = nullptr; } void Flush() override { PCHECK(fflush(file_) == 0) << "Fail to flush file " << fname_ << ", errno is " << errno; } }; void PosixFileSystem::NewRandomAccessFile(const std::string& fname, std::unique_ptr* result) { std::string translated_fname = TranslateName(fname); int fd = open(translated_fname.c_str(), O_RDONLY); PCHECK(fd >= 0) << "Fail to open file " << fname << ", errno is " << errno; result->reset(new PosixRandomAccessFile(fname, fd)); CHECK_NOTNULL(result->get()); } void PosixFileSystem::NewWritableFile(const std::string& fname, std::unique_ptr* result) { std::string translated_fname = TranslateName(fname); FILE* f = fopen(translated_fname.c_str(), "w"); PCHECK(f != nullptr) << "Fail to open file " << fname << ", errno is " << errno; result->reset(new PosixWritableFile(translated_fname, f)); CHECK_NOTNULL(result->get()); } void PosixFileSystem::NewAppendableFile(const std::string& fname, std::unique_ptr* result) { std::string translated_name = TranslateName(fname); FILE* f = fopen(translated_name.c_str(), "a"); PCHECK(f != nullptr) << "Fail to open file " << fname << ", errno is " << errno; result->reset(new PosixWritableFile(translated_name, f)); CHECK_NOTNULL(result->get()); } bool PosixFileSystem::FileExists(const std::string& fname) { if (access(TranslateName(fname).c_str(), F_OK) == 0) { return true; } return false; } std::vector PosixFileSystem::ListDir(const std::string& dir) { std::string translated_dir = TranslateName(dir); std::vector result; DIR* d = opendir(translated_dir.c_str()); PCHECK(d != nullptr) << "Fail to open dir " << dir << ", errno is " << errno; struct dirent* entry; while ((entry = readdir(d)) != nullptr) { if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { continue; } result.emplace_back(entry->d_name); } closedir(d); return result; } void PosixFileSystem::DelFile(const std::string& fname) { PCHECK(unlink(TranslateName(fname).c_str()) == 0) << "Fail to delete file " << fname << ", errno is " << errno; } void PosixFileSystem::CreateDir(const std::string& dirname) { PCHECK(mkdir(TranslateName(dirname).c_str(), 0755) == 0) << "Fail to create dir " << dirname << ", errno is " << errno; } void PosixFileSystem::CreateDirIfNotExist(const std::string& dirname) { int ret = mkdir(TranslateName(dirname).c_str(), 0755); PCHECK(ret == 0 || (errno == EEXIST && IsDirectory(dirname))) << "Fail to create dir " << dirname << ", errno is " << errno; } void PosixFileSystem::DeleteDir(const std::string& dirname) { PCHECK(rmdir(TranslateName(dirname).c_str()) == 0) << "Fail to delete dir " << dirname << ", errno is " << errno; } uint64_t PosixFileSystem::GetFileSize(const std::string& fname) { struct stat sbuf; PCHECK(stat(TranslateName(fname).c_str(), &sbuf) == 0) << "Fail to load statistics of " << fname << ", errno is " << errno; return sbuf.st_size; } void PosixFileSystem::RenameFile(const std::string& old_name, const std::string& new_name) { PCHECK(rename(TranslateName(old_name).c_str(), TranslateName(new_name).c_str()) == 0) << "Fail to rename file from " << old_name << " to " << new_name << ", errno is " << errno; } bool PosixFileSystem::IsDirectory(const std::string& fname) { struct stat sbuf; if (stat(TranslateName(fname).c_str(), &sbuf) == 0 && S_ISDIR(sbuf.st_mode)) { return true; } return false; } } // namespace fs } // namespace oneflow #endif // OF_PLATFORM_POSIX ================================================ FILE: oneflow/core/persistence/posix/posix_file_system.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_ #define ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_ #include "oneflow/core/persistence/file_system.h" #ifdef OF_PLATFORM_POSIX namespace oneflow { namespace fs { class PosixFileSystem final : public FileSystem { public: OF_DISALLOW_COPY_AND_MOVE(PosixFileSystem); PosixFileSystem() = default; ~PosixFileSystem() = default; void NewRandomAccessFile(const std::string& fname, std::unique_ptr* result) override; void NewWritableFile(const std::string& fname, std::unique_ptr* result) override; void NewAppendableFile(const std::string& fname, std::unique_ptr* result) override; bool FileExists(const std::string& fname) override; std::vector ListDir(const std::string& dir) override; void DelFile(const std::string& fname) override; void CreateDir(const std::string& dirname) override; void CreateDirIfNotExist(const std::string& dirname) override; void DeleteDir(const std::string& dirname) override; uint64_t GetFileSize(const std::string& fname) override; void RenameFile(const std::string& old_name, const std::string& new_name) override; bool IsDirectory(const std::string& fname) override; private: }; } // namespace fs } // namespace oneflow #endif // OF_PLATFORM_POSIX #endif // ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_ ================================================ FILE: oneflow/core/persistence/stream_scanner.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/stream_scanner.h" #include "oneflow/core/persistence/binary_in_stream_without_local_copy.h" #include "oneflow/core/persistence/binary_in_stream_with_local_copy.h" namespace oneflow { StreamScanner::StreamScanner(fs::FileSystem* fs, const std::vector>& streams, uint64_t offset) : whole_file_offset_(offset) { stream_num_ = streams.size(); whole_file_size_ = 0; int64_t idx = 0; for (auto& stream : streams) { AddStream(fs, stream, idx); ++idx; } CHECK_LE(whole_file_offset_, whole_file_size_); whole_file_pos_ = whole_file_offset_; } void StreamScanner::AddStream(fs::FileSystem* fs, const std::shared_ptr& stream, int64_t idx) { uint64_t cur_file_size = stream->file_size(); if (whole_file_offset_ < whole_file_size_) { stream->set_cur_file_pos(0); } else if (whole_file_size_ <= whole_file_offset_ && whole_file_offset_ < whole_file_size_ + cur_file_size) { stream->set_cur_file_pos(whole_file_offset_ - whole_file_size_); cur_stream_id_ = idx; } else if (whole_file_offset_ >= whole_file_size_ + cur_file_size) { stream->set_cur_file_pos(0); // works for both cyclic and acyclic cases } streams_.emplace_back(stream); whole_file_size_ += cur_file_size; } bool StreamScanner::IsEof() const { return whole_file_pos_ == whole_file_size_; } uint64_t StreamScanner::UpdateBuffer(std::vector* buffer) { if (cur_stream_id_ == stream_num_) return 0; uint64_t n = std::min(buffer->size() - 1, streams_[cur_stream_id_]->file_size() - streams_[cur_stream_id_]->cur_file_pos()); if (n == 0) { return 0; } streams_[cur_stream_id_]->Read(buffer->data(), n); AddNForCurFilePos(n); return n; } void AcyclicStreamScanner::AddNForCurFilePos(uint64_t n) { whole_file_pos_ += n; if (streams_[cur_stream_id_]->IsEof()) { ++cur_stream_id_; } } void CyclicStreamScanner::AddNForCurFilePos(uint64_t n) { whole_file_pos_ = (whole_file_pos_ + n) % whole_file_size_; if (streams_[cur_stream_id_]->IsEof()) { streams_[cur_stream_id_]->set_cur_file_pos(0); ++cur_stream_id_; if (cur_stream_id_ == stream_num_) { CHECK_EQ(whole_file_pos_, 0); cur_stream_id_ = 0; } } } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/stream_scanner.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_ #define ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_ #include #include #include "oneflow/core/persistence/binary_in_stream.h" #include "oneflow/core/persistence/file_system.h" namespace oneflow { class StreamScanner { public: OF_DISALLOW_COPY_AND_MOVE(StreamScanner); virtual ~StreamScanner() {} StreamScanner(fs::FileSystem* fs, const std::vector>& streams, uint64_t offset); bool IsEof() const; uint64_t UpdateBuffer(std::vector* buffer); protected: virtual void AddNForCurFilePos(uint64_t n) = 0; std::vector> streams_; uint64_t whole_file_size_; uint64_t whole_file_pos_; int32_t cur_stream_id_; int32_t stream_num_; uint64_t whole_file_offset_; private: void AddStream(fs::FileSystem* fs, const std::shared_ptr& stream, int64_t idx); }; class CyclicStreamScanner final : public StreamScanner { public: OF_DISALLOW_COPY_AND_MOVE(CyclicStreamScanner); CyclicStreamScanner(fs::FileSystem* fs, const std::vector>& streams, uint64_t offset) : StreamScanner(fs, streams, offset) {} protected: void AddNForCurFilePos(uint64_t n) override; }; class AcyclicStreamScanner final : public StreamScanner { public: OF_DISALLOW_COPY_AND_MOVE(AcyclicStreamScanner); AcyclicStreamScanner(fs::FileSystem* fs, const std::vector>& streams, uint64_t offset) : StreamScanner(fs, streams, offset) {} protected: void AddNForCurFilePos(uint64_t n) override; }; } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_ ================================================ FILE: oneflow/core/persistence/tee_persistent_log_stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" #include namespace oneflow { TeePersistentLogStream::TeePersistentLogStream(const std::string& path) { destinations_.emplace_back(LocalFS(), FLAGS_log_dir); branches_.reserve(destinations_.size()); for (const auto& destination : destinations_) { branches_.emplace_back(std::make_unique( destination.mut_file_system(), JoinPath(destination.base_dir(), path))); } } TeePersistentLogStream::~TeePersistentLogStream() { Flush(); } std::unique_ptr TeePersistentLogStream::Create(const std::string& path) { auto stream_ptr = new TeePersistentLogStream(path); return std::unique_ptr(stream_ptr); } void TeePersistentLogStream::Flush() { for (const auto& branch : branches_) { branch->Flush(); } }; void TeePersistentLogStream::Write(const char* s, size_t n) { for (const auto& branch : branches_) { branch->Write(s, n); } }; void TeePersistentLogStream::Write(const std::string& str) { this->Write(str.data(), str.size()); } void TeePersistentLogStream::Write(const PbMessage& proto) { std::string output; google::protobuf::TextFormat::PrintToString(proto, &output); this->Write(output); } } // namespace oneflow ================================================ FILE: oneflow/core/persistence/tee_persistent_log_stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_ #define ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_ #include "oneflow/core/common/protobuf.h" #include "oneflow/core/persistence/persistent_out_stream.h" namespace oneflow { class LogStreamDestination final { public: LogStreamDestination(fs::FileSystem* file_system, const std::string& base_dir) : file_system_(file_system), base_dir_(base_dir) {} ~LogStreamDestination() = default; fs::FileSystem* mut_file_system() const { return file_system_; }; const std::string& base_dir() const { return base_dir_; }; private: fs::FileSystem* file_system_; std::string base_dir_; }; class TeePersistentLogStream final { public: OF_DISALLOW_COPY_AND_MOVE(TeePersistentLogStream); ~TeePersistentLogStream(); void Write(const char* s, size_t n); void Write(const std::string& str); void Write(const PbMessage& proto); static std::unique_ptr Create(const std::string& path); void Flush(); private: explicit TeePersistentLogStream(const std::string& path); std::vector destinations_; std::vector> branches_; }; inline TeePersistentLogStream& operator<<(TeePersistentLogStream& log_stream, const std::string& s) { log_stream.Write(s.c_str(), s.size()); return log_stream; } inline std::unique_ptr& operator<<( std::unique_ptr& log_stream, const std::string& s) { log_stream->Write(s.c_str(), s.size()); return log_stream; } } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_ ================================================ FILE: oneflow/core/platform/include/ibv.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(WITH_RDMA) #ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_ #define ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_ #include "oneflow/core/platform/include/wrapper.h" #include namespace oneflow { namespace ibv { // has to add extern otherwise it fails to compile at changes meaning of functions extern "C" typedef struct IBV { #define IBV_APIS(_) \ _(ibv_free_device_list) \ _(ibv_destroy_qp) \ _(ibv_query_gid) \ _(ibv_fork_init) \ _(ibv_open_device) \ _(ibv_destroy_cq) \ _(ibv_alloc_pd) \ _(ibv_modify_qp) \ _(ibv_dealloc_pd) \ _(ibv_get_device_list) \ _(ibv_close_device) \ _(ibv_create_qp) \ _(ibv_dereg_mr) \ _(ibv_create_cq) \ _(ibv_query_device) \ _(ibv_get_device_name) #define DECLARE_ONE(name) decltype(&name) name; IBV_APIS(DECLARE_ONE) #undef DECLARE_ONE // for a function is not only a function but also a macro, // it requires an alternative name struct ibv_mr* (*ibv_reg_mr_wrap)(struct ibv_pd* pd, void* addr, size_t length, int access); int (*ibv_query_port_wrap)(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr); } IBV; bool IsAvailable(); extern IBV wrapper; } // namespace ibv } // namespace oneflow #endif // ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_ #endif // WITH_RDMA ================================================ FILE: oneflow/core/platform/include/pthread_fork.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_ #define ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_ namespace oneflow { namespace pthread_fork { bool IsForkedSubProcess(); extern const char* kOfCudaNotSupportInForkedSubProcess; } // namespace pthread_fork } // namespace oneflow #endif // ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_ ================================================ FILE: oneflow/core/platform/include/wrapper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_ #define ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace platform { class DynamicLibrary { public: OF_DISALLOW_COPY_AND_MOVE(DynamicLibrary); ~DynamicLibrary(); static std::unique_ptr Load(const std::vector& names); void* LoadSym(const char* name); #ifdef __linux__ std::string AbsolutePath(); #endif // __linux__ private: DynamicLibrary(void* handle) : handle_(handle){}; void* handle_ = nullptr; }; } // namespace platform } // namespace oneflow #endif // ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_ ================================================ FILE: oneflow/core/platform/lib/ibv_wrapper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(WITH_RDMA) #include "oneflow/core/platform/include/ibv.h" namespace oneflow { namespace ibv { std::vector GetLibPaths() { const char* custom_path = std::getenv("ONEFLOW_LIBIBVERBS_PATH"); if (custom_path == nullptr) { return {"libibverbs.so.1", "libibverbs.so"}; } else { return {custom_path}; } } platform::DynamicLibrary* GetIBVLibraryPtr() { static std::unique_ptr lib = platform::DynamicLibrary::Load(GetLibPaths()); return lib.get(); } platform::DynamicLibrary& GetIBVLibrary() { platform::DynamicLibrary* lib = GetIBVLibraryPtr(); CHECK(lib != nullptr) << "fail to find libibverbs"; return *lib; } template FUNC LoadSymbol(const char* name, FUNC* save) { auto fn = reinterpret_cast(GetIBVLibrary().LoadSym(name)); if (!fn) { std::cerr << "Can't load libibverbs symbol " << name << "\n"; abort(); }; *save = fn; return fn; } bool IsAvailable() { return GetIBVLibraryPtr() != nullptr; } namespace _stubs { void ibv_free_device_list(struct ibv_device** list) { return LoadSymbol(__func__, &wrapper.ibv_free_device_list)(list); } struct ibv_mr* ibv_reg_mr_wrap(struct ibv_pd* pd, void* addr, size_t length, int access) { return LoadSymbol("ibv_reg_mr", &wrapper.ibv_reg_mr_wrap)(pd, addr, length, access); } int ibv_destroy_qp(struct ibv_qp* qp) { return LoadSymbol(__func__, &wrapper.ibv_destroy_qp)(qp); } int ibv_query_gid(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid) { return LoadSymbol(__func__, &wrapper.ibv_query_gid)(context, port_num, index, gid); } int ibv_fork_init(void) { return LoadSymbol(__func__, &wrapper.ibv_fork_init)(); } int ibv_query_port_wrap(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr) { return LoadSymbol("ibv_query_port", &wrapper.ibv_query_port_wrap)(context, port_num, port_attr); } struct ibv_context* ibv_open_device(struct ibv_device* device) { return LoadSymbol(__func__, &wrapper.ibv_open_device)(device); } int ibv_destroy_cq(struct ibv_cq* cq) { return LoadSymbol(__func__, &wrapper.ibv_destroy_cq)(cq); } struct ibv_pd* ibv_alloc_pd(struct ibv_context* context) { return LoadSymbol(__func__, &wrapper.ibv_alloc_pd)(context); } int ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask) { return LoadSymbol(__func__, &wrapper.ibv_modify_qp)(qp, attr, attr_mask); } int ibv_dealloc_pd(struct ibv_pd* pd) { return LoadSymbol(__func__, &wrapper.ibv_dealloc_pd)(pd); } struct ibv_device** ibv_get_device_list(int* num_devices) { return LoadSymbol(__func__, &wrapper.ibv_get_device_list)(num_devices); } int ibv_close_device(struct ibv_context* context) { return LoadSymbol(__func__, &wrapper.ibv_close_device)(context); } struct ibv_qp* ibv_create_qp(struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr) { return LoadSymbol(__func__, &wrapper.ibv_create_qp)(pd, qp_init_attr); } int ibv_dereg_mr(struct ibv_mr* mr) { return LoadSymbol(__func__, &wrapper.ibv_dereg_mr)(mr); } struct ibv_cq* ibv_create_cq(struct ibv_context* context, int cqe, void* cq_context, struct ibv_comp_channel* channel, int comp_vector) { return LoadSymbol(__func__, &wrapper.ibv_create_cq)(context, cqe, cq_context, channel, comp_vector); } int ibv_query_device(struct ibv_context* context, struct ibv_device_attr* device_attr) { return LoadSymbol(__func__, &wrapper.ibv_query_device)(context, device_attr); } const char* ibv_get_device_name(struct ibv_device* device) { return LoadSymbol(__func__, &wrapper.ibv_get_device_name)(device); } } // namespace _stubs IBV wrapper = { #define _REFERENCE_MEMBER(name) _stubs::name, IBV_APIS(_REFERENCE_MEMBER) #undef _REFERENCE_MEMBER _stubs::ibv_reg_mr_wrap, _stubs::ibv_query_port_wrap}; } // namespace ibv } // namespace oneflow #endif // WITH_RDMA ================================================ FILE: oneflow/core/platform/lib/pthread_fork.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/common/util.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/sync_vm_mode_guard.h" namespace oneflow { namespace pthread_fork { static bool is_fork = false; bool IsForkedSubProcess() { return is_fork; } static void SetIsForkedSubProcess() { is_fork = true; } namespace { void CurrentRankVmSync() { if (SyncVmModeGuard::IsCurrentSyncVmMode()) { return; } // Instructions in forked subprocesses are not dispatched to vm, // so no need to sync vm in these processes. if (!is_fork && Singleton::Get() != nullptr) { CHECK_JUST(vm::CurrentRankSync()); } } } // namespace void RegisterForkCallback() { pthread_atfork(&CurrentRankVmSync, nullptr, &SetIsForkedSubProcess); } COMMAND(RegisterForkCallback()); const char* kOfCudaNotSupportInForkedSubProcess = "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you " "must add 'multiprocessing.set_start_method(\"spawn\")' in '__main__' if you are using " "Python's multiprocessing"; } // namespace pthread_fork } // namespace oneflow ================================================ FILE: oneflow/core/platform/lib/wrapper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/platform/include/wrapper.h" #include #ifdef __linux__ #include #endif // __linux__ namespace oneflow { namespace platform { namespace { void* OpenSymbol(void* handle, const char* name) { void* ret = dlsym(handle, name); if (!ret) { std::cerr << "Error in dlopen or dlsym: " << dlerror() << "\n"; abort(); } return ret; } } // namespace // original implementation is from pytorch: // https://github.com/pytorch/pytorch/blob/259d19a7335b32c4a27a018034551ca6ae997f6b/aten/src/ATen/DynamicLibrary.cpp std::unique_ptr DynamicLibrary::Load(const std::vector& names) { for (const std::string& name : names) { void* handle = dlopen(name.c_str(), RTLD_LOCAL | RTLD_NOW); if (handle != nullptr) { DynamicLibrary* lib = new DynamicLibrary(handle); return std::unique_ptr(lib); } } return std::unique_ptr(); } void* DynamicLibrary::LoadSym(const char* name) { return OpenSymbol(handle_, name); } #ifdef __linux__ std::string DynamicLibrary::AbsolutePath() { struct link_map* map; dlinfo(handle_, RTLD_DI_LINKMAP, &map); return map->l_name; } #endif // __linux__ DynamicLibrary::~DynamicLibrary() { dlclose(handle_); } } // namespace platform } // namespace oneflow ================================================ FILE: oneflow/core/profiler/event.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "fmt/format.h" #include "oneflow/core/profiler/event.h" #include "oneflow/core/profiler/util.h" using json = nlohmann::json; namespace oneflow { namespace profiler { nlohmann::json IEvent::ToJson() { return json{{"name", name_}, {"time", GetDuration()}}; } void IEvent::SetStartedAt(double t) { started_at_ = t; } void IEvent::SetFinishedAt(double t) { finished_at_ = t; } void IEvent::Start() { SetStartedAt(GetTimeNow()); } void IEvent::Finish() { SetFinishedAt(GetTimeNow()); } bool IEvent::IsChildOf(const IEvent* e) { if (!e) { return false; } if (this == e) { return false; } return GetStartedAt() >= e->GetStartedAt() && GetFinishedAt() <= e->GetFinishedAt(); } const std::string& IEvent::GetName() const { return name_; } nlohmann::json CustomEvent::ToJson() { auto j = IEvent::ToJson(); j["type"] = EventType::kCustom; j["custom_type"] = type_; return j; } std::shared_ptr CustomEvent::Create(const std::string& name, CustomEventType type) { return std::shared_ptr(new CustomEvent(name, type)); } nlohmann::json KernelEvent::ToJson() { auto j = IEvent::ToJson(); j["type"] = EventType::kOneflowKernel; for (const auto& desc : description_) { j["description"][desc.first] = {desc.second.first, desc.second.second}; } #if defined(WITH_CUDA) j["memory_size"] = memory_size_; if (!children_.empty()) { j["children"] = children_; } #endif // WITH_CUDA return j; } std::shared_ptr KernelEvent::Create(const std::string& name, const Description& description) { return std::shared_ptr(new KernelEvent(name, description)); } } // namespace profiler } // namespace oneflow ================================================ FILE: oneflow/core/profiler/event.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_EVENT_H_ #define ONEFLOW_CORE_PROFILER_EVENT_H_ #include #include #include #include "nlohmann/json.hpp" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { namespace profiler { class ProfileManager; enum class EventType { kCustom, // has three kinds kOneflowKernel // OneFlow cpu/cuda kernel }; enum class CustomEventType { kDefault, // for record_function kCudaKernel, // cuda kernel kCudaRuntime // something like cudaLaunchKernel }; enum class EventTimeUnit { kNS, kUS }; class IEvent { public: OF_DISALLOW_COPY_AND_MOVE(IEvent); IEvent() = delete; IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {} virtual nlohmann::json ToJson(); virtual ~IEvent() = default; virtual void Start(); virtual void Finish(); bool IsChildOf(const IEvent* e); const std::string& GetName() const; template const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const; template const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const; template const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const; protected: virtual void SetStartedAt(double t); virtual void SetFinishedAt(double t); std::string name_; EventTimeUnit time_unit_; double started_at_ = 0; double finished_at_ = 0; }; inline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) { if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) { return time_ / 1000; } if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) { return time_ * 1000; } return time_; } template<> const inline double IEvent::GetStartedAt(EventTimeUnit time_unit) const { return ConvertTime(started_at_, time_unit_, time_unit); } template<> const inline time_t IEvent::GetStartedAt(EventTimeUnit time_unit) const { return static_cast(GetStartedAt(time_unit)); } template<> const inline double IEvent::GetFinishedAt(EventTimeUnit time_unit) const { return ConvertTime(finished_at_, time_unit_, time_unit); } template<> const inline time_t IEvent::GetFinishedAt(EventTimeUnit time_unit) const { return static_cast(GetFinishedAt(time_unit)); } template<> const inline double IEvent::GetDuration(EventTimeUnit time_unit) const { return GetFinishedAt(time_unit) - GetStartedAt(time_unit); } template<> const inline time_t IEvent::GetDuration(EventTimeUnit time_unit) const { return static_cast(GetDuration(time_unit)); } class CustomEvent final : public IEvent { public: friend class ProfileManager; nlohmann::json ToJson() override; static std::shared_ptr Create(const std::string& name, CustomEventType type = CustomEventType::kDefault); private: CustomEventType type_; CustomEvent(const std::string& custom_name, CustomEventType type) : IEvent(custom_name, type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS), type_(type) {} }; class KernelEvent final : public IEvent { public: using Description = std::map>; nlohmann::json ToJson() override; static std::shared_ptr Create(const std::string& name, const Description& description); #if defined(WITH_CUDA) void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; } void AddChildEvent(const std::shared_ptr& e) { children_.emplace(e); } bool AddChildEventIfSo(const std::shared_ptr& e) { if (e->IsChildOf(dynamic_cast(this))) { children_.emplace(e); return true; } return false; } bool HasChildEvent(const std::shared_ptr& e) { return children_.count(e); } void WalkAmongChildren(const std::function& e)>& f) const { for (const auto& x : children_) { f(x); } } #endif // WITH_CUDA private: KernelEvent(const std::string& kernel_name, const Description& description) : IEvent(kernel_name, EventTimeUnit::kNS), description_(description) {} #if defined(WITH_CUDA) int64_t memory_size_ = -1; std::set> children_; #endif // WITH_CUDA const Description description_; }; } // namespace profiler } // namespace oneflow namespace nlohmann { inline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) { j = event->ToJson(); } } // namespace nlohmann #endif // ONEFLOW_CORE_PROFILER_EVENT_H_ ================================================ FILE: oneflow/core/profiler/event_recorder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/profiler/event_recorder.h" #include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { namespace profiler { Maybe EventRecorder::RegisterEventToProfileManager(const std::shared_ptr& event) { auto* pmgr = JUST(SingletonMaybe()); pmgr->events_.push(event_); return Maybe::Ok(); } std::shared_ptr EventRecorder::CreateCustomEventRecorder(const std::string& name) { return std::make_shared(CustomEvent::Create(name)); } Maybe EventRecorder::CreateKernelEventRecorder( const std::string& name, #if defined(WITH_CUDA) const std::function& memory_size_getter, #endif const DescriptionGetter& input_shapes_getter, const DescriptionGetter& attrs_getter) { auto pmgr = Singleton::Get(); if (pmgr) { const auto description_getter = [pmgr, input_shapes_getter, attrs_getter]() { KernelEvent::Description desc; if (pmgr->record_shapes_) { desc["input_shapes"] = input_shapes_getter(); } if (pmgr->record_attrs_) { desc["attrs"] = attrs_getter(); } return desc; }; #if defined(WITH_CUDA) if (pmgr->use_cpu_ || pmgr->use_cuda_) { auto event = KernelEvent::Create(name, description_getter()); if (pmgr->use_cuda_) { if (pmgr->record_bandwidth_) { event->SetMemorySize(memory_size_getter()); } } return std::make_shared(event); } #else if (pmgr->use_cpu_) { return std::make_shared(KernelEvent::Create(name, description_getter())); } #endif // WITH_CUDA } std::shared_ptr null_recorder; return null_recorder; } } // namespace profiler } // namespace oneflow ================================================ FILE: oneflow/core/profiler/event_recorder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ #define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/profiler/event.h" namespace oneflow { namespace profiler { class EventRecorder { public: using DescriptionGetter = std::function()>; OF_DISALLOW_COPY_AND_MOVE(EventRecorder); explicit EventRecorder(const std::shared_ptr& event) : event_(event) { CHECK_JUST(RegisterEventToProfileManager(event)); event_->Start(); } Maybe RegisterEventToProfileManager(const std::shared_ptr& event); ~EventRecorder() { if (event_) { event_->Finish(); event_.reset(); } } static std::shared_ptr CreateCustomEventRecorder(const std::string& name); static Maybe CreateKernelEventRecorder( const std::string& name, #if defined(WITH_CUDA) const std::function& memory_size_getter, #endif const DescriptionGetter& input_shapes_getter, const DescriptionGetter& attrs_getter); private: std::shared_ptr event_; }; } // namespace profiler } // namespace oneflow #endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ ================================================ FILE: oneflow/core/profiler/kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/profiler/kernel.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/lazy/actor/actor_context.h" namespace oneflow { namespace profiler { namespace { bool profile_cuda_memory_bandwidth = false; bool profile_kernel_forward_range = false; void Init() { profile_cuda_memory_bandwidth = ParseBooleanFromEnv("ONEFLOW_PROFILER_KERNEL_PROFILE_CUDA_MEMORY_BANDWIDTH", false); profile_kernel_forward_range = ParseBooleanFromEnv("ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE", false); } COMMAND(Init()); #if defined(WITH_CUDA) thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr; thread_local cudaEvent_t cuda_memory_bandwidth_profile_end_event = nullptr; #endif // WITH_CUDA } // namespace void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel) { #if defined(WITH_CUDA) if (profile_cuda_memory_bandwidth) { auto* actor_context_provider = dynamic_cast(kernel_ctx); auto* cuda_stream = dynamic_cast(kernel_ctx->stream()); if (cuda_stream != nullptr && actor_context_provider != nullptr) { CHECK(cuda_memory_bandwidth_profile_start_event == nullptr); CHECK(cuda_memory_bandwidth_profile_end_event == nullptr); OF_CUDA_CHECK(cudaEventCreate(&cuda_memory_bandwidth_profile_start_event)); OF_CUDA_CHECK(cudaEventCreate(&cuda_memory_bandwidth_profile_end_event)); OF_CUDA_CHECK( cudaEventRecord(cuda_memory_bandwidth_profile_start_event, cuda_stream->cuda_stream())); } } if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); } #endif // WITH_CUDA } void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel) { #if defined(WITH_CUDA) if (profile_kernel_forward_range) { OF_PROFILER_RANGE_POP(); } // The memory bandwidth profiler only works in lazy mode. if (profile_cuda_memory_bandwidth) { auto* cuda_stream = dynamic_cast(kernel_ctx->stream()); auto* actor_context_provider = dynamic_cast(kernel_ctx); if (cuda_stream != nullptr && actor_context_provider != nullptr) { cudaEvent_t start_event = cuda_memory_bandwidth_profile_start_event; cudaEvent_t end_event = cuda_memory_bandwidth_profile_end_event; cuda_memory_bandwidth_profile_start_event = nullptr; cuda_memory_bandwidth_profile_end_event = nullptr; CHECK_NOTNULL(start_event); CHECK_NOTNULL(end_event); OF_CUDA_CHECK(cudaEventRecord(end_event, cuda_stream->cuda_stream())); int64_t memory_size = 0; for (const auto& bn : kernel->op_attribute().input_bns()) { const Blob* blob = kernel_ctx->BnInOp2Blob(bn); if (blob) { memory_size += blob->ByteSizeOfBlobBody(); } } for (const auto& bn : kernel->op_attribute().output_bns()) { const Blob* blob = kernel_ctx->BnInOp2Blob(bn); if (blob) { memory_size += blob->ByteSizeOfBlobBody(); } } const std::string op_name = kernel->op_conf().name(); actor_context_provider->GetActorContext()->AddCallback( [start_event, end_event, memory_size, op_name]() { float elapsed_ms = 0; OF_CUDA_CHECK(cudaEventElapsedTime(&elapsed_ms, start_event, end_event)); OF_CUDA_CHECK(cudaEventDestroy(start_event)); OF_CUDA_CHECK(cudaEventDestroy(end_event)); double bandwidth = static_cast(memory_size) / (1024.0 * 1024.0 * 1024.0) / (elapsed_ms / 1000); LOG(INFO) << "PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: " << op_name << " elapsed(ms): " << elapsed_ms << " memory_size(Byte): " << memory_size << " bandwidth(GB/s): " << bandwidth; }); } } #endif // WITH_CUDA } } // namespace profiler } // namespace oneflow ================================================ FILE: oneflow/core/profiler/kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_KERNEL_H_ #define ONEFLOW_CORE_PROFILER_KERNEL_H_ #include "oneflow/core/common/util.h" namespace oneflow { class Kernel; class KernelContext; class Blob; namespace profiler { void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel); void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel); } // namespace profiler } // namespace oneflow #endif // ONEFLOW_CORE_PROFILER_KERNEL_H_ ================================================ FILE: oneflow/core/profiler/kineto_shim.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(WITH_CUDA) #include "oneflow/core/profiler/kineto_shim.h" #include "libkineto.h" namespace oneflow { namespace profiler { namespace { const std::set cpuTypes{ libkineto::ActivityType::CPU_OP, libkineto::ActivityType::CPU_INSTANT_EVENT, libkineto::ActivityType::USER_ANNOTATION, libkineto::ActivityType::EXTERNAL_CORRELATION, libkineto::ActivityType::CUDA_RUNTIME, // something like cudaLaunchKernel libkineto::ActivityType::PYTHON_FUNCTION, }; const std::set cudaTypes = { libkineto::ActivityType::GPU_MEMCPY, libkineto::ActivityType::GPU_MEMSET, libkineto::ActivityType::CONCURRENT_KERNEL, // cuda kernel // CUDA_RUNTIME appears in both cpuTypes and cudaTypes. libkineto::ActivityType::CUDA_RUNTIME, // something like cudaLaunchKernel }; } // namespace ActivityTraceWrapper::ActivityTraceWrapper(std::unique_ptr trace) : trace_(std::move(trace)), saved_{false} {} ActivityTraceWrapper::operator bool() const { return trace_ != nullptr; } void ActivityTraceWrapper::save(const std::string& path) { // TORCH_CHECK(!saved_, "Trace is already saved."); // TORCH_CHECK(trace_ != nullptr, "Missing trace.") trace_->save(path); saved_ = true; } void PrepareTrace(const bool cpuOnly, const ActivitySet& activities) { if (!libkineto::api().isProfilerRegistered()) { libkineto_init(/*cpuOnly=*/cpuOnly, /*logOnError=*/true); libkineto::api().suppressLogMessages(); } if (!libkineto::api().isProfilerInitialized()) { libkineto::api().initProfilerIfRegistered(); } std::set k_activities; if (activities.count(ActivityType::CPU)) { k_activities.insert(cpuTypes.begin(), cpuTypes.end()); } if (activities.count(ActivityType::CUDA)) { k_activities.insert(cudaTypes.begin(), cudaTypes.end()); } libkineto::api().activityProfiler().prepareTrace(k_activities); } void StartTrace() { libkineto::api().activityProfiler().startTrace(); } ActivityTraceWrapper StopTrace() { return ActivityTraceWrapper{libkineto::api().activityProfiler().stopTrace()}; } } // namespace profiler } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/profiler/kineto_shim.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_ #define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_ #if defined(WITH_CUDA) #include #include #include namespace libkineto { enum class ActivityType; class ActivityTraceInterface; } // namespace libkineto namespace oneflow { namespace profiler { enum class ActivityType { CPU = 0, CUDA, }; using interface_trace_t = libkineto::ActivityTraceInterface; struct ActivityTraceWrapper { explicit ActivityTraceWrapper(std::unique_ptr trace); ActivityTraceWrapper() = default; ActivityTraceWrapper(ActivityTraceWrapper&&) = default; ActivityTraceWrapper(const ActivityTraceWrapper&) = delete; explicit operator bool() const; void save(const std::string& path); const std::unique_ptr& get() { return trace_; } private: std::unique_ptr trace_; bool saved_ = false; // Kineto's save is destructive }; using ActivitySet = std::set; void PrepareTrace(const bool cpuOnly, const ActivitySet& activities); void StartTrace(); ActivityTraceWrapper StopTrace(); } // namespace profiler } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_ ================================================ FILE: oneflow/core/profiler/profile_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "fmt/core.h" #include "nlohmann/json.hpp" #include "oneflow/core/profiler/kineto_shim.h" #include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/profiler/event.h" #if defined(WITH_CUDA) #include #endif // WITH_CUDA using json = nlohmann::json; namespace oneflow { namespace profiler { std::string ProfileManager::RegisterEventRecorder( const std::shared_ptr& event_recorder, const std::string& name) { std::string recorder_key = GetNextEventRecorderKey(name); event_recorders_.emplace(recorder_key, event_recorder); return recorder_key; } void ProfileManager::UnregisterEventRecorder(const std::string& event_recorder_key) { if (event_recorders_.find(event_recorder_key) != event_recorders_.end()) { event_recorders_.erase(event_recorder_key); } } std::string ProfileManager::DumpResultsJson() { const json j = ExportEvents(); return j.dump(); } std::vector> ProfileManager::ExportEvents() { #if defined(WITH_CUDA) auto trace = StopTrace(); const auto& kineto_events = *(trace.get()->activities()); std::set> custom_events; std::unordered_map, int64_t> corr_ids; const std::vector> type_pairs = { {libkineto::ActivityType::CUDA_RUNTIME, CustomEventType::kCudaRuntime}, {libkineto::ActivityType::CONCURRENT_KERNEL, CustomEventType::kCudaKernel}}; for (const auto& evt_ptr : kineto_events) { if (evt_ptr == nullptr) { continue; } const auto& activity = *evt_ptr; for (auto& pair : type_pairs) { if (activity.type() == pair.first) { auto custom_event = CustomEvent::Create(activity.name(), pair.second); custom_event->SetStartedAt(static_cast(activity.timestamp())); custom_event->SetFinishedAt(static_cast(activity.timestamp()) + activity.duration()); custom_events.emplace(custom_event); corr_ids[custom_event] = activity.correlationId(); } } } #endif // WITH_CUDA std::vector> events; while (!events_.empty()) { auto evt = events_.front(); events_.pop(); #if defined(WITH_CUDA) auto evt_kernel = std::dynamic_pointer_cast(evt); if (evt_kernel) { std::set current_corr_ids; if (!custom_events.empty()) { for (const auto& x : custom_events) { if (evt_kernel->AddChildEventIfSo(x)) { current_corr_ids.insert(corr_ids[x]); } } for (const auto& x : custom_events) { if (!evt_kernel->HasChildEvent(x) && current_corr_ids.count(corr_ids[x])) { evt_kernel->AddChildEvent(x); } } evt_kernel->WalkAmongChildren( [&custom_events](const std::shared_ptr& child) { custom_events.erase(child); }); } } #endif // WITH_CUDA events.emplace_back(evt); } return events; } std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) { if (event_recorders_last_id_.find(name) == event_recorders_last_id_.end()) { event_recorders_last_id_[name] = 0; } else { event_recorders_last_id_[name]++; } return fmt::format("{}.{}", name, event_recorders_last_id_[name]); } } // namespace profiler } // namespace oneflow ================================================ FILE: oneflow/core/profiler/profile_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_ #define ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_ #include #include #include #include #include "oneflow/core/profiler/kineto_shim.h" namespace oneflow { namespace profiler { class IEvent; class EventRecorder; class ProfileManager { public: friend class EventRecorder; ProfileManager(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs, bool record_bandwidth) : use_cpu_(use_cpu), use_cuda_(use_cuda), record_shapes_(record_shapes), record_attrs_(record_attrs), record_bandwidth_(record_bandwidth) { #if defined(WITH_CUDA) std::set activities{}; if (use_cpu) { activities.insert(ActivityType::CPU); } if (use_cuda) { activities.insert(ActivityType::CUDA); } PrepareTrace(/*cpuOnly*/ false, activities); StartTrace(); #endif // WITH_CUDA } std::string RegisterEventRecorder(const std::shared_ptr& event_recorder, const std::string& name); void UnregisterEventRecorder(const std::string& event_recorder_key); std::string DumpResultsJson(); private: bool use_cpu_; bool use_cuda_; bool record_shapes_; bool record_attrs_; bool record_bandwidth_; std::queue> events_; std::unordered_map> event_recorders_; // To prevent releasing EventRecorders of the same name. std::unordered_map event_recorders_last_id_; std::string GetNextEventRecorderKey(const std::string& name); std::vector> ExportEvents(); }; } // namespace profiler } // namespace oneflow #endif // ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_ ================================================ FILE: oneflow/core/profiler/profiler.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/profiler/kineto_shim.h" #include "oneflow/core/profiler/event_recorder.h" #include "oneflow/core/vm/vm_util.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #include #include #include #include #endif // WITH_CUDA namespace oneflow { namespace profiler { void NameThisHostThread(const std::string& name) { #ifdef WITH_CUDA static thread_local std::unique_ptr thread_name_prefix; if (!thread_name_prefix) { thread_name_prefix.reset( new std::string(GetStringFromEnv("ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX", ""))); } const std::string name_with_prefix = *thread_name_prefix + name; nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str()); #endif // WITH_CUDA } void RangePush(const std::string& name) { #ifdef OF_ENABLE_PROFILER nvtxRangePushA(name.c_str()); #endif // OF_ENABLE_PROFILER } void RangePop() { #ifdef OF_ENABLE_PROFILER nvtxRangePop(); #endif // OF_ENABLE_PROFILER } RangeGuard::RangeGuard(const std::string& name) { #ifdef OF_ENABLE_PROFILER RangePush(name); #endif // OF_ENABLE_PROFILER } RangeGuard::~RangeGuard() { #ifdef OF_ENABLE_PROFILER RangePop(); #endif // OF_ENABLE_PROFILER } void LogHostMemoryUsage(const std::string& name) { #ifdef OF_ENABLE_PROFILER int64_t vm_pages; int64_t rss_pages; std::ifstream ifs("/proc/self/statm"); ifs >> vm_pages >> rss_pages; ifs.close(); const int64_t page_size = sysconf(_SC_PAGE_SIZE); LOG(INFO) << "HostMemoryUsage: " << name << " VM " << vm_pages * page_size << " RSS " << rss_pages * page_size; #endif // OF_ENABLE_PROFILER } void ProfilerStart() { #ifdef OF_ENABLE_PROFILER OF_CUDA_CHECK(cudaProfilerStart()); #endif // OF_ENABLE_PROFILER } void ProfilerStop() { #ifdef OF_ENABLE_PROFILER OF_CUDA_CHECK(cudaProfilerStop()); #endif // OF_ENABLE_PROFILER } void EnableProfiler(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs, bool record_bandwidth) { CHECK_JUST(vm::ClusterSync()); if (Singleton::Get() == nullptr) { Singleton::New(use_cpu, use_cuda, record_shapes, record_attrs, record_bandwidth); } } // DisableProfilerAndReturnResult will return a json of profile results. Maybe DisableProfilerAndReturnResult() { JUST(vm::ClusterSync()); #if defined(WITH_CUDA) OF_CUDA_CHECK(cudaDeviceSynchronize()); #endif // WITH_CUDA auto* pmgr = JUST(SingletonMaybe()); std::string results = pmgr->DumpResultsJson(); Singleton::Delete(); return results; } Maybe StartRecord(const std::string& name) { auto* pmgr = JUST(SingletonMaybe()); JUST(vm::ClusterSync()); return pmgr->RegisterEventRecorder(profiler::EventRecorder::CreateCustomEventRecorder(name), name); } Maybe EndRecord(const std::string& event_recorder_key) { auto* pmgr = JUST(SingletonMaybe()); JUST(vm::ClusterSync()); pmgr->UnregisterEventRecorder(event_recorder_key); return Maybe::Ok(); } } // namespace profiler } // namespace oneflow ================================================ FILE: oneflow/core/profiler/profiler.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_PROFILER_H_ #define ONEFLOW_CORE_PROFILER_PROFILER_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace profiler { void NameThisHostThread(const std::string& name); void RangePush(const std::string& name); void RangePop(); void LogHostMemoryUsage(const std::string& name); void ProfilerStart(); void ProfilerStop(); class RangeGuardCtx; class RangeGuard final { public: OF_DISALLOW_COPY_AND_MOVE(RangeGuard); explicit RangeGuard(const std::string& name); ~RangeGuard(); private: std::shared_ptr ctx_; }; #define OF_PROFILER_NAME_THIS_HOST_THREAD(name) ::oneflow::profiler::NameThisHostThread(name) #ifdef OF_ENABLE_PROFILER #define OF_PROFILER_ONLY_CODE(...) __VA_ARGS__ #define OF_PROFILER_RANGE_PUSH(name) ::oneflow::profiler::RangePush(name) #define OF_PROFILER_RANGE_POP() ::oneflow::profiler::RangePop() #define OF_PROFILER_RANGE_GUARD(name) \ ::oneflow::profiler::RangeGuard OF_PP_CAT(_of_profiler_range_guard_, __COUNTER__)(name) #define OF_PROFILER_LOG_HOST_MEMORY_USAGE(name) ::oneflow::profiler::LogHostMemoryUsage(name) #else #define OF_PROFILER_ONLY_CODE(...) #define OF_PROFILER_RANGE_PUSH(name) #define OF_PROFILER_RANGE_POP() #define OF_PROFILER_RANGE_GUARD(name) #define OF_PROFILER_LOG_HOST_MEMORY_USAGE(name) #endif void EnableProfiler(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs, bool record_bandwidth); // DisableProfilerAndReturnResult will return a json of profile results. Maybe DisableProfilerAndReturnResult(); Maybe StartRecord(const std::string& name); Maybe EndRecord(const std::string& event_recorder_key); } // namespace profiler } // namespace oneflow #endif // ONEFLOW_CORE_PROFILER_PROFILER_H_ ================================================ FILE: oneflow/core/profiler/util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_PROFILER_UTIL_H_ #define ONEFLOW_CORE_PROFILER_UTIL_H_ #include #include namespace oneflow { namespace profiler { using time_t = int64_t; inline time_t GetTimeNow(bool allow_monotonic = false) { struct timespec t {}; auto mode = CLOCK_REALTIME; if (allow_monotonic) { mode = CLOCK_MONOTONIC; } clock_gettime(mode, &t); return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); } } // namespace profiler } // namespace oneflow #endif // ONEFLOW_CORE_PROFILER_UTIL_H_ ================================================ FILE: oneflow/core/record/coco.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/record/record.proto"; message PolygonList { repeated FloatList polygons = 1; } ================================================ FILE: oneflow/core/record/record.proto ================================================ syntax = "proto2"; package oneflow; message BytesList { repeated bytes value = 1; } message FloatList { repeated float value = 1 [packed = true]; } message DoubleList { repeated double value = 1 [packed = true]; } message Int32List { repeated int32 value = 1 [packed = true]; } message Int64List { repeated int64 value = 1 [packed = true]; } message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; DoubleList double_list = 3; Int32List int32_list = 4; Int64List int64_list = 5; } } message OFRecord { map feature = 1; } ================================================ FILE: oneflow/core/register/blob.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/blob.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { Blob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr) { Init(mem_case, blob_desc, header_ptr, nullptr, 0); } Blob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr) { Init(mem_case, blob_desc, header_ptr, body_ptr, 0); } Blob::Blob(const MemoryCase& mem_case, // NOLINT,Blob::Blob(...) { // NOLINT const BlobDesc* blob_desc, char* header_ptr, char* body_ptr, const int64_t offset) { Init(mem_case, blob_desc, header_ptr, body_ptr, offset); } void Blob::Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr, const int64_t offset) { mem_case_ = mem_case; blob_desc_ = blob_desc; storage_offset_ = offset; dptr_ = body_ptr; header_ptr_ = header_ptr; this->blob_access_checker_ = Singleton>::Get(); int64_t* shape_ptr = reinterpret_cast(header_ptr); shape_view_.reset(new ShapeView(shape_ptr, static_shape().NumAxes())); if (blob_desc->is_dynamic()) { mut_shape_view_.reset(new MutShapeView(shape_ptr, static_shape().NumAxes())); } MutShapeView(shape_ptr, static_shape().NumAxes()).set_shape(static_shape()); } void Blob::CopyHeaderFrom(const Blob* rhs) { size_t header_size = blob_desc().ByteSizeOfBlobHeader(); CHECK_EQ(header_size, rhs->blob_desc().ByteSizeOfBlobHeader()); if (this == rhs || header_size == 0) { return; } std::memcpy(header_ptr_, rhs->header_ptr(), header_size); } char* Blob::mut_contiguous_header_ptr() { // check header and body is continuous CHECK_EQ(header_ptr() + blob_desc_->AlignedByteSizeOfBlobHeader(), dptr()); return header_ptr_; } } // namespace oneflow ================================================ FILE: oneflow/core/register/blob.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_BLOB_H_ #define ONEFLOW_CORE_REGISTER_BLOB_H_ #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class BlobAccessChecker { public: virtual void CheckHeaderMutable() const = 0; virtual void CheckBodyMutable() const = 0; }; template class BlobAccessCheckerIf final : public BlobAccessChecker { public: void CheckHeaderMutable() const override { CHECK(is_header_mutable) << "header mutable check not passed, blob's shape is not mutable at this moment!"; } void CheckBodyMutable() const override { CHECK(is_body_mutable) << "body mutable check not passed, blob's data is not mutable at this moment!"; } }; class Blob final { public: OF_DISALLOW_COPY_AND_MOVE(Blob); Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr); Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr); Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr, const int64_t offset); virtual ~Blob() = default; DataType data_type() const { return blob_desc_->data_type(); } MemoryFormat memory_format() const { return blob_desc_->memory_format(); } const char* header_ptr() const { return header_ptr_; } [[deprecated( "\"mut_header_ptr\" will be removed in Bolb. Please avoid to use this method whenever " "possible. Almost all methods of `mut_header_ptr` are also in `Blob`.")]] char* mut_header_ptr() { return header_ptr_; } char* mut_contiguous_header_ptr(); const BlobDesc& blob_desc() const { return *blob_desc_; } const BlobDesc* blob_desc_ptr() const { return blob_desc_; } template const T* dptr() const { CheckDataType(data_type()); return reinterpret_cast(static_cast(dptr_) + storage_offset_ * GetSizeOfDataType(data_type())); } template T* mut_dptr() { this->blob_access_checker()->CheckBodyMutable(); CheckDataType(data_type()); return reinterpret_cast(static_cast(dptr_) + storage_offset_ * GetSizeOfDataType(data_type())); } template T* ForceMutDptr() { CheckDataType(data_type()); return reinterpret_cast(static_cast(dptr_) + storage_offset_ * GetSizeOfDataType(data_type())); } template const T* raw_dptr() const { CheckDataType(data_type()); return static_cast(dptr_); } template T* mut_raw_dptr() { this->blob_access_checker()->CheckBodyMutable(); CheckDataType(data_type()); return static_cast(dptr_); } // shape const Shape& static_shape() const { return blob_desc_->shape(); } const ShapeView& shape_view() const { return *shape_view_; } const ShapeView& shape() const { return *shape_view_; } MutShapeView* mut_shape_view() { this->blob_access_checker()->CheckHeaderMutable(); return mut_shape_view_.get(); } MutShapeView* ForceMutShapeView() { return mut_shape_view_.get(); } // stride const Stride& stride() const { return blob_desc_->stride(); } void reset_dptr(char* dptr) { dptr_ = dptr; } void CopyHeaderFrom(const Blob* rhs); bool IsBodyEmpty() const { return shape().elem_cnt() == 0; } size_t AlignedTotalByteSize() const { return blob_desc_->AlignedTotalByteSize(); } const MemoryCase& mem_case() const { return mem_case_; } size_t ByteSizeOfBlobBody() const { return blob_desc_->ByteSizeOfBlobBody(); } size_t AlignedByteSizeOfBlobBody() const { return blob_desc_->AlignedByteSizeOfBlobBody(); } void set_blob_access_checker(const BlobAccessChecker* blob_access_checker) { this->blob_access_checker_ = blob_access_checker; } const BlobAccessChecker* blob_access_checker() { return this->blob_access_checker_; } private: void Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr, const int64_t offset); const BlobAccessChecker* blob_access_checker_; MemoryCase mem_case_; const BlobDesc* blob_desc_; void* dptr_; char* header_ptr_; int64_t storage_offset_; std::unique_ptr shape_view_; std::unique_ptr mut_shape_view_; }; #define INIT_GLOBAL_BLOB_MUTABLE_CHECKER(is_header_mutable, is_body_mutable) \ COMMAND(Singleton>::SetAllocated( \ new BlobAccessCheckerIf())) INIT_GLOBAL_BLOB_MUTABLE_CHECKER(false, false); INIT_GLOBAL_BLOB_MUTABLE_CHECKER(false, true); INIT_GLOBAL_BLOB_MUTABLE_CHECKER(true, false); INIT_GLOBAL_BLOB_MUTABLE_CHECKER(true, true); } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_BLOB_H_ ================================================ FILE: oneflow/core/register/blob_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/blob_desc.h" namespace oneflow { bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) { return lhs.lbi() < rhs.lbi(); } BlobDesc::BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format, bool is_dynamic) : shape_(SymbolOf(shape)), stride_(SymbolOf(Stride(shape))), data_type_(dtype), memory_format_(memory_format), is_dynamic_(is_dynamic) {} BlobDesc::BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, bool is_dynamic) : shape_(SymbolOf(shape)), stride_(SymbolOf(stride)), data_type_(dtype), memory_format_(memory_format), is_dynamic_(is_dynamic) {} BlobDesc::BlobDesc(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, bool is_dynamic) : shape_(shape), stride_(stride), data_type_(dtype), memory_format_(memory_format), is_dynamic_(is_dynamic) {} BlobDesc::BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format) : BlobDesc(shape, Stride(shape), dtype, memory_format, false) {} BlobDesc::BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format) : BlobDesc(shape, stride, dtype, memory_format, false) {} BlobDesc::BlobDesc(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format) : BlobDesc(shape, stride, dtype, memory_format, false) {} BlobDesc::BlobDesc(DataType dtype, MemoryFormat memory_format) : BlobDesc(Shape(), Stride(), dtype, memory_format, false) {} BlobDesc::BlobDesc(const BlobDescProto& proto) : shape_(SymbolOf(Shape(proto.shape()))), stride_(SymbolOf(Stride(proto.stride()))), data_type_(proto.data_type()), memory_format_(proto.memory_format()), is_dynamic_(proto.is_dynamic()) {} BlobDesc::BlobDesc(const BlobDesc& other) : shape_(other.shape_), stride_(other.stride_), data_type_(other.data_type()), memory_format_(other.memory_format()), is_dynamic_(other.is_dynamic()) {} void BlobDesc::ToProto(BlobDescProto* proto) const { shape().ToProto(proto->mutable_shape()); stride().ToProto(proto->mutable_stride()); proto->set_data_type(data_type_); proto->set_memory_format(memory_format_); proto->set_is_dynamic(is_dynamic_); } BlobDesc& BlobDesc::operator=(const BlobDesc& rhs) { this->CopyFrom(rhs); return *this; } void BlobDesc::CopyFrom(const BlobDesc& other) { set_shape(other.shape()); set_stride(other.stride()); set_data_type(other.data_type()); set_memory_format(other.memory_format()); set_is_dynamic(other.is_dynamic()); } void BlobDesc::set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } bool BlobDesc::operator==(const BlobDesc& rhs) const { return (shape() == rhs.shape()) && (stride() == rhs.stride()) && (data_type() == rhs.data_type()) && (memory_format() == rhs.memory_format()) && (is_dynamic() == rhs.is_dynamic()); } size_t BlobDesc::ByteSizeOfBlobHeader() const { return shape().is_initialized() ? shape().NumAxes() * sizeof(int64_t) : 0; } size_t BlobDesc::AlignedByteSizeOfBlobHeader() const { return shape().is_initialized() ? RoundUp(shape().NumAxes() * sizeof(int64_t), kBlobHeaderAlignSize) : RoundUp(0, kBlobHeaderAlignSize); } size_t BlobDesc::ByteSizeOfBlobBody() const { return shape().is_initialized() ? shape().elem_cnt() * GetSizeOfDataType(data_type()) : 0; } size_t BlobDesc::AlignedByteSizeOfBlobBody() const { return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize); } size_t BlobDesc::AlignedTotalByteSize() const { return AlignedByteSizeOfBlobHeader() + AlignedByteSizeOfBlobBody(); } } // namespace oneflow ================================================ FILE: oneflow/core/register/blob_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ #define ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/register/blob_desc.pb.h" #include "oneflow/core/register/register_desc.pb.h" namespace oneflow { class BlobDesc final { public: BlobDesc() = delete; ~BlobDesc() = default; // NOTE(chengcheng): Cannot using std::make_shared in header file, because it will cause // Segmentation fault with unknown reason. BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format, bool is_dynamic); BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format, bool is_dynamic); BlobDesc(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format, bool is_dynamic); BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format); BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format); BlobDesc(Symbol shape, Symbol stride, DataType dtype, MemoryFormat memory_format); explicit BlobDesc(DataType dtype, MemoryFormat memory_format); explicit BlobDesc(const BlobDescProto& proto); explicit BlobDesc(const BlobDesc&); BlobDesc& operator=(const BlobDesc&); const Shape& shape() const { CHECK(shape_.operator bool()); return *shape_; } const Stride& stride() const { CHECK(stride_.operator bool()); return *stride_; } const std::shared_ptr& shape_ptr() const { return shape_.shared_from_symbol(); } const std::shared_ptr& stride_ptr() const { return stride_.shared_from_symbol(); } void set_shape(const Shape& shape) { this->shape_ = SymbolOf(shape); } void set_stride(const Stride& stride) { this->stride_ = SymbolOf(stride); } DataType data_type() const { return data_type_; } void set_data_type(DataType data_type) { data_type_ = data_type; } MemoryFormat memory_format() const { return memory_format_; } void set_memory_format(MemoryFormat memory_format) { memory_format_ = memory_format; } bool is_dynamic() const { return is_dynamic_; } void set_is_dynamic(bool is_dynamic); bool operator==(const BlobDesc&) const; void ToProto(BlobDescProto*) const; void CopyFrom(const BlobDesc&); size_t ByteSizeOfBlobHeader() const; size_t ByteSizeOfBlobBody() const; size_t AlignedByteSizeOfBlobHeader() const; size_t AlignedByteSizeOfBlobBody() const; size_t AlignedTotalByteSize() const; private: Symbol shape_; Symbol stride_; DataType data_type_; MemoryFormat memory_format_; bool is_dynamic_; }; bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs); } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ ================================================ FILE: oneflow/core/register/blob_desc.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/sequential.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/common/memory_format.proto"; message BlobDescProto { required ShapeProto shape = 1; required Int64ListProto stride = 2; required DataType data_type = 3; required bool is_dynamic = 4; required MemoryFormat memory_format = 5; } message BlobDescSignature { map bn_in_op2blob_desc = 1; } ================================================ FILE: oneflow/core/register/logical_blob_id.proto ================================================ syntax = "proto2"; package oneflow; message LogicalBlobId { optional string op_name = 1; optional string blob_name = 2; } message LogicalBlobIdPair { required LogicalBlobId first = 1; required LogicalBlobId second = 2; } message LogicalBlobIdPairs { repeated LogicalBlobIdPair pair = 1; } message LogicalBlobIdGroups { message LogicalBlobIdGroup { repeated LogicalBlobId lbi = 1; } repeated LogicalBlobIdGroup lbi_group = 2; } message ArgSignature { map bn_in_op2lbi = 1; } ================================================ FILE: oneflow/core/register/op_blob_arg.proto ================================================ syntax = "proto2"; package oneflow; message OpBlobArg { required string op_name = 1; // blob name in op required string bn_in_op = 2; } message OpBlobArgPair { required OpBlobArg first = 1; required OpBlobArg second = 2; } message OpBlobArgPairs { repeated OpBlobArgPair pair = 1; } message OpBlobArgList { repeated OpBlobArg oba = 1; } ================================================ FILE: oneflow/core/register/op_blob_arg_info.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_ #define ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_ #include "oneflow/core/register/op_blob_arg.pb.h" namespace oneflow { struct InplaceObasInfo { OpBlobArgList mut_in_obas; OpBlobArgPairs mut_inplace_oba_pairs; OpBlobArgPairs con_inplace_oba_pairs; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_ ================================================ FILE: oneflow/core/register/register.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/register.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/memory/memory_allocator.h" namespace oneflow { const std::vector& Regst::consumers_actor_id() const { return regst_desc_->consumers_actor_id(); } Regst::Regst(const RtRegstDesc* regst_desc, RegstAllocationType allocation_type) : regst_desc_(regst_desc), header_mem_ptr_(nullptr), body_mem_ptr_(nullptr), comm_net_token_(nullptr), allocation_type_(allocation_type) { sorted_blob_vec_.resize(regst_desc->lbi_num()); } Regst::~Regst() { if (comm_net_token_ != nullptr) { Singleton::Get()->UnRegisterMemory(comm_net_token_); } } void Regst::Init(void* header_mem_ptr) { CHECK(header_mem_ptr_ == nullptr); header_mem_ptr_ = header_mem_ptr; regst_desc_->ForEachBlobDescOffsetInOnRegst([&](int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* blob_desc, int64_t body_offset, int64_t header_offset) { sorted_blob_vec_.at(ordinal).reset( new Blob(regst_desc_->mem_case(), blob_desc, reinterpret_cast(header_mem_ptr_) + header_offset)); }); } void Regst::ResetBodyMemPtr(void* body_mem_ptr) { if (body_mem_ptr_ == body_mem_ptr) { return; } body_mem_ptr_ = body_mem_ptr; if (body_mem_ptr_ == nullptr) { for (auto& blob : sorted_blob_vec_) { blob->reset_dptr(nullptr); } } else { regst_desc_->ForEachBlobDescOffsetInOnRegst([&](int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* blob_desc, int64_t body_offset, int64_t header_offset) { sorted_blob_vec_.at(ordinal)->reset_dptr(reinterpret_cast(body_mem_ptr_) + body_offset); InitNonPODTypeBlobIfNeed(Singleton::Get(), sorted_blob_vec_.at(ordinal).get()); }); } } Blob* Regst::GetBlobByOrdinal(int64_t ordinal) { return sorted_blob_vec_.at(ordinal).get(); } Blob* Regst::GetBlobByLbi(const LogicalBlobId& lbi) { const int64_t ordinal = regst_desc_->GetOrdinalForLbi(lbi); if (ordinal >= 0) { return sorted_blob_vec_.at(ordinal).get(); } else { return nullptr; } } void Regst::SetBlobByOrdinal(int64_t ordinal, std::unique_ptr&& blob) { CHECK(!sorted_blob_vec_.at(ordinal)); sorted_blob_vec_.at(ordinal).swap(blob); } Blob* Regst::GetMutSoleBlob() { CHECK_EQ(GetBlobSize(), 1); return sorted_blob_vec_.front().get(); } const Blob* Regst::GetSoleBlob() const { CHECK_EQ(GetBlobSize(), 1); return sorted_blob_vec_.front().get(); } void* Regst::comm_net_token() { void* token = comm_net_token_.load(std::memory_order_relaxed); if (token != nullptr) { return token; } { std::lock_guard lock(comm_net_token_mutex_); token = comm_net_token_; if (token != nullptr) { return token; } CHECK(body_mem_ptr_ != nullptr); CHECK(header_mem_ptr_ != nullptr); CHECK(reinterpret_cast(header_mem_ptr_) + regst_desc_->HeaderByteSize4OneRegst() == body_mem_ptr_); token = Singleton::Get()->RegisterMemory(header_mem_ptr_, this->regst_desc()->MainByteSize4OneRegst()); comm_net_token_ = token; return token; } } } // namespace oneflow ================================================ FILE: oneflow/core/register/register.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_REGISTER_H_ #define ONEFLOW_CORE_REGISTER_REGISTER_H_ #include "oneflow/core/register/blob.h" #include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { enum class RegstAllocationType { kInvalid = 0, kStatic = 1, kStreamOrdered = 2, }; class Regst final { public: OF_DISALLOW_COPY_AND_MOVE(Regst); ~Regst(); // Getters int64_t regst_desc_id() const { CHECK(regst_desc_ != nullptr); return regst_desc_->regst_desc_id(); } void Init(void* header_mem_ptr); void ResetBodyMemPtr(void* body_mem_ptr); int64_t producer_actor_id() const { return regst_desc_->producer_actor_id(); } const std::vector& consumers_actor_id() const; const RtRegstDesc* regst_desc() const { return regst_desc_; } Blob* GetBlobByOrdinal(int64_t ordinal); Blob* GetBlobByLbi(const LogicalBlobId& lbi); const Blob* GetSoleBlob() const; Blob* GetMutSoleBlob(); int64_t GetBlobSize() const { return static_cast(sorted_blob_vec_.size()); } void* comm_net_token(); void* header_mem_ptr() const { return header_mem_ptr_; } void* body_mem_ptr() const { return body_mem_ptr_; } RegstAllocationType allocation_type() const { return allocation_type_; } private: friend class RegstMgr; Regst(const RtRegstDesc* regst_desc, RegstAllocationType allocation_type); void SetBlobByOrdinal(int64_t ordinal, std::unique_ptr&& blob); const RtRegstDesc* regst_desc_; std::vector> sorted_blob_vec_; void* header_mem_ptr_; void* body_mem_ptr_; std::atomic comm_net_token_; std::mutex comm_net_token_mutex_; RegstAllocationType allocation_type_; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_REGISTER_H_ ================================================ FILE: oneflow/core/register/register_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/register_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/memory/memory_case_util.h" namespace oneflow { RegstDesc::RegstDesc() { regst_desc_id_ = Singleton::Get()->NewRegstDescId(); // NOLINT producer_ = nullptr; min_register_num_ = 1; max_register_num_ = kMaxRegisterNum; enable_reuse_mem_ = false; mem_block_id_ = -1; mem_block_offset_ = -1; hint_inplace_consumed_regst_desc_id_ = -1; force_inplace_consumed_regst_desc_id_ = -1; } int64_t RegstDesc::mem_block_offset() const { CHECK_GE(mem_block_offset_, 0); return mem_block_offset_; } void RegstDesc::AddConsumer(const TaskNode* new_consumer) { CHECK(consumers_.insert(new_consumer).second); } void RegstDesc::DeleteConsumer(const TaskNode* consumer) { CHECK_EQ(consumers_.erase(consumer), 1); } void RegstDesc::UpdtMinRegstNumIfNeed(int32_t val) { CHECK_LE(val, max_register_num_); min_register_num_ = std::max(min_register_num_, val); } void RegstDesc::UpdtMaxRegstNumIfNeed(int32_t val) { CHECK_GE(val, min_register_num_); max_register_num_ = std::min(max_register_num_, val); } void RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) { CHECK(lbi2blob_desc_.empty()); for (const auto& pair : rhs->lbi2blob_desc_) { const LogicalBlobId& lbi = pair.first; AddLbi(lbi); } CopyBlobDescWithoutAddLbi(rhs); } void RegstDesc::CopyMemBlockInfoFrom(const RegstDesc* rhs) { enable_reuse_mem_ = rhs->enable_reuse_mem_; mem_block_id_ = rhs->mem_block_id_; mem_block_offset_ = rhs->mem_block_offset_; } void RegstDesc::CopyBlobDescWithoutAddLbi(const RegstDesc* rhs) { for (const auto& pair : lbi2blob_desc_) { auto rhs_it = rhs->lbi2blob_desc_.find(pair.first); if (rhs_it != rhs->lbi2blob_desc_.end()) { *(pair.second) = *(rhs_it->second); } } } BlobDesc* RegstDesc::AddLbi(const LogicalBlobId& lbi) { CHECK(lbi2blob_desc_.find(lbi) == lbi2blob_desc_.end()); BlobDesc* blob_desc = new BlobDesc(GlobalJobDesc().DefaultDataType(), MemoryFormat::kContiguous); lbi2blob_desc_[lbi].reset(blob_desc); return blob_desc; } const BlobDesc* RegstDesc::GetBlobDesc(const LogicalBlobId& lbi) const { return const_cast(this)->MutBlobDesc(lbi); } bool RegstDesc::HasLbi(const LogicalBlobId& lbi) const { return lbi2blob_desc_.find(lbi) != lbi2blob_desc_.end(); } BlobDesc* RegstDesc::MutBlobDesc(const LogicalBlobId& lbi) { auto it = lbi2blob_desc_.find(lbi); if (it != lbi2blob_desc_.end()) { return it->second.get(); } else { return nullptr; } } const BlobDesc* RegstDesc::SoleBlobDesc() const { CHECK_EQ(1, lbi2blob_desc_.size()); return (*lbi2blob_desc_.begin()).second.get(); } BlobDesc* RegstDesc::MutSoleBlobDesc() { return const_cast(SoleBlobDesc()); } void RegstDesc::ForEachLbi(std::function func) const { for (const auto& p : lbi2blob_desc_) { func(p.first); } } void RegstDesc::EraseUninitializedShapeBlob() { EraseIf>( &lbi2blob_desc_, [](HashMap>::iterator it) { return !it->second->shape().is_initialized(); }); } void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { regst_desc_id_ = proto.regst_desc_id(); CHECK_EQ(proto.producer_task_id(), producer_->task_id()); regst_desc_type_ = proto.regst_desc_type(); if (regst_desc_type_.has_data_regst_desc()) { const DataRegstDesc& data_regst_desc_proto = proto.regst_desc_type().data_regst_desc(); for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) { *AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc()); } CHECK(!data_regst_desc_proto.has_time_shape()); } else if (regst_desc_type_.has_ctrl_regst_desc()) { // do nothing } else { UNIMPLEMENTED(); } min_register_num_ = proto.min_register_num(); max_register_num_ = proto.max_register_num(); min_register_num_ = proto.register_num(); mem_case_ = proto.mem_case(); enable_reuse_mem_ = proto.enable_reuse_mem(); mem_block_id_ = proto.mem_block_id(); mem_block_offset_ = proto.mem_block_offset(); hint_inplace_consumed_regst_desc_id_ = proto.hint_inplace_consumed_regst_desc_id(); force_inplace_consumed_regst_desc_id_ = proto.force_inplace_consumed_regst_desc_id(); } void RegstDesc::ToProto(RegstDescProto* ret, bool check) const { ret->set_regst_desc_id(regst_desc_id_); ret->set_producer_task_id(producer_->task_id()); for (const TaskNode* consumer : consumers_) { ret->add_consumer_task_id(consumer->task_id()); } *(ret->mutable_regst_desc_type()) = regst_desc_type_; if (regst_desc_type_.has_data_regst_desc()) { DataRegstDesc* data_regst_desc_proto = ret->mutable_regst_desc_type()->mutable_data_regst_desc(); for (const auto& pair : lbi2blob_desc_) { LbiBlobDescPair* pb_pair = data_regst_desc_proto->mutable_lbi2blob_desc()->Add(); *(pb_pair->mutable_lbi()) = pair.first; pair.second->ToProto(pb_pair->mutable_blob_desc()); } if (check) { CHECK(data_regst_time_shape_); } if (data_regst_time_shape_) { data_regst_time_shape_->ToProto(data_regst_desc_proto->mutable_time_shape()); } } else if (regst_desc_type_.has_ctrl_regst_desc()) { // do nothing } else { UNIMPLEMENTED(); } ret->set_min_register_num(min_register_num_); ret->set_max_register_num(max_register_num_); ret->set_register_num(min_register_num_); *(ret->mutable_mem_case()) = mem_case_; ret->set_enable_reuse_mem(enable_reuse_mem_); ret->set_mem_block_id(mem_block_id_); ret->set_mem_block_offset(mem_block_offset_); if (check) { CHECK(hint_inplace_consumed_regst_desc_id_ == -1 || force_inplace_consumed_regst_desc_id_ == -1) << "They are oneof fields"; } if (hint_inplace_consumed_regst_desc_id_ != -1) { ret->set_hint_inplace_consumed_regst_desc_id(hint_inplace_consumed_regst_desc_id_); } else if (force_inplace_consumed_regst_desc_id_ != -1) { ret->set_force_inplace_consumed_regst_desc_id(force_inplace_consumed_regst_desc_id_); } else { // do nothing } } bool RegstDesc::HasSameMemSize(const RegstDesc* rhs) { return SoleBlobDesc()->AlignedTotalByteSize() == rhs->SoleBlobDesc()->AlignedTotalByteSize(); } bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) { if (rhs->lbi2blob_desc_.size() != lbi2blob_desc_.size()) { return false; } for (const auto& pair : rhs->lbi2blob_desc_) { auto iter = lbi2blob_desc_.find(pair.first); if (iter == lbi2blob_desc_.end()) { return false; } if (!(*(pair.second.get()) == *(iter->second.get()))) { return false; } } return true; } void InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto) { CHECK_NOTNULL(ctrl_regst_proto); ctrl_regst_proto->set_regst_desc_id(Singleton::Get()->NewRegstDescId()); ctrl_regst_proto->set_producer_task_id(producer_task_id); ctrl_regst_proto->set_min_register_num(1); ctrl_regst_proto->set_max_register_num(1); ctrl_regst_proto->set_register_num(1); ctrl_regst_proto->mutable_regst_desc_type()->mutable_ctrl_regst_desc(); *ctrl_regst_proto->mutable_mem_case() = memory::MakeHostMemCase(); ctrl_regst_proto->set_enable_reuse_mem(false); ctrl_regst_proto->set_mem_block_id(-1); ctrl_regst_proto->set_mem_block_offset(-1); } } // namespace oneflow ================================================ FILE: oneflow/core/register/register_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ #define ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/register/register_desc.pb.h" namespace oneflow { const int32_t kMaxRegisterNum = std::numeric_limits::max(); void InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto); class TaskNode; class RegstDesc final { public: OF_DISALLOW_COPY_AND_MOVE(RegstDesc); RegstDesc(); ~RegstDesc() = default; // regst_desc_id int64_t regst_desc_id() const { return regst_desc_id_; } // producer_, consumers_ const TaskNode* producer() const { return producer_; } void set_producer(const TaskNode* val) { producer_ = val; } const HashSet& consumers() const { return consumers_; } void AddConsumer(const TaskNode*); void DeleteConsumer(const TaskNode*); // min_register_num_, max_register_num_ int32_t min_register_num() const { return min_register_num_; } void UpdtMinRegstNumIfNeed(int32_t val); int32_t max_register_num() const { return max_register_num_; } void UpdtMaxRegstNumIfNeed(int32_t val); // lbi2blob_desc_ void CopyBlobDescFrom(const RegstDesc*); void CopyBlobDescWithoutAddLbi(const RegstDesc*); BlobDesc* AddLbi(const LogicalBlobId&); const BlobDesc* GetBlobDesc(const LogicalBlobId& lbi) const; bool HasLbi(const LogicalBlobId& lbi) const; BlobDesc* MutBlobDesc(const LogicalBlobId& lbi); const BlobDesc* SoleBlobDesc() const; BlobDesc* MutSoleBlobDesc(); void ForEachLbi(std::function func) const; size_t NumOfLbi() const { return lbi2blob_desc_.size(); } // mem const MemoryCase& mem_case() const { return mem_case_; } MemoryCase* mut_mem_case() { return &mem_case_; } bool enable_reuse_mem() const { return enable_reuse_mem_; } void set_enable_reuse_mem(bool enable_reuse_mem) { enable_reuse_mem_ = enable_reuse_mem; } int64_t mem_block_offset() const; void set_mem_block_offset(int64_t val) { mem_block_offset_ = val; } void set_hint_inplace_consumed_regst_desc_id(int64_t val) { CHECK_EQ(force_inplace_consumed_regst_desc_id_, -1); hint_inplace_consumed_regst_desc_id_ = val; } bool has_force_inplace_consumed_regst_desc_id() { return force_inplace_consumed_regst_desc_id_ != -1; } void set_force_inplace_consumed_regst_desc_id(int64_t val) { CHECK_EQ(hint_inplace_consumed_regst_desc_id_, -1); force_inplace_consumed_regst_desc_id_ = val; } int32_t mem_block_id() const { return mem_block_id_; } void set_mem_block_id(int32_t val) { mem_block_id_ = val; } bool HasSetMemBlockId() { return mem_block_id_ != -1; } void CopyMemBlockInfoFrom(const RegstDesc*); const std::shared_ptr& data_regst_time_shape() const { CHECK(regst_desc_type_.has_data_regst_desc()); CHECK(data_regst_time_shape_); return data_regst_time_shape_; } std::shared_ptr* mut_data_regst_time_shape() { CHECK(regst_desc_type_.has_data_regst_desc()); return &data_regst_time_shape_; } RegstDescTypeProto* mut_regst_desc_type() { return ®st_desc_type_; } const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; } bool HasSameMemSize(const RegstDesc*); // util void EraseUninitializedShapeBlob(); void InitFromProtoExceptConsumers(const RegstDescProto& proto); void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); } void ToProto(RegstDescProto*, bool check) const; bool HasSameBlobDescs(const RegstDesc*); private: int64_t regst_desc_id_; const TaskNode* producer_; HashSet consumers_; int32_t min_register_num_; int32_t max_register_num_; HashMap> lbi2blob_desc_; MemoryCase mem_case_; RegstDescTypeProto regst_desc_type_; bool enable_reuse_mem_; int32_t mem_block_id_; int64_t mem_block_offset_; int64_t hint_inplace_consumed_regst_desc_id_; int64_t force_inplace_consumed_regst_desc_id_; std::shared_ptr data_regst_time_shape_; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ ================================================ FILE: oneflow/core/register/register_desc.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/memory/memory_case.proto"; import "oneflow/core/common/shape.proto"; message LbiBlobDescPair { required LogicalBlobId lbi = 1; required BlobDescProto blob_desc = 2; } message DataRegstDesc { repeated LbiBlobDescPair lbi2blob_desc = 1; optional ShapeProto time_shape = 3; } message CtrlRegstDesc { } message RegstDescTypeProto { oneof type { DataRegstDesc data_regst_desc = 1; CtrlRegstDesc ctrl_regst_desc = 3; } } message RegstDescProto { required int64 regst_desc_id = 1; required int64 producer_task_id = 2; repeated int64 consumer_task_id = 3; required int32 min_register_num = 4; required int32 max_register_num = 5; required int32 register_num = 6; required MemoryCase mem_case = 7; required RegstDescTypeProto regst_desc_type = 8; required bool enable_reuse_mem = 9; required int64 mem_block_id = 10; required int64 mem_block_offset = 11; optional int64 separated_header_mem_block_id = 12 [default = -1]; optional int64 inplace_consumed_regst_desc_id = 13 [default = -1]; oneof inplace_info_type { int64 hint_inplace_consumed_regst_desc_id = 14 [default = -1]; int64 force_inplace_consumed_regst_desc_id = 15 [default = -1]; } // NOTE(chengcheng): mark this regst memory is shared with EagerParameter. optional string variable_op_name = 16 [default = ""]; // NOTE(chengcheng): for mem block debug. optional int64 mem_block_total_actor_count = 20 [default = -1]; optional int64 alloc_before_actor = 21 [default = -1]; optional int64 free_after_actor = 22 [default = -1]; } ================================================ FILE: oneflow/core/register/register_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/register_manager.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/memory/chunk_manager.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace { struct PackedChunkInfo { MemoryCase mem_case; int64_t size; std::vector blocks; PackedChunkInfo(const MemoryCase& mem) { mem_case = mem; size = 0; } }; std::shared_ptr GetDeviceByMemoryCase(const MemoryCase& mem_case) { return Singleton::Get()->GetDevice(mem_case.device_type(), mem_case.device_id()); } void InitDataRegst(Regst* regst, char* main_mem_ptr, char* separated_header_mem_ptr) { auto* rt_regst_desc = regst->regst_desc(); size_t separated_header_mem_size = rt_regst_desc->SeparatedHeaderByteSize4OneRegst(); char* cur_body_pointer = nullptr; char* cur_header_pointer = nullptr; if (separated_header_mem_size > 0) { MemoryCase host_mem_case = memory::MakeHostMemCase(); if (separated_header_mem_ptr == nullptr) { separated_header_mem_ptr = Singleton::Get()->Allocate(host_mem_case, separated_header_mem_size); } cur_header_pointer = separated_header_mem_ptr; cur_body_pointer = main_mem_ptr; } else { CHECK(separated_header_mem_ptr == nullptr); cur_header_pointer = main_mem_ptr; if (main_mem_ptr == nullptr) { cur_body_pointer = nullptr; } else { cur_body_pointer = main_mem_ptr + rt_regst_desc->GetSoleBlobDesc()->AlignedByteSizeOfBlobHeader(); } } if (regst->allocation_type() == RegstAllocationType::kStatic) { CHECK(cur_body_pointer != nullptr || rt_regst_desc->TotalBodyByteSize4AllRegst() == 0); } else if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) { CHECK(cur_body_pointer == nullptr); } else { UNIMPLEMENTED(); } regst->Init(cur_header_pointer); regst->ResetBodyMemPtr(cur_body_pointer); } } // namespace RegstMgr::RegstMgr() : stream_ordered_memory_allocation_enabled_(false) { stream_ordered_memory_allocation_enabled_ = ParseBooleanFromEnv("ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION", false); } bool RegstMgr::IsStreamOrderedMemoryAllocationCase(const MemoryCase& mem_case) const { if (!stream_ordered_memory_allocation_enabled_) { return false; } const auto& device = GetDeviceByMemoryCase(mem_case); return device->IsStreamOrderedMemoryAllocationSupported(); } void RegstMgr::AddPlan( const Plan& plan, const HashMap& variable_op_name2eager_blob_object) { int64_t this_machine_id = GlobalProcessCtx::Rank(); HashMap chunk_id2ptr; for (const ChunkProto& chunk : plan.block_chunk_list().chunk()) { if (chunk.machine_id() != this_machine_id) { continue; } if (chunk.mem_size() == 0) { continue; } if (IsStreamOrderedMemoryAllocationCase(chunk.mem_case())) { continue; } char* chunk_ptr = Singleton::Get()->FindOrCreateChunk(chunk); CHECK(chunk_id2ptr.emplace(chunk.chunk_id(), chunk_ptr).second); } HashSet all_block_ids; HashMap zone_id2packed_chunk; for (const MemBlockProto& mem_block : plan.block_chunk_list().mem_block()) { if (mem_block.machine_id() != this_machine_id) { continue; } if (mem_block.mem_size() == 0) { continue; } const int64_t mem_block_id = mem_block.mem_block_id(); CHECK(all_block_ids.insert(mem_block_id).second); if (mem_block.has_chunk_id()) { if (IsStreamOrderedMemoryAllocationCase(mem_block.mem_case())) { CHECK(mem_block.enable_reuse_mem()); CHECK(stream_ordered_allocation_mem_block_ids_.emplace(mem_block_id).second); continue; } CHECK(mem_block.has_chunk_offset()); CHECK(chunk_id2ptr.find(mem_block.chunk_id()) != chunk_id2ptr.end()); char* mem_block_ptr = chunk_id2ptr.at(mem_block.chunk_id()) + mem_block.chunk_offset(); CHECK(mem_block_id2ptr_.emplace(mem_block_id, mem_block_ptr).second) << " duplicated mem_block_id " << mem_block_id; CHECK(!mem_block.has_variable_op_name()); } else if (mem_block.has_variable_op_name()) { // NOTE(chengcheng): bind mem_block_ptr to variable blob header_ptr and body_ptr CHECK(!mem_block.enable_reuse_mem()); const std::string& var_name = mem_block.variable_op_name(); CHECK(!var_name.empty()); auto it = variable_op_name2eager_blob_object.find(var_name); CHECK(it != variable_op_name2eager_blob_object.end()) << " CANNOT find variable op name: " << var_name; CHECK(mem_block.has_is_separated_header()); vm::EagerBlobObject* var_blob = it->second; CHECK(var_blob) << " variable op name: " << var_name << " in rank: " << this_machine_id << " CANNNOT NULL."; if (mem_block.is_separated_header()) { CHECK_GE(var_blob->AlignedByteSizeOfBlobHeader(), mem_block.mem_size()); CHECK_GE(mem_block.mem_size(), var_blob->ByteSizeOfBlobHeader()); CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->mut_header_ptr()).second); CHECK(memory::IsHostMem(mem_block.mem_case())); } else { CHECK_GE(var_blob->AlignedByteSizeOfBlobBody(), mem_block.mem_size()); CHECK_GE(mem_block.mem_size(), var_blob->ByteSizeOfBlobBody()); CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->mut_dptr()).second); // NOTE(chengcheng): // CPU eager var tensor mem case is host_mem WITHOUT cuda pinned, but Lazy Complier // will set variable op output blob mem_case with cuda pinned memory if this output // blob has GPU op consume. We can JUST ignore this diff because it ONLY has little // perf loss but correct. // And this problem is NOT tensor.to("cuda") or tensor.to_global(). CHECK(memory::EqualsIgnorePinnedDevice(mem_block.mem_case(), var_blob->mem_case())) << " variable op name: " << var_name << " in rank: " << this_machine_id << " bind eager tensor failed. The eager var tensor mem_case is : " << var_blob->mem_case().DebugString() << " but graph expected_mem block mem_case is : " << mem_block.mem_case().DebugString(); } } else { int64_t zone_id = memory::GetMemCaseId(mem_block.mem_case()); if (zone_id2packed_chunk.find(zone_id) == zone_id2packed_chunk.end()) { zone_id2packed_chunk.emplace(zone_id, PackedChunkInfo(mem_block.mem_case())); } PackedChunkInfo* packed_chunk = &(zone_id2packed_chunk.at(zone_id)); packed_chunk->blocks.emplace_back(&mem_block); packed_chunk->size += mem_block.mem_size(); CHECK(packed_chunk->mem_case == mem_block.mem_case()); } } for (auto& pair : zone_id2packed_chunk) { PackedChunkInfo* packed_chunk = &pair.second; char* ptr = Singleton::Get()->Allocate(packed_chunk->mem_case, packed_chunk->size); // sort blocks as thrd id std::vector* blocks = &(packed_chunk->blocks); std::sort(blocks->begin(), blocks->end(), [](const MemBlockProto* lhs, const MemBlockProto* rhs) { if (lhs->thrd_id_hint() == rhs->thrd_id_hint()) { return lhs->mem_block_id() < rhs->mem_block_id(); } return lhs->thrd_id_hint() < rhs->thrd_id_hint(); }); int64_t offset = 0; for (const MemBlockProto* block : packed_chunk->blocks) { CHECK(mem_block_id2ptr_.emplace(block->mem_block_id(), ptr + offset).second); offset += block->mem_size(); } CHECK_EQ(offset, packed_chunk->size); } for (int64_t mem_block_id : all_block_ids) { if (mem_block_id2ptr_.find(mem_block_id) != mem_block_id2ptr_.end()) { CHECK(stream_ordered_allocation_mem_block_ids_.find(mem_block_id) == stream_ordered_allocation_mem_block_ids_.end()); } else { CHECK(stream_ordered_allocation_mem_block_ids_.find(mem_block_id) != stream_ordered_allocation_mem_block_ids_.end()); } } for (const TaskProto& task : plan.task()) { if (task.machine_id() != this_machine_id) { continue; } for (const auto& pair : task.produced_regst_desc()) { const RegstDescProto& regst_desc = pair.second; const int64_t regst_desc_id = regst_desc.regst_desc_id(); CHECK(regst_desc_id2rt_regst_desc_ .emplace(regst_desc_id, std::make_unique(regst_desc)) .second); } } for (const auto& pair : plan.ctrl_regst_desc_info().ctrl_regst_desc_id2producer_task_id()) { CHECK(ctrl_regst_desc_id2producer_task_id_.emplace(pair.first, pair.second).second); } } void RegstMgr::AddPlan(const Plan& plan) { HashMap variable_op_name2eager_blob_object; AddPlan(plan, variable_op_name2eager_blob_object); } void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, std::function OneRegstDone) { const int64_t regst_desc_id = regst_desc_proto.regst_desc_id(); const RegstDescTypeProto& regst_desc_type = regst_desc_proto.regst_desc_type(); const RtRegstDesc* rt_regst_desc = regst_desc_id2rt_regst_desc_.at(regst_desc_id).get(); char* main_mem_ptr = nullptr; char* separated_header_mem_ptr = nullptr; int64_t mem_block_id = regst_desc_proto.mem_block_id(); int64_t header_block_id = regst_desc_proto.separated_header_mem_block_id(); if (mem_block_id != -1 && mem_block_id2ptr_.find(mem_block_id) != mem_block_id2ptr_.end()) { main_mem_ptr = mem_block_id2ptr_.at(mem_block_id) + regst_desc_proto.mem_block_offset(); } if (header_block_id != -1 && mem_block_id2ptr_.find(header_block_id) != mem_block_id2ptr_.end()) { separated_header_mem_ptr = mem_block_id2ptr_.at(header_block_id); } RegstAllocationType allocation_type = stream_ordered_allocation_mem_block_ids_.find(mem_block_id) == stream_ordered_allocation_mem_block_ids_.end() ? RegstAllocationType::kStatic : RegstAllocationType::kStreamOrdered; for (int64_t i = 0; i < rt_regst_desc->register_num(); ++i) { Regst* regst = new Regst(rt_regst_desc, allocation_type); if (regst_desc_type.has_data_regst_desc()) { InitDataRegst(regst, main_mem_ptr, separated_header_mem_ptr); if (main_mem_ptr != nullptr) { main_mem_ptr += rt_regst_desc->MainByteSize4OneRegst(); } if (separated_header_mem_ptr != nullptr) { separated_header_mem_ptr += rt_regst_desc->SeparatedHeaderByteSize4OneRegst(); } } else if (regst_desc_type.has_ctrl_regst_desc()) { // do nothing } else { UNIMPLEMENTED(); } OneRegstDone(regst); } } const RtRegstDesc& RegstMgr::RegstDesc4RegstDescId(int64_t regst_desc_id) const { const auto& it = regst_desc_id2rt_regst_desc_.find(regst_desc_id); CHECK(it != regst_desc_id2rt_regst_desc_.end()); return *it->second; } bool RegstMgr::HasRegstDescId(int64_t regst_desc_id) const { return regst_desc_id2rt_regst_desc_.find(regst_desc_id) != regst_desc_id2rt_regst_desc_.end(); } int64_t RegstMgr::ProducerTaskId4RegstDescId(int64_t regst_desc_id) const { const auto& it = ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id); CHECK(it != ctrl_regst_desc_id2producer_task_id_.end()); return it->second; } bool RegstMgr::HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const { return ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id) != ctrl_regst_desc_id2producer_task_id_.end(); } } // namespace oneflow ================================================ FILE: oneflow/core/register/register_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_ #define ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_ #include #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/register/logical_blob_id.pb.h" #include "oneflow/core/register/register.h" #include "oneflow/core/record/record.pb.h" namespace oneflow { namespace vm { class EagerBlobObject; } class RegstMgr final { public: OF_DISALLOW_COPY_AND_MOVE(RegstMgr); RegstMgr(); ~RegstMgr() = default; void AddPlan( const Plan& plan, const HashMap& variable_op_name2eager_blob_object); void AddPlan(const Plan& plan); void NewRegsts(const RegstDescProto& regst_desc_proto, std::function OneRegstDone); const RtRegstDesc& RegstDesc4RegstDescId(int64_t regst_desc_id) const; bool HasRegstDescId(int64_t regst_desc_id) const; int64_t ProducerTaskId4RegstDescId(int64_t regst_desc_id) const; bool HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const; private: bool IsStreamOrderedMemoryAllocationCase(const MemoryCase& mem_case) const; HashMap> regst_desc_id2rt_regst_desc_; HashMap mem_block_id2ptr_; HashSet stream_ordered_allocation_mem_block_ids_; HashMap ctrl_regst_desc_id2producer_task_id_; std::mutex mutex_; bool stream_ordered_memory_allocation_enabled_; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_ ================================================ FILE: oneflow/core/register/runtime_register_desc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { RtRegstDesc::RtRegstDesc(const RegstDescProto& proto) : one_regst_header_size_(0), one_regst_body_size_(0) { regst_desc_id_ = proto.regst_desc_id(); producer_actor_id_ = proto.producer_task_id(); consumers_actor_id_ = PbRf2StdVec(proto.consumer_task_id()); register_num_ = proto.register_num(); mem_case_ = proto.mem_case(); regst_desc_type_ = proto.regst_desc_type(); if (proto.regst_desc_type().has_data_regst_desc()) { const DataRegstDesc& data_regst_desc = proto.regst_desc_type().data_regst_desc(); std::vector lbi_pairs( {data_regst_desc.lbi2blob_desc().cbegin(), data_regst_desc.lbi2blob_desc().cend()}); std::sort(lbi_pairs.begin(), lbi_pairs.end(), &CompareLbiBlobDescPair); CHECK_EQ(lbi_pairs.size(), 1); sorted_blob_desc_vec_.reserve(lbi_pairs.size()); sorted_lbi_vec_.reserve(lbi_pairs.size()); for (int64_t i = 0; i < lbi_pairs.size(); ++i) { const LbiBlobDescPair& pair = lbi_pairs.at(i); sorted_blob_desc_vec_.emplace_back(std::make_unique(pair.blob_desc())); sorted_lbi_vec_.emplace_back(pair.lbi()); lbi2blob_desc_ordinal_.emplace(pair.lbi(), i); } CHECK(data_regst_desc.has_time_shape()); data_regst_time_shape_.reset(new Shape(data_regst_desc.time_shape())); } else { sorted_blob_desc_vec_.emplace_back( std::make_unique(BlobDesc(DataType::kChar, MemoryFormat::kContiguous))); } for (const auto& blob_desc_ : sorted_blob_desc_vec_) { one_regst_header_size_ += blob_desc_->AlignedByteSizeOfBlobHeader(); one_regst_body_size_ += blob_desc_->AlignedByteSizeOfBlobBody(); } if ((!memory::IsHostMem(proto.mem_case())) || (proto.has_variable_op_name() && !proto.variable_op_name().empty())) { // NOTE(chengcheng): When this regst is shared with EagerBlobObject, header is ALWAYS separated. has_separated_header_ = true; } else { has_separated_header_ = false; } } int64_t RtRegstDesc::GetOrdinalForLbi(const LogicalBlobId& lbi) const { auto it = lbi2blob_desc_ordinal_.find(lbi); if (it != lbi2blob_desc_ordinal_.cend()) { return it->second; } else { return -1; } } const BlobDesc* RtRegstDesc::GetBlobDescFromLbi(const LogicalBlobId& lbi) const { auto it = lbi2blob_desc_ordinal_.find(lbi); if (it == lbi2blob_desc_ordinal_.end()) { return nullptr; } else { return GetBlobDescByOrdinal(it->second); } } const BlobDesc* RtRegstDesc::GetBlobDescByOrdinal(int64_t ordinal) const { return sorted_blob_desc_vec_.at(ordinal).get(); } const LogicalBlobId& RtRegstDesc::GetLbiByOrdinal(int64_t ordinal) const { return sorted_lbi_vec_.at(ordinal); } const BlobDesc* RtRegstDesc::GetSoleBlobDesc() const { CHECK_EQ(sorted_blob_desc_vec_.size(), 1); return sorted_blob_desc_vec_.at(0).get(); } size_t RtRegstDesc::TotalByteSize4AllRegst() const { return (one_regst_header_size_ + one_regst_body_size_) * register_num_; } size_t RtRegstDesc::TotalMainByteSize4AllRegst() const { return MainByteSize4OneRegst() * register_num_; } size_t RtRegstDesc::TotalBodyByteSize4AllRegst() const { return BodyByteSize4OneRegst() * register_num_; } size_t RtRegstDesc::MainByteSize4OneRegst() const { if (has_separated_header_) { return one_regst_body_size_; } else { return one_regst_body_size_ + one_regst_header_size_; } } size_t RtRegstDesc::BodyByteSize4OneRegst() const { return one_regst_body_size_; } size_t RtRegstDesc::HeaderByteSize4OneRegst() const { return one_regst_header_size_; } size_t RtRegstDesc::TotalSeparatedHeaderByteSize4AllRegst() const { return SeparatedHeaderByteSize4OneRegst() * register_num_; } size_t RtRegstDesc::SeparatedHeaderByteSize4OneRegst() const { if (has_separated_header_) { // NOTE(chengcheng): Header size need to be aligned for XRT memory allocate return one_regst_header_size_; } else { return 0; } } const Shape& RtRegstDesc::data_regst_time_shape() const { CHECK(regst_desc_type_.has_data_regst_desc()); CHECK(data_regst_time_shape_); return *data_regst_time_shape_; } void RtRegstDesc::ForEachBlobDescOffsetInOnRegst( const std::function& Handler) const { int64_t cur_body_offset = 0; int64_t cur_header_offset = 0; for (int64_t i = 0; i < sorted_blob_desc_vec_.size(); ++i) { const BlobDesc* blob_desc = sorted_blob_desc_vec_.at(i).get(); const LogicalBlobId& lbi = sorted_lbi_vec_.at(i); Handler(i, lbi, blob_desc, cur_body_offset, cur_header_offset); cur_body_offset += blob_desc->AlignedByteSizeOfBlobBody(); cur_header_offset += blob_desc->AlignedByteSizeOfBlobHeader(); } } } // namespace oneflow ================================================ FILE: oneflow/core/register/runtime_register_desc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_ #define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_ #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/register/register_desc.pb.h" namespace oneflow { class RtRegstDesc { public: OF_DISALLOW_COPY_AND_MOVE(RtRegstDesc); RtRegstDesc() = delete; ~RtRegstDesc() = default; RtRegstDesc(const RegstDescProto& regst_desc_proto); int64_t regst_desc_id() const { return regst_desc_id_; } int64_t producer_actor_id() const { return producer_actor_id_; } const std::vector& consumers_actor_id() const { return consumers_actor_id_; } int64_t register_num() const { return register_num_; } const MemoryCase& mem_case() const { return mem_case_; } const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; } int64_t lbi_num() const { return sorted_lbi_vec_.size(); } int64_t GetOrdinalForLbi(const LogicalBlobId& lbi) const; const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const; const BlobDesc* GetBlobDescByOrdinal(int64_t ordinal) const; const BlobDesc* GetSoleBlobDesc() const; const LogicalBlobId& GetLbiByOrdinal(int64_t ordinal) const; size_t TotalByteSize4AllRegst() const; size_t TotalMainByteSize4AllRegst() const; size_t TotalBodyByteSize4AllRegst() const; size_t TotalSeparatedHeaderByteSize4AllRegst() const; size_t SeparatedHeaderByteSize4OneRegst() const; size_t MainByteSize4OneRegst() const; size_t BodyByteSize4OneRegst() const; size_t HeaderByteSize4OneRegst() const; const Shape& data_regst_time_shape() const; void ForEachBlobDescOffsetInOnRegst( const std::function& Handler) const; private: int64_t regst_desc_id_; int64_t producer_actor_id_; std::vector consumers_actor_id_; int64_t register_num_; RegstDescTypeProto regst_desc_type_; MemoryCase mem_case_; HashMap lbi2blob_desc_ordinal_; std::unique_ptr data_regst_time_shape_; std::vector> sorted_blob_desc_vec_; std::vector sorted_lbi_vec_; bool has_separated_header_; size_t one_regst_header_size_; size_t one_regst_body_size_; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_ ================================================ FILE: oneflow/core/register/tensor_slice_copier.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/tensor_slice_copier.h" namespace oneflow { namespace { TensorSliceView GetRawTenserSliceView(const TensorSliceView& view, DataType data_type) { const size_t size_of_data_type = GetSizeOfDataType(data_type); if (size_of_data_type == 1) { return view; } else { std::vector range_vec = view.range_vec(); if (!view.IsEmpty()) { range_vec.back().mut_begin() = range_vec.back().begin() * size_of_data_type; range_vec.back().mut_end() = range_vec.back().end() * size_of_data_type; } return TensorSliceView(range_vec); } } } // namespace TensorSliceCopier::TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view, const TensorSliceView& copy_view, const DataType data_type, const DeviceType device_type) : dst_view_(dst_view), src_view_(src_view), extent_(copy_view.shape()), data_type_(data_type) { copy_nd_primitive_ = ep::primitive::NewPrimitive( device_type, dst_view_.shape().NumAxes()); CHECK(dst_view.Contains(copy_view)); CHECK(src_view.Contains(copy_view)); dst_pos_ = copy_view.OffsetTo(dst_view); src_pos_ = copy_view.OffsetTo(src_view); } TensorSliceCopier::TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view, const DataType data_type, const DeviceType device_type) : TensorSliceCopier(dst_view, src_view, dst_view.Intersect(src_view), data_type, device_type) {} void TensorSliceCopier::Copy(ep::Stream* stream, void* dst, const void* src) const { copy_nd_primitive_->Launch(stream, data_type_, dst_view_.shape().NumAxes(), dst, dst_view_.shape().dim_vec().data(), dst_pos_.dim_vec().data(), src, src_view_.shape().dim_vec().data(), src_pos_.dim_vec().data(), extent_.dim_vec().data()); } void TensorSliceCopier::Copy(ep::Stream* stream, Blob* dst_blob, const Blob* src_blob) const { CHECK_EQ(dst_blob->data_type(), data_type_); CHECK_EQ(src_blob->data_type(), data_type_); CHECK_EQ(dst_view_.shape().elem_cnt(), dst_blob->shape().elem_cnt()); CHECK_EQ(src_view_.shape().elem_cnt(), src_blob->shape().elem_cnt()); Copy(stream, dst_blob->mut_dptr(), src_blob->dptr()); } } // namespace oneflow ================================================ FILE: oneflow/core/register/tensor_slice_copier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_ #define ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/register/blob.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" namespace oneflow { class TensorSliceCopier final { public: OF_DISALLOW_COPY_AND_MOVE(TensorSliceCopier); TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view, const TensorSliceView& copy_view, DataType data_type, DeviceType device_type); TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view, DataType data_type, DeviceType device_type); virtual ~TensorSliceCopier() = default; void Copy(ep::Stream* stream, void* dst, const void* src) const; void Copy(ep::Stream* stream, Blob* dst_blob, const Blob* src_blob) const; private: const TensorSliceView dst_view_; const TensorSliceView src_view_; NdIndex dst_pos_; NdIndex src_pos_; Shape extent_; const DataType data_type_; std::unique_ptr copy_nd_primitive_; }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_ ================================================ FILE: oneflow/core/register/tensor_slice_view.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { TensorSliceView::TensorSliceView(const std::initializer_list& ranges) : range_vec_(ranges) { UpdateShape(); } TensorSliceView::TensorSliceView(const std::vector& ranges) : range_vec_(ranges) { UpdateShape(); } TensorSliceView::TensorSliceView(const TensorSliceViewProto& proto) { range_vec_.resize(proto.dim_size()); std::transform(proto.dim().cbegin(), proto.dim().cend(), range_vec_.begin(), [](const RangeProto& rp) { return Range(rp); }); UpdateShape(); } TensorSliceView::TensorSliceView(const Shape& shape) { range_vec_.resize(shape.dim_vec().size()); std::transform(shape.dim_vec().cbegin(), shape.dim_vec().cend(), range_vec_.begin(), [](const int64_t dim_size) { return Range(0, dim_size); }); UpdateShape(); } TensorSliceView& TensorSliceView::operator=(const TensorSliceView& other) { range_vec_ = other.range_vec_; UpdateShape(); return *this; } bool TensorSliceView::operator==(const TensorSliceView& rhs) const { return range_vec_ == rhs.range_vec_; } bool TensorSliceView::operator!=(const TensorSliceView& rhs) const { return !(*this == rhs); } void TensorSliceView::UpdateShape() { DimVector dim_vec(range_vec_.size()); std::transform(range_vec_.cbegin(), range_vec_.cend(), dim_vec.begin(), [](const Range& range) { return range.size(); }); shape_ = Shape(dim_vec); } bool TensorSliceView::IsEmpty() const { return range_vec_.empty(); } bool TensorSliceView::Contains(const TensorSliceView& other) const { if (other.IsEmpty()) { return true; } CHECK_EQ(NumAxes(), other.NumAxes()); FOR_RANGE(int64_t, i, 0, NumAxes()) { if (range_vec_.at(i).begin() > other.range_vec_.at(i).begin() || range_vec_.at(i).end() < other.range_vec_.at(i).end()) { return false; } } return true; } TensorSliceView TensorSliceView::Intersect(const TensorSliceView& other) const { if (IsEmpty() || other.IsEmpty()) { return TensorSliceView(); } CHECK_EQ(other.range_vec_.size(), range_vec_.size()); std::vector intersection_vec; intersection_vec.reserve(range_vec_.size()); const Range zero(0, 0); FOR_RANGE(int64_t, i, 0, range_vec_.size()) { const Range intersection = FindIntersectant(range_vec_.at(i), other.range_vec_.at(i)); if (intersection == zero) { return TensorSliceView(); } else { intersection_vec.emplace_back(intersection); } } return TensorSliceView(intersection_vec); } const Range& TensorSliceView::At(int64_t index) const { return range_vec_.at(index); } const Shape& TensorSliceView::shape() const { return shape_; } const std::vector& TensorSliceView::range_vec() const { return range_vec_; } size_t TensorSliceView::NumAxes() const { return range_vec_.size(); } NdIndex TensorSliceView::OffsetTo(const TensorSliceView& other) const { CHECK_EQ(other.NumAxes(), NumAxes()); DimVector indices_vec(range_vec_.size()); std::transform(range_vec_.cbegin(), range_vec_.cend(), other.range_vec_.cbegin(), indices_vec.begin(), [](const Range& lhs, const Range& rhs) { return lhs.begin() - rhs.begin(); }); return NdIndex(indices_vec); } void TensorSliceView::ToProto(TensorSliceViewProto* proto) const { for (const Range& range : range_vec_) { range.ToProto(proto->mutable_dim()->Add()); } } TensorSliceView TensorSliceView::Concatenate(std::vector& slices, int64_t axis) { CHECK_GT(slices.size(), 0); const int64_t num_axes = slices.front().shape().NumAxes(); FOR_RANGE(int64_t, i, 1, slices.size()) { CHECK_EQ(slices.at(i).NumAxes(), num_axes); } CHECK_GE(axis, 0); CHECK_LT(axis, num_axes); FOR_RANGE(int64_t, i, 0, num_axes) { if (i == axis) { CHECK(std::adjacent_find(slices.cbegin(), slices.cend() - 1, [&](const TensorSliceView& lhs, const TensorSliceView& rhs) { return lhs.At(i).end() != rhs.At(i).begin(); }) == slices.cend() - 1); } else { const Range& dim_range = slices.front().At(i); CHECK(std::all_of(slices.cbegin() + 1, slices.cbegin(), [&](const TensorSliceView& view) { return view.At(i) == dim_range; })); } } std::vector range_vec = slices.front().range_vec(); range_vec.at(axis).mut_end() = slices.back().At(axis).end(); return TensorSliceView(range_vec); } } // namespace oneflow ================================================ FILE: oneflow/core/register/tensor_slice_view.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_ #define ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_ #include "oneflow/core/common/range.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/nd_index.h" #include "oneflow/core/register/tensor_slice_view.pb.h" namespace oneflow { class TensorSliceView final { public: TensorSliceView() = default; TensorSliceView(const std::initializer_list& ranges); explicit TensorSliceView(const std::vector& ranges); explicit TensorSliceView(const TensorSliceViewProto& proto); explicit TensorSliceView(const Shape& shape); TensorSliceView& operator=(const TensorSliceView& other); bool operator==(const TensorSliceView& rhs) const; bool operator!=(const TensorSliceView& rhs) const; bool IsEmpty() const; TensorSliceView Intersect(const TensorSliceView& other) const; bool Contains(const TensorSliceView& other) const; const Range& At(int64_t index) const; const Shape& shape() const; const std::vector& range_vec() const; size_t NumAxes() const; NdIndex OffsetTo(const TensorSliceView& other) const; void ToProto(TensorSliceViewProto* proto) const; static TensorSliceView Concatenate(std::vector& slices, int64_t axis); private: std::vector range_vec_; Shape shape_; void UpdateShape(); }; } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_ ================================================ FILE: oneflow/core/register/tensor_slice_view.proto ================================================ syntax = "proto2"; package oneflow; import "oneflow/core/common/range.proto"; message TensorSliceViewProto { repeated RangeProto dim = 1; } ================================================ FILE: oneflow/core/rpc/include/base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_ #define ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_ #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" #include "oneflow/core/control/control.pb.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" namespace oneflow { #define CTRL_METHOD_SEQ \ OF_PP_MAKE_TUPLE_SEQ(LoadServer) \ OF_PP_MAKE_TUPLE_SEQ(Barrier) \ OF_PP_MAKE_TUPLE_SEQ(TryLock) \ OF_PP_MAKE_TUPLE_SEQ(NotifyDone) \ OF_PP_MAKE_TUPLE_SEQ(WaitUntilDone) \ OF_PP_MAKE_TUPLE_SEQ(PushKV) \ OF_PP_MAKE_TUPLE_SEQ(ClearKV) \ OF_PP_MAKE_TUPLE_SEQ(PullKV) \ OF_PP_MAKE_TUPLE_SEQ(Clear) \ OF_PP_MAKE_TUPLE_SEQ(IncreaseCount) \ OF_PP_MAKE_TUPLE_SEQ(EraseCount) #define CatRequest(method) method##Request, #define CatReqponse(method) method##Response, #define CatEnum(method) k##method, #define CatName(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method), #define MAKE_META_DATA() \ enum class CtrlMethod { OF_PP_FOR_EACH_TUPLE(CatEnum, CTRL_METHOD_SEQ) }; \ static const char* g_method_name[] = {OF_PP_FOR_EACH_TUPLE(CatName, CTRL_METHOD_SEQ)}; \ using CtrlRequestTuple = std::tuple; \ using CtrlResponseTuple = std::tuple; MAKE_META_DATA() constexpr const size_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ); template using CtrlRequest = typename std::tuple_element(ctrl_method), CtrlRequestTuple>::type; template using CtrlResponse = typename std::tuple_element(ctrl_method), CtrlResponseTuple>::type; inline const char* GetMethodName(CtrlMethod method) { return g_method_name[static_cast(method)]; } class CtrlClient { public: explicit CtrlClient(const ProcessCtx& process_ctx); CtrlClient() = default; virtual ~CtrlClient() = default; virtual void Barrier(const std::string& barrier_name) = 0; virtual void Barrier(const std::string& barrier_name, int32_t barrier_num) = 0; virtual TryLockResult TryLock(const std::string& name) = 0; virtual void NotifyDone(const std::string& name) = 0; virtual void WaitUntilDone(const std::string& name) = 0; virtual void PushKV(const std::string& k, std::function VSetter) = 0; virtual void PushKV(const std::string& k, const std::string& v) = 0; virtual void PushKV(const std::string& k, const PbMessage& msg) = 0; virtual void PushMasterKV(const std::string& k, const PbMessage& msg) = 0; template typename std::enable_if::value>::type PushKVT(const std::string& k, T v) { PushKV(k, std::to_string(v)); } virtual void ClearKV(const std::string& k) = 0; virtual void ClearMasterKV(const std::string& k) = 0; virtual void PullKV(const std::string& k, std::function VGetter) = 0; virtual void PullKV(const std::string& k, std::string* v) = 0; virtual void PullKV(const std::string& k, PbMessage* msg) = 0; virtual void PullMasterKV(const std::string& k, PbMessage* msg) = 0; template typename std::enable_if::value>::type PullKVT(const std::string& k, T* v) { std::string v_str; PullKV(k, &v_str); *v = oneflow_cast(v_str); } virtual void Clear() = 0; virtual int32_t IncreaseCount(const std::string& k, int32_t v) = 0; int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); } virtual void EraseCount(const std::string& k) = 0; }; #define FILE_LINE_STR __FILE__ ":" OF_PP_STRINGIZE(__LINE__) #define OF_ENV_BARRIER() oneflow::Singleton::Get()->Barrier(FILE_LINE_STR) #define OF_SESSION_BARRIER() \ oneflow::Singleton::Get()->Barrier( \ FILE_LINE_STR, Singleton::Get()->process_ranks().size()) static void OfCallOnce(const std::string& name, std::function f) { TryLockResult lock_ret = Singleton::Get()->TryLock(name); if (lock_ret == TryLockResult::kLocked) { f(); Singleton::Get()->NotifyDone(name); } else if (lock_ret == TryLockResult::kDone) { } else if (lock_ret == TryLockResult::kDoing) { Singleton::Get()->WaitUntilDone(name); } else { UNIMPLEMENTED(); } } template static void OfCallOnce(const std::string& name, Self self, F f, Arg&& arg, Args&&... args) { std::function fn = std::bind(f, self, std::forward(arg), std::forward(args)...); OfCallOnce(name, std::move(fn)); } template static void OfCallOnce(const std::string& name, Self self, F f) { std::function fn = std::bind(f, self, name); OfCallOnce(name, std::move(fn)); } template static void OfCallOnce(const std::string& name, F f, Arg&& arg, Args&&... args) { std::function fn = std::bind(f, std::forward(arg), std::forward(args)...); OfCallOnce(name, std::move(fn)); } class RpcManager { public: RpcManager() = default; virtual ~RpcManager() = default; virtual Maybe Bootstrap() = 0; virtual Maybe CreateServer() = 0; virtual Maybe CreateClient() = 0; }; } // namespace oneflow #endif // ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_ ================================================ FILE: oneflow/core/rpc/include/ctrl.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_CTRL_ #define ONEFLOW_CORE_RPC_INCLUDE_CTRL_ #ifdef RPC_BACKEND_GRPC #include "oneflow/core/rpc/include/grpc.h" #endif // RPC_BACKEND_GRPC #ifdef RPC_BACKEND_LOCAL #include "oneflow/core/rpc/include/local.h" #endif // RPC_BACKEND_LOCAL #endif // ONEFLOW_CORE_RPC_INCLUDE_CTRL_ ================================================ FILE: oneflow/core/rpc/include/global_process_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_ #define ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_ #include namespace oneflow { struct GlobalProcessCtx { static void GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id, int64_t* device_id); static void GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id); static int64_t Rank(); static int64_t LocalRank(); static int64_t LocalRank(int64_t rank); static int64_t NodeId(int64_t process_id); static int64_t NodeSize(); static int64_t ThisNodeId(); static int64_t NumOfProcessPerNode(); static bool IsThisProcessMaster(); static size_t WorldSize(); static std::string LogDirEntry(); }; } // namespace oneflow #endif // ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_ ================================================ FILE: oneflow/core/rpc/include/grpc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_ #define ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_ #include "oneflow/core/control/rpc_client.h" #include "oneflow/core/rpc/include/base.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" namespace oneflow { class GrpcCtrlClient final : public CtrlClient { public: OF_DISALLOW_COPY_AND_MOVE(GrpcCtrlClient); explicit GrpcCtrlClient(const ProcessCtx& process_ctx); ~GrpcCtrlClient() override; void Barrier(const std::string& barrier_name) override; void Barrier(const std::string& barrier_name, int32_t barrier_num) override; TryLockResult TryLock(const std::string& name) override; void NotifyDone(const std::string& name) override; void WaitUntilDone(const std::string& name) override; void PushKV(const std::string& k, std::function VSetter) override; void PushKV(const std::string& k, const std::string& v) override; void PushKV(const std::string& k, const PbMessage& msg) override; void PushMasterKV(const std::string& k, const PbMessage& msg) override; void ClearKV(const std::string& k) override; void ClearMasterKV(const std::string& k) override; void PullKV(const std::string& k, std::function VGetter) override; void PullKV(const std::string& k, std::string* v) override; void PullKV(const std::string& k, PbMessage* msg) override; void PullMasterKV(const std::string& k, PbMessage* msg) override; void Clear() override; int32_t IncreaseCount(const std::string& k, int32_t v) override; void EraseCount(const std::string& k) override; void StopHeartbeat(); private: const ProcessCtx& process_ctx() const { return process_ctx_; } ProcessCtx process_ctx_; bool need_heartbeat_thread_stop_; std::mutex need_heartbeat_thread_stop_mtx_; std::condition_variable need_heartbeat_thread_stop_cv_; std::thread heartbeat_thread_; RpcClient rpc_client_; }; class GrpcRpcManager : public RpcManager { public: GrpcRpcManager() = default; ~GrpcRpcManager() override; Maybe Bootstrap() override; Maybe CreateServer() override; Maybe CreateClient() override; }; } // namespace oneflow #endif // ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_ ================================================ FILE: oneflow/core/rpc/include/local.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_ #define ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_ #include #include #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/rpc/include/base.h" namespace oneflow { class LocalCtrlClient : public CtrlClient { public: OF_DISALLOW_COPY_AND_MOVE(LocalCtrlClient); explicit LocalCtrlClient(const ProcessCtx& process_ctx); ~LocalCtrlClient() override = default; void Barrier(const std::string& barrier_name) override; void Barrier(const std::string& barrier_name, int32_t barrier_num) override; TryLockResult TryLock(const std::string& name) override; void NotifyDone(const std::string& name) override; void WaitUntilDone(const std::string& name) override; void PushKV(const std::string& k, std::function VSetter) override; void PushKV(const std::string& k, const std::string& v) override; void PushKV(const std::string& k, const PbMessage& msg) override; void PushMasterKV(const std::string& k, const PbMessage& msg) override; void ClearKV(const std::string& k) override; void ClearMasterKV(const std::string& k) override; void PullKV(const std::string& k, std::function VGetter) override; void PullKV(const std::string& k, std::string* v) override; void PullKV(const std::string& k, PbMessage* msg) override; void PullMasterKV(const std::string& k, PbMessage* msg) override; void Clear() override; int32_t IncreaseCount(const std::string& k, int32_t v) override; void EraseCount(const std::string& k) override; HashSet done_names_; HashSet doing_names_; std::mutex done_names_mtx_; std::condition_variable done_names_cv_; HashMap kv_; std::mutex kv_mtx_; std::condition_variable kv_cv_; HashMap counter_; std::mutex counter_mtx_; HashMap> barrier_counter_; std::mutex barrier_counter_mtx_; }; class LocalRpcManager : public RpcManager { public: LocalRpcManager() = default; ~LocalRpcManager() override; Maybe Bootstrap() override; Maybe CreateServer() override { return Maybe::Ok(); } Maybe CreateClient() override; }; class DryRunRpcManager : public RpcManager { public: DryRunRpcManager() = default; ~DryRunRpcManager() override; Maybe Bootstrap() override; Maybe CreateServer() override { return Maybe::Ok(); } Maybe CreateClient() override; }; } // namespace oneflow #endif // ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_ ================================================ FILE: oneflow/core/rpc/include/manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_ #define ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_ #ifdef RPC_BACKEND_GRPC #include "oneflow/core/rpc/include/grpc.h" #endif // RPC_BACKEND_GRPC #ifdef RPC_BACKEND_LOCAL #include "oneflow/core/rpc/include/local.h" #endif // RPC_BACKEND_LOCAL #endif // ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_ ================================================ FILE: oneflow/core/rpc/lib/global_process_ctx.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { void GlobalProcessCtx::GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id, int64_t* device_id) { *machine_id = rank; *device_id = rank % NumOfProcessPerNode(); } void GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id) { *machine_id = Rank(); *device_id = LocalRank(); } int64_t GlobalProcessCtx::Rank() { CHECK_NOTNULL(Singleton::Get()); return Singleton::Get()->rank(); } int64_t GlobalProcessCtx::LocalRank() { char* local_rank_env = std::getenv("LOCAL_RANK"); if (!local_rank_env) { static int64_t local_rank = Rank() % NumOfProcessPerNode(); return local_rank; } CHECK(IsStrInt(local_rank_env)); static int64_t local_rank = std::stol(local_rank_env); return local_rank; } int64_t GlobalProcessCtx::NodeSize() { CHECK_NOTNULL(Singleton::Get()); return Singleton::Get()->node_size(); } int64_t GlobalProcessCtx::ThisNodeId() { CHECK_NOTNULL(Singleton::Get()); return NodeId(Rank()); } int64_t GlobalProcessCtx::NodeId(int64_t process_id) { CHECK_NOTNULL(Singleton::Get()); return process_id / NumOfProcessPerNode(); } int64_t GlobalProcessCtx::NumOfProcessPerNode() { CHECK_NOTNULL(Singleton::Get()); CHECK_EQ(WorldSize() % NodeSize(), 0); return int64_t(WorldSize() / NodeSize()); } bool GlobalProcessCtx::IsThisProcessMaster() { CHECK_NOTNULL(Singleton::Get()); return Singleton::Get()->rank() == 0; } size_t GlobalProcessCtx::WorldSize() { CHECK_NOTNULL(Singleton::Get()); return Singleton::Get()->ctrl_addr().size(); } std::string GlobalProcessCtx::LogDirEntry() { CHECK_NOTNULL(Singleton::Get()); const auto& process_ctx = *Singleton::Get(); const auto& addr = process_ctx.ctrl_addr(process_ctx.rank()); CHECK(addr.has_host()); return addr.host() + "-" + std::to_string(addr.port()) + "-" + std::to_string(process_ctx.rank()); } /* static */ int64_t GlobalProcessCtx::LocalRank(int64_t rank) { return rank % NumOfProcessPerNode(); } } // namespace oneflow ================================================ FILE: oneflow/core/rpc/lib/grpc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef RPC_BACKEND_GRPC #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/rpc/include/grpc.h" namespace oneflow { namespace { Maybe GetCtrlPort(const EnvDesc& env_desc) { int port = 0; if (env_desc.has_bootstrap_conf_ctrl_port()) { port = env_desc.bootstrap_conf_ctrl_port(); } return port; } } // namespace Maybe GrpcRpcManager::Bootstrap() { std::shared_ptr ctrl_bootstrap; auto& env_desc = *Singleton::Get(); if (env_desc.has_ctrl_bootstrap_conf()) { ctrl_bootstrap.reset(new RankInfoCtrlBootstrap(env_desc.bootstrap_conf())); } else { ctrl_bootstrap.reset(new HostListCtrlBootstrap(env_desc)); } JUST(ctrl_bootstrap->InitProcessCtx(Singleton::Get()->port(), Singleton::Get())); return Maybe::Ok(); } Maybe GrpcRpcManager::CreateServer() { Singleton::New(JUST(GetCtrlPort(*Singleton::Get()))); return Maybe::Ok(); } Maybe GrpcRpcManager::CreateClient() { auto* client = new GrpcCtrlClient(*Singleton::Get()); Singleton::SetAllocated(client); return Maybe::Ok(); } GrpcRpcManager::~GrpcRpcManager() { auto* grpc_client = dynamic_cast(Singleton::Get()); CHECK_NOTNULL(grpc_client); grpc_client->StopHeartbeat(); OF_ENV_BARRIER(); Singleton::Delete(); CHECK_NOTNULL(Singleton::Get()); Singleton::Delete(); } } // namespace oneflow #endif // RPC_BACKEND_GRPC ================================================ FILE: oneflow/core/rpc/lib/local.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef RPC_BACKEND_LOCAL #include "glog/logging.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/rpc/include/local.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" namespace oneflow { LocalCtrlClient::LocalCtrlClient(const ProcessCtx& process_ctx) { CHECK(process_ctx.ctrl_addr_size() == 1); CHECK(process_ctx.node_size() == 1); } void LocalCtrlClient::Barrier(const std::string& barrier_name) { Barrier(barrier_name, Singleton::Get()->TotalMachineNum()); } void LocalCtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) { std::shared_ptr counter; bool is_first = false; { std::unique_lock lck(barrier_counter_mtx_); auto it = barrier_counter_.find(barrier_name); if (it == barrier_counter_.end()) { is_first = true; counter = std::make_shared(barrier_num); CHECK(barrier_counter_.emplace(barrier_name, counter).second); } else { counter = it->second; } } counter->Decrease(); counter->WaitForeverUntilCntEqualZero(); if (is_first) { std::unique_lock lck(barrier_counter_mtx_); CHECK_EQ(barrier_counter_.erase(barrier_name), 1); } } TryLockResult LocalCtrlClient::TryLock(const std::string& name) { std::unique_lock lck(done_names_mtx_); if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; } else if (doing_names_.find(name) != doing_names_.end()) { return TryLockResult::kDoing; } else { doing_names_.insert(name); return TryLockResult::kLocked; } } void LocalCtrlClient::NotifyDone(const std::string& name) { std::unique_lock lck(done_names_mtx_); done_names_.insert(name); CHECK_EQ(doing_names_.erase(name), 1); done_names_cv_.notify_all(); } void LocalCtrlClient::WaitUntilDone(const std::string& name) { std::unique_lock lck(done_names_mtx_); VLOG(3) << "waiting for name: " << name; done_names_cv_.wait(lck); CHECK(done_names_.find(name) != done_names_.end()); } void LocalCtrlClient::PushKV(const std::string& k, std::function VSetter) { std::unique_lock lck(kv_mtx_); VSetter(&kv_[k]); kv_cv_.notify_all(); } void LocalCtrlClient::PushKV(const std::string& k, const std::string& v) { PushKV(k, [&](std::string* o) { *o = v; }); } void LocalCtrlClient::PushKV(const std::string& k, const PbMessage& msg) { PushKV(k, [&](std::string* o) { msg.SerializeToString(o); }); } void LocalCtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) { PushKV(k, [&](std::string* o) { msg.SerializeToString(o); }); } void LocalCtrlClient::ClearKV(const std::string& k) { std::unique_lock lck(kv_mtx_); kv_.erase(k); } void LocalCtrlClient::ClearMasterKV(const std::string& k) { ClearKV(k); } void LocalCtrlClient::PullKV(const std::string& k, std::function VGetter) { std::unique_lock lck(kv_mtx_); while (true) { auto it = kv_.find(k); if (it == kv_.end()) { VLOG(3) << "waiting for key: " << k; kv_cv_.wait(lck); } else { VGetter(it->second); break; } } } void LocalCtrlClient::PullKV(const std::string& k, std::string* v) { PullKV(k, [&](const std::string& i) { *v = i; }); } void LocalCtrlClient::PullKV(const std::string& k, PbMessage* msg) { PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); } void LocalCtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) { PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); } void LocalCtrlClient::Clear() { { std::unique_lock lck(done_names_mtx_); done_names_.clear(); done_names_cv_.notify_all(); } { std::unique_lock lck(kv_mtx_); kv_.clear(); kv_cv_.notify_all(); } } int32_t LocalCtrlClient::IncreaseCount(const std::string& k, int32_t v) { std::unique_lock lck(counter_mtx_); auto it = counter_.find(k); if (it == counter_.end()) { counter_[k] = 1; return 1; } else { const int32_t new_val = it->second + 1; counter_[k] = new_val; return new_val; } } void LocalCtrlClient::EraseCount(const std::string& k) { std::unique_lock lck(counter_mtx_); counter_.erase(k); } class DryRunCtrlClient : public CtrlClient { public: OF_DISALLOW_COPY_AND_MOVE(DryRunCtrlClient); explicit DryRunCtrlClient(const ProcessCtx& process_ctx) : local_ctrl_client_{std::unique_ptr(new LocalCtrlClient(process_ctx))} { CHECK(process_ctx.ctrl_addr_size() == 1); CHECK(process_ctx.node_size() == 1); } ~DryRunCtrlClient() override = default; void Barrier(const std::string& barrier_name) override { Barrier(barrier_name, Singleton::Get()->TotalMachineNum()); } void Barrier(const std::string& barrier_name, int32_t barrier_num) override { VLOG(3) << "skipping barrier in dry run, barrier name: " << barrier_name << ", barrier num: " << barrier_num; } TryLockResult TryLock(const std::string& name) override { return local_ctrl_client_->TryLock(name); } void NotifyDone(const std::string& name) override { local_ctrl_client_->NotifyDone(name); } void WaitUntilDone(const std::string& name) override { local_ctrl_client_->WaitUntilDone(name); } void PushKV(const std::string& k, std::function VSetter) override { local_ctrl_client_->PushKV(k, VSetter); } void PushKV(const std::string& k, const std::string& v) override { local_ctrl_client_->PushKV(k, v); } void PushKV(const std::string& k, const PbMessage& msg) override { local_ctrl_client_->PushKV(k, msg); } void PushMasterKV(const std::string& k, const PbMessage& msg) override { local_ctrl_client_->PushMasterKV(k, msg); } void ClearKV(const std::string& k) override { local_ctrl_client_->ClearKV(k); } void ClearMasterKV(const std::string& k) override { local_ctrl_client_->ClearMasterKV(k); } void PullKV(const std::string& k, std::function VGetter) override { local_ctrl_client_->PullKV(k, VGetter); } void PullKV(const std::string& k, std::string* v) override { local_ctrl_client_->PullKV(k, v); } void PullKV(const std::string& k, PbMessage* msg) override { local_ctrl_client_->PullKV(k, msg); } void PullMasterKV(const std::string& k, PbMessage* msg) override { local_ctrl_client_->PullMasterKV(k, msg); } void Clear() override { local_ctrl_client_->Clear(); } int32_t IncreaseCount(const std::string& k, int32_t v) override { return local_ctrl_client_->IncreaseCount(k, v); } void EraseCount(const std::string& k) override { local_ctrl_client_->EraseCount(k); } private: std::unique_ptr local_ctrl_client_; }; void SetLocalProcessCtx(oneflow::ProcessCtx* ctx) { Address* addr = ctx->add_ctrl_addr(); addr->set_host("localhost"); ctx->set_rank(0); ctx->set_node_size(1); } Maybe LocalRpcManager::Bootstrap() { SetLocalProcessCtx(Singleton::Get()); return Maybe::Ok(); } Maybe LocalRpcManager::CreateClient() { auto* client = new LocalCtrlClient(*Singleton::Get()); Singleton::SetAllocated(client); return Maybe::Ok(); } LocalRpcManager::~LocalRpcManager() { Singleton::Delete(); } Maybe DryRunRpcManager::Bootstrap() { SetLocalProcessCtx(Singleton::Get()); return Maybe::Ok(); } Maybe DryRunRpcManager::CreateClient() { auto* client = new DryRunCtrlClient(*Singleton::Get()); Singleton::SetAllocated(client); return Maybe::Ok(); } DryRunRpcManager::~DryRunRpcManager() { Singleton::Delete(); } } // namespace oneflow #endif // RPC_BACKEND_LOCAL ================================================ FILE: oneflow/core/summary/event.proto ================================================ syntax = "proto2"; package oneflow.summary; import "oneflow/core/summary/summary.proto"; message Event { required double wall_time = 1; optional int64 step = 2; oneof what { string file_version = 3; bytes graph_def = 4; Summary summary = 5; bytes meta_graph_def = 9; } } ================================================ FILE: oneflow/core/summary/graph.proto ================================================ syntax = "proto2"; package oneflow.summary; import "oneflow/core/framework/user_op_attr.proto"; message GraphDef { repeated NodeDef node = 1; required int32 version = 2 [deprecated = true]; } message NodeDef { required string name = 1; required string op = 2; repeated string input = 3; optional string device = 4; map attr = 5; } ================================================ FILE: oneflow/core/summary/plugin_data.proto ================================================ syntax = "proto2"; package oneflow.summary; import "google/protobuf/struct.proto"; message HParamsPluginData { required int32 version = 1; oneof data { SessionStartInfo session_start_info = 3; } } message SessionStartInfo { map hparams = 1; required string group_name = 4; required double start_time_secs = 5; map metrics = 6; } ================================================ FILE: oneflow/core/summary/projector.proto ================================================ syntax = "proto2"; package oneflow.summary; message MetaData { enum ProjectorType { EMBEDDING = 0; EXCEPTION = 1; } required ProjectorType type = 1; //Metadata specific information optional string content = 2; } message Tensor { message TensorShape { message Dim { required int64 size = 1; optional string name = 2; } repeated Dim dim = 1; } required string dtype = 1; required TensorShape shape = 2; optional bytes content = 3; } message Sample{ enum SampleType { IMAGE = 0; AUDIO = 1; TEXT = 2; } required string name = 1; required SampleType type = 2; required Tensor X = 3; } message Projector { required string tag = 1; optional int64 step = 2; required double WALL_TIME = 3; required Tensor value = 4; optional Tensor label = 5; } message SummaryProjector { required MetaData metadata = 6; optional Sample sample = 2; repeated Projector projector = 1; } ================================================ FILE: oneflow/core/summary/summary.proto ================================================ syntax = "proto2"; package oneflow.summary; import "oneflow/core/summary/tensor.proto"; message SummaryMetadata { message PluginData { required string plugin_name = 1; optional bytes content = 2; } required PluginData plugin_data = 1; optional string display_name = 2; optional string summary_description = 3; }; message HistogramProto { required double min = 1; required double max = 2; required double num = 3; required double sum = 4; required double sum_squares = 5; repeated double bucket_limit = 6 [packed = true]; repeated double bucket = 7 [packed = true]; }; message Image { required int32 height = 1; required int32 width = 2; required int32 colorspace = 3; required bytes encoded_image_string = 4; }; message Summary { message Value { optional string node_name = 7; required string tag = 1; optional SummaryMetadata metadata = 9; oneof value { float simple_value = 2; bytes obsolete_old_style_histogram = 3; Image image = 4; HistogramProto histo = 5; //Audio audio = 6; TensorProto tensor = 8; } } repeated Value value = 1; } ================================================ FILE: oneflow/core/summary/tensor.proto ================================================ syntax = "proto2"; package oneflow.summary; message TensorProto { required TensorDataType dtype = 1; required TensorShapeProto tensor_shape = 2; optional int32 version_number = 3; optional bytes tensor_content = 4; repeated float float_val = 5 [packed = true]; repeated double double_val = 6 [packed = true]; repeated int32 int_val = 7 [packed = true]; repeated bytes string_val = 8; repeated int64 int64_val = 9 [packed = true]; repeated bool bool_val = 10 [packed = true]; repeated uint32 uint32_val = 11 [packed = true]; repeated uint64 uint64_val = 12 [packed = true]; repeated int32 half_val = 13 [packed = true]; }; message TensorShapeProto { message Dim { required int64 size = 1; optional string name = 2; }; repeated Dim dim = 2; }; enum TensorDataType { DT_INVALID = 0; DT_FLOAT = 1; DT_DOUBLE = 2; DT_INT32 = 3; DT_UINT8 = 4; DT_INT16 = 5; DT_INT8 = 6; DT_STRING = 7; DT_INT64 = 8; DT_UINT16 = 9; DT_HALF = 10; DT_UINT32 = 11; DT_UINT64 = 12; } ================================================ FILE: oneflow/core/thread/is_main_thread_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace test { TEST(IsMainThread, IsMainThread) { EXPECT_TRUE(IsMainThread()); auto non_main_thread = std::thread([&]() { EXPECT_FALSE(IsMainThread()); }); non_main_thread.join(); } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/thread/thread.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/lazy/actor/light_actor.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/lazy/stream_context/include/stream_context.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/lazy/stream_context/include/generic_stream_context.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { Thread::Thread(const StreamId& stream_id) : thrd_id_(EncodeStreamIdToInt64(stream_id)) { local_msg_queue_enabled_ = ParseBooleanFromEnv("ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE", true); light_actor_enabled_ = ParseBooleanFromEnv("ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR", true); if (IsClassRegistered(stream_id.device_id().device_type(), stream_id)) { stream_ctx_.reset(NewObj( stream_id.device_id().device_type(), stream_id)); } else { stream_ctx_.reset(new GenericStreamContext(stream_id)); } actor_thread_ = std::thread([this, stream_id]() { LazyMode::Guard guard(true); OF_PROFILER_NAME_THIS_HOST_THREAD("_" + ToString(stream_id.device_id().device_type()) + std::to_string(stream_id.device_id().device_index()) + "_actor"); CHECK_JUST(stream_ctx_->stream()->OnExecutionContextSetup()); PollMsgChannel(); CHECK_JUST(stream_ctx_->stream()->OnExecutionContextTeardown()); }); } Thread::~Thread() { actor_thread_.join(); CHECK(id2task_.empty()); msg_channel_.Close(); } void Thread::AddTask(const TaskProto& task) { std::unique_lock lck(id2task_mtx_); CHECK(id2task_.emplace(task.task_id(), task).second); } void Thread::PollMsgChannel() { while (true) { if (local_msg_queue_.empty()) { CHECK_EQ(msg_channel_.ReceiveMany(&local_msg_queue_), kChannelStatusSuccess); } ActorMsg msg = std::move(local_msg_queue_.front()); local_msg_queue_.pop(); if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.actor_cmd() == ActorCmd::kStopThread) { CHECK(id2actor_ptr_.empty()) << " RuntimeError! Thread: " << thrd_id_ << " NOT empty when stop with actor num: " << id2actor_ptr_.size(); break; } else if (msg.actor_cmd() == ActorCmd::kConstructActor) { ConstructActor(msg.dst_actor_id()); continue; } else { // do nothing } } int64_t actor_id = msg.dst_actor_id(); auto actor_it = id2actor_ptr_.find(actor_id); CHECK(actor_it != id2actor_ptr_.end()); int process_msg_ret = actor_it->second.second->ProcessMsg(msg); if (process_msg_ret == 1) { VLOG(3) << "thread " << thrd_id_ << " deconstruct actor " << actor_id; auto job_id_it = id2job_id_.find(actor_id); const int64_t job_id = job_id_it->second; id2job_id_.erase(job_id_it); id2actor_ptr_.erase(actor_it); Singleton::Get()->DecreaseCounter(GetRunningActorCountKeyByJobId(job_id)); } else { CHECK_EQ(process_msg_ret, 0); } } } void Thread::ConstructActor(int64_t actor_id) { std::unique_lock lck(id2task_mtx_); auto task_it = id2task_.find(actor_id); const TaskProto& task = task_it->second; std::unique_ptr actor_ctx = NewActorContext(task, stream_ctx_.get()); CHECK(actor_ctx); std::unique_ptr actor_ptr; if (light_actor_enabled_) { actor_ptr = TryNewLightActor(actor_ctx.get()); } if (!actor_ptr) { actor_ptr = NewActor(actor_ctx.get()); VLOG(3) << "Thread " << thrd_id_ << " construct Actor " << TaskType_Name(task.task_type()) << " " << actor_id; } else { VLOG(3) << "Thread " << thrd_id_ << " construct LightActor " << TaskType_Name(task.task_type()) << " " << actor_id; } CHECK(id2actor_ptr_.emplace(actor_id, std::make_pair(std::move(actor_ctx), std::move(actor_ptr))) .second); CHECK(id2job_id_.emplace(actor_id, task.job_id()).second); id2task_.erase(task_it); Singleton::Get()->DecreaseCounter("constructing_actor_cnt"); } } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_THREAD_H_ #define ONEFLOW_CORE_THREAD_THREAD_H_ #include "oneflow/core/lazy/actor/actor_message_bus.h" #include "oneflow/core/common/channel.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/lazy/actor/actor_context.h" namespace oneflow { class StreamContext; class Thread { public: OF_DISALLOW_COPY_AND_MOVE(Thread); explicit Thread(const StreamId& stream_id); virtual ~Thread(); void AddTask(const TaskProto&); Channel* GetMsgChannelPtr() { return &msg_channel_; } inline void EnqueueActorMsg(const ActorMsg& msg) { if (UseLocalMsgQueue()) { local_msg_queue_.push(msg); } else { msg_channel_.Send(msg); } } template inline void EnqueueActorMsg(InputIt first, InputIt last) { if (UseLocalMsgQueue()) { for (auto it = first; it != last; ++it) { local_msg_queue_.push(*it); } } else { for (auto it = first; it != last; ++it) { msg_channel_.Send(*it); } } } protected: void PollMsgChannel(); private: void ConstructActor(int64_t actor_id); inline bool UseLocalMsgQueue() const { return local_msg_queue_enabled_ && std::this_thread::get_id() == actor_thread_.get_id(); } HashMap id2task_; std::mutex id2task_mtx_; std::thread actor_thread_; Channel msg_channel_; HashMap, std::unique_ptr>> id2actor_ptr_; HashMap id2job_id_; std::queue local_msg_queue_; bool local_msg_queue_enabled_; int64_t thrd_id_; bool light_actor_enabled_; std::unique_ptr stream_ctx_; }; } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_THREAD_H_ ================================================ FILE: oneflow/core/thread/thread_global_id.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/container_util.h" namespace oneflow { namespace { int64_t* MutThreadLocalUniqueGlobalId() { static thread_local int64_t global_id = kThreadGlobalIdMain; return &global_id; } } // namespace int64_t GetThisThreadGlobalId() { return *MutThreadLocalUniqueGlobalId(); } ThreadGlobalIdGuard::ThreadGlobalIdGuard(int64_t thread_global_id) : old_thread_global_id_(GetThisThreadGlobalId()) { if (old_thread_global_id_ != kThreadGlobalIdMain) { CHECK_EQ(old_thread_global_id_, thread_global_id) << "nested ThreadGlobalIdGuard disabled. old thread_global_id: " << old_thread_global_id_ << ", new thread_global_id:" << thread_global_id; } *MutThreadLocalUniqueGlobalId() = thread_global_id; } ThreadGlobalIdGuard::~ThreadGlobalIdGuard() { *MutThreadLocalUniqueGlobalId() = old_thread_global_id_; } } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread_global_id.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_ #define ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" namespace oneflow { const static int kThreadGlobalIdDefaultWorker = 0; const static int kThreadGlobalIdMain = 7; int64_t GetThisThreadGlobalId(); class ThreadGlobalIdGuard final { public: explicit ThreadGlobalIdGuard(int64_t thread_global_id); ~ThreadGlobalIdGuard(); private: int64_t old_thread_global_id_; }; } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_ ================================================ FILE: oneflow/core/thread/thread_manager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { ThreadMgr::~ThreadMgr() { for (auto& thread_pair : threads_) { ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread); thread_pair.second->GetMsgChannelPtr()->Send(msg); thread_pair.second.reset(); VLOG(1) << " Actor thread: " << thread_pair.first << " finished when process exits."; } } Thread* ThreadMgr::GetThrd(int64_t thrd_id) { auto iter = threads_.find(thrd_id); CHECK(iter != threads_.end()) << " Thread: " << thrd_id << " not found"; return iter->second.get(); } void ThreadMgr::AddThreads(const HashSet& thread_ids) { const int64_t this_rank = GlobalProcessCtx::Rank(); for (int64_t thrd_id : thread_ids) { const auto& it = threads_.find(thrd_id); if (it != threads_.end()) { // NOTE(chengcheng): check thread is not null. CHECK(it->second) << " RuntimeError! Thread: " << thrd_id << " in manager must be NOT null."; VLOG(1) << " Actor thread: " << thrd_id << " reused."; continue; } StreamId stream_id = DecodeStreamIdFromInt64(thrd_id); if (stream_id.rank() != this_rank) { continue; } Thread* thread = new Thread(stream_id); CHECK_NOTNULL(thread); threads_[thrd_id].reset(thread); VLOG(1) << " Actor thread: " << thrd_id << " created."; } } void ThreadMgr::DeleteThreads(const HashSet& thread_ids) { std::unique_lock lock(mutex4del_threads_); for (int64_t thrd_id : thread_ids) { const auto& it = threads_.find(thrd_id); CHECK((it != threads_.end()) && (it->second)) << " RuntimeError! Actor thread: " << thrd_id << " non-existent but want to delete"; auto& thread = it->second; ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread); thread->GetMsgChannelPtr()->Send(msg); thread.reset(); VLOG(1) << " Actor thread: " << thrd_id << " finished when the graph is destructed."; threads_.erase(it); } } } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread_manager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_ #define ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_ #include #include "oneflow/core/common/channel.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/thread/thread.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/platform/include/pthread_fork.h" namespace oneflow { class Plan; class ThreadMgr final { public: OF_DISALLOW_COPY_AND_MOVE(ThreadMgr); ThreadMgr() = default; ~ThreadMgr(); void AddThreads(const HashSet& thread_ids); void DeleteThreads(const HashSet& thread_ids); Thread* GetThrd(int64_t thrd_id); private: friend class Singleton; HashMap> threads_; std::mutex mutex4del_threads_; }; // Use limit_thread_num to config the max thread num. // limit_thread_num == -1 means no limit, use the max avaliable thread num of the ThreadPool. // limit_thread_num == 0 means use the current thread. template void MultiThreadLoop(size_t work_num, const DoEachT& DoEachWork, int64_t limit_thread_num = -1) { if (work_num == 0) { return; } if (unlikely(pthread_fork::IsForkedSubProcess() || Singleton::Get() == nullptr || limit_thread_num == 0)) { FOR_RANGE(size_t, i, 0, work_num) { DoEachWork(i); } return; } size_t thread_num = Singleton::Get()->thread_num(); if (limit_thread_num > 0) { thread_num = std::min(thread_num, static_cast(limit_thread_num)); } thread_num = std::min(work_num, thread_num); BalancedSplitter bs(work_num, thread_num); BlockingCounter bc(thread_num); FOR_RANGE(size_t, range_id, 0, thread_num) { Singleton::Get()->AddWork([&bc, &bs, range_id, DoEachWork] { size_t start = bs.At(range_id).begin(); size_t end = bs.At(range_id).end(); FOR_RANGE(size_t, i, start, end) { DoEachWork(i); } bc.Decrease(); }); } // busy loop wait. bc.WaitForeverUntilCntEqualZero(); } inline bool* MutIsMainThread() { thread_local bool is_main_thread = false; return &is_main_thread; } inline bool IsMainThread() { return *MutIsMainThread(); } inline void SetIsMainThread(bool is_main_thread) { *MutIsMainThread() = is_main_thread; } COMMAND(SetIsMainThread(true)); } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_ ================================================ FILE: oneflow/core/thread/thread_pool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/vm/sync_vm_mode_guard.h" namespace oneflow { ThreadPool::ThreadPool(int32_t thread_num) : work_chans_(thread_num), threads_(thread_num), work_cnt_(0) { FOR_RANGE(int32_t, i, 0, thread_num) { Channel>* chan = &(work_chans_.at(i)); threads_[i] = std::thread([chan]() { SyncVmModeGuard guard(SyncVmMode::kEnable); std::function work; while (chan->Receive(&work) == kChannelStatusSuccess) { work(); } }); } } ThreadPool::~ThreadPool() { FOR_RANGE(int32_t, i, 0, work_chans_.size()) { work_chans_.at(i).Close(); threads_.at(i).join(); } } void ThreadPool::AddWork(const std::function& work) { const size_t cur_chan_idx = work_cnt_.fetch_add(1, std::memory_order_relaxed) % work_chans_.size(); work_chans_.at(cur_chan_idx).Send(work); } } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread_pool.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_THREAD_POOL_H_ #define ONEFLOW_CORE_THREAD_THREAD_POOL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/channel.h" namespace oneflow { class ThreadPool final { public: OF_DISALLOW_COPY_AND_MOVE(ThreadPool); ThreadPool() = delete; ThreadPool(int32_t thread_num); ~ThreadPool(); int32_t thread_num() const { return threads_.size(); } void AddWork(const std::function& work); private: std::vector>> work_chans_; std::vector threads_; std::atomic work_cnt_; }; } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_THREAD_POOL_H_ ================================================ FILE: oneflow/core/thread/thread_runtime.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_ #define ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_ #include #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/thread/thread.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/platform/include/pthread_fork.h" #ifdef WITH_TBB #include #include #include #endif #ifdef WITH_OMP #include #endif namespace oneflow { namespace thread { namespace { using CallableT = std::function; void SeqFor(int64_t begin, int64_t end, const CallableT& func) { func(begin, end); } size_t DivUp(size_t x, size_t y) { return (x + y - 1) / y; } } // namespace class RuntimeBase { public: void ParallelFor(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) { if (begin >= end) { return; } if (num_threads == 1) { return SeqFor(begin, end, func); } ParallelForImpl(begin, end, func, num_threads, grain_size); } private: virtual void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) = 0; }; class SeqRuntime final : public RuntimeBase { private: void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) override { return SeqFor(begin, end, func); } }; class OfRuntime final : public RuntimeBase { private: void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) override { if (unlikely(pthread_fork::IsForkedSubProcess()) || Singleton::Get() == nullptr) { return SeqFor(begin, end, func); } const size_t num_elements = end - begin; num_threads = std::min(num_elements, num_threads); BalancedSplitter bs(num_elements, num_threads); BlockingCounter bc(num_threads); FOR_RANGE(size_t, range_id, 0, num_threads) { Singleton::Get()->AddWork([&bc, &bs, range_id, func] { const size_t begin_ = bs.At(range_id).begin(); const size_t end_ = bs.At(range_id).end(); SeqFor(begin_, end_, func); bc.Decrease(); }); } // buzy loop wait. bc.WaitForeverUntilCntEqualZero(); } }; #if WITH_TBB class TbbRuntime final : public RuntimeBase { private: void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) override { tbb::global_control global_thread_limit(tbb::global_control::max_allowed_parallelism, num_threads); const size_t chunk_size = std::max(DivUp((end - begin), num_threads), grain_size); tbb::parallel_for( tbb::blocked_range(begin, end, chunk_size), [&func](const tbb::blocked_range& r) { SeqFor(r.begin(), r.end(), func); }, tbb::static_partitioner{}); } }; #endif #if WITH_OMP class OmpRuntime final : public RuntimeBase { private: void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads, size_t grain_size) override { num_threads = std::min(DivUp((end - begin), grain_size), num_threads); #pragma omp parallel num_threads(num_threads) { int64_t omp_num_thread = omp_get_num_threads(); int64_t chunk_size = DivUp((end - begin), omp_num_thread); int64_t omp_tid = omp_get_thread_num(); int64_t thread_begin_index = begin + omp_tid * chunk_size; int64_t thread_end_index = std::min(end, chunk_size + thread_begin_index); if (thread_begin_index < end) { SeqFor(thread_begin_index, thread_end_index, func); } } } }; #endif } // namespace thread } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_ ================================================ FILE: oneflow/core/thread/thread_runtime_factory.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/thread/thread_runtime_factory.h" #include "oneflow/core/thread/thread_runtime.h" namespace oneflow { namespace thread { namespace { template std::shared_ptr CreateRuntime() { return std::shared_ptr(std::make_shared()); } } // namespace Maybe RuntimeFactory::Create(RuntimeType type) { if (type == RuntimeType::kOf) { return CreateRuntime(); } const auto format_error_msg = [](const auto& name, const auto& option) { return fmt::format("{} is not enabled, you should compile oneflow with " "`-DCPU_THREADING_RUNTIMES={}`", name, option); }; if (type == RuntimeType::kTbb) { if (!IsTbbEnabled()) { return Error::RuntimeError() << format_error_msg("OneTBB", "TBB"); } #ifdef WITH_TBB return CreateRuntime(); #endif } if (type == RuntimeType::kOmp) { if (!IsOmpEnabled()) { return Error::RuntimeError() << format_error_msg("OpenMP", "OMP"); } #ifdef WITH_OMP return CreateRuntime(); #endif } return CreateRuntime(); } Maybe RuntimeFactory::Create(const std::string& type) { std::unordered_map types{ {"SEQ", RuntimeType::kSeq}, {"OF", RuntimeType::kOf}, {"TBB", RuntimeType::kTbb}, {"OMP", RuntimeType::kOmp}, }; if (types.find(type) == types.end()) { return Error::RuntimeError() << fmt::format("Not supportted cpu threading runtime: {}", type); } return Create(types[type]); } } // namespace thread } // namespace oneflow ================================================ FILE: oneflow/core/thread/thread_runtime_factory.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_ #define ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_ #include "oneflow/core/thread/thread_runtime.h" namespace oneflow { namespace thread { constexpr bool IsTbbEnabled() { #ifdef WITH_TBB return true; #else return false; #endif } constexpr bool IsOmpEnabled() { #ifdef WITH_OMP return true; #else return false; #endif } enum class RuntimeType { kSeq, kOf, kTbb, kOmp, }; class RuntimeFactory { public: static Maybe Create(RuntimeType type); static Maybe Create(const std::string& type); }; } // namespace thread } // namespace oneflow #endif // ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_ ================================================ FILE: oneflow/core/transport/transport.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/transport/transport.h" namespace oneflow { Transport::Transport() { comm_net_ = Singleton::Get(); // NOLINT this_machine_id_ = GlobalProcessCtx::Rank(); CHECK(comm_net_ != nullptr); // maybe need new read id for each dst machine id, maybe need 2 * machine num read ids read_id_ = comm_net_->NewActorReadId(); msg_poller_ = std::thread([this]() { PollMsgChannel(); }); } Transport::~Transport() { msg_channel_.Close(); msg_poller_.join(); comm_net_->DeleteActorReadId(read_id_); } void Transport::EnqueueTransportMsg(const TransportMsg& msg) { CHECK_EQ(msg_channel_.Send(msg), kChannelStatusSuccess); } void Transport::PollMsgChannel() { TransportMsg msg; while (true) { ChannelStatus stat = msg_channel_.Receive(&msg); if (stat != kChannelStatusSuccess) { CHECK_EQ(stat, kChannelStatusErrorClosed); break; } switch (msg.type) { case TransportMsgType::kSend: { HandlerAchievedTransportSendMsgFromSrcMachine(msg); break; } case TransportMsgType::kAck: { HandlerAchievedTransportAckMsgFromDstMachine(msg); break; } default: UNIMPLEMENTED(); break; } } } void Transport::HandlerAchievedTransportSendMsgFromSrcMachine(const TransportMsg& msg) { // This machine is dst machine, and receive Send msg from source machine // Maybe we need create TransportStatus, // or we need update TransportStatus and DoRead(). CHECK_EQ(msg.type, TransportMsgType::kSend); CHECK(msg.src_mem_token != nullptr); CHECK(msg.dst_mem_token == nullptr); uint64_t token = msg.token; CHECK(token != -1); // There are two ways to trigger the creation of TransportStatus: // 1. The time (T_A) when the dst machine receives SendMsg from src machine // 2. The time (T_B) when method Receive() called by the dst machine. // Because of T_ A and t_ B are both protected by the lock(status_mutex_), so the creation of // TransportStatus will NOT trigger at the same time. // // T_ A maybe earlier than t_ B, maybe later. // // In either case, the earlier one is responsible for creating the TransportStatus, and the later // one is responsible for checking the TransportStatus and then calling the DoRead() operation. // prepare transport status for this token. // store callback. TransportStatus* stat = nullptr; // if recv_before_send is true, it means the Receive() method has been called before this handler bool recv_before_send = false; { std::unique_lock lock(status_mutex_); auto it = token2status_.find(token); if (it == token2status_.end()) { token2status_.emplace(token, TransportStatus(token)); stat = &(token2status_.at(token)); // init stat // These three members must be initialized in the block protected by lock // to prevent multi-threaded access bugs stat->size = msg.size; stat->src_machine_id = msg.src_machine_id; stat->dst_machine_id = msg.dst_machine_id; } else { recv_before_send = true; stat = &(it->second); CHECK_GE(stat->size, msg.size); // NOTE(chengcheng): Recv size may larger than Send size. stat->size = msg.size; // NOTE(chengcheng): msg.size always is smaller one. } stat->is_send_ready = true; CHECK(stat->src_mem_token == nullptr); // src_mem_token MUST init in the block protected by lock stat->src_mem_token = msg.src_mem_token; } if (recv_before_send) { // it means the local machine has call Transport::Receive() before this handler // check status CHECK_EQ(stat->src_machine_id, msg.src_machine_id); CHECK_EQ(stat->dst_machine_id, msg.dst_machine_id); // the recv is ready, and the send is ready too, so call DoRead(); DoRead(token); } } void Transport::HandlerAchievedTransportAckMsgFromDstMachine(const TransportMsg& msg) { // This machine is src machine, and receive Ack msg from dst machine. The Send/Receive pair of // this token is all done. So we can call callback function and erase TransportStatus. CHECK_EQ(msg.type, TransportMsgType::kAck); CHECK(msg.src_mem_token != nullptr); CHECK(msg.dst_mem_token != nullptr); uint64_t token = msg.token; CHECK(token != -1); std::function callback; // get status from map { std::unique_lock lock(status_mutex_); auto it = token2status_.find(token); CHECK(it != token2status_.end()); TransportStatus* stat = &(it->second); // check msg == stat CHECK_EQ(stat->src_mem_token, msg.src_mem_token); CHECK_EQ(stat->size, msg.size); CHECK_EQ(stat->src_machine_id, msg.src_machine_id); CHECK_EQ(stat->dst_machine_id, msg.dst_machine_id); CHECK(stat->callback != nullptr); callback = stat->callback; // Recovery status token2status_.erase(it); } // UnRegisterMemory comm_net_->UnRegisterMemory(msg.src_mem_token); // Do Send callback callback(); } void Transport::Send(uint64_t token, int64_t dst_machine_id, const void* ptr, std::size_t size, std::function callback) { void* mut_ptr = const_cast(ptr); // handler for send to local machine if (dst_machine_id == this_machine_id_) { SendToLocalMachine(token, mut_ptr, size, callback); return; } // prepare transport status for this token. // store callback. TransportStatus* stat = nullptr; { std::unique_lock lock(status_mutex_); CHECK(token2status_.find(token) == token2status_.end()); // this token must be first add to status token2status_.emplace(token, TransportStatus(token)); stat = &(token2status_.at(token)); } stat->callback = callback; stat->is_send_ready = true; stat->is_recv_ready = false; stat->src_mem_token = comm_net_->RegisterMemory(mut_ptr, size); stat->dst_mem_token = nullptr; stat->size = size; stat->src_machine_id = this_machine_id_; stat->dst_machine_id = dst_machine_id; // Send msg to dst machine TransportMsg msg; msg.token = token; msg.src_machine_id = stat->src_machine_id; msg.dst_machine_id = stat->dst_machine_id; msg.size = size; msg.src_mem_token = stat->src_mem_token; msg.dst_mem_token = stat->dst_mem_token; msg.type = TransportMsgType::kSend; comm_net_->SendTransportMsg(msg.dst_machine_id, msg); } void Transport::Receive(uint64_t token, int64_t src_machine_id, void* ptr, std::size_t max_size, std::function callback) { // handler for receive from local machine if (src_machine_id == this_machine_id_) { RecvFromLocalMachine(token, ptr, max_size, callback); return; } // prepare transport status for this token. // store callback. TransportStatus* stat = nullptr; // if recv_before_send is true, it means the SendMsg has been handled before this Receive called. bool send_before_recv = false; { std::unique_lock lock(status_mutex_); auto it = token2status_.find(token); if (it == token2status_.end()) { token2status_.emplace(token, TransportStatus(token)); stat = &(token2status_.at(token)); // init stat // These three members must be initialized in the block protected by lock // to prevent multi-threaded access bugs stat->size = max_size; stat->src_machine_id = src_machine_id; stat->dst_machine_id = this_machine_id_; } else { send_before_recv = true; stat = &(it->second); } stat->callback = callback; stat->is_recv_ready = true; // NOTE(chengcheng): Store dst_ptr so that we can create dst_mem_token in DoRead() stat->dst_ptr = ptr; } if (send_before_recv) { // it means the source machine has send message to this machine // check status CHECK_LE(stat->size, max_size); // NOTE(chengcheng): Receive max_size may larger than Send size. CHECK_EQ(stat->src_machine_id, src_machine_id); CHECK_EQ(stat->dst_machine_id, this_machine_id_); // the recv is ready, and the send is ready too, so call DoRead(); DoRead(token); } } void Transport::DoRead(uint64_t token) { TransportStatus* stat = nullptr; { std::unique_lock lock(status_mutex_); auto it = token2status_.find(token); CHECK(it != token2status_.end()); stat = &(it->second); // dst_mem_token MUST init in the block protected by lock CHECK(stat->dst_mem_token == nullptr); // NOTE(chengcheng): ONLY at this time, the stat->size is the real size assigned by Send stat->dst_mem_token = comm_net_->RegisterMemory(stat->dst_ptr, stat->size); } CHECK(stat->is_send_ready && stat->is_recv_ready); CHECK(stat->src_mem_token != nullptr); CHECK(stat->dst_mem_token != nullptr); CHECK(stat->src_machine_id != -1); CHECK(stat->dst_machine_id != -1); CHECK(stat->size != -1); CHECK(stat->callback); comm_net_->Read(read_id_, stat->src_machine_id, stat->src_mem_token, stat->dst_mem_token); comm_net_->AddReadCallBack(read_id_, [stat, this]() { // Send ack message to source machine TransportMsg msg; msg.token = stat->token; msg.src_machine_id = stat->src_machine_id; msg.dst_machine_id = stat->dst_machine_id; msg.size = stat->size; msg.src_mem_token = stat->src_mem_token; msg.dst_mem_token = stat->dst_mem_token; msg.type = TransportMsgType::kAck; comm_net_->SendTransportMsg(msg.src_machine_id, msg); // UnRegisterMemory comm_net_->UnRegisterMemory(msg.dst_mem_token); // Do Receive callback stat->callback(); // Recovery status { std::unique_lock lock(status_mutex_); auto it = token2status_.find(stat->token); CHECK(it != token2status_.end()); token2status_.erase(it); } }); } void Transport::SendToLocalMachine(uint64_t token, void* ptr, std::size_t size, std::function callback) { bool need_do_copy = false; bool need_do_callback = false; std::function receive_callback; void* dst_ptr = nullptr; { std::unique_lock lock(local_copy_lock_); auto it = token2local_copy_status_.find(token); if (it == token2local_copy_status_.end()) { // init local copy status token2local_copy_status_.emplace(token, CopyStatusOnLocalMachine(token, ptr, size, callback)); } else { need_do_callback = true; receive_callback = std::move(it->second.callback); dst_ptr = it->second.ptr; CHECK(size <= it->second.size); // NOTE(chengcheng): Recv size may larger than Send size. if (ptr != dst_ptr) { need_do_copy = true; } // erase local copy status token2local_copy_status_.erase(it); } } if (need_do_copy) { memcpy(dst_ptr, ptr, size); } if (need_do_callback) { callback(); receive_callback(); } } void Transport::RecvFromLocalMachine(uint64_t token, void* ptr, std::size_t max_size, std::function callback) { bool need_do_copy = false; bool need_do_callback = false; std::function send_callback; void* src_ptr = nullptr; std::size_t size = -1; { std::unique_lock lock(local_copy_lock_); auto it = token2local_copy_status_.find(token); if (it == token2local_copy_status_.end()) { // init local copy status token2local_copy_status_.emplace(token, CopyStatusOnLocalMachine(token, ptr, max_size, callback)); } else { need_do_callback = true; send_callback = std::move(it->second.callback); src_ptr = it->second.ptr; size = it->second.size; CHECK(max_size >= size); // NOTE(chengcheng): Recv size may larger than Send size. if (ptr != src_ptr) { need_do_copy = true; } // erase local copy status token2local_copy_status_.erase(it); } } if (need_do_copy) { memcpy(ptr, src_ptr, size); } if (need_do_callback) { callback(); send_callback(); } } } // namespace oneflow #endif // __linux__ ================================================ FILE: oneflow/core/transport/transport.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __linux__ #ifndef ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_ #define ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_ #include "oneflow/core/common/channel.h" #include "oneflow/core/comm_network/epoll/epoll_comm_network.h" #include "oneflow/core/transport/transport_message.h" namespace oneflow { // Transport supports sending and receiving data between two machines, which is identified by // a unique token. // // Suppose machine A wants to send a piece of data to machine B. Singleton both need // created on machine A and machine B respectively. // // Machin A need call: // Singleton::Get()->Send(token, B, data_ptr_A, data_size_A, callback_after_send); // Machin B need call: // Singleton::Get()->Receive(token, A, data_ptr_B, data_size_B, // callback_after_receive); // // data_size_A <= data_size_B // // Both call: Send()/Receive() will be executed asynchronously. // // When the data transmission is completed, the callbacks of the two machines callback_after_send() // and callback_after_receive() will be executed on their respective machines. // // Transport supports send and receive data on local machine. // class Transport { public: OF_DISALLOW_COPY_AND_MOVE(Transport); virtual ~Transport(); void Send(uint64_t token, int64_t dst_machine_id, const void* ptr, std::size_t size, std::function callback); void Receive(uint64_t token, int64_t src_machine_id, void* ptr, std::size_t max_size, std::function callback); void EnqueueTransportMsg(const TransportMsg& msg); private: void PollMsgChannel(); void HandlerAchievedTransportSendMsgFromSrcMachine(const TransportMsg& msg); void HandlerAchievedTransportAckMsgFromDstMachine(const TransportMsg& msg); void DoRead(uint64_t token); void SendToLocalMachine(uint64_t token, void* ptr, std::size_t size, std::function callback); void RecvFromLocalMachine(uint64_t token, void* ptr, std::size_t max_size, std::function callback); // TODO(chengcheng) // Singleton has a dependency on Singleton which should be initialized first. friend class Singleton; Transport(); // TransportStatus stores all the information that Transport needs in a Send / Receive process. // // At the sender (source machine), the TransportStatus stores the callback from the Send(). // At the receiver (destination machine), the TransportStatus stores the callback from Receive(). // // In the process of one transmission between two machines, the TransportStatus will be created, // changed and finally deleted by sending and receiving messages for many times. struct TransportStatus { const uint64_t token; std::function callback; bool is_send_ready; bool is_recv_ready; void* src_mem_token; void* dst_mem_token; // NOTE(chengcheng): must store dst_ptr in status when Receive max_size > Send size void* dst_ptr; std::size_t size; int64_t src_machine_id; int64_t dst_machine_id; TransportStatus(uint64_t tk) : token(tk), callback(nullptr), is_send_ready(false), is_recv_ready(false), src_mem_token(nullptr), dst_mem_token(nullptr), size(-1), src_machine_id(-1), dst_machine_id(-1) {} }; // CopyStatusOnLocalMachine is a stored state to support local data transfer. // // This state stores only the most necessary information. // // When Send() is called first, it stores the token, pointer, size and callback of the sender. // In this way, when Receive() is called, copy and two callbacks can be executed. // // When Receive() is called first, it stores the token, pointer, size and callback of the // receiver. In this way, when Send() is called, copy and two callbacks can be executed. struct CopyStatusOnLocalMachine { const uint64_t token; void* ptr; std::size_t size; std::function callback; CopyStatusOnLocalMachine(uint64_t tk, void* p, std::size_t s, std::function cb) : token(tk), ptr(p), size(s), callback(std::move(cb)) {} }; // Store the TransportStatus for each token (Send/Receive pair). // The map token2status_ should be protected by status_mutex_ when you want to change it. std::mutex status_mutex_; HashMap token2status_; // for local copy std::mutex local_copy_lock_; HashMap token2local_copy_status_; int64_t this_machine_id_; void* read_id_; EpollCommNet* comm_net_; Channel msg_channel_; std::thread msg_poller_; }; } // namespace oneflow #endif // ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_ #endif // __linux__ ================================================ FILE: oneflow/core/transport/transport_message.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_ #define ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_ #include "oneflow/core/common/platform.h" #include "oneflow/core/common/util.h" #ifdef __linux__ namespace oneflow { enum class TransportMsgType { kInvalid = 0, kSend = 1, // send msg from local to remote transport kAck = 2, // this token transmission task is down }; struct TransportMsg { uint64_t token; void* src_mem_token; void* dst_mem_token; std::size_t size; int64_t src_machine_id; int64_t dst_machine_id; TransportMsgType type; }; } // namespace oneflow #endif // __linux__ #endif // ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_ ================================================ FILE: oneflow/core/vm/access_blob_arg_cb_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_ #include #include #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/instruction_policy_util.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/intrusive/list.h" #include "oneflow/core/common/util.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/vm/stream_policy.h" namespace oneflow { namespace vm { class AccessBlobArgCbInstructionPolicy final : public InstructionPolicy { public: AccessBlobArgCbInstructionPolicy( const std::shared_ptr& eager_blob_object, const std::function&)>& callback, const std::string& modifier) : eager_blob_object_(eager_blob_object), callback_(callback), modifier_(modifier), input_dependences_(), output_dependences_() { ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_)); ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); stream_sequential_dependence_ = nullptr; } ~AccessBlobArgCbInstructionPolicy() = default; const std::shared_ptr& eager_blob_object() const { return eager_blob_object_; } const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } void ForEachConstDependence(const std::function& DoEach) const { if (modifier_ == "const") { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); } } void ForEachMutDependence(const std::function& DoEach) const { if (modifier_ == "mut") { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); } } void ForEachMut2Dependence(const std::function& DoEach) const { if (modifier_ == "mut2") { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); } } std::string DebugName(const Instruction& instruction) const override { return "AccessBlobByCallback"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override { StreamPolicy* stream_policy = instruction->mut_stream_policy(); auto rematable_storage = std::dynamic_pointer_cast(eager_blob_object()->tensor_storage()); if (rematable_storage && !rematable_storage->is_in_memory()) { OpCallInstructionPolicy tmp_op = rematable_storage->compute_op(); CHECK_JUST(Recompute(&tmp_op, instruction->mut_stream())); } callback_(stream_policy->stream(), eager_blob_object()); if (rematable_storage && (modifier_ == "mut" || modifier_ == "mut2")) { rematable_storage->set_eviction_disabled(true); } } private: std::shared_ptr eager_blob_object_; std::function&)> callback_; const std::string modifier_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/allocate_tensor_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/allocate_tensor_instruction_policy.h" namespace oneflow { namespace vm { AllocateTensorInstructionPolicy::AllocateTensorInstructionPolicy( const EagerBlobObjectList& eager_blob_objects, vm::Stream* vm_stream) : eager_blob_objects_(eager_blob_objects) { stream_sequential_dependence_ = vm_stream->schedule_local_dep_object().get(); for (const auto& eager_blob_object : eager_blob_objects) { output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } } std::string AllocateTensorInstructionPolicy::DebugName(const vm::Instruction& instruction) const { return "AllocateTensor"; } void AllocateTensorInstructionPolicy::Compute(Instruction* instruction) { Allocator* allocator = instruction->mut_stream()->mut_stream_policy()->mut_allocator(); for (const auto& eager_blob_object : eager_blob_objects_) { CHECK_JUST(eager_blob_object->TryAllocateBlobBodyMemory(allocator)); } } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/allocate_tensor_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/stream.h" namespace oneflow { namespace vm { class AllocateTensorInstructionPolicy final : public InstructionPolicy { public: AllocateTensorInstructionPolicy(const EagerBlobObjectList& eager_blob_objects, vm::Stream* vm_stream); AllocateTensorInstructionPolicy(const AllocateTensorInstructionPolicy&) = delete; AllocateTensorInstructionPolicy(AllocateTensorInstructionPolicy&&) = delete; ~AllocateTensorInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { static thread_local DependenceVector input_dependences{}; return input_dependences; } const DependenceVector& output_dependences() const override { return output_dependences_; } InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; } std::string DebugName(const vm::Instruction& instruction) const override; private: Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override; EagerBlobObjectList eager_blob_objects_; DependenceVector output_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_ALLOCATOR_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/throw.h" namespace oneflow { namespace vm { class Allocator { public: virtual ~Allocator() = default; virtual Maybe Allocate(char** mem_ptr, std::size_t size) = 0; virtual void Deallocate(char* mem_ptr, std::size_t size) = 0; virtual void DeviceReset() = 0; protected: Allocator() = default; }; class UnimplementedAllocator final : public Allocator { public: explicit UnimplementedAllocator(const std::string& debug_str) : debug_str_(debug_str) {} virtual ~UnimplementedAllocator() = default; Maybe Allocate(char** mem_ptr, std::size_t size) override { UNIMPLEMENTED_THEN_RETURN() << debug_str_; } void Deallocate(char* mem_ptr, std::size_t size) override { LOG(FATAL) << debug_str_; } void DeviceReset() override { LOG(FATAL) << debug_str_; } private: std::string debug_str_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/barrier_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_ #include "oneflow/core/vm/instruction_policy.h" namespace oneflow { namespace vm { class BarrierInstructionPolicy final : public InstructionPolicy { public: BarrierInstructionPolicy(const std::function& callback) : callback_(callback) { stream_sequential_dependence_ = nullptr; } ~BarrierInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { static DependenceVector dependences{}; return dependences; } const DependenceVector& output_dependences() const override { static DependenceVector dependences{}; return dependences; } bool IsBarrier() const override { return true; } std::string DebugName(const vm::Instruction& instruction) const override { return "Barrier"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override { return callback_(); } private: std::function callback_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/bin_allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_ #include #include "oneflow/core/vm/allocator.h" #include "oneflow/core/vm/caching_allocator.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { template class BinAllocator final : public CachingAllocator { public: explicit BinAllocator(size_t alignment, std::unique_ptr&& backend); ~BinAllocator(); Maybe Allocate(char** mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override; void DeviceReset() override { typename ThreadLock::RAIIGuard guard(thread_lock_); backend_->DeviceReset(); } void Shrink() override { typename ThreadLock::RAIIGuard guard(thread_lock_); DeallocateFreeBlockForGarbageCollection(); } private: static constexpr int32_t kInvalidBinNum = -1; static constexpr int32_t kBinNumSize = 20; // Piece is the basic memory unit of BinAllocator. // A Piece is either is free(is_free = true) or in used(is_free = false). // If the Piece is_free = true, the pointer to the piece will be stored in the Bin structure of // the corresponding BinSize. Pieces are stored in a linked list. The Piece's prev and next are // continuous with the current Piece in physical memory. struct Piece { size_t size = 0; char* ptr = nullptr; bool is_free = false; Piece* prev = nullptr; Piece* next = nullptr; int32_t bin_num = kInvalidBinNum; }; // Bin is a structure that stores a set of pieces which is free and has similar size, and // these Pieces are arger than the size of bin // // BinAllocator has a set of Bin structures according to the binary multiple increasing relation, // which is used to quickly index and find the free Piece of appropriate size when Allocate() // // The size of the smallest bin is 512 (512 is the smallest unit Allocated by BinAllocator, // and the memory size of all Allocated will be multiples of 512, 512 is kCudaMemAllocAlignSize). // The size of each Bin is twice the size of the previous Bin, like // BinNum: Bin0, Bin1, Bin2, Bin3, ..., Bin19 // BinSize: 512, 1024, 2048, 4096, ... , 512MB struct Bin { size_t size = 0; struct PieceCmp { bool operator()(const Piece* lhs, const Piece* rhs) const { if (lhs->size != rhs->size) { return lhs->size < rhs->size; } return lhs->ptr < rhs->ptr; } }; std::set pieces; }; // Block is large physical memory that is actually allocated. // There maybe many consecutive disjoint Pieces distributed on the Block memory struct Block { size_t size = 0; char* ptr = nullptr; Piece* start_piece = nullptr; Block(Piece* p) : size(p->size), ptr(p->ptr), start_piece(p) {} }; size_t BinSize4BinNum(int32_t bin_num) { return kCudaMemAllocAlignSize << bin_num; } int32_t BinNum4BinSize(size_t size) { uint64_t value = std::max(size, kCudaMemAllocAlignSize) >> 9; return std::min(kBinNumSize - 1, static_cast(63 ^ __builtin_clzll(value))); } // Try find free Piece which size is larger than aligned_size in Bins. // Return nullptr when find failure Piece* FindPiece(size_t aligned_size); // Insert the free Piece to the appropriate Bin which bin size is smaller than piece void InsertPiece2Bin(Piece* piece); // Create new empty Piece or recycle a Piece from recycle_piece_list_ Piece* AllocatePiece(); // Delete a Piece and move in the linked list recycle_piece_list_ void DeallocatePiece(Piece* piece); // Insert a {piece->ptr, piece} pair into the ptr2piece_ map for search Piece when call // Deallocate() void MarkPiece(Piece* piece); // Erase the {piece->ptr, piece} pair from ptr2piece_ because the ptr is useless // Usually call before DeallocatePiece() void UnMarkPiece(Piece* piece); void MergeNeighbourFreePiece(Piece* lhs, Piece* rhs); void RemovePieceFromBin(Piece* piece); Maybe AllocateBlockToExtendTotalMem(size_t aligned_size); bool DeallocateFreeBlockForGarbageCollection(); const size_t alignment_; const std::unique_ptr backend_; ThreadLock thread_lock_; size_t total_memory_bytes_; HashMap mem_ptr2block_; std::vector bins_; std::vector> pieces_; HashMap ptr2piece_; Piece* recycle_piece_list_; }; namespace { inline size_t MemAlignedBytes(size_t bytes, size_t alignment) { return RoundUp(bytes, alignment); } inline bool IsAlignedSize(size_t size, size_t alignment) { return size % alignment == 0; } static const size_t kPieceSplitThreshold = 128 << 20; // 128MiB } // namespace template BinAllocator::BinAllocator(size_t alignment, std::unique_ptr&& backend) : CachingAllocator(), alignment_(alignment), backend_(std::move(backend)), total_memory_bytes_(0), recycle_piece_list_(nullptr) { CHECK_GE(alignment, 1); CHECK_EQ(1 << static_cast(std::log2(alignment)), alignment); bins_.resize(kBinNumSize); for (int i = 0; i < kBinNumSize; ++i) { size_t bin_size = BinSize4BinNum(i); bins_.at(i).size = bin_size; CHECK_EQ(BinNum4BinSize(bin_size), i); CHECK_EQ(BinNum4BinSize(bin_size + alignment_ - 1), i); CHECK_EQ(BinNum4BinSize(bin_size * 2 - 1), i); CHECK_EQ(BinNum4BinSize(bin_size * 2), i == (kBinNumSize - 1) ? i : i + 1); } } template BinAllocator::~BinAllocator() { if (total_memory_bytes_ == 0) { CHECK_EQ(mem_ptr2block_.size(), 0); return; } for (auto& pair : mem_ptr2block_) { backend_->Deallocate(pair.first, pair.second.size); } } template void BinAllocator::InsertPiece2Bin(Piece* piece) { CHECK(piece->is_free && piece->bin_num == kInvalidBinNum); int32_t bin_num = BinNum4BinSize(piece->size); piece->bin_num = bin_num; CHECK(bins_.at(bin_num).pieces.insert(piece).second); } template void BinAllocator::RemovePieceFromBin(Piece* piece) { CHECK(piece->is_free); CHECK_NE(piece->bin_num, kInvalidBinNum); CHECK_GT(bins_.at(piece->bin_num).pieces.erase(piece), 0); piece->bin_num = kInvalidBinNum; } template typename BinAllocator::Piece* BinAllocator::AllocatePiece() { if (recycle_piece_list_) { Piece* ret = recycle_piece_list_; recycle_piece_list_ = recycle_piece_list_->next; return ret; } else { pieces_.emplace_back(new Piece()); return pieces_.at(pieces_.size() - 1).get(); } } template void BinAllocator::DeallocatePiece(Piece* piece) { piece->ptr = nullptr; piece->size = 0; piece->bin_num = kInvalidBinNum; piece->is_free = true; piece->prev = nullptr; piece->next = recycle_piece_list_; recycle_piece_list_ = piece; } template void BinAllocator::MarkPiece(Piece* piece) { CHECK_NOTNULL(piece->ptr); CHECK(ptr2piece_.emplace(piece->ptr, piece).second); } template void BinAllocator::UnMarkPiece(Piece* piece) { CHECK_NOTNULL(piece->ptr); auto it = ptr2piece_.find(piece->ptr); CHECK(it != ptr2piece_.end()); ptr2piece_.erase(it); } template typename BinAllocator::Piece* BinAllocator::FindPiece(size_t aligned_size) { CHECK(IsAlignedSize(aligned_size, alignment_)); for (int32_t bin_num = BinNum4BinSize(aligned_size); bin_num < kBinNumSize; ++bin_num) { Bin* bin = &bins_.at(bin_num); for (auto it = bin->pieces.begin(); it != bin->pieces.end(); ++it) { Piece* piece = *it; CHECK(piece->is_free); CHECK_NOTNULL(piece->ptr); CHECK_EQ(piece->bin_num, bin_num); CHECK(IsAlignedSize(piece->size, alignment_)); if (piece->size >= aligned_size) { bin->pieces.erase(it); piece->bin_num = kInvalidBinNum; piece->is_free = false; if (piece->size >= aligned_size * 2 || piece->size - aligned_size >= kPieceSplitThreshold) { Piece* new_piece = AllocatePiece(); new_piece->ptr = piece->ptr + aligned_size; new_piece->size = piece->size - aligned_size; piece->size = aligned_size; Piece* next_p = piece->next; piece->next = new_piece; new_piece->prev = piece; new_piece->next = next_p; if (next_p != nullptr) { next_p->prev = new_piece; } new_piece->is_free = true; new_piece->bin_num = kInvalidBinNum; CHECK(IsAlignedSize(piece->size, alignment_)); CHECK(IsAlignedSize(new_piece->size, alignment_)); InsertPiece2Bin(new_piece); MarkPiece(new_piece); } return piece; } } } return nullptr; } template void BinAllocator::MergeNeighbourFreePiece(Piece* lhs, Piece* rhs) { CHECK(lhs->is_free); CHECK(rhs->is_free); CHECK(lhs->next == rhs); CHECK(lhs == rhs->prev); CHECK(lhs->ptr + lhs->size == rhs->ptr); lhs->size += rhs->size; lhs->next = rhs->next; if (rhs->next != nullptr) { rhs->next->prev = lhs; } UnMarkPiece(rhs); DeallocatePiece(rhs); } template Maybe BinAllocator::AllocateBlockToExtendTotalMem(size_t aligned_size) { CHECK_OR_RETURN(IsAlignedSize(aligned_size, alignment_)) << "not aligned"; size_t allocate_bytes = aligned_size; if (allocate_bytes < 1048576) { // Allocate 2MB if `allocate_bytes` is less than 1MB allocate_bytes = 2097152; } else if (allocate_bytes < 10485760) { // Allocate 20MB if `allocate_bytes` is between 1MB and 10MB allocate_bytes = 20971520; } else { // Round up to 2MB if `allocate_bytes` is larger than 10MB allocate_bytes = RoundUp(allocate_bytes, 2097152); } const size_t final_allocate_bytes = MemAlignedBytes(allocate_bytes, alignment_); if (final_allocate_bytes < aligned_size) { return false; } char* mem_ptr = nullptr; JUST(backend_->Allocate(&mem_ptr, final_allocate_bytes)); if (mem_ptr == nullptr) { return false; } // extend sucess total_memory_bytes_ += final_allocate_bytes; Piece* piece = AllocatePiece(); piece->size = final_allocate_bytes; piece->ptr = mem_ptr; piece->prev = nullptr; piece->next = nullptr; piece->is_free = true; piece->bin_num = kInvalidBinNum; InsertPiece2Bin(piece); MarkPiece(piece); CHECK_OR_RETURN(mem_ptr2block_.emplace(mem_ptr, Block(piece)).second) << "existed mem_ptr"; return true; } template bool BinAllocator::DeallocateFreeBlockForGarbageCollection() { size_t total_free_bytes = 0; HashSet free_block_ptrs; for (const auto& pair : mem_ptr2block_) { const Block& block = pair.second; bool all_free = true; Piece* p = block.start_piece; while (p != nullptr) { if (!(p->is_free)) { all_free = false; break; } p = p->next; } if (all_free) { total_free_bytes += block.size; free_block_ptrs.insert(pair.first); } } total_memory_bytes_ -= total_free_bytes; if (total_free_bytes > 0) { VLOG(3) << "BinAllocator try deallocate free block for garbage collection. " << " deallocate free bytes : " << total_free_bytes; for (char* ptr : free_block_ptrs) { auto it = mem_ptr2block_.find(ptr); CHECK(it != mem_ptr2block_.end()); const Block& block = it->second; // delete all Piece on Block size_t piece_size_sum = 0; Piece* p = block.start_piece; CHECK_EQ(block.ptr, block.start_piece->ptr); CHECK_EQ(block.ptr, ptr); while (p != nullptr) { Piece* next_p = p->next; piece_size_sum += p->size; RemovePieceFromBin(p); UnMarkPiece(p); DeallocatePiece(p); p = next_p; } CHECK_EQ(block.size, piece_size_sum); mem_ptr2block_.erase(it); backend_->Deallocate(ptr, block.size); } } return total_free_bytes > 0; } template Maybe BinAllocator::Allocate(char** mem_ptr, std::size_t size) { typename ThreadLock::RAIIGuard guard(thread_lock_); if (size == 0) { *mem_ptr = nullptr; return Maybe::Ok(); } size_t aligned_size = MemAlignedBytes(size, alignment_); Piece* piece = FindPiece(aligned_size); if (piece == nullptr) { if (JUST(AllocateBlockToExtendTotalMem(aligned_size))) { piece = FindPiece(aligned_size); } } CHECK_NOTNULL_OR_RETURN(piece) << Error::OutOfMemoryError() << "Error! : Out of memory when allocate size : " << size << ".\n The total_memory_bytes allocated by this BinAllocator is : " << total_memory_bytes_; if (piece == nullptr) { backend_->DeviceReset(); LOG(FATAL) << "Error! : Out of memory when allocate size : " << size << ".\n The total_memory_bytes allocated by this BinAllocator is : " << total_memory_bytes_; } CHECK_NOTNULL_OR_RETURN(piece->ptr) << "invalid piece null ptr"; CHECK_OR_RETURN(ptr2piece_.find(piece->ptr) != ptr2piece_.end()) << "piece is not found"; *mem_ptr = piece->ptr; return Maybe::Ok(); } template void BinAllocator::Deallocate(char* mem_ptr, std::size_t size) { if (mem_ptr == nullptr) { return; } typename ThreadLock::RAIIGuard guard(thread_lock_); auto it = ptr2piece_.find(mem_ptr); CHECK(it != ptr2piece_.end()) << "Error! : Try deallocate mem_ptr non-existent. mem ptr = " << mem_ptr << " size = " << size; Piece* piece = it->second; CHECK_NOTNULL(piece); CHECK_EQ(piece->ptr, mem_ptr); CHECK(!piece->is_free); piece->is_free = true; Piece* last_piece_insert_to_bin = piece; Piece* next_p = piece->next; Piece* prev_p = piece->prev; if (next_p != nullptr && next_p->is_free) { CHECK_EQ(next_p->ptr, piece->ptr + piece->size); RemovePieceFromBin(next_p); MergeNeighbourFreePiece(piece, next_p); } if (prev_p != nullptr && prev_p->is_free) { CHECK_EQ(piece->ptr, prev_p->ptr + prev_p->size); RemovePieceFromBin(prev_p); MergeNeighbourFreePiece(prev_p, piece); last_piece_insert_to_bin = prev_p; } InsertPiece2Bin(last_piece_insert_to_bin); } } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/bin_allocator_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #ifdef WITH_CUDA #include "gtest/gtest.h" #include "oneflow/core/vm/bin_allocator.h" #include "oneflow/core/vm/thread_safe_guard.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace vm { class CudaBackendAllocator final : public CachingAllocator { public: explicit CudaBackendAllocator(int64_t device_id) : device_id_(device_id) {} ~CudaBackendAllocator() override = default; Maybe Allocate(char** mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override; void DeviceReset() override; void Shrink() override{}; private: int64_t device_id_; }; Maybe CudaBackendAllocator::Allocate(char** mem_ptr, std::size_t size) { cudaSetDevice(device_id_); if (cudaMalloc(mem_ptr, size) != cudaSuccess) { *mem_ptr = nullptr; } return Maybe::Ok(); } void CudaBackendAllocator::Deallocate(char* mem_ptr, std::size_t size) { cudaSetDevice(device_id_); OF_CUDA_CHECK(cudaFree(mem_ptr)); } void CudaBackendAllocator::DeviceReset() { cudaSetDevice(device_id_); // NOTE(chengcheng): In some corner case on ubuntu, cuda memory not released even if OOM. // So there need release all cuda memory allocated by this process before core dump. LOG(WARNING) << "OOM error is detected, process will exit. And it will start to reset CUDA " << "device for releasing device memory."; OF_CUDA_CHECK(cudaDeviceReset()); } TEST(CudaBinAllocator, cuda_allocator) { int gpu_num = -1; cudaGetDeviceCount(&gpu_num); if (gpu_num <= 0) { LOG(INFO) << "CudaBinAllocator Test: Skip because of non GPU device."; return; } ASSERT_TRUE(cudaSuccess == cudaSetDevice(0)); size_t free_bytes = -1; size_t total_bytes = -1; const size_t remain_bytes = 50 * 1048576; ASSERT_TRUE(cudaSuccess == cudaMemGetInfo(&free_bytes, &total_bytes)); if (free_bytes <= remain_bytes || free_bytes - remain_bytes < remain_bytes) { LOG(INFO) << "CudaBinAllocator Test: Skip because of allocator mem bytes less than 50MiB in GPU 0"; return; } std::unique_ptr allo(new BinAllocator( kCudaMemAllocAlignSize, std::make_unique(0))); Allocator* a = allo.get(); std::vector ptrs; for (int i = 0; i < 512; ++i) { char* ptr = nullptr; CHECK_JUST(a->Allocate(&ptr, 1)); ASSERT_TRUE(ptr != nullptr); ptrs.emplace_back(ptr); } std::sort(ptrs.begin(), ptrs.end()); for (int i = 0; i < 512; ++i) { if (i > 0) { ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1)); ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize); } a->Deallocate(ptrs.at(i), 1); } ptrs.clear(); for (int i = 0; i < 2048; ++i) { char* ptr = nullptr; CHECK_JUST(a->Allocate(&ptr, 10000)); ASSERT_TRUE(ptr != nullptr); ptrs.emplace_back(ptr); } std::sort(ptrs.begin(), ptrs.end()); for (int i = 0; i < 2048; ++i) { if (i > 0) { ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1)); ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize); } a->Deallocate(ptrs.at(i), 10000); } char* data_ptr_1 = nullptr; CHECK_JUST(a->Allocate(&data_ptr_1, 2048 * sizeof(float))); char* data_ptr_2 = nullptr; CHECK_JUST(a->Allocate(&data_ptr_2, 4096 * sizeof(double))); ASSERT_TRUE(data_ptr_1 != data_ptr_2); if (data_ptr_1 < data_ptr_2) { ASSERT_TRUE(data_ptr_1 + 2048 * sizeof(float) <= data_ptr_2); } else { ASSERT_TRUE(data_ptr_2 + 4096 * sizeof(double) <= data_ptr_1); } a->Deallocate(data_ptr_2, 4096 * sizeof(double)); a->Deallocate(data_ptr_1, 2048 * sizeof(float)); } } // namespace vm } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/core/vm/caching_allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/vm/allocator.h" namespace oneflow { namespace vm { class CachingAllocator : public Allocator { public: virtual ~CachingAllocator() = default; virtual void Shrink() = 0; protected: CachingAllocator() = default; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/control_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_ #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/naive_instruction_status_querier.h" #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { class ControlStreamPolicy final : public StreamPolicy { public: ControlStreamPolicy() = default; ~ControlStreamPolicy() = default; vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; } DeviceType device_type() const override { PRINT_BUG_PROMPT_AND_ABORT(); return DeviceType::kInvalidDevice; } ep::Stream* stream() override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override { static_assert(sizeof(NaiveInstrStatusQuerier) < kInstructionStatusBufferBytes, ""); NaiveInstrStatusQuerier::PlacementNew(status_buffer->mut_buffer()); } void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override { auto* ptr = NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()); ptr->~NaiveInstrStatusQuerier(); } bool QueryInstructionStatusLaunched(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override { return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->launched(); } bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override { return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->done(); } void Run(Instruction* instruction) const override { instruction->Compute(); auto* status_buffer = instruction->mut_status_buffer(); NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer())->set_done(); } bool OnSchedulerThread(StreamType) const override { return true; } bool SupportingTransportInstructions() const override { return false; } }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/critical_section_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/critical_section_instruction_policy.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/just.h" #include "oneflow/core/device/ep_based_event_record.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { void CriticalSectionBeginInstructionPolicy::ForEachDependence( const std::function& DoEach) const { for (const auto& eager_blob_object : *eager_blob_objects_) { DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } } void CriticalSectionBeginInstructionPolicy::ForEachMutDependence( const std::function& DoEach) const { DoEach(vm_stream_->schedule_local_dep_object().get()); } void CriticalSectionBeginInstructionPolicy::FinishInvalidInterfaceEventRecords() { for (const auto& op_name : interfaces_op_names()) { size_t index = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); if (!interfaces_valid().at(index)) { const auto& iter = op_name2end_event_record_->find(op_name); CHECK(iter != op_name2end_event_record_->end()); iter->second->Init(std::make_shared()); } } } void CriticalSectionBeginInstructionPolicy::Finish() { for (const auto& pair : *op_name2end_event_record_) { pair.second->TryInit(std::make_shared()); } } void InputCriticalSectionBeginInstructionPolicy::AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) { int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); CHECK(interfaces_valid().at(i)); const auto& eager_blob_object = eager_blob_objects_->at(i); { size_t header_size = blob->blob_desc().ByteSizeOfBlobHeader(); CHECK_EQ(header_size, eager_blob_object->shape().NumAxes() * sizeof(int64_t)); CHECK_EQ(blob->static_shape(), eager_blob_object->shape()); } const auto& end_event_record = op_name2end_event_record_->at(op_name); if (eager_blob_object->dptr() == nullptr) { end_event_record->Init(std::make_shared()); } else { { const size_t body_bytes = blob->ByteSizeOfBlobBody(); CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes); AutoMemcpy(stream, blob->mut_dptr(), eager_blob_object->dptr(), body_bytes, blob->mem_case(), eager_blob_object->mem_case()); } end_event_record->Init(EpBasedEventRecord::MakeEventRecord(stream)); } } void OutputCriticalSectionBeginInstructionPolicy::AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) { int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); CHECK(interfaces_valid().at(i)); auto& eager_blob_object = eager_blob_objects_->at(i); CHECK_EQ(blob->static_shape(), eager_blob_object->shape()); const auto& end_event_record = op_name2end_event_record_->at(op_name); if (eager_blob_object->dptr() == nullptr) { end_event_record->Init(std::make_shared()); } else { { const size_t body_bytes = blob->ByteSizeOfBlobBody(); CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes); AutoMemcpy(stream, eager_blob_object->mut_dptr(), blob->dptr(), body_bytes, eager_blob_object->mem_case(), blob->mem_case()); } end_event_record->Init(EpBasedEventRecord::MakeEventRecord(stream)); } } void CriticalSectionEndInstructionPolicy::ForEachDependence( const std::function& DoEach) const { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); } void CriticalSectionEndInstructionPolicy::ForEachMutDependence( const std::function& DoEach) const { DoEach(vm_stream_->schedule_local_dep_object().get()); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/critical_section_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_ #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/device/event_record.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/vm/critical_section_status_querier.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/instruction_policy_util.h" #include "oneflow/core/vm/stream.h" namespace oneflow { namespace vm { class CriticalSectionBeginInstructionPolicy : public InstructionPolicy, public std::enable_shared_from_this { public: CriticalSectionBeginInstructionPolicy(const CriticalSectionBeginInstructionPolicy&) = delete; CriticalSectionBeginInstructionPolicy(CriticalSectionBeginInstructionPolicy&&) = delete; CriticalSectionBeginInstructionPolicy& operator=(const CriticalSectionBeginInstructionPolicy&) = delete; CriticalSectionBeginInstructionPolicy& operator=(CriticalSectionBeginInstructionPolicy&&) = delete; virtual ~CriticalSectionBeginInstructionPolicy() = default; explicit CriticalSectionBeginInstructionPolicy( const std::shared_ptr& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, Stream* vm_stream) : nn_graph_(nn_graph), eager_blob_objects_(eager_blob_objects), op_name2end_event_record_(op_name2end_event_record), vm_stream_(vm_stream) {} std::string DebugName(const Instruction& instruction) const override { return "CriticalSectionBegin"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override { OF_PROFILER_RANGE_GUARD("CriticalSectionBegin"); { const auto& critical_section_instance = MakeCriticalSectionInstance(); const auto& job_name = critical_section_instance->job_name(); auto* buffer_mgr = Singleton>>::Get(); for (int i = 0; i < interfaces_op_names().size(); ++i) { if (interfaces_valid().at(i)) { const std::string& interface_op_name = interfaces_op_names().at(i); const auto& buffer_name = GetInterfaceBufferName(job_name, interface_op_name); buffer_mgr->Get(buffer_name)->Push(critical_section_instance); } } const auto& callback_buffer_name = GetInterfaceCriticalSectionCallbackBufferName(job_name); buffer_mgr->Get(callback_buffer_name)->Push(critical_section_instance); const auto& wait_buffer_name = GetInterfaceCriticalSectionWaitBufferName(job_name); buffer_mgr->Get(wait_buffer_name)->Push(critical_section_instance); } { auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer(); auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data); status_querier->SetLaunched(std::make_shared()); } } const std::shared_ptr& nn_graph() const { return nn_graph_; } const EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; } void ForEachDependence(const std::function&) const; void ForEachMutDependence(const std::function&) const; virtual const std::vector& interfaces_op_names() const = 0; virtual const std::vector& interfaces_valid() const = 0; virtual std::string GetInterfaceBufferName(const std::string& job_name, const std::string& op_name) const = 0; virtual std::string GetInterfaceCriticalSectionCallbackBufferName( const std::string& job_name) const = 0; virtual std::string GetInterfaceCriticalSectionWaitBufferName( const std::string& job_name) const = 0; virtual void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) = 0; void FinishInvalidInterfaceEventRecords(); void Finish(); protected: std::shared_ptr nn_graph_; EagerBlobObjectListPtr eager_blob_objects_; std::shared_ptr>> op_name2end_event_record_; HashMap op_name2interface_index_; Stream* vm_stream_; private: class NaiveCriticalSectionInstance final : public CriticalSectionInstance { public: NaiveCriticalSectionInstance(const std::shared_ptr& critical_section_begin_instruction_policy, const std::string& job_name) : CriticalSectionInstance(), critical_section_begin_instruction_policy_(critical_section_begin_instruction_policy), job_name_(job_name) {} ~NaiveCriticalSectionInstance() override = default; const std::string& job_name() const override { return job_name_; } void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) const override { critical_section_begin_instruction_policy_->AccessBlobByOpName(stream, blob, op_name); } void Finish() const override { critical_section_begin_instruction_policy_->Finish(); } private: std::shared_ptr critical_section_begin_instruction_policy_; std::string job_name_; }; std::shared_ptr MakeCriticalSectionInstance() { return std::make_shared(this->shared_from_this(), nn_graph_->job_name()); } }; class InputCriticalSectionBeginInstructionPolicy final : public CriticalSectionBeginInstructionPolicy { public: InputCriticalSectionBeginInstructionPolicy( const std::shared_ptr& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, Stream* vm_stream) : CriticalSectionBeginInstructionPolicy(nn_graph, eager_blob_objects, op_name2end_event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_)); ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); CHECK_EQ(nn_graph->inputs_op_names().size(), eager_blob_objects->size()); CHECK_EQ(nn_graph->inputs_op_names().size(), nn_graph->inputs_valid().size()); for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { CHECK(op_name2interface_index_.emplace(nn_graph->inputs_op_names().at(i), i).second); } } ~InputCriticalSectionBeginInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } // for inputs void ForEachConstDependence(const std::function& DoEach) const { ForEachDependence(DoEach); } // for outputs const std::vector& interfaces_op_names() const override { return nn_graph_->inputs_op_names(); } const std::vector& interfaces_valid() const override { return nn_graph_->inputs_valid(); } std::string GetInterfaceBufferName(const std::string& job_name, const std::string& op_name) const override { return GetInputBufferName(job_name, op_name); } std::string GetInterfaceCriticalSectionCallbackBufferName( const std::string& job_name) const override { return GetInputCriticalSectionCallbackBufferName(job_name); } std::string GetInterfaceCriticalSectionWaitBufferName( const std::string& job_name) const override { return GetInputCriticalSectionWaitBufferName(job_name); } void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) override; void ForEachMut2Dependence(const std::function&) const {} private: DependenceVector input_dependences_; DependenceVector output_dependences_; }; class OutputCriticalSectionBeginInstructionPolicy final : public CriticalSectionBeginInstructionPolicy { public: OutputCriticalSectionBeginInstructionPolicy( const std::shared_ptr& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, Stream* vm_stream) : CriticalSectionBeginInstructionPolicy(nn_graph, eager_blob_objects, op_name2end_event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_)); ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); CHECK_EQ(nn_graph->outputs_op_names().size(), eager_blob_objects->size()); CHECK_EQ(nn_graph->outputs_op_names().size(), nn_graph->outputs_valid().size()); for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { CHECK(op_name2interface_index_.emplace(nn_graph->outputs_op_names().at(i), i).second); } } ~OutputCriticalSectionBeginInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } // for inputs void ForEachConstDependence(const std::function&) const {} // for outputs void ForEachMut2Dependence(const std::function& DoEach) const { ForEachDependence(DoEach); } const std::vector& interfaces_op_names() const override { return nn_graph_->outputs_op_names(); } const std::vector& interfaces_valid() const override { return nn_graph_->outputs_valid(); } std::string GetInterfaceBufferName(const std::string& job_name, const std::string& op_name) const override { return GetOutputBufferName(job_name, op_name); } std::string GetInterfaceCriticalSectionCallbackBufferName( const std::string& job_name) const override { return GetOutputCriticalSectionCallbackBufferName(job_name); } std::string GetInterfaceCriticalSectionWaitBufferName( const std::string& job_name) const override { return GetOutputCriticalSectionWaitBufferName(job_name); } void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) override; private: DependenceVector input_dependences_; DependenceVector output_dependences_; }; class CriticalSectionEndInstructionPolicy : public InstructionPolicy { public: CriticalSectionEndInstructionPolicy(const CriticalSectionEndInstructionPolicy&) = delete; CriticalSectionEndInstructionPolicy(CriticalSectionEndInstructionPolicy&&) = delete; CriticalSectionEndInstructionPolicy& operator=(const CriticalSectionEndInstructionPolicy&) = delete; CriticalSectionEndInstructionPolicy& operator=(CriticalSectionEndInstructionPolicy&&) = delete; CriticalSectionEndInstructionPolicy(const std::shared_ptr& eager_blob_object, const std::shared_ptr& event_record, vm::Stream* vm_stream) : eager_blob_object_(eager_blob_object), event_record_(event_record), vm_stream_(vm_stream) {} virtual ~CriticalSectionEndInstructionPolicy() = default; std::string DebugName(const Instruction& instruction) const override { return "CriticalSectionEnd"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override { auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer(); auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data); status_querier->SetLaunched(event_record()); } const std::shared_ptr& event_record() const { return event_record_; } void ForEachDependence(const std::function&) const; void ForEachMutDependence(const std::function&) const; private: std::shared_ptr eager_blob_object_; std::shared_ptr event_record_; vm::Stream* vm_stream_; }; class InputCriticalSectionEndInstructionPolicy final : public CriticalSectionEndInstructionPolicy { public: InputCriticalSectionEndInstructionPolicy( const std::shared_ptr& eager_blob_object, const std::shared_ptr& event_record, vm::Stream* vm_stream) : CriticalSectionEndInstructionPolicy(eager_blob_object, event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_)); ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); } ~InputCriticalSectionEndInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } void ForEachConstDependence(const std::function& DoEach) const { ForEachDependence(DoEach); } void ForEachMut2Dependence(const std::function&) const {} private: DependenceVector input_dependences_; DependenceVector output_dependences_; }; class OutputCriticalSectionEndInstructionPolicy final : public CriticalSectionEndInstructionPolicy { public: OutputCriticalSectionEndInstructionPolicy( const std::shared_ptr& eager_blob_object, const std::shared_ptr& event_record, vm::Stream* vm_stream) : CriticalSectionEndInstructionPolicy(eager_blob_object, event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_)); ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_)); } ~OutputCriticalSectionEndInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } // for inputs void ForEachConstDependence(const std::function&) const {} // for outputs void ForEachMut2Dependence(const std::function& DoEach) const { ForEachDependence(DoEach); } private: DependenceVector input_dependences_; DependenceVector output_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/critical_section_status_querier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ #define ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ #include #include #include "oneflow/core/device/event_record.h" namespace oneflow { namespace vm { class CriticalSectionStatusQuerier final { public: ~CriticalSectionStatusQuerier() = default; bool QueryLaunched() const { return launched_; } bool QueryDone() const { return launched_ && event_record_->QueryDone(); } void SetLaunched(const std::shared_ptr& event_record) { // No lock needed. This function will be called only one time. // In most cases, errors will be successfully detected by CHECK // even though run in different threads. CHECK(!launched_); event_record_ = event_record; launched_ = true; } static const CriticalSectionStatusQuerier* Cast(const char* mem_ptr) { return reinterpret_cast(mem_ptr); } static CriticalSectionStatusQuerier* MutCast(char* mem_ptr) { return reinterpret_cast(mem_ptr); } static CriticalSectionStatusQuerier* PlacementNew(char* mem_ptr) { return new (mem_ptr) CriticalSectionStatusQuerier(); } private: explicit CriticalSectionStatusQuerier() : launched_(false) {} std::atomic launched_; std::shared_ptr event_record_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ ================================================ FILE: oneflow/core/vm/critical_section_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/critical_section_stream_policy.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/critical_section_status_querier.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { void CriticalSectionStreamPolicy::InitInstructionStatus( const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(CriticalSectionStatusQuerier) < kInstructionStatusBufferBytes, ""); CriticalSectionStatusQuerier::PlacementNew(status_buffer->mut_buffer()); } void CriticalSectionStreamPolicy::DeleteInstructionStatus( const Stream& stream, InstructionStatusBuffer* status_buffer) const { auto* ptr = CriticalSectionStatusQuerier::MutCast(status_buffer->mut_buffer()); ptr->~CriticalSectionStatusQuerier(); } bool CriticalSectionStreamPolicy::QueryInstructionStatusLaunched( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return CriticalSectionStatusQuerier::Cast(status_buffer.buffer())->QueryLaunched(); } bool CriticalSectionStreamPolicy::QueryInstructionStatusDone( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return CriticalSectionStatusQuerier::Cast(status_buffer.buffer())->QueryDone(); } void CriticalSectionStreamPolicy::Run(Instruction* instruction) const { instruction->Compute(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/critical_section_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_ #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/vm/instruction.h" namespace oneflow { namespace vm { class CriticalSectionStreamPolicy final : public StreamPolicy { public: CriticalSectionStreamPolicy() = default; virtual ~CriticalSectionStreamPolicy() = default; vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; } DeviceType device_type() const override { PRINT_BUG_PROMPT_AND_ABORT(); return DeviceType::kInvalidDevice; } ep::Stream* stream() override { PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; bool QueryInstructionStatusLaunched(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Run(Instruction* instruction) const override; bool SupportingTransportInstructions() const override { return false; } }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/ep_backend_allocator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_backend_allocator.h" #include "oneflow/core/ep/include/device.h" namespace oneflow { namespace vm { Maybe EpBackendAllocator::Allocate(char** mem_ptr, std::size_t size) { return ep_device_->Alloc(allocation_options_, reinterpret_cast(mem_ptr), size); } void EpBackendAllocator::Deallocate(char* mem_ptr, std::size_t size) { ep_device_->Free(allocation_options_, mem_ptr); } void EpBackendAllocator::DeviceReset() { if (ep_device_->device_type() != DeviceType::kCPU) { // NOTE(chengcheng): In some corner case on ubuntu, cuda memory not released even if OOM. // So there need release all cuda memory allocated by this process before core dump. LOG(WARNING) << "OOM error is detected, process will exit. And it will start to reset " << "device for releasing device memory."; ep_device_->Reset(); } } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_backend_allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_ #include #include "oneflow/core/vm/allocator.h" #include "oneflow/core/ep/include/allocation_options.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace ep { class Device; } namespace vm { class EpBackendAllocator final : public Allocator { public: explicit EpBackendAllocator(const std::shared_ptr& ep_device, const ep::AllocationOptions& allocation_options) : ep_device_(ep_device), allocation_options_(allocation_options) {} ~EpBackendAllocator() override = default; Maybe Allocate(char** mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override; void DeviceReset() override; private: std::shared_ptr ep_device_; ep::AllocationOptions allocation_options_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/ep_backend_host_allocator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_backend_host_allocator.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/include/device.h" namespace oneflow { namespace vm { Maybe EpBackendHostAllocator::Allocate(char** mem_ptr, std::size_t size) { JUST(ep_device_->AllocPinned(allocation_options_, reinterpret_cast(mem_ptr), size)); return Maybe::Ok(); } void EpBackendHostAllocator::Deallocate(char* mem_ptr, std::size_t size) { ep_device_->FreePinned(allocation_options_, mem_ptr); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_backend_host_allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_ #include #include "oneflow/core/vm/allocator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/allocation_options.h" namespace oneflow { namespace ep { class Device; } namespace vm { class EpBackendHostAllocator final : public Allocator { public: explicit EpBackendHostAllocator(const std::shared_ptr& ep_device, const ep::AllocationOptions& allocation_options) : ep_device_(ep_device), allocation_options_(allocation_options) {} ~EpBackendHostAllocator() override = default; Maybe Allocate(char** mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override; void DeviceReset() override {} private: std::shared_ptr ep_device_; ep::AllocationOptions allocation_options_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/ep_d2h_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_d2h_stream_policy.h" #include #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/ep_backend_host_allocator.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { namespace { std::unique_ptr> CreateEpBackendHostAllocator(Symbol device) { DeviceType device_type = device->enum_type(); size_t device_index = device->device_id(); auto ep_device = Singleton::Get()->GetDevice(device_type, device_index); auto ep_backend_allocator = std::make_unique(ep_device, ep::AllocationOptions{}); return std::make_unique>(ep::kMaxAlignmentRequirement, std::move(ep_backend_allocator)); } } // namespace EpD2HStreamPolicy::EpD2HStreamPolicy(Symbol device) : EpStreamPolicyBase(device, CreateEpBackendHostAllocator(device)) {} void EpD2HStreamPolicy::InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, ""); EpStreamPolicyBase* ep_stream_policy_base = dynamic_cast(const_cast(stream).mut_stream_policy()); CHECK_NOTNULL(ep_stream_policy_base); auto* ep_event_provider = ep_stream_policy_base->ep_event_provider(); auto* data_ptr = status_buffer->mut_buffer(); const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent(); EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, ep_event); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_d2h_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_ #include "oneflow/core/vm/ep_stream_policy_base.h" namespace oneflow { namespace vm { class EpD2HStreamPolicy final : public EpStreamPolicyBase { public: EpD2HStreamPolicy(Symbol device); ~EpD2HStreamPolicy() override = default; void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/ep_event.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_event.h" namespace oneflow { EpEvent::EpEvent(ep::Device* device) : device_(device), event_(nullptr) { device_->SetAsActiveDevice(); event_ = device_->CreateEvent(); // NOLINT } EpEvent::~EpEvent() { device_->SetAsActiveDevice(); device_->DestroyEvent(event_); } bool EpEvent::Query() const { device_->SetAsActiveDevice(); return CHECK_JUST(event_->QueryDone()); } } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_event.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_EVENT_H_ #define ONEFLOW_CORE_VM_EP_EVENT_H_ #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/include/event.h" #include "oneflow/core/common/single_thread_obj_pool.h" namespace oneflow { class EpEvent final { public: EpEvent(const EpEvent&) = delete; EpEvent(EpEvent&&) = delete; EpEvent(ep::Device* device); ~EpEvent(); bool Query() const; ep::Device* mut_device() { return device_; } ep::Event* mut_event() { return event_; } private: ep::Device* device_; ep::Event* event_; }; class EpEventProvider { public: EpEventProvider(const EpEventProvider&) = delete; EpEventProvider(EpEventProvider&&) = delete; virtual ~EpEventProvider() = default; virtual std::shared_ptr GetReusedEpEvent() = 0; protected: EpEventProvider() = default; }; class SingleThreadEpEventProvider final : public EpEventProvider { public: SingleThreadEpEventProvider(const SingleThreadEpEventProvider&) = delete; SingleThreadEpEventProvider(SingleThreadEpEventProvider&&) = delete; explicit SingleThreadEpEventProvider(ep::Device* device) : EpEventProvider(), events_(new SingleThreadPoolType()), device_(device) {} ~SingleThreadEpEventProvider() = default; std::shared_ptr GetReusedEpEvent() override { return events_->make_shared(device_); } private: using SingleThreadPoolType = obj_pool::SingleThreadObjPool; std::shared_ptr events_; ep::Device* device_; }; } // namespace oneflow #endif // ONEFLOW_CORE_VM_EP_EVENT_H_ ================================================ FILE: oneflow/core/vm/ep_optional_event_record_status_querier.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" namespace oneflow { namespace vm { void EpOptionalEventRecordStatusQuerier::SetLaunched(ep::Stream* stream) { CHECK(!launched_); if (ep_event_) { ep_event_->mut_device()->SetAsActiveDevice(); stream->RecordEvent(ep_event_->mut_event()); } launched_ = true; } EpOptionalEventRecordStatusQuerier::~EpOptionalEventRecordStatusQuerier() {} } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_optional_event_record_status_querier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_ #define ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_ #include #include "oneflow/core/vm/ep_event.h" namespace oneflow { namespace vm { class EpOptionalEventRecordStatusQuerier { public: OF_DISALLOW_COPY_AND_MOVE(EpOptionalEventRecordStatusQuerier); ~EpOptionalEventRecordStatusQuerier(); bool launched() const { return launched_; } bool done() const { return launched_ && (ep_event_ == nullptr || ep_event_->Query()); } void SetLaunched(ep::Stream* stream); void reset_ep_event(const std::shared_ptr& ep_event) { ep_event_ = ep_event; } static const EpOptionalEventRecordStatusQuerier* Cast(const char* mem_ptr) { return reinterpret_cast(mem_ptr); } static EpOptionalEventRecordStatusQuerier* MutCast(char* mem_ptr) { return reinterpret_cast(mem_ptr); } static EpOptionalEventRecordStatusQuerier* PlacementNew( char* mem_ptr, const std::shared_ptr& ep_event) { return new (mem_ptr) EpOptionalEventRecordStatusQuerier(ep_event); } private: explicit EpOptionalEventRecordStatusQuerier(const std::shared_ptr& ep_event) : launched_(false), ep_event_(ep_event) {} std::atomic launched_; std::shared_ptr ep_event_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_ ================================================ FILE: oneflow/core/vm/ep_record_event_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_RECORD_EVENT_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_EP_RECORD_EVENT_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/vm/ep_stream_policy_base.h" #include "oneflow/core/vm/stream.h" namespace oneflow { namespace vm { class EpRecordEventInstructionPolicy final : public InstructionPolicy { public: EpRecordEventInstructionPolicy( small_vector>&& compute_local_dep_objects, const std::string& modifier) : compute_local_dep_objects_(std::move(compute_local_dep_objects)), modifier_(modifier), input_dependences_(), output_dependences_() { ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); }); ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); } ~EpRecordEventInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } template void ForEachConstDependence(const DoEachT& DoEach) const { if (modifier_ == "const") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } } template void ForEachMutDependence(const DoEachT& DoEach) const { if (modifier_ == "mut") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } } template void ForEachMut2Dependence(const DoEachT& DoEach) const { if (modifier_ == "mut2") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } } InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAsTailOnly; } void InitInstructionStatus(Instruction* instruction) override { auto* status_buffer = instruction->mut_status_buffer(); auto* stream = instruction->mut_stream(); instruction->stream_policy().InitInstructionStatus(*stream, status_buffer); EpStreamPolicyBase* ep_stream_policy_base = dynamic_cast(stream->mut_stream_policy()); CHECK_NOTNULL(ep_stream_policy_base); auto* ep_event_provider = ep_stream_policy_base->ep_event_provider(); const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent(); auto* data_ptr = status_buffer->mut_buffer(); EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(ep_event); } Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } std::string DebugName(const vm::Instruction&) const override { return "RecordEvent"; } void Compute(vm::Instruction* instruction) override {} private: small_vector> compute_local_dep_objects_; const std::string modifier_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; } // namespace vm struct GetRecordEventInstructionPolicy : public StreamTypeVisitor { template static Maybe VisitCompute(DeviceType device_type, Args&&... args) { return std::shared_ptr( new vm::EpRecordEventInstructionPolicy(std::forward(args)...)); } template static Maybe VisitHost2Device(DeviceType device_type, Args&&... args) { return std::shared_ptr( new vm::EpRecordEventInstructionPolicy(std::forward(args)...)); } template static Maybe VisitDevice2Host(DeviceType device_type, Args&&... args) { return std::shared_ptr( new vm::EpRecordEventInstructionPolicy(std::forward(args)...)); } template static Maybe VisitCcl(DeviceType device_type, Args&&... args) { return std::shared_ptr( new vm::EpRecordEventInstructionPolicy(std::forward(args)...)); } template static Maybe VisitBarrier(DeviceType device_type, Args&&... args) { UNIMPLEMENTED_THEN_RETURN() << "EpRecordEvent instruction not supported in Barrier stream"; } template static Maybe VisitCriticalSection(DeviceType device_type, Args&&... args) { UNIMPLEMENTED_THEN_RETURN() << "EpRecordEvent instruction not supported in CriticalSection stream"; } template static Maybe VisitLazyJobLauncher(DeviceType device_type, Args&&... args) { UNIMPLEMENTED_THEN_RETURN() << "EpRecordEvent instruction not supported in LaunchLazyJob stream"; } template static Maybe VisitPinnedCompute(DeviceType device_type, Args&&... args) { return std::shared_ptr( new vm::EpRecordEventInstructionPolicy(std::forward(args)...)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_ ================================================ FILE: oneflow/core/vm/ep_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_stream_policy.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/vm/remat/allocator.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/ep_backend_allocator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/vm/remat/util.h" namespace oneflow { namespace vm { namespace { std::unique_ptr CreateEpBackendDeviceAllocator(Symbol device) { DeviceType device_type = device->enum_type(); size_t device_index = device->device_id(); if (device->rematable()) { return std::make_unique( Singleton::Get()->CreateOrGetAllocator(device_type, device_index)); } else { auto ep_device = Singleton::Get()->GetDevice(device_type, device_index); auto ep_backend_allocator = std::make_unique(ep_device, ep::AllocationOptions{}); return std::make_unique>(ep::kMaxAlignmentRequirement, std::move(ep_backend_allocator)); } } } // namespace EpStreamPolicy::EpStreamPolicy(Symbol device) : EpStreamPolicyBase(device, CreateEpBackendDeviceAllocator(device)) {} void EpStreamPolicy::InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, ""); auto* data_ptr = status_buffer->mut_buffer(); EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, nullptr); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_ #include "oneflow/core/vm/ep_stream_policy_base.h" namespace oneflow { namespace vm { class EpStreamPolicy final : public EpStreamPolicyBase { public: EpStreamPolicy(Symbol device); ~EpStreamPolicy() override = default; void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/ep_stream_policy_base.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/ep_stream_policy_base.h" #include #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/ep_backend_host_allocator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace vm { void EpStreamPolicyBase::DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { auto* ptr = EpOptionalEventRecordStatusQuerier::MutCast(status_buffer->mut_buffer()); ptr->~EpOptionalEventRecordStatusQuerier(); } bool EpStreamPolicyBase::QueryInstructionStatusLaunched( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return EpOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer())->launched(); } bool EpStreamPolicyBase::QueryInstructionStatusDone( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return EpOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer())->done(); } void EpStreamPolicyBase::Run(Instruction* instruction) const { OF_PROFILER_RANGE_GUARD("S:" + instruction->DebugName()); auto* stream = instruction->mut_stream(); EpStreamPolicyBase* ep_stream_policy_base = dynamic_cast(stream->mut_stream_policy()); CHECK_NOTNULL(ep_stream_policy_base); auto* ep_device = ep_stream_policy_base->GetOrCreateEpDevice(); ep_device->SetAsActiveDevice(); instruction->Compute(); char* data_ptr = instruction->mut_status_buffer()->mut_buffer(); EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched( stream->mut_stream_policy()->stream()); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/ep_stream_policy_base.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_ #define ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_ #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/vm/ep_event.h" #include "oneflow/core/vm/bin_allocator.h" #include "oneflow/core/vm/thread_safe_guard.h" #include "oneflow/core/ep/include/device_manager_registry.h" namespace oneflow { namespace vm { class EpStreamPolicyBase : public StreamPolicy { public: EpStreamPolicyBase(Symbol device, std::unique_ptr&& backend_allocator) : device_(device), ep_event_provier_(), ep_stream_(nullptr), ep_allocator_(std::move(backend_allocator)) {} virtual ~EpStreamPolicyBase() override { if (ep_stream_ != nullptr) { CHECK(ep_device_); ep_device_->DestroyStream(ep_stream_); } } ep::Stream* stream() override { return GetOrCreateEpStream(); } vm::Allocator* mut_allocator() override { return ep_allocator_.get(); } DeviceType device_type() const override { return device_->enum_type(); } EpEventProvider* ep_event_provider() { if (unlikely(ep_event_provier_ == nullptr)) { ep_event_provier_.reset(new SingleThreadEpEventProvider(GetOrCreateEpDevice())); } return ep_event_provier_.get(); } ep::Device* GetOrCreateEpDevice() const { if (unlikely(ep_device_ == nullptr)) { ep_device_ = Singleton::Get()->GetDevice(device_->enum_type(), device_->device_id()); CHECK(ep_device_); } return ep_device_.get(); } bool SupportingTransportInstructions() const override { return true; } void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; bool QueryInstructionStatusLaunched(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Run(Instruction* instruction) const override; private: ep::Stream* GetOrCreateEpStream() const { if (unlikely(ep_stream_ == nullptr)) { ep_stream_ = GetOrCreateEpDevice()->CreateStream(); CHECK(ep_stream_ != nullptr); } return ep_stream_; } Symbol device_; std::unique_ptr ep_event_provier_; mutable std::shared_ptr ep_device_; mutable ep::Stream* ep_stream_; std::unique_ptr ep_allocator_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_ ================================================ FILE: oneflow/core/vm/event_recorded_ep_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/event_recorded_ep_stream_policy.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/ep_backend_allocator.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { /*static*/ std::unique_ptr> EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(Symbol device) { DeviceType device_type = device->enum_type(); size_t device_index = device->device_id(); auto ep_device = Singleton::Get()->GetDevice(device_type, device_index); auto ep_backend_allocator = std::make_unique(ep_device, ep::AllocationOptions{}); return std::make_unique>(ep::kMaxAlignmentRequirement, std::move(ep_backend_allocator)); } EventRecordedEpStreamPolicy::EventRecordedEpStreamPolicy(Symbol device, std::unique_ptr&& allocator) : EpStreamPolicyBase(device, std::move(allocator)) {} void EventRecordedEpStreamPolicy::InitInstructionStatus( const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, ""); EpStreamPolicyBase* ep_stream_policy_base = dynamic_cast(const_cast(stream).mut_stream_policy()); CHECK_NOTNULL(ep_stream_policy_base); auto* ep_event_provider = ep_stream_policy_base->ep_event_provider(); auto* data_ptr = status_buffer->mut_buffer(); const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent(); EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, ep_event); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/event_recorded_ep_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_ #include "oneflow/core/vm/ep_stream_policy_base.h" namespace oneflow { namespace vm { class EventRecordedEpStreamPolicy final : public EpStreamPolicyBase { public: EventRecordedEpStreamPolicy(Symbol device, std::unique_ptr&& allocator); ~EventRecordedEpStreamPolicy() override = default; void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; static std::unique_ptr> CreateEpBackendDeviceAllocator( Symbol device); }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/fuse_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_policy_util.h" #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { class FuseInstructionPolicy final : public InstructionPolicy { public: explicit FuseInstructionPolicy(InstructionList&& instruction_list) : instruction_list_(), input_dependences_(), output_dependences_() { instruction_list.MoveTo(&instruction_list_); auto ReadOnlyDepsInserter = InstructionPolicyUtil::SetInserter(&input_dependences_); auto WritableDepsInserter = InstructionPolicyUtil::SetInserter(&output_dependences_); auto* last_instruction = instruction_list_.Last(); INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, &instruction_list_) { if (instruction == last_instruction) { CHECK(instruction->instruction_policy().fuse_type() == kEnableInstructionFuseAsTailOnly || instruction->instruction_policy().fuse_type() == kEnableInstructionFuseAtAnyPosition); } else { CHECK(instruction->instruction_policy().fuse_type() == kEnableInstructionFuseAtAnyPosition); } if (unlikely(stream_sequential_dependence_ == nullptr)) { stream_sequential_dependence_ = instruction->instruction_policy().stream_sequential_dependence(); } else { CHECK_EQ(stream_sequential_dependence_, instruction->instruction_policy().stream_sequential_dependence()); } for (auto* dep : instruction->instruction_policy().input_dependences()) { ReadOnlyDepsInserter(dep); } for (auto* dep : instruction->instruction_policy().output_dependences()) { WritableDepsInserter(dep); } } } ~FuseInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } InstructionList* mut_instruction_list() { return &instruction_list_; } private: Maybe Prepare(Instruction* instruction) override { INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_instruction_list()) { JUST(instruction->Prepare()); } return Maybe::Ok(); } void Compute(Instruction* instruction) override { OF_PROFILER_RANGE_GUARD("F:" + instruction->DebugName()); INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_instruction_list()) { instruction->Compute(); } } void InitInstructionStatus(Instruction* instruction) override { auto* last_instruction = CHECK_NOTNULL(mut_instruction_list()->Last()); last_instruction->mut_instruction_policy()->InitInstructionStatusIf(instruction); } std::string DebugName(const Instruction&) const override { return "Fuse"; } InstructionList instruction_list_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/global_sync_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_ #include "oneflow/core/rpc/include/base.h" #include "oneflow/core/vm/instruction_policy.h" namespace oneflow { namespace vm { class GlobalSyncInstructionPolicy final : public InstructionPolicy { public: GlobalSyncInstructionPolicy() = default; ~GlobalSyncInstructionPolicy() override = default; const DependenceVector& input_dependences() const override { static DependenceVector dependences{}; return dependences; } const DependenceVector& output_dependences() const override { static DependenceVector dependences{}; return dependences; } bool IsBarrier() const override { return true; } std::string DebugName(const vm::Instruction& instruction) const override { return "GlobalSync"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override { OF_ENV_BARRIER(); } }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/instruction.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/instruction.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/virtual_machine_engine.h" #include "oneflow/core/framework/stream_get_stream_type_name.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/extension/stack/foreign_stack_getter.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace vm { std::string Instruction::DebugName() const { std::string instr_name = instruction_policy().DebugName(*this); return instr_name + ":s_" + GetStreamTypeName::Visit(stream().stream_type()); } void Instruction::__Init__(Stream* stream, std::shared_ptr&& instruction_policy) { stream_ = stream; instruction_policy_ = std::move(instruction_policy); if (IsMainThread()) { if (auto* stack_getter = Singleton::Get()) { foreign_frame_ = stack_getter->GetCurrentFrame(); } } } void Instruction::InitStatus() { instruction_policy_->InitInstructionStatusIf(this); } Maybe Instruction::Prepare() { ForeignFrameThreadLocalGuard guard(foreign_frame_); return instruction_policy_->PrepareIf(this); } void Instruction::Compute() { ForeignFrameThreadLocalGuard guard(foreign_frame_); instruction_policy_->ComputeIf(this); } void Instruction::DeleteStatusAndCheckEdges() { OF_PROFILER_RANGE_GUARD("Instruction::DeleteStatusAndCheckEdges"); instruction_policy_->DeleteInstructionStatusIf(this); INTRUSIVE_FOR_EACH_PTR(edge, mut_in_edges()) { Instruction* in_instruction = edge->mut_src_instruction(); LOG(FATAL) << "unerased edge: " << in_instruction->DebugName() << " -> " << this->DebugName(); } INTRUSIVE_FOR_EACH_PTR(edge, mut_out_edges()) { Instruction* out_instruction = edge->mut_dst_instruction(); LOG(FATAL) << "unerased edge: " << this->DebugName() << " -> " << out_instruction->DebugName(); } } bool Instruction::Launched() const { return stream_policy().QueryInstructionStatusLaunched(stream(), status_buffer()); } bool Instruction::Done() const { return stream_policy().QueryInstructionStatusDone(stream(), status_buffer()); } StreamPolicy* Instruction::mut_stream_policy() { return mut_stream()->mut_stream_policy(); } const StreamPolicy& Instruction::stream_policy() const { return stream().stream_policy(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/instruction.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_ #define ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_ #include #include #include #include "oneflow/core/common/symbol.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/object_pool.h" #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/stream_policy.h" #include "oneflow/extension/stack/foreign_stack_getter.h" namespace oneflow { class Stream; namespace vm { static const int kInstructionStatusBufferBytes = 64; class InstructionStatusBuffer final { public: InstructionStatusBuffer() = default; ~InstructionStatusBuffer() = default; const char* buffer() const { return &buffer_[0]; } char* mut_buffer() { return &buffer_[0]; } private: char buffer_[kInstructionStatusBufferBytes]; }; class Instruction; class InstructionEdge final : public intrusive::Base, public intrusive::EnableObjectPool { public: InstructionEdge() : intrusive_ref_(), src_instruction_(), dst_instruction_(), in_edge_hook_(), out_edge_hook_() {} void __Init__() { clear_src_instruction(); clear_dst_instruction(); } // Getters bool has_src_instruction() const { return src_instruction_ != nullptr; } bool has_dst_instruction() const { return dst_instruction_ != nullptr; } const Instruction& src_instruction() const { return *src_instruction_; } const Instruction& dst_instruction() const { return *dst_instruction_; } // Setters void set_src_instruction(Instruction* val) { src_instruction_ = val; } void set_dst_instruction(Instruction* val) { dst_instruction_ = val; } void clear_src_instruction() { src_instruction_ = nullptr; } void clear_dst_instruction() { dst_instruction_ = nullptr; } Instruction* mut_src_instruction() { return src_instruction_; } Instruction* mut_dst_instruction() { return dst_instruction_; } // methods void __Init__(Instruction* src_instruction, Instruction* dst_instruction) { __Init__(); set_src_instruction(src_instruction); set_dst_instruction(dst_instruction); } intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } private: intrusive::Ref intrusive_ref_; // fields Instruction* src_instruction_; Instruction* dst_instruction_; public: // list hooks intrusive::ListHook in_edge_hook_; intrusive::ListHook out_edge_hook_; }; class Stream; class Instruction final : public intrusive::Base { public: // types using InEdgeList = intrusive::List; using OutEdgeList = intrusive::List; using DependenceAccessList = intrusive::List; void __Init__(Stream* stream, std::shared_ptr&& instruction_policy); // Getters const Stream& stream() const { return *stream_; } const InstructionStatusBuffer& status_buffer() const { return status_buffer_; } const intrusive::ListHook& main_instruction_hook() const { return main_instruction_hook_; } const InstructionPolicy& instruction_policy() const { return *instruction_policy_; } std::string DebugName() const; const intrusive::ListHook& dispatched_instruction_hook() const { return dispatched_instruction_hook_; } const intrusive::ListHook& lively_instruction_hook() const { return lively_instruction_hook_; } const intrusive::ListHook& worker_pending_instruction_hook() const { return worker_pending_instruction_hook_; } const intrusive::ListHook& barrier_instruction_hook() const { return barrier_instruction_hook_; } const InEdgeList& in_edges() const { return in_edges_; } const OutEdgeList& out_edges() const { return out_edges_; } const DependenceAccessList& access_list() const { return access_list_; } Maybe Prepare(); void Compute(); // Setters Stream* mut_stream() { return stream_; } InstructionStatusBuffer* mut_status_buffer() { return &status_buffer_; } InstructionPolicy* mut_instruction_policy() { return instruction_policy_.get(); } InEdgeList* mut_in_edges() { return &in_edges_; } OutEdgeList* mut_out_edges() { return &out_edges_; } DependenceAccessList* mut_access_list() { return &access_list_; } // methods void InitStatus(); void DeleteStatusAndCheckEdges(); bool Launched() const; bool Done() const; StreamPolicy* mut_stream_policy(); const StreamPolicy& stream_policy() const; std::shared_ptr foreign_frame() const { return foreign_frame_; } intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); } // used for instructions building, pending to scheduler, constructing DAG, pending to callback // thread and so on. // lifetime of barrier instructions: // // |<-----main_instruction_hook_----->| // |<-----------lively_instruction_hook_---------------->| // |<---------barrier_instruction_hook_--------->| // // // lifetime of non-barrier instructions: // // |<-----main_instruction_hook_----->| // |<-----------lively_instruction_hook_---------------->| // |<-------dispatched_instruction_hook_-------->| // |<--worker_pending_instruction_hook_-->| // // intrusive::ListHook main_instruction_hook_; // dispatched to Stream intrusive::ListHook dispatched_instruction_hook_; // valid during vm processing intrusive::ListHook lively_instruction_hook_; // pending to ThreadCtx intrusive::ListHook worker_pending_instruction_hook_; // for barrier instruction intrusive::ListHook barrier_instruction_hook_; private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } Instruction() : main_instruction_hook_(), dispatched_instruction_hook_(), lively_instruction_hook_(), worker_pending_instruction_hook_(), barrier_instruction_hook_(), access_list_(), in_edges_(), out_edges_(), intrusive_ref_(), stream_(), instruction_policy_(), status_buffer_() {} // lists DependenceAccessList access_list_; InEdgeList in_edges_; OutEdgeList out_edges_; // fields intrusive::Ref intrusive_ref_; Stream* stream_; std::shared_ptr instruction_policy_; InstructionStatusBuffer status_buffer_; std::shared_ptr foreign_frame_; }; using InstructionList = intrusive::List; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_ ================================================ FILE: oneflow/core/vm/instruction_fuse_type.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_ #define ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_ namespace oneflow { namespace vm { enum InstructionFuseType { kInvalidInstructionFuseType = 0, kDisableInstructionFuse, kEnableInstructionFuseAtAnyPosition, kEnableInstructionFuseAsTailOnly, }; } } // namespace oneflow #endif // ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_ ================================================ FILE: oneflow/core/vm/instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { void InstructionPolicy::InitInstructionStatus(Instruction* instruction) { instruction->stream_policy().InitInstructionStatus(instruction->stream(), instruction->mut_status_buffer()); } void InstructionPolicy::DeleteInstructionStatus(Instruction* instruction) { instruction->stream_policy().DeleteInstructionStatus(instruction->stream(), instruction->mut_status_buffer()); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_ #include #include #include #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/vm/instruction_fuse_type.h" #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { class EagerBlobObject; class Stream; class InstructionPolicy { public: virtual ~InstructionPolicy() = default; // Same stream. virtual bool Prescheduleable(const vm::Stream* src, const vm::Stream* dst) const { return src == dst; } virtual const DependenceVector& input_dependences() const = 0; virtual const DependenceVector& output_dependences() const = 0; virtual Dependence* stream_sequential_dependence() const { return stream_sequential_dependence_; } virtual bool IsBarrier() const { return false; } virtual InstructionFuseType fuse_type() const { return kDisableInstructionFuse; } virtual std::string DebugName(const Instruction&) const = 0; Maybe PrepareIf(Instruction* instruction) { OF_PROFILER_RANGE_GUARD(std::string("Prepare:") + DebugName(*instruction)); return Prepare(instruction); } void ComputeIf(Instruction* instruction) { OF_PROFILER_RANGE_GUARD(std::string("Compute:") + DebugName(*instruction)); Compute(instruction); } void InitInstructionStatusIf(Instruction* instruction) { InitInstructionStatus(instruction); } void DeleteInstructionStatusIf(Instruction* instruction) { DeleteInstructionStatus(instruction); } protected: InstructionPolicy() : stream_sequential_dependence_(nullptr) {} Dependence* stream_sequential_dependence_; private: // Usually for Allocating and deallocating tensors. virtual Maybe Prepare(Instruction* instruction) = 0; virtual void Compute(Instruction* instruction) = 0; virtual void InitInstructionStatus(Instruction* instruction); virtual void DeleteInstructionStatus(Instruction* instruction); }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/instruction_policy_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_ #define ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_ #include #include #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { struct InstructionPolicyUtil { static std::function SetInserter(DependenceVector* dependences) { auto existed = std::make_shared>(dependences->begin(), dependences->end()); return [dependences, existed](Dependence* object) { if (existed->insert(object).second) { dependences->push_back(object); } }; } }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_ ================================================ FILE: oneflow/core/vm/lazy_job_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_ #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/common/of_unused.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/job/job_instance.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/instruction_policy_util.h" #include "oneflow/core/vm/naive_instruction_status_querier.h" #include "oneflow/core/vm/lazy_job_stream_policy.h" #include "oneflow/core/vm/virtual_machine.h" #include namespace oneflow { class LazyJobInstance final : public JobInstance { public: LazyJobInstance(const LazyJobInstance&) = delete; LazyJobInstance(LazyJobInstance&&) = delete; ~LazyJobInstance() override = default; LazyJobInstance(const std::string& job_name, const std::function& finish_cb) : job_name_(job_name), finish_cb_(finish_cb) {} std::string job_name() const override { return job_name_; } void Finish() const override { finish_cb_(); } private: const std::string job_name_; const std::function finish_cb_; }; namespace vm { class LaunchLazyJobInstructionPolicy final : public InstructionPolicy { // NOLINT public: LaunchLazyJobInstructionPolicy(const LaunchLazyJobInstructionPolicy&) = delete; LaunchLazyJobInstructionPolicy(LaunchLazyJobInstructionPolicy&&) = delete; ~LaunchLazyJobInstructionPolicy() = default; LaunchLazyJobInstructionPolicy(const std::shared_ptr& nn_graph, const EagerBlobObjectListPtr& param_blob_objects) : nn_graph_(nn_graph), param_blob_objects_(param_blob_objects), input_dependences_(), output_dependences_() { robin_hood::unordered_flat_map unique_map; ForEachConstDependence([&](Dependence* compute) { if (unique_map.emplace(compute, true).second) { input_dependences_.emplace_back(compute); } }); unique_map.clear(); output_dependences_.reserve(param_blob_objects_->size()); unique_map.reserve(param_blob_objects_->size()); ForEachMutDependence([&](Dependence* compute) { if (unique_map.emplace(compute, true).second) { output_dependences_.emplace_back(compute); } }); ForEachMut2Dependence([&](Dependence* compute) { if (unique_map.emplace(compute, true).second) { output_dependences_.emplace_back(compute); } }); } const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } void ForEachConstDependence(const std::function&) const {} void ForEachMutDependence(const std::function& DoEach) const { for (const auto& eager_blob_object : *param_blob_objects_) { DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } DoEach(CHECK_JUST(SingletonMaybe()) ->FindOrCreateTransportLocalDepObject() .Mutable()); } void ForEachMut2Dependence(const std::function&) const {} std::string DebugName(const Instruction&) const override { return "LaunchLazyJob"; } Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } void Compute(Instruction* instruction) override { VLOG(3) << " VM try launch Graph: " << nn_graph_->job_name() << " in run_cnt: " << nn_graph_->run_cnt() << " START."; auto* lazy_job_stream_policy = GetLazyJobStreamPolicy(instruction); { OF_PROFILER_RANGE_GUARD("WaitUntilQueueEmptyIfFrontNNGraphNotEquals"); lazy_job_stream_policy->WaitUntilQueueEmptyIfFrontNNGraphNotEquals(nn_graph_); VLOG(3) << " VM launch Graph: " << nn_graph_->job_name() << " in run_cnt: " << nn_graph_->run_cnt() << " WaitUntilQueueEmptyIfFrontNNGraphNotEquals."; } { OF_PROFILER_RANGE_GUARD("Send all buffers to BufferMgr"); const auto& job_instance = MakeJobInstance(instruction); const auto& job_name = job_instance->job_name(); auto* buffer_mgr = Singleton>>::Get(); buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Push(job_instance); VLOG(3) << " VM Push CallbackNotifier to Graph: " << nn_graph_->job_name() << " in run_cnt: " << nn_graph_->run_cnt(); buffer_mgr->Get(GetSourceTickBufferName(job_name))->Push(job_instance); VLOG(3) << " VM Push SourceTick to Graph: " << nn_graph_->job_name() << " in run_cnt: " << nn_graph_->run_cnt(); } OF_PROFILER_RANGE_GUARD("EnqueueNNGraph"); lazy_job_stream_policy->EnqueueNNGraph(nn_graph_); VLOG(3) << " VM Enqueue Graph: " << nn_graph_->job_name() << " run_cnt: " << nn_graph_->run_cnt() << " END."; nn_graph_->NextRunCnt(); } private: LazyJobStreamPolicy* GetLazyJobStreamPolicy(Instruction* instruction) const { StreamPolicy* stream_policy = instruction->mut_stream()->mut_stream_policy(); LazyJobStreamPolicy* lazy_job_stream_policy = dynamic_cast(stream_policy); CHECK_NOTNULL(lazy_job_stream_policy); return lazy_job_stream_policy; } std::shared_ptr MakeJobInstance(Instruction* instruction) const { const auto& FinishCb = [this, instruction]() { auto* lazy_job_stream_policy = GetLazyJobStreamPolicy(instruction); lazy_job_stream_policy->DequeueNNGraph(); auto* status_buffer = instruction->mut_status_buffer(); NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer())->set_done(); }; return std::make_shared(nn_graph_->job_name(), FinishCb); } std::shared_ptr nn_graph_; EagerBlobObjectListPtr param_blob_objects_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/lazy_job_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/lazy_job_stream_policy.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/naive_instruction_status_querier.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { void LazyJobStreamPolicy::WaitUntilQueueEmptyIfFrontNNGraphNotEquals( const std::shared_ptr& nn_graph) { std::unique_lock lock(mutex_); if (queue_.empty()) { return; } const auto& last_nn_graph = queue_.front().lock(); if (!last_nn_graph) { return; } if (last_nn_graph == nn_graph) { return; } cond_.wait(lock, [this]() { return queue_.empty(); }); } void LazyJobStreamPolicy::EnqueueNNGraph(const std::shared_ptr& nn_graph) { std::unique_lock lock(mutex_); queue_.emplace(nn_graph); } void LazyJobStreamPolicy::DequeueNNGraph() { std::unique_lock lock(mutex_); queue_.pop(); cond_.notify_all(); } void LazyJobStreamPolicy::InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(NaiveInstrStatusQuerier) < kInstructionStatusBufferBytes, ""); NaiveInstrStatusQuerier::PlacementNew(status_buffer->mut_buffer()); } void LazyJobStreamPolicy::DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { auto* ptr = NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()); ptr->~NaiveInstrStatusQuerier(); } bool LazyJobStreamPolicy::QueryInstructionStatusLaunched( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->launched(); } bool LazyJobStreamPolicy::QueryInstructionStatusDone( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->done(); } void LazyJobStreamPolicy::Run(Instruction* instruction) const { instruction->Compute(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/lazy_job_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_ #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/vm/instruction.h" namespace oneflow { class NNGraphIf; namespace vm { class LazyJobStreamPolicy final : public StreamPolicy { public: LazyJobStreamPolicy() = default; virtual ~LazyJobStreamPolicy() = default; vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; } DeviceType device_type() const override { UNIMPLEMENTED(); return DeviceType::kInvalidDevice; } ep::Stream* stream() override { UNIMPLEMENTED(); return nullptr; } void WaitUntilQueueEmptyIfFrontNNGraphNotEquals(const std::shared_ptr& nn_graph); void EnqueueNNGraph(const std::shared_ptr& nn_graph); void DequeueNNGraph(); void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; bool QueryInstructionStatusLaunched(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Run(Instruction* instruction) const override; bool SupportingTransportInstructions() const override { return false; } private: std::queue> queue_; std::mutex mutex_; std::condition_variable cond_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/naive_instruction_status_querier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_ #define ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_ #include namespace oneflow { namespace vm { class NaiveInstrStatusQuerier { public: ~NaiveInstrStatusQuerier() = default; bool launched() const { return done_; } bool done() const { return done_; } void set_done() { done_ = true; } static const NaiveInstrStatusQuerier* Cast(const char* mem_ptr) { return reinterpret_cast(mem_ptr); } static NaiveInstrStatusQuerier* MutCast(char* mem_ptr) { return reinterpret_cast(mem_ptr); } static NaiveInstrStatusQuerier* PlacementNew(char* mem_ptr) { return new (mem_ptr) NaiveInstrStatusQuerier(); } private: NaiveInstrStatusQuerier() : done_(false) {} std::atomic done_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_ ================================================ FILE: oneflow/core/vm/op_call_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/op_call_instruction_policy.h" #include #include #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/vm/allocator.h" #include "oneflow/core/vm/remat/allocator.h" #include "oneflow/core/vm/remat/disjoint_set.h" #include "oneflow/core/vm/remat/env.h" #include "oneflow/core/vm/remat/util.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/framework/stream_get_stream_type_name.h" #include "oneflow/core/vm/stream_get_allocator_stream_type.h" #include "oneflow/core/profiler/profiler.h" #include "fmt/core.h" namespace oneflow { namespace vm { struct OpCallInstructionUtil final { static inline Maybe Prepare(OpCallInstructionPolicy* op_call_instruction_policy, Instruction* instruction) { VLOG_REMAT(1) << "prepare " << op_call_instruction_policy->opkernel().op_type_name() << std::endl; if (unlikely(op_call_instruction_policy->need_temp_storage())) { InferTempStorageSize(op_call_instruction_policy); } return Maybe::Ok(); } static inline Maybe Compute(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream, bool first, bool recompute) { Allocator* allocator = vm_stream->mut_stream_policy()->mut_allocator(); const auto [remat_helper, inputs_rematable, outputs_rematable] = InitRematInfo(op_call_instruction_policy, vm_stream); const auto& current_op_type_name = op_call_instruction_policy->opkernel().op_type_name(); ThreadLocalGuard current_op_type_name_guard({current_op_type_name}); if (inputs_rematable || outputs_rematable) { VLOG_REMAT(2) << "set current op type name to " << current_op_type_name << std::endl; VLOG_REMAT(2) << "op: " << op_call_instruction_policy->opkernel().op_type_name() << std::endl; VLOG_REMAT(2) << "input_rematable: " << inputs_rematable << ", output_rematable: " << outputs_rematable << std::endl; } if (inputs_rematable) { JUST(remat_helper->RematInputs(vm_stream, first, ComputeFnForRemat)); } JUST(AllocateOutputBlobsMemory(op_call_instruction_policy, allocator, vm_stream)); if (unlikely(op_call_instruction_policy->need_temp_storage())) { JUST(TryAllocateTempStorage(op_call_instruction_policy, allocator)); } ep::Stream* stream = vm_stream->mut_stream_policy()->stream(); user_op::OpKernelState* state = nullptr; user_op::OpKernelCache* cache = nullptr; if (op_call_instruction_policy->user_opkernel()->has_state_or_cache()) { TryInitOpKernelStateAndCache(op_call_instruction_policy, stream, &state, &cache); } OpKernelCompute(op_call_instruction_policy, stream, state, cache); if (unlikely(op_call_instruction_policy->need_temp_storage())) { DeallocateTempStorage(op_call_instruction_policy, allocator); } if (inputs_rematable) { JUST(remat_helper->EagerlyEvictRemattedTensors(first)); } if (inputs_rematable || outputs_rematable) { JUST(remat_helper->UpdateRematInfo(first, recompute, inputs_rematable, outputs_rematable)); } return Maybe::Ok(); } private: static inline void InferTempStorageSize(OpCallInstructionPolicy* op_call_instruction_policy) { auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor(); size_t temp_size = op_call_instruction_policy->opkernel().InferTmpSize( op_call_instruction_policy->mut_call_ctx(), op_call_instruction_policy->user_opkernel()); tmp_tensor->set_tmp_buffer_size(temp_size); } static inline void TryInitOpKernelStateAndCache( OpCallInstructionPolicy* op_call_instruction_policy, ep::Stream* stream, user_op::OpKernelState** state, user_op::OpKernelCache** cache) { OF_PROFILER_RANGE_GUARD("TryInitOpKernelStateAndCache"); if (likely(op_call_instruction_policy->op_interp_ctx().state)) { *state = op_call_instruction_policy->op_interp_ctx().state.get(); // set state to nullptr so that state initialization in TryInitOpKernelStateAndCache will be // skipped. state = nullptr; } op_call_instruction_policy->mut_opkernel()->TryInitOpKernelStateAndCache( op_call_instruction_policy->mut_call_ctx(), stream, op_call_instruction_policy->user_opkernel(), state, cache); } // Returns true if allocation happened. static inline Maybe AllocateOutputBlobsMemory( OpCallInstructionPolicy* op_call_instruction_policy, Allocator* allocator, const vm::Stream* vm_stream) { OF_PROFILER_RANGE_GUARD("AllocateOutputBlobsMemory"); StreamType stream_type = vm_stream->stream_type(); StreamType allocator_stream_type = JUST(GetAllocatorStreamType::Visit(stream_type)); for (const auto& blob_object : op_call_instruction_policy->outputs()) { if (JUST(blob_object->TryAllocateBlobBodyMemory(allocator))) { CHECK_OR_RETURN(stream_type == allocator_stream_type) << "no allocator supported on stream type " << GetStreamTypeName::Visit(stream_type); if (auto* dtr_allocator = dynamic_cast(allocator)) { dtr_allocator->allocator->LinkStorageAndPtr( dynamic_cast(blob_object->tensor_storage().get()), static_cast(blob_object->dptr())); } } } return Maybe::Ok(); } static inline Maybe TryAllocateTempStorage( OpCallInstructionPolicy* op_call_instruction_policy, Allocator* allocator) { OF_PROFILER_RANGE_GUARD("TryAllocateTempStorage"); auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor(); size_t byte_size = tmp_tensor->tmp_buffer_size(); if (byte_size > 0) { char* mem_ptr = nullptr; JUST(allocator->Allocate(&mem_ptr, byte_size)); tmp_tensor->set_tmp_buffer_ptr(mem_ptr); } return Maybe::Ok(); } static inline void DeallocateTempStorage(OpCallInstructionPolicy* op_call_instruction_policy, Allocator* allocator) { auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor(); allocator->Deallocate(tmp_tensor->mut_tmp_buffer_ptr(), tmp_tensor->tmp_buffer_size()); tmp_tensor->set_tmp_buffer_ptr(nullptr); } static inline void OpKernelCompute(OpCallInstructionPolicy* op_call_instruction_policy, ep::Stream* stream, user_op::OpKernelState* state, user_op::OpKernelCache* cache) { auto* user_kernel = op_call_instruction_policy->user_opkernel(); op_call_instruction_policy->mut_opkernel()->Compute(op_call_instruction_policy->mut_call_ctx(), stream, user_kernel, state, cache); } static inline Maybe ComputeFnForRemat(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream) { return Compute(op_call_instruction_policy, vm_stream, false, true); } static inline std::tuple, bool, bool> InitRematInfo( OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream) { bool inputs_rematable = false; bool outputs_rematable = false; if (op_call_instruction_policy->opkernel().op_type_name() == "copy") { inputs_rematable = op_call_instruction_policy->inputs()[0]->tensor_storage()->device()->rematable(); outputs_rematable = op_call_instruction_policy->outputs()[0]->tensor_storage()->device()->rematable(); } else { inputs_rematable = vm_stream->device()->rematable(); outputs_rematable = vm_stream->device()->rematable(); } std::unique_ptr remat_helper; if (inputs_rematable || outputs_rematable) { remat_helper = std::make_unique(*op_call_instruction_policy, inputs_rematable, outputs_rematable); } return std::make_tuple(std::move(remat_helper), inputs_rematable, outputs_rematable); } }; OpCallInstructionPolicy::OpCallInstructionPolicy( Stream* vm_stream, const std::shared_ptr& opkernel, EagerBlobObjectList&& inputs, EagerBlobObjectList&& outputs, const std::shared_ptr& global_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode) : vm_stream_(vm_stream), call_ctx_(ComposedAttrMap(op_interp_ctx.attrs, opkernel->base_attrs()), std::move(inputs), std::move(outputs), global_tensor_infer_result, op_interp_ctx, opkernel->mem_case()), opkernel_(opkernel), user_opkernel_(nullptr), infer_tmp_size_fn_(nullptr), need_temp_storage_(false), dev_vm_dep_object_consume_mode_(dev_vm_dep_object_consume_mode), input_dependences_(), output_dependences_() { ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); }); ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); InitStreamSequentialDependence(); } Maybe OpCallInstructionPolicy::Init() { return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_); } OpCallInstructionPolicy::OpCallInstructionPolicy(const DtrOpCallInstructionPolicy& policy) : vm_stream_(policy.vm_stream_), call_ctx_(policy.dtr_call_ctx_), opkernel_(policy.opkernel_), user_opkernel_(policy.user_opkernel_), infer_tmp_size_fn_(policy.infer_tmp_size_fn_), need_temp_storage_(policy.need_temp_storage_), dev_vm_dep_object_consume_mode_(policy.dev_vm_dep_object_consume_mode_), input_dependences_(policy.input_dependences_), output_dependences_(policy.output_dependences_) {} template void OpCallInstructionPolicy::ForEachConstDependence(const DoEachT& DoEach) const { const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) { const auto& input = input_list.at(index); DoEach(CHECK_JUST(input->compute_local_dep_object())); } } void OpCallInstructionPolicy::InitStreamSequentialDependence() { auto* device_schedule_dep_object = vm_stream_->schedule_local_dep_object().get(); if (IsCommNetStream::Visit(vm_stream_->stream_type())) { // Sequantialize nccl instructions to avoid deadlock stream_sequential_dependence_ = device_schedule_dep_object; } else { // Sequantialize instructions to avoid explosive memory allocation of source ops if (dev_vm_dep_object_consume_mode() == one::DevVmDepObjectConsumeMode::MUTABLE) { stream_sequential_dependence_ = device_schedule_dep_object; } else if (opkernel().input_tuple_indexes4const_ibns().empty() && opkernel().input_tuple_indexes4mut_ibns().empty()) { stream_sequential_dependence_ = device_schedule_dep_object; } } } template void OpCallInstructionPolicy::ForEachMutDependence(const DoEachT& DoEach) const { for (const auto& transport_dependence : vm_stream_->transport_dependences()) { DoEach(transport_dependence.get()); } const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4mut_ibns()) { const auto& input = input_list.at(index); DoEach(CHECK_JUST(input->compute_local_dep_object())); } const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut_obns()) { const auto& output = output_list.at(index); DoEach(CHECK_JUST(output->compute_local_dep_object())); } } template void OpCallInstructionPolicy::ForEachMut2Dependence(const DoEachT& DoEach) const { const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) { const auto& output = output_list.at(index); DoEach(CHECK_JUST(output->compute_local_dep_object())); } } Maybe OpCallInstructionPolicy::Prepare(vm::Instruction* instruction) { return OpCallInstructionUtil::Prepare(this, instruction); } void OpCallInstructionPolicy::Compute(vm::Instruction* instruction) { CHECK_JUST_MSG(OpCallInstructionUtil::Compute(this, instruction->mut_stream(), true, false), instruction->DebugName()); } std::string OpCallInstructionPolicy::DebugName(const vm::Instruction& instruction) const { return opkernel().op_type_name() + ":OpCall"; } Maybe Recompute(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream) { VLOG_REMAT(1) << "recompute " << op_call_instruction_policy->opkernel().op_type_name() << " manually"; return OpCallInstructionUtil::Compute(op_call_instruction_policy, vm_stream, true, true); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/op_call_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/eager/call_context.h" #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/stream.h" #include "oneflow/user/kernels/stateful_opkernel.h" namespace oneflow { namespace user_op { class OpKernel; } // namespace user_op namespace vm { class DtrOpCallInstructionPolicy; class OpCallInstructionPolicy final : public InstructionPolicy { public: OpCallInstructionPolicy(const OpCallInstructionPolicy& other) = default; OpCallInstructionPolicy(OpCallInstructionPolicy&& other) = default; OpCallInstructionPolicy& operator=(const OpCallInstructionPolicy& other) = delete; OpCallInstructionPolicy& operator=(OpCallInstructionPolicy&& other) = delete; ~OpCallInstructionPolicy() override = default; template static Maybe New(Args&&... args) { auto* ptr = new OpCallInstructionPolicy(std::forward(args)...); JUST(ptr->Init()); return std::shared_ptr(ptr); } const one::StatefulOpKernel& opkernel() const { return *opkernel_; } const EagerBlobObjectList& inputs() const { return call_ctx_.inputs(); } const EagerBlobObjectList& outputs() const { return call_ctx_.outputs(); } EagerBlobObjectList& mut_inputs() { return call_ctx_.mut_inputs(); } EagerBlobObjectList& mut_outputs() { return call_ctx_.mut_outputs(); } const ComposedAttrMap& composed_attrs() const { return call_ctx_.composed_attrs(); } const one::OpExprInterpContext& op_interp_ctx() const { return call_ctx_.op_interp_ctx(); } const one::DevVmDepObjectConsumeMode& dev_vm_dep_object_consume_mode() const { return dev_vm_dep_object_consume_mode_; } one::StatefulOpKernel* mut_opkernel() { return opkernel_.get(); } template Maybe ForEachOutputTensor(const DoEachT& DoEach) { for (const auto& output : outputs()) { JUST(DoEach(output.get())); } return Maybe::Ok(); } const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } template void ForEachConstDependence(const DoEachT& DoEach) const; template void ForEachMutDependence(const DoEachT& DoEach) const; template void ForEachMut2Dependence(const DoEachT& DoEach) const; bool need_temp_storage() const { return need_temp_storage_; } const user_op::OpKernel* user_opkernel() const { return user_opkernel_; } const user_op::InferTmpSizeFn& infer_tmp_size_fn() const { return *infer_tmp_size_fn_; } const std::shared_ptr& global_tensor_infer_result() const { return call_ctx_.global_tensor_infer_result(); } const eager::CallContext& call_ctx() const { return call_ctx_; } eager::CallContext* mut_call_ctx() { return &call_ctx_; } Stream* vm_stream() const { return vm_stream_; } InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; } std::string DebugName(const vm::Instruction& instruction) const override; explicit OpCallInstructionPolicy(const DtrOpCallInstructionPolicy& policy); private: OpCallInstructionPolicy( Stream* vm_stream, const std::shared_ptr& opkernel, EagerBlobObjectList&& inputs, EagerBlobObjectList&& outputs, const std::shared_ptr& global_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode); Maybe Init(); void InitStreamSequentialDependence(); Maybe Prepare(Instruction* instruction) override; void Compute(Instruction* instruction) override; Stream* vm_stream_; eager::CallContext call_ctx_; std::shared_ptr opkernel_; const user_op::OpKernel* user_opkernel_; const user_op::InferTmpSizeFn* infer_tmp_size_fn_; bool need_temp_storage_; const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode_; DependenceVector input_dependences_; DependenceVector output_dependences_; friend class DtrOpCallInstructionPolicy; }; class DtrOpCallInstructionPolicy { Stream* vm_stream_; eager::DtrCallContext dtr_call_ctx_; std::shared_ptr opkernel_; const user_op::OpKernel* user_opkernel_; const user_op::InferTmpSizeFn* infer_tmp_size_fn_; bool need_temp_storage_; const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode_; DependenceVector input_dependences_; DependenceVector output_dependences_; public: explicit DtrOpCallInstructionPolicy(const OpCallInstructionPolicy& op) : vm_stream_(op.vm_stream()), dtr_call_ctx_(op.call_ctx()), opkernel_(op.opkernel_), user_opkernel_(op.user_opkernel_), infer_tmp_size_fn_(op.infer_tmp_size_fn_), need_temp_storage_(op.need_temp_storage()), dev_vm_dep_object_consume_mode_(op.dev_vm_dep_object_consume_mode()), input_dependences_(op.input_dependences()), output_dependences_(op.output_dependences()) {} friend class OpCallInstructionPolicy; EagerBlobObjectList& mut_inputs() { return dtr_call_ctx_.mut_inputs(); } WeakEagerBlobObjectList& mut_outputs() { return dtr_call_ctx_.mut_outputs(); } const one::StatefulOpKernel& opkernel() const { return *opkernel_; } }; Maybe Recompute(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream); } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/pinned_ep_stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/pinned_ep_stream_policy.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/vm/ep_backend_host_allocator.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { namespace { std::unique_ptr> CreatePinedEpBackendHostAllocator( Symbol device) { // TODO:(zhaoluyang) empty/cast/copy op support pin_memory_device DeviceType device_type = device->enum_type(); size_t device_index = device->device_id(); auto ep_device = Singleton::Get()->GetDevice(device_type, device_index); ep::AllocationOptions options{}; options.SetPinnedDevice(device_type, device_index); auto ep_backend_allocator = std::make_unique(ep_device, options); return std::make_unique>(ep::kMaxAlignmentRequirement, std::move(ep_backend_allocator)); } } // namespace PinnedEpStreamPolicy::PinnedEpStreamPolicy(Symbol device) : EpStreamPolicyBase(device, CreatePinedEpBackendHostAllocator(device)) {} void PinnedEpStreamPolicy::InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, ""); auto* data_ptr = status_buffer->mut_buffer(); EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, nullptr); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/pinned_ep_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_ #include "oneflow/core/vm/ep_stream_policy_base.h" namespace oneflow { namespace vm { class PinnedEpStreamPolicy final : public EpStreamPolicyBase { public: PinnedEpStreamPolicy(Symbol device); ~PinnedEpStreamPolicy() override = default; void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const override; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/probe.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_PROBE_H_ #define ONEFLOW_CORE_VM_PROBE_H_ #include "oneflow/core/intrusive/intrusive.h" namespace oneflow { namespace vm { template class Probe final : public intrusive::Base { public: Probe(const Probe&) = delete; Probe(Probe&&) = delete; Probe() = default; ~Probe() = default; void __Init__(const ProbeFunction& probe_function) { probe_function_ = probe_function; } const ProbeFunction& probe_function() const { return probe_function_; } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } // fields intrusive::Ref intrusive_ref_; ProbeFunction probe_function_; public: // hooks intrusive::ListHook probe_hook_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_PROBE_H_ ================================================ FILE: oneflow/core/vm/ref_cnt_instruction_status_querier.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_ #define ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_ #include #include namespace oneflow { namespace vm { class RefCntInstrStatusQuerier { public: ~RefCntInstrStatusQuerier() = default; bool done() const { return launched_ && *ref_cnt_ == 0; } void SetRefCntAndSetLaunched(const std::shared_ptr>& ref_cnt) { // No lock needed. This function will be called only one time. // In most cases, errors will be successfully detected by CHECK // even though run in different threads. CHECK(!launched_); ref_cnt_ = ref_cnt; launched_ = true; } static const RefCntInstrStatusQuerier* Cast(const char* mem_ptr) { return reinterpret_cast(mem_ptr); } static RefCntInstrStatusQuerier* MutCast(char* mem_ptr) { return reinterpret_cast(mem_ptr); } static RefCntInstrStatusQuerier* PlacementNew(char* mem_ptr) { return new (mem_ptr) RefCntInstrStatusQuerier(); } private: RefCntInstrStatusQuerier() : launched_(false), ref_cnt_() {} std::atomic launched_; std::shared_ptr> ref_cnt_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_ ================================================ FILE: oneflow/core/vm/release_tensor_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_ #include #include #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/framework/stream_need_soft_sync.h" namespace oneflow { namespace vm { class EagerBlobObject; class ReleaseTensorInstructionPolicy : public InstructionPolicy { public: ReleaseTensorInstructionPolicy(const std::shared_ptr& eager_blob_object, const Optional& stream) : eager_blob_object_(eager_blob_object), output_dependences_() { output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object())); if (stream.has_value()) { stream_sequential_dependence_ = CHECK_JUST(stream)->schedule_local_dep_object().get(); } } ~ReleaseTensorInstructionPolicy() override = default; const std::shared_ptr& eager_blob_object() const { return eager_blob_object_; } const DependenceVector& input_dependences() const override { static thread_local DependenceVector empty{}; return empty; } const DependenceVector& output_dependences() const override { return output_dependences_; } Dependence* stream_sequential_dependence() const override { return stream_sequential_dependence_; } protected: void Release(const std::shared_ptr& eager_blob_object) const { CHECK_JUST(eager_blob_object->DeallocateBlobDataPtr()); } private: void InitInstructionStatus(Instruction* instruction) override { auto* status_buffer = instruction->mut_status_buffer(); auto* stream = instruction->mut_stream(); instruction->stream_policy().InitInstructionStatus(*stream, status_buffer); auto* data_ptr = status_buffer->mut_buffer(); EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr); } std::shared_ptr eager_blob_object_; DependenceVector output_dependences_; }; class FastReleaseTensorInstructionPolicy final : public ReleaseTensorInstructionPolicy { public: using ReleaseTensorInstructionPolicy::ReleaseTensorInstructionPolicy; bool Prescheduleable(const vm::Stream* src, const vm::Stream* dst) const override { return false; } private: std::string DebugName(const vm::Instruction& instruction) const override { return "FastReleaseTensor"; } Maybe Prepare(vm::Instruction* instruction) override { DataType data_type = eager_blob_object()->data_type(); CHECK_OR_RETURN(IsTriviallyCopyableDataType(data_type)); if (eager_blob_object()->tensor_storage()->is_allocated_in_vm()) { Release(eager_blob_object()); } return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override { if (!eager_blob_object()->tensor_storage()->is_allocated_in_vm()) { Release(eager_blob_object()); } } }; class SlowReleaseTensorInstructionPolicy final : public ReleaseTensorInstructionPolicy { public: using ReleaseTensorInstructionPolicy::ReleaseTensorInstructionPolicy; private: std::string DebugName(const vm::Instruction& instruction) const override { return "SlowReleaseTensor"; } Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override { Release(eager_blob_object()); } }; struct MakeReleaseTensorInstructionPolicy : public StreamTypeVisitor { static Maybe VisitCompute( const std::shared_ptr& eager_blob_object, const Optional& stream) { return Make(eager_blob_object, stream); } static Maybe VisitHost2Device( const std::shared_ptr& eager_blob_object, const Optional& stream) { return Make(eager_blob_object, stream); } static Maybe VisitDevice2Host( const std::shared_ptr& eager_blob_object, const Optional& stream) { return Make(eager_blob_object, stream); } static Maybe VisitCcl( const std::shared_ptr& eager_blob_object, const Optional& stream) { return Make(eager_blob_object, stream); } static Maybe VisitBarrier( const std::shared_ptr& eager_blob_object, const Optional& stream) { UNIMPLEMENTED_THEN_RETURN() << "ReleaseTensor instruction not supported in Barrier stream"; } static Maybe VisitCriticalSection( const std::shared_ptr& eager_blob_object, const Optional& stream) { UNIMPLEMENTED_THEN_RETURN() << "ReleaseTensor instruction not supported in CriticalSection stream"; } static Maybe VisitLazyJobLauncher( const std::shared_ptr& eager_blob_object, const Optional& stream) { UNIMPLEMENTED_THEN_RETURN() << "ReleaseTensor instruction not supported in LaunchLazyJob stream"; } static Maybe VisitPinnedCompute( const std::shared_ptr& eager_blob_object, const Optional& stream) { return VisitCompute(eager_blob_object, stream); } private: static Maybe Make( const std::shared_ptr& eager_blob_object, const Optional& stream) { DataType data_type = eager_blob_object->data_type(); if (!IsTriviallyCopyableDataType(data_type)) { return std::shared_ptr( new vm::SlowReleaseTensorInstructionPolicy(eager_blob_object, stream)); } Symbol last_used_stream = JUST(eager_blob_object->last_used_stream()); DeviceType device_type = last_used_stream->device()->enum_type(); if (NeedSoftSync::Visit(last_used_stream->stream_type(), device_type)) { return std::shared_ptr( new vm::SlowReleaseTensorInstructionPolicy(eager_blob_object, stream)); } else { return std::shared_ptr( new vm::FastReleaseTensorInstructionPolicy(eager_blob_object, stream)); } } }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/remat/allocator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "nlohmann/json.hpp" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/thread_local_guard.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/profiler/util.h" #include "oneflow/core/common/env_var/remat.h" #include "oneflow/core/vm/ep_backend_allocator.h" #include "oneflow/core/vm/remat/allocator.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/remat/env.h" #include "oneflow/core/vm/remat/util.h" #include "oneflow/core/vm/thread_safe_guard.h" #include "oneflow/core/vm/remat/disjoint_set.h" #include namespace oneflow { namespace vm { namespace { inline size_t CudaMemAlignedBytes(size_t bytes) { return RoundUp(bytes, kCudaMemAllocAlignSize); } inline bool IsAlignedSize(size_t size) { return size % kCudaMemAllocAlignSize == 0; } inline double bytes2Mb(size_t bytes) { return bytes * 1. / 1024 / 1024; } static constexpr size_t kSmallPieceThreshold = 10 * 1024; // 10 KB inline bool ShouldBeHeldBySmallPiece(size_t size) { return Singleton::Get()->is_small_pieces_optimization_enabled() && size <= kSmallPieceThreshold; } std::vector GroupNumToIndexes(size_t group_num) { switch (group_num) { case 1: return {0}; case 2: return {0, 1}; case 3: return {0, 1, 2}; case 4: return {0, 1, 3, 2}; case 6: return {3, 1, 0, 5, 4, 2}; } UNIMPLEMENTED(); } } // namespace RematEpAllocator::RematEpAllocator(size_t alignment, std::unique_ptr&& backend) : Allocator(), alignment_(alignment), backend_(std::move(backend)), memory_size_(0), recycle_piece_list_(nullptr), normal_group_num_(EnvInteger()), group_indexes_(GroupNumToIndexes(normal_group_num_)), cur_group_index_id_(normal_group_num_ > 1 ? 1 : 0), cur_group_index_id_high_cost_(0), enable_left_and_right_(normal_group_num_ > 1) { free_pieces_overlapping_with_group_.resize(normal_group_num_ + 1); } RematEpAllocator::~RematEpAllocator() { if (memory_ != nullptr) { backend_->Deallocate(static_cast(memory_), memory_size_); } } RematEpAllocator::offset_t RematEpAllocator::get_offset(const char* mem_ptr) const { return mem_ptr - (char*)memory_; } void RematEpAllocator::LinkStorageAndPtr(RematableTensorStorage* storage, const char* mem_ptr) { Piece* piece = ptr2piece_.at(mem_ptr); piece->tensor = storage; CHECK_NOTNULL(piece->tensor); VLOG(1) << "tensor " << piece->tensor->id() << " is allocated at " << get_offset(mem_ptr) << ", left: " << piece->is_left; } Maybe RematEpAllocator::InSmallMemoryArea(void* ptr) { CHECK_NOTNULL_OR_RETURN(small_piece_area_ptr_); CHECK_GE_OR_RETURN(ptr, memory_); CHECK_LT_OR_RETURN(ptr, (char*)memory_ + memory_size_); // compare pointer by raw < or > is undefined behavior return std::greater_equal<>{}(ptr, small_piece_area_ptr_); } RematEpAllocator::Piece* RematEpAllocator::AllocatePiece() { if (recycle_piece_list_) { Piece* ret = recycle_piece_list_; recycle_piece_list_ = recycle_piece_list_->next; return ret; } else { pieces_.emplace_back(new Piece()); return pieces_.at(pieces_.size() - 1).get(); } } void RematEpAllocator::DeallocatePiece(Piece* piece) { piece->ptr = nullptr; piece->size = 0; CHECK(piece->is_free); piece->prev = nullptr; piece->next = recycle_piece_list_; piece->is_left = true; recycle_piece_list_ = piece; } void RematEpAllocator::InsertPiece2PtrMap(Piece* piece) { VLOG(2) << "insert piece, offset " << get_offset(piece->ptr); CHECK_NOTNULL(piece->ptr); CHECK(ptr2piece_.emplace(piece->ptr, piece).second); } void RematEpAllocator::ErasePieceFromPtrMap(Piece* piece) { VLOG(2) << "erase piece, offset " << get_offset(piece->ptr); CHECK_NOTNULL(piece->ptr); auto it = ptr2piece_.find(piece->ptr); CHECK(it != ptr2piece_.end()); ptr2piece_.erase(it); } double get_cost(const vm::RematableTensorStorage* storage) { if (storage == nullptr) { return 0.; } double cost = CHECK_JUST(storage->cost(0)); CHECK(!std::isnan(cost)); return cost; } double get_cost(const vm::RematableTensorStorage* storage, size_t size) { if (storage == nullptr) { return 0.; } double cost = CHECK_JUST(storage->cost(size)); CHECK(!std::isnan(cost)); return cost; } void RematEpAllocator::CheckPieces() { auto it = ptr2piece_.cbegin(); for (int i = 0; i < ptr2piece_.size(); ++i) { Piece* piece = it->second; if (piece->tensor == nullptr) { CHECK(piece->is_free); } if (piece->is_free) { CHECK_ISNULL(piece->tensor); } if (i != 0) { CHECK_EQ(piece->prev->next, piece); CHECK_EQ(piece->prev->ptr + piece->prev->size, piece->ptr); auto it2 = it; --it2; CHECK_EQ(piece->prev, it2->second); } if (i != ptr2piece_.size() - 1) { CHECK_EQ(piece->next->prev, piece); CHECK_EQ(piece->ptr + piece->size, piece->next->ptr); auto it2 = it; ++it2; CHECK_EQ(piece->next, it2->second); } it++; } } void RematEpAllocator::DisplayAllPieces() { std::cout << "ops: " << Singleton::Get()->ops.size() << std::endl; for (const auto& pair : ptr2piece_) { Piece* piece = pair.second; std::stringstream ss; ss << "piece " << piece << ", " << (void*)piece->ptr << ", " << piece->size << ", "; if (piece->tensor) { ss << "ebo: " << piece->tensor << ", id: " << piece->tensor->id() << ", cost: " << (piece->tensor->is_eviction_disabled() ? "disabled" : std::to_string(get_cost(piece->tensor))) << ", pinned: " << piece->tensor->num_pinned() << ", evictable: " << piece->tensor->is_evictable() << ", compute op: " << piece->tensor->compute_op_type_name(); } else { ss << "no tensor"; } std::cout << ss.str() << std::endl; } } void RematEpAllocator::Display() { double total_free_piece_bytes = 0.; for (const auto& free_list : free_pieces_overlapping_with_group_) { for (auto it = free_list.begin(); it != free_list.end(); ++it) { Piece* piece = *it; CHECK(piece->is_free); CHECK_NOTNULL(piece->ptr); CHECK(IsAlignedSize(piece->size)); std::cout << "memory: " << piece->size * 1. / 1024 / 1024 << "MB" << std::endl; total_free_piece_bytes += piece->size; } } std::cout << "total_free_piece_bytes: " << bytes2Mb(total_free_piece_bytes) << "MB" << ", total allocate bytes: " << bytes2Mb(total_allocate_bytes_) << "MB" << ", total deallocate bytes: " << bytes2Mb(total_deallocate_bytes_) << "MB" << std::endl; } // 开启了 left-right 之后,才能开启 op guided RematEpAllocator::offset_t RematEpAllocator::FindProperPositionInGroup(Piece* piece, size_t group_idx, size_t request_size) const { const offset_t grp_left_bound = group_boundaries_[group_idx].first; const offset_t grp_right_bound = group_boundaries_[group_idx].second; const offset_t piece_left_bound = get_offset(piece->ptr); const offset_t piece_right_bound = piece_left_bound + piece->size; const bool is_right = enable_left_and_right_ && (group_idx % 2 == 1) && group_idx != normal_group_num_; #define PNT3(var) VLOG(3) << OF_PP_STRINGIZE(var) << ": " << var << std::endl PNT3(group_idx); PNT3(grp_left_bound); PNT3(grp_right_bound); PNT3(piece_left_bound); PNT3(piece_right_bound); PNT3(is_right); PNT3(request_size); if (is_right) { if (grp_right_bound < piece_right_bound) { if (grp_right_bound - request_size > piece_left_bound) { return grp_right_bound - request_size; } } // half of tensor in group if (piece_right_bound - request_size / 2 < grp_right_bound) { return piece_right_bound - request_size; } } else { if (grp_left_bound > piece_left_bound) { if (grp_left_bound + request_size < piece_right_bound) { return grp_left_bound; } } // half of tensor in group if (piece_left_bound + request_size / 2 > grp_left_bound) { return piece_left_bound; } } return SIZE_MAX; } void RematEpAllocator::InsertToFreeList(Piece* piece) { const offset_t piece_left = get_offset(piece->ptr); const offset_t piece_right = piece_left + piece->size; VLOG(3) << "piece_left: " << piece_left << ", right: " << piece_right << std::endl; for (size_t i = 0; i < group_boundaries_.size(); i++) { VLOG(3) << "g left: " << group_boundaries_[i].first << ", right: " << group_boundaries_[i].second << std::endl; if ((piece_left >= group_boundaries_[i].first && piece_left < group_boundaries_[i].second) || (piece_right > group_boundaries_[i].first && piece_right <= group_boundaries_[i].second)) { VLOG(3) << "overlap" << std::endl; free_pieces_overlapping_with_group_[i].insert(piece); } } } void RematEpAllocator::EraseFromFreeList(Piece* piece) { VLOG(3) << "erase " << get_offset(piece->ptr); // NOTE: very strange bug: // std::map::erase(Key) returns 2 instead of 0 or 1, which conflicts with documentation. for (auto& free_list : free_pieces_overlapping_with_group_) { for (auto it = free_list.begin(); it != free_list.end(); it++) { if ((*it)->ptr == piece->ptr) { free_list.erase(it); break; } } } } auto RematEpAllocator::AllocateMemoryInPiece(Piece* piece, offset_t offset_in_piece, size_t size) -> Piece* { auto SplitPiece = [this](Piece* piece, offset_t offset_in_piece) -> Piece* { // offset_in_piece must be less (not equal) than piece->size so that // new_piece has size CHECK_LE(offset_in_piece, piece->size); Piece* new_piece = AllocatePiece(); new_piece->ptr = piece->ptr + offset_in_piece; VLOG(2) << get_offset(piece->ptr); new_piece->size = piece->size - offset_in_piece; piece->size = offset_in_piece; Piece* next_p = piece->next; piece->next = new_piece; new_piece->prev = piece; new_piece->next = next_p; if (next_p != nullptr) { next_p->prev = new_piece; } InsertPiece2PtrMap(new_piece); CHECK(IsAlignedSize(piece->size)); CHECK(IsAlignedSize(new_piece->size)); return new_piece; }; auto SplitPiece3 = [&SplitPiece]( Piece* piece, offset_t offset1_in_piece, offset_t offset2_in_piece) -> std::tuple { Piece* piece1 = nullptr; Piece* piece2 = nullptr; Piece* piece3 = nullptr; bool has_piece3 = offset2_in_piece != piece->size; if (offset1_in_piece > 0) { piece1 = piece; piece2 = SplitPiece(piece, offset1_in_piece); } else { piece1 = nullptr; piece2 = piece; } if (has_piece3) { piece3 = SplitPiece(piece2, offset2_in_piece - offset1_in_piece); } return {piece1, piece2, piece3}; }; auto pieces = SplitPiece3(piece, offset_in_piece, offset_in_piece + size); EraseFromFreeList(piece); Piece *piece1 = std::get<0>(pieces), *piece2 = std::get<1>(pieces), *piece3 = std::get<2>(pieces); if (piece1 != nullptr) { // piece1 is already free InsertToFreeList(piece1); } // piece2->is_free = false; if (piece3 != nullptr) { piece3->is_free = true; InsertToFreeList(piece3); } return piece2; } size_t RematEpAllocator::iterate_group_index(bool high) const { if (normal_group_num_ == 1) { return 0; } auto is_high_group = [](size_t idx) -> bool { return (idx / 2) % 2 == (idx % 2); }; if (high) { size_t index; // NOLINT do { cur_group_index_id_high_cost_ = (cur_group_index_id_high_cost_ + 1) % normal_group_num_; index = group_indexes_[cur_group_index_id_high_cost_]; } while (!is_high_group(index)); return index; } else { size_t index; // NOLINT do { cur_group_index_id_ = (cur_group_index_id_ + 1) % normal_group_num_; index = group_indexes_[cur_group_index_id_]; } while (is_high_group(index)); return index; } } size_t RematEpAllocator::group_index(bool high) const { if (high) { return group_indexes_[cur_group_index_id_high_cost_]; } else { return group_indexes_[cur_group_index_id_]; } } void RematEpAllocator::InitMemory() { memory_size_ = Singleton::Get()->budget_in_bytes(); CHECK_JUST(backend_->Allocate(&memory_, memory_size_)); LOG(INFO) << "memory_: " << (void*)memory_ << ", size: " << memory_size_; const size_t small_piece_area_size = Singleton::Get()->is_small_pieces_optimization_enabled() ? 1024 * kSmallPieceThreshold : 0; const size_t normal_area_size = memory_size_ - small_piece_area_size; small_piece_area_ptr_ = memory_ + normal_area_size; if (enable_left_and_right_) { CHECK_EQ(normal_group_num_ % 2, 0); } const size_t effective_normal_group_num = enable_left_and_right_ ? normal_group_num_ / 2 : normal_group_num_; const std::vector boundary_tmp = [&]() { const size_t mem_per_group = normal_area_size / effective_normal_group_num; std::vector boundary_tmp; for (size_t i = 0, b = 0; i < effective_normal_group_num; i++, b += mem_per_group) { boundary_tmp.push_back(b); } boundary_tmp.push_back(normal_area_size); return boundary_tmp; }(); for (size_t i = 0; i < effective_normal_group_num; i++) { group_boundaries_.emplace_back(boundary_tmp[i], boundary_tmp[i + 1]); if (enable_left_and_right_) { group_boundaries_.emplace_back(boundary_tmp[i], boundary_tmp[i + 1]); } } if (normal_area_size != memory_size_) { group_boundaries_.emplace_back(normal_area_size, memory_size_); } Piece* piece = AllocatePiece(); piece->size = memory_size_; piece->ptr = memory_; piece->prev = nullptr; piece->next = nullptr; piece->is_free = true; piece->tensor = nullptr; InsertToFreeList(piece); InsertPiece2PtrMap(piece); } Maybe RematEpAllocator::FindPiece(size_t aligned_size, bool after_eviction) { CHECK_OR_RETURN(IsAlignedSize(aligned_size)); if (memory_ == nullptr) { InitMemory(); } // NOLINTNEXTLINE const bool is_high_op = [&]() { std::vector high_compute_cost_names{"conv2d", "conv_data_grad", "conv_filter_grad", "add_n", "matmul", "batch_matmul"}; const auto current_op_type_name = CHECK_JUST(ThreadLocalGuard::Current())->value; PNT3(current_op_type_name); if (std::find(high_compute_cost_names.cbegin(), high_compute_cost_names.cend(), current_op_type_name) != high_compute_cost_names.cend()) { return true; } return false; }(); size_t group_idx = [&]() -> size_t { if (ShouldBeHeldBySmallPiece(aligned_size)) { return normal_group_num_; } // if (after_eviction) { return true; } return group_index(is_high_op); }(); PNT3(aligned_size); size_t iterate_num = 0; do { const auto& free_pieces = free_pieces_overlapping_with_group_[group_idx]; PNT3(group_idx); PNT3(free_pieces.size()); for (auto it = free_pieces.begin(); it != free_pieces.end(); ++it) { Piece* piece = *it; CHECK_OR_RETURN(piece->is_free); CHECK_NOTNULL(piece->ptr); CHECK_OR_RETURN(IsAlignedSize(piece->size)); PNT3(get_offset(piece->ptr)); PNT3(piece->size); if (piece->size >= aligned_size) { const offset_t offset_in_memory = FindProperPositionInGroup(piece, group_idx, aligned_size); PNT3(offset_in_memory); if (offset_in_memory != SIZE_MAX) { const offset_t offset_in_piece = offset_in_memory - get_offset(piece->ptr); auto ret = AllocateMemoryInPiece(piece, offset_in_piece, aligned_size); CheckPieces(); return ret; } } } // update group_idx only if this group fails // multiple outputs of a single op places in the same group group_idx = iterate_group_index(is_high_op); iterate_num++; } while (!ShouldBeHeldBySmallPiece(aligned_size) && iterate_num < normal_group_num_); return nullptr; } void RematEpAllocator::MergeNeighbourFreePiece(Piece* lhs, Piece* rhs) { CHECK(lhs->is_free); CHECK(rhs->is_free); CHECK(lhs->next == rhs); CHECK(lhs == rhs->prev); CHECK(lhs->ptr + lhs->size == rhs->ptr); lhs->size += rhs->size; lhs->next = rhs->next; if (rhs->next != nullptr) { rhs->next->prev = lhs; } ErasePieceFromPtrMap(rhs); DeallocatePiece(rhs); } Maybe RematEpAllocator::EvictAndFindPieceLoop(size_t required_size, bool consider_neighbor) { VLOG(2) << "required size: " << required_size; auto GetSizeIncludingNeighborhood = [](auto it, auto begin, auto end) -> size_t { size_t size = it->second->size; if (it != begin) { for (auto t = std::prev(it); t->second->tensor == nullptr; t--) { size += t->second->size; if (t == begin) { break; } } } if (it != end) { for (auto t = std::next(it); t != end && t->second->tensor == nullptr; t++) { size += t->second->size; } } return size; }; while (true) { double min_cost = std::numeric_limits::max(); vm::RematableTensorStorage* min_tensor = nullptr; for (auto it = ptr2piece_.begin(); it != ptr2piece_.end() && !JUST(InSmallMemoryArea(it->second->ptr)); it++) { auto* tensor = it->second->tensor; if (tensor != nullptr && !tensor->is_pinned() && tensor->is_evictable()) { auto cur_op_cost = consider_neighbor ? get_cost( tensor, GetSizeIncludingNeighborhood(it, ptr2piece_.begin(), ptr2piece_.end())) : get_cost(tensor); if (cur_op_cost < min_cost) { min_cost = cur_op_cost; min_tensor = tensor; } } } if (min_tensor) { min_tensor->Evict(false); Piece* piece = JUST(FindPiece(required_size, true)); if (piece != nullptr) { return piece; } } else { return Error::RuntimeError() << "Cannot find a piece to evict"; } } } Maybe RematEpAllocator::EvictAndFindPieceOnce(size_t required_size) { VLOG(2) << "required size: " << required_size; auto start = ptr2piece_.begin(); auto end = ptr2piece_.begin(); size_t total_size = 0; double cost_except_size = 0; double min_cost = std::numeric_limits::max(); auto min_start = start; auto min_end = start; std::vector costs; costs.reserve(ptr2piece_.size()); size_t start_i = 0; size_t end_i = 0; while (end != ptr2piece_.end() && !JUST(InSmallMemoryArea(end->second->ptr))) { if (total_size < required_size) { auto* end_tensor = end->second->tensor; if (end_tensor != nullptr && (end_tensor->is_pinned() || !end_tensor->is_evictable())) { VLOG(2) << "skip tensor: " << end_tensor << ", size: " << end_tensor->blob_bytes() << ", compute op " << end_tensor->compute_op_type_name() << ", num_pinned: " << end_tensor->num_pinned() << ", is_evictable: " << end_tensor->is_evictable(); end++; costs.push_back(0); end_i++; start = end; start_i = end_i; total_size = 0; cost_except_size = 0; continue; } total_size += end->second->size; auto cur_op_cost = get_cost(end_tensor); costs.push_back(cur_op_cost); cost_except_size += cur_op_cost; VLOG(2) << "move end, include op: " << (end_tensor != nullptr ? end_tensor->compute_op_type_name() : "no tensor") << ", size: " << end->second->size << ", total_size: " << total_size << ", total cost: " << cost_except_size << ", cur op cost: " << cur_op_cost; end++; end_i++; } else { auto* start_tensor = start->second->tensor; // const auto* start_tensor = start->second->tensor; total_size -= start->second->size; // start_tensor is back in the pool, update_after_pesudo_compute double cur_op_cost = 0; cur_op_cost = costs[start_i]; cost_except_size -= cur_op_cost; VLOG(2) << "move start, exclude op: " << (start_tensor != nullptr ? start_tensor->compute_op_type_name() : "no tensor") << ", size: " << start->second->size << ", total_size: " << total_size << ", total cost: " << cost_except_size << ", cur op cost: " << cur_op_cost; start++; start_i++; } double cost = cost_except_size; if (total_size >= required_size && cost < min_cost) { min_cost = cost; min_start = start; min_end = end; VLOG(2) << "record, min_cost: " << min_cost; } } // CHECK(min_end != start); // collect piece ptrs into a new container, because evict() will devalidate the iterators std::vector pieces_to_be_evicted; for (auto it = min_start; it != min_end; ++it) { Piece* piece = it->second; pieces_to_be_evicted.push_back(piece); } if (IsInDebugMode()) { for (auto* piece : pieces_to_be_evicted) { LOG(INFO) << "release dptr: " << get_offset(piece->ptr) << ", size: " << piece->size << ", cost: " << get_cost(piece->tensor) << ", compute op: " << (piece->tensor != nullptr ? piece->tensor->compute_op_type_name() : "no") << ", id: " << (piece->tensor != nullptr ? std::to_string(piece->tensor->id()) : "no"); } } size_t evict_size = 0; for (auto* piece : pieces_to_be_evicted) { evict_size += piece->size; // NOTE: evict will trigger the merge and deallocation of neighbour free pieces, // e.g. two contiguous pieces relu, no_tensor, after relu evict, no_tensor will be deallocated. // currently deallocation only set tensor to nullptr, not real free, // so no bug occurs. It is tricky and fragile. if (piece->tensor != nullptr) { CHECK_OR_RETURN(!ShouldBeHeldBySmallPiece(piece->size)); piece->tensor->Evict(false); } } VLOG(2) << "evict size: " << evict_size; if (!pieces_to_be_evicted.empty()) { return CHECK_NOTNULL(JUST(FindPiece(required_size, true))); } return nullptr; } Maybe RematEpAllocator::Allocate(char** mem_ptr, std::size_t size) { if (size == 0) { *mem_ptr = nullptr; return Maybe::Ok(); } ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_); size_t aligned_size = CudaMemAlignedBytes(size); Piece* piece = JUST(FindPiece(aligned_size, false)); if (piece == nullptr) { if (first_time) { if (EnvBool()) { DisplayAllPieces(); } first_time = false; } const auto started_at = profiler::GetTimeNow(); const size_t evict_num1 = Singleton::Get()->forced_eviction_num(); if (EnvBool()) { piece = JUST(EvictAndFindPieceLoop(aligned_size, true)); } else if (EnvBool()) { piece = JUST(EvictAndFindPieceLoop(aligned_size, false)); } else { piece = JUST(EvictAndFindPieceOnce(aligned_size)); } const size_t evict_num2 = Singleton::Get()->forced_eviction_num(); const auto duration = profiler::GetTimeNow() - started_at; search_free_mem_cost_.emplace_back(size, evict_num2 - evict_num1, duration); if (EnvBool()) { size_t free_mem = 0; for (const auto& pair : ptr2piece_) { Piece* piece = pair.second; if (piece->is_free) { CHECK_ISNULL_OR_RETURN(piece->tensor); free_mem += piece->size; } } remat::append_memory_frag_info_and_get(free_mem, memory_size_); } } if (piece == nullptr) { DisplayAllPieces(); } CHECK_OR_RETURN(piece != nullptr) << "Error! : Out of memory when allocate size : " << size; CHECK_NOTNULL(piece->ptr); CHECK_OR_RETURN(ptr2piece_.find(piece->ptr) != ptr2piece_.end()); LOG(INFO) << "allocate offset: " << get_offset(piece->ptr) << ", size: " << piece->size << std::endl; *mem_ptr = piece->ptr; total_allocate_bytes_ += size; piece->is_free = false; return Maybe::Ok(); } void RematEpAllocator::Deallocate(char* mem_ptr, std::size_t size) { if (mem_ptr == nullptr) { return; } ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_); auto it = ptr2piece_.find(mem_ptr); CHECK(it != ptr2piece_.end()) << "Error! : Try deallocate mem_ptr non-existent. mem ptr = " << mem_ptr << " size = " << size; Piece* piece = it->second; CHECK_NOTNULL(piece); CHECK_EQ(piece->ptr, mem_ptr); CHECK(!piece->is_free); if (auto* tensor = piece->tensor) { CHECK_JUST(remat::DisjointSet::update_after_release(tensor)); } piece->is_free = true; piece->tensor = nullptr; piece->is_left = true; Piece* last_piece_insert_to_free_list = piece; Piece* next_p = piece->next; Piece* prev_p = piece->prev; VLOG(2) << "deallocate offset: " << get_offset(piece->ptr) << ", size: " << piece->size << ", prev: " << prev_p << ", next: " << next_p; if (next_p != nullptr && next_p->is_free) { CHECK_EQ(next_p->ptr, piece->ptr + piece->size); EraseFromFreeList(next_p); VLOG(2) << "merge with next_p"; MergeNeighbourFreePiece(piece, next_p); } if (prev_p != nullptr && prev_p->is_free) { CHECK_EQ(piece->ptr, prev_p->ptr + prev_p->size); EraseFromFreeList(prev_p); VLOG(2) << "merge with prev_p"; MergeNeighbourFreePiece(prev_p, piece); last_piece_insert_to_free_list = prev_p; } InsertToFreeList(last_piece_insert_to_free_list); total_deallocate_bytes_ += size; CheckPieces(); } size_t RematEpAllocator::allocated_memory() { CHECK_GE(total_allocate_bytes_, total_deallocate_bytes_); return total_allocate_bytes_ - total_deallocate_bytes_; } void RematEpAllocator::DeviceReset() { ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_); backend_->DeviceReset(); } nlohmann::json RematEpAllocator::DumpSearchFreeMemCost() { return {{"overhead", search_free_mem_cost_}}; } } // namespace vm vm::RematEpAllocator* remat::AllocatorManager::CreateOrGetAllocator(DeviceType device_type, size_t device_index) { auto key = std::make_pair(device_type, device_index); auto it = allocators_.find(key); if (it == allocators_.end()) { auto ep_device = Singleton::Get()->GetDevice(device_type, device_index); auto ep_backend_allocator = std::make_unique(ep_device, ep::AllocationOptions{}); auto allocator = std::make_unique(ep::kMaxAlignmentRequirement, std::move(ep_backend_allocator)); allocators_.emplace(key, std::move(allocator)); return allocators_.at(key).get(); } else { return it->second.get(); } } } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/allocator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_ #include #include "oneflow/core/common/env_var/remat.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/vm/allocator.h" #include "oneflow/core/common/util.h" #include "nlohmann/json.hpp" #include "oneflow/core/vm/thread_safe_guard.h" namespace oneflow { namespace vm { class EagerBlobObject; class RematableTensorStorage; class RematEpAllocator final : public Allocator { public: OF_DISALLOW_COPY_AND_MOVE(RematEpAllocator); explicit RematEpAllocator(size_t alignment, std::unique_ptr&& backend); ~RematEpAllocator() override; void DeviceReset() override; Maybe Allocate(char** mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override; void LinkStorageAndPtr(RematableTensorStorage* storage, const char* mem_ptr); void CheckPieces(); void DisplayAllPieces(); nlohmann::json DumpSearchFreeMemCost(); size_t allocated_memory(); void set_left(bool is_left) { left = is_left; } bool left = true; size_t iterate_group_index(bool high) const; bool first_time = true; private: const size_t alignment_; const std::unique_ptr backend_; ReentrantThreadSafeLock thread_lock_; using offset_t = size_t; offset_t get_offset(const char* mem_ptr) const; // Piece is the basic memory unit of CudaAllocator. // A Piece is either is free(is_free = true) or in used(is_free = false). // Pieces are stored in a linked list. The Piece's prev and next are // continuous with the current Piece in physical memory. struct Piece { size_t size = 0; char* ptr = nullptr; bool is_free = true; Piece* prev = nullptr; Piece* next = nullptr; vm::RematableTensorStorage* tensor = nullptr; bool is_left = true; }; Maybe InSmallMemoryArea(void* ptr); offset_t FindProperPositionInGroup(Piece* piece, size_t group_idx, size_t request_size) const; Piece* AllocateMemoryInPiece(Piece* piece, offset_t offset_in_piece, size_t size); void InsertToFreeList(Piece* piece); void EraseFromFreeList(Piece* piece); void InitMemory(); // Try find free Piece which size is larger than aligned_size // Return nullptr when find failure Maybe FindPiece(size_t aligned_size, bool after_eviction); void Display(); // Create new empty Piece or recycle a Piece from recycle_piece_list_ Piece* AllocatePiece(); // Delete a Piece and move in the linked list recycle_piece_list_ void DeallocatePiece(Piece* piece); // Insert a {piece->ptr, piece} pair into the ptr2piece_ map for search Piece when call // Deallocate() void InsertPiece2PtrMap(Piece* piece); // Erase the {piece->ptr, piece} pair from ptr2piece_ because the ptr is useless // Usually call before DeallocatePiece() void ErasePieceFromPtrMap(Piece* piece); void MergeNeighbourFreePiece(Piece* lhs, Piece* rhs); Maybe EvictAndFindPieceOnce(size_t required_size); Maybe EvictAndFindPieceLoop(size_t required_size, bool consider_neighbor); char* memory_ = nullptr; size_t memory_size_; void* small_piece_area_ptr_ = nullptr; // hold the lifetime of Piece std::vector> pieces_; struct PieceCmp { bool operator()(const Piece* lhs, const Piece* rhs) const { if (lhs->size != rhs->size) { return lhs->size < rhs->size; } // compare pointer by raw < or > is undefined behavior return std::less<>{}(lhs->ptr, rhs->ptr); } }; std::vector> free_pieces_overlapping_with_group_; // std::map is sorted by key, so we can find contiguous memory by it std::map ptr2piece_; std::vector> search_free_mem_cost_; Piece* recycle_piece_list_; size_t total_allocate_bytes_ = 0; size_t total_deallocate_bytes_ = 0; // ----- size_t normal_group_num_; std::vector group_indexes_; mutable size_t cur_group_index_id_; mutable size_t cur_group_index_id_high_cost_; bool enable_left_and_right_; std::vector> group_boundaries_; size_t group_index(bool high) const; }; class DtrEpAllocatorProxy final : public Allocator { public: explicit DtrEpAllocatorProxy(vm::RematEpAllocator* allocator) : allocator(allocator) {} void DeviceReset() override { allocator->DeviceReset(); } Maybe Allocate(char** mem_ptr, std::size_t size) override { return allocator->Allocate(mem_ptr, size); } void Deallocate(char* mem_ptr, std::size_t size) override { allocator->Deallocate(mem_ptr, size); } vm::RematEpAllocator* const allocator; }; } // namespace vm namespace remat { class AllocatorManager { public: vm::RematEpAllocator* CreateOrGetAllocator(DeviceType device_type, size_t device_index); private: std::unordered_map, std::unique_ptr> allocators_; }; } // namespace remat } // namespace oneflow #endif // ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/remat/disjoint_set.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/remat/disjoint_set.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/remat/allocator.h" namespace oneflow { namespace remat { void DisjointSet::merge(std::shared_ptr& x, std::shared_ptr& y) { auto parent_x = find_father(x); auto parent_y = find_father(y); if (parent_x.get() == parent_y.get()) { return; } parent_y->set_compute_time(parent_y->compute_time() + parent_x->compute_time()); parent_x->set_parent(parent_y); } std::shared_ptr DisjointSet::find_father(std::shared_ptr& x) { if (x->is_root()) { return x; } else { auto fa = x->parent(); auto y = find_father(fa); x->set_parent(y); return y; } } void DisjointSet::update_after_compute(vm::RematableTensorStorage* obj) { auto fa = find_father(obj->node); fa->set_compute_time(fa->compute_time() - obj->node->compute_time()); obj->node->reset(obj->compute_time()); } Maybe DisjointSet::update_after_release(vm::RematableTensorStorage* obj) { CHECK_NOTNULL_OR_RETURN(obj); if (obj->is_eviction_disabled()) { return Maybe::Ok(); } const auto merge_nodes = [&obj](const auto& eager_blob_objects) { for (int i = 0; i < eager_blob_objects.size(); ++i) { if (auto storage = std::dynamic_pointer_cast( eager_blob_objects[i]->tensor_storage()); storage && !storage->is_in_memory()) { merge(storage->node, obj->node); } } }; auto operand = obj->compute_op(); merge_nodes(operand.inputs()); merge_nodes(operand.outputs()); return Maybe::Ok(); } } // namespace remat } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/disjoint_set.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "oneflow/core/common/maybe.h" namespace oneflow { namespace vm { class RematableTensorStorage; } namespace remat { class DisjNode { public: explicit DisjNode(double time) : compute_time_(time), parent_(nullptr), cnt_(1) {} bool is_root() { return !bool(parent_); } void set_parent(std::shared_ptr& parent) { parent_ = parent; } void set_compute_time(double new_time) { compute_time_ = new_time; } void set_cnt(int cnt) { cnt_ = cnt; } void add_cnt() { cnt_++; } void reduce_cnt() { cnt_--; } double compute_time() { return compute_time_; } std::shared_ptr parent() { return parent_; } int cnt() { return cnt_; } void reset(double t) { compute_time_ = t; parent_.reset(); } private: double compute_time_; std::shared_ptr parent_; int cnt_; }; class DisjointSet { public: static void merge(std::shared_ptr& x, std::shared_ptr& y); static std::shared_ptr find_father(std::shared_ptr& x); static void update_after_compute(vm::RematableTensorStorage* obj); static Maybe update_after_release(vm::RematableTensorStorage* obj); }; } // namespace remat } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/env.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/remat/env.h" #include "nlohmann/json.hpp" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace remat { vm::OpCallInstructionPolicy Env::update_tensor_with_storage( vm::RematableTensorStorage* storage, const vm::OpCallInstructionPolicy& current_compute_op) { // TODO: set disjnode properly auto new_storage = std::make_shared(storage->device()); std::unordered_map> old2new; auto update = [&new_storage, &old2new](std::shared_ptr& old) { auto it = old2new.find(old.get()); if (it != old2new.end()) { old = it->second; } else { auto local_tensor_meta = old->tensor_meta(); const auto& eager_blob_object = std::make_shared( std::make_shared(old->mem_case()), local_tensor_meta, old->mut_tensor_meta(), local_tensor_meta->dtype(), local_tensor_meta->memory_format(), new_storage); eager_blob_object->set_storage_offset(old->storage_offset()); old2new.emplace(old.get(), eager_blob_object); old = eager_blob_object; } }; auto update_output = [&old2new, &new_storage](std::weak_ptr& old) { auto it = old2new.find(CHECK_NOTNULL(old.lock()).get()); if (it != old2new.end()) { old = it->second; } else { auto old_locked = old.lock(); auto local_tensor_meta = old_locked->tensor_meta(); const auto& eager_blob_object = std::make_shared( std::make_shared(old_locked->mem_case()), local_tensor_meta, old_locked->mut_tensor_meta(), local_tensor_meta->dtype(), local_tensor_meta->memory_format(), new_storage); eager_blob_object->set_storage_offset(old_locked->storage_offset()); old2new.emplace(old_locked.get(), eager_blob_object); old = eager_blob_object; } }; for (int i = ops.size() - 1; i >= 0; i--) { auto& op = ops[i]; for (int j = 0; j < op->mut_inputs().size(); j++) { auto& x = op->mut_inputs()[j]; if (x == nullptr) { LOG(INFO) << "No." << j << " input of " << op->opkernel().op_type_name() << " is nullptr" << std::endl; continue; } if (x->tensor_storage().get() == storage) { vm::EagerBlobObject* old_ptr = x.get(); update(x); VLOG(1) << "update input of " << op->opkernel().op_type_name() << " from " << old_ptr << " (storage " << storage << ") to " << x.get() << " (storage " << new_storage.get() << "), op addr " << op << std::endl; } } for (int j = 0; j < op->mut_outputs().size(); j++) { auto& y = op->mut_outputs()[j]; if (y.lock() == nullptr) { LOG(INFO) << "No." << j << " output of " << op->opkernel().op_type_name() << " is nullptr" << std::endl; continue; } if (CHECK_NOTNULL(y.lock())->tensor_storage().get() == storage) { vm::EagerBlobObject* old_ptr = y.lock().get(); update_output(y); VLOG(1) << "update output of " << op->opkernel().op_type_name() << " from " << old_ptr << " (storage " << storage << ") to " << y.lock().get() << " (storage " << new_storage.get() << "), op addr " << op << std::endl; } } } vm::OpCallInstructionPolicy new_compute_op = current_compute_op; // only update inputs for (auto& x : new_compute_op.mut_inputs()) { if (x->tensor_storage().get() == storage) { vm::EagerBlobObject* old_ptr = x.get(); update(x); VLOG(1) << "update input of " << new_compute_op.opkernel().op_type_name() << " from " << old_ptr << " to " << x.get() << std::endl; } } VLOG(1) << "update_tensor_with_storage: storage " << storage->id(); // set compute_op_ and compute_time_ new_storage->set_compute_op(storage->dtr_compute_op(), storage->compute_time()); // set blob_bytes_ new_storage->set_blob_dptr(nullptr, storage->blob_bytes()); // set is_initialized_ new_storage->set_initialized(); // set last_access_time_ new_storage->Access(); storage->clear_compute_op(); return new_compute_op; } void Env::add_eviction_num(bool eager_eviction) { if (eager_eviction) { eager_eviction_num_++; } else { forced_eviction_num_++; } } Env::~Env() { LOG(INFO) << "forced eviction num: " << forced_eviction_num_; LOG(INFO) << "eager eviction num: " << eager_eviction_num_; LOG(INFO) << "recomputation num: " << recomputation_num_; LOG(INFO) << "duration: " << time_now_; const char* prefix = std::getenv("ONEFLOW_REMAT_SUMMARY_FILE_PREFIX"); if (prefix != nullptr && GlobalProcessCtx::LocalRank() == 0) { using json = nlohmann::json; json cpp_summary{{"forced eviction", forced_eviction_num_}, {"eager eviction", eager_eviction_num_}, {"recomputation", recomputation_num_}, {"dataset time", time_now_}}; json full_json; // std::fstream has strange default append semantic { std::ifstream fs(std::string(prefix) + ".json"); if (fs.is_open()) { fs >> full_json; } } full_json.merge_patch(cpp_summary); { std::ofstream fs(std::string(prefix) + ".json"); fs << full_json; } } } } // namespace remat } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/env.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "oneflow/core/common/env_var/remat.h" #include "oneflow/core/common/util.h" #define VLOG_REMAT(verbose_level) \ if (Singleton::Get()->log_enabled()) VLOG(verbose_level) namespace oneflow { namespace vm { class RematableTensorStorage; class OpCallInstructionPolicy; class DtrOpCallInstructionPolicy; } // namespace vm namespace remat { class Env { public: Env() = default; ~Env(); OF_DISALLOW_COPY_AND_MOVE(Env); double time_now() { return time_now_; } void add_time(double time) { time_now_ += time; } void remove_compute_op(vm::DtrOpCallInstructionPolicy* op) { ops.erase(std::remove(ops.begin(), ops.end(), op), ops.end()); } vm::OpCallInstructionPolicy update_tensor_with_storage( vm::RematableTensorStorage* storage, const vm::OpCallInstructionPolicy& current_compute_op); std::vector ops; void add_eviction_num(bool eager_eviction); int eager_eviction_num() const { return eager_eviction_num_; } int forced_eviction_num() const { return forced_eviction_num_; } void add_recomputation_num() { recomputation_num_++; } int recomputation_num() const { return recomputation_num_; } void clear_stats() { time_now_ = 0; eager_eviction_num_ = 0; forced_eviction_num_ = 0; recomputation_num_ = 0; } std::set need_eager_eviction_storages; void set_budget_in_bytes(size_t budget_in_bytes) { budget_in_bytes_ = budget_in_bytes; } size_t budget_in_bytes() const { return budget_in_bytes_; } void set_small_pieces_optimization(bool enabled) { small_pieces_optimization_ = enabled; } bool is_small_pieces_optimization_enabled() const { return small_pieces_optimization_; } bool log_enabled() const { return EnvBool(); } private: double time_now_ = 0; int eager_eviction_num_ = 0; int forced_eviction_num_ = 0; int recomputation_num_ = 0; size_t budget_in_bytes_ = 0; bool small_pieces_optimization_ = true; }; struct CurrentOpTypeName { std::string value; }; } // namespace remat } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/remat/util.h" #include #include "nlohmann/json.hpp" #include "oneflow/core/common/env_var/remat.h" #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/eager/tensor_storage.h" #include "oneflow/core/framework/compute_complexity_fn_context.h" #include "oneflow/core/vm/op_call_instruction_policy.h" #include "oneflow/core/vm/remat/env.h" #include "oneflow/core/vm/remat/disjoint_set.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { namespace remat { double append_memory_frag_info_and_get(size_t free_mem, size_t threshold) { static size_t num = 0; // maintain a summation of memory frag rate static double memory_frag_rate_sum = 0; if (threshold > 0) { memory_frag_rate_sum += (1. * free_mem / threshold); num++; } return memory_frag_rate_sum / num; } namespace { std::string SortKey(const std::string& key) { const auto shape_finish_at = key.rfind(")"); if (shape_finish_at == std::string::npos || shape_finish_at + 2 == key.size()) { return key; } const auto name_and_shape = key.substr(0, shape_finish_at + 1); auto attrs = key.substr(shape_finish_at + 2); if (attrs.substr(attrs.size() - 2) == ", ") { attrs = attrs.substr(0, attrs.size() - 2); } const auto need_find_next = [](const std::string& s, size_t index) { const size_t final_pos = index + 2; if (final_pos >= s.size()) { return false; } if (s.at(index + 1) != ' ') { return true; } if (!(s.at(final_pos) >= 'a' && s.at(final_pos) <= 'z')) { return true; } return false; }; const auto split = [&need_find_next](const std::string& s, std::vector& tokens, const std::string& delimiters) { std::string::size_type lastPos = s.find_first_not_of(delimiters, 0); std::string::size_type pos = s.find_first_of(delimiters, lastPos); while (std::string::npos != pos && need_find_next(s, pos)) { pos = s.find_first_of(delimiters, pos + 1); } while (std::string::npos != pos || std::string::npos != lastPos) { tokens.push_back(s.substr(lastPos, pos - lastPos)); lastPos = s.find_first_not_of(delimiters, pos); pos = s.find_first_of(delimiters, lastPos); while (std::string::npos != pos && need_find_next(s, pos)) { pos = s.find_first_of(delimiters, pos + 1); } } }; std::vector attrs_splited; split(attrs, attrs_splited, ", "); std::sort(attrs_splited.begin(), attrs_splited.end()); return fmt::format("{} {}, ", name_and_shape, fmt::join(attrs_splited, ", ")); } using json = nlohmann::json; json LoadTimeDataset() { json j; if (const char* c = std::getenv("ONEFLOW_REMAT_OP_TIME_DATASET")) { std::ifstream i(c); i >> j; i.close(); } json new_j; for (json::iterator iter = j.begin(); iter != j.end(); ++iter) { new_j[SortKey(iter.key())] = iter.value(); } return new_j; } Maybe GetDatasetComputeTime(const json& j, const vm::OpCallInstructionPolicy& operand) { const std::vector zero_time_list{ "empty", "identity", "constant", "copy", "zero_like", "expand_dims", "flatten", "reduce_sum", "reshape", "reshape_like", "squeeze", "transpose", "nll", "nll_grad", "uniform", "uniform_int", "fill_", "slice_update", "normal", // ddp "eager_ccl_broadcast", "eager_ccl_all_reduce", "eager_ccl_touch", "scalar_mul", // "adaptive_avg_pool2d", // "adaptive_avg_pool2d_grad" }; for (const auto& x : zero_time_list) { if (operand.opkernel().op_type_name() == x) { return 0; } } const std::string op_type_str = operand.opkernel().op_type_name(); const std::string input_shape_str = [&]() { std::stringstream ss; for (size_t i = 0; i < operand.inputs().size(); i++) { ss << operand.inputs().at(i)->shape(); if (i != operand.inputs().size() - 1) { ss << ", "; } } return ss.str(); }(); const std::string attr_str = operand.composed_attrs().ToString(); std::string key = op_type_str + " " + input_shape_str + " " + attr_str; key = SortKey(key); CHECK_OR_RETURN(j.contains(key)) << "key " << key << " not found"; CHECK_OR_RETURN(j[key].is_number_float()) << "key " << key << " is not float, but " << j[key]; return j[key].get(); } static Maybe GetComputeComplexityEstimatedBySize( const vm::OpCallInstructionPolicy& operand) { const auto& inputs = operand.inputs(); const auto& outputs = operand.outputs(); size_t estimated_compute_time = 0; for (const auto& input : inputs) { estimated_compute_time += input->shape().elem_cnt(); } for (const auto& output : outputs) { estimated_compute_time += output->shape().elem_cnt(); } return estimated_compute_time; } int32_t TryGetTensorTupleIndex(const std::unordered_map>& arg_name2bn_index2tensor_tuple_index, const std::string& arg_name, const int32_t arg_index) { auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name); if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); } return -1; } class SingleDeviceOpComputeComplexityFnContext : public user_op::ComputeComplexityFnContext { public: using ArgVec = std::vector>; SingleDeviceOpComputeComplexityFnContext(const OperatorConf& op_conf, const vm::EagerBlobObjectList& inputs, const vm::EagerBlobObjectList& outputs, const ArgTuple* input_arg_tuple, const ArgTuple* output_arg_tuple) : user_op::ComputeComplexityFnContext(user_op::UserOpConfWrapper(op_conf)), input_tensors_(inputs), output_tensors_(outputs), input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {} ~SingleDeviceOpComputeComplexityFnContext() override = default; #define RETURN_IF_FOUND(inputs, outputs, post_action) \ int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \ arg_name, index); \ if (i >= 0) { return (inputs).at(i) post_action; } \ i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \ index); \ if (i >= 0) { return (outputs).at(i) post_action; } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { RETURN_IF_FOUND(input_tensors_, output_tensors_, ->tensor_meta().shared_from_symbol().get()); return nullptr; } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { RETURN_IF_FOUND(input_tensors_, output_tensors_, ->shape()) UNIMPLEMENTED_THEN_THROW(); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { RETURN_IF_FOUND(input_tensors_, output_tensors_, ->data_type()) UNIMPLEMENTED_THEN_THROW(); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return false; } const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { static NdSbp nd_sbp = []() { NdSbp nd_sbp; nd_sbp.add_sbp_parallel()->broadcast_parallel(); return nd_sbp; }(); return nd_sbp; } const ArgVec& inputs() const override { UNIMPLEMENTED_THEN_THROW(); } const ArgVec& outputs() const override { UNIMPLEMENTED_THEN_THROW(); } const ParallelDesc& parallel_desc() const override { static ParallelDesc parallel_desc = []() { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-0"); return ParallelDesc(parallel_conf); }(); return parallel_desc; } const NdSbpSignature* GetNdSbpSignature() const override { UNIMPLEMENTED_THEN_THROW(); } private: const vm::EagerBlobObjectList& input_tensors_; const vm::EagerBlobObjectList& output_tensors_; const ArgTuple* input_arg_tuple_; const ArgTuple* output_arg_tuple_; }; Maybe GetComputeComplexity(const vm::OpCallInstructionPolicy& operand) { const auto& op_conf = operand.opkernel().op_conf(); auto registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf.user_conf().op_type_name()); if (registry->compute_complexity_fn) { SingleDeviceOpComputeComplexityFnContext ctx(op_conf, operand.inputs(), operand.outputs(), operand.opkernel().input_arg_tuple(), operand.opkernel().output_arg_tuple()); return registry->compute_complexity_fn(&ctx); } else { return GetComputeComplexityEstimatedBySize(operand); } } } // namespace Maybe GetComputeTime(const vm::OpCallInstructionPolicy& operand) { const static json time_dataset = LoadTimeDataset(); if (!time_dataset.empty()) { return GetDatasetComputeTime(time_dataset, operand); } return GetComputeComplexity(operand); } } // namespace remat namespace vm { RematHelper::RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy_) : op_call_instruction_policy_(op_call_instruction_policy_) { const auto save_eager_blob_object_storages = [](const auto& eager_blob_objects, auto& storage_conatiner) { storage_conatiner.reserve(eager_blob_objects.size()); for (const auto& x : eager_blob_objects) { storage_conatiner.emplace_back( std::dynamic_pointer_cast(x->tensor_storage())); } }; save_eager_blob_object_storages(op_call_instruction_policy_.inputs(), input_storages_); save_eager_blob_object_storages(op_call_instruction_policy_.outputs(), output_storages_); } RematHelper::RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy, bool inputs_rematable, bool outputs_rematable) : RematHelper(op_call_instruction_policy) { if (outputs_rematable) { storage_is_initialized_.reserve(output_storages_.size()); for (auto& storage : output_storages_) { storage_is_initialized_.push_back(storage->is_initialized()); } if (!inputs_rematable) { for (auto& storage : output_storages_) { VLOG_REMAT(1) << "set storage " << storage->id() << " unevictable" << std::endl; storage->set_eviction_disabled(true); } } } } Maybe RematHelper::_IncReferenceNumOfRecomputedTensor( int& pinned_num, std::set& visited_ops) { VLOG_REMAT(1) << "op is " << op_call_instruction_policy_.opkernel().op_type_name(); for (int i = 0; i < input_storages_.size(); i++) { auto& storage = input_storages_[i]; storage->Pin(); VLOG_REMAT(1) << "No." << i << " input is in memory? " << storage->is_in_memory(); if (!storage->is_in_memory()) { OpCallInstructionPolicy tmp_op = storage->compute_op(); if (!storage->is_needed_by_backward()) { Singleton::Get()->need_eager_eviction_storages.insert(storage.get()); } if (visited_ops.find(storage->dtr_compute_op().get()) == visited_ops.end()) { visited_ops.insert(storage->dtr_compute_op().get()); RematHelper new_helper(tmp_op); JUST(new_helper._IncReferenceNumOfRecomputedTensor(pinned_num, visited_ops)); } } else { pinned_num++; } } VLOG_REMAT(1) << "op " << op_call_instruction_policy_.opkernel().op_type_name() << " end"; return Maybe::Ok(); } Maybe RematHelper::IncReferenceNumOfRecomputedTensor() { int pinned_num = 0; std::set visited_ops; JUST(_IncReferenceNumOfRecomputedTensor(pinned_num, visited_ops)); return pinned_num; } Maybe RematHelper::RematInputs( vm::Stream* vm_stream, bool first, const std::function(OpCallInstructionPolicy*, vm::Stream*)>& compute_fn) { CHECK_OR_RETURN(!ThreadLocalEnvBool()); if (first) { JUST(IncReferenceNumOfRecomputedTensor()); } VLOG_REMAT(1) << "compute " << op_call_instruction_policy_.opkernel().op_type_name() << std::endl; VLOG_REMAT(1) << "input num " << op_call_instruction_policy_.inputs().size() << std::endl; for (int i = 0; i < input_storages_.size(); i++) { auto& storage = input_storages_[i]; if (!storage->is_in_memory()) { VLOG_REMAT(1) << "recompute No." << i << " input by " << storage->compute_op_type_name() << ". Storage id: " << storage->id(); OpCallInstructionPolicy tmp_op = storage->compute_op(); JUST(compute_fn(&tmp_op, vm_stream)); } } return Maybe::Ok(); } Maybe RematHelper::EagerlyEvictRemattedTensors(bool first) { auto& need_eager_eviction_storages = Singleton::Get()->need_eager_eviction_storages; for (auto& storage : input_storages_) { storage->Unpin(); if (storage->num_pinned() == 0 && need_eager_eviction_storages.count(storage.get()) > 0) { need_eager_eviction_storages.erase(storage.get()); storage->Evict(true); } } if (first) { if (!need_eager_eviction_storages.empty()) { for (const auto& storage : need_eager_eviction_storages) { VLOG_REMAT(1) << "not empty, storage id: " << storage->id(); } } CHECK_OR_RETURN(need_eager_eviction_storages.empty()); } return Maybe::Ok(); } Maybe RematHelper::UpdateRematInfo(bool first, bool recompute, bool include_input, bool include_output) { if (include_output) { const std::unique_ptr compute_op = [&]() { auto compute_op = std::make_unique(op_call_instruction_policy_); for (int i = 0; i < output_storages_.size(); i++) { const auto& storage = output_storages_[i]; VLOG_REMAT(1) << "output " << i << " storage id: " << storage->id(); if (storage->is_eviction_disabled()) { continue; } if (storage_is_initialized_[i] && !recompute) { VLOG_REMAT(1) << "storage->is_initialized(), op is " << storage->compute_op_type_name() << std::endl; compute_op = std::make_unique( Singleton::Get()->update_tensor_with_storage( storage.get(), op_call_instruction_policy_)); } } return compute_op; }(); std::shared_ptr dtr_compute_op = std::make_shared(*compute_op); double compute_time = JUST(remat::GetComputeTime(*compute_op)); for (auto& storage : output_storages_) { storage->Pin(); if (!recompute && !storage->is_eviction_disabled()) { storage->set_compute_op(dtr_compute_op, compute_time); } storage->Unpin(); storage->Access(); remat::DisjointSet::update_after_compute(storage.get()); } } if (include_input) { for (int i : op_call_instruction_policy_.opkernel().input_tuple_indexes4mut_ibns()) { input_storages_[i]->set_eviction_disabled(true); } for (auto& storage : input_storages_) { storage->Access(); } } if (recompute) { Singleton::Get()->add_recomputation_num(); } Singleton::Get()->add_time(JUST(remat::GetComputeTime(op_call_instruction_policy_))); VLOG_REMAT(1) << "end compute " << op_call_instruction_policy_.opkernel().op_type_name() << std::endl; return Maybe::Ok(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/remat/util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "oneflow/core/common/maybe.h" namespace oneflow { namespace vm { class OpCallInstructionPolicy; } namespace remat { double append_memory_frag_info_and_get(size_t free_mem, size_t threshold); Maybe GetComputeTime(const vm::OpCallInstructionPolicy& operand); } // namespace remat namespace vm { class RematableTensorStorage; class Stream; class DtrOpCallInstructionPolicy; // This class is mainly for holding RematableTensorStorage vector so that we do not // need to generate them every time. class RematHelper { public: explicit RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy); RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy, bool inputs_rematable, bool outputs_rematable); Maybe RematInputs( vm::Stream* vm_stream, bool first, const std::function(OpCallInstructionPolicy*, vm::Stream*)>& compute_fn); Maybe EagerlyEvictRemattedTensors(bool first); Maybe UpdateRematInfo(bool first, bool recompute, bool include_input, bool include_output); private: Maybe IncReferenceNumOfRecomputedTensor(); Maybe _IncReferenceNumOfRecomputedTensor( int& pinned_num, std::set& visited_ops); const OpCallInstructionPolicy& op_call_instruction_policy_; std::vector> input_storages_; std::vector> output_storages_; std::vector storage_is_initialized_; }; } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/vm/stream_create_stream_policy.h" #include "oneflow/core/framework/stream_on_independent_thread.h" namespace oneflow { namespace vm { void Stream::__Init__(ThreadCtx* thread_ctx, Symbol device, StreamType stream_type, const intrusive::shared_ptr& schedule_local_dep_object, const std::vector>& transport_dependences) { set_thread_ctx(thread_ctx); device_ = device; stream_type_ = stream_type; stream_policy_ = CHECK_JUST(CreateStreamPolicy::Visit(stream_type, device)); schedule_local_dep_object_ = schedule_local_dep_object; transport_dependences_ = transport_dependences; on_scheduler_thread_ = stream_policy_->OnSchedulerThread(stream_type); } int64_t Stream::device_id() const { return device_->device_id(); } char* Stream::CheckSizeAndGetTmpSmallPinnedMemPtr(size_t size) { static constexpr int kSmallSize = 512; CHECK_LE(size, kSmallSize); if (!static_cast(small_pinned_mem_ptr_)) { auto* ep_device = stream_policy_->stream()->device(); void* mem_ptr = nullptr; CHECK_JUST(ep_device->AllocPinned(ep::AllocationOptions{}, &mem_ptr, kSmallSize)); std::function Deleter = [ep_device](char* ptr) { ep_device->FreePinned(ep::AllocationOptions{}, ptr); }; char* ptr = reinterpret_cast(mem_ptr); small_pinned_mem_ptr_ = decltype(small_pinned_mem_ptr_)(ptr, Deleter); } return small_pinned_mem_ptr_.get(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_H_ #define ONEFLOW_CORE_VM_STREAM_H_ #include "oneflow/core/vm/instruction.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/vm/stream_policy.h" namespace oneflow { class Device; namespace vm { class ThreadCtx; class MirroredObject; class Dependence; class Stream final : public intrusive::Base { public: // types using DispatchedInstructionList = intrusive::List; // Getters const StreamPolicy& stream_policy() const { return *stream_policy_; } const ThreadCtx& thread_ctx() const { return *thread_ctx_; } bool has_thread_ctx() const { return thread_ctx_ != nullptr; } const intrusive::ListHook& active_stream_hook() const { return active_stream_hook_; } const DispatchedInstructionList& running_instruction_list() const { return running_instruction_list_; } // Setters StreamPolicy* mut_stream_policy() { return stream_policy_.get(); } ThreadCtx* mut_thread_ctx() { return thread_ctx_; } void set_thread_ctx(ThreadCtx* val) { thread_ctx_ = val; } void clear_thread_ctx() { thread_ctx_ = nullptr; } DispatchedInstructionList* mut_running_instruction_list() { return &running_instruction_list_; } // methods void __Init__(ThreadCtx* thread_ctx, Symbol device, StreamType stream_type, const intrusive::shared_ptr& schedule_local_dep_object, const std::vector>& transport_dependences); int64_t device_id() const; Symbol device() const { return device_; } StreamType stream_type() const { return stream_type_; } bool on_scheduler_thread() const { return on_scheduler_thread_; } const intrusive::shared_ptr& schedule_local_dep_object() const { return schedule_local_dep_object_; } const std::vector>& transport_dependences() const { return transport_dependences_; } char* CheckSizeAndGetTmpSmallPinnedMemPtr(size_t size); private: void MoveToFreeList(intrusive::shared_ptr&& instruction); void MoveFromZombieListToFreeList(); friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } Stream() : intrusive_ref_(), thread_ctx_(), device_(), stream_type_(StreamType::kInvalid), stream_policy_(), on_scheduler_thread_(false), small_pinned_mem_ptr_(), running_instruction_list_(), active_stream_hook_(), thread_ctx_stream_hook_() {} intrusive::Ref intrusive_ref_; // fields ThreadCtx* thread_ctx_; Symbol device_; StreamType stream_type_; std::shared_ptr stream_policy_; bool on_scheduler_thread_; std::unique_ptr> small_pinned_mem_ptr_; // lists DispatchedInstructionList running_instruction_list_; intrusive::shared_ptr schedule_local_dep_object_; std::vector> transport_dependences_; public: // list hooks intrusive::ListHook active_stream_hook_; intrusive::ListHook thread_ctx_stream_hook_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_H_ ================================================ FILE: oneflow/core/vm/stream_create_stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_ #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/vm/control_stream_policy.h" #include "oneflow/core/vm/event_recorded_ep_stream_policy.h" #include "oneflow/core/vm/critical_section_stream_policy.h" #include "oneflow/core/vm/ep_d2h_stream_policy.h" #include "oneflow/core/vm/ep_stream_policy.h" #include "oneflow/core/vm/pinned_ep_stream_policy.h" #include "oneflow/core/vm/lazy_job_stream_policy.h" namespace oneflow { class Device; struct CreateStreamPolicy final : public StreamTypeVisitor { static Maybe VisitCompute(Symbol device) { return std::shared_ptr(new vm::EpStreamPolicy(device)); } static Maybe VisitHost2Device(Symbol device) { std::unique_ptr allocator{}; if (device->enum_type() == DeviceType::kCPU) { allocator = vm::EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(device); } else { allocator = std::make_unique("allocator is not supported on h2d stream."); } return std::shared_ptr( new vm::EventRecordedEpStreamPolicy(device, std::move(allocator))); } static Maybe VisitDevice2Host(Symbol device) { return std::shared_ptr(new vm::EpD2HStreamPolicy(device)); } static Maybe VisitCcl(Symbol device) { auto allocator = vm::EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(device); return std::shared_ptr( new vm::EventRecordedEpStreamPolicy(device, std::move(allocator))); } static Maybe VisitBarrier(Symbol device) { return std::shared_ptr(new vm::ControlStreamPolicy()); } static Maybe VisitCriticalSection(Symbol device) { return std::shared_ptr(new vm::CriticalSectionStreamPolicy()); } static Maybe VisitLazyJobLauncher(Symbol device) { return std::shared_ptr(new vm::LazyJobStreamPolicy()); } static Maybe VisitPinnedCompute(Symbol device) { return std::shared_ptr(new vm::PinnedEpStreamPolicy(device)); } }; } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/stream_get_allocator_stream_type.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_ #define ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_ #include "oneflow/core/common/stream_type.h" namespace oneflow { struct GetAllocatorStreamType final : public StreamTypeVisitor { static Maybe VisitCompute() { return StreamType::kCompute; } static Maybe VisitHost2Device() { return StreamType::kCompute; } static Maybe VisitCcl() { return StreamType::kCompute; } static Maybe VisitPinnedCompute() { return StreamType::kPinnedCompute; } static Maybe VisitDevice2Host() { return StreamType::kDevice2Host; } static Maybe VisitBarrier() { UNIMPLEMENTED_THEN_RETURN() << "no allocator supported on 'barrier' stream_type."; } static Maybe VisitCriticalSection() { UNIMPLEMENTED_THEN_RETURN() << "no allocator supported on 'critical_section' stream_type."; } static Maybe VisitLazyJobLauncher() { UNIMPLEMENTED_THEN_RETURN() << "no allocator supported on 'lazy_job_launcher' stream_type."; } }; } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_ ================================================ FILE: oneflow/core/vm/stream_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/framework/stream_on_independent_thread.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/thread/thread_global_id.h" namespace oneflow { namespace vm { bool StreamPolicy::OnSchedulerThread(StreamType stream_type) const { if (StreamOnIndependentThread::Visit(stream_type)) { return false; } return !ThreadLocalEnvBool(); } void StreamPolicy::RunIf(Instruction* instruction) const { if (IsCommNetStream::Visit(instruction->stream().stream_type()) && ThreadLocalEnvBool()) { ThreadGlobalIdGuard guard{kThreadGlobalIdDefaultWorker}; Run(instruction); } else { Run(instruction); } } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_POLICY_H_ #define ONEFLOW_CORE_VM_STREAM_POLICY_H_ #include #include #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class EpEventProvider; namespace ep { class Device; class Stream; } // namespace ep namespace vm { class Allocator; class Stream; class InstructionStatusBuffer; class Instruction; class StreamPolicy { public: virtual ~StreamPolicy() = default; virtual ep::Stream* stream() = 0; virtual vm::Allocator* mut_allocator() = 0; virtual DeviceType device_type() const = 0; virtual void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const = 0; virtual void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const = 0; virtual bool QueryInstructionStatusLaunched( const Stream& stream, const InstructionStatusBuffer& status_buffer) const = 0; virtual bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const = 0; virtual bool OnSchedulerThread(StreamType stream_type) const; virtual bool SupportingTransportInstructions() const = 0; void RunIf(Instruction* instruction) const; protected: StreamPolicy() = default; private: virtual void Run(Instruction* instruction) const = 0; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_POLICY_H_ ================================================ FILE: oneflow/core/vm/stream_record_event_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/stream_record_event_instruction_policy.h" #include "oneflow/core/vm/ep_event.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/ep/cuda/cuda_event.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/core/vm/ep_stream_policy_base.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" namespace oneflow { namespace vm { StreamRecordEventInstructionPolicy::StreamRecordEventInstructionPolicy( const small_vector>& dependences) : dependences_(dependences), input_dependences_(), output_dependences_() { for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); } } void StreamRecordEventInstructionPolicy::InitInstructionStatus(Instruction* instruction) { auto* stream = instruction->mut_stream(); { auto* ep_stream_policy_base = CHECK_NOTNULL(dynamic_cast(instruction->mut_stream_policy())); ep_stream_policy_base->InitInstructionStatus(*stream, instruction->mut_status_buffer()); auto* ep_event_provider = ep_stream_policy_base->ep_event_provider(); const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent(); mut_ep_event() = ep_event; } { auto* status_buffer = instruction->mut_status_buffer(); instruction->stream_policy().InitInstructionStatus(*stream, status_buffer); auto* data_ptr = status_buffer->mut_buffer(); EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr); } } void StreamRecordEventInstructionPolicy::Compute(vm::Instruction* instruction) { const auto& ep_event = mut_ep_event(); // Record event. auto* stream_policy = dynamic_cast(instruction->mut_stream()->mut_stream_policy()); CHECK_NOTNULL(stream_policy)->stream()->RecordEvent(ep_event->mut_event()); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream_record_event_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/common/small_vector.h" namespace oneflow { class EpEvent; namespace vm { class Stream; class StreamRecordEventInstructionPolicy final : public vm::InstructionPolicy { public: StreamRecordEventInstructionPolicy( const small_vector>& dependences); ~StreamRecordEventInstructionPolicy() = default; std::string DebugName(const vm::Instruction&) const override { return "StreamRecordEvent"; } void InitInstructionStatus(Instruction* instruction) override; Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } std::shared_ptr& mut_ep_event() { return ep_event_; } private: small_vector> dependences_; DependenceVector input_dependences_; DependenceVector output_dependences_; std::shared_ptr ep_event_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/stream_wait_event_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/stream_wait_event_instruction_policy.h" #include "oneflow/core/vm/ep_event.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/ep_stream_policy_base.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" namespace oneflow { namespace vm { StreamWaitEventInstructionPolicy::StreamWaitEventInstructionPolicy( const small_vector>& dependences, const std::shared_ptr& stream_record_event_instruction_policy) : dependences_(dependences), input_dependences_(), output_dependences_(), stream_record_event_instruction_policy_(stream_record_event_instruction_policy) { for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); } } void StreamWaitEventInstructionPolicy::DeleteInstructionStatus(Instruction* instruction) { auto* stream = instruction->mut_stream(); instruction->stream_policy().DeleteInstructionStatus(*stream, instruction->mut_status_buffer()); stream_record_event_instruction_policy_->mut_ep_event().reset(); } void StreamWaitEventInstructionPolicy::Compute(vm::Instruction* instruction) { const auto& ep_event = stream_record_event_instruction_policy_->mut_ep_event(); // Wait event. auto* ep_stream_policy_base = dynamic_cast(instruction->mut_stream()->mut_stream_policy()); CHECK_NOTNULL(ep_stream_policy_base); auto* ep_stream = ep_stream_policy_base->stream(); CHECK_EQ(ep_event->mut_device(), ep_stream->device()) << "only support waiting events from same device"; ep_event->mut_device()->SetAsActiveDevice(); ep_stream->WaitEvent(ep_event->mut_event()); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream_wait_event_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/vm/stream_record_event_instruction_policy.h" namespace oneflow { namespace vm { class Stream; class StreamWaitEventInstructionPolicy final : public vm::InstructionPolicy { public: StreamWaitEventInstructionPolicy( const small_vector>& dependences, const std::shared_ptr& stream_record_event_instruction_policy); ~StreamWaitEventInstructionPolicy() = default; std::string DebugName(const vm::Instruction&) const override { return "StreamWaitEvent"; } void DeleteInstructionStatus(Instruction* instruction) override; Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } private: small_vector> dependences_; DependenceVector input_dependences_; DependenceVector output_dependences_; std::shared_ptr stream_record_event_instruction_policy_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/stream_wait_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/stream_wait_instruction_policy.h" #include "oneflow/core/vm/ep_event.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/ep_stream_policy_base.h" #include "oneflow/core/vm/ep_optional_event_record_status_querier.h" namespace oneflow { namespace vm { StreamWaitInstructionPolicy::StreamWaitInstructionPolicy( small_vector>&& dependences, vm::Stream* from_vm_stream, vm::Stream* to_vm_stream) : dependences_(std::move(dependences)), input_dependences_(), output_dependences_(), from_vm_stream_(from_vm_stream) { for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); } stream_sequential_dependence_ = to_vm_stream->schedule_local_dep_object().get(); } bool StreamWaitInstructionPolicy::Prescheduleable(const Stream* src, const Stream* dst) const { return &src->thread_ctx() == &dst->thread_ctx(); } void StreamWaitInstructionPolicy::InitInstructionStatus(Instruction* instruction) { auto* stream = instruction->mut_stream(); auto* ep_stream_policy_base = CHECK_NOTNULL(dynamic_cast(instruction->mut_stream_policy())); ep_stream_policy_base->InitInstructionStatus(*stream, instruction->mut_status_buffer()); auto* ep_event_provider = ep_stream_policy_base->ep_event_provider(); const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent(); mut_ep_event() = ep_event; } void StreamWaitInstructionPolicy::DeleteInstructionStatus(Instruction* instruction) { auto* stream = instruction->mut_stream(); instruction->stream_policy().DeleteInstructionStatus(*stream, instruction->mut_status_buffer()); mut_ep_event().reset(); } void StreamWaitInstructionPolicy::Compute(vm::Instruction* instruction) { const auto& ep_event = mut_ep_event(); { // Record event. auto* from_naive_stream_policy = dynamic_cast(mut_from_vm_stream()->mut_stream_policy()); CHECK_NOTNULL(from_naive_stream_policy); auto* from_stream = from_naive_stream_policy->stream(); from_stream->RecordEvent(ep_event->mut_event()); } { // Wait event. auto* to_ep_stream_policy_base = dynamic_cast(instruction->mut_stream()->mut_stream_policy()); CHECK_NOTNULL(to_ep_stream_policy_base); auto* to_ep_stream = to_ep_stream_policy_base->stream(); CHECK_EQ(ep_event->mut_device(), to_ep_stream->device()) << "only support waiting events from same device"; ep_event->mut_device()->SetAsActiveDevice(); to_ep_stream->WaitEvent(ep_event->mut_event()); } } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/stream_wait_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_ #include #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/common/small_vector.h" namespace oneflow { class EpEvent; namespace vm { class Stream; class StreamWaitInstructionPolicy final : public vm::InstructionPolicy { public: StreamWaitInstructionPolicy(small_vector>&& dependences, vm::Stream* from_vm_stream, vm::Stream* to_vm_stream); ~StreamWaitInstructionPolicy() = default; std::string DebugName(const vm::Instruction&) const override { return "StreamWait"; } bool Prescheduleable(const Stream* src, const Stream* dst) const override; void InitInstructionStatus(Instruction* instruction) override; void DeleteInstructionStatus(Instruction* instruction) override; Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } private: vm::Stream* mut_from_vm_stream() { return from_vm_stream_; } std::shared_ptr& mut_ep_event() { return ep_event_; } small_vector> dependences_; DependenceVector input_dependences_; DependenceVector output_dependences_; vm::Stream* from_vm_stream_; std::shared_ptr ep_event_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/symbol_storage.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/operator/op_conf_symbol.h" namespace oneflow { namespace symbol { namespace detail { template<> Maybe NewSymbol( int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { return ParallelDesc::New(symbol_id, data); } template<> Maybe NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { return JobDesc::New(symbol_id, data); } template<> Maybe NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { return Scope::New(symbol_id, data); } template<> Maybe NewSymbol( int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { return std::make_shared(symbol_id, data); } } // namespace detail } // namespace symbol } // namespace oneflow ================================================ FILE: oneflow/core/vm/symbol_storage.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_STORAGE_H_ #define ONEFLOW_CORE_VM_STORAGE_H_ #include #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/container_util.h" namespace oneflow { class OperatorConfSymbol; class OperatorConf; class ParallelDesc; class ParallelConf; class JobDesc; class JobConfigProto; class Scope; class ScopeProto; namespace symbol { template struct ConstructArgType4Symbol final { using type = T; }; template<> struct ConstructArgType4Symbol final { using type = OperatorConf; }; template<> struct ConstructArgType4Symbol final { using type = ParallelConf; }; template<> struct ConstructArgType4Symbol final { using type = JobConfigProto; }; template<> struct ConstructArgType4Symbol final { using type = ScopeProto; }; namespace detail { template Maybe NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { return std::make_shared(data); } template<> Maybe NewSymbol( int64_t symbol_id, const typename ConstructArgType4Symbol::type& data); template<> Maybe NewSymbol( int64_t symbol_id, const typename ConstructArgType4Symbol::type& data); template<> Maybe NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data); template<> Maybe NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data); } // namespace detail template class Storage final { public: Storage(const Storage&) = delete; Storage(Storage&&) = delete; Storage() = default; ~Storage() = default; bool Has(int64_t symbol_id) const { std::unique_lock lock(mutex_); return symbol_id2symbol_.find(symbol_id) != symbol_id2symbol_.end(); } bool Has(const typename ConstructArgType4Symbol::type& symbol_data) const { std::unique_lock lock(mutex_); const auto& iter = data2symbol_id_.find(symbol_data); return iter != data2symbol_id_.end(); } Maybe MaybeGet(int64_t symbol_id) const { return *JUST(MaybeGetPtr(symbol_id)); } Maybe MaybeGet(const typename ConstructArgType4Symbol::type& data) const { return *JUST(MaybeGetPtr(data)); } const T& Get(int64_t symbol_id) const { return *GetPtr(symbol_id); } const T& Get(const typename ConstructArgType4Symbol::type& data) const { return *GetPtr(data); } Maybe MaybeGetPtr(int64_t symbol_id) const { std::unique_lock lock(mutex_); const auto& iter = symbol_id2symbol_.find(symbol_id); CHECK_OR_RETURN(iter != symbol_id2symbol_.end()) << "symbol_id: " << symbol_id; return iter->second; } Maybe MaybeGetPtr(const typename ConstructArgType4Symbol::type& data) const { std::unique_lock lock(mutex_); const auto& iter = data2symbol_id_.find(data); CHECK_OR_RETURN(iter != data2symbol_id_.end()); return JUST(MapAt(symbol_id2symbol_, iter->second)); } const std::shared_ptr& GetPtr(int64_t symbol_id) const { std::unique_lock lock(mutex_); const auto& iter = symbol_id2symbol_.find(symbol_id); CHECK(iter != symbol_id2symbol_.end()) << "symbol_id: " << symbol_id; return iter->second; } const std::shared_ptr& GetPtr(const typename ConstructArgType4Symbol::type& data) const { std::unique_lock lock(mutex_); const auto& iter = data2symbol_id_.find(data); CHECK(iter != data2symbol_id_.end()); return CHECK_JUST(MapAt(symbol_id2symbol_, iter->second)); } Maybe Add(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { CHECK_GT_OR_RETURN(symbol_id, 0); const auto& ptr = JUST(detail::NewSymbol(symbol_id, data)); std::unique_lock lock(mutex_); CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second); data2symbol_id_[data] = symbol_id; return Maybe::Ok(); } Maybe TryAdd(int64_t symbol_id, const typename ConstructArgType4Symbol::type& data) { CHECK_GT_OR_RETURN(symbol_id, 0); const auto& ptr = JUST(detail::NewSymbol(symbol_id, data)); std::unique_lock lock(mutex_); const auto& iter = symbol_id2symbol_.find(symbol_id); if (iter != symbol_id2symbol_.end()) { CHECK_OR_RETURN(data2symbol_id_.find(data) != data2symbol_id_.end()); return Maybe::Ok(); } CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second); data2symbol_id_[data] = symbol_id; return Maybe::Ok(); } Maybe FindOrCreate(const typename ConstructArgType4Symbol::type& symbol_data, const std::function()>& Create) { int64_t symbol_id = JUST(Create()); const auto& ptr = JUST(detail::NewSymbol(symbol_id, symbol_data)); std::unique_lock lock(mutex_); const auto& iter = data2symbol_id_.find(symbol_data); if (iter != data2symbol_id_.end()) { return JUST(MapAt(symbol_id2symbol_, iter->second)); } CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second); data2symbol_id_[symbol_data] = symbol_id; return JUST(MapAt(symbol_id2symbol_, symbol_id)); } void Clear(int64_t symbol_id) { std::unique_lock lock(mutex_); auto iter = symbol_id2symbol_.find(symbol_id); if (iter != symbol_id2symbol_.end()) { data2symbol_id_.erase(iter->second->data()); symbol_id2symbol_.erase(symbol_id); } } void ClearAll() { std::unique_lock lock(mutex_); symbol_id2symbol_.clear(); data2symbol_id_.clear(); } private: mutable std::mutex mutex_; HashMap> symbol_id2symbol_; HashMap::type, int64_t> data2symbol_id_; }; } // namespace symbol } // namespace oneflow #endif // ONEFLOW_CORE_VM_STORAGE_H_ ================================================ FILE: oneflow/core/vm/sync_access_instruction_policy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/sync_access_instruction_policy.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace vm { SyncAccessInstructionPolicy::SyncAccessInstructionPolicy() : host_mem_case_(memory::MakeHostMemCase()), btb_(), mem_ptr_(nullptr), bytes_(0), eager_blob_object_(nullptr) { ResetBase(nullptr, 0, nullptr); } void SyncAccessInstructionPolicy::ResetBase(char* mem_ptr, size_t bytes, EagerBlobObject* eager_blob_object) { btb_.Reset(); mem_ptr_ = mem_ptr; bytes_ = bytes; eager_blob_object_ = eager_blob_object; } namespace { void FastCopy(char* dst, const char* src, size_t bytes) { switch (bytes) { case 1: { *dst = *src; return; } case 2: { *reinterpret_cast(dst) = *reinterpret_cast(src); return; } case 4: { *reinterpret_cast(dst) = *reinterpret_cast(src); return; } case 8: { *reinterpret_cast(dst) = *reinterpret_cast(src); return; } case 16: { using Bit128 = std::pair; *reinterpret_cast(dst) = *reinterpret_cast(src); return; } default: UNIMPLEMENTED() << "FastCopy on bytes " << bytes << " not supported."; } } } // namespace void SyncReadInstructionPolicy::Compute(Instruction* instruction) { StreamPolicy* stream_policy = instruction->mut_stream_policy(); char* pinned_buffer = instruction->mut_stream()->CheckSizeAndGetTmpSmallPinnedMemPtr(bytes_); mut_btb()->mut_notifier()->Notify(); SyncAutoMemcpy(stream_policy->stream(), pinned_buffer, eager_blob_object_->mut_dptr(), bytes_, host_mem_case_, eager_blob_object_->mem_case()); FastCopy(mem_ptr_, pinned_buffer, bytes_); mut_btb()->mut_spin_counter()->Decrease(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/sync_access_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_ #include #include #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/vm/instruction_policy_util.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/common/util.h" #include "oneflow/core/vm/stream_policy.h" #include "oneflow/core/memory/memory_case_util.h" namespace oneflow { namespace vm { class SyncAccessInstructionPolicy : public InstructionPolicy { public: SyncAccessInstructionPolicy(); virtual ~SyncAccessInstructionPolicy() = default; Maybe Prepare(Instruction* instruction) override { return Maybe::Ok(); } BlockingThenBusy* mut_btb() { return &btb_; } protected: void ResetBase(char* mem_ptr, size_t bytes, EagerBlobObject* eager_blob_object); const MemoryCase host_mem_case_; BlockingThenBusy btb_; char* mem_ptr_; size_t bytes_; EagerBlobObject* eager_blob_object_; }; class SyncReadInstructionPolicy final : public SyncAccessInstructionPolicy { public: SyncReadInstructionPolicy() = default; ~SyncReadInstructionPolicy() = default; const DependenceVector& input_dependences() const override { CHECK_EQ(input_dependences_.size(), 1); return input_dependences_; } const DependenceVector& output_dependences() const override { static thread_local DependenceVector empty{}; return empty; } std::string DebugName(const Instruction& instruction) const override { return "SyncRead"; } void Reset(char* mem_ptr, size_t bytes, EagerBlobObject* eager_blob_object) { ResetBase(mem_ptr, bytes, eager_blob_object); if (likely(input_dependences_.size())) { input_dependences_.clear(); } input_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } void Compute(Instruction* instruction) override; private: DependenceVector input_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/sync_vm_mode_guard.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_ #define ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_ #include "oneflow/core/common/thread_local_guard.h" namespace oneflow { enum class SyncVmMode { kInvalid = 0, kEnable = 1, kDisable = 2, }; class SyncVmModeGuard final : public ThreadLocalGuard { public: using ThreadLocalGuard::ThreadLocalGuard; ~SyncVmModeGuard() = default; static bool IsCurrentSyncVmMode() { const auto& opt_sync_mode = Current(); return opt_sync_mode.has_value() && CHECK_JUST(opt_sync_mode) == SyncVmMode::kEnable; } }; } // namespace oneflow #endif // ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_ ================================================ FILE: oneflow/core/vm/thread_ctx.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { ThreadCtx::ThreadCtx() : intrusive_ref_(), stream_list_(), worker_pending_instruction_mutex_(), worker_pending_instruction_list_(&worker_pending_instruction_mutex_), notifier_(), transport_dependence_(intrusive::make_shared()), thread_ctx_hook_() {} size_t ThreadCtx::TryReceiveAndRun() { intrusive::List tmp_list; mut_worker_pending_instruction_list()->MoveTo(&tmp_list); size_t size = tmp_list.size(); INTRUSIVE_FOR_EACH(instruction, &tmp_list) { tmp_list.Erase(instruction.Mutable()); const StreamPolicy& stream_policy = instruction->stream().stream_policy(); stream_policy.RunIf(instruction.Mutable()); } return size; } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/thread_ctx.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_THREAD__H_ #define ONEFLOW_CORE_VM_THREAD__H_ #include #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/mutexed_list.h" #include "oneflow/core/common/notifier.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/vm_object.h" namespace oneflow { namespace vm { using WorkerPendingInstructionMutexedList = intrusive::MutexedList; class ThreadCtx final : public intrusive::Base { public: // types using StreamList = intrusive::List; // Getters const StreamList& stream_list() const { return stream_list_; } // Setters StreamList* mut_stream_list() { return &stream_list_; } WorkerPendingInstructionMutexedList* mut_worker_pending_instruction_list() { return &worker_pending_instruction_list_; } // methods size_t TryReceiveAndRun(); Notifier* mut_notifier() { return ¬ifier_; } const intrusive::shared_ptr& transport_dependence() const { return transport_dependence_; }; private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } ThreadCtx(); intrusive::Ref intrusive_ref_; // lists StreamList stream_list_; std::mutex worker_pending_instruction_mutex_; WorkerPendingInstructionMutexedList worker_pending_instruction_list_; Notifier notifier_; intrusive::shared_ptr transport_dependence_; public: // list hooks intrusive::ListHook thread_ctx_hook_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_THREAD__H_ ================================================ FILE: oneflow/core/vm/thread_safe_guard.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_ #define ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_ #include #include #include #include #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { class ThreadSafeLock final { public: ThreadSafeLock() = default; ~ThreadSafeLock() = default; OF_DISALLOW_COPY_AND_MOVE(ThreadSafeLock); class RAIIGuard final { public: explicit RAIIGuard(ThreadSafeLock& lock) : guard_(lock.mutex4guard) {} ~RAIIGuard() = default; OF_DISALLOW_COPY_AND_MOVE(RAIIGuard); private: std::unique_lock guard_; }; private: std::mutex mutex4guard; }; class ReentrantThreadSafeLock final { public: ReentrantThreadSafeLock() = default; ~ReentrantThreadSafeLock() = default; OF_DISALLOW_COPY_AND_MOVE(ReentrantThreadSafeLock); class RAIIGuard final { public: explicit RAIIGuard(ReentrantThreadSafeLock& lock) : guard_(lock.mutex4guard) {} ~RAIIGuard() = default; OF_DISALLOW_COPY_AND_MOVE(RAIIGuard); private: std::unique_lock guard_; }; private: std::recursive_mutex mutex4guard; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_ ================================================ FILE: oneflow/core/vm/touch_tensors_instruction_policy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_ #define ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_ #include "oneflow/core/vm/instruction_policy.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/vm/instruction_policy_util.h" namespace oneflow { namespace vm { class TouchTensorsInstructionPolicy final : public InstructionPolicy { public: explicit TouchTensorsInstructionPolicy(const vm::EagerBlobObjectList& eager_blob_objects) : eager_blob_objects_(eager_blob_objects) { const auto& Insert = InstructionPolicyUtil::SetInserter(&input_dependences_); for (const auto& eager_blob_object : eager_blob_objects_) { Insert(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } } ~TouchTensorsInstructionPolicy() = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { static DependenceVector empty{}; return empty; } std::string DebugName(const vm::Instruction& instruction) const override { return "TouchTensors"; } Maybe Prepare(vm::Instruction* instruction) override { return Maybe::Ok(); } void Compute(vm::Instruction* instruction) override {} private: vm::EagerBlobObjectList eager_blob_objects_; DependenceVector input_dependences_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_ ================================================ FILE: oneflow/core/vm/virtual_machine.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/core/vm/sync_vm_mode_guard.h" #include "oneflow/core/vm/barrier_instruction_policy.h" #include "oneflow/core/vm/caching_allocator.h" #include "oneflow/core/vm/global_sync_instruction_policy.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/allocator.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/thread/thread_global_id.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/stream_on_independent_thread.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/common/env_var/env_var.h" #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_get_stream_type_name.h" #include "oneflow/core/framework/stream_mgr.h" namespace oneflow { namespace { template int MicrosecondsFrom(const T& start) { return std::chrono::duration_cast(std::chrono::steady_clock::now() - start) .count(); } Maybe ForEachThreadCtx(vm::VirtualMachineEngine* engine, const std::function(vm::ThreadCtx*)>& DoEach) { INTRUSIVE_UNSAFE_FOR_EACH_PTR(thread_ctx, engine->mut_thread_ctx_list()) { JUST(DoEach(thread_ctx)); } return Maybe::Ok(); } void GetSchedulerThreadInitializer(std::function* Initializer) { *Initializer = [&]() { OF_PROFILER_NAME_THIS_HOST_THREAD("_VM::Scheduler"); }; } void WorkerLoop(vm::ThreadCtx* thread_ctx, const std::function& Initializer) { SyncVmModeGuard guard(SyncVmMode::kEnable); Initializer(thread_ctx); constexpr static size_t kExpireMicroseconds = 200; while (thread_ctx->mut_notifier()->WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) { std::chrono::time_point start{}; do { while (thread_ctx->TryReceiveAndRun()) { start = std::chrono::steady_clock::now(); } std::this_thread::yield(); } while (MicrosecondsFrom(start) < kExpireMicroseconds); } } } // namespace VirtualMachine::VirtualMachine() : multi_thread_(ThreadLocalEnvBool()), threads_closed_(false), scheduler_stopped_(false) { // Class VirtualMachineEngine only cares the basic logical of vm, while class VirtualMachine // manages threads and condition variables. // In order to notify threads in VirtualMachineEngine, a notify callback lambda should be take as // an argument for VirtualMachineEngine's constructor. engine_ = intrusive::make_shared(); OF_PROFILER_NAME_THIS_HOST_THREAD("_Main"); if (multi_thread_) { std::function SchedulerInitializer; GetSchedulerThreadInitializer(&SchedulerInitializer); schedule_thread_ = std::thread(&VirtualMachine::ScheduleLoop, this, SchedulerInitializer); } transport_dependence_.Reset(); } namespace { Maybe> GetBarrierStream() { auto device = JUST(Device::New("cpu")); return Stream::New(device, StreamType::kBarrier); } void MakeBarrierInstructions(vm::InstructionList* list, const std::function& BarrierCallback) { auto* vm = Singleton::Get(); { auto stream = CHECK_JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( CHECK_JUST(vm->GetVmStream(stream)), std::make_shared()); list->EmplaceBack(std::move(instruction)); } { auto stream = CHECK_JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( CHECK_JUST(vm->GetVmStream(stream)), std::make_shared(BarrierCallback)); list->EmplaceBack(std::move(instruction)); } } } // namespace void VirtualMachine::ControlSync() { auto bc = std::make_shared(1); vm::InstructionList list; MakeBarrierInstructions(&list, [bc] { bc->Decrease(); }); CHECK_JUST(Receive(&list)); CHECK_JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); } Maybe VirtualMachine::CloseVMThreads() { CHECK_OR_RETURN(!threads_closed_) << "vm threads closed"; ControlSync(); pending_notifier_.Close(); if (multi_thread_) { schedule_thread_.join(); } else { // For technical reasons, worker threads are always created even in single thread mode JUST(CloseWorkerThreads()); } threads_closed_ = true; return Maybe::Ok(); } namespace { class SingleThreadScheduleCtx : public vm::ScheduleCtx { public: SingleThreadScheduleCtx() = default; ~SingleThreadScheduleCtx() = default; void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const override { while (thread_ctx->TryReceiveAndRun() > 0) {} } }; void ScheduleUntilVMEmpty(vm::VirtualMachineEngine* vm, const vm::ScheduleCtx& schedule_ctx) { do { vm->Schedule(schedule_ctx); } while (!(vm->SchedulerEmpty())); } } // namespace Maybe VirtualMachine::BlockingRunProbeFunc( const std::function& prob_func) { JUST(Singleton::Get()->WithScopedRelease([&, this]() -> Maybe { auto bc = std::make_shared(1); engine_->InsertProbe([bc, prob_func](vm::VirtualMachineEngine* engine) { if (!prob_func(engine)) { return false; } bc->Decrease(); return true; }); if (threads_closed_ || !multi_thread_) { ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx()); } else { pending_notifier_.Notify(); } JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); })); return Maybe::Ok(); } Maybe VirtualMachine::ShrinkAllMem() { auto try_shrink_men = [](vm::VirtualMachineEngine* engine) -> bool { if (engine->mut_active_stream_list()->size()) { return false; } INTRUSIVE_FOR_EACH_PTR(thread_ctx, engine->mut_thread_ctx_list()) { INTRUSIVE_FOR_EACH_PTR(stream, thread_ctx->mut_stream_list()) { vm::Allocator* allocator = stream->mut_stream_policy()->mut_allocator(); if (allocator) { auto* cache = dynamic_cast(allocator); if (cache != nullptr) { cache->Shrink(); } } } } return true; }; return BlockingRunProbeFunc(try_shrink_men); } VirtualMachine::~VirtualMachine() { if (!threads_closed_) { CHECK_JUST(CloseVMThreads()); } RunMainThreadPendingTasks(); CHECK(engine_->SchedulerEmpty()); engine_.Reset(); } std::function()> VirtualMachine::GetPredicatorNoMoreInstructionsFinished() { auto last_total_erased = std::make_shared(0); auto* vm = Singleton::Get(); if (vm != nullptr) { *last_total_erased = vm->engine_->total_erased_instruction_cnt(); } return [last_total_erased]() -> Maybe { auto* vm = Singleton::Get(); CHECK_NOTNULL_OR_RETURN(vm) << "virtual machine not initialized."; CHECK_OR_RETURN(!vm->NoMoreErasedInstructions(last_total_erased.get())) << "blocking instructions\n" << vm->GetBlockingDebugString(); return false; }; } bool VirtualMachine::NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const { size_t cnt = engine_->total_erased_instruction_cnt(); bool no_more_erased = (*last_total_erased_instruction_cnt == cnt); *last_total_erased_instruction_cnt = cnt; return no_more_erased; } std::string VirtualMachine::GetBlockingDebugString() { size_t limit = EnvInteger(); return engine_->GetLivelyInstructionListDebugString(limit); } void VirtualMachine::RunMainThreadPendingTasks() { std::unique_lock lock(main_thread_pending_tasks_mutex_); for (const auto& main_thread_pending_task : main_thread_pending_tasks_) { main_thread_pending_task(); } main_thread_pending_tasks_.clear(); } Maybe VirtualMachine::Receive(vm::InstructionList* instruction_list) { SyncVmModeGuard guard(SyncVmMode::kEnable); RunMainThreadPendingTasks(); if (unlikely(pthread_fork::IsForkedSubProcess())) { INTRUSIVE_FOR_EACH_PTR(instruction, instruction_list) { const auto& device = instruction->stream().device(); CHECK_OR_RETURN(device->enum_type() == DeviceType::kCPU) << pthread_fork::kOfCudaNotSupportInForkedSubProcess; JUST(instruction->Prepare()); instruction->Compute(); } instruction_list->Clear(); } else if (unlikely(threads_closed_ || !multi_thread_)) { JUST(RunInCurrentThread(instruction_list)); } else { const int64_t kHighWaterMark = GetInstructionHighWaterMark(); if (engine_->flying_instruction_cnt() > kHighWaterMark) { JUST(Singleton::Get()->WithScopedRelease([&, this]() -> Maybe { auto bc = std::make_shared(1); engine_->InsertProbe([bc](vm::VirtualMachineEngine* engine) { const int64_t kLowWaterMark = GetInstructionLowWaterMark(); if (engine->flying_instruction_cnt() > kLowWaterMark) { return false; } bc->Decrease(); return true; }); pending_notifier_.Notify(); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); })); } if (JUST(engine_->Receive(instruction_list))) { // old scheduler_pending_instruction_list is empty. pending_notifier_.Notify(); } } return Maybe::Ok(); } Maybe VirtualMachine::NotifyOrRunScheduler() { if (unlikely(pthread_fork::IsForkedSubProcess() || threads_closed_ || !multi_thread_)) { ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx()); } else { pending_notifier_.Notify(); } return Maybe::Ok(); } Maybe VirtualMachine::CloseWorkerThreads() { JUST(ForEachThreadCtx(engine_.Mutable(), [&](vm::ThreadCtx* thread_ctx) -> Maybe { thread_ctx->mut_notifier()->Close(); return Maybe::Ok(); })); { std::unique_lock lock(worker_threads_mutex_); for (const auto& worker_thread : worker_threads_) { worker_thread->join(); } } return Maybe::Ok(); } Maybe VirtualMachine::RunInCurrentThread(vm::InstructionList* instr_list) { CHECK_OR_RETURN(engine_->SchedulerEmpty()) << "vm scheduler not empty. May be a fatal error occured"; JUST(engine_->Receive(instr_list)); ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx()); return Maybe::Ok(); } namespace { class MultiThreadScheduleCtx : public vm::ScheduleCtx { public: MultiThreadScheduleCtx() = default; ~MultiThreadScheduleCtx() = default; void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const override { thread_ctx->mut_notifier()->Notify(); } }; } // namespace void VirtualMachine::ScheduleLoop(const std::function& Initializer) { SyncVmModeGuard guard(SyncVmMode::kEnable); Initializer(); MultiThreadScheduleCtx schedule_ctx{}; while (pending_notifier_.WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) { OF_PROFILER_RANGE_GUARD("VirtualMachine::ScheduleLoop"); auto start = std::chrono::steady_clock::now(); static constexpr int kWorkingMicroseconds = 1000; // Every time this thread wakes up, engine_ is scheduled for about `kWorkingMicroseconds`. // The cost of os thread switching is about 5-10 microseconds. Doing more scheduling in // a single waiting up can reach higher performance. do { // Use SchedulerThreadUnsafeEmpty to avoid acquiring mutex lock. // It's safe to use SchedulerThreadUnsafeEmpty here. pending_notifier_.notified_cnt_ will be // greater than zero when inconsistency between // engine_->pending_instruction_list.list_head_.list_head_.container_ and // engine_->pending_instruction_list.list_head_.list_head_.size_ occured. hence the pending // instructions // will get handled in the next iteration. // VirtualMachine::Receive may be less effiencient if the thread safe version // `engine_->SchedulerEmpty()` // used // here, because VirtualMachine::ScheduleLoop is more likely to get the mutex lock. do { const size_t total_inserted = engine_->total_inserted_instruction_cnt(); const size_t total_erased = engine_->total_erased_instruction_cnt(); engine_->Schedule(schedule_ctx); if (ThreadLocalEnvBool() && total_inserted == engine_->total_inserted_instruction_cnt() && total_erased == engine_->total_erased_instruction_cnt()) { // nothing handled. std::this_thread::yield(); } } while (!engine_->SchedulerThreadUnsafeEmpty()); } while (MicrosecondsFrom(start) < kWorkingMicroseconds); } ScheduleUntilVMEmpty(engine_.Mutable(), schedule_ctx); CHECK_JUST(CloseWorkerThreads()); scheduler_stopped_ = true; } intrusive::shared_ptr VirtualMachine::FindOrCreateScheduleDependence( Symbol stream) { std::unique_lock lock(stream_and_thread_ctx_mutex_); intrusive::shared_ptr* ptr = &stream2dependence_[stream]; if (!*ptr) { *ptr = intrusive::make_shared(); } return *ptr; } intrusive::shared_ptr VirtualMachine::FindOrCreateTransportLocalDepObject() { std::unique_lock lock(stream_and_thread_ctx_mutex_); if (!transport_dependence_) { transport_dependence_ = intrusive::make_shared(); } return transport_dependence_; } Maybe VirtualMachine::CreateStream(Symbol stream) { std::unique_lock lock(stream_and_thread_ctx_mutex_); vm::ThreadCtx* thread_ctx = JUST(FindOrCreateThreadCtx(stream->device(), stream->stream_type(), stream->thread_uid())); return JUST(CreateStream(thread_ctx, stream)); } Maybe VirtualMachine::GetVmStream(Symbol stream) { if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) { std::unique_lock lock(stream_and_thread_ctx_mutex_); if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) { auto* stream_mgr = JUST(SingletonMaybe()); for (int i = unique_stream_id2vm_stream_.size(); i <= stream->unique_stream_id(); ++i) { Symbol cur_stream = JUST(stream_mgr->GetStreamSymbol(i)); CHECK_EQ_OR_RETURN(cur_stream->unique_stream_id(), i) << "invalid Stream::unique_stream_id()"; unique_stream_id2vm_stream_.SetOrAdd(cur_stream->unique_stream_id(), JUST(CreateStream(cur_stream))); } } } return JUST(VectorAt(unique_stream_id2vm_stream_, stream->unique_stream_id())); } Maybe VirtualMachine::FindOrCreateThreadCtx(Symbol device, StreamType stream_type, size_t thread_uid) { std::unique_lock lock(stream_and_thread_ctx_mutex_); vm::ThreadCtx** thread_ctx_ptr = nullptr; if (StreamOnIndependentThread::Visit(stream_type)) { auto key = std::make_pair(device->enum_type(), stream_type); thread_ctx_ptr = &devcie_type_stream_type_2independent_thread_ctx_[key]; } else { thread_ctx_ptr = &thread_uid2shared_thread_ctx_[thread_uid]; } if (*thread_ctx_ptr == nullptr) { *thread_ctx_ptr = JUST(CreateThreadCtx(device, stream_type, thread_uid)); } return *thread_ctx_ptr; } Maybe VirtualMachine::CreateThreadCtx(Symbol device, StreamType stream_type, size_t thread_uid) { std::unique_lock lock(stream_and_thread_ctx_mutex_); // thread_ctx_ptr may be used after timout. auto thread_ctx_ptr = std::make_shared(nullptr); { auto bc = std::make_shared(1); engine_->InsertProbe([thread_ctx_ptr, bc](vm::VirtualMachineEngine* engine) { auto thread_ctx = intrusive::make_shared(); engine->mut_thread_ctx_list()->PushBack(thread_ctx.Mutable()); *thread_ctx_ptr = thread_ctx.Mutable(); bc->Decrease(); return true; }); JUST(NotifyOrRunScheduler()); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); } auto* thread_ctx = *thread_ctx_ptr; { const std::string thread_tag = [&] { std::string device_tag = *CHECK_JUST(DeviceTag4DeviceType(device->enum_type())); if (StreamOnIndependentThread::Visit(stream_type)) { return device_tag + GetStreamTypeName::Visit(stream_type); } else { return std::to_string(thread_uid); } }(); const auto& WorkerInitializer = [thread_tag](vm::ThreadCtx* thread_ctx) { OF_PROFILER_NAME_THIS_HOST_THREAD("_VM::Worker_" + thread_tag); }; auto thread = std::make_unique(&WorkerLoop, thread_ctx, WorkerInitializer); { std::unique_lock lock(worker_threads_mutex_); worker_threads_.push_back(std::move(thread)); } } return thread_ctx; } Maybe VirtualMachine::CreateStream(vm::ThreadCtx* thread_ctx, Symbol stream) { std::unique_lock lock(stream_and_thread_ctx_mutex_); intrusive::shared_ptr schedule_dependence = FindOrCreateScheduleDependence(stream); std::vector> transport_dependences{}; if (IsCommNetStream::Visit(stream->stream_type())) { transport_dependences.push_back(FindOrCreateTransportLocalDepObject()); } auto vm_stream = intrusive::make_shared(thread_ctx, stream->device(), stream->stream_type(), schedule_dependence, transport_dependences); auto bc = std::make_shared(1); engine_->InsertProbe([&vm_stream, thread_ctx, bc](vm::VirtualMachineEngine* engine) { thread_ctx->mut_stream_list()->PushBack(vm_stream.Mutable()); bc->Decrease(); return true; }); JUST(NotifyOrRunScheduler()); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return vm_stream.Mutable(); } } // namespace oneflow ================================================ FILE: oneflow/core/vm/virtual_machine.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_ #define ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_ #include #include "oneflow/core/common/notifier.h" #include "oneflow/core/vm/virtual_machine_engine.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/common/stream_type.h" #include "oneflow/core/common/steady_vector.h" namespace oneflow { class InstructionsBuilder; class Device; class VirtualMachine final { public: VirtualMachine(const VirtualMachine&) = delete; VirtualMachine(VirtualMachine&&) = delete; VirtualMachine(); ~VirtualMachine(); static std::function()> GetPredicatorNoMoreInstructionsFinished(); intrusive::shared_ptr FindOrCreateTransportLocalDepObject(); std::string GetBlockingDebugString(); Maybe Receive(vm::InstructionList* instr_list); Maybe CloseVMThreads(); // Never called in vm work threads. // VM sync must be called to ensure all working instructions are finished. Maybe ShrinkAllMem(); Maybe GetVmStream(Symbol stream); size_t flying_instruction_cnt() const { return engine().flying_instruction_cnt(); } void add_main_thread_pending_task(std::function task) { std::unique_lock lock(main_thread_pending_tasks_mutex_); main_thread_pending_tasks_.push_back(std::move(task)); } private: friend class InstructionsBuilder; void ScheduleLoop(const std::function& Initializer); intrusive::shared_ptr FindOrCreateScheduleDependence(Symbol stream); bool NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const; const vm::VirtualMachineEngine& engine() const { return *engine_; } vm::VirtualMachineEngine* mut_engine() { return engine_.Mutable(); } void ControlSync(); Maybe FindOrCreateThreadCtx(Symbol device, StreamType stream_type, size_t thread_uid); Maybe CreateThreadCtx(Symbol device, StreamType stream_type, size_t thread_uid); Maybe CreateStream(Symbol stream); Maybe CreateStream(vm::ThreadCtx* thread_ctx, Symbol stream); Maybe RunInCurrentThread(vm::InstructionList* instr_list); Maybe BlockingRunProbeFunc(const std::function& prob_func); Maybe NotifyOrRunScheduler(); Maybe CloseWorkerThreads(); void RunMainThreadPendingTasks(); bool multi_thread_; bool threads_closed_; bool scheduler_stopped_; intrusive::shared_ptr engine_; // for asynchronized execution std::mutex worker_threads_mutex_; std::list> worker_threads_; // for vm::Stream and vm::ThreadCtx std::recursive_mutex stream_and_thread_ctx_mutex_; HashMap thread_uid2shared_thread_ctx_; HashMap, vm::ThreadCtx*> devcie_type_stream_type_2independent_thread_ctx_; HashMap, intrusive::shared_ptr> stream2dependence_; intrusive::shared_ptr transport_dependence_; SteadyVector unique_stream_id2vm_stream_; std::thread schedule_thread_; Notifier pending_notifier_; std::mutex main_thread_pending_tasks_mutex_; std::vector> main_thread_pending_tasks_; }; } // namespace oneflow #endif // ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_ ================================================ FILE: oneflow/core/vm/virtual_machine_engine.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/virtual_machine_engine.h" #include "oneflow/core/common/env_var/vm.h" #include "oneflow/core/vm/caching_allocator.h" #include "oneflow/core/vm/fuse_instruction_policy.h" #include "oneflow/core/vm/release_tensor_instruction_policy.h" #include "oneflow/core/vm/allocator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/extension/stack/foreign_stack_getter.h" namespace oneflow { namespace vm { void VirtualMachineEngine::ReleaseInstruction(Instruction* instruction) { OF_PROFILER_RANGE_GUARD("R:" + instruction->DebugName()); auto* access_list = instruction->mut_access_list(); INTRUSIVE_FOR_EACH(access, access_list) { CHECK_GT(access->ref_cnt(), 1); access_list->Erase(access.Mutable()); auto* dependence = access->mut_dependence(); if (unlikely(!access->rw_mutexed_object_access_hook().empty())) { dependence->mut_access_list()->Erase(access.Mutable()); } } auto* out_edges = instruction->mut_out_edges(); INTRUSIVE_FOR_EACH_PTR(out_edge, out_edges) { Instruction* out_instruction = out_edge->mut_dst_instruction(); // Edges are erased only if the instruction is completed. out_edges->Erase(out_edge); out_instruction->mut_in_edges()->Erase(out_edge); if (Dispatchable(out_instruction)) { OF_PROFILER_RANGE_GUARD("E:" + out_instruction->DebugName()); mut_ready_instruction_list()->PushBack(out_instruction); } } } // Handle pending instructions, and try schedule them to ready list. void VirtualMachineEngine::HandleLocalPending() { OF_PROFILER_RANGE_GUARD("HandleLocalPending"); InstructionList pending_instructions; FetchAndTryFusePendingInstructions(&pending_instructions); INTRUSIVE_FOR_EACH_PTR(instruction, &pending_instructions) { const auto& instruction_policy = instruction->instruction_policy(); instruction->InitStatus(); LivelyInstructionListPushBack(instruction); if (unlikely(instruction_policy.IsBarrier())) { mut_barrier_instruction_list()->PushBack(instruction); } else { ConsumeDependences(instruction); if (likely(Dispatchable(instruction))) { mut_ready_instruction_list()->PushBack(instruction); } } } } namespace { bool FusableBetween(InstructionFuseType fuse_type, Instruction* instruction, Instruction* prev_instruction) { if (unlikely(instruction->instruction_policy().fuse_type() != fuse_type)) { return false; } auto* stream = instruction->mut_stream(); if (unlikely(stream == nullptr)) { return false; } auto* sequential_dep = instruction->instruction_policy().stream_sequential_dependence(); if (unlikely(sequential_dep == nullptr)) { return false; } if (unlikely(prev_instruction == nullptr)) { return true; } if (unlikely(stream != prev_instruction->mut_stream())) { return false; } if (unlikely(sequential_dep != prev_instruction->instruction_policy().stream_sequential_dependence())) { return false; } return true; } } // namespace void VirtualMachineEngine::MakeAndAppendFusedInstruction( InstructionList&& fused_instruction_list, InstructionList* /*out*/ pending_instructions) { if (unlikely(fused_instruction_list.size() == 0)) { return; } if (unlikely(fused_instruction_list.size() == 1)) { fused_instruction_list.MoveTo(pending_instructions); return; } auto* begin = fused_instruction_list.Begin(); auto instruction = intrusive::make_shared( begin->mut_stream(), std::make_shared(std::move(fused_instruction_list))); pending_instructions->EmplaceBack(std::move(instruction)); } void VirtualMachineEngine::FetchAndTryFusePendingInstructions( InstructionList* /*out*/ pending_instructions) { size_t window_size = ThreadLocalEnvInteger(); InstructionList fused_instruction_list; INTRUSIVE_FOR_EACH_PTR(instruction, mut_local_pending_instruction_list()) { if (window_size-- <= 0) { break; } auto* fuse_begin = fused_instruction_list.Begin(); if (likely(FusableBetween(kEnableInstructionFuseAtAnyPosition, instruction, fuse_begin))) { // fuse mut_local_pending_instruction_list()->MoveToDstBack(instruction, &fused_instruction_list); } else if (likely(FusableBetween(kEnableInstructionFuseAsTailOnly, instruction, fuse_begin))) { // fuse mut_local_pending_instruction_list()->MoveToDstBack(instruction, &fused_instruction_list); MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions); } else { // no fuse MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions); mut_local_pending_instruction_list()->MoveToDstBack(instruction, pending_instructions); } } MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions); } std::string VirtualMachineEngine::GetLivelyInstructionListDebugString(int64_t debug_cnt) { std::stringstream ss; INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_lively_instruction_list()) { if (--debug_cnt <= 0) { break; } ss << instruction->DebugName() << " ptr: " << instruction << " dispatched:" << (instruction->dispatched_instruction_hook().empty() ? "0" : "1") << " launched:" << (instruction->Launched() ? "1" : "0") << " done:" << (instruction->Done() ? "1" : "0"); INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_in_edges()) { ss << " dep-ptr:" << &edge->src_instruction(); } ss << "\n"; } return ss.str(); } void VirtualMachineEngine::LivelyInstructionListPushBack(Instruction* instruction) { ++total_inserted_instruction_cnt_; mut_lively_instruction_list()->PushBack(instruction); } void VirtualMachineEngine::InsertProbe( const std::function& ProbeFunction) { probe_list_.EmplaceBack(intrusive::make_shared(ProbeFunction)); } void VirtualMachineEngine::HandleLocalProbe() { OF_PROFILER_RANGE_GUARD("HandleLocalProbe"); if (unlikely(local_probe_list_.size())) { OF_PROFILER_RANGE_PUSH("HandleLocalProbe"); INTRUSIVE_FOR_EACH_PTR(probe, &local_probe_list_) { if (probe->probe_function()(this)) { local_probe_list_.Erase(probe); } } OF_PROFILER_RANGE_POP(); } } intrusive::shared_ptr VirtualMachineEngine::LivelyInstructionListErase( Instruction* instruction) { ++total_erased_instruction_cnt_; return mut_lively_instruction_list()->Erase(instruction); } // Collect ready instructions onto ready_instruction_list_ void VirtualMachineEngine::ReleaseFinishedInstructions(const ScheduleCtx& schedule_ctx) { INTRUSIVE_FOR_EACH_PTR(stream, mut_active_stream_list()) { while (true) { auto* instruction_ptr = stream->mut_running_instruction_list()->Begin(); if (instruction_ptr == nullptr) { break; } if (!(instruction_ptr->in_edges().empty() && instruction_ptr->Done())) { break; } ReleaseInstruction(instruction_ptr); // Prevent destructing instruction_ptr. intrusive::shared_ptr instruction = stream->mut_running_instruction_list()->Erase(instruction_ptr); LivelyInstructionListErase(instruction_ptr); instruction_ptr->DeleteStatusAndCheckEdges(); } if (stream->running_instruction_list().empty()) { mut_active_stream_list()->Erase(stream); } } } DependenceAccess* VirtualMachineEngine::AccessDependence(OperandAccessType access_type, Dependence* dependence, Instruction* instruction) { auto access = access_pool_.make_shared(instruction, dependence, access_type); auto* ptr = access.Mutable(); instruction->mut_access_list()->PushBack(ptr); dependence->mut_access_list()->EmplaceBack(std::move(access)); return ptr; } void VirtualMachineEngine::TryConnectInstruction(Instruction* src_instruction, Instruction* dst_instruction) { if (unlikely(src_instruction == dst_instruction)) { return; } if (likely(EdgeDispatchable(src_instruction, dst_instruction))) { return; } auto edge = instruction_edge_pool_.make_shared(src_instruction, dst_instruction); src_instruction->mut_out_edges()->PushBack(edge.Mutable()); dst_instruction->mut_in_edges()->PushBack(edge.Mutable()); } void VirtualMachineEngine::ConnectInstructionsByWrite(DependenceAccess* dst_access) { CHECK(dst_access->is_mut_operand()); auto* dependence = dst_access->mut_dependence(); auto* dst_instruction = dst_access->mut_instruction(); auto* access_list = dependence->mut_access_list(); if (likely(access_list->Begin() == dst_access)) { return; } INTRUSIVE_FOR_EACH_PTR(src_access, access_list) { if (unlikely(src_access == dst_access)) { break; } TryConnectInstruction(src_access->mut_instruction(), dst_instruction); access_list->Erase(src_access); } } void VirtualMachineEngine::ConnectInstructionsByRead(DependenceAccess* dst_access) { CHECK(dst_access->is_const_operand()); auto* dependence = dst_access->mut_dependence(); auto* dst_instruction = dst_access->mut_instruction(); auto* first = dependence->mut_access_list()->Begin(); if (first->is_mut_operand()) { TryConnectInstruction(first->mut_instruction(), dst_instruction); } else if (first->is_const_operand()) { // do nothing } else { UNIMPLEMENTED(); } } void VirtualMachineEngine::ConsumeDependences(Instruction* instruction) { const auto& instruction_policy = instruction->instruction_policy(); auto* stream_sequential_dep = instruction_policy.stream_sequential_dependence(); if (likely(stream_sequential_dep != nullptr)) { ConnectInstructionsByWrite( AccessDependence(kMutableOperandAccess, stream_sequential_dep, instruction)); } // Connect instructions by write before connecting by read. for (auto* dependence : instruction_policy.output_dependences()) { ConnectInstructionsByWrite(AccessDependence(kMutableOperandAccess, dependence, instruction)); } for (auto* dependence : instruction_policy.input_dependences()) { ConnectInstructionsByRead(AccessDependence(kConstOperandAccess, dependence, instruction)); } } bool VirtualMachineEngine::EdgeDispatchable(const Instruction* src, const Instruction* dst) const { return dst->instruction_policy().Prescheduleable(&src->stream(), &dst->stream()) && !src->dispatched_instruction_hook().empty() /* dispatched */; } bool VirtualMachineEngine::Dispatchable(Instruction* instruction) const { if (unlikely(!instruction->dispatched_instruction_hook().empty())) { return false; } INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_in_edges()) { const auto* src_instruction = &edge->src_instruction(); if (unlikely(!EdgeDispatchable(src_instruction, instruction))) { return false; } } return true; } // Dispatch ready instructions and put prescheduled instructions onto ready_instruction_list_. void VirtualMachineEngine::DispatchAndPrescheduleInstructions(const ScheduleCtx& schedule_ctx) { OF_PROFILER_RANGE_GUARD("DispatchAndPrescheduleInstructions"); ReadyInstructionList tmp_ready_instruction_list; mut_ready_instruction_list()->MoveTo(&tmp_ready_instruction_list); INTRUSIVE_FOR_EACH(instruction, &tmp_ready_instruction_list) { // Erases `instruction` from tmp_ready_instruction_list before dispatching, because // `instruction.dispatched_instruction_hook_` are used in DispatchInstruction. tmp_ready_instruction_list.Erase(instruction.Mutable()); OF_PROFILER_RANGE_GUARD("D:" + instruction->DebugName()); DispatchInstruction(instruction.Mutable(), schedule_ctx); // preschedule instructions INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_out_edges()) { auto* out_instruction = edge->mut_dst_instruction(); if (Dispatchable(out_instruction)) { OF_PROFILER_RANGE_GUARD("P:" + out_instruction->DebugName()); mut_ready_instruction_list()->PushBack(out_instruction); } } } } namespace { std::string DebugDeviceReset(vm::Stream* stream) { stream->mut_stream_policy()->mut_allocator()->DeviceReset(); return "reset device"; } } // namespace void VirtualMachineEngine::DispatchInstruction(Instruction* instruction, const ScheduleCtx& schedule_ctx) { ForeignFrameThreadLocalGuard guard(instruction->foreign_frame()); auto* stream = instruction->mut_stream(); // Prepare { const auto& ret = TRY(instruction->Prepare()); if (unlikely(!ret.IsOk())) { if (ret.error()->has_out_of_memory_error()) { CHECK_JUST_MSG(ret, std::stringstream() << DebugDeviceReset(stream)); } else { CHECK_JUST(ret); } } } stream->mut_running_instruction_list()->PushBack(instruction); if (stream->active_stream_hook().empty()) { mut_active_stream_list()->PushBack(stream); } // Compute if (OnSchedulerThread(*stream)) { stream->stream_policy().RunIf(instruction); } else { stream->mut_thread_ctx()->mut_worker_pending_instruction_list()->PushBack(instruction); schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx()); } } // Returns true if old scheduler_pending_instruction_list is empty Maybe VirtualMachineEngine::Receive(InstructionList* compute_instruction_list) { OF_PROFILER_RANGE_GUARD("vm:Receive"); #ifdef OF_ENABLE_PROFILER INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instruction, compute_instruction_list) { OF_PROFILER_RANGE_GUARD(compute_instruction->DebugName()); } #endif bool old_list_empty = mut_pending_instruction_list()->MoveFrom(compute_instruction_list); return old_list_empty; } bool VirtualMachineEngine::OnSchedulerThread(const Stream& stream) { return stream.on_scheduler_thread() || pthread_fork::IsForkedSubProcess(); } // Barrier instructions are run after all previous lively instructions. // // `instruction.lively_instruction_hook_` is linked to `vm.lively_instruction_list_` for all // instructions. `instruction.barrier_instruction_list_` is linked to `vm.barrier_instruction_list_` // only for barrier instructions. // // // e.g. case0: waiting other instructions done. // // +---------------------------+ +---------------------------+ +---------------------------+ // | virtual_machine | | instruction0 | | instruction1 | // +---------------------------+ +---------------------------+ +---------------------------+ // | ... | | ... | | ... | // |---------------------------| |---------------------------| |---------------------------| // | lively_instruction_list_ |<->| lively_instruction_hook_ |<->| lively_instruction_hook_ | // |---------------------------| |---------------------------| |---------------------------| // | ... | | ... | | ... | // |---------------------------| |---------------------------| |---------------------------| // | barrier_instruction_list_ |<+ | barrier_instruction_hook_ | +>| barrier_instruction_hook_ | // |---------------------------| | |---------------------------| | |---------------------------| // | ... | | | ... | | | ... | // +---------------------------+ | +---------------------------+ | +---------------------------+ // | | // +-------------------------------+ // // `instruction1` is a barrier instruction with barrier_instruction_hook_ linked, while // instruction0 is not. From the `virtual_machine`'s view, `barrier_instruction_list_.Begin() != // lively_instruction_list_.Begin()`, so it's not the time to run barrier instruction // `barrier_instruction_list_.Begin()`. // // // e.g. case1: run barrier instructions. // // +---------------------------+ +---------------------------+ +---------------------------+ // | virtual_machine | | instruction0 | | instruction1 | // +---------------------------+ +---------------------------+ +---------------------------+ // | ... | | ... | | ... | // |---------------------------| |---------------------------| |---------------------------| // | lively_instruction_list_ |<->| lively_instruction_hook_ |<->| lively_instruction_hook_ | // |---------------------------| |---------------------------| |---------------------------| // | ... | | ... | | ... | // |---------------------------| |---------------------------| |---------------------------| // | barrier_instruction_list_ |<->| barrier_instruction_hook_ | | barrier_instruction_hook_ | // |---------------------------| |---------------------------| |---------------------------| // | ... | | ... | | ... | // +---------------------------+ +---------------------------+ +---------------------------+ // // `instruction0` is a barrier instruction with barrier_instruction_hook_ linked. // From the `virtual_machine`'s view, `barrier_instruction_list_.Begin() == // lively_instruction_list_.Begin()`, so it's the time to run barrier instruction // `barrier_instruction_list_.Begin()`. // // // With the introduction of barrier_instruction_list_/barrier_instruction_hook_, the function // VirtualMachineEngine::Schedule can achive higher performance. For the most cases, barrier // instructions are scarcely received by vm, there is no need for vm to run // VirtualMachineEngine::TryRunBarrierInstruction every time VirtualMachineEngine::Schedule run. On // the other hand, `barrier_instruction_hook_.size() == 0` is more lightweight than // `lively_instruction_list_.Begin()?->instruction_policy().IsBarrier()` // void VirtualMachineEngine::TryRunBarrierInstruction(const ScheduleCtx& schedule_ctx) { auto* sequnential_instruction = mut_barrier_instruction_list()->Begin(); CHECK_NOTNULL(sequnential_instruction); if (likely(sequnential_instruction != mut_lively_instruction_list()->Begin())) { return; } // All instructions before `sequnential_instruction` are handled now, it's time to handle // `sequnential_instruction`. OF_PROFILER_RANGE_GUARD("TryRunBarrierInstruction"); const auto& instruction_policy = sequnential_instruction->instruction_policy(); CHECK(instruction_policy.IsBarrier()); CHECK(OnSchedulerThread(sequnential_instruction->stream())); const StreamPolicy& stream_policy = sequnential_instruction->stream().stream_policy(); stream_policy.RunIf(sequnential_instruction); mut_barrier_instruction_list()->Erase(sequnential_instruction); LivelyInstructionListErase(sequnential_instruction); } void VirtualMachineEngine::Schedule(const ScheduleCtx& schedule_ctx) { // Release finished instructions and try to schedule out instructions in DAG onto ready list. if (unlikely(mut_active_stream_list()->size())) { ReleaseFinishedInstructions(schedule_ctx); } // Try run the first barrier instruction. if (unlikely(mut_barrier_instruction_list()->size())) { TryRunBarrierInstruction(schedule_ctx); } // Handle pending instructions, and try schedule them to ready list. // Use thread_unsafe_size to avoid acquiring mutex lock. // The inconsistency between pending_instruction_list.list_head_.list_head_.container_ and // pending_instruction_list.list_head_.list_head_.size_ is not a fatal error because // VirtualMachineEngine::Schedule is always in a busy loop. All instructions will get handled // eventually. // VirtualMachineEngine::Receive may be less effiencient if the thread safe version // `pending_instruction_list().size()` used here, because VirtualMachineEngine::Schedule is more // likely to get the mutex lock. if (unlikely(local_pending_instruction_list().size())) { HandleLocalPending(); } else if (unlikely(pending_instruction_list().thread_unsafe_size())) { // MoveTo is under a lock. mut_pending_instruction_list()->MoveTo(mut_local_pending_instruction_list()); if (local_pending_instruction_list().size()) { HandleLocalPending(); } } // dispatch ready instructions and try to schedule out instructions in DAG onto ready list. if (unlikely(mut_ready_instruction_list()->size())) { DispatchAndPrescheduleInstructions(schedule_ctx); } // handle scheduler probes if (unlikely(local_probe_list_.size())) { HandleLocalProbe(); } else if (unlikely(probe_list_.thread_unsafe_size())) { probe_list_.MoveTo(&local_probe_list_); if (local_probe_list_.size()) { HandleLocalProbe(); } } } bool VirtualMachineEngine::SchedulerThreadUnsafeEmpty() const { return pending_instruction_list().thread_unsafe_size() == 0 && local_pending_instruction_list().empty() && lively_instruction_list_.empty() && active_stream_list().empty() && probe_list_.thread_unsafe_size() == 0 && local_probe_list_.empty(); } bool VirtualMachineEngine::SchedulerEmpty() const { // hook and size will be check in pending_instruction_list().empty(). return pending_instruction_list().empty() && probe_list_.empty() && SchedulerThreadUnsafeEmpty(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/virtual_machine_engine.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_ #define ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_ #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/common/range.h" #include "oneflow/core/intrusive/mutexed_list.h" #include "oneflow/core/intrusive/object_pool.h" #include "oneflow/core/vm/probe.h" namespace oneflow { namespace vm { class ThreadCtx; class ScheduleCtx { public: ScheduleCtx() = default; virtual ~ScheduleCtx() = default; virtual void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const = 0; }; using ReadyInstructionList = intrusive::List; class VirtualMachineEngine final : public intrusive::Base { public: // types using ActiveStreamList = intrusive::List; using ThreadCtxList = intrusive::List; using InstructionList = intrusive::List; using LivelyInstructionList = intrusive::List; using BarrierInstructionList = intrusive::List; using InstructionMutexedList = intrusive::MutexedList; // Getters std::size_t flying_instruction_cnt() const { return pending_instruction_list().thread_unsafe_size() + local_pending_instruction_list().size() + (total_inserted_instruction_cnt() - total_erased_instruction_cnt()); } size_t total_inserted_instruction_cnt() const { return total_inserted_instruction_cnt_; } size_t total_erased_instruction_cnt() const { return total_erased_instruction_cnt_; } void InsertProbe(const std::function& ProbeFunction); const ActiveStreamList& active_stream_list() const { return active_stream_list_; } const ThreadCtxList& thread_ctx_list() const { return thread_ctx_list_; } const LivelyInstructionList& lively_instruction_list() const { return lively_instruction_list_; } const BarrierInstructionList& barrier_instruction_list() const { return barrier_instruction_list_; } const InstructionMutexedList& pending_instruction_list() const { return pending_instruction_list_; } const InstructionList& local_pending_instruction_list() const { return local_pending_instruction_list_; } // Setters ActiveStreamList* mut_active_stream_list() { return &active_stream_list_; } ThreadCtxList* mut_thread_ctx_list() { return &thread_ctx_list_; } LivelyInstructionList* mut_lively_instruction_list() { return &lively_instruction_list_; } BarrierInstructionList* mut_barrier_instruction_list() { return &barrier_instruction_list_; } InstructionMutexedList* mut_pending_instruction_list() { return &pending_instruction_list_; } InstructionList* mut_local_pending_instruction_list() { return &local_pending_instruction_list_; } // Returns true if old scheduler_pending_instruction_list is empty Maybe Receive(InstructionList* instr_list); void Schedule(const ScheduleCtx& schedule_ctx); bool SchedulerThreadUnsafeEmpty() const; bool SchedulerEmpty() const; std::string GetLivelyInstructionListDebugString(int64_t debug_cnt); void MoveToGarbageListAndNotifyGC(const ScheduleCtx& schedule_ctx); private: ReadyInstructionList* mut_ready_instruction_list() { return &ready_instruction_list_; } void ReleaseFinishedInstructions(const ScheduleCtx& schedule_ctx); void HandleLocalPending(); void FetchAndTryFusePendingInstructions(InstructionList* /*out*/ pending_instructions); void MakeAndAppendFusedInstruction(InstructionList&& fused_instruction_list, InstructionList* /*out*/ pending_instructions); void TryRunBarrierInstruction(const ScheduleCtx& schedule_ctx); void DispatchAndPrescheduleInstructions(const ScheduleCtx& schedule_ctx); bool OnSchedulerThread(const vm::Stream& stream); void ReleaseInstruction(Instruction* instruction); void TryConnectInstruction(Instruction* src_instruction, Instruction* dst_instruction); void ConnectInstructionsByWrite(DependenceAccess* dst_access); void ConnectInstructionsByRead(DependenceAccess* dst_access); DependenceAccess* AccessDependence(OperandAccessType access_type, Dependence* dependence, Instruction* instrution); void ConsumeDependences(Instruction* instruction); void DispatchInstruction(Instruction* instruction, const ScheduleCtx& schedule_ctx); bool EdgeDispatchable(const Instruction* src, const Instruction* dst) const; bool Dispatchable(Instruction* instruction) const; void TryDispatchReadyInstructions(); void LivelyInstructionListPushBack(Instruction* instruction); intrusive::shared_ptr LivelyInstructionListErase(Instruction* instruction); void HandleLocalProbe(); friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } VirtualMachineEngine() : intrusive_ref_(), active_stream_list_(), thread_ctx_list_(), pending_instruction_mutex_(), pending_instruction_list_(&pending_instruction_mutex_), local_pending_instruction_list_(), ready_instruction_list_(), lively_instruction_list_(), total_inserted_instruction_cnt_(0), total_erased_instruction_cnt_(0), probe_mutex_(), probe_list_(&probe_mutex_), local_probe_list_(), barrier_instruction_list_() {} intrusive::Ref intrusive_ref_; // lists or maps // Do not change the order of the following fields ActiveStreamList active_stream_list_; ThreadCtxList thread_ctx_list_; std::mutex pending_instruction_mutex_; InstructionMutexedList pending_instruction_list_; // local_pending_instruction_list_ should be consider as the cache of pending_instruction_list_. InstructionList local_pending_instruction_list_; ReadyInstructionList ready_instruction_list_; LivelyInstructionList lively_instruction_list_; size_t total_inserted_instruction_cnt_; size_t total_erased_instruction_cnt_; using VmProbe = Probe>; std::mutex probe_mutex_; intrusive::MutexedList probe_list_; intrusive::List local_probe_list_; BarrierInstructionList barrier_instruction_list_; DependenceAccess::object_pool_type access_pool_; InstructionEdge::object_pool_type instruction_edge_pool_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_ ================================================ FILE: oneflow/core/vm/virtual_machine_scope.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/virtual_machine_scope.h" #include "oneflow/core/vm/virtual_machine_engine.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { namespace vm { VirtualMachineScope::VirtualMachineScope(const Resource& resource) { Singleton::New(); } VirtualMachineScope::~VirtualMachineScope() { Singleton::Delete(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/virtual_machine_scope.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/resource.pb.h" namespace oneflow { namespace vm { class VirtualMachineScope { public: VirtualMachineScope(const Resource& resource); ~VirtualMachineScope(); }; } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/vm_object.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/common/util.h" namespace oneflow { namespace vm { void DependenceAccess::__Init__() { clear_instruction(); clear_dependence(); } void DependenceAccess::__Init__(Instruction* instruction, Dependence* dependence, OperandAccessType access_type) { __Init__(); set_instruction(instruction); set_dependence(dependence); set_access_type(access_type); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/vm_object.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_VM_OBJECT_H_ #define ONEFLOW_CORE_VM_VM_OBJECT_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/object_pool.h" namespace oneflow { namespace vm { class Instruction; class Dependence; using DependenceVector = std::vector; enum OperandAccessType { kConstOperandAccess = 0, kMutableOperandAccess, }; class DependenceAccess final : public intrusive::Base, public intrusive::EnableObjectPool { public: void __Init__(); // Getters OperandAccessType access_type() const { return access_type_; } bool has_instruction() const { return instruction_ != nullptr; } bool has_dependence() const { return dependence_ != nullptr; } const Instruction& instruction() const { return *instruction_; } const Dependence& dependence() const { return *dependence_; } const intrusive::ListHook& rw_mutexed_object_access_hook() const { return rw_mutexed_object_access_hook_; } // Setters void set_access_type(OperandAccessType val) { access_type_ = val; } void set_instruction(Instruction* val) { instruction_ = val; } void set_dependence(Dependence* val) { dependence_ = val; } void clear_instruction() { instruction_ = nullptr; } void clear_dependence() { dependence_ = nullptr; } Instruction* mut_instruction() { return instruction_; } Dependence* mut_dependence() { return dependence_; } // methods void __Init__(Instruction* instruction, Dependence* dependence, OperandAccessType access_type); bool is_const_operand() const { return kConstOperandAccess == access_type(); } bool is_mut_operand() const { return kMutableOperandAccess == access_type(); } intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); } intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } // NOLINT private: friend class intrusive::Ref; DependenceAccess() : intrusive_ref_(), access_type_(), instruction_(), dependence_(), instruction_access_hook_(), rw_mutexed_object_access_hook_() {} intrusive::Ref intrusive_ref_; // fields OperandAccessType access_type_; Instruction* instruction_; Dependence* dependence_; public: // list hooks intrusive::ListHook instruction_access_hook_; intrusive::ListHook rw_mutexed_object_access_hook_; }; // NOLINT class Dependence final : public intrusive::Base { public: // types using DependenceAccessList = intrusive::List; // Setters DependenceAccessList* mut_access_list() { return &access_list_; } // methods void __Init__() {} intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); } private: friend class intrusive::Ref; intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } Dependence() : intrusive_ref_(), access_list_() {} intrusive::Ref intrusive_ref_; // list hooks DependenceAccessList access_list_; }; } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_VM_OBJECT_H_ ================================================ FILE: oneflow/core/vm/vm_sync.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_SYNC_H_ #define ONEFLOW_CORE_VM_SYNC_H_ #include "oneflow/core/common/maybe.h" namespace oneflow { namespace vm { Maybe ClusterSync(); Maybe CurrentRankSync(); } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_SYNC_H_ ================================================ FILE: oneflow/core/vm/vm_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace vm { Maybe Run(vm::InstructionList* instruction_list) { auto* virtual_machine = JUST(SingletonMaybe()); JUST(virtual_machine->Receive(instruction_list)); return Maybe::Ok(); } Maybe ClusterSync() { auto bc = std::make_shared(1); JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe { JUST(builder->GlobalSync()); JUST(builder->Barrier([bc]() { bc->Decrease(); })); return Maybe::Ok(); })); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); } Maybe CurrentRankSync() { auto bc = std::make_shared(1); JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe { JUST(builder->Barrier([bc]() { bc->Decrease(); })); return Maybe::Ok(); })); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); return Maybe::Ok(); } } // namespace vm } // namespace oneflow ================================================ FILE: oneflow/core/vm/vm_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_VM_H_ #define ONEFLOW_CORE_VM_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/vm_sync.h" namespace oneflow { namespace vm { class Instruction; Maybe Run(vm::InstructionList* instruction_list); Maybe ClusterSync(); Maybe CurrentRankSync(); } // namespace vm } // namespace oneflow #endif // ONEFLOW_CORE_VM_H_ ================================================ FILE: oneflow/extension/python/numpy.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/stride.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/registry_error.h" #include "oneflow/extension/python/numpy_internal.h" namespace py = pybind11; namespace oneflow { namespace numpy { NumPyArrayInternal::NumPyArrayInternal(PyObject* obj, const std::function& deleter) : obj_((PyArrayObject*)obj), deleter_(deleter) { CHECK_OR_THROW(PyArray_Check(obj)) << "The object is not a numpy array."; CHECK_OR_THROW(PyArray_ISCONTIGUOUS(obj_)) << "Contiguous array is expected."; size_ = PyArray_SIZE(obj_); data_ = PyArray_DATA(obj_); } NumPyArrayInternal::~NumPyArrayInternal() { if (deleter_) { deleter_(); } } Maybe OFDataTypeToNumpyType(DataType of_data_type) { switch (of_data_type) { case DataType::kBool: return NPY_BOOL; case DataType::kFloat: return NPY_FLOAT32; case DataType::kDouble: return NPY_FLOAT64; case DataType::kInt8: return NPY_INT8; case DataType::kInt16: return NPY_INT16; case DataType::kChar: return NPY_INT8; case DataType::kInt32: return NPY_INT32; case DataType::kInt64: return NPY_INT64; case DataType::kUInt8: return NPY_UINT8; case DataType::kFloat16: return NPY_FLOAT16; case DataType::kComplex64: return NPY_COMPLEX64; case DataType::kComplex128: return NPY_COMPLEX128; default: return Error::InvalidValueError() << "OneFlow data type " << DataType_Name(of_data_type) << " is not valid to Numpy data type."; } } Maybe NumpyTypeToOFDataType(int np_type) { switch (np_type) { case NPY_BOOL: return DataType::kBool; case NPY_FLOAT32: return DataType::kFloat; case NPY_FLOAT64: return DataType::kDouble; case NPY_INT8: return DataType::kInt8; case NPY_INT16: return DataType::kInt16; case NPY_INT32: return DataType::kInt32; case NPY_INT64: case NPY_LONGLONG: return DataType::kInt64; case NPY_UINT8: return DataType::kUInt8; case NPY_FLOAT16: return DataType::kFloat16; case NPY_COMPLEX64: return DataType::kComplex64; case NPY_COMPLEX128: return DataType::kComplex128; default: return Error::InvalidValueError() << "Numpy data type " << std::to_string(np_type) << " is not valid to OneFlow data type."; } } Maybe GetOFDataTypeFromNpArray(PyArrayObject* array) { int np_array_type = PyArray_TYPE(array); return NumpyTypeToOFDataType(np_array_type); } std::vector OFShapeToNumpyShape(const DimVector& fixed_vec) { size_t ndim = fixed_vec.size(); auto result = std::vector(ndim); for (int i = 0; i < ndim; i++) { result[i] = fixed_vec.at(i); } return result; } // NumPy strides use bytes. OneFlow strides use element counts. std::vector OFStrideToNumpyStride(const Stride& stride, const DataType data_type) { size_t ndim = stride.size(); auto result = std::vector(ndim); int byte_per_elem = GetSizeOfDataType(data_type); for (int i = 0; i < ndim; i++) { result[i] = stride.at(i) * byte_per_elem; } return result; } bool PyArrayCheckLongScalar(PyObject* obj) { return PyArray_CheckScalar(obj) && PyDataType_ISINTEGER(PyArray_DescrFromScalar(obj)); } bool PyArrayCheckFloatScalar(PyObject* obj) { return PyArray_CheckScalar(obj) && PyDataType_ISFLOAT(PyArray_DescrFromScalar(obj)); } bool PyArrayCheckBoolScalar(PyObject* obj) { return PyArray_CheckScalar(obj) && PyDataType_ISBOOL(PyArray_DescrFromScalar(obj)); } bool PyArrayCheckComplexScalar(PyObject* obj) { return PyArray_CheckScalar(obj) && PyDataType_ISCOMPLEX(PyArray_DescrFromScalar(obj)); } // Executing any numpy c api before _import_array() results in segfault // NOTE: this InitNumpyCAPI() works because of `PY_ARRAY_UNIQUE_SYMBOL` // defined in numpy_internal.h // Reference: // https://numpy.org/doc/stable/reference/c-api/array.html#importing-the-api Maybe InitNumpyCAPI() { CHECK_ISNULL_OR_RETURN(PyArray_API); CHECK_EQ_OR_RETURN(_import_array(), 0) << ". Unable to import Numpy array, try to upgrade Numpy version!"; return Maybe::Ok(); } } // namespace numpy } // namespace oneflow ================================================ FILE: oneflow/extension/python/numpy.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_PYTHON_NUMPY_H_ #define ONEFLOW_EXTENSION_PYTHON_NUMPY_H_ #define NO_IMPORT_ARRAY #include "oneflow/extension/python/numpy_internal.h" namespace oneflow { class NumPyArrayPtr final { public: NumPyArrayPtr(PyObject* obj) : internal_(std::make_shared(obj, []() -> void {})) {} NumPyArrayPtr(PyObject* obj, const std::function& deleter) : internal_(std::make_shared(obj, deleter)) {} void* data() const { return internal_->data(); } size_t size() const { return internal_->size(); } private: std::shared_ptr internal_; }; } // namespace oneflow #endif // ONEFLOW_EXTENSION_PYTHON_NUMPY_H_ ================================================ FILE: oneflow/extension/python/numpy_internal.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // ************************ // // NOTE: Do NOT include this file (numpy_internal.h) directly. // Include numpy.h instead. // // ************************ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/shape_vec.h" // PyArrayObject cannot be forward declared, or a compile error will occur // https://numpy.org/doc/stable/reference/c-api/array.html?highlight=array%20api#importing-the-api #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL oneflow_ARRAY_API #include namespace oneflow { class Stride; namespace numpy { class NumPyArrayInternal final { public: NumPyArrayInternal(PyObject* obj, const std::function& deleter); ~NumPyArrayInternal(); void* data() const { return data_; } size_t size() const { return size_; } private: PyArrayObject* obj_; void* data_; size_t size_; std::function deleter_; }; Maybe OFDataTypeToNumpyType(DataType of_data_type); Maybe NumpyTypeToOFDataType(int np_array_type); Maybe GetOFDataTypeFromNpArray(PyArrayObject* array); std::vector OFShapeToNumpyShape(const DimVector& fixed_vec); std::vector OFStrideToNumpyStride(const Stride& stride, const DataType data_type); bool PyArrayCheckLongScalar(PyObject* obj); bool PyArrayCheckFloatScalar(PyObject* obj); bool PyArrayCheckBoolScalar(PyObject* obj); bool PyArrayCheckComplexScalar(PyObject* obj); Maybe InitNumpyCAPI(); } // namespace numpy } // namespace oneflow ================================================ FILE: oneflow/extension/python/py_compute.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/extension/python/py_compute.h" #define PY_SSIZE_T_CLEAN #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/framework/util.h" #include "oneflow/extension/python/numpy.h" namespace oneflow { namespace pyext { namespace { static PyObject* py_kernels_dic = nullptr; #define TENSOR_MEM_CAST(dtype) static_cast(const_cast(tensor->dptr())) void* TensorToMem(const user_op::Tensor* tensor) { switch (tensor->data_type()) { case DataType::kFloat: return TENSOR_MEM_CAST(float); case DataType::kDouble: return TENSOR_MEM_CAST(double); case DataType::kBool: return TENSOR_MEM_CAST(bool); case DataType::kInt8: return TENSOR_MEM_CAST(int8_t); case DataType::kInt32: return TENSOR_MEM_CAST(int32_t); case DataType::kInt64: return TENSOR_MEM_CAST(int64_t); case DataType::kUInt8: return TENSOR_MEM_CAST(uint8_t); case DataType::kFloat16: return TENSOR_MEM_CAST(float16); default: LOG(FATAL) << "OneFlow data type " << DataType_Name(tensor->data_type()) << " is not supported yet."; return nullptr; } } void TensorToNumpy(const user_op::Tensor* tensor, PyObject** arg_ptr) { if (tensor == nullptr) { Py_INCREF(Py_None); *arg_ptr = Py_None; return; } int type_num = CHECK_JUST(numpy::OFDataTypeToNumpyType(tensor->data_type())); VLOG(3) << "Tensor data type " << DataType_Name(tensor->data_type()) << " Numpy type " << type_num; int dim_size = tensor->shape_view().NumAxes(); npy_intp dims[dim_size]; FOR_RANGE(size_t, i, 0, dim_size) { dims[i] = tensor->shape_view().At(i); } void* data = TensorToMem(tensor); auto* np_array = reinterpret_cast(PyArray_SimpleNewFromData(dim_size, dims, type_num, data)); // Numpy will not release the data PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA); *arg_ptr = reinterpret_cast(np_array); } #define TENSOR_MEM_ASSIGN(dtype) \ do { \ dtype* array_data = static_cast(array_data_ptr); \ FOR_RANGE(int64_t, i, 0, size) { tensor->mut_dptr()[i] = array_data[i]; } \ } while (0) void MemToTensor(void* array_data_ptr, const size_t size, user_op::Tensor* tensor) { switch (tensor->data_type()) { case DataType::kFloat: TENSOR_MEM_ASSIGN(float); break; case DataType::kDouble: TENSOR_MEM_ASSIGN(double); break; case DataType::kBool: TENSOR_MEM_ASSIGN(bool); break; case DataType::kInt8: TENSOR_MEM_ASSIGN(int8_t); break; case DataType::kInt32: TENSOR_MEM_ASSIGN(int32_t); break; case DataType::kInt64: TENSOR_MEM_ASSIGN(int64_t); break; case DataType::kUInt8: TENSOR_MEM_ASSIGN(uint8_t); break; case DataType::kFloat16: TENSOR_MEM_ASSIGN(float16); break; default: LOG(FATAL) << "OneFlow data type " << DataType_Name(tensor->data_type()) << " is not supported yet."; } } void NumpyToTensor(PyObject* arg, user_op::Tensor* tensor) { PyObject* ro_array = PyArray_FromAny(arg, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr); // PyArray_FromAny has increased the reference count Py_DECREF(ro_array); PyArrayObject* array = reinterpret_cast(ro_array); DataType of_data_type = CHECK_JUST(numpy::GetOFDataTypeFromNpArray(array)); CHECK_EQ(of_data_type, tensor->data_type()) << "Numpy to OneFlow data type " << DataType_Name(of_data_type) << " is not equal to OneFlow tensor data type " << DataType_Name(tensor->data_type()); int64_t array_elem_cnt = 1; FOR_RANGE(int, i, 0, PyArray_NDIM(array)) { array_elem_cnt *= PyArray_SHAPE(array)[i]; } CHECK_EQ(array_elem_cnt, tensor->shape_view().elem_cnt()) << "Numpy array element count " << array_elem_cnt << " is not equal to OneFlow tensor element count " << tensor->shape_view().elem_cnt(); void* array_data_ptr = PyArray_DATA(array); MemToTensor(array_data_ptr, array_elem_cnt, tensor); } void MakePyInputs(const UserOpDef& op_def, user_op::KernelComputeContext* ctx, PyObject** py_inputs) { const size_t kernel_in_num = ctx->inputs().size(); const size_t def_in_num = op_def.input_size(); CHECK_EQ(kernel_in_num, def_in_num) << "kernel input num " << kernel_in_num << " not equal to definition input num " << def_in_num; PyObject* py_list = PyList_New(def_in_num); CHECK(py_list); FOR_RANGE(size_t, i, 0, def_in_num) { PyObject* arg = nullptr; const std::string& arg_name = op_def.input(i).name(); VLOG(3) << "input arg_name " << arg_name; // do not support multi input in one symbolic arg name int32_t index = 0; TensorToNumpy(ctx->Tensor4ArgNameAndIndex(arg_name, index), &arg); arg = PyArray_Return(reinterpret_cast(arg)); PyList_SetItem(py_list, i, arg); } *py_inputs = Py_BuildValue("(N)", py_list); CHECK(*py_inputs); } void GetPyOutputs(const UserOpDef& op_def, user_op::KernelComputeContext* ctx, PyObject* py_outputs) { const size_t kernel_out_num = ctx->outputs().size(); const size_t def_out_num = op_def.output_size(); CHECK_EQ(kernel_out_num, def_out_num) << "kernel output num " << kernel_out_num << " not equal to definition output num " << def_out_num; if (PyList_Check(py_outputs)) { FOR_RANGE(size_t, i, 0, def_out_num) { const std::string& arg_name = op_def.output(i).name(); VLOG(3) << "output arg_name " << arg_name; int32_t index = 0; NumpyToTensor(PyList_GetItem(py_outputs, i), ctx->Tensor4ArgNameAndIndex(arg_name, index)); } } else if (PyArray_Check(py_outputs)) { const std::string& arg_name = ctx->outputs().at(0).first; VLOG(3) << "output arg_name " << arg_name; int32_t index = 0; NumpyToTensor(py_outputs, ctx->Tensor4ArgNameAndIndex(arg_name, index)); } else { LOG(FATAL) << "Unexpeted PyObject was returned: " << Py_TYPE(py_outputs)->tp_name; } } } // namespace void PyRegisterKernels(PyObject* py_kernels) { if (py_kernels_dic == nullptr) { py_kernels_dic = py_kernels; Py_INCREF(py_kernels_dic); } else { LOG(FATAL) << "RegisterPyKernels should only be call once."; } } void PyCompute(user_op::KernelComputeContext* ctx, const std::string& py_func_name) { const std::string& op_type_name = ctx->op_type_name(); const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); CHECK(val) << "Op op_type_name " << op_type_name << " has no definition."; const UserOpDef& op_def = val->op_def; // get GIL PyGILState_STATE py_gil_st; py_gil_st = PyGILState_Ensure(); // prepare for numpy c api if (PyArray_API == nullptr) { _import_array(); } PyObject *py_str, *py_module, *py_func; PyObject *py_inputs, *py_outputs; // get python kernel static const std::string forward_suffix = "_forward"; static const std::string backward_suffix = "_backward"; std::string op_module_name = op_type_name; if (op_type_name.size() > forward_suffix.size() && op_type_name.rfind(forward_suffix) == (op_type_name.size() - forward_suffix.size())) { op_module_name = op_type_name.substr(0, op_type_name.size() - forward_suffix.size()); } if (op_type_name.size() > backward_suffix.size() && op_type_name.rfind(backward_suffix) == (op_type_name.size() - backward_suffix.size())) { op_module_name = op_type_name.substr(0, op_type_name.size() - backward_suffix.size()); } py_str = PyUnicode_DecodeFSDefault(op_module_name.c_str()); CHECK(py_kernels_dic) << "py_kernels_dic should not be nullptr."; py_module = PyDict_GetItem(py_kernels_dic, py_str); CHECK(py_module) << op_module_name << " has no python kernel."; Py_DECREF(py_str); // get func py_func = PyObject_GetAttrString(py_module, py_func_name.c_str()); if (py_func == nullptr || !PyCallable_Check(py_func)) { Py_DECREF(py_module); PyErr_Print(); } // get numpy input MakePyInputs(op_def, ctx, &py_inputs); // call func py_outputs = PyEval_CallObject(py_func, py_inputs); Py_DECREF(py_inputs); // get numpy output GetPyOutputs(op_def, ctx, py_outputs); Py_XDECREF(py_func); Py_DECREF(py_outputs); // release GIL PyGILState_Release(py_gil_st); } } // namespace pyext } // namespace oneflow ================================================ FILE: oneflow/extension/python/py_compute.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_ #define ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/framework/framework.h" namespace oneflow { namespace pyext { void PyRegisterKernels(PyObject* py_kernels); void PyCompute(user_op::KernelComputeContext* ctx, const std::string& py_func_name); } // namespace pyext } // namespace oneflow #endif // ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_ ================================================ FILE: oneflow/extension/python/py_kernel_caller.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/extension/python/py_kernel_caller.h" #include "oneflow/extension/python/py_compute.h" namespace oneflow { void PyForwardKernel::Compute(user_op::KernelComputeContext* ctx) const { ::oneflow::pyext::PyCompute(ctx, "forward"); } void PyBackwardKernel::Compute(user_op::KernelComputeContext* ctx) const { ::oneflow::pyext::PyCompute(ctx, "backward"); } } // namespace oneflow ================================================ FILE: oneflow/extension/python/py_kernel_caller.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_ #define ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_ #include "oneflow/core/framework/framework.h" namespace oneflow { class PyForwardKernel final : public user_op::OpKernel { public: PyForwardKernel() = default; ~PyForwardKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class PyBackwardKernel final : public user_op::OpKernel { public: PyBackwardKernel() = default; ~PyBackwardKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace oneflow #endif // ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_ ================================================ FILE: oneflow/extension/python/py_kernel_registry.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/extension/python/py_kernel_registry.h" #include "oneflow/extension/python/py_compute.h" #include "oneflow/extension/python/py_kernel_caller.h" namespace oneflow { namespace pyext { Maybe RegisterPyKernelCaller(const std::string& op_module_name) { // register python op kernel auto reg = user_op::UserOpRegistryMgr::Get() .CheckAndGetOpKernelRegistry(op_module_name + "_forward") .SetCreateFn() .SetIsMatchedHob(((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDeviceSubTag() == "py"))); JUST(user_op::UserOpRegistryMgr::Get().Register(JUST(reg.Finish()).GetResult())); // register python grad op kernel auto grad_reg = user_op::UserOpRegistryMgr::Get() .CheckAndGetOpKernelRegistry(op_module_name + "_backward") .SetCreateFn() .SetIsMatchedHob(((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDeviceSubTag() == "py"))); JUST(user_op::UserOpRegistryMgr::Get().Register(JUST(grad_reg.Finish()).GetResult())); return Maybe::Ok(); } void RegisterPyKernels(PyObject* py_kernels) { PyRegisterKernels(py_kernels); } } // namespace pyext } // namespace oneflow ================================================ FILE: oneflow/extension/python/py_kernel_registry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_ #define ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_ #include #include #undef _PyGC_FINALIZED #include "oneflow/core/common/maybe.h" namespace oneflow { namespace pyext { Maybe RegisterPyKernelCaller(const std::string& op_module_name); void RegisterPyKernels(PyObject* py_kernels); } // namespace pyext } // namespace oneflow #endif // ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_ ================================================ FILE: oneflow/extension/stack/foreign_stack_getter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_ #define ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_ #include #include #include "oneflow/core/common/thread_local_guard.h" namespace oneflow { class Frame { public: virtual ~Frame() = default; }; using ForeignFrameThreadLocalGuard = ThreadLocalGuard>; class ForeignStackGetter { public: virtual ~ForeignStackGetter() = default; virtual std::shared_ptr GetCurrentFrame() const = 0; virtual std::string GetFormattedStack(std::shared_ptr frame) const = 0; }; } // namespace oneflow #endif // ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_ ================================================ FILE: oneflow/extension/stack/python/custom_eval_frame.c ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // see https://bugs.python.org/issue23644 for why this file is written // as .c instead of .cpp #include "oneflow/extension/stack/python/custom_eval_frame.h" #define PY_SSIZE_T_CLEAN #include #undef _PyGC_FINALIZED #include #include // see https://bugs.python.org/issue35886 #if PY_VERSION_HEX >= 0x03080000 #define Py_BUILD_CORE #include "internal/pycore_pystate.h" #undef Py_BUILD_CORE #endif inline static void EnableCustomEvalFrame(PyThreadState* tstate, _PyFrameEvalFunction eval_func) { #if PY_VERSION_HEX >= 0x03090000 if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != eval_func) { _PyInterpreterState_SetEvalFrameFunc(tstate->interp, eval_func); } #else if (tstate->interp->eval_frame != eval_func) { tstate->interp->eval_frame = eval_func; } #endif } void EnableCustomEvalFrameForCurrentThread(PyFrameEvalFunc eval_func) { return EnableCustomEvalFrame(PyThreadState_GET(), eval_func); } ================================================ FILE: oneflow/extension/stack/python/custom_eval_frame.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_ #define ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_ #ifdef __cplusplus extern "C" { #endif #include #undef _PyGC_FINALIZED #if PY_VERSION_HEX >= 0x03090000 typedef PyObject* (*PyFrameEvalFunc)(struct _ts*, struct _frame*, int); #else typedef PyObject* (*PyFrameEvalFunc)(struct _frame*, int); #endif void EnableCustomEvalFrameForCurrentThread(PyFrameEvalFunc eval_func); #ifdef __cplusplus } #endif #endif // ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_ ================================================ FILE: oneflow/extension/stack/python/stack_getter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/extension/stack/python/stack_getter.h" #include #include "fmt/core.h" #include "fmt/color.h" #include "fmt/ostream.h" #include "pybind11/pybind11.h" #if PY_VERSION_HEX >= 0x030b0000 #ifndef Py_BUILD_CORE #define Py_BUILD_CORE 1 #endif #include "internal/pycore_frame.h" #endif #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/job/graph_scope_vars.h" #include "oneflow/extension/stack/foreign_stack_getter.h" #include "oneflow/extension/stack/python/custom_eval_frame.h" namespace py = pybind11; namespace oneflow { namespace { std::string PyUnicodeToStdString(const PyObject* py_str) { return PyBytes_AsString(PyUnicode_AsEncodedString(const_cast(py_str), "utf-8", "~E~")); } #if PY_VERSION_HEX < 0x03090000 PyCodeObject* PyFrame_GetCode(PyFrameObject* frame) { assert(frame != NULL); PyCodeObject* code = frame->f_code; assert(code != NULL); Py_INCREF(code); return code; } #endif } // namespace class PyFrame final : public Frame { public: // There is no need to increase the reference count of these cpython objects // because they must be alive during the lifetime of `PyFrame`. PyFrame(PyFrameObject* frame, std::shared_ptr back) : cpython_frame(frame), lineno(0), back(std::move(back)) { PyCodeObject* code = PyFrame_GetCode(frame); filename = code->co_filename; funcname = code->co_name; Py_DECREF(code); } ~PyFrame() = default; OF_DISALLOW_COPY_AND_MOVE(PyFrame); PyObject* filename; PyObject* funcname; PyFrameObject* cpython_frame; int lineno; std::shared_ptr back; }; class PyStackGetter final : public ForeignStackGetter { public: PyStackGetter() { auto* frame = PyEval_GetFrame(); // Get the first frame. It assumes `import oneflow` is called in global scope, while (frame->f_back != nullptr) { frame = frame->f_back; } current_frame_ = std::make_shared(frame, nullptr); } // indended to be called in main thread. std::shared_ptr GetCurrentFrame() const override { if (IsShuttingDown() || !current_frame_) { return nullptr; } // See `RecordAndEvalFrame` for documentation. current_frame_->lineno = PyFrame_GetLineNumber(current_frame_->cpython_frame); return current_frame_; } // bad path, performance is not a concern. std::string GetFormattedStack(std::shared_ptr frame) const override { if (frame == nullptr) { return " \n"; } std::string buffer; const auto* py_frame = dynamic_cast(frame.get()); py::gil_scoped_acquire acquire; while (py_frame != nullptr) { const auto& lineno = py_frame->lineno; const std::string line_text = [&]() -> std::string { std::string line_text; std::ifstream ifs(PyUnicodeToStdString(py_frame->filename)); if (!ifs.is_open()) { return ""; } for (int j = 0; j < lineno; ++j) { std::getline(ifs, line_text); } line_text.erase(line_text.find_last_not_of(' ') + 1); // suffixing spaces line_text.erase(0, line_text.find_first_not_of(' ')); // prefixing spaces return line_text; }(); // immitate python's stack trace format fmt::format_to(std::back_inserter(buffer), " File \"{}\", line {}, in {}\n {}\n", PyUnicodeToStdString(py_frame->filename), lineno, PyUnicodeToStdString(py_frame->funcname), line_text); py_frame = py_frame->back.get(); } return buffer; }; #if PY_VERSION_HEX >= 0x03090000 PyObject* RecordAndEvalFrame(PyThreadState* tstate, PyFrameObject* frame, #else PyObject* RecordAndEvalFrame(PyFrameObject* frame, #endif int throw_flag) { // Example: // >> def f(): # Line 1 // >> pass # Line 2 // >> f() # Line 3 // // When we call f(), `RecordAndEvalFrame` is triggered and the `frame` // argument is the frame of function `f`, which is Line 1 at that time. It is not // useful to us, but we can adjust it in `GetCurrentFrame` method. // PushFrame(frame); #if PY_VERSION_HEX >= 0x03090000 if (tstate == NULL) { tstate = PyThreadState_GET(); } #if PY_VERSION_HEX >= 0x030b0000 PyObject* ret = _PyEval_EvalFrameDefault(tstate, frame->f_frame, throw_flag); #else PyObject* ret = _PyEval_EvalFrameDefault(tstate, frame, throw_flag); #endif #else PyObject* ret = _PyEval_EvalFrameDefault(frame, throw_flag); #endif PopFrame(); return ret; } private: std::shared_ptr current_frame_; void PushFrame(PyFrameObject* frame) { if (auto* f = frame->f_back) { current_frame_->lineno = PyFrame_GetLineNumber(f); } current_frame_ = std::make_shared(frame, current_frame_); } void PopFrame() { CHECK_NOTNULL(current_frame_); current_frame_ = current_frame_->back; } }; #if PY_VERSION_HEX >= 0x03090000 PyObject* RecordAndEvalFrame(PyThreadState* tstate, PyFrameObject* frame, #else PyObject* RecordAndEvalFrame(PyFrameObject* frame, #endif int throw_flag) { using namespace oneflow; return dynamic_cast(Singleton::Get()) #if PY_VERSION_HEX >= 0x03090000 ->RecordAndEvalFrame(tstate, frame, throw_flag); #else ->RecordAndEvalFrame(frame, throw_flag); #endif } void RegisterPyStackGetter() { if (!IsPythonStackGetterEnabled()) { return; } Singleton::Delete(); Singleton::SetAllocated(new PyStackGetter()); EnableCustomEvalFrameForCurrentThread(&RecordAndEvalFrame); } namespace { // get a formatted stack frame representation std::string get_python_frame_str_repr(PyFrameObject* frame) { if (frame == NULL) return ""; std::string buffer; PyCodeObject* code = PyFrame_GetCode(frame); std::string file_name = PyUnicodeToStdString(code->co_filename); std::string code_name = PyUnicodeToStdString(code->co_name); Py_DECREF(code); int line_number = PyFrame_GetLineNumber(frame); fmt::format_to(std::back_inserter(buffer), "File \"{}\", line {}, in {}", file_name, line_number, code_name); std::string line_text; const bool debug_mode = GetGraphDebugMode() || IsInDebugMode(); if (debug_mode) { const auto& GetCurSrc = [&file_name, line_number]() -> std::string { std::string line_text; std::ifstream ifs(file_name); if (!ifs.is_open()) { return ""; } for (int j = 0; j < line_number; ++j) { std::getline(ifs, line_text); } line_text.erase(line_text.find_last_not_of(' ') + 1); // suffixing spaces line_text.erase(0, line_text.find_first_not_of(' ')); // prefixing spaces return line_text; }; line_text = GetCurSrc(); buffer += ", source < " + line_text + " >; "; } else { buffer += "; "; } return buffer; } bool check_if_python_file_should_be_filtered(const std::string& path) { const auto& paths_to_be_kept = GetPythonPathsToBeKeptForDebugging(); for (int i = 0; i < paths_to_be_kept.size(); ++i) { const std::string& path_to_be_kept = paths_to_be_kept[i]; if (path.size() > path_to_be_kept.size()) { if (path.substr(0, path_to_be_kept.size()) == path_to_be_kept) { return false; } } } const auto& paths_to_be_filtered = GetPythonPathsToBeFilteredForDebugging(); for (int i = 0; i < paths_to_be_filtered.size(); ++i) { const std::string& path_to_be_filtered = paths_to_be_filtered[i]; if (path.size() > path_to_be_filtered.size()) { if (path.substr(0, path_to_be_filtered.size()) == path_to_be_filtered) { return true; } } } return false; } bool check_if_frame_should_be_filtered(PyFrameObject* frame) { std::string frame_file_name = PyUnicodeToStdString(PyFrame_GetCode(frame)->co_filename); return check_if_python_file_should_be_filtered(frame_file_name); } bool check_if_should_skip_this_frame(PyFrameObject* frame) { const bool only_user_py_stack = GetGraphDebugOnlyUserPyStack(); if (only_user_py_stack) { return check_if_frame_should_be_filtered(frame); } return false; } int32_t get_cur_stack_depth() { int32_t current_stack_depth = 0; PyFrameObject* f = PyEval_GetFrame(); while (f) { if (check_if_should_skip_this_frame(f)) { f = f->f_back; continue; } current_stack_depth++; f = f->f_back; } return current_stack_depth; } std::string get_cur_frame_stack_str() { const int32_t max_stack_depth = GetGraphDebugMaxPyStackDepth(); std::string cur_f_str; PyFrameObject* cur_frame = PyEval_GetFrame(); int i = 0; while (i < max_stack_depth) { if (cur_frame == NULL) break; if (check_if_should_skip_this_frame(cur_frame)) { cur_frame = cur_frame->f_back; continue; } cur_f_str += get_python_frame_str_repr(cur_frame); cur_frame = cur_frame->f_back; i++; } // show how may stack frames remain to be shown in debug mode const bool debug_mode = GetGraphDebugMode() || IsInDebugMode(); if (debug_mode) { const int32_t current_stack_depth = get_cur_stack_depth(); if (current_stack_depth > max_stack_depth) { cur_f_str += "... " + std::to_string(current_stack_depth - max_stack_depth) + " more"; } } else { if (cur_frame != NULL) { cur_f_str += " ... more"; } } return cur_f_str; } } // namespace PythonFrameGuard::PythonFrameGuard() { if (OF_PREDICT_FALSE(LazyMode::is_enabled())) { prev_frame_str_ = DispatchFrame::get_str(); DispatchFrame::set_str(get_cur_frame_stack_str()); } } PythonFrameGuard::~PythonFrameGuard() { if (OF_PREDICT_FALSE(LazyMode::is_enabled())) { DispatchFrame::set_str(prev_frame_str_); } } } // namespace oneflow ================================================ FILE: oneflow/extension/stack/python/stack_getter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER #define ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER #include namespace oneflow { void RegisterPyStackGetter(); class PythonFrameGuard { public: PythonFrameGuard(); ~PythonFrameGuard(); private: std::string prev_frame_str_; }; } // namespace oneflow #endif // ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER ================================================ FILE: oneflow/extension/stack/stacktrace.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* * backward.hpp * Copyright 2013 Google Inc. All Rights Reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifndef H_6B9572DA_A64B_49E6_B234_051480991C89 #define H_6B9572DA_A64B_49E6_B234_051480991C89 #ifndef __cplusplus #error "It's not going to compile without a C++ compiler..." #endif #if defined(BACKWARD_CXX11) #elif defined(BACKWARD_CXX98) #else #if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800) #define BACKWARD_CXX11 #define BACKWARD_ATLEAST_CXX11 #define BACKWARD_ATLEAST_CXX98 #if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) #define BACKWARD_ATLEAST_CXX17 #endif #else #define BACKWARD_CXX98 #define BACKWARD_ATLEAST_CXX98 #endif #endif // You can define one of the following (or leave it to the auto-detection): // // #define BACKWARD_SYSTEM_LINUX // - specialization for linux // // #define BACKWARD_SYSTEM_DARWIN // - specialization for Mac OS X 10.5 and later. // // #define BACKWARD_SYSTEM_WINDOWS // - specialization for Windows (Clang 9 and MSVC2017) // // #define BACKWARD_SYSTEM_UNKNOWN // - placebo implementation, does nothing. // #if defined(BACKWARD_SYSTEM_LINUX) #elif defined(BACKWARD_SYSTEM_DARWIN) #elif defined(BACKWARD_SYSTEM_UNKNOWN) #elif defined(BACKWARD_SYSTEM_WINDOWS) #else #if defined(__linux) || defined(__linux__) #define BACKWARD_SYSTEM_LINUX #elif defined(__APPLE__) #define BACKWARD_SYSTEM_DARWIN #elif defined(_WIN32) #define BACKWARD_SYSTEM_WINDOWS #else #define BACKWARD_SYSTEM_UNKNOWN #endif #endif #define NOINLINE __attribute__((noinline)) #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(BACKWARD_SYSTEM_LINUX) // On linux, backtrace can back-trace or "walk" the stack using the following // libraries: // // #define BACKWARD_HAS_UNWIND 1 // - unwind comes from libgcc, but I saw an equivalent inside clang itself. // - with unwind, the stacktrace is as accurate as it can possibly be, since // this is used by the C++ runtime in gcc/clang for stack unwinding on // exception. // - normally libgcc is already linked to your program by default. // // #define BACKWARD_HAS_LIBUNWIND 1 // - libunwind provides, in some cases, a more accurate stacktrace as it knows // to decode signal handler frames and lets us edit the context registers when // unwinding, allowing stack traces over bad function references. // // #define BACKWARD_HAS_BACKTRACE == 1 // - backtrace seems to be a little bit more portable than libunwind, but on // linux, it uses unwind anyway, but abstract away a tiny information that is // sadly really important in order to get perfectly accurate stack traces. // - backtrace is part of the (e)glib library. // // The default is: // #define BACKWARD_HAS_UNWIND == 1 // // Note that only one of the define should be set to 1 at a time. // #if BACKWARD_HAS_UNWIND == 1 #elif BACKWARD_HAS_LIBUNWIND == 1 #elif BACKWARD_HAS_BACKTRACE == 1 #else #undef BACKWARD_HAS_UNWIND #define BACKWARD_HAS_UNWIND 1 #undef BACKWARD_HAS_LIBUNWIND #define BACKWARD_HAS_LIBUNWIND 0 #undef BACKWARD_HAS_BACKTRACE #define BACKWARD_HAS_BACKTRACE 0 #endif // On linux, backward can extract detailed information about a stack trace // using one of the following libraries: // // #define BACKWARD_HAS_DW 1 // - libdw gives you the most juicy details out of your stack traces: // - object filename // - function name // - source filename // - line and column numbers // - source code snippet (assuming the file is accessible) // - variable names (if not optimized out) // - variable values (not supported by backward-cpp) // - You need to link with the lib "dw": // - apt-get install libdw-dev // - g++/clang++ -ldw ... // // #define BACKWARD_HAS_BFD 1 // - With libbfd, you get a fair amount of details: // - object filename // - function name // - source filename // - line numbers // - source code snippet (assuming the file is accessible) // - You need to link with the lib "bfd": // - apt-get install binutils-dev // - g++/clang++ -lbfd ... // // #define BACKWARD_HAS_DWARF 1 // - libdwarf gives you the most juicy details out of your stack traces: // - object filename // - function name // - source filename // - line and column numbers // - source code snippet (assuming the file is accessible) // - variable names (if not optimized out) // - variable values (not supported by backward-cpp) // - You need to link with the lib "dwarf": // - apt-get install libdwarf-dev // - g++/clang++ -ldwarf ... // // #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 // - backtrace provides minimal details for a stack trace: // - object filename // - function name // - backtrace is part of the (e)glib library. // // The default is: // #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1 // // Note that only one of the define should be set to 1 at a time. // #if BACKWARD_HAS_DW == 1 #elif BACKWARD_HAS_BFD == 1 #elif BACKWARD_HAS_DWARF == 1 #elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 #else #undef BACKWARD_HAS_DW #define BACKWARD_HAS_DW 0 #undef BACKWARD_HAS_BFD #define BACKWARD_HAS_BFD 0 #undef BACKWARD_HAS_DWARF #define BACKWARD_HAS_DWARF 0 #undef BACKWARD_HAS_BACKTRACE_SYMBOL #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 #endif #include #include #ifdef __ANDROID__ // Old Android API levels define _Unwind_Ptr in both link.h and // unwind.h Rename the one in link.h as we are not going to be using // it #define _Unwind_Ptr _Unwind_Ptr_Custom #include #undef _Unwind_Ptr #else #include #endif #if defined(__ppc__) || defined(__powerpc) || defined(__powerpc__) || defined(__POWERPC__) // Linux kernel header required for the struct pt_regs definition // to access the NIP (Next Instruction Pointer) register value #include #endif #include #include #include #include #ifndef _GNU_SOURCE #define _GNU_SOURCE #include #undef _GNU_SOURCE #else #include #endif #if BACKWARD_HAS_BFD == 1 // NOTE: defining PACKAGE{,_VERSION} is required before including // bfd.h on some platforms, see also: // https://sourceware.org/bugzilla/show_bug.cgi?id=14243 #ifndef PACKAGE #define PACKAGE #endif #ifndef PACKAGE_VERSION #define PACKAGE_VERSION #endif #include #endif #if BACKWARD_HAS_DW == 1 #include #include #include #endif #if BACKWARD_HAS_DWARF == 1 #include #include #include #include #include #endif #if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1) // then we shall rely on backtrace #include #endif #endif // defined(BACKWARD_SYSTEM_LINUX) #if defined(BACKWARD_SYSTEM_DARWIN) // On Darwin, backtrace can back-trace or "walk" the stack using the following // libraries: // // #define BACKWARD_HAS_UNWIND 1 // - unwind comes from libgcc, but I saw an equivalent inside clang itself. // - with unwind, the stacktrace is as accurate as it can possibly be, since // this is used by the C++ runtime in gcc/clang for stack unwinding on // exception. // - normally libgcc is already linked to your program by default. // // #define BACKWARD_HAS_LIBUNWIND 1 // - libunwind comes from clang, which implements an API compatible version. // - libunwind provides, in some cases, a more accurate stacktrace as it knows // to decode signal handler frames and lets us edit the context registers when // unwinding, allowing stack traces over bad function references. // // #define BACKWARD_HAS_BACKTRACE == 1 // - backtrace is available by default, though it does not produce as much // information as another library might. // // The default is: // #define BACKWARD_HAS_UNWIND == 1 // // Note that only one of the define should be set to 1 at a time. // #if BACKWARD_HAS_UNWIND == 1 #elif BACKWARD_HAS_BACKTRACE == 1 #elif BACKWARD_HAS_LIBUNWIND == 1 #else #undef BACKWARD_HAS_UNWIND #define BACKWARD_HAS_UNWIND 1 #undef BACKWARD_HAS_BACKTRACE #define BACKWARD_HAS_BACKTRACE 0 #undef BACKWARD_HAS_LIBUNWIND #define BACKWARD_HAS_LIBUNWIND 0 #endif // On Darwin, backward can extract detailed information about a stack trace // using one of the following libraries: // // #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 // - backtrace provides minimal details for a stack trace: // - object filename // - function name // // The default is: // #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1 // #if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 #else #undef BACKWARD_HAS_BACKTRACE_SYMBOL #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 #endif #include #include #include #include #include #include #if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1) #include #endif #endif // defined(BACKWARD_SYSTEM_DARWIN) #if defined(BACKWARD_SYSTEM_WINDOWS) #include #include #include #include #ifdef _WIN64 typedef SSIZE_T ssize_t; #else typedef int ssize_t; #endif #ifndef NOMINMAX #define NOMINMAX #endif #include #include #include #include #ifndef __clang__ #undef NOINLINE #define NOINLINE __declspec(noinline) #endif #ifdef _MSC_VER #pragma comment(lib, "psapi.lib") #pragma comment(lib, "dbghelp.lib") #endif // Comment / packing is from stackoverflow: // https://stackoverflow.com/questions/6205981/windows-c-stack-trace-from-a-running-app/28276227#28276227 // Some versions of imagehlp.dll lack the proper packing directives themselves // so we need to do it. #pragma pack(push, before_imagehlp, 8) #include #pragma pack(pop, before_imagehlp) // TODO maybe these should be undefined somewhere else? #undef BACKWARD_HAS_UNWIND #undef BACKWARD_HAS_BACKTRACE #if BACKWARD_HAS_PDB_SYMBOL == 1 #else #undef BACKWARD_HAS_PDB_SYMBOL #define BACKWARD_HAS_PDB_SYMBOL 1 #endif #endif #if BACKWARD_HAS_UNWIND == 1 #include // while gcc's unwind.h defines something like that: // extern _Unwind_Ptr _Unwind_GetIP (struct _Unwind_Context *); // extern _Unwind_Ptr _Unwind_GetIPInfo (struct _Unwind_Context *, int *); // // clang's unwind.h defines something like this: // uintptr_t _Unwind_GetIP(struct _Unwind_Context* __context); // // Even if the _Unwind_GetIPInfo can be linked to, it is not declared, worse we // cannot just redeclare it because clang's unwind.h doesn't define _Unwind_Ptr // anyway. // // Luckily we can play on the fact that the guard macros have a different name: #ifdef __CLANG_UNWIND_H // In fact, this function still comes from libgcc (on my different linux boxes, // clang links against libgcc). #include extern "C" uintptr_t _Unwind_GetIPInfo(_Unwind_Context*, int*); #endif #endif // BACKWARD_HAS_UNWIND == 1 #if BACKWARD_HAS_LIBUNWIND == 1 #define UNW_LOCAL_ONLY #include #endif // BACKWARD_HAS_LIBUNWIND == 1 #ifdef BACKWARD_ATLEAST_CXX11 #include #include // for std::swap namespace backward { namespace details { template struct hashtable { typedef std::unordered_map type; }; using std::move; } // namespace details } // namespace backward #else // NOT BACKWARD_ATLEAST_CXX11 #define nullptr NULL #define override #include namespace backward { namespace details { template struct hashtable { typedef std::map type; }; template const T& move(const T& v) { return v; } template T& move(T& v) { return v; } } // namespace details } // namespace backward #endif // BACKWARD_ATLEAST_CXX11 namespace backward { namespace details { #if defined(BACKWARD_SYSTEM_WINDOWS) const char kBackwardPathDelimiter[] = ";"; #else const char kBackwardPathDelimiter[] = ":"; #endif } // namespace details } // namespace backward namespace backward { namespace system_tag { struct linux_tag; // seems that I cannot call that "linux" because the name // is already defined... so I am adding _tag everywhere. struct darwin_tag; struct windows_tag; struct unknown_tag; #if defined(BACKWARD_SYSTEM_LINUX) typedef linux_tag current_tag; #elif defined(BACKWARD_SYSTEM_DARWIN) typedef darwin_tag current_tag; #elif defined(BACKWARD_SYSTEM_WINDOWS) typedef windows_tag current_tag; #elif defined(BACKWARD_SYSTEM_UNKNOWN) typedef unknown_tag current_tag; #else #error "May I please get my system defines?" #endif } // namespace system_tag namespace trace_resolver_tag { #if defined(BACKWARD_SYSTEM_LINUX) struct libdw; struct libbfd; struct libdwarf; struct backtrace_symbol; #if BACKWARD_HAS_DW == 1 typedef libdw current; #elif BACKWARD_HAS_BFD == 1 typedef libbfd current; #elif BACKWARD_HAS_DWARF == 1 typedef libdwarf current; #elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 typedef backtrace_symbol current; #else #error "You shall not pass, until you know what you want." #endif #elif defined(BACKWARD_SYSTEM_DARWIN) struct backtrace_symbol; #if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 typedef backtrace_symbol current; #else #error "You shall not pass, until you know what you want." #endif #elif defined(BACKWARD_SYSTEM_WINDOWS) struct pdb_symbol; #if BACKWARD_HAS_PDB_SYMBOL == 1 typedef pdb_symbol current; #else #error "You shall not pass, until you know what you want." #endif #endif } // namespace trace_resolver_tag namespace details { template struct rm_ptr { typedef T type; }; template struct rm_ptr { typedef T type; }; template struct rm_ptr { typedef const T type; }; template struct deleter { template void operator()(U& ptr) const { (*F)(ptr); } }; template struct default_delete { void operator()(T& ptr) const { delete ptr; } }; template> class handle { struct dummy; T _val; bool _empty; #ifdef BACKWARD_ATLEAST_CXX11 handle(const handle&) = delete; handle& operator=(const handle&) = delete; #endif public: ~handle() { if (!_empty) { Deleter()(_val); } } explicit handle() : _val(), _empty(true) {} explicit handle(T val) : _val(val), _empty(false) { if (!_val) _empty = true; } #ifdef BACKWARD_ATLEAST_CXX11 handle(handle&& from) : _empty(true) { swap(from); } handle& operator=(handle&& from) { swap(from); return *this; } #else explicit handle(const handle& from) : _empty(true) { // some sort of poor man's move semantic. swap(const_cast(from)); } handle& operator=(const handle& from) { // some sort of poor man's move semantic. swap(const_cast(from)); return *this; } #endif void reset(T new_val) { handle tmp(new_val); swap(tmp); } void update(T new_val) { _val = new_val; _empty = !static_cast(new_val); } operator const dummy*() const { if (_empty) { return nullptr; } return reinterpret_cast(_val); } T get() { return _val; } T release() { _empty = true; return _val; } void swap(handle& b) { using std::swap; swap(b._val, _val); // can throw, we are safe here. swap(b._empty, _empty); // should not throw: if you cannot swap two // bools without throwing... It's a lost cause anyway! } T& operator->() { return _val; } const T& operator->() const { return _val; } typedef typename rm_ptr::type& ref_t; typedef const typename rm_ptr::type& const_ref_t; ref_t operator*() { return *_val; } const_ref_t operator*() const { return *_val; } ref_t operator[](size_t idx) { return _val[idx]; } // Watch out, we've got a badass over here T* operator&() { _empty = false; return &_val; } }; namespace { // how many args to keep in template params // e.g. {std::vector, 1} means std::vector -> std::vector static std::unordered_map class2keepsize{ {"std::vector", 1}, {"Maybe", 1}, }; class SignatureType { public: SignatureType(const std::string& name, const std::vector& args, const std::string& specifier) : name(name), args(args), specifier(specifier){}; std::string name; std::vector args; std::string specifier; using pss = std::pair; size_t get_keep_size(const std::string& name) { auto it = class2keepsize.find(name); if (it == class2keepsize.end()) { return 0; } else { return it->second; } } std::string to_string() { std::string str_args, str_specifer; if (args.empty()) { str_args = ""; } else { str_args = "<"; size_t keep_size = get_keep_size(name); if (keep_size == 0) { keep_size = args.size(); } for (int i = 0; i < keep_size; i++) { SignatureType type = args[i]; str_args += type.to_string() + ((i != (keep_size - 1)) ? ", " : ""); }; str_args += "> "; } return name + str_args + specifier; } static std::pair, std::string> parse_args(std::string s) { std::vector args; while (s[0] != '>') { s = s.substr(1, s.size() - 1); auto type_and_rest = parse_type(s); s = type_and_rest.second; args.push_back(type_and_rest.first); } return {args, s.substr(1, s.size() - 1)}; } static pss parse_spaces(const std::string& inp) { size_t pos = inp.find_first_not_of(" "); if (pos == 0) { return {"", inp}; } else { return {inp.substr(0, pos), inp.substr(pos, inp.size() - pos)}; } } static pss parse_type_specifier(const std::string& inp) { static std::vector specifier_list{ "const&", "const", "volatile", }; for (const auto& specifier : specifier_list) { if (inp.rfind(specifier, 0) == 0) { return {specifier, inp.substr(specifier.size(), inp.size() - specifier.size())}; } } return {"", inp}; } static pss parse_simple_type_id(const std::string& inp) { auto rest = parse_spaces(inp).second; std::smatch found; std::regex_search(rest, found, std::regex("^((\\w|:|\\*|&)+)")); std::string name = found[0]; return {name, rest.substr(name.size(), rest.size() - name.size())}; } static std::pair parse_type_id(const std::string& inp) { auto rest = parse_spaces(inp).second; std::smatch found; std::regex_search(rest, found, std::regex("^((\\w|:|\\*|&)+)")); std::string name = found[0]; rest = rest.substr(name.size(), rest.size() - name.size()); rest = parse_spaces(rest).second; auto spec_and_rest = parse_type_specifier(rest); auto type_spec = spec_and_rest.first; rest = spec_and_rest.second; return {SignatureType(name, {}, type_spec), rest}; } static std::pair parse_type(const std::string& inp) { auto name_and_rest = parse_simple_type_id(inp); auto type_name = name_and_rest.first; auto rest = name_and_rest.second; std::vector args; if (rest[0] == '<') { auto args_and_rest = parse_args(rest); args = args_and_rest.first; rest = args_and_rest.second; } rest = parse_spaces(rest).second; auto specifier_and_rest = parse_type_specifier(rest); auto type_spec = specifier_and_rest.first; rest = specifier_and_rest.second; return {SignatureType(type_name, args, type_spec), rest}; } }; std::string replace_each(const std::string& signature, const std::string& src, const std::string& dst) { std::string result; std::string::size_type substr_begin = 0; for (std::string::size_type pos = 0; signature.npos != (pos = signature.find(src.data(), pos, src.length()));) { result.insert(result.end(), signature.begin() + substr_begin, signature.begin() + pos); result += dst; substr_begin = pos + src.length(); pos = substr_begin; } result.insert(result.end(), signature.begin() + substr_begin, signature.end()); return result; } std::string replace(const std::string& signature) { static std::vector> replace_pairs = { {"oneflow::one::", ""}, {"oneflow::", ""}, {"std::__cxx11::basic_string, std::allocator >", "std::string"}, }; std::string result = signature; for (const auto& p : replace_pairs) { result = replace_each(result, p.first, p.second); } return result; } std::string simplify_type(const std::string& inp, const std::string& type_name) { std::string result; std::string::size_type begin = 0; std::string::size_type pos = 0; for (; inp.npos != (pos = inp.find(type_name.data(), pos, type_name.length()));) { result.insert(result.end(), inp.begin() + begin, inp.begin() + pos); auto type_and_rest = SignatureType::parse_type(inp.substr(pos, inp.size() - pos)); result += type_and_rest.first.to_string(); begin = inp.size() - type_and_rest.second.size(); pos = begin; } result.insert(result.end(), inp.begin() + begin, inp.end()); return result; } std::string simplify(const std::string& inp) { std::string result = replace(inp); for (const auto& type_pair : class2keepsize) { auto type_name = type_pair.first; result = simplify_type(result, type_name); } return result; } } // namespace // Default demangler implementation (do nothing). template struct demangler_impl { static std::string demangle(const char* funcname) { return funcname; } }; #if defined(BACKWARD_SYSTEM_LINUX) || defined(BACKWARD_SYSTEM_DARWIN) template<> struct demangler_impl { demangler_impl() : _demangle_buffer_length(0) {} std::string demangle(const char* funcname) { using namespace details; char* result = abi::__cxa_demangle(funcname, _demangle_buffer.get(), &_demangle_buffer_length, nullptr); if (result) { _demangle_buffer.update(result); // Modify: simplify func signature return simplify(result); // return result; } return funcname; } private: details::handle _demangle_buffer; size_t _demangle_buffer_length; }; #endif // BACKWARD_SYSTEM_LINUX || BACKWARD_SYSTEM_DARWIN struct demangler : public demangler_impl {}; // Split a string on the platform's PATH delimiter. Example: if delimiter // is ":" then: // "" --> [] // ":" --> ["",""] // "::" --> ["","",""] // "/a/b/c" --> ["/a/b/c"] // "/a/b/c:/d/e/f" --> ["/a/b/c","/d/e/f"] // etc. inline std::vector split_source_prefixes(const std::string& s) { std::vector out; size_t last = 0; size_t next = 0; size_t delimiter_size = sizeof(kBackwardPathDelimiter) - 1; while ((next = s.find(kBackwardPathDelimiter, last)) != std::string::npos) { out.push_back(s.substr(last, next - last)); last = next + delimiter_size; } if (last <= s.length()) { out.push_back(s.substr(last)); } return out; } } // namespace details /*************** A TRACE ***************/ struct Trace { void* addr; size_t idx; Trace() : addr(nullptr), idx(0) {} explicit Trace(void* _addr, size_t _idx) : addr(_addr), idx(_idx) {} }; struct ResolvedTrace : public Trace { struct SourceLoc { std::string function; std::string filename; unsigned line; unsigned col; SourceLoc() : line(0), col(0) {} bool operator==(const SourceLoc& b) const { return function == b.function && filename == b.filename && line == b.line && col == b.col; } bool operator!=(const SourceLoc& b) const { return !(*this == b); } }; // In which binary object this trace is located. std::string object_filename; // The function in the object that contain the trace. This is not the same // as source.function which can be an function inlined in object_function. std::string object_function; // The source location of this trace. It is possible for filename to be // empty and for line/col to be invalid (value 0) if this information // couldn't be deduced, for example if there is no debug information in the // binary object. SourceLoc source; // An optionals list of "inliners". All the successive sources location // from where the source location of the trace (the attribute right above) // is inlined. It is especially useful when you compiled with optimization. typedef std::vector source_locs_t; source_locs_t inliners; ResolvedTrace() : Trace() {} ResolvedTrace(const Trace& mini_trace) : Trace(mini_trace) {} }; /*************** STACK TRACE ***************/ // default implemention. template class StackTraceImpl { public: size_t size() const { return 0; } Trace operator[](size_t) const { return Trace(); } size_t load_here(size_t = 0) { return 0; } size_t load_from(void*, size_t = 0, void* = nullptr, void* = nullptr) { return 0; } size_t thread_id() const { return 0; } void skip_n_firsts(size_t) {} }; class StackTraceImplBase { public: StackTraceImplBase() : _thread_id(0), _skip(0), _context(nullptr), _error_addr(nullptr) {} size_t thread_id() const { return _thread_id; } void skip_n_firsts(size_t n) { _skip = n; } protected: void load_thread_info() { #ifdef BACKWARD_SYSTEM_LINUX #ifndef __ANDROID__ _thread_id = static_cast(syscall(SYS_gettid)); #else _thread_id = static_cast(gettid()); #endif if (_thread_id == static_cast(getpid())) { // If the thread is the main one, let's hide that. // I like to keep little secret sometimes. _thread_id = 0; } #elif defined(BACKWARD_SYSTEM_DARWIN) _thread_id = reinterpret_cast(pthread_self()); if (pthread_main_np() == 1) { // If the thread is the main one, let's hide that. _thread_id = 0; } #endif } void set_context(void* context) { _context = context; } void* context() const { return _context; } void set_error_addr(void* error_addr) { _error_addr = error_addr; } void* error_addr() const { return _error_addr; } size_t skip_n_firsts() const { return _skip; } private: size_t _thread_id; size_t _skip; void* _context; void* _error_addr; }; class StackTraceImplHolder : public StackTraceImplBase { public: size_t size() const { return (_stacktrace.size() >= skip_n_firsts()) ? _stacktrace.size() - skip_n_firsts() : 0; } Trace operator[](size_t idx) const { if (idx >= size()) { return Trace(); } return Trace(_stacktrace[idx + skip_n_firsts()], idx); } void* const* begin() const { if (size()) { return &_stacktrace[skip_n_firsts()]; } return nullptr; } protected: std::vector _stacktrace; }; #if BACKWARD_HAS_UNWIND == 1 namespace details { template class Unwinder { public: size_t operator()(F& f, size_t depth) { _f = &f; _index = -1; _depth = depth; _Unwind_Backtrace(&this->backtrace_trampoline, this); if (_index == -1) { // _Unwind_Backtrace has failed to obtain any backtraces return 0; } else { return static_cast(_index); } } private: F* _f; ssize_t _index; size_t _depth; static _Unwind_Reason_Code backtrace_trampoline(_Unwind_Context* ctx, void* self) { return (static_cast(self))->backtrace(ctx); } _Unwind_Reason_Code backtrace(_Unwind_Context* ctx) { if (_index >= 0 && static_cast(_index) >= _depth) return _URC_END_OF_STACK; int ip_before_instruction = 0; uintptr_t ip = _Unwind_GetIPInfo(ctx, &ip_before_instruction); if (!ip_before_instruction) { // calculating 0-1 for unsigned, looks like a possible bug to sanitizers, // so let's do it explicitly: if (ip == 0) { ip = std::numeric_limits::max(); // set it to 0xffff... (as // from casting 0-1) } else { ip -= 1; // else just normally decrement it (no overflow/underflow will // happen) } } if (_index >= 0) { // ignore first frame. (*_f)(static_cast(_index), reinterpret_cast(ip)); } _index += 1; return _URC_NO_REASON; } }; template size_t unwind(F f, size_t depth) { Unwinder unwinder; return unwinder(f, depth); } } // namespace details template<> class StackTraceImpl : public StackTraceImplHolder { public: NOINLINE size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { load_thread_info(); set_context(context); set_error_addr(error_addr); if (depth == 0) { return 0; } _stacktrace.resize(depth); size_t trace_cnt = details::unwind(callback(*this), depth); _stacktrace.resize(trace_cnt); skip_n_firsts(0); return size(); } size_t load_from(void* addr, size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { load_here(depth + 8, context, error_addr); for (size_t i = 0; i < _stacktrace.size(); ++i) { if (_stacktrace[i] == addr) { skip_n_firsts(i); break; } } _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth)); return size(); } private: struct callback { StackTraceImpl& self; callback(StackTraceImpl& _self) : self(_self) {} void operator()(size_t idx, void* addr) { self._stacktrace[idx] = addr; } }; }; #elif BACKWARD_HAS_LIBUNWIND == 1 template<> class StackTraceImpl : public StackTraceImplHolder { public: __attribute__((noinline)) size_t load_here(size_t depth = 32, void* _context = nullptr, void* _error_addr = nullptr) { set_context(_context); set_error_addr(_error_addr); load_thread_info(); if (depth == 0) { return 0; } _stacktrace.resize(depth + 1); int result = 0; unw_context_t ctx; size_t index = 0; // Add the tail call. If the Instruction Pointer is the crash address it // means we got a bad function pointer dereference, so we "unwind" the // bad pointer manually by using the return address pointed to by the // Stack Pointer as the Instruction Pointer and letting libunwind do // the rest if (context()) { ucontext_t* uctx = reinterpret_cast(context()); // x86_64 #ifdef REG_RIP if (uctx->uc_mcontext.gregs[REG_RIP] == reinterpret_cast(error_addr())) { uctx->uc_mcontext.gregs[REG_RIP] = *reinterpret_cast(uctx->uc_mcontext.gregs[REG_RSP]); } _stacktrace[index] = reinterpret_cast(uctx->uc_mcontext.gregs[REG_RIP]); ++index; ctx = *reinterpret_cast(uctx); // x86_32 #elif defined(REG_EIP) if (uctx->uc_mcontext.gregs[REG_EIP] == reinterpret_cast(error_addr())) { uctx->uc_mcontext.gregs[REG_EIP] = *reinterpret_cast(uctx->uc_mcontext.gregs[REG_ESP]); } _stacktrace[index] = reinterpret_cast(uctx->uc_mcontext.gregs[REG_EIP]); ++index; ctx = *reinterpret_cast(uctx); #elif defined(__arm__) // libunwind uses its own context type for ARM unwinding. // Copy the registers from the signal handler's context so we can // unwind unw_getcontext(&ctx); ctx.regs[UNW_ARM_R0] = uctx->uc_mcontext.arm_r0; ctx.regs[UNW_ARM_R1] = uctx->uc_mcontext.arm_r1; ctx.regs[UNW_ARM_R2] = uctx->uc_mcontext.arm_r2; ctx.regs[UNW_ARM_R3] = uctx->uc_mcontext.arm_r3; ctx.regs[UNW_ARM_R4] = uctx->uc_mcontext.arm_r4; ctx.regs[UNW_ARM_R5] = uctx->uc_mcontext.arm_r5; ctx.regs[UNW_ARM_R6] = uctx->uc_mcontext.arm_r6; ctx.regs[UNW_ARM_R7] = uctx->uc_mcontext.arm_r7; ctx.regs[UNW_ARM_R8] = uctx->uc_mcontext.arm_r8; ctx.regs[UNW_ARM_R9] = uctx->uc_mcontext.arm_r9; ctx.regs[UNW_ARM_R10] = uctx->uc_mcontext.arm_r10; ctx.regs[UNW_ARM_R11] = uctx->uc_mcontext.arm_fp; ctx.regs[UNW_ARM_R12] = uctx->uc_mcontext.arm_ip; ctx.regs[UNW_ARM_R13] = uctx->uc_mcontext.arm_sp; ctx.regs[UNW_ARM_R14] = uctx->uc_mcontext.arm_lr; ctx.regs[UNW_ARM_R15] = uctx->uc_mcontext.arm_pc; // If we have crashed in the PC use the LR instead, as this was // a bad function dereference if (reinterpret_cast(error_addr()) == uctx->uc_mcontext.arm_pc) { ctx.regs[UNW_ARM_R15] = uctx->uc_mcontext.arm_lr - sizeof(unsigned long); } _stacktrace[index] = reinterpret_cast(ctx.regs[UNW_ARM_R15]); ++index; #elif defined(__APPLE__) && defined(__x86_64__) unw_getcontext(&ctx); // OS X's implementation of libunwind uses its own context object // so we need to convert the passed context to libunwind's format // (information about the data layout taken from unw_getcontext.s // in Apple's libunwind source ctx.data[0] = uctx->uc_mcontext->__ss.__rax; ctx.data[1] = uctx->uc_mcontext->__ss.__rbx; ctx.data[2] = uctx->uc_mcontext->__ss.__rcx; ctx.data[3] = uctx->uc_mcontext->__ss.__rdx; ctx.data[4] = uctx->uc_mcontext->__ss.__rdi; ctx.data[5] = uctx->uc_mcontext->__ss.__rsi; ctx.data[6] = uctx->uc_mcontext->__ss.__rbp; ctx.data[7] = uctx->uc_mcontext->__ss.__rsp; ctx.data[8] = uctx->uc_mcontext->__ss.__r8; ctx.data[9] = uctx->uc_mcontext->__ss.__r9; ctx.data[10] = uctx->uc_mcontext->__ss.__r10; ctx.data[11] = uctx->uc_mcontext->__ss.__r11; ctx.data[12] = uctx->uc_mcontext->__ss.__r12; ctx.data[13] = uctx->uc_mcontext->__ss.__r13; ctx.data[14] = uctx->uc_mcontext->__ss.__r14; ctx.data[15] = uctx->uc_mcontext->__ss.__r15; ctx.data[16] = uctx->uc_mcontext->__ss.__rip; // If the IP is the same as the crash address we have a bad function // dereference The caller's address is pointed to by %rsp, so we // dereference that value and set it to be the next frame's IP. if (uctx->uc_mcontext->__ss.__rip == reinterpret_cast<__uint64_t>(error_addr())) { ctx.data[16] = *reinterpret_cast<__uint64_t*>(uctx->uc_mcontext->__ss.__rsp); } _stacktrace[index] = reinterpret_cast(ctx.data[16]); ++index; #elif defined(__APPLE__) unw_getcontext(&ctx) // TODO: Convert the ucontext_t to libunwind's unw_context_t like // we do in 64 bits if (ctx.uc_mcontext->__ss.__eip == reinterpret_cast(error_addr())) { ctx.uc_mcontext->__ss.__eip = ctx.uc_mcontext->__ss.__esp; } _stacktrace[index] = reinterpret_cast(ctx.uc_mcontext->__ss.__eip); ++index; #endif } unw_cursor_t cursor; if (context()) { #if defined(UNW_INIT_SIGNAL_FRAME) result = unw_init_local2(&cursor, &ctx, UNW_INIT_SIGNAL_FRAME); #else result = unw_init_local(&cursor, &ctx); #endif } else { unw_getcontext(&ctx); ; result = unw_init_local(&cursor, &ctx); } if (result != 0) return 1; unw_word_t ip = 0; while (index <= depth && unw_step(&cursor) > 0) { result = unw_get_reg(&cursor, UNW_REG_IP, &ip); if (result == 0) { _stacktrace[index] = reinterpret_cast(--ip); ++index; } } --index; _stacktrace.resize(index + 1); skip_n_firsts(0); return size(); } size_t load_from(void* addr, size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { load_here(depth + 8, context, error_addr); for (size_t i = 0; i < _stacktrace.size(); ++i) { if (_stacktrace[i] == addr) { skip_n_firsts(i); _stacktrace[i] = (void*)((uintptr_t)_stacktrace[i]); break; } } _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth)); return size(); } }; #elif defined(BACKWARD_HAS_BACKTRACE) template<> class StackTraceImpl : public StackTraceImplHolder { public: NOINLINE size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { set_context(context); set_error_addr(error_addr); load_thread_info(); if (depth == 0) { return 0; } _stacktrace.resize(depth + 1); size_t trace_cnt = backtrace(&_stacktrace[0], _stacktrace.size()); _stacktrace.resize(trace_cnt); skip_n_firsts(1); return size(); } size_t load_from(void* addr, size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { load_here(depth + 8, context, error_addr); for (size_t i = 0; i < _stacktrace.size(); ++i) { if (_stacktrace[i] == addr) { skip_n_firsts(i); _stacktrace[i] = (void*)((uintptr_t)_stacktrace[i] + 1); break; } } _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth)); return size(); } }; #elif defined(BACKWARD_SYSTEM_WINDOWS) template<> class StackTraceImpl : public StackTraceImplHolder { public: // We have to load the machine type from the image info // So we first initialize the resolver, and it tells us this info void set_machine_type(DWORD machine_type) { machine_type_ = machine_type; } void set_context(CONTEXT* ctx) { ctx_ = ctx; } void set_thread_handle(HANDLE handle) { thd_ = handle; } NOINLINE size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { set_context(static_cast(context)); set_error_addr(error_addr); CONTEXT localCtx; // used when no context is provided if (depth == 0) { return 0; } if (!ctx_) { ctx_ = &localCtx; RtlCaptureContext(ctx_); } if (!thd_) { thd_ = GetCurrentThread(); } HANDLE process = GetCurrentProcess(); STACKFRAME64 s; memset(&s, 0, sizeof(STACKFRAME64)); // TODO: 32 bit context capture s.AddrStack.Mode = AddrModeFlat; s.AddrFrame.Mode = AddrModeFlat; s.AddrPC.Mode = AddrModeFlat; #ifdef _M_X64 s.AddrPC.Offset = ctx_->Rip; s.AddrStack.Offset = ctx_->Rsp; s.AddrFrame.Offset = ctx_->Rbp; #else s.AddrPC.Offset = ctx_->Eip; s.AddrStack.Offset = ctx_->Esp; s.AddrFrame.Offset = ctx_->Ebp; #endif if (!machine_type_) { #ifdef _M_X64 machine_type_ = IMAGE_FILE_MACHINE_AMD64; #else machine_type_ = IMAGE_FILE_MACHINE_I386; #endif } for (;;) { // NOTE: this only works if PDBs are already loaded! SetLastError(0); if (!StackWalk64(machine_type_, process, thd_, &s, ctx_, NULL, SymFunctionTableAccess64, SymGetModuleBase64, NULL)) break; if (s.AddrReturn.Offset == 0) break; _stacktrace.push_back(reinterpret_cast(s.AddrPC.Offset)); if (size() >= depth) break; } return size(); } size_t load_from(void* addr, size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) { load_here(depth + 8, context, error_addr); for (size_t i = 0; i < _stacktrace.size(); ++i) { if (_stacktrace[i] == addr) { skip_n_firsts(i); break; } } _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth)); return size(); } private: DWORD machine_type_ = 0; HANDLE thd_ = 0; CONTEXT* ctx_ = nullptr; }; #endif class StackTrace : public StackTraceImpl {}; /*************** TRACE RESOLVER ***************/ class TraceResolverImplBase { public: virtual ~TraceResolverImplBase() {} virtual void load_addresses(void* const* addresses, int address_count) { (void)addresses; (void)address_count; } template void load_stacktrace(ST& st) { load_addresses(st.begin(), static_cast(st.size())); } virtual ResolvedTrace resolve(ResolvedTrace t) { return t; } protected: std::string demangle(const char* funcname) { return _demangler.demangle(funcname); } private: details::demangler _demangler; }; template class TraceResolverImpl; #ifdef BACKWARD_SYSTEM_UNKNOWN template<> class TraceResolverImpl : public TraceResolverImplBase {}; #endif #ifdef BACKWARD_SYSTEM_LINUX class TraceResolverLinuxBase : public TraceResolverImplBase { public: TraceResolverLinuxBase() : argv0_(get_argv0()), exec_path_(read_symlink("/proc/self/exe")) {} std::string resolve_exec_path(Dl_info& symbol_info) const { // mutates symbol_info.dli_fname to be filename to open and returns filename // to display if (symbol_info.dli_fname == argv0_) { // dladdr returns argv[0] in dli_fname for symbols contained in // the main executable, which is not a valid path if the // executable was found by a search of the PATH environment // variable; In that case, we actually open /proc/self/exe, which // is always the actual executable (even if it was deleted/replaced!) // but display the path that /proc/self/exe links to. // However, this right away reduces probability of successful symbol // resolution, because libbfd may try to find *.debug files in the // same dir, in case symbols are stripped. As a result, it may try // to find a file /proc/self/.debug, which obviously does // not exist. /proc/self/exe is a last resort. First load attempt // should go for the original executable file path. symbol_info.dli_fname = "/proc/self/exe"; return exec_path_; } else { return symbol_info.dli_fname; } } private: std::string argv0_; std::string exec_path_; static std::string get_argv0() { std::string argv0; std::ifstream ifs("/proc/self/cmdline"); std::getline(ifs, argv0, '\0'); return argv0; } static std::string read_symlink(std::string const& symlink_path) { std::string path; path.resize(100); while (true) { ssize_t len = ::readlink(symlink_path.c_str(), &*path.begin(), path.size()); if (len < 0) { return ""; } if (static_cast(len) == path.size()) { path.resize(path.size() * 2); } else { path.resize(static_cast(len)); break; } } return path; } }; template class TraceResolverLinuxImpl; #if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 template<> class TraceResolverLinuxImpl : public TraceResolverLinuxBase { public: void load_addresses(void* const* addresses, int address_count) override { if (address_count == 0) { return; } _symbols.reset(backtrace_symbols(addresses, address_count)); } ResolvedTrace resolve(ResolvedTrace trace) override { char* filename = _symbols[trace.idx]; char* funcname = filename; while (*funcname && *funcname != '(') { funcname += 1; } trace.object_filename.assign(filename, funcname); // ok even if funcname is the ending // \0 (then we assign entire string) if (*funcname) { // if it's not end of string (e.g. from last frame ip==0) funcname += 1; char* funcname_end = funcname; while (*funcname_end && *funcname_end != ')' && *funcname_end != '+') { funcname_end += 1; } *funcname_end = '\0'; trace.object_function = this->demangle(funcname); trace.source.function = trace.object_function; // we cannot do better. } return trace; } private: details::handle _symbols; }; #endif // BACKWARD_HAS_BACKTRACE_SYMBOL == 1 #if BACKWARD_HAS_BFD == 1 template<> class TraceResolverLinuxImpl : public TraceResolverLinuxBase { public: TraceResolverLinuxImpl() : _bfd_loaded(false) {} ResolvedTrace resolve(ResolvedTrace trace) override { Dl_info symbol_info; // trace.addr is a virtual address in memory pointing to some code. // Let's try to find from which loaded object it comes from. // The loaded object can be yourself btw. if (!dladdr(trace.addr, &symbol_info)) { return trace; // dat broken trace... } // Now we get in symbol_info: // .dli_fname: // pathname of the shared object that contains the address. // .dli_fbase: // where the object is loaded in memory. // .dli_sname: // the name of the nearest symbol to trace.addr, we expect a // function name. // .dli_saddr: // the exact address corresponding to .dli_sname. if (symbol_info.dli_sname) { trace.object_function = demangle(symbol_info.dli_sname); } if (!symbol_info.dli_fname) { return trace; } trace.object_filename = resolve_exec_path(symbol_info); bfd_fileobject* fobj; // Before rushing to resolution need to ensure the executable // file still can be used. For that compare inode numbers of // what is stored by the executable's file path, and in the // dli_fname, which not necessarily equals to the executable. // It can be a shared library, or /proc/self/exe, and in the // latter case has drawbacks. See the exec path resolution for // details. In short - the dli object should be used only as // the last resort. // If inode numbers are equal, it is known dli_fname and the // executable file are the same. This is guaranteed by Linux, // because if the executable file is changed/deleted, it will // be done in a new inode. The old file will be preserved in // /proc/self/exe, and may even have inode 0. The latter can // happen if the inode was actually reused, and the file was // kept only in the main memory. // struct stat obj_stat; struct stat dli_stat; if (stat(trace.object_filename.c_str(), &obj_stat) == 0 && stat(symbol_info.dli_fname, &dli_stat) == 0 && obj_stat.st_ino == dli_stat.st_ino) { // The executable file, and the shared object containing the // address are the same file. Safe to use the original path. // this is preferable. Libbfd will search for stripped debug // symbols in the same directory. fobj = load_object_with_bfd(trace.object_filename); } else { // The original object file was *deleted*! The only hope is // that the debug symbols are either inside the shared // object file, or are in the same directory, and this is // not /proc/self/exe. fobj = nullptr; } if (fobj == nullptr || !fobj->handle) { fobj = load_object_with_bfd(symbol_info.dli_fname); if (!fobj->handle) { return trace; } } find_sym_result* details_selected; // to be filled. // trace.addr is the next instruction to be executed after returning // from the nested stack frame. In C++ this usually relate to the next // statement right after the function call that leaded to a new stack // frame. This is not usually what you want to see when printing out a // stacktrace... find_sym_result details_call_site = find_symbol_details(fobj, trace.addr, symbol_info.dli_fbase); details_selected = &details_call_site; #if BACKWARD_HAS_UNWIND == 0 // ...this is why we also try to resolve the symbol that is right // before the return address. If we are lucky enough, we will get the // line of the function that was called. But if the code is optimized, // we might get something absolutely not related since the compiler // can reschedule the return address with inline functions and // tail-call optimization (among other things that I don't even know // or cannot even dream about with my tiny limited brain). find_sym_result details_adjusted_call_site = find_symbol_details(fobj, (void*)(uintptr_t(trace.addr) - 1), symbol_info.dli_fbase); // In debug mode, we should always get the right thing(TM). if (details_call_site.found && details_adjusted_call_site.found) { // Ok, we assume that details_adjusted_call_site is a better estimation. details_selected = &details_adjusted_call_site; trace.addr = (void*)(uintptr_t(trace.addr) - 1); } if (details_selected == &details_call_site && details_call_site.found) { // we have to re-resolve the symbol in order to reset some // internal state in BFD... so we can call backtrace_inliners // thereafter... details_call_site = find_symbol_details(fobj, trace.addr, symbol_info.dli_fbase); } #endif // BACKWARD_HAS_UNWIND if (details_selected->found) { if (details_selected->filename) { trace.source.filename = details_selected->filename; } trace.source.line = details_selected->line; if (details_selected->funcname) { // this time we get the name of the function where the code is // located, instead of the function were the address is // located. In short, if the code was inlined, we get the // function corresponding to the code. Else we already got in // trace.function. trace.source.function = demangle(details_selected->funcname); if (!symbol_info.dli_sname) { // for the case dladdr failed to find the symbol name of // the function, we might as well try to put something // here. trace.object_function = trace.source.function; } } // Maybe the source of the trace got inlined inside the function // (trace.source.function). Let's see if we can get all the inlined // calls along the way up to the initial call site. trace.inliners = backtrace_inliners(fobj, *details_selected); #if 0 if (trace.inliners.size() == 0) { // Maybe the trace was not inlined... or maybe it was and we // are lacking the debug information. Let's try to make the // world better and see if we can get the line number of the // function (trace.source.function) now. // // We will get the location of where the function start (to be // exact: the first instruction that really start the // function), not where the name of the function is defined. // This can be quite far away from the name of the function // btw. // // If the source of the function is the same as the source of // the trace, we cannot say if the trace was really inlined or // not. However, if the filename of the source is different // between the function and the trace... we can declare it as // an inliner. This is not 100% accurate, but better than // nothing. if (symbol_info.dli_saddr) { find_sym_result details = find_symbol_details(fobj, symbol_info.dli_saddr, symbol_info.dli_fbase); if (details.found) { ResolvedTrace::SourceLoc diy_inliner; diy_inliner.line = details.line; if (details.filename) { diy_inliner.filename = details.filename; } if (details.funcname) { diy_inliner.function = demangle(details.funcname); } else { diy_inliner.function = trace.source.function; } if (diy_inliner != trace.source) { trace.inliners.push_back(diy_inliner); } } } } #endif } return trace; } private: bool _bfd_loaded; typedef details::handle> bfd_handle_t; typedef details::handle bfd_symtab_t; struct bfd_fileobject { bfd_handle_t handle; bfd_vma base_addr; bfd_symtab_t symtab; bfd_symtab_t dynamic_symtab; }; typedef details::hashtable::type fobj_bfd_map_t; fobj_bfd_map_t _fobj_bfd_map; bfd_fileobject* load_object_with_bfd(const std::string& filename_object) { using namespace details; if (!_bfd_loaded) { using namespace details; bfd_init(); _bfd_loaded = true; } fobj_bfd_map_t::iterator it = _fobj_bfd_map.find(filename_object); if (it != _fobj_bfd_map.end()) { return &it->second; } // this new object is empty for now. bfd_fileobject* r = &_fobj_bfd_map[filename_object]; // we do the work temporary in this one; bfd_handle_t bfd_handle; int fd = open(filename_object.c_str(), O_RDONLY); bfd_handle.reset(bfd_fdopenr(filename_object.c_str(), "default", fd)); if (!bfd_handle) { close(fd); return r; } if (!bfd_check_format(bfd_handle.get(), bfd_object)) { return r; // not an object? You lose. } if ((bfd_get_file_flags(bfd_handle.get()) & HAS_SYMS) == 0) { return r; // that's what happen when you forget to compile in debug. } ssize_t symtab_storage_size = bfd_get_symtab_upper_bound(bfd_handle.get()); ssize_t dyn_symtab_storage_size = bfd_get_dynamic_symtab_upper_bound(bfd_handle.get()); if (symtab_storage_size <= 0 && dyn_symtab_storage_size <= 0) { return r; // weird, is the file is corrupted? } bfd_symtab_t symtab, dynamic_symtab; ssize_t symcount = 0, dyn_symcount = 0; if (symtab_storage_size > 0) { symtab.reset(static_cast(malloc(static_cast(symtab_storage_size)))); symcount = bfd_canonicalize_symtab(bfd_handle.get(), symtab.get()); } if (dyn_symtab_storage_size > 0) { dynamic_symtab.reset( static_cast(malloc(static_cast(dyn_symtab_storage_size)))); dyn_symcount = bfd_canonicalize_dynamic_symtab(bfd_handle.get(), dynamic_symtab.get()); } if (symcount <= 0 && dyn_symcount <= 0) { return r; // damned, that's a stripped file that you got there! } r->handle = move(bfd_handle); r->symtab = move(symtab); r->dynamic_symtab = move(dynamic_symtab); return r; } struct find_sym_result { bool found; const char* filename; const char* funcname; unsigned int line; }; struct find_sym_context { TraceResolverLinuxImpl* self; bfd_fileobject* fobj; void* addr; void* base_addr; find_sym_result result; }; find_sym_result find_symbol_details(bfd_fileobject* fobj, void* addr, void* base_addr) { find_sym_context context; context.self = this; context.fobj = fobj; context.addr = addr; context.base_addr = base_addr; context.result.found = false; bfd_map_over_sections(fobj->handle.get(), &find_in_section_trampoline, static_cast(&context)); return context.result; } static void find_in_section_trampoline(bfd*, asection* section, void* data) { find_sym_context* context = static_cast(data); context->self->find_in_section(reinterpret_cast(context->addr), reinterpret_cast(context->base_addr), context->fobj, section, context->result); } void find_in_section(bfd_vma addr, bfd_vma base_addr, bfd_fileobject* fobj, asection* section, find_sym_result& result) { if (result.found) return; #ifdef bfd_get_section_flags if ((bfd_get_section_flags(fobj->handle.get(), section) & SEC_ALLOC) == 0) #else if ((bfd_section_flags(section) & SEC_ALLOC) == 0) #endif return; // a debug section is never loaded automatically. #ifdef bfd_get_section_vma bfd_vma sec_addr = bfd_get_section_vma(fobj->handle.get(), section); #else bfd_vma sec_addr = bfd_section_vma(section); #endif #ifdef bfd_get_section_size bfd_size_type size = bfd_get_section_size(section); #else bfd_size_type size = bfd_section_size(section); #endif // are we in the boundaries of the section? if (addr < sec_addr || addr >= sec_addr + size) { addr -= base_addr; // oops, a relocated object, lets try again... if (addr < sec_addr || addr >= sec_addr + size) { return; } } #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wzero-as-null-pointer-constant" #endif if (!result.found && fobj->symtab) { result.found = bfd_find_nearest_line(fobj->handle.get(), section, fobj->symtab.get(), addr - sec_addr, &result.filename, &result.funcname, &result.line); } if (!result.found && fobj->dynamic_symtab) { result.found = bfd_find_nearest_line(fobj->handle.get(), section, fobj->dynamic_symtab.get(), addr - sec_addr, &result.filename, &result.funcname, &result.line); } #if defined(__clang__) #pragma clang diagnostic pop #endif } ResolvedTrace::source_locs_t backtrace_inliners(bfd_fileobject* fobj, find_sym_result previous_result) { // This function can be called ONLY after a SUCCESSFUL call to // find_symbol_details. The state is global to the bfd_handle. ResolvedTrace::source_locs_t results; while (previous_result.found) { find_sym_result result; result.found = bfd_find_inliner_info(fobj->handle.get(), &result.filename, &result.funcname, &result.line); if (result.found) /* and not ( cstrings_eq(previous_result.filename, result.filename) and cstrings_eq(previous_result.funcname, result.funcname) and result.line == previous_result.line )) */ { ResolvedTrace::SourceLoc src_loc; src_loc.line = result.line; if (result.filename) { src_loc.filename = result.filename; } if (result.funcname) { src_loc.function = demangle(result.funcname); } results.push_back(src_loc); } previous_result = result; } return results; } bool cstrings_eq(const char* a, const char* b) { if (!a || !b) { return false; } return strcmp(a, b) == 0; } }; #endif // BACKWARD_HAS_BFD == 1 #if BACKWARD_HAS_DW == 1 template<> class TraceResolverLinuxImpl : public TraceResolverLinuxBase { public: TraceResolverLinuxImpl() : _dwfl_handle_initialized(false) {} ResolvedTrace resolve(ResolvedTrace trace) override { using namespace details; Dwarf_Addr trace_addr = reinterpret_cast(trace.addr); if (!_dwfl_handle_initialized) { // initialize dwfl... _dwfl_cb.reset(new Dwfl_Callbacks); _dwfl_cb->find_elf = &dwfl_linux_proc_find_elf; _dwfl_cb->find_debuginfo = &dwfl_standard_find_debuginfo; _dwfl_cb->debuginfo_path = 0; _dwfl_handle.reset(dwfl_begin(_dwfl_cb.get())); _dwfl_handle_initialized = true; if (!_dwfl_handle) { return trace; } // ...from the current process. dwfl_report_begin(_dwfl_handle.get()); int r = dwfl_linux_proc_report(_dwfl_handle.get(), getpid()); dwfl_report_end(_dwfl_handle.get(), NULL, NULL); if (r < 0) { return trace; } } if (!_dwfl_handle) { return trace; } // find the module (binary object) that contains the trace's address. // This is not using any debug information, but the addresses ranges of // all the currently loaded binary object. Dwfl_Module* mod = dwfl_addrmodule(_dwfl_handle.get(), trace_addr); if (mod) { // now that we found it, lets get the name of it, this will be the // full path to the running binary or one of the loaded library. const char* module_name = dwfl_module_info(mod, 0, 0, 0, 0, 0, 0, 0); if (module_name) { trace.object_filename = module_name; } // We also look after the name of the symbol, equal or before this // address. This is found by walking the symtab. We should get the // symbol corresponding to the function (mangled) containing the // address. If the code corresponding to the address was inlined, // this is the name of the out-most inliner function. const char* sym_name = dwfl_module_addrname(mod, trace_addr); if (sym_name) { trace.object_function = demangle(sym_name); } } // now let's get serious, and find out the source location (file and // line number) of the address. // This function will look in .debug_aranges for the address and map it // to the location of the compilation unit DIE in .debug_info and // return it. Dwarf_Addr mod_bias = 0; Dwarf_Die* cudie = dwfl_module_addrdie(mod, trace_addr, &mod_bias); #if 1 if (!cudie) { // Sadly clang does not generate the section .debug_aranges, thus // dwfl_module_addrdie will fail early. Clang doesn't either set // the lowpc/highpc/range info for every compilation unit. // // So in order to save the world: // for every compilation unit, we will iterate over every single // DIEs. Normally functions should have a lowpc/highpc/range, which // we will use to infer the compilation unit. // note that this is probably badly inefficient. while ((cudie = dwfl_module_nextcu(mod, cudie, &mod_bias))) { Dwarf_Die die_mem; Dwarf_Die* fundie = find_fundie_by_pc(cudie, trace_addr - mod_bias, &die_mem); if (fundie) { break; } } } #endif //#define BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE #ifdef BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE if (!cudie) { // If it's still not enough, lets dive deeper in the shit, and try // to save the world again: for every compilation unit, we will // load the corresponding .debug_line section, and see if we can // find our address in it. Dwarf_Addr cfi_bias; Dwarf_CFI* cfi_cache = dwfl_module_eh_cfi(mod, &cfi_bias); Dwarf_Addr bias; while ((cudie = dwfl_module_nextcu(mod, cudie, &bias))) { if (dwarf_getsrc_die(cudie, trace_addr - bias)) { // ...but if we get a match, it might be a false positive // because our (address - bias) might as well be valid in a // different compilation unit. So we throw our last card on // the table and lookup for the address into the .eh_frame // section. handle frame; dwarf_cfi_addrframe(cfi_cache, trace_addr - cfi_bias, &frame); if (frame) { break; } } } } #endif if (!cudie) { return trace; // this time we lost the game :/ } // Now that we have a compilation unit DIE, this function will be able // to load the corresponding section in .debug_line (if not already // loaded) and hopefully find the source location mapped to our // address. Dwarf_Line* srcloc = dwarf_getsrc_die(cudie, trace_addr - mod_bias); if (srcloc) { const char* srcfile = dwarf_linesrc(srcloc, 0, 0); if (srcfile) { trace.source.filename = srcfile; } int line = 0, col = 0; dwarf_lineno(srcloc, &line); dwarf_linecol(srcloc, &col); trace.source.line = static_cast(line); trace.source.col = static_cast(col); } deep_first_search_by_pc(cudie, trace_addr - mod_bias, inliners_search_cb(trace)); if (trace.source.function.size() == 0) { // fallback. trace.source.function = trace.object_function; } return trace; } private: typedef details::handle> dwfl_handle_t; details::handle> _dwfl_cb; dwfl_handle_t _dwfl_handle; bool _dwfl_handle_initialized; // defined here because in C++98, template function cannot take locally // defined types... grrr. struct inliners_search_cb { void operator()(Dwarf_Die* die) { switch (dwarf_tag(die)) { const char* name; case DW_TAG_subprogram: if ((name = dwarf_diename(die))) { trace.source.function = name; } break; case DW_TAG_inlined_subroutine: ResolvedTrace::SourceLoc sloc; Dwarf_Attribute attr_mem; if ((name = dwarf_diename(die))) { sloc.function = name; } if ((name = die_call_file(die))) { sloc.filename = name; } Dwarf_Word line = 0, col = 0; dwarf_formudata(dwarf_attr(die, DW_AT_call_line, &attr_mem), &line); dwarf_formudata(dwarf_attr(die, DW_AT_call_column, &attr_mem), &col); sloc.line = static_cast(line); sloc.col = static_cast(col); trace.inliners.push_back(sloc); break; }; } ResolvedTrace& trace; inliners_search_cb(ResolvedTrace& t) : trace(t) {} }; static bool die_has_pc(Dwarf_Die* die, Dwarf_Addr pc) { Dwarf_Addr low, high; // continuous range if (dwarf_hasattr(die, DW_AT_low_pc) && dwarf_hasattr(die, DW_AT_high_pc)) { if (dwarf_lowpc(die, &low) != 0) { return false; } if (dwarf_highpc(die, &high) != 0) { Dwarf_Attribute attr_mem; Dwarf_Attribute* attr = dwarf_attr(die, DW_AT_high_pc, &attr_mem); Dwarf_Word value; if (dwarf_formudata(attr, &value) != 0) { return false; } high = low + value; } return pc >= low && pc < high; } // non-continuous range. Dwarf_Addr base; ptrdiff_t offset = 0; while ((offset = dwarf_ranges(die, offset, &base, &low, &high)) > 0) { if (pc >= low && pc < high) { return true; } } return false; } static Dwarf_Die* find_fundie_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, Dwarf_Die* result) { if (dwarf_child(parent_die, result) != 0) { return 0; } Dwarf_Die* die = result; do { switch (dwarf_tag(die)) { case DW_TAG_subprogram: case DW_TAG_inlined_subroutine: if (die_has_pc(die, pc)) { return result; } }; bool declaration = false; Dwarf_Attribute attr_mem; dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration); if (!declaration) { // let's be curious and look deeper in the tree, // function are not necessarily at the first level, but // might be nested inside a namespace, structure etc. Dwarf_Die die_mem; Dwarf_Die* indie = find_fundie_by_pc(die, pc, &die_mem); if (indie) { *result = die_mem; return result; } } } while (dwarf_siblingof(die, result) == 0); return 0; } template static bool deep_first_search_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, CB cb) { Dwarf_Die die_mem; if (dwarf_child(parent_die, &die_mem) != 0) { return false; } bool branch_has_pc = false; Dwarf_Die* die = &die_mem; do { bool declaration = false; Dwarf_Attribute attr_mem; dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration); if (!declaration) { // let's be curious and look deeper in the tree, function are // not necessarily at the first level, but might be nested // inside a namespace, structure, a function, an inlined // function etc. branch_has_pc = deep_first_search_by_pc(die, pc, cb); } if (!branch_has_pc) { branch_has_pc = die_has_pc(die, pc); } if (branch_has_pc) { cb(die); } } while (dwarf_siblingof(die, &die_mem) == 0); return branch_has_pc; } static const char* die_call_file(Dwarf_Die* die) { Dwarf_Attribute attr_mem; Dwarf_Word file_idx = 0; dwarf_formudata(dwarf_attr(die, DW_AT_call_file, &attr_mem), &file_idx); if (file_idx == 0) { return 0; } Dwarf_Die die_mem; Dwarf_Die* cudie = dwarf_diecu(die, &die_mem, 0, 0); if (!cudie) { return 0; } Dwarf_Files* files = 0; size_t nfiles; dwarf_getsrcfiles(cudie, &files, &nfiles); if (!files) { return 0; } return dwarf_filesrc(files, file_idx, 0, 0); } }; #endif // BACKWARD_HAS_DW == 1 #if BACKWARD_HAS_DWARF == 1 template<> class TraceResolverLinuxImpl : public TraceResolverLinuxBase { public: TraceResolverLinuxImpl() : _dwarf_loaded(false) {} ResolvedTrace resolve(ResolvedTrace trace) override { // trace.addr is a virtual address in memory pointing to some code. // Let's try to find from which loaded object it comes from. // The loaded object can be yourself btw. Dl_info symbol_info; int dladdr_result = 0; #if defined(__GLIBC__) link_map* link_map; // We request the link map so we can get information about offsets dladdr_result = dladdr1(trace.addr, &symbol_info, reinterpret_cast(&link_map), RTLD_DL_LINKMAP); #else // Android doesn't have dladdr1. Don't use the linker map. dladdr_result = dladdr(trace.addr, &symbol_info); #endif if (!dladdr_result) { return trace; // dat broken trace... } // Now we get in symbol_info: // .dli_fname: // pathname of the shared object that contains the address. // .dli_fbase: // where the object is loaded in memory. // .dli_sname: // the name of the nearest symbol to trace.addr, we expect a // function name. // .dli_saddr: // the exact address corresponding to .dli_sname. // // And in link_map: // .l_addr: // difference between the address in the ELF file and the address // in memory // l_name: // absolute pathname where the object was found if (symbol_info.dli_sname) { trace.object_function = demangle(symbol_info.dli_sname); } if (!symbol_info.dli_fname) { return trace; } trace.object_filename = resolve_exec_path(symbol_info); dwarf_fileobject& fobj = load_object_with_dwarf(symbol_info.dli_fname); if (!fobj.dwarf_handle) { return trace; // sad, we couldn't load the object :( } #if defined(__GLIBC__) // Convert the address to a module relative one by looking at // the module's loading address in the link map Dwarf_Addr address = reinterpret_cast(trace.addr) - reinterpret_cast(link_map->l_addr); #else Dwarf_Addr address = reinterpret_cast(trace.addr); #endif if (trace.object_function.empty()) { symbol_cache_t::iterator it = fobj.symbol_cache.lower_bound(address); if (it != fobj.symbol_cache.end()) { if (it->first != address) { if (it != fobj.symbol_cache.begin()) { --it; } } trace.object_function = demangle(it->second.c_str()); } } // Get the Compilation Unit DIE for the address Dwarf_Die die = find_die(fobj, address); if (!die) { return trace; // this time we lost the game :/ } // libdwarf doesn't give us direct access to its objects, it always // allocates a copy for the caller. We keep that copy alive in a cache // and we deallocate it later when it's no longer required. die_cache_entry& die_object = get_die_cache(fobj, die); if (die_object.isEmpty()) return trace; // We have no line section for this DIE die_linemap_t::iterator it = die_object.line_section.lower_bound(address); if (it != die_object.line_section.end()) { if (it->first != address) { if (it == die_object.line_section.begin()) { // If we are on the first item of the line section // but the address does not match it means that // the address is below the range of the DIE. Give up. return trace; } else { --it; } } } else { return trace; // We didn't find the address. } // Get the Dwarf_Line that the address points to and call libdwarf // to get source file, line and column info. Dwarf_Line line = die_object.line_buffer[it->second]; Dwarf_Error error = DW_DLE_NE; char* filename; if (dwarf_linesrc(line, &filename, &error) == DW_DLV_OK) { trace.source.filename = std::string(filename); dwarf_dealloc(fobj.dwarf_handle.get(), filename, DW_DLA_STRING); } Dwarf_Unsigned number = 0; if (dwarf_lineno(line, &number, &error) == DW_DLV_OK) { trace.source.line = number; } else { trace.source.line = 0; } if (dwarf_lineoff_b(line, &number, &error) == DW_DLV_OK) { trace.source.col = number; } else { trace.source.col = 0; } std::vector namespace_stack; deep_first_search_by_pc(fobj, die, address, namespace_stack, inliners_search_cb(trace, fobj, die)); dwarf_dealloc(fobj.dwarf_handle.get(), die, DW_DLA_DIE); return trace; } public: static int close_dwarf(Dwarf_Debug dwarf) { return dwarf_finish(dwarf, NULL); } private: bool _dwarf_loaded; typedef details::handle> dwarf_file_t; typedef details::handle> dwarf_elf_t; typedef details::handle> dwarf_handle_t; typedef std::map die_linemap_t; typedef std::map die_specmap_t; struct die_cache_entry { die_specmap_t spec_section; die_linemap_t line_section; Dwarf_Line* line_buffer; Dwarf_Signed line_count; Dwarf_Line_Context line_context; inline bool isEmpty() { return line_buffer == NULL || line_count == 0 || line_context == NULL || line_section.empty(); } die_cache_entry() : line_buffer(0), line_count(0), line_context(0) {} ~die_cache_entry() { if (line_context) { dwarf_srclines_dealloc_b(line_context); } } }; typedef std::map die_cache_t; typedef std::map symbol_cache_t; struct dwarf_fileobject { dwarf_file_t file_handle; dwarf_elf_t elf_handle; dwarf_handle_t dwarf_handle; symbol_cache_t symbol_cache; // Die cache die_cache_t die_cache; die_cache_entry* current_cu; }; typedef details::hashtable::type fobj_dwarf_map_t; fobj_dwarf_map_t _fobj_dwarf_map; static bool cstrings_eq(const char* a, const char* b) { if (!a || !b) { return false; } return strcmp(a, b) == 0; } dwarf_fileobject& load_object_with_dwarf(const std::string& filename_object) { if (!_dwarf_loaded) { // Set the ELF library operating version // If that fails there's nothing we can do _dwarf_loaded = elf_version(EV_CURRENT) != EV_NONE; } fobj_dwarf_map_t::iterator it = _fobj_dwarf_map.find(filename_object); if (it != _fobj_dwarf_map.end()) { return it->second; } // this new object is empty for now dwarf_fileobject& r = _fobj_dwarf_map[filename_object]; dwarf_file_t file_handle; file_handle.reset(open(filename_object.c_str(), O_RDONLY)); if (file_handle.get() < 0) { return r; } // Try to get an ELF handle. We need to read the ELF sections // because we want to see if there is a .gnu_debuglink section // that points to a split debug file dwarf_elf_t elf_handle; elf_handle.reset(elf_begin(file_handle.get(), ELF_C_READ, NULL)); if (!elf_handle) { return r; } const char* e_ident = elf_getident(elf_handle.get(), 0); if (!e_ident) { return r; } // Get the number of sections // We use the new APIs as elf_getshnum is deprecated size_t shdrnum = 0; if (elf_getshdrnum(elf_handle.get(), &shdrnum) == -1) { return r; } // Get the index to the string section size_t shdrstrndx = 0; if (elf_getshdrstrndx(elf_handle.get(), &shdrstrndx) == -1) { return r; } std::string debuglink; // Iterate through the ELF sections to try to get a gnu_debuglink // note and also to cache the symbol table. // We go the preprocessor way to avoid having to create templated // classes or using gelf (which might throw a compiler error if 64 bit // is not supported #define ELF_GET_DATA(ARCH) \ Elf_Scn* elf_section = 0; \ Elf_Data* elf_data = 0; \ Elf##ARCH##_Shdr* section_header = 0; \ Elf_Scn* symbol_section = 0; \ size_t symbol_count = 0; \ size_t symbol_strings = 0; \ Elf##ARCH##_Sym* symbol = 0; \ const char* section_name = 0; \ \ while ((elf_section = elf_nextscn(elf_handle.get(), elf_section)) != NULL) { \ section_header = elf##ARCH##_getshdr(elf_section); \ if (section_header == NULL) { return r; } \ \ if ((section_name = elf_strptr(elf_handle.get(), shdrstrndx, section_header->sh_name)) \ == NULL) { \ return r; \ } \ \ if (cstrings_eq(section_name, ".gnu_debuglink")) { \ elf_data = elf_getdata(elf_section, NULL); \ if (elf_data && elf_data->d_size > 0) { \ debuglink = std::string(reinterpret_cast(elf_data->d_buf)); \ } \ } \ \ switch (section_header->sh_type) { \ case SHT_SYMTAB: \ symbol_section = elf_section; \ symbol_count = section_header->sh_size / section_header->sh_entsize; \ symbol_strings = section_header->sh_link; \ break; \ \ /* We use .dynsyms as a last resort, we prefer .symtab */ \ case SHT_DYNSYM: \ if (!symbol_section) { \ symbol_section = elf_section; \ symbol_count = section_header->sh_size / section_header->sh_entsize; \ symbol_strings = section_header->sh_link; \ } \ break; \ } \ } \ \ if (symbol_section && symbol_count && symbol_strings) { \ elf_data = elf_getdata(symbol_section, NULL); \ symbol = reinterpret_cast(elf_data->d_buf); \ for (size_t i = 0; i < symbol_count; ++i) { \ int type = ELF##ARCH##_ST_TYPE(symbol->st_info); \ if (type == STT_FUNC && symbol->st_value > 0) { \ r.symbol_cache[symbol->st_value] = \ std::string(elf_strptr(elf_handle.get(), symbol_strings, symbol->st_name)); \ } \ ++symbol; \ } \ } if (e_ident[EI_CLASS] == ELFCLASS32) { ELF_GET_DATA(32) } else if (e_ident[EI_CLASS] == ELFCLASS64) { // libelf might have been built without 64 bit support #if __LIBELF64 ELF_GET_DATA(64) #endif } if (!debuglink.empty()) { // We have a debuglink section! Open an elf instance on that // file instead. If we can't open the file, then return // the elf handle we had already opened. dwarf_file_t debuglink_file; debuglink_file.reset(open(debuglink.c_str(), O_RDONLY)); if (debuglink_file.get() > 0) { dwarf_elf_t debuglink_elf; debuglink_elf.reset(elf_begin(debuglink_file.get(), ELF_C_READ, NULL)); // If we have a valid elf handle, return the new elf handle // and file handle and discard the original ones if (debuglink_elf) { elf_handle = move(debuglink_elf); file_handle = move(debuglink_file); } } } // Ok, we have a valid ELF handle, let's try to get debug symbols Dwarf_Debug dwarf_debug; Dwarf_Error error = DW_DLE_NE; dwarf_handle_t dwarf_handle; int dwarf_result = dwarf_elf_init(elf_handle.get(), DW_DLC_READ, NULL, NULL, &dwarf_debug, &error); // We don't do any special handling for DW_DLV_NO_ENTRY specially. // If we get an error, or the file doesn't have debug information // we just return. if (dwarf_result != DW_DLV_OK) { return r; } dwarf_handle.reset(dwarf_debug); r.file_handle = move(file_handle); r.elf_handle = move(elf_handle); r.dwarf_handle = move(dwarf_handle); return r; } die_cache_entry& get_die_cache(dwarf_fileobject& fobj, Dwarf_Die die) { Dwarf_Error error = DW_DLE_NE; // Get the die offset, we use it as the cache key Dwarf_Off die_offset; if (dwarf_dieoffset(die, &die_offset, &error) != DW_DLV_OK) { die_offset = 0; } die_cache_t::iterator it = fobj.die_cache.find(die_offset); if (it != fobj.die_cache.end()) { fobj.current_cu = &it->second; return it->second; } die_cache_entry& de = fobj.die_cache[die_offset]; fobj.current_cu = &de; Dwarf_Addr line_addr; Dwarf_Small table_count; // The addresses in the line section are not fully sorted (they might // be sorted by block of code belonging to the same file), which makes // it necessary to do so before searching is possible. // // As libdwarf allocates a copy of everything, let's get the contents // of the line section and keep it around. We also create a map of // program counter to line table indices so we can search by address // and get the line buffer index. // // To make things more difficult, the same address can span more than // one line, so we need to keep the index pointing to the first line // by using insert instead of the map's [ operator. // Get the line context for the DIE if (dwarf_srclines_b(die, 0, &table_count, &de.line_context, &error) == DW_DLV_OK) { // Get the source lines for this line context, to be deallocated // later if (dwarf_srclines_from_linecontext(de.line_context, &de.line_buffer, &de.line_count, &error) == DW_DLV_OK) { // Add all the addresses to our map for (int i = 0; i < de.line_count; i++) { if (dwarf_lineaddr(de.line_buffer[i], &line_addr, &error) != DW_DLV_OK) { line_addr = 0; } de.line_section.insert(std::pair(line_addr, i)); } } } // For each CU, cache the function DIEs that contain the // DW_AT_specification attribute. When building with -g3 the function // DIEs are separated in declaration and specification, with the // declaration containing only the name and parameters and the // specification the low/high pc and other compiler attributes. // // We cache those specifications so we don't skip over the declarations, // because they have no pc, and we can do namespace resolution for // DWARF function names. Dwarf_Debug dwarf = fobj.dwarf_handle.get(); Dwarf_Die current_die = 0; if (dwarf_child(die, ¤t_die, &error) == DW_DLV_OK) { for (;;) { Dwarf_Die sibling_die = 0; Dwarf_Half tag_value; dwarf_tag(current_die, &tag_value, &error); if (tag_value == DW_TAG_subprogram || tag_value == DW_TAG_inlined_subroutine) { Dwarf_Bool has_attr = 0; if (dwarf_hasattr(current_die, DW_AT_specification, &has_attr, &error) == DW_DLV_OK) { if (has_attr) { Dwarf_Attribute attr_mem; if (dwarf_attr(current_die, DW_AT_specification, &attr_mem, &error) == DW_DLV_OK) { Dwarf_Off spec_offset = 0; if (dwarf_formref(attr_mem, &spec_offset, &error) == DW_DLV_OK) { Dwarf_Off spec_die_offset; if (dwarf_dieoffset(current_die, &spec_die_offset, &error) == DW_DLV_OK) { de.spec_section[spec_offset] = spec_die_offset; } } } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } } } int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error); if (result == DW_DLV_ERROR) { break; } else if (result == DW_DLV_NO_ENTRY) { break; } if (current_die != die) { dwarf_dealloc(dwarf, current_die, DW_DLA_DIE); current_die = 0; } current_die = sibling_die; } } return de; } static Dwarf_Die get_referenced_die(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Half attr, bool global) { Dwarf_Error error = DW_DLE_NE; Dwarf_Attribute attr_mem; Dwarf_Die found_die = NULL; if (dwarf_attr(die, attr, &attr_mem, &error) == DW_DLV_OK) { Dwarf_Off offset; int result = 0; if (global) { result = dwarf_global_formref(attr_mem, &offset, &error); } else { result = dwarf_formref(attr_mem, &offset, &error); } if (result == DW_DLV_OK) { if (dwarf_offdie(dwarf, offset, &found_die, &error) != DW_DLV_OK) { found_die = NULL; } } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } return found_die; } static std::string get_referenced_die_name(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Half attr, bool global) { Dwarf_Error error = DW_DLE_NE; std::string value; Dwarf_Die found_die = get_referenced_die(dwarf, die, attr, global); if (found_die) { char* name; if (dwarf_diename(found_die, &name, &error) == DW_DLV_OK) { if (name) { value = std::string(name); } dwarf_dealloc(dwarf, name, DW_DLA_STRING); } dwarf_dealloc(dwarf, found_die, DW_DLA_DIE); } return value; } // Returns a spec DIE linked to the passed one. The caller should // deallocate the DIE static Dwarf_Die get_spec_die(dwarf_fileobject& fobj, Dwarf_Die die) { Dwarf_Debug dwarf = fobj.dwarf_handle.get(); Dwarf_Error error = DW_DLE_NE; Dwarf_Off die_offset; if (fobj.current_cu && dwarf_die_CU_offset(die, &die_offset, &error) == DW_DLV_OK) { die_specmap_t::iterator it = fobj.current_cu->spec_section.find(die_offset); // If we have a DIE that completes the current one, check if // that one has the pc we are looking for if (it != fobj.current_cu->spec_section.end()) { Dwarf_Die spec_die = 0; if (dwarf_offdie(dwarf, it->second, &spec_die, &error) == DW_DLV_OK) { return spec_die; } } } // Maybe we have an abstract origin DIE with the function information? return get_referenced_die(fobj.dwarf_handle.get(), die, DW_AT_abstract_origin, true); } static bool die_has_pc(dwarf_fileobject& fobj, Dwarf_Die die, Dwarf_Addr pc) { Dwarf_Addr low_pc = 0, high_pc = 0; Dwarf_Half high_pc_form = 0; Dwarf_Form_Class return_class; Dwarf_Error error = DW_DLE_NE; Dwarf_Debug dwarf = fobj.dwarf_handle.get(); bool has_lowpc = false; bool has_highpc = false; bool has_ranges = false; if (dwarf_lowpc(die, &low_pc, &error) == DW_DLV_OK) { // If we have a low_pc check if there is a high pc. // If we don't have a high pc this might mean we have a base // address for the ranges list or just an address. has_lowpc = true; if (dwarf_highpc_b(die, &high_pc, &high_pc_form, &return_class, &error) == DW_DLV_OK) { // We do have a high pc. In DWARF 4+ this is an offset from the // low pc, but in earlier versions it's an absolute address. has_highpc = true; // In DWARF 2/3 this would be a DW_FORM_CLASS_ADDRESS if (return_class == DW_FORM_CLASS_CONSTANT) { high_pc = low_pc + high_pc; } // We have low and high pc, check if our address // is in that range return pc >= low_pc && pc < high_pc; } } else { // Reset the low_pc, in case dwarf_lowpc failing set it to some // undefined value. low_pc = 0; } // Check if DW_AT_ranges is present and search for the PC in the // returned ranges list. We always add the low_pc, as it not set it will // be 0, in case we had a DW_AT_low_pc and DW_AT_ranges pair bool result = false; Dwarf_Attribute attr; if (dwarf_attr(die, DW_AT_ranges, &attr, &error) == DW_DLV_OK) { Dwarf_Off offset; if (dwarf_global_formref(attr, &offset, &error) == DW_DLV_OK) { Dwarf_Ranges* ranges; Dwarf_Signed ranges_count = 0; Dwarf_Unsigned byte_count = 0; if (dwarf_get_ranges_a(dwarf, offset, die, &ranges, &ranges_count, &byte_count, &error) == DW_DLV_OK) { has_ranges = ranges_count != 0; for (int i = 0; i < ranges_count; i++) { if (ranges[i].dwr_addr1 != 0 && pc >= ranges[i].dwr_addr1 + low_pc && pc < ranges[i].dwr_addr2 + low_pc) { result = true; break; } } dwarf_ranges_dealloc(dwarf, ranges, ranges_count); } } } // Last attempt. We might have a single address set as low_pc. if (!result && low_pc != 0 && pc == low_pc) { result = true; } // If we don't have lowpc, highpc and ranges maybe this DIE is a // declaration that relies on a DW_AT_specification DIE that happens // later. Use the specification cache we filled when we loaded this CU. if (!result && (!has_lowpc && !has_highpc && !has_ranges)) { Dwarf_Die spec_die = get_spec_die(fobj, die); if (spec_die) { result = die_has_pc(fobj, spec_die, pc); dwarf_dealloc(dwarf, spec_die, DW_DLA_DIE); } } return result; } static void get_type(Dwarf_Debug dwarf, Dwarf_Die die, std::string& type) { Dwarf_Error error = DW_DLE_NE; Dwarf_Die child = 0; if (dwarf_child(die, &child, &error) == DW_DLV_OK) { get_type(dwarf, child, type); } if (child) { type.insert(0, "::"); dwarf_dealloc(dwarf, child, DW_DLA_DIE); } char* name; if (dwarf_diename(die, &name, &error) == DW_DLV_OK) { type.insert(0, std::string(name)); dwarf_dealloc(dwarf, name, DW_DLA_STRING); } else { type.insert(0, ""); } } static std::string get_type_by_signature(Dwarf_Debug dwarf, Dwarf_Die die) { Dwarf_Error error = DW_DLE_NE; Dwarf_Sig8 signature; Dwarf_Bool has_attr = 0; if (dwarf_hasattr(die, DW_AT_signature, &has_attr, &error) == DW_DLV_OK) { if (has_attr) { Dwarf_Attribute attr_mem; if (dwarf_attr(die, DW_AT_signature, &attr_mem, &error) == DW_DLV_OK) { if (dwarf_formsig8(attr_mem, &signature, &error) != DW_DLV_OK) { return std::string(""); } } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } } Dwarf_Unsigned next_cu_header; Dwarf_Sig8 tu_signature; std::string result; bool found = false; while (dwarf_next_cu_header_d(dwarf, 0, 0, 0, 0, 0, 0, 0, &tu_signature, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { if (strncmp(signature.signature, tu_signature.signature, 8) == 0) { Dwarf_Die type_cu_die = 0; if (dwarf_siblingof_b(dwarf, 0, 0, &type_cu_die, &error) == DW_DLV_OK) { Dwarf_Die child_die = 0; if (dwarf_child(type_cu_die, &child_die, &error) == DW_DLV_OK) { get_type(dwarf, child_die, result); found = !result.empty(); dwarf_dealloc(dwarf, child_die, DW_DLA_DIE); } dwarf_dealloc(dwarf, type_cu_die, DW_DLA_DIE); } } } if (found) { while (dwarf_next_cu_header_d(dwarf, 0, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { // Reset the cu header state. Unfortunately, libdwarf's // next_cu_header API keeps its own iterator per Dwarf_Debug // that can't be reset. We need to keep fetching elements until // the end. } } else { // If we couldn't resolve the type just print out the signature std::ostringstream string_stream; string_stream << "<0x" << std::hex << std::setfill('0'); for (int i = 0; i < 8; ++i) { string_stream << std::setw(2) << std::hex << (int)(unsigned char)(signature.signature[i]); } string_stream << ">"; result = string_stream.str(); } return result; } struct type_context_t { bool is_const; bool is_typedef; bool has_type; bool has_name; std::string text; type_context_t() : is_const(false), is_typedef(false), has_type(false), has_name(false) {} }; // Types are resolved from right to left: we get the variable name first // and then all specifiers (like const or pointer) in a chain of DW_AT_type // DIEs. Call this function recursively until we get a complete type // string. static void set_parameter_string(dwarf_fileobject& fobj, Dwarf_Die die, type_context_t& context) { char* name; Dwarf_Error error = DW_DLE_NE; // typedefs contain also the base type, so we skip it and only // print the typedef name if (!context.is_typedef) { if (dwarf_diename(die, &name, &error) == DW_DLV_OK) { if (!context.text.empty()) { context.text.insert(0, " "); } context.text.insert(0, std::string(name)); dwarf_dealloc(fobj.dwarf_handle.get(), name, DW_DLA_STRING); } } else { context.is_typedef = false; context.has_type = true; if (context.is_const) { context.text.insert(0, "const "); context.is_const = false; } } bool next_type_is_const = false; bool is_keyword = true; Dwarf_Half tag = 0; Dwarf_Bool has_attr = 0; if (dwarf_tag(die, &tag, &error) == DW_DLV_OK) { switch (tag) { case DW_TAG_structure_type: case DW_TAG_union_type: case DW_TAG_class_type: case DW_TAG_enumeration_type: context.has_type = true; if (dwarf_hasattr(die, DW_AT_signature, &has_attr, &error) == DW_DLV_OK) { // If we have a signature it means the type is defined // in .debug_types, so we need to load the DIE pointed // at by the signature and resolve it if (has_attr) { std::string type = get_type_by_signature(fobj.dwarf_handle.get(), die); if (context.is_const) type.insert(0, "const "); if (!context.text.empty()) context.text.insert(0, " "); context.text.insert(0, type); } // Treat enums like typedefs, and skip printing its // base type context.is_typedef = (tag == DW_TAG_enumeration_type); } break; case DW_TAG_const_type: next_type_is_const = true; break; case DW_TAG_pointer_type: context.text.insert(0, "*"); break; case DW_TAG_reference_type: context.text.insert(0, "&"); break; case DW_TAG_restrict_type: context.text.insert(0, "restrict "); break; case DW_TAG_rvalue_reference_type: context.text.insert(0, "&&"); break; case DW_TAG_volatile_type: context.text.insert(0, "volatile "); break; case DW_TAG_typedef: // Propagate the const-ness to the next type // as typedefs are linked to its base type next_type_is_const = context.is_const; context.is_typedef = true; context.has_type = true; break; case DW_TAG_base_type: context.has_type = true; break; case DW_TAG_formal_parameter: context.has_name = true; break; default: is_keyword = false; break; } } if (!is_keyword && context.is_const) { context.text.insert(0, "const "); } context.is_const = next_type_is_const; Dwarf_Die ref = get_referenced_die(fobj.dwarf_handle.get(), die, DW_AT_type, true); if (ref) { set_parameter_string(fobj, ref, context); dwarf_dealloc(fobj.dwarf_handle.get(), ref, DW_DLA_DIE); } if (!context.has_type && context.has_name) { context.text.insert(0, "void "); context.has_type = true; } } // Resolve the function return type and parameters static void set_function_parameters(std::string& function_name, std::vector& ns, dwarf_fileobject& fobj, Dwarf_Die die) { Dwarf_Debug dwarf = fobj.dwarf_handle.get(); Dwarf_Error error = DW_DLE_NE; Dwarf_Die current_die = 0; std::string parameters; bool has_spec = true; // Check if we have a spec DIE. If we do we use it as it contains // more information, like parameter names. Dwarf_Die spec_die = get_spec_die(fobj, die); if (!spec_die) { has_spec = false; spec_die = die; } std::vector::const_iterator it = ns.begin(); std::string ns_name; for (it = ns.begin(); it < ns.end(); ++it) { ns_name.append(*it).append("::"); } if (!ns_name.empty()) { function_name.insert(0, ns_name); } // See if we have a function return type. It can be either on the // current die or in its spec one (usually true for inlined functions) std::string return_type = get_referenced_die_name(dwarf, die, DW_AT_type, true); if (return_type.empty()) { return_type = get_referenced_die_name(dwarf, spec_die, DW_AT_type, true); } if (!return_type.empty()) { return_type.append(" "); function_name.insert(0, return_type); } if (dwarf_child(spec_die, ¤t_die, &error) == DW_DLV_OK) { for (;;) { Dwarf_Die sibling_die = 0; Dwarf_Half tag_value; dwarf_tag(current_die, &tag_value, &error); if (tag_value == DW_TAG_formal_parameter) { // Ignore artificial (ie, compiler generated) parameters bool is_artificial = false; Dwarf_Attribute attr_mem; if (dwarf_attr(current_die, DW_AT_artificial, &attr_mem, &error) == DW_DLV_OK) { Dwarf_Bool flag = 0; if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { is_artificial = flag != 0; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } if (!is_artificial) { type_context_t context; set_parameter_string(fobj, current_die, context); if (parameters.empty()) { parameters.append("("); } else { parameters.append(", "); } parameters.append(context.text); } } int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error); if (result == DW_DLV_ERROR) { break; } else if (result == DW_DLV_NO_ENTRY) { break; } if (current_die != die) { dwarf_dealloc(dwarf, current_die, DW_DLA_DIE); current_die = 0; } current_die = sibling_die; } } if (parameters.empty()) parameters = "("; parameters.append(")"); // If we got a spec DIE we need to deallocate it if (has_spec) dwarf_dealloc(dwarf, spec_die, DW_DLA_DIE); function_name.append(parameters); } // defined here because in C++98, template function cannot take locally // defined types... grrr. struct inliners_search_cb { void operator()(Dwarf_Die die, std::vector& ns) { Dwarf_Error error = DW_DLE_NE; Dwarf_Half tag_value; Dwarf_Attribute attr_mem; Dwarf_Debug dwarf = fobj.dwarf_handle.get(); dwarf_tag(die, &tag_value, &error); switch (tag_value) { char* name; case DW_TAG_subprogram: if (!trace.source.function.empty()) break; if (dwarf_diename(die, &name, &error) == DW_DLV_OK) { trace.source.function = std::string(name); dwarf_dealloc(dwarf, name, DW_DLA_STRING); } else { // We don't have a function name in this DIE. // Check if there is a referenced non-defining // declaration. trace.source.function = get_referenced_die_name(dwarf, die, DW_AT_abstract_origin, true); if (trace.source.function.empty()) { trace.source.function = get_referenced_die_name(dwarf, die, DW_AT_specification, true); } } // Append the function parameters, if available set_function_parameters(trace.source.function, ns, fobj, die); // If the object function name is empty, it's possible that // there is no dynamic symbol table (maybe the executable // was stripped or not built with -rdynamic). See if we have // a DWARF linkage name to use instead. We try both // linkage_name and MIPS_linkage_name because the MIPS tag // was the unofficial one until it was adopted in DWARF4. // Old gcc versions generate MIPS_linkage_name if (trace.object_function.empty()) { details::demangler demangler; if (dwarf_attr(die, DW_AT_linkage_name, &attr_mem, &error) != DW_DLV_OK) { if (dwarf_attr(die, DW_AT_MIPS_linkage_name, &attr_mem, &error) != DW_DLV_OK) { break; } } char* linkage; if (dwarf_formstring(attr_mem, &linkage, &error) == DW_DLV_OK) { trace.object_function = demangler.demangle(linkage); dwarf_dealloc(dwarf, linkage, DW_DLA_STRING); } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } break; case DW_TAG_inlined_subroutine: ResolvedTrace::SourceLoc sloc; if (dwarf_diename(die, &name, &error) == DW_DLV_OK) { sloc.function = std::string(name); dwarf_dealloc(dwarf, name, DW_DLA_STRING); } else { // We don't have a name for this inlined DIE, it could // be that there is an abstract origin instead. // Get the DW_AT_abstract_origin value, which is a // reference to the source DIE and try to get its name sloc.function = get_referenced_die_name(dwarf, die, DW_AT_abstract_origin, true); } set_function_parameters(sloc.function, ns, fobj, die); std::string file = die_call_file(dwarf, die, cu_die); if (!file.empty()) sloc.filename = file; Dwarf_Unsigned number = 0; if (dwarf_attr(die, DW_AT_call_line, &attr_mem, &error) == DW_DLV_OK) { if (dwarf_formudata(attr_mem, &number, &error) == DW_DLV_OK) { sloc.line = number; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } if (dwarf_attr(die, DW_AT_call_column, &attr_mem, &error) == DW_DLV_OK) { if (dwarf_formudata(attr_mem, &number, &error) == DW_DLV_OK) { sloc.col = number; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } trace.inliners.push_back(sloc); break; }; } ResolvedTrace& trace; dwarf_fileobject& fobj; Dwarf_Die cu_die; inliners_search_cb(ResolvedTrace& t, dwarf_fileobject& f, Dwarf_Die c) : trace(t), fobj(f), cu_die(c) {} }; static Dwarf_Die find_fundie_by_pc(dwarf_fileobject& fobj, Dwarf_Die parent_die, Dwarf_Addr pc, Dwarf_Die result) { Dwarf_Die current_die = 0; Dwarf_Error error = DW_DLE_NE; Dwarf_Debug dwarf = fobj.dwarf_handle.get(); if (dwarf_child(parent_die, ¤t_die, &error) != DW_DLV_OK) { return NULL; } for (;;) { Dwarf_Die sibling_die = 0; Dwarf_Half tag_value; dwarf_tag(current_die, &tag_value, &error); switch (tag_value) { case DW_TAG_subprogram: case DW_TAG_inlined_subroutine: if (die_has_pc(fobj, current_die, pc)) { return current_die; } }; bool declaration = false; Dwarf_Attribute attr_mem; if (dwarf_attr(current_die, DW_AT_declaration, &attr_mem, &error) == DW_DLV_OK) { Dwarf_Bool flag = 0; if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { declaration = flag != 0; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } if (!declaration) { // let's be curious and look deeper in the tree, functions are // not necessarily at the first level, but might be nested // inside a namespace, structure, a function, an inlined // function etc. Dwarf_Die die_mem = 0; Dwarf_Die indie = find_fundie_by_pc(fobj, current_die, pc, die_mem); if (indie) { result = die_mem; return result; } } int res = dwarf_siblingof(dwarf, current_die, &sibling_die, &error); if (res == DW_DLV_ERROR) { return NULL; } else if (res == DW_DLV_NO_ENTRY) { break; } if (current_die != parent_die) { dwarf_dealloc(dwarf, current_die, DW_DLA_DIE); current_die = 0; } current_die = sibling_die; } return NULL; } template static bool deep_first_search_by_pc(dwarf_fileobject& fobj, Dwarf_Die parent_die, Dwarf_Addr pc, std::vector& ns, CB cb) { Dwarf_Die current_die = 0; Dwarf_Debug dwarf = fobj.dwarf_handle.get(); Dwarf_Error error = DW_DLE_NE; if (dwarf_child(parent_die, ¤t_die, &error) != DW_DLV_OK) { return false; } bool branch_has_pc = false; bool has_namespace = false; for (;;) { Dwarf_Die sibling_die = 0; Dwarf_Half tag; if (dwarf_tag(current_die, &tag, &error) == DW_DLV_OK) { if (tag == DW_TAG_namespace || tag == DW_TAG_class_type) { char* ns_name = NULL; if (dwarf_diename(current_die, &ns_name, &error) == DW_DLV_OK) { if (ns_name) { ns.push_back(std::string(ns_name)); } else { ns.push_back(""); } dwarf_dealloc(dwarf, ns_name, DW_DLA_STRING); } else { ns.push_back(""); } has_namespace = true; } } bool declaration = false; Dwarf_Attribute attr_mem; if (tag != DW_TAG_class_type && dwarf_attr(current_die, DW_AT_declaration, &attr_mem, &error) == DW_DLV_OK) { Dwarf_Bool flag = 0; if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { declaration = flag != 0; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); } if (!declaration) { // let's be curious and look deeper in the tree, function are // not necessarily at the first level, but might be nested // inside a namespace, structure, a function, an inlined // function etc. branch_has_pc = deep_first_search_by_pc(fobj, current_die, pc, ns, cb); } if (!branch_has_pc) { branch_has_pc = die_has_pc(fobj, current_die, pc); } if (branch_has_pc) { cb(current_die, ns); } int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error); if (result == DW_DLV_ERROR) { return false; } else if (result == DW_DLV_NO_ENTRY) { break; } if (current_die != parent_die) { dwarf_dealloc(dwarf, current_die, DW_DLA_DIE); current_die = 0; } if (has_namespace) { has_namespace = false; ns.pop_back(); } current_die = sibling_die; } if (has_namespace) { ns.pop_back(); } return branch_has_pc; } static std::string die_call_file(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Die cu_die) { Dwarf_Attribute attr_mem; Dwarf_Error error = DW_DLE_NE; Dwarf_Unsigned file_index; std::string file; if (dwarf_attr(die, DW_AT_call_file, &attr_mem, &error) == DW_DLV_OK) { if (dwarf_formudata(attr_mem, &file_index, &error) != DW_DLV_OK) { file_index = 0; } dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR); if (file_index == 0) { return file; } char** srcfiles = 0; Dwarf_Signed file_count = 0; if (dwarf_srcfiles(cu_die, &srcfiles, &file_count, &error) == DW_DLV_OK) { if (file_count > 0 && file_index <= static_cast(file_count)) { file = std::string(srcfiles[file_index - 1]); } // Deallocate all strings! for (int i = 0; i < file_count; ++i) { dwarf_dealloc(dwarf, srcfiles[i], DW_DLA_STRING); } dwarf_dealloc(dwarf, srcfiles, DW_DLA_LIST); } } return file; } Dwarf_Die find_die(dwarf_fileobject& fobj, Dwarf_Addr addr) { // Let's get to work! First see if we have a debug_aranges section so // we can speed up the search Dwarf_Debug dwarf = fobj.dwarf_handle.get(); Dwarf_Error error = DW_DLE_NE; Dwarf_Arange* aranges; Dwarf_Signed arange_count; Dwarf_Die returnDie; bool found = false; if (dwarf_get_aranges(dwarf, &aranges, &arange_count, &error) != DW_DLV_OK) { aranges = NULL; } if (aranges) { // We have aranges. Get the one where our address is. Dwarf_Arange arange; if (dwarf_get_arange(aranges, arange_count, addr, &arange, &error) == DW_DLV_OK) { // We found our address. Get the compilation-unit DIE offset // represented by the given address range. Dwarf_Off cu_die_offset; if (dwarf_get_cu_die_offset(arange, &cu_die_offset, &error) == DW_DLV_OK) { // Get the DIE at the offset returned by the aranges search. // We set is_info to 1 to specify that the offset is from // the .debug_info section (and not .debug_types) int dwarf_result = dwarf_offdie_b(dwarf, cu_die_offset, 1, &returnDie, &error); found = dwarf_result == DW_DLV_OK; } dwarf_dealloc(dwarf, arange, DW_DLA_ARANGE); } } if (found) return returnDie; // The caller is responsible for freeing the die // The search for aranges failed. Try to find our address by scanning // all compilation units. Dwarf_Unsigned next_cu_header; Dwarf_Half tag = 0; returnDie = 0; while (!found && dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { if (returnDie) dwarf_dealloc(dwarf, returnDie, DW_DLA_DIE); if (dwarf_siblingof(dwarf, 0, &returnDie, &error) == DW_DLV_OK) { if ((dwarf_tag(returnDie, &tag, &error) == DW_DLV_OK) && tag == DW_TAG_compile_unit) { if (die_has_pc(fobj, returnDie, addr)) { found = true; } } } } if (found) { while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { // Reset the cu header state. Libdwarf's next_cu_header API // keeps its own iterator per Dwarf_Debug that can't be reset. // We need to keep fetching elements until the end. } } if (found) return returnDie; // We couldn't find any compilation units with ranges or a high/low pc. // Try again by looking at all DIEs in all compilation units. Dwarf_Die cudie; while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { if (dwarf_siblingof(dwarf, 0, &cudie, &error) == DW_DLV_OK) { Dwarf_Die die_mem = 0; Dwarf_Die resultDie = find_fundie_by_pc(fobj, cudie, addr, die_mem); if (resultDie) { found = true; break; } } } if (found) { while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error) == DW_DLV_OK) { // Reset the cu header state. Libdwarf's next_cu_header API // keeps its own iterator per Dwarf_Debug that can't be reset. // We need to keep fetching elements until the end. } } if (found) return cudie; // We failed. return NULL; } }; #endif // BACKWARD_HAS_DWARF == 1 template<> class TraceResolverImpl : public TraceResolverLinuxImpl {}; #endif // BACKWARD_SYSTEM_LINUX #ifdef BACKWARD_SYSTEM_DARWIN template class TraceResolverDarwinImpl; template<> class TraceResolverDarwinImpl : public TraceResolverImplBase { public: void load_addresses(void* const* addresses, int address_count) override { if (address_count == 0) { return; } _symbols.reset(backtrace_symbols(addresses, address_count)); } ResolvedTrace resolve(ResolvedTrace trace) override { // parse: // + char* filename = _symbols[trace.idx]; // skip " " while (*filename && *filename != ' ') filename++; while (*filename == ' ') filename++; // find start of from end ( may contain a space) char* p = filename + strlen(filename) - 1; // skip to start of " + " while (p > filename && *p != ' ') p--; while (p > filename && *p == ' ') p--; while (p > filename && *p != ' ') p--; while (p > filename && *p == ' ') p--; char* funcname_end = p + 1; // skip to start of "" while (p > filename && *p != ' ') p--; char* funcname = p + 1; // skip to start of " " while (p > filename && *p == ' ') p--; while (p > filename && *p != ' ') p--; while (p > filename && *p == ' ') p--; // skip "", handling the case where it contains a char* filename_end = p + 1; if (p == filename) { // something went wrong, give up filename_end = filename + strlen(filename); funcname = filename_end; } trace.object_filename.assign(filename, filename_end); // ok even if filename_end is the ending // \0 (then we assign entire string) if (*funcname) { // if it's not end of string *funcname_end = '\0'; trace.object_function = this->demangle(funcname); trace.object_function += " "; trace.object_function += (funcname_end + 1); trace.source.function = trace.object_function; // we cannot do better. } return trace; } private: details::handle _symbols; }; template<> class TraceResolverImpl : public TraceResolverDarwinImpl {}; #endif // BACKWARD_SYSTEM_DARWIN #ifdef BACKWARD_SYSTEM_WINDOWS // Load all symbol info // Based on: // https://stackoverflow.com/questions/6205981/windows-c-stack-trace-from-a-running-app/28276227#28276227 struct module_data { std::string image_name; std::string module_name; void* base_address; DWORD load_size; }; class get_mod_info { HANDLE process; static const int buffer_length = 4096; public: get_mod_info(HANDLE h) : process(h) {} module_data operator()(HMODULE module) { module_data ret; char temp[buffer_length]; MODULEINFO mi; GetModuleInformation(process, module, &mi, sizeof(mi)); ret.base_address = mi.lpBaseOfDll; ret.load_size = mi.SizeOfImage; GetModuleFileNameExA(process, module, temp, sizeof(temp)); ret.image_name = temp; GetModuleBaseNameA(process, module, temp, sizeof(temp)); ret.module_name = temp; std::vector img(ret.image_name.begin(), ret.image_name.end()); std::vector mod(ret.module_name.begin(), ret.module_name.end()); SymLoadModule64(process, 0, &img[0], &mod[0], (DWORD64)ret.base_address, ret.load_size); return ret; } }; template<> class TraceResolverImpl : public TraceResolverImplBase { public: TraceResolverImpl() { HANDLE process = GetCurrentProcess(); std::vector modules; DWORD cbNeeded; std::vector module_handles(1); SymInitialize(process, NULL, false); DWORD symOptions = SymGetOptions(); symOptions |= SYMOPT_LOAD_LINES | SYMOPT_UNDNAME; SymSetOptions(symOptions); EnumProcessModules(process, &module_handles[0], static_cast(module_handles.size() * sizeof(HMODULE)), &cbNeeded); module_handles.resize(cbNeeded / sizeof(HMODULE)); EnumProcessModules(process, &module_handles[0], static_cast(module_handles.size() * sizeof(HMODULE)), &cbNeeded); std::transform(module_handles.begin(), module_handles.end(), std::back_inserter(modules), get_mod_info(process)); void* base = modules[0].base_address; IMAGE_NT_HEADERS* h = ImageNtHeader(base); image_type = h->FileHeader.Machine; } static const int max_sym_len = 255; struct symbol_t { SYMBOL_INFO sym; char buffer[max_sym_len]; } sym; DWORD64 displacement; ResolvedTrace resolve(ResolvedTrace t) override { HANDLE process = GetCurrentProcess(); char name[256]; memset(&sym, 0, sizeof(sym)); sym.sym.SizeOfStruct = sizeof(SYMBOL_INFO); sym.sym.MaxNameLen = max_sym_len; if (!SymFromAddr(process, (ULONG64)t.addr, &displacement, &sym.sym)) { // TODO: error handling everywhere char* lpMsgBuf; DWORD dw = GetLastError(); if (FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, dw, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (char*)&lpMsgBuf, 0, NULL)) { std::fprintf(stderr, "%s\n", lpMsgBuf); LocalFree(lpMsgBuf); } // abort(); } UnDecorateSymbolName(sym.sym.Name, (PSTR)name, 256, UNDNAME_COMPLETE); DWORD offset = 0; IMAGEHLP_LINE line; if (SymGetLineFromAddr(process, (ULONG64)t.addr, &offset, &line)) { t.object_filename = line.FileName; t.source.filename = line.FileName; t.source.line = line.LineNumber; t.source.col = offset; } t.source.function = name; t.object_filename = ""; t.object_function = name; return t; } DWORD machine_type() const { return image_type; } private: DWORD image_type; }; #endif class TraceResolver : public TraceResolverImpl {}; /*************** CODE SNIPPET ***************/ class SourceFile { public: typedef std::vector> lines_t; SourceFile() {} SourceFile(const std::string& path) { // 1. If BACKWARD_CXX_SOURCE_PREFIXES is set then assume it contains // a colon-separated list of path prefixes. Try prepending each // to the given path until a valid file is found. const std::vector& prefixes = get_paths_from_env_variable(); for (size_t i = 0; i < prefixes.size(); ++i) { // Double slashes (//) should not be a problem. std::string new_path = prefixes[i] + '/' + path; _file.reset(new std::ifstream(new_path.c_str())); if (is_open()) break; } // 2. If no valid file found then fallback to opening the path as-is. if (!_file || !is_open()) { _file.reset(new std::ifstream(path.c_str())); } } bool is_open() const { return _file->is_open(); } lines_t& get_lines(unsigned line_start, unsigned line_count, lines_t& lines) { using namespace std; // This function make uses of the dumbest algo ever: // 1) seek(0) // 2) read lines one by one and discard until line_start // 3) read line one by one until line_start + line_count // // If you are getting snippets many time from the same file, it is // somewhat a waste of CPU, feel free to benchmark and propose a // better solution ;) _file->clear(); _file->seekg(0); string line; unsigned line_idx; for (line_idx = 1; line_idx < line_start; ++line_idx) { std::getline(*_file, line); if (!*_file) { return lines; } } // think of it like a lambda in C++98 ;) // but look, I will reuse it two times! // What a good boy am I. struct isspace { bool operator()(char c) { return std::isspace(c); } }; bool started = false; for (; line_idx < line_start + line_count; ++line_idx) { getline(*_file, line); if (!*_file) { return lines; } if (!started) { if (std::find_if(line.begin(), line.end(), not_isspace()) == line.end()) continue; started = true; } lines.push_back(make_pair(line_idx, line)); } lines.erase(std::find_if(lines.rbegin(), lines.rend(), not_isempty()).base(), lines.end()); return lines; } lines_t get_lines(unsigned line_start, unsigned line_count) { lines_t lines; return get_lines(line_start, line_count, lines); } // there is no find_if_not in C++98, lets do something crappy to // workaround. struct not_isspace { bool operator()(char c) { return !std::isspace(c); } }; // and define this one here because C++98 is not happy with local defined // struct passed to template functions, fuuuu. struct not_isempty { bool operator()(const lines_t::value_type& p) { return !(std::find_if(p.second.begin(), p.second.end(), not_isspace()) == p.second.end()); } }; void swap(SourceFile& b) { _file.swap(b._file); } #ifdef BACKWARD_ATLEAST_CXX11 SourceFile(SourceFile&& from) : _file(nullptr) { swap(from); } SourceFile& operator=(SourceFile&& from) { swap(from); return *this; } #else explicit SourceFile(const SourceFile& from) { // some sort of poor man's move semantic. swap(const_cast(from)); } SourceFile& operator=(const SourceFile& from) { // some sort of poor man's move semantic. swap(const_cast(from)); return *this; } #endif // Allow adding to paths gotten from BACKWARD_CXX_SOURCE_PREFIXES after loading the // library; this can be useful when the library is loaded when the locations are unknown // Warning: Because this edits the static paths variable, it is *not* intrinsiclly thread safe static void add_paths_to_env_variable_impl(const std::string& to_add) { get_mutable_paths_from_env_variable().push_back(to_add); } private: details::handle> _file; static std::vector get_paths_from_env_variable_impl() { std::vector paths; const char* prefixes_str = std::getenv("BACKWARD_CXX_SOURCE_PREFIXES"); if (prefixes_str && prefixes_str[0]) { paths = details::split_source_prefixes(prefixes_str); } return paths; } static std::vector& get_mutable_paths_from_env_variable() { static volatile std::vector paths = get_paths_from_env_variable_impl(); return const_cast&>(paths); } static const std::vector& get_paths_from_env_variable() { return get_mutable_paths_from_env_variable(); } #ifdef BACKWARD_ATLEAST_CXX11 SourceFile(const SourceFile&) = delete; SourceFile& operator=(const SourceFile&) = delete; #endif }; class SnippetFactory { public: typedef SourceFile::lines_t lines_t; lines_t get_snippet(const std::string& filename, unsigned line_start, unsigned context_size) { SourceFile& src_file = get_src_file(filename); unsigned start = line_start - context_size / 2; return src_file.get_lines(start, context_size); } lines_t get_combined_snippet(const std::string& filename_a, unsigned line_a, const std::string& filename_b, unsigned line_b, unsigned context_size) { SourceFile& src_file_a = get_src_file(filename_a); SourceFile& src_file_b = get_src_file(filename_b); lines_t lines = src_file_a.get_lines(line_a - context_size / 4, context_size / 2); src_file_b.get_lines(line_b - context_size / 4, context_size / 2, lines); return lines; } lines_t get_coalesced_snippet(const std::string& filename, unsigned line_a, unsigned line_b, unsigned context_size) { SourceFile& src_file = get_src_file(filename); using std::max; using std::min; unsigned a = min(line_a, line_b); unsigned b = max(line_a, line_b); if ((b - a) < (context_size / 3)) { return src_file.get_lines((a + b - context_size + 1) / 2, context_size); } lines_t lines = src_file.get_lines(a - context_size / 4, context_size / 2); src_file.get_lines(b - context_size / 4, context_size / 2, lines); return lines; } private: typedef details::hashtable::type src_files_t; src_files_t _src_files; SourceFile& get_src_file(const std::string& filename) { src_files_t::iterator it = _src_files.find(filename); if (it != _src_files.end()) { return it->second; } SourceFile& new_src_file = _src_files[filename]; new_src_file = SourceFile(filename); return new_src_file; } }; /*************** PRINTER ***************/ namespace ColorMode { enum type { automatic, never, always }; } class cfile_streambuf : public std::streambuf { public: cfile_streambuf(FILE* _sink) : sink(_sink) {} int_type underflow() override { return traits_type::eof(); } int_type overflow(int_type ch) override { if (traits_type::not_eof(ch) && fputc(ch, sink) != EOF) { return ch; } return traits_type::eof(); } std::streamsize xsputn(const char_type* s, std::streamsize count) override { return static_cast(fwrite(s, sizeof *s, static_cast(count), sink)); } #ifdef BACKWARD_ATLEAST_CXX11 public: cfile_streambuf(const cfile_streambuf&) = delete; cfile_streambuf& operator=(const cfile_streambuf&) = delete; #else private: cfile_streambuf(const cfile_streambuf&); cfile_streambuf& operator=(const cfile_streambuf&); #endif private: FILE* sink; std::vector buffer; }; #ifdef BACKWARD_SYSTEM_LINUX namespace Color { enum type { yellow = 33, purple = 35, reset = 39 }; } // namespace Color class Colorize { public: Colorize(std::ostream& os) : _os(os), _reset(false), _enabled(false) {} void activate(ColorMode::type mode) { _enabled = mode == ColorMode::always; } void activate(ColorMode::type mode, FILE* fp) { activate(mode, fileno(fp)); } void set_color(Color::type ccode) { if (!_enabled) return; // I assume that the terminal can handle basic colors. Seriously I // don't want to deal with all the termcap shit. _os << "\033[" << static_cast(ccode) << "m"; _reset = (ccode != Color::reset); } ~Colorize() { if (_reset) { set_color(Color::reset); } } private: void activate(ColorMode::type mode, int fd) { activate(mode == ColorMode::automatic && isatty(fd) ? ColorMode::always : mode); } std::ostream& _os; bool _reset; bool _enabled; }; #else // ndef BACKWARD_SYSTEM_LINUX namespace Color { enum type { yellow = 0, purple = 0, reset = 0 }; } // namespace Color class Colorize { public: Colorize(std::ostream&) {} void activate(ColorMode::type) {} void activate(ColorMode::type, FILE*) {} void set_color(Color::type) {} }; #endif // BACKWARD_SYSTEM_LINUX class Printer { public: bool snippet; ColorMode::type color_mode; bool address; bool object; int inliner_context_size; int trace_context_size; bool reverse; Printer() : snippet(true), color_mode(ColorMode::automatic), address(false), object(false), // Modify: Show one line by default // inliner_context_size(5), // trace_context_size(7), inliner_context_size(1), trace_context_size(1), reverse(true) {} template FILE* print(ST& st, FILE* fp = stderr) { cfile_streambuf obuf(fp); std::ostream os(&obuf); Colorize colorize(os); colorize.activate(color_mode, fp); print_stacktrace(st, os, colorize); return fp; } template std::ostream& print(ST& st, std::ostream& os) { Colorize colorize(os); colorize.activate(color_mode); print_stacktrace(st, os, colorize); return os; } template FILE* print(IT begin, IT end, FILE* fp = stderr, size_t thread_id = 0) { cfile_streambuf obuf(fp); std::ostream os(&obuf); Colorize colorize(os); colorize.activate(color_mode, fp); print_stacktrace(begin, end, os, thread_id, colorize); return fp; } template std::ostream& print(IT begin, IT end, std::ostream& os, size_t thread_id = 0) { Colorize colorize(os); colorize.activate(color_mode); print_stacktrace(begin, end, os, thread_id, colorize); return os; } // Modify: skip stacks in python object file static inline bool is_oneflow_file(const std::string& filename) { return std::string(std::filesystem::path(filename).filename()).find("oneflow") != std::string::npos; } TraceResolver const& resolver() const { return _resolver; } private: TraceResolver _resolver; SnippetFactory _snippets; template void print_stacktrace(ST& st, std::ostream& os, Colorize& colorize) { print_header(os, st.thread_id()); _resolver.load_stacktrace(st); if (reverse) { for (size_t trace_idx = st.size(); trace_idx > 0; --trace_idx) { print_trace(os, _resolver.resolve(st[trace_idx - 1]), colorize); } } else { for (size_t trace_idx = 0; trace_idx < st.size(); ++trace_idx) { print_trace(os, _resolver.resolve(st[trace_idx]), colorize); } } // Modify: Add a new line before Python stack os << std::endl; } template void print_stacktrace(IT begin, IT end, std::ostream& os, size_t thread_id, Colorize& colorize) { print_header(os, thread_id); for (; begin != end; ++begin) { print_trace(os, *begin, colorize); } } void print_header(std::ostream& os, size_t thread_id) { os << "Stack trace (most recent call last)"; if (thread_id) { os << " in thread " << thread_id; } os << ":\n"; } void print_trace(std::ostream& os, const ResolvedTrace& trace, Colorize& colorize) { // Modify: skip stacks in python object file if (!is_oneflow_file(trace.object_filename)) { return; } // Modify: symbol '#', trace idx and indent are not necessary // os << "#" << std::left << std::setw(2) << trace.idx << std::right; // bool already_indented = true; if (!trace.source.filename.size() || object) { os << " Object \"" << trace.object_filename << "\", at " << trace.addr << ", in " << trace.object_function << "\n"; // Modify: Extra indent is not necessary // already_indented = false; } for (size_t inliner_idx = trace.inliners.size(); inliner_idx > 0; --inliner_idx) { // Modify: Extra indent is not necessary // if (!already_indented) { os << " "; } const ResolvedTrace::SourceLoc& inliner_loc = trace.inliners[inliner_idx - 1]; print_source_loc(os, " | ", inliner_loc); if (snippet) { // Modify: Symbol '|' is not necessary // print_snippet(os, " | ", inliner_loc, colorize, Color::purple, inliner_context_size); print_snippet(os, " ", inliner_loc, colorize, Color::purple, inliner_context_size); } // Modify: Extra indent is not necessary // already_indented = false; } if (trace.source.filename.size()) { // Modify: Extra indent is not necessary // if (!already_indented) { os << " "; } // Modify: Adjust the indent // print_source_loc(os, " ", trace.source, trace.addr); print_source_loc(os, " ", trace.source, trace.addr); if (snippet) { // Modify: Adjust the indent // print_snippet(os, " ", trace.source, colorize, Color::yellow, trace_context_size); print_snippet(os, " ", trace.source, colorize, Color::yellow, trace_context_size); } } } void print_snippet(std::ostream& os, const char* indent, const ResolvedTrace::SourceLoc& source_loc, Colorize& colorize, Color::type color_code, int context_size) { using namespace std; typedef SnippetFactory::lines_t lines_t; lines_t lines = _snippets.get_snippet(source_loc.filename, source_loc.line, static_cast(context_size)); for (lines_t::const_iterator it = lines.begin(); it != lines.end(); ++it) { if (it->first == source_loc.line) { colorize.set_color(color_code); // Modify: Remove symbol '>' if there is only one line to show // os << indent << ">"; // } else { // os << indent << " "; // } // os << std::setw(4) << it->first << ": " << it->second << "\n"; if (lines.size() > 1) { os << indent << ">"; } else { os << indent << " "; } } else { os << indent << " "; } const auto pos = it->second.find_first_not_of(" \t"); os << std::setw(4) << it->second.substr(pos, it->second.size() - pos) << "\n"; if (it->first == source_loc.line) { colorize.set_color(Color::reset); } } } void print_source_loc(std::ostream& os, const char* indent, const ResolvedTrace::SourceLoc& source_loc, void* addr = nullptr) { // Modify: Remove indent and replace 'Source' to 'File' // os << indent << "Source \"" << source_loc.filename << "\", line " << source_loc.line << ", in // " os << " File \"" << source_loc.filename << "\", line " << source_loc.line << ", in " << source_loc.function; if (address && addr != nullptr) { os << " [" << addr << "]"; } os << "\n"; } }; /*************** SIGNALS HANDLING ***************/ #if defined(BACKWARD_SYSTEM_LINUX) || defined(BACKWARD_SYSTEM_DARWIN) class SignalHandling { public: static std::vector make_default_signals() { const int posix_signals[] = { // Signals for which the default action is "Core". SIGABRT, // Abort signal from abort(3) SIGBUS, // Bus error (bad memory access) SIGFPE, // Floating point exception SIGILL, // Illegal Instruction SIGIOT, // IOT trap. A synonym for SIGABRT SIGQUIT, // Quit from keyboard SIGSEGV, // Invalid memory reference SIGSYS, // Bad argument to routine (SVr4) SIGTRAP, // Trace/breakpoint trap SIGXCPU, // CPU time limit exceeded (4.2BSD) SIGXFSZ, // File size limit exceeded (4.2BSD) #if defined(BACKWARD_SYSTEM_DARWIN) SIGEMT, // emulation instruction executed #endif }; return std::vector(posix_signals, posix_signals + sizeof posix_signals / sizeof posix_signals[0]); } SignalHandling(const std::vector& posix_signals = make_default_signals()) : _loaded(false) { bool success = true; const size_t stack_size = 1024 * 1024 * 8; _stack_content.reset(static_cast(malloc(stack_size))); if (_stack_content) { stack_t ss; ss.ss_sp = _stack_content.get(); ss.ss_size = stack_size; ss.ss_flags = 0; if (sigaltstack(&ss, nullptr) < 0) { success = false; } } else { success = false; } for (size_t i = 0; i < posix_signals.size(); ++i) { struct sigaction action; memset(&action, 0, sizeof action); action.sa_flags = static_cast(SA_SIGINFO | SA_ONSTACK | SA_NODEFER | SA_RESETHAND); sigfillset(&action.sa_mask); sigdelset(&action.sa_mask, posix_signals[i]); #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdisabled-macro-expansion" #endif action.sa_sigaction = &sig_handler; #if defined(__clang__) #pragma clang diagnostic pop #endif int r = sigaction(posix_signals[i], &action, nullptr); if (r < 0) success = false; } _loaded = success; } bool loaded() const { return _loaded; } static void handleSignal(int, siginfo_t* info, void* _ctx) { ucontext_t* uctx = static_cast(_ctx); StackTrace st; void* error_addr = nullptr; #ifdef REG_RIP // x86_64 error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_RIP]); #elif defined(REG_EIP) // x86_32 error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_EIP]); #elif defined(__arm__) error_addr = reinterpret_cast(uctx->uc_mcontext.arm_pc); #elif defined(__aarch64__) #if defined(__APPLE__) error_addr = reinterpret_cast(uctx->uc_mcontext->__ss.__pc); #else error_addr = reinterpret_cast(uctx->uc_mcontext.pc); #endif #elif defined(__mips__) error_addr = reinterpret_cast(reinterpret_cast(&uctx->uc_mcontext)->sc_pc); #elif defined(__ppc__) || defined(__powerpc) || defined(__powerpc__) || defined(__POWERPC__) error_addr = reinterpret_cast(uctx->uc_mcontext.regs->nip); #elif defined(__riscv) error_addr = reinterpret_cast(uctx->uc_mcontext.__gregs[REG_PC]); #elif defined(__s390x__) error_addr = reinterpret_cast(uctx->uc_mcontext.psw.addr); #elif defined(__APPLE__) && defined(__x86_64__) error_addr = reinterpret_cast(uctx->uc_mcontext->__ss.__rip); #elif defined(__APPLE__) error_addr = reinterpret_cast(uctx->uc_mcontext->__ss.__eip); #else #warning ":/ sorry, ain't know no nothing none not of your architecture!" #endif if (error_addr) { st.load_from(error_addr, 32, reinterpret_cast(uctx), info->si_addr); } else { st.load_here(32, reinterpret_cast(uctx), info->si_addr); } Printer printer; // Modify: Hide the address in stack when seg fault // printer.address = true; printer.print(st, stderr); #if (defined(_XOPEN_SOURCE) && _XOPEN_SOURCE >= 700) \ || (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200809L) psiginfo(info, nullptr); #else (void)info; #endif } private: details::handle _stack_content; bool _loaded; #ifdef __GNUC__ __attribute__((noreturn)) #endif static void sig_handler(int signo, siginfo_t* info, void* _ctx) { handleSignal(signo, info, _ctx); // try to forward the signal. raise(info->si_signo); // terminate the process immediately. puts("watf? exit"); _exit(EXIT_FAILURE); } }; #endif // BACKWARD_SYSTEM_LINUX || BACKWARD_SYSTEM_DARWIN #ifdef BACKWARD_SYSTEM_WINDOWS class SignalHandling { public: SignalHandling(const std::vector& = std::vector()) : reporter_thread_([]() { /* We handle crashes in a utility thread: backward structures and some Windows functions called here need stack space, which we do not have when we encounter a stack overflow. To support reporting stack traces during a stack overflow, we create a utility thread at startup, which waits until a crash happens or the program exits normally. */ { std::unique_lock lk(mtx()); cv().wait(lk, [] { return crashed() != crash_status::running; }); } if (crashed() == crash_status::crashed) { handle_stacktrace(skip_recs()); } { std::unique_lock lk(mtx()); crashed() = crash_status::ending; } cv().notify_one(); }) { SetUnhandledExceptionFilter(crash_handler); signal(SIGABRT, signal_handler); _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); std::set_terminate(&terminator); #ifndef BACKWARD_ATLEAST_CXX17 std::set_unexpected(&terminator); #endif _set_purecall_handler(&terminator); _set_invalid_parameter_handler(&invalid_parameter_handler); } bool loaded() const { return true; } ~SignalHandling() { { std::unique_lock lk(mtx()); crashed() = crash_status::normal_exit; } cv().notify_one(); reporter_thread_.join(); } private: static CONTEXT* ctx() { static CONTEXT data; return &data; } enum class crash_status { running, crashed, normal_exit, ending }; static crash_status& crashed() { static crash_status data; return data; } static std::mutex& mtx() { static std::mutex data; return data; } static std::condition_variable& cv() { static std::condition_variable data; return data; } static HANDLE& thread_handle() { static HANDLE handle; return handle; } std::thread reporter_thread_; // TODO: how not to hardcode these? static const constexpr int signal_skip_recs = #ifdef __clang__ // With clang, RtlCaptureContext also captures the stack frame of the // current function Below that, there are 3 internal Windows functions 4 #else // With MSVC cl, RtlCaptureContext misses the stack frame of the current // function The first entries during StackWalk are the 3 internal Windows // functions 3 #endif ; static int& skip_recs() { static int data; return data; } static inline void terminator() { crash_handler(signal_skip_recs); abort(); } static inline void signal_handler(int) { crash_handler(signal_skip_recs); abort(); } static inline void __cdecl invalid_parameter_handler(const wchar_t*, const wchar_t*, const wchar_t*, unsigned int, uintptr_t) { crash_handler(signal_skip_recs); abort(); } NOINLINE static LONG WINAPI crash_handler(EXCEPTION_POINTERS* info) { // The exception info supplies a trace from exactly where the issue was, // no need to skip records crash_handler(0, info->ContextRecord); return EXCEPTION_CONTINUE_SEARCH; } NOINLINE static void crash_handler(int skip, CONTEXT* ct = nullptr) { if (ct == nullptr) { RtlCaptureContext(ctx()); } else { memcpy(ctx(), ct, sizeof(CONTEXT)); } DuplicateHandle(GetCurrentProcess(), GetCurrentThread(), GetCurrentProcess(), &thread_handle(), 0, FALSE, DUPLICATE_SAME_ACCESS); skip_recs() = skip; { std::unique_lock lk(mtx()); crashed() = crash_status::crashed; } cv().notify_one(); { std::unique_lock lk(mtx()); cv().wait(lk, [] { return crashed() != crash_status::crashed; }); } } static void handle_stacktrace(int skip_frames = 0) { // printer creates the TraceResolver, which can supply us a machine type // for stack walking. Without this, StackTrace can only guess using some // macros. // StackTrace also requires that the PDBs are already loaded, which is done // in the constructor of TraceResolver Printer printer; StackTrace st; st.set_machine_type(printer.resolver().machine_type()); st.set_thread_handle(thread_handle()); st.load_here(32 + skip_frames, ctx()); st.skip_n_firsts(skip_frames); printer.address = true; printer.print(st, std::cerr); } }; #endif // BACKWARD_SYSTEM_WINDOWS #ifdef BACKWARD_SYSTEM_UNKNOWN class SignalHandling { public: SignalHandling(const std::vector& = std::vector()) {} bool init() { return false; } bool loaded() { return false; } }; #endif // BACKWARD_SYSTEM_UNKNOWN } // namespace backward #endif /* H_GUARD */ ================================================ FILE: oneflow/ir/.gitignore ================================================ /build* lit.site.cfg.py ================================================ FILE: oneflow/ir/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.13.4) include(FetchContent) # prevent LLVM_DEFINITIONS has a TRUE in it unset(result CACHE) set(CMAKE_INSTALL_MESSAGE LAZY) if(POLICY CMP0068) cmake_policy(SET CMP0068 NEW) set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) endif() if(POLICY CMP0075) cmake_policy(SET CMP0075 NEW) endif() if(POLICY CMP0077) cmake_policy(SET CMP0077 NEW) endif() if(POLICY CMP0116) cmake_policy(SET CMP0116 OLD) endif() project(oneflow-dialect LANGUAGES CXX C) # https://github.com/llvm/llvm-project/issues/55010 set(LLVM_ABI_BREAKING_CHECKS "FORCE_OFF" CACHE STRING "") if(LLVM_PROVIDER STREQUAL "in-tree") include(llvm-in-tree.cmake) elseif(LLVM_PROVIDER STREQUAL "install") include(install-llvm.cmake) else() message(FATAL_ERROR "LLVM_PROVIDER should be in-tree or install, but got: ${LLVM_PROVIDER}") endif() set_property(GLOBAL PROPERTY LLVM_INSTALL_DIR ${LLVM_INSTALL_DIR}) set(MLIR_TABLEGEN_EXE mlir-tblgen) set(MLIR_PDLL_TABLEGEN_EXE mlir-pdll) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) set(LLVM_INCLUDE_DIRS ${LLVM_INCLUDE_DIRS} PARENT_SCOPE) set(MLIR_INCLUDE_DIRS ${MLIR_INCLUDE_DIRS} PARENT_SCOPE) set(ONEFLOW_MLIR_SOURCE_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/include PARENT_SCOPE) set(ONEFLOW_MLIR_BINARY_INCLUDE_DIRS ${PROJECT_BINARY_DIR}/include PARENT_SCOPE) include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) if(LLVM_PROVIDER STREQUAL "in-tree") add_subdirectory(${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen ${PROJECT_BINARY_DIR}/oneflow-tblgen) endif() function(update_rpath) set_property(TARGET ${ARGV0} APPEND PROPERTY BUILD_RPATH "${LLVM_LIBRARY_DIR}") set_property(TARGET ${ARGV0} APPEND PROPERTY BUILD_RPATH "${ONEFLOW_BUILD_ROOT_DIR}") set_property(TARGET ${ARGV0} APPEND PROPERTY INSTALL_RPATH "${LLVM_LIBRARY_DIR}") set_property(TARGET ${ARGV0} APPEND PROPERTY INSTALL_RPATH "${ONEFLOW_BUILD_ROOT_DIR}") endfunction(update_rpath) function(oneflow_add_mlir_library) add_mlir_library(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) update_rpath(${ARGV0}) endfunction() function(oneflow_add_mlir_dialect_library) add_mlir_dialect_library(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) update_rpath(${ARGV0}) endfunction() function(oneflow_add_llvm_tool) add_llvm_tool(${ARGV}) llvm_update_compile_flags(oneflow-runner) set_compile_options_to_oneflow_target(${ARGV0}) update_rpath(${ARGV0}) endfunction() find_package(Threads REQUIRED) set(LLVM_PTHREAD_LIB ${CMAKE_THREAD_LIBS_INIT}) set(LLVM_RUNTIME_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/lib) if(WITH_MLIR) add_subdirectory(include) add_subdirectory(lib) add_subdirectory(test) add_subdirectory(oneflow-translate) add_subdirectory(oneflow-runtime) add_subdirectory(oneflow-extension) add_subdirectory(oneflow-opt) add_subdirectory(oneflow-runner) add_subdirectory(oneflow-lite) endif(WITH_MLIR) if(BUILD_PYTHON) foreach(llvm_include_dir ${LLVM_INCLUDE_DIRS}) if(llvm_include_dir MATCHES "/include$") list(APPEND LLVM_INSTALL_INCLUDE_DIRS "${llvm_include_dir}//") else() list(APPEND LLVM_INSTALL_INCLUDE_DIRS "${llvm_include_dir}") endif() endforeach() install( DIRECTORY ${LLVM_INSTALL_INCLUDE_DIRS} DESTINATION ${ONEFLOW_INCLUDE_DIR} COMPONENT oneflow_py_include EXCLUDE_FROM_ALL FILES_MATCHING PATTERN llvm/ADT/ArrayRef.h PATTERN llvm/ADT/Hashing.h PATTERN llvm/ADT/iterator.h PATTERN llvm/ADT/None.h PATTERN llvm/ADT/SmallVector.h PATTERN llvm/ADT/STLExtras.h PATTERN llvm/ADT/STLFunctionalExtras.h PATTERN llvm/ADT/DenseMapInfo.h PATTERN llvm/ADT/identity.h PATTERN llvm/ADT/iterator_range.h PATTERN llvm/ADT/Optional.h PATTERN llvm/ADT/STLArrayExtras.h PATTERN llvm/ADT/STLForwardCompat.h PATTERN llvm/ADT/StringRef.h PATTERN llvm/ADT/bit.h PATTERN llvm/Config/abi-breaking.h PATTERN llvm/Config/llvm-config.h PATTERN llvm/Support/Compiler.h PATTERN llvm/Support/DataTypes.h PATTERN llvm/Support/ErrorHandling.h PATTERN llvm/Support/SwapByteOrder.h PATTERN llvm/Support/type_traits.h PATTERN llvm-c/DataTypes.h) endif() ================================================ FILE: oneflow/ir/README.md ================================================ # OneFlow IR OneFlow IR, a MLIR dialect ## Code style Inevitably, developers maintaining OneFlow IR would face these challenges: - Debugging components related to IR, compiler could be complicated and peculiar. - IR subsystems should follow latest changes of OneFlow and MLIR closely. To address these problems, within the IR source code directory, there are some rules must be enforced for all the optimizers, importers, exporters, runners: - separate library, include, target - MLIR-releted code should follow the style and paradigm of MLIR and LLVM closely - ensure every component could be independently compiled and tested - there should be one `CMakeLists.txt` in every sub-directory - don't link anything from OneFlow unless it is necessary for the feature ## Major components - ### oneflow-translate Everything related to MLIR-OneFlow translation. [read more](oneflow-translate/README.md) - ### oneflow-opt Optimizations on OneFlow MLIR dialect. A CLI to optimize .mlir file. [read more](oneflow-opt/README.md) - ### OneFlow dialect In the `include` and `lib` directories, there are definitions of MLIR OneFlow dialect and its operators. - ### OneFlow Kenerl Memory (OKM) Dialect In the `include` and `lib` directories, there are definitions of MLIR OKM dialect and its operators. OKM is a dialect which support oneflow using mlir memref style and use-def flow to optimize memory usage. - ### OneFlow Kernel Launch (OKL) dialect In the `include` and `lib` directories, there are definitions of MLIR OKL dialect and its operators. OKL is a dialect which support oneflow kernel ops launched as a a llvm dialect callee. ## Parallel Signature - There is parallel signature as 0 for OneFlow Ops in MLIR. It is implemented as MLIR dialect attribute. Some examples: - 1D SBP ```mlir %100 = "oneflow.relu"(%99) {parallel = #sbp.parallel<[#sbp.S<0>] -> [#sbp.S<0>]>, ... ``` - multiple inputs and outputs 1D SBP ```mlir %102 = "oneflow.add_n2"(%101, %97) {parallel = #sbp.parallel<[#sbp.S<0>, #sbp.S<0>] -> [#sbp.S<0>]>, ... ``` - 2D SBP `matmul` ``` %120 = "oneflow.matmul"(%119, %output_105) {parallel = #sbp.parallel<[[#sbp.S<0>, #sbp.P], #sbp.S<0>] -> [#sbp.S<0>]>, ... ``` - To avoid confusion and potential parsing error, use the term "parallel" instead of using "sbp" but conceptually and documentally there are the same. ### Principle - In IR, The signature should be orthogonal to device placement information althogh in some passes they might be related to each other. ## Development - To run all the regression tests. The `-j3` option for [`LIT`](https://llvm.org/docs/CommandGuide/lit.html) is to prevent OOM on GPU. ```bash LIT_OPTS="-j3" cmake --build build -t c1 -j24 ``` ================================================ FILE: oneflow/ir/include/CMakeLists.txt ================================================ add_subdirectory(OneFlow) add_subdirectory(Transform) ================================================ FILE: oneflow/ir/include/OneFlow/CMakeLists.txt ================================================ # set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_BINARY_DIR}/include/OneFlow") set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_SOURCE_DIR}/include/OneFlow") set(LLVM_TARGET_DEFINITIONS OneFlowEnums.td) mlir_tablegen(OneFlowEnums.h.inc -gen-enum-decls) mlir_tablegen(OneFlowEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIROneFlowEnumsIncGen) set(LLVM_TARGET_DEFINITIONS OneFlowPatterns.td) set(ONEFLOW_OP_GROUPS_USED_IN_PATTERNS "SCALAR;UNARY;FUSED;MISC;BINARY;IDEMPOTENT;NORMALIZATION;MATMUL;BROADCAST;CONV;PADDING") foreach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS_USED_IN_PATTERNS) list(APPEND LLVM_TABLEGEN_FLAGS "-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS") endforeach() mlir_tablegen(OneFlowPatterns.cpp.inc -gen-rewriters) add_public_tablegen_target(MLIROneFlowPatternsIncGen) # NOTE: seperate conversion and opt with --name if(WITH_MLIR_CUDA_CODEGEN) list(APPEND LLVM_TABLEGEN_FLAGS "-DWITH_MLIR_CUDA_CODEGEN") endif() set(LLVM_TARGET_DEFINITIONS OneFlowPasses.td) mlir_tablegen(OneFlowPasses.h.inc -gen-pass-decls) add_public_tablegen_target(MLIROneFlowPassIncGen) set(LLVM_TABLEGEN_FLAGS "") add_mlir_interface(OneFlowInterfaces) set(LLVM_TARGET_DEFINITIONS OneFlowOpGetGen.td) set(ONEFLOW_OP_GROUPS "ASSIGN;BINARY;BROADCAST;CONV;CROSS_ENTROPY;CUDA;DATASET;DETECTION;EAGER;FUSED;IDEMPOTENT;IDENTITY;IMAGE;INDICES;INVOLUTION;LOSS;MATH;MATMUL;MISC;NCCL;NORMALIZATION;OPTIMIZER;PADDING;PARALLEL_CAST;POOL;QUANTIZATION;REDUCE;RESHAPE;SCALAR;SOFTMAX;SUMMARY;TENSOR_BUFFER;TEST;TRIGONOMETRIC;UNARY;UPSAMPLE;ONE_EMBEDDING;LINEAR_ALGEBRA;SYSTEM;MLIR_JIT" ) foreach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) message(STATUS "Enable OneFlow MLIR op group: ${OP_GROUP_NAME}") set(ONE_LLVM_TABLEGEN_FLAGS "-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS") list(APPEND FULL_LLVM_TABLEGEN_FLAGS "${ONE_LLVM_TABLEGEN_FLAGS}") set(LLVM_TABLEGEN_FLAGS "${ONE_LLVM_TABLEGEN_FLAGS}") string(TOLOWER "${OP_GROUP_NAME}" OP_GROUP_NAME_LOWER) set(CPP_INC_FILE "OneFlow.${OP_GROUP_NAME_LOWER}_ops.cpp.inc") set(HEADER_INC_FILE "OneFlow.${OP_GROUP_NAME_LOWER}_ops.h.inc") mlir_tablegen(${CPP_INC_FILE} -gen-op-defs) mlir_tablegen(${HEADER_INC_FILE} -gen-op-decls) endforeach() add_public_tablegen_target(MLIROneFlowOpGroupDefsIncGen) set(LLVM_TABLEGEN_FLAGS "${FULL_LLVM_TABLEGEN_FLAGS}") mlir_tablegen(OneFlow.gen_ops.h.inc -gen-op-decls) add_public_tablegen_target(MLIROneFlowOpGroupDeclsIncGen) set(LLVM_TARGET_DEFINITIONS SBP/SBPOps.td) mlir_tablegen(SBPDialect.h.inc -gen-dialect-decls) mlir_tablegen(SBPDialect.cpp.inc -gen-dialect-defs) mlir_tablegen(SBPAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(SBPAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRSBPIncGen) set(LLVM_TARGET_DEFINITIONS OKL/OKLOps.td) mlir_tablegen(OKLDialect.h.inc -gen-dialect-decls -dialect=okl) mlir_tablegen(OKLDialect.cpp.inc -gen-dialect-defs -dialect=okl) mlir_tablegen(OKLOps.h.inc -gen-op-decls) mlir_tablegen(OKLOps.cpp.inc -gen-op-defs) mlir_tablegen(OKLTypes.h.inc -gen-typedef-decls) mlir_tablegen(OKLTypes.cpp.inc -gen-typedef-defs) mlir_tablegen(OKLPasses.h.inc -gen-pass-decls) mlir_tablegen(OKLEnums.h.inc -gen-enum-decls) mlir_tablegen(OKLEnums.cpp.inc -gen-enum-defs) mlir_tablegen(OKLAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=okl) mlir_tablegen(OKLAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=okl) add_public_tablegen_target(MLIROKLIncGen) set(LLVM_TARGET_DEFINITIONS OKM/OKMOps.td) mlir_tablegen(OKMDialect.h.inc -gen-dialect-decls -dialect=okm) mlir_tablegen(OKMDialect.cpp.inc -gen-dialect-defs -dialect=okm) mlir_tablegen(OKMOps.h.inc -gen-op-decls) mlir_tablegen(OKMOps.cpp.inc -gen-op-defs) mlir_tablegen(OKMPasses.h.inc -gen-pass-decls) mlir_tablegen(OKMAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(OKMAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIROKMIncGen) set(LLVM_TABLEGEN_FLAGS "") add_mlir_dialect( OneFlowOps oneflow DEPENDS MLIRSBPIncGen MLIROneFlowEnumsIncGen MLIROneFlowPatternsIncGen MLIROneFlowPassIncGen MLIROneFlowInterfacesIncGen MLIROneFlowOpGroupDefsIncGen MLIROneFlowOpGroupDeclsIncGen) ================================================ FILE: oneflow/ir/include/OneFlow/Conversion/NVVMToCubin.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_ #ifdef WITH_MLIR_CUDA_CODEGEN #include "mlir/Pass/Pass.h" namespace mlir { namespace gpu { inline std::string getCubinAnnotation() { return "gpu.binary"; } } // namespace gpu namespace oneflow { const char* getArchVersion(); std::unique_ptr createNVVMToCubinPass(); void InitializeLLVMNVPTXBackend(); } // namespace oneflow } // namespace mlir #endif // WITH_MLIR_CUDA_CODEGEN #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Conversion/OneFlowToTosa.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createLowerOneFlowToTosaPass(); std::unique_ptr createLowerOneFlowToLinalgPass(); std::unique_ptr createConvertToSignlessForTosaPass(); std::unique_ptr createCastOneFlowOpsToSignlessPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Extension.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_ #include #include namespace oneflow { using SharedLibs = std::unordered_set; SharedLibs* MutSharedLibPaths(); const SharedLibs* SharedLibPaths(); } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Conversion/Conversion.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_ #include "OneFlow/OKL/Conversion/OKLToLLVM.h" #include "mlir/IR/BuiltinOps.h" namespace mlir { namespace okl { // convert okl dialect to llvm dialect LogicalResult LowerOKLComputeToLLVM(ModuleOp module); } // namespace okl } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Conversion/OKLToLLVM.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace okl { // lower !okl.launcher_ctx to !llvm.ptr std::unique_ptr createLowerLauncherToLLVMPtrPass(); // lower okl ops to llvm.call @{callee in liboneflow.so} std::unique_ptr createLowerOKLToLLVMCallPass(); // tag {okl.cuda_graph_support} according to its wrapped ops std::unique_ptr createTagCudaGraphSupportPass(); namespace cuda_graph_support { static const auto TAG_NAME = "cuda_graph_support"; } // namespace cuda_graph_support } // namespace okl } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/ComputeContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_ #include "mlir/IR/BuiltinAttributes.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKL/Kernel/TmpBufferManager.h" namespace oneflow { namespace okl { class ComputeContext final : public user_op::KernelComputeContext { public: ComputeContext(RegContext const* reg_ctx, user_op::KernelComputeContext* comp_ctx) : reg_ctx_(reg_ctx), comp_ctx_(comp_ctx), tmp_buffer_(comp_ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)) {} ~ComputeContext() = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index); } ep::Stream* stream() override { return comp_ctx_->stream(); } DeviceType device_type() const override { return reg_ctx_->device_type(); } const ParallelContext& parallel_ctx() const override { return comp_ctx_->parallel_ctx(); } const ArgVec& inputs() const override { return reg_ctx_->inputs(); } const ArgVec& outputs() const override { return reg_ctx_->outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const override { return reg_ctx_->user_op_conf(); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override; private: RegContext const* reg_ctx_; KernelComputeContext* comp_ctx_; TmpBufferManager tmp_buffer_; std::unordered_map tensor_{}; user_op::Tensor* CreateTensorWithArgNameAndIndex(const std::string& arg_name, int32_t index); const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return user_op_conf().Attr4Name(attr_name); } }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/InferContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_ #include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/kernel/user_kernel.h" #include "OneFlow/OKL/Kernel/RegContext.h" namespace oneflow { namespace okl { class InferContext final : public user_op::InferContext { public: explicit InferContext(RegContext const* reg_ctx); const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { return *LogicalTensorDesc4ArgNameAndIndex(arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, int32_t index) const override { return *LogicalTensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(const std::string&, int32_t) override { TODO(); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override; const Shape& InputShape(const std::string& arg_name, int32_t index) const override; const Shape& OutputShape(const std::string&, int32_t) const override { TODO(); } void SetOutputShape(const std::string&, int32_t, const Shape&) override { TODO(); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override; void SetShape4ArgNameAndIndex(const std::string&, int32_t, const Shape&) override { TODO(); } const Stride& InputStride(const std::string&, int32_t) const override { TODO(); } const Stride& OutputStride(const std::string&, int32_t) const override { TODO(); } void SetOutputStride(const std::string&, int32_t, const Stride&) override { TODO(); } const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } void SetStride4ArgNameAndIndex(const std::string&, int32_t, const Stride&) override { TODO(); } DataType InputDType(const std::string&, int32_t) const override { TODO(); } DataType OutputDType(const std::string&, int32_t) const override { TODO(); } void SetOutputDType(const std::string&, int32_t, DataType) override { TODO(); } DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } void SetDtype4ArgNameAndIndex(const std::string&, int32_t, DataType) override { TODO(); } MemoryFormat InputMemoryFormat(const std::string&, int32_t) const override { TODO(); } MemoryFormat OutputMemoryFormat(const std::string&, int32_t) const override { TODO(); } void SetOutputMemoryFormat(const std::string&, int32_t, MemoryFormat) override { TODO(); } MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } void SetMemoryFormat4ArgNameAndIndex(const std::string&, int32_t, MemoryFormat) override { TODO(); } const std::vector>& inputs() const override { return reg_ctx_->inputs(); } const std::vector>& outputs() const override { return reg_ctx_->outputs(); } const std::string& input(const std::string& arg_name, int32_t index) const override { return reg_ctx_->user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const override { return reg_ctx_->user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const override { return reg_ctx_->user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const override { return reg_ctx_->user_op_conf().has_input(arg_name, index); } int32_t input_size(const std::string& arg_name) const override { return reg_ctx_->user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const override { return reg_ctx_->user_op_conf().output_size(arg_name); } const std::string& op_name() const override { return reg_ctx_->user_op_conf().op_name(); } const std::string& op_type_name() const override { return reg_ctx_->user_op_conf().op_type_name(); } const std::string& op_loc() const override { TODO(); } const ParallelContext& parallel_ctx() const override { TODO(); } const ParallelDesc& parallel_desc() const override { TODO(); } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } bool InputIsDynamic(const std::string&, int32_t) const override { TODO(); } bool OutputIsDynamic(const std::string&, int32_t) const override { TODO(); } void SetOutputIsDynamic(const std::string&, int32_t, bool) override { TODO(); } bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } void SetIsDynamic4ArgNameAndIndex(const std::string&, int32_t, bool) override { TODO(); } int64_t parallel_num() const override { TODO(); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override; RegContext const* reg_ctx_; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/InitContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_ #include "OneFlow/OKL/Kernel/RegContext.h" #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace okl { class InitContext final : public user_op::KernelCacheContext, public user_op::KernelInitContext { public: InitContext(RegContext const* reg_ctx, user_op::KernelComputeContext* compute_ctx) : reg_ctx_(reg_ctx), compute_ctx_(compute_ctx) {} DeviceType device_type() const override { return reg_ctx_->device_type(); } const ParallelContext& parallel_ctx() const override { return compute_ctx_->parallel_ctx(); } ep::Stream* stream() override { return compute_ctx_->stream(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index); } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index); } const ParallelDesc& parallel_desc() const override { TODO(); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); } const std::vector>& inputs() const override { return reg_ctx_->inputs(); } const std::vector>& outputs() const override { return reg_ctx_->outputs(); } private: RegContext const* reg_ctx_; user_op::KernelComputeContext* compute_ctx_; const user_op::UserOpConfWrapper& user_op_conf() const override { return reg_ctx_->user_op_conf(); } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return reg_ctx_->Attr4Name(attr_name); } }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/JITEngine.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_ #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/BuiltinOps.h" #include "oneflow/core/framework/op_kernel.h" #include "OneFlow/OKL/Kernel/LauncherContext.h" extern "C" { void okl_llvm_func(void* launcher, int64_t index); } // extern "C" namespace oneflow { namespace okl { using LLVMLaunchArgs = std::tuple; class JITEngine { public: explicit JITEngine(mlir::ModuleOp module); void Run(const std::string& name, LauncherContext* launcher) const { auto error = engine_->invoke(name, launcher); CHECK(!error) << "fail to invoke jit engine, error: " << llvm::toString(std::move(error)); } private: std::unique_ptr engine_; }; namespace llvm_func { #define C_FUNC_NAME(func) #func const auto LLVM_FUNC = C_FUNC_NAME(okl_llvm_func); #undef C_FUNC_NAME } // namespace llvm_func } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/JITOpInfer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ #include "oneflow/core/framework/infer_util.h" namespace oneflow { namespace ir { namespace jit { Maybe InferTensorDesc(user_op::InferContext* ctx); Maybe SetTensorDataType(user_op::InferContext* ctx); } // namespace jit } // namespace ir } // namespace oneflow #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/LauncherContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_ #include "oneflow/core/framework/op_kernel.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKL/Kernel/WrapperContext.h" #include "mlir/IR/Operation.h" namespace oneflow { namespace okl { class LauncherContext final { public: // compile the mlir to ctx explicit LauncherContext(mlir::ModuleOp module); // infer ctx with okl info bool Infer() { return inferred_; } bool Infer(user_op::KernelComputeContext* compute_context); // launch kernel with index void Launch(int index); private: bool inferred_ = false; std::vector compile_ctx_vec_; std::vector run_ctx_vec_; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/LauncherState.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/Kernel/JITEngine.h" #include "OneFlow/OKL/Kernel/LauncherContext.h" #include "OneFlow/OKL/Conversion/Conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" namespace oneflow { namespace okl { inline mlir::DialectRegistry GetRegistry() { mlir::DialectRegistry registry; registry.insert(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); return registry; } class LauncherState final : public user_op::OpKernelState { public: explicit LauncherState(user_op::KernelInitContext* ctx); ~LauncherState() = default; void DoCompute(user_op::KernelComputeContext* ctx); bool IsCudaGraphSupported(user_op::KernelInitContext* ctx); private: // manage module(compile) mlir::MLIRContext mlir_ctx_; mlir::OwningOpRef module_; // manage context LauncherContext launcher_context_; // manage engine(runtime) JITEngine engine_; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/README.md ================================================ ## context相关概念与其生命周期: ### LauncherState LauncherState 是OpKernelState的派生类,在okl kernel的初始化kernel state的阶段被创建。 ``` c++ LauncherState final : public user_op::OpKernelState ``` 每个LauncherState拥有一个LauncherContext管理运行的上下文和一个JIT Engine管理运行时引擎。 单个okl kernel的资源的管理者 - LauncherContext的维护者,负责对于context信息的更新; - JIT Engine的所有者 ### LauncherContext LauncherContext作为单个okl kernel的上下文的管理者,维护若干有序的oneflow op的上下文资源信息,每个oneflow op对应的上下文资源对应一个专门的WrapperContext作为一个总体的维护者。 因此LauncherContext下维护一系列编译期状态的WrapperContext和运行时状态的WrapperContext以对应不同阶段的上下文。这些ctx与oneflow op一一对应。 ``` class LauncherContext final { bool inferred_ = false; std::vector compile_ctx_vec_; std::vector run_ctx_vec_; }; ``` ### WrapperContext(op, ctx): 单个被okl wrap的oneflow op的管理者,编译期存在的东西在初始化后不可被改变,运行时需要做一个懒汉模式的infer推导流程。 1. 推导前 - reg_ctx_(op) - device - inputs/outputs - kernel - user config 2. 推导后 - init_ctx_(reg_ctx_, ctx) - state_(reg_ctx, init_ctx_) - cache_(reg_ctx, init_ctx_) - compute_ctx_(ctx) ``` class CompileTimeWrapperContext { std::shared_ptr reg_ctx_; }; class RunTimeWrapperContext : public CompileTimeWrapperContext { std::shared_ptr compute_ctx_; std::shared_ptr init_ctx_; std::shared_ptr kernel_state_; std::shared_ptr kernel_cache_; }; ``` CompileTimeWrapperContext维护着从ir获取得到的上下文信息并加以封装到reg ctx,用于作为后面推导RunTimeWrapperContext的输入之一。 RunTimeWrapperContext通过CompileTimeWrapperContext的信息以及okl kernel所创建的comp ctx以及tmp buffer等资源组成了单个op运行时计算所需的实际上下文环境。通过创建的init ctx,创建kernel state,kernel cache等资源用于kernel的compute计算。 ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/RegContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_ #include "oneflow/core/framework/user_op_kernel_registry.h" #include "OneFlow/UserOpReflection.h" #include "mlir/IR/Operation.h" namespace oneflow { namespace okl { // this context should support querying information about the kernel from representation in MLIR using ArgVec = std::vector>; class RegContext final : public user_op::KernelRegContext { public: explicit RegContext(mlir::Operation* op); ~RegContext() = default; // override user_op KernelRegContext DeviceType device_type() const override; const ParallelContext& parallel_ctx() const override; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override; const ArgVec& inputs() const override; const ArgVec& outputs() const override; const user_op::UserOpConfWrapper& user_op_conf() const override; const std::shared_ptr& Attr4Name( const std::string& attr_name) const override; const size_t GetTmpBufferSize() const; ::mlir::Operation* GetOp() const { return op_; }; const user_op::OpKernel* GetKernel() const { return kernel_; }; private: ::mlir::Operation* op_; DeviceType device_type_ = DeviceType::kInvalidDevice; std::unordered_map arg2tensor_desc_{}; ArgVec inputs_; ArgVec outputs_; user_op::UserOpConfWrapper conf_wrapper_; const user_op::OpKernelRegistryResult* reg_res_ = nullptr; const user_op::OpKernel* kernel_ = nullptr; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/TmpBufferManager.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_tensor.h" #include namespace oneflow { namespace okl { class TmpBufferManager { class PoolToTensor final : public oneflow::user_op::Tensor { public: explicit PoolToTensor(user_op::Tensor* tensor, const user_op::TensorDesc* tensor_desc, int64_t offset) : tensor_(tensor), raw_dptr_(reinterpret_cast(tensor_->mut_raw_dptr()) + offset), tensor_desc_(tensor_desc) {} ShapeView shape_view() const override { return tensor_desc_->shape(); } const Stride& stride() const override { return tensor_desc_->stride(); } DataType data_type() const override { return tensor_desc_->data_type(); } MemoryFormat memory_format() const override { return tensor_desc_->memory_format(); } MutShapeView mut_shape_view() override { TODO(); } const MemoryCase& mem_case() const override { return tensor_->mem_case(); } const void* raw_dptr() const override { return raw_dptr_; } void* mut_raw_dptr() override { return raw_dptr_; } private: user_op::Tensor* tensor_; void* raw_dptr_; const user_op::TensorDesc* tensor_desc_; }; class PoolToBuffer final : public oneflow::user_op::Tensor { public: explicit PoolToBuffer(user_op::Tensor* tensor, int64_t size, int64_t offset) : tensor_(tensor), raw_dptr_(reinterpret_cast(tensor_->mut_raw_dptr()) + offset), shape_({size}) {} ShapeView shape_view() const override { return shape_; } const Stride& stride() const override { return tensor_->stride(); } DataType data_type() const override { return tensor_->data_type(); } MemoryFormat memory_format() const override { return tensor_->memory_format(); } MutShapeView mut_shape_view() override { return shape_; } const MemoryCase& mem_case() const override { return tensor_->mem_case(); } const void* raw_dptr() const override { return raw_dptr_; } void* mut_raw_dptr() override { return raw_dptr_; } private: user_op::Tensor* tensor_; void* raw_dptr_; Shape shape_; }; public: static size_t InferTmpSize(user_op::InferContext* ctx); explicit TmpBufferManager(user_op::Tensor* tensor) : tensor_(tensor) {} user_op::Tensor* GetPoolTensor(const user_op::TensorDesc* tensor_desc, int64_t offset) { CHECK_LE(offset + tensor_desc->shape().elem_cnt() * GetSizeOfDataType(tensor_desc->data_type()), tensor_->shape_view().elem_cnt()); auto res = tensor_map_.insert({tensor_desc, PoolToTensor(tensor_, tensor_desc, offset)}).first; return &res->second; } user_op::Tensor* GetPoolBuffer(int64_t size, int64_t offset) { auto res = buffer_map_.insert({{size, offset}, PoolToBuffer(tensor_, size, offset)}).first; return &res->second; } private: std::unordered_map tensor_map_{}; std::unordered_map, PoolToBuffer> buffer_map_{}; user_op::Tensor* tensor_; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/Kernel/WrapperContext.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_ #include "mlir/IR/BuiltinAttributes.h" #include "OneFlow/OKL/Kernel/InitContext.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKL/Kernel/ComputeContext.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace okl { class CompileTimeWrapperContext { public: explicit CompileTimeWrapperContext(mlir::Operation* op) : reg_ctx_(std::make_shared(op)) {} CompileTimeWrapperContext(CompileTimeWrapperContext&&) = default; RegContext const* GetRegContext() const { return reg_ctx_.get(); } private: std::shared_ptr reg_ctx_; }; class RunTimeWrapperContext { public: RunTimeWrapperContext(mlir::Operation* op, user_op::KernelComputeContext* ctx) : compile_time_wrapper_ctx_(op), compute_ctx_(std::make_unique(GetRegContext(), ctx)), init_ctx_(std::make_unique(GetRegContext(), ctx)), kernel_state_(GetRegContext()->GetKernel()->CreateOpKernelState(init_ctx_.get())), kernel_cache_(GetRegContext()->GetKernel()->InitOpKernelCache(init_ctx_.get())) {} void Run() { GetRegContext()->GetKernel()->Compute(compute_ctx_.get(), kernel_state_.get(), kernel_cache_.get()); } RegContext const* GetRegContext() const { return compile_time_wrapper_ctx_.GetRegContext(); } private: CompileTimeWrapperContext compile_time_wrapper_ctx_; std::unique_ptr compute_ctx_; std::unique_ptr init_ctx_; std::shared_ptr kernel_state_; std::shared_ptr kernel_cache_; }; } // namespace okl } // namespace oneflow #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLAttributes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_ #define ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "OneFlow/OKLEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "OneFlow/OKLAttributes.h.inc" #endif // ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLAttributes.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES #define ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES include "OneFlow/OKL/OKLDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" #endif // ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLBase.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLBASE #define ONEFLOW_IR_INCLUDE_OKL_OKLBASE include "OneFlow/OKL/OKLDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" class OKL_Op traits = []> : Op; class OKL_Type traits = []> : TypeDef { let mnemonic = typeMnemonic; } #endif // ONEFLOW_IR_INCLUDE_OKL_OKLBASE ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLDialect.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_ #define ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_ #include "OneFlow/Passes.h" #include "mlir/IR/Dialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "OneFlow/OKLDialect.h.inc" #include "OneFlow/OKL/OKLOps.h" #endif // ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLDialect.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT #define ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT include "mlir/IR/OpBase.td" def OKL_Dialect : Dialect { let name = "okl"; let summary = "OneFlow Kernel Launch Dialect."; let description = [{ This dialect is the IR of abstract represent of OneFlow Kernel Launch Op. }]; let cppNamespace = "::mlir::okl"; let dependentDialects = [ "func::FuncDialect" ]; let useDefaultTypePrinterParser = 1; } #endif // ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/OKLAttributes.h" namespace mlir { namespace func { class FuncOp; } // namespace func } // namespace mlir #define GET_OP_CLASSES #include "OneFlow/OKLOps.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLOps.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_OKL_OKLOPS #define ONEFLOW_IR_INCLUDE_OKL_OKLOPS include "OneFlow/OKL/OKLDialect.td" include "OneFlow/OKL/OKLBase.td" include "OneFlow/OKL/OKLTypes.td" include "mlir/Pass/PassBase.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" def GetTensorFromArgOp : OKL_Op<"get_tensor_from_arg"> { let summary = "get tensor as arguments from operands of context"; let description = [{ get tensor value from launcher context as arguments. }]; let arguments = (ins LauncherContextType:$launcher_ctx, I32Attr:$index ); let results = (outs AnyTensor); } def GetTensorFromRetOp : OKL_Op<"get_tensor_from_ret"> { let summary = "get tensor as arguments from results of context"; let description = [{ get tensor value from launcher context as arguments. }]; let arguments = (ins LauncherContextType:$launcher_ctx, I32Attr:$index ); let results = (outs AnyTensor); } def GetTensorAsRetOp : OKL_Op<"get_tensor_as_ret"> { let summary = "get tensor as outcomes from results of context"; let description = [{ get tensor value from launcher context as outcomes. }]; let arguments = (ins LauncherContextType:$launcher_ctx, AnyTensor:$tensor, I32Attr:$index ); let results = (outs AnyTensor); } def PoolToTensorOp : OKL_Op<"pool_to_tensor"> { let arguments = (ins LauncherContextType:$launcher_ctx, I64Attr:$offset ); let results = (outs AnyTensor); } def PoolToBufferOp : OKL_Op<"pool_to_buffer"> { let arguments = (ins LauncherContextType:$launcher_ctx, I64Attr:$offset ); let results = (outs AnyTensor); } def TensorToPoolOp : OKL_Op<"tensor_to_pool"> { let arguments = (ins LauncherContextType:$launcher_ctx, AnyTensor:$tensor, I64Attr:$offset ); let results = (outs AnyTensor); } def WrapperKernelOp : OKL_Op<"wrapper_kernel"> { let summary = "build reg context operation"; let description = [{ this context is generated from module op and used on kernel/run_ctx build phase. each wrapped op has their own reg_ctx with their own attrs. }]; let arguments = (ins I32Attr:$index ); let regions = (region AnyRegion:$body); } def ReturnOp : OKL_Op<"return", [HasParent<"WrapperKernelOp">, Terminator]> { let summary = "return operation"; let description = [{ return oneflow ops in reg context ``` }]; let arguments = (ins Variadic:$operands); let builders = [ OpBuilder<(ins), [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } def LowerLauncherToLLVMPtrPass : Pass<"lower-launcher-to-llvm-ptr", "ModuleOp"> { let summary = "convert okl dialect func to llvm dialect"; let constructor = "mlir::okl::createLowerLauncherToLLVMPtrPass()"; } def LowerOKLToLLVMCallPass : Pass<"lower-okl-to-llvm-call", "ModuleOp"> { let summary = "convert okl dialect ops to llvm dialect llvm.call"; let constructor = "mlir::okl::createLowerOKLToLLVMCallPass()"; } def TagCudaGraphSupportPass : Pass<"tag-cuda-graph-support", "ModuleOp"> { let summary = "tag cuda graph support according to its wrapped ops"; let constructor = "mlir::okl::createTagCudaGraphSupportPass()"; } #endif // ONEFLOW_IR_INCLUDE_OKL_OKLOPS ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLTypes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_ #include "mlir/IR/Types.h" #define GET_TYPEDEF_CLASSES #include "OneFlow/OKLTypes.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKL/OKLTypes.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES include "OneFlow/OKL/OKLBase.td" include "mlir/IR/AttrTypeBase.td" def LauncherContextType : OKL_Type<"LauncherContext", "launcher_ctx">; #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES ================================================ FILE: oneflow/ir/include/OneFlow/OKL/passes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "OneFlow/OKL/Conversion/OKLToLLVM.h" namespace mlir { namespace okl { #define GEN_PASS_CLASSES #define GEN_PASS_REGISTRATION #include "OneFlow/OKLPasses.h.inc" } // namespace okl } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKM/Conversion/Conversion.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_ #include "mlir/IR/BuiltinOps.h" namespace mlir { namespace okm { LogicalResult LowerWrapOpsToOKL(ModuleOp module); } } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMAttributes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_ #include "mlir/IR/BuiltinAttributes.h" #define GET_ATTRDEF_CLASSES #include "OneFlow/OKMAttributes.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMAttributes.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES include "OneFlow/OKM/OKMBase.td" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMBase.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE include "OneFlow/OKM/OKMDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/Pass/PassBase.td" class OKM_Op traits = []> : Op; class OKM_Attr traits = []> : AttrDef { let mnemonic = attrMnemonic; } #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMDialect.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_ #include "mlir/IR/Dialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "OneFlow/OKMDialect.h.inc" #include "OneFlow/OKM/OKMOps.h" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMDialect.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT include "mlir/IR/OpBase.td" def OKM_Dialect : Dialect { let name = "okm"; let summary = "OneFlow Kernel Memory Dialect."; let description = [{ This dialect is the IR of abstract represent of OneFlow Kernel Launch Op. }]; let cppNamespace = "::mlir::okm"; let dependentDialects = [ "func::FuncDialect", "memref::MemRefDialect" ]; } #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "OneFlow/OKM/OKMAttributes.h" namespace mlir { namespace func { class FuncOp; } // namespace func } // namespace mlir #define GET_OP_CLASSES #include "OneFlow/OKMOps.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMOps.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS include "OneFlow/OKM/OKMAttributes.td" include "OneFlow/OKM/OKMPasses.td" include "mlir/IR/OpBase.td" def ArgToTensorOp : OKM_Op<"arg_to_tensor"> { let arguments = (ins I32Attr:$index ); let results = (outs AnyTensor); } def ArgToMemrefOp : OKM_Op<"arg_to_memref"> { let arguments = (ins I32Attr:$index ); let results = (outs AnyMemRef); } def RetToMemrefOp : OKM_Op<"ret_to_memref"> { let arguments = (ins I32Attr:$index ); let results = (outs AnyMemRef); } def AllocMemrefOp : OKM_Op<"alloc_memref"> { let results = (outs AnyMemRef); } def PlanMemrefOp : OKM_Op<"plan_memref"> { let results = (outs AnyMemRef); } def TensorToRetOp : OKM_Op<"tensor_to_ret"> { let arguments = (ins AnyTensor:$tensor, I32Attr:$index ); let results = (outs AnyTensor); } def MemrefToRetOp : OKM_Op<"memref_to_ret"> { let arguments = (ins AnyMemRef:$tensor, I32Attr:$index ); let results = (outs AnyMemRef); } def WrapperOp : OKM_Op<"wrapper_kernel"> { let arguments = (ins Variadic:$operands ); let results = (outs Variadic); let regions = (region AnyRegion:$body); } def ReturnOp : OKM_Op<"return", [HasParent<"WrapperOp">, Terminator]> { let arguments = (ins Variadic:$operands); let builders = [ OpBuilder<(ins), [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS ================================================ FILE: oneflow/ir/include/OneFlow/OKM/OKMPasses.td ================================================ #ifndef ONEFLOW_OKM_PASSES #define ONEFLOW_OKM_PASSES include "OneFlow/OKM/OKMBase.td" def ExtractOKMTensorPass : Pass<"extract-okm-tensor", "ModuleOp"> { let summary = "extract okm tensors from args and rets"; let constructor = "mlir::okm::createExtractOKMTensorPass()"; } def WrapOKMKernelPass : Pass<"wrap-okm-kernel", "ModuleOp"> { let summary = "wrap kernel in okm"; let constructor = "mlir::okm::createWrapOKMKernelPass()"; } def OptOKMMemrefPass : Pass<"opt-okm-memref", "ModuleOp"> { let summary = "optimize okm memref"; let constructor = "mlir::okm::createOptOKMMemrefPass()"; } def ConvertOKMToOKLPass : Pass<"convert-okm-to-okl", "ModuleOp"> { let summary = "convert okm to okl"; let constructor = "mlir::okm::createConvertOKMToOKLPass()"; } #endif // ONEFLOW_OKM_PASSES ================================================ FILE: oneflow/ir/include/OneFlow/OKM/passes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace okm { namespace func_name { extern const std::string GRAPH_NAME; extern const std::string MEM_GRAPH_NAME; extern const std::string WRAP_GRAPH_NAME; extern const std::string OPT_GRAPH_NAME; extern const std::string OKL_GRAPH_NAME; extern const std::string OKL_POOL_SIZE_TAG; } // namespace func_name std::unique_ptr createExtractOKMTensorPass(); std::unique_ptr createWrapOKMKernelPass(); std::unique_ptr createOptOKMMemrefPass(); std::unique_ptr createConvertOKMToOKLPass(); #define GEN_PASS_CLASSES #define GEN_PASS_REGISTRATION #include "OneFlow/OKMPasses.h.inc" } // namespace okm } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowBase.td ================================================ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_ include "OneFlow/OneFlowDialect.td" include "OneFlow/OneFlowInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "/mlir/Interfaces/InferTypeOpInterface.td" def OneFlow_InvalidElement: TypeDef { let mnemonic = "invalid_element"; } def OneFlow_CharElement: TypeDef { let mnemonic = "char_element"; } def OneFlow_TensorBufferElement: TypeDef { let mnemonic = "tensor_buffer_element"; } def OneFlow_OFRecordElement: TypeDef { let mnemonic = "of_record_element"; } def OneFlow_OFRecordTensor : TensorOf<[OneFlow_OFRecordElement]>; def OneFlow_TensorBufferTensor : TensorOf<[OneFlow_TensorBufferElement]>; def OneFlow_Tensor : TensorOf<[AnyType]>; def SI32ArrayAttr : TypedArrayAttrBase {} def SI64ArrayAttr : TypedArrayAttrBase {} def ShapeAttr : TypedArrayAttrBase {} def DTArrayAttr : TypedArrayAttrBase {} def ShapeArrayAttr : TypedArrayAttrBase {} def ComplexDoubleAttr : TypedArrayAttrBase {} def BytesAttr : StringBasedAttr()">, "bytes attribute">; def OneFlow_IsOpConfCompatible : NativeOpTrait<"IsOpConfCompatible">; def OneFlow_IsImportCompatible : NativeOpTrait<"IsImportCompatible">; def OneFlow_AlternativeOp : NativeOpTrait<"IsAlternative">; def OneFlow_TensorSource : NativeOpTrait<"TensorSource">; def OneFlow_OnlyExistsInIR : NativeOpTrait<"OnlyExistsInIR">; def OneFlow_ElementwiseOp : NativeOpTrait<"IsElementwise">; class OneFlow_IROp traits = []> : Op {} class OneFlow_BaseOp traits = []> : Op { dag op_conf_attrs = (ins StrAttr:$op_name, StrAttr:$device_tag, StrArrayAttr:$device_name, // TODO: change device_name to dict and parse the literal fmt like "0:0-0" OptionalAttr:$scope_symbol_id, OptionalAttr:$hierarchy ); dag attrs = (ins); dag trait_attrs = (ins); dag user_op_attrs = (ins); dag input = (ins Optional:$UserSourceOpTickInput ); dag output = (outs); dag ctrl_input = (ins); dag ctrl_output = (outs); let arguments = !con( input, ctrl_input, op_conf_attrs, trait_attrs, user_op_attrs, attrs ); let results = !con( output, ctrl_output ); int same_output_regst_num = -1; bit has_check_fn = 0; bit has_logical_tensor_desc_infer_fn = 0; bit has_physical_tensor_desc_infer_fn = 0; bit has_get_sbp_fn = 0; bit has_sbp_signature_infer_fn = 0; bit has_data_type_infer_fn = 0; bit has_device_and_stream_infer_fn = 0; bit has_input_arg_modify_fn = 0; bit has_output_arg_modify_fn = 0; bit has_output_blob_time_shape_infer_fn = 0; bit has_nd_sbp_infer_fn = 0; bit has_compute_complexity_fn = 0; bit has_get_nd_sbp_fn = 0; bit has_enumerate_nd_sbp_signatures_fn = 0; bit has_dump_nd_sbp_signature_for_op_conf_fn = 0; } class OneFlow_Op traits = []> : OneFlow_BaseOp])> { let ctrl_input = (ins Variadic:$ctrl_inputs); let ctrl_output = (outs Optional:$ctrl_output); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); } class OneFlow_UserBaseOp traits = [OneFlow_AlternativeOp]> : OneFlow_BaseOp { let summary = ""; let user_op_attrs = (ins StrAttr:$op_type_name, // NOTE: vector types must have positive constant sizes, so we can't use I32ElementsAttr I32ArrayAttr:$input_sizes, I32ArrayAttr:$output_sizes ); } // Why don't we merge ctrl in/out and data in/out into operand_segment/result_segment_sizes? // 1. We only need to erase operand_segment/result_segment_sizes when we are creating a concrete user op // 2. Isolating data and ctrl make debug easier and produced IR more human-readable class OneFlow_UserBaseWithCtrlOp traits = []> : OneFlow_UserBaseOp])> { let summary = ""; let ctrl_input = (ins Variadic:$ctrl_inputs); let ctrl_output = (outs Optional:$ctrl_output); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); } class OneFlow_ConvolutionBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow convolution operation"; let description = [{ "The convolution operator consumes an input tensor and a filter, and" "computes the output." }]; let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$weight, Optional:$bias, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out); let attrs = (ins DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups, DefaultValuedAttr:$tuning_cache ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_TFPoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow pooling operation, align with TensorFlow"; let input = (ins AnyType:$x); let output = (outs AnyType:$y); let attrs = (ins StrAttr:$padding, SI32ArrayAttr:$padding_before, SI32ArrayAttr:$padding_after, StrAttr:$data_format, SI32ArrayAttr:$pool_size, SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_TFPoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow pooling grad operation, align with TensorFlow"; let input = (ins AnyType:$x, AnyType:$y, AnyType:$dy ); let output = (outs AnyType:$dx); let attrs = (ins StrAttr:$padding, SI32ArrayAttr:$padding_before, SI32ArrayAttr:$padding_after, StrAttr:$data_format, SI32ArrayAttr:$pool_size, SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_MaxPoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Max Pooling operation"; let input = (ins AnyType:$x ); let output = (outs AnyType:$y, AnyType:$indice ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, SI32ArrayAttr:$dilation, DefaultValuedAttr:$return_indices, DefaultValuedAttr:$ceil_mode ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_MaxUnpoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Max Unpooling operation"; let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$kernel_size, SI64ArrayAttr:$stride, SI64ArrayAttr:$padding, DefaultValuedAttr:$has_output_size, ShapeAttr:$output_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_AvgPoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Average Pooling operation"; let input = (ins AnyType:$x ); let output = (outs AnyType:$y ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, DefaultValuedAttr:$ceil_mode, DefaultValuedAttr:$count_include_pad, DefaultValuedAttr:$divisor_override ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_MaxPoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Max Pooling Grad operation"; let input = (ins AnyType:$x, AnyType:$indice, AnyType:$dy ); let output = (outs AnyType:$dx ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, SI32ArrayAttr:$dilation, DefaultValuedAttr:$return_indices, DefaultValuedAttr:$ceil_mode ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_AvgPoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Average Pooling Grad operation"; let input = (ins AnyType:$x, AnyType:$dy ); let output = (outs AnyType:$dx ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, DefaultValuedAttr:$ceil_mode, DefaultValuedAttr:$count_include_pad, DefaultValuedAttr:$divisor_override ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } class OneFlow_MaxUnpoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow Max Unpooling Grad operation"; let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$indices, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow adaptive pool operation"; let input = (ins AnyType:$x ); let output = (outs AnyType:$y); let attrs = (ins StrAttr:$data_format, SI64ArrayAttr:$output_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow adaptive pool operation"; let input = (ins AnyType:$x, AnyType:$dy ); let output = (outs AnyType:$dx); let attrs = (ins StrAttr:$data_format, SI64ArrayAttr:$output_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_UnaryBaseOp traits = []> : OneFlow_BaseOp { let summary = ""; let input = (ins AnyType:$x); let output = (outs AnyType:$y); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_AdaptiveMaxPoolBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow adaptive max pool operation"; let input = (ins AnyType:$x ); let output = (outs AnyType:$y, AnyType:$index ); let attrs = (ins StrAttr:$data_format, SI64ArrayAttr:$output_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } class OneFlow_AdaptiveMaxPoolGradBaseOp traits = []> : OneFlow_BaseOp])> { let summary = "OneFlow adaptive max pool grad operation"; let input = (ins AnyType:$dy, AnyType:$x, AnyType:$index ); let output = (outs AnyType:$dx); let attrs = (ins StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Idempotent : NativeOpTrait<"IsIdempotentOfIdenticalPlacement">; class OneFlow_IdempotentBaseOp traits = []> : OneFlow_UnaryBaseOp {} def OneFlow_Involution : NativeOpTrait<"IsInvolutionOfIdenticalPlacement">; class OneFlow_InvolutionBaseOp traits = []> : OneFlow_UnaryBaseOp {} #define GET_ONEFLOW_BASE_OP_DEFINITIONS include "OneFlow/OneFlowUserOps.td" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowDataTypeConversion.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_ #include "mlir/IR/Builders.h" #include "OneFlow/OneFlowSupport.h" namespace mlir { namespace oneflow { Type getTypeFromOneFlowDataType(MLIRContext* context, ::oneflow::DataType dt); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowDialect.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_ONEFLOWDIALECT_H #define ONEFLOW_ONEFLOWDIALECT_H #include "mlir/IR/Dialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "OneFlow/SBP/SBPDialect.h" #include "OneFlow/OneFlowOpsDialect.h.inc" #endif // ONEFLOW_ONEFLOWDIALECT_H ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowDialect.td ================================================ #ifndef ONEFLOW_DIALECT #define ONEFLOW_DIALECT include "mlir/IR/OpBase.td" def OneFlow_Dialect : Dialect { let name = "oneflow"; let summary = "OneFlow MLIR dialect."; let description = [{ This dialect is the IR of OneFlow. }]; let cppNamespace = "::mlir::oneflow"; let dependentDialects = [ "sbp::SBPDialect", "func::FuncDialect" ]; let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 1; } #endif // ONEFLOW_DIALECT ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowEnums.td ================================================ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_ include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" def OneFlow_InvalidDataType : I32EnumAttrCase<"DT_InvalidDataType", 0>; def OneFlow_Char : I32EnumAttrCase<"DT_Char", 1>; def OneFlow_Float : I32EnumAttrCase<"DT_Float", 2>; def OneFlow_Double : I32EnumAttrCase<"DT_Double", 3>; def OneFlow_Int8 : I32EnumAttrCase<"DT_Int8", 4>; def OneFlow_Int32 : I32EnumAttrCase<"DT_Int32", 5>; def OneFlow_Int64 : I32EnumAttrCase<"DT_Int64", 6>; def OneFlow_UInt8 : I32EnumAttrCase<"DT_UInt8", 7>; def OneFlow_OFRecord : I32EnumAttrCase<"DT_OFRecord", 8>; def OneFlow_Float16 : I32EnumAttrCase<"DT_Float16", 9>; def OneFlow_TensorBuffer: I32EnumAttrCase<"DT_TensorBuffer", 10>; def OneFlow_BFloat16: I32EnumAttrCase<"DT_BFloat16", 11>; def OneFlow_Bool: I32EnumAttrCase<"DT_Bool", 12>; def OneFlow_DataType: I32EnumAttr<"DataType", "OneFlow Data Type enum", [ OneFlow_InvalidDataType, OneFlow_Char, OneFlow_Float, OneFlow_Double, OneFlow_Int8, OneFlow_Int32, OneFlow_Int64, OneFlow_UInt8, OneFlow_OFRecord, OneFlow_Float16, OneFlow_TensorBuffer, OneFlow_BFloat16, OneFlow_Bool ] > { let cppNamespace = "::mlir::oneflow"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString"; } def OneFlow_Contiguous : I32EnumAttrCase<"MF_Contiguous", 0>; def OneFlow_ChannelsLast : I32EnumAttrCase<"MF_ChannelsLast", 1>; def OneFlow_Preserve : I32EnumAttrCase<"MF_Preserve", 2>; def OneFlow_MemoryFormatCount : I32EnumAttrCase<"MF_MemoryFormatCount", 3>; def OneFlow_MemoryFormat: I32EnumAttr<"MemoryFormat", "OneFlow Memory Format enum", [ OneFlow_Contiguous, OneFlow_ChannelsLast, OneFlow_Preserve, OneFlow_MemoryFormatCount ] > { let cppNamespace = "::mlir::oneflow"; let stringToSymbolFnName = "ConvertToMemoryFormat"; let symbolToStringFnName = "ConvertMemoryFormatToString"; } #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowInterfaces.td ================================================ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_ include "mlir/IR/OpBase.td" def UserOpCompatibleInterface : OpInterface<"UserOpCompatible"> { let description = [{ Interface to getting the hard-coded bn }]; let methods = [ StaticInterfaceMethod<"", "const std::vector*", "inputKeys", (ins), [{ static std::vector val(mlir::oneflow::support::GetInputKeys(ConcreteOp::getOperationName().split('.').second.str())); return &val; }]>, StaticInterfaceMethod<"", "const std::vector*", "outputKeys", (ins), [{ static std::vector val(mlir::oneflow::support::GetOutputKeys(ConcreteOp::getOperationName().split('.').second.str())); return &val; }]>, InterfaceMethod<"", "std::pair", "getODSOperandIndexAndLength", (ins "unsigned":$index), [{ return $_op.getODSOperandIndexAndLength(index); }]>, InterfaceMethod<"", "std::pair", "getODSResultIndexAndLength", (ins "unsigned":$index), [{ return $_op.getODSResultIndexAndLength(index); }]> ]; let cppNamespace = "::mlir::oneflow"; } def AlternativeOpTypeNameInterface : OpInterface<"HasAlternativeOpTypeName"> { let description = [{ Interface to getting control edges }]; let methods = [ StaticInterfaceMethod<"", "std::string", "getOriginalOpTypeName", (ins) >, StaticInterfaceMethod<"", "const std::vector*", "inputKeys", (ins), [{ static std::vector val(mlir::oneflow::support::GetInputKeys(ConcreteOp::getOriginalOpTypeName())); return &val; }]>, StaticInterfaceMethod<"", "const std::vector*", "outputKeys", (ins), [{ static std::vector val(mlir::oneflow::support::GetOutputKeys(ConcreteOp::getOriginalOpTypeName())); return &val; }]>, ]; let cppNamespace = "::mlir::oneflow"; } def ControlEdgeCompatibleInterface : OpInterface<"ControlEdgeCompatible"> { let description = [{ Interface to getting control edges }]; let methods = [ InterfaceMethod<"", "::mlir::OperandRange", "dataInputOperands", (ins) >, InterfaceMethod<"", "::mlir::OperandRange", "ctrlInputOperands", (ins) >, InterfaceMethod<"", "::mlir::ResultRange", "dataOutputResults", (ins) >, InterfaceMethod<"", "::mlir::Value", "ctrlOutputResult", (ins) > ]; let cppNamespace = "::mlir::oneflow"; } def NoGrad : OpInterface<"NoGrad"> { let description = [{ }]; let cppNamespace = "::mlir::oneflow"; } def SupportNonContiguous : OpInterface<"SupportNonContiguous"> { let description = [{ }]; let cppNamespace = "::mlir::oneflow"; } def CpuOnly : OpInterface<"CpuOnly"> { let description = [{ }]; let cppNamespace = "::mlir::oneflow"; } def NCHWCompatibleInterface : OpInterface<"NCHWCompatible"> { let description = [{ Interface of NCHW compatibility }]; let methods = [ InterfaceMethod<"", "bool", "IsNCHW", (ins) >, InterfaceMethod<"Create NHWC op and return the new op's results to be transposed", "llvm::SmallVector", "NchwToNhwc", (ins "llvm::SmallVector": $transposed_inputs, "PatternRewriter&": $rewriter) >, InterfaceMethod<"", "llvm::DenseSet", "OperandsToTranspose", (ins) >, InterfaceMethod<"", "llvm::DenseSet", "ResultsToTranspose", (ins) >, ]; let cppNamespace = "::mlir::oneflow"; } def BiasAddCompatibleInterface : OpInterface<"BiasAddCompatible"> { let description = [{ Interface of ops used as bias add }]; let methods = [ InterfaceMethod<"", "bool", "isLastDim", (ins) >, InterfaceMethod<"", "mlir::Value", "biasAddGetBias", (ins) >, InterfaceMethod<"", "mlir::Value", "biasAddGetOut", (ins) >, ]; let cppNamespace = "::mlir::oneflow"; } def MatMulCompatibleInterface : OpInterface<"MatMulCompatible"> { let description = [{ Interface of ops used as matmul }]; let methods = [ InterfaceMethod<"is this a transpose_a=false, transpose_b=true matmul", "bool", "isLinear", (ins) >, InterfaceMethod<"", "mlir::Value", "matMulGetX", (ins) >, InterfaceMethod<"", "mlir::Value", "matMulGetW", (ins) >, InterfaceMethod<"", "mlir::Value", "matMulGetY", (ins) >, ]; let cppNamespace = "::mlir::oneflow"; } #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowOpGetGen.td ================================================ include "OneFlow/OneFlowDialect.td" include "OneFlow/OneFlowEnums.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "OneFlow/OneFlowBase.td" ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowOpTraits.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "oneflow/core/operator/op_conf.pb.h" namespace mlir { namespace OpTrait { namespace impl { OpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op); OpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op); LogicalResult VerifyIsOpConfCompatible(Operation* op); LogicalResult VerifyIsImportCompatible(Operation* op); // trait IsOpConfCompatible LogicalResult saveAttrToOpConf(Operation* op, ::oneflow::OperatorConf* op_conf); LogicalResult saveAttrsToNamedAttrList(Operation* op, NamedAttrList& named_attr_list); StringAttr getOpName(Operation* op); StringAttr getDeviceTag(Operation* op); ArrayAttr getDeviceName(Operation* op); IntegerAttr getScopeSymbolID(Operation* op); ArrayAttr getHierarchy(Operation* op); } // namespace impl template class IsOpConfCompatible : public TraitBase { public: static StringRef getOpNameAttr() { return "op_name"; } static StringRef getDeviceTagAttr() { return "device_tag"; } static StringRef getDeviceNameAttr() { return "device_name"; } static StringRef getScopeSymbolIDAttr() { return "scope_symbol_id"; } static StringRef getHierarchyAttr() { return "hierarchy"; } static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsOpConfCompatible(op); } static LogicalResult dump_attr(Operation* op, ::oneflow::OperatorConf* op_conf) { return impl::saveAttrToOpConf(op, op_conf); } static LogicalResult saveToNamedAttrList(Operation* op, NamedAttrList& named_attr_list) { return impl::saveAttrsToNamedAttrList(op, named_attr_list); } static StringAttr getOpName(Operation* op) { return impl::getOpName(op); } static StringAttr getDeviceTag(Operation* op) { return impl::getDeviceTag(op); } static ArrayAttr getDeviceName(Operation* op) { return impl::getDeviceName(op); } static IntegerAttr getScopeSymbolID(Operation* op) { return impl::getScopeSymbolID(op); } static ArrayAttr getHierarchy(Operation* op) { return impl::getHierarchy(op); } }; template class IsImportCompatible : public TraitBase { public: static StringRef getOutputLBNsAttr() { return "output_lbns"; } static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsImportCompatible(op); } }; template class IsIdempotentOfIdenticalPlacement : public TraitBase { public: static LogicalResult verifyTrait(Operation* op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); static_assert(ConcreteType::template hasTrait(), "expected operation to take one operand"); static_assert(ConcreteType::template hasTrait(), "expected operation to preserve type"); static_assert(ConcreteType::template hasTrait(), "expected operation to be op conf compatible"); return impl::verifyIsIdempotent(op); } static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { return impl::foldIdempotentOfIdenticalPlacement(op); } }; template class IsInvolutionOfIdenticalPlacement : public TraitBase { public: static LogicalResult verifyTrait(Operation* op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); static_assert(ConcreteType::template hasTrait(), "expected operation to take one operand"); static_assert(ConcreteType::template hasTrait(), "expected operation to preserve type"); static_assert(ConcreteType::template hasTrait(), "expected operation to be op conf compatible"); return impl::verifyIsInvolution(op); } static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { return impl::foldInvolutionOfIdenticalPlacement(op); } }; template class IsAlternative : public TraitBase { public: static StringRef getOpTypeNameAttr() { return "op_type_name"; } static LogicalResult verifyTrait(Operation* op) { if (op->hasAttrOfType(getOpTypeNameAttr())) { return success(); } else { return op->emitError("expected operation to have attribute: " + getOpTypeNameAttr()); } } }; template class TensorSource : public TraitBase { public: static StringRef getShapeAttrName() { return "shape"; } static StringRef getDataTypeAttrName() { return "data_type"; } static StringRef getIsDynamicAttrName() { return "is_dynamic"; } static StringRef getNdSbpAttrName() { return "nd_sbp"; } static StringRef getSbpAttrName() { return "parallel"; } static LogicalResult verifyTrait(Operation* op) { if (!op->hasAttrOfType(getShapeAttrName())) { return op->emitError("expected operation to have attribute: " + getShapeAttrName()); } if (!op->hasAttrOfType(getDataTypeAttrName())) { return op->emitError("expected operation to have attribute: " + getDataTypeAttrName()); } return success(); } }; template class OnlyExistsInIR : public TraitBase {}; template class IsElementwise : public TraitBase {}; } // namespace OpTrait } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Builders.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/PatternMatch.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/OneFlowInterfaces.h.inc" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/SBP/SBPAttributes.h" namespace mlir { namespace func { class FuncOp; } // namespace func } // namespace mlir #define GET_OP_CLASSES #include "OneFlow/OneFlowOps.h.inc" #define GET_OP_CLASSES #include "OneFlow/OneFlow.gen_ops.h.inc" namespace mlir { namespace oneflow { template inline std::string GetOpTypeName(T op) { std::string op_type_name = op->getName().stripDialect().str(); if (op->template hasTrait()) { op_type_name = op->template getAttrOfType(OpTrait::IsAlternative::getOpTypeNameAttr()) .str(); } if (auto alternative_name = dyn_cast(op)) { op_type_name = alternative_name.getOriginalOpTypeName(); } if (auto user_op = dyn_cast(op)) { op_type_name = user_op.getOpTypeName().str(); } return op_type_name; } ResultRange GetDataOutputResults(Operation* op); OperandRange GetDataInputOperands(Operation* op); llvm::Optional GetCtrlIntputOperands(Operation* op); llvm::Optional GetCtrlOutputResult(Operation* op); ArrayAttr getSI32ArrayAttr(::mlir::PatternRewriter& rewriter, ArrayRef values); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowOps.td ================================================ #ifndef ONEFLOW_OPS #define ONEFLOW_OPS include "OneFlow/OneFlowDialect.td" include "OneFlow/OneFlowEnums.td" include "OneFlow/OneFlowInterfaces.td" include "OneFlow/OneFlowBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Pass/PassBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "OneFlow/SBP/SBPOps.td" #ifndef REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS def OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<"user", [OneFlow_IsImportCompatible]> { let summary = ""; let input = (ins Variadic:$data_input); let output = (outs Variadic:$data_output); let attrs = (ins StrArrayAttr:$output_lbns ); let hasCanonicalizer = 1; } def OneFlow_ConfOp : OneFlow_BaseOp<"conf", [OneFlow_IsImportCompatible]> { let summary = "This op is mainly used by create its adaptor in importing/exporting"; } def OneFlow_SystemOp : OneFlow_Op<"system", [OneFlow_IsImportCompatible]> { let summary = ""; let input = (ins Variadic:$data_input); let output = (outs Variadic:$data_output); let attrs = (ins StrArrayAttr:$input_bns, StrArrayAttr:$output_lbns, I32Attr:$op_type_case ); let hasCanonicalizer = 1; } def F32ElementsAttr : FloatElementsAttr<32>; def OneFlow_FrozenVariableOp : OneFlow_IROp<"variable_ir", [ConstantLike, NoMemoryEffect]> { let summary = "Auxiliary variable op for constant folding, only exists in IR."; let arguments = (ins F32ElementsAttr:$value, StrAttr:$op_name, OptionalAttr:$data_type, StrAttr:$device_tag, StrArrayAttr:$device_name, // TODO: change device_name to dict and parse the literal fmt like "0:0-0" OptionalAttr:$scope_symbol_id, OptionalAttr:$hierarchy, StrArrayAttr:$nd_sbp ); let results = (outs AnyType:$output ); let hasFolder = 1; } def OneFlow_Add2Op : OneFlow_BaseOp<"add_n2", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = ""; let input = (ins AnyType:$in0, AnyType:$in1 ); let output = (outs AnyType:$out); } class OneFlow_ConcreteSystemOp traits = []> : OneFlow_BaseOp])> { let input = (ins); let output = (ins); let ctrl_input = (ins Variadic:$ctrl_inputs); let ctrl_output = (outs Optional:$ctrl_output); dag required_attrs = (ins StrArrayAttr:$output_lbns); dag custom_attrs = (ins); let attrs = !con( required_attrs, custom_attrs ); let hasCanonicalizer = 1; } def OneFlow_VariableOp : OneFlow_ConcreteSystemOp<"variable", [OneFlow_TensorSource]> { let summary = ""; let input = (ins); let output = (outs AnyType:$output); let custom_attrs = (ins ShapeAttr:$shape, OptionalAttr:$data_type, OptionalAttr:$model_name, OptionalAttr:$l1_regularization, OptionalAttr:$l2_regularization, OptionalAttr:$trainable, OptionalAttr:$float_initializer, OptionalAttr:$integer_initializer, OptionalAttr:$parallel ); } def OneFlow_InputOp : OneFlow_ConcreteSystemOp<"input", [OneFlow_TensorSource]> { let summary = ""; let input = (ins AnyType:$input); let output = (outs AnyType:$output); let custom_attrs = (ins OptionalAttr:$shape, OptionalAttr:$data_type, OptionalAttr:$is_dynamic, OptionalAttr:$nd_sbp, OptionalAttr:$job_name ); let builders = [ OpBuilder<(ins "::oneflow::OperatorConf":$op_conf )> ]; } def OneFlow_OutputOp : OneFlow_ConcreteSystemOp<"output", [OneFlow_TensorSource]> { let summary = ""; let input = (ins AnyType:$input); let output = (outs AnyType:$output); let custom_attrs = (ins OptionalAttr:$shape, OptionalAttr:$data_type, OptionalAttr:$is_dynamic, OptionalAttr:$nd_sbp, OptionalAttr:$job_name ); } def OneFlow_Job : Op { let regions = (region AnyRegion:$body); let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, OptionalAttr:$sym_visibility, OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs ); let builders = [OpBuilder<(ins "StringRef":$sym_name, "FunctionType":$function_type, CArg<"ArrayRef", "{}">:$attrs )>]; let extraClassDeclaration = [{ bool isDeclaration() { return isExternal(); } ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } ArrayRef getResultTypes() { return getFunctionType().getResults(); } LogicalResult verifyType() { auto type = getFunctionTypeAttr().getValue(); if (!type.isa()) return emitOpError("requires '" + getFunctionTypeAttrName().str() + "' attribute of function type"); return success(); } }]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def OneFlow_ReturnOp : Op, MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "return operation"; let description = [{ The "return" operation represents a return operation within a Job. The operation takes an optional tensor operand and produces no results. The operand type must match the signature of the job function that contains the operation. For example: ```mlir job @foo() -> tensor<2xf64> { ... oneflow.return %0 : tensor<2xf64> } ``` }]; let arguments = (ins Variadic:$operands); let builders = [ OpBuilder<(ins), [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def OneFlow_NormalizationInferenceOp : OneFlow_NormalizationBaseOp<"normalization_infer", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$y ); } #endif // REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS #endif // ONEFLOW_OPS ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowPDLLPatterns.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_ #include "mlir/IR/PatternMatch.h" namespace mlir { namespace oneflow { void populateAllocEliminationPatterns(RewritePatternSet& patterns); void populateForwardOpPatterns(RewritePatternSet& patterns); void populateNormalizationOpPatterns(RewritePatternSet& patterns); void populateFuseConv2DBatchNormPattern(RewritePatternSet& patterns); void populateFuseOpsWithBackwardImplPattern(RewritePatternSet& patterns); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowPasses.td ================================================ #ifndef ONEFLOW_PASSES #define ONEFLOW_PASSES include "OneFlow/OneFlowOps.td" #ifdef WITH_MLIR_CUDA_CODEGEN def NVVMToCubinPass : InterfacePass<"nvvm-to-cubin", "SymbolOpInterface"> { let summary = "convert nvvm ir to cubin"; let constructor = "mlir::oneflow::createNVVMToCubinPass()"; let options = [ Option<"triple", "triple", "StringRef", "\"nvptx64-nvidia-cuda\"", "Target triple">, Option<"chip", "chip", "StringRef", "mlir::oneflow::getArchVersion()", "Target architecture">, Option<"features", "features", "StringRef", "\"+ptx60\"", "Target features">, ]; } #endif // WITH_MLIR_CUDA_CODEGEN def TestOneFlowTraitFolderPass : Pass<"test-oneflow-trait-folder", "func::FuncOp"> { let constructor = "mlir::oneflow::createTestOneFlowTraitFolderPass()"; } def LowerOneFlowToTosaPass : Pass<"lower-oneflow-to-tosa", "ModuleOp"> { let summary = "lower oneflow dialect to tosa dialect"; let constructor = "mlir::oneflow::createLowerOneFlowToTosaPass()"; let dependentDialects = ["tosa::TosaDialect", "memref::MemRefDialect", "mlir::func::FuncDialect"]; let options = [ Option<"variableAsConstant", "variable-as-constant", "int", "0", "convert variable op as const op of tosa">, Option<"fullyConvert", "full", "bool", /*default=*/"true", "Fully convert operations and make OneFlow dialect illegal target">, Option<"lowerJob", "lower-job", "bool", /*default=*/"true", "Convert oneflow.job to func.func">, ]; } def LowerOneFlowToLinalgPass : Pass<"lower-oneflow-to-linalg", "ModuleOp"> { let summary = "lower oneflow dialect to Linalg dialect"; let constructor = "mlir::oneflow::createLowerOneFlowToLinalgPass()"; let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect", "mlir::func::FuncDialect"]; } def FuncToOneFlowJobPass : Pass<"func-to-ofjob", "ModuleOp"> { let summary = "convert func Ops to oneflow Ops"; let constructor = "mlir::oneflow::createFuncToOneFlowJobPass()"; let dependentDialects = ["mlir::func::FuncDialect"]; } def OneFlowJobToFuncPass : Pass<"ofjob-to-func", "ModuleOp"> { let summary = "convert oneflow Ops to func Ops"; let constructor = "mlir::oneflow::createOneFlowJobToFuncPass()"; let dependentDialects = ["mlir::func::FuncDialect"]; } def CastOneFlowOpsToSignlessPass : Pass<"cast-ofops-to-signless", "ModuleOp"> { let summary = "cast oneflow ops to singless"; let constructor = "mlir::oneflow::createCastOneFlowOpsToSignlessPass()"; let dependentDialects = ["mlir::func::FuncDialect", "mlir::BuiltinDialect"]; } def BufferHostRegisterPass : Pass<"buffer-host-register", "func::FuncOp"> { let summary = ""; let constructor = "mlir::oneflow::createBufferHostRegisterPass()"; let dependentDialects = ["gpu::GPUDialect"]; } def GpuCopyArgPass : Pass<"gpu-copy-arg", "func::FuncOp"> { let summary = ""; let constructor = "mlir::oneflow::createGpuCopyArgPass()"; let dependentDialects = ["memref::MemRefDialect", "gpu::GPUDialect"]; } def OutlineJitFunctionPass : InterfacePass<"outline-jit-function", "FunctionOpInterface"> { let summary = "move ops could be jitted to jit function"; let constructor = "mlir::oneflow::createOutlineJitFunctionPass()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect", "LLVM::LLVMDialect"]; let options = [ Option<"compileToLLVM", "compile-to-llvm", "bool", /*default=*/"true", "Convert to llvm dialect in this pass">, ]; } def AggregateComputeOpsPass : Pass<"aggregate-compute-ops", "ModuleOp"> { let summary = "aggregate compute ops together"; let constructor = "mlir::oneflow::createAggregateComputeOpsPass()"; } def WrapOpsToKernelLaunchPass : Pass<"wrap-ops-to-kernel-launch", "ModuleOp"> { let summary = "wrap user ops with a single kernel launch op in OneFlow Job"; let constructor = "mlir::oneflow::createWrapOpsToKernelLaunchPass()"; } def EliminateAllocOpsPass : Pass<"eliminate-alloc-ops", "ModuleOp"> { let summary = "eliminate memref.alloc and memref.copy which target is a block argument"; let constructor = "mlir::oneflow::createEliminateAllocOpsPass()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect"]; } def AppendOneFlowStreamPass : Pass<"append-ofstream", "ModuleOp"> { let summary = "append oneflow stream to gpu function arguments"; let constructor = "mlir::oneflow::createAppendOneFlowStreamPass()"; } def MgpuToOneFlowStreamPass : Pass<"mgpu-to-ofstream", "ModuleOp"> { let summary = "convert mlir abi about mgpu to oneflow stream, this pass should be invoked after append-ofstream pass"; let constructor = "mlir::oneflow::createMgpuToOneFlowStreamPass()"; } def FuseIntoExistingOpPass : Pass<"fuse-into-existing-op", "ModuleOp"> { let summary = ""; let constructor = "mlir::oneflow::createFuseIntoExistingOpPass()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect"]; } def InsertOneFlowMemPoolPass : Pass<"insert-ofmempool", "ModuleOp"> { let summary = "insert oneflow tmp buffer as a memory pool in mlir codegen"; let constructor = "mlir::oneflow::createInsertOneFlowMemPoolPass()"; } def FoldAllocToSubviewPass : Pass<"fold-alloc-to-subview", "func::FuncOp"> { let summary = "fold dispersed memref.alloc ops with memory optimize algo to a single memref.alloc op and memref.subview ops"; let constructor = "mlir::oneflow::createFoldAllocToSubviewPass()"; } def AutoNhwcPass : Pass<"auto-nhwc", "ModuleOp"> { let summary = ""; let constructor = "mlir::oneflow::createAutoNhwcPass()"; } def PreConvertInferenceOpPass : Pass<"pre-convert-inference-op", "ModuleOp"> { let summary = "Convert variable op to variable ir op for constant folding."; let constructor = "mlir::oneflow::createPreConvertInferenceOpPass()"; } def ConvertInferenceOpPass : Pass<"convert-inference-op", "ModuleOp"> { let summary = "Convert ops to their inference version and rewrite them with a more performant equivalent in inference workflow."; let constructor = "mlir::oneflow::createConvertInferenceOpPass()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect"]; } def PostConvertInferenceOpPass : Pass<"post-convert-inference-op", "ModuleOp"> { let summary = "Convert variable ir op to variable op after contant folding."; let constructor = "mlir::oneflow::createPostConvertInferenceOpPass()"; } def ConvertToSignlessForTosaPass : Pass<"convert-to-signless-for-tosa", "ModuleOp"> { let summary = "convert func type to unsigned before lowering to tosa"; let description = [{ In oneflow, int typed tensor is explicit signed. Convert them before lowering to TOSA. }]; let constructor = "mlir::oneflow::createConvertToSignlessForTosaPass()"; let dependentDialects = ["func::FuncDialect"]; } def CSEWithAttributesIgnored : Pass<"cse-with-attributes-ignored", "ModuleOp"> { let summary = "ignore oneflow attributes to have cse work"; let description = [{ cse and ignore oneflow attributes like op name, symbol id, etc. }]; let constructor = "mlir::oneflow::createCSEWithAttributesIgnored()"; let dependentDialects = []; } def CSEPutAttributes : Pass<"cse-put-attributes", "ModuleOp"> { let summary = "cse and ignore oneflow attributes"; let description = [{ put back oneflow attributes like op name, symbol id, etc. }]; let constructor = "mlir::oneflow::createCSEPutAttributes()"; let dependentDialects = []; } def GroupMatMul : Pass<"group-matmul", "ModuleOp"> { let summary = "group matmul together"; let description = [{ group matmul ops together and use cudnn batched matmul }]; let constructor = "mlir::oneflow::createGroupMatMul()"; let dependentDialects = []; } def FuseForwardOps : Pass<"fuse-forward-only-ops", "ModuleOp"> { let summary = "fuse forward ops"; let description = [{ fuse forward ops. Usually they are actions after an op. }]; let constructor = "mlir::oneflow::createFuseForwardOps()"; let dependentDialects = []; } def FuseOpsWithBackwardImpl : Pass<"fuse-ops-with-backward-impl", "ModuleOp"> { let summary = "fuse ops with backward impl"; let description = [{ fuse ops with backward impl. }]; let constructor = "mlir::oneflow::createFuseOpsWithBackwardImpl()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect"]; } def FuseNormalizationOps : Pass<"fuse-normalization-ops", "ModuleOp"> { let summary = "fuse forward ops"; let description = [{ fuse forward ops. Usually they are actions after an op. }]; let constructor = "mlir::oneflow::createFuseNormalizationOps()"; let dependentDialects = ["pdl_interp::PDLInterpDialect", "pdl::PDLDialect"]; } #endif // ONEFLOW_PASSES ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowPatternUtils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" namespace mlir { namespace oneflow { namespace rewrites { mlir::IntegerAttr GetDefaultSeed(::mlir::PatternRewriter& rewriter); void populateRewrites(RewritePatternSet& patterns); } // namespace rewrites namespace constraints { void populateConstraints(RewritePatternSet& patterns); } // namespace constraints } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowPatterns.td ================================================ #ifndef ONEFLOW_PATTERNS #define ONEFLOW_PATTERNS include "mlir/IR/PatternBase.td" include "OneFlow/OneFlowOps.td" include "mlir/Dialect/MemRef/IR/MemRefOps.td" include "mlir/Dialect/GPU/IR/GPUOps.td" def GetFirstValue : NativeCodeCall<"*$0.begin()">; def IsAddToOutputNone: Constraint, "">; def IsTraingTrue: Constraint, "">; def IsArg: Constraint()">, "">; def getResultTypes : NativeCodeCall<"$0.getResultTypes()">; def CreateGPUMemcpyOpFromMemrefCopy : NativeCodeCall<"::mlir::oneflow::CreateGPUMemcpyOpFromMemrefCopy($_builder, $0)">; def ReplaceCopyWithGPUPattern : Pat< ( CopyOp:$results $src, $dst ), ( CreateGPUMemcpyOpFromMemrefCopy $results ), [(IsArg $dst)] >; #endif // ONEFLOW_PATTERNS ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowSupport.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_ #include #include #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "OneFlow/OneFlowEnums.h.inc" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/tensor.h" // This include is not necessary now, but it is here for testing the namespace collision #include "oneflow/core/framework/user_op_registry_manager.h" namespace mlir { namespace oneflow { namespace support { const ::oneflow::UserOpDef& getUserOpDef(const std::string& op_type_name); static const std::vector* inputKeys() { static std::vector val({"in"}); return &val; } std::vector GetInputKeys(const std::string& op_type_name); std::vector GetOutputKeys(const std::string& op_type_name); mlir::DenseElementsAttr TensorToDenseElementsAttr( const std::shared_ptr<::oneflow::one::Tensor>& tensor, MLIRContext* ctx); std::shared_ptr<::oneflow::one::Tensor> DenseElementsAttrToTensor( const mlir::Attribute& attr, const mlir::Attribute& device_tag, const mlir::Attribute& device_name); void DenseElementsAttrToTensor(const mlir::Attribute& attr, const mlir::Attribute& device_tag, const mlir::Attribute& device_name, std::shared_ptr<::oneflow::one::Tensor>& tensor); FailureOr<::oneflow::DataType> FromMLIRTypeToOFDataType(Type mlir_type); FailureOr<::oneflow::DataType> FromMLIRDataTypeToOFDataType(::mlir::oneflow::DataType data_type); FailureOr<::oneflow::DataType> FromMLIRAttrToOFDataType(Attribute attr); } // namespace support } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowTypes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_ #include "mlir/IR/Types.h" #define GET_TYPEDEF_CLASSES #include "OneFlow/OneFlowOpsTypes.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowUserOps.td ================================================ #ifdef GET_ONEFLOW_ASSIGN_OP_DEFINITIONS def OneFlow_AssignUserOp : OneFlow_BaseOp<"assign", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ref, OneFlow_Tensor:$value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_AssignIfOp : OneFlow_BaseOp<"assign_if", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ref, OneFlow_Tensor:$value, OneFlow_Tensor:$condition ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_AssignIfNotOp : OneFlow_BaseOp<"assign_if_not", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ref, OneFlow_Tensor:$value, OneFlow_Tensor:$condition ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } #endif // GET_ONEFLOW_ASSIGN_OP_DEFINITIONS #ifdef GET_ONEFLOW_BASE_OP_DEFINITIONS class OneFlow_NormalizationBaseOp traits = []> : OneFlow_BaseOp])> { let input = (ins OneFlow_Tensor:$x, Optional:$moving_mean, Optional:$moving_variance, OneFlow_Tensor:$gamma, OneFlow_Tensor:$beta, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$y, Optional:$mean, Optional:$inv_variance ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$training, DefaultValuedAttr:$momentum ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } #endif // GET_ONEFLOW_BASE_OP_DEFINITIONS #ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS def OneFlow_BiasAddOp : OneFlow_BaseOp<"bias_add", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CastLikeOp : OneFlow_BaseOp<"cast_like", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$dtype_like ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_CeluGradOp : OneFlow_BaseOp<"celu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DiagGradOp : OneFlow_BaseOp<"diag_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$diagonal ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DiagonalGradOp : OneFlow_BaseOp<"diagonal_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$offset ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DotOp : OneFlow_BaseOp<"dot", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DropoutGradOp : OneFlow_BaseOp<"dropout_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ElementwiseMaximumOp : OneFlow_BaseOp<"elementwise_maximum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ElementwiseMinimumOp : OneFlow_BaseOp<"elementwise_minimum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EluGradOp : OneFlow_BaseOp<"elu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FloordivOp : OneFlow_BaseOp<"floordiv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LerpOp : OneFlow_BaseOp<"lerp", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$start, OneFlow_Tensor:$end, OneFlow_Tensor:$weight ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LerpGradOp : OneFlow_BaseOp<"lerp_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$start, OneFlow_Tensor:$end, OneFlow_Tensor:$weight, OneFlow_Tensor:$out_diff ); let output = (outs OneFlow_Tensor:$start_diff, OneFlow_Tensor:$end_diff, OneFlow_Tensor:$weight_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TruncdivOp : OneFlow_BaseOp<"truncdiv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GeluGradOp : OneFlow_BaseOp<"gelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FastGeluGradOp : OneFlow_BaseOp<"fast_gelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_QuickGeluGradOp : OneFlow_BaseOp<"quick_gelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SquareReLUGradOp : OneFlow_BaseOp<"square_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GridSampleOp : OneFlow_BaseOp<"grid_sample", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$grid ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins StrAttr:$interpolation_mode, StrAttr:$padding_mode, DefaultValuedAttr:$align_corners ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<"hardsigmoid_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardShrinkGradOp : OneFlow_BaseOp<"hardshrink_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$lambd ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardswishGradOp : OneFlow_BaseOp<"hardswish_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_L1L2RegularizeGradientOp : OneFlow_BaseOp<"l1_l2_regularize_gradient", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$l1, DefaultValuedAttr:$l2 ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LeakyReluGradOp : OneFlow_BaseOp<"leaky_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MaskedFillOp : OneFlow_BaseOp<"masked_fill", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$has_bool_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand, DefaultValuedAttr:$bool_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MishGradOp : OneFlow_BaseOp<"mish_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NarrowGradOp : OneFlow_BaseOp<"narrow_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$dim, DefaultValuedAttr:$start, DefaultValuedAttr:$length ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_PowOp : OneFlow_BaseOp<"pow", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FracOp : OneFlow_BaseOp<"frac", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_PreluOp : OneFlow_BaseOp<"prelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$alpha ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReluGradOp : OneFlow_BaseOp<"relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SeluGradOp : OneFlow_BaseOp<"selu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ThresholdGradOp : OneFlow_BaseOp<"threshold_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$threshold_val ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftShrinkGradOp : OneFlow_BaseOp<"softshrink_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TfPreluOp : OneFlow_BaseOp<"tf_prelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$alpha ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UnfoldTensorGradOp : OneFlow_BaseOp<"unfold_tensor_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$dimension, DefaultValuedAttr:$size, DefaultValuedAttr:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XdivyOp : OneFlow_BaseOp<"xdivy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XlogyOp : OneFlow_BaseOp<"xlogy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastZetaOp : OneFlow_BaseOp<"broadcast_zeta", [NoGrad,NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_BINARY_OP_DEFINITIONS #ifdef GET_ONEFLOW_BROADCAST_OP_DEFINITIONS def OneFlow_BroadcastAddOp : OneFlow_BaseOp<"broadcast_add", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastDivOp : OneFlow_BaseOp<"broadcast_div", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_BroadcastDivGradOp : OneFlow_BaseOp<"broadcast_div_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$z, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastEqualOp : OneFlow_BaseOp<"broadcast_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastFloorModOp : OneFlow_BaseOp<"broadcast_floor_mod", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastFmodOp : OneFlow_BaseOp<"broadcast_fmod", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastGreaterOp : OneFlow_BaseOp<"broadcast_greater", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadCastInplaceGreaterOp : OneFlow_BaseOp<"broadcast_inplace_greater", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastGreaterEqualOp : OneFlow_BaseOp<"broadcast_greater_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastLessOp : OneFlow_BaseOp<"broadcast_less", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastLessEqualOp : OneFlow_BaseOp<"broadcast_less_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastIsCloseEqualNanOp : OneFlow_BaseOp<"broadcast_isclose_eq_nan", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let attrs = (ins DefaultValuedAttr:$atol, DefaultValuedAttr:$rtol ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastIsCloseNotEqualNanOp : OneFlow_BaseOp<"broadcast_isclose_neq_nan", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let attrs = (ins DefaultValuedAttr:$atol, DefaultValuedAttr:$rtol ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastLikeOp : OneFlow_BaseOp<"broadcast_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI32ArrayAttr:$broadcast_axes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BroadcastLogicalAndOp : OneFlow_BaseOp<"broadcast_logical_and", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastLogicalOrOp : OneFlow_BaseOp<"broadcast_logical_or", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastLogicalXorOp : OneFlow_BaseOp<"broadcast_logical_xor", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastMaximumOp : OneFlow_BaseOp<"broadcast_maximum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastMinimumOp : OneFlow_BaseOp<"broadcast_minimum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastMulOp : OneFlow_BaseOp<"broadcast_mul", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_BroadcastNotEqualOp : OneFlow_BaseOp<"broadcast_not_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastPowOp : OneFlow_BaseOp<"broadcast_pow", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastSubOp : OneFlow_BaseOp<"broadcast_sub", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_BitwiseNotOp : OneFlow_BaseOp<"bitwise_not", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastBitwiseAndOp : OneFlow_BaseOp<"broadcast_bitwise_and", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastBitwiseOrOp : OneFlow_BaseOp<"broadcast_bitwise_or", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BroadcastBitwiseXorOp : OneFlow_BaseOp<"broadcast_bitwise_xor", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_BROADCAST_OP_DEFINITIONS #ifdef GET_ONEFLOW_CONV_OP_DEFINITIONS def OneFlow_Conv1DOp : OneFlow_ConvolutionBaseOp<"conv1d", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} def OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<"conv2d", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> {} def OneFlow_Conv3DOp : OneFlow_ConvolutionBaseOp<"conv3d", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} def OneFlow_ConvBiasGradOp : OneFlow_BaseOp<"conv_bias_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$bias_diff ); let attrs = (ins StrAttr:$data_format, DefaultValuedAttr:$num_spatial_dims ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_ConvDataGradOp : OneFlow_BaseOp<"conv_data_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$filter, OneFlow_Tensor:$x_like, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$num_spatial_dims, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_ConvFilterGradOp : OneFlow_BaseOp<"conv_filter_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$filter_diff ); let attrs = (ins DefaultValuedAttr:$num_spatial_dims, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_Deconv1DOp : OneFlow_BaseOp<"deconv1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$weight ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$output_padding, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Deconv2DOp : OneFlow_BaseOp<"deconv2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$weight ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$output_padding, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Deconv3DOp : OneFlow_BaseOp<"deconv3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$weight ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$output_padding, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, DefaultValuedAttr:$groups ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DeformConv2dOp : OneFlow_BaseOp<"deform_conv2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$offset, OneFlow_Tensor:$weight, Optional:$bias, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$stride_h, DefaultValuedAttr:$stride_w, DefaultValuedAttr:$pad_h, DefaultValuedAttr:$pad_w, DefaultValuedAttr:$dilation_h, DefaultValuedAttr:$dilation_w, DefaultValuedAttr:$groups, DefaultValuedAttr:$offset_groups, DefaultValuedAttr:$use_mask ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DeformConv2dInputGradOp : OneFlow_BaseOp<"deform_conv2d_input_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$output_grad, OneFlow_Tensor:$input, OneFlow_Tensor:$offset, OneFlow_Tensor:$weight, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$input_grad, //OneFlow_Tensor:$weight_grad, OneFlow_Tensor:$offset_grad, OneFlow_Tensor:$mask_grad ); let attrs = (ins DefaultValuedAttr:$stride_h, DefaultValuedAttr:$stride_w, DefaultValuedAttr:$pad_h, DefaultValuedAttr:$pad_w, DefaultValuedAttr:$dilation_h, DefaultValuedAttr:$dilation_w, DefaultValuedAttr:$groups, DefaultValuedAttr:$offset_groups, DefaultValuedAttr:$use_mask ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DeformConv2dParamGradOp : OneFlow_BaseOp<"deform_conv2d_param_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$output_grad, OneFlow_Tensor:$input, OneFlow_Tensor:$offset, OneFlow_Tensor:$weight, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$weight_grad ); let attrs = (ins DefaultValuedAttr:$stride_h, DefaultValuedAttr:$stride_w, DefaultValuedAttr:$pad_h, DefaultValuedAttr:$pad_w, DefaultValuedAttr:$dilation_h, DefaultValuedAttr:$dilation_w, DefaultValuedAttr:$groups, DefaultValuedAttr:$offset_groups, DefaultValuedAttr:$use_mask ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_CONV_OP_DEFINITIONS #ifdef GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS def OneFlow_BinaryCrossEntropyOp : OneFlow_BaseOp<"binary_cross_entropy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BinaryCrossEntropyGradOp : OneFlow_BaseOp<"binary_cross_entropy_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BinaryCrossEntropyWithLogitsOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight, Optional:$pos_weight ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_pos_weight ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BinaryCrossEntropyWithLogitsGradOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight, Optional:$pos_weight, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$has_pos_weight ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BinaryCrossEntropyWithLogitsReduceMeanOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits_reduce_mean", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BinaryCrossEntropyWithLogitsReduceMeanGradOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits_reduce_mean_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedBCEReduceMeanFwBwOp : OneFlow_BaseOp<"fused_bce_reduce_mean_fw_bw", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$dx ); let attrs = (ins OneFlow_DataType:$out_dtype, DefaultValuedAttr:$constant_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SigmoidCrossEntropyOp : OneFlow_BaseOp<"sigmoid_cross_entropy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$loss ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SigmoidCrossEntropyGradOp : OneFlow_BaseOp<"sigmoid_cross_entropy_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$loss_diff, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$prediction_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SparseCrossEntropyOp : OneFlow_BaseOp<"sparse_cross_entropy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SparseCrossEntropyGradOp : OneFlow_BaseOp<"sparse_cross_entropy_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$prediction_diff ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SparseCrossEntropyMsOp : OneFlow_BaseOp<"sparse_cross_entropy_ms", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SparseCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_cross_entropy_ms_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$prediction_diff ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS #ifdef GET_ONEFLOW_CUDA_OP_DEFINITIONS def OneFlow_NvtxEndOp : OneFlow_BaseOp<"nvtx_end", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$mark_prefix ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NvtxStartOp : OneFlow_BaseOp<"nvtx_start", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$mark_prefix ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_CUDA_OP_DEFINITIONS #ifdef GET_ONEFLOW_DATASET_OP_DEFINITIONS def OneFlow_COCOReaderOp : OneFlow_BaseOp<"COCOReader", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_TensorBufferTensor:$image, OneFlow_Tensor:$image_id, OneFlow_Tensor:$image_size, OneFlow_TensorBufferTensor:$gt_bbox, OneFlow_TensorBufferTensor:$gt_label, OneFlow_TensorBufferTensor:$gt_segm, OneFlow_TensorBufferTensor:$gt_segm_index ); let attrs = (ins DefaultValuedAttr:$session_id, StrAttr:$annotation_file, StrAttr:$image_dir, DefaultValuedAttr:$batch_size, DefaultValuedAttr:$shuffle_after_epoch, DefaultValuedAttr:$random_seed, DefaultValuedAttr:$group_by_ratio, DefaultValuedAttr:$remove_images_without_annotations, DefaultValuedAttr:$stride_partition, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_OFRecordReaderOp : OneFlow_BaseOp<"OFRecordReader", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$data_dir, DefaultValuedAttr:$data_part_num, DefaultValuedAttr:$batch_size, DefaultValuedAttr:$part_name_prefix, DefaultValuedAttr:$part_name_suffix_length, DefaultValuedAttr:$random_shuffle, DefaultValuedAttr:$seed, DefaultValuedAttr:$shuffle_buffer_size, DefaultValuedAttr:$shuffle_after_epoch, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; let has_nd_sbp_infer_fn = 1; let has_get_nd_sbp_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_CtcGreedyDecoderOp : OneFlow_BaseOp<"ctc_greedy_decoder", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$log_probs, OneFlow_Tensor:$input_lengths ); let output = (outs OneFlow_Tensor:$decoded, OneFlow_Tensor:$neg_sum_logits ); let attrs = (ins DefaultValuedAttr:$merge_repeated ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MegatronGptMmapDataLoaderOp : OneFlow_BaseOp<"megatron_gpt_mmap_data_loader", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins Optional:$iteration ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$data_file_prefix, DefaultValuedAttr:$seq_length, DefaultValuedAttr:$label_length, DefaultValuedAttr:$num_samples, DefaultValuedAttr:$batch_size, OneFlow_DataType:$dtype, SI64ArrayAttr:$split_sizes, DefaultValuedAttr:$split_index, DefaultValuedAttr:$shuffle, DefaultValuedAttr:$random_seed, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_OfrecordBytesDecoderOp : OneFlow_BaseOp<"ofrecord_bytes_decoder", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$name ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_OfrecordImageClassificationReaderOp : OneFlow_BaseOp<"ofrecord_image_classification_reader", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$image, OneFlow_Tensor:$label ); let attrs = (ins StrAttr:$data_dir, DefaultValuedAttr:$data_part_num, DefaultValuedAttr:$batch_size, DefaultValuedAttr:$part_name_prefix, DefaultValuedAttr:$part_name_suffix_length, DefaultValuedAttr:$random_shuffle, DefaultValuedAttr:$seed, DefaultValuedAttr:$shuffle_buffer_size, DefaultValuedAttr:$shuffle_after_epoch, DefaultValuedAttr:$color_space, DefaultValuedAttr:$image_feature_name, DefaultValuedAttr:$label_feature_name, DefaultValuedAttr:$decode_buffer_size_per_thread, DefaultValuedAttr:$num_decode_threads_per_machine ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; } def OneFlow_OfrecordImageDecoderOp : OneFlow_BaseOp<"ofrecord_image_decoder", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$name, DefaultValuedAttr:$color_space ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_OfrecordImageDecoderRandomCropOp : OneFlow_BaseOp<"ofrecord_image_decoder_random_crop", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$name, DefaultValuedAttr:$color_space, DefaultValuedAttr:$num_attempts, DefaultValuedAttr:$seed, DefaultValuedAttr:$has_seed, F32ArrayAttr:$random_area, F32ArrayAttr:$random_aspect_ratio ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_OfrecordRawDecoderOp : OneFlow_BaseOp<"ofrecord_raw_decoder", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$name, ShapeAttr:$shape, OneFlow_DataType:$data_type, DefaultValuedAttr:$dim1_varying_length, DefaultValuedAttr:$truncate ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_RawReaderOp : OneFlow_BaseOp<"raw_reader", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$files, OneFlow_DataType:$data_type, ShapeAttr:$shape, SI64Attr:$batch_size, SI64Attr:$shuffle_block_size, DefaultValuedAttr:$random_shuffle, DefaultValuedAttr:$seed, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; let has_nd_sbp_infer_fn = 1; } #endif // GET_ONEFLOW_DATASET_OP_DEFINITIONS #ifdef GET_ONEFLOW_DETECTION_OP_DEFINITIONS def OneFlow_InTopKOp : OneFlow_BaseOp<"in_top_k", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$targets, OneFlow_Tensor:$predictions ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$k ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NmsOp : OneFlow_BaseOp<"nms", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$iou_threshold, DefaultValuedAttr:$keep_n ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ObjectBboxFlipOp : OneFlow_BaseOp<"object_bbox_flip", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$bbox, OneFlow_Tensor:$image_size, OneFlow_Tensor:$flip_code ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ObjectBboxScaleOp : OneFlow_BaseOp<"object_bbox_scale", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$bbox, OneFlow_Tensor:$scale ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ObjectSegmentationPolygonFlipOp : OneFlow_BaseOp<"object_segmentation_polygon_flip", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$poly, OneFlow_Tensor:$image_size, OneFlow_Tensor:$flip_code ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ObjectSegmentationPolygonScaleOp : OneFlow_BaseOp<"object_segmentation_polygon_scale", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$poly, OneFlow_Tensor:$scale ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ObjectSegmentationPolygonToMaskOp : OneFlow_BaseOp<"object_segmentation_polygon_to_mask", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$poly, OneFlow_Tensor:$poly_index, OneFlow_Tensor:$image_size ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RoiAlignOp : OneFlow_BaseOp<"roi_align", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$rois ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$pooled_h, DefaultValuedAttr:$pooled_w, DefaultValuedAttr:$spatial_scale, DefaultValuedAttr:$sampling_ratio, DefaultValuedAttr:$aligned ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_RoiAlignGradOp : OneFlow_BaseOp<"roi_align_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x_like, OneFlow_Tensor:$rois ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$pooled_h, DefaultValuedAttr:$pooled_w, DefaultValuedAttr:$spatial_scale, DefaultValuedAttr:$sampling_ratio, DefaultValuedAttr:$aligned ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TopKOp : OneFlow_BaseOp<"top_k", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$k, DefaultValuedAttr:$sorted ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_DETECTION_OP_DEFINITIONS #ifdef GET_ONEFLOW_EAGER_OP_DEFINITIONS def OneFlow_EagerBToSOp : OneFlow_BaseOp<"eager_b_to_s", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$out_split_axis, StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerNaiveSToSOp : OneFlow_BaseOp<"eager_naive_s_to_s", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$in_split_axis, DefaultValuedAttr:$out_split_axis, StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerCclAllGatherOp : OneFlow_BaseOp<"eager_ccl_all_gather", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$parallel_conf, ShapeAttr:$output_shape, OneFlow_DataType:$output_dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerCclAllReduceOp : OneFlow_BaseOp<"eager_ccl_all_reduce", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$parallel_conf ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_EagerCclBroadcastOp : OneFlow_BaseOp<"eager_ccl_broadcast", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$parallel_conf, ShapeArrayAttr:$shape_list, DefaultValuedAttr:$root, DefaultValuedAttr:$async_launch ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_EagerCclTouchOp : OneFlow_BaseOp<"eager_ccl_touch", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let attrs = (ins DefaultValuedAttr:$async_launch ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_EagerCclReduceOp : OneFlow_BaseOp<"eager_ccl_reduce", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$parallel_conf, DefaultValuedAttr:$root ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_EagerCclReduceScatterOp : OneFlow_BaseOp<"eager_ccl_reduce_scatter", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$parallel_conf, ShapeAttr:$output_shape, OneFlow_DataType:$output_dtype, DefaultValuedAttr:$op_type ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerCclS2SOp : OneFlow_BaseOp<"eager_ccl_s2s", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$in_split_axis, DefaultValuedAttr:$out_split_axis, StrAttr:$parallel_conf ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerPToBOp : OneFlow_BaseOp<"eager_p_to_b", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerPToSOp : OneFlow_BaseOp<"eager_p_to_s", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$out_split_axis, StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerSToBOp : OneFlow_BaseOp<"eager_s_to_b", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$in_split_axis, StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerSToPOp : OneFlow_BaseOp<"eager_s_to_p", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$in_split_axis, StrAttr:$in_parallel_conf, StrAttr:$out_parallel_conf, ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_EagerSymmetricSToPOp : OneFlow_BaseOp<"eager_symmetric_s_to_p", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$in_split_axis, StrAttr:$parallel_conf ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } #endif // GET_ONEFLOW_EAGER_OP_DEFINITIONS #ifdef GET_ONEFLOW_FUSED_OP_DEFINITIONS def OneFlow_FusedLstmCellOp : OneFlow_BaseOp<"fused_lstm_cell", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_gates, OneFlow_Tensor:$hidden_gates, OneFlow_Tensor:$cx, Optional:$input_bias, Optional:$hidden_bias ); let output = (outs OneFlow_Tensor:$hy, OneFlow_Tensor:$cy, OneFlow_Tensor:$workspace ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedLstmCellGradOp : OneFlow_BaseOp<"fused_lstm_cell_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_hy, OneFlow_Tensor:$grad_cy, OneFlow_Tensor:$cx, OneFlow_Tensor:$cy, OneFlow_Tensor:$workspace ); let output = (outs OneFlow_Tensor:$grad_gates, Optional:$grad_cx, Optional:$grad_bias ); let trait_attrs = (ins DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGruCellOp : OneFlow_BaseOp<"fused_gru_cell", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_gates, OneFlow_Tensor:$hidden_gates, OneFlow_Tensor:$hx, Optional:$input_bias, Optional:$hidden_bias ); let output = (outs OneFlow_Tensor:$hy, OneFlow_Tensor:$workspace ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGruCellGradOp : OneFlow_BaseOp<"fused_gru_cell_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_hy, OneFlow_Tensor:$workspace ); let output = (outs OneFlow_Tensor:$grad_input_gates, OneFlow_Tensor:$grad_hidden_gates, Optional:$grad_hx, Optional:$grad_input_bias, Optional:$grad_hidden_bias ); let trait_attrs = (ins DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CudnnFusedNormalizationAddReluOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$addend, Optional:$moving_mean, Optional:$moving_variance, OneFlow_Tensor:$gamma, OneFlow_Tensor:$beta ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$reserve_space, Optional:$mean, Optional:$inv_variance ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$momentum ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_CudnnFusedNormalizationAddReluGradOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, OneFlow_Tensor:$gamma, OneFlow_Tensor:$beta, OneFlow_Tensor:$reserve_space, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$gamma_diff, OneFlow_Tensor:$beta_diff, OneFlow_Tensor:$dx, Optional:$addend_diff ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedBiasAddGeluOp : OneFlow_BaseOp<"fused_bias_add_gelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedBiasAddGeluGradOp : OneFlow_BaseOp<"fused_bias_add_gelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedBiasAddMaskScaleOp : OneFlow_BaseOp<"fused_bias_add_mask_scale", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, OneFlow_Tensor:$mask, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$scale, DefaultValuedAttr:$seed ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedCastScaleOp : OneFlow_BaseOp<"fused_cast_scale", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scale_by_tensor ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$scale ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedScaleMaskSoftmaxOp : OneFlow_BaseOp<"fused_scale_mask_softmax", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$scale_value, DefaultValuedAttr:$mask_fill_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$mask, OneFlow_Tensor:$dropout_mask ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$softmax_y ); let attrs = (ins DefaultValuedAttr:$scale_value, DefaultValuedAttr:$mask_fill_value, DefaultValuedAttr:$dropout_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedScaleMaskSoftmaxDropoutGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$softmax_y, OneFlow_Tensor:$dy, OneFlow_Tensor:$mask, OneFlow_Tensor:$dropout_mask ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale_value, DefaultValuedAttr:$dropout_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedBiasAddScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<"fused_bias_add_scale_mask_softmax_dropout", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$bias, OneFlow_Tensor:$mask, OneFlow_Tensor:$dropout_mask ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$softmax_y ); let attrs = (ins DefaultValuedAttr:$scale_value, DefaultValuedAttr:$mask_fill_value, DefaultValuedAttr:$dropout_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedScaleMaskSoftmaxGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedScaleTrilOp : OneFlow_BaseOp<"fused_scale_tril", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$diagonal, DefaultValuedAttr:$floating_fill_value, DefaultValuedAttr:$integer_fill_value, DefaultValuedAttr:$is_floating_fill_value, DefaultValuedAttr:$floating_scale_value, DefaultValuedAttr:$integer_scale_value, DefaultValuedAttr:$is_floating_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedSelfAttentionQueryMulKeyAndValueOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$hidden_states ); let output = (outs OneFlow_Tensor:$query_mul_key, OneFlow_Tensor:$value ); let attrs = (ins DefaultValuedAttr:$head_size, DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedSelfAttentionQueryMulKeyAndValueGradOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query_mul_key_grad, OneFlow_Tensor:$value_grad, OneFlow_Tensor:$hidden_states ); let output = (outs OneFlow_Tensor:$hidden_states_grad ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedTrilScaleSoftmaxMaskScaleOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$softmax_y ); let attrs = (ins DefaultValuedAttr:$diagonal, DefaultValuedAttr:$tril_fill_value, DefaultValuedAttr:$tril_scale_value, DefaultValuedAttr:$mask_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedTrilScaleSoftmaxMaskScaleGradOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$softmax_y, OneFlow_Tensor:$dy, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$diagonal, DefaultValuedAttr:$tril_scale_value, DefaultValuedAttr:$mask_scale_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NormalizationAddReluGradOp : OneFlow_BaseOp<"normalization_add_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, OneFlow_Tensor:$gamma, OneFlow_Tensor:$beta, OneFlow_Tensor:$reserve_space, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$gamma_diff, OneFlow_Tensor:$beta_diff, OneFlow_Tensor:$dx, Optional:$addend_diff ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedDotFeatureInteractionOp : OneFlow_BaseOp<"fused_dot_feature_interaction", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$features, Optional:$output_concat, Optional:$num_valid_sparse_feature, Optional:$sparse_feature, Optional:$sparse_indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$self_interaction, DefaultValuedAttr:$has_output_concat, DefaultValuedAttr:$output_padding, DefaultValuedAttr:$pooling ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedDotFeatureInteractionGradOp : OneFlow_BaseOp<"fused_dot_feature_interaction_grad", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, Variadic:$features, Optional:$num_valid_sparse_feature, Optional:$sparse_feature, Optional:$sparse_indices ); let output = (outs Variadic:$features_grad, Optional:$output_concat_grad, Optional:$sparse_feature_grad ); let attrs = (ins DefaultValuedAttr:$self_interaction, DefaultValuedAttr:$output_concat_grad_dim, DefaultValuedAttr:$pooling ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCrossFeatureInteractionOp : OneFlow_BaseOp<"fused_cross_feature_interaction", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$weight, OneFlow_Tensor:$bias, OneFlow_Tensor:$x0 ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$matmul_result ); let attrs = (ins StrAttr:$interaction_mode ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCrossFeatureInteractionV1GradOp : OneFlow_BaseOp<"fused_cross_feature_interaction_v1_grad", [NoMemoryEffect, DeclareOpInterfaceMethods, NoGrad]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$weight, OneFlow_Tensor:$x0, OneFlow_Tensor:$x, OneFlow_Tensor:$matmul_result ); let output = (outs OneFlow_Tensor:$dx0, OneFlow_Tensor:$dw, OneFlow_Tensor:$dx, OneFlow_Tensor:$dbias ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_feature_interaction_v2_grad", [NoMemoryEffect, DeclareOpInterfaceMethods, NoGrad]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$weight, OneFlow_Tensor:$bias, OneFlow_Tensor:$x0, OneFlow_Tensor:$x, OneFlow_Tensor:$matmul_result ); let output = (outs OneFlow_Tensor:$dx0, OneFlow_Tensor:$dw, OneFlow_Tensor:$dx, OneFlow_Tensor:$dbias ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, OneFlow_Tensor:$key, OneFlow_Tensor:$value, Optional:$alibi_slopes_ ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$softmax_lse, OneFlow_Tensor:$rng_state ); let attrs = (ins DefaultValuedAttr:$p_dropout, DefaultValuedAttr:$softmax_scale, DefaultValuedAttr:$is_causal, SI32Attr:$window_size_left, SI32Attr:$window_size_right, DefaultValuedAttr:$seed ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScaledDotProductFlashAttentionGradOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_out, OneFlow_Tensor:$query, OneFlow_Tensor:$key, OneFlow_Tensor:$value, OneFlow_Tensor:$out, OneFlow_Tensor:$softmax_lse, OneFlow_Tensor:$rng_state, Optional:$alibi_slopes_ ); let output = (outs OneFlow_Tensor:$grad_q, OneFlow_Tensor:$grad_k, OneFlow_Tensor:$grad_v ); let attrs = (ins DefaultValuedAttr:$p_dropout, DefaultValuedAttr:$softmax_scale, DefaultValuedAttr:$is_causal, SI32Attr:$window_size_left, SI32Attr:$window_size_right ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, OneFlow_Tensor:$key, OneFlow_Tensor:$value, Optional:$attn_bias, Optional:$query_seq_start, Optional:$key_seq_start, Optional:$key_seq_len ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI64Attr:$query_head_size, DefaultValuedAttr:$query_max_seq_len, DefaultValuedAttr:$key_max_seq_len, F64Attr:$scale, DefaultValuedAttr:$causal_diagonal_offset, DefaultValuedAttr:$attn_mask_type, StrAttr:$query_layout, StrAttr:$key_layout, StrAttr:$value_layout, StrAttr:$output_layout ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedAttentionConcatPastKeyValueOp : OneFlow_BaseOp<"fused_attention_concat_past_key_value", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$key, OneFlow_Tensor:$value, Optional:$past_key, Optional:$past_value ); let output = (outs OneFlow_Tensor:$output_key, OneFlow_Tensor:$output_value ); let attrs = (ins StrAttr:$past_key_layout, StrAttr:$past_value_layout, StrAttr:$key_layout, StrAttr:$value_layout, SI64Attr:$key_head_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedFastGeluMulOp : OneFlow_BaseOp<"fused_fast_gelu_mul", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$multiplier ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedFastGeluMulGradOp : OneFlow_BaseOp<"fused_fast_gelu_mul_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_diff, OneFlow_Tensor:$in, OneFlow_Tensor:$multiplier ); let output = (outs OneFlow_Tensor:$in_diff, OneFlow_Tensor:$multiplier_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetBounddingBoxesCoordOp : OneFlow_BaseOp<"fused_get_boundding_boxes_coord", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x1, OneFlow_Tensor:$y1, OneFlow_Tensor:$w1, OneFlow_Tensor:$h1, OneFlow_Tensor:$x2, OneFlow_Tensor:$y2, OneFlow_Tensor:$w2, OneFlow_Tensor:$h2 ); let output = (outs OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2 ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetBounddingBoxesCoordGradOp : OneFlow_BaseOp<"fused_get_boundding_boxes_coord_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1_diff, OneFlow_Tensor:$b1_x2_diff, OneFlow_Tensor:$b1_y1_diff, OneFlow_Tensor:$b1_y2_diff, OneFlow_Tensor:$b2_x1_diff, OneFlow_Tensor:$b2_x2_diff, OneFlow_Tensor:$b2_y1_diff, OneFlow_Tensor:$b2_y2_diff ); let output = (outs OneFlow_Tensor:$x1_diff, OneFlow_Tensor:$y1_diff, OneFlow_Tensor:$w1_diff, OneFlow_Tensor:$h1_diff, OneFlow_Tensor:$x2_diff, OneFlow_Tensor:$y2_diff, OneFlow_Tensor:$w2_diff, OneFlow_Tensor:$h2_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetCiouResultOp : OneFlow_BaseOp<"fused_get_ciou_result", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$v, OneFlow_Tensor:$iou, OneFlow_Tensor:$rho2, OneFlow_Tensor:$c2 ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$alpha ); let attrs = (ins F32Attr: $eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetCiouResultGradOp : OneFlow_BaseOp<"fused_get_ciou_result_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$alpha, OneFlow_Tensor:$rho2, OneFlow_Tensor:$c2 ); let output = (outs OneFlow_Tensor:$dv, OneFlow_Tensor:$diou, OneFlow_Tensor:$drho2, OneFlow_Tensor:$dc2 ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetIouOp : OneFlow_BaseOp<"fused_get_iou", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$w1, OneFlow_Tensor:$h1, OneFlow_Tensor:$w2, OneFlow_Tensor:$h2, OneFlow_Tensor:$inter ); let output = (outs OneFlow_Tensor:$iou ); let attrs = (ins F32Attr: $eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetIouGradOp : OneFlow_BaseOp<"fused_get_iou_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$diou, OneFlow_Tensor:$w1, OneFlow_Tensor:$h1, OneFlow_Tensor:$w2, OneFlow_Tensor:$h2, OneFlow_Tensor:$inter ); let attrs = (ins F32Attr: $eps ); let output = (outs OneFlow_Tensor:$dw1, OneFlow_Tensor:$dh1, OneFlow_Tensor:$dinter ); let attrs = (ins F32Attr: $eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCenterOp : OneFlow_BaseOp<"fused_get_center_dist", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2 ); let output = (outs OneFlow_Tensor:$rho2 ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCenterGradOp : OneFlow_BaseOp<"fused_get_center_dist_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2, OneFlow_Tensor:$rho2_diff ); let output = (outs OneFlow_Tensor:$b1_x1_diff, OneFlow_Tensor:$b1_x2_diff, OneFlow_Tensor:$b2_x1_diff, OneFlow_Tensor:$b2_x2_diff, OneFlow_Tensor:$b1_y1_diff, OneFlow_Tensor:$b1_y2_diff, OneFlow_Tensor:$b2_y1_diff, OneFlow_Tensor:$b2_y2_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetCiouDiagonalAngleOp : OneFlow_BaseOp<"fused_get_ciou_diagonal_angle", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$w1, OneFlow_Tensor:$h1, OneFlow_Tensor:$w2, OneFlow_Tensor:$h2 ); let output = (outs OneFlow_Tensor:$v ); let attrs = (ins DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetCiouDiagonalAngleGradOp : OneFlow_BaseOp<"fused_get_ciou_diagonal_angle_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$w1, OneFlow_Tensor:$h1, OneFlow_Tensor:$w2, OneFlow_Tensor:$h2, OneFlow_Tensor:$v_diff ); let output = (outs OneFlow_Tensor:$w1_diff, OneFlow_Tensor:$h1_diff, OneFlow_Tensor:$w2_diff, OneFlow_Tensor:$h2_diff ); let attrs = (ins DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetIntersectionAreaOp : OneFlow_BaseOp<"fused_get_intersection_area", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2 ); let output = (outs OneFlow_Tensor:$inter ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetIntersectionAreaGradOp : OneFlow_BaseOp<"fused_get_intersection_area_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2, OneFlow_Tensor:$inter_diff ); let output = (outs OneFlow_Tensor:$b1_x1_diff, OneFlow_Tensor:$b1_x2_diff, OneFlow_Tensor:$b2_x1_diff, OneFlow_Tensor:$b2_x2_diff, OneFlow_Tensor:$b1_y1_diff, OneFlow_Tensor:$b1_y2_diff, OneFlow_Tensor:$b2_y1_diff, OneFlow_Tensor:$b2_y2_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetConvexDiagonalSquaredOp : OneFlow_BaseOp<"fused_get_convex_diagonal_squared", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2 ); let output = (outs OneFlow_Tensor:$c2 ); let attrs = (ins DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGetConvexDiagonalSquaredGradOp : OneFlow_BaseOp<"fused_get_convex_diagonal_squared_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$c2_diff, OneFlow_Tensor:$b1_x1, OneFlow_Tensor:$b1_x2, OneFlow_Tensor:$b2_x1, OneFlow_Tensor:$b2_x2, OneFlow_Tensor:$b1_y1, OneFlow_Tensor:$b1_y2, OneFlow_Tensor:$b2_y1, OneFlow_Tensor:$b2_y2 ); let output = (outs OneFlow_Tensor:$b1_x1_diff, OneFlow_Tensor:$b1_x2_diff, OneFlow_Tensor:$b2_x1_diff, OneFlow_Tensor:$b2_x2_diff, OneFlow_Tensor:$b1_y1_diff, OneFlow_Tensor:$b1_y2_diff, OneFlow_Tensor:$b2_y1_diff, OneFlow_Tensor:$b2_y2_diff ); let attrs = (ins DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedScaleMaskBiasSoftmaxGradOp : OneFlow_BaseOp<"fused_scale_mask_bias_softmax_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedScaleMaskBiasSoftmaxOp : OneFlow_BaseOp<"fused_scale_mask_bias_softmax", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$mask, Optional:$bias ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$scale, DefaultValuedAttr:$inplace ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedCodegeexQkvReshapeOp : OneFlow_BaseOp<"fused_codegeex_qkv_reshape", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, OneFlow_Tensor:$key, OneFlow_Tensor:$value ); let output = (outs OneFlow_Tensor:$new_query, OneFlow_Tensor:$new_key, OneFlow_Tensor:$new_value ); let attrs = (ins DefaultValuedAttr:$num_attention_heads ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedClipGradOp : OneFlow_BaseOp<"fused_clip_grad", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model_diff ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$max_norm, DefaultValuedAttr:$norm_type ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_NonContiguousBinaryOp : OneFlow_BaseOp<"noncontiguous_binary_op", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$lhs, OneFlow_Tensor:$rhs ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$op, DefaultValuedAttr:$inplace ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NonContiguousBinaryOpGrad : OneFlow_BaseOp<"noncontiguous_binary_op_grad", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$lhs, OneFlow_Tensor:$rhs ); let output = (outs OneFlow_Tensor:$dlhs, OneFlow_Tensor:$drhs ); let attrs = (ins DefaultValuedAttr:$op, DefaultValuedAttr:$inplace ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS #ifdef GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS def OneFlow_AbsOp : OneFlow_IdempotentBaseOp<"abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_CeilOp : OneFlow_IdempotentBaseOp<"ceil", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_FloorOp : OneFlow_IdempotentBaseOp<"floor", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<"ones_like", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let same_output_regst_num = 1; let has_nd_sbp_infer_fn = 1; let input = (ins AnyType:$like); let output = (outs AnyType:$out); } def OneFlow_ReluOp : OneFlow_IdempotentBaseOp<"relu", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> {} def OneFlow_RintOp : OneFlow_IdempotentBaseOp<"rint", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_RoundOp : OneFlow_IdempotentBaseOp<"round", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_SignOp : OneFlow_IdempotentBaseOp<"sign", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} #endif // GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS #ifdef GET_ONEFLOW_IDENTITY_OP_DEFINITIONS def OneFlow_AmpWhiteIdentityOp : OneFlow_BaseOp<"amp_white_identity", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AmpBlackIdentityOp : OneFlow_BaseOp<"amp_black_identity", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IdentityOp : OneFlow_BaseOp<"identity", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IdentityBufferOp : OneFlow_BaseOp<"identity_buffer", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$buffer_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TupleIdentityOp : OneFlow_BaseOp<"tuple_identity", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs Variadic:$out ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_sbp_signature_infer_fn = 1; } def OneFlow_PinnedIdentityOp : OneFlow_BaseOp<"pinned_identity", [DeclareOpInterfaceMethods]> { let summary = "mark defining op of operand can't be erased"; let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_IDENTITY_OP_DEFINITIONS #ifdef GET_ONEFLOW_IMAGE_OP_DEFINITIONS def OneFlow_ImageBatchAlignOp : OneFlow_BaseOp<"image_batch_align", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$shape, OneFlow_DataType:$data_type, DefaultValuedAttr:$alignment, DefaultValuedAttr:$dynamic_out ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; } def OneFlow_ImageDecodeOp : OneFlow_BaseOp<"image_decode", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$color_space, OneFlow_DataType:$data_type ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImageFlipOp : OneFlow_BaseOp<"image_flip", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$flip_code ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImageRandomCropOp : OneFlow_BaseOp<"image_random_crop", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$num_attempts, DefaultValuedAttr:$seed, DefaultValuedAttr:$has_seed, F32ArrayAttr:$random_area, F32ArrayAttr:$random_aspect_ratio ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ImageResizeKeepAspectRatioOp : OneFlow_BaseOp<"image_resize_keep_aspect_ratio", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$size, OneFlow_Tensor:$scale ); let attrs = (ins DefaultValuedAttr:$target_size, DefaultValuedAttr:$min_size, DefaultValuedAttr:$max_size, DefaultValuedAttr:$resize_longer, DefaultValuedAttr:$interpolation_type ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImageResizeToFixedOp : OneFlow_BaseOp<"image_resize_to_fixed", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$scale ); let attrs = (ins DefaultValuedAttr:$target_width, DefaultValuedAttr:$target_height, DefaultValuedAttr:$channels, OneFlow_DataType:$data_type, DefaultValuedAttr:$interpolation_type ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_IMAGE_OP_DEFINITIONS #ifdef GET_ONEFLOW_INDICES_OP_DEFINITIONS def OneFlow_ArgSortOp : OneFlow_BaseOp<"arg_sort", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$direction ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ArgmaxOp : OneFlow_BaseOp<"argmax", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ArgwhereOp : OneFlow_BaseOp<"argwhere", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output, OneFlow_Tensor:$output_size ); let attrs = (ins OneFlow_DataType:$dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BatchGatherOp : OneFlow_BaseOp<"batch_gather", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimGatherOp : OneFlow_BaseOp<"dim_gather", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterAddOp : OneFlow_BaseOp<"dim_scatter_add", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index, OneFlow_Tensor:$src ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterAddLikeOp : OneFlow_BaseOp<"dim_scatter_add_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$like, OneFlow_Tensor:$index, OneFlow_Tensor:$src ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterAddScalarOp : OneFlow_BaseOp<"dim_scatter_add_scalar", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$src_scalar, DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterMulOp : OneFlow_BaseOp<"dim_scatter_mul", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index, OneFlow_Tensor:$src ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterMulScalarOp : OneFlow_BaseOp<"dim_scatter_mul_scalar", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$src_scalar, DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterUpdateOp : OneFlow_BaseOp<"dim_scatter_update", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index, OneFlow_Tensor:$src ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DimScatterUpdateScalarOp : OneFlow_BaseOp<"dim_scatter_update_scalar", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$src_scalar, DefaultValuedAttr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_EmbeddingRenormOp : OneFlow_BaseOp<"embedding_renorm", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$max_norm, DefaultValuedAttr:$norm_type ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingOp : OneFlow_BaseOp<"embedding", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$weight, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$padding_idx, DefaultValuedAttr:$scale_grad_by_freq ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedApplyRotaryEmbOp : OneFlow_BaseOp<"fused_apply_rotary_emb", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$cos, Optional:$sin, Optional:$position_ids ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$x_layout, DefaultValuedAttr:$output_layout, DefaultValuedAttr:$mode, DefaultValuedAttr:$tensor_index, DefaultValuedAttr:$base, DefaultValuedAttr:$k_size, DefaultValuedAttr:$rotary_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingGradOp : OneFlow_BaseOp<"embedding_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$weight, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$padding_idx, DefaultValuedAttr:$scale_grad_by_freq ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_GatherOp : OneFlow_BaseOp<"gather", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_GatherNdOp : OneFlow_BaseOp<"gather_nd", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$params, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_GenerateRandomBatchPermutationIndicesOp : OneFlow_BaseOp<"generate_random_batch_permutation_indices", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$seed ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImageTargetResizeOp : OneFlow_BaseOp<"image_target_resize", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$size, OneFlow_Tensor:$scale ); let attrs = (ins DefaultValuedAttr:$target_size, DefaultValuedAttr:$max_size ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SliceOp : OneFlow_BaseOp<"slice", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$start, SI64ArrayAttr:$stop, SI64ArrayAttr:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SliceUpdateOp : OneFlow_BaseOp<"slice_update", [SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ref, OneFlow_Tensor:$value ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$start, SI64ArrayAttr:$stop, SI64ArrayAttr:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SliceGradOp : OneFlow_BaseOp<"slice_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins ShapeAttr:$like_shape, SI64ArrayAttr:$start, SI64ArrayAttr:$stop, SI64ArrayAttr:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ScatterNdOp : OneFlow_BaseOp<"scatter_nd", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$indices, OneFlow_Tensor:$updates ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ScatterNdLikeOp : OneFlow_BaseOp<"scatter_nd_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$like, OneFlow_Tensor:$indices, OneFlow_Tensor:$updates ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TensorScatterNdAddOp : OneFlow_BaseOp<"tensor_scatter_nd_add", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$params, OneFlow_Tensor:$updates, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_TensorScatterNdUpdateOp : OneFlow_BaseOp<"tensor_scatter_nd_update", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$params, OneFlow_Tensor:$updates, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_UnsortedBatchSegmentSumOp : OneFlow_BaseOp<"unsorted_batch_segment_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$data, OneFlow_Tensor:$segment_ids ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$num_segments ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_UnsortedSegmentSumOp : OneFlow_BaseOp<"unsorted_segment_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$data, OneFlow_Tensor:$segment_ids ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$num_segments ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_UnsortedSegmentSumLikeOp : OneFlow_BaseOp<"unsorted_segment_sum_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$data, OneFlow_Tensor:$segment_ids, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_WhereOp : OneFlow_BaseOp<"where", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$condition, OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MedianOp : OneFlow_BaseOp<"median", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MedianWithIndicesOp : OneFlow_BaseOp<"median_with_indices", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$values, OneFlow_Tensor:$indices ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SearchSortedOp : OneFlow_BaseOp<"searchsorted", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$sorted_sequence, OneFlow_Tensor:$values ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$out_int32, DefaultValuedAttr:$right ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SearchSortedScalarOp : OneFlow_BaseOp<"searchsorted_scalar", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$sorted_sequence ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$out_int32, DefaultValuedAttr:$right, DefaultValuedAttr:$values ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ModeOp: OneFlow_BaseOp<"mode", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$values, OneFlow_Tensor:$indices ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_INDICES_OP_DEFINITIONS #ifdef GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS def OneFlow_NegativeOp : OneFlow_InvolutionBaseOp<"negative", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_ReciprocalOp : OneFlow_InvolutionBaseOp<"reciprocal", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} #endif // GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS #ifdef GET_ONEFLOW_LOSS_OP_DEFINITIONS def OneFlow_CombinedMarginLossOp : OneFlow_BaseOp<"combined_margin_loss", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$theta ); let attrs = (ins DefaultValuedAttr:$m1, DefaultValuedAttr:$m2, DefaultValuedAttr:$m3, DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_CombinedMarginLossGradOp : OneFlow_BaseOp<"combined_margin_loss_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$label, OneFlow_Tensor:$theta ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$m1, DefaultValuedAttr:$m2, DefaultValuedAttr:$m3, DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CtcLossOp : OneFlow_BaseOp<"ctc_loss", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$log_probs, OneFlow_Tensor:$targets, OneFlow_Tensor:$input_lengths, OneFlow_Tensor:$target_lengths ); let output = (outs OneFlow_Tensor:$loss, OneFlow_Tensor:$alpha ); let attrs = (ins DefaultValuedAttr:$max_target_length, DefaultValuedAttr:$blank, DefaultValuedAttr:$zero_infinity ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CtcLossGradOp : OneFlow_BaseOp<"ctc_loss_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_out, OneFlow_Tensor:$log_probs, OneFlow_Tensor:$targets, OneFlow_Tensor:$input_lengths, OneFlow_Tensor:$target_lengths, OneFlow_Tensor:$loss, OneFlow_Tensor:$alpha ); let output = (outs OneFlow_Tensor:$grad ); let attrs = (ins DefaultValuedAttr:$max_target_length, DefaultValuedAttr:$blank, DefaultValuedAttr:$zero_infinity ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DynamicLossScaleScheduleOp : OneFlow_BaseOp<"dynamic_loss_scale_schedule", [DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$count_not_finite, OneFlow_Tensor:$loss_scale, OneFlow_Tensor:$good_step_counter ); let attrs = (ins DefaultValuedAttr:$increment_period, DefaultValuedAttr:$multiplier ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_KlDivLossOp : OneFlow_BaseOp<"kl_div_loss", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$log_target ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_KlDivLossGradOp : OneFlow_BaseOp<"kl_div_loss_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$log_target ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SmoothL1LossOp : OneFlow_BaseOp<"smooth_l1_loss", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$beta ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SmoothL1LossGradOp : OneFlow_BaseOp<"smooth_l1_loss_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$beta ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_LOSS_OP_DEFINITIONS #ifdef GET_ONEFLOW_MATH_OP_DEFINITIONS def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ErfGradOp : OneFlow_BaseOp<"erf_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ExpOp : OneFlow_BaseOp<"exp", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ExpGradOp : OneFlow_BaseOp<"exp_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Exp2Op : OneFlow_BaseOp<"exp2", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Exp2GradOp : OneFlow_BaseOp<"exp2_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Expm1Op : OneFlow_BaseOp<"expm1", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Expm1GradOp : OneFlow_BaseOp<"expm1_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FloordivXGradOp : OneFlow_BaseOp<"floordiv_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FloordivYGradOp : OneFlow_BaseOp<"floordiv_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TruncdivXGradOp : OneFlow_BaseOp<"truncdiv_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TruncdivYGradOp : OneFlow_BaseOp<"truncdiv_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LgammaOp : OneFlow_BaseOp<"lgamma", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LgammaGradOp : OneFlow_BaseOp<"lgamma_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DigammaOp : OneFlow_BaseOp<"digamma", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DigammaGradOp : OneFlow_BaseOp<"digamma_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TrigammaOp : OneFlow_BaseOp<"trigamma", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogOp : OneFlow_BaseOp<"log", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log1pOp : OneFlow_BaseOp<"log1p", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log1pGradOp : OneFlow_BaseOp<"log1p_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log2GradOp : OneFlow_BaseOp<"log2_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log10GradOp : OneFlow_BaseOp<"log10_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogGradOp : OneFlow_BaseOp<"log_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogSigmoidOp : OneFlow_BaseOp<"log_sigmoid", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogSigmoidGradOp : OneFlow_BaseOp<"log_sigmoid_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReciprocalGradOp : OneFlow_BaseOp<"reciprocal_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReciprocalNoNanOp : OneFlow_BaseOp<"reciprocal_no_nan", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReciprocalNoNanGradOp : OneFlow_BaseOp<"reciprocal_no_nan_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RsqrtOp : OneFlow_BaseOp<"rsqrt", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RsqrtGradOp : OneFlow_BaseOp<"rsqrt_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SigmoidOp : OneFlow_BaseOp<"sigmoid", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SigmoidGradOp : OneFlow_BaseOp<"sigmoid_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftplusOp : OneFlow_BaseOp<"softplus", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$beta, DefaultValuedAttr:$threshold ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftplusGradOp : OneFlow_BaseOp<"softplus_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$beta, DefaultValuedAttr:$threshold ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftsignGradOp : OneFlow_BaseOp<"softsign_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_VarOp : OneFlow_BaseOp<"var", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins SI32ArrayAttr:$dim, DefaultValuedAttr:$unbiased, DefaultValuedAttr:$keepdim, OneFlow_DataType:$dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SqrtOp : OneFlow_BaseOp<"sqrt", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_SqrtGradOp : OneFlow_BaseOp<"sqrt_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SquareOp : OneFlow_BaseOp<"square", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SquareGradOp : OneFlow_BaseOp<"square_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XlogyXGradOp : OneFlow_BaseOp<"xlogy_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XlogyYGradOp : OneFlow_BaseOp<"xlogy_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CumsumOp : OneFlow_BaseOp<"cumsum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64Attr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CumProdOp : OneFlow_BaseOp<"cumprod", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64Attr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CumProdGradOp : OneFlow_BaseOp<"cumprod_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$output, OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64Attr:$dim ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ErfInvOp : OneFlow_BaseOp<"erfinv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FftC2COp : OneFlow_BaseOp<"fft_c2c", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI64ArrayAttr:$dims, BoolAttr:$forward, SI32Attr:$norm_mode, F64Attr:$norm_fct ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FftR2COp : OneFlow_BaseOp<"fft_r2c", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI64ArrayAttr:$dims, SI32Attr:$norm_mode, F64Attr:$norm_fct, BoolAttr:$onesided ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FftC2ROp : OneFlow_BaseOp<"fft_c2r", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI64ArrayAttr:$dims, SI32Attr:$norm_mode, F64Attr:$norm_fct, SI64Attr:$last_dim_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_StftOp : OneFlow_BaseOp<"stft", [SupportNonContiguous, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, Optional:$window ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$n_fft, DefaultValuedAttr:$hop_length, DefaultValuedAttr:$win_length, DefaultValuedAttr:$center, DefaultValuedAttr:$pad_mode, DefaultValuedAttr:$normalized, DefaultValuedAttr:$onesided, DefaultValuedAttr:$return_complex ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } #endif // GET_ONEFLOW_MATH_OP_DEFINITIONS #ifdef GET_ONEFLOW_MATMUL_OP_DEFINITIONS def OneFlow_BatchMatmulOp : OneFlow_BaseOp<"batch_matmul", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b, DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_BroadcastMatmulOp : OneFlow_BaseOp<"broadcast_matmul", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b, DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_BroadcastMatmulGradBOp : OneFlow_BaseOp<"broadcast_matmul_grad_b", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_DistributedPartialFcSampleOp : OneFlow_BaseOp<"distributed_partial_fc_sample", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$weight, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$mapped_label, OneFlow_Tensor:$sampled_label, OneFlow_Tensor:$sampled_weight ); let attrs = (ins DefaultValuedAttr:$num_sample, DefaultValuedAttr:$seed ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_DistributedPartialFcSampleDisableBoxingOp : OneFlow_BaseOp<"distributed_partial_fc_sample_disable_boxing", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$sampled_weight_diff, OneFlow_Tensor:$sampled_label ); let output = (outs OneFlow_Tensor:$boxing_disabled_sampled_weight_diff, OneFlow_Tensor:$boxing_disabled_sampled_label ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ErfcOp : OneFlow_BaseOp<"erfc", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ErfcGradOp : OneFlow_BaseOp<"erfc_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MatmulOp : OneFlow_BaseOp<"matmul", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b, DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_MatrixVectorProductOp : OneFlow_BaseOp<"matrix_vector_product", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MatrixVectorProductGradAOp : OneFlow_BaseOp<"matrix_vector_product_grad_a", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MatrixVectorProductGradBOp : OneFlow_BaseOp<"matrix_vector_product_grad_b", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$a ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_VectorMatrixProductOp : OneFlow_BaseOp<"vector_matrix_product", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_VectorMatrixProductGradAOp : OneFlow_BaseOp<"vector_matrix_product_grad_a", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$b ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_VectorMatrixProductGradBOp : OneFlow_BaseOp<"vector_matrix_product_grad_b", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$a ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CublasFusedMLPOp : OneFlow_BaseOp<"cublas_fused_mlp", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Variadic:$weights, Variadic:$biases ); let output = (outs OneFlow_Tensor:$out, Variadic:$cublas_aux, Variadic:$hidden ); let attrs = (ins DefaultValuedAttr:$skip_final_activation ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoMemoryEffect, NoGrad, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, Variadic:$weights, Variadic:$cublas_aux, Variadic:$hidden ); let output = (outs OneFlow_Tensor:$d_x, Variadic:$d_biases, Variadic:$d_weights ); let attrs = (ins F32ArrayAttr:$alpha_list ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CublasBiasAddReluMatmulGradOp : OneFlow_BaseOp<"cublas_bias_add_relu_matmul_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$weight, OneFlow_Tensor:$aux ); let output = (outs OneFlow_Tensor:$d_grad, OneFlow_Tensor:$d_bias ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CublasMatmulBiasAddGradOp : OneFlow_BaseOp<"cublas_matmul_bias_add_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$w_grad, OneFlow_Tensor:$b_grad ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedMatmulBiasOp : OneFlow_BaseOp<"fused_matmul_bias", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$weight, OneFlow_Tensor:$bias, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_FusedMatmulBiasAddReluDropoutOp : OneFlow_BaseOp<"fused_matmul_bias_add_relu_dropout", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Variadic:$weights, Variadic:$biases ); let output = (outs OneFlow_Tensor:$out, Variadic:$cublas_aux, Variadic:$hidden ); let attrs = (ins DefaultValuedAttr:$skip_final_activation, DefaultValuedAttr:$seed, F32ArrayAttr:$dropout_rate_list ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedReluDropoutGradOp : OneFlow_BaseOp<"fused_relu_dropout_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$mask ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGluOp : OneFlow_BaseOp<"fused_glu", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$w, Optional:$b, Optional:$v, Optional:$c ); let attrs = (ins DefaultValuedAttr:$activation, DefaultValuedAttr:$has_bias, DefaultValuedAttr:$is_split ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$matmul_wx, Optional:$matmul_vx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedGluWithoutLinearGradOp : OneFlow_BaseOp<"fused_glu_without_linear_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$matmul_wx, Optional:$matmul_vx ); let attrs = (ins DefaultValuedAttr:$activation ); let output = (outs OneFlow_Tensor:$d_matmul_wx, Optional:$d_matmul_vx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GroupedMatmulBiasOp : OneFlow_BaseOp<"grouped_matmul_bias", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$xs, Variadic:$weights, Variadic:$biases ); let output = (outs Variadic:$ys ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS #ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS def OneFlow_CategoricalOrdinalEncodeOp : OneFlow_BaseOp<"CategoricalOrdinalEncode", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$table, OneFlow_Tensor:$size, OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$hash_precomputed ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_AddNOp : OneFlow_BaseOp<"add_n", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs OneFlow_Tensor:$out ); let hasCanonicalizer = 1; let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ArangeOp : OneFlow_BaseOp<"arange", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$integer_start, DefaultValuedAttr:$integer_delta, DefaultValuedAttr:$integer_limit, DefaultValuedAttr:$float_start, DefaultValuedAttr:$float_delta, DefaultValuedAttr:$float_limit, OneFlow_DataType:$dtype, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_BinCountOp : OneFlow_BaseOp<"bincount", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Optional:$weight ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$size ); let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_CoinFlipOp : OneFlow_BaseOp<"coin_flip", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$probability, DefaultValuedAttr:$batch_size, DefaultValuedAttr:$seed, DefaultValuedAttr:$has_seed, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_ConcatOp : OneFlow_BaseOp<"cat", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$max_dim_size ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TensorConstantOp : OneFlow_BaseOp<"tensor_constant", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ComplexDoubleAttr:$complex_value, DefaultValuedAttr:$floating_value, DefaultValuedAttr:$integer_value, DefaultValuedAttr:$is_floating_value, DefaultValuedAttr:$is_complex_value, OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_DropoutOp : OneFlow_BaseOp<"dropout", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$mask ); let attrs = (ins DefaultValuedAttr:$rate, DefaultValuedAttr:$seed ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ElementwiseMaximumBackwardOp : OneFlow_BaseOp<"elementwise_maximum_backward", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dz, OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs Optional:$dx, Optional:$dy ); let trait_attrs = (ins DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ElementwiseMinimumBackwardOp : OneFlow_BaseOp<"elementwise_minimum_backward", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dz, OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs Optional:$dx, Optional:$dy ); let trait_attrs = (ins DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmptyOp : OneFlow_BaseOp<"empty", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp, DefaultValuedAttr:$pin_memory, StrAttr:$device_type, DefaultValuedAttr:$device_id ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_EyeOp : OneFlow_BaseOp<"eye", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$rows, DefaultValuedAttr:$cols, OneFlow_DataType:$dtype, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GridSampleGradOp : OneFlow_BaseOp<"grid_sample_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$doutput, OneFlow_Tensor:$input, OneFlow_Tensor:$grid ); let output = (outs OneFlow_Tensor:$dinput, OneFlow_Tensor:$dgrid ); let attrs = (ins StrAttr:$interpolation_mode, StrAttr:$padding_mode, DefaultValuedAttr:$align_corners ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MultiCountNotFiniteOp : OneFlow_BaseOp<"multi_count_not_finite", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MultiSquareSumOp : OneFlow_BaseOp<"multi_square_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MultiReduceSumPowAbsOp : OneFlow_BaseOp<"multi_reduce_sum_pow_abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_MultiReduceMaxAbsOp : OneFlow_BaseOp<"multi_reduce_max_abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_MultiReduceMinAbsOp : OneFlow_BaseOp<"multi_reduce_min_abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_LocalMultiReduceMaxAbsOp : OneFlow_BaseOp<"local_multi_reduce_max_abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_LocalMultiReduceMinAbsOp : OneFlow_BaseOp<"local_multi_reduce_min_abs", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight ); let output = (outs OneFlow_Tensor:$output, OneFlow_Tensor:$out_weight ); let attrs = (ins DefaultValuedAttr:$ignore_index ); let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_NLLGradOp : OneFlow_BaseOp<"nll_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_grad, OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight ); let output = (outs OneFlow_Tensor:$in_grad ); let attrs = (ins DefaultValuedAttr:$ignore_index ); let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; } def OneFlow_PowXGradOp : OneFlow_BaseOp<"pow_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_PowYGradOp : OneFlow_BaseOp<"pow_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_PreluGradOp : OneFlow_BaseOp<"prelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$alpha ); let output = (outs OneFlow_Tensor:$dx, OneFlow_Tensor:$alpha_diff ); let attrs = (ins DefaultValuedAttr:$alpha_requires_grad ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RandpermOp : OneFlow_BaseOp<"randperm", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$n, DefaultValuedAttr:$seed, StrArrayAttr:$nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_RecvOp : OneFlow_BaseOp<"recv", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$src_process_id, OneFlow_DataType:$dtype, ShapeAttr:$shape, StrAttr:$device_type, DefaultValuedAttr:$device_id ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_SendOp : OneFlow_BaseOp<"send", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let attrs = (ins DefaultValuedAttr:$dst_process_id ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_SplitLikeOp : OneFlow_BaseOp<"split_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Variadic:$like ); let output = (outs Variadic:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SspVariableProxyOp : OneFlow_BaseOp<"ssp_variable_proxy", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$var ); let output = (outs OneFlow_Tensor:$ref, OneFlow_Tensor:$value ); let attrs = (ins DefaultValuedAttr:$buffer_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; } def OneFlow_TfPreluGradOp : OneFlow_BaseOp<"tf_prelu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$alpha ); let output = (outs OneFlow_Tensor:$dx, OneFlow_Tensor:$alpha_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UniformOp : OneFlow_BaseOp<"uniform", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$from, DefaultValuedAttr:$to, DefaultValuedAttr:$seed, OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; let has_dump_nd_sbp_signature_for_op_conf_fn = 1; } def OneFlow_UniformIntOp : OneFlow_BaseOp<"uniform_int", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$from, DefaultValuedAttr:$to, DefaultValuedAttr:$seed, OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_ExponentialOp : OneFlow_BaseOp<"exponential", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$seed, DefaultValuedAttr:$lambd, OneFlow_DataType:$dtype, ShapeAttr:$out_shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_MultinomialWithReplacementOp : OneFlow_BaseOp<"multinomial_with_replacement", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$prefix_sum ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$seed, DefaultValuedAttr:$num_samples, DefaultValuedAttr:$replacement ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UniqueOp : OneFlow_BaseOp<"unique", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$idx, OneFlow_Tensor:$num_unique ); let attrs = (ins OneFlow_DataType:$out_idx, DefaultValuedAttr:$sorted ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UniqueWithCountsOp : OneFlow_BaseOp<"unique_with_counts", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$idx, OneFlow_Tensor:$num_unique, OneFlow_Tensor:$count ); let attrs = (ins OneFlow_DataType:$out_idx, DefaultValuedAttr:$sorted ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XdivyXGradOp : OneFlow_BaseOp<"xdivy_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_XdivyYGradOp : OneFlow_BaseOp<"xdivy_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_StackOp : OneFlow_BaseOp<"stack", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$max_dim_size ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_StackGradOp : OneFlow_BaseOp<"stack_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Variadic:$like ); let output = (outs Variadic:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FusedWeightedSumOp : OneFlow_BaseOp<"fused_weighted_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins F32ArrayAttr:$weights, DefaultValuedAttr:$alpha ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DependOp : OneFlow_BaseOp<"depend", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$depend_tensor ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_MISC_OP_DEFINITIONS #ifdef GET_ONEFLOW_NCCL_OP_DEFINITIONS def OneFlow__ncclLogical_2DSameDim0All2allOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all2all", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogical_2DSameDim0AllGatherOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogical_2DSameDim0AllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather_noncontinuous", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogical_2DSameDim0AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_reduce", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogical_2DSameDim1AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim1_all_reduce", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalAllGatherOp : OneFlow_BaseOp<"_nccl_logical_all_gather", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalAllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_all_gather_noncontinuous", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalAllReduceOp : OneFlow_BaseOp<"_nccl_logical_all_reduce", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalReduceScatterOp : OneFlow_BaseOp<"_nccl_logical_reduce_scatter", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalReduceScatterNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_reduce_scatter_noncontinuous", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalS2sOp : OneFlow_BaseOp<"_nccl_logical_s2s", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalSendRecvOp : OneFlow_BaseOp<"_nccl_logical_send_recv", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$src_reduced_nd_sbp, StrArrayAttr:$dst_reduced_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow__ncclLogicalFusionOp : OneFlow_BaseOp<"_nccl_logical_fusion", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$in ); let output = (outs Variadic:$out ); let attrs = (ins StrArrayAttr:$src_nd_sbp_str_list, StrArrayAttr:$dst_nd_sbp_str_list, StrArrayAttr:$nccl_type_list ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } #endif // GET_ONEFLOW_NCCL_OP_DEFINITIONS #ifdef GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS def OneFlow_NormalizationAddReluOp : OneFlow_BaseOp<"normalization_add_relu", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$addend, Optional:$moving_mean, Optional:$moving_variance, OneFlow_Tensor:$gamma, OneFlow_Tensor:$beta ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$reserve_space, Optional:$mean, Optional:$inv_variance ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$training, DefaultValuedAttr:$momentum ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BatchNormStatsOp : OneFlow_BaseOp<"batch_norm_stats", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$mean, OneFlow_Tensor:$invstd ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BatchNormGatherStatsWithCountsOp : OneFlow_BaseOp<"batch_norm_gather_stats_with_counts", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$mean, OneFlow_Tensor:$invstd, OneFlow_Tensor:$counts, Optional:$running_mean, Optional:$running_var ); let output = (outs OneFlow_Tensor:$global_mean, OneFlow_Tensor:$global_invstd ); let attrs = (ins DefaultValuedAttr:$eps, DefaultValuedAttr:$momentum ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_BatchNormElemtOp : OneFlow_BaseOp<"batch_norm_elemt", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$weight, OneFlow_Tensor:$bias, OneFlow_Tensor:$mean, OneFlow_Tensor:$invstd ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$eps ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BatchNormBackwardReduceOp : OneFlow_BaseOp<"batch_norm_backward_reduce", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_out, OneFlow_Tensor:$input, OneFlow_Tensor:$mean, OneFlow_Tensor:$invstd ); let output = (outs OneFlow_Tensor:$sum_dy, OneFlow_Tensor:$sum_dy_xmu, OneFlow_Tensor:$grad_weight, OneFlow_Tensor:$grad_bias ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BatchNormBackwardElemtOp : OneFlow_BaseOp<"batch_norm_backward_elemt", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$grad_out, OneFlow_Tensor:$input, OneFlow_Tensor:$mean, OneFlow_Tensor:$invstd, OneFlow_Tensor:$weight, OneFlow_Tensor:$sum_dy, OneFlow_Tensor:$sum_dy_xmu, OneFlow_Tensor:$count ); let output = (outs OneFlow_Tensor:$grad_in ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CropMirrorNormalizeFromTensorbufferOp : OneFlow_BaseOp<"crop_mirror_normalize_from_tensorbuffer", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Optional:$mirror ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$color_space, DefaultValuedAttr:$output_layout, F32ArrayAttr:$mean, F32ArrayAttr:$std, DefaultValuedAttr:$crop_h, DefaultValuedAttr:$crop_w, DefaultValuedAttr:$crop_pos_x, DefaultValuedAttr:$crop_pos_y, OneFlow_DataType:$output_dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CropMirrorNormalizeFromUint8Op : OneFlow_BaseOp<"crop_mirror_normalize_from_uint8", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, Optional:$mirror ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$color_space, DefaultValuedAttr:$output_layout, F32ArrayAttr:$mean, F32ArrayAttr:$std, DefaultValuedAttr:$crop_h, DefaultValuedAttr:$crop_w, DefaultValuedAttr:$crop_pos_x, DefaultValuedAttr:$crop_pos_y, OneFlow_DataType:$output_dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImageNormalizeOp : OneFlow_BaseOp<"image_normalize", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins F32ArrayAttr:$std, F32ArrayAttr:$mean ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_L2NormalizeOp : OneFlow_BaseOp<"l2_normalize", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$square_x_sum ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_L2NormalizeGradOp : OneFlow_BaseOp<"l2_normalize_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$y, OneFlow_Tensor:$square_x_sum ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LayerNormOp : OneFlow_BaseOp<"layer_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$beta, Optional:$gamma ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance ); let attrs = (ins DefaultValuedAttr:$center, DefaultValuedAttr:$scale, DefaultValuedAttr:$begin_norm_axis, DefaultValuedAttr:$begin_params_axis, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SkipLayerNormOp : OneFlow_BaseOp<"skip_layer_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$gamma, Optional:$beta, Optional:$bias, Optional:$skip ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance ); let attrs = (ins DefaultValuedAttr:$epsilon, DefaultValuedAttr:$alpha ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, Optional:$gamma, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$begin_norm_axis, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, Optional:$gamma, Optional:$_add_to_output ); let output = (outs OneFlow_Tensor:$dx, OneFlow_Tensor:$gamma_diff, OneFlow_Tensor:$beta_diff ); let attrs = (ins DefaultValuedAttr:$begin_norm_axis, DefaultValuedAttr:$begin_params_axis, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes, DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance ); let output = (outs Optional:$beta_diff, Optional:$gamma_diff ); let attrs = (ins DefaultValuedAttr:$begin_params_axis ); let trait_attrs = (ins DenseI32ArrayAttr:$result_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NormalOp : OneFlow_BaseOp<"normal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$mean, DefaultValuedAttr:$std, DefaultValuedAttr:$seed, OneFlow_DataType:$dtype, ShapeAttr:$shape, StrArrayAttr:$nd_sbp ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_NormalizationOp : OneFlow_NormalizationBaseOp<"normalization", [AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; } def OneFlow_NormalizationGradOp : OneFlow_BaseOp<"normalization_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, OneFlow_Tensor:$gamma ); let output = (outs OneFlow_Tensor:$gamma_diff, OneFlow_Tensor:$beta_diff, OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$axis, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GroupNormOp : OneFlow_BaseOp<"group_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$beta, Optional:$gamma ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance ); let attrs = (ins DefaultValuedAttr:$affine, DefaultValuedAttr:$num_groups, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$data_format, DefaultValuedAttr:$activation ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GroupNormGradOp : OneFlow_BaseOp<"group_norm_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance, Optional:$gamma ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$num_groups, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GroupNormParamGradOp : OneFlow_BaseOp<"group_norm_param_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$mean, OneFlow_Tensor:$inv_variance ); let output = (outs OneFlow_Tensor:$dgamma, OneFlow_Tensor:$dbeta ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RmsNormOp : OneFlow_BaseOp<"rms_norm", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$weight ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$inv_rms ); let attrs = (ins ShapeAttr:$normalized_shape, DefaultValuedAttr:$epsilon ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RmsNormParamGradOp : OneFlow_BaseOp<"rms_norm_param_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$inv_rms ); let output = (outs OneFlow_Tensor:$weight_grad ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RmsNormGradOp : OneFlow_BaseOp<"rms_norm_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, OneFlow_Tensor:$inv_rms, Optional:$weight ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SkipRmsNormOp : OneFlow_BaseOp<"skip_rms_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, Optional:$weight, Optional:$bias, Optional:$skip ); let output = (outs OneFlow_Tensor:$y, OneFlow_Tensor:$inv_rms ); let attrs = (ins ShapeAttr:$normalized_shape, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS #ifdef GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS def OneFlow_AdagradUpdateOp : OneFlow_BaseOp<"adagrad_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, Optional:$train_step, OneFlow_Tensor:$sum ); let attrs = (ins DefaultValuedAttr:$train_step_val, DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$lr_decay, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$epsilon ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_AdamBiasCorrectionFactorOp : OneFlow_BaseOp<"adam_bias_correction_factor", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$train_step ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$beta ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AdamUpdateOp : OneFlow_BaseOp<"adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$model_copy, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, Optional:$bias_correction1, Optional:$bias_correction2, OneFlow_Tensor:$m, OneFlow_Tensor:$v, Optional:$max_v ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$bias_correction1_val, DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$amsgrad, DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_IndexedSlicesAdamUpdateOp : OneFlow_BaseOp<"indexed_slices_adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff_indices, OneFlow_Tensor:$model_diff_values, OneFlow_Tensor:$learning_rate, Optional:$bias_correction1, Optional:$bias_correction2, OneFlow_Tensor:$m, OneFlow_Tensor:$v, Optional:$max_v ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$amsgrad, DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momentum_update", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff_indices, OneFlow_Tensor:$model_diff_values, OneFlow_Tensor:$learning_rate, OneFlow_Tensor:$momentum ); let attrs = (ins DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$beta, DefaultValuedAttr:$dampening, DefaultValuedAttr:$nesterov, DefaultValuedAttr:$maximize, DefaultValuedAttr:$weight_decay ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<"indexed_slices_sgd_update", [NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff_indices, OneFlow_Tensor:$model_diff_values, OneFlow_Tensor:$learning_rate ); let attrs = (ins DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$weight_decay ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_LambUpdateOp : OneFlow_BaseOp<"lamb_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, Optional:$bias_correction1, Optional:$bias_correction2, OneFlow_Tensor:$m, OneFlow_Tensor:$v ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$bias_correction1_val, DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_LarsUpdateOp : OneFlow_BaseOp<"lars_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, OneFlow_Tensor:$learning_rate, OneFlow_Tensor:$momentum, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$momentum_beta, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$lars_coefficient, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, OneFlow_Tensor:$momentum, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta, DefaultValuedAttr:$dampening, DefaultValuedAttr:$nesterov, DefaultValuedAttr:$maximize, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_RmspropUpdateOp : OneFlow_BaseOp<"rmsprop_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, OneFlow_Tensor:$mean_square, Optional:$mean_gradient ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$centered, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$decay_rate, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SgdUpdateOp : OneFlow_BaseOp<"sgd_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$model_copy, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_FtrlUpdateOp : OneFlow_BaseOp<"ftrl_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$learning_rate, Optional:$skip_if, OneFlow_Tensor:$accumulate, OneFlow_Tensor:$z ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$lr_power, DefaultValuedAttr:$lambda1, DefaultValuedAttr:$lambda2, DefaultValuedAttr:$beta ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_AdadeltaUpdateOp : OneFlow_BaseOp<"adadelta_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, Optional:$learning_rate, Optional:$skip_if, OneFlow_Tensor:$square_avgs, OneFlow_Tensor:$acc_deltas ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$rho, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$maximize ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorSgdUpdateOp : OneFlow_BaseOp<"multi_tensor_sgd_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorMomentumUpdateOp : OneFlow_BaseOp<"multi_tensor_momentum_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Variadic:$momentum_buf, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$momentum, DefaultValuedAttr:$dampening, DefaultValuedAttr:$nesterov, DefaultValuedAttr:$maximize ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorAdamUpdateOp : OneFlow_BaseOp<"multi_tensor_adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, Optional:$bias_correction1, Optional:$bias_correction2, Variadic:$m, Variadic:$v ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$bias_correction1_val, DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$amsgrad, DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorSgdUpdateWithCastOp : OneFlow_BaseOp<"multi_tensor_sgd_update_with_cast", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Variadic:$model_copy, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorMomentumUpdateWithCastOp : OneFlow_BaseOp<"multi_tensor_momentum_update_with_cast", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Variadic:$model_copy, Variadic:$momentum_buf, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$momentum, DefaultValuedAttr:$dampening, DefaultValuedAttr:$nesterov, DefaultValuedAttr:$maximize ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorAdamUpdateWithCastOp : OneFlow_BaseOp<"multi_tensor_adam_update_with_cast", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_diff, Variadic:$model_copy, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$skip_if, Optional:$bias_correction1, Optional:$bias_correction2, Variadic:$m, Variadic:$v ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$learning_rate_scale, DefaultValuedAttr:$bias_correction1_val, DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$amsgrad, DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MultiTensorYoloV5WeightUpdateOp : OneFlow_BaseOp<"multi_tensor_yolov5_weight_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins Variadic:$model, Variadic:$model_update ); let attrs = (ins DefaultValuedAttr:$d ); let trait_attrs = (ins DenseI32ArrayAttr:$operand_segment_sizes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } #endif // GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS #ifdef GET_ONEFLOW_PADDING_OP_DEFINITIONS def OneFlow_PadOp : OneFlow_BaseOp<"pad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$padding_before, SI64ArrayAttr:$padding_after, SI64ArrayAttr:$padding, DefaultValuedAttr:$floating_constant_value, DefaultValuedAttr:$integral_constant_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReplicationPad1DOp : OneFlow_BaseOp<"replication_pad1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReplicationPad1DGradOp : OneFlow_BaseOp<"replication_pad1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReflectionPad1DOp : OneFlow_BaseOp<"reflection_pad1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReflectionPad1DGradOp : OneFlow_BaseOp<"reflection_pad1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReflectionPad2DOp : OneFlow_BaseOp<"reflection_pad2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReflectionPad2DGradOp : OneFlow_BaseOp<"reflection_pad2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReplicationPad2DOp : OneFlow_BaseOp<"replication_pad2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReplicationPad2DGradOp : OneFlow_BaseOp<"replication_pad2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64ArrayAttr:$padding ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SamePaddingOp : OneFlow_BaseOp<"same_padding", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins StrAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SamePaddingGradOp : OneFlow_BaseOp<"same_padding_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x_like, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins StrAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_PADDING_OP_DEFINITIONS #ifdef GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS def OneFlow_HierarchicalParallelCastOp : OneFlow_BaseOp<"hierarchical_parallel_cast", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrArrayAttr:$nd_sbp, StrAttr:$grad_mode, StrArrayAttr:$grad_nd_sbp ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; let has_get_nd_sbp_fn = 1; } def OneFlow_HierarchicalParallelCastLikeOp : OneFlow_BaseOp<"hierarchical_parallel_cast_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_ParallelCastOp : OneFlow_BaseOp<"parallel_cast", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$sbp_parallel, StrAttr:$grad_sbp_parallel ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_sbp_signature_infer_fn = 1; } #endif // GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS #ifdef GET_ONEFLOW_POOL_OP_DEFINITIONS def OneFlow_AdaptiveAvgPool1DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveAvgPool1DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveAvgPool2DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveAvgPool2DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveAvgPool3DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveAvgPool3DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool1DOp : OneFlow_AdaptiveMaxPoolBaseOp<"adaptive_max_pool1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool1DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<"adaptive_max_pool1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool2DOp : OneFlow_AdaptiveMaxPoolBaseOp<"adaptive_max_pool2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool2DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<"adaptive_max_pool2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool3DOp : OneFlow_AdaptiveMaxPoolBaseOp<"adaptive_max_pool3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AdaptiveMaxPool3DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<"adaptive_max_pool3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool1DOp : OneFlow_AvgPoolBaseOp<"avg_pool_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool1DGradOp : OneFlow_AvgPoolGradBaseOp<"avg_pool_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool2DOp : OneFlow_AvgPoolBaseOp<"avg_pool_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool2DGradOp : OneFlow_AvgPoolGradBaseOp<"avg_pool_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool3DOp : OneFlow_AvgPoolBaseOp<"avg_pool_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_AvgPool3DGradOp : OneFlow_AvgPoolGradBaseOp<"avg_pool_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool1DOp : OneFlow_MaxPoolBaseOp<"max_pool_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool1DGradOp : OneFlow_MaxPoolGradBaseOp<"max_pool_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool2DOp : OneFlow_MaxPoolBaseOp<"max_pool_2d", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool2DGradOp : OneFlow_MaxPoolGradBaseOp<"max_pool_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool3DOp : OneFlow_MaxPoolBaseOp<"max_pool_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxPool3DGradOp : OneFlow_MaxPoolGradBaseOp<"max_pool_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool1DOp : OneFlow_MaxUnpoolBaseOp<"max_unpool_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool2DOp : OneFlow_MaxUnpoolBaseOp<"max_unpool_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool3DOp : OneFlow_MaxUnpoolBaseOp<"max_unpool_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool1DGradOp : OneFlow_MaxUnpoolGradBaseOp<"max_unpool_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool2DGradOp : OneFlow_MaxUnpoolGradBaseOp<"max_unpool_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_MaxUnpool3DGradOp : OneFlow_MaxUnpoolGradBaseOp<"max_unpool_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool1DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool2DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool3DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfAvgPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool1DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool2DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool3DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} def OneFlow_TfMaxPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> {} #endif // GET_ONEFLOW_POOL_OP_DEFINITIONS #ifdef GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS def OneFlow_FakeQuantizationOp : OneFlow_BaseOp<"fake_quantization", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$scale, OneFlow_Tensor:$zero_point ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$quantization_formula, DefaultValuedAttr:$quantization_bit, DefaultValuedAttr:$quantization_scheme ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MinMaxObserverOp : OneFlow_BaseOp<"min_max_observer", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$scale, OneFlow_Tensor:$zero_point ); let attrs = (ins DefaultValuedAttr:$quantization_formula, DefaultValuedAttr:$quantization_bit, DefaultValuedAttr:$quantization_scheme, DefaultValuedAttr:$per_layer_quantization ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_MovingAverageMinMaxObserverOp : OneFlow_BaseOp<"moving_average_min_max_observer", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$current_train_step, OneFlow_Tensor:$moving_max, OneFlow_Tensor:$moving_min ); let output = (outs OneFlow_Tensor:$scale, OneFlow_Tensor:$zero_point ); let attrs = (ins DefaultValuedAttr:$training, DefaultValuedAttr:$quantization_formula, DefaultValuedAttr:$stop_update_after_iters, DefaultValuedAttr:$quantization_bit, DefaultValuedAttr:$quantization_scheme, DefaultValuedAttr:$momentum ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_QuantizationOp : OneFlow_BaseOp<"quantization", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$scale, OneFlow_Tensor:$zero_point ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$quantization_formula, DefaultValuedAttr:$quantization_bit, DefaultValuedAttr:$quantization_scheme ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_GroupwiseDequantizeOp : OneFlow_BaseOp<"groupwise_dequantize", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$scale, Optional:$zero ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$num_bits, DefaultValuedAttr:$symmetric, SI64Attr:$group_dim, SI64Attr:$group_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FusedLinearWithGroupwiseQuantizedWeightOp : OneFlow_BaseOp<"fused_linear_with_groupwise_quantized_weight", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$w, OneFlow_Tensor:$w_scale, Optional:$w_zero, Optional:$b ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$num_bits, DefaultValuedAttr:$symmetric, SI64Attr:$group_dim, SI64Attr:$group_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS #ifdef GET_ONEFLOW_REDUCE_OP_DEFINITIONS def OneFlow_IndexedSlicesReduceSumOp : OneFlow_BaseOp<"indexed_slices_reduce_sum", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x_indices, OneFlow_Tensor:$x_values ); let output = (outs OneFlow_Tensor:$y_indices, OneFlow_Tensor:$y_values, OneFlow_Tensor:$num_unique ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceAllOp : OneFlow_BaseOp<"reduce_all", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceAnyOp : OneFlow_BaseOp<"reduce_any", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMaxOp : OneFlow_BaseOp<"reduce_max", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMaxDeviceStageOp : OneFlow_BaseOp<"reduce_max_device_stage", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$mask, OneFlow_Tensor:$count ); let attrs = (ins SI32ArrayAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMaxDeviceStageGradOp : OneFlow_BaseOp<"reduce_max_device_stage_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_diff, OneFlow_Tensor:$mask, OneFlow_Tensor:$count ); let output = (outs OneFlow_Tensor:$in_diff ); let attrs = (ins SI32ArrayAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMaxGlobalStageOp : OneFlow_BaseOp<"reduce_max_global_stage", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$device_count ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$mask ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReduceMaxGlobalStageGradOp : OneFlow_BaseOp<"reduce_max_global_stage_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_diff, OneFlow_Tensor:$mask, OneFlow_Tensor:$device_count ); let output = (outs OneFlow_Tensor:$in_diff ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMinOp : OneFlow_BaseOp<"reduce_min", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMinDeviceStageOp : OneFlow_BaseOp<"reduce_min_device_stage", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$mask, OneFlow_Tensor:$count ); let attrs = (ins SI32ArrayAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMinDeviceStageGradOp : OneFlow_BaseOp<"reduce_min_device_stage_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_diff, OneFlow_Tensor:$mask, OneFlow_Tensor:$count ); let output = (outs OneFlow_Tensor:$in_diff ); let attrs = (ins SI32ArrayAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceMinGlobalStageOp : OneFlow_BaseOp<"reduce_min_global_stage", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$device_count ); let output = (outs OneFlow_Tensor:$out, OneFlow_Tensor:$mask ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_ReduceMinGlobalStageGradOp : OneFlow_BaseOp<"reduce_min_global_stage_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_diff, OneFlow_Tensor:$mask, OneFlow_Tensor:$device_count ); let output = (outs OneFlow_Tensor:$in_diff ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceProdOp : OneFlow_BaseOp<"reduce_prod", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceSumOp : OneFlow_BaseOp<"reduce_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceNanSumOp : OneFlow_BaseOp<"reduce_nansum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input_tensor ); let output = (outs OneFlow_Tensor:$output_tensor ); let attrs = (ins SI32ArrayAttr:$axis, DefaultValuedAttr:$keepdims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ReduceSumLikeOp : OneFlow_BaseOp<"reduce_sum_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI32ArrayAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } #endif // GET_ONEFLOW_REDUCE_OP_DEFINITIONS #ifdef GET_ONEFLOW_RESHAPE_OP_DEFINITIONS def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_enumerate_nd_sbp_signatures_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_ReshapeLikeOp : OneFlow_BaseOp<"reshape_like", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } #endif // GET_ONEFLOW_RESHAPE_OP_DEFINITIONS #ifdef GET_ONEFLOW_SCALAR_OP_DEFINITIONS def OneFlow_ClipByScalarOp : OneFlow_BaseOp<"clip_by_scalar", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$floating_min, DefaultValuedAttr:$integral_min, DefaultValuedAttr:$floating_max, DefaultValuedAttr:$integral_max ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ClipByScalarGradOp : OneFlow_BaseOp<"clip_by_scalar_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$floating_min, DefaultValuedAttr:$integral_min, DefaultValuedAttr:$floating_max, DefaultValuedAttr:$integral_max ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ClipByScalarMaxOp : OneFlow_BaseOp<"clip_by_scalar_max", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$floating_max, DefaultValuedAttr:$integral_max ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ClipByScalarMaxGradOp : OneFlow_BaseOp<"clip_by_scalar_max_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$floating_max, DefaultValuedAttr:$integral_max ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ClipByScalarMinOp : OneFlow_BaseOp<"clip_by_scalar_min", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$floating_min, DefaultValuedAttr:$integral_min ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ClipByScalarMinGradOp : OneFlow_BaseOp<"clip_by_scalar_min_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$floating_min, DefaultValuedAttr:$integral_min ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarAddOp : OneFlow_BaseOp<"scalar_add", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_ScalarAddByTensorOp : OneFlow_BaseOp<"scalar_add_by_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scalar ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } // host_scalar_add_by_tensor op just for test host memory input def OneFlow_HostScalarAddByTensorOp : OneFlow_BaseOp<"host_scalar_add_by_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scalar ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarDivByTensorOp : OneFlow_BaseOp<"scalar_div_by_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scalar ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarFloordivOp : OneFlow_BaseOp<"scalar_floordiv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarTruncdivOp : OneFlow_BaseOp<"scalar_truncdiv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarFmodOp : OneFlow_BaseOp<"scalar_fmod", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalAndOp : OneFlow_BaseOp<"scalar_logical_and", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalEqualOp : OneFlow_BaseOp<"scalar_logical_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalGreaterOp : OneFlow_BaseOp<"scalar_logical_greater", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalGreaterEqualOp : OneFlow_BaseOp<"scalar_logical_greater_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalInplaceGreaterOp : OneFlow_BaseOp<"scalar_logical_inplace_greater", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalLessOp : OneFlow_BaseOp<"scalar_logical_less", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalLessEqualOp : OneFlow_BaseOp<"scalar_logical_less_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalNotEqualOp : OneFlow_BaseOp<"scalar_logical_not_equal", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalOrOp : OneFlow_BaseOp<"scalar_logical_or", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLogicalXorOp : OneFlow_BaseOp<"scalar_logical_xor", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarMulOp : OneFlow_BaseOp<"scalar_mul", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarMulByTensorOp : OneFlow_BaseOp<"scalar_mul_by_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scalar ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarDivOp : OneFlow_BaseOp<"scalar_div", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarPowOp : OneFlow_BaseOp<"scalar_pow", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarPowGradOp : OneFlow_BaseOp<"scalar_pow_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarReversePowOp : OneFlow_BaseOp<"scalar_reverse_pow", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarReversePowGradOp : OneFlow_BaseOp<"scalar_reverse_pow_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarSubByTensorOp : OneFlow_BaseOp<"scalar_sub_by_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$scalar ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLerpOp : OneFlow_BaseOp<"scalar_lerp", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$start, OneFlow_Tensor:$end ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarLerpGradOp : OneFlow_BaseOp<"scalar_lerp_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$start, OneFlow_Tensor:$end, OneFlow_Tensor:$out_diff ); let output = (outs OneFlow_Tensor:$start_diff, OneFlow_Tensor:$end_diff ); let attrs = (ins DefaultValuedAttr:$has_int_operand, DefaultValuedAttr:$has_float_operand, DefaultValuedAttr:$int_operand, DefaultValuedAttr:$float_operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarBitwiseAndOp : OneFlow_BaseOp<"scalar_bitwise_and", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarBitwiseOrOp : OneFlow_BaseOp<"scalar_bitwise_or", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ScalarBitwiseXorOp : OneFlow_BaseOp<"scalar_bitwise_xor", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$operand ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_SCALAR_OP_DEFINITIONS #ifdef GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS def OneFlow_LogSoftmaxOp : OneFlow_BaseOp<"log_softmax", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$prob ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogSoftmaxGradOp : OneFlow_BaseOp<"log_softmax_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prob, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftmaxOp : OneFlow_BaseOp<"softmax", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_compute_complexity_fn = 1; } def OneFlow_SoftmaxCrossEntropyOp : OneFlow_BaseOp<"softmax_cross_entropy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$prob, OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"softmax_cross_entropy_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$label, OneFlow_Tensor:$prob ); let output = (outs OneFlow_Tensor:$prediction_diff ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftmaxGradOp : OneFlow_BaseOp<"softmax_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SparseSoftmaxCrossEntropyOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$prob, OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SparseSoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$label, OneFlow_Tensor:$dy, OneFlow_Tensor:$prob ); let output = (outs OneFlow_Tensor:$prediction_diff ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SparseSoftmaxCrossEntropyMsOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$prediction, OneFlow_Tensor:$label ); let output = (outs OneFlow_Tensor:$prob, OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_SparseSoftmaxCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$label, OneFlow_Tensor:$dy, OneFlow_Tensor:$prob ); let output = (outs OneFlow_Tensor:$prediction_diff ); let attrs = (ins DefaultValuedAttr:$depth ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS #ifdef GET_ONEFLOW_SUMMARY_OP_DEFINITIONS def OneFlow_CreateSummaryWriterOp : OneFlow_BaseOp<"create_summary_writer", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let attrs = (ins StrAttr:$logdir ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FlushSummaryWriterOp : OneFlow_BaseOp<"flush_summary_writer", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SummaryWriteHistogramOp : OneFlow_BaseOp<"summary_write_histogram", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$step, OneFlow_Tensor:$tag ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SummaryWriteImageOp : OneFlow_BaseOp<"summary_write_image", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$step, OneFlow_Tensor:$tag ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SummaryWritePbOp : OneFlow_BaseOp<"summary_write_pb", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SummaryWriteScalarOp : OneFlow_BaseOp<"summary_write_scalar", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$step, OneFlow_Tensor:$tag ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_SUMMARY_OP_DEFINITIONS #ifdef GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS def OneFlow_GenTensorBufferOp : OneFlow_BaseOp<"gen_tensor_buffer", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$shape, ShapeArrayAttr:$shape_list, F32ArrayAttr:$value_list, OneFlow_DataType:$data_type, DefaultValuedAttr:$dynamic_out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TensorBufferToListOfTensorsOp : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs Variadic:$out ); let attrs = (ins ShapeAttr:$out_shape, OneFlow_DataType:$out_dtype, DefaultValuedAttr:$dynamic_out ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; } def OneFlow_TensorBufferToListOfTensorsV2Op : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors_v2", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs Variadic:$out ); let attrs = (ins ShapeArrayAttr:$out_shapes, DTArrayAttr:$out_dtypes, DefaultValuedAttr:$dynamic_out ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_arg_modify_fn = 1; } def OneFlow_TensorBufferToTensorOp : OneFlow_BaseOp<"tensor_buffer_to_tensor", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$instance_shape, OneFlow_DataType:$dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TensorToTensorBufferOp : OneFlow_BaseOp<"tensor_to_tensor_buffer", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$instance_dims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS #ifdef GET_ONEFLOW_TEST_OP_DEFINITIONS def OneFlow_ThrowErrorOp : OneFlow_BaseOp<"throw_error", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_TEST_OP_DEFINITIONS #ifdef GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS def OneFlow_AcosOp : OneFlow_BaseOp<"acos", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AcosGradOp : OneFlow_BaseOp<"acos_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AcoshOp : OneFlow_BaseOp<"acosh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AcoshGradOp : OneFlow_BaseOp<"acosh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AsinOp : OneFlow_BaseOp<"asin", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AsinGradOp : OneFlow_BaseOp<"asin_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AsinhOp : OneFlow_BaseOp<"asinh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AsinhGradOp : OneFlow_BaseOp<"asinh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AtanOp : OneFlow_BaseOp<"atan", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Atan2Op : OneFlow_BaseOp<"atan2", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y ); let output = (outs OneFlow_Tensor:$z ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Atan2XGradOp : OneFlow_BaseOp<"atan2_x_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Atan2YGradOp : OneFlow_BaseOp<"atan2_y_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$y, OneFlow_Tensor:$dz ); let output = (outs OneFlow_Tensor:$dy ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AtanGradOp : OneFlow_BaseOp<"atan_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AtanhOp : OneFlow_BaseOp<"atanh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AtanhGradOp : OneFlow_BaseOp<"atanh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CosOp : OneFlow_BaseOp<"cos", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CosGradOp : OneFlow_BaseOp<"cos_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CoshOp : OneFlow_BaseOp<"cosh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CoshGradOp : OneFlow_BaseOp<"cosh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardtanhOp : OneFlow_BaseOp<"hardtanh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$min_val, DefaultValuedAttr:$max_val ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardtanhGradOp : OneFlow_BaseOp<"hardtanh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$min_val, DefaultValuedAttr:$max_val ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SinOp : OneFlow_BaseOp<"sin", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SinGradOp : OneFlow_BaseOp<"sin_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SinhOp : OneFlow_BaseOp<"sinh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SinhGradOp : OneFlow_BaseOp<"sinh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TanOp : OneFlow_BaseOp<"tan", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TanGradOp : OneFlow_BaseOp<"tan_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TanhOp : OneFlow_BaseOp<"tanh", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$y, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NotEqualZeroOp : OneFlow_BaseOp<"not_equal_zero", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS #ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS def OneFlow_AccOp : OneFlow_BaseOp<"acc", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$max_acc_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_blob_time_shape_infer_fn = 1; } def OneFlow_AccCtrlTickOp : OneFlow_BaseOp<"acc_ctrl_tick", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$max_acc_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; let has_output_blob_time_shape_infer_fn = 1; } def OneFlow_AffineGridOp : OneFlow_BaseOp<"affine_grid", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$theta ); let output = (outs OneFlow_Tensor:$grid ); let attrs = (ins ShapeAttr:$size, DefaultValuedAttr:$align_corners ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AffineGridGradOp : OneFlow_BaseOp<"affine_grid_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dgrid ); let output = (outs OneFlow_Tensor:$dtheta ); let attrs = (ins ShapeAttr:$size, DefaultValuedAttr:$align_corners ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_BernoulliOp : OneFlow_BaseOp<"bernoulli", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_DataType:$dtype, DefaultValuedAttr:$seed, DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CastOp : OneFlow_BaseOp<"cast", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_DataType:$dtype, DefaultValuedAttr:$pin_memory ); let has_device_and_stream_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MutableCastOnceOp : OneFlow_BaseOp<"mutable_cast_once", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_DataType:$dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let same_output_regst_num = 1; } def OneFlow_CastToStaticShapeOp : OneFlow_BaseOp<"cast_to_static_shape", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CastToTickOp : OneFlow_BaseOp<"cast_to_tick", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_CeluOp : OneFlow_BaseOp<"celu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CopyOp : OneFlow_BaseOp<"copy", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$device_type, DefaultValuedAttr:$device_id, DefaultValuedAttr:$pin_memory ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_device_and_stream_infer_fn = 1; } def OneFlow_CountNotFiniteOp : OneFlow_BaseOp<"count_not_finite", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DiagOp : OneFlow_BaseOp<"diag", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$diagonal ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DiagonalOp : OneFlow_BaseOp<"diagonal", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$offset ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EluOp : OneFlow_BaseOp<"elu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ExpandOp : OneFlow_BaseOp<"expand", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins ShapeAttr:$expand_shape ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ExpandDimsOp : OneFlow_BaseOp<"expand_dims", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$axis ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FlipOp : OneFlow_BaseOp<"flip", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins SI32ArrayAttr:$dims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FoldOp : OneFlow_BaseOp<"fold", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins StrAttr:$data_format, SI32ArrayAttr:$output_size, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$padding, SI32ArrayAttr:$dilation_rate ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_GeluOp : OneFlow_BaseOp<"gelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FastGeluOp : OneFlow_BaseOp<"fast_gelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_QuickGeluOp : OneFlow_BaseOp<"quick_gelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SquareReLUOp : OneFlow_BaseOp<"square_relu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardShrinkOp : OneFlow_BaseOp<"hardshrink", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$lambd ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_HardswishOp : OneFlow_BaseOp<"hardswish", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LeakyReluOp : OneFlow_BaseOp<"leaky_relu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RReluOp : OneFlow_BaseOp<"rrelu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$output, OneFlow_Tensor:$noise_data ); let attrs = (ins DefaultValuedAttr:$seed, DefaultValuedAttr:$lower, DefaultValuedAttr:$upper, DefaultValuedAttr:$training ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log2Op : OneFlow_BaseOp<"log2", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_Log10Op : OneFlow_BaseOp<"log10", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LogicalNotOp : OneFlow_BaseOp<"logical_not", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_MishOp : OneFlow_BaseOp<"mish", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_NarrowOp : OneFlow_BaseOp<"narrow", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$dim, DefaultValuedAttr:$start, DefaultValuedAttr:$length ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneHotOp : OneFlow_BaseOp<"one_hot", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$depth, DefaultValuedAttr:$floating_on_value, DefaultValuedAttr:$integer_on_value, DefaultValuedAttr:$floating_off_value, DefaultValuedAttr:$integer_off_value, OneFlow_DataType:$dtype ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_PackOp : OneFlow_BaseOp<"pack", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$pack_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_blob_time_shape_infer_fn = 1; } def OneFlow_RandomMaskLikeOp : OneFlow_BaseOp<"random_mask_like", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$rate, DefaultValuedAttr:$seed ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasCanonicalizer = 1; } def OneFlow_RepeatOp : OneFlow_BaseOp<"repeat", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$repeat_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_blob_time_shape_infer_fn = 1; } def OneFlow_Repeat_InterLeaveOp : OneFlow_BaseOp<"repeat_interleave", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$cumsum ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$repeat_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RollOp : OneFlow_BaseOp<"roll", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI32ArrayAttr:$shifts, SI32ArrayAttr:$dims ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SeluOp : OneFlow_BaseOp<"selu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SiluOp : OneFlow_BaseOp<"silu", [NoMemoryEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftShrinkOp: OneFlow_BaseOp<"softshrink", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SoftsignOp : OneFlow_BaseOp<"softsign", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SortOp : OneFlow_BaseOp<"sort", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins StrAttr:$direction ); let has_check_fn = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SquareSumOp : OneFlow_BaseOp<"square_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SqrtSquareSumOp : OneFlow_BaseOp<"sqrt_square_sum", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_SqueezeOp : OneFlow_BaseOp<"squeeze", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins SI32ArrayAttr:$axes ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ThresholdOp : OneFlow_BaseOp<"threshold", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$threshold_val, DefaultValuedAttr:$value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TransposeOp : OneFlow_BaseOp<"transpose", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins SI32ArrayAttr:$perm ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let hasFolder = 1; } def OneFlow_AsStridedOp : OneFlow_BaseOp<"as_strided", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins SI64ArrayAttr:$size, SI64ArrayAttr:$stride, DefaultValuedAttr:$storage_offset ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IndexAddOp : OneFlow_BaseOp<"index_add", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$index, OneFlow_Tensor:$source ); let output = (outs OneFlow_Tensor:$output ); let attrs = (ins SI64Attr: $dim, F32Attr: $alpha ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_AsStridedGradOp : OneFlow_BaseOp<"as_strided_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$input ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins SI64ArrayAttr:$size, SI64ArrayAttr:$stride, DefaultValuedAttr:$storage_offset ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TrilOp : OneFlow_BaseOp<"tril", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$diagonal, DefaultValuedAttr:$floating_fill_value, DefaultValuedAttr:$integer_fill_value, DefaultValuedAttr:$is_floating_fill_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TriuOp : OneFlow_BaseOp<"triu", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$diagonal ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TruncOp : OneFlow_BaseOp<"trunc", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_TruncGradOp : OneFlow_BaseOp<"trunc_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x, OneFlow_Tensor:$dy ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UnfoldOp : OneFlow_BaseOp<"unfold", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$padding, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UnfoldTensorOp : OneFlow_BaseOp<"unfold_tensor", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$dimension, DefaultValuedAttr:$size, DefaultValuedAttr:$step ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UnpackOp : OneFlow_BaseOp<"unpack", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$unpack_num ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_output_blob_time_shape_infer_fn = 1; } def OneFlow_ZeroLikeOp : OneFlow_BaseOp<"zero_like", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$like ); let output = (outs OneFlow_Tensor:$out ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_nd_sbp_infer_fn = 1; } def OneFlow_ToContiguousOp : OneFlow_BaseOp<"to_contiguous", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ConvertMemoryFormatOp : OneFlow_BaseOp<"convert_memory_format", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins OneFlow_MemoryFormat:$memory_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IsNanOp : OneFlow_BaseOp<"isnan", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IsInfOp : OneFlow_BaseOp<"isinf", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IsFiniteOp : OneFlow_BaseOp<"isfinite", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RealOp : OneFlow_BaseOp<"real", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RealGradOp : OneFlow_BaseOp<"real_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dout ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImagOp : OneFlow_BaseOp<"imag", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ImagGradOp : OneFlow_BaseOp<"imag_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dout ); let output = (outs OneFlow_Tensor:$dx ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_ConjPhysicalOp : OneFlow_BaseOp<"conj_physical", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS #ifdef GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS def OneFlow_UpsampleBicubic2DOp : OneFlow_BaseOp<"upsample_bicubic_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleBicubic2DGradOp : OneFlow_BaseOp<"upsample_bicubic_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleBilinear2DOp : OneFlow_BaseOp<"upsample_bilinear_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleBilinear2DGradOp : OneFlow_BaseOp<"upsample_bilinear_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleLinear1DOp : OneFlow_BaseOp<"upsample_linear_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$scale_factor, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<"upsample_linear_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale_factor, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$scale_factor, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$scale_factor, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$depth_scale, DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<"upsample_nearest_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$depth_scale, DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleTrilinear3DOp : OneFlow_BaseOp<"upsample_trilinear_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let attrs = (ins DefaultValuedAttr:$depth_scale, DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UpsampleTrilinear3DGradOp : OneFlow_BaseOp<"upsample_trilinear_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$dx ); let attrs = (ins DefaultValuedAttr:$depth_scale, DefaultValuedAttr:$height_scale, DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS #ifdef GET_ONEFLOW_ONE_EMBEDDING_OP_DEFINITIONS def OneFlow_OneEmbeddingFusedLookupOp : OneFlow_BaseOp<"one_embedding_fused_lookup", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$shadow, OneFlow_Tensor:$ids, Optional:$table_ids ); let output = (outs OneFlow_Tensor:$embeddings ); let attrs = (ins OneFlow_DataType:$dtype, StrAttr:$embedding_name, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, DefaultValuedAttr:$is_full_cache, DefaultValuedAttr:$num_tables, DefaultValuedAttr:$padding_idx, DefaultValuedAttr:$has_padding_idx, StrAttr:$embedding_tables, DefaultValuedAttr:$seed ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } def OneFlow_OneEmbeddingFusedLookupGradOp : OneFlow_BaseOp<"one_embedding_fused_lookup_grad", [DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ids, OneFlow_Tensor:$embedding_grad ); let attrs = (ins StrAttr:$embedding_name, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_UniqueKeyValuePairOp : OneFlow_BaseOp<"unique_key_value_pair", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$keys, Optional:$values ); let output = (outs OneFlow_Tensor:$num_unique, OneFlow_Tensor:$unique_keys, OneFlow_Tensor:$unique_values, OneFlow_Tensor:$inverse_indices ); let attrs = (ins DefaultValuedAttr:$num_tables, DefaultValuedAttr:$padding_idx, DefaultValuedAttr:$has_padding_idx, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IdShuffleOp : OneFlow_BaseOp<"id_shuffle", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$ids, Optional:$table_ids ); let output = (outs OneFlow_Tensor:$num_unique_matrix, OneFlow_Tensor:$inverse_unique_partition_indices, OneFlow_Tensor:$cur_rank_num_unique, OneFlow_Tensor:$cur_rank_unique_ids, OneFlow_Tensor:$cur_rank_unique_table_ids, OneFlow_Tensor:$cur_rank_inverse_indices ); let attrs = (ins DefaultValuedAttr:$num_tables, DefaultValuedAttr:$padding_idx, DefaultValuedAttr:$has_padding_idx, StrAttr:$embedding_name ); let same_output_regst_num = 2; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_IdShuffleCopyOutOp : OneFlow_BaseOp<"id_shuffle_copy_out", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_matrix, OneFlow_Tensor:$inverse_unique_partition_indices, OneFlow_Tensor:$cur_rank_num_unique, OneFlow_Tensor:$cur_rank_unique_ids, OneFlow_Tensor:$cur_rank_unique_table_ids, OneFlow_Tensor:$cur_rank_inverse_indices ); let output = (outs OneFlow_Tensor:$out_num_unique_matrix, OneFlow_Tensor:$out_inverse_unique_partition_indices, OneFlow_Tensor:$out_cur_rank_num_unique, OneFlow_Tensor:$out_cur_rank_unique_ids, OneFlow_Tensor:$out_cur_rank_unique_table_ids, OneFlow_Tensor:$out_cur_rank_inverse_indices ); let attrs = (ins StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingGatherOp : OneFlow_BaseOp<"one_embedding_gather", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$indices ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingShuffleOp : OneFlow_BaseOp<"embedding_shuffle", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$cur_rank_embeddings, OneFlow_Tensor:$num_unique_matrix, OneFlow_Tensor:$cur_rank_inverse_indices, OneFlow_Tensor:$inverse_unique_partition_indices ); let output = (outs OneFlow_Tensor:$embeddings ); let attrs = (ins DefaultValuedAttr:$embedding_size, DefaultValuedAttr:$skip_last_gather, DefaultValuedAttr:$is_train, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingGradientShuffleOp : OneFlow_BaseOp<"embedding_gradient_shuffle", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$embedding_grad, OneFlow_Tensor:$num_unique_matrix, OneFlow_Tensor:$cur_rank_inverse_indices, OneFlow_Tensor:$inverse_unique_partition_indices ); let output = (outs OneFlow_Tensor:$cur_rank_unique_embedding_grad ); let attrs = (ins DefaultValuedAttr:$embedding_size, DefaultValuedAttr:$only_zero_valid_grad, DefaultValuedAttr:$skip_first_scatter, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingPrefetchOp : OneFlow_BaseOp<"embedding_prefetch", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_ids, OneFlow_Tensor:$table_ids ); let output = (outs OneFlow_Tensor:$context //no practical sense, control lookup run after prefetch. ); let attrs = (ins DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name, StrAttr:$embedding_tables, StrAttr:$state_initializer, DefaultValuedAttr:$seed ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingLookupOp : OneFlow_BaseOp<"embedding_lookup", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_ids, OneFlow_Tensor:$table_ids, Optional:$context ); let output = (outs OneFlow_Tensor:$unique_values, Optional:$embeddings ); let attrs = (ins OneFlow_DataType:$dtype, OneFlow_DataType:$embeddings_dtype, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name, StrAttr:$embedding_tables, StrAttr:$state_initializer, DefaultValuedAttr:$seed ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingFusedSgdUpdatePutOp : OneFlow_BaseOp<"one_embedding_fused_sgd_update_put", [DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, OneFlow_Tensor:$learning_rate ); let attrs = (ins DefaultValuedAttr:$scale, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingSgdUpdateOp : OneFlow_BaseOp<"one_embedding_sgd_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingMomentumUpdateOp : OneFlow_BaseOp<"one_embedding_momentum_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$beta, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingAdamUpdateOp : OneFlow_BaseOp<"one_embedding_adam_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if, Optional:$bias_correction1, Optional:$bias_correction2 ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$bias_correction1_val, DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$do_bias_correction, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingSmartDecaySparseAdamUpdateOp : OneFlow_BaseOp<"one_embedding_smart_decay_sparse_adam_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$train_step, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$train_step_val, DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$do_bias_correction, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingAdagradUpdateOp : OneFlow_BaseOp<"one_embedding_adagrad_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$train_step, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$train_step_val, DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$lr_decay, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_EmbeddingPutOp : OneFlow_BaseOp<"embedding_put", [DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_ids, OneFlow_Tensor:$unique_embeddings ); let attrs = (ins DefaultValuedAttr:$line_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_OneEmbeddingFtrlUpdateOp : OneFlow_BaseOp<"one_embedding_ftrl_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, OneFlow_Tensor:$unique_embeddings, OneFlow_Tensor:$embedding_grad, Optional:$learning_rate, Optional:$scale_by_tensor, Optional:$down_scale_by_tensor, Optional:$skip_if ); let output = (outs OneFlow_Tensor:$updated_unique_embeddings ); let attrs = (ins DefaultValuedAttr:$learning_rate_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$lr_power, DefaultValuedAttr:$lambda1, DefaultValuedAttr:$lambda2, DefaultValuedAttr:$beta, DefaultValuedAttr:$line_size, DefaultValuedAttr:$embedding_size, StrAttr:$embedding_name ); let same_output_regst_num = 1; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_RocAucScoreOp : OneFlow_BaseOp<"roc_auc_score", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$label, OneFlow_Tensor:$pred ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FillOp : OneFlow_BaseOp<"fill_", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let attrs = (ins DefaultValuedAttr:$floating_value, DefaultValuedAttr:$integral_value, DefaultValuedAttr:$is_floating_value ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_FillTensorOp : OneFlow_BaseOp<"fill_tensor_", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in, OneFlow_Tensor:$value ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_ONE_EMBEDDING_OP_DEFINITIONS #ifdef GET_ONEFLOW_LINEAR_ALGEBRA_OP_DEFINITIONS def OneFlow_InvOp : OneFlow_BaseOp<"inv", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LinalgCrossOp : OneFlow_BaseOp<"linalg_cross", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$other ); let attrs = (ins SI64Attr:$dim ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_DetOp : OneFlow_BaseOp<"det", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$y ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_LUDecompositionOp : OneFlow_BaseOp<"lu_decomposition", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x ); let output = (outs OneFlow_Tensor:$LU, OneFlow_Tensor:$pivot ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_LINEAR_ALGEBRA_OP_DEFINITIONS #ifdef GET_ONEFLOW_SYSTEM_OP_DEFINITIONS def OneFlow_CopyH2DOp : OneFlow_BaseOp<"copy_h2d", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } def OneFlow_CopyD2HOp : OneFlow_BaseOp<"copy_d2h", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); let output = (outs OneFlow_Tensor:$out ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #endif // GET_ONEFLOW_SYSTEM_OP_DEFINITIONS include "mlir/Interfaces/CallInterfaces.td" class OneFlow_JITLikeOp : OneFlow_BaseOp]> { let input = (ins Variadic:$in); let output = (outs Variadic:$out); let attrs = (ins FlatSymbolRefAttr:$callee, BytesAttr:$mlir_assembly ); let builders = [ OpBuilder<(ins "func::FuncOp":$callee, "NamedAttrList":$attributes, CArg<"ValueRange", "{}">:$in), [{ $_state.addOperands(in); $_state.addAttributes(attributes); $_state.addAttribute("callee", SymbolRefAttr::get(callee)); $_state.addTypes(callee.getFunctionType().getResults()); }]> ]; let extraClassDeclaration = [{ operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } CallInterfaceCallable getCallableForCallee() { return (*this)->getAttrOfType("callee"); } void setCalleeFromCallable(CallInterfaceCallable callee) { (*this)->setAttr("callee", callee.get()); } }]; let assemblyFormat = [{ $callee `(` $in `)` attr-dict `:` functional-type($in, results) }]; let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; } #ifdef GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS def OneFlow_MlirJitOp : OneFlow_JITLikeOp<"mlir_jit"> {} def OneFlow_KernelLaunchOp : OneFlow_JITLikeOp<"kernel_launch"> {} #endif // GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS ================================================ FILE: oneflow/ir/include/OneFlow/OneFlowUtils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_ #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringExtras.h" namespace mlir { namespace oneflow { void CheckEnableIRPrinting(mlir::PassManager& pm); // sanitize identifier to make the special name allowed as a legal token StringRef SanitizeIdentifier(StringRef name, SmallString<16>& buffer, StringRef allowedPunctChars = "$._", bool allowTrailingDigit = true); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Passes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Pass/Pass.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "OneFlow/Conversion/OneFlowToTosa.h" #include "OneFlow/Transform/OneFlowMemPool.h" #include "OneFlow/Transform/BufferHostRegister.h" #include "OneFlow/Transform/ConvertInferenceOp.h" #include "OneFlow/Transform/OutlineAndFuse.h" #include "OneFlow/Transform/AutoNhwc.h" #include "OneFlow/Transform/AggregateOps.h" #include "OneFlow/Transform/FuncOps.h" #include "OneFlow/Transform/CSEWithAttributesIgnored.h" #include "OneFlow/Transform/OneFlowStream.h" #include "OneFlow/Transform/EliminateAllocOps.h" #include "OneFlow/Transform/TraitFolder.h" #ifdef WITH_MLIR_CUDA_CODEGEN #include "OneFlow/Conversion/NVVMToCubin.h" #endif // WITH_MLIR_CUDA_CODEGEN namespace mlir { namespace oneflow { #define GEN_PASS_CLASSES #define GEN_PASS_REGISTRATION #include "OneFlow/OneFlowPasses.h.inc" LogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module); #ifdef WITH_MLIR_CUDA_CODEGEN LogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module); #endif // WITH_MLIR_CUDA_CODEGEN void populateWrapOpsToKernelLaunchPatterns(::mlir::RewritePatternSet& patterns, const std::string& mode); void populateFuserForExistingOp(::mlir::RewritePatternSet& patterns); void populateGpuHelperPatterns(::mlir::RewritePatternSet& patterns); void populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns); void populatePreConvertInferenceOp(::mlir::RewritePatternSet& patterns); void populateConvertInferenceOp(::mlir::RewritePatternSet& patterns); void populatePostConvertInferenceOp(::mlir::RewritePatternSet& patterns); namespace okl_func { const auto OKL_FUNC = "_mlir_okl_subgraph"; } // namespace okl_func } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPAttributes.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_ #define ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #define GET_ATTRDEF_CLASSES #include "OneFlow/SBPAttributes.h.inc" #endif // ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_ ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPBase.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPBASE #define ONEFLOW_IR_INCLUDE_SBP_SBPBASE include "OneFlow/SBP/SBPDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" class SBP_Attr traits = []> : AttrDef { let mnemonic = attrMnemonic; } def SBP_SplitAttr : SBP_Attr<"Split", "S"> { let summary = "Signature S"; let description = [{ signature split, representing a sharded tensor at the `axis` }]; let parameters = (ins "int":$axis); let assemblyFormat = "`<` $axis `>`"; } def SBP_BroadcastAttr : SBP_Attr<"Broadcast", "B"> { let summary = "Signature B"; let description = [{ signature broadcast, representing a tensor to be duplicated }]; } def SBP_PartialSumAttr : SBP_Attr<"PartialSum", "P"> { let summary = "Signature P"; let description = [{ signature partial sum, representing a shareded tensor will be reduced lazily }]; } def SBP_AnyAttr : SBP_Attr<"Any", "Any"> { let summary = "Signature Any"; let description = [{ signature any, representing one of sbp tensor; }]; } def SBP_ParallelSignatureAttr : SBP_Attr<"ParallelSignature", "parallel"> { let summary = "Parallel signature of OneFlow Op, aka. SBP"; let description = [{ To represent a signature, with a arrow in beween, pass two listes corepondent to the data input and data output tensors. For example: ``` #sbp.parallel<[#sbp.S<0>] -> [#sbp.S<0>]> ``` One level nested list is used to represent a 2D parallelism signature. For example: ``` #sbp.parallel<[[#sbp.S<0>, #sbp.P]] -> [#sbp.S<0>]> ``` }]; let parameters = (ins "ArrayAttr":$inputs, "ArrayAttr":$outputs); let assemblyFormat = "`<` custom($inputs) ` ` `->` ` ` custom($outputs) `>`"; } #endif // ONEFLOW_IR_INCLUDE_SBP_SBPBASE ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPDialect.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_ #define ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_ #include "mlir/IR/Dialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "OneFlow/SBPDialect.h.inc" #endif // ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_ ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPDialect.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT #define ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT include "mlir/IR/OpBase.td" def SBP_Dialect : Dialect { let name = "sbp"; let summary = "S(split)B(broadcast)P(partial sum) MLIR dialect."; let description = [{ This dialect is the IR of S(split)B(broadcast)P(partial sum). }]; let cppNamespace = "::mlir::sbp"; let dependentDialects = [ "func::FuncDialect" ]; let extraClassDeclaration = [{ void registerAttributes(); }]; let useDefaultAttributePrinterParser = 1; } #endif // ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPImporter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_ #define ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_ #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "OneFlow/OneFlowOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include #include namespace mlir { namespace oneflow { class SBPTranslation { public: static mlir::LogicalResult PrintSbpAttrToString(mlir::Attribute sbp_attr, std::string& sbp); static mlir::Attribute ConvertSBPToString(mlir::Builder& builder, mlir::sbp::ParallelSignatureAttr& parallel); static mlir::Attribute ConvertNdSbpToPsig(mlir::Builder& builder, const std::vector& nd_sbp, const int nd_size); }; } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_ ================================================ FILE: oneflow/ir/include/OneFlow/SBP/SBPOps.td ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_SBP_SBPOPS #define ONEFLOW_IR_INCLUDE_SBP_SBPOPS include "OneFlow/SBP/SBPDialect.td" include "OneFlow/SBP/SBPBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Pass/PassBase.td" include "mlir/IR/OpBase.td" #endif // ONEFLOW_IR_INCLUDE_SBP_SBPOPS ================================================ FILE: oneflow/ir/include/OneFlow/Transform/AggregateOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createAggregateComputeOpsPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/AutoNhwc.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createAutoNhwcPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/BufferHostRegister.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createBufferHostRegisterPass(); std::unique_ptr createGpuCopyArgPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { struct CSEState { llvm::DenseMap scopeSymbolIDs; llvm::DenseMap opNames; }; std::unique_ptr createCSEWithAttributesIgnored(); std::unique_ptr createCSEPutAttributes(); std::pair, std::unique_ptr> createCSEPasses( std::shared_ptr state); void registerCSEPasses(std::shared_ptr state); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/ConvertInferenceOp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createPreConvertInferenceOpPass(); std::unique_ptr createConvertInferenceOpPass(); std::unique_ptr createPostConvertInferenceOpPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/EliminateAllocOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createEliminateAllocOpsPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/FuncOps.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createOneFlowJobToFuncPass(); std::unique_ptr createFuncToOneFlowJobPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/OneFlow MLIR CodeGen ABI.md ================================================ mlir生成的llvm最终的参数列表为: - 缓存池相关信息 - 输入1相关信息 ... 输入n相关信息 - 输出1相关信息 ... 输出n相关信息 - stream 相关信息 基于上述abi设计相关pass - append-ofstream - insert-ofmempool ================================================ FILE: oneflow/ir/include/OneFlow/Transform/OneFlowMemPool.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_ #include "mlir/Pass/Pass.h" #include "mlir/Dialect/Func/IR/FuncOps.h" namespace mlir { namespace oneflow { namespace codegen { namespace mempool { inline const std::string MEMPOOL_ATTR_NAME = "oneflow.mempool"; } // namespace mempool } // namespace codegen void applyFoldAlloc(func::FuncOp op); std::unique_ptr createFoldAllocToSubviewPass(); std::unique_ptr createInsertOneFlowMemPoolPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/OneFlowStream.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createAppendOneFlowStreamPass(); std::unique_ptr createMgpuToOneFlowStreamPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/OutlineAndFuse.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace wrap_mode { inline const std::string SIMPLE = "simple"; inline const std::string CUDA_GRAPH = "cuda_graph"; } // namespace wrap_mode namespace jit { inline const std::string RAW_GRAPH = "oneflow.raw_graph"; } std::unique_ptr createWrapOpsToKernelLaunchPass(); std::unique_ptr createOutlineJitFunctionPass(); std::unique_ptr createFuseIntoExistingOpPass(); std::unique_ptr createGroupMatMul(); std::unique_ptr createFuseForwardOps(); std::unique_ptr createFuseOpsWithBackwardImpl(); std::unique_ptr createFuseNormalizationOps(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/TraitFolder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { std::unique_ptr createTestOneFlowTraitFolderPass(); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_ ================================================ FILE: oneflow/ir/include/OneFlow/Transform/TransposeHelpers.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "OneFlow/OneFlowOps.h" namespace mlir { namespace oneflow { RankedTensorType getNHWCType(RankedTensorType t); RankedTensorType getNHWCType(Type t); RankedTensorType getNHWCType(Value v); RankedTensorType getNCHWType(RankedTensorType t); RankedTensorType getNCHWType(Type t); RankedTensorType getNCHWType(Value v); llvm::SmallVector getNHWCResultTypes(NCHWCompatible op); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_ ================================================ FILE: oneflow/ir/include/OneFlow/UserOpConversion.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_ #include "OneFlow/OneFlowOps.h" namespace mlir { namespace oneflow { namespace user_op { ::oneflow::ShapeProto getAttrAsShape(mlir::Attribute& attr); ::oneflow::Int64ListProto getAttrAsStride(mlir::Attribute& attr); ::oneflow::AttrType queryAttrType(const std::string& op_type_name, const std::string& attr_name); LogicalResult saveAttrDictionaryToOpConf(DictionaryAttr attributes, ::oneflow::OperatorConf* op_conf); LogicalResult ConvertUserOpAttributes(llvm::StringRef op_type_name, ValueRange operands, DictionaryAttr attributes, ::oneflow::OperatorConf& op_conf); LogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf); LogicalResult ConvertUserOpAttributes( Operation* op, ::oneflow::OperatorConf& op_conf, bool is_mapping_size /* the input and output size should be mapped after building kernel and provide information for the next query*/ = false); LogicalResult ConvertUserOpInputs(llvm::StringRef op_type_name, ValueRange operands, DictionaryAttr attributes, ::oneflow::UserOpConf* user_conf); ::oneflow::ParallelConf getParallelConfFromAttrDictionary(DictionaryAttr attributes); ::oneflow::ParallelConf getParallelConfFromAttrs(Attribute device_name_attr, Attribute device_tag_attr); ::oneflow::DeviceType getDeviceTypeFromAttrDictionary(DictionaryAttr attributes); } // namespace user_op } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_ ================================================ FILE: oneflow/ir/include/OneFlow/UserOpReflection.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_ #define ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_ #include "OneFlow/OneFlowOps.h" namespace mlir { namespace oneflow { namespace user_op { template class Trait> LogicalResult GetFilteredSegmentKeyAndSizes(Operation* op, std::vector& keys, std::vector& sizes); template class Trait> LogicalResult GetFilteredSegmentKeyAndSizes(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes, std::vector& keys, std::vector& sizes); struct Source { enum { INPUT, OUTPUT, BUFFER, INVALID, } type; int offset; }; Source GetOpSourceByName(Operation* op, const std::string& to_find); using ArgID = std::pair; template class Trait> class ArgIds { public: explicit ArgIds(Operation* op); ArgIds(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes); std::vector::const_iterator begin() const { return ids_.begin(); } std::vector::const_iterator end() const { return ids_.end(); } private: std::vector ids_; }; llvm::Optional GetOutputLbn(OpResult result); } // namespace user_op } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_ ================================================ FILE: oneflow/ir/include/Transform/CMakeLists.txt ================================================ set(LLVM_TARGET_DEFINITIONS TransformDialectExtension.td) mlir_tablegen(TransformDialectExtension.h.inc -gen-op-decls) mlir_tablegen(TransformDialectExtension.cpp.inc -gen-op-defs) mlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform) mlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform) add_public_tablegen_target(MLIROneFlowTransformDialectExtensionIncGen) ================================================ FILE: oneflow/ir/include/Transform/TransformDialectExtension.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ #define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" namespace mlir { class DialectRegistry; namespace oneflow { namespace transform_dialect { /// Registers the test extension to the Transform dialect. void registerTransformDialectExtension(::mlir::DialectRegistry& registry); void registerTransformDialectEraseSchedulePass(); void registerTransformDialectInterpreterPass(); struct ApplyPatternsOpPatterns { bool canonicalization = false; bool cse = false; }; } // namespace transform_dialect } // namespace oneflow } // namespace mlir #define GET_TYPEDEF_CLASSES #include "Transform/TransformDialectExtensionTypes.h.inc" #define GET_OP_CLASSES #include "Transform/TransformDialectExtension.h.inc" #endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_ ================================================ FILE: oneflow/ir/include/Transform/TransformDialectExtension.td ================================================ #ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ #define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "mlir/Dialect/Transform/IR/MatchInterfaces.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" class ProduceNoneProto traits = []> : Op { let arguments = (ins TransformHandleTypeInterface:$target); let results = (outs); let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; let cppNamespace = "mlir::oneflow::transform_dialect"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); void getEffects( SmallVectorImpl &effects) { ::mlir::transform::onlyReadsHandle(getTarget(), effects); ::mlir::transform::modifiesPayload(effects); } }]; } def CSEOp : ProduceNoneProto<"oneflow.cse"> { let description = [{ cse in transform dialect. }]; } def CanonicalizationOp : ProduceNoneProto<"oneflow.canonicalization"> { let description = [{ canonicalization in transform dialect. }]; } def ExplicitLinalgOutcomeOp : ProduceNoneProto<"oneflow.explicit_linalg_outcome"> { let description = [{ fold unit-extent dimensions in operands/results of linalg ops on tensors via rank-reducing slice in transform dialect. }]; } def EliminateCopyOp : ProduceNoneProto<"oneflow.eliminate_copy"> { let description = [{ eliminate memref.copy if its target equals to source or comes from block arguments. }]; } def FoldAllocOp : ProduceNoneProto<"oneflow.fold_alloc"> { let description = [{ fold memref.alloc to a single one and subview on it. }]; } def ResultsToOutParamsOp : ProduceNoneProto<"oneflow.results_to_out_params"> { let description = [{ move results to out params. }]; } #endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_ ================================================ FILE: oneflow/ir/include/Transform/TransformStateExtension.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ #define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" namespace mlir { namespace oneflow { namespace transform_dialect { class TransformStateExtension : public ::mlir::transform::TransformState::Extension { public: TransformStateExtension(::mlir::transform::TransformState& state, StringAttr message) : Extension(state), message(message) {} StringRef getMessage() const { return message.getValue(); } LogicalResult updateMapping(Operation* previous, Operation* updated); private: StringAttr message; }; } // namespace transform_dialect } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_ ================================================ FILE: oneflow/ir/install-llvm.cmake ================================================ message("-- LLVM_MONO_REPO_URL: " ${LLVM_MONO_REPO_URL}) message("-- LLVM_MONO_REPO_MD5: " ${LLVM_MONO_REPO_MD5}) FetchContent_Declare(llvm_monorepo) FetchContent_GetProperties(llvm_monorepo) if(NOT llvm_monorepo_POPULATED) FetchContent_Populate(llvm_monorepo URL ${LLVM_MONO_REPO_URL} URL_HASH MD5=${LLVM_MONO_REPO_MD5}) set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) execute_process( COMMAND "${CMAKE_COMMAND}" ${llvm_monorepo_SOURCE_DIR}/llvm -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} # this is required in newer version of LLVM -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER} -DCMAKE_EXE_LINKER_FLAGS_INIT=${CMAKE_EXE_LINKER_FLAGS_INIT} -DCMAKE_MODULE_LINKER_FLAGS_INIT=${CMAKE_MODULE_LINKER_FLAGS_INIT} -DCMAKE_SHARED_LINKER_FLAGS_INIT=${CMAKE_SHARED_LINKER_FLAGS_INIT} -DCMAKE_INSTALL_PREFIX=${LLVM_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE=${CMAKE_INSTALL_MESSAGE} -DLLVM_ENABLE_RTTI=ON # turn this on to make it compatible with protobuf -DLLVM_ENABLE_EH=ON # turn this on to make it compatible with half (the library) -DLLVM_BUILD_EXAMPLES=OFF -DLLVM_BUILD_TOOLS=OFF -DLLVM_INCLUDE_EXAMPLES=OFF -DLLVM_INCLUDE_TESTS=OFF -DLLVM_INCLUDE_BENCHMARKS=OFF -DLLVM_TARGETS_TO_BUILD=host\;NVPTX -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_APPEND_VC_REV=OFF -DLLVM_ENABLE_ZLIB=OFF -DLLVM_INSTALL_UTILS=ON -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS} -DLLVM_ENABLE_OCAMLDOC=OFF -DLLVM_ENABLE_BINDINGS=OFF -DLLVM_ENABLE_TERMINFO=OFF # Disable terminfo in llvm so that oneflow doesn't need to link against it -DMLIR_ENABLE_CUDA_RUNNER=${WITH_MLIR_CUDA_CODEGEN} -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} -DINJA_URL=${INJA_URL} -DINJA_URL_HASH=${INJA_URL_HASH} -DJSON_URL=${JSON_URL} -DJSON_URL_HASH=${JSON_URL_HASH} -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} -DLLVM_EXTERNAL_PROJECTS=OneFlowTableGen -DLLVM_EXTERNAL_ONEFLOWTABLEGEN_SOURCE_DIR=${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen -G ${CMAKE_GENERATOR} WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} RESULT_VARIABLE ret) if(ret EQUAL "1") message(FATAL_ERROR "Bad exit status") endif() include(ProcessorCount) ProcessorCount(PROC_NUM) if(WITH_MLIR) set(INSTALL_ALL "install") endif() execute_process( COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} --target ${INSTALL_ALL} install-oneflow-tblgen install-mlir-headers WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} RESULT_VARIABLE ret) if(ret EQUAL "1") message(FATAL_ERROR "Bad exit status") endif() endif() set(LLVM_INCLUDE_DIRS ${llvm_monorepo_SOURCE_DIR}/llvm/include;${llvm_monorepo_BINARY_DIR}/include) if(WITH_MLIR) set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir) find_package(MLIR REQUIRED CONFIG) message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") set(MLIR_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") include(TableGen) include(AddLLVM) include(AddMLIR) include(HandleLLVMOptions) set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "" FORCE) endif() ================================================ FILE: oneflow/ir/lib/CMakeLists.txt ================================================ add_subdirectory(OneFlow) add_subdirectory(Transform) ================================================ FILE: oneflow/ir/lib/OneFlow/CMakeLists.txt ================================================ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) message(STATUS "MLIR_DIALECT_LIBS: ${dialect_libs}") if(WITH_MLIR_CUDA_CODEGEN) set(MLIR_GPU_LIBS MLIRGPUToNVVMTransforms MLIRNVVMToLLVMIRTranslation) endif(WITH_MLIR_CUDA_CODEGEN) set(ONEFLOW_OP_GROUPS "ASSIGN;BINARY;BROADCAST;CONV;CROSS_ENTROPY;CUDA;DATASET;DETECTION;EAGER;FUSED;IDEMPOTENT;IDENTITY;IMAGE;INDICES;INVOLUTION;LOSS;MATH;MATMUL;MISC;NCCL;NORMALIZATION;OPTIMIZER;PADDING;PARALLEL_CAST;POOL;QUANTIZATION;REDUCE;RESHAPE;SCALAR;SOFTMAX;SUMMARY;TENSOR_BUFFER;TEST;TRIGONOMETRIC;UNARY;UPSAMPLE;ONE_EMBEDDING;LINEAR_ALGEBRA;SYSTEM;MLIR_JIT" ) foreach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) string(TOLOWER "${OP_GROUP_NAME}" OP_GROUP_NAME_LOWER) set(CPP_FILE "OneFlow.${OP_GROUP_NAME_LOWER}_ops.cpp") list(APPEND GROUPED_OP_CPP_FILES "${CPP_FILE}") configure_file(OneFlowOpGetGen.cpp.in ${CPP_FILE} @ONLY) endforeach() add_subdirectory(PDLL) oneflow_add_mlir_dialect_library( MLIROneFlow OKM/OKMDialect.cpp OKM/passes.cpp OKM/Conversion/Conversion.cpp OKL/OKLDialect.cpp OKL/OKLOps.cpp OKL/OKLTypes.cpp OKL/Conversion/OKLToLLVM.cpp OKL/Conversion/CudaGraphSupport.cpp OKL/Conversion/Conversion.cpp OKL/Kernel/InferContext.cpp OKL/Kernel/KernelLaunchOp.cpp OKL/Kernel/LauncherState.cpp OKL/Kernel/LauncherContext.cpp OKL/Kernel/ComputeContext.cpp OKL/Kernel/RegContext.cpp OKL/Kernel/TmpBufferManager.cpp OKL/Kernel/JITOpInfer.cpp OKL/Kernel/JITEngine.cpp SBP/SBPDialect.cpp SBP/SBPAttributes.cpp SBP/SBPImporter.cpp OneFlowDialect.cpp OneFlowTypes.cpp OneFlowInferReturnTypes.cpp OneFlowOps.cpp OneFlowOpTraits.cpp OneFlowSupport.cpp OneFlowUtils.cpp OneFlowDataTypeConversion.cpp UserOpReflection.cpp UserOpConversion.cpp OneFlowOpFolders.cpp Conversion/OneFlowToTosa.cpp Conversion/OneFlowToLinalg.cpp Conversion/NVVMToCubin.cpp Transform/BufferHostRegister.cpp Transform/OutlineAndFuse.cpp Transform/JITPasses.cpp Transform/AutoNhwc.cpp Transform/ConvertInferenceOp.cpp Transform/AggregateOps.cpp Transform/EliminateAllocOps.cpp Transform/FuncOps.cpp Transform/CSEWithAttributesIgnored.cpp Transform/GroupMatMulOps.cpp Transform/AutoNHWCOps.cpp Transform/OneFlowMemPool.cpp Transform/OneFlowStream.cpp Transform/TraitFolder.cpp TransposeHelpers.cpp Passes.cpp OneFlowCanonicalizers.cpp OneFlowRewrites.cpp ${GROUPED_OP_CPP_FILES} ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/OneFlow DEPENDS MLIROneFlowOpsIncGen prepare_oneflow_third_party LINK_LIBS PUBLIC ${dialect_libs} MLIRTosaToLinalg MLIRTosaToTensor MLIRMemRefToLLVM MLIRLinalgToLLVM MLIRSCFToGPU MLIRReconcileUnrealizedCasts ${MLIR_GPU_LIBS} MLIRIR MLIRBytecodeWriter MLIROneFlowPDLLPatterns MLIRExecutionEngine oneflow) if(WITH_MLIR_CUDA_CODEGEN) find_library(CUDA_DRIVER_LIBRARY cuda) target_link_libraries(MLIROneFlow PRIVATE ${CUDA_DRIVER_LIBRARY}) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) endif(WITH_MLIR_CUDA_CODEGEN) ================================================ FILE: oneflow/ir/lib/OneFlow/Conversion/NVVMToCubin.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_MLIR_CUDA_CODEGEN #include "oneflow/core/common/util.h" #include "OneFlow/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Linker/Linker.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/FileSystem.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/Scalar/DCE.h" #include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "llvm/Transforms/Vectorize/SLPVectorizer.h" #include #include static void emitCudaError(const llvm::Twine& expr, const char* buffer, CUresult result, mlir::Location loc) { const char* error; cuGetErrorString(result, &error); emitError(loc, expr.concat(" failed with error code ") .concat(llvm::Twine{error}) .concat("[") .concat(buffer) .concat("]")); } #define RETURN_ON_CUDA_ERROR(expr) \ do { \ if (auto status = (expr)) { \ emitCudaError(#expr, jitErrorBuffer, status, loc); \ return {}; \ } \ } while (false) namespace mlir { namespace oneflow { const char* getArchVersion() { static std::string version; if (!version.empty()) return version.c_str(); cudaDeviceProp prop{}; cudaError_t err = cudaGetDeviceProperties(&prop, 0); if (err != cudaSuccess) { printf("%s\n", cudaGetErrorString(err)); exit(1); } version = "sm_" + std::to_string(prop.major) + std::to_string(prop.minor); return version.c_str(); } namespace { const std::string& getLibDevice() { static std::string p; if (!p.empty()) return p; const auto toolkit_env_name = "CUDA_TOOLKIT_ROOT_DIR"; p = ::oneflow::GetStringFromEnv(toolkit_env_name, "/usr/local/cuda/") + "nvvm/libdevice/libdevice.10.bc"; if (llvm::sys::fs::exists(p)) return p; LOG(FATAL) << "Could not find file: " << p << ". Please check you cuda toolkit directory and set " << toolkit_env_name << " correctly as an environment variable"; } LogicalResult linkLibdevice(llvm::Module& llvmModule, llvm::LLVMContext& llvmContext) { // Note: infer libdevice path from environment variable auto libDevice = getLibDevice(); // Note: load raw data from file std::string errorMessage; auto libDeviceBuf = openInputFile(libDevice, &errorMessage); if (!libDeviceBuf) LOG(FATAL) << "Open File error when link libdevice: " << errorMessage; // Note: load module from raw data auto moduleOrErr = llvm::getOwningLazyBitcodeModule(std::move(libDeviceBuf), llvmContext); if (!moduleOrErr) LOG(FATAL) << "Failed to load: " << libDevice << "\n"; std::unique_ptr libDeviceModule = std::move(moduleOrErr.get()); // Note: link libdevice with module if (llvm::Linker::linkModules(llvmModule, std::move(libDeviceModule), llvm::Linker::Flags::LinkOnlyNeeded, [](llvm::Module& M, const llvm::StringSet<>& GS) { llvm::internalizeModule(M, [&GS](const llvm::GlobalValue& GV) { return !GV.hasName() || (GS.count(GV.getName()) == 0); }); })) { LOG(FATAL) << "failed to link libdevice module\n"; } return success(); } std::optional translateToISA(llvm::Module& llvmModule, llvm::TargetMachine& targetMachine) { llvmModule.setDataLayout(targetMachine.createDataLayout()); // TODO(yuhao): optimizeLlvm std::string targetISA; llvm::raw_string_ostream stream(targetISA); { // Note: Drop pstream after this to prevent the ISA from being stuck buffering llvm::buffer_ostream pstream(stream); llvm::legacy::PassManager codegenPasses; if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr, llvm::CGFT_AssemblyFile)) return std::nullopt; codegenPasses.run(llvmModule); } return stream.str(); } class NVVMToCubinPass : public NVVMToCubinPassBase { std::unique_ptr translateToLLVMIR(llvm::LLVMContext& llvmContext) { return translateModuleToLLVMIR(getOperation(), llvmContext, "LLVMDialectModule"); } public: std::unique_ptr createTargetMachine(); std::unique_ptr> serializeISA(const std::string& isa); void runOnOperation() override; void getDependentDialects(::mlir::DialectRegistry& registry) const override { registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); registerGPUDialectTranslation(registry); registerLLVMDialectTranslation(registry); } }; std::unique_ptr NVVMToCubinPass::createTargetMachine() { Location loc = getOperation().getLoc(); std::string error; const llvm::Target* target = ::llvm::TargetRegistry::lookupTarget(triple.str(), error); if (!target) { emitError(loc, Twine("failed to lookup target: ") + error); return {}; } llvm::TargetMachine* machine = target->createTargetMachine(triple.str(), chip.str(), features.str(), {}, {}); if (!machine) { emitError(loc, "failed to create target machine"); return {}; } return std::unique_ptr{machine}; } std::unique_ptr> NVVMToCubinPass::serializeISA(const std::string& isa) { Location loc = getOperation().getLoc(); char jitErrorBuffer[4096] = {0}; RETURN_ON_CUDA_ERROR(cuInit(0)); // Note: Linking requires a device context. CUdevice device; RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0)); CUcontext context; RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device)); CUlinkState linkState; CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; void* jitOptionsVals[] = {jitErrorBuffer, reinterpret_cast(sizeof(jitErrorBuffer))}; RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */ jitOptions, /* jit options */ jitOptionsVals, /* jit option values */ &linkState)); auto kernelName = getOperation().getName().str(); RETURN_ON_CUDA_ERROR(cuLinkAddData(linkState, CUjitInputType::CU_JIT_INPUT_PTX, const_cast(static_cast(isa.c_str())), isa.length(), kernelName.c_str(), 0, /* number of jit options */ nullptr, /* jit options */ nullptr /* jit option values */ )); void* cubinData; size_t cubinSize; RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize)); char* cubinAsChar = static_cast(cubinData); auto result = std::make_unique>(cubinAsChar, cubinAsChar + cubinSize); RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState)); RETURN_ON_CUDA_ERROR(cuCtxDestroy(context)); return result; } void NVVMToCubinPass::runOnOperation() { llvm::LLVMContext llvmContext; std::unique_ptr llvmModule = translateToLLVMIR(llvmContext); if (!llvmModule) return signalPassFailure(); if (failed(linkLibdevice(*llvmModule, llvmContext))) { return signalPassFailure(); } // Note: Lower the LLVM IR module to target ISA. std::unique_ptr targetMachine = createTargetMachine(); if (!targetMachine) return signalPassFailure(); std::optional maybeTargetISA = translateToISA(*llvmModule, *targetMachine); if (!maybeTargetISA.has_value()) return signalPassFailure(); std::string targetISA = std::move(*maybeTargetISA); // Note: Serialize the target ISA. std::unique_ptr> blob = serializeISA(targetISA); if (!blob) return signalPassFailure(); // Note: Add the blob as module attribute. auto attr = StringAttr::get(&getContext(), StringRef(blob->data(), blob->size())); getOperation()->setAttr(gpu::getCubinAnnotation(), attr); } } // namespace std::unique_ptr createNVVMToCubinPass() { return std::make_unique(); } void InitializeLLVMNVPTXBackend() { LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); } } // namespace oneflow } // namespace mlir #endif // WITH_MLIR_CUDA_CODEGEN ================================================ FILE: oneflow/ir/lib/OneFlow/Conversion/OneFlowToLinalg.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace oneflow { namespace { std::tuple, SmallVector> computeIteratorTypesAndIndexingMaps(int64_t inputRank, int64_t dim, OpBuilder& builder, bool allParallel = false) { SmallVector<::mlir::utils::IteratorType> iteratorTypes(inputRank, ::mlir::utils::IteratorType::parallel); if (!allParallel) iteratorTypes[dim] = ::mlir::utils::IteratorType::reduction; auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, builder.getContext()); SmallVector affineExprs; for (int i = 0; i < inputRank; i++) { if (i != dim) affineExprs.push_back(mlir::getAffineDimExpr(i, builder.getContext())); } auto reductionMap = AffineMap::get(inputRank, 0, affineExprs, builder.getContext()); SmallVector indexingMaps{identityMap, reductionMap}; return std::make_tuple(iteratorTypes, indexingMaps); } template static Value reduce(Value input, Value output, int64_t dim, Location loc, OpBuilder& builder) { auto inputType = input.getType().cast(); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(inputRank, dim, builder); auto genericOp = builder.create( loc, output.getType(), input, output, indexingMaps, iteratorTypes, [&](OpBuilder& b, Location loc, ValueRange args) { Value result = b.create(loc, args[0], args[1]); b.create(loc, result); }); return genericOp.getResult(0); } static Value subtractAndExp(Value input, Value max, Value output, int64_t dim, Location loc, OpBuilder& builder) { auto inputType = input.getType().cast(); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true); indexingMaps.push_back(indexingMaps[0]); auto genericOp = builder.create( loc, input.getType(), ValueRange{input, max}, output, indexingMaps, iteratorTypes, [&](OpBuilder& b, Location loc, ValueRange args) { Value diff = b.create(loc, args[0], args[1]); Value result = b.create(loc, diff); b.create(loc, result); }); return genericOp.getResult(0); } static Value computeSoftmax(Value numerator, Value denominator, Value output, int64_t dim, Location loc, OpBuilder& builder) { auto inputType = numerator.getType().cast(); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true); indexingMaps.push_back(indexingMaps[0]); auto genericOp = builder.create( loc, numerator.getType(), ValueRange{numerator, denominator}, output, indexingMaps, iteratorTypes, [&](OpBuilder& b, Location loc, ValueRange args) { Value result = b.create(loc, args[0], args[1]); b.create(loc, result); }); return genericOp.getResult(0); } /// Given an N-dimensional tensor x, this op converts /// softmax(x) to the following sequence of operations: /// /// 1. Compute the max of x along dimension d. This results /// in a N-1 dimensional tensor m. /// m = max(x, dim = d) /// /// 2. Subtract m from x and exponentiate. This results in /// a N dimensional tensor z. /// z = exp(x - m) /// /// 3. Compute the sum of z along dimension d. This results in /// a N-1 dimensional tensor l. /// l = sum(z, dim = d) /// /// 4. Divide z and l. This gives the N-dimensional softmax. /// softmax = z / l /// // Implementation above is from IREE. // https://github.com/google/iree/blob/b339919814f10589f779b39c3ab7c6575716dab6/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp SmallVector createDimValues(OpBuilder& b, Location loc, Value rankedTensor) { auto tensorTy = rankedTensor.getType().cast(); SmallVector dims; for (const auto& en : llvm::enumerate(tensorTy.getShape())) { if (ShapedType::isDynamic(en.value())) { dims.push_back(b.createOrFold(loc, rankedTensor, en.index())); } else { dims.push_back(b.getIndexAttr(en.value())); } } return dims; } struct SoftmaxOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SoftmaxOp softmaxOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(softmaxOp); Location loc = softmaxOp.getLoc(); Value input = softmaxOp.getIn(); ShapedType inputType = input.getType().cast(); Type elementType = inputType.getElementType(); int64_t reductionDim = inputType.getRank() - 1; SmallVector dims = createDimValues(rewriter, loc, input); Value outputNd = rewriter.create(loc, dims, elementType); dims.erase(dims.begin() + reductionDim); // Compute max along dim Value output = rewriter.create(loc, dims, elementType); Value largeNegative = rewriter.create(loc, rewriter.getFloatAttr(elementType, -1.0e30)); Value negativeInit = rewriter.create(loc, Value{largeNegative}, output).result(); Value max = reduce(input, negativeInit, reductionDim, loc, rewriter); // Subtract max from input and exponentiate Value numerator = subtractAndExp(input, max, outputNd, reductionDim, loc, rewriter); // Compute sum along dim Value zero = rewriter.create(loc, rewriter.getZeroAttr(elementType)); Value zeroInit = rewriter.create(loc, Value{zero}, output).result(); Value denominator = reduce(numerator, zeroInit, reductionDim, loc, rewriter); // Compute softmax Value result = computeSoftmax(numerator, denominator, outputNd, reductionDim, loc, rewriter); rewriter.replaceOp(softmaxOp, {result}); return success(); } }; struct OneFlowLoweringToLinalgPass : public LowerOneFlowToLinalgPassBase { void runOnOperation() { MLIRContext* context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); RewritePatternSet patterns(context); patterns.add(context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); } } }; } // namespace std::unique_ptr createLowerOneFlowToLinalgPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Conversion/OneFlowToTosa.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include #include #include #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #include namespace mlir { namespace oneflow { Type convertToSignless(MLIRContext* context, Type type) { if (auto ranked_tensor = type.dyn_cast()) { if (auto intTy = ranked_tensor.getElementType().dyn_cast()) { if (!intTy.isSignless()) { return RankedTensorType::get( ranked_tensor.getShape(), IntegerType::get(context, intTy.getWidth(), mlir::IntegerType::SignednessSemantics::Signless)); } } } return type; } FunctionType convertToSignlessFuncType(MLIRContext* context, FunctionType funcType) { llvm::SmallVector inputs; llvm::SmallVector results; for (auto arg : funcType.getInputs()) { inputs.push_back(convertToSignless(context, arg)); } for (auto res : funcType.getResults()) { results.push_back(convertToSignless(context, res)); } return FunctionType::get(context, inputs, results); } bool isSignLessTensorOrOther(Type type) { if (auto ranked_tensor = type.dyn_cast()) { if (auto intTy = ranked_tensor.getElementType().dyn_cast()) { if (intTy.isUnsigned()) { return false; } if (intTy.isSigned()) { return false; } } } return true; } bool allSignless(mlir::TypeRange types) { for (auto type : types) { if (!isSignLessTensorOrOther(type)) { return false; } } return true; } bool allSignless(FunctionType funcType) { for (auto arg : funcType.getInputs()) { if (!isSignLessTensorOrOther(arg)) { return false; } } for (auto res : funcType.getResults()) { if (!isSignLessTensorOrOther(res)) { return false; } } return true; } Value CreateTransposeValue(Location& loc, ConversionPatternRewriter& rewriter, Value input, ArrayRef perms) { int perms_size = perms.size(); auto transpose_perms = rewriter.create( loc, RankedTensorType::get({perms_size}, rewriter.getI32Type()), rewriter.getI32TensorAttr(perms)); const auto shape_type = input.getType().cast(); std::vector ranked_type; for (const auto& index : perms) ranked_type.push_back(shape_type.getDimSize(index)); return rewriter.create( loc, RankedTensorType::get(ranked_type, shape_type.getElementType()), input, transpose_perms); }; RankedTensorType CreateTransposeType(ShapedType output, ArrayRef perms) { std::vector ranked_type; for (auto index : perms) ranked_type.push_back(output.getDimSize(index)); return RankedTensorType::get(ranked_type, output.getElementType()); }; Value CreateBNOp(Location loc, ConversionPatternRewriter& rewriter, Type output_type, Value x, Value mean, Value variance, Value epsilon, Value gamma, Value beta) { // sub_op = sub(input, mean) auto sub_op0 = rewriter.create(loc, output_type, x, mean); // add_op0 = add(var, epsilon) auto add_op0 = rewriter.create(loc, variance.getType(), variance, epsilon); // rsqrt_op = rsqrt(add_op0) auto rsqrt_op = rewriter.create(loc, variance.getType(), add_op0); // op4 = mul(sub_op, rsqrt_op) auto mul_op0 = rewriter.create(loc, output_type, sub_op0, rsqrt_op, 0); // op5 = mul(mul_op0, gamma) auto mul_op1 = rewriter.create(loc, output_type, mul_op0, gamma, 0); // op6 = add(mul_op1, beta) Value batch_norm = rewriter.create(loc, output_type, mul_op1, beta); return batch_norm; }; struct ScalarMulByTensorOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ScalarMulByTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Value scalar = op.getScalar(); rewriter.replaceOpWithNewOp( op, /* output */ op->getResultTypes().front().cast(), /* input1 */ op.getX(), /* input2 */ scalar, /* shift */ rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); return success(); } }; struct JobLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Job op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto func_type = convertToSignlessFuncType(op->getContext(), op.getFunctionType()); auto func = rewriter.create(op.getLoc(), op.getName(), func_type); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); } }; struct ReturnOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp(op, /* operands */ op.getOperands()); return success(); } }; struct InputOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(InputOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // TODO: more choices to passing data between tosa and oneflow const auto newValues = op.getInput(); const auto is_block_arg = newValues.dyn_cast() != nullptr; if (!is_block_arg) { return op->emitError("input is not block arg"); } rewriter.replaceOp(op, newValues); return success(); } }; struct OutputOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OutputOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // TODO: more choices to passing data between tosa and oneflow const auto newValues = op.getInput(); rewriter.replaceOp(op, newValues); return success(); } }; struct VariableOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { const auto mgr = ::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get(); if (!mgr) { return op->emitError("global variable tensor manager miss"); } const auto tensor = CHECK_JUST(mgr->Get(op.getOpName().str())); if (!tensor) { return op->emitError("tensor is null"); } const auto value = support::TensorToDenseElementsAttr(tensor, rewriter.getContext()); const auto output = op.getOutput().getType(); rewriter.replaceOpWithNewOp(op, output, value); return success(); } }; struct VariableOpToConstLowering final : public OpConversionPattern { public: VariableOpToConstLowering(TypeConverter& typeConverter, MLIRContext* context, int const_val) : OpConversionPattern(typeConverter, context), const_val_(const_val){}; using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { const auto output = op.getOutput().getType(); const auto type = output.cast().getElementType(); // TODO: more control about this scope with flag if (type.isa()) { const auto float_attr = rewriter.getFloatAttr(type, const_val_); auto value = DenseElementsAttr::get(output.cast(), float_attr); rewriter.replaceOpWithNewOp(op, output, value); } else if (auto integerType = type.dyn_cast()) { const auto int_attr = rewriter.getIntegerAttr(type, APInt(type.cast().getWidth(), const_val_)); auto value = DenseElementsAttr::get(output.cast(), int_attr); rewriter.replaceOpWithNewOp(op, output, value); } else { return op->emitError( "OneFlow variable op lower to TOSA const op only support integer and float value now"); } return success(); } private: int const_val_; }; struct CastOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CastOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto output = op.getOut().getType(); auto input = op.getIn(); rewriter.replaceOpWithNewOp(op, output, input); return success(); } }; struct ReluOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReluOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { const auto output = op.getY().getType(); auto input = op.getX(); auto ranked_output = llvm::dyn_cast_or_null(output); auto value = DenseElementsAttr::get(output.cast(), rewriter.getZeroAttr(ranked_output ? ranked_output.getElementType() : rewriter.getI64Type())); tosa::ConstOp zeros = rewriter.create(op.getLoc(), output, value); rewriter.replaceOpWithNewOp(op, output, input, zeros); return success(); } }; struct BroadcastAddOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BroadcastAddOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { const auto output = op.getZ().getType(); auto input1 = op.getX(); auto input2 = op.getY(); rewriter.replaceOpWithNewOp(op, output, input1, input2); return success(); } }; struct Add2OpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Add2Op op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { const auto output = op.getOut().getType(); auto input1 = op.getIn0(); auto input2 = op.getIn1(); rewriter.replaceOpWithNewOp(op, output, input1, input2); return success(); } }; struct AvgPool2DOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AvgPool2DOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair { return {arr.getValue()[0].cast().getSInt(), arr.getValue()[1].cast().getSInt()}; }; auto stride_pairs = get_pair_int64_from_array(op.getStride()); auto pad_pairs = get_pair_int64_from_array(op.getPadding()); auto kernel_pairs = get_pair_int64_from_array(op.getKernelSize()); auto loc = op.getLoc(); auto perms = {0, 2, 3, 1}; const auto kernel = rewriter.getDenseI64ArrayAttr({kernel_pairs.first, kernel_pairs.second}); const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second}); const auto pad = rewriter.getDenseI64ArrayAttr( {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second}); auto input = CreateTransposeValue(loc, rewriter, op.getX(), perms); auto output = CreateTransposeType(op.getY().getType().cast(), perms); auto avg_pool2d = rewriter.create(loc, output, input, kernel, stride, pad); auto out = CreateTransposeValue(loc, rewriter, avg_pool2d, {0, 3, 1, 2}); rewriter.replaceOp(op, {out}); return success(); } }; struct MaxPool2DOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(MaxPool2DOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair { return {arr.getValue()[0].cast().getSInt(), arr.getValue()[1].cast().getSInt()}; }; // TODO: support return indice if (op.getReturnIndices()) { return op->emitError("not support return indices now"); } auto stride_pairs = get_pair_int64_from_array(op.getStride()); auto kernel_pairs = get_pair_int64_from_array(op.getKernelSize()); auto pad_pairs = get_pair_int64_from_array(op.getPadding()); auto loc = op.getLoc(); const auto kernel = rewriter.getDenseI64ArrayAttr({kernel_pairs.first, kernel_pairs.second}); const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second}); const auto pad = rewriter.getDenseI64ArrayAttr( {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second}); auto input = op.getX(); auto out_type = op.getY().getType().cast(); Value y; if (op.IsNCHW()) { auto perms = {0, 2, 3, 1}; auto reverse_perms = {0, 3, 1, 2}; input = CreateTransposeValue(loc, rewriter, input, perms); out_type = CreateTransposeType(out_type, perms); auto max_pool2d = rewriter.create(loc, out_type, input, kernel, stride, pad); y = CreateTransposeValue(loc, rewriter, max_pool2d, reverse_perms); } else { y = rewriter.create(loc, out_type, input, kernel, stride, pad); } auto indice_output = convertToSignless(op->getContext(), op.getIndice().getType()); auto value = DenseElementsAttr::get(indice_output.cast(), rewriter.getZeroAttr(rewriter.getI64Type())); tosa::ConstOp indice = rewriter.create(loc, indice_output, value); rewriter.replaceOp(op, {y, indice}); return success(); } }; struct ReshapeOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto output = op.getOut().getType(); auto input = op.getIn(); llvm::SmallVector new_shape; for (const auto& dim_attr : op.getShape()) { new_shape.push_back(dim_attr.cast().getSInt()); } rewriter.replaceOpWithNewOp(op, output, input, rewriter.getDenseI64ArrayAttr(new_shape)); return success(); } }; // transpose the last two dims of the tensor. Reshape it to 3D if it is 2D. Value transposeAndReshapeIfRequired(Location loc, ConversionPatternRewriter& rewriter, Value matrix, bool transpose) { auto shape_type = matrix.getType().cast(); CHECK(shape_type.getRank() == 2 || shape_type.getRank() == 3); if (transpose) { if (shape_type.getRank() == 2) { matrix = CreateTransposeValue(loc, rewriter, matrix, {1, 0}); shape_type = matrix.getType().cast(); llvm::SmallVector reshape_dims{1, shape_type.getDimSize(0), shape_type.getDimSize(1)}; auto reshape_type = RankedTensorType::get(reshape_dims, shape_type.getElementType()); return rewriter.create(loc, reshape_type, matrix, rewriter.getDenseI64ArrayAttr(reshape_dims)); } else if (shape_type.getRank() == 3) { return CreateTransposeValue(loc, rewriter, matrix, {0, 2, 1}); } else { return Value{}; } } else if (shape_type.getRank() == 2) { llvm::SmallVector reshape_dims{1, shape_type.getDimSize(0), shape_type.getDimSize(1)}; auto reshape_type = RankedTensorType::get(reshape_dims, shape_type.getElementType()); return rewriter.create(loc, reshape_type, matrix, rewriter.getDenseI64ArrayAttr(reshape_dims)); } return matrix; } // Reshape: 2D -> 3D -> tosa.matmul -> 3D -> 2D struct MatmulOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(MatmulOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto a = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getA(), op.getTransposeA()); auto b = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getB(), op.getTransposeB()); const auto out_shape_type = op.getOut().getType().cast(); const auto out_reshape_type = RankedTensorType::get({1, out_shape_type.getDimSize(0), out_shape_type.getDimSize(1)}, out_shape_type.getElementType()); auto matmul = rewriter.create(op.getLoc(), out_reshape_type, a, b); const auto new_shape = rewriter.getDenseI64ArrayAttr({out_shape_type.getDimSize(0), out_shape_type.getDimSize(1)}); rewriter.replaceOpWithNewOp(op, out_shape_type, matmul, new_shape); return success(); } }; struct BatchMatmulOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BatchMatmulOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto a = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getA(), op.getTransposeA()); auto b = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getB(), op.getTransposeB()); rewriter.replaceOpWithNewOp(op, op.getOut().getType(), a, b); return success(); } }; struct NormalizationInferenceOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(NormalizationInferenceOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto loc = op->getLoc(); const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type()); auto epsilon = rewriter.create( loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon())); auto mean = op.getMovingMean(); auto variance = op.getMovingVariance(); auto gamma = op.getGamma(); auto beta = op.getBeta(); auto output_type = op.getY().getType(); Value x = op.getX(); if (op.IsNCHW()) { const auto perms = {0, 2, 3, 1}; x = CreateTransposeValue(loc, rewriter, x, perms); output_type = CreateTransposeType(output_type, perms); } auto batch_norm = oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta); if (op.IsNCHW()) { const auto reverse_perms = {0, 3, 1, 2}; batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms); } rewriter.replaceOp(op, {batch_norm}); return success(); } }; struct NormalizationOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(NormalizationOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto loc = op->getLoc(); const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type()); auto epsilon = rewriter.create( loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon())); auto mean = op.getMovingMean(); auto variance = op.getMovingVariance(); auto gamma = op.getGamma(); auto beta = op.getBeta(); auto output_type = op.getY().getType(); Value x = op.getX(); if (op.IsNCHW()) { const auto perms = {0, 2, 3, 1}; x = CreateTransposeValue(loc, rewriter, x, perms); output_type = CreateTransposeType(output_type, perms); } auto batch_norm = oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta); if (op.IsNCHW()) { const auto reverse_perms = {0, 3, 1, 2}; batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms); } auto moving_mean = op.getMovingMean(); auto moving_variance = op.getMovingVariance(); rewriter.replaceOp(op, {batch_norm, moving_mean, moving_variance}); return success(); } }; struct Conv2DOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Conv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair { return {arr.getValue()[0].cast().getSInt(), arr.getValue()[1].cast().getSInt()}; }; auto stride_pairs = get_pair_int64_from_array(op.getStrides()); auto pad_pairs = get_pair_int64_from_array(op.getPaddingBeforeAttr()); auto dilation_pairs = get_pair_int64_from_array(op.getDilationRate()); const auto pad = rewriter.getDenseI64ArrayAttr( {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second}); const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second}); const auto dilation = rewriter.getDenseI64ArrayAttr({dilation_pairs.first, dilation_pairs.second}); auto bias = op.getBias(); auto loc = op.getLoc(); if (!bias) { const auto output_shape = op.getOut().getType().cast(); // support nhwc const auto output_channels = output_shape.getDimSize(op.IsNCHW() ? 1 : 3); const auto bias_elem_type = output_shape.getElementType(); const auto type = RankedTensorType::get(output_channels, bias_elem_type); bias = rewriter.create( op.getLoc(), type, DenseElementsAttr::get(type, rewriter.getZeroAttr(bias_elem_type))); } Value in = op.getIn(); Value weight = op.getWeight(); auto out_type = op.getOut().getType().cast(); if (out_type.getRank() != 4) { LOG(FATAL) << "Failed to lowering oneflow op"; op->dump(); } // support nhwc if (op.IsNCHW()) { const auto perms = {0, 2, 3, 1}; const auto reverse_perms = {0, 3, 1, 2}; in = CreateTransposeValue(loc, rewriter, in, perms); weight = CreateTransposeValue(loc, rewriter, weight, perms); out_type = CreateTransposeType(out_type, perms); auto conv2d = rewriter.create(loc, out_type, in, weight, bias, pad, stride, dilation); auto res = CreateTransposeValue(loc, rewriter, conv2d, reverse_perms); rewriter.replaceOp(op, {res}); } else { rewriter.replaceOpWithNewOp(op, out_type, in, weight, bias, pad, stride, dilation); } return success(); } }; struct TransposeOpLowering final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { llvm::SmallVector perms{}; for (auto dim : op.getPerm().getAsValueRange()) { perms.push_back(dim.getSExtValue()); } llvm::SmallVector perms_shape(op.getPerm().size(), 1); auto perms_op = rewriter.create( op->getLoc(), RankedTensorType::get(perms_shape, rewriter.getI32Type()), rewriter.getI32TensorAttr(perms)); rewriter.replaceOpWithNewOp(op, op.getOutput().getType(), op.getInput(), perms_op.getOutput()); return success(); } }; struct CastInputConversion final : public OpRewritePattern { public: explicit CastInputConversion(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(InputOp op, mlir::PatternRewriter& rewriter) const override { auto outType = op.getOutput().getType(); if (isSignLessTensorOrOther(outType)) { return failure(); } if (op->hasOneUse()) { if (auto cast = llvm::dyn_cast(op.getOutput().use_begin()->getOwner())) { if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); } } } InputOp cloned = rewriter.create(op->getLoc(), op.getResultTypes(), op->getOperands(), op->getAttrs()); rewriter.replaceOpWithNewOp( op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput()); return success(); } }; struct CastVariableConversion final : public OpRewritePattern { public: explicit CastVariableConversion(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(VariableOp op, mlir::PatternRewriter& rewriter) const override { auto outType = op.getOutput().getType(); if (isSignLessTensorOrOther(outType)) { return failure(); } if (op->hasOneUse()) { if (auto cast = llvm::dyn_cast(op.getOutput().use_begin()->getOwner())) { if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); } } } if (op.getOutput().getUses().empty()) { return failure(); } VariableOp cloned = rewriter.create(op->getLoc(), op.getResultTypes(), op->getOperands(), op->getAttrs()); rewriter.replaceOpWithNewOp( op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput()); return success(); } }; namespace { class CastOneFlowOpsToSignlessPass : public CastOneFlowOpsToSignlessPassBase { void getDependentDialects(::mlir::DialectRegistry& registry) const override { registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(&getContext()); patterns.add(op->getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; struct OneFlowLoweringToTosaPass : public LowerOneFlowToTosaPassBase { void runOnOperation() override; }; struct ConvertToSignlessForTosaPass : public ConvertToSignlessForTosaPassBase { void runOnOperation() override; }; } // namespace std::unique_ptr createLowerOneFlowToTosaPass() { return std::make_unique(); } std::unique_ptr createConvertToSignlessForTosaPass() { return std::make_unique(); } void OneFlowLoweringToTosaPass::runOnOperation() { MLIRContext* context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); if (fullyConvert) { target.addIllegalDialect(); } TypeConverter typeConverter; typeConverter.addConversion([context](Type type) { return convertToSignless(context, type); }); typeConverter.addSourceMaterialization( [&](OpBuilder& builder, Type resultType, ValueRange inputs, Location loc) -> Optional { CHECK_EQ(inputs.size(), 1) << "expect to materialize a single value"; return builder.create(loc, resultType, inputs).getResult(0); }); typeConverter.addTargetMaterialization( [&](OpBuilder& builder, Type resultType, ValueRange inputs, Location loc) -> Optional { CHECK_EQ(inputs.size(), 1) << "expect to materialize a single value"; return builder.create(loc, resultType, inputs).getResult(0); }); RewritePatternSet patterns(context); // check if the pass is triggered by python based on the presence of variable tensor manger if (fullyConvert) { if (::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()) { patterns.add(typeConverter, context); } else { patterns.add(typeConverter, context, this->variableAsConstant); } } patterns.add( typeConverter, context); if (lowerJob) { patterns.add(typeConverter, context); } if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); LOG(ERROR) << "Failed to lower OneFlow to Tosa"; getOperation()->dump(); } } struct ConvertReturnToSignlessPattern : public OpRewritePattern { explicit ConvertReturnToSignlessPattern(::mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} ::mlir::LogicalResult matchAndRewrite(func::ReturnOp op, ::mlir::PatternRewriter& rewriter) const override { // make sure result not converted if (allSignless(op.getOperandTypes())) { return failure(); } llvm::SmallVector results; for (auto res : op->getOperandTypes()) { results.push_back(convertToSignless(op->getContext(), res)); } auto uc = rewriter.create(op->getLoc(), results, op.getOperands()); rewriter.replaceOpWithNewOp(op, op->getResultTypes(), uc->getResults(), op->getAttrs()); return success(); } }; struct ConvertFuncToSignlessPattern : public OpRewritePattern { explicit ConvertFuncToSignlessPattern(::mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} ::mlir::LogicalResult matchAndRewrite(func::FuncOp op, ::mlir::PatternRewriter& rewriter) const override { if (allSignless(op.getFunctionType())) { return failure(); } auto ft = convertToSignlessFuncType(op->getContext(), op.getFunctionType()); auto func = rewriter.create(op.getLoc(), op.getName(), ft); IRMapping bvm; op.getRegion().cloneInto(&func.getRegion(), bvm); for (auto& block : func.getBody().getBlocks()) { for (auto arg : block.getArguments()) { auto new_type = convertToSignless(op.getContext(), arg.getType()); arg.setType(new_type); for (auto* use : arg.getUsers()) { if (auto input = llvm::dyn_cast_or_null(use)) { input.getOutput().setType(new_type); } } } } rewriter.eraseOp(op); RewritePatternSet patterns(func->getContext()); patterns.add(func->getContext()); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); return success(); } }; void ConvertToSignlessForTosaPass::runOnOperation() { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } std::unique_ptr createCastOneFlowOpsToSignlessPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Conversion/Conversion.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Conversion/Conversion.h" #include "OneFlow/OKL/Conversion/OKLToLLVM.h" #include "OneFlow/Passes.h" #include "OneFlow/Transform/OutlineAndFuse.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Pass/PassManager.h" #include "oneflow/ir/include/OneFlow/OneFlowUtils.h" namespace mlir { namespace okl { LogicalResult LowerOKLComputeToLLVM(ModuleOp module) { PassManager pm(module->getContext()); pm.addPass(createLowerLauncherToLLVMPtrPass()); // lower-launcher-to-llvm-ptr pm.addPass(createLowerOKLToLLVMCallPass()); // lower-okl-to-llvm-call pm.addPass(createConvertFuncToLLVMPass()); // convert-func-to-llvm pm.addPass(createReconcileUnrealizedCastsPass()); // reconcile-unrealized-casts oneflow::CheckEnableIRPrinting(pm); return pm.run(module); } } // namespace okl } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Conversion/CudaGraphSupport.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/JITEngine.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/passes.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/raw_ostream.h" namespace mlir { namespace okl { struct TagCudaGraphSupportPattern final : public mlir::OpRewritePattern { static mlir::Operation* FindOneFlowOp(mlir::Operation* op) { mlir::Operation* reg_op = nullptr; for (auto& op_it : op->getRegion(0).front().getOperations()) { if (op_it.getDialect()->getNamespace() != "oneflow") { continue; } reg_op = &op_it; break; } return reg_op; } static LogicalResult CheckChild(func::FuncOp func) { using namespace ::oneflow::user_op; for (auto& op : func->getRegion(0).front()) { if (auto reg_ctx_op = llvm::dyn_cast_or_null(&op)) { // iter reg context op const auto reg_op = FindOneFlowOp(&op); if (!reg_op) { func->emitError("Failed to find reg_op in okl.build_reg_context_op"); return failure(); } // generate kernel from oneflow.{compute op} ::oneflow::okl::RegContext reg_ctx(reg_op); auto* kernel = const_cast(reg_ctx.GetKernel()); // check whether cuda graph support is base class if (const auto* cuda_graph_support = dynamic_cast(kernel)) { // TODO: more check continue; } return failure(); } } return success(); } public: explicit TagCudaGraphSupportPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { const auto tag_name = mlir::okl::cuda_graph_support::TAG_NAME; // check whether this op is okl init context function op if (!op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) { return failure(); } // check whether this op has been taged before if (op->getAttr(tag_name).dyn_cast_or_null() != nullptr) { return success(); } // check whether its childern is all cuda graph supported const auto outcome = succeeded(CheckChild(op)); // set cuda graph support tag on init_context and compute function ops op->setAttr(tag_name, rewriter.getBoolAttr(outcome)); return success(); } }; namespace { struct TagCudaGraphSupportPass : public TagCudaGraphSupportPassBase { void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } }; } // namespace std::unique_ptr createTagCudaGraphSupportPass() { return std::make_unique(); } void TagCudaGraphSupportPass::runOnOperation() { MLIRContext* context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } } // namespace okl } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Conversion/OKLToLLVM.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/JITEngine.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/passes.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" namespace mlir { namespace okl { template ModuleOp GetModuleOpFromJobBodyOp(T op) { auto parent_func_op = op->template getParentOfType(); if (!parent_func_op) { return nullptr; } return parent_func_op->template getParentOfType(); } // use this func to union the ptr type in this conversion phase. LLVM::LLVMPointerType GetPtrType(::mlir::PatternRewriter& rewriter) { return LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)); } struct WrapperKernelOpLowering final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename WrapperKernelOp::Adaptor; static LLVM::LLVMFuncOp DeclareLaunchFunc(::mlir::PatternRewriter& rewriter, ModuleOp* module) { LLVM::LLVMFuncOp func; const auto func_name = ::oneflow::okl::llvm_func::LLVM_FUNC; if (!(func = module->lookupSymbol(func_name))) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module->getBody()); auto void_type = LLVM::LLVMVoidType::get(rewriter.getContext()); auto func_type = LLVM::LLVMFunctionType::get( void_type, {GetPtrType(rewriter), rewriter.getI64Type()}, false); func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type, LLVM::Linkage::External); func->setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(rewriter.getContext())); } return func; } LogicalResult matchAndRewrite(WrapperKernelOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto module = GetModuleOpFromJobBodyOp(op); if (!module) { LOG(FATAL) << "Failed to lowering llvm call because of op is not in a module"; }; auto launch_func = DeclareLaunchFunc(rewriter, &module); auto launcher_ctx = op->getParentOfType().getBody().getArgument(0); auto index_op = rewriter.create(op->getLoc(), rewriter.getI64Type(), rewriter.getIndexAttr(op.getIndex())); auto new_op = rewriter.create(op->getLoc(), launch_func, ValueRange{launcher_ctx, index_op}); rewriter.replaceOp(op, new_op.getResults()); return success(); } }; // erase type of okl.launcher_ctx and get opaque ptr // llvm.ptr -> okl.launcher_ctx } struct RewriteFunctionArgsPattern final : public mlir::OpRewritePattern { static LogicalResult ConvertLauncherToLLVMPtr(func::FuncOp op, mlir::PatternRewriter& rewriter) { auto func_type = rewriter.getFunctionType({GetPtrType(rewriter)}, {}); auto func = rewriter.create(op.getLoc(), op.getSymName(), func_type); func->setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(rewriter.getContext())); IRMapping bvm; op.getRegion().cloneInto(&func.getRegion(), bvm); auto& block = func.getBody().getBlocks().front(); auto launcher_ctx = block.getArgument(0); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); auto cast_op = rewriter.create(op->getLoc(), launcher_ctx.getType(), launcher_ctx); launcher_ctx.setType(GetPtrType(rewriter)); launcher_ctx.replaceAllUsesExcept(cast_op->getResult(0), {cast_op}); rewriter.setInsertionPointToEnd(&block); rewriter.replaceOpWithNewOp(&block.back(), ValueRange()); rewriter.eraseOp(op); return success(); } public: explicit RewriteFunctionArgsPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { if (op.getNumArguments() == 1 && op.getArgumentTypes().begin()->isa()) { return ConvertLauncherToLLVMPtr(op, rewriter); } return success(); } }; namespace { struct LowerLauncherToLLVMPtrPass : public LowerLauncherToLLVMPtrPassBase { void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); } }; struct LowerOKLToLLVMCallPass : public LowerOKLToLLVMCallPassBase { void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); } }; } // namespace std::unique_ptr createLowerOKLToLLVMCallPass() { return std::make_unique(); } std::unique_ptr createLowerLauncherToLLVMPtrPass() { return std::make_unique(); } void LowerLauncherToLLVMPtrPass::runOnOperation() { MLIRContext* context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void LowerOKLToLLVMCallPass::runOnOperation() { MLIRContext* context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalDialect(); auto llvm_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); TypeConverter typeConverter; typeConverter.addConversion([&](mlir::okl::LauncherContextType type) { return llvm_ptr_type; }); RewritePatternSet patterns(context); patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); getOperation()->emitError("Failed to lower OKL to LLVM Call"); } } } // namespace okl } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/ComputeContext.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/ComputeContext.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "OneFlow/OKL/OKLOps.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { namespace okl { user_op::Tensor* ComputeContext::CreateTensorWithArgNameAndIndex(const std::string& arg_name, int32_t index) { auto op = reg_ctx_->GetOp(); auto source = mlir::oneflow::user_op::GetOpSourceByName(op, arg_name); if (source.type == mlir::oneflow::user_op::Source::OUTPUT) { if (op->getNumResults() <= index + source.offset) { return nullptr; } mlir::Value val = op->getResult(index + source.offset); auto use = *val.getUsers().begin(); if (auto ret_op = llvm::dyn_cast_or_null(use)) { return comp_ctx_->Tensor4ArgNameAndIndex("out", ret_op.getIndex()); } if (auto pool_op = llvm::dyn_cast_or_null(use)) { return tmp_buffer_.GetPoolTensor(TensorDesc4ArgNameAndIndex(arg_name, index), pool_op.getOffset()); } op->emitError("Failed to find " + std::to_string(index) + "in outputs"); exit(1); } if (source.type == mlir::oneflow::user_op::Source::INPUT) { if (op->getNumOperands() <= index + source.offset) { return nullptr; } mlir::Value val = op->getOperand(index + source.offset); auto define_op = val.getDefiningOp(); return llvm::TypeSwitch<::mlir::Operation*, user_op::Tensor*>(define_op) .Case([&](mlir::okl::GetTensorFromArgOp elem) { return comp_ctx_->Tensor4ArgNameAndIndex("in", elem.getIndex()); }) .Case([&](mlir::okl::GetTensorFromRetOp elem) { return comp_ctx_->Tensor4ArgNameAndIndex("out", elem.getIndex()); }) .Case([&](mlir::okl::PoolToTensorOp elem) { return tmp_buffer_.GetPoolTensor(TensorDesc4ArgNameAndIndex(arg_name, index), elem.getOffset()); }) .Default([&](::mlir::Operation* op) { op->dump(); LOG(FATAL) << "Signature: " << arg_name << " Not supported"; return nullptr; }); } if (source.type == mlir::oneflow::user_op::Source::BUFFER) { auto wrap = op->getParentOfType(); for (auto& op : wrap.getBody().front()) { if (auto pool_to_buffer = llvm::dyn_cast_or_null(op)) { return tmp_buffer_.GetPoolBuffer(pool_to_buffer.getType().getShape()[0], pool_to_buffer.getOffset()); } } } op->emitError("Failed to check source type"); exit(1); } user_op::Tensor* ComputeContext::Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = tensor_.find({arg_name, index}); if (it != tensor_.end()) return it->second; user_op::Tensor* res = CreateTensorWithArgNameAndIndex(arg_name, index); tensor_[{arg_name, index}] = res; return res; } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/InferContext.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/InferContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "llvm/Support/Casting.h" namespace oneflow { namespace okl { using namespace user_op; InferContext::InferContext(const RegContext* reg_ctx) : reg_ctx_(reg_ctx) {} const TensorDesc* InferContext::LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index); } const Shape& InferContext::InputShape(const std::string& arg_name, int32_t index) const { return Shape4ArgNameAndIndex(arg_name, index); } const Shape& InferContext::Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { return LogicalTensorDesc4ArgNameAndIndex(arg_name, index)->shape(); } const std::shared_ptr& InferContext::Attr4Name(const std::string& attr_name) const { return reg_ctx_->user_op_conf().Attr4Name(attr_name); } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/JITEngine.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Extension.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" #include "OneFlow/OKL/Kernel/JITEngine.h" #include "OneFlow/OKL/Kernel/ComputeContext.h" #include "oneflow/core/device/cuda_util.h" extern "C" { void okl_llvm_func(void* launcher, int64_t index) { static_cast>(launcher)->Launch( index); } } // extern "C" namespace oneflow { SharedLibs* MutSharedLibPaths() { static SharedLibs libs = {}; return &libs; } const SharedLibs* SharedLibPaths() { return MutSharedLibPaths(); } } // namespace oneflow oneflow::okl::JITEngine::JITEngine(mlir::ModuleOp module) { llvm::SmallVector ext_libs( {oneflow::SharedLibPaths()->begin(), oneflow::SharedLibPaths()->end()}); mlir::ExecutionEngineOptions jitOptions; jitOptions.transformer = {}; jitOptions.jitCodeGenOptLevel = llvm::None; jitOptions.sharedLibPaths = ext_libs; auto jit_or_error = mlir::ExecutionEngine::create(module, jitOptions); CHECK(!!jit_or_error) << "failed to create JIT exe engine, " << llvm::toString((jit_or_error).takeError()); jit_or_error->swap(engine_); } ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/JITOpInfer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "OneFlow/OneFlowSupport.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/nn_util.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" namespace oneflow { namespace ir { namespace jit { static Maybe GetFunctionType(user_op::InferContext* ctx, mlir::OwningOpRef& module) { mlir::func::FuncOp funcOp = mlir::SymbolTable::lookupNearestSymbolFrom( module.get(), mlir::SymbolRefAttr::get(module->getContext(), ctx->op_name())); CHECK_OR_RETURN(funcOp) << "Fail to find funcOp of symbol " << ctx->op_name(); const auto funcType = funcOp.getFunctionType(); CHECK_EQ_OR_RETURN(funcType.getNumInputs(), ctx->input_size("in")) << "input size mismatch with mlir assembly"; CHECK_EQ_OR_RETURN(funcType.getNumResults(), ctx->output_size("out")) << "output size mismatch with mlir assembly"; int32_t arg_i = 0; for (mlir::Type arg_type : funcType.getInputs()) { if (auto rankedTensorType = arg_type.dyn_cast()) { CHECK_EQ_OR_RETURN( (Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()}), ctx->InputShape("in", arg_i)) << "arg #" << arg_i; const auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType()); if (mlir::failed(data_type)) { exit(1); } CHECK_EQ_OR_RETURN(data_type.value(), ctx->InputDType("in", arg_i)) << "arg #" << arg_i; arg_i += 1; } else { std::string arg_type_str = ""; llvm::raw_string_ostream os(arg_type_str); arg_type.print(os); THROW(RuntimeError) << "Unsupported arg type " << arg_type_str; } } return funcType; } Maybe SetTensorDataType(user_op::InferContext* ctx) { auto mlir_assembly = ctx->Attr>("mlir_assembly"); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); mlir::OwningOpRef module = mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &context); if (!module) { LOG(ERROR) << "Fail to load mlir assembly"; exit(1); } if ((*module)->hasAttr(mlir::oneflow::jit::RAW_GRAPH)) { auto raw_graph = (*module)->getAttr(mlir::oneflow::jit::RAW_GRAPH).cast(); if (raw_graph) module = mlir::parseSourceString(raw_graph.strref(), module->getContext()); } auto funcType = *JUST(GetFunctionType(ctx, module)); int32_t res_i = 0; for (mlir::Type res_type : funcType.getResults()) { if (auto rankedTensorType = res_type.dyn_cast()) { const auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType()); if (mlir::failed(data_type)) { exit(1); } ctx->SetDtype4ArgNameAndIndex("out", res_i, data_type.value()); res_i += 1; } else { std::string res_type_str = ""; llvm::raw_string_ostream os(res_type_str); res_type.print(os); THROW(RuntimeError) << "Unsupported arg type " << res_type_str; } } return Maybe::Ok(); } Maybe InferTensorDesc(user_op::InferContext* ctx) { auto mlir_assembly = ctx->Attr>("mlir_assembly"); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); mlir::OwningOpRef module = mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &context); if (!module) { LOG(ERROR) << "Fail to load mlir assembly"; exit(1); } if ((*module)->hasAttr(mlir::oneflow::jit::RAW_GRAPH)) { auto raw_graph = (*module)->getAttr(mlir::oneflow::jit::RAW_GRAPH).cast(); if (raw_graph) module = mlir::parseSourceString(raw_graph.strref(), module->getContext()); } auto funcType = *JUST(GetFunctionType(ctx, module)); int32_t res_i = 0; for (mlir::Type res_type : funcType.getResults()) { if (auto rankedTensorType = res_type.dyn_cast()) { ctx->SetOutputShape( "out", res_i, Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()}); const auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType()); if (mlir::failed(data_type)) { exit(1); } ctx->SetOutputDType("out", res_i, data_type.value()); llvm::SmallVector strides; int64_t _; auto mem_type = mlir::MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType()); if (failed(mlir::getStridesAndOffset(mem_type, strides, _))) { LOG(FATAL) << "Fail to get stride from memory type"; } ctx->SetOutputStride("out", res_i, Stride(strides.begin(), strides.end())); res_i += 1; } else { std::string res_type_str = ""; llvm::raw_string_ostream os(res_type_str); res_type.print(os); THROW(RuntimeError) << "Unsupported arg type " << res_type_str; } } return Maybe::Ok(); } } // namespace jit } // namespace ir } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/KernelLaunchOp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Conversion/Conversion.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/UserOpReflection.h" #include "OneFlow/Passes.h" #include "OneFlow/Extension.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/kernel/blob_tensor_view.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/framework/op_generated.h" #include "OneFlow/OKL/Kernel/JITOpInfer.h" #include "OneFlow/OKL/Kernel/JITEngine.h" #include "OneFlow/OKL/Kernel/LauncherState.h" #include "OneFlow/OKL/Kernel/TmpBufferManager.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Parser/Parser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "llvm/Support/Error.h" #include "llvm/Support/TargetSelect.h" #include #include #include #include namespace oneflow { Maybe KernelLaunchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return ir::jit::InferTensorDesc(ctx); } Maybe KernelLaunchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ir::jit::InferTensorDesc(ctx); } Maybe KernelLaunchOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } Maybe KernelLaunchOp::InferDataType(user_op::InferContext* ctx) { return ir::jit::SetTensorDataType(ctx); } namespace { using namespace oneflow::okl; template class KernelLaunchKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: KernelLaunchKernel() = default; ~KernelLaunchKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { // use ctx to create module, reg_ctx and fn; std::shared_ptr res(new LauncherState(ctx)); return res; } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, user_op::OpKernelState* state) const override { return dynamic_cast(state)->IsCudaGraphSupported(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* okl_state = dynamic_cast(state); okl_state->DoCompute(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_KERNEL_LAUNCH_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("kernel_launch") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ return oneflow::okl::TmpBufferManager::InferTmpSize(ctx); \ }) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ return Maybe::Ok(); \ }); REGISTER_KERNEL_LAUNCH_CPU_KERNEL(float) REGISTER_KERNEL_LAUNCH_CPU_KERNEL(double) REGISTER_KERNEL_LAUNCH_CPU_KERNEL(int32_t) REGISTER_KERNEL_LAUNCH_CPU_KERNEL(int64_t) #undef REGISTER_KERNEL_LAUNCH_CPU_KERNEL #define REGISTER_KERNEL_LAUNCH_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("kernel_launch") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ return oneflow::okl::TmpBufferManager::InferTmpSize(ctx); \ }) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ return Maybe::Ok(); \ }); REGISTER_KERNEL_LAUNCH_GPU_KERNEL(float) REGISTER_KERNEL_LAUNCH_GPU_KERNEL(double) REGISTER_KERNEL_LAUNCH_GPU_KERNEL(int8_t) REGISTER_KERNEL_LAUNCH_GPU_KERNEL(int32_t) REGISTER_KERNEL_LAUNCH_GPU_KERNEL(int64_t) #if CUDA_VERSION >= 11000 REGISTER_KERNEL_LAUNCH_GPU_KERNEL(half) REGISTER_KERNEL_LAUNCH_GPU_KERNEL(nv_bfloat16) #endif #undef REGISTER_KERNEL_LAUNCH_GPU_KERNEL } // namespace } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/LauncherContext.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/WrapperContext.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/Passes.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "oneflow/core/framework/op_kernel.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKL/Kernel/ComputeContext.h" #include "OneFlow/OKL/Kernel/LauncherContext.h" #include "llvm/ADT/TypeSwitch.h" namespace oneflow { namespace okl { LauncherContext::LauncherContext(mlir::ModuleOp module) { mlir::Operation* func; module->walk([&](mlir::func::FuncOp op) { if (op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) { func = op; } }); if (!func) { LOG(FATAL) << "Not Found okl_func in mlir ir"; } auto& ops = func->getRegion(0).front(); for (auto& op : ops) { llvm::TypeSwitch(&op) .Case([&](mlir::okl::WrapperKernelOp elem) { mlir::Operation* reg_op = nullptr; for (auto& op_it : op.getRegion(0).front().getOperations()) { if (op_it.getDialect()->getNamespace() == "oneflow") { reg_op = &op_it; break; } } if (!reg_op) { LOG(FATAL) << "Failed to find reg_op in okl.build_reg_context_op"; } compile_ctx_vec_.emplace_back(reg_op); }) .Case([&](mlir::func::ReturnOp elem) {}) .Default([&](mlir::Operation* elem) { elem->dump(); LOG(FATAL) << "Fail to parse this op in okl init context"; }); } } bool LauncherContext::Infer(user_op::KernelComputeContext* compute_context) { // if this context has been inferred before, it won't be rebuilt later if (inferred_) { return inferred_; } for (auto& elem : compile_ctx_vec_) { run_ctx_vec_.emplace_back(elem.GetRegContext()->GetOp(), compute_context); } inferred_ = compile_ctx_vec_.size() == run_ctx_vec_.size(); return inferred_; } void LauncherContext::Launch(int index) { if (!inferred_) { LOG(FATAL) << "Not infer yet when launch kernels"; } run_ctx_vec_[index].Run(); } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/LauncherState.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Conversion/Conversion.h" #include "OneFlow/OKM/Conversion/Conversion.h" #include "OneFlow/Passes.h" #include "OneFlow/OKM/passes.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectRegistry.h" #include "oneflow/core/framework/op_kernel.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OKL/OKLDialect.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Parser/Parser.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "OneFlow/OKL/Kernel/JITEngine.h" #include "OneFlow/OKL/Kernel/LauncherContext.h" #include "OneFlow/OKL/Kernel/LauncherState.h" namespace oneflow { namespace okl { namespace { mlir::OwningOpRef GetModule(user_op::KernelInitContext* ctx, mlir::MLIRContext* mlir_ctx) { auto mlir_assembly = ctx->Attr>("mlir_assembly"); mlir::OwningOpRef module = mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx); if (!module) { LOG(FATAL) << "Fail to load mlir assembly"; } // lower oneflow wrap ops into okl dialect if (failed(mlir::okm::LowerWrapOpsToOKL(*module))) { LOG(FATAL) << "Fail lowering kernel launch Module to okm and okl ir"; } return module; } JITEngine GetEngine(mlir::ModuleOp module) { if (failed(mlir::okl::LowerOKLComputeToLLVM(module))) { LOG(FATAL) << "Fail lowering okl compute Module to llvm ir"; } return JITEngine(module); } } // namespace LauncherState::LauncherState(user_op::KernelInitContext* ctx) : mlir_ctx_(GetRegistry()), module_(GetModule(ctx, &mlir_ctx_)), launcher_context_(module_->clone()), engine_(GetEngine(module_->clone())) {} bool LauncherState::IsCudaGraphSupported(user_op::KernelInitContext* ctx) { const auto tag_name = mlir::okl::cuda_graph_support::TAG_NAME; if (const auto func = module_->lookupSymbol(mlir::okm::func_name::OKL_GRAPH_NAME)) { if (const auto is_supported = func->getAttr(tag_name).dyn_cast_or_null()) { return is_supported.getValue(); } } return false; } void LauncherState::DoCompute(user_op::KernelComputeContext* ctx) { launcher_context_.Infer(ctx); engine_.Run(mlir::okm::func_name::OKL_GRAPH_NAME, &launcher_context_); } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/RegContext.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/UserOpConversion.h" #include "OneFlow/UserOpReflection.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/kernel/blob_tensor_view.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "OneFlow/OKL/Kernel/InferContext.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" namespace oneflow { namespace okl { static user_op::UserOpConfWrapper GetConfWrapper(mlir::Operation* op, bool is_mapping_size = false) { OperatorConf op_conf; if (mlir::failed(mlir::oneflow::user_op::ConvertUserOpAttributes(op, op_conf, is_mapping_size))) { op->emitError("fail to convert user op attributes"); exit(1); } auto conf_wrapper_ = user_op::UserOpConfWrapper(std::make_shared(op_conf)); return conf_wrapper_; } RegContext::RegContext(mlir::Operation* op) : op_(op), conf_wrapper_(GetConfWrapper(op, true)) { const auto handle_operands_or_results = [&op, this](const auto& arg_ids, const auto& get_operand_or_result, ArgVec& arg_vec) { for (const auto& obj_id : ::llvm::enumerate(arg_ids)) { user_op::NaiveTensorDesc tensor_desc{}; auto obj = get_operand_or_result(op, obj_id.index()); if (auto rankedTensorType = obj.getType().template dyn_cast()) { tensor_desc.set_shape( Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()}); const auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType()); if (mlir::failed(data_type)) { exit(1); } tensor_desc.set_data_type(data_type.value()); llvm::SmallVector strides; int64_t _; auto mem_type = mlir::MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType()); if (failed(mlir::getStridesAndOffset(mem_type, strides, _))) { LOG(FATAL) << "Fail to get stride from memory type"; } tensor_desc.set_stride(Stride(strides.begin(), strides.end())); // TODO: set is_dynamic } else { LOG(FATAL) << "Unranked tensor type not supported"; } CHECK(arg2tensor_desc_.emplace(obj_id.value(), tensor_desc).second) << "duplicate key"; arg_vec.push_back(obj_id.value()); } }; handle_operands_or_results( ::mlir::oneflow::user_op::ArgIds(op), [](auto& x, size_t index) { return x->getOperand(index); }, inputs_); handle_operands_or_results( ::mlir::oneflow::user_op::ArgIds(op), [](auto& x, size_t index) { return x->getResult(index); }, outputs_); auto dev_tag = mlir::OpTrait::IsOpConfCompatible::getDeviceTag(op); if (dev_tag == "cpu") { device_type_ = DeviceType::kCPU; } else if (dev_tag == "cuda") { device_type_ = DeviceType::kCUDA; } else { LOG(FATAL) << "Unsupported device tag: " << dev_tag.str(); } auto op_name = GetOp()->getName().stripDialect().str(); if (const auto op_type_name = GetOp()->getAttr("op_type_name").dyn_cast_or_null()) { op_name = op_type_name.str(); } reg_res_ = CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_name, *this)); kernel_ = reg_res_->create_fn(); conf_wrapper_ = GetConfWrapper(op_, true); } DeviceType RegContext::device_type() const { return device_type_; } const ParallelContext& RegContext::parallel_ctx() const { TODO() << "create parallel_ctx from op in mlir"; ParallelContext* parallel_ctx = nullptr; return *parallel_ctx; } const user_op::TensorDesc* RegContext::TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return &(it->second); } const ArgVec& RegContext::inputs() const { return inputs_; } const ArgVec& RegContext::outputs() const { return outputs_; } // TODO: more information is needed const user_op::UserOpConfWrapper& RegContext::user_op_conf() const { return conf_wrapper_; } const std::shared_ptr& RegContext::Attr4Name( const std::string& attr_name) const { return user_op_conf().Attr4Name(attr_name); } const size_t RegContext::GetTmpBufferSize() const { if (reg_res_->need_temp_storage) { InferContext infer_ctx(this); return reg_res_->infer_tmp_size_fn(&infer_ctx); } return 0; } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/Kernel/TmpBufferManager.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Kernel/TmpBufferManager.h" #include "OneFlow/OKL/Kernel/LauncherState.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKM/Conversion/Conversion.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/Passes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "llvm/Support/Casting.h" namespace oneflow { namespace okl { size_t TmpBufferManager::InferTmpSize(user_op::InferContext* ctx) { using namespace user_op; mlir::MLIRContext mlir_ctx(GetRegistry()); auto mlir_assembly = ctx->Attr>("mlir_assembly"); mlir::OwningOpRef module = mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &mlir_ctx); if (!module) { LOG(FATAL) << "Fail to load mlir assembly"; } if (failed(mlir::okm::LowerWrapOpsToOKL(*module))) { LOG(ERROR) << "Fail lowering kernel launch Module to okl ir"; exit(1); } size_t pool_size = 0; module->walk([&](mlir::func::FuncOp op) { if (op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) { if (auto pool_size_attr = op->getAttrOfType(mlir::okm::func_name::OKL_POOL_SIZE_TAG)) { pool_size = pool_size_attr.getInt(); } } }); return pool_size; } } // namespace okl } // namespace oneflow ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/OKLDialect.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/OKLAttributes.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "mlir/IR/BuiltinAttributes.h" #include "OneFlow/OKLDialect.cpp.inc" #include "mlir/IR/Dialect.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Support/LogicalResult.h" #define GET_ATTRDEF_CLASSES #include "OneFlow/OKLAttributes.cpp.inc" namespace mlir { namespace okl { void OKLDialect::initialize() { addOperations< #define GET_OP_LIST #include "OneFlow/OKLOps.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "OneFlow/OKLTypes.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "OneFlow/OKLAttributes.cpp.inc" >(); } } // namespace okl } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/OKLOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/OKLAttributes.h" #include "OneFlow/OneFlowOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Dialect.h" #include "OneFlow/OKLEnums.cpp.inc" #define GET_OP_CLASSES #include "OneFlow/OKLOps.cpp.inc" ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/OKLTypes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLTypes.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" #define GET_TYPEDEF_CLASSES #include "OneFlow/OKLTypes.cpp.inc" ================================================ FILE: oneflow/ir/lib/OneFlow/OKL/README-OriginVersion.md ================================================ # 初版OKL设计文档 oneflow kernel launch dialect 将oneflow kernel引入mlir执行。 ## 编译期 ### 1. FromGraphToMLIR - GraphToJob - JobToOneFlowDialect ### 2. OneFlowDialectToOKLDialect 通过三个Pass将OneFlow转换成okl的ir形式。 - extract-kernel-launch-tensor - trim-return-to-void - lower-to-okl ``` mlir module { func.func @wrap0(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) attributes {llvm.emit_c_interface} { %0 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> return %0, %1 : tensor<2xf32>, tensor<2xf32> } } ``` -extract-kernel-launch-tensor 将tensor的输入流转换为ctx中获取 ``` mlir module { func.func @wrap0(%arg0: !okl.launcher_ctx) -> (tensor<2xf32>, tensor<2xf32>) { %0 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.tanh"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> %4 = "okl.get_tensor_as_ret"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> return %3, %4 : tensor<2xf32>, tensor<2xf32> } } ``` -trim-return-to-void 将tensor的输出流删除掉 ```mlir module { func.func @wrap0(%arg0: !okl.launcher_ctx) { %0 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.tanh"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> %4 = "okl.get_tensor_as_ret"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> return } } ``` -lower-to-okl 将oneflow kernel op用okl wrapper_kernel封装起来,并通过okl op编译推导对应tensor流的信息。 ```mlir module { func.func @okl_func(%arg0: !okl.launcher_ctx) { "okl.wrapper_kernel"() ({ %0 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 0 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.get_tensor_from_ret"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 1 : i32} : () -> () return } } ``` ### 3. OKLDialectToLLVMDialect 通过四个Pass将OKL的IR转换为LLVM的IR形式作为运行时的输入 - lower-launcher-to-llvm-ptr - lower-okl-to-llvm-call - reconcile-unrealized-casts - convert-func-to-llvm -lower-launcher-to-llvm-ptr 将ctx转换成一个llvm.ptr,通过llvm.ptr表示ctx的传递。 ```mlir module { func.func @okl_func(%arg0: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr to !okl.launcher_ctx "okl.wrapper_kernel"() ({ %1 = "okl.get_tensor_from_arg"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "oneflow.relu"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%0, %2) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 0 : i32} : () -> () "okl.wrapper_kernel"() ({ %1 = "okl.get_tensor_from_ret"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "oneflow.tanh"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 1 : i32} : () -> () return } } ``` -lower-okl-to-llvm-call 将okl的wrapper_kernel转换成llvm的call调用。 ```mlir module { llvm.func @okl_llvm_func(!llvm.ptr, i64) attributes {llvm.emit_c_interface} func.func @okl_func(%arg0: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr to !okl.launcher_ctx %1 = llvm.mlir.constant(0 : index) : i64 llvm.call @okl_llvm_func(%arg0, %1) : (!llvm.ptr, i64) -> () %2 = llvm.mlir.constant(1 : index) : i64 llvm.call @okl_llvm_func(%arg0, %2) : (!llvm.ptr, i64) -> () return } } ``` -reconcile-unrealized-casts -convert-func-to-llvm 转换成可以直接运行的llvm IR ```mlir module attributes {llvm.data_layout = ""} { llvm.func @okl_llvm_func(!llvm.ptr, i64) attributes {llvm.emit_c_interface} llvm.func @okl_func(%arg0: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(0 : index) : i64 llvm.call @okl_llvm_func(%arg0, %0) : (!llvm.ptr, i64) -> () %1 = llvm.mlir.constant(1 : index) : i64 llvm.call @okl_llvm_func(%arg0, %1) : (!llvm.ptr, i64) -> () llvm.return } llvm.func @_mlir_ciface_okl_func(%arg0: !llvm.ptr) attributes {llvm.emit_c_interface} { llvm.call @okl_func(%arg0) : (!llvm.ptr) -> () llvm.return } } ``` ## 运行时 OKLDialect IR不仅作为编译期最后一阶段的输出,同时作为运行时初始化时期资源的输入来初始化运行时的各种ctx,从而为计算期的计算做准备。 一个 OKL 的 kernel 包含了一整个子图。因此 OKL 的 kernel 需要管理子图的若干有序子 op 的 ctx 资源。这些通过 LauncherState 来初始化创建,LauncherState 中含有 LauncherContext 用来统一管理子图的所有子 Op 的资源。 LauncherContext含有若干有序的CompileTimeWrapperContext一一对应其子Op未Infer前的ctx,以及若干RunTimeWrapperContext一一对应其子Op在Infer后的ctx。 下面为这两种Ctx所持有的资源。 ``` class CompileTimeWrapperContext { std::shared_ptr reg_ctx_; }; class RunTimeWrapperContext : public CompileTimeWrapperContext { std::shared_ptr compute_ctx_; std::shared_ptr init_ctx_; std::shared_ptr kernel_state_; std::shared_ptr kernel_cache_; }; ``` CompileTimeWrapperContext 主要是reg_ctx,以作为infer推导的必须输入。 RunTimeWrapperContext 包含所有子op运行时计算需要用的的资源,主要有compute_ctx以及state和cache。 ================================================ FILE: oneflow/ir/lib/OneFlow/OKM/Conversion/Conversion.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/Conversion/Conversion.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/OneFlowUtils.h" namespace mlir { namespace okm { LogicalResult LowerWrapOpsToOKL(ModuleOp module) { PassManager pm(module->getContext()); pm.addPass(createExtractOKMTensorPass()); pm.addPass(createWrapOKMKernelPass()); pm.addPass(createOptOKMMemrefPass()); pm.addPass(createConvertOKMToOKLPass()); pm.addPass(okl::createTagCudaGraphSupportPass()); oneflow::CheckEnableIRPrinting(pm); return pm.run(module); } } // namespace okm } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKM/OKMDialect.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKM/OKMDialect.h" #include "OneFlow/OKM/OKMOps.h" #include "OneFlow/OKM/OKMAttributes.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "mlir/IR/BuiltinAttributes.h" #include "OneFlow/OKMDialect.cpp.inc" #include "mlir/IR/Dialect.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Support/LogicalResult.h" #define GET_ATTRDEF_CLASSES #include "OneFlow/OKMAttributes.cpp.inc" #define GET_OP_CLASSES #include "OneFlow/OKMOps.cpp.inc" namespace mlir { namespace okm { void OKMDialect::initialize() { addOperations< #define GET_OP_LIST #include "OneFlow/OKMOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "OneFlow/OKMAttributes.cpp.inc" >(); } } // namespace okm } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OKM/passes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/intra_job_mem_sharing_util.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKM/OKMDialect.h" #include "OneFlow/OKM/OKMOps.h" #include "OneFlow/OKM/passes.h" #include "OneFlow/OneFlowDialect.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir-c/BuiltinTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace okm { namespace func_name { const std::string GRAPH_NAME = "_mlir_oneflow_subgraph"; const std::string MEM_GRAPH_NAME = "okm_subgraph"; const std::string WRAP_GRAPH_NAME = "okm_wrap_subgraph"; const std::string OPT_GRAPH_NAME = "okm_alloc_subgraph"; const std::string OKL_GRAPH_NAME = "okl_subgraph"; const std::string OKL_POOL_SIZE_TAG = "pool_size"; } // namespace func_name struct ExtractOKMTensorPattern : public mlir::OpRewritePattern { static void ExtractArgTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) { auto& body = op.getBody(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&body.front()); for (const auto& arg : llvm::enumerate(op.getBody().getArguments())) { auto tensor = rewriter.create(op->getLoc(), arg.value().getType(), arg.index()); arg.value().replaceAllUsesWith(tensor); } } static void ExtractRetTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) { auto& return_op = op.getBody().front().back(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&return_op); llvm::SmallVector returns; for (const auto& ret_val : llvm::enumerate(return_op.getOperands())) { auto new_ret = rewriter.create(op->getLoc(), ret_val.value().getType(), ret_val.value(), ret_val.index()); returns.push_back(new_ret); } rewriter.replaceOpWithNewOp(&return_op, ValueRange{returns}); } explicit ExtractOKMTensorPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { const auto sym_name = op.getSymName(); if (sym_name.startswith(func_name::GRAPH_NAME)) { // rename function const auto index = sym_name.substr(func_name::GRAPH_NAME.size()); const auto rename = func_name::MEM_GRAPH_NAME + index; op.setSymNameAttr(rewriter.getStringAttr(rename)); // extract tensors ExtractArgTensors(op, rewriter); ExtractRetTensors(op, rewriter); return success(); } return failure(); } }; class ExtractOKMTensorPass : public ExtractOKMTensorPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; std::unique_ptr createExtractOKMTensorPass() { return std::make_unique(); } struct WrapOKMKernelPattern : public mlir::OpRewritePattern { static Value AllocOrMapOutTensor(Value res, mlir::PatternRewriter& rewriter) { if (auto type = res.getType().dyn_cast_or_null()) { int ret_index = -1; for (auto use : res.getUsers()) { if (auto to_ret = llvm::dyn_cast_or_null(use)) { ret_index = to_ret.getIndex(); break; } } auto mem_type = MemRefType::get(type.getShape(), type.getElementType()); auto out = (ret_index == -1) ? rewriter.create(rewriter.getUnknownLoc(), mem_type) : rewriter.create(rewriter.getUnknownLoc(), mem_type, ret_index); return out->getResult(0); } return nullptr; } static void CreateWrapOp(Operation* op, mlir::PatternRewriter& rewriter, IRMapping& mapper, const llvm::SmallVector& mem_outs_types, const llvm::SmallVector& map_ins) { auto wrapper_op = rewriter.create(op->getLoc(), mem_outs_types, ValueRange(map_ins)); for (auto elem : llvm::zip(op->getResults(), wrapper_op->getResults())) { mapper.map(std::get<0>(elem), std::get<1>(elem)); } auto& wrap_block = wrapper_op.getBody().emplaceBlock(); OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&wrap_block); ImplicitLocOpBuilder nb(rewriter.getUnknownLoc(), rewriter); IRMapping wrap_mapper; for (auto in : llvm::zip(op->getOperands(), wrapper_op.getOperands())) { auto to_tensor = rewriter.create(rewriter.getUnknownLoc(), std::get<1>(in)); wrap_mapper.map(std::get<0>(in), to_tensor); } auto new_op = nb.clone(*op, wrap_mapper); SmallVector outs; for (auto out : new_op->getResults()) { if (auto type = out.getType().dyn_cast_or_null()) { auto mem_type = MemRefType::get(type.getShape(), type.getElementType()); auto to_memref = rewriter.create(rewriter.getUnknownLoc(), mem_type, out); outs.push_back(to_memref); } else { llvm::errs() << "Fail to identify op type in wrap okm kernel"; exit(1); } } rewriter.create(rewriter.getUnknownLoc(), ValueRange(outs)); } static void HandleOneFlowOp(Operation* op, mlir::PatternRewriter& rewriter, IRMapping& mapper) { // record outs type llvm::SmallVector mem_outs_types; for (auto it : op->getResultTypes()) { if (auto type = it.dyn_cast_or_null()) { auto mem_type = MemRefType::get(type.getShape(), type.getElementType()); mem_outs_types.push_back(mem_type); } else { llvm::errs() << "Fail to identify op type in wrap okm kernel"; exit(1); } } llvm::SmallVector map_ins; // record ins for (auto in : op->getOperands()) { auto mirror = mapper.lookup(in); if (auto wrap_op = llvm::dyn_cast_or_null(mirror.getDefiningOp())) { int idx = 0; for (auto res : wrap_op->getResults()) { if (mirror == res) { break; } ++idx; } Operation* oneflow_op = nullptr; auto& ops = wrap_op.getBody().front(); for (auto& op : ops) { if (oneflow::OneFlowDialect::getDialectNamespace().equals( op.getDialect()->getNamespace())) { oneflow_op = &op; } } if (!oneflow_op) { LOG(FATAL) << "Fail to find oneflow op in wrap op"; } mirror = wrap_op->getOperand(oneflow_op->getNumOperands() + idx).getDefiningOp()->getResult(0); } map_ins.push_back(mirror); } // append alloc outs after ins for (auto out : op->getResults()) { if (auto new_out = AllocOrMapOutTensor(out, rewriter)) { map_ins.push_back(new_out); } else { llvm::errs() << "Fail to alloc or map op in wrap okm kernel"; exit(1); } } if (int64_t buffer_size = ::oneflow::okl::RegContext(op).GetTmpBufferSize()) { auto type = MemRefType::get({buffer_size}, rewriter.getI8Type()); auto tmp_buffer = rewriter.create(rewriter.getUnknownLoc(), type)->getResult(0); map_ins.push_back(tmp_buffer); } CreateWrapOp(op, rewriter, mapper, mem_outs_types, map_ins); } static func::FuncOp WrapOps(func::FuncOp func, mlir::PatternRewriter& rewriter, const std::string& func_name) { OpBuilder::InsertionGuard insertGuard(rewriter); auto func_type = rewriter.getFunctionType({}, {}); rewriter.setInsertionPoint(func); auto wrap_func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); auto& block = wrap_func.getBody().emplaceBlock(); rewriter.setInsertionPointToStart(&block); auto& ops = func.getBody().front(); IRMapping mapper; for (auto& op : ops) { llvm::TypeSwitch(&op) .Case([&](ArgToTensorOp op) { auto mem_type = MemRefType::get(op.getType().getShape(), op.getType().getElementType()); auto mem_op = rewriter.create(op->getLoc(), mem_type, op.getIndex()); mapper.map(Value(op), mem_op); }) .Case([&](TensorToRetOp op) { auto mem_type = MemRefType::get(op.getType().getShape(), op.getType().getElementType()); rewriter.create(op->getLoc(), mem_type, mapper.lookup(op.getTensor()), op.getIndex()); }) .Default([&](Operation* op) { if (oneflow::OneFlowDialect::getDialectNamespace().equals( op->getDialect()->getNamespace())) { HandleOneFlowOp(op, rewriter, mapper); } }); } rewriter.create(rewriter.getUnknownLoc()); return wrap_func; } explicit WrapOKMKernelPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { const auto sym_name = op.getSymName(); if (sym_name.startswith(func_name::MEM_GRAPH_NAME)) { // rename function const auto index = sym_name.substr(func_name::MEM_GRAPH_NAME.size()).str(); const std::string rename = func_name::WRAP_GRAPH_NAME + index; // wrap kernels WrapOps(op, rewriter, rename); rewriter.eraseOp(op); } return success(); } }; class WrapOKMKernelPass : public WrapOKMKernelPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; std::unique_ptr createWrapOKMKernelPass() { return std::make_unique(); } namespace { void MemSizeFirst(func::FuncOp func, mlir::PatternRewriter& rewriter) { OpBuilder::InsertionGuard insertGuard(rewriter); auto& ops = func.getBody().front(); rewriter.setInsertionPointToStart(&ops); auto mem_type = MemRefType::get({0}, rewriter.getI8Type()); auto global_buffer = rewriter.create(rewriter.getUnknownLoc(), mem_type); ::oneflow::HashMap op2lifetime; int32_t idx = 0; for (auto& op : ops) { if (auto wrap_op = llvm::dyn_cast_or_null(op)) { op2lifetime[&op] = idx++; } } ::oneflow::HashMap val2size; ::oneflow::HashMap> val2lifetime; for (auto& op : ops) { if (auto alloc_op = llvm::dyn_cast_or_null(op)) { // get size MemRefType type = op.getResult(0).getType().dyn_cast(); int64_t size = type.getElementTypeBitWidth() / 8; for (int64_t i : type.getShape()) { size *= i; } int align = ::oneflow::kBlobBodyAlignSize; size = (size / align + ((size % align) != 0)) * align; val2size[&op] = size; // get life time int min = INT_MAX, max = 0; for (auto use : op.getUsers()) { if (auto wrap_op = llvm::dyn_cast_or_null(use)) { auto op_val = op2lifetime[use]; min = std::min(min, op_val); max = std::max(max, op_val + 1); } } val2lifetime[&op] = {min, max}; } } ::oneflow::MemBlockResultInfo res; ::oneflow::MemReusedMemSizeFirstAlgo(false, val2lifetime, val2size, &res); auto val2offset = res.regst_desc2offset; for (auto [op, offset] : val2offset) { if (auto plan_op = llvm::dyn_cast_or_null(op)) { rewriter.setInsertionPoint(plan_op); auto off_set = rewriter.create(rewriter.getUnknownLoc(), offset); auto type = plan_op->getResult(0).getType(); rewriter.replaceOpWithNewOp(plan_op, type, global_buffer, off_set, ValueRange{}); } } mem_type = MemRefType::get({static_cast(res.mem_block_size)}, rewriter.getI8Type()); rewriter.setInsertionPoint(global_buffer); rewriter.replaceOpWithNewOp(global_buffer, mem_type); } } // namespace struct OptOKMMemrefPattern : public mlir::OpRewritePattern { explicit OptOKMMemrefPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { const auto sym_name = op.getSymName(); if (sym_name.startswith(func_name::WRAP_GRAPH_NAME)) { const auto index = sym_name.substr(func_name::WRAP_GRAPH_NAME.size()).str(); const std::string rename = func_name::OPT_GRAPH_NAME + index; op.setSymNameAttr(rewriter.getStringAttr(rename)); MemSizeFirst(op, rewriter); } return success(); } }; class OptOKMMemrefPass : public OptOKMMemrefPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; std::unique_ptr createOptOKMMemrefPass() { return std::make_unique(); } struct ConvertOKMToOKLPattern : public mlir::OpRewritePattern { static void ConvertOpToOKL(mlir::Operation& it, func::FuncOp& wrap_func, WrapperOp wrap_mem_op, mlir::PatternRewriter& rewriter, int& index) { auto wrap_okl_op = rewriter.create(rewriter.getUnknownLoc(), index++); wrap_okl_op.getBody().emplaceBlock(); OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&wrap_okl_op.getBody().front()); IRMapping mapper; auto ins_num = it.getNumOperands(); auto outs_num = it.getNumResults() + ins_num; for (int idx = 0; idx < ins_num; ++idx) { auto val = llvm::TypeSwitch(wrap_mem_op->getOperand(idx).getDefiningOp()) .Case([&](ArgToMemrefOp op) { return rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), op.getIndex()); }) .Case([&](RetToMemrefOp op) { return rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), op.getIndex()); }) .Case([&](memref::ViewOp op) { auto offset = rewriter.getI64IntegerAttr( llvm::dyn_cast(op.getByteShift().getDefiningOp()) .value()); return rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), offset); }) .Default([&](Operation*) { return Value{}; }); mapper.map(it.getOperand(idx), val); } ImplicitLocOpBuilder new_block(rewriter.getUnknownLoc(), rewriter); auto new_op = new_block.clone(it, mapper); for (int idx = ins_num; idx < outs_num; ++idx) { llvm::TypeSwitch(wrap_mem_op->getOperand(idx).getDefiningOp()) .Case([&](RetToMemrefOp op) { return rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), new_op->getResult(idx - ins_num), op.getIndex()); }) .Case([&](memref::ViewOp op) { auto offset = rewriter.getI64IntegerAttr( llvm::dyn_cast(op.getByteShift().getDefiningOp()).value()); return rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), new_op->getResult(idx - ins_num), offset); }) .Default([&](Operation*) { return Value{}; }); } if (outs_num + 1 == wrap_mem_op->getNumOperands()) { auto op = llvm::dyn_cast(wrap_mem_op->getOperand(outs_num).getDefiningOp()); auto offset = rewriter.getI64IntegerAttr( llvm::dyn_cast(op.getByteShift().getDefiningOp()).value()); rewriter.create( rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()), wrap_func.getArgument(0), offset); } rewriter.create(rewriter.getUnknownLoc()); } static func::FuncOp BuildOKLGraph(func::FuncOp func, mlir::PatternRewriter& rewriter, const std::string& func_name) { OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPoint(func); auto func_type = rewriter.getFunctionType( {mlir::okl::LauncherContextType::get(rewriter.getContext())}, TypeRange{}); auto wrap_func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); auto& block = wrap_func.getBody().emplaceBlock(); wrap_func.getBody().addArguments(mlir::okl::LauncherContextType::get(rewriter.getContext()), rewriter.getUnknownLoc()); rewriter.setInsertionPointToStart(&block); llvm::SmallVector raw_ops; for (auto& op : func.getBody().front()) { raw_ops.push_back(&op); } auto index = 0; for (auto op : raw_ops) { if (auto alloc_op = llvm::dyn_cast_or_null(op)) { if (auto mem_type = alloc_op->getResult(0).getType().dyn_cast_or_null()) { wrap_func->setAttr(func_name::OKL_POOL_SIZE_TAG, rewriter.getI64IntegerAttr(mem_type.getShape().front())); } } if (auto wrap_mem_op = llvm::dyn_cast_or_null(op)) { auto& wrap_ops = wrap_mem_op.getBody().front(); for (auto& it : wrap_ops) { if (oneflow::OneFlowDialect::getDialectNamespace().equals( it.getDialect()->getNamespace())) { ConvertOpToOKL(it, wrap_func, wrap_mem_op, rewriter, index); } } } } rewriter.setInsertionPointToEnd(&block); rewriter.create(rewriter.getUnknownLoc()); return wrap_func; } explicit ConvertOKMToOKLPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { const auto sym_name = op.getSymName(); if (sym_name.startswith(func_name::OPT_GRAPH_NAME)) { const auto index = sym_name.substr(func_name::OPT_GRAPH_NAME.size()).str(); const std::string rename = func_name::OKL_GRAPH_NAME; BuildOKLGraph(op, rewriter, rename); rewriter.eraseOp(op); } return success(); } }; class ConvertOKMToOKLPass : public ConvertOKMToOKLPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; std::unique_ptr createConvertOKMToOKLPass() { return std::make_unique(); } } // namespace okm } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowCanonicalizers.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/random_generator.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowPatternUtils.h" namespace mlir { namespace oneflow { namespace { struct PutSeed : public OpRewritePattern { explicit PutSeed(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(oneflow::RandomMaskLikeOp op, PatternRewriter& rewriter) const override { if (op->hasAttr(op.getSeedAttrName())) { return failure(); } else { op->setAttr(op.getSeedAttrName(), rewrites::GetDefaultSeed(rewriter)); return success(); } } }; } // namespace void RandomMaskLikeOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowDataTypeConversion.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDataTypeConversion.h" #include "OneFlow/OneFlowTypes.h" namespace mlir { namespace oneflow { Type getTypeFromOneFlowDataType(MLIRContext* context, ::oneflow::DataType dt) { if (dt == ::oneflow::DataType::kInvalidDataType) { return InvalidElementType::get(context); } if (dt == ::oneflow::DataType::kChar) { return CharElementType::get(context); } if (dt == ::oneflow::DataType::kFloat16) { return FloatType::getF16(context); } if (dt == ::oneflow::DataType::kFloat) { return FloatType::getF32(context); } if (dt == ::oneflow::DataType::kDouble) { return FloatType::getF64(context); } if (dt == ::oneflow::DataType::kInt8) { return IntegerType::get(context, 8, IntegerType::Signed); } if (dt == ::oneflow::DataType::kInt32) { return IntegerType::get(context, 32, IntegerType::Signed); } if (dt == ::oneflow::DataType::kInt64) { return IntegerType::get(context, 64, IntegerType::Signed); } if (dt == ::oneflow::DataType::kOFRecord) { return OFRecordElementType::get(context); } if (dt == ::oneflow::DataType::kTensorBuffer) { return TensorBufferElementType::get(context); } if (dt == ::oneflow::DataType::kBool) { return IntegerType::get(context, 8, IntegerType::Signed); } if (dt == ::oneflow::DataType::kUInt8) { return IntegerType::get(context, 8, IntegerType::Unsigned); } if (dt == ::oneflow::DataType::kUInt16) { return IntegerType::get(context, 16, IntegerType::Unsigned); } if (dt == ::oneflow::DataType::kUInt32) { return IntegerType::get(context, 32, IntegerType::Unsigned); } if (dt == ::oneflow::DataType::kUInt64) { return IntegerType::get(context, 64, IntegerType::Unsigned); } if (dt == ::oneflow::DataType::kUInt128) { return IntegerType::get(context, 128, IntegerType::Unsigned); } llvm::errs() << "unsupported oneflow data type: " << dt << "\n"; return Type(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowDialect.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowTypes.h" #include "OneFlow/OneFlowOpsDialect.cpp.inc" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TypeRange.h" namespace mlir { namespace oneflow { void OneFlowDialect::initialize() { addOperations< #define GET_OP_LIST #include "OneFlow/OneFlowOps.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.assign_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.binary_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.broadcast_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.conv_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.cross_entropy_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.cuda_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.dataset_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.detection_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.eager_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.fused_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.idempotent_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.identity_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.image_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.indices_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.involution_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.loss_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.math_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.matmul_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.misc_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.nccl_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.normalization_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.optimizer_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.padding_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.parallel_cast_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.pool_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.quantization_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.reduce_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.reshape_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.scalar_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.softmax_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.summary_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.tensor_buffer_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.trigonometric_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.unary_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.upsample_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.one_embedding_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.linear_algebra_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.system_ops.cpp.inc" , #define GET_OP_LIST #include "OneFlow/OneFlow.mlir_jit_ops.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "OneFlow/OneFlowOpsTypes.cpp.inc" >(); } mlir::Operation* OneFlowDialect::materializeConstant(mlir::OpBuilder& builder, mlir::Attribute value, mlir::Type type, mlir::Location loc) { return builder.create(loc, type, ValueRange(), value.cast().getValue()); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowInferReturnTypes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include "OneFlow/UserOpConversion.h" #include "OneFlow/UserOpReflection.h" #include "OneFlow/OneFlowDataTypeConversion.h" #include "mlir/Support/LogicalResult.h" namespace mlir { namespace oneflow { namespace { std::unique_ptr<::oneflow::BlobDesc> getBlobDescFromTensorType(TensorType tensor_type) { auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(tensor_type.getElementType()); if (mlir::succeeded(data_type)) { auto shape_from_mlir = new ::oneflow::Shape(llvm::SmallVector( {tensor_type.getShape().begin(), tensor_type.getShape().end()})); return std::make_unique<::oneflow::BlobDesc>(*shape_from_mlir, data_type.value(), ::oneflow::MemoryFormat::kContiguous); } tensor_type.dump(); LOG(FATAL) << "fail to get BlobDesc from TensorType"; } Type getTensorTypeFromBlobDesc(MLIRContext* context, const ::oneflow::BlobDesc* blob_desc) { if (auto type = getTypeFromOneFlowDataType(context, blob_desc->data_type())) { return RankedTensorType::get( llvm::SmallVector( {blob_desc->shape().dim_vec().begin(), blob_desc->shape().dim_vec().end()}), type); } else { return Type{}; } } static auto MagicalOpName = "INFER_MAGICAL"; LogicalResult ConvertUserOp(llvm::StringRef op_type_name, ::oneflow::OperatorConf& op_conf, ValueRange operands, DictionaryAttr attributes) { oneflow::ConfOpAdaptor conf_op_adaptor(operands, attributes); op_conf.set_name(MagicalOpName); CHECK( user_op::ConvertUserOpInputs(op_type_name, operands, attributes, op_conf.mutable_user_conf()) .succeeded()); if (!succeeded(user_op::ConvertUserOpAttributes(op_type_name, operands, attributes, op_conf))) { return failure(); } return success(); } size_t getResultSize(DictionaryAttr attributes) { const StringRef attr_name = OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(); const DenseI32ArrayAttr& size_attr = attributes.get(attr_name).dyn_cast_or_null(); CHECK(size_attr) << "Attr " << attr_name.str() << " is not found or not DenseI32ArrayAttr"; auto size = 0; for (auto s : size_attr.asArrayRef()) { size += s; } return size; } ::mlir::LogicalResult inferReturnTypesWithOpTypeName( llvm::StringRef op_type_name, ::mlir::MLIRContext* context, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { ::oneflow::OperatorConf op_conf{}; CHECK(ConvertUserOp(op_type_name, op_conf, operands, attributes).succeeded()); std::unordered_map> lbi2logical_blob_desc_; auto operand_ids = user_op::ArgIds(op_type_name, operands.size(), attributes); auto operand_index = 0; for (const auto& idOperand : llvm::zip(operand_ids, operands)) { const auto& arg_name = std::get<0>(idOperand).first; const auto& arg_id = std::get<0>(idOperand).second; const auto operand = std::get<1>(idOperand); auto blob_desc = getBlobDescFromTensorType(operand.getType().cast()); auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id); lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc)); operand_index += 1; } auto result_ids = user_op::ArgIds( op_type_name, getResultSize(attributes), attributes); for (const auto& result_id : result_ids) { const auto& arg_name = result_id.first; const auto& arg_id = result_id.second; const auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id); auto blob_desc = std::make_unique<::oneflow::BlobDesc>(::oneflow::kInvalidDataType, ::oneflow::MemoryFormat::kContiguous); lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc)); (*op_conf.mutable_user_conf()->mutable_output())[arg_name].add_s( ::oneflow::GenLogicalBlobName(op_conf.name(), bn)); } auto op = CHECK_JUST(ConstructOp(op_conf, user_op::getDeviceTypeFromAttrDictionary(attributes))); auto GetLogicalBlobDesc4BnInOp = [&](const std::string& bn) -> ::oneflow::BlobDesc* { auto it = lbi2logical_blob_desc_.find(bn); if (it == lbi2logical_blob_desc_.end()) { LOG(FATAL) << "fail to find blob name in op: " << bn; } return it->second.get(); }; ::oneflow::ParallelConf parallel_conf = user_op::getParallelConfFromAttrDictionary(attributes); ::oneflow::ParallelDesc parallel_desc{parallel_conf}; CHECK_JUST(op->FillOpParallelDesc(parallel_desc)); CHECK_JUST(op->InferLogicalOutBlobDescs(GetLogicalBlobDesc4BnInOp, parallel_desc)); for (const auto& result_id : result_ids) { const auto& arg_name = result_id.first; const auto& arg_id = result_id.second; const auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id); const auto* desc = lbi2logical_blob_desc_.at(bn).get(); if (auto t = getTensorTypeFromBlobDesc(context, desc)) { inferredReturnTypes.push_back(t); } } return success(); } } // namespace ::mlir::LogicalResult NormalizationAddReluOp::refineReturnTypes( ::mlir::MLIRContext* context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { return success(); } ::mlir::LogicalResult NormalizationAddReluOp::inferReturnTypes( ::mlir::MLIRContext* context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { return inferReturnTypesWithOpTypeName("normalization_add_relu", context, operands, attributes, regions, inferredReturnTypes); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/variable_tensor_mgr.h" namespace mlir { namespace oneflow { namespace { namespace functional = ::oneflow::one::functional; using TensorPtr = std::shared_ptr<::oneflow::one::Tensor>; using MaybeTensor = ::oneflow::Maybe<::oneflow::one::Tensor>; StringAttr GenNewVariableOpName(MLIRContext* ctx, const std::string& key = "") { if (key == "") { return StringAttr::get(ctx, "variable_" + ::oneflow::NewUniqueId()); } return StringAttr::get(ctx, "variable_" + key + "_" + ::oneflow::NewUniqueId()); } bool MLIRDataTypesAreSame(const std::vector& data_types) { if (data_types.empty() || data_types.size() == 1) { return true; } bool result = true; const auto first_data_type = data_types[0]; for (size_t i = 1; i < data_types.size(); ++i) { result &= (first_data_type == data_types[i]); } return result; } bool DictionaryAttrsHaveSameDataType(const std::vector& attrs) { std::vector data_types; for (const auto& attr : attrs) { data_types.push_back(attr.get(OpTrait::TensorSource::getDataTypeAttrName()) .cast() .getValue()); } return MLIRDataTypesAreSame(data_types); } OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, const std::function& f) { ::oneflow::LazyMode::Guard guard{false}; if (!operands.front()) { return {}; } // Important! const auto attr_dict = operands.front().cast(); auto attrs = NamedAttrList(attr_dict); const auto tensor = support::DenseElementsAttrToTensor( attr_dict.get("value"), attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); attrs.set(OpTrait::TensorSource::getDataTypeAttrName(), attr_dict.get(OpTrait::TensorSource::getDataTypeAttrName())); return attrs.getDictionary(ctx); } OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef operands, const std::function& f) { ::oneflow::LazyMode::Guard guard{false}; if (!(operands.front() && operands.back())) { return {}; } // Important! auto lhs_attr_dict = operands.front().cast(); auto rhs_attr_dict = operands.back().cast(); if (!DictionaryAttrsHaveSameDataType({lhs_attr_dict, rhs_attr_dict})) { llvm::errs() << "Input tensors should have same data type in binary operation of constant folding." << "\n"; return nullptr; } auto attrs = NamedAttrList(lhs_attr_dict); const auto lhs_tensor = support::DenseElementsAttrToTensor( lhs_attr_dict.get("value"), lhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), lhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto rhs_tensor = support::DenseElementsAttrToTensor( rhs_attr_dict.get("value"), rhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()), rhs_attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); attrs.set(OpTrait::TensorSource::getDataTypeAttrName(), lhs_attr_dict.get(OpTrait::TensorSource::getDataTypeAttrName())); return attrs.getDictionary(ctx); } } // namespace OpFoldResult FrozenVariableOp::fold(FoldAdaptor adaptor) { NamedAttrList attrs; attrs.set(getValueAttrName(), getValueAttr()); attrs.set(getOpNameAttrName(), getOpNameAttr()); attrs.set(getDataTypeAttrName(), getDataTypeAttr()); attrs.set(getDeviceTagAttrName(), getDeviceTagAttr()); attrs.set(getDeviceNameAttrName(), getDeviceNameAttr()); attrs.set(getScopeSymbolIdAttrName(), getScopeSymbolIdAttr()); attrs.set(getHierarchyAttrName(), getHierarchyAttr()); attrs.set(getNdSbpAttrName(), getNdSbpAttr()); return DictionaryAttr::get(getContext(), attrs); } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return UnaryFold(getContext(), operands, [this](const auto& tensor) { std::vector perm_; for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast().getSInt()); } return functional::Transpose(tensor, perm_); }); } OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return UnaryFold(getContext(), operands, [this](const auto& tensor) { std::vector shape_vec; for (auto& x : getShape().getValue()) { shape_vec.emplace_back(x.cast().getValue().getSExtValue()); } return functional::Reshape( tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end()))); }); } OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor { if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); } if (getHasFloatOperand()) { return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false); } emitError("Scalar op must has a int operand or a float operand."); return TensorPtr(); }); } OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return UnaryFold(getContext(), operands, functional::Sqrt); } OpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return BinaryFold(getContext(), operands, functional::Mul); } OpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return BinaryFold(getContext(), operands, functional::Div); } OpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor { return functional::Sub(lhs, rhs, /*alpha=*/1.0, false); }); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in ================================================ #include #include #include "llvm/ADT/STLExtras.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/StringSet.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/OneFlowInterfaces.h.inc" #include "OneFlow/OneFlowTypes.h" #define GET_OP_CLASSES #include "OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.h.inc" #define GET_OP_CLASSES #include "OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.cpp.inc" ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowOpTraits.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include "OneFlow/UserOpConversion.h" namespace mlir { namespace OpTrait { namespace { // TODO: merge all ctrl input and output when folding op bool HaveIdenticalPlacement(mlir::Operation* a, mlir::Operation* b) { const bool has_identical_dev_tag = IsOpConfCompatible::getDeviceTag(a) == IsOpConfCompatible::getDeviceTag(b); const bool has_identical_dev_name = IsOpConfCompatible::getDeviceName(a) == IsOpConfCompatible::getDeviceName(b); return has_identical_dev_tag && has_identical_dev_name; } } // namespace namespace impl { OpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op) { auto* argument_op = op->getOperand(0).getDefiningOp(); if (argument_op && op->getName() == argument_op->getName() && HaveIdenticalPlacement(op, argument_op)) { return op->getOperand(0); } return {}; } OpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op) { auto* argument_op = op->getOperand(0).getDefiningOp(); if (argument_op && op->getName() == argument_op->getName() && HaveIdenticalPlacement(op, argument_op)) { return argument_op->getOperand(0); } return {}; } LogicalResult VerifyIsOpConfCompatible(Operation* op) { for (auto attr : { IsOpConfCompatible::getOpNameAttr(), IsOpConfCompatible::getDeviceTagAttr(), }) { if (!op->hasAttrOfType(attr)) { return op->emitError("expected operation to have attribute: " + attr); } } if (!op->hasAttrOfType(IsOpConfCompatible::getDeviceNameAttr())) { return op->emitError("expected operation to have attribute: " + IsOpConfCompatible::getDeviceNameAttr()); } return success(); } LogicalResult VerifyIsImportCompatible(Operation* op) { if (auto output_lbns = op->getAttrOfType(IsImportCompatible::getOutputLBNsAttr())) { if (auto cec = dyn_cast(op)) { if (cec.dataOutputResults().size() != output_lbns.size()) { return op->emitError("expected number of data output results to be " + std::to_string(output_lbns.size()) + " but got " + std::to_string(cec.dataOutputResults().size())); } } else { return op->emitError("expected to support ControlEdgeCompatible"); } } else { return op->emitError("expected operation to have attribute: " + IsImportCompatible::getOutputLBNsAttr()); } return success(); } LogicalResult saveAttrToOpConf(Operation* op, ::oneflow::OperatorConf* op_conf) { return oneflow::user_op::saveAttrDictionaryToOpConf(op->getAttrDictionary(), op_conf); } StringAttr getOpName(Operation* op) { assert(op->hasTrait()); return op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()); } StringAttr getDeviceTag(Operation* op) { assert(op->hasTrait()); return op->getAttrOfType(IsOpConfCompatible::getDeviceTagAttr()); } ArrayAttr getDeviceName(Operation* op) { assert(op->hasTrait()); return op->getAttrOfType(IsOpConfCompatible::getDeviceNameAttr()); } IntegerAttr getScopeSymbolID(Operation* op) { assert(op->hasTrait()); return op->getAttrOfType(IsOpConfCompatible::getScopeSymbolIDAttr()); } ArrayAttr getHierarchy(Operation* op) { assert(op->hasTrait()); return op->getAttrOfType(IsOpConfCompatible::getHierarchyAttr()); } LogicalResult saveAttrsToNamedAttrList(Operation* op, NamedAttrList& attributes) { attributes.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), OpTrait::IsOpConfCompatible::getDeviceTag(op)); attributes.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), OpTrait::IsOpConfCompatible::getDeviceName(op)); if (auto hierarchy = OpTrait::IsOpConfCompatible::getHierarchy(op)) { attributes.set(OpTrait::IsOpConfCompatible::getHierarchyAttr(), hierarchy); } attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), OpTrait::IsOpConfCompatible::getOpName(op)); if (auto scope_symbol_id = OpTrait::IsOpConfCompatible::getScopeSymbolID(op)) { attributes.set(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), scope_symbol_id); } return success(); } } // namespace impl } // namespace OpTrait } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/SBP/SBPAttributes.h" #include "OneFlow/Transform/TransposeHelpers.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/vm/vm_util.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include #include #include #include namespace mlir { namespace oneflow { OperandRange UserOp::dataInputOperands() { return getDataInput(); } OperandRange UserOp::ctrlInputOperands() { return getCtrlInputs(); } ResultRange UserOp::dataOutputResults() { return getDataOutput(); } Value UserOp::ctrlOutputResult() { return getCtrlOutput(); } OperandRange SystemOp::dataInputOperands() { return getDataInput(); } OperandRange SystemOp::ctrlInputOperands() { return getCtrlInputs(); } ResultRange SystemOp::dataOutputResults() { return getDataOutput(); } Value SystemOp::ctrlOutputResult() { return getCtrlOutput(); } OperandRange VariableOp::dataInputOperands() { return {operand_begin(), operand_begin()}; } OperandRange VariableOp::ctrlInputOperands() { return getCtrlInputs(); } ResultRange VariableOp::dataOutputResults() { return getOutput().dyn_cast(); } Value VariableOp::ctrlOutputResult() { return getCtrlOutput(); } OperandRange InputOp::dataInputOperands() { return getODSOperands(0); } OperandRange InputOp::ctrlInputOperands() { return getCtrlInputs(); } ResultRange InputOp::dataOutputResults() { return getOutput().dyn_cast(); } Value InputOp::ctrlOutputResult() { return getCtrlOutput(); } OperandRange OutputOp::dataInputOperands() { return getODSOperands(0); } OperandRange OutputOp::ctrlInputOperands() { return getCtrlInputs(); } ResultRange OutputOp::dataOutputResults() { return getOutput().dyn_cast(); } Value OutputOp::ctrlOutputResult() { return getCtrlOutput(); } static ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) { mlir::DenseElementsAttr value; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(value, "value", result.attributes)) { return failure(); } result.addTypes(value.getType()); return success(); } ArrayAttr getSI32ArrayAttr(::mlir::PatternRewriter& rewriter, ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [&](int32_t v) -> Attribute { return rewriter.getSI32IntegerAttr(v); })); return rewriter.getArrayAttr(attrs); } namespace { LogicalResult TrimRedundantCtrl(Operation* op, PatternRewriter& rewriter) { auto ctrl_out = GetCtrlOutputResult(op); auto data_outputs = GetDataOutputResults(op); if (ctrl_out && ctrl_out.value().use_empty()) { const int32_t num_data_outputs = data_outputs.size(); NamedAttrList attributes(op->getAttrs()); if (op->hasTrait()) { attributes.erase(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); attributes.push_back( rewriter.getNamedAttr(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), rewriter.getDenseI32ArrayAttr({num_data_outputs, 0}))); } OperationState state(op->getLoc(), op->getName(), op->getOperands(), data_outputs.getTypes(), attributes); auto created = rewriter.create(state); for (auto data_output : data_outputs) { data_output.replaceAllUsesWith(created->getOpResult(data_output.getResultNumber())); } op->erase(); return success(); } return failure(); } bool IsCtrlOutTrimmed(UserOp& op) { return !op.getCtrlOutput(); } bool IsCtrlInAbsent(UserOp& op) { if (!op->hasAttrOfType( OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr())) op.dump(); return op.getCtrlInputs().empty(); } } // namespace template static void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector& arrayValues) { for (Attribute val : attr.getValue()) { arrayValues.push_back(val.cast().getValue().getSExtValue()); } } struct ConcreteUserOps : public OpRewritePattern { explicit ConcreteUserOps(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(UserOp op, PatternRewriter& rewriter) const override { if (succeeded(TrimRedundantCtrl(op, rewriter))) { return success(); } // In principle, a concrete user op has no ctrl input/output. Some benefits: // 1. simplify things // 2. make conversion and code gen more doable // 3. enable the reuse of established MLIR infra like built-in traits if (IsCtrlOutTrimmed(op) && IsCtrlInAbsent(op)) { NamedAttrList attributes(op->getAttrDictionary()); attributes.erase(op.getInputSizesAttrName()); attributes.erase(op.getOutputSizesAttrName()); attributes.erase(op.getOutputLbnsAttrName()); attributes.erase(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); attributes.erase(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); llvm::SmallVector input_sizes, output_sizes; getValuesFromIntArrayAttribute(op.getInputSizes(), input_sizes); getValuesFromIntArrayAttribute(op.getOutputSizes(), output_sizes); if (!input_sizes.empty()) { attributes.push_back(rewriter.getNamedAttr( OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), rewriter.getDenseI32ArrayAttr(input_sizes))); } if (!output_sizes.empty()) { attributes.push_back(rewriter.getNamedAttr( OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), rewriter.getDenseI32ArrayAttr(output_sizes))); } OperationState state(op->getLoc(), OneFlowDialect::getDialectNamespace().str() + "." + op.getOpTypeName().str()); state.addAttributes(attributes); state.addOperands(op.getODSOperands(0) /* data in */); state.addTypes(op.getODSResults(0 /* data out */).getTypes()); if (auto created = rewriter.create(state)) { if (created->hasTrait() == false) { created->removeAttr(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); } if (created->hasTrait() == false) { created->removeAttr(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); } if (created->hasTrait() == false) { created->removeAttr(OpTrait::IsAlternative::getOpTypeNameAttr()); } rewriter.replaceOp(op, created->getResults()); } else { op->emitError("Fail to convert opaque user op to concrete op when creating: " + op.getOpTypeName()); op->dump(); return failure(); } } return success(); } }; void UserOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } struct ConcreteSystemOps : public OpRewritePattern { explicit ConcreteSystemOps(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(oneflow::SystemOp op, PatternRewriter& rewriter) const override { return TrimRedundantCtrl(op, rewriter); } }; void SystemOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } struct ConvertAddOpWithArity : public OpRewritePattern { explicit ConvertAddOpWithArity(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(AddNOp op, PatternRewriter& rewriter) const override { const auto arity = op.getIn().size(); if (arity == 2) { NamedAttrList attributes = op->getAttrs(); attributes.set(OpTrait::IsAlternative::getOpTypeNameAttr(), rewriter.getStringAttr("add_n")); if (auto created_op = rewriter.replaceOpWithNewOp(op, op->getResultTypes(), op.getOperands(), attributes)) { return success(); } else { op->emitError("Fail to convert add op with arity: ") << arity; op->dump(); return failure(); } } return failure(); } }; void AddNOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } template struct ConcreteSystemOpPattern : public OpRewritePattern { explicit ConcreteSystemOpPattern(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(OpType op, PatternRewriter& rewriter) const override { if (op.getCtrlOutput() && op.getCtrlOutput().use_empty()) { NamedAttrList attributes(op->getAttrDictionary()); if (auto created = rewriter.create(op->getLoc(), op.getOutput().getType(), op->getOperands(), attributes)) { op.getOutput().replaceAllUsesWith( created->getResult(op.getOutput().template cast().getResultNumber())); op->erase(); return success(); } } return failure(); } }; void VariableOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert>(context); } void InputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert>(context); } void OutputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert>(context); } std::string Add2Op::getOriginalOpTypeName() { return "add_n"; } std::string NormalizationInferenceOp::getOriginalOpTypeName() { return "normalization"; } void Job::build(OpBuilder& builder, OperationState& state, StringRef name, FunctionType type, llvm::ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(Job::getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); } ParseResult Job::parse(OpAsmParser& parser, OperationState& result) { auto buildFuncType = [](Builder& builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string&) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void Job::print(OpAsmPrinter& p) { function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } LogicalResult Job::verify() { // If this function is external there is nothing to do. if (isExternal()) return success(); // Verify that the argument list of the function and the arg list of the entry // block line up. The trait already verified that the number of arguments is // the same between the signature and the block. auto fnInputTypes = getFunctionType().getInputs(); Block& entryBlock = front(); for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) return emitOpError("type of entry block argument #") << i << '(' << entryBlock.getArgument(i).getType() << ") must match the type of the corresponding argument in " << "function signature(" << fnInputTypes[i] << ')'; return success(); } LogicalResult ReturnOp::verify() { auto job = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto& results = job.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << job.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << job.getName(); return success(); } struct NormalizationInferencePattern : public OpRewritePattern { explicit NormalizationInferencePattern(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(oneflow::NormalizationOp op, PatternRewriter& rewriter) const override { if (op.getMean() || op.getInvVariance()) return failure(); if (auto created_op = rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getOperands(), op->getAttrs())) { return success(); } op.emitError("Failed to create inference bn op"); return failure(); } }; void NormalizationOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } ResultRange GetDataOutputResults(Operation* op) { if (auto cec = dyn_cast(op)) { return cec.dataOutputResults(); } else { return op->getResults(); } } OperandRange GetDataInputOperands(Operation* op) { if (auto cec = dyn_cast(op)) { return cec.dataInputOperands(); } else { return op->getOperands(); } } llvm::Optional GetCtrlIntputOperands(Operation* op) { if (auto cec = dyn_cast(op)) { return cec.ctrlInputOperands(); } else { return llvm::None; } } llvm::Optional GetCtrlOutputResult(Operation* op) { if (auto cec = dyn_cast(op)) { if (auto ctrl_out = cec.ctrlOutputResult()) { return ctrl_out.cast(); } } return llvm::None; } } // namespace oneflow } // namespace mlir #include "OneFlow/OneFlowEnums.cpp.inc" #define GET_OP_CLASSES #include "OneFlow/OneFlowOps.cpp.inc" #include "OneFlow/OneFlowInterfaces.cpp.inc" ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowRewrites.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ //===- TestPDLByteCode.cpp - Test PDLL functionality ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "OneFlow/UserOpConversion.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" #include "OneFlow/OneFlowOps.h" #include "oneflow/core/framework/random_generator.h" #include "OneFlow/OneFlowUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/Dialect/Func/IR/FuncOps.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.h.inc" namespace mlir { namespace oneflow { namespace { static std::atomic uniqID{0}; std::string getUniqName(llvm::StringRef name) { uniqID += 1; return name.str() + "-mlir-gen-" + std::to_string(uniqID); } static Operation* CopyUserOpAttrs(PatternRewriter& rewriter, Operation* src, Operation* dst) { dst->setAttr(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), OpTrait::IsOpConfCompatible::getDeviceTag(src)); dst->setAttr(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), OpTrait::IsOpConfCompatible::getDeviceName(src)); if (auto hierarchy = OpTrait::IsOpConfCompatible::getHierarchy(src)) { dst->setAttr(OpTrait::IsOpConfCompatible::getHierarchyAttr(), hierarchy); } if (auto scope_symbol_id = OpTrait::IsOpConfCompatible::getScopeSymbolID(src)) { dst->setAttr(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), scope_symbol_id); } dst->setAttr( OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr(getUniqName(OpTrait::IsOpConfCompatible::getOpName(src).str()))); return dst; } static Operation* BuildFusedBiasAddMaskScaleOpWithRate(PatternRewriter& rewriter, Value a, Value b, Value mask, Attribute axis, Attribute rate, Operation* dropout) { auto dropout_op = llvm::dyn_cast(dropout); assert(dropout_op); SmallVector operands; operands.push_back(a); operands.push_back(b); operands.push_back(mask); NamedAttrList attributes; attributes.set("axis", axis); float scale = 1.0f; float rate_float = rate.cast().getValueAsDouble(); if (rate_float < 1.0f) { scale = 1.0f / (1.0f - rate_float); } attributes.set("scale", rewriter.getF32FloatAttr(scale)); return rewriter.create( dropout_op->getLoc(), dropout_op.getOut().getType(), operands, attributes); } static Operation* CreateConv2dAndErasePad(PatternRewriter& rewriter, Value x, Value weight, Attribute padding_before, Attribute data_format, Operation* conv) { auto conv_op = llvm::dyn_cast(conv); assert(conv_op); SmallVector operands; operands.push_back(x); operands.push_back(weight); NamedAttrList attributes = conv_op->getAttrs(); llvm::SmallVector padding_before_array; attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr(OpTrait::IsOpConfCompatible::getOpName(conv).str() + "-fuse-conv")); if (data_format.cast().str() == "channels_first") { for (auto val : padding_before.cast().getValue().take_back(2)) { padding_before_array.push_back(val.cast().getValue().getSExtValue()); } } else { padding_before_array.push_back(padding_before.cast() .getValue()[1] .cast() .getValue() .getSExtValue()); padding_before_array.push_back(padding_before.cast() .getValue()[2] .cast() .getValue() .getSExtValue()); } attributes.set(conv_op.getPaddingBeforeAttrName(), getSI32ArrayAttr(rewriter, padding_before_array)); return rewriter.create(conv_op->getLoc(), conv_op.getOut().getType(), operands, attributes); } IntegerAttr getSI64IntegerAttr(::mlir::PatternRewriter& rewriter, int64_t value) { return IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true), APInt(64, value, /*isSigned=*/true)); } static Attribute GetHeadSizeFromTranpose(PatternRewriter& rewriter, Operation* transpose) { auto transpose_op = llvm::dyn_cast(transpose); CHECK(transpose_op); return getSI64IntegerAttr(rewriter, transpose_op.getOutput().getType().cast().getDimSize(3)); } NamedAttrList GetUserOpCommonAttrs(MLIRContext* ctx, const std::string& op_name) { NamedAttrList attrs; attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), StringAttr::get(ctx, op_name)); attrs.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), StringAttr::get(ctx, "cpu")); attrs.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(ArrayRef({"@0:0"}), [&](StringRef v) -> Attribute { return StringAttr::get(ctx, v); })))); return attrs; } static Operation* CreateConv2DBatchNorm(PatternRewriter& rewriter, Attribute epsilon, Operation* conv, Operation* bn) { auto conv_op = llvm::dyn_cast(conv); auto bn_op = llvm::dyn_cast(bn); auto ctx = rewriter.getContext(); NamedAttrList attributes = conv_op->getAttrs(); attributes.set(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), rewriter.getDenseI32ArrayAttr({1, 1, 1, 0})); SmallVector operands; operands.push_back(conv_op.getIn()); // deal with weight auto add_op_attrs = GetUserOpCommonAttrs(ctx, "scalar_add"); add_op_attrs.set("has_float_operand", BoolAttr::get(ctx, true)); double epsilon_attr = epsilon.cast().getValueAsDouble(); add_op_attrs.set("float_operand", rewriter.getF64FloatAttr(epsilon_attr)); auto add_op = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({bn_op.getMovingVariance()}), add_op_attrs); auto sqrt_op = rewriter.create(conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({add_op.getOut()}), GetUserOpCommonAttrs(ctx, "sqrt")); auto div_op = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({bn_op.getGamma(), sqrt_op.getY()}), GetUserOpCommonAttrs(ctx, "div")); auto bn_gamma_variable_op = llvm::dyn_cast(bn_op.getGamma().getDefiningOp()); CHECK(bn_gamma_variable_op) << "Gamma of batchnorm should be a FrozenVariableOp."; auto bn_gamma_shape = bn_gamma_variable_op.getValue().getType().cast().getShape(); auto conv_weight_variable_op = llvm::dyn_cast(conv_op.getWeight().getDefiningOp()); CHECK(conv_weight_variable_op) << "Weight of conv2d should be a FrozenVariableOp."; auto conv_weight_shape = conv_weight_variable_op.getValue().getType().cast().getShape(); std::vector bn_gamma_new_shape({bn_gamma_shape.front()}); for (int i = 1; i < conv_weight_shape.size(); ++i) { bn_gamma_new_shape.emplace_back(1); } auto reshape_op_attrs = GetUserOpCommonAttrs(ctx, "reshape"); reshape_op_attrs.set( "shape", ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range( ArrayRef(bn_gamma_new_shape), [&](int64_t v) -> Attribute { return getSI64IntegerAttr(rewriter, v); })))); auto reshape_op = rewriter.create(conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({div_op.getZ()}), reshape_op_attrs); auto mul_op = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({conv_op.getWeight(), reshape_op.getOut()}), GetUserOpCommonAttrs(ctx, "multiply")); operands.push_back(mul_op.getZ()); // deal with bias CHECK(!conv_op.getBias()) << "Fusing conv2d and batch_norm only supports conv2d without bias now."; auto mul_op_bias = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({bn_op.getMovingMean(), div_op.getZ()}), GetUserOpCommonAttrs(ctx, "multiply_bias")); auto sub_op_bias = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), SmallVector({bn_op.getBeta(), mul_op_bias.getZ()}), GetUserOpCommonAttrs(ctx, "sub_bias")); operands.push_back(sub_op_bias.getZ()); auto new_conv_op = rewriter.create( conv_op->getLoc(), conv_op.getOut().getType(), operands, attributes); return new_conv_op; } static LogicalResult IsPaddingCouldBeAssimilatedIntoConv(PatternRewriter& rewriter, Attribute padding_before, Attribute padding_after, Attribute data_format) { if (padding_before.cast().size() == 4 && padding_after.cast().size() == 4) { if (padding_before.cast().getValue().equals( padding_after.cast().getValue())) { if (data_format.cast().str() == "channels_first") { return success(padding_before.cast() .getValue()[0] .cast() .getValue() .getSExtValue() == 0 && padding_before.cast() .getValue()[1] .cast() .getValue() .getSExtValue() == 0); } if (data_format.cast().str() == "channels_last") { return success(padding_before.cast() .getValue()[0] .cast() .getValue() .getSExtValue() == 0 && padding_before.cast() .getValue()[3] .cast() .getValue() .getSExtValue() == 0); } } } return failure(); } static LogicalResult IsNotNestedInJit(PatternRewriter& rewriter, Operation* mul) { return success(mul->getParentOfType()); } static LogicalResult IsScalarTensor(PatternRewriter& rewriter, Value value) { if (auto tensor = value.getType().dyn_cast()) { return success(tensor.getNumElements() == 1); } return failure(); } static float mha_scale_max_diff = 1e-5; static LogicalResult IsScalarEqualSqrtDim(PatternRewriter& rewriter, Value query_reshape, Attribute scalar_div_operand) { auto query_reshape_shape = query_reshape.getType().dyn_cast(); double scalar_div_operand_attr = scalar_div_operand.cast().getValueAsDouble(); return success( std::abs(std::sqrt(query_reshape_shape.getShape().back()) - scalar_div_operand_attr) < mha_scale_max_diff); } static LogicalResult IsScalarEqualSqrtDimReciprocal(PatternRewriter& rewriter, Value query_reshape, Attribute scalar_div_operand) { auto query_reshape_shape = query_reshape.getType().dyn_cast(); double scalar_div_operand_attr = scalar_div_operand.cast().getValueAsDouble(); return success( std::abs(std::sqrt(query_reshape_shape.getShape().back()) - (1 / scalar_div_operand_attr)) < mha_scale_max_diff); } static Attribute GetReciprocal(PatternRewriter& rewriter, Attribute a) { return rewriter.getF64FloatAttr(1 / a.cast().getValueAsDouble()); } } // namespace namespace rewrites { void populateRewrites(RewritePatternSet& patterns) { patterns.getPDLPatterns().registerRewriteFunction("BuildFusedBiasAddMaskScaleOpWithRate", BuildFusedBiasAddMaskScaleOpWithRate); patterns.getPDLPatterns().registerRewriteFunction("CopyUserOpAttrs", CopyUserOpAttrs); patterns.getPDLPatterns().registerRewriteFunction("GetHeadSizeFromTranpose", GetHeadSizeFromTranpose); patterns.getPDLPatterns().registerRewriteFunction("CreateConv2dAndErasePad", CreateConv2dAndErasePad); patterns.getPDLPatterns().registerRewriteFunction("CreateConv2DBatchNorm", CreateConv2DBatchNorm); patterns.getPDLPatterns().registerRewriteFunction("GetReciprocal", GetReciprocal); } mlir::IntegerAttr GetDefaultSeed(::mlir::PatternRewriter& rewriter) { const auto gen = CHECK_JUST(::oneflow::one::DefaultAutoGenerator()); return getSI64IntegerAttr(rewriter, (int64_t)gen->current_seed()); } } // namespace rewrites namespace constraints { void populateConstraints(RewritePatternSet& patterns) { auto& pdll_patterns = patterns.getPDLPatterns(); #define PDLL_REGISTER(NAME) pdll_patterns.registerConstraintFunction(#NAME, NAME); PDLL_REGISTER(IsPaddingCouldBeAssimilatedIntoConv); PDLL_REGISTER(IsNotNestedInJit); PDLL_REGISTER(IsScalarTensor); PDLL_REGISTER(IsScalarEqualSqrtDim); PDLL_REGISTER(IsScalarEqualSqrtDimReciprocal); #undef PDLL_REGISTER } } // namespace constraints } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowSupport.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "oneflow/ir/include/OneFlow/OneFlowSupport.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/common/data_type.h" #include #include namespace mlir { namespace oneflow { namespace support { std::vector GetInputKeys(const std::string& op_type_name) { std::vector ret{}; for (auto& arg : getUserOpDef(op_type_name).input()) { ret.push_back(arg.name()); } return ret; } std::vector GetOutputKeys(const std::string& op_type_name) { std::vector ret{}; for (auto& arg : getUserOpDef(op_type_name).output()) { ret.push_back(arg.name()); } return ret; } namespace { ::oneflow::Symbol<::oneflow::Device> MakeDevice(const mlir::Attribute& device_tag_attr, const mlir::Attribute& device_name_attr) { const auto device_tag = device_tag_attr.cast().str(); const auto device_name = device_name_attr.cast().getValue().front().cast().str(); const std::string device_info = device_tag == "gpu" ? "cuda" : device_tag + device_name.substr(device_name.rfind(":")); return ::oneflow::Device::ParseAndNew(device_info).GetOrThrow(); } template mlir::DenseElementsAttr __TensorToDenseElementsAttr( const std::shared_ptr<::oneflow::one::Tensor>& tensor, const MLIR_T& mlir_type) { ::oneflow::LazyMode::Guard guard{false}; const auto tensor_ = ::oneflow::one::functional::ToContiguous(tensor).GetPtrOrThrow(); auto shape = tensor_->shape(); std::vector shape_vec(shape->dim_vec().begin(), shape->dim_vec().end()); std::vector data(shape->elem_cnt()); const auto& callback = [&](::oneflow::ep::Stream* stream, const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) { ::oneflow::AutoMemcpy(stream, data.data(), eager_blob_object->dptr(), data.size() * sizeof(T), ::oneflow::memory::MakeHostMemCase(), eager_blob_object->mem_case()); }; ::oneflow::one::SyncAccessTensorWithTimeOut(tensor_, callback, "const").GetOrThrow(); return mlir::DenseElementsAttr::get(mlir::RankedTensorType::get(shape_vec, mlir_type), llvm::makeArrayRef(data)); } template std::shared_ptr<::oneflow::one::Tensor> __DenseElementsAttrToTensor( const mlir::DenseElementsAttr dense_attr, const mlir::Attribute& device_tag_attr, const mlir::Attribute& device_name_attr, const ::oneflow::DataType& dtype) { const auto dense_type = dense_attr.getType().cast(); std::vector shape = dense_type.getShape().vec(); const auto device = MakeDevice(device_tag_attr, device_name_attr); std::shared_ptr<::oneflow::one::Tensor> tensor = ::oneflow::one::functional::Empty( ::oneflow::Shape(::oneflow::DimVector(shape.begin(), shape.end())), ::oneflow::DType::Get(dtype).GetOrThrow(), device, /*requires_grad=*/false, /*pin_memory=*/false) .GetPtrOrThrow(); std::vector data(dense_attr.getValues().begin(), dense_attr.getValues().end()); const auto& callback = [&](::oneflow::ep::Stream* stream, const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) { ::oneflow::AutoMemcpy(stream, eager_blob_object->mut_dptr(), data.data(), tensor->shape()->elem_cnt() * sizeof(T), eager_blob_object->mem_case(), ::oneflow::memory::MakeHostMemCase()); }; ::oneflow::one::SyncAccessTensorWithTimeOut(tensor, callback, "mut").GetOrThrow(); return tensor; } template void __DenseElementsAttrToTensor(const mlir::DenseElementsAttr dense_attr, const mlir::Attribute& device_tag_attr, const mlir::Attribute& device_name_attr, const ::oneflow::DataType& dtype, std::shared_ptr<::oneflow::one::Tensor>& tensor) { const auto dense_type = dense_attr.getType().cast(); std::vector shape = dense_type.getShape().vec(); int ndim = shape.size(); CHECK_EQ(tensor->shape()->size(), ndim); for (int i = 0; i < ndim; ++i) { CHECK_EQ(tensor->shape()->at(i), shape[i]); } const auto device = MakeDevice(device_tag_attr, device_name_attr); CHECK(CHECK_JUST(tensor->device()) == device); std::vector data; std::vector<::oneflow::float16> fp16_data; void* dptr = nullptr; const size_t tensor_size = tensor->shape()->elem_cnt() * ::oneflow::GetSizeOfDataType(tensor->dtype()->data_type()); CHECK_EQ(::oneflow::GetDataType::value, dtype); if (tensor->dtype()->data_type() == ::oneflow::DataType::kFloat16) { for (const T elem : dense_attr.getValues()) { fp16_data.push_back(static_cast<::oneflow::float16>(elem)); } CHECK_EQ(fp16_data.size() * sizeof(::oneflow::float16), tensor_size); dptr = fp16_data.data(); } else if (tensor->dtype()->data_type() == dtype) { for (const T elem : dense_attr.getValues()) { data.push_back(elem); } CHECK_EQ(data.size() * sizeof(T), tensor_size); dptr = data.data(); } else { UNIMPLEMENTED(); } const auto& callback = [=](::oneflow::ep::Stream* stream, const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) { ::oneflow::AutoMemcpy(stream, eager_blob_object->mut_dptr(), dptr, tensor_size, eager_blob_object->mem_case(), ::oneflow::memory::MakeHostMemCase()); }; ::oneflow::one::SyncAccessTensorWithTimeOut(tensor, callback, "mut").GetOrThrow(); } } // namespace mlir::DenseElementsAttr TensorToDenseElementsAttr( const std::shared_ptr<::oneflow::one::Tensor>& tensor, MLIRContext* ctx) { const auto dtype = tensor->dtype()->data_type(); if (dtype == ::oneflow::DataType::kFloat) { return __TensorToDenseElementsAttr(tensor, mlir::FloatType::getF32(ctx)); } else if (dtype == ::oneflow::DataType::kInt64) { auto mlir_type = mlir::IntegerType::IntegerType::get( ctx, 64, mlir::IntegerType::SignednessSemantics::Signed); return __TensorToDenseElementsAttr(tensor, mlir_type); } llvm::errs() << "Converting oneflow::Tensor to mlir::DenseElementsAttr only support float32 now." << "\n"; exit(EXIT_FAILURE); } std::shared_ptr<::oneflow::one::Tensor> DenseElementsAttrToTensor( const mlir::Attribute& dense_attr, const mlir::Attribute& device_tag_attr, const mlir::Attribute& device_name_attr) { ::oneflow::LazyMode::Guard guard{false}; const auto dense_attr_ = dense_attr.cast(); const auto dense_element_type = dense_attr_.getElementType(); if (dense_element_type.isF32()) { return __DenseElementsAttrToTensor(dense_attr_, device_tag_attr, device_name_attr, ::oneflow::DataType::kFloat); } llvm::errs() << "Converting mlir::DenseElementsAttr to oneflow::Tensor only support float32 and int64 now." << "\n"; exit(EXIT_FAILURE); } void DenseElementsAttrToTensor(const mlir::Attribute& dense_attr, const mlir::Attribute& device_tag_attr, const mlir::Attribute& device_name_attr, std::shared_ptr<::oneflow::one::Tensor>& tensor) { ::oneflow::LazyMode::Guard guard{false}; const auto dense_attr_ = dense_attr.cast(); const auto dense_element_type = dense_attr_.getElementType(); if (dense_element_type.isF32()) { __DenseElementsAttrToTensor(dense_attr_, device_tag_attr, device_name_attr, ::oneflow::DataType::kFloat, tensor); } else { llvm::errs() << "Converting mlir::DenseElementsAttr to oneflow::Tensor only support float32 " "and int64 now." << "\n"; exit(EXIT_FAILURE); } } FailureOr<::oneflow::DataType> FromMLIRTypeToOFDataType(Type mlir_type) { if (mlir_type.dyn_cast()) { return ::oneflow::DataType::kInvalidDataType; } if (mlir_type.dyn_cast()) { return ::oneflow::DataType::kChar; } if (mlir_type.dyn_cast()) { return ::oneflow::DataType::kOFRecord; } if (mlir_type.dyn_cast()) { return ::oneflow::DataType::kTensorBuffer; } if (mlir_type.isF16()) { return ::oneflow::DataType::kFloat16; } if (mlir_type.isF32()) { return ::oneflow::DataType::kFloat; } if (mlir_type.isF64()) { return ::oneflow::DataType::kDouble; } if (mlir_type.isSignlessInteger(8)) { return ::oneflow::DataType::kBool; } if (mlir_type.isSignlessInteger(16)) { return ::oneflow::DataType::kUInt16; } if (mlir_type.isSignlessInteger(32)) { return ::oneflow::DataType::kUInt32; } if (mlir_type.isSignlessInteger(64)) { return ::oneflow::DataType::kUInt64; } if (mlir_type.isSignlessInteger(128)) { return ::oneflow::DataType::kUInt128; } if (mlir_type.isSignedInteger(8)) { return ::oneflow::DataType::kInt8; } if (mlir_type.isSignedInteger(16)) { return ::oneflow::DataType::kInt16; } if (mlir_type.isSignedInteger(32)) { return ::oneflow::DataType::kInt32; } if (mlir_type.isSignedInteger(64)) { return ::oneflow::DataType::kInt64; } if (mlir_type.isSignedInteger(128)) { return ::oneflow::DataType::kInt128; } llvm::errs() << "Unsupported data type: " << mlir_type << "\n"; return failure(); } FailureOr<::oneflow::DataType> FromMLIRDataTypeToOFDataType(::mlir::oneflow::DataType data_type) { switch (data_type) { case ::mlir::oneflow::DataType::DT_InvalidDataType: return ::oneflow::DataType::kInvalidDataType; #define DEFINE_ONE_CASE(datatype) \ case ::mlir::oneflow::DataType::DT_##datatype: return ::oneflow::DataType::k##datatype; DEFINE_ONE_CASE(Char) DEFINE_ONE_CASE(Float) DEFINE_ONE_CASE(Double) DEFINE_ONE_CASE(Int8) DEFINE_ONE_CASE(Int32) DEFINE_ONE_CASE(Int64) DEFINE_ONE_CASE(UInt8) DEFINE_ONE_CASE(OFRecord) DEFINE_ONE_CASE(Float16) DEFINE_ONE_CASE(TensorBuffer) DEFINE_ONE_CASE(Bool) #undef DEFINE_ONE_CASE default: { return failure(); } } return failure(); } FailureOr<::oneflow::DataType> FromMLIRAttrToOFDataType(Attribute attr) { const auto data_type_attr = attr.dyn_cast(); return FromMLIRDataTypeToOFDataType(data_type_attr.getValue()); } const ::oneflow::UserOpDef& getUserOpDef(const std::string& op_type_name) { const ::oneflow::user_op::OpRegistryResult* val = ::oneflow::user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); CHECK(val) << " Cannot find op_type_name: " << op_type_name; return val->op_def; } } // namespace support } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowTypes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowTypes.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" #define GET_TYPEDEF_CLASSES #include "OneFlow/OneFlowOpsTypes.cpp.inc" ================================================ FILE: oneflow/ir/lib/OneFlow/OneFlowUtils.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowUtils.h" #include "oneflow/core/common/util.h" namespace mlir { namespace oneflow { void CheckEnableIRPrinting(mlir::PassManager& pm) { bool enable_ir_printing = ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_ENABLE_IR_PRINTING", false); pm.getContext()->disableMultithreading(enable_ir_printing); if (enable_ir_printing) { pm.enableIRPrinting(); } } StringRef SanitizeIdentifier(StringRef name, SmallString<16>& buffer, StringRef allowedPunctChars, bool allowTrailingDigit) { assert(!name.empty() && "Shouldn't have an empty name here"); auto copyNameToBuffer = [&] { for (char ch : name) { if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) buffer.push_back(ch); else if (ch == ' ') buffer.push_back('_'); else buffer.append(llvm::utohexstr((unsigned char)ch)); } }; // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's, so add an underscore // prefix to avoid problems. if (isdigit(name[0])) { buffer.push_back('_'); copyNameToBuffer(); return buffer; } // If the name ends with a trailing digit, add a '_' to avoid potential // conflicts with autogenerated ID's. if (!allowTrailingDigit && isdigit(name.back())) { copyNameToBuffer(); buffer.push_back('_'); return buffer; } // Check to see that the name consists of only valid identifier characters. for (char ch : name) { if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { copyNameToBuffer(); return buffer; } } // If there are no invalid characters, return the original name. return name; } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" #include "mlir/IR/Value.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.h.inc" namespace mlir { namespace oneflow { void populateAllocEliminationPatterns(RewritePatternSet& patterns) { populateGeneratedPDLLPatterns(patterns); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll ================================================ #include "OneFlow/OneFlowOps.td" Constraint IsFuncArguments(value: Value) [{ return success(llvm::dyn_cast(value)); }]; Pattern { arg: Value; let alloc = op(); let copy = op(alloc.0, arg); IsFuncArguments(arg); rewrite alloc with { erase copy; replace alloc with arg; }; } ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/CMakeLists.txt ================================================ add_mlir_pdll_library(MLIROneFlowPDLLAllocEliminaionPatternsIncGen AllocEliminationPatterns.pdll AllocEliminationPatterns.h.inc) add_mlir_pdll_library(MLIROneFlowPDLLForwardOpPatternsIncGen ForwardOpPatterns.pdll ForwardOpPatterns.h.inc) add_mlir_pdll_library(MLIROneFlowPDLLNormalizationPatternsIncGen NormalizationPatterns.pdll NormalizationPatterns.h.inc) add_mlir_pdll_library(MLIROneFlowPDLLFuseConv2DBatchNormPatternIncGen FuseConv2DBatchNormPattern.pdll FuseConv2DBatchNormPattern.h.inc) add_mlir_pdll_library(MLIROneFlowPDLLFuseOpsWithBackwardImplPatternsIncGen FuseOpsWithBackwardImplPattern.pdll FuseOpsWithBackwardImplPattern.h.inc) oneflow_add_mlir_dialect_library( MLIROneFlowPDLLPatterns AllocEliminationPatterns.cpp ForwardOpPatterns.cpp NormalizationPatterns.cpp FuseConv2DBatchNormPattern.cpp FuseOpsWithBackwardImplPattern.cpp DEPENDS MLIROneFlowPDLLAllocEliminaionPatternsIncGen MLIROneFlowPDLLForwardOpPatternsIncGen MLIROneFlowPDLLNormalizationPatternsIncGen MLIROneFlowPDLLFuseConv2DBatchNormPatternIncGen MLIROneFlowPDLLFuseOpsWithBackwardImplPatternsIncGen) ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.h.inc" namespace mlir { namespace oneflow { void populateForwardOpPatterns(RewritePatternSet& patterns) { populateGeneratedPDLLPatterns(patterns); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll ================================================ #include "OneFlow/OneFlowOps.td" #include "OneFlowPDLLUtils.pdll" Pattern { let rate: Attr; let device_name: Attr; let device_tag: Attr; let axis: Attr; let dropout = op( op(a: Value, b: Value) {axis = axis, device_name = device_name, device_tag = device_tag}) {rate = rate, device_name = device_name, device_tag = device_tag} -> (out: Type, mask: Type); rewrite dropout with { let random_mask_like = CopyUserOpAttrs(dropout, op(a){rate = rate} -> (mask)); let fused_bias_add_mask_scale = CopyUserOpAttrs(dropout, BuildFusedBiasAddMaskScaleOpWithRate(a, b, random_mask_like.0, axis, rate, dropout)); replace dropout with (fused_bias_add_mask_scale.0, random_mask_like.0); }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let gelu = op( op(a: Value, b: Value) {axis = axis, device_name = device_name, device_tag = device_tag}) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite gelu with{ replace gelu with CopyUserOpAttrs(gelu, op(a, b){axis = axis} -> (out)); }; } Pattern { let device_name: Attr; let device_tag = attr<"\"cuda\"">; let scalar_div_operand: Attr; let out_shape: Attr; let query: Value; let key: Value; let value: Value; let query_reshape = op(query) {device_name = device_name, device_tag = device_tag}; let key_reshape = op(key) {device_name = device_name, device_tag = device_tag}; let value_reshape = op(value) {device_name = device_name, device_tag = device_tag}; let query_transpose = op(query_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let key_transpose = op(key_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 3 : si32, 1 : si32]">}; let value_transpose = op(value_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let scores = op(query_transpose.0, key_transpose.0) {alpha = attr<"1.000000e+00 : f64">, device_name = device_name, device_tag = device_tag, transpose_a = attr<"false">, transpose_b = attr<"false">}; let scores_scaled = op(scores.0) {device_name = device_name, device_tag = device_tag, float_operand = scalar_div_operand, has_float_operand = attr<"true">}; let attn = op(scores_scaled.0) {device_name = device_name, device_tag = device_tag}; let out = op(attn.0, value_transpose.0) {alpha = attr<"1.000000e+00 : f64">, device_name = device_name, device_tag = device_tag, transpose_a = attr<"false">, transpose_b = attr<"false">}; let out_transpose = op(out.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let out_reshape = op(out_transpose.0) {device_name = device_name, device_tag = device_tag, shape = out_shape} -> (out_t: Type); IsScalarEqualSqrtDim(query_reshape.0, scalar_div_operand); rewrite out_reshape with{ replace out_reshape with CopyUserOpAttrs(out, op(query, key, value) { attn_mask_type = attr<"\"none\"">, query_max_seq_len = attr<"0 : si64">, key_max_seq_len = attr<"0 : si64">, causal_diagonal_offset = attr<"0 : si64">, query_head_size = GetHeadSizeFromTranpose(query_transpose), query_layout = attr<"\"BM(HK)\"">, key_layout = attr<"\"BM(HK)\"">, value_layout = attr<"\"BM(HK)\"">, output_layout = attr<"\"BM(HK)\"">, operand_segment_sizes = attr<"array">, scale = GetReciprocal(scalar_div_operand) } -> (out_t)); }; } Pattern { let device_name: Attr; let device_tag = attr<"\"cuda\"">; let batch_matmul_alpha: Attr; let out_shape: Attr; let query: Value; let key: Value; let value: Value; let value_reshape = op(value) {device_name = device_name, device_tag = device_tag}; let key_reshape = op(key) {device_name = device_name, device_tag = device_tag}; let query_reshape = op(query) {device_name = device_name, device_tag = device_tag}; let value_permute = op(value_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let key_permute = op(key_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let query_permute = op(query_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let value_reshape_to_batch = op(value_permute.0) {device_name = device_name, device_tag = device_tag}; let key_reshape_to_batch = op(key_permute.0) {device_name = device_name, device_tag = device_tag}; let query_reshape_to_batch = op(query_permute.0) {device_name = device_name, device_tag = device_tag}; let key_transpose = op(key_reshape_to_batch.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32]">}; let scores_scaled = op(query_reshape_to_batch.0, key_transpose.0) {alpha = batch_matmul_alpha, device_name = device_name, device_tag = device_tag, transpose_a = attr<"false">, transpose_b = attr<"false">}; let attn = op(scores_scaled.0) {device_name = device_name, device_tag = device_tag}; let out = op(attn.0, value_reshape_to_batch.0) {alpha = attr<"1.000000e+00 : f64">, device_name = device_name, device_tag = device_tag, transpose_a = attr<"false">, transpose_b = attr<"false">}; let out_reshape_before = op(out.0) {device_name = device_name, device_tag = device_tag}; let out_transpose = op(out_reshape_before.0) {device_name = device_name, device_tag = device_tag, perm = attr<"[0 : si32, 2 : si32, 1 : si32, 3 : si32]">}; let out_reshape = op(out_transpose.0) {device_name = device_name, device_tag = device_tag, shape = out_shape} -> (out_t: Type); IsScalarEqualSqrtDimReciprocal(query_reshape.0, batch_matmul_alpha); rewrite out_reshape with{ replace out_reshape with CopyUserOpAttrs(out, op(query, key, value) { attn_mask_type = attr<"\"none\"">, query_max_seq_len = attr<"0 : si64">, key_max_seq_len = attr<"0 : si64">, causal_diagonal_offset = attr<"0 : si64">, query_head_size = GetHeadSizeFromTranpose(query_permute), query_layout = attr<"\"BM(HK)\"">, key_layout = attr<"\"BM(HK)\"">, value_layout = attr<"\"BM(HK)\"">, output_layout = attr<"\"BM(HK)\"">, operand_segment_sizes = attr<"array">, scale = batch_matmul_alpha } -> (out_t)); }; } Pattern { let device_name: Attr; let device_tag: Attr; let padding_before: Attr; let padding_after: Attr; let data_format: Attr; let conv = op( op(x: Value){device_name = device_name, device_tag = device_tag, padding_before = padding_before, padding_after = padding_after}, weight: Value) {device_name = device_name, device_tag = device_tag, data_format = data_format}; IsPaddingCouldBeAssimilatedIntoConv(padding_before, padding_after, data_format); rewrite conv with{ let conv2d_and_erase_pad = CreateConv2dAndErasePad(x, weight, padding_before, data_format, conv); replace conv with CopyUserOpAttrs(conv, conv2d_and_erase_pad); }; } Pattern { let valueType: Type; let x: Value; let cast = op(x) -> (valueType); replace cast with x; } Pattern { let device_name: Attr; let has_float_operand: Attr; let int_operand: Attr; let float_operand: Attr; let diagonal: Attr; let floating_fill_value: Attr; let integer_fill_value: Attr; let is_floating_fill_value: Attr; let tril = op( op(x: Value) {device_name = device_name, device_tag = attr<"\"cuda\"">, has_float_operand = has_float_operand, int_operand = int_operand, float_operand = float_operand}) {device_name = device_name, device_tag = attr<"\"cuda\"">, diagonal = diagonal, floating_fill_value = floating_fill_value, integer_fill_value =integer_fill_value, is_floating_fill_value = is_floating_fill_value} -> (out: Type); replace tril with CopyUserOpAttrs(tril, CreatScaleTrilOp(x, diagonal, floating_fill_value, integer_fill_value, is_floating_fill_value, float_operand ,int_operand, has_float_operand, out)); } Pattern { let device_name: Attr; let has_float_operand: Attr; let int_operand: Attr; let float_operand: Attr; let diagonal: Attr; let floating_fill_value: Attr; let integer_fill_value: Attr; let is_floating_fill_value: Attr; let scalar = op( op(x: Value) {device_name = device_name, device_tag = attr<"\"cuda\"">, diagonal = diagonal, floating_fill_value = floating_fill_value, integer_fill_value =integer_fill_value, is_floating_fill_value = is_floating_fill_value }) {device_name = device_name, device_tag = attr<"\"cuda\"">, has_float_operand = has_float_operand, int_operand = int_operand, float_operand = float_operand} -> (out: Type); replace scalar with CopyUserOpAttrs(scalar, CreatScaleTrilOp(x, diagonal, floating_fill_value, integer_fill_value, is_floating_fill_value, float_operand ,int_operand, has_float_operand, out)); } Pattern { let device_name: Attr; let device_tag: Attr; let broadcast_mul = op(x: Value, y: Value){device_name = device_name, device_tag = device_tag}-> (out: Type); IsScalarTensor(y); rewrite broadcast_mul with{ let scalar_mul = op(x, y) {device_name = device_name, device_tag = device_tag} -> (out); replace broadcast_mul with CopyUserOpAttrs(broadcast_mul, scalar_mul); }; } ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.h.inc" namespace mlir { namespace oneflow { void populateFuseConv2DBatchNormPattern(RewritePatternSet& patterns) { populateGeneratedPDLLPatterns(patterns); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.pdll ================================================ #include "OneFlowPDLLUtils.pdll" Pattern { let device_name: Attr; let device_tag: Attr; let epsilon: Attr; let moving_mean: Value; let moving_variance: Value; let beta: Value; let weight = op; let gamma = op; let conv = op(x: Value, weight.0){device_name = device_name, device_tag = device_tag}; let normalization = op(conv, moving_mean, moving_variance, gamma.0, beta) {device_name = device_name, device_tag = device_tag, epsilon = epsilon} -> (y: Type); rewrite normalization with{ let conv2d_bn = CreateConv2DBatchNorm(epsilon, conv, normalization); replace normalization with CopyUserOpAttrs(normalization, conv2d_bn); }; } ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.h.inc" namespace mlir { namespace oneflow { void populateFuseOpsWithBackwardImplPattern(RewritePatternSet& patterns) { populateGeneratedPDLLPatterns(patterns); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.pdll ================================================ #include "OneFlowPDLLUtils.pdll" Pattern { let device_name: Attr; let device_tag: Attr; let matmul_wx = op(x: Value, w: Value){device_name = device_name, device_tag = device_tag, alpha = attr<"1.000000e+00 : f64">}; let matmul_wx_add = op(matmul_wx.0, b: Value){device_name = device_name, device_tag = device_tag} -> (matmul_wx_out: Type); let hidden_states = op(matmul_wx_add.0){device_name = device_name, device_tag = device_tag}; let gate = op(matmul_wx_add.0){device_name = device_name, device_tag = device_tag}; let gate_activate = op(gate.0){device_name = device_name, device_tag = device_tag}; let gelu_out = op(hidden_states.0,gate_activate.0){device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite gelu_out with{ let fused_gelu_out = op(x, w, b){activation = attr<"\"gelu\"">, operand_segment_sizes = attr<"array">, device_name = device_name, device_tag = device_tag, has_bias = attr<"true">, is_split = attr<"false">}-> (out, matmul_wx_out); CopyUserOpAttrs(gelu_out, fused_gelu_out); replace gelu_out with fused_gelu_out.0; replace matmul_wx_add with fused_gelu_out.1; }; } Pattern { let device_name: Attr; let device_tag: Attr; let matmul_wx_add = op(x: Value, w: Value, b: Value){device_name = device_name, device_tag = device_tag, alpha = attr<"1.000000e+00 : f64">} -> (matmul_wx_out: Type); let hidden_states = op(matmul_wx_add.0){device_name = device_name, device_tag = device_tag}; let gate = op(matmul_wx_add.0){device_name = device_name, device_tag = device_tag}; let gate_activate = op(gate.0){device_name = device_name, device_tag = device_tag}; let gelu_out = op(hidden_states.0,gate_activate.0){device_name = device_name, device_tag = device_tag}-> (out: Type); rewrite gelu_out with{ let fused_gelu_out = op(x, w, b){activation = attr<"\"gelu\"">, operand_segment_sizes = attr<"array">, device_name = device_name, device_tag = device_tag}-> (out, matmul_wx_out); CopyUserOpAttrs(gelu_out, fused_gelu_out); replace gelu_out with fused_gelu_out.0; replace matmul_wx_add with fused_gelu_out.1; }; } ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPDLLPatterns.h" using namespace mlir; #include "oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.h.inc" namespace mlir { namespace oneflow { void populateNormalizationOpPatterns(RewritePatternSet& patterns) { populateGeneratedPDLLPatterns(patterns); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.pdll ================================================ #include "OneFlowPDLLUtils.pdll" Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"true">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let addend: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type); let relu = op( op(normalization.0, addend) {device_name = device_name, device_tag = device_tag}) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"true">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let addend: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type); let relu = op( op(normalization.0, addend) {device_name = device_name, device_tag = device_tag}) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"false">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let addend: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type); let relu = op( op(normalization.0, addend) {device_name = device_name, device_tag = device_tag}) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"false">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let addend: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type); let relu = op( op(normalization.0, addend) {device_name = device_name, device_tag = device_tag}) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"false">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type); let relu = op(normalization.0) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } Pattern { let device_name: Attr; let device_tag: Attr; let axis: Attr; let epsilon: Attr; let training = attr<"false">; let momentum: Attr; let x: Value; let moving_mean: Value; let moving_variance: Value; let gamma: Value; let beta: Value; let normalization = op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type); let relu = op(normalization.0) {device_name = device_name, device_tag = device_tag} -> (out: Type); rewrite relu with{ let fused_bn = CopyUserOpAttrs(normalization, op(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<"array">, result_segment_sizes = attr<"array">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag}); replace relu with fused_bn.0; }; } ================================================ FILE: oneflow/ir/lib/OneFlow/PDLL/OneFlowPDLLUtils.pdll ================================================ Rewrite BuildFusedBiasAddMaskScaleOpWithRate(a: Value, b: Value, mask: Value, axis: Attr, rate: Attr, dropout: Op) -> Op; Rewrite CopyUserOpAttrs(src: Op, dst: Op) -> Op; Rewrite GetHeadSizeFromTranpose(transpose: Op) -> Attr; Rewrite CreateConv2dAndErasePad(x: Value, weight: Value, padding_before: Attr, data_format: Attr, conv: Op) -> Op; Rewrite CreatScaleTrilOp(x: Value, diagonal: Attr, floating_fill_value: Attr, integer_fill_value: Attr, is_floating_fill_value: Attr, float_operand: Attr, int_operand: Attr, has_float_operand: Attr, out: Type) -> Op { let floating_scale_value = float_operand; let integer_scale_value = int_operand; let is_floating_scale_value = has_float_operand; let scale_tril_op = op(x){diagonal = diagonal, floating_fill_value = floating_fill_value, integer_fill_value = integer_fill_value, is_floating_fill_value = is_floating_fill_value, floating_scale_value = floating_scale_value, integer_scale_value = integer_scale_value, is_floating_scale_value = is_floating_scale_value} -> (out); return scale_tril_op; } Rewrite CreateConv2DBatchNorm(epsilon: Attr, conv: Op, bn: Op) -> Op; Constraint IsPaddingCouldBeAssimilatedIntoConv(padding_before: Attr, padding_after: Attr, data_format:Attr); Constraint IsNotNestedInJit(mul: Op); Constraint IsScalarTensor(value: Value); Constraint IsScalarEqualSqrtDim(query_reshape: Value, scalar_div_operand: Attr); Constraint IsScalarEqualSqrtDimReciprocal(query_reshape: Value, scalar_div_operand: Attr); Rewrite GetReciprocal(a: Attr) -> Attr; ================================================ FILE: oneflow/ir/lib/OneFlow/Passes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "OneFlow/Transform/OneFlowMemPool.h" #include "OneFlow/Transform/EliminateAllocOps.h" #include "OneFlow/Transform/OneFlowStream.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/variable_tensor_mgr.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowUtils.h" #include "OneFlow/Passes.h" #include "OneFlow/OneFlowUtils.h" #include "OneFlow/OneFlowPatternUtils.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/SBP/SBPImporter.h" #include "OneFlow/SBP/SBPAttributes.h" #include "OneFlow/OKL/OKLOps.h" #include "OneFlow/OKL/OKLTypes.h" #include "OneFlow/OKL/Kernel/RegContext.h" #include "OneFlow/OKM/Conversion/Conversion.h" #include "OneFlow/Transform/TransposeHelpers.h" #include "OneFlow/Transform/OutlineAndFuse.h" #include "OneFlow/OneFlowPDLLPatterns.h" #include "OneFlow/OKL/passes.h" #include "OneFlow/OKL/OKLAttributes.h" #include "OneFlow/OKM/passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/SymbolTable.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ErrorHandling.h" #include #include #include #include #include #ifdef WITH_MLIR_CUDA_CODEGEN #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #endif // WITH_MLIR_CUDA_CODEGEN #ifdef WITH_CUDA // enable with_cuda_graphs #include "oneflow/core/ep/cuda/cuda_stream.h" #endif // WITH_CUDA namespace mlir { namespace oneflow { LLVM::LLVMPointerType GetPtr(::mlir::PatternRewriter& rewriter) { return LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)); } template LogicalResult DumpAssembly(::mlir::PatternRewriter& rewriter, T op, StringRef func_name) { // TODO: now we only need one JIT engine auto parent_func_op = op->template getParentOfType(); if (!parent_func_op) { return failure(); } auto parent_module_op = parent_func_op->template getParentOfType(); if (!parent_module_op) { return failure(); } SymbolTable symbol_table(parent_module_op); std::string mlir; llvm::raw_string_ostream os_mlir(mlir); if (auto found = symbol_table.lookup(func_name)) { found->print(os_mlir); } else { parent_module_op->dump(); return op.emitError("symbol of jit function not found: " + op.getOpName()); } op->setAttr("mlir_assembly", rewriter.getStringAttr(mlir)); return success(); } LLVM::LLVMFuncOp DeclareKernelLaunchCInterface(::mlir::PatternRewriter& rewriter, mlir::Location loc, ModuleOp* module, StringRef c_api_callee, Type llvm_ptr_type) { LLVM::LLVMFuncOp func; if (!(func = module->lookupSymbol(c_api_callee))) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module->getBody()); auto void_type = LLVM::LLVMVoidType::get(rewriter.getContext()); auto func_type = LLVM::LLVMFunctionType::get(void_type, {llvm_ptr_type, llvm_ptr_type}, false); func = rewriter.create(loc, c_api_callee, func_type, LLVM::Linkage::External); func->setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(rewriter.getContext())); } return func; } LLVM::GlobalOp DeclareOrGetGlobalString(::mlir::PatternRewriter& rewriter, mlir::Location loc, ModuleOp* module, StringRef func_name) { LLVM::GlobalOp global; StringRef variable = rewriter.getStringAttr(func_name + "_var"); if (!(global = module->lookupSymbol(variable))) { OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module->getBody()); auto type = LLVM::LLVMArrayType::get(IntegerType::get(rewriter.getContext(), 8), func_name.size()); global = rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, variable, rewriter.getStringAttr(func_name), /*alignment=*/0); } return global; } template ModuleOp GetModuleOpFromJobBodyOp(Operation* op) { auto parent_func_op = op->getParentOfType(); if (!parent_func_op) { return nullptr; } return parent_func_op->template getParentOfType(); } func::FuncOp InsertKernelOFFuncOp(::mlir::PatternRewriter& rewriter, Operation* op, const std::string& func_name) { auto loc = op->getLoc(); auto module = GetModuleOpFromJobBodyOp(op); if (!module) { emitError(loc) << "null ModuleOp " << *op; return nullptr; } IRMapping mapping; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto func_type = rewriter.getFunctionType(TypeRange(op->getOperandTypes()), TypeRange(op->getResultTypes())); func::FuncOp func = rewriter.create(loc, func_name, func_type); func->setAttr("compiled", rewriter.getStringAttr("true")); func.getBody().emplaceBlock(); for (auto& arg : func_type.getInputs()) { func.getBody().addArguments(arg, loc); } for (auto argument_pair : llvm::zip(ValueRange(op->getOperands()), func.getBody().getArguments())) { mapping.map(std::get<0>(argument_pair), std::get<1>(argument_pair)); } rewriter.setInsertionPointToStart(&func.getBody().front()); ImplicitLocOpBuilder new_block(loc, rewriter); new_block.clone(*op, mapping); SmallVector<::mlir::Value, 4> mapped_results; for (auto result : ValueRange(op->getResults())) { mapped_results.push_back(mapping.lookup(result)); } rewriter.create(loc, mapped_results); return func; } ::llvm::SmallVector<::mlir::Value, 4> CreateGPUMemcpyOpFromMemrefCopy( ::mlir::PatternRewriter& rewriter, ::mlir::memref::CopyOp copyOp) { // NOTE: to get lowered to LLVM, it has to be async ::mlir::ValueRange empty_async_dependencies{}; auto token = rewriter.getType(); auto t0 = rewriter.create(copyOp->getLoc(), token, empty_async_dependencies) .getAsyncToken(); auto t2 = rewriter .create(copyOp->getLoc(), /*optional asyncToken*/ token, /*asyncDependencies*/ llvm::SmallVector({t0}), /*dst*/ copyOp.getTarget(), /*src*/ copyOp.getSource()) .getResults(); rewriter.create(copyOp->getLoc(), llvm::None, t2); return {}; } bool HasZeroPadding(mlir::ArrayAttr padding) { for (auto val : padding.getValue()) { if (val.cast().getValue().getSExtValue() != 0) return false; } return true; } NamedAttrList GetUserOpCommonAttrs(MLIRContext* ctx, const std::string& op_name) { NamedAttrList attrs; attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), StringAttr::get(ctx, op_name)); attrs.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), StringAttr::get(ctx, "cpu")); attrs.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(ArrayRef({"@0:0"}), [&](StringRef v) -> Attribute { return StringAttr::get(ctx, v); })))); return attrs; } struct ReplaceVariablePattern : public ::mlir::RewritePattern { explicit ReplaceVariablePattern(::mlir::MLIRContext* context) : ::mlir::RewritePattern("oneflow.variable", 1, context, {"oneflow.variable_ir"}) {} ::mlir::LogicalResult matchAndRewrite(::mlir::Operation* op0, ::mlir::PatternRewriter& rewriter) const override { auto op = ::llvm::dyn_cast(op0); if (!op) return failure(); NamedAttrList attrs; if (op.getOpName().str().find("FreeEagerTensor") != std::string::npos) { return failure(); } attrs.set(StringAttr::get(getContext(), "value"), support::TensorToDenseElementsAttr( CHECK_JUST(::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Get( op.getOpName().str(), ::oneflow::DType::Float())), rewriter.getContext())); attrs.set(op.getOpNameAttrName(), op.getOpNameAttr()); attrs.set(op.getDataTypeAttrName(), op.getDataTypeAttr()); attrs.set(op.getDeviceTagAttrName(), op.getDeviceTagAttr()); attrs.set(op.getDeviceNameAttrName(), op.getDeviceNameAttr()); attrs.set(op.getScopeSymbolIdAttrName(), op.getScopeSymbolIdAttr()); attrs.set(op.getHierarchyAttrName(), op.getHierarchyAttr()); auto name = FrozenVariableOp::getNdSbpAttrName( OperationName(FrozenVariableOp::getOperationName(), rewriter.getContext())); auto parallel_attr = op.getParallelAttr(); attrs.set(name, SBPTranslation::ConvertSBPToString(rewriter, parallel_attr)); auto op_new = rewriter.create(op->getLoc(), op.getOutput().getType(), ValueRange(), attrs); rewriter.replaceOp(op0, op_new->getResults()); return ::mlir::success(); } }; struct ReplaceVariableIrPattern : public ::mlir::RewritePattern { explicit ReplaceVariableIrPattern(::mlir::MLIRContext* context) : ::mlir::RewritePattern("oneflow.variable_ir", 1, context, {"oneflow.variable"}) {} ::mlir::LogicalResult matchAndRewrite(::mlir::Operation* op0, ::mlir::PatternRewriter& rewriter) const override { auto op = ::llvm::dyn_cast(op0); if (!op) return failure(); NamedAttrList attrs; const auto tensor_attr = op.getValue(); attrs.set(StringAttr::get(getContext(), "shape"), rewriter.getArrayAttr(llvm::to_vector<8>(llvm::map_range( tensor_attr.getType().cast().getShape(), [&](int64_t v) -> Attribute { return IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true), APInt(64, v, /*isSigned=*/true)); })))); attrs.set(StringAttr::get(getContext(), "data_type"), oneflow::DataTypeAttr::get(getContext(), oneflow::DataType::DT_Float)); auto output_lbns_attr = rewriter.getStrArrayAttr({op.getOpName().str() + "/out"}); attrs.set(OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr); attrs.set(op.getOpNameAttrName(), op.getOpNameAttr()); attrs.set(op.getDataTypeAttrName(), op.getDataTypeAttr()); attrs.set(op.getDeviceTagAttrName(), op.getDeviceTagAttr()); attrs.set(op.getDeviceNameAttrName(), op.getDeviceNameAttr()); attrs.set(op.getScopeSymbolIdAttrName(), op.getScopeSymbolIdAttr()); attrs.set(op.getHierarchyAttrName(), op.getHierarchyAttr()); auto name = VariableOp::getParallelAttrName( OperationName(VariableOp::getOperationName(), rewriter.getContext())); auto nd_size = op.getHierarchy()->size(); ArrayAttr nd_sbp = op.getNdSbp(); std::vector nd_sbp_str; std::for_each(nd_sbp.begin(), nd_sbp.end(), [&](Attribute elem) { if (auto sbp_str_attr = elem.dyn_cast()) { nd_sbp_str.push_back(sbp_str_attr.str()); } }); attrs.set(name, SBPTranslation::ConvertNdSbpToPsig(rewriter, nd_sbp_str, nd_size)); auto op_new = rewriter.create(op->getLoc(), op.getOutput().getType(), ValueRange(), attrs); const std::string tensor_name = op.getOpNameAttr().str(); const auto data_type = support::FromMLIRAttrToOFDataType(op.getDataTypeAttr()); if (failed(data_type)) { op0->emitError(::llvm::formatv("unsupported data type: {0}", ConvertToString(op.getDataTypeAttr().getValue()))); return ::mlir::failure(); } auto var_tensor = CHECK_JUST( ::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Get(op.getOpName().str())); if (var_tensor) { support::DenseElementsAttrToTensor(tensor_attr, op.getDeviceTagAttr(), op.getDeviceNameAttr(), var_tensor); } else { CHECK_JUST(::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Set( tensor_name, // tensor_name can't be replaced by op.op_nameAttr().str() directly when // compiling with gcc and I has no idea why. // But it works when compiling with clang. // Maybe temporary objects would be released earlier when using gcc. support::DenseElementsAttrToTensor(tensor_attr, op.getDeviceTagAttr(), op.getDeviceNameAttr()), CHECK_JUST(::oneflow::DType::Get(data_type.value())))); } // replaceOp may deallocate `op0` (and also `op`), so we should not use `op` after this call. rewriter.replaceOp(op0, op_new->getResults()); return ::mlir::success(); } }; LogicalResult InitTransposeAttributes(Operation* op, NamedAttrList& transpose_attributes, PatternRewriter& rewriter) { if (op->hasTrait()) { return OpTrait::IsOpConfCompatible::saveToNamedAttrList(op, transpose_attributes); } else { op->emitError("must be a op of trait IsOpConfCompatible!"); return failure(); } } bool IsAddToOutputNone(ValueRange value) { return (int)value.size() > 0 ? false : true; } llvm::SmallVector getChannelLastTransposePerm() { return {0, 2, 3, 1}; } llvm::SmallVector getChannelFirstTransposePerm() { return {0, 3, 1, 2}; } llvm::SmallVector getInputOperandTransposeOp(NCHWCompatible op, Value val, NamedAttrList transpose_attributes, int num_transposed_operand, PatternRewriter& rewriter) { std::string transpose_name = OpTrait::IsOpConfCompatible::getOpName(op).str() + "_transpose_input_" + std::to_string(num_transposed_operand); transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible::getOpNameAttr()), rewriter.getStringAttr(transpose_name)); SmallVector input_operands; input_operands.push_back(val); auto res = rewriter .create(op.getLoc(), getNHWCType(val.getType()), input_operands, transpose_attributes) ->getResults(); return res; } TransposeOp getResultTransposeOp(NCHWCompatible op, Value val, NamedAttrList transpose_attributes, int num_transposed_result, PatternRewriter& rewriter) { std::string transpose_name = OpTrait::IsOpConfCompatible::getOpName(op).str() + "_transpose_output_" + std::to_string(num_transposed_result); transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible::getOpNameAttr()), rewriter.getStringAttr(transpose_name)); SmallVector operands; operands.push_back(val); TransposeOp transpose_op = rewriter.create( op.getLoc(), getNCHWType(val.getType()), operands, transpose_attributes); return transpose_op; } bool IsInsertTransposeOpBefore(NCHWCompatible op, PatternRewriter& rewriter) { bool insert_transpose_op_flag = false; for (mlir::Value operand : op->getOperands()) { TransposeOp transposeInputOp = operand.getDefiningOp(); if (!transposeInputOp) continue; const auto perm = transposeInputOp.getPermAttr(); if (perm.size() == 4 && perm[0] == rewriter.getSI32IntegerAttr(0) && perm[1] == rewriter.getSI32IntegerAttr(3) && perm[2] == rewriter.getSI32IntegerAttr(1) && perm[3] == rewriter.getSI32IntegerAttr(2)) { insert_transpose_op_flag = true; break; } } return insert_transpose_op_flag; } } // namespace oneflow } // namespace mlir #include "OneFlow/OneFlowPatterns.cpp.inc" namespace mlir { namespace oneflow { template struct FusedConsecutiveAddPattern : public OpRewritePattern { explicit FusedConsecutiveAddPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} public: LogicalResult matchAndRewrite(Op op, PatternRewriter& rewriter) const override; }; template LogicalResult TryFusedConsecutiveAdd(Op op, const SmallVector& opOperands, PatternRewriter& rewriter) { for (mlir::Value operand : opOperands) { if (!operand.getDefiningOp() && !operand.getDefiningOp()) { continue; } // check if the operand has only one user LogicalResult checkResult = [&]() { for (const auto& use : operand.getUses()) { if (use.getOwner() != op) { return failure(); } } return success(); }(); if (failed(checkResult)) { continue; } SmallVector operands; SmallVector inputOpOperands; mlir::Value inputOpResult; if (AddNOp addInputOp = operand.getDefiningOp()) { inputOpOperands = addInputOp.getIn(); inputOpResult = addInputOp.getOut(); } else if (Add2Op addInputOp = operand.getDefiningOp()) { inputOpOperands = {addInputOp.getIn0(), addInputOp.getIn1()}; inputOpResult = addInputOp.getOut(); } for (mlir::Value operand : opOperands) { if (operand != inputOpResult) { operands.push_back(operand); } else { operands.insert(operands.end(), inputOpOperands.begin(), inputOpOperands.end()); } } auto new_op = rewriter.create(op->getLoc(), op->getResultTypes(), operands, op->getAttrs()); rewriter.replaceOp(op, new_op.getOut()); return success(); } return failure(); } template<> LogicalResult FusedConsecutiveAddPattern::matchAndRewrite(AddNOp op, PatternRewriter& rewriter) const { return TryFusedConsecutiveAdd(op, op.getIn(), rewriter); } template<> LogicalResult FusedConsecutiveAddPattern::matchAndRewrite(Add2Op op, PatternRewriter& rewriter) const { return TryFusedConsecutiveAdd(op, {op.getIn0(), op.getIn1()}, rewriter); } struct AutoNhwcPattern : public OpInterfaceRewritePattern { explicit AutoNhwcPattern(mlir::MLIRContext* context) : OpInterfaceRewritePattern(context, /*benefit=*/1) {} public: LogicalResult matchAndRewrite(NCHWCompatible op, PatternRewriter& rewriter) const override { if (op->hasTrait()) { for (mlir::Value operand : op.OperandsToTranspose()) { if (operand.getType().cast().getShape().size() != 4) { return failure(); } } const auto device_name = OpTrait::IsOpConfCompatible::getDeviceTag(op) .cast() .getValue() .str(); if (device_name == "cpu") { return failure(); } } llvm::SmallVector perm = getChannelLastTransposePerm(); llvm::SmallVector result_perm = getChannelFirstTransposePerm(); NamedAttrList transpose_attributes; if (InitTransposeAttributes(op, transpose_attributes, rewriter).succeeded()) { transpose_attributes.append(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, perm)); } else { return failure(); } // when op op has no sense of data_format and pre op is transpose, we greedily insert transpose // into this op, seeking more opportunities to eliminate transpose pattern. const bool greedily_transpose_flag = !op.IsNCHW() && IsInsertTransposeOpBefore(op, rewriter); if (op.IsNCHW() || greedily_transpose_flag) { // create transpose op for input operand SmallVector tranposed_operands; llvm::DenseSet operand_transpose = op.OperandsToTranspose(); int num_transposed_operand = 0; for (Value operand : op->getOperands()) { if (operand_transpose.find(operand) != operand_transpose.end()) { SmallVector input_res = getInputOperandTransposeOp( op, operand, transpose_attributes, num_transposed_operand, rewriter); tranposed_operands.push_back(input_res[0]); num_transposed_operand += 1; } } // create NHWC op SmallVector created_results = op.NchwToNhwc(tranposed_operands, rewriter); // create transpose op for results int num_transposed_result = 0; transpose_attributes.set(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, result_perm)); llvm::DenseSet transpose_result = op.ResultsToTranspose(); for (Value result : op->getOpResults()) { if (transpose_result.find(result) != transpose_result.end()) { if (auto result_transpose_op = getResultTransposeOp(op, created_results[num_transposed_result], transpose_attributes, num_transposed_result, rewriter)) { result.replaceAllUsesWith(result_transpose_op); num_transposed_result += 1; } else { return failure(); } } } } return success(); } }; bool IsRedundantTransposeMatch(ArrayAttr pre, ArrayAttr afe, mlir::PatternRewriter& rewriter) { const auto prePerm = pre.getValue().vec(); const auto afePerm = afe.getValue().vec(); if (prePerm.size() == 4 && afePerm.size() == 4) { // handle nchw->nhwc->nchw: (0, 2, 3, 1) -> (0, 3, 1, 2) if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[3] && prePerm[2] == afePerm[1] && prePerm[3] == afePerm[2] && prePerm[0] == rewriter.getSI32IntegerAttr(0) && prePerm[1] == rewriter.getSI32IntegerAttr(2) && prePerm[2] == rewriter.getSI32IntegerAttr(3) && prePerm[3] == rewriter.getSI32IntegerAttr(1)) return true; // handle nhwc->nchw->nhwc: (0, 3, 1, 2) -> (0, 2, 3, 1) if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[2] && prePerm[2] == afePerm[3] && prePerm[3] == afePerm[1] && prePerm[0] == rewriter.getSI32IntegerAttr(0) && prePerm[1] == rewriter.getSI32IntegerAttr(3) && prePerm[2] == rewriter.getSI32IntegerAttr(1) && prePerm[3] == rewriter.getSI32IntegerAttr(2)) return true; } return false; } struct AutoNhwcEliminateRedundantTransposePattern : public mlir::OpRewritePattern { explicit AutoNhwcEliminateRedundantTransposePattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter& rewriter) const override { mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp(); if (!transposeInputOp || !IsRedundantTransposeMatch(op.getPermAttr(), transposeInputOp.getPermAttr(), rewriter)) { return failure(); } rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return success(); } }; struct LowerToOKLPattern : public mlir::OpRewritePattern { static LogicalResult LowerToOKLOp(::mlir::PatternRewriter& rewriter, Operation* op, func::FuncOp okl_func, int index) { auto op_type_name = op->getAttr("op_name").dyn_cast(); auto raw_func = op->getParentOfType(); if (!op_type_name) { return failure(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(&okl_func.getBody().back()); auto loc = op->getLoc(); auto wrap_kernel = rewriter.create(loc, index); wrap_kernel.getBody().emplaceBlock(); rewriter.setInsertionPointToEnd(&wrap_kernel.getBody().back()); IRMapping mapping; // map launcher_ctx from wrap func to block mapping.map(raw_func.getArgument(0), okl_func.getArgument(0)); ImplicitLocOpBuilder new_block(loc, rewriter); for (auto arg : op->getOperands()) { auto define_op = arg.getDefiningOp(); if (define_op->getName().getStringRef() == okl::GetTensorFromArgOp::getOperationName()) { new_block.clone(*define_op, mapping); } else { auto find = false; for (auto use : arg.getUsers()) { if (use->getName().getStringRef() == okl::GetTensorAsRetOp::getOperationName()) { find = true; auto index = use->getAttr("index").cast().getInt(); auto source = rewriter.create(op->getLoc(), arg.getType(), okl_func.getArgument(0), index); mapping.map(arg, source->getResult(0)); break; } } if (!find) { op->emitError("Fail to find operand source"); } } } new_block.clone(*op, mapping); for (auto ret : op->getResults()) { auto find = false; for (auto use : ret.getUsers()) { if (use->getName().getStringRef() == okl::GetTensorAsRetOp::getOperationName()) { find = true; new_block.clone(*use, mapping); break; } } if (!find) { op->emitError("Fail to find result source"); } } rewriter.create(loc); return success(); } explicit LowerToOKLPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { ModuleOp module = op->getParentOfType(); if (!module) { LOG(FATAL) << "Not found module"; } if (module.lookupSymbol(okl_func::OKL_FUNC)) { return success(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); auto& block = op.getBody().front(); auto loc = op->getLoc(); auto func_type = rewriter.getFunctionType( {mlir::okl::LauncherContextType::get(rewriter.getContext())}, TypeRange{}); auto okl_func = rewriter.create(loc, okl_func::OKL_FUNC, func_type); okl_func.getBody().emplaceBlock(); okl_func.getBody().addArguments(mlir::okl::LauncherContextType::get(rewriter.getContext()), loc); auto index = 0; for (auto& op : block) { if (!op.hasAttr("op_name")) { if (op.getDialect()->getNamespace() == "okl") { continue; } if (isa(op)) { break; } op.emitError("Failed to parse this op in kernel launch wrap func."); } if (failed(LowerToOKLOp(rewriter, &op, okl_func, index))) { index += 1; op.emitError("Failed to lowering OneFlow op to okl dialect."); return failure(); } index += 1; } rewriter.setInsertionPointToEnd(&okl_func.getBody().back()); rewriter.create(loc); rewriter.eraseOp(op); return success(); } }; // {func, ins, outs_mapping} std::tuple, std::vector>> CreateWrapFuncAndReturnWithIns(mlir::Location loc, std::vector& wrap_ops, mlir::PatternRewriter& rewriter, int& name_index) { auto getProto = [&]() -> std::tuple, std::vector, std::vector>> { std::vector whole_ins, whole_outs, ins, outs; std::vector> outs_mapping; for (auto op : wrap_ops) { auto operands = op->getOperands(); auto results = op->getResults(); for (auto it = operands.begin(); it != operands.end(); ++it) { whole_ins.push_back(*it); } std::vector map; auto add_res = [&](mlir::OpResult res) { map.push_back(outs.size()); outs.push_back(res); }; for (auto it = results.begin(); it != results.end(); ++it) { whole_outs.push_back(*it); for (auto user : (*it).getUsers()) { if (std::find(wrap_ops.begin(), wrap_ops.end(), user) == wrap_ops.end()) { add_res(*it); break; } } } outs_mapping.push_back(map); } for (auto in : whole_ins) { if (std::find(whole_outs.begin(), whole_outs.end(), in) == whole_outs.end()) { ins.push_back(in); } } return {ins, outs, outs_mapping}; }; auto [ins, outs, map] = getProto(); auto func_type = rewriter.getFunctionType(TypeRange(ValueRange(ArrayRef(ins))), TypeRange(ValueRange(ArrayRef(outs)))); auto func_name = okm::func_name::GRAPH_NAME + std::to_string(name_index++); auto module = GetModuleOpFromJobBodyOp(wrap_ops[0]); if (!module) { LOG(FATAL) << "Fail to find parent ModuleOp"; } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto function = rewriter.create(loc, func_name, func_type); function->setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(rewriter.getContext())); function.getBody().emplaceBlock(); for (auto arg : ins) { function.getBody().addArgument(arg.getType(), loc); } IRMapping mapping; for (auto args_pair : llvm::zip(ins, function.getBody().getArguments())) { mapping.map(std::get<0>(args_pair), std::get<1>(args_pair)); } rewriter.setInsertionPointToStart(&function.getBody().front()); ImplicitLocOpBuilder new_block(loc, rewriter); for (auto op : wrap_ops) { new_block.clone(*op, mapping); } SmallVector<::mlir::Value, 4> mapped_results; for (auto result : outs) { mapped_results.push_back(mapping.lookup(result)); } rewriter.create(loc, mapped_results); return {function, ins, map}; }; KernelLaunchOp ConsumeOpsToFunc(std::vector& wrap_ops, mlir::PatternRewriter& rewriter, int& name_index) { if (wrap_ops.size() < 2) { wrap_ops.clear(); return nullptr; } auto loc = wrap_ops.front()->getLoc(); OpBuilder::InsertionGuard guard(rewriter); auto [wrap_func, wrap_ins, map] = CreateWrapFuncAndReturnWithIns(loc, wrap_ops, rewriter, name_index); auto func_name = wrap_func.getSymNameAttr(); std::vector attrs; for (auto attr : wrap_ops[0]->getAttrs()) { auto attr_list = {"scope_symbol_id", "device_tag", "device_name"}; if (std::find(attr_list.begin(), attr_list.end(), attr.getName()) != attr_list.end()) { attrs.push_back(attr); } } attrs.emplace_back(rewriter.getStringAttr("op_name"), func_name); rewriter.setInsertionPointAfter(wrap_ops.back()); auto func = rewriter.create(wrap_ops[0]->getLoc(), wrap_func, ArrayRef(attrs), wrap_ins); if (failed(DumpAssembly(rewriter, func, func_name))) { LOG(FATAL) << "Fail to dumping asm to kernel launch op."; } for (auto it : llvm::zip(map, wrap_ops)) { auto op = std::get<1>(it); auto list = std::get<0>(it); if (!list.size()) { op->dropAllUses(); rewriter.eraseOp(op); continue; } std::vector vals; for (auto idx : list) { vals.push_back(func->getResult(idx)); } if (op->getNumResults() == vals.size()) { rewriter.replaceOp(op, vals); } else { // if op has multi results but only some of them used outside, we need tackle with // mapper manually. int idx = 0; auto results = op->getResults(); for (auto it = results.begin(); it != results.end(); ++it) { for (auto user : (*it).getUsers()) { if (std::find(wrap_ops.begin(), wrap_ops.end(), user) == wrap_ops.end()) { (*it).replaceAllUsesWith(func->getResult(list[idx])); idx += 1; break; } } } rewriter.eraseOp(op); } } wrap_ops.clear(); return func; } struct ExtractKernelLaunchTensorPattern : public mlir::OpRewritePattern { static func::FuncOp ExtractArgTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) { auto launcher_ctx_type = okl::LauncherContextType::get(rewriter.getContext()); auto return_types = op.getBody().front().back().getOperandTypes(); auto func_type = rewriter.getFunctionType({launcher_ctx_type}, return_types); auto func = rewriter.create(op.getLoc(), op.getName(), func_type); auto& body = func.getBody(); body.emplaceBlock(); body.addArgument(launcher_ctx_type, op->getLoc()); auto launcher_ctx = body.getArgument(0); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&body.front()); IRMapping mapping; for (const auto& arg : llvm::enumerate(op.getBody().getArguments())) { auto tensor = rewriter.create(func->getLoc(), arg.value().getType(), launcher_ctx, arg.index()); mapping.map(arg.value(), tensor); } ImplicitLocOpBuilder new_block(func->getLoc(), rewriter); for (auto& op : op.getBody().front().getOperations()) { new_block.clone(op, mapping); } rewriter.eraseOp(op); return func; } static func::FuncOp ExtractRetTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) { auto& block = op.getBody().front(); auto launcher_ctx = op.getArgument(0); auto& return_op = block.back(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&return_op); std::vector returns; for (const auto& ret_val : llvm::enumerate(return_op.getOperands())) { auto new_ret = rewriter.create( op->getLoc(), ret_val.value().getType(), launcher_ctx, ret_val.value(), ret_val.index()); returns.push_back(new_ret); } rewriter.replaceOpWithNewOp(&return_op, ValueRange{returns}); return op; } explicit ExtractKernelLaunchTensorPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { if (op.getBody().getNumArguments()) { // skip if already converted if (op.getBody().getArgument(0).getType().isa()) { return success(); } } op = ExtractArgTensors(op, rewriter); op = ExtractRetTensors(op, rewriter); return success(); } }; struct TrimReturnAsVoidPattern : public mlir::OpRewritePattern { explicit TrimReturnAsVoidPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { if (op.getBody().front().back().getNumOperands() == 0) { return success(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); auto func_type = rewriter.getFunctionType(op.getFunctionType().getInputs(), TypeRange{}); auto func = rewriter.create(op.getLoc(), op.getName(), func_type); IRMapping bvm; op.getRegion().cloneInto(&func.getRegion(), bvm); auto& old_ret = func.getBody().front().back(); rewriter.setInsertionPoint(&old_ret); rewriter.replaceOpWithNewOp(&old_ret); rewriter.eraseOp(op); return success(); } }; struct KernelLaunchPattern : public mlir::OpRewritePattern { explicit KernelLaunchPattern(mlir::MLIRContext* context, bool trim = false) : OpRewritePattern(context, /*benefit=*/0) {} // if the pre-packed ops is continuous with the current op, this current op will be packed with // pre-packed ops together. virtual bool IsConsecutive(std::vector&, mlir::Operation*) const { return true; }; virtual bool IsPackagable(mlir::Operation* op) const { return GetModuleOpFromJobBodyOp(&(*op)) && op->getAttr("op_name") && dyn_cast(op) && op->getName().getStringRef() != KernelLaunchOp::getOperationName(); } mlir::LogicalResult matchAndRewrite(oneflow::Job op, mlir::PatternRewriter& rewriter) const override { auto& ops = op->getRegion(0).front(); if (ops.empty()) { return success(); } int name_index = 0; std::vector current_wrap_ops; for (auto op_it = ops.begin(); op_it != ops.end(); ++op_it) { auto current_op = &(*op_it); if (!IsPackagable(current_op)) { ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index); continue; } if (!IsConsecutive(current_wrap_ops, current_op)) { ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index); } current_wrap_ops.push_back(current_op); } if (!current_wrap_ops.empty()) { ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index); } return success(); } }; struct KernelLaunchSimplePattern : public KernelLaunchPattern { explicit KernelLaunchSimplePattern(mlir::MLIRContext* context) : KernelLaunchPattern(context) {} bool IsSameDevice(std::vector& ops, mlir::Operation* op) const { if (ops.empty()) { return true; } auto device_tag = op->getAttr("device_tag").dyn_cast_or_null(); auto device_name = op->getAttr("device_name").dyn_cast_or_null(); auto cmp_device_tag = ops.front()->getAttr("device_tag").dyn_cast_or_null(); auto cmp_device_name = ops.front()->getAttr("device_name").dyn_cast_or_null(); if (!device_tag || !device_name || !cmp_device_tag || !cmp_device_name) { return false; } auto same_device_tag = device_tag.str() == cmp_device_tag.str(); auto same_device_name = std::equal(device_name.begin(), device_name.end(), cmp_device_name.begin(), [](const Attribute a, const Attribute b) { auto a_str = a.dyn_cast_or_null(); auto b_str = b.dyn_cast_or_null(); if (!a_str || !b_str) { return false; } return a_str.str() == b_str.str(); }); return same_device_tag && same_device_name; } bool IsConsecutive(std::vector& ops, mlir::Operation* op) const override { if (ops.empty()) { return true; } return IsSameDevice(ops, op); } }; struct KernelLaunchWithCudaGraphPattern : public KernelLaunchSimplePattern { explicit KernelLaunchWithCudaGraphPattern(mlir::MLIRContext* context) : KernelLaunchSimplePattern(context) {} bool IsOpCudaGraphSupport(mlir::Operation* op) const { ::oneflow::okl::RegContext reg_ctx(op); auto* kernel = const_cast<::oneflow::user_op::OpKernel*>(reg_ctx.GetKernel()); return dynamic_cast<::oneflow::user_op::CudaGraphSupport*>(kernel); } bool IsSameCudaGraphSupport(std::vector& ops, mlir::Operation* op) const { if (ops.empty()) { return true; } auto cuda_support = IsOpCudaGraphSupport(op); return cuda_support == IsOpCudaGraphSupport(ops.front()); } bool IsConsecutive(std::vector& ops, mlir::Operation* op) const override { if (ops.empty()) { return true; } return IsSameDevice(ops, op) && IsSameCudaGraphSupport(ops, op); } }; void AddLoweringToLinalgMemRefPasses(PassManager& pm) { pm.addPass(createConvertToSignlessForTosaPass()); pm.addNestedPass(LLVM::createRequestCWrappersPass()); pm.addPass(createLowerOneFlowToTosaPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addPass(createCSEPass()); pm.addNestedPass(tosa::createTosaToLinalg()); pm.addNestedPass(tosa::createTosaToTensor()); pm.addNestedPass(createLinalgElementwiseOpFusionPass()); // TODO: more optimization pass // Note: OneShot bufferization with result extract realization. pm.addPass(bufferization::createEmptyTensorEliminationPass()); pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); auto oneshot_bufferize = bufferization::createOneShotBufferizePass(); CHECK( oneshot_bufferize ->initializeOptions("create-deallocs=0 bufferize-function-boundaries allow-return-allocs") .succeeded()); pm.addPass(std::move(oneshot_bufferize)); pm.addPass(bufferization::createBufferResultsToOutParamsPass()); pm.addPass(mlir::oneflow::createEliminateAllocOpsPass()); pm.addPass(createCanonicalizerPass()); } LogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module) { mlir::PassManager pm(context); mlir::oneflow::CheckEnableIRPrinting(pm); AddLoweringToLinalgMemRefPasses(pm); pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertSCFToCFPass()); pm.addNestedPass(createFoldAllocToSubviewPass()); pm.addPass(createInsertOneFlowMemPoolPass()); pm.addPass(createAppendOneFlowStreamPass()); pm.addPass(memref::createExpandOpsPass()); pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createConvertLinalgToLLVMPass()); pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); return pm.run(module); } #ifdef WITH_MLIR_CUDA_CODEGEN void AddLoweringLinalgOnBufferToGpuWithStdPasses(PassManager& pm) { pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); pm.addNestedPass(createGpuMapParallelLoopsPass()); pm.addNestedPass(createParallelLoopToGpuPass()); pm.addNestedPass(createGpuLauchSinkIndexComputationsPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(createFoldAllocToSubviewPass()); pm.addPass(createInsertOneFlowMemPoolPass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertSCFToCFPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(createGpuCopyArgPass()); } void AddAdheringCubinToGpuModulePasses(PassManager& pm) { pm.addNestedPass(createLowerAffinePass()); pm.addNestedPass(createStripDebugInfoPass()); pm.addNestedPass(createLowerGpuOpsToNVVMOpsPass()); pm.addNestedPass(createNVVMToCubinPass()); } void AddLoweringGpuToLLVMPasses(PassManager& pm) { pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createAppendOneFlowStreamPass()); pm.addPass(createGpuToLLVMConversionPass()); pm.addPass(createMgpuToOneFlowStreamPass()); pm.addPass(createReconcileUnrealizedCastsPass()); } LogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module) { InitializeLLVMNVPTXBackend(); mlir::PassManager pm(context); mlir::oneflow::CheckEnableIRPrinting(pm); AddLoweringToLinalgMemRefPasses(pm); AddLoweringLinalgOnBufferToGpuWithStdPasses(pm); pm.addPass(memref::createExpandOpsPass()); pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createGpuKernelOutliningPass()); AddAdheringCubinToGpuModulePasses(pm); AddLoweringGpuToLLVMPasses(pm); return pm.run(module); } #endif // WITH_MLIR_CUDA_CODEGEN void populateWrapOpsToKernelLaunchPatterns(::mlir::RewritePatternSet& patterns, const std::string& mode) { if (mode == wrap_mode::SIMPLE) { patterns.add(patterns.getContext()); } else if (mode == wrap_mode::CUDA_GRAPH) { #ifdef WITH_CUDA_GRAPHS patterns.add(patterns.getContext()); #else patterns.add(patterns.getContext()); #endif } else { LOG(FATAL) << "Found an unsupported mode in wrap-ops-to-kernel-launch pass"; } } void populateFuserForExistingOp(::mlir::RewritePatternSet& patterns) { populateForwardOpPatterns(patterns); rewrites::populateRewrites(patterns); constraints::populateConstraints(patterns); populateNormalizationOpPatterns(patterns); patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); } void populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns) { bool enable_nhwc = ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_PREFER_NHWC", false); if (enable_nhwc) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); } } void populateGpuHelperPatterns(::mlir::RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } void populatePreConvertInferenceOp(::mlir::RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } void populateConvertInferenceOp(::mlir::RewritePatternSet& patterns) { populateFuseConv2DBatchNormPattern(patterns); } void populatePostConvertInferenceOp(::mlir::RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/SBP/SBPAttributes.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/SBP/SBPDialect.h" #include "OneFlow/SBP/SBPAttributes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Support/LogicalResult.h" using namespace mlir; LogicalResult parseSBP(AsmParser& parser, ArrayAttr& args) { if (failed(parser.parseLSquare())) { return failure(); } if (succeeded(parser.parseOptionalRSquare())) { args = parser.getBuilder().getArrayAttr({}); return success(); } llvm::SmallVector res; llvm::SmallVector nd_list; auto parserListElem = [&](llvm::SmallVector& list) { auto loc = parser.getCurrentLocation(); if (failed(parser.parseAttribute(list.emplace_back()))) { parser.emitError(loc, "failed to parse an attribute here"); return failure(); } if (list.back().dyn_cast() || list.back().dyn_cast() || list.back().dyn_cast() || list.back().dyn_cast()) { return success(); } parser.emitError(loc, "failed to parse a sbp attribute here"); return failure(); }; auto parserList = [&]() { nd_list.clear(); if (parser.parseCommaSeparatedList([&]() { return parserListElem(nd_list); }) || parser.parseRSquare()) { return failure(); } res.emplace_back(parser.getBuilder().getArrayAttr(nd_list)); return success(); }; if (parser.parseCommaSeparatedList([&]() { if (succeeded(parser.parseOptionalLSquare())) { return parserList(); } return parserListElem(res); }) || parser.parseRSquare()) { return failure(); } args = parser.getBuilder().getArrayAttr(res); return success(); } void printSBP(AsmPrinter& printer, ArrayAttr args) { printer << args; } #define GET_ATTRDEF_CLASSES #include "OneFlow/SBPAttributes.cpp.inc" namespace mlir { namespace sbp { void SBPDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST #include "OneFlow/SBPAttributes.cpp.inc" >(); } } // namespace sbp } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/SBP/SBPDialect.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/SBP/SBPDialect.h" #include "mlir/IR/BuiltinAttributes.h" #include "OneFlow/SBPDialect.cpp.inc" #include "mlir/IR/Dialect.h" #include "mlir/IR/TypeRange.h" namespace mlir { namespace sbp { void SBPDialect::initialize() { registerAttributes(); } } // namespace sbp } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/SBP/SBPImporter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/SBP/SBPImporter.h" #include #include namespace mlir { namespace oneflow { mlir::LogicalResult SBPTranslation::PrintSbpAttrToString(mlir::Attribute sbp_attr, std::string& sbp) { if (auto sbp_s_attr = sbp_attr.dyn_cast()) { sbp = "S(" + std::to_string(sbp_s_attr.getAxis()) + ")"; } else if (auto sbp_b_attr = sbp_attr.dyn_cast()) { sbp = "B"; } else if (auto sbp_p_attr = sbp_attr.dyn_cast()) { sbp = "P"; } else if (auto sbp_p_attr = sbp_attr.dyn_cast()) { sbp = ""; } else { return mlir::failure(); } return mlir::success(); } mlir::Attribute SBPTranslation::ConvertSBPToString(mlir::Builder& builder, mlir::sbp::ParallelSignatureAttr& parallel) { std::vector list; for (auto output : parallel.getOutputs()) { if (auto nd_outputs = output.dyn_cast()) { for (auto nd_output : nd_outputs) { std::string sbp; if (failed(SBPTranslation::PrintSbpAttrToString(nd_output, sbp))) return {}; list.push_back(sbp); } } else { std::string sbp; if (failed(SBPTranslation::PrintSbpAttrToString(output, sbp))) return {}; list.push_back(sbp); } } return builder.getStrArrayAttr( makeArrayRef(llvm::SmallVector(list.begin(), list.end()))); } mlir::Attribute SBPTranslation::ConvertNdSbpToPsig(mlir::Builder& builder, const std::vector& nd_sbp, const int nd_size) { auto ctx = builder.getContext(); std::vector outputs_vec; for (const auto& sbp_data : nd_sbp) { mlir::Attribute attr; if (sbp_data == "") { attr = mlir::sbp::AnyAttr::get(ctx); } else { ::oneflow::SbpParallel sbp; ParseSbpParallelFromString(sbp_data, &sbp); if (sbp.has_split_parallel()) { attr = mlir::sbp::SplitAttr::get(ctx, sbp.split_parallel().axis()); } else if (sbp.has_broadcast_parallel()) { attr = mlir::sbp::BroadcastAttr::get(ctx); } else if (sbp.has_partial_sum_parallel()) { attr = mlir::sbp::PartialSumAttr::get(ctx); } else { llvm::errs() << "Unsupported sbp type from nd_sbp: "; for (const auto& sbp_data : nd_sbp) { llvm::errs() << sbp_data << " "; } llvm::errs() << "\n"; exit(EXIT_FAILURE); } } outputs_vec.push_back(attr); } auto inputs = builder.getArrayAttr({}); mlir::ArrayAttr outputs; std::vector outputs_vec_nd; for (auto iter = outputs_vec.begin(); iter < outputs_vec.end(); iter += nd_size) { outputs_vec_nd.emplace_back( builder.getArrayAttr(std::vector(iter, iter + nd_size))); } outputs = builder.getArrayAttr(outputs_vec_nd); return mlir::sbp::ParallelSignatureAttr::get(ctx, inputs, outputs); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/AggregateOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include namespace mlir { namespace oneflow { struct AggregateComputeOpsPattern : public mlir::OpRewritePattern { explicit AggregateComputeOpsPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(OutputOp op, mlir::PatternRewriter& rewriter) const override { if (op->getNumResults() != 1) { return failure(); } if (llvm::isa(op->getNextNode())) { return failure(); } // oneflow.output only have a single result for (auto user : op->getResult(0).getUsers()) { if (!llvm::isa(user)) { return failure(); } rewriter.setInsertionPoint(user); } auto new_val = rewriter.clone(*op)->getResults(); rewriter.replaceOp(op, new_val); return success(); }; }; namespace { class AggregateComputeOpsPass : public AggregateComputeOpsPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createAggregateComputeOpsPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/AutoNHWCOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" #include "OneFlow/Transform/TransposeHelpers.h" namespace mlir { namespace oneflow { bool Conv2DOp::IsNCHW() { return this->getDataFormat().str() == "channels_first"; } llvm::DenseSet Conv2DOp::OperandsToTranspose() { if (this->get_addToOutput()) { return {this->getIn(), this->getWeight(), this->get_addToOutput()}; } else { return {this->getIn(), this->getWeight()}; } } llvm::DenseSet Conv2DOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector Conv2DOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto conv_op = *this; SmallVector operands; operands.push_back(value[0]); operands.push_back(value[1]); if (conv_op.getBias()) operands.push_back(conv_op.getBias()); if (this->get_addToOutput()) { operands.push_back(value[2]); } NamedAttrList attributes = conv_op->getAttrs(); attributes.set(conv_op.getDataFormatAttrName(), rewriter.getStringAttr("channels_last")); auto res = rewriter .create(conv_op.getLoc(), getNHWCResultTypes(conv_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); return results; } bool BiasAddOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; } llvm::DenseSet BiasAddOp::OperandsToTranspose() { return {this->getA()}; } llvm::DenseSet BiasAddOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector BiasAddOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto bias_add_op = *this; SmallVector operands; operands.push_back(value[0]); operands.push_back(bias_add_op.getB()); NamedAttrList attributes = bias_add_op->getAttrs(); attributes.set(bias_add_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3)); auto res = rewriter .create(bias_add_op.getLoc(), getNHWCResultTypes(bias_add_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); return results; } bool BroadcastAddOp::IsNCHW() { return false; } llvm::DenseSet BroadcastAddOp::OperandsToTranspose() { return {this->getX(), this->getY()}; } llvm::DenseSet BroadcastAddOp::ResultsToTranspose() { return {this->getZ()}; } llvm::SmallVector BroadcastAddOp::NchwToNhwc(llvm::SmallVector values, PatternRewriter& rewriter) { auto broadcast_op = *this; NamedAttrList attributes = broadcast_op->getAttrs(); auto res = rewriter .create( broadcast_op.getLoc(), getNHWCResultTypes(broadcast_op), values, attributes) .getZ(); llvm::SmallVector results; results.push_back(res); return results; } bool NormalizationOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; } bool NormalizationInferenceOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; } llvm::DenseSet NormalizationOp::OperandsToTranspose() { return {this->getX()}; } llvm::DenseSet NormalizationInferenceOp::OperandsToTranspose() { return {this->getX()}; } llvm::DenseSet NormalizationOp::ResultsToTranspose() { return {this->getY()}; } llvm::DenseSet NormalizationInferenceOp::ResultsToTranspose() { return {this->getY()}; } llvm::SmallVector NormalizationOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto normalization_op = *this; SmallVector operands; operands.push_back(value[0]); if (normalization_op.getMovingMean()) operands.push_back(normalization_op.getMovingMean()); if (normalization_op.getMovingVariance()) operands.push_back(normalization_op.getMovingVariance()); operands.push_back(normalization_op.getGamma()); operands.push_back(normalization_op.getBeta()); if (normalization_op.get_addToOutput()) operands.push_back(normalization_op.get_addToOutput()); NamedAttrList attributes = normalization_op->getAttrs(); attributes.set(normalization_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3)); auto res = rewriter .create( normalization_op.getLoc(), getNHWCResultTypes(normalization_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); return results; } llvm::SmallVector NormalizationInferenceOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto normalization_op = *this; SmallVector operands; operands.push_back(value[0]); if (normalization_op.getMovingMean()) operands.push_back(normalization_op.getMovingMean()); if (normalization_op.getMovingVariance()) operands.push_back(normalization_op.getMovingVariance()); operands.push_back(normalization_op.getGamma()); operands.push_back(normalization_op.getBeta()); if (normalization_op.get_addToOutput()) operands.push_back(normalization_op.get_addToOutput()); NamedAttrList attributes = normalization_op->getAttrs(); attributes.set(normalization_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3)); auto res = rewriter .create( normalization_op.getLoc(), getNHWCResultTypes(normalization_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); return results; } bool MaxPool2DOp::IsNCHW() { return this->getDataFormat().str() == "channels_first"; } llvm::DenseSet MaxPool2DOp::OperandsToTranspose() { return {this->getX()}; } llvm::DenseSet MaxPool2DOp::ResultsToTranspose() { return {this->getY(), this->getIndice()}; } llvm::SmallVector MaxPool2DOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto max_pool_2d_op = *this; SmallVector operands; operands.push_back(value[0]); NamedAttrList attributes = max_pool_2d_op->getAttrs(); attributes.set(max_pool_2d_op.getDataFormatAttrName(), rewriter.getStringAttr("channels_last")); auto res = rewriter .create(max_pool_2d_op.getLoc(), getNHWCResultTypes(max_pool_2d_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); results.push_back(res[1]); return results; } bool ReluOp::IsNCHW() { return false; } llvm::DenseSet ReluOp::OperandsToTranspose() { return {this->getX()}; } llvm::DenseSet ReluOp::ResultsToTranspose() { return {this->getY()}; } llvm::SmallVector ReluOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto relu_op = *this; SmallVector operands{value[0]}; auto res = rewriter .create(relu_op.getLoc(), getNHWCResultTypes(relu_op), operands, relu_op->getAttrs()) ->getResults(); return {res[0]}; } bool ScalarDivOp::IsNCHW() { return false; } llvm::DenseSet ScalarDivOp::OperandsToTranspose() { return {this->getIn()}; } llvm::DenseSet ScalarDivOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector ScalarDivOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto elementwise_op = *this; SmallVector operands{value[0]}; auto res = rewriter .create(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op), operands, elementwise_op->getAttrs()) ->getResults(); return {res[0]}; } bool SiluOp::IsNCHW() { return false; } llvm::DenseSet SiluOp::OperandsToTranspose() { return {this->getIn()}; } llvm::DenseSet SiluOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector SiluOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto elementwise_op = *this; SmallVector operands{value[0]}; auto res = rewriter .create(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op), operands, elementwise_op->getAttrs()) ->getResults(); return {res[0]}; } bool CastOp::IsNCHW() { return false; } llvm::DenseSet CastOp::OperandsToTranspose() { return {this->getIn()}; } llvm::DenseSet CastOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector CastOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto elementwise_op = *this; SmallVector operands{value[0]}; auto res = rewriter .create(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op), operands, elementwise_op->getAttrs()) ->getResults(); return {res[0]}; } bool Add2Op::IsNCHW() { return false; } llvm::DenseSet Add2Op::OperandsToTranspose() { return {this->getIn0(), this->getIn1()}; } llvm::DenseSet Add2Op::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector Add2Op::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto add2_op = *this; SmallVector operands{value[0], value[1]}; auto res = rewriter .create(add2_op.getLoc(), getNHWCResultTypes(add2_op), operands, add2_op->getAttrs()) ->getResults(); return {res[0]}; } bool ConcatOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; } llvm::DenseSet ConcatOp::OperandsToTranspose() { llvm::DenseSet operands; for (auto operand : this->getIn()) { operands.insert(operand); } return operands; } llvm::DenseSet ConcatOp::ResultsToTranspose() { return {this->getOut()}; } llvm::SmallVector ConcatOp::NchwToNhwc(llvm::SmallVector values, PatternRewriter& rewriter) { auto elementwise_op = *this; NamedAttrList attributes = elementwise_op->getAttrs(); attributes.set(elementwise_op.getAxisAttrName(), IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true), APInt(64, 3, /*isSigned=*/true))); auto out = rewriter .create(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op), values, attributes) .getOut(); return {out}; } bool GroupNormOp::IsNCHW() { return this->getDataFormat().str() == "channels_first"; } llvm::DenseSet GroupNormOp::OperandsToTranspose() { return {this->getX()}; } llvm::DenseSet GroupNormOp::ResultsToTranspose() { return {this->getY()}; } llvm::SmallVector GroupNormOp::NchwToNhwc(llvm::SmallVector value, PatternRewriter& rewriter) { auto group_norm_op = *this; SmallVector operands; operands.push_back(value[0]); if (this->getAffine()) { operands.push_back(this->getBeta()); operands.push_back(this->getGamma()); } NamedAttrList attributes = group_norm_op->getAttrs(); attributes.set(group_norm_op.getDataFormatAttrName(), rewriter.getStringAttr("channels_last")); auto res = rewriter .create(group_norm_op.getLoc(), getNHWCResultTypes(group_norm_op), operands, attributes) ->getResults(); llvm::SmallVector results; results.push_back(res[0]); results.push_back(res[1]); results.push_back(res[2]); return results; } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/AutoNhwc.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "OneFlow/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace oneflow { namespace { class AutoNhwcPass : public AutoNhwcPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); oneflow::populateAutoNhwcPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createAutoNhwcPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/BufferHostRegister.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "OneFlow/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace oneflow { namespace { class BufferHostRegisterPass : public BufferHostRegisterPassBase { void runOnOperation() override { getOperation()->walk([&](memref::AllocOp alloc) { auto ranked_type = alloc.getResult().getType().cast(); Type unranked_type = UnrankedMemRefType::get(ranked_type.getElementType(), ranked_type.getMemorySpace()); OpBuilder builder(alloc); builder.setInsertionPointAfter(alloc); Value casted = builder.create(alloc->getLoc(), unranked_type, alloc); builder.create(alloc->getLoc(), casted); }); } }; class GpuCopyArgPass : public GpuCopyArgPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); oneflow::populateGpuHelperPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createBufferHostRegisterPass() { return std::make_unique(); } std::unique_ptr createGpuCopyArgPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace oneflow { namespace { static const auto MAGIC_OP_NAME = "ONEFLOW_ERASE_MAGIC"; static const auto MAGIC_SCOPE_SYMBOL_ID = 77777; struct EraseAttributes : public mlir::OpInterfaceRewritePattern { explicit EraseAttributes(mlir::MLIRContext* context, std::shared_ptr state) : OpInterfaceRewritePattern(context, /*benefit=*/1), state_{state} {} mlir::LogicalResult matchAndRewrite(UserOpCompatible op, mlir::PatternRewriter& rewriter) const override { if (op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) .getValue() .str() != MAGIC_OP_NAME) { if (state_) { state_->opNames[op] = op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()); state_->scopeSymbolIDs[op] = op->getAttrOfType( OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr()); } op->setAttr(OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr(MAGIC_OP_NAME)); op->setAttr(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID)); return success(); } else { return failure(); } } private: std::shared_ptr state_; }; struct PutAttributes : public mlir::OpInterfaceRewritePattern { explicit PutAttributes(mlir::MLIRContext* context, std::shared_ptr state) : OpInterfaceRewritePattern(context, /*benefit=*/1), state_{state} {} mlir::LogicalResult matchAndRewrite(UserOpCompatible op, mlir::PatternRewriter& rewriter) const override { if (op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) .getValue() .str() == MAGIC_OP_NAME) { if (state_) { op->setAttr(OpTrait::IsOpConfCompatible::getOpNameAttr(), state_->opNames[op]); op->setAttr(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), state_->scopeSymbolIDs[op]); } return success(); } else { return failure(); } } private: std::shared_ptr state_; }; class CSEWithAttributesIgnored : public CSEWithAttributesIgnoredBase { public: explicit CSEWithAttributesIgnored() {} explicit CSEWithAttributesIgnored(std::shared_ptr state) : state_(state) {} void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext(), state_); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } private: std::shared_ptr state_; }; class CSEPutAttributes : public CSEPutAttributesBase { public: explicit CSEPutAttributes() {} explicit CSEPutAttributes(std::shared_ptr state) { state_ = state; } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext(), state_); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } private: std::shared_ptr state_; }; } // namespace std::unique_ptr createCSEWithAttributesIgnored() { return std::make_unique(); } std::unique_ptr createCSEPutAttributes() { return std::make_unique(); } std::pair, std::unique_ptr> createCSEPasses( std::shared_ptr state) { return std::make_pair(std::make_unique(state), std::make_unique(state)); } void registerCSEPasses(std::shared_ptr state) { ::mlir::registerPass([state]() -> std::unique_ptr<::mlir::Pass> { return std::make_unique(state); }); ::mlir::registerPass([state]() -> std::unique_ptr<::mlir::Pass> { return std::make_unique(state); }); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/ConvertInferenceOp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "OneFlow/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "OneFlow/OneFlowPatternUtils.h" namespace mlir { namespace oneflow { namespace { class PreConvertInferenceOpPass : public PreConvertInferenceOpPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); oneflow::populatePreConvertInferenceOp(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class ConvertInferenceOpPass : public ConvertInferenceOpPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); oneflow::populateConvertInferenceOp(patterns); oneflow::rewrites::populateRewrites(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class PostConvertInferenceOpPass : public PostConvertInferenceOpPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); oneflow::populatePostConvertInferenceOp(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createPreConvertInferenceOpPass() { return std::make_unique(); } std::unique_ptr createConvertInferenceOpPass() { return std::make_unique(); } std::unique_ptr createPostConvertInferenceOpPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/EliminateAllocOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowPDLLPatterns.h" #include "OneFlow/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace oneflow { namespace { class EliminateAllocOpsPass : public EliminateAllocOpsPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); mlir::oneflow::populateAllocEliminationPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createEliminateAllocOpsPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/FuncOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace func { struct FuncConversionToOneFlow final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto func = rewriter.create(op.getLoc(), op.getName(), op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); } }; struct ReturnConversionToOneFlow final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp(op, /* operands */ op.getOperands()); return success(); } }; } // namespace func namespace oneflow { struct JobConversionToFunc final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Job op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto func = rewriter.create(op.getLoc(), op.getName(), op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); } }; struct ReturnConversionToFunc final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp(op, /* operands */ op.getOperands()); return success(); } }; namespace { class OneFlowJobToFuncPass : public OneFlowJobToFuncPassBase { void runOnOperation() override { Operation* op = getOperation(); ConversionTarget target(getContext()); target.addLegalDialect(); RewritePatternSet patterns(&getContext()); patterns.add(op->getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); LOG(ERROR) << "Failed to ofjob to func"; getOperation()->dump(); } } }; class FuncToOneFlowJobPass : public FuncToOneFlowJobPassBase { void getDependentDialects(::mlir::DialectRegistry& registry) const override { registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); ConversionTarget target(getContext()); target.addLegalDialect(); RewritePatternSet patterns(&getContext()); patterns.add(op->getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); LOG(ERROR) << "Failed to func to ofjob"; getOperation()->dump(); } } }; } // namespace std::unique_ptr createOneFlowJobToFuncPass() { return std::make_unique(); } std::unique_ptr createFuncToOneFlowJobPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/GroupMatMulOps.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" namespace mlir { namespace oneflow { template bool isLinearMatmulOp(OpTy op) { const bool isAlphaOne = op.getAlpha().convertToDouble() == 1.0; const bool isLinear = op.getTransposeA() == false && op.getTransposeB() == true; const bool hasNoAddToOutput = !op.get_addToOutput(); const bool isCUDA = op.getDeviceTag() == "cuda"; return isAlphaOne && isLinear && hasNoAddToOutput && isCUDA; } bool MatmulOp::isLinear() { return isLinearMatmulOp(*this); } Value MatmulOp::matMulGetX() { return getA(); } Value MatmulOp::matMulGetW() { return getB(); } Value MatmulOp::matMulGetY() { return getOut(); } bool BroadcastMatmulOp::isLinear() { return isLinearMatmulOp(*this); } Value BroadcastMatmulOp::matMulGetX() { return getA(); } Value BroadcastMatmulOp::matMulGetW() { return getB(); } Value BroadcastMatmulOp::matMulGetY() { return getOut(); } bool BiasAddOp::isLastDim() { return getAxis() == -1 || getAxis() == getOut().getType().cast().getRank() - 1; } Value BiasAddOp::biasAddGetBias() { return getB(); } Value BiasAddOp::biasAddGetOut() { return getOut(); } Value BroadcastAddOp::biasAddGetBias() { return getY(); } Value BroadcastAddOp::biasAddGetOut() { return getZ(); } bool BroadcastAddOp::isLastDim() { return true; } Value FusedMatmulBiasOp::matMulGetX() { return getX(); } Value FusedMatmulBiasOp::matMulGetW() { return getWeight(); } Value FusedMatmulBiasOp::matMulGetY() { return getOut(); } namespace { bool shouldGroupFusedMatmulBiasOp(FusedMatmulBiasOp& op) { return !op.get_addToOutput() && op.getDeviceTag() == "cuda" && op.getAlpha().convertToDouble() == 1.0; } } // namespace bool FusedMatmulBiasOp::isLinear() { return shouldGroupFusedMatmulBiasOp(*this); } bool FusedMatmulBiasOp::isLastDim() { return shouldGroupFusedMatmulBiasOp(*this); } Value FusedMatmulBiasOp::biasAddGetBias() { return getBias(); } Value FusedMatmulBiasOp::biasAddGetOut() { return getOut(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/JITPasses.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowUtils.h" #include "OneFlow/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Bytecode/BytecodeWriter.h" namespace mlir { namespace oneflow { namespace { // general lowering path: // 1. outline linalg ops to a func.func and an oneflow.jit op // 2. bufferize the func.func and update oneflow.jit op's tmp buffer size // 1. collect ops to outline // 2. create func.func jit ops to call // 3. replace the usages with jit ops' results // entries: non-oneflow ops which have operands are from oneflow ops // exits: result consumed by oneflow ops // NOTE: we assume all arg values are produced by an oneflow op and won't be an argument NamedAttrList GetJitOpAttributes(Builder& rewriter, StringRef op_name, int32_t input_size, int32_t output_size, Operation* op) { NamedAttrList attributes; attributes.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), OpTrait::IsOpConfCompatible::getDeviceTag(op)); attributes.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), OpTrait::IsOpConfCompatible::getDeviceName(op)); if (auto hierarchy = OpTrait::IsOpConfCompatible::getHierarchy(op)) { attributes.set(OpTrait::IsOpConfCompatible::getHierarchyAttr(), hierarchy); } attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr(op_name)); if (auto scope_symbol_id = OpTrait::IsOpConfCompatible::getScopeSymbolID(op)) { attributes.set(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), scope_symbol_id); } return attributes; } bool isOneFlowOp(Operation* op) { return llvm::dyn_cast(op->getDialect()); } class Outliner { private: OpBuilder& builder; Block* body; llvm::DenseSet& visitedOps; std::queue worklist{}; void cloneOpsToNewBody(Operation* op, bool defer = false) { if (visitedOps.contains(op)) { return; } for (auto operand : op->getOperands()) { if (!mapping.lookupOrNull(operand)) { if (auto defOp = operand.getDefiningOp()) { if (isOneFlowOp(defOp)) { entries.insert(operand); auto arg = body->addArgument(operand.getType(), operand.getLoc()); mapping.map(operand, arg); mappingReversed.map(arg, operand); } else { cloneOpsToNewBody(defOp, true); } } } } ImplicitLocOpBuilder nb(op->getLoc(), builder); nb.clone(*op, mapping); visitedOps.insert(op); for (auto& use : op->getUses()) { auto owner = use.getOwner(); if (isOneFlowOp(owner)) { exits.insert(use.get()); } else { if (defer) { worklist.push(owner); } else { cloneOpsToNewBody(owner); } } } if (!defer) { while (!worklist.empty()) { auto op = worklist.front(); worklist.pop(); cloneOpsToNewBody(op); } } } public: Outliner(OpBuilder& builder, Block* body, Operation* op, llvm::DenseSet& visitedOps) : builder{builder}, body{body}, visitedOps{visitedOps} { cloneOpsToNewBody(op); } IRMapping mapping{}; IRMapping mappingReversed{}; llvm::DenseSet entries{}, exits{}; }; static std::string JITOpNamePrefix = "JITOpGenerated"; int64_t getCountJITFunction() { static std::atomic_int64_t countJITFunction = 0; return countJITFunction.fetch_add(1); } namespace { std::function getLowerFunction( const StringAttr& device_tag) { auto device_tag_str = device_tag.str(); #ifdef WITH_MLIR_CUDA_CODEGEN if (device_tag_str == "cuda") { return [](mlir::MLIRContext* mlir_ctx, mlir::ModuleOp module) { CHECK(mlir::succeeded(mlir::oneflow::LowerModuleToCUDALLVM(mlir_ctx, module))) << "fail to lower OneFlow to CUDA LLVM"; }; } #endif // WITH_MLIR_CUDA_CODEGEN if (device_tag_str == "cpu") { return [](mlir::MLIRContext* mlir_ctx, mlir::ModuleOp module) { CHECK(mlir::succeeded(mlir::oneflow::LowerModuleToLLVM(mlir_ctx, module))) << "fail to lower OneFlow to LLVM"; }; } LOG(FATAL) << "Fail to match lowering function with device tag name: " << device_tag_str; } std::string convertFuncToByte(func::FuncOp& func) { std::string byte; llvm::raw_string_ostream os_byte(byte); mlir::writeBytecodeToFile(func, os_byte); return byte; } std::string lowerFuncToLLVMByte(const std::string& raw_byte, const StringAttr& device_tag) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); registry.insert(); mlir::MLIRContext mlir_ctx(registry); mlir::OwningOpRef module = ::mlir::parseSourceString(raw_byte, &mlir_ctx); mlir::registerLLVMDialectTranslation(registry); if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_STDOUT", false)) { module->print(llvm::outs()); } getLowerFunction(device_tag)(&mlir_ctx, *module); if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_STDOUT", false)) { module->print(llvm::outs()); } (*module)->setAttr(jit::RAW_GRAPH, StringAttr::get(&mlir_ctx, raw_byte)); std::string byte; llvm::raw_string_ostream os_byte(byte); mlir::writeBytecodeToFile(*module, os_byte); return byte; } } // namespace class OutlineJitFunctionPass : public OutlineJitFunctionPassBase { void runOnOperation() override { llvm::DenseSet entryOps, visitedOps; FunctionOpInterface job = getOperation(); auto& operations = job.getFunctionBody().front().getOperations(); for (auto& op : operations) { if (llvm::dyn_cast(op.getDialect())) { for (auto result : op.getResults()) { for (auto user : result.getUsers()) { if (!isOneFlowOp(user)) { entryOps.insert(user); } } } } } OpBuilder builder{&getContext()}; for (auto entryOp : entryOps) { if (visitedOps.contains(entryOp)) { continue; } OpBuilder::InsertionGuard guard(builder); auto block = new Block(); builder.setInsertionPointToStart(block); auto outliner = Outliner(builder, block, entryOp, visitedOps); SmallVector<::mlir::Value, 4> entries, exits, mappedExits; SmallVector argumentTypes, resultTypes; for (Value exit : outliner.exits) { exits.push_back(exit); mappedExits.push_back(outliner.mapping.lookup(exit)); resultTypes.push_back(exit.getType()); } builder.setInsertionPointToEnd(block); builder.create(entryOp->getLoc(), mappedExits); for (auto argument : block->getArguments()) { if (auto found = outliner.mappingReversed.lookupOrNull(argument)) { entries.push_back(found); argumentTypes.push_back(argument.getType()); } else { job->emitError() << "fail to outline, entry not found for argument #" << argument.getArgNumber(); signalPassFailure(); } } auto funcType = builder.getFunctionType(argumentTypes, resultTypes); if (auto mod = job->getParentOfType()) { auto name = JITOpNamePrefix + std::to_string(getCountJITFunction()); SmallString<16> tempBuffer; name = SanitizeIdentifier(name, tempBuffer); builder.setInsertionPointToStart(&mod.getRegion().front()); auto function = builder.create(entryOp->getLoc(), name, funcType); function.getBody().push_front(block); if (auto lastOp = exits.back().getDefiningOp()) { builder.setInsertionPointAfter(lastOp); NamedAttrList attributes = GetJitOpAttributes(builder, name, argumentTypes.size(), resultTypes.size(), entryOp->getOperand(0).getDefiningOp()); std::string byte = compileToLLVM.getValue() ? lowerFuncToLLVMByte( convertFuncToByte(function), attributes.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()) .cast()) : convertFuncToByte(function); auto jitOp = builder.create(entryOp->getLoc(), function, attributes, entries); jitOp->setAttr("mlir_assembly", builder.getStringAttr(byte)); for (const auto& old : llvm::enumerate(exits)) { old.value().replaceAllUsesWith(jitOp->getResult(old.index())); } } else { job->emitError() << "fail to outline, nowhere to replace"; signalPassFailure(); } } else { job->emitError() << "fail to outline"; signalPassFailure(); } } } }; } // namespace std::unique_ptr createOutlineJitFunctionPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/OneFlowMemPool.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Passes.h" #include "OneFlow/Transform/OneFlowMemPool.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "oneflow/core/common/hash_container.h" #include "oneflow/core/job/intra_job_mem_sharing_util.h" #include #include #include #include #include namespace mlir { namespace oneflow { namespace { Type getMemPoolElemType(MLIRContext* ctx) { return IntegerType::get(ctx, 8); } const int align_size_ = ::oneflow::kBlobBodyAlignSize; struct AllocOpInfo { Operation* val_ = nullptr; int32_t start_lifetime_ = 0; int32_t end_lifetime_ = 0; size_t size_ = 0; }; template std::vector getAllocInfoList(T op) { std::vector list; // collect all memref.alloc ops and gpu.launch_func ops op->walk([&](memref::AllocOp alloc) { list.push_back(alloc); }); std::vector ret; for (auto alloc : list) { // compute size MemRefType type = alloc->getResult(0).getType().dyn_cast(); size_t size = type.getElementTypeBitWidth() / 8; for (int64_t i : type.getShape()) { size *= i; } size = (size / align_size_ + ((size % align_size_) != 0)) * align_size_; // compute lifetime // TODO: support lifetime analysis int32_t start_lifetime = 0; int32_t end_lifetime = INT_MAX; ret.push_back({alloc, start_lifetime, end_lifetime, size}); } return ret; } void replaceAllocwithSubview(func::FuncOp func, OpBuilder& builder, const ::oneflow::MemBlockResultInfo& ret) { // create the uni memref.alloc op builder.setInsertionPointToStart(&func.getBody().front()); auto output_type = MemRefType::get({static_cast(ret.mem_block_size)}, getMemPoolElemType(func->getContext())); Value mempool = builder.create(func->getLoc(), output_type); // replace alloc with subview for (auto [op, offset] : ret.regst_desc2offset) { MemRefType type = op->getResult(0).getType().cast(); Value byte_shift = builder.create(op->getLoc(), offset); Value new_op = builder.create(op->getLoc(), type, mempool, byte_shift, ValueRange{}); op->replaceAllUsesWith(ValueRange{new_op}); op->erase(); } } bool isMemPool(Operation* op) { auto alloc = dyn_cast(op); if (!alloc) return false; MemRefType type = alloc->getOpResult(0).getType().cast(); if (!type) return false; return type.getRank() == 1 && type.getElementType() == getMemPoolElemType(op->getContext()); } struct InsertOneFlowMemPoolPattern final : public OpRewritePattern { // GetAllocOpSize(funop) -> std::pair getAllocOp(func::FuncOp func) const { memref::AllocOp ret; auto& ops = func.getBody().front(); for (auto& op : ops) { if (auto alloc = llvm::dyn_cast_or_null(op)) { if (ret) return {false, ret}; ret = alloc; } } return {true, ret}; } MemRefType getNullMemType(mlir::PatternRewriter& rewriter) const { return MemRefType::get({1}, getMemPoolElemType(rewriter.getContext())); } public: explicit InsertOneFlowMemPoolPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { auto module = op->getParentOfType(); if (module && module->getAttr(codegen::mempool::MEMPOOL_ATTR_NAME)) return success(); auto [is_legal, alloc_op] = getAllocOp(op); if (!is_legal) { LOG(FATAL) << "you should run -fold-memref-alloc before insert-ofmem-pool pass"; return failure(); } auto type = alloc_op ? alloc_op->getResult(0).getType().dyn_cast_or_null() : getNullMemType(rewriter); if (type.getRank() != 1 || type.getElementType() != getMemPoolElemType(op->getContext())) { LOG(FATAL) << "the alloc op fail to matching memref"; return failure(); } llvm::SmallVector new_operand_types; new_operand_types.push_back(type); for (auto type : op.getFunctionType().getInputs()) { new_operand_types.push_back(type); } auto function_type = rewriter.getFunctionType(new_operand_types, op.getFunctionType().getResults()); auto func = rewriter.create(op.getLoc(), op.getName(), function_type); for (auto pair : op->getDialectAttrs()) { func->setAttr(pair.getName(), pair.getValue()); } op.getBody().insertArgument(unsigned(0), type, op->getLoc()); if (alloc_op) rewriter.replaceOp(alloc_op, {op.getArgument(0)}); IRMapping bvm; op.getRegion().cloneInto(&func.getRegion(), bvm); rewriter.eraseOp(op); module->setAttr(codegen::mempool::MEMPOOL_ATTR_NAME, rewriter.getI64IntegerAttr(type.getDimSize(0))); return success(); } }; class InsertOneFlowMemPoolPass : public InsertOneFlowMemPoolPassBase { void runOnOperation() override { Operation* op = getOperation(); auto ctx = op->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class FoldAllocToSubviewPass : public FoldAllocToSubviewPassBase { void runOnOperation() override { func::FuncOp op = getOperation(); applyFoldAlloc(op); } }; } // namespace void applyFoldAlloc(func::FuncOp op) { std::vector list; // TODO-1: support cpu memory fold // TODO-2: support multiple gpu.launch op->walk([&](gpu::LaunchOp launchOp) { list = getAllocInfoList(launchOp); }); op->walk([&](scf::ForallOp launchOp) { list = getAllocInfoList(launchOp); }); { std::vector body_list; body_list = getAllocInfoList(op); list.insert(list.end(), body_list.begin(), body_list.end()); } auto ctx = op->getContext(); OpBuilder builder(ctx); // Note: no malloc op should be folded. if (!list.size()) { return; } // Note: the single malloc op with out type of memref means it has been folded. if (list.size() == 1 && oneflow::isMemPool(list.front().val_)) { return; } ::oneflow::HashMap> val2lifetime; ::oneflow::HashMap val2size; for (const auto& info : list) { val2lifetime[info.val_] = {info.start_lifetime_, info.end_lifetime_}; val2size[info.val_] = info.size_; } ::oneflow::MemBlockResultInfo ret; ::oneflow::MemReusedMemSizeFirstAlgo(false, val2lifetime, val2size, &ret); oneflow::replaceAllocwithSubview(op, builder, ret); } std::unique_ptr createInsertOneFlowMemPoolPass() { return std::make_unique(); } std::unique_ptr createFoldAllocToSubviewPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/OneFlowStream.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowPDLLPatterns.h" #include "OneFlow/Passes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include namespace mlir { namespace oneflow { namespace { struct MgpuToOneFlowStreamPattern final : public OpRewritePattern { public: explicit MgpuToOneFlowStreamPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(LLVM::CallOp op, mlir::PatternRewriter& rewriter) const override { auto ptr_type = LLVM::LLVMPointerType::get(rewriter.getContext()); auto func = op->getParentOfType(); auto callee = op.getCallee(); if (!func || !callee) return failure(); Value stream = func.getArguments().back(); if (stream.getType() != ptr_type) { LOG(ERROR) << "failed to find stream in llvm.func block arguments"; return failure(); } DenseMap, std::function>> oneflow_abi = { {"mgpuStreamCreate", {[](LLVM::CallOp& op, Value& stream) { return true; }, [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) { rewriter.replaceOp(op, {stream}); }}}, {"mgpuLaunchKernel", {[](LLVM::CallOp& op, Value& stream) { unsigned idx = op->getNumOperands(); return op.getOperand(idx - 3) != stream; }, [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) { unsigned idx = op->getNumOperands(); auto target = op.getOperand(idx - 3).getDefiningOp(); rewriter.replaceOp(target, {stream}); }}}, // this sync operation is created by gpu-to-llvm-pass from gpu.launch_func op. {"mgpuStreamSynchronize", {[](LLVM::CallOp& op, Value& stream) { return true; }, [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) { rewriter.eraseOp(op); }}}, {"mgpuStreamDestroy", {[](LLVM::CallOp& op, Value& stream) { return true; }, [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) { rewriter.eraseOp(op); }}}, }; auto out = oneflow_abi.find(callee.value().str()); if (out != oneflow_abi.end() && out->getSecond().first(op, stream)) { out->getSecond().second(rewriter, op, stream); } return success(); } }; struct AppendOneFlowStreamPattern final : public OpRewritePattern { public: explicit AppendOneFlowStreamPattern(mlir::MLIRContext* context) : OpRewritePattern(context, /*benefit=*/0) {} mlir::LogicalResult matchAndRewrite(func::FuncOp op, mlir::PatternRewriter& rewriter) const override { auto ptr_type = LLVM::LLVMPointerType::get(rewriter.getContext()); if (llvm::dyn_cast(op.getFunctionType().getInputs().back())) return success(); llvm::SmallVector new_operand_type; for (auto type : op.getFunctionType().getInputs()) { new_operand_type.push_back(type); } new_operand_type.push_back(ptr_type); auto function_type = rewriter.getFunctionType(new_operand_type, op.getFunctionType().getResults()); auto func = rewriter.create(op.getLoc(), op.getName(), function_type); for (auto pair : op->getDialectAttrs()) { func->setAttr(pair.getName(), pair.getValue()); } op.getBody().addArgument(ptr_type, func->getLoc()); IRMapping bvm; op.getRegion().cloneInto(&func.getRegion(), bvm); rewriter.eraseOp(op); return success(); } }; class AppendOneFlowStreamPass : public AppendOneFlowStreamPassBase { void runOnOperation() override { Operation* op = getOperation(); auto ctx = op->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class MgpuToOneFlowStreamPass : public MgpuToOneFlowStreamPassBase { void runOnOperation() override { Operation* op = getOperation(); auto ctx = op->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createAppendOneFlowStreamPass() { return std::make_unique(); } std::unique_ptr createMgpuToOneFlowStreamPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/OutlineAndFuse.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/OutlineAndFuse.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/Passes.h" #include "OneFlow/OneFlowPDLLPatterns.h" #include "OneFlow/OneFlowPatternUtils.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include namespace mlir { namespace oneflow { namespace { class WrapOpsToKernelLaunchPass : public WrapOpsToKernelLaunchPassBase { public: WrapOpsToKernelLaunchPass() = default; WrapOpsToKernelLaunchPass(const WrapOpsToKernelLaunchPass& other) : WrapOpsToKernelLaunchPassBase(other) {} void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); populateWrapOpsToKernelLaunchPatterns(patterns, wrap_ops_mode_.c_str()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } private: Option wrap_ops_mode_{*this, "mode", llvm::cl::desc("the mode of this pass to wrap ops"), llvm::cl::init(wrap_mode::SIMPLE)}; }; class FuseIntoExistingOpPass : public FuseIntoExistingOpPassBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); populateFuserForExistingOp(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; namespace { BiasAddCompatible getBiasAddCompatibleOp(MatMulCompatible op) { BiasAddCompatible bias_add; auto self_bias_op = dyn_cast(op.getOperation()); if (self_bias_op) /* matmul itself is also bias add op */ { bias_add = self_bias_op; } else /* there is bias add op */ { for (auto u : op.matMulGetY().getUsers()) { if (auto b = dyn_cast(u)) { bias_add = b; break; } } } if (bias_add && bias_add.isLastDim()) { return bias_add; } else { return BiasAddCompatible{}; } } } // namespace struct GroupMatMulPattern : public mlir::OpInterfaceRewritePattern { explicit GroupMatMulPattern(mlir::MLIRContext* context) : OpInterfaceRewritePattern(context, /*benefit=*/1) {} mlir::LogicalResult matchAndRewrite(MatMulCompatible op, mlir::PatternRewriter& rewriter) const override { if (!op.isLinear()) { return failure(); } auto bias_add = getBiasAddCompatibleOp(op); llvm::SmallVector all_matmuls{}; llvm::SmallVector all_bias_adds{}; for (auto xUser : op.matMulGetX().getUsers()) { if (auto matmul = dyn_cast(xUser)) { if (!matmul.isLinear()) { continue; } auto each_bias_add = getBiasAddCompatibleOp(matmul); if (each_bias_add) { all_bias_adds.push_back(each_bias_add); } if (!!bias_add == !!each_bias_add) { all_matmuls.push_back(matmul); } } } // all_matmuls has only self, means no other matmul can be grouped if (all_matmuls.size() == 1) { return failure(); } llvm::SmallVector operands{}; for (auto matmul : all_matmuls) { operands.push_back(matmul.matMulGetX()); } for (auto matmul : all_matmuls) { operands.push_back(matmul.matMulGetW()); } for (auto bias_adds : all_bias_adds) { operands.push_back(bias_adds.biasAddGetBias()); } llvm::SmallVector results{}; for (auto matmul : all_matmuls) { results.push_back(matmul.matMulGetY().getType()); } NamedAttrList attributes{}; attributes.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), OpTrait::IsOpConfCompatible::getDeviceTag(op)); attributes.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), OpTrait::IsOpConfCompatible::getDeviceName(op)); if (auto hierarchy = OpTrait::IsOpConfCompatible::getHierarchy(op)) { attributes.set(OpTrait::IsOpConfCompatible::getHierarchyAttr(), hierarchy); } if (auto scope_symbol_id = OpTrait::IsOpConfCompatible::getScopeSymbolID(op)) { attributes.set(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), scope_symbol_id); } attributes.set(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), rewriter.getDenseI32ArrayAttr({static_cast(all_matmuls.size()), static_cast(all_matmuls.size()), static_cast(all_bias_adds.size())})); attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr( "grouped_matmul_" + OpTrait::IsOpConfCompatible::getOpName(op).str())); auto grouped_matmul = rewriter.create(op->getLoc(), results, operands, attributes); if (all_bias_adds.empty()) { for (const auto& matmul : llvm::enumerate(all_matmuls)) { matmul.value().matMulGetY().replaceAllUsesWith(grouped_matmul.getYs()[matmul.index()]); } } else { CHECK(all_bias_adds.size() == all_matmuls.size()); for (const auto& bias_add : llvm::enumerate(all_bias_adds)) { bias_add.value().biasAddGetOut().replaceAllUsesWith( grouped_matmul.getYs()[bias_add.index()]); } } return success(); } }; class GroupMatMulPass : public GroupMatMulBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; struct GroupNormActivationPattern : public OpRewritePattern { explicit GroupNormActivationPattern(MLIRContext* context) : OpRewritePattern(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(oneflow::GroupNormOp op, PatternRewriter& rewriter) const override { if (op.getActivation() == "none") { llvm::SmallVector act_ops{}; for (auto& u : op.getY().getUses()) { if (auto act_op = dyn_cast(u.getOwner())) { act_ops.push_back(act_op); } } NamedAttrList attributes(op->getAttrs()); attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), rewriter.getStringAttr(OpTrait::IsOpConfCompatible::getOpName(op).str() + "_with_activation")); attributes.set("activation", rewriter.getStringAttr("silu")); auto gn_with_act = rewriter.create(op->getLoc(), op->getResultTypes(), op.getOperands(), attributes); for (auto act : act_ops) { if (auto op = dyn_cast(act)) { op.getOut().replaceAllUsesWith(gn_with_act.getY()); } } return success(); } return failure(); } }; class FuseForwardOpsPass : public FuseForwardOpsBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext()); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class FuseOpsWithBackwardImplPass : public FuseOpsWithBackwardImplBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); populateFuseOpsWithBackwardImplPattern(patterns); rewrites::populateRewrites(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; class FuseNormalizationOpsPass : public FuseNormalizationOpsBase { void runOnOperation() override { Operation* op = getOperation(); RewritePatternSet patterns(op->getContext()); populateNormalizationOpPatterns(patterns); rewrites::populateRewrites(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr createWrapOpsToKernelLaunchPass() { return std::make_unique(); } std::unique_ptr createFuseIntoExistingOpPass() { return std::make_unique(); } std::unique_ptr createGroupMatMul() { return std::make_unique(); } std::unique_ptr createFuseForwardOps() { return std::make_unique(); } std::unique_ptr createFuseOpsWithBackwardImpl() { return std::make_unique(); } std::unique_ptr createFuseNormalizationOps() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/Transform/TraitFolder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Passes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include namespace mlir { namespace oneflow { namespace { class TestOneFlowTraitFolderPass : public TestOneFlowTraitFolderPassBase { void runOnOperation() override { if (failed(applyPatternsAndFoldGreedily(getOperation(), RewritePatternSet(&getContext())))) { exit(1); } } }; } // namespace std::unique_ptr createTestOneFlowTraitFolderPass() { return std::make_unique(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/TransposeHelpers.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include namespace mlir { namespace oneflow { RankedTensorType getNHWCType(RankedTensorType t) { return RankedTensorType::get({t.getShape()[0], t.getShape()[2], t.getShape()[3], t.getShape()[1]}, t.getElementType()); } RankedTensorType getNHWCType(Type t) { return getNHWCType(t.cast()); } RankedTensorType getNHWCType(Value v) { return getNHWCType(v.getType()); } RankedTensorType getNCHWType(RankedTensorType t) { return RankedTensorType::get({t.getShape()[0], t.getShape()[3], t.getShape()[1], t.getShape()[2]}, t.getElementType()); } RankedTensorType getNCHWType(Type t) { return getNCHWType(t.cast()); } RankedTensorType getNCHWType(Value v) { return getNCHWType(v.getType()); } llvm::SmallVector getNHWCResultTypes(NCHWCompatible op) { llvm::SmallVector result_types; llvm::DenseSet transpose_result = op.ResultsToTranspose(); for (Value result : op->getOpResults()) { Type t = result.getType(); if (transpose_result.find(result) != transpose_result.end()) { result_types.push_back(getNHWCType(t)); } else { result_types.push_back(t); } } return result_types; } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/UserOpConversion.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // this file should contains functions to get operands and results with user op name and index #include "OneFlow/UserOpConversion.h" #include "OneFlow/UserOpReflection.h" #include "oneflow/core/framework/user_op_def.h" namespace mlir { namespace oneflow { namespace user_op { LogicalResult saveAttrDictionaryToOpConf(DictionaryAttr attributes, ::oneflow::OperatorConf* op_conf) { if (auto scope_symbol_id = attributes.get(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr()) .dyn_cast_or_null()) { op_conf->set_scope_symbol_id(scope_symbol_id.getInt()); } if (auto op_name = attributes.get(OpTrait::IsOpConfCompatible::getOpNameAttr()) .dyn_cast_or_null()) { op_conf->set_name(op_name.str()); } auto device_tag = attributes.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()) .dyn_cast_or_null(); CHECK(device_tag) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceTagAttr().str(); op_conf->set_device_tag(device_tag.str()); return success(); } LogicalResult doConvertUserOpAttributes(llvm::StringRef op_type_name, DictionaryAttr attributes, ::oneflow::OperatorConf& op_conf) { auto user_conf = op_conf.mutable_user_conf(); op_conf.mutable_user_conf()->set_op_type_name(op_type_name.str()); CHECK(saveAttrDictionaryToOpConf(attributes, &op_conf).succeeded()); for (auto id_attr : attributes) { auto id = id_attr.getName(); // mlir only attrs // TODO: prefix special attributes with "oneflow.". For example: `oneflow.op_type_name = "add"` if (id.strref().equals("callee") || id.strref().equals(OpTrait::IsOpConfCompatible::getDeviceNameAttr()) || id.strref().equals(OpTrait::IsOpConfCompatible::getHierarchyAttr()) || id.strref().equals(OpTrait::IsImportCompatible::getOutputLBNsAttr()) || id.strref().equals(OpTrait::IsAlternative::getOpTypeNameAttr()) || id.strref().equals( mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()) || id.strref().equals( mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr())) { continue; } else if (id.strref().equals("input_sizes") || id.strref().equals("output_sizes")) { continue; } // convert op conf attributes else if (id.strref().equals(OpTrait::IsOpConfCompatible::getOpNameAttr())) { continue; } else if (id.strref().equals(OpTrait::IsOpConfCompatible::getDeviceTagAttr())) { continue; } else if (id.strref().equals(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr())) { continue; } // convert user conf attributes else { auto attr_name = id.str(); Attribute attr = id_attr.getValue(); auto user_attr = ::oneflow::AttrValue(); const ::oneflow::AttrType attr_type = queryAttrType(op_type_name.str(), attr_name); if (attr_type == ::oneflow::kAtInt32) { user_attr.set_at_int32(attr.dyn_cast().getSInt()); } else if (attr_type == ::oneflow::kAtInt64) { user_attr.set_at_int64(attr.dyn_cast().getSInt()); } else if (attr_type == ::oneflow::kAtBool) { user_attr.set_at_bool(attr.dyn_cast().getValue()); } else if (attr_type == ::oneflow::kAtFloat) { user_attr.set_at_float(attr.dyn_cast().getValue().convertToFloat()); } else if (attr_type == ::oneflow::kAtDouble) { user_attr.set_at_double(attr.dyn_cast().getValue().convertToDouble()); } else if (attr_type == ::oneflow::kAtString) { user_attr.set_at_string(attr.dyn_cast().getValue().str()); } else if (attr_type == ::oneflow::kAtShape) { *user_attr.mutable_at_shape() = getAttrAsShape(attr); } else if (attr_type == ::oneflow::kAtStride) { *user_attr.mutable_at_stride() = getAttrAsStride(attr); } else if (attr_type == ::oneflow::kAtDataType) { const auto dt = support::FromMLIRAttrToOFDataType(attr); if (succeeded(dt)) { user_attr.set_at_data_type(dt.value()); } else { LOG(FATAL) << "fail to convert op attr to data type, key: " + id.str(); return failure(); } } else if (attr_type == ::oneflow::kAtListInt32) { user_attr.mutable_at_list_int32(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_int32()->add_val(v.dyn_cast().getSInt()); } } else if (attr_type == ::oneflow::kAtListInt64) { user_attr.mutable_at_list_int64(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_int64()->add_val(v.dyn_cast().getSInt()); } } else if (attr_type == ::oneflow::kAtListFloat) { user_attr.mutable_at_list_float(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_float()->add_val( v.dyn_cast().getValue().convertToFloat()); } } else if (attr_type == ::oneflow::kAtListDataType) { for (auto v : attr.dyn_cast().getValue()) { const auto dt = support::FromMLIRAttrToOFDataType(attr); if (succeeded(dt)) { user_attr.mutable_at_list_data_type()->add_val(dt.value()); } else { LOG(FATAL) << "fail to convert op attr to data type, key: " + id.str(); return failure(); } } } else if (attr_type == ::oneflow::kAtListShape) { for (auto shape_attr : attr.dyn_cast().getValue()) { ::oneflow::ShapeProto* shape_ptr = user_attr.mutable_at_list_shape()->add_val(); *shape_ptr = getAttrAsShape(shape_attr); } } else if (attr_type == ::oneflow::kAtListStride) { for (auto stride_attr : attr.dyn_cast().getValue()) { ::oneflow::Int64ListProto* stride_ptr = user_attr.mutable_at_list_stride()->add_val(); *stride_ptr = getAttrAsStride(stride_attr); } } else if (attr_type == ::oneflow::kAtListString) { // attr like nd_sbp requires the existence of list even it is empty user_attr.mutable_at_list_string(); for (auto s : attr.dyn_cast().getValue()) { user_attr.mutable_at_list_string()->add_val(s.dyn_cast().getValue().str()); } } else if (attr_type == ::oneflow::kAtComplexDouble) { // TODO(lml): use arrayattr to represent complex number is not safe, need improve. user_attr.mutable_at_complex_double(); auto ref = attr.dyn_cast(); user_attr.mutable_at_complex_double()->set_real( ref.getValue()[0].dyn_cast().getValue().convertToDouble()); user_attr.mutable_at_complex_double()->set_imag( ref.getValue()[1].dyn_cast().getValue().convertToDouble()); } else { return failure(); } (*user_conf->mutable_attr())[id.str()] = user_attr; } } return success(); } LogicalResult ConvertUserOpAttributes(llvm::StringRef op_type_name, ValueRange operands, DictionaryAttr attributes, ::oneflow::OperatorConf& op_conf) { { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes( op_type_name, operands.size(), attributes, keys, sizes))) { LOG(FATAL) << "fail to get filtered segment key and sizes"; return failure(); } for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); } } return doConvertUserOpAttributes(op_type_name, attributes, op_conf); } LogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf) { std::string op_type_name = GetOpTypeName(op); { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op input order"); return failure(); } for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); } } { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op output order"); return failure(); } for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); } } return doConvertUserOpAttributes(op_type_name, op->getAttrDictionary(), op_conf); } LogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf, bool is_mapping_size) { auto user_conf = op_conf.mutable_user_conf(); std::string op_type_name = GetOpTypeName(op); op_conf.mutable_user_conf()->set_op_type_name(op_type_name); if (op->hasTrait()) { if (OpTrait::IsOpConfCompatible::dump_attr(op, &op_conf).failed()) { return op->emitError("fail to save attr to op_conf"); } } auto writeAttrToShape = [](mlir::Attribute& attr, ::oneflow::ShapeProto* shape) { for (auto v : attr.dyn_cast().getValue()) { shape->add_dim(v.dyn_cast().getSInt()); } }; auto writeAttrToStride = [](mlir::Attribute& attr, ::oneflow::Int64ListProto* stride) { for (auto v : attr.dyn_cast().getValue()) { stride->add_dim(v.dyn_cast().getSInt()); } }; for (auto id_attr : op->getAttrDictionary()) { auto id = id_attr.getName(); // mlir only attrs // TODO: prefix special attributes with "oneflow.". For example: `oneflow.op_type_name = "add"` if (id.strref().equals("callee") || id.strref().equals(OpTrait::IsOpConfCompatible::getDeviceNameAttr()) || id.strref().equals(OpTrait::IsOpConfCompatible::getHierarchyAttr()) || id.strref().equals(OpTrait::IsImportCompatible::getOutputLBNsAttr()) || id.strref().equals(OpTrait::IsAlternative::getOpTypeNameAttr()) || id.strref().equals( mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()) || id.strref().equals( mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr())) { continue; } else if (id.strref().equals("input_sizes") || id.strref().equals("output_sizes")) { continue; } // convert op conf attributes else if (id.strref().equals(OpTrait::IsOpConfCompatible::getOpNameAttr())) { continue; } else if (id.strref().equals(OpTrait::IsOpConfCompatible::getDeviceTagAttr())) { continue; } else if (id.strref().equals(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr())) { continue; } // convert user conf attributes else { auto attr_name = id.str(); Attribute attr = id_attr.getValue(); auto user_attr = ::oneflow::AttrValue(); const ::oneflow::AttrType attr_type = user_op::queryAttrType(op_type_name, attr_name); if (attr_type == ::oneflow::kAtInt32) { user_attr.set_at_int32(attr.dyn_cast().getSInt()); } else if (attr_type == ::oneflow::kAtInt64) { user_attr.set_at_int64(attr.dyn_cast().getSInt()); } else if (attr_type == ::oneflow::kAtBool) { user_attr.set_at_bool(attr.dyn_cast().getValue()); } else if (attr_type == ::oneflow::kAtFloat) { user_attr.set_at_float(attr.dyn_cast().getValue().convertToFloat()); } else if (attr_type == ::oneflow::kAtDouble) { user_attr.set_at_double(attr.dyn_cast().getValue().convertToDouble()); } else if (attr_type == ::oneflow::kAtString) { user_attr.set_at_string(attr.dyn_cast().getValue().str()); } else if (attr_type == ::oneflow::kAtShape) { writeAttrToShape(attr, user_attr.mutable_at_shape()); } else if (attr_type == ::oneflow::kAtStride) { writeAttrToStride(attr, user_attr.mutable_at_stride()); } else if (attr_type == ::oneflow::kAtDataType) { const auto dt = support::FromMLIRAttrToOFDataType(attr); if (succeeded(dt)) { user_attr.set_at_data_type(dt.value()); } else { op->emitError() << "fail to convert op attr to data type, key: " + id.str(); return failure(); } } else if (attr_type == ::oneflow::kAtListInt32) { user_attr.mutable_at_list_int32(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_int32()->add_val(v.dyn_cast().getSInt()); } } else if (attr_type == ::oneflow::kAtListInt64) { user_attr.mutable_at_list_int64(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_int64()->add_val(v.dyn_cast().getSInt()); } } else if (attr_type == ::oneflow::kAtListFloat) { user_attr.mutable_at_list_float(); auto ref = attr.dyn_cast(); for (auto v : ref.getValue()) { user_attr.mutable_at_list_float()->add_val( v.dyn_cast().getValue().convertToFloat()); } } else if (attr_type == ::oneflow::kAtListDataType) { for (auto v : attr.dyn_cast().getValue()) { const auto dt = support::FromMLIRAttrToOFDataType(attr); if (succeeded(dt)) { user_attr.mutable_at_list_data_type()->add_val(dt.value()); } else { op->emitError() << "fail to convert op attr to data type, key: " + id.str(); return failure(); } } } else if (attr_type == ::oneflow::kAtListShape) { for (auto shape_attr : attr.dyn_cast().getValue()) { ::oneflow::ShapeProto* shape_ptr = user_attr.mutable_at_list_shape()->add_val(); writeAttrToShape(shape_attr, shape_ptr); } } else if (attr_type == ::oneflow::kAtListStride) { for (auto stride_attr : attr.dyn_cast().getValue()) { ::oneflow::Int64ListProto* stride_ptr = user_attr.mutable_at_list_stride()->add_val(); writeAttrToStride(stride_attr, stride_ptr); } } else if (attr_type == ::oneflow::kAtListString) { // attr like nd_sbp requires the existence of list even it is empty user_attr.mutable_at_list_string(); for (auto s : attr.dyn_cast().getValue()) { user_attr.mutable_at_list_string()->add_val(s.dyn_cast().getValue().str()); } } else if (attr_type == ::oneflow::kAtComplexDouble) { // TODO(lml): use arrayattr to represent complex number is not safe, need improve. user_attr.mutable_at_complex_double(); auto ref = attr.dyn_cast(); user_attr.mutable_at_complex_double()->set_real( ref.getValue()[0].dyn_cast().getValue().convertToDouble()); user_attr.mutable_at_complex_double()->set_imag( ref.getValue()[1].dyn_cast().getValue().convertToDouble()); } else if (attr_type == ::oneflow::kAtBytes) { auto value = attr.dyn_cast().getValue().str(); // The trailing null character also needs to be saved. user_attr.mutable_at_bytes()->assign(value.data(), value.size() + 1); } else { op->emitError() << "fail to convert op attr of name: " + attr_name; return failure(); } (*user_conf->mutable_attr())[id.str()] = user_attr; } } { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op input order"); return failure(); } for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); } if (is_mapping_size) { for (const auto it : llvm::zip(keys, sizes)) { auto key = std::get<0>(it).c_str(); auto size = std::get<1>(it); auto tar = op_conf.mutable_user_conf()->mutable_input(); auto val = ::oneflow::UserOpConf_ListString::default_instance(); tar->insert({key, val}); for (int i = 0; i < size; ++i) { tar->at(key).add_s(); } } } } { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op output order"); return failure(); } for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); } if (is_mapping_size) { for (const auto it : llvm::zip(keys, sizes)) { auto key = std::get<0>(it).c_str(); auto size = std::get<1>(it); auto tar = op_conf.mutable_user_conf()->mutable_output(); auto val = ::oneflow::UserOpConf_ListString::default_instance(); tar->insert({key, val}); for (int i = 0; i < size; ++i) { tar->at(key).add_s(); } } } } return success(); } LogicalResult ConvertUserOpInputs(llvm::StringRef op_type_name, ValueRange operands, DictionaryAttr attributes, ::oneflow::UserOpConf* user_conf) { std::vector keys{}; std::vector sizes{}; CHECK(user_op::GetFilteredSegmentKeyAndSizes( op_type_name, operands.size(), attributes, keys, sizes) .succeeded()); int32_t input_idx = 0; for (auto tuple : llvm::zip(keys, sizes)) { auto input_key = std::get<0>(tuple); auto input_size = std::get<1>(tuple); for (int32_t i = 0; i < input_size; i++) { auto input_s_ptr = (*user_conf->mutable_input())[input_key].mutable_s()->Add(); if (auto result = operands[input_idx].dyn_cast()) { *(input_s_ptr) = GetOutputLbn(result).value(); } else if (auto argument = operands[input_idx].dyn_cast()) { *(input_s_ptr) = "BlockArgument/" + std::to_string(argument.getArgNumber()); } else { LOG(FATAL) << "fail to convert MLIR result to protobuf, op_type_name: " + op_type_name.str(); return failure(); } input_idx += 1; } } return success(); } ::oneflow::ShapeProto getAttrAsShape(mlir::Attribute& attr) { ::oneflow::ShapeProto shape{}; for (auto v : attr.dyn_cast().getValue()) { shape.add_dim(v.dyn_cast().getSInt()); } return shape; } ::oneflow::Int64ListProto getAttrAsStride(mlir::Attribute& attr) { ::oneflow::Int64ListProto stride{}; for (auto v : attr.dyn_cast().getValue()) { stride.add_dim(v.dyn_cast().getSInt()); } return stride; } ::oneflow::ParallelConf getParallelConfFromAttrDictionary(DictionaryAttr attributes) { ::oneflow::ParallelConf parallel_conf{}; auto device_tag = attributes.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()) .dyn_cast_or_null(); CHECK(device_tag) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceTagAttr().str(); parallel_conf.set_device_tag(device_tag.str()); auto device_name = attributes.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr()) .dyn_cast_or_null(); CHECK(device_name) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceNameAttr().str(); for (auto s : device_name.getValue()) { parallel_conf.add_device_name(s.cast().str()); } if (auto hierarchy = attributes.get(OpTrait::IsOpConfCompatible::getHierarchyAttr()) .dyn_cast_or_null()) { for (auto dim : hierarchy.getValue()) { parallel_conf.mutable_hierarchy()->add_dim(dim.template dyn_cast().getInt()); } } return parallel_conf; } ::oneflow::ParallelConf getParallelConfFromAttrs(Attribute device_name_attr, Attribute device_tag_attr) { ::oneflow::ParallelConf parallel_conf{}; auto device_tag = device_tag_attr.dyn_cast_or_null(); CHECK(device_tag) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceTagAttr().str(); parallel_conf.set_device_tag(device_tag.str()); auto device_name = device_name_attr.dyn_cast_or_null(); CHECK(device_name) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceNameAttr().str(); for (auto s : device_name.getValue()) { parallel_conf.add_device_name(s.cast().str()); } return parallel_conf; } ::oneflow::DeviceType getDeviceTypeFromAttrDictionary(DictionaryAttr attributes) { ::oneflow::ParallelConf parallel_conf{}; auto device_tag = attributes.get(OpTrait::IsOpConfCompatible::getDeviceTagAttr()) .dyn_cast_or_null(); CHECK(device_tag) << "attr absent: " << OpTrait::IsOpConfCompatible::getDeviceTagAttr().str(); if (device_tag.str() == "cpu") { return ::oneflow::DeviceType::kCPU; } else if (device_tag.str() == "cuda") { return ::oneflow::DeviceType::kCUDA; } else if (device_tag.str() == "mlu") { return ::oneflow::DeviceType::kMLU; } else if (device_tag.str() == "npu") { return ::oneflow::DeviceType::kNPU; } else if (device_tag.str() == "xpu") { return ::oneflow::DeviceType::kXPU; } else { LOG(FATAL) << "unsupported device tag: " << device_tag.str(); return ::oneflow::DeviceType::kInvalidDevice; } } ::oneflow::AttrType queryAttrType(const std::string& op_type_name, const std::string& attr_name) { ::oneflow::user_op::UserOpDefWrapper op_def(support::getUserOpDef(op_type_name)); CHECK(op_def.IsAttrName(attr_name)) << attr_name << " not a attr name for op: " << op_type_name; return op_def.GetAttrType(attr_name); } } // namespace user_op } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/OneFlow/UserOpReflection.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // this file should contains functions to get operands and results with user op name and index #include "OneFlow/UserOpReflection.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" namespace mlir { namespace oneflow { namespace user_op { template class Trait> const std::vector* GetFullKeys(UserOpCompatible& uc, Operation* op); template class Trait> std::vector GetFullKeys(UserOp op); template class Trait> std::vector GetFullKeys(::llvm::StringRef op_type_name); template<> const std::vector* GetFullKeys(UserOpCompatible& uc, Operation* op) { if (auto alternative_name = dyn_cast(op)) { return alternative_name.inputKeys(); } return uc.inputKeys(); } template<> const std::vector* GetFullKeys(UserOpCompatible& uc, Operation* op) { if (auto alternative_name = dyn_cast(op)) { return alternative_name.outputKeys(); } return uc.outputKeys(); } template<> std::vector GetFullKeys(UserOp op) { return mlir::oneflow::support::GetInputKeys(op.getOpTypeName().str()); } template<> std::vector GetFullKeys(UserOp op) { return mlir::oneflow::support::GetOutputKeys(op.getOpTypeName().str()); } template<> std::vector GetFullKeys( ::llvm::StringRef op_type_name) { return mlir::oneflow::support::GetInputKeys(op_type_name.str()); } template<> std::vector GetFullKeys( ::llvm::StringRef op_type_name) { return mlir::oneflow::support::GetOutputKeys(op_type_name.str()); } template class Trait> std::pair getODSIndexAndLength(UserOpCompatible& op, unsigned index); template<> std::pair getODSIndexAndLength( UserOpCompatible& op, unsigned index) { return op.getODSOperandIndexAndLength(index); } template<> std::pair getODSIndexAndLength( UserOpCompatible& op, unsigned index) { return op.getODSResultIndexAndLength(index); } template class Trait> StringRef GetSegmentSizeAttr(); template<> StringRef GetSegmentSizeAttr() { return OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(); } template<> StringRef GetSegmentSizeAttr() { return OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(); } template class Trait> int32_t GetSingleSegmentSize(Operation*); template<> int32_t GetSingleSegmentSize(Operation* op) { return op->getNumOperands(); } template<> int32_t GetSingleSegmentSize(Operation* op) { return op->getNumResults(); } template class Trait> ArrayAttr GetUserOpArgSizes(UserOp); template<> ArrayAttr GetUserOpArgSizes(UserOp op) { return op.getInputSizes(); } template<> ArrayAttr GetUserOpArgSizes(UserOp op) { return op.getOutputSizes(); } template class Trait> LogicalResult GetUserOpFilteredSegmentKeyAndSizes(UserOp op, std::vector& keys, std::vector& sizes) { auto full_keys = GetFullKeys(op); for (const auto& key_size_tuple : llvm::zip(full_keys, GetUserOpArgSizes(op).getValue())) { const std::string& key = std::get<0>(key_size_tuple); const int32_t size = std::get<1>(key_size_tuple).template cast().getValue().getSExtValue(); if (size > 0) { keys.push_back(key); sizes.push_back(size); } } return success(); } Source GetOpSourceByName(Operation* op, const std::string& to_find) { if (auto user_op = dyn_cast(op)) { auto found = [&](std::vector keys, bool find_in_results /*or in operands*/ = false) -> int { auto offset = 0; for (const auto& key : llvm::enumerate(keys)) { if (key.value() == to_find) { return offset; } offset += find_in_results ? user_op.getODSResultIndexAndLength(key.index()).second : user_op.getODSOperandIndexAndLength(key.index()).second; } return -1; }; if (auto alternative_name = dyn_cast(op)) { if (auto offset = found(*alternative_name.inputKeys()); offset != -1) { return {Source::INPUT, offset}; } if (auto offset = found(*alternative_name.outputKeys(), true); offset != -1) { return {Source::OUTPUT, offset}; } } if (to_find == "tmp_buffer") { return {Source::BUFFER, 0}; } if (auto offset = found(*user_op.inputKeys()); offset != -1) { return {Source::INPUT, offset}; } if (auto offset = found(*user_op.outputKeys(), true); offset != -1) { return {Source::OUTPUT, offset}; } op->emitError(to_find + " not found in this op"); return {Source::INVALID, -1}; } op->emitError("Not support op which is not user op"); return {Source::INVALID, -1}; } template class Trait> LogicalResult GetFilteredSegmentKeyAndSizes(Operation* op, std::vector& keys, std::vector& sizes) { if (auto user_op = dyn_cast(op)) { return GetUserOpFilteredSegmentKeyAndSizes(user_op, keys, sizes); } const std::vector* full_keys = nullptr; std::vector full_sizes{}; auto uc = dyn_cast(op); if (!uc) { op->emitError("interface UserOpCompatible not supported"); return failure(); } full_keys = GetFullKeys(uc, op); if (op->hasTrait()) { const StringRef attr_name = GetSegmentSizeAttr(); const DenseI32ArrayAttr& size_attr = op->getAttrOfType(attr_name); if (!size_attr) return failure(); auto segment_sizes = size_attr.asArrayRef(); if (full_keys->size() != segment_sizes.size()) { op->emitError() << "fail to convert op inputs, attr_name: " << attr_name << ", full_keys: " << full_keys->size() << ", segment_sizes: " << segment_sizes.size() << ", name: " << op->getName(); op->dump(); return failure(); }; full_sizes = {segment_sizes.begin(), segment_sizes.end()}; } else { if (full_keys->size() == 1) { full_sizes.push_back(GetSingleSegmentSize(op)); } else { for (const auto& key : llvm::enumerate(*full_keys)) { full_sizes.push_back(getODSIndexAndLength(uc, key.index()).second); } } } for (const auto& key_size_tuple : llvm::zip(*full_keys, full_sizes)) { const std::string& key = std::get<0>(key_size_tuple); const int32_t size = std::get<1>(key_size_tuple); if (size > 0) { keys.push_back(key); sizes.push_back(size); } } return success(); } template class Trait> LogicalResult GetFilteredSegmentKeyAndSizes(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes, std::vector& keys, std::vector& sizes) { const std::vector full_keys = GetFullKeys(op_type_name); std::vector full_sizes{}; const StringRef attr_name = GetSegmentSizeAttr(); if (auto size_attr = attributes.get(attr_name).dyn_cast_or_null()) { if (!size_attr) return failure(); auto segment_sizes = size_attr.asArrayRef(); if (full_keys.size() != segment_sizes.size()) { LOG(FATAL) << "fail to convert op inputs, attr_name: " << attr_name.str() << ", full_keys: " << full_keys.size() << ", segment_sizes: " << segment_sizes.size(); return failure(); }; full_sizes = {segment_sizes.begin(), segment_sizes.end()}; } else { if (full_keys.size() == 1) { full_sizes.push_back(valueSize); } else { LOG(FATAL) << "set attr: " << attr_name.str(); } } for (const auto& key_size_tuple : llvm::zip(full_keys, full_sizes)) { const std::string& key = std::get<0>(key_size_tuple); const int32_t size = std::get<1>(key_size_tuple); if (size > 0) { keys.push_back(key); sizes.push_back(size); } } return success(); } template LogicalResult GetFilteredSegmentKeyAndSizes( Operation* op, std::vector& keys, std::vector& sizes); template LogicalResult GetFilteredSegmentKeyAndSizes( Operation* op, std::vector& keys, std::vector& sizes); template LogicalResult GetFilteredSegmentKeyAndSizes( llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes, std::vector& keys, std::vector& sizes); template LogicalResult GetFilteredSegmentKeyAndSizes( llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes, std::vector& keys, std::vector& sizes); template class Trait> ArgIds::ArgIds(Operation* op) { std::vector keys; std::vector sizes; if (failed(GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to get filtered segment key and sizes"); exit(1); } for (int i = 0; i < keys.size(); i += 1) { auto& key = keys[i]; for (size_t j = 0; j < sizes[i]; j += 1) { ArgID id{key, j}; ids_.push_back(id); } } } template class Trait> ArgIds::ArgIds(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes) { std::vector keys{}; std::vector sizes{}; CHECK(user_op::GetFilteredSegmentKeyAndSizes(op_type_name, valueSize, attributes, keys, sizes) .succeeded()); for (int i = 0; i < keys.size(); i += 1) { auto& key = keys[i]; for (size_t j = 0; j < sizes[i]; j += 1) { ArgID id{key, j}; ids_.push_back(id); } } } template oneflow::user_op::ArgIds::ArgIds(Operation*); template oneflow::user_op::ArgIds::ArgIds(Operation*); template oneflow::user_op::ArgIds::ArgIds( llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes); template oneflow::user_op::ArgIds::ArgIds( llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes); llvm::Optional GetOutputLbn(OpResult result) { const auto def_op = result.getDefiningOp(); if (def_op->hasTrait()) { return def_op ->getAttrOfType( OpTrait::IsImportCompatible::getOutputLBNsAttr())[result.getResultNumber()] .dyn_cast() .getValue() .str(); } else { std::vector def_op_keys{}; std::vector def_op_sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes( def_op, def_op_keys, def_op_sizes))) { def_op->emitError("fail to get output lbn"); return llvm::None; } const auto result_number = result.getResultNumber(); uint32_t size_sum = 0; for (const auto& name_size_tuple : llvm::zip(def_op_keys, def_op_sizes)) { auto name = std::get<0>(name_size_tuple); auto size = std::get<1>(name_size_tuple); if ((size_sum + size) > result_number) { const uint32_t bn_i = result_number - size_sum; return OpTrait::IsOpConfCompatible::getOpName(def_op).str() + "/" + name + "_" + std::to_string(bn_i); } size_sum += size; } } return llvm::None; } } // namespace user_op } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/Transform/CMakeLists.txt ================================================ add_mlir_library( MLIROneFlowTransformDialect TransformDialectExtension.cpp TransformDialectInterpreter.cpp TransformStateExtension.cpp EXCLUDE_FROM_LIBMLIR DEPENDS MLIROneFlowTransformDialectExtensionIncGen LINK_LIBS PUBLIC MLIRIR MLIRPass MLIRPDLDialect MLIRTransformDialect MLIRTransformDialectTransforms) ================================================ FILE: oneflow/ir/lib/Transform/TransformDialectExtension.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/OneFlowMemPool.h" #include "OneFlow/OneFlowPDLLPatterns.h" #include "Transform/TransformDialectExtension.h" #include "Transform/TransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::oneflow; using namespace mlir::transform; namespace { struct MemrefCopyOpFoldPatterns final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::CopyOp op, PatternRewriter& rewriter) const override { if (op.getSource() == op.getTarget()) rewriter.eraseOp(op); return success(); } }; } // namespace DiagnosedSilenceableFailure transform_dialect::EliminateCopyOp::applyToOne( Operation* target, transform::ApplyToEachResultList& results, transform::TransformState& state) { MLIRContext* ctx = target->getContext(); RewritePatternSet patterns(ctx); patterns.add(patterns.getContext()); mlir::oneflow::populateAllocEliminationPatterns(patterns); SmallVector ops; GreedyRewriteConfig config; target->walk([&](Operation* nestedOp) { if (target != nestedOp) ops.push_back(nestedOp); }); LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config); if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform_dialect::ExplicitLinalgOutcomeOp::applyToOne( Operation* target, transform::ApplyToEachResultList& results, transform::TransformState& state) { MLIRContext* ctx = target->getContext(); RewritePatternSet patterns(ctx); linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); SmallVector ops; GreedyRewriteConfig config; target->walk([&](Operation* nestedOp) { if (target != nestedOp) ops.push_back(nestedOp); }); LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config); if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform_dialect::CanonicalizationOp::applyToOne( Operation* target, transform::ApplyToEachResultList& results, transform::TransformState& state) { MLIRContext* ctx = target->getContext(); RewritePatternSet patterns(ctx); for (Dialect* dialect : ctx->getLoadedDialects()) dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx->getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, ctx); SmallVector ops; GreedyRewriteConfig config; target->walk([&](Operation* nestedOp) { if (target != nestedOp) ops.push_back(nestedOp); }); LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config); if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform_dialect::FoldAllocOp::applyToOne( Operation* target, transform::ApplyToEachResultList& results, transform::TransformState& state) { if (auto func = llvm::dyn_cast(target)) { applyFoldAlloc(func); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform_dialect::ResultsToOutParamsOp::applyToOne( Operation* target, transform::ApplyToEachResultList& results, transform::TransformState& state) { if (auto module = llvm::dyn_cast(target)) { if (failed(bufferization::promoteBufferResultsToOutParams(module, {}))) { return DiagnosedSilenceableFailure::definiteFailure(); } } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform_dialect::CSEOp::applyToOne(Operation* target, ApplyToEachResultList& results, transform::TransformState& state) { auto context = target->getContext(); mlir::PassManager pm(context); pm.addPass(createCSEPass()); if (failed(pm.run(target))) return mlir::emitDefiniteFailure(target, "greedy patterns failed"); return DiagnosedSilenceableFailure::success(); } namespace { class OneFlowTransformDialectExtension : public transform::TransformDialectExtension { public: using Base::Base; void init() { declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "Transform/TransformDialectExtension.cpp.inc" >(); registerTypes< #define GET_TYPEDEF_LIST #include "Transform/TransformDialectExtensionTypes.cpp.inc" >(); } }; } // namespace // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. LLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser& parser, StringRef* mnemonic, Type& value); LLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter& printer); #define GET_TYPEDEF_CLASSES #include "Transform/TransformDialectExtensionTypes.cpp.inc" #define GET_OP_CLASSES #include "Transform/TransformDialectExtension.cpp.inc" void mlir::oneflow::transform_dialect::registerTransformDialectExtension( DialectRegistry& registry) { registry.addExtensions(); } ================================================ FILE: oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; namespace { /// Simple pass that applies transform dialect ops directly contained in a /// module. template class OpPassWrapper : public PassWrapper> {}; class TransformDialectInterpreterPass : public transform::TransformInterpreterPassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectInterpreterPass) TransformDialectInterpreterPass() = default; TransformDialectInterpreterPass(const TransformDialectInterpreterPass& pass) : TransformInterpreterPassBase(pass) {} StringRef getArgument() const override { return "oneflow-transform-dialect-interpreter"; } StringRef getDescription() const override { return "apply transform dialect operations one by one"; } void findOperationsByName(Operation* root, StringRef name, SmallVectorImpl& operations) { root->walk([&](Operation* op) { if (op->getName().getStringRef() == name) { operations.push_back(op); } }); } void createParameterMapping(MLIRContext& context, ArrayRef values, RaggedArray& result) { SmallVector storage = llvm::to_vector(llvm::map_range(values, [&](int v) { Builder b(&context); return transform::MappedValue(b.getI64IntegerAttr(v)); })); result.push_back(std::move(storage)); } void createOpResultMapping(Operation* root, StringRef name, RaggedArray& extraMapping) { SmallVector operations; findOperationsByName(root, name, operations); SmallVector results; for (Operation* op : operations) llvm::append_range(results, op->getResults()); extraMapping.push_back(results); } unsigned numberOfSetOptions(const Option& ops, const ListOption& params, const Option& values) { unsigned numSetValues = 0; numSetValues += !ops.empty(); numSetValues += !params.empty(); numSetValues += !values.empty(); return numSetValues; } void runOnOperation() override { unsigned firstSetOptions = numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, bindFirstExtraToResultsOfOps); unsigned secondSetOptions = numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, bindSecondExtraToResultsOfOps); auto loc = UnknownLoc::get(&getContext()); if (firstSetOptions > 1) { emitError(loc) << "cannot bind the first extra top-level argument to " "multiple entities"; return signalPassFailure(); } if (secondSetOptions > 1) { emitError(loc) << "cannot bind the second extra top-level argument to " "multiple entities"; return signalPassFailure(); } if (firstSetOptions == 0 && secondSetOptions != 0) { emitError(loc) << "cannot bind the second extra top-level argument " "without bindings the first"; } RaggedArray extraMapping; if (!bindFirstExtraToOps.empty()) { SmallVector operations; findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), operations); extraMapping.push_back(operations); } else if (!bindFirstExtraToParams.empty()) { createParameterMapping(getContext(), bindFirstExtraToParams, extraMapping); } else if (!bindFirstExtraToResultsOfOps.empty()) { createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, extraMapping); } if (!bindSecondExtraToOps.empty()) { SmallVector operations; findOperationsByName(getOperation(), bindSecondExtraToOps, operations); extraMapping.push_back(operations); } else if (!bindSecondExtraToParams.empty()) { createParameterMapping(getContext(), bindSecondExtraToParams, extraMapping); } else if (!bindSecondExtraToResultsOfOps.empty()) { createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, extraMapping); } options = options.enableExpensiveChecks(enableExpensiveChecks); if (failed(transform::detail::interpreterBaseRunOnOperationImpl( getOperation(), getArgument(), getSharedTransformModule(), getTransformLibraryModule(), extraMapping, options, transformFileName, transformLibraryFileName, debugPayloadRootTag, debugTransformRootTag, getBinaryName()))) return signalPassFailure(); } Option enableExpensiveChecks{ *this, "enable-expensive-checks", llvm::cl::init(false), llvm::cl::desc("perform expensive checks to better report errors in the " "transform IR")}; Option bindFirstExtraToOps{ *this, "bind-first-extra-to-ops", llvm::cl::desc("bind the first extra argument of the top-level op to " "payload operations of the given kind")}; ListOption bindFirstExtraToParams{ *this, "bind-first-extra-to-params", llvm::cl::desc("bind the first extra argument of the top-level op to " "the given integer parameters")}; Option bindFirstExtraToResultsOfOps{ *this, "bind-first-extra-to-results-of-ops", llvm::cl::desc("bind the first extra argument of the top-level op to " "results of payload operations of the given kind")}; Option bindSecondExtraToOps{ *this, "bind-second-extra-to-ops", llvm::cl::desc("bind the second extra argument of the top-level op to " "payload operations of the given kind")}; ListOption bindSecondExtraToParams{ *this, "bind-second-extra-to-params", llvm::cl::desc("bind the second extra argument of the top-level op to " "the given integer parameters")}; Option bindSecondExtraToResultsOfOps{ *this, "bind-second-extra-to-results-of-ops", llvm::cl::desc("bind the second extra argument of the top-level op to " "results of payload operations of the given kind")}; Option transformFileName{ *this, "transform-file-name", llvm::cl::init(""), llvm::cl::desc("Optional filename containing a transform dialect specification to " "apply. If left empty, the IR is assumed to contain one top-level " "transform dialect operation somewhere in the module.")}; Option debugPayloadRootTag{ *this, "debug-payload-root-tag", llvm::cl::init(""), llvm::cl::desc("Select the operation with 'transform.target_tag' attribute having " "the given value as payload IR root. If empty select the pass anchor " "operation as the payload IR root.")}; Option debugTransformRootTag{ *this, "debug-transform-root-tag", llvm::cl::init(""), llvm::cl::desc("Select the operation with 'transform.target_tag' attribute having " "the given value as container IR for top-level transform ops. This " "allows user control on what transformation to apply. If empty, " "select the container of the top-level transform op.")}; Option transformLibraryFileName{ *this, "transform-library-file-name", llvm::cl::init(""), llvm::cl::desc("Optional name of the file containing transform dialect symbol " "definitions to be injected into the transform module.")}; }; struct TransformDialectEraseSchedulePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectEraseSchedulePass) StringRef getArgument() const final { return "oneflow-transform-dialect-erase-schedule"; } StringRef getDescription() const final { return "erase transform dialect schedule from the IR"; } void runOnOperation() override { getOperation()->walk([&](Operation* nestedOp) { if (isa(nestedOp)) { nestedOp->erase(); return WalkResult::skip(); } return WalkResult::advance(); }); } }; } // namespace namespace mlir { namespace oneflow { namespace transform_dialect { /// Registers the test pass for erasing transform dialect ops. void registerTransformDialectEraseSchedulePass() { PassRegistration reg; } /// Registers the test pass for applying transform dialect ops. void registerTransformDialectInterpreterPass() { PassRegistration reg; } } // namespace transform_dialect } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/lib/Transform/TransformStateExtension.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "Transform/TransformStateExtension.h" using namespace mlir; LogicalResult mlir::oneflow::transform_dialect::TransformStateExtension::updateMapping( Operation* previous, Operation* updated) { // Update value handles. The new ops should have at least as many results as // the replacement op. Fewer results are acceptable, if those results are not // mapped to any handle. for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) { SmallVector handles; (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r), handles); if (!handles.empty()) return emitError(previous->getLoc()) << "cannot replace an op with another op producing fewer results " "while tracking handles"; } for (auto [oldValue, newValue] : llvm::zip(previous->getResults(), updated->getResults())) if (failed(replacePayloadValue(oldValue, newValue))) return failure(); // Update op handle. return replacePayloadOp(previous, updated); } ================================================ FILE: oneflow/ir/llvm-in-tree.cmake ================================================ include(FetchContent) message("-- LLVM_MONO_REPO_URL: " ${LLVM_MONO_REPO_URL}) message("-- LLVM_MONO_REPO_MD5: " ${LLVM_MONO_REPO_MD5}) FetchContent_Declare(llvm_monorepo) FetchContent_GetProperties(llvm_monorepo) set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) if(NOT llvm_monorepo_POPULATED) FetchContent_Populate(llvm_monorepo URL ${LLVM_MONO_REPO_URL} URL_HASH MD5=${LLVM_MONO_REPO_MD5}) endif() set(CMAKE_INSTALL_PREFIX ${LLVM_INSTALL_DIR} CACHE STRING "" FORCE) set(LLVM_ENABLE_RTTI ON CACHE BOOL "turn this on to make it compatible with protobuf") set(LLVM_ENABLE_EH ON CACHE BOOL "turn this on to make it compatible with half (the library)") set(LLVM_ENABLE_TERMINFO OFF CACHE BOOL "disable terminfo in llvm so that oneflow doesn't need to link against it") set(LLVM_BUILD_EXAMPLES OFF CACHE BOOL "") set(LLVM_BUILD_TOOLS OFF CACHE BOOL "") set(LLVM_INCLUDE_EXAMPLES OFF CACHE BOOL "") set(LLVM_INCLUDE_TESTS OFF CACHE BOOL "" FORCE) set(MLIR_INCLUDE_TESTS OFF CACHE BOOL "" FORCE) set(LLVM_INCLUDE_BENCHMARKS OFF CACHE BOOL "") set(LLVM_TARGETS_TO_BUILD host;NVPTX CACHE STRING "") set(LLVM_ENABLE_ASSERTIONS ON CACHE BOOL "") set(LLVM_ENABLE_PROJECTS mlir CACHE STRING "") set(LLVM_APPEND_VC_REV OFF CACHE BOOL "") set(LLVM_ENABLE_ZLIB OFF CACHE BOOL "") set(LLVM_INSTALL_UTILS ON CACHE BOOL "") set(LLVM_ENABLE_OCAMLDOC OFF CACHE BOOL "") set(LLVM_ENABLE_BINDINGS OFF CACHE BOOL "") set(LLVM_OPTIMIZED_TABLEGEN ON CACHE BOOL "" FORCE) set(MLIR_ENABLE_CUDA_RUNNER ${WITH_MLIR_CUDA_CODEGEN} CACHE BOOL "" FORCE) set(LLVM_MAIN_SRC_DIR ${llvm_monorepo_SOURCE_DIR}/llvm) set(LLVM_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) set(LLVM_TOOLS_BINARY_DIR ${llvm_monorepo_BINARY_DIR}/bin CACHE STRING "" FORCE) set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") set(llvm_monorepo_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) install(TARGETS oneflow of_protoobj of_functional_obj EXPORT oneflow DESTINATION lib) install(EXPORT oneflow DESTINATION lib/oneflow) add_subdirectory(${llvm_monorepo_SOURCE_DIR}/llvm ${llvm_monorepo_BINARY_DIR}) set(LLVM_INCLUDE_DIRS ${LLVM_MAIN_SRC_DIR}/include;${llvm_monorepo_BINARY_DIR}/include) set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "" FORCE) set(LTDL_SHLIB_EXT ${CMAKE_SHARED_LIBRARY_SUFFIX}) set(LLVM_LIBRARY_DIR "${llvm_monorepo_BINARY_DIR}/lib") ================================================ FILE: oneflow/ir/oneflow-extension/CMakeLists.txt ================================================ include_directories(${PROJECT_SOURCE_DIR}/oneflow-extension/include) oneflow_add_mlir_library( MLIROneFlowExtension mlir_jit_op.cpp mlir_jit_op_kernel.cpp ir_pass.cpp lr_jit.cpp mlir_gen.cpp DEPENDS LINK_LIBS PUBLIC MLIRIR ${dialect_libs} ${translation_libs} MLIRIR MLIRParser MLIRPass MLIRSPIRVDialect MLIRTranslateLib MLIRSupport MLIROneFlow oneflow MLIRExecutionEngine MLIROneFlowTranslation MLIROneFlowRuntime) mlir_check_all_link_libraries(MLIROneFlowExtension) add_custom_target(mex DEPENDS MLIROneFlowExtension) ================================================ FILE: oneflow/ir/oneflow-extension/README.md ================================================ # OneFlow extension of MLIR features] ## KernelLaunchOp ### Stage 1 - 1:1 conversion from user op to kernel launch op ### Stage 2 - multi user op merged into one single kernel launch op ### Stage 3 - oneflow-opt and similar non-python execution environment - multi-gpu/multi-node compilation support (in the beginning it is all single-node with broadcast SBP signature) ### relationship with MlirJitOp - the graph of a MlirJitOp might contain one or multiple kernel launch op - an op inside the graph of MlirJitOp could be optionally lowered to a kernel launch op ================================================ FILE: oneflow/ir/oneflow-extension/include/CMakeLists.txt ================================================ add_subdirectory(OneFlow) ================================================ FILE: oneflow/ir/oneflow-extension/include/OneFlow/CMakeLists.txt ================================================ ================================================ FILE: oneflow/ir/oneflow-extension/include/OneFlow/JITOpInfer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ #include "oneflow/core/framework/infer_util.h" namespace oneflow { namespace ir { namespace jit { Maybe InferTensorDesc(user_op::InferContext* ctx); Maybe SetTensorDataType(user_op::InferContext* ctx); } // namespace jit } // namespace ir } // namespace oneflow #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_ ================================================ FILE: oneflow/ir/oneflow-extension/include/OneFlow/OneFlowLRJITRegistry.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_ #include "oneflow/core/common/just.h" #include "oneflow/core/common/singleton.h" #include "oneflow/core/common/util.h" #include "oneflow/ir/oneflow-extension/include/PyAst/Ast.h" #include #include #include #include #include #include namespace mlir { class ExecutionEngine; } typedef std::pair, std::function> LRJITRegistry_Store_; class LRJITRegistry final { public: OF_DISALLOW_COPY_AND_MOVE(LRJITRegistry); ~LRJITRegistry() = default; void Register(const std::string& function_id, pyast::FunctionDef& ast, bool is_dump); std::function LookUp(const std::string& function_id); private: friend class oneflow::Singleton; LRJITRegistry() = default; std::unordered_map functionId2engine_; }; #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_ ================================================ FILE: oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_ #include "oneflow/core/job_rewriter/job_pass.h" namespace oneflow { enum IRPassType : int32_t { kBeforeAD = 0, kAfterAD = 1 }; template class IRRoundTrip final : public JobPass { public: IRRoundTrip() = default; ~IRRoundTrip() override = default; bool IsEnabled(const JobPassCtx& ctx) const; Maybe Apply(Job* job, JobPassCtx* ctx) const override; }; } // namespace oneflow #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_ ================================================ FILE: oneflow/ir/oneflow-extension/include/PyAst/Ast.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_ #include #include #include #include namespace pyast { using namespace std; typedef string identifier; class arg { identifier id; public: explicit arg(const identifier& arg) : id(arg) {} identifier get_arg() { return id; } static shared_ptr arg_(const identifier& arg_) { return make_shared(arg_); } }; class arguments { vector> args; public: explicit arguments(vector> args) : args(std::move(args)) {} vector> get_args() { return args; } static shared_ptr arguments_(vector> args) { return make_shared(args); } }; class stmt { public: enum StmtKind { kFunctionDef, kReturn, kAssign, kIf, kRaise, kAssert, kExpr, }; explicit stmt(StmtKind kind) : kind(kind) {} virtual ~stmt() = default; StmtKind get_kind() const { return kind; } private: StmtKind kind; }; class expr { public: enum ExprKind { kBoolOp, kBinOp, kLambda, kCompare, kCall, kNum, kConstant, kAttribute, kName, }; explicit expr(ExprKind kind) : kind(kind) {} virtual ~expr() = default; ExprKind get_kind() const { return kind; } private: ExprKind kind; }; class FunctionDef : public stmt { identifier name; shared_ptr args; vector> body; public: FunctionDef(identifier name, shared_ptr args, vector> body) : stmt(kFunctionDef), name(std::move(name)), args(std::move(args)), body(std::move(body)) {} static shared_ptr FunctionDef_(identifier name, shared_ptr args, vector> body) { return make_shared(name, args, body); } identifier get_name() { return name; } shared_ptr get_args() { return args; } vector> get_body() { return body; } static bool classof(const stmt* c) { return c->get_kind() == kFunctionDef; } }; class Return : public stmt { shared_ptr value; public: explicit Return(shared_ptr value) : stmt(kReturn), value(std::move(value)) {} static shared_ptr Return_(shared_ptr value) { return make_shared(value); } shared_ptr get_value() { return value; } static bool classof(const stmt* c) { return c->get_kind() == kReturn; } }; class Assign : public stmt { vector> targets; shared_ptr value; public: Assign(vector> targets, shared_ptr value) : stmt(kAssign), targets(std::move(targets)), value(std::move(value)) {} static shared_ptr Assign_(vector> targets, shared_ptr value) { return make_shared(targets, value); } shared_ptr get_value() { return value; } vector> get_targets() { return targets; } static bool classof(const stmt* c) { return c->get_kind() == kAssign; } }; class If : public stmt { shared_ptr test; vector> body; vector> orelse; public: If(shared_ptr test, vector> body, vector> orelse) : stmt(kIf), test(std::move(test)), body(std::move(body)), orelse(orelse) {} static shared_ptr If_(shared_ptr test, vector> body, vector> orelse) { return make_shared(test, body, orelse); } shared_ptr get_test() { return test; } vector> get_body() { return body; } vector> get_orelse() { return orelse; } static bool classof(const stmt* c) { return c->get_kind() == kIf; } }; class Raise : public stmt { shared_ptr exc; shared_ptr cause; public: Raise(shared_ptr exc, shared_ptr cause) : stmt(kRaise), exc(std::move(exc)), cause(std::move(cause)) {} static shared_ptr Raise_(shared_ptr exc, shared_ptr cause) { return make_shared(exc, cause); } shared_ptr get_exc() { return exc; } shared_ptr get_cause() { return cause; } static bool classof(const stmt* c) { return c->get_kind() == kRaise; } }; class Assert : public stmt { shared_ptr test; shared_ptr msg; public: Assert(shared_ptr test, shared_ptr msg) : stmt(kAssert), test(std::move(test)), msg(std::move(msg)) {} static shared_ptr Assert_(shared_ptr test, shared_ptr msg) { return make_shared(test, msg); } shared_ptr get_test() { return test; } shared_ptr get_msg() { return msg; } static bool classof(const stmt* c) { return c->get_kind() == kAssert; } }; class Expr : public stmt { shared_ptr value; public: explicit Expr(shared_ptr value) : stmt(kExpr), value(std::move(value)) {} static shared_ptr Expr_(shared_ptr value) { return make_shared(value); } shared_ptr get_value() { return value; } static bool classof(const stmt* c) { return c->get_kind() == kExpr; } }; class BoolOp : public expr { public: enum boolop_t { kAnd = 1, kOr, }; BoolOp(boolop_t op, vector> values) : expr(kBoolOp), op(op), values(std::move(values)) {} static shared_ptr BoolOp_(boolop_t op, vector> values) { return make_shared(op, values); } boolop_t get_op() { return op; } vector> get_values() { return values; } static bool classof(const expr* c) { return c->get_kind() == kBoolOp; } private: boolop_t op; vector> values; }; class BinOp : public expr { public: enum operator_t { kAdd = 1, kSub, kMult, kDiv, kPow, }; BinOp(shared_ptr left, operator_t op, shared_ptr right) : expr(kBinOp), left(std::move(left)), right(std::move(right)), op(std::move(op)) {} BinOp(shared_ptr left, int op, shared_ptr right) : expr(kBinOp), left(std::move(left)), right(std::move(right)), op(int2op(op)) {} static shared_ptr BinOp_(shared_ptr left, int op, shared_ptr right) { return make_shared(left, op, right); } static operator_t int2op(int op) { return operator_t(op); } operator_t get_op() { return op; } shared_ptr get_left() { return left; } shared_ptr get_right() { return right; } static bool classof(const expr* c) { return c->get_kind() == kBinOp; } private: shared_ptr left; shared_ptr right; operator_t op; }; class Lambda : public expr { shared_ptr args; shared_ptr body; public: Lambda(shared_ptr args, shared_ptr body) : expr(kLambda), args(std::move(args)), body(std::move(body)) {} static shared_ptr Lambda_(shared_ptr args, shared_ptr body) { return make_shared(args, body); } shared_ptr get_args() { return args; } shared_ptr get_body() { return body; } static bool classof(const expr* c) { return c->get_kind() == kLambda; } }; class Compare : public expr { public: enum cmpop_t { kEq = 1, kNotEq, kLt, kLtE, kGt, kGtE, }; Compare(shared_ptr left, vector ops, vector> comparators) : expr(kCompare), left(std::move(left)), ops(std::move(ops)), comparators(std::move(comparators)) {} Compare(shared_ptr left, const vector& ops, vector> comparators) : expr(kCompare), left(std::move(left)), ops(int2op(ops)), comparators(std::move(comparators)) {} static shared_ptr Compare_(shared_ptr left, vector ops, vector> comparators) { return make_shared(left, ops, comparators); } static vector int2op(const vector& op) { vector res; for (auto i : op) res.emplace_back(cmpop_t(i)); return res; } vector get_ops() { return ops; } shared_ptr get_left() { return left; } vector> get_comparators() { return comparators; } static bool classof(const expr* c) { return c->get_kind() == kCompare; } private: shared_ptr left; vector ops; vector> comparators; }; class Call : public expr { shared_ptr func; vector> args; public: Call(shared_ptr func, vector> args) : expr(kCall), func(std::move(func)), args(std::move(args)) {} static shared_ptr Call_(shared_ptr func, vector> args) { return make_shared(func, args); } shared_ptr get_func() { return func; } vector> get_args() { return args; } static bool classof(const expr* c) { return c->get_kind() == kCall; } }; class Num : public expr { double value; public: explicit Num(double value) : expr(kNum), value(value) {} static shared_ptr Num_(double value) { return make_shared(value); } double get_value() { return value; } static bool classof(const expr* c) { return c->get_kind() == kNum; } }; class Constant : public expr { double value; public: explicit Constant(double value) : expr(kConstant), value(value) {} static shared_ptr Constant_(double value) { return make_shared(value); } double get_value() { return value; } static bool classof(const expr* c) { return c->get_kind() == kConstant; } }; class Attribute : public expr { shared_ptr value; identifier attr; public: Attribute(shared_ptr value, const identifier& attr) : expr(kAttribute), value(std::move(value)), attr(attr) {} static shared_ptr Attribute_(shared_ptr value, const identifier& attr) { return make_shared(value, attr); } shared_ptr get_value() { return value; } identifier get_attr() { return attr; } static bool classof(const expr* c) { return c->get_kind() == kAttribute; } }; class Name : public expr { identifier id; public: explicit Name(const identifier& id) : expr(kName), id(id) {} static shared_ptr Name_(const identifier& id) { return make_shared(id); } identifier get_id() { return id; } static bool classof(const expr* c) { return c->get_kind() == kName; } }; } // namespace pyast #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_ ================================================ FILE: oneflow/ir/oneflow-extension/include/PyAst/AstMlirGen.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_ #define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_ #include "OneFlow/OneFlowLRJITRegistry.h" #include "PyAst/Ast.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Builders.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/TargetSelect.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include #include #include #include class BuilderWithSymbolTable { protected: mlir::OpBuilder builder_; mlir::ModuleOp theModule_; std::map symbolTable_; mlir::Block* symbolTableForDeclareBlock_{}; explicit BuilderWithSymbolTable(mlir::MLIRContext& context) : builder_(&context) {} virtual ~BuilderWithSymbolTable() = default; mlir::LogicalResult Declare(const std::string& var, mlir::Value value); mlir::Value LoopUp(const std::string& var); mlir::Location Loc(const std::string& file_name = "unknown", int line = 0, int col = 0); void Dump(); }; class MLIRGenImpl : public BuilderWithSymbolTable { public: explicit MLIRGenImpl(mlir::MLIRContext& context) : BuilderWithSymbolTable(context) {} mlir::ModuleOp GenModule(pyast::FunctionDef* func); mlir::Value MlirGen(pyast::Compare* expr); mlir::Value MlirGen(pyast::BinOp* expr); mlir::Value MlirGen(pyast::Call* expr); mlir::Value MlirGen(pyast::Constant* expr); mlir::Value MlirGen(pyast::Name* expr); mlir::Value MlirGen(pyast::expr* expr); void MlirGen(pyast::If* stmt); void MlirGen(pyast::Assign* stmt); void MlirGen(pyast::Return* stmt); void MlirGen(pyast::stmt* stmt); }; #endif // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_ ================================================ FILE: oneflow/ir/oneflow-extension/ir_pass.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/graph/op_graph.h" #include "OneFlow/OneFlowRoundTrip.h" #include "oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/job/job_ir.h" #include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { namespace { template std::string IRPassTypeName(); template<> std::string IRPassTypeName() { return "before_ad"; } template<> std::string IRPassTypeName() { return "after_ad"; } template bool IsLastIRPassForIRPassType(); template<> bool IsLastIRPassForIRPassType() { return false; } template<> bool IsLastIRPassForIRPassType() { return true; } template class RoundTripOneFlowJobWrapper : public mlir::oneflow::RoundTripOneFlowJobWrapperInterface { public: explicit RoundTripOneFlowJobWrapper(::oneflow::Job* job) : job_(job), op_graph_(*job), job_builder_(job), is_updated_(false) {} const Job* job() const override { return job_; } bool IsLastIRPass() const override { return IsLastIRPassForIRPassType(); } void UpdateJob(::oneflow::Job* new_job) override { CHECK(is_updated_ == false); job_->Swap(new_job); is_updated_ = true; } void DumpLog(const std::string& filename, const std::string& content) override { if (IsInDebugMode()) { TeePersistentLogStream::Create(JoinPath(LogDir(), filename))->Write(content); } } const ::oneflow::ParallelConf& ParallelConf4OpName(const std::string& op_name) const override { return job_builder_.ParallelConf4OpName(op_name).GetOrThrow(); } const ::oneflow::OperatorConf& OpConf4OpName(const std::string& op_name) const override { return job_builder_.OpConf4OpName(op_name).GetOrThrow(); } std::pair, std::vector> InputBns4OpName( const std::string& op_name) const override { auto node = op_graph_.OpNode4OpName(op_name); std::vector input_bns{}; std::vector input_lbns{}; for (auto e : node->in_edges()) { for (const auto& lbi_ibn_pair : e->lbi2ibns()) { for (const auto& ibn : lbi_ibn_pair.second) { input_bns.push_back(ibn); input_lbns.push_back(GenLogicalBlobName(lbi_ibn_pair.first)); } } } return std::make_pair(input_bns, input_lbns); } std::vector OutputLbns4OpName(const std::string& op_name) const override { std::unordered_set ret{}; auto node = op_graph_.OpNode4OpName(op_name); for (auto e : node->out_edges()) { for (const auto& lbi : e->lbis()) { ret.insert(GenLogicalBlobName(lbi)); } } return {ret.begin(), ret.end()}; } std::string ReplaceInputLbnInOpCustomizedConf(::oneflow::OperatorConf* op_conf, const std::string& ibn, const std::string& new_val) const override { return ::oneflow::ReplaceInputLbnInOpCustomizedConf(op_conf, ibn, new_val); } void QueryLogicalBlob( const std::string& lbn, std::function cb) const override { LogicalBlobId lbi = GenLogicalBlobId(lbn); auto& blob_desc = op_graph_.GetLogicalBlobDesc(lbi); cb(blob_desc.shape().dim_vec().begin(), blob_desc.shape().dim_vec().end(), blob_desc.data_type()); } void TopoForEachOpConf( std::function Handler) const override { op_graph_.TopoForEachNodeWithCtrlEdge( [&](OpNode* op_node) { Handler(&op_node->op().op_conf()); }); } std::string LogDir() { return JoinPath("ir_pass", IRPassTypeName(), job_->job_conf().job_name()); } private: Job* job_; const OpGraph op_graph_; JobBuilder job_builder_; bool is_updated_; }; } // namespace template bool IRRoundTrip::IsEnabled(const JobPassCtx& ctx) const { return ParseBooleanFromEnv("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", false); } void SortJob(Job& job) { auto* ops = job.mutable_net()->mutable_op(); std::sort(ops->begin(), ops->end(), [](const oneflow::OperatorConf& l, const oneflow::OperatorConf& r) { return l.name() < r.name(); }); } template Maybe IRRoundTrip::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); Job job_before{}; job_before.CopyFrom(*job); RoundTripOneFlowJobWrapper w(job); SortJob(job_before); if (IsInDebugMode()) { TeePersistentLogStream::Create(JoinPath(w.LogDir(), "job_before_ir_round_trip.prototxt")) ->Write(job_before); } mlir::oneflow::RoundTripOneFlowJob(w, [](::oneflow::Job* job, std::string& reason) { // TODO: It is not clear how to define if extra boxing is introduced TODO(); return true; }); if (IsInDebugMode()) { Job job_after{}; job_after.CopyFrom(*job); SortJob(job_after); TeePersistentLogStream::Create(JoinPath(w.LogDir(), "job_after_ir_round_trip.prototxt")) ->Write(job_after); } return Maybe::Ok(); } template class IRRoundTrip; template class IRRoundTrip; Maybe ConvertJobToTosaIR(Job* job) { RoundTripOneFlowJobWrapper job_wrapper(job); return ::mlir::oneflow::ConvertJobToTosaIR(job_wrapper); } Maybe SaveJobToIR(Job* job, const std::string& path) { // TODO: check path is valid dir if (IsInDebugMode()) { TeePersistentLogStream::Create("saved_job")->Write(*job); } RoundTripOneFlowJobWrapper job_wrapper(job); ::mlir::oneflow::SaveJobToIR(job_wrapper, path); return Maybe::Ok(); } Maybe ConvertJobToIR(Job* job) { if (IsInDebugMode()) { TeePersistentLogStream::Create("saved_job")->Write(*job); } RoundTripOneFlowJobWrapper job_wrapper(job); return ::mlir::oneflow::ConvertJobToIR(job_wrapper); } Maybe LoadJobFromIR(Job* job, const std::string& path) { job->Clear(); RoundTripOneFlowJobWrapper job_wrapper(job); ::mlir::oneflow::LoadJobFromIR(job_wrapper, path); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/ir/oneflow-extension/lr_jit.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "PyAst/Ast.h" #include "PyAst/AstMlirGen.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/InitAllDialects.h" #include "mlir/IR/Builders.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/StringRef.h" #include #include #include #include #include using llvm::ArrayRef; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; static struct LLVMInitializer { LLVMInitializer() { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); } } initializer; static mlir::LogicalResult lowerToLLVMDialect(mlir::ModuleOp module) { mlir::PassManager pm(module.getContext()); pm.addNestedPass(mlir::LLVM::createRequestCWrappersPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); pm.addPass(mlir::createConvertFuncToLLVMPass()); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertControlFlowToLLVMPass()); pm.addPass(mlir::createConvertMathToLLVMPass()); pm.addPass(mlir::arith::createArithExpandOpsPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); return pm.run(module); } // generate a simple mlir module for test static mlir::OwningOpRef GenModuleForTest(mlir::MLIRContext& context) { std::string moduleStr = R"mlir( func.func @get_lr(%arg0 : f32, %arg1 : i32) -> f32 attributes { llvm.emit_c_interface } { return %arg0 : f32 } )mlir"; mlir::OwningOpRef module = mlir::parseSourceString(moduleStr, &context); return module; } // generate a module op from a function def python ast static mlir::OwningOpRef GenModule(mlir::MLIRContext& context, pyast::FunctionDef& ast) { using namespace pyast; MLIRGenImpl mlir_gen(context); mlir::OwningOpRef module = mlir_gen.GenModule(&ast); // module->dump(); return module; } // generate store of lr jit registry from a function def python ast static LRJITRegistry_Store_ GenFunc(pyast::FunctionDef& ast, bool is_dump) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerBuiltinDialectTranslation(registry); mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); auto module = GenModule(context, ast); if (is_dump) { module->dump(); } // auto module = genModuleForTest(context); CHECK(!!module) << "failed to parse module"; CHECK(succeeded(lowerToLLVMDialect(*module))) << "failed to lower to llvm dialect"; auto jit_or_err = mlir::ExecutionEngine::create(*module); CHECK(jit_or_err) << "failed to create JIT exe engine, " << llvm::toString(jit_or_err.takeError()); std::shared_ptr engine = cantFail(std::move(jit_or_err)); std::weak_ptr engine_ = engine; auto func = [engine_](double base_lr, double step) { float res = 0; if (!engine_.expired()) { auto engine = engine_.lock(); auto&& out = mlir::ExecutionEngine::result(res); auto base_lr_jit = static_cast(base_lr); auto step_jit = static_cast(step); auto err = engine->invoke("get_lr", base_lr_jit, step_jit, out); } return res; }; return {engine, func}; } void LRJITRegistry::Register(const std::string& function_id, pyast::FunctionDef& ast, bool is_dump) { auto jit = GenFunc(ast, is_dump); functionId2engine_[function_id] = jit; } std::function LRJITRegistry::LookUp(const std::string& function_id) { auto iter = functionId2engine_.find(function_id); if (iter != functionId2engine_.end()) { return iter->second.second; } llvm::errs() << "function '" << function_id << "' not be registered before lookup."; return nullptr; }; ================================================ FILE: oneflow/ir/oneflow-extension/mlir_gen.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "PyAst/AstMlirGen.h" // declare any scope variables in the front of function block to ensure the enough lifetime. mlir::LogicalResult BuilderWithSymbolTable::Declare(const std::string& var, mlir::Value value) { auto iter = symbolTable_.find(var); if (iter != symbolTable_.end()) { builder_.create(Loc(), value, iter->second); return mlir::failure(); } auto history_block = builder_.getInsertionBlock(); auto history_point = builder_.getInsertionPoint(); builder_.setInsertionPointToStart(symbolTableForDeclareBlock_); auto single_type = mlir::Float32Type::getF32(builder_.getContext()); auto type = mlir::MemRefType::get({}, single_type); auto key = builder_.create(Loc(), type); builder_.setInsertionPoint(history_block, history_point); builder_.create(Loc(), value, key); symbolTable_[var] = key; return mlir::success(); } // look up memref of the special symbol with variable name mlir::Value BuilderWithSymbolTable::LoopUp(const std::string& var) { if (symbolTable_.count(var) == 1) { return symbolTable_[var]; } theModule_->emitError("error: unknown variable '" + var + "'"); return nullptr; } // generate a location of mlir for ops mlir::Location BuilderWithSymbolTable::Loc(const std::string& file_name, int line, int col) { return mlir::FileLineColLoc::get(builder_.getStringAttr(file_name), line, col); } // dump the current whole module up void BuilderWithSymbolTable::Dump() { theModule_.dump(); } // generate a module op for lr jit registry from a ast mlir::ModuleOp MLIRGenImpl::GenModule(pyast::FunctionDef* func) { theModule_ = mlir::ModuleOp::create(Loc()); if (failed(verify(theModule_))) { theModule_.emitError("module verification error"); return nullptr; } builder_.setInsertionPointToEnd(theModule_.getBody()); auto args = func->get_args()->get_args(); auto type = mlir::Float32Type::getF32(builder_.getContext()); llvm::SmallVector arg_types(args.size(), type); llvm::SmallVector res_types(1, type); auto func_type = builder_.getFunctionType(arg_types, res_types); auto function = mlir::func::FuncOp::create(Loc(), func->get_name(), func_type); auto* entry_block = function.addEntryBlock(); symbolTableForDeclareBlock_ = entry_block; theModule_.push_back(function); builder_.setInsertionPointToStart(entry_block); for (const auto nameValue : llvm::zip(args, entry_block->getArguments())) { if (failed(Declare(std::get<0>(nameValue)->get_arg(), std::get<1>(nameValue)))) { return nullptr; } } builder_.setInsertionPointToStart(entry_block); for (auto& stmt : func->get_body()) { MlirGen(stmt.get()); } return theModule_; } // use llvm rtti to dispatch respective code gen tasks of stmt void MLIRGenImpl::MlirGen(pyast::stmt* stmt) { llvm::TypeSwitch(stmt) .Case([&](auto* node) { MlirGen(node); }) .Default([&](auto* node) { theModule_->emitError("StmtKind not support yet"); }); } // use llvm rtti to dispatch respective code gen tasks of expr mlir::Value MLIRGenImpl::MlirGen(pyast::expr* expr) { mlir::Value res; llvm::TypeSwitch(expr) .Case( [&](auto* node) { res = MlirGen(node); }) .Default([&](auto* node) { theModule_->emitError("ExprKind not support yet"); }); return res; } void MLIRGenImpl::MlirGen(pyast::If* expr) { auto test = MlirGen(expr->get_test().get()); if (test.getType().isF32()) { auto eq = mlir::arith::CmpFPredicate::ONE; auto zero_attr = builder_.getF32FloatAttr(0); auto zero = builder_.create(Loc(), zero_attr); test = builder_.create(Loc(), eq, test, zero); } mlir::Block* then_block = builder_.createBlock(builder_.getBlock()->getParent()); mlir::Block* else_block = builder_.createBlock(builder_.getBlock()->getParent()); mlir::Block* after_block = builder_.createBlock(builder_.getBlock()->getParent()); builder_.setInsertionPointAfterValue(test); builder_.create(Loc(), test, then_block, llvm::None, else_block, llvm::None); builder_.setInsertionPointToStart(then_block); for (auto& expr : expr->get_body()) { MlirGen(expr.get()); } if (then_block->empty() || !llvm::dyn_cast(then_block->back())) { builder_.create(Loc(), after_block); } builder_.setInsertionPointToStart(else_block); for (auto& expr : expr->get_orelse()) { MlirGen(expr.get()); } if (else_block->empty() || !llvm::dyn_cast(else_block->back())) { builder_.create(Loc(), after_block); } builder_.setInsertionPointToStart(after_block); } mlir::Value MLIRGenImpl::MlirGen(pyast::Compare* expr) { if (expr->get_comparators().size() != 1 || expr->get_ops().size() != 1) { theModule_->emitError("compare only support once compare now"); } mlir::arith::CmpFPredicate op = mlir::arith::CmpFPredicate::OEQ; switch (expr->get_ops()[0]) { case pyast::Compare::kEq: op = mlir::arith::CmpFPredicate::OEQ; break; case pyast::Compare::kNotEq: op = mlir::arith::CmpFPredicate::ONE; break; case pyast::Compare::kLt: op = mlir::arith::CmpFPredicate::OLT; break; case pyast::Compare::kLtE: op = mlir::arith::CmpFPredicate::OLE; break; case pyast::Compare::kGt: op = mlir::arith::CmpFPredicate::OGT; break; case pyast::Compare::kGtE: op = mlir::arith::CmpFPredicate::OGE; break; default: theModule_->emitError("compare_ not support op now"); } auto lhs = MlirGen(expr->get_left().get()); auto rhs = MlirGen(expr->get_comparators()[0].get()); auto res = builder_.create(Loc(), op, lhs, rhs); return res; } mlir::Value MLIRGenImpl::MlirGen(pyast::BinOp* expr) { auto lhs = MlirGen(expr->get_left().get()); auto rhs = MlirGen(expr->get_right().get()); mlir::Value res; switch (expr->get_op()) { case pyast::BinOp::kAdd: res = builder_.create(Loc(), lhs, rhs); break; case pyast::BinOp::kSub: res = builder_.create(Loc(), lhs, rhs); break; case pyast::BinOp::kDiv: res = builder_.create(Loc(), lhs, rhs); break; case pyast::BinOp::kMult: res = builder_.create(Loc(), lhs, rhs); break; case pyast::BinOp::kPow: res = builder_.create(Loc(), lhs, rhs); break; default: break; } return res; } mlir::Value MLIRGenImpl::MlirGen(pyast::Call* expr) { mlir::Value res; if (expr->get_func()->get_kind() == pyast::expr::kAttribute) { auto func_ = expr->get_func().get(); auto func = *dynamic_cast(func_); auto func_value = func.get_value(); if (func_value->get_kind() != pyast::expr::kName || dynamic_cast(func_value.get())->get_id() != "math") { theModule_->emitError("only support call func is python math lib"); } if (expr->get_args().size() != 1) { theModule_->emitError("attribute node only support call func with one param"); } auto value = MlirGen(expr->get_args()[0].get()); auto attr = func.get_attr(); if (attr == "floor") { res = builder_.create(Loc(), value); } else if (attr == "cos") { res = builder_.create(Loc(), value); } else if (attr == "ceil") { res = builder_.create(Loc(), value); } else { theModule_->emitError(attr + " not support yet"); } } else if (expr->get_func()->get_kind() == pyast::expr::kName) { auto func_ = expr->get_func().get(); auto func = *dynamic_cast(func_); if (expr->get_args().size() != 2) { theModule_->emitError("name node only support call func with two param"); } auto left = MlirGen(expr->get_args()[0].get()); auto right = MlirGen(expr->get_args()[1].get()); auto attr = func.get_id(); if (attr == "max") { res = builder_.create(Loc(), left, right); } else if (attr == "min") { res = builder_.create(Loc(), left, right); } else { theModule_->emitError(attr + " not support yet"); } } else { theModule_->emitError("only support call func is attribute and name node"); } return res; } mlir::Value MLIRGenImpl::MlirGen(pyast::Constant* expr) { float value = expr->get_value(); auto constant = builder_.create(Loc(), builder_.getF32FloatAttr(value)); return constant; } mlir::Value MLIRGenImpl::MlirGen(pyast::Name* expr) { auto key = LoopUp(expr->get_id()); builder_.setInsertionPointToEnd(builder_.getInsertionBlock()); auto value = builder_.create(Loc(), key); return value; } void MLIRGenImpl::MlirGen(pyast::Assign* stmt) { auto value = MlirGen(stmt->get_value().get()); for (auto& target : stmt->get_targets()) { if (target->get_kind() != pyast::expr::kName) { theModule_->emitError("only support assign to name node"); } auto name = dynamic_cast(target.get())->get_id(); Declare(name, value); } } void MLIRGenImpl::MlirGen(pyast::Return* stmt) { auto value = MlirGen(stmt->get_value().get()); builder_.create(Loc(), mlir::ValueRange({value})); } ================================================ FILE: oneflow/ir/oneflow-extension/mlir_jit_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowSupport.h" #include "llvm/Support/raw_ostream.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/nn_util.h" #include "OneFlow/OKL/Kernel/JITOpInfer.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" namespace oneflow { namespace { Maybe GetSbpFn(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } } // namespace Maybe MlirJitOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return ir::jit::InferTensorDesc(ctx); } Maybe MlirJitOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ir::jit::InferTensorDesc(ctx); } Maybe MlirJitOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpFn(ctx); } Maybe MlirJitOp::InferDataType(user_op::InferContext* ctx) { return ir::jit::SetTensorDataType(ctx); ; } } // namespace oneflow ================================================ FILE: oneflow/ir/oneflow-extension/mlir_jit_op_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OKL/Kernel/LauncherState.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/ir/include/OneFlow/Passes.h" #include "oneflow/ir/include/OneFlow/Extension.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Parser/Parser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "llvm/Support/TargetSelect.h" namespace oneflow { namespace { using OpaqueMemRefDescriptor = std::shared_ptr; template OpaqueMemRefDescriptor CreateMemRefDescriptor(user_op::Tensor* tensor) { using MemRefType = StridedMemRefType; auto desc = new MemRefType(); *desc = mlir::detail::makeStridedMemRefDescriptor( tensor->dptr(), tensor->dptr(), {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()}, {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()}); auto deleter = [](void const* data) { auto p = static_cast(data); delete p; }; return OpaqueMemRefDescriptor(desc, deleter); } template OpaqueMemRefDescriptor CreateMutMemRefDescriptor(user_op::Tensor* tensor) { using MemRefType = StridedMemRefType; auto desc = new MemRefType(); *desc = mlir::detail::makeStridedMemRefDescriptor( tensor->mut_dptr(), tensor->mut_dptr(), {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()}, {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()}); auto deleter = [](void const* data) { auto p = static_cast(data); delete p; }; return OpaqueMemRefDescriptor(desc, deleter); } #define MAKE_STRIDED_MEM_REF_SWITCH_ENTRY(func_name, N, T) func_name DEFINE_STATIC_SWITCH_FUNC(OpaqueMemRefDescriptor, CreateMemRefDescriptor, MAKE_STRIDED_MEM_REF_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ), MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ)); DEFINE_STATIC_SWITCH_FUNC(OpaqueMemRefDescriptor, CreateMutMemRefDescriptor, MAKE_STRIDED_MEM_REF_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ), MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ)); #undef MAKE_STRIDED_MEM_REF_SWITCH_ENTRY std::string GetMLIRCInterface(const std::string& func_name) { return std::string("_mlir_ciface_") + func_name; } llvm::SmallVector GetMLIRCInterfaceArgs( user_op::KernelComputeContext* ctx) { llvm::SmallVector args{}; auto tensor = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); args.push_back(SwitchCreateMemRefDescriptor(SwitchCase(1, kInt8), tensor)); for (auto& pair : ctx->inputs()) { auto tensor = ctx->Tensor4ArgNameAndIndex(pair.first, pair.second); auto ref = SwitchCreateMemRefDescriptor( SwitchCase(tensor->shape_view().NumAxes(), tensor->data_type()), tensor); args.push_back(ref); } for (auto& pair : ctx->outputs()) { auto tensor = ctx->Tensor4ArgNameAndIndex(pair.first, pair.second); auto ref = SwitchCreateMutMemRefDescriptor( SwitchCase(tensor->shape_view().NumAxes(), tensor->data_type()), tensor); args.push_back(ref); } return args; } mlir::DialectRegistry getDialectRegistry() { mlir::DialectRegistry registry; registry .insert(); mlir::registerLLVMDialectTranslation(registry); mlir::registerBuiltinDialectTranslation(registry); return registry; } void WithMlirContext( user_op::KernelComputeContext* ctx, const llvm::SmallVector& ext_libs, const std::function(mlir::MLIRContext* mlir_ctx)>& parse, void* stream) { mlir::MLIRContext mlir_ctx(getDialectRegistry()); mlir::OwningOpRef module = parse(&mlir_ctx); CHECK(module) << "fail to parse MLIR, op: " << ctx->op_name(); if (ParseBooleanFromEnv("ONEFLOW_MLIR_STDOUT", false)) { module->print(llvm::outs()); } mlir::ExecutionEngineOptions jitOptions; jitOptions.transformer = {}; jitOptions.jitCodeGenOptLevel = std::nullopt; jitOptions.sharedLibPaths = ext_libs; auto jit_or_error = mlir::ExecutionEngine::create(*module, jitOptions); CHECK(!!jit_or_error) << "failed to create JIT exe engine, " << llvm::toString(jit_or_error.takeError()); auto jit = std::move(jit_or_error.get()); llvm::SmallVector args /* args must outlive JIT invocation */ = GetMLIRCInterfaceArgs(ctx); llvm::SmallVector packed_args{}; for (auto& arg /* arg must be a reference*/ : args) { packed_args.push_back(&arg); } packed_args.push_back(&stream); auto error = jit->invokePacked(GetMLIRCInterface(ctx->op_name()), packed_args); CHECK(!error) << "fail to invoke jit engine, error: " << llvm::toString(std::move(error)); } size_t inferOneFlowMemPoolSize(user_op::InferContext* ctx) { using namespace user_op; mlir::MLIRContext mlir_ctx(oneflow::okl::GetRegistry()); auto mlir_assembly = ctx->Attr>("mlir_assembly"); auto mlir = mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &mlir_ctx); auto module = mlir.get(); if (auto mempool = module->getAttr(mlir::oneflow::codegen::mempool::MEMPOOL_ATTR_NAME) .cast()) { return mempool.getInt(); } // Note: we should ensure the tmp buffer should be fetched in the mlir jit op in case of null // object error. return 1; } template class MlirJitCpuKernel final : public user_op::OpKernel { public: MlirJitCpuKernel() = default; ~MlirJitCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { llvm::SmallVector ext_libs( {SharedLibPaths()->begin(), SharedLibPaths()->end()}); WithMlirContext( ctx, ext_libs, [&ctx](mlir::MLIRContext* mlir_ctx) { auto mlir_assembly = ctx->Attr>("mlir_assembly"); return mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx); }, nullptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MLIR_JIT_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("mlir_jit") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { return inferOneFlowMemPoolSize(ctx); }) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ return Maybe::Ok(); \ }); REGISTER_MLIR_JIT_CPU_KERNEL(float) REGISTER_MLIR_JIT_CPU_KERNEL(double) REGISTER_MLIR_JIT_CPU_KERNEL(int32_t) REGISTER_MLIR_JIT_CPU_KERNEL(int64_t) #undef REGISTER_MLIR_JIT_CPU_KERNEL #ifdef WITH_MLIR_CUDA_CODEGEN template class MlirJitGpuKernel final : public user_op::OpKernel { public: MlirJitGpuKernel() = default; ~MlirJitGpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { llvm::SmallVector ext_libs( {SharedLibPaths()->begin(), SharedLibPaths()->end()}); WithMlirContext( ctx, ext_libs, [&ctx](mlir::MLIRContext* mlir_ctx) { auto mlir_assembly = ctx->Attr>("mlir_assembly"); return mlir::parseSourceString( llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx); }, #ifdef WITH_CUDA ctx->stream()->As()->cuda_stream()); #else nullptr); #endif // WITH_CUDA } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MLIR_JIT_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("mlir_jit") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { return inferOneFlowMemPoolSize(ctx); }) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ return Maybe::Ok(); \ }); REGISTER_MLIR_JIT_GPU_KERNEL(float) REGISTER_MLIR_JIT_GPU_KERNEL(double) REGISTER_MLIR_JIT_GPU_KERNEL(int32_t) REGISTER_MLIR_JIT_GPU_KERNEL(int64_t) #undef REGISTER_MLIR_JIT_GPU_KERNEL #endif // WITH_MLIR_CUDA_CODEGEN } // namespace } // namespace oneflow ================================================ FILE: oneflow/ir/oneflow-lite/CMakeLists.txt ================================================ include_directories(${PROJECT_BINARY_DIR}/oneflow-lite) include_directories(${PROJECT_SOURCE_DIR}/oneflow-lite) include_directories(${PROJECT_SOURCE_DIR}/oneflow-lite/include) include_directories(${PROJECT_BINARY_DIR}/oneflow-lite/include) add_subdirectory(schemas) add_subdirectory(lib) set(LLVM_LINK_COMPONENTS Support) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_llvm_executable( oneflow-lite-compile OneFlowLiteCompileMain.cpp DEPENDS MLIROneFlow lite_schemas OneFlowLiteConversion flatcc-runtime) set(_origin_prefix "\$ORIGIN") if(APPLE) set(_origin_prefix "@loader_path") endif() set_target_properties( oneflow-lite-compile PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH "${_origin_prefix}" INSTALL_RPATH "${_origin_prefix}") llvm_update_compile_flags(oneflow-lite-compile) target_link_libraries(oneflow-lite-compile PRIVATE OneFlowLiteConversion ${dialect_libs} flatcc-runtime) ================================================ FILE: oneflow/ir/oneflow-lite/OneFlowLiteCompileMain.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "llvm/ADT/SmallString.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Path.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Pass/PassManager.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Transforms/Passes.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/ConvertToLiteExecutable.h" namespace mlir { namespace oneflow { namespace lite { LogicalResult Compile(int argc, char** argv) { llvm::InitLLVM y(argc, argv); static llvm::cl::OptionCategory mainOptions("OneFlowLite Compile Main Options"); llvm::cl::opt inputFiledir(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::Required, llvm::cl::cat(mainOptions)); llvm::cl::opt outputFilename("o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("-"), llvm::cl::cat(mainOptions)); llvm::cl::list targets("targets", llvm::cl::desc("Target backends for executable compilation"), llvm::cl::ZeroOrMore, llvm::cl::cat(mainOptions)); llvm::cl::ParseCommandLineOptions(argc, argv, "OneFlowLite compile\n"); llvm::SmallString<128> inputFilename = StringRef(inputFiledir + "/model.mlir"); llvm::sys::path::native(inputFilename); mlir::MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module = parseSourceFile(inputFilename, &context); ConvertOptions options; options.checkpointDir = inputFiledir; if (targets.empty()) { options.target = "host"; } else { if (targets.size() > 1) { llvm::errs() << "Support only one target currently.\n"; return failure(); } options.target = targets[0]; } llvm::errs() << "Enable compilation for target: " << options.target << "\n"; llvm::SmallVector executable; if (failed(ConvertToLiteExecutable(&context, module.get(), options, &executable))) { return failure(); } std::string errorMessage; auto output = mlir::openOutputFile(outputFilename, &errorMessage); if (!output) { llvm::errs() << errorMessage << "\n"; return failure(); } output->os().write(reinterpret_cast(executable.data()), executable.size()); output->keep(); return success(); } } // namespace lite } // namespace oneflow } // namespace mlir int main(int argc, char** argv) { if (mlir::failed(mlir::oneflow::lite::Compile(argc, argv))) { return 1; } return 0; } ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/ConvertToLiteExecutable.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_ #include "llvm/ADT/SmallString.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/LLVM.h" #include "OneFlow/FlatbufferUtils.h" namespace mlir { namespace oneflow { namespace lite { typedef struct ConvertOptions { llvm::SmallString<128> target; llvm::SmallString<128> checkpointDir; } ConvertOptions; LogicalResult ConvertToLiteExecutable(MLIRContext* context, ModuleOp module, ConvertOptions options, llvm::SmallVector* executable); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/FlatbufferUtils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Copyright 2020 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_ #include #include #include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcast-qual" #include "flatcc/flatcc_builder.h" #include "flatcc/flatcc_json_printer.h" #include "flatcc/reflection/reflection_builder.h" #pragma GCC diagnostic pop namespace mlir { namespace oneflow { namespace lite { // RAII wrapper for flatcc_builder_t; pass to functions requiring a builder. // // Usage: // FlatbufferBuilder builder; // // NOTE: FlatBuffers are built bottoms-up so we first generate our [uint8]: // auto dataRef = builder.streamUint8Vec(...); // // ... and then start the table that references it: // my_type_start_as_root(builder); // my_type_uint8_vec_field_add(builder, dataRef); // my_type_end_as_root(builder); // // ... and finally capture the results as an mlir::Attribute. // auto attr = builder.getBufferAttr(mlirContext); class FlatbufferBuilder { public: FlatbufferBuilder(); ~FlatbufferBuilder(); operator flatcc_builder_t*() { return &builder; } // Creates a string with the given string contents (including zeros). flatbuffers_string_ref_t createString(StringRef value) { if (value.empty()) return 0; return flatbuffers_string_create(*this, value.data(), value.size()); } // Creates a string vector containing all strings in the given range. template flatbuffers_string_vec_ref_t createStringVec(RangeTy&& Range) { auto stringRefs = llvm::to_vector<8>(llvm::map_range(Range, [&](StringRef value) { return flatbuffers_string_create(*this, value.data(), value.size()); })); if (stringRefs.empty()) return 0; return flatbuffers_string_vec_create(*this, stringRefs.data(), stringRefs.size()); } // Creates an offset vector with the given values. The source values will not // be modified. flatbuffers_vec_ref_t createOffsetVec(ArrayRef values) { if (values.empty()) return 0; return flatcc_builder_create_offset_vector(*this, values.data(), values.size()); } // Creates an offset vector with the given values. // Unlike createOffsetVec this will destroy the input values array during // serialization but be much faster. flatbuffers_vec_ref_t createOffsetVecDestructive(SmallVectorImpl& values) { if (values.empty()) return 0; return flatcc_builder_create_offset_vector_direct(*this, values.data(), values.size()); } // Creates an [int32] vec with the contents of the given range. template flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy&& Range) { if (std::empty(Range)) return 0; flatbuffers_int32_vec_start(*this); for (int32_t v : Range) { flatbuffers_int32_vec_push_create(*this, v); } return flatbuffers_int32_vec_end(*this); } // Creates an [int64] vec with the contents of the given range. template flatbuffers_int64_vec_ref_t createInt64Vec(RangeTy&& Range) { if (std::empty(Range)) return 0; flatbuffers_int64_vec_start(*this); for (int64_t v : Range) { flatbuffers_int64_vec_push_create(*this, v); } return flatbuffers_int64_vec_end(*this); } // Provides a raw_ostream that |fn| can use to directly stream into a [uint8] // in the FlatBuffer builder. // // Usage: // auto ref = builder.streamUint8Vec([&](llvm::raw_ostream &stream) { // stream << "foo"; // return true; // }); // ... // my_type_uint8_vec_field_add(builder, ref); // use vec reference // ... flatbuffers_uint8_vec_ref_t streamUint8Vec(std::function fn, size_t alignment = 16); // Captures the current contents of the flatbuffer builder and returns them // as a shaped `vector` dense attr. The builder is left unmodified. DenseIntElementsAttr getBufferAttr(MLIRContext* context); // Copies the current contents of the flatbuffer builder to the target output // stream. The builder is left unmodified. // // This is reduces a significant large allocation that can happen when trying // to stitch together all of the pages that were allocated in the emitter as // the FlatBuffer was constructed; here we can just walk over each page and // write it out in order without any allocations. LogicalResult copyToStream(llvm::raw_ostream& output); using print_json_fn_t = int (*)(flatcc_json_printer_t* ctx, const char* buf, size_t bufsiz); // Prints the FlatBuffer in its canonical JSON format to the given stream. // The builder is left unmodified. // // |pretty| enables newlines and indentation; somewhat useful for lit testing // (as large byte buffers end up with a byte per line!). // // |includeDefaults| will force all values, including those that would not // be serialized to the binary format due to the default value (0, etc) being // omitted. // // NOTE: JSON representations will also differ structurally from the binary // format as reused tables are printed wherever they are used as opposed to // referencing the same bytes; meaning that this can't be used to verify that // we are correctly memoizing strings/structures/etc. LogicalResult printJsonToStream(bool pretty, bool includeDefaults, print_json_fn_t printJsonFn, llvm::raw_ostream& output); private: flatcc_builder_t builder; }; // Allows streaming bytes directly into a FlatBuffer `[uint8]` field. // The ostream runs in buffered mode and routes all writes into pages // allocated by the FlatBuffer builder as we grow the output. // // Usage: // flatbuffers_uint8_vec_start(builder); // raw_flatbuffer_uint8_vec_ostream stream(builder); // stream << "foo"; // stream.flush(); // *********** IMPORTANT *********** // flatbuffers_uint8_vec_ref_t ref = flatbuffers_uint8_vec_end(builder); class raw_flatbuffer_uint8_vec_ostream : public llvm::raw_ostream { public: explicit raw_flatbuffer_uint8_vec_ostream(flatcc_builder_t* builder) : raw_ostream(/*unbuffered=*/true), builder(builder) {} ~raw_flatbuffer_uint8_vec_ostream() override { flush(); } private: void write_impl(const char* Ptr, size_t Size) override { flatbuffers_uint8_vec_append(builder, reinterpret_cast(Ptr), Size); pos += Size; } uint64_t current_pos() const override { return pos - GetNumBytesInBuffer(); } flatcc_builder_t* builder; uint64_t pos = 0; }; } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/OneFlowLiteUtils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/FlatbufferUtils.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" namespace mlir { namespace oneflow { namespace lite { Operation* getEntryJobOp(ModuleOp module); Operation* getEntryJobOp(Operation* op); StringAttr getValueDevice(Value value); Optional getLiteStringElementType(Type type); Optional getLiteStringElementType(::mlir::oneflow::DataType type); Optional<::oneflow::AttrType> getUserOpAttrType(StringRef opName, StringRef attrName); void serializeI32Attr(FlatbufferBuilder& builder, Attribute attribute); void serializeI64Attr(FlatbufferBuilder& builder, Attribute attribute); void serializeBoolAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeF32Attr(FlatbufferBuilder& builder, Attribute attribute); void serializeF64Attr(FlatbufferBuilder& builder, Attribute attribute); void serializeStringAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeShapeAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeStrideAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeDataTypeAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeI32sAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeI64sAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeF32sAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeDataTypesAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeShapesAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeStridesAttr(FlatbufferBuilder& builder, Attribute attribute); void serializeStringsAttr(FlatbufferBuilder& builder, Attribute attribute); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/FoldVariable.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { std::unique_ptr createLiteFoldVariablePass(); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/InferPlacement.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { std::unique_ptr createLiteInferPlacementPass(StringRef target); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/InsertTransferOp.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { std::unique_ptr createLiteInsertTransferOpPass(); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/Lowering/LoweringAscend.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_ #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowLiteUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" namespace mlir { namespace oneflow { namespace lite { LogicalResult loweringAscend(OpBuilder& builder, Operation* callee, StringRef checkpointDir, llvm::SmallVector* loweringData); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/Lowering/LoweringAscendUtils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_ #include #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowLiteUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" // huawei ascend sdk headers #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" #include "op_proto/built-in/inc/all_ops.h" #pragma GCC diagnostic pop namespace mlir { namespace oneflow { namespace lite { inline ge::Shape convertAscendShape(ArrayRef shape) { return ge::Shape(std::vector{shape.begin(), shape.end()}); } inline Optional convertAscendElementType(Type type) { assert(type.isIntOrFloat()); if (type.isF16()) { return ge::DT_FLOAT16; } else if (type.isF32()) { return ge::DT_FLOAT; } else if (type.isF64()) { return ge::DT_DOUBLE; } else if (type.isSignedInteger()) { int bitwidth = type.getIntOrFloatBitWidth(); if (bitwidth == 8) { return ge::DT_INT8; } else if (bitwidth == 16) { return ge::DT_INT16; } else if (bitwidth == 32) { return ge::DT_INT32; } else if (bitwidth == 64) { return ge::DT_INT64; } else { return llvm::None; } } else if (type.isUnsignedInteger()) { int bitwidth = type.getIntOrFloatBitWidth(); if (bitwidth == 8) { return ge::DT_UINT8; } else if (bitwidth == 16) { return ge::DT_UINT16; } else if (bitwidth == 32) { return ge::DT_UINT32; } else if (bitwidth == 64) { return ge::DT_UINT64; } else { return llvm::None; } } else { return llvm::None; } } inline Optional convertAscendElementType(::mlir::oneflow::DataType type) { switch (type) { case ::mlir::oneflow::DataType::DT_Bool: return ge::DT_BOOL; case ::mlir::oneflow::DataType::DT_Char: return ge::DT_UINT8; case ::mlir::oneflow::DataType::DT_Float16: return ge::DT_FLOAT16; case ::mlir::oneflow::DataType::DT_Float: return ge::DT_FLOAT; case ::mlir::oneflow::DataType::DT_Double: return ge::DT_DOUBLE; case ::mlir::oneflow::DataType::DT_Int8: return ge::DT_INT8; case ::mlir::oneflow::DataType::DT_Int32: return ge::DT_INT32; case ::mlir::oneflow::DataType::DT_Int64: return ge::DT_INT64; case ::mlir::oneflow::DataType::DT_UInt8: return ge::DT_UINT8; default: { return llvm::None; } } } inline ge::TensorDesc convertAscendType(Type type) { auto tensorType = type.cast(); assert(tensorType && "type should be tensor type"); auto elementType = convertAscendElementType(tensorType.getElementType()); if (!elementType) { llvm::errs() << "element type " << tensorType.getElementType() << " is not supported\n"; exit(1); } return ge::TensorDesc(convertAscendShape(tensorType.getShape()), ge::FORMAT_NCHW, elementType.value()); } inline ge::TensorDesc convertAscendType(::mlir::oneflow::DataType type, ArrayRef shape) { auto elementType = convertAscendElementType(type); if (!elementType) { llvm::errs() << "element type " << static_cast(type) << " is not supported\n"; exit(1); } return ge::TensorDesc(convertAscendShape(shape), ge::FORMAT_NCHW, elementType.value()); } inline ge::TensorDesc convertAscendType(Attribute type, Attribute shape) { SmallVector shapeArray; for (auto v : shape.dyn_cast().getValue()) { shapeArray.push_back(v.dyn_cast().getSInt()); } return convertAscendType(type.dyn_cast().getValue(), shapeArray); } inline ge::Operator::OpListInt convertPaddings(ArrayAttr paddings) { assert(paddings.size() == 2 || paddings.size() == 4); if (paddings.size() == 2) { int s0 = paddings[0].dyn_cast().getSInt(); int s1 = paddings[1].dyn_cast().getSInt(); return ge::Operator::OpListInt({s0, s0, s1, s1}); } else { int s0 = paddings[0].dyn_cast().getSInt(); int s1 = paddings[1].dyn_cast().getSInt(); int s2 = paddings[2].dyn_cast().getSInt(); int s3 = paddings[3].dyn_cast().getSInt(); return ge::Operator::OpListInt({s0, s1, s2, s3}); } } inline ge::Operator::OpListInt convertStrides(ArrayAttr strides) { assert(strides.size() == 2); int s0 = strides[0].dyn_cast().getSInt(); int s1 = strides[1].dyn_cast().getSInt(); return ge::Operator::OpListInt({1, 1, s0, s1}); } inline ge::Operator::OpListInt convertDilations(ArrayAttr dilations) { return convertStrides(dilations); } inline ge::Operator::OpListInt convertKernelSize(ArrayAttr kernel_size) { return convertStrides(kernel_size); } inline StringRef convertDataFormat(StringRef dataFormat) { if (dataFormat == "nchw" || dataFormat == "NCHW" || dataFormat == "channels_first") { return StringRef("NCHW"); } else if (dataFormat == "nhwc" || dataFormat == "NHWC" || dataFormat == "channels_last") { return StringRef("NHWC"); } else { llvm::errs() << "unsupport data format " << dataFormat << "\n"; exit(1); } } } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/LoweringLaunchJob.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { std::unique_ptr createLiteLoweringLaunchJobPass(StringRef checkpointDir); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/MemoryPlanning.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { struct LiteBufferSegment { StringRef device; size_t size; size_t alignment; }; class LiteBufferStrategy { public: LiteBufferStrategy() = default; const llvm::SmallVector& getSegments() const { return segments; } llvm::SmallVector& getSegments() { return segments; } int getValueSegmentId(Value value) const; size_t getValueSegmentOffset(Value value) const; LogicalResult insertValue(Value value, int segmentId, size_t segmentOffset); private: llvm::SmallVector segments; struct ValueSegmentInfo { int segmentId; size_t segmentOffset; }; llvm::DenseMap valueSegmentInfos; }; std::unique_ptr createLiteMemoryPlanningPass(LiteBufferStrategy* strategy); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_ ================================================ FILE: oneflow/ir/oneflow-lite/include/OneFlow/Transform/PartitionLaunchJob.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_ #define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_ #include "mlir/Pass/Pass.h" namespace mlir { namespace oneflow { namespace lite { std::unique_ptr createLitePartitionLaunchJobPass(); } // namespace lite } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_ ================================================ FILE: oneflow/ir/oneflow-lite/lib/CMakeLists.txt ================================================ add_subdirectory(OneFlow) ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/CMakeLists.txt ================================================ set(LITE_LOWERING_SRCS "") set(LITE_LOWERING_LIBS "") if(LITE_USE_ASCEND_NPU) include(cmake/FindAscendSdk.cmake) include_directories(${ASCEND_INCLUDE_DIR}) include_directories(${ASCEND_INCLUDE_DIR}/../../opp) add_definitions(-DLITE_USE_ASCEND_NPU=1) list(APPEND LITE_LOWERING_SRCS Transform/Lowering/LoweringAscend.cpp) list(APPEND LITE_LOWERING_LIBS ${ASCEND_LIBRARIES}) endif() oneflow_add_mlir_library( OneFlowLiteConversion ConvertToLiteExecutable.cpp FlatbufferUtils.cpp OneFlowLiteUtils.cpp Transform/FoldVariable.cpp Transform/InferPlacement.cpp Transform/InsertTransferOp.cpp Transform/MemoryPlanning.cpp Transform/PartitionLaunchJob.cpp Transform/LoweringLaunchJob.cpp ${LITE_LOWERING_SRCS} DEPENDS MLIRIR MLIRParser MLIRPass MLIRSPIRVDialect MLIRTranslateLib MLIRSupport MLIROneFlow MLIROneFlowExtension flatcc-runtime LINK_LIBS MLIRIR ${dialect_libs} ${translation_libs} MLIRParser MLIRPass MLIRSPIRVDialect MLIRTranslateLib MLIRSupport MLIROneFlow oneflow $ MLIROneFlowExtension ${LITE_LOWERING_LIBS} $) ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/ConvertToLiteExecutable.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/ConvertToLiteExecutable.h" #include "OneFlow/OneFlowDialect.h" // undefine fallthrough to fix the conflicit of flatcc and fmt #if defined(fallthrough) #undef fallthrough #endif #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/Passes.h" #include "OneFlow/OneFlowUtils.h" #include "OneFlow/OneFlowLiteUtils.h" #include "OneFlow/Transform/FoldVariable.h" #include "OneFlow/Transform/InferPlacement.h" #include "OneFlow/Transform/InsertTransferOp.h" #include "OneFlow/Transform/LoweringLaunchJob.h" #include "OneFlow/Transform/MemoryPlanning.h" #include "OneFlow/Transform/PartitionLaunchJob.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Transforms/Passes.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcast-qual" #include "schemas/executable_generated.h" #pragma GCC diagnostic pop namespace mlir { namespace oneflow { namespace lite { static flatbuffers_vec_ref_t createLiteOpAttrs(FlatbufferBuilder& builder, Operation* op) { assert((llvm::dyn_cast(op) || llvm::dyn_cast(op)) && "the argument op is not a valid user op"); llvm::SmallVector attrDefs; for (auto kv : op->getAttrDictionary()) { auto attrName = kv.getName(); Optional<::oneflow::AttrType> attrType = getUserOpAttrType(GetOpTypeName(op), attrName.strref()); if (!attrType) { continue; } auto attrValue = kv.getValue(); StringRef strAttrType; FlatbufferBuilder attrBuilder; if (attrType.value() == ::oneflow::kAtInt32) { strAttrType = "i32"; serializeI32Attr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtInt64) { strAttrType = "i64"; serializeI64Attr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtBool) { strAttrType = "bool"; serializeBoolAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtFloat) { strAttrType = "f32"; serializeF32Attr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtDouble) { strAttrType = "f64"; serializeF64Attr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtString) { strAttrType = "str"; serializeStringAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtShape) { strAttrType = "shape"; serializeShapeAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtStride) { strAttrType = "stride"; serializeStrideAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtDataType) { strAttrType = "dtype"; serializeDataTypeAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListInt32) { strAttrType = "i32s"; serializeI32sAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListInt64) { strAttrType = "i64s"; serializeI64sAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListFloat) { strAttrType = "f32s"; serializeF32sAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListDataType) { strAttrType = "dtypes"; serializeDataTypesAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListShape) { strAttrType = "shapes"; serializeShapesAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListStride) { strAttrType = "strides"; serializeStridesAttr(attrBuilder, attrValue); } else if (attrType.value() == ::oneflow::kAtListString) { strAttrType = "strs"; serializeStringsAttr(attrBuilder, attrValue); } else { llvm::errs() << "error attribute type: " << attrType.value() << "\n"; exit(1); } oneflow_lite_AttrDef_start(builder); oneflow_lite_AttrDef_type_add(builder, builder.createString(strAttrType)); oneflow_lite_AttrDef_key_add(builder, builder.createString(attrName.strref())); oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) { if (failed(attrBuilder.copyToStream(stream))) { return false; } return true; })); attrDefs.push_back(oneflow_lite_AttrDef_end(builder)); } return builder.createOffsetVecDestructive(attrDefs); } static flatbuffers_vec_ref_t createLiteVariableOpAttrs(FlatbufferBuilder& builder, VariableOp op, StringRef checkpointDir) { llvm::SmallVector attrDefs; { oneflow_lite_AttrDef_start(builder); oneflow_lite_AttrDef_type_add(builder, builder.createString("dtype")); oneflow_lite_AttrDef_key_add(builder, builder.createString("dtype")); FlatbufferBuilder attrBuilder; serializeDataTypeAttr(attrBuilder, op.getDataTypeAttr()); oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) { if (failed(attrBuilder.copyToStream(stream))) { return false; } return true; })); attrDefs.push_back(oneflow_lite_AttrDef_end(builder)); } { oneflow_lite_AttrDef_start(builder); oneflow_lite_AttrDef_type_add(builder, builder.createString("shape")); oneflow_lite_AttrDef_key_add(builder, builder.createString("shape")); FlatbufferBuilder attrBuilder; serializeShapeAttr(attrBuilder, op.getShapeAttr()); oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) { if (failed(attrBuilder.copyToStream(stream))) { return false; } return true; })); attrDefs.push_back(oneflow_lite_AttrDef_end(builder)); } // serialize weight data oneflow_lite_AttrDef_start(builder); oneflow_lite_AttrDef_type_add(builder, builder.createString("u8")); oneflow_lite_AttrDef_key_add(builder, builder.createString("value")); llvm::SmallString<128> inputFilename; llvm::sys::path::native(checkpointDir + "/" + op.getOpName() + "/out", inputFilename); std::string errorMessage; auto input = mlir::openInputFile(inputFilename, &errorMessage); if (!input) { llvm::errs() << errorMessage << "\n"; exit(1); } oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) { stream << input->getBuffer(); stream.flush(); return true; })); attrDefs.push_back(oneflow_lite_AttrDef_end(builder)); return builder.createOffsetVecDestructive(attrDefs); } static oneflow_lite_OpDef_ref_t createLiteVariableOpDef( FlatbufferBuilder& builder, VariableOp op, llvm::DenseMap& valueOrdering, const llvm::DenseMap& deviceOrdering, StringRef checkpointDir) { oneflow_lite_OpDef_start(builder); oneflow_lite_OpDef_name_add(builder, builder.createString("constant")); oneflow_lite_OpDef_inputs_add(builder, 0); auto index = valueOrdering.try_emplace(op.getOutput(), valueOrdering.size()).first->second; oneflow_lite_OpDef_outputs_add(builder, builder.createInt32Vec(llvm::SmallVector{index})); oneflow_lite_OpDef_attrs_add(builder, createLiteVariableOpAttrs(builder, op, checkpointDir)); auto it = deviceOrdering.find(op.getDeviceTag()); assert(it != deviceOrdering.end()); oneflow_lite_OpDef_device_add(builder, it->second); return oneflow_lite_OpDef_end(builder); } static oneflow_lite_OpDef_ref_t createLiteOpDef( FlatbufferBuilder& builder, Operation* op, llvm::DenseMap& valueOrdering, const llvm::DenseMap& deviceOrdering) { llvm::SmallVector inputOrdering; for (const auto& operand : op->getOperands()) { auto it = valueOrdering.find(operand); if (it == valueOrdering.end()) { it = valueOrdering.try_emplace(operand, valueOrdering.size()).first; } inputOrdering.push_back(it->second); } llvm::SmallVector outputOrdering; for (const auto& result : op->getResults()) { auto it = valueOrdering.find(result); if (it == valueOrdering.end()) { it = valueOrdering.try_emplace(result, valueOrdering.size()).first; } outputOrdering.push_back(it->second); } oneflow_lite_OpDef_start(builder); oneflow_lite_OpDef_name_add(builder, builder.createString(GetOpTypeName(op))); oneflow_lite_OpDef_inputs_add(builder, builder.createInt32Vec(inputOrdering)); oneflow_lite_OpDef_outputs_add(builder, builder.createInt32Vec(outputOrdering)); oneflow_lite_OpDef_attrs_add(builder, createLiteOpAttrs(builder, op)); auto device = op->getAttrOfType(OpTrait::IsOpConfCompatible::getDeviceTagAttr()); auto it = deviceOrdering.find(device.getValue()); assert(it != deviceOrdering.end()); oneflow_lite_OpDef_device_add(builder, it->second); return oneflow_lite_OpDef_end(builder); } static oneflow_lite_TensorDef_ref_t createLiteTensorDef(FlatbufferBuilder& builder, Value value, int segmentId, size_t segmentOffset) { TensorType type = value.getType().cast(); oneflow_lite_TensorDef_start(builder); auto elemType = getLiteStringElementType(type.getElementType()); if (!elemType) { llvm::errs() << "error tensor element type: " << type.getElementType() << "\n"; exit(1); } oneflow_lite_TensorDef_type_add(builder, builder.createString(elemType.value())); oneflow_lite_TensorDef_layout_add(builder, builder.createString("default")); oneflow_lite_TensorDef_sizes_add(builder, builder.createInt64Vec(type.getShape())); oneflow_lite_TensorDef_strides_add(builder, builder.createInt64Vec(llvm::SmallVector{})); oneflow_lite_TensorDef_segment_id_add(builder, segmentId); oneflow_lite_TensorDef_segment_offset_add(builder, segmentOffset); return oneflow_lite_TensorDef_end(builder); } static oneflow_lite_BufferSegmentDef_ref_t createLiteBufferSegmentDef( FlatbufferBuilder& builder, const LiteBufferSegment& segment, const llvm::DenseMap& deviceOrdering) { auto it = deviceOrdering.find(segment.device); assert(it != deviceOrdering.end()); oneflow_lite_BufferSegmentDef_start(builder); oneflow_lite_BufferSegmentDef_size_add(builder, segment.size); oneflow_lite_BufferSegmentDef_device_add(builder, it->second); oneflow_lite_BufferSegmentDef_alignment_add(builder, static_cast(segment.alignment)); return oneflow_lite_BufferSegmentDef_end(builder); } LogicalResult ConvertToLiteExecutable(MLIRContext* context, ModuleOp module, ConvertOptions options, llvm::SmallVector* executable) { mlir::PassManager pm(context); pm.addPass(createCanonicalizerPass()); pm.addPass(createLiteFoldVariablePass()); pm.addPass(createLiteInferPlacementPass(options.target)); pm.addPass(createLiteInsertTransferOpPass()); pm.addPass(createLitePartitionLaunchJobPass()); pm.addPass(createLiteLoweringLaunchJobPass(options.checkpointDir)); pm.addPass(createCanonicalizerPass()); LiteBufferStrategy bufferStrategy; pm.addPass(createLiteMemoryPlanningPass(&bufferStrategy)); if (mlir::failed(pm.run(module))) { llvm::errs() << "Failed to run oneflow lite compilation passes.\n"; return failure(); } // llvm::errs() << *module << "\n"; Operation* entryJobOp = getEntryJobOp(module); if (!entryJobOp) { llvm::errs() << "Job not found in module: " << *module; return failure(); } auto funcName = entryJobOp->getAttrOfType("sym_name"); llvm::SmallVector devices; llvm::DenseMap deviceOrdering; for (const auto& segment : bufferStrategy.getSegments()) { int ordering = deviceOrdering.size(); if (deviceOrdering.try_emplace(segment.device, ordering).second) { devices.push_back(segment.device); } } FlatbufferBuilder builder; oneflow_lite_ExecutableDef_start_as_root(builder); oneflow_lite_ExecutableDef_version_add(builder, 0); oneflow_lite_ExecutableDef_name_add(builder, builder.createString(funcName.getValue())); oneflow_lite_ExecutableDef_devices_add(builder, builder.createStringVec(devices)); llvm::DenseMap valueOrdering; llvm::SmallVector inputValueOrdering, outputValueOrdering; llvm::SmallVector inputValueNames, outputValueNames; llvm::SmallVector opDefs; entryJobOp->walk([&](Operation* op) { if (!op->hasTrait()) { return; } if (auto inputOp = llvm::dyn_cast(op)) { auto it = valueOrdering.try_emplace(inputOp.getOutput(), valueOrdering.size()).first; inputValueOrdering.push_back(it->second); inputValueNames.push_back( op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) .getValue()); } else if (auto outputOp = llvm::dyn_cast(op)) { auto it = valueOrdering.try_emplace(outputOp.getInput(), valueOrdering.size()).first; outputValueOrdering.push_back(it->second); outputValueNames.push_back( op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) .getValue()); } else if (auto variableOp = llvm::dyn_cast(op)) { opDefs.push_back(createLiteVariableOpDef(builder, variableOp, valueOrdering, deviceOrdering, options.checkpointDir)); } else { opDefs.push_back(createLiteOpDef(builder, op, valueOrdering, deviceOrdering)); } }); oneflow_lite_ExecutableDef_ops_add(builder, builder.createOffsetVecDestructive(opDefs)); llvm::SmallVector orderedValues(valueOrdering.size()); for (auto it : valueOrdering) { orderedValues[it.second] = it.first; } llvm::SmallVector tensorDefs; for (auto value : orderedValues) { int segmentId = bufferStrategy.getValueSegmentId(value); size_t segmentOffset = bufferStrategy.getValueSegmentOffset(value); tensorDefs.push_back(createLiteTensorDef(builder, value, segmentId, segmentOffset)); } oneflow_lite_ExecutableDef_operands_add(builder, builder.createOffsetVecDestructive(tensorDefs)); oneflow_lite_ExecutableDef_inputs_add(builder, builder.createInt32Vec(inputValueOrdering)); oneflow_lite_ExecutableDef_outputs_add(builder, builder.createInt32Vec(outputValueOrdering)); oneflow_lite_ExecutableDef_input_names_add(builder, builder.createStringVec(inputValueNames)); oneflow_lite_ExecutableDef_output_names_add(builder, builder.createStringVec(outputValueNames)); llvm::SmallVector segmentDefs; for (const auto& segment : bufferStrategy.getSegments()) { segmentDefs.push_back(createLiteBufferSegmentDef(builder, segment, deviceOrdering)); } oneflow_lite_ExecutableDef_segments_add(builder, builder.createOffsetVecDestructive(segmentDefs)); oneflow_lite_ExecutableDef_end_as_root(builder); size_t packedSize = flatcc_builder_get_buffer_size(builder); executable->resize(packedSize); flatcc_builder_copy_buffer(builder, executable->data(), packedSize); return success(); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/FlatbufferUtils.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Copyright 2020 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "OneFlow/FlatbufferUtils.h" #include #include #include #include #include "mlir/IR/BuiltinTypes.h" namespace mlir { namespace oneflow { namespace lite { // Combines all pages of the FlatBuffer builder into a single contiguous byte // buffer and returns the result. // // NOTE: this is a alloc/copy. We need to have a single contiguous buffer to // pass into the elements factory function and the data we have in the // builder is paged. If we end up with a custom attribute type for this that // does not support storage uniquing then we can directly allocate and copy // the pages into the buffer without the extra copy. static SmallVector cloneBufferIntoContiguousBytes(FlatbufferBuilder& fbb) { size_t packedSize = flatcc_builder_get_buffer_size(fbb); SmallVector packedData(packedSize); void* result = flatcc_builder_copy_buffer(fbb, packedData.data(), packedData.size()); assert(result && "flatcc_emitter_t impl failed (non-default?)"); (void)result; return packedData; } FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); } FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); } flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec( std::function fn, size_t alignment) { flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); raw_flatbuffer_uint8_vec_ostream stream(*this); if (!fn(stream)) { return 0; } stream.flush(); return flatbuffers_uint8_vec_end(*this); } DenseIntElementsAttr FlatbufferBuilder::getBufferAttr(MLIRContext* context) { // We require direct access to the FlatBuffer bytes so we can pass them to // the attribute constructor (which needs to inspect them all for uniquing). auto bufferData = cloneBufferIntoContiguousBytes(*this); // NOTE: ew. OpaqueAttr may be better? It does equality checks but won't try // to unique and would let us get a mutable buffer out. return DenseIntElementsAttr::get( VectorType::get({static_cast(bufferData.size())}, IntegerType::get(context, 8)), std::move(bufferData)); } LogicalResult FlatbufferBuilder::copyToStream(llvm::raw_ostream& output) { // NOTE: expected to be the default emitter. auto* E = reinterpret_cast(flatcc_builder_get_emit_context(*this)); if (!E->front) { return failure(); } if (E->front == E->back) { output.write(reinterpret_cast(E->front_cursor), E->used); return success(); } size_t len = FLATCC_EMITTER_PAGE_SIZE - E->front_left; output.write(reinterpret_cast(E->front_cursor), len); flatcc_emitter_page_t* p = E->front->next; while (p != E->back) { output.write(reinterpret_cast(p->page), FLATCC_EMITTER_PAGE_SIZE); p = p->next; } output.write(reinterpret_cast(p->page), FLATCC_EMITTER_PAGE_SIZE - E->back_left); return success(); } LogicalResult FlatbufferBuilder::printJsonToStream(bool pretty, bool includeDefaults, print_json_fn_t printJsonFn, llvm::raw_ostream& output) { // The printer requires direct access to the FlatBuffer bytes so clone here. auto bufferData = cloneBufferIntoContiguousBytes(*this); auto moduleData = ArrayRef(bufferData.data(), bufferData.size()) .drop_front(sizeof(flatbuffers_uoffset_t)); flatcc_json_printer_t printer; flatcc_json_printer_init_dynamic_buffer(&printer, /*buffer_size=*/0); flatcc_json_printer_set_indent(&printer, pretty ? 2 : 0); flatcc_json_printer_set_skip_default(&printer, !includeDefaults); flatcc_json_printer_set_force_default(&printer, includeDefaults); // Print into the dynamically-resizing buffer. May fail if OOM. int rv = printJsonFn(&printer, reinterpret_cast(moduleData.data()), moduleData.size()); if (rv == -1) { flatcc_json_printer_clear(&printer); return failure(); } // Take the buffer from the printer; note that it is 0 terminated and can be // used directly as a cstr if needed. size_t outputSize = 0; char* outputBytes = reinterpret_cast(flatcc_json_printer_finalize_dynamic_buffer(&printer, &outputSize)); output.write(outputBytes, outputSize); free(outputBytes); return success(); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/OneFlowLiteUtils.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowLiteUtils.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowUtils.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcast-qual" #include "schemas/executable_generated.h" #include "schemas/attributes/bool_generated.h" #include "schemas/attributes/f32_generated.h" #include "schemas/attributes/f32s_generated.h" #include "schemas/attributes/f64_generated.h" #include "schemas/attributes/i32_generated.h" #include "schemas/attributes/i32s_generated.h" #include "schemas/attributes/i64_generated.h" #include "schemas/attributes/i64s_generated.h" #include "schemas/attributes/shape_generated.h" #include "schemas/attributes/shapes_generated.h" #include "schemas/attributes/str_generated.h" #include "schemas/attributes/strs_generated.h" #pragma GCC diagnostic pop namespace mlir { namespace oneflow { namespace lite { Operation* getEntryJobOp(ModuleOp module) { return getEntryJobOp(module.getOperation()); } Operation* getEntryJobOp(Operation* op) { Operation* entry = nullptr; op->walk([&](oneflow::Job job) -> WalkResult { entry = job.getOperation(); return WalkResult::advance(); }); return entry; } StringAttr getValueDevice(Value value) { StringAttr device; Operation* op = value.getDefiningOp(); if (auto copyOp = dyn_cast(op)) { device = copyOp.getDeviceTypeAttr(); } else { device = value.getDefiningOp()->getAttrOfType( OpTrait::IsOpConfCompatible::getDeviceTagAttr()); } return device; } Optional getLiteStringElementType(Type type) { assert(type.isIntOrFloat()); if (type.isF16()) { return StringRef("f16"); } else if (type.isBF16()) { return StringRef("bf16"); } else if (type.isF32()) { return StringRef("f32"); } else if (type.isF64()) { return StringRef("f64"); } else if (type.isSignedInteger()) { int bitwidth = type.getIntOrFloatBitWidth(); return StringRef("i" + llvm::Twine(bitwidth).str()); } else if (type.isUnsignedInteger()) { int bitwidth = type.getIntOrFloatBitWidth(); return StringRef("u" + llvm::Twine(bitwidth).str()); } else { return llvm::None; } } Optional getLiteStringElementType(::mlir::oneflow::DataType type) { switch (type) { case ::mlir::oneflow::DataType::DT_Bool: return StringRef("bool"); case ::mlir::oneflow::DataType::DT_Char: return StringRef("char"); case ::mlir::oneflow::DataType::DT_Float16: return StringRef("f16"); case ::mlir::oneflow::DataType::DT_Float: return StringRef("f32"); case ::mlir::oneflow::DataType::DT_Double: return StringRef("f64"); case ::mlir::oneflow::DataType::DT_Int8: return StringRef("i8"); case ::mlir::oneflow::DataType::DT_Int32: return StringRef("i32"); case ::mlir::oneflow::DataType::DT_Int64: return StringRef("i64"); case ::mlir::oneflow::DataType::DT_UInt8: return StringRef("u8"); default: { return llvm::None; } } } Optional<::oneflow::AttrType> getUserOpAttrType(StringRef opName, StringRef attrName) { const ::oneflow::user_op::OpRegistryResult* val = ::oneflow::user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(opName.str()); if (!val) { llvm::errs() << "unregistered user op: " << opName << "\n"; exit(1); } ::oneflow::user_op::UserOpDefWrapper op_def(val->op_def); if (!op_def.IsAttrName(attrName.str())) { return llvm::None; } return op_def.GetAttrType(attrName.str()); } void serializeI32Attr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_I32Def_start_as_root(builder); oneflow_lite_I32Def_value_add(builder, attribute.dyn_cast().getSInt()); oneflow_lite_I32Def_end_as_root(builder); } void serializeI64Attr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_I64Def_start_as_root(builder); oneflow_lite_I64Def_value_add(builder, attribute.dyn_cast().getSInt()); oneflow_lite_I64Def_end_as_root(builder); } void serializeBoolAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_BoolDef_start_as_root(builder); oneflow_lite_BoolDef_value_add(builder, attribute.dyn_cast().getValue()); oneflow_lite_BoolDef_end_as_root(builder); } void serializeF32Attr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_F32Def_start_as_root(builder); oneflow_lite_F32Def_value_add(builder, attribute.dyn_cast().getValue().convertToFloat()); oneflow_lite_F32Def_end_as_root(builder); } void serializeF64Attr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_F64Def_start_as_root(builder); oneflow_lite_F64Def_value_add(builder, attribute.dyn_cast().getValue().convertToDouble()); oneflow_lite_F64Def_end_as_root(builder); } void serializeStringAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_StringDef_start_as_root(builder); oneflow_lite_StringDef_value_add( builder, builder.createString(attribute.dyn_cast().getValue())); oneflow_lite_StringDef_end_as_root(builder); } void serializeShapeAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_ShapeDef_start_as_root(builder); SmallVector shape; for (auto v : attribute.dyn_cast().getValue()) { shape.push_back(v.dyn_cast().getSInt()); } oneflow_lite_ShapeDef_value_add(builder, builder.createInt64Vec(shape)); oneflow_lite_ShapeDef_end_as_root(builder); } void serializeStrideAttr(FlatbufferBuilder& builder, Attribute attribute) { serializeShapeAttr(builder, attribute); } void serializeDataTypeAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_StringDef_start_as_root(builder); auto dtype = getLiteStringElementType(attribute.dyn_cast().getValue()); if (!dtype) { llvm::errs() << "error data type: " << attribute << "\n"; exit(1); } oneflow_lite_StringDef_value_add(builder, builder.createString(dtype.value())); oneflow_lite_StringDef_end_as_root(builder); } void serializeI32sAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_I32sDef_start_as_root(builder); SmallVector vec; for (auto v : attribute.dyn_cast().getValue()) { vec.push_back(v.dyn_cast().getSInt()); } oneflow_lite_I32sDef_value_add(builder, builder.createInt32Vec(vec)); oneflow_lite_I32sDef_end_as_root(builder); } void serializeI64sAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_I64sDef_start_as_root(builder); SmallVector vec; for (auto v : attribute.dyn_cast().getValue()) { vec.push_back(v.dyn_cast().getSInt()); } oneflow_lite_I64sDef_value_add(builder, builder.createInt64Vec(vec)); oneflow_lite_I64sDef_end_as_root(builder); } void serializeF32sAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_F32sDef_start_as_root(builder); flatbuffers_float_vec_start(builder); for (auto v : attribute.dyn_cast().getValue()) { flatbuffers_float_vec_push_create(builder, v.dyn_cast().getValue().convertToFloat()); } oneflow_lite_F32sDef_value_add(builder, flatbuffers_float_vec_end(builder)); oneflow_lite_F32sDef_end_as_root(builder); } void serializeDataTypesAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_StringsDef_start_as_root(builder); llvm::SmallVector dtypes; for (auto v : attribute.dyn_cast().getValue()) { auto dtype = getLiteStringElementType(v.dyn_cast().getValue()); if (!dtype) { llvm::errs() << "error data type: " << v << "\n"; exit(1); } dtypes.push_back(dtype.value()); } oneflow_lite_StringsDef_value_add(builder, builder.createStringVec(dtypes)); oneflow_lite_StringsDef_end_as_root(builder); } void serializeShapesAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_ShapesDef_start_as_root(builder); SmallVector shapeDefs; for (auto v : attribute.dyn_cast().getValue()) { oneflow_lite_ShapeDef_start(builder); SmallVector vec; for (auto p : v.dyn_cast().getValue()) { vec.push_back(p.dyn_cast().getSInt()); } oneflow_lite_ShapeDef_value_add(builder, builder.createInt64Vec(vec)); shapeDefs.push_back(oneflow_lite_ShapeDef_end(builder)); } oneflow_lite_ShapesDef_value_add(builder, builder.createOffsetVecDestructive(shapeDefs)); oneflow_lite_ShapesDef_end_as_root(builder); } void serializeStridesAttr(FlatbufferBuilder& builder, Attribute attribute) { return serializeShapesAttr(builder, attribute); } void serializeStringsAttr(FlatbufferBuilder& builder, Attribute attribute) { oneflow_lite_StringsDef_start_as_root(builder); SmallVector stringDefs; for (auto v : attribute.dyn_cast().getValue()) { oneflow_lite_StringDef_start(builder); oneflow_lite_StringDef_value_add(builder, builder.createString(v.dyn_cast().getValue())); stringDefs.push_back(oneflow_lite_StringDef_end(builder)); } oneflow_lite_StringsDef_value_add(builder, builder.createOffsetVecDestructive(stringDefs)); oneflow_lite_StringsDef_end_as_root(builder); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/FoldVariable.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/FoldVariable.h" namespace mlir { namespace oneflow { namespace lite { struct FoldVariablePass : public PassWrapper> { void runOnOperation() override { // TODO } }; std::unique_ptr createLiteFoldVariablePass() { return std::unique_ptr(new FoldVariablePass); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/InferPlacement.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/InferPlacement.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" namespace mlir { namespace oneflow { namespace lite { static bool CanScheduleOnTarget(Operation* op, StringRef target) { if (!op->hasTrait()) { return false; } if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { return false; } // TODO() return true; } struct InferPlacementPass : public PassWrapper> { StringRef target_; explicit InferPlacementPass(StringRef target) : target_(target) {} void runOnOperation() override; }; void InferPlacementPass::runOnOperation() { getOperation().walk([&](Operation* op) { if (!op->hasTrait()) { return; } auto target = [&]() -> StringRef { if (CanScheduleOnTarget(op, target_)) { return target_; } return StringRef("host"); }(); OpBuilder builder(&getContext()); op->setAttr(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), builder.getStringAttr(target)); }); } std::unique_ptr createLiteInferPlacementPass(StringRef target) { return std::unique_ptr(new InferPlacementPass(target)); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/InsertTransferOp.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/InsertTransferOp.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" namespace mlir { namespace oneflow { namespace lite { struct InsertTransferOpPass : public PassWrapper> { void runOnOperation() override; StringAttr InferTargetDevice(StringAttr from, StringAttr to) const; }; StringAttr InsertTransferOpPass::InferTargetDevice(StringAttr from, StringAttr to) const { auto IsHostDevice = [](StringAttr device) { return device == "host" || device == "cpu" || device == "x86" || device == "arm"; }; return IsHostDevice(from) ? to : from; } void InsertTransferOpPass::runOnOperation() { auto opNameAttrkey = OpTrait::IsOpConfCompatible::getOpNameAttr(); auto deviceTagAttrKey = OpTrait::IsOpConfCompatible::getDeviceTagAttr(); auto deviceNameAttrKey = OpTrait::IsOpConfCompatible::getDeviceNameAttr(); OpBuilder builder(&getContext()); getOperation().walk([&](Operation* op) { if (!op->hasTrait()) { return; } auto device = op->getAttrOfType(deviceTagAttrKey); for (Value result : op->getResults()) { llvm::DenseMap> operandsToReplace; for (auto& use : result.getUses()) { if (!use.getOwner()->hasTrait()) { continue; } auto use_device = use.getOwner()->getAttrOfType(deviceTagAttrKey); if (use_device != device) { operandsToReplace[use_device].push_back(&use); } } for (const auto& it : operandsToReplace) { NamedAttrList attrs; attrs.set(opNameAttrkey, builder.getStringAttr("copy")); attrs.set(deviceTagAttrKey, InferTargetDevice(device, it.first)); attrs.set(deviceNameAttrKey, builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range( ArrayRef({"@0:0"}), [&](StringRef v) -> Attribute { return builder.getStringAttr(v); })))); attrs.set(builder.getStringAttr("device_type"), it.first); builder.setInsertionPointAfter(op); SmallVector operands{result}; auto copy_op = builder.create(op->getLoc(), op->getResultTypes(), operands, attrs); for (OpOperand* operand : it.second) { operand->getOwner()->setOperand(operand->getOperandNumber(), copy_op.getOut()); } } } }); } std::unique_ptr createLiteInsertTransferOpPass() { return std::unique_ptr(new InsertTransferOpPass()); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/Lowering/LoweringAscend.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/Lowering/LoweringAscend.h" #include "OneFlow/Transform/Lowering/LoweringAscendUtils.h" #include #include #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/ToolOutputFile.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowLiteUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/ToolUtilities.h" // huawei ascend sdk headers #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" #include "op_proto/built-in/inc/all_ops.h" #pragma GCC diagnostic pop namespace mlir { namespace oneflow { namespace lite { class AscendValue { public: AscendValue() = default; AscendValue(const std::shared_ptr& op, const ge::TensorDesc& type, StringRef componentName) : op_(op), type_(type), componentName_(componentName), componentIndex_(-1) {} AscendValue(const std::shared_ptr& op, const ge::TensorDesc& type, StringRef componentName, int componentIndex) : op_(op), type_(type), componentName_(componentName), componentIndex_(componentIndex) {} AscendValue(const AscendValue&) = default; const std::shared_ptr& getOperation() const { return op_; } const ge::TensorDesc& getType() const { return type_; } StringRef getComponentName() const { return componentName_; } int getComponentIndex() const { return componentIndex_; } StringRef getComponentNameAndIndex() const { if (componentIndex_ < 0) { return componentName_; } auto name = componentName_ + llvm::Twine(componentIndex_); return StringRef(name.str()); } void setOperation(const std::shared_ptr& op) { op_ = op; } void setType(const ge::TensorDesc& type) { type_ = type; } void setComponentName(StringRef componentName) { componentName_ = componentName; } void setComponentIndex(int componentIndex) { componentIndex_ = componentIndex; } private: std::shared_ptr op_; ge::TensorDesc type_; StringRef componentName_; int componentIndex_; }; class AscendCompiler { public: AscendCompiler() = default; void addInputs(llvm::SmallVector& operands); void lowerOp(VariableOp op, StringRef checkpointDir); void lowerOp(Conv2DOp op); void lowerOp(NormalizationInferenceOp op); void lowerOp(ReluOp op); void lowerOp(MaxPool2DOp op); void lowerOp(AvgPool2DOp op); void lowerOp(Add2Op op); void lowerOp(AdaptiveAvgPool2DOp op); void lowerOp(MatmulOp op); void lowerOp(BroadcastAddOp op); void lowerOp(ReshapeOp op); void lowerOp(func::ReturnOp op); void serializeToBuffer(llvm::SmallVector* data); private: AscendValue getValue(Value value) const { auto it = ascendVals.find(value); assert(it != ascendVals.end()); return it->second; } template std::shared_ptr createOp(llvm::Twine opName) { auto op = std::make_shared(opName.str()); ascendOps.push_back(op); return op; } llvm::SmallVector inputs; llvm::SmallVector results; llvm::SmallVector, 4> ascendOps; llvm::DenseMap ascendVals; }; void AscendCompiler::serializeToBuffer(llvm::SmallVector* data) { std::vector ins; std::vector> outs; for (auto in : inputs) { ins.push_back(*(in.getOperation())); } for (auto out : results) { outs.push_back(std::make_pair(*(out.getOperation()), out.getComponentNameAndIndex().data())); } ge::Graph graph("ascend-graph"); graph.SetInputs(ins).SetOutputs(outs); if (!graph.IsValid()) { llvm::errs() << "ascend graph is invalid\n"; exit(1); } const char* outputFilename = ".__TMP__ascend_graph"; graph.SaveToFile(outputFilename); std::string errorMessage; auto f = mlir::openInputFile(outputFilename, &errorMessage); if (!f) { llvm::errs() << errorMessage << "\n"; exit(1); } data->resize(f->getBufferSize()); memcpy(data->data(), f->getBufferStart(), data->size()); // clean temp file if (0 != remove(outputFilename)) { llvm::errs() << "faile to clean temp file\n"; exit(1); } } void AscendCompiler::addInputs(llvm::SmallVector& operands) { for (auto operand : llvm::enumerate(operands)) { llvm::Twine opName = "input_" + llvm::Twine(operand.index()); auto inputOp = createOp(opName.str()); auto ascendType = convertAscendType(operand.value().getType()); inputOp->update_input_desc_x(ascendType); inputOp->update_output_desc_y(ascendType); inputs.push_back(AscendValue(inputOp, ascendType, "y")); ascendVals[operand.value()] = inputs.back(); } } void AscendCompiler::lowerOp(VariableOp op, StringRef checkpointDir) { auto ascendType = convertAscendType(op.data_typeAttr(), op.shapeAttr()); llvm::SmallString<128> inputFilename; llvm::sys::path::native(checkpointDir + "/" + op.getOpName() + "/out", inputFilename); std::string errorMessage; auto input = mlir::openInputFile(inputFilename, &errorMessage); if (!input) { llvm::errs() << errorMessage << "\n"; exit(1); } auto constantOp = createOp(op.getOpName()); auto tensor = std::make_shared(); tensor->SetTensorDesc(ascendType); tensor->SetData(reinterpret_cast(input->getBufferStart()), input->getBufferSize()); constantOp->set_attr_value(*tensor); ascendVals[op.getOutput()] = AscendValue(constantOp, ascendType, "y"); } #define SET_INPUT(op, name, value) \ op->set_input_##name##_by_name(*(value.getOperation()), value.getComponentNameAndIndex().data()) void AscendCompiler::lowerOp(Conv2DOp op) { auto conv2DOp = createOp(op.getOpName()); conv2DOp->set_attr_pads(convertPaddings(op.padding_before())); conv2DOp->set_attr_dilations(convertDilations(op.getDilationRate())); conv2DOp->set_attr_strides(convertStrides(op.getStrides())); conv2DOp->set_attr_groups(op.getGroups()); conv2DOp->set_attr_data_format(convertDataFormat(op.data_format()).data()); SET_INPUT(conv2DOp, x, getValue(op.getIn())); SET_INPUT(conv2DOp, filter, getValue(op.getWeight())); if (op.getBias()) { SET_INPUT(conv2DOp, bias, getValue(op.getBias())); } auto outType = convertAscendType(op.getOut().getType()); conv2DOp->update_output_desc_y(outType); auto output = AscendValue(conv2DOp, outType, "y"); if (op._add_to_output()) { auto addOp = createOp(op.getOpName() + "_add_to_output"); SET_INPUT(addOp, x1, output); SET_INPUT(addOp, x2, getValue(op._add_to_output())); addOp->update_output_desc_y(outType); output = AscendValue(addOp, outType, "y"); } ascendVals[op.getOut()] = output; } void AscendCompiler::lowerOp(NormalizationInferenceOp op) { auto batchNormOp = createOp(op.getOpName()); batchNormOp->set_attr_epsilon(op.getEpsilon().convertToFloat()); SET_INPUT(batchNormOp, x, getValue(op.getX())); SET_INPUT(batchNormOp, mean, getValue(op.getMovingMean())); SET_INPUT(batchNormOp, variance, getValue(op.getMovingVariance())); SET_INPUT(batchNormOp, scale, getValue(op.getGamma())); SET_INPUT(batchNormOp, offset, getValue(op.getBeta())); auto outType = convertAscendType(op.getY().getType()); batchNormOp->update_output_desc_y(outType); auto output = AscendValue(batchNormOp, outType, "y"); if (op._add_to_output()) { auto addOp = createOp(op.getOpName() + "_add_to_output"); SET_INPUT(addOp, x1, output); SET_INPUT(addOp, x2, getValue(op._add_to_output())); addOp->update_output_desc_y(outType); output = AscendValue(addOp, outType, "y"); } ascendVals[op.getY()] = output; } void AscendCompiler::lowerOp(ReluOp op) { auto reluOp = createOp(op.getOpName()); SET_INPUT(reluOp, x, getValue(op.getX())); auto outType = convertAscendType(op.getY().getType()); reluOp->update_output_desc_y(outType); ascendVals[op.getY()] = AscendValue(reluOp, outType, "y"); } void AscendCompiler::lowerOp(MaxPool2DOp op) { auto maxPoolOp = createOp(op.getOpName()); maxPoolOp->set_attr_ksize(convertKernelSize(op.getKernelSize())); maxPoolOp->set_attr_pads(convertPaddings(op.getPadding())); maxPoolOp->set_attr_strides(convertStrides(op.getStride())); maxPoolOp->set_attr_ceil_mode(op.ceil_mode()); maxPoolOp->set_attr_padding_mode("CALCULATED"); maxPoolOp->set_attr_global_pooling(false); SET_INPUT(maxPoolOp, x, getValue(op.getX())); auto outType = convertAscendType(op.getY().getType()); maxPoolOp->update_output_desc_y(outType); ascendVals[op.getY()] = AscendValue(maxPoolOp, outType, "y"); } void AscendCompiler::lowerOp(AvgPool2DOp op) { auto avgPoolOp = createOp(op.getOpName()); avgPoolOp->set_attr_ksize(convertKernelSize(op.getKernelSize())); avgPoolOp->set_attr_pads(convertPaddings(op.getPadding())); avgPoolOp->set_attr_strides(convertStrides(op.getStride())); avgPoolOp->set_attr_ceil_mode(op.ceil_mode()); avgPoolOp->set_attr_padding_mode("CALCULATED"); avgPoolOp->set_attr_global_pooling(false); avgPoolOp->set_attr_exclusive(!op.count_include_pad()); SET_INPUT(avgPoolOp, x, getValue(op.getX())); auto outType = convertAscendType(op.getY().getType()); avgPoolOp->update_output_desc_y(outType); ascendVals[op.getY()] = AscendValue(avgPoolOp, outType, "y"); } void AscendCompiler::lowerOp(Add2Op op) { auto addOp = createOp(op.getOpName()); SET_INPUT(addOp, x1, getValue(op.getIn0())); SET_INPUT(addOp, x2, getValue(op.getIn1())); auto outType = convertAscendType(op.getOut().getType()); addOp->update_output_desc_y(outType); ascendVals[op.getOut()] = AscendValue(addOp, outType, "y"); } void AscendCompiler::lowerOp(AdaptiveAvgPool2DOp op) { auto adaptiveAvgPoolOp = createOp(op.getOpName()); ArrayAttr output_size = op.output_size(); assert(output_size.size() == 2); int64_t s0 = output_size[0].dyn_cast().getSInt(); int64_t s1 = output_size[1].dyn_cast().getSInt(); adaptiveAvgPoolOp->set_attr_output_size(ge::Operator::OpListInt({s0, s1})); SET_INPUT(adaptiveAvgPoolOp, x, getValue(op.getX())); auto outType = convertAscendType(op.getY().getType()); adaptiveAvgPoolOp->update_output_desc_y(outType); ascendVals[op.getY()] = AscendValue(adaptiveAvgPoolOp, outType, "y"); } void AscendCompiler::lowerOp(MatmulOp op) { auto matmulOp = createOp(op.getOpName()); matmulOp->set_attr_transpose_x1(op.getTransposeA()); matmulOp->set_attr_transpose_x2(op.getTransposeB()); SET_INPUT(matmulOp, x1, getValue(op.getA())); SET_INPUT(matmulOp, x2, getValue(op.getB())); auto outType = convertAscendType(op.getOut().getType()); matmulOp->update_output_desc_y(outType); auto output = AscendValue(matmulOp, outType, "y"); if (op._add_to_output()) { auto addOp = createOp(op.getOpName() + "_add_to_output"); SET_INPUT(addOp, x1, output); SET_INPUT(addOp, x2, getValue(op._add_to_output())); addOp->update_output_desc_y(outType); output = AscendValue(addOp, outType, "y"); } ascendVals[op.getOut()] = output; } void AscendCompiler::lowerOp(BroadcastAddOp op) { auto addOp = createOp(op.getOpName()); SET_INPUT(addOp, x1, getValue(op.getX())); SET_INPUT(addOp, x2, getValue(op.getY())); auto outType = convertAscendType(op.getZ().getType()); addOp->update_output_desc_y(outType); ascendVals[op.getZ()] = AscendValue(addOp, outType, "y"); } void AscendCompiler::lowerOp(ReshapeOp op) { llvm::SmallVector shape; for (auto v : op.getShape()) { shape.push_back(v.dyn_cast().getSInt()); } auto constantOp = createOp(op.getOpName() + "_shape"); auto shapeType = ge::TensorDesc(ge::Shape(std::vector{static_cast(shape.size())}), ge::FORMAT_NCHW, ge::DT_INT64); auto tensor = std::make_shared(); tensor->SetTensorDesc(shapeType); tensor->SetData(reinterpret_cast(shape.data()), shape.size() * sizeof(int64_t)); constantOp->set_attr_value(*tensor); auto reshapeOp = createOp(op.getOpName()); SET_INPUT(reshapeOp, x, getValue(op.getIn())); SET_INPUT(reshapeOp, shape, (AscendValue(constantOp, shapeType, "y"))); auto outType = convertAscendType(op.getOut().getType()); reshapeOp->update_output_desc_y(outType); ascendVals[op.getOut()] = AscendValue(reshapeOp, outType, "y"); } void AscendCompiler::lowerOp(func::ReturnOp op) { for (auto operand : op.getOperands()) { results.push_back(getValue(operand)); } } #undef SET_INPUT LogicalResult loweringAscend(OpBuilder& builder, Operation* callee, StringRef checkpointDir, llvm::SmallVector* loweringData) { AscendCompiler compiler; llvm::SmallVector inputs; auto func = dyn_cast(callee); for (auto argument : func.getArguments()) { inputs.push_back(argument); } compiler.addInputs(inputs); func.getBody().walk([&](Operation* op) { if (auto x = dyn_cast(op)) { compiler.lowerOp(x, checkpointDir); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else if (auto x = dyn_cast(op)) { compiler.lowerOp(x); } else { llvm::errs() << "could not lowerring " << op->getName() << " for backend ascend\n"; exit(1); } }); compiler.serializeToBuffer(loweringData); return success(); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/LoweringLaunchJob.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/LoweringLaunchJob.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowLiteUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #ifdef LITE_USE_ASCEND_NPU #include "OneFlow/Transform/Lowering/LoweringAscend.h" #endif // LITE_USE_ASCEND_NPU namespace mlir { namespace oneflow { namespace lite { struct LoweringLaunchJobPass : public PassWrapper> { StringRef checkpointDir; explicit LoweringLaunchJobPass(StringRef checkpointDir) : checkpointDir(checkpointDir) {} void runOnOperation() override; LogicalResult loweringLaunchJob(OpBuilder& builder, Operation* callee, StringRef backend, llvm::SmallVector* loweringData); }; LogicalResult LoweringLaunchJobPass::loweringLaunchJob( OpBuilder& builder, Operation* callee, StringRef backend, llvm::SmallVector* loweringData) { if (backend == "ascend") { #ifdef LITE_USE_ASCEND_NPU return loweringAscend(builder, callee, checkpointDir, loweringData); #else llvm::errs() << "please recompile with LITE_USE_ASCEND_NPU=ON\n"; return failure(); #endif // LITE_USE_ASCEND_NPU } else { llvm::errs() << "lowering for backend " << backend << " is not supported yet\n"; return failure(); } return success(); } void LoweringLaunchJobPass::runOnOperation() { SmallVector launchOps; Operation* entryJobOp = getEntryJobOp(getOperation()); entryJobOp->walk([&](Operation* op) { if (dyn_cast(op)) { launchOps.push_back(op); } }); SymbolTable symbolTable(getOperation()); OpBuilder builder(&getContext()); // TODO(): register backend converters for (Operation* op : launchOps) { auto launchOp = dyn_cast(op); Operation* callee = symbolTable.lookup(launchOp.getCallee()); if (!callee) { llvm::errs() << "can not find a callee named " << launchOp.getCallee() << "\n"; return signalPassFailure(); } llvm::SmallVector loweringData; if (failed(loweringLaunchJob(builder, callee, launchOp.getDeviceTag(), &loweringData))) { llvm::errs() << "failed to lowerring job " << launchOp.getCallee() << "\n"; } op->setAttr("mlir_assembly", builder.getStringAttr(StringRef(reinterpret_cast(loweringData.data()), loweringData.size()))); } } std::unique_ptr createLiteLoweringLaunchJobPass(StringRef checkpointDir) { return std::unique_ptr(new LoweringLaunchJobPass(checkpointDir)); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/MemoryPlanning.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/MemoryPlanning.h" #include #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowLiteUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/SetVector.h" namespace mlir { namespace oneflow { namespace lite { int LiteBufferStrategy::getValueSegmentId(Value value) const { auto it = valueSegmentInfos.find(value); if (it == valueSegmentInfos.end()) { return -1; } return it->second.segmentId; } size_t LiteBufferStrategy::getValueSegmentOffset(Value value) const { auto it = valueSegmentInfos.find(value); if (it == valueSegmentInfos.end()) { return -1; } return it->second.segmentOffset; } LogicalResult LiteBufferStrategy::insertValue(Value value, int segmentId, size_t segmentOffset) { if (segments.size() < segmentId) { llvm::errs() << "segmentId is out of boundary.\n"; return failure(); } valueSegmentInfos[value] = ValueSegmentInfo{segmentId, segmentOffset}; return success(); } class ValueLiveness { public: ValueLiveness() = default; void addValue(Value value, size_t liveStart, size_t liveEnd) { liveness[value] = LiveRange{liveStart, liveEnd}; } bool isLivenessOverlap(Value lhs, Value rhs) { LiveRange lhs_liveness = liveness[lhs]; LiveRange rhs_liveness = liveness[rhs]; return lhs_liveness.liveEnd < rhs_liveness.liveStart || lhs_liveness.liveStart > rhs_liveness.liveEnd; } private: struct LiveRange { size_t liveStart; size_t liveEnd; }; llvm::DenseMap liveness; }; struct MemoryPlanningPass : public PassWrapper> { Operation* entryJobOp; ValueLiveness valueLiveness; llvm::SmallVector sortedValues; LiteBufferStrategy* bufferStrategy; explicit MemoryPlanningPass(LiteBufferStrategy* strategy) : bufferStrategy(strategy) {} void runOnOperation() override { entryJobOp = getEntryJobOp(getOperation()); if (!entryJobOp) { llvm::errs() << "Job not found in module: " << *getOperation(); exit(1); } computeValueLiveness(); computeValueSizeAndSort(); doMemoryPlanning(); } void computeValueLiveness(); void computeValueSizeAndSort(); void doMemoryPlanning(); bool canShareMemoryWithBlock(Value value, llvm::SmallVector block); }; void MemoryPlanningPass::computeValueLiveness() { llvm::SmallVector opList; llvm::DenseMap opOrdering; llvm::DenseMap liveEnds; // Compute value liveness entryJobOp->walk([&](Operation* op) { if (!op->hasTrait() || llvm::dyn_cast(op)) { return; } opOrdering[op] = opOrdering.size(); opList.push_back(op); }); for (Operation* op : llvm::reverse(opList)) { size_t ordering = opOrdering[op]; for (Value operand : op->getOperands()) { if (liveEnds.find(operand) == liveEnds.end()) { liveEnds[operand] = ordering; } } for (Value result : op->getResults()) { size_t liveEnd = opOrdering.size(); const auto& it = liveEnds.find(result); if (it != liveEnds.end()) { liveEnd = it->second; } valueLiveness.addValue(result, ordering, liveEnd); } } } static bool isDynamicTensorType(TensorType value) { for (auto dim : value.getShape()) { if (dim == -1) { return true; } } return false; } /// Returns the bitwidth of a scalar or vector type. static size_t getTensorBitSize(TensorType value) { auto type = value.getElementType(); assert(type.isIntOrFloat()); if (isDynamicTensorType(value)) { return 0; } int64_t num = 1; for (auto dim : value.getShape()) { num *= dim; } return num * type.getIntOrFloatBitWidth(); } void MemoryPlanningPass::computeValueSizeAndSort() { llvm::SetVector> valueList; entryJobOp->walk([&](Operation* op) { if (!op->hasTrait() || llvm::dyn_cast(op) || llvm::dyn_cast(op)) { return; } valueList.insert(op->getOperands().begin(), op->getOperands().end()); valueList.insert(op->getResults().begin(), op->getResults().end()); }); sortedValues = valueList.takeVector(); llvm::sort(sortedValues.begin(), sortedValues.end(), [](Value lhs, Value rhs) { assert(lhs.getType().isa()); assert(rhs.getType().isa()); return getTensorBitSize(lhs.getType().cast()) > getTensorBitSize(rhs.getType().cast()); }); } bool MemoryPlanningPass::canShareMemoryWithBlock(Value value, llvm::SmallVector block) { if (isDynamicTensorType(value.getType().cast())) { return false; } auto device = getValueDevice(value); for (auto v : block) { if (device != getValueDevice(v)) { return false; } if (valueLiveness.isLivenessOverlap(value, v)) { return false; } } return true; } void MemoryPlanningPass::doMemoryPlanning() { if (sortedValues.empty()) { return; } llvm::SmallVector, 4> memoryBlocks; for (auto value : sortedValues) { bool shared = false; for (auto& block : memoryBlocks) { if (canShareMemoryWithBlock(value, block)) { block.push_back(value); shared = true; } } if (!shared) { memoryBlocks.push_back(llvm::SmallVector{value}); } } llvm::SmallVector& segments = bufferStrategy->getSegments(); for (auto& block : memoryBlocks) { auto device = getValueDevice(block.front()); int segmentId = segments.size(); size_t blockSize = 0; size_t alignment = 512; for (auto value : block) { size_t valueSize = getTensorBitSize(value.getType().cast()); if (valueSize > blockSize) { blockSize = valueSize; } } blockSize = (blockSize + 7) / 8; // convert to bytes blockSize = (blockSize + alignment - 1) / alignment * alignment; // alignas 512 bytes segments.push_back(LiteBufferSegment{device.getValue(), blockSize, alignment}); for (auto value : block) { auto result = bufferStrategy->insertValue(value, segmentId, /*segmentOffset*/ 0); assert(succeeded(result) && "failed to insert value to buffer strategy"); } } } std::unique_ptr createLiteMemoryPlanningPass(LiteBufferStrategy* strategy) { return std::unique_ptr(new MemoryPlanningPass(strategy)); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/Transform/PartitionLaunchJob.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Transform/PartitionLaunchJob.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" namespace mlir { namespace oneflow { namespace lite { struct PartitionLaunchJobPass : public PassWrapper> { void runOnOperation() override; bool needPartition(StringRef device) const { return device == "tensorrt" || device == "ascend"; } func::FuncOp addCallableFunc(OpBuilder& builder, StringRef callee_name, const llvm::SmallVector& operands, const llvm::SmallVector& results, const llvm::SmallVector& block); }; func::FuncOp PartitionLaunchJobPass::addCallableFunc( OpBuilder& builder, StringRef callee_name, const llvm::SmallVector& operands, const llvm::SmallVector& results, const llvm::SmallVector& block) { llvm::SmallVector operand_types, result_types; for (auto operand : operands) { operand_types.push_back(operand.getType()); } for (auto result : results) { result_types.push_back(result.getType()); } auto parentFuncOp = block[0]->getParentOfType(); auto parentModuleOp = parentFuncOp->getParentOfType(); Block::iterator insertPt(parentFuncOp->getNextNode()); builder.setInsertionPointToStart(parentModuleOp.getBody()); auto funcType = builder.getFunctionType(operand_types, result_types); auto funcOp = builder.create(block[0]->getLoc(), callee_name, funcType); auto* entryBlock = funcOp.addEntryBlock(); IRMapping mapping; for (auto operand : llvm::enumerate(operands)) { mapping.map(operand.value(), entryBlock->getArgument(operand.index())); } builder.setInsertionPointToStart(entryBlock); for (Operation* op : block) { builder.insert(op->clone(mapping)); for (auto result : llvm::enumerate(op->getResults())) { mapping.map(result.value(), entryBlock->back().getResult(result.index())); } } llvm::SmallVector mappingResults; for (auto result : results) { mappingResults.push_back(mapping.lookup(result)); } builder.create(block[0]->getLoc(), mappingResults); return funcOp; } void PartitionLaunchJobPass::runOnOperation() { // TODO(): refactor llvm::DenseMap>> partitionOps; getOperation().walk([&](Operation* op) { if (!op->hasTrait()) { return; } if (dyn_cast(op)) { return; } auto device = op->getAttrOfType(OpTrait::IsOpConfCompatible::getDeviceTagAttr()); if (!needPartition(device.getValue())) { return; } partitionOps[device.getValue()].insert(op); }); for (auto it : partitionOps) { if (it.second.empty()) { continue; } llvm::DenseMap inputVals, resultVals; for (Operation* op : it.second) { for (Value operand : op->getOperands()) { if (!it.second.count(operand.getDefiningOp())) { inputVals.try_emplace(operand, inputVals.size()); } } for (Value result : op->getResults()) { for (auto& use : result.getUses()) { if (!it.second.count(use.getOwner())) { resultVals.try_emplace(result, resultVals.size()); break; } } } } auto block = it.second.takeVector(); // TODO(): check job is acyclic or not llvm::SmallVector operands(inputVals.size()); llvm::SmallVector results(resultVals.size()); for (auto in : inputVals) { operands[in.second] = in.first; } for (auto out : resultVals) { results[out.second] = out.first; } OpBuilder builder(&getContext()); auto callableFunc = addCallableFunc(builder, it.first.str() + ".launch", operands, results, block); Operation* firstOp = block[0]; NamedAttrList attributes; attributes.set(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), OpTrait::IsOpConfCompatible::getDeviceTag(firstOp)); attributes.set(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), OpTrait::IsOpConfCompatible::getDeviceName(firstOp)); if (auto hierarchy = OpTrait::IsOpConfCompatible::getHierarchy(firstOp)) { attributes.set(OpTrait::IsOpConfCompatible::getHierarchyAttr(), hierarchy); } attributes.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), builder.getStringAttr(it.first.str() + ".launch")); if (auto scope_symbol_id = OpTrait::IsOpConfCompatible::getScopeSymbolID((firstOp))) { attributes.set(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), scope_symbol_id); } builder.setInsertionPointAfter(firstOp); auto launchOp = builder.create(firstOp->getLoc(), callableFunc, attributes, operands); launchOp->setAttr("mlir_assembly", builder.getStringAttr("")); for (auto result : llvm::enumerate(results)) { result.value().replaceAllUsesWith(launchOp->getOperand(result.index())); } for (Operation* op : block) { op->dropAllUses(); op->erase(); } } } std::unique_ptr createLitePartitionLaunchJobPass() { return std::unique_ptr(new PartitionLaunchJobPass()); } } // namespace lite } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-lite/lib/OneFlow/cmake/FindAscendSdk.cmake ================================================ find_path(ASCEND_INCLUDE_DIR graph/graph.h PATHS ${ASCEND_HOME_PATH} ${ASCEND_HOME_PATH}/include $ENV{ASCEND_HOME_PATH} $ENV{ASCEND_HOME_PATH}/include) find_library( ASCEND_GRAPH_LIBRARY NAMES graph PATHS ${ASCEND_HOME_PATH} ${ASCEND_HOME_PATH}/lib64 $ENV{ASCEND_HOME_PATH} $ENV{ASCEND_HOME_PATH}/lib64) if(NOT ASCEND_INCLUDE_DIR OR NOT ASCEND_GRAPH_LIBRARY) message( FATAL_ERROR "Ascend Sdk was not found. You can set ASCEND_HOME_PATH to specify the search path." ) endif() add_library(ascend_graph SHARED IMPORTED GLOBAL) set_property(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${ASCEND_GRAPH_LIBRARY}) set(ASCEND_LIBRARIES ascend_graph) ================================================ FILE: oneflow/ir/oneflow-lite/schemas/CMakeLists.txt ================================================ include(install_flatcc.cmake) add_subdirectory(attributes) file(GLOB LITE_SCHEMA_FILES *.fbs) flatcc_generate(SCHEMA_SRCS ${LITE_SCHEMA_FILES}) add_custom_target(lite_schema_gen DEPENDS ${SCHEMA_SRCS} flatcc-runtime) add_library(lite_schemas INTERFACE) add_dependencies(lite_schemas lite_schema_gen lite_attribute_schema_gen) ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/CMakeLists.txt ================================================ file(GLOB LITE_ATTRIBUTE_SCHEMA_FILES *.fbs) flatcc_generate(ATTRIBUTE_SCHEMA_SRCS ${LITE_ATTRIBUTE_SCHEMA_FILES}) add_custom_target(lite_attribute_schema_gen DEPENDS ${ATTRIBUTE_SCHEMA_SRCS} flatcc-runtime) ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/bool.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table BoolDef { value:bool; } root_type BoolDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/f32.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table F32Def { value:float; } root_type F32Def; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/f32s.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table F32sDef { value:[float]; } root_type F32sDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/f64.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table F64Def { value:double; } root_type F64Def; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/i32.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table I32Def { value:int; } root_type I32Def; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/i32s.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table I32sDef { value:[int]; } root_type I32sDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/i64.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table I64Def { value:long; } root_type I64Def; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/i64s.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table I64sDef { value:[long]; } root_type I64sDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/shape.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table ShapeDef { value:[long]; } root_type ShapeDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/shapes.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. include "shape.fbs"; namespace oneflow_lite; table ShapesDef { value:[ShapeDef]; } root_type ShapesDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/str.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table StringDef { value:string; } root_type StringDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/attributes/strs.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; table StringsDef { value:[string]; } root_type StringsDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/executable.fbs ================================================ // Copyright 2020 The OneFlow Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. namespace oneflow_lite; // Buffer segment can be regarded as the device memory block table BufferSegmentDef { size:long; // Device the segment belongs device:int; alignment:int; } table TensorDef { // Type should be one of the primary data type // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool type:string; layout:string; sizes:[long]; strides:[long]; // Memory planning information about this tensor segment_id:int; segment_offset:long; } table ParameterDef { // Type should be one of the primary data type // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool type:string; sizes:[long]; buffer:[byte]; } table AttrDef { // Type should be one of the primary data type // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool,str,param,etc. // or a list of them i8s,i16s,i32s,i64s,u8s,u16s,u32s,u64s, // f8s,f16s,bf16s,f32s,f64s,bools,strs type:string; key:string; value:[byte]; } table OpFunctionDef { name:string; // Code generated for AOT codegen. body:[byte]; // Signature of the function call. Signature can be empty to use the // default function signature // "(t0, t1, ..., tN, r0, r1, ..., rM) -> (tN+1, tN+2, ..., tN+T)" // in which t means Tensor and r means op attributes signature:string; } table OpDef { // The operator type name, such as "conv2d", "softmax" name:string; // Input operand indices inputs:[int]; // Output operand indices outputs:[int]; // Attributes the operator has attrs:[AttrDef]; // Device that executes the operator device:int; } table ExecutableDef { version:int; name:string; // Devices used in this executable devices:[string]; ops:[OpDef]; operands:[TensorDef]; inputs:[int]; outputs:[int]; input_names:[string]; output_names:[string]; segments:[BufferSegmentDef]; // Functions will be registered in the global function table and will // be used firstly, even if those operators functions are available // in the runtime library functions:[OpFunctionDef]; } root_type ExecutableDef; ================================================ FILE: oneflow/ir/oneflow-lite/schemas/install_flatcc.cmake ================================================ include(ExternalProject) include(FetchContent) set(FLATCC_URL https://github.com/dvidelabs/flatcc/archive/refs/tags/v0.6.1.tar.gz) use_mirror(VARIABLE FLATCC_URL URL ${FLATCC_URL}) message(STATUS "Download flatcc from url: ${FLATCC_URL}") #FetchContent_Declare(flatcc URL ${FLATCC_URL}) #FetchContent_MakeAvailable(flatcc) FetchContent_Populate(flatcc URL ${FLATCC_URL} SOURCE_DIR flatcc) set(FLATCC_ROOT ${CMAKE_CURRENT_BINARY_DIR}/flatcc) set(FLATCC_SRCS "${FLATCC_ROOT}/src/runtime/builder.c" "${FLATCC_ROOT}/src/runtime/verifier.c" "${FLATCC_ROOT}/src/runtime/emitter.c" "${FLATCC_ROOT}/src/runtime/json_parser.c" "${FLATCC_ROOT}/src/runtime/json_printer.c" "${FLATCC_ROOT}/src/runtime/refmap.c" "${FLATCC_ROOT}/config/config.h") set(FLATCC_INCLUDE_DIR ${FLATCC_ROOT}/include) add_library(flatcc-runtime STATIC ${FLATCC_SRCS}) target_include_directories(flatcc-runtime SYSTEM PUBLIC ${FLATCC_INCLUDE_DIR}) add_executable( flatcc-cli "${FLATCC_ROOT}/src/cli/flatcc_cli.c" "${FLATCC_ROOT}/external/hash/cmetrohash64.c" "${FLATCC_ROOT}/external/hash/str_set.c" "${FLATCC_ROOT}/external/hash/ptr_set.c" "${FLATCC_ROOT}/src/compiler/hash_tables/symbol_table.c" "${FLATCC_ROOT}/src/compiler/hash_tables/scope_table.c" "${FLATCC_ROOT}/src/compiler/hash_tables/name_table.c" "${FLATCC_ROOT}/src/compiler/hash_tables/schema_table.c" "${FLATCC_ROOT}/src/compiler/hash_tables/value_set.c" "${FLATCC_ROOT}/src/compiler/fileio.c" "${FLATCC_ROOT}/src/compiler/parser.c" "${FLATCC_ROOT}/src/compiler/semantics.c" "${FLATCC_ROOT}/src/compiler/coerce.c" "${FLATCC_ROOT}/src/compiler/codegen_schema.c" "${FLATCC_ROOT}/src/compiler/flatcc.c" "${FLATCC_ROOT}/src/compiler/codegen_c.c" "${FLATCC_ROOT}/src/compiler/codegen_c_reader.c" "${FLATCC_ROOT}/src/compiler/codegen_c_sort.c" "${FLATCC_ROOT}/src/compiler/codegen_c_builder.c" "${FLATCC_ROOT}/src/compiler/codegen_c_verifier.c" "${FLATCC_ROOT}/src/compiler/codegen_c_sorter.c" "${FLATCC_ROOT}/src/compiler/codegen_c_json_parser.c" "${FLATCC_ROOT}/src/compiler/codegen_c_json_printer.c" "${FLATCC_ROOT}/src/runtime/builder.c" "${FLATCC_ROOT}/src/runtime/emitter.c" "${FLATCC_ROOT}/src/runtime/refmap.c") target_include_directories(flatcc-cli PRIVATE "${FLATCC_ROOT}/external" "${FLATCC_ROOT}/include" "${FLATCC_ROOT}/config") #set(FLATCC_EXE ${CMAKE_CURRENT_BINARY_DIR}/flatcc-cli PARENT_SCOPE) set(FLATCC_EXE ${CMAKE_CURRENT_BINARY_DIR}/flatcc-cli) function(FLATCC_GENERATE SRCS) set(${SRCS}) foreach(FIL ${ARGN}) get_filename_component(ABS_FIL ${FIL} ABSOLUTE) get_filename_component(FIL_WE ${FIL} NAME_WE) list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h") add_custom_command( OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h" COMMAND ${FLATCC_EXE} ARGS --builder --verifier --outfile=${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h -a ${ABS_FIL} DEPENDS ${ABS_FIL} ${FLATCC_EXE} COMMENT "Running flatcc compiler on ${FIL}" VERBATIM) set(${SRCS} ${${SRCS}} PARENT_SCOPE) endforeach() endfunction() ================================================ FILE: oneflow/ir/oneflow-opt/CMakeLists.txt ================================================ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) add_llvm_executable(oneflow-opt oneflow-opt.cpp) set(_origin_prefix "\$ORIGIN") if(APPLE) set(_origin_prefix "@loader_path") endif() set_target_properties( oneflow-opt PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH "${_origin_prefix}" INSTALL_RPATH "${_origin_prefix}") llvm_update_compile_flags(oneflow-opt) target_link_libraries( oneflow-opt PRIVATE MLIROneFlow ${dialect_libs} ${conversion_libs} MLIROptLib $ MLIROneFlowExtension MLIROneFlowTransformDialect) mlir_check_all_link_libraries(oneflow-opt) ================================================ FILE: oneflow/ir/oneflow-opt/README.md ================================================ # OneFlow MLIR Optimizer This module includes a CLI optimize a `.mlir` file. ================================================ FILE: oneflow/ir/oneflow-opt/oneflow-opt.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/control/ctrl_bootstrap.pb.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "OneFlow/SBP/SBPDialect.h" #include "OneFlow/OKL/OKLDialect.h" #include "OneFlow/OKM/OKMDialect.h" #include "OneFlow/OKL/passes.h" #include "OneFlow/OKM/passes.h" #include "Transform/TransformDialectExtension.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" const auto global_cse_state = std::make_shared(); int32_t main(int32_t argc, char** argv) { ::oneflow::Singleton<::oneflow::ProcessCtx>::New(); mlir::registerAllPasses(); mlir::oneflow::registerCSEPasses(global_cse_state); mlir::oneflow::registerPasses(); mlir::okm::registerPasses(); mlir::okl::registerPasses(); mlir::oneflow::transform_dialect::registerTransformDialectEraseSchedulePass(); mlir::oneflow::transform_dialect::registerTransformDialectInterpreterPass(); mlir::DialectRegistry registry; // Note: register all mlir dialect and their extension. mlir::registerAllDialects(registry); mlir::oneflow::transform_dialect::registerTransformDialectExtension(registry); registry.insert(); registry.insert(); registry.insert(); registry.insert(); return failed(mlir::MlirOptMain(argc, argv, "OneFlow optimizer driver\n", registry)); } ================================================ FILE: oneflow/ir/oneflow-runner/CMakeLists.txt ================================================ set(LLVM_LINK_COMPONENTS Core Support nativecodegen native) oneflow_add_llvm_tool(oneflow-runner oneflow-runner.cpp) set(_origin_prefix "\$ORIGIN") if(APPLE) set(_origin_prefix "@loader_path") endif() set_target_properties( oneflow-runner PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH "${_origin_prefix}" INSTALL_RPATH "${_origin_prefix}") target_link_libraries( oneflow-runner PRIVATE MLIRAnalysis MLIRExecutionEngine MLIRIR MLIRJitRunner MLIRLLVMIRTransforms MLIRLLVMToLLVMIRTranslation MLIRToLLVMIRTranslationRegistration MLIRParser MLIRTargetLLVMIRExport MLIRSupport MLIROneFlow glog::glog) ================================================ FILE: oneflow/ir/oneflow-runner/oneflow-runner.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ //===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Main entry point to a command line utility that executes an MLIR file on the // CPU by translating MLIR to LLVM IR before JIT-compiling and executing the // latter. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/Target/LLVMIR/Dialect/All.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" #include "OneFlow/OneFlowDialect.h" int main(int argc, char** argv) { llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); // llvm::InitializeNativeTargetAsmParser(); // link fails mlir::DialectRegistry registry; mlir::registerAllToLLVMIRTranslations(registry); registry.insert(); return mlir::JitRunnerMain(argc, argv, registry); } ================================================ FILE: oneflow/ir/oneflow-runtime/CMakeLists.txt ================================================ add_subdirectory(lib) ================================================ FILE: oneflow/ir/oneflow-runtime/lib/CMakeLists.txt ================================================ oneflow_add_mlir_library(MLIROneFlowRuntime Runtime.cpp) if(WITH_MLIR_CUDA_CODEGEN) set(MLIR_RUNTIME_GPU_LIBS mlir_cuda_runtime) endif(WITH_MLIR_CUDA_CODEGEN) target_link_libraries(MLIROneFlowRuntime PUBLIC -Wl,--no-as-needed ${MLIR_RUNTIME_GPU_LIBS} mlir_c_runner_utils -Wl,--as-needed) ================================================ FILE: oneflow/ir/oneflow-runtime/lib/Runtime.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // This file is added to avoid cmake error ================================================ FILE: oneflow/ir/oneflow-translate/CMakeLists.txt ================================================ set(LLVM_LINK_COMPONENTS Support) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) set(LLVM_ENABLE_RTTI ON) # turn this on to make it compatible with protobuf include_directories(${PROJECT_SOURCE_DIR}/oneflow-translate/include) include_directories(${PROJECT_BINARY_DIR}/oneflow-translate/include) add_subdirectory(include) add_subdirectory(lib) add_llvm_executable(oneflow-translate oneflow-translate.cpp DEPENDS MLIROneFlow MLIROneFlowTranslation) set(_origin_prefix "\$ORIGIN") if(APPLE) set(_origin_prefix "@loader_path") endif() set_target_properties( oneflow-translate PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH "${_origin_prefix}" INSTALL_RPATH "${_origin_prefix}") llvm_update_compile_flags(oneflow-translate) target_link_libraries(oneflow-translate PRIVATE ${dialect_libs} ${translation_libs} PUBLIC MLIRTranslateLib MLIROneFlowTranslation) mlir_check_link_libraries(oneflow-translate) ================================================ FILE: oneflow/ir/oneflow-translate/README.md ================================================ # OneFlow Translate ## Import OneFlow Job to MLIR and dump a new Job ``` job -> module sub graph -> function ``` ### Pipeline - Lower case: OneFlow, upper case: MLIR - [something]: a step, could be rewrite or other kinds of optimizations ``` user op -> OPAQUE USER OP -> CONCRETE OP -> [OPTIMIZATION] -> user op system op -> OPAQUE SYSTEM OP -> system op ``` ### About blob name - MLIR exporters and and exporters should take care of blob names so other components don't touch it. ### About SBP signature - There should be a sharding op to store SBP information. - Reusing built-in tensor types is pratical and makes it easy to resuse pass interfaces. - Implementing a tensor type with SBP is actually working agaist MLIR because pass in MLIR works better with operations. ### Basic principles for a legit rewrite 1. Source op of control edge shouldn't be erased 2. Erasing, creating op shouldn't introduce boxing 3. Results' shapes should stay identical ### Information not included in OpConf - There are information in job not included in `OpConf`: ```protobuf message JobHelperConf { map tag2lbi_relations = 1; ... } message JobParallelViewConf { ... } ``` - Create callbacks wrapping `JobBuilder` MLIR can call to update job helperconfs when it is erasing/building operations. ================================================ FILE: oneflow/ir/oneflow-translate/include/CMakeLists.txt ================================================ add_subdirectory(OneFlow) ================================================ FILE: oneflow/ir/oneflow-translate/include/OneFlow/CMakeLists.txt ================================================ ================================================ FILE: oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_ #define ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_ #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "OneFlow/SBP/SBPImporter.h" #include "OneFlow/OneFlowOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include #include using UserOpArgs = const ::google::protobuf::Map&; using UserOpArgDefs = const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>&; namespace mlir { namespace oneflow { // TODO: wrap in a helper namespace LogicalResult IsAttrBelong2Op(const std::string& op_type_name, const std::string& attr_name); LogicalResult ConvertUserOpInputs(Operation* op, StringRef op_name, ::oneflow::UserOpConf* user_conf); LogicalResult ConvertUserOpOutputs(Operation* op, StringRef op_name, ::oneflow::UserOpConf* user_conf); LogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf); llvm::Optional GetDataTypeAttr(MLIRContext* context, ::oneflow::DataType oneflow_value); LogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_conf); LogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf); LogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf); LogicalResult ParseNdSbpFromAttr(ArrayAttr nd_sbp_attr, ::oneflow::NdSbp* nd_sbp); Attribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp); class Importer { public: Importer(MLIRContext* context, ModuleOp module) : builder_(context), context_(context), module_(module), unknown_loc_(FileLineColLoc::get(context, "unknown_loc", 0, 0)) {} virtual ~Importer() = default; LogicalResult namedAttributesFromUserOp(const ::oneflow::OperatorConf& op, std::vector& attr_vec); virtual LogicalResult AppendDataInOperand(const std::string& lbn, std::vector<::mlir::Value>& operand_vec) { return failure(); } virtual LogicalResult AppendDataInOperand(const std::string& key, const int32_t index, const std::string& lbn, std::vector<::mlir::Value>& operand_vec) { return AppendDataInOperand(lbn, operand_vec); } virtual LogicalResult AppendCtrlInOperand(const ::oneflow::OperatorConf& op, std::vector<::mlir::Value>& operand_vec) = 0; LogicalResult AppendCtrlOutType(llvm::SmallVector& out_types); LogicalResult AddOpConf(const ::oneflow::OperatorConf& op, std::vector& attr_vec); LogicalResult AddUserOpInputOutputSegments(const ::oneflow::OperatorConf& op, std::vector& attr_vec); virtual LogicalResult AddDeviceName(const ::oneflow::OperatorConf& op, std::vector& attr_vec) = 0; LogicalResult AddOperandSegmentSizes(int32_t input_lbns_size, int32_t ctrl_in_size, std::vector& attr_vec); LogicalResult AddResultSegmentSizes(int32_t output_lbns_size, std::vector& attr_vec); virtual LogicalResult InsertOpResults(const ::oneflow::OperatorConf& op, Operation*) = 0; LogicalResult ProcessUserOp(const ::oneflow::OperatorConf& op); virtual LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) = 0; IntegerAttr getSI64IntegerAttr(int64_t value) { return IntegerAttr::get(GetBuilder().getIntegerType(64, /*isSigned=*/true), APInt(64, value, /*isSigned=*/true)); } ArrayAttr getSI32ArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [this](int32_t v) -> Attribute { return GetBuilder().getSI32IntegerAttr(v); })); return GetBuilder().getArrayAttr(attrs); } ArrayAttr getSI64ArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>( llvm::map_range(values, [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); })); return GetBuilder().getArrayAttr(attrs); } ArrayAttr GetAttrFromShape(const ::oneflow::ShapeProto& shape); ArrayAttr GetAttrFromStride(const ::oneflow::Int64ListProto& stride); OpBuilder& GetBuilder() { return builder_; } MLIRContext* GetMLIRContext() { return context_; } ModuleOp& GetModule() { return module_; } Location& GetRootLocation() { return unknown_loc_; } virtual Type GetTensorTypeOfLbn(const std::string& lbn) = 0; void SetOpStateLoc(const ::oneflow::OperatorConf&, OperationState&); private: OpBuilder builder_; MLIRContext* context_; ModuleOp module_; Location unknown_loc_; }; class RoundTripOneFlowJobWrapperInterface { public: virtual ~RoundTripOneFlowJobWrapperInterface() {} virtual const ::oneflow::Job* job() const = 0; virtual void UpdateJob(::oneflow::Job* new_job) = 0; virtual void DumpLog(const std::string& filename, const std::string& content) = 0; virtual const ::oneflow::ParallelConf& ParallelConf4OpName(const std::string& op_name) const = 0; virtual const ::oneflow::OperatorConf& OpConf4OpName(const std::string& op_name) const = 0; virtual std::pair, std::vector> InputBns4OpName( const std::string& op_name) const = 0; virtual std::vector OutputLbns4OpName(const std::string& op_name) const = 0; virtual std::string ReplaceInputLbnInOpCustomizedConf(::oneflow::OperatorConf* op_conf, const std::string& ibn, const std::string& new_val) const = 0; virtual void QueryLogicalBlob( const std::string& lbn, std::function cb) const = 0; virtual void TopoForEachOpConf( std::function Handler) const = 0; virtual bool IsLastIRPass() const = 0; }; void RoundTripOneFlowJob( RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::function& is_legit_job); void registerFromOneFlowJobTranslation(); std::string ConvertJobToTosaIR(RoundTripOneFlowJobWrapperInterface& job_wrapper); void SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path); std::string ConvertJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper); void LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path); } // namespace oneflow } // namespace mlir #endif // ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_ ================================================ FILE: oneflow/ir/oneflow-translate/lib/CMakeLists.txt ================================================ add_subdirectory(OneFlow) ================================================ FILE: oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt ================================================ oneflow_add_mlir_library( MLIROneFlowTranslation MLIROneFlowTranslation.cpp Importer.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/oneflow-translate/include/OneFlow DEPENDS oneflow_deps LINK_LIBS PUBLIC MLIRIR ${dialect_libs} ${translation_libs} MLIRIR MLIRParser MLIRPass MLIRSPIRVDialect MLIRTranslateLib MLIRSupport MLIROneFlow MLIRTosaToTensor oneflow) if(BUILD_SHARED_LIBS) get_filename_component(ONEFLOW_BUILD_ROOT_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../.. ABSOLUTE) get_property(TRANSLATE_INSTALL_RPATH TARGET MLIROneFlowTranslation PROPERTY INSTALL_RPATH) list(APPEND TRANSLATE_INSTALL_RPATH ${PROTOBUF_LIBRARY_DIR}) list(APPEND TRANSLATE_INSTALL_RPATH ${ONEFLOW_BUILD_ROOT_DIR}) set_target_properties(MLIROneFlowTranslation PROPERTIES INSTALL_RPATH "${TRANSLATE_INSTALL_RPATH}") endif() mlir_check_link_libraries(MLIROneFlowTranslation) ================================================ FILE: oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/UserOpConversion.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/user_op_conf.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/SBP/SBPDialect.h" #include "OneFlow/SBP/SBPAttributes.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/UserOpReflection.h" #include "OneFlow/OneFlowTypes.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/Passes.h" #include "OneFlow/MLIROneFlowTranslation.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/OneFlowDataTypeConversion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/UseDefLists.h" #include "mlir/IR/Value.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm-c/Core.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/job/sbp_signature_builder.h" namespace mlir { namespace oneflow { using PbMessage = google::protobuf::Message; namespace { using SizeVec = SmallVector; SizeVec GetSizesFromArgs(UserOpArgs args, UserOpArgDefs arg_defs) { SizeVec sizes{}; llvm::StringSet<> names({}); for (const auto& arg : args) { names.insert(arg.first); } for (const auto& arg_def : arg_defs) { int32_t size = 0; if (names.contains(arg_def.name())) { size = args.at(arg_def.name()).s_size(); } sizes.push_back(size); } return sizes; } std::vector GetOutputLbns(const ::oneflow::OperatorConf& op, UserOpArgDefs arg_defs) { SizeVec sizes{}; llvm::StringSet<> names_appeared({}); std::vector output_lbn_vec{}; const auto& op_name = op.name(); for (const auto& arg : op.user_conf().output()) { names_appeared.insert(arg.first); } for (const auto& arg_def : arg_defs) { const auto& key = arg_def.name(); const auto& it = op.user_conf().output().find(key); if (it == op.user_conf().output().end()) { continue; } auto result_size = it->second.s_size(); if (result_size == 0) { continue; } for (int32_t i = 0; i < result_size; i++) { const auto output_lbn = op_name + "/" + key + "_" + std::to_string(i); output_lbn_vec.push_back(output_lbn); } } return output_lbn_vec; } } // namespace LogicalResult IsAttrBelong2Op(const std::string& op_type_name, const std::string& attr_name) { ::oneflow::user_op::UserOpDefWrapper op_def(support::getUserOpDef(op_type_name)); return success(op_def.IsAttrName(attr_name)); } LogicalResult Importer::AddUserOpInputOutputSegments(const ::oneflow::OperatorConf& op, std::vector& attr_vec) { if (op.has_user_conf() == false) return failure(); const auto& user_conf = op.user_conf(); const ::oneflow::UserOpDef& op_def = support::getUserOpDef(op.user_conf().op_type_name()); const auto UserOpOperationName = OperationName(UserOp::getOperationName(), GetMLIRContext()); attr_vec.push_back(GetBuilder().getNamedAttr( oneflow::UserOp::getInputSizesAttrName(UserOpOperationName), GetBuilder().getI32ArrayAttr(GetSizesFromArgs(user_conf.input(), op_def.input())))); attr_vec.push_back(GetBuilder().getNamedAttr( oneflow::UserOp::getOutputSizesAttrName(UserOpOperationName), GetBuilder().getI32ArrayAttr(GetSizesFromArgs(user_conf.output(), op_def.output())))); auto output_lbns = GetOutputLbns(op, op_def.output()); attr_vec.push_back(GetBuilder().getNamedAttr( OpTrait::IsImportCompatible::getOutputLBNsAttr(), GetBuilder().getStrArrayAttr( SmallVector({output_lbns.begin(), output_lbns.end()})))); return success(); } llvm::Optional GetDataTypeAttr(MLIRContext* context, ::oneflow::DataType oneflow_value) { switch (oneflow_value) { case ::oneflow::DataType::kInvalidDataType: return oneflow::DataTypeAttr::get(context, mlir::oneflow::DataType::DT_InvalidDataType); break; #define DEFINE_ONE_ELIF(datatype) \ case ::oneflow::DataType::k##datatype: \ return oneflow::DataTypeAttr::get(context, mlir::oneflow::DataType::DT_##datatype); \ break; DEFINE_ONE_ELIF(Char) DEFINE_ONE_ELIF(Float) DEFINE_ONE_ELIF(Double) DEFINE_ONE_ELIF(Int8) DEFINE_ONE_ELIF(Int32) DEFINE_ONE_ELIF(Int64) DEFINE_ONE_ELIF(UInt8) DEFINE_ONE_ELIF(OFRecord) DEFINE_ONE_ELIF(Float16) DEFINE_ONE_ELIF(TensorBuffer) DEFINE_ONE_ELIF(BFloat16) DEFINE_ONE_ELIF(Bool) #undef DEFINE_ONE_ELIF default: llvm::errs() << "unsupported data type: " << oneflow_value << "\n"; return llvm::None; } } ArrayAttr Importer::GetAttrFromShape(const ::oneflow::ShapeProto& shape) { return GetBuilder().getArrayAttr(llvm::to_vector<8>(llvm::map_range( shape.dim(), [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); }))); } ArrayAttr Importer::GetAttrFromStride(const ::oneflow::Int64ListProto& stride) { return GetBuilder().getArrayAttr(llvm::to_vector<8>(llvm::map_range( stride.dim(), [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); }))); } LogicalResult Importer::namedAttributesFromUserOp(const ::oneflow::OperatorConf& op, std::vector& attr_vec) { if (op.has_user_conf() == false) { GetModule().emitError("Not a user op. op name: " + op.name()); return failure(); } for (const google::protobuf::MapPair, ::oneflow::AttrValue>& attr : op.user_conf().attr()) { const std::string& name = attr.first; const ::oneflow::AttrValue& value = attr.second; if (value.has_at_int32()) { mlir::NamedAttribute kv = GetBuilder().getNamedAttr(name, GetBuilder().getSI32IntegerAttr(value.at_int32())); attr_vec.emplace_back(kv); } else if (value.has_at_int64()) { mlir::NamedAttribute kv = GetBuilder().getNamedAttr(name, getSI64IntegerAttr(value.at_int64())); attr_vec.emplace_back(kv); } #define DEFINE_ONE_ELIF(at_key, get_attr) \ else if (value.has_##at_key()) { \ mlir::NamedAttribute kv = \ GetBuilder().getNamedAttr(name, GetBuilder().get_attr(value.at_key())); \ attr_vec.emplace_back(kv); \ } DEFINE_ONE_ELIF(at_bool, getBoolAttr) DEFINE_ONE_ELIF(at_float, getF32FloatAttr) DEFINE_ONE_ELIF(at_double, getF64FloatAttr) DEFINE_ONE_ELIF(at_string, getStringAttr) #undef DEFINE_ONE_ELIF else if (value.has_at_shape()) { attr_vec.emplace_back(GetBuilder().getNamedAttr(name, GetAttrFromShape(value.at_shape()))); } else if (value.has_at_stride()) { attr_vec.emplace_back(GetBuilder().getNamedAttr(name, GetAttrFromStride(value.at_stride()))); } #define DEFINE_ONE_ELIF(at_key, get_attr, field) \ else if (value.has_##at_key()) { \ mlir::NamedAttribute kv = GetBuilder().getNamedAttr( \ name, get_attr({value.at_key().field().begin(), value.at_key().field().end()})); \ attr_vec.emplace_back(kv); \ } DEFINE_ONE_ELIF(at_list_int32, getSI32ArrayAttr, val) DEFINE_ONE_ELIF(at_list_int64, getSI64ArrayAttr, val) DEFINE_ONE_ELIF(at_list_float, GetBuilder().getF32ArrayAttr, val) #undef DEFINE_ONE_ELIF else if (value.has_at_list_string()) { std::vector r_vec = {value.at_list_string().val().begin(), value.at_list_string().val().end()}; mlir::NamedAttribute kv = GetBuilder().getNamedAttr(name, GetBuilder().getStrArrayAttr(r_vec)); attr_vec.emplace_back(kv); } else if (value.has_at_data_type()) { if (auto dt_attr = GetDataTypeAttr(GetMLIRContext(), value.at_data_type())) { mlir::NamedAttribute kv = GetBuilder().getNamedAttr(name, dt_attr.value()); attr_vec.emplace_back(kv); } else { GetModule().emitError("fail to convert op attr, key: " + name); return failure(); } } else if (value.has_at_list_data_type()) { auto dt_attr_list = llvm::map_range(value.at_list_data_type().val(), [&](auto t) -> mlir::Attribute { auto dt = GetDataTypeAttr(GetMLIRContext(), static_cast<::oneflow::DataType>(t)); CHECK(dt) << "fail to convert op attr, key: " + name; return dt.value(); }); attr_vec.emplace_back(GetBuilder().getNamedAttr( name, GetBuilder().getArrayAttr(llvm::to_vector<8>(dt_attr_list)))); } else if (value.has_at_list_shape()) { auto dense_attr_list = llvm::map_range(value.at_list_shape().val(), [&](const ::oneflow::ShapeProto& s) { return GetAttrFromShape(s); }); std::vector dense_attr_vector{dense_attr_list.begin(), dense_attr_list.end()}; attr_vec.emplace_back( GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector))); } else if (value.has_at_list_stride()) { auto dense_attr_list = llvm::map_range(value.at_list_stride().val(), [&](const ::oneflow::Int64ListProto& s) { return GetAttrFromStride(s); }); std::vector dense_attr_vector{dense_attr_list.begin(), dense_attr_list.end()}; attr_vec.emplace_back( GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector))); } else if (value.has_at_complex_double()) { std::vector dense_attr_vector{ GetBuilder().getF64FloatAttr(value.at_complex_double().real()), GetBuilder().getF64FloatAttr(value.at_complex_double().imag())}; attr_vec.emplace_back( GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector))); } else { GetModule().emitError("can't handle user op attr: " + name + ", op name: " + op.name() + ", op type name: " + op.user_conf().op_type_name()); return failure(); } } if (failed(AddUserOpInputOutputSegments(op, attr_vec))) { GetModule().emitError("fail to add input output segments: " + op.name()); return failure(); } return success(); } LogicalResult Importer::AddOperandSegmentSizes(int32_t input_lbns_size, int32_t ctrl_in_size, std::vector& attr_vec) { attr_vec.push_back(GetBuilder().getNamedAttr( mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), GetBuilder().getDenseI32ArrayAttr({input_lbns_size, ctrl_in_size}))); return success(); } LogicalResult Importer::AddResultSegmentSizes(int32_t output_lbns_size, std::vector& attr_vec) { attr_vec.push_back(GetBuilder().getNamedAttr( mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), GetBuilder().getDenseI32ArrayAttr( {output_lbns_size, 1} /* {data_out_size, ctrl_out_size} */))); return success(); } LogicalResult Importer::AppendCtrlOutType(llvm::SmallVector& out_types) { out_types.append({RankedTensorType::get({}, GetBuilder().getI1Type())}); return success(); } LogicalResult Importer::AddOpConf(const ::oneflow::OperatorConf& op, std::vector& attr_vec) { attr_vec.push_back(GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible::getOpNameAttr(), GetBuilder().getStringAttr(op.name()))); if (op.has_device_tag()) { attr_vec.push_back( GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible::getDeviceTagAttr(), GetBuilder().getStringAttr(op.device_tag()))); } attr_vec.push_back( GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible::getScopeSymbolIDAttr(), GetBuilder().getI64IntegerAttr(op.scope_symbol_id()))); return success(); } LogicalResult ParseNdSbpFromAttr(::llvm::ArrayRef nd_sbp_attr, ::oneflow::NdSbp* nd_sbp) { for (const auto& sbp_attr : nd_sbp_attr) { auto sbp_str_attr = sbp_attr.dyn_cast(); if (!sbp_str_attr) { llvm::errs() << "nd_sbp attr is not a StrArrayAttr"; return failure(); } auto sbp_strref = sbp_str_attr.getValue(); if (sbp_strref.startswith("S")) { if (!(sbp_strref.substr(1, 1) == "(" && sbp_strref.endswith(")"))) { llvm::errs() << "invalid sbp S(x) string value: " << sbp_strref; return failure(); } auto split_axis = std::stoi(sbp_strref.substr(2, 1).str()); nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(split_axis); } else if (sbp_strref == "B") { nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } else if (sbp_strref == "P") { nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); } else { llvm::errs() << "unsupported nd_sbp string value: " << sbp_strref; return failure(); } } return success(); } Attribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp) { llvm::SmallVector sbp_strs; for (const auto& sbp : nd_sbp.sbp_parallel()) { if (sbp.has_split_parallel()) { sbp_strs.emplace_back("S(" + std::to_string(sbp.split_parallel().axis()) + ")"); } else if (sbp.has_broadcast_parallel()) { sbp_strs.emplace_back("B"); } else if (sbp.has_partial_sum_parallel()) { sbp_strs.emplace_back("P"); } else { llvm::errs() << "unsupported sbp: " << nd_sbp.DebugString(); exit(EXIT_FAILURE); } } return builder.getStrArrayAttr( makeArrayRef(llvm::SmallVector(sbp_strs.begin(), sbp_strs.end()))); } LogicalResult ValidateUserOpConf(const ::oneflow::OperatorConf& op_conf, UserOpArgs args, UserOpArgDefs arg_defs) { for (const auto& input_arg : args) { const bool found = std::find_if(arg_defs.begin(), arg_defs.end(), [&](const ::oneflow::UserOpDef_ArgDef& arg_def) { return input_arg.first == arg_def.name(); }) != arg_defs.end(); if (!found) { llvm::errs() << "fail to validate user op conf, arg def of arg not found: " << input_arg.first << ", op: \n" << op_conf.DebugString() << "\n"; return failure(); } } return success(); } LogicalResult Importer::ProcessUserOp(const ::oneflow::OperatorConf& op) { if (op.has_user_conf() == false) { GetModule().emitError("Not a user op. op name: " + op.name()); return failure(); } std::vector attr_vec; if (failed(AddOpConf(op, attr_vec))) { return failure(); } if (failed(AddDeviceName(op, attr_vec))) { return failure(); } attr_vec.push_back( GetBuilder().getNamedAttr(OpTrait::IsAlternative::getOpTypeNameAttr(), GetBuilder().getStringAttr(op.user_conf().op_type_name()))); std::vector<::mlir::Value> operand_vec; if (failed(namedAttributesFromUserOp(op, attr_vec))) { return failure(); } const auto& op_def = support::getUserOpDef(op.user_conf().op_type_name()); if (failed(ValidateUserOpConf(op, op.user_conf().input(), op_def.input()))) { return failure(); } if (failed(ValidateUserOpConf(op, op.user_conf().output(), op_def.output()))) { return failure(); } for (const auto& arg_def : op_def.input()) { const auto& key = arg_def.name(); auto it = op.user_conf().input().find(key); if (it == op.user_conf().input().end()) { continue; } int32_t index = 0; for (const std::string& lbn : it->second.s()) { if (failed(AppendDataInOperand(key, index, lbn, operand_vec))) { return failure(); } index += 1; } } if (failed(AppendCtrlInOperand(op, operand_vec))) { return failure(); } ::mlir::ValueRange operands(operand_vec); Operation* created_op = nullptr; auto out_types = llvm::SmallVector(); for (const auto& arg_def : op_def.output()) { const auto& key = arg_def.name(); auto it = op.user_conf().output().find(key); if (it == op.user_conf().output().end()) { continue; } for (const auto& output_lbn : it->second.s()) { out_types.push_back(GetTensorTypeOfLbn(output_lbn)); } } if (failed(AppendCtrlOutType(out_types))) { return failure(); } OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0), UserOp::getOperationName()); uint32_t data_input_size = 0; uint32_t data_output_size = 0; for (const auto& input : op.user_conf().input()) { data_input_size += input.second.s().size(); } for (const auto& output : op.user_conf().output()) { data_output_size += output.second.s().size(); } if (failed(AddOperandSegmentSizes(data_input_size, op.ctrl_in_op_name_size(), attr_vec))) { return failure(); } if (failed(AddResultSegmentSizes(data_output_size, attr_vec))) { return failure(); } ArrayRef named_attributes(attr_vec); state.addAttributes(named_attributes); state.addOperands(operands); state.addTypes(out_types); SetOpStateLoc(op, state); created_op = GetBuilder().create(state); if (created_op == nullptr) { GetModule()->emitError("fail to create " + op.user_conf().op_type_name() + " op, name: " + op.name()); return failure(); } if (failed(InsertOpResults(op, created_op))) { return failure(); } return success(); } // namespace LogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf) { if (op->isRegistered() && !llvm::dyn_cast(op)) return success(); if (auto ctrl_ins = GetCtrlIntputOperands(op)) { for (auto ctrl_in : ctrl_ins.value()) { op_conf.add_ctrl_in_op_name( OpTrait::IsOpConfCompatible::getOpName(ctrl_in.getDefiningOp()).str()); } } return success(); } LogicalResult ConvertUserOpInputs(Operation* op, StringRef op_name, ::oneflow::UserOpConf* user_conf) { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op inputs"); return failure(); } int32_t input_idx = 0; for (auto tuple : llvm::zip(keys, sizes)) { auto input_key = std::get<0>(tuple); auto input_size = std::get<1>(tuple); if (input_size <= 0) return op->emitError("input_size <= 0, op: " + op->getName().getStringRef()); for (int32_t i = 0; i < input_size; i++) { if (auto result = GetDataInputOperands(op)[input_idx].dyn_cast()) { auto input_s_ptr = (*user_conf->mutable_input())[input_key].mutable_s()->Add(); *(input_s_ptr) = user_op::GetOutputLbn(result).value(); input_idx += 1; } else { op->emitError() << "fail to convert MLIR result to protobuf, name: " + op_name; op->dump(); return failure(); } } } return success(); } LogicalResult ConvertUserOpOutputs(Operation* op, StringRef op_name, ::oneflow::UserOpConf* user_conf) { std::vector keys{}; std::vector sizes{}; if (failed(user_op::GetFilteredSegmentKeyAndSizes(op, keys, sizes))) { op->emitError("fail to convert user op outputs"); return failure(); } for (auto tuple : llvm::zip(keys, sizes)) { auto name = std::get<0>(tuple); auto result_size = std::get<1>(tuple); if (result_size == 0) continue; for (int32_t i = 0; i < result_size; i++) { auto out_s_ptr = (*user_conf->mutable_output())[name].mutable_s()->Add(); *(out_s_ptr) = op_name.str() + "/" + name + "_" + std::to_string(i); } } return success(); } LogicalResult ConvertDT(::mlir::oneflow::DataType data_type_mlir, ::oneflow::DataType& data_type) { switch (data_type_mlir) { case oneflow::DataType::DT_InvalidDataType: data_type = ::oneflow::DataType::kInvalidDataType; break; #define DEFINE_ONE_CASE(datatype) \ case oneflow::DataType::DT_##datatype: data_type = ::oneflow::DataType::k##datatype; break; DEFINE_ONE_CASE(Char) DEFINE_ONE_CASE(Float) DEFINE_ONE_CASE(Double) DEFINE_ONE_CASE(Int8) DEFINE_ONE_CASE(Int32) DEFINE_ONE_CASE(Int64) DEFINE_ONE_CASE(UInt8) DEFINE_ONE_CASE(OFRecord) DEFINE_ONE_CASE(Float16) DEFINE_ONE_CASE(TensorBuffer) DEFINE_ONE_CASE(Bool) #undef DEFINE_ONE_CASE default: return failure(); } return success(); } LogicalResult ConvertDTFromAttr(Attribute attr, ::oneflow::DataType& data_type) { auto dt_attr = attr.dyn_cast(); return ConvertDT(dt_attr.getValue(), data_type); } void Importer::SetOpStateLoc(const ::oneflow::OperatorConf& op_conf, OperationState& state) { if (op_conf.has_loc()) { state.location = (FileLineColLoc::get(GetMLIRContext(), op_conf.loc(), 0, 0)); } } LogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_conf) { op_conf->set_name(op.getOpName().str()); op_conf->set_device_tag(op.getDeviceTag().str()); if (auto scope_symbol_id = op.getScopeSymbolId()) { op_conf->set_scope_symbol_id(scope_symbol_id.value()); } // TODO: process stream_name_hint auto* var_op_conf = op_conf->mutable_variable_conf(); var_op_conf->set_out("out"); if (auto shape_attr = op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { *var_op_conf->mutable_shape() = user_op::getAttrAsShape(shape_attr); } if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { if (auto dt_mlir = op.getDataType()) { const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value()); if (failed(dt)) { return failure(); } var_op_conf->set_data_type(dt.value()); } } if (auto model_name = op.getModelNameAttr()) { var_op_conf->set_model_name(model_name.getValue().str()); } if (auto l1_regularization = op.getL1RegularizationAttr()) { LOG(ERROR) << op_conf->name(); var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l1( l1_regularization.getValue().convertToFloat()); } if (auto l2_regularization = op.getL2RegularizationAttr()) { var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2( l2_regularization.getValue().convertToFloat()); } if (auto trainable = op.getTrainableAttr()) { var_op_conf->set_trainable(trainable.getValue()); } for (auto output : op.getParallel()->getOutputs()) { if (auto nd_outputs = output.dyn_cast()) { for (auto nd_output : nd_outputs) { std::string sbp{}; if (failed(SBPTranslation::PrintSbpAttrToString(nd_output, sbp))) return failure(); var_op_conf->add_nd_sbp(sbp); } } else { std::string sbp{}; if (failed(SBPTranslation::PrintSbpAttrToString(output, sbp))) return failure(); var_op_conf->add_nd_sbp(sbp); } } // all operands are ctrl_inputs for (const auto& operand : op->getOperands()) { op_conf->add_ctrl_in_op_name( OpTrait::IsOpConfCompatible::getOpName(operand.getDefiningOp()).str()); } if (auto floatInit = op.getFloatInitializer()) { var_op_conf->mutable_initializer()->mutable_constant_conf()->set_value( floatInit.value().convertToFloat()); } else if (auto integerInit = op.getIntegerInitializer()) { var_op_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(integerInit.value()); } else { // empty initializer var_op_conf->mutable_initializer()->mutable_empty_conf(); } return success(); } LogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf) { op_conf->set_name(op.getOpName().str()); op_conf->set_device_tag(op.getDeviceTag().str()); if (auto scope_symbol_id = op.getScopeSymbolId()) { op_conf->set_scope_symbol_id(scope_symbol_id.value()); } // TODO: process stream_name_hint auto* input_op_conf = op_conf->mutable_input_conf(); input_op_conf->set_out("out"); if (auto shape_attr = op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { *input_op_conf->mutable_blob_conf()->mutable_shape() = user_op::getAttrAsShape(shape_attr); } if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { if (auto dt_mlir = op.getDataType()) { const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value()); if (failed(dt)) { return failure(); } input_op_conf->mutable_blob_conf()->set_data_type(dt.value()); } } if (op->hasAttr(OpTrait::TensorSource::getIsDynamicAttrName())) { input_op_conf->mutable_blob_conf()->set_is_dynamic(op.getIsDynamic().value()); } if (op->hasAttr(OpTrait::TensorSource::getNdSbpAttrName())) { if (failed(ParseNdSbpFromAttr(op.getNdSbp()->getValue(), input_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) { return failure(); } } if (op->hasAttr("job_name")) { input_op_conf->set_job_name(op.getJobName().value().str()); } // operand 0 is block argument, others are ctrl_inputs for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( OpTrait::IsOpConfCompatible::getOpName(op->getOperand(i).getDefiningOp()).str()); } return success(); } LogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf) { op_conf->set_name(op.getOpName().str()); op_conf->set_device_tag(op.getDeviceTag().str()); if (auto scope_symbol_id = op.getScopeSymbolId()) { op_conf->set_scope_symbol_id(scope_symbol_id.value()); } // TODO: process stream_name_hint auto* output_op_conf = op_conf->mutable_output_conf(); output_op_conf->set_out("out"); if (auto shape_attr = op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { *output_op_conf->mutable_blob_conf()->mutable_shape() = user_op::getAttrAsShape(shape_attr); } if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { if (auto dt_mlir = op.getDataType()) { const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value()); if (failed(dt)) { return failure(); } output_op_conf->mutable_blob_conf()->set_data_type(dt.value()); } } if (op->hasAttr(OpTrait::TensorSource::getIsDynamicAttrName())) { output_op_conf->mutable_blob_conf()->set_is_dynamic(op.getIsDynamic().value()); } if (op->hasAttr(OpTrait::TensorSource::getNdSbpAttrName())) { if (failed(ParseNdSbpFromAttr(op.getNdSbp()->getValue(), output_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) { return failure(); } } if (op->hasAttr("job_name")) { output_op_conf->set_job_name(op.getJobName().value().str()); } if (op->getNumOperands() == 0) { op->emitError("output op has at least one input."); return failure(); } auto result = op->getOperand(0).dyn_cast(); auto output_lbn = user_op::GetOutputLbn(result).value(); output_op_conf->set_in(output_lbn); for (size_t i = 1; i < op->getNumOperands(); ++i) { op_conf->add_ctrl_in_op_name( OpTrait::IsOpConfCompatible::getOpName(op->getOperand(i).getDefiningOp()).str()); } return success(); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/Conversion/OneFlowToTosa.h" #include "OneFlow/OneFlowDataTypeConversion.h" #include "OneFlow/Transform/FuncOps.h" #include "OneFlow/UserOpReflection.h" #include "OneFlow/Transform/AggregateOps.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Dialect/Linalg/Passes.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/user_op_conf.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/interface_blob_conf.pb.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "OneFlow/OneFlowDialect.h" #include "OneFlow/OneFlowOps.h" #include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/Passes.h" #include "OneFlow/MLIROneFlowTranslation.h" #include "OneFlow/OneFlowUtils.h" #include "OneFlow/UserOpConversion.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/UseDefLists.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Parser/Parser.h" #include "llvm-c/Core.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include namespace mlir { namespace oneflow { using PbMessage = google::protobuf::Message; class JobImporter : Importer { public: JobImporter(RoundTripOneFlowJobWrapperInterface& job_wrapper, MLIRContext* context, ModuleOp module) : Importer(context, module), job_(job_wrapper.job()), job_wrapper_(job_wrapper) {} virtual ~JobImporter() = default; LogicalResult AppendDataInOperand(const std::string& lbn, std::vector<::mlir::Value>& operand_vec) override; LogicalResult AppendCtrlInOperand(const ::oneflow::OperatorConf& op, std::vector<::mlir::Value>& operand_vec) override; LogicalResult AddDeviceName(const ::oneflow::OperatorConf& op, std::vector& attr_vec) override; LogicalResult InsertOpResults(const ::oneflow::OperatorConf& op, Operation*) override; LogicalResult ProcessJob(); LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) override; LogicalResult ProcessVariableOp(const ::oneflow::OperatorConf& op); LogicalResult ProcessInputOp(const ::oneflow::OperatorConf& op_conf, Block* entry_block, size_t& input_count); LogicalResult ProcessOutputOp(const ::oneflow::OperatorConf& op_conf); LogicalResult TryToUpdateJob(); LogicalResult ConvertUserOp(Operation* op, ::oneflow::Job& job); LogicalResult ConvertSystemOp(Operation* op, ::oneflow::Job& job); LogicalResult ConvertVariableOp(VariableOp op, ::oneflow::Job& job); LogicalResult ConvertInputOp(InputOp op, ::oneflow::Job& job); LogicalResult ConvertOutputOp(OutputOp op, ::oneflow::Job& job); Type GetTensorTypeOfLbn(const std::string& lbn) override; Type GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf); private: std::unordered_map lbn2result_; std::unordered_map op_name2ctrl_result_; const ::oneflow::Job* job_; RoundTripOneFlowJobWrapperInterface& job_wrapper_; }; LogicalResult JobImporter::AppendCtrlInOperand(const ::oneflow::OperatorConf& op, std::vector<::mlir::Value>& operand_vec) { for (auto& ctrl_in_op_name : op.ctrl_in_op_name()) { auto it = op_name2ctrl_result_.find(ctrl_in_op_name); if (it == op_name2ctrl_result_.end()) { GetModule().emitError("ctrl edge result of this op not found: " + ctrl_in_op_name + ". op being controlled: " + op.name()); return failure(); } else { operand_vec.push_back(it->second); } } return success(); } LogicalResult JobImporter::AppendDataInOperand(const std::string& lbn, std::vector<::mlir::Value>& operand_vec) { auto it = lbn2result_.find(lbn); if (it == lbn2result_.end()) { GetModule().emitError("IR result not found for: " + lbn); return failure(); } else { operand_vec.push_back(it->second); return success(); } } LogicalResult JobImporter::InsertOpResults(const ::oneflow::OperatorConf& op, Operation* created_op) { auto output_lbns = created_op->getAttrOfType(OpTrait::IsImportCompatible::getOutputLBNsAttr()); auto data_results = GetDataOutputResults(created_op); if (output_lbns.size() != data_results.size()) { output_lbns.dump(); llvm::errs() << "output_lbns size: " << output_lbns.size() << " != data_results size: " << data_results.size() << "\n" << op.DebugString(); created_op->getAttrDictionary().dump(); created_op->dump(); return failure(); } for (const auto& data_out : llvm::enumerate(data_results)) { auto data_out_index = data_out.index(); lbn2result_.insert({output_lbns[data_out_index].dyn_cast().getValue().str(), data_out.value().dyn_cast()}); } if (auto ctrl_out = GetCtrlOutputResult(created_op)) { op_name2ctrl_result_.insert( {created_op->getAttrOfType(OpTrait::IsOpConfCompatible::getOpNameAttr()) .getValue() .str(), ctrl_out->dyn_cast()}); } return success(); } LogicalResult JobImporter::AddDeviceName(const ::oneflow::OperatorConf& op, std::vector& attr_vec) { const ::oneflow::ParallelConf& pc = job_wrapper_.ParallelConf4OpName(op.name()); std::vector device_vec = {pc.device_name().begin(), pc.device_name().end()}; attr_vec.push_back( GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible::getDeviceNameAttr(), GetBuilder().getStrArrayAttr(device_vec))); if (pc.has_hierarchy()) { attr_vec.push_back(GetBuilder().getNamedAttr( OpTrait::IsOpConfCompatible::getHierarchyAttr(), GetBuilder().getI64ArrayAttr({pc.hierarchy().dim().begin(), pc.hierarchy().dim().end()}))); } return success(); } Type JobImporter::GetTensorTypeOfLbn(const std::string& lbn) { Type ret{}; job_wrapper_.QueryLogicalBlob( lbn, [this, &ret, &lbn](const int64_t* shape_begin, const int64_t* shape_end, ::oneflow::DataType dt) { if (auto t = getTypeFromOneFlowDataType(GetMLIRContext(), dt)) { ret = RankedTensorType::get(ArrayRef(shape_begin, shape_end), t); } else { llvm::errs() << "fail to get data tensor type for: " << lbn << "\n"; } }); return ret; } LogicalResult JobImporter::ProcessSystemOp(const ::oneflow::OperatorConf& op) { if (op.has_user_conf()) { GetModule().emitError("Not a sys op. op name: " + op.name()); return failure(); } if (op.has_variable_conf()) { return ProcessVariableOp(op); } auto input_bns_lbns = job_wrapper_.InputBns4OpName(op.name()); auto input_bns = input_bns_lbns.first; auto input_lbns = input_bns_lbns.second; auto output_lbns = job_wrapper_.OutputLbns4OpName(op.name()); job_wrapper_.OutputLbns4OpName(op.name()); std::vector attr_vec; if (failed(AddOpConf(op, attr_vec))) { return failure(); } if (failed(AddDeviceName(op, attr_vec))) { return failure(); } attr_vec.push_back(GetBuilder().getNamedAttr( "input_bns", GetBuilder().getStrArrayAttr( std::vector({input_bns.begin(), input_bns.end()})))); attr_vec.push_back(GetBuilder().getNamedAttr( OpTrait::IsImportCompatible::getOutputLBNsAttr(), GetBuilder().getStrArrayAttr( std::vector({output_lbns.begin(), output_lbns.end()})))); OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0), SystemOp::getOperationName()); attr_vec.push_back( GetBuilder().getNamedAttr("op_type_case", GetBuilder().getI32IntegerAttr(op.op_type_case()))); if (failed(AddOperandSegmentSizes(static_cast(input_lbns.size()), op.ctrl_in_op_name_size(), attr_vec))) { return failure(); } if (failed(AddResultSegmentSizes(output_lbns.size(), attr_vec))) { return failure(); } state.addAttributes(attr_vec); std::vector<::mlir::Value> operand_vec; for (const auto& input_lbn : input_lbns) { if (failed(AppendDataInOperand(input_lbn, operand_vec))) { return failure(); } } if (failed(AppendCtrlInOperand(op, operand_vec))) { return failure(); } auto out_types = llvm::SmallVector(); for (const auto& output_lbn : output_lbns) { out_types.push_back(GetTensorTypeOfLbn(output_lbn)); } if (failed(AppendCtrlOutType(out_types))) { return failure(); } state.addOperands(operand_vec); state.addTypes(out_types); if (auto created_op = GetBuilder().create(state)) { if (failed(InsertOpResults(op, created_op))) { return failure(); } } else { GetModule()->emitError("fail to create op, name: " + op.name()); return failure(); } return success(); } LogicalResult JobImporter::ProcessVariableOp(const ::oneflow::OperatorConf& op_conf) { if (!op_conf.has_variable_conf()) { GetModule().emitError("Not a variable op. op name: " + op_conf.name()); return failure(); } if (op_conf.variable_conf().has_tick()) { GetModule().emitError("variable op has tick input. op name: " + op_conf.name()); return failure(); } OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), "oneflow.variable"); // attrs std::vector attr_vec; if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } // attr output_lbns auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); // attr shape auto shape_attr = GetAttrFromShape(op_conf.variable_conf().shape()); auto shape_named_attr = GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr); attr_vec.emplace_back(shape_named_attr); // attr data_type if (op_conf.variable_conf().has_data_type()) { attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::TensorSource::getDataTypeAttrName(), GetDataTypeAttr(GetMLIRContext(), op_conf.variable_conf().data_type()).value())); } // attr model_name if (op_conf.variable_conf().has_model_name()) { const std::string& model_name = op_conf.variable_conf().model_name(); attr_vec.emplace_back( GetBuilder().getNamedAttr("model_name", GetBuilder().getStringAttr(model_name))); } // attr l1 l2 regularization if (op_conf.variable_conf().has_regularizer() && op_conf.variable_conf().regularizer().has_l1_l2_conf()) { if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l1()) { float l1_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l1(); attr_vec.emplace_back(GetBuilder().getNamedAttr( "l1_regularization", GetBuilder().getF32FloatAttr(l1_regularization))); } if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l2()) { float l2_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l2(); attr_vec.emplace_back(GetBuilder().getNamedAttr( "l2_regularization", GetBuilder().getF32FloatAttr(l2_regularization))); } } // attr trainable if (op_conf.variable_conf().has_trainable()) { bool trainable = op_conf.variable_conf().trainable(); attr_vec.emplace_back( GetBuilder().getNamedAttr("trainable", GetBuilder().getBoolAttr(trainable))); } if (op_conf.variable_conf().has_initializer()) { if (op_conf.variable_conf().initializer().has_constant_conf()) { const mlir::Attribute const_initialize_attr = GetBuilder().getF32FloatAttr( op_conf.variable_conf().initializer().constant_conf().value()); attr_vec.emplace_back(GetBuilder().getNamedAttr("float_initializer", const_initialize_attr)); } else if (op_conf.variable_conf().initializer().has_constant_int_conf()) { const mlir::Attribute const_initialize_attr = getSI64IntegerAttr(op_conf.variable_conf().initializer().constant_int_conf().value()); attr_vec.emplace_back( GetBuilder().getNamedAttr("integer_initializer", const_initialize_attr)); } } // attr parallel auto conf = this->job_wrapper_.ParallelConf4OpName(op_conf.name()); auto nd_size = conf.hierarchy().dim().size(); auto nd_sbp = op_conf.variable_conf().nd_sbp(); auto parallel = mlir::oneflow::SBPTranslation::ConvertNdSbpToPsig( GetBuilder(), std::vector(nd_sbp.begin(), nd_sbp.end()), nd_size); attr_vec.emplace_back( GetBuilder().getNamedAttr(OpTrait::TensorSource::getSbpAttrName(), parallel)); // add attrs state.addAttributes(attr_vec); // operands std::vector<::mlir::Value> operand_vec; if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } state.addOperands(operand_vec); // result types llvm::SmallVector out_types; auto output_lbn = op_conf.name() + "/out"; out_types.push_back(GetTensorTypeOfLbn(output_lbn)); if (failed(AppendCtrlOutType(out_types))) { return failure(); } state.addTypes(out_types); SetOpStateLoc(op_conf, state); // create op auto op = GetBuilder().create(state); if (!op) { GetModule()->emitError("fail to create op, name: " + op_conf.name()); return failure(); } // record result if (op->getNumResults() != 2) { op->emitError("variable op should has two results (out and ctrl_output), but got " + std::to_string(op->getNumResults()) + "\n"); return failure(); } if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { op->emitError("lbn already exists, lbn: ") << output_lbn; return failure(); } if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); return failure(); } return success(); } LogicalResult JobImporter::ProcessInputOp(const ::oneflow::OperatorConf& op_conf, Block* entry_block, size_t& input_count) { if (!op_conf.has_input_conf()) { GetModule().emitError("Not a input op. op name: " + op_conf.name()); return failure(); } if (op_conf.input_conf().has_tick()) { GetModule().emitError("input op has tick input. op name: " + op_conf.name()); return failure(); } OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), "oneflow.input"); // attrs std::vector attr_vec; if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } // attr output_lbns auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); // attr shape if (op_conf.input_conf().blob_conf().has_shape()) { auto shape_attr = GetAttrFromShape(op_conf.input_conf().blob_conf().shape()); attr_vec.emplace_back( GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr)); } // attr data_type if (op_conf.input_conf().blob_conf().has_data_type()) { attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::TensorSource::getDataTypeAttrName(), GetDataTypeAttr(GetMLIRContext(), op_conf.input_conf().blob_conf().data_type()).value())); } // attr is_dynamic if (op_conf.input_conf().blob_conf().has_is_dynamic()) { bool is_dynamic = op_conf.input_conf().blob_conf().is_dynamic(); attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::TensorSource::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic))); } // attr nd_sbp if (op_conf.input_conf().blob_conf().has_nd_sbp()) { auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.input_conf().blob_conf().nd_sbp()); attr_vec.emplace_back( GetBuilder().getNamedAttr(OpTrait::TensorSource::getNdSbpAttrName(), nd_sbp_attr)); } // attr job_name if (op_conf.input_conf().has_job_name()) { const std::string& job_name = op_conf.input_conf().job_name(); attr_vec.emplace_back( GetBuilder().getNamedAttr("job_name", GetBuilder().getStringAttr(job_name))); } // add attrs state.addAttributes(attr_vec); // operands std::vector<::mlir::Value> operand_vec; operand_vec.emplace_back(entry_block->getArgument(input_count++)); if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } state.addOperands(operand_vec); // result types llvm::SmallVector out_types; auto output_lbn = op_conf.name() + "/out"; out_types.push_back(GetTensorTypeOfLbn(output_lbn)); if (failed(AppendCtrlOutType(out_types))) { return failure(); } state.addTypes(out_types); // create op auto op = GetBuilder().create(state); if (!op) { GetModule()->emitError("fail to create op, name: " + op_conf.name()); return failure(); } // record result if (op->getNumResults() != 2) { op->emitError("input op should has two results (out and ctrl_output), but got " + std::to_string(op->getNumResults()) + "\n"); return failure(); } if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { op->emitError("lbn already exists, lbn: ") << output_lbn; return failure(); } if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); return failure(); } return success(); } LogicalResult JobImporter::ProcessOutputOp(const ::oneflow::OperatorConf& op_conf) { if (!op_conf.has_output_conf()) { GetModule().emitError("Not a output op. op name: " + op_conf.name()); return failure(); } OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), "oneflow.output"); // attrs std::vector attr_vec; if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } // attr output_lbns auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); // attr shape if (op_conf.output_conf().blob_conf().has_shape()) { auto shape_attr = GetAttrFromShape(op_conf.output_conf().blob_conf().shape()); attr_vec.emplace_back( GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr)); } // attr data_type if (op_conf.output_conf().blob_conf().has_data_type()) { attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::TensorSource::getDataTypeAttrName(), GetDataTypeAttr(GetMLIRContext(), op_conf.output_conf().blob_conf().data_type()).value())); } // attr is_dynamic if (op_conf.output_conf().blob_conf().has_is_dynamic()) { bool is_dynamic = op_conf.output_conf().blob_conf().is_dynamic(); attr_vec.emplace_back(GetBuilder().getNamedAttr( OpTrait::TensorSource::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic))); } // attr nd_sbp if (op_conf.output_conf().blob_conf().has_nd_sbp()) { auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.output_conf().blob_conf().nd_sbp()); attr_vec.emplace_back( GetBuilder().getNamedAttr(OpTrait::TensorSource::getNdSbpAttrName(), nd_sbp_attr)); } // attr job_name if (op_conf.output_conf().has_job_name()) { const std::string& job_name = op_conf.output_conf().job_name(); attr_vec.emplace_back( GetBuilder().getNamedAttr("job_name", GetBuilder().getStringAttr(job_name))); } // add attrs state.addAttributes(attr_vec); // operands std::vector<::mlir::Value> operand_vec; auto input_bns_lbns = job_wrapper_.InputBns4OpName(op_conf.name()); if (input_bns_lbns.second.size() != 1) { GetModule()->emitError("output op should has only one input, op_name: " + op_conf.name()); return failure(); } if (failed(AppendDataInOperand(input_bns_lbns.second[0], operand_vec))) { return failure(); } if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } state.addOperands(operand_vec); // result types llvm::SmallVector out_types; auto output_lbn = op_conf.name() + "/out"; out_types.push_back(GetTensorTypeOfLbn(output_lbn)); if (failed(AppendCtrlOutType(out_types))) { return failure(); } state.addTypes(out_types); // create op auto op = GetBuilder().create(state); if (!op) { GetModule()->emitError("fail to create op, name: " + op_conf.name()); return failure(); } // record result if (op->getNumResults() != 2) { op->emitError("output_conf op should has two results (out and ctrl_output), but got " + std::to_string(op->getNumResults()) + "\n"); return failure(); } if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { op->emitError("lbn already exists, lbn: ") << output_lbn; return failure(); } if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); return failure(); } return success(); } LogicalResult JobImporter::ProcessJob() { llvm::SmallVector input_types; llvm::SmallVector result_types; llvm::SmallVector results; bool is_succeeded = true; job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) { if (op_conf->has_input_conf()) { auto type = GetInterfaceBlobConfType(op_conf->input_conf().blob_conf()); if (type) { input_types.emplace_back(type); } else { GetModule()->emitError("fail to collect func arg types for job:\n" + op_conf->DebugString()); is_succeeded = false; } } }); if (!is_succeeded) { return failure(); } auto func_type = GetBuilder().getFunctionType(input_types, std::nullopt); auto job_op = GetBuilder().create(GetRootLocation(), job_->job_conf().job_name(), func_type); auto* entryBlock = job_op.addEntryBlock(); GetBuilder().setInsertionPointToStart(entryBlock); is_succeeded = true; size_t input_count = 0; job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) { if (is_succeeded == false) { return; } if (op_conf->has_user_conf()) { is_succeeded = succeeded(ProcessUserOp(*op_conf)); } else if (op_conf->has_input_conf()) { is_succeeded = succeeded(ProcessInputOp(*op_conf, entryBlock, input_count)); } else if (op_conf->has_output_conf()) { is_succeeded = succeeded(ProcessOutputOp(*op_conf)); if (is_succeeded) { auto result = entryBlock->back().getResult(0); results.emplace_back(result); result_types.emplace_back(result.getType()); } } else { is_succeeded = succeeded(ProcessSystemOp(*op_conf)); } }); if (is_succeeded == false) { return failure(); } mlir::oneflow::ReturnOp return_op; if (!entryBlock->empty()) { return_op = dyn_cast(entryBlock->back()); } if (!return_op) { GetBuilder().create(GetRootLocation(), results); } func_type = GetBuilder().getFunctionType(input_types, result_types); job_op.setFunctionTypeAttr(TypeAttr::get(func_type)); GetModule().push_back(job_op); return success(); } template void UpdatePlacement(OpType* op, AdaptorType& adaptor, ::oneflow::Job& job) { auto* pg = job.mutable_placement()->add_placement_group(); pg->mutable_op_set()->add_op_name(adaptor.getOpName().str()); pg->mutable_parallel_conf()->set_device_tag(adaptor.getDeviceTag().str()); for (auto p : adaptor.getDeviceName()) { pg->mutable_parallel_conf()->add_device_name( p.template dyn_cast().getValue().str()); } if (::llvm::Optional hierarchy = adaptor.getHierarchy()) { for (auto dim : hierarchy->getValue()) { pg->mutable_parallel_conf()->mutable_hierarchy()->add_dim( dim.template dyn_cast().getInt()); } } } LogicalResult JobImporter::TryToUpdateJob() { auto new_job = ::oneflow::Job(); new_job.CopyFrom(*job_); new_job.clear_net(); new_job.mutable_placement()->clear_placement_group(); Operation* job_op = nullptr; llvm::SmallVector outputs; auto find_first_job = [&](oneflow::Job job) -> WalkResult { job_op = job.getOperation(); new_job.mutable_job_conf()->set_job_name(job.getSymName().str()); return WalkResult::interrupt(); }; GetModule().getOperation()->walk(find_first_job); if (!job_op) { GetModule()->emitError("job not found. module op: ") << *GetModule(); return failure(); } auto ConvertOp = [&](Operation* op) -> WalkResult { if (op->hasTrait()) { if (llvm::dyn_cast(op)) { if (failed(ConvertUserOp(op, new_job))) { op->emitError("failed to convert generic UserOp: ") << *op; return WalkResult::interrupt(); } } else if (llvm::dyn_cast(op)) { if (failed(ConvertSystemOp(op, new_job))) { op->emitError("failed to convert SystemOp: ") << *op; return WalkResult::interrupt(); } } else if (auto variable_op = llvm::dyn_cast(op)) { if (failed(ConvertVariableOp(variable_op, new_job))) { op->emitError("failed to process VariableOp: ") << *op; return WalkResult::interrupt(); } } else if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { // do nothing and advance } else { if (!dyn_cast(op)) { op->emitError("op is not UserOpCompatible ") << *op; return WalkResult::interrupt(); } if (failed(ConvertUserOp(op, new_job))) { op->emitError("failed to process UserOp: ") << *op; return WalkResult::interrupt(); } } } else if (llvm::dyn_cast(op)) { // do nothing and advance } else if (op->hasTrait()) { // do nothing and advance } else if (auto return_op = llvm::dyn_cast(op)) { for (auto operand : return_op->getOperands()) { outputs.emplace_back(operand); } } else { op->emitError("unexcepted op: ") << *op; return WalkResult::interrupt(); } return WalkResult::advance(); }; if (job_op->walk(ConvertOp).wasInterrupted()) { return failure(); } // add input op auto arguments = llvm::dyn_cast(job_op).getBody().front().getArguments(); for (BlockArgument argument : arguments) { for (auto& use : argument.getUses()) { Operation* owner = use.getOwner(); if (auto input_op = dyn_cast(owner)) { if (failed(ConvertInputOp(input_op, new_job))) { return failure(); } } else { return failure(); } } } // add output op for (auto output : outputs) { Operation* owner = output.getDefiningOp(); if (auto output_op = dyn_cast(owner)) { if (failed(ConvertOutputOp(output_op, new_job))) { return failure(); } } else { return failure(); } } job_wrapper_.UpdateJob(&new_job); return success(); } LogicalResult JobImporter::ConvertUserOp(Operation* op, ::oneflow::Job& job) { oneflow::ConfOpAdaptor conf_op_adaptor(op->getOperands(), op->getAttrDictionary()); UpdatePlacement(op, conf_op_adaptor, job); StringRef op_name = conf_op_adaptor.getOpName(); auto* op_conf = job.mutable_net()->add_op(); auto* user_conf = op_conf->mutable_user_conf(); if (!succeeded(ConvertUserOpInputs(op, op_name, user_conf))) { op->emitError("fail to convert user op inputs"); return failure(); } if (!succeeded(ConvertUserOpOutputs(op, op_name, user_conf))) { op->emitError("fail to convert user op outputs"); return failure(); } if (!succeeded(user_op::ConvertUserOpAttributes(op, *op_conf, false))) { op->emitError("fail to convert user op attributes"); return failure(); } if (!succeeded(ConvertCtrlInputs(op, *op_conf))) { op->emitError("fail to convert user op control inputs"); return failure(); } return success(); } LogicalResult JobImporter::ConvertSystemOp(Operation* op, ::oneflow::Job& job) { oneflow::SystemOpAdaptor system_op_adaptor(op->getOperands(), op->getAttrDictionary()); UpdatePlacement(op, system_op_adaptor, job); auto op_name = system_op_adaptor.getOpName().str(); ::oneflow::OperatorConf op_conf = job_wrapper_.OpConf4OpName(op_name); for (const auto& ibn : llvm::enumerate(op->getAttrOfType("input_bns"))) { auto result = GetDataInputOperands(op)[ibn.index()].dyn_cast(); std::string new_val = user_op::GetOutputLbn(result).value(); job_wrapper_.ReplaceInputLbnInOpCustomizedConf( &op_conf, ibn.value().dyn_cast().getValue().str(), new_val); } if (failed(ConvertCtrlInputs(op, op_conf))) { return failure(); } *(job.mutable_net()->add_op()) = op_conf; return success(); } LogicalResult JobImporter::ConvertVariableOp(VariableOp op, ::oneflow::Job& job) { oneflow::VariableOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); UpdatePlacement(&op, op_adaptor, job); auto* op_conf = job.mutable_net()->add_op(); return ConvertVariableOpConf(op, op_conf); } LogicalResult JobImporter::ConvertInputOp(InputOp op, ::oneflow::Job& job) { oneflow::InputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); UpdatePlacement(&op, op_adaptor, job); auto* op_conf = job.mutable_net()->add_op(); return ConvertInputOpConf(op, op_conf); } LogicalResult JobImporter::ConvertOutputOp(OutputOp op, ::oneflow::Job& job) { oneflow::OutputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); UpdatePlacement(&op, op_adaptor, job); auto* op_conf = job.mutable_net()->add_op(); return ConvertOutputOpConf(op, op_conf); } Type JobImporter::GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf) { if (!blob_conf.has_data_type()) { return Type{}; } if (!blob_conf.has_shape()) { return Type{}; }; if (auto data_type = getTypeFromOneFlowDataType(GetMLIRContext(), blob_conf.data_type())) { return RankedTensorType::get({blob_conf.shape().dim().begin(), blob_conf.shape().dim().end()}, data_type); } else { return Type{}; } } void DumpMLIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, ModuleOp module, const std::string& name) { std::string mlir; llvm::raw_string_ostream os_mlir(mlir); module->print(os_mlir); job_wrapper.DumpLog(name + ".mlir", mlir); } LogicalResult ApplyRoundTripPatterns(RoundTripOneFlowJobWrapperInterface& job_wrapper, MLIRContext* context, OwningOpRef& module) { mlir::PassManager pm(context); if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_ENABLE_TIMING", false)) { pm.enableTiming(); } mlir::oneflow::CheckEnableIRPrinting(pm); // this canonicalizer should create concrete ops and create fuse opportunities pm.addPass(createCanonicalizerPass()); // we must do auto nhwc and eliminate redundant transpose op first, avoid insert redundant // transpose op due to fuse pattern like normlazation_add_relu. pm.addPass(oneflow::createAutoNhwcPass()); if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_CSE", false)) { auto cse_state = std::make_shared(); auto passes = createCSEPasses(cse_state); pm.addPass(std::move(passes.first)); pm.addPass(createCSEPass()); pm.addPass(std::move(passes.second)); } if (job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_FUSE_FORWARD_OPS", false)) { pm.addPass(oneflow::createFuseForwardOps()); pm.addPass(oneflow::createFuseIntoExistingOpPass()); } if (job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS", false)) { pm.addPass(oneflow::createOneFlowJobToFuncPass()); pm.addPass(oneflow::createCastOneFlowOpsToSignlessPass()); auto toTosa = oneflow::createLowerOneFlowToTosaPass(); CHECK(toTosa->initializeOptions("full=0 lower-job=0").succeeded()); pm.addPass(std::move(toTosa)); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addPass(oneflow::createLowerOneFlowToLinalgPass()); pm.addPass(tosa::createTosaToTensor()); pm.addNestedPass(tosa::createTosaToLinalgNamed()); pm.addNestedPass(tosa::createTosaToLinalg()); pm.addPass(createLinalgElementwiseOpFusionPass()); pm.addPass(oneflow::createFuncToOneFlowJobPass()); pm.addNestedPass(oneflow::createOutlineJitFunctionPass()); pm.addPass(createCanonicalizerPass()); } if (!job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", false)) { pm.addPass(oneflow::createFuseOpsWithBackwardImpl()); } // TODO: support backward or put it in a env flag if (job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_GROUP_MATMUL", false)) { pm.addPass(oneflow::createGroupMatMul()); } if (!job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", false)) { pm.addPass(oneflow::createPreConvertInferenceOpPass()); pm.addPass(oneflow::createConvertInferenceOpPass()); pm.addPass(oneflow::createPostConvertInferenceOpPass()); } if (!job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS", false)) { pm.addPass(oneflow::createFuseNormalizationOps()); } if (job_wrapper.IsLastIRPass() && ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH", false)) { pm.addPass(createAggregateComputeOpsPass()); auto wrap_pass = createWrapOpsToKernelLaunchPass(); std::string options = "mode=" + (::oneflow::ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false) ? wrap_mode::CUDA_GRAPH : wrap_mode::SIMPLE); (void)wrap_pass->initializeOptions(options); pm.addPass(std::move(wrap_pass)); } pm.addPass(createCanonicalizerPass()); if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_PRINT_STATS", false)) { pm.addPass(createPrintOpStatsPass()); } std::string graphviz; llvm::raw_string_ostream os_graphviz(graphviz); const bool shouldPrintGraphviz = ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_PRINT_OP_GRAPH", false); if (shouldPrintGraphviz) { pm.addPass(createPrintOpGraphPass(os_graphviz)); } if (mlir::failed(pm.run(*module))) { module->emitError("Failed to run round-trip passes"); return failure(); } if (shouldPrintGraphviz) { job_wrapper.DumpLog("RoundTripOneFlowJob.optimized.mlir.dot", graphviz); } if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_DUMPMLIR", false)) { DumpMLIR(job_wrapper, module.get(), "RoundTripOneFlowJob.optimized"); } return success(); } OwningOpRef TranslateOneFlowJobToModule(llvm::StringRef str, MLIRContext* context) { std::string cpp_str = str.str(); ::oneflow::Job job; google::protobuf::TextFormat::ParseFromString(cpp_str, &job); context->loadDialect(); context->loadDialect(); OwningOpRef module( ModuleOp::create(FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0))); return module; } void RoundTripOneFlowJob( RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::function& is_legit_job) { const ::oneflow::Job* job = job_wrapper.job(); mlir::MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module( ModuleOp::create(FileLineColLoc::get(&context, "", /*line=*/0, /*column=*/0))); JobImporter imp(job_wrapper, &context, module.get()); // TODO: Add flag in job desc to decide whether to run mlir optimizer if (succeeded(imp.ProcessJob())) { if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_DUMPMLIR", false)) { DumpMLIR(job_wrapper, module.get(), "RoundTripOneFlowJob.imported"); } if (failed(ApplyRoundTripPatterns(job_wrapper, &context, module))) { exit(EXIT_FAILURE); } if (::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_STDOUT", false) && job_wrapper.IsLastIRPass()) { // for FileCheck module->print(llvm::outs()); } // TODO: Add flag in oneflow to define if failure in MLIR is allowed if (failed(imp.TryToUpdateJob())) { llvm::errs() << "fail to update job with IR, job will stay intact, job_name: " << job->job_conf().job_name() << "\n"; exit(EXIT_FAILURE); } } else { llvm::errs() << "fail to convert job to IR, job_name: " << job->job_conf().job_name() << "\n"; exit(EXIT_FAILURE); } } std::string ConvertJobToTosaIR(RoundTripOneFlowJobWrapperInterface& job_wrapper) { const ::oneflow::Job* job = job_wrapper.job(); mlir::MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module( ModuleOp::create(FileLineColLoc::get(&context, "", /*line=*/0, /*column=*/0))); JobImporter imp(job_wrapper, &context, module.get()); if (succeeded(imp.ProcessJob())) { mlir::PassManager pm(&context); pm.addPass(createCanonicalizerPass()); pm.addPass(createConvertToSignlessForTosaPass()); pm.addPass(createLowerOneFlowToTosaPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); if (mlir::failed(pm.run(*module))) { module->emitError("Failed to run oneflow-to-tosa pass"); exit(EXIT_FAILURE); } std::string mlir; llvm::raw_string_ostream os_mlir(mlir); module->print(os_mlir); return mlir; } else { const auto& job_name = job->job_conf().job_name(); llvm::errs() << "fail to convert job to IR, job_name: " << job_name << "\n"; exit(EXIT_FAILURE); } } std::string ConvertJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper) { const ::oneflow::Job* job = job_wrapper.job(); mlir::MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module( ModuleOp::create(FileLineColLoc::get(&context, "", /*line=*/0, /*column=*/0))); JobImporter imp(job_wrapper, &context, module.get()); if (succeeded(imp.ProcessJob())) { mlir::PassManager pm(&context); pm.addPass(createCanonicalizerPass()); if (mlir::failed(pm.run(*module))) { module->emitError("Failed to run canonicalizer pass"); exit(EXIT_FAILURE); } std::string mlir; llvm::raw_string_ostream os_mlir(mlir); module->print(os_mlir); return mlir; } else { const auto& job_name = job->job_conf().job_name(); llvm::errs() << "Failed to convert Job to IR, job_name: " << job_name << "\n"; exit(EXIT_FAILURE); } } void SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) { const ::oneflow::Job* job = job_wrapper.job(); mlir::MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module( ModuleOp::create(FileLineColLoc::get(&context, "", /*line=*/0, /*column=*/0))); JobImporter imp(job_wrapper, &context, module.get()); if (succeeded(imp.ProcessJob())) { mlir::PassManager pm(&context); pm.addPass(createCanonicalizerPass()); if (mlir::failed(pm.run(*module))) { module->emitError("Failed to run canonicalizer pass"); exit(EXIT_FAILURE); } std::string mlir; llvm::raw_string_ostream os_mlir(mlir); module->print(os_mlir); std::string filename = path + "/model.mlir"; std::ofstream fs(filename, std::ios::trunc); if (!fs.is_open()) { llvm::errs() << "fail to open file " << filename; exit(EXIT_FAILURE); } fs << mlir; fs.close(); } else { const auto& job_name = job->job_conf().job_name(); llvm::errs() << "fail to convert job to IR, job_name: " << job_name << "\n"; exit(EXIT_FAILURE); } } void LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) { MLIRContext context; context.getOrLoadDialect(); context.loadDialect(); OwningOpRef module = parseSourceFile(path, &context); if (!module) { llvm::errs() << "fail to parse file: " << path << "\n"; exit(EXIT_FAILURE); } JobImporter imp(job_wrapper, &context, module.get()); if (failed(imp.TryToUpdateJob())) { llvm::errs() << "fail to load job from IR"; exit(EXIT_FAILURE); } } void registerFromOneFlowJobTranslation() { TranslateToMLIRRegistration fromOneFlowJob("import-oneflow-job", "import oneflow from job", [](llvm::StringRef str, MLIRContext* context) { return TranslateOneFlowJobToModule(str, context); }); } } // namespace oneflow } // namespace mlir ================================================ FILE: oneflow/ir/oneflow-translate/oneflow-translate.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/MLIROneFlowTranslation.h" #include "mlir/InitAllTranslations.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/Support/LogicalResult.h" int32_t main(int32_t argc, char** argv) { mlir::registerAllTranslations(); mlir::oneflow::registerFromOneFlowJobTranslation(); return failed(mlir::mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); } ================================================ FILE: oneflow/ir/test/CMakeLists.txt ================================================ llvm_canonicalize_cmake_booleans(WITH_MLIR_CUDA_CODEGEN BUILD_CUDA) message(STATUS "LLVM_TOOLS_BINARY_DIR (used as LLVM_TOOLS_DIR): ${LLVM_TOOLS_BINARY_DIR}") message(STATUS "LLVM_EXTERNAL_LIT: ${LLVM_EXTERNAL_LIT}") configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py MAIN_CONFIG ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) set(ONEFLOW_TEST_DEPENDS FileCheck count not oneflow-opt oneflow-translate) add_lit_testsuite( check-oneflow "Running the OneFlow MLIR regression tests from: ${CMAKE_CURRENT_SOURCE_DIR}" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${ONEFLOW_TEST_DEPENDS}) set_target_properties(check-oneflow PROPERTIES FOLDER "Tests") if(LLVM_PROVIDER STREQUAL "in-tree") add_dependencies(check-oneflow mlir-cpu-runner) endif() add_dependencies(check-oneflow oneflow_internal) add_dependencies(check-oneflow oneflow-runner) add_lit_testsuites(ONEFLOW ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${ONEFLOW_TEST_DEPENDS}) add_custom_target(c1 DEPENDS check-oneflow) ================================================ FILE: oneflow/ir/test/Frontend/lit.local.cfg ================================================ if not config.WITH_ONEFLOW_IREE: config.unsupported = True ================================================ FILE: oneflow/ir/test/Frontend/oneflow_to_iree.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -split-input-file \ // RUN: -auto-nhwc \ // RUN: -lower-oneflow-to-tosa \ // RUN: -tosa-make-broadcastable \ // RUN: -verify-diagnostics -o - | FileCheck %s // CHECK-NOT: oneflow oneflow.job @test_func(%arg0: tensor<1xf32>) -> tensor<1xf32> { oneflow.return %arg0 : tensor<1xf32> } oneflow.job @test_input(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.input"(%arg0) { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427412479 : i64, shape = [1 : si64] } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } oneflow.job @test_output(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.output"(%arg0) { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427412479 : i64, shape = [1 : si64] } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } oneflow.job @test_variable() -> tensor<64x3x7x7xf32> { %res = "oneflow.variable"() { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = "fw.model.conv1.weight", output_lbns = ["fw.model.conv1.weight/out"], scope_symbol_id = 4611686018427432959 : i64, shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64] } : () -> tensor<64x3x7x7xf32> oneflow.return %res : tensor<64x3x7x7xf32> } oneflow.job @test_add_n2(%arg0: tensor<1x7x7xf32>, %arg1: tensor<1x7x7xf32>) -> tensor<1x7x7xf32> { %res = "oneflow.add_n2"(%arg0, %arg1) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", op_type_name = "add_n", output_lbns = [""], scope_symbol_id = 4611686018431205375 : i64 } : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32> oneflow.return %res: tensor<1x7x7xf32> } oneflow.job @test_broadcast_add(%arg0: tensor<1x1000xf32>, %arg1: tensor<1000xf32>) -> tensor<1x1000xf32> { %res = "oneflow.broadcast_add"(%arg0, %arg1) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018431234047 : i64 } : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32> oneflow.return %res : tensor<1x1000xf32> } oneflow.job @test_max_pool_2d(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> { %y, %indice = "oneflow.max_pool_2d"(%arg0) { ceil_mode = false, data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation = [1 : si32, 1 : si32], hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "", output_lbns = ["", ""], padding = [1 : si32, 1 : si32], return_indices = false, scope_symbol_id = 4611686018427502591 : i64, stride = [2 : si32, 2 : si32] } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xi64>) oneflow.return %y : tensor<1x64x56x56xf32> } oneflow.job @test_avg_pool_2d(%arg0: tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32> { %res = "oneflow.avg_pool_2d"(%arg0) { ceil_mode = false, count_include_pad = true, data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", divisor_override = 0 : si32, hierarchy = [1], kernel_size = [7 : si32, 7 : si32], op_name = "model.avgpool-avg_pool_2d-172", output_lbns = ["model.avgpool-avg_pool_2d-172/y_0"], padding = [0 : si32, 0 : si32], scope_symbol_id = 4611686018430775295 : i64, stride = [7 : si32, 7 : si32] } : (tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32> oneflow.return %res: tensor<1x2048x1x1xf32> } oneflow.job @test_conv2d(%arg0: tensor<1x3x224x224xf32>, %arg1: tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32> { %res = "oneflow.conv2d"(%arg0, %arg1) { data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation_rate = [1 : si32, 1 : si32], filters = 512 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [1 : si32, 1 : si32], op_name = "", operand_segment_sizes = array, output_lbns = [""], padding_before = [0 : si32, 0 : si32], scope_symbol_id = 4611686018431012863 : i64, strides = [1 : si32, 1 : si32] } : (tensor<1x3x224x224xf32>, tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32> oneflow.return %res : tensor<1x5x224x224xf32> } oneflow.job @test_matmul(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>) ->tensor<1x1000xf32> { %res = "oneflow.matmul"(%arg0, %arg1) { alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018431234047 : i64, transpose_a = false, transpose_b = true } : (tensor<1x2048xf32>, tensor<1000x2048xf32>) -> tensor<1x1000xf32> oneflow.return %res : tensor<1x1000xf32> } oneflow.job @test_relu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.relu"(%arg0) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427424767 : i64 } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } oneflow.job @test_bn( %x: tensor<1x64x112x112xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %gamma: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %y, %mean, %inv_variance = "oneflow.normalization"(%x, %moving_mean, %moving_variance, %gamma, %beta) { axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "", operand_segment_sizes = array, output_lbns = ["", "", ""], result_segment_sizes = array, scope_symbol_id = 4611686018427453439 : i64, training = true } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>) oneflow.return %y: tensor<1x64x112x112xf32> } oneflow.job @test_bn_infer( %x: tensor<1x64x112x112xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %gamma: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %y = "oneflow.normalization_infer"(%x, %moving_mean, %moving_variance, %gamma, %beta) { axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "", operand_segment_sizes = array, output_lbns = ["", "", ""], result_segment_sizes = array, scope_symbol_id = 4611686018427453439 : i64, training = true } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> oneflow.return %y: tensor<1x64x112x112xf32> } ================================================ FILE: oneflow/ir/test/Frontend/tosa_to_elf.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" \ // RUN: | oneflow-opt -cse \ // RUN: --linalg-fuse-elementwise-ops -empty-tensor-to-alloc-tensor -linalg-bufferize \ // RUN: -tensor-bufferize -func-bufferize -buffer-results-to-out-params \ // RUN: -convert-linalg-to-loops -convert-math-to-libm -convert-math-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm \ // RUN: -convert-func-to-llvm -finalize-memref-to-llvm -reconcile-unrealized-casts --print-after-all \ // RUN: | oneflow-translate -mlir-to-llvmir builtin.module { func.func @Graph_0(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "tosa.cast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %1 = "tosa.tanh"(%0) : (tensor<2xf32>) -> tensor<2xf32> %2 = "tosa.cast"(%1) : (tensor<2xf32>) -> tensor<2xf32> func.return %2 : tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/GPU/lit.local.cfg ================================================ if not config.WITH_MLIR_CUDA_CODEGEN: config.unsupported = True ================================================ FILE: oneflow/ir/test/GPU/nvvm_to_cubin.mlir ================================================ // RUN: oneflow-opt %s -pass-pipeline="builtin.module(gpu.module(nvvm-to-cubin))" | FileCheck %s // CHECK: .text.__nv_logf // CHECK-SAME: .text.__nv_expf module attributes {gpu.container_module, oneflow.mempool = 1 : i64} { func.func @JITOpGenerated0(%arg0: memref<1xi8>, %arg1: memref<5xi64>, %arg2: memref<1xf32>, %arg3: memref<5xf32>) attributes {llvm.emit_c_interface} { return } gpu.module @JITOpGenerated0_kernel { llvm.func @__nv_logf(f32) -> f32 llvm.func @__nv_expf(f32) -> f32 llvm.func @JITOpGenerated0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: !llvm.ptr, %arg13: !llvm.ptr, %arg14: i64, %arg15: i64, %arg16: i64, %arg17: !llvm.ptr, %arg18: !llvm.ptr, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: i64, %arg23: i64, %arg24: !llvm.ptr, %arg25: !llvm.ptr, %arg26: i64, %arg27: i64, %arg28: i64, %arg29: i64, %arg30: i64, %arg31: !llvm.ptr, %arg32: !llvm.ptr, %arg33: i64, %arg34: i64, %arg35: i64, %arg36: i64, %arg37: i64, %arg38: !llvm.ptr, %arg39: !llvm.ptr, %arg40: i64, %arg41: i64, %arg42: i64, %arg43: i64, %arg44: i64, %arg45: !llvm.ptr, %arg46: !llvm.ptr, %arg47: i64, %arg48: i64, %arg49: i64, %arg50: i64, %arg51: i64) attributes {gpu.kernel, nvvm.kernel} { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %4 = llvm.insertvalue %arg5, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %5 = llvm.insertvalue %arg6, %4[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %6 = llvm.insertvalue %arg7, %5[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %7 = llvm.insertvalue %arg8, %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %8 = llvm.insertvalue %arg12, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.insertvalue %arg13, %8[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.insertvalue %arg17, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %11 = llvm.insertvalue %arg18, %10[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %12 = llvm.insertvalue %arg19, %11[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %13 = llvm.insertvalue %arg20, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %14 = llvm.insertvalue %arg24, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %15 = llvm.insertvalue %arg25, %14[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %16 = llvm.insertvalue %arg26, %15[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %17 = llvm.insertvalue %arg27, %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %18 = llvm.insertvalue %arg31, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %19 = llvm.insertvalue %arg32, %18[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %20 = llvm.insertvalue %arg33, %19[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %21 = llvm.insertvalue %arg34, %20[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %22 = llvm.insertvalue %arg38, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %23 = llvm.insertvalue %arg39, %22[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %24 = llvm.insertvalue %arg40, %23[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %25 = llvm.insertvalue %arg41, %24[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %26 = llvm.insertvalue %arg45, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %27 = llvm.insertvalue %arg46, %26[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %28 = llvm.insertvalue %arg47, %27[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %29 = llvm.insertvalue %arg48, %28[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %30 = llvm.mlir.constant(0 : index) : i64 %31 = llvm.mlir.constant(4000 : index) : i64 %32 = llvm.mlir.constant(1000 : index) : i64 %33 = llvm.mlir.constant(-1 : index) : i64 %34 = nvvm.read.ptx.sreg.ctaid.x : i32 %35 = llvm.sext %34 : i32 to i64 %36 = nvvm.read.ptx.sreg.ntid.x : i32 %37 = llvm.sext %36 : i32 to i64 %38 = nvvm.read.ptx.sreg.tid.x : i32 %39 = llvm.sext %38 : i32 to i64 %40 = llvm.mul %37, %35 : i64 %41 = llvm.add %39, %40 : i64 %42 = llvm.icmp "slt" %41, %31 : i64 llvm.cond_br %42, ^bb1, ^bb2 ^bb1: // pred: ^bb0 %43 = llvm.srem %41, %32 : i64 %44 = llvm.icmp "slt" %43, %30 : i64 %45 = llvm.add %43, %32 : i64 %46 = llvm.select %44, %45, %43 : i1, i64 %47 = llvm.icmp "slt" %41, %30 : i64 %48 = llvm.sub %33, %41 : i64 %49 = llvm.select %47, %48, %41 : i1, i64 %50 = llvm.sdiv %49, %32 : i64 %51 = llvm.sub %33, %50 : i64 %52 = llvm.select %47, %51, %50 : i1, i64 %53 = llvm.mul %52, %32 : i64 %54 = llvm.add %53, %46 : i64 %55 = llvm.getelementptr %arg18[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16 %56 = llvm.load %55 : !llvm.ptr -> f16 %57 = llvm.getelementptr %arg6[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16 %58 = llvm.load %57 : !llvm.ptr -> f16 %59 = llvm.getelementptr %arg1[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f16 %60 = llvm.load %59 : !llvm.ptr -> f16 %61 = llvm.getelementptr %arg13[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f16 %62 = llvm.load %61 : !llvm.ptr -> f16 %63 = llvm.getelementptr %arg25[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %64 = llvm.load %63 : !llvm.ptr -> f32 %65 = llvm.fpext %60 : f16 to f32 %66 = llvm.call @__nv_logf(%65) : (f32) -> f32 %67 = llvm.fptrunc %66 : f32 to f16 %68 = llvm.fsub %58, %67 : f16 %69 = llvm.fpext %68 : f16 to f32 %70 = llvm.call @__nv_expf(%69) : (f32) -> f32 %71 = llvm.fptrunc %70 : f32 to f16 %72 = llvm.fmul %71, %62 : f16 %73 = llvm.fsub %56, %72 : f16 %74 = llvm.fmul %69, %64 : f32 %75 = llvm.fpext %73 : f16 to f32 %76 = llvm.getelementptr %arg32[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16 llvm.store %73, %76 : f16, !llvm.ptr %77 = llvm.getelementptr %arg39[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32 llvm.store %74, %77 : f32, !llvm.ptr %78 = llvm.getelementptr %arg46[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32 llvm.store %75, %78 : f32, !llvm.ptr llvm.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 llvm.return } } } ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/lit.local.cfg ================================================ if not config.BUILD_CUDA: config.unsupported = True ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_batchnorm_relu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_bacth_norm(test_case, with_cuda): x = flow.randn(2, 3, 4, 5) bn = flow.nn.BatchNorm2d(3) if with_cuda: x = x.cuda() bn.to("cuda") eager_batch_norm_res = flow.relu(bn(x)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.m = bn def build(self, x): return flow.relu(self.m(x)) graph_to_run = GraphToRun() lazy_batch_norm_res = graph_to_run(x) test_case.assertTrue( np.allclose( eager_batch_norm_res.numpy(), lazy_batch_norm_res.numpy(), rtol=1e-5, atol=1e-5, ) ) @flow.unittest.skip_unless_1n1d() class TestNhwcConv(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_conv_graph(test_case): import oneflow.sysconfig if oneflow.sysconfig.with_cuda(): do_nhwc_bacth_norm(test_case, True) # do_nhwc_bacth_norm(test_case, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_bias_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_bias_add(test_case, with_cuda): a = flow.randn(2, 3, 4, 5) b = flow.randn(3) if with_cuda: a = a.cuda() b = b.cuda() eager_bias_add_res = flow._C.bias_add(a, b, axis=1) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() def build(self, a, b): return flow._C.bias_add(a, b, axis=1) graph_to_run = GraphToRun() lazy_bias_add_res = graph_to_run(a, b) test_case.assertTrue( np.allclose( eager_bias_add_res.numpy(), lazy_bias_add_res.numpy(), rtol=1e-5, atol=1e-5 ) ) @flow.unittest.skip_unless_1n1d() class TestNhwcBiasAdd(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_bias_add_graph(test_case): import oneflow.sysconfig if oneflow.sysconfig.with_cuda(): do_nhwc_bias_add(test_case, True) do_nhwc_bias_add(test_case, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_conv(test_case, with_cuda, with_bias): x = flow.randn(2, 3, 4, 5) conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias) if with_cuda: x = x.cuda() conv.to("cuda") eager_conv_x = conv(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.conv = conv def build(self, x): return self.conv(x) graph_to_run = GraphToRun() lazy_conv_x = graph_to_run(x) test_case.assertTrue( np.allclose(eager_conv_x.numpy(), lazy_conv_x.numpy(), rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestNhwcConv(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" def test_nhwc_conv_graph(test_case): do_nhwc_conv(test_case, True, True) do_nhwc_conv(test_case, False, True) do_nhwc_conv(test_case, True, False) do_nhwc_conv(test_case, False, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv2d_maxpool2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_conv_maxpool(test_case, with_cuda, with_bias): x = flow.randn(2, 3, 4, 5) conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias) maxpool_2d = flow.nn.MaxPool2d( kernel_size=3, padding=1, stride=2, return_indices=False ) if with_cuda: x = x.cuda() conv.to("cuda") maxpool_2d.to("cuda") eager_x = maxpool_2d(conv(x)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.conv = conv def build(self, x): return maxpool_2d(self.conv(x)) graph_to_run = GraphToRun() lazy_x = graph_to_run(x) test_case.assertTrue( np.allclose(eager_x.numpy(), lazy_x.numpy(), rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestNhwcConvMaxPool(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_conv_graph(test_case): do_nhwc_conv_maxpool(test_case, True, True) do_nhwc_conv_maxpool(test_case, True, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv_relu_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_conv(test_case, with_cuda, with_bias): x = flow.randn(2, 3, 4, 5) conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias) if with_cuda: x = x.cuda() conv.to("cuda") eager_conv_x = flow.relu(conv(x)) + flow.relu(conv(x)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.conv = conv def build(self, x): return flow.relu(self.conv(x)) + flow.relu(self.conv(x)) graph_to_run = GraphToRun() lazy_conv_x = graph_to_run(x) print(eager_conv_x.numpy().flatten()[:10]) print(lazy_conv_x.numpy().flatten()[:10]) test_case.assertTrue( np.allclose(eager_conv_x.numpy(), lazy_conv_x.numpy(), rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestNhwcConv(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_conv_graph(test_case): do_nhwc_conv(test_case, True, True) do_nhwc_conv(test_case, False, True) do_nhwc_conv(test_case, True, False) do_nhwc_conv(test_case, False, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_lenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.nn as nn import oneflow.nn.functional as F class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): out = F.relu(self.conv1(x)) out = F.max_pool2d(out, 2) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out def do_lenet(test_case, with_cuda): x = flow.randn(2, 3, 32, 32) lenet = LeNet() if with_cuda: x = x.cuda() lenet.to("cuda") eager_res = lenet(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.lenet = lenet def build(self, x): return self.lenet(x) graph_to_run = GraphToRun() lazy_res = graph_to_run(x) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestLeNet(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_lenet_graph(test_case): do_lenet(test_case, True) do_lenet(test_case, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_maxpool_2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_nhwc_maxpool_2d(test_case, with_cuda, with_return_induces): x = flow.randn(1, 4, 4, 4) maxpool_2d = flow.nn.MaxPool2d( kernel_size=3, padding=1, stride=3, return_indices=with_return_induces ) if with_cuda: x = x.cuda() maxpool_2d.to("cuda") eager_maxpool_2d_res = maxpool_2d(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.m = maxpool_2d def build(self, x): return self.m(x) graph_to_run = GraphToRun() lazy_maxpool_2d_res = graph_to_run(x) if with_return_induces: test_case.assertTrue( np.allclose( eager_maxpool_2d_res[0].numpy(), lazy_maxpool_2d_res[0].numpy(), rtol=1e-5, atol=1e-5, ) ) else: test_case.assertTrue( np.allclose( eager_maxpool_2d_res.numpy(), lazy_maxpool_2d_res.numpy(), rtol=1e-5, atol=1e-5, ) ) @flow.unittest.skip_unless_1n1d() class TestNhwcMaxPool2d(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_maxpool_2d_graph(test_case): do_nhwc_maxpool_2d(test_case, True, True) do_nhwc_maxpool_2d(test_case, True, False) do_nhwc_maxpool_2d(test_case, False, True) do_nhwc_maxpool_2d(test_case, False, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_resnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np from typing import Type, Any, Callable, Union, List, Optional import os import oneflow as flow import oneflow.unittest from oneflow import Tensor import oneflow.nn as nn __all__ = [ "ResNet", "resnet50", ] def conv3x3( in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 ) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in flowvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) return model def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def do_resnet(test_case): x = flow.randn(2, 3, 224, 224) resnet = resnet50() x = x.cuda() resnet.to("cuda") eager_res = resnet(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.resnet = resnet def build(self, x): return self.resnet(x) graph_to_run = GraphToRun() lazy_res = graph_to_run(x) test_case.assertTrue( # TODO(yuhao): High precision loss np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-4, atol=1e-1) ) @flow.unittest.skip_unless_1n1d() class TestResNet(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_nhwc_resnet_graph(test_case): do_resnet(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_transpose_eliminate.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.transpose import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def do_eliminate_transpose(test_case, with_cuda): x = flow.randn(2, 3, 4, 5) if with_cuda: x = x.cuda() eager_res = flow.permute(flow.permute(x, (0, 2, 3, 1)), (0, 3, 1, 2)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): return flow.permute(flow.permute(x, (0, 2, 3, 1)), (0, 3, 1, 2)) graph_to_run = GraphToRun() lazy_res = graph_to_run(x) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestNhwcEliminateTranspose(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" def test_eliminate_transpose(test_case): do_eliminate_transpose(test_case, True) do_eliminate_transpose(test_case, False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/auto_nhwc/test_resnet101_benchmark.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import unittest import numpy as np import time import datetime from typing import Type, Any, Callable, Union, List, Optional import os os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" import oneflow as flow import oneflow.unittest from oneflow import Tensor import oneflow.nn as nn __all__ = [ "ResNet", "resnet50", ] def conv3x3( in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 ) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in flowvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) return model def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: """ Constructs the ResNet-101 model. .. note:: `Deep Residual Learning for Image Recognition `_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> resnet101 = flowvision.models.resnet101(pretrained=False, progress=True) """ return _resnet( "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs ) def bench(forward: Callable, x, n=1000): # warm up for _ in range(5): output = forward(x) res = output.numpy() flow._oneflow_internal.profiler.RangePush("eval begin") start_time = time.time() for _ in range(n): flow._oneflow_internal.profiler.RangePush("forward") output = forward(x) flow._oneflow_internal.profiler.RangePop() flow._oneflow_internal.profiler.RangePush("numpy") res = output.numpy() flow._oneflow_internal.profiler.RangePop() flow._oneflow_internal.profiler.RangePop() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(total_time_str) class ResNetEvalGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model self.config.enable_amp(True) def build(self, x): y_pred = self.model(x) return y_pred def main(): np.random.seed(42) device = oneflow.device("cuda") model = resnet101() model.eval() model.to(device) batch_size = 64 x = oneflow.randn(batch_size, 3, 224, 224).to(oneflow.device("cuda")) model_graph = ResNetEvalGraph(model) bench(model_graph, x, n=10) if __name__ == "__main__": main() ================================================ FILE: oneflow/ir/test/OneFlow/conversion/lower_to_tosa.mlir ================================================ // RUN: oneflow-opt \ // RUN: -lower-oneflow-to-tosa \ // RUN: -tosa-make-broadcastable \ // RUN: --print-after-all %s module { func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>, %arg1: tensor<1xf32>) -> tensor<96x96xf32> { %0 = "oneflow.cast"(%arg0) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32> %1 = "oneflow.scalar_mul_by_tensor"(%0, %arg1) {device_name = ["0:0"], device_tag = "cpu", hierarchy = [1], op_name = "ScalarMulByTensor_2", op_type_name = "scalar_mul_by_tensor", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<1xf32>) -> tensor<96x96xf32> return %1 : tensor<96x96xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/conversion/lower_to_tosa_signed.mlir ================================================ // RUN: oneflow-opt -convert-to-signless-for-tosa --mlir-print-ir-before-all --mlir-print-ir-after-all \ // RUN: -lower-oneflow-to-tosa \ // RUN: -tosa-make-broadcastable \ // RUN: -reconcile-unrealized-casts --print-after-all %s module { func.func @test(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xsi64> { %1, %indice = "oneflow.max_pool_2d"(%arg0) { ceil_mode = false, data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation = [1 : si32, 1 : si32], hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "model.maxpool-max_pool_2d-3", padding = [1 : si32, 1 : si32], return_indices = false, scope_symbol_id = 49 : i64, stride = [2 : si32, 2 : si32] } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xsi64>) return %indice : tensor<1x64x56x56xsi64> } } ================================================ FILE: oneflow/ir/test/OneFlow/conversion/oneflow_to_tosa.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -split-input-file \ // RUN: -auto-nhwc \ // RUN: -lower-oneflow-to-tosa \ // RUN: -verify-diagnostics -o - \ // RUN: | FileCheck %s // CHECK-LABEL: test_func // CHECK: return [[V0:%.+]] : tensor<1xf32> oneflow.job @test_func(%arg0: tensor<1xf32>) -> tensor<1xf32> { oneflow.return %arg0 : tensor<1xf32> } // CHECK-LABEL: test_input // CHECK: return [[V0:%.+]] : tensor<1xf32> oneflow.job @test_input(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.input"(%arg0) { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427412479 : i64, shape = [1 : si64] } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } // CHECK-LABEL: test_output // CHECK: return [[V0:%.+]] : tensor<1xf32> oneflow.job @test_output(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.output"(%arg0) { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427412479 : i64, shape = [1 : si64] } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } // CHECK-LABEL: test_variable // CHECK: [[V0:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<0.000000e+00> : tensor<64x3x7x7xf32>} // CHECK: return [[V0]] : tensor<64x3x7x7xf32> oneflow.job @test_variable() -> tensor<64x3x7x7xf32> { %res = "oneflow.variable"() { data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = "fw.model.conv1.weight", output_lbns = ["fw.model.conv1.weight/out"], scope_symbol_id = 4611686018427432959 : i64, shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64] } : () -> tensor<64x3x7x7xf32> oneflow.return %res : tensor<64x3x7x7xf32> } // CHECK-LABEL: test_add_n2 // CHECK: [[V0:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32> // CHECK: return [[V0]] : tensor<1x7x7xf32> oneflow.job @test_add_n2(%arg0: tensor<1x7x7xf32>, %arg1: tensor<1x7x7xf32>) -> tensor<1x7x7xf32> { %res = "oneflow.add_n2"(%arg0, %arg1) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", op_type_name = "add_n", output_lbns = [""], scope_symbol_id = 4611686018431205375 : i64 } : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32> oneflow.return %res: tensor<1x7x7xf32> } //CHECK-LABEL: test_broadcast_add //CHECK: [[V0:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32> //CHECK: return [[V0]] : tensor<1x1000xf32> oneflow.job @test_broadcast_add(%arg0: tensor<1x1000xf32>, %arg1: tensor<1000xf32>) -> tensor<1x1000xf32> { %res = "oneflow.broadcast_add"(%arg0, %arg1) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018431234047 : i64 } : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32> oneflow.return %res : tensor<1x1000xf32> } // CHECK-LABEL: test_max_pool_2d // CHECK: [[V0:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} // CHECK: [[V1:%.+]] = "tosa.transpose"(%arg0, [[V0]]) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32> // CHECK: [[V2:%.+]] = "tosa.max_pool2d"([[V1]]) // CHECK-SAME {kernel = array, pad = array, stride = array} // CHECK: [[V3:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} // CHECK: [[V4:%.+]] = "tosa.transpose"([[V2]], [[V3]]) : (tensor<1x56x56x64xf32>, tensor<4xi32>) -> tensor<1x64x56x56xf32> // CHECK: [[V5:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<0> : tensor<1x64x56x56xi64>} // CHECK: return [[V4]] : tensor<1x64x56x56xf32> oneflow.job @test_max_pool_2d(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> { %y, %indice = "oneflow.max_pool_2d"(%arg0) { ceil_mode = false, data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation = [1 : si32, 1 : si32], hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "", output_lbns = ["", ""], padding = [1 : si32, 1 : si32], return_indices = false, scope_symbol_id = 4611686018427502591 : i64, stride = [2 : si32, 2 : si32] } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xi64>) oneflow.return %y : tensor<1x64x56x56xf32> } // CHECK-LABEL: test_avg_pool_2d // CHECK: [[V0:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} // CHECK: [[V1:%.+]] = "tosa.transpose"(%arg0, [[V0]]) : (tensor<1x2048x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x2048xf32> // CHECK: [[V2:%.+]] = "tosa.avg_pool2d"([[V1]]) // CHECK-SAME: {kernel = array, pad = array, stride = array} // CHECK: [[V3:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} // CHECK: [[V4:%.+]] = "tosa.transpose"([[V2]], [[V3]]) : (tensor<1x1x1x2048xf32>, tensor<4xi32>) -> tensor<1x2048x1x1xf32> // CHECK: return [[V4]] : tensor<1x2048x1x1xf32> oneflow.job @test_avg_pool_2d(%arg0: tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32> { %res = "oneflow.avg_pool_2d"(%arg0) { ceil_mode = false, count_include_pad = true, data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", divisor_override = 0 : si32, hierarchy = [1], kernel_size = [7 : si32, 7 : si32], op_name = "model.avgpool-avg_pool_2d-172", output_lbns = ["model.avgpool-avg_pool_2d-172/y_0"], padding = [0 : si32, 0 : si32], scope_symbol_id = 4611686018430775295 : i64, stride = [7 : si32, 7 : si32] } : (tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32> oneflow.return %res: tensor<1x2048x1x1xf32> } // CHECK-LABEL: test_conv2d // CHECK: [[V0:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<0.000000e+00> : tensor<5xf32>} // CHECK: [[V1:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} // CHECK: [[V2:%.+]] = "tosa.transpose"(%arg0, [[V1]]) : (tensor<1x3x224x224xf32>, tensor<4xi32>) -> tensor<1x224x224x3xf32> // CHECK: [[V3:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} // CHECK: [[V4:%.+]] = "tosa.transpose"(%arg1, [[V3]]) : (tensor<5x3x1x1xf32>, tensor<4xi32>) -> tensor<5x1x1x3xf32> // CHECK: [[V5:%.+]] = "tosa.conv2d"([[V2]], [[V4]], [[V0]]) // CHECK-SAME: {dilation = array, pad = array, stride = array} // CHECK: [[V6:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} // CHECK: [[V7:%.+]] = "tosa.transpose"([[V5]], [[V6]]) : (tensor<1x224x224x5xf32>, tensor<4xi32>) -> tensor<1x5x224x224xf32> // CHECK: return [[V7]] : tensor<1x5x224x224xf32> oneflow.job @test_conv2d(%arg0: tensor<1x3x224x224xf32>, %arg1: tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32> { %res = "oneflow.conv2d"(%arg0, %arg1) { data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation_rate = [1 : si32, 1 : si32], filters = 512 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [1 : si32, 1 : si32], op_name = "", operand_segment_sizes = array, output_lbns = [""], padding_before = [0 : si32, 0 : si32], scope_symbol_id = 4611686018431012863 : i64, strides = [1 : si32, 1 : si32] } : (tensor<1x3x224x224xf32>, tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32> oneflow.return %res : tensor<1x5x224x224xf32> } // CHECK-LABEL: test_matmul // CHECK: [[V0:%.+]] = "tosa.reshape"(%arg0) // CHECK: [[V1:%.+]] = "tosa.const"() // CHECK-SAME: {value = dense<[1, 0]> : tensor<2xi32>} // CHECK: [[V2:%.+]] = "tosa.transpose"(%arg1, [[V1]]) : (tensor<1000x2048xf32>, tensor<2xi32>) -> tensor<2048x1000xf32> // CHECK: [[V3:%.+]] = "tosa.reshape"([[V2]]) // CHECK: [[V4:%.+]] = "tosa.matmul"([[V0]], [[V3]]) : (tensor<1x1x2048xf32>, tensor<1x2048x1000xf32>) -> tensor<1x1x1000xf32> // CHECK: [[V5:%.+]] = "tosa.reshape"([[V4]]) // CHECK: return [[V5]] : tensor<1x1000xf32> oneflow.job @test_matmul(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>) ->tensor<1x1000xf32> { %res = "oneflow.matmul"(%arg0, %arg1) { alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018431234047 : i64, transpose_a = false, transpose_b = true } : (tensor<1x2048xf32>, tensor<1000x2048xf32>) -> tensor<1x1000xf32> oneflow.return %res : tensor<1x1000xf32> } // CHECK-LABEL: test_relu // CHECK: [[V0:%.+]] = "tosa.maximum"([[V1:%.+]], [[V2:%.+]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: return [[V0]] : tensor<1xf32> oneflow.job @test_relu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %res = "oneflow.relu"(%arg0) { device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "", output_lbns = [""], scope_symbol_id = 4611686018427424767 : i64 } : (tensor<1xf32>) -> tensor<1xf32> oneflow.return %res : tensor<1xf32> } // CHECK-LABEL: test_bn // CHECK: "tosa.sub" // CHECK: "tosa.add" // CHECK: "tosa.rsqrt" // CHECK: "tosa.mul" // CHECK: "tosa.mul" // CHECK: "tosa.add" oneflow.job @test_bn( %x: tensor<1x64x112x112xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %gamma: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %y, %mean, %inv_variance = "oneflow.normalization"(%x, %moving_mean, %moving_variance, %gamma, %beta) { axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "", operand_segment_sizes = array, output_lbns = ["", "", ""], result_segment_sizes = array, scope_symbol_id = 4611686018427453439 : i64, training = true } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>) oneflow.return %y: tensor<1x64x112x112xf32> } // CHECK-LABEL: test_bn_infer // CHECK: "tosa.sub" // CHECK: "tosa.add" // CHECK: "tosa.rsqrt" // CHECK: "tosa.mul" // CHECK: "tosa.mul" // CHECK: "tosa.add" oneflow.job @test_bn_infer( %x: tensor<1x64x112x112xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %gamma: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %y = "oneflow.normalization_infer"(%x, %moving_mean, %moving_variance, %gamma, %beta) { axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "", operand_segment_sizes = array, output_lbns = ["", "", ""], result_segment_sizes = array, scope_symbol_id = 4611686018427453439 : i64, training = true } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> oneflow.return %y: tensor<1x64x112x112xf32> } ================================================ FILE: oneflow/ir/test/OneFlow/cse.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -cse-with-attributes-ignored -cse -cse-put-attributes -canonicalize | FileCheck %s module { func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>) -> tensor<96x96xf32> { %0 = "oneflow.cast"(%arg0) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32> %1 = "oneflow.cast"(%arg0) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_2", op_type_name = "cast", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32> %2 = "oneflow.add_n"(%0, %1) {device_name = ["0:0"], device_tag = "cpu", hierarchy = [1], op_name = "ScalarMulByTensor_2", op_type_name = "add_n", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<96x96xf32>) -> tensor<96x96xf32> // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.cast" // CHECK: "oneflow.add_n2"(%[[OUT]], %[[OUT]]) // CHECK: op_name = "ScalarMulByTensor_2" return %2 : tensor<96x96xf32> } func.func @f2(%input: tensor<2x64x64x320xf16>, %w: tensor<320x320x3x3xf16>, %bias: tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>) { %transpose_w = "oneflow.transpose"(%w) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16> %conv2d = "oneflow.conv2d"(%input, %transpose_w, %bias) {data_format = "channels_last", device_name = ["@0:0"], device_tag = "cuda", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31", operand_segment_sizes = array, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16> %transpose_w1 = "oneflow.transpose"(%w) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_2", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16> %conv2d_2 = "oneflow.conv2d"(%input, %transpose_w1, %bias) {data_format = "channels_last", device_name = ["@0:0"], device_tag = "cuda", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31", operand_segment_sizes = array, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16> return %conv2d, %conv2d_2 : tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16> // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.conv2d" // CHECK: scope_symbol_id = 163 : i64 // CHECK: return %[[OUT]], %[[OUT]] } } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/gpu_copy_arg.mlir ================================================ // RUN: oneflow-opt %s -lower-oneflow-to-tosa -tosa-make-broadcastable \ // RUN: | oneflow-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" \ // RUN: | oneflow-opt -cse --linalg-fuse-elementwise-ops -linalg-bufferize -convert-linalg-to-parallel-loops -gpu-map-parallel-loops \ // RUN: -convert-parallel-loops-to-gpu -gpu-kernel-outlining -buffer-host-register -canonicalize \ // RUN: | oneflow-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,lower-affine,convert-gpu-to-nvvm,gpu-to-cubin))' \ // RUN: | oneflow-opt --func-bufferize -buffer-results-to-out-params -gpu-copy-arg func.func @Cast_289__FUSE__ScalarMulByTensor_290(%arg0: tensor<3x3xi64>, %arg1: tensor<1xf32>) -> tensor<3x3xf32> { %0 = "oneflow.cast"(%arg0) {device_name = ["@0:0"], device_tag = "cuda", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_289", output_lbns = ["Cast_289/out_0"], scope_symbol_id = 4611686018427478014 : i64} : (tensor<3x3xi64>) -> tensor<3x3xf32> %1 = "oneflow.scalar_mul_by_tensor"(%0, %arg1) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "ScalarMulByTensor_290", output_lbns = ["ScalarMulByTensor_290/y_0"], scope_symbol_id = 4611686018427478014 : i64} : (tensor<3x3xf32>, tensor<1xf32>) -> tensor<3x3xf32> return %1 : tensor<3x3xf32> } // CHECK: gpu.memcpy %arg2 ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/lit.local.cfg ================================================ if not config.WITH_MLIR_CUDA_CODEGEN: config.unsupported = True ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_append_oneflow_stream.mlir ================================================ // RUN: oneflow-opt %s -append-ofstream | FileCheck %s // CHECK: func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>, %arg3: !llvm.ptr) attributes {llvm.emit_c_interface} module attributes {gpu.container_module} { func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} { %c5 = arith.constant 5 : index %c1 = arith.constant 1 : index %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref gpu.launch_func @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg1 : memref<5xi64>, %collapse_shape : memref, %arg2 : memref<5xf32>) return } gpu.module @JITOpGenerated0_kernel attributes {gpu.binary = ""} { llvm.func @JITOpGenerated0_kernel(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: !llvm.ptr, %arg9: !llvm.ptr, %arg10: i64, %arg11: i64, %arg12: i64) attributes {gpu.kernel, gpu.known_block_size = array, nvvm.kernel} { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64)> %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64)> %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64)> %10 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.insertvalue %arg8, %10[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.insertvalue %arg9, %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %13 = llvm.insertvalue %arg10, %12[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %14 = llvm.insertvalue %arg11, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %15 = llvm.insertvalue %arg12, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %16 = nvvm.read.ptx.sreg.ctaid.x : i32 %17 = llvm.sext %16 : i32 to i64 %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %19 = llvm.getelementptr %18[%17] : (!llvm.ptr, i64) -> !llvm.ptr %20 = llvm.load %19 : !llvm.ptr %21 = llvm.extractvalue %9[1] : !llvm.struct<(ptr, ptr, i64)> %22 = llvm.load %21 : !llvm.ptr %23 = llvm.sitofp %20 : i64 to f32 %24 = llvm.fmul %23, %22 : f32 %25 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %26 = llvm.getelementptr %25[%17] : (!llvm.ptr, i64) -> !llvm.ptr llvm.store %24, %26 : !llvm.ptr llvm.return } } } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_cast_ops_to_signless.mlir ================================================ // RUN: oneflow-opt %s -cast-ofops-to-signless | FileCheck %s // CHECK: unrealized_conversion_cast func.func @Cast_289__FUSE__ScalarMulByTensor_290() -> tensor<512x2048x1x1xf32> { %output_299 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "resnet.layer4.2.conv1.weight", output_lbns = ["resnet.layer4.2.conv1.weight/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 1995 : i64, shape = [512 : si64, 2048 : si64, 1 : si64, 1 : si64]} : () -> tensor<512x2048x1x1xsi64> %0 = "oneflow.cast"(%output_299) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", scope_symbol_id = 4611686018427416574 : i64} : (tensor<512x2048x1x1xsi64>) -> tensor<512x2048x1x1xf32> func.return %0 : tensor<512x2048x1x1xf32> } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_fold_alloc_to_subview.mlir ================================================ // RUN: oneflow-opt %s -fold-alloc-to-subview #map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)> #map1 = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> module attributes {gpu.container_module} { func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c5 = arith.constant 5 : index %c1 = arith.constant 1 : index %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref %alloc = memref.alloc() {alignment = 64 : i64} : memref<5xf32> // CHECK-NOT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<5xf32> // CHECK: memref.alloc() : memref<512xi8> // CHECK: memref.view %c1_0 = arith.constant 1 : index %0 = affine.apply #map(%c5)[%c0, %c1] gpu.launch_func @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%0, %c1_0, %c1_0) threads in (%c1_0, %c1_0, %c1_0) args(%arg1 : memref<5xi64>, %alloc : memref<5xf32>) %c1_2 = arith.constant 1 : index %1 = affine.apply #map(%c5)[%c0, %c1] gpu.launch_func @JITOpGenerated0_kernel_0::@JITOpGenerated0_kernel blocks in (%1, %c1_2, %c1_2) threads in (%c1_2, %c1_2, %c1_2) args(%alloc : memref<5xf32>, %collapse_shape : memref, %arg2 : memref<5xf32>) return } gpu.module @JITOpGenerated0_kernel { gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xi64>, %arg1: memref<5xf32>) kernel attributes {gpu.known_block_size = array} { %0 = gpu.block_id x %1 = gpu.block_id y %2 = gpu.block_id z %3 = gpu.thread_id x %4 = gpu.thread_id y %5 = gpu.thread_id z %6 = gpu.grid_dim x %7 = gpu.grid_dim y %8 = gpu.grid_dim z %9 = gpu.block_dim x %10 = gpu.block_dim y %11 = gpu.block_dim z cf.br ^bb1 ^bb1: // pred: ^bb0 %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %12 = affine.apply #map1(%0)[%c1, %c0] %13 = memref.load %arg0[%12] : memref<5xi64> %14 = arith.sitofp %13 : i64 to f32 memref.store %14, %arg1[%12] : memref<5xf32> gpu.return } } gpu.module @JITOpGenerated0_kernel_0 { gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xf32>, %arg1: memref, %arg2: memref<5xf32>) kernel attributes {gpu.known_block_size = array} { %0 = gpu.block_id x %1 = gpu.block_id y %2 = gpu.block_id z %3 = gpu.thread_id x %4 = gpu.thread_id y %5 = gpu.thread_id z %6 = gpu.grid_dim x %7 = gpu.grid_dim y %8 = gpu.grid_dim z %9 = gpu.block_dim x %10 = gpu.block_dim y %11 = gpu.block_dim z cf.br ^bb1 ^bb1: // pred: ^bb0 %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %12 = affine.apply #map1(%0)[%c1, %c0] %13 = memref.load %arg0[%12] : memref<5xf32> %14 = memref.load %arg1[] : memref %15 = arith.mulf %13, %14 : f32 memref.store %15, %arg2[%12] : memref<5xf32> gpu.return } } } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_fuser_cast_scale.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: jit import unittest import numpy as np import os import oneflow as flow import oneflow.unittest class CastModule(flow.nn.Module): def __init__(self): super().__init__() def forward(self, x, scale): # TODO: also support scale as a scalar, for instance: scale = 7.7 return x.to(dtype=flow.float32) * scale def do_relu_graph(test_case, data, with_cuda): x = flow.tensor(data, dtype=flow.int64) scale = flow.tensor([7.7], dtype=flow.float32) if with_cuda: x = x.cuda() scale = scale.cuda() module_to_run = CastModule() y_eager = module_to_run(x, scale) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.fw = module_to_run def build(self, x, scale): return self.fw(x, scale) graph_to_run = GraphToRun() y_lazy = graph_to_run(x, scale) test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) @flow.unittest.skip_unless_1n1d() class TestFuseCastScale(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" def test_relu_graph(test_case): import oneflow.sysconfig if oneflow.sysconfig.with_cuda(): do_relu_graph(test_case, np.array([2.0, 1.0, 0.0, -1.0, -2.0]), True) do_relu_graph( test_case, np.array([[2.0, 1.0, 0.0, -1.0, -2.0], [2.0, 1.0, 0.0, -1.0, -2.0]]), False, ) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_gpu_all_reduce.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: | oneflow-opt -gpu-kernel-outlining \ // RUN: | oneflow-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin))' \ // RUN: | oneflow-opt -gpu-to-llvm \ // RUN: | oneflow-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ // RUN: --entry-point-result=void \ // RUN: | FileCheck %s func.func @main() { %data = memref.alloc() : memref<2x6xi32> %sum = memref.alloc() : memref<2xi32> %cst0 = arith.constant 0 : i32 %cst1 = arith.constant 1 : i32 %cst2 = arith.constant 2 : i32 %cst4 = arith.constant 4 : i32 %cst8 = arith.constant 8 : i32 %cst16 = arith.constant 16 : i32 %cst3 = arith.constant 3 : i32 %cst6 = arith.constant 6 : i32 %cst7 = arith.constant 7 : i32 %cst10 = arith.constant 10 : i32 %cst11 = arith.constant 11 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index %cast_data = memref.cast %data : memref<2x6xi32> to memref<*xi32> gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref.cast %sum : memref<2xi32> to memref<*xi32> gpu.host_register %cast_sum : memref<*xi32> memref.store %cst0, %data[%c0, %c0] : memref<2x6xi32> memref.store %cst1, %data[%c0, %c1] : memref<2x6xi32> memref.store %cst2, %data[%c0, %c2] : memref<2x6xi32> memref.store %cst4, %data[%c0, %c3] : memref<2x6xi32> memref.store %cst8, %data[%c0, %c4] : memref<2x6xi32> memref.store %cst16, %data[%c0, %c5] : memref<2x6xi32> memref.store %cst2, %data[%c1, %c0] : memref<2x6xi32> memref.store %cst3, %data[%c1, %c1] : memref<2x6xi32> memref.store %cst6, %data[%c1, %c2] : memref<2x6xi32> memref.store %cst7, %data[%c1, %c3] : memref<2x6xi32> memref.store %cst10, %data[%c1, %c4] : memref<2x6xi32> memref.store %cst11, %data[%c1, %c5] : memref<2x6xi32> // MAX gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { %val = memref.load %data[%bx, %tx] : memref<2x6xi32> %reduced = gpu.all_reduce max %val uniform {} : (i32) -> (i32) memref.store %reduced, %sum[%bx] : memref<2xi32> gpu.terminator } call @printMemrefI32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [16, 11] return } func.func private @printMemrefI32(memref<*xi32>) ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_insert_ofmempool.mlir ================================================ // RUN: oneflow-opt %s -insert-ofmempool | FileCheck %s #map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> module attributes {gpu.container_module} { // CHECK: func.func @JITOpGenerated0(%[[ARG0:[a-zA-Z0-9_]+]]: memref<512xi8> func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c5 = arith.constant 5 : index %c0 = arith.constant 0 : index // CHECK-NOT: memref.alloc() : memref<512xi8> %alloc = memref.alloc() : memref<512xi8> // CHECK: memref.view %[[ARG0]] %view = memref.view %alloc[%c0][] : memref<512xi8> to memref<5xf32> %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref gpu.launch_func @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg1 : memref<5xi64>, %view : memref<5xf32>) gpu.launch_func @JITOpGenerated0_kernel_0::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%view : memref<5xf32>, %collapse_shape : memref, %arg2 : memref<5xf32>) return } gpu.module @JITOpGenerated0_kernel { gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xi64>, %arg1: memref<5xf32>) kernel attributes {gpu.known_block_size = array} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %0 = gpu.block_id x cf.br ^bb1 ^bb1: // pred: ^bb0 %1 = affine.apply #map(%0)[%c1, %c0] %2 = memref.load %arg0[%1] : memref<5xi64> %3 = arith.sitofp %2 : i64 to f32 memref.store %3, %arg1[%1] : memref<5xf32> gpu.return } } gpu.module @JITOpGenerated0_kernel_0 { gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xf32>, %arg1: memref, %arg2: memref<5xf32>) kernel attributes {gpu.known_block_size = array} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %0 = gpu.block_id x cf.br ^bb1 ^bb1: // pred: ^bb0 %1 = affine.apply #map(%0)[%c1, %c0] %2 = memref.load %arg0[%1] : memref<5xf32> %3 = memref.load %arg1[] : memref %4 = arith.mulf %2, %3 : f32 memref.store %4, %arg2[%1] : memref<5xf32> gpu.return } } } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: jit import unittest import numpy as np import os import oneflow as flow import oneflow.unittest class MatMulModule(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.Tensor(5, 10)) self.b = flow.nn.Parameter(flow.Tensor(10)) def forward(self, x): return flow.matmul(x, self.w) + self.b def do_matmul_graph(test_case, with_cuda=False): x = flow.randn(2, 5) module_to_run = MatMulModule() if with_cuda: x = x.cuda() module_to_run = module_to_run.to("cuda") y_eager = module_to_run(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.fw = module_to_run def build(self, x): return self.fw(x) graph_to_run = GraphToRun() y_lazy = graph_to_run(x) test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) @flow.unittest.skip_unless_1n1d() class TestFuseCastScale(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" def test_relu_graph(test_case): import oneflow.sysconfig if oneflow.sysconfig.with_cuda(): do_matmul_graph(test_case, True) do_matmul_graph(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/test_mgpu_to_oneflow_stream.mlir ================================================ // RUN: oneflow-opt %s -mgpu-to-ofstream module attributes {gpu.container_module} { llvm.mlir.global internal constant @JITOpGenerated0_kernel_JITOpGenerated0_kernel_kernel_name("JITOpGenerated0_kernel\00") {addr_space = 0 : i32} llvm.mlir.global internal constant @JITOpGenerated0_kernel_gpubin_cst("\7FELF\02\01\013\07\00\00\00\00\00\00\00\02\00\BE\00u\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0A\00\00\00\00\00\00V\05V\00@\00\00\00\00\00@\00\0C\00\01\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.info\00.text.JITOpGenerated0_kernel\00.nv.info.JITOpGenerated0_kernel\00.nv.shared.JITOpGenerated0_kernel\00.nv.constant0.JITOpGenerated0_kernel\00.rel.nv.constant0.JITOpGenerated0_kernel\00.debug_frame\00.rel.debug_frame\00.rela.debug_frame\00.nv.callgraph\00.nv.prototype\00.nv.rel.action\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.info\00JITOpGenerated0_kernel\00.text.JITOpGenerated0_kernel\00.nv.info.JITOpGenerated0_kernel\00.nv.shared.JITOpGenerated0_kernel\00.rel.nv.constant0.JITOpGenerated0_kernel\00.nv.constant0.JITOpGenerated0_kernel\00_param\00.debug_frame\00.rel.debug_frame\00.rela.debug_frame\00.nv.callgraph\00.nv.prototype\00.nv.rel.action\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00I\00\00\00\03\00\0B\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D1\00\00\00\03\00\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\FD\00\00\00\03\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00-\01\00\00\03\00\07\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00I\01\00\00\03\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\002\00\00\00\12\10\0B\00\00\00\00\00\00\00\00\00\00\02\00\00\00\00\00\00\FF\FF\FF\FF(\00\00\00\00\00\00\00\FF\FF\FF\FF\FF\FF\FF\FF\03\00\04|\FF\FF\FF\FF\0F\0C\81\80\80(\00\08\FF\81\80(\08\81\80\80(\00\00\00\00\00\00\00\FF\FF\FF\FF0\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\01\00\00\00\00\00\00\04\04\00\00\00\04<\00\00\00\0C\81\80\80(\00\04\FC\FF\FF?\00\00\00\04\11\08\00\06\00\00\00\00\00\00\00\04/\08\00\06\00\00\00\0E\00\00\00\04\12\08\00\06\00\00\00\00\00\00\00\04\1C\04\00\F0\00\00\00\03\1B\FF\00\04\17\0C\00\00\00\00\00\00\00\00\00\00\F0!\00\04\17\0C\00\00\00\00\00\01\00\08\00\00\F0!\00\04\17\0C\00\00\00\00\00\02\00\10\00\00\F0!\00\04\17\0C\00\00\00\00\00\03\00\18\00\00\F0!\00\04\17\0C\00\00\00\00\00\04\00 \00\00\F0!\00\04\17\0C\00\00\00\00\00\05\00(\00\00\F0!\00\04\17\0C\00\00\00\00\00\06\000\00\00\F0!\00\04\17\0C\00\00\00\00\00\07\008\00\00\F0!\00\04\17\0C\00\00\00\00\00\08\00@\00\00\F0!\00\04\17\0C\00\00\00\00\00\09\00H\00\00\F0!\00\04\17\0C\00\00\00\00\00\0A\00P\00\00\F0!\00\04\17\0C\00\00\00\00\00\0B\00X\00\00\F0!\00\04\17\0C\00\00\00\00\00\0C\00`\00\00\F0!\00\03\19h\00\04\0A\08\00\02\00\00\00`\01h\00\015\00\00\047\04\00u\00\00\00\00\00\00\00\FF\FF\FF\FF\00\00\00\00\FE\FF\FF\FF\00\00\00\00\FD\FF\FF\FF\00\00\00\00K\00\00\00\00\00\00\00\00\02\02\08\10\0A/\22\00\00\00\08\00\00\00\00\00\00\08\08\00\00\00\00\00\00\10\08\00\00\00\00\00\00\18\08\00\00\00\00\00\00 \08\00\00\00\00\00\00(\08\00\00\00\00\00\000\08\00\00\00\00\00\008\08\00\00\00\00\01\00\00\08\00\00\00\00\01\00\08\08\00\00\00\00\01\00\10\08\00\00\00\00\01\00\18\08\00\00\00\00\01\00 \08\00\00\00\00\01\00(\08\00\00\00\00\01\000\08\00\00\00\00\01\008\08\00\00\00\00\02\00\00\08\00\00\00\00\02\00\08\08\00\00\00\00\02\00\10\08\00\00\00\00\02\00\18\08\00\00\00\00\02\00 \08\00\00\00\00\02\00(\08\00\00\00\00\02\000\08\00\00\00\00\02\008\08\00\00\00\00\00\00\00\14,\00\00\00\09\00\00\0C\00\00\00\00H\00\00\00\00\00\00\00\02\00\00\00\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\02z\01\00\00\0A\00\00\00\0F\00\00\00\C4\0F\00\19y\06\00\00\00\00\00\00%\00\00\00\22\0E\00\02x\03\00\08\00\00\00\00\0F\00\00\00\E2\0F\00\B9z\04\00\00F\00\00\00\0A\00\00\00\E2\0F\00\02z\04\00\00d\00\00\00\0F\00\00\00\E4\0F\00\02z\05\00\00e\00\00\00\0F\00\00\00\CA\0F\00\80y\04\04\04\00\00\00\00\19\10\0C\00\A2\0E\00%v\02\06\00Z\00\00\03\02\8E\07\00\CA\1F\00\80y\08\02\04\00\00\00\00\19\10\0C\00\E8\0E\00\80y\09\02\04\04\00\00\00\19\10\0C\00\E2\0E\00\02x\07\00\04\00\00\00\00\0F\00\00\00\CA\0F\00%v\06\06\00j\00\00\07\02\8E\07\00\E2\0F\00\12s\09\00\08\00\00\00\00\140\00\00\A4\8E\00 r\0B\04\09\00\00\00\00\00@\00\00\CAO\00\85y\00\06\0B\00\00\00\04\19\10\0C\00\E2\0F\00My\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00Gy\00\00\F0\FF\FF\FF\FF\FF\83\03\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00:\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0B\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00z\01\00\00\00\00\00\00X\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\13\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D8\02\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\02\00\00\00\06\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\DF\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\80\03\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00)\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\03\00\00\00\00\00\00$\00\00\00\00\00\00\00\03\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00O\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\14\04\00\00\00\00\00\00\F8\00\00\00\00\00\00\00\03\00\00\00\0B\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0F\01\00\00\01\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0C\05\00\00\00\00\00\00\18\00\00\00\00\00\00\00\03\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00+\01\00\00\0B\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00(\05\00\00\00\00\00\00\E0\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\EC\00\00\00\09\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\08\06\00\00\00\00\00\00\10\00\00\00\00\00\00\00\03\00\00\00\04\00\00\00\08\00\00\00\00\00\00\00\10\00\00\00\00\00\00\00\91\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\18\06\00\00\00\00\00\00\C8\01\00\00\00\00\00\00\00\00\00\00\0B\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\002\00\00\00\01\00\00\00\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\02\00\00\00\00\00\00\03\00\00\00\06\00\00\0E\80\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00") {addr_space = 0 : i32} llvm.func @JITOpGenerated0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: !llvm.ptr, %arg11: !llvm.ptr, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = builtin.unrealized_conversion_cast %5 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32> %7 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.insertvalue %arg5, %7[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.insertvalue %arg6, %8[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.insertvalue %arg7, %9[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.insertvalue %arg8, %10[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.insertvalue %arg9, %11[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %13 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %14 = llvm.insertvalue %arg10, %13[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %15 = llvm.insertvalue %arg11, %14[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %16 = llvm.insertvalue %arg12, %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %17 = llvm.insertvalue %arg13, %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %18 = llvm.insertvalue %arg14, %17[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %19 = llvm.mlir.constant(5 : index) : i64 %20 = llvm.mlir.constant(1 : index) : i64 %collapse_shape = memref.collapse_shape %6 [] : memref<1xf32> into memref %21 = builtin.unrealized_conversion_cast %collapse_shape : memref to !llvm.struct<(ptr, ptr, i64)> %22 = llvm.mlir.addressof @JITOpGenerated0_kernel_gpubin_cst : !llvm.ptr> %23 = llvm.getelementptr %22[0, 0] : (!llvm.ptr>) -> !llvm.ptr %24 = llvm.call @mgpuModuleLoad(%23) : (!llvm.ptr) -> !llvm.ptr %25 = llvm.mlir.addressof @JITOpGenerated0_kernel_JITOpGenerated0_kernel_kernel_name : !llvm.ptr> %26 = llvm.getelementptr %25[0, 0] : (!llvm.ptr>) -> !llvm.ptr %27 = llvm.call @mgpuModuleGetFunction(%24, %26) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr %28 = llvm.mlir.constant(0 : i32) : i32 // CHECK-NOT: mgpuStreamCreate %29 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr %30 = llvm.extractvalue %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %31 = llvm.extractvalue %12[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %32 = llvm.extractvalue %12[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %33 = llvm.extractvalue %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %34 = llvm.extractvalue %12[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %35 = llvm.extractvalue %21[0] : !llvm.struct<(ptr, ptr, i64)> %36 = llvm.extractvalue %21[1] : !llvm.struct<(ptr, ptr, i64)> %37 = llvm.extractvalue %21[2] : !llvm.struct<(ptr, ptr, i64)> %38 = llvm.extractvalue %18[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %39 = llvm.extractvalue %18[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %40 = llvm.extractvalue %18[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %41 = llvm.extractvalue %18[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %42 = llvm.extractvalue %18[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %43 = llvm.mlir.constant(1 : i32) : i32 %44 = llvm.alloca %43 x !llvm.struct<"", (ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)> : (i32) -> !llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>> %45 = llvm.mlir.constant(13 : i32) : i32 %46 = llvm.alloca %45 x !llvm.ptr : (i32) -> !llvm.ptr> %47 = llvm.getelementptr %44[0, 0] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %30, %47 : !llvm.ptr> %48 = llvm.getelementptr %46[0] : (!llvm.ptr>) -> !llvm.ptr> %49 = llvm.bitcast %47 : !llvm.ptr> to !llvm.ptr llvm.store %49, %48 : !llvm.ptr> %50 = llvm.getelementptr %44[0, 1] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %31, %50 : !llvm.ptr> %51 = llvm.getelementptr %46[1] : (!llvm.ptr>) -> !llvm.ptr> %52 = llvm.bitcast %50 : !llvm.ptr> to !llvm.ptr llvm.store %52, %51 : !llvm.ptr> %53 = llvm.getelementptr %44[0, 2] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %32, %53 : !llvm.ptr %54 = llvm.getelementptr %46[2] : (!llvm.ptr>) -> !llvm.ptr> %55 = llvm.bitcast %53 : !llvm.ptr to !llvm.ptr llvm.store %55, %54 : !llvm.ptr> %56 = llvm.getelementptr %44[0, 3] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %33, %56 : !llvm.ptr %57 = llvm.getelementptr %46[3] : (!llvm.ptr>) -> !llvm.ptr> %58 = llvm.bitcast %56 : !llvm.ptr to !llvm.ptr llvm.store %58, %57 : !llvm.ptr> %59 = llvm.getelementptr %44[0, 4] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %34, %59 : !llvm.ptr %60 = llvm.getelementptr %46[4] : (!llvm.ptr>) -> !llvm.ptr> %61 = llvm.bitcast %59 : !llvm.ptr to !llvm.ptr llvm.store %61, %60 : !llvm.ptr> %62 = llvm.getelementptr %44[0, 5] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %35, %62 : !llvm.ptr> %63 = llvm.getelementptr %46[5] : (!llvm.ptr>) -> !llvm.ptr> %64 = llvm.bitcast %62 : !llvm.ptr> to !llvm.ptr llvm.store %64, %63 : !llvm.ptr> %65 = llvm.getelementptr %44[0, 6] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %36, %65 : !llvm.ptr> %66 = llvm.getelementptr %46[6] : (!llvm.ptr>) -> !llvm.ptr> %67 = llvm.bitcast %65 : !llvm.ptr> to !llvm.ptr llvm.store %67, %66 : !llvm.ptr> %68 = llvm.getelementptr %44[0, 7] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %37, %68 : !llvm.ptr %69 = llvm.getelementptr %46[7] : (!llvm.ptr>) -> !llvm.ptr> %70 = llvm.bitcast %68 : !llvm.ptr to !llvm.ptr llvm.store %70, %69 : !llvm.ptr> %71 = llvm.getelementptr %44[0, 8] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %38, %71 : !llvm.ptr> %72 = llvm.getelementptr %46[8] : (!llvm.ptr>) -> !llvm.ptr> %73 = llvm.bitcast %71 : !llvm.ptr> to !llvm.ptr llvm.store %73, %72 : !llvm.ptr> %74 = llvm.getelementptr %44[0, 9] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr> llvm.store %39, %74 : !llvm.ptr> %75 = llvm.getelementptr %46[9] : (!llvm.ptr>) -> !llvm.ptr> %76 = llvm.bitcast %74 : !llvm.ptr> to !llvm.ptr llvm.store %76, %75 : !llvm.ptr> %77 = llvm.getelementptr %44[0, 10] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %40, %77 : !llvm.ptr %78 = llvm.getelementptr %46[10] : (!llvm.ptr>) -> !llvm.ptr> %79 = llvm.bitcast %77 : !llvm.ptr to !llvm.ptr llvm.store %79, %78 : !llvm.ptr> %80 = llvm.getelementptr %44[0, 11] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %41, %80 : !llvm.ptr %81 = llvm.getelementptr %46[11] : (!llvm.ptr>) -> !llvm.ptr> %82 = llvm.bitcast %80 : !llvm.ptr to !llvm.ptr llvm.store %82, %81 : !llvm.ptr> %83 = llvm.getelementptr %44[0, 12] : (!llvm.ptr, ptr, i64, i64, i64, ptr, ptr, i64, ptr, ptr, i64, i64, i64)>>) -> !llvm.ptr llvm.store %42, %83 : !llvm.ptr %84 = llvm.getelementptr %46[12] : (!llvm.ptr>) -> !llvm.ptr> %85 = llvm.bitcast %83 : !llvm.ptr to !llvm.ptr llvm.store %85, %84 : !llvm.ptr> %86 = llvm.mlir.null : !llvm.ptr> // CHECK-NOT: mgpuLaunchKernel(%18, %4, %3, %3, %3, %3, %3, %2, %arg15, %23, %62) llvm.call @mgpuLaunchKernel(%27, %19, %20, %20, %20, %20, %20, %28, %29, %46, %86) : (!llvm.ptr, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr, !llvm.ptr>, !llvm.ptr>) -> () llvm.call @mgpuStreamSynchronize(%29) : (!llvm.ptr) -> () // CHECK-NOT: mgpuStreamDestroy llvm.call @mgpuStreamDestroy(%29) : (!llvm.ptr) -> () llvm.call @mgpuModuleUnload(%24) : (!llvm.ptr) -> () llvm.return } llvm.func @_mlir_ciface_JITOpGenerated0(%arg0: !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, %arg1: !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, %arg2: !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, %arg3: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.load %arg1 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.load %arg2 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> %13 = llvm.extractvalue %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %14 = llvm.extractvalue %12[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %15 = llvm.extractvalue %12[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %16 = llvm.extractvalue %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %17 = llvm.extractvalue %12[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.call @JITOpGenerated0(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11, %13, %14, %15, %16, %17, %arg3) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr) -> () llvm.return } llvm.func @mgpuModuleLoad(!llvm.ptr) -> !llvm.ptr llvm.func @mgpuModuleGetFunction(!llvm.ptr, !llvm.ptr) -> !llvm.ptr llvm.func @mgpuStreamCreate() -> !llvm.ptr llvm.func @mgpuLaunchKernel(!llvm.ptr, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr, !llvm.ptr>, !llvm.ptr>) llvm.func @mgpuStreamSynchronize(!llvm.ptr) llvm.func @mgpuStreamDestroy(!llvm.ptr) llvm.func @mgpuModuleUnload(!llvm.ptr) } ================================================ FILE: oneflow/ir/test/OneFlow/cuda_code_gen/tosa_to_linalg.mlir ================================================ // RUN: oneflow-opt %s -ofjob-to-func --tosa-make-broadcastable \ // RUN: | oneflow-opt -pass-pipeline="builtin.module(oneflow.job(tosa-to-linalg))" \ // RUN: | oneflow-opt -func-to-ofjob oneflow.job @GraphToRun_1(%arg0: tensor<2x5xi64>, %arg1: tensor<1xf32>) -> tensor<2x5xf32> { %2 = "tosa.cast"(%arg0) : (tensor<2x5xi64>) -> tensor<2x5xf32> %3 = "tosa.mul"(%2, %arg1) {shift = 0 : i32} : (tensor<2x5xf32>, tensor<1xf32>) -> tensor<2x5xf32> oneflow.return %3 : tensor<2x5xf32> } ================================================ FILE: oneflow/ir/test/OneFlow/folding/test_conv_bn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.normalization import os import unittest import numpy as np import oneflow as flow import oneflow.unittest import oneflow.nn as nn from flowvision.models.resnet import resnet50 def _test_fuse_conv_bn(test_case): data = flow.randn(1, 3, 224, 224) model = resnet50(pretrained=False, progress=True) model.eval() eager_res = model(data) class Resnet50Graph(nn.Graph): def __init__(self): super().__init__() self.model = model def build(self, *input): return self.model(*input) graph = Resnet50Graph() lazy_res = graph(data) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-2, atol=1e-2) ) @flow.unittest.skip_unless_1n1d() class TestFuseConvBn(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" def test_fuse_conv_bn(test_case): _test_fuse_conv_bn(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/folding/test_simple_multiply.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.broadcast_mul import os import unittest import numpy as np import oneflow as flow import oneflow.unittest import oneflow.nn as nn class MultiplyModel(nn.Module): def __init__(self, dtype=flow.float32): super().__init__() self.dtype = dtype self.x = nn.Parameter(flow.tensor([2, 2], dtype=self.dtype), False) self.y = nn.Parameter(flow.tensor([3, 3], dtype=self.dtype), False) def forward(self): return self.x * self.y class MultiplyModelComplex(MultiplyModel): def __init__(self, dtype=flow.float32): super().__init__(dtype) self.z = nn.Parameter(flow.tensor([4, 5], dtype=self.dtype), False) def forward(self): return self.x * self.y * self.z class MultiplyModelWithInput(MultiplyModel): def __init__(self, dtype=flow.float32): super().__init__(dtype) def forward(self, a: flow.Tensor, b: flow.Tensor): z = self.x * self.y return a + b + z def _test_fold_multiply(test_case, module, with_cuda, *args, dtype=oneflow.float32): model = module(dtype) if with_cuda: model.to("cuda") model.eval() eager_res = model(*args) class MultiplyGraph(nn.Graph): def __init__(self): super().__init__() self.model = model def build(self, *args): return self.model(*args) graph = MultiplyGraph() lazy_res = graph(*args) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5) ) test_case.assertTrue(eager_res.dtype == dtype and lazy_res.dtype == dtype) @flow.unittest.skip_unless_1n1d() class TestFoldMultiply(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" def test_fold_multiply(test_case): _test_fold_multiply(test_case, MultiplyModel, with_cuda=False) _test_fold_multiply( test_case, MultiplyModel, with_cuda=False, dtype=flow.float16 ) @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "only test cpu cases") def test_fold_multiply_cuda(test_case): _test_fold_multiply(test_case, MultiplyModel, with_cuda=True) _test_fold_multiply( test_case, MultiplyModel, with_cuda=True, dtype=flow.float16 ) def test_fold_multiply_complex(test_case): _test_fold_multiply(test_case, MultiplyModelComplex, with_cuda=False) _test_fold_multiply( test_case, MultiplyModelComplex, with_cuda=False, dtype=flow.float16 ) @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "only test cpu cases") def test_fold_multiply_complex_cuda(test_case): _test_fold_multiply(test_case, MultiplyModelComplex, with_cuda=True) _test_fold_multiply( test_case, MultiplyModelComplex, with_cuda=True, dtype=flow.float16 ) def test_fold_multiply_with_input(test_case): a = flow.tensor([3, 7], dtype=flow.float32) b = flow.tensor([9, -1], dtype=flow.float32) a_fp16 = flow.tensor([3, 7], dtype=flow.float16) b_fp16 = flow.tensor([9, -1], dtype=flow.float16) _test_fold_multiply(test_case, MultiplyModelWithInput, False, a, b) _test_fold_multiply( test_case, MultiplyModelWithInput, False, a_fp16, b_fp16, dtype=flow.float16 ) @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "only test cpu cases") def test_fold_multiply_with_input_cuda(test_case): a = flow.tensor([3, 7], dtype=flow.float32, device="cuda") b = flow.tensor([9, -1], dtype=flow.float32, device="cuda") a_fp16 = flow.tensor([3, 7], dtype=flow.float16, device="cuda") b_fp16 = flow.tensor([9, -1], dtype=flow.float16, device="cuda") _test_fold_multiply(test_case, MultiplyModelWithInput, True, a, b) _test_fold_multiply( test_case, MultiplyModelWithInput, True, a_fp16, b_fp16, dtype=flow.float16 ) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/fuse/fuse_forward_ops.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -fuse-forward-only-ops -fuse-into-existing-op -fuse-normalization-ops -convert-inference-op -fuse-ops-with-backward-impl -canonicalize | FileCheck %s module { func.func @Cast_1__FUSE__ScalarMulByTensor_2(%685: tensor<2x64x64x320xf16>, %output_574: tensor<320xf16>, %output_573: tensor<320xf16>) -> tensor<2x64x64x320xf16> { %y_958, %mean_959, %inv_variance_960 = "oneflow.group_norm"(%685, %output_574, %output_573) {activation = "none", affine = true, data_format = "channels_last", device_name = ["@0:0"], device_tag = "cuda", epsilon = 1.000000e-05 : f64, hierarchy = [1], num_groups = 32 : si32, op_name = "unet.up_blocks.3.resnets.0.norm2-group_norm-877", operand_segment_sizes = array, scope_symbol_id = 5517 : i64} : (tensor<2x64x64x320xf16>, tensor<320xf16>, tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x32xf32>, tensor<2x32xf32>) %686 = "oneflow.silu"(%y_958) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.up_blocks.3.resnets.0.nonlinearity-silu-878", scope_symbol_id = 5466 : i64} : (tensor<2x64x64x320xf16>) -> tensor<2x64x64x320xf16> // CHECK: activation = "silu" // CHECK-NOT: oneflow.silu return %686 : tensor<2x64x64x320xf16> } func.func @GraphToRun_bias_add_and_dropout_0(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>) { %0 = "oneflow.bias_add"(%arg0, %arg1) {axis = 3 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "bias_add-0", scope_symbol_id = 12 : i64} : (tensor<2x3x4x5xf32>, tensor<5xf32>) -> tensor<2x3x4x5xf32> %out, %mask = "oneflow.dropout"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "dropout-dropout-1", rate = 0.750000e+00 : f32, scope_symbol_id = 22 : i64} : (tensor<2x3x4x5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>) // CHECK: func.func @GraphToRun_bias_add_and_dropout_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>, %[[B:[a-zA-Z0-9_]+]]: tensor<5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>) // CHECK: %[[MASK:[a-zA-Z0-9_]+]] = "oneflow.random_mask_like"(%[[A]]) // CHECK: "oneflow.fused_bias_add_mask_scale"(%[[A]], %[[B]], %[[MASK]]) // CHECK: scale = 4.000000e+00 return %out, %mask : tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8> } func.func @GraphToRun_bias_add_and_gelu_0(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<5xf32>) -> tensor<2x3x4x5xf32> { %0 = "oneflow.bias_add"(%arg0, %arg1) {axis = 3 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "bias_add-0", scope_symbol_id = 12 : i64} : (tensor<2x3x4x5xf32>, tensor<5xf32>) -> tensor<2x3x4x5xf32> %out = "oneflow.gelu"(%0) {axis = 3 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu-gelu-1", scope_symbol_id = 22 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> // CHECK: func.func @GraphToRun_bias_add_and_gelu_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>, %[[B:[a-zA-Z0-9_]+]]: tensor<5xf32>) -> tensor<2x3x4x5xf32> // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.fused_bias_add_gelu"(%[[A]], %[[B]]) {axis = 3 : si32 // CHECK: return %[[OUT0]] return %out : tensor<2x3x4x5xf32> } func.func @fuse_mha(%query: tensor<2x4096x320xf16>, %key: tensor<2x4096x320xf16>, %value: tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> { %query_reshape = "oneflow.reshape"(%query) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-1", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %key_reshape = "oneflow.reshape"(%key) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-3", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %value_reshape = "oneflow.reshape"(%value) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-5", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %query_transpose = "oneflow.transpose"(%query_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-2", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16> %key_transpose = "oneflow.transpose"(%key_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-4", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x40x4096xf16> %value_transpose = "oneflow.transpose"(%value_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-6", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16> %scores = "oneflow.batch_matmul"(%query_transpose, %key_transpose) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "batch_matmul-7", scope_symbol_id = 12 : i64, transpose_a = false, transpose_b = false} : (tensor<2x8x4096x40xf16>, tensor<2x8x40x4096xf16>) -> tensor<2x8x4096x4096xf16> %scores_scaled = "oneflow.scalar_div"(%scores) {device_name = ["@0:0"], device_tag = "cuda", float_operand = 6.324555320336759 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = "scalar_div-8", scope_symbol_id = 12 : i64} : (tensor<2x8x4096x4096xf16>) -> tensor<2x8x4096x4096xf16> %attn = "oneflow.softmax"(%scores_scaled) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "softmax-9", scope_symbol_id = 12 : i64} : (tensor<2x8x4096x4096xf16>) -> tensor<2x8x4096x4096xf16> %out = "oneflow.batch_matmul"(%attn, %value_transpose) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "batch_matmul-10", scope_symbol_id = 12 : i64, transpose_a = false, transpose_b = false} : (tensor<2x8x4096x4096xf16>, tensor<2x8x4096x40xf16>) -> tensor<2x8x4096x40xf16> %out_transpose = "oneflow.transpose"(%out) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-11", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16> %out_reshape = "oneflow.reshape"(%out_transpose) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-12", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16> // CHECK: func.func @fuse_mha(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>) // CHECK: "oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {attn_mask_type = "none", causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", key_max_seq_len = 0 : si64, op_name = [[OP_NAME:".*"]], operand_segment_sizes = array, output_layout = "BM(HK)", query_head_size = 40 : si64, query_layout = "BM(HK)", query_max_seq_len = 0 : si64, scale = 0.15811388300841897 : f64, scope_symbol_id = 12 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> return %out_reshape : tensor<2x4096x320xf16> } func.func @fuse_mha2(%query: tensor<2x4096x320xf16>, %key: tensor<2x4096x320xf16>, %value: tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> { %value_reshape = "oneflow.reshape"(%value) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-124", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %key_reshape = "oneflow.reshape"(%key) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-121", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %query_reshape = "oneflow.reshape"(%query) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-116", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16> %value_permute = "oneflow.transpose"(%value_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-125", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16> %key_permute = "oneflow.transpose"(%key_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-122", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16> %query_permute = "oneflow.transpose"(%query_reshape) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-117", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16> %value_reshape_to_batch = "oneflow.reshape"(%value_permute) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-126", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16> %key_reshape_to_batch = "oneflow.reshape"(%key_permute) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-123", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16> %query_reshape_to_batch = "oneflow.reshape"(%query_permute) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-118", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16> %key_transpose = "oneflow.transpose"(%key_reshape_to_batch) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-128", perm = [0 : si32, 2 : si32, 1 : si32], scope_symbol_id = 661 : i64} : (tensor<16x4096x40xf16>) -> tensor<16x40x4096xf16> %scores_scaled = "oneflow.batch_matmul"(%query_reshape_to_batch, %key_transpose) {alpha = 0.15811388300841897 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-batch_matmul-129", scope_symbol_id = 661 : i64, transpose_a = false, transpose_b = false} : (tensor<16x4096x40xf16>, tensor<16x40x4096xf16>) -> tensor<16x4096x4096xf16> %attn = "oneflow.softmax"(%scores_scaled) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-softmax-130", scope_symbol_id = 661 : i64} : (tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16> %309 = "oneflow.batch_matmul"(%attn, %value_reshape_to_batch) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-batch_matmul-131", scope_symbol_id = 661 : i64, transpose_a = false, transpose_b = false} : (tensor<16x4096x4096xf16>, tensor<16x4096x40xf16>) -> tensor<16x4096x40xf16> %310 = "oneflow.reshape"(%309) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-132", scope_symbol_id = 661 : i64, shape = [2 : si64, 8 : si64, 4096 : si64, 40 : si64]} : (tensor<16x4096x40xf16>) -> tensor<2x8x4096x40xf16> %311 = "oneflow.transpose"(%310) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-133", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16> %out_reshape_to_heads = "oneflow.reshape"(%311) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-134", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16> // CHECK: func.func @fuse_mha2(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>) // CHECK: oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {attn_mask_type = "none", causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", key_max_seq_len = 0 : si64, op_name = [[OP_NAME:".*"]], operand_segment_sizes = array, output_layout = "BM(HK)", query_head_size = 40 : si64, query_layout = "BM(HK)", query_max_seq_len = 0 : si64, scale = 0.15811388300841897 : f64, scope_symbol_id = 661 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> return %out_reshape_to_heads : tensor<2x4096x320xf16> } func.func @GraphToRun_pad_and_conv2d_0(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x5x6xf32> { %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "conv.weight", output_lbns = ["conv.weight/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 73 : i64, shape = [3 : si64, 3 : si64, 2 : si64, 2 : si64]} : () -> tensor<3x3x2x2xf32> %output_0 = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_2_input.0.0_2", output_lbns = ["_GraphToRun_2_input.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> %0 = "oneflow.pad"(%output_0) {device_name = ["@0:0"], device_tag = "cpu", floating_constant_value = 0.000000e+00 : f64, hierarchy = [1], integral_constant_value = 0 : si64, op_name = "pad-0", padding = [1 : si64, 1 : si64, 1 : si64, 1 : si64], padding_after = [0 : si64, 0 : si64, 1 : si64, 1 : si64], padding_before = [0 : si64, 0 : si64, 1 : si64, 1 : si64], scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x6x7xf32> %1 = "oneflow.conv2d"(%0, %output) {data_format = "channels_first", device_name = ["@0:0"], device_tag = "cpu", dilation_rate = [1 : si32, 1 : si32], filters = 3 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [2 : si32, 2 : si32], op_name = "conv-conv2d-1", operand_segment_sizes = array, padding_before = [0 : si32, 0 : si32], scope_symbol_id = 76 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x3x6x7xf32>, tensor<3x3x2x2xf32>) -> tensor<2x3x5x6xf32> %output_1 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_2_output.0.0_2", output_lbns = ["_GraphToRun_2_output.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 5 : si64, 6 : si64]} : (tensor<2x3x5x6xf32>) -> tensor<2x3x5x6xf32> // CHECK: func.func @GraphToRun_pad_and_conv2d_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x5x6xf32> { // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[A]]) // CHECK-NOT: oneflow.pad // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.conv2d"(%[[OUT0]], %[[OUT]]) // CHECK: %[[OUT2:[a-zA-Z0-9_]+]] = "oneflow.output" return %output_1 : tensor<2x3x5x6xf32> } func.func @GraphToRun_same_dtype_cast_0(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { %output_0 = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_3_input.0.0_2", output_lbns = ["_GraphToRun_3_input.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> %0 = "oneflow.cast"(%output_0) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> %output_1 = "oneflow.output"(%0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_3_output.0.0_2", output_lbns = ["_GraphToRun_3_output.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> // CHECK: func.func @GraphToRun_same_dtype_cast_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[A]]) // CHECK-NOT: oneflow.cast // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[OUT0]]) // CHECK:return %[[OUT]] : tensor<2x3x4x5xf32> return %output_1 : tensor<2x3x4x5xf32> } func.func @GraphToRun_same_dtype_cast_1(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32> { %output_0 = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_4_input.0.0_2", output_lbns = ["_GraphToRun_4_input.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> %0 = "oneflow.cast"(%output_0) {device_name = ["0:0"], device_tag = "cpu", dtype = 5 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32> %output_1 = "oneflow.output"(%0) {data_type = 5 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_4_output.0.0_2", output_lbns = ["_GraphToRun_4_output.0.0_2/out"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xi32>) -> tensor<2x3x4x5xi32> // CHECK: func.func @GraphToRun_same_dtype_cast_1(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32> { // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[A]]) // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.cast"(%[[OUT0]]) // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[OUT1]]) // CHECK:return %[[OUT]] : tensor<2x3x4x5xi32> return %output_1 : tensor<2x3x4x5xi32> } func.func @GraphToRun_scale_tril_0() -> tensor<5x5xf32> { %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "FreeEagerTensor-1", output_lbns = ["FreeEagerTensor-1/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 12 : i64, shape = [5 : si64, 5 : si64], trainable = false} : () -> tensor<5x5xf32> %0 = "oneflow.scalar_mul"(%output) {device_name = ["@0:0"], device_tag = "cuda", float_operand = -2.300000e+00 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = "scalar_mul-0", scope_symbol_id = 12 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32> %1 = "oneflow.tril"(%0) {device_name = ["@0:0"], device_tag = "cuda", diagonal = -1 : si64, floating_fill_value = 0.000000e+00 : f64, hierarchy = [1], integer_fill_value = 0 : si64, is_floating_fill_value = false, op_name = "tril-2", scope_symbol_id = 12 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32> %output_0 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_TestFuseScaleTril_0_output.0.0_2", output_lbns = ["_TestFuseScaleTril_0_output.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [5 : si64, 5 : si64]} : (tensor<5x5xf32>) -> tensor<5x5xf32> // CHECK: func.func @GraphToRun_scale_tril_0() -> tensor<5x5xf32> { // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.fused_scale_tril"(%[[OUT0]]) // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[OUT1]]) // CHECK:return %[[OUT]] return %output_0 : tensor<5x5xf32> } func.func @GraphToRun_scale_tril_1() -> tensor<5x5xf32> { %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "FreeEagerTensor-1", output_lbns = ["FreeEagerTensor-1/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 66 : i64, shape = [5 : si64, 5 : si64], trainable = false} : () -> tensor<5x5xf32> %0 = "oneflow.tril"(%output) {device_name = ["@0:0"], device_tag = "cuda", diagonal = -1 : si64, floating_fill_value = 0.000000e+00 : f64, hierarchy = [1], integer_fill_value = 0 : si64, is_floating_fill_value = false, op_name = "tril-0", scope_symbol_id = 66 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32> %1 = "oneflow.scalar_mul"(%0) {device_name = ["@0:0"], device_tag = "cuda", float_operand = 2.000000e+00 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = "scalar_mul-2", scope_symbol_id = 66 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32> %output_0 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_TestFuseTrilScale_1_output.0.0_2", output_lbns = ["_TestFuseTrilScale_1_output.0.0_2/out"], scope_symbol_id = 66 : i64, shape = [5 : si64, 5 : si64]} : (tensor<5x5xf32>) -> tensor<5x5xf32> // CHECK: func.func @GraphToRun_scale_tril_1() -> tensor<5x5xf32> { // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.fused_scale_tril"(%[[OUT0]]) // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[OUT1]]) // CHECK:return %[[OUT]] return %output_0 : tensor<5x5xf32> } func.func @GraphToRun_normalization_1(%x: tensor<2x3x224x224xf32>, %moving_mean: tensor<3xf32>, %moving_variance: tensor<3xf32>, %gamma: tensor<3xf32>, %beta: tensor<3xf32>, %addend: tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> { %y, %mean, %inv_variance = "oneflow.normalization"(%x, %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "normalization-2", operand_segment_sizes = array, result_segment_sizes = array, scope_symbol_id = 12 : i64, training = true} : (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>) %0 = "oneflow.add_n2"(%y, %addend) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "add_n-7", op_type_name = "add_n", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>, tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-8", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> // CHECK: func.func @GraphToRun_normalization_1(%[[X:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[GAMMA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[ADDEND:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>) // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[reserve_space:[a-zA-Z0-9_]+]], %[[mean:[a-zA-Z0-9_]+]], %[[inv_variance:[a-zA-Z0-9_]+]] = "oneflow.normalization_add_relu"(%[[X]], %[[ADDEND]], %[[MOVING_MEAN]], %[[MOVING_VARIANCE]], %[[GAMMA]], %[[BETA]]) // CHECK: return %[[Y]] return %1 : tensor<2x3x224x224xf32> } func.func @GraphToRun_normalization_2(%x: tensor<2x3x224x224xf32>, %moving_mean: tensor<3xf32>, %moving_variance: tensor<3xf32>, %gamma: tensor<3xf32>, %beta: tensor<3xf32>, %addend: tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> { %y = "oneflow.normalization_infer"(%x, %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cpu", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "normalization-2", operand_segment_sizes = array, result_segment_sizes = array, scope_symbol_id = 12 : i64, training = true} : (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x224x224xf32>) %0 = "oneflow.add_n2"(%y, %addend) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "add_n-7", op_type_name = "add_n", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>, tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-8", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> // CHECK: func.func @GraphToRun_normalization_2(%[[X:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[GAMMA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[ADDEND:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>) // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[reserve_space:[a-zA-Z0-9_]+]], %[[mean:[a-zA-Z0-9_]+]], %[[inv_variance:[a-zA-Z0-9_]+]] = "oneflow.normalization_add_relu"(%[[X]], %[[ADDEND]], %[[MOVING_MEAN]], %[[MOVING_VARIANCE]], %[[GAMMA]], %[[BETA]]) // CHECK: return %[[Y]] return %1 : tensor<2x3x224x224xf32> } func.func @GraphToRun_conv_bn_1(%arg0: tensor<1x3x224x224xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_conv_bn_1_input.0.0_2", output_lbns = ["_conv_bn_1_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [1 : si64, 3 : si64, 224 : si64, 224 : si64]} : (tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32> %0 = "oneflow.variable_ir"() {value = dense<1.0> : tensor<64x3x7x7xf32> ,data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "model.conv1.weight", output_lbns = ["model.conv1.weight/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64], nd_sbp = ["B"]} : () -> tensor<64x3x7x7xf32> %gamma = "oneflow.variable_ir"() {value = dense<1.0> : tensor<64xf32> ,data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "model.bn.gamma", output_lbns = ["model.bn.gamma/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [64 : si64], nd_sbp = ["B"]} : () -> tensor<64xf32> %1 = "oneflow.conv2d"(%output, %0) {data_format = "channels_first", device_name = ["@0:0"], device_tag = "cuda", dilation_rate = [1 : si32, 1 : si32], filters = 64 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [7 : si32, 7 : si32], op_name = "model.conv1-conv2d-0", operand_segment_sizes = array, padding_before = [3 : si32, 3 : si32], scope_symbol_id = 21 : i64, strides = [2 : si32, 2 : si32]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<1x64x112x112xf32> %2 = "oneflow.normalization_infer"(%1, %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = "model.bn1-normalization-1", operand_segment_sizes = array, result_segment_sizes = array, scope_symbol_id = 41 : i64, training = false} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> // CHECK: func.func @GraphToRun_conv_bn_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<64xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<64xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<64xf32>) // CHECK: %[[GAMMA:[a-zA-Z0-9_]+]] = "oneflow.variable_ir"() // CHECK: %[[WEIGHT:[a-zA-Z0-9_]+]] = "oneflow.variable_ir"() // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[ARG_0]]) // CHECK: %[[OUT2:[a-zA-Z0-9_]+]] = "oneflow.scalar_add"(%[[MOVING_VARIANCE]]) // CHECK: %[[OUT3:[a-zA-Z0-9_]+]] = "oneflow.sqrt"(%[[OUT2]]) // CHECK: %[[OUT4:[a-zA-Z0-9_]+]] = "oneflow.broadcast_div"(%[[GAMMA]], %[[OUT3]]) // CHECK: %[[OUT5:[a-zA-Z0-9_]+]] = "oneflow.reshape"(%[[OUT4]]) // CHECK: %[[OUT6:[a-zA-Z0-9_]+]] = "oneflow.broadcast_mul"(%[[WEIGHT]], %[[OUT5]]) // CHECK: %[[OUT7:[a-zA-Z0-9_]+]] = "oneflow.broadcast_mul"(%[[MOVING_MEAN]], %[[OUT4]]) // CHECK: %[[OUT8:[a-zA-Z0-9_]+]] = "oneflow.broadcast_sub"(%[[BETA]], %[[OUT7]]) // CHECK: %[[OUT9:[a-zA-Z0-9_]+]] = "oneflow.conv2d"(%[[OUT]], %[[OUT6]], %[[OUT8]]) // CHECK: return %[[OUT9]] return %2 : tensor<1x64x112x112xf32> } func.func @GraphToRun_broadcastmul_to_scalarmul_1(%arg0: tensor<64x3x7x7xf32>, %arg1: tensor<1xf32>) -> tensor<64x3x7x7xf32> { %output = "oneflow.broadcast_mul"(%arg0, %arg1) {device_name = ["@0:0"], device_tag = "cuda", op_name = "multiply"} : (tensor<64x3x7x7xf32>, tensor<1xf32>) -> tensor<64x3x7x7xf32> // CHECK: func.func @GraphToRun_broadcastmul_to_scalarmul_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<64x3x7x7xf32>, %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<1xf32>) // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.scalar_mul_by_tensor"(%[[ARG_0]], %[[ARG_1]] return %output : tensor<64x3x7x7xf32> } func.func @GraphToRun_fused_gelu_1(%arg0: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> { %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj.weight", output_lbns = ["gelu_mod.proj.weight/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [10240 : si64, 640 : si64]} : () -> tensor<10240x640xf32> %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj.bias", output_lbns = ["gelu_mod.proj.bias/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 25 : i64, shape = [10240 : si64]} : () -> tensor<10240xf32> %output_1 = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 640 : si64]} : (tensor<2x2304x640xf32>) -> tensor<2x2304x640xf32> %matmul_wx = "oneflow.broadcast_matmul"(%output_1, %output) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj-broadcast_matmul-0", scope_symbol_id = 21 : i64, transpose_a = false, transpose_b = true} : (tensor<2x2304x640xf32>, tensor<10240x640xf32>) -> tensor<2x2304x10240xf32> %matmul_wx_add = "oneflow.broadcast_add"(%matmul_wx, %output_0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj-broadcast_add-1", scope_symbol_id = 21 : i64} : (tensor<2x2304x10240xf32>, tensor<10240xf32>) -> tensor<2x2304x10240xf32> %hidden_states = "oneflow.narrow"(%matmul_wx_add) {device_name = ["@0:0"], device_tag = "cuda", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = "gelu_mod-narrow-2", scope_symbol_id = 31 : i64, start = 0 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32> %gate = "oneflow.narrow"(%matmul_wx_add) {device_name = ["@0:0"], device_tag = "cuda", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = "gelu_mod-narrow-3", scope_symbol_id = 31 : i64, start = 5120 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32> %gate_activate = "oneflow.gelu"(%gate) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod-gelu-4", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> %y = "oneflow.broadcast_mul"(%hidden_states, %gate_activate) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod-broadcast_mul-5", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>, tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> %output_2 = "oneflow.output"(%y) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0_2", output_lbns = ["_GraphToRun_0_output.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 5120 : si64]} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> // CHECK: func.func @GraphToRun_fused_gelu_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> { // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[ARG_0]]) // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[MATMUL:[a-zA-Z0-9_]+]] = "oneflow.fused_glu"(%[[OUT1]], %[[OUT]], %[[OUT0]]) // CHECK: %[[OUT2:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[Y]]) // CHECK: return %[[OUT2]] return %output_2 : tensor<2x2304x5120xf32> } func.func @GraphToRun_fused_gelu_2(%arg0: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> { %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj.weight", output_lbns = ["gelu_mod.proj.weight/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [10240 : si64, 640 : si64]} : () -> tensor<10240x640xf32> %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj.bias", output_lbns = ["gelu_mod.proj.bias/out"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 25 : i64, shape = [10240 : si64]} : () -> tensor<10240xf32> %output_1 = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 640 : si64]} : (tensor<2x2304x640xf32>) -> tensor<2x2304x640xf32> %matmul_wx_add = "oneflow.fused_matmul_bias"(%output_1, %output, %output_0) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod.proj-broadcast_add-1", scope_symbol_id = 21 : i64} : (tensor<2x2304x640xf32>, tensor<10240x640xf32>, tensor<10240xf32>) -> tensor<2x2304x10240xf32> %hidden_states = "oneflow.narrow"(%matmul_wx_add) {device_name = ["@0:0"], device_tag = "cuda", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = "gelu_mod-narrow-2", scope_symbol_id = 31 : i64, start = 0 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32> %gate = "oneflow.narrow"(%matmul_wx_add) {device_name = ["@0:0"], device_tag = "cuda", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = "gelu_mod-narrow-3", scope_symbol_id = 31 : i64, start = 5120 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32> %gate_activate = "oneflow.gelu"(%gate) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod-gelu-4", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> %y = "oneflow.broadcast_mul"(%hidden_states, %gate_activate) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "gelu_mod-broadcast_mul-5", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>, tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> %output_2 = "oneflow.output"(%y) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0_2", output_lbns = ["_GraphToRun_0_output.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 5120 : si64]} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32> // CHECK: func.func @GraphToRun_fused_gelu_2(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> { // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = "oneflow.variable"() // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.input"(%[[ARG_0]]) // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[MATMUL:[a-zA-Z0-9_]+]] = "oneflow.fused_glu"(%[[OUT1]], %[[OUT]], %[[OUT0]]) // CHECK: %[[OUT2:[a-zA-Z0-9_]+]] = "oneflow.output"(%[[Y]]) // CHECK: return %[[OUT2]] return %output_2 : tensor<2x2304x5120xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/fuse/test_cast_optimal_pass.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.cast import os import unittest import numpy as np import oneflow as flow import oneflow.unittest def _cast_optimal_pass(test_case, dtype): a = flow.tensor([2, 3], dtype=dtype) eager_b = flow.cast(a, dtype=dtype) class CastOpOptimalPass(flow.nn.Graph): def __init__(self): super().__init__() self.cast = flow.cast def build(self, x): return self.cast(x, dtype=dtype) lazy_b = CastOpOptimalPass()(a) test_case.assertEqual(eager_b.dtype, lazy_b.dtype) @flow.unittest.skip_unless_1n1d() class TestCastOpOptimalPass(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_TIMING"] = "1" os.environ["ONEFLOW_MLIR_PRINT_STATS"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_IR_PRINTING"] = "1" def test_case_optimal_pass(test_case): for dtype in [flow.float32, flow.float64, flow.int32, flow.int64]: _cast_optimal_pass(test_case, dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/fuse/test_fuse_pad_conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.pad import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def do_pad_conv_graph(test_case, with_cuda, with_bias, with_nchw=True): if with_nchw: x = flow.randn(2, 3, 4, 5) else: x = flow.randn(2, 4, 5, 3) conv = flow.nn.Conv2d(3, 3, 2, 1, bias=with_bias) if with_cuda: x = x.cuda() conv.to("cuda") if with_nchw: pad_x = flow.nn.functional.pad(x, (1, 1, 1, 1)) else: pad_x = flow.nn.functional.pad(x, (0, 0, 1, 1, 1, 1)) eager_conv_x = conv(pad_x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.conv = conv def build(self, x): if with_nchw: pad_x = flow.nn.functional.pad(x, (1, 1, 1, 1)) else: pad_x = flow.nn.functional.pad(x, (0, 0, 1, 1, 1, 1)) return self.conv(pad_x) graph_to_run = GraphToRun() lazy_conv_x = graph_to_run(x) test_case.assertTrue(np.array_equal(eager_conv_x.numpy(), lazy_conv_x.numpy())) @flow.unittest.skip_unless_1n1d() class TestFusePadConv(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_TIMING"] = "1" os.environ["ONEFLOW_MLIR_PRINT_STATS"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_IR_PRINTING"] = "1" @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") def test_pad_conv_graph_cuda(test_case): do_pad_conv_graph(test_case, True, True) do_pad_conv_graph(test_case, True, False) do_pad_conv_graph(test_case, True, False, True) def test_pad_conv_graph_cpu(test_case): do_pad_conv_graph(test_case, False, True) do_pad_conv_graph(test_case, False, False) do_pad_conv_graph(test_case, False, False, True) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/group_matmul.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -group-matmul | FileCheck %s module { // CHECK-LABEL: func.func func.func @no_bias(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) { %1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> return %1, %2 : tensor<2x1280xf16>, tensor<2x1280xf16> // CHECK: @no_bias(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1]]) // CHECK: return %[[OUT]]#1, %[[OUT]]#0 } // CHECK-LABEL: func.func func.func @with_bias(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) { %1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r1 = "oneflow.bias_add"(%1, %bias1) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r2 = "oneflow.bias_add"(%2, %bias2) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> return %r1, %r2 : tensor<2x1280xf16>, tensor<2x1280xf16> // CHECK: @with_bias(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>) // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]]) // CHECK: return %[[OUT]]#1, %[[OUT]]#0 } // CHECK-LABEL: func.func func.func @with_broadcast_add(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) { %1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r1 = "oneflow.broadcast_add"(%1, %bias1) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r2 = "oneflow.broadcast_add"(%2, %bias2) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> return %r1, %r2 : tensor<2x1280xf16>, tensor<2x1280xf16> // CHECK: @with_broadcast_add(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>) // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]]) // CHECK: return %[[OUT]]#1, %[[OUT]]#0 } // CHECK-LABEL: func.func func.func @mixed(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>) { %1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r1 = "oneflow.bias_add"(%1, %bias1) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r2 = "oneflow.bias_add"(%2, %bias2) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %m1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %m2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> return %r1, %r2, %m1, %m2: tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16> // CHECK: @mixed(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>) // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]]) // CHECK: %[[OUT1:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1]]) // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0, %[[OUT1]]#1, %[[OUT1]]#0 } // CHECK-LABEL: func.func func.func @left_alone(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>) { %1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r1 = "oneflow.bias_add"(%1, %bias1) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %2 = "oneflow.matmul"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> %r2 = "oneflow.bias_add"(%2, %bias2) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16> %m1 = "oneflow.matmul"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16> return %r1, %r2, %m1: tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16> // CHECK: @left_alone(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>) // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]]) // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = "oneflow.matmul"(%arg0, %arg1) // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0, %[[OUT1]] } func.func @f_broadcast_matmul(%x: tensor<2x4096x320xf16>, %w1: tensor<320x320xf16>, %w2: tensor<320x320xf16>, %w3: tensor<320x320xf16>) -> (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) { %matmul1 = "oneflow.broadcast_matmul"(%x, %w1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q-broadcast_matmul-16315", scope_symbol_id = 5497 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16> %matmul2 = "oneflow.broadcast_matmul"(%x, %w2) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k-broadcast_matmul-16316", scope_symbol_id = 5505 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16> %matmul3 = "oneflow.broadcast_matmul"(%x, %w3) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v-broadcast_matmul-16317", scope_symbol_id = 5513 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16> return %matmul1, %matmul2, %matmul3 : tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16> // CHECK: @f_broadcast_matmul(%[[X:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<320x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<320x320xf16>, %[[WEIGHT3:[a-zA-Z0-9_]+]]: tensor<320x320xf16>) // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:3 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[X]], %[[WEIGHT3]], %[[WEIGHT2]], %[[WEIGHT1]]) // CHECK: return %[[OUT0]]#2, %[[OUT0]]#1, %[[OUT0]]#0 } func.func @test_fused_matmul_bias_graph(%x: tensor<8x9xf64>, %w: tensor<10x9xf64>, %bias: tensor<10xf64>) -> (tensor<8x10xf64>, tensor<8x10xf64>) { %y0 = "oneflow.fused_matmul_bias"(%x, %w, %bias) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "fused_matmul_bias-0", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64> %y1 = "oneflow.fused_matmul_bias"(%x, %w, %bias) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "fused_matmul_bias-0", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64> return %y0, %y1 : tensor<8x10xf64>, tensor<8x10xf64> // CHECK: @test_fused_matmul_bias_graph(%[[X:[a-zA-Z0-9_]+]]: tensor<8x9xf64>, %[[W:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS:[a-zA-Z0-9_]+]]: tensor<10xf64>) // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[W]], %[[W]], %[[BIAS]], %[[BIAS]]) // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0 } func.func @test_fused_matmul_bias_graph_mixed(%x: tensor<8x9xf64>, %w: tensor<10x9xf64>, %bias: tensor<10xf64>, %w1: tensor<10x9xf64>, %bias1: tensor<10xf64>) -> (tensor<8x10xf64>, tensor<8x10xf64>, tensor<8x10xf64>) { %y0 = "oneflow.fused_matmul_bias"(%x, %w, %bias) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "fused_matmul_bias-0", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64> %y1 = "oneflow.fused_matmul_bias"(%x, %w, %bias) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "fused_matmul_bias-0", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64> %matmul = "oneflow.matmul"(%x, %w1) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-matmul-20", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<8x9xf64>, tensor<10x9xf64>) -> tensor<8x10xf64> %bias_add = "oneflow.bias_add"(%matmul, %bias1) {axis = 1 : si32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.time_embedding.linear_1-bias_add-21", scope_symbol_id = 90 : i64} : (tensor<8x10xf64>, tensor<10xf64>) -> tensor<8x10xf64> return %y0, %y1, %bias_add : tensor<8x10xf64>, tensor<8x10xf64>, tensor<8x10xf64> // CHECK: @test_fused_matmul_bias_graph_mixed(%[[X:[a-zA-Z0-9_]+]]: tensor<8x9xf64>, %[[W:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS:[a-zA-Z0-9_]+]]: tensor<10xf64>, %[[W1:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<10xf64>) // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:3 = "oneflow.grouped_matmul_bias"(%[[X]], %[[X]], %[[X]], %[[W1]], %[[W]], %[[W]], %[[BIAS1]], %[[BIAS]], %[[BIAS]]) // CHECK: return %[[OUT0]]#2, %[[OUT0]]#1, %[[OUT0]]#0 } } ================================================ FILE: oneflow/ir/test/OneFlow/jit_outline_func.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -ofjob-to-func \ // RUN: -convert-to-signless-for-tosa \ // RUN: -lower-oneflow-to-tosa="full=0 lower-job=0" \ // RUN: --tosa-make-broadcastable \ // RUN: -lower-oneflow-to-linalg \ // RUN: -tosa-to-tensor \ // RUN: | oneflow-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg))" \ // RUN: | oneflow-opt -linalg-fuse-elementwise-ops \ // RUN: -func-to-ofjob \ // RUN: | oneflow-opt -pass-pipeline="builtin.module(oneflow.job(outline-jit-function{compile-to-llvm=0}))" \ // RUN: | oneflow-opt -canonicalize \ // RUN: | FileCheck --dump-input=always %s // CHECK: linalg.generic // CHECK: oneflow.mlir_jit // CHECK-NOT: oneflow.softmax oneflow.job @GraphToRun_11(%arg0: tensor<2x256x1280xf16>, %arg1: tensor<2x77x1280xf16>, %arg2: tensor<2x77x1280xf16>) -> tensor<2x256x1280xf16> { %output = "oneflow.input"(%arg0) {data_type = 9 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_11_input.0.0_2", output_lbns = ["_GraphToRun_11_input.0.0_2/out"], scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x1280xf16> %output_0 = "oneflow.input"(%arg1) {data_type = 9 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_11_input.0.1_3", output_lbns = ["_GraphToRun_11_input.0.1_3/out"], scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 1280 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x1280xf16> %output_1 = "oneflow.input"(%arg2) {data_type = 9 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_11_input.0.2_4", output_lbns = ["_GraphToRun_11_input.0.2_4/out"], scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 1280 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x1280xf16> %0 = "oneflow.reshape"(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-0", scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 8 : si64, 160 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x8x160xf16> %1 = "oneflow.reshape"(%output_0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-2", scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 8 : si64, 160 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x8x160xf16> %2 = "oneflow.reshape"(%output_1) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-4", scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 8 : si64, 160 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x8x160xf16> %3 = "oneflow.transpose"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-1", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x256x8x160xf16>) -> tensor<2x8x256x160xf16> %4 = "oneflow.transpose"(%1) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-3", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x77x8x160xf16>) -> tensor<2x8x77x160xf16> %5 = "oneflow.transpose"(%2) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-5", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x77x8x160xf16>) -> tensor<2x8x77x160xf16> %6 = "oneflow.reshape"(%3) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-6", scope_symbol_id = 681 : i64, shape = [16 : si64, 256 : si64, 160 : si64]} : (tensor<2x8x256x160xf16>) -> tensor<16x256x160xf16> %7 = "oneflow.reshape"(%4) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-7", scope_symbol_id = 681 : i64, shape = [16 : si64, 77 : si64, 160 : si64]} : (tensor<2x8x77x160xf16>) -> tensor<16x77x160xf16> %8 = "oneflow.reshape"(%5) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-9", scope_symbol_id = 681 : i64, shape = [16 : si64, 77 : si64, 160 : si64]} : (tensor<2x8x77x160xf16>) -> tensor<16x77x160xf16> %9 = "oneflow.transpose"(%7) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-8", perm = [0 : si32, 2 : si32, 1 : si32], scope_symbol_id = 681 : i64} : (tensor<16x77x160xf16>) -> tensor<16x160x77xf16> %10 = "oneflow.batch_matmul"(%6, %9) {alpha = 0.079056941504209485 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "batch_matmul-11", scope_symbol_id = 681 : i64, transpose_a = false, transpose_b = false} : (tensor<16x256x160xf16>, tensor<16x160x77xf16>) -> tensor<16x256x77xf16> %11 = "oneflow.softmax"(%10) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "softmax-12", scope_symbol_id = 681 : i64} : (tensor<16x256x77xf16>) -> tensor<16x256x77xf16> %12 = "oneflow.batch_matmul"(%11, %8) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "batch_matmul-13", scope_symbol_id = 681 : i64, transpose_a = false, transpose_b = false} : (tensor<16x256x77xf16>, tensor<16x77x160xf16>) -> tensor<16x256x160xf16> %13 = "oneflow.reshape"(%12) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-14", scope_symbol_id = 681 : i64, shape = [2 : si64, 8 : si64, 256 : si64, 160 : si64]} : (tensor<16x256x160xf16>) -> tensor<2x8x256x160xf16> %14 = "oneflow.transpose"(%13) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-15", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x8x256x160xf16>) -> tensor<2x256x8x160xf16> %15 = "oneflow.reshape"(%14) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-16", scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x8x160xf16>) -> tensor<2x256x1280xf16> %output_2 = "oneflow.output"(%15) {data_type = 9 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_11_output.0.0_2", output_lbns = ["_GraphToRun_11_output.0.0_2/out"], scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x1280xf16> oneflow.return %output_2 : tensor<2x256x1280xf16> } // CHECK: oneflow.mlir_jit // CHECK-NOT: oneflow.cast oneflow.job @GraphToRun_1(%arg0: tensor<2x5xsi64>, %arg1: tensor<1xf32>) -> tensor<2x5xf32> { %output = "oneflow.input"(%arg0) {data_type = 6 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_input.0.0_2", output_lbns = ["_GraphToRun_1_input.0.0_2/out"], scope_symbol_id = 34 : i64, shape = [2 : si64, 5 : si64]} : (tensor<2x5xsi64>) -> tensor<2x5xsi64> %output_0 = "oneflow.input"(%arg1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_input.0.1_3", output_lbns = ["_GraphToRun_1_input.0.1_3/out"], scope_symbol_id = 34 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32> %0 = "oneflow.cast"(%output) {device_name = ["@0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "fw-cast-0", pin_memory = false, scope_symbol_id = 41 : i64} : (tensor<2x5xsi64>) -> tensor<2x5xf32> %1 = "oneflow.scalar_mul_by_tensor"(%0, %output_0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "fw-broadcast_mul-1-mlir-gen-2", scope_symbol_id = 41 : i64} : (tensor<2x5xf32>, tensor<1xf32>) -> tensor<2x5xf32> %output_1 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_output.0.0_2", output_lbns = ["_GraphToRun_1_output.0.0_2/out"], scope_symbol_id = 34 : i64, shape = [2 : si64, 5 : si64]} : (tensor<2x5xf32>) -> tensor<2x5xf32> oneflow.return %output_1 : tensor<2x5xf32> } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKLPass/lower_launcher_to_llvm_ptr.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -lower-launcher-to-llvm-ptr \ // RUN: | FileCheck %s // CHECK: func.func @okl_subgraph(%[[ARG:[a-zA-Z0-9_]+]]: !llvm.ptr) attributes {llvm.emit_c_interface} { // CHECK: %[[ARG0:[a-zA-Z0-9_]+]] = builtin.unrealized_conversion_cast %[[ARG]] : !llvm.ptr to !okl.launcher_ctx // CHECK: "okl.get_tensor_from_arg"(%[[ARG0]]) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> // CHECK: "okl.get_tensor_as_ret"(%[[ARG0]], %[[ARG3:[a-zA-Z0-9_]+]]) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> module { func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {cuda_graph_support = false, pool_size = 1024 : i64} { "okl.wrapper_kernel"() ({ %0 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.tensor_to_pool"(%arg0, %1) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 0 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.tensor_to_pool"(%arg0, %1) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 1 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.arg_sort"(%0) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %2 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32> okl.return }) {index = 2 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "okl.get_tensor_from_ret"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32> %2 = "oneflow.dim_gather"(%0, %1) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 3 : i32} : () -> () return } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKLPass/lower_okl_to_llvm_call.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -lower-okl-to-llvm-call \ // RUN: | FileCheck %s // CHECK-COUNT-4: llvm.call @okl_llvm_func module { func.func @okl_subgraph(%arg0: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr to !okl.launcher_ctx "okl.wrapper_kernel"() ({ %1 = "okl.get_tensor_from_arg"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "oneflow.relu"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.tensor_to_pool"(%0, %2) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 0 : i32} : () -> () "okl.wrapper_kernel"() ({ %1 = "okl.pool_to_tensor"(%0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "oneflow.tanh"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "okl.tensor_to_pool"(%0, %2) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 1 : i32} : () -> () "okl.wrapper_kernel"() ({ %1 = "okl.pool_to_tensor"(%0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "oneflow.arg_sort"(%1) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %3 = "okl.get_tensor_as_ret"(%0, %2) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32> okl.return }) {index = 2 : i32} : () -> () "okl.wrapper_kernel"() ({ %1 = "okl.pool_to_tensor"(%0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %2 = "okl.get_tensor_from_ret"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32> %3 = "oneflow.dim_gather"(%1, %2) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %4 = "okl.get_tensor_as_ret"(%0, %3) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 3 : i32} : () -> () return } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKLPass/tag_cuda_graph_support.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -tag-cuda-graph-support \ // RUN: | FileCheck %s // CHECK: func.func @okl_subgraph(%[[ARG0:[a-zA-Z0-9_]+]]: !okl.launcher_ctx) attributes {cuda_graph_support = false, pool_size = 1024 : i64} module { func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {pool_size = 1024 : i64} { "okl.wrapper_kernel"() ({ %0 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.tensor_to_pool"(%arg0, %1) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 0 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "okl.tensor_to_pool"(%arg0, %1) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 1 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "oneflow.arg_sort"(%0) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %2 = "okl.get_tensor_as_ret"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32> okl.return }) {index = 2 : i32} : () -> () "okl.wrapper_kernel"() ({ %0 = "okl.pool_to_tensor"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32> %1 = "okl.get_tensor_from_ret"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32> %2 = "oneflow.dim_gather"(%0, %1) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %3 = "okl.get_tensor_as_ret"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32> okl.return }) {index = 3 : i32} : () -> () return } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKMPass/extract_okm_tensor.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -extract-okm-tensor \ // RUN: | FileCheck %s // CHECK: "okm.arg_to_tensor"() {index = 0 : i32} : () -> tensor<2xf32> // CHECK: "okm.tensor_to_ret"(%[[ARG0:[a-zA-Z0-9_]+]]) {index = 0 : i32} : (tensor<2xsi32>) -> tensor<2xsi32> // CHECK: "okm.tensor_to_ret"(%[[ARG1:[a-zA-Z0-9_]+]]) {index = 1 : i32} : (tensor<2xf32>) -> tensor<2xf32> module { func.func @_mlir_oneflow_subgraph0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) attributes {llvm.emit_c_interface} { %0 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.arg_sort"(%1) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %3 = "oneflow.dim_gather"(%1, %2) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> return %2, %3 : tensor<2xsi32>, tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKMPass/okm_to_okl.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -convert-okm-to-okl \ // RUN: | FileCheck %s // CHECK: func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {pool_size = 1024 : i64} { // CHECK-COUNT-4: "okl.wrapper_kernel"() module { func.func @okm_alloc_subgraph0() { %c512 = arith.constant 512 : index %c0 = arith.constant 0 : index %0 = "okm.alloc_memref"() : () -> memref<1024xi8> %1 = "okm.arg_to_memref"() {index = 0 : i32} : () -> memref<2xf32> %2 = memref.view %0[%c0][] : memref<1024xi8> to memref<2xf32> %3 = "okm.wrapper_kernel"(%1, %2) ({ %12 = bufferization.to_tensor %1 : memref<2xf32> %13 = "oneflow.relu"(%12) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %14 = bufferization.to_memref %13 : memref<2xf32> okm.return %14 : memref<2xf32> }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> %4 = memref.view %0[%c512][] : memref<1024xi8> to memref<2xf32> %5 = "okm.wrapper_kernel"(%2, %4) ({ %12 = bufferization.to_tensor %2 : memref<2xf32> %13 = "oneflow.tanh"(%12) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %14 = bufferization.to_memref %13 : memref<2xf32> okm.return %14 : memref<2xf32> }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> %6 = "okm.ret_to_memref"() {index = 0 : i32} : () -> memref<2xsi32> %7 = "okm.wrapper_kernel"(%4, %6) ({ %12 = bufferization.to_tensor %4 : memref<2xf32> %13 = "oneflow.arg_sort"(%12) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %14 = bufferization.to_memref %13 : memref<2xsi32> okm.return %14 : memref<2xsi32> }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32> %8 = "okm.ret_to_memref"() {index = 1 : i32} : () -> memref<2xf32> %9 = "okm.wrapper_kernel"(%4, %6, %8) ({ %12 = bufferization.to_tensor %4 : memref<2xf32> %13 = bufferization.to_tensor %6 : memref<2xsi32> %14 = "oneflow.dim_gather"(%12, %13) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %15 = bufferization.to_memref %14 : memref<2xf32> okm.return %15 : memref<2xf32> }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32> %10 = "okm.memref_to_ret"(%7) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32> %11 = "okm.memref_to_ret"(%9) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32> return } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKMPass/opt_okm_memref.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -opt-okm-memref \ // RUN: | FileCheck %s // CHECK: func.func @okm_alloc_subgraph // CHECK: "okm.alloc_memref"() // CHECK: memref.view module { func.func @okm_wrap_subgraph0() { %0 = "okm.arg_to_memref"() {index = 0 : i32} : () -> memref<2xf32> %1 = "okm.plan_memref"() : () -> memref<2xf32> %2 = "okm.wrapper_kernel"(%0, %1) ({ %11 = bufferization.to_tensor %0 : memref<2xf32> %12 = "oneflow.relu"(%11) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %13 = bufferization.to_memref %12 : memref<2xf32> okm.return %13 : memref<2xf32> }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> %3 = "okm.plan_memref"() : () -> memref<2xf32> %4 = "okm.wrapper_kernel"(%1, %3) ({ %11 = bufferization.to_tensor %1 : memref<2xf32> %12 = "oneflow.tanh"(%11) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %13 = bufferization.to_memref %12 : memref<2xf32> okm.return %13 : memref<2xf32> }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> %5 = "okm.ret_to_memref"() {index = 0 : i32} : () -> memref<2xsi32> %6 = "okm.wrapper_kernel"(%3, %5) ({ %11 = bufferization.to_tensor %3 : memref<2xf32> %12 = "oneflow.arg_sort"(%11) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %13 = bufferization.to_memref %12 : memref<2xsi32> okm.return %13 : memref<2xsi32> }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32> %7 = "okm.ret_to_memref"() {index = 1 : i32} : () -> memref<2xf32> %8 = "okm.wrapper_kernel"(%3, %5, %7) ({ %11 = bufferization.to_tensor %3 : memref<2xf32> %12 = bufferization.to_tensor %5 : memref<2xsi32> %13 = "oneflow.dim_gather"(%11, %12) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %14 = bufferization.to_memref %13 : memref<2xf32> okm.return %14 : memref<2xf32> }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32> %9 = "okm.memref_to_ret"(%6) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32> %10 = "okm.memref_to_ret"(%8) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32> return } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OKMPass/wrap_okm_kernel.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -wrap-okm-kernel \ // RUN: | FileCheck %s // CHECK: module { // CHECK: func.func @okm_wrap_subgraph0() { // CHECK: %[[ARG0:[a-zA-Z0-9_]+]] = "okm.arg_to_memref"() {index = 0 : i32} : () -> memref<2xf32> // CHECK: %[[ARG1:[a-zA-Z0-9_]+]] = "okm.plan_memref"() : () -> memref<2xf32> // CHECK: %[[ARG2:[a-zA-Z0-9_]+]] = "okm.wrapper_kernel"(%[[ARG0]], %[[ARG1]]) ({ // CHECK: %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG0]] : memref<2xf32> // CHECK: %[[ARG12:[a-zA-Z0-9_]+]] = "oneflow.relu"(%[[ARG11]]) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12]] : memref<2xf32> // CHECK: okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xf32> // CHECK: }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> // CHECK: %[[ARG3:[a-zA-Z0-9_]+]] = "okm.plan_memref"() : () -> memref<2xf32> // CHECK: %[[ARG4:[a-zA-Z0-9_]+]] = "okm.wrapper_kernel"(%[[ARG1]], %[[ARG3]]) ({ // CHECK: %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG1]] : memref<2xf32> // CHECK: %[[ARG12:[a-zA-Z0-9_]+]] = "oneflow.tanh"(%[[ARG11]]) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12:[a-zA-Z0-9_]+]] : memref<2xf32> // CHECK: okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xf32> // CHECK: }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32> // CHECK: %[[ARG5:[a-zA-Z0-9_]+]] = "okm.ret_to_memref"() {index = 0 : i32} : () -> memref<2xsi32> // CHECK: %[[ARG6:[a-zA-Z0-9_]+]] = "okm.wrapper_kernel"(%[[ARG3]], %[[ARG5]]) ({ // CHECK: %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG3]] : memref<2xf32> // CHECK: %[[ARG12:[a-zA-Z0-9_]+]] = "oneflow.arg_sort"(%[[ARG11]]) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> // CHECK: %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12]] : memref<2xsi32> // CHECK: okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xsi32> // CHECK: }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32> // CHECK: %[[ARG7:[a-zA-Z0-9_]+]] = "okm.ret_to_memref"() {index = 1 : i32} : () -> memref<2xf32> // CHECK: %[[ARG8:[a-zA-Z0-9_]+]] = "okm.wrapper_kernel"(%[[ARG3]], %[[ARG5]], %7) ({ // CHECK: %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG3]] : memref<2xf32> // CHECK: %[[ARG12:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG5]] : memref<2xsi32> // CHECK: %[[ARG13:[a-zA-Z0-9_]+]] = "oneflow.dim_gather"(%[[ARG11]], %[[ARG12]]) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> // CHECK: %[[ARG14:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG13]] : memref<2xf32> // CHECK: okm.return %[[ARG14]] : memref<2xf32> // CHECK: }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32> // CHECK: %[[ARG9:[a-zA-Z0-9_]+]] = "okm.memref_to_ret"(%[[ARG6]]) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32> // CHECK: %[[ARG10:[a-zA-Z0-9_]+]] = "okm.memref_to_ret"(%[[ARG8]]) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32> // CHECK: return // CHECK: } // CHECK: } module { func.func @okm_subgraph0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) attributes {llvm.emit_c_interface} { %0 = "okm.arg_to_tensor"() {index = 0 : i32} : () -> tensor<2xf32> %1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.tanh"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %3 = "oneflow.arg_sort"(%2) {device_name = ["@0:0"], device_tag = "cpu", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %4 = "oneflow.dim_gather"(%2, %3) {device_name = ["@0:0"], device_tag = "cpu", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %5 = "okm.tensor_to_ret"(%3) {index = 0 : i32} : (tensor<2xsi32>) -> tensor<2xsi32> %6 = "okm.tensor_to_ret"(%4) {index = 1 : i32} : (tensor<2xf32>) -> tensor<2xf32> return %5, %6 : tensor<2xsi32>, tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/aggregate_compute_ops.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -aggregate-compute-ops \ // RUN: | FileCheck %s // CHECK: %[[ARG0:[a-zA-Z0-9_]+]] = "oneflow.arg_sort" // CHECK: %[[ARG1:[a-zA-Z0-9_]+]] = "oneflow.dim_gather" // CHECK: "oneflow.output"(%[[ARG0]]) module { oneflow.job @GraphToRun_1(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_input.0.0_2", output_lbns = ["_GraphToRun_1_input.0.0_2/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %0 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.arg_sort"(%1) {device_name = ["@0:0"], device_tag = "cuda", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %output_0 = "oneflow.output"(%2) {data_type = 5 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_output.0.0.1_4", output_lbns = ["_GraphToRun_1_output.0.0.1_4/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32> %3 = "oneflow.dim_gather"(%1, %2) {device_name = ["@0:0"], device_tag = "cuda", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 30 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %output_1 = "oneflow.output"(%3) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_output.0.0.0_3", output_lbns = ["_GraphToRun_1_output.0.0.0_3/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> oneflow.return %output_0, %output_1 : tensor<2xsi32>, tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/cuda_graph.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -wrap-ops-to-kernel-launch="mode=cuda_graph" \ // RUN: | FileCheck %s // CHECK: func.func @_mlir_oneflow_subgraph1 // CHECK: func.func @_mlir_oneflow_subgraph0 module { oneflow.job @GraphToRun_0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %0 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.arg_sort"(%1) {device_name = ["@0:0"], device_tag = "cuda", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %3 = "oneflow.dim_gather"(%1, %2) {device_name = ["@0:0"], device_tag = "cuda", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %output_0 = "oneflow.output"(%3) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0.0_3", output_lbns = ["_GraphToRun_0_output.0.0.0_3/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %output_1 = "oneflow.output"(%2) {data_type = 5 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0.1_4", output_lbns = ["_GraphToRun_0_output.0.0.1_4/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32> oneflow.return %output_1, %output_0 : tensor<2xsi32>, tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/lit.local.cfg ================================================ if not config.BUILD_CUDA: config.unsupported = True ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/simple.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -wrap-ops-to-kernel-launch="mode=simple" \ // RUN: | FileCheck %s // CHECK-NOT: func.func @_mlir_oneflow_subgraph1 // CHECK: func.func @_mlir_oneflow_subgraph0 module { oneflow.job @GraphToRun_0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) { %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %0 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32> %2 = "oneflow.arg_sort"(%1) {device_name = ["@0:0"], device_tag = "cuda", direction = "ASCENDING", hierarchy = [1], op_name = "arg_sort-2", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32> %3 = "oneflow.dim_gather"(%1, %2) {device_name = ["@0:0"], device_tag = "cuda", dim = 0 : si32, hierarchy = [1], op_name = "dim_gather-3", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32> %output_0 = "oneflow.output"(%3) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0.0_3", output_lbns = ["_GraphToRun_0_output.0.0.0_3/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> %output_1 = "oneflow.output"(%2) {data_type = 5 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0.1_4", output_lbns = ["_GraphToRun_0_output.0.0.1_4/out"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32> oneflow.return %output_1, %output_0 : tensor<2xsi32>, tensor<2xf32> } } ================================================ FILE: oneflow/ir/test/OneFlow/kernel_launch/test_resnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s import os import sys sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.dirname(__file__)) + "/..") import unittest import numpy as np import oneflow as flow import oneflow.unittest from networks.resnet50 import resnet50 def _test_okl_resnet(test_case): x = flow.randn(2, 3, 224, 224) resnet = resnet50() x = x.cuda() resnet.to("cuda") eager_res = resnet(x) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.resnet = resnet def build(self, x): return self.resnet(x) graph_to_run = GraphToRun() lazy_res = graph_to_run(x) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-4, atol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestOKLResNet(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_CSE"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH"] = "1" @unittest.skipUnless(flow.sysconfig.with_cuda(), "only test cpu cases") def test_okl_resnet(test_case): _test_okl_resnet(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/networks/__init__.py ================================================ ================================================ FILE: oneflow/ir/test/OneFlow/networks/resnet50.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import oneflow.nn as nn from oneflow import Tensor from typing import Type, Any, Callable, Union, List, Optional def conv3x3( in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 ) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU() self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU() self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AvgPool2d((7, 7)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) return model def resnet50(**kwargs: Any) -> ResNet: r"""ResNet-5 `"Deep Residual Learning for Image Recognition" `_. """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs) ================================================ FILE: oneflow/ir/test/OneFlow/oneflow-opt.mlir ================================================ // RUN: oneflow-opt --show-dialects | FileCheck %s // CHECK: Available Dialects: // CHECK: oneflow ================================================ FILE: oneflow/ir/test/OneFlow/oneflow-translate.mlir ================================================ // RUN: oneflow-translate --help | FileCheck %s // CHECK: --import-oneflow-job ================================================ FILE: oneflow/ir/test/OneFlow/psig/error_parse.mlir ================================================ // RUN: not oneflow-opt %s \ // RUN: -split-input-file \ // RUN: -verify-diagnostics -o - 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s // CHECK_ERROR_1: unexpected error: failed to parse a sbp attribute here module { oneflow.job @test_err(){ %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0", "@1:1"], device_tag = "cuda", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[[]], "S(0)", #sbp.P]>, op_name = "net-FreeEagerTensor-2", output_lbns = ["net-FreeEagerTensor-2/out"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32> oneflow.return } } ================================================ FILE: oneflow/ir/test/OneFlow/psig/sbp_parse.mlir ================================================ // RUN: oneflow-opt %s \ // RUN: -split-input-file \ // RUN: -verify-diagnostics -o - | FileCheck %s // CHECK-LABEL: test_single module { oneflow.job @test_single(){ // CHECK: parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.S<0>]> %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0", "@1:1"], device_tag = "cuda", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.S<0>]>, op_name = "net-FreeEagerTensor-1", output_lbns = ["net-FreeEagerTensor-1/out"], scope_symbol_id = 14 : i64, shape = [4 : si64, 5 : si64], trainable = false} : () -> tensor<4x5xf32> // CHECK: parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.P]> %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0", "@1:1"], device_tag = "cuda", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.P]>, op_name = "net-FreeEagerTensor-2", output_lbns = ["net-FreeEagerTensor-2/out"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32> oneflow.return } } // CHECK-LABEL: test_nd module { oneflow.job @test_nd(){ // CHECK: #sbp.B, #sbp.S<0> %output = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0", "@1:1"], device_tag = "cuda", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[#sbp.B, #sbp.S<0>]]>, op_name = "net-FreeEagerTensor-1", output_lbns = ["net-FreeEagerTensor-1/out"], scope_symbol_id = 14 : i64, shape = [4 : si64, 5 : si64], trainable = false} : () -> tensor<4x5xf32> // CHECK: [#sbp.B, #sbp.P] %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0", "@1:1"], device_tag = "cuda", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[#sbp.B, #sbp.P]]>, op_name = "net-FreeEagerTensor-2", output_lbns = ["net-FreeEagerTensor-2/out"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32> oneflow.return } } ================================================ FILE: oneflow/ir/test/OneFlow/psig/test_2nd_basic_parse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.distributed.launch --nproc_per_node 2 %s | FileCheck %s # CHECK: [#sbp.B, #sbp.S<0>] # CHECK: [#sbp.B, #sbp.S<0>] import oneflow as flow import unittest import oneflow.unittest import os from google.protobuf import text_format def _test_nd_basic_parse(test_case): class ModuleToRun(flow.nn.Module): def __init__(self): super().__init__() P0 = flow.placement("cpu", ranks=[[0], [1]]) a0_sbp = (flow.sbp.broadcast, flow.sbp.split(0)) b0_sbp = (flow.sbp.broadcast, flow.sbp.split(0)) self.A0 = flow.randn(4, 5, placement=P0, sbp=a0_sbp) self.B0 = flow.randn(5, 8, placement=P0, sbp=b0_sbp) def forward(self): return flow.matmul(self.A0, self.B0) net = ModuleToRun() class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.net = net def build(self): return self.net() graph_to_run = GraphToRun() lazy_output = graph_to_run() serialized_job = graph_to_run._forward_job_proto.SerializeToString() mlir = flow._oneflow_internal.nn.graph.ConvertJobToIR(serialized_job) print(mlir) @flow.unittest.skip_unless_1n1d() class TestBasicParse(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" def test_nd_basic_parse(test_case): _test_nd_basic_parse(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/psig/test_basic_parse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 %s | FileCheck %s # CHECK: [#sbp.B] # CHECK: [#sbp.S<0>] import oneflow as flow import unittest import oneflow.unittest import os from google.protobuf import text_format def _test_1nd_basic_parse(test_case): class ModuleToRun(flow.nn.Module): def __init__(self): super().__init__() P0 = flow.placement("cpu", ranks=[0]) a0_sbp = flow.sbp.broadcast b0_sbp = flow.sbp.split(0) self.A0 = flow.randn(4, 5, placement=P0, sbp=a0_sbp) self.B0 = flow.randn(5, 8, placement=P0, sbp=b0_sbp) def forward(self): return flow.matmul(self.A0, self.B0) net = ModuleToRun() class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.net = net def build(self): return self.net() graph_to_run = GraphToRun() lazy_output = graph_to_run() serialized_job = graph_to_run._forward_job_proto.SerializeToString() mlir = flow._oneflow_internal.nn.graph.ConvertJobToIR(serialized_job) print(mlir) @flow.unittest.skip_unless_1n1d() class TestBasicParse(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" def test_1nd_basic_parse(test_case): _test_1nd_basic_parse(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/traits.mlir ================================================ // RUN: oneflow-opt -test-oneflow-trait-folder %s | FileCheck %s // CHECK-LABEL: func.func @testSingleIdempotent // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testSingleIdempotent(%arg0 : tensor) -> tensor { // CHECK: [[IDEMPOTENT:%.+]] = "oneflow.relu"([[ARG0]]) %0 = "oneflow.relu"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[IDEMPOTENT]] return %0: tensor } // CHECK-LABEL: func.func @testDoubleIdempotent // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testDoubleIdempotent(%arg0: tensor) -> tensor { // CHECK: [[IDEMPOTENT:%.+]] = "oneflow.relu"([[ARG0]]) %0 = "oneflow.relu"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.relu"(%0) {device_tag = "cuda", op_name = "Relu_2", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[IDEMPOTENT]] return %1: tensor } // CHECK-LABEL: func.func @testTripleIdempotent // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testTripleIdempotent(%arg0: tensor) -> tensor { // CHECK: [[IDEMPOTENT:%.+]] = "oneflow.relu"([[ARG0]]) %0 = "oneflow.relu"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.relu"(%0) {device_tag = "cuda", op_name = "Relu_2", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %2 = "oneflow.relu"(%1) {device_tag = "cuda", op_name = "Relu_3", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[IDEMPOTENT]] return %2: tensor } // CHECK-LABEL: func.func @testDoubleInvolution // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testDoubleInvolution(%arg0: tensor) -> tensor { %0 = "oneflow.negative"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.negative"(%0) {device_tag = "cuda", op_name = "Relu_2", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[ARG0]] return %1: tensor } // CHECK-LABEL: func.func @testTripleInvolution // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testTripleInvolution(%arg0: tensor) -> tensor { // CHECK: [[INVOLUTION:%.+]] = "oneflow.negative"([[ARG0]]) %0 = "oneflow.negative"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.negative"(%0) {device_tag = "cuda", op_name = "Relu_2", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %2 = "oneflow.negative"(%1) {device_tag = "cuda", op_name = "Relu_3", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[INVOLUTION]] return %2: tensor } // CHECK-LABEL: func.func @testFailedInvolutionFoldDueToDifferentPlacement // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testFailedInvolutionFoldDueToDifferentPlacement(%arg0: tensor) -> tensor { %0 = "oneflow.negative"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.negative"(%0) {device_tag = "cuda", op_name = "Relu_2", op_type_name = "relu", device_name = ["1:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: [[INVOLUTION:%.+]] = "oneflow.negative"(%1) %2 = "oneflow.negative"(%1) {device_tag = "cuda", op_name = "Relu_3", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[INVOLUTION]] return %2: tensor } // CHECK-LABEL: func.func @testFailedInvolutionFoldDueToDifferentDevice // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @testFailedInvolutionFoldDueToDifferentDevice(%arg0: tensor) -> tensor { %0 = "oneflow.negative"(%arg0) {device_tag = "cuda", op_name = "Relu_1", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor %1 = "oneflow.negative"(%0) {device_tag = "cpu", op_name = "Relu_2", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: [[INVOLUTION:%.+]] = "oneflow.negative"(%1) %2 = "oneflow.negative"(%1) {device_tag = "cuda", op_name = "Relu_3", op_type_name = "relu", device_name = ["0:0-0"], scope_symbol_id = 4611686018427420670 : i64} : (tensor) -> tensor // CHECK: return [[INVOLUTION]] return %2: tensor } ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/lit.local.cfg ================================================ if not config.BUILD_CUDA: config.unsupported = True ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_conv_bn_auto_nhwc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK: oneflow.transpose import os import unittest import numpy as np import oneflow as flow import oneflow.unittest import oneflow.nn as nn from flowvision.models.resnet import resnet50 def _test_fuse_conv_bn(test_case, with_cuda): data = flow.randn(1, 3, 224, 224) if with_cuda: data = data.to("cuda") model = resnet50(pretrained=False, progress=True) if with_cuda: model.to("cuda") model.eval() eager_res = model(data) class Resnet50Graph(nn.Graph): def __init__(self): super().__init__() self.model = model def build(self, *input): return self.model(*input) graph = Resnet50Graph() lazy_res = graph(data) test_case.assertTrue( np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-2, atol=1e-2) ) @flow.unittest.skip_unless_1n1d() class TestFuseConvBn(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS"] = "1" os.environ["ONEFLOW_MLIR_PRINT_STATS"] = "1" @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "only test cpu cases") def test_fuse_conv_bn_cuda(test_case): _test_fuse_conv_bn(test_case, True) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fuse_bias_add_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.bias_add import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def do_bias_add_dropout_graph(test_case, with_cuda, prob): x = flow.randn(2, 3, 4, 5) bias = flow.randn(5) dropout = flow.nn.Dropout(p=prob) if with_cuda: x = x.cuda() bias = bias.to("cuda") dropout.to("cuda") eager_res = dropout(flow._C.bias_add(x, bias, axis=3)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.dropout = dropout def build(self, x, bias): return self.dropout(flow._C.bias_add(x, bias, axis=3)) graph_to_run = GraphToRun() lazy_res = graph_to_run(x, bias) if prob == 1.0: test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy())) else: test_case.assertTrue(lazy_res.sum().item() != 0.0) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") class TestBiasAddDropout(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" def test_bias_add_dropout_graph(test_case): do_bias_add_dropout_graph(test_case, True, 1.0) do_bias_add_dropout_graph(test_case, True, 0.5) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fuse_bias_add_gelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.bias_add import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def do_bias_add_gelu_graph(test_case, with_cuda): x = flow.randn(2, 3, 4, 5) bias = flow.randn(5) gelu = flow.nn.GELU() if with_cuda: x = x.cuda() bias = bias.to("cuda") gelu.to("cuda") eager_res = gelu(flow._C.bias_add(x, bias, axis=3)) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.gelu = gelu def build(self, x, bias): return self.gelu(flow._C.bias_add(x, bias, axis=3)) graph_to_run = GraphToRun() lazy_res = graph_to_run(x, bias) test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy())) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") class TestBiasAddGelu(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" def test_bias_add_gelu_graph(test_case): do_bias_add_gelu_graph(test_case, True) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fuse_bn_add_relu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: "oneflow.normalization" import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def do_normalization_add_relu_graph(test_case, with_cuda): def get_bn(fused=True): if fused: return flow.nn.FusedBatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( "cuda" ) else: return flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( "cuda" ) class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() self.m = get_bn() def build(self, x, addend): return self.m(x, addend=addend) class GraphToRunWithOpt(flow.nn.Graph): def __init__(self): super().__init__() self.m = get_bn(fused=False) def build(self, x, addend): return flow.relu(self.m(x) + addend) graph_to_run = GraphToRun() graph_to_run_opt = GraphToRunWithOpt() x = flow.Tensor(np.random.randn(4, 2, 8, 3)).to("cuda") addend = flow.Tensor(np.random.randn(4, 2, 8, 3)).to("cuda") eager_res = flow.relu(get_bn(fused=False)(x) + addend) eager_res_fuse = get_bn()(x, addend=addend) lazy_res = graph_to_run(x, addend) lazy_res_opt = graph_to_run_opt(x, addend) test_case.assertTrue(np.array_equal(eager_res.numpy(), eager_res_fuse.numpy())) test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy())) test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res_opt.numpy())) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") class TestNormalizationAddRelu(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS"] = "1" os.environ["ONEFLOW_MLIR_PRINT_STATS"] = "1" def test_normalization_add_relu_graph(test_case): do_normalization_add_relu_graph(test_case, True) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fuse_gelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.broadcast_matmul # CHECK-NOT: oneflow.fused_matmul_bias # CHECK-NOT: oneflow.narrow # CHECK: "oneflow.fused_glu" import unittest import numpy as np import os import oneflow as flow import oneflow.nn as nn import oneflow.nn.functional as F import oneflow.unittest import oneflow.sysconfig class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__( self, dim_in: int, dim_out: int, ): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * F.gelu(gate) class GraphToRun(flow.nn.Graph): def __init__(self, gelu_mod): super().__init__() self.gelu_mod = gelu_mod def build(self, hidden_states): return self.gelu_mod(hidden_states) def do_fused_gelu_graph(test_case, dev, fuse_linear=False): if fuse_linear: os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" else: os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "0" gelu_mod = GEGLU(640, 5120).to(dev) hidden_states = flow.randn(2, 2304, 640).to(dev) eager_res = gelu_mod(hidden_states) graph_to_run = GraphToRun(gelu_mod) lazy_res = graph_to_run(hidden_states) test_case.assertTrue(np.allclose(eager_res.numpy(), lazy_res.numpy())) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") class TestFusedGelu(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" def test_fused_gelu_graph(test_case): do_fused_gelu_graph(test_case, "cuda", fuse_linear=True) do_fused_gelu_graph(test_case, "cuda", fuse_linear=False) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fuse_scale_tril.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.tril import os import unittest import numpy as np import oneflow as flow from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict def _test_fused_scale_tril( test_case, shape, diagonal=0, scale=1.0, ): x = np.random.rand(*shape) # Different dtype will result in insert of cast op causing pass to fail. tensor_x = flow.tensor(x, device="cuda", dtype=flow.float32) eager_out = flow.tril(tensor_x, diagonal) * scale class TestFuseScaleTril(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return flow.tril(tensor_x * scale, diagonal) lazy_out_0 = TestFuseScaleTril()() test_case.assertTrue(np.allclose(eager_out.numpy(), lazy_out_0.numpy())) class TestFuseTrilScale(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return flow.tril(tensor_x, diagonal) * scale lazy_out_1 = TestFuseTrilScale()() test_case.assertTrue(np.allclose(eager_out.numpy(), lazy_out_1.numpy())) @flow.unittest.skip_unless_1n1d() class FusedScaleTrilTestCase(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" def test_fused_scale_tril(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(5, 5), (4, 6)] arg_dict["diagonal"] = [-1, 0] arg_dict["scale"] = [-2.3, 2.0] for kwargs in GenArgDict(arg_dict): _test_fused_scale_tril(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fused_matmul_bias.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.bias_add # CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:5 = "oneflow.grouped_matmul_bias" import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def _matmul_bias0(x, weight, bias): return flow._C.bias_add( flow._C.matmul(x, weight, transpose_b=True), bias, axis=len(x.shape) - 1 ) def _matmul_bias1(x, w, bias): return flow._C.fused_matmul_bias(x, w, bias) def do_fused_matmul_bias_graph(test_case, dev): x = np.random.uniform(low=-1, high=1, size=(8, 9)) w = np.random.uniform(low=-1, high=1, size=(10, 9)) bias = np.random.uniform(low=-1, high=1, size=(10)) x = flow.from_numpy(x).to(dev).to(flow.float32) w = flow.from_numpy(w).to(dev).to(flow.float32) bias = flow.from_numpy(bias).to(dev).to(flow.float32) eager_res = _matmul_bias0(x, w, bias) * 5 class GraphToRun(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x, w, bias): return ( _matmul_bias0(x, w, bias) + _matmul_bias1(x, w, bias) + _matmul_bias0(x, w, bias) + _matmul_bias1(x, w, bias) + _matmul_bias0(x, w, bias) ) graph_to_run = GraphToRun() lazy_res = graph_to_run(x, w, bias) test_case.assertTrue(np.allclose(eager_res.numpy(), lazy_res.numpy())) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") class TestGroupMatMulBias(oneflow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_CSE"] = "0" def test_fused_matmul_bias_graph(test_case): do_fused_matmul_bias_graph(test_case, "cuda") if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_fused_multi_head_attention_inference.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s # CHECK-NOT: oneflow.softmax # CHECK-NOT: oneflow.batch_matmul import unittest import numpy as np import math import os import oneflow as flow import oneflow.unittest import oneflow.sysconfig def _ref(query, key, value, num_heads, causal=False): query = query.view(query.shape[0], query.shape[1], num_heads, -1).permute( 0, 2, 1, 3 ) key = key.view(key.shape[0], key.shape[1], num_heads, -1).permute(0, 2, 3, 1) value = value.view(value.shape[0], value.shape[1], num_heads, -1).permute( 0, 2, 1, 3 ) scores = flow.matmul(query, key) / math.sqrt(query.shape[-1]) if causal: causal_mask = flow.triu( flow.ones( scores.shape[-2], scores.shape[-1], dtype=flow.bool, device="cuda" ), 1, ) scores = flow.masked_fill(scores, causal_mask, float("-inf")) attn = flow.softmax(scores, dim=-1) out = flow.matmul(attn, value) out = out.permute(0, 2, 1, 3) out = out.reshape(out.shape[0], out.shape[1], -1) return out def _ref2(query, key, value, num_heads, causal=False): query = query.view(query.shape[0], query.shape[1], num_heads, -1).permute( 0, 2, 1, 3 ) key = key.view(key.shape[0], key.shape[1], num_heads, -1).permute(0, 2, 1, 3) value = value.view(value.shape[0], value.shape[1], num_heads, -1).permute( 0, 2, 1, 3 ) query = query.reshape(-1, query.shape[2], query.shape[3]) key = key.reshape(-1, key.shape[2], key.shape[3]).permute(0, 2, 1) value = value.reshape(-1, value.shape[2], value.shape[3]) scale = 1 / math.sqrt(query.shape[-1]) scores = flow.baddbmm( flow.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device, ), query, key, beta=0, alpha=scale, ) if causal: causal_mask = flow.triu( flow.ones( scores.shape[-2], scores.shape[-1], dtype=flow.bool, device="cuda" ), 1, ) scores = flow.masked_fill(scores, causal_mask, float("-inf")) attn = flow.softmax(scores, dim=-1) out = flow.matmul(attn, value) out = out.reshape(-1, num_heads, out.shape[1], out.shape[2]) out = out.permute(0, 2, 1, 3) out = out.reshape(out.shape[0], out.shape[1], -1) return out def _fused_mha(query, key, value, num_heads, causal=False): return flow._C.fused_multi_head_attention_inference( query, key, value, num_heads, causal=causal ) class GraphToRun(flow.nn.Graph): def __init__(self, ref=None, num_heads=None, causal=False): super().__init__() self.ref = ref self.causal = causal self.num_heads = num_heads def build(self, query, key, value): return self.ref(query, key, value, self.num_heads, self.causal) def _test_fused_multi_head_attention_inference( test_case, batch_size, num_heads, query_seq_len, kv_seq_len, query_head_size, value_head_size, dtype, graph_builder, ref, causal=False, ): query = flow.randn( (batch_size, query_seq_len, num_heads * query_head_size), device="cuda", dtype=flow.float, ).to(dtype) key = flow.randn( (batch_size, kv_seq_len, num_heads * query_head_size), device="cuda", dtype=flow.float, ).to(dtype) value = flow.randn( (batch_size, kv_seq_len, num_heads * value_head_size), device="cuda", dtype=flow.float, ).to(dtype) g = graph_builder(ref=ref, num_heads=num_heads, causal=causal) ref_out = ref(query, key, value, num_heads, causal).numpy() fused_out = _fused_mha(query, key, value, num_heads, causal).numpy() g_out = g(query, key, value).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) test_case.assertTrue(np.allclose(ref_out, g_out, atol=1e-2, rtol=1e-2)) @flow.unittest.skip_unless_1n1d() @unittest.skipUnless(oneflow.sysconfig.with_cuda(), "needs -DBUILD_CUDA=ON") # TODO: skip for GTX1080 in CI @unittest.skipUnless( flow.cuda.get_device_capability()[0] >= 7, "needs CUDA compatibility >= 7" ) class TestFusedMultiHeadAttentionInference(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_STDOUT"] = "1" os.environ["ONEFLOW_MLIR_CSE"] = "0" def test_multi_head_attention_inference(test_case): # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype for ref in [_ref, _ref2]: _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 4096, 40, 40, flow.float16, GraphToRun, ref ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 77, 40, 40, flow.float16, GraphToRun, ref ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 1024, 80, 80, flow.float16, GraphToRun, ref ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 77, 80, 80, flow.float16, GraphToRun, ref ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 256, 160, 160, flow.float16, GraphToRun, ref ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 77, 160, 160, flow.float16, GraphToRun, ref ) if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/OneFlow/with_cuda/test_graph_save_and_load.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s import os import sys sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.dirname(__file__)) + "/..") import unittest import oneflow as flow import oneflow.unittest from oneflow.core.job import job_pb2 as job_pb from networks.resnet50 import resnet50 class InferGraph(flow.nn.Graph): def __init__(self, placement_arg=None): super().__init__() model = resnet50() if placement_arg is not None: if "placement" in placement_arg: model.to_global(**placement_arg) else: model.to(**placement_arg) self.model = model def build(self, image): logits = self.model(image.to("cuda")) pred = logits.softmax() return pred @unittest.skipIf(not flow.sysconfig.with_mlir(), "only test with mlir") @flow.unittest.skip_unless_1n1d() class GraphSaveTestCase(flow.unittest.MLIRTestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" def test_save_and_load(self): placement_arg = { "placement": flow.placement("cuda", ranks=[0]), "sbp": flow.sbp.broadcast, } graph = InferGraph(placement_arg) image_placeholder = flow.empty( (1, 3, 224, 224), dtype=flow.float32, placement=flow.placement("cpu", ranks=[0]), sbp=flow.sbp.broadcast, ) graph._compile(image_placeholder) saved_path = os.path.join("saved_model", graph.name) if not os.path.exists(saved_path): os.makedirs(saved_path) flow.save(graph, saved_path) saved_ir_path = os.path.join(saved_path, "model.mlir") serialized_job = oneflow._oneflow_internal.nn.graph.LoadSerializedJobFromIR( saved_ir_path ) job = job_pb.Job() job.ParseFromString(serialized_job) # TODO: run loaded job as graph and original graph, compare the result if __name__ == "__main__": unittest.main() ================================================ FILE: oneflow/ir/test/Transform/lit.local.cfg ================================================ if not config.WITH_MLIR_CUDA_CODEGEN: config.unsupported = True ================================================ FILE: oneflow/ir/test/Transform/matmul.mlir ================================================ // RUN: oneflow-opt %s --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \ // RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))' module { func.func @JITOpGenerated0(%arg0: memref<5x10xf32, strided<[?, ?], offset: ?>>, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x10xf32>) attributes {llvm.emit_c_interface} { %alloc = memref.alloc() : memref<512xi8> %c0 = arith.constant 0 : index %view = memref.view %alloc[%c0][] : memref<512xi8> to memref<1x2x10xf32> %c10 = arith.constant 10 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0_0 = arith.constant 0 : index %c5 = arith.constant 5 : index %cst = arith.constant 0.000000e+00 : f32 %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] : memref<5x10xf32, strided<[?, ?], offset: ?>> into memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>> %expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2]] : memref<2x5xf32, strided<[?, ?], offset: ?>> into memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>> gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { memref.store %cst, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> gpu.terminator } {SCFToGPU_visited} gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { scf.for %arg15 = %c0_0 to %c5 step %c1 { %0 = memref.load %expand_shape_1[%c0_0, %arg4, %arg15] : memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>> %1 = memref.load %expand_shape[%c0_0, %arg15, %arg5] : memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>> %2 = memref.load %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> %3 = arith.mulf %0, %1 : f32 %4 = arith.addf %2, %3 : f32 memref.store %4, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> } gpu.terminator } {SCFToGPU_visited} %collapse_shape = memref.collapse_shape %view [[0, 1], [2]] : memref<1x2x10xf32> into memref<2x10xf32> memref.copy %collapse_shape, %arg2 : memref<2x10xf32> to memref<2x10xf32> return } } ================================================ FILE: oneflow/ir/test/Transform/softmax.mlir ================================================ // RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec_no_vectorize.mlir})" \ // RUN: | oneflow-opt --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \ // RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))' !tmp_tensor_t = tensor<16x128xf32> !in_tensor_t = tensor<16x128x128xf32> !out_tensor_t = tensor<16x128x128xf32> func.func @softmax() -> !out_tensor_t { %cst_0 = arith.constant 0.0 : f32 %cst_1 = arith.constant 1.0 : f32 %cst_min = arith.constant -3.40282347E+38 : f32 %input = arith.constant dense<5.000000e+00> : !out_tensor_t %input_max_empty = tensor.empty() : !tmp_tensor_t %input_max_filled = linalg.fill ins(%cst_min : f32) outs(%input_max_empty : !tmp_tensor_t) -> !tmp_tensor_t %input_max = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%input : !in_tensor_t) outs(%input_max_filled : !tmp_tensor_t) { ^bb0(%arg0: f32, %arg1: f32): %max = arith.maxf %arg0, %arg1 : f32 linalg.yield %max : f32 } -> !tmp_tensor_t // This has been fused manually to avoid the fusion on tensors pass and reduce noise atm. %exps_empty = tensor.empty() : !out_tensor_t %exps_sum_empty = tensor.empty() : !tmp_tensor_t %exps_sum_filled = linalg.fill ins(%cst_0 : f32) outs(%exps_sum_empty : !tmp_tensor_t) -> !tmp_tensor_t %exps, %exps_sum = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%input, %input_max : !in_tensor_t, !tmp_tensor_t) outs(%exps_empty, %exps_sum_filled : !out_tensor_t, !tmp_tensor_t) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): %sub = arith.subf %arg0, %arg1 : f32 %exp = math.exp %sub : f32 %add = arith.addf %exp, %arg3 : f32 linalg.yield %exp, %add : f32, f32 } -> (!out_tensor_t, !tmp_tensor_t) %res_empty = tensor.empty() : !out_tensor_t %res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%exps, %exps_sum : !out_tensor_t, !tmp_tensor_t) outs(%res_empty : !out_tensor_t) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // %10 = arith.divf %cst_1, %arg1 : f32 // %11 = arith.mulf %arg0, %10 : f32 %div = arith.divf %arg0, %arg1 : f32 linalg.yield %div : f32 } -> !out_tensor_t return %res: !out_tensor_t } ================================================ FILE: oneflow/ir/test/Transform/softmax_codegen_spec.mlir ================================================ // RUN: oneflow-opt %s transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): // Note: step 1, tiling and fusing linalg ops in block level. %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %module_op : (!pdl.operation) -> !pdl.operation %match_0, %match_1, %match_2, %match_3, %match_end = transform.split_handle %ops : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %forall, %_ = transform.structured.tile_to_forall_op %match_end tile_sizes [1, 4] ( mapping = [#gpu.block, #gpu.block] ) transform.structured.fuse_into_containing_op %match_3 into %forall transform.structured.fuse_into_containing_op %match_2 into %forall transform.structured.fuse_into_containing_op %match_1 into %forall transform.structured.fuse_into_containing_op %match_0 into %forall transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () transform.oneflow.cse %module_op : (!pdl.operation) -> () // Note: step 2, tiling and fusing linalg ops in thread level. %ops_1 = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %module_op : (!pdl.operation) -> !pdl.operation %match_0_0, %match_0_1, %match_0_2, %match_0_3, %match_0_end = transform.split_handle %ops_1 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %reduction_linalg_ops = transform.merge_handles %match_0_1, %match_0_3 : !pdl.operation transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1] ( mapping = [#gpu.thread, #gpu.thread] ) %parallel_linalg_ops = transform.merge_handles %match_0_0, %match_0_2, %match_0_end : !pdl.operation transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32] ( mapping = [#gpu.thread, #gpu.thread, #gpu.thread] ) // Note: step 3,vectorize transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () transform.oneflow.cse %module_op : (!pdl.operation) -> () %to_vectorize = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %to_vectorize // Note: step 4, bufferize transform.oneflow.explicit_linalg_outcome %module_op : (!pdl.operation) -> () transform.bufferization.eliminate_empty_tensors %module_op %empty = transform.structured.match ops{["tensor.empty"]} in %module_op : (!pdl.operation) -> !pdl.operation %empty_id = transform.cast %empty : !pdl.operation to !transform.op<"tensor.empty"> transform.bufferization.empty_tensor_to_alloc_tensor %empty_id : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> %bufferized_module_op = transform.bufferization.one_shot_bufferize %module_op {create_deallocs = false, bufferize_function_boundaries = true, allow_return_allocs = true} : (!pdl.operation) -> !pdl.operation // Note: step 5, post bufferize function-type-related transform transform.oneflow.canonicalization %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.cse %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () %func = transform.structured.match ops{["func.func"]} in %bufferized_module_op : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %func : (!pdl.operation) -> () // Note: step 6, post bufferize memory-buffer-pool transform transform.oneflow.results_to_out_params %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.fold_alloc %func : (!pdl.operation) -> () // Note: step 7, mapping scf to gpu %gpu_launch_op = transform.gpu.map_forall_to_blocks %bufferized_module_op { generate_gpu_launch } transform.gpu.map_nested_forall_to_threads %gpu_launch_op block_dims = [32, 4, 1] } ================================================ FILE: oneflow/ir/test/Transform/softmax_codegen_spec_no_vectorize.mlir ================================================ // RUN: oneflow-opt %s transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): // Note: step 1, tiling and fusing linalg ops in block level. %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %module_op : (!pdl.operation) -> !pdl.operation %match_0, %match_1, %match_2, %match_3, %match_end = transform.split_handle %ops : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %forall, %_ = transform.structured.tile_to_forall_op %match_end tile_sizes [1, 4] ( mapping = [#gpu.block, #gpu.block] ) transform.structured.fuse_into_containing_op %match_3 into %forall transform.structured.fuse_into_containing_op %match_2 into %forall transform.structured.fuse_into_containing_op %match_1 into %forall transform.structured.fuse_into_containing_op %match_0 into %forall transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () transform.oneflow.cse %module_op : (!pdl.operation) -> () // Note: step 2, tiling and fusing linalg ops in thread level. %ops_1 = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %module_op : (!pdl.operation) -> !pdl.operation %match_0_0, %match_0_1, %match_0_2, %match_0_3, %match_0_end = transform.split_handle %ops_1 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %reduction_linalg_ops = transform.merge_handles %match_0_1, %match_0_3 : !pdl.operation transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1] ( mapping = [#gpu.thread, #gpu.thread] ) %parallel_linalg_ops = transform.merge_handles %match_0_0, %match_0_2, %match_0_end : !pdl.operation transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32] ( mapping = [#gpu.thread, #gpu.thread, #gpu.thread] ) transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () transform.oneflow.cse %module_op : (!pdl.operation) -> () // Note: step 3, bufferize transform.oneflow.explicit_linalg_outcome %module_op : (!pdl.operation) -> () transform.bufferization.eliminate_empty_tensors %module_op %empty = transform.structured.match ops{["tensor.empty"]} in %module_op : (!pdl.operation) -> !pdl.operation %empty_id = transform.cast %empty : !pdl.operation to !transform.op<"tensor.empty"> transform.bufferization.empty_tensor_to_alloc_tensor %empty_id : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> %bufferized_module_op = transform.bufferization.one_shot_bufferize %module_op {create_deallocs = false, bufferize_function_boundaries = true, allow_return_allocs = true} : (!pdl.operation) -> !pdl.operation // Note: step 4, post bufferize function-type-related transform transform.oneflow.canonicalization %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.cse %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () %func = transform.structured.match ops{["func.func"]} in %bufferized_module_op : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %func : (!pdl.operation) -> () // Note: step 5, post bufferize memory-buffer-pool transform transform.oneflow.results_to_out_params %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () transform.oneflow.fold_alloc %func : (!pdl.operation) -> () // Note: step 6, mapping scf to gpu %gpu_launch_op = transform.gpu.map_forall_to_blocks %bufferized_module_op { generate_gpu_launch } transform.gpu.map_nested_forall_to_threads %gpu_launch_op block_dims = [32, 4, 1] } ================================================ FILE: oneflow/ir/test/Transform/test_dialect.mlir ================================================ // RUN: oneflow-opt --oneflow-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s // Test One-Shot Bufferize. transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = transform.bufferization.one_shot_bufferize %0 : (!pdl.operation) -> !pdl.operation } // CHECK-LABEL: func @test_function( // CHECK-SAME: %[[A:.*]]: tensor func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { %c0 = arith.constant 0 : index // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) // CHECK: memref.copy %[[A_memref]], %[[alloc]] // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor // CHECK: memref.dealloc %[[alloc]] // CHECK: return %[[res_tensor]] return %0 : tensor } ================================================ FILE: oneflow/ir/test/lit.cfg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # -*- Python -*- import os import platform import re import subprocess import tempfile import lit.formats import lit.util from lit.llvm import llvm_config from lit.llvm.subst import ToolSubst from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. config.name = "ONEFLOW" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. config.suffixes = [".mlir", ".py"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.oneflow_obj_root, "test") config.substitutions.append(("%PATH%", config.environment["PATH"])) config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. config.excludes = [ "Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt", "networks", "test_fuse_cast_scale.mlir.py", "test_util.py", "test_mlir_opt.mlir.py", "lit.cfg.py", "saved_model", ] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.oneflow_obj_root, "test") config.oneflow_tools_dir = os.path.join(config.oneflow_ir_obj_root, "bin") # Tweak the PATH to include the tools dir. llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) # TODO: these two should be unnecessary llvm_config.with_environment( "LD_LIBRARY_PATH", os.path.join(config.oneflow_obj_root, "third_party_install/protobuf/lib"), append_path=True, ) llvm_config.with_environment( "LD_LIBRARY_PATH", os.path.join(config.oneflow_obj_root, "_deps/glog-build"), append_path=True, ) llvm_config.with_environment("ONEFLOW_MLIR_STDOUT", "1") llvm_config.with_environment("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") llvm_config.with_environment("ONEFLOW_MLIR_CSE", "1") llvm_config.with_environment("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") llvm_config.with_environment( "PYTHONPATH", os.path.join(config.oneflow_src_root, "python"), append_path=True, ) # Searches for a runtime library with the given name and returns a tool # substitution of the same name and the found path. # Correctly handles the platforms shared library directory and naming conventions. def add_runtime(name): path = "" for prefix in ["", "lib"]: path = os.path.join( config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}" ) if os.path.isfile(path): break return ToolSubst(f"%{name}", path) tool_dirs = [config.oneflow_tools_dir, config.llvm_tools_dir] tools = [ "oneflow-opt", "oneflow-translate", "oneflow-runner", add_runtime("mlir_runner_utils"), ] if config.WITH_MLIR_CUDA_CODEGEN: tools.extend([add_runtime("mlir_cuda_runtime")]) tools.extend( [ ToolSubst("%with_cuda", config.BUILD_CUDA, unresolved="ignore"), ToolSubst("%linalg_test_lib_dir", config.llvm_lib_dir, unresolved="ignore"), ToolSubst("%test_exec_root", config.test_exec_root, unresolved="ignore"), ] ) llvm_config.add_tool_substitutions(tools, tool_dirs) try: from iree import runtime as ireert from iree.compiler import compile_str config.WITH_ONEFLOW_IREE = True except ImportError: config.WITH_ONEFLOW_IREE = False ================================================ FILE: oneflow/ir/test/lit.site.cfg.py.in ================================================ @LIT_SITE_CFG_IN_HEADER@ import sys config.host_triple = "@LLVM_HOST_TRIPLE@" config.target_triple = "@TARGET_TRIPLE@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" config.llvm_shlib_dir = "@LLVM_LIBRARY_DIR@" config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.python_executable = "@PYTHON_EXECUTABLE@" config.gold_executable = "@GOLD_EXECUTABLE@" config.ld64_executable = "@LD64_EXECUTABLE@" config.enable_shared = @ENABLE_SHARED@ config.enable_assertions = @ENABLE_ASSERTIONS@ config.targets_to_build = "@TARGETS_TO_BUILD@" config.native_target = "@LLVM_NATIVE_ARCH@" config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') config.host_os = "@HOST_OS@" config.host_cc = "@HOST_CC@" config.host_cxx = "@HOST_CXX@" # Note: ldflags can contain double-quoted paths, so must use single quotes here. config.host_ldflags = '@HOST_LDFLAGS@' config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' config.host_arch = "@HOST_ARCH@" config.oneflow_src_root = "@CMAKE_SOURCE_DIR@" config.oneflow_obj_root = "@CMAKE_BINARY_DIR@" config.oneflow_ir_obj_root = "@PROJECT_BINARY_DIR@" config.WITH_MLIR_CUDA_CODEGEN = @WITH_MLIR_CUDA_CODEGEN@ config.BUILD_CUDA = @BUILD_CUDA@ # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time. try: config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params except KeyError: e = sys.exc_info()[1] key, = e.args lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) import lit.llvm lit.llvm.initialize(lit_config, config) # Let the main config do the real work. lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/oneflow/ir/test/lit.cfg.py") ================================================ FILE: oneflow/maybe/config.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_CONFIG_H_ #define ONEFLOW_MAYBE_CONFIG_H_ #include // pre-define it if you use a logging library like glog #ifndef OF_MAYBE_ASSERT #define OF_MAYBE_ASSERT(_cond_) assert(_cond_) #endif // ASSERT_EQ is different from ASSERT in logging / testing framework // pre-define it if you use a logging library like glog #ifndef OF_MAYBE_ASSERT_EQ #define OF_MAYBE_ASSERT_EQ(_lhs_, _rhs_) OF_MAYBE_ASSERT(_lhs_ == _rhs_) #endif #if __GNUC__ >= 7 #define OF_MAYBE_HAS_IS_AGGREGATE // in old versions of clang, __has_builtin(__is_aggregate) returns false #elif __clang__ #if !__is_identifier(__is_aggregate) #define OF_MAYBE_HAS_IS_AGGREGATE #endif #else #if __has_builtin(__is_aggregate) #define OF_MAYBE_HAS_IS_AGGREGATE #endif #endif #ifdef OF_MAYBE_HAS_IS_AGGREGATE #define OF_MAYBE_IS_AGGREGATE(...) (__is_aggregate(__VA_ARGS__)) #else // decay to POD checking if no such builtin (because implementing __is_aggregate need reflection) #define OF_MAYBE_IS_AGGREGATE(...) \ (std::is_standard_layout<__VA_ARGS__>::value && std::is_trivial<__VA_ARGS__>::value) #endif // `__builtin_expect` exists at least since GCC 4 / Clang 3 #define OF_MAYBE_EXPECT_FALSE(x) (__builtin_expect((x), 0)) #if __has_cpp_attribute(nodiscard) #define OF_MAYBE_NODISCARD_FUNC [[nodiscard]] #define OF_MAYBE_NODISCARD_TYPE [[nodiscard]] #elif __has_attribute(warn_unused_result) #define OF_MAYBE_NODISCARD_FUNC \ __attribute__((warn_unused_result)) // or [[gnu::warn_unused_result]] #define OF_MAYBE_NODISCARD_TYPE #endif #endif // ONEFLOW_MAYBE_CONFIG_H_ ================================================ FILE: oneflow/maybe/error.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_ERROR_H_ #define ONEFLOW_MAYBE_ERROR_H_ #include #include #include #include #include #include #include #include #include #include "utility.h" #include "type_traits.h" namespace oneflow { namespace maybe { namespace details { template struct ErrorStackFromContainerBase { private: using Derived = D; auto& Stack() { return static_cast(this)->GetStack(); } const auto& Stack() const { return static_cast(this)->GetStack(); } public: std::size_t StackSize() const { return Stack().size(); } template void PushStack(Args&&... args) { auto& s = Stack(); s.emplace(s.end(), std::forward(args)...); } template const typename T::StackType::value_type& StackElem(std::size_t index) const { return Stack()[index]; } auto StackBegin() const { return Stack().begin(); } auto StackEnd() const { return Stack().end(); } }; } // namespace details template struct StackedErrorTraits { StackedErrorTraits() = delete; using ErrorType = typename T::ErrorType; using StackEntryType = typename T::StackEntryType; template< typename U, std::enable_if_t< std::is_same>::value && std::is_same().Error())>>::value, int> = 0> static decltype(auto) Error(U&& se) { return se.Error(); } static std::size_t StackSize(const T& se) { return se.StackSize(); } static ConstRefExceptVoid StackElem(const T& se, std::size_t index) { return se.StackElem(index); } template>::value, int> = 0> static void PushStack(U&& se, Args&&... args) { se.PushStack(std::forward(args)...); } template>::value, int> = 0> static std::string Dump(U&& se) { return se.Dump(); } template>::value, int> = 0> [[noreturn]] static void Abort(U&& se) { se.Abort(); } }; template struct StackedErrorTraits> { StackedErrorTraits() = delete; using PointedTraits = StackedErrorTraits; using ValueType = std::unique_ptr; using ErrorType = typename PointedTraits::ErrorType; using StackEntryType = typename PointedTraits::StackEntryType; template>::value, int> = 0> static decltype(auto) Error(U&& se) { return PointedTraits::Error(*se); } static std::size_t StackSize(const ValueType& se) { return PointedTraits::StackSize(*se); } static ConstRefExceptVoid StackElem(const T& se, std::size_t index) { return PointedTraits::StackElem(*se, index); } template>::value, int> = 0> static void PushStack(U&& se, Args&&... args) { PointedTraits::PushStack(*se, std::forward(args)...); } template>::value, int> = 0> static std::string Dump(U&& se) { return PointedTraits::Dump(*se); } template>::value, int> = 0> [[noreturn]] static void Abort(U&& se) { PointedTraits::Abort(*se); } }; // simple implementation for some customization points namespace simple { template struct MessageFormatTrait; template<> struct MessageFormatTrait { template static std::string Format(Code&& code, Args&&... args) { if (sizeof...(args) > 0) { std::stringstream res; res << code << ": "; ((res << args), ...); return res.str(); } else { return code; } } }; template<> struct MessageFormatTrait { template static std::string_view Format(Code&& code) { return code; } }; template> struct ErrorStackEntry { std::string_view filename; std::size_t lineno; std::string_view function; Message message; template ErrorStackEntry(std::string_view filename, std::size_t lineno, std::string_view function, Args&&... args) : filename(filename), lineno(lineno), function(function), message(MessageFormatTraits::Format(std::forward(args)...)) {} }; template struct StackedError : details::ErrorStackFromContainerBase> { public: using ErrorType = E; using StackMessage = M; using StackEntryType = ErrorStackEntry; using StackType = std::vector; using BaseType = details::ErrorStackFromContainerBase>; static_assert(!std::is_reference::value, "the underlying value type cannot be reference"); StackedError(ErrorType error) // NOLINT(google-explicit-constructor) : error_(std::move(error)) {} ErrorType& Error() { return error_; } const ErrorType& Error() const { return error_; } std::string Dump() { std::stringstream res; res << "error occurred: " << error_ << std::endl; for (const auto& elem : stack_) { res << "from " << elem.function << " in " << elem.filename << ":" << elem.lineno << ": " << elem.message << std::endl; } return res.str(); } [[noreturn]] void Abort() { std::cerr << "error occurred: " << error_ << std::endl; for (const auto& elem : stack_) { std::cerr << "from " << elem.function << " in " << elem.filename << ":" << elem.lineno << ": " << elem.message << std::endl; } std::abort(); } private: ErrorType error_; StackType stack_; StackType& GetStack() { return stack_; } const StackType& GetStack() const { return stack_; } friend BaseType; }; template struct NoStackError { using ErrorType = E; using StackEntryType = void; static_assert(!std::is_reference::value, "the underlying value type cannot be reference"); NoStackError(ErrorType error) // NOLINT(google-explicit-constructor) : error_(std::move(error)) {} ErrorType& Error() { return error_; } const ErrorType& Error() const { return error_; } std::size_t StackSize() const { return 0; } void StackElem(std::size_t) const {} template void PushStack(Args&&... args) {} std::string Dump() { std::stringstream res; res << error_ << std::endl; return res.str(); } [[noreturn]] void Abort() { std::cerr << error_ << std::endl; std::abort(); } private: ErrorType error_; }; } // namespace simple } // namespace maybe } // namespace oneflow #endif // ONEFLOW_MAYBE_ERROR_H_ ================================================ FILE: oneflow/maybe/error_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/maybe/error.h" using namespace oneflow::maybe; using namespace oneflow::maybe::simple; using namespace std::string_literals; namespace oneflow { namespace maybe { // test if StackedErrorTraits can be applied to some simple types template struct StackedErrorTraits>; template struct StackedErrorTraits>; } // namespace maybe } // namespace oneflow TEST(StackedError, SimpleStackedError) { StackedError a(std::make_error_code(std::errc::timed_out)); ASSERT_EQ(a.Error(), std::errc::timed_out); ASSERT_EQ(a.StackSize(), 0); const auto& ec = a.Error(); ASSERT_DEATH(a.Abort(), // NOLINT(cppcoreguidelines-avoid-goto) ec.category().name() + ":"s + std::to_string(ec.value())); [&a] { a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, "hello"); }(); struct SomeType { explicit SomeType(decltype(a)& a) { a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, "hi"); } } x(a); ASSERT_EQ(a.StackSize(), 2); ASSERT_DEATH(a.Abort(), // NOLINT(cppcoreguidelines-avoid-goto) "(lambda|operator\\(\\)).*hello.*\n.*SomeType::SomeType.*hi"); ASSERT_EQ(a.StackElem(0).message, "hello"); ASSERT_EQ(a.StackElem(1).message, "hi"); } TEST(StackedError, SimpleNoStackError) { NoStackError a(std::make_error_code(std::errc::address_in_use)); ASSERT_EQ(a.Error(), std::errc::address_in_use); ASSERT_EQ(a.StackSize(), 0); const auto& ec = a.Error(); ASSERT_DEATH(a.Abort(), // NOLINT(cppcoreguidelines-avoid-goto) ec.category().name() + ":"s + std::to_string(ec.value())); a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, "hello"); ASSERT_EQ(a.StackSize(), 0); } ================================================ FILE: oneflow/maybe/just.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_JUST_H_ #define ONEFLOW_MAYBE_JUST_H_ #include #include #include "oneflow/maybe/error.h" #include "oneflow/maybe/type_traits.h" namespace oneflow { namespace maybe { template struct Maybe; template struct IsMaybe : std::false_type {}; template struct IsMaybe> : std::true_type {}; template struct Optional; template struct IsOptional : std::false_type {}; template struct IsOptional> : std::true_type {}; // user should provide which error will be returned while an optional has no value // and is used in JUST or CHECK_JUST; // if not provided, then JUST(_MSG) and CHECK_JUST(_MSG) cannot be used for Optional // i.e. ```c++ // template struct JustConfig> { // static SomeError ValueNotFoundError(auto&&) { ... } // }; // ``` // or some other optional types, i.e. std::shared_ptr ```c++ // template struct JustConfig> { // // define which error will be returned while it is empty // static SomeError ValueNotFoundError(auto&&) { ... } // // define how to get the underlying value // static decltype(auto) Value(auto&&) { ... } // }; // ``` template struct JustTraits; namespace details { struct JustPrivateScope { template static decltype(auto) Value(T&& v) { return std::forward(v).Value(); } template>::value, int> = 0> static decltype(auto) StackedError(T&& v) { return std::forward(v).StackedError(); } template>::value, int> = 0> static decltype(auto) StackedError(T&& v) { return JustTraits>::ValueNotFoundError(std::forward(v)); } }; template typename std::remove_const::type>::type&& RemoveRValConst( T&& v) noexcept { static_assert(std::is_rvalue_reference::value, "rvalue is expected here"); return const_cast::type>::type&&>(v); } template decltype(auto) JustPushStackAndReturn(T&& v, Args&&... args) { StackedErrorTraits>::PushStack(std::forward(v), std::forward(args)...); return std::forward(v); } template [[noreturn]] void JustPushStackAndAbort(T&& v, Args&&... args) { using Traits = StackedErrorTraits>; Traits::PushStack(std::forward(v), std::forward(args)...); Traits::Abort(std::forward(v)); } template::value || IsOptional::value, int> = 0> auto JustGetValue(T&& v) -> RemoveRValRef(v)))> { return JustPrivateScope::Value(std::forward(v)); } template::value && !IsOptional::value, int> = 0> auto JustGetValue(T&& v) -> RemoveRValRef>::Value(std::forward(v)))> { return JustTraits>::Value(std::forward(v)); } } // namespace details } // namespace maybe } // namespace oneflow // macros begin #define JUST_STACK_CHECK_I(...) __VA_ARGS__ #define JUST_TO_STR_I(...) #__VA_ARGS__ #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) #define JUST(...) \ ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__); \ if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { \ return ::oneflow::maybe::details::JustPushStackAndReturn( \ ::oneflow::maybe::details::JustPrivateScope::StackedError( \ std::forward(_just_value_to_check_)), \ __FILE__, __LINE__, __PRETTY_FUNCTION__, JUST_TO_STR_I(__VA_ARGS__)); \ } \ std::forward(_just_value_to_check_); \ }))) #define CHECK_JUST(...) \ ::oneflow::maybe::details::JustGetValue([&](const auto& _just_function_name_) { \ auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__); \ if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { \ ::oneflow::maybe::details::JustPushStackAndAbort( \ ::oneflow::maybe::details::JustPrivateScope::StackedError( \ std::forward(_just_value_to_check_)), \ __FILE__, __LINE__, _just_function_name_, JUST_TO_STR_I(__VA_ARGS__)); \ } \ return std::forward(_just_value_to_check_); \ }(__PRETTY_FUNCTION__)) #define JUST_MSG(_just_expr_, ...) \ ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = (_just_expr_); \ if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { \ return ::oneflow::maybe::details::JustPushStackAndReturn( \ ::oneflow::maybe::details::JustPrivateScope::StackedError( \ std::forward(_just_value_to_check_)), \ __FILE__, __LINE__, __PRETTY_FUNCTION__, JUST_TO_STR_I(_just_expr_), __VA_ARGS__); \ } \ std::forward(_just_value_to_check_); \ }))) #define CHECK_JUST_MSG(_just_expr_, ...) \ ::oneflow::maybe::details::JustGetValue([&](const auto& _just_function_name_) { \ auto&& _just_value_to_check_ = (_just_expr_); \ if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { \ ::oneflow::maybe::details::JustPushStackAndAbort( \ ::oneflow::maybe::details::JustPrivateScope::StackedError( \ std::forward(_just_value_to_check_)), \ __FILE__, __LINE__, _just_function_name_, JUST_TO_STR_I(_just_expr_), __VA_ARGS__); \ } \ return std::forward(_just_value_to_check_); \ }(__PRETTY_FUNCTION__)) #define OPT_JUST(...) \ ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({ \ auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__); \ if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { return NullOpt; } \ std::forward(_just_value_to_check_); \ }))) #else #error "statement expression is not supported, please implement try-catch version of JUST" #endif // defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) // macros end #endif // ONEFLOW_MAYBE_JUST_H_ ================================================ FILE: oneflow/maybe/just_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "oneflow/maybe/maybe.h" #include "oneflow/maybe/optional.h" using namespace oneflow::maybe; TEST(Just, MaybeBasic) { using Error = simple::StackedError; using MaybeInt = Maybe; auto f = [](int x) -> MaybeInt { if (x > 10 || x < 0) { return Error{"not in range"}; } return x + 10; }; auto g = [&f](int x) -> MaybeInt { if (x == 15) { return Error{"invalid value"}; } return JUST(f(x)) * 2; }; auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; }; ASSERT_EQ(CHECK_JUST(h(0)), 22); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(11)), R"(not in range.*(lambda|operator\(\)).*f\(x\).*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(11\))"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(15)), R"(invalid value.*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(15\))"); ASSERT_EQ(details::JustPrivateScope::StackedError(h(12)).StackSize(), 2); ASSERT_EQ(details::JustPrivateScope::StackedError(h(15)).StackSize(), 1); } TEST(Just, MaybeVoid) { using Error = simple::StackedError; using MaybeVoid = Maybe; auto f = [](int& x) -> MaybeVoid { if (x > 10 || x < 0) { return Error{"not in range"}; } x = x + 5; return Ok; }; auto g = [&f](int& x) -> MaybeVoid { if (x == 15) { return Error{"invalid value"}; } JUST(f(x)); JUST(f(x)); return Ok; }; auto h = [&g](int& x) -> MaybeVoid { JUST(g(x)); x = x + 2; return Ok; }; int x = 0; CHECK_JUST(h(x)); ASSERT_EQ(x, 12); x = 11; ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(x)), R"(not in range.*(lambda|operator\(\)).*f\(x\).*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(x\))"); ASSERT_EQ(x, 11); x = 8; ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(x)), R"(not in range.*(lambda|operator\(\)).*f\(x\).*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(x\))"); [[maybe_unused]] auto _ = h(x); // NOLINT ASSERT_EQ(x, 13); } TEST(Just, MaybeRef) { using Error = simple::StackedError; using MaybeRef = Maybe; int k = 100; auto f = [&k](const int& x) -> MaybeRef { if (x > 10 || x < 0) { return Error{"not in range"}; } if (x < 5) return x; return k; }; auto g = [&f](const int& x) -> MaybeRef { if (x == 2) { return Error{"invalid value"}; } return JUST(f(x)); }; int x = 1; ASSERT_EQ(CHECK_JUST(g(x)), 1); const int& y = CHECK_JUST(g(5)); ASSERT_EQ(y, 100); k = 200; ASSERT_EQ(y, 200); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(g(11)), R"(not in range.*(lambda|operator\(\)).*f\(x\).*TestBody.*g\(11\))"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(g(2)), R"(invalid value.*TestBody.*g\(2\))"); } TEST(Just, MaybeErrorPtr) { using E = simple::StackedError; using Error = std::unique_ptr; using MaybeInt = Maybe; auto f = [](int x) -> MaybeInt { if (x > 10 || x < 0) { return std::make_unique("not in range"); } return x + 10; }; auto g = [&f](int x) -> MaybeInt { if (x == 15) { return std::make_unique("invalid value"); } return JUST(f(x)) * 2; }; auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; }; ASSERT_EQ(CHECK_JUST(h(0)), 22); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(11)), R"(not in range.*(lambda|operator\(\)).*f\(x\).*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(11\))"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(15)), R"(invalid value.*(lambda|operator\(\)).*g\(x\).*TestBody.*h\(15\))"); ASSERT_EQ(details::JustPrivateScope::StackedError(h(12))->StackSize(), 2); ASSERT_EQ(details::JustPrivateScope::StackedError(h(15))->StackSize(), 1); } namespace oneflow { namespace maybe { template struct JustTraits { template static simple::StackedError ValueNotFoundError(U&&) { return {"not found"}; } template static decltype(auto) Value(U&& v) { return *v; } }; } // namespace maybe } // namespace oneflow TEST(Just, Optional) { using Error = simple::StackedError; using MaybeInt = Maybe; Optional a, b(1), c(2); auto f = [](const Optional& x) -> MaybeInt { if (x == 1) return Error("hello"); return JUST(x) + 1; }; ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(f(a)), R"(not found.*(lambda|operator\(\)).*x.*TestBody.*f\(a\))"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(f(b)), R"(hello.*TestBody.*f\(b\))"); ASSERT_EQ(CHECK_JUST(f(c)), 3); } TEST(Just, Ptr) { using Error = simple::StackedError; using MaybeInt = Maybe; std::shared_ptr a, b(std::make_shared(1)), c(std::make_shared(2)); auto f = [](const std::shared_ptr& x) -> MaybeInt { if (JUST(x) == 1) return Error("hello"); return JUST(x) + 1; }; ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(f(a)), R"(not found.*(lambda|operator\(\)).*x.*TestBody.*f\(a\))"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(f(b)), R"(hello.*TestBody.*f\(b\))"); ASSERT_EQ(CHECK_JUST(f(c)), 3); } TEST(Just, WithMsg) { struct UniqueInt { int x; void drop() { x = -333; } explicit UniqueInt(int x) : x{x} {} UniqueInt(const UniqueInt& i) = delete; UniqueInt(UniqueInt&& i) noexcept : x{i.x} { i.drop(); } // NOLINT UniqueInt& operator=(const UniqueInt& i) = delete; UniqueInt& operator=(UniqueInt&& i) noexcept { x = i.x; i.drop(); return *this; } ~UniqueInt() { drop(); } }; using Error = simple::StackedError; using MaybeInt = Maybe; auto f = [](UniqueInt x) -> MaybeInt { if (x.x > 10) { return Error{"input value " + std::to_string(x.x)}; } return UniqueInt{233}; }; auto g = [](UniqueInt x) { int y = x.x; return UniqueInt{y * y - 5 * y + 3}; }; auto h = [&](UniqueInt x) -> MaybeInt { int n = x.x; auto y = g(std::move(x)); return JUST_MSG(f(std::move(y)), "input value g(", n, ")"); }; auto i = [&](float x) -> MaybeInt { UniqueInt y{int(x)}; return JUST_MSG(h(std::move(y)), "input value int(", x, ")"); }; auto data = CHECK_JUST(i(1)); ASSERT_EQ(data.x, 233); auto err = details::JustPrivateScope::StackedError(i(10.123)); ASSERT_EQ(err.Error(), "input value 53"); ASSERT_EQ(err.StackElem(0).message, "f(std::move(y)): input value g(10)"); ASSERT_EQ(err.StackElem(1).message, "h(std::move(y)): input value int(10.123)"); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto) ASSERT_EXIT(CHECK_JUST(i(10.234)), testing::KilledBySignal(SIGABRT), R"(input value 53)"); } TEST(Just, JustOpt) { auto f = [](int x) -> Optional { if (x > 10) return NullOpt; return x + 1; }; auto g = [&f](int x) -> Optional { return OPT_JUST(f(x)) * 2; }; ASSERT_EQ(CHECK_JUST(g(2)), 6); ASSERT_FALSE(g(11)); auto h = [&](int x) -> Optional { if (x == 10) return NullOpt; return OPT_JUST(g(x)) + OPT_JUST(f(x + 2)); }; ASSERT_FALSE(h(10)); ASSERT_FALSE(h(9)); ASSERT_EQ(h(8), 29); } TEST(Just, NoStack) { using Error = simple::NoStackError; using MaybeInt = Maybe; auto f = [](int x) -> MaybeInt { if (x > 10 || x < 0) { return Error{"not in range"}; } return x + 10; }; auto g = [&f](int x) -> MaybeInt { if (x == 15) { return Error{"invalid value"}; } return JUST(f(x)) * 2; }; auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; }; ASSERT_EQ(CHECK_JUST(h(0)), 22); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(11)), R"(not in range)"); ASSERT_DEATH( // NOLINT(cppcoreguidelines-avoid-goto) CHECK_JUST(h(15)), R"(invalid value)"); ASSERT_EQ(details::JustPrivateScope::StackedError(h(12)).StackSize(), 0); ASSERT_EQ(details::JustPrivateScope::StackedError(h(15)).StackSize(), 0); } ================================================ FILE: oneflow/maybe/maybe.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_MAYBE_H_ #define ONEFLOW_MAYBE_MAYBE_H_ #include #include #include "oneflow/maybe/just.h" #include "oneflow/maybe/variant.h" #include "oneflow/maybe/optional.h" #include "oneflow/maybe/error.h" #include "oneflow/maybe/config.h" namespace oneflow { namespace maybe { struct InPlaceOkType { explicit constexpr InPlaceOkType() = default; }; constexpr InPlaceOkType Ok{}; struct InPlaceErrorType { explicit constexpr InPlaceErrorType() = default; }; constexpr InPlaceErrorType InPlaceError{}; namespace details { template struct MaybeStorage : Variant { using Base = Variant; MaybeStorage(const T& v) : Base(v) {} // NOLINT(google-explicit-constructor) MaybeStorage(T&& v) : Base(std::move(v)) {} // NOLINT(google-explicit-constructor) template explicit MaybeStorage(InPlaceOkType, Args&&... args) : Base(InPlaceType, std::forward(args)...) {} template explicit MaybeStorage(InPlaceErrorType, Args&&... args) : Base(InPlaceType, std::forward(args)...) {} MaybeStorage(const E& err) : Base(err) {} // NOLINT(google-explicit-constructor) MaybeStorage(E&& err) : Base(std::move(err)) {} // NOLINT(google-explicit-constructor) decltype(auto) Value() & { return this->Base::template Value(); } decltype(auto) Value() const& { return this->Base::template Value(); } decltype(auto) Value() && { return std::move(*this).Base::template Value(); } decltype(auto) Error() & { return this->Base::template Value(); } decltype(auto) Error() const& { return this->Base::template Value(); } decltype(auto) Error() && { return std::move(*this).Base::template Value(); } bool IsOk() const { return this->template Is(); } }; template struct MaybeStorage::value>> : Variant*, E> { static_assert(std::is_lvalue_reference::value, "rvalue reference is not allowed here"); using PointedType = std::remove_reference_t; using UnderlyingType = PointedType*; using Base = Variant; MaybeStorage(T v) : Base(&v) {} // NOLINT(google-explicit-constructor) MaybeStorage(const E& err) : Base(err) {} // NOLINT(google-explicit-constructor) MaybeStorage(E&& err) : Base(std::move(err)) {} // NOLINT(google-explicit-constructor) template explicit MaybeStorage(InPlaceErrorType, Args&&... args) : Base(InPlaceType, std::forward(args)...) {} PointedType& Value() { return *this->Base::template Value(); } const PointedType& Value() const { return *this->Base::template Value(); } decltype(auto) Error() & { return this->Base::template Value(); } decltype(auto) Error() const& { return this->Base::template Value(); } decltype(auto) Error() && { return std::move(*this).Base::template Value(); } bool IsOk() const { return this->template Is(); } }; template struct MaybeStorage : Optional { using Base = Optional; MaybeStorage(InPlaceOkType) : Base(NullOpt) {} // NOLINT(google-explicit-constructor) MaybeStorage(const E& err) : Base(err) {} // NOLINT(google-explicit-constructor) MaybeStorage(E&& err) : Base(std::move(err)) {} // NOLINT(google-explicit-constructor) template explicit MaybeStorage(InPlaceErrorType, Args&&... args) : Base(InPlace, std::forward(args)...) {} void Value() const {} decltype(auto) Error() & { return this->Base::Value(); } decltype(auto) Error() const& { return this->Base::Value(); } decltype(auto) Error() && { return std::move(*this).Base::Value(); } bool IsOk() const { return !this->HasValue(); } }; struct MaybePrivateScope { template static decltype(auto) Value(T&& m) { return std::forward(m).Value(); } template static decltype(auto) StackedError(T&& m) { return std::forward(m).StackedError(); } template static auto Map(T&& maybe, F&& f) -> Maybe(f)(std::forward(maybe).Value())), typename RemoveCVRef::StackedErrorType> { if (maybe) { return std::forward(f)(std::forward(maybe).Value()); } return std::forward(maybe).StackedError(); } template()(std::declval().Value()))>> static auto Bind(T&& maybe, F&& f) -> std::enable_if_t::value, U> { if (maybe) { return std::forward(f)(std::forward(maybe).Value()); } return std::forward(maybe).StackedError(); } }; } // namespace details // A type which can be either a value typed T, or a stacked error typed E template struct OF_MAYBE_NODISCARD_TYPE Maybe : private details::MaybeStorage { static_assert(!std::is_reference::value, "error type cannot be reference"); static_assert(!(std::is_const::value || std::is_volatile::value), "error type cannot be cv-qualified"); // E must be a stacked error, which implies StackedErrorTraits must exist using ErrorTraits = StackedErrorTraits; using StackedErrorType = E; using ValueType = T; using ErrorType = typename ErrorTraits::ErrorType; private: using Base = details::MaybeStorage; friend struct details::MaybePrivateScope; friend struct details::JustPrivateScope; protected: decltype(auto) Value() & { return Base::Value(); } decltype(auto) Value() const& { return Base::Value(); } decltype(auto) Value() && { return std::move(*this).Base::Value(); } decltype(auto) StackedError() & { return Base::Error(); } decltype(auto) StackedError() const& { return Base::Error(); } decltype(auto) StackedError() && { return std::move(*this).Base::Error(); } decltype(auto) Error() & { return ErrorTraits::Error(StackedError()); } decltype(auto) Error() const& { return ErrorTraits::Error(StackedError()); } decltype(auto) Error() && { return ErrorTraits::Error(std::move(*this).StackedError()); } public: using Base::Base; OF_MAYBE_NODISCARD_FUNC bool IsOk() const { return Base::IsOk(); } OF_MAYBE_NODISCARD_FUNC bool IsErr() const { return !Base::IsOk(); } explicit operator bool() const { return IsOk(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() & { OF_MAYBE_ASSERT(IsErr()); return StackedError(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() const& { OF_MAYBE_ASSERT(IsErr()); return StackedError(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() && { OF_MAYBE_ASSERT(IsErr()); return std::move(*this).StackedError(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() & { OF_MAYBE_ASSERT(IsErr()); return Error(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() const& { OF_MAYBE_ASSERT(IsErr()); return Error(); } OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() && { OF_MAYBE_ASSERT(IsErr()); return std::move(*this).Error(); } template OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) const& { return details::MaybePrivateScope::Map(*this, std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) && { return details::MaybePrivateScope::Map(std::move(*this), std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) const& { return details::MaybePrivateScope::Bind(*this, std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) && { return details::MaybePrivateScope::Bind(std::move(*this), std::forward(f)); } }; } // namespace maybe } // namespace oneflow #endif // ONEFLOW_MAYBE_MAYBE_H_ ================================================ FILE: oneflow/maybe/maybe_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/maybe/error.h" #include "oneflow/maybe/maybe.h" using namespace oneflow::maybe; TEST(Maybe, Basic) { using Error = simple::StackedError; Maybe a{1}, b{a}, c{Error(2)}, d{c}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); ASSERT_EQ(details::MaybePrivateScope::Value(a), 1); ASSERT_EQ(details::MaybePrivateScope::Value(b), 1); a = 2; ASSERT_EQ(details::MaybePrivateScope::Value(a), 2); ASSERT_EQ(details::MaybePrivateScope::Value(b), 1); ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), 2); ASSERT_EQ(details::MaybePrivateScope::StackedError(d).Error(), 2); a = c; ASSERT_EQ(details::MaybePrivateScope::StackedError(a).Error(), 2); } TEST(Maybe, NonPOD) { using Error = simple::StackedError; Maybe, Error> a{Ok, new int{1}}, b{a}, c{Error("test")}, d{c}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); ASSERT_EQ(details::MaybePrivateScope::Value(a).use_count(), 2); { Maybe, Error> x(a); ASSERT_EQ(details::MaybePrivateScope::Value(x).use_count(), 3); x = c; ASSERT_FALSE(x); x = a; ASSERT_EQ(details::MaybePrivateScope::Value(x).use_count(), 3); } ASSERT_EQ(details::MaybePrivateScope::Value(a).use_count(), 2); ASSERT_EQ(*details::MaybePrivateScope::Value(a), 1); *details::MaybePrivateScope::Value(a) = 2; ASSERT_EQ(*details::MaybePrivateScope::Value(a), 2); ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), "test"); ASSERT_EQ(details::MaybePrivateScope::StackedError(c).StackSize(), 0); } TEST(Maybe, Reference) { using Error = simple::StackedError; const int& n = 1; Maybe a{n}, b{a}, c{Error("test")}, d{c}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); ASSERT_EQ(details::MaybePrivateScope::Value(a), 1); int k = 2; a = k; ASSERT_EQ(details::MaybePrivateScope::Value(a), 2); k = 3; ASSERT_EQ(details::MaybePrivateScope::Value(a), 3); int x = 1; Maybe e{x}, f{e}, g{Error("test")}, h{g}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); ASSERT_EQ(details::MaybePrivateScope::Value(e), 1); e = k; ASSERT_EQ(details::MaybePrivateScope::Value(e), 3); details::MaybePrivateScope::Value(e) = 4; ASSERT_EQ(k, 4); } TEST(Maybe, Void) { using Error = simple::StackedError; Maybe a{Ok}, b{a}, c{Error("test")}, d{c}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), "test"); c = Error("hello"); ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), "hello"); a = c; ASSERT_EQ(details::MaybePrivateScope::StackedError(a).Error(), "hello"); } TEST(Maybe, PtrError) { using PointedError = simple::StackedError; using Error = std::unique_ptr; Maybe a{1}, c{InPlaceError, new PointedError("test")}; ASSERT_TRUE(a); ASSERT_FALSE(c); ASSERT_EQ(details::MaybePrivateScope::StackedError(c)->Error(), "test"); } TEST(Maybe, NoStack) { using Error = simple::NoStackError; Maybe a{1}, b{a}, c{InPlaceError, "hello"}, d{c}; ASSERT_TRUE(a); ASSERT_TRUE(b); ASSERT_FALSE(c); ASSERT_FALSE(d); a = c; ASSERT_FALSE(a); } TEST(Maybe, Monadic) { using Error = simple::NoStackError; Maybe a{1}, b{InPlaceError, "hello"}; auto x2 = [](int x) { return x * 2; }; auto x2e2 = [](int x) -> Maybe { if (x == 4) return Error("test"); return x * 2; }; ASSERT_EQ(CHECK_JUST(a.Map(x2).Map(x2)), 4); ASSERT_FALSE(b.Map(x2).Map(x2)); a = 1; ASSERT_EQ(CHECK_JUST(a.Bind(x2e2).Bind(x2e2)), 4); a = 2; ASSERT_EQ(CHECK_JUST(a.Bind(x2e2)), 4); ASSERT_EQ(a.Bind(x2e2).Bind(x2e2).GetError(), "test"); a = 4; ASSERT_EQ(a.Bind(x2e2).GetError(), "test"); ASSERT_EQ(a.Bind(x2e2).GetError(), "test"); } ================================================ FILE: oneflow/maybe/optional.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_OPTIONAL_H_ #define ONEFLOW_MAYBE_OPTIONAL_H_ #include #include #include "oneflow/maybe/just.h" #include "oneflow/maybe/utility.h" #include "oneflow/maybe/type_traits.h" namespace oneflow { namespace maybe { template struct Optional; namespace details { // OptionalStorage is specialized for 2 cases: // 1. for scalar types, we optimize all construction, destruction and value check // 2. for reference types, we store a pointer to the referenced value template struct OptionalStorage { private: bool has_; alignas(T) unsigned char value_[sizeof(T)]; using Type = std::remove_const_t; public: OptionalStorage() = default; ~OptionalStorage() = default; OptionalStorage(const OptionalStorage&) = delete; OptionalStorage& operator=(const OptionalStorage&) = delete; void Init() { has_ = false; } T& Value() & { return *reinterpret_cast(value_); } Type&& Value() && { return std::move(*const_cast(reinterpret_cast(value_))); } const T& Value() const& { return *reinterpret_cast(value_); } bool HasValue() const { return has_; } void Reset() { if (has_) { has_ = false; Value().~T(); } } void Destory() { if (has_) { Value().~T(); } } template, int> = 0> void Construct(Args&&... args) { new (value_) Type{std::forward(args)...}; has_ = true; } template, int> = 0> void Construct(Args&&... args) { new (value_) Type(std::forward(args)...); has_ = true; } template::value, int> = 0> T& Emplace(Args&&... args) { if (!has_) { Construct(std::forward(args)...); return Value(); } else { return Value() = Type(std::forward(args)...); } } template::value, int> = 0> T& Emplace(Args&&... args) { Destory(); Construct(std::forward(args)...); return Value(); } template void CopyConstruct(OS&& s) { has_ = s.has_; if (has_) { new (value_) Type(std::forward(s).Value()); } } template void Copy(OS&& s) { if (s.has_) { Emplace(std::forward(s).Value()); } else { Reset(); } } }; template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct OptionalStorage::value>> { private: using Type = std::remove_const_t; bool has_; Type value_; public: OptionalStorage() = default; ~OptionalStorage() = default; OptionalStorage(const OptionalStorage&) = delete; OptionalStorage& operator=(const OptionalStorage&) = delete; void Init() { has_ = false; value_ = {}; } T& Value() & { return value_; } Type&& Value() && { return std::move(const_cast(value_)); } const T& Value() const& { return value_; } bool HasValue() const { return has_; } void Reset() { has_ = false; } void Destory() {} template void Construct(const U& v) { value_ = v; has_ = true; } template T& Emplace(const U& v) { Construct(v); return Value(); } void CopyConstruct(const OptionalStorage& s) { has_ = s.has_; value_ = s.value_; } void Copy(const OptionalStorage& s) { CopyConstruct(s); } }; template struct OptionalStorage::value>> { static_assert(std::is_lvalue_reference::value, "rvalue reference is not allowed here"); using Type = std::remove_reference_t; private: Type* value_; public: OptionalStorage() = default; ~OptionalStorage() = default; OptionalStorage(const OptionalStorage&) = delete; OptionalStorage& operator=(const OptionalStorage&) = delete; void Init() { value_ = nullptr; } T Value() { return *value_; } const Type& Value() const { return *value_; } bool HasValue() const { return value_ != nullptr; } void Reset() { value_ = nullptr; } void Destory() {} void Construct(T v) { value_ = &v; } T Emplace(T v) { Construct(v); return Value(); } void CopyConstruct(const OptionalStorage& s) { value_ = s.value_; } void Copy(const OptionalStorage& s) { CopyConstruct(s); } }; struct OptionalPrivateScope { template static decltype(auto) Value(T&& opt) { return std::forward(opt).Value(); } template static auto Map(T&& opt, F&& f) -> Optional(f)(std::forward(opt).Value()))> { if (opt.HasValue()) { return std::forward(f)(std::forward(opt).Value()); } return NullOpt; } template()(std::declval().Value()))>> static auto Bind(T&& opt, F&& f) -> std::enable_if_t::value, U> { if (opt.HasValue()) { return std::forward(f)(std::forward(opt).Value()); } return NullOpt; } template()()), void>::value, int> = 0> static auto OrElse(T&& opt, F&& f) -> std::decay_t { if (!opt.HasValue()) { std::forward(f)(); return NullOpt; } return std::forward(opt); } template()()), std::decay_t>::value, int> = 0> static auto OrElse(T&& opt, F&& f) -> std::decay_t { if (!opt.HasValue()) { return std::forward(f)(); } return std::forward(opt); } }; } // namespace details // unlike Variant, type arguments can be cv qualified or lvalue referenced // this Optional DO NOT guarantee exception safety template struct OF_MAYBE_NODISCARD_TYPE Optional { protected: details::OptionalStorage storage_; using Type = std::remove_const_t; decltype(auto) Value() & { return storage_.Value(); } decltype(auto) Value() && { return std::move(storage_).Value(); } decltype(auto) Value() const& { return storage_.Value(); } // we DO NOT export Value method, then leave these methods accessable for the JUST macro friend struct details::OptionalPrivateScope; friend struct details::JustPrivateScope; public: static_assert(!std::is_same, NullOptType>::value, "NullOptType is not allowed in Optional"); using ValueType = T; explicit Optional() { storage_.Init(); }; Optional(NullOptType) { storage_.Init(); } // NOLINT(google-explicit-constructor) Optional(const T& v) { storage_.Construct(v); } // NOLINT(google-explicit-constructor) template::value, int> = 0> Optional(Type&& v) { // NOLINT(google-explicit-constructor) storage_.Construct(std::move(v)); } Optional(const Optional& opt) { storage_.CopyConstruct(opt.storage_); } Optional(Optional&& opt) noexcept { storage_.CopyConstruct(std::move(opt.storage_)); } template explicit Optional(InPlaceT, Args&&... args) { storage_.Construct(std::forward(args)...); } ~Optional() { storage_.Destory(); } Optional& operator=(NullOptType) { storage_.Reset(); return *this; } Optional& operator=(const T& v) { storage_.Emplace(v); return *this; } template::value, int> = 0> Optional& operator=(Type&& v) { storage_.Emplace(std::move(v)); return *this; } template decltype(auto) Emplace(Args&&... args) { return storage_.Emplace(std::forward(args)...); } Optional& operator=(const Optional& opt) { storage_.Copy(opt.storage_); return *this; } Optional& operator=(Optional&& opt) noexcept { storage_.Copy(std::move(opt.storage_)); return *this; } OF_MAYBE_NODISCARD_FUNC bool HasValue() const { return storage_.HasValue(); } explicit operator bool() const { return HasValue(); } bool operator==(const Optional& opt) const { if (HasValue()) { if (opt.HasValue()) { return Value() == opt.Value(); } else { return false; } } else { return !opt.HasValue(); } } bool operator!=(const Optional& opt) const { return !operator==(opt); } bool operator<(const Optional& opt) const { if (HasValue()) { if (opt.HasValue()) { return Value() < opt.Value(); } else { return false; } } else { return opt.HasValue(); } } bool operator>=(const Optional& opt) const { return !operator<(opt); } bool operator>(const Optional& opt) const { if (HasValue()) { if (opt.HasValue()) { return Value() > opt.Value(); } else { return true; } } else { return false; } } bool operator<=(const Optional& opt) const { return !operator>(opt); } friend bool operator==(const Optional& opt, NullOptType) { return !opt.HasValue(); } friend bool operator!=(const Optional& opt, NullOptType) { return opt.HasValue(); } friend bool operator==(NullOptType, const Optional& opt) { return !opt.HasValue(); } friend bool operator!=(NullOptType, const Optional& opt) { return opt.HasValue(); } friend bool operator<(const Optional& opt, NullOptType) { return false; } friend bool operator>(const Optional& opt, NullOptType) { return opt.HasValue(); } friend bool operator<=(const Optional& opt, NullOptType) { return !opt.HasValue(); } friend bool operator>=(const Optional& opt, NullOptType) { return true; } friend bool operator<(NullOptType, const Optional& opt) { return opt > NullOpt; } friend bool operator>(NullOptType, const Optional& opt) { return opt < NullOpt; } friend bool operator<=(NullOptType, const Optional& opt) { return opt >= NullOpt; } friend bool operator>=(NullOptType, const Optional& opt) { return opt <= NullOpt; } friend bool operator==(const Optional& opt, const T& v) { if (opt.HasValue()) { return opt.Value() == v; } else { return false; } } friend bool operator!=(const Optional& opt, const T& v) { return !(opt == v); } friend bool operator==(const T& v, const Optional& opt) { return opt == v; } friend bool operator!=(const T& v, const Optional& opt) { return !(opt == v); } friend bool operator<(const Optional& opt, const T& v) { if (opt.HasValue()) { return opt.Value() < v; } else { return true; } } friend bool operator>=(const Optional& opt, const T& v) { return !(opt < v); } friend bool operator>(const T& v, const Optional& opt) { return opt < v; } friend bool operator<=(const T& v, const Optional& opt) { return !(opt < v); } friend bool operator>(const Optional& opt, const T& v) { if (opt.HasValue()) { return opt.Value() > v; } else { return false; } } friend bool operator<=(const Optional& opt, const T& v) { return !(opt > v); } friend bool operator<(const T& v, const Optional& opt) { return opt > v; } friend bool operator>=(const T& v, const Optional& opt) { return !(opt > v); } decltype(auto) ValueOr(const T& v) const& { if (HasValue()) { return Value(); } else { return v; } } template::value, int> = 0> auto ValueOr(T&& v) const& { if (HasValue()) { return Value(); } else { return std::move(v); } } template::value, int> = 0> auto ValueOr(const T& v) && { if (HasValue()) { return std::move(*this).Value(); } else { return v; } } template::value, int> = 0> decltype(auto) ValueOr(T&& v) && { if (HasValue()) { return std::move(*this).Value(); } else { return std::move(v); } } void Reset() { storage_.Reset(); } template OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) const& { return details::OptionalPrivateScope::Map(*this, std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) && { return details::OptionalPrivateScope::Map(std::move(*this), std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) const& { return details::OptionalPrivateScope::Bind(*this, std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) && { return details::OptionalPrivateScope::Bind(std::move(*this), std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto OrElse(F&& f) const& { return details::OptionalPrivateScope::OrElse(*this, std::forward(f)); } template OF_MAYBE_NODISCARD_FUNC auto OrElse(F&& f) && { return details::OptionalPrivateScope::OrElse(std::move(*this), std::forward(f)); } }; } // namespace maybe } // namespace oneflow namespace std { template struct hash> { size_t operator()(const oneflow::maybe::Optional& v) const noexcept { if (v.HasValue()) { return hashImpl(oneflow::maybe::details::OptionalPrivateScope::Value(v)); } else { return oneflow::maybe::NullOptHash; } } template::value, int> = 0> static std::size_t hashImpl(const T& v) { return std::hash>()(v); } template::value, int> = 0> static std::size_t hashImpl(const std::remove_reference_t& v) { return std::hash*>()(&v); } }; } // namespace std #endif // ONEFLOW_MAYBE_OPTIONAL_H_ ================================================ FILE: oneflow/maybe/optional_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/maybe/optional.h" using namespace oneflow::maybe; using Private = details::OptionalPrivateScope; TEST(Optional, Scalar) { Optional a, b(1), c(a), d(b), e(NullOpt), bb(InPlace, 1); static_assert(std::is_same::value, ""); ASSERT_TRUE(!a.HasValue()); ASSERT_TRUE(b.HasValue()); ASSERT_EQ(b.ValueOr(0), 1); ASSERT_TRUE(!c.HasValue()); ASSERT_EQ(c.ValueOr(233), 233); ASSERT_TRUE(d.HasValue()); ASSERT_EQ(d.ValueOr(0), 1); ASSERT_TRUE(!e.HasValue()); a = b; ASSERT_TRUE(a.HasValue()); ASSERT_EQ(a.ValueOr(0), 1); a = NullOpt; ASSERT_TRUE(!a.HasValue()); a = 222; ASSERT_TRUE(a.HasValue()); ASSERT_TRUE(a); ASSERT_EQ(a.ValueOr(1), 222); Private::Value(a) = 2333; ASSERT_EQ(a.ValueOr(1), 2333); Optional f, g(1); ASSERT_TRUE(!f.HasValue()); ASSERT_TRUE(g.HasValue()); ASSERT_EQ(g.ValueOr(2), 1); static_assert(std::is_same::value, ""); f = 1; ASSERT_TRUE(f.HasValue()); ASSERT_EQ(f.ValueOr(2), 1); ASSERT_EQ(Private::Value(f), 1); int x = 2; ASSERT_EQ(f.ValueOr(x), 1); ASSERT_EQ(f.Emplace(2), 2); ASSERT_EQ(Private::Value(f), 2); f.Reset(); ASSERT_TRUE(!f); } TEST(Optional, NonScalar) { auto x = std::make_shared(233); ASSERT_EQ(x.use_count(), 1); Optional> a, b(x), aa(a), aaa(InPlace, std::make_shared(244)); ASSERT_EQ(x.use_count(), 2); ASSERT_EQ(*Private::Value(b), 233); static_assert(std::is_same&>::value, ""); ASSERT_TRUE(!a.HasValue()); ASSERT_TRUE(!aa.HasValue()); Optional> c(a), d(b); ASSERT_TRUE(!c.HasValue()); ASSERT_EQ(x.use_count(), 3); ASSERT_EQ(b, d); a = x; ASSERT_EQ(x.use_count(), 4); a = NullOpt; ASSERT_EQ(x.use_count(), 3); a = b; ASSERT_EQ(x.use_count(), 4); ASSERT_EQ(a, b); { Optional> e(a); // NOLINT ASSERT_EQ(x.use_count(), 5); Optional> f; f = e; ASSERT_EQ(x.use_count(), 6); } ASSERT_EQ(x.use_count(), 4); *Private::Value(a) = 234; ASSERT_EQ(*x, 234); Optional> g(std::move(a)); ASSERT_EQ(x.use_count(), 4); { Optional> h; ASSERT_TRUE(!h.HasValue()); h = std::move(b); ASSERT_EQ(x.use_count(), 4); } ASSERT_EQ(x.use_count(), 3); Optional> i(x); ASSERT_EQ(x.use_count(), 4); static_assert(std::is_same&>::value, ""); i = NullOpt; ASSERT_EQ(x.use_count(), 3); i.Emplace(x); ASSERT_EQ(x.use_count(), 4); i.Reset(); ASSERT_EQ(x.use_count(), 3); i.Emplace(std::move(x)); ASSERT_EQ(Private::Value(i).use_count(), 3); struct A { int id; std::string name; }; Optional a1, a2{InPlace, 233, "oneflow"}; ASSERT_FALSE(a1); ASSERT_TRUE(a2); ASSERT_EQ(a1, NullOpt); ASSERT_EQ(Private::Value(a2).id, 233); ASSERT_EQ(Private::Value(a2).name, "oneflow"); } TEST(Optional, Reference) { int x = 233; Optional a, b(x), c(a), d(b); ASSERT_TRUE(!a); ASSERT_TRUE(b); ASSERT_TRUE(!c); ASSERT_TRUE(d); ASSERT_EQ(Private::Value(b), 233); ASSERT_EQ(Private::Value(d), 233); static_assert(std::is_same::value, ""); a = x; ASSERT_TRUE(a); ASSERT_EQ(Private::Value(a), 233); a = NullOpt; ASSERT_TRUE(!a); a = b; ASSERT_TRUE(a); ASSERT_EQ(Private::Value(a), 233); Private::Value(a) = 234; ASSERT_EQ(x, 234); Optional e, f(x), g(e), h(f); ASSERT_TRUE(!e); ASSERT_TRUE(f); ASSERT_TRUE(!g); ASSERT_TRUE(h); ASSERT_NE(NullOpt, h); ASSERT_EQ(Private::Value(f), 234); ASSERT_EQ(Private::Value(h), 234); static_assert(std::is_same::value, ""); e = x; ASSERT_TRUE(e); ASSERT_EQ(e, x); ASSERT_EQ(e, 234); ASSERT_EQ(Private::Value(e), 234); e = NullOpt; ASSERT_TRUE(!e); ASSERT_EQ(e, NullOpt); } TEST(Optional, Hash) { Optional a, b(123); ASSERT_EQ(std::hash()(a), NullOptHash); ASSERT_EQ(std::hash()(b), std::hash()(123)); auto si = std::make_shared(123); Optional> c, d(si); ASSERT_EQ(std::hash()(c), NullOptHash); ASSERT_EQ(std::hash()(d), std::hash()(si)); int x = 233; Optional e, f(x); ASSERT_EQ(std::hash()(e), NullOptHash); ASSERT_EQ(std::hash()(f), std::hash()(&x)); Optional g; ASSERT_EQ(std::hash()(g), NullOptHash); } TEST(Optional, Compare) { Optional a, b, c(-1), d(0), e(1), f(1); ASSERT_EQ(a, b); ASSERT_EQ(e, f); ASSERT_NE(a, d); ASSERT_LT(b, c); ASSERT_LE(b, c); ASSERT_LE(c, c); ASSERT_LT(c, d); ASSERT_LT(d, e); ASSERT_GT(e, d); ASSERT_GT(d, c); ASSERT_GT(c, b); ASSERT_GE(c, b); ASSERT_GE(a, b); int x = 0, y = 1, z = -1; ASSERT_NE(a, x); ASSERT_EQ(d, x); ASSERT_NE(x, c); ASSERT_EQ(z, c); ASSERT_LT(a, x); ASSERT_LT(c, x); ASSERT_LT(d, y); ASSERT_LT(z, f); ASSERT_LE(a, x); ASSERT_LE(d, x); ASSERT_GT(x, a); ASSERT_GT(x, c); ASSERT_GT(y, d); ASSERT_GT(f, z); ASSERT_GE(x, a); ASSERT_GE(x, d); std::set> s{2, NullOpt, -1, 3, NullOpt, 2}; ASSERT_EQ(s.size(), 4); auto iter = s.begin(); ASSERT_EQ(*(iter++), NullOpt); ASSERT_EQ(*(iter++), -1); ASSERT_EQ(*(iter++), 2); ASSERT_EQ(*(iter++), 3); } TEST(Optional, Monadic) { Optional a(1), b, c(2); ASSERT_EQ(a.Map([](int x) { return x + 1; }), c); ASSERT_EQ(b.Map([](int x) { return x + 1; }), b); ASSERT_EQ(a.Map([](int x) { return std::string(x + 1, 'a'); }).Map([](const auto& x) { return (int)x.size(); }), c); ASSERT_EQ(a.Bind([](int x) -> Optional { if (x < 10) { return x * 1.1; } else { return NullOpt; } }) .Map([](float x) { return x - 1; }) .Map([](float x) { return std::abs(x - 0.1) < 0.001; }), Optional(true)); int x = 0; [[maybe_unused]] auto _ = b.OrElse([&] { x++; }).OrElse([&] { x *= 2; }); ASSERT_EQ(x, 2); ASSERT_EQ(b.OrElse([] { return Optional(3); }).Map([](int x) { return x - 1; }), c); } ================================================ FILE: oneflow/maybe/type_traits.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_TYPE_TRAITS_H_ #define ONEFLOW_MAYBE_TYPE_TRAITS_H_ #include #include #include #include #include "config.h" namespace oneflow { namespace maybe { // in this file, xxxS represents struct of xxx // for implementant aspect, xxx is an alias of xxxS::type or xxxS::value template using BoolConstant = std::integral_constant; template using IndexConstant = std::integral_constant; constexpr std::size_t NPos = -1; template struct ConjS : std::true_type {}; template struct ConjS : B1 {}; template struct ConjS : std::conditional_t, B1> {}; template constexpr bool Conj = ConjS::value; template struct DisjS : std::false_type {}; template struct DisjS : B1 {}; template struct DisjS : std::conditional_t> {}; template constexpr bool Disj = DisjS::value; template struct NegS : BoolConstant {}; template constexpr bool Neg = NegS::value; struct TypeNotFound; // return TypeNotFound while out of range template struct TypeGetS; template struct TypeGetS : TypeGetS {}; template struct TypeGetS<0, T1, Tn...> { using type = T1; }; template struct TypeGetS { using type = TypeNotFound; }; template using TypeGet = typename TypeGetS::type; // return NPos (-1) while not found template struct IndexGetFromS; template struct IndexGetFromS : IndexGetFromS {}; template struct IndexGetFromS : IndexConstant {}; template struct IndexGetFromS : IndexConstant {}; template constexpr auto IndexGet = IndexGetFromS<0, T, Ts...>::value; template constexpr auto TypeIn = IndexGet != NPos; template using TypeInS = BoolConstant>; template struct RemoveCVRefS { using type = std::remove_cv_t>; }; template using RemoveCVRef = typename RemoveCVRefS::type; template struct IsDifferentTypesS : BoolConstant && IsDifferentTypesS::value> {}; template struct IsDifferentTypesS : std::true_type {}; template constexpr auto IsDifferentTypes = IsDifferentTypesS::value; template struct ConstRefExceptVoidS { using type = const T&; }; template<> struct ConstRefExceptVoidS { using type = void; }; template using ConstRefExceptVoid = typename ConstRefExceptVoidS::type; template using RemoveRValRef = std::conditional_t::value, std::remove_reference_t, T>; template constexpr bool IsAggregate = OF_MAYBE_IS_AGGREGATE(T); } // namespace maybe } // namespace oneflow #endif // ONEFLOW_MAYBE_TYPE_TRAITS_H_ ================================================ FILE: oneflow/maybe/type_traits_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/maybe/type_traits.h" using namespace oneflow::maybe; TEST(TypeTraits, Basics) { static_assert(Conj, ""); static_assert(!Conj, ""); static_assert(!Conj, ""); static_assert(!Conj, ""); static_assert(!Conj, ""); static_assert(Disj, ""); static_assert(Disj, ""); static_assert(Disj, ""); static_assert(!Disj, ""); static_assert(Disj, ""); static_assert(!Disj, ""); static_assert(std::is_same, int>::value, ""); static_assert(std::is_same, int>::value, ""); static_assert(std::is_same, float>::value, ""); static_assert(std::is_same, bool>::value, ""); static_assert(std::is_same, TypeNotFound>::value, ""); static_assert(std::is_same, float>::value, ""); static_assert(std::is_same, float>::value, ""); static_assert(std::is_same, TypeNotFound>::value, ""); static_assert(IndexGet == 0, ""); static_assert(IndexGet == NPos, ""); static_assert(IndexGet == 0, ""); static_assert(IndexGet == 1, ""); static_assert(IndexGet == 3, ""); static_assert(IndexGet == NPos, ""); static_assert(!TypeIn, ""); static_assert(TypeIn, ""); static_assert(TypeIn, ""); static_assert(!TypeIn, ""); static_assert(TypeIn, ""); static_assert(TypeIn, ""); static_assert(IsDifferentTypes, ""); static_assert(!IsDifferentTypes, ""); static_assert(IsDifferentTypes, ""); static_assert(!IsDifferentTypes, ""); } ================================================ FILE: oneflow/maybe/utility.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_UTILITY_H_ #define ONEFLOW_MAYBE_UTILITY_H_ #include #include namespace oneflow { namespace maybe { // unlike std::nullopt in c++17, the NullOptType is used in both Variant and Optional, // so it is more like both std::nullopt and std::monostate (in c++17), // the advantage of this unification is a more unifed experience, // i.e. `return NullOpt` can be used in both Variant and Optional context struct NullOptType { explicit constexpr NullOptType() = default; bool operator==(NullOptType) const { return true; } bool operator!=(NullOptType) const { return false; } bool operator<(NullOptType) const { return false; } bool operator>(NullOptType) const { return false; } bool operator<=(NullOptType) const { return true; } bool operator>=(NullOptType) const { return true; } }; constexpr const std::size_t NullOptHash = -3333; constexpr NullOptType NullOpt{}; struct InPlaceT { explicit constexpr InPlaceT() = default; }; constexpr InPlaceT InPlace; template struct InPlaceTypeT { explicit constexpr InPlaceTypeT() = default; }; template constexpr InPlaceTypeT InPlaceType; template struct InPlaceIndexT { explicit constexpr InPlaceIndexT() = default; }; template constexpr InPlaceIndexT InPlaceIndex; template constexpr void HashCombine(std::size_t& seed, const T& v) { std::hash hasher; seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } } // namespace maybe } // namespace oneflow namespace std { template<> struct hash { size_t operator()(oneflow::maybe::NullOptType) const noexcept { return oneflow::maybe::NullOptHash; } }; } // namespace std #endif // ONEFLOW_MAYBE_UTILITY_H_ ================================================ FILE: oneflow/maybe/utility_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/maybe/utility.h" using namespace oneflow::maybe; TEST(Utility, NullOpt) { NullOptType a, b(NullOpt), c(a); // NOLINT a = NullOpt; a = b; ASSERT_EQ(a, NullOptType{}); ASSERT_EQ(std::hash()(a), std::hash()(NullOpt)); ASSERT_EQ(NullOpt, a); ASSERT_GE(NullOpt, a); ASSERT_LE(NullOpt, a); ASSERT_FALSE(NullOpt < a); ASSERT_FALSE(NullOpt > a); } ================================================ FILE: oneflow/maybe/variant.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_MAYBE_VARIANT_H_ #define ONEFLOW_MAYBE_VARIANT_H_ #include #include #include #include #include #include #include #include "oneflow/maybe/utility.h" #include "oneflow/maybe/type_traits.h" namespace oneflow { namespace maybe { template struct Variant; namespace details { // there are generally two ways to implement visit (like std::visit in c++17) // 1. O(N) or O(log N), to iterate for all types or do a binary search on type index recursively // 2. O(1), to store an static (storage duration) array of function pointers for every (Variant, F) // where N = Variant::Num, and normally (in most cases) within the range [2, 5] // the 2nd method is required in std::visit(f, x...) while sizeof...(x) == 1 // but weakness of the 2nd method is that compilers usually cannot efficiently optimize these // function pointers (compared to trivial recursion, which is easy to do optimization, and also // friendly to CPU cache) here we implement visit via the first method: // 1. for 2 <= N < 4, we use the O(N) algorithm (TrivialRecursiveVisitImpl) for better optimization // 2. for N >= 4, we use the O(log N) algorithm (BinarySearchVisitImpl) for less recursion rounds struct VariantPrivateScope { template static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT::Num - 1>) { // assume v.Index() == N - 1 now return static_cast( std::forward(f)(std::forward(v).template Value::Num - 1>())); } template::Num - 1), int> = 0> static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT) { if (v.Index() == I) { return static_cast(std::forward(f)(std::forward(v).template Value())); } return TrivialRecursiveVisitImpl(std::forward(f), std::forward(v), InPlaceIndex); } template::Num), int> = 0> static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { return static_cast(std::forward(f)(std::forward(v).template Value())); } template::Num), int> = 0> static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { constexpr std::size_t M = (I + I + 1) / 2; constexpr std::size_t N = (M == I) ? I + 1 : I; if (v.Index() == M) { return static_cast(std::forward(f)(std::forward(v).template Value())); } else { return static_cast(std::forward(f)(std::forward(v).template Value())); } } template::Num), int> = 0> static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT, InPlaceIndexT) { constexpr std::size_t M = (L + U) / 2; if (v.Index() < M) { return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex, InPlaceIndex); } else if (v.Index() > M) { return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex, InPlaceIndex); } else { return static_cast(std::forward(f)(std::forward(v).template Value())); } } template::Num<4, int> = 0> static R VisitImpl(F&& f, V&& v) { return TrivialRecursiveVisitImpl(std::forward(f), std::forward(v), InPlaceIndex<0>); } template::Num >= 4, int> = 0> static R VisitImpl(F&& f, V&& v) { return BinarySearchVisitImpl(std::forward(f), std::forward(v), InPlaceIndex<0>, InPlaceIndex::Num - 1>); } }; struct AutoDeducedResultType; template struct VisitResultS { using type = R; }; template struct VisitResultS { using type = std::common_type_t()(std::declval()))...>; }; template using VisitResult = typename VisitResultS::type; } // namespace details // preconditions: template type arguments must be no less than 2 different type // and without reference and cv qualifiers // this Variant DO NOT guarantee exception safety template struct Variant { // NOLINT(cppcoreguidelines-pro-type-member-init) public: static_assert(sizeof...(Ts) > 1, "expected more than two types"); static_assert(Conj>...>, "reference types are not allowed here"); static_assert(Conj, std::is_volatile>>...>, "cv qualifiers are not allowed here"); // important precondition to optimize Visit via binary search static_assert(IsDifferentTypes, "expected all of different types"); static constexpr std::size_t Num = sizeof...(Ts); template static constexpr std::size_t IndexOfType = IndexGet; template static constexpr bool HasType = TypeIn; template using TypeByIndex = TypeGet; template, std::enable_if_t::value, int> = 0> Variant() { // NOLINT(cppcoreguidelines-pro-type-member-init) Construct<0>(); } // unlike std::variant, we only accept exact types to avoid wrong construction template>, int> = 0> Variant(T&& v) { // NOLINT(cppcoreguidelines-pro-type-member-init, google-explicit-constructor) Construct>(std::forward(v)); } template>, int> = 0> explicit Variant(InPlaceTypeT, // NOLINT(cppcoreguidelines-pro-type-member-init) Args&&... args) { Construct>(std::forward(args)...); } template = 0> explicit Variant(InPlaceIndexT, // NOLINT(cppcoreguidelines-pro-type-member-init) Args&&... args) { Construct(std::forward(args)...); } template decltype(auto) Visit(F&& f) & { using Result = details::VisitResult; return details::VariantPrivateScope::VisitImpl(std::forward(f), *this); } template decltype(auto) Visit(F&& f) && { using Result = details::VisitResult; return details::VariantPrivateScope::VisitImpl(std::forward(f), std::move(*this)); } template decltype(auto) Visit(F&& f) const& { using Result = details::VisitResult; return details::VariantPrivateScope::VisitImpl(std::forward(f), *this); } Variant(const Variant& v) { // NOLINT(cppcoreguidelines-pro-type-member-init) CopyConstruct(v); } Variant(Variant&& v) noexcept { // NOLINT(cppcoreguidelines-pro-type-member-init) CopyConstruct(std::move(v)); } template>, int> = 0> Variant& operator=(T&& v) { using Type = RemoveCVRef; Emplace(std::forward(v)); return *this; } Variant& operator=(const Variant& v) { Copy(v); return *this; } Variant& operator=(Variant&& v) noexcept { Copy(std::move(v)); return *this; } std::size_t Index() const { return type_index_; } template, int> = 0> bool Is() const { return type_index_ == IndexOfType; } ~Variant() { Destory(); } bool operator==(const Variant& v) const { if (type_index_ != v.type_index_) return false; return v.Visit( [this](const auto& elem) { return elem == Value>(); }); } bool operator!=(const Variant& v) const { return !operator==(v); } bool operator<(const Variant& v) const { if (type_index_ < v.type_index_) return true; if (type_index_ > v.type_index_) return false; return v.Visit( [this](const auto& elem) { return Value>() < elem; }); } bool operator>=(const Variant& v) const { return !(*this < v); } bool operator>(const Variant& v) const { if (type_index_ > v.type_index_) return true; if (type_index_ < v.type_index_) return false; return v.Visit( [this](const auto& elem) { return Value>() > elem; }); } bool operator<=(const Variant& v) const { return !(*this > v); } template, int> = 0> friend bool operator==(const Variant& v, const T& x) { if (v.type_index_ != IndexOfType) return false; return v.Value() == x; } template, int> = 0> friend bool operator!=(const Variant& v, const T& x) { return !(v == x); } template, int> = 0> friend bool operator==(const T& x, const Variant& v) { return v == x; } template, int> = 0> friend bool operator!=(const T& x, const Variant& v) { return !(v == x); } template T& Emplace(Args&&... args) { if (Is()) { return Value() = T(std::forward(args)...); } else { Destory(); Construct(std::forward(args)...); return Value(); } } template decltype(auto) Emplace(Args&&... args) { return Emplace>(std::forward(args)...); } template, int> = 0> T& Get() & { OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); return Value(); } template, int> = 0> T&& Get() && { OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); return std::move(*this).template Value(); } template, int> = 0> const T& Get() const& { OF_MAYBE_ASSERT_EQ(Index(), IndexOfType); return Value(); } template = 0> TypeByIndex& Get() & { OF_MAYBE_ASSERT_EQ(Index(), I); return Value(); } template = 0> TypeByIndex&& Get() && { OF_MAYBE_ASSERT_EQ(Index(), I); return std::move(*this).template Value(); } template = 0> const TypeByIndex& Get() const& { OF_MAYBE_ASSERT_EQ(Index(), I); return Value(); } protected: // use std::launder while updating to c++17 template, int> = 0> T& Value() & { return *reinterpret_cast(storage_); } template, int> = 0> T&& Value() && { return std::move(*reinterpret_cast(storage_)); } template, int> = 0> const T& Value() const& { return *reinterpret_cast(storage_); } template = 0> TypeByIndex& Value() & { return *reinterpret_cast*>(storage_); } template = 0> TypeByIndex&& Value() && { return std::move(*reinterpret_cast*>(storage_)); } template = 0> const TypeByIndex& Value() const& { return *reinterpret_cast*>(storage_); } private: static constexpr const std::size_t size = std::max({sizeof(Ts)...}); alignas(Ts...) unsigned char storage_[size]; std::uint8_t type_index_; friend struct details::VariantPrivateScope; template && IsAggregate, int> = 0> void Construct(Args&&... args) { new (storage_) T{std::forward(args)...}; type_index_ = IndexOfType; } template && !IsAggregate, int> = 0> void Construct(Args&&... args) { new (storage_) T(std::forward(args)...); type_index_ = IndexOfType; } template = 0> void Construct(Args&&... args) { Construct>(std::forward(args)...); } template void CopyConstruct(V&& v) { std::forward(v).Visit([this](auto&& elem) { using T = RemoveCVRef; new (storage_) T(std::forward(elem)); type_index_ = IndexOfType; }); } template void Copy(V&& v) { std::forward(v).Visit([this](auto&& elem) { using T = RemoveCVRef; if (Is()) { Value() = std::forward(elem); } else { Destory(); Construct(std::forward(elem)); } }); } void Destory() { Visit([this](auto& elem) { using T = RemoveCVRef; Value().~T(); }); } }; template using OptionalVariant = Variant; } // namespace maybe } // namespace oneflow namespace std { template struct hash> { size_t operator()(const oneflow::maybe::Variant& v) const noexcept { size_t seed = hash()(v.Index()); v.Visit([&seed](const auto& x) { using type = oneflow::maybe::RemoveCVRef; oneflow::maybe::HashCombine(seed, x); }); return seed; } }; } // namespace std #endif // ONEFLOW_MAYBE_VARIANT_H_ ================================================ FILE: oneflow/maybe/variant_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/maybe/variant.h" using namespace oneflow::maybe; using namespace std::string_literals; TEST(Variant, Basics) { Variant a, b(1), c(1.2f), d(InPlaceType, 'a'), e(InPlaceType, 6.66); ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), 0); ASSERT_TRUE(b.Is()); ASSERT_EQ(b.Get(), 1); ASSERT_TRUE(c.Is()); ASSERT_EQ(c.Get(), 1.2f); ASSERT_TRUE(d.Is()); ASSERT_EQ(d.Get(), 'a'); ASSERT_TRUE(e.Is()); ASSERT_FLOAT_EQ(e.Get(), 6.66); Variant f(b), g(c), h(InPlaceIndex<1>, 2.33), i(InPlaceIndex<0>, 2.33); ASSERT_TRUE(f.Is()); ASSERT_EQ(f.Get(), 1); ASSERT_TRUE(g.Is()); ASSERT_EQ(g.Get(), 1.2f); ASSERT_TRUE(h.Is()); ASSERT_FLOAT_EQ(h.Get(), 2.33); ASSERT_TRUE(i.Is()); ASSERT_EQ(i.Get(), 2); a = 1; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), 1); a = 1.3f; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), 1.3f); a = b; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), 1); a = c; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), 1.2f); ASSERT_EQ((b.Visit>([](auto&& x) { return x + 1; })), (Variant(2))); ASSERT_EQ((c.Visit>([](auto&& x) { return x + 1; })), (Variant(2.2f))); ASSERT_EQ(a.Emplace<1>(1.3f), 1.3f); ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get<1>(), 1.3f); ASSERT_EQ(a.Emplace<0>(233), 233); ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get<0>(), 233); } TEST(Variant, NonPOD) { Variant> a; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), false); a = true; ASSERT_TRUE(a.Is()); ASSERT_EQ(a.Get(), true); a = std::make_shared(233); ASSERT_EQ(a.Index(), 1); ASSERT_EQ(*a.Get<1>(), 233); ASSERT_EQ(a.Get<1>().use_count(), 1); { Variant> b = a; ASSERT_EQ(b.Index(), 1); ASSERT_EQ(*b.Get<1>(), 233); ASSERT_EQ(a.Get<1>().use_count(), 2); *b.Get<1>() = 234; } ASSERT_EQ(a.Get<1>().use_count(), 1); ASSERT_EQ(*a.Get<1>(), 234); Variant> b = std::move(a); ASSERT_EQ(b.Get<1>().use_count(), 1); ASSERT_EQ(*b.Get<1>(), 234); Variant> c = b; ASSERT_EQ(c.Get<1>().use_count(), 2); ASSERT_EQ(b, c); b = true; ASSERT_EQ(c.Get<1>().use_count(), 1); ASSERT_NE(b, c); } TEST(Variant, Optional) { OptionalVariant a, b(NullOpt), c(a); const char* hello = "hello"; std::size_t hash = 0, hash2 = 1, hash3 = 2; HashCombine(hash, NullOpt); HashCombine(hash2, 1); HashCombine(hash3, hello); ASSERT_TRUE(a == NullOpt); ASSERT_EQ(std::hash()(a), hash); a = 1; ASSERT_EQ(a, 1); ASSERT_EQ(std::hash()(a), hash2); a = NullOpt; ASSERT_EQ(a, NullOpt); ASSERT_EQ(std::hash()(a), hash); a = hello; ASSERT_EQ(a, hello); ASSERT_EQ(std::hash()(a), hash3); ASSERT_EQ(b, NullOpt); ASSERT_EQ(c, NullOpt); ASSERT_NE(a, b); } TEST(Variant, BinarySearchVisit) { const char* hello = "hello"; OptionalVariant x, y(123), z(1.2f), w(true); OptionalVariant a, b(123), c(1.2f), d(true), e(hello); ASSERT_EQ(x, NullOpt); ASSERT_EQ(y, 123); ASSERT_EQ(z, 1.2f); ASSERT_EQ(w, true); ASSERT_EQ(a, NullOpt); ASSERT_EQ(b, 123); ASSERT_EQ(c, 1.2f); ASSERT_EQ(d, true); ASSERT_EQ(e, hello); OptionalVariant a1(a), b1(b), c1(c), d1(d), e1(e); ASSERT_EQ(a1, NullOpt); ASSERT_EQ(b1, 123); ASSERT_EQ(c1, 1.2f); ASSERT_EQ(d1, true); ASSERT_EQ(e1, hello); a = 233; ASSERT_EQ(a, 233); a = hello; ASSERT_EQ(a, hello); a = c; ASSERT_EQ(a, 1.2f); ASSERT_EQ(1.2f, a); ASSERT_EQ(a, c); ASSERT_NE(a, b); } TEST(Variant, Compare) { OptionalVariant a, b, c(0), d(5), dd(5), e(-1.2f), f(2.3f), g(false), h(true); ASSERT_EQ(a, b); ASSERT_EQ(d, dd); ASSERT_NE(a, c); ASSERT_NE(c, d); ASSERT_NE(d, e); ASSERT_NE(e, f); ASSERT_NE(f, g); ASSERT_NE(g, h); ASSERT_LT(a, c); ASSERT_LT(c, d); ASSERT_LT(d, e); ASSERT_LT(e, f); ASSERT_LT(f, g); ASSERT_LT(g, h); ASSERT_GT(c, a); ASSERT_GT(d, c); ASSERT_GT(e, d); ASSERT_GT(f, e); ASSERT_GT(g, f); ASSERT_GT(h, g); ASSERT_LE(a, b); ASSERT_LE(b, c); ASSERT_LE(c, d); ASSERT_LE(d, dd); std::set> s{100, 2.3f, true, 3.3f, NullOpt, 0, false, 22, true, NullOpt}; ASSERT_EQ(s.size(), 8); auto iter = s.begin(); ASSERT_EQ(*(iter++), NullOpt); ASSERT_EQ(*(iter++), 0); ASSERT_EQ(*(iter++), 22); ASSERT_EQ(*(iter++), 100); ASSERT_EQ(*(iter++), 2.3f); ASSERT_EQ(*(iter++), 3.3f); ASSERT_EQ(*(iter++), false); ASSERT_EQ(*(iter++), true); } TEST(Variant, UniquePtr) { Variant> a("hello"s), b(std::make_unique(1)); ASSERT_EQ(a, "hello"s); ASSERT_EQ(*b.Get<1>(), 1); Variant> c(std::move(a)), d(std::move(b)); ASSERT_EQ(c, "hello"s); ASSERT_EQ(*d.Get<1>(), 1); } ================================================ FILE: oneflow/user/data/batch_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_BATCH_DATASET_H_ #define ONEFLOW_USER_DATA_BATCH_DATASET_H_ #include "oneflow/user/data/dataset.h" namespace oneflow { namespace data { template class BatchDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; BatchDataset(int32_t batch_size, std::unique_ptr>&& dataset) : batch_size_(batch_size), nested_ds_(std::move(dataset)) {} ~BatchDataset() = default; BatchType Next() override { BatchType batch; batch.reserve(batch_size_); for (size_t i = 0; i < batch_size_; ++i) { BatchType tmp = nested_ds_->Next(); CHECK_EQ(tmp.size(), 1); batch.push_back(std::move(tmp[0])); } return batch; } private: int32_t batch_size_; std::unique_ptr> nested_ds_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_BATCH_DATASET_H_ ================================================ FILE: oneflow/user/data/batch_random_shuffle_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_ #define ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_ #include "oneflow/user/data/dataset.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace data { template class BatchRandomShuffleDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; BatchRandomShuffleDataset(user_op::KernelInitContext* ctx, std::unique_ptr>&& data_set) : loader_(std::move(data_set)) { // random seed_ = ctx->Attr("seed"); if (seed_ == -1) { seed_ = NewRandomSeed(); } std::seed_seq seq({seed_}); rand_engine_ = std::default_random_engine(seq); // fill buffer initial_buffer_fill_ = ctx->Attr("shuffle_buffer_size"); for (int32_t i = 0; i < initial_buffer_fill_; ++i) { BatchType batch = loader_->Next(); batch_buffer_.push_back(std::move(batch)); } } ~BatchRandomShuffleDataset() = default; BatchType Next() override { BatchType batch = loader_->Next(); std::uniform_int_distribution<> dis(0, batch_buffer_.size() - 1); const int offset = dis(rand_engine_); std::swap(batch_buffer_.at(offset), batch); return batch; } private: std::unique_ptr> loader_; std::vector batch_buffer_; int32_t initial_buffer_fill_; std::default_random_engine rand_engine_; int64_t seed_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_ ================================================ FILE: oneflow/user/data/coco_data_reader.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/coco_data_reader.h" #include "oneflow/user/data/coco_dataset.h" #include "oneflow/user/data/distributed_training_dataset.h" #include "oneflow/user/data/group_batch_dataset.h" #include "oneflow/user/data/batch_dataset.h" #include "oneflow/user/data/distributed_util.h" #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace data { COCODataReader::COCODataReader(user_op::KernelInitContext* ctx) : DataReader(ctx) { batch_size_ = ctx->TensorDesc4ArgNameAndIndex("image", 0)->shape().elem_cnt(); if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); } std::shared_ptr meta(new COCOMeta( ctx->Attr("session_id"), ctx->Attr("annotation_file"), ctx->Attr("image_dir"), ctx->Attr("remove_images_without_annotations"))); std::unique_ptr> coco_dataset_ptr(new COCODataset(ctx, meta)); size_t world_size = 1; int64_t rank = 0; CHECK_JUST(InitDataSourceDistributedInfo(ctx, world_size, rank)); loader_.reset(new DistributedTrainingDataset( world_size, rank, ctx->Attr("stride_partition"), ctx->Attr("shuffle_after_epoch"), ctx->Attr("random_seed"), std::move(coco_dataset_ptr))); if (ctx->Attr("group_by_ratio")) { auto GetGroupId = [](const COCOImage& sample) { return static_cast(sample.height / sample.width); }; loader_.reset(new GroupBatchDataset(batch_size_, GetGroupId, std::move(loader_))); } else { loader_.reset(new BatchDataset(batch_size_, std::move(loader_))); } parser_.reset(new COCOParser(meta)); StartLoadThread(); } COCODataReader::~COCODataReader() { if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); } } COCOMeta::COCOMeta(int64_t session_id, const std::string& annotation_file, const std::string& image_dir, bool remove_images_without_annotations) : image_dir_(image_dir) { // Read content of annotation file (json format) to json obj PersistentInStream in_stream(session_id, DataFS(), annotation_file); std::string json_str; std::string line; while (in_stream.ReadLine(&line) == 0) { json_str += line; } std::istringstream in_str_stream(json_str); in_str_stream >> annotation_json_; // initialize image_ids_, image_id2image_ and image_id2anno_ids_ for (const auto& image : annotation_json_["images"]) { int64_t id = image["id"].get(); image_ids_.emplace_back(id); CHECK(image_id2image_.emplace(id, image).second); CHECK(image_id2anno_ids_.emplace(id, std::vector()).second); } // build anno map for (const auto& anno : annotation_json_["annotations"]) { int64_t id = anno["id"].get(); int64_t image_id = anno["image_id"].get(); // ignore crowd object for now if (anno["iscrowd"].get() == 1) { continue; } // check if invalid segmentation if (anno["segmentation"].is_array()) { for (const auto& poly : anno["segmentation"]) { // at least 3 points can compose a polygon // every point needs 2 element (x, y) to present CHECK_GT(poly.size(), 6); } } CHECK(anno_id2anno_.emplace(id, anno).second); image_id2anno_ids_.at(image_id).emplace_back(id); } // remove images without annotations if necessary if (remove_images_without_annotations) { HashSet to_remove_image_ids; for (int64_t image_id : image_ids_) { if (!ImageHasValidAnnotations(image_id)) { to_remove_image_ids.insert(image_id); } } image_ids_.erase(std::remove_if(image_ids_.begin(), image_ids_.end(), [&to_remove_image_ids](int64_t image_id) { return to_remove_image_ids.find(image_id) != to_remove_image_ids.end(); }), image_ids_.end()); } // sort image ids for reproducible results std::sort(image_ids_.begin(), image_ids_.end()); // build categories map std::vector category_ids; for (const auto& cat : annotation_json_["categories"]) { category_ids.emplace_back(cat["id"].get()); } std::sort(category_ids.begin(), category_ids.end()); int32_t contiguous_id = 1; for (int32_t category_id : category_ids) { CHECK(category_id2contiguous_id_.emplace(category_id, contiguous_id++).second); } } bool COCOMeta::ImageHasValidAnnotations(int64_t image_id) const { const std::vector& anno_id_vec = image_id2anno_ids_.at(image_id); if (anno_id_vec.empty()) { return false; } bool bbox_area_all_close_to_zero = true; size_t visible_keypoints_count = 0; for (int64_t anno_id : anno_id_vec) { const auto& anno = anno_id2anno_.at(anno_id); if (anno["bbox"][2] > 1 && anno["bbox"][3] > 1) { bbox_area_all_close_to_zero = false; } if (anno.contains("keypoints")) { const auto& keypoints = anno["keypoints"]; CHECK_EQ(keypoints.size() % 3, 0); FOR_RANGE(size_t, i, 0, keypoints.size() / 3) { int32_t keypoints_label = keypoints[i * 3 + 2].get(); if (keypoints_label > 0) { visible_keypoints_count += 1; } } } } // check if all boxes are close to zero area if (bbox_area_all_close_to_zero) { return false; } // keypoints task have a slight different critera for considering // if an annotation is valid if (!anno_id2anno_.at(anno_id_vec.at(0)).contains("keypoints")) { return true; } // for keypoint detection tasks, only consider valid images those // containing at least min_keypoints_per_image if (visible_keypoints_count >= kMinKeypointsPerImage) { return true; } return false; } } // namespace data } // namespace oneflow ================================================ FILE: oneflow/user/data/coco_data_reader.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_COCO_DATA_READER_H_ #define ONEFLOW_USER_DATA_COCO_DATA_READER_H_ #include "oneflow/user/data/data_reader.h" #include "oneflow/user/data/coco_parser.h" #include "oneflow/core/common/str_util.h" #include "nlohmann/json.hpp" namespace oneflow { namespace data { class COCODataReader final : public DataReader { public: COCODataReader(user_op::KernelInitContext* ctx); ~COCODataReader() override; protected: using DataReader::loader_; using DataReader::parser_; private: size_t batch_size_; }; class COCOMeta final { public: COCOMeta(int64_t session_id, const std::string& annotation_file, const std::string& image_dir, bool remove_images_without_annotations); ~COCOMeta() = default; int64_t Size() const { return image_ids_.size(); } int64_t GetImageId(int64_t index) const { return image_ids_.at(index); } int32_t GetImageHeight(int64_t index) const { int64_t image_id = image_ids_.at(index); return image_id2image_.at(image_id)["height"].get(); } int32_t GetImageWidth(int64_t index) const { int64_t image_id = image_ids_.at(index); return image_id2image_.at(image_id)["width"].get(); } std::string GetImageFilePath(int64_t index) const { int64_t image_id = image_ids_.at(index); const auto& image_json = image_id2image_.at(image_id); return JoinPath(image_dir_, image_json["file_name"].get()); } template std::vector GetBboxVec(int64_t index) const; template std::vector GetLabelVec(int64_t index) const; template void ReadSegmentationsToTensorBuffer(int64_t index, TensorBuffer* segm, TensorBuffer* segm_offset_mat) const; private: bool ImageHasValidAnnotations(int64_t image_id) const; static constexpr int kMinKeypointsPerImage = 10; nlohmann::json annotation_json_; std::string image_dir_; std::vector image_ids_; HashMap image_id2image_; HashMap anno_id2anno_; HashMap> image_id2anno_ids_; HashMap category_id2contiguous_id_; }; template std::vector COCOMeta::GetBboxVec(int64_t index) const { std::vector bbox_vec; int64_t image_id = image_ids_.at(index); const auto& anno_ids = image_id2anno_ids_.at(image_id); for (int64_t anno_id : anno_ids) { const auto& bbox_json = anno_id2anno_.at(anno_id)["bbox"]; CHECK(bbox_json.is_array()); CHECK_EQ(bbox_json.size(), 4); // COCO bounding box format is [left, top, width, height] // we need format xyxy const T alginment = static_cast(1); const T min_size = static_cast(0); T left = bbox_json[0].get(); T top = bbox_json[1].get(); T width = bbox_json[2].get(); T height = bbox_json[3].get(); T right = left + std::max(width - alginment, min_size); T bottom = top + std::max(height - alginment, min_size); // clip to image int32_t image_height = GetImageHeight(index); int32_t image_width = GetImageWidth(index); left = std::min(std::max(left, min_size), image_width - alginment); top = std::min(std::max(top, min_size), image_height - alginment); right = std::min(std::max(right, min_size), image_width - alginment); bottom = std::min(std::max(bottom, min_size), image_height - alginment); // ensure bbox is not empty if (right > left && bottom > top) { bbox_vec.insert(bbox_vec.end(), {left, top, right, bottom}); } } return bbox_vec; } template std::vector COCOMeta::GetLabelVec(int64_t index) const { std::vector label_vec; int64_t image_id = image_ids_.at(index); const auto& anno_ids = image_id2anno_ids_.at(image_id); for (int64_t anno_id : anno_ids) { int32_t category_id = anno_id2anno_.at(anno_id)["category_id"].get(); label_vec.emplace_back(category_id2contiguous_id_.at(category_id)); } return label_vec; } template void COCOMeta::ReadSegmentationsToTensorBuffer(int64_t index, TensorBuffer* segm, TensorBuffer* segm_index) const { if (segm == nullptr || segm_index == nullptr) { return; } int64_t image_id = image_ids_.at(index); const auto& anno_ids = image_id2anno_ids_.at(image_id); std::vector segm_vec; for (int64_t anno_id : anno_ids) { const auto& segm_json = anno_id2anno_.at(anno_id)["segmentation"]; if (!segm_json.is_array()) { continue; } for (const auto& poly_json : segm_json) { CHECK(poly_json.is_array()); for (const auto& elem : poly_json) { segm_vec.emplace_back(elem.get()); } } } CHECK_EQ(segm_vec.size() % 2, 0); int64_t num_pts = segm_vec.size() / 2; segm->Resize(Shape({num_pts, 2}), GetDataType::value); std::copy(segm_vec.begin(), segm_vec.end(), segm->mut_data()); segm_index->Resize(Shape({num_pts, 3}), DataType::kInt32); int32_t* index_ptr = segm_index->mut_data(); int i = 0; int32_t segm_idx = 0; for (int64_t anno_id : anno_ids) { const auto& segm_json = anno_id2anno_.at(anno_id)["segmentation"]; CHECK(segm_json.is_array()); FOR_RANGE(int32_t, poly_idx, 0, segm_json.size()) { const auto& poly_json = segm_json[poly_idx]; CHECK(poly_json.is_array()); CHECK_EQ(poly_json.size() % 2, 0); FOR_RANGE(int32_t, pt_idx, 0, poly_json.size() / 2) { index_ptr[i * 3 + 0] = pt_idx; index_ptr[i * 3 + 1] = poly_idx; index_ptr[i * 3 + 2] = segm_idx; i += 1; } } segm_idx += 1; } CHECK_EQ(i, num_pts); } } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_COCO_DATA_READER_H_ ================================================ FILE: oneflow/user/data/coco_dataset.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/coco_dataset.h" #include "oneflow/user/data/coco_data_reader.h" #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/persistent_in_stream.h" namespace oneflow { namespace data { COCODataset::BatchType COCODataset::At(int64_t index) const { BatchType batch; batch.push_back(COCOImage()); auto& sample = batch.back(); sample.index = index; sample.id = meta_->GetImageId(index); sample.height = meta_->GetImageHeight(index); sample.width = meta_->GetImageWidth(index); const std::string& image_file_path = meta_->GetImageFilePath(index); PersistentInStream in_stream(session_id_, DataFS(), image_file_path); int64_t file_size = DataFS()->GetFileSize(image_file_path); sample.data.Resize(Shape({file_size}), DataType::kChar); CHECK_EQ(in_stream.ReadFully(sample.data.mut_data(), sample.data.nbytes()), 0); return batch; } size_t COCODataset::Size() const { return meta_->Size(); } } // namespace data } // namespace oneflow ================================================ FILE: oneflow/user/data/coco_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_COCO_DATASET_H_ #define ONEFLOW_USER_DATA_COCO_DATASET_H_ #include "oneflow/user/data/dataset.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace data { struct COCOImage { TensorBuffer data; int64_t index; int64_t id; int32_t height; int32_t width; }; class COCOMeta; class COCODataset final : public RandomAccessDataset { public: using Base = RandomAccessDataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; COCODataset(user_op::KernelInitContext* ctx, const std::shared_ptr& meta) : meta_(meta), session_id_(ctx->Attr("session_id")) {} ~COCODataset() = default; BatchType At(int64_t index) const override; size_t Size() const override; private: std::shared_ptr meta_; int64_t session_id_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_COCO_DATASET_H_ ================================================ FILE: oneflow/user/data/coco_parser.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/coco_parser.h" #include "oneflow/user/data/coco_data_reader.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace data { void COCOParser::Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) { user_op::Tensor* image_tensor = ctx->Tensor4ArgNameAndIndex("image", 0); CHECK_NOTNULL(image_tensor); user_op::Tensor* image_id_tensor = ctx->Tensor4ArgNameAndIndex("image_id", 0); user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex("image_size", 0); user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex("gt_bbox", 0); user_op::Tensor* label_tensor = ctx->Tensor4ArgNameAndIndex("gt_label", 0); user_op::Tensor* segm_tensor = ctx->Tensor4ArgNameAndIndex("gt_segm", 0); user_op::Tensor* segm_index_tensor = ctx->Tensor4ArgNameAndIndex("gt_segm_index", 0); MultiThreadLoop(batch_data.size(), [&](size_t i) { TensorBuffer* image_buffer = image_tensor->mut_dptr() + i; COCOImage& image = batch_data[i]; image_buffer->Swap(image.data); if (image_size_tensor) { auto* image_size_ptr = image_size_tensor->mut_dptr() + i * 2; image_size_ptr[0] = meta_->GetImageHeight(image.index); image_size_ptr[1] = meta_->GetImageWidth(image.index); } if (image_id_tensor) { auto* image_id_ptr = image_id_tensor->mut_dptr(); image_id_ptr[i] = image.id; } if (bbox_tensor) { TensorBuffer* bbox_buffer = bbox_tensor->mut_dptr() + i; const auto& bbox_vec = meta_->GetBboxVec(image.index); CHECK_EQ(bbox_vec.size() % 4, 0); int64_t num_bboxes = bbox_vec.size() / 4; bbox_buffer->Resize(Shape({num_bboxes, 4}), DataType::kFloat); std::copy(bbox_vec.begin(), bbox_vec.end(), bbox_buffer->mut_data()); } if (label_tensor) { TensorBuffer* label_buffer = label_tensor->mut_dptr() + i; const auto& label_vec = meta_->GetLabelVec(image.index); label_buffer->Resize(Shape({static_cast(label_vec.size())}), DataType::kInt32); std::copy(label_vec.begin(), label_vec.end(), label_buffer->mut_data()); } if (segm_tensor && segm_index_tensor) { TensorBuffer* segm_buffer = segm_tensor->mut_dptr() + i; TensorBuffer* segm_index_buffer = segm_index_tensor->mut_dptr() + i; meta_->ReadSegmentationsToTensorBuffer(image.index, segm_buffer, segm_index_buffer); } }); // dynamic batch size if (image_tensor->shape_view().elem_cnt() != batch_data.size()) { CHECK_EQ(image_tensor->shape_view().NumAxes(), 1); image_tensor->mut_shape_view().Set(0, batch_data.size()); } if (image_id_tensor && image_id_tensor->shape_view().At(0) != batch_data.size()) { image_id_tensor->mut_shape_view().Set(0, batch_data.size()); } if (image_size_tensor && image_size_tensor->shape_view().At(0) != batch_data.size()) { image_size_tensor->mut_shape_view().Set(0, batch_data.size()); } if (bbox_tensor && bbox_tensor->shape_view().elem_cnt() != batch_data.size()) { CHECK_EQ(bbox_tensor->shape_view().NumAxes(), 1); bbox_tensor->mut_shape_view().Set(0, batch_data.size()); } if (label_tensor && label_tensor->shape_view().elem_cnt() != batch_data.size()) { CHECK_EQ(label_tensor->shape_view().NumAxes(), 1); label_tensor->mut_shape_view().Set(0, batch_data.size()); } if (segm_tensor && segm_index_tensor && segm_tensor->shape_view().elem_cnt() != batch_data.size()) { CHECK_EQ(segm_tensor->shape_view().NumAxes(), 1); CHECK_EQ(segm_index_tensor->shape_view().NumAxes(), 1); CHECK_EQ(segm_tensor->shape_view().elem_cnt(), segm_index_tensor->shape_view().elem_cnt()); segm_tensor->mut_shape_view().Set(0, batch_data.size()); segm_index_tensor->mut_shape_view().Set(0, batch_data.size()); } } } // namespace data } // namespace oneflow ================================================ FILE: oneflow/user/data/coco_parser.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_COCO_PARSER_H_ #define ONEFLOW_USER_DATA_COCO_PARSER_H_ #include "oneflow/user/data/parser.h" #include "oneflow/user/data/coco_dataset.h" namespace oneflow { namespace data { class COCOMeta; class COCOParser final : public Parser { public: using Base = Parser; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; COCOParser(const std::shared_ptr& meta) : meta_(meta){}; ~COCOParser() = default; void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override; private: std::shared_ptr meta_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_COCO_PARSER_H_ ================================================ FILE: oneflow/user/data/data_reader.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_DATA_READER_H_ #define ONEFLOW_USER_DATA_DATA_READER_H_ #include "oneflow/user/data/dataset.h" #include "oneflow/user/data/parser.h" #include "oneflow/core/common/buffer.h" namespace oneflow { namespace data { static const int32_t kDataReaderBatchBufferSize = 4; template class DataReader { public: using SampleType = LoadTarget; using BatchType = std::vector; DataReader(user_op::KernelInitContext* ctx) : is_closed_(false), batch_buffer_(kDataReaderBatchBufferSize) {} virtual ~DataReader() { Close(); if (load_thrd_.joinable()) { load_thrd_.join(); } } void Read(user_op::KernelComputeContext* ctx) { CHECK(load_thrd_.joinable()) << "You should call StartLoadThread before read data"; auto batch = FetchBatchData(); parser_->Parse(batch, ctx); } void Close() { if (!is_closed_.load()) { is_closed_.store(true); batch_buffer_.Close(); } } protected: void StartLoadThread() { if (load_thrd_.joinable()) { return; } load_thrd_ = std::thread([this] { while (!is_closed_.load() && LoadBatch()) {} }); } std::unique_ptr> loader_; std::unique_ptr> parser_; private: BatchType FetchBatchData() { BatchType batch; CHECK_EQ(batch_buffer_.Pull(&batch), BufferStatus::kBufferStatusSuccess); return batch; } bool LoadBatch() { BatchType batch = loader_->Next(); return batch_buffer_.Push(std::move(batch)) == BufferStatus::kBufferStatusSuccess; } std::atomic is_closed_; Buffer batch_buffer_; std::thread load_thrd_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_DATA_READER_H_ ================================================ FILE: oneflow/user/data/dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_DATASET_H_ #define ONEFLOW_USER_DATA_DATASET_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/tensor_buffer.h" namespace oneflow { namespace data { static constexpr int kOneflowDatasetSeed = 524287; template class Dataset { public: using SampleType = LoadTarget; using BatchType = std::vector; Dataset() = default; virtual ~Dataset() = default; virtual BatchType Next() = 0; }; template class RandomAccessDataset : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; RandomAccessDataset() : cur_idx_(0) {} virtual ~RandomAccessDataset() = default; virtual BatchType At(int64_t index) const = 0; virtual size_t Size() const = 0; BatchType Next() final { BatchType ret = this->At(cur_idx_); cur_idx_ += 1; if (cur_idx_ >= this->Size()) { cur_idx_ %= this->Size(); } return ret; } private: int64_t cur_idx_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_DATASET_H_ ================================================ FILE: oneflow/user/data/distributed_training_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_ #define ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_ #include "oneflow/user/data/dataset.h" namespace oneflow { namespace data { template class DistributedTrainingDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; using NestedDS = RandomAccessDataset; DistributedTrainingDataset(int64_t parallel_num, int64_t parallel_id, bool stride_partition, bool shuffle, int64_t random_seed, std::unique_ptr&& dataset) : nested_ds_(std::move(dataset)), shuffle_(shuffle), stride_partition_(stride_partition), rnd_seed_(random_seed), num_shards_(parallel_num), pos_(0), pos_in_shard_(0), epoch_cnt_(0) { shard_size_ = std::ceil(static_cast(nested_ds_->Size()) / num_shards_); if (stride_partition) { pos_ = parallel_id; } else { pos_ = parallel_id * shard_size_; } index_seq_.resize(nested_ds_->Size()); std::iota(index_seq_.begin(), index_seq_.end(), 0); GenNewIndexSequence(); } virtual ~DistributedTrainingDataset() = default; virtual BatchType Next() override { // There are 2 partition strategies // assume epoch size is 10, index seq don't shuffle and there are 4 parts // stride partition strategy (when stride_partition is true): // | part1 | part2 | part3 | part4 | // iter0 | 0, 4, 8, | 1, 5, 9, | 2, 6, 0, | 3, 7, 1, | // iter1 | 2, 6, 0, | 3, 7, 1, | 4, 8, 2, | 5, 9, 3, | // contiguous partition strategy (when stride_partition is false): // | part1 | part2 | part3 | part4 | // iter0 | 0, 1, 2, | 3, 4, 5, | 6, 7, 8, | 9, 0, 1, | // iter1 | 2, 3, 4, | 5, 6, 7, | 8, 9, 0, | 1, 2, 3, | BatchType batch = nested_ds_->At(index_seq_.at(pos_)); if (stride_partition_) { pos_ += num_shards_; } else { pos_ += 1; pos_in_shard_ += 1; if (pos_in_shard_ == shard_size_) { pos_ += (num_shards_ - 1) * shard_size_; pos_in_shard_ = 0; } } CheckRanOutOfSize(); return batch; } private: void CheckRanOutOfSize() { if (pos_ >= index_seq_.size()) { GenNewIndexSequence(); pos_ %= index_seq_.size(); } } void GenNewIndexSequence() { if (shuffle_) { std::mt19937 engine(rnd_seed_ + epoch_cnt_); std::shuffle(index_seq_.begin(), index_seq_.end(), engine); } epoch_cnt_ += 1; } std::unique_ptr nested_ds_; bool shuffle_; bool stride_partition_; int64_t rnd_seed_; int64_t num_shards_; int64_t shard_size_; int64_t pos_; int64_t pos_in_shard_; int64_t epoch_cnt_; std::vector index_seq_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_ ================================================ FILE: oneflow/user/data/distributed_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_DISTRIBUTED_UTIL_H_ #define ONEFLOW_USER_DATA_DISTRIBUTED_UTIL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace data { inline Maybe InitDataSourceDistributedInfo(user_op::KernelInitContext* ctx, size_t& world_size, int64_t& rank) { auto nd_sbp_str_vec = ctx->Attr>("nd_sbp"); if (nd_sbp_str_vec.empty()) { world_size = GlobalProcessCtx::WorldSize(); rank = GlobalProcessCtx::Rank(); } else { const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); CHECK_EQ_OR_RETURN(hierarchy.NumAxes(), nd_sbp_str_vec.size()); rank = 0; world_size = 1; using index_helper_t = NdIndexOffsetHelper; index_helper_t index_helper(hierarchy.dim_vec().data(), hierarchy.NumAxes()); int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0}; index_helper.OffsetToNdIndex(ctx->parallel_ctx().parallel_id(), nd_index); for (int i = hierarchy.NumAxes() - 1; i >= 0; --i) { SbpParallel sbp; CHECK_OR_RETURN(ParseSbpParallelFromString(nd_sbp_str_vec[i], &sbp)); if (sbp.has_split_parallel()) { rank += nd_index[i] * world_size; world_size *= hierarchy.At(i); } } } return Maybe::Ok(); } } // namespace data } // namespace oneflow #endif // ONEFLOW_CUSTOMIZED_DATA_ONEREC_DATA_READER_H_ ================================================ FILE: oneflow/user/data/gpt_dataset.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/gpt_dataset.h" #ifdef __linux__ #include #include #include #include #include #endif namespace oneflow { namespace data { namespace { void GetSplitDocIndices(std::vector* doc_indices, const std::vector& split_sizes, size_t split_index, size_t num_docs) { CHECK_LT(split_index, split_sizes.size()); size_t total_size = 0; FOR_RANGE(size_t, i, 0, split_sizes.size()) { total_size += split_sizes[i]; } size_t split_offset = 0; RoundModeGuard round_guard(FE_TONEAREST); FOR_RANGE(size_t, i, 0, split_index) { float ratio = static_cast(split_sizes[i]) / total_size; size_t split_size = static_cast(std::nearbyint(ratio * num_docs)); split_offset += split_size; } float ratio = static_cast(split_sizes[split_index]) / total_size; size_t split_size = static_cast(std::nearbyint(ratio * num_docs)); doc_indices->resize(split_size); std::iota(doc_indices->begin(), doc_indices->end(), split_offset); } size_t GetNumEpochs(size_t num_samples, size_t seq_length, size_t tokens_per_epoch) { // num_epochs * tokens_per_epoch >= num_samples * seq_length + 1 // +1 is because we need to retrieve seq_length + 1 token each time // but the last token will overlap with the first token of the next // sample except for the last sample. return static_cast( std::ceil(static_cast(num_samples * seq_length + 1) / tokens_per_epoch)); } size_t GetNumCompleteEpochs(size_t num_samples, size_t seq_length, size_t tokens_per_epoch) { size_t num_epochs = GetNumEpochs(num_samples, seq_length, tokens_per_epoch); if (num_epochs == 1) { return 1; } size_t num_samples_per_epoch = static_cast(std::floor(static_cast(tokens_per_epoch - 1) / seq_length)); size_t num_samples_exclude_last_epoch = static_cast( std::floor(static_cast((num_epochs - 1) * tokens_per_epoch - 1) / seq_length)); CHECK_LE(num_samples_exclude_last_epoch, num_samples); size_t last_epoch_num_samples = num_samples - num_samples_exclude_last_epoch; CHECK_LT(last_epoch_num_samples, num_samples_per_epoch); bool separate_last_epoch = last_epoch_num_samples < static_cast(0.8f * num_samples_per_epoch); return separate_last_epoch ? (num_epochs - 1) : num_epochs; } } // namespace constexpr char MegatronGPTIndex::kMagicCode[]; MegatronGPTIndex::MegatronGPTIndex(const std::string& index_file_path) { auto start = std::chrono::system_clock::now(); std::ifstream stream(index_file_path, std::ios::binary); CHECK(stream.is_open()) << "can't open dataset index file " << index_file_path; // verify magic code char magic_code[kMagicCodeLen]; stream.read(magic_code, kMagicCodeLen); CHECK_EQ(std::memcmp(magic_code, kMagicCode, kMagicCodeLen), 0); // read version stream.read(reinterpret_cast(&version_), sizeof(version_)); // read dtype stream.read(&dtype_code_, sizeof(dtype_code_)); // read size of sizes and doc_offsets uint64_t sizes_size = 0; stream.read(reinterpret_cast(&sizes_size), sizeof(sizes_size)); uint64_t doc_offsets_size = 0; stream.read(reinterpret_cast(&doc_offsets_size), sizeof(doc_offsets_size)); // NOTE: this check is not necessary CHECK_EQ(sizes_size + 1, doc_offsets_size); // read sizes sizes_.resize(sizes_size); stream.read(reinterpret_cast(sizes_.data()), sizeof(decltype(sizes_)::value_type) * sizes_.size()); // read addresses addresses_.resize(sizes_size); stream.read(reinterpret_cast(addresses_.data()), sizeof(decltype(addresses_)::value_type) * addresses_.size()); // read doc_offsets doc_offsets_.resize(doc_offsets_size); stream.read(reinterpret_cast(doc_offsets_.data()), sizeof(decltype(doc_offsets_)::value_type) * doc_offsets_.size()); // check eof int pos = stream.tellg(); stream.seekg(0, std::ios_base::end); CHECK_EQ(pos, stream.tellg()); // log std::chrono::duration elapse = std::chrono::system_clock::now() - start; VLOG(2) << "Load GPT Dataset index file successed, file_path: " << index_file_path << ", number of documents: " << this->num_docs() << ", elapsed time: " << elapse.count() << " ms"; } MappedBuffer::MappedBuffer(const std::string& filename) : mapped_(nullptr), size_(0) { #ifdef __linux__ int fd = open(filename.c_str(), O_RDONLY); CHECK(fd != -1) << "open " << filename << " failed: " << strerror(errno); struct stat s; CHECK(fstat(fd, &s) != -1) << "stat " << filename << " failed: " << strerror(errno); size_ = s.st_size; mapped_ = mmap(nullptr, size_, PROT_READ, MAP_PRIVATE, fd, 0); CHECK(mapped_ != MAP_FAILED) << "mmap " << filename << " failed: " << strerror(errno); close(fd); #endif } MappedBuffer::~MappedBuffer() { #ifdef __linux__ CHECK(munmap(mapped_, size_) == 0) << "munmap failed"; #endif } MegatronGPTMMapDataset::MegatronGPTMMapDataset(const std::string& data_file_prefix, size_t seq_len, size_t label_len, size_t num_samples, const std::vector& split_sizes, size_t split_index, bool shuffle, uint32_t seed) : seq_len_(seq_len), sample_len_(seq_len + label_len), num_samples_(num_samples), shuffle_(shuffle), seed_(seed), gen_(seed) { auto start = std::chrono::system_clock::now(); index_ = std::make_unique(data_file_prefix + ".idx"); data_ = std::make_unique(data_file_prefix + ".bin"); dtype_size_ = kDTypeCode2Size.at(index_->dtype_code()); std::vector epoch_doc_indices; GetSplitDocIndices(&epoch_doc_indices, split_sizes, split_index, index_->num_docs()); tokens_per_epoch_ = GetEpochNumTokens(epoch_doc_indices); num_epochs_ = GetNumEpochs(num_samples_, seq_len_, tokens_per_epoch_); num_complete_epochs_ = GetNumCompleteEpochs(num_samples_, seq_len_, tokens_per_epoch_); InitDocIndices(epoch_doc_indices, num_epochs_, num_complete_epochs_); size_t total_num_samples = static_cast( std::floor(static_cast(num_epochs_ * tokens_per_epoch_ - 1) / seq_len_)); InitSampleIndices(total_num_samples); InitShuffleIndices(sample_indices_.size()); std::chrono::duration elapse = std::chrono::system_clock::now() - start; VLOG(2) << "Create GPT Dataset successed, sequence length: " << seq_len_ << ", number of samples: " << num_samples_ << ", total number of samples: " << shuffle_indices_.size() << ", total number of documents: " << doc_indices_.size() << ", number of epochs: " << num_epochs_ << ", number of complete epochs: " << num_complete_epochs_ << ", shuffle: " << std::boolalpha << shuffle_ << ", random_seed: " << seed_ << ", elapsed time: " << elapse.count() << " ms"; } size_t MegatronGPTMMapDataset::GetEpochNumTokens(const std::vector& doc_indices) const { size_t num_tokens = 0; for (auto doc_index : doc_indices) { num_tokens += index_->doc_length(doc_index); } return num_tokens; } void MegatronGPTMMapDataset::InitDocIndices(const std::vector& epoch_doc_indices, size_t num_epochs, size_t num_complete_epochs) { doc_indices_.reserve(epoch_doc_indices.size() * num_epochs); InitDocIndices(epoch_doc_indices, num_complete_epochs); if (num_epochs != num_complete_epochs) { CHECK_EQ(num_complete_epochs + 1, num_epochs); InitDocIndices(epoch_doc_indices, 1); } } void MegatronGPTMMapDataset::InitDocIndices(const std::vector& epoch_doc_indices, size_t num_epochs) { auto start = std::distance(doc_indices_.cbegin(), doc_indices_.cend()); FOR_RANGE(size_t, i, 0, num_epochs) { doc_indices_.insert(doc_indices_.end(), epoch_doc_indices.cbegin(), epoch_doc_indices.cend()); } if (shuffle_) { std::shuffle(doc_indices_.begin() + start, doc_indices_.end(), gen_); } } void MegatronGPTMMapDataset::InitSampleIndices(size_t total_num_samples) { sample_indices_.reserve(total_num_samples); size_t doc_indices_idx = 0; size_t doc_offset = 0; FOR_RANGE(size_t, i, 0, total_num_samples) { if (doc_indices_idx >= doc_indices_.size()) { break; } sample_indices_.emplace_back(doc_indices_idx, doc_offset); int remaining_tokens = seq_len_; while (remaining_tokens > 0) { CHECK_LT(doc_indices_idx, doc_indices_.size()); size_t doc_len = index_->doc_length(doc_indices_[doc_indices_idx]); CHECK_LT(doc_offset, doc_len); doc_len -= doc_offset; if (remaining_tokens < doc_len) { // move offset inside doc doc_offset += remaining_tokens; } else { // move to next doc doc_indices_idx += 1; doc_offset = 0; } remaining_tokens -= doc_len; } } CHECK_EQ(sample_indices_.size(), total_num_samples); CHECK_GE(sample_indices_.size(), num_samples_); } void MegatronGPTMMapDataset::InitShuffleIndices(size_t total_num_samples) { shuffle_indices_.resize(total_num_samples); std::iota(shuffle_indices_.begin(), shuffle_indices_.end(), 0); if (shuffle_) { size_t num_samples = static_cast( std::floor(static_cast(num_complete_epochs_ * tokens_per_epoch_ - 1) / seq_len_)); CHECK_LE(num_samples, shuffle_indices_.size()); std::shuffle(shuffle_indices_.begin(), shuffle_indices_.begin() + num_samples, gen_); if (num_complete_epochs_ != num_epochs_) { std::shuffle(shuffle_indices_.begin() + num_samples, shuffle_indices_.end(), gen_); } } } const HashMap MegatronGPTMMapDataset::kDTypeCode2Size = { {1, 1}, // DataType::kUInt8 {2, 1}, // DataType::kInt8 {3, 2}, // DataType::kInt16 {4, 4}, // DataType::kInt32 {5, 8}, // DataType::kInt64 {6, 4}, // DataType::kFloat {7, 8}, // DataType::kDouble {8, 2}, // DataType::kUInt16 }; } // namespace data } // namespace oneflow ================================================ FILE: oneflow/user/data/gpt_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_GPT_DATASET_H_ #define ONEFLOW_USER_DATA_GPT_DATASET_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace data { class MegatronGPTIndex final { public: MegatronGPTIndex(const std::string& index_file); ~MegatronGPTIndex() = default; static constexpr char kMagicCode[] = "MMIDIDX\x00\x00"; static constexpr size_t kMagicCodeLen = sizeof(kMagicCode) - 1; uint64_t version() const { return version_; } char dtype_code() const { return dtype_code_; } size_t num_docs() const { return sizes_.size(); } size_t doc_length(size_t doc_index) const { return sizes_.at(doc_index); } size_t doc_offset(size_t doc_index) const { return doc_offsets_.at(doc_index); } size_t address(size_t doc_index) const { return addresses_.at(doc_index); } private: uint64_t version_; char dtype_code_; std::vector sizes_; std::vector addresses_; std::vector doc_offsets_; }; class MappedBuffer final { public: MappedBuffer(const std::string& filename); ~MappedBuffer(); const void* ptr() const { return mapped_; } size_t size() const { return size_; } private: void* mapped_; size_t size_; }; class MegatronGPTMMapDataset final { public: MegatronGPTMMapDataset(const std::string& data_file_prefix, size_t seq_len, size_t label_len, size_t num_samples, const std::vector& split_sizes, size_t split_index, bool shuffle, uint32_t seed); OF_DISALLOW_COPY_AND_MOVE(MegatronGPTMMapDataset); ~MegatronGPTMMapDataset() = default; template void GetSample(size_t index, T* data) const; private: static const HashMap kDTypeCode2Size; size_t GetEpochNumTokens(const std::vector& doc_indices) const; void InitDocIndices(const std::vector& epoch_doc_indices, size_t num_epochs, size_t num_complete_epochs); void InitDocIndices(const std::vector& doc_indices, size_t num_epochs); void InitSampleIndices(size_t total_num_samples); void InitShuffleIndices(size_t total_num_samples); template void ReadTokens(const void* src, size_t offset, T* dst, size_t size) const; // initializer list size_t seq_len_; size_t sample_len_; size_t num_samples_; bool shuffle_; uint32_t seed_; std::mt19937 gen_; // initializing in constructor (in order as below) std::unique_ptr index_; std::unique_ptr data_; size_t dtype_size_; size_t tokens_per_epoch_; size_t num_epochs_; size_t num_complete_epochs_; std::vector doc_indices_; std::vector> sample_indices_; std::vector shuffle_indices_; }; template void MegatronGPTMMapDataset::GetSample(size_t index, T* data) const { CHECK_LT(index, shuffle_indices_.size()); const size_t sample_index = shuffle_indices_[index]; CHECK_LT(sample_index, sample_indices_.size()); size_t doc_indices_idx = sample_indices_[sample_index].first; size_t doc_offset = sample_indices_[sample_index].second; int remaining_tokens = sample_len_; while (remaining_tokens > 0) { CHECK_LT(doc_indices_idx, doc_indices_.size()); const size_t doc_index = doc_indices_[doc_indices_idx]; size_t offset = index_->address(doc_index) + doc_offset * dtype_size_; size_t num_tokens = index_->doc_length(doc_index); CHECK_LT(doc_offset, num_tokens); num_tokens -= doc_offset; if (num_tokens > remaining_tokens) { num_tokens = remaining_tokens; } else { doc_indices_idx += 1; doc_offset = 0; } ReadTokens(data_->ptr(), offset, data, num_tokens); data += num_tokens; remaining_tokens -= num_tokens; } CHECK_EQ(remaining_tokens, 0); } template void MegatronGPTMMapDataset::ReadTokens(const void* src, size_t bytes_offset, T* dst, size_t size) const { CHECK_NOTNULL(src); switch (index_->dtype_code()) { #define SWITCH_CASE_ENTRY(type_code, type) \ case type_code: { \ const auto* src_ptr = \ reinterpret_cast(static_cast(src) + bytes_offset); \ std::copy(src_ptr, src_ptr + size, dst); \ break; \ } SWITCH_CASE_ENTRY(1, uint8_t) SWITCH_CASE_ENTRY(2, int8_t) SWITCH_CASE_ENTRY(3, int16_t) SWITCH_CASE_ENTRY(4, int32_t) SWITCH_CASE_ENTRY(5, int64_t) SWITCH_CASE_ENTRY(6, float) SWITCH_CASE_ENTRY(7, double) SWITCH_CASE_ENTRY(8, uint16_t) #undef SWITCH_CASE_ENTRY default: { UNIMPLEMENTED(); } } } } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_GPT_DATASET_H_ ================================================ FILE: oneflow/user/data/group_batch_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_ #define ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_ #include "oneflow/user/data/dataset.h" namespace oneflow { namespace data { template class GroupBatchDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; using NestedDS = Dataset; GroupBatchDataset(size_t batch_size, const std::function& GroupId4Sample, std::unique_ptr&& dataset) : nested_ds_(std::move(dataset)), batch_size_(batch_size), group_fn_(GroupId4Sample), order_count_(0) {} ~GroupBatchDataset() = default; BatchType Next() override { BatchType batch; int64_t group_id = FindEarliestBatchGroupId(); auto group_it = group_id2buffered_samples_.find(group_id); if (group_it != group_id2buffered_samples_.end()) { auto& batch_sample_list = group_it->second; if (!batch_sample_list.empty()) { std::swap(batch, batch_sample_list.front().data); batch_sample_list.pop_front(); } } while (batch.size() < batch_size_) { auto next_batch = nested_ds_->Next(); CHECK_EQ(next_batch.size(), 1); int64_t next_group_id = group_fn_(next_batch[0]); if (group_id == -1) { group_id = next_group_id; } if (group_id == next_group_id) { batch.emplace_back(std::move(next_batch[0])); } else { auto group_it = group_id2buffered_samples_.find(next_group_id); if (group_it == group_id2buffered_samples_.end()) { group_it = group_id2buffered_samples_.emplace(next_group_id, std::list()).first; } auto& batch_sample_list = group_it->second; if (batch_sample_list.empty() || batch_sample_list.back().data.size() == batch_size_) { BatchSample batch_sample; std::swap(batch_sample.data, next_batch); batch_sample.data.reserve(batch_size_); batch_sample.order = order_count_++; batch_sample_list.emplace_back(std::move(batch_sample)); } else { batch_sample_list.back().data.emplace_back(std::move(next_batch[0])); } } } return batch; } private: int64_t FindEarliestBatchGroupId() const { int64_t group_id = -1; int64_t min_order = -1; for (const auto& pair : group_id2buffered_samples_) { if (pair.second.size() > 0) { if (min_order == -1 || pair.second.front().order < min_order) { min_order = pair.second.front().order; group_id = pair.first; } } } return group_id; } struct BatchSample { BatchType data; int64_t order; }; std::unique_ptr nested_ds_; size_t batch_size_; std::function group_fn_; std::map> group_id2buffered_samples_; int64_t order_count_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_ ================================================ FILE: oneflow/user/data/ofrecord_data_reader.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_ #define ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_ #include "oneflow/user/data/data_reader.h" #include "oneflow/user/data/ofrecord_dataset.h" #include "oneflow/user/data/ofrecord_parser.h" #include "oneflow/user/data/random_shuffle_dataset.h" #include "oneflow/user/data/batch_dataset.h" namespace oneflow { namespace data { class OFRecordDataReader final : public DataReader { public: OFRecordDataReader(user_op::KernelInitContext* ctx) : DataReader(ctx) { batch_size_ = ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape().elem_cnt(); if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); } loader_.reset(new OFRecordDataset(ctx)); if (ctx->Attr("random_shuffle")) { loader_.reset(new RandomShuffleDataset(ctx, std::move(loader_))); } loader_.reset(new BatchDataset(batch_size_, std::move(loader_))); parser_.reset(new OFRecordParser()); StartLoadThread(); } ~OFRecordDataReader() override { if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); } } protected: using DataReader::loader_; using DataReader::parser_; private: size_t batch_size_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_ ================================================ FILE: oneflow/user/data/ofrecord_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_DATASET_H_ #define ONEFLOW_USER_DATA_OFRECORD_DATASET_H_ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/user/data/dataset.h" namespace oneflow { namespace data { class OFRecordDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; OF_DISALLOW_COPY_AND_MOVE(OFRecordDataset); OFRecordDataset(user_op::KernelInitContext* ctx) { current_epoch_ = 0; shuffle_after_epoch_ = ctx->Attr("shuffle_after_epoch"); // in stream data_part_num_ = ctx->Attr("data_part_num"); std::string data_dir = ctx->Attr("data_dir"); std::string part_name_prefix = ctx->Attr("part_name_prefix"); int32_t part_name_suffix_length = ctx->Attr("part_name_suffix_length"); for (int i = 0; i < data_part_num_; ++i) { std::string num = std::to_string(i); int32_t zero_count = std::max(part_name_suffix_length - static_cast(num.length()), 0); data_file_paths_.emplace_back( JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num)); } bool is_local = false; // NOTE(zwx): OFRecordDataset is used by OFRecordDataReader and // OFRecordImageClassificationDataReader both, the latter has no attr nd_sbp, // so it couldn't work in DDP for now. The If condition here could be removed when // OFRecordImageClassificationDataReader had supported DDP (add attr nd_sbp) // or been deprecated. if (ctx->op_type_name() == "OFRecordReader") { auto nd_sbp_str_vec = ctx->Attr>("nd_sbp"); // NOTE(zwx): OFRecordDataset is not global since attr nd_sbp is empty, // we assume that it works in DDP if (nd_sbp_str_vec.empty()) { is_local = true; } } if (is_local) { parallel_id_ = GlobalProcessCtx::Rank(); parallel_num_ = GlobalProcessCtx::WorldSize(); } else { parallel_id_ = ctx->parallel_ctx().parallel_id(); parallel_num_ = ctx->parallel_ctx().parallel_num(); } CHECK_LE(parallel_num_, data_part_num_); BalancedSplitter bs(data_part_num_, parallel_num_); range_ = bs.At(parallel_id_); std::vector local_file_paths = GetLocalFilePaths(); in_stream_.reset( new PersistentInStream(DataFS(), local_file_paths, !shuffle_after_epoch_, false)); } ~OFRecordDataset() = default; BatchType Next() override { BatchType batch; batch.push_back(TensorBuffer()); ReadSample(batch.back()); return batch; } private: void ReadSample(TensorBuffer& tensor) { int64_t OFRecord_size = -1; char* size_ptr = reinterpret_cast(&OFRecord_size); if (in_stream_->ReadFully(size_ptr, sizeof(int64_t)) != 0) { ShuffleAfterEpoch(); CHECK_EQ(in_stream_->ReadFully(size_ptr, sizeof(int64_t)), 0); } CHECK_GT(OFRecord_size, 0); tensor.Resize(Shape({OFRecord_size}), DataType::kChar); CHECK_EQ(in_stream_->ReadFully(tensor.mut_data(), OFRecord_size), 0); } void ShuffleAfterEpoch() { CHECK(shuffle_after_epoch_); current_epoch_++; // move to next epoch std::mt19937 g(kOneflowDatasetSeed + current_epoch_); std::shuffle(data_file_paths_.begin(), data_file_paths_.end(), g); std::vector local_file_paths = GetLocalFilePaths(); in_stream_.reset(new PersistentInStream(DataFS(), local_file_paths, false, false)); } std::vector GetLocalFilePaths() { std::vector ret; for (int i = range_.begin(); i < range_.end(); ++i) { ret.emplace_back(data_file_paths_.at(i)); } return ret; } int32_t current_epoch_; bool shuffle_after_epoch_; int32_t data_part_num_; int32_t parallel_id_; int32_t parallel_num_; Range range_; std::vector data_file_paths_; std::unique_ptr in_stream_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_DATASET_H_ ================================================ FILE: oneflow/user/data/ofrecord_image_classification_data_reader.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ #define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ #include "oneflow/user/data/data_reader.h" #include "oneflow/user/data/ofrecord_dataset.h" #include "oneflow/user/data/ofrecord_parser.h" #include "oneflow/user/data/random_shuffle_dataset.h" #include "oneflow/user/data/batch_dataset.h" #include "oneflow/user/data/ofrecord_image_classification_dataset.h" #include "oneflow/user/data/ofrecord_image_classification_parser.h" namespace oneflow { namespace data { class OFRecordImageClassificationDataReader final : public DataReader { public: explicit OFRecordImageClassificationDataReader(user_op::KernelInitContext* ctx) : DataReader(ctx) { batch_size_ = ctx->TensorDesc4ArgNameAndIndex("image", 0)->shape().elem_cnt(); if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); } std::unique_ptr> base(new OFRecordDataset(ctx)); if (ctx->Attr("random_shuffle")) { base.reset(new RandomShuffleDataset(ctx, std::move(base))); } loader_.reset(new OFRecordImageClassificationDataset(ctx, std::move(base))); loader_.reset( new BatchDataset(batch_size_, std::move(loader_))); parser_.reset(new OFRecordImageClassificationParser()); StartLoadThread(); } ~OFRecordImageClassificationDataReader() override { if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); } } protected: using DataReader::loader_; using DataReader::parser_; private: size_t batch_size_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ ================================================ FILE: oneflow/user/data/ofrecord_image_classification_dataset.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/ofrecord_image_classification_dataset.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/user/image/image_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include namespace oneflow { namespace data { namespace { using DS = OFRecordImageClassificationDataset; void DecodeImageFromOFRecord(const OFRecord& record, const std::string& feature_name, const std::string& color_space, TensorBuffer* out) { auto image_feature_it = record.feature().find(feature_name); CHECK(image_feature_it != record.feature().end()); const Feature& image_feature = image_feature_it->second; CHECK(image_feature.has_bytes_list()); CHECK(image_feature.bytes_list().value_size() == 1); const std::string& src_data = image_feature.bytes_list().value(0); cv::Mat image = cv::imdecode(cv::Mat(1, src_data.size(), CV_8UC1, (void*)(src_data.data())), cv::IMREAD_COLOR); int W = image.cols; int H = image.rows; // convert color space if (ImageUtil::IsColor(color_space) && color_space != "BGR") { ImageUtil::ConvertColor("BGR", image, color_space, image); } CHECK(image.isContinuous()); const int c = ImageUtil::IsColor(color_space) ? 3 : 1; CHECK_EQ(c, image.channels()); Shape image_shape({H, W, c}); out->Resize(image_shape, DataType::kUInt8); CHECK_EQ(image_shape.elem_cnt(), out->nbytes()); CHECK_EQ(image_shape.elem_cnt(), image.total() * image.elemSize()); memcpy(out->mut_data(), image.ptr(), image_shape.elem_cnt()); } void DecodeLabelFromFromOFRecord(const OFRecord& record, const std::string& feature_name, TensorBuffer* out) { auto label_feature_it = record.feature().find(feature_name); CHECK(label_feature_it != record.feature().end()); const Feature& label_feature = label_feature_it->second; out->Resize(Shape({1}), DataType::kInt32); if (label_feature.has_int32_list()) { CHECK_EQ(label_feature.int32_list().value_size(), 1); *out->mut_data() = label_feature.int32_list().value(0); } else if (label_feature.has_int64_list()) { CHECK_EQ(label_feature.int64_list().value_size(), 1); *out->mut_data() = label_feature.int64_list().value(0); } else { UNIMPLEMENTED(); } } void LoadWorker(Dataset* record_dataset, std::vector>>* decode_in_buffers) { int64_t thread_idx = 0; bool shutdown = false; while (!shutdown) { auto records = record_dataset->Next(); for (auto& record : records) { auto& current_in_buffer = decode_in_buffers->at(thread_idx++); if (thread_idx >= decode_in_buffers->size()) { thread_idx = 0; } auto status = current_in_buffer->Push(std::move(record)); if (status == kBufferStatusErrorClosed) { shutdown = true; break; } CHECK(status == kBufferStatusSuccess); } } } void DecodeWorker(const std::string& image_feature_name, const std::string& label_feature_name, const std::string& color_space, Buffer* in_buffer, Buffer* out_buffer) { while (true) { TensorBuffer serialized_record; auto receive_status = in_buffer->Pull(&serialized_record); if (receive_status == kBufferStatusErrorClosed) { break; } CHECK(receive_status == kBufferStatusSuccess); OFRecord record; CHECK(record.ParseFromArray(serialized_record.data(), serialized_record.shape_view().elem_cnt())); ImageClassificationDataInstance instance; DecodeImageFromOFRecord(record, image_feature_name, color_space, &instance.image); DecodeLabelFromFromOFRecord(record, label_feature_name, &instance.label); auto send_status = out_buffer->Push(std::move(instance)); if (send_status == kBufferStatusErrorClosed) { break; } CHECK(send_status == kBufferStatusSuccess); } } int32_t GetNumLocalDecodeThreads(int32_t num_decode_threads_per_machine, const ParallelDesc& parallel_desc, const ParallelContext& parallel_ctx) { if (num_decode_threads_per_machine == 0) { num_decode_threads_per_machine = Singleton::Get()->ComputeThreadPoolSize(); } int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_ctx.parallel_id())); int64_t parallel_num_on_this_machine = parallel_desc.sorted_dev_phy_ids(machine_id).size(); return std::max(num_decode_threads_per_machine / parallel_num_on_this_machine, 1); } } // namespace OFRecordImageClassificationDataset::OFRecordImageClassificationDataset( user_op::KernelInitContext* ctx, std::unique_ptr&& dataset) : nested_ds_(std::move(dataset)), out_thread_idx_(0) { const std::string& color_space = ctx->Attr("color_space"); const std::string& image_feature_name = ctx->Attr("image_feature_name"); const std::string& label_feature_name = ctx->Attr("label_feature_name"); auto num_decode_threads_per_machine = ctx->Attr("num_decode_threads_per_machine"); auto decode_buffer_size_per_thread = ctx->Attr("decode_buffer_size_per_thread"); auto num_local_decode_threads = GetNumLocalDecodeThreads( num_decode_threads_per_machine, ctx->parallel_desc(), ctx->parallel_ctx()); decode_in_buffers_.reserve(num_local_decode_threads); decode_out_buffers_.reserve(num_local_decode_threads); for (int64_t i = 0; i < num_local_decode_threads; ++i) { decode_in_buffers_.emplace_back( std::make_unique>(decode_buffer_size_per_thread)); decode_out_buffers_.emplace_back( std::make_unique>(decode_buffer_size_per_thread)); decode_threads_.emplace_back(DecodeWorker, image_feature_name, label_feature_name, color_space, decode_in_buffers_.back().get(), decode_out_buffers_.back().get()); } load_thread_ = std::thread(LoadWorker, nested_ds_.get(), &decode_in_buffers_); } OFRecordImageClassificationDataset::~OFRecordImageClassificationDataset() { for (auto& out_buffer : decode_out_buffers_) { out_buffer->Close(); } for (auto& in_buffer : decode_in_buffers_) { in_buffer->Close(); } load_thread_.join(); for (auto& decode_thread : decode_threads_) { decode_thread.join(); } } } // namespace data } // namespace oneflow ================================================ FILE: oneflow/user/data/ofrecord_image_classification_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ #define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ #include "oneflow/user/data/dataset.h" #include "oneflow/core/common/buffer.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace data { struct ImageClassificationDataInstance { TensorBuffer label; TensorBuffer image; }; class OFRecordImageClassificationDataset final : public Dataset { public: using Base = Dataset; using SampleType = Base::SampleType; using BatchType = Base::BatchType; using NestedDS = Dataset; using NestedSampleType = NestedDS::SampleType; OF_DISALLOW_COPY_AND_MOVE(OFRecordImageClassificationDataset); OFRecordImageClassificationDataset(user_op::KernelInitContext* ctx, std::unique_ptr&& dataset); ~OFRecordImageClassificationDataset() override; BatchType Next() override { size_t thread_idx = out_thread_idx_.fetch_add(1, std::memory_order_relaxed) % decode_out_buffers_.size(); CHECK_LT(thread_idx, decode_out_buffers_.size()); BatchType batch; SampleType sample; auto status = decode_out_buffers_[thread_idx]->Pull(&sample); CHECK_EQ(status, kBufferStatusSuccess); batch.push_back(std::move(sample)); return batch; } private: std::unique_ptr nested_ds_; std::thread load_thread_; std::vector decode_threads_; std::vector>> decode_in_buffers_; std::vector>> decode_out_buffers_; std::atomic out_thread_idx_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ ================================================ FILE: oneflow/user/data/ofrecord_image_classification_parser.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ #define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ #include "oneflow/user/data/parser.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/data/ofrecord_image_classification_dataset.h" namespace oneflow { namespace data { class OFRecordImageClassificationParser final : public Parser { public: using Base = Parser; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; OFRecordImageClassificationParser() = default; ~OFRecordImageClassificationParser() override = default; void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override { const int64_t batch_size = batch_data.size(); user_op::Tensor* image_tensor = ctx->Tensor4ArgNameAndIndex("image", 0); CHECK_EQ(image_tensor->shape_view().NumAxes(), 1); CHECK_EQ(image_tensor->shape_view().At(0), batch_size); auto* image_buffers = image_tensor->mut_dptr(); user_op::Tensor* label_tensor = ctx->Tensor4ArgNameAndIndex("label", 0); CHECK_EQ(label_tensor->shape_view().NumAxes(), 1); CHECK_EQ(label_tensor->shape_view().At(0), batch_size); auto* label_buffers = label_tensor->mut_dptr(); for (size_t i = 0; i < batch_data.size(); ++i) { auto& instance = batch_data[i]; image_buffers[i].Swap(instance.image); label_buffers[i].Swap(instance.label); } } }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ ================================================ FILE: oneflow/user/data/ofrecord_parser.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_OFRECORD_PARSER_H_ #define ONEFLOW_USER_DATA_OFRECORD_PARSER_H_ #include "oneflow/user/data/parser.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace data { class OFRecordParser final : public Parser { public: using Base = Parser; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; OFRecordParser() = default; ~OFRecordParser() = default; void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); OFRecord* dptr = out_tensor->mut_dptr(); MultiThreadLoop(batch_data.size(), [&](size_t i) { auto& sample = batch_data[i]; CHECK(dptr[i].ParseFromArray(sample.data(), sample.nbytes())); }); if (batch_data.size() != out_tensor->shape_view().elem_cnt()) { CHECK_EQ(out_tensor->mut_shape_view().NumAxes(), 1); out_tensor->mut_shape_view().Set(0, batch_data.size()); } } }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_OFRECORD_PARSER_H_ ================================================ FILE: oneflow/user/data/parser.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_PARSER_H_ #define ONEFLOW_USER_DATA_PARSER_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace data { template class Parser { public: using SampleType = LoadTarget; using BatchType = std::vector; Parser() = default; virtual ~Parser() = default; virtual void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) = 0; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_PARSER_H_ ================================================ FILE: oneflow/user/data/random_shuffle_dataset.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_ #define ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_ #include "oneflow/user/data/dataset.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace data { template class RandomShuffleDataset final : public Dataset { public: using Base = Dataset; using SampleType = typename Base::SampleType; using BatchType = typename Base::BatchType; RandomShuffleDataset(user_op::KernelInitContext* ctx, std::unique_ptr>&& dataset) : nested_ds_(std::move(dataset)) { // random seed_ = ctx->Attr("seed"); if (seed_ == -1) { seed_ = NewRandomSeed(); } std::seed_seq seq({seed_}); rand_engine_ = std::default_random_engine(seq); // fill buffer initial_buffer_fill_ = ctx->Attr("shuffle_buffer_size"); int32_t remain_cnt = initial_buffer_fill_; while (remain_cnt > 0) { BatchType batch = nested_ds_->Next(); for (auto& sample : batch) { sample_buffer_.push_back(std::move(sample)); remain_cnt--; } } } ~RandomShuffleDataset() = default; BatchType Next() override { BatchType batch = nested_ds_->Next(); for (auto& sample : batch) { std::uniform_int_distribution<> dis(0, sample_buffer_.size() - 1); int offset = dis(rand_engine_); std::swap(sample_buffer_[offset], sample); } return batch; } private: std::unique_ptr> nested_ds_; std::vector sample_buffer_; int32_t initial_buffer_fill_; std::default_random_engine rand_engine_; int64_t seed_; }; } // namespace data } // namespace oneflow #endif // ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_ ================================================ FILE: oneflow/user/image/crop_window.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_IMAGE_CROP_WINDOW_H_ #define ONEFLOW_USER_IMAGE_CROP_WINDOW_H_ #include "oneflow/core/common/shape.h" namespace oneflow { struct CropWindow { Shape anchor; Shape shape; CropWindow() : anchor{0, 0}, shape{0, 0} {} }; } // namespace oneflow #endif // ONEFLOW_USER_IMAGE_CROP_WINDOW_H_ ================================================ FILE: oneflow/user/image/image_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/image/image_util.h" #include namespace oneflow { bool ImageUtil::IsColor(const std::string& color_space) { if (color_space == "RGB" || color_space == "BGR") { return true; } else if (color_space == "GRAY") { return false; } else { UNIMPLEMENTED(); return false; } } void ImageUtil::ConvertColor(const std::string& input_color, const cv::Mat& input_img, const std::string& output_color, cv::Mat& output_img) { if (input_color == "BGR" && output_color == "RGB") { cv::cvtColor(input_img, output_img, cv::COLOR_BGR2RGB); } else { UNIMPLEMENTED(); } } cv::Mat GenCvMat4ImageBuffer(const TensorBuffer& image_buffer) { CHECK_EQ(image_buffer.shape_view().NumAxes(), 3); int h = image_buffer.shape_view().At(0); int w = image_buffer.shape_view().At(1); int channels = image_buffer.shape_view().At(2); DataType data_type = image_buffer.data_type(); if (channels == 1 && data_type == DataType::kUInt8) { return CreateMatWithPtr(h, w, CV_8UC1, image_buffer.data()); } else if (channels == 1 && data_type == DataType::kFloat) { return CreateMatWithPtr(h, w, CV_32FC1, image_buffer.data()); } else if (channels == 3 && data_type == DataType::kUInt8) { return CreateMatWithPtr(h, w, CV_8UC3, image_buffer.data()); } else if (channels == 3 && data_type == DataType::kFloat) { return CreateMatWithPtr(h, w, CV_32FC3, image_buffer.data()); } else { UNIMPLEMENTED(); } return cv::Mat(); } cv::Mat GenCvMat4ImageTensor(const user_op::Tensor* image_tensor, int image_offset) { int has_batch_dim = 0; if (image_tensor->shape_view().NumAxes() == 3) { has_batch_dim = 0; image_offset = 0; } else if (image_tensor->shape_view().NumAxes() == 4) { has_batch_dim = 1; CHECK_GE(image_offset, 0); CHECK_LT(image_offset, image_tensor->shape_view().At(0)); } else { UNIMPLEMENTED(); } int h = image_tensor->shape_view().At(0 + has_batch_dim); int w = image_tensor->shape_view().At(1 + has_batch_dim); int c = image_tensor->shape_view().At(2 + has_batch_dim); int elem_offset = image_offset * h * w * c; DataType data_type = image_tensor->data_type(); if (c == 1 && data_type == DataType::kUInt8) { return CreateMatWithPtr(h, w, CV_8UC1, image_tensor->dptr() + elem_offset); } else if (c == 1 && data_type == DataType::kFloat) { return CreateMatWithPtr(h, w, CV_32FC1, image_tensor->dptr() + elem_offset); } else if (c == 3 && data_type == DataType::kUInt8) { return CreateMatWithPtr(h, w, CV_8UC3, image_tensor->dptr() + elem_offset); } else if (c == 3 && data_type == DataType::kFloat) { return CreateMatWithPtr(h, w, CV_32FC3, image_tensor->dptr() + elem_offset); } else { UNIMPLEMENTED(); } return cv::Mat(); } void CvMatConvertToDataType(const cv::Mat& src, cv::Mat* dst, DataType dtype) { if (dtype == DataType::kUInt8) { src.convertTo(*dst, CV_8U); } else if (dtype == DataType::kFloat) { src.convertTo(*dst, CV_32F); } else { UNIMPLEMENTED(); } } int GetCvInterpolationFlag(const std::string& interp_type, int org_w, int org_h, int res_w, int res_h) { if (interp_type == "bilinear") { return cv::INTER_LINEAR; } else if (interp_type == "nearest_neighbor" || interp_type == "nn") { return cv::INTER_NEAREST; } else if (interp_type == "bicubic") { return cv::INTER_CUBIC; } else if (interp_type == "area") { return cv::INTER_AREA; } else if (interp_type == "auto") { if (res_w * res_h >= org_w * org_h) { return cv::INTER_LINEAR; } else { return cv::INTER_AREA; } } else { UNIMPLEMENTED(); } } bool CheckInterpolationValid(const std::string& interp_type, std::ostringstream& err) { if (interp_type != "bilinear" && interp_type != "nearest_neighbor" && interp_type != "nn" && interp_type != "bicubic" && interp_type != "area" && interp_type != "auto") { err << ", interpolation_type: " << interp_type << " (interpolation_type must be one of bilinear, nearest_neighbor(nn), bicubic, area and " "auto)"; return false; } return true; } } // namespace oneflow ================================================ FILE: oneflow/user/image/image_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_ #define ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/framework/user_op_tensor.h" #include namespace oneflow { struct ImageUtil { static bool IsColor(const std::string& color_space); static void ConvertColor(const std::string& input_color, const cv::Mat& input_img, const std::string& output_color, cv::Mat& output_img); }; template inline cv::Mat CreateMatWithPtr(int H, int W, int type, const T* ptr, size_t step = cv::Mat::AUTO_STEP) { return cv::Mat(H, W, type, const_cast(ptr), step); } cv::Mat GenCvMat4ImageBuffer(const TensorBuffer& image_buffer); cv::Mat GenCvMat4ImageTensor(const user_op::Tensor* image_tensor, int image_offset); void CvMatConvertToDataType(const cv::Mat& src, cv::Mat* dst, DataType dtype); int GetCvInterpolationFlag(const std::string& inter_type, int org_w, int org_h, int res_w, int res_h); bool CheckInterpolationValid(const std::string& interp_type, std::ostringstream& ss); } // namespace oneflow #endif // ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_ ================================================ FILE: oneflow/user/image/jpeg_decoder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/user/image/jpeg_decoder.h" #include "oneflow/user/image/image_util.h" namespace oneflow { class LibjpegCtx { public: explicit LibjpegCtx(struct jpeg_decompress_struct* compress_info) : compress_info_(compress_info) {} ~LibjpegCtx() { jpeg_destroy_decompress(compress_info_); } OF_DISALLOW_COPY_AND_MOVE(LibjpegCtx); struct jpeg_decompress_struct* compress_info() { return compress_info_; } private: struct jpeg_decompress_struct* compress_info_; }; bool JpegPartialDecodeRandomCropImage(const unsigned char* data, size_t length, RandomCropGenerator* random_crop_gen, unsigned char* workspace, size_t workspace_size, cv::Mat* out_mat) { struct jpeg_decompress_struct compress_info {}; struct jpeg_error_mgr jpeg_err {}; compress_info.err = jpeg_std_error(&jpeg_err); jpeg_create_decompress(&compress_info); if (compress_info.err->msg_code != 0) { return false; } LibjpegCtx ctx_guard(&compress_info); jpeg_mem_src(ctx_guard.compress_info(), data, length); if (ctx_guard.compress_info()->err->msg_code != 0) { return false; } int rc = jpeg_read_header(ctx_guard.compress_info(), TRUE); if (rc != JPEG_HEADER_OK) { return false; } jpeg_start_decompress(ctx_guard.compress_info()); int width = ctx_guard.compress_info()->output_width; int height = ctx_guard.compress_info()->output_height; int pixel_size = ctx_guard.compress_info()->output_components; unsigned int u_crop_x = 0, u_crop_y = 0, u_crop_w = width, u_crop_h = height; if (random_crop_gen) { CropWindow crop; random_crop_gen->GenerateCropWindow({height, width}, &crop); u_crop_y = crop.anchor.At(0); u_crop_x = crop.anchor.At(1); u_crop_h = crop.shape.At(0); u_crop_w = crop.shape.At(1); } unsigned int tmp_w = u_crop_w; jpeg_crop_scanline(ctx_guard.compress_info(), &u_crop_x, &tmp_w); if (jpeg_skip_scanlines(ctx_guard.compress_info(), u_crop_y) != u_crop_y) { return false; } int row_offset = (tmp_w - u_crop_w) * pixel_size; int out_row_stride = u_crop_w * pixel_size; std::vector decode_output_buf; unsigned char* decode_output_pointer = nullptr; size_t image_space_size = width * pixel_size; if (image_space_size > workspace_size) { decode_output_buf.resize(image_space_size); decode_output_pointer = decode_output_buf.data(); } else { decode_output_pointer = workspace; } out_mat->create(u_crop_h, u_crop_w, CV_8UC3); while (ctx_guard.compress_info()->output_scanline < u_crop_y + u_crop_h) { unsigned char* buffer_array[1]; buffer_array[0] = decode_output_pointer; unsigned int read_line_index = ctx_guard.compress_info()->output_scanline; jpeg_read_scanlines(ctx_guard.compress_info(), buffer_array, 1); memcpy(out_mat->data + (read_line_index - u_crop_y) * out_row_stride, decode_output_pointer + row_offset, out_row_stride); } jpeg_skip_scanlines(ctx_guard.compress_info(), height - u_crop_y - u_crop_h); jpeg_finish_decompress(ctx_guard.compress_info()); return true; } void OpenCvPartialDecodeRandomCropImage(const unsigned char* data, size_t length, RandomCropGenerator* random_crop_gen, const std::string& color_space, cv::Mat& out_mat) { cv::Mat image = cv::imdecode(cv::Mat(1, length, CV_8UC1, const_cast(data)), ImageUtil::IsColor(color_space) ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE); int W = image.cols; int H = image.rows; // random crop if (random_crop_gen != nullptr) { CHECK(image.data != nullptr); cv::Mat image_roi; CropWindow crop; random_crop_gen->GenerateCropWindow({H, W}, &crop); const int y = crop.anchor.At(0); const int x = crop.anchor.At(1); const int newH = crop.shape.At(0); const int newW = crop.shape.At(1); CHECK(newW > 0 && newW <= W); CHECK(newH > 0 && newH <= H); cv::Rect roi(x, y, newW, newH); image(roi).copyTo(out_mat); W = out_mat.cols; H = out_mat.rows; CHECK(W == newW); CHECK(H == newH); } else { image.copyTo(out_mat); } } } // namespace oneflow ================================================ FILE: oneflow/user/image/jpeg_decoder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_IMAGE_JPEG_DECODER_H_ #define ONEFLOW_USER_IMAGE_JPEG_DECODER_H_ #include #include #include #include "oneflow/user/image/random_crop_generator.h" namespace oneflow { bool JpegPartialDecodeRandomCropImage(const unsigned char* data, size_t length, RandomCropGenerator* random_crop_gen, unsigned char* workspace, size_t workspace_size, cv::Mat* out_mat); void OpenCvPartialDecodeRandomCropImage(const unsigned char* data, size_t length, RandomCropGenerator* random_crop_gen, const std::string& color_space, cv::Mat& out_mat); } // namespace oneflow #endif // ONEFLOW_USER_IMAGE_JPEG_DECODER_H_ ================================================ FILE: oneflow/user/image/jpeg_decoder_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include #include #include "oneflow/user/image/jpeg_decoder.h" #include "oneflow/user/image/image_util.h" namespace oneflow { // generate image void GenerateImage(std::vector& jpg, int w, int h) { std::vector raw_data(w * h * 3); for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { uint8_t r = 0, g = 0, b = 0; if (i < w / 2 && j < h / 2) { r = 255; g = 0; b = 0; } else if ((i >= w / 2 && j < h / 2) || (i < w / 2 && j >= h / 2)) { r = 0; g = 255; b = 0; } else if ((i >= w / 2) && (j >= h / 2)) { r = 0; g = 0; b = 255; } raw_data[3 * (i * w + j)] = b; raw_data[3 * (i * w + j) + 1] = g; raw_data[3 * (i * w + j) + 2] = r; } } std::vector compression_params; compression_params.push_back(cv::IMWRITE_JPEG_QUALITY); compression_params.push_back(100); cv::Mat raw(h, w, CV_8UC3, (void*)raw_data.data(), cv::Mat::AUTO_STEP); cv::imencode(".jpg", raw, jpg); } TEST(JPEG, decoder) { constexpr size_t test_num = 3; std::vector jpg; GenerateImage(jpg, 192, 192); std::seed_seq seq{1, 2, 3}; std::vector seeds(test_num); seq.generate(seeds.begin(), seeds.end()); for (int i = 0; i < test_num; i++) { cv::Mat libjpeg_image_mat; RandomCropGenerator libjpeg_random_crop_gen({0.1, 0.9}, {0.4, 0.6}, seeds[i], 1); RandomCropGenerator opencv_random_crop_gen({0.1, 0.9}, {0.4, 0.6}, seeds[i], 1); auto status = JpegPartialDecodeRandomCropImage(jpg.data(), jpg.size(), &libjpeg_random_crop_gen, nullptr, 0, &libjpeg_image_mat); ASSERT_EQ(status, true); cv::Mat opencv_image_mat; std::string color_space("RGB"); OpenCvPartialDecodeRandomCropImage(jpg.data(), jpg.size(), &opencv_random_crop_gen, color_space, opencv_image_mat); ImageUtil::ConvertColor("BGR", opencv_image_mat, color_space, opencv_image_mat); cv::Mat checkout = libjpeg_image_mat - opencv_image_mat; auto sum = cv::sum(cv::sum(checkout)); ASSERT_EQ(sum[0], 0); // cv::imwrite("jpeg.ppm", libjpeg_image_mat); // cv::imwrite("opencv.ppm", opencv_image_mat); } } } // namespace oneflow ================================================ FILE: oneflow/user/image/random_crop_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/image/random_crop_generator.h" namespace oneflow { RandomCropGenerator::RandomCropGenerator(AspectRatioRange aspect_ratio_range, AreaRange area_range, int64_t seed, int32_t num_attempts) : aspect_ratio_range_(aspect_ratio_range), aspect_ratio_log_dis_(std::log(aspect_ratio_range.first), std::log(aspect_ratio_range.second)), area_dis_(area_range.first, area_range.second), rand_gen_(seed), seed_(seed), num_attempts_(num_attempts) {} void RandomCropGenerator::GenerateCropWindow(const Shape& shape, CropWindow* crop_window) { CHECK_EQ(shape.NumAxes(), 2); CHECK(crop_window != nullptr); int H = shape.At(0); int W = shape.At(1); if (H <= 0 || W <= 0) { return; } float min_wh_ratio = aspect_ratio_range_.first; float max_wh_ratio = aspect_ratio_range_.second; float max_hw_ratio = 1 / aspect_ratio_range_.first; float min_area = W * H * area_dis_.a(); int maxW = std::max(1, static_cast(H * max_wh_ratio)); int maxH = std::max(1, static_cast(W * max_hw_ratio)); if (H * maxW < min_area) { crop_window->shape = Shape({H, maxW}); } else if (W * maxH < min_area) { crop_window->shape = Shape({maxH, W}); } else { int attempts_left = num_attempts_; for (; attempts_left > 0; attempts_left--) { float scale = area_dis_(rand_gen_); size_t original_area = H * W; float target_area = scale * original_area; float ratio = std::exp(aspect_ratio_log_dis_(rand_gen_)); int w = static_cast(std::roundf(sqrtf(target_area * ratio))); int h = static_cast(std::roundf(sqrtf(target_area / ratio))); w = std::max(w, 1); h = std::max(h, 1); crop_window->shape = Shape({h, w}); ratio = static_cast(w) / h; if (w <= W && h <= H && ratio >= min_wh_ratio && ratio <= max_wh_ratio) { break; } } if (attempts_left <= 0) { float max_area = area_dis_.b() * W * H; float ratio = static_cast(W) / H; if (ratio > max_wh_ratio) { crop_window->shape = Shape({H, maxW}); } else if (ratio < min_wh_ratio) { crop_window->shape = Shape({maxH, W}); } else { crop_window->shape = Shape({H, W}); } float scale = std::min(1.0f, max_area / (crop_window->shape.At(0) * crop_window->shape.At(1))); crop_window->shape.Set(0, std::max(1, crop_window->shape.At(0) * std::sqrt(scale))); crop_window->shape.Set(1, std::max(1, crop_window->shape.At(1) * std::sqrt(scale))); } } crop_window->anchor.Set( 0, std::uniform_int_distribution(0, H - crop_window->shape.At(0))(rand_gen_)); crop_window->anchor.Set( 1, std::uniform_int_distribution(0, W - crop_window->shape.At(1))(rand_gen_)); } void RandomCropGenerator::GenerateCropWindows(const Shape& shape, size_t n, std::vector* crop_windows) { std::seed_seq seq{seed_}; std::vector seeds(n); seq.generate(seeds.begin(), seeds.end()); crop_windows->resize(n); for (std::size_t i = 0; i < n; i++) { rand_gen_.seed(seeds.at(i)); GenerateCropWindow(shape, &(crop_windows->at(i))); } } } // namespace oneflow ================================================ FILE: oneflow/user/image/random_crop_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_ #define ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_ #include "oneflow/user/image/crop_window.h" namespace oneflow { using AspectRatioRange = std::pair; using AreaRange = std::pair; class RandomCropGenerator { public: RandomCropGenerator(AspectRatioRange aspect_ratio_range, AreaRange area_range, int64_t seed, int32_t num_attempts); void GenerateCropWindow(const Shape& shape, CropWindow* crop_window); void GenerateCropWindows(const Shape& shape, size_t n, std::vector* crop_windows); private: AspectRatioRange aspect_ratio_range_; std::uniform_real_distribution aspect_ratio_log_dis_; std::uniform_real_distribution area_dis_; std::mt19937 rand_gen_; int64_t seed_; int32_t num_attempts_; }; } // namespace oneflow #endif // ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_ ================================================ FILE: oneflow/user/kernels/acc_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/add.h" namespace oneflow { namespace { class AccKernel final : public user_op::OpKernel { public: AccKernel() = default; ~AccKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()); CHECK_EQ(in->data_type(), out->data_type()); std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(primitive); primitive->Launch(ctx->stream(), out->dptr(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("acc").SetCreateFn(); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/activation_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/user/kernels/elementwise_primitive_kernel.h" namespace oneflow { REGISTER_USER_KERNEL("elu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kElu, src->data_type(), dst->data_type(), ctx->Attr("alpha")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kElu, "out", "in")); REGISTER_USER_KERNEL("elu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kEluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("alpha")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kEluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("celu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kCelu, src->data_type(), dst->data_type(), ctx->Attr("alpha")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kCelu, "out", "in")); REGISTER_USER_KERNEL("celu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kCeluBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("alpha")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kCeluBackwardWithDyY, "dx", "dy")); REGISTER_USER_KERNEL("hardswish") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kHardSwish, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardSwish, "out", "in")); REGISTER_USER_KERNEL("hardswish_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kHardswishBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardswishBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("hardsigmoid") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kHardSigmoid, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardSigmoid, "out", "in")); REGISTER_USER_KERNEL("hardsigmoid_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kHardsigmoidBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardsigmoidBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("hardshrink") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kHardShrink, src->data_type(), dst->data_type(), ctx->Attr("lambd")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardShrink, "out", "in")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("hardshrink_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kHardshrinkBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("lambd")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardshrinkBackwardWithDyY, "dx", "dy")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("hardtanh") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kHardTanh, src->data_type(), dst->data_type(), ctx->Attr("min_val"), ctx->Attr("max_val")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardTanh, "out", "in")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("hardtanh_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kHardtanhBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("min_val"), ctx->Attr("max_val")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardtanhBackwardWithDyY, "dx", "dy")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("gelu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kGelu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kGelu, "out", "in")); REGISTER_USER_KERNEL("gelu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kGeluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kGeluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("fast_gelu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kFastGelu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kFastGelu, "out", "in")); REGISTER_USER_KERNEL("fast_gelu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kFastGeluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kFastGeluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("quick_gelu") .SetCreateFn([]() { return user_op::NewOpKernel( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kQuickGelu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kQuickGelu, "y", "x")); REGISTER_USER_KERNEL("quick_gelu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("square_relu") .SetCreateFn([]() { return user_op::NewOpKernel( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSquareReLU, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSquareReLU, "y", "x")); REGISTER_USER_KERNEL("square_relu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("leaky_relu") .SetCreateFn([]() { return user_op::NewOpKernel( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kLeakyRelu, src->data_type(), dst->data_type(), ctx->Attr("alpha")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kLeakyRelu, "y", "x")); REGISTER_USER_KERNEL("leaky_relu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kLeakyReluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("alpha")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kLeakyReluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("mish") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kMish, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kMish, "out", "in")); REGISTER_USER_KERNEL("mish_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMishBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kMishBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("relu") .SetCreateFn([]() { return user_op::NewOpKernel( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("relu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kReluBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kReluBackwardWithDyY, "dx", "dy")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); return Maybe::Ok(); }); REGISTER_USER_KERNEL("silu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSilu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSilu, "out", "in")); REGISTER_USER_KERNEL("silu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSiluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSiluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("trunc") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kTrunc, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kTrunc, "out", "in")); REGISTER_USER_KERNEL("selu") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSelu, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSelu, "out", "in")); REGISTER_USER_KERNEL("selu_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSeluBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSeluBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("softshrink") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSoftShrink, src->data_type(), dst->data_type(), ctx->Attr("alpha")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftShrink, "out", "in")); REGISTER_USER_KERNEL("softshrink_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSoftshrinkBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("alpha")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftshrinkBackwardWithDyY, "dx", "dy")); REGISTER_USER_KERNEL("softsign") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSoftSign, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftSign, "out", "in")); REGISTER_USER_KERNEL("softsign_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSoftsignBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftsignBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("softplus") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kSoftPlus, src->data_type(), dst->data_type(), ctx->Attr("beta"), ctx->Attr("threshold")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftPlus, "out", "in")); REGISTER_USER_KERNEL("softplus_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kSoftplusBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("beta"), ctx->Attr("threshold")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftplusBackwardWithDyX, "dx", "dy")); REGISTER_USER_KERNEL("tanh") .SetCreateFn([]() { return user_op::NewOpKernel( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kTanh, src->data_type(), dst->data_type()); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kTanh, "y", "x")); REGISTER_USER_KERNEL("tanh_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kTanhBackwardWithDyY, src->data_type(), dst->data_type(), 1 /*max_num_dims*/); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kTanhBackwardWithDyY, "dx", "dy")); REGISTER_USER_KERNEL("threshold") .SetCreateFn([]() { return user_op::NewOpKernel( "out", "in", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kThreshold, src->data_type(), dst->data_type(), ctx->Attr("threshold_val"), ctx->Attr("value")); }); }) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kThreshold, "out", "in")); REGISTER_USER_KERNEL("threshold_grad") .SetCreateFn([]() { return user_op::NewOpKernel( "dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kThresholdBackwardWithDyX, src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr("threshold_val")); }); }) .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kThresholdBackwardWithDyX, "dx", "dy")); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/adaptive_avg_pool_cpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/user/kernels/adaptive_pool_kernel_util.h" namespace oneflow { namespace { template void AvgForwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim) { user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); // TODO (Tianyu): Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_avg_pool on cpu only supports NCHW data format"; const Shape& in = GetShape5D(x_shape, data_format, dim); const Shape& out = GetShape5D(y_shape, data_format, dim); const T* in_ptr = in_tensor->dptr(); T* out_ptr = out_tensor->mut_dptr(); const int64_t input_width = in.Count(4); const int64_t output_width = out.Count(4); const int64_t input_image_size = in.Count(3); const int64_t output_image_size = out.Count(3); const int64_t input_size = in.Count(2); const int64_t output_size = out.Count(2); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, od, 0, out.At(2)) { int64_t id0 = start_index(od, out.At(2), in.At(2)); int64_t id1 = end_index(od, out.At(2), in.At(2)); int64_t kd = id1 - id0; FOR_RANGE(int64_t, oh, 0, out.At(3)) { int64_t ih0 = start_index(oh, out.At(3), in.At(3)); int64_t ih1 = end_index(oh, out.At(3), in.At(3)); int64_t kh = ih1 - ih0; FOR_RANGE(int64_t, ow, 0, out.At(4)) { int64_t iw0 = start_index(ow, out.At(4), in.At(4)); int64_t iw1 = end_index(ow, out.At(4), in.At(4)); int64_t kw = iw1 - iw0; // Compute local average accT sum = static_cast(0); FOR_RANGE(int64_t, id, id0, id1) { FOR_RANGE(int64_t, ih, ih0, ih1) { FOR_RANGE(int64_t, iw, iw0, iw1) { sum += static_cast(in_ptr[id * input_image_size + ih * input_width + iw]); } } } out_ptr[od * output_image_size + oh * output_width + ow] = static_cast(sum / kd / kh / kw); } } } in_ptr += input_size; out_ptr += output_size; } } } template void AvgBackwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim) { user_op::Tensor* grad_input = ctx->Tensor4ArgNameAndIndex("dx", 0); const user_op::Tensor* grad_output = ctx->Tensor4ArgNameAndIndex("dy", 0); const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape(); const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape(); // TODO (Tianyu): Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_avg_pool backward on cpu only supports NCHW data format"; const Shape& in = GetShape5D(dx_shape, data_format, dim); const Shape& out = GetShape5D(dy_shape, data_format, dim); const T* out_ptr = grad_output->dptr(); T* in_ptr = grad_input->mut_dptr(); std::fill(in_ptr, in_ptr + grad_input->shape_view().elem_cnt(), static_cast(0)); const int64_t input_width = in.Count(4); const int64_t output_width = out.Count(4); const int64_t input_image_size = in.Count(3); const int64_t output_image_size = out.Count(3); const int64_t input_size = in.Count(2); const int64_t output_size = out.Count(2); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, od, 0, out.At(2)) { int64_t id0 = start_index(od, out.At(2), in.At(2)); int64_t id1 = end_index(od, out.At(2), in.At(2)); int64_t kd = id1 - id0; FOR_RANGE(int64_t, oh, 0, out.At(3)) { int64_t ih0 = start_index(oh, out.At(3), in.At(3)); int64_t ih1 = end_index(oh, out.At(3), in.At(3)); int64_t kh = ih1 - ih0; FOR_RANGE(int64_t, ow, 0, out.At(4)) { int64_t iw0 = start_index(ow, out.At(4), in.At(4)); int64_t iw1 = end_index(ow, out.At(4), in.At(4)); int64_t kw = iw1 - iw0; T grad_delta = static_cast(out_ptr[od * output_image_size + oh * output_width + ow] / kd / kh / kw); FOR_RANGE(int64_t, id, id0, id1) { FOR_RANGE(int64_t, ih, ih0, ih1) { FOR_RANGE(int64_t, iw, iw0, iw1) { in_ptr[id * input_image_size + ih * input_width + iw] += grad_delta; } } } } } } in_ptr += input_size; out_ptr += output_size; } } } } // namespace template class AdaptivePool1DCpuKernel final : public user_op::OpKernel { public: AdaptivePool1DCpuKernel() = default; ~AdaptivePool1DCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { if (GetDataType::value == kFloat16) { AvgForwardCompute(ctx, 1); } else { AvgForwardCompute(ctx, 1); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptivePool2DCpuKernel final : public user_op::OpKernel { public: AdaptivePool2DCpuKernel() = default; ~AdaptivePool2DCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { if (GetDataType::value == kFloat16) { AvgForwardCompute(ctx, 2); } else { AvgForwardCompute(ctx, 2); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptivePool3DCpuKernel final : public user_op::OpKernel { public: AdaptivePool3DCpuKernel() = default; ~AdaptivePool3DCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { if (GetDataType::value == kFloat16) { AvgForwardCompute(ctx, 3); } else { AvgForwardCompute(ctx, 3); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptivePool1DCpuGradKernel final : public user_op::OpKernel { public: AdaptivePool1DCpuGradKernel() = default; ~AdaptivePool1DCpuGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 1); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptivePool2DCpuGradKernel final : public user_op::OpKernel { public: AdaptivePool2DCpuGradKernel() = default; ~AdaptivePool2DCpuGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptivePool3DCpuGradKernel final : public user_op::OpKernel { public: AdaptivePool3DCpuGradKernel() = default; ~AdaptivePool3DCpuGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 3); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ADAPTIVE_POOL_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_avg_pool1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); #define REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(device) \ REGISTER_ADAPTIVE_POOL_KERNEL(device, float16) \ REGISTER_ADAPTIVE_POOL_KERNEL(device, float) \ REGISTER_ADAPTIVE_POOL_KERNEL(device, double) \ REGISTER_ADAPTIVE_POOL_KERNEL(device, int) REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(DeviceType::kCPU) #define REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_avg_pool1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); #define REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(device) \ REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, float16) \ REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, float) \ REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, double) \ REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int) REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(DeviceType::kCPU) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/adaptive_avg_pool_gpu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/common/data_type.h" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/user/kernels/adaptive_pool_kernel_util.h" namespace oneflow { namespace user_op { template __global__ void InitPtr(int elements, T* ptr) { int gid = (blockDim.x * blockIdx.x) + threadIdx.x; int step = gridDim.x * blockDim.x; while (gid < elements) { ptr[gid] = static_cast(0); gid += step; } } inline Shape GetShape5D(const Shape& shape, const std::string& data_format, int32_t dim) { FixedDimVector shape_3d = {GetInDim(shape, data_format, 0, dim), GetInDim(shape, data_format, 1, dim), GetInDim(shape, data_format, 2, dim)}; return Shape({shape.At(0), shape.At(1), shape_3d.at(0), shape_3d.at(1), shape_3d.at(2)}); } template __global__ void AdaptiveAvgPoolCudaKernel(const T* input, T* output, int num_elems, int in_d, int in_h, int in_w, int out_d, int out_h, int out_w) { const int out_panel_size = out_d * out_h * out_w; const int in_panel_size = in_d * in_h * in_w; CUDA_1D_KERNEL_LOOP(idx, num_elems) { // TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper' int bc_idx = idx / out_panel_size; int out_d_idx = (idx % out_panel_size) / out_w / out_h; int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w; int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w; int in_start_d = START_IND(out_d_idx, out_d, in_d); int in_end_d = END_IND(out_d_idx, out_d, in_d); int k_d = in_end_d - in_start_d; int in_start_h = START_IND(out_h_idx, out_h, in_h); int in_end_h = END_IND(out_h_idx, out_h, in_h); int k_h = in_end_h - in_start_h; int in_start_w = START_IND(out_w_idx, out_w, in_w); int in_end_w = END_IND(out_w_idx, out_w, in_w); int k_w = in_end_w - in_start_w; const T* in_ptr = input + bc_idx * in_panel_size + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w; T sum = static_cast(0); for (int id = 0; id < k_d; ++id) { for (int ih = 0; ih < k_h; ++ih) { for (int iw = 0; iw < k_w; ++iw) { T val = *(in_ptr + ih * in_w + iw); sum += val; } } in_ptr += in_h * in_w; // next input depth } // Update output output[idx] = sum / static_cast(k_d) / static_cast(k_h) / static_cast(k_w); } } template __global__ void AdaptiveAvgPoolGradCudaKernel(T* input, const T* output, int num_elems, int in_d, int in_h, int in_w, int out_d, int out_h, int out_w) { const int out_panel_size = out_d * out_h * out_w; const int in_panel_size = in_d * in_h * in_w; CUDA_1D_KERNEL_LOOP(idx, num_elems) { // TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper' int bc_idx = idx / out_panel_size; int out_d_idx = (idx % out_panel_size) / out_w / out_h; int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w; int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w; int in_start_d = START_IND(out_d_idx, out_d, in_d); int in_end_d = END_IND(out_d_idx, out_d, in_d); int k_d = in_end_d - in_start_d; int in_start_h = START_IND(out_h_idx, out_h, in_h); int in_end_h = END_IND(out_h_idx, out_h, in_h); int k_h = in_end_h - in_start_h; int in_start_w = START_IND(out_w_idx, out_w, in_w); int in_end_w = END_IND(out_w_idx, out_w, in_w); int k_w = in_end_w - in_start_w; const T grad_delta = output[idx] / static_cast(k_d) / static_cast(k_h) / static_cast(k_w); T* input_ptr = input + bc_idx * in_panel_size + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w; for (int id = 0; id < k_d; ++id) { for (int ih = 0; ih < k_h; ++ih) { for (int iw = 0; iw < k_w; ++iw) { // TODO (Tianyu): Use 'atmoic::Add' when necessary cuda::atomic::Add(input_ptr + ih * in_w + iw, grad_delta); } } input_ptr += in_h * in_w; // next input depth } } } template void AvgForwardCompute(KernelComputeContext* ctx, const int32_t& dim) { const Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const T* in_ptr = in_tensor->dptr(); T* out_ptr = out_tensor->mut_dptr(); const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); // TODO (Tianyu): Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_avg_pool on cuda only supports NCHW data format"; const Shape& in = GetShape5D(x_shape, data_format, dim); const Shape& out = GetShape5D(y_shape, data_format, dim); const int out_elems = out_tensor->shape_view().elem_cnt(); RUN_CUDA_KERNEL((AdaptiveAvgPoolCudaKernel), ctx->stream(), out_elems, in_ptr, out_ptr, out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3), out.At(4)); } template void AvgBackwardCompute(KernelComputeContext* ctx, const int32_t& dim) { const Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const T* out_ptr = out_tensor->dptr(); T* in_ptr = in_tensor->mut_dptr(); const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape(); const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape(); // TODO (Tianyu): Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_avg_pool backward on cuda only supports NCHW data format"; const Shape& in = GetShape5D(dx_shape, data_format, dim); const Shape& out = GetShape5D(dy_shape, data_format, dim); const int in_elems = in_tensor->shape_view().elem_cnt(); const int out_elems = out_tensor->shape_view().elem_cnt(); RUN_CUDA_KERNEL((InitPtr), ctx->stream(), in_elems, in_elems, in_ptr); RUN_CUDA_KERNEL((AdaptiveAvgPoolGradCudaKernel), ctx->stream(), out_elems, in_ptr, out_ptr, out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3), out.At(4)); } template class GpuAdaptiveAvgPool1dKernel final : public OpKernel { public: GpuAdaptiveAvgPool1dKernel() = default; ~GpuAdaptiveAvgPool1dKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute(ctx, 1); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveAvgPool2dKernel final : public OpKernel { public: GpuAdaptiveAvgPool2dKernel() = default; ~GpuAdaptiveAvgPool2dKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute(ctx, 2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveAvgPool3dKernel final : public OpKernel { public: GpuAdaptiveAvgPool3dKernel() = default; ~GpuAdaptiveAvgPool3dKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute(ctx, 3); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveAvgPool1dGradKernel final : public OpKernel { public: GpuAdaptiveAvgPool1dGradKernel() = default; ~GpuAdaptiveAvgPool1dGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 1); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveAvgPool2dGradKernel final : public OpKernel { public: GpuAdaptiveAvgPool2dGradKernel() = default; ~GpuAdaptiveAvgPool2dGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveAvgPool3dGradKernel final : public OpKernel { public: GpuAdaptiveAvgPool3dGradKernel() = default; ~GpuAdaptiveAvgPool3dGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute(ctx, 3); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_avg_pool1d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool2d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool3d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, half); REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, float); REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, double); REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, int); #define REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_avg_pool1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_avg_pool3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, half); REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, float); REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, double); REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, int); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/adaptive_max_pool_cpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/adaptive_pool_kernel_util.h" namespace oneflow { namespace { template void AdapativeMaxPoolForward(user_op::KernelComputeContext* ctx) { user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); // TODO : Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_max_pool on cpu only supports NCHW data format"; const Shape& in = GetShape5D(x_shape, data_format, dim); const Shape& out = GetShape5D(y_shape, data_format, dim); const T* in_ptr = in_tensor->dptr(); T* out_ptr = out_tensor->mut_dptr(); int64_t* index_ptr = index_tensor->mut_dptr(); const int64_t input_width = in.Count(4); const int64_t output_width = out.Count(4); const int64_t input_image_size = in.Count(3); const int64_t output_image_size = out.Count(3); const int64_t input_size = in.Count(2); const int64_t output_size = out.Count(2); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, od, 0, out.At(2)) { int64_t id0 = start_index(od, out.At(2), in.At(2)); int64_t id1 = end_index(od, out.At(2), in.At(2)); FOR_RANGE(int64_t, oh, 0, out.At(3)) { int64_t ih0 = start_index(oh, out.At(3), in.At(3)); int64_t ih1 = end_index(oh, out.At(3), in.At(3)); FOR_RANGE(int64_t, ow, 0, out.At(4)) { int64_t iw0 = start_index(ow, out.At(4), in.At(4)); int64_t iw1 = end_index(ow, out.At(4), in.At(4)); // Find out local max auto start_offset = id0 * input_image_size + ih0 * input_width + iw0; T local_max = in_ptr[start_offset]; int64_t local_max_index = start_offset; FOR_RANGE(int64_t, id, id0, id1) { FOR_RANGE(int64_t, ih, ih0, ih1) { FOR_RANGE(int64_t, iw, iw0, iw1) { auto cur_index = id * input_image_size + ih * input_width + iw; if (in_ptr[cur_index] > local_max) { local_max_index = cur_index; local_max = in_ptr[cur_index]; } } } } auto i = od * output_image_size + oh * output_width + ow; out_ptr[i] = local_max; index_ptr[i] = local_max_index; } } } in_ptr += input_size; index_ptr += output_size; out_ptr += output_size; } } } template void AdaptiveMaxPoolBackward(user_op::KernelComputeContext* ctx) { user_op::Tensor* grad_input = ctx->Tensor4ArgNameAndIndex("dx", 0); const user_op::Tensor* grad_output = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* return_indices = ctx->Tensor4ArgNameAndIndex("index", 0); const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape(); const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape(); // TODO : Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_max_pool backward on cpu only supports NCHW data format"; const Shape& in = GetShape5D(dx_shape, data_format, dim); const Shape& out = GetShape5D(dy_shape, data_format, dim); const T* dy_ptr = grad_output->dptr(); const int64_t* indices_ptr = return_indices->dptr(); T* dx_ptr = grad_input->mut_dptr(); std::fill(dx_ptr, dx_ptr + grad_input->shape_view().elem_cnt(), static_cast(0)); const int64_t output_width = out.Count(4); const int64_t output_image_size = out.Count(3); const int64_t input_size = in.Count(2); const int64_t output_size = out.Count(2); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, od, 0, out.At(2)) { FOR_RANGE(int64_t, oh, 0, out.At(3)) { FOR_RANGE(int64_t, ow, 0, out.At(4)) { auto i = od * output_image_size + oh * output_width + ow; dx_ptr[indices_ptr[i]] += dy_ptr[i]; } } } dx_ptr += input_size; dy_ptr += output_size; indices_ptr += output_size; } } } } // namespace template class AdaptiveMaxPoolNDCpuKernel final : public user_op::OpKernel { public: AdaptiveMaxPoolNDCpuKernel() = default; ~AdaptiveMaxPoolNDCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { AdapativeMaxPoolForward(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class AdaptiveMaxPoolNDGradCpuKernel final : public user_op::OpKernel { public: AdaptiveMaxPoolNDGradCpuKernel() = default; ~AdaptiveMaxPoolNDGradCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { AdaptiveMaxPoolBackward(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, dtype, dim) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ \ REGISTER_USER_KERNEL(op_type_name "_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); #define REGISTER_ADAPTIVE_MAX_POOL_CPU(op_type_name, dim) \ REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, double, dim); \ REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, float, dim); \ REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, int, dim); REGISTER_ADAPTIVE_MAX_POOL_CPU("adaptive_max_pool1d", 1); REGISTER_ADAPTIVE_MAX_POOL_CPU("adaptive_max_pool2d", 2); REGISTER_ADAPTIVE_MAX_POOL_CPU("adaptive_max_pool3d", 3); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/adaptive_max_pool_gpu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/common/data_type.h" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/user/kernels/adaptive_pool_kernel_util.h" namespace oneflow { namespace user_op { template __global__ void AdaptiveMaxPoolCudaKernel(const T* input, T* output, int64_t* return_index, int num_elems, int in_d, int in_h, int in_w, int out_d, int out_h, int out_w) { const int out_panel_size = out_d * out_h * out_w; const int in_panel_size = in_d * in_h * in_w; const int out_hw = out_w * out_h; CUDA_1D_KERNEL_LOOP(idx, num_elems) { int bc_idx = idx / out_panel_size; int out_d_idx = (idx % out_panel_size) / out_hw; int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w; int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w; int in_start_d = START_IND(out_d_idx, out_d, in_d); int in_end_d = END_IND(out_d_idx, out_d, in_d); int k_d = in_end_d - in_start_d; int in_start_h = START_IND(out_h_idx, out_h, in_h); int in_end_h = END_IND(out_h_idx, out_h, in_h); int k_h = in_end_h - in_start_h; int in_start_w = START_IND(out_w_idx, out_w, in_w); int in_end_w = END_IND(out_w_idx, out_w, in_w); int k_w = in_end_w - in_start_w; int64_t batch_idx_base = bc_idx * in_panel_size; const T* in_ptr = input + batch_idx_base + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w; T local_max = in_ptr[0]; int64_t local_max_index = static_cast(in_ptr - input) - batch_idx_base; for (int id = 0; id < k_d; ++id) { for (int ih = 0; ih < k_h; ++ih) { for (int iw = 0; iw < k_w; ++iw) { T val = *(in_ptr + ih * in_w + iw); if (val > local_max) { local_max = val; local_max_index = in_ptr - input - batch_idx_base + ih * in_w + iw; } } } in_ptr += in_h * in_w; // next input depth } output[idx] = local_max; return_index[idx] = local_max_index; } } template __global__ void AdaptiveMaxPoolGradCudaKernel(T* input, const T* output, const int64_t* index, int dy_elems, int in_panel_size, int out_panel_size) { CUDA_1D_KERNEL_LOOP(idx, dy_elems) { int bc_idx = idx / out_panel_size; T* input_ptr = input + bc_idx * in_panel_size; cuda::atomic::Add(input_ptr + index[idx], output[idx]); } } template void MaxForwardCompute(KernelComputeContext* ctx) { const Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); Tensor* return_indices = ctx->Tensor4ArgNameAndIndex("index", 0); const T* in_ptr = in_tensor->dptr(); T* out_ptr = out_tensor->mut_dptr(); int64_t* index_ptr = return_indices->mut_dptr(); const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); // TODO: Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_max_pool on CUDA only supports NCHW data format"; const Shape& in = GetShape5D(x_shape, data_format, dim); const Shape& out = GetShape5D(y_shape, data_format, dim); const int out_elems = out_tensor->shape_view().elem_cnt(); RUN_CUDA_KERNEL((AdaptiveMaxPoolCudaKernel), ctx->stream(), out_elems, in_ptr, out_ptr, index_ptr, out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3), out.At(4)); } template void MaxBackwardCompute(KernelComputeContext* ctx) { const Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const user_op::Tensor* return_indices = ctx->Tensor4ArgNameAndIndex("index", 0); const T* out_ptr = out_tensor->dptr(); T* in_ptr = in_tensor->mut_dptr(); const int64_t* index_ptr = return_indices->dptr(); const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape(); const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape(); // TODO (Tianyu): Support 'channels_last' const std::string& data_format = ctx->Attr("data_format"); CHECK_OR_THROW(data_format == "channels_first") << "adaptive_max_pool backward on CUDA only supports NCHW data format"; const Shape& in = GetShape5D(dx_shape, data_format, dim); const Shape& out = GetShape5D(dy_shape, data_format, dim); const int in_elems = in_tensor->shape_view().elem_cnt(); const int out_elems = out_tensor->shape_view().elem_cnt(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), in_ptr, 0, in_elems * sizeof(T)); RUN_CUDA_KERNEL((AdaptiveMaxPoolGradCudaKernel), ctx->stream(), out_elems, in_ptr, out_ptr, index_ptr, out_elems, in.At(2) * in.At(3) * in.At(4), out.At(2) * out.At(3) * out.At(4)); } template class GpuAdaptiveMaxPoolNdKernel final : public OpKernel { public: GpuAdaptiveMaxPoolNdKernel() = default; ~GpuAdaptiveMaxPoolNdKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { MaxForwardCompute(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAdaptiveMaxPoolNdGradKernel final : public OpKernel { public: GpuAdaptiveMaxPoolNdGradKernel() = default; ~GpuAdaptiveMaxPoolNdGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { MaxBackwardCompute(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_max_pool1d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_max_pool2d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_max_pool3d") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("y", 0) == GetDataType::value)); REGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, float); REGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, double); REGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, int); #define REGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("adaptive_max_pool1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_max_pool2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("adaptive_max_pool3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == device) \ && (HobDataType("dx", 0) == GetDataType::value)); REGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, float); REGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, double); REGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, int); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/adaptive_pool_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_ #define _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/utils/pool_util.h" namespace oneflow { namespace { inline int64_t start_index(int64_t a, int64_t b, int64_t c) { return (int64_t)std::floor((float)(a * c) / b); } inline int64_t end_index(int64_t a, int64_t b, int64_t c) { return (int64_t)std::ceil((float)((a + 1) * c) / b); } #define START_IND(a, b, c) (int)std::floor((float)(a * c) / b) #define END_IND(a, b, c) (int)std::ceil((float)((a + 1) * c) / b) #define START_IND_INT(a, b, c) ((a * c) / b) #define END_IND_INT(a, b, c) (((a + 1) * c + b - 1) / b) inline Shape GetShape5D(const Shape& shape, const std::string& data_format, int32_t dim) { FixedDimVector shape_3d = {GetInDim(shape, data_format, 0, dim), GetInDim(shape, data_format, 1, dim), GetInDim(shape, data_format, 2, dim)}; return Shape({shape.At(0), shape.At(1), shape_3d.at(0), shape_3d.at(1), shape_3d.at(2)}); } } // namespace } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/add_n_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewAddPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } class AddNKernel : public OpKernel, public CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(AddNKernel); AddNKernel() = default; ~AddNKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(KernelComputeContext* ctx) const override { auto primitive = NewAddPrimitive(ctx); CHECK(primitive); Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const DataType data_type = out->data_type(); const size_t count = out->shape_view().elem_cnt(); if (count == 0) { return; } size_t in_num = ctx->inputs().size(); std::vector srcs(in_num); for (size_t i = 0; i < in_num; ++i) { const Tensor* in_i = ctx->Tensor4ArgNameAndIndex("in", i); CHECK_EQ(in_i->shape_view().elem_cnt(), count); CHECK_EQ(in_i->data_type(), data_type); srcs[i] = in_i->template dptr(); } primitive->Launch(ctx->stream(), srcs.data(), in_num, out->mut_dptr(), count); } }; auto AddPrimitiveExists() { return hob::make_custom("AddPrimitiveExists", [](const KernelRegContext& ctx) { return NewAddPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("add_n") .SetCreateFn() .SetIsMatchedHob(AddPrimitiveExists() == true) .SetInplaceProposalFn([](const InferContext&, const AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); return Maybe::Ok(); }); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/affine_grid_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "affine_grid_kernel.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewAffineGridMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("theta", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/true); } auto AffineGridMatmulPrimitiveExists() { return hob::make_custom("AffineGridMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewAffineGridMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewAffineGridGradMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dgrid", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } auto AffineGridGradMatmulPrimitiveExists() { return hob::make_custom("AffineGridGradMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewAffineGridGradMatmulPrimitive(&ctx).operator bool(); }); } } // namespace template class AffineGridKernel final : public user_op::OpKernel { public: AffineGridKernel() = default; ~AffineGridKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex("theta", 0); user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const Shape& size = ctx->Attr("size"); const bool& align_corners = ctx->Attr("align_corners"); bool is_2d_grid = true; if (size.NumAxes() == 5) { is_2d_grid = false; } int64_t N = theta->shape_view().At(0); int64_t theta_h = theta->shape_view().At(1); int64_t theta_w = theta->shape_view().At(2); auto matmul = NewAffineGridMatmulPrimitive(ctx); CHECK(matmul); if (is_2d_grid) { int64_t H = size.At(2); int64_t W = size.At(3); // generate base grid GenerateBaseGridImp::Generate2D(ctx, tmp_buffer->mut_dptr(), H, W, align_corners); // Compute each batch for (int n = 0; n < N; n++) { matmul->Launch(ctx->stream(), H * W, theta_h, theta_w, /*alpha=*/1.0, tmp_buffer->dptr(), theta->dptr() + n * theta_h * theta_w, /*beta=*/0.0, grid->mut_dptr() + n * theta_h * H * W); } } else { int64_t D = size.At(2); int64_t H = size.At(3); int64_t W = size.At(4); // generate base grid GenerateBaseGridImp::Generate3D(ctx, tmp_buffer->mut_dptr(), D, H, W, align_corners); // Compute each batch for (int n = 0; n < N; n++) { matmul->Launch(ctx->stream(), D * H * W, theta_h, theta_w, /*alpha=*/1.0, tmp_buffer->dptr(), theta->dptr() + n * theta_h * theta_w, /*beta=*/0.0, grid->mut_dptr() + n * theta_h * D * H * W); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_AFFINE_GRID_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("affine_grid") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("theta", 0) == GetDataType::value) \ && AffineGridMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& size = ctx->Attr("size"); \ size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, float); REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, double); #ifdef WITH_CUDA REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCUDA, float); REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCUDA, double); #endif template class AffineGridGradKernel final : public user_op::OpKernel { public: AffineGridGradKernel() = default; ~AffineGridGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); user_op::Tensor* dtheta = ctx->Tensor4ArgNameAndIndex("dtheta", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const Shape& size = ctx->Attr("size"); const bool& align_corners = ctx->Attr("align_corners"); bool is_2d_grid = true; if (size.NumAxes() == 5) { is_2d_grid = false; } int64_t N = dtheta->shape_view().At(0); int64_t dtheta_h = dtheta->shape_view().At(1); int64_t dtheta_w = dtheta->shape_view().At(2); auto matmul = NewAffineGridGradMatmulPrimitive(ctx); CHECK(matmul); if (is_2d_grid) { int64_t H = size.At(2); int64_t W = size.At(3); // generate base grid GenerateBaseGridImp::Generate2D(ctx, tmp_buffer->mut_dptr(), H, W, align_corners); // Compute each batch for (int n = 0; n < N; n++) { matmul->Launch(ctx->stream(), dtheta_h, dtheta_w, H * W, /*alpha=*/1.0, dgrid->dptr() + n * dtheta_h * H * W, tmp_buffer->dptr(), /*beta=*/0.0, dtheta->mut_dptr() + n * dtheta_h * dtheta_w); } } else { int64_t D = size.At(2); int64_t H = size.At(3); int64_t W = size.At(4); GenerateBaseGridImp::Generate3D(ctx, tmp_buffer->mut_dptr(), D, H, W, align_corners); // Compute each batch for (int n = 0; n < N; n++) { matmul->Launch(ctx->stream(), dtheta_h, dtheta_w, D * H * W, /*alpha=*/1.0, dgrid->dptr() + n * dtheta_h * D * H * W, tmp_buffer->dptr(), /*beta=*/0.0, dtheta->mut_dptr() + n * dtheta_h * dtheta_w); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_AFFINE_GRID_GRAD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("affine_grid_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dgrid", 0) == GetDataType::value) \ && AffineGridGradMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& size = ctx->Attr("size"); \ size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, float); REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, double); #ifdef WITH_CUDA REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCUDA, float); REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCUDA, double); #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/affine_grid_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/device/cuda_util.h" #include "affine_grid_kernel.h" namespace oneflow { namespace { template OF_DEVICE_FUNC data_type LinspaceGPU(int32_t index, int32_t num_steps) { if (num_steps <= 1) { return static_cast(0.0); } if (align_corners) { return static_cast(-1.0 + 2.0 / (num_steps - 1) * index); } else { return static_cast((-1.0 + 2.0 / (num_steps - 1) * index) * (num_steps - 1) / num_steps); } } template __global__ void Generate2DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t H, int32_t W) { CUDA_1D_KERNEL_LOOP(index, nthreads) { const int32_t h = index / W; const int32_t w = index % W; const int32_t pixel_length = 3; data_type* row_ptr = grid_ptr + h * W * pixel_length; data_type* pixel_ptr = row_ptr + w * pixel_length; data_type h_value = LinspaceGPU(h, H); data_type w_value = LinspaceGPU(w, W); pixel_ptr[0] = w_value; pixel_ptr[1] = h_value; pixel_ptr[2] = static_cast(1.0); } } template __global__ void Generate3DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t D, int32_t H, int32_t W) { CUDA_1D_KERNEL_LOOP(index, nthreads) { const int32_t d = index / H; const int32_t h = index % H; const int32_t pixel_length = 4; data_type* image_ptr = grid_ptr + d * H * W * pixel_length; data_type* row_ptr = image_ptr + h * W * pixel_length; data_type d_value = LinspaceGPU(d, D); data_type h_value = LinspaceGPU(h, H); for (int32_t w = 0; w < W; ++w) { data_type* pixel_ptr = row_ptr + w * pixel_length; data_type w_value = LinspaceGPU(w, W); pixel_ptr[0] = w_value; pixel_ptr[1] = h_value; pixel_ptr[2] = d_value; pixel_ptr[3] = static_cast(1.0); } } } } // namespace void GenerateBaseGridImp::Generate2D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t H, int64_t W, bool align_corners) { int count = H * W; if (align_corners) { RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, H, W); } else { RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, H, W); } } void GenerateBaseGridImp::Generate2D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t H, int64_t W, bool align_corners) { int count = H * W; if (align_corners) { RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, H, W); } else { RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, H, W); } } void GenerateBaseGridImp::Generate3D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t D, int64_t H, int64_t W, bool align_corners) { int count = D * H; if (align_corners) { RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, D, H, W); } else { RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, D, H, W); } } void GenerateBaseGridImp::Generate3D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t D, int64_t H, int64_t W, bool align_corners) { int count = D * H; if (align_corners) { RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, D, H, W); } else { RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->stream(), count, count, grid_ptr, D, H, W); } } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/affine_grid_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ #define _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/common/device_type.h" namespace oneflow { template struct GenerateBaseGridImp {}; template<> struct GenerateBaseGridImp { template static void Linspace(std::vector& grid, int64_t num_steps, bool align_corners) { if (num_steps <= 1) { for (auto& it : grid) { it = static_cast(0.0); } return; } if (align_corners) { for (int i = 0; i < num_steps; i++) { grid[i] = static_cast(-1.0 + 2.0 / (num_steps - 1) * i); } } else { for (int i = 0; i < num_steps; i++) { grid[i] = static_cast((-1.0 + 2.0 / (num_steps - 1) * i) * (num_steps - 1) / num_steps); } } } template static void Generate2D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t H, int64_t W, bool align_corners) { std::vector w_step(W); std::vector h_step(H); Linspace(w_step, W, align_corners); Linspace(h_step, H, align_corners); for (int h = 0; h < H; h++) { data_type* row_ptr = grid_ptr + h * W * 3; for (int w = 0; w < W; w++) { data_type* pixel_ptr = row_ptr + w * 3; pixel_ptr[0] = w_step[w]; pixel_ptr[1] = h_step[h]; pixel_ptr[2] = static_cast(1.0); } } } template static void Generate3D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t D, int64_t H, int64_t W, bool align_corners) { std::vector w_step(W); std::vector h_step(H); std::vector d_step(D); Linspace(w_step, W, align_corners); Linspace(h_step, H, align_corners); Linspace(d_step, D, align_corners); for (int d = 0; d < D; d++) { data_type* image_ptr = grid_ptr + d * H * W * 4; for (int h = 0; h < H; h++) { data_type* row_ptr = image_ptr + h * W * 4; for (int w = 0; w < W; w++) { data_type* pixel_ptr = row_ptr + w * 4; pixel_ptr[0] = w_step[w]; pixel_ptr[1] = h_step[h]; pixel_ptr[2] = d_step[d]; pixel_ptr[3] = static_cast(1.0); } } } } }; template<> struct GenerateBaseGridImp { static void Generate2D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t H, int64_t W, bool align_corners); static void Generate2D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t H, int64_t W, bool align_corners); static void Generate3D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t D, int64_t H, int64_t W, bool align_corners); static void Generate3D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t D, int64_t H, int64_t W, bool align_corners); }; } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ ================================================ FILE: oneflow/user/kernels/arange_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/arange_kernel_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { class ArangeOpKernelCache final : public user_op::OpKernelCache { public: ArangeOpKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {} ~ArangeOpKernelCache() override = default; int32_t lower() const { return lower_; } int32_t upper() const { return upper_; } private: const int32_t lower_; const int32_t upper_; }; template class ArangeKernel final : public OpKernel, public CudaGraphSupport { public: ArangeKernel() = default; ~ArangeKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { DataType dtype = ctx->Attr("dtype"); int64_t range_elem_cnt = 0; int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { if (IsIntegralDataType(dtype)) { int64_t integer_delta = ctx->Attr("integer_delta"); int64_t integer_start = ctx->Attr("integer_start"); int64_t integer_limit = ctx->Attr("integer_limit"); range_elem_cnt = std::ceil(static_cast(integer_limit - integer_start) / integer_delta); } else { double float_delta = ctx->Attr("float_delta"); double float_start = ctx->Attr("float_start"); double float_limit = ctx->Attr("float_limit"); range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); } const Shape& logical_shape = Shape({range_elem_cnt}); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); TensorSliceView view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); std::shared_ptr cache( new ArangeOpKernelCache(view.At(0).begin(), view.At(0).end())); return cache; } else { return nullptr; } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T* output = out->mut_dptr(); const DataType dtype = ctx->Attr("dtype"); int64_t arange_elem_cnt = 0; T start = static_cast(0.0); T delta = static_cast(0.0); T limit = static_cast(0.0); if (IsIntegralDataType(dtype)) { start = static_cast(static_cast(ctx->Attr("integer_start"))); delta = static_cast(static_cast(ctx->Attr("integer_delta"))); limit = static_cast(static_cast(ctx->Attr("integer_limit"))); arange_elem_cnt = std::ceil(static_cast(limit - start) / static_cast(delta)); } else { // If we use static_cast(start, delta, limit) and std::ceil to calculate arange_elem_cnt, // it will cause rounding error. double float_start = ctx->Attr("float_start"); double float_delta = ctx->Attr("float_delta"); double float_limit = ctx->Attr("float_limit"); arange_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); start = static_cast(float_start); delta = static_cast(float_delta); limit = static_cast(float_limit); } if (arange_elem_cnt == 0) { return; } if (cache == nullptr) { ArangeFunctor()(ctx->stream(), start, delta, arange_elem_cnt, output); } else { const auto* arange_cache = dynamic_cast(cache); auto arange_len = arange_cache->upper() - arange_cache->lower(); auto lower = static_cast(static_cast(arange_cache->lower())); ArangeFunctor()(ctx->stream(), static_cast(start + delta * lower), delta, arange_len, output); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ARANGE_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("arange").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); #define REGISTER_ARANGE_KERNELS_WITH_DEVICE(device) \ REGISTER_ARANGE_KERNEL(device, uint8_t) \ REGISTER_ARANGE_KERNEL(device, int8_t) \ REGISTER_ARANGE_KERNEL(device, int32_t) \ REGISTER_ARANGE_KERNEL(device, int64_t) \ REGISTER_ARANGE_KERNEL(device, float) \ REGISTER_ARANGE_KERNEL(device, double) #define REGISTER_ARANGE_KERNELS_WITH_CUDA_HALF(device) REGISTER_ARANGE_KERNEL(device, half) // Register CPU version REGISTER_ARANGE_KERNELS_WITH_DEVICE(DeviceType::kCPU); REGISTER_ARANGE_KERNEL(DeviceType::kCPU, float16); // Register GPU version #ifdef WITH_CUDA REGISTER_ARANGE_KERNELS_WITH_DEVICE(DeviceType::kCUDA); REGISTER_ARANGE_KERNELS_WITH_CUDA_HALF(DeviceType::kCUDA); #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arange_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/arange_kernel_util.h" namespace oneflow { namespace user_op { template struct ArangeFunctor final { void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt, T* out) { DoArange(start, delta, arange_elem_cnt, out); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCPU), ARANGE_DATA_TYPE_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arange_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/arange_kernel_util.h" namespace oneflow { namespace user_op { template __global__ void ArangeForwardGpuKernel(const T start, const T delta, const int64_t arange_elem_cnt, T* out) { // Use Loop to set the value DoArange(start, delta, arange_elem_cnt, out); } template<> __global__ void ArangeForwardGpuKernel(const half start, const half delta, const int64_t arange_elem_cnt, half* out) { // Use Loop to set the value XPU_1D_KERNEL_LOOP(i, arange_elem_cnt) { out[i] = start + static_cast(static_cast(i)) * delta; } } template struct ArangeFunctor final { void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt, T* out) { // The thread num is set as arange_elem_cnt RUN_CUDA_KERNEL((ArangeForwardGpuKernel), stream, arange_elem_cnt, start, delta, arange_elem_cnt, out); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCUDA), ARANGE_DATA_TYPE_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow #endif // End WITH_CUDA ================================================ FILE: oneflow/user/kernels/arange_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { #define ARANGE_DATA_TYPE_SEQ \ FLOATING_DATA_TYPE_SEQ \ INT_DATA_TYPE_SEQ \ UNSIGNED_INT_DATA_TYPE_SEQ namespace user_op { template struct ArangeFunctor final { void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt, T* out); }; template OF_DEVICE_FUNC void DoArange(const T start, const T delta, const int64_t arange_elem_cnt, T* out) { XPU_1D_KERNEL_LOOP(i, arange_elem_cnt) { out[i] = start + i * delta; } } #define INSTANTIATE_ARANGE_FUNCTOR(device_type_v, dtype_pair) \ template struct ArangeFunctor; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/arg_sort_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { template class CpuArgSortKernel final : public user_op::OpKernel { public: CpuArgSortKernel() = default; ~CpuArgSortKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int32_t instance_num = in->shape_view().elem_cnt() / instance_size; const std::string& direction = ctx->Attr("direction"); const bool is_ascending = direction == "ASCENDING"; const bool is_descending = direction == "DESCENDING"; FOR_RANGE(int32_t, i, 0, instance_num) { const T* in_ptr_i = in->dptr() + i * instance_size; int32_t* out_ptr_i = out->mut_dptr() + i * instance_size; std::iota(out_ptr_i, out_ptr_i + instance_size, 0); auto comp = [&](const int32_t lhs, const int32_t rhs) { const T l = in_ptr_i[lhs]; const T r = in_ptr_i[rhs]; if (l == r) { return lhs < rhs; } else { if (is_ascending) { return l < r; } else if (is_descending) { return l > r; } else { LOG(FATAL) << "expected the input direction parameter value is \"ASCENDING\" or " "\"DESCENDING\", " << "but found the value is " << "\"" << direction << "\""; } } }; std::sort(out_ptr_i, out_ptr_i + instance_size, comp); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_ARG_SORT_KERNEL(dtype) \ REGISTER_USER_KERNEL("arg_sort") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_CPU_ARG_SORT_KERNEL(float) REGISTER_CPU_ARG_SORT_KERNEL(double) REGISTER_CPU_ARG_SORT_KERNEL(bool) REGISTER_CPU_ARG_SORT_KERNEL(int8_t) REGISTER_CPU_ARG_SORT_KERNEL(uint8_t) REGISTER_CPU_ARG_SORT_KERNEL(int32_t) REGISTER_CPU_ARG_SORT_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arg_sort_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(int32_t capacity, void* ptr, const ShapeView& in_shape) : capacity_{capacity}, sorted_in_elem_cnt_{in_shape.elem_cnt()}, indices_elem_cnt_{sorted_in_elem_cnt_} { const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); const int32_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int32_t)); sorted_in_ptr_ = reinterpret_cast(ptr); indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_in_ptr_) + sorted_in_aligned_bytes); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(indices_ptr_) + indices_aligned_bytes); temp_storage_bytes_ = capacity_ - sorted_in_aligned_bytes - indices_aligned_bytes; CHECK_GE(temp_storage_bytes_, 0); } ~TmpBufferManager() = default; T* SortedInPtr() const { return sorted_in_ptr_; } int32_t* IndicesPtr() const { return indices_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } int32_t TempStorageBytes() const { return temp_storage_bytes_; } private: int32_t capacity_; T* sorted_in_ptr_; int32_t* indices_ptr_; void* temp_storage_ptr_; int64_t sorted_in_elem_cnt_; int64_t indices_elem_cnt_; int32_t temp_storage_bytes_; }; __global__ void InitializeIndices(int32_t elem_cnt, int32_t* indices_ptr, int32_t instance_size) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; }; } } // namespace template class GpuArgSortKernel final : public user_op::OpKernel { public: GpuArgSortKernel() = default; ~GpuArgSortKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buf_manager(static_cast(tmp_buffer->shape_view().elem_cnt()), tmp_buffer->mut_dptr(), in->shape_view()); const int32_t elem_cnt = in->shape_view().elem_cnt(); const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int32_t instance_num = elem_cnt / instance_size; const std::string& direction = ctx->Attr("direction"); InitializeIndices<<stream()->As()->cuda_stream()>>>( elem_cnt, buf_manager.IndicesPtr(), instance_size); if (direction == "ASCENDING") { SortPairsAscending(in->dptr(), buf_manager.IndicesPtr(), instance_num, instance_size, buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedInPtr(), out->mut_dptr(), ctx->stream()->As()->cuda_stream()); } else if (direction == "DESCENDING") { SortPairsDescending(in->dptr(), buf_manager.IndicesPtr(), instance_num, instance_size, buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedInPtr(), out->mut_dptr(), ctx->stream()->As()->cuda_stream()); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ARG_SORT_KERNEL(dtype) \ REGISTER_USER_KERNEL("arg_sort") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ const int32_t elem_cnt = in_shape.elem_cnt(); \ const int32_t instance_size = in_shape.dim_vec().back(); \ const int32_t instance_num = elem_cnt / instance_size; \ \ /* Sorted In */ \ const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype)); \ /* Indices */ \ const int32_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int32_t)); \ /* CUB Temp Storage */ \ int32_t temp_storage_bytes = -1; \ const std::string& direction = ctx->Attr("direction"); \ if (direction == "ASCENDING") { \ temp_storage_bytes = \ InferTempStorageForSortPairsAscending(instance_num, instance_size); \ } else if (direction == "DESCENDING") { \ temp_storage_bytes = \ InferTempStorageForSortPairsDescending(instance_num, instance_size); \ } else { \ UNIMPLEMENTED(); \ } \ \ return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes; \ }); REGISTER_CUDA_ARG_SORT_KERNEL(float) REGISTER_CUDA_ARG_SORT_KERNEL(double) REGISTER_CUDA_ARG_SORT_KERNEL(bool) REGISTER_CUDA_ARG_SORT_KERNEL(int8_t) REGISTER_CUDA_ARG_SORT_KERNEL(uint8_t) REGISTER_CUDA_ARG_SORT_KERNEL(int32_t) REGISTER_CUDA_ARG_SORT_KERNEL(int64_t) REGISTER_CUDA_ARG_SORT_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arg_where_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/data_type_seq.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/user/kernels/arg_where_kernel_util.h" namespace oneflow { namespace { template class ArgWhereKernel final : public user_op::OpKernel { public: ArgWhereKernel() = default; ~ArgWhereKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { int64_t ndims = ctx->Tensor4ArgNameAndIndex("input", 0)->shape_view().NumAxes(); if (ndims == 0) { // 0-dim tensor, elem_cnt of input is 1 CHECK_EQ(ctx->Tensor4ArgNameAndIndex("input", 0)->shape_view().elem_cnt(), 1); SetOutputSize( ctx->stream(), ctx->Tensor4ArgNameAndIndex("input", 0)->dptr(), ctx->Tensor4ArgNameAndIndex("output_size", 0)->mut_dptr()); return; } SwitchNdimCompute(SwitchCase(ndims), ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } #define COMPUTE_SWITCH_ENTRY(func_name, ndim) func_name DEFINE_STATIC_SWITCH_FUNC(void, NdimCompute, COMPUTE_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); #undef COMPUTE_SWITCH_ENTRY template static void NdimCompute(user_op::KernelComputeContext* ctx) { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* output_size = ctx->Tensor4ArgNameAndIndex("output_size", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr; size_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0; ArgWhereKernelUtil::ArgWhere( ctx->stream(), input->shape_view(), input->dptr(), tmp_ptr, tmp_size, output->mut_dptr(), output_size->mut_dptr()); } }; template size_t GetWorkspaceBytesSize(int64_t elem_cnt) { return ArgWhereKernelUtil::GetWorkspaceBytesSize(nullptr, elem_cnt); } template struct SwitchUtil; template<> struct SwitchUtil { #define SWITCH_ENTRY(func_name, device, itype, otype, ndim) func_name DEFINE_STATIC_SWITCH_FUNC( size_t, GetWorkspaceBytesSize, SWITCH_ENTRY, MAKE_DEVICE_TYPE_CTRV_SEQ(OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)), MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ), MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ), MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); #undef SWITCH_ENTRY }; #ifdef WITH_CUDA template<> struct SwitchUtil { #define SWITCH_ENTRY(func_name, device, itype, otype, ndim) func_name DEFINE_STATIC_SWITCH_FUNC( size_t, GetWorkspaceBytesSize, SWITCH_ENTRY, MAKE_DEVICE_TYPE_CTRV_SEQ(OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA)), MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ), MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ), MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); #undef SWITCH_ENTRY }; #endif // WITH_CUDA template size_t InferTempStorageBytesSize(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); if (input_shape.NumAxes() == 0) { return 0; } DataType input_dtype = ctx->InputDType("input", 0); DataType output_dtype = ctx->OutputDType("output", 0); return SwitchUtil::SwitchGetWorkspaceBytesSize( SwitchCase(device_type, input_dtype, output_dtype, input_shape.NumAxes()), input_shape.elem_cnt()); } } // namespace #define REGISTER_ARG_WHERE_KERNEL(device, itype, otype) \ REGISTER_USER_KERNEL("argwhere") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("output", 0) == GetDataType::value) \ && (user_op::HobDataType("output_size", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTempStorageBytesSize); #define REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR(device, itype_pair, otype_pair) \ REGISTER_ARG_WHERE_KERNEL(device, OF_PP_PAIR_FIRST(itype_pair), OF_PP_PAIR_FIRST(otype_pair)) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arg_where_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/arg_where_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct ArgWhereKernelUtil { static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr, void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr, OUT_T* output_size_ptr) { // deal with empty blob if (input_shape.elem_cnt() == 0) { Memset(stream, output_size_ptr, 0, sizeof(OUT_T)); return; } const int64_t elem_cnt = input_shape.elem_cnt(); CHECK_LE(elem_cnt, std::numeric_limits::max()); OUT_T true_cnt = 0; OUT_T dims[NDIM] = {0}; std::transform(input_shape.ptr(), input_shape.ptr() + input_shape.NumAxes(), dims, [](int64_t dim) { return static_cast(dim); }); NdIndexOffsetHelper index_converter(dims); FOR_RANGE(int64_t, i, 0, elem_cnt) { if (static_cast(input_ptr[i])) { index_converter.OffsetToNdIndex(i, output_ptr + true_cnt * NDIM); true_cnt += 1; } } *output_size_ptr = true_cnt; } static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt) { return 0; } }; INSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU) #define INSTANTIATE_CPU_FLOAT16_ARG_WHERE_KERNEL_UTIL \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR, \ (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \ DIM_SEQ) INSTANTIATE_CPU_FLOAT16_ARG_WHERE_KERNEL_UTIL template void SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr) { if (*input_ptr == GetZeroVal()) { *output_size_ptr = GetZeroVal(); } else { *output_size_ptr = GetOneVal(); } } INSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(DeviceType::kCPU) #define INSTANTIATE_CPU_FLOAT16_SET_OUTPUT_SIZE \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, \ (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) INSTANTIATE_CPU_FLOAT16_SET_OUTPUT_SIZE } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arg_where_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/arg_where_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetNumBlocks(int64_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } template struct StrideIterator { typedef StrideIterator self_type; typedef std::ptrdiff_t difference_type; typedef T value_type; typedef T* pointer; typedef T& reference; typedef std::random_access_iterator_tag iterator_category; explicit StrideIterator(T* ptr, size_t max_iters) : ptr_(ptr), max_iters_(max_iters) {} OF_DEVICE_FUNC reference operator[](int i) { assert(0 <= i && i < max_iters_); return *(ptr_ + (i * NDIM)); } private: T* ptr_; size_t max_iters_; }; template __global__ void __launch_bounds__(kBlockSize) CudaOffsetToNdIndexInplace(NdIndexOffsetHelper index_converter, const T* output_size_ptr, T* output_ptr) { CUDA_1D_KERNEL_LOOP_T(T, i, *output_size_ptr) { T* index_ptr = output_ptr + i * NDIM; index_converter.OffsetToNdIndex(*index_ptr, index_ptr); } } template struct IsTrue { __device__ __forceinline__ bool operator()(const T& val) const { return static_cast(val); } }; template cudaError_t SelectTrue(cudaStream_t stream, int num_items, void* temp_storage, size_t& temp_storage_bytes, const IN_T* input, OUT_ITER output_iter, OUT_T* num_selected) { IsTrue is_true; cub::TransformInputIterator, const IN_T*> flag_iter(input, is_true); cub::CountingInputIterator offset_counter(0); return cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, offset_counter, flag_iter, output_iter, num_selected, num_items, stream); } template __global__ void SetOutputSizeKernel(const IN_T* input_ptr, OUT_T* output_size_ptr) { if (*input_ptr == GetZeroVal()) { *output_size_ptr = GetZeroVal(); } else { *output_size_ptr = GetOneVal(); } } } // namespace template struct ArgWhereKernelUtil { static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr, void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr, OUT_T* output_size_ptr) { const int64_t elem_cnt = input_shape.elem_cnt(); // deal with empty blob if (elem_cnt == 0) { Memset(stream, output_size_ptr, 0, sizeof(OUT_T)); return; } CHECK_NOTNULL(stream); CHECK_LE(elem_cnt, std::numeric_limits::max()); size_t workspace = GetWorkspaceBytesSize(stream, elem_cnt); CHECK_LE(workspace, temp_storage_bytes); if (NDIM == 1) { OF_CUDA_CHECK((SelectTrue( stream->As()->cuda_stream(), input_shape.elem_cnt(), temp_storage, workspace, input_ptr, output_ptr, output_size_ptr))); } else { using OutputIterator = StrideIterator; OutputIterator output_iter(output_ptr, elem_cnt); OF_CUDA_CHECK((SelectTrue( stream->As()->cuda_stream(), elem_cnt, temp_storage, workspace, input_ptr, output_iter, output_size_ptr))); OUT_T dims[NDIM] = {0}; std::transform(input_shape.ptr(), input_shape.ptr() + input_shape.NumAxes(), dims, [](int64_t dim) { return static_cast(dim); }); NdIndexOffsetHelper index_converter(dims); CudaOffsetToNdIndexInplace <<As()->cuda_stream()>>>( index_converter, output_size_ptr, output_ptr); } } static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt) { cudaStream_t cuda_stream = stream ? stream->As()->cuda_stream() : 0; size_t workspace = 0; if (NDIM == 1) { OF_CUDA_CHECK((SelectTrue(cuda_stream, elem_cnt, nullptr, workspace, nullptr, nullptr, nullptr))); } else { using OutputIterator = StrideIterator; OutputIterator output_iter(nullptr, elem_cnt); OF_CUDA_CHECK((SelectTrue( cuda_stream, elem_cnt, nullptr, workspace, nullptr, output_iter, nullptr))); } return workspace; } }; INSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA) #define INSTANTIATE_CUDA_HALF_ARG_WHERE_KERNEL_UTIL \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR, \ (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \ DIM_SEQ) INSTANTIATE_CUDA_HALF_ARG_WHERE_KERNEL_UTIL template void SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr) { SetOutputSizeKernel <<<1, 1, 0, stream->As()->cuda_stream()>>>(input_ptr, output_size_ptr); } INSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(DeviceType::kCUDA) #define INSTANTIATE_CUDA_HALF_SET_OUTPUT_SIZE \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, \ (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) INSTANTIATE_CUDA_HALF_SET_OUTPUT_SIZE } // namespace oneflow ================================================ FILE: oneflow/user/kernels/arg_where_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/common/shape_view.h" namespace oneflow { template struct ArgWhereKernelUtil { static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr, void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr, OUT_T* output_size_ptr); static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt); }; #define INSTANTIATE_ARG_WHERE_KERNEL_UTIL(device, itype, otype, ndim) \ template struct ArgWhereKernelUtil; #define INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR(device, itype_pair, otype_pair, ndim) \ INSTANTIATE_ARG_WHERE_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(itype_pair), \ OF_PP_PAIR_FIRST(otype_pair), ndim) #define INSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(device) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( \ INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR, (device), \ ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \ DIM_SEQ) template void SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr); #define INSTANTIATE_SET_OUTPUT_SIZE(device, itype, otype) \ template void SetOutputSize(ep::Stream * stream, const itype* input_ptr, \ otype* output_size_ptr); #define INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR(device, itype_pair, otype_pair) \ INSTANTIATE_SET_OUTPUT_SIZE(device, OF_PP_PAIR_FIRST(itype_pair), OF_PP_PAIR_FIRST(otype_pair)) #define INSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(device) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( \ INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, (device), \ ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/argmax_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { template class CpuArgMaxKernel final : public user_op::OpKernel { public: CpuArgMaxKernel() = default; ~CpuArgMaxKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t elem_cnt = in->shape_view().elem_cnt(); CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) { return; } const T* in_ptr = in->dptr(); int64_t* out_ptr = out->mut_dptr(); const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int64_t instance_num = elem_cnt / instance_size; const int64_t num_thread = std::min(instance_num, (int64_t)Singleton::Get()->thread_num()); const BalancedSplitter bs(instance_num, num_thread); BlockingCounter bc(num_thread); FOR_RANGE(int64_t, thread_id, 0, num_thread) { const Range range = bs.At(thread_id); Singleton::Get()->AddWork([=, &bc]() { FOR_RANGE(int64_t, i, range.begin(), range.end()) { const T* in_ptr_i = in_ptr + i * instance_size; out_ptr[i] = std::distance(in_ptr_i, std::max_element(in_ptr_i, in_ptr_i + instance_size)); } bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_ARGMAX_KERNEL(dtype) \ REGISTER_USER_KERNEL("argmax").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_CPU_ARGMAX_KERNEL(bool) REGISTER_CPU_ARGMAX_KERNEL(float) REGISTER_CPU_ARGMAX_KERNEL(float16) REGISTER_CPU_ARGMAX_KERNEL(double) REGISTER_CPU_ARGMAX_KERNEL(uint8_t) REGISTER_CPU_ARGMAX_KERNEL(int8_t) REGISTER_CPU_ARGMAX_KERNEL(int32_t) REGISTER_CPU_ARGMAX_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/argmax_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(int32_t capacity, void* ptr, int32_t instance_num) : capacity_{capacity}, key_value_out_elem_cnt_{instance_num} { const int32_t key_value_out_aligned_bytes = GetCudaAlignedSize(key_value_out_elem_cnt_ * sizeof(cub::KeyValuePair)); key_value_out_ptr_ = reinterpret_cast*>(ptr); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(key_value_out_ptr_) + key_value_out_aligned_bytes); temp_storage_bytes_ = capacity_ - key_value_out_aligned_bytes; CHECK_GE(temp_storage_bytes_, 0); } ~TmpBufferManager() = default; cub::KeyValuePair* KeyValueOutPtr() const { return key_value_out_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } int32_t TempStorageBytes() const { return temp_storage_bytes_; } private: int32_t capacity_; cub::KeyValuePair* key_value_out_ptr_; void* temp_storage_ptr_; int32_t key_value_out_elem_cnt_; int32_t temp_storage_bytes_; }; class MultiplyFunctor final { public: MultiplyFunctor(int32_t num_col) : num_col_(num_col) {} __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const { return idx * num_col_; } private: int32_t num_col_; }; template size_t InferTempStorageForArgMax(int32_t num_row, int32_t num_col) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); size_t temp_storage_bytes = 0; auto err = cub::DeviceSegmentedReduce::ArgMax*, SegmentOffsetIter>( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_in */ nullptr, /* d_out */ nullptr, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* stream */ 0); OF_CUDA_CHECK(err); return temp_storage_bytes; } template void ArgMax(const T* in_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes, cub::KeyValuePair* out_ptr, cudaStream_t stream) { size_t rt_inferred_temp_storage_bytes = InferTempStorageForArgMax(num_row, num_col); CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes); using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedReduce::ArgMax( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_in */ in_ptr, /* d_out */ out_ptr, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* stream */ stream); OF_CUDA_CHECK(err); } template __global__ void WriteKeysToOutput(const int32_t instance_num, const cub::KeyValuePair* key_value_out_ptr, int64_t* out_ptr) { CUDA_1D_KERNEL_LOOP(i, instance_num) { out_ptr[i] = key_value_out_ptr[i].key; } } } // namespace template class GpuArgMaxKernel final : public user_op::OpKernel { public: GpuArgMaxKernel() = default; ~GpuArgMaxKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t elem_cnt = in->shape_view().elem_cnt(); CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) { return; } const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int32_t instance_num = elem_cnt / instance_size; TmpBufferManager buffer_manager(tmp_buffer->shape_view().elem_cnt(), tmp_buffer->mut_dptr(), instance_num); ArgMax(in->dptr(), instance_num, instance_size, buffer_manager.TempStoragePtr(), buffer_manager.TempStorageBytes(), buffer_manager.KeyValueOutPtr(), ctx->stream()->As()->cuda_stream()); WriteKeysToOutput<<stream()->As()->cuda_stream()>>>( instance_num, buffer_manager.KeyValueOutPtr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ARGMAX_KERNEL(dtype) \ REGISTER_USER_KERNEL("argmax") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ const int32_t instance_size = in_shape.dim_vec().back(); \ const int32_t instance_num = in_shape.elem_cnt() / instance_size; \ \ /* Key-Value Out */ \ int32_t key_value_out_bytes = \ GetCudaAlignedSize(instance_num * sizeof(cub::KeyValuePair)); \ \ /* CUB Temp Storage */ \ size_t temp_storage_bytes = InferTempStorageForArgMax(instance_num, instance_size); \ \ return key_value_out_bytes + temp_storage_bytes; \ }); REGISTER_CUDA_ARGMAX_KERNEL(bool) REGISTER_CUDA_ARGMAX_KERNEL(float) REGISTER_CUDA_ARGMAX_KERNEL(double) REGISTER_CUDA_ARGMAX_KERNEL(uint8_t) REGISTER_CUDA_ARGMAX_KERNEL(int8_t) REGISTER_CUDA_ARGMAX_KERNEL(int32_t) REGISTER_CUDA_ARGMAX_KERNEL(int64_t) REGISTER_CUDA_ARGMAX_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/as_strided_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace { constexpr size_t NUM_DIM = 20; template struct AsStridedFunctor final { void operator()(ep::Stream* stream, const T* input_buf, T* output_buf, const int64_t* dest_dims, const int64_t* stride, const int64_t dest_num_dims, const int64_t storage_offset, const int64_t input_num, const int64_t output_num) { NdIndexOffsetHelper destIndexOffsetHelper(dest_dims, dest_num_dims); FOR_RANGE(int64_t, i, 0, output_num) { int64_t dst_index[NUM_DIM]; destIndexOffsetHelper.OffsetToNdIndex(i, dst_index, dest_num_dims); int64_t index_in_input = storage_offset; FOR_RANGE(int64_t, j, 0, dest_num_dims) { index_in_input += dst_index[j] * stride[j]; } output_buf[i] = input_buf[index_in_input]; } } }; template struct AsStridedGradFunctor final { void operator()(ep::Stream* stream, const T* dy_buf, T* dx_buf, const int64_t* dy_dims, const int64_t* stride, const int64_t dy_num_dims, const int64_t storage_offset, const int64_t dx_num, const int64_t dy_num) { NdIndexOffsetHelper destIndexOffsetHelper(dy_dims, dy_num_dims); FOR_RANGE(int64_t, i, 0, dy_num) { int64_t dy_index[NUM_DIM]; destIndexOffsetHelper.OffsetToNdIndex(i, dy_index, dy_num_dims); int64_t index_in_dx = storage_offset; FOR_RANGE(int64_t, j, 0, dy_num_dims) { index_in_dx += dy_index[j] * stride[j]; } dx_buf[index_in_dx] += dy_buf[i]; } } }; } // namespace template class CpuAsStridedKernel final : public user_op::OpKernel { public: CpuAsStridedKernel() = default; ~CpuAsStridedKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const auto size = ctx->Attr>("size"); const auto stride = ctx->Attr>("stride"); const int64_t storage_offset = ctx->Attr("storage_offset"); size_t dest_num_dims = output->shape_view().NumAxes(); const int64_t* dest_dims = output->shape_view().ptr(); const size_t input_num = input->shape_view().Count(0); const size_t output_num = output->shape_view().Count(0); AsStridedFunctor()(ctx->stream(), input->dptr(), output->mut_dptr(), dest_dims, stride.data(), dest_num_dims, storage_offset, input_num, output_num); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class CpuAsStridedGradKernel final : public user_op::OpKernel { public: CpuAsStridedGradKernel() = default; ~CpuAsStridedGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto size = ctx->Attr>("size"); const auto stride = ctx->Attr>("stride"); const int64_t storage_offset = ctx->Attr("storage_offset"); size_t dy_num_dims = dy->shape_view().NumAxes(); const int64_t* dy_dims = dy->shape_view().ptr(); const size_t dx_num = dx->shape_view().Count(0); const size_t dy_num = dy->shape_view().Count(0); Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape_view().Count(0) * sizeof(T)); AsStridedGradFunctor()(ctx->stream(), dy->dptr(), dx->mut_dptr(), dy_dims, stride.data(), dy_num_dims, storage_offset, dx_num, dy_num); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_ASSTRIDED_KERNEL(in_type) \ REGISTER_USER_KERNEL("as_strided") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_CPU_ASSTRIDED_KERNEL(float); REGISTER_CPU_ASSTRIDED_KERNEL(double); REGISTER_CPU_ASSTRIDED_KERNEL(int8_t); REGISTER_CPU_ASSTRIDED_KERNEL(uint8_t); REGISTER_CPU_ASSTRIDED_KERNEL(int32_t); REGISTER_CPU_ASSTRIDED_KERNEL(int64_t); #undef REGISTER_CPU_ASSTRIDED_KERNEL #define REGISTER_CPU_ASSTRIDED_GRAD_KERNEL(in_type) \ REGISTER_USER_KERNEL("as_strided_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_CPU_ASSTRIDED_GRAD_KERNEL(float); REGISTER_CPU_ASSTRIDED_GRAD_KERNEL(double); #undef REGISTER_CPU_ASSTRIDED_GRAD_KERNEL REGISTER_USER_KERNEL("as_strided") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("input", 0) == GetDataType::value)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/as_strided_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace { constexpr size_t NUM_DIM = 8; template struct AsStridedParams { NdIndexOffsetHelper destIndexOffsetHelper; int64_t dest_dims[num_dims]; int32_t stride[num_dims]; int32_t dest_num_dims; int32_t storage_offset; int32_t input_num; int32_t output_num; }; template __global__ void AsStrided_kernel(const T* input_buf, T* output_buf, AsStridedParams params) { const int64_t* dest_dims = reinterpret_cast(params.dest_dims); const int32_t* stride = reinterpret_cast(params.stride); CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.output_num) { int64_t dst_index[NUM_DIM]; params.destIndexOffsetHelper.OffsetToNdIndex(i, dst_index, params.dest_num_dims); int32_t index_in_input = params.storage_offset; FOR_RANGE(int64_t, j, 0, params.dest_num_dims) { index_in_input += dst_index[j] * stride[j]; } output_buf[i] = input_buf[index_in_input]; } } template __global__ void AsStridedGrad_kernel(const T* dy_buf, T* dx_buf, AsStridedParams params) { const int64_t* dest_dims = reinterpret_cast(params.dest_dims); const int32_t* stride = reinterpret_cast(params.stride); CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.output_num) { int64_t dy_index[NUM_DIM]; params.destIndexOffsetHelper.OffsetToNdIndex(i, dy_index, params.dest_num_dims); int32_t index_in_dx = params.storage_offset; FOR_RANGE(int64_t, j, 0, params.dest_num_dims) { index_in_dx += dy_index[j] * stride[j]; } cuda::atomic::Add(dx_buf + index_in_dx, dy_buf[i]); } } template struct AsStridedFunctor final { void operator()(ep::Stream* stream, const T* input_buf, T* output_buf, const int64_t* dest_dims, const int64_t* stride, const int64_t dest_num_dims, const int64_t storage_offset, const int64_t input_num, const int64_t output_num) { NdIndexOffsetHelper destIndexOffsetHelper(dest_dims, dest_num_dims); AsStridedParams params; params.destIndexOffsetHelper = destIndexOffsetHelper; FOR_RANGE(size_t, i, 0, dest_num_dims) { params.dest_dims[i] = dest_dims[i]; params.stride[i] = stride[i]; } params.dest_num_dims = dest_num_dims; params.storage_offset = storage_offset; params.input_num = input_num; params.output_num = output_num; AsStrided_kernel <<As()->cuda_stream()>>>(input_buf, output_buf, params); } }; template struct AsStridedGradFunctor final { void operator()(ep::Stream* stream, const T* dy_buf, T* dx_buf, const int64_t* dy_dims, const int64_t* stride, const int64_t dy_num_dims, const int64_t storage_offset, const int64_t dx_num, const int64_t dy_num) { NdIndexOffsetHelper dyIndexOffsetHelper(dy_dims, dy_num_dims); AsStridedParams params; params.destIndexOffsetHelper = dyIndexOffsetHelper; FOR_RANGE(size_t, i, 0, dy_num_dims) { params.dest_dims[i] = dy_dims[i]; params.stride[i] = stride[i]; } params.dest_num_dims = dy_num_dims; params.storage_offset = storage_offset; params.input_num = dx_num; params.output_num = dy_num; AsStridedGrad_kernel <<As()->cuda_stream()>>>(dy_buf, dx_buf, params); } }; } // namespace template class GpuAsStridedKernel final : public user_op::OpKernel { public: GpuAsStridedKernel() = default; ~GpuAsStridedKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const auto size = ctx->Attr>("size"); const auto stride = ctx->Attr>("stride"); const int64_t storage_offset = ctx->Attr("storage_offset"); size_t dest_num_dims = output->shape_view().NumAxes(); const int64_t* dest_dims = output->shape_view().ptr(); const size_t input_num = input->shape_view().Count(0); const size_t output_num = output->shape_view().Count(0); if (input_num == 0) { // 0-size tensor return; } AsStridedFunctor()(ctx->stream(), input->dptr(), output->mut_dptr(), dest_dims, stride.data(), dest_num_dims, storage_offset, input_num, output_num); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuAsStridedGradKernel final : public user_op::OpKernel { public: GpuAsStridedGradKernel() = default; ~GpuAsStridedGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto size = ctx->Attr>("size"); const auto stride = ctx->Attr>("stride"); const int64_t storage_offset = ctx->Attr("storage_offset"); size_t dy_num_dims = dy->shape_view().NumAxes(); const int64_t* dy_dims = dy->shape_view().ptr(); const size_t dx_num = dx->shape_view().Count(0); const size_t dy_num = dy->shape_view().Count(0); Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape_view().Count(0) * sizeof(T)); AsStridedGradFunctor()(ctx->stream(), dy->dptr(), dx->mut_dptr(), dy_dims, stride.data(), dy_num_dims, storage_offset, dx_num, dy_num); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GPU_ASSTRIDED_KERNEL(in_type) \ REGISTER_USER_KERNEL("as_strided") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_GPU_ASSTRIDED_KERNEL(half); REGISTER_GPU_ASSTRIDED_KERNEL(float); REGISTER_GPU_ASSTRIDED_KERNEL(double); REGISTER_GPU_ASSTRIDED_KERNEL(int8_t); REGISTER_GPU_ASSTRIDED_KERNEL(uint8_t); REGISTER_GPU_ASSTRIDED_KERNEL(int32_t); REGISTER_GPU_ASSTRIDED_KERNEL(int64_t); #undef REGISTER_GPU_ASSTRIDED_KERNEL #define REGISTER_GPU_ASSTRIDED_GRAD_KERNEL(in_type) \ REGISTER_USER_KERNEL("as_strided_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_GPU_ASSTRIDED_GRAD_KERNEL(half); REGISTER_GPU_ASSTRIDED_GRAD_KERNEL(float); REGISTER_GPU_ASSTRIDED_GRAD_KERNEL(double); #undef REGISTER_GPU_ASSTRIDED_GRAD_KERNEL REGISTER_USER_KERNEL("as_strided") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("input", 0) == GetDataType::value)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/assign_if_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace { template class AssignIfCPUKernel final : public user_op::OpKernel { public: AssignIfCPUKernel() = default; ~AssignIfCPUKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* condition = ctx->Tensor4ArgNameAndIndex("condition", 0); if ((assign_if == (*condition->dptr() == 0))) { return; } const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex("ref", 0); if (value->dptr() == ref->dptr()) { return; } CHECK_EQ(value->shape_view(), ref->shape_view()); CHECK_EQ(value->data_type(), ref->data_type()); const size_t tensor_bytes_size = ref->shape_view().elem_cnt() * GetSizeOfDataType(ref->data_type()); AutoMemcpy(ctx->stream(), ref->mut_dptr(), value->dptr(), tensor_bytes_size, ref->mem_case(), value->mem_case()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; } // namespace #define REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL(op_type_name, assign_if, condition_type) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("condition", 0) == GetDataType::value)); #define REGISTER_ASSIGN_IF_CPU_KERNEL(condition_cpp_type, condition_data_type) \ REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL("assign_if", true, condition_cpp_type); \ REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL("assign_if_not", false, condition_cpp_type) OF_PP_FOR_EACH_TUPLE(REGISTER_ASSIGN_IF_CPU_KERNEL, INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/assign_if_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void AssignGpu(int64_t elem_cnt, const C* condition, const T* value, T* ref) { if (assign_if == (*condition == 0)) { return; } CUDA_1D_KERNEL_LOOP(i, elem_cnt) { ref[i] = value[i]; } } template class AssignIfGPUKernel final : public user_op::OpKernel { public: AssignIfGPUKernel() = default; ~AssignIfGPUKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* condition = ctx->Tensor4ArgNameAndIndex("condition", 0); CHECK_EQ(condition->shape_view().NumAxes(), 1); CHECK_EQ(condition->shape_view().At(0), 1); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex("ref", 0); if (value->dptr() == ref->dptr()) { return; } CHECK_EQ(value->shape_view(), ref->shape_view()); CHECK_EQ(value->data_type(), ref->data_type()); const size_t elem_cnt = ref->shape_view().elem_cnt(); AssignGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, condition->dptr(), value->dptr(), ref->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; } // namespace #define REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL(op_type_name, assign_if, condition_type, \ value_type) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("condition", 0) == GetDataType::value) \ && (user_op::HobDataType("value", 0) == GetDataType::value)); #define REGISTER_ASSIGN_IF_CUDA_KERNEL(condition_type, value_type) \ REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL( \ "assign_if", true, OF_PP_PAIR_FIRST(condition_type), OF_PP_PAIR_FIRST(value_type)); \ REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL( \ "assign_if_not", false, OF_PP_PAIR_FIRST(condition_type), OF_PP_PAIR_FIRST(value_type)) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ASSIGN_IF_CUDA_KERNEL, INT_DATA_TYPE_SEQ, POD_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/assign_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace { class AssignKernel final : public user_op::OpKernel { public: AssignKernel() = default; ~AssignKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* ref_tensor = ctx->Tensor4ArgNameAndIndex("ref", 0); if (value_tensor->dptr() == ref_tensor->dptr()) { return; } size_t tensor_bytes_size = ref_tensor->shape_view().elem_cnt() * GetSizeOfDataType(ref_tensor->data_type()); size_t val_tensor_bytes_size = value_tensor->shape_view().elem_cnt() * GetSizeOfDataType(value_tensor->data_type()); CHECK_EQ(tensor_bytes_size, val_tensor_bytes_size); AutoMemcpy(ctx->stream(), ref_tensor->mut_dptr(), value_tensor->dptr(), tensor_bytes_size, ref_tensor->mem_case(), value_tensor->mem_case()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; } // namespace REGISTER_USER_KERNEL("assign").SetCreateFn(); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/avg_pool_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/avg_pool_kernel_util.h" namespace oneflow { struct AvgPoolOpKernelCache final : public user_op::OpKernelCache { AvgPoolParams3D params_3d; explicit AvgPoolOpKernelCache(const AvgPoolParams3D& params_3d) : params_3d(params_3d) {} const AvgPoolParams3D& GetParams3D() const { return params_3d; } }; std::shared_ptr CreateAvgOpKernelCache(user_op::KernelCacheContext* ctx, const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& stride = ctx->Attr>("stride"); const bool ceil_mode = ctx->Attr("ceil_mode"); const bool count_include_pad = ctx->Attr("count_include_pad"); const int32_t divisor_override = ctx->Attr("divisor_override"); AvgPoolParams3D params_3d = AvgPoolParams3D(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); std::shared_ptr cache(new AvgPoolOpKernelCache(params_3d)); return cache; } template struct AvgPoolKernelUtil { static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool1dForwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool1dBackwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool2dForwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool2dBackwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool3dForwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { Avgpool3dBackwardCompute( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } }; template class AvgPool1dKernel final : public user_op::OpKernel { public: AvgPool1dKernel() = default; ~AvgPool1dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); DimVector y_vector(2); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool1dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool1dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; template class AvgPool1dGradKernel final : public user_op::OpKernel { public: AvgPool1dGradKernel() = default; ~AvgPool1dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); T* dest = dx->mut_dptr(); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); DimVector dy_vector(2); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool1dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool1dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; template class AvgPool2dKernel final : public user_op::OpKernel { public: AvgPool2dKernel() = default; ~AvgPool2dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); DimVector y_vector(3); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); y_vector.at(2) = y->shape_view().At(3); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool2dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool2dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; template class AvgPool2dGradKernel final : public user_op::OpKernel { public: AvgPool2dGradKernel() = default; ~AvgPool2dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); T* dest = dx->mut_dptr(); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); DimVector dy_vector(3); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); dy_vector.at(2) = dy->shape_view().At(3); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool2dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool2dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; template class AvgPool3dKernel final : public user_op::OpKernel { public: AvgPool3dKernel() = default; ~AvgPool3dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); DimVector y_vector(4); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); y_vector.at(2) = y->shape_view().At(3); y_vector.at(3) = y->shape_view().At(4); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool3dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); AvgPoolKernelUtil::Avgpool3dForward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; template class AvgPool3dGradKernel final : public user_op::OpKernel { public: AvgPool3dGradKernel() = default; ~AvgPool3dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateAvgOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const AvgPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); T* dest = dx->mut_dptr(); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); DimVector dy_vector(4); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); dy_vector.at(2) = dy->shape_view().At(3); dy_vector.at(3) = dy->shape_view().At(4); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool3dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); AvgPoolKernelUtil::Avgpool3dBackward(ctx->stream(), index_helper, elem_num, src, dest, params_3d); } }; }; #define REGISTER_AVG_POOL_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("avg_pool_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("avg_pool_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("avg_pool_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("avg_pool_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("avg_pool_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("avg_pool_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); #define REGISTER_AVG_POOL_WITH_DEVICE(device) \ REGISTER_AVG_POOL_KERNELS(device, float) \ REGISTER_AVG_POOL_KERNELS(device, double) REGISTER_AVG_POOL_WITH_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_AVG_POOL_WITH_DEVICE(DeviceType::kCUDA) REGISTER_AVG_POOL_KERNELS(DeviceType::kCUDA, half) #endif OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_AVG_POOL_KERNEL_UTIL, (DeviceType::kCPU), AVG_POOL_DATA_TYPE_CPU_SEQ, AVG_POOL_IDX_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/avg_pool_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/avg_pool_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetMinThreadNum(const int64_t elem_num) { return std::min(elem_num, kBlockSize); } int GetNumBlocks(int32_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } } // namespace template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool1dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t x_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { Avgpool1dForwardCompute(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel, x_length, kernel_size_l, stride_l, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool2dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { Avgpool2dForwardCompute(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch, n_channel, x_height, x_width, kernel_size_h, kernel_size_w, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool3dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { Avgpool3dForwardCompute(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool1dBackward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t input_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { Avgpool1dBackwardCompute(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel, input_length, kernel_size_l, stride_l, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool2dBackward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t input_height, const int32_t input_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { Avgpool2dBackwardCompute(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch, n_channel, input_height, input_width, kernel_size_h, kernel_size_w, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool3dBackward( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { Avgpool3dBackwardCompute(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool1dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t x_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { HalfAvgpool1dForwardCompute(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel, x_length, kernel_size_l, stride_l, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool2dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { HalfAvgpool2dForwardCompute(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch, n_channel, x_height, x_width, kernel_size_h, kernel_size_w, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool3dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { HalfAvgpool3dForwardCompute(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool1dBackward(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t input_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { HalfAvgpool1dBackwardCompute(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel, input_length, kernel_size_l, stride_l, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool2dBackward(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t input_height, const int32_t input_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { HalfAvgpool2dBackwardCompute(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch, n_channel, input_height, input_width, kernel_size_h, kernel_size_w, stride_h, stride_w, count_include_pad, divisor_override); }; template __launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool3dBackward( const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { HalfAvgpool3dBackwardCompute(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, count_include_pad, divisor_override); }; template struct AvgPoolKernelUtil { static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool1dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool1dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool2dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool2dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool3dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d) { DoCUDAAvgPool3dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } }; template struct AvgPoolKernelUtil { static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool1dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool1dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool2dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool2dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool3dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d) { DoHalfAvgPool3dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override()); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_AVG_POOL_KERNEL_UTIL, (DeviceType::kCUDA), AVG_POOL_DATA_TYPE_CUDA_SEQ, AVG_POOL_IDX_DATA_TYPE_SEQ); template struct AvgPoolKernelUtil; template struct AvgPoolKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/avg_pool_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/avg_pool_kernel_util.h" namespace oneflow { std::vector GetAvg3DVec(const std::vector& original_vec, int32_t NDims) { std::vector vec; FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - NDims); if (index < 0) { vec.emplace_back(1); } else { vec.emplace_back(original_vec.at(index)); } } return vec; } std::vector GetAvg3DPadVec(const std::vector& original_vec, int32_t NDims) { std::vector vec; FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - NDims); if (index < 0) { vec.emplace_back(0); } else { vec.emplace_back(original_vec.at(index)); } } return vec; } const int64_t GetNoDilationWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride, int32_t padding, bool ceil_mode) { int64_t output_size = (input_size + 2 * padding - (filter_size - 1) - 1 + stride + (ceil_mode ? stride - 1 : 0)) / stride; if (ceil_mode) { // ensure that the last pooling starts inside the image // needed to avoid problems in ceil mode if ((output_size - 1) * stride >= input_size + padding) { --output_size; } } return output_size; } void GetNoDilation3DOutputShape(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::vector& padding, const bool ceil_mode, DimVector* out) { out->clear(); out->resize(3); FOR_RANGE(size_t, i, 0, 3) { out->at(i) = GetNoDilationWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode); } } AvgPoolParams3D::AvgPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const bool ceil_mode, const bool count_include_pad, const int32_t divisor_override) : dim_(dim), data_format_(data_format), padding_(GetAvg3DPadVec(padding, dim)), pool_size_3d_(GetAvg3DVec(kernel_size, dim)), stride_3d_(GetAvg3DVec(stride, dim)), ceil_mode_(ceil_mode), count_include_pad_(count_include_pad), divisor_override_(divisor_override) { x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), GetInDim(x_shape, data_format, 2, dim)}; GetNoDilation3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, &y_3d_); if (data_format == "channels_first") { channel_num_ = x_shape.At(1); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; channel_num_ = x_shape.At(x_shape.NumAxes() - 1); } batch_num_ = x_shape.At(0); } void AvgPoolParams3D::Reset(const ShapeView& x_shape) { x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_), GetInDim(x_shape, data_format_, 2, dim_)}; GetNoDilation3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, &y_3d_); } Shape AvgPoolParams3D::GetYShape() const { DimVector y_dim_vec; if (dim_ == 1) { y_dim_vec = {y_3d_.at(2)}; } else if (dim_ == 2) { y_dim_vec = {y_3d_.at(1), y_3d_.at(2)}; } else if (dim_ == 3) { y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}; } else { UNIMPLEMENTED(); } if (data_format_ == "channels_first") { y_dim_vec.insert(y_dim_vec.begin(), channel_num_); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; y_dim_vec.insert(y_dim_vec.end(), channel_num_); } y_dim_vec.insert(y_dim_vec.begin(), batch_num_); return Shape(y_dim_vec); } Shape AvgPoolParams3D::GetXShape5D() const { return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)}); } Shape AvgPoolParams3D::GetYShape5D() const { return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/avg_pool_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/core/kernel/util/numerics.cuh" #include "oneflow/core/kernel/util/numeric_limits.cuh" #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA namespace oneflow { namespace { template OF_DEVICE_FUNC T XPU_INT_MIN(T a, T b) { return a <= b ? a : b; } template OF_DEVICE_FUNC T XPU_INT_MAX(T a, T b) { return a >= b ? a : b; } template struct XPUAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #if defined(__CUDA_ARCH__) cuda::atomic::Add(y, *x); #else *y += *x; #endif }; }; } // namespace #define AVG_POOL_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define AVG_POOL_DATA_TYPE_CPU_SEQ AVG_POOL_DATA_TYPE_SEQ #define AVG_POOL_DATA_TYPE_CUDA_SEQ AVG_POOL_DATA_TYPE_SEQ #define AVG_POOL_IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) typedef small_vector FixedDimVector; class AvgPoolParams3D { public: AvgPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const bool ceil_mode, const bool count_include_pad, const int32_t divisor_override); ~AvgPoolParams3D() = default; const std::string& data_format() const { return data_format_; } const std::vector& padding() const { return padding_; } const std::vector& pool_size_3d() const { return pool_size_3d_; } const std::vector& stride_3d() const { return stride_3d_; } const bool& ceil_mode() const { return ceil_mode_; } const bool& count_include_pad() const { return count_include_pad_; } const int32_t& divisor_override() const { return divisor_override_; } const int32_t& num_batch() const { return batch_num_; } const int32_t& num_channel() const { return channel_num_; } void Reset(const ShapeView& x_shape); Shape GetYShape() const; Shape GetXShape5D() const; Shape GetYShape5D() const; private: int32_t dim_; FixedDimVector x_3d_; FixedDimVector y_3d_; std::string data_format_; std::vector padding_; std::vector pool_size_3d_; std::vector stride_3d_; bool ceil_mode_; bool count_include_pad_; int32_t divisor_override_; int32_t batch_num_; int32_t channel_num_; }; template struct AvgPoolKernelUtil { static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const AvgPoolParams3D& params_3d); }; template OF_DEVICE_FUNC void Avgpool1dForwardCompute(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t x_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); const IDX start_idx = n_c * x_length; IDX lstart = l * stride_l - padding_l; IDX lend = XPU_INT_MIN(lstart + kernel_size_l, x_length + padding_l); const IDX pool_size = (lend - lstart); lstart = XPU_INT_MAX(0, lstart); lend = XPU_INT_MIN(lend, x_length); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (lend - lstart); } } T sum = 0; const T* data = src + start_idx; for (IDX idx = lstart; idx < lend; idx += 1) { sum += data[idx]; } dest[num] = static_cast(sum / divide_factor); } } template OF_DEVICE_FUNC void Avgpool1dBackwardCompute(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t input_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); const IDX start_idx = n_c * input_length; IDX lstart = l * stride_l - padding_l; IDX lend = XPU_INT_MIN(lstart + kernel_size_l, input_length + padding_l); const IDX pool_size = (lend - lstart); lstart = XPU_INT_MAX(IDX(0), lstart); lend = XPU_INT_MIN(lend, input_length); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (lend - lstart); } } T grad_delta = src[num] / divide_factor; T* data = dest + start_idx; for (IDX idx = lstart; idx < lend; idx += 1) { XPUAdd::Invoke(&grad_delta, &data[idx]); // dest[search_idx] += grad_delta } } } template OF_DEVICE_FUNC void Avgpool2dForwardCompute( const NdIndexOffsetHelper index_helper, int64_t elem_num, const T* src, T* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); const IDX start_idx = n_c * x_width * x_height; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (hend - hstart) * (wend - wstart); hstart = XPU_INT_MAX(0, hstart); wstart = XPU_INT_MAX(0, wstart); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } T sum = 0; const T* data = src + start_idx; for (int64_t i = hstart; i < hend; i += 1) { for (int64_t j = wstart; j < wend; j += 1) { const IDX window_idx = i * x_width + j; sum += data[window_idx]; } } dest[num] = sum / divide_factor; } } template OF_DEVICE_FUNC void Avgpool2dBackwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t input_height, const int32_t input_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); const IDX start_idx = n_c * input_width * input_height; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX hend = XPU_INT_MIN(hstart + kernel_size_h, input_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, input_width + padding_w); const IDX pool_size = (hend - hstart) * (wend - wstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); hend = XPU_INT_MIN(hend, input_height); wend = XPU_INT_MIN(wend, input_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } T grad_delta = src[num] / divide_factor; T* data = dest + start_idx; for (IDX i = hstart; i < hend; i += 1) { for (IDX j = wstart; j < wend; j += 1) { const IDX window_idx = i * input_width + j; XPUAdd::Invoke(&grad_delta, &data[window_idx]); // dest[search_idx] += grad_delta } } } } template OF_DEVICE_FUNC void Avgpool3dForwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); const IDX start_idx = n_c * x_time * x_height * x_width; IDX tstart = t * stride_t - padding_t; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX tend = XPU_INT_MIN(tstart + kernel_size_t, x_time + padding_t); IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); tstart = XPU_INT_MAX(IDX(0), tstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); tend = XPU_INT_MIN(tend, x_time); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); } } T sum = 0; const T* data = src + start_idx; for (IDX i = tstart; i < tend; i += 1) { for (IDX j = hstart; j < hend; j += 1) { for (IDX k = wstart; k < wend; k += 1) { const IDX window_idx = i * x_height * x_width + j * x_width + k; sum += data[window_idx]; } } } dest[num] = sum / divide_factor; } } template OF_DEVICE_FUNC void Avgpool3dBackwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); const IDX start_idx = n_c * x_time * x_width * x_height; IDX tstart = t * stride_t - padding_t; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX tend = XPU_INT_MIN(tstart + kernel_size_t, x_time + padding_t); IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); tstart = XPU_INT_MAX(IDX(0), tstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); tend = XPU_INT_MIN(tend, x_time); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); } } T grad_delta = src[num] / divide_factor; T* data = dest + start_idx; for (IDX i = tstart; i < tend; i += 1) { for (IDX j = hstart; j < hend; j += 1) { for (IDX k = wstart; k < wend; k += 1) { const IDX window_idx = i * x_height * x_width + j * x_width + k; XPUAdd::Invoke(&grad_delta, &data[window_idx]); // dest[search_idx] += grad_delta } } } } } #ifdef WITH_CUDA template struct AvgPoolKernelUtil { static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const half* src, half* dest, const AvgPoolParams3D& params_3d); }; template OF_DEVICE_FUNC void HalfAvgpool1dForwardCompute(const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t x_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); const IDX start_idx = n_c * x_length; IDX lstart = l * stride_l - padding_l; IDX lend = XPU_INT_MIN(lstart + kernel_size_l, x_length + padding_l); const IDX pool_size = (lend - lstart); lstart = XPU_INT_MAX(0, lstart); lend = XPU_INT_MIN(lend, x_length); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (lend - lstart); } } float sum = 0; const half* data = src + start_idx; for (IDX idx = lstart; idx < lend; idx += 1) { sum += __half2float(data[idx]); } dest[num] = __float2half(sum / divide_factor); } } template OF_DEVICE_FUNC void HalfAvgpool1dBackwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t input_length, const int32_t kernel_size_l, const int32_t stride_l, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); const IDX start_idx = n_c * input_length; IDX lstart = l * stride_l - padding_l; IDX lend = XPU_INT_MIN(lstart + kernel_size_l, input_length + padding_l); const IDX pool_size = (lend - lstart); lstart = XPU_INT_MAX(IDX(0), lstart); lend = XPU_INT_MIN(lend, input_length); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (lend - lstart); } } half grad_delta = static_cast(__half2float(src[num]) / divide_factor); half* data = dest + start_idx; for (IDX idx = lstart; idx < lend; idx += 1) { XPUAdd::Invoke(&grad_delta, &data[idx]); } } } template OF_DEVICE_FUNC void HalfAvgpool2dForwardCompute( const NdIndexOffsetHelper index_helper, int64_t elem_num, const half* src, half* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); const IDX start_idx = n_c * x_width * x_height; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (hend - hstart) * (wend - wstart); hstart = XPU_INT_MAX(0, hstart); wstart = XPU_INT_MAX(0, wstart); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } float sum = 0; const half* data = src + start_idx; for (int64_t i = hstart; i < hend; i += 1) { for (int64_t j = wstart; j < wend; j += 1) { const IDX window_idx = i * x_width + j; sum += __half2float(data[window_idx]); } } dest[num] = __float2half(sum / divide_factor); } } template OF_DEVICE_FUNC void HalfAvgpool2dBackwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t input_height, const int32_t input_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); const IDX start_idx = n_c * input_width * input_height; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX hend = XPU_INT_MIN(hstart + kernel_size_h, input_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, input_width + padding_w); const IDX pool_size = (hend - hstart) * (wend - wstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); hend = XPU_INT_MIN(hend, input_height); wend = XPU_INT_MIN(wend, input_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } half grad_delta = static_cast(__half2float(src[num]) / divide_factor); half* data = dest + start_idx; for (IDX i = hstart; i < hend; i += 1) { for (IDX j = wstart; j < wend; j += 1) { const IDX window_idx = i * input_width + j; XPUAdd::Invoke(&grad_delta, &data[window_idx]); } } } } template OF_DEVICE_FUNC void HalfAvgpool3dForwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); const IDX start_idx = n_c * x_time * x_height * x_width; IDX tstart = t * stride_t - padding_t; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX tend = XPU_INT_MIN(tstart + kernel_size_t, x_time + padding_t); IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); tstart = XPU_INT_MAX(IDX(0), tstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); tend = XPU_INT_MIN(tend, x_time); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); } } float sum = 0; const half* data = src + start_idx; for (IDX i = tstart; i < tend; i += 1) { for (IDX j = hstart; j < hend; j += 1) { for (IDX k = wstart; k < wend; k += 1) { const IDX window_idx = i * x_height * x_width + j * x_width + k; sum += __half2float(data[window_idx]); } } } dest[num] = __float2half(sum / divide_factor); } } template OF_DEVICE_FUNC void HalfAvgpool3dBackwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const half* src, half* dest, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); const IDX start_idx = n_c * x_time * x_width * x_height; IDX tstart = t * stride_t - padding_t; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; IDX tend = XPU_INT_MIN(tstart + kernel_size_t, x_time + padding_t); IDX hend = XPU_INT_MIN(hstart + kernel_size_h, x_height + padding_h); IDX wend = XPU_INT_MIN(wstart + kernel_size_w, x_width + padding_w); const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); tstart = XPU_INT_MAX(IDX(0), tstart); hstart = XPU_INT_MAX(IDX(0), hstart); wstart = XPU_INT_MAX(IDX(0), wstart); tend = XPU_INT_MIN(tend, x_time); hend = XPU_INT_MIN(hend, x_height); wend = XPU_INT_MIN(wend, x_width); IDX divide_factor; if (divisor_override != static_cast(0)) { divide_factor = divisor_override; } else { if (count_include_pad) { divide_factor = pool_size; } else { divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); } } half grad_delta = static_cast(__half2float(src[num]) / divide_factor); half* data = dest + start_idx; for (IDX i = tstart; i < tend; i += 1) { for (IDX j = hstart; j < hend; j += 1) { for (IDX k = wstart; k < wend; k += 1) { const IDX window_idx = i * x_height * x_width + j * x_width + k; XPUAdd::Invoke(&grad_delta, &data[window_idx]); } } } } } #endif // WITH_CUDA #define INSTANTIATE_AVG_POOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \ template struct AvgPoolKernelUtil; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/batch_gather_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/batch_gather_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { template class BatchGatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BatchGatherKernel() = default; ~BatchGatherKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t axis = indices->shape_view().NumAxes() - 1; const Shape flat_out_shape = Shape({out->shape_view().Count(0, axis), out->shape_view().At(axis), out->shape_view().Count(axis + 1)}); BatchGatherKernelUtilImpl::Forward( ctx->stream(), in->dptr(), indices->dptr(), flat_out_shape, in->shape_view().At(axis), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_GATHER_KERNEL(device, out_dtype, indices_dtype) \ REGISTER_USER_KERNEL("batch_gather") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BATCH_GATHER_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_gather_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/batch_gather_kernel_util.h" namespace oneflow { namespace { Shape GetFlatShape(const ShapeView& shape, const int64_t axis) { CHECK_GT(shape.NumAxes(), 0); CHECK_GE(axis, 0); CHECK_LT(axis, shape.NumAxes()); return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); } template void BatchGatherForward(ep::Stream* stream, const Blob* in, const Blob* indices, Blob* out) { const int64_t axis = indices->shape_view().NumAxes() - 1; const Shape flat_out_shape = GetFlatShape(out->shape_view(), axis); BatchGatherKernelUtilImpl::Forward(stream, in->dptr(), indices->dptr(), flat_out_shape, in->shape_view().At(axis), out->mut_dptr()); } template void BatchGatherBackward(ep::Stream* stream, const Blob* out_diff, const Blob* indices, Blob* in_diff) { Memset(stream, in_diff->mut_dptr(), 0, in_diff->ByteSizeOfBlobBody()); const int64_t axis = indices->shape_view().NumAxes() - 1; const Shape flat_out_diff_shape = GetFlatShape(out_diff->shape_view(), axis); BatchGatherKernelUtilImpl::Backward( stream, out_diff->dptr(), indices->dptr(), flat_out_diff_shape, in_diff->shape_view().At(axis), in_diff->mut_dptr()); } template struct BatchGatherSwitchUtil final { #define MAKE_BATCH_GATHER_SWITCH_ENTRY(func_name, K) func_name #define DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(func_name) \ DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_BATCH_GATHER_SWITCH_ENTRY, \ MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherForward); DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherBackward); #undef DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC #undef MAKE_BATCH_GATHER_SWITCH_ENTRY }; } // namespace template void BatchGatherKernelUtil::Forward(ep::Stream* stream, const Blob* in, const Blob* indices, Blob* out) { BatchGatherSwitchUtil::SwitchBatchGatherForward(SwitchCase(indices->data_type()), stream, in, indices, out); } template void BatchGatherKernelUtil::Backward(ep::Stream* stream, const Blob* out_diff, const Blob* indices, Blob* in_diff) { BatchGatherSwitchUtil::SwitchBatchGatherBackward( SwitchCase(indices->data_type()), stream, out_diff, indices, in_diff); } template struct BatchGatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const T* in, const K* indices, const Shape& flat_out_shape, int64_t gather_dim_size, T* out); static void Backward(ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape, int64_t gather_dim_size, T* in_diff); }; template void BatchGatherKernelUtilImpl::Forward(ep::Stream* stream, const T* in, const K* indices, const Shape& flat_out_shape, const int64_t gather_dim_size, T* out) { const int64_t batch_num = flat_out_shape.At(0); const int64_t indices_num = flat_out_shape.At(1); const int64_t instance_size = flat_out_shape.At(2); FOR_RANGE(int64_t, batch_idx, 0, batch_num) { FOR_RANGE(int64_t, i, 0, indices_num) { const K idx = indices[batch_idx * indices_num + i]; CHECK(idx >= 0 && idx < gather_dim_size); const T* from = in + batch_idx * gather_dim_size * instance_size + idx * instance_size; T* to = out + batch_idx * indices_num * instance_size + i * instance_size; std::copy(from, from + instance_size, to); } } } template void BatchGatherKernelUtilImpl::Backward( ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff) { const int64_t batch_num = flat_out_diff_shape.At(0); const int64_t indices_num = flat_out_diff_shape.At(1); const int64_t instance_size = flat_out_diff_shape.At(2); FOR_RANGE(int64_t, batch_idx, 0, batch_num) { FOR_RANGE(int64_t, i, 0, indices_num) { const int64_t idx = indices[batch_idx * indices_num + i]; CHECK(idx >= 0 && idx < gather_dim_size); const T* from = out_diff + batch_idx * indices_num * instance_size + i * instance_size; T* to = in_diff + batch_idx * gather_dim_size * instance_size + idx * instance_size; std::transform(from, from + instance_size, to, to, std::plus()); } } } #define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU(in_type_pair, index_type_pair) \ template struct BatchGatherKernelUtilImpl; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); #undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU #define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL(device_type, in_type_pair) \ template struct BatchGatherKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_gather_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/batch_gather_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template __device__ int64_t GetInOffset(const int64_t out_offset, const K* indices, const int64_t indices_num, const int64_t instance_size, const int64_t gather_dim_size) { const int64_t batch_idx = out_offset / (indices_num * instance_size); const int64_t indices_idx = out_offset % (indices_num * instance_size) / instance_size; const int64_t inner_idx = out_offset % instance_size; const int64_t idx = indices[batch_idx * indices_num + indices_idx]; assert(idx >= 0 && idx < gather_dim_size); return batch_idx * gather_dim_size * instance_size + idx * instance_size + inner_idx; } template __global__ void BatchGatherForwardGpu(const int64_t elem_cnt, const T* in, const K* indices, const int64_t indices_num, const int64_t instance_size, const int64_t gather_dim_size, T* out) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { out[i] = in[GetInOffset(i, indices, indices_num, instance_size, gather_dim_size)]; } } template __global__ void BatchGatherBackwardGpu(const int64_t elem_cnt, const T* out_diff, const K* indices, const int64_t indices_num, const int64_t instance_size, const int64_t gather_dim_size, T* in_diff) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { cuda::atomic::Add( in_diff + GetInOffset(i, indices, indices_num, instance_size, gather_dim_size), out_diff[i]); } } } // namespace template struct BatchGatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const T* in, const K* indices, const Shape& flat_out_shape, const int64_t gather_dim_size, T* out); static void Backward(ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff); }; template void BatchGatherKernelUtilImpl::Forward(ep::Stream* stream, const T* in, const K* indices, const Shape& flat_out_shape, const int64_t gather_dim_size, T* out) { const int64_t batch_num = flat_out_shape.At(0); const int64_t indices_num = flat_out_shape.At(1); const int64_t instance_size = flat_out_shape.At(2); const int64_t elem_cnt = batch_num * indices_num * instance_size; BatchGatherForwardGpu<<As()->cuda_stream()>>>( elem_cnt, in, indices, indices_num, instance_size, gather_dim_size, out); } template void BatchGatherKernelUtilImpl::Backward( ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff) { const int64_t batch_num = flat_out_diff_shape.At(0); const int64_t indices_num = flat_out_diff_shape.At(1); const int64_t instance_size = flat_out_diff_shape.At(2); const int64_t elem_cnt = batch_num * indices_num * instance_size; BatchGatherBackwardGpu<<As()->cuda_stream()>>>( elem_cnt, out_diff, indices, indices_num, instance_size, gather_dim_size, in_diff); } #define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA(in_type_pair, index_type_pair) \ template struct BatchGatherKernelUtilImpl; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); #undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_gather_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel.h" namespace oneflow { template struct BatchGatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const T* in, const K* indices, const Shape& flat_out_shape, int64_t gather_dim_size, T* out); static void Backward(ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape, int64_t gather_dim_size, T* in_diff); }; template struct BatchGatherKernelUtil final { static void Forward(ep::Stream* stream, const Blob* in, const Blob* indices, Blob* out); static void Backward(ep::Stream* stream, const Blob* out_diff, const Blob* indices, Blob* in_diff); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/batch_norm_backward_elemt_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/user/kernels/batch_norm_kernel_utils.h" // NOTE(Liang Depeng): // The implementation of batch_norm_backward_elemt kernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh namespace oneflow { namespace { template __global__ void batch_norm_backward_elemt_kernel( const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* sum_dy_ptr, const T* sum_dy_xmu_ptr, T* grad_in_ptr, const int32_t* count_ptr, const int64_t world_size) { int64_t total_numel = 0; for (int i = 0; i < world_size; i++) { total_numel += count_ptr[i]; } const ACC_T norm_fct = static_cast(1) / static_cast(total_numel); IDX_TYPE channel = blockIdx.x; if (channel >= channel_size) { return; } ACC_T m_c = mean_ptr[channel]; ACC_T m_dy_c = sum_dy_ptr[channel] * norm_fct; ACC_T factor_1_c = invstd_ptr[channel]; ACC_T factor_2_c = static_cast(weight_ptr[channel]); factor_2_c *= factor_1_c; factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_ptr[channel] * norm_fct; IDX_TYPE batch_offset = spatial_size * channel_size; IDX_TYPE channel_offset = channel * spatial_size; IDX_TYPE bstep = blockDim.y * gridDim.y; for (IDX_TYPE batch = threadIdx.y + blockIdx.y * blockDim.y; batch < batch_size; batch += bstep) { IDX_TYPE offset = batch * batch_offset; for (IDX_TYPE feature = threadIdx.x; feature < spatial_size; feature += blockDim.x) { grad_in_ptr[offset + channel_offset + feature] = static_cast((grad_out_ptr[offset + channel_offset + feature] - m_dy_c - (input_ptr[offset + channel_offset + feature] - m_c) * factor_1_c) * factor_2_c); } } } template __global__ void batch_norm_backward_elemt_channels_last_kernel( const T* grad_out_ptr, const T* input_ptr, const ACC_T* mean_ptr, const ACC_T* invstd_ptr, const T* weight_ptr, const ACC_T* sum_dy_ptr, const ACC_T* sum_dy_xmu_ptr, const int32_t* count_ptr, T* grad_in_ptr, const IDX_TYPE world_size, const IDX_TYPE stride, const IDX_TYPE reduction_size) { IDX_TYPE total_numel = 0; for (IDX_TYPE i = 0; i < world_size; i++) { total_numel += count_ptr[i]; } auto norm_fct = static_cast(1) / static_cast(total_numel); // tensor dimension (m,c) // loop along m dimension IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y; IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } auto m_c = mean_ptr[c_offset]; auto m_dy_c = sum_dy_ptr[c_offset] * norm_fct; auto factor_1_c = invstd_ptr[c_offset]; auto factor_2_c = (weight_ptr == nullptr ? ACC_T(1.0) : static_cast(weight_ptr[c_offset])) * factor_1_c; factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_ptr[c_offset] * norm_fct; int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; for (int i = 0; i < loop_count; i++) { #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { grad_in_ptr[address_base] = static_cast((static_cast(grad_out_ptr[address_base]) - m_dy_c - (static_cast(input_ptr[address_base]) - m_c) * factor_1_c) * factor_2_c); } m_offset += inner_loop_stride; address_base += address_increment; } } } template struct BatchNormBackwardElemtFunctor final { void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size, const int64_t spatial_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* sum_dy_ptr, const T* sum_dy_xmu_ptr, T* grad_in_ptr, const int32_t* count_ptr, const int64_t world_size) { using ACC_T = acc_type; // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, // weight/bias) - which we only do once and have a for loop afterwards - with having many // threads and blocks and good occupancy. Quiet likely, we could go with even more blocks than // 1024. The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(spatial_size / 4), std::min(getNumThreads(spatial_size), 64)); int tb = std::max(64 / tf, 1); dim3 blocks_trans(channel_size, std::max(1, std::min((256 * 1024) / channel_size, (batch_size + tb - 1) / tb))); blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); dim3 threads_trans(tf, tb); if (batch_size * channel_size * spatial_size < std::numeric_limits::max()) { batch_norm_backward_elemt_kernel <<As()->cuda_stream()>>>( static_cast(batch_size), static_cast(channel_size), static_cast(spatial_size), grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size); } else { batch_norm_backward_elemt_kernel <<As()->cuda_stream()>>>( batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size); } } }; template struct BatchNormBackwardElemtChannelLastFunctor final { void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* sum_dy_ptr, const T* sum_dy_xmu_ptr, T* grad_in_ptr, const int32_t* count_ptr, const int64_t world_size) { using ACC_T = acc_type; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid); if (stride * reduction_size < std::numeric_limits::max()) { batch_norm_backward_elemt_channels_last_kernel <<As()->cuda_stream()>>>( grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, count_ptr, grad_in_ptr, world_size, static_cast(stride), static_cast(reduction_size)); } else { batch_norm_backward_elemt_channels_last_kernel <<As()->cuda_stream()>>>( grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, count_ptr, grad_in_ptr, world_size, stride, reduction_size); } } }; } // namespace template class GpuBatchNormBackwardElemtKernel final : public user_op::OpKernel { public: GpuBatchNormBackwardElemtKernel() = default; ~GpuBatchNormBackwardElemtKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex("invstd", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* sum_dy = ctx->Tensor4ArgNameAndIndex("sum_dy", 0); const user_op::Tensor* sum_dy_xmu = ctx->Tensor4ArgNameAndIndex("sum_dy_xmu", 0); const user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex("count", 0); user_op::Tensor* grad_in = ctx->Tensor4ArgNameAndIndex("grad_in", 0); const T* grad_out_ptr = grad_out->dptr(); const T* input_ptr = input->dptr(); const T* mean_ptr = mean->dptr(); const T* invstd_ptr = invstd->dptr(); const T* weight_ptr = weight->dptr(); const T* sum_dy_ptr = sum_dy->dptr(); const T* sum_dy_xmu_ptr = sum_dy_xmu->dptr(); const int32_t* count_ptr = count->dptr(); T* grad_in_ptr = grad_in->mut_dptr(); const int32_t axis = ctx->Attr("axis"); bool use_channels_last_kernel = axis == 1 ? false : true; const int64_t world_size = count->shape_view().elem_cnt(); if (use_channels_last_kernel) { // NHWC format const int64_t stride = input->shape_view().At(axis); const int64_t reduction_size = input->shape_view().elem_cnt() / stride; BatchNormBackwardElemtChannelLastFunctor()( ctx->stream(), stride, reduction_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size); } else { // NCHW format const int64_t batch_size = input->shape_view().At(0); const int64_t channel_size = input->shape_view().At(1); const int64_t spatial_size = input->shape_view().Count(2); BatchNormBackwardElemtFunctor()( ctx->stream(), batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(dtype) \ REGISTER_USER_KERNEL("batch_norm_backward_elemt") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("grad_out", 0) == GetDataType::value) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("mean", 0) == GetDataType::value) \ && (user_op::HobDataType("invstd", 0) == GetDataType::value) \ && (user_op::HobDataType("weight", 0) == GetDataType::value) \ && (user_op::HobDataType("sum_dy", 0) == GetDataType::value) \ && (user_op::HobDataType("sum_dy_xmu", 0) == GetDataType::value) \ && (user_op::HobDataType("count", 0) == GetDataType::value)) REGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(float); REGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_norm_backward_reduce_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/user/kernels/batch_norm_kernel_utils.h" // NOTE(Liang Depeng): // The implementation of batch_norm_backward_reduce kernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh namespace oneflow { namespace { template static size_t InferTmpSizeForChannelLastKernel(user_op::InferContext* ctx) { const int32_t axis = ctx->Attr("axis"); const Shape& in_shape = ctx->InputTensorDesc("input", 0).shape(); const int64_t stride = in_shape.At(axis); const int64_t reduction_size = in_shape.elem_cnt() / stride; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); size_t tmp_size = 0; if (grid.y > 1) { tmp_size += 2 * stride * grid.y * sizeof(T); tmp_size += grid.x * sizeof(int32_t); } return tmp_size; } template struct Float2 { ACC_T v1, v2; __device__ Float2() {} __device__ Float2(T v1, T v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} __device__ Float2& operator+=(const Float2& a) { v1 += a.v1; v2 += a.v2; return *this; } }; // Sum across all threads within a warp template static __device__ __forceinline__ T warpSum_(T val) { for (int i = 0; i < getMSB(WARP_SIZE); ++i) { val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); } return val; } template static __device__ __forceinline__ RES_T warpSum(RES_T value) { value.v1 = warpSum_(value.v1); value.v2 = warpSum_(value.v2); return value; } template __device__ RES_T reduce(const T* input_ptr, const T* grad_out_ptr, ACC_T r_mean, IDX_TYPE channel, IDX_TYPE batch_size, IDX_TYPE channel_size, IDX_TYPE spatial_size) { IDX_TYPE batch_offset = spatial_size * channel_size; IDX_TYPE channel_offset = channel * spatial_size; // first the reductions each thread does separately RES_T sum = static_cast(0); for (int batch = threadIdx.y; batch < batch_size; batch += blockDim.y) { IDX_TYPE offset = batch * batch_offset; for (int x = threadIdx.x; x < spatial_size; x += blockDim.x) { // sum += op(batch, plane, x); ACC_T g = grad_out_ptr[offset + channel_offset + x]; ACC_T c = static_cast(input_ptr[offset + channel_offset + x]) - r_mean; sum.v1 += g; sum.v2 += g * c; } } // first warpSum to get one value per thread to // one value per warp sum = warpSum(sum); // this writes each warps item into shared memory // there are at most WARP_SIZE items left because // there are at most WARP_SIZE**2 threads at the beginning __shared__ RES_T shared[WARP_SIZE]; __syncthreads(); int tid = threadIdx.x + threadIdx.y * blockDim.x; if (tid % WARP_SIZE == 0) { shared[tid / WARP_SIZE] = sum; } if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) { // zero out the other entries in shared shared[tid] = (RES_T)0; } __syncthreads(); // now have a second warpSum to reduce the intermediate values // from shared memory to a single number. The very first // thread writes it to shared memory. if (tid / WARP_SIZE == 0) { sum = warpSum(shared[tid]); if (tid == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole grad_input return shared[0]; } template __global__ void batch_norm_backward_reduce_kernel( const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, T* sum_dy_ptr, T* sum_dy_xmu_ptr, T* grad_weight_ptr, T* grad_bias_ptr) { IDX_TYPE channel = blockIdx.x; ACC_T r_mean = mean_ptr[channel]; ACC_T factor = invstd_ptr[channel]; auto res = reduce, T, ACC_T, IDX_TYPE>(input_ptr, grad_out_ptr, r_mean, channel, batch_size, channel_size, spatial_size); if (threadIdx.x == 0) { if (grad_weight_ptr != nullptr) { grad_weight_ptr[channel] = static_cast(res.v2 * factor); } if (grad_bias_ptr != nullptr) { grad_bias_ptr[channel] = static_cast(res.v1); } if (sum_dy_ptr != nullptr) { sum_dy_ptr[channel] = static_cast(res.v1); } if (sum_dy_xmu_ptr != nullptr) { sum_dy_xmu_ptr[channel] = static_cast(res.v2); } } } template __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, T& sum_dy_xmu, T* shmem_sum_dy, T* shmem_sum_dy_xmu) { // write to shared memory auto address_base = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { if (threadIdx.y < offset * 2) { shmem_sum_dy[address_base] = sum_dy; shmem_sum_dy_xmu[address_base] = sum_dy_xmu; } __syncthreads(); if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { auto address = address_base + offset * blockDim.x; sum_dy += shmem_sum_dy[address]; sum_dy_xmu += shmem_sum_dy_xmu[address]; } } } template __global__ void batch_norm_backward_reduce_channels_last_kernel( const T* __restrict__ grad_output_ptr, const T* __restrict__ input_ptr, const ACC_T* __restrict__ mean_ptr, const ACC_T* __restrict__ inv_std_ptr, ACC_T* __restrict__ sum_dy_o_ptr, ACC_T* __restrict__ sum_dy_xmu_o_ptr, T* __restrict__ grad_weight_ptr, T* __restrict__ grad_bias_ptr, volatile ACC_T* staging_data_ptr, int32_t* semaphores_ptr, const IDX_TYPE reduction_size, const IDX_TYPE stride) { // hide latency with concurrency ACC_T sum_dy[PARALLEL_LOADS]; ACC_T sum_dy_xmu[PARALLEL_LOADS]; #pragma unroll for (int i = 0; i < PARALLEL_LOADS; i++) { sum_dy[i] = ACC_T(0); sum_dy_xmu[i] = ACC_T(0); } // tensor dimension (m,c) // loop along m dimension int inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension int m_offset = blockIdx.y * blockDim.y + threadIdx.y; int c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; auto r_mean = mean_ptr[c_offset]; auto factor = inv_std_ptr[c_offset]; for (int i = 0; i < loop_count; i++) { ACC_T x_input[PARALLEL_LOADS]; ACC_T x_grad_output[PARALLEL_LOADS]; // load multiple data in #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { x_input[j] = input_ptr[address_base]; x_grad_output[j] = grad_output_ptr[address_base]; } else { x_input[j] = ACC_T(0); x_grad_output[j] = ACC_T(0); } m_offset += inner_loop_stride; address_base += address_increment; } // calculate sum_dy / sum_dy_xmu #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { sum_dy[j] += x_grad_output[j]; sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); } } // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS #pragma unroll for (int j = 1; j < PARALLEL_LOADS; j++) { sum_dy[0] += sum_dy[j]; sum_dy_xmu[0] += sum_dy_xmu[j]; } // release array of registers auto sum_dy_th = sum_dy[0]; auto sum_dy_xmu_th = sum_dy_xmu[0]; // block-wise reduction with shared memory (since reduction cannot be done within a warp) static __shared__ ACC_T shmem_sum_dy[MAX_BLOCK_SIZE]; static __shared__ ACC_T shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); if (gridDim.y > 1) { volatile ACC_T* staging_sum_dy = staging_data_ptr; volatile ACC_T* staging_sum_dy_xmu = &staging_data_ptr[stride * gridDim.y]; address_base = c_offset + blockIdx.y * stride; // write data to staging_data; if (threadIdx.y == 0 && c_offset < stride) { staging_sum_dy[address_base] = sum_dy_th; staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; } __threadfence(); __syncthreads(); // ensuring writes to staging_ is visible to all blocks __shared__ bool is_last_block_done; // mark block done if (threadIdx.x == 0 && threadIdx.y == 0) { int old = atomicAdd(&semaphores_ptr[blockIdx.x], 1); is_last_block_done = (old == (gridDim.y - 1)); } __syncthreads(); // check that all data is now available in global memory if (is_last_block_done) { sum_dy_th = ACC_T(0.0); sum_dy_xmu_th = ACC_T(0.0); for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { address_base = c_offset + y * stride; sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : ACC_T(0.0)); sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : ACC_T(0.0)); } merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); if (threadIdx.y == 0 && c_offset < stride) { if (grad_bias_ptr != nullptr) { grad_bias_ptr[c_offset] = static_cast(sum_dy_th); } if (grad_weight_ptr != nullptr) { grad_weight_ptr[c_offset] = static_cast(sum_dy_xmu_th * factor); } sum_dy_o_ptr[c_offset] = sum_dy_th; sum_dy_xmu_o_ptr[c_offset] = sum_dy_xmu_th; } } } else { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { if (grad_bias_ptr != nullptr) { grad_bias_ptr[c_offset] = static_cast(sum_dy_th); } if (grad_weight_ptr != nullptr) { grad_weight_ptr[c_offset] = static_cast(sum_dy_xmu_th * factor); } sum_dy_o_ptr[c_offset] = sum_dy_th; sum_dy_xmu_o_ptr[c_offset] = sum_dy_xmu_th; } } } template struct BatchNormBackwardReduceFunctor final { void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size, const int64_t spatial_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, T* sum_dy_ptr, T* sum_dy_xmu_ptr, T* grad_weight_ptr, T* grad_bias_ptr) { using ACC_T = acc_type; int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE / WARP_SIZE); // We want block_x to be at least a warp width int block_x = std::min(std::max(getNumThreads(spatial_size), WARP_SIZE), MAX_BLOCK_SIZE / block_y); const dim3 block(block_x, block_y); const dim3 grid(channel_size); if (batch_size * channel_size * spatial_size < std::numeric_limits::max()) { batch_norm_backward_reduce_kernel <<As()->cuda_stream()>>>( static_cast(batch_size), static_cast(channel_size), static_cast(spatial_size), grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr); } else { batch_norm_backward_reduce_kernel <<As()->cuda_stream()>>>( batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr); } } }; template struct BatchNormBackwardReduceChannelLastFunctor final { void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size, const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, T* sum_dy_ptr, T* sum_dy_xmu_ptr, T* grad_weight_ptr, T* grad_bias_ptr, user_op::Tensor* tmp_buffer) { using ACC_T = acc_type; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); T* staging_data_ptr = nullptr; int32_t* semaphores_ptr = nullptr; if (grid.y > 1) { staging_data_ptr = tmp_buffer->mut_dptr(); semaphores_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + 2 * stride * grid.y * sizeof(T)); } if (stride * reduction_size < std::numeric_limits::max()) { batch_norm_backward_reduce_channels_last_kernel <<As()->cuda_stream()>>>( grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr, staging_data_ptr, semaphores_ptr, static_cast(reduction_size), static_cast(stride)); } else { batch_norm_backward_reduce_channels_last_kernel <<As()->cuda_stream()>>>( grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr, staging_data_ptr, semaphores_ptr, reduction_size, stride); } } }; } // namespace template class GpuBatchNormBackwardReduceKernel final : public user_op::OpKernel { public: GpuBatchNormBackwardReduceKernel() = default; ~GpuBatchNormBackwardReduceKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex("invstd", 0); user_op::Tensor* sum_dy = ctx->Tensor4ArgNameAndIndex("sum_dy", 0); user_op::Tensor* sum_dy_xmu = ctx->Tensor4ArgNameAndIndex("sum_dy_xmu", 0); user_op::Tensor* grad_weight = ctx->Tensor4ArgNameAndIndex("grad_weight", 0); user_op::Tensor* grad_bias = ctx->Tensor4ArgNameAndIndex("grad_bias", 0); const T* grad_out_ptr = grad_out->dptr(); const T* input_ptr = input->dptr(); const T* mean_ptr = mean->dptr(); const T* invstd_ptr = invstd->dptr(); T* sum_dy_ptr = sum_dy->mut_dptr(); T* sum_dy_xmu_ptr = sum_dy_xmu->mut_dptr(); T* grad_weight_ptr = grad_weight->mut_dptr(); T* grad_bias_ptr = grad_bias->mut_dptr(); const int32_t axis = ctx->Attr("axis"); bool use_channels_last_kernel = axis == 1 ? false : true; if (use_channels_last_kernel) { // NHWC format const int64_t stride = input->shape_view().At(axis); const int64_t reduction_size = input->shape_view().elem_cnt() / stride; user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); BatchNormBackwardReduceChannelLastFunctor()( ctx->stream(), stride, reduction_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr, tmp_buffer); } else { // NCHW format const int64_t batch_size = input->shape_view().At(0); const int64_t channel_size = input->shape_view().At(1); const int64_t spatial_size = input->shape_view().Count(2); BatchNormBackwardReduceFunctor()(ctx->stream(), batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(dtype) \ REGISTER_USER_KERNEL("batch_norm_backward_reduce") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("grad_out", 0) == GetDataType::value) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("mean", 0) == GetDataType::value) \ && (user_op::HobDataType("invstd", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpSizeForChannelLastKernel) REGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(float); REGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_norm_elemt_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/user/kernels/batch_norm_kernel_utils.h" // NOTE(Liang Depeng): // The implementation of batch_norm_elemt kernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh namespace oneflow { namespace { template __global__ void batch_norm_transform_input_channels_last_kernel( const T* __restrict__ input_ptr, const T* __restrict__ mean_ptr, const T* __restrict__ inv_std_ptr, const T* __restrict__ weight_ptr, const T* __restrict__ bias_ptr, T* __restrict__ out_ptr, const IDX_TYPE reduction_size, const IDX_TYPE stride) { // tensor dimension (m,c) // loop along m dimension IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y; IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } auto m_c = mean_ptr[c_offset]; auto inv_std_c = static_cast(inv_std_ptr[c_offset]); auto w_c = weight_ptr == nullptr ? T(1.0) : static_cast(weight_ptr[c_offset]); auto b_c = bias_ptr == nullptr ? T(0.0) : static_cast(bias_ptr[c_offset]); IDX_TYPE loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); IDX_TYPE address_base = m_offset * stride + c_offset; IDX_TYPE address_increment = inner_loop_stride * stride; for (IDX_TYPE i = 0; i < loop_count; i++) { #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { out_ptr[address_base] = static_cast(w_c * (static_cast(input_ptr[address_base]) - m_c) * inv_std_c + b_c); } m_offset += inner_loop_stride; address_base += address_increment; } } } template __global__ void batch_norm_transform_input_kernel(const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* bias_ptr, T* output_ptr) { IDX_TYPE channel = blockIdx.x; IDX_TYPE channel_offset = channel * spatial_size; IDX_TYPE batch_step = channel_size * spatial_size; if (channel >= channel_size) { return; } T gamma = static_cast(weight_ptr[channel]); T beta = static_cast(bias_ptr[channel]); T mean = static_cast(mean_ptr[channel]); T invstd = invstd_ptr[channel]; IDX_TYPE bstep = blockDim.y * gridDim.y; for (IDX_TYPE batch = threadIdx.y + blockIdx.y * blockDim.y; batch < batch_size; batch += bstep) { IDX_TYPE offset = batch * batch_step + channel_offset; for (IDX_TYPE feature = threadIdx.x; feature < spatial_size; feature += blockDim.x) { output_ptr[offset + feature] = static_cast(gamma * (input_ptr[offset + feature] - mean) * invstd + beta); } } } template struct BatchNormElemtFunctor final { void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size, const int64_t spatial_size, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* bias_ptr, T* output_ptr) { // The input_transform kernel is pointwise, but we need to balance reading parameters // (save_var/mean, weight/bias) - which we only do once and have a for loop afterwards - with // having many threads and blocks and good occupancy. Quiet likely, we could go with even more // blocks than 1024. The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(spatial_size / 4), std::min(getNumThreads(spatial_size), 64)); int tb = std::max(64 / tf, 1); dim3 blocks_trans(channel_size, std::max(1, std::min((256 * 1024) / channel_size, (batch_size + tb - 1) / tb))); blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); dim3 threads_trans(tf, tb); if (batch_size * channel_size * spatial_size < std::numeric_limits::max()) { batch_norm_transform_input_kernel <<As()->cuda_stream()>>>( static_cast(batch_size), static_cast(channel_size), static_cast(spatial_size), input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr); } else { batch_norm_transform_input_kernel <<As()->cuda_stream()>>>( batch_size, channel_size, spatial_size, input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr); } } }; template struct BatchNormElemtChannelLastFunctor final { void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* bias_ptr, T* output_ptr) { dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid); if (reduction_size * stride < std::numeric_limits::max()) { batch_norm_transform_input_channels_last_kernel <<As()->cuda_stream()>>>( input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr, static_cast(reduction_size), static_cast(stride)); } else { batch_norm_transform_input_channels_last_kernel <<As()->cuda_stream()>>>( input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr, reduction_size, stride); } } }; } // namespace template class GpuBatchNormElemtKernel final : public user_op::OpKernel { public: GpuBatchNormElemtKernel() = default; ~GpuBatchNormElemtKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex("invstd", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const T* input_ptr = input->dptr(); const T* mean_ptr = mean->dptr(); const T* invstd_ptr = invstd->dptr(); const T* weight_ptr = weight->dptr(); const T* bias_ptr = bias->dptr(); T* output_ptr = output->mut_dptr(); const int32_t axis = ctx->Attr("axis"); bool use_channels_last_kernel = axis == 1 ? false : true; if (use_channels_last_kernel) { // NHWC format const int64_t stride = input->shape_view().At(axis); const int64_t reduction_size = input->shape_view().elem_cnt() / stride; BatchNormElemtChannelLastFunctor()(ctx->stream(), stride, reduction_size, input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr); } else { // NCHW format const int64_t batch_size = input->shape_view().At(0); const int64_t channel_size = input->shape_view().At(1); const int64_t spatial_size = input->shape_view().Count(2); BatchNormElemtFunctor()(ctx->stream(), batch_size, channel_size, spatial_size, input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_NORM_ELEMT_KERNEL(dtype) \ REGISTER_USER_KERNEL("batch_norm_elemt") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("mean", 0) == GetDataType::value) \ && (user_op::HobDataType("invstd", 0) == GetDataType::value) \ && (user_op::HobDataType("weight", 0) == GetDataType::value) \ && (user_op::HobDataType("bias", 0) == GetDataType::value)) REGISTER_BATCH_NORM_ELEMT_KERNEL(float); REGISTER_BATCH_NORM_ELEMT_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_norm_gather_stats_with_counts_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/user/kernels/batch_norm_kernel_utils.h" // NOTE(Liang Depeng): // The implementation of batch_norm_gather_stats_with_counts kernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh namespace oneflow { namespace { template __global__ void batch_norm_reduce_statistics_kernel(const int64_t world_size, const int64_t channel_size, const T* mean_ptr, const T* invstd_ptr, const T* counts_ptr, T* global_mean_ptr, T* global_invstd_ptr, T* running_mean_ptr, T* running_var_ptr, const float eps, const float momentum) { IDX_TYPE bid = blockIdx.x; IDX_TYPE tid = threadIdx.x; // first the reductions each thread does separately for (IDX_TYPE i = bid * blockDim.x + tid; i < channel_size; i += gridDim.x * blockDim.x) { ACC_T avg = 0; ACC_T var_n = 0; IDX_TYPE n = 0; for (IDX_TYPE j = 0; j < world_size; j++) { T count = counts_ptr[j]; ACC_T m = mean_ptr[j * channel_size + i]; ACC_T v = ACC_T(1.0) / (invstd_ptr[j * channel_size + i]); v = (v * v - eps) * count; ACC_T factor = 1.0 / (n + count); var_n += v + (avg - m) * (avg - m) * n * count * factor; avg = n * factor * avg + count * factor * m; n += count; } global_mean_ptr[i] = avg; global_invstd_ptr[i] = static_cast(1) / device_sqrt(var_n / n + eps); if (running_mean_ptr != nullptr) { running_mean_ptr[i] = static_cast((1 - momentum) * running_mean_ptr[i] + momentum * avg); } ACC_T unbiasedVar = var_n / (n - 1); if (running_var_ptr != nullptr) { running_var_ptr[i] = static_cast((1 - momentum) * running_var_ptr[i] + momentum * unbiasedVar); } } } template struct BatchNormGatherStatsWithCountsFunctor final { void operator()(ep::Stream* stream, const int64_t world_size, const int64_t channel_size, const T* mean_ptr, const T* invstd_ptr, const T* counts_ptr, T* global_mean_ptr, T* global_invstd_ptr, T* running_mean_ptr, T* running_var_ptr, const float eps, const float momentum) { using ACC_T = acc_type; int32_t block = getNumThreads(channel_size); int32_t grid = std::max(1, channel_size / block); if (world_size * channel_size < std::numeric_limits::max()) { batch_norm_reduce_statistics_kernel <<As()->cuda_stream()>>>( static_cast(world_size), static_cast(channel_size), mean_ptr, invstd_ptr, counts_ptr, global_mean_ptr, global_invstd_ptr, running_mean_ptr, running_var_ptr, eps, momentum); } else { batch_norm_reduce_statistics_kernel <<As()->cuda_stream()>>>( world_size, channel_size, mean_ptr, invstd_ptr, counts_ptr, global_mean_ptr, global_invstd_ptr, running_mean_ptr, running_var_ptr, eps, momentum); } } }; } // namespace template class GpuBatchNormGatherStatsWithCountsKernel final : public user_op::OpKernel { public: GpuBatchNormGatherStatsWithCountsKernel() = default; ~GpuBatchNormGatherStatsWithCountsKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex("invstd", 0); const user_op::Tensor* counts = ctx->Tensor4ArgNameAndIndex("counts", 0); user_op::Tensor* global_mean = ctx->Tensor4ArgNameAndIndex("global_mean", 0); user_op::Tensor* global_invstd = ctx->Tensor4ArgNameAndIndex("global_invstd", 0); const T* mean_ptr = mean->dptr(); const T* invstd_ptr = invstd->dptr(); const T* counts_ptr = counts->dptr(); T* global_mean_ptr = global_mean->mut_dptr(); T* global_invstd_ptr = global_invstd->mut_dptr(); T* running_mean_ptr = nullptr; T* running_var_ptr = nullptr; if (ctx->has_input("running_mean", 0)) { CHECK(ctx->has_input("running_var", 0)); running_mean_ptr = ctx->Tensor4ArgNameAndIndex("running_mean", 0)->mut_dptr(); running_var_ptr = ctx->Tensor4ArgNameAndIndex("running_var", 0)->mut_dptr(); } const float eps = ctx->Attr("eps"); const float momentum = ctx->Attr("momentum"); const int64_t world_size = mean->shape_view().At(0); const int64_t channel_size = mean->shape_view().At(1); BatchNormGatherStatsWithCountsFunctor()( ctx->stream(), world_size, channel_size, mean_ptr, invstd_ptr, counts_ptr, global_mean_ptr, global_invstd_ptr, running_mean_ptr, running_var_ptr, eps, momentum); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(dtype) \ REGISTER_USER_KERNEL("batch_norm_gather_stats_with_counts") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("mean", 0) == GetDataType::value) \ && (user_op::HobDataType("invstd", 0) == GetDataType::value) \ && (user_op::HobDataType("counts", 0) == GetDataType::value)) REGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(float); REGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/batch_norm_kernel_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_ #define ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_ // NOTE(Liang Depeng): // Modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh #if defined(__CUDACC__) constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency constexpr int ELEMENTS_PER_THREAD = 16; constexpr int OPTIMAL_TILE_W = 32; constexpr int MAX_H_BLOCK = 128; constexpr int32_t MAX_BLOCK_SIZE = 512; constexpr unsigned MAX_GRID_SIZE = 65535u; #define WARP_SIZE 32 // returns 2**floor(log2(n)) static int lastPow2(unsigned int n) { n |= (n >> 1); n |= (n >> 2); n |= (n >> 4); n |= (n >> 8); n |= (n >> 16); return std::max(1, n - (n >> 1)); } /** Computes ceil(a / b) */ template::value>> static T ceil_div(T a, T b) { return (a + b - 1) / b; } static void flexible_launch_configs(const int reduction, const int stride, dim3& block, dim3& grid, const bool coop_flag = false) { int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); int block_y = std::min(lastPow2(ceil_div(reduction, ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x); if (block_x * block_y != MAX_BLOCK_SIZE) { block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); } int grid_x = ceil_div(stride, block_x); int grid_y = std::min(ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); if (coop_flag) { // it's not worth having a grid reduction if the reduction dimension is not big enough grid_y = grid_y < 8 ? 1 : grid_y; } block.x = block_x; block.y = block_y; block.z = 1; grid.x = grid_x; grid.y = grid_y; grid.z = 1; } template struct AccumulateType {}; template<> struct AccumulateType { using type = float; }; template<> struct AccumulateType { using type = double; }; template using acc_type = typename AccumulateType::type; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int32_t getNumThreads(int64_t nElem) { int32_t threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; for (int32_t i = 0; i != 5; ++i) { if (nElem <= threadSizes[i]) { return threadSizes[i]; } } return MAX_BLOCK_SIZE; } template static __forceinline__ __device__ T device_sqrt(T val); template<> __forceinline__ __device__ float device_sqrt(float val) { return ::sqrtf(val); } template<> __forceinline__ __device__ double device_sqrt(double val) { return ::sqrt(val); } template __device__ __forceinline__ T inv_std(T var, double eps) { T invstd = 0; if (var != static_cast(0) || eps != static_cast(0)) { invstd = static_cast(1) / device_sqrt(var + eps); } return invstd; } // Returns the index of the most significant 1 bit in `val`. __device__ __forceinline__ int32_t getMSB(int32_t val) { return 31 - __clz(val); } template __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { return __shfl_xor_sync(mask, value, laneMask, width); } #endif #endif // ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_ ================================================ FILE: oneflow/user/kernels/batch_norm_stats_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/user/kernels/batch_norm_kernel_utils.h" // NOTE(Liang Depeng): // The implementation of batch_norm_stats kernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh namespace oneflow { namespace { template static size_t InferTmpSizeForChannelLastKernel(user_op::InferContext* ctx) { const int32_t axis = ctx->Attr("axis"); const Shape& in_shape = ctx->InputTensorDesc("input", 0).shape(); const int64_t stride = in_shape.At(axis); const int64_t reduction_size = in_shape.elem_cnt() / stride; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); size_t tmp_size = 0; if (grid.y > 1) { tmp_size += 4 * stride * grid.y * sizeof(T); tmp_size += grid.x * sizeof(int32_t); } return tmp_size; } template __device__ __forceinline__ void welford_merge_element(C& count, T& mean, T& m2n, const C& count_new, const T& mean_new, const T& m2n_new) { T factor = T(1.0) / ::max(C(1), (count + count_new)); T delta0 = mean - mean_new; mean = (mean_new * count_new + mean * count) * factor; m2n += m2n_new + delta0 * delta0 * count_new * count * factor; count += count_new; } // merge mean/m2n among threadIdx.y within block template __device__ __forceinline__ void welford_merge_block_vertical(C& count, T& mean, T& m2n, C* shmem_count, T* shmem_mean, T* shmem_m2n) { // write to shared memory auto address_base = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { if (threadIdx.y < offset * 2) { shmem_mean[address_base] = mean; shmem_m2n[address_base] = m2n; shmem_count[address_base] = count; } __syncthreads(); if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { auto address = address_base + offset * blockDim.x; // read shared memory back to register for reduction auto count_new = shmem_count[address]; auto mean_new = shmem_mean[address]; auto m2n_new = shmem_m2n[address]; welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); } } } template __global__ void batch_norm_collect_statistics_channels_last_kernel( const T* __restrict__ input_ptr, ACC_T* __restrict__ out_mean_ptr, ACC_T* __restrict__ out_invstd_ptr, volatile ACC_T* staging_data_ptr, int32_t* semaphores_ptr, const IDX_TYPE reduction_size, const IDX_TYPE stride, ACC_T epsilon) { // hide latency with concurrency ACC_T x_mean[PARALLEL_LOADS]; ACC_T m_2_n[PARALLEL_LOADS]; IDX_TYPE count[PARALLEL_LOADS]; #pragma unroll for (IDX_TYPE i = 0; i < PARALLEL_LOADS; i++) { x_mean[i] = ACC_T(0); m_2_n[i] = ACC_T(0); count[i] = ACC_T(0); } // tensor dimension (m,c) // loop along m dimension IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y; IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x; IDX_TYPE loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); IDX_TYPE address_base = m_offset * stride + c_offset; IDX_TYPE address_increment = inner_loop_stride * stride; for (IDX_TYPE i = 0; i < loop_count; i++) { ACC_T x_math[PARALLEL_LOADS]; ACC_T x_count_inv[PARALLEL_LOADS]; ACC_T is_valid[PARALLEL_LOADS]; // load multiple data in #pragma unroll for (IDX_TYPE j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { x_math[j] = input_ptr[address_base]; count[j]++; x_count_inv[j] = ACC_T(1) / count[j]; is_valid[j] = ACC_T(1); } else { x_math[j] = ACC_T(0); x_count_inv[j] = ACC_T(0); is_valid[j] = ACC_T(0); } m_offset += inner_loop_stride; address_base += address_increment; } // calculate mean/m2n with welford #pragma unroll for (IDX_TYPE j = 0; j < PARALLEL_LOADS; j++) { ACC_T delta0 = x_math[j] - x_mean[j]; x_mean[j] += delta0 * x_count_inv[j]; ACC_T delta1 = x_math[j] - x_mean[j]; m_2_n[j] += delta0 * delta1 * is_valid[j]; } } // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS #pragma unroll for (IDX_TYPE j = 1; j < PARALLEL_LOADS; j++) { welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); } // release x_mean / m_2_n auto mean_th = x_mean[0]; auto m2_th = m_2_n[0]; auto count_th = count[0]; // block-wise reduction with shared memory (since reduction cannot be done within a warp) static __shared__ ACC_T shmem_mean[MAX_BLOCK_SIZE]; static __shared__ ACC_T shmem_m2n[MAX_BLOCK_SIZE]; static __shared__ IDX_TYPE shmem_count[MAX_BLOCK_SIZE]; welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); if (gridDim.y > 1) { volatile ACC_T* staging_mean = staging_data_ptr; volatile ACC_T* staging_m2n = &staging_data_ptr[stride * gridDim.y]; volatile IDX_TYPE* staging_count = reinterpret_cast(&staging_m2n[stride * gridDim.y]); address_base = c_offset + blockIdx.y * stride; // write data to staging_data_ptr; if (threadIdx.y == 0 && c_offset < stride) { staging_mean[address_base] = mean_th; staging_m2n[address_base] = m2_th; staging_count[address_base] = count_th; } __threadfence(); __syncthreads(); // ensuring writes to staging_ is visible to all blocks __shared__ bool is_last_block_done; // mark block done if (threadIdx.x == 0 && threadIdx.y == 0) { IDX_TYPE old = atomicAdd(&semaphores_ptr[blockIdx.x], 1); is_last_block_done = (old == (gridDim.y - 1)); } __syncthreads(); // check that all data is now available in global memory if (is_last_block_done) { count_th = 0; mean_th = ACC_T(0.0); m2_th = ACC_T(0.0); for (IDX_TYPE y = threadIdx.y; y < gridDim.y; y += blockDim.y) { address_base = c_offset + y * stride; IDX_TYPE count_new = c_offset < stride ? staging_count[address_base] : 0; ACC_T mean_new = c_offset < stride ? staging_mean[address_base] : ACC_T(0.0); ACC_T m2n_new = c_offset < stride ? staging_m2n[address_base] : ACC_T(0.0); welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); } welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); if (threadIdx.y == 0 && c_offset < stride) { out_mean_ptr[c_offset] = static_cast(mean_th); out_invstd_ptr[c_offset] = inv_std(m2_th / count_th, epsilon); } } } else { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { out_mean_ptr[c_offset] = static_cast(mean_th); out_invstd_ptr[c_offset] = inv_std(m2_th / count_th, epsilon); } } } template __global__ void batch_norm_collect_statistics_kernel(const T* input_ptr, const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size, const ACC_T eps, T* mean_ptr, T* invstd_ptr) { __shared__ IDX_TYPE shared_n[2 * 2 * WARP_SIZE + WARP_SIZE]; IDX_TYPE channel_idx = blockIdx.x; IDX_TYPE N = batch_size * spatial_size; IDX_TYPE tid = threadIdx.x + threadIdx.y * blockDim.x; // Compute the mean and variance across (batch, x/y/z) // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm // and the parallel algorithm on the same page. // We use two shuffles to reduce across the entire block. // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. ACC_T* shared_avg_var = (ACC_T*)&shared_n[WARP_SIZE]; // first the reductions each thread does separately ACC_T avg = 0; ACC_T var_n = 0; IDX_TYPE n = 0; const IDX_TYPE channel_offset = channel_idx * spatial_size; const IDX_TYPE batch_offset = channel_size * spatial_size; for (IDX_TYPE batch = threadIdx.y; batch < batch_size; batch += blockDim.y) { IDX_TYPE offset = batch * batch_offset + channel_offset; for (IDX_TYPE x = threadIdx.x; x < spatial_size; x += blockDim.x) { ACC_T v = input_ptr[offset + x]; ACC_T d1 = v - avg; n++; avg += d1 / n; var_n += d1 * (v - avg); } } // summing the result of all the threads within a warp // refer to: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm // first warpSum to get one value per thread to one value per warp for (IDX_TYPE i = 0; i < getMSB(WARP_SIZE); ++i) { ACC_T o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); IDX_TYPE o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); ACC_T factor = 1.0 / fmaxf(1.0, n + o_n); var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } // this writes each warp's final sum result into shared memory // there are at most (thread_number_of_a_block / WARP_SIZE) results __syncthreads(); if (tid % WARP_SIZE == 0) { shared_n[tid / WARP_SIZE] = n; shared_avg_var[tid / WARP_SIZE * 2] = avg; shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n; } __syncthreads(); // now have a second warpSum to reduce the intermediate values // from shared memory to a single number. The very first // thread writes it to shared memory. if (tid < WARP_SIZE) { // initialize n, avg and var_n of each thread within the first warp n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0); avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : ACC_T(0)); var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : ACC_T(0)); for (IDX_TYPE i = 0; i < getMSB(WARP_SIZE); ++i) { ACC_T o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); IDX_TYPE o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); ACC_T factor = 1.0 / fmaxf(1.0, n + o_n); var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } } // save the mean and inverse standard deviation if (tid == 0) { mean_ptr[channel_idx] = avg; invstd_ptr[channel_idx] = inv_std(var_n / N, eps); } } template struct BatchNormStatsFunctor final { void operator()(ep::Stream* stream, const user_op::Tensor* input, user_op::Tensor* mean, user_op::Tensor* invstd, const float eps) { using ACC_T = acc_type; const ShapeView& input_shape = input->shape_view(); const int64_t input_numel = input_shape.elem_cnt(); const int64_t spatial_size = input_shape.Count(2); dim3 blocks(input_shape.At(1)); int32_t tf = getNumThreads(spatial_size); dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE / tf)); const T* input_ptr = input->dptr(); T* mean_ptr = mean->mut_dptr(); T* invstd_ptr = invstd->mut_dptr(); if (input_numel < std::numeric_limits::max()) { batch_norm_collect_statistics_kernel <<As()->cuda_stream()>>>( input_ptr, static_cast(input_shape.At(0)), static_cast(input_shape.At(1)), static_cast(spatial_size), eps, mean_ptr, invstd_ptr); } else { batch_norm_collect_statistics_kernel <<As()->cuda_stream()>>>( input_ptr, input_shape.At(0), input_shape.At(1), spatial_size, eps, mean_ptr, invstd_ptr); } } }; template struct BatchNormStatsChannelLastFunctor final { void operator()(ep::Stream* stream, const user_op::Tensor* input, user_op::Tensor* mean, user_op::Tensor* invstd, user_op::Tensor* tmp_buffer, const float eps, const int32_t axis) { using ACC_T = acc_type; const ShapeView& input_shape = input->shape_view(); const int64_t stride = input_shape.At(axis); const int64_t reduction_size = input_shape.elem_cnt() / stride; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); T* staging_data_ptr = nullptr; int32_t* semaphores_ptr = nullptr; if (grid.y > 1) { staging_data_ptr = tmp_buffer->mut_dptr(); semaphores_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + 4 * stride * grid.y * sizeof(T)); } const T* input_ptr = input->dptr(); T* mean_ptr = mean->mut_dptr(); T* invstd_ptr = invstd->mut_dptr(); if (input_shape.elem_cnt() < std::numeric_limits::max()) { batch_norm_collect_statistics_channels_last_kernel <<As()->cuda_stream()>>>( input_ptr, mean_ptr, invstd_ptr, staging_data_ptr, semaphores_ptr, static_cast(reduction_size), static_cast(stride), eps); } else { batch_norm_collect_statistics_channels_last_kernel <<As()->cuda_stream()>>>( input_ptr, mean_ptr, invstd_ptr, staging_data_ptr, semaphores_ptr, reduction_size, stride, eps); } } }; } // namespace template class GpuBatchNormStatsKernel final : public user_op::OpKernel { public: GpuBatchNormStatsKernel() = default; ~GpuBatchNormStatsKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex("invstd", 0); const int32_t axis = ctx->Attr("axis"); const float eps = ctx->Attr("eps"); bool use_channels_last_kernel = axis == 1 ? false : true; if (use_channels_last_kernel) { // NHWC format user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); BatchNormStatsChannelLastFunctor()(ctx->stream(), input, mean, invstd, tmp_buffer, eps, axis); } else { // NCHW format BatchNormStatsFunctor()(ctx->stream(), input, mean, invstd, eps); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BATCH_NORM_STATS_KERNEL(dtype) \ REGISTER_USER_KERNEL("batch_norm_stats") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpSizeForChannelLastKernel) REGISTER_BATCH_NORM_STATS_KERNEL(float); REGISTER_BATCH_NORM_STATS_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/bernoulli_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/random_mask_generator.h" namespace oneflow { template class BernoulliKerenl final : public user_op::OpKernel { public: BernoulliKerenl() = default; ~BernoulliKerenl() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const T* in_dptr = in_blob->dptr(); K* out_dptr = out_blob->mut_dptr(); CHECK_EQ(GetDataType(), in_blob->data_type()); CHECK_EQ(GetDataType(), out_blob->data_type()); CHECK_EQ(in_blob->shape_view().elem_cnt(), out_blob->shape_view().elem_cnt()); auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const auto& generator = kernel_state->generator(); CHECK_NOTNULL(generator); const auto& cpu_generator = CHECK_JUST(generator->Get()); double p = ctx->Attr("p"); // prob != -1 means use prob instead of tensor to generate random number if (p != static_cast(-1.0)) { for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) { std::bernoulli_distribution dis(p); *(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal() : GetZeroVal(); } } else { for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) { double prob = static_cast(*(in_dptr + i)); CHECK(prob >= 0.0 && prob <= 1.0); std::bernoulli_distribution dis(prob); *(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal() : GetZeroVal(); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BERNOULLI_KERNEL(in_dtype_pair, out_dtype_pair) \ REGISTER_USER_KERNEL("bernoulli") \ .SetCreateFn< \ BernoulliKerenl>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == OF_PP_PAIR_SECOND(in_dtype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BERNOULLI_KERNEL, FLOATING_DATA_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/bias_add_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace { template std::unique_ptr NewPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("a", 0)->data_type(); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kAdd, data_type, data_type, 3); } class BiasAddUserKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BiasAddUserKernel() = default; ~BiasAddUserKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* a_tensor = ctx->Tensor4ArgNameAndIndex("a", 0); const auto* b_tensor = ctx->Tensor4ArgNameAndIndex("b", 0); if (a_tensor->shape_view().elem_cnt() == 0 || b_tensor->shape_view().elem_cnt() == 0) { return; } auto* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t bias_add_axis = ctx->Attr("axis"); const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis); const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis); const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1); auto primitive = NewPrimitive(ctx); const int64_t src0_dims[3] = {outer_size, bias_size, inner_size}; const int64_t src1_dims[3] = {1, bias_size, 1}; primitive->Launch(ctx->stream(), 3, src0_dims, a_tensor->dptr(), 3, src1_dims, b_tensor->dptr(), out_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto PrimitiveExists() { return hob::make_custom("PrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("bias_add") .SetCreateFn() .SetIsMatchedHob(PrimitiveExists() == true) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "a", 0, true)); return Maybe::Ok(); }); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_concat_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template __global__ void BinaryConcatKernel(const IDX out_elems, const IDX out_cols, const IDX in0_cols, const IDX in1_cols, const T* src0, const T* src1, T* dst) { CUDA_1D_KERNEL_LOOP_T(IDX, i, out_elems) { const IDX row = i / out_cols; const IDX col = i - row * out_cols; const T* src_ptr = nullptr; if (col < in0_cols) { src_ptr = src0 + row * in0_cols + col; } else { src_ptr = src1 + row * in1_cols + (col - in0_cols); } dst[i] = *src_ptr; } } template void LaunchBinaryConcatKernel(ep::Stream* stream, const IDX rows, const IDX in0_cols, const IDX in1_cols, const void* src0, const void* src1, void* dst) { const IDX out_cols = in0_cols + in1_cols; const IDX out_elems = rows * out_cols; RUN_CUDA_KERNEL((BinaryConcatKernel), stream, out_elems, out_elems, out_cols, in0_cols, in1_cols, reinterpret_cast(src0), reinterpret_cast(src1), reinterpret_cast(dst)); } template void DispatchIndexType(ep::Stream* stream, const int64_t rows, const int64_t in0_cols, const int64_t in1_cols, const void* src0, const void* src1, void* dst) { if (rows * (in0_cols + in1_cols) >= (1 >> 30)) { LaunchBinaryConcatKernel(stream, rows, in0_cols, in1_cols, src0, src1, dst); } else { LaunchBinaryConcatKernel(stream, rows, in0_cols, in1_cols, src0, src1, dst); } } void DispatchDataType(ep::Stream* stream, const int64_t rows, const int64_t in0_cols, const int64_t in1_cols, const void* src0, const void* src1, void* dst) { const uintptr_t src0_ptr = reinterpret_cast(src0); const uintptr_t src1_ptr = reinterpret_cast(src1); const uintptr_t dst_ptr = reinterpret_cast(dst); const auto IsAligned = [&](const size_t alignment) { return src0_ptr % alignment == 0 && src1_ptr % alignment == 0 && dst_ptr % alignment == 0 && in0_cols % alignment == 0 && in1_cols % alignment == 0; }; if (IsAligned(16)) { DispatchIndexType(stream, rows, in0_cols / 16, in1_cols / 16, src0, src1, dst); } else if (IsAligned(8)) { DispatchIndexType(stream, rows, in0_cols / 8, in1_cols / 8, src0, src1, dst); } else if (IsAligned(4)) { DispatchIndexType(stream, rows, in0_cols / 4, in1_cols / 4, src0, src1, dst); } else if (IsAligned(2)) { DispatchIndexType(stream, rows, in0_cols / 2, in1_cols / 2, src0, src1, dst); } else { DispatchIndexType(stream, rows, in0_cols, in1_cols, src0, src1, dst); } } void DispatchBinaryConcat(ep::Stream* stream, const int64_t elem_size, const int64_t rows, const int64_t in0_cols, const int64_t in1_cols, const void* src0, const void* src1, void* dst) { DispatchDataType(stream, rows, in0_cols * elem_size, in1_cols * elem_size, src0, src1, dst); } class ConcatKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ConcatKernel() = default; ~ConcatKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const DataType data_type = out_tensor->data_type(); if (out_tensor->shape_view().elem_cnt() == 0) { return; } const int64_t axis = ctx->Attr("axis"); CHECK_GE(axis, 0); const int64_t num_axes = out_tensor->shape_view().NumAxes(); CHECK_LT(axis, num_axes); const int64_t out_cols = out_tensor->shape_view().Count(axis); const int64_t rows = out_tensor->shape_view().elem_cnt() / out_cols; CHECK_GT(rows, 0); CHECK_EQ(ctx->input_size("in"), 2); const user_op::Tensor* in0_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* in1_tensor = ctx->Tensor4ArgNameAndIndex("in", 1); CHECK_EQ(in0_tensor->data_type(), data_type); CHECK_EQ(in1_tensor->data_type(), data_type); if (in0_tensor->shape_view().elem_cnt() == 0) { CHECK_EQ(in1_tensor->shape_view(), out_tensor->shape_view()); Memcpy(ctx->stream(), out_tensor->mut_dptr(), in1_tensor->dptr(), out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); return; } if (in1_tensor->shape_view().elem_cnt() == 0) { CHECK_EQ(in0_tensor->shape_view(), out_tensor->shape_view()); Memcpy(ctx->stream(), out_tensor->mut_dptr(), in0_tensor->dptr(), out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); return; } CHECK_EQ(in0_tensor->shape_view().NumAxes(), num_axes); CHECK_EQ(in1_tensor->shape_view().NumAxes(), num_axes); for (int64_t i = 0; i < num_axes; ++i) { if (i != axis) { CHECK_EQ(in0_tensor->shape_view().At(i), out_tensor->shape_view().At(i)); CHECK_EQ(in1_tensor->shape_view().At(i), out_tensor->shape_view().At(i)); } } CHECK_EQ(in0_tensor->shape_view().At(axis) + in1_tensor->shape_view().At(axis), out_tensor->shape_view().At(axis)); const int64_t in0_cols = in0_tensor->shape_view().Count(axis); const int64_t in1_cols = in1_tensor->shape_view().Count(axis); DispatchBinaryConcat(ctx->stream(), GetSizeOfDataType(data_type), rows, in0_cols, in1_cols, in0_tensor->dptr(), in1_tensor->dptr(), out_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace REGISTER_USER_KERNEL("cat") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobInputSize("in") == 2)) .SetPriority(user_op::kKernelPriorityOptimized); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/loss_kernel_util.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template void ComputeBinaryCrossEntropyOut(int64_t elem_cnt, const T* input, const T* target, T* out, const T* weight) { T negative_100 = static_cast(-100); FOR_RANGE(int64_t, i, 0, elem_cnt) { T input_val = input[i]; T target_val = target[i]; CHECK_LE(input_val, 1.0); CHECK_GE(input_val, 0.0); out[i] = (target_val - 1) * std::max(static_cast(std::log(1.0 - input_val)), negative_100) - target_val * std::max(static_cast(std::log(input_val)), negative_100); if (weight != nullptr) { out[i] *= weight[i]; } } } template void ComputeBinaryCrossEntropyGradOut(int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx, const T* weight) { const T eps = static_cast(1e-12); FOR_RANGE(int64_t, i, 0, elem_cnt) { T input_val = input[i]; T target_val = target[i]; T dy_val = dy[i]; dx[i] = dy_val * (input_val - target_val) / (std::max((static_cast(1.0) - input_val) * input_val, eps)); if (weight != nullptr) { dx[i] *= weight[i]; } } } template class BinaryCrossEntropyKernel final : public user_op::OpKernel { public: BinaryCrossEntropyKernel() = default; ~BinaryCrossEntropyKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* out = out_blob->mut_dptr(); const T* weight = ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; ComputeBinaryCrossEntropyOut(elem_cnt, input, target, out, weight); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryCrossEntropyGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyGradKernel() = default; ~BinaryCrossEntropyGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* dy = dy_blob->dptr(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* dx = dx_blob->mut_dptr(); const T* weight = ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; ComputeBinaryCrossEntropyGradOut(elem_cnt, input, target, dy, dx, weight); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); #define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_BINARY_CROSS_ENTROPY_KERNEL(float) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(double) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/loss_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template struct BinaryCrossEntropyFunctor { T zero_; T one_; T negative_hundred_; BinaryCrossEntropyFunctor() : zero_(GetZeroVal()), one_(GetOneVal()), negative_hundred_(static_cast(-100)) {} __device__ __forceinline__ T operator()(T input_val, T target_val) const { assert(input_val >= zero_); assert(input_val <= one_); return (target_val - one_) * max(static_cast(log(one_ - input_val)), negative_hundred_) - target_val * max(static_cast(log(input_val)), negative_hundred_); } __device__ __forceinline__ T operator()(T input_val, T target_val, T weight_val) const { return (*this)(input_val, target_val) * weight_val; } }; template<> struct BinaryCrossEntropyFunctor { float zero_; float one_; float negative_hundred_; BinaryCrossEntropyFunctor() : zero_(0.f), one_(1.f), negative_hundred_(-100.f) {} __device__ __forceinline__ float operator()(float input_val, float target_val) const { assert(input_val >= zero_); assert(input_val <= one_); return (target_val - one_) * max(logf(one_ - input_val), negative_hundred_) - target_val * max(logf(input_val), negative_hundred_); } __device__ __forceinline__ float operator()(float input_val, float target_val, float weight_val) const { return (*this)(input_val, target_val) * weight_val; } }; template<> struct BinaryCrossEntropyFunctor { BinaryCrossEntropyFunctor float_functor; __device__ __forceinline__ half operator()(half input_val, half target_val) const { return __float2half(float_functor(__half2float(input_val), __half2float(target_val))); } __device__ __forceinline__ half operator()(half input_val, half target_val, half weight_val) const { return (*this)(input_val, target_val) * weight_val; } }; template struct BinaryCrossEntropyGradFunctor { T eps_; T one_; BinaryCrossEntropyGradFunctor() : eps_(static_cast(1e-12)), one_(GetOneVal()) {} __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val) const { return dy_val * (input_val - target_val) / max((one_ - input_val) * input_val, eps_); } __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val, T weight_val) const { return (*this)(input_val, target_val, dy_val) * weight_val; } }; template<> struct BinaryCrossEntropyGradFunctor { BinaryCrossEntropyGradFunctor float_functor; BinaryCrossEntropyGradFunctor() {} __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const { return __float2half( float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val))); } __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, half weight_val) const { return __float2half(float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val), __half2float(weight_val))); } }; template class BinaryCrossEntropyKernel final : public user_op::OpKernel { public: BinaryCrossEntropyKernel() = default; ~BinaryCrossEntropyKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* out = out_blob->mut_dptr(); if (ctx->has_input("weight", 0)) { const T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); OF_CUDA_CHECK( (cuda::elementwise::Ternary(BinaryCrossEntropyFunctor(), elem_cnt, out, input, target, weight, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK( (cuda::elementwise::Binary(BinaryCrossEntropyFunctor(), elem_cnt, out, input, target, ctx->stream()->As()->cuda_stream()))); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryCrossEntropyGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyGradKernel() = default; ~BinaryCrossEntropyGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* dy = dy_blob->dptr(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* dx = dx_blob->mut_dptr(); if (ctx->has_input("weight", 0)) { const T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); using FunctorT = BinaryCrossEntropyGradFunctor; using FactoryT = cuda::elementwise::SimpleFactory; OF_CUDA_CHECK((cuda::elementwise::GenericLauncher::Launch( FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, weight, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK((cuda::elementwise::Ternary( BinaryCrossEntropyGradFunctor(), elem_cnt, dx, input, target, dy, ctx->stream()->As()->cuda_stream()))); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); #define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_BINARY_CROSS_ENTROPY_KERNEL(half) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(float) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(double) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_with_logits_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/user/kernels/loss_kernel_util.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template inline T ComputeMaxVal(const T x) { T y = -x; return y < 0 ? 0 : y; } template inline T CalSigmoid(const T x) { const T half_of_one = static_cast(0.5); return half_of_one * std::tanh(half_of_one * x) + half_of_one; } template void ComputeBinaryCrossEntropyWithLogitsOut(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target, TARGET_T* out, const TARGET_T* weight, const TARGET_T* pos_weight_processed) { FOR_RANGE(int64_t, i, 0, elem_cnt) { TARGET_T input_val = static_cast(input[i]); TARGET_T target_val = target[i]; TARGET_T max_val = ComputeMaxVal(input_val); if (out != nullptr) { if (pos_weight_processed == nullptr) { out[i] = (1 - target_val) * input_val + max_val + (std::log(std::exp(-max_val) + std::exp(-input_val - max_val))); } else { TARGET_T pos_weight_processed_val = pos_weight_processed[i] - target_val + 1; out[i] = (1 - target_val) * input_val + (pos_weight_processed_val * (std::log(std::exp(-max_val) + std::exp(-input_val - max_val)) + max_val)); } } if (weight != nullptr && out != nullptr) { out[i] *= weight[i]; } } } template void ComputeBinaryCrossEntropyWithLogitsGradOut(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target, const TARGET_T* dy, INPUT_T* dx, const TARGET_T* weight, const TARGET_T* pos_weight_processed) { FOR_RANGE(int64_t, i, 0, elem_cnt) { INPUT_T input_val = input[i]; TARGET_T target_val = target[i]; TARGET_T dy_val = dy[i]; TARGET_T input_sigmoid = static_cast(CalSigmoid(input_val)); TARGET_T dx_i_buffer = 0.0; if (pos_weight_processed == nullptr) { dx_i_buffer = (input_sigmoid - target_val) * dy_val; } else { dx_i_buffer = dy_val * ((pos_weight_processed[i] + 1 - target_val) * input_sigmoid - pos_weight_processed[i]); } if (weight != nullptr) { dx_i_buffer *= weight[i]; } dx[i] = static_cast(dx_i_buffer); } } template class BinaryCrossEntropyWithLogitsKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsKernel() = default; ~BinaryCrossEntropyWithLogitsKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); TARGET_T* out = out_blob->mut_dptr(); const TARGET_T* weight = ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; TARGET_T* pos_weight_processed = nullptr; if (ctx->Attr("has_pos_weight")) { pos_weight_processed = tmp_buffer_blob->mut_dptr(); const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->dptr(); Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes()); pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1, ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->shape_view().elem_cnt()); auto bcast_mul = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(), target_blob->data_type(), target_blob->shape_view().NumAxes()); CHECK(bcast_mul); bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(), target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(), pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed); } ComputeBinaryCrossEntropyWithLogitsOut(elem_cnt, input, target, out, weight, pos_weight_processed); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryCrossEntropyWithLogitsGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsGradKernel() = default; ~BinaryCrossEntropyWithLogitsGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const TARGET_T* dy = dy_blob->dptr(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); INPUT_T* dx = dx_blob->mut_dptr(); const TARGET_T* weight = ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; TARGET_T* pos_weight_processed = nullptr; if (ctx->Attr("has_pos_weight")) { pos_weight_processed = tmp_buffer_blob->mut_dptr(); const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->dptr(); Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes()); pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1, ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->shape_view().elem_cnt()); auto bcast_mul = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(), target_blob->data_type(), target_blob->shape_view().NumAxes()); CHECK(bcast_mul); bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(), target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(), pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed); } ComputeBinaryCrossEntropyWithLogitsGradOut(elem_cnt, input, target, dy, dx, weight, pos_weight_processed); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenFwInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const int64_t n = ctx->InputShape("target", 0).elem_cnt(); size_t tmp_buffer_size = 0; if (ctx->Attr("has_pos_weight")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); } return tmp_buffer_size; }; } template user_op::InferTmpSizeFn GenBwInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const int64_t n = ctx->InputShape("target", 0).elem_cnt(); size_t tmp_buffer_size = 0; if (ctx->Attr("has_pos_weight")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); } return tmp_buffer_size; }; } } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(GenFwInferTmpSizeFn()); #define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(GenBwInferTmpSizeFn()); REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(double, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(double, double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_with_logits_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/user/kernels/loss_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace user_op { namespace { using namespace loss; enum class WeightType { kNone, kWeight, kPosWeight, kBoth, }; template struct BinaryCrossEntropyWithLogitsFunctor; template struct BinaryCrossEntropyWithLogitsFunctor { TARGET_T zero_; TARGET_T one_; BinaryCrossEntropyWithLogitsFunctor() : zero_(GetZeroVal()), one_(GetOneVal()) {} __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val) const { const TARGET_T input_val_ = static_cast(input_val); const TARGET_T max_val = -input_val_ < zero_ ? zero_ : -input_val_; return (one_ - target_val) * input_val_ + max_val + (log(exp(-max_val) + exp(-input_val_ - max_val))); } }; template struct BinaryCrossEntropyWithLogitsFunctor { TARGET_T zero_; TARGET_T one_; BinaryCrossEntropyWithLogitsFunctor() : zero_(GetZeroVal()), one_(GetOneVal()) {} __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T weight_val) const { const TARGET_T input_val_ = static_cast(input_val); const TARGET_T max_val = -input_val_ < zero_ ? zero_ : -input_val_; const TARGET_T pos_weight_processed_val = weight_val - target_val + one_; return (one_ - target_val) * input_val_ + (pos_weight_processed_val * (log(exp(-max_val) + exp(-input_val_ - max_val)) + max_val)); } }; template struct BinaryCrossEntropyWithLogitsFunctor { float zero_; float one_; BinaryCrossEntropyWithLogitsFunctor() : zero_(0.f), one_(1.f) {} __device__ __forceinline__ float operator()(INPUT_T input_val, float target_val) const { const float input_val_ = static_cast(input_val); const float max_val = -input_val_ < zero_ ? zero_ : -input_val_; return (one_ - target_val) * input_val_ + max_val + (logf(expf(-max_val) + expf(-input_val_ - max_val))); } }; template struct BinaryCrossEntropyWithLogitsFunctor { float zero_; float one_; BinaryCrossEntropyWithLogitsFunctor() : zero_(0.f), one_(1.f) {} __device__ __forceinline__ float operator()(INPUT_T input_val, float target_val, float weight_val) const { const float input_val_ = static_cast(input_val); const float max_val = -input_val_ < zero_ ? zero_ : -input_val_; const float pos_weight_processed_val = weight_val - target_val + one_; return (one_ - target_val) * input_val_ + (pos_weight_processed_val * (logf(expf(-max_val) + expf(-input_val_ - max_val)) + max_val)); } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T weight_val) const { return f(input_val, target_val) * weight_val; } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T weight_val, TARGET_T pos_weight_val) const { return f(input_val, target_val, pos_weight_val) * weight_val; } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val) const { return __float2half(f(input_val, __half2float(target_val))); } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val, half weight_val) const { return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val))); } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val, half weight_val) const { return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val))); } }; template struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val, half weight_val, half pos_weight_val) const { return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val), __half2float(pos_weight_val))); } }; template<> struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val) const { return __float2half(f(__half2float(input_val), __half2float(target_val))); } }; template<> struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val, half weight_val) const { return __float2half( f(__half2float(input_val), __half2float(target_val), __half2float(weight_val))); } }; template<> struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val, half weight_val) const { return __float2half( f(__half2float(input_val), __half2float(target_val), __half2float(weight_val))); } }; template<> struct BinaryCrossEntropyWithLogitsFunctor { BinaryCrossEntropyWithLogitsFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val, half weight_val, half pos_weight_val) const { return __float2half(f(__half2float(input_val), __half2float(target_val), __half2float(weight_val), __half2float(pos_weight_val))); } }; template __device__ __forceinline__ T CalSigmoid(const T x) { const T half_of_one = static_cast(0.5); return half_of_one * tanh(half_of_one * x) + half_of_one; } template<> __device__ __forceinline__ float CalSigmoid(const float x) { const float half_of_one = static_cast(0.5); return half_of_one * tanhf(half_of_one * x) + half_of_one; } template<> __device__ __forceinline__ half CalSigmoid(const half x) { return __float2half(CalSigmoid(__half2float(x))); } template struct BinaryCrossEntropyWithLogitsGradFunctor; template struct BinaryCrossEntropyWithLogitsGradFunctor { __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T dy_val) const { return (CalSigmoid(input_val) - static_cast(target_val)) * static_cast(dy_val); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { INPUT_T one_; BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal()) {} __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val) const { TARGET_T dx_tmp = dy_val * ((weight_val + one_ - target_val) * static_cast(CalSigmoid(input_val)) - weight_val); return static_cast(dx_tmp); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val) const { return f(input_val, target_val, dy_val) * static_cast(weight_val); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val, TARGET_T pos_weight_val) const { return f(input_val, target_val, dy_val, pos_weight_val) * static_cast(weight_val); } }; template<> struct BinaryCrossEntropyWithLogitsGradFunctor { __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const { return (CalSigmoid(input_val) - target_val) * dy_val; } }; template<> struct BinaryCrossEntropyWithLogitsGradFunctor { half one_; BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal()) {} __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, half weight_val) const { return dy_val * ((weight_val + one_ - target_val) * CalSigmoid(input_val) - weight_val); } }; template<> struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, half weight_val) const { return f(input_val, target_val, dy_val) * weight_val; } }; template<> struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, half weight_val, half pos_weight_val) const { return f(input_val, target_val, dy_val, pos_weight_val) * weight_val; } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val) const { return (CalSigmoid(input_val) - static_cast(__half2float(target_val))) * static_cast(__half2float(dy_val)); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { INPUT_T one_; BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal()) {} __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val, half weight_val) const { const INPUT_T dy_val_f = static_cast(__half2float(dy_val)); const INPUT_T target_val_f = static_cast(__half2float(target_val)); const INPUT_T weight_val_f = static_cast(__half2float(weight_val)); return dy_val_f * ((weight_val_f + one_ - target_val_f) * CalSigmoid(input_val)) - weight_val_f; } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val, half weight_val) const { return f(input_val, target_val, dy_val) * static_cast(__half2float(weight_val)); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val, half weight_val, half pos_weight_val) const { return f(input_val, target_val, dy_val, pos_weight_val) * static_cast(__half2float(weight_val)); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val) const { const half dy_val_h = __float2half(static_cast(dy_val)); const half target_val_h = __float2half(static_cast(target_val)); return (CalSigmoid(input_val) - target_val_h) * dy_val_h; } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { half one_; BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal()) {} __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val) const { const half dy_val_h = __float2half(static_cast(dy_val)); const half target_val_h = __float2half(static_cast(target_val)); const half weight_val_h = __float2half(static_cast(weight_val)); return dy_val_h * ((weight_val_h + one_ - target_val_h) * CalSigmoid(input_val) - weight_val_h); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val) const { return f(input_val, target_val, dy_val) * __float2half(static_cast(weight_val)); } }; template struct BinaryCrossEntropyWithLogitsGradFunctor { BinaryCrossEntropyWithLogitsGradFunctor f; __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val, TARGET_T weight_val, TARGET_T pos_weight_val) const { return f(input_val, target_val, dy_val, pos_weight_val) * __float2half(static_cast(weight_val)); } }; template class BinaryCrossEntropyWithLogitsKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsKernel() = default; ~BinaryCrossEntropyWithLogitsKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); TARGET_T* out = out_blob->mut_dptr(); if (ctx->Attr("has_pos_weight")) { TARGET_T* pos_weight_processed = tmp_buffer_blob->mut_dptr(); const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->dptr(); Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes()); pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1, ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->shape_view().elem_cnt()); auto bcast_mul = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(), target_blob->data_type(), target_blob->shape_view().NumAxes()); CHECK(bcast_mul); bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(), target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(), pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed); if (ctx->has_input("weight", 0)) { const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); using FunctorT = BinaryCrossEntropyWithLogitsFunctor; using FactoryT = cuda::elementwise::SimpleFactory; OF_CUDA_CHECK( (cuda::elementwise:: GenericLauncher::Launch( FactoryT(FunctorT()), elem_cnt, out, input, target, weight, pos_weight_processed, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK((cuda::elementwise::Ternary( BinaryCrossEntropyWithLogitsFunctor(), elem_cnt, out, input, target, pos_weight_processed, ctx->stream()->As()->cuda_stream()))); } } else { if (ctx->has_input("weight", 0)) { const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); OF_CUDA_CHECK((cuda::elementwise::Ternary( BinaryCrossEntropyWithLogitsFunctor(), elem_cnt, out, input, target, weight, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK((cuda::elementwise::Binary( BinaryCrossEntropyWithLogitsFunctor(), elem_cnt, out, input, target, ctx->stream()->As()->cuda_stream()))); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryCrossEntropyWithLogitsGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsGradKernel() = default; ~BinaryCrossEntropyWithLogitsGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const TARGET_T* dy = dy_blob->dptr(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); INPUT_T* dx = dx_blob->mut_dptr(); if (ctx->Attr("has_pos_weight")) { TARGET_T* pos_weight_processed = tmp_buffer_blob->mut_dptr(); const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->dptr(); Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes()); pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1, ctx->Tensor4ArgNameAndIndex("pos_weight", 0)->shape_view().elem_cnt()); auto bcast_mul = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(), target_blob->data_type(), target_blob->shape_view().NumAxes()); CHECK(bcast_mul); bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(), target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(), pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed); if (ctx->has_input("weight", 0)) { const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); using FunctorT = BinaryCrossEntropyWithLogitsGradFunctor; using FactoryT = cuda::elementwise::SimpleFactory; OF_CUDA_CHECK((cuda::elementwise::GenericLauncher< FactoryT, INPUT_T, INPUT_T, TARGET_T, TARGET_T, TARGET_T, TARGET_T>::Launch(FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, weight, pos_weight_processed, ctx->stream()->As()->cuda_stream()))); } else { using FunctorT = BinaryCrossEntropyWithLogitsGradFunctor; using FactoryT = cuda::elementwise::SimpleFactory; OF_CUDA_CHECK( (cuda::elementwise:: GenericLauncher::Launch( FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, pos_weight_processed, ctx->stream()->As()->cuda_stream()))); } } else { if (ctx->has_input("weight", 0)) { const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); using FunctorT = BinaryCrossEntropyWithLogitsGradFunctor; using FactoryT = cuda::elementwise::SimpleFactory; OF_CUDA_CHECK( (cuda::elementwise:: GenericLauncher::Launch( FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, weight, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK((cuda::elementwise::Ternary( BinaryCrossEntropyWithLogitsGradFunctor(), elem_cnt, dx, input, target, dy, ctx->stream()->As()->cuda_stream()))); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenFwInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const int64_t n = ctx->InputShape("input", 0).elem_cnt(); size_t tmp_buffer_size = 0; if (ctx->Attr("has_pos_weight")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); } return tmp_buffer_size; }; } template user_op::InferTmpSizeFn GenBwInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const int64_t n = ctx->InputShape("target", 0).elem_cnt(); size_t tmp_buffer_size = 0; if (ctx->Attr("has_pos_weight")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); } return tmp_buffer_size; }; } } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(GenFwInferTmpSizeFn()); #define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(GenBwInferTmpSizeFn()); REGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, half) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, float) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, half) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, double) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, half) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, double) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, half) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, float) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, half) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, double) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, half) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { namespace { constexpr int32_t kBlockSize = 1024; constexpr int32_t kReduceLocalSumBlockSize = 1024; constexpr int32_t kSingleBlockProcessNumThreshold = 1024; template struct DefaultComputeType { using type = T; }; template<> struct DefaultComputeType { using type = float; }; template inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size, int64_t max_blocks, int64_t waves, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int max_active_blocks; { cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, block_size, dynamic_smem_size); } *num_blocks = std::max(1, std::min(max_blocks, sm_count * max_active_blocks * waves)); return cudaSuccess; } template __device__ __forceinline__ T Sigmoid(const T x) { const T half_of_one = static_cast(0.5); return half_of_one * tanh(half_of_one * x) + half_of_one; } template<> __device__ __forceinline__ half Sigmoid(const half x) { return __float2half(Sigmoid(__half2float(x))); } template __global__ void FusedBinaryCrossEntropyWithLogitsReduceMeanKernel(const INPUT_T* input, const TARGET_T* target, OUTPUT_T* out, const int64_t local_elem_cnt, const int64_t reduce_elem_cnt) { ComputeType zero = static_cast(0.0); ComputeType one = static_cast(1.0); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; ComputeType reduce_sum = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, local_elem_cnt) { const ComputeType input_val = static_cast(input[i]); const ComputeType target_val = static_cast(target[i]); const ComputeType max_val = -input_val < zero ? zero : -input_val; const ComputeType result = (one - target_val) * input_val + max_val + (log(exp(-max_val) + exp(-input_val - max_val))); reduce_sum += result; } const ComputeType block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum); if (threadIdx.x == 0) { out[blockIdx.x] = static_cast(block_reduce_sum / reduce_elem_cnt); } } template __global__ void ReduceLocalSumKernel(INPUT_T* block_local_sum_buf, TARGET_T* out, int64_t elem_cnt) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; INPUT_T reduce_sum = 0.0; CUDA_1D_KERNEL_LOOP(i, elem_cnt) { reduce_sum += block_local_sum_buf[i]; } const INPUT_T block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum); if (threadIdx.x == 0) { out[0] = block_reduce_sum; } } template struct BinaryCrossEntropyWithLogitsReduceMeanGradFunctor { OF_DEVICE_FUNC explicit BinaryCrossEntropyWithLogitsReduceMeanGradFunctor( const INPUT_T elem_cnt_reciprocal, const TARGET_T dy) : elem_cnt_reciprocal(elem_cnt_reciprocal), dy(dy) {} __device__ ComputeType operator()(const INPUT_T input_val, const TARGET_T target_val) const { const ComputeType input_val_ = static_cast(input_val); const ComputeType target_val_ = static_cast(target_val); const ComputeType dy_ = static_cast(dy); const ComputeType elem_cnt_reciprocal_ = static_cast(elem_cnt_reciprocal); return (Sigmoid(input_val_) - target_val_) * dy_ * elem_cnt_reciprocal_; } const TARGET_T dy; const INPUT_T elem_cnt_reciprocal; }; template struct BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor { OF_DEVICE_FUNC explicit BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor( const int32_t elem_cnt, const TARGET_T* dy_ptr) : elem_cnt_reciprocal(1.0f / elem_cnt), dy_ptr(dy_ptr) {} __device__ BinaryCrossEntropyWithLogitsReduceMeanGradFunctor operator()() const { return BinaryCrossEntropyWithLogitsReduceMeanGradFunctor( elem_cnt_reciprocal, *dy_ptr); } const TARGET_T* dy_ptr; const INPUT_T elem_cnt_reciprocal; }; template __global__ void FusedBCEReduceMeanFwBwKernel(const INPUT_T* input, const TARGET_T* target, TARGET_T* out, INPUT_T* input_grad, const ComputeType constant_output_grad, const ComputeType elem_cnt_reciprocal, const int32_t local_elem_cnt, const int32_t reduce_elem_cnt) { ComputeType zero = static_cast(0.0); ComputeType one = static_cast(1.0); BinaryCrossEntropyWithLogitsReduceMeanGradFunctor grad_functor( elem_cnt_reciprocal, constant_output_grad); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; TARGET_T reduce_sum = 0.0; CUDA_1D_KERNEL_LOOP(i, local_elem_cnt) { const INPUT_T input_val = input[i]; const TARGET_T target_val = target[i]; input_grad[i] = grad_functor(input_val, target_val); const ComputeType input_val_ = static_cast(input_val); const ComputeType target_val_ = static_cast(target_val); const ComputeType max_val = -input_val_ < zero ? zero : -input_val_; const ComputeType result = (one - target_val_) * input_val_ + max_val + (log(exp(-max_val) + exp(-input_val_ - max_val))); reduce_sum += result; } const ComputeType block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum); if (threadIdx.x == 0) { out[blockIdx.x] = static_cast(block_reduce_sum / reduce_elem_cnt); } } template class FusedBCEMeanFwBwKernel final : public user_op::OpKernel, public CudaGraphSupport { public: FusedBCEMeanFwBwKernel() = default; ~FusedBCEMeanFwBwKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateBCEWithLogitsReduceMeanKernelCache(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); int64_t local_elem_cnt = input_blob->shape_view().elem_cnt(); int64_t reduce_elem_cnt = local_elem_cnt; if (cache != nullptr) { // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor. const auto* bce_cache = dynamic_cast(cache); CHECK_NOTNULL(bce_cache); reduce_elem_cnt = bce_cache->reduce_elem_cnt(); } const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); using ComputeType = typename DefaultComputeType::type; ComputeType constant_output_grad = ctx->Attr("constant_value"); ComputeType elem_cnt_reciprocal = static_cast(1) / reduce_elem_cnt; if (local_elem_cnt <= kSingleBlockProcessNumThreshold) { FusedBCEReduceMeanFwBwKernel <<<1, kBlockSize, 0, ctx->stream()->As()->cuda_stream()>>>( input_blob->dptr(), target_blob->dptr(), out_blob->mut_dptr(), dx_blob->mut_dptr(), constant_output_grad, elem_cnt_reciprocal, local_elem_cnt, reduce_elem_cnt); } else { auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t tmp_buffer_elem_cnt = tmp_buffer->shape_view().elem_cnt() / sizeof(TARGET_T); const int64_t block_num = (local_elem_cnt + kBlockSize - 1) / kBlockSize; int launch_block = block_num; OF_CUDA_CHECK(GetNumBlocks(FusedBCEReduceMeanFwBwKernel, kBlockSize, 0, block_num, 32, &launch_block)); launch_block = std::min(tmp_buffer_elem_cnt, launch_block); FusedBCEReduceMeanFwBwKernel <<stream()->As()->cuda_stream()>>>( input_blob->dptr(), target_blob->dptr(), tmp_buffer->mut_dptr(), dx_blob->mut_dptr(), constant_output_grad, elem_cnt_reciprocal, local_elem_cnt, reduce_elem_cnt); ReduceLocalSumKernel <<<1, kReduceLocalSumBlockSize, 0, ctx->stream()->As()->cuda_stream()>>>( tmp_buffer->mut_dptr(), out_blob->mut_dptr(), block_num); } } }; template class BinaryCrossEntropyWithLogitsMeanKernel final : public user_op::OpKernel, public CudaGraphSupport { public: BinaryCrossEntropyWithLogitsMeanKernel() = default; ~BinaryCrossEntropyWithLogitsMeanKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateBCEWithLogitsReduceMeanKernelCache(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t local_elem_cnt = input_blob->shape_view().elem_cnt(); int64_t reduce_elem_cnt = local_elem_cnt; if (cache != nullptr) { // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor. const auto* bce_cache = dynamic_cast(cache); CHECK_NOTNULL(bce_cache); reduce_elem_cnt = bce_cache->reduce_elem_cnt(); } const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); TARGET_T* out = out_blob->mut_dptr(); using ComputeType = typename DefaultComputeType::type; if (local_elem_cnt <= kSingleBlockProcessNumThreshold) { FusedBinaryCrossEntropyWithLogitsReduceMeanKernel <<<1, kBlockSize, 0, ctx->stream()->As()->cuda_stream()>>>( input_blob->dptr(), target_blob->dptr(), out_blob->mut_dptr(), local_elem_cnt, reduce_elem_cnt); } else { auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t tmp_buffer_elem_cnt = tmp_buffer->shape_view().elem_cnt() / sizeof(TARGET_T); const int64_t block_num = (local_elem_cnt + kBlockSize - 1) / kBlockSize; int launch_block = block_num; OF_CUDA_CHECK( GetNumBlocks(FusedBinaryCrossEntropyWithLogitsReduceMeanKernel, kBlockSize, 0, block_num, 32, &launch_block)); launch_block = std::min(tmp_buffer_elem_cnt, launch_block); FusedBinaryCrossEntropyWithLogitsReduceMeanKernel <<stream()->As()->cuda_stream()>>>( input_blob->dptr(), target_blob->dptr(), tmp_buffer->mut_dptr(), local_elem_cnt, reduce_elem_cnt); ReduceLocalSumKernel <<<1, kReduceLocalSumBlockSize, 0, ctx->stream()->As()->cuda_stream()>>>( tmp_buffer->mut_dptr(), out_blob->mut_dptr(), block_num); } } }; template class BinaryCrossEntropyWithLogitsReduceMeanGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default; ~BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateBCEWithLogitsReduceMeanKernelCache(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); int64_t local_elem_cnt = input_blob->shape_view().elem_cnt(); int64_t reduce_elem_cnt = local_elem_cnt; if (cache != nullptr) { // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor. const auto* bce_cache = dynamic_cast(cache); CHECK_NOTNULL(bce_cache); reduce_elem_cnt = bce_cache->reduce_elem_cnt(); } const TARGET_T* dy = dy_blob->dptr(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); INPUT_T* dx = dx_blob->mut_dptr(); using ComputeType = typename DefaultComputeType::type; OF_CUDA_CHECK((cuda::elementwise::BinaryWithFactory( BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor( reduce_elem_cnt, dy), local_elem_cnt, dx, input, target, ctx->stream()->As()->cuda_stream()))); } }; } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_reduce_mean") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const int64_t elem_cnt = ctx->InputShape("input", 0).elem_cnt(); \ const int64_t block_num = (elem_cnt + kBlockSize - 1) / kBlockSize; \ int launch_block = block_num; \ using compute_dtype = typename DefaultComputeType::type; \ OF_CUDA_CHECK(GetNumBlocks( \ FusedBinaryCrossEntropyWithLogitsReduceMeanKernel, \ kBlockSize, 0, block_num, 32, &launch_block)); \ const int64_t tmp_buffer_size = GetCudaAlignedSize(launch_block * sizeof(compute_dtype)); \ return tmp_buffer_size; \ }); #define REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_reduce_mean_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, double) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, double) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, double) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, half) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, double) #define REGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("fused_bce_reduce_mean_fw_bw") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const int64_t elem_cnt = ctx->InputShape("input", 0).elem_cnt(); \ const int64_t block_num = (elem_cnt + kBlockSize - 1) / kBlockSize; \ int launch_block = block_num; \ using compute_dtype = typename DefaultComputeType::type; \ OF_CUDA_CHECK(GetNumBlocks( \ FusedBinaryCrossEntropyWithLogitsReduceMeanKernel, \ kBlockSize, 0, block_num, 32, &launch_block)); \ const int64_t tmp_buffer_size = GetCudaAlignedSize(launch_block * sizeof(target_dtype)); \ return tmp_buffer_size; \ }); REGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(half, half) REGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(float, float) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace user_op { namespace { class BCEWithLogitsReduceMeanKernelCache final : public user_op::OpKernelCache { public: BCEWithLogitsReduceMeanKernelCache(int64_t reduce_elem_cnt) : reduce_elem_cnt_(reduce_elem_cnt) {} ~BCEWithLogitsReduceMeanKernelCache() override = default; int64_t reduce_elem_cnt() const { return reduce_elem_cnt_; } private: const int64_t reduce_elem_cnt_; }; std::shared_ptr CreateBCEWithLogitsReduceMeanKernelCache( user_op::KernelCacheContext* ctx) { if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const int64_t reduce_elem_cnt = ctx->LogicalTensorDesc4ArgNameAndIndex("input", 0)->shape().elem_cnt(); return std::make_shared(reduce_elem_cnt); } } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/binary_cross_entropy_with_logits_reduce_mean.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h" #include "oneflow/user/kernels/loss_kernel_util.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template inline T ComputeMaxVal(const T x) { T y = -x; return y < 0 ? 0 : y; } template inline T CalSigmoid(const T x) { const T half_of_one = static_cast(0.5); return half_of_one * std::tanh(half_of_one * x) + half_of_one; } template struct ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor { inline ComputeType Compute(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target, int64_t reduce_elem_cnt) { ComputeType result = 0.0; FOR_RANGE(int64_t, i, 0, elem_cnt) { ComputeType input_val = static_cast(input[i]); ComputeType target_val = static_cast(target[i]); ComputeType max_val = ComputeMaxVal(input_val); result += (1 - target_val) * input_val + max_val + (std::log(std::exp(-max_val) + std::exp(-input_val - max_val))); } return static_cast(result) / reduce_elem_cnt; } }; template void ComputeBinaryCrossEntropyWithLogitsReduceMeanOut(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target, TARGET_T* out, int64_t reduce_elem_cnt) { if (sizeof(INPUT_T) > sizeof(TARGET_T)) { ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor f; out[0] = f.Compute(elem_cnt, input, target, reduce_elem_cnt); } else { ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor f; out[0] = f.Compute(elem_cnt, input, target, reduce_elem_cnt); } } template void ComputeBinaryCrossEntropyWithLogitsReduceMeanGradOut(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target, const TARGET_T* dy, INPUT_T* dx, int64_t reduce_elem_cnt) { INPUT_T dy_val = static_cast(dy[0]) / reduce_elem_cnt; FOR_RANGE(int64_t, i, 0, elem_cnt) { INPUT_T input_val = input[i]; INPUT_T target_val = static_cast(target[i]); INPUT_T input_sigmoid = CalSigmoid(input_val); dx[i] = (input_sigmoid - target_val) * dy_val; } } template class BinaryCrossEntropyWithLogitsReduceMeanKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsReduceMeanKernel() = default; ~BinaryCrossEntropyWithLogitsReduceMeanKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateBCEWithLogitsReduceMeanKernelCache(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t local_elem_cnt = input_blob->shape_view().elem_cnt(); int64_t reduce_elem_cnt = local_elem_cnt; if (cache != nullptr) { // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor. const auto* bce_cache = dynamic_cast(cache); CHECK_NOTNULL(bce_cache); reduce_elem_cnt = bce_cache->reduce_elem_cnt(); } const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); TARGET_T* out = out_blob->mut_dptr(); ComputeBinaryCrossEntropyWithLogitsReduceMeanOut(local_elem_cnt, input, target, out, reduce_elem_cnt); } }; template class BinaryCrossEntropyWithLogitsReduceMeanGradKernel final : public user_op::OpKernel { public: BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default; ~BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateBCEWithLogitsReduceMeanKernelCache(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); int64_t local_elem_cnt = input_blob->shape_view().elem_cnt(); int64_t reduce_elem_cnt = local_elem_cnt; if (cache != nullptr) { // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor. const auto* bce_cache = dynamic_cast(cache); CHECK_NOTNULL(bce_cache); reduce_elem_cnt = bce_cache->reduce_elem_cnt(); } const TARGET_T* dy = dy_blob->dptr(); const INPUT_T* input = input_blob->dptr(); const TARGET_T* target = target_blob->dptr(); INPUT_T* dx = dx_blob->mut_dptr(); ComputeBinaryCrossEntropyWithLogitsReduceMeanGradOut(local_elem_cnt, input, target, dy, dx, reduce_elem_cnt); } }; } // namespace #define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(input_dtype, target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_reduce_mean") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); #define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(input_dtype, \ target_dtype) \ REGISTER_USER_KERNEL("binary_cross_entropy_with_logits_reduce_mean_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(double, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(float, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(float, double) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(double, float) REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(double, double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/bincount_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { template void BinCountComputeWeight(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t size) { FOR_RANGE(int64_t, i, 0, size) { IDX idx = *(in_ptr + i); out_ptr[idx] += weight[i]; } } template void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t size) { FOR_RANGE(int64_t, i, 0, size) { IDX idx = *(in_ptr + i); out_ptr[idx] += 1L; } } template class CpuBinCountKernel final : public user_op::OpKernel { public: CpuBinCountKernel() = default; ~CpuBinCountKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); size_t out_size = ctx->Attr("size") * sizeof(T); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const IDX* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size); int64_t in_size = in->shape_view().elem_cnt(); if (ctx->has_input("weight", 0)) { const T* weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); BinCountComputeWeight(in_ptr, weight_ptr, out_ptr, in_size); } else { BinCountCompute(in_ptr, out_ptr, in_size); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_BINCOUNT_KERNEL(idx_type, dtype) \ REGISTER_USER_KERNEL("bincount") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CPU_BINCOUNT_KERNEL(int64_t, int64_t) REGISTER_CPU_BINCOUNT_KERNEL(int64_t, float16) REGISTER_CPU_BINCOUNT_KERNEL(int64_t, float) REGISTER_CPU_BINCOUNT_KERNEL(int64_t, double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/bincount_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace user_op { namespace { template __global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t in_size, int64_t out_size) { if constexpr (UseGlobalMem) { CUDA_1D_KERNEL_LOOP(i, in_size) { IDX idx = *(in_ptr + i); cuda::atomic::Add(out_ptr + idx, weight[i]); } } else { __shared__ T shm[kCudaThreadsNumPerBlock]; T zero = GetZeroVal(); shm[threadIdx.x] = zero; __syncthreads(); CUDA_1D_KERNEL_LOOP(i, in_size) { IDX idx = *(in_ptr + i); cuda::atomic::Add(shm + idx, weight[i]); } __syncthreads(); if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); } } }; template __global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t in_size, int64_t out_size) { T one = GetOneVal(); if constexpr (UseGlobalMem) { CUDA_1D_KERNEL_LOOP(i, in_size) { IDX idx = *(in_ptr + i); cuda::atomic::Add(out_ptr + idx, one); } } else { __shared__ T shm[kCudaThreadsNumPerBlock]; T zero = GetZeroVal(); shm[threadIdx.x] = zero; __syncthreads(); CUDA_1D_KERNEL_LOOP(i, in_size) { IDX idx = *(in_ptr + i); cuda::atomic::Add(shm + idx, one); } __syncthreads(); if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); } } }; template static void BinCountDispatch(user_op::KernelComputeContext* ctx, const IDX* in_ptr, const T* weight_ptr, T* out_ptr, int64_t in_size, int64_t out_size) { if (weight_ptr) { BinCountCompute <<stream()->As()->cuda_stream()>>>(in_ptr, weight_ptr, out_ptr, in_size, out_size); } else { BinCountCompute <<stream()->As()->cuda_stream()>>>(in_ptr, out_ptr, in_size, out_size); } } template class CUDABinCountKernel final : public user_op::OpKernel { public: CUDABinCountKernel() = default; ~CUDABinCountKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); size_t out_size = ctx->Attr("size"); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const IDX* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size * sizeof(T)); const int64_t in_size = in->shape_view().elem_cnt(); if (in_size == 0) { return; } const T* weight_ptr = nullptr; if (ctx->has_input("weight", 0)) { weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr(); }; if (out_size > kCudaThreadsNumPerBlock) { BinCountDispatch(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size); } else { BinCountDispatch(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_CUDA_BINCOUNT_KERNEL(idx_type, dtype) \ REGISTER_USER_KERNEL("bincount") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CUDA_BINCOUNT_KERNEL(int64_t, int64_t) REGISTER_CUDA_BINCOUNT_KERNEL(int64_t, half) REGISTER_CUDA_BINCOUNT_KERNEL(int64_t, float) REGISTER_CUDA_BINCOUNT_KERNEL(int64_t, double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/broadcast_div_grad_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace { template class BroadcastDivGradKernel final : public user_op::OpKernel { public: BroadcastDivGradKernel() = default; ~BroadcastDivGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* z_tensor = ctx->Tensor4ArgNameAndIndex("z", 0); const user_op::Tensor* dz_tensor = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const int64_t num_axes = dz_tensor->shape_view().NumAxes(); XpuVarNdarray dz(dz_tensor->shape_view(), dz_tensor->dptr(), num_axes); XpuVarNdarray const_tmp(dz.shape(), tmp_buffer->dptr()); XpuVarNdarray tmp(dz.shape(), tmp_buffer->mut_dptr()); auto bcast_div = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kDiv, z_tensor->data_type(), z_tensor->data_type(), z_tensor->shape_view().NumAxes()); CHECK(bcast_div); bcast_div->Launch(ctx->stream(), z_tensor->shape_view().NumAxes(), z_tensor->shape_view().ptr(), z_tensor->dptr(), y_tensor->shape_view().NumAxes(), y_tensor->shape_view().ptr(), y_tensor->dptr(), tmp_buffer->mut_dptr()); if (IsComplexDataType(z_tensor->data_type())) { auto conj = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kConj, z_tensor->data_type(), z_tensor->data_type()); CHECK(conj); const int64_t elem_cnt = dz_tensor->shape_view().elem_cnt(); conj->Launch(ctx->stream(), tmp_buffer->dptr(), tmp_buffer->mut_dptr(), elem_cnt); } auto bcast_mul = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kMul, dz_tensor->data_type(), dz_tensor->data_type(), dz_tensor->shape_view().NumAxes()); CHECK(bcast_mul); bcast_mul->Launch(ctx->stream(), dz_tensor->shape_view().NumAxes(), dz_tensor->shape_view().ptr(), tmp_buffer->dptr(), dz_tensor->shape_view().NumAxes(), dz_tensor->shape_view().ptr(), dz_tensor->dptr(), tmp_buffer->mut_dptr()); NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(dy_tensor->shape_view(), dy_tensor->mut_dptr(), num_axes), const_tmp, tmp); auto negative = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kNegative, dy_tensor->data_type(), dy_tensor->data_type()); CHECK(negative); negative->Launch(ctx->stream(), dy_tensor->dptr(), dy_tensor->mut_dptr(), dy_tensor->shape_view().elem_cnt()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_BROADCAST_DIV_GRAD_KERNEL(device, dtype_pair) \ REGISTER_USER_KERNEL("broadcast_div_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); \ DataType data_type = z.data_type(); \ const int64_t elem_cnt = z.shape().elem_cnt(); \ return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type)); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA), FLOAT16_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/broadcast_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template class BroadcastLikeKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BroadcastLikeKernel() = default; ~BroadcastLikeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& axis = ctx->Attr>("broadcast_axes"); const Shape& reduced_shape = CreateReducedShapeOrOnesShape(like_tensor->shape_view(), {axis.begin(), axis.end()}); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(out_tensor->shape_view(), out_tensor->mut_dptr()), XpuVarNdarray(reduced_shape, in_tensor->dptr())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_BROADCAST_LIKE_XPU_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("broadcast_like") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); #ifdef WITH_CUDA #define REGISTER_BROADCAST_LIKE_KERNEL(dtype) \ REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, dtype) \ REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, dtype) #else #define REGISTER_BROADCAST_LIKE_KERNEL(dtype) \ REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, dtype) #endif REGISTER_BROADCAST_LIKE_KERNEL(float) REGISTER_BROADCAST_LIKE_KERNEL(float16) REGISTER_BROADCAST_LIKE_KERNEL(double) REGISTER_BROADCAST_LIKE_KERNEL(bool) REGISTER_BROADCAST_LIKE_KERNEL(int8_t) REGISTER_BROADCAST_LIKE_KERNEL(int32_t) REGISTER_BROADCAST_LIKE_KERNEL(int64_t) REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, std::complex) REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, std::complex) #ifdef WITH_CUDA REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, cuComplex) REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, cuDoubleComplex) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cast_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewBroadcastPrimitive(Context* ctx) { const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); const size_t max_ndim = std::max(ctx->TensorDesc4ArgNameAndIndex("in", 0)->shape().NumAxes(), ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape().NumAxes()); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kCast, in_data_type, out_data_type, max_ndim); } class CastKernel final : public OpKernel, public user_op::CudaGraphSupport { public: CastKernel() = default; ~CastKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input = ctx->Tensor4ArgNameAndIndex("in", 0); Tensor* output = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input->shape_view().elem_cnt(); // 0-size tensor CHECK_EQ(output->shape_view().elem_cnt(), elem_cnt) << "The number of cast op's input and output elements should be equal."; if (elem_cnt == 0) { return; } if (input->data_type() == output->data_type() && input->dptr() == output->dptr()) { return; } const size_t ndim = input->shape_view().NumAxes(); auto broadcast_primitive = NewBroadcastPrimitive(ctx); CHECK(broadcast_primitive); if (ndim == 0 && elem_cnt == 1) { // 0-dim tensor // TODO: remove these when BroadcastElementwiseUnary primitive support 0-dim(scalar) tensor Shape input_shape(DimVector{1}); Shape output_shape(DimVector{1}); Stride input_stride(DimVector{1}); Stride output_stride(DimVector{1}); const size_t scalar_ndim = 1; broadcast_primitive->Launch(ctx->stream(), scalar_ndim, input_shape.data(), input_stride.data(), input->dptr(), scalar_ndim, output_shape.data(), output_stride.data(), output->mut_dptr()); } else { broadcast_primitive->Launch( ctx->stream(), ndim, input->shape_view().data(), input->stride().data(), input->dptr(), ndim, output->shape_view().data(), output->stride().data(), output->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto BroadcastPrimitiveExists() { return hob::make_custom("BroadcastElementwiseUnaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewBroadcastPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("cast") .SetCreateFn() .SetIsMatchedHob(BroadcastPrimitiveExists() == true) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.InputDType("in", 0) == ctx.Attr("dtype")) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); } return Maybe::Ok(); }); REGISTER_USER_KERNEL("cast_like") .SetCreateFn() .SetIsMatchedHob(BroadcastPrimitiveExists() == true) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.InputDType("in", 0) == ctx.InputDType("like", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); } return Maybe::Ok(); }); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cast_to_static_shape_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { class CastToStaticShapeKernel final : public user_op::OpKernel { public: CastToStaticShapeKernel() = default; ~CastToStaticShapeKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); const Shape& input_static_shape = ctx->TensorDesc4ArgNameAndIndex("input", 0)->shape(); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); CHECK(input_tensor->shape_view() == ShapeView(input_static_shape)); CHECK_EQ(output_tensor->shape_view(), input_tensor->shape_view()); size_t output_tensor_size = output_tensor->shape_view().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()); std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD); CHECK(primitive) << "Can not create Memcpy primitive for device type " << ctx->stream()->device_type(); primitive->Launch(ctx->stream(), output_tensor->mut_dptr(), input_tensor->dptr(), output_tensor_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace REGISTER_USER_KERNEL("cast_to_static_shape") .SetCreateFn() .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("output", 0, "input", 0, false)); return Maybe::Ok(); }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/categorical_ordinal_encode_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h" namespace oneflow { template class CategoricalOrdinalEncodeKernel final : public user_op::OpKernel { public: CategoricalOrdinalEncodeKernel() = default; ~CategoricalOrdinalEncodeKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { bool hash_precomputed = ctx->Attr("hash_precomputed"); CHECK(hash_precomputed); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* table = ctx->Tensor4ArgNameAndIndex("table", 0); user_op::Tensor* size = ctx->Tensor4ArgNameAndIndex("size", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t table_elem_cnt = table->shape_view().elem_cnt(); CHECK_EQ(table_elem_cnt % 2, 0); const int64_t capacity = table_elem_cnt / 2; CategoricalOrdinalEncodeKernelUtil::Encode( ctx->stream(), capacity, table->mut_dptr(), size->mut_dptr(), in->shape_view().elem_cnt(), in->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(device, proto_type, cpp_type) \ REGISTER_USER_KERNEL("CategoricalOrdinalEncode") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in", 0) == proto_type)); REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCPU, DataType::kInt32, int32_t); REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCPU, DataType::kInt64, int64_t); #ifdef WITH_CUDA REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCUDA, DataType::kInt32, int32_t); REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCUDA, DataType::kInt64, int64_t); #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/categorical_ordinal_encode_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h" namespace oneflow { template struct CategoricalOrdinalEncodeKernelUtil { static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n, const T* hash, T* out) { for (int64_t i = 0; i < n; ++i) { const T h = hash[i]; bool success = false; for (int64_t count = 0; count < capacity; ++count) { size_t idx = (static_cast(h) + static_cast(count)) % static_cast(capacity); T* k_ptr = table + idx * 2; T* v_ptr = k_ptr + 1; if (*k_ptr == h) { out[i] = *v_ptr; success = true; break; } else if (*k_ptr == 0) { T new_size = *size + 1; *k_ptr = h; *v_ptr = new_size; out[i] = new_size; *size = new_size; success = true; break; } else { continue; } } CHECK(success); } } }; #define INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU(type_cpp, type_proto) \ template struct CategoricalOrdinalEncodeKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/categorical_ordinal_encode_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef NDEBUG #undef NDEBUG #endif #include #include "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { using CuInt64T = unsigned long long int; __device__ __inline__ int32_t AtomicCAS(int32_t* address, int32_t compare, int32_t val) { return atomicCAS(address, compare, val); } __device__ __inline__ int64_t AtomicCAS(int64_t* address, int64_t compare, int64_t val) { static_assert(sizeof(int64_t) == sizeof(CuInt64T), "size error"); return static_cast(atomicCAS(reinterpret_cast(address), static_cast(compare), static_cast(val))); } __device__ __inline__ int32_t AtomicAdd(int32_t* address, int32_t val) { return atomicAdd(address, val); } __device__ __inline__ int64_t AtomicAdd(int64_t* address, int64_t val) { static_assert(sizeof(int64_t) == sizeof(CuInt64T), "size error"); return static_cast( atomicAdd(reinterpret_cast(address), static_cast(val))); } template __device__ bool TryGetOrInsert(K* key, volatile V* value, V* size, const K hash, V* out) { K old_key = AtomicCAS(key, static_cast(0), hash); if (old_key == 0) { V v = AtomicAdd(size, 1) + 1; *value = v; *out = v; return true; } else if (old_key == hash) { while (true) { V v = *value; if (v != 0) { *out = v; break; } } return true; } else { return false; } } template __device__ bool GetOrInsertOne(const size_t capacity, T* table, T* size, const T hash, T* out) { if (hash == 0) { *out = 0; return true; } const size_t start_idx = static_cast(hash) % capacity; // fast path { T* key = table + start_idx * 2; T* value = key + 1; if (*key == hash && *value != 0) { *out = *value; return true; } } for (size_t count = 0; count < capacity; ++count) { const size_t idx = (start_idx + count) % capacity; T* key = table + idx * 2; T* value = key + 1; if (TryGetOrInsert(key, value, size, hash, out)) { return true; } } return false; } template __global__ void EncodeGpu(const size_t capacity, T* table, T* size, const int64_t n, const T* hash, T* out) { CUDA_1D_KERNEL_LOOP(i, n) { bool success = GetOrInsertOne(capacity, table, size, hash[i], out + i); assert(success); } } } // namespace template struct CategoricalOrdinalEncodeKernelUtil { static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n, const T* hash, T* out) { EncodeGpu <<As()->cuda_stream()>>>(capacity, table, size, n, hash, out); } }; #define INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA(type_cpp, type_proto) \ template struct CategoricalOrdinalEncodeKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct CategoricalOrdinalEncodeKernelUtil { static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n, const T* hash, T* out); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/clip_by_value_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/clip_by_value_kernel.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #ifdef WITH_CUDA #include #endif namespace oneflow { namespace { template T GetDtypeMatchedValue(double floating, int64_t integral); template<> float GetDtypeMatchedValue(double floating, int64_t integral) { return static_cast(floating); } template<> double GetDtypeMatchedValue(double floating, int64_t integral) { return floating; } template<> int8_t GetDtypeMatchedValue(double floating, int64_t integral) { return static_cast(integral); } template<> int32_t GetDtypeMatchedValue(double floating, int64_t integral) { return static_cast(integral); } template<> int64_t GetDtypeMatchedValue(double floating, int64_t integral) { return integral; } #ifdef WITH_CUDA template<> half GetDtypeMatchedValue(double floating, int64_t integral) { #if CUDA_VERSION >= 11000 return __double2half(floating); #else return __float2half(static_cast(floating)); #endif } #endif template<> float16 GetDtypeMatchedValue(double floating, int64_t integral) { return static_cast(floating); } } // namespace template struct ClipKernelUtil { template static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y) { FOR_RANGE(int64_t, i, 0, n) { y[i] = clip_func(x[i]); } } template static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy, T* dx) { FOR_RANGE(int64_t, i, 0, n) { dx[i] = clip_func(x[i], dy[i]); } } }; template class ClipByScalarKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarKernel() = default; ~ClipByScalarKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); double floating_min = ctx->Attr("floating_min"); int64_t integral_min = ctx->Attr("integral_min"); double floating_max = ctx->Attr("floating_max"); int64_t integral_max = ctx->Attr("integral_max"); ClipByMinMaxFunctor clip_func(GetDtypeMatchedValue(floating_min, integral_min), GetDtypeMatchedValue(floating_max, integral_max)); ClipKernelUtil::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(), x->dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ClipByScalarMinKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarMinKernel() = default; ~ClipByScalarMinKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); double floating_min = ctx->Attr("floating_min"); int64_t integral_min = ctx->Attr("integral_min"); ClipByMinFunctor clip_func(GetDtypeMatchedValue(floating_min, integral_min)); ClipKernelUtil::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(), x->dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ClipByScalarMaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarMaxKernel() = default; ~ClipByScalarMaxKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); double floating_max = ctx->Attr("floating_max"); int64_t integral_max = ctx->Attr("integral_max"); ClipByMaxFunctor clip_func(GetDtypeMatchedValue(floating_max, integral_max)); ClipKernelUtil::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(), x->dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ClipByScalarGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarGradKernel() = default; ~ClipByScalarGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); double floating_min = ctx->Attr("floating_min"); int64_t integral_min = ctx->Attr("integral_min"); double floating_max = ctx->Attr("floating_max"); int64_t integral_max = ctx->Attr("integral_max"); ClipByMinMaxGradFunctor clip_func(GetDtypeMatchedValue(floating_min, integral_min), GetDtypeMatchedValue(floating_max, integral_max)); ClipKernelUtil::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(), x->dptr(), dy->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ClipByScalarMinGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarMinGradKernel() = default; ~ClipByScalarMinGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); double floating_min = ctx->Attr("floating_min"); int64_t integral_min = ctx->Attr("integral_min"); ClipByMinGradFunctor clip_func(GetDtypeMatchedValue(floating_min, integral_min)); ClipKernelUtil::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(), x->dptr(), dy->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ClipByScalarMaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ClipByScalarMaxGradKernel() = default; ~ClipByScalarMaxGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); double floating_max = ctx->Attr("floating_max"); int64_t integral_max = ctx->Attr("integral_max"); ClipByMaxGradFunctor clip_func(GetDtypeMatchedValue(floating_max, integral_max)); ClipKernelUtil::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(), x->dptr(), dy->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CLIP_KERNEL(op_type_name, kernel_name, device_type_v, dtype) \ REGISTER_USER_KERNEL(#op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); \ return Maybe::Ok(); \ }); #define REGISTER_CLIP_GRAD_KERNEL(op_type_name, kernel_name, device_type_v, dtype) \ REGISTER_USER_KERNEL(#op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); \ return Maybe::Ok(); \ }); #define REGISTER_CLIP_KERNELS(device_type_v, dtype_pair) \ REGISTER_CLIP_KERNEL(clip_by_scalar, ClipByScalar, device_type_v, OF_PP_PAIR_FIRST(dtype_pair)) \ REGISTER_CLIP_KERNEL(clip_by_scalar_min, ClipByScalarMin, device_type_v, \ OF_PP_PAIR_FIRST(dtype_pair)) \ REGISTER_CLIP_KERNEL(clip_by_scalar_max, ClipByScalarMax, device_type_v, \ OF_PP_PAIR_FIRST(dtype_pair)) \ REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_grad, ClipByScalar, device_type_v, \ OF_PP_PAIR_FIRST(dtype_pair)) \ REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_min_grad, ClipByScalarMin, device_type_v, \ OF_PP_PAIR_FIRST(dtype_pair)) \ REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_max_grad, ClipByScalarMax, device_type_v, \ OF_PP_PAIR_FIRST(dtype_pair)) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CLIP_KERNELS, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) REGISTER_CLIP_KERNELS(DeviceType::kCPU, (float16, DataType::kFloat16)) #ifdef WITH_CUDA REGISTER_CLIP_KERNELS(DeviceType::kCUDA, (half, DataType::kFloat16)) #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/clip_by_value_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/clip_by_value_kernel.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace { template __global__ void CudaClipForward(F clip_func, int64_t n, const T* x, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = clip_func(x[i]); } } template __global__ void CudaClipBackward(F clip_func, int64_t n, const T* x, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = clip_func(x[i], dy[i]); } } } // namespace template struct ClipKernelUtil { template static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y) { if (n == 0) { return; } RUN_CUDA_KERNEL((CudaClipForward), stream, n, clip_func, n, x, y); } template static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy, T* dx) { if (n == 0) { return; } RUN_CUDA_KERNEL((CudaClipBackward), stream, n, clip_func, n, x, dy, dx); } }; #define INITIATE_CLIP_KERNEL_UTIL_CUDA(dtype, dtype_v) \ template struct ClipKernelUtil; \ template void ClipKernelUtil::Forward( \ ep::Stream*, ClipByMinFunctor, const int64_t n, const dtype*, dtype*); \ template void ClipKernelUtil::Forward( \ ep::Stream*, ClipByMaxFunctor, const int64_t n, const dtype*, dtype*); \ template void ClipKernelUtil::Forward( \ ep::Stream*, ClipByMinMaxFunctor, const int64_t n, const dtype*, dtype*); \ template void ClipKernelUtil::Backward( \ ep::Stream*, ClipByMinGradFunctor, const int64_t n, const dtype*, const dtype*, \ dtype*); \ template void ClipKernelUtil::Backward( \ ep::Stream*, ClipByMaxGradFunctor, const int64_t n, const dtype*, const dtype*, \ dtype*); \ template void ClipKernelUtil::Backward( \ ep::Stream*, ClipByMinMaxGradFunctor, const int64_t n, const dtype*, const dtype*, \ dtype*); OF_PP_FOR_EACH_TUPLE(INITIATE_CLIP_KERNEL_UTIL_CUDA, ARITHMETIC_DATA_TYPE_SEQ) INITIATE_CLIP_KERNEL_UTIL_CUDA(half, DataType::kFloat16) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/clip_by_value_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_ #define ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template OF_DEVICE_FUNC T DeviceMin(T a, T b) { #if defined(__CUDA_ARCH__) return a < b ? a : b; #else return std::min(a, b); #endif } template OF_DEVICE_FUNC T DeviceMax(T a, T b) { #if defined(__CUDA_ARCH__) return a > b ? a : b; #else return std::max(a, b); #endif } template struct ClipByMinFunctor { ClipByMinFunctor(T min) : min_value(min) {} OF_DEVICE_FUNC T operator()(T value) { return DeviceMax(value, min_value); } T min_value; }; template struct ClipByMaxFunctor { ClipByMaxFunctor(T max) : max_value(max) {} OF_DEVICE_FUNC T operator()(T value) { return DeviceMin(value, max_value); } T max_value; }; template struct ClipByMinMaxFunctor { ClipByMinMaxFunctor(T min, T max) : min_value(min), max_value(max) {} OF_DEVICE_FUNC T operator()(T value) { return DeviceMin(DeviceMax(value, min_value), max_value); } T min_value; T max_value; }; template struct ClipByMinGradFunctor { ClipByMinGradFunctor(T min) : min_value(min) {} OF_DEVICE_FUNC T operator()(T value, T grad) { return value < min_value ? static_cast(0) : grad; } T min_value; }; template struct ClipByMaxGradFunctor { ClipByMaxGradFunctor(T max) : max_value(max) {} OF_DEVICE_FUNC T operator()(T value, T grad) { return value > max_value ? static_cast(0) : grad; } T max_value; }; template struct ClipByMinMaxGradFunctor { ClipByMinMaxGradFunctor(T min, T max) : min_value(min), max_value(max) {} OF_DEVICE_FUNC T operator()(T value, T grad) { return (value < min_value || value > max_value) ? static_cast(0) : grad; } T min_value; T max_value; }; template struct ClipKernelUtil { template static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y); template static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy, T* dx); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/coco_reader_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/data/coco_data_reader.h" namespace oneflow { namespace { class COCOReaderWrapper final : public user_op::OpKernelState { public: explicit COCOReaderWrapper(user_op::KernelInitContext* ctx) : reader_(ctx) {} ~COCOReaderWrapper() = default; void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); } private: data::COCODataReader reader_; }; class COCOReaderKernel final : public user_op::OpKernel { public: COCOReaderKernel() = default; ~COCOReaderKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { std::shared_ptr reader(new COCOReaderWrapper(ctx)); return reader; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); reader->Read(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace REGISTER_USER_KERNEL("COCOReader") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("image", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("image_id", 0) == DataType::kInt64) && (user_op::HobDataType("image_size", 0) == DataType::kInt32) && (user_op::HobDataType("gt_bbox", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("gt_label", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("gt_segm", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("gt_segm_index", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/user/kernels/collective_communication/include/all_gather.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h" namespace oneflow { namespace ccl { namespace { Maybe AllGatherImpl(const void* in, void* out, size_t elem_cnt, DataType dtype, Symbol parallel_desc) { int64_t parallel_num = parallel_desc->parallel_num(); if (parallel_num == 1) { if (in != out) { std::memcpy(out, in, elem_cnt * GetSizeOfDataType(dtype)); } return Maybe::Ok(); } char* char_out = reinterpret_cast(out); size_t chunk_size = elem_cnt * GetSizeOfDataType(dtype); BalancedSplitter bs(chunk_size * parallel_num, parallel_num); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); CHECK_OR_RETURN(opt_parallel_id->has_value()) << kOfBugIssueUploadPrompt; const auto& rank_group = JUST(RankGroup::New(parallel_desc)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); int64_t parallel_id = JUST(*opt_parallel_id); // In-place operation will happen if in == out + parallel_id * chunk_size if (in != &char_out[parallel_id * chunk_size]) { memcpy(&char_out[parallel_id * chunk_size], in, chunk_size); } for (int64_t i = 0, part_id = parallel_id; i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) { int64_t send_part_id = part_id; const void* send_ptr = &char_out[bs.At(send_part_id).begin()]; size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = RingDecrease(part_id, parallel_num); void* recv_ptr = &char_out[bs.At(recv_part_id).begin()]; size_t recv_size = bs.At(recv_part_id).size(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size; *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size; *Cb = [] {}; return Maybe::Ok(); }); if (send_size > 0) { JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); } if (recv_size > 0) { JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); } JUST(ctx.WaitDone()); } return Maybe::Ok(); } } // namespace class CpuAllGather final : public AllGather { public: OF_DISALLOW_COPY_AND_MOVE(CpuAllGather); CpuAllGather() : datatype_(kInvalidDataType) {} ~CpuAllGather() = default; void Init(DataType datatype) override { this->datatype_ = datatype; } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx); CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc())); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } private: DataType datatype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, AllGather, CpuAllGather); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/user/kernels/collective_communication/include/all_reduce.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h" namespace oneflow { namespace ccl { namespace { template struct AllReduceImpl final { static Maybe Call(const void* void_in, void* void_out, size_t elem_cnt, Symbol parallel_desc) { int64_t parallel_num = parallel_desc->parallel_num(); if (parallel_num == 1) { if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); } return Maybe::Ok(); } const T* in = reinterpret_cast(void_in); T* out = reinterpret_cast(void_out); BalancedSplitter bs(elem_cnt, parallel_num); auto recv_buffer = std::make_unique(bs.At(0).size()); Optional parallel_id; JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); const auto& rank_group = JUST(RankGroup::New(parallel_desc)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); for (int64_t i = 0, part_id = JUST(parallel_id); i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) { int64_t send_part_id = part_id; const T* send_ptr = nullptr; if (i == 0) { send_ptr = &in[bs.At(send_part_id).begin()]; } else { send_ptr = &out[bs.At(send_part_id).begin()]; } size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = RingDecrease(part_id, parallel_num); T* recv_ptr = recv_buffer.get(); size_t recv_size = bs.At(recv_part_id).size(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }); if (send_size > 0) { JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); } if (recv_size > 0) { JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); } JUST(ctx.WaitDone()); const T* cur_in = &in[bs.At(recv_part_id).begin()]; T* cur_out = &out[bs.At(recv_part_id).begin()]; if (recv_size > 0) { ReduceFunctor::Call(recv_size, cur_out, cur_in, recv_ptr); } } for (int64_t i = 0, part_id = RingIncrease(JUST(parallel_id), parallel_num); i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) { int64_t send_part_id = part_id; const T* send_ptr = &out[bs.At(send_part_id).begin()]; size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = RingDecrease(part_id, parallel_num); T* recv_ptr = &out[bs.At(recv_part_id).begin()]; size_t recv_size = bs.At(recv_part_id).size(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }); if (send_size > 0) { JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); } if (recv_size > 0) { JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); } JUST(ctx.WaitDone()); } return Maybe::Ok(); } }; #define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name::Call DEFINE_STATIC_SWITCH_FUNC(Maybe, AllReduceImpl, MAKE_ALL_REDUCE_ENTRY, // NOLINT MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), // NOLINT REDUCE_TYPE_CTRV_SEQ); // NOLINT #undef MAKE_ALL_REDUCE_ENTRY } // namespace class CpuAllReduce final : public AllReduce { public: OF_DISALLOW_COPY_AND_MOVE(CpuAllReduce); CpuAllReduce() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {} ~CpuAllReduce() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->datatype_ = datatype; this->reduce_type_ = reduce_type; } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt; CHECK_JUST(SwitchAllReduceImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt, cpu_communication_ctx->parallel_desc())); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } private: DataType datatype_; ReduceType reduce_type_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, AllReduce, CpuAllReduce); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/user/kernels/collective_communication/include/broadcast.h" namespace oneflow { namespace ccl { // Use CpuBroadcastImpl to avoid name conflict class CpuBroadcastImpl final : public Broadcast { public: OF_DISALLOW_COPY_AND_MOVE(CpuBroadcastImpl); CpuBroadcastImpl() : size_of_dtype_(0) {} ~CpuBroadcastImpl() = default; void Init(DataType datatype) override { CHECK(IsTriviallyCopyableDataType(datatype)); this->size_of_dtype_ = GetSizeOfDataType(datatype); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communication_ctx) const override { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx); size_t buffer_size = elem_cnt * size_of_dtype_; const auto& transport_token = CHECK_JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); CHECK_JUST(CpuBroadcast(in, out, buffer_size, root, cpu_communication_ctx->parallel_desc(), transport_token)); } private: size_t size_of_dtype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Broadcast, CpuBroadcastImpl); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_ #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { namespace ccl { inline int64_t RingDecrease(int64_t n, int64_t size) { return (n - 1 + size) % size; } inline int64_t RingIncrease(int64_t n, int64_t size) { return (n + 1 + size) % size; } template struct ReduceFunctor; template struct ReduceFunctor { static void Call(size_t size, T* out, const T* in0, const T* in1) { size_t thread_num = Singleton::Get()->thread_num(); BalancedSplitter bs(size, thread_num); MultiThreadLoop(thread_num, [&](size_t thread_idx) { size_t end = bs.At(thread_idx).end(); for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) { out[i] = in0[i] + in1[i]; } }); } }; template struct ReduceFunctor { static void Call(size_t size, T* out, const T* in0, const T* in1) { size_t thread_num = Singleton::Get()->thread_num(); BalancedSplitter bs(size, thread_num); MultiThreadLoop(thread_num, [&](size_t thread_idx) { size_t end = bs.At(thread_idx).end(); for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) { out[i] = std::max(in0[i], in1[i]); } }); } }; } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/core/job/parallel_desc.h" namespace oneflow { namespace ccl { void CpuCommunicationContext::Init(Symbol parallel_desc) { parallel_desc_ = parallel_desc; } REGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(DeviceType::kCPU, CpuCommunicationContext); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_ #include "oneflow/user/kernels/collective_communication/include/communication_context.h" #include "oneflow/core/common/symbol.h" namespace oneflow { class ParallelDesc; namespace ccl { class CpuCommunicationContext : public CommunicationContext { public: explicit CpuCommunicationContext() = default; ~CpuCommunicationContext() override = default; void Init(Symbol) override; Symbol parallel_desc() const { return parallel_desc_; } private: Symbol parallel_desc_; }; } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" namespace oneflow { namespace ccl { // Use CpuRecvImpl to avoid name conflict class CpuRecvImpl final : public Recv { public: OF_DISALLOW_COPY_AND_MOVE(CpuRecvImpl); CpuRecvImpl() : size_of_dtype_(0) {} ~CpuRecvImpl() = default; void Init(DataType datatype) override { CHECK(IsTriviallyCopyableDataType(datatype)); this->size_of_dtype_ = GetSizeOfDataType(datatype); } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override { size_t buffer_size = elem_cnt * size_of_dtype_; CHECK_JUST(CpuRecv(out, buffer_size, src)); } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, const ccl::CclComm& ccl_comm) const override { Launch(stream, out, elem_cnt, src); } private: size_t size_of_dtype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Recv, CpuRecvImpl); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_reduce.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/user/kernels/collective_communication/include/reduce.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h" namespace oneflow { namespace ccl { namespace { template struct ReduceImpl final { static Maybe Call(const void* void_in, void* void_out, size_t elem_cnt, int64_t root, Symbol parallel_desc) { const T* in = reinterpret_cast(void_in); T* out = reinterpret_cast(void_out); int64_t parallel_num = parallel_desc->parallel_num(); BalancedSplitter bs(elem_cnt, parallel_num); size_t size = root == GlobalProcessCtx::Rank() && void_in != void_out ? 0 : bs.At(0).size(); T* tmp_out = nullptr; // void_out is only used on rank root and ignored for other ranks. auto tmp_out_buffer = std::make_unique(size); int64_t parallel_id_of_root = JUST(parallel_desc->ParallelId4MachineDeviceId(root, GlobalProcessCtx::LocalRank(root))); if (root == GlobalProcessCtx::Rank() && void_in != void_out) { tmp_out = &reinterpret_cast(void_out)[bs.At(parallel_id_of_root).begin()]; } else { tmp_out = tmp_out_buffer.get(); } auto recv_buffer = std::make_unique(bs.At(0).size()); Optional parallel_id; JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); const auto& rank_group = JUST(RankGroup::New(parallel_desc)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); for (int64_t i = 0, part_id = RingDecrease(JUST(parallel_id), parallel_num); i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) { int64_t send_part_id = part_id; const T* send_ptr = nullptr; if (i == 0) { send_ptr = &in[bs.At(send_part_id).begin()]; } else { send_ptr = tmp_out; } size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = RingDecrease(part_id, parallel_num); T* recv_ptr = recv_buffer.get(); size_t recv_size = bs.At(recv_part_id).size(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }); if (send_size > 0) { JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); } if (recv_size > 0) { JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); } JUST(ctx.WaitDone()); const T* cur_in = &in[bs.At(recv_part_id).begin()]; if (recv_size > 0) { ReduceFunctor::Call(recv_size, tmp_out, cur_in, recv_ptr); } } if (root == GlobalProcessCtx::Rank() && void_in == void_out) { memcpy(&out[bs.At(parallel_id_of_root).begin()], tmp_out, bs.At(parallel_id_of_root).size() * sizeof(T)); } for (int64_t i = 0, part_id = RingIncrease(parallel_id_of_root, parallel_num); i < parallel_num - 1; ++i, part_id = RingIncrease(part_id, parallel_num)) { int64_t send_part_id = part_id; int64_t src_rank = JUST(parallel_desc->MachineId4ParallelId(send_part_id)); const T* send_ptr = tmp_out; size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = part_id; T* recv_ptr = &out[bs.At(recv_part_id).begin()]; size_t recv_size = bs.At(recv_part_id).size(); if (send_size > 0 && src_rank == GlobalProcessCtx::Rank()) { NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }); JUST(TransportUtil::SendDataToRank(root, transport_token, &ctx)); JUST(ctx.WaitDone()); } if (recv_size > 0 && root == GlobalProcessCtx::Rank()) { NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }); JUST(TransportUtil::ReceiveDataFromRank(src_rank, transport_token, &ctx)); JUST(ctx.WaitDone()); } } return Maybe::Ok(); } }; #define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name::Call DEFINE_STATIC_SWITCH_FUNC(Maybe, ReduceImpl, MAKE_ALL_REDUCE_ENTRY, // NOLINT MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), // NOLINT REDUCE_TYPE_CTRV_SEQ); // NOLINT #undef MAKE_ALL_REDUCE_ENTRY } // namespace class CpuReduce final : public Reduce { public: OF_DISALLOW_COPY_AND_MOVE(CpuReduce); CpuReduce() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {} ~CpuReduce() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->datatype_ = datatype; this->reduce_type_ = reduce_type; } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communication_ctx) const override { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt; CHECK_JUST(SwitchReduceImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt, root, cpu_communication_ctx->parallel_desc())); } private: DataType datatype_; ReduceType reduce_type_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Reduce, CpuReduce); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h" #include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #include "oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h" namespace oneflow { namespace ccl { namespace { template struct ReduceScatterImpl final { static Maybe Call(const void* void_in, void* void_out, size_t elem_cnt, Symbol parallel_desc) { int64_t parallel_num = parallel_desc->parallel_num(); if (parallel_num == 1) { if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); } return Maybe::Ok(); } const T* in = reinterpret_cast(void_in); T* out = reinterpret_cast(void_out); BalancedSplitter bs(elem_cnt * parallel_num, parallel_num); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); CHECK_OR_RETURN(opt_parallel_id->has_value()) << kOfBugIssueUploadPrompt; int64_t parallel_id = JUST(*opt_parallel_id); auto recv_buffer = std::make_unique(bs.At(0).size()); const auto& rank_group = JUST(RankGroup::New(parallel_desc)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); for (int64_t i = 0, part_id = RingDecrease(parallel_id, parallel_num); i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) { int64_t send_part_id = part_id; const T* send_ptr = nullptr; if (i == 0) { send_ptr = &in[bs.At(send_part_id).begin()]; } else { send_ptr = out; } size_t send_size = bs.At(send_part_id).size(); int64_t recv_part_id = RingDecrease(part_id, parallel_num); T* recv_ptr = recv_buffer.get(); size_t recv_size = bs.At(recv_part_id).size(); NaiveAsyncTransportCtx ctx( transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = const_cast(send_ptr); *size = send_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = recv_ptr; *size = recv_size * sizeof(T); *Cb = [] {}; return Maybe::Ok(); }); if (send_size > 0) { JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); } if (recv_size > 0) { JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); } JUST(ctx.WaitDone()); const T* cur_in = &in[bs.At(recv_part_id).begin()]; if (recv_size > 0) { ReduceFunctor::Call(recv_size, out, cur_in, recv_ptr); } } return Maybe::Ok(); } }; #define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name::Call DEFINE_STATIC_SWITCH_FUNC(Maybe, ReduceScatterImpl, MAKE_ALL_REDUCE_ENTRY, // NOLINT MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), // NOLINT REDUCE_TYPE_CTRV_SEQ); // NOLINT #undef MAKE_ALL_REDUCE_ENTRY } // namespace class CpuReduceScatter final : public ReduceScatter { public: OF_DISALLOW_COPY_AND_MOVE(CpuReduceScatter); CpuReduceScatter() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {} ~CpuReduceScatter() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->datatype_ = datatype; this->reduce_type_ = reduce_type; } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cpu_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt; CHECK_JUST(SwitchReduceScatterImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt, cpu_communication_ctx->parallel_desc())); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } private: DataType datatype_; ReduceType reduce_type_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, ReduceScatter, CpuReduceScatter); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ccl/ccl.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/transport_util.h" #include "oneflow/user/kernels/collective_communication/include/send.h" namespace oneflow { namespace ccl { // Use CpuSendImpl to avoid name conflict class CpuSendImpl final : public Send { public: OF_DISALLOW_COPY_AND_MOVE(CpuSendImpl); CpuSendImpl() : size_of_dtype_(0) {} ~CpuSendImpl() = default; void Init(DataType datatype) override { CHECK(IsTriviallyCopyableDataType(datatype)); this->size_of_dtype_ = GetSizeOfDataType(datatype); } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { size_t buffer_size = elem_cnt * size_of_dtype_; CHECK_JUST(CpuSend(in, buffer_size, dst)); } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, const ccl::CclComm& comm) const override { Launch(stream, in, elem_cnt, dst); } private: size_t size_of_dtype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Send, CpuSendImpl); } // namespace ccl } // namespace oneflow ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/all_gather.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { class CudaAllGather final : public AllGather { public: OF_DISALLOW_COPY_AND_MOVE(CudaAllGather); CudaAllGather() : nccl_datatype_() {} ~CudaAllGather() = default; void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cuda_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt; OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, cuda_communication_ctx->nccl_comm(), stream->As()->cuda_stream())); } virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm, stream->As()->cuda_stream())); } private: ncclDataType_t nccl_datatype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllGather, CudaAllGather); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/all_reduce.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { namespace { inline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) { switch (reduce_type) { #define NCCL_REDUCE_TYPE_CASE(dtype) \ case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype NCCL_REDUCE_TYPE_CASE(Sum); NCCL_REDUCE_TYPE_CASE(Max); default: PRINT_BUG_PROMPT_AND_ABORT(); } } } // namespace class CudaAllReduce final : public AllReduce { public: OF_DISALLOW_COPY_AND_MOVE(CudaAllReduce); CudaAllReduce() : nccl_datatype_(), nccl_reduce_op_() {} ~CudaAllReduce() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->nccl_datatype_ = GetNcclDataType(datatype); this->nccl_reduce_op_ = GetNcclReduceType(reduce_type); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cuda_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cuda_communication_ctx); OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, cuda_communication_ctx->nccl_comm(), stream->As()->cuda_stream())); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, stream->As()->cuda_stream())); } private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllReduce, CudaAllReduce); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/device_type.h" namespace oneflow { namespace ccl { class CudaAllToAll final : public AllToAll { public: OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll); CudaAllToAll() : send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {} ~CudaAllToAll() = default; void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override { this->send_dtype_ = send_dtype; this->recv_dtype_ = recv_dtype; this->nccl_send_dtype_ = GetNcclDataType(send_dtype); this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype); this->rank_count_ = parallel_num; } void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t send_offset = 0; int64_t recv_offset = 0; OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < this->rank_count_; ++i) { if (send_count > 0) { char* send_ptr = static_cast(send) + send_offset; OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, stream->As()->cuda_stream())); } send_offset += send_count * GetSizeOfDataType(this->send_dtype_); if (recv_count) { char* recv_ptr = static_cast(recv) + recv_offset; OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, stream->As()->cuda_stream())); } recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_); } OF_NCCL_CHECK(ncclGroupEnd()); } void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input, const bool has_output) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); int64_t* send_offsets_ptr = static_cast(const_cast(send_offsets)); int64_t* recv_offsets_ptr = static_cast(const_cast(recv_offsets)); if (has_input || has_output) { OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < this->rank_count_; ++i) { if (has_input) { const uint64_t send_count = static_cast(send_counts_ptr[i]); if (send_count > 0) { uint64_t send_offset = static_cast(send_offsets_ptr[i]); char* send_ptr = static_cast(send) + send_offset; OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, stream->As()->cuda_stream())); } } if (has_output) { const uint64_t recv_count = static_cast(recv_counts_ptr[i]); if (recv_count > 0) { uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); char* recv_ptr = static_cast(recv) + recv_offset; OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, stream->As()->cuda_stream())); } } } OF_NCCL_CHECK(ncclGroupEnd()); } } private: DataType send_dtype_; DataType recv_dtype_; ncclDataType_t nccl_send_dtype_; ncclDataType_t nccl_recv_dtype_; size_t rank_count_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_broadcast.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/broadcast.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { class CudaBroadcast final : public Broadcast { public: OF_DISALLOW_COPY_AND_MOVE(CudaBroadcast); CudaBroadcast() : nccl_datatype_() {} ~CudaBroadcast() = default; void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communication_ctx) const override { const auto& cuda_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cuda_communication_ctx); OF_NCCL_CHECK(ncclBroadcast( in, out, elem_cnt, nccl_datatype_, cuda_communication_ctx->nccl_index4rank(root), cuda_communication_ctx->nccl_comm(), stream->As()->cuda_stream())); } private: ncclDataType_t nccl_datatype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Broadcast, CudaBroadcast); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #ifdef WITH_CUDA namespace oneflow { namespace ccl { void CudaCommunicationContext::Init(Symbol parallel_desc) { std::set> device_set; FOR_RANGE(int64_t, parallel_id, 0, parallel_desc->parallel_num()) { int64_t machine_id = CHECK_JUST(parallel_desc->MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc->DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); rank2nccl_index_.emplace(machine_id, parallel_id); } nccl_comm_ = CHECK_NOTNULL(Singleton::Get()) ->As() ->GetCommForDevice(device_set); } REGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(DeviceType::kCUDA, CudaCommunicationContext); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_ #include "oneflow/user/kernels/collective_communication/include/communication_context.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/parallel_desc.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace ccl { class CudaCommunicationContext : public CommunicationContext { public: explicit CudaCommunicationContext() = default; ~CudaCommunicationContext() override = default; void Init(Symbol) override; ncclComm_t nccl_comm() const { return nccl_comm_; } int64_t nccl_index4rank(int rank) const { return rank2nccl_index_.at(rank); } private: ncclComm_t nccl_comm_; HashMap rank2nccl_index_; }; } // namespace ccl } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/recv.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { class CudaRecv final : public Recv { public: OF_DISALLOW_COPY_AND_MOVE(CudaRecv); CudaRecv() : nccl_datatype_() {} ~CudaRecv() = default; void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override { #if HAS_NCCL_SEND_RECV const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src); OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, comm_and_peer_rank.first, stream->As()->cuda_stream())); #else UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, const ccl::CclComm& ccl_comm) const override { #if HAS_NCCL_SEND_RECV ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, src, *comm, stream->As()->cuda_stream())); #else UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV } private: ncclDataType_t nccl_datatype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Recv, CudaRecv); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_reduce.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/reduce.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { namespace { inline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) { switch (reduce_type) { #define NCCL_REDUCE_TYPE_CASE(dtype) \ case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype NCCL_REDUCE_TYPE_CASE(Sum); NCCL_REDUCE_TYPE_CASE(Max); default: PRINT_BUG_PROMPT_AND_ABORT(); } } } // namespace class CudaReduce final : public Reduce { public: OF_DISALLOW_COPY_AND_MOVE(CudaReduce); CudaReduce() : nccl_datatype_(), nccl_reduce_op_() {} ~CudaReduce() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->nccl_datatype_ = GetNcclDataType(datatype); this->nccl_reduce_op_ = GetNcclReduceType(reduce_type); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communication_ctx) const override { const auto& cuda_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt; OF_NCCL_CHECK(ncclReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, cuda_communication_ctx->nccl_index4rank(root), cuda_communication_ctx->nccl_comm(), stream->As()->cuda_stream())); } private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Reduce, CudaReduce); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { namespace { inline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) { switch (reduce_type) { #define NCCL_REDUCE_TYPE_CASE(dtype) \ case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype NCCL_REDUCE_TYPE_CASE(Sum); NCCL_REDUCE_TYPE_CASE(Max); default: PRINT_BUG_PROMPT_AND_ABORT(); } } } // namespace class CudaReduceScatter final : public ReduceScatter { public: OF_DISALLOW_COPY_AND_MOVE(CudaReduceScatter); CudaReduceScatter() : nccl_datatype_(), nccl_reduce_op_() {} ~CudaReduceScatter() = default; void Init(DataType datatype, ReduceType reduce_type) override { this->nccl_datatype_ = GetNcclDataType(datatype); this->nccl_reduce_op_ = GetNcclReduceType(reduce_type); } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communication_ctx) const override { const auto& cuda_communication_ctx = std::dynamic_pointer_cast(communication_ctx); CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt; OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, cuda_communication_ctx->nccl_comm(), stream->As()->cuda_stream())); } virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, stream->As()->cuda_stream())); } private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, ReduceScatter, CudaReduceScatter); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { class CudaSend final : public Send { public: OF_DISALLOW_COPY_AND_MOVE(CudaSend); CudaSend() : nccl_datatype_() {} ~CudaSend() = default; void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { #if HAS_NCCL_SEND_RECV const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, comm_and_peer_rank.first, stream->As()->cuda_stream())); #else UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, const ccl::CclComm& ccl_comm) const override { #if HAS_NCCL_SEND_RECV ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm, stream->As()->cuda_stream())); #else UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV } private: ncclDataType_t nccl_datatype_; }; REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Send, CudaSend); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/common/decorator.h" #ifdef WITH_CUDA #include "oneflow/core/job/eager_nccl_comm_manager.h" namespace oneflow { namespace ccl { std::pair RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) { std::set> device_set; const int64_t& rank = GlobalProcessCtx::Rank(); const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0; device_set.emplace(rank, GlobalProcessCtx::LocalRank()); device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id)); return {CHECK_NOTNULL(Singleton::Get()) ->As() ->GetCommForDevice(device_set), peer_nccl_rank}; } decltype(GetNcclCommAndPeerNcclRank) GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/device/nccl_util.h" namespace oneflow { namespace ccl { extern std::pair (*GetNcclCommAndPeerNcclRank)(int64_t peer_process_i); } // namespace ccl } // namespace oneflow #endif // WITH_CUDA #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/all_gather.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class AllGather : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(AllGather); AllGather() = default; ~AllGather() override = default; virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllGatherRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/all_reduce.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class AllReduce : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(AllReduce); AllReduce() = default; ~AllReduce() override = default; virtual void Init(DataType dtype, ReduceType reduce_type) = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllReduceRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/all_to_all.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class AllToAll : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(AllToAll); AllToAll() = default; ~AllToAll() override = default; virtual void Init(DataType send_dtype, DataType recv_dtype, size_t rank_count) = 0; // for normal alltoall(balanced send/resv count) virtual void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, const ccl::CclComm& ccl_comm) const = 0; // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input, const bool has_output) const = 0; }; inline bool IsAllToAllRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/broadcast.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class Broadcast : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(Broadcast); Broadcast() = default; ~Broadcast() override = default; virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communicator) const = 0; }; inline bool IsBroadcastRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/collective_communication.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/user/kernels/collective_communication/include/communication_context.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { namespace ccl { #define REDUCE_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(kSum) \ OF_PP_MAKE_TUPLE_SEQ(kMax) enum ReduceType { kInvalidReduceFunctorType = 0, #define DEFINE_REDUCE_TYPE_ENUM_VALUE(enum_value) enum_value, OF_PP_FOR_EACH_TUPLE(DEFINE_REDUCE_TYPE_ENUM_VALUE, REDUCE_TYPE_SEQ) #undef DEFINE_REDUCE_TYPE_ENUM_VALUE kReduceTypeSize }; #define REDUCE_TYPE_CTRV_SEQ \ MAKE_TYPED_CTRV_SEQ(ReduceType, \ OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, REDUCE_TYPE_SEQ)) // abstruct base class for comm class CommBase { public: virtual ~CommBase() = default; // return impl of comm virtual void* getComm() const = 0; }; class CclComm { public: CclComm() {} explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} void* getComm() const { return comm_->getComm(); } private: std::shared_ptr comm_{}; }; class CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveCommunication); CollectiveCommunication() = default; virtual ~CollectiveCommunication() = default; }; template static std::unique_ptr NewCollectiveCommunication( DeviceType device_type, Args&&... args) { std::unique_ptr collective_communication_entry = NewObjUniquePtr(device_type); if (!collective_communication_entry) { return nullptr; } collective_communication_entry->Init(std::forward(args)...); return collective_communication_entry; } #define REGISTER_COLLECTIVE_COMMUNICATION(device, Base, Derived) \ REGISTER_CLASS(DeviceType, device, Base, Derived) } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/communication_context.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ #include "collective_communication.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/auto_registration_factory.h" namespace oneflow { namespace ccl { class CommunicationContext { public: CommunicationContext() = default; virtual ~CommunicationContext() = default; virtual void Init(Symbol) = 0; }; inline std::shared_ptr NewCommunicationContext( DeviceType device_type, Symbol parallel_desc) { CHECK_EQ(device_type, parallel_desc->device_type()) << "device_type not match placement (" << DeviceType_Name(device_type) << " vs. " << DeviceType_Name(parallel_desc->device_type()) << ". " << kOfBugIssueUploadPrompt; std::shared_ptr communication_ctx = std::shared_ptr(NewObj(device_type)); communication_ctx->Init(parallel_desc); return communication_ctx; } inline bool IsCommunicationContextRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } #define REGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(device, Derived) \ REGISTER_CLASS(DeviceType, device, CommunicationContext, Derived) } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/recv.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class Recv : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(Recv); Recv() = default; ~Recv() override = default; virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/reduce.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class Reduce : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(Reduce); Reduce() = default; ~Reduce() override = default; virtual void Init(DataType dtype, ReduceType reduce_type) = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root, const std::shared_ptr& communicator) const = 0; }; inline bool IsReduceRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/reduce_scatter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class ReduceScatter : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(ReduceScatter); ReduceScatter() = default; ~ReduceScatter() override = default; virtual void Init(DataType dtype, ReduceType reduce_type) = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsReduceScatterRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_ ================================================ FILE: oneflow/user/kernels/collective_communication/include/send.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { namespace ccl { class Send : public CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(Send); Send() = default; ~Send() override = default; virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0; virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsSendRegistered(DeviceType device_type) { return IsClassRegistered(device_type); } } // namespace ccl } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ ================================================ FILE: oneflow/user/kernels/combined_margin_loss_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/math_unary_elementwise_func.h" namespace oneflow { namespace { class CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache { public: CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~CombinedMarginLossOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; std::shared_ptr CreateCombinedMarginLossOpKernelCache( user_op::KernelCacheContext* ctx, const std::string& in_arg_name) { if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0); if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1 && ctx->parallel_ctx().parallel_num() > 1) { CHECK(ctx->SbpParallel4ArgNameAndIndex("label", 0).has_broadcast_parallel()); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(in_arg_name, 0); const auto depth = ctx->Attr("depth"); CHECK_EQ(depth, in_logical_desc->shape().At(1)); BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num()); return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { return nullptr; } } } // namespace template class CombinedMarginLossCpuKernel final : public user_op::OpKernel { public: CombinedMarginLossCpuKernel() = default; ~CombinedMarginLossCpuKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCombinedMarginLossOpKernelCache(ctx, "x"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const T* x_ptr = x->dptr(); const K* label_ptr = ctx->Tensor4ArgNameAndIndex("label", 0)->dptr(); T* y_ptr = ctx->Tensor4ArgNameAndIndex("y", 0)->mut_dptr(); T* theta_ptr = ctx->Tensor4ArgNameAndIndex("theta", 0)->mut_dptr(); const float m1 = ctx->Attr("m1"); const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(x->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } const int64_t num_classes = x->shape_view().Count(1); FOR_RANGE(int32_t, i, 0, x->shape_view().elem_cnt()) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; const T in_data = x_ptr[i]; T out_data = in_data; K label = label_ptr[row_id] - lower_bound; if (label == col_id) { const T theta_data = AcosFunctor::Forward(in_data); out_data = CosFunctor::Forward(theta_data * static_cast(m1) + static_cast(m2)) - static_cast(m3); theta_ptr[row_id] = theta_data; } else if ((label < 0 || label >= num_classes) && col_id == 0) { theta_ptr[row_id] = 0; } y_ptr[i] = out_data; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COMBINED_MARGIN_LOSS_CPU_KERNEL(in_type, indices_type) \ REGISTER_USER_KERNEL("combined_margin_loss") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_CPU_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class CombinedMarginLossGradCpuKernel final : public user_op::OpKernel { public: CombinedMarginLossGradCpuKernel() = default; ~CombinedMarginLossGradCpuKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCombinedMarginLossOpKernelCache(ctx, "dy"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const T* dy_ptr = dy->dptr(); const K* label_ptr = ctx->Tensor4ArgNameAndIndex("label", 0)->dptr(); const T* theta_ptr = ctx->Tensor4ArgNameAndIndex("theta", 0)->dptr(); T* dx_ptr = ctx->Tensor4ArgNameAndIndex("dx", 0)->mut_dptr(); const float m1 = ctx->Attr("m1"); const float m2 = ctx->Attr("m2"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(dy->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } const int64_t num_classes = dy->shape_view().Count(1); FOR_RANGE(int32_t, i, 0, dy->shape_view().elem_cnt()) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; K label = label_ptr[row_id] - lower_bound; const T dy_data = dy_ptr[i]; const T theta_data = theta_ptr[row_id]; T dx_data = dy_data; if (label == col_id) { dx_data = dy_data * SinFunctor::Forward(theta_data * static_cast(m1) + static_cast(m2)) * static_cast(m1) / SinFunctor::Forward(theta_data); } dx_ptr[i] = dx_data; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COMBINED_MARGIN_LOSS_GRAD_CPU_KERNEL(dy_type, indices_type) \ REGISTER_USER_KERNEL("combined_margin_loss_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dy", 0) == OF_PP_PAIR_SECOND(dy_type)) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_GRAD_CPU_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/combined_margin_loss_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/math_unary_elementwise_func.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void GpuForward(const int64_t n, const int64_t num_classes, const int64_t lower_bound, const T m1, const T m2, const T m3, const T* in, const K* labels, T* out, T* theta) { CUDA_1D_KERNEL_LOOP(i, n) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; const T in_data = in[i]; T out_data = in_data; K label = labels[row_id] - lower_bound; if (is_cosine_loss) { if (label == col_id) { out_data = in_data - m3; } } else { if (label == col_id) { const T theta_data = AcosFunctor::Forward(in_data); out_data = CosFunctor::Forward(theta_data * m1 + m2) - m3; theta[row_id] = theta_data; } else if ((label < 0 || label >= num_classes) && col_id == 0) { theta[row_id] = 0; } } out[i] = out_data; } } template __global__ void GpuBackward(const int64_t n, const int64_t num_classes, const int64_t lower_bound, const T m1, const T m2, const T m3, const T* dy, const K* labels, const T* theta, T* dx) { CUDA_1D_KERNEL_LOOP(i, n) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; K label = labels[row_id] - lower_bound; const T dy_data = dy[i]; const T theta_data = theta[row_id]; T dx_data = dy_data; if (label == col_id && !is_cosine_loss) { dx_data = dy_data * SinFunctor::Forward(theta_data * m1 + m2) * m1 / SinFunctor::Forward(theta_data); } dx[i] = dx_data; } } class CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache { public: CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~CombinedMarginLossOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; std::shared_ptr CreateCombinedMarginLossOpKernelCache( user_op::KernelCacheContext* ctx, const std::string& in_arg_name) { if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0); if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1 && ctx->parallel_ctx().parallel_num() > 1) { CHECK(ctx->SbpParallel4ArgNameAndIndex("label", 0).has_broadcast_parallel()); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(in_arg_name, 0); const auto depth = ctx->Attr("depth"); CHECK_EQ(depth, in_logical_desc->shape().At(1)); BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num()); return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { return nullptr; } } } // namespace template class CombinedMarginLossGpuKernel final : public user_op::OpKernel { public: CombinedMarginLossGpuKernel() = default; ~CombinedMarginLossGpuKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCombinedMarginLossOpKernelCache(ctx, "x"); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex("theta", 0); const float m1 = ctx->Attr("m1"); const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(x->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } if (m1 == 1.0 && m2 == 0.0) { GpuForward <<shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>( x->shape_view().elem_cnt(), x->shape_view().Count(1), lower_bound, static_cast(m1), static_cast(m2), static_cast(m3), x->dptr(), label->dptr(), y->mut_dptr(), theta->mut_dptr()); } else { GpuForward <<shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>( x->shape_view().elem_cnt(), x->shape_view().Count(1), lower_bound, static_cast(m1), static_cast(m2), static_cast(m3), x->dptr(), label->dptr(), y->mut_dptr(), theta->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COMBINED_MARGIN_LOSS_CUDA_KERNEL(in_type, indices_type) \ REGISTER_USER_KERNEL("combined_margin_loss") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_CUDA_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class CombinedMarginLossGradGpuKernel final : public user_op::OpKernel { public: CombinedMarginLossGradGpuKernel() = default; ~CombinedMarginLossGradGpuKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCombinedMarginLossOpKernelCache(ctx, "dy"); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex("theta", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float m1 = ctx->Attr("m1"); const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(dy->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } if (m1 == 1.0 && m2 == 0.0) { GpuBackward <<shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>( dy->shape_view().elem_cnt(), dy->shape_view().Count(1), lower_bound, static_cast(m1), static_cast(m2), static_cast(m3), dy->dptr(), label->dptr(), theta->dptr(), dx->mut_dptr()); } else { GpuBackward <<shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>( dy->shape_view().elem_cnt(), dy->shape_view().Count(1), lower_bound, static_cast(m1), static_cast(m2), static_cast(m3), dy->dptr(), label->dptr(), theta->dptr(), dx->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COMBINED_MARGIN_LOSS_GRAD_CUDA_KERNEL(dy_type, indices_type) \ REGISTER_USER_KERNEL("combined_margin_loss_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == OF_PP_PAIR_SECOND(dy_type)) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_GRAD_CUDA_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/communicate_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" namespace oneflow { namespace { const void** ThreadLocalSrcDataPtr() { static thread_local const void* data_ptr = nullptr; return &data_ptr; } } // namespace bool IsSendAndRecvRegistered(DeviceType device_type) { return ccl::IsSendRegistered(device_type) && ccl::IsRecvRegistered(device_type); } Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, DeviceType device_type, ep::Stream* stream) { if (GlobalProcessCtx::Rank() == dst) { auto** src_data_ptr = ThreadLocalSrcDataPtr(); CHECK_OR_RETURN(*src_data_ptr == nullptr); *src_data_ptr = in; } else { std::unique_ptr send = ccl::NewCollectiveCommunication(device_type, dtype); send->Launch(stream, in, elem_cnt, dst); } return Maybe::Ok(); } Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type, ep::Stream* stream) { if (GlobalProcessCtx::Rank() == src) { size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype); auto** src_data_ptr = ThreadLocalSrcDataPtr(); const void* in = *src_data_ptr; CHECK_OR_RETURN(*src_data_ptr != nullptr); std::unique_ptr memcpy_primitive = ep::primitive::NewPrimitive(device_type, ep::primitive::MemcpyKind::kDtoD); CHECK(memcpy_primitive) << "Can not create Memcpy primitive for device type " << device_type; memcpy_primitive->Launch(stream, out, in, buffer_size); *src_data_ptr = nullptr; } else { std::unique_ptr recv = ccl::NewCollectiveCommunication(device_type, dtype); recv->Launch(stream, out, elem_cnt, src); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/communicate_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_ #define ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/user_op_kernel_registry.h" namespace oneflow { bool IsSendAndRecvRegistered(DeviceType device_type); ALWAYS_INLINE inline auto HobIsSendAndRecvRegistered() { return hob::make_custom("HobIsSendAndRecvRegistered", [](const user_op::KernelRegContext& ctx) { return IsSendAndRecvRegistered(ctx.device_type()); }); } // Send data from in to rank dst, if cur rank equal dst, memcopy will happen. // Rank dst needs to call Recv with the same datatype and the same count from this rank. Maybe Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, DeviceType device_type, ep::Stream* stream); // Receive data from rank src into out, if cur rank equal src, memcopy will happen. // Rank src needs to call Send with the same datatype and the same count to this rank. Maybe Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type, ep::Stream* stream); } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_ ================================================ FILE: oneflow/user/kernels/complex_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape_view.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/user/kernels/elementwise_primitive_kernel.h" #include #ifdef WITH_CUDA #include #endif // WITH_CUDA namespace oneflow { namespace user_op { #define COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ \ OF_PP_MAKE_TUPLE_SEQ("conj_physical", ep::primitive::UnaryOp::kConj) \ OF_PP_MAKE_TUPLE_SEQ("real", ep::primitive::UnaryOp::kReal) \ OF_PP_MAKE_TUPLE_SEQ("imag", ep::primitive::UnaryOp::kImag) #define COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ \ OF_PP_MAKE_TUPLE_SEQ("real_grad", ep::primitive::UnaryOp::kRealGrad) \ OF_PP_MAKE_TUPLE_SEQ("imag_grad", ep::primitive::UnaryOp::kImagGrad) #define REGISTER_COMPLEX_KERNEL(name, UnaryOp) \ REGISTER_USER_KERNEL(name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel( \ "out", "x", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), UnaryOp, src->data_type(), dst->data_type()); \ }); \ }) \ .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, "out", "x")); OF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_KERNEL, COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ) #define REGISTER_COMPLEX_GRAD_KERNEL(name, UnaryOp) \ REGISTER_USER_KERNEL(name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel( \ "dx", "dout", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dout", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), UnaryOp, src->data_type(), dst->data_type()); \ }); \ }) \ .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, "dx", "dout")); OF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_GRAD_KERNEL, COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/concat_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template std::unique_ptr NewCopyNdPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type(), 2); } class ConcatKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ConcatKernel() = default; ~ConcatKernel() = default; private: void InferShape(user_op::KernelInferContext* ctx) const override { const int64_t axis = ctx->Attr("axis"); DimVector dim_vec; for (const auto& in_arg_pair : ctx->inputs()) { const ShapeView& input_shape_view = ctx->ShapeView4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second); if (dim_vec.size() == 0) { input_shape_view.ToDimVector(&dim_vec); } else { CHECK_EQ(input_shape_view.NumAxes(), dim_vec.size()); FOR_RANGE(int64_t, i, 0, input_shape_view.NumAxes()) { if (i == axis) { dim_vec.at(i) += input_shape_view.At(i); } else { CHECK_EQ(input_shape_view.At(i), dim_vec.at(i)); } } } } ctx->MutShapeView4ArgNameAndIndex("out", 0).set_shape(Shape(dim_vec)); } void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); if (out_tensor->shape_view().elem_cnt() == 0) { return; } const int64_t axis = ctx->Attr("axis"); const int64_t out_cols = out_tensor->shape_view().Count(axis); const int64_t rows = out_tensor->shape_view().elem_cnt() / out_cols; CHECK_GT(rows, 0); auto primitive = NewCopyNdPrimitive(ctx); CHECK(primitive); int64_t out_col_offset = 0; for (const auto& in_arg_pair : ctx->inputs()) { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second); if (in_tensor->shape_view().elem_cnt() == 0) { continue; } const int64_t in_cols = in_tensor->shape_view().Count(axis); CHECK_EQ(in_tensor->shape_view().elem_cnt(), rows * in_cols); if (in_cols > 0) { DimVector dst_shape = {rows, out_cols}; DimVector dst_pos_vec = {0, out_col_offset}; DimVector src_shape = {rows, in_cols}; DimVector src_pos_vec = {0, 0}; DimVector extent_vec = {rows, in_cols}; primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } out_col_offset += in_cols; } CHECK_EQ(out_col_offset, out_cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto CopyNdPrimitiveExists() { return hob::make_custom("CopyNdPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewCopyNdPrimitive(&ctx).operator bool(); }); } } // namespace REGISTER_USER_KERNEL("cat").SetCreateFn().SetIsMatchedHob(CopyNdPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/constant_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } class ConstantKernel final : public OpKernel { public: ConstantKernel() = default; ~ConstantKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); bool is_complex_value = ctx->Attr("is_complex_value"); bool is_floating_value = ctx->Attr("is_floating_value"); const Scalar value = is_complex_value ? Scalar(ctx->Attr>("complex_value")) : (is_floating_value ? Scalar(ctx->Attr("floating_value")) : Scalar(ctx->Attr("integer_value"))); const int64_t elem_cnt = out_tensor->shape_view().elem_cnt(); CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) { return; } std::unique_ptr fill = NewFillPrimitive(ctx); CHECK(fill); fill->Launch(ctx->stream(), out_tensor->mut_dptr(), value, elem_cnt); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto FillPrimitiveExists() { return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewFillPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("constant") .SetCreateFn() .SetIsMatchedHob(FillPrimitiveExists() == true); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/conv_cudnn_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cudnn_conv_util.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace { template struct CudnnConvArgsAndAlgo final { using AlgoT = decltype(std::declval().algo); CudnnConvArgs args; PerfT algo_perf; CudnnConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w, const user_op::Tensor* y, user_op::Tensor* buf, const user_op::KernelComputeContext* ctx, ep::Stream* stream, bool has_forced_algo, int32_t forced_algo) : args(*ctx, x->data_type(), x->shape_view(), w->data_type(), w->shape_view(), y->data_type(), y->shape_view(), ctx->Attr("data_format"), buf->shape_view().elem_cnt(), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_heuristic_search_algo() || (!LazyMode::is_enabled()), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_use_deterministic_algo_only(), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_enable_pseudo_half() || (ctx->Attr("data_format") == "channels_last" && std::is_same::value)) { size_t byte_size_of_buf = buf->shape_view().elem_cnt(); AllocatedCudnnConvResource res(stream->As()->cudnn_handle(), const_cast(x->dptr()), const_cast(w->dptr()), const_cast(y->dptr()), buf->mut_dptr()); if (has_forced_algo) { algo_perf = GetCudnnConvAlgorithmPerferenceWithResource( &args, &res, static_cast(forced_algo)); } else { algo_perf = FindCudnnConvAlgorithmWithResource(&args, &res); } CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS) << "op (" << ctx->op_name() << ") find algorithm perference failed. algo: " << algo_perf.algo; CHECK_LE(algo_perf.memory, byte_size_of_buf) << "op (" << ctx->op_name() << ") find algorithm " << algo_perf.algo << ", need memory " << algo_perf.memory << ", but cudnn_buf_limit_byte is " << byte_size_of_buf; OF_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.Get(), algo_perf.mathType)); } CudnnConvArgsAndAlgo() = delete; OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgsAndAlgo); }; template size_t InferTmpSizeWithCudnn(const user_op::TensorDesc* x, const user_op::TensorDesc* w, const user_op::TensorDesc* y, const user_op::InferContext& ctx, bool has_forced_algo, int32_t forced_algo) { using AlgoT = decltype(std::declval().algo); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); size_t workspace_size = cudnn_conf.cudnn_buf_limit_mbyte() * 1024 * 1024; if (!x->is_dynamic()) { CudnnConvArgs args(ctx, x->data_type(), ShapeView(x->shape()), w->data_type(), ShapeView(w->shape()), y->data_type(), ShapeView(y->shape()), ctx.Attr("data_format"), workspace_size, cudnn_conf.cudnn_conv_heuristic_search_algo() || (!LazyMode::is_enabled()), cudnn_conf.cudnn_conv_use_deterministic_algo_only(), cudnn_conf.cudnn_conv_enable_pseudo_half() || (ctx.Attr("data_format") == "channels_last" && std::is_same::value)); PerfT algo_perf{}; if (has_forced_algo) { algo_perf = GetCudnnConvAlgorithmPerference(&args, static_cast(forced_algo)); } else { algo_perf = FindCudnnConvAlgorithm(&args); } CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS) << "op (" << ctx.op_name() << ") find algorithm perference failed. algo: " << algo_perf.algo; CHECK_LE(algo_perf.memory, workspace_size) << "op (" << ctx.op_name() << ") find algorithm " << algo_perf.algo << ", need memory " << algo_perf.memory << ", but cudnn_buf_limit_byte is " << workspace_size; workspace_size = algo_perf.memory; } workspace_size = std::max(size_t(1), workspace_size); return workspace_size; } // for 1d and 2d template CudnnTensorDesc* GetBiasCudnnTensorDesc(const std::string& data_format, int32_t filters, DataType data_type) { if (data_format == "channels_first") { return new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1, filters, 1, 1); } else { CHECK_EQ("channels_last", data_format); return new CudnnTensorDesc(CUDNN_TENSOR_NHWC, data_type, 1, filters, 1, 1); } } // for 3d and Nd template<> CudnnTensorDesc* GetBiasCudnnTensorDesc<3>(const std::string& data_format, int32_t filters, DataType data_type) { constexpr int NDims = 3 + 2; CHECK_EQ("channels_first", data_format) << "CUDNN Nd API only support channels first"; std::vector bias_dim(NDims, 1); std::vector stride_of_bias_tensor(NDims, 1); bias_dim[1] = filters; stride_of_bias_tensor[0] = filters; return new CudnnTensorDesc(data_type, NDims, bias_dim.data(), stride_of_bias_tensor.data()); } struct ConvCudnnOpKernelCache final : public user_op::OpKernelCache { std::unique_ptr bias_desc; }; template class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ConvGpuKernel() = default; ~ConvGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr CreateConvCudnnOpKernelCache( user_op::KernelCacheContext* ctx) const { const auto& data_format = ctx->Attr("data_format"); int32_t filters = ctx->Attr("filters"); std::shared_ptr state(new ConvCudnnOpKernelCache()); const user_op::TensorDesc* bias = ctx->TensorDesc4ArgNameAndIndex("bias", 0); if (bias != nullptr) { state->bias_desc.reset( GetBiasCudnnTensorDesc(data_format, filters, bias->data_type())); } return state; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvCudnnOpKernelCache(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); if (in->shape_view().elem_cnt() == 0) return; const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); CudnnConvArgsAndAlgo args_and_algo( in, weight, out, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_fwd_algo(), cudnn_conf.cudnn_conv_force_fwd_algo()); const CudnnConvArgs& args = args_and_algo.args; const cudnnConvolutionFwdAlgoPerf_t& algo_perf = args_and_algo.algo_perf; const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); const void* beta = nullptr; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), out->data_type()); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); Memcpy( ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = CudnnSPOnePtr(in->data_type()); } else { beta = CudnnSPZeroPtr(in->data_type()); } OF_CUDNN_CHECK(cudnnConvolutionForward( ctx->stream()->As()->cudnn_handle(), CudnnSPOnePtr(in->data_type()), args.xdesc.Get(), in->dptr(), args.wdesc.Get(), weight->dptr(), args.cdesc.Get(), algo_perf.algo, buf->mut_dptr(), args.params.max_ws_size, beta, args.ydesc.Get(), out->mut_dptr())); if (bias != nullptr) { const auto* conv_cache = dynamic_cast(cache); CHECK_NOTNULL(conv_cache); OF_CUDNN_CHECK(cudnnAddTensor(ctx->stream()->As()->cudnn_handle(), CudnnSPOnePtr(in->data_type()), conv_cache->bias_desc->Get(), bias->dptr(), CudnnSPOnePtr(in->data_type()), args.ydesc.Get(), out->mut_dptr())); } } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, user_op::OpKernelState* state) const override { return Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_heuristic_search_algo(); } }; #define REGISTER_CONV_KERNEL(op_name, ndims) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& in = ctx->InputTensorDesc("in", 0); \ if (in.shape().elem_cnt() == 0) return 0; \ const auto& weight = ctx->InputTensorDesc("weight", 0); \ const auto& out = ctx->OutputTensorDesc("out", 0); \ const auto& cudnn_conf = \ Singleton::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ &in, &weight, &out, *ctx, cudnn_conf.has_cudnn_conv_force_fwd_algo(), \ cudnn_conf.cudnn_conv_force_fwd_algo()); \ }) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ if (ctx.has_input("_add_to_output", 0)) { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ } \ return Maybe::Ok(); \ }); REGISTER_CONV_KERNEL(conv1d, 1); REGISTER_CONV_KERNEL(conv2d, 2); REGISTER_CONV_KERNEL(conv3d, 3); class ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(ConvDataGradGpuKernel); ConvDataGradGpuKernel() = default; ~ConvDataGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); if (dx->shape_view().elem_cnt() == 0) return; user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); CudnnConvArgsAndAlgo args_and_algo( dx, filter, dy, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), cudnn_conf.cudnn_conv_force_bwd_data_algo()); const CudnnConvArgs& args = args_and_algo.args; const cudnnConvolutionBwdDataAlgoPerf_t& algo_perf = args_and_algo.algo_perf; const void* alpha = CudnnSPOnePtr(dy->data_type()); const void* beta = nullptr; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), dx->data_type()); CHECK_EQ(add_to_output->shape_view(), dx->shape_view()); Memcpy( ctx->stream(), dx->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = CudnnSPOnePtr(dy->data_type()); } else { beta = CudnnSPZeroPtr(dy->data_type()); } OF_CUDNN_CHECK(cudnnConvolutionBackwardData( ctx->stream()->As()->cudnn_handle(), alpha, args.wdesc.Get(), filter->dptr(), args.ydesc.Get(), dy->dptr(), args.cdesc.Get(), algo_perf.algo, buf->mut_dptr(), args.params.max_ws_size, beta, args.xdesc.Get(), dx->mut_dptr())); } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, user_op::OpKernelState* state) const override { return Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_heuristic_search_algo(); } }; REGISTER_USER_KERNEL("conv_data_grad") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { const auto& dy = ctx->InputTensorDesc("dy", 0); const auto& filter = ctx->InputTensorDesc("filter", 0); const auto& dx = ctx->OutputTensorDesc("dx", 0); if (dx.shape().elem_cnt() == 0) return 0; const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); return InferTmpSizeWithCudnn( &dx, &filter, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), cudnn_conf.cudnn_conv_force_bwd_data_algo()); }) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); class ConvFilterGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradGpuKernel); ConvFilterGradGpuKernel() = default; ~ConvFilterGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); if (x->shape_view().elem_cnt() == 0) { Memset( ctx->stream(), filter_diff->mut_dptr(), 0, filter_diff->shape_view().elem_cnt() * GetSizeOfDataType(filter_diff->data_type())); return; } user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); CudnnConvArgsAndAlgo args_and_algo( x, filter_diff, dy, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(), cudnn_conf.cudnn_conv_force_bwd_filter_algo()); const CudnnConvArgs& args = args_and_algo.args; const cudnnConvolutionBwdFilterAlgoPerf_t& algo_perf = args_and_algo.algo_perf; OF_CUDNN_CHECK(cudnnConvolutionBackwardFilter( ctx->stream()->As()->cudnn_handle(), CudnnSPOnePtr(dy->data_type()), args.xdesc.Get(), x->dptr(), args.ydesc.Get(), dy->dptr(), args.cdesc.Get(), algo_perf.algo, buf->mut_dptr(), args.params.max_ws_size, CudnnSPZeroPtr(dy->data_type()), args.wdesc.Get(), filter_diff->mut_dptr())); } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, user_op::OpKernelState* state) const override { return Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_heuristic_search_algo(); } }; REGISTER_USER_KERNEL("conv_filter_grad") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { const auto& dy = ctx->InputTensorDesc("dy", 0); const auto& x = ctx->InputTensorDesc("x", 0); if (x.shape().elem_cnt() == 0) return 0; const auto& filter_diff = ctx->OutputTensorDesc("filter_diff", 0); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); return InferTmpSizeWithCudnn( &x, &filter_diff, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(), cudnn_conf.cudnn_conv_force_bwd_filter_algo()); }); struct ConvBiasGradState final : public user_op::OpKernelState { std::unique_ptr bias_diff_desc; }; class ConvBiasGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ConvBiasGradGpuKernel() = default; ~ConvBiasGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr CreateConvBiasGradState( user_op::KernelComputeContext* ctx) const { const auto* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0); const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const auto& data_format = ctx->Attr("data_format"); std::shared_ptr state(new ConvBiasGradState()); if (data_format == "channels_first") { CHECK_EQ(dy->shape().At(1), bias_diff->shape().At(0)); state->bias_diff_desc.reset( new CudnnTensorDesc(CUDNN_TENSOR_NCHW, bias_diff->data_type(), 1, static_cast(bias_diff->shape().At(0)), 1, 1)); } else { CHECK(data_format == "channels_last") << "Illegal data_format: " << data_format; CHECK_EQ(dy->shape().At(dy->shape().NumAxes() - 1), bias_diff->shape().At(0)); state->bias_diff_desc.reset( new CudnnTensorDesc(CUDNN_TENSOR_NHWC, bias_diff->data_type(), 1, static_cast(bias_diff->shape().At(0)), 1, 1)); } return state; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex("bias_diff", 0); CHECK_EQ(bias_diff->shape_view().NumAxes(), 1); CHECK_GE(dy->shape_view().NumAxes(), 3); CHECK_LE(dy->shape_view().NumAxes(), 5); const std::string& data_format = ctx->Attr("data_format"); std::unique_ptr dy_desc; dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape_view(), data_format)); const auto& bias_grad_state = CreateConvBiasGradState(ctx); CHECK_NOTNULL(bias_grad_state.get()); OF_CUDNN_CHECK(cudnnConvolutionBackwardBias( ctx->stream()->As()->cudnn_handle(), CudnnSPOnePtr(dy->data_type()), dy_desc->Get(), dy->dptr(), CudnnSPZeroPtr(dy->data_type()), bias_grad_state->bias_diff_desc->Get(), bias_diff->mut_dptr())); } }; REGISTER_USER_KERNEL("conv_bias_grad") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA); } // namespace } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/conv_cutlass_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUTLASS #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/user/kernels/cutlass_conv_tuner.h" #include #include #include #include namespace oneflow { namespace { class Conv2dCutlassKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: Conv2dCutlassKernel() = default; ~Conv2dCutlassKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK(add_to_output == nullptr); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& padding_before = ctx->Attr>("padding_before"); const auto& dilation_rate = ctx->Attr>("dilation_rate"); const auto& strides = ctx->Attr>("strides"); const int n = in->shape_view().At(0); const int h = in->shape_view().At(1); const int w = in->shape_view().At(2); const int c = in->shape_view().At(3); const int k = weight->shape_view().At(0); const int r = weight->shape_view().At(1); const int s = weight->shape_view().At(2); CHECK_EQ(weight->shape_view().At(3), c); const int p = out->shape_view().At(1); const int q = out->shape_view().At(2); auto* stream = ctx->stream()->As(); cutlass::library::ConvFunctionalKey key( cutlass::library::Provider::kCUTLASS, cutlass::library::ConvKind::kFprop, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC, cutlass::library::NumericTypeID::kF32, cutlass::library::NumericTypeID::kF32); const bool allow_half_accumulation = ParseBooleanFromEnv("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", false); if (allow_half_accumulation) { key.element_accumulator = cutlass::library::NumericTypeID::kF16; key.element_compute = cutlass::library::NumericTypeID::kF16; } cutlass::conv::Conv2dProblemSize problem_size( n, h, w, c, k, r, s, p, q, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1), dilation_rate.at(0), dilation_rate.at(1), cutlass::conv::Mode::kCrossCorrelation); cutlass::library::Conv2dConfiguration configuraion; configuraion.split_k_mode = cutlass::conv::SplitKMode::kSerial; configuraion.problem_size = problem_size; configuraion.stride_a = {c, w * c, h * w * c}; configuraion.stride_b = {c, s * c, r * s * c}; configuraion.stride_c = {0, 0, 0}; cutlass::library::ConvArguments arguments; arguments.A = in->dptr(); arguments.B = weight->dptr(); arguments.reordered_B = nullptr; if (bias == nullptr) { arguments.C = nullptr; } else { arguments.C = bias->dptr(); } arguments.D = out->mut_dptr(); union SP { float f; half h; }; SP alpha; SP beta; if (allow_half_accumulation) { alpha.h = static_cast(1.0F); if (bias == nullptr) { beta.h = static_cast(0.0F); } else { beta.h = static_cast(1.0F); } } else { alpha.f = 1.0F; if (bias == nullptr) { beta.f = 0.0F; } else { beta.f = 1.0F; } } arguments.alpha = α arguments.beta = β arguments.pointer_mode = cutlass::library::ScalarPointerMode::kHost; const cutlass::library::Operation* operation = nullptr; operation = [&]() -> const cutlass::library::Operation* { const std::string& tuning_cache = ctx->Attr("tuning_cache"); if (tuning_cache.empty()) { return nullptr; } auto tuning_cache_object = nlohmann::json::parse(tuning_cache); if (!tuning_cache_object.is_object()) { return nullptr; } auto it = tuning_cache_object.find("cutlass"); if (it == tuning_cache_object.end()) { return nullptr; } if (!it->is_string()) { return nullptr; } const std::string name = *it; return CutlassConvTuner::Get().GetConv2dOperation(name, stream, key, configuraion, arguments, tmp_buffer->mut_dptr(), tmp_buffer->shape_view().elem_cnt()); }(); if (!operation) { operation = CutlassConvTuner::Get().FindConv2dOperation(stream, key, configuraion, arguments, tmp_buffer->mut_dptr(), tmp_buffer->shape_view().elem_cnt()); } CHECK(operation != nullptr); const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion); std::vector host_workspace(host_workspace_size, 0); auto init_status = operation->initialize(&configuraion, host_workspace.data(), tmp_buffer->mut_dptr(), stream->cuda_stream()); CHECK(init_status == cutlass::Status::kSuccess); auto run_status = operation->run(&arguments, host_workspace.data(), tmp_buffer->mut_dptr(), stream->cuda_stream()); CHECK(run_status == cutlass::Status::kSuccess); } }; REGISTER_USER_KERNEL("conv2d") .SetCreateFn() .SetIsMatchedHob( (user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobAttr("data_format") == "channels_last") && (user_op::HobAttr("groups") == 1) && (user_op::HobDataType("in", 0) == DataType::kFloat16) // Compatible with typo `KERENL` && ((user_op::HobEnvBool("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", false) == true) || (user_op::HobEnvBool("ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL", false) == true))) .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { // use static workspace size return 128 * 1024 * 1024; }) .SetPriority(user_op::kKernelPriorityOptimized); } // namespace } // namespace oneflow #endif // WITH_CUTLASS ================================================ FILE: oneflow/user/kernels/conv_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewChannelsFirstMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } auto ChannelsFirstMatmulPrimitiveExists() { return hob::make_custom("ChannelsFirstMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewChannelsFirstMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewChannelsLastMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ChannelsLastMatmulPrimitiveExists() { return hob::make_custom("ChannelsLastMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewChannelsLastMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvDataGradTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ConvDataGradTransATransBMatmulPrimitiveExists() { return hob::make_custom("ConvDataGradTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvDataGradTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvDataGradTransANoTransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } auto ConvDataGradTransANoTransBMatmulPrimitiveExists() { return hob::make_custom( "ConvDataGradTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvDataGradTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvWeightGradTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ConvWeightGradTransATransBMatmulPrimitiveExists() { return hob::make_custom( "ConvWeightGradTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvWeightGradTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvWeightGradNoTransATransBMatmulPrimitive( Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/true); } auto ConvWeightGradNoTransATransBMatmulPrimitiveExists() { return hob::make_custom( "ConvWeightGradNoTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvWeightGradNoTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvBiasGradNoTransANoTransBMatmulPrimitive( Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } auto ConvBiasGradNoTransANoTransBMatmulPrimitiveExists() { return hob::make_custom( "ConvBiasGradNoTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvBiasGradNoTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvBiasGradTransANoTransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } auto ConvBiasGradTransANoTransBMatmulPrimitiveExists() { return hob::make_custom( "ConvBiasGradTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvBiasGradTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template using Im2ColFunc = void (*)(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf); template using Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr); template T* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) { return tensor->mut_dptr() + tensor->shape_view().Count(1) * idx; } template const T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) { return tensor->dptr() + tensor->shape_view().Count(1) * idx; } size_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape, const int32_t idx_offset) { int64_t col_buf_elem_cnt = 1; int64_t ndims = out_shape.NumAxes() - 2; for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); } for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); } return col_buf_elem_cnt; } template class ColBufWriter { public: ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : src_ptr_(src_ptr), dst_ptr_(dst_ptr), c_size_(c_size), id_size_(id_size), ih_size_(ih_size), iw_size_(iw_size), od_size_(od_size), oh_size_(oh_size), ow_size_(ow_size) {} virtual ~ColBufWriter() = default; virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void InvalidDFunc() = 0; virtual void InvalidHFunc() = 0; virtual void InvalidWFunc() = 0; virtual void NextImCSize() = 0; protected: const T* src_ptr_; T* dst_ptr_; int64_t c_size_; int64_t id_size_; int64_t ih_size_; int64_t iw_size_; int64_t od_size_; int64_t oh_size_; int64_t ow_size_; }; template class Im2ColWriter final : public ColBufWriter { public: Im2ColWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Im2ColWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c]; } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw]; } void InvalidDFunc() override { FOR_RANGE(int64_t, i, 0, this->od_size_) { *(this->dst_ptr_++) = 0; } } void InvalidHFunc() override { FOR_RANGE(int64_t, i, 0, this->oh_size_) { *(this->dst_ptr_++) = 0; } } void InvalidWFunc() override { FOR_RANGE(int64_t, i, 0, this->ow_size_) { *(this->dst_ptr_++) = 0; } } void NextImCSize() override { this->src_ptr_ += this->c_size_; } }; template class Col2ImWriter final : public ColBufWriter { public: Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Col2ImWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] += *(this->src_ptr_++); } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++); } void InvalidDFunc() override { this->src_ptr_ += this->od_size_; } void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; } void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; } void NextImCSize() override { this->dst_ptr_ += this->c_size_; } }; template using DHWValidFunc = void (ColBufWriter::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw); template class ColBufUtil final { public: ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before) : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before) { id_num_ = in_shape.At(dhw_offset); ih_num_ = in_shape.At(dhw_offset + 1); iw_num_ = in_shape.At(dhw_offset + 2); od_num_ = out_shape.At(dhw_offset); oh_num_ = out_shape.At(dhw_offset + 1); ow_num_ = out_shape.At(dhw_offset + 2); if (dhw_offset == 2) { dhw_valid_func_ = &ColBufWriter::CDHWWrite; } else { dhw_valid_func_ = &ColBufWriter::DHWCWrite; } } void operator()(ColBufWriter* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) { int64_t id = kd * dilation_rate_[0] - padding_before_[0]; FOR_RANGE(int64_t, od, 0, od_num_) { if (id < 0 || id >= id_num_) { col_buf_writer->InvalidDFunc(); } else { int64_t ih = kh * dilation_rate_[1] - padding_before_[1]; FOR_RANGE(int64_t, oh, 0, oh_num_) { if (ih < 0 || ih >= ih_num_) { col_buf_writer->InvalidHFunc(); } else { int64_t iw = kw * dilation_rate_[2] - padding_before_[2]; FOR_RANGE(int64_t, ow, 0, ow_num_) { if (iw < 0 || iw >= iw_num_) { col_buf_writer->InvalidWFunc(); } else { (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw); } iw += strides_[2]; } } ih += strides_[1]; } } id += strides_[0]; } } private: int64_t id_num_; int64_t ih_num_; int64_t iw_num_; int64_t od_num_; int64_t oh_num_; int64_t ow_num_; const int32_t* strides_; const int32_t* dilation_rate_; const int32_t* padding_before_; DHWValidFunc dhw_valid_func_; }; template struct ConvKernelUtil final { public: static void NCDHWIm2Col(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before); Im2ColWriter col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCIm2Col(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before); Im2ColWriter col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } private: static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { for (int64_t c = 0; c != weight_shape.At(4); ++c) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } }; template struct ConvOpKernelCache final : public user_op::OpKernelCache { Im2ColFunc im2col_func_ = nullptr; Col2ImFunc col2im_func_ = nullptr; Shape in_5d_shape_; Shape out_5d_shape_; Shape weight_5d_shape_; std::vector strides_3d_; std::vector dilation_rate_3d_; std::vector padding_before_3d_; bool is_out_diff_need_trans_ = false; int32_t idx_offset_{}; bool is_dynamic_{}; }; template std::shared_ptr> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); std::shared_ptr> cache(new ConvOpKernelCache()); if (data_format == "channels_first") { cache->im2col_func_ = ConvKernelUtil::NCDHWIm2Col; cache->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; cache->is_out_diff_need_trans_ = false; cache->idx_offset_ = 2; } else { cache->im2col_func_ = ConvKernelUtil::NDHWCIm2Col; cache->col2im_func_ = ConvKernelUtil::NDHWCCol2Im; cache->is_out_diff_need_trans_ = true; cache->idx_offset_ = 1; } auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { DimVector ret_vec(shape.dim_vec()); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; const auto* in_tensor = ctx->TensorDesc4ArgNameAndIndex(in_name, 0); const auto& in_shape = in_tensor->shape(); cache->in_5d_shape_ = Gen5DShape(in_shape, cache->idx_offset_); cache->out_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_); cache->weight_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; cache->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { cache->padding_before_3d_.emplace_back(0); } else { cache->padding_before_3d_.emplace_back(padding_before.at(index)); } } return cache; } template void InitBiasMulBuf(T* dptr, int64_t num) { for (int64_t i = 0; i < num; ++i) { dptr[i] = 1; } } template class ConvCpuKernel final : public user_op::OpKernel { public: ConvCpuKernel() = default; ~ConvCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "in", "out", "weight"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T* col_buf_dptr = tmp_buffer->mut_dptr(); bool is_bias_mul_inited = false; const auto& data_format = ctx->Attr("data_format"); std::unique_ptr matmul; if (data_format == "channels_first") { matmul = NewChannelsFirstMatmulPrimitive(ctx); } else { matmul = NewChannelsLastMatmulPrimitive(ctx); } CHECK(matmul); float beta = 0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), out->data_type()); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); Memcpy( ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = 1; } for (int64_t i = 0; i < in->shape_view().At(0); ++i) { conv_cache->im2col_func_(GetImgDptr(in, i), ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf_dptr); // channels first: out = weight * col_buf // channels last: out = (weight * col_buf)(T) int32_t idx_offset = conv_cache->idx_offset_; matmul->Launch(ctx->stream(), conv_cache->weight_5d_shape_.At(0), // filter conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw static_cast(1), weight->dptr(), col_buf_dptr, beta, GetImgMutDptr(out, i)); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { int64_t num_of_col_buf = CalcElemNumOfColBuf(out->shape_view(), weight->shape_view(), idx_offset); int64_t num_of_bias_mul = (tmp_buffer->shape_view().elem_cnt() - num_of_col_buf * sizeof(T)) / sizeof(T); CHECK_GT(num_of_bias_mul, 0); T* bias_mul_dptr = col_buf_dptr + num_of_col_buf; if (!is_bias_mul_inited) { InitBiasMulBuf(bias_mul_dptr, num_of_bias_mul); is_bias_mul_inited = true; } // channels first: out += bias * bias_mul // channels last: out += (bias * bias_mul)(T) matmul->Launch(ctx->stream(), conv_cache->weight_5d_shape_.At(0), // filter conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow 1, // 1 static_cast(1), bias->dptr(), bias_mul_dptr, static_cast(1), GetImgMutDptr(out, i)); } } } }; #define REGISTER_CONV_KERNEL(op_name, dtype, ndims) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") == 1) \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && ChannelsFirstMatmulPrimitiveExists() \ && ChannelsLastMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_shape = ctx->OutputTensorDesc("out", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_shape, weight_shape, idx_offset) * sizeof(dtype); \ bool has_bias = ctx->has_input("bias", 0); \ if (has_bias) { \ int64_t bias_mul_cnt = 1; \ for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_shape.At(idx_offset + i); } \ tmp_buffer_size += bias_mul_cnt * sizeof(dtype); \ } \ return tmp_buffer_size; \ }) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ if (ctx.has_input("_add_to_output", 0)) { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ } \ return Maybe::Ok(); \ }); REGISTER_CONV_KERNEL(conv1d, float, 1); REGISTER_CONV_KERNEL(conv2d, float, 2); REGISTER_CONV_KERNEL(conv3d, float, 3); REGISTER_CONV_KERNEL(conv1d, double, 1); REGISTER_CONV_KERNEL(conv2d, double, 2); REGISTER_CONV_KERNEL(conv3d, double, 3); template class ConvDataGradCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(ConvDataGradCpuKernel); ConvDataGradCpuKernel() = default; ~ConvDataGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "dx", "dy", "filter"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (conv_cache->is_out_diff_need_trans_) { matmul = NewConvDataGradTransATransBMatmulPrimitive(ctx); } else { matmul = NewConvDataGradTransANoTransBMatmulPrimitive(ctx); } CHECK(matmul); int32_t idx_offset = conv_cache->idx_offset_; FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) { // channels first: col_buf' = weight(T) * out[i]' // channels last : col_buf' = weight(T) * out[i]'(T) matmul->Launch(ctx->stream(), conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow conv_cache->weight_5d_shape_.At(0), // filter static_cast(1), filter->dptr(), GetImgDptr(dy, i), static_cast(0), col_buf->mut_dptr()); // in' = col2im(col_buf') conv_cache->col2im_func_(col_buf->dptr(), ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), GetImgMutDptr(dx, i)); } if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), dx->data_type()); CHECK_EQ(add_to_output->shape_view(), dx->shape_view()); std::unique_ptr primitive = ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); primitive->Launch(ctx->stream(), dx->dptr(), add_to_output->dptr(), dx->mut_dptr(), add_to_output->shape_view().elem_cnt()); } } }; #define REGISTER_CONV_DATA_GRAD_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") == 1) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && ConvDataGradTransATransBMatmulPrimitiveExists() \ && ConvDataGradTransANoTransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("filter", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_diff_shape, weight_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, float); REGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, double); template class ConvFilterGradCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradCpuKernel); ConvFilterGradCpuKernel() = default; ~ConvFilterGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "x", "dy", "filter_diff"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); Memset(ctx->stream(), filter_diff->mut_dptr(), 0, filter_diff->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (conv_cache->is_out_diff_need_trans_) { matmul = NewConvWeightGradTransATransBMatmulPrimitive(ctx); } else { matmul = NewConvWeightGradNoTransATransBMatmulPrimitive(ctx); } CHECK(matmul); int32_t idx_offset = conv_cache->idx_offset_; FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) { conv_cache->im2col_func_(GetImgDptr(x, i), ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf->mut_dptr()); // channels first: weight' += out[i]' * col_buf(T) // channels last : weight' += out[i]'(T) * col_buf(T) matmul->Launch(ctx->stream(), conv_cache->weight_5d_shape_.At(0), // filter conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow static_cast(1), GetImgDptr(dy, i), col_buf->dptr(), static_cast(1), filter_diff->mut_dptr()); } } }; #define REGISTER_CONV_FILTER_GRAD_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") == 1) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && ConvWeightGradTransATransBMatmulPrimitiveExists() \ && ConvWeightGradNoTransATransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_diff_shape, weight_diff_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, float); REGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, double); template class ConvBiasGradCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(ConvBiasGradCpuKernel); ConvBiasGradCpuKernel() = default; ~ConvBiasGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex("bias_diff", 0); user_op::Tensor* bias_mul_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); InitBiasMulBuf(bias_mul_buf->mut_dptr(), bias_mul_buf->shape_view().elem_cnt() / sizeof(T)); Memset(ctx->stream(), bias_diff->mut_dptr(), 0, bias_diff->shape_view().elem_cnt() * sizeof(T)); const auto& data_format = ctx->Attr("data_format"); int32_t idx_offset; bool is_out_diff_need_trans = false; int32_t filter; if (data_format == "channels_first") { idx_offset = 2; is_out_diff_need_trans = false; filter = dy->shape_view().At(1); } else { idx_offset = 1; is_out_diff_need_trans = true; filter = dy->shape_view().At(dy->shape_view().NumAxes() - 1); } std::unique_ptr matmul; if (is_out_diff_need_trans) { matmul = NewConvBiasGradTransANoTransBMatmulPrimitive(ctx); } else { matmul = NewConvBiasGradNoTransANoTransBMatmulPrimitive(ctx); } CHECK(matmul); int ndims = dy->shape_view().NumAxes() - 2; FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) { // channels first: bias' += out' * bias_mul // channels last: bias' += out'(T) * bias_mul matmul->Launch(ctx->stream(), filter, // filter 1, // 1 dy->shape_view().Count(idx_offset, idx_offset + ndims), // od * oh * ow static_cast(1), GetImgDptr(dy, i), bias_mul_buf->dptr(), static_cast(1), bias_diff->mut_dptr()); } } }; #define REGISTER_CONV_BIAS_GRAD_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && ConvBiasGradTransANoTransBMatmulPrimitiveExists() \ && ConvBiasGradNoTransANoTransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ const int ndims = out_diff_shape.NumAxes() - 2; \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ int64_t bias_mul_cnt = 1; \ for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_diff_shape.At(idx_offset + i); } \ return bias_mul_cnt * sizeof(dtype); \ }) REGISTER_CONV_BIAS_GRAD_KERNEL(conv_bias_grad, float); REGISTER_CONV_BIAS_GRAD_KERNEL(conv_bias_grad, double); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/convert_memory_format_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/convert_memory_format_util.h" namespace oneflow { class ConvertMemoryFormatKernel final : public user_op::OpKernel { public: ConvertMemoryFormatKernel() = default; ~ConvertMemoryFormatKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); ConvertMemoryFormat(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().data(), in->data_type(), in->dptr(), out->mut_dptr(), in->memory_format(), out->memory_format()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("convert_memory_format").SetCreateFn(); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/convert_memory_format_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/convert_memory_format_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { std::unique_ptr NewPermutePrimitive(DeviceType device_type, const int& num_dims) { return ep::primitive::NewPrimitive(device_type, num_dims); } std::unique_ptr NewMemcpyPrimitive(DeviceType device_type) { return ep::primitive::NewPrimitive( device_type, ep::primitive::MemcpyKind::kDtoD); } void ComputeIdentity(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type, const void* in, void* out) { size_t count = 1; for (int i = 0; i < ndim; ++i) { count *= shape[i]; } auto memcpy_primitive = NewMemcpyPrimitive(stream->device_type()); CHECK(memcpy_primitive) << "Can not create Memcpy primitive for device type " << stream->device_type(); memcpy_primitive->Launch(stream, out, in, count * GetSizeOfDataType(data_type)); } void ComputeContiguousToChannelsLast(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type, const void* in, void* out) { if (ndim <= 2) { return ComputeIdentity(stream, ndim, shape, data_type, in, out); } std::vector permute(ndim); permute[0] = 0; permute[ndim - 1] = 1; for (int i = 0; i < ndim - 2; ++i) { permute[i + 1] = i + 2; } auto primitive = NewPermutePrimitive(stream->device_type(), ndim); CHECK_NOTNULL_OR_THROW(primitive); primitive->Launch(stream, data_type, ndim, shape, in, permute.data(), out); } void ComputeChannelsLastToContiguous(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type, const void* in, void* out) { if (ndim <= 2) { return ComputeIdentity(stream, ndim, shape, data_type, in, out); } std::vector permute(ndim); permute[0] = 0; permute[1] = ndim - 1; for (int i = 0; i < ndim - 2; ++i) { permute[i + 2] = i + 1; } auto primitive = NewPermutePrimitive(stream->device_type(), ndim); CHECK_NOTNULL_OR_THROW(primitive); primitive->Launch(stream, data_type, ndim, shape, in, permute.data(), out); } using ConvertMemoryFormatFunc = std::function; ConvertMemoryFormatFunc convert_funcs[kMemoryFormatCount][kMemoryFormatCount] = { /*kContiguous->other*/ {ComputeIdentity, ComputeContiguousToChannelsLast}, /*kChannelsLast->other*/ {ComputeChannelsLastToContiguous, ComputeIdentity}, }; void ConvertMemoryFormat(ep::Stream* stream, const user_op::Tensor* in, user_op::Tensor* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format) { auto convert_func = convert_funcs[in_memory_format][out_memory_format]; convert_func(stream, in->shape_view().size(), in->shape_view().data(), in->data_type(), in->dptr(), out->mut_dptr()); } void ConvertMemoryFormat(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type, const void* in, void* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format) { auto convert_func = convert_funcs[in_memory_format][out_memory_format]; convert_func(stream, ndim, shape, data_type, in, out); } void ConvertMemoryFormat(ep::Stream* stream, const ShapeView& shape, DataType data_type, const void* in, void* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format) { ConvertMemoryFormat(stream, shape.size(), shape.data(), data_type, in, out, in_memory_format, out_memory_format); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/convert_memory_format_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { void ConvertMemoryFormat(ep::Stream* stream, const user_op::Tensor* in, user_op::Tensor* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format); void ConvertMemoryFormat(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type, const void* in, void* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format); void ConvertMemoryFormat(ep::Stream* stream, const ShapeView& shape, DataType data_type, const void* in, void* out, MemoryFormat in_memory_format, MemoryFormat out_memory_format); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/copy_data_content_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { namespace { class CopyDataContentKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CopyDataContentKernel() = default; ~CopyDataContentKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = in->shape_view().elem_cnt(); // For 0-size tensor, we don't need to copy data, but we must // fill output tensor with Scalar(0) because during the backward propogation, this kernel will // also be used. if (elem_cnt == 0) { const int64_t out_elem_cnt = out->shape_view().elem_cnt(); CHECK_GE(out_elem_cnt, 0); if (out_elem_cnt == 0) { return; } std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->device_type(), out->data_type()); CHECK(fill); fill->Launch(ctx->stream(), out->mut_dptr(), Scalar(0), out_elem_cnt); return; } CHECK_EQ(out->shape_view().elem_cnt(), elem_cnt); CHECK_EQ(in->data_type(), out->data_type()); if (elem_cnt > 0) { std::unique_ptr primitive = ep::primitive::NewPrimitive( ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD); CHECK(primitive) << "Can not create Memcpy primitive for device type " << ctx->stream()->device_type(); primitive->Launch(ctx->stream(), out->mut_dptr(), in->dptr(), elem_cnt * GetSizeOfDataType(in->data_type())); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COPY_DATA_CONTENT_KERNEL(op_type_name) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn() \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); \ return Maybe::Ok(); \ }); REGISTER_COPY_DATA_CONTENT_KERNEL("squeeze"); REGISTER_COPY_DATA_CONTENT_KERNEL("reshape_like"); REGISTER_COPY_DATA_CONTENT_KERNEL("expand_dims"); REGISTER_COPY_DATA_CONTENT_KERNEL("reshape"); REGISTER_COPY_DATA_CONTENT_KERNEL("amp_white_identity"); REGISTER_COPY_DATA_CONTENT_KERNEL("amp_black_identity"); REGISTER_COPY_DATA_CONTENT_KERNEL("identity"); REGISTER_COPY_DATA_CONTENT_KERNEL("identity_buffer"); REGISTER_COPY_DATA_CONTENT_KERNEL("parallel_cast"); REGISTER_COPY_DATA_CONTENT_KERNEL("hierarchical_parallel_cast"); REGISTER_COPY_DATA_CONTENT_KERNEL("hierarchical_parallel_cast_like"); REGISTER_COPY_DATA_CONTENT_KERNEL("pinned_identity"); REGISTER_COPY_DATA_CONTENT_KERNEL("depend"); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/copy_hd_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace { class CopyHdKernel final : public user_op::OpKernel { public: CopyHdKernel() = default; ~CopyHdKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK(in) << "input of copy not found"; const ShapeView& in_shape = in->shape_view(); if (in_shape.elem_cnt() == 0) { // 0 shape tensor do not need copy } else { const DataType in_data_type = in->data_type(); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK(out) << "output of copy not found, op: " << ctx->op_name(); CHECK_EQ(out->shape_view(), in_shape); CHECK_EQ(out->data_type(), in_data_type); ep::primitive::MemcpyKind kind{}; if (ctx->op_type_name() == "copy_h2d") { kind = ep::primitive::MemcpyKind::kHtoD; } else if (ctx->op_type_name() == "copy_d2h") { kind = ep::primitive::MemcpyKind::kDtoH; } else { UNIMPLEMENTED(); } std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), kind); primitive->Launch(ctx->stream(), out->mut_raw_dptr(), in->raw_dptr(), in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("copy_h2d").SetCreateFn(); REGISTER_USER_KERNEL("copy_d2h").SetCreateFn(); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/copy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace { class CopyKernel final : public user_op::OpKernel { public: CopyKernel() = default; ~CopyKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); CHECK_EQ(out->shape_view(), in_shape); const DataType in_data_type = in->data_type(); CHECK_EQ(out->data_type(), in_data_type); if (in_shape.elem_cnt() == 0) { // 0 shape tensor do not need copy return; } else { AutoMemcpy(ctx->stream(), out->mut_raw_dptr(), in->raw_dptr(), in_shape.elem_cnt() * GetSizeOfDataType(in_data_type), out->mem_case(), in->mem_case()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("copy").SetCreateFn(); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/count_not_finite_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { template class MultiCountNotFiniteCpuKernel final : public user_op::OpKernel { public: MultiCountNotFiniteCpuKernel() = default; ~MultiCountNotFiniteCpuKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); int64_t* y_ptr = y->mut_dptr(); int64_t count = 0; FOR_RANGE(int32_t, i, 0, ctx->inputs().size()) { user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", i); const T* x_ptr = x->dptr(); FOR_RANGE(int32_t, j, 0, x->shape_view().elem_cnt()) { if (!std::isfinite(x_ptr[j])) { count++; } } } y_ptr[0] = count; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COUNT_NOT_FINITE_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("count_not_finite") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_COUNT_NOT_FINITE_CPU_KERNEL(float) REGISTER_COUNT_NOT_FINITE_CPU_KERNEL(double) #define REGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("multi_count_not_finite") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(float) REGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/count_not_finite_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template struct Param { const T* x[N]; int64_t x_elem_cnt[N]; int64_t* y; int64_t num_x; }; using CuInt64T = unsigned long long int; __device__ __inline__ int64_t AtomicAdd(int64_t* address, int64_t val) { static_assert(sizeof(int64_t) == sizeof(CuInt64T), "size error"); return static_cast( atomicAdd(reinterpret_cast(address), static_cast(val))); } template __inline__ __device__ bool IsFinite(T x) { return isfinite(x); } template<> __inline__ __device__ bool IsFinite(half x) { return IsFinite(static_cast(x)); } template __global__ void CountNotFiniteGpu(const int64_t n, const T* x, int64_t* y) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage; int64_t thread_count = 0; CUDA_1D_KERNEL_LOOP(i, n) { if (!IsFinite(x[i])) { thread_count += 1; } } __syncthreads(); int64_t block_count_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_count, cub::Sum()); if (threadIdx.x == 0) { AtomicAdd(y, block_count_sum); } } template __global__ void MultiCountNotFiniteGpu(Param param) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage; int64_t thread_count = 0; for (int32_t k = 0; k < param.num_x; ++k) { CUDA_1D_KERNEL_LOOP(i, param.x_elem_cnt[k]) { if (!IsFinite(param.x[k][i])) { thread_count += 1; } } } __syncthreads(); int64_t block_count_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_count, cub::Sum()); if (threadIdx.x == 0) { AtomicAdd(param.y, block_count_sum); } } constexpr int64_t kCountNotFiniteNumBlocks = 512; int GetCountNotFiniteNumBlocks(const int64_t elem_cnt) { return std::min((elem_cnt + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, kCountNotFiniteNumBlocks); } } // namespace template class CountNotFiniteGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CountNotFiniteGpuKernel() = default; ~CountNotFiniteGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t elem_cnt = x->shape_view().elem_cnt(); Memset(ctx->stream(), y->mut_dptr(), 0, y->shape_view().elem_cnt() * sizeof(int64_t)); CountNotFiniteGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, x->dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("count_not_finite") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(half) REGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(float) REGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(double) template class MultiCountNotFiniteGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiCountNotFiniteGpuKernel() = default; ~MultiCountNotFiniteGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); Param para; Memset(ctx->stream(), y->mut_dptr(), 0, y->shape_view().elem_cnt() * sizeof(int64_t)); para.y = y->mut_dptr(); int64_t remain_size = ctx->inputs().size(); int64_t input_id = 0; while (remain_size > 0) { if (remain_size > 128) { remain_size -= 128; para.num_x = 128; } else { para.num_x = remain_size; remain_size = 0; } int64_t max_elem_cnt = 0; for (int32_t i = 0; i < para.num_x; ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", input_id); input_id++; para.x[i] = x->dptr(); para.x_elem_cnt[i] = x->shape_view().elem_cnt(); max_elem_cnt = std::max(max_elem_cnt, x->shape_view().elem_cnt()); } MultiCountNotFiniteGpu <<stream()->As()->cuda_stream()>>>(para); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("multi_count_not_finite") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(half) REGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(float) REGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_greedy_decoder.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/ctc_greedy_decoder.h" namespace oneflow { namespace { template struct CTCGreedyDecoderFunctor final { void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr, const T* log_probs_ptr, const int64_t* input_lengths_ptr, const bool merge_repeated, const int64_t max_input_length, const int64_t batch_size, const int64_t num_labels) { FOR_RANGE(int64_t, b, 0, batch_size) { CHECK_GE(max_input_length, input_lengths_ptr[b]); } NdIndexOffsetHelper input_helper(max_input_length, batch_size, num_labels); FOR_RANGE(int64_t, b, 0, batch_size) { int64_t prev_indices = -1, t_dec = 0; neg_sum_logits_ptr[b] = 0; FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) { const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)]; int64_t max_indice = std::max_element(prob_data_t, prob_data_t + num_labels) - prob_data_t; neg_sum_logits_ptr[b] -= prob_data_t[max_indice]; if (max_indice != num_labels - 1 && !(merge_repeated && (prev_indices == max_indice))) { decoded_ptr[b * max_input_length + t_dec] = max_indice; t_dec++; } prev_indices = max_indice; } FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; } } } }; } // namespace REGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCPU, float); REGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCPU, double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_greedy_decoder.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/ctc_greedy_decoder.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void CtcGreedyDecodeGpuMultiThread(int64_t* decoded_ptr, T* neg_sum_logits_ptr, const T* log_probs_ptr, const int64_t* input_lengths_ptr, const bool merge_repeated, const int64_t max_input_length, const int64_t batch_size, const int64_t num_labels) { const int64_t bid = blockIdx.x; const int64_t tid = threadIdx.x; for (int64_t b = bid; b < batch_size; b += gridDim.x) { if (tid == 0) { if (input_lengths_ptr[b] > max_input_length) __trap(); } } for (int64_t b = bid; b < batch_size; b += gridDim.x) { extern __shared__ int64_t shared_max_indices_memory[]; int64_t* shared_max_indices = (int64_t*)shared_max_indices_memory; NdIndexOffsetHelper input_helper(max_input_length, batch_size, num_labels); for (int64_t t = tid; t < max_input_length; t += blockDim.x) { const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)]; int64_t max_indice = 0; T max_value = -FLT_MAX; FOR_RANGE(int64_t, c, 0, num_labels) { const T prob = prob_data_t[c]; if (prob > max_value) { max_indice = c; max_value = prob; } } shared_max_indices[t] = max_indice; } __syncthreads(); if (tid == 0) { int64_t prev_indices = -1, t_dec = 0; FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) { const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)]; const int64_t indice_t = shared_max_indices[t]; neg_sum_logits_ptr[b] -= prob_data_t[indice_t]; if (indice_t != num_labels - 1 && !(merge_repeated && (prev_indices == indice_t))) { decoded_ptr[b * max_input_length + t_dec] = indice_t; t_dec++; } prev_indices = indice_t; } FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; } } } } template __global__ void CtcGreedyDecodeGpu(int64_t* decoded_ptr, T* neg_sum_logits_ptr, const T* log_probs_ptr, const int64_t* input_lengths_ptr, const bool merge_repeated, const int64_t max_input_length, const int64_t batch_size, const int64_t num_labels) { for (int64_t b = 0; b < batch_size; b++) { if (input_lengths_ptr[b] > max_input_length) __trap(); } NdIndexOffsetHelper input_helper(max_input_length, batch_size, num_labels); CUDA_1D_KERNEL_LOOP(b, batch_size) { int prev_indices = -1, t_dec = 0; neg_sum_logits_ptr[b] = 0; FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) { const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)]; int64_t max_indice = -1; T max_value = -FLT_MAX; FOR_RANGE(int64_t, c, 0, num_labels) { if (prob_data_t[c] > max_value) { max_indice = c; max_value = prob_data_t[c]; } } neg_sum_logits_ptr[b] -= max_value; if (max_indice != num_labels - 1 && !(merge_repeated && (prev_indices == max_indice))) { decoded_ptr[b * max_input_length + t_dec] = max_indice; t_dec++; } prev_indices = max_indice; } FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; } } } template struct CTCGreedyDecoderFunctor final { void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr, const T* log_probs_ptr, const int64_t* input_lengths_ptr, const bool merge_repeated, const int64_t max_input_length, const int64_t batch_size, const int64_t num_labels) { int32_t thread_num = batch_size * kCudaThreadsNumPerBlock; int64_t shared_mem_size = max_input_length * sizeof(int64_t); int max_active_blocks; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, CtcGreedyDecodeGpu, kCudaThreadsNumPerBlock, shared_mem_size)); if (max_active_blocks > 0) { CtcGreedyDecodeGpuMultiThread<<As()->cuda_stream()>>>( decoded_ptr, neg_sum_logits_ptr, log_probs_ptr, input_lengths_ptr, merge_repeated, max_input_length, batch_size, num_labels); } else { CtcGreedyDecodeGpu<<As()->cuda_stream()>>>( decoded_ptr, neg_sum_logits_ptr, log_probs_ptr, input_lengths_ptr, merge_repeated, max_input_length, batch_size, num_labels); } } }; } // namespace REGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCUDA, float); REGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCUDA, double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_greedy_decoder.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_ #define _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace { template struct CTCGreedyDecoderFunctor final { void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr, const T* log_probs_ptr, const int64_t* input_lengths_ptr, const bool merge_repeated, const int64_t max_input_length, const int64_t batch_size, const int64_t num_labels); }; } // namespace template class CTCGreedyDecoderKernel final : public user_op::OpKernel { public: CTCGreedyDecoderKernel() = default; ~CTCGreedyDecoderKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex("log_probs", 0); const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex("input_lengths", 0); user_op::Tensor* decoded = ctx->Tensor4ArgNameAndIndex("decoded", 0); user_op::Tensor* neg_sum_logits = ctx->Tensor4ArgNameAndIndex("neg_sum_logits", 0); const T* log_probs_ptr = log_probs->dptr(); const int64_t* input_lengths_ptr = input_lengths->dptr(); const bool merge_repeated = ctx->Attr("merge_repeated"); const int64_t max_input_length = log_probs->shape_view().At(0); const int64_t batch_size = log_probs->shape_view().At(1); const int64_t num_labels = log_probs->shape_view().At(2); CHECK_EQ(batch_size, input_lengths->shape_view().At(0)); int64_t* decoded_ptr = decoded->mut_dptr(); T* neg_sum_logits_ptr = neg_sum_logits->mut_dptr(); CTCGreedyDecoderFunctor()(ctx->stream(), decoded_ptr, neg_sum_logits_ptr, log_probs_ptr, input_lengths_ptr, merge_repeated, max_input_length, batch_size, num_labels); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CTC_GREEDY_DECODER_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("ctc_greedy_decoder") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("log_probs", 0) == GetDataType::value)); } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/ctc_loss_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/ctc_loss_kernel_util.h" namespace oneflow { template class CtcLossKernel final : public user_op::OpKernel { public: CtcLossKernel() = default; ~CtcLossKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex("log_probs", 0); const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex("targets", 0); const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex("input_lengths", 0); const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex("target_lengths", 0); user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex("loss", 0); user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const T* log_probs_ptr = log_probs->dptr(); const TARGET* targets_ptr = targets->dptr(); const IDX* input_lengths_ptr = input_lengths->dptr(); const IDX* target_lengths_ptr = target_lengths->dptr(); const int64_t blank = ctx->Attr("blank"); const int64_t max_input_length = log_probs->shape_view().At(0); const int64_t batch_size = log_probs->shape_view().At(1); const int64_t num_labels = log_probs->shape_view().At(2); const int64_t max_target_length = ctx->Attr("max_target_length"); const int32_t targets_ndim = targets->shape_view().NumAxes(); NdIndexOffsetHelper input_helper(max_input_length, batch_size, num_labels); NdIndexOffsetHelper alpha_helper(batch_size, max_input_length, 2 * max_target_length + 1); T* loss_ptr = loss->mut_dptr(); T* alpha_ptr = alpha->mut_dptr(); CtcLossKernelUtil::CtcLossForward( ctx->stream(), log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, alpha_ptr, loss_ptr, input_helper, alpha_helper, batch_size, max_input_length, max_target_length, blank, targets_ndim); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CTC_LOSS_KERNEL(device, dtype, target_type, idx_dtype) \ REGISTER_USER_KERNEL("ctc_loss") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("log_probs", 0) == OF_PP_PAIR_SECOND(dtype)) \ && (user_op::HobDataType("targets", 0) == OF_PP_PAIR_SECOND(target_type)) \ && (user_op::HobDataType("input_lengths", 0) == OF_PP_PAIR_SECOND(idx_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class CtcLossGradKernel final : public user_op::OpKernel { public: CtcLossGradKernel() = default; ~CtcLossGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); const user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex("loss", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex("log_probs", 0); const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex("targets", 0); const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex("input_lengths", 0); const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex("target_lengths", 0); user_op::Tensor* grad = ctx->Tensor4ArgNameAndIndex("grad", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const T* grad_out_ptr = grad_out->dptr(); const T* loss_ptr = loss->dptr(); const T* alpha_ptr = alpha->dptr(); const T* log_probs_ptr = log_probs->dptr(); const TARGET* targets_ptr = targets->dptr(); const IDX* input_lengths_ptr = input_lengths->dptr(); const IDX* target_lengths_ptr = target_lengths->dptr(); const int64_t blank = ctx->Attr("blank"); const bool zero_infinity = ctx->Attr("zero_infinity"); const int64_t batch_size = log_probs->shape_view().At(1); const int64_t num_labels = log_probs->shape_view().At(2); const int64_t max_input_length = log_probs->shape_view().At(0); const int64_t max_target_length = ctx->Attr("max_target_length"); const int32_t targets_ndim = targets->shape_view().NumAxes(); NdIndexOffsetHelper input_helper(max_input_length, batch_size, num_labels); NdIndexOffsetHelper beta_helper(batch_size, max_input_length, 2 * max_target_length + 1); T* grad_ptr = grad->mut_dptr(); T* beta_ptr = tmp_buffer->mut_dptr(); CtcLossKernelUtil::CtcLossBackward( ctx->stream(), grad_out_ptr, loss_ptr, alpha_ptr, log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, beta_ptr, grad_ptr, input_helper, beta_helper, batch_size, max_input_length, max_target_length, num_labels, blank, zero_infinity, targets_ndim); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CTC_LOSS_BACKWARD_KERNEL(device, dtype, target_type, idx_dtype) \ REGISTER_USER_KERNEL("ctc_loss_grad") \ .SetCreateFn< \ CtcLossGradKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("log_probs", 0) == OF_PP_PAIR_SECOND(dtype)) \ && (user_op::HobDataType("targets", 0) == OF_PP_PAIR_SECOND(target_type)) \ && (user_op::HobDataType("input_lengths", 0) == OF_PP_PAIR_SECOND(idx_dtype))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& log_probs_shape = ctx->InputShape("log_probs", 0); \ const int64_t max_target_length = ctx->Attr("max_target_length"); \ int64_t elem_cnt = \ log_probs_shape.At(1) * log_probs_shape.At(0) * (2 * max_target_length + 1); \ return elem_cnt * sizeof(OF_PP_PAIR_FIRST(dtype)); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_BACKWARD_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_loss_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/ctc_loss_kernel_util.h" namespace oneflow { template int64_t get_target_prime(const TARGET* targets_ptr, const IDX* target_lengths_ptr, int64_t max_target_length, int64_t b, int64_t s, int64_t blank, const int32_t targets_ndim) { if (s % 2 == 0) { return blank; } else { int64_t idx = 0; if (targets_ndim == 1) { FOR_RANGE(int64_t, i, 0, b) { idx += target_lengths_ptr[i]; } } else { // targets_ndim == 2 idx = b * max_target_length; } idx += s / 2; return static_cast(targets_ptr[idx]); } } template struct CtcLossKernelUtil final { static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& alpha_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t blank, const int32_t targets_ndim); static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& beta_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels, const int64_t blank, const bool zero_infinity, const int32_t targets_ndim); }; template void CtcLossKernelUtil::CtcLossForward( ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& alpha_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t blank, const int32_t targets_ndim) { constexpr T neginf = -std::numeric_limits::infinity(); FOR_RANGE(int64_t, b, 0, batch_size) { CHECK_GE(max_input_length, input_lengths_ptr[b]); CHECK_GE(max_target_length, target_lengths_ptr[b]); } FOR_RANGE(int32_t, b, 0, batch_size) { IDX input_length = input_lengths_ptr[b]; IDX target_length = target_lengths_ptr[b]; int64_t alpha_idx = alpha_helper.NdIndexToOffset(b, 0, 0); for (IDX s = 0; s < 2 * target_length + 1; s++) { alpha_ptr[alpha_idx + s] = neginf; } alpha_ptr[alpha_idx] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)]; if (target_length > 0) { TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 1, blank, targets_ndim); alpha_ptr[alpha_idx + 1] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)]; } for (IDX t = 1; t < input_length; t++) { for (IDX s = 0; s < 2 * target_length + 1; s++) { TARGET current_target_prime = get_target_prime( targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim); T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)]; T la2, la3, lamax = la1; if (s > 0) { la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)]; if (la2 > lamax) lamax = la2; } else { la2 = neginf; } if ((s > 1) && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s - 2, blank, targets_ndim) != current_target_prime)) { la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)]; if (la3 > lamax) lamax = la3; } else { la3 = neginf; } if (lamax == neginf) lamax = 0; int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s); alpha_ptr[idx_t_s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; } } if (target_length == 0) { int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0); loss_ptr[b] = -alpha_ptr[idx]; } else { int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2); int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1); T l1 = alpha_ptr[idx1]; T l2 = alpha_ptr[idx2]; T m = std::max(l1, l2); m = ((m == neginf) ? 0 : m); T log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m; loss_ptr[b] = -log_likelihood; } } } template void CtcLossKernelUtil::CtcLossBackward( ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& beta_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels, const int64_t blank, const bool zero_infinity, const int32_t targets_ndim) { constexpr T neginf = -std::numeric_limits::infinity(); int64_t elem_cnt = max_input_length * batch_size * num_labels; FOR_RANGE(int64_t, i, 0, elem_cnt) { grad_ptr[i] = neginf; } FOR_RANGE(int64_t, b, 0, batch_size) { IDX input_length = input_lengths_ptr[b]; IDX target_length = target_lengths_ptr[b]; T nll = loss_ptr[b]; if (zero_infinity && nll == std::numeric_limits::infinity()) { for (IDX t = 0; t < max_input_length; t++) { for (IDX c = 0; c < num_labels; c++) { grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0; } } continue; } if (input_length > 0) { int64_t beta_idx = beta_helper.NdIndexToOffset(b, input_length - 1, 0); for (IDX s = 0; s < 2 * target_length + 1; s++) { beta_ptr[beta_idx + s] = neginf; } beta_ptr[beta_idx + 2 * target_length] = log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)]; grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] = alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]; if (target_length > 0) { TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 2 * target_length - 1, blank, targets_ndim); beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] = log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)]; grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] = alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]; } } for (IDX t = input_length - 2; t >= 0; t--) { for (IDX s = 2 * target_length; s >= 0; s--) { TARGET current_target_prime = get_target_prime( targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim); T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)]; T lb2, lb3, lbmax = lb1; if (s < 2 * target_length) { lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)]; if (lb2 > lbmax) lbmax = lb2; } else { lb2 = neginf; } if ((s < 2 * target_length - 1) && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s + 2, blank, targets_ndim) != current_target_prime)) { lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)]; if (lb3 > lbmax) lbmax = lb3; } else { lb3 = neginf; } if (lbmax == neginf) lbmax = 0; int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); beta_ptr[idx_t_s] = std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) + std::exp(lb3 - lbmax)) + lbmax + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s]; T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; if (lcab == neginf) { lcab = log_alpha_beta; } else { T m = std::max(lcab, log_alpha_beta); lcab = std::log(std::exp(lcab - m) + std::exp(log_alpha_beta - m)) + m; } } } for (int32_t t = 0; t < input_length; t++) { for (int64_t c = 0; c < num_labels; c++) { T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)]; T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)]; res = (std::exp(lp) - std::exp(res + nll - lp)) * grad_out_ptr[b]; } } // zero the remainder if (input_length < max_input_length) { for (int64_t t = input_length; t < max_input_length; t++) { for (int64_t c = 0; c < num_labels; c++) { int64_t grad_idx = input_helper.NdIndexToOffset(t, b, c); grad_ptr[grad_idx] = 0; } } } } } #define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU(device_type_v, log_probs_dtype_pair, \ targets_dtype_pair, input_lengths_dtype_pair) \ template struct CtcLossKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU, (DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_loss_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/ctc_loss_kernel_util.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace { template __device__ __inline__ static int64_t get_target_prime(const TARGET* targets_ptr, const IDX* target_lengths_ptr, int64_t max_target_length, int64_t b, int64_t s, int64_t blank, const int32_t targets_ndim) { if (s % 2 == 0) { return blank; } else { int64_t idx = 0; if (targets_ndim == 1) { FOR_RANGE(int64_t, i, 0, b) { idx += target_lengths_ptr[i]; } } else { // targets_ndim == 2 idx = b * max_target_length; } idx += s / 2; return static_cast(targets_ptr[idx]); } } template __global__ void CtcLossGpu(const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper input_helper, NdIndexOffsetHelper alpha_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t blank, const int32_t targets_ndim) { constexpr T neginf = -INFINITY; const int32_t bid = blockIdx.x; const int32_t tid = threadIdx.x; for (int64_t b = bid; b < batch_size; b += gridDim.x) { if (tid == 0) { if (input_lengths_ptr[b] > max_input_length) __trap(); if (target_lengths_ptr[b] > max_target_length) __trap(); } } for (int64_t b = bid; b < batch_size; b += gridDim.x) { IDX input_length = input_lengths_ptr[b]; IDX target_length = target_lengths_ptr[b]; for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, s)] = neginf; } if (tid == 0) { alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 0)] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)]; if (target_length > 0) { TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 1, blank, targets_ndim); alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 1)] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)]; } } __syncthreads(); for (IDX t = 1; t < input_length; t++) { for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { TARGET current_target_prime = get_target_prime( targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim); T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)]; T la2, la3, lamax = la1; if (s > 0) { la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)]; if (la2 > lamax) lamax = la2; } else { la2 = neginf; } if ((s > 1) && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s - 2, blank, targets_ndim) != current_target_prime)) { la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)]; if (la3 > lamax) lamax = la3; } else { la3 = neginf; } if (lamax == neginf) lamax = 0; int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s); alpha_ptr[idx_t_s] = log(exp(la1 - lamax) + exp(la2 - lamax) + exp(la3 - lamax)) + lamax + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; } __syncthreads(); } if (tid == 0) { if (target_length == 0) { int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0); loss_ptr[b] = -alpha_ptr[idx]; } else { int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2); int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1); T l1 = alpha_ptr[idx1]; T l2 = alpha_ptr[idx2]; T m = max(l1, l2); m = ((m == neginf) ? 0 : m); T log_likelihood = log(exp(l1 - m) + exp(l2 - m)) + m; loss_ptr[b] = -log_likelihood; } } } } template __global__ void CtcLossGradGpu( const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper input_helper, NdIndexOffsetHelper beta_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels, const int64_t blank, const bool zero_infinity, const int32_t targets_ndim) { constexpr T neginf = -INFINITY; const int32_t bid = blockIdx.x; const int32_t tid = threadIdx.x; for (int64_t b = bid; b < batch_size; b += gridDim.x) { IDX input_length = input_lengths_ptr[b]; IDX target_length = target_lengths_ptr[b]; T nll = loss_ptr[b]; if (zero_infinity && nll == INFINITY) { for (IDX t = tid; t < max_input_length; t += blockDim.x) { for (IDX c = 0; c < num_labels; c++) { grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0; } } __syncthreads(); continue; } if (input_length > 0) { for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, s)] = neginf; } if (tid == 0) { beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] = log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)]; if (target_length > 0) { TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 2 * target_length - 1, blank, targets_ndim); beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] = log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)]; } } __syncthreads(); } for (IDX t = input_length - 2; t >= 0; t--) { for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { TARGET current_target_prime = get_target_prime( targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim); T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)]; T lb2, lb3, lbmax = lb1; if (s < 2 * target_length) { lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)]; if (lb2 > lbmax) lbmax = lb2; } else { lb2 = neginf; } if ((s < 2 * target_length - 1) && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s + 2, blank, targets_ndim) != current_target_prime)) { lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)]; if (lb3 > lbmax) lbmax = lb3; } else { lb3 = neginf; } if (lbmax == neginf) lbmax = 0; int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); beta_ptr[idx_t_s] = log(exp(lb1 - lbmax) + exp(lb2 - lbmax) + exp(lb3 - lbmax)) + lbmax + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; } __syncthreads(); } for (IDX t = tid; t < max_input_length; t += blockDim.x) { for (IDX c = 0; c < num_labels; c++) { grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = t < input_length ? neginf : 0; } } __syncthreads(); if (tid == 0) { grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] = alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]; if (target_length > 0) { TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 2 * target_length - 1, blank, targets_ndim); grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] = alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]; } } __syncthreads(); for (IDX t = tid; t < input_length; t += blockDim.x) { for (IDX s = 0; (t < input_length - 1) && (s < 2 * target_length + 1); s += 1) { TARGET current_target_prime = get_target_prime( targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim); int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s]; T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; if (lcab == neginf) { lcab = log_alpha_beta; } else { T m = max(lcab, log_alpha_beta); lcab = log(exp(lcab - m) + exp(log_alpha_beta - m)) + m; } } for (int32_t c = 0; c < num_labels; c++) { T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)]; T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)]; res = (exp(lp) - exp(res + nll - lp)) * grad_out_ptr[b]; } } } } } // namespace template struct CtcLossKernelUtil { static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& alpha_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t blank, const int32_t targets_ndim) { int32_t thread_num = batch_size * kCudaThreadsNumPerBlock; RUN_CUDA_KERNEL((CtcLossGpu), stream, thread_num, log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, alpha_ptr, loss_ptr, input_helper, alpha_helper, batch_size, max_input_length, max_target_length, blank, targets_ndim); } static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& beta_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels, const int64_t blank, const bool zero_infinity, const int32_t targets_ndim) { int32_t thread_num = batch_size * kCudaThreadsNumPerBlock; RUN_CUDA_KERNEL((CtcLossGradGpu), stream, thread_num, grad_out_ptr, loss_ptr, alpha_ptr, log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, beta_ptr, grad_ptr, input_helper, beta_helper, batch_size, max_input_length, max_target_length, num_labels, blank, zero_infinity, targets_ndim); } }; #define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA(device_type_v, log_probs_dtype_pair, \ targets_dtype_pair, input_lengths_dtype_pair) \ template struct CtcLossKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA, (DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ctc_loss_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { template struct CtcLossKernelUtil final { static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& alpha_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t blank, const int32_t targets_ndim); static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper& input_helper, NdIndexOffsetHelper& beta_helper, const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels, const int64_t blank, const bool zero_infinity, const int32_t targets_ndim); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/cublas_bias_add_relu_matmul_grad_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 namespace oneflow { namespace { template class CublasBiasAddReluMatmulGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CublasBiasAddReluMatmulGradKernel() = default; ~CublasBiasAddReluMatmulGradKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("aux", 0); user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_bias", 0); user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); const auto* matmul_grad_cache = CHECK_NOTNULL(dynamic_cast(cache)); auto* cuda_stream = ctx->stream()->As(); const DataType data_type = dy->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha = ctx->Attr("alpha"); const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const double beta = 0.0; const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // currently only support 2D matmul. DimVector dy_shape(2); dy->shape_view().ToDimVector(&dy_shape); DimVector weight_shape(2); weight->shape_view().ToDimVector(&weight_shape); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; InferMatmulCublasMNK(dy_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, d_bias->dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); /* a = dy, b = weight cublas_a=weight, cublas_b=dy */ OF_CUBLAS_CHECK( cublasLtMatmul(cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dy->dptr(), matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cublas_bias_add_relu_matmul_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("weight", 0) == GetDataType::value)); REGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(float) REGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(double) REGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(half) } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/user/kernels/cublas_fused_matmul_bias_add_grad.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/cuda/cuda_device.h" // CUBLASLT_EPILOGUE_BGRADB only support in cuda11.4.2 or higher version. // TODO(zhengzekang): In cuda11.6 version, CUBLASLT_EPILOGUE_BGRADB may occur illegal memory access // error in some shapes. #if CUDA_VERSION >= 11060 namespace oneflow { namespace { cudaDataType_t GetGemmComputeType(cudaDataType_t data_type) { switch (data_type) { case CUDA_R_32F: return CUDA_R_32F; case CUDA_R_64F: return CUDA_R_64F; case CUDA_R_16F: return CUDA_R_32F; #if CUDA_VERSION >= 11000 case CUDA_R_16BF: return CUDA_R_32F; #endif // CUDA_VERSION >= 11000 default: UNIMPLEMENTED(); return CUDA_R_32F; } } template class CublasMatmulBiasAddGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CublasMatmulBiasAddGradKernel() = default; ~CublasMatmulBiasAddGradKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* w_grad = ctx->Tensor4ArgNameAndIndex("w_grad", 0); user_op::Tensor* b_grad = ctx->Tensor4ArgNameAndIndex("b_grad", 0); const auto* matmul_grad_cache = CHECK_NOTNULL(dynamic_cast(cache)); auto* cuda_stream = ctx->stream()->As(); const DataType data_type = dy->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha = 1.0; const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const double beta = 0.0; const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // currently only support 2D matmul. DimVector dy_shape(2); dy->shape_view().ToDimVector(&dy_shape); DimVector x_shape(2); x->shape_view().ToDimVector(&x_shape); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BGRADB; InferMatmulCublasMNK(dy_shape, x_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); if (cublas_k != 1) { SetCublasAttr( matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, b_grad->mut_dptr(), /*aux_ptr=*/nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); /* a = dy, b = x cublas_a=x, cublas_b=dy */ OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, x->dptr(), matmul_grad_cache->cublas_a_desc, dy->dptr(), matmul_grad_cache->cublas_b_desc, &sp_beta, w_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, w_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } else { // Cause cublasLtmatmul get wrong bias grad in cublas_k == 1. #if CUDA_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; #else cublasGemmAlgo_t algo = (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DEFAULT; #endif cudaDataType_t gemm_compute_type = GetGemmComputeType(cuda_data_type); std::unique_ptr memcpy_primitive = ep::primitive::NewPrimitive( ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD); CHECK(memcpy_primitive); memcpy_primitive->Launch(ctx->stream(), b_grad->mut_dptr(), dy->dptr(), cublas_n * sizeof(T)); OF_CUBLAS_CHECK(cublasGemmEx( cuda_stream->cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_T, cublas_m, cublas_n, cublas_k, &sp_alpha, x->dptr(), cuda_data_type, cublas_lda, dy->dptr(), cuda_data_type, cublas_ldb, &sp_beta, w_grad->mut_dptr(), cuda_data_type, cublas_ldc, gemm_compute_type, algo)); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cublas_matmul_bias_add_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(float) REGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(double) REGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(half) } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 namespace oneflow { namespace { struct Comm { Comm(ncclComm_t comm) : comm(comm) {} ncclComm_t comm; }; class MatmulGradKernelState final : public user_op::OpKernelState { public: MatmulGradKernelState(user_op::KernelInitContext* ctx) : if_need_comm_(false), stream_name_(EagerNcclCommMgr::kDefaultStreamName) { OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); OF_CUDA_CHECK(cudaStreamCreate(&allreduce_stream_)); OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); workspace_size_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) * 1024 * 1024; OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_)); if (ctx->parallel_ctx().parallel_num() > 1) { parallel_conf_ = ctx->parallel_desc().parallel_conf(); } } ~MatmulGradKernelState() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_)); OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); OF_CUDA_CHECK(cudaStreamSynchronize(allreduce_stream_)); OF_CUDA_CHECK(cudaStreamDestroy(allreduce_stream_)); OF_CUDA_CHECK(cudaFree(workspace_)); } cudaStream_t grad_cuda_stream() const { return cuda_stream_; } cudaStream_t allreduce_stream() const { return allreduce_stream_; } cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } size_t cublas_workspace_size() const { return workspace_size_; } void* cublas_workspace() const { return workspace_; } bool IfCommCreate() const { if (!comm_) { return false; } return true; } bool IfNeedComm() const { return if_need_comm_; } ncclComm_t comm() { return GetOrCreate().comm; } const Comm& GetOrCreate() { if (!comm_) { InitCommMgr(); } return *comm_; } void InitNeedComm(user_op::KernelInitContext* ctx) { if_need_comm_ = false; if (ctx->parallel_ctx().parallel_num() > 1) { const int64_t d_weights_size = ctx->output_size("d_weights"); if (ctx->SbpParallel4ArgNameAndIndex("d_weights", 0).has_broadcast_parallel()) { for (int i = 0; i < d_weights_size; i++) { CHECK(ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel()) << "All d_weight's SBP should be Broadcast. "; CHECK(ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel()) << "All d_bias's SBP should be Broadcast. "; } if (ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { if_need_comm_ = true; } } } } void InitCommMgr() { std::set> device_set; const ParallelDesc parallel_desc(parallel_conf_); for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ncclComm_t comm; comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); comm_.reset(new Comm(comm)); } private: cudaStream_t cuda_stream_{}; cudaStream_t allreduce_stream_{}; cublasLtHandle_t cublas_lt_handle_{}; void* workspace_{}; size_t workspace_size_; std::string stream_name_; std::unique_ptr comm_; bool if_need_comm_; ParallelConf parallel_conf_; }; template class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CublasFusedMLPGradKernel() { OF_CUDA_CHECK(cudaEventCreate(&main_stream_event_)); OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event_)); OF_CUDA_CHECK(cudaEventCreate(&dweight_event_)); OF_CUDA_CHECK(cudaEventCreate(&allreduce_event_)); }; ~CublasFusedMLPGradKernel() override { OF_CUDA_CHECK(cudaEventDestroy(main_stream_event_)); OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event_)); OF_CUDA_CHECK(cudaEventDestroy(dweight_event_)); OF_CUDA_CHECK(cudaEventDestroy(allreduce_event_)); }; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { std::shared_ptr kernel_state = std::make_shared(ctx); kernel_state->InitNeedComm(ctx); return kernel_state; } private: cudaEvent_t main_stream_event_; cudaEvent_t async_weight_grad_event_; cudaEvent_t dweight_event_; cudaEvent_t allreduce_event_; bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { auto* kernel_state = dynamic_cast(state); if (kernel_state->IfNeedComm()) { return kernel_state->IfCommCreate(); } else { return true; } } using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int64_t tmp_buf_elem_cnt = tmp_buffer->shape_view().elem_cnt(); const int64_t weight_num = ctx->input_size("weights"); user_op::Tensor* d_x = ctx->Tensor4ArgNameAndIndex("d_x", 0); const std::vector alpha_list = ctx->Attr>("alpha_list"); auto* kernel_state = dynamic_cast(state); const auto* matmul_grad_cache = CHECK_NOTNULL(dynamic_cast(cache)); ncclComm_t comm{}; bool if_need_comm = kernel_state->IfNeedComm(); if (if_need_comm) { comm = kernel_state->comm(); } void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t tmp_buf_offset = 0; auto* cuda_stream = ctx->stream()->As(); const DataType data_type = dy->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha_one = 1.0; auto sp_alpha_one = GetCublasScalarParameter(alpha_one, cublas_compute_dtype); double alpha = 1.0; auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); double beta = 0.0; auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // currently only support 2D matmul. DimVector weight_shape(2); DimVector hidden_shape(2); DimVector dy_shape(2); dy->shape_view().ToDimVector(&dy_shape); const void* dgrad_buf = dy->dptr(); const int64_t batch_size = dy->shape_view().At(0); const void* ones = nullptr; ep::CudaDevice* cuda_device = dynamic_cast(ctx->stream()->device()); CHECK_NOTNULL(cuda_device); ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); if (ones == nullptr) { std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), data_type); CHECK(fill); fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, batch_size); ones = tmp_buffer->mut_dptr(); tmp_buf_offset += GetCudaAlignedSize(batch_size * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_buf_offset); } for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); weight->shape_view().ToDimVector(&weight_shape); InferMatmulCublasMNK(dy_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); if (idx != 0) { alpha = alpha_list.at(idx - 1); sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, d_bias->mut_dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); /* a = dy, b = weight cublas_a=weight, cublas_b=dy */ OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream())); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, matmul_grad_cache->cublas_b_desc, &sp_beta, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } else { epilogue = CUBLASLT_EPILOGUE_DEFAULT; SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); /* a = dy, b = weight cublas_a=weight, cublas_b=dy */ OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream())); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one, weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, matmul_grad_cache->cublas_b_desc, &sp_beta, d_x->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_x->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } // step1: Get last layer's dbias. if (idx == weight_num - 1) { user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(ones_buf_shape, dy_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one, dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream())); } user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); epilogue = CUBLASLT_EPILOGUE_DEFAULT; if (idx != 0) { const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here hidden->shape_view().ToDimVector(&hidden_shape); InferMatmulCublasMNK(dy_shape, hidden_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); if (idx != weight_num - 1) { // if idx == weight_num - 1, async_stream has wait main_stream_event_ in d_bias. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); } OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one, hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream())); OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream())); // compute dy shape dy_shape.at(1) = weight_shape.at(1); // compute dybuf dgrad_buf = dy_tmp_buf; tmp_buf_offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); CHECK_LE(tmp_buf_offset, tmp_buf_elem_cnt) << "Tmp buffer offset should <= Tmp buffer elem_cnt. "; dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_buf_offset); } else { x->shape_view().ToDimVector(&hidden_shape); InferMatmulCublasMNK(dy_shape, hidden_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one, x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream())); OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream())); } if (if_need_comm) { // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event_)); OF_NCCL_CHECK(ncclGroupStart()); user_op::Tensor* allreduce_d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx); OF_NCCL_CHECK(ncclAllReduce(allreduce_d_bias->mut_dptr(), allreduce_d_bias->mut_dptr(), allreduce_d_bias->shape_view().elem_cnt(), GetNcclDataType(allreduce_d_bias->data_type()), ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), d_weight->shape_view().elem_cnt(), GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); OF_NCCL_CHECK(ncclGroupEnd()); if (idx == 0) { // We should sync allreduce before the kernel finish. OF_CUDA_CHECK(cudaEventRecord(allreduce_event_, kernel_state->allreduce_stream())); } } } if (if_need_comm) { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event_)); } else { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event_)); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const int64_t weight_num = ctx->input_size("weights"); \ const Shape& dy_shape = ctx->InputShape("dy", 0); \ int64_t m = dy_shape.At(0); \ int64_t k = dy_shape.At(1); \ int64_t tmp_buffer_size = 0; \ tmp_buffer_size += GetCudaAlignedSize(m * sizeof(dtype)); /*For last layer's bias grad*/ \ for (int idx = weight_num - 1; idx > 0; idx--) { \ const Shape& weight_shape = ctx->InputShape("weights", idx); \ k = weight_shape.At(1); \ tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ } \ return tmp_buffer_size; \ }); REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float) REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(double) REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(half) REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("cublas_fused_mlp_grad"); } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/user/kernels/cublas_fused_mlp_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 namespace oneflow { namespace { template class CublasFusedMLPKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CublasFusedMLPKernel() = default; ~CublasFusedMLPKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { /* Fused DenseActivation Layer. Assume we have two layers: A: (m, k) B: (n, k) need transpose C: (j, n) need transpose tmp: A matmul B(transpose), its shape is (m, n) out: tmp matmul C(transpose), its shape is (m, j) */ const int32_t weight_size = ctx->input_size("weights"); const int32_t bias_size = ctx->input_size("biases"); CHECK_EQ(weight_size, bias_size) << "The number of weight and bias is not equal!. "; auto* cuda_stream = ctx->stream()->As(); const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast(cache)); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); bool skip_final_activation = ctx->Attr("skip_final_activation"); const DataType data_type = out->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha = 1.0; const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const double beta = 0.0; const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // Currently only support 2D matmul. DimVector in_shape(2); x->shape_view().ToDimVector(&in_shape); DimVector weight_shape(2); const void* in_buf_ptr = x->dptr(); for (int idx = 0; idx < weight_size; idx++) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("biases", idx); user_op::Tensor* cublas_aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx); int64_t out_feature = weight->shape_view().At(0); weight->shape_view().ToDimVector(&weight_shape); InferMatmulCublasMNK(in_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_RELU_AUX_BIAS; bool need_aux = true; void* y_ptr = nullptr; if (idx == weight_size - 1) { y_ptr = ctx->Tensor4ArgNameAndIndex("out", 0)->mut_dptr(); if (skip_final_activation) { epilogue = CUBLASLT_EPILOGUE_BIAS; need_aux = false; } } else { y_ptr = ctx->Tensor4ArgNameAndIndex("hidden", idx)->mut_dptr(); } SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, need_aux, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(), cublas_aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_cache->cublas_a_desc, in_buf_ptr, matmul_cache->cublas_b_desc, &sp_beta, y_ptr, matmul_cache->cublas_c_desc, y_ptr, matmul_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); // Set hidden_layer ptr as next layer's input. in_buf_ptr = y_ptr; // Set hidden_layer shape as next layer's input shape. in_shape.at(1) = out_feature; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("cublas_fused_mlp") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type)); REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(double, DataType::kDouble); REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(float, DataType::kFloat); REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(half, DataType::kFloat16); REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16); } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/user/kernels/cublas_fused_mlp_util.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if defined(__CUDACC__) #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11020 namespace oneflow { namespace { constexpr int32_t kAuxReluLdAlignRequirement = 128; constexpr size_t kDefaultWorkspaceSizeMb = 4; // 4M long AlignReluAuxLd(long aux_ld) { /* ReLu bit-mask matrix leading dimension in elements. Must be divisible by 128 and be no less than the number of rows in the output matrix. */ long old_aux_ld = aux_ld; return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement) * kAuxReluLdAlignRequirement; } class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { public: CublasFusedMLPKernelCache() { // Just for init. OF_CUBLAS_CHECK(cublasLtMatmulDescCreate(&operation_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_a_desc, CUDA_R_32F, 1, 1, 1)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_b_desc, CUDA_R_32F, 1, 1, 1)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_c_desc, CUDA_R_32F, 1, 1, 1)); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&cublas_preference)); } ~CublasFusedMLPKernelCache() override { OF_CUBLAS_CHECK(cublasLtMatmulDescDestroy(operation_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_a_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_b_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_c_desc)); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(cublas_preference)); } cublasLtMatmulDesc_t operation_desc; cublasLtMatrixLayout_t cublas_a_desc; cublasLtMatrixLayout_t cublas_b_desc; cublasLtMatrixLayout_t cublas_c_desc; cublasLtMatmulPreference_t cublas_preference; }; std::shared_ptr CreateCublasFusedMLPKernelCache() { std::shared_ptr cache(new CublasFusedMLPKernelCache()); return cache; } Optional OptCudaDataType(DataType data_type) { switch (data_type) { case kFloat: return CUDA_R_32F; case kDouble: return CUDA_R_64F; case kFloat16: return CUDA_R_16F; case kBFloat16: return CUDA_R_16BF; default: return NullOpt; } } cudaDataType_t GetCudaDataType(DataType data_type) { auto cuda_data_type = OptCudaDataType(data_type); CHECK(cuda_data_type.has_value()); return cuda_data_type.value_or(CUDA_R_32F); } cublasComputeType_t GetComputeType(DataType data_type) { switch (data_type) { case kFloat: if (ParseBooleanFromEnv("ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION", true)) { return CUBLAS_COMPUTE_32F_FAST_TF32; } else { return CUBLAS_COMPUTE_32F; } case kDouble: return CUBLAS_COMPUTE_64F; case kFloat16: { const bool allow_half_accumulation = ParseBooleanFromEnv("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", false); if (allow_half_accumulation) { return CUBLAS_COMPUTE_16F; } else { return CUBLAS_COMPUTE_32F; } } case kBFloat16: return CUBLAS_COMPUTE_32F; default: UNIMPLEMENTED(); return CUBLAS_COMPUTE_32F; } } union CublasScalarParameter { double d; float s; half h; }; CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cublasComputeType_t compute_type) { CublasScalarParameter sp{}; if (compute_type == CUBLAS_COMPUTE_64F) { sp.d = scalar.Value(); } else if (compute_type == CUBLAS_COMPUTE_32F || compute_type == CUBLAS_COMPUTE_32F_FAST_TF32) { sp.s = scalar.Value(); } else if (compute_type == CUBLAS_COMPUTE_16F) { sp.h = static_cast(scalar.Value()); } else { UNIMPLEMENTED(); } return sp; } void InferMatmulCublasMNK(const DimVector& a_shape, const DimVector& b_shape, ep::primitive::BlasTransposeType transpose_a, ep::primitive::BlasTransposeType transpose_b, size_t* cublas_m, size_t* cublas_n, size_t* cublas_k, int64_t* cublas_lda, int64_t* cublas_ldb, int64_t* cublas_ldc) { const int64_t num_a_axes = a_shape.size(); CHECK_GE(num_a_axes, 2); const int64_t num_b_axes = b_shape.size(); CHECK_GE(num_b_axes, 2); size_t m = 0, n = 0, k = 0; if (transpose_a == ep::primitive::BlasTransposeType::N) { m = a_shape.at(num_a_axes - 2); k = a_shape.at(num_a_axes - 1); *cublas_ldb = k; } else if (transpose_a == ep::primitive::BlasTransposeType::T) { m = a_shape.at(num_a_axes - 1); k = a_shape.at(num_a_axes - 2); *cublas_ldb = m; } else { UNIMPLEMENTED(); } if (transpose_b == ep::primitive::BlasTransposeType::N) { CHECK_EQ(b_shape.at(num_b_axes - 2), k); n = b_shape.at(num_b_axes - 1); *cublas_lda = n; } else if (transpose_b == ep::primitive::BlasTransposeType::T) { CHECK_EQ(b_shape.at(num_b_axes - 1), k); n = b_shape.at(num_b_axes - 2); *cublas_lda = k; } else { UNIMPLEMENTED(); } *cublas_m = n; *cublas_n = m; *cublas_k = k; *cublas_ldc = n; } void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc, cudaDataType_t cuda_data_type, cublasOperation_t cublas_trans, const size_t cublas_m1, const size_t cublas_n1, int64_t cublas_ld) { OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(layout_desc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cuda_data_type, sizeof(cuda_data_type))); OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute( layout_desc, CUBLASLT_MATRIX_LAYOUT_ROWS, cublas_trans == CUBLAS_OP_N ? &cublas_m1 : &cublas_n1, sizeof(cublas_m1))); OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute( layout_desc, CUBLASLT_MATRIX_LAYOUT_COLS, cublas_trans == CUBLAS_OP_N ? &cublas_n1 : &cublas_m1, sizeof(cublas_m1))); OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(layout_desc, CUBLASLT_MATRIX_LAYOUT_LD, &cublas_ld, sizeof(cublas_ld))); } void SetCublasEpilogue(const CublasFusedMLPKernelCache* matmul_cache, cublasLtEpilogue_t epilogue, const void* bias_ptr, const void* aux_ptr) { // Set epilogue OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); #if CUDA_VERSION >= 11060 const bool has_bias = (epilogue == CUBLASLT_EPILOGUE_RELU_BIAS || epilogue == CUBLASLT_EPILOGUE_BIAS || epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD || epilogue == CUBLASLT_EPILOGUE_BGRADB); #else const bool has_bias = (epilogue == CUBLASLT_EPILOGUE_RELU_BIAS || epilogue == CUBLASLT_EPILOGUE_BIAS); #endif // CUDA_VERSION >= 11060 if (has_bias) { // Set bias ptr OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } else { // unset bias_ptr = nullptr; OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } #if CUDA_VERSION >= 11060 if (epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD) { // Set aux ptr for backward. OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_ptr, sizeof(aux_ptr))); } else { // Clear Aux ptr. aux_ptr = nullptr; OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_ptr, sizeof(aux_ptr))); } #endif // CUDA_VERSION >= 11060 } void SetCublasAttr(const CublasFusedMLPKernelCache* matmul_grad_cache, const cublasComputeType_t cublas_compute_dtype, const cudaDataType_t cuda_data_type, bool need_aux, ep::primitive::BlasTransposeType transpose_a, ep::primitive::BlasTransposeType transpose_b, cublasLtEpilogue_t epilogue, const void* d_bias_ptr, const void* aux_ptr, size_t cublas_m, size_t cublas_n, size_t cublas_k, int64_t cublas_lda, int64_t cublas_ldb, int64_t cublas_ldc) { OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_COMPUTE_TYPE, &cublas_compute_dtype, sizeof(cublas_compute_dtype))); size_t workspace_size = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) * 1024 * 1024; OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); #if CUDA_VERSION < 12000 uint32_t pointer_mode = CUBLASLT_POINTER_MODE_MASK_HOST; OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference, CUBLASLT_MATMUL_PREF_POINTER_MODE_MASK, &pointer_mode, sizeof(pointer_mode))); #endif // CUDA_VERSION < 12000 // transpose_a = False, transpose_b = True. But in cublas is reversed. const cublasOperation_t cublas_trans_a = transpose_b == ep::primitive::BlasTransposeType::T ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t cublas_trans_b = transpose_a == ep::primitive::BlasTransposeType::T ? CUBLAS_OP_T : CUBLAS_OP_N; OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &cublas_trans_a, sizeof(cublas_trans_a))); OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &cublas_trans_b, sizeof(cublas_trans_b))); // Set epilogue SetCublasEpilogue(matmul_grad_cache, epilogue, d_bias_ptr, aux_ptr); /* Set AUX pointer LD If is used for CUBLASLT_EPILOGUE_DRELU_BGRAD, the AUX_LD need to align 128bit. If is used for CUBLASLT_EPILOGUE_DGELU_BGRAD, the AUX_LD need to align 8. For more details you can refer to CUBLAS docs: https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t `CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD`. */ #if CUDA_VERSION >= 11060 if (need_aux) { long aligned_aux_ld = AlignReluAuxLd(cublas_ldc); OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aligned_aux_ld, sizeof(aligned_aux_ld))); } else { long no_need_aligned_aux_ld = 0; OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &no_need_aligned_aux_ld, sizeof(no_need_aligned_aux_ld))); } #endif // CUDA_VERSION >= 11060 // Set matrix layout SetCublasMatrixLayout(matmul_grad_cache->cublas_a_desc, cuda_data_type, cublas_trans_a, cublas_m, cublas_k, cublas_lda); SetCublasMatrixLayout(matmul_grad_cache->cublas_b_desc, cuda_data_type, cublas_trans_b, cublas_k, cublas_n, cublas_ldb); SetCublasMatrixLayout(matmul_grad_cache->cublas_c_desc, cuda_data_type, CUBLAS_OP_N, cublas_m, cublas_n, cublas_ldc); } } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11020 #endif // defined(__CUDACC__) ================================================ FILE: oneflow/user/kernels/cufft_plan_cache.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_ #define ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_ #include #include #include #include #include #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { namespace { constexpr int max_rank = 3; enum class CUFFT_EXCUTETYPE { R2C, C2C, C2R }; struct CuFFTDataTypeDesc { cudaDataType inputtype; cudaDataType outputtype; cudaDataType executiontype; }; } // namespace class CuFFTHandle { cufftHandle handle; public: CuFFTHandle() { OF_CUFFT_CHECK(cufftCreate(&handle)); } cufftHandle& get() { return handle; } const cufftHandle& get() const { return handle; } ~CuFFTHandle() { cufftDestroy(handle); } }; // NOTE: The implementation of `CuFFTDataLayout`, `cufft_simple_embed` and `as_cufft_embed` are // mostly taken from pytorch. For more details pls refer to `CuFFTPlanCache.h` in PyTorch. typedef long long cufft_size_type; typedef small_vector cufft_dim_vector; struct CuFFTDataLayout { small_vector embed; cufft_size_type stride, dist; bool must_clone, simple; }; // Returns a cufft embedding for a contiguous signal of the given size. // e.g. if the input is cloned, this will be the resulting data layout inline CuFFTDataLayout cufft_simple_embed(const cufft_dim_vector& sizes, bool onesided) { CuFFTDataLayout layout; layout.simple = true; layout.must_clone = false; layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); if (onesided) { layout.embed.back() = sizes.back() / 2 + 1; } layout.stride = 1; layout.dist = 1; for (const auto& len : layout.embed) { layout.dist *= len; } return layout; } // Convert strides to a CuFFT embedded representation. // If strides cannot be embedded, returns a simple layout and sets must_clone flag inline CuFFTDataLayout as_cufft_embed(const cufft_dim_vector& strides, const cufft_dim_vector& sizes, bool onesided) { const auto signal_ndim = strides.size() - 1; CuFFTDataLayout layout; auto last_stride = strides[signal_ndim]; layout.must_clone = (last_stride <= 0); const auto last_dim_size = onesided ? sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; const auto signal_numel = std::accumulate(sizes.begin() + 1, sizes.end() - 1, (cufft_size_type)1, std::multiplies()) * last_dim_size; // Zero stides are not allowed, even if the batch size is one. // If that happens just set a dummy case if (sizes[0] == 1) { layout.dist = signal_numel; } else if (strides[0] == 0) { layout.must_clone = true; } else { layout.dist = strides[0]; } // Calculate the embedding shape, or set must_clone if the strides cannot be embedded layout.embed.resize(signal_ndim); for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { auto stride = strides[i]; if (sizes[i] == 1) { layout.embed[i] = 1; } else if (stride > 0 && stride % last_stride == 0) { layout.embed[i] = stride / last_stride; last_stride = stride; } else { layout.must_clone = true; } } // must_clone == false if (layout.must_clone) { // If the input needs to be cloned, assume it will be contiguous layout = cufft_simple_embed(sizes, onesided); layout.must_clone = true; } else { layout.embed[0] = sizes[1]; layout.stride = strides[signal_ndim]; // Determine if layout represents a simple embedding (contiguous data) layout.simple = [&] { FOR_RANGE(int, i, 1, signal_ndim - 1) { if (layout.embed[i] != sizes[i + 1]) { return false; } } return (layout.stride == 1 && layout.dist == signal_numel && layout.embed.back() == last_dim_size); }(); } return layout; } struct CuFFTParams { int64_t ndim; cufft_dim_vector input_shape; cufft_dim_vector input_strides; cufft_dim_vector output_shape; cufft_dim_vector output_strides; cufft_dim_vector data_shape; CUFFT_EXCUTETYPE excute_type; DataType real_data_type; CuFFTParams() = default; CuFFTParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_strides, const Stride& out_strides, int64_t dims, CUFFT_EXCUTETYPE type, DataType real) : ndim(dims), excute_type(type), real_data_type(real) { CHECK_OR_THROW(ndim >= 1 && ndim <= max_rank); CHECK_OR_THROW(in_shape.size() == ndim + 1); CHECK_OR_THROW(out_shape.size() == ndim + 1); CHECK_OR_THROW(in_shape.size() == in_strides.size()); CHECK_OR_THROW(out_shape.size() == out_strides.size()); data_shape.resize(ndim + 1); input_shape.resize(in_shape.size()); input_strides.resize(in_strides.size()); output_shape.resize(out_shape.size()); output_strides.resize(out_strides.size()); std::copy(in_strides.begin(), in_strides.end(), input_strides.begin()); std::copy(out_strides.begin(), out_strides.end(), output_strides.begin()); std::copy(in_shape.begin(), in_shape.end(), input_shape.begin()); std::copy(out_shape.begin(), out_shape.end(), output_shape.begin()); data_shape[0] = input_shape[0]; // batch size FOR_RANGE(int64_t, i, 0, ndim) { auto in_size = input_shape[i + 1]; auto out_size = output_shape[i + 1]; data_shape[i + 1] = std::max(in_size, out_size); CHECK_OR_THROW(in_size == data_shape[i + 1] || in_size == (data_shape[i + 1] / 2) + 1); CHECK_OR_THROW(out_size == data_shape[i + 1] || out_size == (data_shape[i + 1] / 2) + 1); } } }; class CuFFTConfig { public: CuFFTConfig(const CuFFTConfig&) = delete; CuFFTConfig& operator=(CuFFTConfig const&) = delete; ~CuFFTConfig() = default; explicit CuFFTConfig(CuFFTParams& params) { // NOLINT if (params.real_data_type == kBFloat16 || params.real_data_type == kFloat16) { // CuFFT support half data type, but there are some limits: // https://docs.nvidia.com/cuda/cufft/#half-precision-cufft-transforms CHECK_OR_THROW(false) << "Unsupported datatype kBFloat16 and kFloat16."; } CuFFTDataLayout input_layout = as_cufft_embed(params.input_strides, params.data_shape, params.excute_type == CUFFT_EXCUTETYPE::C2R); CuFFTDataLayout output_layout = as_cufft_embed(params.output_strides, params.data_shape, params.excute_type == CUFFT_EXCUTETYPE::R2C); bool clone_input = input_layout.must_clone; // that means: input should be contiguous because // original input can't be embeded const bool is_layout_simple = input_layout.simple && output_layout.simple; // disable cuFFT the default behavior of allocating work area at plan generating time OF_CUFFT_CHECK(cufftSetAutoAllocation(plan_handle_.get(), 0)); infer_cufft_type_(params.excute_type, params.real_data_type); // exclude input_shape[0] whtich is batch dim cufft_dim_vector fft_shape(params.data_shape.begin() + 1, params.data_shape.end()); cufft_size_type batch = params.data_shape[0]; if (is_layout_simple) { OF_CUFFT_CHECK(cufftXtMakePlanMany(plan_handle_.get(), params.ndim, fft_shape.data(), /*inembed=*/nullptr, /*istride=*/1, /*idist=*/1, /*inputtype=*/data_type_desc_.inputtype, /*onembed=*/nullptr, /*ostride=*/1, /*odist=*/1, /*outputtype=*/data_type_desc_.outputtype, /*batch=*/batch, /*workSize=*/&work_size_, /*executiontype=*/data_type_desc_.executiontype)); } else { OF_CUFFT_CHECK(cufftXtMakePlanMany( plan_handle_.get(), params.ndim, fft_shape.data(), /*inembed=*/input_layout.embed.data(), /*istride=*/input_layout.stride, /*idist=*/input_layout.dist, /*inputtype=*/data_type_desc_.inputtype, /*onembed=*/output_layout.embed.data(), /*ostride=*/output_layout.stride, /*odist=*/output_layout.dist, /*outputtype=*/data_type_desc_.outputtype, /*batch=*/batch, /*workSize=*/&work_size_, /*executiontype=*/data_type_desc_.executiontype)); } } size_t workspace_size() const { return work_size_; } const cufftHandle& plan() const { return plan_handle_.get(); } void excute(void* input, void* output, bool forward) { OF_CUFFT_CHECK( cufftXtExec(plan_handle_.get(), input, output, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); } private: void infer_cufft_type_(CUFFT_EXCUTETYPE excute_type, DataType real_data_type) { if (real_data_type == kFloat) { data_type_desc_.executiontype = CUDA_C_32F; data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_32F : CUDA_C_32F; data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_32F : CUDA_C_32F; } else if (real_data_type == kDouble) { data_type_desc_.executiontype = CUDA_C_64F; data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_64F : CUDA_C_64F; data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_64F : CUDA_C_64F; } else { CHECK_OR_THROW(false) << "cuFFT doesn't support type " << real_data_type; } } CuFFTHandle plan_handle_; CuFFTDataTypeDesc data_type_desc_; size_t work_size_; }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_ ================================================ FILE: oneflow/user/kernels/cum_backward_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { // CumProd backward, formula: flip(cumsum(flip(dY * Y))) / X. template void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr, const int64_t up_space, const int64_t space, const int64_t down_space, const int64_t elem_cnt) { const auto step = space * down_space; for (size_t i = 0; i < up_space; i++) { const size_t base_ptr_offset = step * i; const T* input_ptr_base = input_ptr + base_ptr_offset; const T* output_ptr_base = output_ptr + base_ptr_offset; const T* dy_ptr_base = dy_ptr + base_ptr_offset; T* dx_ptr_base = dx_ptr + base_ptr_offset; // Use dx as tmp buffer for finding 0 element in the input. for (size_t j = 0; j < space; j++) { const size_t ptr_offset = j * down_space; auto* cur_input_ptr = input_ptr_base + ptr_offset; auto* cumsum_zeros_number_ptr = dx_ptr_base + ptr_offset; auto* last_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr - down_space; for (size_t k = 0; k < down_space; k++) { int is_zero = cur_input_ptr[k] == 0 ? 1 : 0; cumsum_zeros_number_ptr[k] = is_zero + (j == 0 ? 0 : last_cumsum_zeros_number_ptr[k]); } } for (size_t j = 0; j < down_space; j++) { const auto* cur_output_ptr = output_ptr_base + j; const auto* cur_input_ptr = input_ptr_base + j; const auto* cur_dy_ptr = dy_ptr_base + j; auto* cur_dx_ptr = dx_ptr_base + j; const auto* cumsum_zeros_number_ptr = dx_ptr_base + j; size_t first_zero_index = space; // Find index of first zero in input. for (size_t k = 0; k < space; k++) { if (cumsum_zeros_number_ptr[k * down_space] == 1) { first_zero_index = k; break; } } // Suppose z is index of first zero element in input, // for element which index is less than z grad is computed as below: T reverse_cumsum = 0; for (size_t k = 0; k < first_zero_index; k++) { const size_t data_offset = (first_zero_index - k - 1) * down_space; reverse_cumsum += cur_output_ptr[data_offset] * cur_dy_ptr[data_offset]; cur_dx_ptr[data_offset] = reverse_cumsum / cur_input_ptr[data_offset]; } // For where index is z, its grad is computed as below: if (first_zero_index == space) { continue; } T cumprod = 1; T cumsum = 0; T cumprod_before_first_zero = first_zero_index == 0 ? 1 : cur_output_ptr[(first_zero_index - 1) * down_space]; for (size_t k = first_zero_index; k < space; k++) { const size_t data_offset = k * down_space; // Recover dx_ptr default value if (cur_dx_ptr[data_offset] >= 1) { cur_dx_ptr[data_offset] = 0; } if (k != first_zero_index) { cumprod *= cur_input_ptr[data_offset]; } cumsum += cumprod_before_first_zero * cumprod * cur_dy_ptr[data_offset]; } cur_dx_ptr[first_zero_index * down_space] = cumsum; } } } } // namespace template class CpuCumProdGradKernel final : public user_op::OpKernel { public: CpuCumProdGradKernel() = default; ~CpuCumProdGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto* output = ctx->Tensor4ArgNameAndIndex("output", 0); const auto* input = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_cnt = dy->shape_view().elem_cnt(); if (elem_cnt == 0) { return; } const auto* output_ptr = output->dptr(); const auto* input_ptr = input->dptr(); const auto* dy_ptr = dy->dptr(); auto* dx_ptr = dx->mut_dptr(); // data partition: up_space|space|down_space auto dim = ctx->Attr("dim"); auto up_space = elem_cnt / dx->shape_view().Count(dim); auto space = dx->shape_view().At(dim); auto down_space = dx->shape_view().Count(dim + 1); if (space == 1) { Memcpy(ctx->stream(), dx_ptr, dy_ptr, elem_cnt * sizeof(T)); return; } CumProdBackward(dy_ptr, dx_ptr, output_ptr, input_ptr, up_space, space, down_space, elem_cnt); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_CUMPROD_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cumprod_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_CPU_CUMPROD_GRAD_KERNEL(float) REGISTER_CPU_CUMPROD_GRAD_KERNEL(double) #undef REGISTER_CPU_CUMPROD_GRAD_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cum_backward_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { #ifdef WITH_CUDA namespace { template __global__ void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr, const int64_t up_space, const int64_t space, const int64_t down_space, const int64_t thread_num) { // A thread is responsible for a row along specific dimension. const size_t up_space_step = space * down_space; CUDA_1D_KERNEL_LOOP_T(size_t, i, thread_num) { const size_t up_space_id = i / down_space; const size_t down_space_id = i % down_space; const size_t ptr_offset = up_space_id * up_space_step + down_space_id; auto* dy_ptr_base = dy_ptr + ptr_offset; auto* dx_ptr_base = dx_ptr + ptr_offset; auto* input_ptr_base = input_ptr + ptr_offset; auto* output_ptr_base = output_ptr + ptr_offset; // Buffer storing number of zero element along specific dimension. // Use dx as tmp buffer. for (size_t j = 0; j < space; j++) { const size_t data_offset = j * down_space; int is_zero = input_ptr_base[data_offset] == 0 ? 1 : 0; dx_ptr_base[data_offset] = is_zero + (j == 0 ? 0 : dx_ptr_base[data_offset - down_space]); } // Find index of first zero in input. size_t first_zero_index = space; for (size_t j = 0; j < space; j++) { const size_t data_offset = j * down_space; if (dx_ptr_base[data_offset] == 1) { first_zero_index = j; break; } } // Suppose z is index of first zero element in input, // for element which index is less than z grad is computed as below: T reverse_cumsum = 0; for (size_t j = 0; j < first_zero_index; j++) { const size_t cur_index = first_zero_index - j - 1; const size_t data_offset = cur_index * down_space; reverse_cumsum += output_ptr_base[data_offset] * dy_ptr_base[data_offset]; dx_ptr_base[data_offset] = reverse_cumsum / input_ptr_base[data_offset]; } // Where index is z, its grad is computed as below: if (first_zero_index == space) { return; } T cumprod = 1; T cumsum = 0; T cumprod_before_first_zero = first_zero_index == 0 ? 1 : output_ptr_base[(first_zero_index - 1) * down_space]; for (size_t j = first_zero_index; j < space; j++) { const size_t down_space_offset = j * down_space; // Recover dx_ptr default value if (dx_ptr_base[down_space_offset] >= 1) { dx_ptr_base[down_space_offset] = 0; } if (j != first_zero_index) { cumprod *= input_ptr_base[down_space_offset]; } cumsum += cumprod_before_first_zero * dy_ptr_base[down_space_offset] * cumprod; } dx_ptr_base[first_zero_index * down_space] = cumsum; } } } // namespace template class GpuCumProdGradKernel final : public user_op::OpKernel { public: GpuCumProdGradKernel() = default; ~GpuCumProdGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* output = ctx->Tensor4ArgNameAndIndex("output", 0); const auto* input = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto elem_cnt = dy->shape_view().elem_cnt(); if (!elem_cnt) { return; } const auto* output_ptr = output->dptr(); const auto* input_ptr = input->dptr(); const auto* dy_ptr = dy->dptr(); auto* dx_ptr = dx->mut_dptr(); // Data partition: up_space|space|down_space auto dim = ctx->Attr("dim"); const auto up_space = elem_cnt / dx->shape_view().Count(dim); const auto space = dx->shape_view().At(dim); const auto down_space = dx->shape_view().Count(dim + 1); const size_t thread_num = up_space * down_space; if (space == 1) { Memcpy(ctx->stream(), dx_ptr, dy_ptr, elem_cnt * sizeof(T)); return; } ep::CudaLaunchConfig config{}; ctx->stream()->As()->InitLaunchConfigWithWaves( &config, thread_num, /*DefaultBlockSize*/ 256, /*max_wave*/ 1); CumProdBackward<<stream()->As()->cuda_stream()>>>( dy_ptr, dx_ptr, output_ptr, input_ptr, up_space, space, down_space, thread_num); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_CUMPROD_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cumprod_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_CUDA_CUMPROD_GRAD_KERNEL(float) REGISTER_CUDA_CUMPROD_GRAD_KERNEL(double) #undef REGISTER_CUDA_CUMPROD_GRAD_KERNEL #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cum_forward_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { namespace { template class BinaryFunc> void CumForward(const T* in_ptr, T* out_ptr, int64_t up_space, int64_t space, int64_t down_space, int64_t elem_cnt) { std::copy_n(in_ptr, elem_cnt, out_ptr); auto* tmp_out_ptr_base = out_ptr; auto step = space * down_space; for (auto i = 0; i < up_space; i++) { for (auto j = 1; j < space; j++) { auto* tmp_out_ptr = tmp_out_ptr_base + j * down_space; auto* last_tmp_out_ptr = tmp_out_ptr - down_space; for (auto k = 0; k < down_space; k++) { tmp_out_ptr[k] = BinaryFunc::Invoke(tmp_out_ptr[k], last_tmp_out_ptr[k]); } } tmp_out_ptr_base += step; } } } // namespace template class BinaryFunc> class CpuCumKernel : public user_op::OpKernel { public: CpuCumKernel() = default; ~CpuCumKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto* in = ctx->Tensor4ArgNameAndIndex("x", 0); auto elem_cnt = in->shape_view().elem_cnt(); // judge whether tensor has 0 size dimension first if (!elem_cnt) { return; } auto* out = ctx->Tensor4ArgNameAndIndex("y", 0); auto dim = ctx->Attr("dim"); const auto* in_ptr = in->dptr(); auto* out_ptr = out->mut_dptr(); // data partition: up_space|space|down_space auto up_space = elem_cnt / in->shape_view().Count(dim); auto space = in->shape_view().At(dim); auto down_space = in->shape_view().Count(dim + 1); CumForward(in_ptr, out_ptr, up_space, space, down_space, elem_cnt); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define CUMOP_SEQ \ OF_PP_MAKE_TUPLE_SEQ("cumprod", BinaryFuncMul) \ OF_PP_MAKE_TUPLE_SEQ("cumsum", BinaryFuncAdd) #define REGISTER_CUMOP_KERNEL(dtype, op_name, op_functor) \ REGISTER_USER_KERNEL(op_name).SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); #define REGISTER_CUMOP_KERNEL_WITH_DTYPE(op_name, op_functor) \ REGISTER_CUMOP_KERNEL(int32_t, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(int64_t, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(float, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(double, op_name, op_functor) OF_PP_FOR_EACH_TUPLE(REGISTER_CUMOP_KERNEL_WITH_DTYPE, CUMOP_SEQ); #undef REGISTER_CUMOP_KERNEL #undef REGISTER_CUMOP_KERNEL_WITH_DTYPE #undef CUMOP_SEQ } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cum_forward_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ndarray/binary_func.h" namespace oneflow { #ifdef WITH_CUDA namespace { template inline T CeilDiv(T n, T m) { return (n + m - 1) / m; } template struct SumFunctor { __device__ __forceinline__ T operator()(const T a, const T b) const { return a + b; } }; template struct ProdFunctor { __device__ __forceinline__ T operator()(const T a, const T b) const { return a * b; } }; template class BinaryFunc> size_t InferTmpBufferSize(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("x", 0); const int64_t dim = ctx->Attr("dim"); const size_t dim_size = in_shape.At(dim); if (in_shape.elem_cnt() == dim_size) { size_t temp_storage_bytes = 0; OF_CUDA_CHECK(cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes, static_cast(nullptr), static_cast(nullptr), BinaryFunc(), dim_size)); return GetCudaAlignedSize(temp_storage_bytes); } return 0; } // total thread number: cs_up_space * cs_down_space // in cs_down_space part, use cs_down_space threads // to calculate as follows(m=cs_down_space-1, n=cs_space-1, '|' stands for dependency): template class BinaryFunc> __global__ void CumForwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, int64_t cs_down_space) { CUDA_1D_KERNEL_LOOP(i, cs_up_space * cs_down_space) { auto cs_up_space_id = i / cs_down_space; auto cs_down_space_id = i - (i / cs_down_space) * cs_down_space; auto* in_ptr_base = in_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id; auto* out_ptr_base = out_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id; // calculate cs_space data in one thread for (auto j = 0; j < cs_space; j++) { auto idx = j * cs_down_space; out_ptr_base[idx] = in_ptr_base[idx]; if (j != 0) { out_ptr_base[idx] = BinaryFunc()(out_ptr_base[idx], out_ptr_base[idx - cs_down_space]); } } } } template class BinaryFunc> void ScanOuterDim(ep::Stream* ep_stream, const ShapeView& in_shape, int64_t dim, const T* in_ptr, T* out_ptr) { // data partition: up_space|space|down_space auto up_space = in_shape.elem_cnt() / in_shape.Count(dim); auto space = in_shape.At(dim); auto down_space = in_shape.Count(dim + 1); auto thread_num = up_space * down_space; RUN_CUDA_KERNEL((CumForwardGpu), ep_stream, thread_num, in_ptr, out_ptr, up_space, space, down_space); } // Refer from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ScanKernels.cu template class BinaryFunc> __device__ void ScanInnerMostDimKernelImpl(T* row_buf, T* src_, T* tgt_, const uint32_t num_rows, const uint32_t row_size, T init) { for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { uint32_t row = block_row + threadIdx.y; T block_total = init; T* row_src = src_ + row * row_size; T* row_tgt = tgt_ + row * row_size; // Perform scan on one block at a time, keeping track of the total value of // all blocks processed so far. for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { // Load data into shared memory (two values per thread). uint32_t col1 = block_col + threadIdx.x; uint32_t col2 = block_col + num_threads_x + threadIdx.x; if (row < num_rows) { if (col1 < row_size) { row_buf[threadIdx.x] = row_src[col1]; } else { row_buf[threadIdx.x] = init; } if (col2 < row_size) { row_buf[num_threads_x + threadIdx.x] = row_src[col2]; } else { row_buf[num_threads_x + threadIdx.x] = init; } // Add the total value of all previous blocks to the first value of this block. if (threadIdx.x == 0) { row_buf[0] = BinaryFunc()(row_buf[0], block_total); } } __syncthreads(); for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { if (row < num_rows && threadIdx.x < s) { uint32_t offset = (2 * threadIdx.x + 1) * d - 1; row_buf[offset + d] = BinaryFunc()(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { if (row < num_rows && threadIdx.x < s - 1) { uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; row_buf[offset + d] = BinaryFunc()(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } // Write back to output. if (row < num_rows) { if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; } block_total = row_buf[2 * num_threads_x - 1]; __syncthreads(); } } } template class BinaryFunc> __global__ void ScanInnerMostDimKernel(const T* in_ptr, T* out_ptr, const int64_t num_rows, const int64_t row_size, T init) { __shared__ T sbuf[num_threads_y][2 * num_threads_x]; T* row_buf = sbuf[threadIdx.y]; ScanInnerMostDimKernelImpl( row_buf, const_cast(in_ptr), out_ptr, num_rows, row_size, init); } template class BinaryFunctor> void ScanInnerMostDim(const T* in_ptr, T* out_ptr, const int64_t num_rows, const int64_t row_size, const ep::CudaStream* cuda_stream) { dim3 block(16, 32); const int64_t max_grid_dim = cuda_stream->device()->properties().maxGridSize[0]; dim3 grid(std::min(max_grid_dim, CeilDiv(num_rows, (int64_t)block.y))); if (std::is_same, SumFunctor>::value) { ScanInnerMostDimKernel <<cuda_stream()>>>(in_ptr, out_ptr, num_rows, row_size, /*init*/ 0); } else if (std::is_same, ProdFunctor>::value) { ScanInnerMostDimKernel <<cuda_stream()>>>(in_ptr, out_ptr, num_rows, row_size, /*init*/ 1); } else { UNIMPLEMENTED() << "Only Support cumsum and cumprod for now."; } } template class BinaryFunc> void CubInclusiveScan(user_op::Tensor* temp_buffer, const T* in_ptr, T* out_ptr, int64_t elem_cnt, const ep::CudaStream* cuda_stream) { auto* temp_storage = temp_buffer->mut_dptr(); size_t temp_storage_bytes = temp_buffer->shape_view().elem_cnt(); OF_CUDA_CHECK(cub::DeviceScan::InclusiveScan(temp_storage, temp_storage_bytes, in_ptr, out_ptr, BinaryFunc(), elem_cnt, cuda_stream->cuda_stream())); } } // namespace template class BinaryFunc> class GpuCumKernel : public user_op::OpKernel { public: GpuCumKernel() = default; ~GpuCumKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* in = ctx->Tensor4ArgNameAndIndex("x", 0); auto* out = ctx->Tensor4ArgNameAndIndex("y", 0); const ShapeView& in_shape = in->shape_view(); const int64_t dim = ctx->Attr("dim"); const int64_t dim_size = in_shape.At(dim); // Judge whether tensor has 0 size dimension first. auto elem_cnt = in_shape.elem_cnt(); if (!elem_cnt) { return; } const auto* in_ptr = in->dptr(); auto* out_ptr = out->mut_dptr(); const auto* cuda_stream = ctx->stream()->As(); if (elem_cnt == dim_size) { auto* temp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); CubInclusiveScan(temp_buffer, in_ptr, out_ptr, elem_cnt, cuda_stream); } else if (dim == in_shape.NumAxes() - 1) { // Treat all outer dimension as a single dimension. const int64_t num_rows = elem_cnt / dim_size; ScanInnerMostDim(in_ptr, out_ptr, num_rows, dim_size, cuda_stream); } else { ScanOuterDim(ctx->stream(), in_shape, dim, in_ptr, out_ptr); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define CUMOP_SEQ \ OF_PP_MAKE_TUPLE_SEQ("cumprod", ProdFunctor) \ OF_PP_MAKE_TUPLE_SEQ("cumsum", SumFunctor) #define REGISTER_CUMOP_KERNEL(dtype, op_name, op_functor) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpBufferSize); #define REGISTER_CUMOP_KERNEL_WITH_DTYPE(op_name, op_functor) \ REGISTER_CUMOP_KERNEL(int32_t, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(int64_t, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(float, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(double, op_name, op_functor) \ REGISTER_CUMOP_KERNEL(half, op_name, op_functor) OF_PP_FOR_EACH_TUPLE(REGISTER_CUMOP_KERNEL_WITH_DTYPE, CUMOP_SEQ); #undef REGISTER_CUMOP_KERNEL #undef REGISTER_CUMOP_KERNEL_WITH_DTYPE #undef CUMOP_SEQ #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/cutlass_conv_tuner.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUTLASS #include "oneflow/user/kernels/cutlass_conv_tuner.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/lazy_mode.h" #include #include #include namespace oneflow { namespace { bool IsWeakerAlginOperation(const cutlass::library::Operation* lhs, const cutlass::library::Operation* rhs) { const char* lhs_name = lhs->description().name; const char* rhs_name = rhs->description().name; const size_t len = std::strlen(lhs_name); const size_t suffix_len = std::strlen("align8"); if (std::strlen(rhs_name) != len) { return false; } if (len < suffix_len) { return false; } const size_t prefix_len = len - suffix_len; if (std::strncmp(lhs_name, rhs_name, prefix_len) != 0) { return false; } const auto& HasLegalSuffix = [&](const char* str) { if (std::strncmp(str + prefix_len, "align", std::strlen("align")) != 0) { return false; } const char align = str[len - 1]; return align == '8' || align == '4' || align == '2' || align == '1'; }; if ((!HasLegalSuffix(lhs_name)) || (!HasLegalSuffix(rhs_name))) { return false; } return lhs_name[len - 1] < rhs_name[len - 1]; } struct Conv2dOperationCacheKey { cutlass::library::ConvFunctionalKey functional_key; cutlass::library::Conv2dConfiguration configuraion; size_t alignment; Conv2dOperationCacheKey(cutlass::library::ConvFunctionalKey functional_key, cutlass::library::Conv2dConfiguration configuraion, cutlass::library::ConvArguments arguments) : functional_key(functional_key), configuraion(configuraion) { const auto IsStrideAligned = [&](const std::vector& stride, size_t n) { return std::all_of(stride.cbegin(), stride.cend(), [&](const int64_t& s) { return s % n == 0; }); }; CHECK_EQ(reinterpret_cast(arguments.A) % kCudaAlignSize, 0); CHECK_EQ(reinterpret_cast(arguments.B) % kCudaAlignSize, 0); CHECK_EQ(reinterpret_cast(arguments.C) % kCudaAlignSize, 0); CHECK_EQ(reinterpret_cast(arguments.D) % kCudaAlignSize, 0); const auto IsAligned = [&](size_t n) { return IsStrideAligned(configuraion.stride_a, n) && IsStrideAligned(configuraion.stride_b, n) && IsStrideAligned(configuraion.stride_c, n); }; if (IsAligned(8)) { alignment = 8; } else if (IsAligned(4)) { alignment = 4; } else if (IsAligned(2)) { alignment = 2; } else { alignment = 1; } } }; struct Conv2dProblemSizeHasher { size_t operator()(const cutlass::conv::Conv2dProblemSize& problem_size) const { size_t hash = 0; hash = HashCombine(hash, std::hash()(problem_size.N)); hash = HashCombine(hash, std::hash()(problem_size.H)); hash = HashCombine(hash, std::hash()(problem_size.W)); hash = HashCombine(hash, std::hash()(problem_size.C)); hash = HashCombine(hash, std::hash()(problem_size.P)); hash = HashCombine(hash, std::hash()(problem_size.Q)); hash = HashCombine(hash, std::hash()(problem_size.K)); hash = HashCombine(hash, std::hash()(problem_size.R)); hash = HashCombine(hash, std::hash()(problem_size.S)); hash = HashCombine(hash, std::hash()(problem_size.pad_h)); hash = HashCombine(hash, std::hash()(problem_size.pad_w)); hash = HashCombine(hash, std::hash()(problem_size.stride_h)); hash = HashCombine(hash, std::hash()(problem_size.stride_w)); hash = HashCombine(hash, std::hash()(problem_size.dilation_h)); hash = HashCombine(hash, std::hash()(problem_size.dilation_w)); hash = HashCombine(hash, std::hash()(static_cast(problem_size.mode))); hash = HashCombine(hash, std::hash()(problem_size.split_k_slices)); hash = HashCombine(hash, std::hash()(problem_size.groups)); return hash; } }; struct Conv2dConfigurationHasher { size_t operator()(const cutlass::library::Conv2dConfiguration& configuraion) const { size_t hash = std::hash()(static_cast(configuraion.split_k_mode)); hash = HashCombine(hash, Conv2dProblemSizeHasher()(configuraion.problem_size)); for (const int64_t v : configuraion.stride_a) { hash = HashCombine(hash, std::hash()(v)); } for (const int64_t v : configuraion.stride_b) { hash = HashCombine(hash, std::hash()(v)); } for (const int64_t v : configuraion.stride_c) { hash = HashCombine(hash, std::hash()(v)); } return hash; } }; struct Conv2dOperationCacheKeyHasher { size_t operator()(const Conv2dOperationCacheKey& key) const { size_t hash = cutlass::library::ConvFunctionalKeyHasher()(key.functional_key); hash = HashCombine(hash, Conv2dConfigurationHasher()(key.configuraion)); hash = HashCombine(hash, std::hash()(key.alignment)); return hash; } }; inline bool operator==(const cutlass::library::Conv2dConfiguration& lhs, const cutlass::library::Conv2dConfiguration& rhs) { return lhs.split_k_mode == rhs.split_k_mode && lhs.problem_size == rhs.problem_size && lhs.stride_a == rhs.stride_a && lhs.stride_b == rhs.stride_b && lhs.stride_c == rhs.stride_c; } inline bool operator==(const Conv2dOperationCacheKey& lhs, const Conv2dOperationCacheKey& rhs) { return lhs.functional_key == rhs.functional_key && lhs.configuraion == rhs.configuraion && lhs.alignment == rhs.alignment; } size_t GetTensorSize(cutlass::library::NumericTypeID element, cutlass::library::LayoutTypeID layout, const cutlass::Tensor4DCoord& extent, const std::vector& stride) { const size_t element_size = cutlass::library::sizeof_bits(element) / 8; size_t capacity = 0; if (layout == cutlass::library::LayoutTypeID::kTensorNHWC) { CHECK_EQ(stride.size(), 3); capacity = cutlass::layout::TensorNHWC(stride.at(0), stride.at(1), stride.at(2)).capacity(extent); } else { UNIMPLEMENTED(); } return capacity * element_size; } }; // namespace using CacheMap = std::unordered_map; struct CutlassConvTuner::Impl { std::mutex mutex; std::unordered_map cache; const cutlass::library::Operation* FindConv2dOperation( ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size); const cutlass::library::Operation* GetConv2dOperation( const std::string& name, ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size); }; const cutlass::library::Operation* CutlassConvTuner::Impl::FindConv2dOperation( ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) { int dev = 0; OF_CUDA_CHECK(cudaGetDevice(&dev)); Conv2dOperationCacheKey cache_key(functional_key, configuraion, arguments); { std::lock_guard lock(mutex); const auto& device_cache = cache[dev]; const auto& it = device_cache.find(cache_key); if (it != device_cache.end()) { return it->second; } } cutlass::library::ConvArguments benchmark_arguments = arguments; void* benchmark_workspace = workspace; cudaStream_t benchmark_stream = stream->cuda_stream(); #ifdef WITH_CUDA_GRAPHS cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; if (stream->IsGraphCapturing()) { OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); OF_CUDA_CHECK(cudaStreamCreate(&benchmark_stream)); OF_CUDA_CHECK(cudaMalloc(&benchmark_workspace, workspace_size)); const size_t a_size = GetTensorSize(functional_key.element_A, functional_key.layout_A, configuraion.problem_size.activation_extent(), configuraion.stride_a); OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.A, a_size)); const size_t b_size = GetTensorSize(functional_key.element_B, functional_key.layout_B, configuraion.problem_size.filter_extent(), configuraion.stride_b); OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.B, b_size)); if (benchmark_arguments.C != nullptr) { const size_t c_size = GetTensorSize(functional_key.element_C, functional_key.layout_C, configuraion.problem_size.output_extent(), configuraion.stride_c); OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.C, c_size)); } const size_t d_size = GetTensorSize( functional_key.element_C, functional_key.layout_C, configuraion.problem_size.output_extent(), {configuraion.problem_size.K, configuraion.problem_size.K * configuraion.problem_size.Q, configuraion.problem_size.K * configuraion.problem_size.Q * configuraion.problem_size.P}); OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.D, d_size)); } #endif // WITH_CUDA_GRAPHS constexpr int turing_warmup_iters = 2; constexpr int turing_iters = 5; cudaEvent_t start{}; cudaEvent_t end{}; OF_CUDA_CHECK(cudaEventCreate(&start)); OF_CUDA_CHECK(cudaEventCreate(&end)); const cutlass::library::Operation* fastest_operation = nullptr; float fastest_time = 0; const auto& operations_map_it = cutlass::library::Singleton::get().operation_table.conv2d_operations.find(functional_key); CHECK(operations_map_it != cutlass::library::Singleton::get().operation_table.conv2d_operations.cend()); const cutlass::library::ConvOperationVectorMap& operations_map = operations_map_it->second; for (const auto& pair : operations_map) { std::map> operations; for (auto operation : pair.second) { operations.emplace(operation->description().name, operation); } const cutlass::library::Operation* prev_operation = nullptr; for (const auto& name_operation : operations) { const cutlass::library::Operation* operation = name_operation.second; if (prev_operation != nullptr && IsWeakerAlginOperation(operation, prev_operation)) { continue; } if (operation->description().tile_description.minimum_compute_capability * 10 > stream->cuda_arch() || operation->description().tile_description.maximum_compute_capability * 10 < stream->cuda_arch()) { continue; } auto status = operation->can_implement(&configuraion, &benchmark_arguments); if (status != cutlass::Status::kSuccess) { continue; } const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion); const size_t device_workspace_size = operation->get_device_workspace_size(&configuraion); if (device_workspace_size > workspace_size) { continue; } std::vector host_workspace(host_workspace_size, 0); if (operation->initialize(&configuraion, host_workspace.data(), benchmark_workspace, benchmark_stream) != cutlass::Status::kSuccess) { continue; } const auto Run = [&]() { auto init_status = operation->initialize(&configuraion, host_workspace.data(), benchmark_workspace, benchmark_stream); CHECK(init_status == cutlass::Status::kSuccess); auto run_status = operation->run(&benchmark_arguments, host_workspace.data(), benchmark_workspace, benchmark_stream); CHECK(run_status == cutlass::Status::kSuccess); }; OF_CUDA_CHECK(cudaStreamSynchronize(benchmark_stream)); for (int i = 0; i < turing_warmup_iters; ++i) { Run(); } OF_CUDA_CHECK(cudaEventRecord(start, benchmark_stream)); for (int i = 0; i < turing_iters; ++i) { Run(); } OF_CUDA_CHECK(cudaEventRecord(end, benchmark_stream)); OF_CUDA_CHECK(cudaEventSynchronize(end)); float time = 0; OF_CUDA_CHECK(cudaEventElapsedTime(&time, start, end)); VLOG(3) << operation->description().name << " " << time; prev_operation = operation; if (fastest_operation == nullptr || time < fastest_time) { fastest_operation = operation; fastest_time = time; } } } OF_CUDA_CHECK(cudaEventDestroy(start)); OF_CUDA_CHECK(cudaEventDestroy(end)); #ifdef WITH_CUDA_GRAPHS if (stream->IsGraphCapturing()) { OF_CUDA_CHECK(cudaStreamSynchronize(benchmark_stream)); OF_CUDA_CHECK(cudaStreamDestroy(benchmark_stream)); OF_CUDA_CHECK(cudaFree(const_cast(benchmark_arguments.A))); OF_CUDA_CHECK(cudaFree(const_cast(benchmark_arguments.B))); OF_CUDA_CHECK(cudaFree(const_cast(benchmark_arguments.C))); OF_CUDA_CHECK(cudaFree(benchmark_arguments.D)); OF_CUDA_CHECK(cudaFree(benchmark_workspace)); OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); } #endif // WITH_CUDA_GRAPHS if (fastest_operation != nullptr) { VLOG(3) << "Fastest: " << fastest_operation->description().name << " " << fastest_time; { std::lock_guard lock(mutex); cache[dev][cache_key] = fastest_operation; } } return fastest_operation; } const cutlass::library::Operation* CutlassConvTuner::Impl::GetConv2dOperation( const std::string& name, ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) { int dev = 0; OF_CUDA_CHECK(cudaGetDevice(&dev)); const auto& operations_map_it = cutlass::library::Singleton::get().operation_table.conv2d_operations.find(functional_key); if (operations_map_it == cutlass::library::Singleton::get().operation_table.conv2d_operations.cend()) { return nullptr; } const cutlass::library::ConvOperationVectorMap& operations_map = operations_map_it->second; for (const auto& pair : operations_map) { for (auto operation : pair.second) { if (name != operation->description().name) { continue; } if (operation->description().tile_description.minimum_compute_capability * 10 > stream->cuda_arch() || operation->description().tile_description.maximum_compute_capability * 10 < stream->cuda_arch()) { continue; } auto status = operation->can_implement(&configuraion, &arguments); if (status != cutlass::Status::kSuccess) { continue; } const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion); const size_t device_workspace_size = operation->get_device_workspace_size(&configuraion); if (device_workspace_size > workspace_size) { continue; } std::vector host_workspace(host_workspace_size, 0); if (operation->initialize(&configuraion, host_workspace.data(), workspace, stream->cuda_stream()) != cutlass::Status::kSuccess) { continue; } return operation; } } return nullptr; } CutlassConvTuner::CutlassConvTuner() { impl_.reset(new Impl()); } const CutlassConvTuner& CutlassConvTuner::Get() { static CutlassConvTuner instance; return instance; } const cutlass::library::Operation* CutlassConvTuner::FindConv2dOperation( ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) const { return impl_->FindConv2dOperation(stream, functional_key, configuraion, arguments, workspace, workspace_size); } const cutlass::library::Operation* CutlassConvTuner::GetConv2dOperation( const std::string& name, ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) const { return impl_->GetConv2dOperation(name, stream, functional_key, configuraion, arguments, workspace, workspace_size); } } // namespace oneflow #endif // WITH_CUTLASS ================================================ FILE: oneflow/user/kernels/cutlass_conv_tuner.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_ #define ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_ #ifdef WITH_CUTLASS #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/lazy_mode.h" #include #include #include namespace oneflow { class CutlassConvTuner { public: OF_DISALLOW_COPY_AND_MOVE(CutlassConvTuner); ~CutlassConvTuner() = default; const cutlass::library::Operation* FindConv2dOperation( ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) const; const cutlass::library::Operation* GetConv2dOperation( const std::string& name, ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key, const cutlass::library::Conv2dConfiguration& configuraion, const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) const; static const CutlassConvTuner& Get(); private: CutlassConvTuner(); struct Impl; std::unique_ptr impl_; }; } // namespace oneflow #endif // WITH_CUTLASS #endif // ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_ ================================================ FILE: oneflow/user/kernels/data_shuffle_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/gather_kernel_util.h" #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/hash_functions.cuh" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/user/kernels/one_embedding_data_shuffle.cuh" namespace oneflow { namespace { enum class IdShuffleBufferType { kNumPartitionedUnique = 0, kPartitionedUniqueIds, kReceivedIds, kTableIds, kPartitionedUniqueTableIds, kReceivedTableIds, kWorkspace, kMaxType }; template class IdShuffleTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(IdShuffleTmpBufferManager); IdShuffleTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num, bool need_table_ids, bool need_process_table_ids) : offset_(0), offsets_(static_cast(IdShuffleBufferType::kMaxType), -1), sizes_(static_cast(IdShuffleBufferType::kMaxType)), ptr_(ptr) { const int64_t num_table_ids = need_process_table_ids ? num_ids : 0; const size_t table_ids_bytes = need_table_ids ? num_ids * sizeof(U) : 0; AllocBuffer(IdShuffleBufferType::kNumPartitionedUnique, parallel_num * sizeof(IDX)); size_t partitioned_ids_bytes = parallel_num * num_ids * sizeof(K); AllocBuffer(IdShuffleBufferType::kPartitionedUniqueIds, partitioned_ids_bytes); AllocBuffer(IdShuffleBufferType::kReceivedIds, partitioned_ids_bytes); AllocBuffer(IdShuffleBufferType::kTableIds, table_ids_bytes); size_t partitioned_table_ids_bytes = parallel_num * num_table_ids * sizeof(U); AllocBuffer(IdShuffleBufferType::kPartitionedUniqueTableIds, partitioned_table_ids_bytes); AllocBuffer(IdShuffleBufferType::kReceivedTableIds, partitioned_table_ids_bytes); const size_t hash_table_capacity = parallel_num * num_ids; AllocBuffer(IdShuffleBufferType::kWorkspace, hash_table_capacity * sizeof(data_shuffle::TableEntry)); } template T* Ptr(IdShuffleBufferType type) { CHECK(ptr_ != nullptr); int64_t offset = offsets_.at(static_cast(type)); CHECK_NE(offset, -1); return reinterpret_cast(reinterpret_cast(ptr_) + offset); } int64_t Size(IdShuffleBufferType type) { return sizes_.at(static_cast(type)); } size_t TotalBufferSize() const { return offset_; } private: void AllocBuffer(IdShuffleBufferType type, size_t size) { const size_t type_id = static_cast(type); CHECK_EQ(offsets_.at(type_id), -1); offsets_.at(type_id) = offset_; sizes_.at(type_id) = size; offset_ += GetCudaAlignedSize(size); } size_t offset_; std::vector offsets_; std::vector sizes_; void* ptr_; }; template class DataShuffleKernelState final : public user_op::OpKernelState { public: explicit DataShuffleKernelState(user_op::KernelInitContext* ctx) : device_index_(-1), stream_name_(EagerNcclCommMgr::kDefaultStreamName), parallel_desc_(ctx->parallel_desc()) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX))); OF_CUDA_CHECK(cudaMallocHost( &host_num_unique_matrix_, parallel_desc_.parallel_num() * parallel_desc_.parallel_num() * sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); } ~DataShuffleKernelState() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_)); } ncclComm_t comm() { return GetOrCreate().comm; } IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; } IDX* HostNumKeys() { return host_num_keys_; } embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } private: struct Comm { Comm(ncclComm_t comm) : comm(comm) {} ncclComm_t comm; }; const Comm& GetOrCreate() { if (!comm_) { Init(); } return *comm_; } void Init() { std::set> device_set; for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ncclComm_t comm; comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); comm_.reset(new Comm(comm)); } int device_index_; bool has_independent_stream_; std::string stream_name_; ParallelDesc parallel_desc_; std::unique_ptr comm_; IDX* host_num_unique_matrix_; IDX* host_num_keys_; embedding::EmbeddingState* embedding_state_; }; } // namespace template class IdShuffleKernel final : public user_op::OpKernel { public: IdShuffleKernel() : current_iter_(0){}; ~IdShuffleKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex("ids", 0); user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); user_op::Tensor* inverse_unique_partition_indices = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); user_op::Tensor* cur_rank_num_unique = ctx->Tensor4ArgNameAndIndex("cur_rank_num_unique", 0); user_op::Tensor* cur_rank_unique_ids = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_ids", 0); user_op::Tensor* cur_rank_unique_table_ids = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_table_ids", 0); user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t num_tables = ctx->Attr("num_tables"); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool has_padding_idx = ctx->Attr("has_padding_idx"); const bool has_table_ids = ctx->has_input("table_ids", 0); const bool need_gen_table_ids = (!has_table_ids && num_tables > 1); const bool need_process_table_ids = (has_table_ids || num_tables > 1); const int64_t num_ids = ids->shape_view().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); IdShuffleTmpBufferManager buffer_manager( tmp_buffer->mut_dptr(), num_ids, parallel_num, need_gen_table_ids, need_process_table_ids); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize()); ncclComm_t comm = kernel_state->comm(); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); IDX* host_num_keys = kernel_state->HostNumKeys(); data_shuffle::IdShuffleDataPtrs data_ptrs; data_ptrs.ids_ptr = reinterpret_cast(ids->dptr()); if (has_table_ids) { const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); data_ptrs.table_ids_ptr = reinterpret_cast(table_ids->dptr()); } else if (need_gen_table_ids) { data_shuffle::GenerateTableIds<<>>( num_ids, num_tables, buffer_manager.template Ptr(IdShuffleBufferType::kTableIds)); data_ptrs.table_ids_ptr = buffer_manager.template Ptr(IdShuffleBufferType::kTableIds); } else { data_ptrs.table_ids_ptr = nullptr; } data_ptrs.num_partitioned_unique = buffer_manager.template Ptr(IdShuffleBufferType::kNumPartitionedUnique); data_ptrs.partitioned_unique_ids = buffer_manager.template Ptr(IdShuffleBufferType::kPartitionedUniqueIds); data_ptrs.partitioned_unique_table_ids = buffer_manager.template Ptr(IdShuffleBufferType::kPartitionedUniqueTableIds); data_ptrs.workspace_ptr = buffer_manager.Ptr(IdShuffleBufferType::kWorkspace); data_ptrs.workspace_size = buffer_manager.Size(IdShuffleBufferType::kWorkspace); data_ptrs.received_ids = buffer_manager.template Ptr(IdShuffleBufferType::kReceivedIds); data_ptrs.received_table_ids = buffer_manager.template Ptr(IdShuffleBufferType::kReceivedTableIds); data_ptrs.num_unique_matrix_ptr = reinterpret_cast(num_unique_matrix->mut_dptr()); data_ptrs.inverse_unique_partition_indices_ptr = reinterpret_cast(inverse_unique_partition_indices->mut_dptr()); data_ptrs.cur_rank_num_unique_ptr = reinterpret_cast(cur_rank_num_unique->mut_dptr()); data_ptrs.cur_rank_unique_ids_ptr = reinterpret_cast(cur_rank_unique_ids->mut_dptr()); data_ptrs.cur_rank_unique_table_ids_ptr = reinterpret_cast(cur_rank_unique_table_ids->mut_dptr()); data_ptrs.cur_rank_inverse_indices_ptr = reinterpret_cast(cur_rank_inverse_indices->mut_dptr()); data_shuffle::IdShuffle(ctx->stream(), comm, data_ptrs, num_ids, parallel_id, parallel_num, num_unique_matrix->data_type(), ids->data_type(), cur_rank_unique_table_ids->data_type(), need_process_table_ids, has_padding_idx, padding_idx, host_num_unique_matrix, host_num_keys); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::vector num_unique_matrix_vec(parallel_num * parallel_num); std::memcpy(num_unique_matrix_vec.data(), host_num_unique_matrix, parallel_num * parallel_num * sizeof(IDX)); CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << "assume sizeof(IDX) equals to sizeof(uint32_t)"; embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_); uint32_t final_num_unique = *host_num_keys; embedding_state->SetIdFinalNumUnique(final_num_unique, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define TABLE_ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define REGISTER_CUDA_ID_SHUFFLE_KERNEL(k_dtype_pair, table_id_dtype_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("id_shuffle") \ .SetCreateFn< \ IdShuffleKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("cur_rank_unique_table_ids", 0) \ == OF_PP_PAIR_SECOND(table_id_dtype_pair)) \ && (user_op::HobDataType("num_unique_matrix", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ID_SHUFFLE_USE_P2P", false))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& ids = ctx->InputTensorDesc("ids", 0); \ const bool has_table_ids = ctx->has_input("table_ids", 0); \ const int32_t num_tables = ctx->Attr("num_tables"); \ const bool need_gen_table_ids = (!has_table_ids && num_tables > 1); \ const bool need_process_table_ids = (has_table_ids || num_tables > 1); \ IdShuffleTmpBufferManager \ buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_desc().parallel_num(), \ need_gen_table_ids, need_process_table_ids); \ return buffer_manager.TotalBufferSize(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_KERNEL, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) __device__ float RoundHalfAwayFromZero(const float x) { float abs_val = abs(x); float floor_val = floor(abs_val + static_cast(0.5)); return copysignf(floor_val, x); } // warp reduce version. constexpr int32_t kWarpSize = 32; constexpr int32_t kMaxColSize = 1024; template __inline__ __device__ T WarpMaxAllReduce(T val) { for (int32_t lane_mask = thread_group_width / 2; lane_mask > 0; lane_mask /= 2) { val = max(val, __shfl_xor_sync(0xffffffff, val, lane_mask, thread_group_width)); } return val; } inline cudaError_t GetWarpImplNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min(max_blocks, sm_count * tpm / block_size * waves)); return cudaSuccess; } template __global__ void QuantizeWarpImplKernel(const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); static_assert(thread_group_width <= kWarpSize, ""); static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int num_packs = cols_per_thread / pack_size; assert(cols <= cols_per_thread * thread_group_width); ComputeType buf[rows_per_access][cols_per_thread]; const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; const int64_t step = num_global_thread_group * rows_per_access; using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; using StoreType = cuda::elementwise::PackType; using StorePack = cuda::elementwise::Pack; for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { ComputeType thread_abs_max[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; row_id++) { ComputeType* row_buf = buf[row_id]; thread_abs_max[row_id] = 0.0; #pragma unroll for (int pack_id = 0; pack_id < num_packs; pack_id++) { const int pack_offset = pack_id * pack_size; const int col = (pack_id * thread_group_width + lane_id) * pack_size; LoadPack load_pack; if (!padding || col < cols) { const int64_t load_offset = ((row + row_id) * cols + col) / pack_size; load_pack.storage = *(reinterpret_cast(src) + load_offset); #pragma unroll for (int i = 0; i < pack_size; i++) { row_buf[pack_offset + i] = static_cast(load_pack.elem[i]); thread_abs_max[row_id] = max(thread_abs_max[row_id], abs(row_buf[pack_offset + i])); } } else { #pragma unroll for (int i = 0; i < pack_size; i++) { row_buf[pack_offset + i] = 0.0; } } } } ComputeType warp_max[rows_per_access]; #pragma unroll for (int row_id = 0; row_id < rows_per_access; row_id++) { warp_max[row_id] = WarpMaxAllReduce(thread_abs_max[row_id]); if (threadIdx.x == 0) { quantize_factor[row + row_id] = static_cast(warp_max[row_id]); } ComputeType* row_buf = buf[row_id]; ComputeType quantize_factor_val = static_cast(127.0) / warp_max[row_id]; #pragma unroll for (int col = 0; col < cols_per_thread; col++) { row_buf[col] = RoundHalfAwayFromZero(row_buf[col] * quantize_factor_val); } #pragma unroll for (int pack_id = 0; pack_id < num_packs; pack_id++) { const int pack_offset = pack_id * pack_size; const int col = (pack_id * thread_group_width + lane_id) * pack_size; StorePack store_pack; if (!padding || col < cols) { const int64_t store_offset = ((row + row_id) * cols + col) / pack_size; for (int i = 0; i < pack_size; i++) { store_pack.elem[i] = static_cast(row_buf[pack_id * pack_size + i]); } *(reinterpret_cast(dst) + store_offset) = store_pack.storage; } } } } } template inline cudaError_t LaunchQuantizeWarpImpl(cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; static_assert(block_size % thread_group_width == 0, ""); constexpr int thread_groups_per_block = block_size / thread_group_width; dim3 block_dim(thread_group_width, thread_groups_per_block); const int64_t num_blocks = (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; int grid_dim_x = 0; cudaError_t err = GetWarpImplNumBlocks(block_size, num_blocks, waves, &grid_dim_x); if (err != cudaSuccess) { return err; } QuantizeWarpImplKernel <<>>(src, dst, quantize_factor, rows, cols); return cudaPeekAtLastError(); } template inline cudaError_t DispatchQuantizeWarpImplPadding(cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { if (cols == cols_per_thread * thread_group_width) { return LaunchQuantizeWarpImpl(stream, src, dst, quantize_factor, rows, cols); } else { return LaunchQuantizeWarpImpl(stream, src, dst, quantize_factor, rows, cols); } } template typename std::enable_if::type DispatchQuantizeWarpImplCols( cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchQuantizeWarpImplPadding(stream, src, dst, \ quantize_factor, rows, cols); \ } else { \ return DispatchQuantizeWarpImplPadding(stream, src, dst, \ quantize_factor, rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchQuantizeWarpImplPadding( \ stream, src, dst, quantize_factor, rows, cols); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(5) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(7) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(9) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(11) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(13) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(15) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(17) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(19) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(21) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(23) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(25) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(27) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(29) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(31) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template typename std::enable_if::type DispatchQuantizeWarpImplCols( cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { if (cols <= 0) { return cudaErrorInvalidValue; } #define DEFINE_ONE_ELIF(thread_group_width) \ else if (cols <= (thread_group_width)*pack_size) { \ if (rows % 2 == 0) { \ return DispatchQuantizeWarpImplPadding(stream, src, dst, \ quantize_factor, rows, cols); \ } else { \ return DispatchQuantizeWarpImplPadding(stream, src, dst, \ quantize_factor, rows, cols); \ } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF #define DEFINE_ONE_ELIF(col) \ else if (cols <= (col)*kWarpSize) { \ return DispatchQuantizeWarpImplPadding( \ stream, src, dst, quantize_factor, rows, cols); \ } DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) DEFINE_ONE_ELIF(12) DEFINE_ONE_ELIF(14) DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(18) DEFINE_ONE_ELIF(20) DEFINE_ONE_ELIF(22) DEFINE_ONE_ELIF(24) DEFINE_ONE_ELIF(26) DEFINE_ONE_ELIF(28) DEFINE_ONE_ELIF(30) DEFINE_ONE_ELIF(32) #undef DEFINE_ONE_ELIF else { return cudaErrorInvalidValue; } } template struct DispatchQuantizeWarpImplPackSize { cudaError_t operator()(cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { return DispatchQuantizeWarpImplCols(stream, src, dst, quantize_factor, rows, cols); } else { return DispatchQuantizeWarpImplCols(stream, src, dst, quantize_factor, rows, cols); } } }; template __global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size, IDX elem_cnt); template __global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size, IDX elem_cnt) { IDX global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; for (int index = global_thread_id * pack_size; index < elem_cnt; index += gridDim.x * blockDim.x * pack_size) { IDX quantize_factor_idx = index / col_size; ComputeType quantize_factor_val = static_cast(quantize_factor[quantize_factor_idx]) / static_cast(127.0); using LoadPackType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; using StorePackType = cuda::elementwise::PackType; using StorePack = cuda::elementwise::Pack; LoadPack load_pack{}; StorePack store_pack{}; load_pack.storage = *(reinterpret_cast(x) + index / pack_size); #pragma unroll for (int i = 0; i < pack_size; i++) { store_pack.elem[i] = static_cast(static_cast(load_pack.elem[i]) * quantize_factor_val); } *(reinterpret_cast(out) + index / pack_size) = store_pack.storage; } } template cudaError_t DispatchDequantizeKernelPackSize(cudaStream_t stream, const int8_t* src, T* quantize_factor, T* dst, const int64_t col_size, const int64_t elem_cnt) { const int64_t pack_num = elem_cnt / pack_size; int grid_size = 0; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (err != cudaSuccess) { return err; } DequantizeKernel <<>>(src, quantize_factor, dst, col_size, elem_cnt); return cudaSuccess; } template inline cudaError_t LaunchDequantizeKernel(cudaStream_t stream, const int8_t* src, T* quantize_factor, T* dst, const int64_t col_size, const int64_t elem_cnt) { constexpr int quantized_src_pack_size = cuda::elementwise::PackSize(); constexpr int dst_pack_size = cuda::elementwise::PackSize(); int launch_pack_size = std::min(quantized_src_pack_size, dst_pack_size); if (launch_pack_size == 8 && col_size % 8 == 0) { cudaError_t err = DispatchDequantizeKernelPackSize( stream, src, quantize_factor, dst, col_size, elem_cnt); if (err != cudaSuccess) { return err; } } else if (launch_pack_size == 4 && col_size % 4 == 0) { cudaError_t err = DispatchDequantizeKernelPackSize( stream, src, quantize_factor, dst, col_size, elem_cnt); if (err != cudaSuccess) { return err; } } else if (launch_pack_size == 2 && col_size % 2 == 0) { cudaError_t err = DispatchDequantizeKernelPackSize( stream, src, quantize_factor, dst, col_size, elem_cnt); if (err != cudaSuccess) { return err; } } else { cudaError_t err = DispatchDequantizeKernelPackSize( stream, src, quantize_factor, dst, col_size, elem_cnt); if (err != cudaSuccess) { return err; } } return cudaPeekAtLastError(); } template struct DefaultComputeType { using type = T; }; template<> struct DefaultComputeType { using type = float; }; template class EmbeddingShuffleKernel final : public user_op::OpKernel { public: EmbeddingShuffleKernel() : current_iter_(0) {} ~EmbeddingShuffleKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::unique_ptr allocator = embedding_state->NewTmpBufferAllocator(ctx); embedding_state->OnEmbeddingShuffleStart(ctx, current_iter_); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); const user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); const user_op::Tensor* inverse_unique_partition_indices = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); ncclComm_t comm = kernel_state->comm(); using ComputeType = typename DefaultComputeType::type; const int64_t embedding_size = ctx->Attr("embedding_size"); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); DataType data_type = embeddings->data_type(); const int64_t num_ids = inverse_unique_partition_indices->shape_view().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const bool skip_last_gather = ctx->Attr("skip_last_gather"); bool enable_quantized_comm_env_var = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize); if (enable_quantized_comm_env_var && !enable_quantized_comm) { LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and " "embedding_size less equal than 1024 can use quantized communication. "; } cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const std::vector& num_unique_matrix_vec = embedding_state->GetIdNumUniqueMatrix(current_iter_); CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << "assume sizeof(IDX) equals to sizeof(uint32_t)"; ; std::memcpy(host_num_unique_matrix, num_unique_matrix_vec.data(), parallel_num * parallel_num * sizeof(IDX)); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); int64_t cur_rank_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } int64_t unique_partitioned_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i]; } const T* cur_rank_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingShuffleCurRankEmbeddings(current_iter_)); if (!enable_quantized_comm) { // 1. reverse cur_rank unique, from (num_unique, embedding_size) to (cur_rank_num_ids, // embedding_size) void* reverse_unique_cur_rank_embeddings; allocator->Allocate(&reverse_unique_cur_rank_embeddings, cur_rank_num_ids * embedding_size * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}), reinterpret_cast(reverse_unique_cur_rank_embeddings), 0); // 2. send recv embedding, from (cur_rank_num_ids, embedding_size) to // (unique_partitioned_num_ids, embedding_size) if (skip_last_gather) { data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, data_type, host_num_unique_matrix, reinterpret_cast(reverse_unique_cur_rank_embeddings), embeddings->mut_dptr()); allocator->Free(reverse_unique_cur_rank_embeddings); } else { void* received_embeddings; // T allocator->Allocate(&received_embeddings, GetCudaAlignedSize(unique_partitioned_num_ids * embedding_size * sizeof(T))); data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, data_type, host_num_unique_matrix, reinterpret_cast(reverse_unique_cur_rank_embeddings), reinterpret_cast(received_embeddings)); allocator->Free(reverse_unique_cur_rank_embeddings); // 3. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to // (num_ids, embedding_size) GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(received_embeddings), Shape({1, unique_partitioned_num_ids, embedding_size}), embeddings->mut_dptr(), 0); allocator->Free(received_embeddings); } } else { CHECK(!skip_last_gather) << "when enable_quantized_comm, should not use fuse kernel."; // 1. quantize cur_rank_embeddings, from (num_unique, embedding_size) T to (num_unique, // embedding_size) int8_t, and get (num_unique,) T factor void* quantize_cur_rank_embeddings; // int8_t allocator->Allocate(&quantize_cur_rank_embeddings, num_unique * embedding_size * sizeof(int8_t)); void* cur_rank_quantize_factor; // T allocator->Allocate(&cur_rank_quantize_factor, num_unique * sizeof(T)); DispatchQuantizeWarpImplPackSize()( cuda_stream, cur_rank_embeddings_ptr, reinterpret_cast(quantize_cur_rank_embeddings), reinterpret_cast(cur_rank_quantize_factor), num_unique, embedding_size); // 2. reverse cur_rank unique, from (num_unique, embedding_size) to (cur_rank_num_ids, // embedding_size) void* reverse_unique_cur_rank_embeddings; // int8_t allocator->Allocate(&reverse_unique_cur_rank_embeddings, cur_rank_num_ids * embedding_size * sizeof(int8_t)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, reinterpret_cast(quantize_cur_rank_embeddings), Shape({1, num_unique, embedding_size}), reinterpret_cast(reverse_unique_cur_rank_embeddings), 0); allocator->Free(quantize_cur_rank_embeddings); // 3. reverse cur_rank quantize factor unique, from (num_unique) to (cur_rank_num_ids) void* reverse_cur_rank_quantize_factor; // T allocator->Allocate(&reverse_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_num_ids, reinterpret_cast(cur_rank_quantize_factor), Shape({1, num_unique, 1}), reinterpret_cast(reverse_cur_rank_quantize_factor), 0); allocator->Free(cur_rank_quantize_factor); // 4. send recv embedding and factor, from (cur_rank_num_ids, embedding_size) to // (unique_partitioned_num_ids, embedding_size) void* received_embeddings; // int8_t void* recv_quantize_factor; // T allocator->Allocate(&received_embeddings, unique_partitioned_num_ids * embedding_size * sizeof(int8_t)); allocator->Allocate(&recv_quantize_factor, unique_partitioned_num_ids * sizeof(T)); data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, data_type, host_num_unique_matrix, reinterpret_cast(reverse_unique_cur_rank_embeddings), reinterpret_cast(received_embeddings), reinterpret_cast(reverse_cur_rank_quantize_factor), reinterpret_cast(recv_quantize_factor)); allocator->Free(reverse_unique_cur_rank_embeddings); allocator->Free(reverse_cur_rank_quantize_factor); // 5. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to (num_ids, // embedding_size) void* reverse_recv_quantize_cur_rank_embeddings; // int8_t allocator->Allocate(&reverse_recv_quantize_cur_rank_embeddings, num_ids * embedding_size * sizeof(int8_t)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(received_embeddings), Shape({1, unique_partitioned_num_ids, embedding_size}), reinterpret_cast(reverse_recv_quantize_cur_rank_embeddings), 0); allocator->Free(received_embeddings); // 6. reverse unique_partition_factor, from (unique_partitioned_num_ids) to (num_ids) void* reverse_recv_quantize_factor; // T allocator->Allocate(&reverse_recv_quantize_factor, num_ids * sizeof(T)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), num_ids, reinterpret_cast(recv_quantize_factor), Shape({1, unique_partitioned_num_ids, 1}), reinterpret_cast(reverse_recv_quantize_factor), 0); allocator->Free(recv_quantize_factor); // 7. dequantize embeddings, from (num_ids, embedding_size) int8_t to (num_ids, // embedding_size) T int32_t dequantize_row_size = num_ids; IDX dequantize_elem_cnt = dequantize_row_size * embedding_size; OF_CUDA_CHECK((LaunchDequantizeKernel( cuda_stream, reinterpret_cast(reverse_recv_quantize_cur_rank_embeddings), reinterpret_cast(reverse_recv_quantize_factor), embeddings->mut_dptr(), embedding_size, dequantize_elem_cnt))); allocator->Free(reverse_recv_quantize_cur_rank_embeddings); allocator->Free(reverse_recv_quantize_factor); } embedding_state->OnEmbeddingShuffleEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_EMBEDDING_SHUFFLE_KERNEL(t_dtype_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("embedding_shuffle") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("cur_rank_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)) \ && ((user_op::HobAttr("skip_last_gather") == false) \ || (!embedding::UseEmbeddingShuffleP2PKernel(OF_PP_PAIR_SECOND(t_dtype_pair), \ OF_PP_PAIR_SECOND(idx_dtype_pair)))) \ && (user_op::HobDataType("num_unique_matrix", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& inverse_unique_partition_indices = \ ctx->InputTensorDesc("inverse_unique_partition_indices", 0); \ const int64_t num_ids = inverse_unique_partition_indices.shape().elem_cnt(); \ const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); \ const int64_t cur_rank_max_num_ids = parallel_num * num_ids; \ const int64_t embedding_size = ctx->Attr("embedding_size"); \ bool enable_quantized_comm = \ ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \ && (embedding_size < kMaxColSize); \ size_t tmp_size = 0; \ if (embedding::UseDynamicMemoryAllocation()) { return tmp_size; } \ if (!enable_quantized_comm) { \ size_t reverse_cur_rank_embeddings_size = GetCudaAlignedSize( \ cur_rank_max_num_ids * embedding_size * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ size_t recv_unique_embeddings_size = reverse_cur_rank_embeddings_size; \ tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings_size; \ } else { \ size_t total_elem_cnt = cur_rank_max_num_ids * embedding_size; \ size_t reverse_cur_rank_embeddings_size = \ GetCudaAlignedSize(total_elem_cnt * sizeof(int8_t)); \ size_t recv_unique_embeddings = reverse_cur_rank_embeddings_size; \ size_t quantize_cur_rank_embeddings_size = reverse_cur_rank_embeddings_size; \ size_t reverse_recv_quantize_cur_rank_embeddings_size = \ reverse_cur_rank_embeddings_size; \ size_t cur_rank_quantize_factor_size = \ GetCudaAlignedSize(cur_rank_max_num_ids * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ size_t reverse_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; \ size_t recv_quantize_factor_size = cur_rank_quantize_factor_size; \ size_t reverse_recv_quantize_factor_size = cur_rank_quantize_factor_size; \ tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings \ + quantize_cur_rank_embeddings_size \ + reverse_recv_quantize_cur_rank_embeddings_size \ + cur_rank_quantize_factor_size + reverse_cur_rank_quantize_factor_size \ + recv_quantize_factor_size + reverse_recv_quantize_factor_size; \ } \ return tmp_size; \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_SHUFFLE_KERNEL, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { public: EmbeddingGradientShuffleKernel() : current_iter_(0){}; ~EmbeddingGradientShuffleKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::unique_ptr allocator = embedding_state->NewTmpBufferAllocator(ctx); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); const user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); const user_op::Tensor* inverse_unique_partition_indices = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); user_op::Tensor* cur_rank_unique_embedding_grad = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_embedding_grad", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const bool only_zero_valid_grad = ctx->Attr("only_zero_valid_grad"); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); DataType data_type = embedding_grad->data_type(); const int64_t num_ids = inverse_unique_partition_indices->shape_view().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int64_t padded_embedding_size = data_shuffle::GetPaddedEmbeddingSize(data_type, embedding_size); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); ncclComm_t comm = kernel_state->comm(); using ComputeType = typename DefaultComputeType::type; bool enable_quantized_comm_env_var = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); bool enable_quantized_comm = enable_quantized_comm_env_var && (padded_embedding_size < kMaxColSize); if (enable_quantized_comm_env_var && !enable_quantized_comm) { LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and " "embedding_size less equal than 1024 can use quantized communication. "; } const bool skip_first_scatter = ctx->Attr("skip_first_scatter"); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const std::vector& num_unique_matrix_vec = embedding_state->GetIdNumUniqueMatrix(current_iter_); CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << "assume sizeof(IDX) equals to sizeof(uint32_t)"; std::memcpy(host_num_unique_matrix, num_unique_matrix_vec.data(), parallel_num * parallel_num * sizeof(IDX)); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); int64_t cur_rank_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } int64_t unique_partitioned_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i]; } if (!enable_quantized_comm) { // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids, // padded_embedding_size) void* unique_partition_embedding_grad; // T allocator->Allocate(&unique_partition_embedding_grad, unique_partitioned_num_ids * padded_embedding_size * sizeof(T)); const T* unique_embedding_grad_ptr; if (skip_first_scatter) { unique_embedding_grad_ptr = embedding_grad->dptr(); } else { data_shuffle::UniquePartitionEmbeddingGrad( ctx->stream(), unique_partitioned_num_ids, num_ids, embedding_size, padded_embedding_size, host_num_unique_matrix, embedding_grad->dptr(), reinterpret_cast(inverse_unique_partition_indices->dptr()), reinterpret_cast(unique_partition_embedding_grad)); unique_embedding_grad_ptr = reinterpret_cast(unique_partition_embedding_grad); } // 2. send recv grad, from (unique_partitioned_num_ids, padded_embedding_size) to // (cur_rank_num_ids, padded_embedding_size) void* received_embedding_grad; // T allocator->Allocate(&received_embedding_grad, cur_rank_num_ids * padded_embedding_size * sizeof(T)); data_shuffle::ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, padded_embedding_size, data_type, host_num_unique_matrix, unique_embedding_grad_ptr, reinterpret_cast(received_embedding_grad)); // 3. sum to unique grad, from (cur_rank_num_ids, padded_embedding_size) to (num_unique, // padded_embedding_size) then slice to out from (num_unique, padded_embedding_size) to // (num_unique, embedding_size) should memset cur_rank_unique_embedding_grad all tensor for // amp count_not_finite // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer. T* buffer_ptr = reinterpret_cast(unique_partition_embedding_grad); data_shuffle::UniqueCurRankEmbeddingGrad( ctx->stream(), data_type, cur_rank_num_ids, num_unique, embedding_size, padded_embedding_size, only_zero_valid_grad, cur_rank_unique_embedding_grad->shape_view().elem_cnt(), reinterpret_cast(received_embedding_grad), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); allocator->Free(unique_partition_embedding_grad); allocator->Free(received_embedding_grad); } else { CHECK(!skip_first_scatter) << "when enable_quantized_comm, should not use fuse kernel."; // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids, // padded_embedding_size) void* unique_partition_embedding_grad; // T allocator->Allocate(&unique_partition_embedding_grad, unique_partitioned_num_ids * padded_embedding_size * sizeof(T)); data_shuffle::UniquePartitionEmbeddingGrad( ctx->stream(), unique_partitioned_num_ids, num_ids, embedding_size, padded_embedding_size, host_num_unique_matrix, embedding_grad->dptr(), reinterpret_cast(inverse_unique_partition_indices->dptr()), reinterpret_cast(unique_partition_embedding_grad)); // 2. Quantize unique_partition_embedding_grad, get // quantize_cur_rank_embedding_grad(unique_partitioned_num_ids, padded_embedding_size) int8_t // and cur_rank_quantize_factor(unique_partitioned_num_ids) T void* quantize_cur_rank_embedding_grad; // int8_t allocator->Allocate(&quantize_cur_rank_embedding_grad, unique_partitioned_num_ids * padded_embedding_size * sizeof(int8_t)); void* cur_rank_quantize_factor; // T allocator->Allocate(&cur_rank_quantize_factor, unique_partitioned_num_ids * sizeof(T)); DispatchQuantizeWarpImplPackSize()( cuda_stream, reinterpret_cast(unique_partition_embedding_grad), reinterpret_cast(quantize_cur_rank_embedding_grad), reinterpret_cast(cur_rank_quantize_factor), unique_partitioned_num_ids, padded_embedding_size); // 3. send recv grad, from (unique_partitioned_num_ids, padded_embedding_size) int8_t to // (cur_rank_num_ids, padded_embedding_size) int8_t send recv quantize_factor, from // (unique_partitioned_num_ids) T to (cur_rank_num_ids) T void* received_embedding_grad; // int8_t allocator->Allocate(&received_embedding_grad, cur_rank_num_ids * padded_embedding_size * sizeof(int8_t)); void* received_cur_rank_quantize_factor; // T allocator->Allocate(&received_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T)); data_shuffle::ShuffleEmbeddingsGrad( cuda_stream, comm, parallel_id, parallel_num, num_ids, padded_embedding_size, data_type, host_num_unique_matrix, reinterpret_cast(quantize_cur_rank_embedding_grad), reinterpret_cast(received_embedding_grad), reinterpret_cast(cur_rank_quantize_factor), reinterpret_cast(received_cur_rank_quantize_factor)); allocator->Free(quantize_cur_rank_embedding_grad); allocator->Free(cur_rank_quantize_factor); /* Host num unique matrix: | Partition0 | Partition1 | | Rank0 | 2 | 4 | | Rank1 | 3 | 3 | After ShuffleEmbeddingGrads, each rank will exchange partition. For example: Rank0 will have (matrix[rank0][part0] + matrix[rank1][part0]) grad tensor. Rank1 will have (matrix[rank0][part1] + matrix[rank1][part1]) grad tensor. */ // 4. dequantize grad, from (cur_rank_num_ids, padded_embedding_size) int8_t to // (cur_rank_num_ids, padded_embedding_size) T void* dequantize_cur_rank_embedding_grad; // T allocator->Allocate(&dequantize_cur_rank_embedding_grad, cur_rank_num_ids * padded_embedding_size * sizeof(T)); OF_CUDA_CHECK((LaunchDequantizeKernel( cuda_stream, reinterpret_cast(received_embedding_grad), reinterpret_cast(received_cur_rank_quantize_factor), reinterpret_cast(dequantize_cur_rank_embedding_grad), padded_embedding_size, cur_rank_num_ids * padded_embedding_size))); allocator->Free(received_embedding_grad); allocator->Free(received_cur_rank_quantize_factor); // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer. T* buffer_ptr = reinterpret_cast(unique_partition_embedding_grad); // 5. sum to unique grad, from (cur_rank_num_ids, padded_embedding_size) to (num_unique, // padded_embedding_size) then slice to out from (num_unique, padded_embedding_size) to // (num_unique, embedding_size) should memset cur_rank_unique_embedding_grad all tensor for // amp count_not_finite data_shuffle::UniqueCurRankEmbeddingGrad( ctx->stream(), data_type, cur_rank_num_ids, num_unique, embedding_size, padded_embedding_size, only_zero_valid_grad, cur_rank_unique_embedding_grad->shape_view().elem_cnt(), reinterpret_cast(dequantize_cur_rank_embedding_grad), reinterpret_cast(cur_rank_inverse_indices->dptr()), cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); allocator->Free(unique_partition_embedding_grad); allocator->Free(dequantize_cur_rank_embedding_grad); } current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_EMBEDDING_GRADIENT_SHUFFLE_KERNEL(t_dtype_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("embedding_gradient_shuffle") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)) \ && ((user_op::HobAttr("skip_first_scatter") == false) \ || (!embedding::UseEmbeddingGradientShuffleP2PKernel( \ OF_PP_PAIR_SECOND(t_dtype_pair), OF_PP_PAIR_SECOND(idx_dtype_pair)))) \ && (user_op::HobDataType("num_unique_matrix", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& cur_rank_unique_embedding_grad = \ ctx->InputTensorDesc("cur_rank_unique_embedding_grad", 0); \ size_t cur_rank_embedding_grad_num = cur_rank_unique_embedding_grad.shape().At(0); \ size_t embedding_size = cur_rank_unique_embedding_grad.shape().At(1); \ size_t padded_embedding_size = data_shuffle::GetPaddedEmbeddingSize( \ cur_rank_unique_embedding_grad.data_type(), embedding_size); \ size_t cur_rank_embedding_grad_elem_cnt = \ cur_rank_embedding_grad_num * padded_embedding_size; \ bool enable_quantized_comm = \ ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \ && (padded_embedding_size < kMaxColSize); \ size_t tmp_size = 0; \ if (embedding::UseDynamicMemoryAllocation()) { return tmp_size; } \ if (!enable_quantized_comm) { \ size_t cur_rank_embedding_grad_size = GetCudaAlignedSize( \ cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ tmp_size = 2 * cur_rank_embedding_grad_size; \ } else { \ size_t unique_partition_embedding_grad_size = GetCudaAlignedSize( \ cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ size_t received_embedding_grad_size = \ GetCudaAlignedSize(cur_rank_embedding_grad_elem_cnt * sizeof(int8_t)); \ size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size; \ size_t cur_rank_quantize_factor_size = GetCudaAlignedSize( \ cur_rank_embedding_grad_num * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ size_t received_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; \ size_t dequantize_cur_rank_embedding_grad_size = unique_partition_embedding_grad_size; \ tmp_size = unique_partition_embedding_grad_size + received_embedding_grad_size \ + quantize_cur_rank_embedding_grad_size + cur_rank_quantize_factor_size \ + received_cur_rank_quantize_factor_size \ + dequantize_cur_rank_embedding_grad_size; \ } \ return tmp_size; \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_GRADIENT_SHUFFLE_KERNEL, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class EmbeddingUniqueKeyValuePairKernelState final : public user_op::OpKernelState { public: explicit EmbeddingUniqueKeyValuePairKernelState(user_op::KernelInitContext* ctx) : device_index_(-1) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); } ~EmbeddingUniqueKeyValuePairKernelState() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_num_keys_)); } embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } IDX* HostNumKeys() { return host_num_keys_; } private: int device_index_; embedding::EmbeddingState* embedding_state_; IDX* host_num_keys_; }; template class UniqueKeyValuePairKernel final : public user_op::OpKernel { public: UniqueKeyValuePairKernel() : current_iter_(0){}; ~UniqueKeyValuePairKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* keys = ctx->Tensor4ArgNameAndIndex("keys", 0); user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex("num_unique", 0); user_op::Tensor* unique_keys = ctx->Tensor4ArgNameAndIndex("unique_keys", 0); user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0); user_op::Tensor* inverse_indices = ctx->Tensor4ArgNameAndIndex("inverse_indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t num_tables = ctx->Attr("num_tables"); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool has_padding_idx = ctx->Attr("has_padding_idx"); const bool has_values = ctx->has_input("values", 0); const bool need_values_buffer = (!has_values && num_tables > 1); size_t values_buffer_bytes = need_values_buffer ? GetCudaAlignedSize(keys->shape_view().elem_cnt() * sizeof(V)) : 0; const int64_t num_keys = keys->shape_view().elem_cnt(); const int64_t hash_capacity = num_keys; const size_t workspace_bytes = GetCudaAlignedSize(hash_capacity * sizeof(data_shuffle::TableEntry)); CHECK_LE(values_buffer_bytes + workspace_bytes, tmp_buffer->shape_view().elem_cnt()); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const V* values_ptr; if (has_values) { const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); values_ptr = reinterpret_cast(values->dptr()); } else if (need_values_buffer) { V* values_buffer_ptr = reinterpret_cast(tmp_buffer->mut_dptr()); data_shuffle::GenerateTableIds<<>>(num_keys, num_tables, values_buffer_ptr); values_ptr = values_buffer_ptr; } else { values_ptr = nullptr; } const bool need_process_table_ids = (has_values || num_tables > 1); data_shuffle::TableEntry* workspace_ptr = reinterpret_cast*>( tmp_buffer->mut_dptr() + values_buffer_bytes); data_shuffle::UniqueAndPartition( cuda_stream, num_keys, hash_capacity, 1, reinterpret_cast(keys->dptr()), values_ptr, reinterpret_cast(num_unique->mut_dptr()), reinterpret_cast(unique_keys->mut_dptr()), reinterpret_cast(unique_values->mut_dptr()), reinterpret_cast(inverse_indices->mut_dptr()), workspace_ptr, workspace_bytes, need_process_table_ids, has_padding_idx, padding_idx); IDX* host_num_keys = kernel_state->HostNumKeys(); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_unique->mut_dptr(), sizeof(IDX), cudaMemcpyDefault, cuda_stream)); CHECK_JUST(ctx->stream()->Sync()); uint32_t num_unique_ids = *host_num_keys; embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::vector num_unique_matrix_vec({num_unique_ids}); embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_); embedding_state->SetIdFinalNumUnique(num_unique_ids, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_UNIQUE_KEY_VALUE_PAIR_KERNEL(k_dtype_pair, value_dtype_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("unique_key_value_pair") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("keys", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("inverse_indices", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("unique_values", 0) == OF_PP_PAIR_SECOND(value_dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& keys = ctx->InputTensorDesc("keys", 0); \ const int64_t num_keys = keys.shape().elem_cnt(); \ const int64_t hash_capacity = num_keys; \ const size_t workspace_bytes = GetCudaAlignedSize( \ hash_capacity * sizeof(data_shuffle::TableEntry)); \ const int32_t num_tables = ctx->Attr("num_tables"); \ const bool has_values = ctx->has_input("values", 0); \ const bool need_values_buffer = (!has_values && num_tables > 1); \ size_t values_buffer_bytes = \ need_values_buffer \ ? GetCudaAlignedSize(num_keys * sizeof(OF_PP_PAIR_FIRST(value_dtype_pair))) \ : 0; \ return workspace_bytes + values_buffer_bytes; \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_UNIQUE_KEY_VALUE_PAIR_KERNEL, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class OneEmbeddingGatherKernel final : public user_op::OpKernel { public: OneEmbeddingGatherKernel() : current_iter_(0) {} ~OneEmbeddingGatherKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingGatherStart(ctx, current_iter_); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const int64_t num_indices = indices->shape_view().elem_cnt(); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_size = ctx->Attr("embedding_size"); const T* in_ptr = reinterpret_cast(embedding_state->EmbeddingGatherIn(current_iter_)); GatherKernelUtilImpl::Forward( ctx->stream(), reinterpret_cast(indices->dptr()), num_indices, in_ptr, Shape({1, num_unique, embedding_size}), out->mut_dptr(), 0); embedding_state->OnEmbeddingGatherEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_ONE_EMBEDDING_GATHER_KERNEL(in_type, indices_type) \ REGISTER_USER_KERNEL("one_embedding_gather") \ .SetCreateFn< \ OneEmbeddingGatherKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ONE_EMBEDDING_GATHER_KERNEL, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("id_shuffle"); REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("embedding_shuffle"); REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("embedding_gradient_shuffle"); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/deconv_cpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewDeconvTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, true); } template std::unique_ptr NewDeconvTransANoTransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); } auto DeconvTransATransBMatmulPrimitiveExists() { return hob::make_custom("DeconvTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewDeconvTransATransBMatmulPrimitive(&ctx).operator bool(); }); } auto DeconvTransANoTransBMatmulPrimitiveExists() { return hob::make_custom("DeconvTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewDeconvTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template using Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr); template T* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) { return tensor->mut_dptr() + tensor->shape_view().Count(1) * idx; } template const T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) { return tensor->dptr() + tensor->shape_view().Count(1) * idx; } size_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape, const int32_t idx_offset) { int64_t col_buf_elem_cnt = 1; int64_t ndims = out_shape.NumAxes() - 2; for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); } for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); } return col_buf_elem_cnt; } template class ColBufWriter { public: ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : src_ptr_(src_ptr), dst_ptr_(dst_ptr), c_size_(c_size), id_size_(id_size), ih_size_(ih_size), iw_size_(iw_size), od_size_(od_size), oh_size_(oh_size), ow_size_(ow_size) {} virtual ~ColBufWriter() = default; virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void InvalidDFunc() = 0; virtual void InvalidHFunc() = 0; virtual void InvalidWFunc() = 0; virtual void NextImCSize() = 0; protected: const T* src_ptr_; T* dst_ptr_; int64_t c_size_; int64_t id_size_; int64_t ih_size_; int64_t iw_size_; int64_t od_size_; int64_t oh_size_; int64_t ow_size_; }; template class Col2ImWriter final : public ColBufWriter { public: Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Col2ImWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] += *(this->src_ptr_++); } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++); } void InvalidDFunc() override { this->src_ptr_ += this->od_size_; } void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; } void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; } void NextImCSize() override { this->dst_ptr_ += this->c_size_; } }; template using DHWValidFunc = void (ColBufWriter::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw); template class ColBufUtil final { public: ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before) : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before) { id_num_ = in_shape.At(dhw_offset); ih_num_ = in_shape.At(dhw_offset + 1); iw_num_ = in_shape.At(dhw_offset + 2); od_num_ = out_shape.At(dhw_offset); oh_num_ = out_shape.At(dhw_offset + 1); ow_num_ = out_shape.At(dhw_offset + 2); if (dhw_offset == 2) { dhw_valid_func_ = &ColBufWriter::CDHWWrite; } else { dhw_valid_func_ = &ColBufWriter::DHWCWrite; } } void operator()(ColBufWriter* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) { int64_t id = kd * dilation_rate_[0] - padding_before_[0]; FOR_RANGE(int64_t, od, 0, od_num_) { if (id < 0 || id >= id_num_) { col_buf_writer->InvalidDFunc(); } else { int64_t ih = kh * dilation_rate_[1] - padding_before_[1]; FOR_RANGE(int64_t, oh, 0, oh_num_) { if (ih < 0 || ih >= ih_num_) { col_buf_writer->InvalidHFunc(); } else { int64_t iw = kw * dilation_rate_[2] - padding_before_[2]; FOR_RANGE(int64_t, ow, 0, ow_num_) { if (iw < 0 || iw >= iw_num_) { col_buf_writer->InvalidWFunc(); } else { (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw); } iw += strides_[2]; } } ih += strides_[1]; } } id += strides_[0]; } } private: int64_t id_num_; int64_t ih_num_; int64_t iw_num_; int64_t od_num_; int64_t oh_num_; int64_t ow_num_; const int32_t* strides_; const int32_t* dilation_rate_; const int32_t* padding_before_; DHWValidFunc dhw_valid_func_; }; template struct DeconvKernelUtil final { public: static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } private: static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { for (int64_t c = 0; c != weight_shape.At(4); ++c) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } }; template struct DeconvOpKernelCache final : public user_op::OpKernelCache { Col2ImFunc col2im_func_ = nullptr; Shape in_5d_shape_; Shape out_5d_shape_; Shape weight_5d_shape_; std::vector strides_3d_; std::vector dilation_rate_3d_; std::vector padding_before_3d_; bool is_out_diff_need_trans_ = false; int32_t idx_offset_ = 0; bool is_dynamic_ = false; void Update(const ShapeView& x_shape, const ShapeView& out_shape) { auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { DimVector ret_vec; shape.ToDimVector(&ret_vec); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; if (is_dynamic_) { Shape in_shape; in_5d_shape_ = Gen5DShape(x_shape, idx_offset_); out_5d_shape_ = Gen5DShape(out_shape, idx_offset_); } } }; template std::shared_ptr> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); std::shared_ptr> cache(new DeconvOpKernelCache()); if (data_format == "channels_first") { cache->col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; cache->is_out_diff_need_trans_ = false; cache->idx_offset_ = 2; } else { cache->col2im_func_ = DeconvKernelUtil::NDHWCCol2Im; cache->is_out_diff_need_trans_ = true; cache->idx_offset_ = 1; } auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { DimVector ret_vec(shape.dim_vec()); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; cache->in_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), cache->idx_offset_); cache->out_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_); cache->weight_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; cache->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { cache->padding_before_3d_.push_back(0); } else { cache->padding_before_3d_.push_back(padding_before.at(index)); } } return cache; } template class DeconvCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(DeconvCpuKernel); DeconvCpuKernel() = default; ~DeconvCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) { auto deconv_cache = std::dynamic_pointer_cast>(*cache_ptr); deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex("in", 0)->shape(), ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape()); return; } *cache_ptr = CreateDeconvOpKernelCache(ctx, "out", "in", "weight"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto deconv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(deconv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); Memset(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (deconv_cache->is_out_diff_need_trans_) { matmul = NewDeconvTransATransBMatmulPrimitive(ctx); } else { matmul = NewDeconvTransANoTransBMatmulPrimitive(ctx); } CHECK(matmul); FOR_RANGE(int64_t, i, 0, in->shape_view().At(0)) { // channels first: col_buf' = weight(T) * in[i]' // channels last : col_buf' = weight(T) * in[i]'(T) // m, n, k int32_t idx_offset = deconv_cache->idx_offset_; matmul->Launch(ctx->stream(), deconv_cache->weight_5d_shape_.Count(1), deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), deconv_cache->weight_5d_shape_.At(0), static_cast(1), weight->dptr(), GetImgDptr(in, i), static_cast(0), col_buf->mut_dptr()); // out = col2im(col_buf') deconv_cache->col2im_func_( col_buf->dptr(), ShapeView(deconv_cache->in_5d_shape_), ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_), deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(), deconv_cache->padding_before_3d_.data(), GetImgMutDptr(out, i)); } } }; #define REGISTER_DECONV_DATA_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") == 1) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && DeconvTransATransBMatmulPrimitiveExists() \ && DeconvTransANoTransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& in_shape = ctx->InputTensorDesc("in", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(in_shape, weight_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_DECONV_DATA_KERNEL(deconv1d, float); REGISTER_DECONV_DATA_KERNEL(deconv1d, double); REGISTER_DECONV_DATA_KERNEL(deconv2d, float); REGISTER_DECONV_DATA_KERNEL(deconv2d, double); REGISTER_DECONV_DATA_KERNEL(deconv3d, float); REGISTER_DECONV_DATA_KERNEL(deconv3d, double); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/deconv_cudnn_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/device/cudnn_conv_util.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template struct CudnnDeConvArgsAndAlgo final { using AlgoT = decltype(std::declval().algo); CudnnConvArgs args; PerfT algo_perf; // CudnnDeConvArgsAndAlgo CudnnDeConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w, const user_op::Tensor* y, user_op::Tensor* buf, const user_op::KernelComputeContext* ctx, ep::Stream* stream, bool has_forced_algo, int32_t forced_algo) : args(*ctx, x->data_type(), x->shape_view(), w->data_type(), w->shape_view(), y->data_type(), y->shape_view(), ctx->Attr("data_format"), buf->shape_view().elem_cnt(), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_heuristic_search_algo(), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_use_deterministic_algo_only(), Singleton::Get() ->resource() .cudnn_conf() .cudnn_conv_enable_pseudo_half()) { size_t byte_size_of_buf = buf->shape_view().elem_cnt(); AllocatedCudnnConvResource res(stream->As()->cudnn_handle(), const_cast(x->dptr()), const_cast(w->dptr()), const_cast(y->dptr()), buf->mut_dptr()); if (has_forced_algo) { algo_perf = GetCudnnConvAlgorithmPerferenceWithResource( &args, &res, static_cast(forced_algo)); } else { algo_perf = FindCudnnConvAlgorithmWithResource(&args, &res); } CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS) << "op (" << ctx->op_name() << ") find algorithm perference failed. algo: " << algo_perf.algo; CHECK_LE(algo_perf.memory, byte_size_of_buf) << "op (" << ctx->op_name() << ") find algorithm " << algo_perf.algo << ", need memory " << algo_perf.memory << ", but cudnn_buf_limit_byte is " << byte_size_of_buf; } CudnnDeConvArgsAndAlgo() = delete; OF_DISALLOW_COPY_AND_MOVE(CudnnDeConvArgsAndAlgo); }; template size_t InferTmpSizeWithCudnn(const user_op::TensorDesc* x, const user_op::TensorDesc* w, const user_op::TensorDesc* y, const user_op::InferContext& ctx, bool has_forced_algo, int32_t forced_algo) { using AlgoT = decltype(std::declval().algo); const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); size_t workspace_size = cudnn_conf.cudnn_buf_limit_mbyte() * 1024 * 1024; if (!x->is_dynamic()) { CudnnConvArgs args(ctx, x->data_type(), ShapeView(x->shape()), w->data_type(), ShapeView(w->shape()), y->data_type(), ShapeView(y->shape()), ctx.Attr("data_format"), workspace_size, cudnn_conf.cudnn_conv_heuristic_search_algo(), cudnn_conf.cudnn_conv_use_deterministic_algo_only(), cudnn_conf.cudnn_conv_enable_pseudo_half()); PerfT algo_perf; if (has_forced_algo) { algo_perf = GetCudnnConvAlgorithmPerference(&args, static_cast(forced_algo)); } else { algo_perf = FindCudnnConvAlgorithm(&args); } CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS) << "op (" << ctx.op_name() << ") find algorithm perference failed. algo: " << algo_perf.algo; CHECK_LE(algo_perf.memory, workspace_size) << "op (" << ctx.op_name() << ") find algorithm " << algo_perf.algo << ", need memory " << algo_perf.memory << ", but cudnn_buf_limit_byte is " << workspace_size; workspace_size = algo_perf.memory; } workspace_size = std::max(size_t(1), workspace_size); return workspace_size; } } // namespace template class DeConvGpuKernel final : public user_op::OpKernel { public: DeConvGpuKernel() = default; ~DeConvGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); if (in->shape_view().elem_cnt() == 0) return; const auto& cudnn_conf = Singleton::Get()->resource().cudnn_conf(); CudnnDeConvArgsAndAlgo args_and_algo( out, weight, in, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), cudnn_conf.cudnn_conv_force_bwd_data_algo()); const CudnnConvArgs& args = args_and_algo.args; const cudnnConvolutionBwdDataAlgoPerf_t& algo_perf = args_and_algo.algo_perf; OF_CUDNN_CHECK(cudnnConvolutionBackwardData( ctx->stream()->As()->cudnn_handle(), CudnnSPOnePtr(), args.wdesc.Get(), weight->dptr(), args.ydesc.Get(), in->dptr(), args.cdesc.Get(), algo_perf.algo, buf->mut_dptr(), args.params.max_ws_size, CudnnSPZeroPtr(), args.xdesc.Get(), out->mut_dptr())); } }; #define REGISTER_DECONV_KERNEL(op_name, dtype, ndims) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& in = ctx->InputTensorDesc("in", 0); \ if (in.shape().elem_cnt() == 0) return 0; \ const auto& weight = ctx->InputTensorDesc("weight", 0); \ const auto& out = ctx->OutputTensorDesc("out", 0); \ const auto& cudnn_conf = \ Singleton::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ &out, &weight, &in, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ cudnn_conf.cudnn_conv_force_bwd_data_algo()); \ }) REGISTER_DECONV_KERNEL(deconv1d, float, 1); REGISTER_DECONV_KERNEL(deconv2d, float, 2); REGISTER_DECONV_KERNEL(deconv3d, float, 3); REGISTER_DECONV_KERNEL(deconv1d, double, 1); REGISTER_DECONV_KERNEL(deconv2d, double, 2); REGISTER_DECONV_KERNEL(deconv3d, double, 3); } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/deform_conv_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template T get_coordinate_weight(const T* im_data, int height, int width, T y, T x, bool is_y_direction) { int y_l = floor(y); int x_l = floor(x); int y_h = y_l + 1; int x_h = x_l + 1; bool valid_y_l = 0 <= y_l && y_l < height; bool valid_y_h = 0 <= y_h && y_h < height; bool valid_x_l = 0 <= x_l && x_l < width; bool valid_x_h = 0 <= x_h && x_h < width; T zero = 0; T v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; T v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; T v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; T v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; if (is_y_direction) { T dx = x - x_l; return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); } else { T dy = y - y_l; return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); } } template T DeformableIm2ColBilinear(const T* bottom_data, const int data_width, const int height, const int width, T h, T w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; T lh = h - h_low; T lw = w - w_low; T hh = 1 - lh, hw = 1 - lw; T v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; T v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; T v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; T v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template T GetGradientWeight(T argmax_h, T argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { // empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; T weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } template T GetCoordinateWeight(T argmax_h, T argmax_w, const int height, const int width, const T* im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { // empty return static_cast(0); } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; T weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } template T bilinear_interpolate(const T* in, int height, int width, T h, T w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; } int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; T lh = h - h_low; T lw = w - w_low; T hh = 1 - lh, hw = 1 - lw; T v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = in[h_low * width + w_low]; T v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = in[h_low * width + w_high]; T v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = in[h_high * width + w_low]; T v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = in[h_high * width + w_high]; T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template void DeformableIm2Col(int n, const T* input, const T* offset, const T* mask, int height, int width, int weight_h, int weight_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int n_in_channels, int n_offset_grps, int out_h, int out_w, bool use_mask, T* columns) { for (int index = 0; index != n; ++index) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int out_b = (index / (out_w * out_h)) % batch_sz; const int in_c = index / (out_w * out_h * batch_sz); const int out_c = in_c * weight_h * weight_w; int c_per_offset_grp = n_in_channels / n_offset_grps; const int grp_idx = in_c / c_per_offset_grp; auto columns_ptr = columns + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x); auto input_ptr = input + (out_b * (n_in_channels * height * width) + in_c * (height * width)); auto offset_ptr = offset + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; auto mask_ptr = mask; if (use_mask) { mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; } for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { const int mask_idx = i * weight_w + j; const int offset_idx = 2 * mask_idx; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; } const T offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const T offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; } } } } template void DeformableCol2Im(int n, const T* col, const T* offset_data, const T* mask_data, int channels, int height, int width, int kernel_h, int kernel_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int n_offset_grps, int out_h, int out_w, bool use_mask, T* grad_im) { for (int index = 0; index != n; ++index) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int b = (index / (out_w * out_h)) % batch_sz; const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); int c_per_offset_grp = channels / n_offset_grps; const int offset_grp = c / c_per_offset_grp; auto offset_ptr = offset_data; offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; auto mask_ptr = mask_data; if (use_mask) { mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * out_h * out_w; } const int mask_idx = i * kernel_w + j; const int offset_idx = 2 * mask_idx; const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; const T offset_h = offset_ptr[offset_h_ptr]; const T offset_w = offset_ptr[offset_w_ptr]; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; } const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; for (int dy = -1; dy <= 1; dy++) { for (int dx = -1; dx <= 1; dx++) { int yp = (int)y + dy; int xp = (int)x + dx; if (0 <= yp && yp < height && 0 <= xp && xp < width && std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; T weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); grad_im[grad_pos] += mask_value * weight * col[index]; } } } } } template void DeformableCol2ImCoord(int n, const T* col_data, const T* im_data, const T* offset_data, const T* mask_data, int channels, int height, int width, int weight_h, int weight_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int offset_channels, int n_offset_grps, int out_h, int out_w, const bool use_mask, T* grad_offset, T* grad_mask) { for (int index = 0; index != n; ++index) { T grad_offset_val = 0; T grad_mask_val = 0; int w = index % out_w; int h = (index / out_w) % out_h; int w_w = (index / (out_w * out_h * 2)) % weight_w; int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; int c = (index / (out_w * out_h)) % offset_channels; int b = index / (out_w * out_h * offset_channels); const int offset_grp = c / (2 * weight_h * weight_w); const int col_step = weight_h * weight_w; int c_per_offset_grp = channels / n_offset_grps; auto col_ptr = col_data; col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; auto im_ptr = im_data; im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; auto offset_ptr = offset_data; offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; auto mask_ptr = mask_data; if (use_mask) { mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * out_h * out_w; } const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const bool is_y_direction = offset_c % 2 == 0; const int c_bound = c_per_offset_grp * weight_h * weight_w; for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; int out_x = col_pos % out_w; int out_y = (col_pos / out_w) % out_h; int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; const int mask_idx = i * weight_w + j; const int offset_h_ptr = (((2 * mask_idx) * out_h + out_y) * out_w + out_x); const int offset_w_ptr = (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); const T offset_h = offset_ptr[offset_h_ptr]; const T offset_w = offset_ptr[offset_w_ptr]; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; } T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const T weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); grad_offset_val += mask_value * weight * col_ptr[col_pos]; if (use_mask && is_y_direction) { grad_mask_val += col_ptr[col_pos] * bilinear_interpolate(im_ptr, height, width, y, x); } im_ptr += height * width; } grad_offset[index] = grad_offset_val; if (use_mask && is_y_direction) { const int idx = ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + w_w) * out_h + h) * out_w + w; grad_mask[idx] = grad_mask_val; } } } ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewPermutePrimitive(Context* ctx, const int& num_dims) { return ep::primitive::NewPrimitive(ctx->device_type(), num_dims); } template class DeformableConv2dCpuKernel final : public user_op::OpKernel { public: DeformableConv2dCpuKernel() = default; ~DeformableConv2dCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& input_shape = input->shape_view(); const ShapeView& output_shape = output->shape_view(); const ShapeView& weight_shape = weight->shape_view(); const int64_t out_elem_cnt = output_shape.elem_cnt(); const int64_t output_bytes = (out_elem_cnt * sizeof(T)); T* column_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + output_bytes); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const int64_t outputWidth = ((input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW) + 1; const int64_t outputHeight = ((input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH) + 1; const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth; if (column_nums > 0) { DeformableIm2Col(column_nums, input->dptr(), offset->dptr(), mask->dptr(), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), input_shape.At(1), deformable_group, output_shape.At(2), output_shape.At(3), use_mask, column_tmp_buffer); const int64_t weight_group_offset = weight->shape_view().elem_cnt() / group; const int64_t column_group_offset = input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group; const int64_t output_group_offset = out_elem_cnt / group; auto matmul = NewMatmulPrimitive(ctx->device_type(), output->data_type(), false, false); CHECK(matmul); FOR_RANGE(int, g, 0, group) { matmul->Launch(ctx->stream(), weight_shape.At(0) / group, input_shape.At(0) * outputHeight * outputWidth, input_shape.At(1) * kW * kH / group, static_cast(1), weight->dptr() + g * weight_group_offset, column_tmp_buffer + g * column_group_offset, static_cast(0), tmp_buffer->mut_dptr() + g * output_group_offset); } std::vector out_shapevec( {output_shape.At(1), output_shape.At(0), output_shape.At(2), output_shape.At(3)}); auto transpose = NewPermutePrimitive(ctx, output_shape.NumAxes()); CHECK(transpose); transpose->Launch(ctx->stream(), output->data_type(), output_shape.NumAxes(), out_shapevec.data(), tmp_buffer->dptr(), std::vector({1, 0, 2, 3}).data(), output->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class DeformableConv2dInputGradCpuKernel final : public user_op::OpKernel { public: DeformableConv2dInputGradCpuKernel() = default; ~DeformableConv2dInputGradCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex("output_grad", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); user_op::Tensor* input_grad = ctx->Tensor4ArgNameAndIndex("input_grad", 0); user_op::Tensor* offset_grad = ctx->Tensor4ArgNameAndIndex("offset_grad", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& output_grad_shape = output_grad->shape_view(); const ShapeView& input_shape = input->shape_view(); const ShapeView& weight_shape = weight->shape_view(); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const T* data_mask = nullptr; T* data_mask_grad = nullptr; if (use_mask) { data_mask = ctx->Tensor4ArgNameAndIndex("mask", 0)->dptr(); data_mask_grad = ctx->Tensor4ArgNameAndIndex("mask_grad", 0)->mut_dptr(); } const int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; const int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); primitive->Launch(ctx->stream(), input_grad->mut_dptr(), 0, input_grad->shape_view().elem_cnt() * sizeof(T)); if (use_mask) { primitive->Launch( ctx->stream(), data_mask_grad, 0, ctx->Tensor4ArgNameAndIndex("mask_grad", 0)->shape_view().elem_cnt() * sizeof(T)); } const int64_t nthreads_coord = outputHeight * outputWidth * 2 * kH * kW * deformable_group * input_shape.At(0); const int64_t nthreads_feat = outputHeight * outputWidth * input_shape.At(0) * kH * kW * input_shape.At(1); if (nthreads_coord > 0 && nthreads_feat > 0) { const int64_t weight_group_offset = weight_shape.elem_cnt() / group; const int64_t output_grad_group_offset = output_grad_shape.Count(1) / group; const int64_t column_group_offset = input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group; auto matmul = NewMatmulPrimitive(ctx->device_type(), input_grad->data_type(), true, true); CHECK(matmul); FOR_RANGE(int, g, 0, group) { matmul->Launch(ctx->stream(), weight_shape.Count(1), input_shape.At(0) * outputHeight * outputWidth, weight_shape.At(0) / group, static_cast(1), weight->dptr() + g * weight_group_offset, output_grad->dptr() + g * output_grad_group_offset, static_cast(0), tmp_buffer->mut_dptr() + g * column_group_offset); } DeformableCol2ImCoord( nthreads_coord, tmp_buffer->dptr(), input->dptr(), offset->dptr(), data_mask, input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), 2 * kH * kW * deformable_group, deformable_group, outputHeight, outputWidth, use_mask, offset_grad->mut_dptr(), data_mask_grad); DeformableCol2Im(nthreads_feat, tmp_buffer->dptr(), offset->dptr(), data_mask, input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), deformable_group, outputHeight, outputWidth, use_mask, input_grad->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class DeformableConv2dParamGradCpuKernel final : public user_op::OpKernel { public: DeformableConv2dParamGradCpuKernel() = default; ~DeformableConv2dParamGradCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex("output_grad", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex("weight_grad", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& output_grad_shape = output_grad->shape_view(); const ShapeView& weight_grad_shape = weight_grad->shape_view(); const ShapeView& input_shape = input->shape_view(); const int64_t out_elem_cnt = output_grad_shape.elem_cnt(); const int64_t output_bytes = (out_elem_cnt * sizeof(T)); T* column_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + output_bytes); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; const int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth; if (column_nums > 0) { DeformableIm2Col(column_nums, input->dptr(), offset->dptr(), mask->dptr(), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), input_shape.At(1), deformable_group, output_grad_shape.At(2), output_grad_shape.At(3), use_mask, column_tmp_buffer); std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); primitive->Launch(ctx->stream(), weight_grad->mut_dptr(), 0, weight_grad->shape_view().elem_cnt()); std::vector output_grad_buffer_vec({output_grad_shape.At(1), output_grad_shape.At(0), output_grad_shape.At(2), output_grad_shape.At(3)}); auto transpose = NewPermutePrimitive(ctx, output_grad_shape.NumAxes()); CHECK(transpose); transpose->Launch(ctx->stream(), output_grad->data_type(), output_grad_shape.NumAxes(), output_grad_buffer_vec.data(), output_grad->dptr(), std::vector({1, 0, 2, 3}).data(), tmp_buffer->mut_dptr()); const int64_t output_grad_group_offset = output_grad_shape.elem_cnt() / group; const int64_t column_group_offset = input_shape.At(1) * kW * kW * input_shape.At(0) * outputHeight * outputWidth / group; const int64_t weight_grad_group_offset = weight_grad->shape_view().elem_cnt() / group; FOR_RANGE(int, g, 0, group) { auto matmul = NewMatmulPrimitive(ctx->device_type(), weight_grad->data_type(), false, true); CHECK(matmul); matmul->Launch(ctx->stream(), weight_grad_shape.At(0) / group, input_shape.At(1) * kW * kH / group, input_shape.At(0) * outputHeight * outputWidth, static_cast(1), tmp_buffer->dptr() + g * output_grad_group_offset, column_tmp_buffer + g * column_group_offset, static_cast(0), weight_grad->mut_dptr() + g * weight_grad_group_offset); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DEFORM_CONV2D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& output_shape = ctx->OutputShape("output", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ ((input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW) + 1; \ const int64_t outputHeight = \ ((input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH) + 1; \ const int64_t column_bytes = (input_shape.At(1) * kW * kH * input_shape.At(0) \ * outputHeight * outputWidth * sizeof(dtype)); \ const int64_t output_bytes = (output_shape.elem_cnt() * sizeof(dtype)); \ return column_bytes + output_bytes; \ }); REGISTER_DEFORM_CONV2D_CPU_KERNEL(float) REGISTER_DEFORM_CONV2D_CPU_KERNEL(double) #define REGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d_input_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("weight", 0) == GetDataType::value) \ && (user_op::HobDataType("offset", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; \ const int64_t outputHeight = \ (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; \ const int64_t column_bytes = input_shape.At(1) * kW * kH * input_shape.At(0) \ * outputHeight * outputWidth * sizeof(dtype); \ return column_bytes; \ }); REGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(float) REGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(double) #define REGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d_param_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("offset", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& output_grad_shape = ctx->InputShape("output_grad", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; \ const int64_t outputHeight = \ (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; \ const int64_t column_bytes = (input_shape.At(1) * kW * kH * input_shape.At(0) \ * outputHeight * outputWidth * sizeof(dtype)); \ const int64_t output_bytes = (output_grad_shape.elem_cnt() * sizeof(dtype)); \ return column_bytes + output_bytes; \ }); REGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(float) REGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(double) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/deform_conv_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace { __device__ __forceinline__ float Add(float* address, float val) { return atomicAdd(address, val); } __device__ __forceinline__ double Add(double* address, double val) { #if __CUDA_ARCH__ >= 600 return atomicAdd(address, val); #else auto address_as_ull = reinterpret_cast(address); unsigned long long int old = *address_as_ull; unsigned long long int assumed = 0; do { assumed = old; old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); } while (assumed != old); return __longlong_as_double(old); #endif } template __device__ T bilinear_interpolate(const T* in, int height, int width, T h, T w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; } int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; T lh = h - h_low; T lw = w - w_low; T hh = 1 - lh, hw = 1 - lw; T v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = in[h_low * width + w_low]; T v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = in[h_low * width + w_high]; T v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = in[h_high * width + w_low]; T v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = in[h_high * width + w_high]; T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ T DeformableIm2ColBilinear(const T* bottom_data, const int data_width, const int height, const int width, T h, T w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; T lh = h - h_low; T lw = w - w_low; T hh = 1 - lh, hw = 1 - lw; T v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; T v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; T v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; T v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ T get_coordinate_weight(const T* im_data, int height, int width, T y, T x, bool is_y_direction) { int y_l = floor(y); int x_l = floor(x); int y_h = y_l + 1; int x_h = x_l + 1; bool valid_y_l = 0 <= y_l && y_l < height; bool valid_y_h = 0 <= y_h && y_h < height; bool valid_x_l = 0 <= x_l && x_l < width; bool valid_x_h = 0 <= x_h && x_h < width; T zero = 0; T v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; T v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; T v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; T v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; if (is_y_direction) { T dx = x - x_l; return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); } else { T dy = y - y_l; return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); } } template __device__ T GetGradientWeight(T argmax_h, T argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { // empty return static_cast(0); } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; T weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } template __device__ T GetCoordinateWeight(T argmax_h, T argmax_w, const int height, const int width, const T* im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { // empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; T weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } template __global__ void DeformableCol2Im(int n, const T* col, const T* offset_data, const T* mask_data, int channels, int height, int width, int kernel_h, int kernel_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int n_offset_grps, int out_h, int out_w, bool use_mask, T* grad_im) { CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int b = (index / (out_w * out_h)) % batch_sz; const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); int c_per_offset_grp = channels / n_offset_grps; const int offset_grp = c / c_per_offset_grp; auto offset_ptr = offset_data; offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; auto mask_ptr = mask_data; if (use_mask) { mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * out_h * out_w; } const int mask_idx = i * kernel_w + j; const int offset_idx = 2 * mask_idx; const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; const T offset_h = offset_ptr[offset_h_ptr]; const T offset_w = offset_ptr[offset_w_ptr]; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; } const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; for (int dy = -1; dy <= 1; dy++) { for (int dx = -1; dx <= 1; dx++) { int yp = (int)y + dy; int xp = (int)x + dx; if (0 <= yp && yp < height && 0 <= xp && xp < width && abs(y - yp) < 1 && abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; T weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); Add(grad_im + grad_pos, mask_value * weight * col[index]); } } } } } template __global__ void DeformableIm2Col(int n, const T* input, const T* offset, const T* mask, int height, int width, int weight_h, int weight_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int n_in_channels, int n_offset_grps, int out_h, int out_w, bool use_mask, T* columns) { CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int out_b = (index / (out_w * out_h)) % batch_sz; const int in_c = index / (out_w * out_h * batch_sz); const int out_c = in_c * weight_h * weight_w; int c_per_offset_grp = n_in_channels / n_offset_grps; const int grp_idx = in_c / c_per_offset_grp; auto columns_ptr = columns; columns_ptr += (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x); auto input_ptr = input; input_ptr += (out_b * (n_in_channels * height * width) + in_c * (height * width)); auto offset_ptr = offset; offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; auto mask_ptr = mask; if (use_mask) { mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; } for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { const int mask_idx = i * weight_w + j; const int offset_idx = 2 * mask_idx; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; } const T offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const T offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; } } } } template __global__ void DeformableCol2imCoord(int n, const T* col_data, const T* im_data, const T* offset_data, const T* mask_data, int channels, int height, int width, int weight_h, int weight_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz, int offset_channels, int n_offset_grps, int out_h, int out_w, const bool use_mask, T* grad_offset, T* grad_mask) { CUDA_1D_KERNEL_LOOP(index, n) { T grad_offset_val = 0; T grad_mask_val = 0; int w = index % out_w; int h = (index / out_w) % out_h; int w_w = (index / (out_w * out_h * 2)) % weight_w; int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; int c = (index / (out_w * out_h)) % offset_channels; int b = index / (out_w * out_h * offset_channels); const int offset_grp = c / (2 * weight_h * weight_w); const int col_step = weight_h * weight_w; int c_per_offset_grp = channels / n_offset_grps; auto col_ptr = col_data; col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; auto im_ptr = im_data; im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; auto offset_ptr = offset_data; offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; auto mask_ptr = mask_data; if (use_mask) { mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * out_h * out_w; } const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const bool is_y_direction = offset_c % 2 == 0; const int c_bound = c_per_offset_grp * weight_h * weight_w; for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; int out_x = col_pos % out_w; int out_y = (col_pos / out_w) % out_h; int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; const int mask_idx = i * weight_w + j; const int offset_h_ptr = (((2 * mask_idx) * out_h + out_y) * out_w + out_x); const int offset_w_ptr = (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); const T offset_h = offset_ptr[offset_h_ptr]; const T offset_w = offset_ptr[offset_w_ptr]; T mask_value = 1; if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; } T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const T weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); grad_offset_val += mask_value * weight * col_ptr[col_pos]; if (use_mask && is_y_direction) { grad_mask_val += col_ptr[col_pos] * bilinear_interpolate(im_ptr, height, width, y, x); } im_ptr += height * width; } grad_offset[index] = grad_offset_val; if (use_mask && is_y_direction) { const int idx = ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + w_w) * out_h + h) * out_w + w; grad_mask[idx] = grad_mask_val; } } } } // namespace ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewPermutePrimitive(Context* ctx, const int& num_dims) { return ep::primitive::NewPrimitive(ctx->device_type(), num_dims); } template class DeformableConv2dCudaKernel final : public user_op::OpKernel { public: DeformableConv2dCudaKernel() = default; ~DeformableConv2dCudaKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& input_shape = input->shape_view(); const ShapeView& output_shape = output->shape_view(); const int64_t out_elem_cnt = output_shape.elem_cnt(); const int64_t output_bytes = GetCudaAlignedSize(out_elem_cnt * sizeof(T)); T* column_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + output_bytes); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const int32_t channel_per_deformable_group = input_shape.At(1) / deformable_group; const int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; const int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputWidth * outputHeight; if (column_nums > 0) { DeformableIm2Col<<stream()->As()->cuda_stream()>>>( column_nums, input->dptr(), offset->dptr(), mask->dptr(), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), input_shape.At(1), deformable_group, output_shape.At(2), output_shape.At(3), use_mask, column_tmp_buffer); const int64_t weight_group_offset = weight->shape_view().elem_cnt() / group; const int64_t column_group_offset = input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group; const int64_t output_group_offset = out_elem_cnt / group; auto matmul = NewMatmulPrimitive(ctx->device_type(), output->data_type(), false, false); CHECK(matmul); FOR_RANGE(int, g, 0, group) { matmul->Launch(ctx->stream(), weight->shape_view().At(0) / group, input_shape.At(0) * outputHeight * outputWidth, input_shape.At(1) * kW * kH / group, static_cast(1), weight->dptr() + g * weight_group_offset, column_tmp_buffer + g * column_group_offset, static_cast(0), tmp_buffer->mut_dptr() + g * output_group_offset); } std::vector out_shapevec( {output_shape.At(1), output_shape.At(0), output_shape.At(2), output_shape.At(3)}); auto transpose = NewPermutePrimitive(ctx, output_shape.NumAxes()); CHECK(transpose); transpose->Launch(ctx->stream(), output->data_type(), output_shape.NumAxes(), out_shapevec.data(), tmp_buffer->dptr(), std::vector({1, 0, 2, 3}).data(), output->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class DeformableConv2dInputGradKernel final : public user_op::OpKernel { public: DeformableConv2dInputGradKernel() = default; ~DeformableConv2dInputGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex("output_grad", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); user_op::Tensor* input_grad = ctx->Tensor4ArgNameAndIndex("input_grad", 0); user_op::Tensor* offset_grad = ctx->Tensor4ArgNameAndIndex("offset_grad", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& output_grad_shape = output_grad->shape_view(); const ShapeView& input_shape = input->shape_view(); const ShapeView& weight_shape = weight->shape_view(); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const T* data_mask = nullptr; T* data_mask_grad = nullptr; if (use_mask) { data_mask = ctx->Tensor4ArgNameAndIndex("mask", 0)->dptr(); data_mask_grad = ctx->Tensor4ArgNameAndIndex("mask_grad", 0)->mut_dptr(); } const int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; const int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); primitive->Launch(ctx->stream(), input_grad->mut_dptr(), 0, input_grad->shape_view().elem_cnt() * sizeof(T)); const int64_t nthreads_coord = outputHeight * outputWidth * 2 * deformable_group * input_shape.At(0) * kW * kH; const int64_t nthreads_feat = outputHeight * outputWidth * input_shape.At(0) * input_shape.At(1) * kW * kH; if (nthreads_coord > 0 && nthreads_feat > 0) { const int64_t weight_group_offset = weight_shape.elem_cnt() / group; const int64_t output_grad_group_offset = output_grad_shape.Count(1) / group; const int64_t column_group_offset = input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group; auto matmul = NewMatmulPrimitive(ctx->device_type(), input_grad->data_type(), true, true); CHECK(matmul); FOR_RANGE(int, g, 0, group) { matmul->Launch(ctx->stream(), weight_shape.Count(1), input_shape.At(0) * outputHeight * outputWidth, weight_shape.At(0) / group, static_cast(1), weight->dptr() + g * weight_group_offset, output_grad->dptr() + g * output_grad_group_offset, static_cast(0), tmp_buffer->mut_dptr() + g * column_group_offset); } DeformableCol2imCoord<<stream()->As()->cuda_stream()>>>( nthreads_coord, tmp_buffer->dptr(), input->dptr(), offset->dptr(), data_mask, input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), 2 * kH * kW * deformable_group, deformable_group, outputHeight, outputWidth, use_mask, offset_grad->mut_dptr(), data_mask_grad); DeformableCol2Im<<stream()->As()->cuda_stream()>>>( nthreads_feat, tmp_buffer->dptr(), offset->dptr(), data_mask, input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), deformable_group, outputHeight, outputWidth, use_mask, input_grad->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class DeformableConv2dParamGradKernel final : public user_op::OpKernel { public: DeformableConv2dParamGradKernel() = default; ~DeformableConv2dParamGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex("output_grad", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex("offset", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex("weight_grad", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& output_grad_shape = output_grad->shape_view(); const ShapeView& weight_grad_shape = weight_grad->shape_view(); const ShapeView& input_shape = input->shape_view(); const int64_t out_elem_cnt = output_grad_shape.elem_cnt(); const int64_t output_bytes = GetCudaAlignedSize(out_elem_cnt * sizeof(T)); T* column_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + output_bytes); const int32_t kW = weight->shape_view().At(2); const int32_t kH = weight->shape_view().At(3); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t group = ctx->Attr("groups"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); const int32_t channel_per_deformable_group = input_shape.At(1) / deformable_group; const int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; const int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; const T* data_mask = nullptr; if (use_mask) { data_mask = ctx->Tensor4ArgNameAndIndex("mask", 0)->dptr(); } const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth; if (column_nums > 0) { DeformableIm2Col<<stream()->As()->cuda_stream()>>>( column_nums, input->dptr(), offset->dptr(), data_mask, input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0), input_shape.At(1), deformable_group, output_grad_shape.At(2), output_grad_shape.At(3), use_mask, column_tmp_buffer); std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); primitive->Launch(ctx->stream(), weight_grad->mut_dptr(), 0, weight_grad->shape_view().elem_cnt() * sizeof(T)); std::vector output_grad_buffer_vec({output_grad_shape.At(1), output_grad_shape.At(0), output_grad_shape.At(2), output_grad_shape.At(3)}); auto transpose = NewPermutePrimitive(ctx, output_grad_shape.NumAxes()); CHECK(transpose); transpose->Launch(ctx->stream(), output_grad->data_type(), output_grad_shape.NumAxes(), output_grad_buffer_vec.data(), output_grad->dptr(), std::vector({1, 0, 2, 3}).data(), tmp_buffer->mut_dptr()); const int64_t output_grad_group_offset = output_grad_shape.elem_cnt() / group; const int64_t column_group_offset = input_shape.At(1) * kW * kW * input_shape.At(0) * outputHeight * outputWidth / group; const int64_t weight_grad_group_offset = weight_grad->shape_view().elem_cnt() / group; FOR_RANGE(int, g, 0, group) { auto matmul = NewMatmulPrimitive(ctx->device_type(), weight_grad->data_type(), false, true); CHECK(matmul); matmul->Launch(ctx->stream(), weight_grad_shape.At(0) / group, input_shape.At(1) * kW * kH / group, input_shape.At(0) * outputHeight * outputWidth, static_cast(1), tmp_buffer->dptr() + g * output_grad_group_offset, column_tmp_buffer + g * column_group_offset, static_cast(0), weight_grad->mut_dptr() + g * weight_grad_group_offset); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DEFORM_CONV2D_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& output_shape = ctx->OutputShape("output", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; \ const int64_t outputHeight = \ (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; \ const int64_t column_bytes = \ GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight \ * outputWidth * sizeof(dtype)); \ const int64_t output_bytes = GetCudaAlignedSize(output_shape.elem_cnt() * sizeof(dtype)); \ return column_bytes + output_bytes; \ }); REGISTER_DEFORM_CONV2D_GPU_KERNEL(float) REGISTER_DEFORM_CONV2D_GPU_KERNEL(double) #define REGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d_input_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("weight", 0) == GetDataType::value) \ && (user_op::HobDataType("offset", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; \ const int64_t outputHeight = \ (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; \ const int64_t column_bytes = \ GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight \ * outputWidth * sizeof(dtype)); \ return column_bytes; \ }); REGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(float) REGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(double) #define REGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("deform_conv2d_param_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("offset", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& input_shape = ctx->InputShape("input", 0); \ const Shape& output_grad_shape = ctx->InputShape("output_grad", 0); \ const Shape& weight_shape = ctx->InputShape("weight", 0); \ const int32_t kW = weight_shape.At(2); \ const int32_t kH = weight_shape.At(3); \ const int32_t dW = ctx->Attr("stride_w"); \ const int32_t dH = ctx->Attr("stride_h"); \ const int32_t padW = ctx->Attr("pad_w"); \ const int32_t padH = ctx->Attr("pad_h"); \ const int32_t dilationW = ctx->Attr("dilation_w"); \ const int32_t dilationH = ctx->Attr("dilation_h"); \ const int64_t outputWidth = \ (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; \ const int64_t outputHeight = \ (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; \ const int64_t column_bytes = \ GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight \ * outputWidth * sizeof(dtype)); \ const int64_t output_bytes = \ GetCudaAlignedSize(output_grad_shape.elem_cnt() * sizeof(dtype)); \ return column_bytes + output_bytes; \ }); REGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(float) REGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/det_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/eigen_util.h" namespace oneflow { namespace { static inline size_t BatchCount(const user_op::Tensor* batched_matrices) { size_t result = 1; for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) { result *= batched_matrices->shape_view().At(i); } return result; } static inline size_t MatrixStride(const user_op::Tensor* batched_matrices) { const int64_t num_axes = batched_matrices->shape_view().NumAxes(); return batched_matrices->shape_view().At(num_axes - 2) * batched_matrices->shape_view().At(num_axes - 1); } } // namespace template class DetKernel final : public user_op::OpKernel { public: DetKernel() = default; ~DetKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); auto batch_count = BatchCount(x); auto matrix_stride = MatrixStride(x); auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); FOR_RANGE(int64_t, i, 0, batch_count) { ConstEigenMatrixMap x_mat(x_ptr + i * matrix_stride, matrix_size, matrix_size); if (x_mat.determinant() == 0) { LOG(FATAL) << "(Batch element " << i << "): the inversion could not be completed because the input matrix is singular."; } T y = x_mat.determinant(); *(y_ptr + i) = y; }; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DET_KERNEL(dtype) \ REGISTER_USER_KERNEL("det").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_DET_KERNEL(float) REGISTER_DET_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/diag_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/diag_kernel.h" namespace oneflow { namespace { template struct DiagFunctor final { void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride, int32_t in_dim) { if (in_dim == 1) { FOR_RANGE(int32_t, i, 0, size) { out_buf[i * stride] = in_buf[i]; } } else { FOR_RANGE(int32_t, i, 0, size) { out_buf[i] = in_buf[i * stride]; } } } }; template struct DiagGradFunctor final { void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt, int32_t stride, int32_t in_dim) { if (in_dim == 1) { FOR_RANGE(int32_t, i, 0, dx_cnt) { dx_buf[i] = dy_buf[i * stride]; } } else { FOR_RANGE(int32_t, i, 0, dy_cnt) { dx_buf[i * stride] = dy_buf[i]; } } } }; } // namespace REGISTER_DIAG_KERNELS(DeviceType::kCPU, float); REGISTER_DIAG_KERNELS(DeviceType::kCPU, double); REGISTER_DIAG_KERNELS(DeviceType::kCPU, bool); REGISTER_DIAG_KERNELS(DeviceType::kCPU, uint8_t); REGISTER_DIAG_KERNELS(DeviceType::kCPU, int8_t); REGISTER_DIAG_KERNELS(DeviceType::kCPU, int32_t); REGISTER_DIAG_KERNELS(DeviceType::kCPU, int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/diag_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/diag_kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void vector_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t stride) { CUDA_1D_KERNEL_LOOP(i, size) { out_buf[i * stride] = in_buf[i]; } } template __global__ void matrix_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t stride) { CUDA_1D_KERNEL_LOOP(i, size) { out_buf[i] = in_buf[i * stride]; } } template struct DiagFunctor final { void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride, int32_t in_dim) { if (in_dim == 1) { vector_diagonal_kernel<<As()->cuda_stream()>>>(out_buf, in_buf, size, stride); } else { matrix_diagonal_kernel<<As()->cuda_stream()>>>(out_buf, in_buf, size, stride); } } }; template struct DiagGradFunctor final { void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt, int32_t stride, int32_t in_dim) { if (in_dim == 1) { matrix_diagonal_kernel<<As()->cuda_stream()>>>(dx_buf, dy_buf, dx_cnt, stride); } else { vector_diagonal_kernel<<As()->cuda_stream()>>>(dx_buf, dy_buf, dy_cnt, stride); } } }; } // namespace REGISTER_DIAG_KERNELS(DeviceType::kCUDA, half); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, float); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, double); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, bool); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, uint8_t); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, int8_t); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, int32_t); REGISTER_DIAG_KERNELS(DeviceType::kCUDA, int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/diag_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_ #define _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { template struct DiagFunctor final { void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride, int32_t in_dim); }; template struct DiagGradFunctor final { void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt, int32_t stride, int32_t in_dim); }; } // namespace template class DiagKernel final : public user_op::OpKernel { public: DiagKernel() = default; ~DiagKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int32_t diagonal = ctx->Attr("diagonal"); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& out_shape = out->shape_view(); const ShapeView& in_shape = in->shape_view(); int32_t in_dim = in_shape.NumAxes(); const T* in_buf = in->dptr(); T* out_buf = out->mut_dptr(); Memset(ctx->stream(), out->mut_dptr(), 0, out_shape.elem_cnt() * sizeof(T)); if (in_dim == 1) { int32_t size = in_shape.elem_cnt(); out_buf += (diagonal >= 0 ? diagonal : -diagonal * out_shape.At(1)); DiagFunctor()(ctx->stream(), out_buf, in_buf, size, out_shape.At(1) + 1, in_dim); } else { int32_t size = 0; in_buf += (diagonal >= 0 ? diagonal : -diagonal * in_shape.At(1)); if (diagonal >= 0) { size = std::min(in_shape.At(0), in_shape.At(1) - diagonal); } else { size = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); } DiagFunctor()(ctx->stream(), out_buf, in_buf, size, in_shape.At(1) + 1, in_dim); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class DiagBackwardKernel final : public user_op::OpKernel { public: DiagBackwardKernel() = default; ~DiagBackwardKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); int32_t diagonal = ctx->Attr("diagonal"); const ShapeView& dx_shape = dx->shape_view(); const ShapeView& dy_shape = dy->shape_view(); int32_t in_dim = dx_shape.NumAxes(); int32_t dy_cnt = dy_shape.Count(0); int32_t dx_cnt = dx_shape.Count(0); T* dx_buf = dx->mut_dptr(); const T* dy_buf = dy->dptr(); Memset(ctx->stream(), dx->mut_dptr(), 0, dx_shape.elem_cnt() * sizeof(T)); if (in_dim == 1) { dy_buf += (diagonal >= 0 ? diagonal : -diagonal * dy_shape.At(1)); DiagGradFunctor()(ctx->stream(), dx_buf, dy_buf, dx_cnt, dy_cnt, dy_shape.At(1) + 1, in_dim); } else { dx_buf += (diagonal >= 0 ? diagonal : -diagonal * dx_shape.At(1)); DiagGradFunctor()(ctx->stream(), dx_buf, dy_buf, dx_cnt, dy_cnt, dx_shape.At(1) + 1, in_dim); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DIAG_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("diag").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("diag_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/diagonal_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template struct DiagonalFunctor final { void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1, int32_t dim2) { int32_t offset_index = (dim1 + 1) * dim2; FOR_RANGE(int32_t, index, 0, size * dim2) { int32_t i = index / dim2; int32_t j = index - i * dim2; out_buf[j * size + i] = in_buf[i * offset_index + j]; } } }; template struct DiagonalGradFunctor final { void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, int32_t dim2) { int32_t offset_index = (dim1 + 1) * dim2; FOR_RANGE(int32_t, index, 0, size * dim2) { int32_t i = index / dim2; int32_t j = index - i * dim2; dx_buf[i * offset_index + j] = dy_buf[j * size + i]; } } }; } // namespace template class CpuDiagonalKernel final : public user_op::OpKernel { public: CpuDiagonalKernel() = default; ~CpuDiagonalKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int32_t offset = ctx->Attr("offset"); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& out_shape = out->shape_view(); const ShapeView& in_shape = in->shape_view(); const T* in_buf = in->dptr(); T* out_buf = out->mut_dptr(); int32_t size = out_shape.At(out_shape.NumAxes() - 1); int32_t dim1 = in_shape.At(1); int32_t dim2 = 0; if (in_shape.NumAxes() <= 2) { dim2 = 1; } else { dim2 = in_shape.Count(2, in_shape.NumAxes()); } int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); in_buf += offset_in_bufer; DiagonalFunctor()(ctx->stream(), out_buf, in_buf, size, dim1, dim2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class CpuDiagonalBackwardKernel final : public user_op::OpKernel { public: CpuDiagonalBackwardKernel() = default; ~CpuDiagonalBackwardKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); int32_t offset = ctx->Attr("offset"); const ShapeView& dx_shape = dx->shape_view(); const ShapeView& dy_shape = dy->shape_view(); T* dx_buf = dx->mut_dptr(); const T* dy_buf = dy->dptr(); Memset(ctx->stream(), dx->mut_dptr(), 0, dx_shape.elem_cnt() * sizeof(T)); int32_t dim1 = dx_shape.At(1); int32_t dim2 = 0; if (dx_shape.NumAxes() <= 2) { dim2 = 1; } else { dim2 = dx_shape.Count(2, dx_shape.NumAxes()); } int32_t size = dy_shape.At(dy_shape.NumAxes() - 1); int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); dx_buf += offset_in_bufer; DiagonalGradFunctor()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DIAGONAL_KERNELS(dtype) \ REGISTER_USER_KERNEL("diagonal") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("diagonal_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_DIAGONAL_KERNELS(bool); REGISTER_DIAGONAL_KERNELS(float); REGISTER_DIAGONAL_KERNELS(double); REGISTER_DIAGONAL_KERNELS(int8_t); REGISTER_DIAGONAL_KERNELS(int32_t); REGISTER_DIAGONAL_KERNELS(int64_t); #undef REGISTER_DIAGONAL_KERNELS } // namespace oneflow ================================================ FILE: oneflow/user/kernels/diagonal_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void forward_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t dim1, int32_t dim2) { int32_t offset_index = (dim1 + 1) * dim2; CUDA_1D_KERNEL_LOOP(index, size * dim2) { int32_t i = index / dim2; int32_t j = index - i * dim2; out_buf[j * size + i] = in_buf[i * offset_index + j]; } } template __global__ void backward_diagonal_kernel(T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, int32_t dim2) { int32_t offset_index = (dim1 + 1) * dim2; CUDA_1D_KERNEL_LOOP(index, size * dim2) { int32_t i = index / dim2; int32_t j = index - i * dim2; dx_buf[i * offset_index + j] = dy_buf[j * size + i]; } } template struct DiagonalFunctor final { void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1, int32_t dim2) { if (size * dim2 > 0) { forward_diagonal_kernel <<As()->cuda_stream()>>>(out_buf, in_buf, size, dim1, dim2); } } }; template struct DiagonalGradFunctor final { void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, int32_t dim2) { if (size * dim2 > 0) { backward_diagonal_kernel <<As()->cuda_stream()>>>(dx_buf, dy_buf, size, dim1, dim2); } } }; } // namespace template class GpuDiagonalKernel final : public user_op::OpKernel { public: GpuDiagonalKernel() = default; ~GpuDiagonalKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int32_t offset = ctx->Attr("offset"); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& out_shape = out->shape_view(); const ShapeView& in_shape = in->shape_view(); const T* in_buf = in->dptr(); T* out_buf = out->mut_dptr(); int32_t size = out_shape.At(out_shape.NumAxes() - 1); int32_t dim1 = in_shape.At(1); int32_t dim2 = 0; if (in_shape.NumAxes() <= 2) { dim2 = 1; } else { dim2 = in_shape.Count(2, in_shape.NumAxes()); } int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); in_buf += offset_in_bufer; DiagonalFunctor()(ctx->stream(), out_buf, in_buf, size, dim1, dim2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuDiagonalBackwardKernel final : public user_op::OpKernel { public: GpuDiagonalBackwardKernel() = default; ~GpuDiagonalBackwardKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); int32_t offset = ctx->Attr("offset"); const ShapeView& dx_shape = dx->shape_view(); const ShapeView& dy_shape = dy->shape_view(); T* dx_buf = dx->mut_dptr(); const T* dy_buf = dy->dptr(); Memset(ctx->stream(), dx->mut_dptr(), 0, dx_shape.elem_cnt() * sizeof(T)); int32_t dim1 = dx_shape.At(1); int32_t dim2 = 0; if (dx_shape.NumAxes() <= 2) { dim2 = 1; } else { dim2 = dx_shape.Count(2, dx_shape.NumAxes()); } int32_t size = dy_shape.At(dy_shape.NumAxes() - 1); int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); dx_buf += offset_in_bufer; DiagonalGradFunctor()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DIAGONAL_KERNELS(dtype) \ REGISTER_USER_KERNEL("diagonal") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("diagonal_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_DIAGONAL_KERNELS(bool); REGISTER_DIAGONAL_KERNELS(half); REGISTER_DIAGONAL_KERNELS(float); REGISTER_DIAGONAL_KERNELS(double); REGISTER_DIAGONAL_KERNELS(int8_t); REGISTER_DIAGONAL_KERNELS(int32_t); REGISTER_DIAGONAL_KERNELS(int64_t); #undef REGISTER_DIAGONAL_KERNELS } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_gather_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" namespace oneflow { namespace user_op { template struct DimGatherFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& input_nd_helper, const DimOpIndexNdHelper& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { DoDimGather(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input, output); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCPU), DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_gather_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" namespace oneflow { namespace user_op { template __global__ void DoCUDADimGather(const DimOpIndexNdHelper input_nd_helper, const DimOpIndexNdHelper index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { DoDimGather(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input, output); } template struct DimGatherFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& input_nd_helper, const DimOpIndexNdHelper& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { RUN_CUDA_KERNEL((DoCUDADimGather), stream, BlocksNum4ThreadsNum(elem_cnt), input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input, output); } }; // float16 special case of DimGatherFunctor template template struct DimGatherFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& input_nd_helper, const DimOpIndexNdHelper& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const float16* input, float16* output) { RUN_CUDA_KERNEL((DoCUDADimGather), stream, BlocksNum4ThreadsNum(elem_cnt), input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, reinterpret_cast(input), reinterpret_cast(output)); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCUDA), DIM_GATHER_SCATTER_DATA_TYPE_CUDA_SEQ, INDEX_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/dim_gather_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { #define DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ \ ARITHMETIC_DATA_TYPE_SEQ \ UNSIGNED_INT_DATA_TYPE_SEQ \ BOOL_DATA_TYPE_SEQ #define DIM_GATHER_SCATTER_DATA_TYPE_CUDA_SEQ \ DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ \ FLOAT16_DATA_TYPE_SEQ constexpr int kDimGatherMaxDimCount = 8; template using DimOpIndexNdHelper = NdIndexOffsetHelper; namespace user_op { template struct DimGatherFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& input_nd_helper, const DimOpIndexNdHelper& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output); }; template OF_DEVICE_FUNC void DoDimGather(const DimOpIndexNdHelper& input_nd_helper, const DimOpIndexNdHelper& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) { IDX_T coordinate[kDimGatherMaxDimCount] = {0}; const IDX_T x = index[index_offset]; #ifdef __CUDA_ARCH__ assert(x < dim_length && "gather index is out of bounds"); #else CHECK_LE(x, dim_length) << "RuntimeError: index " << x << " is out of bounds for dimension " << dim << " with size " << dim_length; #endif index_nd_helper.OffsetToNdIndex(index_offset, coordinate, ndim); coordinate[dim] = x; IDX_T input_offset = input_nd_helper.NdIndexToOffset(coordinate, ndim); output[index_offset] = input[input_offset]; } } template struct DeviceAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #ifdef __CUDA_ARCH__ cuda::atomic::Add(y, *x); // TODO:(YaoChi), refine add using float16 -> half -> float -> half #else *y += *x; #endif }; }; // macros for functors instantiate(used by dim_gather_kernel_util.cu and dim_gather_kernel_uti.cpp) #define INSTANTIATE_DIM_GATHER_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ template struct DimGatherFunctor; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/dim_gather_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" namespace oneflow { namespace user_op { namespace { template void ConvertShape2Array(const ShapeView& shape_view, IDX_T* array, int64_t num_axis) { FOR_RANGE(int64_t, i, 0, num_axis) { array[i] = shape_view.At(i); } } } // namespace template class DimGatherKernel final : public user_op::OpKernel { public: DimGatherKernel() = default; ~DimGatherKernel() override = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); if (input_tensor->shape_view().elem_cnt() == 0) { return; } const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); const int32_t dim = ctx->Attr("dim"); const IN_T* input = input_tensor->dptr(); const IDX_T* index = index_tensor->dptr(); IN_T* output = out_tensor->mut_dptr(); const Shape in_shape = ExpandDimIf0D(input_tensor->shape_view()); const auto ndim = in_shape.NumAxes(); const auto dim_length = in_shape.At(dim); DimOpIndexNdHelper input_nd_helper(in_shape.data(), ndim); DimOpIndexNdHelper index_nd_helper(index_tensor->shape_view().data(), ndim); DimGatherFunctor()(ctx->stream(), input_nd_helper, index_nd_helper, ndim, index_tensor->shape_view().elem_cnt(), dim_length, dim, index, input, output); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DIM_GATHER_KERNEL(device, dtype_pair, itype_pair) \ REGISTER_USER_KERNEL("dim_gather") \ .SetCreateFn< \ DimGatherKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("index", 0) == OF_PP_PAIR_SECOND(itype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_DIM_GATHER_KERNEL, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_GATHER_KERNEL, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_scatter_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_scatter_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> struct DimScatterFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& src_nd_helper, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const IN_T* src, IN_T* output) { DoDimScatter(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, output); } }; INSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpAddFunctor); INSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpMulFunctor); INSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpUpdateFunctor); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_scatter_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/dim_scatter_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> __global__ void DoCUDADimScatter(const DimOpIndexNdHelper src_nd_helper, const DimOpIndexNdHelper idx_nd_helper, const DimOpIndexNdHelper output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const IN_T* src, IN_T* output) { DoDimScatter(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, output); } template class Opt> struct DimScatterFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& src_nd_helper, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const IN_T* src, IN_T* output) { RUN_CUDA_KERNEL((DoCUDADimScatter), stream, BlocksNum4ThreadsNum(elem_cnt), src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, output); } }; template class Opt> struct DimScatterFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& src_nd_helper, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const float16* src, float16* output) { RUN_CUDA_KERNEL((DoCUDADimScatter), stream, BlocksNum4ThreadsNum(elem_cnt), src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, reinterpret_cast(src), reinterpret_cast(output)); } }; INSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpAddFunctor); INSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpMulFunctor); INSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpUpdateFunctor); } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/dim_scatter_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #include #endif // WITH_CUDA #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/error.pb.h" namespace oneflow { #define NO_HALF_UTIL_FOUND \ printf("cuda arch must >= 530"); \ assert(false) namespace user_op { constexpr int kDimGatherMaxDimCount = 8; template using DimOpIndexNdHelper = NdIndexOffsetHelper; #define INSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(device_type, opt) \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; #define INSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(device_type, opt) \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; \ template struct DimScatterFunctor; template struct BinOpAddFunctor { OF_DEVICE_FUNC static void apply(const T* x, T* y) { #ifdef __CUDA_ARCH__ cuda::atomic::Add(y, *x); #else *y += *x; #endif } }; #ifdef WITH_CUDA template<> struct BinOpAddFunctor { OF_DEVICE_FUNC static void apply(const half* x, half* y) { #ifdef __CUDA_ARCH__ *y = __float2half(__half2float(*x) + __half2float(*y)); #else NO_HALF_UTIL_FOUND; #endif } }; #endif #define SPECIALIZE_BIN_OP_ADD_FUNCTOR(name, dtype) \ template<> \ struct name { \ OF_DEVICE_FUNC static void apply(const dtype* x, dtype* y) { *y += *x; } \ }; SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, bool) SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, int8_t) SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, uint8_t) SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, int64_t) template struct BinOpMulFunctor { OF_DEVICE_FUNC static void apply(const T* x, T* y) { #ifdef __CUDA_ARCH__ cuda::atomic::Mul(y, *x); #else *y *= *x; #endif } }; #ifdef WITH_CUDA template<> struct BinOpMulFunctor { OF_DEVICE_FUNC static void apply(const half* x, half* y) { #ifdef __CUDA_ARCH__ *y = __float2half(__half2float(*x) * __half2float(*y)); #else NO_HALF_UTIL_FOUND; #endif } }; #endif #define SPECIALIZE_BIN_OP_MUL_FUNCTOR(name, dtype) \ template<> \ struct name { \ OF_DEVICE_FUNC static void apply(const dtype* x, dtype* y) { *y *= *x; } \ }; SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, int8_t) SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, uint8_t) SPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, int64_t) template<> struct BinOpMulFunctor { OF_DEVICE_FUNC static void apply(const bool* x, bool* y) { *y &= *x; } }; template struct BinOpUpdateFunctor { OF_DEVICE_FUNC static void apply(const T* x, T* y) { *y = *x; } }; template class Opt> struct DimScatterFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& src_nd_helper, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const IN_T* src, IN_T* output); }; template class Opt> OF_DEVICE_FUNC void DoDimScatter(const DimOpIndexNdHelper& src_nd_helper, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const IN_T* src, IN_T* output) { XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) { IDX_T coordinate[kDimGatherMaxDimCount] = {0}; idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk IDX_T idx_elem = index[idx_offset]; if (upper_bound != 0 && idx_elem >= upper_bound) { #if __CUDA_ARCH__ __trap(); #else UNIMPLEMENTED() << "The index element " << idx_elem << " is out of bounds for dimension " << dim << " with size " << upper_bound << "."; #endif } IDX_T src_offset = src_nd_helper.NdIndexToOffset(coordinate, ndim); coordinate[dim] = idx_elem; IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim); Opt::apply(src + src_offset, output + output_offset); } } } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/dim_scatter_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/user/kernels/dim_scatter_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> class DimScatterKernel final : public user_op::OpKernel { public: DimScatterKernel() = default; ~DimScatterKernel() override = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); const Tensor* src_tensor = ctx->Tensor4ArgNameAndIndex("src", 0); const int32_t dim = ctx->Attr("dim"); const IDX_T* index = index_tensor->dptr(); IN_T* output = out_tensor->mut_dptr(); size_t out_bytes_size = out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()); Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0); const IN_T* src = src_tensor->dptr(); if (input_tensor) { Memcpy(ctx->stream(), output, input_tensor->dptr(), out_bytes_size); } else if (like_tensor) { Memset(ctx->stream(), output, 0, out_bytes_size); } else { UNIMPLEMENTED() << "Input tensor and like tensor cannot be empty simultaneously."; } const Shape src_shape = ExpandDimIf0D(src_tensor->shape_view()); const Shape index_shape = ExpandDimIf0D(index_tensor->shape_view()); const int ndim = src_shape.NumAxes(); DimOpIndexNdHelper src_nd_helper(src_shape.data(), ndim); DimOpIndexNdHelper idx_nd_helper(index_shape.data(), ndim); DimOpIndexNdHelper output_nd_helper(out_tensor->shape_view().data(), ndim); const int64_t upper_bound = [&]() { if (input_tensor) { const Shape input_shape = ExpandDimIf0D(input_tensor->shape_view()); return input_shape.At(dim); } else { const Shape like_shape = ExpandDimIf0D(like_tensor->shape_view()); return like_shape.At(dim); } }(); DimScatterFunctor()( ctx->stream(), src_nd_helper, idx_nd_helper, output_nd_helper, ndim, index_shape.elem_cnt(), dim, upper_bound, index, src, output); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, device, dtype, itype, opt) \ REGISTER_USER_KERNEL(op_type) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("like", 0) == GetDataType::value) \ && (user_op::HobDataType("index", 0) == GetDataType::value)); #define REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS(op_type, opt) \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, bool, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float16, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, bool, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float16, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt); #define REGISTER_DIM_SCATTER_LIKE_CUDA_KERNELS(op_type, opt) \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, bool, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, float, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, double, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, half, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, int32_t, int32_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, bool, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, float, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, double, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, half, int64_t, opt); \ REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, int32_t, int64_t, opt); #define REGISTER_DIM_SCATTER_KERNEL(op_type, device, dtype_pair, itype_pair, opt) \ REGISTER_USER_KERNEL(#op_type) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("index", 0) == OF_PP_PAIR_SECOND(itype_pair))); #define REGISTER_DIM_SCATTER_CPU_KERNELS(dtype_pair, itype_pair) \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_add, DeviceType::kCPU, dtype_pair, itype_pair, \ BinOpAddFunctor); \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_mul, DeviceType::kCPU, dtype_pair, itype_pair, \ BinOpMulFunctor); \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_update, DeviceType::kCPU, dtype_pair, itype_pair, \ BinOpUpdateFunctor); #define REGISTER_DIM_SCATTER_CUDA_KERNELS(dtype_pair, itype_pair) \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_add, DeviceType::kCUDA, dtype_pair, itype_pair, \ BinOpAddFunctor); \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_mul, DeviceType::kCUDA, dtype_pair, itype_pair, \ BinOpMulFunctor); \ REGISTER_DIM_SCATTER_KERNEL(dim_scatter_update, DeviceType::kCUDA, dtype_pair, itype_pair, \ BinOpUpdateFunctor); REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS("dim_scatter_add_like", BinOpAddFunctor); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_SCATTER_CPU_KERNELS, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA REGISTER_DIM_SCATTER_LIKE_CUDA_KERNELS("dim_scatter_add_like", BinOpAddFunctor); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_SCATTER_CUDA_KERNELS, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> struct DimScatterScalarFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const IN_T src, IN_T* output) { DoScatterScalarFunctor(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, output); } }; INSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(DeviceType::kCPU, UpdateScalarFunctor); INSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(DeviceType::kCPU, AddScalarFunctor); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> __global__ void DoCUDADimScatterScalar(const DimOpIndexNdHelper idx_nd_helper, const DimOpIndexNdHelper output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, const IDX_T* index, const IN_T src_scalar, IN_T* output) { DoScatterScalarFunctor(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src_scalar, output); } template class Opt> struct DimScatterScalarFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const IN_T src, IN_T* output) { RUN_CUDA_KERNEL((DoCUDADimScatterScalar), stream, BlocksNum4ThreadsNum(elem_cnt), idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, output); } }; template class Opt> struct DimScatterScalarFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const float16 src, float16* output) { RUN_CUDA_KERNEL((DoCUDADimScatterScalar), stream, BlocksNum4ThreadsNum(elem_cnt), idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, reinterpret_cast(output)); } }; INSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(DeviceType::kCUDA, UpdateScalarFunctor); INSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(DeviceType::kCUDA, AddScalarFunctor); } // namespace user_op } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/dim_scatter_scalar_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #include #endif // WITH_CUDA #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/data_type.h" namespace oneflow { #define NO_HALF_UTIL_FOUND \ printf("cuda arch must >= 530"); \ assert(false) namespace user_op { constexpr int kDimGatherMaxDimCount = 8; template struct AddScalarFunctor { OF_DEVICE_FUNC static void apply(const T x, T* y) { #ifdef __CUDA_ARCH__ cuda::atomic::Add(y, x); #else *y += x; #endif } }; #ifdef WITH_CUDA template<> struct AddScalarFunctor { OF_DEVICE_FUNC static void apply(const half x, half* y) { #if __CUDA_ARCH__ *y = __float2half(__half2float(*y) + __half2float(x)); #else NO_HALF_UTIL_FOUND; #endif } }; #endif template<> struct AddScalarFunctor { OF_DEVICE_FUNC static void apply(const int8_t x, int8_t* y) { *y += x; } }; template<> struct AddScalarFunctor { OF_DEVICE_FUNC static void apply(const uint8_t x, uint8_t* y) { *y += x; } }; template<> struct AddScalarFunctor { OF_DEVICE_FUNC static void apply(const int64_t x, int64_t* y) { *y += x; } }; template struct UpdateScalarFunctor { OF_DEVICE_FUNC static void apply(const T x, T* y) { *y = x; } }; #define INSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(device_type, opt) \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; #define INSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(device_type, opt) \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; \ template struct DimScatterScalarFunctor; template using DimOpIndexNdHelper = NdIndexOffsetHelper; template class Opt> struct DimScatterScalarFunctor final { void operator()(ep::Stream* stream, const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const IN_T src, IN_T* output); }; template class Opt> OF_DEVICE_FUNC void DoScatterScalarFunctor(const DimOpIndexNdHelper& idx_nd_helper, const DimOpIndexNdHelper& output_nd_helper, const int ndim, const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, const IDX_T* index, const IN_T src, IN_T* output) { XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) { IDX_T coordinate[kDimGatherMaxDimCount] = {0}; idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk IDX_T idx_elem = index[idx_offset]; if (idx_elem >= upper_bound) { #if __CUDA_ARCH__ __trap(); #else UNIMPLEMENTED() << "The index element " << idx_elem << " is out of bounds for dimension " << dim << " with size " << upper_bound << "."; #endif } coordinate[dim] = idx_elem; IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim); Opt::apply(src, output + output_offset); } } } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/dim_scatter_scalar_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" namespace oneflow { namespace user_op { template class Opt> class DimScatterScalarKernel final : public user_op::OpKernel { public: DimScatterScalarKernel() = default; ~DimScatterScalarKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); const int32_t dim = ctx->Attr("dim"); const IDX_T* index = index_tensor->dptr(); IN_T* output = out_tensor->mut_dptr(); size_t out_bytes_size = out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()); Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0); const IN_T src_scalar = static_cast(ctx->Attr("src_scalar")); if (input_tensor) { Memcpy(ctx->stream(), output, input_tensor->dptr(), out_bytes_size); } else if (like_tensor) { Memset(ctx->stream(), output, 0, out_bytes_size); } else { UNIMPLEMENTED() << "Input tensor and like tensor cannot be empty simultaneously."; } const int ndim = out_tensor->shape_view().NumAxes(); small_vector shape_vec(ndim); auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void { std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(), [](int32_t dim) -> IDX_T { return static_cast(dim); }); }; shape2dims(index_tensor->shape_view()); DimOpIndexNdHelper idx_nd_helper(shape_vec.data(), ndim); shape2dims(out_tensor->shape_view()); DimOpIndexNdHelper output_nd_helper(shape_vec.data(), ndim); int64_t upper_bound = 0; if (input_tensor) { upper_bound = input_tensor->shape_view().At(dim); // ensure the idx is smaller than upperbound } else { upper_bound = like_tensor->shape_view().At(dim); // ensure the idx is smaller than upperbound } DimScatterScalarFunctor()( ctx->stream(), idx_nd_helper, output_nd_helper, ndim, index_tensor->shape_view().elem_cnt(), dim, upper_bound, index, src_scalar, output); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCATTERSCALAR_KERNEL(op_type_name, device, dtype_pair, itype_pair, opt) \ REGISTER_USER_KERNEL(#op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("index", 0) == OF_PP_PAIR_SECOND(itype_pair))); #define REGISTER_SCATTER_SCALAR_CPU_KERNELS(dtype_pair, itype_pair) \ REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_update_scalar, DeviceType::kCPU, dtype_pair, \ itype_pair, UpdateScalarFunctor); \ REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_add_scalar, DeviceType::kCPU, dtype_pair, itype_pair, \ AddScalarFunctor); #define REGISTER_SCATTER_SCALAR_CUDA_KERNELS(dtype_pair, itype_pair) \ REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_update_scalar, DeviceType::kCUDA, dtype_pair, \ itype_pair, UpdateScalarFunctor); \ REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_add_scalar, DeviceType::kCUDA, dtype_pair, itype_pair, \ AddScalarFunctor); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_SCATTER_SCALAR_CPU_KERNELS, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_SCATTER_SCALAR_CUDA_KERNELS, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/common.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/random_generator.h" namespace oneflow { class DistributionKernelState : public user_op::OpKernelState { public: explicit DistributionKernelState(const std::shared_ptr& generator) : generator_(generator) {} const std::shared_ptr& generator() const { return generator_; } private: std::shared_ptr generator_; }; // FIXME: refine warning message #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \ CHECK(var >= min && var <= max) << name << " is out of bounds for " << dtype; #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \ if (var < -(1LL << digits) || var > (1LL << digits)) { \ LOG(WARNING) << name << " is out of bounds [-(2^" << digits << "), 2^" << digits << "]. " \ << "Due to precision limitations " << dtype \ << " can support discrete uniform distribution only within this range. " \ << "This warning will become an error in later version release."; \ } template void check_from_to_in_range(int64_t from, int64_t to_inc) { if (IsFloating::value) { const auto min = static_cast(std::numeric_limits::lowest()); const auto max = static_cast(std::numeric_limits::max()); CHECK_OUT_OF_BOUNDS(from, "from", min, max, GetDataType::value); CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, GetDataType::value); constexpr auto digits = std::numeric_limits::digits; WARN_OUT_OF_BOUNDS(from, "from", digits, GetDataType::value); WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, GetDataType::value); } else if (IsIntegral::value || IsUnsignedIntegral::value) { const auto min = static_cast(std::numeric_limits::lowest()); const auto max = static_cast(std::numeric_limits::max()); CHECK_OUT_OF_BOUNDS(from, "from", min, max, GetDataType::value); CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, GetDataType::value); } else { UNIMPLEMENTED() << "check_random_bounds handles only integral, floating-point and boolean types"; } } } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/distributions/distribution_template_util.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_ #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/fused_rnn_cell_kernel_util.h" #include "oneflow/core/common/scalar.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { namespace distribution { template struct DefaultComputeType { using type = T; }; #define OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE(T, typeproto) \ template<> \ struct DefaultComputeType { \ using type = float; \ }; OF_PP_FOR_EACH_TUPLE(OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ) #undef OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE } // namespace distribution namespace { // launch bounds used for kernels const uint32_t block_size_bound = 256; const uint32_t grid_size_bound = 4; } // namespace #ifdef WITH_CUDA enum class DistributionOp { kNormal4, kNormal2Double, kUniform4, kUniform2Double, }; template struct DistributionFunctor; template<> struct DistributionFunctor { DistributionFunctor() {} __device__ float4 operator()(curandStatePhilox4_32_10_t* state) const { return curand_normal4(state); } }; template<> struct DistributionFunctor { DistributionFunctor() {} __device__ double2 operator()(curandStatePhilox4_32_10_t* state) const { return curand_normal2_double(state); } }; template<> struct DistributionFunctor { DistributionFunctor() {} __device__ float4 operator()(curandStatePhilox4_32_10_t* state) const { return curand_uniform4(state); } }; template<> struct DistributionFunctor { DistributionFunctor() {} __device__ double2 operator()(curandStatePhilox4_32_10_t* state) const { return curand_uniform2_double(state); } }; template OF_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) __global__ void DistributionElementwiseGridStrideKernel(int64_t numel, uint64_t seed, uint64_t offset, T* out_ptr, Distribution dist_func, Transform transform_func) { int idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, idx, offset, &state); int rounded_size = ((numel - 1) / (blockDim.x * gridDim.x * unroll_factor) + 1) * blockDim.x * gridDim.x * unroll_factor; for (int32_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { auto rand = dist_func(&state); #pragma unroll for (int ii = 0; ii < unroll_factor; ii++) { int li = linear_index + blockDim.x * gridDim.x * ii; if (li < numel) { out_ptr[li] = transform_func(static_cast((&rand.x)[ii])); } } } } #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_ ================================================ FILE: oneflow/user/kernels/distributions/exponential_distribution.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/exponential_distribution.h" namespace oneflow { static uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) { return (static_cast(hi) << 32) | lo; } template static T uniform_real(V val, T from, T to) { constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); constexpr auto DIVISOR = static_cast(1) / (static_cast(1) << std::numeric_limits::digits); T x = (val & MASK) * DIVISOR; return (x * (to - from) + from); } template void ExponentialDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); auto gen = CHECK_JUST(generator->Get()); ep::pytorch_mt19937_engine& engine = gen->torch_engine(); for (int64_t i = 0; i < elem_cnt; ++i) { uint32_t random1 = engine(); uint32_t random2 = engine(); uint64_t rand_unit = make64BitsFrom32Bits(random1, random2); T random_val = uniform_real(rand_unit, 0.0, 1.0); dptr[i] = static_cast(-1.0) / lambd_ * std::log(static_cast(1.0) - random_val); } } #define INITIATE_CPU_UNIFORM_DISTRIBUTION(T, typeproto) \ template void ExponentialDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/exponential_distribution.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/user/kernels/distributions/exponential_distribution.h" #include "oneflow/user/kernels/fused_rnn_cell_kernel_util.h" namespace oneflow { template struct ExponentialTransformFunctor; template<> struct ExponentialTransformFunctor { ExponentialTransformFunctor(float epsilon, float lambd) : epsilon(epsilon), lambd(lambd) {} __device__ float operator()(float random_val) const { float log_rand = __logf(static_cast(random_val)); // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. // we need log to be not 0, and not underflow when converted to half // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 // args float log = static_cast(random_val) >= static_cast(1.) - epsilon / 2 ? -epsilon / 2 : log_rand; return static_cast(-1.0) / lambd * log; } float epsilon; float lambd; }; template<> struct ExponentialTransformFunctor { ExponentialTransformFunctor(double epsilon, double lambd) : epsilon(epsilon), lambd(lambd) {} __device__ double operator()(double random_val) const { double log_rand = ::log(static_cast(random_val)); // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. // we need log to be not 0, and not underflow when converted to half // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 // args double log = static_cast(random_val) >= static_cast(1.) - epsilon / 2 ? -epsilon / 2 : log_rand; return static_cast(-1.0) / lambd * log; } double epsilon; double lambd; }; template<> struct ExponentialTransformFunctor { ExponentialTransformFunctor(float epsilon, float lambd) : float_functor(epsilon, lambd) {} __device__ half operator()(float random_val) const { return static_cast(float_functor(random_val)); } ExponentialTransformFunctor float_functor; }; template<> void ExponentialDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, double* dptr, const std::shared_ptr& generator) const { CHECK_GT(elem_cnt, 0); const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); ExponentialTransformFunctor transform_functor( std::numeric_limits::epsilon(), static_cast(lambd_)); DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } template<> void ExponentialDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, float* dptr, const std::shared_ptr& generator) const { CHECK_GT(elem_cnt, 0); const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); ExponentialTransformFunctor transform_functor(std::numeric_limits::epsilon(), static_cast(lambd_)); DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } template<> void ExponentialDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, half* dptr, const std::shared_ptr& generator) const { CHECK_GT(elem_cnt, 0); const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); ExponentialTransformFunctor transform_functor(std::numeric_limits::epsilon(), static_cast(lambd_)); DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/exponential_distribution.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { template class ExponentialDistribution; template class ExponentialDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(ExponentialDistribution); ExponentialDistribution(T lambd) : lambd_(lambd) {} ~ExponentialDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T lambd_; }; #ifdef WITH_CUDA template class ExponentialDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(ExponentialDistribution); ExponentialDistribution(T lambd) : lambd_(lambd) {} ~ExponentialDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T lambd_; }; #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_ ================================================ FILE: oneflow/user/kernels/distributions/exponential_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/exponential_kernel.h" namespace oneflow { namespace { #define REGISTER_EXPONENTIAL_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("exponential") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); REGISTER_EXPONENTIAL_KERNEL(DeviceType::kCPU, float) REGISTER_EXPONENTIAL_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, float) REGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, double) REGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, half) #endif // WITH_CUDA } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/exponential_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/distributions/exponential_distribution.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template class ExponentialKernel final : public user_op::OpKernel { public: ExponentialKernel() = default; ~ExponentialKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(device_type)); // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const float lambd = ctx->Attr("lambd"); int64_t elem_cnt = out->shape_view().elem_cnt(); T* out_dptr = out->mut_dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); ExponentialDistribution distribution(static_cast(lambd)); distribution(ctx->stream(), elem_cnt, out_dptr, generator); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/distributions/multinomial_with_replacement_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" // NOTE(Liang Depeng): The implementation of MultinomialWithReplacementCpuKernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/MultinomialKernel.cpp#L23 namespace oneflow { namespace { static size_t InferTmpSizeForCpuKernel(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); int64_t n_categories = x.shape().At(x.shape().NumAxes() - 1); return n_categories * GetSizeOfDataType(x.data_type()); } template static T uniform_real(V val, T from, T to) { constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); constexpr auto DIVISOR = static_cast(1) / (static_cast(1) << std::numeric_limits::digits); T x = (val & MASK) * DIVISOR; return (x * (to - from) + from); } static uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) { return (static_cast(hi) << 32) | lo; } } // namespace template class MultinomialWithReplacementCpuKernel final : public user_op::OpKernel { public: MultinomialWithReplacementCpuKernel() = default; ~MultinomialWithReplacementCpuKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU)); // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); auto cpu_gen = CHECK_JUST(generator->Get()); std::lock_guard lock(cpu_gen->mutex_); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const T* self_ptr = x->dptr(); int64_t* result_ptr = out->mut_dptr(); /* cumulative probability distribution vector */ T* cum_dist_ptr = tmp_buffer->mut_dptr(); int64_t n_categories = x->shape_view().At(x->shape_view().NumAxes() - 1); int64_t n_dist = x->shape_view().NumAxes() > 1 ? x->shape_view().At(0) : 1; const int32_t num_samples = ctx->Attr("num_samples"); int64_t self_stride_0 = x->shape_view().NumAxes() > 1 ? x->stride().at(0) : 0; int64_t self_stride_1 = x->stride().at(x->shape_view().NumAxes() - 1); int64_t result_dist_stride_0 = out->shape_view().NumAxes() > 1 ? out->stride().at(0) : 0; int64_t result_dist_stride_1 = out->stride().at(out->shape_view().NumAxes() - 1); ep::pytorch_mt19937_engine& engine = cpu_gen->torch_engine(); for (int i = 0; i < n_dist; ++i) { /* Get normalized cumulative distribution from prob distribution */ T sum = 0; T val; for (int j = 0; j < n_categories; ++j) { val = self_ptr[i * self_stride_0 + j * self_stride_1]; CHECK(val >= 0) << "invalid multinomial distribution (encountering probability entry < 0)"; CHECK(std::isfinite(val)) << "invalid multinomial distribution (encountering probability " "entry = infinity or NaN)"; sum += val; cum_dist_ptr[j] = sum; } CHECK(sum > 0) << "invalid multinomial distribution (sum of probabilities <= 0)"; /* normalize cumulative probability distribution so that last val is 1 i.e. doesn't assume original self row sums to one */ if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) { for (int j = 0; j < n_categories; ++j) { cum_dist_ptr[j] /= sum; } } for (int j = 0; j < num_samples; ++j) { /* sample a probability mass from a uniform distribution */ // at::uniform_real_distribution uniform(0, 1); // double uniform_sample = uniform(gen); uint32_t random1 = engine(); uint32_t random2 = engine(); uint64_t rand_unit = make64BitsFrom32Bits(random1, random2); double uniform_sample = uniform_real(rand_unit, 0.0, 1.0); // Do a binary search for the slot in which the prob falls // ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot] int left_pointer = 0; int right_pointer = n_categories; int mid_pointer = 0; T cum_prob; int sample_idx = 0; // Make sure the last cumulative distribution bucket sums to 1 cum_dist_ptr[(n_categories - 1)] = 1; while (right_pointer - left_pointer > 0) { mid_pointer = left_pointer + (right_pointer - left_pointer) / 2; cum_prob = cum_dist_ptr[mid_pointer]; if (cum_prob < uniform_sample) { left_pointer = mid_pointer + 1; } else { right_pointer = mid_pointer; } } sample_idx = left_pointer; // store in result tensor (will be incremented for lua compat by wrapper) result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("multinomial_with_replacement") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpSizeForCpuKernel); REGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(float) REGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/multinomial_with_replacement_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" // NOTE(Liang Depeng): The implementation of MultinomialWithReplacementGpuKernel is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/MultinomialKernel.cu#L324 namespace oneflow { namespace { template __device__ int binarySearchForMultinomial(const T* cumdist, const T* dist, int32_t size, T val) { int start = 0; int end = size; while (end - start > 0) { int mid = start + (end - start) / 2; T midVal = cumdist[mid]; if (midVal < val) { start = mid + 1; } else { end = mid; } } if (start == size) { // No probability mass or precision problems; just return the // first non-zero element by setting start to size-1 here, // the code below will move it to the last non-zero probability // this actually can happen when the random number is 1 // (github pytorch issue #4858). start = size - 1; } while (start >= 1 && dist[start] == 0) start--; return start; } template __global__ void sampleMultinomialWithReplacement(uint64_t seed, uint64_t offset, int32_t totalSamples, int64_t* dest, int64_t distributions, int64_t categories, const T* normDistPrefixSum, const T* normDist) { // At the moment, each warp computes one sample value in the binary // search due to divergence. It seems possible to compute multiple // values and limit divergence though later on. // global index formula for 2D grid of 1D blocks int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, idx, offset, &state); // The block determines the distribution for which we generate a point for (int64_t curDist = blockIdx.y; curDist < distributions; curDist += gridDim.y) { for (int sample = blockIdx.x * blockDim.x + threadIdx.x; sample < totalSamples; sample += blockDim.x * gridDim.x) { // we are losing 3 out of 4 generated numbers but it's ok // this kernel is not very efficient anyway auto rand = curand_uniform4(&state); T r = static_cast(rand.x); // Find the bucket that a uniform sample lies in int choice = binarySearchForMultinomial(normDistPrefixSum + curDist * categories, normDist + curDist * categories, categories, r); dest[curDist * totalSamples + sample] = choice; } } } } // namespace template class MultinomialWithReplacementGpuKernel final : public user_op::OpKernel { public: MultinomialWithReplacementGpuKernel() = default; ~MultinomialWithReplacementGpuKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); auto gpu_gen = CHECK_JUST(generator->Get()); const user_op::Tensor* norm_dist = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* prefix_sum = ctx->Tensor4ArgNameAndIndex("prefix_sum", 0); CHECK_NOTNULL(prefix_sum); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const T* norm_dist_ptr = norm_dist->dptr(); const T* prefix_sum_ptr = prefix_sum->dptr(); int64_t* result_ptr = out->mut_dptr(); int64_t numCategories = norm_dist->shape_view().At(norm_dist->shape_view().NumAxes() - 1); int64_t numDist = norm_dist->shape_view().NumAxes() > 1 ? norm_dist->shape_view().At(0) : 1; const int32_t n_sample = ctx->Attr("num_samples"); // Binary search is warp divergent (so effectively we're running // with just a single thread), but for better utilization, // we need each block to have at least 4 warps. dim3 block(128); ep::CudaStream* stream = ctx->stream()->As(); // Each block will generate a sample from one // distribution concurrently. int grid_y = std::min(numDist, stream->device_properties().maxGridSize[1]); dim3 grid((n_sample - 1) / block.x + 1, grid_y); uint64_t seed = gpu_gen->current_seed(); uint64_t offset = gpu_gen->get_philox_offset(((numDist - 1) / grid.y + 1) * 4); // Sample with replacement sampleMultinomialWithReplacement<<cuda_stream()>>>( seed, offset, n_sample, result_ptr, numDist, numCategories, prefix_sum_ptr, norm_dist_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("multinomial_with_replacement") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && (user_op::HobDataType("prefix_sum", 0) == GetDataType::value)); REGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(float) REGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/normal_distribution.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/normal_distribution.h" #include "oneflow/core/framework/framework.h" namespace oneflow { template void NormalDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0) << "elem_cnt must be non-negative, but got " << elem_cnt; auto gen = CHECK_JUST(generator->Get()); std::normal_distribution random_distribution(mean_, std_); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(gen->engine()); } } #define INITIATE_CPU_NORMAL_DISTRIBUTION(T, typeproto) \ template void NormalDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) // specialization for half template<> void NormalDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, float16* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0) << "elem_cnt must be non-negative, but got " << elem_cnt; auto gen = CHECK_JUST(generator->Get()); std::normal_distribution random_distribution(mean_, std_); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = static_cast(random_distribution(gen->engine())); } } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/normal_distribution.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/user/kernels/distributions/normal_distribution.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/device.h" namespace oneflow { template struct NormalTransformFunctor { NormalTransformFunctor(ComputeType mean, ComputeType std) : mean(mean), std(std) {} __device__ T operator()(ComputeType random_val) const { return static_cast(random_val * std + mean); } ComputeType mean; ComputeType std; }; template void NormalDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) return; const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); using ComputeType = typename distribution::DefaultComputeType::type; NormalTransformFunctor transform_functor(static_cast(mean_), static_cast(std_)); if (std::is_same::value) { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } else { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } } #define INITIATE_CUDA_NORMAL_DISTRIBUTION(T, typeproto) \ template void NormalDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) INITIATE_CUDA_NORMAL_DISTRIBUTION(half, DataType::kFloat16) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/normal_distribution.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { template class NormalDistribution; template class NormalDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(NormalDistribution); NormalDistribution(T mean, T std) : mean_(mean), std_(std) {} ~NormalDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T mean_; const T std_; }; #ifdef WITH_CUDA template class NormalDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(NormalDistribution); NormalDistribution(T mean, T std) : mean_(mean), std_(std) {} ~NormalDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T mean_; const T std_; }; #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_ ================================================ FILE: oneflow/user/kernels/distributions/normal_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/normal_kernel.h" namespace oneflow { namespace { #define REGISTER_UNIFORM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("normal").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float16) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, half) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double) #endif // WITH_CUDA } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/normal_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/distributions/normal_distribution.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template class NormalKernel final : public user_op::OpKernel { public: NormalKernel() = default; ~NormalKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(device_type)); // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double mean = ctx->Attr("mean"); const double std = ctx->Attr("std"); int64_t elem_cnt = out->shape_view().elem_cnt(); T* out_dptr = out->mut_dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); NormalDistribution distribution(static_cast(mean), static_cast(std)); distribution(ctx->stream(), elem_cnt, out_dptr, generator); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/distributions/uniform_distribution.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/uniform_distribution.h" namespace oneflow { template class CPUUniformDistributionImpl; template class CPUUniformDistributionImpl::value>::type> { public: CPUUniformDistributionImpl(T low, T high) : random_distribution_(low, high) {} T operator()(std::mt19937& engine) { return random_distribution_(engine); } private: std::uniform_real_distribution random_distribution_; }; template void UniformDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0) << "elem_cnt must be non-negative, but got " << elem_cnt; auto gen = CHECK_JUST(generator->Get()); CPUUniformDistributionImpl impl(low_, high_); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); } } #define INITIATE_CPU_UNIFORM_DISTRIBUTION(T, typeproto) \ template void UniformDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) // specialization for half template<> void UniformDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, float16* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0) << "elem_cnt must be non-negative, but got " << elem_cnt; auto gen = CHECK_JUST(generator->Get()); CPUUniformDistributionImpl impl(low_, high_); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = static_cast(impl(gen->engine())); } } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_distribution.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/user/kernels/distributions/uniform_distribution.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { template struct UniformTransformFunctor { UniformTransformFunctor(ComputeType low, ComputeType high) : low(low), high(high) {} __device__ T operator()(ComputeType rand_num) const { if (rand_num == static_cast(1.0)) { rand_num = static_cast(0.0); } return static_cast(rand_num * (high - low) + low); } ComputeType low; ComputeType high; }; template void UniformDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) return; const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); using ComputeType = typename distribution::DefaultComputeType::type; UniformTransformFunctor transform_functor(static_cast(low_), static_cast(high_)); if (std::is_same::value) { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } else { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } } #define INITIATE_CUDA_UNIFORM_DISTRIBUTION(T, typeproto) \ template void UniformDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) INITIATE_CUDA_UNIFORM_DISTRIBUTION(half, DataType::kFloat16) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_distribution.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { template class UniformDistribution; template class UniformDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(UniformDistribution); UniformDistribution(T low, T high) : low_(low), high_(high) {} ~UniformDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T low_; const T high_; }; #ifdef WITH_CUDA template class UniformDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(UniformDistribution); UniformDistribution(T low, T high) : low_(low), high_(high) {} ~UniformDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const T low_; const T high_; }; #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_ ================================================ FILE: oneflow/user/kernels/distributions/uniform_int_distribution.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/user/kernels/distributions/uniform_int_distribution.h" namespace oneflow { template class CPUUniformIntDistributionImpl { public: CPUUniformIntDistributionImpl(int64_t low, int64_t high) : random_distribution_(low, high) {} T operator()(std::mt19937& engine) { return static_cast(random_distribution_(engine)); } private: std::uniform_int_distribution random_distribution_; }; template void UniformIntDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); auto gen = CHECK_JUST(generator->Get()); // std::uniform_int_distribution generates [low, high], but we want [low, high) here CPUUniformIntDistributionImpl impl(low_, high_ - 1); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); } } #define INITIATE_CPU_UNIFORM_INT_DISTRIBUTION(T, typeproto) \ template void UniformIntDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_int_distribution.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/user/kernels/distributions/uniform_int_distribution.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { template struct UniformIntTransformFunctor { UniformIntTransformFunctor(ComputeType low, ComputeType high) : low(low), high(high) {} __device__ T operator()(ComputeType rand_num) const { if (rand_num == 1.0) { rand_num = 0.0; } return static_cast(static_cast(rand_num * (high - low) + low)); } ComputeType low; ComputeType high; }; template void UniformIntDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) return; const auto device_index = stream->device()->device_index(); auto gen = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gen->current_seed(); uint64_t offset = gen->get_philox_offset(counter_offset); using ComputeType = typename distribution::DefaultComputeType::type; UniformIntTransformFunctor transform_functor(low_, high_); if (std::is_same::value) { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } else { DistributionFunctor dist_functor; DistributionElementwiseGridStrideKernel <<As()->cuda_stream()>>>( elem_cnt, seed, offset, dptr, dist_functor, transform_functor); } } #define INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION(T, typeproto) \ template void UniformIntDistribution::operator()( \ ep::Stream* stream, const int64_t elem_cnt, T* dptr, \ const std::shared_ptr& generator) const; OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_int_distribution.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { template class UniformIntDistribution; template class UniformIntDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution); UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {} ~UniformIntDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const int64_t low_; const int64_t high_; }; #ifdef WITH_CUDA template class UniformIntDistribution final { public: OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution); UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {} ~UniformIntDistribution() = default; void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const; private: const int64_t low_; const int64_t high_; }; #endif // WITH_CUDA } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_ ================================================ FILE: oneflow/user/kernels/distributions/uniform_int_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/uniform_int_kernel.h" namespace oneflow { namespace { #define REGISTER_UNIFORM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("uniform_int") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, uint8_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int8_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int32_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int64_t) #ifdef WITH_CUDA REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, uint8_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int8_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int32_t) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int64_t) #endif // WITH_CUDA } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_int_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/distributions/uniform_int_distribution.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { // The following algorithm is adopted from pytorch: // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can // be used as actual `from`. The current implementation of `random_` uses uint64_t arithmetics and // casts the result to the target dtype(scalar_t). This casting can result in generating numbers // that happen to be greater or equal to `to` value. For instance: // // auto actual = torch::empty({3, 3}, torch::half); // actual.random_(0, 65504); // // If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it // becomes 65504 and violates the requirement that random value must be less than `to`. To resolve // this issue `update_from` and `update_to` moves `from` to the right and `to` to the left to the // next closest value that won't go outside [from, to) after casting to the target dtype. For `to` = // 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous // available number for torch::half dtype. template int64_t update_from(int64_t from) { const auto from_plus_1 = static_cast(static_cast(from + 1)); if (from_plus_1 < from) { int64_t from_ = std::abs(from + 1); int n = 0; while (from_ >>= 1) ++n; // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) from = from_plus_1 + (1LL << (n - std::numeric_limits::digits + 1)); } return from; } template int64_t update_to(int64_t to) { const auto to_minus_1 = static_cast(static_cast(to - 1)); if (to_minus_1 >= to) { int64_t to_ = std::abs(to - 1); int n = 0; while (to_ >>= 1) ++n; // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) to = to_minus_1 - (1LL << (n - std::numeric_limits::digits + 1)); } return to; } template class UniformIntKernel final : public user_op::OpKernel { public: UniformIntKernel() = default; ~UniformIntKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeAutoGenerator()); // When SBP is Spit, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t from = ctx->Attr("from"); int64_t to = ctx->Attr("to"); CHECK_LT(from, to) << "uniform kernel expects 'from' to be less than 'to'"; if (IsFloating::value) { from = update_from(from); to = update_to(to); CHECK_LT(from, to) << "uniform kernel expects 'from' casted to dtype to be less than 'to'" " casted to dtype"; } check_from_to_in_range(from, to - 1); int64_t elem_cnt = out->shape_view().elem_cnt(); T* out_dptr = out->mut_dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); UniformIntDistribution distribution(from, to); distribution(ctx->stream(), elem_cnt, out_dptr, generator); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/distributions/uniform_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/distributions/uniform_kernel.h" namespace oneflow { namespace { #define REGISTER_UNIFORM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("uniform").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float16) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, half) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float) REGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double) #endif // WITH_CUDA } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/distributions/uniform_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_ #define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/distributions/uniform_distribution.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template class UniformKernel final : public user_op::OpKernel { public: UniformKernel() = default; ~UniformKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(device_type)); // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double from = ctx->Attr("from"); const double to = ctx->Attr("to"); check_from_to_in_range(from, to); int64_t elem_cnt = out->shape_view().elem_cnt(); T* out_dptr = out->mut_dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); UniformDistribution distribution(static_cast(from), static_cast(to)); distribution(ctx->stream(), elem_cnt, out_dptr, generator); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/dot_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { using namespace ep::primitive; template std::unique_ptr NewMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type, BlasTransposeType::N, BlasTransposeType::N); } auto MatmulPrimitiveExists() { return hob::make_custom("MatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatmulPrimitive(&ctx).operator bool(); }); } class DotKernel final : public user_op::OpKernel { public: DotKernel() = default; ~DotKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t n = x->shape_view().elem_cnt(); auto primitive = NewMatmulPrimitive(ctx); primitive->Launch(ctx->stream(), 1, 1, n, 1, x->dptr(), y->dptr(), 0, out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("dot").SetCreateFn().SetIsMatchedHob(MatmulPrimitiveExists() == true); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dropout_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/ep/include/primitive/add.h" namespace oneflow { namespace { template void MaskAndScale(ep::Stream* stream, const int64_t n, float scale, const T* x, const bool* mask, T* y) { for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * static_cast(mask[i]) * scale; } } template void FusedDropoutKernel(ep::Stream* stream, const int64_t elem_cnt, const std::shared_ptr& cpu_gen, const float rate, float scale, const T* x, bool* mask, T* y) { /* `uniform_real_distribution` interval is [a, b). And `curand_uniform4` interval is (0, 1.0], so we use > in CUDA and use >= in CPU. */ std::uniform_real_distribution random_distribution(GetZeroVal(), GetOneVal()); for (int64_t i = 0; i < elem_cnt; ++i) { mask[i] = random_distribution(cpu_gen->engine()) >= rate; y[i] = x[i] * static_cast(mask[i]) * scale; } } template class DropoutKernelCPU final : public user_op::OpKernel { public: DropoutKernelCPU() = default; ~DropoutKernelCPU() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const float rate = ctx->Attr("rate"); float scale = 0.0f; if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } auto* fused_dropout_kernel_state = dynamic_cast(state); CHECK_NOTNULL(fused_dropout_kernel_state); const auto& generator = fused_dropout_kernel_state->generator(); CHECK_NOTNULL(generator); std::shared_ptr cpu_generator = CHECK_JUST(generator->Get()); FusedDropoutKernel(ctx->stream(), in->shape_view().elem_cnt(), cpu_generator, rate, scale, in->dptr(), mask->mut_dptr(), out->mut_dptr()); if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), out->data_type()); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); std::unique_ptr primitive = ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); primitive->Launch(ctx->stream(), out->dptr(), add_to_output->dptr(), out->mut_dptr(), add_to_output->shape_view().elem_cnt()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DROPOUT_KERNEL_CPU(dtype) \ REGISTER_USER_KERNEL("dropout") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_DROPOUT_KERNEL_CPU(float) REGISTER_DROPOUT_KERNEL_CPU(double) template class DropoutGradKernelCPU final : public user_op::OpKernel { public: DropoutGradKernelCPU() = default; ~DropoutGradKernelCPU() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float scale = ctx->Attr("scale"); MaskAndScale(ctx->stream(), dy->shape_view().elem_cnt(), scale, dy->dptr(), mask->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DROPOUT_GRAD_KERNEL_CPU(dtype) \ REGISTER_USER_KERNEL("dropout_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_DROPOUT_GRAD_KERNEL_CPU(float) REGISTER_DROPOUT_GRAD_KERNEL_CPU(double) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dropout_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/device/cuda_pseudo_bfloat16.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { constexpr int32_t kVecSize = 4; constexpr int32_t kBlockSize = 256; template constexpr int32_t GetDropoutPackSize() { // For float, bfloat16, half. return 4; }; template<> constexpr int32_t GetDropoutPackSize() { return 2; }; template<> constexpr int32_t GetDropoutPackSize() { return 2; }; union RandPack4 { float4 storage; float elem[4]; }; template struct GetPack2Type { using T2 = typename std::aligned_storage<2 * sizeof(T), 2 * sizeof(T)>::type; }; template<> struct GetPack2Type { using T2 = half2; }; #if CUDA_VERSION >= 11000 template<> struct GetPack2Type { using T2 = nv_bfloat162; }; #endif template using Pack2Type = typename GetPack2Type::T2; using H2PackType = typename std::aligned_storage<4 * sizeof(half), 4 * sizeof(half)>::type; template union H2Pack { cuda::elementwise::Pack pack_storage; Pack2Type h2[2]; __device__ H2Pack() { // do nothing } }; template<> union H2Pack { cuda::elementwise::Pack pack_storage; half2 h2[2]; __device__ H2Pack() { // do nothing } }; #if CUDA_VERSION >= 11000 template<> union H2Pack { cuda::elementwise::Pack pack_storage; nv_bfloat162 h2[2]; __device__ H2Pack() { // do nothing } }; #endif template __device__ Pack2Type Make2(float v); template<> __device__ Pack2Type Make2(float v) { return __float2half2_rn(v); } #if CUDA_VERSION >= 11000 template<> __device__ Pack2Type Make2(float v) { return __float2bfloat162_rn(v); } #endif #if CUDA_VERSION >= 11000 #define RETURN_VOID_IF_HALF \ typename std::enable_if_t<(std::is_same::value || std::is_same::value), \ void> #else #define RETURN_VOID_IF_HALF typename std::enable_if_t::value, void> #endif #define RETURN_VOID_IF_FLOAT typename std::enable_if_t::value, void> #define RETURN_VOID_IF_DOUBLE typename std::enable_if_t::value, void> template __global__ RETURN_VOID_IF_FLOAT FusedDropoutAddGpu(uint64_t seed, uint64_t offset, const int64_t elem_cnt, float rate, float scale, int64_t n_tail, const T* x, bool* mask, const T* addend, T* y, const T* tail_x, bool* tail_mask, const T* tail_addend, T* tail_y) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; using MaskType = cuda::elementwise::PackType; using MaskPack = cuda::elementwise::Pack; T t_scale = static_cast(scale); RandPack4 rand_uniform_pack4; for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { rand_uniform_pack4.storage = curand_uniform4(&state); const LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; LoadPack addend_vec; if (has_addend) { const LoadType* addend_load = reinterpret_cast(addend + linear_index); addend_vec.storage = *addend_load; } MaskPack mask_vec; LoadPack y_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { mask_vec.elem[i] = rand_uniform_pack4.elem[i] > rate; T tmp_float_mask = static_cast(mask_vec.elem[i]); y_vec.elem[i] = x_vec.elem[i] * tmp_float_mask * t_scale; if (has_addend) { y_vec.elem[i] += addend_vec.elem[i]; } } *(reinterpret_cast(y + linear_index)) = y_vec.storage; *(reinterpret_cast(mask + linear_index)) = mask_vec.storage; } if (tail && global_thread_id < n_tail) { const float rand_uniform = curand_uniform(&state); const bool mask_val = rand_uniform > rate; tail_mask[global_thread_id] = mask_val; T tmp_float_mask = static_cast(mask_val); T tmp_tail_out = tail_x[global_thread_id] * tmp_float_mask * t_scale; if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; } tail_y[global_thread_id] = tmp_tail_out; } } template __global__ RETURN_VOID_IF_HALF FusedDropoutAddGpu(uint64_t seed, uint64_t offset, const int64_t elem_cnt, float rate, float scale, int64_t n_tail, const T* x, bool* mask, const T* addend, T* y, const T* tail_x, bool* tail_mask, const T* tail_addend, T* tail_y) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; using StoreType = cuda::elementwise::PackType, pack_size / 2>; using StorePack = cuda::elementwise::Pack, pack_size / 2>; using MaskType = cuda::elementwise::PackType; using MaskPack = cuda::elementwise::Pack; RandPack4 rand_uniform_pack4; Pack2Type h2_scale = Make2(scale); for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { rand_uniform_pack4.storage = curand_uniform4(&state); const LoadType* x_load = reinterpret_cast(x + linear_index); H2Pack x_vec{}; x_vec.pack_storage.storage = *x_load; H2Pack addend_vec{}; if (has_addend) { const LoadType* addend_load = reinterpret_cast(addend + linear_index); addend_vec.pack_storage.storage = *addend_load; } MaskPack mask_vec; StorePack y_vec; StorePack one_or_zero_h2; mask_vec.elem[0] = rand_uniform_pack4.elem[0] > rate; float tmp_float_mask = static_cast(mask_vec.elem[0]); one_or_zero_h2.elem[0].x = tmp_float_mask; mask_vec.elem[1] = rand_uniform_pack4.elem[1] > rate; tmp_float_mask = static_cast(mask_vec.elem[1]); one_or_zero_h2.elem[0].y = tmp_float_mask; y_vec.elem[0] = __hmul2(__hmul2(x_vec.h2[0], one_or_zero_h2.elem[0]), h2_scale); mask_vec.elem[2] = rand_uniform_pack4.elem[2] > rate; tmp_float_mask = static_cast(mask_vec.elem[2]); one_or_zero_h2.elem[1].x = tmp_float_mask; mask_vec.elem[3] = rand_uniform_pack4.elem[3] > rate; tmp_float_mask = static_cast(mask_vec.elem[3]); one_or_zero_h2.elem[1].y = tmp_float_mask; y_vec.elem[1] = __hmul2(__hmul2(x_vec.h2[1], one_or_zero_h2.elem[1]), h2_scale); if (has_addend) { y_vec.elem[0] = __hadd2(y_vec.elem[0], addend_vec.h2[0]); y_vec.elem[1] = __hadd2(y_vec.elem[1], addend_vec.h2[1]); } *(reinterpret_cast(y + linear_index)) = y_vec.storage; *(reinterpret_cast(mask + linear_index)) = mask_vec.storage; } if (tail && global_thread_id < n_tail) { const float rand_uniform = curand_uniform(&state); const bool mask_val = rand_uniform > rate; tail_mask[global_thread_id] = mask_val; float tmp_half_mask = static_cast(mask_val); T tmp_tail_out = tail_x[global_thread_id] * static_cast(tmp_half_mask) * h2_scale.x; if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; } tail_y[global_thread_id] = tmp_tail_out; } } template __global__ RETURN_VOID_IF_DOUBLE FusedDropoutAddGpu(uint64_t seed, uint64_t offset, const int64_t elem_cnt, float rate, float scale, int64_t n_tail, const T* x, bool* mask, const T* addend, T* y, const T* tail_x, bool* tail_mask, const T* tail_addend, T* tail_y) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; using MaskType = cuda::elementwise::PackType; using MaskPack = cuda::elementwise::Pack; RandPack4 rand_uniform_pack4; bool grid_loop_rand_state = 0; for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { if (grid_loop_rand_state == 0) { rand_uniform_pack4.storage = curand_uniform4(&state); grid_loop_rand_state ^= 1; } else { // Use the last two random numbers we generated in previous iteration. rand_uniform_pack4.elem[0] = rand_uniform_pack4.elem[2]; rand_uniform_pack4.elem[1] = rand_uniform_pack4.elem[3]; grid_loop_rand_state ^= 1; } const LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; LoadPack addend_vec; if (has_addend) { const LoadType* addend_load = reinterpret_cast(addend + linear_index); addend_vec.storage = *addend_load; } MaskPack mask_vec; LoadPack y_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { mask_vec.elem[i] = rand_uniform_pack4.elem[i] > rate; y_vec.elem[i] = x_vec.elem[i] * mask_vec.elem[i] * scale; if (has_addend) { y_vec.elem[i] += addend_vec.elem[i]; } } *(reinterpret_cast(y + linear_index)) = y_vec.storage; *(reinterpret_cast(mask + linear_index)) = mask_vec.storage; } if (tail && global_thread_id < n_tail) { const float rand_uniform = curand_uniform(&state); const bool mask_val = rand_uniform > rate; tail_mask[global_thread_id] = mask_val; double tmp_tail_out = tail_x[global_thread_id] * mask_val * scale; if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; } tail_y[global_thread_id] = tmp_tail_out; } } unsigned int ComputeGridSize(ep::Stream* stream, const int32_t block_size, const int64_t elem_cnt) { auto* cuda_stream = stream->As(); const int32_t max_threads_multi_process = cuda_stream->device_properties().maxThreadsPerMultiProcessor; const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount; unsigned int blocks_per_sm = max_threads_multi_process / block_size; unsigned int grid_size = std::max((int64_t)1, ((elem_cnt + block_size - 1) / block_size)); grid_size = std::min((unsigned int)multi_processor_count * blocks_per_sm, grid_size); return grid_size; } template void DispatchTail(ep::Stream* stream, const std::shared_ptr& cuda_generator, const int64_t elem_cnt, float rate, float scale, const T* x, bool* mask, const T* addend, T* y) { constexpr int pack_size = GetDropoutPackSize(); const int64_t pack_num = elem_cnt / pack_size; unsigned int grid_size = ComputeGridSize(stream, kBlockSize, pack_num); const int64_t tail_offset = pack_num * pack_size; const int64_t n_tail = elem_cnt - tail_offset; const bool tail = n_tail > 0 ? true : false; uint64_t offset = 0; uint64_t seed = cuda_generator->current_seed(); if (tail) { // If tail, we need generate randnum one more time, so here we add another `1`. uint64_t inc_offset = ((elem_cnt - 1) / (kBlockSize * grid_size * kVecSize) + 1) * kVecSize + 1; offset = cuda_generator->get_philox_offset(inc_offset); FusedDropoutAddGpu <<As()->cuda_stream()>>>( seed, offset, elem_cnt, rate, scale, n_tail, x, mask, addend, y, (x + tail_offset), (mask + tail_offset), (addend + tail_offset), (y + tail_offset)); } else { uint64_t inc_offset = ((elem_cnt - 1) / (kBlockSize * grid_size * kVecSize) + 1) * kVecSize; offset = cuda_generator->get_philox_offset(inc_offset); FusedDropoutAddGpu <<As()->cuda_stream()>>>( seed, offset, elem_cnt, rate, scale, n_tail, x, mask, addend, y, nullptr, nullptr, nullptr, nullptr); } } template struct MaskAndScaleFunctor { OF_DEVICE_FUNC explicit MaskAndScaleFunctor(float scale) : scale(scale) {} __device__ T operator()(T x, bool mask) const { return x * static_cast(mask) * static_cast(scale); } float scale; }; #if CUDA_VERSION >= 11000 template<> struct MaskAndScaleFunctor { OF_DEVICE_FUNC explicit MaskAndScaleFunctor(float scale) : scale(scale) {} __device__ nv_bfloat16 operator()(nv_bfloat16 x, bool mask) const { float float_mask = static_cast(mask); return x * static_cast(float_mask) * static_cast(scale); } float scale; }; #endif template class DropoutKernelGPU final : public user_op::OpKernel { public: DropoutKernelGPU() = default; ~DropoutKernelGPU() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); auto* fused_dropout_kernel_state = dynamic_cast(state); CHECK_NOTNULL(fused_dropout_kernel_state); const auto& generator = fused_dropout_kernel_state->generator(); CHECK_NOTNULL(generator); auto* stream = ctx->stream(); const auto device_index = stream->device()->device_index(); std::shared_ptr cuda_generator = CHECK_JUST(generator->Get(device_index)); const float rate = ctx->Attr("rate"); float scale = 0.0; if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); DispatchTail( stream, cuda_generator, in->shape_view().elem_cnt(), rate, scale, reinterpret_cast(in->dptr()), reinterpret_cast(mask->mut_dptr()), reinterpret_cast(addend->dptr()), reinterpret_cast(out->mut_dptr())); } else { DispatchTail(stream, cuda_generator, in->shape_view().elem_cnt(), rate, scale, reinterpret_cast(in->dptr()), reinterpret_cast(mask->mut_dptr()), nullptr, reinterpret_cast(out->mut_dptr())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DROPOUT_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("dropout").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)) REGISTER_DROPOUT_KERNEL_GPU(half, DataType::kFloat16); REGISTER_DROPOUT_KERNEL_GPU(float, DataType::kFloat); REGISTER_DROPOUT_KERNEL_GPU(double, DataType::kDouble); #if CUDA_VERSION >= 11000 REGISTER_DROPOUT_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16); #endif template class DropoutGradKernelGPU final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: DropoutGradKernelGPU() = default; ~DropoutGradKernelGPU() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float scale = ctx->Attr("scale"); const int64_t elem_cnt = dy->shape_view().elem_cnt(); OF_CUDA_CHECK((cuda::elementwise::Binary( MaskAndScaleFunctor(scale), elem_cnt, reinterpret_cast(dx->mut_dptr()), reinterpret_cast(dy->dptr()), reinterpret_cast(mask->dptr()), ctx->stream()->As()->cuda_stream()))); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DROPOUT_GRAD_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("dropout_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == data_type)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); \ return Maybe::Ok(); \ }) REGISTER_DROPOUT_GRAD_KERNEL_GPU(half, DataType::kFloat16); REGISTER_DROPOUT_GRAD_KERNEL_GPU(float, DataType::kFloat); REGISTER_DROPOUT_GRAD_KERNEL_GPU(double, DataType::kDouble); #if CUDA_VERSION >= 11000 REGISTER_DROPOUT_GRAD_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16); #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dropout_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_ #define ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_ #include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class FusedDropoutKernelState : public user_op::OpKernelState { public: explicit FusedDropoutKernelState(const std::shared_ptr& generator) : generator_(generator) {} const std::shared_ptr& generator() const { return generator_; } private: std::shared_ptr generator_; }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/dynamic_loss_scale_schedule_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { class DynamicLossScaleScheduleCpuKernel final : public user_op::OpKernel { public: DynamicLossScaleScheduleCpuKernel() = default; ~DynamicLossScaleScheduleCpuKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto* count_not_finite = ctx->Tensor4ArgNameAndIndex("count_not_finite", 0)->dptr(); auto* loss_scale = ctx->Tensor4ArgNameAndIndex("loss_scale", 0)->mut_dptr(); auto* good_step_counter = ctx->Tensor4ArgNameAndIndex("good_step_counter", 0)->mut_dptr(); const auto increment_period = ctx->Attr("increment_period"); const auto multiplier = ctx->Attr("multiplier"); if (*count_not_finite == 0) { int64_t cur_good_step_counter = *good_step_counter + 1; if (cur_good_step_counter >= increment_period) { const double old_loss_scale = *loss_scale; const double new_loss_scale = std::min(old_loss_scale * multiplier, static_cast(FLT_MAX)); *loss_scale = static_cast(new_loss_scale); cur_good_step_counter = 0; LOG(INFO) << "In past " << increment_period << " steps, there are no nan or inf in gradients, so we increase loss_scale from " << old_loss_scale << " to " << new_loss_scale; } *good_step_counter = cur_good_step_counter; } else { *good_step_counter = 0; const double old_loss_scale = *loss_scale; const double new_loss_scale = std::max(old_loss_scale / multiplier, 1.0); *loss_scale = static_cast(new_loss_scale); LOG(INFO) << "There are nan or inf in gradients, so we decrease loss_scale from " << old_loss_scale << " to " << new_loss_scale; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("dynamic_loss_scale_schedule") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/dynamic_loss_scale_schedule_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { __global__ void DynamicLossScaleScheduleGpu(const int64_t increment_period, const float multiplier, const int64_t* count_not_finite, float* loss_scale, int64_t* good_step_counter) { if (*count_not_finite == 0) { int64_t cur_good_step_counter = *good_step_counter + 1; if (cur_good_step_counter >= increment_period) { *loss_scale = static_cast( min(static_cast(*loss_scale) * multiplier, static_cast(FLT_MAX))); cur_good_step_counter = 0; } *good_step_counter = cur_good_step_counter; } else { *good_step_counter = 0; *loss_scale = static_cast(max(static_cast(*loss_scale) / multiplier, 1.0)); } } } // namespace class DynamicLossScaleScheduleGpuKernel final : public user_op::OpKernel { public: DynamicLossScaleScheduleGpuKernel() = default; ~DynamicLossScaleScheduleGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* count_not_finite = ctx->Tensor4ArgNameAndIndex("count_not_finite", 0); user_op::Tensor* loss_scale = ctx->Tensor4ArgNameAndIndex("loss_scale", 0); user_op::Tensor* good_step_counter = ctx->Tensor4ArgNameAndIndex("good_step_counter", 0); const auto increment_period = ctx->Attr("increment_period"); const auto multiplier = ctx->Attr("multiplier"); DynamicLossScaleScheduleGpu<<<1, 1, 0, ctx->stream()->As()->cuda_stream()>>>( increment_period, multiplier, count_not_finite->dptr(), loss_scale->mut_dptr(), good_step_counter->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("dynamic_loss_scale_schedule") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_b_to_s_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/framework/placement_sbp_util.h" namespace oneflow { namespace { Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); Maybe> GetAllBroadcastNdSbp(int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); class EagerBToSOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerBToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerBToSOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { return sorted_elem_cnt2in_tensor_slice_copier_pair_; } const std::vector>>& sorted_elem_cnt2out_tensor_slice_copier_pair() const { return sorted_elem_cnt2out_tensor_slice_copier_pair_; } const std::vector>& sorted_p2p_pair() const { return sorted_p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const Shape& shape = ctx->Attr("shape"); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) { int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id)); int64_t src = -1; const TensorSliceView& out_slice = GetTensorSliceView4ParallelId( *out_parallel_desc->hierarchy(), *CHECK_JUST( CachedGetAllSplitNdSbp(out_split_axis, out_parallel_desc->hierarchy()->NumAxes())), shape, out_parallel_id); CHECK(!out_slice.IsEmpty()); TensorSliceView in_slice; TensorSliceView intersection; { if (in_parallel_desc->ContainingMachineId(dst)) { src = dst; int64_t src_device_id = GlobalProcessCtx::LocalRank(src); int64_t in_parallel_id = CHECK_JUST(in_parallel_desc->ParallelId4MachineDeviceId(src, src_device_id)); in_slice = GetTensorSliceView4ParallelId( *in_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllBroadcastNdSbp(in_parallel_desc->hierarchy()->NumAxes())), shape, in_parallel_id); // copy to out_slice from in_slice if src == dst intersection = out_slice; } else { int64_t in_parallel_num = in_parallel_desc->parallel_num(); int64_t in_parallel_id = out_parallel_id % in_parallel_num; src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id)); in_slice = GetTensorSliceView4ParallelId( *in_parallel_desc->hierarchy(), *CHECK_JUST(GetAllBroadcastNdSbp(in_parallel_desc->hierarchy()->NumAxes())), shape, in_parallel_id); intersection = out_slice.Intersect(in_slice); } } CHECK_NE(src, -1); CHECK(!in_slice.IsEmpty()); CHECK(!intersection.IsEmpty()); sorted_p2p_pair_.emplace_back(std::make_pair(src, dst)); sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(intersection, in_slice, data_type, device_type))); sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(out_slice, intersection, data_type, device_type))); } } std::vector>> sorted_elem_cnt2in_tensor_slice_copier_pair_; std::vector>> sorted_elem_cnt2out_tensor_slice_copier_pair_; std::vector> sorted_p2p_pair_; }; size_t InferEagerBToSKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); Shape shape = ctx->Attr("shape"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LT(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); shape.Set(out_split_axis, bs.At(0).size()); } size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); return tensor_byte_size; } } // namespace class EagerBToSKernel final : public user_op::OpKernel { public: EagerBToSKernel() = default; ~EagerBToSKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* out_ptr = out->mut_dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); DeviceType device_type = ctx->device_type(); for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; int64_t dst = p2p_pair.second; if (src == dst && src == GlobalProcessCtx::Rank()) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr); continue; } if (GlobalProcessCtx::Rank() == src) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2out_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_b_to_s") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferEagerBToSKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_ccl_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/user/kernels/collective_communication/include/communication_context.h" #include "oneflow/user/kernels/collective_communication/include/all_reduce.h" #include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #include "oneflow/user/kernels/collective_communication/include/all_gather.h" #include "oneflow/user/kernels/collective_communication/include/reduce.h" #include "oneflow/user/kernels/collective_communication/include/broadcast.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { auto AllReduceCollectiveCommunicationExists() { return hob::make_custom("AllReduceCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllReduceRegistered(device_type); }); } auto ReduceScatterCollectiveCommunicationExists() { return hob::make_custom("ReduceScatterCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsReduceScatterRegistered(device_type); }); } auto AllGatherCollectiveCommunicationExists() { return hob::make_custom("AllGatherCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllGatherRegistered(device_type); }); } auto ReduceCollectiveCommunicationExists() { return hob::make_custom("ReduceCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsReduceRegistered(device_type); }); } auto BroadcastCollectiveCommunicationExists() { return hob::make_custom("BroadcastCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsBroadcastRegistered(device_type); }); } class EagerCclOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerCclOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerCclOpKernelCache() override = default; const std::shared_ptr& communication_ctx() const { return communication_ctx_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); Symbol parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); communication_ctx_ = ccl::NewCommunicationContext(parallel_desc->device_type(), parallel_desc); } std::shared_ptr communication_ctx_; }; void InitEagerCclOpKernelCache(user_op::KernelCacheContext* ctx, std::shared_ptr* cache_ptr) { // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton // once parallel_conf is determined, so only init the cache at the first time. if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } } // namespace class EagerCclAllReduceKernel final : public user_op::OpKernel { public: EagerCclAllReduceKernel() = default; ~EagerCclAllReduceKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclOpKernelCache(ctx, cache_ptr); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()) << kOfBugIssueUploadPrompt; CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt; ccl::ReduceType reduce_type = ccl::kSum; if (in->data_type() == kBool) { reduce_type = ccl::kMax; } std::unique_ptr all_reduce = ccl::NewCollectiveCommunication( ctx->device_type(), in->data_type(), reduce_type); all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), kernel_cache->communication_ctx()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_ccl_all_reduce") .SetCreateFn() .SetIsMatchedHob(AllReduceCollectiveCommunicationExists()); class EagerCclReduceScatterKernel final : public user_op::OpKernel { public: EagerCclReduceScatterKernel() = default; ~EagerCclReduceScatterKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr) << kOfBugIssueUploadPrompt; const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt; const auto& op_type = ctx->Attr("op_type"); CHECK_EQ(op_type, "sum") << kOfBugIssueUploadPrompt; ccl::ReduceType reduce_type = ccl::kSum; if (in->data_type() == kBool) { reduce_type = ccl::kMax; } std::unique_ptr reduce_scatter = ccl::NewCollectiveCommunication(ctx->device_type(), in->data_type(), reduce_type); reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), kernel_cache->communication_ctx()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_ccl_reduce_scatter") .SetCreateFn() .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists()); class EagerCclAllGatherKernel final : public user_op::OpKernel { public: EagerCclAllGatherKernel() = default; ~EagerCclAllGatherKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr) << kOfBugIssueUploadPrompt; const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt; std::unique_ptr all_gather = ccl::NewCollectiveCommunication(ctx->device_type(), in->data_type()); all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), kernel_cache->communication_ctx()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_ccl_all_gather") .SetCreateFn() .SetIsMatchedHob(AllGatherCollectiveCommunicationExists()); class EagerCclReduceKernel final : public user_op::OpKernel { public: EagerCclReduceKernel() = default; ~EagerCclReduceKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t root = ctx->Attr("root"); void* out_ptr = out->mut_dptr(); if (GlobalProcessCtx::Rank() == root) { CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); } if (out_ptr != nullptr) { CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); } ccl::ReduceType reduce_type = ccl::kSum; if (in->data_type() == kBool) { reduce_type = ccl::kMax; } std::unique_ptr reduce = ccl::NewCollectiveCommunication( ctx->device_type(), in->data_type(), reduce_type); reduce->Launch(ctx->stream(), in->dptr(), out_ptr, in->shape_view().elem_cnt(), root, kernel_cache->communication_ctx()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_ccl_reduce") .SetCreateFn() .SetIsMatchedHob(ReduceCollectiveCommunicationExists()); class EagerCclBroadcastKernel final : public user_op::OpKernel { public: EagerCclBroadcastKernel() = default; ~EagerCclBroadcastKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclOpKernelCache(ctx, cache_ptr); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { size_t size = ctx->input_size("in"); CHECK_EQ(size, ctx->output_size("out")); for (int i = 0; i < size; ++i) { ComputeForOneInput(ctx, cache, i); } } void ComputeForOneInput(user_op::KernelComputeContext* ctx, const user_op::OpKernelCache* cache, int index) const { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", index); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", index); int64_t root = ctx->Attr("root"); const void* in_ptr = in->dptr(); if (GlobalProcessCtx::Rank() == root) { CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); } if (in_ptr != nullptr) { CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); } std::unique_ptr broadcast = ccl::NewCollectiveCommunication(ctx->device_type(), out->data_type()); broadcast->Launch(ctx->stream(), in_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), root, kernel_cache->communication_ctx()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_ccl_broadcast") .SetCreateFn() .SetIsMatchedHob(BroadcastCollectiveCommunicationExists()); class EagerCclTouchKernel final : public user_op::OpKernel { public: EagerCclTouchKernel() = default; ~EagerCclTouchKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override{ // Do nothing. }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("eager_ccl_touch") .SetCreateFn() .SetIsMatchedHob(!(user_op::HobDeviceType() == DeviceType::kInvalidDevice) && !(user_op::HobDeviceType() == DeviceType::kMockDevice)); namespace { class EagerCclS2SCpuOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerCclS2SCpuOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerCclS2SCpuOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf)); } Symbol parallel_desc_; }; size_t InferEagerCclS2SCpuKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); size_t tensor_byte_size = in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); // NOTE(hanbinbin): Set tmp_buffer_size to twice tensor_byte_size because the // SbpParallel4ArgNameAndIndex function of LocalUserOpInferContext is unimplemented return tensor_byte_size * 2; } Maybe>> RawGroupP2PPair( Symbol parallel_desc) { std::shared_ptr>> p2p_pairs = std::make_shared>>(); for (int64_t src : parallel_desc->sorted_machine_ids()) { for (int64_t dst : parallel_desc->sorted_machine_ids()) { p2p_pairs->emplace_back(std::make_pair(src, dst)); } } return p2p_pairs; } static constexpr auto* GroupP2PPair = DECORATE(&RawGroupP2PPair, ThreadLocal); } // namespace template class EagerCclS2SCPUKernel final : public user_op::OpKernel { public: EagerCclS2SCPUKernel() = default; ~EagerCclS2SCPUKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton // once parallel_conf is determined, so only init the cache at the first time. if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = in->shape_view().elem_cnt() * dtype_size; // NOTE: in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out const char* pack_to_ptr = in->dptr(); char* unpack_from_ptr = out->mut_dptr(); int64_t tmp_size = tmp_buffer->shape_view().elem_cnt(); CHECK_EQ(tmp_size, data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()) << in->shape_view().ToString() << " vs " << out->shape_view().ToString(); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; if (out_split_axis != 0) { // Do pack. Need transpose in -> pack_to // pack use temp buffer offset: [0, data_size] pack_to_ptr = tmp_buffer->dptr(); DimVector transpose_in_dim_vec = logical_shape_dim_vec; CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0); transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); } if (in_split_axis != 0) { // Do unpack. Need transpose unpack_from -> out // unpack use temp buffer offset: [tmp_size - data_size, tmp_size] unpack_from_ptr = tmp_buffer->mut_dptr() + (tmp_size - data_size); } { // NOTE: Do S2S const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; const auto& p2p_pairs = CHECK_JUST(GroupP2PPair(kernel_cache->parallel_desc())); for (const auto& pair : *p2p_pairs) { int64_t src = pair.first; int64_t dst = pair.second; if (GlobalProcessCtx::Rank() == src) { Symbol parallel_desc = kernel_cache->parallel_desc(); int64_t device_id = GlobalProcessCtx::LocalRank(dst); int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(dst, device_id)); CHECK_JUST(Send(reinterpret_cast(reinterpret_cast(pack_to_ptr) + parallel_id * chunk_size), elem_per_chunk, in->data_type(), dst, DeviceType::kCPU, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { Symbol parallel_desc = kernel_cache->parallel_desc(); int64_t device_id = GlobalProcessCtx::LocalRank(src); int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(src, device_id)); CHECK_JUST(Recv(reinterpret_cast(reinterpret_cast(unpack_from_ptr) + parallel_id * chunk_size), elem_per_chunk, out->data_type(), src, DeviceType::kCPU, ctx->stream())); } } } if (in_split_axis != 0) { // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_EAGER_CCL_S2S_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("eager_ccl_s2s") \ .SetCreateFn>() \ .SetIsMatchedHob(!(user_op::HobDeviceType() == DeviceType::kCUDA) \ && HobIsSendAndRecvRegistered() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferEagerCclS2SCpuKernelTmpBufferSize); REGISTER_EAGER_CCL_S2S_CPU_KERNEL(int8_t) REGISTER_EAGER_CCL_S2S_CPU_KERNEL(int32_t) REGISTER_EAGER_CCL_S2S_CPU_KERNEL(int64_t) REGISTER_EAGER_CCL_S2S_CPU_KERNEL(bool) REGISTER_EAGER_CCL_S2S_CPU_KERNEL(float) REGISTER_EAGER_CCL_S2S_CPU_KERNEL(double) #undef REGISTER_EAGER_CCL_S2S_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_nccl_s2s_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { namespace { class EagerCclS2SOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerCclS2SOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerCclS2SOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } const ccl::CclComm& ccl_comm() const { return ccl_comm_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf)); EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl_comm_ = comm_mgr->GetCclCommForParallelDesc(parallel_conf); } Symbol parallel_desc_; ccl::CclComm ccl_comm_{}; }; size_t InferEagerCclS2SKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); size_t tensor_byte_size = GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type())); // NOTE(hanbinbin): Set tmp_buffer_size to twice tensor_byte_size because the // SbpParallel4ArgNameAndIndex function of LocalUserOpInferContext is unimplemented return tensor_byte_size * 2; } void InitEagerCclS2SOpKernelCache(user_op::KernelCacheContext* ctx, std::shared_ptr* cache_ptr) { // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton // once parallel_conf is determined, so only init the cache at the first time. if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } } // namespace template class EagerCclS2SKernel final : public user_op::OpKernel { public: EagerCclS2SKernel() = default; ~EagerCclS2SKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { InitEagerCclS2SOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int64_t tmp_size = 0; const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size); // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out const char* pack_to_ptr = in->dptr(); char* unpack_from_ptr = out->mut_dptr(); if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); } CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()) << in->shape_view().ToString() << " vs " << out->shape_view().ToString(); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; if (out_split_axis != 0) { // NOTE(chengcheng): Do pack. Need transpose in -> pack_to // pack use temp buffer offset: [0, data_size] pack_to_ptr = tmp_buffer->dptr(); DimVector transpose_in_dim_vec = logical_shape_dim_vec; CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0); transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); } if (in_split_axis != 0) { // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out // unpack use temp buffer offset: [tmp_size - data_size, tmp_size] unpack_from_ptr = tmp_buffer->mut_dptr() + (tmp_size - data_size); } { // NOTE: Do S2S const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); const auto& ccl_comm = kernel_cache->ccl_comm(); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_EAGER_CCL_S2S_KERNEL(dtype) \ REGISTER_USER_KERNEL("eager_ccl_s2s") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferEagerCclS2SKernelTmpBufferSize); REGISTER_EAGER_CCL_S2S_KERNEL(int8_t) REGISTER_EAGER_CCL_S2S_KERNEL(int32_t) REGISTER_EAGER_CCL_S2S_KERNEL(int64_t) REGISTER_EAGER_CCL_S2S_KERNEL(bool) REGISTER_EAGER_CCL_S2S_KERNEL(float) REGISTER_EAGER_CCL_S2S_KERNEL(double) REGISTER_EAGER_CCL_S2S_KERNEL(float16) #undef REGISTER_EAGER_CCL_S2S_KERNEL } // namespace oneflow #endif // WITH_CUDA || WITH_NPU ================================================ FILE: oneflow/user/kernels/eager_p_to_b_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace { class EagerPToBOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerPToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerPToBOpKernelCache() override = default; const std::vector>& p2p_pair() const { return p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); int64_t in_parallel_num = in_parallel_desc->parallel_num(); for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) { int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id)); for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) { int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id)); p2p_pair_.emplace_back(std::make_pair(src, dst)); } } } std::vector> p2p_pair_; }; size_t InferEagerPToBKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); const Shape& shape = ctx->Attr("shape"); size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); return tensor_byte_size; } } // namespace class EagerPToBKernel final : public user_op::OpKernel { public: EagerPToBKernel() = default; ~EagerPToBKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const int64_t total_elem_cnt = ctx->Attr("shape").elem_cnt(); const auto& p2p_pair = kernel_cache->p2p_pair(); DeviceType device_type = ctx->device_type(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(device_type); CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, total_elem_cnt * GetSizeOfDataType(out->data_type())); std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(add_primitive); for (const auto& pair : p2p_pair) { int64_t src = pair.first; int64_t dst = pair.second; if (GlobalProcessCtx::Rank() == src) { CHECK_JUST(Send(in_ptr, total_elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { CHECK_JUST(Recv(tmp_buffer_ptr, total_elem_cnt, out->data_type(), src, device_type, ctx->stream())); add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), total_elem_cnt); } } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_p_to_b") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferEagerPToBKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_p_to_s_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace { Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); Maybe> GetAllPartialSumNdSbp(int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel(); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal); class EagerPToSOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerPToSOpKernelCache(user_op::KernelCacheContext* ctx) : elem_cnt_of_this_chunk_(0) { Init(ctx); } ~EagerPToSOpKernelCache() override = default; int64_t elem_cnt_of_this_chunk() const { return elem_cnt_of_this_chunk_; } const std::vector>>& sorted_elem_cnt2_in_tensor_slice_copier() const { return sorted_elem_cnt2_in_tensor_slice_copier_; } const std::vector>& sorted_p2p_pair() const { return sorted_p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const Shape& shape = ctx->Attr("shape"); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); int64_t in_parallel_num = in_parallel_desc->parallel_num(); elem_cnt_of_this_chunk_ = 0; for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) { int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id)); const TensorSliceView& out_slice = GetTensorSliceView4ParallelId( *out_parallel_desc->hierarchy(), *CHECK_JUST( CachedGetAllSplitNdSbp(out_split_axis, out_parallel_desc->hierarchy()->NumAxes())), shape, out_parallel_id); CHECK(!out_slice.IsEmpty()); for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) { int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id)); const TensorSliceView& in_slice = GetTensorSliceView4ParallelId( *in_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllPartialSumNdSbp(in_parallel_desc->hierarchy()->NumAxes())), shape, in_parallel_id); CHECK(!in_slice.IsEmpty()); const TensorSliceView& intersection = out_slice.Intersect(in_slice); CHECK(!intersection.IsEmpty()); sorted_p2p_pair_.emplace_back(std::make_pair(src, dst)); sorted_elem_cnt2_in_tensor_slice_copier_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(intersection, in_slice, data_type, device_type))); } if (GlobalProcessCtx::Rank() == dst) { elem_cnt_of_this_chunk_ = sorted_elem_cnt2_in_tensor_slice_copier_.back().first; } } } int64_t elem_cnt_of_this_chunk_; std::vector>> sorted_elem_cnt2_in_tensor_slice_copier_; std::vector> sorted_p2p_pair_; }; size_t InferEagerPToSKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); Shape shape = ctx->Attr("shape"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LT(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); shape.Set(out_split_axis, bs.At(0).size()); } size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); return tensor_byte_size; } } // namespace class EagerPToSKernel final : public user_op::OpKernel { public: EagerPToSKernel() = default; ~EagerPToSKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); int64_t elem_cnt_of_this_chunk = kernel_cache->elem_cnt_of_this_chunk(); const auto& sorted_elem_cnt2_in_tensor_slice_copier = kernel_cache->sorted_elem_cnt2_in_tensor_slice_copier(); const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2_in_tensor_slice_copier.size(), sorted_p2p_pair.size()); DeviceType device_type = ctx->device_type(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(device_type); CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, elem_cnt_of_this_chunk * GetSizeOfDataType(out->data_type())); std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(add_primitive); for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; int64_t dst = p2p_pair.second; if (GlobalProcessCtx::Rank() == src) { const auto& tensor_slice_copier = sorted_elem_cnt2_in_tensor_slice_copier.at(i).second; int64_t send_elem_cnt = sorted_elem_cnt2_in_tensor_slice_copier.at(i).first; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), send_elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { CHECK_JUST(Recv(tmp_buffer_ptr, elem_cnt_of_this_chunk, out->data_type(), src, device_type, ctx->stream())); add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), elem_cnt_of_this_chunk); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_p_to_s") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferEagerPToSKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_s_to_b_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" namespace oneflow { namespace { Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); Maybe> GetAllBroadcastNdSbp(int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); class EagerSToBOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerSToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerSToBOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { return sorted_elem_cnt2in_tensor_slice_copier_pair_; } const std::vector>>& sorted_elem_cnt2out_tensor_slice_copier_pair() const { return sorted_elem_cnt2out_tensor_slice_copier_pair_; } const std::vector>& sorted_p2p_pair() const { return sorted_p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const Shape& shape = ctx->Attr("shape"); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); int64_t in_parallel_num = in_parallel_desc->parallel_num(); for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) { int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id)); const TensorSliceView& out_slice = GetTensorSliceView4ParallelId( *out_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllBroadcastNdSbp(out_parallel_desc->hierarchy()->NumAxes())), shape, out_parallel_id); CHECK(!out_slice.IsEmpty()); for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) { int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id)); const TensorSliceView& in_slice = GetTensorSliceView4ParallelId( *in_parallel_desc->hierarchy(), *CHECK_JUST( CachedGetAllSplitNdSbp(in_split_axis, in_parallel_desc->hierarchy()->NumAxes())), shape, in_parallel_id); CHECK(!in_slice.IsEmpty()); const TensorSliceView& intersection = out_slice.Intersect(in_slice); CHECK(!intersection.IsEmpty()); sorted_p2p_pair_.emplace_back(std::make_pair(src, dst)); sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(intersection, in_slice, data_type, device_type))); sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(out_slice, intersection, data_type, device_type))); } } } std::vector>> sorted_elem_cnt2in_tensor_slice_copier_pair_; std::vector>> sorted_elem_cnt2out_tensor_slice_copier_pair_; std::vector> sorted_p2p_pair_; }; size_t InferEagerSToBKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); Shape shape = ctx->Attr("shape"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); int64_t in_parallel_num = in_parallel_desc->parallel_num(); if (in_parallel_num > 1) { CHECK_LT(in_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(in_split_axis), in_parallel_num); shape.Set(in_split_axis, bs.At(0).size()); } size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); return tensor_byte_size; } } // namespace class EagerSToBKernel final : public user_op::OpKernel { public: EagerSToBKernel() = default; ~EagerSToBKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* out_ptr = out->mut_dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); DeviceType device_type = ctx->device_type(); for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; int64_t dst = p2p_pair.second; if (GlobalProcessCtx::Rank() == src) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2out_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_s_to_b") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferEagerSToBKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_s_to_p_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" namespace oneflow { namespace { Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); Maybe> GetAllPartialSumNdSbp(int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel(); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal); class EagerSToPOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerSToPOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerSToPOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { return sorted_elem_cnt2in_tensor_slice_copier_pair_; } const std::vector>>& sorted_elem_cnt2out_tensor_slice_copier_pair() const { return sorted_elem_cnt2out_tensor_slice_copier_pair_; } const std::vector>& sorted_p2p_pair() const { return sorted_p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const Shape& shape = ctx->Attr("shape"); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t in_parallel_num = in_parallel_desc->parallel_num(); for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) { int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id)); int64_t dst = -1; const TensorSliceView& in_slice = GetTensorSliceView4ParallelId( *in_parallel_desc->hierarchy(), *CHECK_JUST( CachedGetAllSplitNdSbp(in_split_axis, in_parallel_desc->hierarchy()->NumAxes())), shape, in_parallel_id); CHECK(!in_slice.IsEmpty()); TensorSliceView out_slice; TensorSliceView intersection; { if (out_parallel_desc->ContainingMachineId(src)) { dst = src; int64_t dst_device_id = GlobalProcessCtx::LocalRank(dst); int64_t out_parallel_id = CHECK_JUST(in_parallel_desc->ParallelId4MachineDeviceId(dst, dst_device_id)); out_slice = GetTensorSliceView4ParallelId( *out_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllPartialSumNdSbp(out_parallel_desc->hierarchy()->NumAxes())), shape, out_parallel_id); // copy to out_slice from in_slice if src == dst intersection = out_slice; } else { int64_t out_parallel_num = out_parallel_desc->parallel_num(); int64_t out_parallel_id = in_parallel_id % out_parallel_num; dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id)); out_slice = GetTensorSliceView4ParallelId( *out_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllPartialSumNdSbp(in_parallel_desc->hierarchy()->NumAxes())), shape, out_parallel_id); intersection = out_slice.Intersect(in_slice); } } CHECK_NE(dst, -1); CHECK(!out_slice.IsEmpty()); CHECK(!intersection.IsEmpty()); sorted_p2p_pair_.emplace_back(std::make_pair(src, dst)); sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(intersection, in_slice, data_type, device_type))); sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(out_slice, intersection, data_type, device_type))); } } std::vector>> sorted_elem_cnt2in_tensor_slice_copier_pair_; std::vector>> sorted_elem_cnt2out_tensor_slice_copier_pair_; std::vector> sorted_p2p_pair_; }; size_t InferEagerSToPKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); Shape shape = ctx->Attr("shape"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); int64_t in_parallel_num = in_parallel_desc->parallel_num(); if (in_parallel_num > 1) { CHECK_LT(in_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(in_split_axis), in_parallel_num); shape.Set(in_split_axis, bs.At(0).size()); } return shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); } } // namespace class EagerSToPKernel final : public user_op::OpKernel { public: EagerSToPKernel() = default; ~EagerSToPKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* out_ptr = out->mut_dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const int64_t total_elem_cnt = ctx->Attr("shape").elem_cnt(); DeviceType device_type = ctx->device_type(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(device_type); CHECK(memset_primitive) << "Can not create Memset primitive for device type " << device_type; memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, total_elem_cnt * GetSizeOfDataType(out->data_type())); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; int64_t dst = p2p_pair.second; if (src == dst && src == GlobalProcessCtx::Rank()) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr); continue; } if (GlobalProcessCtx::Rank() == src) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2out_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_s_to_p") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferEagerSToPKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_s_to_s_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/communicate_util.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { namespace { bool ContainsEmptySlice(const std::vector& slices) { return std::any_of(slices.cbegin(), slices.cend(), [](const TensorSliceView& slice) { return slice.IsEmpty(); }); } Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); class EagerNaiveSToSOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerNaiveSToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerNaiveSToSOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { return sorted_elem_cnt2in_tensor_slice_copier_pair_; } const std::vector>>& sorted_elem_cnt2out_tensor_slice_copier_pair() const { return sorted_elem_cnt2out_tensor_slice_copier_pair_; } const std::vector>& sorted_p2p_pair() const { return sorted_p2p_pair_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const Shape& shape = ctx->Attr("shape"); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t in_parallel_num = in_parallel_desc->parallel_num(); int64_t out_parallel_num = out_parallel_desc->parallel_num(); const std::vector in_slices = GetTensorSliceView(*in_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllSplitNdSbp( in_split_axis, in_parallel_desc->hierarchy()->NumAxes())), shape); CHECK(!ContainsEmptySlice(in_slices)); const std::vector out_slices = GetTensorSliceView(*out_parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllSplitNdSbp( out_split_axis, out_parallel_desc->hierarchy()->NumAxes())), shape); CHECK(!ContainsEmptySlice(out_slices)); for (int64_t i = 0; i < out_parallel_num; ++i) { const TensorSliceView& out_slice = out_slices.at(i); for (int64_t j = 0; j < in_parallel_num; ++j) { const TensorSliceView& in_slice = in_slices.at(j); const TensorSliceView& intersection = out_slice.Intersect(in_slice); if (intersection.IsEmpty()) { continue; } int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(j)); int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(i)); sorted_p2p_pair_.emplace_back(std::make_pair(src, dst)); sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(intersection, in_slice, data_type, device_type))); sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair( intersection.shape().elem_cnt(), std::make_shared(out_slice, intersection, data_type, device_type))); } } } std::vector>> sorted_elem_cnt2in_tensor_slice_copier_pair_; std::vector>> sorted_elem_cnt2out_tensor_slice_copier_pair_; std::vector> sorted_p2p_pair_; }; size_t InferNaiveSToSKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); Shape shape = ctx->Attr("shape"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); Symbol out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt)); int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LT(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); shape.Set(out_split_axis, bs.At(0).size()); } size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type()); return tensor_byte_size; } } // namespace class EagerNaiveSToSKernel final : public user_op::OpKernel { public: EagerNaiveSToSKernel() = default; ~EagerNaiveSToSKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* out_ptr = out->mut_dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); DeviceType device_type = ctx->device_type(); for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) { const auto& p2p_pair = sorted_p2p_pair.at(i); int64_t src = p2p_pair.first; int64_t dst = p2p_pair.second; if (GlobalProcessCtx::Rank() == src) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2in_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr); CHECK_JUST(Send(reinterpret_cast(tmp_buffer_ptr), elem_cnt, in->data_type(), dst, device_type, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { const auto& elem_cnt2tensor_slice_copier_pair = sorted_elem_cnt2out_tensor_slice_copier_pair.at(i); const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first; const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second; CHECK_JUST( Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream())); tensor_slice_copier->Copy(ctx->stream(), out_ptr, reinterpret_cast(tmp_buffer_ptr)); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_naive_s_to_s") .SetCreateFn() .SetIsMatchedHob(HobIsSendAndRecvRegistered()) .SetInferTmpSizeFn(InferNaiveSToSKernelTmpBufferSize); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/control/global_process_ctx.h" namespace oneflow { namespace { template std::unique_ptr NewMemsetPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type()); } auto MemsetPrimitiveExists() { return hob::make_custom("MemsetPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMemsetPrimitive(&ctx).operator bool(); }); } Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); Maybe> GetAllPartialSumNdSbp(int64_t ndim) { NdSbp split_nd_sbp; for (int64_t i = 0; i < ndim; ++i) { split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel(); } return SymbolOf(split_nd_sbp); } auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal); class EagerSymmetricSToPOpKernelCache final : public user_op::OpKernelCache { public: explicit EagerSymmetricSToPOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } ~EagerSymmetricSToPOpKernelCache() override = default; const std::shared_ptr& tensor_slice_copier() const { return tensor_slice_copier_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const Shape& shape = in_logical_desc->shape(); DeviceType device_type = ctx->device_type(); DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); ParallelConf parallel_conf; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); Symbol parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); const TensorSliceView& in_slice = GetTensorSliceView4ParallelId( *parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllSplitNdSbp(in_split_axis, parallel_desc->hierarchy()->NumAxes())), shape, ctx->parallel_ctx().parallel_id()); CHECK(!in_slice.IsEmpty()); const TensorSliceView& out_slice = GetTensorSliceView4ParallelId( *parallel_desc->hierarchy(), *CHECK_JUST(CachedGetAllPartialSumNdSbp(parallel_desc->hierarchy()->NumAxes())), shape, ctx->parallel_ctx().parallel_id()); CHECK(!out_slice.IsEmpty()); const TensorSliceView& intersection = out_slice.Intersect(in_slice); CHECK(!intersection.IsEmpty()); tensor_slice_copier_ = std::make_shared(out_slice, in_slice, data_type, device_type); } std::shared_ptr tensor_slice_copier_; }; } // namespace class EagerSymmetricSToPKernel final : public user_op::OpKernel { public: EagerSymmetricSToPKernel() = default; ~EagerSymmetricSToPKernel() override = default; void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(cache); CHECK(kernel_cache != nullptr); auto primitive = NewMemsetPrimitive(ctx); CHECK(primitive); // NOLINT const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto& out_shape_view = out->shape_view(); const void* in_ptr = in->dptr(); void* out_ptr = out->mut_dptr(); primitive->Launch(ctx->stream(), out->mut_dptr(), 0, out_shape_view.elem_cnt() * GetSizeOfDataType(out->data_type())); const auto& tensor_slice_copier = kernel_cache->tensor_slice_copier(); tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("eager_symmetric_s_to_p") .SetCreateFn() .SetIsMatchedHob(MemsetPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/elementwise_maximum_minimum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/elementwise_maximum_minimum_kernel.h" namespace oneflow { namespace { template class Opt, typename T> struct ElemwiseXimumGradFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx, T* dy) { XPU_1D_KERNEL_LOOP(idx, elem_cnt) { Opt()(dz[idx], x[idx], y[idx], dx ? &dx[idx] : nullptr, dy ? &dy[idx] : nullptr); } } }; template class Opt, typename T> struct ElemwiseXimumFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y) { FOR_RANGE(int64_t, idx, 0, elem_cnt) { z[idx] = Opt()(x[idx], y[idx]); } } }; } // namespace OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MAXIMUM_KERNELS, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MINIMUM_KERNELS, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/elementwise_maximum_minimum_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/elementwise_maximum_minimum_kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template class Opt, typename T> __global__ void ElementwiseXimumGradGpuKernel(int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx, T* dy) { XPU_1D_KERNEL_LOOP(idx, elem_cnt) { Opt()(dz[idx], x[idx], y[idx], dx ? &dx[idx] : nullptr, dy ? &dy[idx] : nullptr); } } template class Opt, typename T> struct ElemwiseXimumGradFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx, T* dy) { ElementwiseXimumGradGpuKernel <<As()->cuda_stream()>>>(elem_cnt, dz, x, y, dx, dy); } }; template class Opt, typename T> struct ElemwiseXimumFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y) { OF_CUDA_CHECK(cuda::elementwise::Binary(Opt(), elem_cnt, z, x, y, stream->As()->cuda_stream())); } }; } // namespace OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MAXIMUM_KERNELS, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MINIMUM_KERNELS, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/elementwise_maximum_minimum_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_ #define _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { template struct MaximumFunctor { OF_DEVICE_FUNC T operator()(const T x, const T y) const { return x > y ? x : y; } }; template struct MaximumGradFunctor { OF_DEVICE_FUNC void operator()(const T dz, const T x, const T y, T* dx, T* dy) { T dx_val = 0; T dy_val = 0; if (x > y) { dx_val = dz; } else if (x == y) { dx_val = dz / 2; dy_val = dz / 2; } else { dy_val = dz; } if (dx) { *dx = dx_val; } if (dy) { *dy = dy_val; } } }; template struct MinimumFunctor { OF_DEVICE_FUNC T operator()(const T x, const T y) const { return x < y ? x : y; } }; template struct MinimumGradFunctor { OF_DEVICE_FUNC void operator()(const T dz, const T x, const T y, T* dx, T* dy) { T dx_val = 0; T dy_val = 0; if (x < y) { dx_val = dz; } else if (x == y) { dx_val = dz / 2; dy_val = dz / 2; } else { dy_val = dz; } if (dx) { *dx = dx_val; } if (dy) { *dy = dy_val; } } }; namespace { template class Opt, typename T> struct ElemwiseXimumGradFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx, T* dy); }; template class Opt, typename T> struct ElemwiseXimumFunctor final { void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y); }; } // namespace template class Opt, typename T> class ElemwiseXimumKernel final : public user_op::OpKernel { public: ElemwiseXimumKernel() = default; ~ElemwiseXimumKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); int64_t n = tensor_x->shape_view().elem_cnt(); ElemwiseXimumFunctor()(ctx->stream(), n, tensor_z->mut_dptr(), tensor_x->dptr(), tensor_y->dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class Opt, typename T> class ElemwiseXimumBackwardKernel final : public user_op::OpKernel { public: ElemwiseXimumBackwardKernel() = default; ~ElemwiseXimumBackwardKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const T* dptr_dz = tensor_dz->dptr(); const T* dptr_x = tensor_x->dptr(); const T* dptr_y = tensor_y->dptr(); T* dptr_dx = tensor_dx ? tensor_dx->mut_dptr() : nullptr; T* dptr_dy = tensor_dy ? tensor_dy->mut_dptr() : nullptr; ElemwiseXimumGradFunctor()(ctx->stream(), tensor_dz->shape_view().elem_cnt(), dptr_dz, dptr_x, dptr_y, dptr_dx, dptr_dy); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MAXIMUM_KERNELS(device, dtype_pair) \ REGISTER_USER_KERNEL("elementwise_maximum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))); \ REGISTER_USER_KERNEL("elementwise_maximum_backward") \ .SetCreateFn< \ ElemwiseXimumBackwardKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))); #define REGISTER_MINIMUM_KERNELS(device, dtype_pair) \ REGISTER_USER_KERNEL("elementwise_minimum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))); \ REGISTER_USER_KERNEL("elementwise_minimum_backward") \ .SetCreateFn< \ ElemwiseXimumBackwardKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))); } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/elementwise_primitive_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_ #define _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_ #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { class UnaryPrimitiveKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(UnaryPrimitiveKernel); UnaryPrimitiveKernel() = default; ~UnaryPrimitiveKernel() = default; using PrimitiveFactoryFuncType = std::function( user_op::KernelComputeContext*)>; UnaryPrimitiveKernel(const std::string& output_name, const std::string& input_name, PrimitiveFactoryFuncType fn) : output_name_(output_name), input_name_(input_name), primitive_factory_func_(std::move(fn)) {} private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { auto primitive = primitive_factory_func_(ctx); CHECK(primitive); const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(input_name_, 0); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(output_name_, 0); const ShapeView& input_shape = input_tensor->shape_view(); const ShapeView& output_shape = output_tensor->shape_view(); CHECK_EQ(input_shape, output_shape) << "Input shape should be equal to Output shape."; const int64_t elem_cnt = input_shape.elem_cnt(); if (elem_cnt != 0) { primitive->Launch(ctx->stream(), input_tensor->dptr(), output_tensor->mut_dptr(), elem_cnt); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::string output_name_; std::string input_name_; PrimitiveFactoryFuncType primitive_factory_func_; }; class BinaryPrimitiveKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(BinaryPrimitiveKernel); BinaryPrimitiveKernel() = default; ~BinaryPrimitiveKernel() = default; using PrimitiveFactoryFuncType = std::function( user_op::KernelComputeContext*)>; BinaryPrimitiveKernel(const std::string& output_name, const std::string& input_a_name, const std::string& input_b_name, PrimitiveFactoryFuncType fn) : output_name_(output_name), input_a_name_(input_a_name), input_b_name_(input_b_name), primitive_factory_func_(std::move(fn)) {} private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { auto primitive = primitive_factory_func_(ctx); CHECK(primitive); const user_op::Tensor* input_a_tensor = ctx->Tensor4ArgNameAndIndex(input_a_name_, 0); const user_op::Tensor* input_b_tensor = ctx->Tensor4ArgNameAndIndex(input_b_name_, 0); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(output_name_, 0); const ShapeView& input_a_shape = input_a_tensor->shape_view(); const ShapeView& input_b_shape = input_b_tensor->shape_view(); const ShapeView& output_shape = output_tensor->shape_view(); CHECK_EQ(input_a_shape, input_b_shape) << "InputA shape should be equal to InputB shape."; CHECK_EQ(input_a_shape, output_shape) << "Input shape should be equal to Output shape."; const int64_t elem_cnt = input_a_shape.elem_cnt(); if (elem_cnt != 0) { primitive->Launch(ctx->stream(), input_a_shape.NumAxes(), input_a_shape.ptr(), input_a_tensor->dptr(), input_b_shape.NumAxes(), input_b_shape.ptr(), input_b_tensor->dptr(), output_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::string output_name_; std::string input_a_name_; std::string input_b_name_; PrimitiveFactoryFuncType primitive_factory_func_; }; namespace { auto UnaryPrimitiveExists(ep::primitive::UnaryOp op, const std::string& output_name, const std::string& input_name) { return hob::make_custom( "ElementwiseUnaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) { const user_op::TensorDesc* src = ctx.TensorDesc4ArgNameAndIndex(input_name, 0); const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0); auto primitive = ep::primitive::NewPrimitive( ctx.device_type(), op, src->data_type(), dst->data_type()); return primitive.operator bool(); }); } auto BinaryPrimitiveExists(ep::primitive::BinaryOp op, const std::string& output_name, const std::string& input_a_name) { return hob::make_custom( "BroadcastElementwiseBinaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) { const user_op::TensorDesc* src0 = ctx.TensorDesc4ArgNameAndIndex(input_a_name, 0); const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0); auto primitive = ep::primitive::NewPrimitive( ctx.device_type(), op, src0->data_type(), dst->data_type(), 1 /*max_num_dims*/); return primitive.operator bool(); }); } } // namespace } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/embedding_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/user/kernels/embedding_kernel_util.h" namespace oneflow { template class CpuEmbeddingRenormKernel final : public user_op::OpKernel { public: CpuEmbeddingRenormKernel() = default; ~CpuEmbeddingRenormKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double max_norm = ctx->Attr("max_norm"); const double norm_type = ctx->Attr("norm_type"); const ShapeView& in_shape = in->shape_view(); const int64_t emb_size = in_shape.At(0); const int64_t emb_dim = in_shape.At(1); const T* in_buf = in->dptr(); const IndexType* indices_buf = indices->dptr(); T* out_buf = out->mut_dptr(); const int64_t num_indices = indices->shape_view().elem_cnt(); EmbeddingReNormFunctor()( ctx->stream(), in_buf, indices_buf, out_buf, max_norm, norm_type, num_indices, emb_size, emb_dim, nullptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class CpuEmbeddingKernel final : public user_op::OpKernel { public: CpuEmbeddingKernel() = default; ~CpuEmbeddingKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); const ShapeView& out_shape = out->shape_view(); const int64_t num_indices = out_shape.Count(0, out_shape.NumAxes() - 1); const int64_t emb_size = weight->shape_view().At(0); const int64_t emb_dim = out_shape.At(out_shape.NumAxes() - 1); const T* weight_buf = weight->dptr(); const IndexType* indices_buf = indices->dptr(); T* out_buf = out->mut_dptr(); EmbeddingFunctor()(ctx->stream(), weight_buf, indices_buf, out_buf, padding_idx, scale_grad_by_freq, num_indices, emb_size, emb_dim); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class CpuEmbeddingGradKernel final : public user_op::OpKernel { public: CpuEmbeddingGradKernel() = default; ~CpuEmbeddingGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); const ShapeView& dy_shape = dy->shape_view(); const int64_t num_indices = dy_shape.Count(0, dy_shape.NumAxes() - 1); const int64_t emb_size = weight->shape_view().At(0); const int64_t emb_dim = dy_shape.At(dy_shape.NumAxes() - 1); const T* dy_buf = dy->dptr(); const IndexType* indices_buf = indices->dptr(); T* dx_buf = dx->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), dx_buf, 0, dx->shape_view().Count(0) * sizeof(T)); EmbeddingGradFunctor()(ctx->stream(), dy_buf, indices_buf, dx_buf, padding_idx, scale_grad_by_freq, num_indices, emb_size, emb_dim, nullptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_EMBEDDING_KERNEL(in_type, indices_type) \ REGISTER_USER_KERNEL("embedding_renorm") \ .SetCreateFn< \ CpuEmbeddingRenormKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); \ REGISTER_USER_KERNEL("embedding") \ .SetCreateFn< \ CpuEmbeddingKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("weight", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); \ REGISTER_USER_KERNEL("embedding_grad") \ .SetCreateFn< \ CpuEmbeddingGradKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("weight", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_EMBEDDING_KERNEL, EMBEDDING_DATA_TYPE_SEQ_CPU, INDEX_DATA_TYPE_SEQ) #undef REGISTER_CPU_EMBEDDING_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/embedding_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/user/kernels/embedding_kernel_util.h" namespace oneflow { template class GpuEmbeddingRenormKernel final : public user_op::OpKernel { public: GpuEmbeddingRenormKernel() = default; ~GpuEmbeddingRenormKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double max_norm = ctx->Attr("max_norm"); const double norm_type = ctx->Attr("norm_type"); const ShapeView& in_shape = in->shape_view(); const int64_t emb_size = in_shape.At(0); const int64_t emb_dim = in_shape.At(1); const T* in_buf = in->dptr(); const IndexType* indices_buf = indices->dptr(); T* out_buf = out->mut_dptr(); const int64_t num_indices = indices->shape_view().elem_cnt(); int32_t* tmp_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), tmp_buf, 0, GetCudaAlignedSize(sizeof(int32_t) * emb_size)); EmbeddingReNormFunctor()( ctx->stream(), in_buf, indices_buf, out_buf, max_norm, norm_type, num_indices, emb_size, emb_dim, tmp_buf); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuEmbeddingKernel final : public user_op::OpKernel { public: GpuEmbeddingKernel() = default; ~GpuEmbeddingKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); const int64_t num_indices = indices->shape_view().elem_cnt(); const int64_t emb_size = weight->shape_view().At(0); const int64_t emb_dim = weight->shape_view().At(1); const T* weight_buf = weight->dptr(); const IndexType* indices_buf = indices->dptr(); T* out_buf = out->mut_dptr(); EmbeddingFunctor()(ctx->stream(), weight_buf, indices_buf, out_buf, padding_idx, scale_grad_by_freq, num_indices, emb_size, emb_dim); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class GpuEmbeddingGradKernel final : public user_op::OpKernel { public: GpuEmbeddingGradKernel() = default; ~GpuEmbeddingGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); const int64_t num_indices = indices->shape_view().elem_cnt(); const int64_t emb_size = weight->shape_view().At(0); const int64_t emb_dim = weight->shape_view().At(1); const T* dy_buf = dy->dptr(); const IndexType* indices_buf = indices->dptr(); T* dx_buf = dx->mut_dptr(); int32_t* tmp_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), dx_buf, 0, dx->shape_view().elem_cnt() * sizeof(T)); memset_primitive->Launch(ctx->stream(), tmp_buf, 0, GetCudaAlignedSize(sizeof(int32_t) * emb_size)); EmbeddingGradFunctor()( ctx->stream(), dy_buf, indices_buf, dx_buf, padding_idx, scale_grad_by_freq, num_indices, emb_size, emb_dim, tmp_buf); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_EMBEDDING_KERNEL(in_type, indices_type) \ REGISTER_USER_KERNEL("embedding_renorm") \ .SetCreateFn< \ GpuEmbeddingRenormKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& in_shape = ctx->InputShape("in", 0); \ const int64_t emb_size = in_shape.At(0); \ return GetCudaAlignedSize(sizeof(int32_t) * emb_size); \ }); \ REGISTER_USER_KERNEL("embedding") \ .SetCreateFn< \ GpuEmbeddingKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("weight", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); \ REGISTER_USER_KERNEL("embedding_grad") \ .SetCreateFn< \ GpuEmbeddingGradKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("weight", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& in_shape = ctx->InputShape("weight", 0); \ const int64_t emb_size = in_shape.At(0); \ return GetCudaAlignedSize(sizeof(int32_t) * emb_size); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_KERNEL, EMBEDDING_DATA_TYPE_SEQ_CUDA, INDEX_DATA_TYPE_SEQ) #undef REGISTER_CUDA_EMBEDDING_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/embedding_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/embedding_kernel_util.h" namespace oneflow { template struct EmbeddingReNormFunctor final { void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf, const double max_norm, const double norm_type, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) { auto sorted_indices = std::vector(indices_buf, indices_buf + num_indices); std::sort(sorted_indices.begin(), sorted_indices.end()); for (int64_t i = 0; i < num_indices; i++) { if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { continue; } CHECK(sorted_indices[i] >= 0 && sorted_indices[i] < emb_size); double norm = 0; for (int64_t j = emb_dim * sorted_indices[i]; j < emb_dim * (sorted_indices[i] + 1); j++) { norm += std::pow(std::abs(in_buf[j]), norm_type); } norm = std::pow(norm, (1.0 / norm_type)); if (norm > max_norm) { double scale = max_norm / (norm + 1e-7); for (int64_t j = emb_dim * sorted_indices[i]; j < emb_dim * (sorted_indices[i] + 1); j++) { out_buf[j] = in_buf[j] * scale; } } } } }; template struct EmbeddingFunctor final { void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim) { for (int64_t i = 0; i < num_indices; i++) { IndexType indice = indices_buf[i]; CHECK(indice >= 0 && indice < emb_size); const T* from = weight_buf + indice * emb_dim; T* to = out_buf + i * emb_dim; std::copy(from, from + emb_dim, to); } } }; template struct EmbeddingGradFunctor final { void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) { for (int64_t i = 0; i < num_indices; i++) { IndexType indice = indices_buf[i]; if (indice != padding_idx) { const T* from = dy_buf + i * emb_dim; T* to = dx_buf + indice * emb_dim; std::transform(from, from + emb_dim, to, to, std::plus()); } } if (scale_grad_by_freq) { std::vector indice_freq(emb_size, 0); for (int64_t i = 0; i < num_indices; i++) { indice_freq[indices_buf[i]]++; } for (int64_t i = 0; i < emb_size; i++) { if (indice_freq[i] > 1) { T* from = dx_buf + i * emb_dim; for (int64_t j = 0; j < emb_dim; j++) { from[j] /= indice_freq[i]; } } } } } }; #define INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL(in_type_pair, index_type_pair) \ template struct EmbeddingReNormFunctor; \ template struct EmbeddingFunctor; \ template struct EmbeddingGradFunctor; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL, EMBEDDING_DATA_TYPE_SEQ_CPU, INDEX_DATA_TYPE_SEQ); #undef INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/embedding_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/embedding_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template struct AccumulateType { using type = T; }; template<> struct AccumulateType { using type = float; }; template __global__ void embedding_kernel(const T* weight_buf, const IndexType* indices_buf, T* out_buf, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_indices * emb_dim) { IndexType indices_index = i / emb_dim; IndexType emb_dim_index = i - indices_index * emb_dim; IndexType emb_size_index = indices_buf[indices_index]; assert(emb_size_index >= 0 && emb_size_index < emb_size); IndexType from_index = emb_size_index * emb_dim + emb_dim_index; out_buf[i] = weight_buf[from_index]; } } template __global__ void embedding_grad_kernel(const T* dy_buf, const IndexType* indices_buf, T* dx_buf, const int64_t padding_idx, const int64_t num_indices, const int64_t emb_dim) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_indices * emb_dim) { IndexType indices_index = i / emb_dim; IndexType emb_dim_index = i - indices_index * emb_dim; IndexType emb_size_index = indices_buf[indices_index]; if (emb_size_index != padding_idx) { IndexType from_index = emb_size_index * emb_dim + emb_dim_index; cuda::atomic::Add(dx_buf + from_index, dy_buf[i]); } } } template __global__ void indices_freq_kernel(const IndexType* indices_buf, const int64_t num_indices, int32_t* indices_freq, const int64_t emb_size) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_indices) { IndexType index = indices_buf[i]; assert(index >= 0 && index < emb_size); cuda::atomic::Add(indices_freq + index, 1); } } template __global__ void emb_scale_kernel(T* dx_buf, const int64_t emb_size, const int64_t emb_dim, int32_t* indices_freq) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, emb_size * emb_dim) { IndexType emb_size_index = i / emb_dim; if (indices_freq[emb_size_index] > 1) { dx_buf[i] /= static_cast(indices_freq[emb_size_index]); } } } template __global__ void embedding_renorm_kernel(const T* in_buf, T* out_buf, int32_t* indices_freq, const AccumType max_norm, const AccumType norm_type, const int64_t emb_size, const int64_t emb_dim) { int64_t tid = threadIdx.x; for (int64_t emb_idx = blockIdx.x; emb_idx < emb_size; emb_idx += gridDim.x) { if (indices_freq[emb_idx] == 0) { continue; } int64_t base_index = emb_idx * emb_dim; AccumType v = 0; for (int64_t i = tid; i < emb_dim; i += blockDim.x) { v += pow(abs(static_cast(in_buf[base_index + i])), norm_type); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ AccumType norm; v = BlockReduce(temp_storage).Sum(v); if (tid == 0) { norm = pow(v, static_cast(1.0 / norm_type)); } __syncthreads(); if (norm > max_norm) { auto scale = static_cast(max_norm / (norm + 1e-7)); for (int64_t i = tid; i < emb_dim; i += blockDim.x) { out_buf[base_index + i] = in_buf[base_index + i] * scale; } } } } } // namespace template struct EmbeddingReNormFunctor final { void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf, const double max_norm, const double norm_type, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) { indices_freq_kernel<<As()->cuda_stream()>>>( indices_buf, num_indices, tmp_buf, emb_size); using AccumType = typename AccumulateType::type; embedding_renorm_kernel <<As()->cuda_stream()>>>( in_buf, out_buf, tmp_buf, static_cast(max_norm), static_cast(norm_type), emb_size, emb_dim); } }; template struct EmbeddingFunctor final { void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim) { embedding_kernel <<As()->cuda_stream()>>>(weight_buf, indices_buf, out_buf, num_indices, emb_size, emb_dim); } }; template struct EmbeddingGradFunctor final { void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) { embedding_grad_kernel <<As()->cuda_stream()>>>(dy_buf, indices_buf, dx_buf, padding_idx, num_indices, emb_dim); if (scale_grad_by_freq) { indices_freq_kernel<<As()->cuda_stream()>>>( indices_buf, num_indices, tmp_buf, emb_size); emb_scale_kernel <<As()->cuda_stream()>>>(dx_buf, emb_size, emb_dim, tmp_buf); } } }; #define INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL(in_type_pair, index_type_pair) \ template struct EmbeddingReNormFunctor; \ template struct EmbeddingFunctor; \ template struct EmbeddingGradFunctor; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL, EMBEDDING_DATA_TYPE_SEQ_CUDA, INDEX_DATA_TYPE_SEQ); #undef INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/embedding_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct EmbeddingReNormFunctor final { void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf, const double max_norm, const double norm_type, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf); }; template struct EmbeddingFunctor final { void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim); }; template struct EmbeddingGradFunctor final { void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf, const int64_t padding_idx, const bool scale_grad_by_freq, const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf); }; #define EMBEDDING_DATA_TYPE_SEQ_CPU FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ #define EMBEDDING_DATA_TYPE_SEQ_CUDA FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/empty_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace user_op { class EmptyKernel final : public OpKernel { public: EmptyKernel() = default; ~EmptyKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); // None POD type need check if (!IsTriviallyCopyableDataType(out->data_type())) { CHECK(out->shape_view().NumAxes() > 0 && out->shape_view().elem_cnt() == 0) << "None POD Tensor created by empty op must be 0-Size tensor."; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("empty").SetCreateFn(); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/erfinv_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include namespace oneflow { template class CpuErfinvKernel final : public user_op::OpKernel { public: CpuErfinvKernel() = default; ~CpuErfinvKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); constexpr float central_range = 0.7; const T temp = static_cast(2.0) / static_cast(std::sqrt(M_PI)); T a[4] = {T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331)}; T b[4] = {T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801)}; T c[4] = {T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311)}; T d[2] = {T(3.543889200), T(1.637067800)}; FOR_RANGE(int32_t, i, 0, elem_cnt) { T z, num, dem; T x = x_ptr[i]; // Promise the correctness of inplace version. T x_abs = std::abs(x); if (x_abs > 1.0) { y_ptr[i] = std::numeric_limits::quiet_NaN(); continue; } if (x_abs == 1.0) { y_ptr[i] = std::copysign(std::numeric_limits::infinity(), x); continue; } if (x_abs <= static_cast(central_range)) { z = x * x; num = (((a[3] * z + a[2]) * z + a[1]) * z + a[0]); dem = ((((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + static_cast(1.0)); y_ptr[i] = x * num / dem; } else { z = std::sqrt(-std::log((static_cast(1.0) - x_abs) / static_cast(2.0))); num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; dem = (d[1] * z + d[0]) * z + static_cast(1.0); y_ptr[i] = std::copysign(num, x) / dem; } y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i])); y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i])); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_ERFINV_KERNEL(dtype) \ REGISTER_USER_KERNEL("erfinv") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_CPU_ERFINV_KERNEL(float) REGISTER_CPU_ERFINV_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/erfinv_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { template struct ErfInvFunctor { OF_DEVICE_FUNC ErfInvFunctor() {} OF_DEVICE_FUNC T operator()(T x) const { return erfinv(x); } }; template class GpuErfinvKernel final : public user_op::OpKernel { public: GpuErfinvKernel() = default; ~GpuErfinvKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); OF_CUDA_CHECK(cuda::elementwise::Unary(ErfInvFunctor(), elem_cnt, y->mut_dptr(), x->dptr(), ctx->stream()->As()->cuda_stream())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ERFINV_KERNEL(dtype) \ REGISTER_USER_KERNEL("erfinv") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_CUDA_ERFINV_KERNEL(float) REGISTER_CUDA_ERFINV_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/expand_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template std::unique_ptr NewPrimitive(Context* ctx) { const auto* in_desc = ctx->TensorDesc4ArgNameAndIndex("in", 0); const auto* out_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0); size_t max_ndim = std::max(in_desc->shape().size(), out_desc->shape().size()); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kIdentity, in_desc->data_type(), out_desc->data_type(), max_ndim); } auto PrimitiveExists() { return hob::make_custom("BroadcastElementwiseUnaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewPrimitive(&ctx).operator bool(); }); } } // namespace class ExpandKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ExpandKernel() = default; ~ExpandKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); auto in_shape = in->shape_view(); auto out_shape = out->shape_view(); // handle 0-size tensor if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t dim) { return dim <= 0; })) { return; } auto prim = NewPrimitive(ctx); CHECK(prim); if (in_shape.size() == 0 && in_shape.elem_cnt() == 1) { // handle 0-dim tensor // NOTE: this handle will be remove when BroadcastElementwiseUnary primitive support 0-dim // tensor int64_t scalar_ndim = 1; Shape scalar_shape(DimVector{scalar_ndim}); Shape scalar_stride(DimVector{scalar_ndim}); prim->Launch(ctx->stream(), scalar_ndim, scalar_shape.data(), scalar_stride.data(), in->dptr(), out_shape.size(), out_shape.data(), out->stride().data(), out->mut_dptr()); } else { prim->Launch(ctx->stream(), in_shape.size(), in_shape.data(), in->stride().data(), in->dptr(), out_shape.size(), out_shape.data(), out->stride().data(), out->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("expand").SetCreateFn().SetIsMatchedHob(PrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eye_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/eye_kernel_util.h" #include "oneflow/core/common/data_type.h" namespace oneflow { namespace user_op { template class EyeKernel final : public OpKernel { public: EyeKernel() = default; ~EyeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { int64_t rows = ctx->Attr("rows"); int64_t cols = ctx->Attr("cols"); if (rows == 0 || cols == 0) { return; } Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); T* out = out_tensor->mut_dptr(); Memset( ctx->stream(), out_tensor->mut_dptr(), 0, out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type())); EyeFunctor()(ctx->stream(), cols, std::min(cols, rows), out); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_EYE_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("eye").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobAttr("dtype") == GetDataType::value)); #define REGISTER_EYE_KERNELS_WITH_DEVICE(device) \ REGISTER_EYE_KERNEL(device, bool) \ REGISTER_EYE_KERNEL(device, uint8_t) \ REGISTER_EYE_KERNEL(device, int8_t) \ REGISTER_EYE_KERNEL(device, int32_t) \ REGISTER_EYE_KERNEL(device, int64_t) \ REGISTER_EYE_KERNEL(device, float) \ REGISTER_EYE_KERNEL(device, double) // Register CPU version REGISTER_EYE_KERNELS_WITH_DEVICE(DeviceType::kCPU); // Register CUDA version #ifdef WITH_CUDA REGISTER_EYE_KERNELS_WITH_DEVICE(DeviceType::kCUDA); #endif #undef REGISTER_EYE_KERNELS_WITH_DEVICE #undef REGISTER_EYE_KERNEL } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eye_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/eye_kernel_util.h" namespace oneflow { namespace user_op { template struct EyeFunctor final { void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out) { SetOneInDiag(cols, rows, out); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_EYE_FUNCTOR, (DeviceType::kCPU), EYE_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/eye_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/eye_kernel_util.h" namespace oneflow { namespace user_op { template __global__ void EyeForwardGpuKernel(const int64_t cols, const int64_t rows, T* out) { SetOneInDiag(cols, rows, out); } template struct EyeFunctor final { void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out) { RUN_CUDA_KERNEL((EyeForwardGpuKernel), stream, rows, cols, rows, out); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_EYE_FUNCTOR, (DeviceType::kCUDA), EYE_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow #endif // End WITH_CUDA ================================================ FILE: oneflow/user/kernels/eye_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { namespace user_op { #define EYE_DATA_TYPE_SEQ \ FLOATING_DATA_TYPE_SEQ \ INT_DATA_TYPE_SEQ \ UNSIGNED_INT_DATA_TYPE_SEQ \ BOOL_DATA_TYPE_SEQ template struct EyeFunctor final { void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out); }; template OF_DEVICE_FUNC void SetOneInDiag(const int64_t cols, const int64_t rows, T* out) { const T one = static_cast(1); XPU_1D_KERNEL_LOOP(i, rows) { const int64_t index = i * cols + i; out[index] = one; } } #define INSTANTIATE_EYE_FUNCTOR(device_type_v, dtype_pair) \ template struct EyeFunctor; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/fake_quantization_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include namespace oneflow { template void FakeQuantizationPerLayerSymmetric(const T* in_ptr, const T scale, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = out * scale; } } template void FakeQuantizationPerLayerAffine(const T* in_ptr, const T scale, const T zero_point, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit)) - 1; T lower_bound = 0; uint8_t zero_point_uint8 = static_cast(std::round(zero_point)); FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale + zero_point_uint8); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = (out - zero_point_uint8) * scale; } } template void FakeQuantizationPerLayerCambricon(const T* in_ptr, const T shift, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; T scale = static_cast(pow(2.0, static_cast(shift))); FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = out * scale; } } template class CpuFakeQuantizationKernel final : public user_op::OpKernel { public: CpuFakeQuantizationKernel() = default; ~CpuFakeQuantizationKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const T* in_ptr = in->dptr(); const T* scale_ptr = scale->dptr(); T* out_ptr = out->mut_dptr(); // round to even auto origin_round_mode = std::fegetround(); std::fesetround(FE_TONEAREST); if (quantization_formula == "google") { int64_t outer_num = 1; int64_t inner_num = in->shape_view().elem_cnt(); if (scale->shape_view().elem_cnt() > 1) { // per-channel quantization outer_num = in->shape_view().At(0); inner_num = in->shape_view().Count(1); } if (quantization_scheme == "symmetric") { FOR_RANGE(int64_t, c, 0, outer_num) { FakeQuantizationPerLayerSymmetric(in_ptr, scale_ptr[c], quantization_bit, inner_num, out_ptr); in_ptr += inner_num; out_ptr += inner_num; } } else { // quantization_scheme == "affine" const T* zero_point_ptr = zero_point->dptr(); FOR_RANGE(int64_t, c, 0, outer_num) { FakeQuantizationPerLayerAffine(in_ptr, scale_ptr[c], zero_point_ptr[c], quantization_bit, inner_num, out_ptr); in_ptr += inner_num; out_ptr += inner_num; } } } else if (quantization_formula == "cambricon") { FakeQuantizationPerLayerCambricon(in_ptr, scale_ptr[0], quantization_bit, in->shape_view().elem_cnt(), out_ptr); } else { UNIMPLEMENTED(); } std::fesetround(origin_round_mode); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FAKE_QUANTIZATION_KERNEL(dtype) \ REGISTER_USER_KERNEL("fake_quantization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_FAKE_QUANTIZATION_KERNEL(float); REGISTER_FAKE_QUANTIZATION_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fake_quantization_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { namespace { template __global__ void FakeQuantizationSymmetric(const T* in_ptr, const T* scale_ptr, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; while (gid < elements) { int64_t channel_index = gid / panel_size; int64_t scale_idx = min(scale_size - 1, channel_index); T scale = scale_ptr[scale_idx]; T out = nearbyint(in_ptr[gid] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = out * scale; gid += step; } } template __global__ void FakeQuantizationAffine(const T* in_ptr, const T* scale_ptr, const T* zero_point_ptr, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit)) - 1; T lower_bound = 0; while (gid < elements) { int64_t channel_index = gid / panel_size; int64_t scale_idx = min(scale_size - 1, channel_index); T scale = scale_ptr[scale_idx]; T zero_point = zero_point_ptr[scale_idx]; T out = nearbyint(in_ptr[gid] / scale + zero_point); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = (out - zero_point) * scale; gid += step; } } template __global__ void FakeQuantizationCambricon(const T* in_ptr, const T* shift, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; T scale = static_cast(pow(2.0, static_cast(shift[0]))); while (gid < elements) { T out = nearbyint(in_ptr[gid] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = out * scale; gid += step; } } } // namespace template class GpuFakeQuantizationKernel final : public user_op::OpKernel { public: GpuFakeQuantizationKernel() = default; ~GpuFakeQuantizationKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const int64_t elements = in->shape_view().elem_cnt(); const int64_t panel_size = in->shape_view().Count(1); const int64_t scale_size = scale->shape_view().elem_cnt(); // round to even auto origin_round_mode = std::fegetround(); std::fesetround(FE_TONEAREST); if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { RUN_CUDA_KERNEL((FakeQuantizationSymmetric), ctx->stream(), elements, in->dptr(), scale->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } else { // quantization_scheme == "affine" RUN_CUDA_KERNEL((FakeQuantizationAffine), ctx->stream(), elements, in->dptr(), scale->dptr(), zero_point->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } } else if (quantization_formula == "cambricon") { RUN_CUDA_KERNEL((FakeQuantizationCambricon), ctx->stream(), elements, in->dptr(), scale->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } else { UNIMPLEMENTED(); } std::fesetround(origin_round_mode); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FAKE_QUANTIZATION_KERNEL(dtype) \ REGISTER_USER_KERNEL("fake_quantization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_FAKE_QUANTIZATION_KERNEL(float); REGISTER_FAKE_QUANTIZATION_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fft_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/fft_kernel_util.h" #include #include "pocketfftplan.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/framework/user_op_tensor.h" namespace oneflow { template static void _conj_symmetry_cpu(T* data_out, const Shape& shape, const std::vector& strides, const int64_t last_dim, int64_t elem_count) { const oneflow::NdIndexStrideOffsetHelper helper(strides.data(), shape.size()); // NOTE: dims must be sorted int64_t last_dim_size = shape[last_dim]; int64_t last_dim_half = last_dim_size / 2; int64_t ndim = shape.size(); std::vector indices(ndim); for (int offset = 0; offset < elem_count; offset++) { helper.OffsetToNdIndex(offset, indices.data(), ndim); if (indices[last_dim] <= last_dim_half) { continue; } int64_t cur_last_dim_index = indices[last_dim]; // get symmetric indices[last_dim] = last_dim_size - cur_last_dim_index; int64_t symmetric_offset = helper.NdIndexToOffset(indices.data(), ndim); // conj data_out[offset] = std::conj(data_out[symmetric_offset]); } } template struct FillConjSymmetryUtil { static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, const Stride& strides, const int64_t last_dim, int64_t elem_count) { std::vector strides_vec(strides.begin(), strides.end()); _conj_symmetry_cpu(/*data_out*/ data_out, /*shape*/ shape, /*strides*/ strides_vec, /*last_dim*/ last_dim, /*elem_count*/ elem_count); } }; template struct ComplexConvertUtil { static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, size_t len, size_t n) { size_t fact_len = 2 * len - 2; // input_shape.back() for (int i = 0; i < n; i++) { int index_x = i / fact_len; int index_y = i % fact_len; if (index_y == 0) { dst[i] = in[index_x * len]; } else if (index_y == len - 1) { dst[i] = in[(index_x + 1) * len - 1]; } else if (index_y < len - 1 && index_y > 0) { dst[i] = in[index_x * len + index_y]; } else { auto index = (index_x + 2) * len - index_y - 2; auto realvalue = in[index].real(); dst[i].real(realvalue); auto imagvalue = -in[index].imag(); dst[i].imag(imagvalue); } } } static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, size_t n) { for (int i = 0; i < n; i++) { out[2 * i] = in[i].real(); out[2 * i + 1] = in[i].imag(); } } }; template struct FftC2CKernelUtil { static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, FCT_TYPE norm_fct, DataType real_type) { PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2C); PocketFFtConfig config(params); config.excute(data_in, data_out); } }; template struct FftR2CKernelUtil { static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, IN norm_fct, DataType real_type) { PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C); PocketFFtConfig config(params); config.excute(data_in, data_out); } }; template struct FftC2RKernelUtil { static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, int64_t last_dim_size, const std::vector& dims, OUT norm_fct, DataType real_type) { PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, /*is_forward=*/false, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2R); PocketFFtConfig config(params); config.excute(data_in, data_out); } }; template struct FftStftKernelUtil { static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& axes, IN norm_fct, int64_t len, int64_t dims, int64_t batch) { PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, axes, forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C); PocketFFtConfig config(params); int64_t in_offset = len; int64_t out_offset = len / 2 + 1; for (int j = 0; j < dims; j++) { for (int i = 0; i < batch; i++) { const IN* in = data_in + j * batch * in_offset + i * in_offset; OUT* out = data_out + j * batch * out_offset + i * out_offset; config.excute(in, out); } } } }; template struct FillConjSymmetryUtil>; template struct FillConjSymmetryUtil>; template struct ComplexConvertUtil>; template struct ComplexConvertUtil>; template struct FftC2CKernelUtil, float>; template struct FftC2CKernelUtil, double>; template struct FftR2CKernelUtil>; template struct FftR2CKernelUtil>; template struct FftC2RKernelUtil, float>; template struct FftC2RKernelUtil, double>; template struct FftStftKernelUtil>; template struct FftStftKernelUtil>; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fft_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/user/kernels/to_contiguous_kernel.h" #if CUDA_VERSION >= 11000 #include "oneflow/user/kernels/fft_kernel_util.h" #include "cufft_plan_cache.h" namespace oneflow { namespace { template __global__ void fft_apply_normalization(FFTTYPE* dst, const double normalization_scale, size_t n, bool IsNormalized) { if (!IsNormalized) { return; } CUDA_1D_KERNEL_LOOP(i, n) { dst[i].x *= normalization_scale; dst[i].y *= normalization_scale; }; } struct FillConjSymmetricParams { int64_t last_dim; int64_t elem_count; int64_t ndim; oneflow::NdIndexStrideOffsetHelper helper; int64_t last_dim_size; int64_t last_dim_half; FillConjSymmetricParams() = default; FillConjSymmetricParams(const Shape& shape, const Stride& strides, int64_t last_dim_, int64_t elemcnt) : last_dim(last_dim_), elem_count(elemcnt), ndim(strides.size()), helper(strides.data(), ndim) { CHECK_OR_THROW(strides.size() == shape.size()); last_dim_size = shape[last_dim]; last_dim_half = last_dim_size / 2; } }; } // namespace template __global__ void _conj_symmetry_cuda(T* data_out, FillConjSymmetricParams param) { CUDA_1D_KERNEL_LOOP_T(int64_t, offset, param.elem_count) { int64_t ndim = param.ndim; int64_t indices[SHAPE_MAX_AXIS_SIZE]; param.helper.OffsetToNdIndex(offset, indices, ndim); if (indices[param.last_dim] <= param.last_dim_half) { continue; } int64_t cur_last_dim_index = indices[param.last_dim]; // get symmetric indices[param.last_dim] = param.last_dim_size - cur_last_dim_index; int64_t symmetric_offset = param.helper.NdIndexToOffset(indices, ndim); // conj data_out[offset] = T{data_out[symmetric_offset].x, -data_out[symmetric_offset].y}; } } template struct FillConjSymmetryUtil { static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, const Stride& strides, const int64_t last_dim, int64_t elem_count) { FillConjSymmetricParams param(shape, strides, last_dim, elem_count); _conj_symmetry_cuda<<As()->cuda_stream()>>>(data_out, param); } }; template __global__ void _convert_to_double_sized(const IN* in, OUT* dst, size_t len, size_t n) { size_t fact_len = 2 * len - 2; CUDA_1D_KERNEL_LOOP(i, n) { int index_x = i / fact_len; int index_y = i % fact_len; if (index_y == 0) { dst[i] = in[index_x * len]; } else if (index_y == len - 1) { dst[i] = in[(index_x + 1) * len - 1]; } else if (index_y < len - 1 && index_y > 0) { dst[i] = in[index_x * len + index_y]; } else { auto index = (index_x + 2) * len - index_y - 2; dst[i].x = in[index].x; dst[i].y = -in[index].y; } } } template __global__ void _convert_complex_to_real(const IN* in, OUT* out, size_t n) { CUDA_1D_KERNEL_LOOP(i, n) { out[2 * i] = in[i].x; out[2 * i + 1] = in[i].y; }; } template struct ComplexConvertUtil { static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, size_t len, size_t n) { _convert_to_double_sized<<As()->cuda_stream()>>>(in, dst, len, n); } static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, size_t n) { _convert_complex_to_real<<As()->cuda_stream()>>>(in, out, n); } }; template class StftGpuKernel final : public user_op::OpKernel { public: StftGpuKernel() = default; ~StftGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const bool normalized = ctx->Attr("normalized"); const bool onesided = ctx->Attr("onesided"); const bool return_complex = ctx->Attr("return_complex"); const ShapeView& input_shape = input->shape_view(); const ShapeView& output_shape = output->shape_view(); const Stride& input_stride = input->stride(); const int out_elem_cnt = return_complex ? output->shape_view().elem_cnt() : output->shape_view().elem_cnt() / 2; const dtype_in* data_in = input->dptr(); dtype_in* data_out = output->mut_dptr(); dtype_out* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); int64_t ndim = 1; int64_t batch = static_cast(input_shape.At(1)); int64_t fft_size = static_cast(input_shape.At(2)); int64_t rank[1] = {fft_size}; const Stride& in_stride = {input_stride.at(1), input_stride.at(2)}; const Shape& in_shape = {batch, fft_size}; const Shape& out_shape = {batch, fft_size / 2 + 1}; Stride out_stride = Stride(out_shape); CuFFTParams params(in_shape, out_shape, in_stride, out_stride, ndim, CUFFT_EXCUTETYPE::R2C, input->data_type()); CuFFTConfig config(params); auto& plan = config.plan(); OF_CUFFT_CHECK(cufftSetStream(plan, ctx->stream()->As()->cuda_stream())); void* workspace{}; OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); int64_t in_offset = input_stride.at(0); int64_t out_offset = std::accumulate(out_shape.begin(), out_shape.end(), 0, std::multiplies()); int64_t signal_groups_count = static_cast(input_shape.At(0)); for (int64_t i = 0; i < signal_groups_count; i++) { config.excute((void*)(data_in + i * in_offset), (void*)(out_tmp_buffer + i * out_offset), /*forward=*/true); } OF_CUDA_CHECK(cudaFree(workspace)); if (!onesided) { size_t last_dim_length = fft_size / 2 + 1; dtype_out* doublesided_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()) + out_elem_cnt; ComplexConvertUtil::ConvertToDoubleSized( ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, out_elem_cnt); out_tmp_buffer = doublesided_tmp_buffer; } const double normalization_scale = _fft_normalization_scale(input_shape.back(), normalized); fft_apply_normalization<<stream()->As()->cuda_stream()>>>( out_tmp_buffer, normalization_scale, out_elem_cnt, normalized); if (!return_complex) { ComplexConvertUtil::ConvertComplexToReal( ctx->stream(), out_tmp_buffer, data_out, out_elem_cnt); } else { // TODO(yzm):support return_complex after oneflow supports complex numbers } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_STFT_GPU_KERNEL(intype, outtype) \ REGISTER_USER_KERNEL("stft") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& output_shape = ctx->InputShape("output", 0); \ const bool return_complex = ctx->Attr("return_complex"); \ const bool onesided = ctx->Attr("onesided"); \ int64_t output_elem_cnt = \ return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ const int64_t output_bytes = GetCudaAlignedSize(output_elem_cnt * sizeof(outtype)); \ return onesided ? output_bytes : 2 * output_bytes; \ }); REGISTER_STFT_GPU_KERNEL(float, cufftComplex) REGISTER_STFT_GPU_KERNEL(double, cufftDoubleComplex) template class FftC2CKernelUtil { static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, FCT_TYPE normalization, DataType real_type) { // NOTE: before calling `FftC2CKernelUtil`, input must be // batched out already CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), CUFFT_EXCUTETYPE::C2C, real_type); CuFFTConfig config(params); auto& plan = config.plan(); OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); void* workspace{}; OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); config.excute((void*)data_in, (void*)data_out, forward); OF_CUDA_CHECK(cudaFree(workspace)); } }; template struct FftR2CKernelUtil { static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, IN normalization, DataType real_type) { // NOTE: before calling `FftR2CKernelUtil`, input must be batched // out already CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), CUFFT_EXCUTETYPE::R2C, real_type); CuFFTConfig config(params); auto& plan = config.plan(); OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); void* workspace{}; OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); config.excute((void*)data_in, (void*)data_out, forward); OF_CUDA_CHECK(cudaFree(workspace)); } }; template struct FftC2RKernelUtil { static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, int64_t last_dim_size, const std::vector& dims, OUT normalization, DataType real_type) { // NOTE: before calling `FftC2RKernelUtil`, input must be batched // out already CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), CUFFT_EXCUTETYPE::C2R, real_type); CuFFTConfig config(params); auto& plan = config.plan(); OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); void* workspace{}; OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); config.excute((void*)data_in, (void*)data_out, forward); OF_CUDA_CHECK(cudaFree(workspace)); } }; template struct FillConjSymmetryUtil; template struct FillConjSymmetryUtil; template struct ComplexConvertUtil; template struct ComplexConvertUtil; template struct FftC2CKernelUtil; template struct FftC2CKernelUtil; template struct FftR2CKernelUtil; template struct FftR2CKernelUtil; template struct FftC2RKernelUtil; template struct FftC2RKernelUtil; } // namespace oneflow #endif // CUDA_VERSION >= 11000 ================================================ FILE: oneflow/user/kernels/fft_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ #include #include #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { template inline T _fft_normalization_scale(const int32_t frame_length, bool normalized) { if (!normalized) { return static_cast(1.0); } return static_cast(1.0 / std::sqrt(frame_length)); } template struct FillConjSymmetryUtil { static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, const Stride& strides, const int64_t last_dim, int64_t elem_count); }; template struct ComplexConvertUtil { static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, size_t len, size_t n); static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, size_t n); }; template struct FftC2CKernelUtil { static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, FCT_TYPE norm_fct, DataType real_type); }; template struct FftR2CKernelUtil { static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& dims, IN norm_fct, DataType real_type); }; template struct FftC2RKernelUtil { static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, int64_t last_dim_size, const std::vector& dims, OUT norm_fct, DataType real_type); }; template struct FftStftKernelUtil { static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out, const Shape& input_shape, const Shape& output_shape, const Stride& input_stride, const Stride& output_stride, bool forward, const std::vector& axes, IN norm_fct, int64_t len, int64_t dims, int64_t batch); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/fft_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "pocketfftplan.h" #include "oneflow/core/common/stride.h" #include "oneflow/user/kernels/fft_kernel_util.h" using namespace pocketfft; namespace oneflow { template class FftC2CKernel final : public user_op::OpKernel { public: FftC2CKernel() = default; ~FftC2CKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); bool forward = ctx->Attr("forward"); double norm_fct = ctx->Attr("norm_fct"); const std::vector& dims = ctx->Attr>("dims"); const T* input_ptr = input->dptr(); T* out_ptr = out->mut_dptr(); Shape input_shape(input->shape_view()); Shape out_shape(out->shape_view()); if (input->data_type() == kComplex64) { FftC2CKernelUtil::FftC2CForward( ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), forward, dims, static_cast(norm_fct), DataType::kFloat); } else if (input->data_type() == kComplex128) { FftC2CKernelUtil::FftC2CForward( ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), forward, dims, static_cast(norm_fct), DataType::kDouble); } else { CHECK_OR_THROW(false) << "expects kComplex64 or kComplex128, but got " << input->data_type(); } } }; template class FftR2CKernel final : public user_op::OpKernel { public: FftR2CKernel() = default; ~FftR2CKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); bool onesided = ctx->Attr("onesided"); double norm_fct = ctx->Attr("norm_fct"); const std::vector& dims = ctx->Attr>("dims"); const dtype_in* input_ptr = input->dptr(); dtype_out* out_ptr = out->mut_dptr(); Shape input_shape(input->shape_view()); Shape out_shape(out->shape_view()); if (input->data_type() == kFloat || input->data_type() == kDouble) { FftR2CKernelUtil::FftR2CForward( ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), /*forward=*/true, dims, norm_fct, /*real_type=*/input->data_type()); } else { CHECK_OR_THROW(false) << "expects kFloat or kDouble, but gets " << input->data_type(); } if (!onesided) { FillConjSymmetryUtil::FillConjSymmetryForward( ctx->stream(), out_ptr, out_shape, out->stride(), dims.back(), out_shape.elem_cnt()); } } }; template class FftC2RKernel final : public user_op::OpKernel { public: FftC2RKernel() = default; ~FftC2RKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t last_dim_size = ctx->Attr("last_dim_size"); double norm_fct = ctx->Attr("norm_fct"); const std::vector& dims = ctx->Attr>("dims"); const dtype_in* input_ptr = input->dptr(); dtype_out* out_ptr = out->mut_dptr(); Shape input_shape(input->shape_view()); Shape out_shape(out->shape_view()); out_shape[dims.back()] = last_dim_size; if (input->data_type() == kComplex64 || input->data_type() == kComplex128) { FftC2RKernelUtil::FftC2RForward( ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), /*forward=*/false, /*last_dim_size=*/last_dim_size, dims, norm_fct, /*real_type=*/out->data_type()); } else { CHECK_OR_THROW(false) << "expects kComplex64 or kComplex128, but gets " << input->data_type(); } } }; template class StftCpuKernel final : public user_op::OpKernel { public: StftCpuKernel() = default; ~StftCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto normalized = ctx->Attr("normalized"); const auto return_complex = ctx->Attr("return_complex"); const bool onesided = ctx->Attr("onesided"); const ShapeView input_shape = input->shape_view(); const ShapeView output_shape = output->shape_view(); const auto output_elem_cnt = output_shape.elem_cnt() / 2; int64_t dims = input_shape.At(0); int64_t batch = input_shape.At(1); int64_t len = input_shape.back(); const dtype_in* data_in = input->dptr(); dtype_in* data_out = output->mut_dptr(); dtype_out* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); Shape out_tmp_shape = Shape{len}; Stride out_tmp_stride = Stride(out_tmp_shape); std::vector axes(out_tmp_shape.size()); std::iota(axes.begin(), axes.end(), 0); auto norm_fct = _fft_normalization_scale(len, normalized); FftStftKernelUtil::FftStftForward( ctx->stream(), data_in, out_tmp_buffer, out_tmp_shape, out_tmp_shape, out_tmp_stride, out_tmp_stride, true, /*axes=*/axes, /*norm_fct=*/norm_fct, /*len=*/len, /*dims=*/dims, /*batch=*/batch); if (!onesided) { dtype_out* doublesided_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()) + output_elem_cnt; size_t last_dim_length = len / 2 + 1; size_t elem_conut = output_elem_cnt; ComplexConvertUtil::ConvertToDoubleSized( ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, elem_conut); out_tmp_buffer = doublesided_tmp_buffer; } if (!return_complex) { ComplexConvertUtil::ConvertComplexToReal( ctx->stream(), out_tmp_buffer, data_out, output_elem_cnt); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_STFT_CPU_KERNEL(dtype_in, dtype_out) \ REGISTER_USER_KERNEL("stft") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& output_shape = ctx->InputShape("output", 0); \ const bool return_complex = ctx->Attr("return_complex"); \ const bool onesided = ctx->Attr("onesided"); \ int64_t output_elem_cnt = \ return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ const int64_t output_bytes = (output_elem_cnt * sizeof(dtype_out)); \ return onesided ? output_bytes : 2 * output_bytes; \ }); REGISTER_STFT_CPU_KERNEL(double, std::complex) REGISTER_STFT_CPU_KERNEL(float, std::complex) #define REGISTER_FFTC2C_KERNELS(device_type, dtype, fct_type) \ REGISTER_USER_KERNEL("fft_c2c") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) REGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex, float); REGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex, double); #ifdef WITH_CUDA REGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuComplex, float); REGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double); #endif #define REGISTER_FFTR2C_KERNELS(device_type, dtype_in, dtype_out) \ REGISTER_USER_KERNEL("fft_r2c") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) REGISTER_FFTR2C_KERNELS(DeviceType::kCPU, float, std::complex); REGISTER_FFTR2C_KERNELS(DeviceType::kCPU, double, std::complex); #ifdef WITH_CUDA REGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, float, cuComplex); REGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, double, cuDoubleComplex); #endif #define REGISTER_FFTC2R_KERNELS(device_type, dtype_in, dtype_out) \ REGISTER_USER_KERNEL("fft_c2r") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) REGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex, float); REGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex, double); #ifdef WITH_CUDA REGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuComplex, float); REGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double); #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fill_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/scalar.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { template std::unique_ptr NewFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } } // namespace class FillKernel final : public user_op::OpKernel { public: FillKernel() = default; ~FillKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const bool is_floating_value = ctx->Attr("is_floating_value"); const Scalar value = is_floating_value ? Scalar(ctx->Attr("floating_value")) : Scalar(ctx->Attr("integral_value")); const int32_t elem_cnt = in->shape_view().elem_cnt(); CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) { return; } std::unique_ptr fill = NewFillPrimitive(ctx); CHECK(fill); fill->Launch(ctx->stream(), out->mut_dptr(), value, elem_cnt); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto FillPrimitiveExists() { return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewFillPrimitive(&ctx).operator bool(); }); } template class FillTensorCpuKernel final : public user_op::OpKernel { public: FillTensorCpuKernel() = default; ~FillTensorCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const T value_ = value->dptr()[0]; const int32_t elem_cnt = in->shape_view().elem_cnt(); T* out_ptr = out->mut_dptr(); FOR_RANGE(int32_t, i, 0, elem_cnt) { out_ptr[i] = value_; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FILL_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("fill_tensor_") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_FILL_CPU_KERNEL(float) REGISTER_FILL_CPU_KERNEL(float16) REGISTER_FILL_CPU_KERNEL(double) REGISTER_FILL_CPU_KERNEL(int8_t) REGISTER_FILL_CPU_KERNEL(int32_t) REGISTER_FILL_CPU_KERNEL(int64_t) REGISTER_USER_KERNEL("fill_").SetCreateFn().SetIsMatchedHob(FillPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fill_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template __global__ void FillTensorGpuForward(const int n, const T* value, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = value[0]; } } }; // namespace template class FillTensorGpuKernel final : public user_op::OpKernel { public: FillTensorGpuKernel() = default; ~FillTensorGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const int32_t elem_cnt = in->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FillTensorGpuForward), ctx->stream(), elem_cnt, elem_cnt, value->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FILL_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fill_tensor_") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_FILL_CUDA_KERNEL(float) REGISTER_FILL_CUDA_KERNEL(half) REGISTER_FILL_CUDA_KERNEL(double) REGISTER_FILL_CUDA_KERNEL(int8_t) REGISTER_FILL_CUDA_KERNEL(int32_t) REGISTER_FILL_CUDA_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/flip_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace { const int32_t NDIMS = 16; struct SIZE_V { int32_t val[NDIMS]; }; struct VIS { bool val[NDIMS] = {false}; }; template void FlipCpuForward(const int32_t element, const int64_t total_dims, const SIZE_V sizes_v, const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) { for (int i = 0; i < element; i++) { int32_t cur_indices = i; int32_t rem = 0; int32_t dst_offset = 0; for (int32_t d = 0; d < total_dims; d++) { int32_t temp = cur_indices; cur_indices = cur_indices / strides_v.val[d]; rem = temp - cur_indices * strides_v.val[d]; dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d] : cur_indices * strides_v.val[d]; cur_indices = rem; } out_dptr[i] = in_dptr[dst_offset]; } } } // namespace template class FlipCpuKernel final : public user_op::OpKernel { public: FlipCpuKernel() = default; ~FlipCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = y_tensor->shape_view().elem_cnt(); if (elem_cnt == 0) { return; } const int32_t total_dims = y_tensor->shape_view().NumAxes(); std::vector dims = ctx->Attr>("dims"); VIS vis; for (auto x : dims) { vis.val[x] = true; } SIZE_V sizes_v; for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); } // TODO(bbuf) delete strides caluculate, after tensor strides supported SIZE_V strides_v; strides_v.val[total_dims - 1] = 1; for (int32_t i = total_dims - 2; i >= 0; i--) { strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape_view().At(i + 1); } FlipCpuForward(elem_cnt, total_dims, sizes_v, vis, strides_v, x_tensor->dptr(), y_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FLIP_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("flip").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_FLIP_CPU_KERNEL(bool) REGISTER_FLIP_CPU_KERNEL(float) REGISTER_FLIP_CPU_KERNEL(double) REGISTER_FLIP_CPU_KERNEL(uint8_t) REGISTER_FLIP_CPU_KERNEL(int8_t) REGISTER_FLIP_CPU_KERNEL(int32_t) REGISTER_FLIP_CPU_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/flip_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { namespace { const int32_t NDIMS = 16; struct SIZE_V { int32_t val[NDIMS]; }; struct VIS { bool val[NDIMS] = {false}; }; template __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, const SIZE_V sizes_v, const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) { CUDA_1D_KERNEL_LOOP(i, element) { int32_t cur_indices = i; int32_t rem = 0; int32_t dst_offset = 0; for (int32_t d = 0; d < total_dims; d++) { int32_t temp = cur_indices; cur_indices = cur_indices / strides_v.val[d]; rem = temp - cur_indices * strides_v.val[d]; dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d] : cur_indices * strides_v.val[d]; cur_indices = rem; } out_dptr[i] = in_dptr[dst_offset]; } } /* Example tensor: [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]] Given parameters: BlockSize=4, GridSize=4 For each block_i, `block_begin_idx` is calculated as (i - 1) * BlockSize = (i - 1) * 4, and `thread_end_idx` is set to 4 for all blocks except the final block. In the final block, `thread_end_idx` is 2, representing the border index of the active thread. `i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before flipping. For instance, consider block 1 and thread 2 (element 6). The element is located at row 0, column 7 in the tensor. Its original index `i_ori` is 7, and after flipping, it is mapped to row 0, column 0. ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ global mem before: │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ A │ B │ C │ D │ x │ x │ └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ block0 │ block1 │ block2 │ block3 ┌───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┐ shm after loading: │ 3 │ 2 │ 1 │ 0 │ 7 │ 6 │ 5 │ 4 │ B │ A │ 9 │ 8 │ D │ C │ x │ x │ └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ global mem after: │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ D │ C │ B │ A │ 9 │ 8 │ 7 │ x │ x │ └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ */ template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, const T* in_dptr, T* out_dptr) { __shared__ T shm[ep::CudaStream::kDefaultBlockSize]; CUDA_1D_KERNEL_LOOP(i, element) { int32_t block_begin_idx = blockDim.x * blockIdx.x; int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx; int32_t i_ori = block_begin_idx + (thread_end_idx - threadIdx.x - 1); shm[threadIdx.x] = in_dptr[i_ori]; __syncthreads(); int32_t row = i_ori / last_dim_size; int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1; out_dptr[row * last_dim_size + col] = shm[threadIdx.x]; } } } // namespace template class FlipGpuKernel final : public user_op::OpKernel { public: FlipGpuKernel() = default; ~FlipGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = y_tensor->shape_view().elem_cnt(); if (elem_cnt == 0) { return; } const int32_t total_dims = y_tensor->shape_view().NumAxes(); std::vector dims = ctx->Attr>("dims"); VIS vis; for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { RUN_CUDA_KERNEL((FlipLastDimGpuForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), y_tensor->mut_dptr()); return; } SIZE_V sizes_v; for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); } SIZE_V strides_v; for (int32_t i = 0; i < total_dims; i++) { strides_v.val[i] = CHECK_JUST(VectorAt(y_tensor->stride(), i)); } RUN_CUDA_KERNEL((FlipGpuForward), ctx->stream(), elem_cnt, elem_cnt, total_dims, sizes_v, vis, strides_v, x_tensor->dptr(), y_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FLIP_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("flip").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_FLIP_CUDA_KERNEL(bool) REGISTER_FLIP_CUDA_KERNEL(float) REGISTER_FLIP_CUDA_KERNEL(half) REGISTER_FLIP_CUDA_KERNEL(double) REGISTER_FLIP_CUDA_KERNEL(uint8_t) REGISTER_FLIP_CUDA_KERNEL(int8_t) REGISTER_FLIP_CUDA_KERNEL(int32_t) REGISTER_FLIP_CUDA_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fold_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/kernels/fold_kernel_util.h" namespace oneflow { namespace user_op { namespace { // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template class FoldOpKernelState : public OpKernelState { public: using ParamType = FoldParams; FoldOpKernelState(const ShapeView& input_shape, const std::vector& output_size, const std::vector& kernel_size, const std::vector& padding, const std::vector& stride, const std::vector& dilation) : params_(input_shape.At(0), input_shape.At(ParamType::kInputChannelDim), output_size.data(), input_shape.ptr() + SDIM, kernel_size.data(), padding.data(), stride.data(), dilation.data()) {} const ParamType& params() const { return params_; } private: ParamType params_; }; template std::shared_ptr> CreateFoldOpKernelState( const ShapeView& input_shape, const std::vector& output_size, const std::vector& kernel_size, const std::vector& padding, const std::vector& stride, const std::vector& dilation) { std::shared_ptr> state( new FoldOpKernelState(input_shape, output_size, kernel_size, padding, stride, dilation)); return state; } template class FoldKernel final : public OpKernel { public: FoldKernel() = default; ~FoldKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* output = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); const std::vector kernel_size = ctx->Attr>("kernel_size"); const std::vector dilation = ctx->Attr>("dilation_rate"); const std::vector padding = ctx->Attr>("padding"); const std::vector stride = ctx->Attr>("strides"); const auto& state_ptr = CreateFoldOpKernelState( input->shape_view(), output_size, kernel_size, padding, stride, dilation); const FoldParams params = state_ptr->params(); size_t out_bytes_size = output->shape_view().elem_cnt() * GetSizeOfDataType(output->data_type()); Memset(ctx->stream(), output->mut_dptr(), 0, out_bytes_size); FoldKernelUtil::Forward( ctx->stream(), ¶ms, input->dptr(), output->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace // Currently support 4-D tensor and NCHW format #define REGISTER_FOLD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("fold") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_FOLD_KERNEL(DeviceType::kCPU, float) REGISTER_FOLD_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_FOLD_KERNEL(DeviceType::kCUDA, float) REGISTER_FOLD_KERNEL(DeviceType::kCUDA, double) #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fold_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/fold_kernel_util.h" namespace oneflow { namespace user_op { // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template struct FoldKernelUtil { using ParamType = FoldParams; static void Forward(ep::Stream* stream, const void* raw_params, const T* input_ptr, T* output_ptr) { const auto* params = static_cast(raw_params); for (INDEX_T in_offset = 0; in_offset < params->in_elem_cnt; ++in_offset) { using ParamType = FoldParams; INDEX_T in_index[ParamType::kInputNDim] = {0}; INDEX_T out_index[ParamType::kOutputNDim] = {0}; params->in_index_helper.OffsetToNdIndex(in_offset, in_index); if (!FoldIndexTransform(*params, in_index, out_index)) { INDEX_T out_offset = params->out_index_helper.NdIndexToOffset(out_index); XPUAdd::Invoke(&input_ptr[in_offset], &output_ptr[out_offset]); } else { continue; } } } }; INSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fold_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/fold_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetNumBlocks(int64_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template __global__ void CudaFoldForward(FoldParams params, const T* input_ptr, T* output_ptr) { CUDA_1D_KERNEL_LOOP_T(INDEX_T, in_offset, params.in_elem_cnt) { using ParamType = FoldParams; INDEX_T in_index[ParamType::kInputNDim] = {0}; INDEX_T out_index[ParamType::kOutputNDim] = {0}; params.in_index_helper.OffsetToNdIndex(in_offset, in_index); if (!FoldIndexTransform(params, in_index, out_index)) { INDEX_T out_offset = params.out_index_helper.NdIndexToOffset(out_index); XPUAdd::Invoke(&input_ptr[in_offset], &output_ptr[out_offset]); } else { continue; } } } } // namespace template struct FoldKernelUtil { using ParamType = FoldParams; static void Forward(ep::Stream* stream, const void* raw_params, const T* input_ptr, T* output_ptr) { const auto* fold_params = static_cast(raw_params); CudaFoldForward <<in_elem_cnt), kBlockSize, 0, stream->As()->cuda_stream()>>>(*fold_params, input_ptr, output_ptr); } }; INSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA) } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/fold_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/ndarray/xpu_util.h" #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA namespace oneflow { namespace user_op { namespace { template struct XPUAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #if defined(__CUDA_ARCH__) cuda::atomic::Add(y, *x); #else *y += *x; #endif }; }; } // namespace // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template struct FoldParams { static constexpr int kInputNDim = NDIM * 2 + 2; static constexpr int kOutputNDim = NDIM + 2; static constexpr int kInputChannelDim = (2 - SDIM) * NDIM * 2 + 1; static constexpr int kOutputChannelDim = (2 - SDIM) * NDIM + 1; static_assert(kInputChannelDim < kInputNDim, ""); static_assert(kOutputChannelDim < kOutputNDim, ""); FoldParams(const int64_t batch_size, const int64_t channels, const int32_t* output_size, const int64_t* spatial_dims, const int32_t* kernel_size, const int32_t* padding, const int32_t* stride, const int32_t* dilation); INDEX_T in_elem_cnt; INDEX_T out_elem_cnt; INDEX_T dims[NDIM]; int padding[NDIM]; int stride[NDIM]; int dilation[NDIM]; NdIndexOffsetHelper in_index_helper; NdIndexOffsetHelper out_index_helper; }; template FoldParams::FoldParams(const int64_t batch_size, const int64_t channels_columns, const int32_t* output_size, const int64_t* spatial_dims, const int32_t* kernel_size, const int32_t* padding, const int32_t* stride, const int32_t* dilation) : in_elem_cnt(0), out_elem_cnt(0), in_index_helper(0), out_index_helper(0) { INDEX_T input_dims[kInputNDim] = {0}; INDEX_T output_dims[kOutputNDim] = {0}; const int32_t channels = channels_columns / (kernel_size[0] * kernel_size[1]); // channels_columns = C*K*K this->in_elem_cnt = batch_size * channels; this->out_elem_cnt = batch_size * channels; input_dims[0] = batch_size; output_dims[0] = batch_size; input_dims[kInputChannelDim] = channels; output_dims[kOutputChannelDim] = channels; for (int d = 0; d < NDIM; ++d) { this->dims[d] = output_size[d]; this->padding[d] = padding[d]; this->stride[d] = stride[d]; this->dilation[d] = dilation[d]; input_dims[SDIM + NDIM + d] = (output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1; input_dims[SDIM + d] = kernel_size[d]; this->in_elem_cnt *= input_dims[SDIM + d] * input_dims[SDIM + NDIM + d]; // N,C*Kh*Kw, H*W output_dims[SDIM + d] = output_size[d]; this->out_elem_cnt *= output_dims[SDIM + d]; } in_index_helper = NdIndexOffsetHelper(input_dims); out_index_helper = NdIndexOffsetHelper(output_dims); } // index_a format: (N, C, D, H, W) or (N, D, H, W, C) // index_b format: (N, C, di, hi, wi, db, hb, wb) or (N, di, hi, wi, db, hb, wb, C) // return: true indicates out-of-bound, otherwise in-bound template OF_DEVICE_FUNC bool FoldIndexTransform(const FoldParams& params, const INDEX_T* index_a, INDEX_T* index_b) { // batch dim index transform index_b[0] = index_a[0]; // channel dim index transform using ParamType = FoldParams; index_b[ParamType::kOutputChannelDim] = index_a[ParamType::kInputChannelDim]; // spatial dim index transform #ifdef __CUDA_ARCH__ #pragma unroll #endif // D,H,W spatial dim index transform for (int64_t d = 0; d < NDIM; ++d) { INDEX_T idx = index_a[SDIM + NDIM + d] * params.stride[d] + index_a[SDIM + d] * params.dilation[d] - params.padding[d]; if (idx < 0 || idx >= params.dims[d]) return true; index_b[SDIM + d] = idx; } return false; } template struct FoldKernelUtil { static void Forward(ep::Stream* stream, const void* params, const T* input_ptr, T* output_ptr); }; #define SPATIAL_NDIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3) #define SPATIAL_DIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) #define INSTANTIATE_FOLD_KERNEL_UTIL(device, dtype, itype, ndim, sdim) \ template struct FoldKernelUtil; #define INSTANTIATE_FOLD_KERNEL_UTIL_WITH_TYPE_PAIR(device, dtype_pair, itype_pair, ndim, sdim) \ INSTANTIATE_FOLD_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(itype_pair), \ ndim, sdim) #define INSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(device) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_FOLD_KERNEL_UTIL_WITH_TYPE_PAIR, (device), \ FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, SPATIAL_NDIM_SEQ, \ SPATIAL_DIM_SEQ) } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/frac_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { template class CpuFracKernel final : public user_op::OpKernel { public: CpuFracKernel() = default; ~CpuFracKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] - std::trunc(x_ptr[i]); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_FRAC_KERNEL(dtype) \ REGISTER_USER_KERNEL("frac").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CPU_FRAC_KERNEL(float) REGISTER_CPU_FRAC_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/frac_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { // Write ReLU Functor. template struct FracForwardGpu { OF_DEVICE_FUNC T operator()(T x) const { return x - std::trunc(x); } }; } // namespace template class GpuFracKernel final : public user_op::OpKernel { public: GpuFracKernel() = default; ~GpuFracKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); // Use CUDA Elementwise Template. OF_CUDA_CHECK( (cuda::elementwise::Unary(FracForwardGpu(), elem_cnt, y->mut_dptr(), x->dptr(), ctx->stream()->As()->cuda_stream()))); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GPU_FRAC_KERNEL(dtype) \ REGISTER_USER_KERNEL("frac").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_GPU_FRAC_KERNEL(float) REGISTER_GPU_FRAC_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_attention_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUTLASS #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/include/primitive/permute.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/warp/mma.h" #include "kernel_forward.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "trt_flash_attention/fmha.h" #include "trt_flash_attention/fmha_flash_attention.h" namespace oneflow { namespace user_op { namespace { void ParseDims(const ShapeView& shape, const std::string& layout, const Optional& batch_size, const Optional& seq_len, const Optional& num_heads, const Optional& head_size, int64_t tensor_index, int64_t* b, int64_t* m, int64_t* h, int64_t* k, int64_t* b_stride, int64_t* m_stride, int64_t* h_stride, int64_t* offset, bool* bm_packed) { if (shape.NumAxes() == 2) { if (layout == "(BM)(HK)" || layout == "(BM)(H2K)" || layout == "(BM)(H3K)") { *bm_packed = true; CHECK(batch_size); CHECK(seq_len); *b = CHECK_JUST(batch_size); *m = CHECK_JUST(seq_len); int64_t packed_n = 0; if (layout == "(BM)(HK)") { packed_n = 1; } else if (layout == "(BM)(H2K)") { packed_n = 2; } else if (layout == "(BM)(H3K)") { packed_n = 3; } else { UNIMPLEMENTED(); } const int64_t hidden_size = shape.At(1); if (num_heads) { const int64_t expected_h = CHECK_JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ(hidden_size % packed_h, 0); *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = CHECK_JUST(head_size); const int64_t packed_k = packed_n * expected_k; CHECK_EQ(hidden_size % packed_k, 0); *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED(); } *h_stride = *k * packed_n; *m_stride = *h_stride * *h; *b_stride = 0; if (packed_n == 1) { *offset = 0; } else if (packed_n == 2) { CHECK_GE(tensor_index, 1); *offset = (tensor_index - 1) * *k; } else if (packed_n == 3) { *offset = tensor_index * *k; } else { UNIMPLEMENTED(); } } else { UNIMPLEMENTED(); } } else if (shape.NumAxes() == 3) { if (layout == "BM(HK)" || layout == "BM(H2K)" || layout == "BM(H3K)" || layout == "MB(HK)" || layout == "MB(H2K)" || layout == "MB(H3K)") { *bm_packed = false; bool batch_first = false; int64_t packed_n = 0; const std::string layout_bm = layout.substr(0, 2); const std::string layout_hk = layout.substr(2); if (layout_bm == "BM") { *b = shape.At(0); *m = shape.At(1); batch_first = true; } else if (layout_bm == "MB") { *b = shape.At(1); *m = shape.At(0); batch_first = false; } else { UNIMPLEMENTED(); } if (layout_hk == "(HK)") { packed_n = 1; } else if (layout_hk == "(H2K)") { packed_n = 2; } else if (layout_hk == "(H3K)") { packed_n = 3; } else { UNIMPLEMENTED(); } const int64_t hidden_size = shape.At(2); if (num_heads) { const int64_t expected_h = CHECK_JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ(hidden_size % packed_h, 0); *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = CHECK_JUST(head_size); const int64_t packed_k = packed_n * expected_k; CHECK_EQ(hidden_size % packed_k, 0); *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED(); } *h_stride = *k * packed_n; if (batch_first) { *m_stride = *h_stride * *h; *b_stride = *m_stride * *m; } else { *b_stride = *h_stride * *h; *m_stride = *b_stride * *b; } if (packed_n == 1) { *offset = 0; } else if (packed_n == 2) { CHECK_GE(tensor_index, 1); *offset = (tensor_index - 1) * *k; } else if (packed_n == 3) { *offset = tensor_index * *k; } else { UNIMPLEMENTED(); } } else if (layout == "(BM)HK") { *bm_packed = true; CHECK(batch_size); CHECK(seq_len); *b = CHECK_JUST(batch_size); *m = CHECK_JUST(seq_len); *h = shape.At(1); *k = shape.At(2); *h_stride = *k; *m_stride = *h_stride * *h; *b_stride = 0; } else { UNIMPLEMENTED(); } } else if (shape.NumAxes() == 4) { *bm_packed = false; if (layout == "BMHK") { *b = shape.At(0); *m = shape.At(1); *h = shape.At(2); *k = shape.At(3); *h_stride = *k; *m_stride = *h_stride * *h; *b_stride = *m_stride * *m; } else if (layout == "BHMK") { *b = shape.At(0); *m = shape.At(2); *h = shape.At(1); *k = shape.At(3); *m_stride = *k; *h_stride = *m_stride * *m; *b_stride = *h_stride * *h; } else if (layout == "MBHK") { *b = shape.At(1); *m = shape.At(0); *h = shape.At(2); *k = shape.At(3); *h_stride = *k; *b_stride = *h_stride * *h; *m_stride = *b_stride * *b; } else { UNIMPLEMENTED(); } *offset = 0; } else { UNIMPLEMENTED(); }; if (batch_size) { const int64_t expected_b = CHECK_JUST(batch_size); CHECK_EQ(*b, expected_b); } if (seq_len) { const int64_t expected_m = CHECK_JUST(seq_len); CHECK_EQ(*m, expected_m); } if (num_heads) { const int64_t expected_h = CHECK_JUST(num_heads); CHECK_EQ(*h, expected_h); } if (head_size) { const int64_t expected_k = CHECK_JUST(head_size); CHECK_EQ(*k, expected_k); } } void ParseDims(const ShapeView& shape, const std::string& layout, const Optional& num_heads, const Optional& head_size, int64_t tensor_index, int64_t* b, int64_t* m, int64_t* h, int64_t* k, int64_t* b_stride, int64_t* m_stride, int64_t* h_stride, int64_t* offset) { bool bm_packed{}; ParseDims(shape, layout, Optional(), Optional(), num_heads, head_size, tensor_index, b, m, h, k, b_stride, m_stride, h_stride, offset, &bm_packed); } template struct alignas(pack_size * sizeof(T)) Pack { T elem[pack_size]; }; template __global__ void PackQkv(int b, int s, int nh, int d, const T* q, const T* k, const T* v, T* o, int32_t* seq_len) { int count = b * s * nh * d * 3; for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { int row = i / (d * 3); int out_col = i - row * (d * 3); T out; if (out_col < d) { out = q[row * d + out_col]; } else if (out_col < 2 * d) { out = k[row * d + out_col - d]; } else { out = v[row * d + out_col - d * 2]; } o[i] = out; } for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < b + 1; i += blockDim.x * gridDim.x) { seq_len[i] = i * s; } } struct Params { DataType data_type; int64_t num_batches; int64_t num_heads; int64_t query_seq_len; int64_t kv_seq_len; int64_t head_size; int64_t value_head_size; int64_t q_stride_b; int64_t q_stride_m; int64_t q_stride_h; int64_t k_stride_b; int64_t k_stride_m; int64_t k_stride_h; int64_t v_stride_b; int64_t v_stride_m; int64_t v_stride_h; std::string attn_mask_type; int64_t causal_diagonal_offset; const void* query_ptr; const void* key_ptr; const void* value_ptr; const void* attn_bias_ptr; const void* query_seq_start_ptr; const void* key_seq_start_ptr; const void* key_seq_len_ptr; int64_t attn_bias_stride_b; int64_t attn_bias_stride_h; int64_t attn_bias_stride_m; void* out_ptr; void* workspace; int64_t workspace_size; float scale; }; template void LaunchCutlassFmha(const Params& params, ep::CudaStream* stream) { // The fmha implementation below is based on xformers's fmha // implementation at: // https://github.com/facebookresearch/xformers/tree/main/xformers/csrc/attention/cuda/fmha using Attention = AttentionKernel; typename Attention::Params p{}; p.query_ptr = const_cast(reinterpret_cast(params.query_ptr)); p.key_ptr = const_cast(reinterpret_cast(params.key_ptr)); p.value_ptr = const_cast(reinterpret_cast(params.value_ptr)); p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias_ptr)); p.seqstart_q_ptr = const_cast(reinterpret_cast(params.query_seq_start_ptr)); p.seqstart_k_ptr = const_cast(reinterpret_cast(params.key_seq_start_ptr)); p.seqlen_k_ptr = const_cast(reinterpret_cast(params.key_seq_len_ptr)); p.logsumexp_ptr = nullptr; p.output_ptr = reinterpret_cast(params.out_ptr); if (Attention::kNeedsOutputAccumulatorBuffer) { using Acc = typename Attention::accum_t; CHECK_GE(params.workspace_size, params.num_batches * params.query_seq_len * params.num_heads * params.value_head_size * sizeof(Acc)); p.output_accum_ptr = reinterpret_cast(params.workspace); } else { p.output_accum_ptr = nullptr; } p.num_heads = params.num_heads; p.num_batches = params.num_batches; p.head_dim = params.head_size; p.head_dim_value = params.value_head_size; p.num_queries = params.query_seq_len; p.num_keys = params.kv_seq_len; p.q_strideM = params.q_stride_m; p.k_strideM = params.k_stride_m; p.v_strideM = params.v_stride_m; p.o_strideM = p.head_dim_value * p.num_heads; p.bias_strideM = params.attn_bias_stride_m; p.q_strideH = params.q_stride_h; p.k_strideH = params.k_stride_h; p.v_strideH = params.v_stride_h; p.bias_strideH = params.attn_bias_stride_h; p.q_strideB = params.q_stride_b; p.k_strideB = params.k_stride_b; p.v_strideB = params.v_stride_b; p.bias_strideB = params.attn_bias_stride_b; p.scale = params.scale; if (params.attn_mask_type == "none") { p.custom_mask_type = Attention::NoCustomMask; } else if (params.attn_mask_type == "causal_from_top_left") { p.custom_mask_type = Attention::CausalFromTopLeft; } else if (params.attn_mask_type == "causal_from_bottom_right") { p.custom_mask_type = Attention::CausalFromBottomRight; } else { UNIMPLEMENTED(); } p.causal_diagonal_offset = params.causal_diagonal_offset; p.use_dropout = false; constexpr auto kernel_fn = attention_kernel_batched_impl; int smem_bytes = sizeof(typename Attention::SharedStorage); if (smem_bytes > 0xc000) { static bool once = [&]() { cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); return true; }(); } CHECK(Attention::check_supported(p)); kernel_fn<<cuda_stream()>>>(p); } template void DispatchWithAttnBias(const Params& params, ep::CudaStream* stream) { if (params.attn_bias_ptr != nullptr) { LaunchCutlassFmha(params, stream); } else { LaunchCutlassFmha(params, stream); } } template void DispatchSingleValueIteration(const Params& params, ep::CudaStream* stream) { if (params.value_head_size <= keys_per_block) { DispatchWithAttnBias(params, stream); } else { DispatchWithAttnBias(params, stream); } } template void DispatchKeysPerBlock(const Params& params, ep::CudaStream* stream) { if (params.value_head_size <= 64) { DispatchSingleValueIteration(params, stream); } else { DispatchSingleValueIteration(params, stream); } } template void DispatchIsAligned(const Params& params, ep::CudaStream* stream) { if (reinterpret_cast(params.query_ptr) % 16 == 0 && reinterpret_cast(params.key_ptr) % 16 == 0 && reinterpret_cast(params.value_ptr) % 16 == 0 && params.attn_bias_stride_m % (16 / sizeof(T)) == 0 && params.head_size % (16 / sizeof(T)) == 0 && params.value_head_size % (16 / sizeof(T)) == 0) { DispatchKeysPerBlock(params, stream); } else { DispatchKeysPerBlock(params, stream); } } template void DispatchArchTag(const Params& params, ep::CudaStream* stream) { const int major = stream->device_properties().major; const int minor = stream->device_properties().minor; if (major == 8) { DispatchIsAligned(params, stream); } else if (major == 7) { if (minor == 5) { DispatchIsAligned(params, stream); } else { DispatchIsAligned(params, stream); } } else { UNIMPLEMENTED(); } } void DispatchCutlassFmha(const Params& params, ep::CudaStream* stream) { if (params.data_type == DataType::kFloat16) { DispatchArchTag(params, stream); } else if (params.data_type == DataType::kFloat) { DispatchArchTag(params, stream); } else { UNIMPLEMENTED(); } } class FusedMultiHeadAttentionInferenceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedMultiHeadAttentionInferenceKernel() = default; ~FusedMultiHeadAttentionInferenceKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const Tensor* attn_bias = nullptr; if (ctx->has_input("attn_bias", 0)) { attn_bias = ctx->Tensor4ArgNameAndIndex("attn_bias", 0); } const Tensor* query_seq_start = nullptr; const Tensor* key_seq_start = nullptr; const Tensor* key_seq_len = nullptr; const float scale = ctx->Attr("scale"); if (ctx->has_input("query_seq_start", 0)) { CHECK(ctx->has_input("key_seq_start", 0)); query_seq_start = ctx->Tensor4ArgNameAndIndex("query_seq_start", 0); key_seq_start = ctx->Tensor4ArgNameAndIndex("key_seq_start", 0); CHECK(query_seq_start->data_type() == DataType::kInt32); CHECK(key_seq_start->data_type() == DataType::kInt32); CHECK_EQ(query_seq_start->shape_view().NumAxes(), 1); CHECK_GT(query_seq_start->shape_view().At(0), 1); CHECK(query_seq_start->shape_view() == key_seq_start->shape_view()); if (ctx->has_input("key_seq_len", 0)) { key_seq_len = ctx->Tensor4ArgNameAndIndex("key_seq_len", 0); CHECK(key_seq_len->data_type() == DataType::kInt32); CHECK_EQ(key_seq_len->shape_view().NumAxes(), 1); CHECK_EQ(key_seq_len->shape_view().At(0), query_seq_start->shape_view().At(0) - 1); } } else { CHECK(!ctx->has_input("key_seq_start", 0)); CHECK(!ctx->has_input("key_seq_len", 0)); } Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const DataType data_type = query->data_type(); CHECK_EQ(key->data_type(), data_type); CHECK_EQ(value->data_type(), data_type); CHECK_EQ(out->data_type(), data_type); const int64_t query_head_size = ctx->Attr("query_head_size"); const std::string& attn_mask_type = ctx->Attr("attn_mask_type"); const int64_t causal_diagonal_offset = ctx->Attr("causal_diagonal_offset"); CHECK_GE(causal_diagonal_offset, 0); const std::string& query_layout = ctx->Attr("query_layout"); const std::string& key_layout = ctx->Attr("key_layout"); const std::string& value_layout = ctx->Attr("value_layout"); const std::string& output_layout = ctx->Attr("output_layout"); Optional batch_size; if (query_seq_start != nullptr) { batch_size = query_seq_start->shape_view().At(0) - 1; } Optional query_max_seq_len; const int64_t attr_query_max_seq_len = ctx->Attr("query_max_seq_len"); if (attr_query_max_seq_len != 0) { query_max_seq_len = attr_query_max_seq_len; } Optional key_max_seq_len; const int64_t attr_key_max_seq_len = ctx->Attr("key_max_seq_len"); if (attr_key_max_seq_len != 0) { key_max_seq_len = attr_key_max_seq_len; } int64_t q_b = 0; int64_t q_m = 0; int64_t q_h = 0; int64_t q_k = 0; int64_t q_b_stride = 0; int64_t q_m_stride = 0; int64_t q_h_stride = 0; int64_t q_offset = 0; bool q_bm_packed = false; ParseDims(query->shape_view(), query_layout, batch_size, query_max_seq_len, Optional(), query_head_size, 0, &q_b, &q_m, &q_h, &q_k, &q_b_stride, &q_m_stride, &q_h_stride, &q_offset, &q_bm_packed); if (q_bm_packed) { CHECK(query_seq_start != nullptr); } int64_t k_b = 0; int64_t k_m = 0; int64_t k_h = 0; int64_t k_k = 0; int64_t k_b_stride = 0; int64_t k_m_stride = 0; int64_t k_h_stride = 0; int64_t k_offset = 0; bool k_bm_packed = false; ParseDims(key->shape_view(), key_layout, q_b, key_max_seq_len, Optional(), query_head_size, 1, &k_b, &k_m, &k_h, &k_k, &k_b_stride, &k_m_stride, &k_h_stride, &k_offset, &k_bm_packed); CHECK_EQ(k_b, q_b); CHECK_EQ(k_h, q_h); CHECK_EQ(k_bm_packed, q_bm_packed); int64_t v_b = 0; int64_t v_m = 0; int64_t v_h = 0; int64_t v_k = 0; int64_t v_b_stride = 0; int64_t v_m_stride = 0; int64_t v_h_stride = 0; int64_t v_offset = 0; bool v_bm_packed = false; ParseDims(value->shape_view(), value_layout, q_b, k_m, q_h, Optional(), 2, &v_b, &v_m, &v_h, &v_k, &v_b_stride, &v_m_stride, &v_h_stride, &v_offset, &v_bm_packed); CHECK_EQ(v_b, q_b); CHECK_EQ(v_m, k_m); CHECK_EQ(v_bm_packed, k_bm_packed); if (output_layout == "BM(HK)") { CHECK(!q_bm_packed); CHECK_EQ(out->shape_view().NumAxes(), 3); CHECK_EQ(out->shape_view().At(0), q_b); CHECK_EQ(out->shape_view().At(1), q_m); CHECK_EQ(out->shape_view().At(2), q_h * v_k); } else if (output_layout == "MB(HK)") { CHECK(!q_bm_packed); CHECK_EQ(out->shape_view().NumAxes(), 3); CHECK_EQ(q_b, 1); CHECK_EQ(out->shape_view().At(0), q_m); CHECK_EQ(out->shape_view().At(1), q_b); CHECK_EQ(out->shape_view().At(2), q_h * v_k); } else if (output_layout == "(BM)(HK)") { CHECK(q_bm_packed); CHECK_EQ(out->shape_view().NumAxes(), 2); CHECK_EQ(out->shape_view().At(0), query->shape_view().At(0)); CHECK_EQ(out->shape_view().At(1), q_h * v_k); } else { UNIMPLEMENTED(); } auto* cuda_stream = ctx->stream()->As(); // Compatible with typo `KERENL` const bool enable_trt_flash_attn = ParseBooleanFromEnv( "ONEFLOW_KERNEL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", ParseBooleanFromEnv("ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", true)) && ParseBooleanFromEnv("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", false); const bool is_default_scale = std::abs(scale - 1.0 / std::sqrt(static_cast(q_k))) <= 1e-5; const int arch = cuda_stream->cuda_arch() / 10; const bool is_trt_supported_arch = (arch == 75 || arch == 80 || arch == 86 || arch == 89); const bool is_trt_supported_head_size = ((q_k == 40) || (q_k == 64)); // Avoid PackQKV overhead when seq_len is small. const bool is_long_seq_len = q_m >= 512; const bool is_trt_supported_layout = (query_layout == "BMHK" || query_layout == "BM(HK)") && (key_layout == "BMHK" || key_layout == "BM(HK)") && (value_layout == "BMHK" || value_layout == "BM(HK)") && (output_layout == "BMHK" || output_layout == "BM(HK)"); if (is_default_scale && query_seq_start == nullptr && enable_trt_flash_attn && data_type == DataType::kFloat16 && q_m == k_m && q_k == v_k && is_trt_supported_head_size && is_long_seq_len && is_trt_supported_arch && attn_mask_type == "none" && attn_bias == nullptr && is_trt_supported_layout) { // The fmha implementation below is based on TensorRT's multiHeadFlashAttentionPlugin // implementation at: // https://github.com/NVIDIA/TensorRT/tree/main/plugin/multiHeadFlashAttentionPlugin int32_t cu_seqlens_d_size = (q_b + 1) * sizeof(int32_t); int32_t* cu_seqlens_d = reinterpret_cast(tmp->mut_dptr()); half* packed_qkv = reinterpret_cast(tmp->mut_dptr() + GetCudaAlignedSize(cu_seqlens_d_size)); constexpr int pack_size = 4; using PackType = Pack; const int64_t count = q_b * q_m * q_h * q_k * 3 / pack_size; PackQkv<<<(count - 1 + 256) / 256, 256, 0, cuda_stream->cuda_stream()>>>( q_b, q_m, q_h, q_k / pack_size, reinterpret_cast(query->dptr()), reinterpret_cast(key->dptr()), reinterpret_cast(value->dptr()), reinterpret_cast(packed_qkv), cu_seqlens_d); #ifdef WITH_CUDA_GRAPHS cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; if (cuda_stream->IsGraphCapturing()) { OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); } #endif // WITH_CUDA_GRAPHS nvinfer1::plugin::FusedMultiHeadFlashAttentionKernel const* kernels = nvinfer1::plugin::getFMHAFlashCubinKernels(nvinfer1::plugin::DATA_TYPE_FP16, arch); #ifdef WITH_CUDA_GRAPHS if (cuda_stream->IsGraphCapturing()) { OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); } #endif // WITH_CUDA_GRAPHS nvinfer1::plugin::runFMHFAKernel(packed_qkv, cu_seqlens_d, out->mut_dptr(), q_b * q_m, arch, kernels, q_b, q_h, q_k, q_m, cuda_stream->cuda_stream()); return; } Params params{}; params.data_type = data_type; params.num_batches = q_b; params.num_heads = q_h; params.query_seq_len = q_m; params.kv_seq_len = k_m; params.head_size = q_k; params.value_head_size = v_k; params.scale = scale; params.q_stride_b = q_b_stride; params.q_stride_m = q_m_stride; params.q_stride_h = q_h_stride; params.k_stride_b = k_b_stride; params.k_stride_m = k_m_stride; params.k_stride_h = k_h_stride; params.v_stride_b = v_b_stride; params.v_stride_m = v_m_stride; params.v_stride_h = v_h_stride; params.query_ptr = query->dptr() + q_offset * GetSizeOfDataType(data_type); params.key_ptr = key->dptr() + k_offset * GetSizeOfDataType(data_type); params.value_ptr = value->dptr() + v_offset * GetSizeOfDataType(data_type); params.query_seq_start_ptr = query_seq_start == nullptr ? nullptr : query_seq_start->dptr(); params.key_seq_start_ptr = key_seq_start == nullptr ? nullptr : key_seq_start->dptr(); params.key_seq_len_ptr = key_seq_len == nullptr ? nullptr : key_seq_len->dptr(); params.out_ptr = out->mut_dptr(); const int64_t tmp_buffer_size = tmp->shape_view().elem_cnt(); params.workspace = tmp->mut_dptr(); params.workspace_size = tmp_buffer_size; params.attn_mask_type = attn_mask_type; params.causal_diagonal_offset = causal_diagonal_offset; if (attn_bias != nullptr) { const int64_t num_attn_bias_axes = attn_bias->shape_view().NumAxes(); CHECK_GE(num_attn_bias_axes, 1); CHECK_LE(num_attn_bias_axes, 4); DimVector padded_attn_bias_shape; for (int i = 0; i < 4 - num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(1); } for (int i = 0; i < num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(attn_bias->shape_view().At(i)); } CHECK_GE(padded_attn_bias_shape.at(3), k_m); int64_t bias_stride = padded_attn_bias_shape.at(3); if (padded_attn_bias_shape.at(2) == 1) { params.attn_bias_stride_m = 0; } else { CHECK_GE(padded_attn_bias_shape.at(2), q_m); params.attn_bias_stride_m = bias_stride; bias_stride *= padded_attn_bias_shape.at(2); } if (padded_attn_bias_shape.at(1) == 1) { params.attn_bias_stride_h = 0; } else { CHECK_EQ(padded_attn_bias_shape.at(1), q_h); params.attn_bias_stride_h = bias_stride; bias_stride *= q_h; } if (padded_attn_bias_shape.at(0) == 1) { params.attn_bias_stride_b = 0; } else { CHECK_EQ(padded_attn_bias_shape.at(0), q_b); params.attn_bias_stride_b = bias_stride; } params.attn_bias_ptr = attn_bias->dptr(); } else { params.attn_bias_ptr = nullptr; params.attn_bias_stride_m = 0; params.attn_bias_stride_h = 0; params.attn_bias_stride_b = 0; } DispatchCutlassFmha(params, cuda_stream); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; size_t InferTmpBufferSize(InferContext* ctx) { const auto& out_desc = ctx->OutputTensorDesc("out", 0); size_t buffer_size = 0; buffer_size += GetCudaAlignedSize(out_desc.shape().elem_cnt() * GetSizeOfDataType(DataType::kFloat)); buffer_size += GetCudaAlignedSize(out_desc.shape().elem_cnt() * GetSizeOfDataType(out_desc.data_type())) * 3; buffer_size += GetCudaAlignedSize((out_desc.shape().At(0) + 1) * GetSizeOfDataType(DataType::kInt32)); return buffer_size; } #define REGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_multi_head_attention_inference") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == dtype)) \ .SetInferTmpSizeFn(InferTmpBufferSize); REGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(DataType::kFloat16) REGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(DataType::kFloat) template struct ConcatParam { const void* past_ptr; const void* ptr; void* output_ptr; Index past_offset; Index offset; Index output_offset; Index past_m; Index past_stride_b; Index past_stride_m; Index past_stride_h; Index stride_b; Index stride_m; Index stride_h; Index output_stride_b; Index output_stride_m; Index output_stride_h; Index count; Index output_khm; Index output_kh; Index output_k; }; template struct BatchConcatParam { ConcatParam params[2]; }; template __device__ void ConcatPastKeyValue(ConcatParam p) { for (Index i = blockIdx.x * blockDim.x + threadIdx.x; i < p.count; i += blockDim.x * gridDim.x) { Index b_idx = i / p.output_khm; Index b_off = i - b_idx * p.output_khm; Index m_idx = b_off / p.output_kh; Index m_off = b_off - m_idx * p.output_kh; Index h_idx = m_off / p.output_k; Index k_idx = m_off - h_idx * p.output_k; T v; if (m_idx < p.past_m) { v = reinterpret_cast( p.past_ptr)[p.past_offset + b_idx * p.past_stride_b + m_idx * p.past_stride_m + h_idx * p.past_stride_h + k_idx]; } else { v = reinterpret_cast( p.ptr)[p.offset + b_idx * p.stride_b + (m_idx - p.past_m) * p.stride_m + h_idx * p.stride_h + k_idx]; } reinterpret_cast( p.output_ptr)[p.output_offset + b_idx * p.output_stride_b + m_idx * p.output_stride_m + h_idx * p.output_stride_h + k_idx] = v; } } template __global__ void BatchConcatPastKeyValue(BatchConcatParam params) { if (blockIdx.y == 0) { ConcatPastKeyValue::type, Index>(params.params[0]); } else if (blockIdx.y == 1) { ConcatPastKeyValue::type, Index>(params.params[1]); } else { // do nothing } } class FusedAttentionConcatPastKeyValueKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedAttentionConcatPastKeyValueKernel() = default; ~FusedAttentionConcatPastKeyValueKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); Tensor* output_key = ctx->Tensor4ArgNameAndIndex("output_key", 0); Tensor* output_value = ctx->Tensor4ArgNameAndIndex("output_value", 0); const DataType data_type = key->data_type(); const Tensor* past_key = nullptr; const Tensor* past_value = nullptr; if (ctx->has_input("past_key", 0)) { CHECK(ctx->has_input("past_value", 0)); past_key = ctx->Tensor4ArgNameAndIndex("past_key", 0); past_value = ctx->Tensor4ArgNameAndIndex("past_value", 0); CHECK_EQ(past_key->data_type(), data_type); CHECK_EQ(past_value->data_type(), data_type); } else { CHECK(!ctx->has_input("past_value", 0)); } CHECK_EQ(value->data_type(), data_type); CHECK_EQ(output_key->data_type(), data_type); CHECK_EQ(output_value->data_type(), data_type); const int64_t size_of_data_type = GetSizeOfDataType(data_type); const int64_t key_head_size = ctx->Attr("key_head_size"); const std::string& past_key_layout = ctx->Attr("past_key_layout"); const std::string& past_value_layout = ctx->Attr("past_value_layout"); const std::string& key_layout = ctx->Attr("key_layout"); const std::string& value_layout = ctx->Attr("value_layout"); int64_t pack_size = 16 / size_of_data_type; while (key_head_size % pack_size != 0) { pack_size /= 2; } auto ParsePackedDims = [](const ShapeView& shape, const std::string& layout, const Optional& num_heads, const Optional& head_size, int64_t tensor_index, int64_t* b, int64_t* m, int64_t* h, int64_t* k, int64_t* b_stride, int64_t* m_stride, int64_t* h_stride, int64_t* offset, int64_t pack_size) { ParseDims(shape, layout, num_heads, head_size, tensor_index, b, m, h, k, b_stride, m_stride, h_stride, offset); *k /= pack_size; *b_stride /= pack_size; *m_stride /= pack_size; *h_stride /= pack_size; *offset /= pack_size; }; int64_t key_b = 0; int64_t key_m = 0; int64_t key_h = 0; int64_t key_k = 0; int64_t key_b_stride = 0; int64_t key_m_stride = 0; int64_t key_h_stride = 0; int64_t key_offset = 0; ParsePackedDims(key->shape_view(), key_layout, Optional(), key_head_size, 1, &key_b, &key_m, &key_h, &key_k, &key_b_stride, &key_m_stride, &key_h_stride, &key_offset, pack_size); int64_t value_b = 0; int64_t value_m = 0; int64_t value_h = 0; int64_t value_k = 0; int64_t value_b_stride = 0; int64_t value_m_stride = 0; int64_t value_h_stride = 0; int64_t value_offset = 0; ParsePackedDims(value->shape_view(), value_layout, key_h, key_head_size, 2, &value_b, &value_m, &value_h, &value_k, &value_b_stride, &value_m_stride, &value_h_stride, &value_offset, pack_size); CHECK_EQ(value_b, key_b); CHECK_EQ(value_m, key_m); int64_t past_key_b = 0; int64_t past_key_m = 0; int64_t past_key_h = 0; int64_t past_key_k = 0; int64_t past_key_b_stride = 0; int64_t past_key_m_stride = 0; int64_t past_key_h_stride = 0; int64_t past_key_offset = 0; if (past_key != nullptr) { ParsePackedDims(past_key->shape_view(), past_key_layout, key_h, key_head_size, 1, &past_key_b, &past_key_m, &past_key_h, &past_key_k, &past_key_b_stride, &past_key_m_stride, &past_key_h_stride, &past_key_offset, pack_size); } int64_t past_value_b = 0; int64_t past_value_m = 0; int64_t past_value_h = 0; int64_t past_value_k = 0; int64_t past_value_b_stride = 0; int64_t past_value_m_stride = 0; int64_t past_value_h_stride = 0; int64_t past_value_offset = 0; if (past_value != nullptr) { ParsePackedDims(past_value->shape_view(), past_value_layout, key_h, key_head_size, 2, &past_value_b, &past_value_m, &past_value_h, &past_value_k, &past_value_b_stride, &past_value_m_stride, &past_value_h_stride, &past_value_offset, pack_size); } CHECK_EQ(past_value_b, past_key_b); CHECK_EQ(past_value_m, past_key_m); int64_t output_key_b = 0; int64_t output_key_m = 0; int64_t output_key_h = 0; int64_t output_key_k = 0; int64_t output_key_b_stride = 0; int64_t output_key_m_stride = 0; int64_t output_key_h_stride = 0; int64_t output_key_offset = 0; ParsePackedDims(output_key->shape_view(), past_key_layout, key_h, key_head_size, 1, &output_key_b, &output_key_m, &output_key_h, &output_key_k, &output_key_b_stride, &output_key_m_stride, &output_key_h_stride, &output_key_offset, pack_size); CHECK_EQ(output_key_b, key_b); CHECK_EQ(output_key_m, past_key_m + key_m); int64_t output_value_b = 0; int64_t output_value_m = 0; int64_t output_value_h = 0; int64_t output_value_k = 0; int64_t output_value_b_stride = 0; int64_t output_value_m_stride = 0; int64_t output_value_h_stride = 0; int64_t output_value_offset = 0; ParsePackedDims(output_value->shape_view(), past_value_layout, key_h, key_head_size, 2, &output_value_b, &output_value_m, &output_value_h, &output_value_k, &output_value_b_stride, &output_value_m_stride, &output_value_h_stride, &output_value_offset, pack_size); CHECK_EQ(output_value_b, key_b); CHECK_EQ(output_value_m, past_value_m + value_m); int64_t max_tensor_elem = (1 << 30) * pack_size; CHECK((past_key == nullptr || past_key->shape_view().elem_cnt() <= max_tensor_elem) && (past_value == nullptr || past_value->shape_view().elem_cnt() <= max_tensor_elem) && key->shape_view().elem_cnt() <= max_tensor_elem && value->shape_view().elem_cnt() <= max_tensor_elem && output_key->shape_view().elem_cnt() <= max_tensor_elem && output_value->shape_view().elem_cnt() <= max_tensor_elem); int64_t count = output_key_b * output_key_m * output_key_h * output_key_k; BatchConcatParam kv; kv.params[0].past_ptr = past_key == nullptr ? nullptr : past_key->dptr(); kv.params[0].ptr = key->dptr(); kv.params[0].output_ptr = output_key->mut_dptr(); kv.params[0].past_offset = past_key_offset; kv.params[0].offset = key_offset; kv.params[0].output_offset = output_key_offset; kv.params[0].past_m = past_key_m; kv.params[0].past_stride_b = past_key_b_stride; kv.params[0].past_stride_m = past_key_m_stride; kv.params[0].past_stride_h = past_key_h_stride; kv.params[0].stride_b = key_b_stride; kv.params[0].stride_m = key_m_stride; kv.params[0].stride_h = key_h_stride; kv.params[0].output_stride_b = output_key_b_stride; kv.params[0].output_stride_m = output_key_m_stride; kv.params[0].output_stride_h = output_key_h_stride; kv.params[0].count = count; kv.params[0].output_khm = output_key_k * output_key_h * output_key_m; kv.params[0].output_kh = output_key_k * output_key_h; kv.params[0].output_k = output_key_k; kv.params[1].past_ptr = past_value == nullptr ? nullptr : past_value->dptr(); kv.params[1].ptr = value->dptr(); kv.params[1].output_ptr = output_value->mut_dptr(); kv.params[1].past_offset = past_value_offset; kv.params[1].offset = value_offset; kv.params[1].output_offset = output_value_offset; kv.params[1].past_m = past_value_m; kv.params[1].past_stride_b = past_value_b_stride; kv.params[1].past_stride_m = past_value_m_stride; kv.params[1].past_stride_h = past_value_h_stride; kv.params[1].stride_b = value_b_stride; kv.params[1].stride_m = value_m_stride; kv.params[1].stride_h = value_h_stride; kv.params[1].output_stride_b = output_value_b_stride; kv.params[1].output_stride_m = output_value_m_stride; kv.params[1].output_stride_h = output_value_h_stride; kv.params[1].count = count; kv.params[1].output_khm = output_value_k * output_value_h * output_value_m; kv.params[1].output_kh = output_value_k * output_value_h; kv.params[1].output_k = output_value_k; constexpr uint32_t block_size = 256; const dim3 grid_size((count - 1 + block_size) / block_size, 2); const int64_t elem_size = size_of_data_type * pack_size; cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); if (elem_size == 16) { BatchConcatPastKeyValue<16, int32_t><<>>(kv); } else if (elem_size == 8) { BatchConcatPastKeyValue<8, int32_t><<>>(kv); } else if (elem_size == 4) { BatchConcatPastKeyValue<4, int32_t><<>>(kv); } else if (elem_size == 2) { BatchConcatPastKeyValue<2, int32_t><<>>(kv); } else if (elem_size == 1) { BatchConcatPastKeyValue<1, int32_t><<>>(kv); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("fused_attention_concat_past_key_value") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); template struct FusedApplyRotaryEmbParam { const T* x; const T* cos; const T* sin; const PositionType* position_ids; T* out; const T theta; const float inv_actual_rotary_size; // 1.0 / (rotary_size per rotary dimension) const IndexType actual_rotary_size; // rotary_size per rotary dimension const IndexType rotary_size; const IndexType rotate_stride; const IndexType k0; const IndexType k1; IndexType num_elements; const IndexType k; const IndexType x_offset; IndexType ref_stride[num_dims]; // b, m, h, k IndexType out_stride[num_dims]; // ordered descendingly by stride IndexType x_stride[num_dims]; IndexType position_b_stride; IndexType position_rotate_stride; IndexType sinuous_m_stride; FusedApplyRotaryEmbParam(const T* x, const T* cos, const T* sin, const PositionType* position_ids, T* out, const T theta, const float inv_actual_rotary_size, const IndexType actual_rotary_size, const IndexType rotary_size, const IndexType rotate_stride, const IndexType num_elements, const IndexType k, const IndexType k0, const IndexType k1, const IndexType x_offset) : x(x), cos(cos), sin(sin), position_ids(position_ids), out(out), theta(theta), inv_actual_rotary_size(inv_actual_rotary_size), actual_rotary_size(actual_rotary_size), rotary_size(rotary_size), rotate_stride(rotate_stride), num_elements(num_elements), k(k), k0(k0), k1(k1), x_offset(x_offset) {} }; template __global__ void IntervalKernel( FusedApplyRotaryEmbParam param) { for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x; packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) { using LoadPack = cuda::elementwise::Packed; IndexType offset = packed_offset * PackSize; IndexType index[num_dims]; // b, m, h, k IndexType temp_offset = offset; for (int i = 0; i < num_dims - 1; i++) { IndexType ref_stride = param.ref_stride[i]; IndexType idx = temp_offset / ref_stride; index[i] = idx; temp_offset = temp_offset - idx * ref_stride; } index[num_dims - 1] = temp_offset; IndexType x_offset = param.x_offset; IndexType out_offset = 0; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; out_offset = out_offset + param.out_stride[i] * index[i]; } const LoadPack x_vec = *reinterpret_cast(param.x + x_offset); const IndexType k_index = index[num_dims - 1]; if (k_index < param.rotary_size) { const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; const IndexType b_index = index[0], m_index = index[1]; const IndexType position_id_offset = b_index * param.position_b_stride + position_rotate_index * param.position_rotate_stride + m_index; const PositionType position = param.position_ids ? param.position_ids[position_id_offset] : m_index; const IndexType actual_k_index = k_index % param.actual_rotary_size; const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index; LoadPack cos_vec, sin_vec, out_vec; if (param.cos && param.sin) { cos_vec = *reinterpret_cast(param.cos + sinuous_offset); sin_vec = *reinterpret_cast(param.sin + sinuous_offset); } else { const IndexType actual_ndim = param.rotary_size / rotary_emb_dim; #pragma unroll for (int i = 0; i < PackSize / 2; i++) { T val = position * expf(2.0f * static_cast(((actual_k_index >> 1) + i)) * param.inv_actual_rotary_size * logf(param.theta)); T cos_val = cosf(val); T sin_val = sinf(val); cos_vec.elem[i * 2] = cos_val; cos_vec.elem[i * 2 + 1] = cos_val; sin_vec.elem[i * 2] = sin_val; sin_vec.elem[i * 2 + 1] = sin_val; } } #pragma unroll for (int i = 0; i < PackSize / 2; i++) { out_vec.elem[i * 2] = x_vec.elem[i * 2] * cos_vec.elem[i * 2] - x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2]; out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1] + x_vec.elem[i * 2] * sin_vec.elem[i * 2 + 1]; } *(reinterpret_cast(param.out + out_offset)) = out_vec; } else { *(reinterpret_cast(param.out + out_offset)) = x_vec; } } } template __global__ void PlaneKernel( FusedApplyRotaryEmbParam param) { for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements; offset += blockDim.x * gridDim.x) { using LoadPack = cuda::elementwise::Packed; IndexType temp_offset = offset; IndexType index[num_dims]; #pragma unroll for (int i = 0; i < num_dims - 1; i++) { IndexType ref_stride = param.ref_stride[i]; IndexType idx = temp_offset / ref_stride; index[i] = idx; temp_offset = temp_offset - idx * ref_stride; } index[num_dims - 1] = temp_offset; const IndexType b_index = index[0], m_index = index[1], k_index = index[num_dims - 1]; const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; const IndexType position_id_offset = b_index * param.position_b_stride + position_rotate_index * param.position_rotate_stride + m_index; const PositionType position = param.position_ids ? param.position_ids[position_id_offset] : m_index; const IndexType actual_k_index = k_index % param.actual_rotary_size; const IndexType sinuous_offset = position * param.k + actual_k_index; T cos_val, sin_val, out_val; if (param.cos && param.sin) { cos_val = *(param.cos + sinuous_offset); sin_val = *(param.sin + sinuous_offset); } else { T val = position * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) * param.inv_actual_rotary_size * logf(param.theta)); cos_val = cosf(val); sin_val = sinf(val); } LoadPack x_vec; IndexType x_offset = param.x_offset; IndexType out_offset = 0; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; out_offset = out_offset + param.out_stride[i] * index[i]; } if (k_index < param.k0) { x_vec.elem[0] = *(param.x + x_offset); x_vec.elem[1] = (param.k0 - k_index > param.rotate_stride) ? static_cast(-*(param.x + x_offset + param.rotate_stride)) : *(param.x + x_offset - param.rotate_stride); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else if (k_index < param.k1) { x_vec.elem[0] = *(param.x + x_offset); x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride) ? static_cast(-*(param.x + x_offset + param.rotate_stride)) : *(param.x + x_offset - param.rotate_stride); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else { out_val = *(param.x + x_offset); } *(param.out + out_offset) = out_val; } } template void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, const PositionType* position_ids, T* out, const int64_t* position_shape, const std::string& x_layout, const std::string& output_layout, const std::string& mode, const T theta, const IndexType rotary_size, const IndexType b, const IndexType m, const IndexType h, const IndexType k, const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, const IndexType out_h_stride, IndexType num_elements) { const IndexType k0 = rotary_size / rotary_emb_dim, k1 = rotary_size; // TODO: this only support 1d, 2d, rotary postional encoding const IndexType rotate_stride = rotary_size / (2 * rotary_emb_dim); const IndexType actual_rotary_size = rotary_size / rotary_emb_dim; const float inv_actual_rotary_size = 1.0 / actual_rotary_size; struct FusedApplyRotaryEmbParam param( x, cos, sin, position_ids, out, theta, inv_actual_rotary_size, actual_rotary_size, rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset); const IndexType ref_strides[num_dims] = {m * h * k, h * k, k, 1}; const IndexType out_strides[num_dims] = {out_b_stride, out_m_stride, out_h_stride, 1}; const IndexType x_strides[num_dims] = {x_b_stride, x_m_stride, x_h_stride, 1}; param.sinuous_m_stride = actual_rotary_size; const IndexType position_m = position_shape ? static_cast(position_shape[2]) : m; param.position_rotate_stride = position_m; param.position_b_stride = position_m * rotary_emb_dim; // K has to be the last dimension, only k&m matters, therefore strides other than k&m does not // really needs to be computed #pragma unroll for (int i = 0; i < num_dims; i++) { param.ref_stride[i] = ref_strides[i]; param.out_stride[i] = out_strides[i]; param.x_stride[i] = x_strides[i]; } constexpr size_t blk_size = 128; if (mode == "plane") { param.num_elements = param.num_elements * PackSize; PlaneKernel <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( param); } else { IntervalKernel <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( param); } } template void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, const PositionType* position_ids, T* out, const int64_t* position_shape, const std::string& x_layout, const std::string& output_layout, const std::string& mode, const T theta, const IndexType rotary_size, const IndexType b, const IndexType m, const IndexType h, const IndexType k, const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, const IndexType out_h_stride, IndexType num_elements) { const auto CheckPackSize = [&](const size_t PackSize) { bool r = (((reinterpret_cast(x) % (sizeof(T) * PackSize)) == 0) && (((rotary_size / rotary_emb_dim) % PackSize) == 0) && (((k - rotary_size) % PackSize) == 0) && ((16 / sizeof(T)) >= PackSize)); return r; }; if (CheckPackSize(8)) { num_elements /= 8; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, num_elements); } else if (CheckPackSize(4)) { num_elements /= 4; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, num_elements); } else { num_elements /= 2; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, num_elements); } } template void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, const PositionType* position_ids, T* out, const int64_t* position_shape, const std::string& x_layout, const std::string& output_layout, const std::string& mode, const T theta, const int64_t rotary_size, const int64_t b, const int64_t m, const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, const int64_t out_h_stride) { int64_t num_elements = b * m * h * k; if (num_elements < (1 << 30)) { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, static_cast(rotary_size), static_cast(b), static_cast(m), static_cast(h), static_cast(k), static_cast(x_b_stride), static_cast(x_m_stride), static_cast(x_h_stride), static_cast(x_offset), static_cast(out_b_stride), static_cast(out_m_stride), static_cast(out_h_stride), static_cast(num_elements)); } else { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, num_elements); } } template void DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, const PositionType* position_ids, T* out, const int64_t* position_shape, const std::string& x_layout, const std::string& output_layout, const std::string& mode, const T theta, const int64_t rotary_size, const int rotary_emb_dim, const int64_t b, const int64_t m, const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, const int64_t out_h_stride) { if (rotary_emb_dim == 1) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride); } else if (rotary_emb_dim == 2) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride); } } template class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { public: FusedApplyRotaryEmbKernel() = default; ~FusedApplyRotaryEmbKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* cos = nullptr; user_op::Tensor* sin = nullptr; user_op::Tensor* position_ids = nullptr; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string& x_layout = ctx->Attr("x_layout"); const std::string& output_layout = ctx->Attr("output_layout"); const std::string& mode = ctx->Attr("mode"); const int64_t tensor_index = ctx->Attr("tensor_index"); const int64_t k_size = ctx->Attr("k_size"); const int64_t rotary_size = ctx->Attr("rotary_size"); const float theta = 1.0f / ctx->Attr("base"); int rotary_emb_dim = 1; if (ctx->has_input("cos", 0)) { cos = ctx->Tensor4ArgNameAndIndex("cos", 0); } if (ctx->has_input("sin", 0)) { sin = ctx->Tensor4ArgNameAndIndex("sin", 0); } if (ctx->has_input("position_ids", 0)) { position_ids = ctx->Tensor4ArgNameAndIndex("position_ids", 0); rotary_emb_dim = position_ids->shape_view().At(1); } constexpr size_t ndims = 4; int64_t b = 0; int64_t m = 0; int64_t h = 0; int64_t k = 0; int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; ParseDims(out->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); ParseDims(x->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); // TODO: hard code num_dims & seems redundant template problem... DispatchRotaryEmbeddingDimension( ctx->stream()->As(), reinterpret_cast(x->dptr()), cos ? reinterpret_cast(cos->dptr()) : nullptr, sin ? reinterpret_cast(sin->dptr()) : nullptr, position_ids ? reinterpret_cast(position_ids->dptr()) : nullptr, reinterpret_cast(out->mut_dptr()), position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, position_type) \ REGISTER_USER_KERNEL("fused_apply_rotary_emb") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && (user_op::HobInputSize("position_ids") == 1) \ && (user_op::HobDataType("position_ids", 0) == GetDataType::value)); #define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(dtype) \ REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, int64_t); \ REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, int32_t); \ REGISTER_USER_KERNEL("fused_apply_rotary_emb") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && (user_op::HobInputSize("position_ids") == 0)); REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(float); REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(half); #if CUDA_VERSION >= 11000 REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(nv_bfloat16); #endif // CUDA_VERSION >= 11000 } // namespace } // namespace user_op } // namespace oneflow #endif // WITH_CUTLASS ================================================ FILE: oneflow/user/kernels/fused_bias_add_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace { template struct GeluFunctor { __device__ T Compute(T x, int64_t i) const { return static_cast(0.5) * x * (static_cast(1.0) + erf(static_cast(M_SQRT1_2) * x)); } }; template<> struct GeluFunctor { GeluFunctor float_functor; __device__ half Compute(half x, int64_t i) const { return __float2half(float_functor.Compute(__half2float(x), i)); } __device__ half2 ComputeHalf2(half2 x, int64_t i) const { half2 y; y.x = __float2half(float_functor.Compute(__half2float(x.x), 2 * i)); y.y = __float2half(float_functor.Compute(__half2float(x.y), 2 * i + 1)); return y; } }; #if CUDA_VERSION >= 11000 template<> struct GeluFunctor { GeluFunctor float_functor; __device__ nv_bfloat16 Compute(nv_bfloat16 x, int64_t i) const { return static_cast(float_functor.Compute(static_cast(x), i)); } }; #endif template struct MaskAndScaleFunctor { MaskAndScaleFunctor(const bool* mask, float scale) : mask(mask), scale(scale) {} __device__ T Compute(T x, int64_t i) const { return x * static_cast(mask[i] * scale); } const bool* mask; float scale; }; template<> struct MaskAndScaleFunctor { MaskAndScaleFunctor(const bool* mask, float scale) : mask(mask), scale(scale) {} __device__ half Compute(half x, int64_t i) const { return x * static_cast(mask[i] * scale); } __device__ half2 ComputeHalf2(half2 x, int64_t i) const { const char2* mask_c2 = reinterpret_cast(mask); char2 mask_val = mask_c2[i]; half2 one_or_zero_h2; half2 h2_scale = __float2half2_rn(scale); one_or_zero_h2.x = mask_val.x; one_or_zero_h2.y = mask_val.y; return __hmul2(__hmul2(x, one_or_zero_h2), h2_scale); } const bool* mask; float scale; }; template struct MaskAndScaleAddFunctor { MaskAndScaleAddFunctor(const bool* mask, const T* addend, float scale) : mask(mask), addend(addend), scale(scale) {} __device__ T Compute(T x, int64_t i) const { return x * static_cast(mask[i] * scale) + addend[i]; } const bool* mask; const T* addend; float scale; }; template<> struct MaskAndScaleAddFunctor { MaskAndScaleAddFunctor(const bool* mask, const half* addend, float scale) : mask(mask), addend(addend), scale(scale) {} __device__ half Compute(half x, int64_t i) const { return x * static_cast(mask[i] * scale) + addend[i]; } __device__ half2 ComputeHalf2(half2 x, int64_t i) const { const char2* mask_c2 = reinterpret_cast(mask); const half2* addend_h2 = reinterpret_cast(addend); char2 mask_val = mask_c2[i]; half2 one_or_zero_h2; half2 h2_scale = __float2half2_rn(scale); one_or_zero_h2.x = mask_val.x; one_or_zero_h2.y = mask_val.y; return __hadd2(__hmul2(__hmul2(x, one_or_zero_h2), h2_scale), addend_h2[i]); } const bool* mask; const half* addend; float scale; }; template struct GeluGradFunctor { const T coef = std::sqrt(static_cast(2.0) / std::acos(static_cast(-1.0))); __device__ T Compute(T x, T dy, int64_t i) const { return static_cast(0.5) * (static_cast(1.0) + erf(static_cast(M_SQRT1_2) * x) + x * coef * exp(static_cast(-0.5) * x * x)) * dy; } }; template<> struct GeluGradFunctor { GeluGradFunctor float_functor; __device__ half Compute(half x, half dy, int64_t i) const { return __float2half(float_functor.Compute(__half2float(x), __half2float(dy), i)); } }; #if CUDA_VERSION >= 11000 template<> struct GeluGradFunctor { GeluGradFunctor float_functor; __device__ nv_bfloat16 Compute(nv_bfloat16 x, nv_bfloat16 dy, int64_t i) const { return static_cast( float_functor.Compute(static_cast(x), static_cast(dy), i)); } }; #endif template __global__ void FusedBiasAddGpu(FUNCTOR functor, const Index elem_cnt, const Index bias_size, const Index inner_size, const T* x, const T* bias, T* y) { const Index block_size = bias_size * inner_size; CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[(i % block_size) / inner_size]; y[i] = functor.Compute(x_i, i); } } template __global__ void FusedBiasAddGradGpu(FUNCTOR grad_functor, const Index elem_cnt, const Index bias_size, const Index inner_size, const T* x, const T* bias, const T* dy, T* dx) { const Index block_size = bias_size * inner_size; CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[(i % block_size) / inner_size]; dx[i] = grad_functor.Compute(x_i, dy[i], i); } } template __global__ void FusedBiasAddRowGpu(FUNCTOR functor, const Index elem_cnt, const Index bias_size, const T* x, const T* bias, T* y) { CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[i % bias_size]; y[i] = functor.Compute(x_i, i); } } template __global__ void FusedBiasAddGradRowGpu(FUNCTOR grad_functor, const Index elem_cnt, const Index bias_size, const T* x, const T* bias, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[i % bias_size]; dx[i] = grad_functor.Compute(x_i, dy[i], i); } } template __global__ void FusedBiasAddRowGpuHalf2(FUNCTOR functor, const Index elem_cnt, const Index bias_size, const half* x, const half* bias, half* y) { const Index h2_elem_cnt = elem_cnt / 2; const Index h2_bias_size = bias_size / 2; const auto* x_h2 = reinterpret_cast(x); const auto* bias_h2 = reinterpret_cast(bias); auto* y_h2 = reinterpret_cast(y); CUDA_1D_KERNEL_LOOP_T(Index, i, h2_elem_cnt) { half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]); y_h2[i] = functor.ComputeHalf2(x_i, i); } } template __global__ void FusedBiasAddGradRowGpuHalf2(FUNCTOR grad_functor, const Index elem_cnt, const Index bias_size, const half* x, const half* bias, const half* dy, half* dx) { const Index h2_elem_cnt = elem_cnt / 2; const Index h2_bias_size = bias_size / 2; const auto* x_h2 = reinterpret_cast(x); const auto* bias_h2 = reinterpret_cast(bias); const auto* dy_h2 = reinterpret_cast(dy); auto* dx_h2 = reinterpret_cast(dx); CUDA_1D_KERNEL_LOOP_T(Index, i, h2_elem_cnt) { half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]); half2 dy_i = dy_h2[i]; half2 dx_i; dx_i.x = grad_functor.Compute(x_i.x, dy_i.x, 2 * i); dx_i.y = grad_functor.Compute(x_i.y, dy_i.y, 2 * i + 1); dx_h2[i] = dx_i; } } template __global__ void FusedBiasAddColGpu(FUNCTOR functor, const Index elem_cnt, const Index inner_size, const T* x, const T* bias, T* y) { CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[i / inner_size]; y[i] = functor.Compute(x_i, i); } } template __global__ void FusedBiasAddGradColGpu(FUNCTOR grad_functor, const Index elem_cnt, const Index inner_size, const T* x, const T* bias, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) { T x_i = x[i] + bias[i / inner_size]; dx[i] = grad_functor.Compute(x_i, dy[i], i); } } template struct FusedBiasAddRow { static void Invoke(ep::Stream* stream, FUNCTOR functor, Index elem_cnt, Index bias_size, const T* x, const T* bias, T* y) { FusedBiasAddRowGpu <<As()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias, y); } }; template struct FusedBiasAddRow { static void Invoke(ep::Stream* stream, FUNCTOR functor, Index elem_cnt, Index bias_size, const half* x, const half* bias, half* y) { if (bias_size % 2 == 0) { FusedBiasAddRowGpuHalf2 <<As()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias, y); } else { FusedBiasAddRowGpu <<As()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias, y); } } }; template void FusedBiasAddForwardImpl(ep::Stream* stream, FUNCTOR functor, Index outer_size, Index bias_size, Index inner_size, const T* x, const T* bias, T* y) { const Index elem_cnt = outer_size * bias_size * inner_size; if (inner_size == 1) { FusedBiasAddRow::Invoke(stream, functor, elem_cnt, bias_size, x, bias, y); } else if (outer_size == 1) { FusedBiasAddColGpu<<As()->cuda_stream()>>>( functor, elem_cnt, inner_size, x, bias, y); } else { FusedBiasAddGpu<<As()->cuda_stream()>>>( functor, elem_cnt, bias_size, inner_size, x, bias, y); } } template struct FusedBiasAddGradRow { static void Invoke(ep::Stream* stream, FUNCTOR grad_functor, Index elem_cnt, Index bias_size, const T* x, const T* bias, const T* dy, T* dx) { FusedBiasAddGradRowGpu <<As()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x, bias, dy, dx); } }; template struct FusedBiasAddGradRow { static void Invoke(ep::Stream* stream, FUNCTOR grad_functor, Index elem_cnt, Index bias_size, const half* x, const half* bias, const half* dy, half* dx) { if (bias_size % 2 == 0) { FusedBiasAddGradRowGpuHalf2 <<As()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x, bias, dy, dx); } else { FusedBiasAddGradRowGpu <<As()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x, bias, dy, dx); } } }; template void FusedBiasAddGradImpl(ep::Stream* stream, FUNCTOR grad_functor, Index outer_size, Index bias_size, Index inner_size, const T* x, const T* bias, const T* dy, T* dx) { const Index elem_cnt = outer_size * bias_size * inner_size; if (inner_size == 1) { FusedBiasAddGradRow::Invoke(stream, grad_functor, elem_cnt, bias_size, x, bias, dy, dx); } else if (outer_size == 1) { FusedBiasAddGradColGpu <<As()->cuda_stream()>>>(grad_functor, elem_cnt, inner_size, x, bias, dy, dx); } else { FusedBiasAddGradGpu <<As()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, inner_size, x, bias, dy, dx); } } template void DispatchFusedBiasAddForwardImpl(ep::Stream* stream, FUNCTOR functor, int64_t n, int64_t outer_size, int64_t bias_size, int64_t inner_size, const T* x, const T* bias, T* y) { if (IsKernelSafeInt32(n)) { FusedBiasAddForwardImpl(stream, functor, outer_size, bias_size, inner_size, x, bias, y); } else { FusedBiasAddForwardImpl(stream, functor, outer_size, bias_size, inner_size, x, bias, y); } } } // namespace template class FusedFusedBiasAddKernel final : public user_op::OpKernel { public: FusedFusedBiasAddKernel() = default; ~FusedFusedBiasAddKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* a_tensor = ctx->Tensor4ArgNameAndIndex("a", 0); const auto* b_tensor = ctx->Tensor4ArgNameAndIndex("b", 0); auto* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t bias_add_axis = ctx->Attr("axis"); const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis); const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis); const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1); const auto n = a_tensor->shape_view().elem_cnt(); GeluFunctor gelu_functor{}; DispatchFusedBiasAddForwardImpl( ctx->stream(), gelu_functor, n, outer_size, bias_size, inner_size, a_tensor->dptr(), b_tensor->dptr(), out_tensor->mut_dptr()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_bias_add_gelu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(float) REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(double) REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(nv_bfloat16) #endif template class FusedBiasAddMaskScaleKernel final : public user_op::OpKernel { public: FusedBiasAddMaskScaleKernel() = default; ~FusedBiasAddMaskScaleKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* a_tensor = ctx->Tensor4ArgNameAndIndex("a", 0); const auto* b_tensor = ctx->Tensor4ArgNameAndIndex("b", 0); const auto* mask_tensor = ctx->Tensor4ArgNameAndIndex("mask", 0); auto* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t bias_add_axis = ctx->Attr("axis"); const float scale = ctx->Attr("scale"); const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis); const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis); const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1); const auto n = a_tensor->shape_view().elem_cnt(); if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); MaskAndScaleAddFunctor mask_and_scale_add_functor(mask_tensor->dptr(), addend->dptr(), scale); DispatchFusedBiasAddForwardImpl( ctx->stream(), mask_and_scale_add_functor, n, outer_size, bias_size, inner_size, a_tensor->dptr(), b_tensor->dptr(), out_tensor->mut_dptr()); } else { MaskAndScaleFunctor mask_and_scale_functor(mask_tensor->dptr(), scale); DispatchFusedBiasAddForwardImpl( ctx->stream(), mask_and_scale_functor, n, outer_size, bias_size, inner_size, a_tensor->dptr(), b_tensor->dptr(), out_tensor->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_bias_add_mask_scale") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(float) REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(double) REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(nv_bfloat16) #endif template class FusedFusedBiasAddGradKernel final : public user_op::OpKernel { public: FusedFusedBiasAddGradKernel() = default; ~FusedFusedBiasAddGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* a_tensor = ctx->Tensor4ArgNameAndIndex("a", 0); const auto* b_tensor = ctx->Tensor4ArgNameAndIndex("b", 0); const auto* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const int32_t bias_add_axis = ctx->Attr("axis"); const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis); const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis); const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1); const auto n = a_tensor->shape_view().elem_cnt(); GeluGradFunctor gelu_grad_functor; if (IsKernelSafeInt32(n)) { FusedBiasAddGradImpl( ctx->stream(), gelu_grad_functor, outer_size, bias_size, inner_size, a_tensor->dptr(), b_tensor->dptr(), dy_tensor->dptr(), dx_tensor->mut_dptr()); } else { FusedBiasAddGradImpl( ctx->stream(), gelu_grad_functor, outer_size, bias_size, inner_size, a_tensor->dptr(), b_tensor->dptr(), dy_tensor->dptr(), dx_tensor->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_bias_add_gelu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(float) REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(double) REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(nv_bfloat16) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_bias_add_scale_mask_softmax_dropout.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/fused_softmax.cuh" namespace oneflow { namespace cuda { namespace { template struct BroadcastMapper { using index_type = IndexType; IndexType src_dims[NDIM] = {0}; IndexType dst_dims[NDIM] = {0}; template BroadcastMapper(const DimType* arg_src_dims, const DimType* arg_dst_dims) { for (size_t i = 0; i < NDIM; ++i) { src_dims[i] = arg_src_dims[i]; } for (size_t i = 0; i < NDIM; ++i) { dst_dims[i] = arg_dst_dims[i]; } } __device__ IndexType map(IndexType src) const { NdIndexOffsetHelper src_index_helper(src_dims); NdIndexOffsetHelper dst_index_helper(dst_dims); IndexType src_index[NDIM]; IndexType dst_index[NDIM]; src_index_helper.OffsetToNdIndex(src, src_index); #pragma unroll for (int dim = 0; dim < NDIM; ++dim) { if (dst_dims[dim] == 1) { dst_index[dim] = 0; } else { dst_index[dim] = src_index[dim]; } } return dst_index_helper.NdIndexToOffset(dst_index); } }; template struct ElementwiseMapper { using index_type = IndexType; ElementwiseMapper() {} __device__ IndexType map(IndexType index) const { return index; } }; template struct BiasAddScaleMaskLoad { static_assert( std::is_same::value, ""); using IndexType = typename BiasMapper::index_type; const SRC* src; const SRC* bias; const MASK* mask; const DST fill; const DST scale; const IndexType row_size; const BiasMapper bias_mapper; const MaskMapper mask_mapper; BiasAddScaleMaskLoad(const SRC* src, const SRC* bias, const MASK* mask, const DST fill, const DST scale, const IndexType row_size, const BiasMapper bias_mapper, const MaskMapper mask_mapper) : src(src), bias(bias), mask(mask), fill(fill), scale(scale), row_size(row_size), bias_mapper(bias_mapper), mask_mapper(mask_mapper) {} template __device__ void load(DST* dst, IndexType row, IndexType col) { softmax::Pack src_pack; softmax::Pack bias_pack; softmax::Pack mask_pack; const IndexType offset = row * row_size + col; const IndexType bias_offset = bias_mapper.map(offset); const IndexType mask_offset = mask_mapper.map(offset); src_pack.storage = *(reinterpret_cast*>(src) + offset / N); bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset / N); mask_pack.storage = *(reinterpret_cast*>(mask) + mask_offset / N); #pragma unroll for (int i = 0; i < N; ++i) { if (mask_pack.elem[i] == 0) { dst[i] = fill; } else { dst[i] = static_cast(src_pack.elem[i] + bias_pack.elem[i]) * scale; } } } }; template void DispatchForward(cudaStream_t stream, const user_op::Tensor* x, const user_op::Tensor* bias, const user_op::Tensor* mask, const user_op::Tensor* dropout_mask, const float mask_fill, const float scale, const float dropout_scale, user_op::Tensor* y, user_op::Tensor* softmax_y) { using ComputeType = typename softmax::DefaultComputeType::type; using IndexType = int32_t; constexpr int kMaxNDim = 5; const auto& x_shape = x->shape_view(); CHECK_GE(x_shape.size(), 2); // the last dim is softmax dim which is considered as col int64_t ncol = x_shape[x_shape.size() - 1]; int64_t nrow = x_shape.elem_cnt() / ncol; fused_softmax::DropoutStore store( y->mut_dptr(), softmax_y->mut_dptr(), dropout_mask->dptr(), ncol, dropout_scale); size_t bias_sndim = 0; int64_t bias_x_sdims[kMaxNDim]; int64_t bias_sdims[kMaxNDim]; const auto& bias_shape = bias->shape_view(); fused_softmax::SimplifyBroadcastDims(x_shape.size(), x_shape.ptr(), bias_shape.size(), bias_shape.ptr(), &bias_sndim, bias_x_sdims, bias_sdims); size_t mask_sndim = 0; int64_t mask_x_sdims[kMaxNDim]; int64_t mask_sdims[kMaxNDim]; const auto& mask_shape = mask->shape_view(); fused_softmax::SimplifyBroadcastDims(x_shape.size(), x_shape.ptr(), mask_shape.size(), mask_shape.ptr(), &mask_sndim, mask_x_sdims, mask_sdims); #define DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper) \ BiasAddScaleMaskLoad load( \ x->dptr(), bias->dptr(), mask->dptr(), mask_fill, scale, ncol, bias_mapper, \ mask_mapper); \ OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( \ stream, load, store, nrow, ncol))) if (bias_sndim == 1 && mask_sndim == 1) { // bias elementwise // mask elementwise ElementwiseMapper bias_mapper; ElementwiseMapper mask_mapper; DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); } else if (bias_sndim == 1 && mask_sndim == 2) { // bias elementwise // mask broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N) ElementwiseMapper bias_mapper; BroadcastMapper mask_mapper(mask_x_sdims, mask_sdims); DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); } else if (bias_sndim == 1 && mask_sndim == 3) { // bias elementwise // mask broadcast: (M, 1, N) -> (M, K, N) ElementwiseMapper bias_mapper; BroadcastMapper mask_mapper(mask_x_sdims, mask_sdims); DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); } else if (bias_sndim == 2 && mask_sndim == 1) { // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N) // mask elementwise BroadcastMapper bias_mapper(bias_x_sdims, bias_sdims); ElementwiseMapper mask_mapper; DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); } else if (bias_sndim == 2 && mask_sndim == 2) { // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N) // mask broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N) BroadcastMapper bias_mapper(bias_x_sdims, bias_sdims); BroadcastMapper mask_mapper(mask_x_sdims, mask_sdims); DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); } else if (bias_sndim == 2 && mask_sndim == 3) { // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N) // mask broadcast: (M, 1, N) -> (M, K, N) BroadcastMapper bias_mapper(bias_x_sdims, bias_sdims); BroadcastMapper mask_mapper(mask_x_sdims, mask_sdims); DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper); // not support for now // } else if (bias_sndim == 3 && mask_sndim == 1) { // } else if (bias_sndim == 3 && mask_sndim == 2) { // } else if (bias_sndim == 3 && mask_sndim == 3) { } else { UNIMPLEMENTED() << ", bias_sndim=" << bias_sndim << ", mask_sndim=" << mask_sndim; } #undef DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX } template class FusedBiasAddScaleMaskSoftmaxDropoutKernel final : public user_op::OpKernel { public: FusedBiasAddScaleMaskSoftmaxDropoutKernel() = default; ~FusedBiasAddScaleMaskSoftmaxDropoutKernel() override = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex("dropout_mask", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex("softmax_y", 0); const float mask_fill = ctx->Attr("mask_fill_value"); const float scale = ctx->Attr("scale_value"); const float dropout_scale = ctx->Attr("dropout_scale_value"); const ShapeView& x_shape = x->shape_view(); // int32 index computing is much faster than int64 // TODO: consider using multiple int32 computing to substitute int64 computing CHECK_LT(x_shape.elem_cnt(), INT_MAX) << "only support int32 max limits size of elements"; DispatchForward(ctx->stream()->As()->cuda_stream(), x, bias, mask, dropout_mask, mask_fill, scale, dropout_scale, y, softmax_y); } }; } // namespace #define REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(dtype, mask_dtype) \ REGISTER_USER_KERNEL("fused_bias_add_scale_mask_softmax_dropout") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)); REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(float, bool) REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(half, bool) #undef REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL } // namespace cuda } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_cast_scale_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { template class FusedCastScaleCpuKernel final : public user_op::OpKernel { public: FusedCastScaleCpuKernel() = default; ~FusedCastScaleCpuKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const double scale_val = ctx->Attr("scale"); const int64_t n = x->shape_view().elem_cnt(); const T scale = *(scale_by_tensor->dptr()) * scale_val; const U* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); FOR_RANGE(int64_t, i, 0, n) { y_ptr[i] = static_cast(x_ptr[i]) * scale; } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_CAST_SCALE_CPU_KERNEL(x_type, y_type) \ REGISTER_USER_KERNEL("fused_cast_scale") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_FUSED_CAST_SCALE_CPU_KERNEL(float, double); REGISTER_FUSED_CAST_SCALE_CPU_KERNEL(double, float); #undef REGISTER_FUSED_CAST_SCALE_CPU_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_cast_scale_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace { template __global__ void FusedCastScaleGpu(const int64_t n, const T scale_val, const U* in, const T* scale_by_ptr, T* out) { const T scale = *scale_by_ptr * scale_val; CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast(in[i]) * scale; } } template<> __global__ void FusedCastScaleGpu(const int64_t n, const float scale_val, const half* in, const float* scale_by_ptr, float* out) { const float scale = *scale_by_ptr * scale_val; const int64_t n_2 = n / 2; const auto* in_2 = reinterpret_cast(in); auto* out_2 = reinterpret_cast(out); CUDA_1D_KERNEL_LOOP(i, n_2) { float2 f2 = __half22float2(in_2[i]); f2.x *= scale; f2.y *= scale; out_2[i] = f2; } if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) { out[n - 1] = __half2float(in[n - 1]) * scale; } } template<> __global__ void FusedCastScaleGpu(const int64_t n, const half scale_val, const float* in, const half* scale_by_ptr, half* out) { const half scale = *scale_by_ptr * scale_val; const half2 scale_h2 = __half2half2(scale); const int64_t n_2 = n / 2; const auto* in_2 = reinterpret_cast(in); auto* out_h2 = reinterpret_cast(out); CUDA_1D_KERNEL_LOOP(i, n_2) { half2 in_h2 = __float22half2_rn(in_2[i]); out_h2[i] = __hmul2(in_h2, scale_h2); } if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) { out[n - 1] = __float2half(in[n - 1]) * scale; } } #if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800 template<> __global__ void FusedCastScaleGpu(const int64_t n, const float scale_val, const nv_bfloat16* in, const float* scale_by_ptr, float* out) { const float scale = *scale_by_ptr * scale_val; const int64_t n_2 = n / 2; const auto* in_2 = reinterpret_cast(in); auto* out_2 = reinterpret_cast(out); CUDA_1D_KERNEL_LOOP(i, n_2) { float2 f2 = __bfloat1622float2(in_2[i]); f2.x *= scale; f2.y *= scale; out_2[i] = f2; } if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) { out[n - 1] = __bfloat162float(in[n - 1]) * scale; } } template<> __global__ void FusedCastScaleGpu(const int64_t n, const nv_bfloat16 scale_val, const float* in, const nv_bfloat16* scale_by_ptr, nv_bfloat16* out) { const nv_bfloat16 scale = *scale_by_ptr * scale_val; const nv_bfloat162 scale_h2 = __bfloat162bfloat162(scale); const int64_t n_2 = n / 2; const auto* in_2 = reinterpret_cast(in); auto* out_h2 = reinterpret_cast(out); CUDA_1D_KERNEL_LOOP(i, n_2) { nv_bfloat162 in_h2 = __float22bfloat162_rn(in_2[i]); out_h2[i] = __hmul2(in_h2, scale_h2); } if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) { out[n - 1] = __float2bfloat16(in[n - 1]) * scale; } } #endif template class FusedCastScaleGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedCastScaleGpuKernel() = default; ~FusedCastScaleGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t n = x->shape_view().elem_cnt(); const double scale = ctx->Attr("scale"); const bool use_pack = (x->data_type() == DataType::kFloat && (y->data_type() == DataType::kFloat16 || y->data_type() == DataType::kBFloat16)) || (y->data_type() == DataType::kFloat && (x->data_type() == DataType::kFloat16 || x->data_type() == DataType::kBFloat16)); const int64_t launch_n = use_pack ? RoundUp(n, 2) / 2 : n; FusedCastScaleGpu<<stream()->As()->cuda_stream()>>>( n, static_cast(scale), x->dptr(), scale_by_tensor->dptr(), y->mut_dptr()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(x_type, y_type) \ REGISTER_USER_KERNEL("fused_cast_scale") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(half, float); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(half, double); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, half); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, double); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(double, half); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(double, float); #if CUDA_VERSION >= 11000 REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(nv_bfloat16, float); REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, nv_bfloat16); #endif #undef REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_center_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template struct FusedCenterForwardFunctor { __device__ T Compute(T b_x_delta, T b_y_delta) const { return (b_x_delta * b_x_delta + b_y_delta * b_y_delta) / static_cast(4.0); } }; template<> struct FusedCenterForwardFunctor { FusedCenterForwardFunctor float_functor; __device__ half Compute(half b_x_delta, half b_y_delta) const { return __float2half(float_functor.Compute(__half2float(b_x_delta), __half2float(b_y_delta))); } }; template __global__ void FusedCenterForward(FUNCTOR functor, const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, T* rho) { CUDA_1D_KERNEL_LOOP(i, n) { const T b_x_delta = (b2_x1[i] + b2_x2[i] - b1_x1[i] - b1_x2[i]); const T b_y_delta = (b2_y1[i] + b2_y2[i] - b1_y1[i] - b1_y2[i]); rho[i] = functor.Compute(b_x_delta, b_y_delta); } } template __global__ void FusedCenterBackward(const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, const T* rho2_diff, T* b1_x1_diff, T* b1_x2_diff, T* b2_x1_diff, T* b2_x2_diff, T* b1_y1_diff, T* b1_y2_diff, T* b2_y1_diff, T* b2_y2_diff) { CUDA_1D_KERNEL_LOOP(i, n) { const T rho2_diff_i_2 = rho2_diff[i] / static_cast(2.0); const T b_x_diff = rho2_diff_i_2 * (b1_x1[i] + b1_x2[i] - b2_x1[i] - b2_x2[i]); const T b_y_diff = rho2_diff_i_2 * (b1_y1[i] + b1_y2[i] - b2_y1[i] - b2_y2[i]); b1_x1_diff[i] = b_x_diff; b1_x2_diff[i] = b_x_diff; b2_x1_diff[i] = b_x_diff * static_cast(-1.0); b2_x2_diff[i] = b_x_diff * static_cast(-1.0); b1_y1_diff[i] = b_y_diff; b1_y2_diff[i] = b_y_diff; b2_y1_diff[i] = b_y_diff * static_cast(-1.0); b2_y2_diff[i] = b_y_diff * static_cast(-1.0); } } } // namespace template class FusedCenterKernel final : public user_op::OpKernel { public: FusedCenterKernel() = default; ~FusedCenterKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); user_op::Tensor* rho = ctx->Tensor4ArgNameAndIndex("rho2", 0); const int64_t elem_cnt = b1_x1->shape_view().elem_cnt(); FusedCenterForwardFunctor fused_center_forward_functor{}; RUN_CUDA_KERNEL((FusedCenterForward), ctx->stream(), elem_cnt, fused_center_forward_functor, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), rho->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_center_dist") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("rho2", 0) == GetDataType::value)); REGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(float) REGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(double) REGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(half) template class FusedCenterGradKernel final : public user_op::OpKernel { public: FusedCenterGradKernel() = default; ~FusedCenterGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); const user_op::Tensor* rho2_diff = ctx->Tensor4ArgNameAndIndex("rho2_diff", 0); user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex("b1_x1_diff", 0); user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex("b1_x2_diff", 0); user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex("b2_x1_diff", 0); user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex("b2_x2_diff", 0); user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex("b1_y1_diff", 0); user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex("b1_y2_diff", 0); user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex("b2_y1_diff", 0); user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex("b2_y2_diff", 0); const int64_t elem_cnt = b1_x1_diff->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedCenterBackward), ctx->stream(), elem_cnt, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), rho2_diff->dptr(), b1_x1_diff->mut_dptr(), b1_x2_diff->mut_dptr(), b2_x1_diff->mut_dptr(), b2_x2_diff->mut_dptr(), b1_y1_diff->mut_dptr(), b1_y2_diff->mut_dptr(), b2_y1_diff->mut_dptr(), b2_y2_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_center_dist_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1", 0) == GetDataType::value)); REGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(double) REGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_clip_grad.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/fused_clip_grad.h" namespace oneflow { namespace { constexpr int64_t kMultiReduceScaleMulPackSize = 64; template struct MultiClipGradParamPack { MultiClipGradParam params[kMultiReduceScaleMulPackSize]; size_t size; }; size_t InferFusedClipGradTempStorageSize(user_op::InferContext* ctx) { auto input_size = ctx->input_size("model_diff"); if (input_size == 0) { return 0; } int64_t max_elem_cnt = 0; int64_t pack_size = 0; int32_t num_blocks = 0; for (size_t i = 0; i < input_size; ++i) { int64_t elem_cnt = ctx->InputShape("model_diff", i).elem_cnt(); max_elem_cnt = std::max(max_elem_cnt, elem_cnt); pack_size++; if (pack_size == kMultiReduceScaleMulPackSize || i == input_size - 1) { CHECK_LT(max_elem_cnt, std::numeric_limits::max()); num_blocks += BlocksNum4ThreadsNum(static_cast(max_elem_cnt)); max_elem_cnt = 0; pack_size = 0; } } CHECK_LT(num_blocks, kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock) << "Too much blocks needed for computing " << ctx->op_name() << ", should be less than " << kCudaThreadsNumPerBlock << "*" << kCudaThreadsNumPerBlock << "*" << kCudaThreadsNumPerBlock << ", but got " << num_blocks; size_t elem_size = GetSizeOfDataType(ctx->InputDType("model_diff", 0)); return GetCudaAlignedSize(num_blocks * elem_size * 2); } template __global__ void MultiBlockClipGradGpu(MultiClipGradParamPack pack_params, T* scale, const float norm_type, const float max_norm, const ClipGradType clip_grad_type, const bool scale_writable) { T t = *scale; if (clip_grad_type == ClipGradType::ZeroType) { t = static_cast(t > 0); } else if (clip_grad_type == ClipGradType::PowerType) { t = std::pow(t, 1.f / norm_type); } if (scale_writable && blockDim.x * blockIdx.x + threadIdx.x == 0) { *scale = t; } t = max_norm / (t + 1e-6); if (t >= 1.) { return; } for (int i = 0; i < pack_params.size; ++i) { auto& param = pack_params.params[i]; CUDA_1D_KERNEL_LOOP(j, param.size) { param.data[j] *= t; } } } } // namespace template struct MultiClipGrad { void operator()(ep::Stream* stream, std::vector>& params, T* scale, const float norm_type, const float max_norm, const ClipGradType clip_grad_type) { int32_t total_num_blocks = 0; for (size_t i = 0; i < params.size(); i += kMultiReduceScaleMulPackSize) { MultiClipGradParamPack pack_params{}; size_t max_elem_cnt = 0; pack_params.size = std::min(kMultiReduceScaleMulPackSize, params.size() - i); for (size_t j = 0; j < pack_params.size; ++j) { pack_params.params[j] = params[i + j]; max_elem_cnt = std::max(max_elem_cnt, pack_params.params[j].size); } int32_t num_blocks = BlocksNum4ThreadsNum(max_elem_cnt); bool scale_writable = static_cast(i + kMultiReduceScaleMulPackSize >= params.size()); MultiBlockClipGradGpu <<As()->cuda_stream()>>>( pack_params, scale, norm_type, max_norm, clip_grad_type, scale_writable); total_num_blocks += num_blocks; } } }; #define REGISTER_FUSED_CLIP_GRAD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("fused_clip_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferFusedClipGradTempStorageSize); REGISTER_FUSED_CLIP_GRAD_KERNEL(DeviceType::kCUDA, float); REGISTER_FUSED_CLIP_GRAD_KERNEL(DeviceType::kCUDA, double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_clip_grad.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_ #define ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/multi_reduce_kernel_util.h" #include "oneflow/user/kernels/fused_clip_grad_util.h" namespace oneflow { template class FusedClipGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedClipGradKernel() = default; ~FusedClipGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T* out_ptr = out->mut_dptr(); T* temp = (ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0))->mut_dptr(); const int32_t input_size = ctx->input_size("model_diff"); const float max_norm = ctx->Attr("max_norm"); const float norm_type = ctx->Attr("norm_type"); std::vector> params; params.resize(input_size); for (size_t i = 0; i < input_size; ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("model_diff", i); params[i].size = x->shape_view().elem_cnt(); params[i].data = x->dptr(); } if (norm_type == 0) { PowByZero func{}; MultiReduce> reduce_add{}; reduce_add(ctx->stream(), func, params, GetZeroVal(), out_ptr, temp); } else if (norm_type == INFINITY) { Abs func{}; MultiReduce> reduce_max{}; reduce_max(ctx->stream(), func, params, GetZeroVal(), out_ptr, temp); } else if (norm_type == -INFINITY) { Abs func{}; MultiReduce> reduce_min{}; reduce_min(ctx->stream(), func, params, std::numeric_limits::max(), out_ptr, temp); } else if (norm_type == 1) { Abs func{}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), out_ptr, temp); } else if (norm_type == 2) { Square func{}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), out_ptr, temp); } else { AbsPow func{norm_type}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), out_ptr, temp); } std::vector> mut_params; mut_params.resize(input_size); for (size_t i = 0; i < input_size; ++i) { user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("model_diff", i); mut_params[i].size = x->shape_view().elem_cnt(); mut_params[i].data = x->mut_dptr(); } MultiClipGrad multi_clip_grad{}; if (norm_type == 0) { multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm, ClipGradType::ZeroType); } else if (std::abs(norm_type) == INFINITY || norm_type == 1) { multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm, ClipGradType::OtherType); } else { multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm, ClipGradType::PowerType); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_ ================================================ FILE: oneflow/user/kernels/fused_clip_grad_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_ #define ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct MultiClipGradParam { T* data; size_t size; }; enum ClipGradType : int { ZeroType, PowerType, OtherType, }; template struct MultiClipGrad { void operator()(ep::Stream* stream, std::vector>& params, T* scale, const float norm_type, const float max_norm, const ClipGradType clip_grad_type); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_ ================================================ FILE: oneflow/user/kernels/fused_codegeex_qkv_reshape_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template struct alignas(sizeof(T) * pack_size) Packed { __device__ Packed(T val) { #pragma unroll for (int i = 0; i < pack_size; i++) { elem[i] = val; } } __device__ Packed() { // do nothing } union { T elem[pack_size]; }; __device__ void operator=(Packed packA) { #pragma unroll for (int i = 0; i < pack_size; i++) { elem[i] = packA.elem[i]; } } }; // [seq_length, batch_size, hidden_size] -> [seq_length, batch_size, head_num, size_per_head] template __global__ void batch_reshape_for_qkv(const int n, const T* query, const T* key, const T* value, T* new_query, T* new_key, T* new_value) { const auto* query_pack_ptr = reinterpret_cast*>(query); const auto* key_pack_ptr = reinterpret_cast*>(key); const auto* value_pack_ptr = reinterpret_cast*>(value); auto* new_query_pack_ptr = reinterpret_cast*>(new_query); auto* new_key_pack_ptr = reinterpret_cast*>(new_key); auto* new_value_pack_ptr = reinterpret_cast*>(new_value); assert(n % pack_size == 0); CUDA_1D_KERNEL_LOOP(i, n) { Packed query_pack = query_pack_ptr[i]; Packed key_pack = key_pack_ptr[i]; Packed value_pack = value_pack_ptr[i]; new_query_pack_ptr[i] = query_pack; new_key_pack_ptr[i] = key_pack; new_value_pack_ptr[i] = value_pack; } } }; // namespace template class FusedCodegeexQkvReshapeGpuKernel final : public user_op::OpKernel { public: FusedCodegeexQkvReshapeGpuKernel() = default; ~FusedCodegeexQkvReshapeGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { // [seq_length, batch_size, hidden_size] -> [seq_length, batch_size, head_num, size_per_head] const user_op::Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); const user_op::Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* new_query = ctx->Tensor4ArgNameAndIndex("new_query", 0); user_op::Tensor* new_key = ctx->Tensor4ArgNameAndIndex("new_key", 0); user_op::Tensor* new_value = ctx->Tensor4ArgNameAndIndex("new_value", 0); const int32_t n = query->shape_view().elem_cnt(); if (n % 4 == 0) { RUN_CUDA_KERNEL((batch_reshape_for_qkv), ctx->stream(), n / 4, n / 4, query->dptr(), key->dptr(), value->dptr(), new_query->mut_dptr(), new_key->mut_dptr(), new_value->mut_dptr()); } else if (n % 2 == 0) { RUN_CUDA_KERNEL((batch_reshape_for_qkv), ctx->stream(), n / 2, n / 2, query->dptr(), key->dptr(), value->dptr(), new_query->mut_dptr(), new_key->mut_dptr(), new_value->mut_dptr()); } else { RUN_CUDA_KERNEL((batch_reshape_for_qkv), ctx->stream(), n, n, query->dptr(), key->dptr(), value->dptr(), new_query->mut_dptr(), new_key->mut_dptr(), new_value->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_codegeex_qkv_reshape") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("query", 0) == GetDataType::value)); REGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(float) REGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(half) REGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_cross_feature_interaction.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { enum InteractionMode { kVector = 0, kMatrix }; constexpr int kBlockSize = 256; void InferMatmulMNK(const ShapeView& a_shape, const ShapeView& b_shape, bool transpose_a, bool transpose_b, size_t* m, size_t* n, size_t* k) { const int64_t num_a_axes = a_shape.NumAxes(); CHECK_GE(num_a_axes, 2); const int64_t num_b_axes = b_shape.NumAxes(); CHECK_GE(num_b_axes, 2); if (!transpose_a) { *m = a_shape.At(num_a_axes - 2); *k = a_shape.At(num_a_axes - 1); } else { *m = a_shape.At(num_a_axes - 1); *k = a_shape.At(num_a_axes - 2); } if (!transpose_b) { CHECK_EQ(b_shape.At(num_b_axes - 2), *k); *n = b_shape.At(num_b_axes - 1); } else { CHECK_EQ(b_shape.At(num_b_axes - 1), *k); *n = b_shape.At(num_b_axes - 2); } } ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("x", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/true); } auto MatmulPrimitiveExists() { return hob::make_custom("MatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatmulPrimitive(&ctx).operator bool(); }); } template __global__ void FusedBiasAddMulAddResidualKernel(const T* in, const T* x, const T* x0, const T* bias, T* out, const IndexType cols, const IndexType elem_cnt) { const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x; using LoadPack = cuda::elementwise::Packed; for (IndexType linear_index = global_thread_id * pack_size, step = gridDim.x * blockDim.x * pack_size; linear_index < elem_cnt; linear_index += step) { const IndexType row_idx = linear_index / cols; const IndexType col_idx = linear_index - row_idx * cols; const LoadPack* x0_load = reinterpret_cast(x0 + linear_index); const LoadPack* x_load = reinterpret_cast(x + linear_index); const LoadPack* bias_load = reinterpret_cast(bias + col_idx); LoadPack x0_vec = *x0_load; LoadPack x_vec = *x_load; LoadPack bias_vec = *bias_load; LoadPack out_store; if (mode == InteractionMode::kVector) { T in_val = in[row_idx]; #pragma unroll for (int i = 0; i < pack_size; i++) { out_store.elem[i] = x0_vec.elem[i] * in_val + bias_vec.elem[i] + x_vec.elem[i]; } } else if (mode == InteractionMode::kMatrix) { const LoadPack* in_load = reinterpret_cast(in + linear_index); LoadPack in_vec = *in_load; #pragma unroll for (int i = 0; i < pack_size; i++) { out_store.elem[i] = (in_vec.elem[i] + bias_vec.elem[i]) * x0_vec.elem[i] + x_vec.elem[i]; } } else { __trap(); } *(reinterpret_cast(out + linear_index)) = out_store; } } template int GetLaunchPackSize(const int64_t cols) { constexpr int type_pack_size = cuda::elementwise::PackSize(); for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) { if (type_pack_size >= launch_pack_size && cols % launch_pack_size == 0) { return launch_pack_size; } } return 1; } template void DispatchFusedBiasAddMulAddResidualPackSize(ep::Stream* stream, const T* in, const T* x, const T* x0, const T* bias, T* out, const IndexType cols, const IndexType elem_cnt) { int grid_size; const int pack_size = GetLaunchPackSize(cols); const int64_t pack_num = elem_cnt / pack_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (pack_size == 8) { FusedBiasAddMulAddResidualKernel <<As()->cuda_stream()>>>( in, x, x0, bias, out, cols, elem_cnt); } else if (pack_size == 4) { FusedBiasAddMulAddResidualKernel <<As()->cuda_stream()>>>( in, x, x0, bias, out, cols, elem_cnt); } else if (pack_size == 2) { FusedBiasAddMulAddResidualKernel <<As()->cuda_stream()>>>( in, x, x0, bias, out, cols, elem_cnt); } else { FusedBiasAddMulAddResidualKernel <<As()->cuda_stream()>>>( in, x, x0, bias, out, cols, elem_cnt); } } template void DispatchFusedBiasAddMulAddResidualIndexType(ep::Stream* stream, const T* in, const T* x, const T* x0, const T* bias, T* out, const int64_t cols, const int64_t elem_cnt) { if (elem_cnt < GetMaxVal()) { DispatchFusedBiasAddMulAddResidualPackSize(stream, in, x, x0, bias, out, cols, elem_cnt); } else { DispatchFusedBiasAddMulAddResidualPackSize(stream, in, x, x0, bias, out, cols, elem_cnt); } } template class FusedCrossFeatureInteractionKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedCrossFeatureInteractionKernel() = default; ~FusedCrossFeatureInteractionKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { /* Cross Interaction v1: 1. x matmul weight. matmul_result0 -> (B, E) matmul (1, E) -> (B, 1) dx = dmatmul_result0 matmul weight dw = x matmul dmatmul_result0 2. matmul_result0 broadcast_mul x0. matmul_result1 -> (B, 1) broadcast_mul (B, E) -> (B, E) dmatmul_result0 = reduce_sum(dmatmul_result1 * x0, axis=1) dx0 = dmatmul_result1 broadcast_mul matmul_result0 3. matmul_result1 broadcast_add bias. matmul_result2 -> (B, E) broadcast_add (1, E) -> (B, E) dmatmul_result1 = dout dbias = reduce_sum(dmatmul_result2, axis=0) 4. matmul_result2 add x. out -> (B, E) elementwise_add (B, E) -> (B, E) dmatmul_result2 = dout, dx = dout. Cross Interaction Grad: dw = x matmul dmatmul_result0 dx0 = dmatmul_result1 broadcast_mul matmul_result0 dbias = reduce_sum(dmatmul_result2, axis=0) dx = (dmatmul_result0 matmul weight) + dout. Cross Interaction v2: 1. x matmul weight. matmul_result0 -> (B, E) matmul (E, E) -> (B, E) 2. matmul_result0 add bias. matmul_result1 -> (B, E) bias_add (1, E) -> (B, E) 3. matmul_result1 multiply x0. matmul_result2 -> (B, E) elementwise_mul (B, E) -> (B, E) 4. matmul_result2 add x. out -> (B, E) elementwise_add (B, E) -> (B, E) */ const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* x0 = ctx->Tensor4ArgNameAndIndex("x0", 0); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex("matmul_result", 0); const std::string interaction_mode = ctx->Attr("interaction_mode"); CHECK_EQ(out->shape_view().NumAxes(), 2); size_t m = 0, n = 0, k = 0; InferMatmulMNK(x->shape_view(), weight->shape_view(), /*trans_a=*/false, /*trans_b=*/true, &m, &n, &k); const double alpha = 1.0; double beta = 0.0; auto matmul = NewMatmulPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, x->dptr(), weight->dptr(), beta, matmul_result->mut_dptr()); const int64_t elem_cnt = out->shape_view().elem_cnt(); const int64_t cols = out->shape_view().At(1); if (interaction_mode == "vector") { DispatchFusedBiasAddMulAddResidualIndexType( ctx->stream(), matmul_result->mut_dptr(), x->dptr(), x0->dptr(), bias->dptr(), out->mut_dptr(), cols, elem_cnt); } else { DispatchFusedBiasAddMulAddResidualIndexType( ctx->stream(), matmul_result->mut_dptr(), x->dptr(), x0->dptr(), bias->dptr(), out->mut_dptr(), cols, elem_cnt); } } }; #define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_cross_feature_interaction") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && MatmulPrimitiveExists()); REGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(float) REGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(half) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_cross_feature_interaction_grad.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { constexpr int kBlockSize = 256; void InferMatmulMNK(const DimVector& a_shape, const DimVector& b_shape, bool transpose_a, bool transpose_b, size_t* m, size_t* n, size_t* k) { const int64_t num_a_axes = a_shape.size(); CHECK_GE(num_a_axes, 2); const int64_t num_b_axes = b_shape.size(); CHECK_GE(num_b_axes, 2); if (!transpose_a) { *m = a_shape.at(num_a_axes - 2); *k = a_shape.at(num_a_axes - 1); } else { *m = a_shape.at(num_a_axes - 1); *k = a_shape.at(num_a_axes - 2); } if (!transpose_b) { CHECK_EQ(b_shape.at(num_b_axes - 2), *k); *n = b_shape.at(num_b_axes - 1); } else { CHECK_EQ(b_shape.at(num_b_axes - 1), *k); *n = b_shape.at(num_b_axes - 2); } } ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } template struct MulOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a * b; } }; template struct AddOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } }; template int GetLaunchPackSize(const int64_t cols) { constexpr int type_pack_size = cuda::elementwise::PackSize(); for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) { if (type_pack_size >= launch_pack_size && cols % launch_pack_size == 0) { return launch_pack_size; } } return 1; } template __global__ void BroadcastMulKernel(const T* x, const T* y, T* out, const IndexType cols, const IndexType elem_cnt) { const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x; using LoadPack = cuda::elementwise::Packed; for (IndexType linear_index = global_thread_id * pack_size, step = gridDim.x * blockDim.x * pack_size; linear_index < elem_cnt; linear_index += step) { const IndexType row_idx = linear_index / cols; const LoadPack* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec = *x_load; LoadPack out_store; const T y_val = y[row_idx]; #pragma unroll for (int i = 0; i < pack_size; i++) { out_store.elem[i] = x_vec.elem[i] * y_val; } *(reinterpret_cast(out + linear_index)) = out_store; } } template void DispatchBroadcastMulPackSize(ep::Stream* stream, const T* x, const T* y, T* out, const IndexType cols, const IndexType elem_cnt) { int grid_size; const int pack_size = GetLaunchPackSize(cols); const int64_t pack_num = elem_cnt / pack_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (pack_size == 8) { BroadcastMulKernel <<As()->cuda_stream()>>>(x, y, out, cols, elem_cnt); } else if (pack_size == 4) { BroadcastMulKernel <<As()->cuda_stream()>>>(x, y, out, cols, elem_cnt); } else if (pack_size == 2) { BroadcastMulKernel <<As()->cuda_stream()>>>(x, y, out, cols, elem_cnt); } else { BroadcastMulKernel <<As()->cuda_stream()>>>(x, y, out, cols, elem_cnt); } } template void DispatchBroadcastMulIndexType(ep::Stream* stream, const T* x, const T* y, T* out, const int64_t cols, const int64_t elem_cnt) { if (elem_cnt < GetMaxVal()) { DispatchBroadcastMulPackSize(stream, x, y, out, cols, elem_cnt); } else { DispatchBroadcastMulPackSize(stream, x, y, out, cols, elem_cnt); } } template __global__ void BroadcastAddElementwiseMulKernel(const T* x, const T* y, const T* z, T* out, const IndexType cols, const IndexType elem_cnt) { const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x; using LoadPack = cuda::elementwise::Packed; for (IndexType linear_index = global_thread_id * pack_size, step = gridDim.x * blockDim.x * pack_size; linear_index < elem_cnt; linear_index += step) { const IndexType row_idx = linear_index / cols; const IndexType col_idx = linear_index - row_idx * cols; const LoadPack* x_load = reinterpret_cast(x + linear_index); const LoadPack* y_load = reinterpret_cast(y + col_idx); const LoadPack* z_load = reinterpret_cast(z + linear_index); LoadPack x_vec = *x_load; LoadPack y_vec = *y_load; LoadPack z_vec = *z_load; LoadPack out_store; #pragma unroll for (int i = 0; i < pack_size; i++) { out_store.elem[i] = (x_vec.elem[i] + y_vec.elem[i]) * z_vec.elem[i]; } *(reinterpret_cast(out + linear_index)) = out_store; } } template void DispatchBroadcastAddElementwiseMulPackSize(ep::Stream* stream, const T* x, const T* y, const T* z, T* out, const IndexType cols, const IndexType elem_cnt) { int grid_size; const int pack_size = GetLaunchPackSize(cols); const int64_t pack_num = elem_cnt / pack_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (pack_size == 8) { BroadcastAddElementwiseMulKernel <<As()->cuda_stream()>>>(x, y, z, out, cols, elem_cnt); } else if (pack_size == 4) { BroadcastAddElementwiseMulKernel <<As()->cuda_stream()>>>(x, y, z, out, cols, elem_cnt); } else if (pack_size == 2) { BroadcastAddElementwiseMulKernel <<As()->cuda_stream()>>>(x, y, z, out, cols, elem_cnt); } else { BroadcastAddElementwiseMulKernel <<As()->cuda_stream()>>>(x, y, z, out, cols, elem_cnt); } } template void DispatchBroadcastAddElementwiseMulIndexType(ep::Stream* stream, const T* x, const T* y, const T* z, T* out, const int64_t cols, const int64_t elem_cnt) { if (elem_cnt < GetMaxVal()) { DispatchBroadcastAddElementwiseMulPackSize(stream, x, y, z, out, cols, elem_cnt); } else { DispatchBroadcastAddElementwiseMulPackSize(stream, x, y, z, out, cols, elem_cnt); } } } // namespace namespace user_op { std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewReduceMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } auto ReduceMatmulPrimitiveExists() { return hob::make_custom("MatmulPrimitiveExists", [](const KernelRegContext& ctx) { return NewReduceMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewWeightGradMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("x", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } auto WeightGradMatmulPrimitiveExists() { return hob::make_custom("MatmulPrimitiveExists", [](const KernelRegContext& ctx) { return NewWeightGradMatmulPrimitive(&ctx).operator bool(); }); } template class FusedCrossFeatureInteractionGradKernel final : public OpKernel, public CudaGraphSupport { public: FusedCrossFeatureInteractionGradKernel() = default; ~FusedCrossFeatureInteractionGradKernel() override = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const Tensor* x0 = ctx->Tensor4ArgNameAndIndex("x0", 0); const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex("matmul_result", 0); const int64_t batch_size = dy->shape_view().At(0); const int64_t hidden_size = dy->shape_view().At(1); const int64_t out_size = weight->shape_view().At(0); const int64_t dy_elem_cnt = dy->shape_view().elem_cnt(); Tensor* dx0 = ctx->Tensor4ArgNameAndIndex("dx0", 0); Tensor* dw = ctx->Tensor4ArgNameAndIndex("dw", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); Tensor* dbias = ctx->Tensor4ArgNameAndIndex("dbias", 0); Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); // step1: Get dbias. const T* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { ones = static_cast(cuda_device->GetConstOnes(dy->data_type(), batch_size)); } size_t m = 0, n = 0, k = 0; DimVector dy_shape(2); dy->shape_view().ToDimVector(&dy_shape); DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; InferMatmulMNK(ones_buf_shape, dy_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k); auto reduce_matmul = NewReduceMatmulPrimitive(ctx); CHECK(reduce_matmul); reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, ones, dy->dptr(), 0.0, dbias->mut_dptr()); // step2: Get dmatmul_result0. T* dy_mul_x0 = reinterpret_cast(tmp_buffer->mut_dptr()); T* dmatmul_result0 = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(dy_elem_cnt * sizeof(T))); OF_CUDA_CHECK(cuda::elementwise::Binary(MulOp(), dy_elem_cnt, dy_mul_x0, dy->dptr(), x0->dptr(), ctx->stream()->As()->cuda_stream())); ones = static_cast(cuda_device->GetConstOnes(dy->data_type(), hidden_size)); DimVector dy_mul_x0_shape(2); dy->shape_view().ToDimVector(&dy_mul_x0_shape); ones_buf_shape.at(0) = hidden_size; ones_buf_shape.at(1) = 1; InferMatmulMNK(dy_mul_x0_shape, ones_buf_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k); reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dy_mul_x0, ones, 0.0, dmatmul_result0); // step3: Get dx T* dx_buf = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(dy_elem_cnt * sizeof(T)) + GetCudaAlignedSize(batch_size * sizeof(T))); DimVector dmatmul_result_shape(2); dmatmul_result_shape.at(0) = batch_size; dmatmul_result_shape.at(1) = 1; // todo change to hidden size DimVector weight_shape(2); weight->shape_view().ToDimVector(&weight_shape); InferMatmulMNK(dmatmul_result_shape, weight_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k); reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, weight->dptr(), 0.0, reinterpret_cast(dx_buf)); OF_CUDA_CHECK(cuda::elementwise::Binary(AddOp(), dy_elem_cnt, dx->mut_dptr(), dx_buf, dy->dptr(), ctx->stream()->As()->cuda_stream())); // step4: Get dw. DimVector x_shape(2); x->shape_view().ToDimVector(&x_shape); InferMatmulMNK(dmatmul_result_shape, x_shape, /*trans_a=*/true, /*trans_b=*/false, &m, &n, &k); auto weight_grad_matmul = NewWeightGradMatmulPrimitive(ctx); CHECK(weight_grad_matmul); weight_grad_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, x->dptr(), 0.0, dw->mut_dptr()); // step5: Get dx0. DispatchBroadcastMulIndexType(ctx->stream(), dy->dptr(), matmul_result->dptr(), dx0->mut_dptr(), hidden_size, dy_elem_cnt); } }; #define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_cross_feature_interaction_v1_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == DeviceType::kCUDA) \ && (HobDataType("dy", 0) == GetDataType::value) \ && ReduceMatmulPrimitiveExists() && WeightGradMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](InferContext* ctx) { \ size_t tmp_size = 0; \ const TensorDesc& dy = ctx->InputTensorDesc("dy", 0); \ const int64_t dy_elem_cnt = dy.shape().elem_cnt(); \ const int64_t batch_size = dy.shape().At(0); \ size_t dy_mul_x0_size = GetCudaAlignedSize(dy_elem_cnt * sizeof(dtype)); \ size_t dmatmul_result_size = GetCudaAlignedSize(batch_size * sizeof(dtype)); \ size_t dx_buf_size = dy_mul_x0_size; \ tmp_size = dy_mul_x0_size + dmatmul_result_size + dx_buf_size; \ return tmp_size; \ }); REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(float) REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(half) template class FusedCrossFeatureInteractionV2GradKernel final : public OpKernel, public CudaGraphSupport { public: FusedCrossFeatureInteractionV2GradKernel() = default; ~FusedCrossFeatureInteractionV2GradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: using user_op::OpKernel::Compute; void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); const Tensor* x0 = ctx->Tensor4ArgNameAndIndex("x0", 0); const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex("matmul_result", 0); const int64_t batch_size = dy->shape_view().At(0); const int64_t in_size = weight->shape_view().At(1); const int64_t hidden_size = weight->shape_view().At(0); const int64_t dy_elem_cnt = dy->shape_view().elem_cnt(); Tensor* dx0 = ctx->Tensor4ArgNameAndIndex("dx0", 0); Tensor* dw = ctx->Tensor4ArgNameAndIndex("dw", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); Tensor* dbias = ctx->Tensor4ArgNameAndIndex("dbias", 0); Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); // step1: Get dx0. DispatchBroadcastAddElementwiseMulIndexType(ctx->stream(), matmul_result->dptr(), bias->dptr(), dy->dptr(), dx0->mut_dptr(), hidden_size, dy_elem_cnt); // step2: Get dmatmul_result0. T* dmatmul_result0 = reinterpret_cast(tmp_buffer->mut_dptr()); OF_CUDA_CHECK(cuda::elementwise::Binary(MulOp(), dy_elem_cnt, dmatmul_result0, dy->dptr(), x0->dptr(), ctx->stream()->As()->cuda_stream())); // step3: Get dx T* dx_buf = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(dy_elem_cnt * sizeof(T))); DimVector dmatmul_result_shape(2); dmatmul_result_shape.at(0) = batch_size; dmatmul_result_shape.at(1) = hidden_size; DimVector weight_shape(2); weight->shape_view().ToDimVector(&weight_shape); size_t m = 0, n = 0, k = 0; InferMatmulMNK(dmatmul_result_shape, weight_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k); auto reduce_matmul = NewReduceMatmulPrimitive(ctx); CHECK(reduce_matmul); reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, weight->dptr(), 0.0, reinterpret_cast(dx_buf)); OF_CUDA_CHECK(cuda::elementwise::Binary(AddOp(), dy_elem_cnt, dx->mut_dptr(), dx_buf, dy->dptr(), ctx->stream()->As()->cuda_stream())); // step4: Get dw. DimVector x_shape(2); x->shape_view().ToDimVector(&x_shape); InferMatmulMNK(dmatmul_result_shape, x_shape, /*trans_a=*/true, /*trans_b=*/false, &m, &n, &k); auto weight_grad_matmul = NewWeightGradMatmulPrimitive(ctx); CHECK(weight_grad_matmul); weight_grad_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, x->dptr(), 0.0, dw->mut_dptr()); // step5: Get dbias. const T* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { ones = static_cast(cuda_device->GetConstOnes(dy->data_type(), batch_size)); } DimVector dy_shape(2); dy->shape_view().ToDimVector(&dy_shape); DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; InferMatmulMNK(ones_buf_shape, dy_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k); reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, ones, reinterpret_cast(dmatmul_result0), 0.0, dbias->mut_dptr()); } }; #define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_cross_feature_interaction_v2_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((HobDeviceType() == DeviceType::kCUDA) \ && (HobDataType("dy", 0) == GetDataType::value) \ && ReduceMatmulPrimitiveExists() && WeightGradMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](InferContext* ctx) { \ size_t tmp_size = 0; \ const TensorDesc& dy = ctx->InputTensorDesc("dy", 0); \ const int64_t dy_elem_cnt = dy.shape().elem_cnt(); \ size_t dmatmul_result_size = GetCudaAlignedSize(dy_elem_cnt * sizeof(dtype)); \ size_t dx_buf_size = dmatmul_result_size; \ tmp_size = dmatmul_result_size + dx_buf_size; \ return tmp_size; \ }); REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(float) REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(half) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_dot_feature_interaction_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/batch_matmul.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/cuda/atomic.cuh" #include namespace oneflow { namespace { __global__ void GenerateGatherIndicesGpu(const int32_t elem_cnt, const int32_t stride, const int32_t in_cols, const int32_t offset, int32_t* gather_indices) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int32_t row = i / stride; const int32_t col = i - row * stride; if (col < row + offset) { int32_t in_index = row * in_cols + col; int32_t idx = row * (offset + row - 1 + offset) / 2 + col; gather_indices[idx] = in_index; } } } template __global__ void GatherConcatGpu(int32_t elem_cnt, int32_t out_cols, int32_t valid_out_cols, int32_t in_cols, int32_t output_concat_end_dim, const int32_t* gather_indices, const T* in, const T* output_concat_ptr, T* out_ptr) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int32_t row = i / out_cols; const int32_t col = i - row * out_cols; T out_val; if (col < output_concat_end_dim) { const int32_t output_concat_idx = row * output_concat_end_dim + col; out_val = output_concat_ptr[output_concat_idx]; } else if (col < valid_out_cols) { const int32_t gather_col_idx = gather_indices[col - output_concat_end_dim]; const int32_t in_offset = row * in_cols + gather_col_idx; out_val = in[in_offset]; } else { out_val = 0; } out_ptr[i] = out_val; } } template __global__ void ScatterSplitAddTransposeGpu(int32_t elem_cnt, int32_t stride_dim, int32_t out_dim, int32_t in_grad_stride, int32_t in_grad_matrix_dim, int32_t in_grad_matrix_valid_dim, int32_t output_concat_end_dim, const int32_t offset, const T* dy, T* output_concat_grad, T* in_grad) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int32_t row = i / stride_dim; const int32_t col = i - row * stride_dim; if (col < output_concat_end_dim) { output_concat_grad[row * output_concat_end_dim + col] = dy[row * out_dim + col]; } else { int32_t in_col_id = col - output_concat_end_dim; const int32_t matrix_row = in_col_id / in_grad_matrix_dim; const int32_t matrix_col = in_col_id - matrix_row * in_grad_matrix_dim; T grad_val = 0; const T* row_dy = dy + row * out_dim + output_concat_end_dim; if (matrix_row < in_grad_matrix_valid_dim && matrix_col < in_grad_matrix_valid_dim) { if (matrix_col < matrix_row) { int32_t dy_col_idx = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col; grad_val = row_dy[dy_col_idx]; } else if (matrix_row < matrix_col) { // transpose add int32_t trans_row_id = matrix_col; int32_t trans_col_id = matrix_row; int32_t dy_col_idx = trans_row_id * (offset + trans_row_id - 1 + offset) / 2 + trans_col_id; grad_val = row_dy[dy_col_idx]; } else if ((matrix_row == matrix_col) && (offset == 1)) { int32_t dy_col_idx = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col; grad_val = row_dy[dy_col_idx] * static_cast(2); } } int32_t in_grad_offset = row * in_grad_stride + in_col_id; in_grad[in_grad_offset] = grad_val; } } } template void ConcatFeatures(user_op::KernelComputeContext* ctx, int64_t dst_rows, int64_t dst_cols, void* dst_ptr) { const int64_t feature_input_size = ctx->input_size("features"); auto primitive = ep::primitive::NewPrimitive(DeviceType::kCUDA, 2); DimVector dst_shape = {dst_rows, dst_cols}; int64_t out_col_offset = 0; for (int64_t i = 0; i < feature_input_size; ++i) { const user_op::Tensor* feature = ctx->Tensor4ArgNameAndIndex("features", i); const int64_t feature_rows = feature->shape_view().At(0); const int64_t feature_cols = feature->shape_view().Count(1); DimVector dst_pos_vec = {0, out_col_offset}; DimVector src_shape = {feature_rows, feature_cols}; DimVector src_pos_vec = {0, 0}; DimVector extent_vec = {feature_rows, feature_cols}; primitive->Launch(ctx->stream(), feature->data_type(), 2, dst_ptr, dst_shape.data(), dst_pos_vec.data(), feature->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); out_col_offset += feature_cols; } int64_t pad_dim = dst_cols - out_col_offset; if (pad_dim > 0) { char* out_ptr = reinterpret_cast(dst_ptr) + out_col_offset * sizeof(T); OF_CUDA_CHECK(cudaMemset2DAsync(out_ptr, dst_cols * sizeof(T), 0, pad_dim * sizeof(T), dst_rows, ctx->stream()->As()->cuda_stream())); } } template void GatherConcatKernel(ep::Stream* stream, int32_t elem_cnt, int32_t out_dim, int32_t valid_out_dim, int32_t features_concated_dim, int32_t concated_padded_dim, int32_t output_concat_end_dim, bool self_interaction, const T* matmul_out, const T* output_concat_ptr, int32_t* gather_indices_ptr, T* out_ptr) { cudaStream_t cuda_stream = stream->As()->cuda_stream(); const int32_t gen_indices_elem_cnt = features_concated_dim * features_concated_dim; int32_t offset = self_interaction ? 1 : 0; GenerateGatherIndicesGpu<<>>(gen_indices_elem_cnt, features_concated_dim, concated_padded_dim, offset, gather_indices_ptr); int32_t matmul_stride = concated_padded_dim * concated_padded_dim; GatherConcatGpu<<>>( elem_cnt, out_dim, valid_out_dim, matmul_stride, output_concat_end_dim, gather_indices_ptr, matmul_out, output_concat_ptr, out_ptr); } template void ScatterSplitAddTranspose(ep::Stream* stream, int32_t batch_size, int32_t out_dim, int32_t concated_padded_dim, int32_t features_concated_dim, int32_t output_concat_end_dim, const bool self_interaction, const T* dy, T* output_concat_grad, T* matmul_out_grad_ptr) { int32_t stride_dim = output_concat_end_dim + concated_padded_dim * concated_padded_dim; int32_t matmul_stride = concated_padded_dim * concated_padded_dim; const int32_t elem_cnt = batch_size * stride_dim; int32_t offset = self_interaction ? 1 : 0; ScatterSplitAddTransposeGpu<<As()->cuda_stream()>>>( elem_cnt, stride_dim, out_dim, matmul_stride, concated_padded_dim, features_concated_dim, output_concat_end_dim, offset, dy, output_concat_grad, matmul_out_grad_ptr); } template void ConcatFeaturesGrad(user_op::KernelComputeContext* ctx, const int64_t batch_size, const int64_t concated_padded_dim, const int64_t vector_size, const T* concated_features_grad) { auto primitive = ep::primitive::NewPrimitive(DeviceType::kCUDA, 2); DimVector src_shape = {batch_size, concated_padded_dim * vector_size}; int64_t in_col_offset = 0; for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { user_op::Tensor* feature_grad = ctx->Tensor4ArgNameAndIndex("features_grad", i); const int64_t feature_grad_rows = feature_grad->shape_view().At(0); const int64_t feature_grad_cols = feature_grad->shape_view().Count(1); DimVector dst_shape = {feature_grad_rows, feature_grad_cols}; DimVector dst_pos_vec = {0, 0}; DimVector src_pos_vec = {0, in_col_offset}; DimVector extent_vec = {feature_grad_rows, feature_grad_cols}; in_col_offset += feature_grad_cols; primitive->Launch(ctx->stream(), feature_grad->data_type(), 2, feature_grad->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), concated_features_grad, src_shape.data(), src_pos_vec.data(), extent_vec.data()); } } template struct DefaultComputeType { using type = T; }; template<> struct DefaultComputeType { using type = float; }; template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; int64_t GetPaddedDim(int64_t dim) { const int64_t align_dim = 16; const int64_t padded_dim = (dim + align_dim - 1) / align_dim * align_dim; return padded_dim; } template struct DotFwdParam { const T* in[max_in]; int32_t in_feature_dim[max_in]; int32_t dim_start_offset[max_in]; const T* sparse_feature; const uint32_t* sparse_indices; int32_t sparse_dim; int32_t sparse_dim_start; int32_t features_dim; const T* output_concat; int32_t output_concat_size; T* out; int32_t num_in; }; #if __CUDA_ARCH__ >= 700 template class Wmma { public: __device__ void LoadA(const T* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(a_, ptr, ldm); } __device__ void LoadB(const T* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(b_, ptr, ldm); } __device__ void Store(AccType* ptr, int ldm) { nvcuda::wmma::store_matrix_sync(ptr, acc_, ldm, nvcuda::wmma::mem_row_major); } __device__ void Mma() { nvcuda::wmma::mma_sync(acc_, a_, b_, acc_); } __device__ void InitAcc() { nvcuda::wmma::fill_fragment(acc_, 0.0f); } __device__ __forceinline__ T Convert(T src) { return src; } private: nvcuda::wmma::fragment a_; nvcuda::wmma::fragment b_; nvcuda::wmma::fragment acc_; }; template class Wmma { public: #if __CUDA_ARCH__ >= 800 __device__ void LoadA(const float* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(a_, ptr, ldm); } __device__ void LoadB(const float* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(b_, ptr, ldm); } __device__ void Mma() { nvcuda::wmma::mma_sync(acc_, a_, b_, acc_); } __device__ __forceinline__ float Convert(float src) { return nvcuda::wmma::__float_to_tf32(src); } __device__ void Store(AccType* ptr, int ldm) { nvcuda::wmma::store_matrix_sync(ptr, acc_, ldm, nvcuda::wmma::mem_row_major); } __device__ void InitAcc() { nvcuda::wmma::fill_fragment(acc_, 0.0f); } #else __device__ void LoadA(const float* ptr, int ldm) { __trap(); } __device__ void LoadB(const float* ptr, int ldm) { __trap(); } __device__ void Mma() { __trap(); } __device__ __forceinline__ float Convert(float src) { return src; } __device__ void Store(AccType* ptr, int ldm) { __trap(); } __device__ void InitAcc() { __trap(); } #endif private: #if __CUDA_ARCH__ >= 800 nvcuda::wmma::fragment a_; nvcuda::wmma::fragment b_; nvcuda::wmma::fragment acc_; #endif }; #endif //__CUDA_ARCH__ >= 700 constexpr int kUnrollDim = 2; template __global__ void DotFeatureInteractionWmmaImpl( int m_num_tiles, int k_num_tiles, int64_t batch_size, int padded_num_rows, int vector_num_pack, int padded_vector_num_pack, int out_num_cols, int out_num_cols_num_pack, int in_shared_mem_cols, int in_shared_mem_cols_num_pack, int acc_shared_mem_cols, int acc_shared_mem_cols_num_pack, int offset, int output_padding, DotFwdParam param) { #if __CUDA_ARCH__ >= 700 Wmma wmma; extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; int warp_id = threadIdx.y; T* buf = reinterpret_cast(shared_buf); Pack* buf_pack = reinterpret_cast*>(shared_buf); ComputeType* acc_buf = reinterpret_cast(shared_buf + padded_num_rows * in_shared_mem_cols * sizeof(T)); int batch_idx = blockIdx.x; T* batch_out = param.out + batch_idx * out_num_cols; Pack* batch_out_pack = reinterpret_cast*>(param.out) + batch_idx * out_num_cols_num_pack; const int output_concat_size = param.output_concat_size; const T* batch_output_concat = (param.output_concat) ? (param.output_concat + batch_idx * output_concat_size) : nullptr; const uint32_t* batch_sparse_indices = (param.sparse_indices) ? (param.sparse_indices + batch_idx * param.sparse_dim) : nullptr; const Pack* sparse_feature_pack = (param.sparse_feature) ? reinterpret_cast*>(param.sparse_feature) : nullptr; for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) { // load dense feature to shared_mem #pragma unroll for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } const Pack* batch_in = reinterpret_cast*>(param.in[i]) + batch_idx * param.in_feature_dim[i] * vector_num_pack; for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i]; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.in_feature_dim[i]) { break; } int buf_row = param.dim_start_offset[i] + in_row; Pack pack_in_val = batch_in[in_row * vector_num_pack + col]; #pragma unroll for (int t = 0; t < pack_size; ++t) { pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]); } buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val; } } } // load sparse feature to shared_mem for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.sparse_dim) { break; } int buf_row = param.sparse_dim_start + in_row; int sparse_in_row = batch_sparse_indices[in_row]; Pack pack_in_val = sparse_feature_pack[sparse_in_row * vector_num_pack + col]; #pragma unroll for (int t = 0; t < pack_size; ++t) { pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]); } buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val; } } } Pack zero; #pragma unroll for (int k = 0; k < pack_size; ++k) { zero.elem[k] = wmma.Convert(0); } for (int row = threadIdx.y; row < param.features_dim; row += blockDim.y) { for (int col = vector_num_pack + threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) { buf_pack[row * in_shared_mem_cols_num_pack + col] = zero; } } __syncthreads(); for (int blocks_id = warp_id; blocks_id < m_num_tiles * m_num_tiles; blocks_id += blockDim.y) { int blocks_row = blocks_id / m_num_tiles; int blocks_col = blocks_id - blocks_row * m_num_tiles; if (blocks_row >= blocks_col) { wmma.InitAcc(); for (int step = 0; step < k_num_tiles; ++step) { T* tile_a_ptr = buf + blocks_row * mn_tile_dim * in_shared_mem_cols + step * k_tile_dim; T* tile_b_ptr = buf + blocks_col * mn_tile_dim * in_shared_mem_cols + step * k_tile_dim; wmma.LoadA(tile_a_ptr, in_shared_mem_cols); wmma.LoadB(tile_b_ptr, in_shared_mem_cols); wmma.Mma(); } ComputeType* tile_ptr = acc_buf + blocks_row * mn_tile_dim * acc_shared_mem_cols + blocks_col * mn_tile_dim; wmma.Store(tile_ptr, acc_shared_mem_cols); } } __syncthreads(); T* emb_out = batch_out + output_concat_size; for (int base_row = threadIdx.y * kUnrollDim; base_row < param.features_dim; base_row += kUnrollDim * blockDim.y) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int row = base_row + k; if (row >= param.features_dim) { break; } for (int col = threadIdx.x; col < param.features_dim; col += blockDim.x) { if (col < row + offset) { int64_t idx = row * (offset + row - 1 + offset) / 2 + col; emb_out[idx] = static_cast(acc_buf[row * acc_shared_mem_cols + col]); } } } } int thread_id = threadIdx.y * blockDim.x + threadIdx.x; for (int i = thread_id; i < output_concat_size; i += blockDim.x * blockDim.y) { batch_out[i] = batch_output_concat[i]; } for (int i = thread_id; i < output_padding; i += blockDim.x * blockDim.y) { batch_out[out_num_cols - 1 - i] = 0; } #else __trap(); #endif // __CUDA_ARCH__ >= 700 } template struct KTileDim { static const int val = 16; }; template<> struct KTileDim { static const int val = 8; }; template struct DotFeatureInteractionKernel { static bool Launch(ep::Stream* stream, int64_t batch_size, int concated_padded_dim, int vector_size, int out_num_cols, bool self_interaction, int output_padding, const DotFwdParam& param) { const int block_size = 128; const int block_dim_x = 32; const int block_dim_y = block_size / block_dim_x; const int num_blocks = batch_size; const int mn_tile_dim = 16; const int k_tile_dim = KTileDim::val; const int64_t padded_vector_size = GetPaddedDim(vector_size); const int m_num_tiles = concated_padded_dim / mn_tile_dim; const int k_num_tiles = padded_vector_size / k_tile_dim; const int skew_in = 8; const int skew_acc = 8; const int in_shared_mem_num_cols = padded_vector_size + skew_in; const int acc_shared_mem_num_cols = concated_padded_dim + skew_acc; const size_t in_shared_mem_bytes = concated_padded_dim * in_shared_mem_num_cols * sizeof(T); using ComputeType = typename DefaultComputeType::type; const size_t acc_shared_mem_bytes = concated_padded_dim * acc_shared_mem_num_cols * sizeof(ComputeType); const size_t total_shared_mem_bytes = in_shared_mem_bytes + acc_shared_mem_bytes; const int32_t offset = self_interaction ? 1 : 0; const int out_num_cols_num_pack = out_num_cols / pack_size; const int vector_num_pack = vector_size / pack_size; const int padded_vector_num_pack = padded_vector_size / pack_size; const int in_shared_mem_cols_num_pack = in_shared_mem_num_cols / pack_size; const int acc_shared_mem_cols_num_pack = acc_shared_mem_num_cols / pack_size; int max_active_blocks; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, DotFeatureInteractionWmmaImpl, block_size, total_shared_mem_bytes)); if (max_active_blocks <= 0) { return false; } cudaStream_t cuda_stream = stream->As()->cuda_stream(); DotFeatureInteractionWmmaImpl <<>>( m_num_tiles, k_num_tiles, batch_size, concated_padded_dim, vector_num_pack, padded_vector_num_pack, out_num_cols, out_num_cols_num_pack, in_shared_mem_num_cols, in_shared_mem_cols_num_pack, acc_shared_mem_num_cols, acc_shared_mem_cols_num_pack, offset, output_padding, param); return true; } }; template struct DotBwdParam { const T* out_grad; const T* in[max_in]; T* in_grad[max_in]; T* output_concat_grad; const T* sparse_feature; const uint32_t* sparse_indices; int32_t sparse_dim; int32_t sparse_dim_start; T* sparse_feature_grad; int32_t output_concat_size; int32_t in_feature_dim[max_in]; int32_t dim_start_offset[max_in]; int32_t features_dim; int32_t num_in; }; template __device__ __inline__ void AtomicAdd(Pack* address, Pack val) { #pragma unroll for (int i = 0; i < pack_size; ++i) { cuda::atomic::Add(reinterpret_cast(address) + i, static_cast(val.elem[i])); } } template<> __device__ __inline__ void AtomicAdd(Pack* address, Pack val) { half2 h2_val; h2_val.x = static_cast(val.elem[0]); h2_val.y = static_cast(val.elem[1]); cuda::atomic::Add(reinterpret_cast(address), h2_val); } template __global__ void DotFeatureInteractionBackwardWmmaImpl( int m_num_tiles, int n_num_tiles, int k_num_tiles, int64_t batch_size, int padded_num_rows, int vector_num_pack, int vector_num_sparse_grad_pack, int padded_vector_num_pack, int out_num_cols, int in_shared_mem_cols, int in_shared_mem_cols_num_pack, int in_shared_mem_cols_num_sparse_grad_pack, int matrix_out_grad_shared_mem_cols, int offset, DotBwdParam param) { #if __CUDA_ARCH__ >= 700 Wmma wmma; extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; int warp_id = threadIdx.y; T* in_buf = reinterpret_cast(shared_buf); Pack* in_buf_pack = reinterpret_cast*>(shared_buf); T* matrix_out_grad_buf = in_buf + padded_num_rows * in_shared_mem_cols; ComputeType* in_grad_buf = reinterpret_cast( matrix_out_grad_buf + padded_num_rows * matrix_out_grad_shared_mem_cols); Pack* in_grad_buf_pack = reinterpret_cast*>(in_grad_buf); int batch_idx = blockIdx.x; const T* batch_out_grad = param.out_grad + batch_idx * out_num_cols; const int output_concat_size = param.output_concat_size; T* batch_output_concat_grad = (param.output_concat_grad) ? (param.output_concat_grad + batch_idx * output_concat_size) : nullptr; const uint32_t* batch_sparse_indices = (param.sparse_indices) ? (param.sparse_indices + batch_idx * param.sparse_dim) : nullptr; const Pack* sparse_feature_pack = (param.sparse_feature) ? reinterpret_cast*>(param.sparse_feature) : nullptr; int features_dim = param.features_dim; // 1.split out_grad to concat_out_grad and matrix_out_grad buf int thread_id = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thread_id; i < output_concat_size; i += blockDim.x * blockDim.y) { batch_output_concat_grad[i] = batch_out_grad[i]; } const T* batch_interaction_out_grad = batch_out_grad + output_concat_size; for (int matrix_row = threadIdx.y; matrix_row < padded_num_rows; matrix_row += blockDim.y) { for (int matrix_col = threadIdx.x; matrix_col < padded_num_rows; matrix_col += blockDim.x) { const int64_t i = matrix_row * matrix_out_grad_shared_mem_cols + matrix_col; T grad_val = 0; if (matrix_row < features_dim && matrix_col < features_dim) { if (matrix_col < matrix_row) { int32_t out_grad_col = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col; grad_val = batch_interaction_out_grad[out_grad_col]; } else if (matrix_row < matrix_col) { // transpose add int32_t trans_row_id = matrix_col; int32_t trans_col_id = matrix_row; int32_t out_grad_col = trans_row_id * (offset + trans_row_id - 1 + offset) / 2 + trans_col_id; grad_val = batch_interaction_out_grad[out_grad_col]; } else if ((matrix_row == matrix_col) && (offset == 1)) { int32_t out_grad_col = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col; grad_val = batch_interaction_out_grad[out_grad_col] * static_cast(2); } } matrix_out_grad_buf[i] = wmma.Convert(grad_val); } } // 2.load in to in in_buf for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) { #pragma unroll for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } const Pack* batch_in = reinterpret_cast*>(param.in[i]) + batch_idx * param.in_feature_dim[i] * vector_num_pack; for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i]; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.in_feature_dim[i]) { break; } int buf_row = param.dim_start_offset[i] + in_row; Pack pack_in_val = batch_in[in_row * vector_num_pack + col]; #pragma unroll for (int t = 0; t < pack_size; ++t) { pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]); } in_buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val; } } } // load sparse feature to shared_mem for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.sparse_dim) { break; } int buf_row = param.sparse_dim_start + in_row; int sparse_in_row = batch_sparse_indices[in_row]; Pack pack_in_val = sparse_feature_pack[sparse_in_row * vector_num_pack + col]; #pragma unroll for (int t = 0; t < pack_size; ++t) { pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]); } in_buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val; } } } Pack zero; #pragma unroll for (int k = 0; k < pack_size; ++k) { zero.elem[k] = wmma.Convert(0); } #pragma unroll for (int row = features_dim + threadIdx.y; row < padded_num_rows; row += blockDim.y) { for (int col = threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) { in_buf_pack[row * in_shared_mem_cols_num_pack + col] = zero; } } for (int row = threadIdx.y; row < features_dim; row += blockDim.y) { for (int col = vector_num_pack + threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) { in_buf_pack[row * in_shared_mem_cols_num_pack + col] = zero; } } __syncthreads(); for (int blocks_id = warp_id; blocks_id < m_num_tiles * n_num_tiles; blocks_id += blockDim.y) { int blocks_row = blocks_id / n_num_tiles; int blocks_col = blocks_id - blocks_row * n_num_tiles; wmma.InitAcc(); for (int step = 0; step < k_num_tiles; ++step) { // blocks_row is a row_id, step is a col_id. blocks_col is b col_id, // step is b row_id. T* tile_a_ptr = matrix_out_grad_buf + blocks_row * mn_tile_dim * matrix_out_grad_shared_mem_cols + step * k_tile_dim; T* tile_b_ptr = in_buf + step * k_tile_dim * in_shared_mem_cols + blocks_col * mn_tile_dim; wmma.LoadA(tile_a_ptr, matrix_out_grad_shared_mem_cols); wmma.LoadB(tile_b_ptr, in_shared_mem_cols); wmma.Mma(); } ComputeType* tile_ptr = in_grad_buf + blocks_row * mn_tile_dim * in_shared_mem_cols + blocks_col * mn_tile_dim; wmma.Store(tile_ptr, in_shared_mem_cols); } __syncthreads(); // 4.split in_grad buf to dx // shared_mem to dense dx for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) { #pragma unroll for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } Pack* batch_in_grad = reinterpret_cast*>(param.in_grad[i]) + batch_idx * param.in_feature_dim[i] * vector_num_pack; for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i]; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.in_feature_dim[i]) { break; } int buf_row = param.dim_start_offset[i] + in_row; Pack grad_val; Pack buf_grad_val = in_grad_buf_pack[buf_row * in_shared_mem_cols_num_pack + col]; #pragma unroll for (int t = 0; t < pack_size; ++t) { grad_val.elem[t] = static_cast(buf_grad_val.elem[t]); } batch_in_grad[in_row * vector_num_pack + col] = grad_val; } } } } // shared_mem to sparse dx, sparse in grad use sparse_grad_pack_size Pack* in_grad_buf_sparse_grad_pack = reinterpret_cast*>(in_grad_buf); Pack* sparse_feature_grad_pack = reinterpret_cast*>(param.sparse_feature_grad); for (int col = threadIdx.x; col < vector_num_sparse_grad_pack; col += blockDim.x) { for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) { #pragma unroll for (int k = 0; k < kUnrollDim; ++k) { int in_row = j + k; if (in_row >= param.sparse_dim) { break; } int buf_row = param.sparse_dim_start + in_row; int sparse_in_row = batch_sparse_indices[in_row]; Pack buf_grad_val = in_grad_buf_sparse_grad_pack[buf_row * in_shared_mem_cols_num_sparse_grad_pack + col]; AtomicAdd( sparse_feature_grad_pack + sparse_in_row * vector_num_sparse_grad_pack + col, buf_grad_val); } } } #else __trap(); #endif // __CUDA_ARCH__ >= 700 } template struct DotFeatureInteractionBackwardKernel { static bool Launch(ep::Stream* stream, int64_t batch_size, int concated_padded_dim, int vector_size, int out_num_cols, bool self_interaction, const DotBwdParam& param) { const int block_size = 256; const int block_dim_x = 32; const int block_dim_y = block_size / block_dim_x; const int num_blocks = batch_size; const int mn_tile_dim = 16; const int k_tile_dim = KTileDim::val; const int64_t padded_vector_size = GetPaddedDim(vector_size); const int m_num_tiles = concated_padded_dim / mn_tile_dim; const int k_num_tiles = concated_padded_dim / k_tile_dim; const int n_num_tiles = padded_vector_size / mn_tile_dim; const int skew_in = 8; const int in_shared_mem_num_cols = padded_vector_size + skew_in; const int matrix_out_grad_shared_mem_cols = concated_padded_dim + skew_in; const size_t in_shared_mem_bytes = concated_padded_dim * in_shared_mem_num_cols * sizeof(T); const size_t matrix_out_grad_shared_mem_bytes = concated_padded_dim * matrix_out_grad_shared_mem_cols * sizeof(T); using ComputeType = typename DefaultComputeType::type; const size_t in_grad_shared_mem_bytes = concated_padded_dim * in_shared_mem_num_cols * sizeof(ComputeType); const size_t total_shared_mem_bytes = in_shared_mem_bytes + matrix_out_grad_shared_mem_bytes + in_grad_shared_mem_bytes; const int32_t offset = self_interaction ? 1 : 0; const int vector_num_pack = vector_size / pack_size; const int padded_vector_num_pack = padded_vector_size / pack_size; const int in_shared_mem_cols_num_pack = in_shared_mem_num_cols / pack_size; const int vector_num_sparse_grad_pack = vector_size / sparse_grad_pack_size; const int in_shared_mem_cols_num_sparse_grad_pack = in_shared_mem_num_cols / sparse_grad_pack_size; int max_active_blocks; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, DotFeatureInteractionBackwardWmmaImpl, block_size, total_shared_mem_bytes)); if (max_active_blocks <= 0) { return false; } cudaStream_t cuda_stream = stream->As()->cuda_stream(); DotFeatureInteractionBackwardWmmaImpl <<>>( m_num_tiles, n_num_tiles, k_num_tiles, batch_size, concated_padded_dim, vector_num_pack, vector_num_sparse_grad_pack, padded_vector_num_pack, out_num_cols, in_shared_mem_num_cols, in_shared_mem_cols_num_pack, in_shared_mem_cols_num_sparse_grad_pack, matrix_out_grad_shared_mem_cols, offset, param); return true; } }; template __global__ void MemsetGpu(int64_t parallel_num, int64_t vector_size, const uint32_t* num_valid, T* dst) { size_t count = 0; for (int i = 0; i < parallel_num; ++i) { count += num_valid[i] * vector_size; } const size_t pack_count = count / pack; Pack pack_value; for (int i = 0; i < pack; ++i) { pack_value.elem[i] = static_cast(0); } auto* pack_dst = reinterpret_cast*>(dst); CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; } T* tail_dst = dst + pack_count * pack; const size_t tail_count = count - pack_count * pack; CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = static_cast(0); } } template typename std::enable_if<(pack != 0), void>::type LaunchPackMemsetGpu(cudaStream_t stream, const uint32_t* num_valid, T* ptr, size_t sm_count, int64_t vector_size, int64_t parallel_num) { MemsetGpu<<<2 * sm_count, 1024, 0, stream>>>(parallel_num, vector_size, num_valid, ptr); } template typename std::enable_if<(pack == 0), void>::type LaunchPackMemsetGpu(cudaStream_t stream, const uint32_t* num_valid, T* ptr, size_t sm_count, int64_t vector_size, int64_t parallel_num) { LOG(FATAL) << "wrong alignment"; } template void LaunchMemset(cudaStream_t stream, size_t sm_count, int64_t vector_size, int64_t parallel_num, const uint32_t* num_valid, T* ptr) { auto uintptr = reinterpret_cast(ptr); if (uintptr % 16 == 0) { LaunchPackMemsetGpu(stream, num_valid, ptr, sm_count, vector_size, parallel_num); } else if (uintptr % 8 == 0) { LaunchPackMemsetGpu(stream, num_valid, ptr, sm_count, vector_size, parallel_num); } else if (uintptr % 4 == 0) { LaunchPackMemsetGpu(stream, num_valid, ptr, sm_count, vector_size, parallel_num); } else if (uintptr % 2 == 0) { LaunchPackMemsetGpu(stream, num_valid, ptr, sm_count, vector_size, parallel_num); } else { LaunchPackMemsetGpu(stream, num_valid, ptr, sm_count, vector_size, parallel_num); } } template bool DispatchFeatureInteractionDotPackSize(user_op::KernelComputeContext* ctx, const int32_t input_size) { CHECK_LE(input_size, max_in) << input_size; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t batch_size = out->shape_view().At(0); const int64_t out_num_cols = out->shape_view().At(1); const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex("features", 0)->shape().At(2); DotFwdParam param; param.num_in = input_size; param.out = out->mut_dptr(); int64_t features_concated_dim = 0; for (int i = 0; i < input_size; ++i) { param.in[i] = ctx->Tensor4ArgNameAndIndex("features", i)->dptr(); param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex("features", i)->shape().At(1); param.dim_start_offset[i] = features_concated_dim; features_concated_dim += param.in_feature_dim[i]; } if (ctx->has_input("sparse_feature", 0)) { CHECK(ctx->has_input("sparse_indices", 0)); const user_op::Tensor* sparse_feature = ctx->Tensor4ArgNameAndIndex("sparse_feature", 0); const user_op::Tensor* sparse_indices = ctx->Tensor4ArgNameAndIndex("sparse_indices", 0); param.sparse_feature = sparse_feature->dptr(); CHECK_EQ(sparse_indices->data_type(), DataType::kUInt32); param.sparse_indices = reinterpret_cast(sparse_indices->dptr()); param.sparse_dim = ctx->TensorDesc4ArgNameAndIndex("sparse_indices", 0)->shape().At(1); param.sparse_dim_start = features_concated_dim; features_concated_dim += param.sparse_dim; } else { param.sparse_feature = nullptr; param.sparse_indices = nullptr; param.sparse_dim = 0; param.sparse_dim_start = 0; } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); param.features_dim = features_concated_dim; if (ctx->has_input("output_concat", 0)) { const user_op::Tensor* output_concat = ctx->Tensor4ArgNameAndIndex("output_concat", 0); param.output_concat = output_concat->dptr(); param.output_concat_size = output_concat->shape_view().At(1); } else { param.output_concat = nullptr; param.output_concat_size = 0; } const bool self_interaction = ctx->Attr("self_interaction"); const int32_t output_padding = ctx->Attr("output_padding"); if (vector_size % 4 == 0 && out_num_cols % 4 == 0) { return DotFeatureInteractionKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, output_padding, param); } else if (vector_size % 2 == 0 && out_num_cols % 2 == 0) { return DotFeatureInteractionKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, output_padding, param); } else { return DotFeatureInteractionKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, output_padding, param); } } template bool DispatchFeatureInteractionDotBackwardPackSize(user_op::KernelComputeContext* ctx, const int32_t input_size) { CHECK_LE(input_size, max_in) << input_size; user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const int64_t batch_size = dy->shape_view().At(0); const int64_t out_num_cols = dy->shape_view().At(1); const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex("features", 0)->shape().At(2); DotBwdParam param; param.num_in = input_size; param.out_grad = dy->dptr(); int64_t features_concated_dim = 0; for (int i = 0; i < input_size; ++i) { param.in[i] = ctx->Tensor4ArgNameAndIndex("features", i)->dptr(); param.in_grad[i] = ctx->Tensor4ArgNameAndIndex("features_grad", i)->mut_dptr(); param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex("features", i)->shape().At(1); param.dim_start_offset[i] = features_concated_dim; features_concated_dim += param.in_feature_dim[i]; } if (ctx->has_input("sparse_feature", 0)) { CHECK(ctx->has_input("sparse_indices", 0)); CHECK(ctx->has_input("num_valid_sparse_feature", 0)); CHECK(ctx->has_output("sparse_feature_grad", 0)); const user_op::Tensor* sparse_feature = ctx->Tensor4ArgNameAndIndex("sparse_feature", 0); const user_op::Tensor* sparse_indices = ctx->Tensor4ArgNameAndIndex("sparse_indices", 0); const user_op::Tensor* num_valid_sparse_feature = ctx->Tensor4ArgNameAndIndex("num_valid_sparse_feature", 0); param.sparse_feature = sparse_feature->dptr(); CHECK_EQ(sparse_indices->data_type(), DataType::kUInt32); param.sparse_indices = reinterpret_cast(sparse_indices->dptr()); param.sparse_dim = ctx->TensorDesc4ArgNameAndIndex("sparse_indices", 0)->shape().At(1); param.sparse_dim_start = features_concated_dim; features_concated_dim += param.sparse_dim; param.sparse_feature_grad = ctx->Tensor4ArgNameAndIndex("sparse_feature_grad", 0)->mut_dptr(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); CHECK_EQ(num_valid_sparse_feature->data_type(), DataType::kUInt32); LaunchMemset(ctx->stream()->As()->cuda_stream(), ctx->stream()->As()->device_properties().multiProcessorCount, vector_size, parallel_num, reinterpret_cast(num_valid_sparse_feature->dptr()) + parallel_id * parallel_num, param.sparse_feature_grad); } else { param.sparse_feature = nullptr; param.sparse_indices = nullptr; param.sparse_feature_grad = nullptr; param.sparse_dim = 0; param.sparse_dim_start = 0; } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); param.features_dim = features_concated_dim; if (ctx->has_output("output_concat_grad", 0)) { user_op::Tensor* output_concat_grad = ctx->Tensor4ArgNameAndIndex("output_concat_grad", 0); param.output_concat_grad = output_concat_grad->mut_dptr(); param.output_concat_size = output_concat_grad->shape_view().At(1); } else { param.output_concat_grad = nullptr; param.output_concat_size = 0; } const bool self_interaction = ctx->Attr("self_interaction"); if (vector_size % 4 == 0) { return DotFeatureInteractionBackwardKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, param); } else if (vector_size % 2 == 0) { return DotFeatureInteractionBackwardKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, param); } else { if (ctx->has_input("sparse_feature", 0) && dy->data_type() == DataType::kFloat16) { UNIMPLEMENTED() << "fused dot interaction backward kernel not support sparse_feature with pack_size 1, " "because atomicAdd(half) is too slow"; return false; } return DotFeatureInteractionBackwardKernel::Launch( ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction, param); } } template struct Param { const T* in[max_in]; int32_t in_feature_dim[max_in]; T* out; int32_t num_in; }; template __global__ void FeatureInteractionSum(int64_t batch_size, int64_t vector_num_pack, Param param) { using ComputeType = typename DefaultComputeType::type; Pack* dst_pack = reinterpret_cast*>(param.out); for (int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; batch_idx < batch_size; batch_idx += gridDim.x * blockDim.y) { Pack* batch_out = dst_pack + batch_idx * vector_num_pack; for (int col_id = threadIdx.x; col_id < vector_num_pack; col_id += blockDim.x) { Pack sum; Pack square_sum; #pragma unroll for (int k = 0; k < pack_size; ++k) { sum.elem[k] = static_cast(0); square_sum.elem[k] = static_cast(0); } for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } const Pack* batch_in = reinterpret_cast*>(param.in[i]) + batch_idx * param.in_feature_dim[i] * vector_num_pack; #pragma unroll for (int j = 0; j < param.in_feature_dim[i]; ++j) { Pack val = batch_in[j * vector_num_pack + col_id]; #pragma unroll for (int k = 0; k < pack_size; ++k) { const ComputeType compute_val = static_cast(val.elem[k]); sum.elem[k] += compute_val; square_sum.elem[k] += compute_val * compute_val; } } } Pack out; #pragma unroll for (int k = 0; k < pack_size; ++k) { out.elem[k] = static_cast((sum.elem[k] * sum.elem[k] - square_sum.elem[k]) * static_cast(0.5)); } batch_out[col_id] = out; } } } template struct GradParam { const T* out_grad; const T* in[max_in]; int32_t in_feature_dim[max_in]; T* in_grad[max_in]; int32_t num_in; }; template __global__ void FeatureInteractionSumGrad(int64_t batch_size, int64_t vector_size, GradParam param) { using ComputeType = typename DefaultComputeType::type; for (int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; batch_idx < batch_size; batch_idx += gridDim.x * blockDim.y) { const T* batch_out_grad = param.out_grad + batch_idx * vector_size; for (int col_id = threadIdx.x; col_id < vector_size; col_id += blockDim.x) { ComputeType sum = 0; for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } const T* batch_in = param.in[i] + batch_idx * param.in_feature_dim[i] * vector_size; for (int j = 0; j < param.in_feature_dim[i]; ++j) { sum += static_cast(batch_in[j * vector_size + col_id]); } } for (int i = 0; i < max_in; ++i) { if (i >= param.num_in) { break; } const int64_t in_batch_offset = batch_idx * param.in_feature_dim[i] * vector_size; const T* batch_in = param.in[i] + in_batch_offset; T* batch_in_grad = param.in_grad[i] + in_batch_offset; for (int j = 0; j < param.in_feature_dim[i]; ++j) { const int64_t offset = j * vector_size + col_id; batch_in_grad[offset] = static_cast(static_cast(batch_out_grad[col_id]) * (sum - static_cast(batch_in[offset]))); } } } } } void GetBlockDims(const int64_t vector_size, int* block_dim_x, int* block_dim_y) { const int block_size = 256; if (vector_size < block_size) { *block_dim_x = std::ceil(static_cast(vector_size) / 8) * 8; *block_dim_y = (block_size + *block_dim_x - 1) / *block_dim_x; } else { *block_dim_x = block_size; *block_dim_y = 1; } } int GetNumBlocks(const int64_t num_instances, const int64_t instance_per_block) { int max_blocks = (num_instances + instance_per_block - 1) / instance_per_block; return std::min(max_blocks, kCudaMaxBlocksNum); } template void DispatchFeatureInteractionSumPackSize(ep::Stream* stream, const int64_t batch_size, const int64_t vector_size, const Param& param) { int block_dim_x; int block_dim_y; const int pack_size = (vector_size % 2 == 0) ? 2 : 1; const int64_t vector_num_pack = vector_size / pack_size; GetBlockDims(vector_num_pack, &block_dim_x, &block_dim_y); const int num_blocks = GetNumBlocks(batch_size, block_dim_y); dim3 block_dims = dim3(block_dim_x, block_dim_y); cudaStream_t cuda_stream = stream->As()->cuda_stream(); if (pack_size == 2) { FeatureInteractionSum <<>>(batch_size, vector_num_pack, param); } else { FeatureInteractionSum <<>>(batch_size, vector_num_pack, param); } } template void DispatchFeatureInteractionSumInputSize(user_op::KernelComputeContext* ctx, const int32_t input_size) { CHECK_LE(input_size, max_in) << input_size; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t batch_size = out->shape_view().At(0); const int64_t vector_size = out->shape_view().At(1); Param param; param.num_in = input_size; param.out = out->mut_dptr(); for (int i = 0; i < input_size; ++i) { param.in[i] = ctx->Tensor4ArgNameAndIndex("features", i)->dptr(); param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex("features", i)->shape().At(1); } DispatchFeatureInteractionSumPackSize(ctx->stream(), batch_size, vector_size, param); } template void DispatchFeatureInteractionSumGradInputSize(user_op::KernelComputeContext* ctx, const int32_t input_size) { CHECK_LE(input_size, max_in) << input_size; const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const int64_t batch_size = dy->shape_view().At(0); const int64_t vector_size = dy->shape_view().At(1); int block_dim_x; int block_dim_y; GetBlockDims(vector_size, &block_dim_x, &block_dim_y); const int num_blocks = GetNumBlocks(batch_size, block_dim_y); dim3 block_dims = dim3(block_dim_x, block_dim_y); GradParam param; param.num_in = input_size; param.out_grad = dy->dptr(); for (int i = 0; i < input_size; ++i) { param.in[i] = ctx->Tensor4ArgNameAndIndex("features", i)->dptr(); param.in_grad[i] = ctx->Tensor4ArgNameAndIndex("features_grad", i)->mut_dptr(); param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex("features_grad", i)->shape().At(1); } FeatureInteractionSumGrad <<stream()->As()->cuda_stream()>>>( batch_size, vector_size, param); } } // namespace template class FusedDotFeatureInteractionPoolingSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedDotFeatureInteractionPoolingSumKernel() = default; ~FusedDotFeatureInteractionPoolingSumKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { CHECK(!ctx->has_input("sparse_feature", 0)) << "pooling sum, sparse_feature is not supported. "; const int input_size = ctx->input_size("features"); if (input_size == 1) { DispatchFeatureInteractionSumInputSize(ctx, input_size); } else if (input_size == 2) { DispatchFeatureInteractionSumInputSize(ctx, input_size); } else if (input_size <= 8) { DispatchFeatureInteractionSumInputSize(ctx, input_size); } else { CHECK_LE(input_size, 128) << "input_size must not greater than 128. "; DispatchFeatureInteractionSumInputSize(ctx, input_size); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_dot_feature_interaction") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && (user_op::HobAttr("pooling") == "sum")); REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(float) REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(half) template bool TryLaunchTensorCoreDotKernel(user_op::KernelComputeContext* ctx) { const int input_size = ctx->input_size("features"); if (input_size == 1) { return DispatchFeatureInteractionDotPackSize(ctx, input_size); } else if (input_size == 2) { return DispatchFeatureInteractionDotPackSize(ctx, input_size); } else if (input_size <= 8) { return DispatchFeatureInteractionDotPackSize(ctx, input_size); } else { CHECK_LE(input_size, 128) << "input_size must not greater than 128. "; return DispatchFeatureInteractionDotPackSize(ctx, input_size); } } template bool TryLaunchTensorCoreDotBackwardKernel(user_op::KernelComputeContext* ctx) { const int input_size = ctx->input_size("features"); if (input_size == 1) { return DispatchFeatureInteractionDotBackwardPackSize(ctx, input_size); } else if (input_size == 2) { return DispatchFeatureInteractionDotBackwardPackSize(ctx, input_size); } else if (input_size <= 8) { return DispatchFeatureInteractionDotBackwardPackSize(ctx, input_size); } else { CHECK_LE(input_size, 128) << "input_size must not greater than 128. "; return DispatchFeatureInteractionDotBackwardPackSize(ctx, input_size); } } template class FusedDotFeatureInteractionKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedDotFeatureInteractionKernel() = default; ~FusedDotFeatureInteractionKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const DataType data_type = out->data_type(); CHECK_LT(out->shape_view().elem_cnt(), GetMaxVal()); auto* cuda_stream = ctx->stream()->As(); if ((cuda_stream->device_properties().major >= 7 && data_type == DataType::kFloat16) || (cuda_stream->device_properties().major >= 8 && data_type == DataType::kFloat)) { bool success = TryLaunchTensorCoreDotKernel(ctx); if (success == true) { return; } } CHECK(!ctx->has_input("sparse_feature", 0)) << "sparse_feature is not supported. "; user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t batch_size = out->shape_view().At(0); int64_t features_concated_dim = 0; for (int64_t i = 0; i < ctx->input_size("features"); ++i) { features_concated_dim += ctx->TensorDesc4ArgNameAndIndex("features", i)->shape().At(1); } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex("features", 0)->shape().At(2); const int64_t out_dim = out->shape_view().At(1); const int32_t output_padding = ctx->Attr("output_padding"); const int64_t valid_out_dim = out_dim - output_padding; const bool self_interaction = ctx->Attr("self_interaction"); T* matmul_out = reinterpret_cast(tmp_buffer->mut_dptr()); size_t matmul_out_size = GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T)); const int64_t interaction_dim = self_interaction ? features_concated_dim * (features_concated_dim + 1) / 2 : features_concated_dim * (features_concated_dim - 1) / 2; int32_t* gather_indices_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + matmul_out_size); size_t gather_indices_size = GetCudaAlignedSize(interaction_dim * sizeof(int32_t)); T* padded_concated_features_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + matmul_out_size + gather_indices_size); size_t padded_concated_features_size = GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T)); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), matmul_out_size + gather_indices_size + padded_concated_features_size); ConcatFeatures(ctx, batch_size, concated_padded_dim * vector_size, padded_concated_features_ptr); auto batch_matmul = ep::primitive::NewPrimitive( ctx->device_type(), data_type, ep::primitive::BlasTransposeType::N, ep::primitive::BlasTransposeType::T); batch_matmul->Launch(ctx->stream(), batch_size, concated_padded_dim, concated_padded_dim, vector_size, 1.0, padded_concated_features_ptr, padded_concated_features_ptr, 0.0, matmul_out); int64_t output_concat_end_dim = 0; const T* output_concat_ptr = nullptr; if (ctx->has_input("output_concat", 0)) { user_op::Tensor* output_concat = ctx->Tensor4ArgNameAndIndex("output_concat", 0); output_concat_end_dim = output_concat->shape_view().At(1); output_concat_ptr = output_concat->dptr(); } CHECK_EQ(valid_out_dim, output_concat_end_dim + interaction_dim); GatherConcatKernel(ctx->stream(), out->shape_view().elem_cnt(), out_dim, valid_out_dim, features_concated_dim, concated_padded_dim, output_concat_end_dim, self_interaction, matmul_out, output_concat_ptr, gather_indices_ptr, out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenFusedDotFeatureInteractionInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& first_feature_shape = ctx->InputShape("features", 0); const int64_t batch_size = first_feature_shape.At(0); const int64_t vector_size = first_feature_shape.At(2); int64_t features_concated_dim = 0; for (int32_t i = 0; i < ctx->input_size("features"); ++i) { features_concated_dim += ctx->InputShape("features", i).At(1); } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); size_t matmul_out_size = GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T)); const bool self_interaction = ctx->Attr("self_interaction"); const int64_t interaction_dim = self_interaction ? features_concated_dim * (features_concated_dim + 1) / 2 : features_concated_dim * (features_concated_dim - 1) / 2; size_t gather_indices_size = GetCudaAlignedSize(interaction_dim * sizeof(int32_t)); size_t padded_concated_features_size = GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T)); return matmul_out_size + gather_indices_size + padded_concated_features_size; }; } #define REGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_dot_feature_interaction") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && (user_op::HobAttr("pooling") == "none")) \ .SetInferTmpSizeFn(GenFusedDotFeatureInteractionInferTmpSizeFn()); REGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(float) REGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(half) template class FusedDotFeatureInteractionGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedDotFeatureInteractionGradKernel() = default; ~FusedDotFeatureInteractionGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const DataType data_type = dy->data_type(); auto* cuda_stream = ctx->stream()->As(); if ((cuda_stream->device_properties().major >= 7 && data_type == DataType::kFloat16) || (cuda_stream->device_properties().major >= 8 && data_type == DataType::kFloat)) { bool success = TryLaunchTensorCoreDotBackwardKernel(ctx); if (success == true) { return; } } CHECK(!ctx->has_input("sparse_feature", 0)) << "sparse_feature is not supported. "; const int64_t batch_size = dy->shape_view().At(0); int64_t features_concated_dim = 0; for (int32_t i = 0; i < ctx->output_size("features_grad"); ++i) { features_concated_dim += ctx->TensorDesc4ArgNameAndIndex("features_grad", i)->shape().At(1); } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex("features_grad", 0)->shape().At(2); const int64_t out_dim = dy->shape_view().At(1); const bool self_interaction = ctx->Attr("self_interaction"); T* matmul_out_grad_ptr = reinterpret_cast(tmp_buffer->mut_dptr()); size_t matmul_out_grad_size = GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T)); T* padded_concated_features_grad_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + matmul_out_grad_size); size_t padded_concated_features_grad_size = GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T)); T* padded_concated_features_ptr = reinterpret_cast( tmp_buffer->mut_dptr() + matmul_out_grad_size + padded_concated_features_grad_size); size_t padded_concated_features_size = padded_concated_features_grad_size; CHECK_LE( matmul_out_grad_size + padded_concated_features_grad_size + padded_concated_features_size, tmp_buffer->shape_view().elem_cnt()); ConcatFeatures(ctx, batch_size, concated_padded_dim * vector_size, padded_concated_features_ptr); T* output_concat_grad_ptr = nullptr; int64_t output_concat_end_dim = 0; if (ctx->has_output("output_concat_grad", 0)) { user_op::Tensor* output_concat_grad = ctx->Tensor4ArgNameAndIndex("output_concat_grad", 0); output_concat_grad_ptr = output_concat_grad->mut_dptr(); output_concat_end_dim = output_concat_grad->shape_view().At(1); } ScatterSplitAddTranspose(ctx->stream(), batch_size, out_dim, concated_padded_dim, features_concated_dim, output_concat_end_dim, self_interaction, dy->dptr(), output_concat_grad_ptr, matmul_out_grad_ptr); auto batch_matmul = ep::primitive::NewPrimitive( ctx->device_type(), data_type, ep::primitive::BlasTransposeType::N, ep::primitive::BlasTransposeType::N); batch_matmul->Launch(ctx->stream(), batch_size, concated_padded_dim, vector_size, concated_padded_dim, 1.0, matmul_out_grad_ptr, padded_concated_features_ptr, 0.0, padded_concated_features_grad_ptr); ConcatFeaturesGrad(ctx, batch_size, concated_padded_dim, vector_size, padded_concated_features_grad_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenFusedDotFeatureInteractionGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { int64_t features_concated_dim = 0; for (int32_t i = 0; i < ctx->output_size("features_grad"); ++i) { features_concated_dim += ctx->InputShape("features_grad", i).At(1); } const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim); const int64_t batch_size = ctx->InputShape("features_grad", 0).At(0); const int64_t vector_size = ctx->InputShape("features_grad", 0).At(2); size_t matmul_out_grad_size = GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T)); size_t padded_concated_features_grad_size = GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T)); size_t padded_concated_features_size = padded_concated_features_grad_size; return matmul_out_grad_size + padded_concated_features_grad_size + padded_concated_features_size; }; } #define REGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_dot_feature_interaction_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobAttr("pooling") == "none")) \ .SetInferTmpSizeFn(GenFusedDotFeatureInteractionGradInferTmpSizeFn()); REGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(float) REGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(half) template class FusedDotFeatureInteractionPoolingSumGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedDotFeatureInteractionPoolingSumGradKernel() = default; ~FusedDotFeatureInteractionPoolingSumGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int input_size = ctx->input_size("features"); if (input_size == 1) { DispatchFeatureInteractionSumGradInputSize(ctx, input_size); } else if (input_size == 2) { DispatchFeatureInteractionSumGradInputSize(ctx, input_size); } else if (input_size <= 8) { DispatchFeatureInteractionSumGradInputSize(ctx, input_size); } else { CHECK_LE(input_size, 128) << "input_size must not greater than 128. "; DispatchFeatureInteractionSumGradInputSize(ctx, input_size); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_dot_feature_interaction_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobAttr("pooling") == "sum")); REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(float) REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_gelu_mul_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace cuda { namespace fused_gelu { OF_DEVICE_FUNC float TanhApprox(float x) { #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) float r; asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); return r; #else return tanhf(x); #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) } template struct FusedFastGeluMulFunctor { static constexpr T alpha = static_cast(0.7978845608028654); static constexpr T beta = static_cast(0.044714998453855515); OF_DEVICE_FUNC FusedFastGeluMulFunctor() {} OF_DEVICE_FUNC T operator()(T x, T m) const { // ref to UnaryFunctor of kFastGelu const T half = static_cast(0.5); const T one = static_cast(1); const T tanh_in = alpha * (x + beta * x * x * x); return half * x * (one + tanh(tanh_in)) * m; } }; template<> struct FusedFastGeluMulFunctor { static constexpr float alpha = FusedFastGeluMulFunctor::alpha; static constexpr float beta = FusedFastGeluMulFunctor::beta; FusedFastGeluMulFunctor float_functor; OF_DEVICE_FUNC FusedFastGeluMulFunctor() {} OF_DEVICE_FUNC half operator()(const half x, const half m) const { #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) const float tanh_in = __half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x)); const float tanh_out = TanhApprox(tanh_in); return __float2half_rn(0.5F) * x * (__float2half_rn(1.0F) + __float2half_rn(tanh_out)) * m; #else return static_cast(float_functor(static_cast(x), static_cast(m))); #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) } #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) __device__ void Apply2(half* y, const half* x, const half* m) const { const half2 x2 = *(reinterpret_cast(x)); const float2 tanh_in = __half22float2( __hmul2(__float2half2_rn(alpha), __hadd2(x2, __hmul2(__hmul2(__hmul2(__float2half2_rn(beta), x2), x2), x2)))); float2 tanh_out; tanh_out.x = TanhApprox(tanh_in.x); tanh_out.y = TanhApprox(tanh_in.y); const half2 m2 = *(reinterpret_cast(m)); const half2 y2 = __hmul2(__hmul2(__hmul2(__float2half2_rn(0.5F), x2), __hadd2(__float2half2_rn(1.0F), __float22half2_rn(tanh_out))), m2); *reinterpret_cast(y) = y2; } #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) }; #if CUDA_VERSION >= 11000 template<> struct FusedFastGeluMulFunctor { FusedFastGeluMulFunctor float_functor; OF_DEVICE_FUNC FusedFastGeluMulFunctor() {} OF_DEVICE_FUNC nv_bfloat16 operator()(const nv_bfloat16 x, const nv_bfloat16 m) const { return __float2bfloat16(float_functor(__bfloat162float(x), __bfloat162float(m))); } }; #endif // CUDA_VERSION >= 11000 } // namespace fused_gelu template class FusedFastGeluMulKernel final : public user_op::OpKernel { public: FusedFastGeluMulKernel() = default; ~FusedFastGeluMulKernel() override = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* in = ctx->Tensor4ArgNameAndIndex("in", 0); const auto* multiplier = ctx->Tensor4ArgNameAndIndex("multiplier", 0); auto* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t elem_cnt = in->shape_view().elem_cnt(); OF_CUDA_CHECK((elementwise::Binary(fused_gelu::FusedFastGeluMulFunctor(), elem_cnt, out->mut_dptr(), in->dptr(), multiplier->dptr(), ctx->stream()->As()->cuda_stream()))); }; }; #define REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_fast_gelu_mul") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(float) REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(double) REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(nv_bfloat16) #endif namespace fused_gelu { template struct FusedFastGeluMulGradFunctor { static constexpr T alpha = static_cast(0.7978845608028654); static constexpr T beta = static_cast(0.044714998453855515); __device__ FusedFastGeluMulGradFunctor() {} __device__ void operator()(T& x_diff, T& m_diff, const T& dy, const T& x, const T& m) const { const T one = static_cast(1); const T half = static_cast(0.5); const T pow3 = x * x * x; const T tanh_in = alpha * (x + beta * pow3); const T tanh_out = tanh(alpha * (x + beta * pow3)); // calc m_diff ref to UnaryFunctor of kFastGelu m_diff = half * x * (one + tanh(tanh_in)) * dy; // calc x_diff ref to BinaryOp::kFastGeluBackwardWithDyX const T dtanh = alpha * (half * x + beta * static_cast(1.5) * pow3); x_diff = (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out)) * m * dy; } }; template<> struct FusedFastGeluMulGradFunctor { static constexpr float alpha = FusedFastGeluMulGradFunctor::alpha; static constexpr float beta = FusedFastGeluMulGradFunctor::beta; FusedFastGeluMulGradFunctor float_functor; __device__ FusedFastGeluMulGradFunctor() {} __device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x, const half& m) const { #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) const half halpha = __float2half_rn(alpha); const half hbeta = __float2half_rn(beta); const half hone = __float2half_rn(1.0F); const half hhalf = __float2half_rn(0.5F); const half pow3 = x * x * x; const float tanh_in = __half2float(halpha * (x + hbeta * pow3)); const half tanh_out = __float2half_rn(TanhApprox(tanh_in)); // m_diff m_diff = hhalf * x * (hone + tanh_out) * dy; // x_diff const half dtanh = halpha * (hhalf * x + hbeta * __float2half_rn(1.5F) * pow3); x_diff = (hhalf + hhalf * tanh_out + dtanh * (hone - tanh_out * tanh_out)) * m * dy; #else float x_diff_float; float m_diff_float; float_functor(x_diff_float, m_diff_float, static_cast(dy), static_cast(x), static_cast(m)); x_diff = static_cast(x_diff_float); m_diff = static_cast(m_diff_float); #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) } #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) __device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x, const half* m) const { const half2 dy2 = *(reinterpret_cast(dy)); const half2 x2 = *(reinterpret_cast(x)); const half2 m2 = *(reinterpret_cast(m)); const half2 alpha2 = __float2half2_rn(alpha); const half2 beta2 = __float2half2_rn(beta); const half2 one2 = __float2half2_rn(1.0F); const half2 hhalf2 = __float2half2_rn(0.5F); const half2 pow3 = __hmul2(__hmul2(x2, x2), x2); const float2 tanh_in = __half22float2(__hmul2(alpha2, __hadd2(x2, __hmul2(beta2, pow3)))); float2 tanh_out; tanh_out.x = TanhApprox(tanh_in.x); tanh_out.y = TanhApprox(tanh_in.y); const half2 tanh_out2 = __float22half2_rn(tanh_out); // m_diff const half2 m_diff2 = __hmul2(__hmul2(hhalf2, __hmul2(x2, __hadd2(one2, tanh_out2))), dy2); // x_diff const half2 dtanh = __hmul2( alpha2, __hadd2(__hmul2(hhalf2, x2), __hmul2(beta2, __hmul2(pow3, __float2half2_rn(1.5F))))); const half2 x_diff2 = __hmul2(__hmul2(__hadd2(__hadd2(hhalf2, __hmul2(hhalf2, tanh_out2)), __hmul2(dtanh, __hsub2(one2, __hmul2(tanh_out2, tanh_out2)))), m2), dy2); *reinterpret_cast(x_diff) = x_diff2; *reinterpret_cast(m_diff) = m_diff2; } #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) }; #if CUDA_VERSION >= 11000 template<> struct FusedFastGeluMulGradFunctor { FusedFastGeluMulGradFunctor float_functor; __device__ FusedFastGeluMulGradFunctor() {} __device__ void operator()(nv_bfloat16& x_diff, nv_bfloat16& m_diff, const nv_bfloat16& dy, const nv_bfloat16& x, const nv_bfloat16& m) const { float x_diff_float; float m_diff_float; float_functor(x_diff_float, m_diff_float, __bfloat162float(dy), __bfloat162float(x), __bfloat162float(m)); x_diff = __float2bfloat16(x_diff_float); m_diff = __float2bfloat16(m_diff_float); } }; #endif // CUDA_VERSION >= 11000 template __device__ __forceinline__ typename std::enable_if::value == true && pack_size % 2 == 0, void>::type FusedFastGeluMulGradFunctorApplyPack(const FunctorT& functor, elementwise::Packed& x_diff_pack, elementwise::Packed& m_diff_pack, const elementwise::Packed& dy_pack, const elementwise::Packed& x_pack, const elementwise::Packed& m_pack) { #pragma unroll for (int j = 0; j < pack_size; j += 2) { functor.Apply2(x_diff_pack.elem + j, m_diff_pack.elem + j, dy_pack.elem + j, x_pack.elem + j, m_pack.elem + j); } } template __device__ __forceinline__ typename std::enable_if::value == false || pack_size % 2 != 0, void>::type FusedFastGeluMulGradFunctorApplyPack(const FunctorT& functor, elementwise::Packed& x_diff_pack, elementwise::Packed& m_diff_pack, const elementwise::Packed& dy_pack, const elementwise::Packed& x_pack, const elementwise::Packed& m_pack) { #pragma unroll for (int j = 0; j < pack_size; ++j) { functor(x_diff_pack.elem[j], m_diff_pack.elem[j], dy_pack.elem[j], x_pack.elem[j], m_pack.elem[j]); } } template __global__ void __launch_bounds__(elementwise::kBlockSize) FusedFastGeluMulGradCudaKernel(int64_t n_pack, elementwise::Packed* x_diff_pack, elementwise::Packed* m_diff_pack, const elementwise::Packed* dy_pack, const elementwise::Packed* x_pack, const elementwise::Packed* m_pack, int64_t n_tail, T* x_diff_tail, T* m_diff_tail, const T* dy_tail, const T* x_tail, const T* m_tail) { FusedFastGeluMulGradFunctor functor; const int global_tid = blockIdx.x * elementwise::kBlockSize + threadIdx.x; for (int64_t i = global_tid; i < n_pack; i += blockDim.x * gridDim.x) { FusedFastGeluMulGradFunctorApplyPack(functor, x_diff_pack[i], m_diff_pack[i], dy_pack[i], x_pack[i], m_pack[i]); } if (global_tid < n_tail) { functor(x_diff_tail[global_tid], m_diff_tail[global_tid], dy_tail[global_tid], x_tail[global_tid], m_tail[global_tid]); } } template cudaError_t LaunchFusedFastGeluMulGradCudaKernelByPack(cudaStream_t stream, int64_t n, T* x_diff, T* m_diff, const T* dy, const T* x, const T* m) { const int64_t n_pack = n / pack_size; const int64_t tail_offset = n_pack * pack_size; const int64_t n_tail = n - tail_offset; int num_blocks; { cudaError_t err = elementwise::GetNumBlocks(n_pack, &num_blocks); if (err != cudaSuccess) { return err; } } FusedFastGeluMulGradCudaKernel<<>>( n_pack, reinterpret_cast*>(x_diff), reinterpret_cast*>(m_diff), reinterpret_cast*>(dy), reinterpret_cast*>(x), reinterpret_cast*>(m), n_tail, x_diff + tail_offset, m_diff + tail_offset, dy + tail_offset, x + tail_offset, m + tail_offset); return cudaPeekAtLastError(); } template static cudaError_t LaunchFusedFastGeluMulGradCudaKernel(cudaStream_t stream, int64_t n, T* x_diff, T* m_diff, const T* dy, const T* x, const T* m) { constexpr int max_pack_size = elementwise::PackSize(); if (elementwise::IsAlignedForPack(x_diff, m_diff, dy, x, m)) { return LaunchFusedFastGeluMulGradCudaKernelByPack(stream, n, x_diff, m_diff, dy, x, m); } else { return LaunchFusedFastGeluMulGradCudaKernelByPack<1>(stream, n, x_diff, m_diff, dy, x, m); } } } // namespace fused_gelu template class FusedFastGeluMulGradKernel final : public user_op::OpKernel { public: FusedFastGeluMulGradKernel() = default; ~FusedFastGeluMulGradKernel() override = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* out_diff = ctx->Tensor4ArgNameAndIndex("out_diff", 0); const auto* in = ctx->Tensor4ArgNameAndIndex("in", 0); const auto* multiplier = ctx->Tensor4ArgNameAndIndex("multiplier", 0); auto* in_diff = ctx->Tensor4ArgNameAndIndex("in_diff", 0); auto* multiplier_diff = ctx->Tensor4ArgNameAndIndex("multiplier_diff", 0); int64_t elem_cnt = in->shape_view().elem_cnt(); OF_CUDA_CHECK((fused_gelu::LaunchFusedFastGeluMulGradCudaKernel( ctx->stream()->As()->cuda_stream(), elem_cnt, in_diff->mut_dptr(), multiplier_diff->mut_dptr(), out_diff->dptr(), in->dptr(), multiplier->dptr()))); }; }; #define REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_fast_gelu_mul_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out_diff", 0) == GetDataType::value)); REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(double) REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(nv_bfloat16) #endif } // namespace cuda } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_bounding_boxes_coord_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template __global__ void FusedGetBounddingBoxesCoordForward(const int n, const T* x1, const T* y1, const T* w1, const T* h1, const T* x2, const T* y2, const T* w2, const T* h2, T* b1_x1, T* b1_x2, T* b1_y1, T* b1_y2, T* b2_x1, T* b2_x2, T* b2_y1, T* b2_y2) { CUDA_1D_KERNEL_LOOP(i, n) { const T w1_ = w1[i] / static_cast(2.0); const T h1_ = h1[i] / static_cast(2.0); const T w2_ = w2[i] / static_cast(2.0); const T h2_ = h2[i] / static_cast(2.0); const T x1_i = x1[i], y1_i = y1[i], x2_i = x2[i], y2_i = y2[i]; b1_x1[i] = x1_i - w1_; b1_x2[i] = x1_i + w1_; b1_y1[i] = y1_i - h1_; b1_y2[i] = y1_i + h1_; b2_x1[i] = x2_i - w2_; b2_x2[i] = x2_i + w2_; b2_y1[i] = y2_i - h2_; b2_y2[i] = y2_i + h2_; } } template __global__ void FusedGetBounddingBoxesCoordBackward( const int n, const T* b1_x1_diff, const T* b1_x2_diff, const T* b1_y1_diff, const T* b1_y2_diff, const T* b2_x1_diff, const T* b2_x2_diff, const T* b2_y1_diff, const T* b2_y2_diff, T* x1_diff, T* y1_diff, T* w1_diff, T* h1_diff, T* x2_diff, T* y2_diff, T* w2_diff, T* h2_diff) { CUDA_1D_KERNEL_LOOP(i, n) { const T b1_x1_diff_i = b1_x1_diff[i]; const T b1_x2_diff_i = b1_x2_diff[i]; const T b1_y1_diff_i = b1_y1_diff[i]; const T b1_y2_diff_i = b1_y2_diff[i]; const T b2_x1_diff_i = b2_x1_diff[i]; const T b2_x2_diff_i = b2_x2_diff[i]; const T b2_y2_diff_i = b2_y2_diff[i]; const T b2_y1_diff_i = b2_y1_diff[i]; x1_diff[i] = b1_x1_diff_i + b1_x2_diff_i; y1_diff[i] = b1_y1_diff_i + b1_y2_diff_i; w1_diff[i] = (b1_x2_diff_i - b1_x1_diff_i) / static_cast(2.0); h1_diff[i] = (b1_y2_diff_i - b1_y1_diff_i) / static_cast(2.0); x2_diff[i] = b2_x1_diff_i + b2_x2_diff_i; y2_diff[i] = b2_y1_diff_i + b2_y2_diff_i; w2_diff[i] = (b2_x2_diff_i - b2_x1_diff_i) / static_cast(2.0); h2_diff[i] = (b2_y2_diff_i - b2_y1_diff_i) / static_cast(2.0); } } }; // namespace template class FusedGetBounddingBoxesCoordGpuKernel final : public user_op::OpKernel { public: FusedGetBounddingBoxesCoordGpuKernel() = default; ~FusedGetBounddingBoxesCoordGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); const user_op::Tensor* y1 = ctx->Tensor4ArgNameAndIndex("y1", 0); const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex("w1", 0); const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex("h1", 0); const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); const user_op::Tensor* y2 = ctx->Tensor4ArgNameAndIndex("y2", 0); const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex("w2", 0); const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex("h2", 0); user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); const int32_t elem_cnt = x1->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetBounddingBoxesCoordForward), ctx->stream(), elem_cnt, elem_cnt, x1->dptr(), y1->dptr(), w1->dptr(), h1->dptr(), x2->dptr(), y2->dptr(), w2->dptr(), h2->dptr(), b1_x1->mut_dptr(), b1_x2->mut_dptr(), b1_y1->mut_dptr(), b1_y2->mut_dptr(), b2_x1->mut_dptr(), b2_x2->mut_dptr(), b2_y1->mut_dptr(), b2_y2->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_boundding_boxes_coord") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1", 0) == GetDataType::value)); REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(float) REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(half) REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(double) template class FusedGetBounddingBoxesCoordGradGpuKernel final : public user_op::OpKernel { public: FusedGetBounddingBoxesCoordGradGpuKernel() = default; ~FusedGetBounddingBoxesCoordGradGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex("b1_x1_diff", 0); const user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex("b1_x2_diff", 0); const user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex("b1_y1_diff", 0); const user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex("b1_y2_diff", 0); const user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex("b2_x1_diff", 0); const user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex("b2_x2_diff", 0); const user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex("b2_y1_diff", 0); const user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex("b2_y2_diff", 0); user_op::Tensor* x1_diff = ctx->Tensor4ArgNameAndIndex("x1_diff", 0); user_op::Tensor* y1_diff = ctx->Tensor4ArgNameAndIndex("y1_diff", 0); user_op::Tensor* w1_diff = ctx->Tensor4ArgNameAndIndex("w1_diff", 0); user_op::Tensor* h1_diff = ctx->Tensor4ArgNameAndIndex("h1_diff", 0); user_op::Tensor* x2_diff = ctx->Tensor4ArgNameAndIndex("x2_diff", 0); user_op::Tensor* y2_diff = ctx->Tensor4ArgNameAndIndex("y2_diff", 0); user_op::Tensor* w2_diff = ctx->Tensor4ArgNameAndIndex("w2_diff", 0); user_op::Tensor* h2_diff = ctx->Tensor4ArgNameAndIndex("h2_diff", 0); const int32_t elem_cnt = b1_x1_diff->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetBounddingBoxesCoordBackward), ctx->stream(), elem_cnt, elem_cnt, b1_x1_diff->dptr(), b1_x2_diff->dptr(), b1_y1_diff->dptr(), b1_y2_diff->dptr(), b2_x1_diff->dptr(), b2_x2_diff->dptr(), b2_y1_diff->dptr(), b2_y2_diff->dptr(), x1_diff->mut_dptr(), y1_diff->mut_dptr(), w1_diff->mut_dptr(), h1_diff->mut_dptr(), x2_diff->mut_dptr(), y2_diff->mut_dptr(), w2_diff->mut_dptr(), h2_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_boundding_boxes_coord_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1_diff", 0) == GetDataType::value)); REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(half) REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_ciou_diagonal_angle_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/data_type.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template struct FusedCiouAngleForwardFunctor { __device__ T Compute(T w1, T h1, T w2, T h2, float eps) const { T angle = (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps))) * (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps))); return static_cast(4.0 / (M_PI * M_PI)) * angle; } }; template<> struct FusedCiouAngleForwardFunctor { __device__ half Compute(half w1, half h1, half w2, half h2, float eps) const { float w1f = __half2float(w1); float h1f = __half2float(h1); float w2f = __half2float(w2); float h2f = __half2float(h2); float angle = (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps))) * (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps))); return __float2half(static_cast(4.0 / (M_PI * M_PI)) * angle); } }; template __global__ void FusedCiouAngleForward(FUNCTOR functor, const int n, const T* w1, const T* h1, const T* w2, const T* h2, const float eps, T* v) { CUDA_1D_KERNEL_LOOP(i, n) { v[i] = functor.Compute(w1[i], h1[i], w2[i], h2[i], eps); } } template struct FusedCiouAngleBackwardFunctor { __device__ T ComputeW1(T h1, T angle_delta, T angle1, float eps) const { return static_cast(-1.0) * angle_delta / ((h1 + eps) * angle1); } __device__ T ComputeW2(T h2, T angle_delta, T angle2, float eps) const { return angle_delta / ((h2 + eps) * angle2); } __device__ T ComputeH1(T w1, T h1, T angle_delta, T angle1, float eps) const { return w1 * angle_delta / ((h1 + eps) * (h1 + eps) * angle1); } __device__ T ComputeH2(T w2, T h2, T angle_delta, T angle2, float eps) const { return static_cast(-1.0) * w2 * angle_delta / ((h2 + eps) * (h2 + eps) * angle2); } }; template<> struct FusedCiouAngleBackwardFunctor { __device__ half ComputeW1(half h1, half angle_delta, half angle1, float eps) const { float h1f = __half2float(h1); float angle_delta_f = __half2float(angle_delta); float angle1f = __half2float(angle1); return __float2half(-1.0 * angle_delta_f / ((h1f + eps) * angle1f)); } __device__ half ComputeW2(half h2, half angle_delta, half angle2, float eps) const { float h2f = __half2float(h2); float angle_delta_f = __half2float(angle_delta); float angle2f = __half2float(angle2); return __float2half(angle_delta_f / ((h2f + eps) * angle2f)); } __device__ half ComputeH1(half w1, half h1, half angle_delta, half angle1, float eps) const { float w1f = __half2float(w1); float h1f = __half2float(h1); float angle_delta_f = __half2float(angle_delta); float angle1f = __half2float(angle1); return __float2half(w1f * angle_delta_f / ((h1f + eps) * (h1f + eps) * angle1f)); } __device__ half ComputeH2(half w2, half h2, half angle_delta, half angle2, float eps) const { float w2f = __half2float(w2); float h2f = __half2float(h2); float angle_delta_f = __half2float(angle_delta); float angle2f = __half2float(angle2); return __float2half(-1.0 * w2f * angle_delta_f / ((h2f + eps) * (h2f + eps) * angle2f)); } }; template struct CalcAngleFunctor { __device__ T ComputeDelta(T w1, T h1, T w2, T h2, float eps) const { return static_cast(8.0) * (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps))) / static_cast((M_PI * M_PI)); } __device__ T Compute1(T w1, T h1, float eps) const { return static_cast(1.0) + (w1 * w1 / ((h1 + eps) * (h1 + eps))); } __device__ T Compute2(T w2, T h2, float eps) const { return static_cast(1.0) + (w2 * w2 / ((h2 + eps) * (h2 + eps))); } }; template<> struct CalcAngleFunctor { __device__ half ComputeDelta(half w1, half h1, half w2, half h2, float eps) const { float w1f = __half2float(w1); float h1f = __half2float(h1); float w2f = __half2float(w2); float h2f = __half2float(h2); return __float2half(8.0 * (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps))) / static_cast((M_PI * M_PI))); } __device__ half Compute1(half w1, half h1, float eps) const { float w1f = __half2float(w1); float h1f = __half2float(h1); return __float2half(1.0 + (w1f * w1f / ((h1f + eps) * (h1f + eps)))); } __device__ half Compute2(half w2, half h2, float eps) const { float w2f = __half2float(w2); float h2f = __half2float(h2); return __float2half(1.0 + (w2f * w2f / ((h2f + eps) * (h2f + eps)))); } }; template __global__ void FusedCiouAngleBackward(FUNCTOR_BACKWARD functor_backward, FUNCTOR_ANGLE functor_angle, const int n, const T* w1, const T* h1, const T* w2, const T* h2, const T* v_diff, const float eps, T* w1_diff, T* h1_diff, T* w2_diff, T* h2_diff) { CUDA_1D_KERNEL_LOOP(i, n) { const T w1_i = w1[i]; const T h1_i = h1[i]; const T w2_i = w2[i]; const T h2_i = h2[i]; const T v_diff_i = v_diff[i]; const T angle_delta_i = functor_angle.ComputeDelta(w1_i, h1_i, w2_i, h2_i, eps); const T angle1_i = functor_angle.Compute1(w1_i, h1_i, eps); const T angle2_i = functor_angle.Compute2(w2_i, h2_i, eps); w1_diff[i] = functor_backward.ComputeW1(h1_i, angle_delta_i, angle1_i, eps) * v_diff_i; w2_diff[i] = functor_backward.ComputeW2(h2_i, angle_delta_i, angle2_i, eps) * v_diff_i; h1_diff[i] = functor_backward.ComputeH1(w1_i, h1_i, angle_delta_i, angle1_i, eps) * v_diff_i; h2_diff[i] = functor_backward.ComputeH2(w2_i, h2_i, angle_delta_i, angle2_i, eps) * v_diff_i; } } } // namespace template class FusedGetCiouDiagonalAngleKernel final : public user_op::OpKernel { public: FusedGetCiouDiagonalAngleKernel() = default; ~FusedGetCiouDiagonalAngleKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex("w1", 0); const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex("h1", 0); const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex("w2", 0); const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex("h2", 0); const auto eps = ctx->Attr("eps"); user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); const int64_t elem_cnt = w1->shape_view().elem_cnt(); FusedCiouAngleForwardFunctor fused_get_ciou_diagonal_angle_functor{}; RUN_CUDA_KERNEL((FusedCiouAngleForward), ctx->stream(), elem_cnt, fused_get_ciou_diagonal_angle_functor, elem_cnt, w1->dptr(), h1->dptr(), w2->dptr(), h2->dptr(), eps, v->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_ciou_diagonal_angle") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("v", 0) == GetDataType::value)); REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(float) REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(double) REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(half) template class FusedGetCiouDiagonalAngleGradKernel final : public user_op::OpKernel { public: FusedGetCiouDiagonalAngleGradKernel() = default; ~FusedGetCiouDiagonalAngleGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex("w1", 0); const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex("h1", 0); const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex("w2", 0); const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex("h2", 0); const user_op::Tensor* v_diff = ctx->Tensor4ArgNameAndIndex("v_diff", 0); const auto eps = ctx->Attr("eps"); user_op::Tensor* w1_diff = ctx->Tensor4ArgNameAndIndex("w1_diff", 0); user_op::Tensor* h1_diff = ctx->Tensor4ArgNameAndIndex("h1_diff", 0); user_op::Tensor* w2_diff = ctx->Tensor4ArgNameAndIndex("w2_diff", 0); user_op::Tensor* h2_diff = ctx->Tensor4ArgNameAndIndex("h2_diff", 0); const int64_t elem_cnt = w1->shape_view().elem_cnt(); FusedCiouAngleBackwardFunctor fused_get_ciou_diagonal_angle_grad_functor{}; CalcAngleFunctor calc_angle_functor{}; RUN_CUDA_KERNEL((FusedCiouAngleBackward), ctx->stream(), elem_cnt, fused_get_ciou_diagonal_angle_grad_functor, calc_angle_functor, elem_cnt, w1->dptr(), h1->dptr(), w2->dptr(), h2->dptr(), v_diff->dptr(), eps, w1_diff->mut_dptr(), h1_diff->mut_dptr(), w2_diff->mut_dptr(), h2_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_ciou_diagonal_angle_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("w1_diff", 0) == GetDataType::value)); REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(double) REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_ciou_result_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template __global__ void FusedGetCiouResultForward(const int n, const T* v, const T* iou, const T* rho2, const T* c2, T* y, T* alpha, float eps) { CUDA_1D_KERNEL_LOOP(i, n) { const T v_i = v[i]; const T iou_i = iou[i]; const T alpha_i = v_i / (v_i - iou_i + static_cast(1.0 + eps)); y[i] = iou_i - (rho2[i] / c2[i] + v_i * alpha_i); alpha[i] = alpha_i; } } template __global__ void FusedGetCiouResultBackward(const int n, const T* dy, const T* alpha, const T* rho2, const T* c2, T* dv, T* diou, T* drho2, T* dc2) { CUDA_1D_KERNEL_LOOP(i, n) { const T c2_i = c2[i]; const T dy_i = dy[i]; dv[i] = -alpha[i] * dy_i; diou[i] = dy_i; drho2[i] = -dy_i / c2[i]; dc2[i] = rho2[i] / (c2_i * c2_i) * dy_i; } } }; // namespace template class FusedGetCiouResultGpuKernel final : public user_op::OpKernel { public: FusedGetCiouResultGpuKernel() = default; ~FusedGetCiouResultGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); const user_op::Tensor* iou = ctx->Tensor4ArgNameAndIndex("iou", 0); const user_op::Tensor* rho2 = ctx->Tensor4ArgNameAndIndex("rho2", 0); const user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex("c2", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); float eps = ctx->Attr("eps"); const int32_t elem_cnt = v->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetCiouResultForward), ctx->stream(), elem_cnt, elem_cnt, v->dptr(), iou->dptr(), rho2->dptr(), c2->dptr(), y->mut_dptr(), alpha->mut_dptr(), eps); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_ciou_result") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("v", 0) == GetDataType::value)); REGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(float) REGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(half) REGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(double) template class FusedGetCiouResultGradGpuKernel final : public user_op::OpKernel { public: FusedGetCiouResultGradGpuKernel() = default; ~FusedGetCiouResultGradGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* rho2 = ctx->Tensor4ArgNameAndIndex("rho2", 0); const user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex("c2", 0); user_op::Tensor* dv = ctx->Tensor4ArgNameAndIndex("dv", 0); user_op::Tensor* diou = ctx->Tensor4ArgNameAndIndex("diou", 0); user_op::Tensor* drho2 = ctx->Tensor4ArgNameAndIndex("drho2", 0); user_op::Tensor* dc2 = ctx->Tensor4ArgNameAndIndex("dc2", 0); const int32_t elem_cnt = dy->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetCiouResultBackward), ctx->stream(), elem_cnt, elem_cnt, dy->dptr(), alpha->dptr(), rho2->dptr(), c2->dptr(), dv->mut_dptr(), diou->mut_dptr(), drho2->mut_dptr(), dc2->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_ciou_result_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(half) REGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_convex_diagonal_squared_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/math_util.h" namespace oneflow { namespace { template __global__ void FusedGetConvexDiagonalSquaredForward(const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, T* c2, const float eps) { CUDA_1D_KERNEL_LOOP(i, n) { const T cw = DeviceMax(b1_x2[i], b2_x2[i]) - DeviceMin(b1_x1[i], b2_x1[i]); const T ch = DeviceMax(b1_y2[i], b2_y2[i]) - DeviceMin(b1_y1[i], b2_y1[i]); c2[i] = cw * cw + ch * ch + static_cast(eps); } } template __global__ void FusedGetConvexDiagonalSquaredBackward( const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, const T* c2_diff, T* b1_x1_diff, T* b1_x2_diff, T* b2_x1_diff, T* b2_x2_diff, T* b1_y1_diff, T* b1_y2_diff, T* b2_y1_diff, T* b2_y2_diff, const float eps) { CUDA_1D_KERNEL_LOOP(i, n) { const T zero = static_cast(0), one = static_cast(1); const T cw = DeviceMax(b1_x2[i], b2_x2[i]) - DeviceMin(b1_x1[i], b2_x1[i]); const T ch = DeviceMax(b1_y2[i], b2_y2[i]) - DeviceMin(b1_y1[i], b2_y1[i]); const T c2_diff_cw = static_cast(2) * cw * c2_diff[i]; const T c2_diff_ch = static_cast(2) * ch * c2_diff[i]; b1_x2_diff[i] = c2_diff_cw * (b1_x2[i] > b2_x2[i] ? one : zero); b2_x2_diff[i] = c2_diff_cw * (b1_x2[i] > b2_x2[i] ? zero : one); b1_x1_diff[i] = -c2_diff_cw * (b1_x1[i] < b2_x1[i] ? one : zero); b2_x1_diff[i] = -c2_diff_cw * (b1_x1[i] < b2_x1[i] ? zero : one); b1_y2_diff[i] = c2_diff_ch * (b1_y2[i] > b2_y2[i] ? one : zero); b2_y2_diff[i] = c2_diff_ch * (b1_y2[i] > b2_y2[i] ? zero : one); b1_y1_diff[i] = -c2_diff_ch * (b1_y1[i] < b2_y1[i] ? one : zero); b2_y1_diff[i] = -c2_diff_ch * (b1_y1[i] < b2_y1[i] ? zero : one); } } } // namespace template class FusedGetConvexDiagonalSquaredKernel final : public user_op::OpKernel { public: FusedGetConvexDiagonalSquaredKernel() = default; ~FusedGetConvexDiagonalSquaredKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex("c2", 0); const float eps = ctx->Attr("eps"); const int64_t elem_cnt = b1_x1->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetConvexDiagonalSquaredForward), ctx->stream(), elem_cnt, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), c2->mut_dptr(), eps); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_convex_diagonal_squared") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1", 0) == GetDataType::value)); REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(float) REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(double) REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(half) template class FusedGetConvexDiagonalSquaredGradKernel final : public user_op::OpKernel { public: FusedGetConvexDiagonalSquaredGradKernel() = default; ~FusedGetConvexDiagonalSquaredGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* c2_diff = ctx->Tensor4ArgNameAndIndex("c2_diff", 0); const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex("b1_x1_diff", 0); user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex("b1_x2_diff", 0); user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex("b2_x1_diff", 0); user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex("b2_x2_diff", 0); user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex("b1_y1_diff", 0); user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex("b1_y2_diff", 0); user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex("b2_y1_diff", 0); user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex("b2_y2_diff", 0); const float eps = ctx->Attr("eps"); const int64_t elem_cnt = b1_x1_diff->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetConvexDiagonalSquaredBackward), ctx->stream(), elem_cnt, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), c2_diff->dptr(), b1_x1_diff->mut_dptr(), b1_x2_diff->mut_dptr(), b2_x1_diff->mut_dptr(), b2_x2_diff->mut_dptr(), b1_y1_diff->mut_dptr(), b1_y2_diff->mut_dptr(), b2_y1_diff->mut_dptr(), b2_y2_diff->mut_dptr(), eps); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_convex_diagonal_squared_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1", 0) == GetDataType::value)); REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(double) REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_intersection_area_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template struct MinMaxDeltaFunctor { __device__ T Compute(T b1_x2_i, T b2_x2_i, T b1_x1_i, T b2_x1_i) const { return min(b1_x2_i, b2_x2_i) - max(b1_x1_i, b2_x1_i); } }; template<> struct MinMaxDeltaFunctor { __device__ half Compute(half b1_x2_i, half b2_x2_i, half b1_x1_i, half b2_x1_i) const { const half b_x2_min = b1_x2_i < b2_x2_i ? b1_x2_i : b2_x2_i; const half b_x1_max = b1_x1_i > b2_x1_i ? b1_x1_i : b2_x1_i; return b_x2_min - b_x1_max; } }; template __global__ void FusedGetIntersectionAreaBackward(FUNCTOR functor, const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, const T* inter_diff, T* b1_x1_diff, T* b1_x2_diff, T* b2_x1_diff, T* b2_x2_diff, T* b1_y1_diff, T* b1_y2_diff, T* b2_y1_diff, T* b2_y2_diff) { CUDA_1D_KERNEL_LOOP(i, n) { const T inter_diff_i = inter_diff[i]; const T b_x_min_max = functor.Compute(b1_x2[i], b2_x2[i], b1_x1[i], b2_x1[i]); const T b_y_min_max = functor.Compute(b1_y2[i], b2_y2[i], b1_y1[i], b2_y1[i]); const T b_x_min_max_inter = b_x_min_max * inter_diff_i; const T b_y_min_max_inter = b_y_min_max * inter_diff_i; b1_x1_diff[i] = static_cast(0.0); b1_x2_diff[i] = static_cast(0.0); b2_x1_diff[i] = static_cast(0.0); b2_x2_diff[i] = static_cast(0.0); b1_y1_diff[i] = static_cast(0.0); b1_y2_diff[i] = static_cast(0.0); b2_y1_diff[i] = static_cast(0.0); b2_y2_diff[i] = static_cast(0.0); if (b_x_min_max > static_cast(0.0) && b_y_min_max > static_cast(0.0)) { if (b1_x1[i] >= b2_x1[i]) { b1_x1_diff[i] = static_cast(-1.0) * b_y_min_max_inter; } if (b1_x1[i] <= b2_x1[i]) { b2_x1_diff[i] = static_cast(-1.0) * b_y_min_max_inter; } if (b1_x2[i] <= b2_x2[i]) { b1_x2_diff[i] = b_y_min_max_inter; } if (b1_x2[i] >= b2_x2[i]) { b2_x2_diff[i] = b_y_min_max_inter; } if (b1_y1[i] >= b2_y1[i]) { b1_y1_diff[i] = static_cast(-1.0) * b_x_min_max_inter; } if (b1_y1[i] <= b2_y1[i]) { b2_y1_diff[i] = static_cast(-1.0) * b_x_min_max_inter; } if (b1_y2[i] <= b2_y2[i]) { b1_y2_diff[i] = b_x_min_max_inter; } if (b1_y2[i] >= b2_y2[i]) { b2_y2_diff[i] = b_x_min_max_inter; } } } } template __global__ void FusedGetIntersectionAreaForward(FUNCTOR functor, const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1, const T* b2_y2, T* inter) { CUDA_1D_KERNEL_LOOP(i, n) { const T b_x_min_max = functor.Compute(b1_x2[i], b2_x2[i], b1_x1[i], b2_x1[i]); const T b_y_min_max = functor.Compute(b1_y2[i], b2_y2[i], b1_y1[i], b2_y1[i]); inter[i] = static_cast(0.0); if (b_x_min_max > static_cast(0.0) && b_y_min_max > static_cast(0.0)) { inter[i] = b_x_min_max * b_y_min_max; } } } } // namespace template class FusedGetIntersectionAreaKernel final : public user_op::OpKernel { public: FusedGetIntersectionAreaKernel() = default; ~FusedGetIntersectionAreaKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex("inter", 0); const int64_t elem_cnt = b1_x2->shape_view().elem_cnt(); MinMaxDeltaFunctor min_max_delta_functor{}; RUN_CUDA_KERNEL((FusedGetIntersectionAreaForward), ctx->stream(), elem_cnt, min_max_delta_functor, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), inter->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_intersection_area") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("inter", 0) == GetDataType::value)); REGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(float) REGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(double) REGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(half) template class FusedGetIntersectionAreaGradKernel final : public user_op::OpKernel { public: FusedGetIntersectionAreaGradKernel() = default; ~FusedGetIntersectionAreaGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex("b1_x1", 0); const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex("b1_x2", 0); const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex("b2_x1", 0); const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex("b2_x2", 0); const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex("b1_y1", 0); const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex("b1_y2", 0); const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex("b2_y1", 0); const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex("b2_y2", 0); user_op::Tensor* inter_diff = ctx->Tensor4ArgNameAndIndex("inter_diff", 0); user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex("b1_x1_diff", 0); user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex("b1_x2_diff", 0); user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex("b2_x1_diff", 0); user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex("b2_x2_diff", 0); user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex("b1_y1_diff", 0); user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex("b1_y2_diff", 0); user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex("b2_y1_diff", 0); user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex("b2_y2_diff", 0); const int64_t elem_cnt = b1_x1->shape_view().elem_cnt(); MinMaxDeltaFunctor min_max_delta_functor{}; RUN_CUDA_KERNEL((FusedGetIntersectionAreaBackward), ctx->stream(), elem_cnt, min_max_delta_functor, elem_cnt, b1_x1->dptr(), b1_x2->dptr(), b2_x1->dptr(), b2_x2->dptr(), b1_y1->dptr(), b1_y2->dptr(), b2_y1->dptr(), b2_y2->dptr(), inter_diff->dptr(), b1_x1_diff->mut_dptr(), b1_x2_diff->mut_dptr(), b2_x1_diff->mut_dptr(), b2_x2_diff->mut_dptr(), b1_y1_diff->mut_dptr(), b1_y2_diff->mut_dptr(), b2_y1_diff->mut_dptr(), b2_y2_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_intersection_area_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("b1_x1_diff", 0) == GetDataType::value)); REGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(double) REGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_get_iou_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template __global__ void FusedGetIouForward(const int n, const T* w1, const T* h1, const T* w2, const T* h2, const T* inter, T* iou, const float eps) { CUDA_1D_KERNEL_LOOP(i, n) { const T inter_i = inter[i]; iou[i] = inter_i / (w1[i] * h1[i] + w2[i] * h2[i] - inter_i + static_cast(eps)); } } template __global__ void FusedGetIouBackward(const int n, const T* diou, const T* w1, const T* h1, const T* w2, const T* h2, const T* inter, T* dw1, T* dh1, T* dinter, const float eps) { CUDA_1D_KERNEL_LOOP(i, n) { const T w1_i = w1[i], h1_i = h1[i], w2_i = w2[i], h2_i = h2[i], inter_i = inter[i], diou_i = diou[i]; const T w_h_eps = w1_i * h1_i + w2_i * h2_i + static_cast(eps); const T w_h_eps_inter_diff = w_h_eps - inter_i; const T w_h_eps_inter_diff_square = w_h_eps_inter_diff * w_h_eps_inter_diff; const T common_for_dwh = -inter_i * diou_i / w_h_eps_inter_diff_square; dinter[i] = w_h_eps * diou_i / w_h_eps_inter_diff_square; dw1[i] = h1_i * common_for_dwh; dh1[i] = w1_i * common_for_dwh; } } }; // namespace template class FusedGetIouGpuKernel final : public user_op::OpKernel { public: FusedGetIouGpuKernel() = default; ~FusedGetIouGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex("w1", 0); const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex("h1", 0); const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex("w2", 0); const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex("h2", 0); const user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex("inter", 0); user_op::Tensor* iou = ctx->Tensor4ArgNameAndIndex("iou", 0); float eps = ctx->Attr("eps"); const int32_t elem_cnt = w1->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetIouForward), ctx->stream(), elem_cnt, elem_cnt, w1->dptr(), h1->dptr(), w2->dptr(), h2->dptr(), inter->dptr(), iou->mut_dptr(), eps); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_IOU_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_iou") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("iou", 0) == GetDataType::value)); REGISTER_FUSED_GET_IOU_CUDA_KERNEL(float) REGISTER_FUSED_GET_IOU_CUDA_KERNEL(half) REGISTER_FUSED_GET_IOU_CUDA_KERNEL(double) template class FusedGetIouGradGpuKernel final : public user_op::OpKernel { public: FusedGetIouGradGpuKernel() = default; ~FusedGetIouGradGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* diou = ctx->Tensor4ArgNameAndIndex("diou", 0); const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex("w1", 0); const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex("h1", 0); const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex("w2", 0); const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex("h2", 0); const user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex("inter", 0); user_op::Tensor* dw1 = ctx->Tensor4ArgNameAndIndex("dw1", 0); user_op::Tensor* dh1 = ctx->Tensor4ArgNameAndIndex("dh1", 0); user_op::Tensor* dinter = ctx->Tensor4ArgNameAndIndex("dinter", 0); float eps = ctx->Attr("eps"); const int32_t elem_cnt = diou->shape_view().elem_cnt(); RUN_CUDA_KERNEL((FusedGetIouBackward), ctx->stream(), elem_cnt, elem_cnt, diou->dptr(), w1->dptr(), h1->dptr(), w2->dptr(), h2->dptr(), inter->dptr(), dw1->mut_dptr(), dh1->mut_dptr(), dinter->mut_dptr(), eps); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_get_iou_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("diou", 0) == GetDataType::value)); REGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(half) REGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_glu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" #if CUDA_VERSION >= 11020 #ifdef WITH_CUTLASS #include "device/dual_gemm.h" #include "thread/left_silu_and_mul.h" namespace cutlass { namespace epilogue { namespace thread { template typename Activation, typename ElementAccumulator_ = ElementOutput_, typename ElementCompute_ = ElementOutput_, FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> class RightActivationAndMul { public: using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; static int const kCount = Count; using FragmentOutput = Array; using FragmentAccumulator = Array; using ComputeFragment = Array; static FloatRoundStyle const kRound = Round; struct Params {}; private: ElementCompute alpha_; ElementCompute beta_; public: CUTLASS_HOST_DEVICE RightActivationAndMul(Params const& /*params*/) {} CUTLASS_HOST_DEVICE bool is_source_needed() const { return true; } CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) { assert(false); } CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const& lhs, FragmentAccumulator const& rhs) const { NumericArrayConverter accumulator_to_output; FragmentOutput converted_lhs = accumulator_to_output(lhs); FragmentOutput converted_rhs = accumulator_to_output(rhs); Activation act; cutlass::multiplies mul; auto act_rhs = act(converted_rhs); return mul(act_rhs, converted_lhs); } CUTLASS_HOST_DEVICE ElementOutput operator()(ElementAccumulator const& lhs, ElementAccumulator const& rhs) const { ElementOutput convert_lhs(lhs); ElementOutput convert_rhs(rhs); Activation act; cutlass::multiplies mul; auto act_rhs = act(convert_rhs); return mul(act_rhs, convert_lhs); } }; } // namespace thread } // namespace epilogue } // namespace cutlass #endif // WITH_CUTLASS namespace oneflow { namespace { #ifdef WITH_CUTLASS template struct GetCutlassType { using type = T; }; template<> struct GetCutlassType { using type = cutlass::half_t; }; #if CUDA_VERSION >= 11000 template<> struct GetCutlassType { using type = cutlass::bfloat16_t; }; #endif template typename Activation> void DualGemmGegluHalf(ep::CudaStream* stream, int32_t m, int32_t n, int32_t k, const void* x, const void* w, const void* v, const void* b, const void* c, void* wx, int32_t wx_stride, void* vx, int32_t vx_stride, void* y) { constexpr int kStages = 5; constexpr bool kSplitKSerial = false; constexpr bool kUseBias = true; using ElementOperandA = cutlass::half_t; using ElementOperandB = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = Acc; using ElementCompute = Acc; using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : ( // No bias kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing); using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination::value, ElementAccumulator, ElementCompute, kScaleType>; using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination::value, ElementAccumulator, ElementCompute, kScaleType>; using EpilogueOutputOp2 = cutlass::epilogue::thread::RightActivationAndMul< ElementOutput, 128 / cutlass::sizeof_bits::value, Activation, ElementOutput, ElementCompute>; const ElementCompute alpha0 = ElementCompute(1); const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0); const ElementCompute alpha1 = ElementCompute(1); const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0); // Optionally, we might not need intermediate GEMM outputs constexpr bool kStoreD0 = true; constexpr bool kStoreD1 = true; using DualGemm = cutlass::gemm::device::DualGemm< ElementOperandA, cutlass::layout::RowMajor, ElementOperandB, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, EpilogueOutputOp2, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, kStages, kStoreD0, kStoreD1, kSplitKSerial>; int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; typename cutlass::TensorRef tensor_a0( reinterpret_cast(x), k); typename cutlass::TensorRef tensor_b0( reinterpret_cast(w), k); typename cutlass::TensorRef tensor_b1( reinterpret_cast(v), k); typename cutlass::TensorRef tensor_bias0( reinterpret_cast(b), {0}); typename cutlass::TensorRef tensor_bias1( reinterpret_cast(c), {0}); typename cutlass::TensorRef tensor_d0( reinterpret_cast(wx), wx_stride); typename cutlass::TensorRef tensor_d1( reinterpret_cast(vx), vx_stride); typename cutlass::TensorRef tensor_out( reinterpret_cast(y), n); cutlass::gemm::GemmCoord problem_size(m, n, k); typename DualGemm::Arguments arguments{ problem_size, tensor_a0, tensor_b0, tensor_bias0, tensor_d0, tensor_b1, tensor_bias1, tensor_d1, tensor_out, {alpha0, beta0}, {alpha1, beta1}, {}, split_k_slices}; DualGemm dual_gemm_op; dual_gemm_op.initialize(arguments, stream->cublas_workspace(), stream->cuda_stream()); dual_gemm_op(stream->cuda_stream()); } template bool TryDispatchDualGemmImplActivation(ep::CudaStream* stream, const std::string& activation, int32_t m, int32_t n, int32_t k, const void* x, const void* w, const void* v, const void* b, const void* c, void* wx, int32_t wx_stride, void* vx, int32_t vx_stride, void* y) { if (activation == "fast_gelu") { DualGemmGegluHalf( stream, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); return true; } else if (activation == "gelu") { DualGemmGegluHalf(stream, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); return true; } else { return false; } } template bool TryDispatchDualGemmImplAccType(ep::CudaStream* stream, const std::string& activation, int32_t m, int32_t n, int32_t k, const T* x, const T* w, const T* v, const T* b, const T* c, T* wx, int32_t wx_stride, T* vx, int32_t vx_stride, T* y) { const bool allow_half_precision = ParseBooleanFromEnv("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", false); if (std::is_same::value) { if (allow_half_precision) { return TryDispatchDualGemmImplActivation( stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); } else { return TryDispatchDualGemmImplActivation(stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); } } else { return false; } } template bool TryDispatchDualGemmImplAlignment(ep::CudaStream* stream, const std::string& activation, int32_t m, int32_t n, int32_t k, const T* x, const T* w, const T* v, const T* b, const T* c, T* wx, int32_t wx_stride, T* vx, int32_t vx_stride, T* y) { if (m % 8 == 0 && n % 8 == 0 && k % 8 == 0 && reinterpret_cast(x) % (8 * sizeof(T)) == 0 && reinterpret_cast(w) % (8 * sizeof(T)) == 0 && reinterpret_cast(v) % (8 * sizeof(T)) == 0 && reinterpret_cast(b) % (8 * sizeof(T)) == 0 && reinterpret_cast(c) % (8 * sizeof(T)) == 0 && reinterpret_cast(wx) % (8 * sizeof(T)) == 0 && wx_stride % 8 == 0 && reinterpret_cast(vx) % (8 * sizeof(T)) == 0 && reinterpret_cast(y) % (8 * sizeof(T)) == 0 && vx_stride % 8 == 0) { return TryDispatchDualGemmImplAccType(stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); } else { return false; } } template bool TryDispatchDualGemmImplArchTag(ep::CudaStream* stream, const std::string& activation, int32_t m, int32_t n, int32_t k, const T* x, const T* w, const T* v, const T* b, const T* c, T* wx, int32_t wx_stride, T* vx, int32_t vx_stride, T* y) { const int arch = stream->cuda_arch(); if (arch == 800) { return TryDispatchDualGemmImplAlignment( stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); } else { return false; } } #endif // WITH_CUTLASS template bool TryDispatchDualGemmImpl(ep::CudaStream* stream, const std::string& activation, int32_t m, int32_t n, int32_t k, const T* x, const T* w, const T* v, const T* b, const T* c, T* wx, int32_t wx_stride, T* vx, int32_t vx_stride, T* y) { #ifdef WITH_CUTLASS const bool enabled = ParseBooleanFromEnv("ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL", true); if (enabled) { return TryDispatchDualGemmImplArchTag(stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y); } else { return false; } #else return false; #endif // WITH_CUTLASS } template __global__ void FusedGluForwardGpu( const IndexType m, const IndexType packed_n, const IndexType packed_num, const IndexType packed_stride, ep::primitive::UnaryFunctor act, T* matmul_wx, T* matmul_vx, T* y) { // obtain global thread index IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; // define type of Pack using LoadPack = cuda::elementwise::Packed; // workload of current thread for (IndexType packed_index = global_thread_id, step = gridDim.x * blockDim.x; packed_index < packed_num; packed_index += step) { // obtain the row and col index in output tensor "y" const IndexType y_packed_row = packed_index / packed_n; const IndexType y_packed_col = packed_index - y_packed_row * packed_n; // cast type to load type const LoadPack* matmul_wx_load = reinterpret_cast(matmul_wx) + (y_packed_row * packed_stride + y_packed_col); const LoadPack* matmul_vx_load = reinterpret_cast(matmul_vx) + (y_packed_row * packed_stride + y_packed_col); // init vectors LoadPack matmul_wx_vec = *matmul_wx_load; LoadPack matmul_vx_vec = *matmul_vx_load; LoadPack y_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { // obtain the hidden_state and gate T hidden_state = matmul_wx_vec.elem[i]; T gate = matmul_vx_vec.elem[i]; // calculate activation T act_gate = act(gate); // calculate element-wise product y_vec.elem[i] = hidden_state * act_gate; } *(reinterpret_cast(y + packed_index * pack_size)) = y_vec; } } template void LaunchFusedGluForwardGpu(ep::Stream* stream, const IndexType m, const IndexType packed_n, const IndexType pack_num, const IndexType packed_stride, T* matmul_wx, T* matmul_vx, T* y) { constexpr int32_t block_size = 128; unsigned int grid_size = (pack_num + block_size - 1) / block_size; ep::primitive::UnaryFunctor act(0, 0); FusedGluForwardGpu <<As()->cuda_stream()>>>( m, packed_n, pack_num, packed_stride, act, matmul_wx, matmul_vx, y); } template void DispatchIndexType(ep::Stream* stream, const int64_t m, const int64_t packed_n, const int64_t pack_num, const int64_t packed_stride, T* matmul_wx, T* matmul_vx, T* y) { // dispatch index type if (pack_num < (1 << 30)) { LaunchFusedGluForwardGpu( stream, m, packed_n, pack_num, packed_stride, matmul_wx, matmul_vx, y); } else { LaunchFusedGluForwardGpu( stream, m, packed_n, pack_num, packed_stride, matmul_wx, matmul_vx, y); } } template::type = 0> void DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, T* matmul_wx, T* matmul_vx, T* y) { DispatchIndexType(stream, m, n, m * n, stride, matmul_wx, matmul_vx, y); } template::type = 0> void DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, T* matmul_wx, T* matmul_vx, T* y) { const int64_t pack_size = alignment / sizeof(T); const int64_t packed_n = n / pack_size; const int64_t pack_num = m * packed_n; const int64_t packed_stride = stride / pack_size; DispatchIndexType(stream, m, packed_n, pack_num, packed_stride, matmul_wx, matmul_vx, y); } template void DispatchAlignment(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, T* matmul_wx, T* matmul_vx, T* y) { const auto IsAligned = [&](const size_t alignment) { const uintptr_t matmul_wx_ptr = reinterpret_cast(matmul_wx); const uintptr_t matmul_vx_ptr = reinterpret_cast(matmul_vx); const uintptr_t y_ptr = reinterpret_cast(y); return (/* memory address alignment */ matmul_wx_ptr % alignment == 0 && matmul_vx_ptr % alignment == 0 && y_ptr % alignment == 0 /* #element per row alignment */ && n % (alignment / sizeof(T)) == 0); }; if (IsAligned(16)) { DispatchPackSize(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (IsAligned(8)) { DispatchPackSize(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (IsAligned(4)) { DispatchPackSize(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (IsAligned(2)) { DispatchPackSize(stream, m, n, stride, matmul_wx, matmul_vx, y); } else { DispatchPackSize(stream, m, n, stride, matmul_wx, matmul_vx, y); } } template void DispatchActivationType(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, T* matmul_wx, T* matmul_vx, T* y, const std::string& activation) { if (activation == "none") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (activation == "sigmoid") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (activation == "relu") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (activation == "gelu") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (activation == "fast_gelu") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else if (activation == "silu") { DispatchAlignment(stream, m, n, stride, matmul_wx, matmul_vx, y); } else { UNIMPLEMENTED(); } } template class GpuFusedGluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GpuFusedGluKernel() = default; ~GpuFusedGluKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { // obtain tensors from context const user_op::Tensor* input_tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* input_tensor_w = ctx->Tensor4ArgNameAndIndex("w", 0); user_op::Tensor* out_tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* out_tensor_matmul_wx = ctx->Tensor4ArgNameAndIndex("matmul_wx", 0); // obtain optional tensors from context bool is_split_mode = false; user_op::Tensor* input_tensor_b = nullptr; user_op::Tensor* input_tensor_v = nullptr; user_op::Tensor* input_tensor_c = nullptr; user_op::Tensor* out_tensor_matmul_vx = nullptr; auto* cuda_stream = ctx->stream()->As(); const auto* fused_glu_cache = CHECK_NOTNULL(dynamic_cast(cache)); // check whether the user provide weight tensor v if (ctx->has_input("v", 0)) { input_tensor_v = ctx->Tensor4ArgNameAndIndex("v", 0); out_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex("matmul_vx", 0); is_split_mode = true; } bool has_b = ctx->has_input("b", 0); bool has_c = ctx->has_input("c", 0); // check whether the user provide bais tensors CHECK(!(has_b && (is_split_mode && !has_c))) << "expected existance of c, when provide tensors w, v and b"; bool has_bias = false; if (has_b && (is_split_mode && has_c)) { input_tensor_b = ctx->Tensor4ArgNameAndIndex("b", 0); input_tensor_c = ctx->Tensor4ArgNameAndIndex("c", 0); has_bias = true; } else if (has_b && (!is_split_mode)) { input_tensor_b = ctx->Tensor4ArgNameAndIndex("b", 0); has_bias = true; } else { has_bias = false; } cublasLtEpilogue_t epilogue; if (has_bias) { epilogue = CUBLASLT_EPILOGUE_BIAS; } else { epilogue = CUBLASLT_EPILOGUE_DEFAULT; } // obtain tensor shapes const ShapeView& x_shape = input_tensor_x->shape_view(); const ShapeView& w_shape = input_tensor_w->shape_view(); ShapeView b_shape; if (has_bias) { Shape _b_shape; input_tensor_b->shape_view().ToShape(&_b_shape); b_shape = ShapeView(_b_shape); } const ShapeView& y_shape = out_tensor_y->shape_view(); // validate dimension and number of axes CHECK_GT(x_shape.NumAxes(), 1) << "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); CHECK_EQ(w_shape.NumAxes(), 2) << "number of axes of \'w\' should have be equal to 2, yet get " << w_shape.NumAxes(); if (has_bias) { CHECK_EQ(b_shape.NumAxes(), 1) << "number of axes of \'b\' should have be equal to 1, yet get " << b_shape.NumAxes(); } // check input tensor shapes size_t x_num_axes = x_shape.NumAxes(); CHECK_EQ(w_shape.At(1), x_shape.At(x_num_axes - 1)) << "dimension 1 of \'w\'(" << w_shape.At(1) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_num_axes - 1) << ")"; if (has_bias) { CHECK_EQ(b_shape.At(0), w_shape.At(0)) << "dimension 0 of \'b\'(" << b_shape.At(0) << ") is not consistant with dimension 0 of \'w\'(" << w_shape.At(0) << ")"; } if (!is_split_mode) { CHECK_EQ(w_shape.At(1) % 2, 0) << "dimension 1 of \'w\' is not divisible by 2"; } // check optional input tensor shapes if (is_split_mode) { const ShapeView& v_shape = input_tensor_v->shape_view(); CHECK_EQ(v_shape.NumAxes(), 2) << "number of axes of \'v\' should have be equal to 2, yet get " << v_shape.NumAxes(); CHECK_EQ(v_shape, w_shape) << "the shape of \'v\' is not consistant with \'w\'"; if (has_bias) { const ShapeView& c_shape = input_tensor_c->shape_view(); CHECK_EQ(c_shape.NumAxes(), 1) << "number of axes of \'c\' should have be equal to 1, yet get " << c_shape.NumAxes(); CHECK_EQ(c_shape, b_shape) << "the shape of \'c\' is not consistant with \'b\'"; } } // obtain data type for cublaslt computation const DataType data_type = out_tensor_matmul_wx->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); // infer m, n, k const int64_t m = x_shape.Count(0, x_num_axes - 1); const int64_t n = y_shape.At(x_num_axes - 1); const int64_t k = x_shape.At(x_num_axes - 1); if (has_bias) { if (TryDispatchDualGemmImpl( ctx->stream()->As(), ctx->Attr("activation"), m, n, k, input_tensor_x->dptr(), input_tensor_w->dptr(), is_split_mode ? input_tensor_v->dptr() : input_tensor_w->dptr() + n * k, input_tensor_b->dptr(), is_split_mode ? input_tensor_c->dptr() : input_tensor_b->dptr() + n, out_tensor_matmul_wx->mut_dptr(), is_split_mode ? n : 2 * n, is_split_mode ? out_tensor_matmul_vx->mut_dptr() : out_tensor_matmul_wx->mut_dptr() + n, is_split_mode ? n : 2 * n, out_tensor_y->mut_dptr())) { return; } } // init scalar parameters for cublaslt const double alpha = 1.0; const double beta = 0.0; const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // calculate matmul_wx (and matmul_vx) through cublaslt if (is_split_mode) { // define shape parameters to be inferred size_t cublas_wx_m = 0, cublas_wx_n = 0, cublas_wx_k = 0; int64_t cublas_wx_lda = 0, cublas_wx_ldb = 0, cublas_wx_ldc = 0; size_t cublas_vx_m = 0, cublas_vx_n = 0, cublas_vx_k = 0; int64_t cublas_vx_lda = 0, cublas_vx_ldb = 0, cublas_vx_ldc = 0; // init dim vector DimVector x_dim_vec({m, k}); DimVector w_dim_vec({n, k}); DimVector v_dim_vec({n, k}); // setup cublaslt matmul attributes InferMatmulCublasMNK(x_dim_vec, w_dim_vec, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_wx_m, &cublas_wx_n, &cublas_wx_k, &cublas_wx_lda, &cublas_wx_ldb, &cublas_wx_ldc); SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, has_bias ? input_tensor_b->dptr() : nullptr, nullptr, cublas_wx_m, cublas_wx_n, cublas_wx_k, cublas_wx_lda, cublas_wx_ldb, cublas_wx_ldc); // setup algorithms cublasLtMatmulPreference_t preference = nullptr; size_t workspace_size = cuda_stream->cublas_workspace_size(); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference)); OF_CUBLAS_CHECK( cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); int wx_returned_result = 0; cublasLtMatmulHeuristicResult_t wx_heuristic_result; OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc, fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc, fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1, &wx_heuristic_result, &wx_returned_result)); CHECK_EQ(wx_returned_result, 1); // launch cublaslt matmul // out_tensor_matmul_wx = 1.0 * (input_tensor_w * input_tensor_x) + 1.0 * input_tensor_b OF_CUBLAS_CHECK(cublasLtMatmul( /*lightHandle*/ cuda_stream->cublas_lt_handle(), /*computeDesc*/ fused_glu_cache->operation_desc, /*alpha*/ &sp_alpha, /*A*/ input_tensor_w->dptr(), /*Adesc*/ fused_glu_cache->cublas_a_desc, /*B*/ input_tensor_x->dptr(), /*Bdesc*/ fused_glu_cache->cublas_b_desc, /*beta*/ &sp_beta, /*C*/ has_bias ? input_tensor_b->dptr() : nullptr, /*Cdesc*/ fused_glu_cache->cublas_c_desc, /*D*/ out_tensor_matmul_wx->mut_dptr(), /*Ddesc*/ fused_glu_cache->cublas_c_desc, /*algo*/ &wx_heuristic_result.algo, /*workspace*/ cuda_stream->cublas_workspace(), /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(), /*stream*/ cuda_stream->cuda_stream())); // setup cublaslt attributes InferMatmulCublasMNK(x_dim_vec, v_dim_vec, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_vx_m, &cublas_vx_n, &cublas_vx_k, &cublas_vx_lda, &cublas_vx_ldb, &cublas_vx_ldc); SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, has_bias ? input_tensor_c->dptr() : nullptr, nullptr, cublas_vx_m, cublas_vx_n, cublas_vx_k, cublas_vx_lda, cublas_vx_ldb, cublas_vx_ldc); // setup algorithm int vx_returned_result = 0; cublasLtMatmulHeuristicResult_t vx_heuristic_result; OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc, fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc, fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1, &vx_heuristic_result, &vx_returned_result)); CHECK_EQ(vx_returned_result, 1); cublasLtMatmulPreferenceDestroy(preference); // launch cublaslt matmul // out_tensor_matmul_vx = 1.0 * (input_tensor_v * input_tensor_x) + 1.0 * input_tensor_c OF_CUBLAS_CHECK(cublasLtMatmul( /*lightHandle*/ cuda_stream->cublas_lt_handle(), /*computeDesc*/ fused_glu_cache->operation_desc, /*alpha*/ &sp_alpha, /*A*/ input_tensor_v->dptr(), /*Adesc*/ fused_glu_cache->cublas_a_desc, /*B*/ input_tensor_x->dptr(), /*Bdesc*/ fused_glu_cache->cublas_b_desc, /*beta*/ &sp_beta, /*C*/ has_bias ? input_tensor_c->dptr() : nullptr, /*Cdesc*/ fused_glu_cache->cublas_c_desc, /*D*/ out_tensor_matmul_vx->mut_dptr(), /*Ddesc*/ fused_glu_cache->cublas_c_desc, /*algo*/ &wx_heuristic_result.algo, /*workspace*/ cuda_stream->cublas_workspace(), /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(), /*stream*/ cuda_stream->cuda_stream())); } else { // define shape parameters to be inferred size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; // init dim vector DimVector x_dim_vec({m, k}); DimVector w_dim_vec({2 * n, k}); // setup cublas attributes InferMatmulCublasMNK(x_dim_vec, w_dim_vec, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, has_bias ? input_tensor_b->dptr() : nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); // setup algorithm cublasLtMatmulPreference_t preference = nullptr; size_t workspace_size = cuda_stream->cublas_workspace_size(); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference)); OF_CUBLAS_CHECK( cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); int wx_returned_result = 0; cublasLtMatmulHeuristicResult_t wx_heuristic_result; OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc, fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc, fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1, &wx_heuristic_result, &wx_returned_result)); CHECK_EQ(wx_returned_result, 1); cublasLtMatmulPreferenceDestroy(preference); // launch cublaslt matmul // out_tensor_matmul_wx = 1.0 * (input_tensor_w * input_tensor_x) + 1.0 * input_tensor_b OF_CUBLAS_CHECK(cublasLtMatmul( /*lightHandle*/ cuda_stream->cublas_lt_handle(), /*computeDesc*/ fused_glu_cache->operation_desc, /*alpha*/ &sp_alpha, /*A*/ input_tensor_w->dptr(), /*Adesc*/ fused_glu_cache->cublas_a_desc, /*B*/ input_tensor_x->dptr(), /*Bdesc*/ fused_glu_cache->cublas_b_desc, /*beta*/ &sp_beta, /*C*/ has_bias ? input_tensor_b->dptr() : nullptr, /*Cdesc*/ fused_glu_cache->cublas_c_desc, /*D*/ out_tensor_matmul_wx->mut_dptr(), /*Ddesc*/ fused_glu_cache->cublas_c_desc, /*algo*/ nullptr, /*workspace*/ cuda_stream->cublas_workspace(), /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(), /*stream*/ cuda_stream->cuda_stream())); } // dispatch according to activation type DispatchActivationType(ctx->stream(), /*m, n=*/m, n, /*stride=*/is_split_mode ? n : 2 * n, /*matmul_wx=*/out_tensor_matmul_wx->mut_dptr(), /*matmul_vx=*/ is_split_mode ? out_tensor_matmul_vx->mut_dptr() : out_tensor_matmul_wx->mut_dptr() + n, /*y=*/out_tensor_y->mut_dptr(), /*activation=*/ctx->Attr("activation")); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_GPU_FUSED_GLU_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_glu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_GPU_FUSED_GLU_KERNEL(double) REGISTER_GPU_FUSED_GLU_KERNEL(float) REGISTER_GPU_FUSED_GLU_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_GPU_FUSED_GLU_KERNEL(nv_bfloat16) #endif } // namespace oneflow #endif // CUDA_VERSION >= 11020 ================================================ FILE: oneflow/user/kernels/fused_glu_without_linear_grad_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/ep/common/primitive/binary_functor.h" #include "oneflow/core/ep/cuda/primitive/binary_functor.cuh" #include "oneflow/core/ep/include/primitive/unary_op.h" #include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace { // declear using "BinaryFunctor" from namespace "ep::primitive::broadcast_elementwise_binary" template using BinaryFunctor = ep::primitive::broadcast_elementwise_binary::BinaryFunctor; template __global__ void FusedGluWithoutLinearGradGpu( const IndexType m, const IndexType packed_n, const IndexType pack_num, const IndexType packed_stride, BinaryFunctor dact, ep::primitive::UnaryFunctor act, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { // define type of Pack using LoadPack = cuda::elementwise::Packed; // obtain global thread index IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; // workload of current thread for (IndexType packed_index = global_thread_id, step = gridDim.x * blockDim.x; packed_index < pack_num; packed_index += step) { // obtain the row and col index in output tensor "d_matmul_wx" and "d_matmul_vx" const IndexType packed_row = packed_index / packed_n; const IndexType packed_col = packed_index - packed_row * packed_n; // cast type to load type const LoadPack* dy_load = reinterpret_cast(dy) + (packed_row * packed_n + packed_col); const LoadPack* matmul_wx_load = reinterpret_cast(matmul_wx) + (packed_row * packed_stride + packed_col); const LoadPack* matmul_vx_load = reinterpret_cast(matmul_vx) + (packed_row * packed_stride + packed_col); // init vectors LoadPack dy_vec = *dy_load; LoadPack matmul_wx_vec = *matmul_wx_load; LoadPack matmul_vx_vec = *matmul_vx_load; LoadPack d_matmul_wx_vec; LoadPack d_matmul_vx_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { // calculate the gradient of activated gate T d_act_gate = matmul_wx_vec.elem[i] * dy_vec.elem[i]; // calculate the gradient of hidden_state T gate = matmul_vx_vec.elem[i]; T act_gate = act(gate); d_matmul_wx_vec.elem[i] = act_gate * dy_vec.elem[i]; // d_hidden_state // calculate the gradient of gate d_matmul_vx_vec.elem[i] = dact(d_act_gate, gate); // d_gate } *(reinterpret_cast(d_matmul_wx) + (packed_row * packed_stride + packed_col)) = d_matmul_wx_vec; *(reinterpret_cast(d_matmul_vx) + (packed_row * packed_stride + packed_col)) = d_matmul_vx_vec; } } template void LaunchFusedGluWithoutLinearGradGpu(ep::Stream* stream, const IndexType m, const IndexType packed_n, const IndexType pack_num, const IndexType packed_stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { constexpr int32_t block_size = 128; unsigned int grid_size = (pack_num + block_size - 1) / block_size; ep::primitive::UnaryFunctor act(0, 0); BinaryFunctor dact(0, 0); FusedGluWithoutLinearGradGpu <<As()->cuda_stream()>>>( m, packed_n, pack_num, packed_stride, dact, act, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } template void DispatchIndexType(ep::Stream* stream, const int64_t m, const int64_t packed_n, const int64_t pack_num, const int64_t packed_stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { if (pack_num < (1 << 30)) { LaunchFusedGluWithoutLinearGradGpu( stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else { LaunchFusedGluWithoutLinearGradGpu( stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } } template::type = 0> void DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { DispatchIndexType(stream, m, n, m * n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } template::type = 0> void DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { const int64_t pack_size = alignment / sizeof(T); const int64_t packed_n = n / pack_size; const int64_t pack_num = m * packed_n; const int64_t packed_stride = stride / pack_size; DispatchIndexType( stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } template void DispatchAlignment(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { const auto IsAligned = [&](const size_t alignment) { const uintptr_t dy_ptr = reinterpret_cast(dy); const uintptr_t matmul_wx_ptr = reinterpret_cast(matmul_wx); const uintptr_t matmul_vx_ptr = reinterpret_cast(matmul_vx); const uintptr_t d_matmul_wx_ptr = reinterpret_cast(d_matmul_wx); const uintptr_t d_matmul_vx_ptr = reinterpret_cast(d_matmul_vx); const int64_t pack_size = alignment / sizeof(T); return pack_size != 0 ? (/* memory address alignment */ dy_ptr % alignment == 0 && matmul_vx_ptr % alignment == 0 && matmul_wx_ptr % alignment == 0 && d_matmul_wx_ptr % alignment == 0 && d_matmul_vx_ptr % alignment == 0 /* #element per row alignment */ && n % (pack_size) == 0) : false; }; // dispatch alignment if (IsAligned(16)) { DispatchPackSize(stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (IsAligned(8)) { DispatchPackSize(stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (IsAligned(4)) { DispatchPackSize(stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (IsAligned(2)) { DispatchPackSize(stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else { DispatchPackSize(stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } } template void DispatchActivationType(ep::Stream* stream, const int64_t m, const int64_t n, const std::string& activation, const int64_t stride, const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) { if (activation == "none") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (activation == "sigmoid") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (activation == "relu") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (activation == "gelu") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (activation == "fast_gelu") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else if (activation == "silu") { DispatchAlignment( stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx); } else { UNIMPLEMENTED(); } } template class GpuFusedGluWithoutLinearGradKernel final : public user_op::OpKernel { public: GpuFusedGluWithoutLinearGradKernel() = default; ~GpuFusedGluWithoutLinearGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { // obtain tensors from context const user_op::Tensor* input_tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* input_tensor_matmul_wx = ctx->Tensor4ArgNameAndIndex("matmul_wx", 0); user_op::Tensor* out_tensor_d_matmul_wx = ctx->Tensor4ArgNameAndIndex("d_matmul_wx", 0); // obtain optional tensors from context bool is_split_mode = false; user_op::Tensor* input_tensor_matmul_vx = nullptr; user_op::Tensor* out_tensor_d_matmul_vx = nullptr; if (ctx->has_input("matmul_vx", 0)) { input_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex("matmul_vx", 0); out_tensor_d_matmul_vx = ctx->Tensor4ArgNameAndIndex("d_matmul_vx", 0); is_split_mode = true; } // obtain tensor shapes and number of axes const ShapeView& dy_shape = input_tensor_dy->shape_view(); const ShapeView& matmul_wx_shape = input_tensor_matmul_wx->shape_view(); const ShapeView& d_matmul_wx_shape = out_tensor_d_matmul_wx->shape_view(); const size_t dy_num_axes = dy_shape.NumAxes(); const size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes(); // validate dimension and number of axes CHECK_GE(dy_num_axes, 2) << "number of axes of \'dy\' should have be greater than 1, yet get " << dy_num_axes; CHECK_GE(matmul_wx_num_axes, 2) << "number of axes of \'matmul_wx\' should have be greater than 1, yet get " << matmul_wx_num_axes; CHECK_EQ(dy_num_axes, matmul_wx_num_axes) << "number of axes of \'dy\'(" << dy_num_axes << ") is not consistant with the one of \'matmul_wx\'(" << matmul_wx_num_axes << ")"; // check input shape if (is_split_mode) { CHECK_EQ(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "the last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ")"; } else { CHECK_EQ(2 * dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "two times of the last dimension of \'dy\'(" << 2 * dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ")"; } // check optional input tensor shapes if (is_split_mode) { const user_op::Tensor* input_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex("matmul_vx", 0); const ShapeView& matmul_vx_shape = input_tensor_matmul_vx->shape_view(); const size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes(); CHECK_GE(matmul_vx_num_axes, 2) << "number of axes of \'matmul_vx\' should have be greater than 1, yet get " << matmul_vx_num_axes; CHECK_EQ(matmul_vx_num_axes, dy_num_axes) << "number of axes of \'dy\'(" << dy_num_axes << ") is not consistant with the one of \'matmul_vx\'(" << matmul_vx_num_axes << ")"; CHECK_EQ(matmul_vx_shape.At(matmul_vx_num_axes - 1), dy_shape.At(dy_num_axes - 1)) << "the last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_vx\'(" << matmul_vx_shape.At(matmul_vx_num_axes - 1) << ")"; } // infer m, n const int64_t m = dy_shape.Count(0, dy_num_axes - 1); const int64_t n = dy_shape.At(dy_num_axes - 1); // start dispatch process DispatchActivationType( ctx->stream(), /*m, n=*/m, n, /*activation=*/ctx->Attr("activation"), /*stride=*/is_split_mode ? n : n * 2, /*dy=*/input_tensor_dy->dptr(), /*matmul_wx=*/input_tensor_matmul_wx->dptr(), /*matmul_vx=*/ is_split_mode ? input_tensor_matmul_vx->dptr() : input_tensor_matmul_wx->dptr() + n, /*d_matmul_wx=*/out_tensor_d_matmul_wx->mut_dptr(), /*d_matmul_vx=*/ is_split_mode ? out_tensor_d_matmul_vx->mut_dptr() : out_tensor_d_matmul_wx->mut_dptr() + n); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_glu_without_linear_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("d_matmul_wx", 0) == GetDataType::value)); REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(double) REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(float) REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(nv_bfloat16) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_gru_cell_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/user/kernels/fused_rnn_cell_kernel_util.h" // NOTE(Liang Depeng): The implementation of fused_gru_cell is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/RNN.cu namespace oneflow { namespace { template struct AccumulateType {}; template<> struct AccumulateType { using type = float; }; template<> struct AccumulateType { using type = double; }; template using acc_type = typename AccumulateType::type; #define H2F(input) static_cast(input) #define F2H(input) static_cast(input) template __device__ __forceinline__ T sigmoid(T in) { T one = static_cast(1.0); return one / (one + ::exp(-in)); } template #if __CUDA_ARCH__ >= 350 OF_LAUNCH_BOUNDS_2(512, 4) #endif __global__ void gru_cell_forward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr, const T* hx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr, T* workspace_ptr) { bool has_bias = input_bias_ptr != nullptr; for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel; linearIndex += gridDim.x * blockDim.x) { IDX_TYPE offset = (linearIndex / hidden_size) * 3 * hidden_size + linearIndex % hidden_size; T ir = input_gates_ptr[offset + 0 * hidden_size]; T ii = input_gates_ptr[offset + 1 * hidden_size]; T in = input_gates_ptr[offset + 2 * hidden_size]; T hr = hidden_gates_ptr[offset + 0 * hidden_size]; T hi = hidden_gates_ptr[offset + 1 * hidden_size]; T hn = hidden_gates_ptr[offset + 2 * hidden_size]; T hx = hx_ptr[linearIndex]; T* hy = &(hy_ptr[linearIndex]); T b1r, b1i, b1n, b2r, b2i, b2n; if (has_bias) { b1r = input_bias_ptr[linearIndex % hidden_size + 0 * hidden_size]; b1i = input_bias_ptr[linearIndex % hidden_size + 1 * hidden_size]; b1n = input_bias_ptr[linearIndex % hidden_size + 2 * hidden_size]; b2r = hidden_bias_ptr[linearIndex % hidden_size + 0 * hidden_size]; b2i = hidden_bias_ptr[linearIndex % hidden_size + 1 * hidden_size]; b2n = hidden_bias_ptr[linearIndex % hidden_size + 2 * hidden_size]; } else { b1r = F2H(0.0); b1i = F2H(0.0); b1n = F2H(0.0); b2r = F2H(0.0); b2i = F2H(0.0); b2n = F2H(0.0); } offset = (linearIndex / hidden_size) * 5 * hidden_size + linearIndex % hidden_size; ACC_T rg, ig, ng; rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r)); ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i)); ng = H2F(in) + H2F(b1n) + rg * (H2F(hn) + H2F(b2n)); ng = ::tanh(ng); *hy = F2H(ng + ig * (H2F(hx) - ng)); // SAVE FOR BACKWARDS workspace_ptr[offset + 0 * hidden_size] = F2H(rg); workspace_ptr[offset + 1 * hidden_size] = F2H(ig); workspace_ptr[offset + 2 * hidden_size] = F2H(ng); workspace_ptr[offset + 3 * hidden_size] = hx; workspace_ptr[offset + 4 * hidden_size] = F2H(H2F(hn) + H2F(b2n)); } } template #if __CUDA_ARCH__ >= 350 OF_LAUNCH_BOUNDS_2(512, 4) #endif __global__ void gru_cell_backward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* grad_hy_ptr, const T* workspace_ptr, T* grad_input_gates_ptr, T* grad_hidden_gates_ptr, T* grad_hx_ptr) { for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel; linearIndex += gridDim.x * blockDim.x) { IDX_TYPE offset = (linearIndex / hidden_size) * 5 * hidden_size + linearIndex % hidden_size; T rg = workspace_ptr[offset + 0 * hidden_size]; T ig = workspace_ptr[offset + 1 * hidden_size]; T ng = workspace_ptr[offset + 2 * hidden_size]; T hx = workspace_ptr[offset + 3 * hidden_size]; T hn = workspace_ptr[offset + 4 * hidden_size]; T go = grad_hy_ptr[linearIndex]; offset = (linearIndex / hidden_size) * 3 * hidden_size + linearIndex % hidden_size; ACC_T gig = H2F(go) * (H2F(hx) - H2F(ng)) * (1 - H2F(ig)) * H2F(ig); ACC_T ghx = H2F(go) * H2F(ig); ACC_T gin = H2F(go) * (1 - H2F(ig)) * (1 - H2F(ng) * H2F(ng)); ACC_T ghn = gin * H2F(rg); ACC_T grg = gin * H2F(hn) * (1 - H2F(rg)) * H2F(rg); grad_input_gates_ptr[offset + 0 * hidden_size] = F2H(grg); grad_input_gates_ptr[offset + 1 * hidden_size] = F2H(gig); grad_input_gates_ptr[offset + 2 * hidden_size] = F2H(gin); grad_hidden_gates_ptr[offset + 0 * hidden_size] = F2H(grg); grad_hidden_gates_ptr[offset + 1 * hidden_size] = F2H(gig); grad_hidden_gates_ptr[offset + 2 * hidden_size] = F2H(ghn); if (grad_hx_ptr != nullptr) { grad_hx_ptr[linearIndex] = F2H(ghx); } } } template struct FusedGruCellGradFunctor final { void operator()(ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel, const int64_t hidden_size, const T* grad_hy_ptr, const T* workspace_ptr, T* grad_input_gates_ptr, T* grad_hidden_gates_ptr, T* grad_hx_ptr) { using ACC_T = acc_type; if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL((gru_cell_backward), stream, hx_numel, static_cast(hx_numel), static_cast(hidden_size), grad_hy_ptr, workspace_ptr, grad_input_gates_ptr, grad_hidden_gates_ptr, grad_hx_ptr); } else { RUN_CUDA_KERNEL((gru_cell_backward), stream, hx_numel, hx_numel, hidden_size, grad_hy_ptr, workspace_ptr, grad_input_gates_ptr, grad_hidden_gates_ptr, grad_hx_ptr); } } }; template<> void FusedGruCellGradFunctor::operator()( ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel, const int64_t hidden_size, const float16* grad_hy_ptr, const float16* workspace_ptr, float16* grad_input_gates_ptr, float16* grad_hidden_gates_ptr, float16* grad_hx_ptr) { if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL( (gru_cell_backward), stream, hx_numel, static_cast(hx_numel), static_cast(hidden_size), reinterpret_cast(grad_hy_ptr), reinterpret_cast(workspace_ptr), reinterpret_cast(grad_input_gates_ptr), reinterpret_cast(grad_hidden_gates_ptr), reinterpret_cast(grad_hx_ptr)); } else { RUN_CUDA_KERNEL( (gru_cell_backward), stream, hx_numel, hx_numel, hidden_size, reinterpret_cast(grad_hy_ptr), reinterpret_cast(workspace_ptr), reinterpret_cast(grad_input_gates_ptr), reinterpret_cast(grad_hidden_gates_ptr), reinterpret_cast(grad_hx_ptr)); } } template struct FusedGruCellFunctor final { void operator()(ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel, const int64_t hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr, const T* hx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr, T* workspace_ptr) { using ACC_T = acc_type; if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL((gru_cell_forward), stream, hx_numel, static_cast(hx_numel), static_cast(hidden_size), input_gates_ptr, hidden_gates_ptr, hx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, workspace_ptr); } else { RUN_CUDA_KERNEL((gru_cell_forward), stream, hx_numel, hx_numel, hidden_size, input_gates_ptr, hidden_gates_ptr, hx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, workspace_ptr); } } }; template<> void FusedGruCellFunctor::operator()( ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel, const int64_t hidden_size, const float16* input_gates_ptr, const float16* hidden_gates_ptr, const float16* hx_ptr, const float16* input_bias_ptr, const float16* hidden_bias_ptr, float16* hy_ptr, float16* workspace_ptr) { if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL( (gru_cell_forward), stream, hx_numel, static_cast(hx_numel), static_cast(hidden_size), reinterpret_cast(input_gates_ptr), reinterpret_cast(hidden_gates_ptr), reinterpret_cast(hx_ptr), reinterpret_cast(input_bias_ptr), reinterpret_cast(hidden_bias_ptr), reinterpret_cast(hy_ptr), reinterpret_cast(workspace_ptr)); } else { RUN_CUDA_KERNEL((gru_cell_forward), stream, hx_numel, hx_numel, hidden_size, reinterpret_cast(input_gates_ptr), reinterpret_cast(hidden_gates_ptr), reinterpret_cast(hx_ptr), reinterpret_cast(input_bias_ptr), reinterpret_cast(hidden_bias_ptr), reinterpret_cast(hy_ptr), reinterpret_cast(workspace_ptr)); } } } // namespace template class GpuFusedGruCellKernel final : public user_op::OpKernel { public: GpuFusedGruCellKernel() = default; ~GpuFusedGruCellKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input_gates = ctx->Tensor4ArgNameAndIndex("input_gates", 0); const user_op::Tensor* hidden_gates = ctx->Tensor4ArgNameAndIndex("hidden_gates", 0); const user_op::Tensor* hx = ctx->Tensor4ArgNameAndIndex("hx", 0); user_op::Tensor* hy = ctx->Tensor4ArgNameAndIndex("hy", 0); user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); const T* input_bias_ptr = nullptr; const T* hidden_bias_ptr = nullptr; if (ctx->has_input("input_bias", 0)) { CHECK(ctx->has_input("hidden_bias", 0)); input_bias_ptr = ctx->Tensor4ArgNameAndIndex("input_bias", 0)->dptr(); hidden_bias_ptr = ctx->Tensor4ArgNameAndIndex("hidden_bias", 0)->dptr(); } const T* input_gates_ptr = input_gates->dptr(); const T* hidden_gates_ptr = hidden_gates->dptr(); const T* hx_ptr = hx->dptr(); T* hy_ptr = hy->mut_dptr(); T* workspace_ptr = workspace->mut_dptr(); const int64_t hx_numel = hx->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = hx->shape_view().At(hx->shape_view().NumAxes() - 1); FusedGruCellFunctor()(ctx->stream(), hx_numel, workspace_numel, hidden_size, input_gates_ptr, hidden_gates_ptr, hx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, workspace_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_GRU_CELL_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_gru_cell") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("hx", 0) == GetDataType::value) \ && (user_op::HobDataType("input_gates", 0) == GetDataType::value) \ && (user_op::HobDataType("hidden_gates", 0) == GetDataType::value)) REGISTER_FUSED_GRU_CELL_KERNEL(float); REGISTER_FUSED_GRU_CELL_KERNEL(float16); class GpuFusedGruCellGradFloatKernel final : public user_op::OpKernel { public: GpuFusedGruCellGradFloatKernel() = default; ~GpuFusedGruCellGradFloatKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex("grad_hy", 0); const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); user_op::Tensor* grad_input_gates = ctx->Tensor4ArgNameAndIndex("grad_input_gates", 0); user_op::Tensor* grad_hidden_gates = ctx->Tensor4ArgNameAndIndex("grad_hidden_gates", 0); const float* grad_hy_ptr = grad_hy->dptr(); const float* workspace_ptr = workspace->dptr(); float* grad_input_gates_ptr = grad_input_gates->mut_dptr(); float* grad_hidden_gates_ptr = grad_hidden_gates->mut_dptr(); float* grad_hx_ptr = nullptr; if (ctx->has_output("grad_hx", 0)) { user_op::Tensor* grad_hx = ctx->Tensor4ArgNameAndIndex("grad_hx", 0); grad_hx_ptr = grad_hx->mut_dptr(); } const int64_t hx_numel = grad_hy->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = grad_hy->shape_view().At(grad_hy->shape_view().NumAxes() - 1); FusedGruCellGradFunctor()(ctx->stream(), hx_numel, workspace_numel, hidden_size, grad_hy_ptr, workspace_ptr, grad_input_gates_ptr, grad_hidden_gates_ptr, grad_hx_ptr); if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { float* grad_input_bias_ptr = ctx->Tensor4ArgNameAndIndex("grad_input_bias", 0)->mut_dptr(); std::vector axis; axis.push_back(0); const Shape& reduced_shape = CreateReducedShape(grad_input_gates->shape_view(), {axis.begin(), axis.end()}); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, grad_input_bias_ptr), XpuVarNdarray(grad_input_gates->shape_view(), grad_input_gates->dptr()), XpuVarNdarray(tmp_buffer->shape_view(), tmp_buffer->mut_dptr())); float* grad_hidden_bias_ptr = ctx->Tensor4ArgNameAndIndex("grad_hidden_bias", 0)->mut_dptr(); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, grad_hidden_bias_ptr), XpuVarNdarray(grad_hidden_gates->shape_view(), grad_hidden_gates->dptr()), XpuVarNdarray(tmp_buffer->shape_view(), tmp_buffer->mut_dptr())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("fused_gru_cell_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("grad_hy", 0) == GetDataType::value) && (user_op::HobDataType("workspace", 0) == GetDataType::value)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { size_t tmp_bytes = 0; if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("grad_hy", 0).shape(); tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * 3 * sizeof(float)); } else { tmp_bytes = 0; } return tmp_bytes; }); class GpuFusedGruCellGradHalfKernel final : public user_op::OpKernel { public: GpuFusedGruCellGradHalfKernel() = default; ~GpuFusedGruCellGradHalfKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex("grad_hy", 0); const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); user_op::Tensor* grad_input_gates = ctx->Tensor4ArgNameAndIndex("grad_input_gates", 0); user_op::Tensor* grad_hidden_gates = ctx->Tensor4ArgNameAndIndex("grad_hidden_gates", 0); const float16* grad_hy_ptr = grad_hy->dptr(); const float16* workspace_ptr = workspace->dptr(); float16* grad_input_gates_ptr = grad_input_gates->mut_dptr(); float16* grad_hidden_gates_ptr = grad_hidden_gates->mut_dptr(); float16* grad_hx_ptr = nullptr; if (ctx->has_output("grad_hx", 0)) { user_op::Tensor* grad_hx = ctx->Tensor4ArgNameAndIndex("grad_hx", 0); grad_hx_ptr = grad_hx->mut_dptr(); } const int64_t hx_numel = grad_hy->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = grad_hy->shape_view().At(grad_hy->shape_view().NumAxes() - 1); FusedGruCellGradFunctor()(ctx->stream(), hx_numel, workspace_numel, hidden_size, grad_hy_ptr, workspace_ptr, grad_input_gates_ptr, grad_hidden_gates_ptr, grad_hx_ptr); if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { std::vector axis; axis.push_back(0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& in_shape = grad_input_gates->shape_view(); const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); float* in_tmp_buffer = tmp_buffer->mut_dptr(); const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); float* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + in_tmp_buffer_bytes); const size_t out_tmp_buffer_bytes = GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float)); float* reduce_tmp_buffer = reinterpret_cast( tmp_buffer->mut_dptr() + in_tmp_buffer_bytes + out_tmp_buffer_bytes); const size_t reduce_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes, tmp_buffer->shape_view().elem_cnt()); auto h2f = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat16, DataType::kFloat); CHECK(h2f); auto f2h = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat, DataType::kFloat16); CHECK(f2h); h2f->Launch(ctx->stream(), grad_input_gates->dptr(), in_tmp_buffer, in_shape.elem_cnt()); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out_tmp_buffer), XpuVarNdarray(in_shape, in_tmp_buffer), XpuVarNdarray(in_shape, reduce_tmp_buffer)); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("grad_input_bias", 0); f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr(), output_tensor->shape_view().elem_cnt()); h2f->Launch(ctx->stream(), grad_hidden_gates->dptr(), in_tmp_buffer, in_shape.elem_cnt()); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out_tmp_buffer), XpuVarNdarray(in_shape, in_tmp_buffer), XpuVarNdarray(in_shape, reduce_tmp_buffer)); output_tensor = ctx->Tensor4ArgNameAndIndex("grad_hidden_bias", 0); f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr(), output_tensor->shape_view().elem_cnt()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("fused_gru_cell_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("grad_hy", 0) == GetDataType::value) && (user_op::HobDataType("workspace", 0) == GetDataType::value)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { size_t tmp_bytes = 0; if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("grad_hy", 0).shape(); const Shape& out_shape = ctx->OutputTensorDesc("grad_input_bias", 0).shape(); tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * 3 * sizeof(float)) + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); } else { tmp_bytes = 0; } return tmp_bytes; }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_lstm_cell_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/cuda/cuda_device.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/user/kernels/fused_rnn_cell_kernel_util.h" // NOTE(Liang Depeng): The implementation of fused_lstm_cell is modified from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/RNN.cu namespace oneflow { namespace { template struct AccumulateType {}; template<> struct AccumulateType { using type = float; }; template<> struct AccumulateType { using type = double; }; template using acc_type = typename AccumulateType::type; #define H2F(input) static_cast(input) #define F2H(input) static_cast(input) template __device__ __forceinline__ T sigmoid(T in) { T one = static_cast(1.0); return one / (one + ::exp(-in)); } template #if __CUDA_ARCH__ >= 350 OF_LAUNCH_BOUNDS_2(512, 4) #endif __global__ void lstm_cell_forward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr, const T* cx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr, T* cy_ptr, T* workspace_ptr) { bool has_bias = input_bias_ptr != nullptr; for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel; linearIndex += gridDim.x * blockDim.x) { IDX_TYPE offset = (linearIndex / hidden_size) * 4 * hidden_size + linearIndex % hidden_size; T iig = input_gates_ptr[offset + 0 * hidden_size]; T ifg = input_gates_ptr[offset + 1 * hidden_size]; T icg = input_gates_ptr[offset + 2 * hidden_size]; T iog = input_gates_ptr[offset + 3 * hidden_size]; T hig = hidden_gates_ptr[offset + 0 * hidden_size]; T hfg = hidden_gates_ptr[offset + 1 * hidden_size]; T hcg = hidden_gates_ptr[offset + 2 * hidden_size]; T hog = hidden_gates_ptr[offset + 3 * hidden_size]; T* wig = &(workspace_ptr[offset + 0 * hidden_size]); T* wfg = &(workspace_ptr[offset + 1 * hidden_size]); T* wcg = &(workspace_ptr[offset + 2 * hidden_size]); T* wog = &(workspace_ptr[offset + 3 * hidden_size]); T cx = cx_ptr[linearIndex]; T* hy = &(hy_ptr[linearIndex]); T* cy = &(cy_ptr[linearIndex]); T b1i, b1f, b1c, b1o; T b2i, b2f, b2c, b2o; if (has_bias) { b1i = input_bias_ptr[linearIndex % hidden_size + 0 * hidden_size]; b1f = input_bias_ptr[linearIndex % hidden_size + 1 * hidden_size]; b1c = input_bias_ptr[linearIndex % hidden_size + 2 * hidden_size]; b1o = input_bias_ptr[linearIndex % hidden_size + 3 * hidden_size]; b2i = hidden_bias_ptr[linearIndex % hidden_size + 0 * hidden_size]; b2f = hidden_bias_ptr[linearIndex % hidden_size + 1 * hidden_size]; b2c = hidden_bias_ptr[linearIndex % hidden_size + 2 * hidden_size]; b2o = hidden_bias_ptr[linearIndex % hidden_size + 3 * hidden_size]; } else { b1i = F2H(0.0); b1f = F2H(0.0); b1c = F2H(0.0); b1o = F2H(0.0); b2i = F2H(0.0); b2f = F2H(0.0); b2c = F2H(0.0); b2o = F2H(0.0); } ACC_T ig, fg, cg, og; ACC_T f_hy, f_cy; ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i)); fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f)); cg = ::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c)); og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o)); f_cy = (fg * H2F(cx)) + (ig * cg); f_hy = og * ::tanh(f_cy); *hy = F2H(f_hy); *cy = F2H(f_cy); // SAVE FOR BACKWARDS // Also need cy and cx but can be saved easily in python *wig = F2H(ig); *wfg = F2H(fg); *wcg = F2H(cg); *wog = F2H(og); } } template #if __CUDA_ARCH__ >= 350 OF_LAUNCH_BOUNDS_2(512, 4) #endif __global__ void lstm_cell_backward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* grad_hy_ptr, const T* grad_cy_ptr, const T* cx_ptr, const T* cy_ptr, const T* workspace_ptr, T* grad_gates_ptr, T* grad_cx_ptr) { for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel; linearIndex += gridDim.x * blockDim.x) { IDX_TYPE offset = (linearIndex / hidden_size) * 4 * hidden_size + linearIndex % hidden_size; T ig = workspace_ptr[offset + 0 * hidden_size]; T fg = workspace_ptr[offset + 1 * hidden_size]; T cg = workspace_ptr[offset + 2 * hidden_size]; T og = workspace_ptr[offset + 3 * hidden_size]; T* ih = &(grad_gates_ptr[offset + 0 * hidden_size]); T* fh = &(grad_gates_ptr[offset + 1 * hidden_size]); T* ch = &(grad_gates_ptr[offset + 2 * hidden_size]); T* oh = &(grad_gates_ptr[offset + 3 * hidden_size]); // will return hidden grads here T cx = cx_ptr[linearIndex]; T cy = cy_ptr[linearIndex]; ACC_T go = H2F(grad_hy_ptr[linearIndex]); ACC_T goc = H2F(grad_cy_ptr[linearIndex]); ACC_T gcx = ::tanh(H2F(cy)); ACC_T gog = go * gcx; gcx = go * H2F(og) * (1 - gcx * gcx) + goc; ACC_T gig = gcx * H2F(cg); ACC_T gfg = gcx * H2F(cx); ACC_T gcg = gcx * H2F(ig); gig = gig * (1 - H2F(ig)) * H2F(ig); gfg = gfg * (1 - H2F(fg)) * H2F(fg); gcg = gcg * (1 - H2F(cg) * H2F(cg)); gog = gog * (1 - H2F(og)) * H2F(og); *ih = F2H(gig); *fh = F2H(gfg); *ch = F2H(gcg); *oh = F2H(gog); if (grad_cx_ptr != nullptr) { gcx = gcx * H2F(fg); T* gi = &(grad_cx_ptr[linearIndex]); *gi = F2H(gcx); } } } template struct FusedLstmCellFunctor final { void operator()(ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel, const int64_t hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr, const T* cx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr, T* cy_ptr, T* workspace_ptr) { using ACC_T = acc_type; if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL((lstm_cell_forward), stream, cx_numel, static_cast(cx_numel), static_cast(hidden_size), input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, cy_ptr, workspace_ptr); } else { RUN_CUDA_KERNEL((lstm_cell_forward), stream, cx_numel, cx_numel, hidden_size, input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, cy_ptr, workspace_ptr); } } }; template<> void FusedLstmCellFunctor::operator()( ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel, const int64_t hidden_size, const float16* input_gates_ptr, const float16* hidden_gates_ptr, const float16* cx_ptr, const float16* input_bias_ptr, const float16* hidden_bias_ptr, float16* hy_ptr, float16* cy_ptr, float16* workspace_ptr) { if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL( (lstm_cell_forward), stream, cx_numel, static_cast(cx_numel), static_cast(hidden_size), reinterpret_cast(input_gates_ptr), reinterpret_cast(hidden_gates_ptr), reinterpret_cast(cx_ptr), reinterpret_cast(input_bias_ptr), reinterpret_cast(hidden_bias_ptr), reinterpret_cast(hy_ptr), reinterpret_cast(cy_ptr), reinterpret_cast(workspace_ptr)); } else { RUN_CUDA_KERNEL((lstm_cell_forward), stream, cx_numel, cx_numel, hidden_size, reinterpret_cast(input_gates_ptr), reinterpret_cast(hidden_gates_ptr), reinterpret_cast(cx_ptr), reinterpret_cast(input_bias_ptr), reinterpret_cast(hidden_bias_ptr), reinterpret_cast(hy_ptr), reinterpret_cast(cy_ptr), reinterpret_cast(workspace_ptr)); } } template struct FusedLstmCellGradFunctor final { void operator()(ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel, const int64_t hidden_size, const T* grad_hy_ptr, const T* grad_cy_ptr, const T* cx_ptr, const T* cy_ptr, const T* workspace_ptr, T* grad_gates_ptr, T* grad_cx_ptr) { using ACC_T = acc_type; if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL((lstm_cell_backward), stream, cx_numel, static_cast(cx_numel), static_cast(hidden_size), grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr, grad_gates_ptr, grad_cx_ptr); } else { RUN_CUDA_KERNEL((lstm_cell_backward), stream, cx_numel, cx_numel, hidden_size, grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr, grad_gates_ptr, grad_cx_ptr); } } }; template<> void FusedLstmCellGradFunctor::operator()( ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel, const int64_t hidden_size, const float16* grad_hy_ptr, const float16* grad_cy_ptr, const float16* cx_ptr, const float16* cy_ptr, const float16* workspace_ptr, float16* grad_gates_ptr, float16* grad_cx_ptr) { if (workspace_numel < std::numeric_limits::max()) { RUN_CUDA_KERNEL((lstm_cell_backward), stream, cx_numel, static_cast(cx_numel), static_cast(hidden_size), reinterpret_cast(grad_hy_ptr), reinterpret_cast(grad_cy_ptr), reinterpret_cast(cx_ptr), reinterpret_cast(cy_ptr), reinterpret_cast(workspace_ptr), reinterpret_cast(grad_gates_ptr), reinterpret_cast(grad_cx_ptr)); } else { RUN_CUDA_KERNEL((lstm_cell_backward), stream, cx_numel, cx_numel, hidden_size, reinterpret_cast(grad_hy_ptr), reinterpret_cast(grad_cy_ptr), reinterpret_cast(cx_ptr), reinterpret_cast(cy_ptr), reinterpret_cast(workspace_ptr), reinterpret_cast(grad_gates_ptr), reinterpret_cast(grad_cx_ptr)); } } } // namespace template class GpuFusedLstmCellKernel final : public user_op::OpKernel { public: GpuFusedLstmCellKernel() = default; ~GpuFusedLstmCellKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input_gates = ctx->Tensor4ArgNameAndIndex("input_gates", 0); const user_op::Tensor* hidden_gates = ctx->Tensor4ArgNameAndIndex("hidden_gates", 0); const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex("cx", 0); user_op::Tensor* hy = ctx->Tensor4ArgNameAndIndex("hy", 0); user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex("cy", 0); user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); const T* input_bias_ptr = nullptr; const T* hidden_bias_ptr = nullptr; if (ctx->has_input("input_bias", 0)) { CHECK(ctx->has_input("hidden_bias", 0)); input_bias_ptr = ctx->Tensor4ArgNameAndIndex("input_bias", 0)->dptr(); hidden_bias_ptr = ctx->Tensor4ArgNameAndIndex("hidden_bias", 0)->dptr(); } const T* input_gates_ptr = input_gates->dptr(); const T* hidden_gates_ptr = hidden_gates->dptr(); const T* cx_ptr = cx->dptr(); T* hy_ptr = hy->mut_dptr(); T* cy_ptr = cy->mut_dptr(); T* workspace_ptr = workspace->mut_dptr(); const int64_t cx_numel = cx->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1); FusedLstmCellFunctor()(ctx->stream(), cx_numel, workspace_numel, hidden_size, input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr, cy_ptr, workspace_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_LSTM_CELL_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_lstm_cell") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("cx", 0) == GetDataType::value) \ && (user_op::HobDataType("input_gates", 0) == GetDataType::value) \ && (user_op::HobDataType("hidden_gates", 0) == GetDataType::value)) REGISTER_FUSED_LSTM_CELL_KERNEL(float); REGISTER_FUSED_LSTM_CELL_KERNEL(float16); class GpuFusedLstmCellGradFloatKernel final : public user_op::OpKernel { public: GpuFusedLstmCellGradFloatKernel() = default; ~GpuFusedLstmCellGradFloatKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex("grad_hy", 0); const user_op::Tensor* grad_cy = ctx->Tensor4ArgNameAndIndex("grad_cy", 0); const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex("cx", 0); const user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex("cy", 0); const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); user_op::Tensor* grad_gates = ctx->Tensor4ArgNameAndIndex("grad_gates", 0); user_op::Tensor* grad_cx = ctx->Tensor4ArgNameAndIndex("grad_cx", 0); const float* grad_hy_ptr = grad_hy->dptr(); const float* grad_cy_ptr = grad_cy->dptr(); const float* cx_ptr = cx->dptr(); const float* cy_ptr = cy->dptr(); const float* workspace_ptr = workspace->dptr(); float* grad_gates_ptr = grad_gates->mut_dptr(); float* grad_cx_ptr = nullptr; if (ctx->has_output("grad_cx", 0)) { grad_cx_ptr = grad_cx->mut_dptr(); } const int64_t cx_numel = cx->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1); FusedLstmCellGradFunctor()(ctx->stream(), cx_numel, workspace_numel, hidden_size, grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr, grad_gates_ptr, grad_cx_ptr); if (ctx->has_output("grad_bias", 0)) { float* grad_bias_ptr = ctx->Tensor4ArgNameAndIndex("grad_bias", 0)->mut_dptr(); std::vector axis; axis.push_back(0); const Shape& reduced_shape = CreateReducedShape(workspace->shape_view(), {axis.begin(), axis.end()}); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, grad_bias_ptr), XpuVarNdarray(grad_gates->shape_view(), grad_gates->dptr()), XpuVarNdarray(tmp_buffer->shape_view(), tmp_buffer->mut_dptr())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("fused_lstm_cell_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("grad_hy", 0) == GetDataType::value) && (user_op::HobDataType("grad_cy", 0) == GetDataType::value) && (user_op::HobDataType("cx", 0) == GetDataType::value) && (user_op::HobDataType("cy", 0) == GetDataType::value) && (user_op::HobDataType("workspace", 0) == GetDataType::value)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { size_t tmp_bytes = 0; if (ctx->has_output("grad_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("workspace", 0).shape(); tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); } else { tmp_bytes = 0; } return tmp_bytes; }); class GpuFusedLstmCellGradHalfKernel final : public user_op::OpKernel { public: GpuFusedLstmCellGradHalfKernel() = default; ~GpuFusedLstmCellGradHalfKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex("grad_hy", 0); const user_op::Tensor* grad_cy = ctx->Tensor4ArgNameAndIndex("grad_cy", 0); const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex("cx", 0); const user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex("cy", 0); const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex("workspace", 0); user_op::Tensor* grad_gates = ctx->Tensor4ArgNameAndIndex("grad_gates", 0); user_op::Tensor* grad_cx = ctx->Tensor4ArgNameAndIndex("grad_cx", 0); const float16* grad_hy_ptr = grad_hy->dptr(); const float16* grad_cy_ptr = grad_cy->dptr(); const float16* cx_ptr = cx->dptr(); const float16* cy_ptr = cy->dptr(); const float16* workspace_ptr = workspace->dptr(); float16* grad_gates_ptr = grad_gates->mut_dptr(); float16* grad_cx_ptr = nullptr; if (ctx->has_output("grad_cx", 0)) { grad_cx_ptr = grad_cx->mut_dptr(); } const int64_t cx_numel = cx->shape_view().elem_cnt(); const int64_t workspace_numel = workspace->shape_view().elem_cnt(); const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1); FusedLstmCellGradFunctor()(ctx->stream(), cx_numel, workspace_numel, hidden_size, grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr, grad_gates_ptr, grad_cx_ptr); if (ctx->has_output("grad_bias", 0)) { std::vector axis; axis.push_back(0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& in_shape = grad_gates->shape_view(); const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); float* in_tmp_buffer = tmp_buffer->mut_dptr(); const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); float* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + in_tmp_buffer_bytes); const size_t out_tmp_buffer_bytes = GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float)); float* reduce_tmp_buffer = reinterpret_cast( tmp_buffer->mut_dptr() + in_tmp_buffer_bytes + out_tmp_buffer_bytes); const size_t reduce_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes, tmp_buffer->shape_view().elem_cnt()); auto h2f = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat16, DataType::kFloat); CHECK(h2f); auto f2h = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat, DataType::kFloat16); CHECK(f2h); h2f->Launch(ctx->stream(), grad_gates->dptr(), in_tmp_buffer, in_shape.elem_cnt()); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out_tmp_buffer), XpuVarNdarray(in_shape, in_tmp_buffer), XpuVarNdarray(in_shape, reduce_tmp_buffer)); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("grad_bias", 0); f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr(), output_tensor->shape_view().elem_cnt()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("fused_lstm_cell_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("grad_hy", 0) == GetDataType::value) && (user_op::HobDataType("grad_cy", 0) == GetDataType::value) && (user_op::HobDataType("cx", 0) == GetDataType::value) && (user_op::HobDataType("cy", 0) == GetDataType::value) && (user_op::HobDataType("workspace", 0) == GetDataType::value)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { size_t tmp_bytes = 0; if (ctx->has_output("grad_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("workspace", 0).shape(); const Shape& out_shape = ctx->OutputTensorDesc("grad_bias", 0).shape(); tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)) + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); } else { tmp_bytes = 0; } return tmp_bytes; }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_matmul_bias_add_relu_dropout.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/user/kernels/random_seed_util.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 namespace oneflow { namespace { constexpr int32_t kVecSize = 4; constexpr int32_t kBlockSize = 256; constexpr int32_t kWarpSize = 32; union RandPack4 { uint4 storage; uint32_t elem[4]; // store curand4 return val. }; template __device__ void SetCublasBitMask(const IndexType aux_ld, const IndexType row, const IndexType col, int32_t thread_bitmask, int32_t* mask) { IndexType linear_index = row * aux_ld + col; IndexType mask_index = linear_index / kWarpSize; IndexType mask_offset = linear_index - mask_index * kWarpSize; int32_t bitmask = thread_bitmask << mask_offset; for (int stride = kWarpSize / (pack_size * 2); stride > 0; stride /= 2) { bitmask |= __shfl_down_sync(__activemask(), bitmask, stride, kWarpSize); } if (mask_offset == 0) { mask[mask_index] = bitmask; } } template __global__ void FusedVectorizedReluDropoutKernel(uint64_t seed, uint64_t offset, const IndexType elem_cnt, const int32_t aux_ld, const IndexType cols, const uint32_t rate, float scale, T* x, int32_t* mask) { IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; T t_scale = static_cast(scale); RandPack4 rand_uniform_pack4; T zero_val = static_cast(0.0); for (IndexType linear_index = global_thread_id * kVecSize, step = gridDim.x * blockDim.x * kVecSize; linear_index < elem_cnt; linear_index += step) { const IndexType row = linear_index / cols; const IndexType col = linear_index - row * cols; int32_t thread_bitmask = 0; rand_uniform_pack4.storage = curand4(&state); LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; LoadPack out_vec; #pragma unroll for (int i = 0; i < kVecSize; i++) { bool relu_mask = true; if (relu) { // Relu relu_mask = x_vec.elem[i] >= zero_val; } // dropout bool mask_val = rand_uniform_pack4.elem[i] > rate; // Combined relu_mask, dropout_mask together. bool combined_mask = relu_mask && mask_val; // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type first T t_combined_mask = static_cast(static_cast(combined_mask)); thread_bitmask |= (combined_mask << i); out_vec.elem[i] = x_vec.elem[i] * t_combined_mask * t_scale; } *(reinterpret_cast(x + linear_index)) = out_vec.storage; SetCublasBitMask(aux_ld, row, col, thread_bitmask, mask); } } template __global__ void FusedPaddedVectorizedReluDropoutKernel(uint64_t seed, uint64_t offset, const IndexType aligned32_elem_cnt, const int32_t aux_ld, const IndexType aligned32_cols, const IndexType cols, const uint32_t rate, float scale, T* x, int32_t* mask) { IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; T t_scale = static_cast(scale); RandPack4 rand_uniform_pack4; T zero_val = static_cast(0.0); for (IndexType linear_index = global_thread_id * kVecSize, step = gridDim.x * blockDim.x * kVecSize; linear_index < aligned32_elem_cnt; linear_index += step) { const IndexType row = linear_index / aligned32_cols; const IndexType col = linear_index - row * aligned32_cols; int32_t thread_bitmask = 0; if (col < cols) { const IndexType actual_index = row * cols + col; rand_uniform_pack4.storage = curand4(&state); LoadType* x_load = reinterpret_cast(x + actual_index); LoadPack x_vec; x_vec.storage = *x_load; LoadPack out_vec; #pragma unroll for (int i = 0; i < kVecSize; i++) { bool relu_mask = true; if (relu) { // Relu relu_mask = x_vec.elem[i] >= zero_val; } // dropout bool mask_val = rand_uniform_pack4.elem[i] > rate; // Combined relu_mask, dropout_mask together. bool combined_mask = relu_mask && mask_val; // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type first T t_combined_mask = static_cast(static_cast(combined_mask)); thread_bitmask |= (combined_mask << i); out_vec.elem[i] = x_vec.elem[i] * t_combined_mask * t_scale; } *(reinterpret_cast(x + actual_index)) = out_vec.storage; } SetCublasBitMask(aux_ld, row, col, thread_bitmask, mask); } } template __global__ void FusedWarpReluDropoutKernel(uint64_t seed, uint64_t offset, const IndexType elem_cnt, const IndexType aux_ld, const IndexType rows, const IndexType cols, const uint32_t rate, float scale, T* x, int32_t* mask) { const int32_t lane_id = threadIdx.x; const IndexType global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; const IndexType step = gridDim.x * blockDim.y; const IndexType global_thread_id = global_warp_id * kWarpSize + lane_id; curandStatePhilox4_32_10_t state; curand_init(seed, global_thread_id, offset, &state); T t_scale = static_cast(scale); T zero_val = static_cast(0.0); RandPack4 rand_uniform_pack4; for (IndexType row = global_warp_id; row < rows; row += step) { for (IndexType col = lane_id; col < cols; col += kWarpSize * kVecSize) { const IndexType linear_index = row * cols + col; rand_uniform_pack4.storage = curand4(&state); #pragma unroll for (int i = 0; i < kVecSize; i++) { int32_t thread_bitmask = 0; int32_t cur_col = col + i * kWarpSize; int32_t cur_linear_index = linear_index + i * kWarpSize; if (cur_col < cols) { T x_val = x[cur_linear_index]; const uint32_t rand_uniform_val = rand_uniform_pack4.elem[i]; bool relu_mask = true; if (relu) { // relu relu_mask = x_val >= zero_val; } // dropout bool mask_val = rand_uniform_val > rate; // Combined relu_mask, dropout_mask together. bool combined_mask = relu_mask && mask_val; thread_bitmask = combined_mask; // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type // first T t_combined_mask = static_cast(static_cast(combined_mask)); T out_val = x_val * t_combined_mask * t_scale; x[cur_linear_index] = out_val; } int32_t warp_mask = __ballot_sync(__activemask(), thread_bitmask); if (lane_id == 0) { mask[(row * aux_ld + cur_col) / kWarpSize] = warp_mask; } } } } } template unsigned int ComputeGridSize(ep::Stream* stream, Func func, const int64_t elem_cnt, const int32_t block_size) { auto* cuda_stream = stream->As(); const int64_t pack_num = elem_cnt / kVecSize; const int32_t num_blocks = std::max(1, (pack_num + block_size - 1) / block_size); const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount; int max_active_blocks = 0; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, block_size, /*shared_memory*/ 0)); return std::min(num_blocks, max_active_blocks * multi_processor_count); } uint64_t RoundUp(uint64_t x, uint64_t y) { return (x + y - 1) / y * y; } template cudaError_t LaunchFusedReluDropoutKernel(ep::CudaStream* stream, const std::shared_ptr& cuda_generator, const int64_t elem_cnt, const int32_t aux_ld, const int64_t rows, const int64_t cols, float rate, float scale, T* x, int32_t* mask) { uint64_t offset = 0; uint64_t seed = cuda_generator->current_seed(); const uint32_t uint_rate = UINT_MAX * rate; unsigned int grid_size = 0; if (cols % 32 == 0) { // Launch Elementwise Vectorized Kernel. if (elem_cnt < GetMaxVal()) { grid_size = ComputeGridSize(stream, FusedVectorizedReluDropoutKernel, elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedVectorizedReluDropoutKernel <<cuda_stream()>>>(seed, offset, elem_cnt, aux_ld, cols, uint_rate, scale, x, mask); } else { grid_size = ComputeGridSize(stream, FusedVectorizedReluDropoutKernel, elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedVectorizedReluDropoutKernel <<cuda_stream()>>>(seed, offset, elem_cnt, aux_ld, cols, uint_rate, scale, x, mask); } } else { if (cols % 4 == 0) { // Padding cols to align kWarpSize. const int64_t align32_cols = (cols + kWarpSize - 1) / kWarpSize * kWarpSize; const int64_t align32_elem_cnt = rows * align32_cols; if (align32_elem_cnt < GetMaxVal()) { grid_size = ComputeGridSize(stream, FusedPaddedVectorizedReluDropoutKernel, align32_elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedPaddedVectorizedReluDropoutKernel <<cuda_stream()>>>(seed, offset, align32_elem_cnt, aux_ld, align32_cols, cols, uint_rate, scale, x, mask); } else { grid_size = ComputeGridSize(stream, FusedPaddedVectorizedReluDropoutKernel, align32_elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedPaddedVectorizedReluDropoutKernel <<cuda_stream()>>>(seed, offset, align32_elem_cnt, aux_ld, align32_cols, cols, uint_rate, scale, x, mask); } } else { // Process a row by using a warp. dim3 block_dim(kWarpSize, kBlockSize / kWarpSize); if (elem_cnt < GetMaxVal()) { grid_size = ComputeGridSize(stream, FusedWarpReluDropoutKernel, elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedWarpReluDropoutKernel <<cuda_stream()>>>( seed, offset, elem_cnt, aux_ld, rows, cols, uint_rate, scale, x, mask); } else { grid_size = ComputeGridSize(stream, FusedWarpReluDropoutKernel, elem_cnt, kBlockSize); uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize); offset = cuda_generator->get_philox_offset(inc_offset); FusedWarpReluDropoutKernel <<cuda_stream()>>>( seed, offset, elem_cnt, aux_ld, rows, cols, uint_rate, scale, x, mask); } } } return cudaPeekAtLastError(); } template class FusedMatmulBiasAddReluDropoutKernel final : public user_op::OpKernel { public: FusedMatmulBiasAddReluDropoutKernel() = default; ~FusedMatmulBiasAddReluDropoutKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { /* Fused DenseActivation Layer. Assume we have two layers: A: (m, k) B: (n, k) need transpose C: (j, n) need transpose tmp: A matmul B(transpose), its shape is (m, n) out: tmp matmul C(transpose), its shape is (m, j) */ const int32_t weight_size = ctx->input_size("weights"); const int32_t bias_size = ctx->input_size("biases"); CHECK_EQ(weight_size, bias_size) << "The number of weight and bias is not equal!. "; auto* cuda_stream = ctx->stream()->As(); const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast(cache)); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); bool skip_final_activation = ctx->Attr("skip_final_activation"); auto* fused_dropout_kernel_state = dynamic_cast(state); CHECK_NOTNULL(fused_dropout_kernel_state); const auto& generator = fused_dropout_kernel_state->generator(); CHECK_NOTNULL(generator); const auto device_index = ctx->stream()->device()->device_index(); std::shared_ptr cuda_generator = CHECK_JUST(generator->Get(device_index)); const std::vector dropout_rate_list = ctx->Attr>("dropout_rate_list"); const DataType data_type = out->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha = 1.0; const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const double beta = 0.0; const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // Currently only support 2D matmul. DimVector in_shape(2); x->shape_view().ToDimVector(&in_shape); DimVector weight_shape(2); const void* in_buf_ptr = x->dptr(); size_t offset = 0; for (int idx = 0; idx < weight_size; idx++) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("biases", idx); user_op::Tensor* cublas_aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx); const int64_t batchsize = in_shape.at(0); const int64_t out_feature = weight->shape_view().At(0); weight->shape_view().ToDimVector(&weight_shape); size_t matmul_out_elem_cnt = batchsize * out_feature; InferMatmulCublasMNK(in_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; void* matmul_out_ptr; float rate = dropout_rate_list.at(idx); float scale = 0.0; const int32_t aux_ld = AlignReluAuxLd(out_feature); if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); } if (idx == weight_size - 1) { matmul_out_ptr = ctx->Tensor4ArgNameAndIndex("out", 0)->mut_dptr(); } else { matmul_out_ptr = ctx->Tensor4ArgNameAndIndex("hidden", idx)->mut_dptr(); } SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(), /*aux_ptr=*/nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_cache->cublas_a_desc, in_buf_ptr, matmul_cache->cublas_b_desc, &sp_beta, matmul_out_ptr, matmul_cache->cublas_c_desc, matmul_out_ptr, matmul_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); if (idx != weight_size - 1 || !skip_final_activation || rate != 0.0f) { OF_CUDA_CHECK(cudaMemsetAsync(cublas_aux->mut_dptr(), 0, cublas_aux->shape_view().elem_cnt() * sizeof(int32_t), cuda_stream->cuda_stream())); } if (idx != weight_size - 1 || !skip_final_activation) { // If it's not last layer or it's last layer but need relu. OF_CUDA_CHECK((LaunchFusedReluDropoutKernel( cuda_stream, cuda_generator, matmul_out_elem_cnt, aux_ld, batchsize, out_feature, rate, scale, reinterpret_cast(matmul_out_ptr), reinterpret_cast(cublas_aux->mut_dptr())))); // Set relu_droput_out ptr as next layer's input. in_buf_ptr = matmul_out_ptr; // Set hidden_layer shape as next layer's input shape. in_shape.at(1) = out_feature; } else { if (rate == 0.0f) { // It's last layer and dropout_rate is 0.0f, we do not launch FusedReluDropoutKernel. break; } else { // skip_final_activation but need dropout. OF_CUDA_CHECK((LaunchFusedReluDropoutKernel( cuda_stream, cuda_generator, matmul_out_elem_cnt, aux_ld, batchsize, out_feature, rate, scale, reinterpret_cast(matmul_out_ptr), reinterpret_cast(cublas_aux->mut_dptr())))); } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("fused_matmul_bias_add_relu_dropout") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type)); REGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(float, DataType::kFloat) REGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(half, DataType::kFloat16) #if CUDA_VERSION >= 11000 REGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16) #endif } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11060 ================================================ FILE: oneflow/user/kernels/fused_matmul_bias_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" // same with cublas_fused_mlp_util.cuh #if CUDA_VERSION >= 11020 namespace oneflow { namespace { class FusedMatmulBiasKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedMatmulBiasKernel() = default; ~FusedMatmulBiasKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateCublasFusedMLPKernelCache(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto* cuda_stream = ctx->stream()->As(); const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast(cache)); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const user_op::Tensor* _add_to_output = (ctx->has_input("_add_to_output", 0)) ? ctx->Tensor4ArgNameAndIndex("_add_to_output", 0) : nullptr; const DataType data_type = out->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; const double alpha = ctx->Attr("alpha"); const double beta = (ctx->has_input("_add_to_output", 0)) ? ctx->Attr("beta") : 0.0; const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); DimVector in_shape({x->shape_view().Count(0, x->shape_view().NumAxes() - 1), x->shape_view().At(x->shape_view().NumAxes() - 1)}); DimVector weight_shape(2); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); weight->shape_view().ToDimVector(&weight_shape); InferMatmulCublasMNK(in_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m, &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; void* y_ptr = ctx->Tensor4ArgNameAndIndex("out", 0)->mut_dptr(); SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(), nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); cublasLtMatmulPreference_t preference = nullptr; size_t workspace_size = cuda_stream->cublas_workspace_size(); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference)); OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); int returned_results = 0; cublasLtMatmulHeuristicResult_t heuristic_result; OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, matmul_cache->cublas_a_desc, matmul_cache->cublas_b_desc, matmul_cache->cublas_c_desc, matmul_cache->cublas_c_desc, preference, 1, &heuristic_result, &returned_results)); CHECK_EQ(returned_results, 1); cublasLtMatmulPreferenceDestroy(preference); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_cache->cublas_a_desc, x->dptr(), matmul_cache->cublas_b_desc, &sp_beta, (_add_to_output == nullptr) ? y_ptr : _add_to_output->dptr(), matmul_cache->cublas_c_desc, y_ptr, matmul_cache->cublas_c_desc, &heuristic_result.algo, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(data_type) \ REGISTER_USER_KERNEL("fused_matmul_bias") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type)); REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kDouble); REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat); REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16); #if CUDA_VERSION >= 11000 REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kBFloat16); #endif // CUDA_VERSION >= 11000 } // namespace } // namespace oneflow #endif // CUDA_VERSION >= 11020 ================================================ FILE: oneflow/user/kernels/fused_relu_dropout_grad_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { namespace { constexpr int32_t kWarpSize = 32; template __global__ void VectorizedReluDropoutBitmaskBackwardKernel( const IndexType elem_cnt, const IndexType cols, const IndexType aux_ld, const float scale, const IndexType n_tail, const IndexType tail_offset, const T* dy, const int32_t* mask, T* dx) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadStoreType = cuda::elementwise::PackType; using LoadStorePack = cuda::elementwise::Pack; T t_scale = static_cast(scale); for (IndexType linear_pack_index = global_thread_id * pack_size; linear_pack_index < elem_cnt; linear_pack_index += gridDim.x * blockDim.x * pack_size) { const LoadStoreType* dy_load = reinterpret_cast(dy + linear_pack_index); LoadStorePack dy_vec; dy_vec.storage = *dy_load; LoadStorePack dx_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { const IndexType linear_index = (linear_pack_index + i); const IndexType row = linear_index / cols; const IndexType col = linear_index - row * cols; const int32_t col_mod_warpsize = col % kWarpSize; const IndexType aux_idx = ((row * aux_ld) + col) / kWarpSize; bool is_positive = mask[aux_idx] & (1 << col_mod_warpsize); dx_vec.elem[i] = dy_vec.elem[i] * static_cast(static_cast(is_positive)) * static_cast(scale); } *(reinterpret_cast(dx + linear_pack_index)) = dx_vec.storage; } if (tail && global_thread_id < n_tail) { const IndexType tail_index = tail_offset + global_thread_id; const IndexType tail_row = tail_index / cols; const IndexType tail_col = tail_index - tail_row * cols; const IndexType tail_col_mod_warpsize = tail_col % kWarpSize; const IndexType tail_aux_idx = ((tail_row * aux_ld) + tail_col) / kWarpSize; bool is_positive = mask[tail_aux_idx] & (1 << tail_col_mod_warpsize); dx[tail_index] = dy[tail_index] * static_cast(static_cast(is_positive)) * static_cast(scale); } } template void LaunchVectorizedReluDropoutBackwardKernel(ep::Stream* stream, const int64_t elem_cnt, const int64_t cols, const int64_t aux_ld, float scale, const T* dy, const int32_t* mask, T* dx) { constexpr int pack_size = cuda::elementwise::PackSize(); const int64_t pack_num = elem_cnt / pack_size; const int64_t tail_offset = pack_num * pack_size; const int64_t n_tail = elem_cnt - tail_offset; const bool tail = n_tail > 0 ? true : false; if (tail) { if (elem_cnt < GetMaxVal()) { stream->As()->LaunchKernelDefaultWaves( (VectorizedReluDropoutBitmaskBackwardKernel), std::max(1, pack_num), elem_cnt, cols, aux_ld, scale, n_tail, tail_offset, dy, mask, dx); } else { stream->As()->LaunchKernelDefaultWaves( (VectorizedReluDropoutBitmaskBackwardKernel), std::max(1, pack_num), elem_cnt, cols, aux_ld, scale, n_tail, tail_offset, dy, mask, dx); } } else { if (elem_cnt < GetMaxVal()) { stream->As()->LaunchKernelDefaultWaves( (VectorizedReluDropoutBitmaskBackwardKernel), std::max(1, pack_num), elem_cnt, cols, aux_ld, scale, /*n_tail=*/0, tail_offset, dy, mask, dx); } else { stream->As()->LaunchKernelDefaultWaves( (VectorizedReluDropoutBitmaskBackwardKernel), std::max(1, pack_num), elem_cnt, cols, aux_ld, scale, /*n_tail=*/0, tail_offset, dy, mask, dx); } } } template class FusedReluDropoutGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedReluDropoutGradKernel() = default; ~FusedReluDropoutGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float scale = ctx->Attr("scale"); const int64_t cols = dy->shape_view().At(1); const int64_t aux_ld = mask->shape_view().At(1) * 32; const int64_t elem_cnt = dy->shape_view().elem_cnt(); LaunchVectorizedReluDropoutBackwardKernel( ctx->stream(), elem_cnt, cols, aux_ld, scale, reinterpret_cast(dy->dptr()), mask->dptr(), reinterpret_cast(dx->mut_dptr())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("fused_relu_dropout_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == data_type)); REGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(float, DataType::kFloat) REGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(half, DataType::kFloat16) #if CUDA_VERSION >= 11000 REGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16) #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_rnn_cell_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_ // NOTE(Liang Depeng): Modified from // https://github.com/pytorch/pytorch/blob/master/c10/macros/Macros.h#L256 #if defined(__CUDACC__) // constants from // (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) // The maximum number of threads per multiprocessor is 1024 for Turing // architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other // architectures. You'll get warnings if you exceed these constants. Hence, the // following macros adjust the input values from the user to resolve potential // warnings. #if __CUDA_ARCH__ == 750 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; #elif __CUDA_ARCH__ == 860 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; #else constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; #endif // CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; // CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block // size. 256 is a good number for this fallback and should give good occupancy // and versatility across all architectures. constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; // NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it // turns out that although __launch_bounds__ can take constexpr, it // can't take a constexpr that has anything to do with templates. // Currently we use launch_bounds that depend on template arguments in // Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, OF_MAX_THREADS_PER_BLOCK // and OF_MIN_BLOCKS_PER_SM are kept as macros. // Suppose you were planning to write __launch_bounds__(a, b), based on your // performance tuning on a modern GPU. Instead, you should write // __launch_bounds__(OF_MAX_THREADS_PER_BLOCK(a), OF_MIN_BLOCKS_PER_SM(a, b)), // which will also properly respect limits on old architectures. #define OF_MAX_THREADS_PER_BLOCK(val) \ (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK) #define OF_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ ? (blocks_per_sm) \ : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / (threads_per_block)))) // OF_LAUNCH_BOUNDS is analogous to __launch_bounds__ #define OF_LAUNCH_BOUNDS_0 \ __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and // versatility across all architectures. #define OF_LAUNCH_BOUNDS_1(max_threads_per_block) \ __launch_bounds__((OF_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) #define OF_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ __launch_bounds__((OF_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ (OF_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) #endif #endif // ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/fused_scale_mask_bias_softmax.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/framework/user_op_tensor.h" namespace oneflow { namespace { template struct LoadWithBias { LoadWithBias(const SRC* x_ptr, const SRC* mask_ptr, const SRC* bias_ptr, const SRC scale, int64_t row_stride, int64_t bias_stride, int64_t row_size) : x_ptr_(x_ptr), mask_ptr_(mask_ptr), bias_ptr_(bias_ptr), scale_(scale), row_stride_(row_stride), bias_stride_(bias_stride), row_size_(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::softmax::Pack x; const int64_t offset = (row * row_size_ + col) / N; x.storage = *(reinterpret_cast*>(x_ptr_) + offset); cuda::softmax::Pack mask; const int64_t m_offset = (row / row_stride_ * row_size_ + col) / N; mask.storage = *(reinterpret_cast*>(mask_ptr_) + m_offset); cuda::softmax::Pack bias; /* 1). bias_stride_ = 0 for bias: [1, num_heads, seqlen_q, seqlen_kv] x: [batch_size, num_heads, seqlen_q, seqlen_kv] 2). bias_stride_ > 0 for bias: [ensemble_batch, 1, num_heads, seqlen_q, seqlen_kv] x: [ensemble_batch, batch_size, num_heads, seqlen_q, seqlen_kv] here, bias_stride_ = batch_size, row_stride_ = num_heads * seqlen_q x could be viewed as [B1, B2, B3] and bias could be viewed as [B1, 1, B3] where B1 = ensemble_batch, B2 = batch_size = bias_stride_, B3 = num_heads * seqlen_q = row_stride_ For row in range [0, B1 * B2 * B3) {[0, ensemble_batch * batch_size * num_heads * seqlen_q]} b1 = row/(B2*B3), b2=(row%(B2*B3)/B3), b3 = row%B3, after broadcast b2 will be 0 for bias. And finally the correspoding (broadcast) row of bias will be: `b1 * B3 + b3 = row/(B2*B3) * B3 + row%B3 = row / (bias_stride_ * row_stride_) * row_stride_ + row % row_stride_` */ int64_t bias_offset = (bias_stride_ > 0) ? ((row / (bias_stride_ * row_stride_) * row_stride_ + row % row_stride_) * row_size_ + col) / N : (row % row_stride_ * row_size_ + col) / N; bias.storage = *(reinterpret_cast*>(bias_ptr_) + bias_offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(x.elem[i]) * static_cast(scale_) + static_cast(mask.elem[i]) + static_cast(bias.elem[i]); } } const SRC* x_ptr_; const SRC* mask_ptr_; const SRC* bias_ptr_; const SRC scale_; int64_t row_stride_; int64_t bias_stride_; int64_t row_size_; }; template struct LoadWithoutBias { LoadWithoutBias(const SRC* x_ptr, const SRC* mask_ptr, const SRC scale, int64_t row_stride, int64_t row_size) : x_ptr_(x_ptr), mask_ptr_(mask_ptr), scale_(scale), row_stride_(row_stride), row_size_(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::softmax::Pack x; const int64_t offset = (row * row_size_ + col) / N; x.storage = *(reinterpret_cast*>(x_ptr_) + offset); cuda::softmax::Pack mask; const int64_t m_offset = (row / row_stride_ * row_size_ + col) / N; mask.storage = *(reinterpret_cast*>(mask_ptr_) + m_offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(x.elem[i]) * static_cast(scale_) + static_cast(mask.elem[i]); } } const SRC* x_ptr_; const SRC* mask_ptr_; const SRC scale_; int64_t row_stride_; int64_t row_size_; }; template::type> void LaunchFusedSoftmaxForwardKernel(cudaStream_t stream, T* out, const T* x, const T* mask, const T* bias, T scale, const int64_t row_stride, const int64_t bias_stride, const int64_t rows, const int64_t row_size) { cuda::softmax::DirectStore store(out, row_size); if (bias != nullptr) { LoadWithBias load(x, mask, bias, scale, row_stride, bias_stride, row_size); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, row_size))); } else { LoadWithoutBias load(x, mask, scale, row_stride, row_size); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, row_size))); } }; template struct GradStore { GradStore(DST* dx, const SRC scale, int64_t row_size) : dx(dx), scale(scale), row_size(row_size) {} template __device__ void store(const SRC* dout, int64_t row, int64_t col) const { cuda::softmax::Pack x; const int64_t offset = (row * row_size + col) / N; #pragma unroll for (int i = 0; i < N; ++i) { x.elem[i] = static_cast(dout[i]) * static_cast(scale); } *(reinterpret_cast*>(dx) + offset) = x.storage; } DST* dx; const SRC scale; int64_t row_size; }; template::type> void LaunchSoftmaxBackwardKernel(cudaStream_t stream, T* dx, const T* y, const T* dy, T scale, const int64_t rows, const int64_t row_size) { GradStore store(dx, scale, row_size); cuda::softmax::DirectLoad load_y(y, row_size); cuda::softmax::DirectLoad load_dy(dy, row_size); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad( stream, load_y, load_dy, store, rows, row_size))); }; } // namespace template class FusedScaleMaskBiasSoftmaxKernel final : public user_op::OpKernel { public: FusedScaleMaskBiasSoftmaxKernel() = default; ~FusedScaleMaskBiasSoftmaxKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const T scale = ctx->Attr("scale"); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); auto x_shape = x->shape_view(); auto axes = x_shape.NumAxes(); /* * axes=3 for x: [batch_size, num_heads, seq], mask: [batch_size, 1, seq], no bias here * axes=4 for x: [batch_size, num_heads, seq_len_q, seq_len_kv] * mask: [batch_size, 1, 1, seq_len_kv] * bias: [1, num_heads, seq_len_q, seq_len_kv] * axes=5 for x: [ensemble_batch, batch_size, num_heads, seq_len_q, seq_len_kv] * mask: [ensemble_batch, batch_size, 1, 1, seq_len_kv] * bias: [ensemble_batch, 1, num_heads, seq_len_q, seq_len_kv] * `axes=5` is equivalent to `axes=4` when ensemble_batch = 1 . * * row_stride is used for computing `mask` stride and * bias_stride for computing `bias` stride * row_stride is num_heads (for `axes=3`) or num_heads * seq_len_q (for `axes=4` & `axes=5`) * bias_stride is 0 (for `axes=4`) or batch_size (for `axes=5`) * row_size = seq_len_k (the last dimension of `x`) */ CHECK(axes == 3 || axes == 4 || axes == 5); auto mask_shape = mask->shape_view(); CHECK(mask_shape.NumAxes() == axes); const int row_size = x_shape.At(axes - 1); const int rows = x_shape.elem_cnt() / row_size; int row_stride = 1; for (int i = axes - 2; i >= 0; i--) { if (mask_shape.At(i) == 1) row_stride *= x_shape.At(i); else break; } user_op::Tensor* bias = nullptr; int64_t bias_stride = 0; if (ctx->has_input("bias", 0)) { bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (axes == 5 && x_shape.At(0) != 1) bias_stride = x_shape.At(1); } LaunchFusedSoftmaxForwardKernel(ctx->stream()->As()->cuda_stream(), out->mut_dptr(), x->dptr(), mask->dptr(), ctx->has_input("bias", 0) ? bias->dptr() : nullptr, scale, row_stride, bias_stride, rows, row_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_bias_softmax") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(nv_bfloat16) #endif REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(float) template class FusedScaleMaskBiasSoftmaxGradKernel final : public user_op::OpKernel { public: FusedScaleMaskBiasSoftmaxGradKernel() = default; ~FusedScaleMaskBiasSoftmaxGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const T scale = ctx->Attr("scale"); auto y_shape = y->shape_view(); const int64_t axes = y_shape.NumAxes(); int64_t row_size = y_shape.At(axes - 1); int64_t rows = y_shape.elem_cnt() / row_size; LaunchSoftmaxBackwardKernel(ctx->stream()->As()->cuda_stream(), dx->mut_dptr(), y->dptr(), dy->dptr(), scale, rows, row_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_bias_softmax_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(half) #if CUDA_VERSION >= 11000 REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(nv_bfloat16) #endif REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(float) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_scale_mask_softmax.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/fused_softmax.cuh" namespace oneflow { namespace { template void LaunchBroadcastForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask, const int64_t elem_cnt, const int64_t rows, const int64_t cols, const float fill, const float scale, const int64_t* input_dims, const int64_t* mask_dims) { NdIndexOffsetHelper input_index_helper(input_dims); NdIndexOffsetHelper mask_index_helper(mask_dims); cuda::fused_softmax::BroadcastMaskSoftmaxParams params; params.src_index_helper = input_index_helper; params.mask_index_helper = mask_index_helper; params.mask_dims = mask_dims; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::fused_softmax::BroadcastScaleMaskLoad load(x, mask, params); cuda::softmax::DirectStore store(y, cols); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, cols))); } template void LaunchElementwiseForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask, const int64_t rows, const int64_t cols, const float fill, const float scale) { cuda::fused_softmax::ElementwiseMaskSoftmaxParams params; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::fused_softmax::ElementwiseScaleMaskLoad load(x, mask, params); cuda::softmax::DirectStore store(y, cols); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, cols))); } template void LaunchBroadcastBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx, const MASK* mask, const int64_t elem_cnt, const int64_t rows, const int64_t cols, const float fill, const float scale, const int64_t* input_dims, const int64_t* mask_dims) { NdIndexOffsetHelper input_index_helper(input_dims); NdIndexOffsetHelper mask_index_helper(mask_dims); cuda::fused_softmax::BroadcastMaskSoftmaxParams params; params.src_index_helper = input_index_helper; params.mask_index_helper = mask_index_helper; params.mask_dims = mask_dims; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::softmax::DirectLoad load_y(y, cols); cuda::softmax::DirectLoad load_dy(dy, cols); cuda::fused_softmax::BroadcastScaleMaskStore store( dx, mask, params); OF_CUDA_CHECK(( cuda::softmax::DispatchSoftmaxGrad(stream, load_y, load_dy, store, rows, cols))); } template void LaunchElementwiseBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx, const MASK* mask, const int64_t rows, const int64_t cols, const float fill, const float scale) { cuda::fused_softmax::ElementwiseMaskSoftmaxParams params; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::softmax::DirectLoad load_y(y, cols); cuda::softmax::DirectLoad load_dy(dy, cols); cuda::fused_softmax::ElementwiseScaleMaskStore store(dx, mask, params); OF_CUDA_CHECK(( cuda::softmax::DispatchSoftmaxGrad(stream, load_y, load_dy, store, rows, cols))); } constexpr int32_t kMaxNumDims = 5; template class FusedScaleMaskSoftmaxKernel final : public user_op::OpKernel { public: FusedScaleMaskSoftmaxKernel() = default; ~FusedScaleMaskSoftmaxKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const float mask_fill_value = ctx->Attr("mask_fill_value"); const float scale_value = ctx->Attr("scale_value"); const ShapeView& x_shape = x->shape_view(); const ShapeView& mask_shape = mask->shape_view(); CHECK_GE(x_shape.NumAxes(), 2); const int64_t elem_cnt = x_shape.elem_cnt(); const int64_t cols = x_shape.At(x_shape.NumAxes() - 1); const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1); const size_t num_input_dims = x_shape.NumAxes(); const int64_t* input_dims = x_shape.ptr(); const size_t num_mask_dims = mask_shape.NumAxes(); const int64_t* mask_dims = mask_shape.ptr(); using ComputeType = typename cuda::softmax::DefaultComputeType::type; size_t simplified_num_dims = 0; int64_t simplified_input_dims[kMaxNumDims]; int64_t simplified_mask_dims[kMaxNumDims]; cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims, &simplified_num_dims, simplified_input_dims, simplified_mask_dims); if (simplified_num_dims == 1) { LaunchElementwiseForwardKernel( ctx->stream()->As()->cuda_stream(), x->dptr(), y->mut_dptr(), mask->dptr(), rows, cols, mask_fill_value, scale_value); } #define DEFINE_ONE_ELIF(dims) \ else if (simplified_num_dims == dims) { \ LaunchBroadcastForwardKernel( \ ctx->stream()->As()->cuda_stream(), x->dptr(), y->mut_dptr(), \ mask->dptr(), elem_cnt, rows, cols, mask_fill_value, scale_value, \ simplified_input_dims, simplified_mask_dims); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) #undef DEFINE_ONE_ELIF else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class FusedScaleMaskSoftmaxGradKernel final : public user_op::OpKernel { public: FusedScaleMaskSoftmaxGradKernel() = default; ~FusedScaleMaskSoftmaxGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float scale_value = ctx->Attr("scale_value"); const float mask_fill_value = static_cast(0.0); const ShapeView& dy_shape = dy->shape_view(); const ShapeView& mask_shape = mask->shape_view(); CHECK_GE(dy_shape.NumAxes(), 2); const int64_t elem_cnt = dy_shape.elem_cnt(); const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1); const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1); const int64_t* input_dims = dy_shape.ptr(); const size_t num_input_dims = dy_shape.NumAxes(); const int64_t* mask_dims = mask_shape.ptr(); const size_t num_mask_dims = mask_shape.NumAxes(); using ComputeType = typename cuda::softmax::DefaultComputeType::type; size_t simplified_num_dims = 0; int64_t simplified_input_dims[kMaxNumDims]; int64_t simplified_mask_dims[kMaxNumDims]; cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims, &simplified_num_dims, simplified_input_dims, simplified_mask_dims); if (simplified_num_dims == 1) { LaunchElementwiseBackwardKernel( ctx->stream()->As()->cuda_stream(), y->dptr(), dy->dptr(), dx->mut_dptr(), mask->dptr(), rows, cols, mask_fill_value, scale_value); } #define DEFINE_ONE_ELIF(dims) \ else if (simplified_num_dims == dims) { \ LaunchBroadcastBackwardKernel( \ ctx->stream()->As()->cuda_stream(), y->dptr(), dy->dptr(), \ dx->mut_dptr(), mask->dptr(), elem_cnt, rows, cols, mask_fill_value, scale_value, \ simplified_input_dims, simplified_mask_dims); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) #undef DEFINE_ONE_ELIF else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(dtype, mask_dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_softmax") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(half, bool) REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(float, bool) #undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL #define REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(dtype, mask_dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_softmax_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(half, bool) REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(float, bool) #undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_scale_mask_softmax_dropout.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/fused_softmax.cuh" namespace oneflow { namespace { template void LaunchBroadcastForwardKernel(cudaStream_t stream, const T* x, T* y, T* softmax_y, const MASK* mask, const bool* dropout_mask, const int64_t elem_cnt, const int64_t rows, const int64_t cols, const float fill, const float scale, const float dropout_scale, const int64_t* input_dims, const int64_t* mask_dims) { cuda::fused_softmax::DropoutStore store(y, softmax_y, dropout_mask, cols, dropout_scale); NdIndexOffsetHelper input_index_helper(input_dims); NdIndexOffsetHelper mask_index_helper(mask_dims); cuda::fused_softmax::BroadcastMaskSoftmaxParams params; params.src_index_helper = input_index_helper; params.mask_index_helper = mask_index_helper; params.mask_dims = mask_dims; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::fused_softmax::BroadcastScaleMaskLoad load(x, mask, params); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, cols))); } template void LaunchElementwiseForwardKernel(cudaStream_t stream, const T* x, T* y, T* softmax_y, const MASK* mask, const bool* dropout_mask, const int64_t rows, const int64_t cols, const float fill, const float scale, const float dropout_scale) { cuda::fused_softmax::ElementwiseMaskSoftmaxParams params; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::fused_softmax::ElementwiseScaleMaskLoad load(x, mask, params); cuda::fused_softmax::DropoutStore store(y, softmax_y, dropout_mask, cols, dropout_scale); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( stream, load, store, rows, cols))); } template void LaunchBroadcastBackwardKernel(cudaStream_t stream, const T* softmax_y, const T* dy, T* dx, const MASK* mask, const bool* dropout_mask, const int64_t elem_cnt, const int64_t rows, const int64_t cols, const float fill, const float scale, const float dropout_scale, const int64_t* input_dims, const int64_t* mask_dims) { cuda::fused_softmax::MaskScaleLoad load_dy(dy, dropout_mask, cols, dropout_scale); NdIndexOffsetHelper input_index_helper(input_dims, num_dims); NdIndexOffsetHelper mask_index_helper(mask_dims, num_dims); cuda::fused_softmax::BroadcastMaskSoftmaxParams params; params.src_index_helper = input_index_helper; params.mask_index_helper = mask_index_helper; params.mask_dims = mask_dims; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::softmax::DirectLoad load_softmax_y(softmax_y, cols); cuda::fused_softmax::BroadcastScaleMaskStore store( dx, mask, params); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad( stream, load_softmax_y, load_dy, store, rows, cols))); } template void LaunchElementwiseBackwardKernel(cudaStream_t stream, const T* softmax_y, const T* dy, T* dx, const MASK* mask, const bool* dropout_mask, const int64_t rows, const int64_t cols, const float fill, const float scale, const float dropout_scale) { cuda::fused_softmax::ElementwiseMaskSoftmaxParams params; params.row_size = cols; params.fill = fill; params.scale = scale; cuda::softmax::DirectLoad load_softmax_y(softmax_y, cols); cuda::fused_softmax::MaskScaleLoad load_dy(dy, dropout_mask, cols, dropout_scale); cuda::fused_softmax::ElementwiseScaleMaskStore store(dx, mask, params); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad( stream, load_softmax_y, load_dy, store, rows, cols))); } constexpr int32_t kMaxNumDims = 5; template class FusedScaleMaskSoftmaxDropoutKernel final : public user_op::OpKernel { public: FusedScaleMaskSoftmaxDropoutKernel() = default; ~FusedScaleMaskSoftmaxDropoutKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex("dropout_mask", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const float mask_fill_value = ctx->Attr("mask_fill_value"); const float scale_value = ctx->Attr("scale_value"); const float dropout_scale_value = ctx->Attr("dropout_scale_value"); user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex("softmax_y", 0); const ShapeView& x_shape = x->shape_view(); const ShapeView& mask_shape = mask->shape_view(); CHECK_GE(x_shape.NumAxes(), 2); const int64_t elem_cnt = x_shape.elem_cnt(); const int64_t cols = x_shape.At(x_shape.NumAxes() - 1); const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1); const size_t num_input_dims = x_shape.NumAxes(); const int64_t* input_dims = x_shape.ptr(); const size_t num_mask_dims = mask_shape.NumAxes(); const int64_t* mask_dims = mask_shape.ptr(); using ComputeType = typename cuda::softmax::DefaultComputeType::type; size_t simplified_num_dims = 0; int64_t simplified_input_dims[kMaxNumDims]; int64_t simplified_mask_dims[kMaxNumDims]; cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims, &simplified_num_dims, simplified_input_dims, simplified_mask_dims); if (simplified_num_dims == 1) { LaunchElementwiseForwardKernel( ctx->stream()->As()->cuda_stream(), x->dptr(), y->mut_dptr(), softmax_y->mut_dptr(), mask->dptr(), dropout_mask->dptr(), rows, cols, mask_fill_value, scale_value, dropout_scale_value); } #define DEFINE_ONE_ELIF(dims) \ else if (simplified_num_dims == dims) { \ LaunchBroadcastForwardKernel( \ ctx->stream()->As()->cuda_stream(), x->dptr(), y->mut_dptr(), \ softmax_y->mut_dptr(), mask->dptr(), dropout_mask->dptr(), elem_cnt, rows, \ cols, mask_fill_value, scale_value, dropout_scale_value, simplified_input_dims, \ simplified_mask_dims); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) #undef DEFINE_ONE_ELIF else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class FusedScaleMaskSoftmaxDropoutGradKernel final : public user_op::OpKernel { public: FusedScaleMaskSoftmaxDropoutGradKernel() = default; ~FusedScaleMaskSoftmaxDropoutGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex("softmax_y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex("dropout_mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float mask_fill_value = static_cast(0.0); const float scale_value = ctx->Attr("scale_value"); const float dropout_scale_value = ctx->Attr("dropout_scale_value"); const ShapeView& dy_shape = dy->shape_view(); const int64_t elem_cnt = dy_shape.elem_cnt(); const ShapeView& mask_shape = mask->shape_view(); CHECK_GE(dy_shape.NumAxes(), 2); const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1); const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1); const int64_t* input_dims = dy_shape.ptr(); const size_t num_input_dims = dy_shape.NumAxes(); const int64_t* mask_dims = mask_shape.ptr(); const size_t num_mask_dims = mask_shape.NumAxes(); using ComputeType = typename cuda::softmax::DefaultComputeType::type; cuda::softmax::DirectLoad load_softmax_y(softmax_y->dptr(), cols); size_t simplified_num_dims = 0; int64_t simplified_input_dims[kMaxNumDims]; int64_t simplified_mask_dims[kMaxNumDims]; cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims, &simplified_num_dims, simplified_input_dims, simplified_mask_dims); if (simplified_num_dims == 1) { LaunchElementwiseBackwardKernel( ctx->stream()->As()->cuda_stream(), softmax_y->dptr(), dy->dptr(), dx->mut_dptr(), mask->dptr(), dropout_mask->dptr(), rows, cols, mask_fill_value, scale_value, dropout_scale_value); } #define DEFINE_ONE_ELIF(dims) \ else if (simplified_num_dims == dims) { \ LaunchBroadcastBackwardKernel( \ ctx->stream()->As()->cuda_stream(), softmax_y->dptr(), dy->dptr(), \ dx->mut_dptr(), mask->dptr(), dropout_mask->dptr(), elem_cnt, rows, cols, \ static_cast(0.0), ctx->Attr("scale_value"), \ ctx->Attr("dropout_scale_value"), simplified_input_dims, simplified_mask_dims); \ } DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) #undef DEFINE_ONE_ELIF else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(dtype, mask_dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_softmax_dropout") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(half, bool) REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(float, bool) #undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL #define REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(dtype, mask_dtype) \ REGISTER_USER_KERNEL("fused_scale_mask_softmax_dropout_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value) \ && (user_op::HobDataType("mask", 0) == GetDataType::value)); REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(half, bool) REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(float, bool) #undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { inline cublasOperation_t GetCublasOp(char op) { switch (op) { case 'n': case 'N': { return CUBLAS_OP_N; } case 't': case 'T': { return CUBLAS_OP_T; } case 'c': case 'C': { return CUBLAS_OP_C; } default: { UNIMPLEMENTED(); } } return CUBLAS_OP_N; } template struct CudaDataTypeTrait; template<> struct CudaDataTypeTrait { const static cudaDataType_t value = CUDA_R_32F; }; template<> struct CudaDataTypeTrait { const static cudaDataType_t value = CUDA_R_16F; }; template void CublasBatchGemm(ep::CudaStream* stream, char transa, char transb, int64_t m, int64_t n, int64_t k, T alpha, const T* a, int64_t lda, int64_t stridea, const T* b, int64_t ldb, int64_t strideb, T beta, T* c, int64_t ldc, int64_t stridec, int64_t batch_size) { cublasOperation_t opa = GetCublasOp(transa); cublasOperation_t opb = GetCublasOp(transb); if (CUDA_VERSION >= 9010 && stream->cuda_arch() >= 500) { #if CUDA_VERSION >= 9010 cudaDataType_t data_type = CudaDataTypeTrait::value; OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx( stream->cublas_handle(), opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), data_type, lda, stridea, reinterpret_cast(b), data_type, ldb, strideb, reinterpret_cast(&beta), reinterpret_cast(c), data_type, ldc, stridec, batch_size, data_type, CUBLAS_GEMM_DEFAULT)); #else UNIMPLEMENTED(); #endif } } #if CUDA_VERSION >= 9010 template<> void CublasBatchGemm(ep::CudaStream* stream, char transa, char transb, int64_t m, int64_t n, int64_t k, half alpha, const half* a, int64_t lda, int64_t stridea, const half* b, int64_t ldb, int64_t strideb, half beta, half* c, int64_t ldc, int64_t stridec, int64_t batch_size) { using comp_t = float; cublasOperation_t opa = GetCublasOp(transa); cublasOperation_t opb = GetCublasOp(transb); if (stream->cuda_arch() >= 500) { float alpha_f = static_cast(alpha); float beta_f = static_cast(beta); #if CUDA_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; #endif cudaDataType_t data_type = CudaDataTypeTrait::value; cudaDataType_t comp_type = CudaDataTypeTrait::value; OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx( stream->cublas_handle(), opa, opb, m, n, k, &alpha_f, reinterpret_cast(a), data_type, lda, stridea, reinterpret_cast(b), data_type, ldb, strideb, &beta_f, reinterpret_cast(c), data_type, ldc, stridec, batch_size, comp_type, algo)); } } template<> void CublasBatchGemm(ep::CudaStream* stream, char transa, char transb, int64_t m, int64_t n, int64_t k, float16 alpha, const float16* a, int64_t lda, int64_t stridea, const float16* b, int64_t ldb, int64_t strideb, float16 beta, float16* c, int64_t ldc, int64_t stridec, int64_t batch_size) { CublasBatchGemm(stream, transa, transb, m, n, k, static_cast(alpha), reinterpret_cast(a), lda, stridea, reinterpret_cast(b), ldb, strideb, static_cast(beta), reinterpret_cast(c), ldc, stridec, batch_size); } #endif // CUDA_VERSION >= 9010 template void BatchedGemm(ep::Stream* stream, char opa, char opb, int64_t m, int64_t n, int64_t k, float alpha, const T* a, int64_t lda, int64_t stridea, const T* b, int64_t ldb, int64_t strideb, float beta, T* c, int64_t ldc, int64_t stridec, int64_t batch_size) { // swap m and n, a and b to convert from row-major to col-major CublasBatchGemm(stream->As(), opb, opa, n, m, k, static_cast(alpha), b, ldb, strideb, a, lda, stridea, static_cast(beta), c, ldc, stridec, batch_size); } SliceParams ConstructSliceParams4Value(int64_t seq_len, int64_t batch_size, int64_t num_heads, int64_t head_size) { // slice (s, b, n, 3, h) to (s, b, n, 1, h) SliceParams params; params.ndim = 4; params.dims[0] = seq_len; params.dims[1] = batch_size; params.dims[2] = num_heads; params.dims[3] = 3 * head_size; params.start[0] = 0; params.start[1] = 0; params.start[2] = 0; params.start[3] = 2 * head_size; params.step[0] = 1; params.step[1] = 1; params.step[2] = 1; params.step[3] = 1; params.size[0] = seq_len; params.size[1] = batch_size; params.size[2] = num_heads; params.size[3] = head_size; return params; } template void TransposeGpu(ep::Stream* stream, DataType data_type, const ShapeView& in_shape, const ShapeView& out_shape, const std::vector& perm, const T* in, T* out) { CHECK_EQ(in_shape.NumAxes(), out_shape.NumAxes()); int32_t num_axes = in_shape.NumAxes(); CHECK_EQ(num_axes, perm.size()); for (int i = 0; i < perm.size(); ++i) { CHECK_EQ(in_shape.At(perm[i]), out_shape.At(i)); } auto transpose = ep::primitive::NewPrimitive(stream->device_type(), in_shape.NumAxes()); CHECK(transpose); transpose->Launch(stream, data_type, in_shape.NumAxes(), in_shape.ptr(), in, perm.data(), out); } template class FusedSelfAttentionQueryMulKeyAndValueGpuKernel final : public user_op::OpKernel { public: FusedSelfAttentionQueryMulKeyAndValueGpuKernel() = default; ~FusedSelfAttentionQueryMulKeyAndValueGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* h_tensor = ctx->Tensor4ArgNameAndIndex("hidden_states", 0); int64_t seq_len = h_tensor->shape_view().At(0); int64_t batch_size = h_tensor->shape_view().At(1); int64_t hidden_size = h_tensor->shape_view().At(2); int64_t head_size = ctx->Attr("head_size"); int64_t num_heads = hidden_size / (3 * head_size); int64_t ld = batch_size * hidden_size; int64_t stride = 3 * head_size; int64_t k_offset = head_size; // q * k: (sq, b, n, h) x (sk, b, n, h) => (b, n, sq, h) x (b, n, sk, h) // => (b, n, sq, h) x (b, n, h, sk) -> (b, n, sq, sk) float alpha = ctx->Attr("alpha"); user_op::Tensor* qmk_tensor = ctx->Tensor4ArgNameAndIndex("query_mul_key", 0); const T* q_dptr = h_tensor->dptr(); const T* k_dptr = h_tensor->dptr() + k_offset; BatchedGemm(ctx->stream(), 'N', 'T', seq_len, seq_len, head_size, alpha, q_dptr, ld, stride, k_dptr, ld, stride, 0.0f, qmk_tensor->mut_dptr(), seq_len, seq_len * seq_len, batch_size * num_heads); // slice v user_op::Tensor* tmp_v_tensor = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* v_tensor = ctx->Tensor4ArgNameAndIndex("value", 0); SliceParams params = ConstructSliceParams4Value(seq_len, batch_size, num_heads, head_size); SliceKernelUtil::Forward(ctx->stream(), params, h_tensor->dptr(), tmp_v_tensor->mut_dptr()); // v from (s, b, n, h) transpose to (b, n, s, h) Shape value_shape({seq_len, batch_size, num_heads, head_size}); TransposeGpu(ctx->stream(), h_tensor->data_type(), value_shape, v_tensor->shape_view(), {1, 2, 0, 3}, tmp_v_tensor->dptr(), v_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel final : public user_op::OpKernel { public: FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel() = default; ~FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* v_grad_tensor = ctx->Tensor4ArgNameAndIndex("value_grad", 0); const user_op::Tensor* qmk_grad_tensor = ctx->Tensor4ArgNameAndIndex("query_mul_key_grad", 0); const user_op::Tensor* h_tensor = ctx->Tensor4ArgNameAndIndex("hidden_states", 0); user_op::Tensor* tmp_v_tensor = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* h_grad_tensor = ctx->Tensor4ArgNameAndIndex("hidden_states_grad", 0); float alpha = ctx->Attr("alpha"); int64_t seq_len = h_grad_tensor->shape_view().At(0); int64_t batch_size = h_grad_tensor->shape_view().At(1); int64_t hidden_size = h_grad_tensor->shape_view().At(2); int64_t num_heads = v_grad_tensor->shape_view().At(1); int64_t head_size = v_grad_tensor->shape_view().At(3); int64_t ld = batch_size * hidden_size; int64_t stride = 3 * head_size; CHECK_EQ(hidden_size, num_heads * stride); // transpose from (b, n, s, h) to (s, b, n, h) Shape value_shape({seq_len, batch_size, num_heads, head_size}); TransposeGpu(ctx->stream(), v_grad_tensor->data_type(), v_grad_tensor->shape_view(), value_shape, {2, 0, 1, 3}, v_grad_tensor->dptr(), tmp_v_tensor->mut_dptr()); // slice v grad SliceParams params = ConstructSliceParams4Value(seq_len, batch_size, num_heads, head_size); SliceKernelUtil::Backward(ctx->stream(), params, tmp_v_tensor->dptr(), h_grad_tensor->mut_dptr()); // grad_q = grad_qmk * k // (b, n, sq, sk) x (b, n, sk, h) -> (b, n, s, h) <= (s, b, n, h) <= (s, b, n, 3, h) const T* qmk_grad_dptr = qmk_grad_tensor->dptr(); const T* k_dptr = h_tensor->dptr() + head_size; T* grad_q_dptr = h_grad_tensor->mut_dptr(); BatchedGemm(ctx->stream(), 'N', 'N', seq_len, head_size, seq_len, alpha, qmk_grad_dptr, seq_len, seq_len * seq_len, k_dptr, ld, stride, 0.0f, grad_q_dptr, ld, stride, batch_size * num_heads); // grad_k = grad_qmk * q // (b, n, sk, sq) x (b, n, sq, h) -> (b, n, sk, h) <= (s, b, n, h) <= (s, b, n, 3, h) const T* q_dptr = h_tensor->dptr(); T* grad_k_dptr = h_grad_tensor->mut_dptr() + head_size; BatchedGemm(ctx->stream(), 'T', 'N', seq_len, head_size, seq_len, alpha, qmk_grad_dptr, seq_len, seq_len * seq_len, q_dptr, ld, stride, 0.0f, grad_k_dptr, ld, stride, batch_size * num_heads); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; size_t InferTmpBufferSize(user_op::InferContext* ctx) { const Shape& value_shape = ctx->OutputShape("value", 0); DataType value_dtype = ctx->OutputDType("value", 0); return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype); } size_t InferGradTmpBufferSize(user_op::InferContext* ctx) { const Shape& value_shape = ctx->InputShape("value_grad", 0); DataType value_dtype = ctx->InputDType("value_grad", 0); return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype); } } // namespace #define REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_self_attention_query_mul_key_and_value") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("hidden_states", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpBufferSize); #define REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_self_attention_query_mul_key_and_value_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("hidden_states", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferGradTmpBufferSize); REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(float) REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(float16) REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(float) REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(float16) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_softmax.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_ #define ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_ #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { namespace cuda { namespace fused_softmax { inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims, const int64_t* b_dims, size_t* simplified_num_dims, int64_t* simplified_a_dims, int64_t* simplified_b_dims) { const size_t num_max_dims = std::max(num_a_dims, num_b_dims); auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) { const int64_t num_padding_dims = num_max_dims - num_dims; return [num_padding_dims, dims](size_t index) { return index < num_padding_dims ? 1 : dims[index - num_padding_dims]; }; }; auto GetADim = MakeGetDim(num_a_dims, a_dims); auto GetBDim = MakeGetDim(num_b_dims, b_dims); *simplified_num_dims = 0; bool prev_broadcast_a = false; bool prev_broadcast_b = false; for (int64_t i = 0; i < num_max_dims; ++i) { const int64_t a_dim = GetADim(i); const int64_t b_dim = GetBDim(i); const int64_t broadcast_dim = std::max(a_dim, b_dim); CHECK_GT(broadcast_dim, 0); const bool broadcast_a = (a_dim == 1); const bool broadcast_b = (b_dim == 1); CHECK((a_dim == broadcast_dim) || broadcast_a); CHECK((b_dim == broadcast_dim) || broadcast_b); if (broadcast_dim == 1) { continue; } else if (*simplified_num_dims != 0 && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b)) { simplified_a_dims[*simplified_num_dims - 1] *= a_dim; simplified_b_dims[*simplified_num_dims - 1] *= b_dim; } else { simplified_a_dims[*simplified_num_dims] = a_dim; simplified_b_dims[*simplified_num_dims] = b_dim; *simplified_num_dims += 1; prev_broadcast_a = broadcast_a; prev_broadcast_b = broadcast_b; } } } template struct BroadcastMaskSoftmaxParams { NdIndexOffsetHelper src_index_helper; NdIndexOffsetHelper mask_index_helper; const int64_t* mask_dims{}; int64_t row_size; float fill; float scale; }; struct ElementwiseMaskSoftmaxParams { int64_t row_size; float fill; float scale; }; template struct BroadcastScaleMaskLoad { BroadcastScaleMaskLoad(const SRC* src, const MASK* mask, BroadcastMaskSoftmaxParams params) : src(src), mask(mask), params(params) { for (int i = 0; i < num_dims; i++) { mask_dims[i] = params.mask_dims[i]; } } template __device__ void load(DST* dst, int64_t row, int64_t col) { cuda::softmax::Pack pack; cuda::softmax::Pack mask_pack; const IndexType offset = row * params.row_size + col; IndexType input_index[num_dims]; IndexType mask_index[num_dims]; params.src_index_helper.OffsetToNdIndex(offset, input_index); for (int dim = 0; dim < num_dims; ++dim) { if (mask_dims[dim] == 1) { mask_index[dim] = 0; } else { mask_index[dim] = input_index[dim]; } } const IndexType mask_offset = params.mask_index_helper.NdIndexToOffset(mask_index); pack.storage = *(reinterpret_cast*>(src) + offset / N); mask_pack.storage = *(reinterpret_cast*>(mask) + mask_offset / N); #pragma unroll for (int i = 0; i < N; ++i) { if (mask_pack.elem[i] == 0) { dst[i] = static_cast(params.fill); } else { dst[i] = static_cast(pack.elem[i]) * static_cast(params.scale); } } } const SRC* src; const MASK* mask; int64_t mask_dims[num_dims]; BroadcastMaskSoftmaxParams params; }; template struct ElementwiseScaleMaskLoad { ElementwiseScaleMaskLoad(const SRC* src, const MASK* mask, ElementwiseMaskSoftmaxParams param) : src(src), mask(mask), param(param) {} template __device__ void load(DST* dst, int64_t row, int64_t col) { cuda::softmax::Pack pack; const int64_t offset = (row * param.row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { if (mask_pack.elem[i] == 0) { dst[i] = static_cast(param.fill); } else { dst[i] = static_cast(pack.elem[i]) * static_cast(param.scale); } } } const SRC* src; const MASK* mask; ElementwiseMaskSoftmaxParams param; }; template struct BroadcastScaleMaskStore { BroadcastScaleMaskStore(DST* dst, const MASK* mask, BroadcastMaskSoftmaxParams params) : dst(dst), mask(mask), params(params) { for (int i = 0; i < num_dims; ++i) { mask_dims[i] = params.mask_dims[i]; } } template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack pack; cuda::softmax::Pack mask_pack; const IndexType offset = row * params.row_size + col; IndexType input_index[num_dims]; IndexType mask_index[num_dims]; params.src_index_helper.OffsetToNdIndex(offset, input_index); for (int dim = 0; dim < num_dims; ++dim) { if (mask_dims[dim] == 1) { mask_index[dim] = 0; } else { mask_index[dim] = input_index[dim]; } } const IndexType mask_offset = params.mask_index_helper.NdIndexToOffset(mask_index); mask_pack.storage = *(reinterpret_cast*>(mask) + mask_offset / N); #pragma unroll for (int i = 0; i < N; ++i) { if (mask_pack.elem[i] == 0) { pack.elem[i] = static_cast(params.fill); } else { pack.elem[i] = static_cast(src[i]) * static_cast(params.scale); } } *(reinterpret_cast*>(dst) + offset / N) = pack.storage; } DST* dst; const MASK* mask; int64_t mask_dims[num_dims]; BroadcastMaskSoftmaxParams params; }; template struct ElementwiseScaleMaskStore { ElementwiseScaleMaskStore(DST* dst, const MASK* mask, ElementwiseMaskSoftmaxParams params) : dst(dst), mask(mask), params(params) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack pack; const int64_t offset = (row * params.row_size + col) / N; cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { if (mask_pack.elem[i] == 0) { pack.elem[i] = params.fill; } else { pack.elem[i] = static_cast(src[i]) * static_cast(params.scale); } } *(reinterpret_cast*>(dst) + offset) = pack.storage; } DST* dst; const MASK* mask; ElementwiseMaskSoftmaxParams params; }; template struct MaskScaleLoad { MaskScaleLoad(const SRC* src, const bool* mask, int64_t row_size, SRC scale) : src(src), mask(mask), row_size(row_size), scale(scale) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::softmax::Pack pack; const int64_t offset = (row * row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]) * static_cast(mask_pack.elem[i]) * static_cast(scale); } } const SRC* src; const bool* mask; int64_t row_size; SRC scale; }; template struct DropoutStore { DropoutStore(DST* dst, DST* softmax_y, const bool* mask, int64_t row_size, DST scale) : dst(dst), softmax_y(softmax_y), mask(mask), row_size(row_size), scale(scale) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack softmax_y_pack; cuda::softmax::Pack dst_pack; const int64_t offset = (row * row_size + col) / N; cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { softmax_y_pack.elem[i] = static_cast(src[i]); dst_pack.elem[i] = static_cast(src[i]) * static_cast(mask_pack.elem[i]) * static_cast(scale); } *(reinterpret_cast*>(softmax_y) + offset) = softmax_y_pack.storage; *(reinterpret_cast*>(dst) + offset) = dst_pack.storage; } DST* dst; DST* softmax_y; const bool* mask; int64_t row_size; DST scale; }; } // namespace fused_softmax } // namespace cuda } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_ ================================================ FILE: oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { template struct TrilScaleLoad { TrilScaleLoad(const SRC* src, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, SRC fill, SRC scale) : src(src), tril_num_rows(tril_num_rows), row_size(row_size), diagonal(diagonal), fill(fill), scale(scale) {} template __device__ void load(DST* dst, int64_t row, int64_t col) { int64_t tril_row = row % tril_num_rows; int64_t diagonal_col_id = tril_row + diagonal; bool need_load = (col <= diagonal_col_id); cuda::softmax::Pack pack; if (need_load) { const int64_t offset = (row * row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); } #pragma unroll for (int i = 0; i < N; ++i) { if (col + i > diagonal_col_id) { dst[i] = static_cast(fill); } else { dst[i] = static_cast(pack.elem[i]) * static_cast(scale); } } } const SRC* src; int64_t tril_num_rows; int64_t row_size; int64_t diagonal; SRC fill; SRC scale; }; template struct MaskAndScaleStore { MaskAndScaleStore(DST* dst, DST* softmax_y, const bool* mask, int64_t row_size, DST scale) : dst(dst), softmax_y(softmax_y), mask(mask), row_size(row_size), scale(scale) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack softmax_y_pack; cuda::softmax::Pack dst_pack; const int64_t offset = (row * row_size + col) / N; cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { softmax_y_pack.elem[i] = static_cast(src[i]); dst_pack.elem[i] = static_cast(src[i]) * static_cast(mask_pack.elem[i]) * static_cast(scale); } *(reinterpret_cast*>(softmax_y) + offset) = softmax_y_pack.storage; *(reinterpret_cast*>(dst) + offset) = dst_pack.storage; } DST* dst; DST* softmax_y; const bool* mask; int64_t row_size; DST scale; }; template struct MaskAndScaleLoad { MaskAndScaleLoad(const SRC* src, const bool* mask, int64_t row_size, SRC scale) : src(src), mask(mask), row_size(row_size), scale(scale) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::softmax::Pack pack; const int64_t offset = (row * row_size + col) / N; pack.storage = *(reinterpret_cast*>(src) + offset); cuda::softmax::Pack mask_pack; mask_pack.storage = *(reinterpret_cast*>(mask) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]) * static_cast(mask_pack.elem[i]) * static_cast(scale); } } const SRC* src; const bool* mask; int64_t row_size; SRC scale; }; template struct TrilScaleStore { TrilScaleStore(DST* dst, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, DST fill, DST scale) : dst(dst), tril_num_rows(tril_num_rows), row_size(row_size), diagonal(diagonal), fill(fill), scale(scale) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack pack; const int64_t offset = (row * row_size + col) / N; int64_t tril_row = row % tril_num_rows; #pragma unroll for (int i = 0; i < N; ++i) { if (col + i > tril_row + diagonal) { pack.elem[i] = fill; } else { pack.elem[i] = static_cast(src[i]) * static_cast(scale); } } *(reinterpret_cast*>(dst) + offset) = pack.storage; } DST* dst; int64_t tril_num_rows; int64_t row_size; int64_t diagonal; DST fill; DST scale; }; template class FusedTrilScaleSoftmaxMaskScaleKernel final : public user_op::OpKernel { public: FusedTrilScaleSoftmaxMaskScaleKernel() = default; ~FusedTrilScaleSoftmaxMaskScaleKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex("softmax_y", 0); const ShapeView& x_shape = x->shape_view(); CHECK_GE(x_shape.NumAxes(), 2); const int64_t cols = x_shape.At(x_shape.NumAxes() - 1); const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1); const int64_t tril_num_rows = x_shape.At(x_shape.NumAxes() - 2); using ComputeType = typename cuda::softmax::DefaultComputeType::type; TrilScaleLoad load( x->dptr(), tril_num_rows, cols, ctx->Attr("diagonal"), ctx->Attr("tril_fill_value"), ctx->Attr("tril_scale_value")); MaskAndScaleStore store(y->mut_dptr(), softmax_y->mut_dptr(), mask->dptr(), cols, ctx->Attr("mask_scale_value")); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax( ctx->stream()->As()->cuda_stream(), load, store, rows, cols))); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_tril_scale_softmax_mask_scale") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(half) REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(float) REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(double) #undef REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL template class FusedTrilScaleSoftmaxMaskScaleGradKernel final : public user_op::OpKernel { public: FusedTrilScaleSoftmaxMaskScaleGradKernel() = default; ~FusedTrilScaleSoftmaxMaskScaleGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex("softmax_y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const ShapeView& dy_shape = dy->shape_view(); CHECK_GE(dy_shape.NumAxes(), 2); const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1); const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1); const int64_t tril_num_rows = dy_shape.At(dy_shape.NumAxes() - 2); using ComputeType = typename cuda::softmax::DefaultComputeType::type; cuda::softmax::DirectLoad load_softmax_y(softmax_y->dptr(), cols); MaskAndScaleLoad load_dy(dy->dptr(), mask->dptr(), cols, ctx->Attr("mask_scale_value")); TrilScaleStore store(dx->mut_dptr(), tril_num_rows, cols, ctx->Attr("diagonal"), static_cast(0.0), ctx->Attr("tril_scale_value")); OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad( ctx->stream()->As()->cuda_stream(), load_softmax_y, load_dy, store, rows, cols))); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_tril_scale_softmax_mask_scale_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(half) REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(float) REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(double) #undef REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_weighted_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cpu/cpu_stream.h" namespace oneflow { namespace { template class FusedWeightedSumKernel final : public user_op::OpKernel { public: FusedWeightedSumKernel() = default; ~FusedWeightedSumKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t arity = ctx->input_size("in"); CHECK_GE(arity, 1); const std::vector& weights = ctx->Attr>("weights"); CHECK_EQ(weights.size(), arity); const float alpha = ctx->Attr("alpha"); const DataType data_type = out->data_type(); const ShapeView& shape = out->shape_view(); std::vector inputs(arity); for (int i = 0; i < arity; ++i) { const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex("in", i); CHECK(in_i->shape_view() == shape); CHECK_EQ(in_i->data_type(), data_type); inputs[i] = in_i->dptr(); } T* out_ptr = out->mut_dptr(); auto* cpu_stream = ctx->stream()->As(); cpu_stream->ParallelFor(0, shape.elem_cnt(), [&](int64_t s, int64_t e) { for (int64_t i = s; i < e; ++i) { T out = static_cast(0.0); for (int j = 0; j < arity; ++j) { out += inputs[j][i] * static_cast(weights[j]); } out_ptr[i] = out * static_cast(alpha); } }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_FUSED_WEIGHT_SUM_KERNEL(data_type, cpp_type) \ REGISTER_USER_KERNEL("fused_weighted_sum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == data_type)) REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kDouble, double); REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat, float); REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat16, float16); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/fused_weighted_sum_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template struct Params { const T* inputs[arity]; float weights[arity]; float alpha{}; T* output; int64_t n; }; template __global__ void WeightedSumKernel(Params params) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.n) { T out = 0; if (acc) { out = params.output[i]; } #pragma unroll for (int j = 0; j < arity; ++j) { out += params.inputs[j][i] * static_cast(params.weights[j]); } params.output[i] = out * static_cast(params.alpha); } } template void LaunchWeightedSum(ep::Stream* stream, int n, const T** inputs, const float* weights, float alpha, T* output) { Params params{}; for (int i = 0; i < arity; ++i) { params.inputs[i] = *(inputs + i); params.weights[i] = *(weights + i); } params.alpha = alpha; params.output = output; params.n = n; RUN_CUDA_KERNEL((WeightedSumKernel), stream, n, params); } template void DispatchWeightedSum(ep::Stream* stream, int arity, int64_t n, const T** inputs, const float* weights, float alpha, T* output) { if (arity == 1) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 2) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 3) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 4) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 5) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 6) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 7) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity == 8) { LaunchWeightedSum(stream, n, inputs, weights, alpha, output); } else if (arity > 8) { LaunchWeightedSum(stream, n, inputs, weights, 1.0F, output); DispatchWeightedSum(stream, arity - 8, n, inputs + 8, weights + 8, alpha, output); } else { UNIMPLEMENTED(); } } template class FusedWeightedSumKernel final : public user_op::OpKernel { public: FusedWeightedSumKernel() = default; ~FusedWeightedSumKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t arity = ctx->input_size("in"); CHECK_GE(arity, 1) << "input_size should be greater than 0."; const std::vector& weights = ctx->Attr>("weights"); CHECK_EQ(weights.size(), arity); const float alpha = ctx->Attr("alpha"); const DataType data_type = out->data_type(); const ShapeView& shape = out->shape_view(); std::vector inputs(arity); for (int i = 0; i < arity; ++i) { const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex("in", i); CHECK(in_i->shape_view() == shape); CHECK_EQ(in_i->data_type(), data_type); inputs[i] = in_i->dptr(); } DispatchWeightedSum(ctx->stream(), arity, shape.elem_cnt(), inputs.data(), weights.data(), alpha, out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_FUSED_WEIGHT_SUM_KERNEL(data_type, cpp_type) \ REGISTER_USER_KERNEL("fused_weighted_sum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type)) REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kDouble, double); REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat, float); REGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat16, half); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/gather_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/gather_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { namespace user_op { namespace { Shape GetFlatShape(ShapeView shape, int64_t axis) { return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); } class GatherOpKernelCache final : public user_op::OpKernelCache { public: GatherOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~GatherOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; void CheckNdSbp(const Shape& hierarchy, int64_t gather_axis, const NdSbp& in_nd_sbp, const NdSbp& indices_nd_sbp, const NdSbp& out_nd_sbp) { CHECK_EQ(hierarchy.NumAxes(), in_nd_sbp.sbp_parallel_size()); CHECK_EQ(hierarchy.NumAxes(), indices_nd_sbp.sbp_parallel_size()); CHECK_EQ(hierarchy.NumAxes(), out_nd_sbp.sbp_parallel_size()); if (hierarchy.elem_cnt() == 1) { return; } FOR_RANGE(int64_t, i, 0, hierarchy.NumAxes()) { const auto& in_sbp = in_nd_sbp.sbp_parallel(i); if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == gather_axis) { CHECK(indices_nd_sbp.sbp_parallel(i).has_broadcast_parallel()); CHECK(out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()); } } } } // namespace template class GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GatherKernel() = default; ~GatherKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const auto axis = ctx->Attr("axis"); const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); CheckNdSbp(hierarchy, axis, in_nd_sbp, ctx->NdSbp4ArgNameAndIndex("indices", 0), ctx->NdSbp4ArgNameAndIndex("out", 0)); const Shape in_logical_shape = ExpandDimIf0D(ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0)->shape()); TensorSliceView view = GetTensorSliceView4ParallelId(hierarchy, in_nd_sbp, in_logical_shape, ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(axis).begin(), view.At(axis).end()); } else { return nullptr; } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const int64_t axis = ctx->Attr("axis"); const int64_t num_indices = indices->shape_view().elem_cnt(); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); if (out->shape_view().elem_cnt() == 0) { return; } const Shape in_shape = ExpandDimIf0D(in->shape_view()); int64_t offset = 0; if (cache != nullptr) { auto* gather_cache = dynamic_cast(cache); CHECK_NOTNULL(gather_cache); CHECK_EQ(in_shape.At(axis), gather_cache->upper() - gather_cache->lower()); offset = gather_cache->lower(); } GatherKernelUtilImpl::Forward(ctx->stream(), indices->dptr(), num_indices, in->dptr(), GetFlatShape(in_shape, axis), out->mut_dptr(), offset); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GATHER_KERNEL(device, in_type, indices_type) \ REGISTER_USER_KERNEL("gather") \ .SetCreateFn< \ GatherKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in", 0) == OF_PP_PAIR_SECOND(in_type)) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(indices_type))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, DEVICE_TYPE_SEQ, GATHER_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) // For cpu float16/bfloat16 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA // For cuda half OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #if CUDA_VERSION >= 11000 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16), INDEX_DATA_TYPE_SEQ) #endif #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/gather_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/gather_kernel_util.h" namespace oneflow { namespace { Shape GetFlatShape(const ShapeView& shape, int64_t axis) { CHECK_GT(shape.NumAxes(), 0); CHECK_GE(axis, 0); CHECK_LT(axis, shape.NumAxes()); return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); } template void GatherForward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis, Blob* out, const int64_t offset) { const Shape& flat_in_shape = GetFlatShape(in->shape_view(), axis); GatherKernelUtilImpl::Forward(stream, indices->dptr(), indices->shape_view().elem_cnt(), in->dptr(), flat_in_shape, out->mut_dptr(), offset); } template struct GatherSwitchUtil final { #define MAKE_GATHER_SWITCH_ENTRY(func_name, K) func_name #define DEFINE_GATHER_STATIC_SWITCH_FUNC(func_name) \ DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_GATHER_SWITCH_ENTRY, \ MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ)); DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherForward); #undef DEFINE_GATHER_STATIC_SWITCH_FUNC #undef MAKE_GATHER_SWITCH_ENTRY }; } // namespace template void GatherKernelUtil::Forward(ep::Stream* stream, const Blob* indices, const Blob* in, const int64_t axis, Blob* out) { GatherKernelUtil::Forward(stream, indices, in, axis, out, 0); } template void GatherKernelUtil::Forward(ep::Stream* stream, const Blob* indices, const Blob* in, const int64_t axis, Blob* out, const int64_t offset) { GatherSwitchUtil::SwitchGatherForward(SwitchCase(indices->data_type()), stream, indices, in, axis, out, offset); } template struct GatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in, const Shape& flat_in_shape, T* out, const int64_t offset); }; template void GatherKernelUtilImpl::Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in, const Shape& flat_in_shape, T* out, const int64_t offset) { const int64_t outer_dim_size = flat_in_shape.At(0); const int64_t gather_dim_size = flat_in_shape.At(1); const int64_t inner_dim_size = flat_in_shape.At(2); FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) { FOR_RANGE(int64_t, i, 0, num_indices) { CHECK_GE(indices[i], 0); const int64_t idx = indices[i] - offset; T* to = out + outer_idx * num_indices * inner_dim_size + i * inner_dim_size; if (idx >= 0 && idx < gather_dim_size) { const T* from = in + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size; std::copy(from, from + inner_dim_size, to); } else { std::memset(reinterpret_cast(to), 0, inner_dim_size * sizeof(T)); } } } } #define INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL(in_type_pair, index_type_pair) \ template struct GatherKernelUtilImpl; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL, GATHER_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ, GATHER_INDEX_TYPE_SEQ); #undef INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL #define INITIATE_GATHER_KERNEL_UTIL(device_type, in_type_pair) \ template struct GatherKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL, DEVICE_TYPE_SEQ, GATHER_DATA_TYPE_SEQ); // For cpu float16/bfloat16 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ); #undef INITIATE_GATHER_KERNEL_UTIL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/gather_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/gather_kernel_util.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace { template __global__ void GatherForwardGpu(const IDX elem_cnt, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const K* indices, const T* in, const IDX gather_dim_size, T* out, const IDX offset) { IDX index[3]; CUDA_1D_KERNEL_LOOP_T(IDX, i, elem_cnt) { out_helper.OffsetToNdIndex(i, index); index[1] = indices[index[1]] - offset; T v{}; if (index[1] >= 0 && index[1] < gather_dim_size) { v = in[in_helper.NdIndexToOffset(index)]; } out[i] = v; } } bool IsSafeUseIndex32(int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size, int64_t num_indices) { const int64_t in_elem_cnt = outer_dim_size * gather_dim_size * inner_dim_size; const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size; return std::max(out_elem_cnt, in_elem_cnt) < GetMaxVal() / 2; } template void DispatchIndexSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size, int64_t num_indices, int64_t offset, const K* indices, const T* in, T* out) { const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size; if (IsSafeUseIndex32(outer_dim_size, gather_dim_size, inner_dim_size, num_indices)) { NdIndexOffsetHelper in_helper(outer_dim_size, gather_dim_size, inner_dim_size); NdIndexOffsetHelper out_helper(outer_dim_size, num_indices, inner_dim_size); GatherForwardGpu<<As()->cuda_stream()>>>( out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset); } else { NdIndexOffsetHelper in_helper(outer_dim_size, gather_dim_size, inner_dim_size); NdIndexOffsetHelper out_helper(outer_dim_size, num_indices, inner_dim_size); GatherForwardGpu<<As()->cuda_stream()>>>( out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset); } } template bool TryDispatchMovementType(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size, int64_t num_indices, int64_t offset, const K* indices, const void* in, void* out) { if (reinterpret_cast(in) % sizeof(T) == 0 && reinterpret_cast(out) % sizeof(T) == 0 && inner_dim_size % sizeof(T) == 0) { DispatchIndexSize(stream, outer_dim_size, gather_dim_size, inner_dim_size / sizeof(T), num_indices, offset, indices, static_cast(in), static_cast(out)); return true; } else { return false; } } template void DispatchMovementSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size, int64_t num_indices, int64_t offset, const K* indices, const void* in, void* out) { using Func = bool (*)(ep::Stream * stream, int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size, int64_t num_indices, int64_t offset, const K* indices, const void* in, void* out); Func funcs[] = { TryDispatchMovementType, // 16B TryDispatchMovementType, // 8B TryDispatchMovementType, // 4B TryDispatchMovementType, // 2B TryDispatchMovementType, // 1B }; for (size_t i = 0; i < sizeof(funcs) / sizeof(funcs[0]); ++i) { if (funcs[i](stream, outer_dim_size, gather_dim_size, inner_dim_size, num_indices, offset, indices, in, out)) { break; } } } } // namespace template struct GatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in, const Shape& flat_in_shape, T* out, const int64_t offset) { DispatchMovementSize(stream, flat_in_shape.At(0), flat_in_shape.At(1), flat_in_shape.At(2) * sizeof(T), num_indices, offset, indices, in, out); } }; #define INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL(in_type_pair, index_type_pair) \ template struct GatherKernelUtilImpl; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL, GATHER_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, GATHER_INDEX_TYPE_SEQ); #if CUDA_VERSION >= 11000 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL, OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16), GATHER_INDEX_TYPE_SEQ); #endif #undef INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/gather_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct GatherKernelUtil final { static void Forward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis, Blob* out); static void Forward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis, Blob* out, int64_t offset); }; template struct GatherKernelUtilImpl final { static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in, const Shape& flat_in_shape, T* out, int64_t offset); }; #define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) #define GATHER_INDEX_TYPE_SEQ INDEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { class GenerateRandomBatchPermutationIndicesCPUKernel final : public user_op::OpKernel { public: GenerateRandomBatchPermutationIndicesCPUKernel() = default; ~GenerateRandomBatchPermutationIndicesCPUKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { int64_t seed = ctx->Attr("seed"); return std::make_shared>(seed); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* random_generator = dynamic_cast*>(state); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); std::iota(y->mut_dptr(), y->mut_dptr() + y->shape_view().elem_cnt(), 0); std::shuffle(y->mut_dptr(), y->mut_dptr() + y->shape_view().elem_cnt(), *random_generator->Mutable()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("generate_random_batch_permutation_indices") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCPU); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/random_generator.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(const int32_t& batch_size, const int32_t& capacity, void* ptr) : capacity_{capacity}, random_value_elem_cnt_{batch_size}, sorted_value_elem_cnt_{batch_size}, indices_elem_cnt_{batch_size} { const int32_t random_value_aligned_bytes = GetCudaAlignedSize(random_value_elem_cnt_ * sizeof(float)); const int32_t sorted_value_aligned_bytes = GetCudaAlignedSize(sorted_value_elem_cnt_ * sizeof(float)); const int32_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int32_t)); random_value_ptr_ = reinterpret_cast(ptr); sorted_value_ptr_ = reinterpret_cast(reinterpret_cast(random_value_ptr_) + random_value_aligned_bytes); indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_value_ptr_) + sorted_value_aligned_bytes); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(indices_ptr_) + indices_aligned_bytes); temp_storage_bytes_ = capacity_ - random_value_aligned_bytes - sorted_value_aligned_bytes - indices_aligned_bytes; CHECK_GE(temp_storage_bytes_, 0); } ~TmpBufferManager() = default; float* RandomValuePtr() const { return random_value_ptr_; } float* SortedValuePtr() const { return sorted_value_ptr_; } int32_t* IndicesPtr() const { return indices_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } int32_t TempStorageBytes() const { return temp_storage_bytes_; } private: int32_t capacity_; float* random_value_ptr_; float* sorted_value_ptr_; int32_t* indices_ptr_; void* temp_storage_ptr_; int32_t random_value_elem_cnt_; int32_t sorted_value_elem_cnt_; int32_t indices_elem_cnt_; int32_t temp_storage_bytes_; }; __global__ void InitializeIndices(int32_t elem_cnt, int32_t* indices_ptr) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i; }; } } // namespace class GenerateRandomBatchPermutationIndicesGPUKernel final : public user_op::OpKernel { public: GenerateRandomBatchPermutationIndicesGPUKernel() = default; ~GenerateRandomBatchPermutationIndicesGPUKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { int64_t seed = ctx->Attr("seed"); return std::make_shared>>( seed, ctx->stream()); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* random_generator = dynamic_cast>*>(state); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t batch_size = y->shape_view().At(0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buf_manager(batch_size, static_cast(tmp_buffer->shape_view().elem_cnt()), tmp_buffer->mut_dptr()); random_generator->Mutable()->Uniform(batch_size, buf_manager.RandomValuePtr()); InitializeIndices<<stream()->As()->cuda_stream()>>>( batch_size, buf_manager.IndicesPtr()); const int32_t argsort_instance_num = 1; const int32_t argsort_instance_size = batch_size; SortPairsAscending(buf_manager.RandomValuePtr(), buf_manager.IndicesPtr(), argsort_instance_num, argsort_instance_size, buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedValuePtr(), y->mut_dptr(), ctx->stream()->As()->cuda_stream()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("generate_random_batch_permutation_indices") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { const Shape& y_shape = ctx->OutputShape("y", 0); const int32_t batch_size = y_shape.At(0); const int32_t random_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); const int32_t sorted_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); const int32_t indices_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(int32_t)); const int32_t argsort_instance_num = 1; const int32_t argsort_instance_size = batch_size; const int32_t temp_storage_bytes = InferTempStorageForSortPairsAscending( argsort_instance_num, argsort_instance_size); return random_value_aligned_bytes + sorted_value_aligned_bytes + indices_aligned_bytes + temp_storage_bytes; }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/gpt_data_loader_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/data/gpt_dataset.h" #include "oneflow/user/data/distributed_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace { using namespace user_op; using namespace data; size_t GetNumShards(const Shape& hierarchy, const NdSbp& nd_sbp) { size_t num_shards = 1; FOR_RANGE(size_t, i, 0, nd_sbp.sbp_parallel_size()) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { num_shards *= hierarchy.At(sbp_parallel.split_parallel().axis()); } } return num_shards; } size_t GetShardIndex(const Shape& hierarchy, const NdSbp& nd_sbp, size_t rank) { using index_helper_t = NdIndexOffsetHelper; size_t ndim = hierarchy.NumAxes(); CHECK_GT(ndim, 0); CHECK_LE(ndim, SHAPE_MAX_AXIS_SIZE); index_helper_t index_helper(hierarchy.dim_vec().data(), ndim); int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0}; index_helper.OffsetToNdIndex(rank, nd_index); size_t stride = 1; size_t index = 0; for (int i = ndim - 1; i >= 0; --i) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { index += nd_index[i] * stride; stride *= hierarchy.At(i); } } return index; } class GPTDataLoader final : public OpKernelState { public: GPTDataLoader(KernelInitContext* ctx) : batch_cnt_(0) { seq_len_ = ctx->Attr("seq_length"); label_len_ = 1; int64_t num_samples = ctx->Attr("num_samples"); dataset_ = std::make_unique( ctx->Attr("data_file_prefix"), seq_len_, label_len_, num_samples, ctx->Attr>("split_sizes"), ctx->Attr("split_index"), ctx->Attr("shuffle"), ctx->Attr("random_seed")); batch_size_ = ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape().At(0); CHECK_JUST(InitDataSourceDistributedInfo(ctx, num_shards_, shard_index_)); } ~GPTDataLoader() = default; template void GetBatch(size_t iter, user_op::Tensor* tokens) const { const size_t sample_len = seq_len_ + label_len_; CHECK_EQ(tokens->shape_view().NumAxes(), 2); CHECK_EQ(tokens->shape_view().At(0), batch_size_); CHECK_EQ(tokens->shape_view().At(1), sample_len); T* dptr = tokens->mut_dptr(); for (size_t i = 0; i < batch_size_; ++i) { size_t sample_iter = iter * batch_size_ * num_shards_ + shard_index_ * batch_size_ + i; dataset_->GetSample(sample_iter, dptr + i * sample_len); } } template void NextBatch(user_op::Tensor* tokens) { GetBatch(batch_cnt_, tokens); batch_cnt_ += 1; } private: std::unique_ptr dataset_; size_t seq_len_; size_t label_len_; size_t batch_size_; size_t num_shards_; int64_t shard_index_; size_t batch_cnt_; }; template class GPTDataLoaderKernel final : public OpKernel { public: GPTDataLoaderKernel() = default; ~GPTDataLoaderKernel() = default; std::shared_ptr CreateOpKernelState(KernelInitContext* ctx) const override { std::shared_ptr reader(new GPTDataLoader(ctx)); return reader; } private: void Compute(KernelComputeContext* ctx, OpKernelState* state, const OpKernelCache*) const override { auto* loader = dynamic_cast(state); user_op::Tensor* iteration_tensor = ctx->Tensor4ArgNameAndIndex("iteration", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); if (iteration_tensor) { CHECK_EQ(iteration_tensor->shape_view().elem_cnt(), 1); CHECK_EQ(iteration_tensor->data_type(), DataType::kInt64); int64_t* iter_ptr = iteration_tensor->mut_dptr(); loader->GetBatch(*iter_ptr, out_tensor); *iter_ptr += 1; } else { loader->NextBatch(out_tensor); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_GPT_DATA_LOADER_KERNEL(dtype) \ REGISTER_USER_KERNEL("megatron_gpt_mmap_data_loader") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) REGISTER_GPT_DATA_LOADER_KERNEL(int32_t); REGISTER_GPT_DATA_LOADER_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/greater_inplace_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/greater_inplace_kernel_util.h" namespace oneflow { template class GreaterInplaceKernel final : public user_op::OpKernel { public: GreaterInplaceKernel() = default; ~GreaterInplaceKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const int64_t elem_cnt = x->shape_view().elem_cnt(); if (elem_cnt == 0) { return; } const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const T* x_ptr = x->dptr(); const T* y_ptr = y->dptr(); T* out_ptr = out->mut_dptr(); T* broadcast_y_ptr = tmp_buffer->mut_dptr(); if (x->shape_view() == y->shape_view()) { GreaterInplaceKernelUtil::Forward(ctx->stream(), elem_cnt, x_ptr, y_ptr, out_ptr); return; } GreaterInplaceKernelUtil::YBroadcastToX( ctx->stream(), elem_cnt, x_ptr, y_ptr, broadcast_y_ptr, x->shape_view(), y->shape_view()); GreaterInplaceKernelUtil::Forward(ctx->stream(), elem_cnt, x_ptr, broadcast_y_ptr, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GREATER_INPLACE_KERNEL(device_type, dtype) \ REGISTER_USER_KERNEL("broadcast_inplace_greater") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& x_shape = ctx->InputShape("x", 0); \ return GetCudaAlignedSize(x_shape.elem_cnt() * sizeof(dtype)); \ }); REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, float) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, double) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int8_t) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int32_t) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int64_t) #ifdef WITH_CUDA REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, half) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, float) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, double) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int8_t) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int32_t) REGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int64_t) #endif // WITH_CUDA template class ScalarGreaterInplaceKernel final : public user_op::OpKernel { public: ScalarGreaterInplaceKernel() = default; ~ScalarGreaterInplaceKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const int64_t elem_cnt = in->shape_view().elem_cnt(); if (elem_cnt == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Scalar scalar_operand; if (ctx->Attr("has_int_operand")) { scalar_operand = ctx->Attr("int_operand"); } else if (ctx->Attr("has_float_operand")) { scalar_operand = ctx->Attr("float_operand"); } else { UNIMPLEMENTED(); } const T* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); ScalarGreaterInplaceKernelUtil::Forward(ctx->stream(), elem_cnt, in_ptr, scalar_operand, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCALAR_GREATER_INPLACE_KERNEL(device_type, dtype, value_type) \ REGISTER_USER_KERNEL("scalar_logical_inplace_greater") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, float, double) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, double, double) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int8_t, int64_t) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int32_t, int64_t) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int64_t, int64_t) #ifdef WITH_CUDA REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, half, double) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, float, double) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, double, double) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int8_t, int64_t) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int32_t, int64_t) REGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int64_t, int64_t) #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/greater_inplace_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/greater_inplace_kernel_util.h" namespace oneflow { template struct GreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out) { FOR_RANGE(int64_t, i, 0, n) { out[i] = x[i] > y[i] ? static_cast(1) : static_cast(0); } } }; template struct ScalarGreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand, T* out) { FOR_RANGE(int64_t, i, 0, n) { out[i] = x[i] > static_cast(operand.Value()) ? static_cast(1) : static_cast(0); } } }; #define INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU(data_type, other) \ template struct GreaterInplaceKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU, GREATER_INPLACE_DATA_TYPE_SEQ_CPU) #undef INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU #define INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU(data_type, value_data_type) \ template struct ScalarGreaterInplaceKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU, GREATER_INPLACE_DATA_TYPE_SEQ_CPU, SCALAR_VALUE_DATA_TYPE_SEQ) #undef INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/greater_inplace_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/greater_inplace_kernel_util.h" namespace oneflow { namespace { template __global__ void GreaterInplacForwardGpu(const int64_t n, const T* x, const T* y, T* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { out[i] = x[i] > y[i] ? static_cast(1) : static_cast(0); } } template __global__ void ScalarGreaterInplacForwardGpu(const int64_t n, const T* x, const Scalar operand, T* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { out[i] = x[i] > static_cast(operand.Value()) ? static_cast(1) : static_cast(0); } } template<> __global__ void ScalarGreaterInplacForwardGpu(const int64_t n, const half* x, const Scalar operand, half* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { float operator_value = static_cast(operand.Value()); out[i] = x[i] > __float2half(operator_value) ? static_cast(1) : static_cast(0); } } template<> __global__ void ScalarGreaterInplacForwardGpu(const int64_t n, const half* x, const Scalar operand, half* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { float operator_value = static_cast(operand.Value()); out[i] = x[i] > __float2half(operator_value) ? static_cast(1) : static_cast(0); } } } // namespace template struct GreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out) { RUN_CUDA_KERNEL((GreaterInplacForwardGpu), stream, n, n, x, y, out); } }; template struct ScalarGreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand, T* out) { RUN_CUDA_KERNEL((ScalarGreaterInplacForwardGpu), stream, n, n, x, operand, out); } }; #define INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA(data_type, other) \ template struct GreaterInplaceKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA, GREATER_INPLACE_DATA_TYPE_SEQ_CUDA) #undef INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA #define INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA(data_type, value_data_type) \ template struct ScalarGreaterInplaceKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA, GREATER_INPLACE_DATA_TYPE_SEQ_CUDA, SCALAR_VALUE_DATA_TYPE_SEQ) #undef INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/greater_inplace_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_ #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template struct ScalarGreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand, T* out); }; template struct GreaterInplaceKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out); static void YBroadcastToX(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* broadcast_y, const ShapeView& x_shape, const ShapeView& y_shape) { const int64_t x_ndim = x_shape.NumAxes(); const int64_t y_ndim = y_shape.NumAxes(); const int64_t num_prepend = x_ndim - y_ndim; std::vector prepend_shape(num_prepend, 1); std::vector broadcast_axes; for (int i = 0; i < y_ndim; ++i) { prepend_shape.emplace_back(y_shape.At(i)); } for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); } for (int i = num_prepend; i < prepend_shape.size(); ++i) { if (prepend_shape[i] != x_shape.At(i)) { if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); } } } const Shape& reduced_shape = CreateReducedShapeOrOnesShape(x_shape, {broadcast_axes.begin(), broadcast_axes.end()}); NdarrayUtil::BroadcastTo(stream, XpuVarNdarray(x_shape, broadcast_y), XpuVarNdarray(reduced_shape, y)); } }; #define SCALAR_VALUE_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define GREATER_INPLACE_DATA_TYPE_SEQ_CPU \ FLOATING_DATA_TYPE_SEQ \ SIGNED_INT_DATA_TYPE_SEQ #ifdef WITH_CUDA #define GREATER_INPLACE_DATA_TYPE_SEQ_CUDA \ FLOATING_DATA_TYPE_SEQ \ SIGNED_INT_DATA_TYPE_SEQ \ HALF_DATA_TYPE_SEQ #endif } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/grid_sample_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/config_def.h" #include "grid_sample_kernel_util.h" namespace oneflow { template class GridSampleKernel final : public user_op::OpKernel { public: GridSampleKernel() = default; ~GridSampleKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const std::string interpolation_mode = ctx->Attr("interpolation_mode"); const std::string padding_mode = ctx->Attr("padding_mode"); GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode); GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode); const bool align_corners = ctx->Attr("align_corners"); const ShapeView& input_shape = input->shape_view(); const ShapeView& grid_shape = grid->shape_view(); const ShapeView& output_shape = output->shape_view(); int64_t count = output_shape.elem_cnt() / input_shape.At(1); if (input_shape.NumAxes() == 4) { if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { GridSampleKernelUtil::Forward4D( ctx, input, grid, output, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } else { GridSampleKernelUtil::Forward4D( ctx, input, grid, output, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } } else { if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { GridSampleKernelUtil::Forward5D( ctx, input, grid, output, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } else { GridSampleKernelUtil::Forward5D( ctx, input, grid, output, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GRID_SAMPLE_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("grid_sample") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, float); REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, double); #ifdef WITH_CUDA REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCUDA, float); REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCUDA, double); #endif template class GridSampleGradKernel final : public user_op::OpKernel { public: GridSampleGradKernel() = default; ~GridSampleGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex("doutput", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex("dinput", 0); user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); const std::string interpolation_mode = ctx->Attr("interpolation_mode"); const std::string padding_mode = ctx->Attr("padding_mode"); GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode); GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode); const bool align_corners = ctx->Attr("align_corners"); const ShapeView& input_shape = input->shape_view(); const ShapeView& grid_shape = grid->shape_view(); const ShapeView& output_shape = doutput->shape_view(); int64_t count = output_shape.elem_cnt() / input_shape.At(1); Memset(ctx->stream(), dinput->mut_dptr(), 0, input_shape.elem_cnt() * sizeof(data_type)); if (input_shape.NumAxes() == 4) { if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { GridSampleKernelUtil::Backward4D( ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } else { GridSampleKernelUtil::Backward4D( ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } } else { if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { GridSampleKernelUtil::Backward5D( ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } else { GridSampleKernelUtil::Backward5D( ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, input_shape, grid_shape, output_shape, count); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GRID_SAMPLE_GRAD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("grid_sample_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, float); REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, double); #ifdef WITH_CUDA REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCUDA, float); REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCUDA, double); #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/grid_sample_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "grid_sample_kernel_util.h" namespace oneflow { template struct GridSampleKernelUtil final { static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { GridSampler4DKernel( count, input->dptr(), grid->dptr(), output->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners); } static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { GridSampler5DKernel( count, input->dptr(), grid->dptr(), output->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, align_corners); } static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { GridSampler4DBackwardKernel( count, doutput->dptr(), input->dptr(), grid->dptr(), dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt()); } static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { GridSampler5DBackwardKernel( count, doutput->dptr(), input->dptr(), grid->dptr(), dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, align_corners, input_shape.elem_cnt()); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/grid_sample_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "grid_sample_kernel_util.h" namespace oneflow { class CudnnGridSampleDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnGridSampleDesc); CudnnGridSampleDesc(DataType data_type, const ShapeView& shape) { std::vector tensor_dim({shape.ptr(), shape.ptr() + shape.NumAxes()}); OF_CUDNN_CHECK(cudnnCreateSpatialTransformerDescriptor(&val_)); OF_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(val_, CUDNN_SAMPLER_BILINEAR, GetCudnnDataType(data_type), shape.NumAxes(), tensor_dim.data())); } ~CudnnGridSampleDesc() { OF_CUDNN_CHECK(cudnnDestroySpatialTransformerDescriptor(val_)); } const cudnnSpatialTransformerDescriptor_t& Get() const { return val_; } private: cudnnSpatialTransformerDescriptor_t val_; }; template struct CudnnGridSampleKernelUtil { static bool CanRunWithCudnn(user_op::KernelComputeContext* ctx) { if (ctx->Attr("interpolation_mode") != "bilinear" || ctx->Attr("padding_mode") != "zeros" || !ctx->Attr("align_corners")) { return false; } const ShapeView& input_shape = ctx->Tensor4ArgNameAndIndex("input", 0)->shape_view(); if (input_shape.NumAxes() != 4 || input_shape.At(1) > 1024) { return false; } return true; } static void ForwardCompute(user_op::KernelComputeContext* ctx) { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const ShapeView& input_shape = input->shape_view(); const ShapeView& output_shape = output->shape_view(); const DataType dtype = input->data_type(); CudnnTensorDesc input_desc(dtype, input_shape, "channels_first"); CudnnTensorDesc output_desc(dtype, output_shape, "channels_first"); CudnnGridSampleDesc transfomer_desc(dtype, output_shape); OF_CUDNN_CHECK(cudnnSpatialTfSamplerForward( ctx->stream()->As()->cudnn_handle(), transfomer_desc.Get(), CudnnSPOnePtr(), input_desc.Get(), input->dptr(), grid->dptr(), CudnnSPZeroPtr(), output_desc.Get(), output->mut_dptr())); } static void BackwardCompute(user_op::KernelComputeContext* ctx) { const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex("doutput", 0); const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex("dinput", 0); user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); const ShapeView& input_shape = input->shape_view(); const ShapeView& output_shape = doutput->shape_view(); const ShapeView& dinput_shape = dinput->shape_view(); const DataType dtype = input->data_type(); CudnnTensorDesc input_desc(dtype, input_shape, "channels_first"); CudnnTensorDesc output_desc(dtype, output_shape, "channels_first"); CudnnTensorDesc dinput_desc(dtype, dinput_shape, "channels_first"); CudnnGridSampleDesc transfomer_desc(dtype, output_shape); OF_CUDNN_CHECK(cudnnSpatialTfSamplerBackward( ctx->stream()->As()->cudnn_handle(), transfomer_desc.Get(), CudnnSPOnePtr(), input_desc.Get(), input->dptr(), CudnnSPZeroPtr(), dinput_desc.Get(), dinput->mut_dptr(), CudnnSPOnePtr(), output_desc.Get(), doutput->dptr(), grid->dptr(), CudnnSPZeroPtr(), dgrid->mut_dptr())); } }; template __launch_bounds__(256) __global__ void CUDAGridSampler4DKernel(const index_type nthreads, const data_type* input_ptr, const data_type* grid_ptr, data_type* output_ptr, index_type N, index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners) { GridSampler4DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_H, inp_W, out_H, out_W, interpolation_mode, padding_mode, align_corners); } template __launch_bounds__(512) __global__ void CUDAGridSampler5DKernel(const index_type nthreads, const data_type* input_ptr, const data_type* grid_ptr, data_type* output_ptr, index_type N, index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners) { GridSampler5DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_D, inp_H, inp_W, out_D, out_H, out_W, interpolation_mode, padding_mode, align_corners); } template __launch_bounds__(256) __global__ void CUDAGridSampler4DBackwardKernel( const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners, const index_type grad_input_memory_span) { GridSampler4DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr, grad_grid_ptr, N, C, inp_H, inp_W, out_H, out_W, interpolation_mode, padding_mode, align_corners, grad_input_memory_span); } template __launch_bounds__(256) __global__ void CUDAGridSampler5DBackwardKernel( const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners, const index_type grad_input_memory_span) { GridSampler5DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr, grad_grid_ptr, N, C, inp_D, inp_H, inp_W, out_D, out_H, out_W, interpolation_mode, padding_mode, align_corners, grad_input_memory_span); } template struct GridSampleKernelUtil final { static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { if (CudnnGridSampleKernelUtil::CanRunWithCudnn(ctx) && CanUse32BitIndex({input_shape, grid_shape, output_shape})) { return CudnnGridSampleKernelUtil::ForwardCompute(ctx); } CUDAGridSampler4DKernel <<stream()->As()->cuda_stream()>>>( count, input->dptr(), grid->dptr(), output->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners); } static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { CUDAGridSampler5DKernel <<stream()->As()->cuda_stream()>>>( count, input->dptr(), grid->dptr(), output->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, align_corners); } static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { if (CudnnGridSampleKernelUtil::CanRunWithCudnn(ctx) && CanUse32BitIndex({input_shape, grid_shape, output_shape})) { return CudnnGridSampleKernelUtil::BackwardCompute(ctx); } CUDAGridSampler4DBackwardKernel <<stream()->As()->cuda_stream()>>>( count, doutput->dptr(), input->dptr(), grid->dptr(), dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt()); } static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { CUDAGridSampler5DBackwardKernel <<stream()->As()->cuda_stream()>>>( count, doutput->dptr(), input->dptr(), grid->dptr(), dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, align_corners, input_shape.elem_cnt()); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/grid_sample_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ #define ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/user/kernels/clip_by_value_kernel.h" #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA namespace oneflow { enum class GridSamplerInterpolation { kBilinear = 0, kNearest, kBicubic }; enum class GridSamplerPadding { kZeros = 0, kBorder, kReflection }; static GridSamplerInterpolation StringToGridSamplerInterpolation(const std::string& mode) { if (mode == "bilinear") { return GridSamplerInterpolation::kBilinear; } else if (mode == "nearest") { return GridSamplerInterpolation::kNearest; } return GridSamplerInterpolation::kBicubic; } static GridSamplerPadding StringToGridGridSamplerPadding(const std::string& mode) { if (mode == "zeros") { return GridSamplerPadding::kZeros; } else if (mode == "border") { return GridSamplerPadding::kBorder; } return GridSamplerPadding::kReflection; } static bool CanUse32BitIndex(const std::initializer_list& shapes) { for (const auto& shape : shapes) { if (shape.elem_cnt() >= std::numeric_limits::max()) { return false; } } return true; } inline int GridSampleGetBlocks(const int64_t number, const int64_t threads_per_block) { // Round up division for positive number that cannot cause integer overflow auto block_num = (number - 1) / threads_per_block + 1; return static_cast(block_num); } // This kernel implement is referenced from: // https://github.com/pytorch/pytorch with git commit id: e7724bb // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). // if align_corners: -1 and +1 get sent to the centers of the corner pixels // -1 --> 0 // +1 --> (size - 1) // scale_factor = (size - 1) / 2 // if not align_corners: -1 and +1 get sent to the image edges // -1 --> -0.5 // +1 --> (size - 1) + 0.5 == size - 0.5 // scale_factor = size / 2 template static OF_DEVICE_FUNC scalar_t GridSamplerUnnormalize(scalar_t coord, int size, bool align_corners) { if (align_corners) { // unnormalize coord from [-1, 1] to [0, size - 1] return ((coord + 1.f) / 2) * (size - 1); } else { // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] return ((coord + 1.f) * size - 1) / 2; } } // GridSamplerUnnormalizeSetGrad works the same as GridSamplerUnnormalize // except that it also returns the `d output / d input` via pointer argument // `grad_in`. // This is useful in the backward pass of grid_sampler. template static OF_DEVICE_FUNC scalar_t GridSamplerUnnormalizeSetGrad(scalar_t coord, int size, bool align_corners, scalar_t* grad_in) { if (align_corners) { // unnormalize coord from [-1, 1] to [0, size - 1] *grad_in = static_cast(size - 1) / 2; return ((coord + 1.f) / 2) * (size - 1); } else { // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] *grad_in = static_cast(size) / 2; return ((coord + 1.f) * size - 1) / 2; } } // Clips coordinates to between 0 and clip_limit - 1 template static OF_DEVICE_FUNC scalar_t ClipCoordinates(scalar_t in, int clip_limit) { return DeviceMin(static_cast(clip_limit - 1), DeviceMax(in, static_cast(0))); } // ClipCoordinatesSetGrad works similarly to ClipCoordinates except that // it also returns the `d output / d input` via pointer argument `grad_in`. // This is useful in the backward pass of grid_sampler. template static OF_DEVICE_FUNC scalar_t ClipCoordinatesSetGrad(scalar_t in, int clip_limit, scalar_t* grad_in) { // Note that it is important for the gradient calculation that borders // are considered out of bounds. if (in <= static_cast(0)) { *grad_in = static_cast(0); return static_cast(0); } else { scalar_t max = static_cast(clip_limit - 1); if (in >= max) { *grad_in = static_cast(0); return max; } else { *grad_in = static_cast(1); return in; } } } // Reflects coordinates until they fall between low and high (inclusive). // The bounds are passed as twice their value so that half-integer values // can be represented as ints. template static OF_DEVICE_FUNC scalar_t ReflectCoordinates(scalar_t in, int twice_low, int twice_high) { if (twice_low == twice_high) { return static_cast(0); } scalar_t min = static_cast(twice_low) / 2; scalar_t span = static_cast(twice_high - twice_low) / 2; in = fabs(in - min); // `fmod` returns same sign as `in`, which is positive after the `fabs` above. scalar_t extra = fmod(in, span); int flips = static_cast(floor(in / span)); if (flips % 2 == 0) { return extra + min; } else { return span - extra + min; } } // ReflectCoordinatesSetGrad works similarly to ReflectCoordinates except // that it also returns the `d output / d input` via pointer argument // `grad_in`. // This is useful in the backward pass of grid_sampler. template static OF_DEVICE_FUNC scalar_t ReflectCoordinatesSetGrad(scalar_t in, int twice_low, int twice_high, scalar_t* grad_in) { if (twice_low == twice_high) { *grad_in = static_cast(0); return static_cast(0); } int grad_in_mult_ = 1; scalar_t min = static_cast(twice_low) / 2; scalar_t span = static_cast(twice_high - twice_low) / 2; in = in - min; if (in < static_cast(0)) { grad_in_mult_ = -1; in = -in; } else { grad_in_mult_ = 1; } // `fmod` returns same sign as `in`, which is positive after the `if` above. scalar_t extra = fmod(in, span); int flips = static_cast(floor(in / span)); if (flips % 2 == 0) { *grad_in = static_cast(grad_in_mult_); return extra + min; } else { *grad_in = static_cast(-grad_in_mult_); return span - extra + min; } } #if defined(__CUDACC__) template static __device__ __forceinline__ scalar_t safe_downgrade_to_int_range(scalar_t x) { // -100.0 does not have special meaning. This is just to make sure // it's not WithinBounds2D or WithinBounds3D, and does not cause // undefined behavior. See #35506. // TODO(pei tingkuan): (explicit or implicit) type conversion from // INT_MAX - 1 to float(INT_MAX - 1) indeed changes value from // 2147483647 to 2147483648 and losses precision // Reference: https://stackoverflow.com/q/526070 if (x > static_cast(INT_MAX - 1) || x < INT_MIN || !isfinite(static_cast(x))) return static_cast(-100.0); return x; } #endif template static OF_DEVICE_FUNC scalar_t ComputeCoordinates(scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { if (padding_mode == GridSamplerPadding::kBorder) { // clip coordinates to image borders coord = ClipCoordinates(coord, size); } else if (padding_mode == GridSamplerPadding::kReflection) { // reflect coordinates by image borders if (align_corners) { coord = ReflectCoordinates(coord, 0, 2 * (size - 1)); } else { coord = ReflectCoordinates(coord, -1, 2 * size - 1); } // clip coordinates to image borders coord = ClipCoordinates(coord, size); } #if defined(__CUDACC__) coord = safe_downgrade_to_int_range(coord); #endif return coord; } // Computes the pixel source index value for a grid coordinate template static OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndex(scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { coord = GridSamplerUnnormalize(coord, size, align_corners); coord = ComputeCoordinates(coord, size, padding_mode, align_corners); return coord; } // GridSamplerComputeSourceIndexSetGrad works similarly to // GridSamplerComputeSourceIndex except that it also returns the // `d output / d input` via pointer argument `grad_in`. // This is useful in the backward pass of grid_sampler. template static OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndexSetGrad(scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners, scalar_t* grad_in) { scalar_t grad_clip, grad_refl; coord = GridSamplerUnnormalizeSetGrad(coord, size, align_corners, grad_in); if (padding_mode == GridSamplerPadding::kBorder) { // clip coordinates to image borders coord = ClipCoordinatesSetGrad(coord, size, &grad_clip); *grad_in = (*grad_in) * grad_clip; } else if (padding_mode == GridSamplerPadding::kReflection) { // reflect coordinates by image borders if (align_corners) { coord = ReflectCoordinatesSetGrad(coord, 0, 2 * (size - 1), &grad_refl); } else { coord = ReflectCoordinatesSetGrad(coord, -1, 2 * size - 1, &grad_refl); } // clip coordinates to image borders coord = ClipCoordinatesSetGrad(coord, size, &grad_clip); *grad_in = (*grad_in) * grad_refl * grad_clip; } #if defined(__CUDACC__) coord = safe_downgrade_to_int_range(coord); #endif return coord; } static OF_DEVICE_FUNC bool WithinBounds2D(int h, int w, int H, int W) { return h >= 0 && h < H && w >= 0 && w < W; } static OF_DEVICE_FUNC bool WithinBounds3D(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } template static OF_DEVICE_FUNC scalar_t GetValueBounded(const scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, GridSamplerPadding padding_mode, bool align_corners) { x = ComputeCoordinates(x, W, padding_mode, align_corners); y = ComputeCoordinates(y, H, padding_mode, align_corners); int ix = static_cast(x); int iy = static_cast(y); if (WithinBounds2D(iy, ix, H, W)) { return data[iy * sH + ix * sW]; } return static_cast(0); } template static OF_DEVICE_FUNC void SafeAdd2D(scalar_t* data, int h, int w, int sH, int sW, int H, int W, scalar_t delta, const index_t NC_offset, const index_t memory_span) { if (WithinBounds2D(h, w, H, W)) { #if defined(__CUDACC__) cuda::atomic::Add(data + NC_offset + h * sH + w * sW, delta); #else data[NC_offset + h * sH + w * sW] += delta; #endif } } template static OF_DEVICE_FUNC void SafeAdd3D(scalar_t* data, int d, int h, int w, int sD, int sH, int sW, int D, int H, int W, scalar_t delta, const index_t NC_offset, const index_t memory_span) { if (WithinBounds3D(d, h, w, D, H, W)) { #if defined(__CUDACC__) cuda::atomic::Add(data + NC_offset + d * sD + h * sH + w * sW, delta); #else data[NC_offset + d * sD + h * sH + w * sW] += delta; #endif } } template static OF_DEVICE_FUNC void AddValueBounded(scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, scalar_t delta, GridSamplerPadding padding_mode, bool align_corners, const index_t NC_offset, const index_t memory_span) { x = ComputeCoordinates(x, W, padding_mode, align_corners); y = ComputeCoordinates(y, H, padding_mode, align_corners); int ix = static_cast(x); int iy = static_cast(y); SafeAdd2D(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span); } // Calculate the differential of the cubic convolution, i.e. `d coeff / d x` template static OF_DEVICE_FUNC void GetCubicCoefficientsGrad(scalar_t coeffs[4], scalar_t t) { // Must be the same as forward calculation in // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients scalar_t A = -0.75; scalar_t x; x = -1 - t; // 1 < x = |-1 - tx| < 2 coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A; x = -t; // x = |0 - tx| <= 1 coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; x = 1 - t; // x = |1 - tx| <= 1 coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; x = 2 - t; // 1 < x = |2 - tx| < 2 coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; } // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm template OF_DEVICE_FUNC static accscalar_t CubicConvolution1(accscalar_t x, accscalar_t A) { return ((A + 2) * x - (A + 3)) * x * x + 1; } template OF_DEVICE_FUNC static accscalar_t CubicConvolution2(accscalar_t x, accscalar_t A) { return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } template OF_DEVICE_FUNC static void GetCubicUpsamplingCoefficients(accscalar_t coeffs[4], accscalar_t t) { accscalar_t A = -0.75; accscalar_t x1 = t; coeffs[0] = CubicConvolution2(x1 + 1.0, A); coeffs[1] = CubicConvolution1(x1, A); // opposite coefficients accscalar_t x2 = 1.0 - t; coeffs[2] = CubicConvolution1(x2, A); coeffs[3] = CubicConvolution2(x2 + 1.0, A); } template OF_DEVICE_FUNC static accscalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, accscalar_t t) { accscalar_t coeffs[4]; GetCubicUpsamplingCoefficients(coeffs, t); return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; } template OF_DEVICE_FUNC void GridSampler4DKernel(const index_type nthreads, const data_type* input_ptr, const data_type* grid_ptr, data_type* output_ptr, index_type N, index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners) { index_type inp_sN = C * inp_H * inp_W; index_type inp_sC = inp_H * inp_W; index_type inp_sH = inp_W; index_type inp_sW = 1; index_type grid_sN = out_H * out_W * 2; index_type grid_sH = out_W * 2; index_type grid_sW = 2; index_type grid_sCoor = 1; index_type out_sN = C * out_H * out_W; index_type out_sC = out_H * out_W; index_type out_sH = out_W; index_type out_sW = 1; XPU_1D_KERNEL_LOOP(index, nthreads) { const index_type w = index % out_W; const index_type h = (index / out_W) % out_H; const index_type n = index / (out_H * out_W); const index_type grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid data_type x = grid_ptr[grid_offset]; data_type y = grid_ptr[grid_offset + grid_sCoor]; data_type ix = GridSamplerComputeSourceIndex(x, inp_W, padding_mode, align_corners); data_type iy = GridSamplerComputeSourceIndex(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::kBilinear) { // get NE, NW, SE, SW pixel values from (x, y) index_type ix_nw = static_cast(::floor(ix)); index_type iy_nw = static_cast(::floor(iy)); index_type ix_ne = ix_nw + 1; index_type iy_ne = iy_nw; index_type ix_sw = ix_nw; index_type iy_sw = iy_nw + 1; index_type ix_se = ix_nw + 1; index_type iy_se = iy_nw + 1; // get surfaces to each neighbor: data_type nw = (ix_se - ix) * (iy_se - iy); data_type ne = (ix - ix_sw) * (iy_sw - iy); data_type sw = (ix_ne - ix) * (iy - iy_ne); data_type se = (ix - ix_nw) * (iy - iy_nw); // calculate bilinear weighted pixel value and set output pixel auto inp_ptr_NC = input_ptr + n * inp_sN; auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { *out_ptr_NCHW = static_cast(0); if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) { *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; } if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) { *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; } if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) { *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; } if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) { *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; } } } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { index_type ix_nearest = static_cast(::round(ix)); index_type iy_nearest = static_cast(::round(iy)); // assign nearest neighor pixel value to output pixel auto inp_ptr_NC = input_ptr + n * inp_sN; auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { if (WithinBounds2D(iy_nearest, ix_nearest, inp_H, inp_W)) { *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; } else { *out_ptr_NCHW = static_cast(0); } } } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) { ix = GridSamplerUnnormalize(x, inp_W, align_corners); iy = GridSamplerUnnormalize(y, inp_H, align_corners); data_type ix_nw = ::floor(ix); data_type iy_nw = ::floor(iy); const data_type tx = ix - ix_nw; const data_type ty = iy - iy_nw; auto inp_ptr_NC = input_ptr + n * inp_sN; auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { data_type coefficients[4]; #ifdef __CUDA_ARCH__ #pragma unroll 4 #endif for (index_type i = 0; i < 4; ++i) { coefficients[i] = cubic_interp1d( GetValueBounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), GetValueBounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), GetValueBounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), GetValueBounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), tx); } *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); } } } } template OF_DEVICE_FUNC void GridSampler5DKernel(const index_type nthreads, const data_type* input_ptr, const data_type* grid_ptr, data_type* output_ptr, index_type N, index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners) { index_type inp_sN = C * inp_D * inp_H * inp_W; index_type inp_sC = inp_D * inp_H * inp_W; index_type inp_sD = inp_H * inp_W; index_type inp_sH = inp_W; index_type inp_sW = 1; index_type grid_sN = out_D * out_H * out_W * 3; index_type grid_sD = out_H * out_W * 3; index_type grid_sH = out_W * 3; index_type grid_sW = 3; index_type grid_sCoor = 1; index_type out_sN = C * out_D * out_H * out_W; index_type out_sC = out_D * out_H * out_W; index_type out_sD = out_H * out_W; index_type out_sH = out_W; index_type out_sW = 1; XPU_1D_KERNEL_LOOP(index, nthreads) { const index_type w = index % out_W; const index_type h = (index / out_W) % out_H; const index_type d = (index / (out_H * out_W)) % out_D; const index_type n = index / (out_D * out_H * out_W); const index_type grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; // get the corresponding input x, y, z co-ordinates from grid data_type ix = grid_ptr[grid_offset]; data_type iy = grid_ptr[grid_offset + grid_sCoor]; data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor]; ix = GridSamplerComputeSourceIndex(ix, inp_W, padding_mode, align_corners); iy = GridSamplerComputeSourceIndex(iy, inp_H, padding_mode, align_corners); iz = GridSamplerComputeSourceIndex(iz, inp_D, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::kBilinear) { // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom index_type ix_tnw = static_cast(::floor(ix)); index_type iy_tnw = static_cast(::floor(iy)); index_type iz_tnw = static_cast(::floor(iz)); index_type ix_tne = ix_tnw + 1; index_type iy_tne = iy_tnw; index_type iz_tne = iz_tnw; index_type ix_tsw = ix_tnw; index_type iy_tsw = iy_tnw + 1; index_type iz_tsw = iz_tnw; index_type ix_tse = ix_tnw + 1; index_type iy_tse = iy_tnw + 1; index_type iz_tse = iz_tnw; index_type ix_bnw = ix_tnw; index_type iy_bnw = iy_tnw; index_type iz_bnw = iz_tnw + 1; index_type ix_bne = ix_tnw + 1; index_type iy_bne = iy_tnw; index_type iz_bne = iz_tnw + 1; index_type ix_bsw = ix_tnw; index_type iy_bsw = iy_tnw + 1; index_type iz_bsw = iz_tnw + 1; index_type ix_bse = ix_tnw + 1; index_type iy_bse = iy_tnw + 1; index_type iz_bse = iz_tnw + 1; // get surfaces to each neighbor: data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); auto inp_ptr_NC = input_ptr + n * inp_sN; auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse *out_ptr_NCDHW = static_cast(0); if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; } if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; } if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; } if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; } if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; } if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; } if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; } if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; } } } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { index_type ix_nearest = static_cast(::round(ix)); index_type iy_nearest = static_cast(::round(iy)); index_type iz_nearest = static_cast(::round(iz)); // assign nearest neighor pixel value to output pixel auto inp_ptr_NC = input_ptr + n * inp_sN; auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { if (WithinBounds3D(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; } else { *out_ptr_NCDHW = static_cast(0); } } } } } // Note [Passing pointer and offset to fastAtomicAdd] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // For its internal bounds checking, fastAtomicAdd needs to know where the destination address // lies relative to the entire tensor, so we pass the base grad_input_ptr and full offset // information, including batch * channel offset (NC_offset). template OF_DEVICE_FUNC void GridSampler4DBackwardKernel( const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners, const index_type grad_input_memory_span) { index_type inp_sN = C * inp_H * inp_W; index_type inp_sC = inp_H * inp_W; index_type inp_sH = inp_W; index_type inp_sW = 1; index_type grid_sN = out_H * out_W * 2; index_type grid_sH = out_W * 2; index_type grid_sW = 2; index_type grid_sCoor = 1; index_type gOut_sN = C * out_H * out_W; index_type gOut_sC = out_H * out_W; index_type gOut_sH = out_W; index_type gOut_sW = 1; index_type gInp_sN = inp_sN; index_type gInp_sC = inp_sC; index_type gInp_sH = inp_sH; index_type gInp_sW = inp_sW; index_type gGrid_sW = grid_sW; XPU_1D_KERNEL_LOOP(index, nthreads) { const index_type w = index % out_W; const index_type h = (index / out_W) % out_H; const index_type n = index / (out_H * out_W); const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid data_type x = grid_ptr[grid_offset]; data_type y = grid_ptr[grid_offset + grid_sCoor]; // multipliers for gradients on ix and iy data_type gix_mult, giy_mult; data_type ix = GridSamplerComputeSourceIndexSetGrad(x, inp_W, padding_mode, align_corners, &gix_mult); data_type iy = GridSamplerComputeSourceIndexSetGrad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::kBilinear) { // get NE, NW, SE, SW pixel values from (x, y) index_type ix_nw = static_cast(::floor(ix)); index_type iy_nw = static_cast(::floor(iy)); index_type ix_ne = ix_nw + 1; index_type iy_ne = iy_nw; index_type ix_sw = ix_nw; index_type iy_sw = iy_nw + 1; index_type ix_se = ix_nw + 1; index_type iy_se = iy_nw + 1; // get surfaces to each neighbor: data_type nw = (ix_se - ix) * (iy_se - iy); data_type ne = (ix - ix_sw) * (iy_sw - iy); data_type sw = (ix_ne - ix) * (iy - iy_ne); data_type se = (ix - ix_nw) * (iy - iy_nw); data_type gix = static_cast(0), giy = static_cast(0); const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; index_type NC_offset = n * gInp_sN; const data_type* inp_ptr_NC = input_ptr + n * inp_sN; for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { data_type gOut = *gOut_ptr_NCHW; // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. SafeAdd2D(grad_input_ptr, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut, NC_offset, grad_input_memory_span); SafeAdd2D(grad_input_ptr, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut, NC_offset, grad_input_memory_span); SafeAdd2D(grad_input_ptr, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut, NC_offset, grad_input_memory_span); SafeAdd2D(grad_input_ptr, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut, NC_offset, grad_input_memory_span); // calculate grad_grid if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) { data_type nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; gix -= nw_val * (iy_se - iy) * gOut; giy -= nw_val * (ix_se - ix) * gOut; } if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) { data_type ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; gix += ne_val * (iy_sw - iy) * gOut; giy -= ne_val * (ix - ix_sw) * gOut; } if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) { data_type sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; gix -= sw_val * (iy - iy_ne) * gOut; giy += sw_val * (ix_ne - ix) * gOut; } if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) { data_type se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; gix += se_val * (iy - iy_nw) * gOut; giy += se_val * (ix - ix_nw) * gOut; } } // assuming grad_grid is contiguous // thus we can // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; gGrid_ptr_NHW[0] = gix_mult * gix; gGrid_ptr_NHW[1] = giy_mult * giy; } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { index_type ix_nearest = static_cast(::round(ix)); index_type iy_nearest = static_cast(::round(iy)); // assign nearest neighor pixel value to output pixel const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; index_type NC_offset = n * gInp_sN; for (index_type c = 0; c < C; ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. SafeAdd2D(grad_input_ptr, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW, NC_offset, grad_input_memory_span); } // assuming grad_grid is contiguous // thus we can // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; gGrid_ptr_NHW[0] = static_cast(0); gGrid_ptr_NHW[1] = static_cast(0); } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) { ix = GridSamplerUnnormalizeSetGrad(x, inp_W, align_corners, &gix_mult); iy = GridSamplerUnnormalizeSetGrad(y, inp_H, align_corners, &giy_mult); data_type ix_nw = ::floor(ix); data_type iy_nw = ::floor(iy); const data_type tx = ix - ix_nw; const data_type ty = iy - iy_nw; data_type x_coeffs[4]; data_type y_coeffs[4]; data_type x_coeffs_grad[4]; data_type y_coeffs_grad[4]; GetCubicUpsamplingCoefficients(x_coeffs, tx); GetCubicUpsamplingCoefficients(y_coeffs, ty); GetCubicCoefficientsGrad(x_coeffs_grad, tx); GetCubicCoefficientsGrad(y_coeffs_grad, ty); data_type gix = static_cast(0); data_type giy = static_cast(0); const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; index_type NC_offset = n * gInp_sN; const data_type* inp_ptr_NC = input_ptr + n * inp_sN; for (index_type c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) { data_type gOut = *gOut_ptr_NCHW; #ifdef __CUDA_ARCH__ #pragma unroll 4 #endif for (index_type i = 0; i < 4; ++i) { #ifdef __CUDA_ARCH__ #pragma unroll 4 #endif for (index_type j = 0; j < 4; ++j) { // set input gradient. See Note [Passing pointer and offset to fastAtomicAdd]. AddValueBounded(grad_input_ptr, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners, NC_offset, grad_input_memory_span); // set grid gradient data_type val = GetValueBounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; } } } data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; gGrid_ptr_NHW[0] = gix_mult * gix; gGrid_ptr_NHW[1] = giy_mult * giy; } } } template OF_DEVICE_FUNC void GridSampler5DBackwardKernel( const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, const bool align_corners, const index_type grad_input_memory_span) { index_type inp_sN = C * inp_D * inp_H * inp_W; index_type inp_sC = inp_D * inp_H * inp_W; index_type inp_sD = inp_H * inp_W; index_type inp_sH = inp_W; index_type inp_sW = 1; index_type grid_sN = out_D * out_H * out_W * 3; index_type grid_sD = out_H * out_W * 3; index_type grid_sH = out_W * 3; index_type grid_sW = 3; index_type grid_sCoor = 1; index_type gOut_sN = C * out_D * out_H * out_W; index_type gOut_sC = out_D * out_H * out_W; index_type gOut_sD = out_H * out_W; index_type gOut_sH = out_W; index_type gOut_sW = 1; index_type gInp_sN = inp_sN; index_type gInp_sC = inp_sC; index_type gInp_sD = inp_sD; index_type gInp_sH = inp_sH; index_type gInp_sW = inp_sW; index_type gGrid_sW = grid_sW; XPU_1D_KERNEL_LOOP(index, nthreads) { const index_type w = index % out_W; const index_type h = (index / out_W) % out_H; const index_type d = (index / (out_H * out_W)) % out_D; const index_type n = index / (out_D * out_H * out_W); const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; // get the corresponding input x, y, z co-ordinates from grid data_type ix = grid_ptr[grid_offset]; data_type iy = grid_ptr[grid_offset + grid_sCoor]; data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor]; // multipliers for gradients on ix, iy, and iz data_type gix_mult, giy_mult, giz_mult; ix = GridSamplerComputeSourceIndexSetGrad(ix, inp_W, padding_mode, align_corners, &gix_mult); iy = GridSamplerComputeSourceIndexSetGrad(iy, inp_H, padding_mode, align_corners, &giy_mult); iz = GridSamplerComputeSourceIndexSetGrad(iz, inp_D, padding_mode, align_corners, &giz_mult); if (interpolation_mode == GridSamplerInterpolation::kBilinear) { // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom index_type ix_tnw = static_cast(::floor(ix)); index_type iy_tnw = static_cast(::floor(iy)); index_type iz_tnw = static_cast(::floor(iz)); index_type ix_tne = ix_tnw + 1; index_type iy_tne = iy_tnw; index_type iz_tne = iz_tnw; index_type ix_tsw = ix_tnw; index_type iy_tsw = iy_tnw + 1; index_type iz_tsw = iz_tnw; index_type ix_tse = ix_tnw + 1; index_type iy_tse = iy_tnw + 1; index_type iz_tse = iz_tnw; index_type ix_bnw = ix_tnw; index_type iy_bnw = iy_tnw; index_type iz_bnw = iz_tnw + 1; index_type ix_bne = ix_tnw + 1; index_type iy_bne = iy_tnw; index_type iz_bne = iz_tnw + 1; index_type ix_bsw = ix_tnw; index_type iy_bsw = iy_tnw + 1; index_type iz_bsw = iz_tnw + 1; index_type ix_bse = ix_tnw + 1; index_type iy_bse = iy_tnw + 1; index_type iz_bse = iz_tnw + 1; // get surfaces to each neighbor: data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); data_type gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); const data_type* gOut_ptr_NCDHW = grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; index_type NC_offset = n * gInp_sN; const data_type* inp_ptr_NC = input_ptr + n * inp_sN; // calculate bilinear weighted pixel value and set output pixel for (index_type c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) { data_type gOut = *gOut_ptr_NCDHW; // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. SafeAdd3D(grad_input_ptr, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut, NC_offset, grad_input_memory_span); SafeAdd3D(grad_input_ptr, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut, NC_offset, grad_input_memory_span); // calculate grad_grid if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { data_type tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; } if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { data_type tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; } if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { data_type tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; } if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { data_type tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; } if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { data_type bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; } if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { data_type bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; } if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { data_type bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; } if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { data_type bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; } } // assuming grad_grid is contiguous // thus we can // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW; gGrid_ptr_NDHW[0] = gix_mult * gix; gGrid_ptr_NDHW[1] = giy_mult * giy; gGrid_ptr_NDHW[2] = giz_mult * giz; } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { auto ix_nearest = static_cast(::round(ix)); auto iy_nearest = static_cast(::round(iy)); auto iz_nearest = static_cast(::round(iz)); // assign nearest neighor pixel value to output pixel const data_type* gOut_ptr_NCDHW = grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; index_type NC_offset = n * gInp_sN; for (index_type c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC) { // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. SafeAdd3D(grad_input_ptr, iz_nearest, iy_nearest, ix_nearest, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW, NC_offset, grad_input_memory_span); } // assuming grad_grid is contiguous // thus we can // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW; gGrid_ptr_NDHW[0] = static_cast(0); gGrid_ptr_NDHW[1] = static_cast(0); gGrid_ptr_NDHW[2] = static_cast(0); } } } template struct GridSampleKernelUtil final { static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* output, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, const user_op::Tensor* input, const user_op::Tensor* grid, user_op::Tensor* dinput, user_op::Tensor* dgrid, GridSamplerInterpolation interpolation, GridSamplerPadding padding, const bool align_corners, const ShapeView& input_shape, const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); }; // macros for functors instantiate(used by grid_sample_kernel_util.cu, grid_sample_kernel_util.cpp) #define INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL(device_type, dtype_pair, itype_pair) \ template struct GridSampleKernelUtil; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/group_conv_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewChannelsFirstMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } auto ChannelsFirstMatmulPrimitiveExists() { return hob::make_custom("ChannelsFirstMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewChannelsFirstMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewChannelsLastMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ChannelsLastMatmulPrimitiveExists() { return hob::make_custom("ChannelsLastMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewChannelsLastMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvDataGradTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ConvDataGradTransATransBMatmulPrimitiveExists() { return hob::make_custom("ConvDataGradTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvDataGradTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvDataGradTransANoTransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } auto ConvDataGradTransANoTransBMatmulPrimitiveExists() { return hob::make_custom( "ConvDataGradTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvDataGradTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvWeightGradTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/true); } auto ConvWeightGradTransATransBMatmulPrimitiveExists() { return hob::make_custom( "ConvWeightGradTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvWeightGradTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewConvWeightGradNoTransATransBMatmulPrimitive( Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/true); } auto ConvWeightGradNoTransATransBMatmulPrimitiveExists() { return hob::make_custom( "ConvWeightGradNoTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewConvWeightGradNoTransATransBMatmulPrimitive(&ctx).operator bool(); }); } template using Im2ColFunc = void (*)(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf); template using Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr); template T* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) { return tensor->mut_dptr() + tensor->shape_view().Count(1) * idx; } template const T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) { return tensor->dptr() + tensor->shape_view().Count(1) * idx; } size_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape, const int32_t idx_offset) { int64_t col_buf_elem_cnt = 1; int64_t ndims = out_shape.NumAxes() - 2; for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); } for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); } return col_buf_elem_cnt; } template class ColBufWriter { public: ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : src_ptr_(src_ptr), dst_ptr_(dst_ptr), c_size_(c_size), id_size_(id_size), ih_size_(ih_size), iw_size_(iw_size), od_size_(od_size), oh_size_(oh_size), ow_size_(ow_size) {} virtual ~ColBufWriter() = default; virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void InvalidDFunc() = 0; virtual void InvalidHFunc() = 0; virtual void InvalidWFunc() = 0; virtual void NextImCSize() = 0; protected: const T* src_ptr_; T* dst_ptr_; int64_t c_size_ = 0; int64_t id_size_ = 0; int64_t ih_size_ = 0; int64_t iw_size_ = 0; int64_t od_size_ = 0; int64_t oh_size_ = 0; int64_t ow_size_ = 0; }; template class Im2ColWriter final : public ColBufWriter { public: Im2ColWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Im2ColWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c]; } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw]; } void InvalidDFunc() override { FOR_RANGE(int64_t, i, 0, this->od_size_) { *(this->dst_ptr_++) = 0; } } void InvalidHFunc() override { FOR_RANGE(int64_t, i, 0, this->oh_size_) { *(this->dst_ptr_++) = 0; } } void InvalidWFunc() override { FOR_RANGE(int64_t, i, 0, this->ow_size_) { *(this->dst_ptr_++) = 0; } } void NextImCSize() override { this->src_ptr_ += this->c_size_; } }; template class Col2ImWriter final : public ColBufWriter { public: Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Col2ImWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] += *(this->src_ptr_++); } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++); } void InvalidDFunc() override { this->src_ptr_ += this->od_size_; } void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; } void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; } void NextImCSize() override { this->dst_ptr_ += this->c_size_; } }; template using DHWValidFunc = void (ColBufWriter::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw); template class ColBufUtil final { public: ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, const int32_t id_num, const int32_t ih_num, const int32_t iw_num, const int32_t od_num, const int32_t oh_num, const int32_t ow_num) : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before), id_num_(id_num), ih_num_(ih_num), iw_num_(iw_num), od_num_(od_num), oh_num_(oh_num), ow_num_(ow_num) { if (dhw_offset == 2) { dhw_valid_func_ = &ColBufWriter::CDHWWrite; } else { dhw_valid_func_ = &ColBufWriter::DHWCWrite; } } void operator()(ColBufWriter* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) { int64_t id = kd * dilation_rate_[0] - padding_before_[0]; FOR_RANGE(int64_t, od, 0, od_num_) { if (id < 0 || id >= id_num_) { col_buf_writer->InvalidDFunc(); } else { int64_t ih = kh * dilation_rate_[1] - padding_before_[1]; FOR_RANGE(int64_t, oh, 0, oh_num_) { if (ih < 0 || ih >= ih_num_) { col_buf_writer->InvalidHFunc(); } else { int64_t iw = kw * dilation_rate_[2] - padding_before_[2]; FOR_RANGE(int64_t, ow, 0, ow_num_) { if (iw < 0 || iw >= iw_num_) { col_buf_writer->InvalidWFunc(); } else { (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw); } iw += strides_[2]; } } ih += strides_[1]; } } id += strides_[0]; } } private: const int32_t* strides_; const int32_t* dilation_rate_; const int32_t* padding_before_; DHWValidFunc dhw_valid_func_; int64_t id_num_; int64_t ih_num_; int64_t iw_num_; int64_t od_num_; int64_t oh_num_; int64_t ow_num_; }; template struct ConvKernelUtil final { public: static void NCDHWIm2Col(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), out_shape.At(3), out_shape.At(4)); Im2ColWriter col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCIm2Col(const T* in_dptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before, in_shape.At(1), in_shape.At(2), in_shape.At(3), out_shape.At(1), out_shape.At(2), out_shape.At(3)); Im2ColWriter col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), out_shape.At(3), out_shape.At(4)); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before, in_shape.At(1), in_shape.At(2), in_shape.At(3), out_shape.At(1), out_shape.At(2), out_shape.At(3)); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } private: static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { for (int64_t c = 0; c != weight_shape.At(4); ++c) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } }; template struct ConvOpKernelCache final : public user_op::OpKernelCache { Im2ColFunc im2col_func_ = ConvKernelUtil::NCDHWIm2Col; Col2ImFunc col2im_func_ = ConvKernelUtil::NCDHWCol2Im; Shape in_5d_shape_; Shape out_5d_shape_; Shape weight_5d_shape_; std::vector strides_3d_; std::vector dilation_rate_3d_; std::vector padding_before_3d_; bool is_out_diff_need_trans_ = false; int32_t idx_offset_ = 0; bool is_dynamic_ = false; int32_t groups = 1; }; template std::shared_ptr> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); std::shared_ptr> state(new ConvOpKernelCache()); if (data_format == "channels_first") { state->im2col_func_ = ConvKernelUtil::NCDHWIm2Col; state->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; state->is_out_diff_need_trans_ = false; state->idx_offset_ = 2; } else { state->im2col_func_ = ConvKernelUtil::NDHWCIm2Col; state->col2im_func_ = ConvKernelUtil::NDHWCCol2Im; state->is_out_diff_need_trans_ = true; state->idx_offset_ = 1; } state->groups = ctx->Attr("groups"); auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { DimVector ret_vec(shape.dim_vec()); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; state->in_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_); state->out_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_); state->weight_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; state->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); state->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { state->padding_before_3d_.push_back(0); } else { state->padding_before_3d_.push_back(padding_before.at(index)); } } return state; } template void InitBiasMulBuf(T* dptr, int64_t num) { for (int64_t i = 0; i < num; ++i) { dptr[i] = 1; } } template class ConvCpuKernel final : public user_op::OpKernel { public: ConvCpuKernel() = default; ~ConvCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "in", "out", "weight"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T* col_buf_dptr = tmp_buffer->mut_dptr(); int32_t idx_offset = conv_cache->idx_offset_; const int32_t input_group_interval = in->shape_view().At(1) / conv_cache->groups; const int32_t weight_group_interval = weight->shape_view().At(0) / conv_cache->groups; const int32_t output_group_interval = out->shape_view().At(1) / conv_cache->groups; const int32_t input_step = input_group_interval * in->shape_view().Count(2); const int32_t weight_step = weight_group_interval * weight->shape_view().Count(1); const int32_t output_step = output_group_interval * out->shape_view().Count(2); const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); const int32_t k = conv_cache->weight_5d_shape_.Count(1); bool is_bias_mul_inited = false; const auto& data_format = ctx->Attr("data_format"); std::unique_ptr matmul; if (data_format == "channels_first") { matmul = NewChannelsFirstMatmulPrimitive(ctx); } else { matmul = NewChannelsLastMatmulPrimitive(ctx); } CHECK(matmul); for (int64_t i = 0; i < in->shape_view().At(0); ++i) { const T* input_ptr = GetImgDptr(in, i); const T* weight_ptr = weight->dptr(); T* output_ptr = GetImgMutDptr(out, i); for (int64_t g = 0; g < conv_cache->groups; g++) { conv_cache->im2col_func_( input_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf_dptr); // channels first: out = weight * col_buf // channels last: out = (weight * col_buf)(T) matmul->Launch(ctx->stream(), m, // filter / groups n, // od * oh * ow k, // ci * kd * kh * kw / groups static_cast(1), weight_ptr, col_buf_dptr, static_cast(0), output_ptr); input_ptr += input_step; weight_ptr += weight_step; output_ptr += output_step; } const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { int64_t num_of_col_buf = CalcElemNumOfColBuf(out->shape_view(), weight->shape_view(), idx_offset); int64_t num_of_bias_mul = (tmp_buffer->shape_view().elem_cnt() - num_of_col_buf * sizeof(T)) / sizeof(T); CHECK_GT(num_of_bias_mul, 0); T* bias_mul_dptr = col_buf_dptr + num_of_col_buf; if (!is_bias_mul_inited) { InitBiasMulBuf(bias_mul_dptr, num_of_bias_mul); is_bias_mul_inited = true; } // channels first: out += bias * bias_mul // channels last: out += (bias * bias_mul)(T) matmul->Launch(ctx->stream(), conv_cache->weight_5d_shape_.At(0), // filter conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow 1, // 1 static_cast(1), bias->dptr(), bias_mul_dptr, static_cast(1), GetImgMutDptr(out, i)); } } } }; #define REGISTER_CONV_KERNEL(op_name, dtype, ndims) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") > 1) \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && ChannelsFirstMatmulPrimitiveExists() \ && ChannelsLastMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_shape = ctx->OutputTensorDesc("out", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_shape, weight_shape, idx_offset) * sizeof(dtype); \ bool has_bias = ctx->has_input("bias", 0); \ if (has_bias) { \ int64_t bias_mul_cnt = 1; \ for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_shape.At(idx_offset + i); } \ tmp_buffer_size += bias_mul_cnt * sizeof(dtype); \ } \ return tmp_buffer_size; \ }) REGISTER_CONV_KERNEL(conv1d, float, 1); REGISTER_CONV_KERNEL(conv2d, float, 2); REGISTER_CONV_KERNEL(conv3d, float, 3); REGISTER_CONV_KERNEL(conv1d, double, 1); REGISTER_CONV_KERNEL(conv2d, double, 2); REGISTER_CONV_KERNEL(conv3d, double, 3); template class ConvDataGradCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(ConvDataGradCpuKernel); ConvDataGradCpuKernel() = default; ~ConvDataGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "dx", "dy", "filter"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int32_t idx_offset = conv_cache->idx_offset_; const int32_t dy_group_interval = dy->shape_view().At(1) / conv_cache->groups; const int32_t filter_group_interval = filter->shape_view().At(0) / conv_cache->groups; const int32_t dx_group_interval = dx->shape_view().At(1) / conv_cache->groups; const int32_t dx_step = dx_group_interval * dx->shape_view().Count(2); const int32_t filter_step = filter_group_interval * filter->shape_view().Count(1); const int32_t dy_step = dy_group_interval * dy->shape_view().Count(2); const int32_t m = conv_cache->weight_5d_shape_.Count(1); const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); const int32_t k = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (conv_cache->is_out_diff_need_trans_) { matmul = NewConvDataGradTransATransBMatmulPrimitive(ctx); } else { matmul = NewConvDataGradTransANoTransBMatmulPrimitive(ctx); } CHECK(matmul); FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) { const T* filter_ptr = filter->dptr(); const T* dy_ptr = GetImgDptr(dy, i); T* dx_ptr = GetImgMutDptr(dx, i); FOR_RANGE(int64_t, g, 0, conv_cache->groups) { // channels first: col_buf' = weight(T) * out[i]' // channels last : col_buf' = weight(T) * out[i]'(T) matmul->Launch(ctx->stream(), m, // ci * kd * kh * kw / groups n, // od * oh * ow k, // filter / groups static_cast(1), filter_ptr, dy_ptr, static_cast(0), col_buf->mut_dptr()); // in' = col2im(col_buf') conv_cache->col2im_func_( col_buf->dptr(), ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), dx_ptr); filter_ptr += filter_step; dy_ptr += dy_step; dx_ptr += dx_step; } } if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), dx->data_type()); CHECK_EQ(add_to_output->shape_view(), dx->shape_view()); std::unique_ptr primitive = ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); primitive->Launch(ctx->stream(), dx->dptr(), add_to_output->dptr(), dx->mut_dptr(), add_to_output->shape_view().elem_cnt()); } } }; #define REGISTER_CONV_DATA_GRAD_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") > 1) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && ConvDataGradTransATransBMatmulPrimitiveExists() \ && ConvDataGradTransANoTransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("filter", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_diff_shape, weight_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, float); REGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, double); template class ConvFilterGradCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradCpuKernel); ConvFilterGradCpuKernel() = default; ~ConvFilterGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateConvOpKernelCache(ctx, "x", "dy", "filter_diff"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const auto* conv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int32_t idx_offset = conv_cache->idx_offset_; const int32_t dy_group_interval = dy->shape_view().At(1) / conv_cache->groups; const int32_t filter_diff_group_interval = filter_diff->shape_view().At(0) / conv_cache->groups; const int32_t x_group_interval = x->shape_view().At(1) / conv_cache->groups; const int32_t x_step = x_group_interval * x->shape_view().Count(2); const int32_t dy_step = dy_group_interval * dy->shape_view().Count(2); const int32_t filter_diff_step = filter_diff_group_interval * filter_diff->shape_view().Count(1); const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; const int32_t n = conv_cache->weight_5d_shape_.Count(1); const int32_t k = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); Memset(ctx->stream(), filter_diff->mut_dptr(), 0, filter_diff->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (conv_cache->is_out_diff_need_trans_) { matmul = NewConvWeightGradTransATransBMatmulPrimitive(ctx); } else { matmul = NewConvWeightGradNoTransATransBMatmulPrimitive(ctx); } CHECK(matmul); FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) { const T* x_ptr = GetImgDptr(x, i); const T* dy_ptr = GetImgDptr(dy, i); T* filter_diff_ptr = filter_diff->mut_dptr(); FOR_RANGE(int64_t, g, 0, conv_cache->groups) { conv_cache->im2col_func_( x_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf->mut_dptr()); // channels first: weight' += out[i]' * col_buf(T) // channels last : weight' += out[i]'(T) * col_buf(T) matmul->Launch(ctx->stream(), m, // filter / groups n, // ci * kd * kh * kw k, // od * oh * ow / groups static_cast(1), dy_ptr, col_buf->dptr(), static_cast(1), filter_diff_ptr); x_ptr += x_step; dy_ptr += dy_step; filter_diff_ptr += filter_diff_step; } } } }; #define REGISTER_CONV_FILTER_GRAD_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") > 1) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && ConvWeightGradTransATransBMatmulPrimitiveExists() \ && ConvWeightGradNoTransATransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(out_diff_shape, weight_diff_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, float); REGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, double); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/group_deconv_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewDeconvTransATransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, true); } template std::unique_ptr NewDeconvTransANoTransBMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); } auto DeconvTransATransBMatmulPrimitiveExists() { return hob::make_custom("DeconvTransATransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewDeconvTransATransBMatmulPrimitive(&ctx).operator bool(); }); } auto DeconvTransANoTransBMatmulPrimitiveExists() { return hob::make_custom("DeconvTransANoTransBMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewDeconvTransANoTransBMatmulPrimitive(&ctx).operator bool(); }); } template using Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr); template T* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) { return tensor->mut_dptr() + tensor->shape_view().Count(1) * idx; } template const T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) { return tensor->dptr() + tensor->shape_view().Count(1) * idx; } size_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape, const int32_t idx_offset) { int64_t col_buf_elem_cnt = 1; int64_t ndims = out_shape.NumAxes() - 2; for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); } for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); } return col_buf_elem_cnt; } template class ColBufWriter { public: ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : src_ptr_(src_ptr), dst_ptr_(dst_ptr), c_size_(c_size), id_size_(id_size), ih_size_(ih_size), iw_size_(iw_size), od_size_(od_size), oh_size_(oh_size), ow_size_(ow_size) {} virtual ~ColBufWriter() = default; virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; virtual void InvalidDFunc() = 0; virtual void InvalidHFunc() = 0; virtual void InvalidWFunc() = 0; virtual void NextImCSize() = 0; protected: const T* src_ptr_; T* dst_ptr_; int64_t c_size_ = 0; int64_t id_size_ = 0; int64_t ih_size_ = 0; int64_t iw_size_ = 0; int64_t od_size_ = 0; int64_t oh_size_ = 0; int64_t ow_size_ = 0; }; template class Col2ImWriter final : public ColBufWriter { public: Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, oh_size, ow_size) {} ~Col2ImWriter() = default; void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] += *(this->src_ptr_++); } void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++); } void InvalidDFunc() override { this->src_ptr_ += this->od_size_; } void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; } void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; } void NextImCSize() override { this->dst_ptr_ += this->c_size_; } }; template using DHWValidFunc = void (ColBufWriter::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw); template class ColBufUtil final { public: ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, const int32_t id_num, const int32_t ih_num, const int32_t iw_num, const int32_t od_num, const int32_t oh_num, const int32_t ow_num) : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before), id_num_(id_num), ih_num_(ih_num), iw_num_(iw_num), od_num_(od_num), oh_num_(oh_num), ow_num_(ow_num) { if (dhw_offset == 2) { dhw_valid_func_ = &ColBufWriter::CDHWWrite; } else { dhw_valid_func_ = &ColBufWriter::DHWCWrite; } } void operator()(ColBufWriter* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) { int64_t id = kd * dilation_rate_[0] - padding_before_[0]; FOR_RANGE(int64_t, od, 0, od_num_) { if (id < 0 || id >= id_num_) { col_buf_writer->InvalidDFunc(); } else { int64_t ih = kh * dilation_rate_[1] - padding_before_[1]; FOR_RANGE(int64_t, oh, 0, oh_num_) { if (ih < 0 || ih >= ih_num_) { col_buf_writer->InvalidHFunc(); } else { int64_t iw = kw * dilation_rate_[2] - padding_before_[2]; FOR_RANGE(int64_t, ow, 0, ow_num_) { if (iw < 0 || iw >= iw_num_) { col_buf_writer->InvalidWFunc(); } else { (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw); } iw += strides_[2]; } } ih += strides_[1]; } } id += strides_[0]; } } private: const int32_t* strides_; const int32_t* dilation_rate_; const int32_t* padding_before_; DHWValidFunc dhw_valid_func_; int64_t id_num_ = 0; int64_t ih_num_ = 0; int64_t iw_num_ = 0; int64_t od_num_ = 0; int64_t oh_num_ = 0; int64_t ow_num_ = 0; }; template struct DeconvKernelUtil final { public: static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), out_shape.At(3), out_shape.At(4)); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); } static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), out_shape.At(3), out_shape.At(4)); Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), out_shape.Count(3, 4), 1); DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); } private: static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, ColBufWriter* col_buf_writer) { for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { for (int64_t c = 0; c != weight_shape.At(4); ++c) { col_buf_util(col_buf_writer, c, kd, kh, kw); } } } } } }; template struct DeconvOpKernelCache final : public user_op::OpKernelCache { Col2ImFunc col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; ; Shape in_5d_shape_; Shape out_5d_shape_; Shape weight_5d_shape_; std::vector strides_3d_; std::vector dilation_rate_3d_; std::vector padding_before_3d_; bool is_out_diff_need_trans_ = false; int32_t idx_offset_ = 0; bool is_dynamic_ = false; int32_t groups = 1; void Update(const ShapeView& x_shape, const ShapeView& out_shape) { auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { DimVector ret_vec; shape.ToDimVector(&ret_vec); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; if (is_dynamic_) { Shape in_shape; in_5d_shape_ = Gen5DShape(x_shape, idx_offset_); out_5d_shape_ = Gen5DShape(out_shape, idx_offset_); } } }; template std::shared_ptr> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); std::shared_ptr> state(new DeconvOpKernelCache()); if (data_format == "channels_first") { state->col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; state->is_out_diff_need_trans_ = false; state->idx_offset_ = 2; } else { state->col2im_func_ = DeconvKernelUtil::NDHWCCol2Im; state->is_out_diff_need_trans_ = true; state->idx_offset_ = 1; } auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { DimVector ret_vec(shape.dim_vec()); int32_t ndims = ret_vec.size() - 2; ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; state->groups = ctx->Attr("groups"); state->in_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_); state->out_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_); state->weight_5d_shape_ = Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; state->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); state->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { state->padding_before_3d_.emplace_back(0); } else { state->padding_before_3d_.emplace_back(padding_before.at(index)); } } return state; } template class DeconvCpuKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(DeconvCpuKernel); DeconvCpuKernel() = default; ~DeconvCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void InitOpKernelCacheWithFlags( user_op::KernelCacheContext* ctx, int8_t flag, std::shared_ptr* cache_ptr) const override { if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) { auto deconv_cache = std::dynamic_pointer_cast>(*cache_ptr); deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex("in", 0)->shape(), ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape()); return; } *cache_ptr = CreateDeconvOpKernelCache(ctx, "out", "in", "weight"); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { auto deconv_cache = dynamic_cast*>(cache); CHECK_NOTNULL(deconv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int32_t idx_offset = deconv_cache->idx_offset_; const int32_t input_group_interval = in->shape_view().At(1) / deconv_cache->groups; const int32_t weight_group_interval = weight->shape_view().At(0) / deconv_cache->groups; const int32_t output_group_interval = out->shape_view().At(1) / deconv_cache->groups; const int32_t input_step = input_group_interval * in->shape_view().Count(2); const int32_t weight_step = weight_group_interval * weight->shape_view().Count(1); const int32_t output_step = output_group_interval * out->shape_view().Count(2); const int32_t m = deconv_cache->weight_5d_shape_.Count(1); const int32_t n = deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); const int32_t k = deconv_cache->weight_5d_shape_.At(0) / deconv_cache->groups; Memset(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * sizeof(T)); std::unique_ptr matmul; if (deconv_cache->is_out_diff_need_trans_) { matmul = NewDeconvTransATransBMatmulPrimitive(ctx); } else { matmul = NewDeconvTransANoTransBMatmulPrimitive(ctx); } CHECK(matmul); FOR_RANGE(int64_t, i, 0, in->shape_view().At(0)) { const T* input_ptr = GetImgDptr(in, i); const T* weight_ptr = weight->dptr(); T* output_ptr = GetImgMutDptr(out, i); FOR_RANGE(int64_t, g, 0, deconv_cache->groups) { matmul->Launch(ctx->stream(), m, // (co / groups) * kd * kh * kw n, // od * oh * ow k, // filter / groups static_cast(1), weight_ptr, input_ptr, static_cast(0), col_buf->mut_dptr()); // out = col2im(col_buf') deconv_cache->col2im_func_( col_buf->mut_dptr(), ShapeView(deconv_cache->in_5d_shape_), ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_), deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(), deconv_cache->padding_before_3d_.data(), output_ptr); input_ptr += input_step; weight_ptr += weight_step; output_ptr += output_step; } } } }; #define REGISTER_DECONV_DATA_KERNEL(op_name, dtype) \ REGISTER_USER_KERNEL(#op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("groups") > 1) \ && (user_op::HobDataType("out", 0) == GetDataType::value) \ && DeconvTransATransBMatmulPrimitiveExists() \ && DeconvTransANoTransBMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& in_shape = ctx->InputTensorDesc("in", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ CalcElemNumOfColBuf(in_shape, weight_shape, idx_offset) * sizeof(dtype); \ return tmp_buffer_size; \ }) REGISTER_DECONV_DATA_KERNEL(deconv1d, float); REGISTER_DECONV_DATA_KERNEL(deconv1d, double); REGISTER_DECONV_DATA_KERNEL(deconv2d, float); REGISTER_DECONV_DATA_KERNEL(deconv2d, double); REGISTER_DECONV_DATA_KERNEL(deconv3d, float); REGISTER_DECONV_DATA_KERNEL(deconv3d, double); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/group_norm_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/primitive/unary_functor.cuh" #include "oneflow/core/cuda/layer_norm.cuh" #include #include "oneflow/core/kernel/cuda_graph_support.h" #ifdef WITH_CUTLASS #include #endif // WITH_CUTLASS namespace oneflow { namespace { template struct AffineStore { AffineStore(DST* y, int64_t row_size, int64_t channel_size, int64_t spatial_size, const DST* gamma, const DST* beta) : y(y), row_size(row_size), channel_size(channel_size), spatial_size(spatial_size), gamma(gamma), beta(beta), act(0, 0) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::layer_norm::Pack y_pack; const int64_t offset = row * row_size + col; const int64_t packed_offset = offset / PackSize; const int64_t gamma_beta_offset = (offset / spatial_size) % channel_size; DST gamma_val = 1.0; DST beta_val = 0.0; if (affine) { gamma_val = gamma[gamma_beta_offset]; beta_val = beta[gamma_beta_offset]; } #pragma unroll for (int i = 0; i < PackSize; ++i) { DST normalized_i = static_cast(src[i]); if (affine) { y_pack.elem[i] = act(normalized_i * gamma_val + beta_val); } else { // Direct Store. y_pack.elem[i] = act(normalized_i); } } *(reinterpret_cast*>(y) + packed_offset) = y_pack.storage; } bool CanPackAs(size_t pack_size) { return (spatial_size % pack_size) == 0; } DST* y; int64_t row_size; int64_t channel_size; int64_t spatial_size; const DST* gamma; const DST* beta; ep::primitive::UnaryFunctor act; }; template struct ScaleLoad { using LoadType = DST; ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size, int64_t channel_size, int64_t spatial_size) : src(src), gamma(gamma), row_size(row_size), channel_size(channel_size), spatial_size(spatial_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::layer_norm::Pack src_pack; cuda::layer_norm::Pack gamma_pack; const int64_t offset = row * row_size + col; const int64_t packed_offset = offset / PackSize; const int64_t gamma_offset = (offset / spatial_size) % channel_size; src_pack.storage = *(reinterpret_cast*>(src) + packed_offset); SRC gamma_val = static_cast(1.0); if (affine) { gamma_val = gamma[gamma_offset]; } #pragma unroll for (int i = 0; i < PackSize; ++i) { dst[i] = static_cast(src_pack.elem[i] * gamma_val); } } bool CanPackAs(size_t pack_size) { return (spatial_size % pack_size) == 0; } const SRC* src; const SRC* gamma; int64_t row_size; int64_t channel_size; int64_t spatial_size; }; #ifdef WITH_CUTLASS template struct ChannelsLastStore { ChannelsLastStore(DST* y, const DST* gamma, const DST* beta, int64_t spatial_size, int64_t channel_size, int64_t num_groups) : y(y), gamma(gamma), beta(beta), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups), act(0, 0) {} template __device__ void store(const SRC* src, int32_t row, int32_t col) { cuda::layer_norm::Pack y_pack; cuda::layer_norm::Pack gamma_pack; cuda::layer_norm::Pack beta_pack; int32_t spatial_idx; int32_t c1_idx; c1(spatial_idx, c1_idx, col); int32_t batch_idx; int32_t c0_idx; c0(batch_idx, c0_idx, row); const int32_t y_offset = (batch_idx * c0.divisor * c1.divisor * spatial_size + spatial_idx * c0.divisor * c1.divisor + c0_idx * c1.divisor + c1_idx) / PackSize; const int32_t gamma_beta_offset = (c0_idx * c1.divisor + c1_idx) / PackSize; if (affine) { gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_beta_offset); beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_beta_offset); } #pragma unroll for (int i = 0; i < PackSize; ++i) { DST normalized_i = static_cast(src[i]); if (affine) { y_pack.elem[i] = act(normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]); } else { // Direct Store. y_pack.elem[i] = act(normalized_i); } } *(reinterpret_cast*>(y) + y_offset) = y_pack.storage; } bool CanPackAs(size_t pack_size) { return (c1.divisor % pack_size) == 0; } DST* y; const DST* gamma; const DST* beta; int32_t spatial_size; cutlass::FastDivmod c0; cutlass::FastDivmod c1; ep::primitive::UnaryFunctor act; }; template struct ChannelsLastLoad { using LoadType = DST; ChannelsLastLoad(const SRC* src, int64_t spatial_size, int64_t channel_size, int64_t num_groups) : src(src), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups) {} template __device__ void load(DST* dst, int32_t row, int32_t col) const { int32_t spatial_idx; int32_t c1_idx; c1(spatial_idx, c1_idx, col); int32_t batch_idx; int32_t c0_idx; c0(batch_idx, c0_idx, row); cuda::layer_norm::Pack pack; const int32_t offset = (batch_idx * c0.divisor * c1.divisor * spatial_size + spatial_idx * c0.divisor * c1.divisor + c0_idx * c1.divisor + c1_idx) / N; pack.storage = *(reinterpret_cast*>(src) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } } bool CanPackAs(size_t pack_size) { return (c1.divisor % pack_size) == 0; } const SRC* src; int32_t spatial_size; cutlass::FastDivmod c0; cutlass::FastDivmod c1; }; #else template struct ChannelsLastStore { ChannelsLastStore(DST* y, const DST* gamma, const DST* beta, int64_t spatial_size, int64_t channel_size, int64_t num_groups) : y(y), gamma(gamma), beta(beta), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups), act(0, 0) {} template __device__ void store(const SRC* src, int32_t row, int32_t col) { cuda::layer_norm::Pack y_pack; cuda::layer_norm::Pack gamma_pack; cuda::layer_norm::Pack beta_pack; int32_t spatial_idx = col / c1; int32_t c1_idx = col - spatial_idx * c1; int32_t batch_idx = row / c0; int32_t c0_idx = row - batch_idx * c0; const int32_t y_offset = (batch_idx * c0 * c1 * spatial_size + spatial_idx * c0 * c1 + c0_idx * c1 + c1_idx) / PackSize; const int32_t gamma_beta_offset = (c0_idx * c1 + c1_idx) / PackSize; if (affine) { gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_beta_offset); beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_beta_offset); } #pragma unroll for (int i = 0; i < PackSize; ++i) { DST normalized_i = static_cast(src[i]); if (affine) { y_pack.elem[i] = act(normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]); } else { // Direct Store. y_pack.elem[i] = act(normalized_i); } } *(reinterpret_cast*>(y) + y_offset) = y_pack.storage; } bool CanPackAs(size_t pack_size) { return (c1 % pack_size) == 0; } DST* y; const DST* gamma; const DST* beta; int32_t spatial_size; int32_t c0; int32_t c1; ep::primitive::UnaryFunctor act; }; template struct ChannelsLastLoad { using LoadType = DST; ChannelsLastLoad(const SRC* src, int64_t spatial_size, int64_t channel_size, int64_t num_groups) : src(src), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups) {} template __device__ void load(DST* dst, int32_t row, int32_t col) const { int32_t spatial_idx = col / c1; int32_t c1_idx = col - spatial_idx * c1; int32_t batch_idx = row / c0; int32_t c0_idx = row - batch_idx * c0; cuda::layer_norm::Pack pack; const int32_t offset = (batch_idx * c0 * c1 * spatial_size + spatial_idx * c0 * c1 + c0_idx * c1 + c1_idx) / N; pack.storage = *(reinterpret_cast*>(src) + offset); #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } } bool CanPackAs(size_t pack_size) { return (c1 % pack_size) == 0; } const SRC* src; int32_t spatial_size; int32_t c0; int32_t c1; }; #endif // WITH_CUTLASS template void GroupNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const int64_t channel_size, const int64_t spatial_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance, bool channels_first) { using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; if (channels_first) { cuda::layer_norm::DirectLoad load(x_ptr, norm_size); AffineStore store(y_ptr, norm_size, channel_size, spatial_size, gamma_ptr, beta_ptr); cuda::layer_norm::DispatchLayerNorm( stream->As()->cuda_stream(), load, store, num_instances, norm_size, epsilon, mean->mut_dptr(), inv_variance->mut_dptr()); } else { ChannelsLastLoad load(x_ptr, spatial_size, channel_size, channel_size / (norm_size / spatial_size)); ChannelsLastStore store( y_ptr, gamma_ptr, beta_ptr, spatial_size, channel_size, channel_size / (norm_size / spatial_size)); cuda::layer_norm::DispatchLayerNorm( stream->As()->cuda_stream(), load, store, num_instances, norm_size, epsilon, mean->mut_dptr(), inv_variance->mut_dptr()); } } template void DispatchGroupNormAffine(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const int64_t channel_size, const int64_t spatial_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance, bool channels_first) { if (gamma_ptr != nullptr && beta_ptr != nullptr) { GroupNormForwardGpu(stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance, channels_first); } else { GroupNormForwardGpu(stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance, channels_first); } } template void DispatchGroupNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const int64_t channel_size, const int64_t spatial_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance, bool channels_first, const std::string& activation) { if (activation == "none") { DispatchGroupNormAffine( stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance, channels_first); } else if (activation == "silu") { DispatchGroupNormAffine( stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance, channels_first); } else { UNIMPLEMENTED(); } } template void GroupNormBackwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const int64_t channel_size, const int64_t spatial_size, const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean, const user_op::Tensor* inv_variance, const T* gamma_ptr, T* dx_ptr) { using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; cuda::layer_norm::DirectLoad load_x(x_ptr, norm_size); ScaleLoad load_scaled_dy(dy_ptr, gamma_ptr, norm_size, channel_size, spatial_size); cuda::layer_norm::DirectStore store(dx_ptr, norm_size); OF_CUDA_CHECK((cuda::layer_norm::DispatchLayerNormGrad( stream->As()->cuda_stream(), load_x, load_scaled_dy, store, mean->dptr(), inv_variance->dptr(), num_instances, norm_size))); } template void LaunchGroupNormBackward(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const int64_t channel_size, const int64_t spatial_size, const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean, const user_op::Tensor* inv_variance, const T* gamma_ptr, T* dx_ptr) { if (gamma_ptr != nullptr) { GroupNormBackwardGpu(stream, num_instances, norm_size, channel_size, spatial_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, dx_ptr); } else { GroupNormBackwardGpu(stream, num_instances, norm_size, channel_size, spatial_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, dx_ptr); } } } // namespace template class GroupNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GroupNormGpuKernel() = default; ~GroupNormGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); const double epsilon = ctx->Attr("epsilon"); const int32_t num_groups = ctx->Attr("num_groups"); const std::string& data_format = ctx->Attr("data_format"); CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON); const int64_t num_instances = mean->shape_view().elem_cnt(); // N*num_groups const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const int64_t batch_size = x->shape_view().At(0); int64_t channel_size = 0; bool channels_first = false; if (data_format == "channels_first") { channel_size = x->shape_view().At(1); channels_first = true; } else if (data_format == "channels_last") { channel_size = x->shape_view().At(x->shape_view().NumAxes() - 1); channels_first = false; } else { UNIMPLEMENTED(); } const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size; const T* gamma_ptr = nullptr; const T* beta_ptr = nullptr; if (ctx->has_input("gamma", 0) && ctx->has_input("beta", 0)) { const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); gamma_ptr = gamma->dptr(); CHECK_EQ(gamma->shape_view().elem_cnt(), channel_size); const user_op::Tensor* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr(); CHECK_EQ(beta->shape_view().elem_cnt(), channel_size); } DispatchGroupNormForwardGpu(ctx->stream(), num_instances, norm_size, channel_size, spatial_size, epsilon, x->dptr(), gamma_ptr, beta_ptr, y->mut_dptr(), mean, inv_variance, channels_first, ctx->Attr("activation")); } }; #define REGISTER_GROUP_NORM_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("group_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_GROUP_NORM_CUDA_KERNEL(half) REGISTER_GROUP_NORM_CUDA_KERNEL(float) REGISTER_GROUP_NORM_CUDA_KERNEL(double) #if CUDA_VRSION >= 11000 REGISTER_GROUP_NORM_CUDA_KERNEL(nv_bfloat16) #endif template class GroupNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GroupNormGradGpuKernel() = default; ~GroupNormGradGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const int64_t batch_size = x->shape_view().At(0); const int64_t channel_size = x->shape_view().At(1); const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size; const T* gamma_ptr = nullptr; if (ctx->has_input("gamma", 0)) { gamma_ptr = ctx->Tensor4ArgNameAndIndex("gamma", 0)->dptr(); } LaunchGroupNormBackward(ctx->stream(), num_instances, norm_size, channel_size, spatial_size, dy->dptr(), x->dptr(), mean, inv_variance, gamma_ptr, dx->mut_dptr()); }; }; #define REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("group_norm_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(half) REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(float) REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(double) #if CUDA_VRSION >= 11000 REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16) #endif constexpr int kReduceBlockSize = 512; constexpr int kBlockSize = 128; constexpr int kNumWaves = 32; inline cudaError_t GetReduceNumBlocks(int64_t n, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min(n, sm_count * tpm / kReduceBlockSize * kNumWaves)); return cudaSuccess; } inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min((n + kBlockSize - 1) / kBlockSize, sm_count * tpm / kBlockSize * kNumWaves)); return cudaSuccess; } template struct SumOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } }; template struct GetPackType { using type = typename std::aligned_storage::type; }; template using PackType = typename GetPackType::type; template union Pack { static_assert(sizeof(PackType) == sizeof(T) * PackSize, ""); __device__ Pack(T val) { for (int i = 0; i < PackSize; i++) { elem[i] = val; } } T elem[PackSize]; PackType storage; }; constexpr int kMaxPackBytes = 128 / 8; constexpr int kMaxPackSize = 8; constexpr int Min(int a, int b) { return a < b ? a : b; } template constexpr int GetPackSize() { return Min(kMaxPackBytes / sizeof(T), kMaxPackSize); } template __global__ void GroupNormParamGradKernel(const T* dy, const T* x, const ComputeType* mean, const ComputeType* inv_var, ComputeType* dgamma_partial_sum, ComputeType* dbeta_partial_sum, const int32_t batch_size, const int32_t group_size, const int32_t channel_size, const int32_t spatial_size) { using LoadType = PackType; const int32_t batch_channel_size = batch_size * channel_size; for (int32_t batch_channel_id = blockIdx.x; batch_channel_id < batch_channel_size; batch_channel_id += gridDim.x) { const int32_t batch_id = batch_channel_id / channel_size; const int32_t channel_id = batch_channel_id % channel_size; const int32_t group_num = channel_size / group_size; const int32_t batch_group_id = batch_id * group_size + channel_id / group_num; ComputeType mean_val = mean[batch_group_id]; ComputeType inv_var_val = inv_var[batch_group_id]; Pack ds_sum_pack(0); Pack db_sum_pack(0); for (int32_t spatial = threadIdx.x * PackSize; spatial < spatial_size; spatial += blockDim.x * PackSize) { Pack dy_pack(0); Pack x_pack(0); const int32_t load_idx = batch_channel_id * spatial_size + spatial; const LoadType* dy_load = reinterpret_cast(dy + load_idx); dy_pack.storage = *dy_load; const LoadType* x_load = reinterpret_cast(x + load_idx); x_pack.storage = *x_load; #pragma unroll for (int i = 0; i < PackSize; i++) { ds_sum_pack.elem[i] += static_cast(dy_pack.elem[i]) * (static_cast(x_pack.elem[i]) - mean_val) * inv_var_val; db_sum_pack.elem[i] += static_cast(dy_pack.elem[i]); } } ComputeType ds_sum = 0.0; ComputeType db_sum = 0.0; #pragma unroll for (int i = 0; i < PackSize; i++) { ds_sum += ds_sum_pack.elem[i]; db_sum += db_sum_pack.elem[i]; } __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage1; __shared__ typename BlockReduce::TempStorage temp_storage2; ComputeType ds_sum_result = BlockReduce(temp_storage1).Reduce(ds_sum, SumOp()); ComputeType db_sum_result = BlockReduce(temp_storage2).Reduce(db_sum, SumOp()); if (threadIdx.x == 0) { dgamma_partial_sum[batch_channel_id] = ds_sum_result; dbeta_partial_sum[batch_channel_id] = db_sum_result; } } } template __global__ void BatchReduceGammaBetaGradKernel(ComputeType* ds_sum, ComputeType* db_sum, T* dgamma, T* dbeta, const int32_t batch_size, const int32_t group_size, const int32_t channel_size, const int32_t spatial_size) { const int32_t group_num = channel_size / group_size; CUDA_1D_KERNEL_LOOP(channel_idx, channel_size) { ComputeType dgamma_sum = 0.0; ComputeType dbeta_sum = 0.0; for (int batch_id = 0; batch_id < batch_size; batch_id++) { const int32_t batch_group_id = batch_id * group_size + channel_idx / group_num; const int32_t batch_channel_id = batch_id * channel_size + channel_idx; dgamma_sum += ds_sum[batch_channel_id]; dbeta_sum += db_sum[batch_channel_id]; } dgamma[channel_idx] = dgamma_sum; dbeta[channel_idx] = dbeta_sum; } } template int32_t GetLaunchPackSize(const int32_t spatial_size) { for (int pack_size = GetPackSize(); pack_size > 0; pack_size /= 2) { if (spatial_size % pack_size == 0) { return pack_size; } } return 1; } template void DispatchGroupNormParamGradKernel(ep::Stream* stream, const T* dy, const T* x, const ComputeType* mean, const ComputeType* inv_var, ComputeType* reduce_ds_buf, ComputeType* reduce_db_buf, const int32_t batch_size, const int32_t group_size, const int32_t channel_size, const int32_t spatial_size) { const int launch_pack_size = GetLaunchPackSize(spatial_size); int num_blocks; OF_CUDA_CHECK(GetReduceNumBlocks(batch_size * channel_size, &num_blocks)); if (launch_pack_size == 8) { GroupNormParamGradKernel <<As()->cuda_stream()>>>( dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size, channel_size, spatial_size); } else if (launch_pack_size == 4) { GroupNormParamGradKernel <<As()->cuda_stream()>>>( dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size, channel_size, spatial_size); } else if (launch_pack_size == 2) { GroupNormParamGradKernel <<As()->cuda_stream()>>>( dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size, channel_size, spatial_size); } else { GroupNormParamGradKernel <<As()->cuda_stream()>>>( dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size, channel_size, spatial_size); } } template class GroupNormParamGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GroupNormParamGradGpuKernel() = default; ~GroupNormParamGradGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); user_op::Tensor* dgamma = ctx->Tensor4ArgNameAndIndex("dgamma", 0); user_op::Tensor* dbeta = ctx->Tensor4ArgNameAndIndex("dbeta", 0); const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const int64_t batch_size = x->shape_view().At(0); const int64_t channel_size = x->shape_view().At(1); const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size; const int64_t group_size = num_instances / batch_size; user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; ComputeType* reduce_ds_buf_ptr = reinterpret_cast(tmp_buffer->mut_dptr()); ComputeType* reduce_db_buf_ptr = reinterpret_cast( tmp_buffer->mut_dptr() + batch_size * channel_size * sizeof(T)); DispatchGroupNormParamGradKernel( ctx->stream(), dy->dptr(), x->dptr(), mean->dptr(), inv_variance->dptr(), reduce_ds_buf_ptr, reduce_db_buf_ptr, batch_size, group_size, channel_size, spatial_size); int num_blocks; OF_CUDA_CHECK(GetNumBlocks(channel_size, &num_blocks)); // Note(zhengzekang): In large batchsize, it is recommend to use gemm to reduce. (1, N) matmul // (N, C) BatchReduceGammaBetaGradKernel <<stream()->As()->cuda_stream()>>>( reduce_ds_buf_ptr, reduce_db_buf_ptr, dgamma->mut_dptr(), dbeta->mut_dptr(), batch_size, group_size, channel_size, spatial_size); }; }; #define REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(dtype, compute_dtype) \ REGISTER_USER_KERNEL("group_norm_param_grad") \ .SetCreateFn>() \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const auto& x = ctx->InputTensorDesc("x", 0); \ const int64_t batch_size = x.shape().At(0); \ const int64_t channel_size = x.shape().At(1); \ size_t tmp_buffer_size = (2 * batch_size * channel_size) * sizeof(compute_dtype); \ return tmp_buffer_size; \ }) \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(half, float) REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(float, float) REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(double, double) #if CUDA_VRSION >= 11000 REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(nv_bfloat16, float) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/grouped_matmul_bias.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/scalar.h" namespace oneflow { struct Problem { Problem(int64_t m, int64_t n, int64_t k) : m(m), n(n), k(k) {} int64_t m; int64_t n; int64_t k; }; inline bool operator==(const Problem& lhs, const Problem& rhs) { return lhs.m == rhs.m && lhs.n == rhs.n && lhs.k == rhs.k; } } // namespace oneflow namespace std { template<> struct hash { std::size_t operator()(const oneflow::Problem& p) const { return oneflow::Hash(p.m, p.n, p.k); } }; } // namespace std namespace oneflow { namespace { constexpr int64_t kMaxProblemBatch = 64; template struct Buffer { const T* x; const T* w; const T* b; T* y; }; template struct Param { Param(const Problem& problem, std::vector> buffers) : problem(problem), n(buffers.size()) { std::copy(buffers.cbegin(), buffers.cend(), buffer); elem_cnt = n * problem.m * problem.n; } Problem problem; Buffer buffer[kMaxProblemBatch]; int n; int elem_cnt; }; template __global__ void InitPtrAndApplyBias(Param p, void** ptr_arr) { if (has_biases) { CUDA_1D_KERNEL_LOOP(i, p.elem_cnt) { const int32_t p_idx = i / (p.problem.m * p.problem.n); const int32_t y_idx = i % (p.problem.m * p.problem.n); const int32_t m_idx = y_idx / p.problem.n; const int32_t n_idx = y_idx % p.problem.n; p.buffer[p_idx].y[y_idx] = p.buffer[p_idx].b[n_idx]; } } CUDA_1D_KERNEL_LOOP(i, p.n) { ptr_arr[i] = const_cast(p.buffer[i].x); ptr_arr[i + kMaxProblemBatch] = const_cast(p.buffer[i].w); ptr_arr[i + 2 * kMaxProblemBatch] = p.buffer[i].y; } } union CublasScalarParameter { double d; float s; half h; }; CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cudaDataType_t compute_type) { CublasScalarParameter sp{}; if (compute_type == CUDA_R_64F) { sp.d = scalar.Value(); } else if (compute_type == CUDA_R_32F) { sp.s = scalar.Value(); } else if (compute_type == CUDA_R_16F) { sp.h = static_cast(scalar.Value()); } else { UNIMPLEMENTED(); } return sp; } template void ApplyGroup(const Problem& problem, std::vector> ptrs, bool has_biases, void* workspace, ep::Stream* stream) { Param params(problem, ptrs); void** ptr_arr = reinterpret_cast(workspace); if (has_biases) { RUN_CUDA_KERNEL((InitPtrAndApplyBias), stream, params.elem_cnt, params, ptr_arr); } else { RUN_CUDA_KERNEL((InitPtrAndApplyBias), stream, params.n, params, ptr_arr); } float alpha = 1.0; float beta = has_biases ? 1.0 : 0.0; cudaDataType_t data_type{}; cudaDataType_t compute_type{}; if (std::is_same::value) { data_type = CUDA_R_16F; const bool allow_half_accumulation = ParseBooleanFromEnv("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", false); if (allow_half_accumulation) { compute_type = CUDA_R_16F; } else { compute_type = CUDA_R_32F; } } else if (std::is_same::value) { data_type = CUDA_R_32F; compute_type = CUDA_R_32F; } else { UNIMPLEMENTED(); } auto sp_alpha = GetCublasScalarParameter(alpha, compute_type); auto sp_beta = GetCublasScalarParameter(beta, compute_type); OF_CUBLAS_CHECK(cublasGemmBatchedEx( stream->As()->cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, problem.n, problem.m, problem.k, &sp_alpha, ptr_arr + kMaxProblemBatch, data_type, problem.k, ptr_arr, data_type, problem.k, &sp_beta, ptr_arr + 2 * kMaxProblemBatch, data_type, problem.n, params.n, compute_type, CUBLAS_GEMM_DEFAULT)); } template class GroupedMatmulBiasKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GroupedMatmulBiasKernel() = default; ~GroupedMatmulBiasKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { HashMap>> groups; const int32_t input_size = ctx->input_size("xs"); CHECK_EQ(ctx->input_size("weights"), input_size); const bool has_biases = ctx->has_input("biases", 0); if (has_biases) { CHECK_EQ(ctx->input_size("biases"), input_size); } CHECK_EQ(ctx->output_size("ys"), input_size); for (int32_t i = 0; i < input_size; ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("xs", i); const user_op::Tensor* w = ctx->Tensor4ArgNameAndIndex("weights", i); const user_op::Tensor* b = has_biases ? ctx->Tensor4ArgNameAndIndex("biases", i) : nullptr; user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("ys", i); CHECK_GE(x->shape_view().NumAxes(), 2); const int64_t k = x->shape_view().At(x->shape_view().NumAxes() - 1); const int64_t m = x->shape_view().elem_cnt() / k; CHECK_EQ(w->shape_view().NumAxes(), 2); CHECK_EQ(w->shape_view().At(1), k); const int64_t n = w->shape_view().At(0); if (has_biases) { CHECK_EQ(b->shape_view().NumAxes(), 1); CHECK_EQ(b->shape_view().At(0), n); } CHECK_EQ(y->shape_view().NumAxes(), x->shape_view().NumAxes()); CHECK_EQ(y->shape_view().At(y->shape_view().NumAxes() - 1), n); for (int32_t j = 0; j < y->shape_view().NumAxes() - 1; ++j) { CHECK_EQ(y->shape_view().At(j), x->shape_view().At(j)); } groups[Problem(m, n, k)].push_back(Buffer{ x->dptr(), w->dptr(), has_biases ? b->dptr() : nullptr, y->mut_dptr()}); } void* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr(); for (const auto& group : groups) { for (size_t i = 0; i < group.second.size(); i += kMaxProblemBatch) { std::vector> ptrs( {group.second.begin() + i, group.second.begin() + i + std::min(group.second.size() - i, kMaxProblemBatch)}); ApplyGroup(group.first, ptrs, has_biases, workspace, ctx->stream()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(cpp_type, data_type) \ REGISTER_USER_KERNEL("grouped_matmul_bias") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("ys", 0) == data_type)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ return kMaxProblemBatch * 3 * sizeof(void*); \ }); \ ; REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(float, DataType::kFloat) REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(half, DataType::kFloat16) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/groupwise_quantization_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template struct alignas(sizeof(T) * pack_size) AlignedArray { __device__ AlignedArray() { // do nothing } union { T elem[pack_size]; }; }; template struct Cast { __device__ void operator()(const AlignedArray& src, AlignedArray* dst) { #pragma unroll for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast(src.elem[i]); } } }; template struct Cast { __device__ void operator()(const AlignedArray& src, AlignedArray* dst) { #pragma unroll for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast(src.elem[i]); } } __device__ void operator()(const AlignedArray& src, AlignedArray* dst) { #pragma unroll for (int i = 0; i < pack_size; ++i) { const uint8_t q = src.elem[i]; const uint8_t hi = (q >> 4); const uint8_t lo = (q & 0xF); dst->elem[i * 2 + 0] = static_cast(hi); dst->elem[i * 2 + 1] = static_cast(lo); } } }; template struct Cast { __device__ void operator()(const AlignedArray& src, AlignedArray* dst) { #pragma unroll for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast(src.elem[i]); } } __device__ void operator()(const AlignedArray& src, AlignedArray* dst) { #pragma unroll for (int i = 0; i < pack_size; ++i) { const int8_t q = src.elem[i]; const int8_t hi = (q >> 4); int8_t lo = (q << 4); lo = (lo >> 4); dst->elem[i * 2 + 0] = static_cast(hi); dst->elem[i * 2 + 1] = static_cast(lo); } } }; template struct InplaceAddScalar { __device__ void operator()(AlignedArray* array, C scalar) { #pragma unroll for (int i = 0; i < pack_size; ++i) { array->elem[i] += scalar; } } }; template struct InplaceFmaScalar { __device__ void operator()(AlignedArray* array, T m, T a) { #pragma unroll for (int i = 0; i < pack_size; ++i) { array->elem[i] = array->elem[i] * m + a; } } }; #if __CUDA_ARCH_ >= 530 template struct InplaceFmaScalar { __device__ void operator()(AlignedArray* array, half m, half a) { if (pack_size == 1) { #pragma unroll for (int i = 0; i < pack_size; ++i) { array->elem[i] = array->elem[i] * m + a; } } else { const half2 m2 = __half2half2(m); const half2 a2 = __half2half2(a); half2* h2 = reinterpret_cast(array->elem); #pragma unroll for (int i = 0; i < pack_size / 2; ++i) { h2[i] = __hfma2(h2[i], m2, a2); } } } }; #endif // __CUDA_ARCH_ >= 530 template struct InplaceFma { __device__ void operator()(AlignedArray* a, const AlignedArray& b, const AlignedArray& c) { #pragma unroll for (int i = 0; i < pack_size; ++i) { a->elem[i] = a->elem[i] * b.elem[i] + c.elem[i]; } } }; template struct InplaceMulScalar { __device__ void operator()(AlignedArray* a, T b) { #pragma unroll for (int i = 0; i < pack_size; ++i) { a->elem[i] = a->elem[i] * b; } } }; template struct MultiplyAccumulate { __device__ void operator()(const AlignedArray& a, const AlignedArray& b, C* sum) { #pragma unroll for (int i = 0; i < pack_size; ++i) { *sum += static_cast(a.elem[i] * b.elem[i]); } } }; template struct MultiplyAccumulate { __device__ void operator()(const AlignedArray& a, const AlignedArray& b, float* sum) { if (pack_size == 1) { #pragma unroll for (int i = 0; i < pack_size; ++i) { *sum += static_cast(a.elem[i] * b.elem[i]); } } else { const half2* a2 = reinterpret_cast(a.elem); const half2* b2 = reinterpret_cast(b.elem); for (int i = 0; i < pack_size / 2; ++i) { const half2 c2 = __hmul2(a2[i], b2[i]); const float2 f2 = __half22float2(c2); *sum += f2.x; *sum += f2.y; } } } }; template __global__ void Dequantize3D(Index packed_elem_cnt, Index group_size, Index packed_inner_size, const AlignedArray* quantized, const AlignedArray* scale, const AlignedArray* zero, AlignedArray* out) { const Index packed_group_inner_size = group_size * packed_inner_size; CUDA_1D_KERNEL_LOOP_T(Index, i, packed_elem_cnt) { const Index outer_id = outer_size_1 ? 0 : i / packed_group_inner_size; const Index group_inner_offset = i - outer_id * packed_group_inner_size; const Index group_id = group_inner_offset / packed_inner_size; const Index inner_id = group_inner_offset - group_id * packed_inner_size; const Index scale_offset = outer_id * packed_inner_size + inner_id; const AlignedArray group_scale = scale[scale_offset]; AlignedArray group_zero; if (symmetric) { if (std::is_same::value) { group_zero = group_scale; InplaceMulScalar()(&group_zero, -static_cast(((1 << (bits - 1)) - 1))); } else { #pragma unroll for (int i = 0; i < d_pack_size; ++i) { group_zero.elem[i] = 0; } } } else { group_zero = zero[scale_offset]; } AlignedArray values; const AlignedArray q = quantized[i]; Cast()(q, &values); InplaceFma()(&values, group_scale, group_zero); out[i] = values; } } template void LaunchDequantize3D(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) { if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) { const int64_t packed_elem_cnt = outer_size * group_size * inner_size / d_pack_size; const int64_t packed_inner_size = inner_size / d_pack_size; if (packed_elem_cnt <= (1 << 30)) { RUN_CUDA_KERNEL((Dequantize3D), stream, packed_elem_cnt, packed_elem_cnt, group_size, packed_inner_size, reinterpret_cast*>(in), reinterpret_cast*>(scale), reinterpret_cast*>(zero), reinterpret_cast*>(out)); } else { RUN_CUDA_KERNEL((Dequantize3D), stream, packed_elem_cnt, packed_elem_cnt, group_size, packed_inner_size, reinterpret_cast*>(in), reinterpret_cast*>(scale), reinterpret_cast*>(zero), reinterpret_cast*>(out)); } } else { UNIMPLEMENTED(); } } template void DispatchDequantize3DOuterSize1(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) { if (outer_size == 1) { LaunchDequantize3D( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else { LaunchDequantize3D( stream, outer_size, group_size, inner_size, in, scale, zero, out); } } template void DispatchDequantize3D(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) { constexpr int32_t max_pack_size = 16 / sizeof(T); constexpr int32_t data_per_quant = 8 / num_bits; int32_t pack_size = max_pack_size; while (inner_size % pack_size != 0) { pack_size /= 2; } if (pack_size == 16) { DispatchDequantize3DOuterSize1( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else if (pack_size == 8) { DispatchDequantize3DOuterSize1( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else if (pack_size == 4) { DispatchDequantize3DOuterSize1( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else if (pack_size == 2) { DispatchDequantize3DOuterSize1( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else if (pack_size == 1) { DispatchDequantize3DOuterSize1( stream, outer_size, group_size, inner_size, in, scale, zero, out); } else { UNIMPLEMENTED(); } } template __global__ void DequantizeInnerSize1(Index packed_elem_cnt, Index packed_group_size, const AlignedArray* quantized, const T* scale, const T* zero, AlignedArray* out) { CUDA_1D_KERNEL_LOOP_T(Index, i, packed_elem_cnt) { const Index group_id = i / packed_group_size; const T group_scale = scale[group_id]; T group_zero; if (symmetric) { if (std::is_same::value) { group_zero = -static_cast(((1 << (bits - 1)) - 1)) * group_scale; } else { group_zero = 0; } } else { group_zero = zero[group_id]; } AlignedArray values; AlignedArray q = quantized[i]; Cast()(q, &values); InplaceFmaScalar()(&values, group_scale, group_zero); out[i] = values; } } template void LaunchDequantizeInnerSize1(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, const U* in, const T* scale, const T* zero, T* out) { if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) { const int64_t packed_elem_cnt = outer_size * group_size / d_pack_size; const int64_t packed_group_size = group_size / d_pack_size; if (packed_elem_cnt <= (1 << 30)) { RUN_CUDA_KERNEL( (DequantizeInnerSize1), stream, packed_elem_cnt, packed_elem_cnt, packed_group_size, reinterpret_cast*>(in), scale, zero, reinterpret_cast*>(out)); } else { RUN_CUDA_KERNEL( (DequantizeInnerSize1), stream, packed_elem_cnt, packed_elem_cnt, packed_group_size, reinterpret_cast*>(in), scale, zero, reinterpret_cast*>(out)); } } else { UNIMPLEMENTED(); } } template void DispatchDequantizeInnerSize1PackSize(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, const U* in, const T* scale, const T* zero, T* out) { constexpr int32_t max_pack_size = 16 / sizeof(T); int32_t pack_size = max_pack_size; while (group_size % pack_size != 0) { pack_size /= 2; } constexpr int32_t data_per_quant = 8 / num_bits; CHECK(group_size % data_per_quant == 0); if (pack_size == 16) { LaunchDequantizeInnerSize1( stream, outer_size, group_size, in, scale, zero, out); } else if (pack_size == 8) { LaunchDequantizeInnerSize1( stream, outer_size, group_size, in, scale, zero, out); } else if (pack_size == 4) { LaunchDequantizeInnerSize1( stream, outer_size, group_size, in, scale, zero, out); } else if (pack_size == 2) { LaunchDequantizeInnerSize1( stream, outer_size, group_size, in, scale, zero, out); } else if (pack_size == 1) { LaunchDequantizeInnerSize1( stream, outer_size, group_size, in, scale, zero, out); } else { UNIMPLEMENTED(); } } template void DispatchDequantizeSize(ep::CudaStream* stream, int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) { if (inner_size == 1) { DispatchDequantizeInnerSize1PackSize(stream, outer_size, group_size, in, scale, zero, out); } else { DispatchDequantize3D(stream, outer_size, group_size, inner_size, in, scale, zero, out); } } template void DispatchDequantize(ep::CudaStream* stream, int32_t num_bits, bool symmetric, int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) { if (num_bits == 4) { if (symmetric) { DispatchDequantizeSize(stream, outer_size, group_size, inner_size, in, scale, zero, out); } else { DispatchDequantizeSize(stream, outer_size, group_size, inner_size, in, scale, zero, out); } } else if (num_bits == 8) { if (symmetric) { DispatchDequantizeSize(stream, outer_size, group_size, inner_size, in, scale, zero, out); } else { DispatchDequantizeSize(stream, outer_size, group_size, inner_size, in, scale, zero, out); } } else { UNIMPLEMENTED(); } } template class GroupwiseDequantizeKernel final : public user_op::OpKernel { public: GroupwiseDequantizeKernel() = default; ~GroupwiseDequantizeKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); const user_op::Tensor* zero = nullptr; if (ctx->has_input("zero", 0)) { zero = ctx->Tensor4ArgNameAndIndex("zero", 0); } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t group_size = ctx->Attr("group_size"); const int64_t group_dim = ctx->Attr("group_dim"); const int32_t num_bits = ctx->Attr("num_bits"); const bool symmetric = ctx->Attr("symmetric"); const int64_t num_in_axes = in->shape_view().NumAxes(); CHECK_GE(num_in_axes, 1); CHECK_EQ(scale->shape_view().NumAxes(), num_in_axes); if (zero != nullptr) { CHECK_EQ(zero->shape_view().NumAxes(), num_in_axes); } CHECK_EQ(out->shape_view().NumAxes(), num_in_axes); CHECK_GE(group_dim, 0); CHECK_LT(group_dim, num_in_axes); for (int i = 0; i < num_in_axes; ++i) { if (i == num_in_axes - 1) { CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i) * (8 / num_bits)); } else { CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i)); } } const int64_t group_dim_size = out->shape_view().At(group_dim); CHECK_GT(group_size, 0); CHECK_LE(group_size, group_dim_size); CHECK_EQ(group_dim_size % group_size, 0); const int64_t num_groups = group_dim_size / group_size; for (int i = 0; i < num_in_axes; ++i) { const int64_t expected_dim_size = i == group_dim ? num_groups : out->shape_view().At(i); CHECK_EQ(scale->shape_view().At(i), expected_dim_size); if (zero != nullptr) { CHECK_EQ(zero->shape_view().At(i), expected_dim_size); } } const int64_t outer_size = out->shape_view().Count(0, group_dim) * num_groups; const int64_t inner_size = out->shape_view().Count(group_dim + 1); if (in->data_type() == DataType::kUInt8) { DispatchDequantize(ctx->stream()->As(), num_bits, symmetric, outer_size, group_size, inner_size, in->dptr(), scale->dptr(), zero == nullptr ? nullptr : zero->dptr(), out->mut_dptr()); } else if (in->data_type() == DataType::kInt8) { DispatchDequantize(ctx->stream()->As(), num_bits, symmetric, outer_size, group_size, inner_size, in->dptr(), scale->dptr(), zero == nullptr ? nullptr : zero->dptr(), out->mut_dptr()); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(dtype) \ REGISTER_USER_KERNEL("groupwise_dequantize") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("scale", 0) == GetDataType::value)) REGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(half); REGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(float); template __global__ void QuantizedMatmulBiasGroupN(int32_t M, int32_t N, int32_t K, int32_t group_size, const AlignedArray* __restrict__ x, const AlignedArray* __restrict__ w, const AlignedArray* __restrict__ scale, const AlignedArray* __restrict__ zero, const T* __restrict__ bias, T* __restrict__ out) { for (int32_t m = blockIdx.x; m < M; m += gridDim.x) { const auto* x_m = x + m * K; for (int32_t n = blockIdx.y; n < N; n += gridDim.y) { C t_sum = 0; const auto* w_n = w + n * K; const int64_t group_id = single_group ? 0 : n / group_size; const auto* scale_n = scale + group_id * K; const auto* zero_n = symmetric ? nullptr : zero + group_id * K; for (int32_t k = threadIdx.x; k < K; k += block_size) { auto xs = x_m[k]; auto ws = w_n[k]; auto scale_k = scale_n[k]; AlignedArray zero_k; if (symmetric) { if (std::is_same::value) { zero_k = scale_k; InplaceMulScalar()(&zero_k, -static_cast(((1 << (bits - 1)) - 1))); } else { for (int i = 0; i < d_pack_size; ++i) { zero_k.elem[i] = 0; } } } else { zero_k = zero_n[k]; } AlignedArray weights; Cast()(ws, &weights); InplaceFma()(&weights, scale_k, zero_k); MultiplyAccumulate()(xs, weights, &t_sum); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; C sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { if (bias != nullptr) { sum += static_cast(bias[n]); } out[m * N + n] = static_cast(sum); } __syncthreads(); } } } template void LaunchMatmulBiasGroupN(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { constexpr uint32_t max_grid_size = 8192; constexpr uint32_t block_size = 128; const int64_t int32_max = std::numeric_limits::max(); if (m * k > int32_max || n * k > int32_max || m * n > int32_max || m > int32_max - max_grid_size || n > int32_max - max_grid_size || k > int32_max - block_size) { UNIMPLEMENTED(); } if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) { QuantizedMatmulBiasGroupN <<(m, max_grid_size), std::min(n, max_grid_size)), block_size, 0, stream->cuda_stream()>>>( m, n, k / d_pack_size, group_size, reinterpret_cast*>(x), reinterpret_cast*>(w), reinterpret_cast*>(scale), reinterpret_cast*>(zero), bias, out); } else { UNIMPLEMENTED(); } } template void DispatchMatmulBiasGroupNSingleGroup(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { if (n == group_size) { LaunchMatmulBiasGroupN( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else { LaunchMatmulBiasGroupN( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } } template void DispatchMatmulBiasGroupNPackSize(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { const int max_pack_size = 16 / sizeof(T); int pack_size = max_pack_size; while (k % pack_size != 0) { pack_size /= 2; } constexpr int32_t data_per_quant = 8 / num_bits; if (pack_size == 16) { DispatchMatmulBiasGroupNSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 8) { DispatchMatmulBiasGroupNSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 4) { DispatchMatmulBiasGroupNSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 2) { DispatchMatmulBiasGroupNSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 1) { DispatchMatmulBiasGroupNSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else { UNIMPLEMENTED(); } } template __global__ void QuantizedMatmulBiasGroupK(int32_t M, int32_t N, int32_t K, int32_t group_size, int32_t num_groups_per_n, const AlignedArray* __restrict__ x, const AlignedArray* __restrict__ w, const T* __restrict__ scale, const T* __restrict__ zero, const T* __restrict__ bias, T* __restrict__ out) { for (int32_t m = blockIdx.x; m < M; m += gridDim.x) { const auto* x_m = x + m * K; for (int32_t n = blockIdx.y; n < N; n += gridDim.y) { C t_sum = 0; const auto* w_n = w + n * K; const auto* scale_n = scale + n * num_groups_per_n; const T* zero_n = symmetric ? nullptr : zero + n * num_groups_per_n; T group_scale; T group_zero; if (single_group) { group_scale = static_cast(scale_n[0]); if (symmetric) { if (std::is_same::value) { group_zero = -static_cast(((1 << (bits - 1)) - 1)) * group_scale; } else { group_zero = 0; } } else { group_zero = zero_n[0]; } } for (int32_t k = threadIdx.x; k < K; k += block_size) { if (!single_group) { auto group_id = k / group_size; group_scale = static_cast(scale_n[group_id]); if (symmetric) { if (std::is_same::value) { group_zero = -static_cast(((1 << (bits - 1)) - 1)) * group_scale; } else { group_zero = 0; } } else { group_zero = zero_n[group_id]; } } auto xs = x_m[k]; auto ws = w_n[k]; AlignedArray weights; Cast()(ws, &weights); InplaceFmaScalar()(&weights, group_scale, group_zero); MultiplyAccumulate()(xs, weights, &t_sum); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; C sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { if (bias != nullptr) { sum += static_cast(bias[n]); } out[m * N + n] = static_cast(sum); } __syncthreads(); } } } template void LaunchMatmulBiasGroupK(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { constexpr uint32_t max_grid_size = 8192; constexpr uint32_t block_size = 128; const int64_t int32_max = std::numeric_limits::max(); if (m * k > int32_max || n * k > int32_max || m * n > int32_max || m > int32_max - max_grid_size || n > int32_max - max_grid_size || k > int32_max - block_size) { UNIMPLEMENTED(); } if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) { QuantizedMatmulBiasGroupK <<(m, max_grid_size), std::min(n, max_grid_size)), block_size, 0, stream->cuda_stream()>>>( m, n, k / d_pack_size, group_size / d_pack_size, k / group_size, reinterpret_cast*>(x), reinterpret_cast*>(w), scale, zero, bias, out); } else { UNIMPLEMENTED(); } } template void DispatchMatmulBiasGroupKSingleGroup(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { if (k == group_size) { LaunchMatmulBiasGroupK( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else { LaunchMatmulBiasGroupK( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } } template void DispatchMatmulBiasGroupKPackSize(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { const int max_pack_size = 16 / sizeof(T); int pack_size = max_pack_size; while (group_size % pack_size != 0) { pack_size /= 2; } constexpr int32_t data_per_quant = 8 / num_bits; if (pack_size == 16) { DispatchMatmulBiasGroupKSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 8) { DispatchMatmulBiasGroupKSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 4) { DispatchMatmulBiasGroupKSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 2) { DispatchMatmulBiasGroupKSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (pack_size == 1) { DispatchMatmulBiasGroupKSingleGroup( stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else { UNIMPLEMENTED(); } } template void DispatchMatmulBiasGroupDim(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k, int64_t group_dim, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { if (group_dim == 0) { DispatchMatmulBiasGroupNPackSize(stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else if (group_dim == 1) { DispatchMatmulBiasGroupKPackSize(stream, m, n, k, group_size, x, w, scale, zero, bias, out); } else { UNIMPLEMENTED(); } } template void DispatchMatmulBias(ep::CudaStream* stream, int num_bits, bool symmetric, int64_t m, int64_t n, int64_t k, int64_t group_dim, int64_t group_size, const T* x, const U* w, const T* scale, const T* zero, const T* bias, T* out) { if (num_bits == 4) { if (symmetric) { DispatchMatmulBiasGroupDim(stream, m, n, k, group_dim, group_size, x, w, scale, zero, bias, out); } else { DispatchMatmulBiasGroupDim(stream, m, n, k, group_dim, group_size, x, w, scale, zero, bias, out); } } else if (num_bits == 8) { if (symmetric) { DispatchMatmulBiasGroupDim(stream, m, n, k, group_dim, group_size, x, w, scale, zero, bias, out); } else { DispatchMatmulBiasGroupDim(stream, m, n, k, group_dim, group_size, x, w, scale, zero, bias, out); } } else { UNIMPLEMENTED(); } } template class FusedLinearWithGroupwiseQuantizedWeightKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedLinearWithGroupwiseQuantizedWeightKernel() = default; ~FusedLinearWithGroupwiseQuantizedWeightKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* w = ctx->Tensor4ArgNameAndIndex("w", 0); const user_op::Tensor* w_scale = ctx->Tensor4ArgNameAndIndex("w_scale", 0); const user_op::Tensor* b = (ctx->has_input("b", 0)) ? ctx->Tensor4ArgNameAndIndex("b", 0) : nullptr; const user_op::Tensor* w_zero = (ctx->has_input("w_zero", 0)) ? ctx->Tensor4ArgNameAndIndex("w_zero", 0) : nullptr; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const DataType data_type = x->data_type(); CHECK_EQ(w_scale->data_type(), data_type); CHECK_EQ(out->data_type(), data_type); const int64_t group_size = ctx->Attr("group_size"); const int64_t group_dim = ctx->Attr("group_dim"); CHECK(group_dim == 0 || group_dim == 1); const int32_t num_bits = ctx->Attr("num_bits"); const bool symmetric = ctx->Attr("symmetric"); CHECK_GE(x->shape_view().NumAxes(), 2); const int64_t k = x->shape_view().At(x->shape_view().NumAxes() - 1); const int64_t m = x->shape_view().elem_cnt() / k; CHECK_EQ(w->shape_view().NumAxes(), 2); if (num_bits == 4) { CHECK_EQ(w->shape_view().At(1) * 2, k); } else if (num_bits == 8) { CHECK_EQ(w->shape_view().At(1), k); } else { UNIMPLEMENTED(); } const int64_t n = w->shape_view().At(0); const int64_t group_dim_size = group_dim == 0 ? n : k; CHECK_GT(group_size, 0); CHECK_LE(group_size, group_dim_size); CHECK_EQ(group_dim_size % group_size, 0); const int64_t num_groups = group_dim_size / group_size; if (group_dim == 0) { CHECK_EQ(w_scale->shape_view().At(0), num_groups); CHECK_EQ(w_scale->shape_view().At(1), k); } else if (group_dim == 1) { CHECK_EQ(w_scale->shape_view().At(0), n); CHECK_EQ(w_scale->shape_view().At(1), num_groups); } else { UNIMPLEMENTED(); } if (w_zero != nullptr) { CHECK_EQ(w_zero->data_type(), data_type); CHECK(w_zero->shape_view() == w_scale->shape_view()); } if (b != nullptr) { CHECK_EQ(b->data_type(), data_type); CHECK_EQ(b->shape_view().NumAxes(), 1); CHECK_EQ(b->shape_view().At(0), n); } CHECK_EQ(x->shape_view().NumAxes(), out->shape_view().NumAxes()); for (int i = 0; i < x->shape_view().NumAxes() - 1; ++i) { CHECK_EQ(out->shape_view().At(i), x->shape_view().At(i)); } CHECK_EQ(out->shape_view().At(out->shape_view().NumAxes() - 1), n); if (symmetric) { CHECK(w_zero == nullptr); } else { CHECK(w_zero != nullptr); } const DataType quant_type = w->data_type(); if (quant_type == DataType::kUInt8) { DispatchMatmulBias( ctx->stream()->As(), num_bits, symmetric, m, n, k, group_dim, group_size, x->dptr(), w->dptr(), w_scale->dptr(), w_zero == nullptr ? nullptr : w_zero->dptr(), b == nullptr ? nullptr : b->dptr(), out->mut_dptr()); } else if (quant_type == DataType::kInt8) { DispatchMatmulBias( ctx->stream()->As(), num_bits, symmetric, m, n, k, group_dim, group_size, x->dptr(), w->dptr(), w_scale->dptr(), w_zero == nullptr ? nullptr : w_zero->dptr(), b == nullptr ? nullptr : b->dptr(), out->mut_dptr()); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(data_type, cpp_type) \ REGISTER_USER_KERNEL("fused_linear_with_groupwise_quantized_weight") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == data_type)); REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat, float); REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16, half); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/host_scalar_add_by_tensor_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void ScalarAdd(int32_t elem_cnt, const T* in, const T scalar, T* out) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { out[i] = in[i] + scalar; }; } } // namespace template class HostScalarAddByTensorKernel final : public user_op::OpKernel { public: HostScalarAddByTensorKernel() = default; ~HostScalarAddByTensorKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex("scalar", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); CHECK_EQ(scalar->shape_view().elem_cnt(), 1); // val of scalar can be visited because it is host input. const T scalar_val = *scalar->dptr(); ScalarAdd<<stream()->As()->cuda_stream()>>>(elem_cnt, x->dptr(), scalar_val, y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ARG_SORT_KERNEL(dtype) \ REGISTER_USER_KERNEL("host_scalar_add_by_tensor") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value) \ && (user_op::HobDataType("scalar", 0) == GetDataType::value) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CUDA_ARG_SORT_KERNEL(float) REGISTER_CUDA_ARG_SORT_KERNEL(double) REGISTER_CUDA_ARG_SORT_KERNEL(bool) REGISTER_CUDA_ARG_SORT_KERNEL(int8_t) REGISTER_CUDA_ARG_SORT_KERNEL(uint8_t) REGISTER_CUDA_ARG_SORT_KERNEL(int32_t) REGISTER_CUDA_ARG_SORT_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_batch_align_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include namespace oneflow { namespace { template void CopyFromTensorBuffer(T* image_ptr, const TensorBuffer& image_buffer, const int batch_height, const int batch_width, const int channels) { CHECK_EQ(image_buffer.shape_view().NumAxes(), 3); const int h = image_buffer.shape_view().At(0); const int w = image_buffer.shape_view().At(1); const int c = image_buffer.shape_view().At(2); CHECK_LE(h, batch_height); CHECK_LE(w, batch_width); CHECK_EQ(c, channels); FOR_RANGE(int, i, 0, h) { const F* from = image_buffer.data() + i * w * c; T* to = image_ptr + i * batch_width * channels; std::transform(from, from + w * c, to, [](F v) { return static_cast(v); }); } } template struct ImageCopier final { #define MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY(func_name, F) func_name DEFINE_STATIC_SWITCH_FUNC(void, CopyFromTensorBuffer, MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(IMAGE_DATA_TYPE_SEQ)) #undef MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY }; } // namespace template class ImageBatchAlignKernel final : public user_op::OpKernel { public: ImageBatchAlignKernel() = default; ~ImageBatchAlignKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in_tensor->shape_view().NumAxes(), 1); CHECK_EQ(out_tensor->shape_view().NumAxes(), 4); const int64_t num_images = in_tensor->shape_view().elem_cnt(); const bool dynamic_out = ctx->Attr("dynamic_out"); CHECK_GT(num_images, 0); int64_t max_height = 0; int64_t max_width = 0; const int64_t channels = out_tensor->shape_view().At(3); FOR_RANGE(int, i, 0, num_images) { const TensorBuffer& image_buffer = in_tensor->dptr()[i]; max_height = std::max(max_height, image_buffer.shape_view().At(0)); max_width = std::max(max_width, image_buffer.shape_view().At(1)); CHECK_EQ(image_buffer.shape_view().At(2), channels); } int32_t alignment = ctx->Attr("alignment"); max_height = RoundUp(max_height, alignment); max_width = RoundUp(max_width, alignment); if (dynamic_out) { auto mut_shape_view = out_tensor->mut_shape_view(); mut_shape_view.Set(0, num_images); mut_shape_view.Set(1, max_height); mut_shape_view.Set(2, max_width); } memset(out_tensor->mut_dptr(), 0, out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type())); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& image_buffer = in_tensor->dptr()[i]; T* out_ptr = out_tensor->mut_dptr() + i * max_height * max_width * channels; ImageCopier::SwitchCopyFromTensorBuffer(SwitchCase(image_buffer.data_type()), out_ptr, image_buffer, max_height, max_width, channels); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_IMAGE_BATCH_ALIGN_KERNEL(dtype) \ REGISTER_USER_KERNEL("image_batch_align") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_IMAGE_BATCH_ALIGN_KERNEL(uint8_t) REGISTER_IMAGE_BATCH_ALIGN_KERNEL(float) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_decode_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include namespace oneflow { namespace { void DecodeImage(const TensorBuffer& raw_bytes, TensorBuffer* image_buffer, const std::string& color_space, DataType data_type) { // should only support kChar, but numpy ndarray maybe cannot convert to char* CHECK(raw_bytes.data_type() == DataType::kChar || raw_bytes.data_type() == DataType::kInt8 || raw_bytes.data_type() == DataType::kUInt8); cv::_InputArray raw_bytes_arr(raw_bytes.data(), raw_bytes.elem_cnt()); cv::Mat image_mat = cv::imdecode( raw_bytes_arr, (ImageUtil::IsColor(color_space) ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE) | cv::IMREAD_ANYDEPTH); if (ImageUtil::IsColor(color_space) && color_space != "BGR") { ImageUtil::ConvertColor("BGR", image_mat, color_space, image_mat); } if (data_type == DataType::kUInt8) { image_mat.convertTo(image_mat, CV_8U); } else if (data_type == DataType::kFloat) { image_mat.convertTo(image_mat, CV_32F); } else { UNIMPLEMENTED(); } int64_t h = image_mat.rows; int64_t w = image_mat.cols; int64_t c = image_mat.channels(); image_buffer->Resize(Shape({h, w, c}), data_type); w *= c; if (image_mat.isContinuous()) { w *= h; h = 1; } char* image_ptr = image_buffer->mut_data(); FOR_RANGE(int64_t, i, 0, h) { memcpy(image_ptr + i * w, image_mat.ptr(i), w * GetSizeOfDataType(data_type)); } } } // namespace class ImageDecodeKernel final : public user_op::OpKernel { public: ImageDecodeKernel() = default; ~ImageDecodeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in_tensor->shape_view().elem_cnt(), out_tensor->shape_view().elem_cnt()); CHECK_GT(in_tensor->shape_view().elem_cnt(), 0); const TensorBuffer* in_img_buf = in_tensor->dptr(); TensorBuffer* out_img_buf = out_tensor->mut_dptr(); const std::string& color_space = ctx->Attr("color_space"); const DataType data_type = ctx->Attr("data_type"); MultiThreadLoop(in_tensor->shape_view().elem_cnt(), [&](size_t i) { DecodeImage(in_img_buf[i], out_img_buf + i, color_space, data_type); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("image_decode") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); ; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_object_preprocess_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include #include namespace oneflow { namespace { enum class FlipCode : int8_t { kNonFlip = 0x00, kHorizontalFlip = 0x01, kVerticalFlip = 0x02, kBothDirectionFlip = 0x03, }; bool operator&(FlipCode lhs, FlipCode rhs) { return static_cast(static_cast::type>(lhs) & static_cast::type>(rhs)); } int CvFlipCode(FlipCode flip_code) { if (flip_code == FlipCode::kHorizontalFlip) { return 1; } else if (flip_code == FlipCode::kVerticalFlip) { return 0; } else if (flip_code == FlipCode::kBothDirectionFlip) { return -1; } else { UNIMPLEMENTED(); } } void FlipImage(TensorBuffer* image_buffer, FlipCode flip_code) { cv::Mat image_mat = GenCvMat4ImageBuffer(*image_buffer); cv::flip(image_mat, image_mat, CvFlipCode(flip_code)); } template void FlipBoxes(TensorBuffer* boxes_buffer, int32_t image_width, int32_t image_height, FlipCode flip_code) { int num_boxes = boxes_buffer->shape_view().At(0); FOR_RANGE(int, i, 0, num_boxes) { T* cur_box_ptr = boxes_buffer->mut_data() + i * 4; if (flip_code & FlipCode::kHorizontalFlip) { T xmin = cur_box_ptr[0]; T xmax = cur_box_ptr[2]; cur_box_ptr[0] = image_width - xmax - static_cast(1); cur_box_ptr[2] = image_width - xmin - static_cast(1); } if (flip_code & FlipCode::kVerticalFlip) { T ymin = cur_box_ptr[1]; T ymax = cur_box_ptr[3]; cur_box_ptr[1] = image_height - ymax - static_cast(1); cur_box_ptr[3] = image_height - ymin - static_cast(1); } } } #define MAKE_FLIP_BOXES_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC(void, FlipBoxes, MAKE_FLIP_BOXES_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ)); #undef MAKE_FLIP_BOXES_SWITCH_ENTRY template void ScaleBoxes(TensorBuffer* boxes_buffer, T scale_w, T scale_h) { int num_boxes = boxes_buffer->shape_view().At(0); FOR_RANGE(int, i, 0, num_boxes) { T* cur_box_ptr = boxes_buffer->mut_data() + i * 4; cur_box_ptr[0] *= scale_w; cur_box_ptr[1] *= scale_h; cur_box_ptr[2] *= scale_w; cur_box_ptr[3] *= scale_h; } } #define MAKE_SCALE_BOXES_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC(void, ScaleBoxes, MAKE_SCALE_BOXES_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ)); #undef MAKE_SCALE_BOXES_SWITCH_ENTRY template void FlipPolygons(TensorBuffer* polygons_buffer, int32_t image_width, int32_t image_height, FlipCode flip_code) { int num_points = polygons_buffer->shape_view().At(0); FOR_RANGE(int, i, 0, num_points) { T* cur_poly_ptr = polygons_buffer->mut_data() + i * 2; if (flip_code & FlipCode::kHorizontalFlip) { cur_poly_ptr[0] = image_width - cur_poly_ptr[0]; } if (flip_code & FlipCode::kVerticalFlip) { cur_poly_ptr[1] = image_height - cur_poly_ptr[1]; } } } #define MAKE_FLIP_POLYGONS_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC(void, FlipPolygons, MAKE_FLIP_POLYGONS_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ)); #undef MAKE_FLIP_POLYGONS_SWITCH_ENTRY template void ScalePolygons(TensorBuffer* poly_buffer, T scale_w, T scale_h) { int num_pts = poly_buffer->shape_view().At(0); FOR_RANGE(int, i, 0, num_pts) { T* cur_pt = poly_buffer->mut_data() + i * 2; cur_pt[0] *= scale_w; cur_pt[1] *= scale_h; } } #define MAKE_SCALE_POLYGONS_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC(void, ScalePolygons, MAKE_SCALE_POLYGONS_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ)); #undef MAKE_SCALE_POLYGONS_SWITCH_ENTRY template void ImageNormalizeByChannel(TensorBuffer* image_buffer, const std::vector& std_vec, const std::vector& mean_vec) { CHECK_EQ(image_buffer->shape_view().NumAxes(), 3); int h = image_buffer->shape_view().At(0); int w = image_buffer->shape_view().At(1); int c = image_buffer->shape_view().At(2); CHECK_EQ(std_vec.size(), c); CHECK_EQ(mean_vec.size(), c); FOR_RANGE(int, i, 0, (h * w)) { T* image_data = image_buffer->mut_data() + i * c; FOR_RANGE(int, j, 0, c) { image_data[j] = (image_data[j] - mean_vec.at(j)) / std_vec.at(j); } } } #define MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY(func_name, T) func_name DEFINE_STATIC_SWITCH_FUNC(void, ImageNormalizeByChannel, MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ)); #undef MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY template void PolygonsToMask(const TensorBuffer& polys, const TensorBuffer& polys_nd_index, TensorBuffer* masks, int32_t im_w, int32_t im_h) { CHECK_EQ(polys.shape_view().NumAxes(), 2); CHECK_EQ(polys.shape_view().At(1), 2); CHECK_EQ(polys_nd_index.shape_view().NumAxes(), 2); CHECK_EQ(polys_nd_index.shape_view().At(1), 3); int num_points = polys.shape_view().At(0); CHECK_EQ(polys_nd_index.shape_view().At(0), num_points); std::vector> poly_point_vec; std::vector mask_mat_vec; auto PolyToMask = [&]() { CHECK_GT(poly_point_vec.size(), 0); CHECK_GT(poly_point_vec.front().size(), 0); cv::Mat mask_mat = cv::Mat(im_h, im_w, CV_8SC1, cv::Scalar(0)); cv::fillPoly(mask_mat, poly_point_vec, cv::Scalar(1), cv::LINE_8); mask_mat_vec.emplace_back(std::move(mask_mat)); poly_point_vec.clear(); }; int origin_round_way = std::fegetround(); CHECK_EQ(std::fesetround(FE_TONEAREST), 0); FOR_RANGE(int, i, 0, num_points) { const I pt_idx = polys_nd_index.data()[i * 3 + 0]; const I poly_idx = polys_nd_index.data()[i * 3 + 1]; const I segm_idx = polys_nd_index.data()[i * 3 + 2]; if (segm_idx != mask_mat_vec.size()) { PolyToMask(); } if (poly_idx == poly_point_vec.size()) { poly_point_vec.emplace_back(std::vector()); } CHECK_EQ(segm_idx, mask_mat_vec.size()); CHECK_EQ(poly_idx, poly_point_vec.size() - 1); CHECK_EQ(pt_idx, poly_point_vec.back().size()); const T* pts_ptr = polys.data() + i * 2; cv::Point pt{static_cast(std::nearbyint(pts_ptr[0])), static_cast(std::nearbyint(pts_ptr[1]))}; poly_point_vec.back().emplace_back(std::move(pt)); } PolyToMask(); CHECK_EQ(std::fesetround(origin_round_way), 0); masks->Resize(Shape({static_cast(mask_mat_vec.size()), static_cast(im_h), static_cast(im_w)}), DataType::kInt8); int mask_idx = 0; for (const auto& mask_mat : mask_mat_vec) { CHECK(mask_mat.isContinuous()); CHECK_EQ(mask_mat.total(), im_h * im_w); memcpy(masks->mut_data() + mask_idx * im_h * im_w, mask_mat.ptr(), mask_mat.total() * sizeof(int8_t)); mask_idx += 1; } } #define MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY(func_name, T, I) func_name DEFINE_STATIC_SWITCH_FUNC(void, PolygonsToMask, MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY, MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ), MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ)); #undef MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY } // namespace class ImageFlipKernel final : public user_op::OpKernel { public: ImageFlipKernel() = default; ~ImageFlipKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex("flip_code", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = in_tensor->shape_view().elem_cnt(); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& in_buffer = in_tensor->dptr()[i]; CHECK_EQ(in_buffer.shape_view().NumAxes(), 3); TensorBuffer* out_buffer = out_tensor->mut_dptr() + i; out_buffer->CopyFrom(in_buffer); FlipCode flip_code = static_cast(flip_code_tensor->dptr()[i]); if (flip_code != FlipCode::kNonFlip) { FlipImage(out_buffer, flip_code); } }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ObjectBboxFlipKernel final : public user_op::OpKernel { public: ObjectBboxFlipKernel() = default; ~ObjectBboxFlipKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex("bbox", 0); const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex("image_size", 0); const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex("flip_code", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = bbox_tensor->shape_view().elem_cnt(); CHECK_GT(num_images, 0); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); CHECK_EQ(image_size_tensor->shape_view().At(0), num_images); CHECK_EQ(flip_code_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& bbox_buffer = bbox_tensor->dptr()[i]; CHECK_EQ(bbox_buffer.shape_view().NumAxes(), 2); CHECK_EQ(bbox_buffer.shape_view().At(1), 4); TensorBuffer* out_bbox_buffer = out_tensor->mut_dptr() + i; out_bbox_buffer->CopyFrom(bbox_buffer); int32_t image_width = image_size_tensor->dptr()[i * 2 + 0]; int32_t image_height = image_size_tensor->dptr()[i * 2 + 1]; FlipCode flip_code = static_cast(flip_code_tensor->dptr()[i]); SwitchFlipBoxes(SwitchCase(out_bbox_buffer->data_type()), out_bbox_buffer, image_width, image_height, flip_code); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ObjectBboxScaleKernel final : public user_op::OpKernel { public: ObjectBboxScaleKernel() = default; ~ObjectBboxScaleKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex("bbox", 0); const user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = bbox_tensor->shape_view().elem_cnt(); CHECK_GT(num_images, 0); CHECK_EQ(scale_tensor->shape_view().At(0), num_images); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& bbox_buffer = bbox_tensor->dptr()[i]; CHECK_EQ(bbox_buffer.shape_view().NumAxes(), 2); CHECK_EQ(bbox_buffer.shape_view().At(1), 4); TensorBuffer* out_bbox_buffer = out_tensor->mut_dptr() + i; out_bbox_buffer->CopyFrom(bbox_buffer); float scale_w = scale_tensor->dptr()[i * 2 + 0]; float scale_h = scale_tensor->dptr()[i * 2 + 1]; SwitchScaleBoxes(SwitchCase(out_bbox_buffer->data_type()), out_bbox_buffer, scale_w, scale_h); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ObjectSegmentationPolygonFlipKernel final : public user_op::OpKernel { public: ObjectSegmentationPolygonFlipKernel() = default; ~ObjectSegmentationPolygonFlipKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* polygon_tensor = ctx->Tensor4ArgNameAndIndex("poly", 0); const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex("image_size", 0); const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex("flip_code", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = polygon_tensor->shape_view().elem_cnt(); CHECK_GT(num_images, 0); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); CHECK_EQ(image_size_tensor->shape_view().At(0), num_images); CHECK_EQ(flip_code_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& polygons_buffer = polygon_tensor->dptr()[i]; CHECK_EQ(polygons_buffer.shape_view().NumAxes(), 2); CHECK_EQ(polygons_buffer.shape_view().At(1), 2); TensorBuffer* out_polygons_buffer = out_tensor->mut_dptr() + i; out_polygons_buffer->CopyFrom(polygons_buffer); int32_t image_width = image_size_tensor->dptr()[i * 2 + 0]; int32_t image_height = image_size_tensor->dptr()[i * 2 + 1]; FlipCode flip_code = static_cast(flip_code_tensor->dptr()[i]); SwitchFlipPolygons(SwitchCase(out_polygons_buffer->data_type()), out_polygons_buffer, image_width, image_height, flip_code); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ObjectSegmentationPolygonScaleKernel final : public user_op::OpKernel { public: ObjectSegmentationPolygonScaleKernel() = default; ~ObjectSegmentationPolygonScaleKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* poly_tensor = ctx->Tensor4ArgNameAndIndex("poly", 0); const user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = poly_tensor->shape_view().elem_cnt(); CHECK_GT(num_images, 0); CHECK_EQ(scale_tensor->shape_view().At(0), num_images); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& poly_buffer = poly_tensor->dptr()[i]; CHECK_EQ(poly_buffer.shape_view().NumAxes(), 2); CHECK_EQ(poly_buffer.shape_view().At(1), 2); TensorBuffer* out_poly_buffer = out_tensor->mut_dptr() + i; out_poly_buffer->CopyFrom(poly_buffer); float scale_w = scale_tensor->dptr()[i * 2 + 0]; float scale_h = scale_tensor->dptr()[i * 2 + 1]; SwitchScalePolygons(SwitchCase(out_poly_buffer->data_type()), out_poly_buffer, scale_w, scale_h); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ImageNormalize final : public user_op::OpKernel { public: ImageNormalize() = default; ~ImageNormalize() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = in_tensor->shape_view().elem_cnt(); CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images); const auto& std_vec = ctx->Attr>("std"); const auto& mean_vec = ctx->Attr>("mean"); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& in_buffer = in_tensor->dptr()[i]; CHECK_EQ(in_buffer.shape_view().NumAxes(), 3); TensorBuffer* out_buffer = out_tensor->mut_dptr() + i; out_buffer->CopyFrom(in_buffer); SwitchImageNormalizeByChannel(SwitchCase(out_buffer->data_type()), out_buffer, std_vec, mean_vec); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ObjectSegmentationPolygonToMask final : public user_op::OpKernel { public: ObjectSegmentationPolygonToMask() = default; ~ObjectSegmentationPolygonToMask() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* poly_tensor = ctx->Tensor4ArgNameAndIndex("poly", 0); const user_op::Tensor* poly_index_tensor = ctx->Tensor4ArgNameAndIndex("poly_index", 0); const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex("image_size", 0); user_op::Tensor* mask_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); int num_images = poly_tensor->shape_view().elem_cnt(); CHECK_GT(num_images, 0); CHECK_EQ(poly_index_tensor->shape_view().elem_cnt(), num_images); CHECK_EQ(image_size_tensor->shape_view().At(0), num_images); CHECK_EQ(mask_tensor->shape_view().elem_cnt(), num_images); MultiThreadLoop(num_images, [&](size_t i) { const TensorBuffer& poly_buffer = poly_tensor->dptr()[i]; const TensorBuffer& poly_index_buffer = poly_index_tensor->dptr()[i]; int32_t image_width = image_size_tensor->dptr()[i * 2 + 0]; int32_t image_height = image_size_tensor->dptr()[i * 2 + 1]; TensorBuffer* mask_buffer = mask_tensor->mut_dptr() + i; SwitchPolygonsToMask(SwitchCase(poly_buffer.data_type(), poly_index_buffer.data_type()), poly_buffer, poly_index_buffer, mask_buffer, image_width, image_height); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; namespace { std::function(const user_op::InferContext&, user_op::AddInplaceArgPair)> MakeInplaceProposalFn(const std::string& input_arg_name) { return [input_arg_name](const user_op::InferContext& ctx, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, input_arg_name, 0, true)); return Maybe::Ok(); }; } } // namespace REGISTER_USER_KERNEL("image_flip") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("flip_code", 0) == DataType::kInt8) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("in")); REGISTER_USER_KERNEL("object_bbox_flip") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("bbox", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("image_size", 0) == DataType::kInt32) && (user_op::HobDataType("flip_code", 0) == DataType::kInt8) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("bbox")); REGISTER_USER_KERNEL("object_bbox_scale") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("bbox", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("scale", 0) == DataType::kFloat) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("bbox")); REGISTER_USER_KERNEL("object_segmentation_polygon_flip") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("poly", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("image_size", 0) == DataType::kInt32) && (user_op::HobDataType("flip_code", 0) == DataType::kInt8) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("poly")); REGISTER_USER_KERNEL("object_segmentation_polygon_scale") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("poly", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("scale", 0) == DataType::kFloat) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("poly")); REGISTER_USER_KERNEL("image_normalize") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)) .SetInplaceProposalFn(MakeInplaceProposalFn("in")); REGISTER_USER_KERNEL("object_segmentation_polygon_to_mask") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("poly", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("poly_index", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("image_size", 0) == DataType::kInt32) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_preprocess_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include "oneflow/user/kernels/random_crop_kernel_state.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template inline int64_t GetOffset(int64_t h, int64_t w, int64_t c, int64_t H, int64_t W, int64_t C); template<> inline int64_t GetOffset(int64_t h, int64_t w, int64_t c, int64_t H, int64_t W, int64_t C) { return c * H * W + h * W + w; // C, H, W } template<> inline int64_t GetOffset(int64_t h, int64_t w, int64_t c, int64_t H, int64_t W, int64_t C) { return h * W * C + w * C + c; // H, W, C } template inline int64_t GetInputW(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x); template<> inline int64_t GetInputW(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x) { return (in_W - out_W) * crop_pos_x + (out_W - 1 - out_w); } template<> inline int64_t GetInputW(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x) { return (in_W - out_W) * crop_pos_x + out_w; } template void CMN1Sample(int64_t C, int64_t in_H, int64_t in_W, int64_t out_H, int64_t out_W, float crop_pos_y, float crop_pos_x, const uint8_t* in_dptr, float* out_dptr, const std::vector& mean_vec, const std::vector& inv_std_vec) { CHECK_LE(out_H, in_H); CHECK_LE(out_W, in_W); for (int64_t c = 0; c < C; ++c) { float mean = mean_vec.at(c); float inv_std = inv_std_vec.at(c); for (int64_t out_h = 0; out_h < out_H; ++out_h) { int64_t in_h = (in_H - out_H) * crop_pos_y + out_h; for (int64_t out_w = 0; out_w < out_W; ++out_w) { int64_t in_w = GetInputW(out_w, out_W, in_W, crop_pos_x); int64_t in_offset = GetOffset(in_h, in_w, c, in_H, in_W, C); int64_t out_offset = GetOffset(out_h, out_w, c, out_H, out_W, C); out_dptr[out_offset] = (static_cast(in_dptr[in_offset]) - mean) * inv_std; } } } } std::vector GetMirrorVec(user_op::KernelComputeContext* ctx) { std::vector mirror; user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* mirror_blob = ctx->Tensor4ArgNameAndIndex("mirror", 0); int64_t record_num = in_blob->shape_view().At(0); if (mirror_blob) { CHECK_EQ(record_num, mirror_blob->shape_view().elem_cnt()); mirror.insert(mirror.end(), mirror_blob->dptr(), mirror_blob->dptr() + record_num); } else { mirror.resize(record_num, 0); } return mirror; } class CMNAttr final : public user_op::OpKernelState { public: CMNAttr(user_op::KernelInitContext* ctx) { mean_vec_ = ctx->Attr>("mean"); const std::vector& std_vec = ctx->Attr>("std"); const std::string& color_space = ctx->Attr("color_space"); int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; CHECK(mean_vec_.size() == 1 || mean_vec_.size() == C); CHECK(std_vec.size() == 1 || std_vec.size() == C); for (float elem : std_vec) { inv_std_vec_.emplace_back(1.0f / elem); } if (mean_vec_.size() == 1) { mean_vec_.resize(C, mean_vec_.at(0)); } if (inv_std_vec_.size() == 1) { inv_std_vec_.resize(C, inv_std_vec_.at(0)); } } ~CMNAttr() = default; const std::vector& mean_vec() const { return mean_vec_; } const std::vector& inv_std_vec() const { return inv_std_vec_; } private: std::vector mean_vec_; std::vector inv_std_vec_; }; } // namespace class CropMirrorNormalizeFromStaticShapeToFloatKernel final : public user_op::OpKernel { public: CropMirrorNormalizeFromStaticShapeToFloatKernel() = default; ~CropMirrorNormalizeFromStaticShapeToFloatKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* cmn_attr = dynamic_cast(state); const std::vector& mean_vec = cmn_attr->mean_vec(); const std::vector& inv_std_vec = cmn_attr->inv_std_vec(); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); std::vector mirror = GetMirrorVec(ctx); int64_t record_num = in_blob->shape_view().At(0); const std::string& color_space = ctx->Attr("color_space"); int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; float crop_pos_y = ctx->Attr("crop_pos_y"); float crop_pos_x = ctx->Attr("crop_pos_x"); const std::string& output_layout = ctx->Attr("output_layout"); float* out_dptr = out_blob->mut_dptr(); const uint8_t* in_dptr = in_blob->dptr(); const ShapeView& in_shape = in_blob->shape_view(); int64_t N = in_shape.At(0); int64_t in_H = in_shape.At(1); int64_t in_W = in_shape.At(2); CHECK_EQ(C, in_shape.At(3)); int64_t in_image_elem_cnt = in_H * in_W * C; const ShapeView& out_shape = out_blob->shape_view(); CHECK_EQ(out_shape.NumAxes(), 4); CHECK_EQ(out_shape.At(0), N); if (output_layout == "NCHW") { CHECK_EQ(out_shape.At(1), C); int64_t out_H = out_shape.At(2); int64_t out_W = out_shape.At(3); int64_t out_image_elem_cnt = C * out_H * out_W; MultiThreadLoop(record_num, [&](size_t i) { if (mirror.at(i)) { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i, out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } else { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i, out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } }); } else if (output_layout == "NHWC") { CHECK_EQ(out_shape.At(3), C); int64_t out_H = out_shape.At(1); int64_t out_W = out_shape.At(2); int64_t out_image_elem_cnt = C * out_H * out_W; MultiThreadLoop(record_num, [&](size_t i) { if (mirror.at(i)) { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i, out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } else { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i, out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } }); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("crop_mirror_normalize_from_uint8") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kUInt8) && (user_op::HobDataType("out", 0) == DataType::kFloat)); class CropMirrorNormalizeFromTensorBufferToFloatKernel final : public user_op::OpKernel { public: CropMirrorNormalizeFromTensorBufferToFloatKernel() = default; ~CropMirrorNormalizeFromTensorBufferToFloatKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* cmn_attr = dynamic_cast(state); const std::vector& mean_vec = cmn_attr->mean_vec(); const std::vector& inv_std_vec = cmn_attr->inv_std_vec(); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); std::vector mirror = GetMirrorVec(ctx); int64_t record_num = in_blob->shape_view().At(0); const std::string& color_space = ctx->Attr("color_space"); int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; float crop_pos_y = ctx->Attr("crop_pos_y"); float crop_pos_x = ctx->Attr("crop_pos_x"); const std::string& output_layout = ctx->Attr("output_layout"); float* out_dptr = out_blob->mut_dptr(); const TensorBuffer* in_buffers = in_blob->dptr(); const ShapeView& in_shape = in_blob->shape_view(); int64_t N = in_shape.At(0); CHECK_EQ(in_shape.NumAxes(), 1); const ShapeView& out_shape = out_blob->shape_view(); CHECK_EQ(out_shape.NumAxes(), 4); CHECK_EQ(out_shape.At(0), N); if (output_layout == "NCHW") { CHECK_EQ(out_shape.At(1), C); int64_t out_H = out_shape.At(2); int64_t out_W = out_shape.At(3); int64_t out_image_elem_cnt = C * out_H * out_W; MultiThreadLoop(record_num, [&](size_t i) { const TensorBuffer* in_buffer = in_buffers + i; const Shape& in_shape = in_buffer->shape(); CHECK_EQ(in_shape.NumAxes(), 3); // H, W, C int64_t in_H = in_shape.At(0); int64_t in_W = in_shape.At(1); CHECK_EQ(C, in_shape.At(2)); if (mirror.at(i)) { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data(), out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } else { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data(), out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } }); } else if (output_layout == "NHWC") { CHECK_EQ(out_shape.At(3), C); int64_t out_H = out_shape.At(1); int64_t out_W = out_shape.At(2); int64_t out_image_elem_cnt = C * out_H * out_W; MultiThreadLoop(record_num, [&](size_t i) { const TensorBuffer* in_buffer = in_buffers + i; const Shape& in_shape = in_buffer->shape(); CHECK_EQ(in_shape.NumAxes(), 3); // H, W, C int64_t in_H = in_shape.At(0); int64_t in_W = in_shape.At(1); CHECK_EQ(C, in_shape.At(2)); if (mirror.at(i)) { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data(), out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } else { CMN1Sample( C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data(), out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec); } }); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("crop_mirror_normalize_from_tensorbuffer") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kFloat)); namespace { class RandBoolGen final : public user_op::OpKernelState { public: explicit RandBoolGen(float prob, int64_t seed) : dis_(prob), rng_(seed) {} ~RandBoolGen() = default; bool GetNextBool() { return dis_(rng_); } private: std::bernoulli_distribution dis_; std::mt19937 rng_; }; } // namespace class CoinFlipKernel final : public user_op::OpKernel { public: CoinFlipKernel() = default; ~CoinFlipKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { float prob = ctx->Attr("probability"); int64_t seed = CHECK_JUST(GetOpKernelRandomSeed(ctx)); std::shared_ptr rand_bool_gen(new RandBoolGen(prob, seed)); return rand_bool_gen; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* rand_bool_gen = dynamic_cast(state); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int8_t* dptr = out_blob->mut_dptr(); for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) { *(dptr + i) = rand_bool_gen->GetNextBool() ? 1 : 0; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("coin_flip") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("out", 0) == DataType::kInt8)); namespace { void ImageRandomCropImpl(const TensorBuffer* in_buffer, TensorBuffer* out_buffer, RandomCropGenerator* random_crop_gen) { cv::Mat image = GenCvMat4ImageBuffer(*in_buffer); int W = image.cols; int H = image.rows; cv::Mat image_roi; CropWindow crop; random_crop_gen->GenerateCropWindow({H, W}, &crop); const int y = crop.anchor.At(0); const int x = crop.anchor.At(1); const int new_h = crop.shape.At(0); const int new_w = crop.shape.At(1); CHECK(new_w > 0 && new_w <= W); CHECK(new_h > 0 && new_h <= H); cv::Rect roi(x, y, new_w, new_h); image(roi).copyTo(image_roi); image = image_roi; W = image.cols; H = image.rows; CHECK(image.isContinuous()); const int c = in_buffer->shape_view().At(2); CHECK_EQ(c, image.channels()); Shape image_shape({H, W, c}); out_buffer->Resize(image_shape, in_buffer->data_type()); memcpy(out_buffer->mut_data<>(), image.ptr(), out_buffer->nbytes()); } } // namespace class ImageRandomCropKernel final : public user_op::OpKernel { public: ImageRandomCropKernel() = default; ~ImageRandomCropKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return CreateRandomCropKernelState(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* crop_window_generators = dynamic_cast(state); CHECK_NOTNULL(crop_window_generators); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t record_num = out_blob->shape_view().elem_cnt(); CHECK(record_num > 0); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_EQ(out_blob->shape_view(), in_blob->shape_view()); const TensorBuffer* in_buffers = in_blob->dptr(); TensorBuffer* out_buffers = out_blob->mut_dptr(); MultiThreadLoop(record_num, [&](size_t i) { ImageRandomCropImpl(in_buffers + i, out_buffers + i, crop_window_generators->GetGenerator(i)); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("image_random_crop") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_preprocess_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/small_vector.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { struct NormalizeVal { float val[3]; }; class NormalizeAttr final : public user_op::OpKernelState { public: NormalizeAttr(user_op::KernelInitContext* ctx) { const std::vector& mean_vec = ctx->Attr>("mean"); if (mean_vec.size() == 1) { for (int i = 0; i < 3; ++i) { mean_.val[i] = mean_vec.at(0); } } else if (mean_vec.size() == 3) { for (int i = 0; i < 3; ++i) { mean_.val[i] = mean_vec.at(i); } } else { UNIMPLEMENTED(); } const std::vector& std_vec = ctx->Attr>("std"); if (std_vec.size() == 1) { for (int i = 0; i < 3; ++i) { inv_std_.val[i] = 1.0f / std_vec.at(0); } } else if (std_vec.size() == 3) { for (int i = 0; i < 3; ++i) { inv_std_.val[i] = 1.0f / std_vec.at(i); } } else { UNIMPLEMENTED(); } } ~NormalizeAttr() = default; const NormalizeVal& mean() const { return mean_; } const NormalizeVal& inv_std() const { return inv_std_; } private: NormalizeVal mean_; NormalizeVal inv_std_; }; template __device__ __forceinline__ void OutIdx2InIdx(int32_t* out_idx, int32_t* in_idx, const int8_t* mirror_dptr, int32_t out_W, int32_t H_offset, int32_t W_offset); template<> __device__ __forceinline__ void OutIdx2InIdx( int32_t* out_idx, int32_t* in_idx, const int8_t* mirror_dptr, int32_t out_W, int32_t H_offset, int32_t W_offset) { if (mirror_dptr && mirror_dptr[out_idx[0]]) { out_idx[3] = out_W - 1 - out_idx[3]; } in_idx[0] = out_idx[0]; // N in_idx[1] = out_idx[2] + H_offset; // H in_idx[2] = out_idx[3] + W_offset; // W in_idx[3] = out_idx[1]; // C } template<> __device__ __forceinline__ void OutIdx2InIdx( int32_t* out_idx, int32_t* in_idx, const int8_t* mirror_dptr, int32_t out_W, int32_t H_offset, int32_t W_offset) { if (mirror_dptr && mirror_dptr[out_idx[0]]) { out_idx[2] = out_W - 1 - out_idx[2]; } in_idx[0] = out_idx[0]; // N in_idx[1] = out_idx[1] + H_offset; // H in_idx[2] = out_idx[2] + W_offset; // W in_idx[3] = out_idx[3]; // C } template __global__ void CropMirrorNormalizeGpuImpl(int32_t elem_cnt, const uint8_t* in_dptr, float* out_dptr, const int8_t* mirror_dptr, int32_t out_W, const NdIndexOffsetHelper in_helper, const NdIndexOffsetHelper out_helper, int32_t H_offset, int32_t W_offset, const NormalizeVal mean, const NormalizeVal inv_std) { CUDA_1D_KERNEL_LOOP(out_offset, elem_cnt) { int32_t in_idx[4]; int32_t out_idx[4]; out_helper.OffsetToNdIndex(out_offset, out_idx); OutIdx2InIdx(out_idx, in_idx, mirror_dptr, out_W, H_offset, W_offset); float mean_val; float inv_std_val; const int32_t c = in_idx[3]; // When the compiler can't resolve array indices to constants it will put private arrays into // GPU local memory. Using local memory is slower than keeping array elements directly in // registers. if (c == 0) { mean_val = mean.val[0]; inv_std_val = inv_std.val[0]; } else if (c == 1) { mean_val = mean.val[1]; inv_std_val = inv_std.val[1]; } else if (c == 2) { mean_val = mean.val[2]; inv_std_val = inv_std.val[2]; } else { // undefined behavior assert(false); } int32_t in_offset = in_helper.NdIndexToOffset(in_idx); out_dptr[out_offset] = (static_cast(in_dptr[in_offset]) - mean_val) * inv_std_val; } } } // namespace class CropMirrorNormalizeGpuKernel final : public user_op::OpKernel { public: CropMirrorNormalizeGpuKernel() = default; ~CropMirrorNormalizeGpuKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* normalize_attr = dynamic_cast(state); const NormalizeVal& mean = normalize_attr->mean(); const NormalizeVal& inv_std = normalize_attr->inv_std(); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string& output_layout = ctx->Attr("output_layout"); float* out_dptr = out_blob->mut_dptr(); const uint8_t* in_dptr = in_blob->dptr(); const ShapeView& in_shape = in_blob->shape_view(); const ShapeView& out_shape = out_blob->shape_view(); CHECK_EQ(in_shape.NumAxes(), 4); CHECK_EQ(out_shape.NumAxes(), 4); int32_t elem_cnt = out_shape.elem_cnt(); CHECK_LE(elem_cnt, GetMaxVal()); float crop_pos_y = ctx->Attr("crop_pos_y"); float crop_pos_x = ctx->Attr("crop_pos_x"); int32_t N = in_shape.At(0); int32_t in_H = in_shape.At(1); int32_t in_W = in_shape.At(2); int32_t C = in_shape.At(3); const NdIndexOffsetHelper in_helper(N, in_H, in_W, C); const int8_t* mirror_dptr = nullptr; user_op::Tensor* mirror_blob = ctx->Tensor4ArgNameAndIndex("mirror", 0); if (mirror_blob) { mirror_dptr = mirror_blob->dptr(); } if (output_layout == "NCHW") { CHECK_EQ(N, out_shape.At(0)); CHECK_EQ(C, out_shape.At(1)); int32_t out_H = out_shape.At(2); int32_t out_W = out_shape.At(3); CHECK_LE(out_H, in_H); CHECK_LE(out_W, in_W); int32_t H_offset = (in_H - out_H) * crop_pos_y; int32_t W_offset = (in_W - out_W) * crop_pos_x; const NdIndexOffsetHelper out_helper(N, C, out_H, out_W); CropMirrorNormalizeGpuImpl <<stream()->As()->cuda_stream()>>>( elem_cnt, in_dptr, out_dptr, mirror_dptr, out_W, in_helper, out_helper, H_offset, W_offset, mean, inv_std); } else if (output_layout == "NHWC") { CHECK_EQ(N, out_shape.At(0)); int32_t out_H = out_shape.At(1); int32_t out_W = out_shape.At(2); CHECK_EQ(C, out_shape.At(3)); CHECK_LE(out_H, in_H); CHECK_LE(out_W, in_W); int32_t H_offset = (in_H - out_H) * crop_pos_y; int32_t W_offset = (in_W - out_W) * crop_pos_x; const NdIndexOffsetHelper out_helper(N, out_H, out_W, C); CropMirrorNormalizeGpuImpl <<stream()->As()->cuda_stream()>>>( elem_cnt, in_dptr, out_dptr, mirror_dptr, out_W, in_helper, out_helper, H_offset, W_offset, mean, inv_std); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("crop_mirror_normalize_from_uint8") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("in", 0) == DataType::kUInt8) && (user_op::HobDataType("out", 0) == DataType::kFloat)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_resize_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include #include namespace oneflow { namespace { template std::pair GetTargetResizedSize4ImageBuffer(const TensorBuffer& image_buffer, const bool resize_longer, const T target_size, const T min_size, const T max_size) { CHECK_GT(target_size, 0); if (min_size > 0) { CHECK_GE(target_size, min_size); } if (max_size > 0) { CHECK_LE(target_size, max_size); } CHECK_EQ(image_buffer.shape_view().NumAxes(), 3); const T origin_height = image_buffer.shape_view().At(0); const T origin_width = image_buffer.shape_view().At(1); // set round to banker's rounding int origin_round_way = std::fegetround(); CHECK_EQ(std::fesetround(FE_TONEAREST), 0); double org_min_size = std::min(origin_height, origin_width); double org_max_size = std::max(origin_height, origin_width); double aspect_ratio = org_min_size / org_max_size; double res_min_size = 0.0; double res_max_size = 0.0; if (resize_longer) { res_max_size = static_cast(target_size); res_min_size = std::nearbyint(res_max_size * aspect_ratio); if (min_size > 0 && res_min_size < min_size) { res_min_size = static_cast(min_size); res_max_size = std::nearbyint(res_min_size / aspect_ratio); } } else { res_min_size = static_cast(target_size); res_max_size = std::nearbyint(res_min_size / aspect_ratio); if (max_size > 0 && res_max_size > max_size) { res_max_size = static_cast(max_size); res_min_size = std::nearbyint(res_max_size * aspect_ratio); } } std::fesetround(origin_round_way); std::pair width_and_height; if (origin_width < origin_height) { width_and_height.first = static_cast(res_min_size); width_and_height.second = static_cast(res_max_size); } else { width_and_height.first = static_cast(res_max_size); width_and_height.second = static_cast(res_min_size); } return width_and_height; } bool CheckMatSizeMatch(const cv::Mat& mat, const bool resize_longer, const int32_t target_size, const int32_t min_size, const int32_t max_size) { bool is_size_match = true; int mat_min_size = std::min(mat.rows, mat.cols); int mat_max_size = std::max(mat.rows, mat.cols); if (resize_longer) { if (min_size > 0) { is_size_match = (mat_max_size >= target_size) && (mat_min_size >= min_size) && (mat_min_size == min_size || mat_max_size == target_size); } else { is_size_match = (mat_max_size == target_size); } } else { if (max_size > 0) { is_size_match = (mat_min_size <= target_size) && (mat_max_size <= max_size) && (mat_min_size == target_size || mat_max_size == max_size); } else { is_size_match = (mat_min_size == target_size); } } return is_size_match; } void ImageTargetResize(const TensorBuffer& image_buffer, TensorBuffer* resized_image_buffer, const bool resize_longer, const int32_t target_size, const int32_t min_size, const int32_t max_size, const std::string& interp_type) { const cv::Mat image_mat = GenCvMat4ImageBuffer(image_buffer); int64_t res_w = 0; int64_t res_h = 0; int64_t channels = image_mat.channels(); std::tie(res_w, res_h) = GetTargetResizedSize4ImageBuffer( image_buffer, resize_longer, target_size, min_size, max_size); resized_image_buffer->Resize(Shape({res_h, res_w, channels}), image_buffer.data_type()); cv::Mat res_image_mat = GenCvMat4ImageBuffer(*resized_image_buffer); int interp_flag = GetCvInterpolationFlag(interp_type, image_mat.cols, image_mat.rows, res_w, res_h); cv::resize(image_mat, res_image_mat, cv::Size(res_w, res_h), 0, 0, interp_flag); CHECK_EQ(res_image_mat.ptr(), resized_image_buffer->data()); CHECK(CheckMatSizeMatch(res_image_mat, resize_longer, target_size, min_size, max_size)); } class ImageResizeToFixedSizeKernel final : public user_op::OpKernel { public: ImageResizeToFixedSizeKernel() = default; ~ImageResizeToFixedSizeKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_NOTNULL(in_tensor); const int64_t batch_size = in_tensor->shape_view().elem_cnt(); CHECK_GT(batch_size, 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out_tensor->shape_view().NumAxes(), 4); CHECK_EQ(out_tensor->shape_view().At(0), batch_size); int64_t res_h = out_tensor->shape_view().At(1); int64_t res_w = out_tensor->shape_view().At(2); int64_t channels = out_tensor->shape_view().At(3); int64_t elem_cnt_per_img = res_h * res_w * channels; user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex("scale", 0); CHECK_EQ(scale_tensor->shape_view().NumAxes(), 2); CHECK_EQ(scale_tensor->shape_view().At(0), batch_size); CHECK_EQ(scale_tensor->shape_view().At(1), 2); MultiThreadLoop(batch_size, [&](size_t i) { const TensorBuffer& in_buffer = in_tensor->dptr()[i]; CHECK_EQ(in_buffer.shape_view().NumAxes(), 3); const int64_t origin_height = in_buffer.shape_view().At(0); const int64_t origin_width = in_buffer.shape_view().At(1); CHECK_EQ(in_buffer.shape_view().At(2), channels); DataType dtype = ctx->Attr("data_type"); int interp_flag = GetCvInterpolationFlag(ctx->Attr("interpolation_type"), origin_width, origin_height, res_w, res_h); const cv::Mat in_img_mat = GenCvMat4ImageBuffer(in_buffer); cv::Mat out_img_mat = GenCvMat4ImageTensor(out_tensor, i); if (in_buffer.data_type() == dtype) { cv::resize(in_img_mat, out_img_mat, cv::Size(res_w, res_h), 0, 0, interp_flag); } else { cv::Mat res_img_mat; cv::resize(in_img_mat, res_img_mat, cv::Size(res_w, res_h), 0, 0, interp_flag); CvMatConvertToDataType(res_img_mat, &out_img_mat, dtype); } char* cur_out_dptr = out_tensor->mut_dptr() + i * elem_cnt_per_img * GetSizeOfDataType(dtype); CHECK(out_img_mat.isContinuous()); CHECK_EQ(out_img_mat.ptr(), static_cast(cur_out_dptr)); CHECK_EQ(out_img_mat.cols, res_w); CHECK_EQ(out_img_mat.rows, res_h); CHECK_EQ(out_img_mat.channels(), channels); if (scale_tensor) { float* scale_dptr = scale_tensor->mut_dptr() + i * 2; scale_dptr[0] = static_cast(res_w) / static_cast(origin_width); scale_dptr[1] = static_cast(res_h) / static_cast(origin_height); } }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class ImageResizeKeepAspectRatioKernel final : public user_op::OpKernel { public: ImageResizeKeepAspectRatioKernel() = default; ~ImageResizeKeepAspectRatioKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* size_tensor = ctx->Tensor4ArgNameAndIndex("size", 0); user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex("scale", 0); CHECK_NOTNULL(out_tensor); CHECK_NOTNULL(size_tensor); CHECK_NOTNULL(scale_tensor); const TensorBuffer* in_img_buf = in_tensor->dptr(); TensorBuffer* out_img_buf = out_tensor->mut_dptr(); TensorBuffer* scale_buf = scale_tensor->mut_dptr(); TensorBuffer* size_buf = size_tensor->mut_dptr(); const int64_t num_images = in_tensor->shape_view().elem_cnt(); const bool resize_longer = ctx->Attr("resize_longer"); const int32_t target_size = ctx->Attr("target_size"); const int32_t min_size = ctx->Attr("min_size"); const int32_t max_size = ctx->Attr("max_size"); const std::string& interp_type = ctx->Attr("interpolation_type"); MultiThreadLoop(num_images, [&](size_t i) { ImageTargetResize(in_img_buf[i], out_img_buf + i, resize_longer, target_size, min_size, max_size, interp_type); const int64_t org_h = in_img_buf[i].shape_view().At(0); const int64_t org_w = in_img_buf[i].shape_view().At(1); const int64_t res_h = out_img_buf[i].shape_view().At(0); const int64_t res_w = out_img_buf[i].shape_view().At(1); scale_buf[i].Resize(Shape({2}), DataType::kFloat); scale_buf[i].mut_data()[0] = static_cast(res_w) / static_cast(org_w); scale_buf[i].mut_data()[1] = static_cast(res_h) / static_cast(org_h); size_buf[i].Resize(Shape({2}), DataType::kInt32); size_buf[i].mut_data()[0] = static_cast(res_w); size_buf[i].mut_data()[1] = static_cast(res_h); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_IMAGE_RESIZE_KERNEL(dtype) \ REGISTER_USER_KERNEL("image_resize_to_fixed") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) \ && (user_op::HobAttr("data_type") == GetDataType::value)); REGISTER_IMAGE_RESIZE_KERNEL(float) REGISTER_IMAGE_RESIZE_KERNEL(uint8_t) REGISTER_USER_KERNEL("image_resize_keep_aspect_ratio") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("size", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("scale", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/image_target_resize_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/image_util.h" #include #include namespace oneflow { namespace { template std::pair GetTargetResizedSize4ImageBuffer(const TensorBuffer& image_buffer, const T target_size, const T max_size) { CHECK_EQ(image_buffer.shape_view().NumAxes(), 3); const T origin_height = image_buffer.shape_view().At(0); const T origin_width = image_buffer.shape_view().At(1); // set round to banker's rounding int origin_round_way = std::fegetround(); CHECK_EQ(std::fesetround(FE_TONEAREST), 0); double origin_min_size = std::min(origin_height, origin_width); double origin_max_size = std::max(origin_height, origin_width); double resized_min_size = static_cast(target_size); double resized_max_size = std::nearbyint((origin_max_size / origin_min_size) * resized_min_size); if (resized_max_size > max_size) { resized_max_size = static_cast(max_size); resized_min_size = std::nearbyint(resized_max_size * origin_min_size / origin_max_size); } std::pair height_and_width; if (origin_width < origin_height) { height_and_width.second = resized_min_size; height_and_width.first = resized_max_size; } else { height_and_width.first = resized_min_size; height_and_width.second = resized_max_size; } std::fesetround(origin_round_way); return height_and_width; } void ImageTargetResize(const TensorBuffer& image_buffer, TensorBuffer* resized_image_buffer, const int32_t target_size, const int32_t max_size) { CHECK_EQ(image_buffer.shape_view().NumAxes(), 3); CHECK_GT(target_size, 0); CHECK_GE(max_size, target_size); cv::Mat image_mat = GenCvMat4ImageBuffer(image_buffer); int64_t res_h = 0; int64_t res_w = 0; int64_t channels = image_mat.channels(); std::tie(res_h, res_w) = GetTargetResizedSize4ImageBuffer(image_buffer, target_size, max_size); resized_image_buffer->Resize(Shape({res_h, res_w, channels}), image_buffer.data_type()); cv::Mat res_image_mat = GenCvMat4ImageBuffer(*resized_image_buffer); cv::resize(image_mat, res_image_mat, cv::Size(res_w, res_h), 0, 0, cv::INTER_LINEAR); CHECK_EQ(res_image_mat.ptr(), resized_image_buffer->data()); CHECK_LE(std::max(res_image_mat.rows, res_image_mat.cols), max_size); CHECK(std::max(res_image_mat.rows, res_image_mat.cols) == max_size || std::min(res_image_mat.rows, res_image_mat.cols) == target_size); } } // namespace class ImageTargetResizeKernel final : public user_op::OpKernel { public: ImageTargetResizeKernel() = default; ~ImageTargetResizeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* size_tensor = ctx->Tensor4ArgNameAndIndex("size", 0); user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex("scale", 0); CHECK_GT(in_tensor->shape_view().elem_cnt(), 0); CHECK_EQ(in_tensor->shape_view().elem_cnt(), out_tensor->shape_view().elem_cnt()); CHECK_EQ(in_tensor->shape_view().elem_cnt(), size_tensor->shape_view().At(0)); CHECK_EQ(in_tensor->shape_view().elem_cnt(), scale_tensor->shape_view().At(0)); const TensorBuffer* in_img_buf = in_tensor->dptr(); TensorBuffer* out_img_buf = out_tensor->mut_dptr(); int32_t* size_ptr = size_tensor ? size_tensor->mut_dptr() : nullptr; float* scale_ptr = scale_tensor ? scale_tensor->mut_dptr() : nullptr; const int32_t target_size = ctx->Attr("target_size"); const int32_t max_size = ctx->Attr("max_size"); MultiThreadLoop(in_tensor->shape_view().elem_cnt(), [&](size_t i) { ImageTargetResize(in_img_buf[i], out_img_buf + i, target_size, max_size); if (size_ptr != nullptr) { size_ptr[i * 2 + 0] = out_img_buf[i].shape_view().At(0); size_ptr[i * 2 + 1] = out_img_buf[i].shape_view().At(1); } if (scale_ptr != nullptr) { scale_ptr[i * 2 + 0] = static_cast(out_img_buf[i].shape_view().At(0)) / static_cast(in_img_buf[i].shape_view().At(0)); scale_ptr[i * 2 + 1] = static_cast(out_img_buf[i].shape_view().At(1)) / static_cast(in_img_buf[i].shape_view().At(1)); } }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("image_target_resize") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("size", 0) == DataType::kInt32) && (user_op::HobDataType("scale", 0) == DataType::kFloat)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/in_top_k_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/in_top_k_kernel_util.h" namespace oneflow { template class InTopkKernel final : public user_op::OpKernel { public: InTopkKernel() = default; ~InTopkKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex("targets", 0); const user_op::Tensor* predictions = ctx->Tensor4ArgNameAndIndex("predictions", 0); const int32_t k = ctx->Attr("k"); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(targets->shape_view().At(0), predictions->shape_view().At(0)); CHECK_EQ(targets->shape_view().NumAxes(), 1); CHECK_EQ(predictions->shape_view().NumAxes(), 2); const int32_t instance_num = predictions->shape_view().At(0); const int32_t classes_num = predictions->shape_view().At(1); InTopkKernelUtil::InTopk(ctx->stream(), instance_num, classes_num, targets->dptr(), predictions->dptr(), k, out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_IN_TOP_K_KERNEL(device, target_dtype_pair) \ REGISTER_USER_KERNEL("in_top_k") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("targets", 0) == OF_PP_PAIR_SECOND(target_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_IN_TOP_K_KERNEL, DEVICE_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #undef REGISTER_IN_TOP_K_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/in_top_k_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/in_top_k_kernel_util.h" #include "oneflow/core/common/data_type_seq.h" namespace oneflow { template struct InTopkKernelUtil { static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num, const T* targets, const float* predictions, const int k, bool* out) { FOR_RANGE(int32_t, idx, 0, instance_num) { T target = targets[idx]; bool cannot_say = (target >= classes_num) || !std::isfinite(predictions[idx * classes_num + target]); int32_t more_probable_classes = 0; if (!cannot_say) { const float target_prediction = predictions[idx * classes_num + target]; FOR_RANGE(int32_t, class_idx, 0, classes_num) { float pred = predictions[idx * classes_num + class_idx]; if (!std::isfinite(pred)) { cannot_say = true; break; } else if (pred > target_prediction) { ++more_probable_classes; if (more_probable_classes > k) break; } } } out[idx] = cannot_say ? false : (more_probable_classes < k); } } }; #define INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU(cpp_data_type, data_type) \ template struct InTopkKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU, INDEX_DATA_TYPE_SEQ) #undef INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/in_top_k_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/in_top_k_kernel_util.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace { template __global__ void InTopkGpu(const int instance_num, const int classes_num, const T* targets, const float* predictions, const int k, bool* out) { CUDA_1D_KERNEL_LOOP(idx, instance_num) { T target = targets[idx]; bool cannot_say = (target >= classes_num) || !isfinite(predictions[idx * classes_num + target]); int32_t more_probable_classes = 0; if (!cannot_say) { const float target_prediction = predictions[idx * classes_num + target]; FOR_RANGE(int32_t, class_idx, 0, classes_num) { float pred = predictions[idx * classes_num + class_idx]; if (!isfinite(pred)) { cannot_say = true; break; } else if (pred > target_prediction) { ++more_probable_classes; if (more_probable_classes > k) break; } } } out[idx] = cannot_say ? false : (more_probable_classes < k); } } } // namespace template struct InTopkKernelUtil { static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num, const T* targets, const float* predictions, const int k, bool* out) { RUN_CUDA_KERNEL((InTopkGpu), stream, instance_num, instance_num, classes_num, targets, predictions, k, out); } }; #define INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA(cpp_data_type, data_type) \ template struct InTopkKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA, INDEX_DATA_TYPE_SEQ) #undef INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/in_top_k_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct InTopkKernelUtil { static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num, const T* targets, const float* predictions, const int k, bool* out); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/index_add_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template void index_add_cpu_kernel(const int64_t n, const T* input, const IndexT* index, const T* source, T* output, const int64_t stride, const int64_t source_dim, const int64_t delta, const float alpha) { const int64_t stride_source_dim = stride * source_dim; for (int i = 0; i < n; i++) { int64_t pre_index = i / stride_source_dim; int64_t dim_index = (i - pre_index * stride_source_dim) / stride; IndexT source_dim_idx = index[dim_index]; int64_t output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride; output[output_index] += static_cast(alpha) * source[i]; } } }; // namespace template class IndexAddCpuKernel final : public user_op::OpKernel { public: IndexAddCpuKernel() = default; ~IndexAddCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* index = ctx->Tensor4ArgNameAndIndex("index", 0); const user_op::Tensor* source = ctx->Tensor4ArgNameAndIndex("source", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const int64_t dim = ctx->Attr("dim"); const float alpha = ctx->Attr("alpha"); const ShapeView& input_shape = input->shape_view(); const ShapeView& source_shape = source->shape_view(); std::vector input_stride(input->stride().begin(), input->stride().end()); const int64_t stride = input_stride[dim]; const int64_t source_dim = source_shape.At(dim); const int64_t delta = input_shape.At(dim) - source_dim; DataType index_dtype = index->data_type(); const int32_t n = source->shape_view().elem_cnt(); Memcpy( ctx->stream(), output->mut_dptr(), input->dptr(), input->shape_view().elem_cnt() * GetSizeOfDataType(input->data_type())); if (GetSizeOfDataType(index_dtype) == 4) { index_add_cpu_kernel(n, input->dptr(), index->dptr(), source->dptr(), output->mut_dptr(), stride, source_dim, delta, alpha); } else { index_add_cpu_kernel(n, input->dptr(), index->dptr(), source->dptr(), output->mut_dptr(), stride, source_dim, delta, alpha); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_INDEX_ADD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("index_add") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("output", 0) == GetDataType::value)); REGISTER_INDEX_ADD_CPU_KERNEL(int8_t) REGISTER_INDEX_ADD_CPU_KERNEL(int32_t) REGISTER_INDEX_ADD_CPU_KERNEL(float) REGISTER_INDEX_ADD_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/index_add_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void index_add_cuda_kernel(const int64_t n, const T* input, const IndexT* index, const T* source, T* output, const int64_t stride, const int64_t source_dim, const int64_t delta, const float alpha) { // For x = flow.ones(5, 3) // source = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float) // index = flow.tensor([0, 4, 2]) // dim = 0 // We have: // stride = 3 // source_dim = 3 // stride * source_dim = 9 // alpha = 1.0 // delta = 5 - 3 = 2 // For i = 8 // pre_index = i / stride_source_dim = 8 / 9 = 0 // dim_index = i % stride_source_dim / stride = 8 % 9 / 3 = 0 // source_dim_idx = index[dim_index] = index[0] = 0 // output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride = 9 + (2 * 0 + 0 - // 0) * 3 = 9 cuda::atomic::Add(output + output_index, static_cast(alpha) * source[i])=> // output[9] += 1.0 * 9 = 10.0 const int64_t stride_source_dim = stride * source_dim; CUDA_1D_KERNEL_LOOP(i, n) { int64_t pre_index = i / stride_source_dim; int64_t dim_index = (i - pre_index * stride_source_dim) / stride; IndexT source_dim_idx = index[dim_index]; int64_t output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride; cuda::atomic::Add(output + output_index, static_cast(alpha) * source[i]); } } }; // namespace template class IndexAddGpuKernel final : public user_op::OpKernel { public: IndexAddGpuKernel() = default; ~IndexAddGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); const user_op::Tensor* index = ctx->Tensor4ArgNameAndIndex("index", 0); const user_op::Tensor* source = ctx->Tensor4ArgNameAndIndex("source", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const int64_t dim = ctx->Attr("dim"); const float alpha = ctx->Attr("alpha"); const ShapeView& input_shape = input->shape_view(); const ShapeView& source_shape = source->shape_view(); std::vector input_stride(input->stride().begin(), input->stride().end()); const int64_t stride = input_stride[dim]; const int64_t source_dim = source_shape.At(dim); const int64_t delta = input_shape.At(dim) - source_dim; DataType index_dtype = index->data_type(); const int32_t n = source->shape_view().elem_cnt(); Memcpy( ctx->stream(), output->mut_dptr(), input->dptr(), input->shape_view().elem_cnt() * GetSizeOfDataType(input->data_type())); if (GetSizeOfDataType(index_dtype) == 4) { RUN_CUDA_KERNEL((index_add_cuda_kernel), ctx->stream(), n, n, input->dptr(), index->dptr(), source->dptr(), output->mut_dptr(), stride, source_dim, delta, alpha); } else { RUN_CUDA_KERNEL((index_add_cuda_kernel), ctx->stream(), n, n, input->dptr(), index->dptr(), source->dptr(), output->mut_dptr(), stride, source_dim, delta, alpha); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_INDEX_ADD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("index_add") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("output", 0) == GetDataType::value)); REGISTER_INDEX_ADD_CUDA_KERNEL(float) REGISTER_INDEX_ADD_CUDA_KERNEL(half) REGISTER_INDEX_ADD_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/indexed_slices_reduce_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h" namespace oneflow { namespace { template class IndexedSlicesReduceSumKernel final : public user_op::OpKernel { public: IndexedSlicesReduceSumKernel() = default; ~IndexedSlicesReduceSumKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_indices = ctx->Tensor4ArgNameAndIndex("x_indices", 0); const user_op::Tensor* x_values = ctx->Tensor4ArgNameAndIndex("x_values", 0); user_op::Tensor* y_indices = ctx->Tensor4ArgNameAndIndex("y_indices", 0); user_op::Tensor* y_values = ctx->Tensor4ArgNameAndIndex("y_values", 0); user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex("num_unique", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr; int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0; const int64_t n = x_indices->shape_view().elem_cnt(); const int64_t m = x_values->shape_view().elem_cnt() / n; IndexedSlicesReduceSumKernelUtil::ReduceSum( ctx->stream(), n, m, x_indices->dptr(), x_values->dptr(), num_unique->mut_dptr(), y_indices->mut_dptr(), y_values->mut_dptr(), tmp_ptr, tmp_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const auto& x_indices = ctx->InputTensorDesc("x_indices", 0); const auto& x_values = ctx->InputTensorDesc("x_values", 0); const int64_t n = x_indices.shape().elem_cnt(); const int64_t m = x_values.shape().elem_cnt() / n; int64_t workspace_size_in_bytes; IndexedSlicesReduceSumKernelUtil::GetReduceSumWorkspaceSizeInBytes( nullptr, n, m, &workspace_size_in_bytes); return workspace_size_in_bytes; }; } #define REGISTER_INDEXED_SLICES_REDUCE_SUM_KERNEL(device_type_v, data_type_pair, \ indices_type_pair) \ REGISTER_USER_KERNEL("indexed_slices_reduce_sum") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("x_values", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("x_indices", 0) == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_REDUCE_SUM_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h" #include "oneflow/user/kernels/unique_kernel_util.h" #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" namespace oneflow { template int64_t GetUniqueIdxSize(int64_t n) { return GetCudaAlignedSize(n * sizeof(IDX)); } template void IndexedSlicesReduceSumKernelUtil::ReduceSum( ep::Stream* stream, int64_t n, int64_t m, const K* indices, const T* values, IDX* num_unique_indices, K* indices_out, T* values_out, void* workspace, int64_t workspace_size_in_bytes) { const int64_t unique_idx_size = GetUniqueIdxSize(n); CHECK_LE(unique_idx_size, workspace_size_in_bytes); IDX* unique_idx_ptr = reinterpret_cast(workspace); void* unique_workspace_ptr = reinterpret_cast(workspace) + unique_idx_size; const int64_t unique_workspace_size = workspace_size_in_bytes - unique_idx_size; UniqueKernelUtil::Unique(stream, n, indices, num_unique_indices, indices_out, unique_idx_ptr, unique_workspace_ptr, unique_workspace_size, /*sorted*/ false); const Shape flat_in_shape({1, n, m}); Memset(stream, values_out, 0, n * m * sizeof(T)); UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( stream, unique_idx_ptr, values, n, n, 1, m, 0, values_out); } template void IndexedSlicesReduceSumKernelUtil::GetReduceSumWorkspaceSizeInBytes( ep::Stream* stream, int64_t n, int64_t m, int64_t* workspace_size_in_bytes) { int64_t unique_workspace_size; UniqueKernelUtil::GetUniqueWorkspaceSizeInBytes(stream, n, &unique_workspace_size); *workspace_size_in_bytes = GetUniqueIdxSize(n) + unique_workspace_size; } #define INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL(device_type, key_type_pair, \ val_type_pair, idx_type_pair) \ template struct IndexedSlicesReduceSumKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL, DEVICE_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct IndexedSlicesReduceSumKernelUtil { static void ReduceSum(ep::Stream* stream, int64_t n, int64_t m, const K* indices, const T* values, IDX* num_unique_indices, K* indices_out, T* values_out, void* workspace, int64_t workspace_size_in_bytes); static void GetReduceSumWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t m, int64_t* workspace_size_in_bytes); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/inv_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/eigen_util.h" namespace oneflow { namespace { static inline size_t BatchCount(const user_op::Tensor* batched_matrices) { size_t result = 1; for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) { result *= batched_matrices->shape_view().At(i); } return result; } static inline size_t MatrixStride(const user_op::Tensor* batched_matrices) { const int64_t num_axes = batched_matrices->shape_view().NumAxes(); return batched_matrices->shape_view().At(num_axes - 2) * batched_matrices->shape_view().At(num_axes - 1); } } // namespace template class CpuInvKernel final : public user_op::OpKernel { public: CpuInvKernel() = default; ~CpuInvKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); auto batch_count = BatchCount(x); auto matrix_stride = MatrixStride(x); auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); FOR_RANGE(int64_t, i, 0, batch_count) { ConstEigenMatrixMap x_mat(x_ptr + i * matrix_stride, matrix_size, matrix_size); EigenMatrixMap y_mat(y_ptr + i * matrix_stride, matrix_size, matrix_size); if (x_mat.determinant() == 0) { LOG(FATAL) << "(Batch element " << i << "): the inversion could not be completed because the input matrix is singular."; } y_mat = x_mat.inverse(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_INV_KERNEL(dtype) \ REGISTER_USER_KERNEL("inv").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_CPU_INV_KERNEL(float) REGISTER_CPU_INV_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/inv_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/arange_kernel_util.h" namespace oneflow { namespace { static inline size_t BatchCount(const user_op::Tensor* batched_matrices) { size_t result = 1; for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) { result *= batched_matrices->shape_view().At(i); } return result; } static inline size_t MatrixStride(const user_op::Tensor* batched_matrices) { const int64_t num_axes = batched_matrices->shape_view().NumAxes(); return batched_matrices->shape_view().At(num_axes - 2) * batched_matrices->shape_view().At(num_axes - 1); } void OFgetrfBatched(ep::Stream* stream, int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { OF_CUBLAS_CHECK(cublasSgetrfBatched(stream->As()->cublas_handle(), n, dA_array, ldda, ipiv_array, info_array, batchsize)); } void OFgetrfBatched(ep::Stream* stream, int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { OF_CUBLAS_CHECK(cublasDgetrfBatched(stream->As()->cublas_handle(), n, dA_array, ldda, ipiv_array, info_array, batchsize)); } void OFgetriBatched(ep::Stream* stream, int n, float** dA_array, int ldda, int* ipiv_array, float** dC_array, int lddc, int* info_array, int batchsize) { OF_CUBLAS_CHECK(cublasSgetriBatched(stream->As()->cublas_handle(), n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); } void OFgetriBatched(ep::Stream* stream, int n, double** dA_array, int ldda, int* ipiv_array, double** dC_array, int lddc, int* info_array, int batchsize) { OF_CUBLAS_CHECK(cublasDgetriBatched(stream->As()->cublas_handle(), n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); } } // namespace namespace user_op { template class CudaInvKernel final : public user_op::OpKernel { public: CudaInvKernel() = default; ~CudaInvKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); auto batch_count = BatchCount(x); auto matrix_stride = MatrixStride(x); auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2); const ShapeView& x_shape = x->shape_view(); const int64_t instance_num = x_shape.Count(0, x_shape.NumAxes() - 2); const int64_t infos_bytes = GetCudaAlignedSize(instance_num * sizeof(int)); const int64_t ipiv_bytes = GetCudaAlignedSize(batch_count * x_shape.At(x_shape.NumAxes() - 2) * sizeof(int)); const int64_t pptr_bytes = GetCudaAlignedSize(batch_count * sizeof(T*)); int* infos_getrf_ptr = tmp_buffer->mut_dptr(); int* infos_getrs_ptr = reinterpret_cast(reinterpret_cast(infos_getrf_ptr) + infos_bytes); int* ipiv_ptr = reinterpret_cast(reinterpret_cast(infos_getrs_ptr) + infos_bytes); T** x_pptr = reinterpret_cast(reinterpret_cast(ipiv_ptr) + ipiv_bytes); T** y_pptr = reinterpret_cast(reinterpret_cast(x_pptr) + pptr_bytes); T* x_copy_ptr = reinterpret_cast(reinterpret_cast(y_pptr) + pptr_bytes); Memcpy(ctx->stream(), x_copy_ptr, x->dptr(), x_shape.elem_cnt() * sizeof(T)); ArangeFunctor()(ctx->stream(), reinterpret_cast(x_copy_ptr), static_cast(matrix_stride * sizeof(T)), batch_count, reinterpret_cast(x_pptr)); ArangeFunctor()(ctx->stream(), reinterpret_cast(y->mut_dptr()), static_cast(matrix_stride * sizeof(T)), batch_count, reinterpret_cast(y_pptr)); Memset(ctx->stream(), infos_getrf_ptr, 0, infos_bytes); Memset(ctx->stream(), infos_getrs_ptr, 0, infos_bytes); Memset(ctx->stream(), ipiv_ptr, 0, ipiv_bytes); OFgetrfBatched(ctx->stream(), matrix_size, x_pptr, matrix_size, ipiv_ptr, infos_getrf_ptr, batch_count); OFgetriBatched(ctx->stream(), matrix_size, x_pptr, matrix_size, ipiv_ptr, y_pptr, matrix_size, infos_getrs_ptr, batch_count); std::vector infos_getrf_vec_host(batch_count, 0); std::vector infos_getrs_vec_host(batch_count, 0); OF_CUDA_CHECK(cudaMemcpyAsync(infos_getrf_vec_host.data(), infos_getrf_ptr, batch_count * sizeof(int), cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); OF_CUDA_CHECK(cudaMemcpyAsync(infos_getrs_vec_host.data(), infos_getrs_ptr, batch_count * sizeof(int), cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); CHECK_JUST(ctx->stream()->Sync()); FOR_RANGE(int64_t, i, 0, batch_count) { if (infos_getrf_vec_host[i] > 0) { LOG(FATAL) << "(Batch element " << i << "): The diagonal element " << infos_getrf_vec_host[i] << " is zero, the inversion could not be completed because the input matrix is " "singular."; } } FOR_RANGE(int64_t, i, 0, batch_count) { if (infos_getrs_vec_host[i] > 0) { LOG(FATAL) << "(Batch element " << i << "): The diagonal element " << infos_getrs_vec_host[i] << " is zero, the inversion could not be completed because the input matrix is " "singular."; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_INV_KERNEL(dtype) \ REGISTER_USER_KERNEL("inv") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& x_shape = ctx->InputShape("x", 0); \ auto batch_size = x_shape.Count(0, x_shape.NumAxes() - 2); \ const int64_t instance_num = x_shape.Count(0, x_shape.NumAxes() - 2); \ const int64_t infos_bytes = GetCudaAlignedSize(instance_num * sizeof(int)); \ const int64_t ipiv_bytes = \ GetCudaAlignedSize(batch_size * x_shape.At(x_shape.NumAxes() - 2) * sizeof(int)); \ const int64_t pptr_bytes = GetCudaAlignedSize(batch_size * sizeof(dtype*)); \ const int64_t x_copy_bytes = GetCudaAlignedSize(x_shape.elem_cnt() * sizeof(dtype)); \ return infos_bytes * 2 + ipiv_bytes + pptr_bytes * 2 + x_copy_bytes; \ }); REGISTER_CUDA_INV_KERNEL(float) REGISTER_CUDA_INV_KERNEL(double) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/kl_div_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/user/kernels/loss_kernel_util.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template void ComputeKLDivOut(int64_t elem_cnt, const T* input, const T* target, T* out, const bool log_target) { if (log_target) { FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = std::exp(target[i]) * (target[i] - input[i]); } } else { FOR_RANGE(int64_t, i, 0, elem_cnt) { const auto out_val = target[i] * (SafeLog(target[i]) - input[i]); out[i] = target[i] > 0 ? out_val : static_cast(0); } } } template void ComputeKLDivGradOut(int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx, const bool log_target) { FOR_RANGE(int64_t, i, 0, elem_cnt) { const T dy_val = dy[i]; dx[i] = log_target ? (-std::exp(target[i]) * dy_val) : (target[i] > 0 ? -target[i] * dy_val : 0); } } template class KLDivKernel : public SimpleLossKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, T* out) const { const bool log_target = ctx->Attr("log_target"); ComputeKLDivOut(elem_cnt, input, target, out, log_target); } }; template class KLDivGradKernel : public SimpleLossGradKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx) const { const bool log_target = ctx->Attr("log_target"); ComputeKLDivGradOut(elem_cnt, input, target, dy, dx, log_target); } }; } // namespace REGISTER_SIMPLE_LOSS_KERNEL_CPU("kl_div_loss", KLDivKernel) REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU("kl_div_loss_grad", KLDivGradKernel) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/kl_div_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/user/kernels/loss_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template struct KLDivFunctor { __device__ __forceinline__ T operator()(T input_val, T target_val) const { if (LOG_TARGET) { return exp(target_val) * (target_val - input_val); } else { const T zero_val = static_cast(0); const T out_val = target_val * (SafeLog(target_val) - input_val); return target_val > zero_val ? out_val : zero_val; } } }; template struct KLDivFunctor { __device__ __forceinline__ half operator()(half input_val, half target_val) const { if (LOG_TARGET) { return hexp(target_val) * (target_val - input_val); } else { const half zero_val = __float2half(0.f); const half out_val = target_val * (SafeLog(target_val) - input_val); return target_val > zero_val ? out_val : zero_val; } } }; template struct KLDivGradFunctor { __device__ __forceinline__ T operator()(T target_val, T dy_val) const { if (LOG_TARGET) { return -exp(target_val) * dy_val; } else { const T zero_val = static_cast(0); return target_val > zero_val ? -target_val * dy_val : zero_val; } } }; template struct KLDivGradFunctor { __device__ __forceinline__ half operator()(half target_val, half dy_val) const { if (LOG_TARGET) { return __hneg(hexp(target_val) * dy_val); } else { const half zero_val = __float2half(0.f); return target_val > zero_val ? __hneg(target_val * dy_val) : zero_val; } } }; template class KLDivKernel : public SimpleLossKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, T* out) const { const bool log_target = ctx->Attr("log_target"); if (log_target) { OF_CUDA_CHECK( (cuda::elementwise::Binary(KLDivFunctor(), elem_cnt, out, input, target, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK( (cuda::elementwise::Binary(KLDivFunctor(), elem_cnt, out, input, target, ctx->stream()->As()->cuda_stream()))); } } }; template class KLDivGradKernel : public SimpleLossGradKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx) const { const bool log_target = ctx->Attr("log_target"); if (log_target) { OF_CUDA_CHECK((cuda::elementwise::Binary( KLDivGradFunctor(), elem_cnt, dx, target, dy, ctx->stream()->As()->cuda_stream()))); } else { OF_CUDA_CHECK((cuda::elementwise::Binary( KLDivGradFunctor(), elem_cnt, dx, target, dy, ctx->stream()->As()->cuda_stream()))); } } }; } // namespace REGISTER_SIMPLE_LOSS_KERNEL_CUDA("kl_div_loss", KLDivKernel) REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA("kl_div_loss_grad", KLDivGradKernel) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/l1_l2_regularize_gradient_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h" namespace oneflow { namespace { template class L1L2RegularizeGradientKernel final : public user_op::OpKernel { public: L1L2RegularizeGradientKernel() = default; ~L1L2RegularizeGradientKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); L1L2RegularizeGradientKernelUtil::RegularizeGradient( ctx->stream(), out->shape_view().elem_cnt(), model->dptr(), model_diff->dptr(), out->mut_dptr(), l1, l2); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("l1_l2_regularize_gradient") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "model_diff", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCPU, float) REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCUDA, float) REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCUDA, double) #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h" namespace oneflow { template struct L1L2RegularizeGradientKernelUtil { static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff, T* out, const T l1, const T l2) { FOR_RANGE(int64_t, i, 0, n) { const T model_val = model[i]; out[i] = model_diff[i] + l1 * (model_val >= 0 ? 1 : -1) + l2 * model_val; } } }; #define INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU(type_cpp, type_proto) \ template struct L1L2RegularizeGradientKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void L1L2RegularizeGradientGpu(int64_t n, const T* model, const T* model_diff, T* out, const T l1, const T l2) { CUDA_1D_KERNEL_LOOP(i, n) { const T model_val = model[i]; out[i] = model_diff[i] + l1 * ((model_val >= 0) - (model_val <= 0)) + l2 * model_val; } } } // namespace template struct L1L2RegularizeGradientKernelUtil { static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff, T* out, const T l1, const T l2) { L1L2RegularizeGradientGpu<<As()->cuda_stream()>>>(n, model, model_diff, out, l1, l2); } }; #define INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA(type_cpp, type_proto) \ template struct L1L2RegularizeGradientKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct L1L2RegularizeGradientKernelUtil { static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff, T* out, T l1, T l2); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/l2_normalize_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template static void L2NormalizeForward(const int32_t n, const int32_t c, const int32_t d, const T epsilon, const T* in, T* square_x_sum, T* out) { for (int32_t i = 0; i < n; i++) { const int32_t offset = (i / d) * d * c + (i % d); for (int32_t j = 0; j < c; j++) { const T x = in[offset + j * d]; square_x_sum[i] += x * x; } const T norm = std::sqrt(std::max(square_x_sum[i], epsilon)); for (int32_t j = 0; j < c; j++) { const int32_t index = offset + j * d; out[index] = in[index] / norm; } } } template static void L2NormalizeBackward(const int32_t n, const int32_t c, const int32_t d, const T epsilon, const T* out, const T* out_diff, const T* square_x_sum, T* in_diff) { for (int32_t i = 0; i < n; i++) { const T norm = std::sqrt(std::max(square_x_sum[i], epsilon)); const int32_t offset = (i / d) * d * c + (i % d); if (square_x_sum[i] >= epsilon) { T y_dy_inner_prod = GetZeroVal(); for (int32_t j = 0; j < c; j++) { const int32_t index = offset + j * d; y_dy_inner_prod += out_diff[index] * out[index]; } for (int32_t j = 0; j < c; j++) { const int32_t index = offset + j * d; in_diff[index] = (1 / norm) * (out_diff[index] - y_dy_inner_prod * out[index]); } } else { for (int32_t j = 0; j < c; j++) { const int32_t index = offset + j * d; in_diff[index] = (1 / norm) * out_diff[index]; } } } } } // namespace template class CpuL2NormalizeKernel final : public user_op::OpKernel { public: CpuL2NormalizeKernel() = default; ~CpuL2NormalizeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex("square_x_sum", 0); const float epsilon = ctx->Attr("epsilon"); int32_t axis = ctx->Attr("axis"); int32_t c = x->shape_view().At(axis); int32_t n = x->shape_view().elem_cnt() / c; int32_t d = x->shape_view().Count(axis + 1); size_t square_x_sum_byte_size = square_x_sum->shape_view().elem_cnt() * sizeof(T); Memset(ctx->stream(), square_x_sum->mut_dptr(), 0, square_x_sum_byte_size); L2NormalizeForward(n, c, d, static_cast(epsilon), x->dptr(), square_x_sum->mut_dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_L2_NORMALIZE_KERNEL(dtype) \ REGISTER_USER_KERNEL("l2_normalize") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CPU_L2_NORMALIZE_KERNEL(float) REGISTER_CPU_L2_NORMALIZE_KERNEL(double) template class CpuL2NormalizeGradKernel final : public user_op::OpKernel { public: CpuL2NormalizeGradKernel() = default; ~CpuL2NormalizeGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex("square_x_sum", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float epsilon = ctx->Attr("epsilon"); int32_t axis = ctx->Attr("axis"); int32_t c = dy->shape_view().At(axis); int32_t n = dy->shape_view().elem_cnt() / c; int32_t d = dy->shape_view().Count(axis + 1); L2NormalizeBackward(n, c, d, static_cast(epsilon), y->dptr(), dy->dptr(), square_x_sum->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("l2_normalize_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(float) REGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/l2_normalize_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/cuda/layer_norm.cuh" namespace oneflow { namespace { template __global__ void L2NormalizeForward(const int32_t n, const int32_t c, const int32_t d, const ComputeType epsilon, const T* in, ComputeType* square_x_sum, T* out) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; for (int32_t i = blockIdx.x; i < n; i += gridDim.x) { ComputeType sum = GetZeroVal(); const int32_t offset = (i / d) * d * c + (i % d); for (int32_t j = threadIdx.x; j < c; j += blockDim.x) { const ComputeType x = static_cast(in[offset + j * d]); sum += x * x; } const ComputeType reduce_sum = BlockReduce(temp_storage).Sum(sum); if (threadIdx.x == 0) { square_x_sum[i] = reduce_sum; } __syncthreads(); const ComputeType inv_norm = rsqrtf(fmaxf(square_x_sum[i], epsilon)); for (int32_t j = threadIdx.x; j < c; j += blockDim.x) { const int32_t index = offset + j * d; out[index] = static_cast(inv_norm * static_cast(in[index])); } } } template __global__ void L2NormalizeBackward(const int32_t n, const int32_t c, const int32_t d, const float epsilon, const T* out, const T* out_diff, const T* square_x_sum, T* in_diff) { for (int32_t i = blockIdx.x; i < n; i += gridDim.x) { const T inv_norm = rsqrt(fmaxf(square_x_sum[i], epsilon)); const int32_t offset = (i / d) * d * c + (i % d); if (square_x_sum[i] >= epsilon) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage_prod_sum; T y_dy_prod_sum = GetZeroVal(); for (int32_t j = threadIdx.x; j < c; j += blockDim.x) { const int32_t index = offset + j * d; y_dy_prod_sum += out[index] * out_diff[index]; } const T reduce_y_dy_prod_sum = BlockReduce(temp_storage_prod_sum).Sum(y_dy_prod_sum); __shared__ T y_dy_inner_prod; if (threadIdx.x == 0) { y_dy_inner_prod = reduce_y_dy_prod_sum; } __syncthreads(); for (int32_t j = threadIdx.x; j < c; j += blockDim.x) { const int32_t index = offset + j * d; in_diff[index] = inv_norm * (out_diff[index] - y_dy_inner_prod * out[index]); } } else { for (int32_t j = threadIdx.x; j < c; j += blockDim.x) { const int32_t index = offset + j * d; in_diff[index] = inv_norm * out_diff[index]; } } } } } // namespace template class GpuL2NormalizeKernel final : public user_op::OpKernel { public: GpuL2NormalizeKernel() = default; ~GpuL2NormalizeKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex("square_x_sum", 0); const float epsilon = ctx->Attr("epsilon"); int32_t axis = ctx->Attr("axis"); int32_t c = x->shape_view().At(axis); int32_t n = x->shape_view().elem_cnt() / c; int32_t d = x->shape_view().Count(axis + 1); using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; RUN_CUDA_KERNEL((L2NormalizeForward), ctx->stream(), n, n, c, d, static_cast(epsilon), x->dptr(), square_x_sum->mut_dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_L2_NORMALIZE_KERNEL(dtype) \ REGISTER_USER_KERNEL("l2_normalize") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CUDA_L2_NORMALIZE_KERNEL(half) REGISTER_CUDA_L2_NORMALIZE_KERNEL(float) REGISTER_CUDA_L2_NORMALIZE_KERNEL(double) template class GpuL2NormalizeGradKernel final : public user_op::OpKernel { public: GpuL2NormalizeGradKernel() = default; ~GpuL2NormalizeGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex("square_x_sum", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const float epsilon = ctx->Attr("epsilon"); int32_t axis = ctx->Attr("axis"); int32_t c = dy->shape_view().At(axis); int32_t n = dy->shape_view().elem_cnt() / c; int32_t d = dy->shape_view().Count(axis + 1); RUN_CUDA_KERNEL((L2NormalizeBackward), ctx->stream(), n, n, c, d, static_cast(epsilon), y->dptr(), dy->dptr(), square_x_sum->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("l2_normalize_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(float) REGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/layer_norm_cpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { template class LayerNormCpuKernel final : public user_op::OpKernel { public: LayerNormCpuKernel() = default; ~LayerNormCpuKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); }; }; #define REGISTER_LAYER_NORM_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_LAYER_NORM_CPU_KERNEL(float) REGISTER_LAYER_NORM_CPU_KERNEL(double) template class LayerNormGradCpuKernel final : public user_op::OpKernel { public: LayerNormGradCpuKernel() = default; ~LayerNormGradCpuKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); }; }; #define REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float) REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double) template class FuseLayerNormGradCpuKernel final : public user_op::OpKernel { public: FuseLayerNormGradCpuKernel() = default; ~FuseLayerNormGradCpuKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); }; }; #define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("fuse_layer_norm_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float) REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double) template class LayerNormParamGradCpuKernel final : public user_op::OpKernel { public: LayerNormParamGradCpuKernel() = default; ~LayerNormParamGradCpuKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); }; }; #define REGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm_param_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(float) REGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/layer_norm_gpu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/cuda/atomic.cuh" #include #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/layer_norm.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace { template struct AffineStore { AffineStore(DST* y, int64_t row_size, const DST* gamma, const DST* beta) : y(y), row_size(row_size), gamma(gamma), beta(beta) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::layer_norm::Pack y_pack; cuda::layer_norm::Pack gamma_pack; cuda::layer_norm::Pack beta_pack; const int64_t offset = (row * row_size + col) / N; const int64_t gamma_offset = col / N; if (do_scale) { gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast(1.f); } } if (do_center) { beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { beta_pack.elem[i] = static_cast(0.f); } } #pragma unroll for (int i = 0; i < N; ++i) { DST normalized_i = static_cast(src[i]); if (do_scale || do_center) { y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; } else { y_pack.elem[i] = normalized_i; } } *(reinterpret_cast*>(y) + offset) = y_pack.storage; } DST* y; int64_t row_size; const DST* gamma; const DST* beta; }; template struct ScaleLoad { using LoadType = DST; ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size) : src(src), gamma(gamma), row_size(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::layer_norm::Pack src_pack; cuda::layer_norm::Pack gamma_pack; const int64_t offset = (row * row_size + col) / N; const int64_t gamma_offset = col / N; src_pack.storage = *(reinterpret_cast*>(src) + offset); if (do_scale) { gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast(1.f); } } #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(src_pack.elem[i] * gamma_pack.elem[i]); } } const SRC* src; const SRC* gamma; int64_t row_size; }; template struct AddStore { AddStore(const DST* add_to_output, DST* dst, int64_t row_size) : add_to_output(add_to_output), dst(dst), row_size(row_size) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::layer_norm::Pack add_to_output_pack; cuda::layer_norm::Pack dst_pack; const int64_t offset = (row * row_size + col) / N; if (do_add) { add_to_output_pack.storage = *(reinterpret_cast*>(add_to_output) + offset); } #pragma unroll for (int i = 0; i < N; ++i) { if (do_add) { dst_pack.elem[i] = static_cast(src[i]) + add_to_output_pack.elem[i]; } else { dst_pack.elem[i] = static_cast(src[i]); } } *(reinterpret_cast*>(dst) + offset) = dst_pack.storage; } const DST* add_to_output; DST* dst; int64_t row_size; }; template __inline__ __device__ T WarpReduce(T val) { for (int mask = 16; mask > 0; mask /= 2) { val += __shfl_down_sync(0xffffffff, val, mask); } return val; } constexpr int tile_size = 32; constexpr int num_per_block = 4; constexpr int block_dim_x = 32; constexpr int block_dim_y = 32 / num_per_block; template __global__ void LayerNormParamGrad(int rows, int cols, const T* __restrict__ dy, const T* __restrict__ x, const ComputeType* __restrict__ mean, const ComputeType* __restrict__ inv_var, T* __restrict__ tmp_gamma_diff, T* __restrict__ tmp_beta_diff) { __shared__ ComputeType dgamma[32][33]; __shared__ ComputeType dbeta[32][33]; ComputeType dgamma_sum[num_per_block]; ComputeType dbeta_sum[num_per_block]; #pragma unroll for (int index = 0; index < num_per_block; ++index) { dgamma_sum[index] = 0; dbeta_sum[index] = 0; } const int col_id = blockIdx.x * blockDim.x + threadIdx.x; if (col_id < cols) { for (int i = blockIdx.y * tile_size + threadIdx.y; i < rows; i += tile_size * gridDim.y) { #pragma unroll for (int index = 0; index < num_per_block; ++index) { int row_id = i + index * blockDim.y; if (row_id < rows) { int offset = row_id * cols + col_id; const ComputeType dy_val = static_cast(dy[offset]); const ComputeType x_val = static_cast(x[offset]); const ComputeType mean_val = mean[row_id]; const ComputeType inv_var_val = inv_var[row_id]; dgamma_sum[index] += dy_val * (x_val - mean_val) * inv_var_val; dbeta_sum[index] += dy_val; } } } } #pragma unroll for (int index = 0; index < num_per_block; ++index) { dgamma[index * blockDim.y + threadIdx.y][threadIdx.x] = dgamma_sum[index]; dbeta[index * blockDim.y + threadIdx.y][threadIdx.x] = dbeta_sum[index]; } __syncthreads(); #pragma unroll for (int index = 0; index < num_per_block; ++index) { const int col_id = blockIdx.x * blockDim.x + threadIdx.y + index * blockDim.y; if (col_id < cols) { ComputeType gamma_sum = dgamma[threadIdx.x][threadIdx.y + index * blockDim.y]; ComputeType beta_sum = dbeta[threadIdx.x][threadIdx.y + index * blockDim.y]; ComputeType global_dgamma = WarpReduce(gamma_sum); ComputeType global_dbeta = WarpReduce(beta_sum); if (threadIdx.x == 0) { const int offset = blockIdx.y * cols + col_id; tmp_gamma_diff[offset] = global_dgamma; tmp_beta_diff[offset] = global_dbeta; } } } } template int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) { using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; const int grid_dim_x = (norm_size + tile_size - 1) / tile_size; const int max_grid_dim_y = (num_instances + tile_size - 1) / tile_size; const int block_size = block_dim_x * block_dim_y; int max_active_blocks = 0; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, LayerNormParamGrad, block_size, 0)); int waves = 1; int dev; OF_CUDA_CHECK(cudaGetDevice(&dev)); int sm_count; OF_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev)); int num_blocks = max_active_blocks * sm_count * waves; int grid_dim_y = std::min(max_grid_dim_y, static_cast(num_blocks / grid_dim_x)); return std::max(grid_dim_y, 1); } template void LayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance) { using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; cuda::layer_norm::DirectLoad load(x_ptr, norm_size); AffineStore store(y_ptr, norm_size, gamma_ptr, beta_ptr); cuda::layer_norm::DispatchLayerNorm( stream->As()->cuda_stream(), load, store, num_instances, norm_size, epsilon, mean->mut_dptr(), inv_variance->mut_dptr()); } template void DispatchLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance) { if (gamma_ptr != nullptr && beta_ptr != nullptr) { LayerNormForwardGpu(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance); } else if (gamma_ptr != nullptr && beta_ptr == nullptr) { LayerNormForwardGpu(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance); } else if (gamma_ptr == nullptr && beta_ptr != nullptr) { LayerNormForwardGpu(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance); } else { LayerNormForwardGpu(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, y_ptr, mean, inv_variance); } } template void LayerNormBackwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean, const user_op::Tensor* inv_variance, const T* gamma_ptr, const T* add_to_output_ptr, T* dx_ptr) { using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; cuda::layer_norm::DirectLoad load_x(x_ptr, norm_size); ScaleLoad load_scaled_dy(dy_ptr, gamma_ptr, norm_size); AddStore store(add_to_output_ptr, dx_ptr, norm_size); OF_CUDA_CHECK((cuda::layer_norm::DispatchLayerNormGrad( stream->As()->cuda_stream(), load_x, load_scaled_dy, store, mean->dptr(), inv_variance->dptr(), num_instances, norm_size))); } template void DispatchLayerNormBackwardDoAdd(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean, const user_op::Tensor* inv_variance, const T* gamma_ptr, const T* add_to_output_ptr, T* dx_ptr) { if (add_to_output_ptr != nullptr) { LayerNormBackwardGpu(stream, num_instances, norm_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr); } else { LayerNormBackwardGpu(stream, num_instances, norm_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr); } } template void LaunchLayerNormBackward(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean, const user_op::Tensor* inv_variance, const T* gamma_ptr, const T* add_to_output_ptr, T* dx_ptr) { if (gamma_ptr != nullptr) { DispatchLayerNormBackwardDoAdd(stream, num_instances, norm_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr); } else { DispatchLayerNormBackwardDoAdd(stream, num_instances, norm_size, dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr); } } } // namespace template class LayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LayerNormGpuKernel() = default; ~LayerNormGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); const double epsilon = ctx->Attr("epsilon"); CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON); const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const T* gamma_ptr = nullptr; const T* beta_ptr = nullptr; if (ctx->has_input("gamma", 0)) { const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); gamma_ptr = gamma->dptr(); CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size); } if (ctx->has_input("beta", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr(); } DispatchLayerNormForwardGpu(ctx->stream(), num_instances, norm_size, epsilon, x->dptr(), gamma_ptr, beta_ptr, y->mut_dptr(), mean, inv_variance); }; }; #define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_LAYER_NORM_CUDA_KERNEL(float) REGISTER_LAYER_NORM_CUDA_KERNEL(double) REGISTER_LAYER_NORM_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_LAYER_NORM_CUDA_KERNEL(nv_bfloat16) #endif template class LayerNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LayerNormGradGpuKernel() = default; ~LayerNormGradGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const T* gamma_ptr = nullptr; if (ctx->has_input("gamma", 0)) { gamma_ptr = ctx->Tensor4ArgNameAndIndex("gamma", 0)->dptr(); } const T* add_to_output_ptr = nullptr; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), dx->data_type()); CHECK_EQ(add_to_output->shape_view(), dx->shape_view()); add_to_output_ptr = add_to_output->dptr(); } LaunchLayerNormBackward(ctx->stream(), num_instances, norm_size, dy->dptr(), x->dptr(), mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr()); }; }; #define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ if (ctx.has_input("_add_to_output", 0)) { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \ } \ return Maybe::Ok(); \ }); REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(float) REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(double) REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(nv_bfloat16) #endif template class LayerNormParamGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LayerNormParamGradGpuKernel() = default; ~LayerNormParamGradGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const DataType data_type = dy->data_type(); const int grid_dim_x = (norm_size + tile_size - 1) / tile_size; const int grid_dim_y = GetGirdDimY(num_instances, norm_size); const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T); T* tmp_gamma_diff_ptr = reinterpret_cast(tmp_buffer->mut_dptr()); T* tmp_beta_diff_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_gamma_diff_size); T* reduce_buf_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + 2 * tmp_gamma_diff_size); using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; LayerNormParamGrad<<stream()->As()->cuda_stream()>>>( num_instances, norm_size, dy->dptr(), x->dptr(), mean->dptr(), inv_variance->dptr(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr); const int32_t m = norm_size; const int32_t n = 1; const int32_t k = grid_dim_y; std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), data_type); CHECK(fill); fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y); std::unique_ptr matmul = ep::primitive::NewPrimitive( ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T, ep::primitive::BlasTransposeType::N); CHECK(matmul); if (ctx->has_output("gamma_diff", 0)) { user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0); matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0, gamma_diff->mut_dptr()); } if (ctx->has_output("beta_diff", 0)) { user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0); matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0, beta_diff->mut_dptr()); } }; }; #define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("layer_norm_param_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); \ const bool has_gamma_diff = ctx->has_output("gamma_diff", 0); \ const bool has_beta_diff = ctx->has_output("beta_diff", 0); \ const auto& dy = ctx->InputTensorDesc("dy", 0); \ const int64_t num_instances = dy.shape().Count(0, begin_params_axis); \ const int64_t norm_size = dy.shape().Count(begin_params_axis); \ const int grid_dim_y = GetGirdDimY(num_instances, norm_size); \ size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \ return tmp_buffer_size; \ }); REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/lerp_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/lerp_kernel_util.h" namespace oneflow { template class LerpKernel final : public user_op::OpKernel { public: LerpKernel() = default; ~LerpKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex("start", 0); const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex("end", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& start_shape = start->shape_view(); const ShapeView& end_shape = end->shape_view(); const ShapeView& weight_shape = weight->shape_view(); CHECK_EQ(start_shape, end_shape); CHECK_EQ(start_shape, weight_shape); const T* start_ptr = start->dptr(); const T* end_ptr = end->dptr(); const T* weight_ptr = weight->dptr(); T* out_ptr = out->mut_dptr(); LerpKernelUtil::Forward(ctx->stream(), start_shape.elem_cnt(), start_ptr, weight_ptr, end_ptr, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_LERP_KERNEL(device_type, dtype) \ REGISTER_USER_KERNEL("lerp").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_LERP_KERNEL(DeviceType::kCPU, float) REGISTER_LERP_KERNEL(DeviceType::kCPU, double) REGISTER_LERP_KERNEL(DeviceType::kCPU, uint8_t) REGISTER_LERP_KERNEL(DeviceType::kCPU, int8_t) REGISTER_LERP_KERNEL(DeviceType::kCPU, int32_t) REGISTER_LERP_KERNEL(DeviceType::kCPU, int64_t) #ifdef WITH_CUDA REGISTER_LERP_KERNEL(DeviceType::kCUDA, half) REGISTER_LERP_KERNEL(DeviceType::kCUDA, float) REGISTER_LERP_KERNEL(DeviceType::kCUDA, double) REGISTER_LERP_KERNEL(DeviceType::kCUDA, uint8_t) REGISTER_LERP_KERNEL(DeviceType::kCUDA, int8_t) REGISTER_LERP_KERNEL(DeviceType::kCUDA, int32_t) REGISTER_LERP_KERNEL(DeviceType::kCUDA, int64_t) #endif // WITH_CUDA template class LerpGradKernel final : public user_op::OpKernel { public: LerpGradKernel() = default; ~LerpGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex("start", 0); const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex("end", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex("out_diff", 0); user_op::Tensor* start_diff = ctx->Tensor4ArgNameAndIndex("start_diff", 0); user_op::Tensor* end_diff = ctx->Tensor4ArgNameAndIndex("end_diff", 0); user_op::Tensor* weight_diff = ctx->Tensor4ArgNameAndIndex("weight_diff", 0); const ShapeView& start_shape = start->shape_view(); const ShapeView& end_shape = end->shape_view(); const ShapeView& weight_shape = weight->shape_view(); CHECK_EQ(start_shape, end_shape); CHECK_EQ(start_shape, weight_shape); const T* start_ptr = start->dptr(); const T* end_ptr = end->dptr(); const T* weight_ptr = weight->dptr(); const T* out_diff_ptr = out_diff->dptr(); T* start_diff_ptr = start_diff->mut_dptr(); T* end_diff_ptr = end_diff->mut_dptr(); T* weight_diff_ptr = weight_diff->mut_dptr(); LerpKernelUtil::Backward(ctx->stream(), start_shape.elem_cnt(), start_ptr, weight_ptr, end_ptr, out_diff_ptr, start_diff_ptr, weight_diff_ptr, end_diff_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_LERP_GRAD_KERNEL(device_type, dtype) \ REGISTER_USER_KERNEL("lerp_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("start_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("weight_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("end_diff", 0) == GetDataType::value)); REGISTER_LERP_GRAD_KERNEL(DeviceType::kCPU, float) REGISTER_LERP_GRAD_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, half) REGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, float) REGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, double) #endif // WITH_CUDA template class ScalarLerpKernel final : public user_op::OpKernel { public: ScalarLerpKernel() = default; ~ScalarLerpKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex("start", 0); const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex("end", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& start_shape = start->shape_view(); const ShapeView& end_shape = end->shape_view(); CHECK_EQ(start_shape, end_shape); const T* start_ptr = start->dptr(); const T* end_ptr = end->dptr(); T* out_ptr = out->mut_dptr(); Scalar scalar_operand; if (ctx->Attr("has_int_operand")) { scalar_operand = ctx->Attr("int_operand"); } else if (ctx->Attr("has_float_operand")) { scalar_operand = ctx->Attr("float_operand"); } else { UNIMPLEMENTED(); } ScalarLerpKernelUtil::Forward( ctx->stream(), start_shape.elem_cnt(), start_ptr, end_ptr, scalar_operand, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCALAR_LERP_KERNEL(device_type, dtype, value_type) \ REGISTER_USER_KERNEL("scalar_lerp") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, float, double) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, double, double) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int8_t, int64_t) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int32_t, int64_t) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int64_t, int64_t) #ifdef WITH_CUDA REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, half, double) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, float, double) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, double, double) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int8_t, int64_t) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int32_t, int64_t) REGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int64_t, int64_t) #endif // WITH_CUDA template class ScalarLerpGradKernel final : public user_op::OpKernel { public: ScalarLerpGradKernel() = default; ~ScalarLerpGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex("start", 0); const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex("end", 0); const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex("out_diff", 0); user_op::Tensor* start_diff = ctx->Tensor4ArgNameAndIndex("start_diff", 0); user_op::Tensor* end_diff = ctx->Tensor4ArgNameAndIndex("end_diff", 0); const ShapeView& start_shape = start->shape_view(); const ShapeView& end_shape = end->shape_view(); CHECK_EQ(start_shape, end_shape); const T* start_ptr = start->dptr(); const T* end_ptr = end->dptr(); const T* out_diff_ptr = out_diff->dptr(); T* start_diff_ptr = start_diff->mut_dptr(); T* end_diff_ptr = end_diff->mut_dptr(); Scalar scalar_operand; if (ctx->Attr("has_int_operand")) { scalar_operand = ctx->Attr("int_operand"); } else if (ctx->Attr("has_float_operand")) { scalar_operand = ctx->Attr("float_operand"); } else { UNIMPLEMENTED(); } ScalarLerpKernelUtil::Backward( ctx->stream(), start_shape.elem_cnt(), start_ptr, end_ptr, out_diff_ptr, scalar_operand, start_diff_ptr, end_diff_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCALAR_LERP_GRAD_KERNEL(device_type, dtype, value_type) \ REGISTER_USER_KERNEL("scalar_lerp_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("start_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("end_diff", 0) == GetDataType::value)); REGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCPU, float, double) REGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCPU, double, double) #ifdef WITH_CUDA REGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, half, double) REGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, float, double) REGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, double, double) #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/lerp_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/lerp_kernel_util.h" namespace oneflow { template struct LerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, T* out) { FOR_RANGE(int64_t, i, 0, n) { out[i] = start[i] + weight[i] * (end[i] - start[i]); } } static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, const T* out_diff, T* start_diff, T* weight_diff, T* end_diff) { FOR_RANGE(int64_t, i, 0, n) { T out_diff_i = out_diff[i]; start_diff[i] = (static_cast(1.0) - weight[i]) * out_diff_i; weight_diff[i] = (end[i] - start[i]) * out_diff_i; end_diff[i] = weight[i] * out_diff_i; } } }; template struct ScalarLerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const Scalar operand, T* out) { T weight = static_cast(operand.Value()); FOR_RANGE(int64_t, i, 0, n) { out[i] = start[i] + weight * (end[i] - start[i]); } } static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const T* out_diff, const Scalar operand, T* start_diff, T* end_diff) { T weight = static_cast(operand.Value()); FOR_RANGE(int64_t, i, 0, n) { T out_diff_i = out_diff[i]; start_diff[i] = (static_cast(1.0) - weight) * out_diff_i; end_diff[i] = out_diff_i - start_diff[i]; } } }; #define INSTANTIATE_LERP_KERNEL_UTIL_CPU(data_type, other) \ template struct LerpKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_LERP_KERNEL_UTIL_CPU, LERP_DATA_TYPE_SEQ_CPU) #undef INSTANTIATE_LERP_KERNEL_UTIL_CPU #define INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU(data_type, value_data_type) \ template struct ScalarLerpKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU, LERP_DATA_TYPE_SEQ_CPU, SCALAR_VALUE_DATA_TYPE_SEQ) #undef INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/lerp_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/lerp_kernel_util.h" namespace oneflow { namespace { template __global__ void LerpForwardGpu(const int n, const T* start, const T* weight, const T* end, T* out) { CUDA_1D_KERNEL_LOOP(i, n) { const T start_i = start[i]; out[i] = start_i + weight[i] * (end[i] - start_i); } } template __global__ void ScalarLerpForwardGpu(const int n, const T* start, const ValueT weight, const T* end, T* out) { T weight_calculate = 0.0; if constexpr (std::is_same::value) { weight_calculate = __float2half(static_cast(weight)); } else { weight_calculate = static_cast(weight); } CUDA_1D_KERNEL_LOOP(i, n) { const T start_i = start[i]; out[i] = start_i + weight_calculate * (end[i] - start_i); } } template __global__ void LerpBackwardGpu(const int n, const T* start, const T* weight, const T* end, const T* out_diff, T* start_diff, T* weight_diff, T* end_diff) { CUDA_1D_KERNEL_LOOP(i, n) { const T out_diff_i = out_diff[i]; const T start_diff_i = (static_cast(1.0) - weight[i]) * out_diff_i; start_diff[i] = start_diff_i; weight_diff[i] = (end[i] - start[i]) * out_diff_i; end_diff[i] = out_diff_i - start_diff_i; } } template __global__ void ScalarLerpBackwardGpu(const int n, const T* start, const ValueT weight, const T* end, const T* out_diff, T* start_diff, T* end_diff) { T weight_calculate = 0.0; if constexpr (std::is_same::value) { weight_calculate = __float2half(static_cast(weight)); } else { weight_calculate = static_cast(weight); } CUDA_1D_KERNEL_LOOP(i, n) { T out_diff_i = out_diff[i]; const T start_diff_i = (static_cast(1.0) - weight_calculate) * out_diff_i; start_diff[i] = start_diff_i; end_diff[i] = out_diff_i - start_diff_i; } } } // namespace template struct LerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, T* out) { RUN_CUDA_KERNEL((LerpForwardGpu), stream, n, n, start, weight, end, out); } static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, const T* out_diff, T* start_diff, T* weight_diff, T* end_diff) { RUN_CUDA_KERNEL((LerpBackwardGpu), stream, n, n, start, weight, end, out_diff, start_diff, weight_diff, end_diff); } }; template struct ScalarLerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const Scalar operand, T* out) { ValueT weight = operand.Value(); RUN_CUDA_KERNEL((ScalarLerpForwardGpu), stream, n, n, start, weight, end, out); } static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const T* out_diff, const Scalar operand, T* start_diff, T* end_diff) { ValueT weight = operand.Value(); RUN_CUDA_KERNEL((ScalarLerpBackwardGpu), stream, n, n, start, weight, end, out_diff, start_diff, end_diff); } }; #define INSTANTIATE_LERP_KERNEL_UTIL_CUDA(data_type, other) \ template struct LerpKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_LERP_KERNEL_UTIL_CUDA, LERP_DATA_TYPE_SEQ_CUDA) #undef INSTANTIATE_LERP_KERNEL_UTIL_CUDA #define INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA(data_type, value_data_type) \ template struct ScalarLerpKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA, LERP_DATA_TYPE_SEQ_CUDA, SCALAR_VALUE_DATA_TYPE_SEQ) #undef INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/lerp_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_ #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template struct LerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, T* out); static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight, const T* end, const T* out_diff, T* start_diff, T* weight_diff, T* end_diff); }; template struct ScalarLerpKernelUtil { static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const Scalar operand, T* out); static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end, const T* out_diff, const Scalar operand, T* start_diff, T* end_diff); }; #define SCALAR_VALUE_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define LERP_DATA_TYPE_SEQ_CPU \ FLOATING_DATA_TYPE_SEQ \ SIGNED_INT_DATA_TYPE_SEQ \ UNSIGNED_INT_DATA_TYPE_SEQ #ifdef WITH_CUDA #define LERP_DATA_TYPE_SEQ_CUDA \ FLOATING_DATA_TYPE_SEQ \ SIGNED_INT_DATA_TYPE_SEQ \ UNSIGNED_INT_DATA_TYPE_SEQ \ HALF_DATA_TYPE_SEQ #endif } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/linalg_cross_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_kernel.h" namespace oneflow { template class CpuLinalgCrossKernel final : public user_op::OpKernel { public: CpuLinalgCrossKernel() = default; ~CpuLinalgCrossKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* other_tensor = ctx->Tensor4ArgNameAndIndex("other", 0); auto* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const auto shape = input_tensor->shape_view(); const auto num_axes = shape.NumAxes(); int64_t dim = ctx->Attr("dim"); const auto strides = [&shape]() -> std::vector { std::vector result(shape.NumAxes(), 1); for (size_t i(0); i < result.size() - 1; ++i) { result[i] = shape.Count(i + 1); } return result; }(); const int64_t total = shape.elem_cnt() / 3; int64_t stride = strides[dim]; const T* input_ptr = input_tensor->dptr(); const T* other_ptr = other_tensor->dptr(); T* out_dtr = out_tensor->mut_dptr(); std::vector positions_in_dims(num_axes); int64_t start = 0; int64_t s = 0; while (s < total) { out_dtr[start + 0 * stride] = input_ptr[start + 1 * stride] * other_ptr[start + 2 * stride] - input_ptr[start + 2 * stride] * other_ptr[start + 1 * stride]; out_dtr[start + 1 * stride] = input_ptr[start + 2 * stride] * other_ptr[start + 0 * stride] - input_ptr[start + 0 * stride] * other_ptr[start + 2 * stride]; out_dtr[start + 2 * stride] = input_ptr[start + 0 * stride] * other_ptr[start + 1 * stride] - input_ptr[start + 1 * stride] * other_ptr[start + 0 * stride]; ++s; FOR_RANGE(int64_t, i, 0, num_axes) { if (i == dim) continue; ++positions_in_dims[i]; start += strides[i]; if (positions_in_dims[i] == shape.At(i) && i != num_axes - 1) { start -= positions_in_dims[i] * strides[i]; positions_in_dims[i] = 0; } else { break; } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_LINALG_CROSS_KERNEL(dtype) \ REGISTER_USER_KERNEL("linalg_cross") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_CPU_LINALG_CROSS_KERNEL(float) REGISTER_CPU_LINALG_CROSS_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/linalg_cross_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/device/cuda_util.h" namespace { template __global__ void LinalgCrossForward(const int64_t n, const T* input, const T* other, T* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { const int64_t index = i * 3; out[index] = input[index + 1] * other[index + 2] - input[index + 2] * other[index + 1]; out[index + 1] = input[index + 2] * other[index] - input[index] * other[index + 2]; out[index + 2] = input[index] * other[index + 1] - input[index + 1] * other[index]; } } } // namespace namespace oneflow { template class CudaLinalgCrossKernel final : public user_op::OpKernel { public: CudaLinalgCrossKernel() = default; ~CudaLinalgCrossKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* other_tensor = ctx->Tensor4ArgNameAndIndex("other", 0); auto* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t n = input_tensor->shape_view().elem_cnt() / 3; if (n == 0) { return; } RUN_CUDA_KERNEL((LinalgCrossForward), ctx->stream(), n, n, input_tensor->dptr(), other_tensor->dptr(), out_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_LINALG_CROSS_KERNEL(dtype) \ REGISTER_USER_KERNEL("linalg_cross") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)); REGISTER_CUDA_LINALG_CROSS_KERNEL(float) REGISTER_CUDA_LINALG_CROSS_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/log_softmax_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/log_softmax.h" #include "oneflow/core/ep/include/primitive/log_softmax_backward.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template std::unique_ptr NewLogSoftmaxPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto LogSoftmaxPrimitiveExists() { return hob::make_custom("LogSoftmaxPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewLogSoftmaxPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewLogSoftmaxBackwardPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto LogSoftmaxBackwardPrimitiveExists() { return hob::make_custom("LogSoftmaxBackwardPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewLogSoftmaxBackwardPrimitive(&ctx).operator bool(); }); } class LogSoftmaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LogSoftmaxKernel() = default; ~LogSoftmaxKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); const ShapeView& in_shape = in->shape_view(); const int64_t num_classes = in_shape.At(in_shape.NumAxes() - 1); const int64_t num_instances = in_shape.Count(0, in_shape.NumAxes() - 1); std::unique_ptr primitive = NewLogSoftmaxPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), num_instances, num_classes, in->dptr(), prob->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class LogSoftmaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LogSoftmaxGradKernel() = default; ~LogSoftmaxGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t num_classes = prob->shape_view().At(prob->shape_view().NumAxes() - 1); const int64_t num_instances = prob->shape_view().elem_cnt() / num_classes; std::unique_ptr primitive = NewLogSoftmaxBackwardPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), num_instances, num_classes, prob->dptr(), dy->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace REGISTER_USER_KERNEL("log_softmax") .SetCreateFn() .SetIsMatchedHob(LogSoftmaxPrimitiveExists() == true); REGISTER_USER_KERNEL("log_softmax_grad") .SetCreateFn() .SetIsMatchedHob(LogSoftmaxBackwardPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/logical_not_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/elementwise_unary.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace { template std::unique_ptr NewLogicalNotPrimitive(Context* ctx) { const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex("x", 0)->data_type(); const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex("y", 0)->data_type(); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kLogicalNot, in_data_type, out_data_type); } class LogicalNotKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LogicalNotKernel() = default; ~LogicalNotKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); int64_t n = tensor_x->shape_view().elem_cnt(); if (n != 0) { auto primitive = NewLogicalNotPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), tensor_x->dptr(), tensor_y->mut_dptr(), n); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto LogicalNotPrimitiveExists() { return hob::make_custom("LogicalNotPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewLogicalNotPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("logical_not") .SetCreateFn() .SetIsMatchedHob(LogicalNotPrimitiveExists()); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/loss_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace user_op { namespace loss { template class SimpleLossKernel : public user_op::OpKernel { public: SimpleLossKernel() = default; ~SimpleLossKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* out = out_blob->mut_dptr(); static_cast(this)->ComputeOut(ctx, elem_cnt, input, target, out); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SimpleLossGradKernel : public user_op::OpKernel { public: SimpleLossGradKernel() = default; ~SimpleLossGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_cnt = input_blob->shape_view().elem_cnt(); const T* dy = dy_blob->dptr(); const T* input = input_blob->dptr(); const T* target = target_blob->dptr(); T* dx = dx_blob->mut_dptr(); static_cast(this)->ComputeOut(ctx, elem_cnt, input, target, dy, dx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; namespace { #define REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, device, dtype) \ REGISTER_USER_KERNEL(name).SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); #define REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, device, dtype) \ REGISTER_USER_KERNEL(name).SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); } // namespace #define REGISTER_SIMPLE_LOSS_KERNEL_CPU(name, kernel) \ REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCPU, float) \ REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCPU, double) #define REGISTER_SIMPLE_LOSS_KERNEL_CUDA(name, kernel) \ REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, half) \ REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, float) \ REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, double) #define REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU(name, kernel) \ REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCPU, float) \ REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCPU, double) #define REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA(name, kernel) \ REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, half) \ REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, float) \ REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, double) } // namespace loss } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/lu_decomposition_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { #if CUDA_VERSION >= 11000 static inline size_t BatchCount(const user_op::Tensor* batched_matrices) { size_t result = 1; for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) { result *= batched_matrices->shape_view().At(i); } return result; } static inline size_t MatrixStride(const user_op::Tensor* batched_matrices) { const int64_t num_axes = batched_matrices->shape_view().NumAxes(); return batched_matrices->shape_view().At(num_axes - 2) * batched_matrices->shape_view().At(num_axes - 1); } static inline size_t PivotStride(const user_op::Tensor* batched_pivot) { const int64_t num_axes = batched_pivot->shape_view().NumAxes(); return batched_pivot->shape_view().At(num_axes - 1); } void OFgetrf_bufferSize(ep::Stream* stream, int32_t m, int32_t n, float* dA_array, int32_t lda, int32_t& lwork) { OF_CUSOLVER_CHECK(cusolverDnSgetrf_bufferSize(stream->As()->cusolver_dn_handle(), m, n, dA_array, m, &lwork)); } void OFgetrf_bufferSize(ep::Stream* stream, int32_t m, int32_t n, double* dA_array, int32_t lda, int32_t& lwork) { OF_CUSOLVER_CHECK(cusolverDnDgetrf_bufferSize(stream->As()->cusolver_dn_handle(), m, n, dA_array, m, &lwork)); } void OFgetrf(ep::Stream* stream, int32_t m, int32_t n, float* dA_array, int32_t lda, float* d_work, int32_t* pivot_ptr, int32_t* d_info) { OF_CUSOLVER_CHECK(cusolverDnSgetrf(stream->As()->cusolver_dn_handle(), m, m, dA_array, lda, d_work, pivot_ptr, d_info)); } void OFgetrf(ep::Stream* stream, int32_t m, int32_t n, double* dA_array, int32_t lda, double* d_work, int32_t* pivot_ptr, int32_t* d_info) { OF_CUSOLVER_CHECK(cusolverDnDgetrf(stream->As()->cusolver_dn_handle(), m, m, dA_array, lda, d_work, pivot_ptr, d_info)); } } // namespace namespace user_op { template class LUDecompositionKernel final : public user_op::OpKernel { public: LUDecompositionKernel() = default; ~LUDecompositionKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* LU = ctx->Tensor4ArgNameAndIndex("LU", 0); user_op::Tensor* pivot = ctx->Tensor4ArgNameAndIndex("pivot", 0); auto stream = ctx->stream()->As(); // infer tmp buffer const int32_t m = x->shape_view().At(x->shape_view().NumAxes() - 2); const int32_t lda = m; const T* x_ptr = x->dptr(); T* LU_ptr = LU->mut_dptr(); int32_t* pivot_ptr = pivot->mut_dptr(); size_t batch_count = BatchCount(x); size_t matrix_stride = MatrixStride(x); size_t pivot_stride = PivotStride(x); std::unique_ptr memcpy_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD); CHECK(memcpy_primitive) << "Can not create Memcpy primitive for device type " << ctx->stream()->device_type(); memcpy_primitive->Launch(stream, LU_ptr, x_ptr, sizeof(T) * x->shape_view().elem_cnt()); std::vector batched_info(batch_count, -1); int32_t* batched_d_info = nullptr; int32_t lwork = -1; T* d_work = nullptr; OF_CUDA_CHECK( cudaMalloc(reinterpret_cast(&batched_d_info), batch_count * sizeof(int32_t))); for (size_t batch = 0; batch < batch_count; batch++) { OFgetrf_bufferSize(stream, m, m, LU_ptr, m, lwork); OF_CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_work), sizeof(T) * lwork)); OFgetrf(stream, m, m, LU_ptr + batch * matrix_stride, lda, d_work, pivot_ptr + batch * pivot_stride, batched_d_info + batch); OF_CUDA_CHECK(cudaFree(d_work)); } OF_CUDA_CHECK(cudaMemcpyAsync(batched_info.data(), batched_d_info, batch_count * sizeof(int32_t), cudaMemcpyDeviceToHost, stream->cuda_stream())); for (size_t i = 0; i < batched_info.size(); i++) { int32_t info = batched_info[i]; CHECK(info >= 0) << "LU decomposition: " << -info << "-th parameter of batch " << i << " is wrong"; } OF_CUDA_CHECK(cudaFree(batched_d_info)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_LU_DECOMPOSITION_KERNEL(dtype) \ REGISTER_USER_KERNEL("lu_decomposition") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_CUDA_LU_DECOMPOSITION_KERNEL(float) REGISTER_CUDA_LU_DECOMPOSITION_KERNEL(double) #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/masked_fill_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/where_kernel_util.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template class MaskedFillKernel final : public user_op::OpKernel { public: MaskedFillKernel() = default; ~MaskedFillKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T scalar_operand = static_cast(0); if (ctx->Attr("has_int_operand")) { scalar_operand = static_cast(ctx->Attr("int_operand")); } else if (ctx->Attr("has_float_operand")) { scalar_operand = static_cast(ctx->Attr("float_operand")); } else if (ctx->Attr("has_bool_operand")) { scalar_operand = static_cast(ctx->Attr("bool_operand")); } else { UNIMPLEMENTED() << "The scalar in MaskedFill should be float or int."; } WhereKernelUtil::WhereXScalar( ctx->stream(), out->shape_view().elem_cnt(), mask->dptr(), scalar_operand, x->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MASKED_FILL_KERNEL(device_type_v, dtype_pair, ctype_pair) \ REGISTER_USER_KERNEL("masked_fill") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("mask", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MASKED_FILL_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MASKED_FILL_KERNEL, (DeviceType::kCUDA), FLOAT16_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/math_binary_broadcast_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/user/ops/math_binary_broadcast_seq.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { template std::enable_if_t> NewBroadcastElementwiseBinaryPrimitive(Context* ctx) { const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* z = ctx->TensorDesc4ArgNameAndIndex("z", 0); size_t num_axes = z->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), binary_op, x->data_type(), z->data_type(), num_axes, ctx->template Attr("atol"), ctx->template Attr("rtol")); } template std::enable_if_t> NewBroadcastElementwiseBinaryPrimitive(Context* ctx) { const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* z = ctx->TensorDesc4ArgNameAndIndex("z", 0); size_t num_axes = z->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), binary_op, x->data_type(), z->data_type(), num_axes); } template class MathBinaryBroadcastEpKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MathBinaryBroadcastEpKernel() = default; ~MathBinaryBroadcastEpKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* z = ctx->Tensor4ArgNameAndIndex("z", 0); auto primitive = NewBroadcastElementwiseBinaryPrimitive(ctx); CHECK(primitive.get() != nullptr) << "Exceeds maximum supported dimensions"; const int64_t x_elem_cnt = x->shape_view().elem_cnt(); const int64_t y_elem_cnt = y->shape_view().elem_cnt(); size_t num_src0_dims = x->shape_view().NumAxes(); size_t num_src1_dims = y->shape_view().NumAxes(); int64_t zero_dim = 1; int64_t* src0_dims = const_cast(x->shape_view().ptr()); int64_t* src1_dims = const_cast(y->shape_view().ptr()); if (x_elem_cnt != 0 && y_elem_cnt != 0) { if (num_src0_dims == 0) { num_src0_dims = 1; src0_dims = &zero_dim; } if (num_src1_dims == 0) { num_src1_dims = 1; src1_dims = &zero_dim; } primitive->Launch(ctx->stream(), num_src0_dims, src0_dims, x->dptr(), num_src1_dims, src1_dims, y->dptr(), z->mut_dptr()); } else { // For 0-size Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template auto MathBinaryBroadcastPrimitiveExists() { return hob::make_custom("MathBinaryBroadcastPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBroadcastElementwiseBinaryPrimitive(&ctx). operator bool(); }); } #define REGISTER_BINARY_BROADCAST_EP_KERNEL(math_type_pair, binary_op) \ REGISTER_USER_KERNEL(math_type_pair) \ .SetCreateFn>() \ .SetIsMatchedHob(MathBinaryBroadcastPrimitiveExists() == true); REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_add", ep::primitive::BinaryOp::kAdd) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_sub", ep::primitive::BinaryOp::kSub) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_mul", ep::primitive::BinaryOp::kMul) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_div", ep::primitive::BinaryOp::kDiv) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_minimum", ep::primitive::BinaryOp::kMin) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_maximum", ep::primitive::BinaryOp::kMax) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_pow", ep::primitive::BinaryOp::kPow) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_equal", ep::primitive::BinaryOp::kEqual) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_not_equal", ep::primitive::BinaryOp::kNotEqual) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_greater", ep::primitive::BinaryOp::kGreaterThan) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_greater_equal", ep::primitive::BinaryOp::kGreaterEqual) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_less", ep::primitive::BinaryOp::kLessThan) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_less_equal", ep::primitive::BinaryOp::kLessEqual) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_isclose_eq_nan", ep::primitive::BinaryOp::kIsCloseEqualNan) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_isclose_neq_nan", ep::primitive::BinaryOp::kIsClose) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_and", ep::primitive::BinaryOp::kLogicalAnd) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_or", ep::primitive::BinaryOp::kLogicalOr) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_xor", ep::primitive::BinaryOp::kLogicalXor) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_bitwise_and", ep::primitive::BinaryOp::kBitwiseAnd) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_bitwise_or", ep::primitive::BinaryOp::kBitwiseOr) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_bitwise_xor", ep::primitive::BinaryOp::kBitwiseXor) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_floor_mod", ep::primitive::BinaryOp::kFloorMod) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_fmod", ep::primitive::BinaryOp::kFmod) REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_zeta", ep::primitive::BinaryOp::kZeta) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/math_binary_elementwise_func.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_ #define ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" #include "oneflow/core/device/cuda_pseudo_half.h" #if defined(__CUDACC__) #include #define MATH_FUNC(name) name #else #include #define MATH_FUNC(name) std::name #endif namespace oneflow { #define DECLARE_BINARY_FUNCTOR(math_binary_elementwise_type, func_prefix) \ template \ struct func_prefix##Functor; OF_PP_FOR_EACH_TUPLE(DECLARE_BINARY_FUNCTOR, MATH_BINARY_ELEMENTWISE_FUNC_SEQ) template struct PowFunctor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { return MATH_FUNC(pow)(x, y); } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return dz * y * (MATH_FUNC(pow)(x, y - T(1))); } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { if (x > T(0)) { return dz * MATH_FUNC(log)(x) * (MATH_FUNC(pow)(x, y)); } else { return T(0); } } }; template struct Atan2Functor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { return MATH_FUNC(atan2)(x, y); } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return dz * (y / (x * x + y * y)); } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return dz * -x / (y * y + x * x); } }; template struct FloorDivFunctor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { #if defined(__CUDACC__) return floor(fdividef(x, y)); #else return std::floor(x / y); #endif } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return T(0); } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return T(0); } }; template struct TruncDivFunctor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { #if defined(__CUDACC__) return trunc(fdividef(x, y)); #else return std::trunc(x / y); #endif } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return T(0); } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return T(0); } }; template struct XdivyFunctor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { if (T(0) == x) { return T(0); } else { return x / y; } } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { if (T(0) == x || T(0) == dz) { return T(0); } else { return dz / y; } } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return dz * XdivyFunctor::Forward((-x), (y * y)); } }; template struct XlogyFunctor { static OF_DEVICE_FUNC const T Forward(const T x, const T y) { if (T(0) == x) { return T(0); } else { return x * MATH_FUNC(log)(y); } } static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { if (T(0) == x || T(0) == dz) { return T(0); } else { return dz * MATH_FUNC(log)(y); } } static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return dz * XdivyFunctor::Forward(x, y); } }; #if defined(__CUDACC__) // half version #define OF_HALF_FUNC __device__ __forceinline__ #define MATH_FUNC_H_FW(name) __float2half(name(__half2float(x), __half2float(y))) #define MATH_FUNC_H_BW(name) __float2half(name(__half2float(x), __half2float(y), __half2float(dz))) template<> struct PowFunctor { static OF_HALF_FUNC const half Forward(const half x, const half y) { return MATH_FUNC_H_FW(PowFunctor::Forward); } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { return MATH_FUNC_H_BW(PowFunctor::BackwardXGrad); } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return MATH_FUNC_H_BW(PowFunctor::BackwardYGrad); } }; template<> struct Atan2Functor { static OF_HALF_FUNC const half Forward(const half x, const half y) { return MATH_FUNC_H_FW(Atan2Functor::Forward); } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { return __hmul(dz, __hdiv(y, __hadd(__hmul(y, y), __hmul(x, x)))); } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return __hmul(dz, __hdiv(__hneg(x), __hadd(__hmul(y, y), __hmul(x, x)))); } }; template<> struct FloorDivFunctor { static OF_HALF_FUNC const half Forward(const half x, const half y) { return hfloor(__hdiv(x, y)); } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { return GetZeroVal(); } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return GetZeroVal(); } }; template<> struct TruncDivFunctor { static OF_HALF_FUNC const half Forward(const half x, const half y) { return htrunc(__hdiv(x, y)); } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { return GetZeroVal(); } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return GetZeroVal(); } }; template<> struct XdivyFunctor { static OF_HALF_FUNC const half Forward(const half x, const half y) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } else { return __hdiv(x, y); } } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } else { return XdivyFunctor::Forward(dz, y); } } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return __hmul(dz, XdivyFunctor::Forward(__hneg(x), __hmul(y, y))); } }; template<> struct XlogyFunctor { static OF_HALF_FUNC const half Forward(const half x, const half y) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } else { return __hmul(x, hlog(y)); } } static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } else { return XlogyFunctor::Forward(dz, y); } } static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) { return __hmul(dz, XdivyFunctor::Forward(x, y)); } }; #endif } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_ ================================================ FILE: oneflow/user/kernels/math_binary_elementwise_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/math_binary_elementwise_func.h" #include "oneflow/core/ep/cpu/cpu_stream.h" #include "oneflow/core/ep/cpu/cpu_device.h" namespace oneflow { template class BinaryFunctor, typename T> class MathBinaryElementwiseCpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseCpuKernel() = default; ~MathBinaryElementwiseCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); const T* x = tensor_x->dptr(); const T* y = tensor_y->dptr(); T* z = tensor_z->mut_dptr(); int64_t n = tensor_x->shape_view().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); ep::CpuStream* cpu_stream = ctx->stream()->As(); cpu_stream->ParallelFor(0, n, [x, y, z](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { z[i] = BinaryFunctor::Forward(x[i], y[i]); } }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor, typename T> class MathBinaryElementwiseXGradCpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseXGradCpuKernel() = default; ~MathBinaryElementwiseXGradCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const T* x = tensor_x->dptr(); const T* y = tensor_y->dptr(); const T* dz = tensor_dz->dptr(); T* dx = tensor_dx->mut_dptr(); int64_t n = tensor_x->shape_view().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); for (int32_t i = 0; i < n; ++i) { dx[i] = BinaryFunctor::BackwardXGrad(x[i], y[i], dz[i]); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor, typename T> class MathBinaryElementwiseYGradCpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseYGradCpuKernel() = default; ~MathBinaryElementwiseYGradCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const T* x = tensor_x->dptr(); const T* y = tensor_y->dptr(); const T* dz = tensor_dz->dptr(); T* dy = tensor_dy->mut_dptr(); int64_t n = tensor_x->shape_view().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); for (int32_t i = 0; i < n; ++i) { dy[i] = BinaryFunctor::BackwardYGrad(x[i], y[i], dz[i]); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD(math_type_pair, data_type_pair) \ REGISTER_USER_KERNEL(OF_PP_PAIR_FIRST(math_type_pair)) \ .SetCreateFn< \ MathBinaryElementwiseCpuKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \ \ REGISTER_USER_KERNEL((std::string("") + OF_PP_PAIR_FIRST(math_type_pair) + "_x_grad")) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \ REGISTER_USER_KERNEL((std::string("") + OF_PP_PAIR_FIRST(math_type_pair) + "_y_grad")) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ, FLOATING_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD, OF_PP_MAKE_TUPLE_SEQ("floordiv", FloorDiv) OF_PP_MAKE_TUPLE_SEQ("truncdiv", TruncDiv), INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/math_binary_elementwise_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/math_binary_elementwise_func.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template class BinaryFunctor, typename T> __global__ void MathBinaryElementwiseForwardGpu(const int64_t n, const T* x, const T* y, T* z) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { z[i] = BinaryFunctor::Forward(x[i], y[i]); } } template class BinaryFunctor, typename T> __global__ void MathBinaryElementwiseBackwardXGradGpu(const int64_t n, const T* x, const T* y, const T* dz, T* dx) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { dx[i] = BinaryFunctor::BackwardXGrad(x[i], y[i], dz[i]); } } template class BinaryFunctor, typename T> __global__ void MathBinaryElementwiseBackwardYGradGpu(const int64_t n, const T* x, const T* y, const T* dz, T* dy) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { dy[i] = BinaryFunctor::BackwardYGrad(x[i], y[i], dz[i]); } } } // namespace template class BinaryFunctor, typename T> class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseGpuKernel() = default; ~MathBinaryElementwiseGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseForwardGpu <<stream()->As()->cuda_stream()>>>( n, tensor_x->dptr(), tensor_y->dptr(), tensor_z->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor, typename T> class MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseXGradGpuKernel() = default; ~MathBinaryElementwiseXGradGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseBackwardXGradGpu <<stream()->As()->cuda_stream()>>>( n, tensor_x->dptr(), tensor_y->dptr(), tensor_dz->dptr(), tensor_dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor, typename T> class MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel { public: MathBinaryElementwiseYGradGpuKernel() = default; ~MathBinaryElementwiseYGradGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseBackwardYGradGpu <<stream()->As()->cuda_stream()>>>( n, tensor_x->dptr(), tensor_y->dptr(), tensor_dz->dptr(), tensor_dy->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD(math_type_pair, data_type_pair) \ REGISTER_USER_KERNEL(OF_PP_PAIR_FIRST(math_type_pair)) \ .SetCreateFn< \ MathBinaryElementwiseGpuKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \ \ REGISTER_USER_KERNEL((std::string("") + OF_PP_PAIR_FIRST(math_type_pair) + "_x_grad")) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \ REGISTER_USER_KERNEL((std::string("") + OF_PP_PAIR_FIRST(math_type_pair) + "_y_grad")) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ, FLOATING_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD, OF_PP_MAKE_TUPLE_SEQ("floordiv", FloorDiv) OF_PP_MAKE_TUPLE_SEQ("truncdiv", TruncDiv), INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ) template class BinaryFunctor> class MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel { public: MathBinaryElementwiseGpuHalfKernel() = default; ~MathBinaryElementwiseGpuHalfKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); const half* x = reinterpret_cast(tensor_x->dptr()); const half* y = reinterpret_cast(tensor_y->dptr()); half* z = reinterpret_cast(tensor_z->mut_dptr()); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseForwardGpu <<stream()->As()->cuda_stream()>>>(n, x, y, z); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor> class MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel { public: MathBinaryElementwiseXGradGpuHalfKernel() = default; ~MathBinaryElementwiseXGradGpuHalfKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const half* x = reinterpret_cast(tensor_x->dptr()); const half* y = reinterpret_cast(tensor_y->dptr()); const half* dz = reinterpret_cast(tensor_dz->dptr()); half* dx = reinterpret_cast(tensor_dx->mut_dptr()); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseBackwardXGradGpu <<stream()->As()->cuda_stream()>>>(n, x, y, dz, dx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class BinaryFunctor> class MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel { public: MathBinaryElementwiseYGradGpuHalfKernel() = default; ~MathBinaryElementwiseYGradGpuHalfKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0); user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const half* x = reinterpret_cast(tensor_x->dptr()); const half* y = reinterpret_cast(tensor_y->dptr()); const half* dz = reinterpret_cast(tensor_dz->dptr()); half* dy = reinterpret_cast(tensor_dy->mut_dptr()); int64_t n = tensor_x->shape_view().elem_cnt(); if (n == 0) { return; } MathBinaryElementwiseBackwardYGradGpu <<stream()->As()->cuda_stream()>>>(n, x, y, dz, dy); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_HALF_KERNEL_AND_GRAD(math_type_str, \ math_func_prefix) \ REGISTER_USER_KERNEL(math_type_str) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == DataType::kFloat16)); \ \ REGISTER_USER_KERNEL((std::string("") + math_type_str + "_x_grad")) \ .SetCreateFn< \ MathBinaryElementwiseXGradGpuHalfKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == DataType::kFloat16)); \ REGISTER_USER_KERNEL((std::string("") + math_type_str + "_y_grad")) \ .SetCreateFn< \ MathBinaryElementwiseYGradGpuHalfKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == DataType::kFloat16)); OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_HALF_KERNEL_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/math_unary_elementwise_func.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ #define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" #include "oneflow/core/device/cuda_pseudo_half.h" #if defined(__CUDACC__) #include #define MATH_FUNC_F(name, x) name##f(x) #define MATH_FUNC_D(name, x) name(x) #else #include #define MATH_FUNC_F(name, x) std::name(x) #define MATH_FUNC_D(name, x) std::name(x) #endif namespace oneflow { #define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \ template \ struct func_prefix##Functor; OF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ) template struct AbsFunctor { static OF_DEVICE_FUNC T Forward(const T x) { if (x == T(0)) return T(0); else return x < T(0) ? -x : x; } static OF_DEVICE_FUNC T Backward(const T x, const T dy) { if (x == T(0)) return T(0); else return x < T(0) ? -dy : dy; } }; template struct SignFunctor { static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); } static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); } }; template<> struct RsqrtFunctor { static OF_DEVICE_FUNC float Forward(const float x) { #if defined(__CUDACC__) return rsqrtf(x); #else return 1.0f / std::sqrt(x); #endif } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x))); } }; template<> struct RsqrtFunctor { static OF_DEVICE_FUNC double Forward(const double x) { #if defined(__CUDACC__) return rsqrt(x); #else return 1.0 / std::sqrt(x); #endif } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x))); } }; // float version template<> struct AcosFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * -RsqrtFunctor::Forward(1.0f - x * x); } }; template<> struct AcoshFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * RsqrtFunctor::Forward(x * x - 1.0f); } }; template<> struct AsinFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * RsqrtFunctor::Forward(1.0f - x * x); } }; template<> struct AsinhFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * RsqrtFunctor::Forward(1.0f + x * x); } }; template<> struct AtanFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (1.0f + x * x)); } }; template<> struct AtanhFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (1.0f - x * x)); } }; template<> struct NotEqualZeroFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } }; template<> struct CeilFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } }; template<> struct CosFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (-MATH_FUNC_F(sin, x)); } }; template<> struct CoshFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * MATH_FUNC_F(sinh, x); } }; template<> struct ErfFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * RsqrtFunctor::Forward(M_PI) * expf(-x * x); } }; template<> struct ErfcFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * -2.0f * RsqrtFunctor::Forward(M_PI) * expf(-x * x); } }; template<> struct ExpFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * MATH_FUNC_F(exp, x); } }; template<> struct Expm1Functor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * MATH_FUNC_F(exp, x); } }; template<> struct FloorFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } }; template<> struct LgammaFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { // TODO(chengcheng): return: dy * digamma(x) assert(false); return 0.0f; } }; template<> struct LogFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); } }; template<> struct Log2Functor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f))); } }; template<> struct Log1pFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (x + 1.0f)); } }; template<> struct LogSigmoidFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x))); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f)); } }; template<> struct NegativeFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return -x; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; } }; template<> struct ReciprocalFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (-1.0f / (x * x)); } }; template<> struct ReciprocalNoNanFunctor { static OF_DEVICE_FUNC float Forward(const float x) { if (fabsf(x) <= 0.0f) { return 0.0f; } return 1.0f / x; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { if (fabsf(x) <= 0.0f) { return 0.0f; } return dy * (-1.0f / (x * x)); } }; template<> struct RintFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } }; template<> struct RoundFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } }; template<> struct SigmoidFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / (1.0f + MATH_FUNC_F(exp, -x)); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x)); return dy * (y * (1.0f - y)); } }; template<> struct SinFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * MATH_FUNC_F(cos, x); } }; template<> struct SinhFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * MATH_FUNC_F(cosh, x); } }; template<> struct SqrtFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 0.5f / MATH_FUNC_F(sqrt, x); } }; template<> struct SquareFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return x * x; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; } }; template<> struct TanFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x))); } }; // double version template<> struct AcosFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * -RsqrtFunctor::Forward(1.0 - x * x); } }; template<> struct AcoshFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * -RsqrtFunctor::Forward(x * x - 1.0); } }; template<> struct AsinFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * RsqrtFunctor::Forward(1.0 - x * x); } }; template<> struct AsinhFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * RsqrtFunctor::Forward(1.0 + x * x); } }; template<> struct AtanFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (1.0 + x * x)); } }; template<> struct AtanhFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (1.0 - x * x)); } }; template<> struct NotEqualZeroFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; } }; template<> struct CeilFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } }; template<> struct CosFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (-MATH_FUNC_D(sin, x)); } }; template<> struct CoshFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * MATH_FUNC_D(sinh, x); } }; template<> struct ErfFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * RsqrtFunctor::Forward(M_PI) * expf(-x * x); } }; template<> struct ErfcFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * -2.0 * RsqrtFunctor::Forward(M_PI) * expf(-x * x); } }; template<> struct ExpFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * MATH_FUNC_D(exp, x); } }; template<> struct Expm1Functor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * MATH_FUNC_D(exp, x); } }; template<> struct FloorFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } }; template<> struct LgammaFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { // TODO(chengcheng): return: dy * digamma(x) assert(false); return 0.0; } }; template<> struct LogFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); } }; template<> struct Log2Functor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0))); } }; template<> struct Log1pFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (x + 1.0)); } }; template<> struct LogSigmoidFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x))); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0)); } }; template<> struct NegativeFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return -x; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; } }; template<> struct ReciprocalFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (-1.0 / (x * x)); } }; template<> struct ReciprocalNoNanFunctor { static OF_DEVICE_FUNC double Forward(const double x) { if (fabs(x) <= 0.0) { return 0.0; } return 1.0 / x; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { if (fabs(x) <= 0.0) { return 0.0; } return dy * (-1.0 / (x * x)); } }; template<> struct RintFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } }; template<> struct RoundFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } }; template<> struct SigmoidFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / (1.0 + MATH_FUNC_D(exp, -x)); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x)); return dy * (y * (1.0 - y)); } }; template<> struct SinFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * MATH_FUNC_D(cos, x); } }; template<> struct SinhFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * MATH_FUNC_D(cosh, x); } }; template<> struct SqrtFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (double)0.5 / MATH_FUNC_D(sqrt, x); } }; template<> struct SquareFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return x * x; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; } }; template<> struct TanFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x))); } }; #if defined(__CUDACC__) // half version #define OF_HALF_FUNC __device__ __forceinline__ #define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x))) #define HALF_VAL_HALF __float2half(0.5f) #define HALF_VAL_TWO __float2half(2.0f) #define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f) template<> struct AbsFunctor { static OF_HALF_FUNC half Forward(const half x) { return __hlt(x, GetZeroVal()) ? __hneg(x) : x; } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hlt(x, GetZeroVal()) ? __hneg(dy) : dy; } }; template<> struct AcosFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal(), __hmul(x, x))))); } }; template<> struct AcoshFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal()))); } }; template<> struct AsinFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrsqrt(__hsub(GetOneVal(), __hmul(x, x)))); } }; template<> struct AsinhFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrsqrt(__hadd(GetOneVal(), __hmul(x, x)))); } }; template<> struct AtanFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hdiv(GetOneVal(), __hadd(GetOneVal(), __hmul(x, x)))); } }; template<> struct AtanhFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hdiv(GetOneVal(), __hsub(GetOneVal(), __hmul(x, x)))); } }; template<> struct CeilFunctor { static OF_HALF_FUNC half Forward(const half x) { return hceil(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct NotEqualZeroFunctor { static OF_HALF_FUNC half Forward(const half x) { return x != static_cast(0.0); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct CosFunctor { static OF_HALF_FUNC half Forward(const half x) { return hcos(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hneg(hsin(x))); } }; template<> struct CoshFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, MATH_FUNC_H(sinh, x)); } }; template<> struct ErfFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))); } }; template<> struct ErfcFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))))); } }; template<> struct ExpFunctor { static OF_HALF_FUNC half Forward(const half x) { return hexp(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); } }; template<> struct Expm1Functor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); } }; template<> struct FloorFunctor { static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct LgammaFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { // TODO(chengcheng): return: dy * digamma(x) assert(false); return GetZeroVal(); } }; template<> struct LogFunctor { static OF_HALF_FUNC half Forward(const half x) { return hlog(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); } }; template<> struct Log2Functor { static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO)))); } }; template<> struct Log1pFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(__hadd(x, GetOneVal()))); } }; template<> struct LogSigmoidFunctor { static OF_HALF_FUNC half Forward(const half x) { return __hneg(hlog(__hadd(GetOneVal(), hexp(__hneg(x))))); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal()))); } }; template<> struct NegativeFunctor { static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); } }; template<> struct ReciprocalFunctor { static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hneg(hrcp(__hmul(x, x)))); } }; template<> struct ReciprocalNoNanFunctor { static OF_HALF_FUNC half Forward(const half x) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } return hrcp(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { if (__heq(GetZeroVal(), x)) { return GetZeroVal(); } return __hmul(dy, __hneg(hrcp(__hmul(x, x)))); } }; template<> struct RintFunctor { static OF_HALF_FUNC half Forward(const half x) { return hrint(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct RoundFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct RsqrtFunctor { static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x))))))); } }; template<> struct SigmoidFunctor { static OF_HALF_FUNC half Forward(const half x) { return hrcp(__hadd(GetOneVal(), hexp(__hneg(x)))); } static OF_HALF_FUNC half Backward(const half x, const half dy) { half y = hrcp(__hadd(GetOneVal(), hexp(__hneg(x)))); return __hmul(dy, __hmul(y, __hsub(GetOneVal(), y))); } }; template<> struct SignFunctor { static OF_HALF_FUNC half Forward(const half x) { if (__hgt(x, GetZeroVal())) { return GetOneVal(); } if (__hlt(x, GetZeroVal())) { return __hneg(GetOneVal()); } return GetZeroVal(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal(); } }; template<> struct SinFunctor { static OF_HALF_FUNC half Forward(const half x) { return hsin(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); } }; template<> struct SinhFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, MATH_FUNC_H(cosh, x)); } }; template<> struct SqrtFunctor { static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x))); } }; template<> struct SquareFunctor { static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, __hmul(HALF_VAL_TWO, x)); } }; template<> struct TanFunctor { static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x)))); } }; #endif } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ ================================================ FILE: oneflow/user/kernels/math_unary_elementwise_primitive_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/ep/include/primitive/binary_op.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/user/kernels/elementwise_primitive_kernel.h" namespace oneflow { #define MATH_UNARY_ELEMENTWISE_PRIMITIVE_SEQ \ OF_PP_MAKE_TUPLE_SEQ("abs", ep::primitive::UnaryOp::kAbs) \ OF_PP_MAKE_TUPLE_SEQ("acos", ep::primitive::UnaryOp::kAcos) \ OF_PP_MAKE_TUPLE_SEQ("acosh", ep::primitive::UnaryOp::kAcosh) \ OF_PP_MAKE_TUPLE_SEQ("asin", ep::primitive::UnaryOp::kAsin) \ OF_PP_MAKE_TUPLE_SEQ("asinh", ep::primitive::UnaryOp::kAsinh) \ OF_PP_MAKE_TUPLE_SEQ("atan", ep::primitive::UnaryOp::kAtan) \ OF_PP_MAKE_TUPLE_SEQ("atanh", ep::primitive::UnaryOp::kAtanh) \ OF_PP_MAKE_TUPLE_SEQ("ceil", ep::primitive::UnaryOp::kCeil) \ OF_PP_MAKE_TUPLE_SEQ("cos", ep::primitive::UnaryOp::kCos) \ OF_PP_MAKE_TUPLE_SEQ("cosh", ep::primitive::UnaryOp::kCosh) \ OF_PP_MAKE_TUPLE_SEQ("digamma", ep::primitive::UnaryOp::kDigamma) \ OF_PP_MAKE_TUPLE_SEQ("trigamma", ep::primitive::UnaryOp::kTrigamma) \ OF_PP_MAKE_TUPLE_SEQ("erf", ep::primitive::UnaryOp::kErf) \ OF_PP_MAKE_TUPLE_SEQ("erfc", ep::primitive::UnaryOp::kErfc) \ OF_PP_MAKE_TUPLE_SEQ("exp", ep::primitive::UnaryOp::kExp) \ OF_PP_MAKE_TUPLE_SEQ("exp2", ep::primitive::UnaryOp::kExp2) \ OF_PP_MAKE_TUPLE_SEQ("expm1", ep::primitive::UnaryOp::kExpm1) \ OF_PP_MAKE_TUPLE_SEQ("floor", ep::primitive::UnaryOp::kFloor) \ OF_PP_MAKE_TUPLE_SEQ("lgamma", ep::primitive::UnaryOp::kLgamma) \ OF_PP_MAKE_TUPLE_SEQ("log", ep::primitive::UnaryOp::kLog) \ OF_PP_MAKE_TUPLE_SEQ("log2", ep::primitive::UnaryOp::kLog2) \ OF_PP_MAKE_TUPLE_SEQ("log10", ep::primitive::UnaryOp::kLog10) \ OF_PP_MAKE_TUPLE_SEQ("log1p", ep::primitive::UnaryOp::kLog1p) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", ep::primitive::UnaryOp::kLogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("negative", ep::primitive::UnaryOp::kNegative) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal", ep::primitive::UnaryOp::kReciprocal) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ep::primitive::UnaryOp::kReciprocalNoNan) \ OF_PP_MAKE_TUPLE_SEQ("rint", ep::primitive::UnaryOp::kRint) \ OF_PP_MAKE_TUPLE_SEQ("round", ep::primitive::UnaryOp::kRound) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt", ep::primitive::UnaryOp::kRsqrt) \ OF_PP_MAKE_TUPLE_SEQ("sigmoid", ep::primitive::UnaryOp::kSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("sign", ep::primitive::UnaryOp::kSign) \ OF_PP_MAKE_TUPLE_SEQ("sin", ep::primitive::UnaryOp::kSin) \ OF_PP_MAKE_TUPLE_SEQ("sinh", ep::primitive::UnaryOp::kSinh) \ OF_PP_MAKE_TUPLE_SEQ("sqrt", ep::primitive::UnaryOp::kSqrt) \ OF_PP_MAKE_TUPLE_SEQ("square", ep::primitive::UnaryOp::kSquare) \ OF_PP_MAKE_TUPLE_SEQ("tan", ep::primitive::UnaryOp::kTan) \ OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", ep::primitive::UnaryOp::kNotEqualZero) \ OF_PP_MAKE_TUPLE_SEQ("bitwise_not", ep::primitive::UnaryOp::kBitwiseNot) #define MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_X_PRIMITIVE_SEQ \ OF_PP_MAKE_TUPLE_SEQ("abs_grad", ep::primitive::BinaryOp::kAbsBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("acos_grad", ep::primitive::BinaryOp::kAcosBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("acosh_grad", ep::primitive::BinaryOp::kAcoshBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("asin_grad", ep::primitive::BinaryOp::kAsinBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("asinh_grad", ep::primitive::BinaryOp::kAsinhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("atan_grad", ep::primitive::BinaryOp::kAtanBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("atanh_grad", ep::primitive::BinaryOp::kAtanhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("cos_grad", ep::primitive::BinaryOp::kCosBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("cosh_grad", ep::primitive::BinaryOp::kCoshBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("digamma_grad", ep::primitive::BinaryOp::kDigammaBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("erf_grad", ep::primitive::BinaryOp::kErfBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("erfc_grad", ep::primitive::BinaryOp::kErfcBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("exp_grad", ep::primitive::BinaryOp::kExpBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("exp2_grad", ep::primitive::BinaryOp::kExp2BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("expm1_grad", ep::primitive::BinaryOp::kExpm1BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("log_grad", ep::primitive::BinaryOp::kLogBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("lgamma_grad", ep::primitive::BinaryOp::kLgammaBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("log2_grad", ep::primitive::BinaryOp::kLog2BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("log10_grad", ep::primitive::BinaryOp::kLog10BackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("log1p_grad", ep::primitive::BinaryOp::kLog1pBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid_grad", ep::primitive::BinaryOp::kLogSigmoidBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_grad", ep::primitive::BinaryOp::kReciprocalBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan_grad", \ ep::primitive::BinaryOp::kReciprocalNoNanBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt_grad", ep::primitive::BinaryOp::kRsqrtBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("sin_grad", ep::primitive::BinaryOp::kSinBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("sinh_grad", ep::primitive::BinaryOp::kSinhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("sqrt_grad", ep::primitive::BinaryOp::kSqrtBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("square_grad", ep::primitive::BinaryOp::kSquareBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ("tan_grad", ep::primitive::BinaryOp::kTanBackwardWithDyX) #define MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_Y_PRIMITIVE_SEQ \ OF_PP_MAKE_TUPLE_SEQ("sigmoid_grad", ep::primitive::BinaryOp::kSigmoidBackwardWithDyY) #define REGISTER_MATH_UNARY_PRIMITIVE_KERNEL(name, UnaryOp) \ REGISTER_USER_KERNEL(name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel( \ "y", "x", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("y", 0); \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("x", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), UnaryOp, src->data_type(), dst->data_type()); \ }); \ }) \ .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, "y", "x")); OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_PRIMITIVE_KERNEL, MATH_UNARY_ELEMENTWISE_PRIMITIVE_SEQ) #define REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_X_KERNEL(name, BinaryOp) \ REGISTER_USER_KERNEL(name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel< \ BinaryPrimitiveKernel>("dx", "dy", "x", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), BinaryOp, src->data_type(), dst->data_type(), \ 1 /*max_num_dims*/); \ }); \ }) \ .SetIsMatchedHob(BinaryPrimitiveExists(BinaryOp, "dx", "dy")); OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_X_KERNEL, MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_X_PRIMITIVE_SEQ) #define REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_Y_KERNEL(name, BinaryOp) \ REGISTER_USER_KERNEL(name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel< \ BinaryPrimitiveKernel>("dx", "dy", "y", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0); \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), BinaryOp, src->data_type(), dst->data_type(), \ 1 /*max_num_dims*/); \ }); \ }) \ .SetIsMatchedHob(BinaryPrimitiveExists(BinaryOp, "dx", "dy")); OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_Y_KERNEL, MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_Y_PRIMITIVE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/matmul_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/include/primitive/batch_matmul.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } template ep::primitive::BlasTransposeType GetBlasTransposeType(Context* ctx, const std::string& attr) { return GetBlasTransposeType(ctx->template Attr(attr)); } void InferMatmulMNK(const ShapeView& a_shape, const ShapeView& b_shape, const ShapeView& c_shape, ep::primitive::BlasTransposeType transpose_a, ep::primitive::BlasTransposeType transpose_b, size_t* m, size_t* n, size_t* k) { const int64_t num_a_axes = a_shape.NumAxes(); CHECK_GE(num_a_axes, 2); const int64_t num_b_axes = b_shape.NumAxes(); CHECK_GE(num_b_axes, 2); const int64_t num_c_axes = c_shape.NumAxes(); CHECK_GE(num_c_axes, 2); if (transpose_a == ep::primitive::BlasTransposeType::N) { *m = a_shape.At(num_a_axes - 2); *k = a_shape.At(num_a_axes - 1); } else if (transpose_a == ep::primitive::BlasTransposeType::T) { *m = a_shape.At(num_a_axes - 1); *k = a_shape.At(num_a_axes - 2); } else { UNIMPLEMENTED(); } if (transpose_b == ep::primitive::BlasTransposeType::N) { CHECK_EQ(b_shape.At(num_b_axes - 2), *k); *n = b_shape.At(num_b_axes - 1); } else if (transpose_b == ep::primitive::BlasTransposeType::T) { CHECK_EQ(b_shape.At(num_b_axes - 1), *k); *n = b_shape.At(num_b_axes - 2); } else { UNIMPLEMENTED(); } CHECK_EQ(c_shape.At(num_c_axes - 2), *m); CHECK_EQ(c_shape.At(num_c_axes - 1), *n); } template std::unique_ptr NewMemcpyPrimitive(Context* ctx) { return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::MemcpyKind::kDtoD); } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, ctx->template Attr("transpose_a"), ctx->template Attr("transpose_b")); } template std::unique_ptr NewBatchMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); return ep::primitive::NewPrimitive( ctx->device_type(), data_type, trans_a, trans_b); } template std::unique_ptr NewBroadcastMatmulPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); const int64_t a_num_axes = ctx->TensorDesc4ArgNameAndIndex("a", 0)->shape().NumAxes(); const int64_t b_num_axes = ctx->TensorDesc4ArgNameAndIndex("b", 0)->shape().NumAxes(); const int64_t max_num_axes = std::max(a_num_axes, b_num_axes); return ep::primitive::NewPrimitive( ctx->device_type(), data_type, trans_a, trans_b, max_num_axes); } auto MemcpyPrimitiveExists() { return hob::make_custom("MemcpyPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMemcpyPrimitive(&ctx).operator bool(); }); } auto MatmulPrimitiveExists() { return hob::make_custom("MatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatmulPrimitive(&ctx).operator bool(); }); } auto BatchMatmulPrimitiveExists() { return hob::make_custom("BatchMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBatchMatmulPrimitive(&ctx).operator bool(); }); } auto BroadcastMatmulPrimitiveExists() { return hob::make_custom("BroadcastMatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBroadcastMatmulPrimitive(&ctx).operator bool(); }); } class MatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MatmulKernel() = default; ~MatmulKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); CHECK_EQ(a->shape_view().NumAxes(), 2); const DataType data_type = a->data_type(); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); CHECK_EQ(b->shape_view().NumAxes(), 2); CHECK_EQ(b->data_type(), data_type); const int32_t elem_cnt_a = a->shape_view().elem_cnt(); const int32_t elem_cnt_b = b->shape_view().elem_cnt(); CHECK_GE(elem_cnt_a, 0); CHECK_GE(elem_cnt_b, 0); if (elem_cnt_a == 0 || elem_cnt_b == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out->shape_view().NumAxes(), 2); CHECK_EQ(out->data_type(), data_type); size_t m = 0, n = 0, k = 0; InferMatmulMNK(a->shape_view(), b->shape_view(), out->shape_view(), trans_a, trans_b, &m, &n, &k); const double alpha = ctx->Attr("alpha"); double beta = 0.0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), data_type); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); auto memcpy = NewMemcpyPrimitive(ctx); CHECK(memcpy); memcpy->Launch(ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); beta = 1.0; } auto matmul = NewMatmulPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; REGISTER_USER_KERNEL("matmul") .SetCreateFn() .SetIsMatchedHob(MemcpyPrimitiveExists() && MatmulPrimitiveExists()) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); class BatchMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BatchMatmulKernel() = default; ~BatchMatmulKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); const DataType data_type = a->data_type(); const int64_t num_axes = a->shape_view().NumAxes(); CHECK_GT(num_axes, 2); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); CHECK_EQ(b->data_type(), data_type); CHECK_EQ(b->shape_view().NumAxes(), num_axes); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out->data_type(), data_type); CHECK_EQ(out->shape_view().NumAxes(), num_axes); size_t m = 0; size_t n = 0; size_t k = 0; InferMatmulMNK(a->shape_view(), b->shape_view(), out->shape_view(), trans_a, trans_b, &m, &n, &k); size_t batch_size = 1; for (size_t i = 0; i < num_axes - 2; ++i) { const int64_t dim_size = a->shape_view().At(i); CHECK_GT(dim_size, 0); CHECK_EQ(b->shape_view().At(i), dim_size); CHECK_EQ(out->shape_view().At(i), dim_size); batch_size *= dim_size; } const double alpha = ctx->Attr("alpha"); double beta = 0.0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), data_type); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); auto memcpy = NewMemcpyPrimitive(ctx); CHECK(memcpy); memcpy->Launch(ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); beta = 1.0; } auto batch_matmul = NewBatchMatmulPrimitive(ctx); CHECK(batch_matmul); batch_matmul->Launch(ctx->stream(), batch_size, m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; REGISTER_USER_KERNEL("batch_matmul") .SetCreateFn() .SetIsMatchedHob(MemcpyPrimitiveExists() && BatchMatmulPrimitiveExists()) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); class BroadcastMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BroadcastMatmulKernel() = default; ~BroadcastMatmulKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { double alpha = ctx->Attr("alpha"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); double beta = 0.0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); auto memcpy = NewMemcpyPrimitive(ctx); CHECK(memcpy); memcpy->Launch( ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = 1.0; } const int64_t a_num_axes = a->shape_view().NumAxes(); const int64_t b_num_axes = b->shape_view().NumAxes(); const int64_t out_num_axes = out->shape_view().NumAxes(); auto broadcast_matmul = NewBroadcastMatmulPrimitive(ctx); CHECK(broadcast_matmul); broadcast_matmul->Launch(ctx->stream(), alpha, a_num_axes, a->shape_view().ptr(), a->dptr(), b_num_axes, b->shape_view().ptr(), b->dptr(), beta, out_num_axes, out->shape_view().ptr(), out->mut_dptr()); } }; REGISTER_USER_KERNEL("broadcast_matmul") .SetCreateFn() .SetIsMatchedHob(MemcpyPrimitiveExists() && BroadcastMatmulPrimitiveExists()) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); template std::unique_ptr NewMatmulPrimitiveForBroadcastMatmulGradB(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); } class BroadcastMatmulGradBKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BroadcastMatmulGradBKernel() = default; ~BroadcastMatmulGradBKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { double alpha = ctx->Attr("alpha"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); double beta = 0.0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->shape_view(), out->shape_view()); auto memcpy = NewMemcpyPrimitive(ctx); CHECK(memcpy); memcpy->Launch( ctx->stream(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = 1.0; } CHECK_EQ(a->shape_view().NumAxes(), b->shape_view().NumAxes()); int64_t k = a->shape_view().Count(0, a->shape_view().NumAxes() - 1); CHECK_EQ(b->shape_view().Count(0, b->shape_view().NumAxes() - 1), k); int64_t m = a->shape_view().At(a->shape_view().NumAxes() - 1); int64_t n = b->shape_view().At(b->shape_view().NumAxes() - 1); auto matmul = NewMatmulPrimitiveForBroadcastMatmulGradB(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; auto PrimitiveExistsForBroadcastMatmulGradB() { return hob::make_custom("MatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatmulPrimitiveForBroadcastMatmulGradB(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("broadcast_matmul_grad_b") .SetCreateFn() .SetIsMatchedHob(MemcpyPrimitiveExists() && PrimitiveExistsForBroadcastMatmulGradB()) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/matrix_vector_product_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } template std::unique_ptr NewMemcpyPrimitive(Context* ctx) { return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::MemcpyKind::kDtoD); } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewMatrixVectorProductPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, false, false); } template std::unique_ptr NewMatrixVectorProductGradAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, false, true); } template std::unique_ptr NewMatrixVectorProductGradBPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); } auto MatrixVectorProductPrimitiveExists() { return hob::make_custom("NewMatrixVectorProductPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatrixVectorProductPrimitive(&ctx).operator bool(); }); } auto MatrixVectorProductGradAPrimitiveExists() { return hob::make_custom("NewMatrixVectorProductGradAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatrixVectorProductGradAPrimitive(&ctx).operator bool(); }); } auto MatrixVectorProductGradBPrimitiveExists() { return hob::make_custom("NewMatrixVectorProductGradBPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewMatrixVectorProductGradBPrimitive(&ctx).operator bool(); }); } class MatrixVectorProductKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MatrixVectorProductKernel() = default; ~MatrixVectorProductKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) */ const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); CHECK_EQ(a->shape_view().NumAxes(), 2) << "A Numdims should be equal to 2. "; const DataType data_type = a->data_type(); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); CHECK_EQ(b->shape_view().NumAxes(), 1) << "B Numdims should be equal to 1. "; CHECK_EQ(b->data_type(), data_type) << "Matrix A Datatype should be equal to Vector B"; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out->shape_view().NumAxes(), 1) << "Out Numdims should be equal to 1. "; CHECK_EQ(out->data_type(), data_type) << "Out Datatype should be equal to input's. "; size_t m = a->shape_view().At(0); size_t k = a->shape_view().At(1); size_t n = 1; const double alpha = 1.0; double beta = 0.0; auto matmul = NewMatrixVectorProductPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; REGISTER_USER_KERNEL("matrix_vector_product") .SetCreateFn() .SetIsMatchedHob(MatrixVectorProductPrimitiveExists()); class MatrixVectorProductGradAKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MatrixVectorProductGradAKernel() = default; ~MatrixVectorProductGradAKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose */ const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); size_t m = dy->shape_view().At(0); size_t k = 1; size_t n = b->shape_view().At(0); const double alpha = 1.0; double beta = 0.0; auto matmul = NewMatrixVectorProductGradAPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), b->dptr(), beta, dx->mut_dptr()); } }; REGISTER_USER_KERNEL("matrix_vector_product_grad_a") .SetCreateFn() .SetIsMatchedHob(MatrixVectorProductGradAPrimitiveExists()); class MatrixVectorProductGradBKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MatrixVectorProductGradBKernel() = default; ~MatrixVectorProductGradBKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) GradB = dy_transpose (1, m) matmul A(m, k) */ const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); size_t m = 1; size_t k = dy->shape_view().At(0); size_t n = a->shape_view().At(1); const double alpha = 1.0; double beta = 0.0; auto matmul = NewMatrixVectorProductGradBPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), a->dptr(), beta, dx->mut_dptr()); } }; REGISTER_USER_KERNEL("matrix_vector_product_grad_b") .SetCreateFn() .SetIsMatchedHob(MatrixVectorProductGradBPrimitiveExists()); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/max_pool_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/max_pool_kernel_util.h" namespace oneflow { struct PoolOpKernelCache final : public user_op::OpKernelCache { MaxPoolParams3D params_3d; explicit PoolOpKernelCache(const MaxPoolParams3D& params_3d) : params_3d(params_3d) {} const MaxPoolParams3D& GetParams3D() const { return params_3d; } }; std::shared_ptr CreatePoolOpKernelCache(user_op::KernelCacheContext* ctx, const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& stride = ctx->Attr>("stride"); const std::vector& dilation = ctx->Attr>("dilation"); const bool return_indices = ctx->Attr("return_indices"); const bool ceil_mode = ctx->Attr("ceil_mode"); MaxPoolParams3D params_3d = MaxPoolParams3D(dim, x_shape, data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); std::shared_ptr cache(new PoolOpKernelCache(params_3d)); return cache; } namespace { template void Maxpool2dForwardComputeCLast(const NdIndexOffsetHelper& index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t y_height, const int32_t y_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h, const int32_t dilation_w) { IDX n = 0, h = 0, w = 0, c = 0; for (IDX num = 0; num < elem_num; ++num) { index_helper.OffsetToNdIndex(num, n, h, w, c); const IDX x_start_idx = n * x_height * x_width * n_channel; const IDX y_start_idx = n * y_height * y_width * n_channel; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height ? (hstart + (kernel_size_h - 1) * dilation_h + 1) : x_height; const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width ? (wstart + (kernel_size_w - 1) * dilation_w + 1) : x_width; while (hstart < 0) { hstart += dilation_h; } while (wstart < 0) { wstart += dilation_w; } /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ IDX max_index = hstart * x_width + wstart; IDX src_idx = 0; /* equal to -std::numeric_limits::infinity(); */ T max_value = detail::numeric_limits::lower_bound(); for (IDX i = hstart; i < hend; i += dilation_h) { for (IDX j = wstart; j < wend; j += dilation_w) { const IDX window_idx = i * x_width * n_channel + j * n_channel + c; const IDX search_idx = x_start_idx + window_idx; T val = src[search_idx]; if (val > max_value || detail::numerics::isnan(val)) { max_value = val; max_index = window_idx; src_idx = search_idx; } } } const IDX out_idx = y_start_idx + h * y_width * n_channel + w * n_channel + c; dest[out_idx] = src[src_idx]; indice_ptr[out_idx] = max_index; } } } // namespace template struct PoolKernelUtil { static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool1dForwardCompute( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.dilation_3d()[2]); } static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool1dBackwardCompute(index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4)); } static void Maxpool2dForwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool2dForwardComputeCFirst( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool2dBackwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool2dBackwardComputeCFirst( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } static void Maxpool2dForwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool2dForwardComputeCLast( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool2dBackwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool2dBackwardComputeCLast( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool3dForwardCompute( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { Maxpool3dBackwardCompute(index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } }; template class MaxPool1dKernel final : public user_op::OpKernel { public: MaxPool1dKernel() = default; ~MaxPool1dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); int64_t* indice_ptr = indice->mut_dptr(); DimVector y_vector(2); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool1dForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool1dForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } } }; template class MaxPool1dGradKernel final : public user_op::OpKernel { public: MaxPool1dGradKernel() = default; ~MaxPool1dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); const int64_t* indice_ptr = indice->dptr(); T* dest = dx->mut_dptr(); DimVector dy_vector(2); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool1dBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool1dBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } }; }; template class MaxPool2dKernel final : public user_op::OpKernel { public: MaxPool2dKernel() = default; ~MaxPool2dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); int64_t* indice_ptr = indice->mut_dptr(); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { DimVector y_vector(3); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); y_vector.at(2) = y->shape_view().At(3); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool2dForwardCFirst( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool2dForwardCFirst( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } } else if (data_format == "channels_last") { DimVector y_vector; y->shape_view().ToDimVector(&y_vector); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool2dForwardCLast( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool2dForwardCLast( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } } else { UNIMPLEMENTED() << "Unsupported data_format"; } }; }; template class MaxPool2dGradKernel final : public user_op::OpKernel { public: MaxPool2dGradKernel() = default; ~MaxPool2dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); const int64_t* indice_ptr = indice->dptr(); T* dest = dx->mut_dptr(); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { DimVector dy_vector(3); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); dy_vector.at(2) = dy->shape_view().At(3); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool2dBackwardCFirst( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool2dBackwardCFirst( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } } else if (data_format == "channels_last") { DimVector dy_vector; dy->shape_view().ToDimVector(&dy_vector); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool2dBackwardCLast( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool2dBackwardCLast( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } } else { UNIMPLEMENTED() << "Unsupported data_format"; } }; }; template class MaxPool3dKernel final : public user_op::OpKernel { public: MaxPool3dKernel() = default; ~MaxPool3dKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = y->shape_view().elem_cnt(); const T* src = x->dptr(); T* dest = y->mut_dptr(); int64_t* indice_ptr = indice->mut_dptr(); DimVector y_vector(4); y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1); y_vector.at(1) = y->shape_view().At(2); y_vector.at(2) = y->shape_view().At(3); y_vector.at(3) = y->shape_view().At(4); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool3dForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(y_vector.data()); PoolKernelUtil::Maxpool3dForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } }; }; template class MaxPool3dGradKernel final : public user_op::OpKernel { public: MaxPool3dGradKernel() = default; ~MaxPool3dGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreatePoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* pool_cache = dynamic_cast(cache); const MaxPoolParams3D& params_3d = pool_cache->GetParams3D(); const int64_t elem_num = dy->shape_view().elem_cnt(); const T* src = dy->dptr(); const int64_t* indice_ptr = indice->dptr(); T* dest = dx->mut_dptr(); DimVector dy_vector(4); dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1); dy_vector.at(1) = dy->shape_view().At(2); dy_vector.at(2) = dy->shape_view().At(3); dy_vector.at(3) = dy->shape_view().At(4); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool3dBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } else { NdIndexOffsetHelper index_helper(dy_vector.data()); PoolKernelUtil::Maxpool3dBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d); } }; }; #define REGISTER_POOL_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("max_pool_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_pool_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_pool_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_pool_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_pool_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_pool_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); #define REGISTER_POOL_WITH_DEVICE(device) \ REGISTER_POOL_KERNELS(device, int32_t) \ REGISTER_POOL_KERNELS(device, float) \ REGISTER_POOL_KERNELS(device, double) REGISTER_POOL_WITH_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_POOL_WITH_DEVICE(DeviceType::kCUDA) REGISTER_POOL_KERNELS(DeviceType::kCUDA, half) #endif OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOL_KERNEL_UTIL, (DeviceType::kCPU), POOL_DATA_TYPE_CPU_SEQ, POOL_IDX_DATA_TYPE_SEQ); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/max_pool_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #ifdef WITH_CUDA #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/max_pool_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetMinThreadNum(int64_t elem_num) { return std::min(elem_num, kBlockSize); } int GetNumBlocks(int64_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } template __device__ __inline__ void Maxpool2dForwardComputeCLast( const NdIndexOffsetHelper& index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch, const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height, const int64_t y_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h, const int32_t dilation_w) { IDX n, h, w, c; CUDA_1D_KERNEL_LOOP(num, elem_num) { index_helper.OffsetToNdIndex(num, n, h, w, c); const IDX x_start_idx = n * n_channel * x_width * x_height; const IDX y_start_idx = n * n_channel * y_height * y_width; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height ? (hstart + (kernel_size_h - 1) * dilation_h + 1) : x_height; const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width ? (wstart + (kernel_size_w - 1) * dilation_w + 1) : x_width; while (hstart < 0) { hstart += dilation_h; } while (wstart < 0) { wstart += dilation_w; } /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ IDX max_index = hstart * x_width + wstart; IDX src_idx = 0; /* equal to -std::numeric_limits::infinity(); */ T max_value = detail::numeric_limits::lower_bound(); for (IDX i = hstart; i < hend; i += dilation_h) { for (IDX j = wstart; j < wend; j += dilation_w) { const IDX window_idx = i * x_width * n_channel + j * n_channel + c; const IDX search_idx = x_start_idx + window_idx; T val = src[search_idx]; if (val > max_value || detail::numerics::isnan(val)) { max_value = val; max_index = window_idx; src_idx = search_idx; } } } const IDX out_idx = y_start_idx + h * y_width * n_channel + w * n_channel + c; dest[out_idx] = src[src_idx]; indice_ptr[out_idx] = max_index; } } } // namespace template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool1dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, int32_t padding_l, int32_t n_batch, int32_t n_channel, int32_t x_length, int32_t kernel_size_l, int32_t stride_l, int32_t dilation_l) { Maxpool1dForwardCompute(index_helper, elem_num, src, dest, indice_ptr, padding_l, n_batch, n_channel, x_length, kernel_size_l, stride_l, dilation_l); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dForwardCFirst(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, int32_t padding_h, int32_t padding_w, int32_t n_batch, int32_t n_channel, int32_t x_height, int32_t x_width, int32_t kernel_size_h, int32_t kernel_size_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w) { Maxpool2dForwardComputeCFirst( index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, n_batch, n_channel, x_height, x_width, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dForwardCLast(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, int32_t padding_h, int32_t padding_w, int32_t n_batch, int32_t n_channel, int32_t x_height, int32_t x_width, int32_t y_height, int32_t y_width, int32_t kernel_size_h, int32_t kernel_size_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w) { Maxpool2dForwardComputeCLast(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, n_batch, n_channel, x_height, x_width, y_height, y_width, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool3dForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, int32_t padding_t, int32_t padding_h, int32_t padding_w, int32_t n_batch, int32_t n_channel, int32_t x_time, int32_t x_height, int32_t x_width, int32_t kernel_size_t, int32_t kernel_size_h, int32_t kernel_size_w, int32_t stride_t, int32_t stride_h, int32_t stride_w, int32_t dilation_t, int32_t dilation_h, int32_t dilation_w) { Maxpool3dForwardCompute(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool1dBackward(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_length, const int32_t dst_length) { Maxpool1dBackwardCompute(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, src_length, dst_length); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dBackwardCFirst(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_height, const int32_t src_width, const int32_t dst_height, const int32_t dst_width) { Maxpool2dBackwardComputeCFirst(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, src_height, src_width, dst_height, dst_width); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dBackwardCLast(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_height, const int32_t src_width, const int32_t dst_height, const int32_t dst_width) { Maxpool2dBackwardComputeCLast(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, src_height, src_width, dst_height, dst_width); }; template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool3dBackward(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_time, const int32_t src_height, const int32_t src_width, const int32_t dst_time, const int32_t dst_height, const int32_t dst_width) { Maxpool3dBackwardCompute(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, src_time, src_height, src_width, dst_time, dst_height, dst_width); }; template struct PoolKernelUtil { static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool1dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.dilation_3d()[2]); } static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool1dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4)); } static void Maxpool2dForwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool2dForwardCFirst<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool2dBackwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool2dBackwardCFirst<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } static void Maxpool2dForwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool2dForwardCLast<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool2dBackwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool2dBackwardCLast<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool3dForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0], params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); } static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) { DoCUDAMaxPool3dBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), params_3d.num_channel(), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOL_KERNEL_UTIL, (DeviceType::kCUDA), POOL_DATA_TYPE_CUDA_SEQ, POOL_IDX_DATA_TYPE_SEQ); } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/max_pool_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/utils/pool_util.h" #include "oneflow/user/kernels/max_pool_kernel_util.h" namespace oneflow { void GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride, int32_t padding, bool ceil_mode, int32_t dilation_rate, int64_t* output_ptr) { int64_t output_size = (input_size + 2 * padding - dilation_rate * (filter_size - 1) - 1 + stride + (ceil_mode ? stride - 1 : 0)) / stride; if (ceil_mode) { // ensure that the last pool starts inside the image // needed to avoid problems in ceil mode if ((output_size - 1) * stride >= input_size + padding) { --output_size; } } *output_ptr = output_size; } void Get3DOutputShape(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::vector& padding, const bool ceil_mode, std::vector dilation_rate, DimVector* out) { out->clear(); out->resize(3); FOR_RANGE(size_t, i, 0, 3) { int64_t* out_ptr = &(*out).at(i); GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode, dilation_rate.at(i), out_ptr); } } MaxPoolParams3D::MaxPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const std::vector& dilation, const bool return_indices, const bool ceil_mode) : dim_(dim), data_format_(data_format), padding_(Get3DVec(padding, dim)), pool_size_3d_(Get3DVec(kernel_size, dim)), stride_3d_(Get3DVec(stride, dim)), dilation_3d_(Get3DVec(dilation, dim)), return_indices_(return_indices), ceil_mode_(ceil_mode) { x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), GetInDim(x_shape, data_format, 2, dim)}; Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_); if (data_format == "channels_first") { channel_num_ = x_shape.At(1); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; channel_num_ = x_shape.At(x_shape.NumAxes() - 1); } batch_num_ = x_shape.At(0); } void MaxPoolParams3D::Reset(const ShapeView& x_shape) { x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_), GetInDim(x_shape, data_format_, 2, dim_)}; Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_); } Shape MaxPoolParams3D::GetYShape() const { DimVector y_dim_vec; if (dim_ == 1) { y_dim_vec = {y_3d_.at(2)}; } else if (dim_ == 2) { y_dim_vec = {y_3d_.at(1), y_3d_.at(2)}; } else if (dim_ == 3) { y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}; } else { UNIMPLEMENTED(); } if (data_format_ == "channels_first") { y_dim_vec.insert(y_dim_vec.begin(), channel_num_); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; y_dim_vec.insert(y_dim_vec.end(), channel_num_); } y_dim_vec.insert(y_dim_vec.begin(), batch_num_); return Shape(y_dim_vec); } Shape MaxPoolParams3D::GetXShape5D() const { return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)}); } Shape MaxPoolParams3D::GetYShape5D() const { return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/max_pool_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/core/kernel/util/numerics.cuh" #include "oneflow/core/kernel/util/numeric_limits.cuh" #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA namespace oneflow { #define POOL_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define POOL_IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define POOL_DATA_TYPE_CPU_SEQ POOL_DATA_TYPE_SEQ #define POOL_DATA_TYPE_CUDA_SEQ POOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) typedef small_vector FixedDimVector; template struct DeviceAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #if defined(__CUDA_ARCH__) cuda::atomic::Add(y, *x); #else *y += *x; #endif }; }; class MaxPoolParams3D { public: MaxPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride, const std::vector& dilation, const bool return_indices, const bool ceil_mode); ~MaxPoolParams3D() = default; const std::string& data_format() const { return data_format_; } const std::vector& padding() const { return padding_; } const std::vector& pool_size_3d() const { return pool_size_3d_; } const std::vector& stride_3d() const { return stride_3d_; } const std::vector& dilation_3d() const { return dilation_3d_; } const bool& return_indices() const { return return_indices_; } const bool& ceil_mode() const { return ceil_mode_; } const int32_t& num_batch() const { return batch_num_; } const int32_t& num_channel() const { return channel_num_; } void Reset(const ShapeView& x_shape); Shape GetYShape() const; Shape GetXShape5D() const; Shape GetYShape5D() const; private: int32_t dim_; FixedDimVector x_3d_; FixedDimVector y_3d_; std::string data_format_; std::vector padding_; std::vector pool_size_3d_; std::vector stride_3d_; std::vector dilation_3d_; bool return_indices_; bool ceil_mode_; int32_t batch_num_; int32_t channel_num_; }; template struct PoolKernelUtil { static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool2dForwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool2dBackwardCFirst(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool2dForwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool2dBackwardCLast(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const MaxPoolParams3D& params_3d); static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const MaxPoolParams3D& params_3d); }; template OF_DEVICE_FUNC void Maxpool1dForwardCompute(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_l, const int32_t n_batch, const int32_t n_channel, const int32_t x_length, const int32_t kernel_size_l, const int32_t stride_l, const int32_t dilation_l) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); IDX lstart = l * stride_l - padding_l; const IDX lend = (lstart + (kernel_size_l - 1) * dilation_l + 1) <= x_length ? (lstart + (kernel_size_l - 1) * dilation_l + 1) : x_length; while (lstart < 0) { lstart += dilation_l; } /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ IDX max_index = lstart; /* equal to -std::numeric_limits::infinity(); */ T max_value = detail::numeric_limits::lower_bound(); const T* data = src + n_c * x_length; for (IDX idx = lstart; idx < lend; idx += dilation_l) { const IDX window_idx = idx; T val = data[window_idx]; if (val > max_value || detail::numerics::isnan(val)) { max_value = val; max_index = idx; } } dest[num] = max_value; indice_ptr[num] = max_index; } } template OF_DEVICE_FUNC void Maxpool1dBackwardCompute(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_length, const int32_t dst_length) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, l; index_helper.OffsetToNdIndex(num, n_c, l); const IDX src_start = n_c * src_length; const IDX dst_start = n_c * dst_length; const IDX index = src_start + l; const IDX max_index = dst_start + indice_ptr[index]; if (max_index != -1) { /* update gradient, equals to dest[max_index] += src[index]; */ DeviceAdd::Invoke(src + index, dest + max_index); } } } template OF_DEVICE_FUNC void Maxpool2dForwardComputeCFirst( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h, const int32_t dilation_w) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height ? (hstart + (kernel_size_h - 1) * dilation_h + 1) : x_height; const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width ? (wstart + (kernel_size_w - 1) * dilation_w + 1) : x_width; while (hstart < 0) { hstart += dilation_h; } while (wstart < 0) { wstart += dilation_w; } /* equal to -std::numeric_limits::infinity(); */ T max_value = detail::numeric_limits::lower_bound(); /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ IDX max_index = hstart * x_width + wstart; const T* data = src + n_c * x_width * x_height; for (IDX i = hstart; i < hend; i += dilation_h) { for (IDX j = wstart; j < wend; j += dilation_w) { const IDX window_idx = i * x_width + j; T val = data[window_idx]; /* NOTE: std::isnan(val) only supports a few data types, see: https://en.cppreference.com/w/cpp/numeric/math/isnan and when use gcc/g++ 4.x to compile, the following exception will be throw: new_kernel_util.cu:24] Check failed: cudaMemcpyAsync(dst, src, sz, cudaMemcpyDefault, ctx->cuda_stream() ) : unspecified launch failure (719) but if use gcc/g++ 7.x to compile, everything is ok! the exact reason is still unknown! */ if (val > max_value || detail::numerics::isnan(val)) { max_index = window_idx; max_value = val; } } } dest[num] = max_value; indice_ptr[num] = max_index; } } template OF_DEVICE_FUNC void Maxpool2dBackwardComputeCFirst( const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_height, const int32_t src_width, const int32_t dst_height, const int32_t dst_width) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, h, w; index_helper.OffsetToNdIndex(num, n_c, h, w); const IDX src_start = n_c * src_height * src_width; const IDX dst_start = n_c * dst_height * dst_width; const IDX index = src_start + h * src_width + w; const IDX max_index = dst_start + indice_ptr[index]; if (max_index != -1) { /* update gradient, equals to dest[max_index] += src[index]; */ DeviceAdd::Invoke(src + index, dest + max_index); } } } template OF_DEVICE_FUNC void Maxpool2dBackwardComputeCLast(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_height, const int32_t src_width, const int32_t dst_height, const int32_t dst_width) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n, c, h, w; index_helper.OffsetToNdIndex(num, n, c, h, w); const IDX src_start = n * src_height * src_width * n_channel; const IDX dst_start = n * dst_height * dst_width * n_channel; const IDX index = src_start + h * src_width + w; const IDX max_index = dst_start + indice_ptr[index]; if (max_index != -1) { /* update gradient, equals to dest[max_index] += src[index]; */ DeviceAdd::Invoke(src + index, dest + max_index); } } } template OF_DEVICE_FUNC void Maxpool3dForwardCompute( const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height, const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, const int32_t dilation_t, const int32_t dilation_h, const int32_t dilation_w) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); IDX tstart = t * stride_t - padding_t; IDX hstart = h * stride_h - padding_h; IDX wstart = w * stride_w - padding_w; const IDX t1 = tstart + (kernel_size_t - 1) * dilation_t + 1; const IDX t2 = hstart + (kernel_size_h - 1) * dilation_h + 1; const IDX t3 = wstart + (kernel_size_w - 1) * dilation_w + 1; const IDX tend = t1 <= x_time ? t1 : x_time; const IDX hend = t2 <= x_height ? t2 : x_height; const IDX wend = t3 <= x_width ? t3 : x_width; while (tstart < 0) { tstart += dilation_t; } while (hstart < 0) { hstart += dilation_h; } while (wstart < 0) { wstart += dilation_w; } IDX max_index = tstart * x_height * x_width + hstart * x_width + wstart; const T* data = src + n_c * x_time * x_width * x_height; T max_value = detail::numeric_limits::lower_bound(); for (IDX zi = tstart; zi < tend; zi += dilation_t) { for (IDX i = hstart; i < hend; i += dilation_h) { for (IDX j = wstart; j < wend; j += dilation_w) { const IDX window_idx = zi * x_height * x_width + i * x_width + j; T val = data[window_idx]; if (val > max_value || detail::numerics::isnan(val)) { max_value = val; max_index = window_idx; } } } /* set output to local max */ dest[num] = max_value; /* store location of max */ indice_ptr[num] = max_index; } } } template OF_DEVICE_FUNC void Maxpool3dBackwardCompute(const NdIndexOffsetHelper index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel, const int32_t src_time, const int32_t src_height, const int32_t src_width, const int32_t dst_time, const int32_t dst_height, const int32_t dst_width) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX n_c, t, h, w; index_helper.OffsetToNdIndex(num, n_c, t, h, w); const IDX src_start = n_c * src_time * src_height * src_width; const IDX dst_start = n_c * dst_time * dst_height * dst_width; const IDX index = src_start + t * src_height * src_width + h * src_width + w; const IDX max_index = dst_start + indice_ptr[index]; if (max_index != -1) { DeviceAdd::Invoke(src + index, dest + max_index); } } } #define INSTANTIATE_POOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \ template struct PoolKernelUtil; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/max_unpool_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "fmt/core.h" #include "oneflow/core/common/bfloat16.h" #include "oneflow/core/common/throw.h" #include "oneflow/user/kernels/max_unpool_kernel_util.h" namespace oneflow { namespace { template void MaxUnpoolNdForwardOrBackward(const NdIndexOffsetHelper& index_helper, const IDX elem_num, const int64_t* indice_ptr, const int64_t hwd_size, const int64_t out_elem_num, const F& f) { XPU_1D_KERNEL_LOOP(num, elem_num) { IDX bc_idx, hwd_idx; index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx); IDX idx = bc_idx * hwd_size + indice_ptr[num]; CHECK_OR_THROW(idx >= 0 && idx < out_elem_num) << fmt::format( "Found an invalid max index: {}, output volumes are of size {}", idx, out_elem_num); f(num, idx); } } } // namespace template struct UnpoolKernelUtil { static void MaxUnpoolNdForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t y_hwd_size, const int64_t y_elem_num) { MaxUnpoolNdForwardOrBackward(index_helper, elem_num, indice_ptr, y_hwd_size, y_elem_num, [&](int64_t num, IDX idx) { dest[idx] = src[num]; }); } static void MaxUnpoolNdBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t dy_hwd_size, const int64_t dy_elem_num) { MaxUnpoolNdForwardOrBackward(index_helper, elem_num, indice_ptr, dy_hwd_size, dy_elem_num, [&](int64_t num, IDX idx) { dest[num] = src[idx]; }); } }; template class MaxUnpoolNdKernel final : public user_op::OpKernel { public: MaxUnpoolNdKernel() = default; ~MaxUnpoolNdKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t elem_num = x->shape_view().elem_cnt(); const T* src = x->dptr(); const int64_t* indice_ptr = indice->dptr(); T* dest = y->mut_dptr(); DimVector x_vector(2); x_vector.at(0) = x->shape_view().At(0) * x->shape_view().At(1); int64_t y_hwd_size = 1; x_vector.at(1) = std::accumulate(x->shape_view().begin() + 2, x->shape_view().end(), 1, std::multiplies()); y_hwd_size = std::accumulate(y->shape_view().begin() + 2, y->shape_view().end(), 1, std::multiplies()); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), dest, 0, y->shape_view().elem_cnt() * GetSizeOfDataType(y->data_type())); const int64_t y_elem_num = y->shape_view().elem_cnt(); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(x_vector.data()); UnpoolKernelUtil::MaxUnpoolNdForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num); } else { NdIndexOffsetHelper index_helper(x_vector.data()); UnpoolKernelUtil::MaxUnpoolNdForward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num); } } }; template class MaxUnpoolNdGradKernel final : public user_op::OpKernel { public: MaxUnpoolNdGradKernel() = default; ~MaxUnpoolNdGradKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_num = dx->shape_view().elem_cnt(); const T* src = dy->dptr(); const int64_t* indice_ptr = indice->dptr(); T* dest = dx->mut_dptr(); DimVector dx_vector(2); dx_vector.at(0) = dx->shape_view().At(0) * dx->shape_view().At(1); int64_t dy_hwd_size = 1; dx_vector.at(1) = std::accumulate(dx->shape_view().begin() + 2, dx->shape_view().end(), 1, std::multiplies()); dy_hwd_size = std::accumulate(dy->shape_view().begin() + 2, dy->shape_view().end(), 1, std::multiplies()); const int64_t dy_elem_num = dy->shape_view().elem_cnt(); if (elem_num < GetMaxVal()) { NdIndexOffsetHelper index_helper(dx_vector.data()); UnpoolKernelUtil::MaxUnpoolNdBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num); } else { NdIndexOffsetHelper index_helper(dx_vector.data()); UnpoolKernelUtil::MaxUnpoolNdBackward( ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num); } }; }; #define REGISTER_UNPOOL_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("max_unpool_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_unpool_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_unpool_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_unpool_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_unpool_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("max_unpool_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCPU), UNPOOL_DATA_TYPE_CPU_SEQ, UNPOOL_IDX_DATA_TYPE_SEQ); #define REGISTER_UNPOOL_WITH_DEVICE(device) \ REGISTER_UNPOOL_KERNELS(device, int32_t) \ REGISTER_UNPOOL_KERNELS(device, int64_t) \ REGISTER_UNPOOL_KERNELS(device, float) \ REGISTER_UNPOOL_KERNELS(device, double) REGISTER_UNPOOL_WITH_DEVICE(DeviceType::kCPU) REGISTER_UNPOOL_KERNELS(DeviceType::kCPU, float16) REGISTER_UNPOOL_KERNELS(DeviceType::kCPU, bfloat16) #ifdef WITH_CUDA REGISTER_UNPOOL_WITH_DEVICE(DeviceType::kCUDA) REGISTER_UNPOOL_KERNELS(DeviceType::kCUDA, half) #if CUDA_VERSION >= 11000 REGISTER_UNPOOL_KERNELS(DeviceType::kCUDA, nv_bfloat16) #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/max_unpool_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/max_unpool_kernel_util.h" #include namespace oneflow { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetMinThreadNum(int64_t elem_num) { return std::min(elem_num, kBlockSize); } int GetNumBlocks(int64_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } } // namespace template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxUnpoolNdForward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t y_hwd_size, const int64_t y_elem_num) { CUDA_1D_KERNEL_LOOP_T(IDX, num, elem_num) { IDX bc_idx, hwd_idx; index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx); IDX dest_idx = bc_idx * y_hwd_size + indice_ptr[num]; if (dest_idx >= 0 && dest_idx < y_elem_num) { dest[dest_idx] = src[num]; } } } template __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxUnpoolNdBackward(const NdIndexOffsetHelper index_helper, IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t dy_hwd_size, const int64_t dy_elem_num) { CUDA_1D_KERNEL_LOOP_T(IDX, num, elem_num) { IDX bc_idx, hwd_idx; index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx); IDX src_idx = bc_idx * dy_hwd_size + indice_ptr[num]; if (src_idx >= 0 && src_idx < dy_elem_num) { dest[num] = src[src_idx]; } else { dest[num] = 0.0f; } } } template struct UnpoolKernelUtil { static void MaxUnpoolNdForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t y_hwd_size, const int64_t y_elem_num) { DoCUDAMaxUnpoolNdForward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num); } static void MaxUnpoolNdBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t dy_hwd_size, const int64_t dy_elem_num) { DoCUDAMaxUnpoolNdBackward<<As()->cuda_stream()>>>( index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCUDA), UNPOOL_DATA_TYPE_CUDA_SEQ, UNPOOL_IDX_DATA_TYPE_SEQ); #if CUDA_VERSION >= 11000 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16), UNPOOL_IDX_DATA_TYPE_SEQ); #endif // CUDA_VERSION >= 11000 } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/max_unpool_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/user/kernels/max_unpool_kernel_util.h" namespace oneflow { namespace { void GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride, int32_t padding, int64_t* output_ptr) { int64_t output_size = (input_size - 1) * stride - 2 * padding + filter_size; *output_ptr = output_size; } void Get3DOutputShape(const DimVector& in, const std::vector& pool_size, const std::vector& strides, const std::vector& padding, DimVector* out) { out->clear(); out->resize(3); FOR_RANGE(size_t, i, 0, 3) { int64_t* out_ptr = &(*out).at(i); GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), out_ptr); } } } // namespace MaxUnpoolParams3D::MaxUnpoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride) : dim_(dim), padding_(Get3DVec(padding, dim)), pool_size_3d_(Get3DVec(kernel_size, dim)), stride_3d_(Get3DVec(stride, dim)), batch_num_(x_shape.At(0)), channel_num_(x_shape.At(1)) { std::string data_format = "channels_first"; x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), GetInDim(x_shape, data_format, 2, dim)}; Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, &y_3d_); } void MaxUnpoolParams3D::Reset(const ShapeView& x_shape) { std::string data_format = "channels_first"; x_3d_ = {GetInDim(x_shape, data_format, 0, dim_), GetInDim(x_shape, data_format, 1, dim_), GetInDim(x_shape, data_format, 2, dim_)}; Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, &y_3d_); } int64_t MaxUnpoolParams3D::GetYStride() const { return y_3d_.at(0) * y_3d_.at(1) * y_3d_.at(2); } Shape MaxUnpoolParams3D::GetYShape() const { DimVector y_dim_vec; if (dim_ == 1) { y_dim_vec = {y_3d_.at(2)}; } else if (dim_ == 2) { y_dim_vec = {y_3d_.at(1), y_3d_.at(2)}; } else if (dim_ == 3) { y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}; } else { UNIMPLEMENTED(); } y_dim_vec.insert(y_dim_vec.begin(), channel_num_); y_dim_vec.insert(y_dim_vec.begin(), batch_num_); return Shape(y_dim_vec); } Shape MaxUnpoolParams3D::GetXShape5D() const { return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)}); } Shape MaxUnpoolParams3D::GetYShape5D() const { return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/max_unpool_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_ #include "oneflow/core/ndarray/xpu_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { #define UNPOOL_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define UNPOOL_IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define UNPOOL_DATA_TYPE_CPU_SEQ \ UNPOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) \ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) #define UNPOOL_DATA_TYPE_CUDA_SEQ \ UNPOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) // OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) typedef small_vector FixedDimVector; class MaxUnpoolParams3D { public: MaxUnpoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::vector& padding, const std::vector& kernel_size, const std::vector& stride); ~MaxUnpoolParams3D() = default; const std::vector& padding() const { return padding_; } const std::vector& pool_size_3d() const { return pool_size_3d_; } const std::vector& stride_3d() const { return stride_3d_; } const int32_t& num_batch() const { return batch_num_; } const int32_t& num_channel() const { return channel_num_; } void Reset(const ShapeView& x_shape); Shape GetYShape() const; Shape GetXShape5D() const; Shape GetYShape5D() const; int64_t GetYStride() const; private: int32_t dim_; FixedDimVector x_3d_; FixedDimVector y_3d_; std::vector padding_; std::vector pool_size_3d_; std::vector stride_3d_; int32_t batch_num_; int32_t channel_num_; }; template struct UnpoolKernelUtil { static void MaxUnpoolNdForward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t y_hwd_size, const int64_t y_elem_num); static void MaxUnpoolNdBackward(ep::Stream* stream, const NdIndexOffsetHelper& index_helper, const IDX elem_num, const T* src, T* dest, const int64_t* indice_ptr, const int64_t dy_hwd_size, const int64_t dy_elem_num); }; #define INSTANTIATE_UNPOOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \ template struct UnpoolKernelUtil; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/median_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template class CpuMedianKernel final : public user_op::OpKernel { public: CpuMedianKernel() = default; ~CpuMedianKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("input", 0); const int64_t size = in->shape_view().elem_cnt(); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* out_ptr = out->mut_dptr(); Memcpy(ctx->stream(), tmp_buffer->mut_dptr(), in->dptr(), size * sizeof(T)); T* first = tmp_buffer->mut_dptr(); T* last = first + size; T* median = first + (size - 1) / 2; std::nth_element(first, median, last); *out_ptr = *median; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_MEDIAN_KERNEL(dtype) \ REGISTER_USER_KERNEL("median") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ return ctx->InputShape("input", 0).elem_cnt() * sizeof(dtype); \ }); REGISTER_CPU_MEDIAN_KERNEL(float) REGISTER_CPU_MEDIAN_KERNEL(double) REGISTER_CPU_MEDIAN_KERNEL(int8_t) REGISTER_CPU_MEDIAN_KERNEL(uint8_t) REGISTER_CPU_MEDIAN_KERNEL(int32_t) REGISTER_CPU_MEDIAN_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/median_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" namespace oneflow { template class CudaMedianKernel final : public user_op::OpKernel { public: CudaMedianKernel() = default; ~CudaMedianKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t instance_size = in->shape_view().elem_cnt(); const size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(instance_size * sizeof(T)); SortKeysAscending( in->dptr(), 1, instance_size, reinterpret_cast(tmp_buffer->mut_dptr() + sort_tensor_buffer_bytes), tmp_buffer->shape_view().elem_cnt() - sort_tensor_buffer_bytes, tmp_buffer->mut_dptr(), ctx->stream()->As()->cuda_stream()); Memcpy(ctx->stream(), out->mut_dptr(), tmp_buffer->mut_dptr() + (instance_size - 1) / 2, sizeof(T)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_MEDIAN_KERNEL(dtype) \ REGISTER_USER_KERNEL("median") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& in_shape = ctx->InputShape("input", 0); \ const int32_t instance_size = in_shape.elem_cnt(); \ size_t sort_tmp_buffer_bytes = \ InferTempStorageForSortKeysAscending(1, instance_size); \ size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(instance_size * sizeof(dtype)); \ return sort_tmp_buffer_bytes + sort_tensor_buffer_bytes; \ }); REGISTER_CUDA_MEDIAN_KERNEL(float) REGISTER_CUDA_MEDIAN_KERNEL(double) REGISTER_CUDA_MEDIAN_KERNEL(int8_t) REGISTER_CUDA_MEDIAN_KERNEL(uint8_t) REGISTER_CUDA_MEDIAN_KERNEL(int32_t) REGISTER_CUDA_MEDIAN_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/median_with_indices_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { template class CpuMedianWithIndicesKernel final : public user_op::OpKernel { public: CpuMedianWithIndicesKernel() = default; ~CpuMedianWithIndicesKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("input", 0); const int64_t num_axes = in->shape_view().NumAxes(); const int64_t size = in->shape_view().elem_cnt(); if (size == 0) return; const int64_t stride = in->shape_view().At(num_axes - 1); const int64_t instance_num = size / stride; user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); Memcpy(ctx->stream(), tmp_buffer->mut_dptr(), in->dptr(), size * sizeof(T)); const int64_t thread_num = std::min(instance_num, (int64_t)Singleton::Get()->thread_num()); const BalancedSplitter bs(instance_num, thread_num); BlockingCounter bc(thread_num); FOR_RANGE(int64_t, thread_id, 0, thread_num) { const Range range = bs.At(thread_id); Singleton::Get()->AddWork([=, &bc]() { FOR_RANGE(int64_t, i, range.begin(), range.end()) { T* in_ptr = tmp_buffer->mut_dptr() + i * stride; T* val_ptr = values->mut_dptr() + i; int64_t* ind_ptr = indices->mut_dptr() + i; std::vector idx(stride); auto first = idx.begin(); auto last = idx.end(); std::iota(first, last, 0); auto nth = first; nth += (stride - 1) / 2; std::nth_element(first, nth, last, [&in_ptr](int64_t i, int64_t j) { return in_ptr[i] < in_ptr[j] || (in_ptr[i] == in_ptr[j] && i < j); }); *val_ptr = in_ptr[*nth]; *ind_ptr = *nth; } bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(dtype) \ REGISTER_USER_KERNEL("median_with_indices") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ return ctx->InputShape("input", 0).elem_cnt() * sizeof(dtype); \ }); REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(float) REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(double) REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int8_t) REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(uint8_t) REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int32_t) REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/median_with_indices_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" namespace oneflow { namespace { template __global__ void MedianSelectCuda(const IDX reduce_elem_cnt, const IDX stride, const T* in, const int64_t* sort_indices, T* values, int64_t* indices) { IDX nth = (stride - 1) / 2; CUDA_1D_KERNEL_LOOP_T(IDX, i, reduce_elem_cnt) { values[i] = in[i * stride + nth]; indices[i] = sort_indices[i * stride + nth]; } } bool IsSafeUseIndex32(int64_t elem_cnt) { return elem_cnt < GetMaxVal() / 2; } template void DispatchIndexSize(ep::Stream* stream, const int64_t elem_cnt, const int64_t stride, const T* in, const int64_t* sort_indices, T* out, int64_t* out_indices) { const int64_t reduce_elem_cnt = elem_cnt / stride; if (IsSafeUseIndex32(elem_cnt)) { RUN_CUDA_KERNEL((MedianSelectCuda), stream, reduce_elem_cnt, reduce_elem_cnt, stride, in, sort_indices, out, out_indices); } else { RUN_CUDA_KERNEL((MedianSelectCuda), stream, reduce_elem_cnt, reduce_elem_cnt, stride, in, sort_indices, out, out_indices); } } template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(size_t capacity, void* ptr, const ShapeView& in_shape) : capacity_{capacity}, sorted_in_elem_cnt_{in_shape.elem_cnt()}, indices_elem_cnt_{sorted_in_elem_cnt_} { const size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); const size_t sort_indices_buffer_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int64_t)); sorted_in_ptr_ = reinterpret_cast(ptr); in_indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_in_ptr_) + sort_tensor_buffer_bytes); out_indices_ptr_ = reinterpret_cast(reinterpret_cast(in_indices_ptr_) + sort_indices_buffer_bytes); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(out_indices_ptr_) + sort_indices_buffer_bytes); temp_storage_bytes_ = capacity_ - sort_tensor_buffer_bytes - sort_indices_buffer_bytes * 2; CHECK_GE(temp_storage_bytes_, 0); } ~TmpBufferManager() = default; T* SortedInPtr() const { return sorted_in_ptr_; } int64_t* InIndicesPtr() const { return in_indices_ptr_; } int64_t* OutIndicesPtr() const { return out_indices_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } size_t TempStorageBytes() const { return temp_storage_bytes_; } private: size_t capacity_; T* sorted_in_ptr_; int64_t* in_indices_ptr_; int64_t* out_indices_ptr_; void* temp_storage_ptr_; int64_t sorted_in_elem_cnt_; int64_t indices_elem_cnt_; size_t temp_storage_bytes_; }; __global__ void InitializeIndices(int64_t elem_cnt, int64_t* indices_ptr, int64_t instance_size) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; }; } } // namespace template class CudaMedianWithIndicesKernel final : public user_op::OpKernel { public: CudaMedianWithIndicesKernel() = default; ~CudaMedianWithIndicesKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("input", 0); if (in->shape_view().elem_cnt() == 0) return; user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buf_manager(tmp_buffer->shape_view().elem_cnt(), tmp_buffer->mut_dptr(), in->shape_view()); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int64_t instance_num = elem_cnt / instance_size; RUN_CUDA_KERNEL(InitializeIndices, ctx->stream(), elem_cnt, elem_cnt, buf_manager.InIndicesPtr(), instance_size); SortPairsAscending(in->dptr(), buf_manager.InIndicesPtr(), instance_num, instance_size, buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedInPtr(), buf_manager.OutIndicesPtr(), ctx->stream()->As()->cuda_stream()); DispatchIndexSize(ctx->stream(), elem_cnt, instance_size, buf_manager.SortedInPtr(), buf_manager.OutIndicesPtr(), values->mut_dptr(), indices->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(dtype) \ REGISTER_USER_KERNEL("median_with_indices") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& in_shape = ctx->InputShape("input", 0); \ const int64_t instance_size = in_shape.dim_vec().back(); \ const int64_t instance_num = in_shape.elem_cnt() / instance_size; \ size_t sort_tmp_buffer_bytes = \ InferTempStorageForSortPairsAscending(instance_num, instance_size); \ size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ size_t sort_indices_buffer_bytes = \ GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(int64_t)); \ return sort_tmp_buffer_bytes + sort_tensor_buffer_bytes + sort_indices_buffer_bytes * 2; \ }); REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(float) REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(double) REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int8_t) REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(uint8_t) REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int32_t) REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/min_max_observer_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include namespace oneflow { template void GenQuantScaleSymmetric(const T* in_ptr, const int32_t quantization_bit, const int64_t num_elements, T* scale, T* zero_point) { T in_max = *std::max_element(in_ptr, in_ptr + num_elements); T in_min = *std::min_element(in_ptr, in_ptr + num_elements); in_max = std::max(std::abs(in_max), std::abs(in_min)); T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; *scale = in_max / denominator; *zero_point = 0; } template void GenQuantScaleAffine(const T* in_ptr, const int32_t quantization_bit, const int64_t num_elements, T* scale, T* zero_point) { T in_max = *std::max_element(in_ptr, in_ptr + num_elements); T in_min = *std::min_element(in_ptr, in_ptr + num_elements); T denominator = static_cast(pow(2.0, quantization_bit)) - 1; *scale = (in_max - in_min) / denominator; *zero_point = -std::nearbyint(in_min / (*scale)); } template void GenQuantScaleCambricon(const T* in_ptr, const int32_t quantization_bit, const int64_t num_elements, T* scale, T* zero_point) { T in_max = *std::max_element(in_ptr, in_ptr + num_elements); T in_min = *std::min_element(in_ptr, in_ptr + num_elements); in_max = std::max(std::abs(in_max), std::abs(in_min)); *scale = std::floor(std::log2(in_max)) - (quantization_bit - 2); *zero_point = 0; } template class CpuMinMaxObserverKernel final : public user_op::OpKernel { public: CpuMinMaxObserverKernel() = default; ~CpuMinMaxObserverKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const bool per_layer_quantization = ctx->Attr("per_layer_quantization"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const T* in_ptr = in->dptr(); T* scale_ptr = scale->mut_dptr(); T* zero_point_ptr = zero_point->mut_dptr(); if (quantization_formula == "google") { // NOTE(Liang Depeng): per-layer quantization by default int64_t outer_num = 1; int64_t inner_num = in->shape_view().elem_cnt(); if (!per_layer_quantization) { // per-channel quantization outer_num = in->shape_view().At(0); inner_num = in->shape_view().Count(1); } if (quantization_scheme == "symmetric") { FOR_RANGE(int64_t, c, 0, outer_num) { GenQuantScaleSymmetric(in_ptr, quantization_bit, inner_num, scale_ptr, zero_point_ptr); in_ptr += inner_num; scale_ptr += 1; zero_point_ptr += 1; } } else { // quantization_scheme == "affine" FOR_RANGE(int64_t, c, 0, outer_num) { GenQuantScaleAffine(in_ptr, quantization_bit, inner_num, scale_ptr, zero_point_ptr); in_ptr += inner_num; scale_ptr += 1; zero_point_ptr += 1; } } } else if (quantization_formula == "cambricon") { if (!per_layer_quantization) { UNIMPLEMENTED() << " per-channel mode is not supported in cambricon scheme"; } GenQuantScaleCambricon(in_ptr, quantization_bit, in->shape_view().elem_cnt(), scale_ptr, zero_point_ptr); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MIN_MAX_OBSERVER_KERNEL(dtype) \ REGISTER_USER_KERNEL("min_max_observer") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_MIN_MAX_OBSERVER_KERNEL(float); REGISTER_MIN_MAX_OBSERVER_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/min_max_observer_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { // NOTE(Liang Depeng): refer to // https://stackoverflow.com/questions/17371275/implementing-max-reduce-in-cuda template __global__ void ReduceMaxMinPerLayer(const T* input_ptr, const int64_t elements, T* max_ptr, T* min_ptr) { extern __shared__ unsigned char shared_max_min_memory[]; T* shared_max = reinterpret_cast(shared_max_min_memory); T* shared_min = shared_max + blockDim.x; int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; shared_max[tid] = -FLT_MAX; shared_min[tid] = -FLT_MAX; while (gid < elements) { shared_max[tid] = max(shared_max[tid], input_ptr[gid]); shared_min[tid] = max(shared_min[tid], -input_ptr[gid]); gid += gridDim.x * blockDim.x; } __syncthreads(); gid = (blockDim.x * blockIdx.x) + tid; for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s && gid < elements) { shared_max[tid] = max(shared_max[tid], shared_max[tid + s]); shared_min[tid] = max(shared_min[tid], shared_min[tid + s]); } __syncthreads(); } if (tid == 0) { cuda::atomic::Max(max_ptr, shared_max[0]); cuda::atomic::Max(min_ptr, shared_min[0]); } } template __global__ void ReduceMaxMinPerChannel(const T* input_ptr, const int64_t elements, const int64_t num_channels, const int64_t panel_size, T* max_ptr, T* min_ptr) { extern __shared__ unsigned char shared_max_min_memory[]; T* shared_max = reinterpret_cast(shared_max_min_memory); T* shared_min = shared_max + blockDim.x; int64_t cur_channel = blockIdx.x; int64_t tid = threadIdx.x; while (cur_channel < num_channels) { shared_max[tid] = -FLT_MAX; shared_min[tid] = -FLT_MAX; int64_t index = (panel_size * cur_channel) + tid; int64_t end = panel_size * (cur_channel + 1); while (index < end && index < elements) { shared_max[tid] = max(shared_max[tid], input_ptr[index]); shared_min[tid] = max(shared_min[tid], -input_ptr[index]); index += blockDim.x; } __syncthreads(); for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { shared_max[tid] = max(shared_max[tid], shared_max[tid + s]); shared_min[tid] = max(shared_min[tid], shared_min[tid + s]); } __syncthreads(); } if (tid == 0) { cuda::atomic::Max(&max_ptr[cur_channel], shared_max[0]); cuda::atomic::Max(&min_ptr[cur_channel], shared_min[0]); } // __syncthreads(); cur_channel += gridDim.x; } } template __global__ void InitMaxMin(const int64_t elements, T* max_ptr, T* min_ptr) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { max_ptr[gid] = -FLT_MAX; min_ptr[gid] = -FLT_MAX; gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointSymmetric(const T* max_ptr, const T* min_ptr, const int64_t elements, const double quantization_bit, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T weight_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid])); T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; scale[gid] = weight_max / denominator; zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointAffine(const T* max_ptr, const T* min_ptr, const int64_t elements, const double quantization_bit, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T denominator = static_cast(pow(2.0, quantization_bit)) - 1; T min = -min_ptr[gid]; T s = (max_ptr[gid] - min) / denominator; scale[gid] = s; zero_point[gid] = -nearbyint(min / s); gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointCambricon(const T* max_ptr, const T* min_ptr, const int64_t elements, const double quantization_bit, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T weight_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid])); // T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; scale[gid] = floor(log2(weight_max)) - (quantization_bit - 2); zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } ep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num, size_t shared_mem_size) { ep::CudaLaunchConfig config; stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1); config.shared_mem_size = shared_mem_size; return config; } } // namespace #define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \ (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__); template class GpuMinMaxObserverKernel final : public user_op::OpKernel { public: GpuMinMaxObserverKernel() = default; ~GpuMinMaxObserverKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const bool per_layer_quantization = ctx->Attr("per_layer_quantization"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const int64_t elements = in->shape_view().elem_cnt(); const int64_t channel = scale->shape_view().At(0); const int64_t panel_size = elements / channel; T* max_ptr = tmp_buffer->mut_dptr(); T* min_ptr = max_ptr + channel; auto* cuda_stream = ctx->stream()->As(); LAUNCH_CUDA_KERNEL((InitMaxMin), cuda_stream, channel, 0, channel, max_ptr, min_ptr); if (per_layer_quantization) { LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), cuda_stream, elements, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, max_ptr, min_ptr); } else { // per-channel quantization // NOTE(Liang Depeng): each block of threads will be responsible for // computing the max and min values of the whole channel. LAUNCH_CUDA_KERNEL((ReduceMaxMinPerChannel), cuda_stream, channel * kCudaThreadsNumPerBlock, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, channel, panel_size, max_ptr, min_ptr); } if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), cuda_stream, channel, 0, max_ptr, min_ptr, channel, static_cast(quantization_bit), scale->mut_dptr(), zero_point->mut_dptr()); } else { // quantization_scheme == "affine" LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), cuda_stream, channel, 0, max_ptr, min_ptr, channel, static_cast(quantization_bit), scale->mut_dptr(), zero_point->mut_dptr()); } } else if (quantization_formula == "cambricon") { if (!per_layer_quantization) { UNIMPLEMENTED() << " per-channel mode is not supported in cambricon scheme"; } LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), cuda_stream, channel, 0, max_ptr, min_ptr, channel, static_cast(quantization_bit), scale->mut_dptr(), zero_point->mut_dptr()); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MIN_MAX_OBSERVER_KERNEL(dtype) \ REGISTER_USER_KERNEL("min_max_observer") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 1; \ if (ctx->Attr("per_layer_quantization") == false) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ tmp_buffer_size = in_shape.At(0); \ } \ return 2 * tmp_buffer_size * sizeof(dtype); \ }) REGISTER_MIN_MAX_OBSERVER_KERNEL(float); REGISTER_MIN_MAX_OBSERVER_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/mode_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { template std::unique_ptr NewMemcpyPrimitive(Context* ctx) { return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::MemcpyKind::kDtoD); } template class CpuModeKernel final : public user_op::OpKernel { public: CpuModeKernel() = default; ~CpuModeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("input", 0); const int64_t num_axes = in->shape_view().NumAxes(); const int64_t size = in->shape_view().elem_cnt(); if (size == 0) return; const int64_t stride = in->shape_view().At(num_axes - 1); const int64_t instance_num = size / stride; user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); auto memcpy = NewMemcpyPrimitive(ctx); CHECK(memcpy); memcpy->Launch(ctx->stream(), tmp_buffer->mut_dptr(), in->dptr(), size * sizeof(T)); const int64_t thread_num = std::min(instance_num, (int64_t)Singleton::Get()->thread_num()); const BalancedSplitter bs(instance_num, thread_num); BlockingCounter bc(thread_num); FOR_RANGE(int64_t, thread_id, 0, thread_num) { const Range range = bs.At(thread_id); Singleton::Get()->AddWork([=, &bc]() { FOR_RANGE(int64_t, i, range.begin(), range.end()) { T* in_ptr = tmp_buffer->mut_dptr() + i * stride; T* val_ptr = values->mut_dptr() + i; int64_t* ind_ptr = indices->mut_dptr() + i; std::vector> elements(stride); T mode = 0; int64_t mode_idx = 0; int64_t temp_freq = 0; int64_t max_freq = 0; FOR_RANGE(int64_t, idx, 0, stride) { elements[idx] = std::make_pair(*(in_ptr + idx), idx); } std::sort(elements.begin(), elements.end(), [=](const auto& i, const auto& j) { return i.first < j.first; }); FOR_RANGE(int64_t, idx, 0, stride) { temp_freq++; if ((idx == stride - 1) || (elements[idx].first != elements[idx + 1].first)) { if (temp_freq > max_freq) { mode = elements[idx].first; mode_idx = elements[idx].second; max_freq = temp_freq; } temp_freq = 0; } } *val_ptr = mode; *ind_ptr = mode_idx; } bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_MODE_KERNEL(dtype) \ REGISTER_USER_KERNEL("mode") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("input", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ return ctx->InputShape("input", 0).elem_cnt() * sizeof(dtype); \ }); REGISTER_CPU_MODE_KERNEL(float) REGISTER_CPU_MODE_KERNEL(double) REGISTER_CPU_MODE_KERNEL(int8_t) REGISTER_CPU_MODE_KERNEL(uint8_t) REGISTER_CPU_MODE_KERNEL(int32_t) REGISTER_CPU_MODE_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/model_update_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/model_update_kernel_util.h" namespace oneflow { namespace { // For bias correction compute in CPU. template T Fastpow(T a, int64_t b) { T ans = static_cast(1); while (b) { if (b & 1) { ans *= a; } a *= a; b >>= 1; } return ans; } template void SumSquares2(int64_t n, const T* src0, T* dst0, const T* src1, T* dst1) { *dst0 += cblas_dot(n, src0, 1, src0, 1); *dst1 += cblas_dot(n, src1, 1, src1, 1); } } // namespace template struct SGDUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy); }; template void SGDUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; for (int64_t i = 0; i != n; ++i) { if (model_copy != nullptr) { FusedSGDUpdateFunctor()(model_diff + i, model + i, model_copy + i, scale, l1, l2, weight_decay, learning_rate_val); } else { SGDUpdateFunctor()(model_diff + i, model + i, scale, l1, l2, weight_decay, learning_rate_val); } } } template struct SGDUpdateKernelUtil; template struct SGDUpdateKernelUtil; template struct IndexedSlicesSGDUpdateKernelUtil { static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model); }; template void IndexedSlicesSGDUpdateKernelUtil::Update( ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model) { const int64_t n = *num_unique_instance * feature_size; T lr = *learning_rate; lr *= lr_scale; FOR_RANGE(int64_t, i, 0, n) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; SGDUpdateFunctor()(values + i, model + model_idx, static_cast(1), 0.0, 0.0, weight_decay, lr); } } } #define INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU(val_type_pair, key_type_pair, \ idx_type_pair) \ template struct IndexedSlicesSGDUpdateKernelUtil< \ DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; for (int64_t i = 0; i != n; ++i) { MomentumUpdateFunctor()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val); } } template struct MomentumUpdateKernelUtil; template struct MomentumUpdateKernelUtil; template struct IndexedSlicesMomentumMdUpdateKernelUtil { static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum); }; template void IndexedSlicesMomentumMdUpdateKernelUtil::Update( ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { const int64_t n = *num_unique_instance * feature_size; T lr = *learning_rate; lr *= lr_scale; for (int64_t i = 0; i != n; ++i) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; MomentumUpdateFunctor()(values + i, model + model_idx, momentum + model_idx, 1.0, 0.0, 0.0, beta, dampening, nesterov, maximize, weight_decay, lr); } } } #define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU( \ val_type_pair, key_type_pair, idx_type_pair) \ template struct IndexedSlicesMomentumMdUpdateKernelUtil< \ DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU template struct AdamUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v); }; template void AdamUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } learning_rate_val *= lr_scale; FOR_RANGE(int64_t, i, 0, n) { if (model_copy != nullptr) { FusedAdamUpdateFunctor()(model_diff + i, model + i, model_copy + i, m + i, v + i, max_v + i, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val, bias_correction2_val, learning_rate_val); } else { AdamUpdateFunctor()(model_diff + i, model + i, m + i, v + i, max_v + i, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val, bias_correction2_val, learning_rate_val); } } } template struct AdamUpdateKernelUtil; template struct AdamUpdateKernelUtil; template struct IndexedSlicesAdamMdUpdateKernelUtil { static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float lr, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v, T* max_v) { if (learning_rate != nullptr) { lr = *learning_rate; } lr *= lr_scale; float bias_correction1 = 1.0; float bias_correction2 = 1.0; if (bias_correction1_ptr != nullptr) { bias_correction1 = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2 = *bias_correction2_ptr; } const int64_t n = *num_unique_instance * feature_size; FOR_RANGE(int64_t, i, 0, n) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; AdamUpdateFunctor()(values + i, model + model_idx, m + model_idx, v + model_idx, max_v + i, /*scale=*/1.0, /*l1=*/0.0, /*l2=*/0.0, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1, bias_correction2, lr); } } } }; #define INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU(val_type_pair, key_type_pair, \ idx_type_pair) \ template struct IndexedSlicesAdamMdUpdateKernelUtil< \ DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU template struct AdagradUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum); }; template void AdagradUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (train_step_ptr != nullptr) { train_step = *train_step_ptr + 1; } // train_step_ptr start from zero. if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val = learning_rate_val * lr_scale / (1 + (train_step - 1) * lr_decay); FOR_RANGE(int64_t, i, 0, n) { AdagradUpdateFunctor()(model_diff + i, model + i, sum + i, scale, l1, l2, epsilon, weight_decay, learning_rate_val); } } template struct AdagradUpdateKernelUtil; template struct AdagradUpdateKernelUtil; template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } FOR_RANGE(int64_t, i, 0, n) { LambGradFunctor()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2, beta1, beta2, epsilon, do_bias_correction, bias_correction1_val, bias_correction2_val); } T* w_norm_2 = norm_buffer; T* g_norm_2 = norm_buffer + 1; Memset(stream, norm_buffer, 0, 2 * sizeof(T)); SumSquares2(n, model, w_norm_2, adam_diff, g_norm_2); learning_rate_val *= lr_scale; const float lr = LambLRFunctor()(learning_rate_val, w_norm_2, g_norm_2); FOR_RANGE(int64_t, i, 0, n) { LambUpdateFunctor()(lr, weight_decay, adam_diff + i, model + i); } } template struct LambUpdateKernelUtil; template struct LambUpdateKernelUtil; template<> struct BiasCorrectionFactorKernelUtil { static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step, float* out); }; void BiasCorrectionFactorKernelUtil::BiasCorrectionFactorCompute( ep::Stream* stream, float beta, const int64_t* train_step, float* out) { const float bias_correction_factor = 1.0 - Fastpow(beta, *train_step + 1); *out = bias_correction_factor; } template struct RmsPropUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* mean_square, T* mean_gradient); }; template void RmsPropUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* mean_square, T* mean_gradient) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; if (centered) { FOR_RANGE(int64_t, i, 0, n) { RmsPropUpdateFunctor()(model_diff + i, model + i, n, scale, l1, l2, mean_square + i, mean_gradient + i, epsilon, weight_decay, decay_rate, learning_rate_val); } } else { FOR_RANGE(int64_t, i, 0, n) { RmsPropUpdateFunctor()(model_diff + i, model + i, n, scale, l1, l2, mean_square + i, nullptr, epsilon, weight_decay, decay_rate, learning_rate_val); } } } template struct RmsPropUpdateKernelUtil; template struct RmsPropUpdateKernelUtil; template struct LarsUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp); }; template void LarsUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp) { if (skip_if != nullptr && *skip_if != 0) { return; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } T model_norm = data_tmp[0]; T model_diff_norm = data_tmp[1]; FOR_RANGE(int64_t, i, 0, n) { model_diff_tmp[i] = CastScaleRegularizeGradientFunctor()(model_diff[i], model[i], scale, l1, l2); } Memset(stream, data_tmp, 0, 2 * sizeof(T)); SumSquares2(n, model, &model_norm, model_diff_tmp, &model_diff_norm); model_norm = std::sqrt(model_norm); model_diff_norm = std::sqrt(model_diff_norm); T lars = static_cast(1); if (model_norm > 0 && model_diff_norm > 0) { lars = lars_coefficient * model_norm / (epsilon + model_diff_norm + weight_decay * model_norm); } T lr = *learning_rate; lr *= lr_scale; T local_learning_rate = lr * lars; FOR_RANGE(int64_t, i, 0, n) { LarsUpdateFunctor()(model_diff_tmp + i, model + i, momentum_beta, momentum + i, weight_decay, local_learning_rate); } } template struct LarsUpdateKernelUtil; template struct LarsUpdateKernelUtil; template struct FtrlUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z); }; template void FtrlUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; for (int64_t i = 0; i != n; ++i) { FtrlUpdateFunctor()(model_diff + i, model + i, accumulate + i, z + i, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val); } } template struct FtrlUpdateKernelUtil; template struct FtrlUpdateKernelUtil; template struct AdadeltaUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas); }; template void AdadeltaUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; for (int64_t i = 0; i != n; ++i) { AdadeltaUpdateFunctor()(model_diff + i, model + i, square_avgs + i, acc_deltas + i, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val); } } template struct AdadeltaUpdateKernelUtil; template struct AdadeltaUpdateKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/model_update_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/model_update_kernel_util.h" #include #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void SGDUpdateGpu(int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { if (model_copy != nullptr) { FusedSGDUpdateFunctor()(model_diff + i, model + i, model_copy + i, scale, l1, l2, weight_decay, learning_rate_val); } else { SGDUpdateFunctor()(model_diff + i, model + i, scale, l1, l2, weight_decay, learning_rate_val); } } } template __global__ void IndexedSlicesSGDUpdateGpu(float weight_decay, float lr_scale, const IDX feature_size, const int64_t lower_bound, const int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model) { const int64_t n = *num_unique_instance * feature_size; T lr = *learning_rate; lr *= lr_scale; CUDA_1D_KERNEL_LOOP_T(IDX, i, n) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; SGDUpdateFunctor()(values + i, model + model_idx, static_cast(1), 0.0, 0.0, weight_decay, lr); } } } template __global__ void SumSquares2(int64_t n, const T* src0, T* dst0, const T* src1, T* dst1) { T t_sum0 = 0; T t_sum1 = 0; CUDA_1D_KERNEL_LOOP(i, n) { t_sum0 += src0[i] * src0[i]; t_sum1 += src1[i] * src1[i]; } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage0; __shared__ typename BlockReduce::TempStorage temp_storage1; T b_sum0 = BlockReduce(temp_storage0).Sum(t_sum0); T b_sum1 = BlockReduce(temp_storage1).Sum(t_sum1); if (threadIdx.x == 0) { cuda::atomic::Add(dst0, b_sum0); cuda::atomic::Add(dst1, b_sum1); } } } // namespace template struct SGDUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy); }; template void SGDUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy) { SGDUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, model_copy); } template struct SGDUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, float16* model_copy); }; template void SGDUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, float16* model_copy) { SGDUpdateKernelUtil::Update( stream, n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, reinterpret_cast(model_copy)); } template struct SGDUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, float16* model_copy); }; template void SGDUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, float16* model_copy) { SGDUpdateKernelUtil::Update( stream, n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, reinterpret_cast(model_copy)); } template struct SGDUpdateKernelUtil; template struct SGDUpdateKernelUtil; template struct SGDUpdateKernelUtil; template struct IndexedSlicesSGDUpdateKernelUtil { static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model); }; template void IndexedSlicesSGDUpdateKernelUtil::Update( ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model) { IndexedSlicesSGDUpdateGpu <<As()->cuda_stream()>>>( weight_decay, lr_scale, feature_size, lower_bound, upper_bound, num_unique_instance, learning_rate, indices, values, model); } #define INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA(val_type_pair, key_type_pair, \ idx_type_pair) \ template struct IndexedSlicesSGDUpdateKernelUtil< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); #undef INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA namespace { template __global__ void MomentumUpdateGpu(int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { MomentumUpdateFunctor()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val); } } template __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { const int64_t n = *num_unique_instance * feature_size; T lr = *learning_rate; lr *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; MomentumUpdateFunctor()(values + i, model + model_idx, momentum + model_idx, static_cast(1), 0.0, 0.0, beta, dampening, nesterov, maximize, weight_decay, lr); } } } } // namespace template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum) { MomentumUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, momentum); } template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* momentum) { MomentumUpdateKernelUtil::Update( stream, n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, momentum); } template struct MomentumUpdateKernelUtil; template struct MomentumUpdateKernelUtil; template struct MomentumUpdateKernelUtil; template struct IndexedSlicesMomentumMdUpdateKernelUtil { static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum); }; template void IndexedSlicesMomentumMdUpdateKernelUtil::Update( ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { IndexedSlicesMomentumUpdateGpu <<As()->cuda_stream()>>>( beta, dampening, nesterov, maximize, weight_decay, lr_scale, feature_size, lower_bound, upper_bound, num_unique_instance, learning_rate, indices, values, model, momentum); } #define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA( \ val_type_pair, key_type_pair, idx_type_pair) \ template struct IndexedSlicesMomentumMdUpdateKernelUtil< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); #undef INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA namespace { __global__ void BiasCorrectionFactorKernelGpu(float beta, const int64_t* train_step, float* out) { const auto exponent = static_cast(*train_step + 1); const float bias_correction_factor = 1.0 - static_cast(pow(beta, exponent)); *out = bias_correction_factor; } template __global__ void AdamUpdateGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { if (model_copy != nullptr) { FusedAdamUpdateFunctor()(model_diff + i, model + i, model_copy + i, m + i, v + i, max_v + i, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val, bias_correction2_val, learning_rate_val); } else { AdamUpdateFunctor()(model_diff + i, model + i, m + i, v + i, max_v + i, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val, bias_correction2_val, learning_rate_val); } } } template __global__ void AdamUpdateBetaTGpu(const T beta1, const T beta2, const int64_t* skip_if, T* beta1_t, T* beta2_t) { if (skip_if != nullptr && *skip_if != 0) { return; } *beta1_t *= beta1; *beta2_t *= beta2; } template __global__ void IndexedSlicesAdamUpdateGpu( float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float lr, float lr_scale, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v, T* max_v) { if (learning_rate != nullptr) { lr = *learning_rate; } lr *= lr_scale; float bias_correction1 = 1.0; float bias_correction2 = 1.0; if (bias_correction1_ptr != nullptr) { bias_correction1 = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2 = *bias_correction2_ptr; } const int64_t n = *num_unique_instance * feature_size; CUDA_1D_KERNEL_LOOP(i, n) { const IDX indices_idx = i / feature_size; const IDX inner_idx = i - indices_idx * feature_size; const IDX instance_id = indices[indices_idx]; if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; AdamUpdateFunctor()(values + i, model + model_idx, m + model_idx, v + model_idx, max_v + i, static_cast(1), 0, 0, beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1, bias_correction2, lr); } } } template __global__ void LambGradGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* bias_correction1_ptr, const float* bias_correction2_ptr) { if (skip_if != nullptr && *skip_if != 0) { return; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } CUDA_1D_KERNEL_LOOP(i, n) { LambGradFunctor()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2, beta1, beta2, epsilon, do_bias_correction, bias_correction1_val, bias_correction2_val); } } template __global__ void LambUpdateGpu(int64_t n, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate_ptr, const int64_t* skip_if, const T* w_norm_2, const T* g_norm_2, const T* adam_diff, T* model) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; } learning_rate_val *= lr_scale; const float lr = LambLRFunctor()(learning_rate_val, w_norm_2, g_norm_2); CUDA_1D_KERNEL_LOOP(i, n) { LambUpdateFunctor()(lr, weight_decay, adam_diff + i, model + i); } } } // namespace template struct AdamUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v); }; template void AdamUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v) { AdamUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate, scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr, model_diff, model, model_copy, m, v, max_v); } template struct AdamUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, float16* model_copy, T* m, T* v, T* max_v); }; template void AdamUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff, T* model, float16* model_copy, T* m, T* v, T* max_v) { AdamUpdateKernelUtil::Update( stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate, scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr, model_diff, model, reinterpret_cast(model_copy), m, v, max_v); } template struct AdamUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const float16* model_diff, T* model, float16* model_copy, T* m, T* v, T* max_v); }; template void AdamUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const float16* model_diff, T* model, float16* model_copy, T* m, T* v, T* max_v) { AdamUpdateKernelUtil::Update( stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate, scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr, reinterpret_cast(model_diff), model, reinterpret_cast(model_copy), m, v, max_v); } template struct AdamUpdateKernelUtil; template struct AdamUpdateKernelUtil; template struct AdamUpdateKernelUtil; template __global__ void AdagradUpdateGpu(int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (train_step_ptr != nullptr) { train_step = *train_step_ptr + 1; } // train_step_ptr start from zero. if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val = learning_rate_val * lr_scale / (1 + (train_step - 1) * lr_decay); CUDA_1D_KERNEL_LOOP(i, n) { AdagradUpdateFunctor()(model_diff + i, model + i, sum + i, scale, l1, l2, epsilon, weight_decay, learning_rate_val); } } template struct AdagradUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum); }; template void AdagradUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum) { AdagradUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, lr_decay, epsilon, weight_decay, learning_rate_val, lr_scale, train_step, learning_rate, train_step_ptr, scale_by_ptr, skip_if, model_diff, model, sum); } template struct AdagradUpdateKernelUtil; template struct AdagradUpdateKernelUtil; template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) { LambGradGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, beta1, beta2, epsilon, scale_by_ptr, skip_if, model_diff, adam_diff, model, m, v, do_bias_correction, bias_correction1_val, bias_correction2_val, bias_correction1_ptr, bias_correction2_ptr); T* w_norm_2 = norm_buffer; T* g_norm_2 = norm_buffer + 1; Memset(stream, norm_buffer, 0, 2 * sizeof(T)); SumSquares2 <<As()->cuda_stream()>>>(n, model, w_norm_2, adam_diff, g_norm_2); LambUpdateGpu<<As()->cuda_stream()>>>( n, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, skip_if, w_norm_2, g_norm_2, adam_diff, model); } template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) { LambUpdateKernelUtil::Update( stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate_val, lr_scale, do_bias_correction, bias_correction1_val, bias_correction2_val, learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr, scale_by_ptr, skip_if, reinterpret_cast(model_diff), adam_diff, model, m, v, norm_buffer); } template struct LambUpdateKernelUtil; template struct LambUpdateKernelUtil; template struct LambUpdateKernelUtil; template struct IndexedSlicesAdamMdUpdateKernelUtil { static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float lr, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v, T* max_v); }; template void IndexedSlicesAdamMdUpdateKernelUtil::Update( ep::Stream* stream, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float lr, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v, T* max_v) { IndexedSlicesAdamUpdateGpu <<As()->cuda_stream()>>>( beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, lr, lr_scale, feature_size, lower_bound, upper_bound, num_unique_instance, learning_rate, bias_correction1_ptr, bias_correction2_ptr, indices, values, model, m, v, max_v); } #define INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA( \ val_type_pair, key_type_pair, idx_type_pair) \ template struct IndexedSlicesAdamMdUpdateKernelUtil< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \ OF_PP_PAIR_FIRST(idx_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); #undef INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA template<> struct BiasCorrectionFactorKernelUtil { static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step, float* out); }; void BiasCorrectionFactorKernelUtil::BiasCorrectionFactorCompute( ep::Stream* stream, float beta, const int64_t* train_step, float* out) { BiasCorrectionFactorKernelGpu<<<1, 1, 0, stream->As()->cuda_stream()>>>( beta, train_step, out); } namespace { template __global__ void RmsPropUpdateGpu(int64_t n, T scale, float l1, float l2, T* mean_square, T* mean_gradient, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { RmsPropUpdateFunctor()(model_diff + i, model + i, n, scale, l1, l2, mean_square + i, (centered ? mean_gradient + i : nullptr), epsilon, weight_decay, decay_rate, learning_rate_val); } } } // namespace template struct RmsPropUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* mean_square, T* mean_gradient); }; template void RmsPropUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* mean_square, T* mean_gradient) { if (centered) { RmsPropUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, mean_square, mean_gradient, epsilon, weight_decay, decay_rate, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model); } else { RmsPropUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, mean_square, mean_gradient, epsilon, weight_decay, decay_rate, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model); } } template struct RmsPropUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* mean_square, T* mean_gradient); }; template void RmsPropUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* mean_square, T* mean_gradient) { RmsPropUpdateKernelUtil::Update( stream, n, scale, l1, l2, centered, epsilon, weight_decay, decay_rate, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, mean_square, mean_gradient); } template struct RmsPropUpdateKernelUtil; template struct RmsPropUpdateKernelUtil; template struct RmsPropUpdateKernelUtil; namespace { template __global__ void LarsScaleModelDiffGpu(int64_t n, T scale, float l1, float l2, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* model_diff_tmp) { if (skip_if != nullptr && *skip_if != 0) { return; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } CUDA_1D_KERNEL_LOOP(i, n) { model_diff_tmp[i] = CastScaleRegularizeGradientFunctor()(model_diff[i], model[i], scale, l1, l2); } } template __global__ void LarsGetLocalLearningRateGpu(const float* learning_rate, float lr_scale, T weight_decay, T epsilon, T lars_coefficient, const int64_t* skip_if, T* data_tmp) { if (skip_if != nullptr && *skip_if != 0) { return; } T* model_norm = &data_tmp[0]; T* model_diff_norm = &data_tmp[1]; T* local_learning_rate = &data_tmp[2]; *model_norm = std::sqrt(*model_norm); *model_diff_norm = std::sqrt(*model_diff_norm); T lars = static_cast(1); if (*model_norm > 0 && *model_diff_norm > 0) { lars = lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm) + weight_decay * (*model_norm)); } T lr = *learning_rate; lr *= lr_scale; *local_learning_rate = lr * lars; } template __global__ void LarsUpdateGpu(int64_t n, float momentum_beta, T* momentum, float weight_decay, const int64_t* skip_if, T* local_learning_rate, T* model_diff_tmp, T* model) { if (skip_if != nullptr && *skip_if != 0) { return; } CUDA_1D_KERNEL_LOOP(i, n) { LarsUpdateFunctor()(model_diff_tmp + i, model + i, momentum_beta, momentum + i, weight_decay, *local_learning_rate); } } } // namespace template struct LarsUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp); }; template void LarsUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp) { LarsScaleModelDiffGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, scale_by_ptr, skip_if, model_diff, model, model_diff_tmp); T* model_norm = data_tmp; T* model_diff_norm = data_tmp + 1; T* local_learning_rate = data_tmp + 2; Memset(stream, data_tmp, 0, 2 * sizeof(T)); SumSquares2<<As()->cuda_stream()>>>(n, model, model_norm, model_diff_tmp, model_diff_norm); LarsGetLocalLearningRateGpu<<<1, 1, 0, stream->As()->cuda_stream()>>>( learning_rate, lr_scale, weight_decay, epsilon, lars_coefficient, skip_if, data_tmp); LarsUpdateGpu<<As()->cuda_stream()>>>( n, momentum_beta, momentum, weight_decay, skip_if, local_learning_rate, model_diff_tmp, model); } template struct LarsUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp); }; template void LarsUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp) { LarsUpdateKernelUtil::Update( stream, n, scale, l1, l2, momentum_beta, epsilon, lars_coefficient, weight_decay, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, momentum, data_tmp, model_diff_tmp); } template struct LarsUpdateKernelUtil; template struct LarsUpdateKernelUtil; template struct LarsUpdateKernelUtil; template __global__ void FtrlUpdateGpu(int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { FtrlUpdateFunctor()(model_diff + i, model + i, accumulate + i, z + i, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val); } } template struct FtrlUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z); }; template void FtrlUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z) { FtrlUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, accumulate, z); } template struct FtrlUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* accumulate, T* z); }; template void FtrlUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* accumulate, T* z) { FtrlUpdateKernelUtil::Update( stream, n, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, accumulate, z); } template struct FtrlUpdateKernelUtil; template struct FtrlUpdateKernelUtil; template struct FtrlUpdateKernelUtil; template __global__ void AdadeltaUpdateGpu(int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; CUDA_1D_KERNEL_LOOP(i, n) { AdadeltaUpdateFunctor()(model_diff + i, model + i, square_avgs + i, acc_deltas + i, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val); } } template struct AdadeltaUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas); }; template void AdadeltaUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas) { AdadeltaUpdateGpu<<As()->cuda_stream()>>>( n, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, square_avgs, acc_deltas); } template struct AdadeltaUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* square_avgs, T* acc_deltas); }; template void AdadeltaUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* square_avgs, T* acc_deltas) { AdadeltaUpdateKernelUtil::Update( stream, n, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast(model_diff), model, square_avgs, acc_deltas); } template struct AdadeltaUpdateKernelUtil; template struct AdadeltaUpdateKernelUtil; template struct AdadeltaUpdateKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/model_update_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MODEL_UPDATE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_MODEL_UPDATE_KERNEL_UTIL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/math_unary_elementwise_func.h" namespace oneflow { template struct CastScaleRegularizeGradientFunctor { OF_DEVICE_FUNC T operator()(G model_diff, T model, T scale, float l1, float l2) const { return static_cast(model_diff) * scale + l1 * ((model >= 0) - (model <= 0)) + l2 * model; } }; template struct SGDUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T scale, float l1, float l2, float weight_decay, float learning_rate) const { const T model_val = *model; const T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_model = model_val - learning_rate * (model_diff_t + weight_decay * model_val); *model = next_model; } }; template struct FusedSGDUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, C* model_copy, T scale, float l1, float l2, float weight_decay, float learning_rate) const { const T model_val = *model; const T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_model = model_val - learning_rate * (model_diff_t + weight_decay * model_val); *model = next_model; *model_copy = static_cast(next_model); } }; template struct SGDUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, C* model_copy); }; template struct IndexedSlicesSGDUpdateKernelUtil final { static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model); }; template struct MomentumUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* momentum, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); T next_momentum = beta * *momentum + (1.0f - dampening) * model_diff_t; *momentum = next_momentum; if (!nesterov) { model_diff_t = next_momentum; } else { model_diff_t += beta * next_momentum; } T alpha = -learning_rate; if (maximize) { alpha = learning_rate; } const T next_model = model_val + alpha * model_diff_t - learning_rate * weight_decay * model_val; *model = next_model; } }; template struct AdamUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* m, T* v, T* max_v, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, float bias_correction1, float bias_correction2, float learning_rate) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_m = beta1 * *m + (1 - beta1) * model_diff_t; *m = next_m; const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t; *v = next_v; T denom = 0; if (amsgrad) { const T next_max_v = *max_v > next_v ? *max_v : next_v; // use std::max has bug in GPU kernel. *max_v = next_max_v; denom = (sqrt(next_max_v) / sqrt(bias_correction2)) + epsilon; } else { denom = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon; } const T step_size = learning_rate / bias_correction1; *model = model_val - step_size * (next_m / denom) - learning_rate * weight_decay * model_val; } }; template struct FusedAdamUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, float bias_correction1, float bias_correction2, float learning_rate) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_m = beta1 * *m + (1 - beta1) * model_diff_t; *m = next_m; const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t; *v = next_v; T denom = 0; if (amsgrad) { const T next_max_v = *max_v > next_v ? *max_v : next_v; // use std::max has bug in GPU kernel. *max_v = next_max_v; denom = (sqrt(next_max_v) / sqrt(bias_correction2)) + epsilon; } else { denom = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon; } const T step_size = learning_rate / bias_correction1; const T next_model = model_val - step_size * (next_m / denom) - learning_rate * weight_decay * model_val; *model = next_model; *model_copy = static_cast(next_model); } }; template struct AdagradUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* sum, T scale, float l1, float l2, float epsilon, float weight_decay, float learning_rate) { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_sum = *sum + model_diff_t * model_diff_t; *sum = next_sum; *model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t - learning_rate * weight_decay * model_val; } }; template struct LambGradFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* adam_diff, T* model, T* m, T* v, float scale, float l1, float l2, float beta1, float beta2, float epsilon, bool do_bias_correction, float bias_correction1, float bias_correction2) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_m = beta1 * *m + (1 - beta1) * model_diff_t; const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t; *m = next_m; *v = next_v; T numerator = 0; T denominator = 0; if (do_bias_correction) { numerator = next_m / bias_correction1; denominator = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon; } else { numerator = next_m; denominator = sqrt(next_v) + epsilon; } *adam_diff = numerator / denominator; } }; template struct LambLRFunctor { OF_DEVICE_FUNC float operator()(const float learning_rate_val, const T* w_norm_2, const T* g_norm_2) const { float lr = learning_rate_val; const T w_norm_val = sqrt(*w_norm_2); const T g_norm_val = sqrt(*g_norm_2); T trust_ratio = 1; if (w_norm_val > 0 && g_norm_val > 0) { trust_ratio = w_norm_val / g_norm_val; } lr *= trust_ratio; return lr; } }; template struct LambUpdateFunctor { OF_DEVICE_FUNC void operator()(const float learning_rate, const float weight_decay, const T* adam_diff, T* model) const { const T model_val = *model; *model = model_val - learning_rate * (*adam_diff + weight_decay * model_val); } }; template struct FtrlUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* accumulate, T* z, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate) { const T model_val = *model; const T z_val = *z; const float lr_reciprocal = static_cast(1.0) / learning_rate; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T accumulate_val = *accumulate; const T next_accumulate_val = accumulate_val + model_diff_t * model_diff_t; const T acc_powered = pow(accumulate_val, lr_power); const T next_acc_powered = pow(next_accumulate_val, lr_power); const T sigma = (next_acc_powered - acc_powered) * lr_reciprocal; const T new_z_val = z_val + model_diff_t - sigma * model_val; T new_model = static_cast(0.0); if (abs(new_z_val) >= lambda1) { new_model = (copysign(lambda1, new_z_val) - new_z_val) / ((beta + next_acc_powered) * lr_reciprocal + lambda2) - learning_rate * weight_decay * model_val; } *model = new_model; *accumulate = next_accumulate_val; *z = new_z_val; } }; template struct AdadeltaUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* square_avgs, T* acc_deltas, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate) { const T model_val = *model; T model_diff_val = *model_diff; if (maximize) { model_diff_val = -model_diff_val; } T model_diff_t = CastScaleRegularizeGradientFunctor()(model_diff_val, model_val, scale, l1, l2); T square_avgs_val = *square_avgs; T new_square_avgs_val = square_avgs_val * rho + model_diff_t * model_diff_t * (1.0f - rho); T square_avgs_std = sqrt(new_square_avgs_val + epsilon); T acc_delta_val = *acc_deltas; T delta = sqrt(acc_delta_val + epsilon) / square_avgs_std * model_diff_t; T new_acc_deltas = acc_delta_val * rho + delta * delta * (1.0f - rho); T new_model = model_val - learning_rate * delta; *model = new_model; *square_avgs = new_square_avgs_val; *acc_deltas = new_acc_deltas; } }; template struct BiasCorrectionFactorKernelUtil { public: static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step, float* out); }; template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template struct IndexedSlicesMomentumMdUpdateKernelUtil { static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum); }; template struct AdamUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v); }; template struct AdagradUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* sum); }; template struct IndexedSlicesAdamMdUpdateKernelUtil { static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float lr, float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v, T* max_v); }; template struct LambUpdateKernelUtil { public: static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, float learning_rate_val, float lr_scale, bool do_bias_correction, float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template struct FtrlUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1, float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* accumulate, T* z); }; template struct AdadeltaUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon, bool maximize, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs, T* acc_deltas); }; template struct RmsPropUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, int64_t n, T scale, float l1, float l2, T* mean_square, T* mean_gradient, float epsilon, float weight_decay, float decay_rate, const float learning_rate) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, *model, scale, l1, l2); T mean_square_val = *mean_square; mean_square_val = (1 - decay_rate) * model_diff_t * model_diff_t + decay_rate * mean_square_val; *mean_square = mean_square_val; T denom_t; if (centered) { T mean_gradient_val = *mean_gradient; mean_gradient_val = (1 - decay_rate) * model_diff_t + decay_rate * mean_gradient_val; *mean_gradient = mean_gradient_val; denom_t = mean_square_val - mean_gradient_val * mean_gradient_val; } else { denom_t = *mean_square; } *model = model_val - learning_rate * model_diff_t * RsqrtFunctor::Forward(denom_t + epsilon); } }; template struct RmsPropUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon, float weight_decay, float decay_rate, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* mean_square, T* mean_gradient); }; template struct LarsUpdateFunctor { OF_DEVICE_FUNC void operator()(T* model_diff_tmp, T* model, float momentum_beta, T* momentum, float weight_decay, const T local_learning_rate) const { const T model_val = *model; T next_momentum = *momentum * momentum_beta - local_learning_rate * *model_diff_tmp; *momentum = next_momentum; const T next_model = model_val + next_momentum - local_learning_rate * weight_decay * model_val; *model = next_model; } }; template struct LarsUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon, float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum, T* data_tmp, T* model_diff_tmp); }; #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/model_update_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/model_update_kernel_util.h" #include "oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(void* ptr, const int64_t num_indices, const int64_t num_values) : ptr_(ptr) { CHECK_NE(num_indices, 0); CHECK_NE(num_values, 0); const size_t unique_diff_indices_bytes = GetCudaAlignedSize(num_indices * sizeof(K)); const size_t unique_diff_values_bytes = GetCudaAlignedSize(num_values * sizeof(T)); const size_t num_unique_diff_indices_bytes = GetCudaAlignedSize(1 * sizeof(int32_t)); CHECK_EQ(num_values % num_indices, 0); IndexedSlicesReduceSumKernelUtil::GetReduceSumWorkspaceSizeInBytes( nullptr, num_indices, num_values / num_indices, &unique_workspace_bytes_); unique_diff_indices_offset_ = 0; unique_diff_values_offset_ = unique_diff_indices_offset_ + unique_diff_indices_bytes; num_unique_diff_indices_offset_ = unique_diff_values_offset_ + unique_diff_values_bytes; unique_workspace_offset_ = num_unique_diff_indices_offset_ + num_unique_diff_indices_bytes; CHECK_GE(unique_workspace_bytes_, 0); total_buffer_size_ = unique_diff_indices_bytes + unique_diff_values_bytes + num_unique_diff_indices_bytes + static_cast(unique_workspace_bytes_); } ~TmpBufferManager() = default; int64_t UniqueWorkspaceBytes() const { return unique_workspace_bytes_; } size_t GetTotalBufferSize() const { return total_buffer_size_; } K* UniqueDiffIndicesPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + unique_diff_indices_offset_); } T* UniqueDiffValuesPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + unique_diff_values_offset_); } int32_t* NumUniqueDiffIndicesPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + num_unique_diff_indices_offset_); } char* UniqueWorkspacePtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(ptr_) + unique_workspace_offset_; } private: size_t unique_diff_indices_offset_; size_t unique_diff_values_offset_; size_t num_unique_diff_indices_offset_; size_t unique_workspace_offset_; int64_t unique_workspace_bytes_; size_t total_buffer_size_; void* ptr_; }; class IndexedSlicesUpdateOpKernelCache final : public user_op::OpKernelCache { public: IndexedSlicesUpdateOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~IndexedSlicesUpdateOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; std::shared_ptr CreateIndexedSlicesUpdateOpKernelCache( user_op::KernelCacheContext* ctx) { const SbpParallel& model_sbp = ctx->SbpParallel4ArgNameAndIndex("model", 0); const user_op::TensorDesc* model_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("model", 0); const int64_t num_model_instances = model_logical_desc->shape().At(0); if (model_sbp.has_split_parallel() && model_sbp.split_parallel().axis() == 0 && ctx->parallel_ctx().parallel_num() > 1) { CHECK(ctx->SbpParallel4ArgNameAndIndex("model_diff_indices", 0).has_broadcast_parallel()); CHECK(ctx->SbpParallel4ArgNameAndIndex("model_diff_values", 0).has_broadcast_parallel()); BalancedSplitter bs(num_model_instances, ctx->parallel_ctx().parallel_num()); return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { return std::make_shared(0, num_model_instances); } } template class SGDUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SGDUpdateKernel() = default; ~SGDUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float* learning_rate_ptr = nullptr; C* model_copy_ptr = nullptr; if (ctx->has_input("model_copy", 0)) { user_op::Tensor* model_copy = ctx->Tensor4ArgNameAndIndex("model_copy", 0); model_copy_ptr = model_copy->mut_dptr(); } if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } SGDUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), model_copy_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_SGD_UPDATE_KERNEL(device, dtype, gtype, ctype) \ REGISTER_USER_KERNEL("sgd_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_SGD_UPDATE_KERNEL(DeviceType::kCPU, float, float, float16); REGISTER_SGD_UPDATE_KERNEL(DeviceType::kCPU, double, double, float16); #ifdef WITH_CUDA REGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float16, float16); REGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float, float16); REGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, double, double, float16); #endif // WITH_CUDA template user_op::InferTmpSizeFn GenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const user_op::TensorDesc& indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& values = ctx->InputTensorDesc("model_diff_values", 0); const int64_t num_indices = indices.shape().elem_cnt(); const int64_t num_values = values.shape().elem_cnt(); TmpBufferManager buffer_manager(nullptr, num_indices, num_values); return buffer_manager.GetTotalBufferSize(); }; } template class IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel { public: IndexedSlicesSGDUpdateKernel() = default; ~IndexedSlicesSGDUpdateKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesSGDUpdateKernelUtil; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff_indices = ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0); const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex("model_diff_values", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); const auto weight_decay = ctx->Attr("weight_decay"); const auto lr_scale = ctx->Attr("learning_rate_scale"); const int64_t num_indices = model_diff_indices->shape_view().elem_cnt(); const int64_t num_values = model_diff_values->shape_view().elem_cnt(); if (num_indices == 0) { CHECK_EQ(num_values, 0); return; } CHECK_NE(num_values, 0); CHECK_EQ(num_values % num_indices, 0); const int64_t feature_size = num_values / num_indices; auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower()); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_indices, num_values); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize()); ReduceSumUtilT::ReduceSum( ctx->stream(), num_indices, feature_size, model_diff_indices->dptr(), model_diff_values->dptr(), buffer_manager.NumUniqueDiffIndicesPtr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update(ctx->stream(), weight_decay, lr_scale, num_indices, feature_size, kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_INDEXED_SLICES_SGD_UPDATE_KERNEL(device_type_v, data_type_pair, \ indices_type_pair) \ REGISTER_USER_KERNEL("indexed_slices_sgd_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("model", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_values", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_indices", 0) \ == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_SGD_UPDATE_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class MomentumUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MomentumUpdateKernel() = default; ~MomentumUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { float learning_rate_val = ctx->Attr("learning_rate_val"); double scale = ctx->Attr("scale"); float l1 = ctx->Attr("l1"); float l2 = ctx->Attr("l2"); float beta = ctx->Attr("beta"); const float dampening = ctx->Attr("dampening"); const bool nesterov = ctx->Attr("nesterov"); const bool maximize = ctx->Attr("maximize"); float weight_decay = ctx->Attr("weight_decay"); const auto lr_scale = ctx->Attr("learning_rate_scale"); const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex("momentum", 0); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } MomentumUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), momentum->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MOMENTUM_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("momentum_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { public: IndexedSlicesMomentumUpdateKernel() = default; ~IndexedSlicesMomentumUpdateKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesMomentumMdUpdateKernelUtil; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff_indices = ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0); const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex("model_diff_values", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex("momentum", 0); const auto beta = ctx->Attr("beta"); const float dampening = ctx->Attr("dampening"); const bool nesterov = ctx->Attr("nesterov"); const bool maximize = ctx->Attr("maximize"); const auto weight_decay = ctx->Attr("weight_decay"); const float lr_scale = ctx->Attr("learning_rate_scale"); const int64_t num_indices = model_diff_indices->shape_view().elem_cnt(); const int64_t num_values = model_diff_values->shape_view().elem_cnt(); if (num_indices == 0) { CHECK_EQ(num_values, 0); return; } CHECK_NE(num_values, 0); CHECK_EQ(num_values % num_indices, 0); const int64_t feature_size = num_values / num_indices; CHECK_EQ(feature_size, model_diff_values->shape_view().Count(model_diff_indices->shape_view().NumAxes())); auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower()); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_indices, num_values); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize()); ReduceSumUtilT::ReduceSum( ctx->stream(), num_indices, feature_size, model_diff_indices->dptr(), model_diff_values->dptr(), buffer_manager.NumUniqueDiffIndicesPtr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update(ctx->stream(), beta, dampening, nesterov, maximize, weight_decay, lr_scale, num_indices, feature_size, kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr(), momentum->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_INDEXED_SLICES_MOMENTUM_UPDATE_KERNEL(device_type_v, data_type_pair, \ indices_type_pair) \ REGISTER_USER_KERNEL("indexed_slices_momentum_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("model", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_values", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_indices", 0) \ == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_MOMENTUM_UPDATE_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class AdamUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AdamUpdateKernel() = default; ~AdamUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex("m", 0); user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const auto weight_decay = ctx->Attr("weight_decay"); const bool amsgrad = ctx->Attr("amsgrad"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); const float lr_scale = ctx->Attr("learning_rate_scale"); T* max_v_ptr = nullptr; if (amsgrad) { user_op::Tensor* max_v = ctx->Tensor4ArgNameAndIndex("max_v", 0); max_v_ptr = max_v->mut_dptr(); CHECK(max_v_ptr != nullptr); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const float bias_correction1_val = ctx->Attr("bias_correction1_val"); const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction1_ptr = bias_correction1->dptr(); } const float bias_correction2_val = ctx->Attr("bias_correction2_val"); const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction2_ptr = bias_correction2->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } C* model_copy_ptr = nullptr; if (ctx->has_input("model_copy", 0)) { user_op::Tensor* model_copy = ctx->Tensor4ArgNameAndIndex("model_copy", 0); model_copy_ptr = model_copy->mut_dptr(); } AdamUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate_ptr, scale_by_ptr, skip_if_ptr, bias_correction1_ptr, bias_correction2_ptr, model_diff->dptr(), model->mut_dptr(), model_copy_ptr, m->mut_dptr(), v->mut_dptr(), max_v_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_ADAM_UPDATE_KERNEL(device, dtype, gtype, ctype) \ REGISTER_USER_KERNEL("adam_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCPU, float, float, float16); REGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCPU, double, double, float16); #ifdef WITH_CUDA REGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16, float16); REGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float, float16); REGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, double, double, float16); #endif // WITH_CUDA template class AdagradUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AdagradUpdateKernel() = default; ~AdagradUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* sum = ctx->Tensor4ArgNameAndIndex("sum", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto lr_decay = ctx->Attr("lr_decay"); const auto epsilon = ctx->Attr("epsilon"); const auto weight_decay = ctx->Attr("weight_decay"); const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; const int64_t train_step_val = ctx->Attr("train_step_val"); const int64_t* train_step_ptr = nullptr; const float lr_scale = ctx->Attr("learning_rate_scale"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } if (ctx->has_input("train_step", 0)) { const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex("train_step", 0); train_step_ptr = train_step->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } AdagradUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, lr_decay, epsilon, weight_decay, learning_rate_val, lr_scale, train_step_val, learning_rate_ptr, train_step_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), sum->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_ADAGRAD_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("adagrad_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel { public: IndexedSlicesAdamUpdateKernel() = default; ~IndexedSlicesAdamUpdateKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesAdamMdUpdateKernelUtil; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1); bias_correction1_ptr = bias_correction1->dptr(); } const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1); bias_correction2_ptr = bias_correction2->dptr(); } const user_op::Tensor* model_diff_indices = ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0); const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex("model_diff_values", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex("m", 0); user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const auto weight_decay = ctx->Attr("weight_decay"); const bool amsgrad = ctx->Attr("amsgrad"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); const float lr_scale = ctx->Attr("learning_rate_scale"); T* max_v_ptr = nullptr; if (amsgrad) { user_op::Tensor* max_v = ctx->Tensor4ArgNameAndIndex("max_v", 0); max_v_ptr = max_v->mut_dptr(); } auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower()); const int64_t num_indices = model_diff_indices->shape_view().elem_cnt(); const int64_t num_values = model_diff_values->shape_view().elem_cnt(); if (num_indices == 0) { CHECK_EQ(num_values, 0); return; } CHECK_NE(num_values, 0); CHECK_EQ(num_values % num_indices, 0); const int64_t feature_size = num_values / num_indices; CHECK_EQ(feature_size, model_diff_values->shape_view().Count(model_diff_indices->shape_view().NumAxes())); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_indices, num_values); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize()); ReduceSumUtilT::ReduceSum( ctx->stream(), num_indices, feature_size, model_diff_indices->dptr(), model_diff_values->dptr(), buffer_manager.NumUniqueDiffIndicesPtr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update( ctx->stream(), beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale, num_indices, feature_size, kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr, buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr(), m->mut_dptr(), v->mut_dptr(), max_v_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_INDEXED_SLICES_ADAM_UPDATE_KERNEL(device_type_v, data_type_pair, \ indices_type_pair) \ REGISTER_USER_KERNEL("indexed_slices_adam_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("model", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_values", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("model_diff_indices", 0) \ == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_ADAM_UPDATE_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class LambTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(LambTmpBufferManager); LambTmpBufferManager(void* ptr, const int64_t n) : ptr_(ptr) { const size_t adam_diff_bytes = GetCudaAlignedSize(n * sizeof(T)); norm_buffer_bytes_ = GetCudaAlignedSize(2 * sizeof(T)); adam_diff_offset_ = 0; norm_buffer_offset_ = adam_diff_offset_ + adam_diff_bytes; total_buffer_size_ = adam_diff_bytes + norm_buffer_bytes_; } ~LambTmpBufferManager() = default; size_t GetNormBufferSize() const { return norm_buffer_bytes_; } size_t GetTotalBufferSize() const { return total_buffer_size_; } T* AdamDiffPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + adam_diff_offset_); } T* NormBufferPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + norm_buffer_offset_); } private: size_t adam_diff_offset_; size_t norm_buffer_offset_; size_t total_buffer_size_; size_t norm_buffer_bytes_; void* ptr_; }; template class LambUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LambUpdateKernel() = default; ~LambUpdateKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex("m", 0); user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); LambTmpBufferManager tbm(tmp_buffer->mut_dptr(), model->shape_view().elem_cnt()); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const auto weight_decay = ctx->Attr("weight_decay"); const auto lr_scale = ctx->Attr("learning_rate_scale"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); const float bias_correction1_val = ctx->Attr("bias_correction1_val"); const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); // Just for Lazy optional input check. CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1); bias_correction1_ptr = bias_correction1->dptr(); } const float bias_correction2_val = ctx->Attr("bias_correction2_val"); const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1); bias_correction2_ptr = bias_correction2->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } LambUpdateKernelUtil::Update( ctx->stream(), m->shape_view().elem_cnt(), scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate_val, lr_scale, do_bias_correction, bias_correction1_val, bias_correction2_val, learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), tbm.AdamDiffPtr(), model->mut_dptr(), m->mut_dptr(), v->mut_dptr(), tbm.NormBufferPtr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; template user_op::InferTmpSizeFn LambGenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); LambTmpBufferManager tbm(nullptr, model.shape().elem_cnt()); return tbm.GetTotalBufferSize(); }; } #define REGISTER_LAMB_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("lamb_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(LambGenInferTmpSizeFn()); REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class BiasCorrectionFactorKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BiasCorrectionFactorKernel() = default; ~BiasCorrectionFactorKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex("train_step", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto beta = ctx->Attr("beta"); BiasCorrectionFactorKernelUtil::BiasCorrectionFactorCompute( ctx->stream(), beta, train_step->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(device) \ REGISTER_USER_KERNEL("adam_bias_correction_factor") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device)); REGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(DeviceType::kCUDA) #endif // WITH_CUDA template class RmsPropUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: RmsPropUpdateKernel() = default; ~RmsPropUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* mean_square = ctx->Tensor4ArgNameAndIndex("mean_square", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto decay_rate = ctx->Attr("decay_rate"); const auto epsilon = ctx->Attr("epsilon"); const auto centered = ctx->Attr("centered"); const auto weight_decay = ctx->Attr("weight_decay"); const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } T* mean_gradient_ptr = nullptr; if (centered) { user_op::Tensor* mean_gradient = ctx->Tensor4ArgNameAndIndex("mean_gradient", 0); mean_gradient_ptr = mean_gradient->mut_dptr(); } RmsPropUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, centered, epsilon, weight_decay, decay_rate, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), mean_square->mut_dptr(), mean_gradient_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_RMSPROP_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("rmsprop_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class LarsTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(LarsTmpBufferManager); LarsTmpBufferManager(void* ptr, const int64_t n) : ptr_(ptr) { model_diff_size_ = GetCudaAlignedSize(n * sizeof(T)); model_diff_offset_ = 0; data_tmp_size_ = GetCudaAlignedSize(3 * sizeof(T)); data_tmp_offset_ = model_diff_offset_ + model_diff_size_; total_buffer_size_ = model_diff_size_ + data_tmp_size_; } ~LarsTmpBufferManager() = default; size_t GetTotalBufferSize() const { return total_buffer_size_; } size_t GetDataTmpBufferSize() const { return data_tmp_size_; } T* ModelDiffPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + model_diff_offset_); } T* DataTmpPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + data_tmp_offset_); } private: size_t model_diff_offset_; size_t model_diff_size_; size_t data_tmp_offset_; size_t data_tmp_size_; size_t total_buffer_size_; void* ptr_; }; template class LarsUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: LarsUpdateKernel() = default; ~LarsUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex("momentum", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); LarsTmpBufferManager tlm(tmp_buffer->mut_dptr(), model->shape_view().elem_cnt()); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const auto momentum_beta = ctx->Attr("momentum_beta"); const auto epsilon = ctx->Attr("epsilon"); const auto lars_coefficient = ctx->Attr("lars_coefficient"); const auto weight_decay = ctx->Attr("weight_decay"); const auto lr_scale = ctx->Attr("learning_rate_scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } LarsUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, momentum_beta, epsilon, lars_coefficient, weight_decay, lr_scale, learning_rate->dptr(), scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), momentum->mut_dptr(), tlm.DataTmpPtr(), tlm.ModelDiffPtr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; template user_op::InferTmpSizeFn LarsGenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); LarsTmpBufferManager tlm(nullptr, model.shape().elem_cnt()); return tlm.GetTotalBufferSize(); }; } #define REGISTER_LARS_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("lars_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(LarsGenInferTmpSizeFn()); REGISTER_LARS_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_LARS_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class FtrlUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FtrlUpdateKernel() = default; ~FtrlUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* accumulate = ctx->Tensor4ArgNameAndIndex("accumulate", 0); user_op::Tensor* z = ctx->Tensor4ArgNameAndIndex("z", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const float lr_power = ctx->Attr("lr_power"); const float lambda1 = ctx->Attr("lambda1"); const float lambda2 = ctx->Attr("lambda2"); const float beta = ctx->Attr("beta"); const float weight_decay = ctx->Attr("weight_decay"); // TODO(zhengzekang): Undefined behavior for ftrl optimizer with weight_decay in `abs(new_z_val) // < lambda1` condition. CHECK_EQ(weight_decay, static_cast(0.0)) << "Currently not support for setting weight decay. "; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } FtrlUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), accumulate->mut_dptr(), z->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_FTRL_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("ftrl_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA template class AdadeltaUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AdadeltaUpdateKernel() = default; ~AdadeltaUpdateKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* square_avgs = ctx->Tensor4ArgNameAndIndex("square_avgs", 0); user_op::Tensor* acc_deltas = ctx->Tensor4ArgNameAndIndex("acc_deltas", 0); const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); const float rho = ctx->Attr("rho"); const float epsilon = ctx->Attr("epsilon"); const bool maximize = ctx->Attr("maximize"); const float weight_decay = ctx->Attr("weight_decay"); const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), model->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } AdadeltaUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), square_avgs->mut_dptr(), acc_deltas->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_ADADELTA_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("adadelta_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); REGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCPU, float, float); REGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCPU, double, double); #ifdef WITH_CUDA REGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif // WITH_CUDA } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/moving_average_min_max_observer_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include namespace oneflow { template void GenQuantScalePerLayerSymmetric(const T* in, const int64_t current_train_step, const int64_t stop_update_after_iters, const bool is_training, const int32_t quantization_bit, const int64_t num_elements, const float momentum, T* moving_max, T* moving_min, T* scale, T* zero_point) { if (current_train_step <= stop_update_after_iters && is_training) { T in_max = *std::max_element(in, in + num_elements); T in_min = *std::min_element(in, in + num_elements); in_max = std::max(std::abs(in_max), std::abs(in_min)); T moving_max_val = *moving_max; if (moving_max_val == 0) { *moving_max = in_max; } else { *moving_max = moving_max_val * momentum + in_max * (1 - momentum); } // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale *moving_min = *moving_max; } T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; *scale = (*moving_max) / denominator; *zero_point = 0; } template void GenQuantScalePerLayerAffine(const T* in, const int64_t current_train_step, const int64_t stop_update_after_iters, const bool is_training, const int32_t quantization_bit, const int64_t num_elements, const float momentum, T* moving_max, T* moving_min, T* scale, T* zero_point) { if (current_train_step <= stop_update_after_iters && is_training) { T in_max = *std::max_element(in, in + num_elements); T in_min = *std::min_element(in, in + num_elements); T moving_max_val = *moving_max; if (moving_max_val == 0) { *moving_max = in_max; } else { *moving_max = moving_max_val * momentum + in_max * (1 - momentum); } T moving_min_val = *moving_min; if (moving_min_val == 0) { *moving_min = in_min; } else { *moving_min = moving_min_val * momentum + in_min * (1 - momentum); } } T denominator = static_cast(pow(2.0, quantization_bit)) - 1; *scale = ((*moving_max) - (*moving_min)) / denominator; *zero_point = -std::round((*moving_min) / (*scale)); } template void GenQuantScalePerLayerCambricon(const T* in, const int64_t current_train_step, const int64_t stop_update_after_iters, const bool is_training, const int32_t quantization_bit, const int64_t num_elements, const float momentum, T* moving_max, T* moving_min, T* scale, T* zero_point) { if (current_train_step <= stop_update_after_iters && is_training) { T in_max = *std::max_element(in, in + num_elements); T in_min = *std::min_element(in, in + num_elements); in_max = std::max(std::abs(in_max), std::abs(in_min)); T moving_max_val = *moving_max; if (moving_max_val == 0) { *moving_max = in_max; } else { *moving_max = moving_max_val * momentum + in_max * (1 - momentum); } // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale *moving_min = *moving_max; } *scale = std::floor(std::log2(*moving_max)) - (quantization_bit - 2); *zero_point = 0; } template class CpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { public: CpuMovingAverageMinMaxObserverKernel() = default; ~CpuMovingAverageMinMaxObserverKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* current_train_step = ctx->Tensor4ArgNameAndIndex("current_train_step", 0); user_op::Tensor* moving_max = ctx->Tensor4ArgNameAndIndex("moving_max", 0); user_op::Tensor* moving_min = ctx->Tensor4ArgNameAndIndex("moving_min", 0); user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const float momentum = ctx->Attr("momentum"); const int64_t stop_update_after_iters = ctx->Attr("stop_update_after_iters"); const bool is_training = ctx->Attr("training"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const T* in_ptr = in->dptr(); const int64_t* current_train_step_ptr = current_train_step->dptr(); T* moving_max_ptr = moving_max->mut_dptr(); T* moving_min_ptr = moving_min->mut_dptr(); T* scale_ptr = scale->mut_dptr(); T* zero_point_ptr = zero_point->mut_dptr(); int64_t num_elements = in->shape_view().elem_cnt(); if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { GenQuantScalePerLayerSymmetric(in_ptr, *current_train_step_ptr, stop_update_after_iters, is_training, quantization_bit, num_elements, momentum, moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr); } else { // quantization_scheme == "affine" GenQuantScalePerLayerAffine(in_ptr, *current_train_step_ptr, stop_update_after_iters, is_training, quantization_bit, num_elements, momentum, moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr); } } else if (quantization_formula == "cambricon") { GenQuantScalePerLayerCambricon(in_ptr, *current_train_step_ptr, stop_update_after_iters, is_training, quantization_bit, num_elements, momentum, moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(dtype) \ REGISTER_USER_KERNEL("moving_average_min_max_observer") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(float); REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/moving_average_min_max_observer_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { // NOTE(Liang Depeng): refer to // https://stackoverflow.com/questions/17371275/implementing-max-reduce-in-cuda template __global__ void ReduceMaxMinPerLayer(const T* input_ptr, const int64_t elements, T* max_ptr, T* min_ptr) { extern __shared__ unsigned char shared_max_min_memory[]; T* shared_max = reinterpret_cast(shared_max_min_memory); T* shared_min = shared_max + blockDim.x; int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; shared_max[tid] = -FLT_MAX; shared_min[tid] = -FLT_MAX; while (gid < elements) { shared_max[tid] = max(shared_max[tid], input_ptr[gid]); shared_min[tid] = max(shared_min[tid], -input_ptr[gid]); gid += gridDim.x * blockDim.x; } __syncthreads(); gid = (blockDim.x * blockIdx.x) + tid; for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s && gid < elements) { shared_max[tid] = max(shared_max[tid], shared_max[tid + s]); shared_min[tid] = max(shared_min[tid], shared_min[tid + s]); } __syncthreads(); } if (tid == 0) { cuda::atomic::Max(max_ptr, shared_max[0]); cuda::atomic::Max(min_ptr, shared_min[0]); } } template __global__ void InitMaxMin(const int64_t elements, T* max_ptr, T* min_ptr) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { max_ptr[gid] = -FLT_MAX; min_ptr[gid] = -FLT_MAX; gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointSymmetric(const int64_t elements, const double quantization_bit, const float momentum, const T* max_ptr, const T* min_ptr, T* moving_max_ptr, T* moving_min_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T activation_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid])); T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; if (moving_max_ptr[gid] == 0) moving_max_ptr[gid] = activation_max; else moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + activation_max * (1 - momentum); // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale moving_min_ptr[gid] = moving_max_ptr[gid]; scale[gid] = moving_max_ptr[gid] / denominator; zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } template __global__ void CalFreezeScaleZeroPointSymmetric(const int64_t elements, const double quantization_bit, const float momentum, const T* moving_max_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; scale[gid] = moving_max_ptr[gid] / denominator; zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointAffine(const int64_t elements, const double quantization_bit, const float momentum, const T* max_ptr, const T* min_ptr, T* moving_max_ptr, T* moving_min_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T denominator = static_cast(pow(2.0, quantization_bit)) - 1; if (moving_max_ptr[gid] == 0) moving_max_ptr[gid] = max_ptr[gid]; else moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + max_ptr[gid] * (1 - momentum); if (moving_min_ptr[gid] == 0) moving_min_ptr[gid] = -min_ptr[gid]; else moving_min_ptr[gid] = moving_min_ptr[gid] * momentum + -min_ptr[gid] * (1 - momentum); T min = moving_min_ptr[gid]; T s = (moving_max_ptr[gid] - min) / denominator; scale[gid] = s; zero_point[gid] = -round(min / s); gid += gridDim.x * blockDim.x; } } template __global__ void CalFreezeScaleZeroPointAffine(const int64_t elements, const double quantization_bit, const float momentum, const T* moving_max_ptr, const T* moving_min_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T denominator = static_cast(pow(2.0, quantization_bit)) - 1; T min = moving_min_ptr[gid]; T s = (moving_max_ptr[gid] - min) / denominator; scale[gid] = s; zero_point[gid] = -round(min / s); gid += gridDim.x * blockDim.x; } } template __global__ void CalScaleZeroPointCambricon(const int64_t elements, const double quantization_bit, const float momentum, const T* max_ptr, const T* min_ptr, T* moving_max_ptr, T* moving_min_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T activation_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid])); if (moving_max_ptr[gid] == 0) moving_max_ptr[gid] = activation_max; else moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + activation_max * (1 - momentum); // NOTE(Liang Depeng): cambricon quantization only use moving_max to calculate the scale moving_min_ptr[gid] = moving_max_ptr[gid]; scale[gid] = floor(log2(moving_max_ptr[gid])) - (quantization_bit - 2); zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } template __global__ void CalFreezeScaleZeroPointCambricon(const int64_t elements, const double quantization_bit, const float momentum, const T* moving_max_ptr, T* scale, T* zero_point) { int64_t tid = threadIdx.x; int64_t gid = (blockDim.x * blockIdx.x) + tid; while (gid < elements) { T denominator = static_cast(pow(2.0, quantization_bit - 1)) - 1; scale[gid] = floor(log2(moving_max_ptr[gid])) - (quantization_bit - 2); zero_point[gid] = 0; gid += gridDim.x * blockDim.x; } } ep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num, size_t shared_mem_size) { ep::CudaLaunchConfig config; stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1); config.shared_mem_size = shared_mem_size; return config; } } // namespace #define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \ (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__); template class GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { public: GpuMovingAverageMinMaxObserverKernel() = default; ~GpuMovingAverageMinMaxObserverKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* current_train_step = ctx->Tensor4ArgNameAndIndex("current_train_step", 0); user_op::Tensor* moving_max = ctx->Tensor4ArgNameAndIndex("moving_max", 0); user_op::Tensor* moving_min = ctx->Tensor4ArgNameAndIndex("moving_min", 0); user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const bool is_training = ctx->Attr("training"); const int64_t stop_update_after_iters = ctx->Attr("stop_update_after_iters"); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const float momentum = ctx->Attr("momentum"); const std::string quantization_formula = ctx->Attr("quantization_formula"); int64_t elements = in->shape_view().elem_cnt(); T* max_ptr = tmp_buffer->mut_dptr(); T* min_ptr = max_ptr + 1; int64_t* host_current_train_step_ptr = new int64_t[current_train_step->shape_view().elem_cnt()]; OF_CUDA_CHECK(cudaMemcpy(host_current_train_step_ptr, current_train_step->dptr(), current_train_step->shape_view().elem_cnt() * sizeof(int64_t), cudaMemcpyDefault)); auto* cuda_stream = ctx->stream()->As(); if (*host_current_train_step_ptr <= stop_update_after_iters && is_training) { LAUNCH_CUDA_KERNEL((InitMaxMin), cuda_stream, 1, 0, 1, max_ptr, min_ptr); LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), cuda_stream, elements, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, max_ptr, min_ptr); } bool moving = (*host_current_train_step_ptr <= stop_update_after_iters) && is_training; if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { if (moving) { LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointSymmetric), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } } else { // quantization_scheme == "affine" if (moving) { LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointAffine), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), moving_min->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } } } else if (quantization_formula == "cambricon") { if (moving) { LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointCambricon), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } } else { UNIMPLEMENTED(); } delete[] host_current_train_step_ptr; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(dtype) \ REGISTER_USER_KERNEL("moving_average_min_max_observer") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 2 * sizeof(dtype); }) REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(float); REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/multi_reduce_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/device_type.h" #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct MultiReduceParam { const T* data; size_t size; }; template struct MultiReduce { void operator()(ep::Stream* stream, TransformFn transform, const std::vector>& params, T init, T* ret, T* temp); }; template struct MultiReduce { void operator()(ep::Stream* stream, TransformFn transform, const std::vector>& params, T init, T* ret, T* temp) { *ret = init; ReduceFn reduce{}; FOR_RANGE(size_t, i, 0, params.size()) { const auto& p = params[i]; FOR_RANGE(size_t, j, 0, p.size) { *ret = reduce(*ret, transform(p.data[j])); } } } }; template struct BinaryAdd { OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x + y; } }; template struct BinaryMax { OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x > y ? x : y; } }; template struct BinaryMin { OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x < y ? x : y; } }; template struct Abs { OF_DEVICE_FUNC T operator()(const T& x) const { return x < GetZeroVal() ? -x : x; } }; template struct PowByZero { OF_DEVICE_FUNC T operator()(const T& x) const { return x != GetZeroVal() ? GetOneVal() : x; } }; template struct Square { OF_DEVICE_FUNC T operator()(const T& x) const { return x * x; } }; template struct AbsPow { explicit AbsPow(const T& base) : base_(base) {} OF_DEVICE_FUNC T operator()(const T& x) { T abs_x = x < GetZeroVal() ? -x : x; #if defined(__CUDA_ARCH__) return pow(abs_x, base_); #else return std::pow(abs_x, base_); #endif } private: T base_; }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/multi_reduce_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/multi_reduce_kernels.h" namespace oneflow { #define REGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("multi_reduce_sum_pow_abs") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); #define REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(op_type_name, ximum_enum, dtype) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); #define REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL("multi_reduce_max_abs", Ximum::kMax, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL("multi_reduce_min_abs", Ximum::kMin, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL("local_multi_reduce_max_abs", Ximum::kMax, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL("local_multi_reduce_min_abs", Ximum::kMin, dtype) REGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(float) REGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(double) REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(float) REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/multi_reduce_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/multi_reduce_kernels.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/device/cuda_util.h" #include #include namespace oneflow { namespace { constexpr int64_t kMultiReduceMaxPackSize = 64; template struct MultiReduceParamsPack { MultiReduceParam params[kMultiReduceMaxPackSize]; size_t size; }; template __global__ void MultiBlockReduceGpu(TransformFn transform, const MultiReduceParamsPack pack_params, const T init, T* out) { ReduceFn reduce_fn{}; T t_out = init; for (int i = 0; i < pack_params.size; ++i) { const auto& param = pack_params.params[i]; CUDA_1D_KERNEL_LOOP(j, param.size) { t_out = reduce_fn(t_out, transform(param.data[j])); } } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_out = BlockReduce(temp_storage).Reduce(t_out, reduce_fn); if (threadIdx.x == 0) { out[blockIdx.x] = b_out; } } size_t InferTempStorageSize(user_op::InferContext* ctx) { auto input_size = ctx->input_size("x"); if (input_size == 0) { return 0; } int64_t max_elem_cnt = 0; int64_t pack_size = 0; int32_t num_blocks = 0; for (size_t i = 0; i < input_size; ++i) { int64_t elem_cnt = ctx->InputShape("x", i).elem_cnt(); max_elem_cnt = std::max(max_elem_cnt, elem_cnt); pack_size++; if (pack_size == kMultiReduceMaxPackSize || i == input_size - 1) { CHECK_LT(max_elem_cnt, std::numeric_limits::max()); num_blocks += BlocksNum4ThreadsNum(static_cast(max_elem_cnt)); max_elem_cnt = 0; pack_size = 0; } } CHECK_LT(num_blocks, kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock) << "Too much blocks needed for computing " << ctx->op_name() << ", should be less than " << kCudaThreadsNumPerBlock << "*" << kCudaThreadsNumPerBlock << "*" << kCudaThreadsNumPerBlock << ", but got " << num_blocks; size_t elem_size = GetSizeOfDataType(ctx->InputDType("x", 0)); return GetCudaAlignedSize(num_blocks * elem_size * 2); } } // namespace template struct MultiReduce { void operator()(ep::Stream* stream, TransformFn transform, const std::vector>& params, T init, T* ret, T* temp) { CHECK_NOTNULL(temp); int32_t total_num_blocks = 0; for (size_t i = 0; i < params.size(); i += kMultiReduceMaxPackSize) { MultiReduceParamsPack pack_params{}; size_t max_elem_cnt = 0; pack_params.size = std::min(kMultiReduceMaxPackSize, params.size() - i); for (size_t j = 0; j < pack_params.size; ++j) { pack_params.params[j] = params[i + j]; max_elem_cnt = std::max(max_elem_cnt, pack_params.params[j].size); } int32_t num_blocks = BlocksNum4ThreadsNum(max_elem_cnt); MultiBlockReduceGpu <<As()->cuda_stream()>>>( transform, pack_params, init, temp + total_num_blocks); total_num_blocks += num_blocks; } size_t wksp_size = 0; auto DeviceReduce = [&](void* temp_storage) -> void { OF_CUDA_CHECK(cub::DeviceReduce::Reduce(temp_storage, wksp_size, temp, ret, total_num_blocks, ReduceFn{}, init, stream->As()->cuda_stream())); }; DeviceReduce(nullptr); // NOTE(zwx): We have allocated the temp storage with the space // that can hold all the elements to reduce, // normally the `temp_storage_bytes` for cub::DeviceReduce shouldn't exceed it. CHECK_LE(wksp_size, total_num_blocks * sizeof(T)) << wksp_size << " size in bytes of temp storage is needed for doing cub::DeviceReduce, " << "but only allocated " << total_num_blocks * sizeof(T); DeviceReduce(temp + total_num_blocks); } }; #define REGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("multi_reduce_sum_pow_abs") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTempStorageSize); #define REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(op_type_name, ximum_enum, dtype) \ REGISTER_USER_KERNEL(op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferTempStorageSize); #define REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL("multi_reduce_max_abs", Ximum::kMax, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL("multi_reduce_min_abs", Ximum::kMin, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL("local_multi_reduce_max_abs", Ximum::kMax, dtype) \ REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL("local_multi_reduce_min_abs", Ximum::kMin, dtype) REGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(float) REGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(double) REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(float) REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/multi_reduce_kernels.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_ #define ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/multi_reduce_kernel_util.h" namespace oneflow { template class MultiReduceSumPowAbsKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiReduceSumPowAbsKernel() = default; ~MultiReduceSumPowAbsKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache*) const override { std::vector> params; params.resize(ctx->input_size("x")); for (size_t i = 0; i < params.size(); ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", i); params[i].size = x->shape_view().elem_cnt(); params[i].data = x->dptr(); } user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); T* y_dptr = y->mut_dptr(); user_op::Tensor* temp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* tmp_dptr = temp ? temp->mut_dptr() : nullptr; float p = ctx->Attr("p"); if (p == 0) { PowByZero func{}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), y_dptr, tmp_dptr); } else if (p == 1) { Abs func{}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), y_dptr, tmp_dptr); } else if (p == 2) { Square func{}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), y_dptr, tmp_dptr); } else { AbsPow func{p}; MultiReduce> reduce_sum{}; reduce_sum(ctx->stream(), func, params, GetZeroVal(), y_dptr, tmp_dptr); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; enum class Ximum { kMax = 0, kMin = 1, }; template class MultiReduceXimumAbsKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiReduceXimumAbsKernel() = default; ~MultiReduceXimumAbsKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache*) const override { std::vector> params; params.resize(ctx->input_size("x")); for (size_t i = 0; i < params.size(); ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", i); params[i].size = x->shape_view().elem_cnt(); params[i].data = x->dptr(); } user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* temp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* tmp_dptr = temp ? temp->mut_dptr() : nullptr; Abs abs{}; if (X == Ximum::kMax) { MultiReduce> reduce_max{}; reduce_max(ctx->stream(), abs, params, GetZeroVal(), y->mut_dptr(), tmp_dptr); } else if (X == Ximum::kMin) { MultiReduce> reduce_min{}; reduce_min(ctx->stream(), abs, params, std::numeric_limits::max(), y->mut_dptr(), tmp_dptr); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_ ================================================ FILE: oneflow/user/kernels/multi_tensor_model_update_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/multi_tensor_model_update_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template class MultiTensorSGDUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorSGDUpdateKernel() = default; ~MultiTensorSGDUpdateKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const double scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float weight_decay = ctx->Attr("weight_decay"); const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<2> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorSGDUpdateKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_sgd_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif template class MultiTensorMomentumUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorMomentumUpdateKernel() = default; ~MultiTensorMomentumUpdateKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const double scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float weight_decay = ctx->Attr("weight_decay"); const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float momentum = ctx->Attr("momentum"); const float dampening = ctx->Attr("dampening"); const bool nesterov = ctx->Attr("nesterov"); const bool maximize = ctx->Attr("maximize"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<3> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[2][count] = (ctx->Tensor4ArgNameAndIndex("momentum_buf", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorMomentumUpdateKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, momentum, dampening, nesterov, maximize, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_momentum_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("momentum_buf", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif template class MultiTensorAdamUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorAdamUpdateKernel() = default; ~MultiTensorAdamUpdateKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const auto scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float beta1 = ctx->Attr("beta1"); const float beta2 = ctx->Attr("beta2"); const float epsilon = ctx->Attr("epsilon"); const float weight_decay = ctx->Attr("weight_decay"); const bool amsgrad = ctx->Attr("amsgrad"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); if (amsgrad) { UNIMPLEMENTED() << "Multi Tensor Adam Update do not support amsgrad = True. "; } const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const float bias_correction1_val = ctx->Attr("bias_correction1_val"); const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction1_ptr = bias_correction1->dptr(); } const float bias_correction2_val = ctx->Attr("bias_correction2_val"); const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction2_ptr = bias_correction2->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<4> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[2][count] = (ctx->Tensor4ArgNameAndIndex("m", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[3][count] = (ctx->Tensor4ArgNameAndIndex("v", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorAdamUpdateKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, bias_correction1_ptr, bias_correction2_ptr, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_adam_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16); REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, double, double); #endif template class MultiTensorSGDUpdateWithCastKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorSGDUpdateWithCastKernel() = default; ~MultiTensorSGDUpdateWithCastKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const double scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float weight_decay = ctx->Attr("weight_decay"); const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<3> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[2][count] = (ctx->Tensor4ArgNameAndIndex("model_copy", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorSGDUpdateWithCastKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_sgd_update_with_cast") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("model_copy", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16); #endif template class MultiTensorMomentumUpdateWithCastKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorMomentumUpdateWithCastKernel() = default; ~MultiTensorMomentumUpdateWithCastKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const double scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float weight_decay = ctx->Attr("weight_decay"); const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); const float momentum = ctx->Attr("momentum"); const float dampening = ctx->Attr("dampening"); const bool nesterov = ctx->Attr("nesterov"); const bool maximize = ctx->Attr("maximize"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<4> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[2][count] = (ctx->Tensor4ArgNameAndIndex("momentum_buf", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[3][count] = (ctx->Tensor4ArgNameAndIndex("model_copy", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorMomentumUpdateWithCastKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, momentum, dampening, nesterov, maximize, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_momentum_update_with_cast") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("momentum_buf", 0) == GetDataType::value) \ && (user_op::HobDataType("model_copy", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16); #endif template class MultiTensorAdamUpdateWithCastKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorAdamUpdateWithCastKernel() = default; ~MultiTensorAdamUpdateWithCastKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const auto scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const float beta1 = ctx->Attr("beta1"); const float beta2 = ctx->Attr("beta2"); const float epsilon = ctx->Attr("epsilon"); const float weight_decay = ctx->Attr("weight_decay"); const bool amsgrad = ctx->Attr("amsgrad"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); if (amsgrad) { UNIMPLEMENTED() << "Multi Tensor Adam Update do not support amsgrad = True. "; } const float* learning_rate_ptr = nullptr; const float learning_rate_val = ctx->Attr("learning_rate_val"); const float lr_scale = ctx->Attr("learning_rate_scale"); if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const float bias_correction1_val = ctx->Attr("bias_correction1_val"); const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction1_ptr = bias_correction1->dptr(); } const float bias_correction2_val = ctx->Attr("bias_correction2_val"); const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1); // Just for Lazy Optional Input Check. bias_correction2_ptr = bias_correction2->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex("model", 0)->data_type()); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } TensorTupleParams<5> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_diff", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[2][count] = (ctx->Tensor4ArgNameAndIndex("m", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[3][count] = (ctx->Tensor4ArgNameAndIndex("v", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[4][count] = (ctx->Tensor4ArgNameAndIndex("model_copy", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorAdamUpdateWithCastKernelUtil::Update( ctx->stream(), total_elem_cnt, count, static_cast(scale), l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, bias_correction1_ptr, bias_correction2_ptr, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype) \ REGISTER_USER_KERNEL("multi_tensor_adam_update_with_cast") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value) \ && (user_op::HobDataType("model_diff", 0) == GetDataType::value) \ && (user_op::HobDataType("model_copy", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float); REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16); #endif template class MultiTensorYoloV5WeightUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiTensorYoloV5WeightUpdateKernel() = default; ~MultiTensorYoloV5WeightUpdateKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t n_tensor = ctx->input_size("model"); const float d = ctx->Attr("d"); TensorTupleParams<2> tensor_tuple_params{}; int32_t count = 0; int32_t total_elem_cnt = 0; for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) { tensor_tuple_params.ptr[0][count] = (ctx->Tensor4ArgNameAndIndex("model", tensor_idx))->mut_dptr(); tensor_tuple_params.ptr[1][count] = (ctx->Tensor4ArgNameAndIndex("model_update", tensor_idx))->mut_dptr(); const int64_t tensor_elem_cnt = ctx->Tensor4ArgNameAndIndex("model", tensor_idx)->shape_view().elem_cnt(); tensor_tuple_params.sizes[count] = tensor_elem_cnt; count += 1; total_elem_cnt += tensor_elem_cnt; if (count == kMaxTuples || tensor_idx == n_tensor - 1) { MultiTensorYoloV5WeightUpdateKernelUtil::Update( ctx->stream(), total_elem_cnt, count, d, tensor_tuple_params); count = 0; total_elem_cnt = 0; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_MULTI_TENSOR_YOLOV5_WEIGHT_UPDATE_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("multi_tensor_yolov5_weight_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("model", 0) == GetDataType::value)); #ifdef WITH_CUDA REGISTER_MULTI_TENSOR_YOLOV5_WEIGHT_UPDATE_KERNEL(DeviceType::kCUDA, float); #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/multi_tensor_model_update_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/model_update_kernel_util.h" #include "oneflow/user/kernels/multi_tensor_model_update_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { constexpr int kBlockSize = 256; constexpr int kUnrollSize = 4; unsigned int ComputeGridSize(ep::Stream* stream, const int32_t block_size, const int64_t elem_cnt) { auto* cuda_stream = stream->As(); const int32_t max_threads_multi_process = cuda_stream->device_properties().maxThreadsPerMultiProcessor; const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount; unsigned int blocks_per_sm = max_threads_multi_process / block_size; unsigned int grid_size = ((elem_cnt + block_size - 1) / block_size); grid_size = std::min((unsigned int)multi_processor_count * blocks_per_sm, grid_size); return grid_size; } template __global__ void MultiTensorSGDUpdateGpu(int64_t num_tensor, T scale, const float l1, const float l2, const float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams tensor_tuple_params) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; int64_t v_block_id = blockIdx.x; for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) { const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx]; T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx]; G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx]; half* model_copy_ptr = nullptr; if (N == 3) { model_copy_ptr = (half*)tensor_tuple_params.ptr[2][tensor_idx]; } for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt; i += blockDim.x * gridDim.x * kUnrollSize) { T model_val[kUnrollSize] = {0}; G model_diff[kUnrollSize] = {0}; #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { model_val[ilp] = *(model_ptr + actual_idx); model_diff[ilp] = *(model_diff_ptr + actual_idx); } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { T model_diff_t = CastScaleRegularizeGradientFunctor()( model_diff[ilp], model_val[ilp], scale, l1, l2); model_val[ilp] = model_val[ilp] - learning_rate_val * (model_diff_t + weight_decay * model_val[ilp]); } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { *(model_ptr + actual_idx) = model_val[ilp]; if (N == 3) { *(model_copy_ptr + actual_idx) = static_cast(model_val[ilp]); } } } } v_block_id -= tensor_tuple_params.block_offset[tensor_idx]; if (v_block_id < 0) { v_block_id += gridDim.x; } } } template struct MultiTensorSGDUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params); }; template void MultiTensorSGDUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorSGDUpdateGpu <<As()->cuda_stream()>>>( n_tensor, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, tensor_tuple_params); } template struct MultiTensorSGDUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params); }; template void MultiTensorSGDUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params) { MultiTensorSGDUpdateKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, tensor_tuple_params); } template struct MultiTensorSGDUpdateKernelUtil; template struct MultiTensorSGDUpdateKernelUtil; template struct MultiTensorSGDUpdateKernelUtil; template __global__ void MultiTensorMomentumUpdateGpu( int64_t num_tensor, T scale, const float l1, const float l2, const float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams tensor_tuple_params) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } learning_rate_val *= lr_scale; int64_t v_block_id = blockIdx.x; for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) { const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx]; T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx]; G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx]; T* momentum_buf_ptr = (T*)tensor_tuple_params.ptr[2][tensor_idx]; half* model_copy_ptr = nullptr; if (N == 4) { model_copy_ptr = (half*)tensor_tuple_params.ptr[3][tensor_idx]; } for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt; i += blockDim.x * gridDim.x * kUnrollSize) { T model_val[kUnrollSize] = {0}; G model_diff[kUnrollSize] = {0}; T momentum_buf[kUnrollSize] = {0}; #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { model_val[ilp] = *(model_ptr + actual_idx); model_diff[ilp] = *(model_diff_ptr + actual_idx); momentum_buf[ilp] = *(momentum_buf_ptr + actual_idx); } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { T model_diff_t = CastScaleRegularizeGradientFunctor()( model_diff[ilp], model_val[ilp], scale, l1, l2); if (weight_decay != 0.f) { model_diff_t += weight_decay * model_val[ilp]; } momentum_buf[ilp] = momentum * momentum_buf[ilp] + (1.f - dampening) * model_diff_t; if (nesterov) model_diff_t += momentum * momentum_buf[ilp]; else model_diff_t = momentum_buf[ilp]; T alpha = -learning_rate_val; if (maximize) alpha = learning_rate_val; model_val[ilp] += alpha * model_diff_t; } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { *(model_ptr + actual_idx) = model_val[ilp]; *(momentum_buf_ptr + actual_idx) = momentum_buf[ilp]; if (N == 4) { *(model_copy_ptr + actual_idx) = static_cast(model_val[ilp]); } } } } v_block_id -= tensor_tuple_params.block_offset[tensor_idx]; if (v_block_id < 0) { v_block_id += gridDim.x; } } } template struct MultiTensorMomentumUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params); }; template void MultiTensorMomentumUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorMomentumUpdateGpu <<As()->cuda_stream()>>>( n_tensor, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize, tensor_tuple_params); } template struct MultiTensorMomentumUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params); }; template void MultiTensorMomentumUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params) { MultiTensorMomentumUpdateKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize, tensor_tuple_params); } template struct MultiTensorMomentumUpdateKernelUtil; template struct MultiTensorMomentumUpdateKernelUtil; template struct MultiTensorMomentumUpdateKernelUtil; template __global__ void MultiTensorAdamUpdateGpu(int64_t num_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1_ptr, const float* bias_correction2_ptr, TensorTupleParams tensor_tuple_params) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } learning_rate_val *= lr_scale; int64_t v_block_id = blockIdx.x; for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) { const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx]; T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx]; G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx]; T* m_ptr = (T*)tensor_tuple_params.ptr[2][tensor_idx]; T* v_ptr = (T*)tensor_tuple_params.ptr[3][tensor_idx]; half* model_copy_ptr = nullptr; if (N == 5) { model_copy_ptr = (half*)tensor_tuple_params.ptr[4][tensor_idx]; } for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt; i += blockDim.x * gridDim.x * kUnrollSize) { T model_val[kUnrollSize] = {0}; T m_val[kUnrollSize] = {0}; T v_val[kUnrollSize] = {0}; G model_diff[kUnrollSize] = {0}; #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { model_val[ilp] = *(model_ptr + actual_idx); m_val[ilp] = *(m_ptr + actual_idx); v_val[ilp] = *(v_ptr + actual_idx); model_diff[ilp] = *(model_diff_ptr + actual_idx); } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { T model_diff_t = CastScaleRegularizeGradientFunctor()( model_diff[ilp], model_val[ilp], scale, l1, l2); m_val[ilp] = beta1 * m_val[ilp] + (1 - beta1) * model_diff_t; v_val[ilp] = beta2 * v_val[ilp] + (1 - beta2) * model_diff_t * model_diff_t; T denom = (sqrt(v_val[ilp]) / sqrt(bias_correction2_val)) + epsilon; const T step_size = learning_rate_val / bias_correction1_val; model_val[ilp] = model_val[ilp] - step_size * (m_val[ilp] / denom) - learning_rate_val * weight_decay * model_val[ilp]; } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { *(model_ptr + actual_idx) = model_val[ilp]; *(m_ptr + actual_idx) = m_val[ilp]; *(v_ptr + actual_idx) = v_val[ilp]; if (N == 5) { *(model_copy_ptr + actual_idx) = static_cast(model_val[ilp]); } } } } v_block_id -= tensor_tuple_params.block_offset[tensor_idx]; if (v_block_id < 0) { v_block_id += gridDim.x; } } } template struct MultiTensorAdamUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params); }; template void MultiTensorAdamUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorAdamUpdateGpu <<As()->cuda_stream()>>>( n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params); } template struct MultiTensorAdamUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params); }; template void MultiTensorAdamUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params) { MultiTensorAdamUpdateKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params); } template struct MultiTensorAdamUpdateKernelUtil; template struct MultiTensorAdamUpdateKernelUtil; template struct MultiTensorAdamUpdateKernelUtil; template struct MultiTensorSGDUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params); }; template void MultiTensorSGDUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorSGDUpdateGpu <<As()->cuda_stream()>>>( n_tensor, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, tensor_tuple_params); } template struct MultiTensorSGDUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params); }; template void MultiTensorSGDUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params) { MultiTensorSGDUpdateWithCastKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, tensor_tuple_params); } template struct MultiTensorSGDUpdateWithCastKernelUtil; template struct MultiTensorSGDUpdateWithCastKernelUtil; template struct MultiTensorMomentumUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params); }; template void MultiTensorMomentumUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorMomentumUpdateGpu <<As()->cuda_stream()>>>( n_tensor, static_cast(scale), l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize, tensor_tuple_params); } template struct MultiTensorMomentumUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params); }; template void MultiTensorMomentumUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params) { MultiTensorMomentumUpdateWithCastKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize, tensor_tuple_params); } template struct MultiTensorMomentumUpdateWithCastKernelUtil; template struct MultiTensorMomentumUpdateWithCastKernelUtil; template struct MultiTensorAdamUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params); }; template void MultiTensorAdamUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorAdamUpdateGpu <<As()->cuda_stream()>>>( n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params); } template struct MultiTensorAdamUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params); }; template void MultiTensorAdamUpdateWithCastKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params) { MultiTensorAdamUpdateWithCastKernelUtil::Update( stream, elem_cnt, n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params); } template struct MultiTensorAdamUpdateWithCastKernelUtil; template struct MultiTensorAdamUpdateWithCastKernelUtil; template __global__ void MultiTensorYoloModelEmaUpdateGpu(int64_t num_tensor, const float d, TensorTupleParams tensor_tuple_params) { int64_t v_block_id = blockIdx.x; for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) { const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx]; T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx]; T* model_update_ptr = (T*)tensor_tuple_params.ptr[1][tensor_idx]; for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt; i += blockDim.x * gridDim.x * kUnrollSize) { T model_val[kUnrollSize] = {0}; T model_update_val[kUnrollSize] = {0}; #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { model_val[ilp] = *(model_ptr + actual_idx); model_update_val[ilp] = *(model_update_ptr + actual_idx); } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { model_val[ilp] *= d; model_val[ilp] += (1 - d) * model_update_val[ilp]; } } #pragma unroll for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) { int64_t actual_idx = i + ilp * blockDim.x; if (actual_idx < tensor_elem_cnt) { *(model_ptr + actual_idx) = model_val[ilp]; *(model_update_ptr + actual_idx) = model_update_val[ilp]; } } } v_block_id -= tensor_tuple_params.block_offset[tensor_idx]; if (v_block_id < 0) { v_block_id += gridDim.x; } } } template struct MultiTensorYoloV5WeightUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d, TensorTupleParams<2> tensor_tuple_params); }; template<> struct MultiTensorYoloV5WeightUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d, TensorTupleParams<2> tensor_tuple_params); }; template void MultiTensorYoloV5WeightUpdateKernelUtil::Update( ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d, TensorTupleParams<2> tensor_tuple_params) { const unsigned int grid_size = ComputeGridSize(stream->As(), kBlockSize, elem_cnt); for (int i = 0; i < n_tensor; i++) { tensor_tuple_params.block_offset[i] = ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize)) % grid_size; } MultiTensorYoloModelEmaUpdateGpu <<As()->cuda_stream()>>>( n_tensor, d, tensor_tuple_params); } template struct MultiTensorYoloV5WeightUpdateKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/multi_tensor_model_update_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_MULTI_TENSOR_MODEL_UPDATE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_MULTI_TENSOR_MODEL_UPDATE_KERNEL_UTIL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { // Kernel arg size has 4K limit, but currently we set process 32 tensors in each kernel. constexpr int kMaxTuples = 32; template struct TensorTupleParams { void* ptr[N][kMaxTuples]; int64_t sizes[kMaxTuples]; int32_t block_offset[kMaxTuples]; }; template struct MultiTensorSGDUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params); }; template struct MultiTensorMomentumUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params); }; template struct MultiTensorAdamUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params); }; template struct MultiTensorSGDUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params); }; template struct MultiTensorMomentumUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params); }; template struct MultiTensorAdamUpdateWithCastKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params); }; template struct MultiTensorYoloV5WeightUpdateKernelUtil { static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d, TensorTupleParams<2> tensor_tuple_params); }; } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/mutable_cast_once_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewCastPrimitive(Context* ctx) { const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), in_data_type, out_data_type); } class MutableCastOnceOpKernelState final : public OpKernelState { public: MutableCastOnceOpKernelState() : cast_once_flag_(false) {} void SetDone() { if (!cast_once_flag_) { cast_once_flag_ = true; } } bool IsDone() { return cast_once_flag_; } private: bool cast_once_flag_ = false; }; class MutableCastOnce final : public OpKernel { public: MutableCastOnce() = default; ~MutableCastOnce() = default; std::shared_ptr CreateOpKernelState(KernelInitContext* ctx) const override { return std::make_shared(); } private: void Compute(KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* cast_state = CHECK_NOTNULL(dynamic_cast(state)); if (cast_state->IsDone()) { return; } const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input_tensor->shape_view().elem_cnt(); CHECK_EQ(output_tensor->shape_view().elem_cnt(), elem_cnt); auto cast_primitive = NewCastPrimitive(ctx); CHECK(cast_primitive); cast_primitive->Launch(ctx->stream(), input_tensor->dptr(), output_tensor->mut_dptr(), elem_cnt); cast_state->SetDone(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto CastPrimitiveExists() { return hob::make_custom("CastPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewCastPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("mutable_cast_once") .SetCreateFn() .SetIsMatchedHob(CastPrimitiveExists() == true); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/narrow_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewCopyNdPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type(), 3); } template std::unique_ptr NewMemsetPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type()); } auto CopyNdPrimitiveExists() { return hob::make_custom("CopyNdPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewCopyNdPrimitive(&ctx).operator bool(); }); } auto MemsetPrimitiveExists() { return hob::make_custom("MemsetPrimitiveExists", [](const KernelRegContext& ctx) { return NewMemsetPrimitive(&ctx).operator bool(); }); } } // namespace class NarrowKernel final : public user_op::OpKernel { public: NarrowKernel() = default; ~NarrowKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); if (in->shape_view().elem_cnt() == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t& dim = ctx->Attr("dim"); const int64_t& start = ctx->Attr("start"); int64_t length = out->shape_view().At(dim); const ShapeView in_shape = in->shape_view(); auto copy_nd_primitive = NewCopyNdPrimitive(ctx); CHECK(copy_nd_primitive); const int64_t outer_dim = in_shape.Count(0, dim); const int64_t inner_dim = in_shape.Count(dim + 1); const int64_t narrow_dim = in_shape.At(dim); DimVector dst_shape = {outer_dim, length, inner_dim}; DimVector dst_pos_vec = {0, 0, 0}; DimVector src_shape = {outer_dim, narrow_dim, inner_dim}; DimVector src_pos_vec = {0, start, 0}; DimVector extent_vec = {outer_dim, length, inner_dim}; copy_nd_primitive->Launch(ctx->stream(), out->data_type(), 3, out->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), in->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class NarrowGradKernel final : public user_op::OpKernel { public: NarrowGradKernel() = default; ~NarrowGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t& dim = ctx->Attr("dim"); const int64_t& start = ctx->Attr("start"); int64_t length = dy->shape_view().At(dim); size_t dx_byte_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); void* dst = dx->mut_dptr(); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->device_type()); CHECK(memset_primitive); memset_primitive->Launch(ctx->stream(), dst, 0, dx_byte_size); auto copy_nd_primitive = NewCopyNdPrimitive(ctx); CHECK(copy_nd_primitive); const ShapeView dx_shape = dx->shape_view(); const int64_t outer_dim = dx_shape.Count(0, dim); const int64_t inner_dim = dx_shape.Count(dim + 1); const int64_t narrow_dim = dx_shape.At(dim); DimVector dst_shape = {outer_dim, narrow_dim, inner_dim}; DimVector dst_pos_vec = {0, start, 0}; DimVector src_shape = {outer_dim, length, inner_dim}; DimVector src_pos_vec = {0, 0, 0}; DimVector extent_vec = {outer_dim, length, inner_dim}; copy_nd_primitive->Launch(ctx->stream(), dx->data_type(), 3, dst, dst_shape.data(), dst_pos_vec.data(), dy->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("narrow").SetCreateFn().SetIsMatchedHob(CopyNdPrimitiveExists() == true); REGISTER_USER_KERNEL("narrow_grad") .SetCreateFn() .SetIsMatchedHob(MemsetPrimitiveExists() && CopyNdPrimitiveExists()); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "oneflow/user/kernels/collective_communication/include/all_reduce.h" #include "oneflow/user/kernels/collective_communication/include/all_gather.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { namespace { auto AllReduceCollectiveCommunicationExists() { return hob::make_custom("AllReduceCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllReduceRegistered(device_type); }); } auto AllGatherCollectiveCommunicationExists() { return hob::make_custom("AllGatherCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllGatherRegistered(device_type); }); } auto AllToAllCollectiveCommunicationExists() { return hob::make_custom("AllToAllCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllToAllRegistered(device_type); }); } class CclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { public: explicit CclLogical2DSameDim0KernelCommState(user_op::KernelInitContext* ctx) : is_init_(false), stream_name_(EagerCclCommMgr::kDefaultCclStreamName), parallel_desc_(ctx->parallel_desc()), this_parallel_id_(ctx->parallel_ctx().parallel_id()) { if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } } ~CclLogical2DSameDim0KernelCommState() override = default; const ccl::CclComm& ccl_comm() { if (!is_init_) { Init(); } return ccl_comm_; } int64_t num_ranks() { if (!is_init_) { Init(); } return num_ranks_; } const std::string& stream_name() const { return stream_name_; } private: void Init() { CHECK(!is_init_); const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); const int64_t group_size = hierarchy.At(1); EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, this_parallel_id_, "SameDim0"); num_ranks_ = group_size; is_init_ = true; } bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; int64_t this_parallel_id_; int64_t num_ranks_{}; ccl::CclComm ccl_comm_{}; }; class CclLogical2DSameDim0AllGatherNoncontinuousKernelState : public CclLogical2DSameDim0KernelCommState { public: explicit CclLogical2DSameDim0AllGatherNoncontinuousKernelState(user_op::KernelInitContext* ctx) : CclLogical2DSameDim0KernelCommState(ctx), src_split_axis_(-1) {} ~CclLogical2DSameDim0AllGatherNoncontinuousKernelState() override = default; int64_t src_split_axis() const { return src_split_axis_; } void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; } private: int64_t src_split_axis_; }; class CclLogical2DSameDim0All2AllKernelState : public CclLogical2DSameDim0KernelCommState { public: explicit CclLogical2DSameDim0All2AllKernelState(user_op::KernelInitContext* ctx) : CclLogical2DSameDim0KernelCommState(ctx), src_split_axis_(-1), dst_split_axis_(-1) {} ~CclLogical2DSameDim0All2AllKernelState() override = default; int64_t src_split_axis() const { return src_split_axis_; } void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; } int64_t dst_split_axis() const { return dst_split_axis_; } void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; } private: int64_t src_split_axis_; int64_t dst_split_axis_; }; class CclLogical2DSameDim0AllReduce final : public user_op::OpKernel { public: CclLogical2DSameDim0AllReduce() = default; ~CclLogical2DSameDim0AllReduce() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); VLOG(3) << "[NcclLogical2D][SameDim0AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } ccl::CclComm ccl_comm = comm_state->ccl_comm(); std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; class CclLogical2DSameDim0AllGather final : public user_op::OpKernel { public: CclLogical2DSameDim0AllGather() = default; ~CclLogical2DSameDim0AllGather() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = comm_state->num_ranks(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); VLOG(3) << "[NcclLogical2D][SameDim0AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl::CclComm ccl_comm = comm_state->ccl_comm(); ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; template class CclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKernel { public: CclLogical2DSameDim0AllGatherNoncontinuous() = default; ~CclLogical2DSameDim0AllGatherNoncontinuous() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { auto state = std::make_shared(ctx); NdSbp src_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); state->set_src_split_axis(src_nd_sbp.sbp_parallel(1).split_parallel().axis()); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(out->shape_view().elem_cnt() * dtype_size); void* unpack_from_ptr = tmp_buffer->mut_dptr(); CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = kernel_state->num_ranks(); const int64_t in_split_axis = kernel_state->src_split_axis(); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; VLOG(3) << "[NcclLogical2D][SameDim0AllGatherNoncontinuous] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; size_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc("out", 0); return GetCudaAlignedSize(out_tensor.shape().elem_cnt() * GetSizeOfDataType(out_tensor.data_type())); } template class CclLogical2DSameDim0All2All final : public user_op::OpKernel { public: CclLogical2DSameDim0All2All() = default; ~CclLogical2DSameDim0All2All() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { auto state = std::make_shared(ctx); NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel()); state->set_src_split_axis(src_nd_sbp.sbp_parallel(1).split_parallel().axis()); state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(1).split_parallel().axis()); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int64_t tmp_size = 0; const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size); // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out const char* pack_to_ptr = in->dptr(); char* unpack_from_ptr = out->mut_dptr(); if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); } CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = kernel_state->num_ranks(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t in_split_axis = kernel_state->src_split_axis(); const int64_t out_split_axis = kernel_state->dst_split_axis(); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; VLOG(3) << "[NcclLogical2D][SameDim0All2All] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; if (out_split_axis != 0) { // NOTE(chengcheng): Do pack. Need transpose in -> pack_to // pack use temp buffer offset: [0, data_size] pack_to_ptr = CHECK_NOTNULL(tmp_buffer)->dptr(); DimVector transpose_in_dim_vec = logical_shape_dim_vec; CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0); transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); } if (in_split_axis != 0) { // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out // unpack use temp buffer offset: [tmp_size - data_size, tmp_size] unpack_from_ptr = CHECK_NOTNULL(tmp_buffer)->mut_dptr() + (tmp_size - data_size); } { // NOTE(chengcheng): Do S2S const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; size_t Infer2DSameDim0All2AllKernelTmpBufferSize(user_op::InferContext* ctx) { size_t ret = 0; const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); size_t tensor_byte_size = GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type())); NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel()); if (src_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += tensor_byte_size; } if (dst_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += tensor_byte_size; } return ret; } class CclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState { public: explicit CclLogical2DSameDim1KernelCommState(user_op::KernelInitContext* ctx) : is_init_(false), stream_name_(EagerCclCommMgr::kDefaultCclStreamName), parallel_desc_(ctx->parallel_desc()), this_parallel_id_(ctx->parallel_ctx().parallel_id()) { if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } } ~CclLogical2DSameDim1KernelCommState() = default; const ccl::CclComm& ccl_comm() { if (!is_init_) { const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, this_parallel_id_, "SameDim1"); is_init_ = true; } return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } private: bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; int64_t this_parallel_id_; ccl::CclComm ccl_comm_{}; }; class CclLogical2DSameDim1AllReduce final : public user_op::OpKernel { public: CclLogical2DSameDim1AllReduce() = default; ~CclLogical2DSameDim1AllReduce() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); VLOG(3) << "[NcclLogical2D][SameDim1AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } ccl::CclComm ccl_comm = comm_state->ccl_comm(); std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; } // namespace REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim0_all_reduce") .SetCreateFn() .SetIsMatchedHob(AllReduceCollectiveCommunicationExists()); REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim0_all_gather") .SetCreateFn() .SetIsMatchedHob(AllGatherCollectiveCommunicationExists()); #define REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim0_all_gather_noncontinuous") \ .SetCreateFn>() \ .SetIsMatchedHob(AllGatherCollectiveCommunicationExists() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize); REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(bool) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(float) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(double) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(float16) #if defined(__CUDA_BF16_TYPES_EXIST__) REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(nv_bfloat16) #endif #define REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim0_all2all") \ .SetCreateFn>() \ .SetIsMatchedHob(AllToAllCollectiveCommunicationExists() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(Infer2DSameDim0All2AllKernelTmpBufferSize); REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(bool) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int8_t) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int32_t) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int64_t) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(float) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(double) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(float16) #if defined(__CUDA_BF16_TYPES_EXIST__) REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(nv_bfloat16) #endif REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim1_all_reduce") .SetCreateFn() .SetIsMatchedHob(AllReduceCollectiveCommunicationExists()); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_reduce"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_gather"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_gather_noncontinuous"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_2D_same_dim0_all2all"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_2D_same_dim1_all_reduce"); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/user/kernels/nccl_logical_fusion_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "collective_communication/include/collective_communication.h" #include "collective_communication/include/send.h" #include "collective_communication/include/recv.h" #include "collective_communication/include/all_gather.h" #include "collective_communication/include/all_reduce.h" #include "collective_communication/include/all_to_all.h" #include "collective_communication/include/reduce_scatter.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { namespace { size_t GetTmpBufferSizeByNcclType(const std::string& nccl_type, size_t in_tensor_byte_size, size_t out_tensor_byte_size, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp) { if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { return out_tensor_byte_size; } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { return in_tensor_byte_size; } else if (nccl_type == "_nccl_logical_s2s") { size_t ret = 0; CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); if (src_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += in_tensor_byte_size; } if (dst_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += in_tensor_byte_size; } return ret; } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { return out_tensor_byte_size; } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { size_t ret = 0; CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel()); if (src_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += in_tensor_byte_size; } if (dst_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += in_tensor_byte_size; } return ret; } return 0; } size_t GetTensorByteSize(const user_op::TensorDesc& tensor_desc) { return GetCudaAlignedSize(tensor_desc.shape().elem_cnt() * GetSizeOfDataType(tensor_desc.data_type())); } size_t GetTensorByteSize(const user_op::Tensor& tensor) { return GetCudaAlignedSize(tensor.shape_view().elem_cnt() * GetSizeOfDataType(tensor.data_type())); } class CclLogicalFusionKernelState : public user_op::OpKernelState { public: explicit CclLogicalFusionKernelState(user_op::KernelInitContext* ctx) : is_init_(false), stream_name_(EagerCclCommMgr::kDefaultCclStreamName), parallel_desc_(ctx->parallel_desc()), this_parallel_id_(ctx->parallel_ctx().parallel_id()), num_ranks_(-1), comm_key_("InvalidKey"), nccl_num_(-1) { if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } InitSplitAxisAndTmpBufferOffset(ctx); } ~CclLogicalFusionKernelState() override = default; ccl::CclComm ccl_comm() { if (!is_init_) { InitComm(); } return ccl_comm_; } int64_t num_ranks() { if (!is_init_) { InitComm(); } return num_ranks_; } const std::string& stream_name() const { return stream_name_; } int64_t src_split_axis(int32_t i) const { CHECK_GE(i, 0); CHECK_LT(i, src_split_axis_list_.size()); return src_split_axis_list_.at(i); } int64_t dst_split_axis(int32_t i) const { CHECK_GE(i, 0); CHECK_LT(i, dst_split_axis_list_.size()); return dst_split_axis_list_.at(i); } int32_t nccl_num() const { return nccl_num_; } size_t tmp_buffer_offset(int32_t i) { CHECK_GE(i, 0); CHECK_LT(i, tmp_buffer_offset_.size()); return tmp_buffer_offset_.at(i); } size_t tmp_buffer_size(int32_t i) { CHECK_GE(i, 0); CHECK_LT(i, tmp_buffer_size_.size()); return tmp_buffer_size_.at(i); } private: void InitComm() { CHECK(!is_init_); const Shape& hierarchy = *parallel_desc_.hierarchy(); if (hierarchy.NumAxes() == 1) { num_ranks_ = parallel_desc_.parallel_num(); } else if (hierarchy.NumAxes() == 2) { CHECK(comm_key_ == "SameDim0" || comm_key_ == "SameDim1"); if (comm_key_ == "SameDim0") { const int64_t group_size = hierarchy.At(1); num_ranks_ = group_size; } else if (comm_key_ == "SameDim1") { const int64_t group_size = hierarchy.At(0); num_ranks_ = group_size; } else { UNIMPLEMENTED(); } } else { UNIMPLEMENTED(); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, this_parallel_id_, comm_key_); is_init_ = true; } void UpdateOrCheckEqCommKey(const std::string& val) { if (comm_key_ == "InvalidKey") { comm_key_ = val; } else { CHECK_EQ(comm_key_, val); } } void UpdateSplitAxisByNcclType(const std::string& nccl_type, const int32_t i, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp) { if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(0).split_parallel().axis(); } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(0).split_parallel().axis(); } else if (nccl_type == "_nccl_logical_s2s") { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(0).split_parallel().axis(); dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(0).split_parallel().axis(); CHECK_NE(src_split_axis_list_.at(i), dst_split_axis_list_.at(i)); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(1).split_parallel().axis(); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel()); src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(1).split_parallel().axis(); dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(1).split_parallel().axis(); CHECK_NE(src_split_axis_list_.at(i), dst_split_axis_list_.at(i)); } } void InitSplitAxisAndTmpBufferOffset(user_op::KernelInitContext* ctx) { nccl_num_ = ctx->input_size("in"); const std::vector& src_nd_sbp_str_list = ctx->Attr>("src_nd_sbp_str_list"); const std::vector& dst_nd_sbp_str_list = ctx->Attr>("dst_nd_sbp_str_list"); const std::vector& nccl_type_list = ctx->Attr>("nccl_type_list"); CHECK_EQ(nccl_num_, ctx->output_size("out")); src_split_axis_list_.resize(nccl_num_, -1); dst_split_axis_list_.resize(nccl_num_, -1); CHECK_EQ(src_nd_sbp_str_list.size(), nccl_num_); CHECK_EQ(dst_nd_sbp_str_list.size(), nccl_num_); CHECK_EQ(nccl_type_list.size(), nccl_num_); CHECK_EQ(src_split_axis_list_.size(), nccl_num_); CHECK_EQ(dst_split_axis_list_.size(), nccl_num_); size_t total_buffer_size = 0; for (int32_t i = 0; i < nccl_num_; ++i) { NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK(ParseNdSbpFromLongString(src_nd_sbp_str_list.at(i), &src_nd_sbp)); CHECK(ParseNdSbpFromLongString(dst_nd_sbp_str_list.at(i), &dst_nd_sbp)); const std::string& nccl_type = nccl_type_list.at(i); UpdateOrCheckEqCommKey(GetCommKeyFromNcclType(nccl_type)); UpdateSplitAxisByNcclType(nccl_type, i, src_nd_sbp, dst_nd_sbp); size_t in_tensor_byte_size = GetTensorByteSize(*ctx->TensorDesc4ArgNameAndIndex("in", i)); size_t out_tensor_byte_size = GetTensorByteSize(*ctx->TensorDesc4ArgNameAndIndex("out", i)); tmp_buffer_offset_.push_back(total_buffer_size); size_t tmp_buffer_size = GetTmpBufferSizeByNcclType( nccl_type, in_tensor_byte_size, out_tensor_byte_size, src_nd_sbp, dst_nd_sbp); tmp_buffer_size_.push_back(tmp_buffer_size); total_buffer_size += tmp_buffer_size; } // NOTE(chengcheng): last element of vector is total_buffer_size tmp_buffer_offset_.push_back(total_buffer_size); CHECK_EQ(tmp_buffer_offset_.size(), nccl_num_ + 1); CHECK_EQ(tmp_buffer_size_.size(), nccl_num_); const user_op::TensorDesc* tmp_buffer_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("tmp_buffer", 0); if (tmp_buffer_tensor_desc == nullptr) { CHECK_EQ(total_buffer_size, 0); } else { CHECK_EQ(total_buffer_size, GetTensorByteSize(*tmp_buffer_tensor_desc)); } } bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; int64_t this_parallel_id_; int64_t num_ranks_; std::string comm_key_; int32_t nccl_num_; std::vector src_split_axis_list_; std::vector dst_split_axis_list_; std::vector tmp_buffer_offset_; std::vector tmp_buffer_size_; ccl::CclComm ccl_comm_{}; }; class CclLogicalFusionKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(CclLogicalFusionKernel); CclLogicalFusionKernel() = default; ~CclLogicalFusionKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); return true; } }; const void* UpdatePackToPtrByNcclType(const void* pack_to_ptr, const std::string& nccl_type, user_op::Tensor* tmp_buffer, CclLogicalFusionKernelState* kernel_state, const int32_t i) { CHECK_NOTNULL(tmp_buffer); const void* tmp_dptr = static_cast(tmp_buffer->dptr() + kernel_state->tmp_buffer_offset(i)); if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { return tmp_dptr; } else if (nccl_type == "_nccl_logical_s2s") { if (kernel_state->dst_split_axis(i) != 0) { return tmp_dptr; // need do pack; } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { if (kernel_state->dst_split_axis(i) != 0) { return tmp_dptr; // need do pack; } } return pack_to_ptr; } void* UpdateUnpackFromPtrByNcclType(void* unpack_from_ptr, const std::string& nccl_type, user_op::Tensor* tmp_buffer, const user_op::Tensor* in, CclLogicalFusionKernelState* kernel_state, const int32_t i) { CHECK_NOTNULL(tmp_buffer); void* tmp_dptr = static_cast(tmp_buffer->mut_dptr() + kernel_state->tmp_buffer_offset(i)); int64_t data_size = GetTensorByteSize(*in); int64_t tmp_buffer_size = kernel_state->tmp_buffer_size(i); if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { return tmp_dptr; } else if (nccl_type == "_nccl_logical_s2s") { if (kernel_state->src_split_axis(i) != 0) { CHECK(tmp_buffer_size == data_size || tmp_buffer_size == 2 * data_size); return static_cast(static_cast(tmp_dptr) + (tmp_buffer_size - data_size)); } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { return tmp_dptr; } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { if (kernel_state->src_split_axis(i) != 0) { CHECK(tmp_buffer_size == data_size || tmp_buffer_size == 2 * data_size); return static_cast(static_cast(tmp_dptr) + (tmp_buffer_size - data_size)); } } return unpack_from_ptr; } void DoPackBeforeNcclGroup(void* pack_to_ptr, const std::string& nccl_type, const user_op::Tensor* in, user_op::KernelComputeContext* ctx, CclLogicalFusionKernelState* kernel_state, const int32_t i) { if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { // Do pack before reduce scatter const int64_t num_ranks = kernel_state->num_ranks(); const int64_t out_split_axis = kernel_state->dst_split_axis(i); DimVector transpose_in_dim_vec; in->shape_view().ToDimVector(&transpose_in_dim_vec); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); const Shape transpose_in_shape(transpose_in_dim_vec); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do pack before [ReduceScatter]"; } else if (nccl_type == "_nccl_logical_s2s") { const int64_t out_split_axis = kernel_state->dst_split_axis(i); if (out_split_axis != 0) { // Do pack before all2all const int64_t num_ranks = kernel_state->num_ranks(); DimVector transpose_in_dim_vec; in->shape_view().ToDimVector(&transpose_in_dim_vec); CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do pack before [All2All]"; } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t out_split_axis = kernel_state->dst_split_axis(i); if (out_split_axis != 0) { const int64_t num_ranks = kernel_state->num_ranks(); DimVector transpose_in_dim_vec; in->shape_view().ToDimVector(&transpose_in_dim_vec); CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do pack before [2DSameDim0All2All]"; } } } void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_ptr, const std::string& nccl_type, const user_op::Tensor* in, user_op::Tensor* out, user_op::KernelComputeContext* ctx, CclLogicalFusionKernelState* kernel_state, const int32_t i, ccl::CclComm ccl_comm) { std::unique_ptr ccl_send = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); std::unique_ptr ccl_recv = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); const int64_t num_ranks = kernel_state->num_ranks(); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Try launch nccl_type: " << nccl_type; if (nccl_type == "_nccl_logical_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_reduce_scatter") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_reduce_scatter = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { CHECK(in->dptr() != pack_to_ptr); // do in -> pack to ptr CHECK(out->mut_dptr() == unpack_from_ptr); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_reduce_scatter = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_s2s") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); } else { UNIMPLEMENTED(); } VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " launched nccl_type: " << nccl_type; } void DoUnpackAfterNcclGroup(void* unpack_from_ptr, const std::string& nccl_type, const user_op::Tensor* in, user_op::Tensor* out, user_op::KernelComputeContext* ctx, CclLogicalFusionKernelState* kernel_state, const int32_t i) { const int64_t num_ranks = kernel_state->num_ranks(); const int64_t in_split_axis = kernel_state->src_split_axis(i); const int64_t out_split_axis = kernel_state->dst_split_axis(i); if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK_GT(in_split_axis, 0); DimVector unpack_from_dim_vec; in->shape_view().ToDimVector(&unpack_from_dim_vec); unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do unpack after [AllGatherNoncontinuous]"; } else if (nccl_type == "_nccl_logical_s2s") { CHECK_GE(in_split_axis, 0); CHECK_GE(out_split_axis, 0); if (in_split_axis != 0) { // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); DimVector unpack_from_dim_vec; in->shape_view().ToDimVector(&unpack_from_dim_vec); CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do unpack after [All2All]"; } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { DimVector unpack_from_dim_vec; in->shape_view().ToDimVector(&unpack_from_dim_vec); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Do unpack after [SameDim0AllGatherNoncontinuous]"; } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { CHECK_GE(in_split_axis, 0); CHECK_GE(out_split_axis, 0); if (in_split_axis != 0) { DimVector unpack_from_dim_vec; in->shape_view().ToDimVector(&unpack_from_dim_vec); // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } } } void CclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const int32_t nccl_num = kernel_state->nccl_num(); const std::vector& nccl_type_list = ctx->Attr>("nccl_type_list"); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); // NOTE(chengcheng): // pack: in.dptr -> pack_to_ptr // if not pack : pack_to_ptr = in.dptr // nccl: pack_to_ptr -> unpack_from_ptr // unpack: unpack_from_ptr ->out.dptr // if not unpack: unpack_from_ptr = out.dptr std::vector pack_to_ptr_list(nccl_num, nullptr); std::vector unpack_from_ptr_list(nccl_num, nullptr); std::vector dtype_list(nccl_num, DataType::kInvalidDataType); for (int32_t i = 0; i < nccl_num; ++i) { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); pack_to_ptr_list.at(i) = in->dptr(); unpack_from_ptr_list.at(i) = out->mut_dptr(); dtype_list.at(i) = in->data_type(); CHECK_EQ(dtype_list.at(i), out->data_type()); } // try to do pack before all nccl for (int32_t i = 0; i < nccl_num; ++i) { if (kernel_state->tmp_buffer_size(i) == 0) { continue; } const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); // TODO(chengcheng): refactor code by template. pack_to_ptr_list.at(i) = UpdatePackToPtrByNcclType(pack_to_ptr_list.at(i), nccl_type_list.at(i), tmp_buffer, kernel_state, i); unpack_from_ptr_list.at(i) = UpdateUnpackFromPtrByNcclType( unpack_from_ptr_list.at(i), nccl_type_list.at(i), tmp_buffer, in, kernel_state, i); DoPackBeforeNcclGroup(const_cast(pack_to_ptr_list.at(i)) /* mut dptr */, nccl_type_list.at(i), in, ctx, kernel_state, i); } // NOTE(chengcheng): init nccl comm need before ncclGroupStart. ccl::CclComm ccl_comm = kernel_state->ccl_comm(); // do nccl compute in group // TODO:(zhaoluyang) replacre ncclGroupStart/ncclGroupEnd with ccl CclGroupStart/CclGroupEnd // OF_NCCL_CHECK(ncclGroupStart()); for (int32_t i = 0; i < nccl_num; ++i) { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i), nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm); } // OF_NCCL_CHECK(ncclGroupEnd()); // try to do unpack after all nccl for (int32_t i = 0; i < nccl_num; ++i) { if (kernel_state->tmp_buffer_size(i) == 0) { continue; } const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); DoUnpackAfterNcclGroup(unpack_from_ptr_list.at(i), nccl_type_list.at(i), in, out, ctx, kernel_state, i); } } size_t InferNcclLogicalFusionKernelTmpBufferSize(user_op::InferContext* ctx) { size_t total_buffer_size = 0; const auto& src_nd_sbp_str_list = ctx->Attr>("src_nd_sbp_str_list"); const auto& dst_nd_sbp_str_list = ctx->Attr>("dst_nd_sbp_str_list"); const auto& nccl_type_list = ctx->Attr>("nccl_type_list"); int32_t nccl_num = nccl_type_list.size(); CHECK_EQ(nccl_num, ctx->input_size("in")); CHECK_EQ(nccl_num, ctx->output_size("out")); CHECK_EQ(nccl_num, src_nd_sbp_str_list.size()); CHECK_EQ(nccl_num, dst_nd_sbp_str_list.size()); for (int32_t i = 0; i < nccl_num; ++i) { const std::string& nccl_type = nccl_type_list.at(i); size_t in_tensor_byte_size = GetTensorByteSize(ctx->InputTensorDesc("in", i)); size_t out_tensor_byte_size = GetTensorByteSize(ctx->OutputTensorDesc("out", i)); NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK(ParseNdSbpFromLongString(src_nd_sbp_str_list.at(i), &src_nd_sbp)); CHECK(ParseNdSbpFromLongString(dst_nd_sbp_str_list.at(i), &dst_nd_sbp)); total_buffer_size += GetTmpBufferSizeByNcclType(nccl_type, in_tensor_byte_size, out_tensor_byte_size, src_nd_sbp, dst_nd_sbp); } return total_buffer_size; } REGISTER_USER_KERNEL("_nccl_logical_fusion") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) || (user_op::HobDeviceType() == DeviceType::kNPU) || (user_op::HobDeviceType() == DeviceType::kMLU)) .SetInferTmpSizeFn(InferNcclLogicalFusionKernelTmpBufferSize); // TODO: SetIsMatchedHob support multi devices(not including cpu) } // namespace } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/user/kernels/nccl_logical_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #include "oneflow/user/kernels/collective_communication/include/all_reduce.h" #include "oneflow/user/kernels/collective_communication/include/all_gather.h" #include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #include "oneflow/user/kernels/collective_communication/include/broadcast.h" #include "oneflow/user/kernels/collective_communication/include/reduce.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { namespace { auto AllReduceCollectiveCommunicationExists() { return hob::make_custom("AllReduceCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllReduceRegistered(device_type); }); } auto ReduceScatterCollectiveCommunicationExists() { return hob::make_custom("ReduceScatterCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsReduceScatterRegistered(device_type); }); } auto AllGatherCollectiveCommunicationExists() { return hob::make_custom("AllGatherCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllGatherRegistered(device_type); }); } auto ReduceCollectiveCommunicationExists() { return hob::make_custom("ReduceCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsReduceRegistered(device_type); }); } auto BroadcastCollectiveCommunicationExists() { return hob::make_custom("BroadcastCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsBroadcastRegistered(device_type); }); } auto AllToAllCollectiveCommunicationExists() { return hob::make_custom("AllToAllCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsCommunicationContextRegistered(device_type) && ccl::IsAllToAllRegistered(device_type); }); } class NcclLogicalKernelCommState : public user_op::OpKernelState { public: explicit NcclLogicalKernelCommState(user_op::KernelInitContext* ctx) : is_init_(false), stream_name_(EagerCclCommMgr::kDefaultCclStreamName), parallel_desc_(ctx->parallel_desc()) { if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } } ~NcclLogicalKernelCommState() override = default; const ccl::CclComm& ccl_comm() { if (!is_init_) { EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_); is_init_ = true; } return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } private: bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; ccl::CclComm ccl_comm_{}; }; class NcclLogicalAllGatherNoncontinuousKernelState : public NcclLogicalKernelCommState { public: explicit NcclLogicalAllGatherNoncontinuousKernelState(user_op::KernelInitContext* ctx) : NcclLogicalKernelCommState(ctx), src_split_axis_(-1) {} ~NcclLogicalAllGatherNoncontinuousKernelState() override = default; int64_t src_split_axis() const { return src_split_axis_; } void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; } private: int64_t src_split_axis_; }; class NcclLogicalReduceScatterNoncontinuousKernelState : public NcclLogicalKernelCommState { public: explicit NcclLogicalReduceScatterNoncontinuousKernelState(user_op::KernelInitContext* ctx) : NcclLogicalKernelCommState(ctx), dst_split_axis_(-1) {} ~NcclLogicalReduceScatterNoncontinuousKernelState() override = default; int64_t dst_split_axis() const { return dst_split_axis_; } void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; } private: int64_t dst_split_axis_; }; class NcclLogicalS2SKernelState : public NcclLogicalKernelCommState { public: explicit NcclLogicalS2SKernelState(user_op::KernelInitContext* ctx) : NcclLogicalKernelCommState(ctx), src_split_axis_(-1), dst_split_axis_(-1) {} ~NcclLogicalS2SKernelState() override = default; int64_t src_split_axis() const { return src_split_axis_; } void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; } int64_t dst_split_axis() const { return dst_split_axis_; } void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; } private: int64_t src_split_axis_; int64_t dst_split_axis_; }; class CclLogicalAllReduceKernel final : public user_op::OpKernel { public: CclLogicalAllReduceKernel() = default; ~CclLogicalAllReduceKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); VLOG(3) << "[NcclLogical][AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::CclComm ccl_comm = comm_state->ccl_comm(); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_all_reduce = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; class CclLogicalReduceScatterKernel final : public user_op::OpKernel { public: CclLogicalReduceScatterKernel() = default; ~CclLogicalReduceScatterKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); VLOG(3) << "[NcclLogical][ReduceScatter] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::CclComm ccl_comm = comm_state->ccl_comm(); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_reduce_scatter = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; class CclLogicalAllGatherKernel final : public user_op::OpKernel { public: CclLogicalAllGatherKernel() = default; ~CclLogicalAllGatherKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* comm_state = dynamic_cast(state); CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); VLOG(3) << "[NcclLogical][AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::CclComm ccl_comm = comm_state->ccl_comm(); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; template class CclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { public: CclLogicalAllGatherNoncontinuous() = default; ~CclLogicalAllGatherNoncontinuous() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { auto state = std::make_shared(ctx); NdSbp src_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); state->set_src_split_axis(src_nd_sbp.sbp_parallel(0).split_parallel().axis()); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(out->shape_view().elem_cnt() * dtype_size); void* unpack_from_ptr = tmp_buffer->mut_dptr(); CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); const int64_t in_split_axis = kernel_state->src_split_axis(); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; VLOG(3) << "[NcclLogical][AllGatherNoncontinuous] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); std::unique_ptr ccl_all_gather = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; size_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc("out", 0); return GetCudaAlignedSize(out_tensor.shape().elem_cnt() * GetSizeOfDataType(out_tensor.data_type())); } template class CclLogicalReduceScatterNoncontinuous final : public user_op::OpKernel { public: CclLogicalReduceScatterNoncontinuous() = default; ~CclLogicalReduceScatterNoncontinuous() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { auto state = std::make_shared(ctx); NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(0).split_parallel().axis()); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size); CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); const int64_t out_split_axis = kernel_state->dst_split_axis(); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); DimVector transpose_in_dim_vec = logical_shape_dim_vec; transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); const Shape transpose_in_shape(transpose_in_dim_vec); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); VLOG(3) << "[NcclLogical][ReduceScatterNoncontinuous] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; ccl::CclComm ccl_comm = kernel_state->ccl_comm(); ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } std::unique_ptr ccl_reduce_scatter = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type(), ccl_reduce_type); ccl_reduce_scatter->Launch(ctx->stream(), tmp_buffer->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; size_t InferReduceScatterNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->OutputTensorDesc("in", 0); return GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type())); } template class CclLogicalS2SKernel final : public user_op::OpKernel { public: CclLogicalS2SKernel() = default; ~CclLogicalS2SKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { auto state = std::make_shared(ctx); NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); state->set_src_split_axis(src_nd_sbp.sbp_parallel(0).split_parallel().axis()); state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(0).split_parallel().axis()); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int64_t tmp_size = 0; const int64_t dtype_size = GetSizeOfDataType(in->data_type()); int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size); // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out const char* pack_to_ptr = in->dptr(); char* unpack_from_ptr = out->mut_dptr(); if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); } CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t in_split_axis = kernel_state->src_split_axis(); const int64_t out_split_axis = kernel_state->dst_split_axis(); DimVector logical_shape_dim_vec; in->shape_view().ToDimVector(&logical_shape_dim_vec); logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; VLOG(3) << "[NcclLogical][S2S] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; if (out_split_axis != 0) { // NOTE(chengcheng): Do pack. Need transpose in -> pack_to // pack use temp buffer offset: [0, data_size] pack_to_ptr = CHECK_NOTNULL(tmp_buffer)->dptr(); DimVector transpose_in_dim_vec = logical_shape_dim_vec; CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0); transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0); transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks; transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks); std::vector perm; perm.emplace_back(out_split_axis); FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) { if (i != out_split_axis) { perm.emplace_back(i); } } auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), transpose_in_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(), transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); } if (in_split_axis != 0) { // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out // unpack use temp buffer offset: [tmp_size - data_size, tmp_size] unpack_from_ptr = CHECK_NOTNULL(tmp_buffer)->mut_dptr() + (tmp_size - data_size); } { // NOTE(chengcheng): Do S2S const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { // Do unpack. CHECK(unpack_from_ptr != out->mut_dptr()); DimVector unpack_from_dim_vec = logical_shape_dim_vec; CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0); unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks; unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); std::vector perm; FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); } perm.insert(perm.begin() + in_split_axis, 0); auto transpose = ep::primitive::NewPrimitive( ctx->stream()->device_type(), unpack_from_dim_vec.size()); CHECK(transpose); transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(), unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr()); } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; size_t InferS2SKernelTmpBufferSize(user_op::InferContext* ctx) { size_t ret = 0; const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); size_t tensor_byte_size = GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type())); NdSbp src_nd_sbp; NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1); CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel()); CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel()); if (src_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += tensor_byte_size; } if (dst_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += tensor_byte_size; } return ret; } } // namespace REGISTER_USER_KERNEL("_nccl_logical_all_reduce") .SetCreateFn() .SetIsMatchedHob(AllReduceCollectiveCommunicationExists()); REGISTER_USER_KERNEL("_nccl_logical_reduce_scatter") .SetCreateFn() .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists()); REGISTER_USER_KERNEL("_nccl_logical_all_gather") .SetCreateFn() .SetIsMatchedHob(AllGatherCollectiveCommunicationExists()); #define REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_all_gather_noncontinuous") \ .SetCreateFn>() \ .SetIsMatchedHob(AllGatherCollectiveCommunicationExists() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferAllGatherNoncontinuousKernelTmpBufferSize); REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(bool) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(double) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float16) #if defined(__CUDA_BF16_TYPES_EXIST__) REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(nv_bfloat16) #endif #define REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_reduce_scatter_noncontinuous") \ .SetCreateFn>() \ .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferReduceScatterNoncontinuousKernelTmpBufferSize); REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(bool) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int8_t) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int32_t) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int64_t) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(float) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(double) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(float16) #if defined(__CUDA_BF16_TYPES_EXIST__) REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(nv_bfloat16) #endif #define REGISTER_S2S_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_s2s") \ .SetCreateFn>() \ .SetIsMatchedHob(AllToAllCollectiveCommunicationExists() \ && (user_op::HobDataType("in", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferS2SKernelTmpBufferSize); REGISTER_S2S_KERNEL(bool) REGISTER_S2S_KERNEL(int8_t) REGISTER_S2S_KERNEL(int32_t) REGISTER_S2S_KERNEL(int64_t) REGISTER_S2S_KERNEL(float) REGISTER_S2S_KERNEL(double) REGISTER_S2S_KERNEL(float16) #if defined(__CUDA_BF16_TYPES_EXIST__) REGISTER_S2S_KERNEL(nv_bfloat16) #endif REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_all_reduce"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_reduce_scatter"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_all_gather"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_all_gather_noncontinuous"); REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT("_nccl_logical_s2s"); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "collective_communication/include/collective_communication.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU) namespace oneflow { class CclLogicalSendRecvState final : public user_op::OpKernelState { public: explicit CclLogicalSendRecvState(user_op::KernelInitContext* ctx); const std::vector>& in_tensor_slice_copier_vec() const { return in_tensor_slice_copier_vec_; } const std::vector>& out_tensor_slice_copier_vec() const { return out_tensor_slice_copier_vec_; } bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; } const std::vector& send_elem_cnts() const { return send_elem_cnts_; } const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } ccl::CclComm ccl_comm() const { return GetOrCreateComm().ccl_comm; } private: struct Comm { Comm(ccl::CclComm comm) : ccl_comm(comm) {} ccl::CclComm ccl_comm; }; void InitComm() const; const Comm& GetOrCreateComm() const { if (!ccl_comm_) { InitComm(); } return *ccl_comm_; } std::string stream_name_; std::unique_ptr parallel_desc_; mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; std::vector send_elem_cnts_; std::vector recv_elem_cnts_; }; CclLogicalSendRecvState::CclLogicalSendRecvState(user_op::KernelInitContext* ctx) : stream_name_(EagerCclCommMgr::kDefaultCclStreamName) { if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); } const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); parallel_desc_ = std::make_unique(ctx->parallel_desc()); NdSbp src_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); const auto& parallel_hierarchy = parallel_desc_->hierarchy(); src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const DataType data_type = in_logical_desc->data_type(); const Shape& logical_shape = Shape(in_logical_desc->shape()); const DeviceType device_type = parallel_desc_->device_type(); const int64_t parallel_num = parallel_desc_->parallel_num(); std::vector src_send_intersections; std::vector dst_recv_intersections; GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/*parallel_desc_, /*in_parallel_desc=*/*parallel_desc_, /*out_parallel_desc=*/*parallel_desc_, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); CHECK_EQ(src_send_intersections.size(), parallel_num); send_elem_cnts_.resize(parallel_num); in_tensor_slice_copier_vec_.resize(parallel_num); const TensorSliceView& cur_rank_in_slice = GetTensorSliceView4ParallelId(*parallel_hierarchy, src_nd_sbp, logical_shape, parallel_id); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = src_send_intersections.at(i); if (!intersection.IsEmpty()) { send_elem_cnts_.at(i) = intersection.shape().elem_cnt(); in_tensor_slice_copier_vec_.at(i).reset( new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type)); } } CHECK_EQ(dst_recv_intersections.size(), parallel_num); recv_elem_cnts_.resize(parallel_num); out_tensor_slice_copier_vec_.resize(parallel_num); const TensorSliceView& cur_rank_out_slice = GetTensorSliceView4ParallelId(*parallel_hierarchy, dst_nd_sbp, logical_shape, parallel_id); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = dst_recv_intersections.at(i); if (!intersection.IsEmpty()) { recv_elem_cnts_.at(i) = intersection.shape().elem_cnt(); out_tensor_slice_copier_vec_.at(i).reset( new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type)); } } } void CclLogicalSendRecvState::InitComm() const { EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ccl::CclComm ccl_comm = comm_mgr->GetCclCommForParallelDescAndStreamName(*parallel_desc_.get(), stream_name_); ccl_comm_.reset(new Comm(ccl_comm)); } class CclLogicalSendRecv final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(CclLogicalSendRecv); CclLogicalSendRecv() = default; ~CclLogicalSendRecv() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); return comm_mgr->IsAsyncLaunchCclLogicalKernel(); } }; void CclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const { auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); const int64_t parallel_num = send_elem_cnts.size(); const DataType data_type = in->data_type(); std::vector send_in_ptr; std::vector recv_out_ptr; std::vector send_offsets; std::vector recv_offsets; char* buf_ptr = tmp_buffer->mut_dptr(); uint64_t offset = 0; for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); } const uint64_t recv_offset = offset; for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); } const std::vector>& in_tensor_slice_copier_vec = kernel_state->in_tensor_slice_copier_vec(); for (int64_t i = 0; i < parallel_num; ++i) { if (in_tensor_slice_copier_vec.at(i)) { in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); } } std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true, /*has_output=*/true); const std::vector>& out_tensor_slice_copier_vec = kernel_state->out_tensor_slice_copier_vec(); if (kernel_state->src_nd_sbp_has_no_partial_parallel()) { for (int64_t i = 0; i < parallel_num; ++i) { if (out_tensor_slice_copier_vec.at(i)) { out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); } } } else { std::unique_ptr add_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out->data_type()); CHECK(add_primitive); std::unique_ptr memset_primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); CHECK(memset_primitive); bool is_first_slice = true; for (int64_t i = 0; i < parallel_num; ++i) { if (out_tensor_slice_copier_vec.at(i)) { if (is_first_slice) { is_first_slice = false; if (recv_elem_cnts.at(i) != out->shape_view().elem_cnt()) { // if not same shape, memset out memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); } out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); } else { if (recv_elem_cnts.at(i) == out->shape_view().elem_cnt()) { add_primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(), out->shape_view().elem_cnt()); } else { void* out_buf = reinterpret_cast(buf_ptr + offset); memset_primitive->Launch(ctx->stream(), out_buf, 0, out->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); add_primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), out->shape_view().elem_cnt()); } } } } } } size_t InferTmpBufferSize(user_op::InferContext* ctx) { const Shape& out_shape = ctx->OutputShape("out", 0); const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const Shape& logical_shape = logical_in_tensor->shape(); const DataType data_type = logical_in_tensor->data_type(); const NdSbp& src_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); const NdSbp& dst_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const int64_t parallel_num = ctx->parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); std::vector src_send_intersections; std::vector dst_recv_intersections; const auto& parallel_desc = ctx->parallel_desc(); GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/parallel_desc, /*in_parallel_desc=*/parallel_desc, /*out_parallel_desc=*/parallel_desc, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); int64_t buf_count = 0; CHECK_EQ(src_send_intersections.size(), parallel_num); for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = src_send_intersections.at(i); if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } } for (int64_t i = 0; i < parallel_num; ++i) { const TensorSliceView& intersection = dst_recv_intersections.at(i); if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } } if (NdSbpHasPartialParallel(src_nd_sbp)) { // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. buf_count += out_shape.elem_cnt(); } return buf_count * GetSizeOfDataType(data_type); } // TODO:(zhaoluyang) SetIsMatchedHob support multi devices(not including cpu) REGISTER_USER_KERNEL("_nccl_logical_send_recv") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) || (user_op::HobDeviceType() == DeviceType::kNPU) || (user_op::HobDeviceType() == DeviceType::kMLU)) .SetInferTmpSizeFn(InferTmpBufferSize); } // namespace oneflow #endif // WITH_CUDA || WITH_NPU || WITH_MLU ================================================ FILE: oneflow/user/kernels/nd_index_slice_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/nd_index_slice_kernels.h" namespace oneflow { template struct GatherNdFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* dense, T* slices) const { DoGatherNd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, dense, slices); } }; template struct ScatterNdAddFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { DoScatterNdAdd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, slices, dense); } }; template struct ScatterNdUpdateFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { DoScatterNdUpdate(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, slices, dense); } }; template struct ScatterNdUpdateWithStrideFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { DoScatterNdUpdateWithStride(args.num_slices * args.slice_size, args, indices, slices, dense); } }; template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense, T value) const { DoFillByNdIndex(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, dense, value); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nd_index_slice_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/nd_index_slice_kernels.h" #include "oneflow/core/cuda/atomic.cuh" namespace oneflow { namespace { template __global__ void CudaGatherNd(NdIndexSliceArgs args, const I* indices, const T* dense, T* slices) { DoGatherNd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, dense, slices); } template __global__ void CudaScatterNdAdd(NdIndexSliceArgs args, const I* indices, const T* slices, T* dense) { DoScatterNdAdd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, slices, dense); } template __global__ void CudaScatterNdUpdate(NdIndexSliceArgs args, const I* indices, const T* slices, T* dense) { DoScatterNdUpdate(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, slices, dense); } template __global__ void CudaScatterNdUpdateWithStride(NdIndexSliceArgs args, const I* indices, const T* slices, T* dense) { DoScatterNdUpdateWithStride(args.num_slices * args.slice_size, args, indices, slices, dense); } template __global__ void CudaFillByNdIndex(NdIndexSliceArgs args, const I* indices, T* dense, T value) { DoFillByNdIndex(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape, indices, dense, value); } } // namespace template struct GatherNdFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* dense, T* slices) const { RUN_CUDA_KERNEL((CudaGatherNd), stream, args.num_slices * args.slice_size, args, indices, dense, slices); } }; template struct ScatterNdAddFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { RUN_CUDA_KERNEL((CudaScatterNdAdd), stream, args.num_slices * args.slice_size, args, indices, slices, dense); } }; template struct ScatterNdUpdateFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { RUN_CUDA_KERNEL((CudaScatterNdUpdate), stream, args.num_slices * args.slice_size, args, indices, slices, dense); } }; template struct ScatterNdUpdateWithStrideFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const { RUN_CUDA_KERNEL((CudaScatterNdUpdateWithStride), stream, args.num_slices * args.slice_size, args, indices, slices, dense); } }; template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense, T value) const { RUN_CUDA_KERNEL((CudaFillByNdIndex), stream, args.num_slices * args.slice_size, args, indices, dense, value); } }; template struct DeviceAdd { __device__ __forceinline__ static void Invoke(const T* x, T* y) { cuda::atomic::Add(y, *x); } }; template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const bool* x, bool* y) { *y += *x; } }; template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const uint8_t* x, uint8_t* y) { *y += *x; } }; template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const int8_t* x, int8_t* y) { *y += *x; } }; template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const int64_t* x, int64_t* y) { *y += *x; } }; #define CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ \ FLOATING_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( INSTANTIATE_GATHER_ND_FUNCTOR, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCATTER_ND_ADD_FUNCTOR, (DeviceType::kCUDA), CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_GATHER_ND_KERNELS, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_SCATTER_ND_KERNELS, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SCATTER_ND_LIKE_KERNELS, (DeviceType::kCUDA), CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_TENSOR_GATHER_ND_ADD_KERNELS, (DeviceType::kCUDA), CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const float16* x, float16* y) { cuda::atomic::Add(reinterpret_cast(y), *(reinterpret_cast(x))); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCUDA), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCUDA), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #if defined(__CUDA_BF16_TYPES_EXIST__) template<> struct DeviceAdd { __device__ __forceinline__ static void Invoke(const bfloat16* x, bfloat16* y) { cuda::atomic::Add(reinterpret_cast(y), *(reinterpret_cast(x))); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCUDA), BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCUDA), BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nd_index_slice_kernels.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_ #define ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_ #include "oneflow/user/kernels/nd_index_slice_util.h" #include "oneflow/core/common/tensor_meta.h" namespace oneflow { template class GatherNdKernel final : public user_op::OpKernel { public: GatherNdKernel() = default; ~GatherNdKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ScatterNdKernel final : public user_op::OpKernel { public: ScatterNdKernel() = default; ~ScatterNdKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class TensorScatterNdUpdateKernel final : public user_op::OpKernel { public: TensorScatterNdUpdateKernel() = default; ~TensorScatterNdUpdateKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class TensorScatterNdAddKernel final : public user_op::OpKernel { public: TensorScatterNdAddKernel() = default; ~TensorScatterNdAddKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template void GatherNdKernel::Compute(user_op::KernelComputeContext* ctx) const { const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex("params", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); if (params->shape_view().elem_cnt() == 0 || indices->shape_view().elem_cnt() == 0) { return; } auto args = ConstructNdIndexSliceArgs(*params, *out, *indices); GatherNdFunctor()(ctx->stream(), args, indices->dptr(), params->dptr(), out->mut_dptr()); } template void ScatterNdKernel::Compute(user_op::KernelComputeContext* ctx) const { const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex("updates", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()); Memset(ctx->stream(), out->mut_dptr(), 0, out_bytes_size); if (indices->shape_view().elem_cnt() == 0) { return; } auto args = ConstructNdIndexSliceArgs(*out, *updates, *indices); ScatterNdAddFunctor()(ctx->stream(), args, indices->dptr(), updates->dptr(), out->mut_dptr()); } template void TensorScatterNdUpdateKernel::Compute( user_op::KernelComputeContext* ctx) const { const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex("params", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex("updates", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()); Memcpy(ctx->stream(), out->mut_dptr(), params->dptr(), out_bytes_size); if (indices->shape_view().elem_cnt() == 0) { return; } auto args = ConstructNdIndexSliceArgs(*params, *updates, *indices); if (one::IsContiguous(params->shape_view(), params->stride()) && one::IsContiguous(updates->shape_view(), updates->stride())) { ScatterNdUpdateFunctor()(ctx->stream(), args, indices->dptr(), updates->dptr(), out->mut_dptr()); } else { ScatterNdUpdateWithStrideFunctor()(ctx->stream(), args, indices->dptr(), updates->dptr(), out->mut_dptr()); } } template void TensorScatterNdAddKernel::Compute( user_op::KernelComputeContext* ctx) const { const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex("params", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex("updates", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()); Memcpy(ctx->stream(), out->mut_dptr(), params->dptr(), out_bytes_size); if (indices->shape_view().elem_cnt() == 0) { return; } auto args = ConstructNdIndexSliceArgs(*params, *updates, *indices); ScatterNdAddFunctor()(ctx->stream(), args, indices->dptr(), updates->dptr(), out->mut_dptr()); } #define REGISTER_GATHER_SCATTER_ND_KERNELS(op_type_name, op, device_type_v, dtype_pair, \ itype_pair) \ REGISTER_USER_KERNEL(#op_type_name) \ .SetCreateFn< \ op##Kernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(itype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); #define REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(op_type_name, opt, device_type_v, dtype_pair, \ itype_pair) \ REGISTER_USER_KERNEL(#op_type_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("indices", 0) == OF_PP_PAIR_SECOND(itype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "params", 0, true)); \ return Maybe::Ok(); \ }); #define REGISTER_GATHER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_GATHER_SCATTER_ND_KERNELS(gather_nd, GatherNd, device_type_v, dtype_pair, itype_pair) #define REGISTER_SCATTER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_GATHER_SCATTER_ND_KERNELS(scatter_nd, ScatterNd, device_type_v, dtype_pair, itype_pair) #define REGISTER_SCATTER_ND_LIKE_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_GATHER_SCATTER_ND_KERNELS(scatter_nd_like, ScatterNd, device_type_v, dtype_pair, \ itype_pair) #define REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(tensor_scatter_nd_update, Update, device_type_v, \ dtype_pair, itype_pair) #define REGISTER_TENSOR_GATHER_ND_ADD_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(tensor_scatter_nd_add, Add, device_type_v, dtype_pair, \ itype_pair) #define REGISTER_ND_INDEX_SLICE_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_GATHER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_SCATTER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_SCATTER_ND_LIKE_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS(device_type_v, dtype_pair, itype_pair) \ REGISTER_TENSOR_GATHER_ND_ADD_KERNELS(device_type_v, dtype_pair, itype_pair) } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_ ================================================ FILE: oneflow/user/kernels/nd_index_slice_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_ #define ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_ #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { struct NdIndexSliceArgs { static const size_t kMaxDims = 8; int64_t num_slices; // The number of slices (indices_shape.Count(0, -1)) int64_t slice_size; // The element_cnt of each slice (sliced_shape.Count(indices_num_axes-1)) int64_t index_ndims; // The number of dims which are sliced (indices_shape.At(-1)) int64_t dense_ndims; int64_t dense_shape[kMaxDims]; int64_t dense_stride[kMaxDims]; int64_t slices_ndims; int64_t slices_shape[kMaxDims]; int64_t slices_stride[kMaxDims]; }; inline NdIndexSliceArgs ConstructNdIndexSliceArgs(const user_op::Tensor& dense, const user_op::Tensor& slices, const user_op::Tensor& indices) { NdIndexSliceArgs args; std::memset(&args, 0, sizeof(NdIndexSliceArgs)); args.num_slices = indices.shape_view().Count(0, indices.shape_view().NumAxes() - 1); args.index_ndims = indices.shape_view().At(indices.shape_view().NumAxes() - 1); args.slice_size = slices.shape_view().Count(indices.shape_view().NumAxes() - 1); args.dense_ndims = dense.shape_view().NumAxes(); FOR_RANGE(int64_t, i, 0, dense.shape_view().NumAxes()) { args.dense_shape[i] = dense.shape_view().At(i); args.dense_stride[i] = dense.stride().at(i); } args.slices_ndims = slices.shape_view().NumAxes(); FOR_RANGE(int64_t, i, 0, slices.stride().size()) { args.slices_shape[i] = slices.shape_view().At(i); args.slices_stride[i] = slices.stride().at(i); } return args; } template struct GatherNdFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* dense, T* slices) const; }; template struct ScatterNdAddFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const; }; template struct ScatterNdUpdateFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const; }; template struct ScatterNdUpdateWithStrideFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) const; }; template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense, T value) const; }; template OF_DEVICE_FUNC int64_t OffsetInSliceToOffsetInDense(int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, int64_t n) { int64_t slice_idx = n / slice_size; const I* nd_index = indices + slice_idx * index_ndims; int64_t offset = 0; int64_t product = 1; int64_t shifted_index = 0; for (int64_t i = index_ndims - 1; i >= 0; --i) { #if defined(__CUDACC__) assert(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i] && "index out of bounds"); #else CHECK(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i]) << "IndexError: index " << nd_index[i] << " is out of bounds for dimension " << i << " with size " << dense_shape[i]; #endif shifted_index = nd_index[i] < 0 && nd_index[i] >= -dense_shape[i] ? nd_index[i] + dense_shape[i] : nd_index[i]; offset += shifted_index * product; product *= dense_shape[i]; } return offset * slice_size + n % slice_size; } OF_DEVICE_FUNC int64_t GetMemoryOffset4ElementIdx(int64_t n, int64_t ndims, const int64_t* shape, const int64_t* stride) { int64_t offset = 0; for (int64_t i = ndims - 1; i >= 0; --i) { offset += n % shape[i] * stride[i]; n /= shape[i]; } return offset; } template OF_DEVICE_FUNC void DoGatherNd(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, const T* dense, T* slices) { XPU_1D_KERNEL_LOOP(i, elem_cnt) { int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i); slices[i] = dense[offset]; } } template struct DeviceAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { *y += *x; } }; template OF_DEVICE_FUNC void DoScatterNdAdd(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, const T* slices, T* dense) { XPU_1D_KERNEL_LOOP(i, elem_cnt) { int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i); DeviceAdd::Invoke(slices + i, dense + offset); } } template OF_DEVICE_FUNC void DoScatterNdUpdate(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, const T* slices, T* dense) { XPU_1D_KERNEL_LOOP(i, elem_cnt) { int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i); dense[offset] = slices[i]; } } template OF_DEVICE_FUNC void DoScatterNdUpdateWithStride(int64_t elem_cnt, const NdIndexSliceArgs& args, const I* indices, const T* slices, T* dense) { XPU_1D_KERNEL_LOOP(i, elem_cnt) { // dense tensor memory offset int64_t dense_index = OffsetInSliceToOffsetInDense(args.slice_size, args.index_ndims, args.dense_shape, indices, i); int64_t dense_mem_offset = GetMemoryOffset4ElementIdx(dense_index, args.dense_ndims, args.dense_shape, args.dense_stride); // update tensor memory offset int64_t slice_mem_offset = GetMemoryOffset4ElementIdx(i, args.slices_ndims, args.slices_shape, args.slices_stride); dense[dense_mem_offset] = slices[slice_mem_offset]; } } template OF_DEVICE_FUNC void DoFillByNdIndex(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, T* dense, T value) { XPU_1D_KERNEL_LOOP(i, elem_cnt) { int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i); dense[offset] = value; } } #define INSTANTIATE_GATHER_ND_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ template struct GatherNdFunctor; #define INSTANTIATE_SCATTER_ND_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ template struct ScatterNdAddFunctor; #define INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ template struct FillByNdIndexFunctor; #define INSTANTIATE_ND_INDEX_SLICE_FUNCTORS(device_type_v, dtype_pair, itype_pair) \ INSTANTIATE_GATHER_ND_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ INSTANTIATE_SCATTER_ND_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR(device_type_v, dtype_pair, itype_pair) } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_ ================================================ FILE: oneflow/user/kernels/nll_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/user/kernels/nll_kernel_util.h" namespace oneflow { namespace { class NLLKernelCache final : public user_op::OpKernelCache { public: NLLKernelCache(int64_t class_start, int64_t num_classes) : class_start_(class_start), num_classes_(num_classes) {} ~NLLKernelCache() override = default; int64_t class_start() const { return class_start_; } int64_t num_classes() const { return num_classes_; } private: const int64_t class_start_; const int64_t num_classes_; }; std::shared_ptr CreateNLLKernelCache(user_op::KernelCacheContext* ctx) { CHECK_GT(ctx->parallel_ctx().parallel_num(), 0) << ctx->op_name() << ": invalid parallel_ctx"; if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("input", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); CHECK_EQ(nd_sbp.sbp_parallel_size(), hierarchy.NumAxes()) << ctx->op_name() << ": Expected input sbp " << NdSbpToString(nd_sbp) << " match hierarchy " << hierarchy.ToString(); const Shape& shape = ctx->LogicalTensorDesc4ArgNameAndIndex("input", 0)->shape(); const int64_t class_axis = shape.NumAxes() - 1; bool split_class_dim = false; for (const auto& sbp : nd_sbp.sbp_parallel()) { if (sbp.has_split_parallel() && sbp.split_parallel().axis() == class_axis) { split_class_dim = true; break; } } if (!split_class_dim) { return nullptr; } TensorSliceView view = GetTensorSliceView4ParallelId(hierarchy, nd_sbp, shape, ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).size()); } } // namespace template class NLLKernel final : public user_op::OpKernel { public: NLLKernel() = default; ~NLLKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateNLLKernelCache(ctx); } private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* input = ctx->Tensor4ArgNameAndIndex("input", 0); const auto* target = ctx->Tensor4ArgNameAndIndex("target", 0); auto* output = ctx->Tensor4ArgNameAndIndex("output", 0); auto* out_weight = ctx->Tensor4ArgNameAndIndex("out_weight", 0); const int64_t N = target->shape_view().elem_cnt(); const int64_t C = input->shape_view().At(input->shape_view().NumAxes() - 1); CHECK_LE(N, std::numeric_limits::max()) << "Expected batch size not exceed int32 numeric limits"; K class_start = 0; if (cache) { const auto* spec_cache = dynamic_cast(cache); CHECK_NOTNULL(spec_cache); CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << ": expected num_classes " << C << ", got " << spec_cache->num_classes(); class_start = spec_cache->class_start(); } const K ignore_index = static_cast(ctx->Attr("ignore_index")); const T* weight_dptr = nullptr; if (ctx->has_input("weight", 0)) { weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex("weight", 0))->dptr(); } NLLKernelUtil::Forward(ctx->stream(), static_cast(N), static_cast(C), class_start, ignore_index, input->dptr(), target->dptr(), weight_dptr, output->mut_dptr(), out_weight->mut_dptr()); } }; template class NLLGradKernel final : public user_op::OpKernel { public: NLLGradKernel() = default; ~NLLGradKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateNLLKernelCache(ctx); } private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const auto* target = ctx->Tensor4ArgNameAndIndex("target", 0); const auto* out_grad = ctx->Tensor4ArgNameAndIndex("out_grad", 0); auto* in_grad = ctx->Tensor4ArgNameAndIndex("in_grad", 0); const int64_t N = target->shape_view().elem_cnt(); const int64_t C = in_grad->shape_view().At(in_grad->shape_view().NumAxes() - 1); CHECK_LE(N, std::numeric_limits::max()) << "Expected batch size not exceed int32 numeric limits"; K class_start = 0; if (cache) { const auto* spec_cache = dynamic_cast(cache); CHECK_NOTNULL(spec_cache); CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << ": expected num_classes " << C << ", got " << spec_cache->num_classes(); class_start = spec_cache->class_start(); } const K ignore_index = static_cast(ctx->Attr("ignore_index")); const T* weight_dptr = nullptr; if (ctx->has_input("weight", 0)) { weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex("weight", 0))->dptr(); } NLLKernelUtil::Backward( ctx->stream(), static_cast(N), static_cast(C), class_start, ignore_index, out_grad->dptr(), target->dptr(), weight_dptr, in_grad->mut_dptr()); } }; #define REGISTER_NLL_KERNELS(device, dtype, ltype) \ REGISTER_USER_KERNEL("nll").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("nll_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input", 0) == GetDataType::value) \ && (user_op::HobDataType("target", 0) == GetDataType::value) \ && (user_op::HobDataType("out_grad", 0) == GetDataType::value)) REGISTER_NLL_KERNELS(DeviceType::kCPU, float, int32_t); REGISTER_NLL_KERNELS(DeviceType::kCPU, float, int64_t); REGISTER_NLL_KERNELS(DeviceType::kCPU, double, int32_t); REGISTER_NLL_KERNELS(DeviceType::kCPU, double, int64_t); #ifdef WITH_CUDA REGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int32_t); REGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int64_t); REGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int32_t); REGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int64_t); REGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int32_t); REGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int64_t); #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nll_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/nll_kernel_util.h" namespace oneflow { template struct NLLKernelUtil { static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* input, const K* target, const T* weight, T* out, T* out_weight) { FOR_RANGE(int32_t, i, 0, num_samples) { K label = target[i]; T w = T{0}; T y = T{0}; if (label != ignore_index) { label -= class_start; if (label >= 0 && label < num_classes) { w = weight ? weight[label] : T{1}; y = -(input[i * num_classes + label] * w); } } out[i] = y; out_weight[i] = w; } } static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* out_grad, const K* target, const T* weight, T* in_grad) { Memset(stream, in_grad, 0, RoundUp(num_samples * num_classes * sizeof(T), kBlobBodyAlignSize)); FOR_RANGE(int32_t, i, 0, num_samples) { K label = target[i]; if (label == ignore_index) { continue; } label -= class_start; if (label >= 0 && label < num_classes) { const T w = weight ? -weight[label] : T(-1); in_grad[i * num_classes + label] = out_grad[i] * w; } } } }; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nll_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/nll_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" namespace oneflow { namespace { template __global__ void NLLForward(const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* input, const K* target, const T* weight, T* out, T* out_weight) { const T zero = GetZeroVal(); const T one = GetOneVal(); CUDA_1D_KERNEL_LOOP(i, num_samples) { K label = target[i]; T w = zero; T y = zero; if (label != ignore_index) { label -= class_start; if (label >= 0 && label < num_classes) { w = weight ? weight[label] : one; y = -(input[i * num_classes + label] * w); } } out[i] = y; out_weight[i] = w; } } template __global__ void NLLBackward(const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* out_grad, const K* target, const T* weight, T* in_grad) { const T one = GetOneVal(); const T zero = GetZeroVal(); CUDA_1D_KERNEL_LOOP_T(K, i, num_samples * num_classes) { const K n = i / num_classes; const K idx = i - n * num_classes; const K label = target[n]; if (label != ignore_index && idx == label - class_start) { in_grad[i] = out_grad[n] * (weight ? -weight[idx] : -one); } else { in_grad[i] = zero; } } } } // namespace template struct NLLKernelUtil { static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* input, const K* target, const T* weight, T* out, T* out_weight) { NLLForward<<As()->cuda_stream()>>>(num_samples, num_classes, class_start, ignore_index, input, target, weight, out, out_weight); } static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* out_grad, const K* target, const T* weight, T* in_grad) { NLLBackward<<As()->cuda_stream()>>>( num_samples, num_classes, class_start, ignore_index, out_grad, target, weight, in_grad); } }; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; template struct NLLKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nll_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct NLLKernelUtil { static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* input, const K* target, const T* weight, T* out, T* out_weight); static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, const K class_start, const K ignore_index, const T* out_grad, const K* target, const T* weight, T* in_grad); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/nms_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __inline__ T IoU(T const* const a, T const* const b) { T interS = std::max(std::min(a[2], b[2]) - std::max(a[0], b[0]), static_cast(0.f)) * std::max(std::min(a[3], b[3]) - std::max(a[1], b[1]), static_cast(0.f)); T Sa = (a[2] - a[0]) * (a[3] - a[1]); T Sb = (b[2] - b[0]) * (b[3] - b[1]); return interS / (Sa + Sb - interS); } } // namespace template class NmsCpuKernel final : public user_op::OpKernel { public: NmsCpuKernel() = default; ~NmsCpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* boxes_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* keep_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const T* boxes = boxes_blob->dptr(); int8_t* keep = keep_blob->mut_dptr(); const int num_boxes = boxes_blob->shape_view().At(0); int num_keep = ctx->Attr("keep_n"); if (num_keep <= 0 || num_keep > num_boxes) { num_keep = num_boxes; } const float iou_threshold = ctx->Attr("iou_threshold"); for (int i = 0; i < num_boxes; i++) { keep[i] = -1; } for (int i = 0; i < num_boxes; i++) { if (keep[i] == 0) continue; keep[i] = 1; for (int j = i + 1; j < num_boxes; j++) { if (IoU(boxes + i * 4, boxes + j * 4) > iou_threshold) { keep[j] = 0; } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_NMS_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("nms").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == DataType::kInt8) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_NMS_CPU_KERNEL(float) REGISTER_NMS_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nms_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { constexpr int kBlockSize = sizeof(int64_t) * 8; template __host__ __device__ __forceinline__ T CeilDiv(T a, T b) { return (a + b - 1) / b; } template __host__ __device__ __forceinline__ T IoU(T const* const a, T const* const b) { T interS = max(min(a[2], b[2]) - max(a[0], b[0]), 0.f) * max(min(a[3], b[3]) - max(a[1], b[1]), 0.f); T Sa = (a[2] - a[0]) * (a[3] - a[1]); T Sb = (b[2] - b[0]) * (b[3] - b[1]); return interS / (Sa + Sb - interS); } template __global__ void CalcSuppressionBitmaskMatrix(int num_boxes, float iou_threshold, const T* boxes, int64_t* suppression_bmask_matrix) { const int row = blockIdx.y; const int col = blockIdx.x; if (row > col) return; const int row_size = min(num_boxes - row * kBlockSize, kBlockSize); const int col_size = min(num_boxes - col * kBlockSize, kBlockSize); __shared__ T block_boxes[kBlockSize * 4]; if (threadIdx.x < col_size) { block_boxes[threadIdx.x * 4 + 0] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 0]; block_boxes[threadIdx.x * 4 + 1] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 1]; block_boxes[threadIdx.x * 4 + 2] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 2]; block_boxes[threadIdx.x * 4 + 3] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 3]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = kBlockSize * row + threadIdx.x; const T* cur_box_ptr = boxes + cur_box_idx * 4; unsigned long long bits = 0; int start = 0; if (row == col) { start = threadIdx.x + 1; } for (int i = start; i < col_size; i++) { if (IoU(cur_box_ptr, block_boxes + i * 4) > iou_threshold) { bits |= 1Ull << i; } } suppression_bmask_matrix[cur_box_idx * gridDim.y + col] = bits; } } __global__ void ScanSuppression(int num_boxes, int num_blocks, int num_keep, int64_t* suppression_bmask, int8_t* keep_mask) { extern __shared__ int64_t remv[]; remv[threadIdx.x] = 0; for (int i = 0; i < num_boxes; ++i) { int block_n = i / kBlockSize; int block_i = i % kBlockSize; if (!(remv[block_n] & (1Ull << block_i))) { remv[threadIdx.x] |= suppression_bmask[i * num_blocks + threadIdx.x]; if (threadIdx.x == block_n && num_keep > 0) { keep_mask[i] = 1; num_keep -= 1; } } } } } // namespace template class NmsGpuKernel final : public user_op::OpKernel { public: NmsGpuKernel() = default; ~NmsGpuKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* boxes_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* keep_blob = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_blob = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const T* boxes = boxes_blob->dptr(); int8_t* keep = keep_blob->mut_dptr(); int64_t* suppression_mask = tmp_blob->mut_dptr(); const int num_boxes = boxes_blob->shape_view().At(0); int num_keep = ctx->Attr("keep_n"); if (num_keep <= 0 || num_keep > num_boxes) { num_keep = num_boxes; } const int num_blocks = CeilDiv(num_boxes, kBlockSize); Memset(ctx->stream(), suppression_mask, 0, num_boxes * num_blocks * sizeof(int64_t)); Memset(ctx->stream(), keep, 0, num_boxes * sizeof(int8_t)); dim3 blocks(num_blocks, num_blocks); dim3 threads(kBlockSize); CalcSuppressionBitmaskMatrix<<stream()->As()->cuda_stream()>>>( num_boxes, ctx->Attr("iou_threshold"), boxes, suppression_mask); ScanSuppression<<<1, num_blocks, num_blocks * sizeof(int64_t), ctx->stream()->As()->cuda_stream()>>>( num_boxes, num_blocks, num_keep, suppression_mask, keep); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_NMS_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("nms") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == DataType::kInt8) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ int64_t num_boxes = in_shape.At(0); \ int64_t blocks = CeilDiv(num_boxes, kBlockSize); \ return num_boxes * blocks * sizeof(int64_t); \ }); REGISTER_NMS_CUDA_KERNEL(float) REGISTER_NMS_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/noncontiguous_binary_op.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/primitive/fast_integer_math.h" #include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { namespace { #define MaxDims 6 #define MAX2(a, b) ((a) > (b)) ? (a) : (b) #define MAX3(a, b, c) MAX2(MAX2(a, b), c) using cuda::elementwise::Packed; #define DEFINE_BINARY_FUNCTOR(OP, expr) \ template \ struct OP { \ __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a expr b; } \ }; \ template<> \ struct OP { \ __device__ __forceinline__ half operator()(const half& a, const half& b) const { \ return __float2half(__half2float(a) expr __half2float(b)); \ } \ }; DEFINE_BINARY_FUNCTOR(Add, +) DEFINE_BINARY_FUNCTOR(Sub, -) DEFINE_BINARY_FUNCTOR(Mul, *) DEFINE_BINARY_FUNCTOR(Div, /) #undef DEFINE_BINARY_FUNCTOR #define DEFINE_BINARY_OP_GRAD_FUNCTOR(OP, dl_expr, dr_expr) \ template \ struct OP##Grad { \ __device__ __forceinline__ void operator()(const T& dout, const T& a, const T& b, T* da, \ T* db) const { \ *da = dl_expr dout; \ *db = dr_expr dout; \ } \ }; \ template<> \ struct OP##Grad { \ __device__ __forceinline__ void operator()(const half& hdout, const half& ha, const half& hb, \ half* hda, half* hdb) const { \ float dout, a, b; \ dout = __half2float(hdout), a = __half2float(ha), b = __half2float(hb); \ *hda = __float2half(dl_expr dout); \ *hdb = __float2half(dr_expr dout); \ } \ }; DEFINE_BINARY_OP_GRAD_FUNCTOR(Add, 1 *, 1 *) DEFINE_BINARY_OP_GRAD_FUNCTOR(Sub, 1 *, -1 *) DEFINE_BINARY_OP_GRAD_FUNCTOR(Mul, b*, a*) DEFINE_BINARY_OP_GRAD_FUNCTOR(Div, 1 / b*, -a / b / b*) #undef DEFINE_BINARY_OP_GRAD_FUNCTOR template __global__ void noncontiguous_binary_op_kernel(IndexType n_pack, Store y, Loader1 x1, Loader2 x2) { Packed pack_y; Packed pack_x1; Packed pack_x2; CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) { x1.load(i, &pack_x1); x2.load(i, &pack_x2); #pragma unroll for (int j = 0; j < pack_size; ++j) pack_y.elem[j] = BinaryOp()(static_cast(pack_x1.elem[j]), static_cast(pack_x2.elem[j])); // todo: Apply2 y.store(i, &pack_y); } }; template struct LoadStore { LoadStore(FastIntegerMath fast_integer_math[MaxDims], const int ndims, const int strides[MaxDims], const Src* src, Dst* dst = nullptr, bool is_contiguous = false) : ndims_(ndims), src_(src), dst_(dst), is_contiguous_(is_contiguous) { for (int i = 0; i < ndims; i++) { strides_[i] = static_cast(strides[i]); fast_integer_math_[i] = fast_integer_math[i]; } } OF_DEVICE_FUNCTION IndexType index2offset(IndexType index) { IndexType offset = 0; IndexType div = 0, mod = 0; #pragma unroll for (int dim = ndims_ - 1; dim >= 0; --dim) { if (index == 0) break; fast_integer_math_[dim].divmod(index, &div, &mod); index = div; offset += mod * strides_[dim]; } return offset; } OF_DEVICE_FUNCTION void load(IndexType idx, Packed* pack) { IndexType offset; if (is_contiguous_) offset = idx * pack_size; else offset = index2offset(idx); *pack = *(reinterpret_cast*>(src_ + offset)); } OF_DEVICE_FUNCTION void store(IndexType idx, Packed* pack) { IndexType offset; if (is_contiguous_) offset = idx * pack_size; else offset = index2offset(idx); *(reinterpret_cast*>(dst_ + offset)) = *pack; } int ndims_; int pack_dim_; bool is_contiguous_; const Src* src_; Dst* dst_; IndexType strides_[MaxDims]; FastIntegerMath fast_integer_math_[MaxDims]; }; template void launch_noncontiguous_binary_op_kernel(cudaStream_t stream, const IndexType n_pack, Store& store, Load1& load1, Load2& load2) { int num_blocks = 1, block_size = cuda::elementwise::kBlockSize; cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks); CHECK(err == cudaSuccess); noncontiguous_binary_op_kernel <<>>(n_pack, store, load1, load2); } template void dispatchOp(cudaStream_t stream, const std::string& op, const IndexType n_pack, Store& store, Load1& load1, Load2& load2) { if (op == "add") launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( stream, n_pack, store, load1, load2); else if (op == "sub") launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( stream, n_pack, store, load1, load2); else if (op == "mul") launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( stream, n_pack, store, load1, load2); else if (op == "div") launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( stream, n_pack, store, load1, load2); else UNIMPLEMENTED_THEN_THROW(); } template void dispatchInplace(cudaStream_t stream, const bool inplace, const std::string& op, const int& ndims, const IndexType n_pack, const int sizes[MaxDims], const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) { typedef FastIntegerMath FastIntegerMathT; FastIntegerMathT fast_integer_math[MaxDims]; for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]); if (inplace) { LoadStore load_store(fast_integer_math, ndims, strides[0], x1, y); LoadStore loader2(fast_integer_math, ndims, strides[2], x2); dispatchOp(stream, op, n_pack, load_store, load_store, loader2); } else { LoadStore store(fast_integer_math, ndims, strides[0], nullptr, y); LoadStore loader1(fast_integer_math, ndims, strides[1], x1); LoadStore loader2(fast_integer_math, ndims, strides[2], x2); dispatchOp(stream, op, n_pack, store, loader1, loader2); } } template void dispatchIndexType(cudaStream_t stream, const bool inplace, const std::string& op, const int& ndims, const int64_t& n_pack, const int sizes[MaxDims], const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) { if ((n_pack * pack_size) >> 30 == 0) { int32_t n = (int32_t)n_pack; dispatchInplace(stream, inplace, op, ndims, n, sizes, strides, y, x1, x2); } else dispatchInplace(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, x2); } template void dispatchPacksize(cudaStream_t stream, const bool inplace, const std::string& op, const int& ndims, const int64_t n_pack, int pack_size, const int sizes[MaxDims], const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) { if (pack_size == 8) dispatchIndexType<8, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, x2); else if (pack_size == 4) dispatchIndexType<4, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, x2); else if (pack_size == 2) dispatchIndexType<2, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, x2); else if (pack_size == 1) dispatchIndexType<1, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, x2); else UNIMPLEMENTED(); } } // namespace template class NonContiguousBinaryOpKernel final : public user_op::OpKernel { public: NonContiguousBinaryOpKernel() = default; ~NonContiguousBinaryOpKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("lhs", 0); const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("rhs", 0); const std::string op = ctx->Attr("op"); const bool inplace = ctx->Attr("inplace"); int ndims = y->shape_view().NumAxes(); const ShapeView& shape = y->shape_view(); int sizes[MaxDims]; int strides[3][MaxDims]; int pack_size = 1; int64_t elem_cnt = 1; int max_elem_size = MAX3(GetSizeOfDataType(y->data_type()), GetSizeOfDataType(x1->data_type()), GetSizeOfDataType(x2->data_type())); for (int i = 0; i < ndims; ++i) { sizes[i] = shape.At(i); elem_cnt *= shape.At(i); strides[0][i] = y->stride()[i]; strides[1][i] = x1->stride()[i]; strides[2][i] = x2->stride()[i]; if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && y->stride()[i] == 1) { pack_size = 16 / max_elem_size; while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1; sizes[i] = sizes[i] / pack_size; strides[0][i] *= pack_size; strides[1][i] *= pack_size; strides[2][i] *= pack_size; } } dispatchPacksize(ctx->stream()->As()->cuda_stream(), inplace, op, ndims, elem_cnt / pack_size, pack_size, sizes, strides, y->mut_dptr(), x1->dptr(), x2->dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(dtype, lhs, rhs) \ REGISTER_USER_KERNEL("noncontiguous_binary_op") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && (user_op::HobDataType("lhs", 0) == GetDataType::value) \ && (user_op::HobDataType("rhs", 0) == GetDataType::value)); // output_type, lhs_type, rhs_type REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(float, float, float) REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(half, half, half) // #if CUDA_VERSION >= 11000 // REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(nv_bfloat16, nv_bfloat16, nv_bfloat16) // #endif // ------------------------------------- grad kernel ------------------------------------- template __global__ void noncontiguous_binary_op_grad_kernel(IndexType n_pack, Loadery dy, Loader1 load1, Loader2 load2) { Packed pack_dy; Packed pack_x1; Packed pack_x2; Packed pack_dx1; Packed pack_dx2; CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) { load1.load(i, &pack_x1); load2.load(i, &pack_x2); dy.load(i, &pack_dy); #pragma unroll for (int j = 0; j < pack_size; ++j) BinaryOp()(pack_dy.elem[j], pack_x1.elem[j], pack_x2.elem[j], &pack_dx1.elem[j], &pack_dx2.elem[j]); // todo: Apply2 load1.store(i, &pack_dx1); load2.store(i, &pack_dx2); } }; template void launch_noncontiguous_binary_op_grad_kernel(cudaStream_t stream, const IndexType n_pack, Loady& load_y, Load1& load1, Load2& load2) { int num_blocks = 1, block_size = cuda::elementwise::kBlockSize; cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks); CHECK(err == cudaSuccess); noncontiguous_binary_op_grad_kernel <<>>(n_pack, load_y, load1, load2); } template void dispatchOpGrad(cudaStream_t stream, const std::string& op, const IndexType& n_pack, Loady& load_y, Load1& load1, Load2& load2) { if (op == "add") launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( stream, n_pack, load_y, load1, load2); else if (op == "sub") launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( stream, n_pack, load_y, load1, load2); else if (op == "mul") launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( stream, n_pack, load_y, load1, load2); else if (op == "div") launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( stream, n_pack, load_y, load1, load2); else UNIMPLEMENTED_THEN_THROW(); } template void dispatchLoader(cudaStream_t stream, const std::string& op, const int& ndims, const IndexType n_pack, const int sizes[MaxDims], const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy, const lhs* x1, const rhs* x2) { typedef FastIntegerMath FastIntegerMathT; FastIntegerMathT fast_integer_math[MaxDims]; for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]); LoadStore load_y(fast_integer_math, ndims, strides[0], dy); LoadStore loader_store1( fast_integer_math, ndims, strides[1], x1, dx1); LoadStore loader_store2( fast_integer_math, ndims, strides[2], x2, dx2); dispatchOpGrad(stream, op, n_pack, load_y, loader_store1, loader_store2); } template void dispatchIndexTypeGrad(cudaStream_t stream, const std::string& op, const int& ndims, const int64_t& n_pack, const int sizes[MaxDims], const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy, const lhs* x1, const rhs* x2) { if ((n_pack * pack_size) >> 30 == 0) { int32_t n = (int32_t)n_pack; dispatchLoader(stream, op, ndims, n, sizes, strides, dx1, dx2, dy, x1, x2); } else dispatchLoader(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, x1, x2); } template void dispatchPacksizeGrad(cudaStream_t stream, const std::string& op, const int& ndims, const int64_t& n_pack, int& pack_size, const int sizes[MaxDims], const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy, const lhs* x1, const rhs* x2) { if (pack_size == 8) dispatchIndexTypeGrad<8, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, x1, x2); else if (pack_size == 4) dispatchIndexTypeGrad<4, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, x1, x2); else if (pack_size == 2) dispatchIndexTypeGrad<2, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, x1, x2); else if (pack_size == 1) dispatchIndexTypeGrad<1, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, x1, x2); else UNIMPLEMENTED(); } template class NonContiguousBinaryOpGradKernel final : public user_op::OpKernel { public: NonContiguousBinaryOpGradKernel() = default; ~NonContiguousBinaryOpGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("lhs", 0); const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("rhs", 0); user_op::Tensor* dx1 = ctx->Tensor4ArgNameAndIndex("dlhs", 0); user_op::Tensor* dx2 = ctx->Tensor4ArgNameAndIndex("drhs", 0); const std::string op = ctx->Attr("op"); const bool inplace = ctx->Attr("inplace"); CHECK(inplace == false) << "inplace should be set to `false` to compute gradients."; int ndims = dy->shape_view().NumAxes(); const ShapeView& shape = dy->shape_view(); int sizes[MaxDims]; int strides[3][MaxDims]; int pack_size = 1; int64_t elem_cnt = 1; int max_elem_size = MAX3(GetSizeOfDataType(dy->data_type()), GetSizeOfDataType(x1->data_type()), GetSizeOfDataType(x2->data_type())); for (int i = 0; i < ndims; ++i) { sizes[i] = shape.At(i); elem_cnt *= shape.At(i); strides[0][i] = dy->stride()[i]; strides[1][i] = x1->stride()[i]; strides[2][i] = x2->stride()[i]; if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && dy->stride()[i] == 1) { pack_size = 16 / max_elem_size; while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1; sizes[i] = sizes[i] / pack_size; strides[0][i] *= pack_size; strides[1][i] *= pack_size; strides[2][i] *= pack_size; } } dispatchPacksizeGrad(ctx->stream()->As()->cuda_stream(), op, ndims, elem_cnt / pack_size, pack_size, sizes, strides, dx1->mut_dptr(), dx2->mut_dptr(), dy->dptr(), x1->dptr(), x2->dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(dtype, lhs, rhs) \ REGISTER_USER_KERNEL("noncontiguous_binary_op_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value) \ && (user_op::HobDataType("lhs", 0) == GetDataType::value) \ && (user_op::HobDataType("rhs", 0) == GetDataType::value)); // output_type, lhs_type, rhs_type REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(float, float, float) REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(half, half, half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/nop_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class NopKernel final : public user_op::OpKernel { public: NopKernel() = default; ~NopKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { // do nothing } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_NOP_KERNEL(op_type_name) \ REGISTER_USER_KERNEL(op_type_name).SetCreateFn(); REGISTER_NOP_KERNEL("cast_to_tick") REGISTER_NOP_KERNEL("acc_ctrl_tick") REGISTER_NOP_KERNEL("repeat") } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/normalization_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { template static void ComputeMeanAndVar(const T* input_ptr, T* mean_ptr, T* inv_variance_ptr, T* moving_mean_ptr, T* moving_variance_ptr, const int64_t batch_size, const int64_t channel_size, const int64_t spatial_size, const float epsilon, const float momentum) { // NOTE(Liang Depeng): the following parameters were used to compute mean and var const int64_t jump_step = spatial_size * channel_size; const int64_t reduce_count = batch_size * spatial_size; const int64_t unbias_reduce_count = reduce_count - 1; const T reduce_scale_factor = static_cast(1) / reduce_count; const T unbias_reduce_scale_factor = static_cast(1) / unbias_reduce_count; const T unbias_reduce_scale_factor_m2 = unbias_reduce_scale_factor * -static_cast(2); const T unbias_reduce_scale_factor_mn = reduce_count * unbias_reduce_scale_factor; const T exponential_average_factor = 1.0f - momentum; for (int64_t channel = 0; channel < channel_size; ++channel) { const T* temp_input_ptr = input_ptr + channel * spatial_size; T sum = 0; T sum_square = 0; for (int64_t batch = 0; batch < batch_size; ++batch) { for (int64_t s = 0; s < spatial_size; ++s) { const T x = temp_input_ptr[s]; sum += x; sum_square += x * x; } temp_input_ptr += jump_step; } const T temp_mean = sum * reduce_scale_factor; mean_ptr[channel] = temp_mean; const T temp_mean_square = temp_mean * temp_mean; const T temp_variance = sum_square * reduce_scale_factor - temp_mean_square; const T temp_unbias_variance = sum_square * unbias_reduce_scale_factor + unbias_reduce_scale_factor_m2 * temp_mean * sum + unbias_reduce_scale_factor_mn * temp_mean_square; inv_variance_ptr[channel] = static_cast(1) / std::sqrt(temp_variance + epsilon); if (moving_mean_ptr != nullptr && moving_variance_ptr != nullptr) { moving_mean_ptr[channel] = moving_mean_ptr[channel] * momentum + temp_mean * exponential_average_factor; moving_variance_ptr[channel] = moving_variance_ptr[channel] * momentum + temp_unbias_variance * exponential_average_factor; } } } template static void Normalize(const T* input_ptr, const T* mean_ptr, const T* variance_ptr, const T* gamma_ptr, const T* beta_ptr, T* output_ptr, const int64_t batch_size, const int64_t channel_size, const int64_t spatial_size, const float epsilon, const bool training) { const T* temp_input_ptr = input_ptr; T* temp_output_ptr = output_ptr; const int64_t all_channels = batch_size * channel_size; int64_t channel = -1; for (int64_t ac = 0; ac < all_channels; ++ac) { channel += 1; if (channel >= channel_size) { channel = 0; } T inv_variance = variance_ptr[channel]; if (!training) { inv_variance = 1.0f / std::sqrt(inv_variance + epsilon); } const T gamma = gamma_ptr[channel] * inv_variance; const T beta = beta_ptr[channel]; const T mean = mean_ptr[channel]; for (int64_t s = 0; s < spatial_size; ++s) { temp_output_ptr[s] = (temp_input_ptr[s] - mean) * gamma + beta; } temp_input_ptr += spatial_size; temp_output_ptr += spatial_size; } } template static void AddToOutput(const T* add_to_output_ptr, T* output_ptr, const int64_t elem_count) { for (int64_t i = 0; i < elem_count; ++i) { output_ptr[i] += add_to_output_ptr[i]; } } template static void AddRelu(const T* addend_ptr, int32_t* mask_ptr, T* output_ptr, const int64_t elem_cnt) { const int32_t step = 32; const int64_t outer_loop = elem_cnt / step; const int64_t remain_loop_start_idx = outer_loop * step; T* temp_output_ptr = output_ptr; for (int64_t outer = 0; outer < outer_loop; ++outer) { int32_t mask = 0; for (int32_t s = 0; s < step; ++s) { const T sum = temp_output_ptr[s] + addend_ptr[s]; const bool is_positive = (sum > 0); mask = mask | (static_cast(is_positive) << s); temp_output_ptr[s] = is_positive ? sum : 0; } mask_ptr[outer] = mask; addend_ptr += step; temp_output_ptr += step; } if (remain_loop_start_idx < elem_cnt) { int32_t mask_val = 0; const int32_t remain = elem_cnt - remain_loop_start_idx; for (int32_t i = 0; i < remain; ++i) { const T sum = temp_output_ptr[i] + addend_ptr[i]; const bool is_positive = (sum > 0); mask_val = mask_val | (static_cast(is_positive) << i); temp_output_ptr[i] = is_positive ? sum : 0; } mask_ptr[outer_loop] = mask_val; } } template static void Relu(int32_t* mask_ptr, T* output_ptr, const int64_t elem_cnt) { const int32_t step = 32; const int64_t outer_loop = elem_cnt / step; const int64_t remain_loop_start_idx = outer_loop * step; T* temp_output_ptr = output_ptr; for (int64_t outer = 0; outer < outer_loop; ++outer) { int32_t mask_val = 0; for (int32_t s = 0; s < step; ++s) { const T output = temp_output_ptr[s]; const bool is_positive = (output > 0); mask_val = mask_val | (static_cast(is_positive) << s); temp_output_ptr[s] = is_positive ? output : 0; } mask_ptr[outer] = mask_val; temp_output_ptr += step; } if (remain_loop_start_idx < elem_cnt) { int32_t mask_val = 0; const int32_t remain = elem_cnt - remain_loop_start_idx; for (int32_t i = 0; i < remain; ++i) { const T output = temp_output_ptr[i]; const bool is_positive = (output > 0); mask_val = mask_val | (static_cast(is_positive) << i); temp_output_ptr[i] = is_positive ? output : 0; } mask_ptr[outer_loop] = mask_val; } } template static void AddReluGrad(const T* dy_ptr, const int32_t* mask_ptr, T* addend_diff_ptr, const int64_t elem_cnt) { const int32_t step = 32; const int64_t outer_loop = elem_cnt / step; const int64_t remain_loop_start_idx = outer_loop * step; for (int64_t outer = 0; outer < outer_loop; ++outer) { const int32_t mask_val = mask_ptr[outer]; for (int32_t s = 0; s < step; ++s) { bool is_positive = mask_val & (1 << s); addend_diff_ptr[s] = static_cast(is_positive) * dy_ptr[s]; } addend_diff_ptr += step; dy_ptr += step; } if (remain_loop_start_idx < elem_cnt) { const int32_t mask_val = mask_ptr[outer_loop]; const int32_t remain = elem_cnt - remain_loop_start_idx; for (int32_t i = 0; i < remain; ++i) { bool is_positive = mask_val & (1 << i); addend_diff_ptr[i] = static_cast(is_positive) * dy_ptr[i]; } } } template static void ReluGrad(const T* dy_ptr, const int32_t* mask_ptr, T* relu_dx_ptr, const int64_t elem_cnt) { const int32_t step = 32; const int64_t outer_loop = elem_cnt / step; const int64_t remain_loop_start_idx = outer_loop * step; for (int64_t outer = 0; outer < outer_loop; ++outer) { const int32_t mask_val = mask_ptr[outer]; for (int32_t s = 0; s < step; ++s) { bool is_positive = mask_val & (1 << s); relu_dx_ptr[s] = static_cast(is_positive) * dy_ptr[s]; } relu_dx_ptr += step; dy_ptr += step; } if (remain_loop_start_idx < elem_cnt) { const int32_t mask_val = mask_ptr[outer_loop]; const int32_t remain = elem_cnt - remain_loop_start_idx; for (int32_t i = 0; i < remain; ++i) { bool is_positive = mask_val & (1 << i); relu_dx_ptr[i] = static_cast(is_positive) * dy_ptr[i]; } } } static size_t InferGradTmpSizeForCpuKernel(user_op::InferContext* ctx) { const auto& dy = ctx->InputTensorDesc("dy", 0); size_t tmp_size = 0; if (ctx->op_type_name() == "normalization_add_relu_grad" && !ctx->has_output("addend_diff", 0)) { tmp_size += dy.shape().elem_cnt() * GetSizeOfDataType(dy.data_type()); } return tmp_size; } // NOTE(Liang Depeng): helper functions to process datas for specific channel over all samples. template static inline void ForEachFast(const T* data, const int64_t batch_size, const int64_t spatial_size, const int64_t jump_step, const int64_t channel_idx, DataProcessor data_processor) { const int64_t start_offset = channel_idx * spatial_size; const T* tmp_data = data + start_offset; for (int64_t outer = 0; outer < batch_size; ++outer) { for (int64_t i = 0; i < spatial_size; ++i) { data_processor(&tmp_data[i]); } tmp_data += jump_step; } } template static inline void ForEachFast(const T* in_data1, const T* in_data2, const int64_t batch_size, const int64_t spatial_size, const int64_t jump_step, const int64_t channel_idx, DataProcessor data_processor) { const int64_t start_offset = channel_idx * spatial_size; const T* tmp_in_data1 = in_data1 + start_offset; const T* tmp_in_data2 = in_data2 + start_offset; for (int64_t outer = 0; outer < batch_size; ++outer) { for (int64_t i = 0; i < spatial_size; ++i) { data_processor(&tmp_in_data1[i], &tmp_in_data2[i]); } tmp_in_data1 += jump_step; tmp_in_data2 += jump_step; } } template static inline void ForEachFast(const T* in_data, T* out_data, const int64_t batch_size, const int64_t spatial_size, const int64_t jump_step, const int64_t channel_idx, DataProcessor data_processor) { const int64_t start_offset = channel_idx * spatial_size; const T* tmp_in_data = in_data + start_offset; T* tmp_out_data = out_data + start_offset; for (int64_t outer = 0; outer < batch_size; ++outer) { for (int64_t i = 0; i < spatial_size; ++i) { data_processor(&tmp_in_data[i], &tmp_out_data[i]); } tmp_in_data += jump_step; tmp_out_data += jump_step; } } template class NormalizationInferenceCpuKernel final : public user_op::OpKernel { public: NormalizationInferenceCpuKernel() = default; ~NormalizationInferenceCpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const bool training = ctx->Attr("training"); CHECK(!training); const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0); auto* moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const DataType data_type = x->data_type(); CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(y->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); if (axis == 1) { // NOTE(Liang Depeng): NCHW format const T* input_ptr = x->dptr(); const T* gamma_ptr = gamma->dptr(); const T* beta_ptr = beta->dptr(); T* output_ptr = y->mut_dptr(); T* moving_mean_ptr = moving_mean->mut_dptr(); T* moving_variance_ptr = moving_variance->mut_dptr(); const int64_t batch_size = x->shape_view().At(0); const int64_t channel_size = x->shape_view().At(axis); const int64_t spatial_size = x->shape_view().Count(axis + 1); // NOTE(Liang Depeng): // compute the normalization result Normalize(input_ptr, moving_mean_ptr, moving_variance_ptr, gamma_ptr, beta_ptr, output_ptr, batch_size, channel_size, spatial_size, epsilon, false); if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->shape_view(), y->shape_view()); AddToOutput(add_to_output->dptr(), output_ptr, x->shape_view().elem_cnt()); } } else { // TODO(Liang Depeng): NHWC format UNIMPLEMENTED() << "cpu normalization op only support nchw data_format now!"; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BN_INFERENCE_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("normalization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && (user_op::HobAttr("training") == false)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ if (ctx.has_input("_add_to_output", 0)) { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "_add_to_output", 0, true)); \ } \ return Maybe::Ok(); \ }); REGISTER_BN_INFERENCE_CPU_KERNEL(float) REGISTER_BN_INFERENCE_CPU_KERNEL(double) #undef REGISTER_BN_INFERENCE_CPU_KERNEL template class NormalizationTrainCpuKernel final : public user_op::OpKernel { public: NormalizationTrainCpuKernel() = default; ~NormalizationTrainCpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { if (ctx->op_type_name() == "normalization") { CHECK(ctx->Attr("training")); } const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const auto momentum = ctx->Attr("momentum"); const DataType data_type = x->data_type(); CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(y->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); user_op::Tensor* moving_mean = nullptr; user_op::Tensor* moving_variance = nullptr; if (ctx->has_input("moving_mean", 0)) { CHECK(ctx->has_input("moving_variance", 0)); moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0); moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0); } if (axis == 1) { // NOTE(Liang Depeng): NCHW format const T* input_ptr = x->dptr(); const T* gamma_ptr = gamma->dptr(); const T* beta_ptr = beta->dptr(); T* output_ptr = y->mut_dptr(); T* mean_ptr = mean->mut_dptr(); T* inv_variance_ptr = inv_variance->mut_dptr(); T* moving_mean_ptr = nullptr; T* moving_variance_ptr = nullptr; if (moving_mean != nullptr && moving_variance != nullptr) { moving_mean_ptr = moving_mean->mut_dptr(); moving_variance_ptr = moving_variance->mut_dptr(); } const int64_t batch_size = x->shape_view().At(0); const int64_t channel_size = x->shape_view().At(axis); const int64_t spatial_size = x->shape_view().Count(axis + 1); // NOTE(Liang Depeng): // Compute mean & inv_variance and update moving_mean & moving_variance for each channel. ComputeMeanAndVar(input_ptr, mean_ptr, inv_variance_ptr, moving_mean_ptr, moving_variance_ptr, batch_size, channel_size, spatial_size, epsilon, momentum); // NOTE(Liang Depeng): // compute the normalization result Normalize(input_ptr, mean_ptr, inv_variance_ptr, gamma_ptr, beta_ptr, output_ptr, batch_size, channel_size, spatial_size, epsilon, true); if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->shape_view(), y->shape_view()); AddToOutput(add_to_output->dptr(), output_ptr, x->shape_view().elem_cnt()); } if (ctx->op_type_name() == "normalization_add_relu") { CHECK(!ctx->has_input("_add_to_output", 0)); auto* mask = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); if (ctx->has_input("addend", 0)) { const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0); AddRelu(addend->dptr(), mask->mut_dptr(), output_ptr, x->shape_view().elem_cnt()); } else { Relu(mask->mut_dptr(), output_ptr, x->shape_view().elem_cnt()); } } } else { // TODO(Liang Depeng): NHWC format } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BN_TRAIN_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("normalization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && (user_op::HobAttr("training") == true)) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ if (ctx.has_input("_add_to_output", 0)) { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "_add_to_output", 0, true)); \ } \ return Maybe::Ok(); \ }); REGISTER_BN_TRAIN_CPU_KERNEL(float) REGISTER_BN_TRAIN_CPU_KERNEL(double) #undef REGISTER_BN_TRAIN_CPU_KERNEL #define REGISTER_BN_ADD_RELU_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("normalization_add_relu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_BN_ADD_RELU_CPU_KERNEL(float) REGISTER_BN_ADD_RELU_CPU_KERNEL(double) #undef REGISTER_BN_ADD_RELU_CPU_KERNEL template class NormalizationGradCpuKernel final : public user_op::OpKernel { public: NormalizationGradCpuKernel() = default; ~NormalizationGradCpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); auto* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0); auto* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0); const auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto axis = ctx->Attr("axis"); const DataType data_type = x->data_type(); CHECK_EQ(dy->shape_view(), x->shape_view()); CHECK_EQ(dy->data_type(), data_type); CHECK_EQ(dx->shape_view(), x->shape_view()); CHECK_EQ(dx->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); const T* dy_ptr = nullptr; if (ctx->op_type_name() == "normalization_grad") { dy_ptr = dy->dptr(); } else if (ctx->op_type_name() == "normalization_add_relu_grad") { const auto* mask = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); if (ctx->has_output("addend_diff", 0)) { user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex("addend_diff", 0); AddReluGrad(dy->dptr(), mask->dptr(), addend_diff->mut_dptr(), dy->shape_view().elem_cnt()); dy_ptr = addend_diff->dptr(); } else { ReluGrad(dy->dptr(), mask->dptr(), tmp_buffer->mut_dptr(), dy->shape_view().elem_cnt()); dy_ptr = tmp_buffer->dptr(); } } else { UNIMPLEMENTED(); } if (axis == 1) { // NOTE(Liang Depeng): NCHW format const T* x_ptr = x->dptr(); const T* gamma_ptr = gamma->dptr(); const T* mean_ptr = mean->dptr(); const T* inv_variance_ptr = inv_variance->dptr(); T* dx_ptr = dx->mut_dptr(); T* gamma_diff_ptr = gamma_diff->mut_dptr(); T* beta_diff_ptr = beta_diff->mut_dptr(); const int64_t batch_size = x->shape_view().At(0); const int64_t channel_size = x->shape_view().At(axis); const int64_t spatial_size = x->shape_view().Count(axis + 1); const int64_t jump_step = spatial_size * channel_size; const int64_t reduce_count = batch_size * spatial_size; // NOTE(Liang Depeng): // Borrow the MXNet implementation to compute dx, gamma_diff and beta_diff. // For more details pls refers to: // https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/batch_norm.cc for (int64_t channel = 0; channel < channel_size; ++channel) { const T gamma_c = gamma_ptr[channel]; const T mean_c = mean_ptr[channel]; const T inv_variance_c = inv_variance_ptr[channel]; // NOTE(Liang Depeng): sum dy for specific channel over all samples T sum_dy_out = 0; ForEachFast(dy_ptr, batch_size, spatial_size, jump_step, channel, [&sum_dy_out](const T* dy_data) { sum_dy_out += *dy_data; }); // NOTE(Liang Depeng): dot product of the x and dy T dotp = 0; ForEachFast(x_ptr, dy_ptr, batch_size, spatial_size, jump_step, channel, [&dotp, mean_c](const T* x_data, const T* dy_data) { dotp += (*x_data - mean_c) * (*dy_data); }); // NOTE(Liang Depeng): projection of dy on to output scaled by std const T k = dotp * inv_variance_c * inv_variance_c / reduce_count; const T iw = inv_variance_c * gamma_c; const T grad_mean_c = sum_dy_out / reduce_count; ForEachFast( x_ptr, dx_ptr, batch_size, spatial_size, jump_step, channel, [&mean_c, &k](const T* x_data, T* dx_data) { *dx_data = (*x_data - mean_c) * k; }); ForEachFast(dy_ptr, dx_ptr, batch_size, spatial_size, jump_step, channel, [iw, grad_mean_c](const T* dy_data, T* dx_data) { *dx_data = (*dy_data - grad_mean_c - *dx_data) * iw; }); gamma_diff_ptr[channel] = dotp * inv_variance_c; beta_diff_ptr[channel] = sum_dy_out; } } else { // TODO(Liang Depeng): NHWC format } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BN_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("normalization_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_BN_GRAD_CPU_KERNEL(float) REGISTER_BN_GRAD_CPU_KERNEL(double) #undef REGISTER_BN_GRAD_CPU_KERNEL #define REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("normalization_add_relu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferGradTmpSizeForCpuKernel); REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(float) REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(double) #undef REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL } // namespace oneflow ================================================ FILE: oneflow/user/kernels/normalization_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_pseudo_bfloat16.h" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include #if (CUDNN_VERSION >= 7401) #define BN_ENABLE_EX_API #endif namespace oneflow { namespace { cudnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) { if (dim == 2) { return CUDNN_BATCHNORM_PER_ACTIVATION; } else if (ParseBooleanFromEnv("ONEFLOW_ENABLE_NHWC", false)) { return CUDNN_BATCHNORM_SPATIAL_PERSISTENT; } else { // NOTE(Liang Depeng): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was // introduced in CuDNN 7 for performance optimization, but it results in // accuracy losses in convolution models such as ResNeXt-101 and // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL return CUDNN_BATCHNORM_SPATIAL; } } void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int32_t* n, int32_t* c, int32_t* h, int32_t* w, cudnnTensorFormat_t* format) { if (x_shape.Count(axis + 1) == 1) { if (axis == 0) { *n = 1; *h = 1; } else { *n = x_shape.At(0); *h = x_shape.Count(1, axis); } *w = 1; *c = x_shape.At(axis); *format = CUDNN_TENSOR_NHWC; } else { *n = x_shape.Count(0, axis); *c = x_shape.At(axis); *h = x_shape.Count(axis + 1); *w = 1; *format = CUDNN_TENSOR_NCHW; } } void InferXYCudnnTensorDesc(const ShapeView& xy_shape, const DataType& data_type, const int32_t axis, cudnnTensorDescriptor_t xy_desc) { int32_t n, c, h, w; cudnnTensorFormat_t format; InferDimSizeAndDataFormat(xy_shape, axis, &n, &c, &h, &w, &format); OF_CUDNN_CHECK( cudnnSetTensor4dDescriptor(xy_desc, format, GetCudnnDataType(data_type), n, c, h, w)); } void InferParamCudnnTensorDesc(const cudnnTensorDescriptor_t xy_desc, cudnnBatchNormMode_t mode, cudnnTensorDescriptor_t param_desc) { OF_CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(param_desc, xy_desc, mode)); } class CudnnTensorDescHelper final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDescHelper); CudnnTensorDescHelper(const ShapeView& xy_shape, const DataType& data_type, const int32_t axis, cudnnBatchNormMode_t mode) { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&xy_desc_)); InferXYCudnnTensorDesc(xy_shape, data_type, axis, xy_desc_); OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(¶m_desc_)); InferParamCudnnTensorDesc(xy_desc_, mode, param_desc_); int n, c, h, w, n_stride, c_stride, h_stride, w_stride; OF_CUDNN_CHECK(cudnnGetTensor4dDescriptor(param_desc_, ¶m_data_type_, &n, &c, &h, &w, &n_stride, &c_stride, &h_stride, &w_stride)); param_size_ = c; } ~CudnnTensorDescHelper() { OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc_)); OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(xy_desc_)); } cudnnTensorDescriptor_t xy_desc() const { return xy_desc_; } cudnnTensorDescriptor_t param_desc() const { return param_desc_; } void CheckParamTensor(const user_op::Tensor* tensor) const { CHECK_NOTNULL(tensor); CHECK_EQ(tensor->shape_view().NumAxes(), 1); CHECK_EQ(tensor->shape_view().At(0), param_size_); CHECK_EQ(GetCudnnDataType(tensor->data_type()), param_data_type_); } private: cudnnTensorDescriptor_t xy_desc_ = nullptr; cudnnTensorDescriptor_t param_desc_ = nullptr; cudnnDataType_t param_data_type_; int32_t param_size_ = 0; }; size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type, const int32_t axis) { #if defined(BN_ENABLE_EX_API) cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x_shape.NumAxes()); const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, mode); size_t size_in_bytes; cudnnHandle_t handle = Singleton::Get()->Get(); OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( handle, mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &size_in_bytes)); Singleton::Get()->Put(handle); return std::max(size_in_bytes, static_cast(1)); #else return 1; #endif } size_t InferTrainTmpSize(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); const auto axis = ctx->Attr("axis"); return InferTrainWorkspaceSize(x.shape(), x.data_type(), axis); } size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type, const int32_t axis) { #if defined(BN_ENABLE_EX_API) cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x_shape.NumAxes()); const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, mode); size_t size_in_bytes; cudnnHandle_t handle = Singleton::Get()->Get(); OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize( handle, mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &size_in_bytes)); Singleton::Get()->Put(handle); return std::max(size_in_bytes, static_cast(1)); #else return 1; #endif } size_t InferGradTmpSize(user_op::InferContext* ctx) { const auto& dy = ctx->InputTensorDesc("dy", 0); const auto axis = ctx->Attr("axis"); size_t tmp_size = 0; if (ctx->op_type_name() == "normalization_add_relu_grad" && !ctx->has_output("addend_diff", 0)) { tmp_size += GetCudaAlignedSize(dy.shape().elem_cnt() * GetSizeOfDataType(dy.data_type())); } tmp_size += GetCudaAlignedSize(InferGradWorkspaceSize(dy.shape(), dy.data_type(), axis)); return tmp_size; } constexpr int64_t kCudaWarpSize = 32; template __global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) { const int32_t lane_id = threadIdx.x % kCudaWarpSize; const T zero = static_cast(0.f); CUDA_1D_KERNEL_LOOP(i, n) { const T x_val = x[i]; const bool is_positive = (x_val > zero); int32_t warp_mask = __ballot_sync(__activemask(), static_cast(is_positive)); if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; } y[i] = is_positive ? x_val : zero; } } template __global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int32_t* mask) { const int32_t lane_id = threadIdx.x % kCudaWarpSize; const T zero = static_cast(0.f); CUDA_1D_KERNEL_LOOP(i, n) { const T sum = x[i] + addend[i]; const bool is_positive = (sum > zero); int32_t warp_mask = __ballot_sync(__activemask(), static_cast(is_positive)); if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; } y[i] = is_positive ? sum : zero; } } template void Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int32_t* mask) { ReluGpu<<As()->cuda_stream()>>>(n, x, y, mask); } template void AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int32_t* mask) { AddReluGpu<<As()->cuda_stream()>>>(n, x, addend, y, mask); } template __global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const T* dy, T* addend_diff) { int32_t lane_id = threadIdx.x % kCudaWarpSize; CUDA_1D_KERNEL_LOOP(i, n) { int32_t mask_val = mask[i / kCudaWarpSize]; bool is_positive = mask_val & (1 << lane_id); addend_diff[i] = static_cast(is_positive) * dy[i]; } } #if CUDA_VERSION >= 11000 template<> __global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const nv_bfloat16* dy, nv_bfloat16* addend_diff) { int32_t lane_id = threadIdx.x % kCudaWarpSize; CUDA_1D_KERNEL_LOOP(i, n) { int32_t mask_val = mask[i / kCudaWarpSize]; bool is_positive = mask_val & (1 << lane_id); addend_diff[i] = static_cast(static_cast(is_positive)) * dy[i]; } } #endif template void ReluBackward(ep::Stream* stream, int64_t n, const int32_t* mask, const T* dy, T* addend_diff) { ReluBackwardGpu<<As()->cuda_stream()>>>(n, mask, dy, addend_diff); } void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void* y, int32_t* mask) { if (data_type == kFloat) { Relu(stream, n, reinterpret_cast(x), reinterpret_cast(y), mask); } else if (data_type == kDouble) { Relu(stream, n, reinterpret_cast(x), reinterpret_cast(y), mask); } else if (data_type == kFloat16) { Relu(stream, n, reinterpret_cast(x), reinterpret_cast(y), mask); } else if (data_type == kBFloat16) { #if CUDA_VERSION >= 11000 Relu(stream, n, reinterpret_cast(x), reinterpret_cast(y), mask); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } } void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, const void* addend, void* y, int32_t* mask) { if (data_type == kFloat) { AddRelu(stream, n, reinterpret_cast(x), reinterpret_cast(addend), reinterpret_cast(y), mask); } else if (data_type == kDouble) { AddRelu(stream, n, reinterpret_cast(x), reinterpret_cast(addend), reinterpret_cast(y), mask); } else if (data_type == kFloat16) { AddRelu(stream, n, reinterpret_cast(x), reinterpret_cast(addend), reinterpret_cast(y), mask); } else if (data_type == kBFloat16) { #if CUDA_VERSION >= 11000 AddRelu(stream, n, reinterpret_cast(x), reinterpret_cast(addend), reinterpret_cast(y), mask); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } } void ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int32_t* mask, const void* dy, void* addend_diff) { if (data_type == kFloat) { ReluBackward(stream, n, mask, reinterpret_cast(dy), reinterpret_cast(addend_diff)); } else if (data_type == kDouble) { ReluBackward(stream, n, mask, reinterpret_cast(dy), reinterpret_cast(addend_diff)); } else if (data_type == kFloat16) { ReluBackward(stream, n, mask, reinterpret_cast(dy), reinterpret_cast(addend_diff)); } else if (data_type == kBFloat16) { #if CUDA_VERSION >= 11000 ReluBackward(stream, n, mask, reinterpret_cast(dy), reinterpret_cast(addend_diff)); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } } class NormalizationInferenceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: NormalizationInferenceKernel() = default; ~NormalizationInferenceKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const bool training = ctx->Attr("training"); CHECK(!training); const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0); auto* moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const DataType data_type = x->data_type(); CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(y->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes()); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode); desc_helper.CheckParamTensor(gamma); desc_helper.CheckParamTensor(beta); desc_helper.CheckParamTensor(moving_mean); desc_helper.CheckParamTensor(moving_variance); const void* sp_alpha = CudnnSPOnePtr(data_type); const void* sp_beta; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->shape_view(), y->shape_view()); Memcpy( ctx->stream(), y->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); sp_beta = CudnnSPOnePtr(data_type); } else { sp_beta = CudnnSPZeroPtr(data_type); } OF_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( ctx->stream()->As()->cudnn_handle(), mode, sp_alpha, sp_beta, desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(), moving_variance->dptr(), epsilon)); if (ctx->op_type_name() == "normalization_add_relu") { CHECK(!ctx->has_input("_add_to_output", 0)); const int64_t elem_cnt = x->shape_view().elem_cnt(); auto* mask = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); if (ctx->has_input("addend", 0)) { const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0); AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(), mask->mut_dptr()); } else { Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(), mask->mut_dptr()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("normalization") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobAttr("training") == false)) .SetInplaceProposalFn([](const user_op::InferContext& ctx, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); REGISTER_USER_KERNEL("normalization_add_relu") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobAttr("training") == false)); class NormalizationTrainKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: NormalizationTrainKernel() = default; ~NormalizationTrainKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { if (ctx->op_type_name() == "normalization") { CHECK(ctx->Attr("training")); } const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const auto momentum = ctx->Attr("momentum"); const DataType data_type = x->data_type(); CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(y->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes()); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); desc_helper.CheckParamTensor(gamma); desc_helper.CheckParamTensor(beta); desc_helper.CheckParamTensor(mean); desc_helper.CheckParamTensor(inv_variance); user_op::Tensor* moving_mean = nullptr; user_op::Tensor* moving_variance = nullptr; if (ctx->has_input("moving_mean", 0)) { CHECK(ctx->has_input("moving_variance", 0)); moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0); moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0); desc_helper.CheckParamTensor(moving_mean); desc_helper.CheckParamTensor(moving_variance); } const void* sp_alpha = CudnnSPOnePtr(data_type); const void* sp_beta; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->shape_view(), y->shape_view()); Memcpy( ctx->stream(), y->mut_dptr(), add_to_output->dptr(), add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); sp_beta = CudnnSPOnePtr(data_type); } else { sp_beta = CudnnSPZeroPtr(data_type); } #if defined(BN_ENABLE_EX_API) size_t workspace_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &workspace_size)); size_t reserve_space_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size)); auto* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); if (reserve_space_size == 0 && workspace_size <= workspace->shape_view().elem_cnt()) { OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, sp_alpha, sp_beta, desc_helper.xy_desc(), x->dptr(), nullptr, nullptr, desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean ? moving_mean->mut_dptr() : NULL, moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(), inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(), workspace->shape_view().elem_cnt(), nullptr, 0)); } else { OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( ctx->stream()->As()->cudnn_handle(), mode, sp_alpha, sp_beta, desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean ? moving_mean->mut_dptr() : NULL, moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(), inv_variance->mut_dptr())); } #else OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( ctx->stream()->As()->cudnn_handle(), mode, sp_alpha, sp_beta, desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean ? moving_mean->mut_dptr() : NULL, moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(), inv_variance->mut_dptr())); #endif if (ctx->op_type_name() == "normalization_add_relu") { CHECK(!ctx->has_input("_add_to_output", 0)); const int64_t elem_cnt = x->shape_view().elem_cnt(); auto* mask = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); if (ctx->has_input("addend", 0)) { const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0); AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(), mask->mut_dptr()); } else { Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(), mask->mut_dptr()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("normalization") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobAttr("training") == true)) .SetInferTmpSizeFn(InferTrainTmpSize) .SetInplaceProposalFn([](const user_op::InferContext& ctx, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); REGISTER_USER_KERNEL("normalization_add_relu") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobAttr("training") == true)) .SetInferTmpSizeFn(InferTrainTmpSize); class NormalizationGradUserKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: NormalizationGradUserKernel() = default; ~NormalizationGradUserKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); auto* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0); auto* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0); const auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const DataType data_type = x->data_type(); CHECK_EQ(dy->shape_view(), x->shape_view()); CHECK_EQ(dy->data_type(), data_type); CHECK_EQ(dx->shape_view(), x->shape_view()); CHECK_EQ(dx->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes()); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode); desc_helper.CheckParamTensor(gamma); desc_helper.CheckParamTensor(gamma_diff); desc_helper.CheckParamTensor(beta_diff); desc_helper.CheckParamTensor(mean); desc_helper.CheckParamTensor(inv_variance); void* bn_workspace_ptr; size_t bn_workspace_size; const void* bn_dy_ptr; if (ctx->op_type_name() == "normalization_grad") { bn_workspace_ptr = tmp_buffer->mut_dptr(); bn_workspace_size = tmp_buffer->shape_view().elem_cnt(); bn_dy_ptr = dy->dptr(); } else if (ctx->op_type_name() == "normalization_add_relu_grad") { const int64_t elem_cnt = dy->shape_view().elem_cnt(); const auto* mask = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); if (ctx->has_output("addend_diff", 0)) { user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex("addend_diff", 0); ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr(), dy->dptr(), addend_diff->mut_dptr()); bn_workspace_ptr = tmp_buffer->mut_dptr(); bn_workspace_size = tmp_buffer->shape_view().elem_cnt(); bn_dy_ptr = addend_diff->dptr(); } else { const size_t tmp_buffer_size = tmp_buffer->shape_view().elem_cnt(); const size_t relu_dx_size = GetCudaAlignedSize(dy->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); CHECK_GE(tmp_buffer_size, relu_dx_size); ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr(), dy->dptr(), tmp_buffer->mut_dptr()); bn_workspace_ptr = tmp_buffer->mut_dptr() + relu_dx_size; bn_workspace_size = tmp_buffer_size - relu_dx_size; bn_dy_ptr = tmp_buffer->dptr(); } } else { UNIMPLEMENTED(); } #if defined(BN_ENABLE_EX_API) size_t workspace_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &workspace_size)); size_t reserve_space_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size)); if (reserve_space_size == 0 && workspace_size <= bn_workspace_size) { OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx( ctx->stream()->As()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), nullptr, nullptr, desc_helper.xy_desc(), bn_dy_ptr, nullptr, nullptr, desc_helper.xy_desc(), dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), nullptr, gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(), nullptr, bn_workspace_ptr, bn_workspace_size, nullptr, 0)); } else { OF_CUDNN_CHECK(cudnnBatchNormalizationBackward( ctx->stream()->As()->cudnn_handle(), mode, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), bn_dy_ptr, desc_helper.xy_desc(), dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr())); } #else OF_CUDNN_CHECK(cudnnBatchNormalizationBackward( ctx->stream()->As()->cudnn_handle(), mode, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), bn_dy_ptr, desc_helper.xy_desc(), dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr())); #endif } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("normalization_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)) .SetInferTmpSizeFn(InferGradTmpSize); #define REGISTER_BN_ADD_RELU_GRAD_KERNEL(dtype) REGISTER_USER_KERNEL("normalization_add_relu_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)) .SetInferTmpSizeFn(InferGradTmpSize); #if (CUDNN_VERSION >= 7401) size_t InferFusedNormalizationAddReluTmpSize(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); const auto axis = ctx->Attr("axis"); const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis, CUDNN_BATCHNORM_SPATIAL_PERSISTENT); size_t size_in_bytes; cudnnHandle_t handle = Singleton::Get()->Get(); CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0); cudnnBatchNormOps_t ops; cudnnTensorDescriptor_t z_desc; if (ctx->has_input("addend", 0)) { ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; z_desc = desc_helper.xy_desc(); } else { ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; z_desc = nullptr; } OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &size_in_bytes)); Singleton::Get()->Put(handle); return std::max(size_in_bytes, static_cast(1)); } size_t InferFusedNormalizationAddReluGradTmpSize(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); const auto axis = ctx->Attr("axis"); const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis, CUDNN_BATCHNORM_SPATIAL_PERSISTENT); size_t size_in_bytes; cudnnHandle_t handle = Singleton::Get()->Get(); CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0); cudnnBatchNormOps_t ops; cudnnTensorDescriptor_t z_desc; if (ctx->has_output("addend_diff", 0)) { ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; z_desc = desc_helper.xy_desc(); } else { ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; z_desc = nullptr; } OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize( handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), desc_helper.xy_desc(), desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &size_in_bytes)); Singleton::Get()->Put(handle); return std::max(size_in_bytes, static_cast(1)); } class FusedNormalizationAddReluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedNormalizationAddReluKernel() = default; ~FusedNormalizationAddReluKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0); auto* moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const auto momentum = ctx->Attr("momentum"); auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const DataType data_type = x->data_type(); CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(y->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, CUDNN_BATCHNORM_SPATIAL_PERSISTENT); desc_helper.CheckParamTensor(gamma); desc_helper.CheckParamTensor(beta); desc_helper.CheckParamTensor(moving_mean); desc_helper.CheckParamTensor(moving_variance); desc_helper.CheckParamTensor(mean); desc_helper.CheckParamTensor(inv_variance); CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0); cudnnTensorDescriptor_t z_desc; const void* z_ptr; cudnnBatchNormOps_t ops; if (ctx->has_input("addend", 0)) { z_desc = desc_helper.xy_desc(); z_ptr = ctx->Tensor4ArgNameAndIndex("addend", 0)->dptr(); ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; } else { z_desc = nullptr; z_ptr = nullptr; ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; } size_t min_workspace_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &min_workspace_size)); const size_t workspace_size = tmp_buffer->shape_view().elem_cnt(); CHECK_GE(workspace_size, min_workspace_size); size_t min_reserve_space_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size)); const size_t reserve_space_size = reserve_space->shape_view().elem_cnt(); CHECK_GE(reserve_space_size, min_reserve_space_size); OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), z_desc, z_ptr, desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr(), activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size, reserve_space->mut_dptr(), reserve_space_size)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)) .SetInferTmpSizeFn(InferFusedNormalizationAddReluTmpSize); class FusedNormalizationAddReluGradUserKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: FusedNormalizationAddReluGradUserKernel() = default; ~FusedNormalizationAddReluGradUserKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0); const auto* y = ctx->Tensor4ArgNameAndIndex("y", 0); auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); auto* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0); auto* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0); const auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); const auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); const auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0); auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto axis = ctx->Attr("axis"); const auto epsilon = ctx->Attr("epsilon"); const DataType data_type = x->data_type(); CHECK_EQ(dy->shape_view(), x->shape_view()); CHECK_EQ(dy->data_type(), data_type); CHECK_EQ(dx->shape_view(), x->shape_view()); CHECK_EQ(dx->data_type(), data_type); CHECK_GE(axis, 0); CHECK_LT(axis, x->shape_view().NumAxes()); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, CUDNN_BATCHNORM_SPATIAL_PERSISTENT); desc_helper.CheckParamTensor(gamma); desc_helper.CheckParamTensor(beta); desc_helper.CheckParamTensor(gamma_diff); desc_helper.CheckParamTensor(beta_diff); desc_helper.CheckParamTensor(mean); desc_helper.CheckParamTensor(inv_variance); CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0); cudnnTensorDescriptor_t dz_desc; void* dz_ptr; cudnnBatchNormOps_t ops; if (ctx->has_output("addend_diff", 0)) { dz_desc = desc_helper.xy_desc(); dz_ptr = ctx->Tensor4ArgNameAndIndex("addend_diff", 0)->mut_dptr(); ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; } else { dz_desc = nullptr; dz_ptr = nullptr; ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; } size_t min_workspace_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), desc_helper.xy_desc(), desc_helper.xy_desc(), dz_desc, desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &min_workspace_size)); const size_t workspace_size = tmp_buffer->shape_view().elem_cnt(); CHECK_GE(workspace_size, min_workspace_size); size_t min_reserve_space_size; OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size)); const size_t reserve_space_size = reserve_space->shape_view().elem_cnt(); CHECK_GE(reserve_space_size, min_reserve_space_size); OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx( ctx->stream()->As()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->dptr(), desc_helper.xy_desc(), dy->dptr(), dz_desc, dz_ptr, desc_helper.xy_desc(), dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(), activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size, const_cast(reserve_space->dptr()), reserve_space_size)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)) .SetInferTmpSizeFn(InferFusedNormalizationAddReluGradTmpSize); #endif } // namespace } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/nvtx_range_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #ifdef OF_ENABLE_PROFILER #include #endif // OF_ENABLE_PROFILER namespace oneflow { namespace { #ifdef OF_ENABLE_PROFILER static thread_local HashMap mark2range_id; #endif } // namespace class NvtxOpKernelState final : public user_op::OpKernelState { public: NvtxOpKernelState() : counter_(0) { #ifndef OF_ENABLE_PROFILER LOG(WARNING) << "To use NVTX, run cmake with -DBUILD_PROFILER=ON"; #endif } ~NvtxOpKernelState() override = default; int64_t counter() const { return counter_; } void IncreaseCount() { counter_ += 1; } private: int64_t counter_; }; class NvtxStartKernel final : public user_op::OpKernel { public: NvtxStartKernel() = default; ~NvtxStartKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); CHECK_EQ(out->shape_view(), in_shape); const DataType in_data_type = in->data_type(); CHECK_EQ(out->data_type(), in_data_type); Memcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); #ifdef OF_ENABLE_PROFILER auto* kernel_state = dynamic_cast(state); const std::string mark_prefix = ctx->Attr("mark_prefix"); const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter()); nvtxRangeId_t range_id = nvtxRangeStartA(mark.c_str()); CHECK(mark2range_id.emplace(mark, range_id).second); kernel_state->IncreaseCount(); #endif // OF_ENABLE_PROFILER } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("nvtx_start") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInplaceProposalFn([](const user_op::InferContext&, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); return Maybe::Ok(); }); class NvtxEndKernel final : public user_op::OpKernel { public: NvtxEndKernel() = default; ~NvtxEndKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); CHECK_EQ(out->shape_view(), in_shape); const DataType in_data_type = in->data_type(); CHECK_EQ(out->data_type(), in_data_type); #ifdef OF_ENABLE_PROFILER auto* kernel_state = dynamic_cast(state); const std::string mark_prefix = ctx->Attr("mark_prefix"); const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter()); auto it = mark2range_id.find(mark.c_str()); CHECK(it != mark2range_id.end()); nvtxRangeId_t range_id = it->second; mark2range_id.erase(it); nvtxRangeEnd(range_id); Memcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); kernel_state->IncreaseCount(); #endif } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("nvtx_end") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInplaceProposalFn([](const user_op::InferContext&, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); return Maybe::Ok(); }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ofrecord_decoder_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/user/image/random_crop_generator.h" #include "oneflow/user/image/image_util.h" #include "oneflow/user/kernels/random_crop_kernel_state.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/image/jpeg_decoder.h" #include #include namespace oneflow { namespace { template void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, bool truncate, bool dim1_varying_length) { if (feature.has_bytes_list()) { CHECK_EQ(feature.bytes_list().value_size(), 1); const auto& value0 = feature.bytes_list().value(0); auto in_dptr = reinterpret_cast(value0.c_str()); sample_elem_cnt = std::min(sample_elem_cnt, value0.size()); std::transform(in_dptr, in_dptr + sample_elem_cnt, dptr, [](int8_t v) { return static_cast(v); }); } #define DEFINE_ONE_ELIF(PbT, CppT) \ else if (feature.has_##PbT##_list()) { \ const auto& list = feature.PbT##_list(); \ const CppT* in_dptr = list.value().data(); \ const int64_t padding_elem_num = truncate ? sample_elem_cnt - list.value_size() : 0; \ if (truncate) { \ sample_elem_cnt = std::min(sample_elem_cnt, list.value_size()); \ } else { \ if (dim1_varying_length) { \ sample_elem_cnt = list.value_size(); \ } else { \ CHECK_EQ(sample_elem_cnt, list.value_size()); \ } \ } \ std::transform(in_dptr, in_dptr + sample_elem_cnt, dptr, \ [](CppT v) { return static_cast(v); }); \ if (padding_elem_num > 0) { \ std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T)); \ } \ } DEFINE_ONE_ELIF(float, float) DEFINE_ONE_ELIF(double, double) DEFINE_ONE_ELIF(int32, int32_t) DEFINE_ONE_ELIF(int64, int64_t) #undef DEFINE_ONE_ELIF else { UNIMPLEMENTED(); } } } // namespace template class OFRecordRawDecoderKernel final : public user_op::OpKernel { public: OFRecordRawDecoderKernel() = default; ~OFRecordRawDecoderKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); // TODO(chengcheng): remove record num in record blob, fix by shape elem cnt int64_t record_num = in_blob->shape_view().At(0); int64_t sample_elem_cnt = out_blob->shape_view().Count(1); CHECK(record_num > 0); const OFRecord* records = in_blob->dptr(); T* out_dptr = out_blob->mut_dptr(); const std::string& name = ctx->Attr("name"); bool truncate = ctx->Attr("truncate"); bool dim1_varying_length = ctx->Attr("dim1_varying_length"); MultiThreadLoop(record_num, [&](size_t i) { const OFRecord& record = *(records + i); T* dptr = out_dptr + i * sample_elem_cnt; CHECK(record.feature().find(name) != record.feature().end()) << "Field " << name << " not found"; const Feature& feature = record.feature().at(name); DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, truncate, dim1_varying_length); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_RAW_DECODER_KERNEL(dtype) \ REGISTER_USER_KERNEL("ofrecord_raw_decoder") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == DataType::kOFRecord) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_RAW_DECODER_KERNEL(char) REGISTER_RAW_DECODER_KERNEL(float) REGISTER_RAW_DECODER_KERNEL(double) REGISTER_RAW_DECODER_KERNEL(int8_t) REGISTER_RAW_DECODER_KERNEL(int32_t) REGISTER_RAW_DECODER_KERNEL(int64_t) REGISTER_RAW_DECODER_KERNEL(uint8_t) class OFRecordBytesDecoderKernel final : public user_op::OpKernel { public: OFRecordBytesDecoderKernel() = default; ~OFRecordBytesDecoderKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out->shape_view(), in->shape_view()); CHECK_EQ(in->data_type(), DataType::kOFRecord); CHECK_EQ(out->data_type(), DataType::kTensorBuffer); const int64_t num_instances = in->shape_view().elem_cnt(); const auto* records = in->dptr(); auto* buffers = out->mut_dptr(); const std::string& name = ctx->Attr("name"); MultiThreadLoop(num_instances, [&](size_t i) { const OFRecord& record = *(records + i); TensorBuffer* buffer = buffers + i; auto it = record.feature().find(name); CHECK(it != record.feature().end()) << "Field " << name << " not found"; const Feature& feature = it->second; CHECK(feature.has_bytes_list()); CHECK_EQ(feature.bytes_list().value_size(), 1); const int64_t size = feature.bytes_list().value(0).size(); buffer->Resize(Shape({size}), DataType::kUInt8); memcpy(buffer->mut_data(), feature.bytes_list().value(0).data(), size); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("ofrecord_bytes_decoder") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kOFRecord) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); namespace { void DecodeRandomCropImageFromOneRecord(const OFRecord& record, TensorBuffer* buffer, const std::string& name, const std::string& color_space, RandomCropGenerator* random_crop_gen) { CHECK(record.feature().find(name) != record.feature().end()) << "Field " << name << " not found"; const Feature& feature = record.feature().at(name); CHECK(feature.has_bytes_list()); CHECK(feature.bytes_list().value_size() == 1); const std::string& src_data = feature.bytes_list().value(0); cv::Mat image; if (JpegPartialDecodeRandomCropImage(reinterpret_cast(src_data.data()), src_data.size(), random_crop_gen, nullptr, 0, &image)) { // convert color space // jpeg decode output RGB if (ImageUtil::IsColor(color_space) && color_space != "RGB") { ImageUtil::ConvertColor("RGB", image, color_space, image); } } else { OpenCvPartialDecodeRandomCropImage(reinterpret_cast(src_data.data()), src_data.size(), random_crop_gen, color_space, image); // convert color space // opencv decode output BGR if (ImageUtil::IsColor(color_space) && color_space != "BGR") { ImageUtil::ConvertColor("BGR", image, color_space, image); } } int W = image.cols; int H = image.rows; CHECK(image.isContinuous()); const int c = ImageUtil::IsColor(color_space) ? 3 : 1; CHECK_EQ(c, image.channels()); Shape image_shape({H, W, c}); buffer->Resize(image_shape, DataType::kUInt8); CHECK_EQ(image_shape.elem_cnt(), buffer->nbytes()); CHECK_EQ(image_shape.elem_cnt(), image.total() * image.elemSize()); memcpy(buffer->mut_data(), image.ptr(), image_shape.elem_cnt()); } } // namespace class OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel { public: OFRecordImageDecoderRandomCropKernel() = default; ~OFRecordImageDecoderRandomCropKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return CreateRandomCropKernelState(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* crop_window_generators = dynamic_cast(state); CHECK_NOTNULL(crop_window_generators); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t record_num = out_blob->shape_view().At(0); CHECK(record_num > 0); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_EQ(out_blob->shape_view(), in_blob->shape_view()); const OFRecord* records = in_blob->dptr(); TensorBuffer* buffers = out_blob->mut_dptr(); const std::string& name = ctx->Attr("name"); const std::string& color_space = ctx->Attr("color_space"); MultiThreadLoop(record_num, [&](size_t i) { const OFRecord& record = *(records + i); TensorBuffer* buffer = buffers + i; RandomCropGenerator* gen = crop_window_generators->GetGenerator(i); DecodeRandomCropImageFromOneRecord(record, buffer, name, color_space, gen); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("ofrecord_image_decoder_random_crop") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kOFRecord) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); class OFRecordImageDecoderKernel final : public user_op::OpKernel { public: OFRecordImageDecoderKernel() = default; ~OFRecordImageDecoderKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t record_num = out_blob->shape_view().At(0); CHECK(record_num > 0); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_EQ(out_blob->shape_view(), in_blob->shape_view()); const OFRecord* records = in_blob->dptr(); TensorBuffer* buffers = out_blob->mut_dptr(); const std::string& name = ctx->Attr("name"); const std::string& color_space = ctx->Attr("color_space"); MultiThreadLoop(record_num, [&](size_t i) { const OFRecord& record = *(records + i); TensorBuffer* buffer = buffers + i; DecodeRandomCropImageFromOneRecord(record, buffer, name, color_space, nullptr); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("ofrecord_image_decoder") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kOFRecord) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/data/ofrecord_image_classification_data_reader.h" namespace oneflow { namespace { class OFRecordImageClassificationReaderKernelState final : public user_op::OpKernelState { public: explicit OFRecordImageClassificationReaderKernelState(user_op::KernelInitContext* ctx) : reader_(ctx) {} ~OFRecordImageClassificationReaderKernelState() override = default; void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); } private: data::OFRecordImageClassificationDataReader reader_; }; } // namespace class OFRecordImageClassificationReaderKernel final : public user_op::OpKernel { public: OFRecordImageClassificationReaderKernel() = default; ~OFRecordImageClassificationReaderKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); CHECK_NOTNULL(reader); reader->Read(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("ofrecord_image_classification_reader") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("image", 0) == DataType::kTensorBuffer) && (user_op::HobDataType("label", 0) == DataType::kTensorBuffer)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ofrecord_reader_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/data/ofrecord_data_reader.h" namespace oneflow { namespace { class OFRecordReaderWrapper final : public user_op::OpKernelState { public: explicit OFRecordReaderWrapper(user_op::KernelInitContext* ctx) : reader_(ctx) {} ~OFRecordReaderWrapper() = default; void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); } private: data::OFRecordDataReader reader_; }; } // namespace class OFRecordReaderKernel final : public user_op::OpKernel { public: OFRecordReaderKernel() = default; ~OFRecordReaderKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { std::shared_ptr reader(new OFRecordReaderWrapper(ctx)); return reader; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); reader->Read(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("OFRecordReader") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("out", 0) == DataType::kOFRecord)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_embedding_data_shuffle.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/gather_kernel_util.h" #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/hash_functions.cuh" namespace oneflow { namespace data_shuffle { template struct TableEntry { K key; uint32_t value; }; template __global__ void GenerateTableIds(int32_t elem_cnt, int32_t num_tables, U* table_ids) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { table_ids[i] = i % num_tables; } } namespace { constexpr uint32_t PADDING_REV_INDEX = 0xffffffff; template __global__ void HashTableUniqueAndPartitionPairs( const uint32_t table_capacity, const uint32_t num_keys, int32_t num_partition, IDX* unique_counts, TableEntry* table, const K* keys, const V* values, K* partitioned_unique_keys, V* partitioned_unique_values, IDX* reverse_index, bool need_process_values, const bool has_padding_idx, const int64_t padding_idx) { CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_keys) { IDX r_index_plus_one = 0; const K key = keys[i]; if (has_padding_idx && key == padding_idx) { reverse_index[i] = PADDING_REV_INDEX; } else { size_t key_hash = HASH()(key); uint32_t partition_id = key_hash % num_partition; IDX* unique_count = unique_counts + partition_id; K* unique_keys = partitioned_unique_keys + partition_id * num_keys; uint32_t pos = key_hash % table_capacity; const K key_hi = (key | 0x1); const K key_lo = (key & 0x1); uint32_t counter = 0; while (r_index_plus_one == 0) { bool prob_next = false; K* key_ptr = &table[pos].key; volatile uint32_t* table_value_ptr = &table[pos].value; const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi); if (old_key == 0) { IDX unique_pos = cuda::atomic::Add(unique_count, 1); r_index_plus_one = unique_pos + 1; unique_keys[unique_pos] = key; if (need_process_values) { partitioned_unique_values[partition_id * num_keys + unique_pos] = values[i]; } *table_value_ptr = ((r_index_plus_one << 1U) | key_lo); } else if (old_key == key_hi) { const uint32_t value = *table_value_ptr; if (value == 0) { // do nothing } else if ((value & 0x1) == key_lo) { r_index_plus_one = (value >> 1U); } else { prob_next = true; } } else { prob_next = true; } if (prob_next) { pos += 1; counter += 1; if (pos >= table_capacity) { pos -= table_capacity; } if (counter >= table_capacity) { __trap(); } } } reverse_index[i] = partition_id * num_keys + r_index_plus_one - 1; } } } template __global__ void ComputeOffset(int32_t n, IDX* value) { IDX sum = 0; for (int i = 0; i < n; ++i) { IDX count = value[i]; value[i] = sum; sum += count; } } template __global__ void ContiguousInverseUniquePartitionIndices(const int32_t num_ids, IDX* indices_offset, IDX* inverse_ptr) { CUDA_1D_KERNEL_LOOP(i, num_ids) { int inverse_indice = inverse_ptr[i]; int partition_id = inverse_indice / num_ids; int partition_indice = inverse_indice - partition_id * num_ids; int new_offset = indices_offset[partition_id]; inverse_ptr[i] = new_offset + partition_indice; } } template void ShuffleData(cudaStream_t cuda_stream, ncclComm_t comm, DataType data_type, const std::vector& send_offsets, const std::vector& send_elem_cnt, const T* send_data, const std::vector& recv_offsets, const std::vector& recv_elem_cnt, T* recv_data) { ncclDataType_t nccl_data_type = GetNcclDataType(data_type); const int64_t parallel_num = send_offsets.size(); OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { OF_NCCL_CHECK(ncclSend(send_data + send_offsets.at(i), send_elem_cnt.at(i), nccl_data_type, i, comm, cuda_stream)); OF_NCCL_CHECK(ncclRecv(recv_data + recv_offsets.at(i), recv_elem_cnt.at(i), nccl_data_type, i, comm, cuda_stream)); } OF_NCCL_CHECK(ncclGroupEnd()); } template void MakeShuffleIdParams(const IDX* host_num_unique_matrix, const int64_t num_ids, const int64_t row_size, int64_t parallel_id, int64_t parallel_num, std::vector* scatter_offset_vec, std::vector* scatter_elem_cnt_vec, std::vector* gather_offset_vec, std::vector* gather_elem_cnt_vec) { scatter_offset_vec->resize(parallel_num); scatter_elem_cnt_vec->resize(parallel_num); gather_offset_vec->resize(parallel_num); gather_elem_cnt_vec->resize(parallel_num); int64_t gather_offset = 0; for (int64_t i = 0; i < parallel_num; ++i) { const int64_t scatter_elem_cnt = host_num_unique_matrix[parallel_id * parallel_num + i] * row_size; const int64_t gather_elem_cnt = host_num_unique_matrix[i * parallel_num + parallel_id] * row_size; scatter_offset_vec->at(i) = i * num_ids * row_size; scatter_elem_cnt_vec->at(i) = scatter_elem_cnt; gather_offset_vec->at(i) = gather_offset; gather_elem_cnt_vec->at(i) = gather_elem_cnt; gather_offset += gather_elem_cnt; } } template void MakeShuffleParams(const IDX* host_num_unique_matrix, const int64_t num_ids, const int64_t row_size, int64_t parallel_id, int64_t parallel_num, std::vector* scatter_offset_vec, std::vector* scatter_elem_cnt_vec, std::vector* gather_offset_vec, std::vector* gather_elem_cnt_vec) { scatter_offset_vec->resize(parallel_num); scatter_elem_cnt_vec->resize(parallel_num); gather_offset_vec->resize(parallel_num); gather_elem_cnt_vec->resize(parallel_num); int64_t gather_offset = 0; int64_t scatter_offset = 0; for (int64_t i = 0; i < parallel_num; ++i) { const int64_t scatter_elem_cnt = host_num_unique_matrix[parallel_id * parallel_num + i] * row_size; const int64_t gather_elem_cnt = host_num_unique_matrix[i * parallel_num + parallel_id] * row_size; scatter_offset_vec->at(i) = scatter_offset; scatter_elem_cnt_vec->at(i) = scatter_elem_cnt; gather_offset_vec->at(i) = gather_offset; gather_elem_cnt_vec->at(i) = gather_elem_cnt; scatter_offset += scatter_elem_cnt; gather_offset += gather_elem_cnt; } } template void ShuffleIdsAndTableIds(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, int64_t parallel_num, int64_t num_ids, DataType ids_data_type, DataType table_ids_data_type, IDX* host_num_unique_matrix, K* partitioned_unique_ids, U* partitioned_unique_table_ids, K* received_ids, U* received_table_ids, int64_t* received_elem_cnt, bool need_process_table_ids) { std::vector send_offsets; std::vector send_elem_cnt; std::vector recv_offsets; std::vector recv_elem_cnt; MakeShuffleIdParams(host_num_unique_matrix, num_ids, 1, parallel_id, parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); ShuffleData(cuda_stream, comm, ids_data_type, send_offsets, send_elem_cnt, partitioned_unique_ids, recv_offsets, recv_elem_cnt, received_ids); *received_elem_cnt = recv_offsets.at(parallel_num - 1) + recv_elem_cnt.at(parallel_num - 1); if (need_process_table_ids) { ShuffleData(cuda_stream, comm, table_ids_data_type, send_offsets, send_elem_cnt, partitioned_unique_table_ids, recv_offsets, recv_elem_cnt, received_table_ids); } } template __global__ void UnsortedSegmentHalfGpu(const IDX in_h2_elem_cnt, const IDX h2_inner_dim_size, const IDX inner_dim_size, const half* data, const K* segment_ids, const IDX num_segments, half2* out_h2) { CUDA_1D_KERNEL_LOOP_T(IDX, i, in_h2_elem_cnt) { const IDX segment_id_idx = i / h2_inner_dim_size; const IDX h2_inner_idx = i - segment_id_idx * h2_inner_dim_size; const IDX inner_idx_0 = 2 * h2_inner_idx; const IDX inner_idx_1 = inner_idx_0 + 1; const half* data_row = data + segment_id_idx * inner_dim_size; half2 val; val.x = data_row[inner_idx_0]; val.y = (inner_idx_1 >= inner_dim_size) ? static_cast(0) : data_row[inner_idx_1]; const IDX idx = segment_ids[segment_id_idx]; const IDX out_h2_offset = idx * h2_inner_dim_size + h2_inner_idx; cuda::atomic::Add(out_h2 + out_h2_offset, val); } } template struct UnsortedSegmentSumPad { void operator()(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size, T* out) const { UNIMPLEMENTED(); } }; template struct UnsortedSegmentSumPad { void operator()(ep::Stream* stream, const K* segment_ids, const half* data, int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size, half* out) const { const int64_t data_elem_cnt = num_segment_ids * inner_dim_size; const int64_t out_elem_cnt = num_segments * padded_inner_dim_size; CHECK_EQ(padded_inner_dim_size % 2, 0); CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size); const int64_t h2_inner_dim_size = padded_inner_dim_size / 2; const int64_t in_h2_elem_cnt = num_segment_ids * h2_inner_dim_size; if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal() / 2) { UnsortedSegmentHalfGpu <<As()->cuda_stream()>>>( in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments, reinterpret_cast(out)); } else { UnsortedSegmentHalfGpu <<As()->cuda_stream()>>>( in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments, reinterpret_cast(out)); } } }; template void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size, T* out) { if (inner_dim_size == padded_inner_dim_size) { UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( stream, segment_ids, data, num_segment_ids, num_segments, 1, inner_dim_size, 0, out); } else { CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size); UnsortedSegmentSumPad()(stream, segment_ids, data, num_segment_ids, num_segments, inner_dim_size, padded_inner_dim_size, out); } } } // namespace template void UniqueAndPartition(cudaStream_t cuda_stream, int64_t num_ids, size_t capacity, int64_t num_partition, const K* ids, const V* table_ids, IDX* num_partitioned_unique_ids_ptr, K* partitioned_unique_ids, V* partitioned_unique_table_ids, IDX* inverse_unique_partition_indices, void* workspace_ptr, size_t workspace_bytes, bool need_process_table_ids, const bool has_padding_idx, const int64_t padding_idx) { size_t table_capacity_bytes = capacity * sizeof(TableEntry); CHECK_GE(workspace_bytes, table_capacity_bytes); OF_CUDA_CHECK(cudaMemsetAsync(workspace_ptr, 0, table_capacity_bytes, cuda_stream)); OF_CUDA_CHECK( cudaMemsetAsync(num_partitioned_unique_ids_ptr, 0, num_partition * sizeof(IDX), cuda_stream)); HashTableUniqueAndPartitionPairs <<>>( capacity, num_ids, num_partition, num_partitioned_unique_ids_ptr, reinterpret_cast*>(workspace_ptr), ids, table_ids, partitioned_unique_ids, partitioned_unique_table_ids, inverse_unique_partition_indices, need_process_table_ids, has_padding_idx, padding_idx); } template void ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, int64_t parallel_num, int64_t num_ids, int64_t embedding_size, DataType data_type, IDX* host_num_unique_matrix, const T* reverse_unique_cur_rank_embeddings, T* received_embeddings) { std::vector send_offsets; std::vector send_elem_cnt; std::vector recv_offsets; std::vector recv_elem_cnt; MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt); ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings); } // Quantized Version. template void ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, int64_t parallel_num, int64_t num_ids, int64_t embedding_size, DataType data_type, IDX* host_num_unique_matrix, int8_t* reverse_unique_cur_rank_embeddings, int8_t* received_embeddings, T* reverse_cur_rank_quantize_factor, T* recv_quantize_factor) { std::vector send_offsets; std::vector send_elem_cnt; std::vector recv_offsets; std::vector recv_elem_cnt; // shuffle quantized_embedding MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt); ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt, reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings); // shuffle quantize_factor MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id, parallel_num, &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt); ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, reverse_cur_rank_quantize_factor, recv_offsets, recv_elem_cnt, recv_quantize_factor); } template void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, int64_t parallel_num, int64_t num_ids, int64_t embedding_size, DataType data_type, IDX* host_num_unique_matrix, const T* unique_partition_embedding_grad, T* received_embeddings_grad) { std::vector send_offsets; std::vector send_elem_cnt; std::vector recv_offsets; std::vector recv_elem_cnt; MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, unique_partition_embedding_grad, recv_offsets, recv_elem_cnt, received_embeddings_grad); } // Quantize Version. template void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, int64_t parallel_num, int64_t num_ids, int64_t embedding_size, DataType data_type, IDX* host_num_unique_matrix, int8_t* unique_partition_embedding_grad, int8_t* received_embeddings_grad, T* cur_rank_quantize_factor, T* received_cur_rank_quantize_factor) { std::vector send_offsets; std::vector send_elem_cnt; std::vector recv_offsets; std::vector recv_elem_cnt; // Shuffle Embedding Grad. MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt, unique_partition_embedding_grad, recv_offsets, recv_elem_cnt, received_embeddings_grad); // Shuffle Quantize factor. MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id, parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, cur_rank_quantize_factor, recv_offsets, recv_elem_cnt, received_cur_rank_quantize_factor); } inline int64_t GetPaddedEmbeddingSize(DataType data_type, int64_t embedding_size) { if (data_type == DataType::kFloat16 && embedding_size % 2 != 0) { return embedding_size + 1; } else { return embedding_size; } } template void UniquePartitionEmbeddingGrad(ep::Stream* stream, int64_t unique_partitioned_num_ids, int64_t num_ids, int64_t embedding_size, int64_t padded_embedding_size, const IDX* host_num_unique_matrix, const T* embedding_grad, const IDX* inverse_unique_partition_indices, T* unique_partition_embedding_grad) { const int64_t valid_value_size = unique_partitioned_num_ids * padded_embedding_size * sizeof(T); OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad, 0, valid_value_size, stream->As()->cuda_stream())); UnsortedSegmentSum(stream, inverse_unique_partition_indices, embedding_grad, num_ids, unique_partitioned_num_ids, embedding_size, padded_embedding_size, unique_partition_embedding_grad); } template void UniqueCurRankEmbeddingGrad(ep::Stream* stream, DataType data_type, int64_t cur_rank_num_ids, int64_t num_unique, int64_t embedding_size, int64_t padded_embedding_size, bool only_zero_valid_grad, int64_t cur_rank_unique_embedding_grad_elem_cnt, const T* cur_rank_embedding_grad, const IDX* cur_rank_inverse_indices, T* cur_rank_unique_embedding_grad, T* tmp_buffer) { cudaStream_t cuda_stream = stream->As()->cuda_stream(); // memset cur_rank_unique_embedding_grad, if only_zero_valid_grad, only memset valid data. if (only_zero_valid_grad) { OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad, 0, num_unique * embedding_size * sizeof(T), cuda_stream)); } else { OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad, 0, cur_rank_unique_embedding_grad_elem_cnt * sizeof(T), cuda_stream)); } T* unsorted_segment_sum_out; if (embedding_size != padded_embedding_size) { unsorted_segment_sum_out = tmp_buffer; size_t buffer_size = GetCudaAlignedSize(num_unique * padded_embedding_size * sizeof(T)); OF_CUDA_CHECK(cudaMemsetAsync(unsorted_segment_sum_out, 0, buffer_size, cuda_stream)); } else { // cur_rank_unique_embedding_grad's has been memset, not need to memset again. unsorted_segment_sum_out = cur_rank_unique_embedding_grad; } UnsortedSegmentSum(stream, cur_rank_inverse_indices, cur_rank_embedding_grad, cur_rank_num_ids, num_unique, padded_embedding_size, padded_embedding_size, unsorted_segment_sum_out); if (embedding_size != padded_embedding_size) { std::unique_ptr primitive = ep::primitive::NewPrimitive(DeviceType::kCUDA, 2); DimVector dst_shape = {num_unique, embedding_size}; DimVector dst_pos_vec = {0, 0}; DimVector src_shape = {num_unique, padded_embedding_size}; DimVector src_pos_vec = {0, 0}; DimVector extent_vec = {num_unique, embedding_size}; primitive->Launch(stream, data_type, 2, cur_rank_unique_embedding_grad, dst_shape.data(), dst_pos_vec.data(), unsorted_segment_sum_out, src_shape.data(), src_pos_vec.data(), extent_vec.data()); } } template struct IdShuffleDataPtrs { const K* ids_ptr; const U* table_ids_ptr; IDX* num_partitioned_unique; K* partitioned_unique_ids; U* partitioned_unique_table_ids; IDX* num_unique_matrix_ptr; IDX* inverse_unique_partition_indices_ptr; void* workspace_ptr; size_t workspace_size; K* received_ids; U* received_table_ids; IDX* cur_rank_num_unique_ptr; K* cur_rank_unique_ids_ptr; U* cur_rank_unique_table_ids_ptr; IDX* cur_rank_inverse_indices_ptr; }; template void IdShuffle(ep::Stream* stream, ncclComm_t comm, const IdShuffleDataPtrs& data_ptrs, int64_t num_ids, int64_t parallel_id, int64_t parallel_num, DataType num_unique_matrix_dtype, DataType ids_dtype, DataType table_ids_dtype, bool need_process_table_ids, const bool has_padding_idx, const int64_t padding_idx, IDX* host_num_unique_matrix, IDX* host_num_keys) { cudaStream_t cuda_stream = stream->As()->cuda_stream(); size_t hash_table_capacity = parallel_num * num_ids; UniqueAndPartition( cuda_stream, num_ids, hash_table_capacity, parallel_num, data_ptrs.ids_ptr, data_ptrs.table_ids_ptr, data_ptrs.num_partitioned_unique, data_ptrs.partitioned_unique_ids, data_ptrs.partitioned_unique_table_ids, data_ptrs.inverse_unique_partition_indices_ptr, data_ptrs.workspace_ptr, data_ptrs.workspace_size, need_process_table_ids, has_padding_idx, padding_idx); OF_NCCL_CHECK(ncclAllGather(data_ptrs.num_partitioned_unique, data_ptrs.num_unique_matrix_ptr, parallel_num, GetNcclDataType(num_unique_matrix_dtype), comm, cuda_stream)); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_unique_matrix, data_ptrs.num_unique_matrix_ptr, parallel_num * parallel_num * sizeof(IDX), cudaMemcpyDefault, cuda_stream)); CHECK_JUST(stream->Sync()); if (parallel_num > 1) { // use num_partitioned_unique as indices_offset buffer, so should after ncclAllGather. ComputeOffset<<<1, 1, 0, cuda_stream>>>(parallel_num, data_ptrs.num_partitioned_unique); ContiguousInverseUniquePartitionIndices<<>>( num_ids, data_ptrs.num_partitioned_unique, data_ptrs.inverse_unique_partition_indices_ptr); } int64_t received_elem_cnt = 0; ShuffleIdsAndTableIds(cuda_stream, comm, parallel_id, parallel_num, num_ids, ids_dtype, table_ids_dtype, host_num_unique_matrix, data_ptrs.partitioned_unique_ids, data_ptrs.partitioned_unique_table_ids, data_ptrs.received_ids, data_ptrs.received_table_ids, &received_elem_cnt, need_process_table_ids); UniqueAndPartition( cuda_stream, received_elem_cnt, hash_table_capacity, 1, data_ptrs.received_ids, data_ptrs.received_table_ids, data_ptrs.cur_rank_num_unique_ptr, data_ptrs.cur_rank_unique_ids_ptr, data_ptrs.cur_rank_unique_table_ids_ptr, data_ptrs.cur_rank_inverse_indices_ptr, data_ptrs.workspace_ptr, data_ptrs.workspace_size, need_process_table_ids, has_padding_idx, padding_idx); if (!need_process_table_ids) { OF_CUDA_CHECK(cudaMemsetAsync(data_ptrs.cur_rank_unique_table_ids_ptr, 0, received_elem_cnt * sizeof(U), cuda_stream)); } OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, data_ptrs.cur_rank_num_unique_ptr, sizeof(IDX), cudaMemcpyDefault, cuda_stream)); CHECK_JUST(stream->Sync()); } } // namespace data_shuffle } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_embedding_embedding_gradient_shuffle_p2p_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include #if CUDA_VERSION >= 11030 namespace oneflow { namespace { template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; template __device__ __inline__ void AtomicAdd(Pack* address, Pack val) { #pragma unroll for (int i = 0; i < pack_size; ++i) { cuda::atomic::Add(reinterpret_cast(address) + i, static_cast(val.elem[i])); } } template<> __device__ __inline__ void AtomicAdd(Pack* address, Pack val) { half2 h2_val; h2_val.x = static_cast(val.elem[0]); h2_val.y = static_cast(val.elem[1]); cuda::atomic::Add(reinterpret_cast(address), h2_val); } template struct Param { const IDX* cur_rank_inverse_indices; const Pack* unique_partitioned_embedding_grads[N]; int32_t* is_kernel_start[N]; const IDX* num_unique_matrix; Pack* cur_rank_unique_embedding_grad_ptr; }; template __global__ void EmbeddingGradientShuffleCudaKernel(int64_t parallel_id, int64_t parallel_num, int64_t embedding_num_pack, Param param) { #pragma unroll 1 for (int i = 0; i < parallel_num; ++i) { int rank_id = (parallel_id + i) % parallel_num; IDX cur_rank_index_offset = 0; for (int k = 0; k < rank_id; ++k) { cur_rank_index_offset += param.num_unique_matrix[k * parallel_num + parallel_id]; } IDX in_index_offset = 0; for (int k = 0; k < parallel_id; ++k) { in_index_offset += param.num_unique_matrix[rank_id * parallel_num + k]; } const IDX* cur_rank_inverse_indices_ptr = param.cur_rank_inverse_indices + cur_rank_index_offset; const Pack* unique_partitioned_embedding_grad_ptr = param.unique_partitioned_embedding_grads[rank_id] + in_index_offset * embedding_num_pack; Pack* cur_rank_unique_embedding_grad_ptr = param.cur_rank_unique_embedding_grad_ptr; const int copy_cnt = param.num_unique_matrix[rank_id * parallel_num + parallel_id] * embedding_num_pack; CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) { int in_row_id = j / embedding_num_pack; int col_id = j - in_row_id * embedding_num_pack; int out_row_id = cur_rank_inverse_indices_ptr[in_row_id]; Pack grad_val = unique_partitioned_embedding_grad_ptr[j]; AtomicAdd(cur_rank_unique_embedding_grad_ptr + out_row_id * embedding_num_pack + col_id, grad_val); } } } template __global__ void BarrierKernel(int32_t parallel_id, int32_t parallel_num, Param param) { int count = param.is_kernel_start[parallel_id][parallel_id]; if (threadIdx.x < parallel_num) { volatile int32_t* start_f = param.is_kernel_start[parallel_id]; volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x]; start_f[threadIdx.x] = count + 1; while (remote_start_f[parallel_id] < count + 1) {} } } struct IpcMemHandleOffset { cudaIpcMemHandle_t handle; int64_t offset; }; void GetPtrs(user_op::KernelComputeContext* ctx, std::vector* unique_partitioned_embedding_grad_ptr, std::vector* is_kernel_start_ptr) { const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); unique_partitioned_embedding_grad_ptr->at(parallel_id) = const_cast(ctx->Tensor4ArgNameAndIndex("embedding_grad", 0)->dptr()); std::string name = ctx->op_name(); { std::vector push_handle_offset; push_handle_offset.resize(2); OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(0).handle, unique_partitioned_embedding_grad_ptr->at(parallel_id))); OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(1).handle, is_kernel_start_ptr->at(parallel_id))); cudaError_t (*func)(void*, CUpointer_attribute, CUdeviceptr); OF_CUDA_CHECK( cudaGetDriverEntryPoint("cuPointerGetAttribute", (void**)(&func), cudaEnableDefault)); void* embedding_grad_base; OF_CUDA_CHECK(func(&embedding_grad_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)(unique_partitioned_embedding_grad_ptr->at(parallel_id)))); push_handle_offset.at(0).offset = reinterpret_cast(unique_partitioned_embedding_grad_ptr->at(parallel_id)) - reinterpret_cast(embedding_grad_base); push_handle_offset.at(1).offset = 0; Singleton::Get()->PushKV( name + std::to_string(parallel_id), std::string(reinterpret_cast(push_handle_offset.data()), 2 * sizeof(IpcMemHandleOffset))); } for (int64_t i = 0; i < parallel_num; ++i) { std::string key = name + std::to_string(i); if (parallel_id != i) { std::vector handle_offset; handle_offset.resize(2); Singleton::Get()->PullKV(key, [i, &handle_offset](const std::string& val) { memcpy(handle_offset.data(), val.data(), 2 * sizeof(IpcMemHandleOffset)); }); OF_CUDA_CHECK(cudaIpcOpenMemHandle(&unique_partitioned_embedding_grad_ptr->at(i), handle_offset.at(0).handle, cudaIpcMemLazyEnablePeerAccess)); unique_partitioned_embedding_grad_ptr->at(i) = reinterpret_cast(unique_partitioned_embedding_grad_ptr->at(i)) + handle_offset.at(0).offset; OF_CUDA_CHECK(cudaIpcOpenMemHandle(&is_kernel_start_ptr->at(i), handle_offset.at(1).handle, cudaIpcMemLazyEnablePeerAccess)); is_kernel_start_ptr->at(i) = reinterpret_cast(is_kernel_start_ptr->at(i)) + handle_offset.at(1).offset; } } } template class DataShuffleKernelState final : public user_op::OpKernelState { public: explicit DataShuffleKernelState(user_op::KernelInitContext* ctx) : device_index_(-1), parallel_desc_(ctx->parallel_desc()), parallel_id_(ctx->parallel_ctx().parallel_id()) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); int64_t parallel_num = parallel_desc_.parallel_num(); unique_partitioned_embedding_grad_ptr_.resize(parallel_num); is_kernel_start_ptr_.resize(parallel_num); size_t is_kernel_start_size = GetCudaAlignedSize(parallel_num * sizeof(int32_t)); OF_CUDA_CHECK(cudaMalloc(&is_kernel_start_ptr_.at(parallel_id_), is_kernel_start_size)); OF_CUDA_CHECK(cudaMemset(is_kernel_start_ptr_.at(parallel_id_), 0, is_kernel_start_size)); } ~DataShuffleKernelState() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFree(is_kernel_start_ptr_.at(parallel_id_))); } std::vector* UniquePartitionedEmbeddingGrads() { return &unique_partitioned_embedding_grad_ptr_; } std::vector* IsKernelStart() { return &is_kernel_start_ptr_; } private: int device_index_; ParallelDesc parallel_desc_; int64_t parallel_id_; std::vector unique_partitioned_embedding_grad_ptr_; std::vector is_kernel_start_ptr_; }; constexpr int pack_size = 2; template __global__ void MemsetCurRankEmbeddingGrad(int64_t parallel_id, int64_t parallel_num, int64_t vector_size, const uint32_t* num_unique_matrix, T* dst) { size_t count = 0; for (int i = 0; i < parallel_num; ++i) { count += num_unique_matrix[i * parallel_num + parallel_id] * vector_size; } const size_t pack_count = count / pack; Pack pack_value; for (int i = 0; i < pack; ++i) { pack_value.elem[i] = static_cast(0); } auto* pack_dst = reinterpret_cast*>(dst); CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; } T* tail_dst = dst + pack_count * pack; const size_t tail_count = count - pack_count * pack; CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = static_cast(0); } } template typename std::enable_if<(pack != 0), void>::type LaunchPackMemsetCurRankEmbeddingGrad( cudaStream_t stream, const uint32_t* num_unique_matrix, T* ptr, int sm_count, int64_t vector_size, int64_t parallel_id, int64_t parallel_num) { MemsetCurRankEmbeddingGrad<<<2 * sm_count, 1024, 0, stream>>>( parallel_id, parallel_num, vector_size, num_unique_matrix, ptr); } template typename std::enable_if<(pack == 0), void>::type LaunchPackMemsetCurRankEmbeddingGrad( cudaStream_t stream, const uint32_t* num_unique_matrix, T* ptr, int sm_count, int64_t vector_size, int64_t parallel_id, int64_t parallel_num) { LOG(FATAL) << "wrong alignment"; } template void LaunchMemsetCurRankEmbeddingGrad(cudaStream_t stream, int sm_count, int64_t vector_size, int64_t parallel_id, int64_t parallel_num, const uint32_t* num_unique_matrix, T* ptr) { auto uintptr = reinterpret_cast(ptr); if (uintptr % 16 == 0) { LaunchPackMemsetCurRankEmbeddingGrad( stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num); } else if (uintptr % 8 == 0) { LaunchPackMemsetCurRankEmbeddingGrad(stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num); } else if (uintptr % 4 == 0) { LaunchPackMemsetCurRankEmbeddingGrad(stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num); } else if (uintptr % 2 == 0) { LaunchPackMemsetCurRankEmbeddingGrad(stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num); } else { LaunchPackMemsetCurRankEmbeddingGrad(stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num); } } } // namespace template class EmbeddingGraidientShuffleP2PKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: EmbeddingGraidientShuffleP2PKernel() : current_iter_(0) {} ~EmbeddingGraidientShuffleP2PKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { if (current_iter_ == 0) { return false; } else { return true; } } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { CHECK(!embedding::UseDynamicMemoryAllocation()); CHECK(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION", false)); // only support skip last gather. CHECK(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true)); // when no identity, every time the cur_rank_inverse_indices // will change becauseof regster num=2. auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); const user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); user_op::Tensor* cur_rank_unique_embedding_grad = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_embedding_grad", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const bool only_zero_valid_grad = ctx->Attr("only_zero_valid_grad"); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int sm_count = ctx->stream()->As()->device_properties().multiProcessorCount; const bool skip_first_scatter = ctx->Attr("skip_first_scatter"); CHECK(skip_first_scatter); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); if (current_iter_ == 0) { GetPtrs(ctx, kernel_state->UniquePartitionedEmbeddingGrads(), kernel_state->IsKernelStart()); } CHECK_EQ(kernel_state->UniquePartitionedEmbeddingGrads()->at(parallel_id), embedding_grad->dptr()); Param param; CHECK_EQ(embedding_size % pack_size, 0); CHECK_LE(parallel_num, 8); param.cur_rank_unique_embedding_grad_ptr = reinterpret_cast*>(cur_rank_unique_embedding_grad->mut_dptr()); for (int i = 0; i < parallel_num; ++i) { param.unique_partitioned_embedding_grads[i] = reinterpret_cast*>( kernel_state->UniquePartitionedEmbeddingGrads()->at(i)); param.is_kernel_start[i] = reinterpret_cast(kernel_state->IsKernelStart()->at(i)); } param.cur_rank_inverse_indices = reinterpret_cast(cur_rank_inverse_indices->dptr()); param.num_unique_matrix = reinterpret_cast(num_unique_matrix->dptr()); int64_t embedding_num_pack = embedding_size / pack_size; if (only_zero_valid_grad) { LaunchMemsetCurRankEmbeddingGrad(cuda_stream, sm_count, embedding_size, parallel_id, parallel_num, reinterpret_cast(num_unique_matrix->dptr()), cur_rank_unique_embedding_grad->mut_dptr()); } else { OF_CUDA_CHECK(cudaMemsetAsync( cur_rank_unique_embedding_grad->mut_dptr(), 0, cur_rank_unique_embedding_grad->shape_view().elem_cnt() * sizeof(T), cuda_stream)); } BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>(parallel_id, parallel_num, param); const int num_blocks = 2 * ctx->stream()->As()->device_properties().multiProcessorCount; EmbeddingGradientShuffleCudaKernel<<>>( parallel_id, parallel_num, embedding_num_pack, param); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; REGISTER_USER_KERNEL("embedding_gradient_shuffle") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("embedding_grad", 0) == DataType::kFloat16) && (user_op::HobDataType("num_unique_matrix", 0) == DataType::kUInt32) && (user_op::HobAttr("skip_first_scatter") == true) && (embedding::UseEmbeddingGradientShuffleP2PKernel(DataType::kFloat16, DataType::kUInt32))); } // namespace oneflow #endif // CUDA_VERSION >= 11030 ================================================ FILE: oneflow/user/kernels/one_embedding_embedding_shuffle_p2p_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include #if CUDA_VERSION >= 11030 namespace oneflow { namespace { template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; template struct Param { IDX* inverse_indices[N]; Pack* unique_embeddings[N]; int32_t* is_kernel_start[N]; const IDX* num_unique_matrix; Pack* embedding_ptr; }; template __global__ void EmbeddingShuffleCudaKernel(int parallel_id, int parallel_num, int embedding_num_pack, Param param) { #pragma unroll 1 for (int i = 0; i < parallel_num; ++i) { int rank_id = (parallel_id + i) % parallel_num; IDX out_index_offset = 0; for (int k = 0; k < rank_id; ++k) { out_index_offset += param.num_unique_matrix[parallel_id * parallel_num + k]; } IDX in_index_offset = 0; for (int k = 0; k < parallel_id; ++k) { in_index_offset += param.num_unique_matrix[k * parallel_num + rank_id]; } const IDX* inverse_indices_ptr = param.inverse_indices[rank_id] + in_index_offset; const Pack* unique_embeddings_ptr = param.unique_embeddings[rank_id]; Pack* embedding_ptr = param.embedding_ptr + out_index_offset * embedding_num_pack; const int copy_cnt = param.num_unique_matrix[parallel_id * parallel_num + rank_id] * embedding_num_pack; CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) { int out_row_id = j / embedding_num_pack; int in_row_id = inverse_indices_ptr[out_row_id]; int col_id = j - out_row_id * embedding_num_pack; embedding_ptr[j] = unique_embeddings_ptr[in_row_id * embedding_num_pack + col_id]; } } } template __global__ void EmbeddingShuffleCopyKernel(int parallel_id, int parallel_num, int embedding_num_pack, Param param) { #pragma unroll 1 for (int i = 0; i < parallel_num; ++i) { int rank_id = (parallel_id + i) % parallel_num; IDX out_index_offset = 0; for (int k = 0; k < rank_id; ++k) { out_index_offset += param.num_unique_matrix[parallel_id * parallel_num + k]; } IDX in_index_offset = 0; for (int k = 0; k < parallel_id; ++k) { in_index_offset += param.num_unique_matrix[k * parallel_num + rank_id]; } const Pack* unique_embeddings_ptr = param.unique_embeddings[rank_id] + in_index_offset * embedding_num_pack; Pack* embedding_ptr = param.embedding_ptr + out_index_offset * embedding_num_pack; const int copy_cnt = param.num_unique_matrix[parallel_id * parallel_num + rank_id] * embedding_num_pack; CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) { embedding_ptr[j] = unique_embeddings_ptr[j]; } } } template __global__ void GatherKernel(int parallel_id, int parallel_num, int embedding_num_pack, const IDX* num_unique_matrix, const IDX* inverse_indices, const Pack* unique_embeddings, Pack* gather_out_unique_embeddings) { int cur_rank_num_ids = 0; for (int i = 0; i < parallel_num; ++i) { cur_rank_num_ids += num_unique_matrix[i * parallel_num + parallel_id]; } int out_cnt = cur_rank_num_ids * embedding_num_pack; CUDA_1D_KERNEL_LOOP_T(int, i, out_cnt) { int out_row_id = i / embedding_num_pack; int in_row_id = inverse_indices[out_row_id]; int col_id = i - out_row_id * embedding_num_pack; gather_out_unique_embeddings[i] = unique_embeddings[in_row_id * embedding_num_pack + col_id]; } } template __global__ void BarrierKernel(int32_t parallel_id, int32_t parallel_num, Param param) { int count = param.is_kernel_start[parallel_id][parallel_id]; if (threadIdx.x < parallel_num) { volatile int32_t* start_f = param.is_kernel_start[parallel_id]; volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x]; start_f[threadIdx.x] = count + 1; while (remote_start_f[parallel_id] < count + 1) {} } } struct IpcMemHandleOffset { cudaIpcMemHandle_t handle; int64_t offset; }; bool DisableFuseGatherCopy() { return ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_P2P_DISABLE_FUSE_GATHER_COPY", false); } void GetPtrs(user_op::KernelComputeContext* ctx, std::vector* unique_embeddings_ptr, std::vector* inverse_indices_ptr, std::vector* is_kernel_start_ptr) { const int64_t num_ids = ctx->TensorDesc4ArgNameAndIndex("inverse_unique_partition_indices", 0)->shape().elem_cnt(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); inverse_indices_ptr->at(parallel_id) = const_cast(ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0)->dptr()); if (DisableFuseGatherCopy()) { unique_embeddings_ptr->at(parallel_id) = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr(); } else { unique_embeddings_ptr->at(parallel_id) = const_cast(ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0)->dptr()); } std::string name = ctx->op_name(); { std::vector push_handle_offset; push_handle_offset.resize(3); OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(0).handle, unique_embeddings_ptr->at(parallel_id))); OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(1).handle, inverse_indices_ptr->at(parallel_id))); OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(2).handle, is_kernel_start_ptr->at(parallel_id))); cudaError_t (*func)(void*, CUpointer_attribute, CUdeviceptr); OF_CUDA_CHECK( cudaGetDriverEntryPoint("cuPointerGetAttribute", (void**)(&func), cudaEnableDefault)); void* unique_embeddings_base; OF_CUDA_CHECK(func(&unique_embeddings_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)(unique_embeddings_ptr->at(parallel_id)))); push_handle_offset.at(0).offset = reinterpret_cast(unique_embeddings_ptr->at(parallel_id)) - reinterpret_cast(unique_embeddings_base); void* inverse_indices_base; OF_CUDA_CHECK(func(&inverse_indices_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)(inverse_indices_ptr->at(parallel_id)))); push_handle_offset.at(1).offset = reinterpret_cast(inverse_indices_ptr->at(parallel_id)) - reinterpret_cast(inverse_indices_base); push_handle_offset.at(2).offset = 0; Singleton::Get()->PushKV( name + std::to_string(parallel_id), std::string(reinterpret_cast(push_handle_offset.data()), 3 * sizeof(IpcMemHandleOffset))); } for (int64_t i = 0; i < parallel_num; ++i) { std::string key = name + std::to_string(i); if (parallel_id != i) { std::vector handle_offset; handle_offset.resize(3); Singleton::Get()->PullKV(key, [i, &handle_offset](const std::string& val) { memcpy(handle_offset.data(), val.data(), 3 * sizeof(IpcMemHandleOffset)); }); OF_CUDA_CHECK(cudaIpcOpenMemHandle(&unique_embeddings_ptr->at(i), handle_offset.at(0).handle, cudaIpcMemLazyEnablePeerAccess)); unique_embeddings_ptr->at(i) = reinterpret_cast(unique_embeddings_ptr->at(i)) + handle_offset.at(0).offset; OF_CUDA_CHECK(cudaIpcOpenMemHandle(&inverse_indices_ptr->at(i), handle_offset.at(1).handle, cudaIpcMemLazyEnablePeerAccess)); inverse_indices_ptr->at(i) = reinterpret_cast(inverse_indices_ptr->at(i)) + handle_offset.at(1).offset; OF_CUDA_CHECK(cudaIpcOpenMemHandle(&is_kernel_start_ptr->at(i), handle_offset.at(2).handle, cudaIpcMemLazyEnablePeerAccess)); is_kernel_start_ptr->at(i) = reinterpret_cast(is_kernel_start_ptr->at(i)) + handle_offset.at(2).offset; } } } template class DataShuffleKernelState final : public user_op::OpKernelState { public: explicit DataShuffleKernelState(user_op::KernelInitContext* ctx) : device_index_(-1), parallel_desc_(ctx->parallel_desc()), parallel_id_(ctx->parallel_ctx().parallel_id()) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); int64_t parallel_num = parallel_desc_.parallel_num(); unique_embeddings_ptr_.resize(parallel_num); inverse_indices_ptr_.resize(parallel_num); is_kernel_start_ptr_.resize(parallel_num); size_t is_kernel_start_size = GetCudaAlignedSize(parallel_num * sizeof(int32_t)); OF_CUDA_CHECK(cudaMalloc(&is_kernel_start_ptr_.at(parallel_id_), is_kernel_start_size)); OF_CUDA_CHECK(cudaMemset(is_kernel_start_ptr_.at(parallel_id_), 0, is_kernel_start_size)); } ~DataShuffleKernelState() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFree(is_kernel_start_ptr_.at(parallel_id_))); } std::vector* UniqueEmbeddings() { return &unique_embeddings_ptr_; } std::vector* InverseIndices() { return &inverse_indices_ptr_; } std::vector* IsKernelStart() { return &is_kernel_start_ptr_; } private: int device_index_; ParallelDesc parallel_desc_; int64_t parallel_id_; std::vector unique_embeddings_ptr_; std::vector inverse_indices_ptr_; std::vector is_kernel_start_ptr_; }; template void LaunchKernel(user_op::KernelComputeContext* ctx, DataShuffleKernelState* kernel_state) { const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); DataType data_type = embeddings->data_type(); Param param; CHECK_LE(parallel_num, 8); param.embedding_ptr = reinterpret_cast*>(embeddings->mut_dptr()); for (int i = 0; i < parallel_num; ++i) { param.inverse_indices[i] = reinterpret_cast(kernel_state->InverseIndices()->at(i)); param.unique_embeddings[i] = reinterpret_cast*>(kernel_state->UniqueEmbeddings()->at(i)); param.is_kernel_start[i] = reinterpret_cast(kernel_state->IsKernelStart()->at(i)); } param.num_unique_matrix = reinterpret_cast(num_unique_matrix->dptr()); int64_t embedding_num_pack = embedding_size / pack_size; cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>(parallel_id, parallel_num, param); const int num_blocks = 2 * ctx->stream()->As()->device_properties().multiProcessorCount; if (DisableFuseGatherCopy()) { CHECK_EQ(kernel_state->UniqueEmbeddings()->at(parallel_id), ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->dptr()) << parallel_id; GatherKernel<<>>( parallel_id, parallel_num, embedding_num_pack, param.num_unique_matrix, param.inverse_indices[parallel_id], reinterpret_cast*>( ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0)->dptr()), param.unique_embeddings[parallel_id]); EmbeddingShuffleCopyKernel<<>>(parallel_id, parallel_num, embedding_num_pack, param); } else { CHECK_EQ(kernel_state->UniqueEmbeddings()->at(parallel_id), ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0)->dptr()) << parallel_id; EmbeddingShuffleCudaKernel<<>>(parallel_id, parallel_num, embedding_num_pack, param); } if (!ctx->Attr("is_train")) { BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>( parallel_id, parallel_num, param); // if in eval, should add last barrier. } } } // namespace template class EmbeddingShuffleP2PKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: EmbeddingShuffleP2PKernel() : current_iter_(0) {} ~EmbeddingShuffleP2PKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { if (current_iter_ == 0) { return false; } else { return true; } } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { CHECK(!embedding::UseDynamicMemoryAllocation()); CHECK(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION", false)); // only support skip last gather. CHECK(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true)); // when no identity, every time the cur_rank_inverse_indices // will change becauseof regster num=2. auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); const user_op::Tensor* inverse_unique_partition_indices = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); const bool skip_last_gather = ctx->Attr("skip_last_gather"); CHECK(skip_last_gather); const int64_t embedding_size = ctx->Attr("embedding_size"); if (current_iter_ == 0) { GetPtrs(ctx, kernel_state->UniqueEmbeddings(), kernel_state->InverseIndices(), kernel_state->IsKernelStart()); } const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); CHECK_EQ(kernel_state->InverseIndices()->at(parallel_id), cur_rank_inverse_indices->dptr()) << parallel_id; if (embedding_size % 4 == 0) { LaunchKernel(ctx, kernel_state); } else if (embedding_size % 2 == 0) { LaunchKernel(ctx, kernel_state); } else { LaunchKernel(ctx, kernel_state); } current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; REGISTER_USER_KERNEL("embedding_shuffle") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("cur_rank_embeddings", 0) == DataType::kFloat16) && (user_op::HobDataType("num_unique_matrix", 0) == DataType::kUInt32) && (user_op::HobAttr("skip_last_gather") == true) && (embedding::UseEmbeddingShuffleP2PKernel(DataType::kFloat16, DataType::kUInt32))) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { return GetCudaAlignedSize(ctx->InputTensorDesc("cur_rank_embeddings", 0).shape().elem_cnt() * sizeof(half)); }); } // namespace oneflow #endif // CUDA_VERSION >= 11030 ================================================ FILE: oneflow/user/kernels/one_embedding_id_shuffle_p2p_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/hash_functions.cuh" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/control/ctrl_client.h" namespace oneflow { namespace { template struct TableEntry { K key; uint32_t value; }; template __global__ void HashTableUniqueAndPartitionPairs( const uint32_t table_capacity, const uint32_t num_keys, int32_t num_partition, IDX* unique_counts, TableEntry* table, const K* keys, const V* values, K* partitioned_unique_keys, V* partitioned_unique_values, IDX* reverse_index, bool need_process_values, int32_t* is_kernel_start) { CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_keys) { IDX r_index_plus_one = 0; const K key = keys[i]; size_t key_hash = HASH()(key); uint32_t partition_id = key_hash % num_partition; IDX* unique_count = unique_counts + partition_id; K* unique_keys = partitioned_unique_keys + partition_id * num_keys; uint32_t pos = key_hash % table_capacity; const K key_hi = (key | 0x1); const K key_lo = (key & 0x1); uint32_t counter = 0; while (r_index_plus_one == 0) { bool prob_next = false; K* key_ptr = &table[pos].key; volatile uint32_t* table_value_ptr = &table[pos].value; const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi); if (old_key == 0) { IDX unique_pos = cuda::atomic::Add(unique_count, 1); r_index_plus_one = unique_pos + 1; unique_keys[unique_pos] = key; if (need_process_values) { partitioned_unique_values[partition_id * num_keys + unique_pos] = values[i]; } *table_value_ptr = ((r_index_plus_one << 1U) | key_lo); } else if (old_key == key_hi) { const uint32_t value = *table_value_ptr; if (value == 0) { // do nothing } else if ((value & 0x1) == key_lo) { r_index_plus_one = (value >> 1U); } else { prob_next = true; } } else { prob_next = true; } if (prob_next) { pos += 1; counter += 1; if (pos >= table_capacity) { pos -= table_capacity; } if (counter >= table_capacity) { __trap(); } } } reverse_index[i] = partition_id * num_keys + r_index_plus_one - 1; } } template struct Param { IDX* num_unique[N]; K* unique_ids[N]; U* unique_table_ids[N]; int32_t* is_kernel_start[N]; IDX* num_unique_matrix; int32_t* counter; }; template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; template __global__ void BarrierAndMemset(int32_t parallel_id, int32_t parallel_num, Param param, Pack* workspace_ptr, size_t workspace_num_pack, IDX* counter, int num_counter) { int count; if (blockIdx.x == 0) { count = param.is_kernel_start[parallel_id][parallel_id]; if (threadIdx.x < parallel_num) { volatile int32_t* start_f = param.is_kernel_start[parallel_id]; start_f[threadIdx.x] = count + 1; } } Pack pack_value; for (int i = 0; i < pack_size; ++i) { pack_value.elem[i] = static_cast(0); } CUDA_1D_KERNEL_LOOP(i, workspace_num_pack) { workspace_ptr[i] = pack_value; } int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (global_thread_id < num_counter) { counter[global_thread_id] = 0; } if (blockIdx.x == 0) { if (threadIdx.x < parallel_num) { volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x]; while (remote_start_f[parallel_id] < count + 1) {} } } } template __global__ void HashTableUniquePairs(const uint32_t table_capacity, const uint32_t num_ids, int32_t parallel_num, int32_t parallel_id, IDX* unique_count, TableEntry* table, Param param, K* unique_keys, V* unique_values, IDX* reverse_index, bool need_process_values) { #pragma unroll 1 for (int i = 0; i < parallel_num; ++i) { int rank_id = (parallel_id + i) % parallel_num; const IDX* num_uniques = param.num_unique[rank_id]; CUDA_1D_KERNEL_LOOP_T(int, rank_index, num_uniques[parallel_id]) { const IDX* num_uniques = param.num_unique[rank_id]; // if (rank_index >= num_uniques[parallel_id]) { continue; } const K* keys = param.unique_ids[rank_id]; const V* values = param.unique_table_ids[rank_id]; IDX index_offset = 0; for (int k = 0; k < rank_id; ++k) { index_offset += param.num_unique[k][parallel_id]; } IDX r_index_plus_one = 0; const K key = keys[rank_index]; size_t key_hash = HASH()(key); uint32_t pos = key_hash % table_capacity; const K key_hi = (key | 0x1); const K key_lo = (key & 0x1); uint32_t counter = 0; while (r_index_plus_one == 0) { bool prob_next = false; K* key_ptr = &table[pos].key; volatile uint32_t* table_value_ptr = &table[pos].value; const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi); if (old_key == 0) { IDX unique_pos = cuda::atomic::Add(unique_count, 1); r_index_plus_one = unique_pos + 1; unique_keys[unique_pos] = key; if (need_process_values) { unique_values[unique_pos] = values[rank_index]; } *table_value_ptr = ((r_index_plus_one << 1U) | key_lo); } else if (old_key == key_hi) { const uint32_t value = *table_value_ptr; if (value == 0) { // do nothing } else if ((value & 0x1) == key_lo) { r_index_plus_one = (value >> 1U); } else { prob_next = true; } } else { prob_next = true; } if (prob_next) { pos += 1; counter += 1; if (pos >= table_capacity) { pos -= table_capacity; } if (counter >= table_capacity) { __trap(); } } } reverse_index[rank_index + index_offset] = r_index_plus_one - 1; if (rank_index < parallel_num) { param.num_unique_matrix[i * parallel_num + rank_index] = param.num_unique[i][rank_index]; } } } } template __global__ void GenerateTableIdsAndMemsetUniqueWorkspace(int32_t elem_cnt, int32_t num_tables, U* table_ids, Pack* workspace_ptr, size_t workspace_num_pack, IDX* counter, int num_counter) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { table_ids[i] = i % num_tables; } Pack pack_value; for (int i = 0; i < pack_size; ++i) { pack_value.elem[i] = static_cast(0); } CUDA_1D_KERNEL_LOOP(i, workspace_num_pack) { workspace_ptr[i] = pack_value; } int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (global_thread_id < num_counter) { counter[global_thread_id] = 0; } } template void UniqueAndPartition(cudaStream_t cuda_stream, int64_t num_blocks, int64_t num_ids, size_t capacity, int64_t num_partition, const K* ids, const V* table_ids, IDX* num_partitioned_unique_ids_ptr, K* partitioned_unique_ids, V* partitioned_unique_table_ids, IDX* inverse_unique_partition_indices, void* workspace_ptr, size_t workspace_bytes, bool need_process_table_ids, int32_t* is_kernel_start_ptr) { size_t table_capacity_bytes = capacity * sizeof(TableEntry); CHECK_GE(workspace_bytes, table_capacity_bytes); HashTableUniqueAndPartitionPairs<<>>( capacity, num_ids, num_partition, num_partitioned_unique_ids_ptr, reinterpret_cast*>(workspace_ptr), ids, table_ids, partitioned_unique_ids, partitioned_unique_table_ids, inverse_unique_partition_indices, need_process_table_ids, is_kernel_start_ptr); } enum class IdShuffleBufferType { kTableIds = 0, kWorkspace, kMaxType }; template class IdShuffleTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(IdShuffleTmpBufferManager); IdShuffleTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num, bool need_table_ids, bool need_process_table_ids) : offset_(0), offsets_(static_cast(IdShuffleBufferType::kMaxType), -1), sizes_(static_cast(IdShuffleBufferType::kMaxType)), ptr_(ptr) { const int64_t num_table_ids = need_process_table_ids ? num_ids : 0; const size_t table_ids_bytes = need_table_ids ? num_ids * sizeof(U) : 0; AllocBuffer(IdShuffleBufferType::kTableIds, table_ids_bytes); const size_t hash_table_capacity = parallel_num * num_ids; AllocBuffer(IdShuffleBufferType::kWorkspace, hash_table_capacity * sizeof(TableEntry)); } template T* Ptr(IdShuffleBufferType type) { CHECK(ptr_ != nullptr); int64_t offset = offsets_.at(static_cast(type)); CHECK_NE(offset, -1); return reinterpret_cast(reinterpret_cast(ptr_) + offset); } int64_t Size(IdShuffleBufferType type) { return sizes_.at(static_cast(type)); } size_t TotalBufferSize() const { return offset_; } private: void AllocBuffer(IdShuffleBufferType type, size_t size) { const size_t type_id = static_cast(type); CHECK_EQ(offsets_.at(type_id), -1); offsets_.at(type_id) = offset_; sizes_.at(type_id) = size; offset_ += GetCudaAlignedSize(size); } size_t offset_; std::vector offsets_; std::vector sizes_; void* ptr_; }; template class DataShuffleKernelState final : public user_op::OpKernelState { public: explicit DataShuffleKernelState(user_op::KernelInitContext* ctx) : device_index_(-1), parallel_desc_(ctx->parallel_desc()), parallel_id_(ctx->parallel_ctx().parallel_id()) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); int64_t parallel_num = parallel_desc_.parallel_num(); OF_CUDA_CHECK( cudaMallocHost(&host_num_unique_matrix_, parallel_num * parallel_num * sizeof(IDX))); OF_CUDA_CHECK(cudaMallocHost(&host_cur_rank_num_unique_, sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = parallel_id_; embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); const int64_t num_ids = ctx->TensorDesc4ArgNameAndIndex("ids", 0)->shape().elem_cnt(); num_partitioned_unique_size_ = GetCudaAlignedSize(parallel_num * sizeof(IDX)); partitioned_unique_ids_size_ = GetCudaAlignedSize(parallel_num * num_ids * sizeof(K)); partitioned_unique_table_ids_size_ = GetCudaAlignedSize(parallel_num * num_ids * sizeof(U)); is_kernel_start_size_ = GetCudaAlignedSize(parallel_num * sizeof(int32_t)); size_t buffer_size = num_partitioned_unique_size_ + partitioned_unique_ids_size_ + partitioned_unique_table_ids_size_ + is_kernel_start_size_; buffer_ptrs_.resize(parallel_num); cudaMalloc(&buffer_ptrs_.at(parallel_id), buffer_size); cudaMemset(buffer_ptrs_.at(parallel_id), 0, buffer_size); } ~DataShuffleKernelState() { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_cur_rank_num_unique_)); OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_)); OF_CUDA_CHECK(cudaFree(buffer_ptrs_.at(parallel_id_))); } std::vector* BufferPtrs() { return &buffer_ptrs_; } IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; } IDX* HostCurRankNumUnique() { return host_cur_rank_num_unique_; } embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } IDX* NumPartitionedUnique(int64_t parallel_id) { return reinterpret_cast(buffer_ptrs_.at(parallel_id)); } K* PartitionedUniqueIds(int64_t parallel_id) { return reinterpret_cast(reinterpret_cast(buffer_ptrs_.at(parallel_id)) + num_partitioned_unique_size_); } U* PartitionedUniqueTableIds(int64_t parallel_id) { return reinterpret_cast(reinterpret_cast(buffer_ptrs_.at(parallel_id)) + num_partitioned_unique_size_ + partitioned_unique_ids_size_); } int32_t* IsKernelStart(int64_t parallel_id) { return reinterpret_cast(reinterpret_cast(buffer_ptrs_.at(parallel_id)) + num_partitioned_unique_size_ + partitioned_unique_ids_size_ + partitioned_unique_table_ids_size_); } private: int device_index_; ParallelDesc parallel_desc_; int64_t parallel_id_; IDX* host_num_unique_matrix_; IDX* host_cur_rank_num_unique_; std::vector buffer_ptrs_; size_t num_partitioned_unique_size_; size_t partitioned_unique_ids_size_; size_t partitioned_unique_table_ids_size_; size_t is_kernel_start_size_; embedding::EmbeddingState* embedding_state_; }; void GetPtrs(user_op::KernelComputeContext* ctx, std::vector* buffer_ptrs) { const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); std::string name = ctx->op_name(); cudaIpcMemHandle_t handle; OF_CUDA_CHECK(cudaIpcGetMemHandle(&handle, buffer_ptrs->at(parallel_id))); Singleton::Get()->PushKV( name + std::to_string(parallel_id), std::string(reinterpret_cast(&handle), sizeof(cudaIpcMemHandle_t))); for (int64_t i = 0; i < parallel_num; ++i) { std::string key = name + std::to_string(i); if (parallel_id != i) { cudaIpcMemHandle_t handle; Singleton::Get()->PullKV(key, [&handle](const std::string& val) { memcpy(&handle, val.data(), sizeof(cudaIpcMemHandle_t)); }); OF_CUDA_CHECK( cudaIpcOpenMemHandle(&buffer_ptrs->at(i), handle, cudaIpcMemLazyEnablePeerAccess)); } } } template __global__ void BarrierAndComputeOut(int32_t parallel_id, int32_t parallel_num, int32_t num_ids, Param param, IDX* num_partitioned_unique, IDX* inverse_ptr, IDX* num_unique_matrix, IDX* host_num_unique_matrix, IDX* cur_rank_num_unique, IDX* host_cur_rank_num_unique) { int count; if (blockIdx.x == 0) { count = param.is_kernel_start[parallel_id][parallel_id]; if (threadIdx.x < parallel_num) { volatile int32_t* start_f = param.is_kernel_start[parallel_id]; start_f[threadIdx.x] = count + 1; } } if (parallel_num > 1) { CUDA_1D_KERNEL_LOOP(i, num_ids) { int inverse_indice = inverse_ptr[i]; int partition_id = inverse_indice / num_ids; int partition_indice = inverse_indice - partition_id * num_ids; int new_offset = 0; for (int k = 0; k < partition_id; ++k) { new_offset += num_partitioned_unique[k]; } inverse_ptr[i] = new_offset + partition_indice; } } int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (global_thread_id < parallel_num * parallel_num) { host_num_unique_matrix[global_thread_id] = num_unique_matrix[global_thread_id]; } if (global_thread_id == 0) { host_cur_rank_num_unique[global_thread_id] = cur_rank_num_unique[global_thread_id]; } if (blockIdx.x == 0) { if (threadIdx.x < parallel_num) { volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x]; while (remote_start_f[parallel_id] < count + 1) {} } } } } // namespace template class IdShuffleP2PKernel final : public user_op::OpKernel { public: IdShuffleP2PKernel() : current_iter_(0){}; ~IdShuffleP2PKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex("ids", 0); user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); user_op::Tensor* inverse_unique_partition_indices = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); user_op::Tensor* cur_rank_num_unique = ctx->Tensor4ArgNameAndIndex("cur_rank_num_unique", 0); user_op::Tensor* cur_rank_unique_ids = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_ids", 0); user_op::Tensor* cur_rank_unique_table_ids = ctx->Tensor4ArgNameAndIndex("cur_rank_unique_table_ids", 0); user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t num_tables = ctx->Attr("num_tables"); const bool has_table_ids = ctx->has_input("table_ids", 0); const bool need_gen_table_ids = (!has_table_ids && num_tables > 1); const bool need_process_table_ids = (has_table_ids || num_tables > 1); const int64_t num_ids = ids->shape_view().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); IdShuffleTmpBufferManager buffer_manager( tmp_buffer->mut_dptr(), num_ids, parallel_num, need_gen_table_ids, need_process_table_ids); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize()); if (current_iter_ == 0) { GetPtrs(ctx, kernel_state->BufferPtrs()); } const int num_blocks = 2 * ctx->stream()->As()->device_properties().multiProcessorCount; IDX* num_partitioned_unique = kernel_state->NumPartitionedUnique(parallel_id); K* partitioned_unique_ids = kernel_state->PartitionedUniqueIds(parallel_id); U* partitioned_unique_table_ids = kernel_state->PartitionedUniqueTableIds(parallel_id); IDX* num_unique_matrix_ptr = reinterpret_cast(num_unique_matrix->mut_dptr()); size_t hash_table_capacity = parallel_num * num_ids; void* workspace_ptr = buffer_manager.Ptr(IdShuffleBufferType::kWorkspace); size_t workspace_size = buffer_manager.Size(IdShuffleBufferType::kWorkspace); const U* table_ids_ptr; bool skip_memset = false; if (has_table_ids) { const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); table_ids_ptr = reinterpret_cast(table_ids->dptr()); } else if (need_gen_table_ids) { CHECK_EQ(workspace_size % 16, 0); CHECK_EQ(reinterpret_cast(workspace_ptr) % 16, 0); GenerateTableIdsAndMemsetUniqueWorkspace<<>>( num_ids, num_tables, buffer_manager.template Ptr(IdShuffleBufferType::kTableIds), reinterpret_cast*>(workspace_ptr), workspace_size / 16, num_partitioned_unique, parallel_num); table_ids_ptr = buffer_manager.template Ptr(IdShuffleBufferType::kTableIds); skip_memset = true; } else { table_ids_ptr = nullptr; } if (!skip_memset) { OF_CUDA_CHECK(cudaMemsetAsync(workspace_ptr, 0, workspace_size, cuda_stream)); OF_CUDA_CHECK( cudaMemsetAsync(num_partitioned_unique, 0, parallel_num * sizeof(IDX), cuda_stream)); } UniqueAndPartition( cuda_stream, num_blocks, num_ids, hash_table_capacity, parallel_num, reinterpret_cast(ids->dptr()), table_ids_ptr, num_partitioned_unique, partitioned_unique_ids, partitioned_unique_table_ids, reinterpret_cast(inverse_unique_partition_indices->mut_dptr()), workspace_ptr, workspace_size, need_process_table_ids, kernel_state->IsKernelStart(parallel_id)); IDX* cur_rank_num_unique_ids_ptr = reinterpret_cast(cur_rank_num_unique->mut_dptr()); Param param; CHECK_LE(parallel_num, 8); for (int i = 0; i < parallel_num; ++i) { param.num_unique[i] = kernel_state->NumPartitionedUnique(i); param.unique_ids[i] = kernel_state->PartitionedUniqueIds(i) + parallel_id * num_ids; param.unique_table_ids[i] = kernel_state->PartitionedUniqueTableIds(i) + parallel_id * num_ids; param.is_kernel_start[i] = kernel_state->IsKernelStart(i); } param.num_unique_matrix = num_unique_matrix_ptr; CHECK_EQ(workspace_size % 16, 0); CHECK_EQ(reinterpret_cast(workspace_ptr) % 16, 0); int workspace_num_pack = workspace_size / 16; BarrierAndMemset<<>>( parallel_id, parallel_num, param, reinterpret_cast*>(workspace_ptr), workspace_num_pack, cur_rank_num_unique_ids_ptr, 1); HashTableUniquePairs <<>>( hash_table_capacity, num_ids, parallel_num, parallel_id, cur_rank_num_unique_ids_ptr, reinterpret_cast*>(workspace_ptr), param, reinterpret_cast(cur_rank_unique_ids->mut_dptr()), reinterpret_cast(cur_rank_unique_table_ids->mut_dptr()), reinterpret_cast(cur_rank_inverse_indices->mut_dptr()), need_process_table_ids); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); IDX* host_cur_rank_num_unique = kernel_state->HostCurRankNumUnique(); BarrierAndComputeOut<<>>( parallel_id, parallel_num, num_ids, param, num_partitioned_unique, reinterpret_cast(inverse_unique_partition_indices->mut_dptr()), num_unique_matrix_ptr, host_num_unique_matrix, cur_rank_num_unique_ids_ptr, host_cur_rank_num_unique); if (!need_process_table_ids) { OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_table_ids->mut_dptr(), 0, cur_rank_unique_table_ids->shape_view().elem_cnt() * sizeof(U), cuda_stream)); } embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::vector num_unique_matrix_vec(parallel_num * parallel_num); CHECK_JUST(ctx->stream()->Sync()); std::memcpy(num_unique_matrix_vec.data(), host_num_unique_matrix, parallel_num * parallel_num * sizeof(IDX)); CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << "assume sizeof(IDX) equals to sizeof(uint32_t)"; embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_); uint32_t final_num_unique = *host_cur_rank_num_unique; embedding_state->SetIdFinalNumUnique(final_num_unique, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define TABLE_ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define REGISTER_CUDA_ID_SHUFFLE_P2P_KERNEL(k_dtype_pair, table_id_dtype_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("id_shuffle") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("cur_rank_unique_table_ids", 0) \ == OF_PP_PAIR_SECOND(table_id_dtype_pair)) \ && (user_op::HobDataType("num_unique_matrix", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ID_SHUFFLE_USE_P2P", false)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& ids = ctx->InputTensorDesc("ids", 0); \ const bool has_table_ids = ctx->has_input("table_ids", 0); \ const int32_t num_tables = ctx->Attr("num_tables"); \ const bool need_gen_table_ids = (!has_table_ids && num_tables > 1); \ const bool need_process_table_ids = (has_table_ids || num_tables > 1); \ IdShuffleTmpBufferManager \ buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_desc().parallel_num(), \ need_gen_table_ids, need_process_table_ids); \ return buffer_manager.TotalBufferSize(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_P2P_KERNEL, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_embedding_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/embedding/key_value_store.h" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/device.h" #include "oneflow/user/kernels/one_embedding_data_shuffle.cuh" #include #include namespace oneflow { namespace { enum class InitializerType { kUniform, kNormal, kConstant, kTruncNormal }; struct EmbeddingInitializer { InitializerType type; union { struct { float low; float high; } uniform_param; struct { float mean; float std; } normal_param; struct { float value; } constant_param; struct { float mean; float std; float a; float b; } trunc_normal_param; }; bool operator==(const EmbeddingInitializer& rhs) const { if (this->type != rhs.type) { return false; } if (rhs.type == InitializerType::kUniform) { return (this->uniform_param.low == rhs.uniform_param.low) && (this->uniform_param.high == rhs.uniform_param.high); } else if (rhs.type == InitializerType::kNormal) { return (this->normal_param.mean == rhs.normal_param.mean) && (this->normal_param.std == rhs.normal_param.std); } else if (rhs.type == InitializerType::kConstant) { return this->constant_param.value == rhs.constant_param.value; } else if (rhs.type == InitializerType::kTruncNormal) { return (this->trunc_normal_param.mean == rhs.trunc_normal_param.mean) && (this->trunc_normal_param.std == rhs.trunc_normal_param.std) && (this->trunc_normal_param.a == rhs.trunc_normal_param.a) && (this->trunc_normal_param.b == rhs.trunc_normal_param.b); } else { UNIMPLEMENTED(); return false; } } }; void ParseInitializerFromJson(const nlohmann::json& initializer, EmbeddingInitializer* embedding_initializer) { CHECK(initializer.contains("type")); CHECK(initializer["type"].is_string()); std::string type = initializer["type"].get(); if (type == "uniform") { embedding_initializer->type = InitializerType::kUniform; CHECK(initializer.contains("low")); CHECK(initializer.contains("high")); CHECK(initializer["low"].is_number()); CHECK(initializer["high"].is_number()); embedding_initializer->uniform_param.low = initializer["low"]; embedding_initializer->uniform_param.high = initializer["high"]; } else if (type == "normal") { CHECK(initializer.contains("mean")); CHECK(initializer.contains("std")); CHECK(initializer["mean"].is_number()); CHECK(initializer["std"].is_number()); embedding_initializer->type = InitializerType::kNormal; embedding_initializer->normal_param.mean = initializer["mean"]; embedding_initializer->normal_param.std = initializer["std"]; } else if (type == "constant") { CHECK(initializer.contains("value")); CHECK(initializer["value"].is_number()); embedding_initializer->type = InitializerType::kConstant; embedding_initializer->constant_param.value = initializer["value"]; } else if (type == "trunc_normal") { CHECK(initializer.contains("mean")); CHECK(initializer.contains("std")); CHECK(initializer.contains("a")); CHECK(initializer.contains("b")); CHECK(initializer["mean"].is_number()); CHECK(initializer["std"].is_number()); CHECK(initializer["a"].is_number()); CHECK(initializer["b"].is_number()); embedding_initializer->type = InitializerType::kTruncNormal; embedding_initializer->trunc_normal_param.mean = initializer["mean"]; embedding_initializer->trunc_normal_param.std = initializer["std"]; embedding_initializer->trunc_normal_param.a = initializer["a"]; embedding_initializer->trunc_normal_param.b = initializer["b"]; } else { UNIMPLEMENTED() << "Unsupported initializer type"; } } int32_t ParseJsonToUniqueInitializerVecAndReturnOffset( const nlohmann::json& initializer, std::vector* initializers) { EmbeddingInitializer embedding_initializer; ParseInitializerFromJson(initializer, &embedding_initializer); for (int32_t i = 0; i < initializers->size(); ++i) { if (initializers->at(i) == embedding_initializer) { return i; } } initializers->push_back(embedding_initializer); return initializers->size() - 1; } void SetInitializerIndex(int32_t row_id, int32_t col_start, int32_t col_end, int64_t line_size, int8_t index, std::vector* initializer_index) { int64_t row_offset = row_id * line_size; for (int32_t col = col_start; col < col_end; ++col) { initializer_index->at(row_offset + col) = index; } } void ParseAndSetStateInitializerIndex(const std::string& state_initializer, const int32_t num_tables, const int64_t line_size, const int64_t embedding_size, std::vector* initializer_params, std::vector* initializer_index) { if (line_size == embedding_size) { return; } CHECK(!state_initializer.empty()); auto initializers = nlohmann::json::parse(state_initializer); CHECK(initializers.is_array()); const int num_states = line_size / embedding_size - 1; CHECK_EQ(num_states, initializers.size()); for (int32_t i = 0; i < num_states; ++i) { int32_t offset = ParseJsonToUniqueInitializerVecAndReturnOffset(initializers.at(i), initializer_params); int32_t col_start = embedding_size + i * embedding_size; int32_t col_end = col_start + embedding_size; CHECK_LE(col_end, line_size); for (int32_t j = 0; j < num_tables; ++j) { SetInitializerIndex(j, col_start, col_end, line_size, offset, initializer_index); } } } void ParseAndSetStepInitializerIndex(const int32_t num_tables, const int64_t line_size, const int64_t embedding_size, std::vector* initializer_params, std::vector* initializer_index) { if (line_size % embedding_size == 0) { return; } nlohmann::json initializer; initializer["type"] = "constant"; initializer["value"] = 0.0; int32_t offset = ParseJsonToUniqueInitializerVecAndReturnOffset(initializer, initializer_params); int32_t col_start = line_size / embedding_size * embedding_size; int32_t col_end = line_size; CHECK_LE(col_end, line_size); for (int32_t j = 0; j < num_tables; ++j) { SetInitializerIndex(j, col_start, col_end, line_size, offset, initializer_index); } } void ParseAndSetModelInitializerIndex(const nlohmann::json& tables, const std::vector& column_dims, const int32_t num_tables, const int32_t num_columns, const int64_t line_size, const int64_t embedding_size, std::vector* initializer_params, std::vector* initializer_index) { for (int32_t i = 0; i < num_tables; ++i) { auto table = tables.at(i); CHECK(table.contains("columns")); auto columns = table["columns"]; CHECK(columns.is_array()); CHECK_EQ(num_columns, columns.size()) << "columns size must equal to num embedding dims"; int32_t col_start = 0; for (int k = 0; k < columns.size(); ++k) { auto column = columns.at(k); CHECK(column.contains("initializer")); int32_t offset = ParseJsonToUniqueInitializerVecAndReturnOffset(column["initializer"], initializer_params); int32_t col_end = col_start + column_dims.at(k); SetInitializerIndex(i, col_start, col_end, line_size, offset, initializer_index); col_start = col_end; } CHECK_EQ(col_start, embedding_size); } } void ParseInitializers(const int64_t line_size, const int64_t embedding_size, const std::string& state_initializer, const std::string& json_serialized, std::vector* initializer_params, std::vector* initializer_index) { auto json_object = nlohmann::json::parse(json_serialized); CHECK(json_object.contains("column_dims")); std::vector column_dims = json_object["column_dims"]; const int32_t num_columns = column_dims.size(); CHECK(json_object.contains("tables")); auto tables = json_object["tables"]; CHECK(tables.is_array()); const int32_t num_tables = tables.size(); initializer_index->resize(num_tables * line_size); ParseAndSetStepInitializerIndex(num_tables, line_size, embedding_size, initializer_params, initializer_index); ParseAndSetStateInitializerIndex(state_initializer, num_tables, line_size, embedding_size, initializer_params, initializer_index); ParseAndSetModelInitializerIndex(tables, column_dims, num_tables, num_columns, line_size, embedding_size, initializer_params, initializer_index); } template class EmbeddingKernelState final : public user_op::OpKernelState { public: explicit EmbeddingKernelState(user_op::KernelInitContext* ctx) : device_index_(-1) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); key_value_store_ = Singleton::Get()->GetKeyValueStore( embedding_name, parallel_id); uint32_t max_query_length = ctx->TensorDesc4ArgNameAndIndex("unique_ids", 0)->shape().elem_cnt(); key_value_store_->ReserveQueryLength(max_query_length); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); const std::string& state_initializer = ctx->Attr("state_initializer"); std::vector initializer_param; std::vector initializer_index; ParseInitializers(line_size, embedding_size, state_initializer, ctx->Attr("embedding_tables"), &initializer_param, &initializer_index); const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes)); std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_, param_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes)); std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_, index_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); } ~EmbeddingKernelState() override { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_num_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_)); OF_CUDA_CHECK(cudaFree(device_initializer_param_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_)); OF_CUDA_CHECK(cudaFree(device_initializer_index_)); } void* HostNumKeys() { return host_num_keys_; } embedding::KeyValueStore* KeyValueStore() { return key_value_store_; } embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } const int8_t* InitializerIndex() { return device_initializer_index_; } const EmbeddingInitializer* Initializers() { return device_initializer_param_; } private: int device_index_; void* host_num_keys_; embedding::KeyValueStore* key_value_store_; embedding::EmbeddingState* embedding_state_; EmbeddingInitializer* host_initializer_param_; EmbeddingInitializer* device_initializer_param_; int8_t* host_initializer_index_; int8_t* device_initializer_index_; }; class EmbeddingPutKernelState final : public user_op::OpKernelState { public: explicit EmbeddingPutKernelState(user_op::KernelInitContext* ctx) { const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); key_value_store_ = Singleton::Get()->GetKeyValueStore( embedding_name, parallel_id); uint32_t max_query_length = ctx->TensorDesc4ArgNameAndIndex("unique_ids", 0)->shape().elem_cnt(); key_value_store_->ReserveQueryLength(max_query_length); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); } ~EmbeddingPutKernelState() override = default; embedding::KeyValueStore* KeyValueStore() { return key_value_store_; } embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } private: embedding::KeyValueStore* key_value_store_; embedding::EmbeddingState* embedding_state_; }; template __global__ void InitValueKernel(uint64_t seed, const int32_t line_size, const int32_t embedding_size, const EmbeddingInitializer* initializer_param, const int8_t* initializer_index, const K* unique_ids, const U* table_ids, const uint32_t* num_missing_keys, const uint32_t* missing_indices, T* values) { int64_t n = *num_missing_keys * line_size; CUDA_1D_KERNEL_LOOP(i, n) { int row = i / line_size; int col = i - row * line_size; const uint32_t index = missing_indices[row]; const int64_t offset = index * line_size + col; const int32_t table_idx = table_ids[index]; const K id = unique_ids[index]; curandStatePhilox4_32_10_t state; curand_init(seed, id, col, &state); const int32_t initializer_idx = initializer_index[table_idx * line_size + col]; EmbeddingInitializer initializer = initializer_param[initializer_idx]; T value; if (initializer.type == InitializerType::kUniform) { const float low = initializer.uniform_param.low; const float high = initializer.uniform_param.high; value = curand_uniform(&state) * (high - low) + low; } else if (initializer.type == InitializerType::kNormal) { const float mean = initializer.normal_param.mean; const float std = initializer.normal_param.std; value = curand_normal(&state) * std + mean; } else if (initializer.type == InitializerType::kConstant) { value = initializer.constant_param.value; } else if (initializer.type == InitializerType::kTruncNormal) { const float mean = initializer.trunc_normal_param.mean; const float std = initializer.trunc_normal_param.std; const float a = initializer.trunc_normal_param.a; const float b = initializer.trunc_normal_param.b; while (true) { value = curand_normal(&state) * std + mean; if (value >= a && value <= b) { break; } skipahead(line_size, &state); } } else { __trap(); } values[offset] = value; } } template void LookupAndInitMissing(ep::Stream* stream, uint64_t seed, embedding::KeyValueStore* store, const EmbeddingInitializer* initializer_param, const int8_t* initializer_index, void* host_num_keys, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, const bool put_to_store, const void* unique_ids, const void* table_ids, void* num_missing_ptr, void* missing_indices, void* store_values) { store->Get(stream, num_unique, unique_ids, store_values, reinterpret_cast(num_missing_ptr), reinterpret_cast(missing_indices)); CHECK_GE(sizeof(IDX), sizeof(uint32_t)); // host_num_keys's buffer size is sizeof(IDX) OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_missing_ptr, sizeof(uint32_t), cudaMemcpyDefault, stream->As()->cuda_stream())); CHECK_JUST(stream->Sync()); uint32_t num_missing = *reinterpret_cast(host_num_keys); // init missing values if (num_missing > 0) { const int64_t elem_cnt = num_missing * line_size; const int64_t num_blocks = BlocksNum4ThreadsNum(elem_cnt); InitValueKernel <<As()->cuda_stream()>>>( seed, line_size, embedding_size, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), reinterpret_cast(num_missing_ptr), reinterpret_cast(missing_indices), reinterpret_cast(store_values)); } if (put_to_store) { store->Put(stream, num_unique, unique_ids, store_values); } } template void LookupAndInitMissing(ep::Stream* stream, EmbeddingKernelState* kernel_state, uint64_t seed, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, const bool put_to_store, const void* unique_ids, const void* table_ids, void* num_missing_ptr, void* missing_indices, void* store_values) { embedding::KeyValueStore* store = kernel_state->KeyValueStore(); const EmbeddingInitializer* initializer_param = kernel_state->Initializers(); const int8_t* initializer_index = kernel_state->InitializerIndex(); void* host_num_keys = kernel_state->HostNumKeys(); LookupAndInitMissing(stream, seed, store, initializer_param, initializer_index, host_num_keys, num_unique, embedding_size, line_size, put_to_store, unique_ids, table_ids, num_missing_ptr, missing_indices, store_values); } template struct alignas(sizeof(T) * pack_size) Pack { T elem[pack_size]; }; template __global__ void FusedInitSliceCast(const int32_t elem_cnt, uint64_t seed, const int32_t line_size, const int32_t embedding_size, const int32_t line_num_pack, const int32_t embedding_num_pack, const EmbeddingInitializer* initializer_param, const int8_t* initializer_index, const K* unique_ids, const U* table_ids, const uint8_t* lookup_mask, Pack* values, Pack* embeddings) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { int row = i / line_num_pack; int col = i - row * line_num_pack; Pack value_i; if (!lookup_mask[row]) { const int32_t table_idx = table_ids[row]; const K id = unique_ids[row]; curandStatePhilox4_32_10_t state; curand_init(seed, id, col, &state); #pragma unroll for (int k = 0; k < pack_size; ++k) { const int32_t initializer_idx = initializer_index[table_idx * line_size + col * pack_size + k]; EmbeddingInitializer initializer = initializer_param[initializer_idx]; T value; if (initializer.type == InitializerType::kUniform) { const float low = initializer.uniform_param.low; const float high = initializer.uniform_param.high; value = curand_uniform(&state) * (high - low) + low; } else if (initializer.type == InitializerType::kNormal) { const float mean = initializer.normal_param.mean; const float std = initializer.normal_param.std; value = curand_normal(&state) * std + mean; } else if (initializer.type == InitializerType::kConstant) { value = initializer.constant_param.value; } else if (initializer.type == InitializerType::kTruncNormal) { const float mean = initializer.trunc_normal_param.mean; const float std = initializer.trunc_normal_param.std; const float a = initializer.trunc_normal_param.a; const float b = initializer.trunc_normal_param.b; while (true) { value = curand_normal(&state) * std + mean; if (value >= a && value <= b) { break; } skipahead(line_size, &state); } } else { __trap(); } value_i.elem[k] = value; } values[i] = value_i; } else { value_i = values[i]; } if (embeddings != nullptr && col < embedding_num_pack) { int64_t embedding_offset = row * embedding_num_pack + col; Pack embedding_i; #pragma unroll for (int k = 0; k < pack_size; ++k) { embedding_i.elem[k] = static_cast(value_i.elem[k]); } embeddings[embedding_offset] = embedding_i; } } } template void InitMissingAndSliceCast(cudaStream_t cuda_stream, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, uint64_t seed, const EmbeddingInitializer* initializer_param, const int8_t* initializer_index, const void* unique_ids, const void* table_ids, const uint8_t* mask, T* values_ptr, V* embeddings_ptr) { int32_t pack_size; if (embedding_size % 4 == 0 && line_size % 4 == 0) { pack_size = 4; } else if (embedding_size % 2 == 0 && line_size % 2 == 0) { pack_size = 2; } else { pack_size = 1; } int32_t embedding_num_pack = embedding_size / pack_size; int32_t line_num_pack = line_size / pack_size; int64_t value_elem_cnt = num_unique * line_size; int64_t value_elem_num_pack = value_elem_cnt / pack_size; const int64_t num_blocks = BlocksNum4ThreadsNum(value_elem_num_pack); if (pack_size == 4) { FusedInitSliceCast<<>>( value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), mask, reinterpret_cast*>(values_ptr), reinterpret_cast*>(embeddings_ptr)); } else if (pack_size == 2) { FusedInitSliceCast<<>>( value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), mask, reinterpret_cast*>(values_ptr), reinterpret_cast*>(embeddings_ptr)); } else { FusedInitSliceCast<<>>( value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), mask, reinterpret_cast*>(values_ptr), reinterpret_cast*>(embeddings_ptr)); } } template void LookupAndFusedInitMissingSliceCast(ep::Stream* stream, EmbeddingKernelState* kernel_state, uint64_t seed, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, DataType value_dtype, DataType embedding_dtype, const void* unique_ids, const void* table_ids, uint8_t* lookup_mask_ptr, void* values_ptr, void* embeddings_ptr) { embedding::KeyValueStore* store = kernel_state->KeyValueStore(); const EmbeddingInitializer* initializer_param = kernel_state->Initializers(); const int8_t* initializer_index = kernel_state->InitializerIndex(); cudaStream_t cuda_stream = stream->As()->cuda_stream(); store->Get(stream, num_unique, unique_ids, values_ptr, lookup_mask_ptr); if (embedding_dtype == value_dtype) { InitMissingAndSliceCast( cuda_stream, num_unique, embedding_size, line_size, seed, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), lookup_mask_ptr, reinterpret_cast(values_ptr), reinterpret_cast(embeddings_ptr)); } else if (embedding_dtype == DataType::kFloat16) { InitMissingAndSliceCast( cuda_stream, num_unique, embedding_size, line_size, seed, initializer_param, initializer_index, reinterpret_cast(unique_ids), reinterpret_cast(table_ids), lookup_mask_ptr, reinterpret_cast(values_ptr), reinterpret_cast(embeddings_ptr)); } else { UNIMPLEMENTED() << "Unimplemented data_type " << embedding_dtype; } } template __global__ void Copy2D(int64_t out_elem_cnt, const int32_t in_cols, const int32_t out_cols, const T* in, U* out) { CUDA_1D_KERNEL_LOOP(i, out_elem_cnt) { const int32_t row = i / out_cols; const int32_t col = i - row * out_cols; const int64_t in_offset = row * in_cols + col; out[i] = static_cast(in[in_offset]); } } template void CopyValuesToEmbeddings(ep::Stream* stream, int64_t num_unique, const int32_t embedding_size, const int32_t value_size, const DataType value_dtype, const DataType embedding_dtype, const T* values, void* embeddings) { bool need_cast = (value_dtype != embedding_dtype); bool need_copy_nd = (embedding_size != value_size); CHECK(need_cast || need_copy_nd); if (need_cast && !need_copy_nd) { const int64_t cast_elem_count = num_unique * embedding_size; std::unique_ptr cast_primitive = ep::primitive::NewPrimitive(DeviceType::kCUDA, value_dtype, embedding_dtype); cast_primitive->Launch(stream, values, embeddings, cast_elem_count); } else if (!need_cast && need_copy_nd) { const int32_t ndims = 2; DimVector src_pos_vec(ndims, 0); DimVector dst_pos_vec(ndims, 0); DimVector src_shape = {num_unique, value_size}; DimVector dst_shape = {num_unique, embedding_size}; DimVector extent_shape = {num_unique, embedding_size}; std::unique_ptr copy_nd_primitive = ep::primitive::NewPrimitive(DeviceType::kCUDA, ndims); CHECK(copy_nd_primitive); copy_nd_primitive->Launch(stream, value_dtype, ndims, embeddings, dst_shape.data(), dst_pos_vec.data(), values, src_shape.data(), src_pos_vec.data(), extent_shape.data()); } else { const int64_t embedding_elem_cnt = num_unique * embedding_size; if (embedding_dtype == DataType::kFloat16) { Copy2D<<As()->cuda_stream()>>>( embedding_elem_cnt, value_size, embedding_size, values, reinterpret_cast(embeddings)); } else { UNIMPLEMENTED(); } } } template user_op::InferTmpSizeFn GenEmbeddingInferTmpSizeFn() { return [](user_op::InferContext* ctx) { size_t total_buffer_size = 0; if (embedding::UseDynamicMemoryAllocation()) { return total_buffer_size; } const user_op::TensorDesc& unique_ids = ctx->InputTensorDesc("unique_ids", 0); int64_t num_ids = unique_ids.shape().elem_cnt(); size_t num_missing_size = GetCudaAlignedSize(sizeof(uint32_t)); size_t missing_indices_size = GetCudaAlignedSize(num_ids * sizeof(uint32_t)); size_t value_buffer_size; if (is_prefetch) { size_t value_byte_size = ctx->Attr("line_size") * sizeof(T); value_buffer_size = GetCudaAlignedSize(num_ids * value_byte_size); } else { value_buffer_size = 0; } total_buffer_size = num_missing_size + missing_indices_size + value_buffer_size; return total_buffer_size; }; } class IdShuffleCopyOutKernelState final : public user_op::OpKernelState { public: explicit IdShuffleCopyOutKernelState(user_op::KernelInitContext* ctx) { const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); } ~IdShuffleCopyOutKernelState() override = default; embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } private: embedding::EmbeddingState* embedding_state_; }; template struct IdShuffleCopyOutParam { uint32_t final_num_unique_ids; const K* cur_rank_unique_ids; K* out_cur_rank_unique_ids; const U* cur_rank_unique_table_ids; U* out_cur_rank_unique_table_ids; uint32_t cur_rank_num_ids; const IDX* cur_rank_inverse_indices; IDX* out_cur_rank_inverse_indices; uint32_t num_ids; const IDX* inverse_unique_partition_indices; IDX* out_inverse_unique_partition_indices; uint32_t num_unique_matrix_cnt; const IDX* num_unique_matrix; IDX* out_num_unique_matrix; const IDX* cur_rank_num_unique; IDX* out_cur_rank_num_unique; }; template __global__ void CopyGpu(IdShuffleCopyOutParam param) { CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.final_num_unique_ids) { param.out_cur_rank_unique_ids[i] = param.cur_rank_unique_ids[i]; param.out_cur_rank_unique_table_ids[i] = param.cur_rank_unique_table_ids[i]; } CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.cur_rank_num_ids) { param.out_cur_rank_inverse_indices[i] = param.cur_rank_inverse_indices[i]; } CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.num_ids) { param.out_inverse_unique_partition_indices[i] = param.inverse_unique_partition_indices[i]; } CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.num_unique_matrix_cnt) { param.out_num_unique_matrix[i] = param.num_unique_matrix[i]; } if (blockIdx.x * blockDim.x + threadIdx.x == 0) { *param.out_cur_rank_num_unique = *param.cur_rank_num_unique; } } } // namespace template class EmbeddingPrefetchKernel final : public user_op::OpKernel { public: EmbeddingPrefetchKernel() : current_iter_(0){}; ~EmbeddingPrefetchKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::unique_ptr allocator = embedding_state->NewTmpBufferAllocator(ctx); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); const int64_t seed = ctx->Attr("seed"); void* num_missing_ptr; allocator->Allocate(&num_missing_ptr, sizeof(uint32_t)); void* missing_indices_ptr; allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t)); void* values_ptr; allocator->Allocate(&values_ptr, num_unique * line_size * sizeof(T)); LookupAndInitMissing( ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, true, unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr); allocator->Free(num_missing_ptr); allocator->Free(missing_indices_ptr); allocator->Free(values_ptr); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define EMBEDDING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define TABLE_ID_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define REGISTER_CUDA_EMBEDDING_PREFETCH_KERNEL(t_dtype_pair, k_dtype_pair, table_dtype_pair, \ idx_dtype_pair) \ REGISTER_USER_KERNEL("embedding_prefetch") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("unique_ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("table_ids", 0) == OF_PP_PAIR_SECOND(table_dtype_pair)) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \ .SetInferTmpSizeFn(GenEmbeddingInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_PREFETCH_KERNEL, EMBEDDING_DATA_TYPE_SEQ, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class EmbeddingLookupKernel final : public user_op::OpKernel { public: EmbeddingLookupKernel() : current_iter_(0){}; ~EmbeddingLookupKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); std::unique_ptr allocator = embedding_state->NewTmpBufferAllocator(ctx); embedding_state->OnEmbeddingLookupStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); const bool has_output_embeddings = ctx->has_output("embeddings", 0); const int64_t seed = ctx->Attr("seed"); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); void* values_ptr = embedding_state->LookupUniqueValues(current_iter_); if (has_output_embeddings && kernel_state->KeyValueStore()->IsFusionSupported()) { void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); void* lookup_mask_ptr; allocator->Allocate(&lookup_mask_ptr, num_unique * sizeof(uint8_t)); LookupAndFusedInitMissingSliceCast( ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, unique_values->data_type(), embeddings->data_type(), unique_ids->dptr(), table_ids->dptr(), reinterpret_cast(lookup_mask_ptr), values_ptr, embeddings_ptr); allocator->Free(lookup_mask_ptr); } else { void* num_missing_ptr; allocator->Allocate(&num_missing_ptr, sizeof(uint32_t)); void* missing_indices_ptr; allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t)); LookupAndInitMissing( ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, false, unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr); allocator->Free(num_missing_ptr); allocator->Free(missing_indices_ptr); if (has_output_embeddings) { void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); CopyValuesToEmbeddings(ctx->stream(), num_unique, embedding_size, line_size, unique_values->data_type(), embeddings->data_type(), reinterpret_cast(values_ptr), embeddings_ptr); } } embedding_state->OnEmbeddingLookupEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_EMBEDDING_LOOKUP_KERNEL(t_dtype_pair, k_dtype_pair, table_dtype_pair, \ idx_dtype_pair) \ REGISTER_USER_KERNEL("embedding_lookup") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("unique_values", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)) \ && (user_op::HobDataType("unique_ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("table_ids", 0) == OF_PP_PAIR_SECOND(table_dtype_pair)) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \ .SetInferTmpSizeFn(GenEmbeddingInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_LOOKUP_KERNEL, EMBEDDING_DATA_TYPE_SEQ, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class EmbeddingPutKernel final : public user_op::OpKernel { public: EmbeddingPutKernel() : current_iter_(0){}; ~EmbeddingPutKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::KeyValueStore* store = kernel_state->KeyValueStore(); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingPutStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); store->Put(ctx->stream(), num_unique, unique_ids->dptr(), embedding_state->EmbeddingPutUniqueEmbeddings(current_iter_)); embedding_state->OnEmbeddingPutEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_EMBEDDING_PUT_KERNEL(dtype, typeproto) \ REGISTER_USER_KERNEL("embedding_put") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == typeproto)); OF_PP_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_PUT_KERNEL, IDX_DATA_TYPE_SEQ) template class OneEmbeddingFusedSgdUpdatePutKernel final : public user_op::OpKernel { public: OneEmbeddingFusedSgdUpdatePutKernel() : current_iter_(0){}; ~OneEmbeddingFusedSgdUpdatePutKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::KeyValueStore* store = kernel_state->KeyValueStore(); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingFusedUpdatePutStart(ctx, current_iter_); const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex("unique_ids", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const float* learning_rate_ptr = learning_rate->dptr(); const auto scale = ctx->Attr("scale"); uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); store->FusedHalfUpdatePut( ctx->stream(), num_unique, unique_ids->dptr(), embedding_state->EmbeddingFusedUpdatePutUniqueEmbeddings(current_iter_), embedding_grad->dptr(), learning_rate_ptr, scale); embedding_state->OnEmbeddingFusedUpdatePutEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_FUSED_SGD_UPDATE_PUT_KERNEL(dtype, typeproto) \ REGISTER_USER_KERNEL("one_embedding_fused_sgd_update_put") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == typeproto) \ && (user_op::HobDataType("unique_embeddings", 0) == DataType::kFloat) \ && (user_op::HobDataType("embedding_grad", 0) == DataType::kFloat16)); OF_PP_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_SGD_UPDATE_PUT_KERNEL, IDX_DATA_TYPE_SEQ) template class IdShuffleCopyOutKernel final : public user_op::OpKernel { public: IdShuffleCopyOutKernel() : current_iter_(0){}; ~IdShuffleCopyOutKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const std::vector& num_unique_matrix_vec = embedding_state->GetIdNumUniqueMatrix(current_iter_); uint32_t cur_rank_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += num_unique_matrix_vec.at(i * parallel_num + parallel_id); } IdShuffleCopyOutParam param; param.final_num_unique_ids = num_unique; param.cur_rank_unique_ids = reinterpret_cast(ctx->Tensor4ArgNameAndIndex("cur_rank_unique_ids", 0)->dptr()); param.out_cur_rank_unique_ids = reinterpret_cast(ctx->Tensor4ArgNameAndIndex("out_cur_rank_unique_ids", 0)->mut_dptr()); param.cur_rank_unique_table_ids = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("cur_rank_unique_table_ids", 0)->dptr()); param.out_cur_rank_unique_table_ids = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("out_cur_rank_unique_table_ids", 0)->mut_dptr()); param.cur_rank_num_ids = cur_rank_num_ids; param.cur_rank_inverse_indices = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0)->dptr()); param.out_cur_rank_inverse_indices = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("out_cur_rank_inverse_indices", 0)->mut_dptr()); param.num_ids = ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0)->shape_view().elem_cnt(); param.inverse_unique_partition_indices = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0)->dptr()); param.out_inverse_unique_partition_indices = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("out_inverse_unique_partition_indices", 0)->mut_dptr()); param.num_unique_matrix_cnt = parallel_num * parallel_num; param.num_unique_matrix = reinterpret_cast(ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0)->dptr()); param.out_num_unique_matrix = reinterpret_cast(ctx->Tensor4ArgNameAndIndex("out_num_unique_matrix", 0)->mut_dptr()); param.cur_rank_num_unique = reinterpret_cast(ctx->Tensor4ArgNameAndIndex("cur_rank_num_unique", 0)->dptr()); param.out_cur_rank_num_unique = reinterpret_cast( ctx->Tensor4ArgNameAndIndex("out_cur_rank_num_unique", 0)->mut_dptr()); CopyGpu<<stream()->As()->cuda_stream()>>>(param); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ID_SHUFFLE_COPY_OUT_KERNEL(k_dtype_pair, table_id_dtype_pair, \ idx_dtype_pair) \ REGISTER_USER_KERNEL("id_shuffle_copy_out") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("cur_rank_unique_ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("cur_rank_unique_table_ids", 0) \ == OF_PP_PAIR_SECOND(table_id_dtype_pair)) \ && (user_op::HobDataType("num_unique_matrix", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_COPY_OUT_KERNEL, ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) enum class FusedEmbeddingBufferType { // id shuffle kNumPartitionedUnique = 0, kPartitionedUniqueIds, kReceivedIds, kTableIds, kPartitionedUniqueTableIds, kReceivedTableIds, kWorkspace, kNumUniqueMatrix, kInverseUniquePartitionIndices, kCurRankNumUnique, kCurRankUniqueIds, kCurRankUniqueTableIds, kCurRankInverseIndices, // embedding lookup kNumMissing, kMissingIndices, kCurRankUniqueValues, kCurRankUniqueEmbeddings, // embedding shuffle kReverseUniqueCurRankEmbeddings, kReceivedEmbeddings, kMaxType }; template class FusedEmbeddingTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(FusedEmbeddingTmpBufferManager); FusedEmbeddingTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num, bool need_process_table_ids, int64_t line_size, int64_t embedding_size, bool need_unique_values, bool need_embeddings, DataType value_dtype, DataType embedding_dtype) : offset_(0), offsets_(static_cast(FusedEmbeddingBufferType::kMaxType), -1), sizes_(static_cast(FusedEmbeddingBufferType::kMaxType)), ptr_(ptr) { // id shuffle const int64_t num_table_ids = need_process_table_ids ? num_ids : 0; const size_t table_ids_bytes = need_process_table_ids ? num_ids * sizeof(U) : 0; AllocBuffer(FusedEmbeddingBufferType::kNumPartitionedUnique, parallel_num * sizeof(IDX)); size_t partitioned_ids_bytes = parallel_num * num_ids * sizeof(K); AllocBuffer(FusedEmbeddingBufferType::kPartitionedUniqueIds, partitioned_ids_bytes); AllocBuffer(FusedEmbeddingBufferType::kReceivedIds, partitioned_ids_bytes); AllocBuffer(FusedEmbeddingBufferType::kTableIds, table_ids_bytes); size_t partitioned_table_ids_bytes = parallel_num * num_table_ids * sizeof(U); AllocBuffer(FusedEmbeddingBufferType::kPartitionedUniqueTableIds, partitioned_table_ids_bytes); AllocBuffer(FusedEmbeddingBufferType::kReceivedTableIds, partitioned_table_ids_bytes); const size_t hash_table_capacity = parallel_num * num_ids; AllocBuffer(FusedEmbeddingBufferType::kWorkspace, hash_table_capacity * sizeof(data_shuffle::TableEntry)); size_t num_unique_matrix_bytes = parallel_num * parallel_num * sizeof(IDX); AllocBuffer(FusedEmbeddingBufferType::kNumUniqueMatrix, num_unique_matrix_bytes); size_t inverse_unique_partition_indices_bytes = num_ids * sizeof(IDX); AllocBuffer(FusedEmbeddingBufferType::kInverseUniquePartitionIndices, inverse_unique_partition_indices_bytes); size_t cur_rank_num_ids = parallel_num * num_ids; size_t cur_rank_num_table_ids = cur_rank_num_ids; size_t cur_rank_num_unique_bytes = sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kCurRankNumUnique, cur_rank_num_unique_bytes); size_t cur_rank_unique_ids_bytes = cur_rank_num_ids * sizeof(K); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueIds, cur_rank_unique_ids_bytes); size_t cur_rank_unique_table_ids_bytes = cur_rank_num_table_ids * sizeof(U); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueTableIds, cur_rank_unique_table_ids_bytes); size_t cur_rank_inverse_indices_bytes = cur_rank_num_ids * sizeof(IDX); AllocBuffer(FusedEmbeddingBufferType::kCurRankInverseIndices, cur_rank_inverse_indices_bytes); // embedding lookup size_t num_missing_bytes = sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kNumMissing, num_missing_bytes); size_t missing_indices_bytes = cur_rank_num_ids * sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kMissingIndices, missing_indices_bytes); if (need_unique_values) { size_t cur_rank_unique_values_bytes = cur_rank_num_ids * line_size * GetSizeOfDataType(value_dtype); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueValues, cur_rank_unique_values_bytes); } if (need_embeddings) { size_t cur_rank_unique_embeddings_bytes = cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings, cur_rank_unique_embeddings_bytes); } // embedding shuffle size_t reverse_unique_cur_rank_embeddings_bytes = cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype); AllocBuffer(FusedEmbeddingBufferType::kReverseUniqueCurRankEmbeddings, reverse_unique_cur_rank_embeddings_bytes); size_t received_embeddings_bytes = cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype); AllocBuffer(FusedEmbeddingBufferType::kReceivedEmbeddings, received_embeddings_bytes); } template T* Ptr(FusedEmbeddingBufferType type) const { CHECK(ptr_ != nullptr); int64_t offset = offsets_.at(static_cast(type)); CHECK_NE(offset, -1); return reinterpret_cast(reinterpret_cast(ptr_) + offset); } int64_t Size(FusedEmbeddingBufferType type) const { return sizes_.at(static_cast(type)); } size_t TotalBufferSize() const { return offset_; } private: void AllocBuffer(FusedEmbeddingBufferType type, size_t size) { const size_t type_id = static_cast(type); CHECK_EQ(offsets_.at(type_id), -1); offsets_.at(type_id) = offset_; sizes_.at(type_id) = size; offset_ += GetCudaAlignedSize(size); } size_t offset_; std::vector offsets_; std::vector sizes_; void* ptr_; }; void MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t line_size, const std::vector& values, std::string* initializer_attr) { if (embedding_size == line_size) { return; } const int32_t num_states = line_size / embedding_size - 1; CHECK_GT(num_states, 0) << "num_states " << num_states; CHECK(values.size() == 0 || num_states == values.size()) << "must set " << num_states << " optimizer states init value, but get " << values.size(); nlohmann::json initializers; for (int32_t i = 0; i < num_states; ++i) { nlohmann::json initializer; initializer["type"] = "constant"; const float initial_value = values.size() > 0 ? values.at(i) : 0.0; initializer["value"] = initial_value; initializers.push_back(initializer); } *initializer_attr = initializers.dump(); } template class OneEmbeddingFusedLookupKernelState final : public user_op::OpKernelState { public: explicit OneEmbeddingFusedLookupKernelState(user_op::KernelInitContext* ctx) : device_index_(-1), stream_name_(EagerNcclCommMgr::kDefaultStreamName), parallel_desc_(ctx->parallel_desc()) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX))); OF_CUDA_CHECK( cudaMallocHost(&host_num_unique_matrix_, parallel_num * parallel_num * sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); key_value_store_ = Singleton::Get()->GetKeyValueStore( embedding_name, parallel_id); uint32_t max_query_length = ctx->TensorDesc4ArgNameAndIndex("ids", 0)->shape().elem_cnt() * parallel_num; key_value_store_->ReserveQueryLength(max_query_length); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); // Note(guoran): This op have no optimizer info, so set embedding states initializer constant // 0, which may make error in optimizer with initial_accumulator_value like adagrad and ftrl. std::string state_initializer; MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer); std::vector initializer_param; std::vector initializer_index; ParseInitializers(line_size, embedding_size, state_initializer, ctx->Attr("embedding_tables"), &initializer_param, &initializer_index); const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes)); std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_, param_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes)); std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_, index_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); } ~OneEmbeddingFusedLookupKernelState() override { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_num_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_)); OF_CUDA_CHECK(cudaFree(device_initializer_param_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_)); OF_CUDA_CHECK(cudaFree(device_initializer_index_)); } ncclComm_t comm() { return GetOrCreate().comm; } IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; } IDX* HostNumKeys() { return host_num_keys_; } embedding::KeyValueStore* KeyValueStore() { return key_value_store_; } const int8_t* InitializerIndex() { return device_initializer_index_; } const EmbeddingInitializer* Initializers() { return device_initializer_param_; } private: struct Comm { Comm(ncclComm_t comm) : comm(comm) {} ncclComm_t comm; }; const Comm& GetOrCreate() { if (!comm_) { Init(); } return *comm_; } void Init() { std::set> device_set; for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ncclComm_t comm; comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); comm_.reset(new Comm(comm)); } int device_index_; std::string stream_name_; ParallelDesc parallel_desc_; std::unique_ptr comm_; IDX* host_num_keys_; IDX* host_num_unique_matrix_; embedding::KeyValueStore* key_value_store_; EmbeddingInitializer* host_initializer_param_; EmbeddingInitializer* device_initializer_param_; int8_t* host_initializer_index_; int8_t* device_initializer_index_; }; template void LookupAndInitMissing(ep::Stream* stream, OneEmbeddingFusedLookupKernelState* kernel_state, uint64_t seed, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, const bool put_to_store, const void* unique_ids, const void* table_ids, void* num_missing_ptr, void* missing_indices, void* store_values) { embedding::KeyValueStore* store = kernel_state->KeyValueStore(); const EmbeddingInitializer* initializer_param = kernel_state->Initializers(); const int8_t* initializer_index = kernel_state->InitializerIndex(); void* host_num_keys = kernel_state->HostNumKeys(); LookupAndInitMissing(stream, seed, store, initializer_param, initializer_index, host_num_keys, num_unique, embedding_size, line_size, put_to_store, unique_ids, table_ids, num_missing_ptr, missing_indices, store_values); } template void SetIdShuffleDataPtrsParam(const void* ids_ptr, const FusedEmbeddingTmpBufferManager& buffer_manager, data_shuffle::IdShuffleDataPtrs* data_ptrs) { data_ptrs->ids_ptr = reinterpret_cast(ids_ptr); data_ptrs->table_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kTableIds); data_ptrs->num_partitioned_unique = buffer_manager.template Ptr(FusedEmbeddingBufferType::kNumPartitionedUnique); data_ptrs->partitioned_unique_ids = buffer_manager.template Ptr(FusedEmbeddingBufferType::kPartitionedUniqueIds); data_ptrs->partitioned_unique_table_ids = buffer_manager.template Ptr(FusedEmbeddingBufferType::kPartitionedUniqueTableIds); data_ptrs->workspace_ptr = buffer_manager.Ptr(FusedEmbeddingBufferType::kWorkspace); data_ptrs->workspace_size = buffer_manager.Size(FusedEmbeddingBufferType::kWorkspace); data_ptrs->received_ids = buffer_manager.template Ptr(FusedEmbeddingBufferType::kReceivedIds); data_ptrs->received_table_ids = buffer_manager.template Ptr(FusedEmbeddingBufferType::kReceivedTableIds); data_ptrs->inverse_unique_partition_indices_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kInverseUniquePartitionIndices); data_ptrs->num_unique_matrix_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kNumUniqueMatrix); data_ptrs->cur_rank_num_unique_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankNumUnique); data_ptrs->cur_rank_unique_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueIds); data_ptrs->cur_rank_unique_table_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueTableIds); data_ptrs->cur_rank_inverse_indices_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankInverseIndices); } template class OneEmbeddingFusedLookupKernel final : public user_op::OpKernel { public: OneEmbeddingFusedLookupKernel() = default; ~OneEmbeddingFusedLookupKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { // IDX type is uint32_t, table_ids type is uint8_t. DataType num_unique_matrix_dtype = DataType::kUInt32; DataType table_ids_dtype = DataType::kUInt8; CHECK_EQ(sizeof(IDX), GetSizeOfDataType(num_unique_matrix_dtype)); CHECK_EQ(sizeof(U), GetSizeOfDataType(table_ids_dtype)); auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex("ids", 0); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); const int32_t num_tables = ctx->Attr("num_tables"); // default uint8_t as table_ids type, so num_tables can not greater than 256. CHECK_LE(num_tables, 256) << num_tables; const bool has_table_ids = ctx->has_input("table_ids", 0); const bool need_process_table_ids = (has_table_ids || num_tables > 1); const int64_t num_ids = ids->shape_view().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); DataType value_dtype = ctx->Attr("dtype"); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool has_padding_idx = ctx->Attr("has_padding_idx"); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); bool need_unique_values = true; bool need_embeddings = (line_size != embedding_size) || (value_dtype != embeddings->data_type()); FusedEmbeddingTmpBufferManager buffer_manager( tmp_buffer->mut_dptr(), num_ids, parallel_num, need_process_table_ids, line_size, embedding_size, need_unique_values, need_embeddings, value_dtype, embeddings->data_type()); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize()); ncclComm_t comm = kernel_state->comm(); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); IDX* host_num_keys = kernel_state->HostNumKeys(); data_shuffle::IdShuffleDataPtrs data_ptrs; SetIdShuffleDataPtrsParam(ids->dptr(), buffer_manager, &data_ptrs); // overwrite data_ptrs.table_ids_ptr if (need_process_table_ids) { U* tmp_table_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kTableIds); data_ptrs.table_ids_ptr = tmp_table_ids_ptr; if (has_table_ids) { // use table_id default data_type uint8, if has input table_ids with different data_type, // cast it to uint8. const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); if (table_ids->data_type() != table_ids_dtype) { std::unique_ptr cast_primitive = ep::primitive::NewPrimitive( DeviceType::kCUDA, table_ids->data_type(), table_ids_dtype); cast_primitive->Launch(ctx->stream(), table_ids->dptr(), tmp_table_ids_ptr, table_ids->shape_view().elem_cnt()); } else { data_ptrs.table_ids_ptr = reinterpret_cast(table_ids->dptr()); } } else { const int32_t num_tables = ctx->Attr("num_tables"); data_shuffle::GenerateTableIds<<>>(num_ids, num_tables, tmp_table_ids_ptr); } } else { data_ptrs.table_ids_ptr = nullptr; } data_shuffle::IdShuffle(ctx->stream(), comm, data_ptrs, num_ids, parallel_id, parallel_num, num_unique_matrix_dtype, ids->data_type(), table_ids_dtype, need_process_table_ids, has_padding_idx, padding_idx, host_num_unique_matrix, host_num_keys); uint32_t num_unique = *host_num_keys; // lookup and put, if is_full_cache, not put to store. uint32_t* num_missing_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kNumMissing); uint32_t* missing_indices_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kMissingIndices); void* values_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueValues); T* cur_rank_embeddings_ptr = need_embeddings ? buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings) : reinterpret_cast(values_ptr); const bool is_full_cache = ctx->Attr("is_full_cache"); const bool put_to_store = (!is_full_cache); const int64_t seed = ctx->Attr("seed"); LookupAndInitMissing( ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, put_to_store, data_ptrs.cur_rank_unique_ids_ptr, data_ptrs.cur_rank_unique_table_ids_ptr, num_missing_ptr, missing_indices_ptr, values_ptr); if (need_embeddings) { CopyValuesToEmbeddings(ctx->stream(), num_unique, embedding_size, line_size, value_dtype, embeddings->data_type(), reinterpret_cast(values_ptr), cur_rank_embeddings_ptr); } // embedding shuffle int64_t cur_rank_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } int64_t unique_partitioned_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i]; } T* reverse_unique_cur_rank_embeddings_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kReverseUniqueCurRankEmbeddings); T* received_embeddings_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kReceivedEmbeddings); GatherKernelUtilImpl::Forward( ctx->stream(), data_ptrs.cur_rank_inverse_indices_ptr, cur_rank_num_ids, cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}), reverse_unique_cur_rank_embeddings_ptr, 0); data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, embeddings->data_type(), host_num_unique_matrix, reverse_unique_cur_rank_embeddings_ptr, received_embeddings_ptr); GatherKernelUtilImpl::Forward( ctx->stream(), data_ptrs.inverse_unique_partition_indices_ptr, num_ids, received_embeddings_ptr, Shape({1, unique_partitioned_num_ids, embedding_size}), embeddings->mut_dptr(), 0); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto SingleDeviceKernel() { return hob::make_custom("SingleDeviceKernel", [](const user_op::KernelRegContext& ctx) { return (ctx.parallel_ctx().parallel_num() == 1); }); } // Note(guoran): Default use U type as uint8_t, IDX as uint32_t. Because table_ids is optional, so // can not use it in hob, if has table_ids input and dtype is not uint8_t cast to uint8_t in kernel. #define REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_KERNEL(k_dtype_pair, t_dtype_pair, v_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_fused_lookup") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)) \ && (user_op::HobAttr("dtype") == OF_PP_PAIR_SECOND(v_dtype_pair)) \ && !SingleDeviceKernel()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& ids = ctx->InputTensorDesc("ids", 0); \ const user_op::TensorDesc& embeddings = ctx->OutputTensorDesc("embeddings", 0); \ const bool has_table_ids = ctx->has_input("table_ids", 0); \ const int32_t num_tables = ctx->Attr("num_tables"); \ const bool need_process_table_ids = (has_table_ids || num_tables > 1); \ DataType value_dtype = ctx->Attr("dtype"); \ const int64_t embedding_size = ctx->Attr("embedding_size"); \ const int64_t line_size = ctx->Attr("line_size"); \ bool need_embeddings = \ (line_size != embedding_size) || (value_dtype != embeddings.data_type()); \ FusedEmbeddingTmpBufferManager \ buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_ctx().parallel_num(), \ need_process_table_ids, line_size, embedding_size, true, \ need_embeddings, value_dtype, embeddings.data_type()); \ return buffer_manager.TotalBufferSize(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_KERNEL, ID_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, EMBEDDING_DATA_TYPE_SEQ) template class OneEmbeddingFusedLookupLocalKernelState final : public user_op::OpKernelState { public: explicit OneEmbeddingFusedLookupLocalKernelState(user_op::KernelInitContext* ctx) : device_index_(-1) { OF_CUDA_CHECK(cudaGetDevice(&device_index_)); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX))); const std::string& embedding_name = ctx->Attr("embedding_name"); key_value_store_ = Singleton::Get()->GetKeyValueStore( embedding_name, parallel_id); uint32_t max_query_length = ctx->TensorDesc4ArgNameAndIndex("ids", 0)->shape().elem_cnt(); key_value_store_->ReserveQueryLength(max_query_length); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); // Note(guoran): This op have no optimizer info, so set embedding states initializer constant // 0, which may make error in optimizer with initial_accumulator_value like adagrad and ftrl. std::string state_initializer; MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer); std::vector initializer_param; std::vector initializer_index; ParseInitializers(line_size, embedding_size, state_initializer, ctx->Attr("embedding_tables"), &initializer_param, &initializer_index); const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes)); std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_, param_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t); OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes)); std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes); OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes)); OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_, index_size_bytes, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); } ~OneEmbeddingFusedLookupLocalKernelState() override { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaFreeHost(host_num_keys_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_)); OF_CUDA_CHECK(cudaFree(device_initializer_param_)); OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_)); OF_CUDA_CHECK(cudaFree(device_initializer_index_)); } IDX* HostNumKeys() { return host_num_keys_; } embedding::KeyValueStore* KeyValueStore() { return key_value_store_; } const int8_t* InitializerIndex() { return device_initializer_index_; } const EmbeddingInitializer* Initializers() { return device_initializer_param_; } private: int device_index_; IDX* host_num_keys_; embedding::KeyValueStore* key_value_store_; EmbeddingInitializer* host_initializer_param_; EmbeddingInitializer* device_initializer_param_; int8_t* host_initializer_index_; int8_t* device_initializer_index_; }; template void LookupAndInitMissing(ep::Stream* stream, OneEmbeddingFusedLookupLocalKernelState* kernel_state, uint64_t seed, uint32_t num_unique, const int64_t embedding_size, const int64_t line_size, const bool put_to_store, const void* unique_ids, const void* table_ids, void* num_missing_ptr, void* missing_indices, void* store_values) { embedding::KeyValueStore* store = kernel_state->KeyValueStore(); const EmbeddingInitializer* initializer_param = kernel_state->Initializers(); const int8_t* initializer_index = kernel_state->InitializerIndex(); void* host_num_keys = kernel_state->HostNumKeys(); LookupAndInitMissing(stream, seed, store, initializer_param, initializer_index, host_num_keys, num_unique, embedding_size, line_size, put_to_store, unique_ids, table_ids, num_missing_ptr, missing_indices, store_values); } template class FusedLocalEmbeddingTmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(FusedLocalEmbeddingTmpBufferManager); FusedLocalEmbeddingTmpBufferManager(void* ptr, const int64_t num_ids, bool need_process_table_ids, int64_t line_size, int64_t embedding_size, bool need_embeddings, DataType value_dtype, DataType embedding_dtype) : offset_(0), offsets_(static_cast(FusedEmbeddingBufferType::kMaxType), -1), sizes_(static_cast(FusedEmbeddingBufferType::kMaxType)), ptr_(ptr) { // id shuffle const size_t table_ids_bytes = need_process_table_ids ? num_ids * sizeof(U) : 0; AllocBuffer(FusedEmbeddingBufferType::kTableIds, table_ids_bytes); const size_t hash_table_capacity = num_ids; AllocBuffer(FusedEmbeddingBufferType::kWorkspace, hash_table_capacity * sizeof(data_shuffle::TableEntry)); size_t cur_rank_num_ids = num_ids; size_t cur_rank_num_table_ids = cur_rank_num_ids; size_t cur_rank_num_unique_bytes = sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kCurRankNumUnique, cur_rank_num_unique_bytes); size_t cur_rank_unique_ids_bytes = cur_rank_num_ids * sizeof(K); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueIds, cur_rank_unique_ids_bytes); size_t cur_rank_unique_table_ids_bytes = cur_rank_num_table_ids * sizeof(U); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueTableIds, cur_rank_unique_table_ids_bytes); size_t cur_rank_inverse_indices_bytes = cur_rank_num_ids * sizeof(IDX); AllocBuffer(FusedEmbeddingBufferType::kCurRankInverseIndices, cur_rank_inverse_indices_bytes); // embedding lookup size_t num_missing_bytes = sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kNumMissing, num_missing_bytes); size_t missing_indices_bytes = cur_rank_num_ids * sizeof(uint32_t); AllocBuffer(FusedEmbeddingBufferType::kMissingIndices, missing_indices_bytes); size_t cur_rank_unique_values_bytes = cur_rank_num_ids * line_size * GetSizeOfDataType(value_dtype); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueValues, cur_rank_unique_values_bytes); if (need_embeddings) { size_t cur_rank_unique_embeddings_bytes = cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype); AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings, cur_rank_unique_embeddings_bytes); } } template T* Ptr(FusedEmbeddingBufferType type) const { CHECK(ptr_ != nullptr); int64_t offset = offsets_.at(static_cast(type)); CHECK_NE(offset, -1); return reinterpret_cast(reinterpret_cast(ptr_) + offset); } int64_t Size(FusedEmbeddingBufferType type) const { return sizes_.at(static_cast(type)); } size_t TotalBufferSize() const { return offset_; } private: void AllocBuffer(FusedEmbeddingBufferType type, size_t size) { const size_t type_id = static_cast(type); CHECK_EQ(offsets_.at(type_id), -1); offsets_.at(type_id) = offset_; sizes_.at(type_id) = size; offset_ += GetCudaAlignedSize(size); } size_t offset_; std::vector offsets_; std::vector sizes_; void* ptr_; }; template class OneEmbeddingFusedLookupLocalKernel final : public user_op::OpKernel { public: OneEmbeddingFusedLookupLocalKernel() = default; ~OneEmbeddingFusedLookupLocalKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { // IDX type is uint32_t, table_ids type is uint8_t. DataType num_unique_matrix_dtype = DataType::kUInt32; DataType table_ids_dtype = DataType::kUInt8; CHECK_EQ(sizeof(IDX), GetSizeOfDataType(num_unique_matrix_dtype)); CHECK_EQ(sizeof(U), GetSizeOfDataType(table_ids_dtype)); auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex("ids", 0); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); const int32_t num_tables = ctx->Attr("num_tables"); // default uint8_t as table_ids type, so num_tables can not greater than 256. CHECK_LE(num_tables, 256) << num_tables; const bool has_table_ids = ctx->has_input("table_ids", 0); const bool need_process_table_ids = (has_table_ids || num_tables > 1); const int64_t num_ids = ids->shape_view().elem_cnt(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); DataType value_dtype = ctx->Attr("dtype"); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); const int64_t padding_idx = ctx->Attr("padding_idx"); const bool has_padding_idx = ctx->Attr("has_padding_idx"); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); bool need_embeddings = (line_size != embedding_size) || (value_dtype != embeddings->data_type()); FusedLocalEmbeddingTmpBufferManager buffer_manager( tmp_buffer->mut_dptr(), num_ids, need_process_table_ids, line_size, embedding_size, need_embeddings, value_dtype, embeddings->data_type()); CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize()); IDX* host_num_keys = kernel_state->HostNumKeys(); const U* table_ids_ptr = nullptr; if (need_process_table_ids) { U* tmp_table_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kTableIds); table_ids_ptr = tmp_table_ids_ptr; if (has_table_ids) { // use table_id default data_type uint8, if has input table_ids with different data_type, // cast it to uint8. const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex("table_ids", 0); if (table_ids->data_type() != table_ids_dtype) { std::unique_ptr cast_primitive = ep::primitive::NewPrimitive( DeviceType::kCUDA, table_ids->data_type(), table_ids_dtype); cast_primitive->Launch(ctx->stream(), table_ids->dptr(), tmp_table_ids_ptr, table_ids->shape_view().elem_cnt()); } else { table_ids_ptr = reinterpret_cast(table_ids->dptr()); } } else { const int32_t num_tables = ctx->Attr("num_tables"); data_shuffle::GenerateTableIds<<>>(num_ids, num_tables, tmp_table_ids_ptr); } } IDX* num_unique_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankNumUnique); K* unique_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueIds); U* unique_table_ids_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueTableIds); IDX* inverse_indices_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankInverseIndices); void* workspace_ptr = buffer_manager.Ptr(FusedEmbeddingBufferType::kWorkspace); const size_t workspace_bytes = buffer_manager.Size(FusedEmbeddingBufferType::kWorkspace); int64_t hash_capacity = num_ids; data_shuffle::UniqueAndPartition( cuda_stream, num_ids, hash_capacity, 1, reinterpret_cast(ids->dptr()), table_ids_ptr, num_unique_ptr, unique_ids_ptr, unique_table_ids_ptr, inverse_indices_ptr, reinterpret_cast*>(workspace_ptr), workspace_bytes, need_process_table_ids, has_padding_idx, padding_idx); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_unique_ptr, sizeof(IDX), cudaMemcpyDefault, cuda_stream)); CHECK_JUST(ctx->stream()->Sync()); uint32_t num_unique = *host_num_keys; // lookup and put, if is_full_cache, not put to store. uint32_t* num_missing_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kNumMissing); uint32_t* missing_indices_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kMissingIndices); void* values_ptr = buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueValues); T* cur_rank_embeddings_ptr = need_embeddings ? buffer_manager.template Ptr(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings) : reinterpret_cast(values_ptr); const bool is_full_cache = ctx->Attr("is_full_cache"); const bool put_to_store = (!is_full_cache); const int64_t seed = ctx->Attr("seed"); LookupAndInitMissing( ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, put_to_store, unique_ids_ptr, unique_table_ids_ptr, num_missing_ptr, missing_indices_ptr, values_ptr); if (need_embeddings) { CopyValuesToEmbeddings(ctx->stream(), num_unique, embedding_size, line_size, value_dtype, embeddings->data_type(), reinterpret_cast(values_ptr), cur_rank_embeddings_ptr); } // gather GatherKernelUtilImpl::Forward( ctx->stream(), inverse_indices_ptr, num_ids, cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}), embeddings->mut_dptr(), 0); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; // Note(guoran): Default use U type as uint8_t, IDX as uint32_t. Because table_ids is optional, so // can not use it in hob, if has table_ids input and dtype is not uint8_t cast to uint8_t in kernel. #define REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_LOCAL_KERNEL(k_dtype_pair, t_dtype_pair, \ v_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_fused_lookup") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("ids", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \ && (user_op::HobDataType("embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)) \ && (user_op::HobAttr("dtype") == OF_PP_PAIR_SECOND(v_dtype_pair)) \ && SingleDeviceKernel()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& ids = ctx->InputTensorDesc("ids", 0); \ const user_op::TensorDesc& embeddings = ctx->OutputTensorDesc("embeddings", 0); \ const bool has_table_ids = ctx->has_input("table_ids", 0); \ const int32_t num_tables = ctx->Attr("num_tables"); \ const bool need_process_table_ids = (has_table_ids || num_tables > 1); \ DataType value_dtype = ctx->Attr("dtype"); \ const int64_t embedding_size = ctx->Attr("embedding_size"); \ const int64_t line_size = ctx->Attr("line_size"); \ bool need_embeddings = \ (line_size != embedding_size) || (value_dtype != embeddings.data_type()); \ FusedLocalEmbeddingTmpBufferManager \ buffer_manager(nullptr, ids.shape().elem_cnt(), need_process_table_ids, line_size, \ embedding_size, need_embeddings, value_dtype, embeddings.data_type()); \ return buffer_manager.TotalBufferSize(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_LOCAL_KERNEL, ID_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, EMBEDDING_DATA_TYPE_SEQ) class OneEmbeddingFusedLookupGradKernel final : public user_op::OpKernel { public: OneEmbeddingFusedLookupGradKernel() = default; ~OneEmbeddingFusedLookupGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { // do nothing } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("one_embedding_fused_lookup_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_embedding_update_kernels.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/model_update_kernel_util.h" #include "oneflow/core/embedding/embedding_manager.h" namespace oneflow { namespace { template __global__ void SGDUpdateKernel(const int64_t embedding_size, T scale, float l1, float l2, float weight_decay, float learning_rate_val, const IDX* num_unique_ids, const float* learning_rate, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* model, T* updated_model) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * embedding_size; CUDA_1D_KERNEL_LOOP(i, n) { updated_model[i] = model[i]; } } else { if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } const int64_t n = *num_unique_ids * embedding_size; CUDA_1D_KERNEL_LOOP(i, n) { updated_model[i] = model[i]; SGDUpdateFunctor()(model_diff + i, updated_model + i, scale, l1, l2, weight_decay, learning_rate_val); } } } __device__ void GetMomentumOffset(const int32_t line_size, const int32_t embedding_size, int64_t model_diff_offset, int64_t* model_offset, int64_t* momentum_offset) { const int32_t row = model_diff_offset / embedding_size; const int32_t col = model_diff_offset - row * embedding_size; *model_offset = row * line_size + col; *momentum_offset = *model_offset + embedding_size; } template __global__ void MomentumUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale, float l1, float l2, float weight_decay, float beta, float dampening, bool nesterov, bool maximize, float learning_rate_val, const IDX* num_unique_ids, const float* learning_rate, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * line_size; CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; } } else { if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } const int64_t n = *num_unique_ids * embedding_size; CUDA_1D_KERNEL_LOOP(i, n) { int64_t model_offset; int64_t momentum_offset; GetMomentumOffset(line_size, embedding_size, i, &model_offset, &momentum_offset); updated_unique_values[model_offset] = unique_values[model_offset]; updated_unique_values[momentum_offset] = unique_values[momentum_offset]; MomentumUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, updated_unique_values + momentum_offset, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val); } } } __device__ void GetAdamOffset(const int32_t line_size, const int32_t embedding_size, int64_t model_diff_offset, int64_t* model_offset, int64_t* m_offset, int64_t* v_offset) { const int32_t row = model_diff_offset / embedding_size; const int32_t col = model_diff_offset - row * embedding_size; *model_offset = row * line_size + col; *m_offset = *model_offset + embedding_size; *v_offset = *model_offset + 2 * embedding_size; } template __global__ void AdamUpdateKernel(const int32_t line_size, const int32_t embedding_size, T scale, float l1, float l2, float weight_decay, float beta1, float beta2, float epsilon, float learning_rate_val, float bias_correction1_val, float bias_correction2_val, const float* bias_correction1_ptr, const float* bias_correction2_ptr, const IDX* num_unique_ids, const float* learning_rate, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * line_size; CUDA_1D_KERNEL_LOOP(i, n) { // The n is the unique_values elem_cnt, so not need to use GetAdamOffset. updated_unique_values[i] = unique_values[i]; } } else { if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } const int64_t n = *num_unique_ids * embedding_size; // The n is model_diff elem_cnt. CUDA_1D_KERNEL_LOOP(i, n) { int64_t model_offset; int64_t m_offset; int64_t v_offset; GetAdamOffset(line_size, embedding_size, i, &model_offset, &m_offset, &v_offset); updated_unique_values[model_offset] = unique_values[model_offset]; updated_unique_values[m_offset] = unique_values[m_offset]; updated_unique_values[v_offset] = unique_values[v_offset]; AdamUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, updated_unique_values + m_offset, updated_unique_values + v_offset, nullptr, scale, l1, l2, beta1, beta2, epsilon, weight_decay, false, bias_correction1_val, bias_correction2_val, learning_rate_val); } } } // Note(guoran): The SmartDecaySparseAdam is from // https://github.com/pytorch/pytorch/blob/master/caffe2/sgd/adam_op.h#L57 template __global__ void SmartDecaySparseAdamUpdateKernel( const int32_t line_size, const int32_t embedding_size, T scale, float l1, float l2, float weight_decay, float beta1, float beta2, float epsilon, float learning_rate_val, int64_t step_col_offset, const IDX* num_unique_ids, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * line_size; CUDA_1D_KERNEL_LOOP(i, n) { // The n is the unique_values elem_cnt, so not need to use GetAdamOffset. updated_unique_values[i] = unique_values[i]; } } else { if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } const int64_t n = *num_unique_ids * embedding_size; // The n is model_diff elem_cnt. CUDA_1D_KERNEL_LOOP(i, n) { const int32_t row = i / embedding_size; const int32_t col = i - row * embedding_size; int64_t model_offset = row * line_size + col; int64_t m_offset = model_offset + embedding_size; int64_t v_offset = model_offset + 2 * embedding_size; int64_t step_offset = row * line_size + step_col_offset; const T model_val = *(unique_values + model_offset); const T m_val = *(unique_values + m_offset); const T v_val = *(unique_values + v_offset); T model_diff_t = CastScaleRegularizeGradientFunctor()(*(model_diff + i), model_val, scale, l1, l2); int64_t prev_step = *reinterpret_cast(unique_values + step_offset); int64_t cur_step = *train_step_ptr + 1; int64_t skip_step = cur_step - prev_step; float catchup = 0.0; if (skip_step > 1) { catchup = m_val * beta1 * (1 - pow(beta1, skip_step - 1)) / (1 - beta1); } const T next_m = pow(beta1, skip_step) * m_val + (1 - beta1) * model_diff_t; const T next_v = pow(beta2, skip_step) * v_val + (1 - beta2) * model_diff_t * model_diff_t; updated_unique_values[m_offset] = next_m; updated_unique_values[v_offset] = next_v; updated_unique_values[model_offset] = model_val - (learning_rate_val * (next_m + catchup)) / (sqrt(next_v) + epsilon); if (col == 0) { *reinterpret_cast(updated_unique_values + step_offset) = cur_step; } } } } template __global__ void AdagradUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale, float l1, float l2, float weight_decay, float lr_decay, float epsilon, float learning_rate_val, int64_t train_step, const IDX* num_unique_ids, const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * line_size; CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; } } else { if (train_step_ptr != nullptr) { train_step = *train_step_ptr + 1; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } learning_rate_val = learning_rate_val / (1 + (train_step - 1) * lr_decay); const int64_t n = *num_unique_ids * embedding_size; CUDA_1D_KERNEL_LOOP(i, n) { int64_t model_offset; int64_t sum_offset; GetMomentumOffset(line_size, embedding_size, i, &model_offset, &sum_offset); updated_unique_values[model_offset] = unique_values[model_offset]; updated_unique_values[sum_offset] = unique_values[sum_offset]; AdagradUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, updated_unique_values + sum_offset, scale, l1, l2, epsilon, weight_decay, learning_rate_val); } } } __device__ void GetFtrlOffset(const int32_t line_size, const int32_t embedding_size, int64_t model_diff_offset, int64_t* model_offset, int64_t* accumulate_offset, int64_t* z_offset) { const int32_t row = model_diff_offset / embedding_size; const int32_t col = model_diff_offset - row * embedding_size; *model_offset = row * line_size + col; *accumulate_offset = *model_offset + embedding_size; *z_offset = *model_offset + 2 * embedding_size; } template __global__ void FtrlUpdateKernel(const int32_t line_size, const int32_t embedding_size, T scale, float l1, float l2, float weight_decay, float lr_power, float lambda1, float lambda2, float beta, float learning_rate_val, const IDX* num_unique_ids, const float* learning_rate, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) { if (skip_if != nullptr && *skip_if != 0) { const int64_t n = *num_unique_ids * line_size; CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; } } else { if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } const int64_t n = *num_unique_ids * embedding_size; CUDA_1D_KERNEL_LOOP(i, n) { int64_t model_offset; int64_t accumulate_offset; int64_t z_offset; GetFtrlOffset(line_size, embedding_size, i, &model_offset, &accumulate_offset, &z_offset); updated_unique_values[model_offset] = unique_values[model_offset]; updated_unique_values[accumulate_offset] = unique_values[accumulate_offset]; updated_unique_values[z_offset] = unique_values[z_offset]; FtrlUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, updated_unique_values + accumulate_offset, updated_unique_values + z_offset, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val); } } } class EmbeddingUpdateKernelState final : public user_op::OpKernelState { public: explicit EmbeddingUpdateKernelState(user_op::KernelInitContext* ctx) { const std::string& embedding_name = ctx->Attr("embedding_name"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); embedding_state_ = Singleton::Get()->GetEmbeddingState( embedding_name, parallel_id); } ~EmbeddingUpdateKernelState() override = default; embedding::EmbeddingState* EmbeddingState() { return embedding_state_; } private: embedding::EmbeddingState* embedding_state_; }; } // namespace template class SgdEmbeddingUpdateKernel final : public user_op::OpKernel { public: SgdEmbeddingUpdateKernel() = default; ~SgdEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2); const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); CHECK_EQ(line_size, embedding_size); const auto scale = ctx->Attr("scale"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; SGDUpdateKernel <<stream()->As()->cuda_stream()>>>( embedding_size, scale, l1, l2, weight_decay, learning_rate_val, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define IDX_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define REGISTER_CUDA_ONE_EMBEDDING_SGD_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_sgd_update") \ .SetCreateFn< \ SgdEmbeddingUpdateKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_SGD_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class MomentumEmbeddingUpdateKernel final : public user_op::OpKernel { public: MomentumEmbeddingUpdateKernel() : current_iter_(0){}; ~MomentumEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2); const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); CHECK_EQ(line_size, embedding_size * 2); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta = ctx->Attr("beta"); // TODO: Suppoprt dampening, nesterov, maximize in OneEmbeddingMomentumUpdate(zhengzekang). const float dampening = 0.0; const bool nesterov = false; const bool maximize = false; const auto scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; MomentumUpdateKernel <<stream()->As()->cuda_stream()>>>( line_size, embedding_size, scale, l1, l2, weight_decay, beta, dampening, nesterov, maximize, learning_rate_val, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_MOMENTUM_UPDATE_KERNEL(t_dtype_pair, g_type_pair, \ idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_momentum_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_MOMENTUM_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class AdamEmbeddingUpdateKernel final : public user_op::OpKernel { public: AdamEmbeddingUpdateKernel() : current_iter_(0){}; ~AdamEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2); const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); CHECK_EQ(line_size, embedding_size * 3); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); const auto scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } const float bias_correction1_val = ctx->Attr("bias_correction1_val"); const float* bias_correction1_ptr = nullptr; if (ctx->has_input("bias_correction1", 0)) { bias_correction1_ptr = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0)->dptr(); } const float bias_correction2_val = ctx->Attr("bias_correction2_val"); const float* bias_correction2_ptr = nullptr; if (ctx->has_input("bias_correction2", 0)) { bias_correction2_ptr = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0)->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; AdamUpdateKernel <<stream()->As()->cuda_stream()>>>( line_size, embedding_size, static_cast(scale), l1, l2, weight_decay, beta1, beta2, epsilon, learning_rate_val, bias_correction1_val, bias_correction2_val, bias_correction1_ptr, bias_correction2_ptr, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_ADAM_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_adam_update") \ .SetCreateFn< \ AdamEmbeddingUpdateKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_ADAM_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class SmartDecaySparseAdamEmbeddingUpdateKernel final : public user_op::OpKernel { public: SmartDecaySparseAdamEmbeddingUpdateKernel() : current_iter_(0){}; ~SmartDecaySparseAdamEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2); const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); const auto scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const int64_t train_step_val = ctx->Attr("train_step_val"); const int64_t* train_step_ptr = nullptr; if (ctx->has_input("train_step", 0)) { const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex("train_step", 0); train_step_ptr = train_step->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; const int64_t value_dtype_size = GetSizeOfDataType(updated_unique_embeddings->data_type()); const int64_t step_dtype_size = sizeof(int64_t); const int64_t model_and_states_bytes = embedding_size * 3 * value_dtype_size; const int64_t align_to_step_size_bytes = (model_and_states_bytes + step_dtype_size - 1) / step_dtype_size * step_dtype_size; const int64_t step_col_offset = align_to_step_size_bytes / value_dtype_size; const int64_t smart_decay_sparse_adam_line_size = (align_to_step_size_bytes + step_dtype_size) / value_dtype_size; CHECK_EQ(line_size, smart_decay_sparse_adam_line_size); SmartDecaySparseAdamUpdateKernel <<stream()->As()->cuda_stream()>>>( line_size, embedding_size, static_cast(scale), l1, l2, weight_decay, beta1, beta2, epsilon, learning_rate_val, step_col_offset, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, train_step_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_SMART_DECAY_SPARSE_ADAM_UPDATE_KERNEL( \ t_dtype_pair, g_type_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_smart_decay_sparse_adam_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_SMART_DECAY_SPARSE_ADAM_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class AdagradEmbeddingUpdateKernel final : public user_op::OpKernel { public: AdagradEmbeddingUpdateKernel() : current_iter_(0){}; ~AdagradEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); user_op::Tensor* updated_unique_embeddings = ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2); const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); CHECK_EQ(line_size, embedding_size * 2); const float l1 = ctx->Attr("l1"); const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto lr_decay = ctx->Attr("lr_decay"); const auto epsilon = ctx->Attr("epsilon"); const auto scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const int64_t train_step_val = ctx->Attr("train_step_val"); const int64_t* train_step_ptr = nullptr; if (ctx->has_input("train_step", 0)) { const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex("train_step", 0); train_step_ptr = train_step->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; AdagradUpdateKernel <<stream()->As()->cuda_stream()>>>( line_size, embedding_size, static_cast(scale), l1, l2, weight_decay, lr_decay, epsilon, learning_rate_val, train_step_val, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, train_step_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_ADAGRAD_UPDATE_KERNEL(t_dtype_pair, g_type_pair, \ idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_adagrad_update") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_ADAGRAD_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template class FtrlEmbeddingUpdateKernel final : public user_op::OpKernel { public: FtrlEmbeddingUpdateKernel() : current_iter_(0){}; ~FtrlEmbeddingUpdateKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared(ctx); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* kernel_state = dynamic_cast(state); CHECK(kernel_state != nullptr); embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState(); embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_); const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2) << "The NumAxes of embedding_grad should be equal to 2. "; const int64_t line_size = ctx->Attr("line_size"); const int64_t embedding_size = ctx->Attr("embedding_size"); CHECK_EQ(line_size, embedding_size * 3) << "The line_size should be equal to 3 x embedding_size. "; const float l1 = 0.0; const float l2 = 0.0; const float weight_decay = ctx->Attr("weight_decay"); // TODO(zhengzekang): Undefined behavior for ftrl optimizer with weight_decay in `abs(new_z_val) // < lambda1` condition. CHECK_EQ(weight_decay, static_cast(0.0)) << "Currently not support for setting weight decay. "; const float lr_power = ctx->Attr("lr_power"); const float lambda1 = ctx->Attr("lambda1"); const float lambda2 = ctx->Attr("lambda2"); const float beta = ctx->Attr("beta"); const double scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } const T* down_scale_by_ptr = nullptr; if (ctx->has_input("down_scale_by_tensor", 0)) { const user_op::Tensor* down_scale_by_tensor = ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1); down_scale_by_ptr = down_scale_by_tensor->dptr(); } const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); learning_rate_ptr = learning_rate->dptr(); } const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape_view().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } // update kernel const T* unique_embeddings_ptr = reinterpret_cast(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_)); T* updated_unique_embeddings_ptr = reinterpret_cast( embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_)); const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_); const int64_t embedding_grad_elem_cnt = num_unique * embedding_size; FtrlUpdateKernel <<stream()->As()->cuda_stream()>>>( line_size, embedding_size, static_cast(scale), l1, l2, weight_decay, lr_power, lambda1, lambda2, beta, learning_rate_val, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } mutable int64_t current_iter_; }; #define REGISTER_CUDA_ONE_EMBEDDING_FTRL_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair) \ REGISTER_USER_KERNEL("one_embedding_ftrl_update") \ .SetCreateFn< \ FtrlEmbeddingUpdateKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FTRL_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_hot_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { template class CpuOneHotKernel final : public user_op::OpKernel { public: CpuOneHotKernel() = default; ~CpuOneHotKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_indices = indices->shape_view().elem_cnt(); const int64_t depth = ctx->Attr("depth"); const DataType dtype = ctx->Attr("dtype"); const T on_value = IsFloatingDataType(dtype) ? static_cast(ctx->Attr("floating_on_value")) : static_cast(ctx->Attr("integer_on_value")); const T off_value = IsFloatingDataType(dtype) ? static_cast(ctx->Attr("floating_off_value")) : static_cast(ctx->Attr("integer_off_value")); const K* indices_dptr = indices->dptr(); T* out_dptr = out->mut_dptr(); std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out->data_type()); CHECK(fill); fill->Launch(ctx->stream(), out->mut_dptr(), off_value, out->shape_view().elem_cnt()); FOR_RANGE(int64_t, i, 0, num_indices) { const int64_t idx = indices_dptr[i]; CHECK_GE(idx, 0); CHECK_LT(idx, depth); out_dptr[i * depth + idx] = on_value; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_ONE_HOT_KERNEL(dtype, itype) \ REGISTER_USER_KERNEL("one_hot").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("indices", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CPU_ONE_HOT_KERNEL(int32_t, int32_t) REGISTER_CPU_ONE_HOT_KERNEL(int32_t, int64_t) REGISTER_CPU_ONE_HOT_KERNEL(int64_t, int32_t) REGISTER_CPU_ONE_HOT_KERNEL(int64_t, int64_t) REGISTER_CPU_ONE_HOT_KERNEL(float, int32_t) REGISTER_CPU_ONE_HOT_KERNEL(float, int64_t) REGISTER_CPU_ONE_HOT_KERNEL(double, int32_t) REGISTER_CPU_ONE_HOT_KERNEL(double, int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/one_hot_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace { template __global__ void OneHotEncodeGpu(int64_t elem_cnt, const int64_t depth, const T on_value, const T off_value, const K* indices, T* out) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int64_t row = i / depth; const int64_t col = i - row * depth; const int64_t idx = indices[row]; assert(idx >= 0 && idx < depth); out[i] = (idx == col) ? on_value : off_value; } } } // namespace template class GpuOneHotKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: GpuOneHotKernel() = default; ~GpuOneHotKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_indices = indices->shape_view().elem_cnt(); const int64_t depth = ctx->Attr("depth"); const DataType dtype = ctx->Attr("dtype"); const T on_value = IsFloatingDataType(dtype) ? static_cast(ctx->Attr("floating_on_value")) : static_cast(ctx->Attr("integer_on_value")); const T off_value = IsFloatingDataType(dtype) ? static_cast(ctx->Attr("floating_off_value")) : static_cast(ctx->Attr("integer_off_value")); RUN_CUDA_KERNEL((OneHotEncodeGpu), ctx->stream(), num_indices * depth, num_indices * depth, depth, on_value, off_value, indices->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_ONE_HOT_KERNEL(dtype, itype) \ REGISTER_USER_KERNEL("one_hot").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("indices", 0) == GetDataType::value) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CUDA_ONE_HOT_KERNEL(int32_t, int32_t) REGISTER_CUDA_ONE_HOT_KERNEL(int32_t, int64_t) REGISTER_CUDA_ONE_HOT_KERNEL(int64_t, int32_t) REGISTER_CUDA_ONE_HOT_KERNEL(int64_t, int64_t) REGISTER_CUDA_ONE_HOT_KERNEL(float, int32_t) REGISTER_CUDA_ONE_HOT_KERNEL(float, int64_t) REGISTER_CUDA_ONE_HOT_KERNEL(double, int32_t) REGISTER_CUDA_ONE_HOT_KERNEL(double, int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/ones_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/switch_func.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } class OnesLikeKernel final : public user_op::OpKernel { public: OnesLikeKernel() = default; ~OnesLikeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), out->data_type()); CHECK(fill); fill->Launch(ctx->stream(), out->mut_dptr(), 1, out->shape_view().elem_cnt()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto FillPrimitiveExists() { return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewFillPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("ones_like") .SetCreateFn() .SetIsMatchedHob(FillPrimitiveExists()); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/op_kernel_wrapper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_ #define ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_ #include "oneflow/core/framework/op_kernel.h" namespace oneflow { template class OpKernelStateWrapper final : public user_op::OpKernelState { public: template explicit OpKernelStateWrapper(Args&&... args) : data_(std::forward(args)...) {} ~OpKernelStateWrapper() = default; const T& Get() const { return data_; } T* Mutable() { return &data_; } private: T data_; }; template class OpKernelCacheWrapper final : public user_op::OpKernelCache { public: template explicit OpKernelCacheWrapper(Args&&... args) : data_(std::forward(args)...) {} ~OpKernelCacheWrapper() = default; const T& Get() const { return data_; } T* Mutable() { return &data_; } private: T data_; }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_ ================================================ FILE: oneflow/user/kernels/p2p_comm_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/rank_group.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" namespace oneflow { namespace { namespace { auto SendCollectiveCommunicationExists() { return hob::make_custom("SendCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsSendRegistered(device_type); }); } auto RecvCollectiveCommunicationExists() { return hob::make_custom("RecvCollectiveCommunicationExists", [=](const user_op::KernelRegContext& ctx) { DeviceType device_type = ctx.device_type(); return ccl::IsRecvRegistered(device_type); }); } } // namespace class SendKernel final : public user_op::OpKernel { public: SendKernel() = default; ~SendKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const auto& dst_process_id = ctx->Attr("dst_process_id"); std::unique_ptr send = ccl::NewCollectiveCommunication(ctx->device_type(), in->data_type()); send->Launch(ctx->stream(), in->dptr(), in->shape_view().elem_cnt(), dst_process_id); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; class RecvKernel final : public user_op::OpKernel { public: RecvKernel() = default; ~RecvKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto& src_process_id = ctx->Attr("src_process_id"); std::unique_ptr recv = ccl::NewCollectiveCommunication(ctx->device_type(), out->data_type()); recv->Launch(ctx->stream(), out->mut_dptr(), out->shape_view().elem_cnt(), src_process_id); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("send").SetCreateFn().SetIsMatchedHob( SendCollectiveCommunicationExists()); REGISTER_USER_KERNEL("recv").SetCreateFn().SetIsMatchedHob( RecvCollectiveCommunicationExists()); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/pack_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace { template class PackKernel final : public user_op::OpKernel { public: PackKernel() = default; ~PackKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>>( std::make_pair(0, 0)); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const auto pack_num = ctx->Attr("pack_num"); if (in->shape_view().NumAxes() > 0) { CHECK_EQ(in->shape_view().NumAxes(), out->shape_view().NumAxes()); CHECK_EQ(out->shape_view().At(0), in->shape_view().At(0) * pack_num); for (int64_t i = 1; i < in->shape_view().NumAxes(); ++i) { CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i)); } } else { // NOTE(chengcheng): for Scalar input pack CHECK_EQ(in->shape_view().NumAxes(), 0); CHECK_EQ(out->shape_view().NumAxes(), 1); CHECK_EQ(in->shape_view().elem_cnt(), 1); CHECK_EQ(out->shape_view().elem_cnt(), pack_num); } const int64_t copy_size = in->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()); auto* state_wrapper = dynamic_cast>*>(state); CHECK_NOTNULL(state_wrapper); const size_t index = state_wrapper->Get().first; CHECK_EQ(state_wrapper->Get().second, pack_num); Memcpy(ctx->stream(), out->mut_dptr() + index * copy_size, in->dptr(), copy_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_PACK_KERNEL(device) \ REGISTER_USER_KERNEL("pack").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device)); OF_PP_FOR_EACH_TUPLE(REGISTER_PACK_KERNEL, DEVICE_TYPE_SEQ) #if defined(WITH_MLU) REGISTER_PACK_KERNEL(DeviceType::kMLU) #endif #undef REGISTER_PACK_KERNEL } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/pad_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/constant_pad.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewConstantPadPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("y", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto ConstantPadPrimitiveExists() { return hob::make_custom("ConstantPadPrimitiveExists", [](const KernelRegContext& ctx) { return NewConstantPadPrimitive(&ctx).operator bool(); }); } } // namespace class PadKernel final : public OpKernel, public CudaGraphSupport { public: PadKernel() = default; ~PadKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); if (y->shape_view().NumAxes() > 0 && y->shape_view().elem_cnt() == 0) { // if output is 0-shape tensor, than do nothing and return return; } Scalar value; if (IsIntegralDataType(x->data_type()) || x->data_type() == kBool) { value = Scalar(ctx->Attr("integral_constant_value")); } else { value = Scalar(ctx->Attr("floating_constant_value")); } const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); const int64_t ndims = x->shape_view().NumAxes(); CHECK_EQ(padding_before.size(), ndims); std::unique_ptr pad_primitive = NewConstantPadPrimitive(ctx); CHECK(pad_primitive); pad_primitive->Launch(ctx->stream(), ndims, x->shape_view().ptr(), x->dptr(), padding_before.data(), padding_after.data(), value, y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("pad").SetCreateFn().SetIsMatchedHob(ConstantPadPrimitiveExists()); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/partial_fc_sample_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/kernels/gather_kernel_util.h" #include "oneflow/core/common/not_equal_to_previous_adjacent_iterator.h" #include #include #include #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { template int64_t GetCubSortPairsTempStorageSize(int64_t n) { size_t cub_sort_temp_store_size = 0; OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs(nullptr, cub_sort_temp_store_size, nullptr, nullptr, nullptr, nullptr, n))); size_t temp_store_size = GetCudaAlignedSize(cub_sort_temp_store_size); CHECK_GE(temp_store_size, 0); CHECK_LT(temp_store_size, static_cast(GetMaxVal())); return static_cast(temp_store_size); } template int64_t GetCubScanTempStorageSize(int64_t n) { size_t cub_scan_temp_store_size = 0; NotEqualToPreviousAdjacentIterator unique_counting_iter(nullptr, 0); OF_CUDA_CHECK((cub::DeviceScan::InclusiveSum, K*>( nullptr, cub_scan_temp_store_size, unique_counting_iter, nullptr, n))); size_t temp_store_size = GetCudaAlignedSize(cub_scan_temp_store_size); CHECK_GE(temp_store_size, 0); CHECK_LT(temp_store_size, static_cast(GetMaxVal())); return static_cast(temp_store_size); } template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(void* ptr, const int64_t device_num_class, const int64_t batch_size, const int64_t parallel_num) : ptr_(ptr) { const int64_t buffer_elem_cnt = std::max(device_num_class, batch_size); const size_t cub_sort_keys_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K)); const size_t cub_sort_values_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K)); const size_t cub_sort_keys_out_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K)); const size_t cub_sort_values_out_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K)); const size_t bound_index_bytes = GetCudaAlignedSize((parallel_num + 1) * sizeof(K)); const size_t bound_value_bytes = GetCudaAlignedSize((parallel_num + 1) * sizeof(K)); cub_tmp_storage_bytes_ = std::max(GetCubSortPairsTempStorageSize(buffer_elem_cnt), GetCubScanTempStorageSize(batch_size)); cub_sort_keys_offset_ = 0; cub_sort_values_offset_ = cub_sort_keys_offset_ + cub_sort_keys_bytes; cub_sort_keys_out_offset_ = cub_sort_values_offset_ + cub_sort_values_bytes; cub_sort_values_out_offset_ = cub_sort_keys_out_offset_ + cub_sort_keys_out_bytes; cub_tmp_storage_offset_ = cub_sort_values_out_offset_ + cub_sort_values_out_bytes; bound_index_offset_ = cub_tmp_storage_offset_ + cub_tmp_storage_bytes_; bound_value_offset_ = bound_index_offset_ + bound_index_bytes; total_buffer_size_ = cub_sort_keys_bytes + cub_sort_values_bytes + cub_sort_keys_out_bytes + cub_sort_values_out_bytes + cub_tmp_storage_bytes_ + bound_index_bytes + bound_value_bytes; } ~TmpBufferManager() = default; size_t GetTotalBufferSize() const { return total_buffer_size_; } size_t GetCubTmpStorageSize() const { return cub_tmp_storage_bytes_; } K* CubSortKeysPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + cub_sort_keys_offset_); } K* CubSortValuesPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + cub_sort_values_offset_); } K* CubSortKeysOutPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + cub_sort_keys_out_offset_); } K* CubSortValuesOutPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + cub_sort_values_out_offset_); } void* CubTmpStoragePtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + cub_tmp_storage_offset_); } K* BoundIndexPtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + bound_index_offset_); } K* BoundValuePtr() const { CHECK(ptr_ != nullptr); return reinterpret_cast(reinterpret_cast(ptr_) + bound_value_offset_); } private: size_t cub_sort_keys_offset_; size_t cub_sort_values_offset_; size_t cub_sort_keys_out_offset_; size_t cub_sort_values_out_offset_; size_t cub_tmp_storage_offset_; size_t bound_index_offset_; size_t bound_value_offset_; size_t cub_tmp_storage_bytes_; size_t total_buffer_size_; void* ptr_; }; __global__ void SetupKernel(int64_t seed, curandState* state) { const int id = blockIdx.x * blockDim.x + threadIdx.x; size_t local_seed = (static_cast(seed) + 0x9e3779b9U + (static_cast(id) << 6U) + (static_cast(id) >> 2U)); curand_init(local_seed, 0, 0, &state[id]); } template __global__ void GenerateGpu(curandState* state, const int64_t n, const int64_t max_val, K* buffer) { const int id = blockIdx.x * blockDim.x + threadIdx.x; curandState localState = state[id]; CUDA_1D_KERNEL_LOOP(i, n) { buffer[i] = static_cast(curand(&localState) % max_val); } state[id] = localState; } class DistributedPartialFcSampleOpKernelState final : public user_op::OpKernelState { public: DistributedPartialFcSampleOpKernelState(ep::Stream* stream, int64_t lower, int64_t upper, int64_t num_sample_per_rank, int64_t seed) : lower_(lower), upper_(upper), num_sample_per_rank_(num_sample_per_rank) { CHECK_NOTNULL(stream); const int64_t num_classes = upper_ - lower_; OF_CUDA_CHECK(cudaMalloc(&curand_states_, BlocksNum4ThreadsNum(num_classes) * kCudaThreadsNumPerBlock * sizeof(curandState))); SetupKernel<<As()->cuda_stream()>>>(seed, curand_states_); } ~DistributedPartialFcSampleOpKernelState() { cudaError_t ret = cudaFree(curand_states_); if (ret != cudaErrorCudartUnloading) { OF_CUDA_CHECK(ret); } }; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } int64_t num_sample_per_rank() const { return num_sample_per_rank_; } template void GenRandom(ep::Stream* stream, const int64_t n, const int64_t max_val, K* buffer) { GenerateGpu <<As()->cuda_stream()>>>(curand_states_, n, max_val, buffer); } private: const int64_t lower_; const int64_t upper_; const int64_t num_sample_per_rank_; curandState* curand_states_; }; template __global__ void IotaKernel(int64_t n, K* out) { CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast(i); } } template __global__ void MarkPositive(const int64_t n, const int64_t offset, const int64_t num_classes, const K* labels, K* out) { CUDA_1D_KERNEL_LOOP(i, n) { K label = labels[i] - offset; if (label >= 0 && label < num_classes) { out[label] = label - num_classes; } } } template __global__ void GetSampledLabel(const int64_t n, const int64_t offset, const K* label, K* sampled_label) { CUDA_1D_KERNEL_LOOP(i, n) { sampled_label[i] = label[i] + offset; } } template __global__ void GetLabelMap(const int64_t n, const int64_t parallel_num, const int64_t num_sample_per_rank, const K* bound_index, const K* bound_value, K* label_map) { CUDA_1D_KERNEL_LOOP(i, n) { #pragma unroll for (int64_t j = 0; j < parallel_num; j++) { if (i >= bound_index[j] && i < bound_index[j + 1]) { label_map[i] = label_map[i] - bound_value[j] + j * num_sample_per_rank; } } } } template __global__ void GetPartionBound(const int64_t n, const int64_t parallel_num, const int64_t num_classes_per_rank, const K* key_ptr, const K* value_ptr, K* bound_index, K* bound_value) { CUDA_1D_KERNEL_LOOP(i, n) { if (i != 0) { const K cur_in = key_ptr[i] / num_classes_per_rank; const K pre_in = key_ptr[i - 1] / num_classes_per_rank; if (cur_in > pre_in) { assert(cur_in < parallel_num); #pragma unroll for (int32_t j = pre_in + 1; j <= cur_in; ++j) { bound_index[j] = static_cast(i); bound_value[j] = value_ptr[i]; } } } } CUDA_1D_KERNEL_LOOP(i, parallel_num + 1) { const K first_in = key_ptr[0] / num_classes_per_rank; const K last_in = key_ptr[n - 1] / num_classes_per_rank; if (i <= first_in) { bound_index[i] = 0; bound_value[i] = value_ptr[0]; } else if (i > last_in) { bound_index[i] = n; bound_value[i] = value_ptr[n - 1]; } } } template __global__ void GetMappedLabel(const int64_t n, const K* label_map_key, const K* label_map_value, K* mapped_label) { CUDA_1D_KERNEL_LOOP(i, n) { mapped_label[label_map_key[i]] = label_map_value[i]; } } template void MapLabel(ep::Stream* stream, const int64_t num_classes, const int64_t batch_size, const int64_t lower_bound, const int64_t parallel_num, const int64_t num_sample, size_t temp_storage_bytes, const K* label_ptr, K* mapped_label_ptr, K* cub_sort_values_ptr, K* cub_sort_keys_out_ptr, K* cub_sort_values_out_ptr, void* cub_tmp_storage_ptr, K* bound_index_ptr, K* bound_value_ptr) { IotaKernel<<As()->cuda_stream()>>>(batch_size, cub_sort_values_ptr); OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs( cub_tmp_storage_ptr, temp_storage_bytes, label_ptr, cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_values_out_ptr, batch_size, 0, sizeof(K) * 8, stream->As()->cuda_stream()))); NotEqualToPreviousAdjacentIterator unique_counting_iter(cub_sort_keys_out_ptr, 0); OF_CUDA_CHECK((cub::DeviceScan::InclusiveSum, K*>( cub_tmp_storage_ptr, temp_storage_bytes, unique_counting_iter, cub_sort_values_ptr, batch_size, stream->As()->cuda_stream()))); GetPartionBound<<As()->cuda_stream()>>>( batch_size, parallel_num, num_classes, cub_sort_keys_out_ptr, cub_sort_values_ptr, bound_index_ptr, bound_value_ptr); GetLabelMap<<As()->cuda_stream()>>>( batch_size, parallel_num, num_sample, bound_index_ptr, bound_value_ptr, cub_sort_values_ptr); GetMappedLabel<<As()->cuda_stream()>>>( batch_size, cub_sort_values_out_ptr, cub_sort_values_ptr, mapped_label_ptr); } } // namespace template class DistributedPartialFcSampleGpuKernel final : public user_op::OpKernel { public: DistributedPartialFcSampleGpuKernel() = default; ~DistributedPartialFcSampleGpuKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex("weight", 0); const TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("weight", 0); const int64_t class_num = in_logical_desc->shape().At(0); const int64_t num_sample = ctx->Attr("num_sample"); int64_t seed = ctx->Attr("seed"); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t num_sample_per_rank = RoundUp(num_sample, parallel_num) / parallel_num; if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 0 && parallel_num > 1) { std::seed_seq seq{seed}; std::vector seeds(parallel_num); seq.generate(seeds.begin(), seeds.end()); seed = seeds.at(ctx->parallel_ctx().parallel_id()); CHECK(ctx->SbpParallel4ArgNameAndIndex("label", 0).has_broadcast_parallel()); BalancedSplitter bs(class_num, parallel_num); return std::make_shared( ctx->stream(), bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end(), num_sample_per_rank, seed); } else { return std::make_shared(ctx->stream(), 0, class_num, num_sample_per_rank, seed); } } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* mapped_label = ctx->Tensor4ArgNameAndIndex("mapped_label", 0); user_op::Tensor* sampled_label = ctx->Tensor4ArgNameAndIndex("sampled_label", 0); user_op::Tensor* sampled_weight = ctx->Tensor4ArgNameAndIndex("sampled_weight", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t batch_size = label->shape_view().At(0); const int64_t num_classes = weight->shape_view().At(0); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_classes, batch_size, parallel_num); auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); CHECK_EQ(num_classes, kernel_state->upper() - kernel_state->lower()); const int64_t lower_bound = kernel_state->lower(); const int64_t num_sample = kernel_state->num_sample_per_rank(); kernel_state->GenRandom(ctx->stream(), num_classes, num_classes, buffer_manager.CubSortKeysPtr()); MarkPositive<<stream()->As()->cuda_stream()>>>( batch_size, lower_bound, num_classes, label->dptr(), buffer_manager.CubSortKeysPtr()); IotaKernel<<stream()->As()->cuda_stream()>>>( num_classes, buffer_manager.CubSortValuesPtr()); size_t temp_storage_bytes = buffer_manager.GetCubTmpStorageSize(); OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs( buffer_manager.CubTmpStoragePtr(), temp_storage_bytes, buffer_manager.CubSortKeysPtr(), buffer_manager.CubSortKeysOutPtr(), buffer_manager.CubSortValuesPtr(), buffer_manager.CubSortValuesOutPtr(), num_classes, 0, sizeof(K) * 8, ctx->stream()->As()->cuda_stream()))); GetSampledLabel<<stream()->As()->cuda_stream()>>>( num_sample, lower_bound, buffer_manager.CubSortValuesOutPtr(), sampled_label->mut_dptr()); GatherKernelUtilImpl::Forward( ctx->stream(), buffer_manager.CubSortValuesOutPtr(), num_sample, weight->dptr(), Shape({1, num_classes, weight->shape_view().Count(1)}), sampled_weight->mut_dptr(), 0); MapLabel(ctx->stream(), num_classes, batch_size, lower_bound, parallel_num, num_sample, buffer_manager.GetCubTmpStorageSize(), label->dptr(), mapped_label->mut_dptr(), buffer_manager.CubSortValuesPtr(), buffer_manager.CubSortKeysOutPtr(), buffer_manager.CubSortValuesOutPtr(), buffer_manager.CubTmpStoragePtr(), buffer_manager.BoundIndexPtr(), buffer_manager.BoundValuePtr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_CUDA_KERNEL(dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL("distributed_partial_fc_sample") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("weight", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const int64_t num_classes = ctx->InputTensorDesc("weight", 0).shape().At(0); \ const int64_t batch_size = ctx->InputTensorDesc("label", 0).shape().At(0); \ const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); \ TmpBufferManager buffer_manager(nullptr, num_classes, \ batch_size, parallel_num); \ return buffer_manager.GetTotalBufferSize(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_CUDA_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class DistributedPartialFcSampleDisableBoxingGpuKernel final : public user_op::OpKernel { public: DistributedPartialFcSampleDisableBoxingGpuKernel() = default; ~DistributedPartialFcSampleDisableBoxingGpuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* sampled_weight_diff = ctx->Tensor4ArgNameAndIndex("sampled_weight_diff", 0); const user_op::Tensor* sampled_label = ctx->Tensor4ArgNameAndIndex("sampled_label", 0); user_op::Tensor* boxing_disabled_sampled_weight_diff = ctx->Tensor4ArgNameAndIndex("boxing_disabled_sampled_weight_diff", 0); user_op::Tensor* boxing_disabled_sampled_label = ctx->Tensor4ArgNameAndIndex("boxing_disabled_sampled_label", 0); Memcpy(ctx->stream(), boxing_disabled_sampled_weight_diff->mut_dptr(), sampled_weight_diff->dptr(), sampled_weight_diff->shape_view().elem_cnt() * GetSizeOfDataType(sampled_weight_diff->data_type())); Memcpy( ctx->stream(), boxing_disabled_sampled_label->mut_dptr(), sampled_label->dptr(), sampled_label->shape_view().elem_cnt() * GetSizeOfDataType(sampled_label->data_type())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_DISABLE_BOXING_CUDA_KERNEL(dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL("distributed_partial_fc_sample_disable_boxing") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("sampled_label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("sampled_weight_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_DISABLE_BOXING_CUDA_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/pocketfft_hdronly.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* This file is part of pocketfft. Copyright (C) 2010-2021 Max-Planck-Society Copyright (C) 2019-2020 Peter Bell For the odd-sized DCT-IV transforms: Copyright (C) 2003, 2007-14 Matteo Frigo Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology Authors: Martin Reinecke, Peter Bell All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #ifndef POCKETFFT_HDRONLY_H #define POCKETFFT_HDRONLY_H #ifndef __cplusplus #error This file is C++ and requires a C++ compiler. #endif #if !(__cplusplus >= 201103L || _MSVC_LANG + 0L >= 201103L) #error This file requires at least C++11 support. #endif #ifndef POCKETFFT_CACHE_SIZE #define POCKETFFT_CACHE_SIZE 0 #endif #include #include #include #include #include #include #include #if POCKETFFT_CACHE_SIZE != 0 #include #include #endif #ifndef POCKETFFT_NO_MULTITHREADING #include #include #include #include #include #include #include #ifdef POCKETFFT_PTHREADS #include #endif #endif #if defined(__GNUC__) #define POCKETFFT_NOINLINE __attribute__((noinline)) #define POCKETFFT_RESTRICT __restrict__ #elif defined(_MSC_VER) #define POCKETFFT_NOINLINE __declspec(noinline) #define POCKETFFT_RESTRICT __restrict #else #define POCKETFFT_NOINLINE #define POCKETFFT_RESTRICT #endif namespace pocketfft { namespace detail { using std::ptrdiff_t; using std::size_t; // Always use std:: for functions template T cos(T) = delete; template T sin(T) = delete; template T sqrt(T) = delete; using shape_t = std::vector; using stride_t = std::vector; constexpr bool FORWARD = true, BACKWARD = false; // only enable vector support for gcc>=5.0 and clang>=5.0 #ifndef POCKETFFT_NO_VECTORS #define POCKETFFT_NO_VECTORS #if defined(__INTEL_COMPILER) // do nothing. This is necessary because this compiler also sets __GNUC__. #elif defined(__clang__) // AppleClang has their own version numbering #ifdef __apple_build_version__ #if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) #undef POCKETFFT_NO_VECTORS #endif #elif __clang_major__ >= 5 #undef POCKETFFT_NO_VECTORS #endif #elif defined(__GNUC__) #if __GNUC__ >= 5 #undef POCKETFFT_NO_VECTORS #endif #endif #endif template struct VLEN { static constexpr size_t val = 1; }; #ifndef POCKETFFT_NO_VECTORS #if (defined(__AVX512F__)) template<> struct VLEN { static constexpr size_t val = 16; }; template<> struct VLEN { static constexpr size_t val = 8; }; #elif (defined(__AVX__)) template<> struct VLEN { static constexpr size_t val = 8; }; template<> struct VLEN { static constexpr size_t val = 4; }; #elif (defined(__SSE2__)) template<> struct VLEN { static constexpr size_t val = 4; }; template<> struct VLEN { static constexpr size_t val = 2; }; #elif (defined(__VSX__)) template<> struct VLEN { static constexpr size_t val = 4; }; template<> struct VLEN { static constexpr size_t val = 2; }; #elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) template<> struct VLEN { static constexpr size_t val = 4; }; template<> struct VLEN { static constexpr size_t val = 2; }; #else #define POCKETFFT_NO_VECTORS #endif #endif #if __cplusplus >= 201703L inline void* aligned_alloc(size_t align, size_t size) { // aligned_alloc() requires that the requested size is a multiple of "align" void* ptr = ::aligned_alloc(align, (size + align - 1) & (~(align - 1))); if (!ptr) throw std::bad_alloc(); return ptr; } inline void aligned_dealloc(void* ptr) { free(ptr); } #else // portable emulation inline void* aligned_alloc(size_t align, size_t size) { align = std::max(align, alignof(max_align_t)); void* ptr = malloc(size + align); if (!ptr) throw std::bad_alloc(); void* res = reinterpret_cast((reinterpret_cast(ptr) & ~(uintptr_t(align - 1))) + uintptr_t(align)); (reinterpret_cast(res))[-1] = ptr; return res; } inline void aligned_dealloc(void* ptr) { if (ptr) free((reinterpret_cast(ptr))[-1]); } #endif template class arr { private: T* p; size_t sz; #if defined(POCKETFFT_NO_VECTORS) static T* ralloc(size_t num) { if (num == 0) return nullptr; void* res = malloc(num * sizeof(T)); if (!res) throw std::bad_alloc(); return reinterpret_cast(res); } static void dealloc(T* ptr) { free(ptr); } #else static T* ralloc(size_t num) { if (num == 0) return nullptr; void* ptr = aligned_alloc(64, num * sizeof(T)); return static_cast(ptr); } static void dealloc(T* ptr) { aligned_dealloc(ptr); } #endif public: arr() : p(0), sz(0) {} arr(size_t n) : p(ralloc(n)), sz(n) {} arr(arr&& other) : p(other.p), sz(other.sz) { other.p = nullptr; other.sz = 0; } ~arr() { dealloc(p); } void resize(size_t n) { if (n == sz) return; dealloc(p); p = ralloc(n); sz = n; } T& operator[](size_t idx) { return p[idx]; } const T& operator[](size_t idx) const { return p[idx]; } T* data() { return p; } const T* data() const { return p; } size_t size() const { return sz; } }; template struct cmplx { T r, i; cmplx() {} cmplx(T r_, T i_) : r(r_), i(i_) {} void Set(T r_, T i_) { r = r_; i = i_; } void Set(T r_) { r = r_; i = T(0); } cmplx& operator+=(const cmplx& other) { r += other.r; i += other.i; return *this; } template cmplx& operator*=(T2 other) { r *= other; i *= other; return *this; } template cmplx& operator*=(const cmplx& other) { T tmp = r * other.r - i * other.i; i = r * other.i + i * other.r; r = tmp; return *this; } template cmplx& operator+=(const cmplx& other) { r += other.r; i += other.i; return *this; } template cmplx& operator-=(const cmplx& other) { r -= other.r; i -= other.i; return *this; } template auto operator*(const T2& other) const -> cmplx { return {r * other, i * other}; } template auto operator+(const cmplx& other) const -> cmplx { return {r + other.r, i + other.i}; } template auto operator-(const cmplx& other) const -> cmplx { return {r - other.r, i - other.i}; } template auto operator*(const cmplx& other) const -> cmplx { return {r * other.r - i * other.i, r * other.i + i * other.r}; } template auto special_mul(const cmplx& other) const -> cmplx { using Tres = cmplx; return fwd ? Tres(r * other.r + i * other.i, i * other.r - r * other.i) : Tres(r * other.r - i * other.i, r * other.i + i * other.r); } }; template inline void PM(T& a, T& b, T c, T d) { a = c + d; b = c - d; } template inline void PMINPLACE(T& a, T& b) { T t = a; a += b; b = t - b; } template inline void MPINPLACE(T& a, T& b) { T t = a; a -= b; b = t + b; } template cmplx conj(const cmplx& a) { return {a.r, -a.i}; } template void special_mul(const cmplx& v1, const cmplx& v2, cmplx& res) { res = fwd ? cmplx(v1.r * v2.r + v1.i * v2.i, v1.i * v2.r - v1.r * v2.i) : cmplx(v1.r * v2.r - v1.i * v2.i, v1.r * v2.i + v1.i * v2.r); } template void ROT90(cmplx& a) { auto tmp_ = a.r; a.r = -a.i; a.i = tmp_; } template void ROTX90(cmplx& a) { auto tmp_ = fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i = tmp_; } // // twiddle factor section // template class sincos_2pibyn { private: using Thigh = typename std::conditional<(sizeof(T) > sizeof(double)), T, double>::type; size_t N, mask, shift; arr> v1, v2; static cmplx calc(size_t x, size_t n, Thigh ang) { x <<= 3; if (x < 4 * n) // first half { if (x < 2 * n) // first quadrant { if (x < n) return cmplx(std::cos(Thigh(x) * ang), std::sin(Thigh(x) * ang)); return cmplx(std::sin(Thigh(2 * n - x) * ang), std::cos(Thigh(2 * n - x) * ang)); } else // second quadrant { x -= 2 * n; if (x < n) return cmplx(-std::sin(Thigh(x) * ang), std::cos(Thigh(x) * ang)); return cmplx(-std::cos(Thigh(2 * n - x) * ang), std::sin(Thigh(2 * n - x) * ang)); } } else { x = 8 * n - x; if (x < 2 * n) // third quadrant { if (x < n) return cmplx(std::cos(Thigh(x) * ang), -std::sin(Thigh(x) * ang)); return cmplx(std::sin(Thigh(2 * n - x) * ang), -std::cos(Thigh(2 * n - x) * ang)); } else // fourth quadrant { x -= 2 * n; if (x < n) return cmplx(-std::sin(Thigh(x) * ang), -std::cos(Thigh(x) * ang)); return cmplx(-std::cos(Thigh(2 * n - x) * ang), -std::sin(Thigh(2 * n - x) * ang)); } } } public: POCKETFFT_NOINLINE sincos_2pibyn(size_t n) : N(n) { constexpr auto pi = 3.141592653589793238462643383279502884197L; Thigh ang = Thigh(0.25L * pi / n); size_t nval = (n + 2) / 2; shift = 1; while ((size_t(1) << shift) * (size_t(1) << shift) < nval) ++shift; mask = (size_t(1) << shift) - 1; v1.resize(mask + 1); v1[0].Set(Thigh(1), Thigh(0)); for (size_t i = 1; i < v1.size(); ++i) v1[i] = calc(i, n, ang); v2.resize((nval + mask) / (mask + 1)); v2[0].Set(Thigh(1), Thigh(0)); for (size_t i = 1; i < v2.size(); ++i) v2[i] = calc(i * (mask + 1), n, ang); } cmplx operator[](size_t idx) const { if (2 * idx <= N) { auto x1 = v1[idx & mask], x2 = v2[idx >> shift]; return cmplx(T(x1.r * x2.r - x1.i * x2.i), T(x1.r * x2.i + x1.i * x2.r)); } idx = N - idx; auto x1 = v1[idx & mask], x2 = v2[idx >> shift]; return cmplx(T(x1.r * x2.r - x1.i * x2.i), -T(x1.r * x2.i + x1.i * x2.r)); } }; struct util // hack to avoid duplicate symbols { static POCKETFFT_NOINLINE size_t largest_prime_factor(size_t n) { size_t res = 1; while ((n & 1) == 0) { res = 2; n >>= 1; } for (size_t x = 3; x * x <= n; x += 2) while ((n % x) == 0) { res = x; n /= x; } if (n > 1) res = n; return res; } static POCKETFFT_NOINLINE double cost_guess(size_t n) { constexpr double lfp = 1.1; // penalty for non-hardcoded larger factors size_t ni = n; double result = 0.; while ((n & 1) == 0) { result += 2; n >>= 1; } for (size_t x = 3; x * x <= n; x += 2) while ((n % x) == 0) { result += (x <= 5) ? double(x) : lfp * double(x); // penalize larger prime factors n /= x; } if (n > 1) result += (n <= 5) ? double(n) : lfp * double(n); return result * double(ni); } /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) { if (n <= 12) return n; size_t bestfac = 2 * n; for (size_t f11 = 1; f11 < bestfac; f11 *= 11) for (size_t f117 = f11; f117 < bestfac; f117 *= 7) for (size_t f1175 = f117; f1175 < bestfac; f1175 *= 5) { size_t x = f1175; while (x < n) x *= 2; for (;;) { if (x < n) x *= 3; else if (x > n) { if (x < bestfac) bestfac = x; if (x & 1) break; x >>= 1; } else return n; } } return bestfac; } /* returns the smallest composite of 2, 3, 5 which is >= n */ static POCKETFFT_NOINLINE size_t good_size_real(size_t n) { if (n <= 6) return n; size_t bestfac = 2 * n; for (size_t f5 = 1; f5 < bestfac; f5 *= 5) { size_t x = f5; while (x < n) x *= 2; for (;;) { if (x < n) x *= 3; else if (x > n) { if (x < bestfac) bestfac = x; if (x & 1) break; x >>= 1; } else return n; } } return bestfac; } static size_t prod(const shape_t& shape) { size_t res = 1; for (auto sz : shape) res *= sz; return res; } static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, bool inplace) { auto ndim = shape.size(); if (ndim < 1) throw std::runtime_error("ndim must be >= 1"); if ((stride_in.size() != ndim) || (stride_out.size() != ndim)) throw std::runtime_error("stride dimension mismatch"); if (inplace && (stride_in != stride_out)) throw std::runtime_error("stride mismatch"); } static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, bool inplace, const shape_t& axes) { sanity_check(shape, stride_in, stride_out, inplace); auto ndim = shape.size(); shape_t tmp(ndim, 0); for (auto ax : axes) { if (ax >= ndim) throw std::invalid_argument("bad axis number"); if (++tmp[ax] > 1) throw std::invalid_argument("axis specified repeatedly"); } } static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, bool inplace, size_t axis) { sanity_check(shape, stride_in, stride_out, inplace); if (axis >= shape.size()) throw std::invalid_argument("bad axis number"); } #ifdef POCKETFFT_NO_MULTITHREADING static size_t thread_count(size_t /*nthreads*/, const shape_t& /*shape*/, size_t /*axis*/, size_t /*vlen*/) { return 1; } #else static size_t thread_count(size_t nthreads, const shape_t& shape, size_t axis, size_t vlen) { if (nthreads == 1) return 1; size_t size = prod(shape); size_t parallel = size / (shape[axis] * vlen); if (shape[axis] < 1000) parallel /= 4; size_t max_threads = nthreads == 0 ? std::thread::hardware_concurrency() : nthreads; return std::max(size_t(1), std::min(parallel, max_threads)); } #endif }; namespace threading { #ifdef POCKETFFT_NO_MULTITHREADING constexpr inline size_t thread_id() { return 0; } constexpr inline size_t num_threads() { return 1; } template void thread_map(size_t /* nthreads */, Func f) { f(); } #else inline size_t& thread_id() { static thread_local size_t thread_id_ = 0; return thread_id_; } inline size_t& num_threads() { static thread_local size_t num_threads_ = 1; return num_threads_; } static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); class latch { std::atomic num_left_; std::mutex mut_; std::condition_variable completed_; using lock_t = std::unique_lock; public: latch(size_t n) : num_left_(n) {} void count_down() { lock_t lock(mut_); if (--num_left_) return; completed_.notify_all(); } void wait() { lock_t lock(mut_); completed_.wait(lock, [this] { return is_ready(); }); } bool is_ready() { return num_left_ == 0; } }; template class concurrent_queue { std::queue q_; std::mutex mut_; std::atomic size_; using lock_t = std::lock_guard; public: void push(T val) { lock_t lock(mut_); ++size_; q_.push(std::move(val)); } bool try_pop(T& val) { if (size_ == 0) return false; lock_t lock(mut_); // Queue might have been emptied while we acquired the lock if (q_.empty()) return false; val = std::move(q_.front()); --size_; q_.pop(); return true; } bool empty() const { return size_ == 0; } }; // C++ allocator with support for over-aligned types template struct aligned_allocator { using value_type = T; template aligned_allocator(const aligned_allocator&) {} aligned_allocator() = default; T* allocate(size_t n) { void* mem = aligned_alloc(alignof(T), n * sizeof(T)); return static_cast(mem); } void deallocate(T* p, size_t /*n*/) { aligned_dealloc(p); } }; class thread_pool { // A reasonable guess, probably close enough for most hardware static constexpr size_t cache_line_size = 64; struct alignas(cache_line_size) worker { std::thread thread; std::condition_variable work_ready; std::mutex mut; std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; std::function work; void worker_main(std::atomic& shutdown_flag, std::atomic& unscheduled_tasks, concurrent_queue>& overflow_work) { using lock_t = std::unique_lock; bool expect_work = true; while (!shutdown_flag || expect_work) { std::function local_work; if (expect_work || unscheduled_tasks == 0) { lock_t lock(mut); // Wait until there is work to be executed work_ready.wait(lock, [&] { return (work || shutdown_flag); }); local_work.swap(work); expect_work = false; } bool marked_busy = false; if (local_work) { marked_busy = true; local_work(); } if (!overflow_work.empty()) { if (!marked_busy && busy_flag.test_and_set()) { expect_work = true; continue; } marked_busy = true; while (overflow_work.try_pop(local_work)) { --unscheduled_tasks; local_work(); } } if (marked_busy) busy_flag.clear(); } } }; concurrent_queue> overflow_work_; std::mutex mut_; std::vector> workers_; std::atomic shutdown_; std::atomic unscheduled_tasks_; using lock_t = std::lock_guard; void create_threads() { lock_t lock(mut_); size_t nthreads = workers_.size(); for (size_t i = 0; i < nthreads; ++i) { try { auto* worker = &workers_[i]; worker->busy_flag.clear(); worker->work = nullptr; worker->thread = std::thread( [worker, this] { worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); }); } catch (...) { shutdown_locked(); throw; } } } void shutdown_locked() { shutdown_ = true; for (auto& worker : workers_) worker.work_ready.notify_all(); for (auto& worker : workers_) if (worker.thread.joinable()) worker.thread.join(); } public: explicit thread_pool(size_t nthreads) : workers_(nthreads) { create_threads(); } thread_pool() : thread_pool(max_threads) {} ~thread_pool() { shutdown(); } void submit(std::function work) { lock_t lock(mut_); if (shutdown_) throw std::runtime_error("Work item submitted after shutdown"); ++unscheduled_tasks_; // First check for any idle workers and wake those for (auto& worker : workers_) if (!worker.busy_flag.test_and_set()) { --unscheduled_tasks_; { lock_t lock(worker.mut); worker.work = std::move(work); } worker.work_ready.notify_one(); return; } // If no workers were idle, push onto the overflow queue for later overflow_work_.push(std::move(work)); } void shutdown() { lock_t lock(mut_); shutdown_locked(); } void restart() { shutdown_ = false; create_threads(); } }; inline thread_pool& get_pool() { static thread_pool pool; #ifdef POCKETFFT_PTHREADS static std::once_flag f; std::call_once(f, [] { pthread_atfork( +[] { get_pool().shutdown(); }, // prepare +[] { get_pool().restart(); }, // parent +[] { get_pool().restart(); } // child ); }); #endif return pool; } /** Map a function f over nthreads */ template void thread_map(size_t nthreads, Func f) { if (nthreads == 0) nthreads = max_threads; if (nthreads == 1) { f(); return; } auto& pool = get_pool(); latch counter(nthreads); std::exception_ptr ex; std::mutex ex_mut; for (size_t i = 0; i < nthreads; ++i) { pool.submit([&f, &counter, &ex, &ex_mut, i, nthreads] { thread_id() = i; num_threads() = nthreads; try { f(); } catch (...) { std::lock_guard lock(ex_mut); ex = std::current_exception(); } counter.count_down(); }); } counter.wait(); if (ex) std::rethrow_exception(ex); } #endif } // namespace threading // // complex FFTPACK transforms // template class cfftp { private: struct fctdata { size_t fct; cmplx*tw, *tws; }; size_t length; arr> mem; std::vector fact; void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); } template void pass2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 2 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k); CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k); } else for (size_t k = 0; k < l1; ++k) { CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k); CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k); for (size_t i = 1; i < ido; ++i) { CH(i, k, 0) = CC(i, 0, k) + CC(i, 1, k); special_mul(CC(i, 0, k) - CC(i, 1, k), WA(0, i), CH(i, k, 1)); } } } #define POCKETFFT_PREP3(idx) \ T t0 = CC(idx, 0, k), t1, t2; \ PM(t1, t2, CC(idx, 1, k), CC(idx, 2, k)); \ CH(idx, k, 0) = t0 + t1; #define POCKETFFT_PARTSTEP3a(u1, u2, twr, twi) \ { \ T ca = t0 + t1 * twr; \ T cb{-t2.i * twi, t2.r * twi}; \ PM(CH(0, k, u1), CH(0, k, u2), ca, cb); \ } #define POCKETFFT_PARTSTEP3b(u1, u2, twr, twi) \ { \ T ca = t0 + t1 * twr; \ T cb{-t2.i * twi, t2.r * twi}; \ special_mul(ca + cb, WA(u1 - 1, i), CH(i, k, u1)); \ special_mul(ca - cb, WA(u2 - 1, i), CH(i, k, u2)); \ } template void pass3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r = -0.5, tw1i = (fwd ? -1 : 1) * T0(0.8660254037844386467637231707529362L); auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 3 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { POCKETFFT_PREP3(0) POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i) } else for (size_t k = 0; k < l1; ++k) { { POCKETFFT_PREP3(0) POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i) } for (size_t i = 1; i < ido; ++i) { POCKETFFT_PREP3(i) POCKETFFT_PARTSTEP3b(1, 2, tw1r, tw1i) } } } #undef POCKETFFT_PARTSTEP3b #undef POCKETFFT_PARTSTEP3a #undef POCKETFFT_PREP3 template void pass4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 4 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { T t1, t2, t3, t4; PM(t2, t1, CC(0, 0, k), CC(0, 2, k)); PM(t3, t4, CC(0, 1, k), CC(0, 3, k)); ROTX90(t4); PM(CH(0, k, 0), CH(0, k, 2), t2, t3); PM(CH(0, k, 1), CH(0, k, 3), t1, t4); } else for (size_t k = 0; k < l1; ++k) { { T t1, t2, t3, t4; PM(t2, t1, CC(0, 0, k), CC(0, 2, k)); PM(t3, t4, CC(0, 1, k), CC(0, 3, k)); ROTX90(t4); PM(CH(0, k, 0), CH(0, k, 2), t2, t3); PM(CH(0, k, 1), CH(0, k, 3), t1, t4); } for (size_t i = 1; i < ido; ++i) { T t1, t2, t3, t4; T cc0 = CC(i, 0, k), cc1 = CC(i, 1, k), cc2 = CC(i, 2, k), cc3 = CC(i, 3, k); PM(t2, t1, cc0, cc2); PM(t3, t4, cc1, cc3); ROTX90(t4); CH(i, k, 0) = t2 + t3; special_mul(t1 + t4, WA(0, i), CH(i, k, 1)); special_mul(t2 - t3, WA(1, i), CH(i, k, 2)); special_mul(t1 - t4, WA(2, i), CH(i, k, 3)); } } } #define POCKETFFT_PREP5(idx) \ T t0 = CC(idx, 0, k), t1, t2, t3, t4; \ PM(t1, t4, CC(idx, 1, k), CC(idx, 4, k)); \ PM(t2, t3, CC(idx, 2, k), CC(idx, 3, k)); \ CH(idx, k, 0).r = t0.r + t1.r + t2.r; \ CH(idx, k, 0).i = t0.i + t1.i + t2.i; #define POCKETFFT_PARTSTEP5a(u1, u2, twar, twbr, twai, twbi) \ { \ T ca, cb; \ ca.r = t0.r + twar * t1.r + twbr * t2.r; \ ca.i = t0.i + twar * t1.i + twbr * t2.i; \ cb.i = twai * t4.r twbi * t3.r; \ cb.r = -(twai * t4.i twbi * t3.i); \ PM(CH(0, k, u1), CH(0, k, u2), ca, cb); \ } #define POCKETFFT_PARTSTEP5b(u1, u2, twar, twbr, twai, twbi) \ { \ T ca, cb, da, db; \ ca.r = t0.r + twar * t1.r + twbr * t2.r; \ ca.i = t0.i + twar * t1.i + twbr * t2.i; \ cb.i = twai * t4.r twbi * t3.r; \ cb.r = -(twai * t4.i twbi * t3.i); \ special_mul(ca + cb, WA(u1 - 1, i), CH(i, k, u1)); \ special_mul(ca - cb, WA(u2 - 1, i), CH(i, k, u2)); \ } template void pass5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r = T0(0.3090169943749474241022934171828191L), tw1i = (fwd ? -1 : 1) * T0(0.9510565162951535721164393333793821L), tw2r = T0(-0.8090169943749474241022934171828191L), tw2i = (fwd ? -1 : 1) * T0(0.5877852522924731291687059546390728L); auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 5 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { POCKETFFT_PREP5(0) POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i) POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i) } else for (size_t k = 0; k < l1; ++k) { { POCKETFFT_PREP5(0) POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i) POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i) } for (size_t i = 1; i < ido; ++i) { POCKETFFT_PREP5(i) POCKETFFT_PARTSTEP5b(1, 4, tw1r, tw2r, +tw1i, +tw2i) POCKETFFT_PARTSTEP5b(2, 3, tw2r, tw1r, +tw2i, -tw1i) } } } #undef POCKETFFT_PARTSTEP5b #undef POCKETFFT_PARTSTEP5a #undef POCKETFFT_PREP5 #define POCKETFFT_PREP7(idx) \ T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7; \ PM(t2, t7, CC(idx, 1, k), CC(idx, 6, k)); \ PM(t3, t6, CC(idx, 2, k), CC(idx, 5, k)); \ PM(t4, t5, CC(idx, 3, k), CC(idx, 4, k)); \ CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r; \ CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i; #define POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, out1, out2) \ { \ T ca, cb; \ ca.r = t1.r + x1 * t2.r + x2 * t3.r + x3 * t4.r; \ ca.i = t1.i + x1 * t2.i + x2 * t3.i + x3 * t4.i; \ cb.i = y1 * t7.r y2 * t6.r y3 * t5.r; \ cb.r = -(y1 * t7.i y2 * t6.i y3 * t5.i); \ PM(out1, out2, ca, cb); \ } #define POCKETFFT_PARTSTEP7a(u1, u2, x1, x2, x3, y1, y2, y3) \ POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, CH(0, k, u1), CH(0, k, u2)) #define POCKETFFT_PARTSTEP7(u1, u2, x1, x2, x3, y1, y2, y3) \ { \ T da, db; \ POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, da, db) \ special_mul(da, WA(u1 - 1, i), CH(i, k, u1)); \ special_mul(db, WA(u2 - 1, i), CH(i, k, u2)); \ } template void pass7(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r = T0(0.6234898018587335305250048840042398L), tw1i = (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), tw2r = T0(-0.2225209339563144042889025644967948L), tw2i = (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), tw3r = T0(-0.9009688679024191262361023195074451L), tw3i = (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 7 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { POCKETFFT_PREP7(0) POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) } else for (size_t k = 0; k < l1; ++k) { { POCKETFFT_PREP7(0) POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) } for (size_t i = 1; i < ido; ++i) { POCKETFFT_PREP7(i) POCKETFFT_PARTSTEP7(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) POCKETFFT_PARTSTEP7(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) POCKETFFT_PARTSTEP7(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) } } } #undef POCKETFFT_PARTSTEP7 #undef POCKETFFT_PARTSTEP7a0 #undef POCKETFFT_PARTSTEP7a #undef POCKETFFT_PREP7 template void ROTX45(T& a) const { constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); if (fwd) { auto tmp_ = a.r; a.r = hsqt2 * (a.r + a.i); a.i = hsqt2 * (a.i - tmp_); } else { auto tmp_ = a.r; a.r = hsqt2 * (a.r - a.i); a.i = hsqt2 * (a.i + tmp_); } } template void ROTX135(T& a) const { constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); if (fwd) { auto tmp_ = a.r; a.r = hsqt2 * (a.i - a.r); a.i = hsqt2 * (-tmp_ - a.i); } else { auto tmp_ = a.r; a.r = hsqt2 * (-a.r - a.i); a.i = hsqt2 * (tmp_ - a.i); } } template void pass8(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 8 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { T a0, a1, a2, a3, a4, a5, a6, a7; PM(a1, a5, CC(0, 1, k), CC(0, 5, k)); PM(a3, a7, CC(0, 3, k), CC(0, 7, k)); PMINPLACE(a1, a3); ROTX90(a3); ROTX90(a7); PMINPLACE(a5, a7); ROTX45(a5); ROTX135(a7); PM(a0, a4, CC(0, 0, k), CC(0, 4, k)); PM(a2, a6, CC(0, 2, k), CC(0, 6, k)); PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1); PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3); ROTX90(a6); PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5); PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7); } else for (size_t k = 0; k < l1; ++k) { { T a0, a1, a2, a3, a4, a5, a6, a7; PM(a1, a5, CC(0, 1, k), CC(0, 5, k)); PM(a3, a7, CC(0, 3, k), CC(0, 7, k)); PMINPLACE(a1, a3); ROTX90(a3); ROTX90(a7); PMINPLACE(a5, a7); ROTX45(a5); ROTX135(a7); PM(a0, a4, CC(0, 0, k), CC(0, 4, k)); PM(a2, a6, CC(0, 2, k), CC(0, 6, k)); PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1); PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3); ROTX90(a6); PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5); PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7); } for (size_t i = 1; i < ido; ++i) { T a0, a1, a2, a3, a4, a5, a6, a7; PM(a1, a5, CC(i, 1, k), CC(i, 5, k)); PM(a3, a7, CC(i, 3, k), CC(i, 7, k)); ROTX90(a7); PMINPLACE(a1, a3); ROTX90(a3); PMINPLACE(a5, a7); ROTX45(a5); ROTX135(a7); PM(a0, a4, CC(i, 0, k), CC(i, 4, k)); PM(a2, a6, CC(i, 2, k), CC(i, 6, k)); PMINPLACE(a0, a2); CH(i, k, 0) = a0 + a1; special_mul(a0 - a1, WA(3, i), CH(i, k, 4)); special_mul(a2 + a3, WA(1, i), CH(i, k, 2)); special_mul(a2 - a3, WA(5, i), CH(i, k, 6)); ROTX90(a6); PMINPLACE(a4, a6); special_mul(a4 + a5, WA(0, i), CH(i, k, 1)); special_mul(a4 - a5, WA(4, i), CH(i, k, 5)); special_mul(a6 + a7, WA(2, i), CH(i, k, 3)); special_mul(a6 - a7, WA(6, i), CH(i, k, 7)); } } } #define POCKETFFT_PREP11(idx) \ T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ PM(t2, t11, CC(idx, 1, k), CC(idx, 10, k)); \ PM(t3, t10, CC(idx, 2, k), CC(idx, 9, k)); \ PM(t4, t9, CC(idx, 3, k), CC(idx, 8, k)); \ PM(t5, t8, CC(idx, 4, k), CC(idx, 7, k)); \ PM(t6, t7, CC(idx, 5, k), CC(idx, 6, k)); \ CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r + t5.r + t6.r; \ CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i + t5.i + t6.i; #define POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, out1, out2) \ { \ T ca = t1 + t2 * x1 + t3 * x2 + t4 * x3 + t5 * x4 + t6 * x5, cb; \ cb.i = y1 * t11.r y2 * t10.r y3 * t9.r y4 * t8.r y5 * t7.r; \ cb.r = -(y1 * t11.i y2 * t10.i y3 * t9.i y4 * t8.i y5 * t7.i); \ PM(out1, out2, ca, cb); \ } #define POCKETFFT_PARTSTEP11a(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5) \ POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, CH(0, k, u1), CH(0, k, u2)) #define POCKETFFT_PARTSTEP11(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5) \ { \ T da, db; \ POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, da, db) \ special_mul(da, WA(u1 - 1, i), CH(i, k, u1)); \ special_mul(db, WA(u2 - 1, i), CH(i, k, u2)); \ } template void pass11(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa) const { constexpr T0 tw1r = T0(0.8412535328311811688618116489193677L), tw1i = (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), tw2r = T0(0.4154150130018864255292741492296232L), tw2i = (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), tw3r = T0(-0.1423148382732851404437926686163697L), tw3i = (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), tw4r = T0(-0.6548607339452850640569250724662936L), tw4i = (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), tw5r = T0(-0.9594929736144973898903680570663277L), tw5i = (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 11 * c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; if (ido == 1) for (size_t k = 0; k < l1; ++k) { POCKETFFT_PREP11(0) POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i, +tw5i) POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i, +tw4i, -tw5i, -tw3i, -tw1i) POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i, +tw4i) POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r, +tw4i, -tw3i, +tw1i, +tw5i, -tw2i) POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i, -tw2i, +tw3i) } else for (size_t k = 0; k < l1; ++k) { { POCKETFFT_PREP11(0) POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i, +tw5i) POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i, +tw4i, -tw5i, -tw3i, -tw1i) POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i, +tw4i) POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r, +tw4i, -tw3i, +tw1i, +tw5i, -tw2i) POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i, -tw2i, +tw3i) } for (size_t i = 1; i < ido; ++i) { POCKETFFT_PREP11(i) POCKETFFT_PARTSTEP11(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i, +tw5i) POCKETFFT_PARTSTEP11(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i, +tw4i, -tw5i, -tw3i, -tw1i) POCKETFFT_PARTSTEP11(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i, +tw4i) POCKETFFT_PARTSTEP11(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r, +tw4i, -tw3i, +tw1i, +tw5i, -tw2i) POCKETFFT_PARTSTEP11(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i, -tw2i, +tw3i) } } } #undef POCKETFFT_PARTSTEP11 #undef POCKETFFT_PARTSTEP11a0 #undef POCKETFFT_PARTSTEP11a #undef POCKETFFT_PREP11 template void passg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const cmplx* POCKETFFT_RESTRICT wa, const cmplx* POCKETFFT_RESTRICT csarr) const { const size_t cdim = ip; size_t ipph = (ip + 1) / 2; size_t idl1 = ido * l1; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + cdim * c)]; }; auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& { return cc[a + ido * (b + l1 * c)]; }; auto CX2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; }; auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& { return ch[a + idl1 * b]; }; arr> wal(ip); wal[0] = cmplx(1., 0.); for (size_t i = 1; i < ip; ++i) wal[i] = cmplx(csarr[i].r, fwd ? -csarr[i].i : csarr[i].i); for (size_t k = 0; k < l1; ++k) for (size_t i = 0; i < ido; ++i) CH(i, k, 0) = CC(i, 0, k); for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) for (size_t k = 0; k < l1; ++k) for (size_t i = 0; i < ido; ++i) PM(CH(i, k, j), CH(i, k, jc), CC(i, j, k), CC(i, jc, k)); for (size_t k = 0; k < l1; ++k) for (size_t i = 0; i < ido; ++i) { T tmp = CH(i, k, 0); for (size_t j = 1; j < ipph; ++j) tmp += CH(i, k, j); CX(i, k, 0) = tmp; } for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) { // j=0 for (size_t ik = 0; ik < idl1; ++ik) { CX2(ik, l).r = CH2(ik, 0).r + wal[l].r * CH2(ik, 1).r + wal[2 * l].r * CH2(ik, 2).r; CX2(ik, l).i = CH2(ik, 0).i + wal[l].r * CH2(ik, 1).i + wal[2 * l].r * CH2(ik, 2).i; CX2(ik, lc).r = -wal[l].i * CH2(ik, ip - 1).i - wal[2 * l].i * CH2(ik, ip - 2).i; CX2(ik, lc).i = wal[l].i * CH2(ik, ip - 1).r + wal[2 * l].i * CH2(ik, ip - 2).r; } size_t iwal = 2 * l; size_t j = 3, jc = ip - 3; for (; j < ipph - 1; j += 2, jc -= 2) { iwal += l; if (iwal > ip) iwal -= ip; cmplx xwal = wal[iwal]; iwal += l; if (iwal > ip) iwal -= ip; cmplx xwal2 = wal[iwal]; for (size_t ik = 0; ik < idl1; ++ik) { CX2(ik, l).r += CH2(ik, j).r * xwal.r + CH2(ik, j + 1).r * xwal2.r; CX2(ik, l).i += CH2(ik, j).i * xwal.r + CH2(ik, j + 1).i * xwal2.r; CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i + CH2(ik, jc - 1).i * xwal2.i; CX2(ik, lc).i += CH2(ik, jc).r * xwal.i + CH2(ik, jc - 1).r * xwal2.i; } } for (; j < ipph; ++j, --jc) { iwal += l; if (iwal > ip) iwal -= ip; cmplx xwal = wal[iwal]; for (size_t ik = 0; ik < idl1; ++ik) { CX2(ik, l).r += CH2(ik, j).r * xwal.r; CX2(ik, l).i += CH2(ik, j).i * xwal.r; CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i; CX2(ik, lc).i += CH2(ik, jc).r * xwal.i; } } } // shuffling and twiddling if (ido == 1) for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) for (size_t ik = 0; ik < idl1; ++ik) { T t1 = CX2(ik, j), t2 = CX2(ik, jc); PM(CX2(ik, j), CX2(ik, jc), t1, t2); } else { for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) for (size_t k = 0; k < l1; ++k) { T t1 = CX(0, k, j), t2 = CX(0, k, jc); PM(CX(0, k, j), CX(0, k, jc), t1, t2); for (size_t i = 1; i < ido; ++i) { T x1, x2; PM(x1, x2, CX(i, k, j), CX(i, k, jc)); size_t idij = (j - 1) * (ido - 1) + i - 1; special_mul(x1, wa[idij], CX(i, k, j)); idij = (jc - 1) * (ido - 1) + i - 1; special_mul(x2, wa[idij], CX(i, k, jc)); } } } } template void pass_all(T c[], T0 fct) const { if (length == 1) { c[0] *= fct; return; } size_t l1 = 1; arr ch(length); T *p1 = c, *p2 = ch.data(); for (size_t k1 = 0; k1 < fact.size(); k1++) { size_t ip = fact[k1].fct; size_t l2 = ip * l1; size_t ido = length / l2; if (ip == 4) pass4(ido, l1, p1, p2, fact[k1].tw); else if (ip == 8) pass8(ido, l1, p1, p2, fact[k1].tw); else if (ip == 2) pass2(ido, l1, p1, p2, fact[k1].tw); else if (ip == 3) pass3(ido, l1, p1, p2, fact[k1].tw); else if (ip == 5) pass5(ido, l1, p1, p2, fact[k1].tw); else if (ip == 7) pass7(ido, l1, p1, p2, fact[k1].tw); else if (ip == 11) pass11(ido, l1, p1, p2, fact[k1].tw); else { passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); std::swap(p1, p2); } std::swap(p1, p2); l1 = l2; } if (p1 != c) { if (fct != 1.) for (size_t i = 0; i < length; ++i) c[i] = ch[i] * fct; else std::copy_n(p1, length, c); } else if (fct != 1.) for (size_t i = 0; i < length; ++i) c[i] *= fct; } public: template void exec(T c[], T0 fct, bool fwd) const { fwd ? pass_all(c, fct) : pass_all(c, fct); } private: POCKETFFT_NOINLINE void factorize() { size_t len = length; while ((len & 7) == 0) { add_factor(8); len >>= 3; } while ((len & 3) == 0) { add_factor(4); len >>= 2; } if ((len & 1) == 0) { len >>= 1; // factor 2 should be at the front of the factor list add_factor(2); std::swap(fact[0].fct, fact.back().fct); } for (size_t divisor = 3; divisor * divisor <= len; divisor += 2) while ((len % divisor) == 0) { add_factor(divisor); len /= divisor; } if (len > 1) add_factor(len); } size_t twsize() const { size_t twsize = 0, l1 = 1; for (size_t k = 0; k < fact.size(); ++k) { size_t ip = fact[k].fct, ido = length / (l1 * ip); twsize += (ip - 1) * (ido - 1); if (ip > 11) twsize += ip; l1 *= ip; } return twsize; } void comp_twiddle() { sincos_2pibyn twiddle(length); size_t l1 = 1; size_t memofs = 0; for (size_t k = 0; k < fact.size(); ++k) { size_t ip = fact[k].fct, ido = length / (l1 * ip); fact[k].tw = mem.data() + memofs; memofs += (ip - 1) * (ido - 1); for (size_t j = 1; j < ip; ++j) for (size_t i = 1; i < ido; ++i) fact[k].tw[(j - 1) * (ido - 1) + i - 1] = twiddle[j * l1 * i]; if (ip > 11) { fact[k].tws = mem.data() + memofs; memofs += ip; for (size_t j = 0; j < ip; ++j) fact[k].tws[j] = twiddle[j * l1 * ido]; } l1 *= ip; } } public: POCKETFFT_NOINLINE cfftp(size_t length_) : length(length_) { if (length == 0) throw std::runtime_error("zero-length FFT requested"); if (length == 1) return; factorize(); mem.resize(twsize()); comp_twiddle(); } }; // // real-valued FFTPACK transforms // template class rfftp { private: struct fctdata { size_t fct; T0 *tw, *tws; }; size_t length; arr mem; std::vector fact; void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); } /* (a+ib) = conj(c+id) * (e+if) */ template inline void MULPM(T1& a, T1& b, T2 c, T2 d, T3 e, T3 f) const { a = c * e + d * f; b = c * f - d * e; } template void radf2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + l1 * c)]; }; auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 2 * c)]; }; for (size_t k = 0; k < l1; k++) PM(CH(0, 0, k), CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 1)); if ((ido & 1) == 0) for (size_t k = 0; k < l1; k++) { CH(0, 1, k) = -CC(ido - 1, k, 1); CH(ido - 1, 0, k) = CC(ido - 1, k, 0); } if (ido <= 2) return; for (size_t k = 0; k < l1; k++) for (size_t i = 2; i < ido; i += 2) { size_t ic = ido - i; T tr2, ti2; MULPM(tr2, ti2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); PM(CH(i - 1, 0, k), CH(ic - 1, 1, k), CC(i - 1, k, 0), tr2); PM(CH(i, 0, k), CH(ic, 1, k), ti2, CC(i, k, 0)); } } // a2=a+b; b2=i*(b-a); #define POCKETFFT_REARRANGE(rx, ix, ry, iy) \ { \ auto t1 = rx + ry, t2 = ry - rx, t3 = ix + iy, t4 = ix - iy; \ rx = t1; \ ix = t3; \ ry = t4; \ iy = t2; \ } template void radf3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + l1 * c)]; }; auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 3 * c)]; }; for (size_t k = 0; k < l1; k++) { T cr2 = CC(0, k, 1) + CC(0, k, 2); CH(0, 0, k) = CC(0, k, 0) + cr2; CH(0, 2, k) = taui * (CC(0, k, 2) - CC(0, k, 1)); CH(ido - 1, 1, k) = CC(0, k, 0) + taur * cr2; } if (ido == 1) return; for (size_t k = 0; k < l1; k++) for (size_t i = 2; i < ido; i += 2) { size_t ic = ido - i; T di2, di3, dr2, dr3; MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); // d2=conj(WA0)*CC1 MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2)); // d3=conj(WA1)*CC2 POCKETFFT_REARRANGE(dr2, di2, dr3, di3); CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2; // c add CH(i, 0, k) = CC(i, k, 0) + di2; T tr2 = CC(i - 1, k, 0) + taur * dr2; // c add T ti2 = CC(i, k, 0) + taur * di2; T tr3 = taui * dr3; // t3 = taui*i*(d3-d2)? T ti3 = taui * di3; PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr3); // PM(i) = t2+t3 PM(CH(i, 2, k), CH(ic, 1, k), ti3, ti2); // PM(ic) = conj(t2-t3) } } template void radf4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + l1 * c)]; }; auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 4 * c)]; }; for (size_t k = 0; k < l1; k++) { T tr1, tr2; PM(tr1, CH(0, 2, k), CC(0, k, 3), CC(0, k, 1)); PM(tr2, CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 2)); PM(CH(0, 0, k), CH(ido - 1, 3, k), tr2, tr1); } if ((ido & 1) == 0) for (size_t k = 0; k < l1; k++) { T ti1 = -hsqt2 * (CC(ido - 1, k, 1) + CC(ido - 1, k, 3)); T tr1 = hsqt2 * (CC(ido - 1, k, 1) - CC(ido - 1, k, 3)); PM(CH(ido - 1, 0, k), CH(ido - 1, 2, k), CC(ido - 1, k, 0), tr1); PM(CH(0, 3, k), CH(0, 1, k), ti1, CC(ido - 1, k, 2)); } if (ido <= 2) return; for (size_t k = 0; k < l1; k++) for (size_t i = 2; i < ido; i += 2) { size_t ic = ido - i; T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; MULPM(cr2, ci2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); MULPM(cr3, ci3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2)); MULPM(cr4, ci4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3)); PM(tr1, tr4, cr4, cr2); PM(ti1, ti4, ci2, ci4); PM(tr2, tr3, CC(i - 1, k, 0), cr3); PM(ti2, ti3, CC(i, k, 0), ci3); PM(CH(i - 1, 0, k), CH(ic - 1, 3, k), tr2, tr1); PM(CH(i, 0, k), CH(ic, 3, k), ti1, ti2); PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr3, ti4); PM(CH(i, 2, k), CH(ic, 1, k), tr4, ti3); } } template void radf5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L), ti11 = T0(0.9510565162951535721164393333793821L), tr12 = T0(-0.8090169943749474241022934171828191L), ti12 = T0(0.5877852522924731291687059546390728L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + l1 * c)]; }; auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 5 * c)]; }; for (size_t k = 0; k < l1; k++) { T cr2, cr3, ci4, ci5; PM(cr2, ci5, CC(0, k, 4), CC(0, k, 1)); PM(cr3, ci4, CC(0, k, 3), CC(0, k, 2)); CH(0, 0, k) = CC(0, k, 0) + cr2 + cr3; CH(ido - 1, 1, k) = CC(0, k, 0) + tr11 * cr2 + tr12 * cr3; CH(0, 2, k) = ti11 * ci5 + ti12 * ci4; CH(ido - 1, 3, k) = CC(0, k, 0) + tr12 * cr2 + tr11 * cr3; CH(0, 4, k) = ti12 * ci5 - ti11 * ci4; } if (ido == 1) return; for (size_t k = 0; k < l1; ++k) for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { T di2, di3, di4, di5, dr2, dr3, dr4, dr5; MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2)); MULPM(dr4, di4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3)); MULPM(dr5, di5, WA(3, i - 2), WA(3, i - 1), CC(i - 1, k, 4), CC(i, k, 4)); POCKETFFT_REARRANGE(dr2, di2, dr5, di5); POCKETFFT_REARRANGE(dr3, di3, dr4, di4); CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2 + dr3; CH(i, 0, k) = CC(i, k, 0) + di2 + di3; T tr2 = CC(i - 1, k, 0) + tr11 * dr2 + tr12 * dr3; T ti2 = CC(i, k, 0) + tr11 * di2 + tr12 * di3; T tr3 = CC(i - 1, k, 0) + tr12 * dr2 + tr11 * dr3; T ti3 = CC(i, k, 0) + tr12 * di2 + tr11 * di3; T tr5 = ti11 * dr5 + ti12 * dr4; T ti5 = ti11 * di5 + ti12 * di4; T tr4 = ti12 * dr5 - ti11 * dr4; T ti4 = ti12 * di5 - ti11 * di4; PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr5); PM(CH(i, 2, k), CH(ic, 1, k), ti5, ti2); PM(CH(i - 1, 4, k), CH(ic - 1, 3, k), tr3, tr4); PM(CH(i, 4, k), CH(ic, 3, k), ti4, ti3); } } #undef POCKETFFT_REARRANGE template void radfg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa, const T0* POCKETFFT_RESTRICT csarr) const { const size_t cdim = ip; size_t ipph = (ip + 1) / 2; size_t idl1 = ido * l1; auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> T& { return cc[a + ido * (b + cdim * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> const T& { return ch[a + ido * (b + l1 * c)]; }; auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& { return cc[a + ido * (b + l1 * c)]; }; auto C2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; }; auto CH2 = [ch, idl1](size_t a, size_t b) -> T& { return ch[a + idl1 * b]; }; if (ido > 1) { for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 114 { size_t is = (j - 1) * (ido - 1), is2 = (jc - 1) * (ido - 1); for (size_t k = 0; k < l1; ++k) // 113 { size_t idij = is; size_t idij2 = is2; for (size_t i = 1; i <= ido - 2; i += 2) // 112 { T t1 = C1(i, k, j), t2 = C1(i + 1, k, j), t3 = C1(i, k, jc), t4 = C1(i + 1, k, jc); T x1 = wa[idij] * t1 + wa[idij + 1] * t2, x2 = wa[idij] * t2 - wa[idij + 1] * t1, x3 = wa[idij2] * t3 + wa[idij2 + 1] * t4, x4 = wa[idij2] * t4 - wa[idij2 + 1] * t3; PM(C1(i, k, j), C1(i + 1, k, jc), x3, x1); PM(C1(i + 1, k, j), C1(i, k, jc), x2, x4); idij += 2; idij2 += 2; } } } } for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 123 for (size_t k = 0; k < l1; ++k) // 122 MPINPLACE(C1(0, k, jc), C1(0, k, j)); // everything in C // memset(ch,0,ip*l1*ido*sizeof(double)); for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) // 127 { for (size_t ik = 0; ik < idl1; ++ik) // 124 { CH2(ik, l) = C2(ik, 0) + csarr[2 * l] * C2(ik, 1) + csarr[4 * l] * C2(ik, 2); CH2(ik, lc) = csarr[2 * l + 1] * C2(ik, ip - 1) + csarr[4 * l + 1] * C2(ik, ip - 2); } size_t iang = 2 * l; size_t j = 3, jc = ip - 3; for (; j < ipph - 3; j += 4, jc -= 4) // 126 { iang += l; if (iang >= ip) iang -= ip; T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; iang += l; if (iang >= ip) iang -= ip; T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; iang += l; if (iang >= ip) iang -= ip; T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1]; iang += l; if (iang >= ip) iang -= ip; T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) // 125 { CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1) + ar3 * C2(ik, j + 2) + ar4 * C2(ik, j + 3); CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1) + ai3 * C2(ik, jc - 2) + ai4 * C2(ik, jc - 3); } } for (; j < ipph - 1; j += 2, jc -= 2) // 126 { iang += l; if (iang >= ip) iang -= ip; T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; iang += l; if (iang >= ip) iang -= ip; T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) // 125 { CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1); CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1); } } for (; j < ipph; ++j, --jc) // 126 { iang += l; if (iang >= ip) iang -= ip; T0 ar = csarr[2 * iang], ai = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) // 125 { CH2(ik, l) += ar * C2(ik, j); CH2(ik, lc) += ai * C2(ik, jc); } } } for (size_t ik = 0; ik < idl1; ++ik) // 101 CH2(ik, 0) = C2(ik, 0); for (size_t j = 1; j < ipph; ++j) // 129 for (size_t ik = 0; ik < idl1; ++ik) // 128 CH2(ik, 0) += C2(ik, j); // everything in CH at this point! // memset(cc,0,ip*l1*ido*sizeof(double)); for (size_t k = 0; k < l1; ++k) // 131 for (size_t i = 0; i < ido; ++i) // 130 CC(i, 0, k) = CH(i, k, 0); for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 137 { size_t j2 = 2 * j - 1; for (size_t k = 0; k < l1; ++k) // 136 { CC(ido - 1, j2, k) = CH(0, k, j); CC(0, j2 + 1, k) = CH(0, k, jc); } } if (ido == 1) return; for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 140 { size_t j2 = 2 * j - 1; for (size_t k = 0; k < l1; ++k) // 139 for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2) // 138 { CC(i, j2 + 1, k) = CH(i, k, j) + CH(i, k, jc); CC(ic, j2, k) = CH(i, k, j) - CH(i, k, jc); CC(i + 1, j2 + 1, k) = CH(i + 1, k, j) + CH(i + 1, k, jc); CC(ic + 1, j2, k) = CH(i + 1, k, jc) - CH(i + 1, k, j); } } } template void radb2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 2 * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; for (size_t k = 0; k < l1; k++) PM(CH(0, k, 0), CH(0, k, 1), CC(0, 0, k), CC(ido - 1, 1, k)); if ((ido & 1) == 0) for (size_t k = 0; k < l1; k++) { CH(ido - 1, k, 0) = 2 * CC(ido - 1, 0, k); CH(ido - 1, k, 1) = -2 * CC(0, 1, k); } if (ido <= 2) return; for (size_t k = 0; k < l1; ++k) for (size_t i = 2; i < ido; i += 2) { size_t ic = ido - i; T ti2, tr2; PM(CH(i - 1, k, 0), tr2, CC(i - 1, 0, k), CC(ic - 1, 1, k)); PM(ti2, CH(i, k, 0), CC(i, 0, k), CC(ic, 1, k)); MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ti2, tr2); } } template void radb3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 3 * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; for (size_t k = 0; k < l1; k++) { T tr2 = 2 * CC(ido - 1, 1, k); T cr2 = CC(0, 0, k) + taur * tr2; CH(0, k, 0) = CC(0, 0, k) + tr2; T ci3 = 2 * taui * CC(0, 2, k); PM(CH(0, k, 2), CH(0, k, 1), cr2, ci3); } if (ido == 1) return; for (size_t k = 0; k < l1; k++) for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { T tr2 = CC(i - 1, 2, k) + CC(ic - 1, 1, k); // t2=CC(I) + conj(CC(ic)) T ti2 = CC(i, 2, k) - CC(ic, 1, k); T cr2 = CC(i - 1, 0, k) + taur * tr2; // c2=CC +taur*t2 T ci2 = CC(i, 0, k) + taur * ti2; CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2; // CH=CC+t2 CH(i, k, 0) = CC(i, 0, k) + ti2; T cr3 = taui * (CC(i - 1, 2, k) - CC(ic - 1, 1, k)); // c3=taui*(CC(i)-conj(CC(ic))) T ci3 = taui * (CC(i, 2, k) + CC(ic, 1, k)); T di2, di3, dr2, dr3; PM(dr3, dr2, cr2, ci3); // d2= (cr2-ci3, ci2+cr3) = c2+i*c3 PM(di2, di3, ci2, cr3); // d3= (cr2+ci3, ci2-cr3) = c2-i*c3 MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2); // ch = WA*d2 MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3); } } template void radb4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 4 * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; for (size_t k = 0; k < l1; k++) { T tr1, tr2; PM(tr2, tr1, CC(0, 0, k), CC(ido - 1, 3, k)); T tr3 = 2 * CC(ido - 1, 1, k); T tr4 = 2 * CC(0, 2, k); PM(CH(0, k, 0), CH(0, k, 2), tr2, tr3); PM(CH(0, k, 3), CH(0, k, 1), tr1, tr4); } if ((ido & 1) == 0) for (size_t k = 0; k < l1; k++) { T tr1, tr2, ti1, ti2; PM(ti1, ti2, CC(0, 3, k), CC(0, 1, k)); PM(tr2, tr1, CC(ido - 1, 0, k), CC(ido - 1, 2, k)); CH(ido - 1, k, 0) = tr2 + tr2; CH(ido - 1, k, 1) = sqrt2 * (tr1 - ti1); CH(ido - 1, k, 2) = ti2 + ti2; CH(ido - 1, k, 3) = -sqrt2 * (tr1 + ti1); } if (ido <= 2) return; for (size_t k = 0; k < l1; ++k) for (size_t i = 2; i < ido; i += 2) { T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; size_t ic = ido - i; PM(tr2, tr1, CC(i - 1, 0, k), CC(ic - 1, 3, k)); PM(ti1, ti2, CC(i, 0, k), CC(ic, 3, k)); PM(tr4, ti3, CC(i, 2, k), CC(ic, 1, k)); PM(tr3, ti4, CC(i - 1, 2, k), CC(ic - 1, 1, k)); PM(CH(i - 1, k, 0), cr3, tr2, tr3); PM(CH(i, k, 0), ci3, ti2, ti3); PM(cr4, cr2, tr1, tr4); PM(ci2, ci4, ti1, ti4); MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ci2, cr2); MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), ci3, cr3); MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), ci4, cr4); } } template void radb5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa) const { constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L), ti11 = T0(0.9510565162951535721164393333793821L), tr12 = T0(-0.8090169943749474241022934171828191L), ti12 = T0(0.5877852522924731291687059546390728L); auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + 5 * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; for (size_t k = 0; k < l1; k++) { T ti5 = CC(0, 2, k) + CC(0, 2, k); T ti4 = CC(0, 4, k) + CC(0, 4, k); T tr2 = CC(ido - 1, 1, k) + CC(ido - 1, 1, k); T tr3 = CC(ido - 1, 3, k) + CC(ido - 1, 3, k); CH(0, k, 0) = CC(0, 0, k) + tr2 + tr3; T cr2 = CC(0, 0, k) + tr11 * tr2 + tr12 * tr3; T cr3 = CC(0, 0, k) + tr12 * tr2 + tr11 * tr3; T ci4, ci5; MULPM(ci5, ci4, ti5, ti4, ti11, ti12); PM(CH(0, k, 4), CH(0, k, 1), cr2, ci5); PM(CH(0, k, 3), CH(0, k, 2), cr3, ci4); } if (ido == 1) return; for (size_t k = 0; k < l1; ++k) for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5; PM(tr2, tr5, CC(i - 1, 2, k), CC(ic - 1, 1, k)); PM(ti5, ti2, CC(i, 2, k), CC(ic, 1, k)); PM(tr3, tr4, CC(i - 1, 4, k), CC(ic - 1, 3, k)); PM(ti4, ti3, CC(i, 4, k), CC(ic, 3, k)); CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2 + tr3; CH(i, k, 0) = CC(i, 0, k) + ti2 + ti3; T cr2 = CC(i - 1, 0, k) + tr11 * tr2 + tr12 * tr3; T ci2 = CC(i, 0, k) + tr11 * ti2 + tr12 * ti3; T cr3 = CC(i - 1, 0, k) + tr12 * tr2 + tr11 * tr3; T ci3 = CC(i, 0, k) + tr12 * ti2 + tr11 * ti3; T ci4, ci5, cr5, cr4; MULPM(cr5, cr4, tr5, tr4, ti11, ti12); MULPM(ci5, ci4, ti5, ti4, ti11, ti12); T dr2, dr3, dr4, dr5, di2, di3, di4, di5; PM(dr4, dr3, cr3, ci4); PM(di3, di4, ci3, cr4); PM(dr5, dr2, cr2, ci5); PM(di2, di5, ci2, cr5); MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2); MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3); MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), di4, dr4); MULPM(CH(i, k, 4), CH(i - 1, k, 4), WA(3, i - 2), WA(3, i - 1), di5, dr5); } } template void radbg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch, const T0* POCKETFFT_RESTRICT wa, const T0* POCKETFFT_RESTRICT csarr) const { const size_t cdim = ip; size_t ipph = (ip + 1) / 2; size_t idl1 = ido * l1; auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + cdim * c)]; }; auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + l1 * c)]; }; auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& { return cc[a + ido * (b + l1 * c)]; }; auto C2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; }; auto CH2 = [ch, idl1](size_t a, size_t b) -> T& { return ch[a + idl1 * b]; }; for (size_t k = 0; k < l1; ++k) // 102 for (size_t i = 0; i < ido; ++i) // 101 CH(i, k, 0) = CC(i, 0, k); for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 108 { size_t j2 = 2 * j - 1; for (size_t k = 0; k < l1; ++k) { CH(0, k, j) = 2 * CC(ido - 1, j2, k); CH(0, k, jc) = 2 * CC(0, j2 + 1, k); } } if (ido != 1) { for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 111 { size_t j2 = 2 * j - 1; for (size_t k = 0; k < l1; ++k) for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2) // 109 { CH(i, k, j) = CC(i, j2 + 1, k) + CC(ic, j2, k); CH(i, k, jc) = CC(i, j2 + 1, k) - CC(ic, j2, k); CH(i + 1, k, j) = CC(i + 1, j2 + 1, k) - CC(ic + 1, j2, k); CH(i + 1, k, jc) = CC(i + 1, j2 + 1, k) + CC(ic + 1, j2, k); } } } for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) { for (size_t ik = 0; ik < idl1; ++ik) { C2(ik, l) = CH2(ik, 0) + csarr[2 * l] * CH2(ik, 1) + csarr[4 * l] * CH2(ik, 2); C2(ik, lc) = csarr[2 * l + 1] * CH2(ik, ip - 1) + csarr[4 * l + 1] * CH2(ik, ip - 2); } size_t iang = 2 * l; size_t j = 3, jc = ip - 3; for (; j < ipph - 3; j += 4, jc -= 4) { iang += l; if (iang > ip) iang -= ip; T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; iang += l; if (iang > ip) iang -= ip; T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; iang += l; if (iang > ip) iang -= ip; T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1]; iang += l; if (iang > ip) iang -= ip; T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) { C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1) + ar3 * CH2(ik, j + 2) + ar4 * CH2(ik, j + 3); C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1) + ai3 * CH2(ik, jc - 2) + ai4 * CH2(ik, jc - 3); } } for (; j < ipph - 1; j += 2, jc -= 2) { iang += l; if (iang > ip) iang -= ip; T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; iang += l; if (iang > ip) iang -= ip; T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) { C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1); C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1); } } for (; j < ipph; ++j, --jc) { iang += l; if (iang > ip) iang -= ip; T0 war = csarr[2 * iang], wai = csarr[2 * iang + 1]; for (size_t ik = 0; ik < idl1; ++ik) { C2(ik, l) += war * CH2(ik, j); C2(ik, lc) += wai * CH2(ik, jc); } } } for (size_t j = 1; j < ipph; ++j) for (size_t ik = 0; ik < idl1; ++ik) CH2(ik, 0) += CH2(ik, j); for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 124 for (size_t k = 0; k < l1; ++k) PM(CH(0, k, jc), CH(0, k, j), C1(0, k, j), C1(0, k, jc)); if (ido == 1) return; for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 127 for (size_t k = 0; k < l1; ++k) for (size_t i = 1; i <= ido - 2; i += 2) { CH(i, k, j) = C1(i, k, j) - C1(i + 1, k, jc); CH(i, k, jc) = C1(i, k, j) + C1(i + 1, k, jc); CH(i + 1, k, j) = C1(i + 1, k, j) + C1(i, k, jc); CH(i + 1, k, jc) = C1(i + 1, k, j) - C1(i, k, jc); } // All in CH for (size_t j = 1; j < ip; ++j) { size_t is = (j - 1) * (ido - 1); for (size_t k = 0; k < l1; ++k) { size_t idij = is; for (size_t i = 1; i <= ido - 2; i += 2) { T t1 = CH(i, k, j), t2 = CH(i + 1, k, j); CH(i, k, j) = wa[idij] * t1 - wa[idij + 1] * t2; CH(i + 1, k, j) = wa[idij] * t2 + wa[idij + 1] * t1; idij += 2; } } } } template void copy_and_norm(T* c, T* p1, T0 fct) const { if (p1 != c) { if (fct != 1.) for (size_t i = 0; i < length; ++i) c[i] = fct * p1[i]; else std::copy_n(p1, length, c); } else if (fct != 1.) for (size_t i = 0; i < length; ++i) c[i] *= fct; } public: template void exec(T c[], T0 fct, bool r2hc) const { if (length == 1) { c[0] *= fct; return; } size_t nf = fact.size(); arr ch(length); T *p1 = c, *p2 = ch.data(); if (r2hc) for (size_t k1 = 0, l1 = length; k1 < nf; ++k1) { size_t k = nf - k1 - 1; size_t ip = fact[k].fct; size_t ido = length / l1; l1 /= ip; if (ip == 4) radf4(ido, l1, p1, p2, fact[k].tw); else if (ip == 2) radf2(ido, l1, p1, p2, fact[k].tw); else if (ip == 3) radf3(ido, l1, p1, p2, fact[k].tw); else if (ip == 5) radf5(ido, l1, p1, p2, fact[k].tw); else { radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); std::swap(p1, p2); } std::swap(p1, p2); } else for (size_t k = 0, l1 = 1; k < nf; k++) { size_t ip = fact[k].fct, ido = length / (ip * l1); if (ip == 4) radb4(ido, l1, p1, p2, fact[k].tw); else if (ip == 2) radb2(ido, l1, p1, p2, fact[k].tw); else if (ip == 3) radb3(ido, l1, p1, p2, fact[k].tw); else if (ip == 5) radb5(ido, l1, p1, p2, fact[k].tw); else radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); std::swap(p1, p2); l1 *= ip; } copy_and_norm(c, p1, fct); } private: void factorize() { size_t len = length; while ((len % 4) == 0) { add_factor(4); len >>= 2; } if ((len % 2) == 0) { len >>= 1; // factor 2 should be at the front of the factor list add_factor(2); std::swap(fact[0].fct, fact.back().fct); } for (size_t divisor = 3; divisor * divisor <= len; divisor += 2) while ((len % divisor) == 0) { add_factor(divisor); len /= divisor; } if (len > 1) add_factor(len); } size_t twsize() const { size_t twsz = 0, l1 = 1; for (size_t k = 0; k < fact.size(); ++k) { size_t ip = fact[k].fct, ido = length / (l1 * ip); twsz += (ip - 1) * (ido - 1); if (ip > 5) twsz += 2 * ip; l1 *= ip; } return twsz; } void comp_twiddle() { sincos_2pibyn twid(length); size_t l1 = 1; T0* ptr = mem.data(); for (size_t k = 0; k < fact.size(); ++k) { size_t ip = fact[k].fct, ido = length / (l1 * ip); if (k < fact.size() - 1) // last factor doesn't need twiddles { fact[k].tw = ptr; ptr += (ip - 1) * (ido - 1); for (size_t j = 1; j < ip; ++j) for (size_t i = 1; i <= (ido - 1) / 2; ++i) { fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 2] = twid[j * l1 * i].r; fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 1] = twid[j * l1 * i].i; } } if (ip > 5) // special factors required by *g functions { fact[k].tws = ptr; ptr += 2 * ip; fact[k].tws[0] = 1.; fact[k].tws[1] = 0.; for (size_t i = 2, ic = 2 * ip - 2; i <= ic; i += 2, ic -= 2) { fact[k].tws[i] = twid[i / 2 * (length / ip)].r; fact[k].tws[i + 1] = twid[i / 2 * (length / ip)].i; fact[k].tws[ic] = twid[i / 2 * (length / ip)].r; fact[k].tws[ic + 1] = -twid[i / 2 * (length / ip)].i; } } l1 *= ip; } } public: POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_) { if (length == 0) throw std::runtime_error("zero-length FFT requested"); if (length == 1) return; factorize(); mem.resize(twsize()); comp_twiddle(); } }; // // complex Bluestein transforms // template class fftblue { private: size_t n, n2; cfftp plan; arr> mem; cmplx*bk, *bkf; template void fft(cmplx c[], T0 fct) const { arr> akf(n2); /* initialize a_k and FFT it */ for (size_t m = 0; m < n; ++m) special_mul(c[m], bk[m], akf[m]); auto zero = akf[0] * T0(0); for (size_t m = n; m < n2; ++m) akf[m] = zero; plan.exec(akf.data(), 1., true); /* do the convolution */ akf[0] = akf[0].template special_mul(bkf[0]); for (size_t m = 1; m < (n2 + 1) / 2; ++m) { akf[m] = akf[m].template special_mul(bkf[m]); akf[n2 - m] = akf[n2 - m].template special_mul(bkf[m]); } if ((n2 & 1) == 0) akf[n2 / 2] = akf[n2 / 2].template special_mul(bkf[n2 / 2]); /* inverse FFT */ plan.exec(akf.data(), 1., false); /* multiply by b_k */ for (size_t m = 0; m < n; ++m) c[m] = akf[m].template special_mul(bk[m]) * fct; } public: POCKETFFT_NOINLINE fftblue(size_t length) : n(length), n2(util::good_size_cmplx(n * 2 - 1)), plan(n2), mem(n + n2 / 2 + 1), bk(mem.data()), bkf(mem.data() + n) { /* initialize b_k */ sincos_2pibyn tmp(2 * n); bk[0].Set(1, 0); size_t coeff = 0; for (size_t m = 1; m < n; ++m) { coeff += 2 * m - 1; if (coeff >= 2 * n) coeff -= 2 * n; bk[m] = tmp[coeff]; } /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ arr> tbkf(n2); T0 xn2 = T0(1) / T0(n2); tbkf[0] = bk[0] * xn2; for (size_t m = 1; m < n; ++m) tbkf[m] = tbkf[n2 - m] = bk[m] * xn2; for (size_t m = n; m <= (n2 - n); ++m) tbkf[m].Set(0., 0.); plan.exec(tbkf.data(), 1., true); for (size_t i = 0; i < n2 / 2 + 1; ++i) bkf[i] = tbkf[i]; } template void exec(cmplx c[], T0 fct, bool fwd) const { fwd ? fft(c, fct) : fft(c, fct); } template void exec_r(T c[], T0 fct, bool fwd) { arr> tmp(n); if (fwd) { auto zero = T0(0) * c[0]; for (size_t m = 0; m < n; ++m) tmp[m].Set(c[m], zero); fft(tmp.data(), fct); c[0] = tmp[0].r; std::copy_n(&tmp[1].r, n - 1, &c[1]); } else { tmp[0].Set(c[0], c[0] * 0); std::copy_n(c + 1, n - 1, &tmp[1].r); if ((n & 1) == 0) tmp[n / 2].i = T0(0) * c[0]; for (size_t m = 1; 2 * m < n; ++m) tmp[n - m].Set(tmp[m].r, -tmp[m].i); fft(tmp.data(), fct); for (size_t m = 0; m < n; ++m) c[m] = tmp[m].r; } } }; // // flexible (FFTPACK/Bluestein) complex 1D transform // template class pocketfft_c { private: std::unique_ptr> packplan; std::unique_ptr> blueplan; size_t len; public: POCKETFFT_NOINLINE pocketfft_c(size_t length) : len(length) { if (length == 0) throw std::runtime_error("zero-length FFT requested"); size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length); if (tmp * tmp <= length) { packplan = std::unique_ptr>(new cfftp(length)); return; } double comp1 = util::cost_guess(length); double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1)); comp2 *= 1.5; /* fudge factor that appears to give good overall performance */ if (comp2 < comp1) // use Bluestein blueplan = std::unique_ptr>(new fftblue(length)); else packplan = std::unique_ptr>(new cfftp(length)); } template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const { packplan ? packplan->exec(c, fct, fwd) : blueplan->exec(c, fct, fwd); } size_t length() const { return len; } }; // // flexible (FFTPACK/Bluestein) real-valued 1D transform // template class pocketfft_r { private: std::unique_ptr> packplan; std::unique_ptr> blueplan; size_t len; public: POCKETFFT_NOINLINE pocketfft_r(size_t length) : len(length) { if (length == 0) throw std::runtime_error("zero-length FFT requested"); size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length); if (tmp * tmp <= length) { packplan = std::unique_ptr>(new rfftp(length)); return; } double comp1 = 0.5 * util::cost_guess(length); double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1)); comp2 *= 1.5; /* fudge factor that appears to give good overall performance */ if (comp2 < comp1) // use Bluestein blueplan = std::unique_ptr>(new fftblue(length)); else packplan = std::unique_ptr>(new rfftp(length)); } template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const { packplan ? packplan->exec(c, fct, fwd) : blueplan->exec_r(c, fct, fwd); } size_t length() const { return len; } }; // // sine/cosine transforms // template class T_dct1 { private: pocketfft_r fftplan; public: POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2 * (length - 1)) {} template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); size_t N = fftplan.length(), n = N / 2 + 1; if (ortho) { c[0] *= sqrt2; c[n - 1] *= sqrt2; } arr tmp(N); tmp[0] = c[0]; for (size_t i = 1; i < n; ++i) tmp[i] = tmp[N - i] = c[i]; fftplan.exec(tmp.data(), fct, true); c[0] = tmp[0]; for (size_t i = 1; i < n; ++i) c[i] = tmp[2 * i - 1]; if (ortho) { c[0] *= sqrt2 * T0(0.5); c[n - 1] *= sqrt2 * T0(0.5); } } size_t length() const { return fftplan.length() / 2 + 1; } }; template class T_dst1 { private: pocketfft_r fftplan; public: POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2 * (length + 1)) {} template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool /*cosine*/) const { size_t N = fftplan.length(), n = N / 2 - 1; arr tmp(N); tmp[0] = tmp[n + 1] = c[0] * 0; for (size_t i = 0; i < n; ++i) { tmp[i + 1] = c[i]; tmp[N - 1 - i] = -c[i]; } fftplan.exec(tmp.data(), fct, true); for (size_t i = 0; i < n; ++i) c[i] = -tmp[2 * i + 2]; } size_t length() const { return fftplan.length() / 2 - 1; } }; template class T_dcst23 { private: pocketfft_r fftplan; std::vector twiddle; public: POCKETFFT_NOINLINE T_dcst23(size_t length) : fftplan(length), twiddle(length) { sincos_2pibyn tw(4 * length); for (size_t i = 0; i < length; ++i) twiddle[i] = tw[i + 1].r; } template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int type, bool cosine) const { constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); size_t N = length(); size_t NS2 = (N + 1) / 2; if (type == 2) { if (!cosine) for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; c[0] *= 2; if ((N & 1) == 0) c[N - 1] *= 2; for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k + 1], c[k]); fftplan.exec(c, fct, false); for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) { T t1 = twiddle[k - 1] * c[kc] + twiddle[kc - 1] * c[k]; T t2 = twiddle[k - 1] * c[k] - twiddle[kc - 1] * c[kc]; c[k] = T0(0.5) * (t1 + t2); c[kc] = T0(0.5) * (t1 - t2); } if ((N & 1) == 0) c[NS2] *= twiddle[NS2 - 1]; if (!cosine) for (size_t k = 0, kc = N - 1; k < kc; ++k, --kc) std::swap(c[k], c[kc]); if (ortho) c[0] *= sqrt2 * T0(0.5); } else { if (ortho) c[0] *= sqrt2; if (!cosine) for (size_t k = 0, kc = N - 1; k < NS2; ++k, --kc) std::swap(c[k], c[kc]); for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) { T t1 = c[k] + c[kc], t2 = c[k] - c[kc]; c[k] = twiddle[k - 1] * t2 + twiddle[kc - 1] * t1; c[kc] = twiddle[k - 1] * t1 - twiddle[kc - 1] * t2; } if ((N & 1) == 0) c[NS2] *= 2 * twiddle[NS2 - 1]; fftplan.exec(c, fct, true); for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k], c[k + 1]); if (!cosine) for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; } } size_t length() const { return fftplan.length(); } }; template class T_dcst4 { private: size_t N; std::unique_ptr> fft; std::unique_ptr> rfft; arr> C2; public: POCKETFFT_NOINLINE T_dcst4(size_t length) : N(length), fft((N & 1) ? nullptr : new pocketfft_c(N / 2)), rfft((N & 1) ? new pocketfft_r(N) : nullptr), C2((N & 1) ? 0 : N / 2) { if ((N & 1) == 0) { sincos_2pibyn tw(16 * N); for (size_t i = 0; i < N / 2; ++i) C2[i] = conj(tw[8 * i + 1]); } } template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool cosine) const { size_t n2 = N / 2; if (!cosine) for (size_t k = 0, kc = N - 1; k < n2; ++k, --kc) std::swap(c[k], c[kc]); if (N & 1) { // The following code is derived from the FFTW3 function apply_re11() // and is released under the 3-clause BSD license with friendly // permission of Matteo Frigo and Steven G. Johnson. arr y(N); { size_t i = 0, m = n2; for (; m < N; ++i, m += 4) y[i] = c[m]; for (; m < 2 * N; ++i, m += 4) y[i] = -c[2 * N - m - 1]; for (; m < 3 * N; ++i, m += 4) y[i] = -c[m - 2 * N]; for (; m < 4 * N; ++i, m += 4) y[i] = c[4 * N - m - 1]; for (; i < N; ++i, m += 4) y[i] = c[m - 4 * N]; } rfft->exec(y.data(), fct, true); { auto SGN = [](size_t i) { constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); return (i & 2) ? -sqrt2 : sqrt2; }; c[n2] = y[0] * SGN(n2 + 1); size_t i = 0, i1 = 1, k = 1; for (; k < n2; ++i, ++i1, k += 2) { c[i] = y[2 * k - 1] * SGN(i1) + y[2 * k] * SGN(i); c[N - i1] = y[2 * k - 1] * SGN(N - i) - y[2 * k] * SGN(N - i1); c[n2 - i1] = y[2 * k + 1] * SGN(n2 - i) - y[2 * k + 2] * SGN(n2 - i1); c[n2 + i1] = y[2 * k + 1] * SGN(n2 + i + 2) + y[2 * k + 2] * SGN(n2 + i1); } if (k == n2) { c[i] = y[2 * k - 1] * SGN(i + 1) + y[2 * k] * SGN(i); c[N - i1] = y[2 * k - 1] * SGN(i + 2) + y[2 * k] * SGN(i1); } } // FFTW-derived code ends here } else { // even length algorithm from // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ arr> y(n2); for (size_t i = 0; i < n2; ++i) { y[i].Set(c[2 * i], c[N - 1 - 2 * i]); y[i] *= C2[i]; } fft->exec(y.data(), fct, true); for (size_t i = 0, ic = n2 - 1; i < n2; ++i, --ic) { c[2 * i] = 2 * (y[i].r * C2[i].r - y[i].i * C2[i].i); c[2 * i + 1] = -2 * (y[ic].i * C2[ic].r + y[ic].r * C2[ic].i); } } if (!cosine) for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; } size_t length() const { return N; } }; // // multi-D infrastructure // template std::shared_ptr get_plan(size_t length) { #if POCKETFFT_CACHE_SIZE == 0 return std::make_shared(length); #else constexpr size_t nmax = POCKETFFT_CACHE_SIZE; static std::array, nmax> cache; static std::array last_access{{0}}; static size_t access_counter = 0; static std::mutex mut; auto find_in_cache = [&]() -> std::shared_ptr { for (size_t i = 0; i < nmax; ++i) if (cache[i] && (cache[i]->length() == length)) { // no need to update if this is already the most recent entry if (last_access[i] != access_counter) { last_access[i] = ++access_counter; // Guard against overflow if (access_counter == 0) last_access.fill(0); } return cache[i]; } return nullptr; }; { std::lock_guard lock(mut); auto p = find_in_cache(); if (p) return p; } auto plan = std::make_shared(length); { std::lock_guard lock(mut); auto p = find_in_cache(); if (p) return p; size_t lru = 0; for (size_t i = 1; i < nmax; ++i) if (last_access[i] < last_access[lru]) lru = i; cache[lru] = plan; last_access[lru] = ++access_counter; } return plan; #endif } class arr_info { protected: shape_t shp; stride_t str; public: arr_info(const shape_t& shape_, const stride_t& stride_) : shp(shape_), str(stride_) {} size_t ndim() const { return shp.size(); } size_t size() const { return util::prod(shp); } const shape_t& shape() const { return shp; } size_t shape(size_t i) const { return shp[i]; } const stride_t& stride() const { return str; } const ptrdiff_t& stride(size_t i) const { return str[i]; } }; template class cndarr : public arr_info { protected: const char* d; public: cndarr(const void* data_, const shape_t& shape_, const stride_t& stride_) : arr_info(shape_, stride_), d(reinterpret_cast(data_)) {} const T& operator[](ptrdiff_t ofs) const { return *reinterpret_cast(d + ofs); } }; template class ndarr : public cndarr { public: ndarr(void* data_, const shape_t& shape_, const stride_t& stride_) : cndarr::cndarr(const_cast(data_), shape_, stride_) {} T& operator[](ptrdiff_t ofs) { return *reinterpret_cast(const_cast(cndarr::d + ofs)); } }; template class multi_iter { private: shape_t pos; const arr_info &iarr, &oarr; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; size_t idim, rem; void advance_i() { for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { auto i = size_t(i_); if (i == idim) continue; p_ii += iarr.stride(i); p_oi += oarr.stride(i); if (++pos[i] < iarr.shape(i)) return; pos[i] = 0; p_ii -= ptrdiff_t(iarr.shape(i)) * iarr.stride(i); p_oi -= ptrdiff_t(oarr.shape(i)) * oarr.stride(i); } } public: multi_iter(const arr_info& iarr_, const arr_info& oarr_, size_t idim_) : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), idim(idim_), rem(iarr.size() / iarr.shape(idim)) { auto nshares = threading::num_threads(); if (nshares == 1) return; if (nshares == 0) throw std::runtime_error("can't run with zero threads"); auto myshare = threading::thread_id(); if (myshare >= nshares) throw std::runtime_error("impossible share requested"); size_t nbase = rem / nshares; size_t additional = rem % nshares; size_t lo = myshare * nbase + ((myshare < additional) ? myshare : additional); size_t hi = lo + nbase + (myshare < additional); size_t todo = hi - lo; size_t chunk = rem; for (size_t i = 0; i < pos.size(); ++i) { if (i == idim) continue; chunk /= iarr.shape(i); size_t n_advance = lo / chunk; pos[i] += n_advance; p_ii += ptrdiff_t(n_advance) * iarr.stride(i); p_oi += ptrdiff_t(n_advance) * oarr.stride(i); lo -= n_advance * chunk; } rem = todo; } void advance(size_t n) { if (rem < n) throw std::runtime_error("underrun"); for (size_t i = 0; i < n; ++i) { p_i[i] = p_ii; p_o[i] = p_oi; advance_i(); } rem -= n; } ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i) * str_i; } ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i) * str_i; } ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i) * str_o; } ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i) * str_o; } size_t length_in() const { return iarr.shape(idim); } size_t length_out() const { return oarr.shape(idim); } ptrdiff_t stride_in() const { return str_i; } ptrdiff_t stride_out() const { return str_o; } size_t remaining() const { return rem; } }; class simple_iter { private: shape_t pos; const arr_info& arr; ptrdiff_t p; size_t rem; public: simple_iter(const arr_info& arr_) : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {} void advance() { --rem; for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (++pos[i] < arr.shape(i)) return; pos[i] = 0; p -= ptrdiff_t(arr.shape(i)) * arr.stride(i); } } ptrdiff_t ofs() const { return p; } size_t remaining() const { return rem; } }; class rev_iter { private: shape_t pos; const arr_info& arr; std::vector rev_axis; std::vector rev_jump; size_t last_axis, last_size; shape_t shp; ptrdiff_t p, rp; size_t rem; public: rev_iter(const arr_info& arr_, const shape_t& axes) : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), rev_jump(arr_.ndim(), 1), p(0), rp(0) { for (auto ax : axes) rev_axis[ax] = 1; last_axis = axes.back(); last_size = arr.shape(last_axis) / 2 + 1; shp = arr.shape(); shp[last_axis] = last_size; rem = 1; for (auto i : shp) rem *= i; } void advance() { --rem; for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (!rev_axis[i]) rp += arr.stride(i); else { rp -= arr.stride(i); if (rev_jump[i]) { rp += ptrdiff_t(arr.shape(i)) * arr.stride(i); rev_jump[i] = 0; } } if (++pos[i] < shp[i]) return; pos[i] = 0; p -= ptrdiff_t(shp[i]) * arr.stride(i); if (rev_axis[i]) { rp -= ptrdiff_t(arr.shape(i) - shp[i]) * arr.stride(i); rev_jump[i] = 1; } else rp -= ptrdiff_t(shp[i]) * arr.stride(i); } } ptrdiff_t ofs() const { return p; } ptrdiff_t rev_ofs() const { return rp; } size_t remaining() const { return rem; } }; template struct VTYPE {}; template using vtype_t = typename VTYPE::type; #ifndef POCKETFFT_NO_VECTORS template<> struct VTYPE { using type = float __attribute__((vector_size(VLEN::val * sizeof(float)))); }; template<> struct VTYPE { using type = double __attribute__((vector_size(VLEN::val * sizeof(double)))); }; template<> struct VTYPE { using type = long double __attribute__((vector_size(VLEN::val * sizeof(long double)))); }; #endif template arr alloc_tmp(const shape_t& shape, size_t axsize, size_t elemsize) { auto othersize = util::prod(shape) / axsize; auto tmpsize = axsize * ((othersize >= VLEN::val) ? VLEN::val : 1); return arr(tmpsize * elemsize); } template arr alloc_tmp(const shape_t& shape, const shape_t& axes, size_t elemsize) { size_t fullsize = util::prod(shape); size_t tmpsize = 0; for (size_t i = 0; i < axes.size(); ++i) { auto axsize = shape[axes[i]]; auto othersize = fullsize / axsize; auto sz = axsize * ((othersize >= VLEN::val) ? VLEN::val : 1); if (sz > tmpsize) tmpsize = sz; } return arr(tmpsize * elemsize); } template void copy_input(const multi_iter& it, const cndarr>& src, cmplx>* POCKETFFT_RESTRICT dst) { for (size_t i = 0; i < it.length_in(); ++i) for (size_t j = 0; j < vlen; ++j) { dst[i].r[j] = src[it.iofs(j, i)].r; dst[i].i[j] = src[it.iofs(j, i)].i; } } template void copy_input(const multi_iter& it, const cndarr& src, vtype_t* POCKETFFT_RESTRICT dst) { for (size_t i = 0; i < it.length_in(); ++i) for (size_t j = 0; j < vlen; ++j) dst[i][j] = src[it.iofs(j, i)]; } template void copy_input(const multi_iter& it, const cndarr& src, T* POCKETFFT_RESTRICT dst) { if (dst == &src[it.iofs(0)]) return; // in-place for (size_t i = 0; i < it.length_in(); ++i) dst[i] = src[it.iofs(i)]; } template void copy_output(const multi_iter& it, const cmplx>* POCKETFFT_RESTRICT src, ndarr>& dst) { for (size_t i = 0; i < it.length_out(); ++i) for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)].Set(src[i].r[j], src[i].i[j]); } template void copy_output(const multi_iter& it, const vtype_t* POCKETFFT_RESTRICT src, ndarr& dst) { for (size_t i = 0; i < it.length_out(); ++i) for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)] = src[i][j]; } template void copy_output(const multi_iter& it, const T* POCKETFFT_RESTRICT src, ndarr& dst) { if (src == &dst[it.oofs(0)]) return; // in-place for (size_t i = 0; i < it.length_out(); ++i) dst[it.oofs(i)] = src[i]; } template struct add_vec { using type = vtype_t; }; template struct add_vec> { using type = cmplx>; }; template using add_vec_t = typename add_vec::type; template POCKETFFT_NOINLINE void general_nd(const cndarr& in, ndarr& out, const shape_t& axes, T0 fct, size_t nthreads, const Exec& exec, const bool allow_inplace = true) { std::shared_ptr plan; for (size_t iax = 0; iax < axes.size(); ++iax) { size_t len = in.shape(axes[iax]); if ((!plan) || (len != plan->length())) plan = get_plan(len); threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); const auto& tin(iax == 0 ? in : out); multi_iter it(tin, out, axes[iax]); #ifndef POCKETFFT_NO_VECTORS if (vlen > 1) while (it.remaining() >= vlen) { it.advance(vlen); auto tdatav = reinterpret_cast*>(storage.data()); exec(it, tin, out, tdatav, *plan, fct); } #endif while (it.remaining() > 0) { it.advance(1); auto buf = allow_inplace && it.stride_out() == sizeof(T) ? &out[it.oofs(0)] : reinterpret_cast(storage.data()); exec(it, tin, out, buf, *plan, fct); } }); // end of parallel region fct = T0(1); // factor has been applied, use 1 for remaining axes } } struct ExecC2C { bool forward; template void operator()(const multi_iter& it, const cndarr>& in, ndarr>& out, T* buf, const pocketfft_c& plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, forward); copy_output(it, buf, out); } }; template void copy_hartley(const multi_iter& it, const vtype_t* POCKETFFT_RESTRICT src, ndarr& dst) { for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, 0)] = src[0][j]; size_t i = 1, i1 = 1, i2 = it.length_out() - 1; for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) for (size_t j = 0; j < vlen; ++j) { dst[it.oofs(j, i1)] = src[i][j] + src[i + 1][j]; dst[it.oofs(j, i2)] = src[i][j] - src[i + 1][j]; } if (i < it.length_out()) for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i1)] = src[i][j]; } template void copy_hartley(const multi_iter& it, const T* POCKETFFT_RESTRICT src, ndarr& dst) { dst[it.oofs(0)] = src[0]; size_t i = 1, i1 = 1, i2 = it.length_out() - 1; for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) { dst[it.oofs(i1)] = src[i] + src[i + 1]; dst[it.oofs(i2)] = src[i] - src[i + 1]; } if (i < it.length_out()) dst[it.oofs(i1)] = src[i]; } struct ExecHartley { template void operator()(const multi_iter& it, const cndarr& in, ndarr& out, T* buf, const pocketfft_r& plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, true); copy_hartley(it, buf, out); } }; struct ExecDcst { bool ortho; int type; bool cosine; template void operator()(const multi_iter& it, const cndarr& in, ndarr& out, T* buf, const Tplan& plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, ortho, type, cosine); copy_output(it, buf, out); } }; template POCKETFFT_NOINLINE void general_r2c(const cndarr& in, ndarr>& out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(in.shape(axis)); size_t len = in.shape(axis); threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); multi_iter it(in, out, axis); #ifndef POCKETFFT_NO_VECTORS if (vlen > 1) while (it.remaining() >= vlen) { it.advance(vlen); auto tdatav = reinterpret_cast*>(storage.data()); copy_input(it, in, tdatav); plan->exec(tdatav, fct, true); for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, 0)].Set(tdatav[0][j]); size_t i = 1, ii = 1; if (forward) for (; i < len - 1; i += 2, ++ii) for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, ii)].Set(tdatav[i][j], tdatav[i + 1][j]); else for (; i < len - 1; i += 2, ++ii) for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, ii)].Set(tdatav[i][j], -tdatav[i + 1][j]); if (i < len) for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, ii)].Set(tdatav[i][j]); } #endif while (it.remaining() > 0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); copy_input(it, in, tdata); plan->exec(tdata, fct, true); out[it.oofs(0)].Set(tdata[0]); size_t i = 1, ii = 1; if (forward) for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], tdata[i + 1]); else for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], -tdata[i + 1]); if (i < len) out[it.oofs(ii)].Set(tdata[i]); } }); // end of parallel region } template POCKETFFT_NOINLINE void general_c2r(const cndarr>& in, ndarr& out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(out.shape(axis)); size_t len = out.shape(axis); threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(out.shape(), len, sizeof(T)); multi_iter it(in, out, axis); #ifndef POCKETFFT_NO_VECTORS if (vlen > 1) while (it.remaining() >= vlen) { it.advance(vlen); auto tdatav = reinterpret_cast*>(storage.data()); for (size_t j = 0; j < vlen; ++j) tdatav[0][j] = in[it.iofs(j, 0)].r; { size_t i = 1, ii = 1; if (forward) for (; i < len - 1; i += 2, ++ii) for (size_t j = 0; j < vlen; ++j) { tdatav[i][j] = in[it.iofs(j, ii)].r; tdatav[i + 1][j] = -in[it.iofs(j, ii)].i; } else for (; i < len - 1; i += 2, ++ii) for (size_t j = 0; j < vlen; ++j) { tdatav[i][j] = in[it.iofs(j, ii)].r; tdatav[i + 1][j] = in[it.iofs(j, ii)].i; } if (i < len) for (size_t j = 0; j < vlen; ++j) tdatav[i][j] = in[it.iofs(j, ii)].r; } plan->exec(tdatav, fct, false); copy_output(it, tdatav, out); } #endif while (it.remaining() > 0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); tdata[0] = in[it.iofs(0)].r; { size_t i = 1, ii = 1; if (forward) for (; i < len - 1; i += 2, ++ii) { tdata[i] = in[it.iofs(ii)].r; tdata[i + 1] = -in[it.iofs(ii)].i; } else for (; i < len - 1; i += 2, ++ii) { tdata[i] = in[it.iofs(ii)].r; tdata[i + 1] = in[it.iofs(ii)].i; } if (i < len) tdata[i] = in[it.iofs(ii)].r; } plan->exec(tdata, fct, false); copy_output(it, tdata, out); } }); // end of parallel region } struct ExecR2R { bool r2h, forward; template void operator()(const multi_iter& it, const cndarr& in, ndarr& out, T* buf, const pocketfft_r& plan, T0 fct) const { copy_input(it, in, buf); if ((!r2h) && forward) for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i]; plan.exec(buf, fct, r2h); if (r2h && (!forward)) for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i]; copy_output(it, buf, out); } }; template void c2c(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, bool forward, const std::complex* data_in, std::complex* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape) == 0) return; util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); cndarr> ain(data_in, shape, stride_in); ndarr> aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); } template void dct(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, int type, const T* data_in, T* data_out, T fct, bool ortho, size_t nthreads = 1) { if ((type < 1) || (type > 4)) throw std::invalid_argument("invalid DCT type"); if (util::prod(shape) == 0) return; util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); const ExecDcst exec{ortho, type, true}; if (type == 1) general_nd>(ain, aout, axes, fct, nthreads, exec); else if (type == 4) general_nd>(ain, aout, axes, fct, nthreads, exec); else general_nd>(ain, aout, axes, fct, nthreads, exec); } template void dst(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, int type, const T* data_in, T* data_out, T fct, bool ortho, size_t nthreads = 1) { if ((type < 1) || (type > 4)) throw std::invalid_argument("invalid DST type"); if (util::prod(shape) == 0) return; util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); const ExecDcst exec{ortho, type, false}; if (type == 1) general_nd>(ain, aout, axes, fct, nthreads, exec); else if (type == 4) general_nd>(ain, aout, axes, fct, nthreads, exec); else general_nd>(ain, aout, axes, fct, nthreads, exec); } template void r2c(const shape_t& shape_in, const stride_t& stride_in, const stride_t& stride_out, size_t axis, bool forward, const T* data_in, std::complex* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape_in) == 0) return; util::sanity_check(shape_in, stride_in, stride_out, false, axis); cndarr ain(data_in, shape_in, stride_in); shape_t shape_out(shape_in); shape_out[axis] = shape_in[axis] / 2 + 1; ndarr> aout(data_out, shape_out, stride_out); general_r2c(ain, aout, axis, forward, fct, nthreads); } template void r2c(const shape_t& shape_in, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, bool forward, const T* data_in, std::complex* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape_in) == 0) return; util::sanity_check(shape_in, stride_in, stride_out, false, axes); r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, fct, nthreads); if (axes.size() == 1) return; shape_t shape_out(shape_in); shape_out[axes.back()] = shape_in[axes.back()] / 2 + 1; auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, T(1), nthreads); } template void c2r(const shape_t& shape_out, const stride_t& stride_in, const stride_t& stride_out, size_t axis, bool forward, const std::complex* data_in, T* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape_out) == 0) return; util::sanity_check(shape_out, stride_in, stride_out, false, axis); shape_t shape_in(shape_out); shape_in[axis] = shape_out[axis] / 2 + 1; cndarr> ain(data_in, shape_in, stride_in); ndarr aout(data_out, shape_out, stride_out); general_c2r(ain, aout, axis, forward, fct, nthreads); } template void c2r(const shape_t& shape_out, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, bool forward, const std::complex* data_in, T* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape_out) == 0) return; if (axes.size() == 1) return c2r(shape_out, stride_in, stride_out, axes[0], forward, data_in, data_out, fct, nthreads); util::sanity_check(shape_out, stride_in, stride_out, false, axes); auto shape_in = shape_out; shape_in[axes.back()] = shape_out[axes.back()] / 2 + 1; auto nval = util::prod(shape_in); stride_t stride_inter(shape_in.size()); stride_inter.back() = sizeof(cmplx); for (int i = int(shape_in.size()) - 2; i >= 0; --i) stride_inter[size_t(i)] = stride_inter[size_t(i + 1)] * ptrdiff_t(shape_in[size_t(i + 1)]); arr> tmp(nval); auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), T(1), nthreads); c2r(shape_out, stride_inter, stride_out, axes.back(), forward, tmp.data(), data_out, fct, nthreads); } template void r2r_fftpack(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, bool real2hermitian, bool forward, const T* data_in, T* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape) == 0) return; util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecR2R{real2hermitian, forward}); } template void r2r_separable_hartley(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, const T* data_in, T* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape) == 0) return; util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, false); } template void r2r_genuine_hartley(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out, const shape_t& axes, const T* data_in, T* data_out, T fct, size_t nthreads = 1) { if (util::prod(shape) == 0) return; if (axes.size() == 1) return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, data_out, fct, nthreads); util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); shape_t tshp(shape); tshp[axes.back()] = tshp[axes.back()] / 2 + 1; arr> tdata(util::prod(tshp)); stride_t tstride(shape.size()); tstride.back() = sizeof(std::complex); for (size_t i = tstride.size() - 1; i > 0; --i) tstride[i - 1] = tstride[i] * ptrdiff_t(tshp[i]); r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); cndarr> atmp(tdata.data(), tshp, tstride); ndarr aout(data_out, shape, stride_out); simple_iter iin(atmp); rev_iter iout(aout, axes); while (iin.remaining() > 0) { auto v = atmp[iin.ofs()]; aout[iout.ofs()] = v.r + v.i; aout[iout.rev_ofs()] = v.r - v.i; iin.advance(); iout.advance(); } } } // namespace detail using detail::BACKWARD; using detail::c2c; using detail::c2r; using detail::dct; using detail::dst; using detail::FORWARD; using detail::r2c; using detail::r2r_fftpack; using detail::r2r_genuine_hartley; using detail::r2r_separable_hartley; using detail::shape_t; using detail::stride_t; } // namespace pocketfft #undef POCKETFFT_NOINLINE #undef POCKETFFT_RESTRICT #endif // POCKETFFT_HDRONLY_H ================================================ FILE: oneflow/user/kernels/pocketfftplan.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "pocketfft_hdronly.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { enum class FFT_EXCUTETYPE { R2C, C2C, C2R }; template struct PocketFFtParams { bool IsForward; FFT_EXCUTETYPE excute_type; dtype fct; pocketfft::shape_t axes; pocketfft::stride_t in_stridef; pocketfft::stride_t out_stridef; pocketfft::shape_t input_shape; pocketfft::shape_t output_shape; PocketFFtParams() = default; PocketFFtParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_stride, const Stride& out_stride, const std::vector& dims, const bool is_forward, const dtype f, FFT_EXCUTETYPE type) : IsForward(is_forward), excute_type(type), fct(f), axes(dims.begin(), dims.end()), in_stridef(in_stride.begin(), in_stride.end()), out_stridef(out_stride.begin(), out_stride.end()) { input_shape.resize(in_shape.size()); output_shape.resize(out_shape.size()); std::copy(in_shape.begin(), in_shape.end(), input_shape.begin()); std::copy(out_shape.begin(), out_shape.end(), output_shape.begin()); // calc element size size_t in_elemsize = type == FFT_EXCUTETYPE::C2C || type == FFT_EXCUTETYPE::C2R ? sizeof(std::complex) : sizeof(dtype); size_t out_elemsize = type == FFT_EXCUTETYPE::R2C || type == FFT_EXCUTETYPE::C2C ? sizeof(std::complex) : sizeof(dtype); for (auto& s : in_stridef) { s *= in_elemsize; } for (auto& s : out_stridef) { s *= out_elemsize; } } }; template class PocketFFtConfig { public: PocketFFtConfig(const PocketFFtConfig&) = delete; PocketFFtConfig& operator=(PocketFFtConfig const&) = delete; explicit PocketFFtConfig(const PocketFFtParams& params) : fftparams(params) {} void excute(const std::complex* in, std::complex* out) { pocketfft::c2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); } void excute(const dtype* in, std::complex* out) { pocketfft::r2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); } void excute(const std::complex* in, dtype* out) { pocketfft::c2r(fftparams.output_shape, fftparams.in_stridef, fftparams.out_stridef, fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); } private: PocketFFtParams fftparams; }; } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/prelu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template class CpuPReluKernel final : public user_op::OpKernel { public: CpuPReluKernel() = default; ~CpuPReluKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const T* x_ptr = x->dptr(); const T* alpha_ptr = alpha->dptr(); T* y_ptr = y->mut_dptr(); const int32_t elem_cnt = x->shape_view().elem_cnt(); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int batch = x->shape_view().At(0); const int channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1); const int32_t inner_size = elem_cnt / batch / channels; FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * alpha_ptr[(i / inner_size) % alpha_size]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_PRELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("prelu").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CPU_PRELU_KERNEL(float) REGISTER_CPU_PRELU_KERNEL(double) template class CpuPReluGradKernel final : public user_op::OpKernel { public: CpuPReluGradKernel() = default; ~CpuPReluGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex("alpha_diff", 0); const T* x_ptr = x->dptr(); const T* alpha_ptr = alpha->dptr(); const T* dy_ptr = dy->dptr(); T* dx_ptr = dx->mut_dptr(); T* alpha_diff_ptr = alpha_diff->mut_dptr(); const int32_t elem_cnt = x->shape_view().elem_cnt(); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int batch = x->shape_view().At(0); const int channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1); const int32_t inner_size = elem_cnt / batch / channels; Memset(ctx->stream(), alpha_diff->mut_dptr(), 0, alpha_diff->shape_view().elem_cnt() * sizeof(T)); for (int i = 0; i < elem_cnt; i++) { const T x_i = x_ptr[i]; const T dy_i = dy_ptr[i]; const T alpha_i = alpha_ptr[(i / inner_size) % alpha_size]; dx_ptr[i] = x_i > 0 ? dy_i : dy_i * alpha_i; alpha_diff_ptr[(i / inner_size) % alpha_size] += x_i > 0 ? 0 : dy_i * x_i; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_PRELU_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("prelu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_CPU_PRELU_GRAD_KERNEL(float) REGISTER_CPU_PRELU_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/prelu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { Shape CreatePreluLeftExtendedShape(const ShapeView& shape, const int32_t alpha_size) { DimVector dim_vec(shape.NumAxes()); dim_vec.at(0) = 1LL; dim_vec.at(1) = alpha_size; for (int i = 2; i < shape.NumAxes(); i++) { dim_vec.at(i) = 1LL; } return Shape(std::move(dim_vec)); } template struct PreluForwardSingleAlphaFunctor { OF_DEVICE_FUNC explicit PreluForwardSingleAlphaFunctor(const T alpha) : alpha(alpha) {} __device__ T operator()(T x) const { return (x > static_cast(0.0)) ? x : (alpha * x); } const T alpha; }; template struct PreluForwardSingleAlphaPtrFunctor { OF_DEVICE_FUNC explicit PreluForwardSingleAlphaPtrFunctor(const T* alpha_ptr) : alpha_ptr(alpha_ptr) {} __device__ PreluForwardSingleAlphaFunctor operator()() const { return PreluForwardSingleAlphaFunctor(*alpha_ptr); } const T* alpha_ptr; }; template __global__ void PReluBackwardSingleAlphaGpu(const IndexType elem_cnt, const int64_t n_tail, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff, const T* tail_x, const T* tail_dy, T* tail_dx, T* tail_alpha_diff) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; T zero_val = static_cast(0); T alpha_val = alpha[0]; for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { const LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; const LoadType* dy_load = reinterpret_cast(dy + linear_index); LoadPack dy_vec; dy_vec.storage = *dy_load; LoadPack dx_vec; T zero_val = static_cast(0.0); if (alpha_requires_grad) { LoadPack dalpha_vec; #pragma unroll for (int i = 0; i < pack_size; i++) { if (x_vec.elem[i] > zero_val) { dx_vec.elem[i] = dy_vec.elem[i]; dalpha_vec.elem[i] = zero_val; } else { dx_vec.elem[i] = dy_vec.elem[i] * alpha_val; dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i]; } } *(reinterpret_cast(dx + linear_index)) = dx_vec.storage; *(reinterpret_cast(alpha_diff + linear_index)) = dalpha_vec.storage; } else { #pragma unroll for (int i = 0; i < pack_size; i++) { if (x_vec.elem[i] > zero_val) { dx_vec.elem[i] = dy_vec.elem[i]; } else { dx_vec.elem[i] = dy_vec.elem[i] * alpha_val; } } *(reinterpret_cast(dx + linear_index)) = dx_vec.storage; } } if (tail && global_thread_id < n_tail) { const T tail_dy_val = tail_dy[global_thread_id]; if (tail_x[global_thread_id] > zero_val) { tail_dx[global_thread_id] = tail_dy_val; if (alpha_requires_grad) { tail_alpha_diff[global_thread_id] = zero_val; } } else { tail_dx[global_thread_id] = alpha_val * tail_dy_val; if (alpha_requires_grad) { tail_alpha_diff[global_thread_id] = tail_x[global_thread_id] * tail_dy_val; } } } } template __global__ void BroadcastPReluMultiAlphaNaiveForwardGpu(const int32_t elem_cnt, const int32_t alpha_size, const int32_t inner_size, const T* x, const T* alpha, T* y) { const T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; int32_t alpha_idx = (i / inner_size) % alpha_size; y[i] = x_i > zero_val ? x_i : x_i * alpha[alpha_idx]; } } template __global__ void PReluForwardMultiAlphaGpu(const IndexType elem_cnt, const IndexType alpha_size, const IndexType inner_size, const T* x, const T* alpha, T* y) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; T zero_val = static_cast(0); for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { IndexType alpha_idx = (linear_index / inner_size) % alpha_size; const LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; LoadPack y_vec; T alpha_val = alpha[alpha_idx]; #pragma unroll for (int i = 0; i < pack_size; i++) { y_vec.elem[i] = x_vec.elem[i] > zero_val ? x_vec.elem[i] : x_vec.elem[i] * alpha_val; } *(reinterpret_cast(y + linear_index)) = y_vec.storage; } } template __global__ void BroadcastPReluMultiAlphaNaiveBackwardGpu(const int32_t elem_cnt, const int32_t alpha_size, const int32_t inner_size, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff) { const T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; const T dy_i = dy[i]; int32_t alpha_i = (i / inner_size) % alpha_size; if (x_i > zero_val) { dx[i] = dy_i; if (alpha_requires_grad) { alpha_diff[i] = zero_val; } } else { dx[i] = dy_i * alpha[alpha_i]; if (alpha_requires_grad) { alpha_diff[i] = dy_i * x_i; } } } } template __global__ void PReluBackwardMultiAlphaGpu(const IndexType elem_cnt, const IndexType alpha_size, const IndexType inner_size, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff) { int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadType = cuda::elementwise::PackType; using LoadPack = cuda::elementwise::Pack; T zero_val = static_cast(0); for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * pack_size) { IndexType alpha_idx = (linear_index / inner_size) % alpha_size; const LoadType* x_load = reinterpret_cast(x + linear_index); LoadPack x_vec; x_vec.storage = *x_load; const LoadType* dy_load = reinterpret_cast(dy + linear_index); LoadPack dy_vec; dy_vec.storage = *dy_load; LoadPack dx_vec; T alpha_val = alpha[alpha_idx]; if (alpha_requires_grad) { LoadPack dalpha_vec; T zero_val = static_cast(0.0); #pragma unroll for (int i = 0; i < pack_size; i++) { if (x_vec.elem[i] > zero_val) { dx_vec.elem[i] = dy_vec.elem[i]; dalpha_vec.elem[i] = zero_val; } else { dx_vec.elem[i] = dy_vec.elem[i] * alpha_val; dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i]; } } *(reinterpret_cast(dx + linear_index)) = dx_vec.storage; *(reinterpret_cast(alpha_diff + linear_index)) = dalpha_vec.storage; } else { #pragma unroll for (int i = 0; i < pack_size; i++) { if (x_vec.elem[i] > zero_val) { dx_vec.elem[i] = dy_vec.elem[i]; } else { dx_vec.elem[i] = dy_vec.elem[i] * alpha_val; } } *(reinterpret_cast(dx + linear_index)) = dx_vec.storage; } } } constexpr int32_t kBlockSize = 256; template int GetLaunchPackSize(const int64_t inner_size) { constexpr int type_pack_size = cuda::elementwise::PackSize(); for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) { if (type_pack_size >= launch_pack_size && inner_size % launch_pack_size == 0) { return launch_pack_size; } } return 1; } template void DispatchPreluForwardPackSize(ep::Stream* stream, const int64_t elem_cnt, const int64_t alpha_size, const int64_t inner_size, const T* x, const T* alpha, T* y) { int grid_size; const int pack_size = GetLaunchPackSize(inner_size); const int64_t pack_num = elem_cnt / pack_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (pack_size == 8) { PReluForwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, y); } else if (pack_size == 4) { PReluForwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, y); } else if (pack_size == 2) { PReluForwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, y); } else { BroadcastPReluMultiAlphaNaiveForwardGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, y); } } template void DispatchPreluForwardIndex(ep::Stream* stream, const int64_t elem_cnt, const int64_t alpha_size, const int64_t inner_size, const T* x, const T* alpha, T* y) { if (elem_cnt < GetMaxVal()) { DispatchPreluForwardPackSize(stream, elem_cnt, alpha_size, inner_size, x, alpha, y); } else { DispatchPreluForwardPackSize(stream, elem_cnt, alpha_size, inner_size, x, alpha, y); } } template void DispatchPreluBackwardPackSize(ep::Stream* stream, const int64_t elem_cnt, const int64_t alpha_size, const int64_t inner_size, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff, const bool alpha_requires_grad) { int grid_size; const int pack_size = GetLaunchPackSize(inner_size); const int64_t pack_num = elem_cnt / pack_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); if (pack_size == 8) { if (alpha_requires_grad) { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } else { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } } else if (pack_size == 4) { if (alpha_requires_grad) { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } else { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } } else if (pack_size == 2) { if (alpha_requires_grad) { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } else { PReluBackwardMultiAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } } else { if (alpha_requires_grad) { BroadcastPReluMultiAlphaNaiveBackwardGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } else { BroadcastPReluMultiAlphaNaiveBackwardGpu <<As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff); } } } template void DispatchPreluBackwardIndex(ep::Stream* stream, const int64_t elem_cnt, const int64_t alpha_size, const int64_t inner_size, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff, const bool alpha_requires_grad) { if (elem_cnt < GetMaxVal()) { DispatchPreluBackwardPackSize(stream, elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff, alpha_requires_grad); } else { DispatchPreluBackwardPackSize(stream, elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff, alpha_requires_grad); } } template void DispatchPreluBackwardSingleAlphaTail(ep::Stream* stream, const IndexType elem_cnt, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff, const bool alpha_requires_grad) { constexpr int pack_size = cuda::elementwise::PackSize(); const int64_t pack_num = elem_cnt / pack_size; int grid_size; cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); const int64_t tail_offset = pack_num * pack_size; const int64_t n_tail = elem_cnt - tail_offset; const bool tail = n_tail > 0 ? true : false; if (tail) { if (alpha_requires_grad) { PReluBackwardSingleAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset, dx + tail_offset, alpha_diff + tail_offset); } else { PReluBackwardSingleAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset, dx + tail_offset, alpha_diff + tail_offset); } } else { if (alpha_requires_grad) { PReluBackwardSingleAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset, dx + tail_offset, alpha_diff + tail_offset); } else { PReluBackwardSingleAlphaGpu <<As()->cuda_stream()>>>( elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset, dx + tail_offset, alpha_diff + tail_offset); } } } template void DispatchPreluBackwardSingleAlphaIndex(ep::Stream* stream, const int64_t elem_cnt, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff, const bool alpha_requires_grad) { if (elem_cnt < GetMaxVal()) { DispatchPreluBackwardSingleAlphaTail(stream, elem_cnt, x, alpha, dy, dx, alpha_diff, alpha_requires_grad); } else { DispatchPreluBackwardSingleAlphaTail(stream, elem_cnt, x, alpha, dy, dx, alpha_diff, alpha_requires_grad); } } } // namespace template class GpuPReluKernel final : public user_op::OpKernel { public: GpuPReluKernel() = default; ~GpuPReluKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); const int32_t batch = x->shape_view().At(0); const int32_t channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int32_t inner_size = elem_cnt / batch / channels; if (alpha_size == 1) { OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( PreluForwardSingleAlphaPtrFunctor(reinterpret_cast(alpha->dptr())), elem_cnt, reinterpret_cast(y->mut_dptr()), reinterpret_cast(x->dptr()), ctx->stream()->As()->cuda_stream()))); } else { DispatchPreluForwardIndex( ctx->stream(), elem_cnt, alpha_size, inner_size, reinterpret_cast(x->dptr()), reinterpret_cast(alpha->dptr()), reinterpret_cast(y->mut_dptr())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_PRELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("prelu").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_CUDA_PRELU_KERNEL(half) REGISTER_CUDA_PRELU_KERNEL(float) REGISTER_CUDA_PRELU_KERNEL(double) template class GpuPReluGradKernel final : public user_op::OpKernel { public: GpuPReluGradKernel() = default; ~GpuPReluGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex("alpha_diff", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const bool alpha_requires_grad = ctx->Attr("alpha_requires_grad"); const int32_t elem_cnt = x->shape_view().elem_cnt(); T* broadcasted_alpha_diff = tmp_buffer->mut_dptr(); T* reduce_sum_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(elem_cnt * sizeof(T))); const int32_t batch = x->shape_view().At(0); const int32_t channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int32_t inner_size = elem_cnt / batch / channels; const Shape& left_extended_shape = CreatePreluLeftExtendedShape(ShapeView(x->shape_view()), alpha_size); if (alpha_size == 1) { DispatchPreluBackwardSingleAlphaIndex(ctx->stream(), elem_cnt, x->dptr(), alpha->dptr(), dy->dptr(), dx->mut_dptr(), broadcasted_alpha_diff, alpha_requires_grad); } else { DispatchPreluBackwardIndex(ctx->stream(), elem_cnt, alpha_size, inner_size, x->dptr(), alpha->dptr(), dy->dptr(), dx->mut_dptr(), broadcasted_alpha_diff, alpha_requires_grad); } if (alpha_requires_grad) { NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(left_extended_shape, alpha_diff->mut_dptr()), XpuVarNdarray(x->shape_view(), broadcasted_alpha_diff), XpuVarNdarray(x->shape_view(), reduce_sum_tmp_buf)); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_PRELU_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("prelu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("x", 0); \ const Shape& alpha_shape = ctx->InputShape("alpha", 0); \ const int64_t tmp_buffer_size = \ 2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ return tmp_buffer_size; \ }); REGISTER_CUDA_PRELU_GRAD_KERNEL(half) REGISTER_CUDA_PRELU_GRAD_KERNEL(float) REGISTER_CUDA_PRELU_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/quantization_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include namespace oneflow { template void QuantizationPerLayerSymmetric(const T* in_ptr, const T scale, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = out; } } template void QuantizationPerLayerAffine(const T* in_ptr, const T scale, const T zero_point, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit)) - 1; T lower_bound = 0; uint8_t zero_point_uint8 = static_cast(std::round(zero_point)); FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale + zero_point_uint8); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = out; } } template void QuantizationPerLayerCambricon(const T* in_ptr, const T shift, const int32_t quantization_bit, const int64_t num_elements, T* out_ptr) { T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; T scale = static_cast(pow(2.0, static_cast(shift))); FOR_RANGE(int64_t, i, 0, num_elements) { T out = std::nearbyint(in_ptr[i] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[i] = out; } } template class CpuQuantizationKernel final : public user_op::OpKernel { public: CpuQuantizationKernel() = default; ~CpuQuantizationKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const T* in_ptr = in->dptr(); const T* scale_ptr = scale->dptr(); T* out_ptr = out->mut_dptr(); // round to even auto origin_round_mode = std::fegetround(); std::fesetround(FE_TONEAREST); if (quantization_formula == "google") { int64_t outer_num = 1; int64_t inner_num = in->shape_view().elem_cnt(); if (scale->shape_view().elem_cnt() > 1) { // per-channel quantization outer_num = in->shape_view().At(0); inner_num = in->shape_view().Count(1); } if (quantization_scheme == "symmetric") { FOR_RANGE(int64_t, c, 0, outer_num) { QuantizationPerLayerSymmetric(in_ptr, scale_ptr[c], quantization_bit, inner_num, out_ptr); in_ptr += inner_num; out_ptr += inner_num; } } else { // quantization_scheme == "affine" const T* zero_point_ptr = zero_point->dptr(); FOR_RANGE(int64_t, c, 0, outer_num) { QuantizationPerLayerAffine(in_ptr, scale_ptr[c], zero_point_ptr[c], quantization_bit, inner_num, out_ptr); in_ptr += inner_num; out_ptr += inner_num; } } } else if (quantization_formula == "cambricon") { QuantizationPerLayerCambricon(in_ptr, scale_ptr[0], quantization_bit, in->shape_view().elem_cnt(), out_ptr); } else { UNIMPLEMENTED(); } std::fesetround(origin_round_mode); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_QUANTIZATION_KERNEL(dtype) \ REGISTER_USER_KERNEL("quantization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_QUANTIZATION_KERNEL(float); REGISTER_QUANTIZATION_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/quantization_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { namespace { template __global__ void QuantizationSymmetric(const T* in_ptr, const T* scale_ptr, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; while (gid < elements) { int64_t channel_index = gid / panel_size; int64_t scale_idx = min(scale_size - 1, channel_index); T scale = scale_ptr[scale_idx]; T out = nearbyint(in_ptr[gid] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = out; gid += step; } } template __global__ void QuantizationAffine(const T* in_ptr, const T* scale_ptr, const T* zero_point_ptr, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit)) - 1; T lower_bound = 0; while (gid < elements) { int64_t channel_index = gid / panel_size; int64_t scale_idx = min(scale_size - 1, channel_index); T scale = scale_ptr[scale_idx]; T zero_point = zero_point_ptr[scale_idx]; T out = nearbyint(in_ptr[gid] / scale + zero_point); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = out; gid += step; } } template __global__ void QuantizationCambricon(const T* in_ptr, const T* shift, const int64_t scale_size, const int64_t elements, const int64_t panel_size, const double quantization_bit, T* out_ptr) { int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int64_t step = gridDim.x * blockDim.x; T upper_bound = static_cast(pow(2.0, quantization_bit - 1)) - 1; T lower_bound = -upper_bound - 1; T scale = static_cast(pow(2.0, static_cast(shift[0]))); while (gid < elements) { T out = nearbyint(in_ptr[gid] / scale); out = out > upper_bound ? upper_bound : out; out = out < lower_bound ? lower_bound : out; out_ptr[gid] = out; gid += step; } } } // namespace template class GpuQuantizationKernel final : public user_op::OpKernel { public: GpuQuantizationKernel() = default; ~GpuQuantizationKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0); const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::string quantization_scheme = ctx->Attr("quantization_scheme"); const int32_t quantization_bit = ctx->Attr("quantization_bit"); const std::string quantization_formula = ctx->Attr("quantization_formula"); const int64_t elements = in->shape_view().elem_cnt(); const int64_t panel_size = in->shape_view().Count(1); const int64_t scale_size = scale->shape_view().elem_cnt(); // round to even auto origin_round_mode = std::fegetround(); std::fesetround(FE_TONEAREST); if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { RUN_CUDA_KERNEL((QuantizationSymmetric), ctx->stream(), elements, in->dptr(), scale->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } else { // quantization_scheme == "affine" RUN_CUDA_KERNEL((QuantizationAffine), ctx->stream(), elements, in->dptr(), scale->dptr(), zero_point->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } } else if (quantization_formula == "cambricon") { RUN_CUDA_KERNEL((QuantizationCambricon), ctx->stream(), elements, in->dptr(), scale->dptr(), scale_size, elements, panel_size, quantization_bit, out->mut_dptr()); } else { UNIMPLEMENTED(); } std::fesetround(origin_round_mode); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_QUANTIZATION_KERNEL(dtype) \ REGISTER_USER_KERNEL("quantization") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_QUANTIZATION_KERNEL(float); REGISTER_QUANTIZATION_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/radix_sort.cuh ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_ #define ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_ #include #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace { class MultiplyFunctor final { public: MultiplyFunctor(int32_t num_col) : num_col_(num_col) {} __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const { return idx * num_col_; } private: int32_t num_col_; }; } // namespace template size_t InferTempStorageForSortPairsAscending(int32_t num_row, int32_t num_col) { size_t temp_storage_bytes = 0; if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortPairs( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* d_values_in */ nullptr, /* d_values_out */ nullptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortPairs( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* d_values_in */ nullptr, /* d_values_out */ nullptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } return temp_storage_bytes; } template size_t InferTempStorageForSortPairsDescending(int32_t num_row, int32_t num_col) { size_t temp_storage_bytes = 0; if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* d_values_in */ nullptr, /* d_values_out */ nullptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortPairsDescending( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* d_values_in */ nullptr, /* d_values_out */ nullptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } return temp_storage_bytes; } template size_t InferTempStorageForSortKeysAscending(int32_t num_row, int32_t num_col) { size_t temp_storage_bytes = 0; if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortKeys( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortKeys( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } return temp_storage_bytes; } template size_t InferTempStorageForSortKeysDescending(int32_t num_row, int32_t num_col) { size_t temp_storage_bytes = 0; if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortKeysDescending( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortKeysDescending( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ nullptr, /* d_keys_out */ nullptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ 0); OF_CUDA_CHECK(err); } return temp_storage_bytes; } template void SortPairsAscending(const KeyType* keys_ptr, const ValueType* values_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes, KeyType* sorted_keys_ptr, ValueType* sorted_values_ptr, cudaStream_t stream) { size_t rt_inferred_temp_storage_bytes = InferTempStorageForSortPairsAscending(num_row, num_col); CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes); if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortPairs( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* d_values_in */ values_ptr, /* d_values_out */ sorted_values_ptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortPairs( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* d_values_in */ values_ptr, /* d_values_out */ sorted_values_ptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } } template void SortPairsDescending(const KeyType* keys_ptr, const ValueType* values_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes, KeyType* sorted_keys_ptr, ValueType* sorted_values_ptr, cudaStream_t stream) { size_t rt_inferred_temp_storage_bytes = InferTempStorageForSortPairsDescending(num_row, num_col); CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes); if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* d_values_in */ values_ptr, /* d_values_out */ sorted_values_ptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortPairsDescending( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* d_values_in */ values_ptr, /* d_values_out */ sorted_values_ptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } } template void SortKeysAscending(const KeyType* keys_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes, KeyType* sorted_keys_ptr, cudaStream_t stream) { size_t rt_inferred_temp_storage_bytes = InferTempStorageForSortKeysAscending(num_row, num_col); CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes); if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortKeys( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortKeys( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } } template void SortKeysDescending(const KeyType* keys_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes, KeyType* sorted_keys_ptr, cudaStream_t stream) { size_t rt_inferred_temp_storage_bytes = InferTempStorageForSortKeysDescending(num_row, num_col); CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes); if (num_row > 1) { using SegmentOffsetIter = cub::TransformInputIterator>; cub::CountingInputIterator counting_iter(0); MultiplyFunctor multiply_functor(num_col); SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); auto err = cub::DeviceSegmentedRadixSort::SortKeysDescending( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* num_items */ num_row * num_col, /* num_segments */ num_row, /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortKeysDescending( /* d_temp_storage */ temp_storage_ptr, /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, /* d_keys_in */ keys_ptr, /* d_keys_out */ sorted_keys_ptr, /* num_items */ num_row * num_col, /* begin_bit */ 0, /* end_bit */ sizeof(KeyType) * 8, /* stream */ stream); OF_CUDA_CHECK(err); } } } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_ ================================================ FILE: oneflow/user/kernels/random_crop_kernel_state.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/random_crop_kernel_state.h" namespace oneflow { std::shared_ptr CreateRandomCropKernelState( user_op::KernelInitContext* ctx) { int32_t num_attempts = ctx->Attr("num_attempts"); CHECK(num_attempts >= 1); const std::vector& random_aspect_ratio = ctx->Attr>("random_aspect_ratio"); CHECK(random_aspect_ratio.size() == 2 && 0 < random_aspect_ratio.at(0) && random_aspect_ratio.at(0) <= random_aspect_ratio.at(1)); const std::vector& random_area = ctx->Attr>("random_area"); CHECK(random_area.size() == 2 && 0 < random_area.at(0) && random_area.at(0) <= random_area.at(1)); const user_op::TensorDesc* out_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0); return std::shared_ptr(new RandomCropKernelState( out_tensor_desc->shape().elem_cnt(), CHECK_JUST(GetOpKernelRandomSeed(ctx)), {random_aspect_ratio.at(0), random_aspect_ratio.at(1)}, {random_area.at(0), random_area.at(1)}, num_attempts)); } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/random_crop_kernel_state.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_ #define ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/user/image/random_crop_generator.h" namespace oneflow { class RandomCropKernelState final : public user_op::OpKernelState { public: explicit RandomCropKernelState(int32_t size, int64_t seed, AspectRatioRange aspect_ratio_range, AreaRange area_range, int32_t num_attempts) : gens_(size) { std::seed_seq seq{seed}; std::vector seeds(size); seq.generate(seeds.begin(), seeds.end()); for (int32_t i = 0; i < size; ++i) { gens_.at(i).reset( new RandomCropGenerator(aspect_ratio_range, area_range, seeds.at(i), num_attempts)); } } ~RandomCropKernelState() = default; RandomCropGenerator* GetGenerator(int32_t idx) { return gens_.at(idx).get(); } private: std::vector> gens_; }; std::shared_ptr CreateRandomCropKernelState(user_op::KernelInitContext* ctx); } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_ ================================================ FILE: oneflow/user/kernels/random_mask_generator.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/random_mask_generator.h" namespace oneflow { void RandomMaskGenerator::Generate(ep::Stream* stream, const int64_t n, const float rate, bool* mask) { CHECK_GE(n, 0); std::uniform_real_distribution random_distribution(GetZeroVal(), GetOneVal()); for (int64_t i = 0; i < n; ++i) { mask[i] = random_distribution(generator_->engine()) > rate; } } template class RandomMaskGenerator; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/random_mask_generator.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" namespace oneflow { namespace { using PackType = ulonglong2; union Pack { PackType p_value; bool b_value[sizeof(PackType)]; }; __device__ bool GenMask(curandStatePhilox4_32_10_t* state, const float rate) { return curand_uniform(state) > rate; } __global__ void GenerateGpu(uint64_t seed, uint64_t offset, const int64_t n, const float rate, bool* mask) { const int id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, id, offset, &state); PackType* pack_mask = reinterpret_cast(mask); Pack pack; CUDA_1D_KERNEL_LOOP(i, n / sizeof(PackType)) { #pragma unroll for (int j = 0; j < sizeof(PackType); j += 4) { auto rand = curand_uniform4(&state); pack.b_value[j] = (&rand.x)[0] > rate; pack.b_value[j + 1] = (&rand.x)[1] > rate; pack.b_value[j + 2] = (&rand.x)[2] > rate; pack.b_value[j + 3] = (&rand.x)[3] > rate; } pack_mask[i] = pack.p_value; } const int32_t rem_cnt = n % sizeof(PackType); const int32_t rem_offset = n - rem_cnt; if (id < rem_cnt) { mask[id + rem_offset] = GenMask(&state, rate); } } } // namespace void RandomMaskGenerator::Generate(ep::Stream* stream, const int64_t n, const float rate, bool* mask) { if (n == 0) return; ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = generator_->CalcExecutionPolicy(n, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = generator_->current_seed(); uint64_t offset = generator_->get_philox_offset(counter_offset); GenerateGpu<<As()->cuda_stream()>>>(seed, offset, n, rate, mask); } template class RandomMaskGenerator; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/random_mask_generator.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_ #define ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_ #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include #endif namespace oneflow { template class RandomMaskGenerator; template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); RandomMaskGenerator(const std::shared_ptr& generator, const int device_index = -1) { generator_ = CHECK_JUST(generator->Get(device_index)); } ~RandomMaskGenerator() = default; void Generate(ep::Stream* stream, int64_t n, float rate, bool* mask); private: std::shared_ptr generator_; }; #ifdef WITH_CUDA template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); RandomMaskGenerator(const std::shared_ptr& generator, const int device_index = -1) { generator_ = CHECK_JUST(generator->Get(device_index)); } ~RandomMaskGenerator() = default; void Generate(ep::Stream* stream, int64_t n, float rate, bool* mask); private: std::shared_ptr generator_; }; #endif } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_ ================================================ FILE: oneflow/user/kernels/random_mask_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/random_mask_like_kernel.h" namespace oneflow { namespace { #define REGISTER_RANDOM_MASK_LIKE_KERNEL(device) \ REGISTER_USER_KERNEL("random_mask_like") \ .SetCreateFn>() \ .SetIsMatchedHob(user_op::HobDeviceType() == device); REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCUDA) #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/random_mask_like_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_ #define ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/ep/include/device.h" namespace oneflow { class RandomMaskLikeKernelState : public user_op::OpKernelState { public: explicit RandomMaskLikeKernelState(const std::shared_ptr& generator) : generator_(generator) {} const std::shared_ptr& generator() const { return generator_; } private: std::shared_ptr generator_; }; namespace { template class RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: RandomMaskLikeKernel() = default; ~RandomMaskLikeKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(device_type)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* like = ctx->Tensor4ArgNameAndIndex("like", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t elem_cnt = like->shape_view().elem_cnt(); bool* mask = out->mut_dptr(); auto* random_mask_like_state = dynamic_cast(state); CHECK_NOTNULL(random_mask_like_state); const auto& generator = random_mask_like_state->generator(); CHECK_NOTNULL(generator); auto* stream = ctx->stream(); const auto device_index = stream->device()->device_index(); auto random_mask_like_gen = std::make_shared>(generator, device_index); random_mask_like_gen->Generate(stream, elem_cnt, ctx->Attr("rate"), mask); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/random_seed_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx) { int64_t seed = ctx->Attr("seed"); if (!ctx->Attr("has_seed")) { seed = NewRandomSeed(); } return GetOpKernelRandomSeedInCurrentRank(ctx, seed); } // NOTE: Get random seed in current rank, and ensure that it will have same seed between // broadcast sbp and it will be different between split sbp. // // It will scan nd_sbp from last axis to first axis(It likes the algorithm in NdIndexOffsetHelper). // If sbp is broadcast, this axis will skip. // If sbp is split, it will use rand_id to accumulate the offset. Maybe GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp, uint64_t init_seed, int64_t rank_id) { uint64_t seed = init_seed; const Shape& hierarchy = *placement.hierarchy(); int64_t seed_idx = 0; int64_t stride = 1; for (int i = nd_sbp.sbp_parallel_size() - 1; i >= 0; --i) { // coordinate at axis i int coord = rank_id % hierarchy.At(i); rank_id = (rank_id - coord) / hierarchy.At(i); // coordinate reset to 0 if broadcast if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { // do nothing } else if (nd_sbp.sbp_parallel(i).has_split_parallel()) { seed_idx += coord * stride; stride *= hierarchy.At(i); } else { // other sbp is not allowed return Error::RuntimeError() << "random source op only support broadcast or split"; } } std::seed_seq seq{init_seed}; std::vector seeds(stride); seq.generate(seeds.begin(), seeds.end()); seed = JUST(VectorAt(seeds, seed_idx)); return seed; } Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, uint64_t init_seed, const user_op::OpArg& arg) { if (ctx->parallel_ctx().parallel_num() == 1) { return init_seed; } CHECK_OR_RETURN(ctx->has_output(arg.name(), arg.index())) << arg.name() << "_" << arg.index() << " not exist"; const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index()); return GetRandomSeedForRank(ctx->parallel_desc(), nd_sbp, init_seed, ctx->parallel_ctx().parallel_id()); } Maybe GetGeneratorForLazyOrGlobal(const std::shared_ptr& generator, bool is_lazy, const Optional>& placement, const Optional>& nd_sbp) { bool is_global = placement.has_value() && nd_sbp.has_value(); if (!is_lazy && !is_global) { return generator; } const auto& eager_cached_generator = generator->children_generators(); if (!is_lazy) { Symbol placement_val = JUST(placement); Symbol nd_sbp_val = JUST(nd_sbp); if (eager_cached_generator.find(std::make_pair(placement_val, nd_sbp_val)) != eager_cached_generator.end()) { return JUST(MapAt(eager_cached_generator, std::make_pair(placement_val, nd_sbp_val))); } } uint64_t init_seed = 0; if (is_lazy) { auto cpu_gen = JUST(generator->Get(0)); CHECK_OR_RETURN(cpu_gen) << "expect a CPUGenerator"; init_seed = cpu_gen->engine()(); } else { init_seed = generator->current_seed(); } auto new_gen = JUST(one::MakeGenerator(JUST(generator->device())->type())); if (is_lazy) { new_gen->set_current_seed(init_seed); return new_gen; } uint64_t rank_seed = init_seed; if (JUST(placement)->parallel_num() > 1) { JUST(one::functional::BroadcastSeedToAllRanks(&init_seed, /*root=*/0)); rank_seed = JUST( GetRandomSeedForRank(*JUST(placement), *JUST(nd_sbp), init_seed, GlobalProcessCtx::Rank())); } new_gen->set_current_seed(rank_seed); if (!is_lazy) { generator->add_children_generator(JUST(placement), JUST(nd_sbp), new_gen); } return new_gen; } Maybe GetGeneratorForLazyOrGlobal(const std::shared_ptr& generator, bool is_lazy, const std::shared_ptr& input) { if (input->is_global()) { return GetGeneratorForLazyOrGlobal(generator, is_lazy, JUST(input->parallel_desc()), JUST(input->nd_sbp())); } else { return GetGeneratorForLazyOrGlobal(generator, is_lazy, NullOpt, NullOpt); } } } // namespace oneflow ================================================ FILE: oneflow/user/kernels/random_seed_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_ #define ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_ #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/random_generator.h" namespace oneflow { Maybe GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp, uint64_t init_seed, int64_t rank_id); Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx); Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, uint64_t init_seed, const user_op::OpArg& arg = {"out", 0}); Maybe GetGeneratorForLazyOrGlobal(const std::shared_ptr& generator, bool is_lazy, const Optional>& placement, const Optional>& nd_sbp); Maybe GetGeneratorForLazyOrGlobal(const std::shared_ptr& generator, bool is_lazy, const std::shared_ptr& input); } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_ ================================================ FILE: oneflow/user/kernels/randperm_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/container_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/arange_kernel_util.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class CpuRandPermKernelCache final : public user_op::OpKernelCache { public: CpuRandPermKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {} ~CpuRandPermKernelCache() override = default; int32_t lower() const { return lower_; } int32_t upper() const { return upper_; } private: const int32_t lower_; const int32_t upper_; }; class CpuRandPermKernel final : public user_op::OpKernel { public: CpuRandPermKernel() = default; ~CpuRandPermKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); int64_t parallel_id = ctx->parallel_ctx().parallel_id(); int32_t n = ctx->Attr("n"); const Shape& logical_shape = Shape({n}); TensorSliceView view = GetTensorSliceView4ParallelId(hierarchy, nd_sbp, logical_shape, parallel_id); std::shared_ptr cache( new CpuRandPermKernelCache(view.At(0).begin(), view.At(0).end())); return cache; } else { return nullptr; } } std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int32_t* output = out->mut_dptr(); const int32_t n = ctx->Attr("n"); if (n == 0) { return; } user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int32_t* temp = tmp_buffer->mut_dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); const auto& cpu_generator = CHECK_JUST(generator->Get()); CHECK_NOTNULL(generator); if (cache == nullptr) { user_op::ArangeFunctor()(ctx->stream(), 0, 1, n, output); std::shuffle(output, output + n, cpu_generator->engine()); } else { const auto* arange_cache = dynamic_cast(cache); user_op::ArangeFunctor()(ctx->stream(), 0, 1, n, temp); std::shuffle(temp, temp + n, cpu_generator->engine()); auto len = arange_cache->upper() - arange_cache->lower(); memcpy(output, temp + arange_cache->lower(), sizeof(int32_t) * len); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("randperm") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { const int32_t n = ctx->Attr("n"); return n * sizeof(int32_t); }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/randperm_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/arange_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { __global__ void GeneKeysAndValues(const int32_t n, uint64_t seed, uint64_t offset, int32_t* values, int32_t* keys) { const int id = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, id, offset, &state); CUDA_1D_KERNEL_LOOP(i, n) { keys[i] = curand(&state); values[i] = i; } } __global__ void tempcopy2output(const int32_t n, const int32_t offset, int32_t* temp, int32_t* output) { CUDA_1D_KERNEL_LOOP(i, n) { output[i] = temp[offset + i]; } } class GpuRandPermKernelCache final : public user_op::OpKernelCache { public: GpuRandPermKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {} ~GpuRandPermKernelCache() override = default; int32_t lower() const { return lower_; } int32_t upper() const { return upper_; } private: const int32_t lower_; const int32_t upper_; }; namespace { template size_t GetCubSortPairsTempStorageSize(int64_t n) { size_t cub_sort_temp_store_size = 0; OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs(nullptr, cub_sort_temp_store_size, nullptr, nullptr, nullptr, nullptr, n))); size_t temp_store_size = GetCudaAlignedSize(cub_sort_temp_store_size); CHECK_GE(temp_store_size, 0) << "temp_store_size should >= 0."; CHECK_LT(temp_store_size, static_cast(GetMaxVal())) << "temp_store_size should < " << static_cast(GetMaxVal()); return temp_store_size; } } // namespace class GpuRandPermKernel final : public user_op::OpKernel { public: GpuRandPermKernel() = default; ~GpuRandPermKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); int64_t parallel_id = ctx->parallel_ctx().parallel_id(); int32_t n = ctx->Attr("n"); const Shape& logical_shape = Shape({n}); TensorSliceView view = GetTensorSliceView4ParallelId(hierarchy, nd_sbp, logical_shape, parallel_id); std::shared_ptr cache( new GpuRandPermKernelCache(view.At(0).begin(), view.At(0).end())); return cache; } else { return nullptr; } } std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(kCUDA)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int32_t* output = out->mut_dptr(); const int32_t n = ctx->Attr("n"); if (n == 0) { return; } user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); auto* stream = ctx->stream(); const auto device_index = stream->device()->device_index(); const auto& gpu_generator = CHECK_JUST(generator->Get(device_index)); ep::CudaStream* cuda_stream = stream->As(); auto execution_policy = gpu_generator->CalcExecutionPolicy(n, cuda_stream); auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); uint64_t seed = gpu_generator->current_seed(); uint64_t offset = gpu_generator->get_philox_offset(counter_offset); // layout for tmp |...key(in and out,2xN)..|....value....|.... space for sort function....| // values are the desired indexes ,and keys are generated randomly. void* tmp = tmp_buffer->mut_dptr(); int32_t* key_base = reinterpret_cast(tmp); const int32_t key_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); int32_t* value_base = reinterpret_cast(reinterpret_cast(key_base) + 2 * key_aligned_bytes); const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); int32_t* temp_buffer_base = reinterpret_cast(reinterpret_cast(value_base) + indices_aligned_bytes); const int32_t temp_buffer_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); void* tmp_base = reinterpret_cast(reinterpret_cast(temp_buffer_base) + temp_buffer_aligned_bytes); size_t temp_storage_bytes = GetCubSortPairsTempStorageSize(n); GeneKeysAndValues<<As()->cuda_stream()>>>( n, seed, offset, value_base, key_base); if (cache == nullptr) { auto err = cub::DeviceRadixSort::SortPairs( /* d_temp_storage */ tmp_base, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ key_base, /* d_keys_out */ key_base + n, /* d_values_in */ value_base, /* d_values_out */ output, /* num_items */ n, /* begin_bit */ 0, /* end_bit */ sizeof(int32_t) * 8, /* stream */ ctx->stream()->As()->cuda_stream()); OF_CUDA_CHECK(err); } else { auto err = cub::DeviceRadixSort::SortPairs( /* d_temp_storage */ tmp_base, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ key_base, /* d_keys_out */ key_base + n, /* d_values_in */ value_base, /* d_values_out */ temp_buffer_base, /* num_items */ n, /* begin_bit */ 0, /* end_bit */ sizeof(int32_t) * 8, /* stream */ ctx->stream()->As()->cuda_stream()); OF_CUDA_CHECK(err); const auto* randperm_cache = dynamic_cast(cache); auto len = randperm_cache->upper() - randperm_cache->lower(); const int64_t offset = randperm_cache->lower(); int32_t block_num = gpu_generator->max_block_num(); tempcopy2output<<stream()->As()->cuda_stream()>>>( len, offset, temp_buffer_base, output); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("randperm") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { const int32_t n = ctx->Attr("n"); /* Sorted In */ const int32_t sorted_in_aligned_bytes = 2 * GetCudaAlignedSize(n * sizeof(int32_t)); /* Indices */ const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); const int32_t temp_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); /* CUB Temp Storage */ const int32_t temp_storage_bytes = GetCubSortPairsTempStorageSize(n); return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes + temp_aligned_bytes; }); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/raw_reader_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/buffer.h" #include "oneflow/core/embedding/posix_file.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/channel.h" namespace oneflow { namespace { struct Block { size_t file_index; size_t offset_in_file; }; struct BatchReaderRequest { std::shared_ptr> blocks; void* buffer{}; }; class BatchReader { public: OF_DISALLOW_COPY_AND_MOVE(BatchReader); BatchReader(std::vector>&& files, std::vector&& blocks, size_t block_size_bytes, size_t num_workers) : head_(0), tail_(0), files_(std::move(files)), blocks_(blocks), block_size_bytes_(block_size_bytes), num_workers_(num_workers) { for (size_t i = 0; i < num_workers_; ++i) { Worker worker; auto* sq = new Channel(); auto* cq = new Channel(); worker.sq.reset(sq); worker.cq.reset(cq); worker.thread = std::thread([sq, cq, this]() { while (true) { BatchReaderRequest request; auto status = sq->Receive(&request); if (status == kChannelStatusErrorClosed) { break; } CHECK_EQ(status, kChannelStatusSuccess) << "channel error"; size_t buffer_offset = 0; for (size_t i = 0; i < request.blocks->size(); ++i) { size_t block_index = request.blocks->at(i); const Block& block = blocks_[block_index]; size_t remaining = block_size_bytes_; size_t file_index = block.file_index; size_t file_offset = block.offset_in_file; while (remaining != 0) { const size_t bytes_to_read = std::min(remaining, files_.at(file_index)->Size() - file_offset); PCHECK(pread(files_[file_index]->fd(), reinterpret_cast(request.buffer) + buffer_offset, bytes_to_read, file_offset) == bytes_to_read) << "file read error"; remaining -= bytes_to_read; buffer_offset += bytes_to_read; if (remaining != 0) { file_index = (file_index + 1) % files_.size(); file_offset = 0; } } } CHECK(cq->Send(std::move(request)) == kChannelStatusSuccess) << "channel error"; } }); workers_.emplace_back(std::move(worker)); } } ~BatchReader() { for (auto& work : workers_) { work.Close(); } } void SubmitRequest(BatchReaderRequest&& request) { size_t worker_id = head_.fetch_add(1, std::memory_order_relaxed) % workers_.size(); workers_.at(worker_id).sq->Send(std::move(request)); } void WaitCompleted(BatchReaderRequest* request) { size_t worker_id = tail_.fetch_add(1, std::memory_order_relaxed) % workers_.size(); workers_.at(worker_id).cq->Receive(request); } private: struct Worker { std::thread thread; std::unique_ptr> sq; std::unique_ptr> cq; void Close() { sq->Close(); cq->Close(); thread.join(); } }; std::atomic head_; std::atomic tail_; std::vector workers_; std::vector> files_; std::vector blocks_; size_t block_size_bytes_; size_t num_workers_; }; size_t GetNumShards(const Shape& hierarchy, const NdSbp& nd_sbp) { size_t num_shards = 1; FOR_RANGE(size_t, i, 0, nd_sbp.sbp_parallel_size()) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { num_shards *= hierarchy.At(sbp_parallel.split_parallel().axis()); } } return num_shards; } size_t GetShardIndex(const Shape& hierarchy, const NdSbp& nd_sbp, size_t rank) { using index_helper_t = NdIndexOffsetHelper; size_t ndim = hierarchy.NumAxes(); CHECK_GT(ndim, 0) << "wrong hierarchy"; CHECK_LE(ndim, SHAPE_MAX_AXIS_SIZE) << "wrong hierarchy"; index_helper_t index_helper(hierarchy.dim_vec().data(), ndim); int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0}; index_helper.OffsetToNdIndex(rank, nd_index); size_t stride = 1; size_t index = 0; for (int i = ndim - 1; i >= 0; --i) { const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { index += nd_index[i] * stride; stride *= hierarchy.At(i); } } return index; } class BatchGenerator { public: OF_DISALLOW_COPY_AND_MOVE(BatchGenerator); BatchGenerator() = default; virtual ~BatchGenerator() = default; virtual void Next(size_t* blocks) = 0; }; class SequentialBatchGenerator : public BatchGenerator { public: OF_DISALLOW_COPY_AND_MOVE(SequentialBatchGenerator); SequentialBatchGenerator(size_t shard_index, size_t num_shards, size_t num_batches, size_t num_blocks_per_batch) : shard_index_(shard_index), num_shards_(num_shards), num_batches_(num_batches), num_blocks_per_batch_(num_blocks_per_batch), num_blocks_per_local_batch_(num_blocks_per_batch_ / num_shards_), next_batch_index_(0) {} ~SequentialBatchGenerator() override = default; void Next(size_t* blocks) override { const size_t batch_index = next_batch_index_; next_batch_index_ = (batch_index + 1) % num_batches_; for (size_t i = 0; i < num_blocks_per_local_batch_; ++i) { blocks[i] = batch_index * num_blocks_per_batch_ + shard_index_ * num_blocks_per_local_batch_ + i; } } private: size_t shard_index_; size_t num_shards_; size_t num_batches_; size_t num_blocks_per_batch_; size_t num_blocks_per_local_batch_; size_t next_batch_index_; }; class RandomShuffleBatchGenerator : public BatchGenerator { public: OF_DISALLOW_COPY_AND_MOVE(RandomShuffleBatchGenerator); RandomShuffleBatchGenerator(size_t shard_index, size_t num_shards, size_t num_batches, size_t num_blocks_per_batch, std::mt19937_64 generator) : shard_index_(shard_index), num_shards_(num_shards), num_batches_(num_batches), num_blocks_per_batch_(num_blocks_per_batch), num_blocks_per_local_batch_(num_blocks_per_batch_ / num_shards_), current_batch_pos_(0), generator_(generator) { batches_.resize(num_batches_); std::iota(batches_.begin(), batches_.end(), 0); } ~RandomShuffleBatchGenerator() override = default; void Next(size_t* blocks) override { size_t target_batch_pos = generator_() % (batches_.size() - current_batch_pos_) + current_batch_pos_; if (target_batch_pos != current_batch_pos_) { std::swap(batches_[target_batch_pos], batches_[current_batch_pos_]); } const size_t batch_index = batches_[current_batch_pos_]; for (size_t i = 0; i < num_blocks_per_local_batch_; ++i) { blocks[i] = batch_index * num_blocks_per_batch_ + shard_index_ * num_blocks_per_local_batch_ + i; } current_batch_pos_ = (current_batch_pos_ + 1) % batches_.size(); if (current_batch_pos_ == 0) { shard_index_ = (shard_index_ + 1) % num_shards_; } } private: size_t shard_index_; size_t num_shards_; size_t num_batches_; size_t num_blocks_per_batch_; size_t num_blocks_per_local_batch_; std::vector batches_; size_t current_batch_pos_; std::mt19937_64 generator_; }; class RawReaderKernelState final : public user_op::OpKernelState { public: OF_DISALLOW_COPY_AND_MOVE(RawReaderKernelState); explicit RawReaderKernelState(user_op::KernelInitContext* ctx) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); num_shards_ = GetNumShards(*ctx->parallel_desc().hierarchy(), nd_sbp); shard_index_ = GetShardIndex(*ctx->parallel_desc().hierarchy(), nd_sbp, ctx->parallel_ctx().parallel_id()); batch_size_ = ctx->Attr("batch_size"); CHECK_EQ(batch_size_ % num_shards_, 0) << "batch_size must be a multiple of num_shards"; local_batch_size_ = batch_size_ / num_shards_; random_shuffle_ = ctx->Attr("random_shuffle"); block_size_ = ctx->Attr("shuffle_block_size"); if (block_size_ <= 0 || !random_shuffle_) { block_size_ = local_batch_size_; } CHECK_EQ(batch_size_ % block_size_, 0) << "batch_size must be a multiple of block_size"; if (block_size_ > local_batch_size_) { block_size_ = local_batch_size_; } const std::vector& filenames = ctx->Attr>("files"); const Shape& instance_shape = ctx->Attr("shape"); const size_t elem_cnt = instance_shape.elem_cnt(); CHECK_GT(elem_cnt, 0) << "instance size must be greater than 0"; DimVector dim_vec; dim_vec.push_back(local_batch_size_); for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) { dim_vec.push_back(instance_shape.At(i)); } out_shape_ = Shape(dim_vec); data_type_ = ctx->Attr("data_type"); instance_size_ = ctx->Attr("shape").elem_cnt() * GetSizeOfDataType(data_type_); CHECK_GT(batch_size_, 0) << "batch size must be greater than 0"; size_t num_instances = 0; std::vector> files; int flags = O_RDONLY; if (ParseBooleanFromEnv("ONEFLOW_RAW_READER_FORCE_DIRECT_IO", false)) { flags |= O_DIRECT; } for (const auto& filename : filenames) { std::unique_ptr file(new embedding::PosixFile(filename, flags, 0644)); if (file->Size() == 0) { continue; } CHECK_EQ(file->Size() % instance_size_, 0) << "file_size must be a multiple of instance_size"; num_instances += file->Size() / instance_size_; files.emplace_back(std::move(file)); } if ((flags & O_DIRECT) != 0) { num_batches_ = num_instances / batch_size_; } else { num_batches_ = RoundUp(num_instances, batch_size_) / batch_size_; } block_size_bytes_ = block_size_ * instance_size_; local_batch_size_bytes_ = local_batch_size_ * instance_size_; num_blocks_per_local_batch_ = local_batch_size_ / block_size_; const size_t num_blocks = num_batches_ * (batch_size_ / block_size_); size_t file_index = 0; size_t offset_in_file = 0; std::vector blocks; for (size_t i = 0; i < num_blocks; ++i) { blocks.emplace_back(Block{file_index, offset_in_file}); size_t remaining = block_size_bytes_; while (remaining != 0) { if (files[file_index]->Size() - offset_in_file >= remaining) { offset_in_file += remaining; if (offset_in_file == files[file_index]->Size()) { offset_in_file = 0; } remaining = 0; } else { remaining -= (files[file_index]->Size() - offset_in_file); offset_in_file = 0; file_index = (file_index + 1) % files.size(); } } } if (random_shuffle_) { std::mt19937_64 generator; generator.seed(ctx->Attr("seed")); std::shuffle(blocks.begin(), blocks.end(), generator); batch_generator_.reset(new RandomShuffleBatchGenerator( shard_index_, num_shards_, num_batches_, batch_size_ / block_size_, generator)); } else { batch_generator_.reset(new SequentialBatchGenerator(shard_index_, num_shards_, num_batches_, batch_size_ / block_size_)); } const size_t num_workers = ParseIntegerFromEnv("ONEFLOW_RAW_READER_NUM_WORKERS", 1); batch_reader_.reset( new BatchReader(std::move(files), std::move(blocks), block_size_bytes_, num_workers)); prefetching_qd_ = ParseIntegerFromEnv("ONEFLOW_RAW_READER_PREFETCHING_QUEUE_DEPTH", 256); for (size_t i = 0; i < prefetching_qd_; ++i) { BatchReaderRequest request; request.blocks = std::make_shared>(); if (ctx->device_type() == DeviceType::kCPU) { request.buffer = aligned_alloc(4096, RoundUp(local_batch_size_bytes_, 4096)); // NOLINT } else if (ctx->device_type() == DeviceType::kCUDA) { #ifdef WITH_CUDA int dev = 0; OF_CUDA_CHECK(cudaGetDevice(&dev)); OF_CUDA_CHECK(NumaAwareCudaMallocHost(dev, &request.buffer, local_batch_size_bytes_)); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } request.blocks = std::make_shared>(local_batch_size_ / block_size_); batch_generator_->Next(request.blocks->data()); batch_reader_->SubmitRequest(std::move(request)); } device_type_ = ctx->device_type(); } ~RawReaderKernelState() { for (size_t i = 0; i < prefetching_qd_; ++i) { BatchReaderRequest request; batch_reader_->WaitCompleted(&request); if (device_type_ == DeviceType::kCPU) { free(request.buffer); // NOLINT } else if (device_type_ == DeviceType::kCUDA) { #ifdef WITH_CUDA OF_CUDA_CHECK(cudaFreeHost(request.buffer)); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED(); } } } void Next(user_op::KernelComputeContext* ctx) { auto* tensor = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(tensor->data_type(), data_type_) << "data type mismatch"; CHECK(tensor->shape_view() == ShapeView(out_shape_)) << "shape mismatch"; BatchReaderRequest request; batch_reader_->WaitCompleted(&request); if (ctx->stream()->device_type() == DeviceType::kCPU) { std::memcpy(tensor->mut_dptr(), request.buffer, local_batch_size_bytes_); } else if (ctx->stream()->device_type() == DeviceType::kCUDA) { #ifdef WITH_CUDA OF_CUDA_CHECK(cudaMemcpyAsync(tensor->mut_dptr(), request.buffer, local_batch_size_bytes_, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); #else UNIMPLEMENTED(); #endif } else { UNIMPLEMENTED() << "only support CPU or CUDA"; } CHECK_JUST(ctx->stream()->Sync()); CHECK(request.blocks) << "blocks is NULL"; CHECK_EQ(request.blocks->size(), num_blocks_per_local_batch_) << "blocks size mismatch"; batch_generator_->Next(request.blocks->data()); batch_reader_->SubmitRequest(std::move(request)); } private: size_t instance_size_; size_t batch_size_; size_t local_batch_size_; size_t num_batches_; size_t num_shards_; size_t shard_index_; size_t block_size_; size_t block_size_bytes_; size_t num_blocks_per_local_batch_; size_t local_batch_size_bytes_; bool random_shuffle_; Shape out_shape_; DataType data_type_; std::unique_ptr batch_generator_; std::unique_ptr batch_reader_; DeviceType device_type_; size_t prefetching_qd_; }; } // namespace class RawReaderKernel final : public user_op::OpKernel { public: RawReaderKernel() = default; ~RawReaderKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { std::shared_ptr state(new RawReaderKernelState(ctx)); return state; } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { auto* reader = CHECK_NOTNULL(dynamic_cast(state)); reader->Next(ctx); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("raw_reader").SetCreateFn(); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/reduce_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #ifdef WITH_CUDA #include "oneflow/core/ep/cuda/cuda_device.h" #endif // WITH_CUDA #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewReduceMatmulTransAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("input_tensor", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } template std::unique_ptr NewReduceMatmulNoTransAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("input_tensor", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } template std::unique_ptr NewFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("output_tensor", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto ReduceMatmulTransAPrimitiveExists() { return hob::make_custom("ReduceMatmulTransAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewReduceMatmulTransAPrimitive(&ctx).operator bool(); }); } auto ReduceMatmulNoTransAPrimitiveExists() { return hob::make_custom("ReduceMatmulNoTransAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewReduceMatmulNoTransAPrimitive(&ctx).operator bool(); }); } auto FillPrimitiveExists() { return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewFillPrimitive(&ctx).operator bool(); }); } template class BinaryFunc, DeviceType device_type, typename T, typename K> class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ReduceKernel() = default; ~ReduceKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& axis = ctx->Attr>("axis"); const int32_t output_elem_cnt = output_tensor->shape_view().elem_cnt(); if (input_tensor->shape_view().elem_cnt() == 0) { if (output_tensor->shape_view().elem_cnt() != 0) { Scalar init_value = [&]() { if (std::is_same, BinaryFuncAny>::value) { return Scalar(0); } if (std::is_same, BinaryFuncAll>::value) { return Scalar(1); } return Scalar(0); }(); CHECK_GE(output_elem_cnt, 0); if (output_elem_cnt == 0) { return; } std::unique_ptr fill = NewFillPrimitive(ctx); CHECK(fill); fill->Launch(ctx->stream(), output_tensor->mut_dptr(), init_value, output_elem_cnt); } return; } const Shape& reduced_shape = CreateReducedShape(input_tensor->shape_view(), {axis.begin(), axis.end()}); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, output_tensor->mut_dptr()), XpuVarNdarray(input_tensor->shape_view(), input_tensor->dptr()), XpuVarNdarray(tmp_buffer->shape_view(), tmp_buffer->mut_dptr())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_REDUCE_XPU_KERNEL(op_name, binary_func, device, dtype) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("output_tensor", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("input_tensor", 0); \ return in_shape.elem_cnt() * sizeof(dtype); \ }); #define REGISTER_REDUCE_LOGICAL_XPU_KERNEL(op_name, binary_func, device, dtype) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("input_tensor", 0) == GetDataType::value) \ && (user_op::HobDataType("output_tensor", 0) == DataType::kBool) \ && FillPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("input_tensor", 0); \ return in_shape.elem_cnt() * sizeof(dtype); \ }); #define REGISTER_REDUCE_ARITHMETIC_KERNELS(device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_prod", BinaryFuncProd, device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_min", BinaryFuncMin, device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_max", BinaryFuncMax, device, dtype) #define REGISTER_REDUCE_NANSUM_KERNELS(device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_nansum", BinaryFuncNanSum, device, dtype) #define REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(device) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, bool) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, float) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, double) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int8_t) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, uint8_t) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int32_t) \ REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int64_t) #define REGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(device) \ REGISTER_REDUCE_NANSUM_KERNELS(device, float) \ REGISTER_REDUCE_NANSUM_KERNELS(device, double) REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCPU) REGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCUDA) REGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(DeviceType::kCUDA) #endif #define REGISTER_REDUCE_SUM_KERNELS(device, dtype) \ REGISTER_REDUCE_XPU_KERNEL("reduce_sum", BinaryFuncSum, device, dtype) #define REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(device) \ REGISTER_REDUCE_SUM_KERNELS(device, double) \ REGISTER_REDUCE_SUM_KERNELS(device, int8_t) \ REGISTER_REDUCE_SUM_KERNELS(device, uint8_t) \ REGISTER_REDUCE_SUM_KERNELS(device, int32_t) \ REGISTER_REDUCE_SUM_KERNELS(device, int64_t) REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, std::complex) REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, std::complex) #ifdef WITH_CUDA REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCUDA, cuComplex) REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCUDA, cuDoubleComplex) #endif REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCUDA) #endif REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float) REGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float16) #define REGISTER_REDUCE_LOGICAL_KERNELS(device) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, bool) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, bool) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, float) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, float) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, double) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, double) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, int8_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, int8_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, uint8_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, uint8_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, int32_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, int32_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_any", BinaryFuncAny, device, int64_t) \ REGISTER_REDUCE_LOGICAL_XPU_KERNEL("reduce_all", BinaryFuncAll, device, int64_t) REGISTER_REDUCE_LOGICAL_KERNELS(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_REDUCE_LOGICAL_KERNELS(DeviceType::kCUDA) namespace { std::vector RegularAxis(const std::vector& axis) { std::vector regular_axis = axis; std::sort(regular_axis.begin(), regular_axis.end()); return regular_axis; } void GetReduceSumLayout(const std::vector& axis, const ShapeView& in_shape, bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size, int64_t* reduce_size) { if (!axis.empty()) { *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size()); *outer_size = in_shape.Count(0, axis.front()); *inner_size = in_shape.Count(axis.back() + 1); *reduce_size = in_shape.Count(axis.front(), axis.back() + 1); } } } // namespace class ReduceSumHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ReduceSumHalfKernel() = default; ~ReduceSumHalfKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { std::vector axis = RegularAxis(ctx->Attr>("axis")); const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& in_shape = input_tensor->shape_view(); const DataType data_type = input_tensor->data_type(); bool is_axis_contiguous = false; int64_t outer_size = 0, inner_size = 0, reduce_size = 0; GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size); if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) { bool trans_a = (inner_size != 1); const int32_t m = (inner_size == 1) ? outer_size : inner_size; const int32_t n = 1; const int32_t k = reduce_size; const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { ones = cuda_device->GetConstOnes(data_type, reduce_size); } if (ones == nullptr) { std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), data_type); CHECK(fill); fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, reduce_size); ones = tmp_buffer->dptr(); } std::unique_ptr matmul; if (trans_a) { matmul = NewReduceMatmulTransAPrimitive(ctx); } else { matmul = NewReduceMatmulNoTransAPrimitive(ctx); } matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0, output_tensor->mut_dptr()); } else { const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); float* in_tmp_buffer = tmp_buffer->mut_dptr(); const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); float* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + in_tmp_buffer_bytes); const size_t out_tmp_buffer_bytes = GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float)); float* reduce_tmp_buffer = reinterpret_cast( tmp_buffer->mut_dptr() + in_tmp_buffer_bytes + out_tmp_buffer_bytes); const size_t reduce_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes, tmp_buffer->shape_view().elem_cnt()); auto h2f = ep::primitive::NewPrimitive( ctx->device_type(), data_type, DataType::kFloat); CHECK(h2f); auto f2h = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat, data_type); CHECK(f2h); h2f->Launch(ctx->stream(), input_tensor->dptr(), in_tmp_buffer, in_shape.elem_cnt()); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out_tmp_buffer), XpuVarNdarray(in_shape, in_tmp_buffer), XpuVarNdarray(in_shape, reduce_tmp_buffer)); f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr(), output_tensor->shape_view().elem_cnt()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REDUCE_SUM_HALF_KERNEL(dtype) \ REGISTER_USER_KERNEL("reduce_sum") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("output_tensor", 0) == GetDataType::value) \ && ReduceMatmulTransAPrimitiveExists() \ && ReduceMatmulNoTransAPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape(); \ const Shape& out_shape = ctx->OutputTensorDesc("output_tensor", 0).shape(); \ const auto& axis = RegularAxis(ctx->Attr>("axis")); \ bool is_axis_contiguous = false; \ int64_t outer_size = 0, inner_size = 0, reduce_size = 0; \ GetReduceSumLayout(axis, ShapeView(in_shape), &is_axis_contiguous, &outer_size, \ &inner_size, &reduce_size); \ size_t tmp_bytes = 0; \ if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) { \ tmp_bytes = GetCudaAlignedSize(reduce_size * sizeof(dtype)); \ } else { \ tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)) \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); \ } \ return tmp_bytes; \ }); REGISTER_REDUCE_SUM_HALF_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_REDUCE_SUM_HALF_KERNEL(nv_bfloat16) #endif class ReduceSumFloatCudaKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ReduceSumFloatCudaKernel() = default; ~ReduceSumFloatCudaKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { std::vector axis = RegularAxis(ctx->Attr>("axis")); const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input_tensor", 0); user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& in_shape = input_tensor->shape_view(); if (input_tensor->shape_view().elem_cnt() == 0) { if (output_tensor->shape_view().elem_cnt() != 0) { Memset( ctx->stream(), output_tensor->mut_dptr(), 0, output_tensor->shape_view().elem_cnt() * GetSizeOfDataType(output_tensor->data_type())); } return; } bool is_axis_contiguous = false; int64_t outer_size = 0, inner_size = 0, reduce_size = 0; GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size); const float* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { ones = static_cast(cuda_device->GetConstOnes(DataType::kFloat, reduce_size)); } if ((!axis.empty()) && in_shape.NumAxes() > 0 && is_axis_contiguous && (outer_size == 1 || inner_size == 1) && ones != nullptr && ParseBooleanFromEnv("ONEFLOW_KERNEL_REDUCE_SUM_USE_MATMUL", false)) { ep::primitive::BlasTransposeType trans_a = (inner_size == 1) ? ep::primitive::BlasTransposeType::N : ep::primitive::BlasTransposeType::T; ep::primitive::BlasTransposeType trans_b = ep::primitive::BlasTransposeType::N; const int32_t m = (inner_size == 1) ? outer_size : inner_size; const int32_t n = 1; const int32_t k = reduce_size; #if CUDA_VERSION >= 11000 CublasMathModeGuard guard(ctx->stream()->As()->cublas_handle()); // disable tf32 guard.SetMathMode(CUBLAS_DEFAULT_MATH); #endif // defined(WITH_CUDA) && CUDA_VERSION >= 11000 auto matmul = ep::primitive::NewPrimitive( DeviceType::kCUDA, DataType::kFloat, trans_a, trans_b); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0, output_tensor->mut_dptr()); } else { const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, output_tensor->mut_dptr()), XpuVarNdarray(input_tensor->shape_view(), input_tensor->dptr()), XpuVarNdarray(tmp_buffer->shape_view(), tmp_buffer->mut_dptr())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("reduce_sum") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) && (user_op::HobDataType("output_tensor", 0) == DataType::kFloat)) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape(); return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); }); #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/reduce_like_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewReduceMatmulTransAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("y", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true, /*transpose_b=*/false); } template std::unique_ptr NewReduceMatmulNoTransAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("y", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false, /*transpose_b=*/false); } auto ReduceMatmulTransAPrimitiveExists() { return hob::make_custom("ReduceMatmulTransAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewReduceMatmulTransAPrimitive(&ctx).operator bool(); }); } auto ReduceMatmulNoTransAPrimitiveExists() { return hob::make_custom("ReduceMatmulNoTransAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewReduceMatmulNoTransAPrimitive(&ctx).operator bool(); }); } size_t ReduceSumLikeInferTmpSize(user_op::InferContext* ctx) { if (ctx->Attr>("axis").empty()) { return 0; } const user_op::TensorDesc& tensor_desc_x = ctx->InputTensorDesc("x", 0); return tensor_desc_x.shape().elem_cnt() * GetSizeOfDataType(tensor_desc_x.data_type()); } } // namespace template class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ReduceSumLikeOpKernel() = default; ~ReduceSumLikeOpKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape_view().elem_cnt() == 0) { if (tensor_y->shape_view().elem_cnt() != 0) { Memset( ctx->stream(), tensor_y->mut_dptr(), 0, tensor_y->shape_view().elem_cnt() * GetSizeOfDataType(tensor_y->data_type())); } return; } if (axis.empty()) { CHECK_EQ(tensor_x->shape_view(), tensor_y->shape_view()); Memcpy( ctx->stream(), tensor_y->mut_dptr(), tensor_x->dptr(), tensor_x->shape_view().elem_cnt() * GetSizeOfDataType(tensor_x->data_type())); } else { user_op::Tensor* tensor_tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* temp_storage = static_cast(tensor_tmp->mut_dptr()); NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(CreateReducedShape(tensor_x->shape_view(), {axis.begin(), axis.end()}), tensor_y->mut_dptr()), XpuVarNdarray(tensor_x->shape_view(), tensor_x->dptr(), tensor_x->shape_view().NumAxes()), XpuVarNdarray(tensor_x->shape_view(), temp_storage, tensor_x->shape_view().NumAxes())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REDUCE_SUM_LIKE_KERNEL(device, data_type_pair) \ REGISTER_USER_KERNEL("reduce_sum_like") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(data_type_pair))) \ .SetInferTmpSizeFn(ReduceSumLikeInferTmpSize); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ); #if defined(WITH_CUDA) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)); #endif // WITH_CUDA #if defined(WITH_CUDA) namespace { std::vector RegularAxis(const std::vector& axis) { std::vector regular_axis = axis; std::sort(regular_axis.begin(), regular_axis.end()); return regular_axis; } void GetReduceSumLayout(const std::vector& axis, const ShapeView& in_shape, bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size, int64_t* reduce_size) { *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size()); *outer_size = in_shape.Count(0, axis.front()); *inner_size = in_shape.Count(axis.back() + 1); *reduce_size = in_shape.Count(axis.front(), axis.back() + 1); } } // namespace class ReduceSumLikeHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ReduceSumLikeHalfKernel() = default; ~ReduceSumLikeHalfKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { std::vector axis = RegularAxis(ctx->Attr>("axis")); const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); if (axis.empty()) { CHECK_EQ(tensor_x->shape_view(), tensor_y->shape_view()); Memcpy( ctx->stream(), tensor_y->mut_dptr(), tensor_x->dptr(), tensor_x->shape_view().elem_cnt() * GetSizeOfDataType(tensor_x->data_type())); } else { user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const ShapeView& in_shape = tensor_x->shape_view(); bool is_axis_contiguous = false; int64_t outer_size = 0, inner_size = 0, reduce_size = 0; GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size); if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) { bool trans_a = (inner_size != 1); const int32_t m = (inner_size == 1) ? outer_size : inner_size; const int32_t n = 1; const int32_t k = reduce_size; std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), tensor_x->data_type()); CHECK(fill); fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, reduce_size); std::unique_ptr matmul; if (trans_a) { matmul = NewReduceMatmulTransAPrimitive(ctx); } else { matmul = NewReduceMatmulNoTransAPrimitive(ctx); } CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, 1.0, tensor_x->dptr(), tmp_buffer->dptr(), 0.0, tensor_y->mut_dptr()); } else { const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()}); float* in_tmp_buffer = tmp_buffer->mut_dptr(); const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); float* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr() + in_tmp_buffer_bytes); const size_t out_tmp_buffer_bytes = GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float)); float* reduce_tmp_buffer = reinterpret_cast( tmp_buffer->mut_dptr() + in_tmp_buffer_bytes + out_tmp_buffer_bytes); const size_t reduce_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)); CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes, tmp_buffer->shape_view().elem_cnt()); auto h2f = ep::primitive::NewPrimitive( ctx->device_type(), tensor_x->data_type(), DataType::kFloat); CHECK(h2f); auto f2h = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat, tensor_x->data_type()); CHECK(f2h); h2f->Launch(ctx->stream(), tensor_x->dptr(), in_tmp_buffer, in_shape.elem_cnt()); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out_tmp_buffer), XpuVarNdarray(in_shape, in_tmp_buffer), XpuVarNdarray(in_shape, reduce_tmp_buffer)); f2h->Launch(ctx->stream(), out_tmp_buffer, tensor_y->mut_dptr(), tensor_y->shape_view().elem_cnt()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(dtype) \ REGISTER_USER_KERNEL("reduce_sum_like") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value) \ && ReduceMatmulTransAPrimitiveExists() \ && ReduceMatmulNoTransAPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputTensorDesc("x", 0).shape(); \ const Shape& out_shape = ctx->OutputTensorDesc("y", 0).shape(); \ const auto& axis = RegularAxis(ctx->Attr>("axis")); \ if (axis.empty()) { \ size_t tmp_bytes = 0; \ return tmp_bytes; \ } \ bool is_axis_contiguous = false; \ int64_t outer_size = 0, inner_size = 0, reduce_size = 0; \ GetReduceSumLayout(axis, ShapeView(in_shape), &is_axis_contiguous, &outer_size, \ &inner_size, &reduce_size); \ size_t tmp_bytes = 0; \ if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) { \ tmp_bytes = GetCudaAlignedSize(reduce_size * sizeof(dtype)); \ } else { \ tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)) \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); \ } \ return tmp_bytes; \ }); REGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(nv_bfloat16) #endif #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/reflection_pad_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/reflection_pad_kernels_util.h" namespace oneflow { namespace user_op { template class ReflectionPad1dKernel final : public OpKernel { public: ReflectionPad1dKernel() = default; ~ReflectionPad1dKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& padding = ctx->Attr>("padding"); const int64_t ndims = x->shape_view().NumAxes(); CHECK_EQ(padding.size(), ndims - 1); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; const int64_t pad_left = padding[0]; const int64_t n_batch = y->shape_view().At(n_idx); const int64_t n_channel = y->shape_view().At(c_idx); const int64_t y_width = y->shape_view().At(w_idx); const int64_t x_width = x->shape_view().At(w_idx); IN_T* dest = y->mut_dptr(); const IN_T* src = x->dptr(); DimVector y_vector; y->shape_view().ToDimVector(&y_vector); NdIndexOffsetHelper index_helper(y_vector.data()); ReflectionPad1dFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, y_width, x_width, pad_left); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReflectionPad1dGradKernel final : public OpKernel { public: ReflectionPad1dGradKernel() = default; ~ReflectionPad1dGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto& padding = ctx->Attr>("padding"); const int64_t ndims = dy->shape_view().NumAxes(); CHECK_EQ(padding.size(), ndims - 1); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; const int64_t pad_left = padding[0]; const int64_t n_batch = dy->shape_view().At(n_idx); const int64_t n_channel = dy->shape_view().At(c_idx); const int64_t dy_width = dy->shape_view().At(w_idx); const int64_t dx_width = dx->shape_view().At(w_idx); const IN_T* src = dy->dptr(); IN_T* dest = dx->mut_dptr(); DimVector dy_vector; dy->shape_view().ToDimVector(&dy_vector); NdIndexOffsetHelper index_helper(dy_vector.data()); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); ReflectionPad1dGradFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, dy_width, dx_width, pad_left); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReflectionPad2dKernel final : public OpKernel { public: ReflectionPad2dKernel() = default; ~ReflectionPad2dKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; const int64_t pad_left = padding[0]; const int64_t pad_top = padding[2]; const int64_t n_batch = y->shape_view().At(n_idx); const int64_t n_channel = y->shape_view().At(c_idx); const int64_t y_height = y->shape_view().At(h_idx); const int64_t y_width = y->shape_view().At(w_idx); const int64_t x_height = x->shape_view().At(h_idx); const int64_t x_width = x->shape_view().At(w_idx); IN_T* dest = y->mut_dptr(); const IN_T* src = x->dptr(); DimVector y_vector; y->shape_view().ToDimVector(&y_vector); NdIndexOffsetHelper index_helper(y_vector.data()); ReflectionPad2dFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, y_height, y_width, x_height, x_width, pad_left, pad_top); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReflectionPad2dGradKernel final : public OpKernel { public: ReflectionPad2dGradKernel() = default; ~ReflectionPad2dGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; int64_t pad_left = padding[0]; int64_t pad_top = padding[2]; int64_t n_batch = dy->shape_view().At(n_idx); int64_t n_channel = dy->shape_view().At(c_idx); int64_t dy_height = dy->shape_view().At(h_idx); int64_t dy_width = dy->shape_view().At(w_idx); int64_t dx_height = dx->shape_view().At(h_idx); int64_t dx_width = dx->shape_view().At(w_idx); const IN_T* src = dy->dptr(); IN_T* dest = dx->mut_dptr(); DimVector dy_vector; dy->shape_view().ToDimVector(&dy_vector); NdIndexOffsetHelper index_helper(dy_vector.data()); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); ReflectionPad2dGradFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REFLECTION_PAD_ND_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("reflection_pad1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("reflection_pad1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("reflection_pad2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("reflection_pad2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); #define REGISTER_REFLECTION_PAD_ND_WITH_DEVICE(device) \ REGISTER_REFLECTION_PAD_ND_KERNELS(device, float) \ REGISTER_REFLECTION_PAD_ND_KERNELS(device, double) \ REGISTER_REFLECTION_PAD_ND_KERNELS(device, int32_t) REGISTER_REFLECTION_PAD_ND_WITH_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_REFLECTION_PAD_ND_WITH_DEVICE(DeviceType::kCUDA) REGISTER_REFLECTION_PAD_ND_KERNELS(DeviceType::kCUDA, float16) #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/reflection_pad_kernels_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/reflection_pad_kernels_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace user_op { template struct ReflectionPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoReflectionPad1d(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } }; template struct ReflectionPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoReflectionPad1dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } }; template struct ReflectionPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoReflectionPad2d(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } }; template struct ReflectionPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoReflectionPad2dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_FUNCTOR, (DeviceType::kCPU), PADDING_DATA_TYPE_CPU_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR, (DeviceType::kCPU), PADDING_DATA_TYPE_CPU_SEQ); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/reflection_pad_kernels_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/reflection_pad_kernels_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { template __global__ void DoCUDAReflectionPad1d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { DoReflectionPad1d(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); }; template __global__ void DoCUDAReflectionPad1dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { DoReflectionPad1dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); }; template __global__ void DoCUDAReflectionPad2d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { DoReflectionPad2d(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); }; template __global__ void DoCUDAReflectionPad2dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { DoReflectionPad2dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); }; template struct ReflectionPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReflectionPad1d<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } }; // float16 implementation template<> void ReflectionPad1dFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReflectionPad1d<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } template struct ReflectionPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReflectionPad1dGrad<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } }; // float16 implementation template<> void ReflectionPad1dGradFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReflectionPad1dGrad<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } template struct ReflectionPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReflectionPad2d<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } }; // float16 implementation template<> void ReflectionPad2dFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReflectionPad2d<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } template struct ReflectionPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReflectionPad2dGrad<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } }; // float16 implementation template<> void ReflectionPad2dGradFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReflectionPad2dGrad<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_FUNCTOR, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), PADDING_DATA_TYPE_CUDA_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), PADDING_DATA_TYPE_CUDA_SEQ); } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/reflection_pad_kernels_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_ #define ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { #define PADDING_DATA_TYPE_CPU_SEQ \ FLOATING_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define PADDING_DATA_TYPE_CUDA_SEQ \ FLOAT16_DATA_TYPE_SEQ \ PADDING_DATA_TYPE_CPU_SEQ namespace user_op { template struct DeviceAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #if defined(__CUDA_ARCH__) cuda::atomic::Add(y, *x); #else *y += *x; #endif }; }; template struct ReflectionPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left); }; template struct ReflectionPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left); }; template struct ReflectionPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top); }; template struct ReflectionPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top); }; template OF_DEVICE_FUNC void DoReflectionPad1d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, j, ip_x; int64_t coord_y[3]; index_helper.OffsetToNdIndex(k, coord_y); n = coord_y[0]; c = coord_y[1]; j = coord_y[2]; if (j < pad_left) { ip_x = pad_left * 2 - j; } else if (j >= pad_left && j < x_width + pad_left) { ip_x = j; } else { ip_x = (x_width + pad_left - 1) * 2 - j; } ip_x = ip_x - pad_left; int64_t dest_index = n * dest_num + c * y_width + j; int64_t src_index = n * src_num + c * x_width + ip_x; dest[dest_index] = src[src_index]; } } template OF_DEVICE_FUNC void DoReflectionPad1dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, j, ip_x; int64_t coord[3]; index_helper.OffsetToNdIndex(k, coord); n = coord[0]; c = coord[1]; j = coord[2]; if (j < pad_left) { ip_x = pad_left * 2 - j; } else if (j >= pad_left && j < dx_width + pad_left) { ip_x = j; } else { ip_x = (dx_width + pad_left - 1) * 2 - j; } ip_x = ip_x - pad_left; int64_t src_index = n * src_num + c * dy_width + j; int64_t dest_index = n * dest_num + c * dx_width + ip_x; DeviceAdd::Invoke(src + src_index, dest + dest_index); } } template OF_DEVICE_FUNC void DoReflectionPad2d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, i, j, ip_x, ip_y; int64_t coord_y[4]; index_helper.OffsetToNdIndex(k, coord_y); n = coord_y[0]; c = coord_y[1]; i = coord_y[2]; j = coord_y[3]; if (j < pad_left) { ip_x = pad_left * 2 - j; } else if (j >= pad_left && j < x_width + pad_left) { ip_x = j; } else { ip_x = (x_width + pad_left - 1) * 2 - j; } if (i < pad_top) { ip_y = pad_top * 2 - i; } else if (i >= pad_top && i < x_height + pad_top) { ip_y = i; } else { ip_y = (x_height + pad_top - 1) * 2 - i; } ip_x = ip_x - pad_left; ip_y = ip_y - pad_top; int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j; int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x; dest[dest_index] = src[src_index]; } } template OF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, i, j, ip_x, ip_y; int64_t coord[4]; index_helper.OffsetToNdIndex(k, coord); n = coord[0]; c = coord[1]; i = coord[2]; j = coord[3]; if (j < pad_left) { ip_x = pad_left * 2 - j; } else if (j >= pad_left && j < dx_width + pad_left) { ip_x = j; } else { ip_x = (dx_width + pad_left - 1) * 2 - j; } if (i < pad_top) { ip_y = pad_top * 2 - i; } else if (i >= pad_top && i < dx_height + pad_top) { ip_y = i; } else { ip_y = (dx_height + pad_top - 1) * 2 - i; } ip_x = ip_x - pad_left; ip_y = ip_y - pad_top; int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j; int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x; DeviceAdd::Invoke(src + src_index, dest + dest_index); } } // macros for functors instantiate #define INSTANTIATE_REFLECTION_PAD_FUNCTOR(device_type_v, dtype_pair) \ template struct ReflectionPad1dFunctor; \ template struct ReflectionPad2dFunctor; #define INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR(device_type_v, dtype_pair) \ template struct ReflectionPad1dGradFunctor; \ template struct ReflectionPad2dGradFunctor; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_ ================================================ FILE: oneflow/user/kernels/repeat_interleave_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/roll_kernel_utils.h" #include namespace oneflow { template class CpuRepeatInterLeaveKernel final : public user_op::OpKernel { public: CpuRepeatInterLeaveKernel() = default; ~CpuRepeatInterLeaveKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* cumsum = ctx->Tensor4ArgNameAndIndex("cumsum", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const T* in_ptr = in->dptr(); const T* cumsum_ptr = cumsum->dptr(); T* out_ptr = out->mut_dptr(); for (T i = 0; i < in->shape_view().At(0); i++) { T end = cumsum_ptr[i]; T size = in_ptr[i]; T start = end - size; for (T j = start; j < end; j++) { out_ptr[j] = i; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REPEAT_INTER_LEAVE_KERNEL(dtype) \ REGISTER_USER_KERNEL("repeat_interleave") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_REPEAT_INTER_LEAVE_KERNEL(int32_t); REGISTER_REPEAT_INTER_LEAVE_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/repeat_interleave_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/roll_kernel_utils.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template __global__ void repeat_interleave(const T* in_ptr, const T* cumsum_ptr, T* out_ptr, const int64_t num) { CUDA_1D_KERNEL_LOOP(i, num) { T end = cumsum_ptr[i]; T size = in_ptr[i]; T start = end - size; for (T j = start; j < end; j++) { out_ptr[j] = i; } } } } // namespace template class GpuRepeatInterLeaveKernel final : public user_op::OpKernel { public: GpuRepeatInterLeaveKernel() = default; ~GpuRepeatInterLeaveKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* cumsum = ctx->Tensor4ArgNameAndIndex("cumsum", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t& repeat_num = ctx->Attr("repeat_num"); const T* in_ptr = in->dptr(); const T* cumsum_ptr = cumsum->dptr(); T* out_ptr = out->mut_dptr(); repeat_interleave<<shape_view().At(0)), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>( in_ptr, cumsum_ptr, out_ptr, in->shape_view().At(0)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REPEAT_INTER_LEAVE_KERNEL(dtype) \ REGISTER_USER_KERNEL("repeat_interleave") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_REPEAT_INTER_LEAVE_KERNEL(int32_t); REGISTER_REPEAT_INTER_LEAVE_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/replication_pad_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/replication_pad_kernels_util.h" namespace oneflow { namespace user_op { template class ReplicationPad1dKernel final : public OpKernel { public: ReplicationPad1dKernel() = default; ~ReplicationPad1dKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; const int64_t pad_left = padding[0]; const int64_t n_batch = y->shape_view().At(n_idx); const int64_t n_channel = y->shape_view().At(c_idx); const int64_t y_width = y->shape_view().At(w_idx); const int64_t x_width = x->shape_view().At(w_idx); IN_T* dest = y->mut_dptr(); const IN_T* src = x->dptr(); DimVector y_vector; y->shape_view().ToDimVector(&y_vector); NdIndexOffsetHelper index_helper(y_vector.data()); ReplicationPad1dFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, y_width, x_width, pad_left); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReplicationPad1dGradKernel final : public OpKernel { public: ReplicationPad1dGradKernel() = default; ~ReplicationPad1dGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; const int64_t pad_left = padding[0]; const int64_t n_batch = dy->shape_view().At(n_idx); const int64_t n_channel = dy->shape_view().At(c_idx); const int64_t dy_width = dy->shape_view().At(w_idx); const int64_t dx_width = dx->shape_view().At(w_idx); const IN_T* src = dy->dptr(); IN_T* dest = dx->mut_dptr(); DimVector dy_vector; dy->shape_view().ToDimVector(&dy_vector); NdIndexOffsetHelper index_helper(dy_vector.data()); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); ReplicationPad1dGradFunctor()( ctx->stream(), src, dest, index_helper, n_batch, n_channel, dy_width, dx_width, pad_left); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReplicationPad2dKernel final : public OpKernel { public: ReplicationPad2dKernel() = default; ~ReplicationPad2dKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; const int64_t pad_left = padding[0]; const int64_t pad_top = padding[2]; const int64_t n_batch = y->shape_view().At(n_idx); const int64_t n_channel = y->shape_view().At(c_idx); const int64_t y_height = y->shape_view().At(h_idx); const int64_t y_width = y->shape_view().At(w_idx); const int64_t x_height = x->shape_view().At(h_idx); const int64_t x_width = x->shape_view().At(w_idx); IN_T* dest = y->mut_dptr(); const IN_T* src = x->dptr(); DimVector y_vector; y->shape_view().ToDimVector(&y_vector); NdIndexOffsetHelper index_helper(y_vector.data()); ReplicationPad2dFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, y_height, y_width, x_height, x_width, pad_left, pad_top); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ReplicationPad2dGradKernel final : public OpKernel { public: ReplicationPad2dGradKernel() = default; ~ReplicationPad2dGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; const int64_t pad_left = padding[0]; const int64_t pad_top = padding[2]; const int64_t n_batch = dy->shape_view().At(n_idx); const int64_t n_channel = dy->shape_view().At(c_idx); const int64_t dy_height = dy->shape_view().At(h_idx); const int64_t dy_width = dy->shape_view().At(w_idx); const int64_t dx_height = dx->shape_view().At(h_idx); const int64_t dx_width = dx->shape_view().At(w_idx); const IN_T* src = dy->dptr(); IN_T* dest = dx->mut_dptr(); DimVector dy_vector; dy->shape_view().ToDimVector(&dy_vector); NdIndexOffsetHelper index_helper(dy_vector.data()); size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type()); Memset(ctx->stream(), dest, 0, out_bytes_size); ReplicationPad2dGradFunctor()(ctx->stream(), src, dest, index_helper, n_batch, n_channel, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REPLICATION_PAD_ND_KERNELS(device, dtype) \ REGISTER_USER_KERNEL("replication_pad1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("replication_pad1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("replication_pad2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("replication_pad2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); #define REGISTER_REPLICATION_PAD_ND_WITH_DEVICE(device) \ REGISTER_REPLICATION_PAD_ND_KERNELS(device, float) \ REGISTER_REPLICATION_PAD_ND_KERNELS(device, double) \ REGISTER_REPLICATION_PAD_ND_KERNELS(device, int32_t) REGISTER_REPLICATION_PAD_ND_WITH_DEVICE(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_REPLICATION_PAD_ND_WITH_DEVICE(DeviceType::kCUDA) REGISTER_REPLICATION_PAD_ND_KERNELS(DeviceType::kCUDA, float16) #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/replication_pad_kernels_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/replication_pad_kernels_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace user_op { template struct ReplicationPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoReplicationPad1d(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } }; template struct ReplicationPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoReplicationPad1dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } }; template struct ReplicationPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoReplicationPad2d(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } }; template struct ReplicationPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoReplicationPad2dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_FUNCTOR, (DeviceType::kCPU), PADDING_DATA_TYPE_CPU_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR, (DeviceType::kCPU), PADDING_DATA_TYPE_CPU_SEQ); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/replication_pad_kernels_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #ifdef WITH_CUDA #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/replication_pad_kernels_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { template __global__ void DoCUDAReplicationPad1d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { DoReplicationPad1d(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); }; template __global__ void DoCUDAReplicationPad1dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { DoReplicationPad1dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); }; template __global__ void DoCUDAReplicationPad2d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { DoReplicationPad2d(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); }; template __global__ void DoCUDAReplicationPad2dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { DoReplicationPad2dGrad(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); }; template struct ReplicationPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReplicationPad1d<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } }; // float16 implementation template<> void ReplicationPad1dFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { const int64_t dest_num = n_channel * y_width; const int64_t src_num = n_channel * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReplicationPad1d<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left); } template struct ReplicationPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReplicationPad1dGrad<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } }; // float16 implementation template<> void ReplicationPad1dGradFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { const int64_t dest_num = n_channel * dx_width; const int64_t src_num = n_channel * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReplicationPad1dGrad<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left); } template struct ReplicationPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReplicationPad2d<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } }; // float16 implementation template<> void ReplicationPad2dFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * y_height * y_width; const int64_t src_num = n_channel * x_height * x_width; const int64_t elem_num = n_batch * dest_num; DoCUDAReplicationPad2d<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); } template struct ReplicationPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReplicationPad2dGrad<<As()->cuda_stream()>>>( src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } }; // float16 implementation template<> void ReplicationPad2dGradFunctor::operator()( ep::Stream* stream, const float16* src, float16* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { const int64_t dest_num = n_channel * dx_height * dx_width; const int64_t src_num = n_channel * dy_height * dy_width; const int64_t elem_num = n_batch * src_num; DoCUDAReplicationPad2dGrad<<As()->cuda_stream()>>>( reinterpret_cast(src), reinterpret_cast(dest), index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); } OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_FUNCTOR, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), PADDING_DATA_TYPE_CUDA_SEQ); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), PADDING_DATA_TYPE_CUDA_SEQ); } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/replication_pad_kernels_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_ #define ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_ #ifdef WITH_CUDA #include "oneflow/core/cuda/atomic.cuh" #endif // WITH_CUDA #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { #define PADDING_DATA_TYPE_CPU_SEQ \ FLOATING_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define PADDING_DATA_TYPE_CUDA_SEQ \ FLOAT16_DATA_TYPE_SEQ \ PADDING_DATA_TYPE_CPU_SEQ namespace user_op { template struct DeviceAdd { OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { #if defined(__CUDA_ARCH__) cuda::atomic::Add(y, *x); #else *y += *x; #endif }; }; template struct ReplicationPad1dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left); }; template struct ReplicationPad1dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left); }; template struct ReplicationPad2dFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top); }; template struct ReplicationPad2dGradFunctor final { void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t n_batch, const int64_t n_channel, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top); }; template OF_DEVICE_FUNC void DoReplicationPad1d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_width, const int64_t x_width, const int64_t pad_left) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, j, ip_x; int64_t coord_y[3]; index_helper.OffsetToNdIndex(k, coord_y); n = coord_y[0]; c = coord_y[1]; j = coord_y[2]; if (j < pad_left) { ip_x = pad_left; } else if (j >= pad_left && j < x_width + pad_left) { ip_x = j; } else { ip_x = x_width + pad_left - 1; } ip_x = ip_x - pad_left; int64_t dest_index = n * dest_num + c * y_width + j; int64_t src_index = n * src_num + c * x_width + ip_x; dest[dest_index] = src[src_index]; } } template OF_DEVICE_FUNC void DoReplicationPad1dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_width, const int64_t dx_width, const int64_t pad_left) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, j, ip_x; int64_t coord[3]; index_helper.OffsetToNdIndex(k, coord); n = coord[0]; c = coord[1]; j = coord[2]; if (j < pad_left) { ip_x = pad_left; } else if (j >= pad_left && j < dx_width + pad_left) { ip_x = j; } else { ip_x = dx_width + pad_left - 1; } ip_x = ip_x - pad_left; int64_t src_index = n * src_num + c * dy_width + j; int64_t dest_index = n * dest_num + c * dx_width + ip_x; DeviceAdd::Invoke(src + src_index, dest + dest_index); } } template OF_DEVICE_FUNC void DoReplicationPad2d(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t y_height, const int64_t y_width, const int64_t x_height, const int64_t x_width, const int64_t pad_left, const int64_t pad_top) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, i, j, ip_x, ip_y; int64_t coord_y[4]; index_helper.OffsetToNdIndex(k, coord_y); n = coord_y[0]; c = coord_y[1]; i = coord_y[2]; j = coord_y[3]; if (j < pad_left) { ip_x = pad_left; } else if (j >= pad_left && j < x_width + pad_left) { ip_x = j; } else { ip_x = x_width + pad_left - 1; } if (i < pad_top) { ip_y = pad_top; } else if (i >= pad_top && i < x_height + pad_top) { ip_y = i; } else { ip_y = x_height + pad_top - 1; } ip_x = ip_x - pad_left; ip_y = ip_y - pad_top; int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j; int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x; dest[dest_index] = src[src_index]; } } template OF_DEVICE_FUNC void DoReplicationPad2dGrad(const IN_T* src, IN_T* dest, const NdIndexOffsetHelper& index_helper, const int64_t elem_num, const int64_t src_num, const int64_t dest_num, const int64_t dy_height, const int64_t dy_width, const int64_t dx_height, const int64_t dx_width, const int64_t pad_left, const int64_t pad_top) { XPU_1D_KERNEL_LOOP(k, elem_num) { int64_t n, c, i, j, ip_x, ip_y; int64_t coord[4]; index_helper.OffsetToNdIndex(k, coord); n = coord[0]; c = coord[1]; i = coord[2]; j = coord[3]; if (j < pad_left) { ip_x = pad_left; } else if (j >= pad_left && j < dx_width + pad_left) { ip_x = j; } else { ip_x = dx_width + pad_left - 1; } if (i < pad_top) { ip_y = pad_top; } else if (i >= pad_top && i < dx_height + pad_top) { ip_y = i; } else { ip_y = dx_height + pad_top - 1; } ip_x = ip_x - pad_left; ip_y = ip_y - pad_top; int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j; int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x; DeviceAdd::Invoke(src + src_index, dest + dest_index); } } // macros for functors instantiate(used by pad2d_kernels_util.cu) #define INSTANTIATE_REPLICATION_PAD_FUNCTOR(device_type_v, dtype_pair) \ template struct ReplicationPad1dFunctor; \ template struct ReplicationPad2dFunctor; #define INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR(device_type_v, dtype_pair) \ template struct ReplicationPad1dGradFunctor; \ template struct ReplicationPad2dGradFunctor; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_ ================================================ FILE: oneflow/user/kernels/rms_norm_gpu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/cuda/rms_norm.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace cuda { namespace rms_norm { template struct AffineStore { AffineStore(DST* dst, const DST* weight, int32_t row_size) : dst(dst), weight(weight), row_size(row_size) {} template __device__ void store(const SRC* src, int32_t row, int32_t col) { layer_norm::Pack dst_pack; layer_norm::Pack weight_pack; const int32_t offset = (row * row_size + col) / N; const int32_t weight_offset = col / N; if (affine) { weight_pack.storage = *(reinterpret_cast*>(weight) + weight_offset); } #pragma unroll for (int i = 0; i < N; ++i) { if (affine) { dst_pack.elem[i] = static_cast(src[i]) * weight_pack.elem[i]; } else { dst_pack.elem[i] = static_cast(src[i]); } } *(reinterpret_cast*>(dst) + offset) = dst_pack.storage; } DST* dst; const DST* weight; int32_t row_size; }; template struct AffineLoad { AffineLoad(const SRC* src, const SRC* weight, int32_t row_size) : src(src), weight(weight), row_size(row_size) {} template __device__ void load(DST* dst, int32_t row, int32_t col) const { layer_norm::Pack src_pack; layer_norm::Pack weight_pack; const int32_t offset = (row * row_size + col) / N; src_pack.storage = *(reinterpret_cast*>(src) + offset); if (affine) { const int32_t weight_offset = col / N; weight_pack.storage = *(reinterpret_cast*>(weight) + weight_offset); } #pragma unroll for (int i = 0; i < N; ++i) { if (affine) { dst[i] = static_cast(src_pack.elem[i] * weight_pack.elem[i]); } else { dst[i] = static_cast(src_pack.elem[i]); } } } const SRC* src; const SRC* weight; int32_t row_size; }; template void DispatchRmsNormForwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const double eps, const T* x_dptr, const T* w_dptr, T* y_dptr, ComputeType* inv_rms) { layer_norm::DirectLoad load(x_dptr, ncol); AffineStore store(y_dptr, w_dptr, ncol); OF_CUDA_CHECK((LaunchRmsNorm( stream->As()->cuda_stream(), load, store, nrow, ncol, eps, inv_rms))); } template void RmsNormForward(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const double eps, const T* x_dptr, const T* w_dptr, T* y_dptr, ComputeType* inv_rms) { if (w_dptr) { DispatchRmsNormForwardAffine(stream, nrow, ncol, eps, x_dptr, w_dptr, y_dptr, inv_rms); } else { DispatchRmsNormForwardAffine(stream, nrow, ncol, eps, x_dptr, w_dptr, y_dptr, inv_rms); } } template void DispatchRmsNormBackwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const T* dy_dptr, const T* x_dptr, const T* weight_dptr, const ComputeType* inv_rms, T* dx_ptr) { layer_norm::DirectLoad load_x(x_dptr, ncol); AffineLoad load_dy(dy_dptr, weight_dptr, ncol); layer_norm::DirectStore store(dx_ptr, ncol); OF_CUDA_CHECK((rms_norm::LaunchRmsNormGrad(stream->As()->cuda_stream(), nrow, ncol, load_x, load_dy, store, inv_rms))); } template void RmsNormBackward(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const T* dy_dptr, const T* x_dptr, const T* weight_dptr, const ComputeType* inv_rms, T* dx_dptr) { if (weight_dptr) { DispatchRmsNormBackwardAffine(stream, nrow, ncol, dy_dptr, x_dptr, weight_dptr, inv_rms, dx_dptr); } else { DispatchRmsNormBackwardAffine(stream, nrow, ncol, dy_dptr, x_dptr, weight_dptr, inv_rms, dx_dptr); } } } // namespace rms_norm template class RmsNormKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: RmsNormKernel() = default; ~RmsNormKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex("inv_rms", 0); const double eps = ctx->Attr("epsilon"); const Shape& normalized_shape = ctx->Attr("normalized_shape"); const int64_t ncol = normalized_shape.elem_cnt(); const int64_t nrow = inv_rms->shape_view().elem_cnt(); const T* weight_dptr = nullptr; if (ctx->has_input("weight", 0)) { const auto* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); CHECK_EQ(weight->shape_view().elem_cnt(), ncol); weight_dptr = weight->dptr(); } CHECK_EQ(x->shape_view().elem_cnt(), ncol * nrow); CHECK_LT(nrow * ncol, std::numeric_limits::max()) << "The size of tensor exceeds int32 max limit. The kernel don't support large tensor."; using ComputeType = typename layer_norm::DefaultComputeType::type; rms_norm::RmsNormForward(ctx->stream(), nrow, ncol, eps, x->dptr(), weight_dptr, y->mut_dptr(), inv_rms->mut_dptr()); }; }; #define REGISTER_RMS_NORM_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("rms_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_RMS_NORM_CUDA_KERNEL(float) REGISTER_RMS_NORM_CUDA_KERNEL(double) REGISTER_RMS_NORM_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_RMS_NORM_CUDA_KERNEL(nv_bfloat16) #endif template class RmsNormGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: RmsNormGradKernel() = default; ~RmsNormGradKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex("inv_rms", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t nrow = inv_rms->shape_view().elem_cnt(); const int64_t ncol = x->shape_view().elem_cnt() / nrow; const T* weight_dptr = nullptr; if (ctx->has_input("weight", 0)) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); CHECK_EQ(ncol, weight->shape_view().elem_cnt()); weight_dptr = weight->dptr(); } CHECK_LT(nrow * ncol, std::numeric_limits::max()) << "The size of tensor exceeds int32 max limit. The kernel don't support large tensor."; using ComputeType = typename layer_norm::DefaultComputeType::type; rms_norm::RmsNormBackward(ctx->stream(), nrow, ncol, dy->dptr(), x->dptr(), weight_dptr, inv_rms->dptr(), dx->mut_dptr()); }; }; #define REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("rms_norm_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)); REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(float) REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(double) REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(nv_bfloat16) #endif namespace { constexpr int kNProcPerThread = 4; } // namespace template class RmsNormParamGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: RmsNormParamGradKernel() = default; ~RmsNormParamGradKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex("inv_rms", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex("weight_grad", 0); const int64_t nrow = inv_rms->shape_view().elem_cnt(); const int64_t ncol = weight_grad->shape_view().elem_cnt(); CHECK_LT(nrow * ncol, std::numeric_limits::max()) << "The size of tensor exceeds int32 max limit. The kernel don't support large tensor."; // step 1: dx = dy * y and reduce partial rows in a block const int block_dim_x = rms_norm::kWarpSize; const int block_dim_y = rms_norm::kWarpSize / kNProcPerThread; int grid_dim_x; int grid_dim_y; OF_CUDA_CHECK((rms_norm::GetGrid2Dim(nrow, ncol, block_dim_x, block_dim_y, &grid_dim_x, &grid_dim_y))); // tmp weight shape [grid_dim_y, ncol] (reduce nrow -> grid_dim_y) size_t tmp_weight_grad_size = grid_dim_y * ncol; T* tmp_weight_grad_dptr = reinterpret_cast(tmp_buffer->mut_dptr()); using ComputeType = typename layer_norm::DefaultComputeType::type; dim3 grid_dims(grid_dim_x, grid_dim_y); dim3 block_dims(block_dim_x, block_dim_y); rms_norm::RmsNormParamGrad <<stream()->As()->cuda_stream()>>>( nrow, ncol, dy->dptr(), x->dptr(), inv_rms->dptr(), tmp_weight_grad_dptr); // step 2: reduce rows throught gemm to calculate weight grad // fill ones matrix with shape (grid_dim_y, 1) const int32_t m = ncol; const int32_t n = 1; const int32_t k = grid_dim_y; const DataType data_type = dy->data_type(); auto fill = ep::primitive::NewPrimitive( ctx->stream()->device_type(), data_type); CHECK(fill); T* tmp_ones_dptr = tmp_buffer->mut_dptr() + tmp_weight_grad_size; fill->Launch(ctx->stream(), tmp_ones_dptr, 1.0, k); // tmp weight grad (grid_dim_y, ncol) (T) * tmp ones (grid_dim_y, 1) (N) // -> weight grad (ncol, 1) auto matmul = ep::primitive::NewPrimitive( ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T, ep::primitive::BlasTransposeType::N); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, /*alpha*/ 1.0, tmp_weight_grad_dptr, tmp_ones_dptr, /*beta*/ 0.0, weight_grad->mut_dptr()); }; }; template size_t InferRmsNormParamGradTempBufferSize(user_op::InferContext* ctx) { const auto& shape = ctx->InputTensorDesc("dy", 0).shape(); const auto& b_shape = ctx->InputTensorDesc("inv_rms", 0).shape(); const int64_t nrow = b_shape.elem_cnt(); const int64_t ncol = shape.elem_cnt() / nrow; const int block_dim_x = rms_norm::kWarpSize; const int block_dim_y = rms_norm::kWarpSize / kNProcPerThread; int grid_dim_x; int grid_dim_y; OF_CUDA_CHECK((rms_norm::GetGrid2Dim(nrow, ncol, block_dim_x, block_dim_y, &grid_dim_x, &grid_dim_y))); return (grid_dim_y * ncol + grid_dim_y) * sizeof(T); } #define REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("rms_norm_param_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dy", 0) == GetDataType::value)) \ .SetInferTmpSizeFn(InferRmsNormParamGradTempBufferSize); REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(float) REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(double) REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16) #endif } // namespace cuda } // namespace oneflow ================================================ FILE: oneflow/user/kernels/roc_auc_score_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { template double RocAucScore(size_t n, const L* label, const P* pred, float* buffer) { size_t p_samples_count = 0; for (size_t i = 0; i < n; ++i) { if (label[i] == 0) { buffer[i] = -pred[i]; } else { p_samples_count += 1; buffer[i] = pred[i]; } } const size_t n_samples_count = n - p_samples_count; constexpr size_t kParallelSortThreshold = 1024; auto comp = [](float a, float b) { return fabs(a) < fabs(b); }; if (n < kParallelSortThreshold) { std::sort(buffer, buffer + n, comp); } else { const size_t m2 = n / 2; const size_t m1 = m2 / 2; const size_t m3 = (m2 + n) / 2; std::thread t0([&] { std::sort(buffer, buffer + m1, comp); }); std::thread t1([&] { std::sort(buffer + m1, buffer + m2, comp); }); std::thread t2([&] { std::sort(buffer + m2, buffer + m3, comp); }); std::thread t3([&] { std::sort(buffer + m3, buffer + n, comp); }); t0.join(); t1.join(); t2.join(); t3.join(); std::inplace_merge(buffer, buffer + m1, buffer + m2, comp); std::inplace_merge(buffer + m2, buffer + m3, buffer + n, comp); std::inplace_merge(buffer, buffer + m2, buffer + n, comp); } size_t tmp_n = 0; double tmp_rank_sum = 0; double rank_sum = 0; size_t tmp_p_samples_count = 0; for (size_t i = 0; i < n; ++i) { if (i != 0 && fabs(buffer[i]) != fabs(buffer[i - 1])) { rank_sum += tmp_p_samples_count * (tmp_rank_sum / tmp_n); tmp_n = 0; tmp_rank_sum = 0; tmp_p_samples_count = 0; } if (buffer[i] > 0) { tmp_p_samples_count += 1; } tmp_rank_sum += (i + 1); tmp_n += 1; } rank_sum += tmp_p_samples_count * (tmp_rank_sum / tmp_n); return (rank_sum - p_samples_count * (p_samples_count + 1) / 2) / (p_samples_count * n_samples_count); } template class RocAucScoreKernel final : public user_op::OpKernel { public: OF_DISALLOW_COPY_AND_MOVE(RocAucScoreKernel); RocAucScoreKernel() = default; ~RocAucScoreKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* pred = ctx->Tensor4ArgNameAndIndex("pred", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); P* out_ptr = out->mut_dptr

(); CHECK_EQ(label->shape_view().elem_cnt(), pred->shape_view().elem_cnt()); CHECK_EQ(out->shape_view().elem_cnt(), 1); out_ptr[0] = RocAucScore(label->shape_view().elem_cnt(), label->dptr(), pred->dptr

(), tmp_buffer->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ROC_AUC_SCORE_KERNEL(label_type, label_cpp_type, pred_type, pred_cpp_type) \ REGISTER_USER_KERNEL("roc_auc_score") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("label", 0) == label_type) \ && (user_op::HobDataType("pred", 0) == pred_type)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& pred_shape = ctx->InputShape("pred", 0); \ size_t tmp_buffer_size = pred_shape.elem_cnt() * sizeof(float); \ return tmp_buffer_size; \ }) REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kDouble, double, DataType::kFloat, float); REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kFloat, float, DataType::kFloat, float); REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt32, int, DataType::kFloat, float); REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt64, int64_t, DataType::kFloat, float); REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt8, int8_t, DataType::kFloat, float); REGISTER_ROC_AUC_SCORE_KERNEL(DataType::kUInt8, uint8_t, DataType::kFloat, float); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/roi_align_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __device__ T BilinearInterpolate(const T* channel_dptr, const int32_t height, const int32_t width, T y, T x) { if (y < -1.0 || y > height || x < -1.0 || x > width) { return 0; } if (y <= 0) { y = 0; } if (x <= 0) { x = 0; } int32_t y_low = static_cast(y); int32_t x_low = static_cast(x); int32_t y_high = 0; int32_t x_high = 0; if (y_low >= height - 1) { y_low = height - 1; y_high = y_low; y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_low = width - 1; x_high = x_low; x = static_cast(x_low); } else { x_high = x_low + 1; } const T ly = y - y_low; const T lx = x - x_low; const T hy = 1.f - ly; const T hx = 1.f - lx; // https://en.wikipedia.org/wiki/Bilinear_interpolation const int64_t q11 = y_low * width + x_low; const int64_t q21 = y_low * width + x_high; const int64_t q12 = y_high * width + x_low; const int64_t q22 = y_high * width + x_high; // no 1 / (x_high - x_low) * (y_high - y_low) because it will always be 1 in RoI Align return (hy * hx) * channel_dptr[q11] + (hy * lx) * channel_dptr[q21] + (ly * hx) * channel_dptr[q12] + (ly * lx) * channel_dptr[q22]; } template __device__ bool BilinearInterpolateDiff(const T bin_diff_avg, const int64_t height, const int64_t width, T y, T x, T& diff11, T& diff21, T& diff12, T& diff22, int32_t& x_low, int32_t& x_high, int32_t& y_low, int32_t& y_high) { if (y < -1.0 || y > height || x < -1.0 || x > width) { return false; } if (y <= 0) { y = 0; } if (x <= 0) { x = 0; } y_low = static_cast(y); x_low = static_cast(x); if (y_low >= height - 1) { y_low = height - 1; y_high = y_low; y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_low = width - 1; x_high = x_low; x = static_cast(x_low); } else { x_high = x_low + 1; } const T ly = y - y_low; const T lx = x - x_low; const T hy = 1.f - ly; const T hx = 1.f - lx; diff11 = bin_diff_avg * hy * hx; diff21 = bin_diff_avg * hy * lx; diff12 = bin_diff_avg * ly * hx; diff22 = bin_diff_avg * ly * lx; return true; } template __global__ void RoiAlignForward(const int64_t nthreads, const T* in_dptr, const T* rois_dptr, const T spatial_scale, const int32_t sampling_ratio, const int64_t channel_num, const int64_t height, const int64_t width, const int64_t pooled_height, const int64_t pooled_width, const bool aligned, T* out_dptr) { const int64_t pooled_area = pooled_height * pooled_width; const int64_t channel_pooled_area = channel_num * pooled_height * pooled_width; CUDA_1D_KERNEL_LOOP(index, nthreads) { const int64_t h = (index / pooled_width) % pooled_height; const int64_t w = index % pooled_width; const int64_t c = (index / pooled_area) % channel_num; const int64_t r = index / channel_pooled_area; const T* offset_rois_dptr = rois_dptr + r * 5; const int64_t n = static_cast(offset_rois_dptr[0]); const T align_offset = aligned ? static_cast(0.5) : static_cast(0.f); const T roi_start_w = offset_rois_dptr[1] * spatial_scale - align_offset; const T roi_start_h = offset_rois_dptr[2] * spatial_scale - align_offset; const T roi_end_w = offset_rois_dptr[3] * spatial_scale - align_offset; const T roi_end_h = offset_rois_dptr[4] * spatial_scale - align_offset; T roi_height = roi_end_h - roi_start_h; T roi_width = roi_end_w - roi_start_w; // aligned == false is for compatibility. the argument "aligned" doesn't have the semantic of // determining minimum roi size if (aligned == false) { roi_height = max(roi_height, static_cast(1.0)); roi_width = max(roi_width, static_cast(1.0)); } const T bin_height = static_cast(roi_height) / static_cast(pooled_height); const T bin_width = static_cast(roi_width) / static_cast(pooled_width); const int32_t bin_grid_height = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); const int32_t bin_grid_width = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); const T count = max(bin_grid_height * bin_grid_width, 1); const T* channel_dptr = in_dptr + (n * channel_num + c) * height * width; T out_val = 0.0; FOR_RANGE(int64_t, grid_i, 0, bin_grid_height) { // + .5f for center position T y = roi_start_h + h * bin_height + static_cast(grid_i + 0.5f) * bin_height / static_cast(bin_grid_height); FOR_RANGE(int64_t, grid_j, 0, bin_grid_width) { T x = roi_start_w + w * bin_width + static_cast(grid_j + 0.5f) * bin_width / static_cast(bin_grid_width); out_val += BilinearInterpolate(channel_dptr, height, width, y, x); } } out_dptr[index] = out_val / count; } } template __global__ void RoiAlignBackward(const int64_t nthreads, const T* out_diff_dptr, const T* rois_dptr, const T spatial_scale, const int32_t sampling_ratio, const int64_t channel_num, const int64_t height, const int64_t width, const int64_t pooled_height, const int64_t pooled_width, const bool aligned, T* in_diff_dptr) { const int64_t pooled_area = pooled_height * pooled_width; const int64_t channel_pooled_area = channel_num * pooled_height * pooled_width; CUDA_1D_KERNEL_LOOP(index, nthreads) { const int64_t h = (index / pooled_width) % pooled_height; const int64_t w = index % pooled_width; const int64_t c = (index / pooled_area) % channel_num; const int64_t r = index / channel_pooled_area; const T* offset_rois_dptr = rois_dptr + r * 5; const int64_t n = static_cast(offset_rois_dptr[0]); const T align_offset = aligned ? static_cast(0.5) : static_cast(0.f); const T roi_start_w = offset_rois_dptr[1] * spatial_scale - align_offset; const T roi_start_h = offset_rois_dptr[2] * spatial_scale - align_offset; const T roi_end_w = offset_rois_dptr[3] * spatial_scale - align_offset; const T roi_end_h = offset_rois_dptr[4] * spatial_scale - align_offset; T roi_width = roi_end_w - roi_start_w; T roi_height = roi_end_h - roi_start_h; // aligned == false is for compatibility. the argument "aligned" doesn't have the semantic of // determining minimum roi size if (aligned == false) { roi_height = max(roi_height, static_cast(1.0)); roi_width = max(roi_width, static_cast(1.0)); } const T bin_height = static_cast(roi_height) / static_cast(pooled_height); const T bin_width = static_cast(roi_width) / static_cast(pooled_width); const int32_t bin_grid_height = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); const int32_t bin_grid_width = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); const T count = max(bin_grid_height * bin_grid_width, 1); const T bin_diff_avg = out_diff_dptr[index] / count; T* in_diff_channel_dptr = in_diff_dptr + (n * channel_num + c) * height * width; FOR_RANGE(int64_t, grid_i, 0, bin_grid_height) { // + .5f for center position T y = roi_start_h + h * bin_height + static_cast(grid_i + 0.5f) * bin_height / static_cast(bin_grid_height); FOR_RANGE(int64_t, grid_j, 0, bin_grid_width) { T x = roi_start_w + w * bin_width + static_cast(grid_j + 0.5f) * bin_width / static_cast(bin_grid_width); T diff11 = 0; T diff21 = 0; T diff12 = 0; T diff22 = 0; int32_t x_low = 0; int32_t x_high = 0; int32_t y_low = 0; int32_t y_high = 0; bool has_diff = BilinearInterpolateDiff(bin_diff_avg, height, width, y, x, diff11, diff21, diff12, diff22, x_low, x_high, y_low, y_high); if (has_diff) { const int64_t q11 = y_low * width + x_low; const int64_t q21 = y_low * width + x_high; const int64_t q12 = y_high * width + x_low; const int64_t q22 = y_high * width + x_high; atomicAdd(in_diff_channel_dptr + q11, diff11); atomicAdd(in_diff_channel_dptr + q21, diff21); atomicAdd(in_diff_channel_dptr + q12, diff12); atomicAdd(in_diff_channel_dptr + q22, diff22); } } } } } } // namespace template class RoIAlignKernel final : public user_op::OpKernel { public: RoIAlignKernel() = default; ~RoIAlignKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* rois_blob = ctx->Tensor4ArgNameAndIndex("rois", 0); if (rois_blob->shape_view().elem_cnt() == 0) { return; } user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); const int32_t pooled_w = ctx->Attr("pooled_w"); const float spatial_scale = ctx->Attr("spatial_scale"); const int32_t sampling_ratio = ctx->Attr("sampling_ratio"); const bool aligned = ctx->Attr("aligned"); const int64_t elem_cnt = y_blob->shape_view().elem_cnt(); RoiAlignForward<<stream()->As()->cuda_stream()>>>( elem_cnt, x_blob->dptr(), rois_blob->dptr(), spatial_scale, sampling_ratio, x_blob->shape_view().At(1), x_blob->shape_view().At(2), x_blob->shape_view().At(3), pooled_h, pooled_w, aligned, y_blob->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class RoIAlignGradKernel final : public user_op::OpKernel { public: RoIAlignGradKernel() = default; ~RoIAlignGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); if (dx_blob == nullptr) { return; } Memset(ctx->stream(), dx_blob->mut_dptr(), 0, dx_blob->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* rois_blob = ctx->Tensor4ArgNameAndIndex("rois", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); const int32_t pooled_w = ctx->Attr("pooled_w"); const float spatial_scale = ctx->Attr("spatial_scale"); const int32_t sampling_ratio = ctx->Attr("sampling_ratio"); const bool aligned = ctx->Attr("aligned"); const int64_t elem_cnt = dy_blob->shape_view().elem_cnt(); if (elem_cnt > 0) { RoiAlignBackward<<stream()->As()->cuda_stream()>>>( elem_cnt, dy_blob->dptr(), rois_blob->dptr(), spatial_scale, sampling_ratio, dx_blob->shape_view().At(1), dx_blob->shape_view().At(2), dx_blob->shape_view().At(3), pooled_h, pooled_w, aligned, dx_blob->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("roi_align") .SetCreateFn>() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA); REGISTER_USER_KERNEL("roi_align_grad") .SetCreateFn>() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/roll_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/roll_kernel_utils.h" #include namespace oneflow { template class CpuRollKernel final : public user_op::OpKernel { public: CpuRollKernel() = default; ~CpuRollKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::vector& shifts = ctx->Attr>("shifts"); const std::vector& dims = ctx->Attr>("dims"); SHAPE new_shape{}; SHIFTS new_shifts{}; int32_t num_axes = 0; computeParams(in->shape_view(), shifts, dims, new_shifts.val, new_shape.val, &num_axes); const T* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); const int32_t size = out->shape_view().elem_cnt(); STRIDE stride{}; initStride(stride, new_shape, num_axes); transformShifts(new_shifts.val, new_shape.val, num_axes); for (int32_t i = 0; i < size; ++i) { int shifted_i = switchGetShiftedIndex(i, new_shifts.val, new_shape.val, stride.val, num_axes); out_ptr[i] = in_ptr[shifted_i]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ROLL_KERNEL(dtype) \ REGISTER_USER_KERNEL("roll").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_ROLL_KERNEL(float); REGISTER_ROLL_KERNEL(double); REGISTER_ROLL_KERNEL(bool); REGISTER_ROLL_KERNEL(uint8_t); REGISTER_ROLL_KERNEL(int8_t); REGISTER_ROLL_KERNEL(int32_t); REGISTER_ROLL_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/roll_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/roll_kernel_utils.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void RollCudaKernel(const T* in_ptr, const SHIFTS shifts, const SHAPE shape, const STRIDE stride, const int64_t elements, T* out_ptr) { int32_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (global_index < elements) { int32_t shifted_global_index = getShiftedIndex(global_index, shifts.val, shape.val, stride.val); out_ptr[global_index] = in_ptr[shifted_global_index]; global_index += step; } } template struct GpuRollFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, const SHIFTS shifts, const SHAPE shape, const STRIDE stride, const int64_t elements, T* out_ptr) { RollCudaKernel<<As()->cuda_stream()>>>( in_ptr, shifts, shape, stride, elements, out_ptr); } }; template struct GpuRollFunctor final { void operator()(ep::Stream* stream, const float16* in_ptr, const SHIFTS shifts, const SHAPE shape, const STRIDE stride, const int64_t elements, float16* out_ptr) { RollCudaKernel<<As()->cuda_stream()>>>( reinterpret_cast(in_ptr), shifts, shape, stride, elements, reinterpret_cast(out_ptr)); } }; template __global__ void RollFlattenCudaKernel(const T* in_ptr, const int64_t start, const int64_t elem_count_minus_start, const int64_t elements, T* out_ptr) { int64_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (global_index < elements) { int64_t source_idx = 0; if (global_index >= elem_count_minus_start) { source_idx = global_index - elem_count_minus_start; } else { source_idx = global_index + start; } out_ptr[global_index] = in_ptr[source_idx]; global_index += step; } } template struct GpuRollFlattenFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, const int64_t start, const int64_t elem_count_minus_start, const int64_t elements, T* out_ptr) { RollFlattenCudaKernel<<As()->cuda_stream()>>>( in_ptr, start, elem_count_minus_start, elements, out_ptr); } }; template<> void GpuRollFlattenFunctor::operator()(ep::Stream* stream, const float16* in_ptr, const int64_t start, const int64_t elem_count_minus_start, const int64_t elements, float16* out_ptr) { RollFlattenCudaKernel<<As()->cuda_stream()>>>( reinterpret_cast(in_ptr), start, elem_count_minus_start, elements, reinterpret_cast(out_ptr)); } template __global__ void Roll1DimCudaKernel(const T* in_ptr, const int32_t stride_x_size, const int32_t stride, const int32_t size_minus_start, const int32_t size_minus_start_x_stride, const int32_t start_x_stride, const int64_t elements, T* out_ptr) { int32_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (global_index < elements) { // roll dim idx is the index of linear_index along the rolling dimension. int32_t roll_dim_idx = global_index % stride_x_size / stride; // index into the source data to find appropriate value. int32_t source_idx = 0; if (roll_dim_idx >= size_minus_start) { source_idx = global_index - size_minus_start_x_stride; } else { source_idx = global_index + start_x_stride; } out_ptr[global_index] = in_ptr[source_idx]; global_index += step; } } template struct GpuRoll1DimFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, const int32_t stride_x_size, const int32_t stride, const int32_t size_minus_start, const int32_t size_minus_start_x_stride, const int32_t start_x_stride, const int64_t elements, T* out_ptr) { Roll1DimCudaKernel<<As()->cuda_stream()>>>( in_ptr, stride_x_size, stride, size_minus_start, size_minus_start_x_stride, start_x_stride, elements, out_ptr); } }; template<> void GpuRoll1DimFunctor::operator()(ep::Stream* stream, const float16* in_ptr, const int32_t stride_x_size, const int32_t stride, const int32_t size_minus_start, const int32_t size_minus_start_x_stride, const int32_t start_x_stride, const int64_t elements, float16* out_ptr) { Roll1DimCudaKernel<<As()->cuda_stream()>>>( reinterpret_cast(in_ptr), stride_x_size, stride, size_minus_start, size_minus_start_x_stride, start_x_stride, elements, reinterpret_cast(out_ptr)); } } // namespace template class GpuRollKernel final : public user_op::OpKernel { public: GpuRollKernel() = default; ~GpuRollKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const std::vector& shifts = ctx->Attr>("shifts"); const std::vector& dims = ctx->Attr>("dims"); const T* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); const int64_t elem_count = out->shape_view().elem_cnt(); if (dims[0] == -1) { // NOTE(Liang Depeng): Borrow the implementation of pytorch and simplify to 1d array case. int64_t start = (elem_count - shifts[0]) % elem_count; if (start < 0) start = start + elem_count; const int64_t elem_count_minus_start = elem_count - start; GpuRollFlattenFunctor()(ctx->stream(), in_ptr, start, elem_count_minus_start, elem_count, out_ptr); } else { SHAPE new_shape{}; SHIFTS new_shifts{}; int32_t num_axes = 0; computeParams(in->shape_view(), shifts, dims, new_shifts.val, new_shape.val, &num_axes); STRIDE stride{}; initStride(stride, new_shape, num_axes); if (dims.size() == 1) { // NOTE(Liang Depeng): Borrow the implementation of pytorch const int32_t size = new_shape.val[dims[0]]; int32_t start = (size - new_shifts.val[dims[0]]) % size; // Behavior of % is different in C++ vs Python for negative numbers. This // corrects the difference. if (start < 0) start = start + size; const int32_t stride_x_size = stride.val[dims[0]] * size; const int32_t size_minus_start = size - start; const int32_t size_minus_start_x_stride = size_minus_start * stride.val[dims[0]]; const int32_t start_x_stride = start * stride.val[dims[0]]; GpuRoll1DimFunctor()(ctx->stream(), in_ptr, stride_x_size, stride.val[dims[0]], size_minus_start, size_minus_start_x_stride, start_x_stride, elem_count, out_ptr); } else { transformShifts(new_shifts.val, new_shape.val, num_axes); switch (num_axes) { case 1: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 2: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 3: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 4: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 5: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 6: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 7: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 8: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 9: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 10: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 11: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 12: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 13: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 14: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 15: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; case 16: GpuRollFunctor()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count, out_ptr); break; default: break; } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_ROLL_KERNEL(dtype) \ REGISTER_USER_KERNEL("roll").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_ROLL_KERNEL(float); REGISTER_ROLL_KERNEL(double); REGISTER_ROLL_KERNEL(float16); REGISTER_ROLL_KERNEL(bool); REGISTER_ROLL_KERNEL(uint8_t); REGISTER_ROLL_KERNEL(int8_t); REGISTER_ROLL_KERNEL(int32_t); REGISTER_ROLL_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/roll_kernel_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_ROLL_KERNEL_UTILS_H_ #define ONEFLOW_ROLL_KERNEL_UTILS_H_ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { const int32_t kMaxDims = 16; struct SHIFTS { int32_t val[kMaxDims]; }; struct SHAPE { int32_t val[kMaxDims]; }; struct STRIDE { STRIDE() { for (int i = 0; i < kMaxDims; ++i) { val[i] = 1; } } int32_t val[kMaxDims]; }; template OF_DEVICE_FUNC int32_t getShiftedIndex(const int32_t global_index, const int32_t* shifts, const int32_t* shape, const int32_t* stride) { int32_t remaining = global_index; int32_t shifted_global_index = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int32_t i = 0; i < Dim; ++i) { const int32_t idx = remaining / stride[i]; // NOTE(Liang Depeng): Compute the shifted index of each axis. int32_t shifted_idx = (idx - shifts[i]); // NOTE(Liang Depeng): This correct the results. if (shifted_idx < 0) shifted_idx = shifted_idx + shape[i]; if (shifted_idx >= shape[i]) shifted_idx = shifted_idx - shape[i]; shifted_global_index += shifted_idx * stride[i]; remaining = remaining - idx * stride[i]; } return shifted_global_index; } OF_DEVICE_FUNC int32_t switchGetShiftedIndex(const int32_t global_index, const int32_t* shifts, const int32_t* shape, const int32_t* stride, int n) { switch (n) { case 1: return getShiftedIndex<1>(global_index, shifts, shape, stride); case 2: return getShiftedIndex<2>(global_index, shifts, shape, stride); case 3: return getShiftedIndex<3>(global_index, shifts, shape, stride); case 4: return getShiftedIndex<4>(global_index, shifts, shape, stride); case 5: return getShiftedIndex<5>(global_index, shifts, shape, stride); case 6: return getShiftedIndex<6>(global_index, shifts, shape, stride); case 7: return getShiftedIndex<7>(global_index, shifts, shape, stride); case 8: return getShiftedIndex<8>(global_index, shifts, shape, stride); case 9: return getShiftedIndex<9>(global_index, shifts, shape, stride); case 10: return getShiftedIndex<10>(global_index, shifts, shape, stride); case 11: return getShiftedIndex<11>(global_index, shifts, shape, stride); case 12: return getShiftedIndex<12>(global_index, shifts, shape, stride); case 13: return getShiftedIndex<13>(global_index, shifts, shape, stride); case 14: return getShiftedIndex<14>(global_index, shifts, shape, stride); case 15: return getShiftedIndex<15>(global_index, shifts, shape, stride); case 16: return getShiftedIndex<16>(global_index, shifts, shape, stride); } return 0; } static void initStride(STRIDE& stride, const SHAPE& dim_vec, const int32_t dims) { for (int i = dims - 2; i >= 0; --i) { stride.val[i] = dim_vec.val[i + 1] * stride.val[i + 1]; } } static void transformShifts(int32_t* shifts, int32_t* shape, int n) { for (int i = 0; i < n; ++i) { shifts[i] = shifts[i] % shape[i]; } // NOLINT } static void computeParams(const ShapeView& in_shape, const std::vector& shifts, const std::vector& dims, int32_t* new_shifts, int32_t* new_shape, int32_t* new_num_axes) { if (dims[0] == -1) { // NOTE(Liang Depeng): // If user did not set the dims parameter, // the input tensor will be flattened before rolling, // which means we can think of the input tensor as an 1 dimensional array. new_shifts[0] = shifts[0]; *new_num_axes = 1; new_shape[0] = in_shape.elem_cnt(); } else { std::map dim_to_shift; for (int i = 0; i < shifts.size(); ++i) { dim_to_shift.emplace(dims[i], shifts[i]); } // NOTE(Liang Depeng): // Compute the shift parameter for each axis. // For those axis which user did not specified shift value, will be set to 0 for (int i = 0; i < in_shape.NumAxes(); ++i) { if (dim_to_shift.count(i) > 0) { new_shifts[i] = dim_to_shift.at(i); } else { new_shifts[i] = 0; } new_shape[i] = in_shape.At(i); } *new_num_axes = in_shape.NumAxes(); } } } // namespace } // namespace oneflow #endif // ONEFLOW_ROLL_KERNEL_UTILS_H_ ================================================ FILE: oneflow/user/kernels/rrelu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template static T uniform_real(V val, T from, T to) { constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); constexpr auto DIVISOR = static_cast(1) / (static_cast(1) << std::numeric_limits::digits); T x = (val & MASK) * DIVISOR; return (x * (to - from) + from); } static uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) { return (static_cast(hi) << 32) | lo; } } // namespace template class CpuRReluKernel final : public user_op::OpKernel { public: CpuRReluKernel() = default; ~CpuRReluKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU)); generator->set_current_seed(CHECK_JUST( GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed"), {"output", 0}))); return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const int64_t size = in->shape_view().elem_cnt(); if (size == 0) return; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* noise_data = ctx->Tensor4ArgNameAndIndex("noise_data", 0); const T& lower = ctx->Attr("lower"); const T& upper = ctx->Attr("upper"); T* out_ptr = out->mut_dptr(); T* noise_ptr = noise_data->mut_dptr(); const T* in_ptr = in->dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); auto cpu_gen = CHECK_JUST(generator->Get()); std::lock_guard lock(cpu_gen->mutex_); ep::pytorch_mt19937_engine& engine = cpu_gen->torch_engine(); FOR_RANGE(int64_t, i, 0, size) { if (*(in_ptr + i) >= 0) { noise_ptr[i] = 1; out_ptr[i] = in_ptr[i]; } else { uint32_t random1 = engine(); uint32_t random2 = engine(); uint64_t rand_unit = make64BitsFrom32Bits(random1, random2); T uniform_sample = uniform_real(rand_unit, lower, upper); noise_ptr[i] = uniform_sample; out_ptr[i] = in_ptr[i] * uniform_sample; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_RRelu_KERNEL(dtype) \ REGISTER_USER_KERNEL("rrelu").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) REGISTER_CPU_RRelu_KERNEL(float); REGISTER_CPU_RRelu_KERNEL(double); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/rrelu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/distributions/normal_distribution.h" #include "oneflow/user/kernels/distributions/distribution_template_util.cuh" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" namespace oneflow { namespace { template struct UniformTransformFunctor { UniformTransformFunctor(ComputeType range, ComputeType lower) : range(range), lower(lower) {} __device__ T operator()(ComputeType random_val) const { return static_cast(random_val * range + lower); } ComputeType range; ComputeType lower; }; template OF_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) __global__ void RReluKernel(int64_t numel, uint64_t seed, uint64_t offset, const T* in_ptr, T* out_ptr, T* noise_data_ptr, Distribution dist_func, Transform transform_func) { int idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed, idx, offset, &state); int rounded_size = ((numel - 1) / (blockDim.x * gridDim.x * unroll_factor) + 1) * blockDim.x * gridDim.x * unroll_factor; for (int32_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { auto rand = dist_func(&state); #pragma unroll for (int ii = 0; ii < unroll_factor; ii++) { int li = linear_index + blockDim.x * gridDim.x * ii; if (li < numel) { T r = transform_func(static_cast((&rand.x)[ii])); if (in_ptr[li] <= static_cast(0)) { out_ptr[li] = in_ptr[li] * r; noise_data_ptr[li] = r; } else { out_ptr[li] = in_ptr[li]; noise_data_ptr[li] = static_cast(1); } } } } } } // namespace template class CudaRReluKernel final : public user_op::OpKernel { public: CudaRReluKernel() = default; ~CudaRReluKernel() = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); generator->set_current_seed(CHECK_JUST( GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed"), {"output", 0}))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const int64_t size = in->shape_view().elem_cnt(); if (size == 0) return; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("output", 0); user_op::Tensor* noise_data = ctx->Tensor4ArgNameAndIndex("noise_data", 0); const T& lower = ctx->Attr("lower"); const T& upper = ctx->Attr("upper"); T* out_ptr = out->mut_dptr(); T* noise_ptr = noise_data->mut_dptr(); const T* in_ptr = in->dptr(); auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); CHECK_NOTNULL(generator); ep::CudaStream* cuda_stream = ctx->stream()->As(); const auto device_index = ctx->stream()->device()->device_index(); std::shared_ptr cuda_gen = CHECK_JUST(generator->Get(device_index)); auto execution_policy = cuda_gen->CalcExecutionPolicy(size, cuda_stream); auto counter_offset = std::get<0>(execution_policy); uint64_t seed = cuda_gen->current_seed(); uint64_t offset = cuda_gen->get_philox_offset(counter_offset); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); using ComputeType = typename distribution::DefaultComputeType::type; UniformTransformFunctor transform_functor( static_cast(upper - lower), static_cast(lower)); if (std::is_same::value) { DistributionFunctor dist_functor; RReluKernel <<cuda_stream()>>>( size, seed, offset, in_ptr, out_ptr, noise_ptr, dist_functor, transform_functor); } else { // float DistributionFunctor dist_functor; RReluKernel <<cuda_stream()>>>( size, seed, offset, in_ptr, out_ptr, noise_ptr, dist_functor, transform_functor); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_RRELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("rrelu").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_CUDA_RRELU_KERNEL(float) REGISTER_CUDA_RRELU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/same_padding_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { namespace { template std::unique_ptr NewFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("y", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } template std::unique_ptr NewCopyNdPrimitive(Context* ctx) { const auto& in_arg_pair = ctx->inputs().front(); const int64_t ndims = ctx->TensorDesc4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second)->shape().NumAxes(); return ep::primitive::NewPrimitive(ctx->device_type(), ndims); } auto FillPrimitiveExists() { return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewFillPrimitive(&ctx).operator bool(); }); } auto CopyNdPrimitiveExists() { return hob::make_custom("CopyNdPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewCopyNdPrimitive(&ctx).operator bool(); }); } } // namespace class SamePaddingKernel final : public user_op::OpKernel { public: SamePaddingKernel() = default; ~SamePaddingKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t num_axes = x->shape_view().NumAxes(); const std::string& padding = ctx->Attr("padding"); const std::string& data_format = ctx->Attr("data_format"); const std::vector kernel_size = ctx->Attr>("kernel_size"); const std::vector strides = ctx->Attr>("strides"); const std::vector dilation_rate = ctx->Attr>("dilation_rate"); std::vector padding_before(num_axes, 0); const size_t idx_offset = IdxOffset(data_format); const int32_t num_spatial_dims = x->shape_view().NumAxes() - 2; for (int32_t i = 0; i < num_spatial_dims; ++i) { int32_t padding_small = 0; int32_t padding_large = 0; CHECK_JUST(CalcSamePadding(x->shape_view().At(idx_offset + i), kernel_size.at(i), // NOLINT dilation_rate.at(i), strides.at(i), &padding_small, // NOLINT &padding_large)); // NOLINT if (padding == "same_lower") { padding_before[idx_offset + i] = padding_large; } else if (padding == "same_upper") { padding_before[idx_offset + i] = padding_small; } else { UNIMPLEMENTED(); } CHECK_EQ(y->shape_view().At(idx_offset + i), x->shape_view().At(idx_offset + i) + padding_small + padding_large); } CHECK_EQ(padding_before.size(), num_axes); std::unique_ptr fill_primitive = NewFillPrimitive(ctx); CHECK(fill_primitive); fill_primitive->Launch(ctx->stream(), y->mut_dptr(), Scalar(0), y->shape_view().elem_cnt()); DimVector src_pos_vec(num_axes, 0); DimVector dst_pos_vec(padding_before.cbegin(), padding_before.cend()); std::unique_ptr copy_nd_primitive = NewCopyNdPrimitive(ctx); CHECK(copy_nd_primitive); copy_nd_primitive->Launch(ctx->stream(), x->data_type(), num_axes, y->mut_dptr(), y->shape_view().ptr(), dst_pos_vec.data(), x->dptr(), x->shape_view().ptr(), src_pos_vec.data(), x->shape_view().ptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("same_padding") .SetCreateFn() .SetIsMatchedHob(FillPrimitiveExists() && CopyNdPrimitiveExists()); class SamePaddingGradKernel final : public user_op::OpKernel { public: SamePaddingGradKernel() = default; ~SamePaddingGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t num_axes = dy->shape_view().NumAxes(); const std::string& padding = ctx->Attr("padding"); const std::string& data_format = ctx->Attr("data_format"); const std::vector kernel_size = ctx->Attr>("kernel_size"); const std::vector strides = ctx->Attr>("strides"); const std::vector dilation_rate = ctx->Attr>("dilation_rate"); std::vector padding_before(num_axes, 0); const size_t idx_offset = IdxOffset(data_format); const int32_t num_spatial_dims = dy->shape_view().NumAxes() - 2; for (int32_t i = 0; i < num_spatial_dims; ++i) { int32_t padding_small = 0; int32_t padding_large = 0; CHECK_JUST(CalcSamePadding(dx->shape_view().At(idx_offset + i), kernel_size.at(i), // NOLINT dilation_rate.at(i), strides.at(i), &padding_small, // NOLINT &padding_large)); // NOLINT if (padding == "same_lower") { padding_before[idx_offset + i] = padding_large; } else if (padding == "same_upper") { padding_before[idx_offset + i] = padding_small; } else { UNIMPLEMENTED(); } CHECK_EQ(dy->shape_view().At(idx_offset + i), dx->shape_view().At(idx_offset + i) + padding_small + padding_large); } DimVector dst_pos_vec(num_axes, 0); DimVector src_pos_vec(padding_before.cbegin(), padding_before.cend()); std::unique_ptr primitive = NewCopyNdPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), dy->data_type(), num_axes, dx->mut_dptr(), dx->shape_view().ptr(), dst_pos_vec.data(), dy->dptr(), dy->shape_view().ptr(), src_pos_vec.data(), dx->shape_view().ptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("same_padding_grad") .SetCreateFn() .SetIsMatchedHob(CopyNdPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/scalar_bitwise_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace { template std::unique_ptr NewBinaryPrimitive( Context* ctx, ep::primitive::BinaryOp op) { const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("out", 0); const int64_t ndims = in->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), op, in->data_type(), out->data_type(), ndims); } template auto PrimitiveExists() { return hob::make_custom("BroadcastElementwiseBinaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBinaryPrimitive(&ctx, op).operator bool(); }); } template class ScalarBitwiseKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScalarBitwiseKernel() = default; ~ScalarBitwiseKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Scalar scalar_operand = ctx->Attr("operand"); int64_t elem_cnt = out->shape_view().elem_cnt(); if (elem_cnt != 0) { std::unique_ptr primitive = NewBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), scalar_operand, out->mut_dptr()); } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL(kernel_name, binary_op) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob(PrimitiveExists()); REGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL("scalar_bitwise_and", ep::primitive::BinaryOp::kBitwiseAnd); REGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL("scalar_bitwise_or", ep::primitive::BinaryOp::kBitwiseOr); REGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL("scalar_bitwise_xor", ep::primitive::BinaryOp::kBitwiseXor); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/scalar_by_tensor_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace { template std::unique_ptr NewBroadcastElementwiseBinaryPrimitive( Context* ctx, ep::primitive::BinaryOp op) { const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex("y", 0); const int64_t ndims = y->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), op, x->data_type(), y->data_type(), ndims); } template auto BroadcastElementwiseBinaryPrimitiveExists() { return hob::make_custom("BroadcastElementwiseBinaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBroadcastElementwiseBinaryPrimitive(&ctx, op).operator bool(); }); } template class ScalarByTensorKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScalarByTensorKernel() = default; ~ScalarByTensorKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex("scalar", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); int64_t elem_cnt = y->shape_view().elem_cnt(); if (elem_cnt != 0) { std::unique_ptr primitive = NewBroadcastElementwiseBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), x->shape_view().NumAxes(), x->shape_view().ptr(), x->dptr(), scalar->shape_view().NumAxes(), scalar->shape_view().ptr(), scalar->dptr(), y->mut_dptr()); } else { // For 0-size Tensor return; } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace #define REGISTER_SCALAR_BY_TENSOR_KERNEL(op_name, binary_op) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob(BroadcastElementwiseBinaryPrimitiveExists()) \ .SetInplaceProposalFn( \ [](const user_op::InferContext&, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); \ return Maybe::Ok(); \ }); #define SCALAR_BY_TENSOR_SEQ \ OF_PP_MAKE_TUPLE_SEQ("scalar_add_by_tensor", ep::primitive::BinaryOp::kAdd) \ OF_PP_MAKE_TUPLE_SEQ("scalar_sub_by_tensor", ep::primitive::BinaryOp::kSub) \ OF_PP_MAKE_TUPLE_SEQ("scalar_mul_by_tensor", ep::primitive::BinaryOp::kMul) \ OF_PP_MAKE_TUPLE_SEQ("scalar_div_by_tensor", ep::primitive::BinaryOp::kDiv) OF_PP_FOR_EACH_TUPLE(REGISTER_SCALAR_BY_TENSOR_KERNEL, SCALAR_BY_TENSOR_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/scalar_logical_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace { template std::unique_ptr NewBinaryPrimitive( Context* ctx, ep::primitive::BinaryOp op) { const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("out", 0); const int64_t ndims = in->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), op, in->data_type(), out->data_type(), ndims); } template auto PrimitiveExists() { return hob::make_custom("BroadcastElementwiseBinaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBinaryPrimitive(&ctx, op).operator bool(); }); } template class ScalarLogicalKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScalarLogicalKernel() = default; ~ScalarLogicalKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Scalar scalar_operand; if (ctx->Attr("has_int_operand")) { scalar_operand = ctx->Attr("int_operand"); } else if (ctx->Attr("has_float_operand")) { scalar_operand = ctx->Attr("float_operand"); } else { UNIMPLEMENTED(); } int64_t elem_cnt = out->shape_view().elem_cnt(); if (elem_cnt != 0) { std::unique_ptr primitive = NewBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), scalar_operand, out->mut_dptr()); } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(kernel_name, binary_op) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob(PrimitiveExists()); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_equal", ep::primitive::BinaryOp::kEqual); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_not_equal", ep::primitive::BinaryOp::kNotEqual); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_greater", ep::primitive::BinaryOp::kGreaterThan); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_greater_equal", ep::primitive::BinaryOp::kGreaterEqual); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_less", ep::primitive::BinaryOp::kLessThan); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_less_equal", ep::primitive::BinaryOp::kLessEqual); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_or", ep::primitive::BinaryOp::kLogicalOr); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_xor", ep::primitive::BinaryOp::kLogicalXor); REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL("scalar_logical_and", ep::primitive::BinaryOp::kLogicalAnd); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/scalar_math_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template std::unique_ptr NewBroadcastElementwiseBinaryPrimitive( Context* ctx, ep::primitive::BinaryOp op) { const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("in", 0); const user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex("out", 0); const int64_t ndims = y->shape().NumAxes(); return ep::primitive::NewPrimitive( ctx->device_type(), op, x->data_type(), y->data_type(), ndims); } template auto BroadcastElementwiseBinaryPrimitiveExists() { return hob::make_custom("BroadcastElementwiseBinaryPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBroadcastElementwiseBinaryPrimitive(&ctx, op).operator bool(); }); } template std::unique_ptr NewBroadcastElementwiseAttrBinaryPrimitive(Context* ctx, ep::primitive::BinaryOp op) { const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const int64_t ndims = dy->shape().NumAxes(); Scalar value; if (ctx->template Attr("has_int_operand")) { value = Scalar(ctx->template Attr("int_operand")); } else if (ctx->template Attr("has_float_operand")) { value = Scalar(ctx->template Attr("float_operand")); } else { UNIMPLEMENTED(); } return ep::primitive::NewPrimitive( ctx->device_type(), op, x->data_type(), dy->data_type(), ndims, value); } template auto BroadcastElementwiseAttrBinaryPrimitiveExists() { return hob::make_custom( "BroadcastElementwiseBinaryAttrPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewBroadcastElementwiseAttrBinaryPrimitive(&ctx, op).operator bool(); }); } } // namespace template class ScalarMathKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScalarMathKernel() = default; ~ScalarMathKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Scalar value; if (ctx->Attr("has_int_operand")) { value = Scalar(ctx->Attr("int_operand")); } else if (ctx->Attr("has_float_operand")) { value = Scalar(ctx->Attr("float_operand")); } else { UNIMPLEMENTED(); } int64_t elem_cnt = out->shape_view().elem_cnt(); if (elem_cnt != 0) { const bool is_add_sub_0 = (op == ep::primitive::BinaryOp::kAdd || op == ep::primitive::BinaryOp::kSub) && value.Value() == 0.0; const bool is_mul_div_1 = (op == ep::primitive::BinaryOp::kMul || op == ep::primitive::BinaryOp::kDiv) && value.Value() == 1.0; if ((is_add_sub_0 || is_mul_div_1) && in->dptr() == out->dptr()) { return; } std::unique_ptr primitive = NewBroadcastElementwiseBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), value, out->mut_dptr()); } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class ScalarReverseMathKernel final : public user_op::OpKernel { public: ScalarReverseMathKernel() = default; ~ScalarReverseMathKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Scalar value; if (ctx->Attr("has_int_operand")) { value = Scalar(ctx->Attr("int_operand")); } else if (ctx->Attr("has_float_operand")) { value = Scalar(ctx->Attr("float_operand")); } else { UNIMPLEMENTED(); } int64_t elem_cnt = out->shape_view().elem_cnt(); if (elem_cnt != 0) { std::unique_ptr primitive = NewBroadcastElementwiseBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), value, in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), out->mut_dptr()); } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define SCALAR_MATH_SEQ \ OF_PP_MAKE_TUPLE_SEQ("scalar_add", ep::primitive::BinaryOp::kAdd) \ OF_PP_MAKE_TUPLE_SEQ("scalar_mul", ep::primitive::BinaryOp::kMul) \ OF_PP_MAKE_TUPLE_SEQ("scalar_div", ep::primitive::BinaryOp::kDiv) \ OF_PP_MAKE_TUPLE_SEQ("scalar_floordiv", ep::primitive::BinaryOp::kFloorDiv) \ OF_PP_MAKE_TUPLE_SEQ("scalar_truncdiv", ep::primitive::BinaryOp::kTruncDiv) \ OF_PP_MAKE_TUPLE_SEQ("scalar_fmod", ep::primitive::BinaryOp::kFmod) \ OF_PP_MAKE_TUPLE_SEQ("scalar_pow", ep::primitive::BinaryOp::kPow) #define REGISTER_UNARY_MATH_SCALAR_ELEMWISE_USER_KERNEL(op_name, binary_op) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((BroadcastElementwiseBinaryPrimitiveExists())) \ .SetInplaceProposalFn( \ [](const user_op::InferContext& ctx, \ const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \ return Maybe::Ok(); \ }); OF_PP_FOR_EACH_TUPLE(REGISTER_UNARY_MATH_SCALAR_ELEMWISE_USER_KERNEL, SCALAR_MATH_SEQ) #define REGISTER_UNARY_MATH_SCALAR_REVERSE_ELEMWISE_USER_KERNEL(op_name, binary_op) \ REGISTER_USER_KERNEL(op_name).SetCreateFn>().SetIsMatchedHob( \ (BroadcastElementwiseBinaryPrimitiveExists())); REGISTER_UNARY_MATH_SCALAR_REVERSE_ELEMWISE_USER_KERNEL("scalar_reverse_pow", ep::primitive::BinaryOp::kPow) template class ScalarPowGradKernel final : public user_op::OpKernel { public: ScalarPowGradKernel() = default; ~ScalarPowGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t elem_cnt = dx_tensor->shape_view().elem_cnt(); if (elem_cnt != 0) { std::unique_ptr primitive = NewBroadcastElementwiseAttrBinaryPrimitive(ctx, op); CHECK(primitive); primitive->Launch(ctx->stream(), x_tensor->shape_view().NumAxes(), x_tensor->shape_view().ptr(), x_tensor->dptr(), dy_tensor->shape_view().NumAxes(), dy_tensor->shape_view().ptr(), dy_tensor->dptr(), dx_tensor->mut_dptr()); } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL(op_name, binary_op) \ REGISTER_USER_KERNEL(op_name).SetCreateFn>().SetIsMatchedHob( \ (BroadcastElementwiseAttrBinaryPrimitiveExists())); REGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL("scalar_pow_grad", ep::primitive::BinaryOp::kScalarBasePowerGrad); REGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL("scalar_reverse_pow_grad", ep::primitive::BinaryOp::kScalarExpPowerGrad); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/scaled_dot_product_attention_grad_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/user_op_tensor.h" #if CUDA_VERSION >= 11070 #ifdef WITH_CUTLASS #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/include/primitive/permute.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/warp/mma.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" // from flash_attention #include "oneflow/user/kernels/scaled_dot_product_attention_util.h" namespace oneflow { namespace user_op { namespace { static size_t InferTmpBufferSizeForFlashAttentionGradKernel(InferContext* ctx) { const auto& q_shape = ctx->InputTensorDesc("query", 0).shape(); const int batch_size = q_shape.At(0); const int seqlen_q = q_shape.At(1); const int num_heads = q_shape.At(2); const int head_size = q_shape.At(3); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); size_t buffer_size = 0; buffer_size += GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded * GetSizeOfDataType(DataType::kFloat)); buffer_size += GetCudaAlignedSize(batch_size * seqlen_q_rounded * num_heads * head_size_rounded * GetSizeOfDataType(DataType::kFloat)); return buffer_size; } class ScaledDotProductFlashAttentionGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScaledDotProductFlashAttentionGradKernel() = default; ~ScaledDotProductFlashAttentionGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); const Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex("softmax_lse", 0); const Tensor* rng_state = ctx->Tensor4ArgNameAndIndex("rng_state", 0); const Tensor* alibi_slopes_ = nullptr; if (ctx->has_input("alibi_slopes_", 0)) { alibi_slopes_ = ctx->Tensor4ArgNameAndIndex("alibi_slopes_", 0); } const float p_dropout = ctx->Attr("p_dropout"); const float softmax_scale = ctx->Attr("softmax_scale"); bool is_causal = ctx->Attr("is_causal"); int window_size_left = ctx->Attr("window_size_left"); int window_size_right = ctx->Attr("window_size_right"); Tensor* grad_q = ctx->Tensor4ArgNameAndIndex("grad_q", 0); Tensor* grad_k = ctx->Tensor4ArgNameAndIndex("grad_k", 0); Tensor* grad_v = ctx->Tensor4ArgNameAndIndex("grad_v", 0); Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); void* tmp_ptr = tmp->mut_dptr(); auto* cuda_device = dynamic_cast(ctx->stream()->device()); auto dprops = cuda_device->properties(); auto* cuda_stream = ctx->stream()->As(); bool is_dropout = p_dropout > 0.0f; if (is_causal) { window_size_right = 0; } const int arch = cuda_stream->cuda_arch() / 10; const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90); CHECK(is_supported_arch); const DataType data_type = query->data_type(); const bool is_supported_dtype = (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16); CHECK(is_supported_dtype); CHECK_EQ(key->data_type(), data_type); CHECK_EQ(value->data_type(), data_type); CHECK_EQ(grad_out->data_type(), data_type); CHECK_EQ(out->data_type(), data_type); CHECK_EQ(softmax_lse->data_type(), DataType::kFloat); CHECK_EQ(rng_state->data_type(), DataType::kUInt64); // check contiguous last dimension. CHECK_EQ(CHECK_JUST(VectorAt(grad_out->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(out->stride(), 3)), 1); const int batch_size = query->shape_view().At(0); const int seqlen_q = query->shape_view().At(1); const int num_heads = query->shape_view().At(2); const int head_size = query->shape_view().At(3); const int seqlen_k = key->shape_view().At(1); const int num_heads_k = key->shape_view().At(2); const int head_size_og = grad_out->shape_view().At(3); // check tensor shape. CHECK_EQ(grad_out->shape_view().At(0), batch_size); CHECK_EQ(grad_out->shape_view().At(1), seqlen_q); CHECK_EQ(grad_out->shape_view().At(2), num_heads); CHECK_EQ(grad_out->shape_view().At(3), head_size_og); CHECK_EQ(query->shape_view().At(0), batch_size); CHECK_EQ(query->shape_view().At(1), seqlen_q); CHECK_EQ(query->shape_view().At(2), num_heads); CHECK_EQ(query->shape_view().At(3), head_size); CHECK_EQ(key->shape_view().At(0), batch_size); CHECK_EQ(key->shape_view().At(1), seqlen_k); CHECK_EQ(key->shape_view().At(2), num_heads_k); CHECK_EQ(key->shape_view().At(3), head_size); CHECK_EQ(value->shape_view().At(0), batch_size); CHECK_EQ(value->shape_view().At(1), seqlen_k); CHECK_EQ(value->shape_view().At(2), num_heads_k); CHECK_EQ(value->shape_view().At(3), head_size); CHECK_EQ(out->shape_view().At(0), batch_size); CHECK_EQ(out->shape_view().At(1), seqlen_q); CHECK_EQ(out->shape_view().At(2), num_heads); CHECK_EQ(out->shape_view().At(3), head_size); CHECK_EQ(softmax_lse->shape_view().At(0), batch_size); CHECK_EQ(softmax_lse->shape_view().At(1), num_heads); CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q); CHECK_GT(batch_size, 0); // batch size must be postive CHECK_LE(head_size, 256); // only support head dimensions at most 256 // FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout // requires A100/A800 or H100/H800 if (head_size > 192 && (head_size <= 224 || is_dropout)) { CHECK((arch == 80 || arch == 90)); } CHECK(num_heads % num_heads_k == 0); // Number of heads in key/value must devide number of heads in query if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; // size: batch_size x num_heads x seqlen_q_rounded; datatype: float void* softmax_d_ptr = tmp_ptr; tmp_ptr = reinterpret_cast(tmp_ptr) + GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded * GetSizeOfDataType(DataType::kFloat)); // set to false by default. // TODO(chende): can get from forward kernel(add input in python interface, it's only used for // backward). bool deterministic = false; void* dq_accum_ptr; if (loop) { // size: batch_size x seqlen_q_rounded x num_heads x head_size_rounded; datatype: float dq_accum_ptr = tmp_ptr; } Flash_bwd_params params; set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out, grad_out, grad_q, grad_k, grad_v, nullptr, nullptr, loop ? dq_accum_ptr : nullptr, // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, nullptr, nullptr, const_cast(softmax_lse->dptr()), softmax_d_ptr, p_dropout, softmax_scale, window_size_left, window_size_right, deterministic); params.dq_accum_split_stride = !deterministic ? 0 : seqlen_q_rounded * num_heads * head_size_rounded; auto launch = &run_mha_bwd; params.rng_state = const_cast(rng_state->dptr()); set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_q > 0) { launch(params, cuda_stream->cuda_stream()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype) \ REGISTER_USER_KERNEL("scaled_dot_product_flash_attention_grad") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == dtype)) \ .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionGradKernel); REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16) REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16) } // namespace } // namespace user_op } // namespace oneflow #endif // WITH_CUTLASS #endif // CUDA_VERSION >= 11070 ================================================ FILE: oneflow/user/kernels/scaled_dot_product_attention_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/throw.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/user_op_tensor.h" #if CUDA_VERSION >= 11070 #ifdef WITH_CUTLASS #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/include/primitive/permute.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/warp/mma.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" // from flash_attention #include "oneflow/user/kernels/scaled_dot_product_attention_util.h" namespace oneflow { namespace user_op { namespace { static size_t InferTmpBufferSizeForFlashAttentionKernel(InferContext* ctx) { const float p_dropout = ctx->Attr("p_dropout"); const auto& q_shape = ctx->InputTensorDesc("query", 0).shape(); const auto& k_shape = ctx->InputTensorDesc("key", 0).shape(); const int batch_size = q_shape.At(0); const int seqlen_q = q_shape.At(1); const int num_heads = q_shape.At(2); const int head_size_og = q_shape.At(3); const int seqlen_k = k_shape.At(1); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = round_multiple(head_size, 32); int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; const int num_m_blocks = (seqlen_q + 64 - 1) / 64; size_t buffer_size = 0; // for splitKV and splitKV is not implemented for dropout. if (p_dropout == 0.0f) { int num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, sm_count, num_n_blocks, 128); buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q * GetSizeOfDataType(DataType::kFloat)); buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q * head_size_rounded * GetSizeOfDataType(DataType::kFloat)); } return buffer_size; } class ScaledDotProductFlashAttentionKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: ScaledDotProductFlashAttentionKernel() = default; ~ScaledDotProductFlashAttentionKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); generator->set_current_seed( CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); return std::make_shared(generator); } private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const Tensor* alibi_slopes_ = nullptr; if (ctx->has_input("alibi_slopes_", 0)) { // default to null, it will never get input for current flash-attn version. alibi_slopes_ = ctx->Tensor4ArgNameAndIndex("alibi_slopes_", 0); CHECK(!alibi_slopes_) << "alibi_slopes should not have value"; } const float p_dropout = ctx->Attr("p_dropout"); const float softmax_scale = ctx->Attr("softmax_scale"); bool is_causal = ctx->Attr("is_causal"); int window_size_left = ctx->Attr("window_size_left"); int window_size_right = ctx->Attr("window_size_right"); uint64_t seed = ctx->Attr("seed"); Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex("softmax_lse", 0); Tensor* rng_state = ctx->Tensor4ArgNameAndIndex("rng_state", 0); Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); void* tmp_ptr = tmp->mut_dptr(); auto* cuda_device = dynamic_cast(ctx->stream()->device()); auto dprops = cuda_device->properties(); auto* cuda_stream = ctx->stream()->As(); const int arch = cuda_stream->cuda_arch() / 10; const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90); CHECK(is_supported_arch) << "only supports CUDA Arch 80, 86, 89 and 90."; const DataType data_type = query->data_type(); const bool is_supported_dtype = (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16); CHECK(is_supported_dtype); CHECK_EQ(key->data_type(), data_type); CHECK_EQ(value->data_type(), data_type); CHECK_EQ(out->data_type(), data_type); CHECK_EQ(softmax_lse->data_type(), DataType::kFloat); // check contiguous last dimension. CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1); CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1); const int batch_size = query->shape_view().At(0); const int seqlen_q = query->shape_view().At(1); const int num_heads = query->shape_view().At(2); const int head_size_og = query->shape_view().At(3); const int seqlen_k = key->shape_view().At(1); const int num_heads_k = key->shape_view().At(2); // check tensor shape. CHECK_EQ(query->shape_view().At(0), batch_size); CHECK_EQ(query->shape_view().At(1), seqlen_q); CHECK_EQ(query->shape_view().At(2), num_heads); CHECK_EQ(query->shape_view().At(3), head_size_og); CHECK_EQ(key->shape_view().At(0), batch_size); CHECK_EQ(key->shape_view().At(1), seqlen_k); CHECK_EQ(key->shape_view().At(2), num_heads_k); CHECK_EQ(key->shape_view().At(3), head_size_og); CHECK_EQ(value->shape_view().At(0), batch_size); CHECK_EQ(value->shape_view().At(1), seqlen_k); CHECK_EQ(value->shape_view().At(2), num_heads_k); CHECK_EQ(value->shape_view().At(3), head_size_og); CHECK_EQ(out->shape_view().At(0), batch_size); CHECK_EQ(out->shape_view().At(1), seqlen_q); CHECK_EQ(out->shape_view().At(2), num_heads); CHECK_EQ(out->shape_view().At(3), head_size_og); CHECK_EQ(softmax_lse->shape_view().At(0), batch_size); CHECK_EQ(softmax_lse->shape_view().At(1), num_heads); CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q); CHECK_GT(batch_size, 0); // batch size must be postive CHECK_LE(head_size_og, 256); // only support head dimensions at most 256 CHECK(num_heads % num_heads_k == 0); // Number of heads in key/value must devide number of heads in query if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_) { is_causal = false; } if (is_causal) { window_size_right = 0; } const int seqlenq_ngroups_swapped = 0; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, /*return_softmax=*/nullptr, softmax_lse->mut_dptr(), p_dropout, softmax_scale, window_size_left, window_size_right); int64_t counter_offset = params.b * params.h * 32; params.rng_state = rng_state->mut_dptr(); set_params_splitkv(params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, tmp_ptr); if (p_dropout > 0.0f) { // todo gennerator. auto* flash_attention_kernel_state = dynamic_cast(state); CHECK_NOTNULL(flash_attention_kernel_state); const auto& generator = flash_attention_kernel_state->generator(); CHECK_NOTNULL(generator); const auto device_index = cuda_device->device_index(); std::shared_ptr cuda_generator = CHECK_JUST(generator->Get(device_index)); params.philox_args = at::PhiloxCudaState(seed, cuda_generator->get_philox_offset(counter_offset)); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_k > 0) { run_mha_fwd(params, cuda_stream->cuda_stream()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype) \ REGISTER_USER_KERNEL("scaled_dot_product_flash_attention") \ .SetCreateFn() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == dtype)) \ .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionKernel); REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16) REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16) } // namespace } // namespace user_op } // namespace oneflow #endif // WITH_CUTLASS #endif // CUDA_VERSION >= 11070 ================================================ FILE: oneflow/user/kernels/scaled_dot_product_attention_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ #define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ #include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/core/framework/framework.h" namespace oneflow { class ScaledDotProductFlashAttentionKernelState : public user_op::OpKernelState { public: explicit ScaledDotProductFlashAttentionKernelState( const std::shared_ptr& generator) : generator_(generator) {} const std::shared_ptr& generator() const { return generator_; } private: std::shared_ptr generator_; }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/scaled_dot_product_attention_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ #define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ #include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/common/util.h" #include "flash.h" #include "static_switch.h" namespace oneflow { namespace user_op { namespace { void set_params_fprop(Flash_fwd_params& params, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const Tensor* q, const Tensor* k, const Tensor* v, Tensor* out, void* cu_seqlens_q_d, void* cu_seqlens_k_d, void* seqused_k, void* p_d, void* softmax_lse_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, bool seqlenq_ngroups_swapped = false) { // Reset the parameters std::memset(¶ms, 0, sizeof(params)); params.is_bf16 = q->data_type() == DataType::kBFloat16; // Set the pointers and strides. params.q_ptr = const_cast(q->dptr()); params.k_ptr = const_cast(k->dptr()); params.v_ptr = const_cast(v->dptr()); // All stride are in elements, not bytes. params.q_row_stride = CHECK_JUST(VectorAt(q->stride(), 1)); params.k_row_stride = CHECK_JUST(VectorAt(k->stride(), 1)); params.v_row_stride = CHECK_JUST(VectorAt(v->stride(), 1)); params.q_head_stride = CHECK_JUST(VectorAt(q->stride(), 2)); params.k_head_stride = CHECK_JUST(VectorAt(k->stride(), 2)); params.v_head_stride = CHECK_JUST(VectorAt(v->stride(), 2)); params.o_ptr = out->mut_dptr(); params.o_row_stride = CHECK_JUST(VectorAt(out->stride(), 1)); params.o_head_stride = CHECK_JUST(VectorAt(out->stride(), 2)); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = CHECK_JUST(VectorAt(q->stride(), 0)); params.k_batch_stride = CHECK_JUST(VectorAt(k->stride(), 0)); params.v_batch_stride = CHECK_JUST(VectorAt(v->stride(), 0)); params.o_batch_stride = CHECK_JUST(VectorAt(out->stride(), 0)); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; params.o_batch_stride *= seqlen_q; } } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); params.seqused_k = static_cast(seqused_k); // P = softmax(QK^T) params.p_ptr = p_d; // Softmax sum params.softmax_lse_ptr = softmax_lse_d; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; // Set the different scale values. params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; // Convert p from float to int so we don't have to convert the random uint to float to compare. // [Minor] We want to round down since when we do the comparison we use <= instead of < // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; CHECK_LT(p_dropout, 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0; if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } params.window_size_left = window_size_left; params.window_size_right = window_size_right; #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), "This flash attention build does not support local attention."); #endif params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif } void set_params_dgrad(Flash_bwd_params& params, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const Tensor* q, const Tensor* k, const Tensor* v, const Tensor* out, const Tensor* dout, Tensor* dq, Tensor* dk, Tensor* dv, void* cu_seqlens_q_d, void* cu_seqlens_k_d, void* dq_accum_d, void* dk_accum_d, void* dv_accum_d, void* softmax_lse_d, void* dsoftmax_sum_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, bool deterministic) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, v, const_cast(out), cu_seqlens_q_d, cu_seqlens_k_d, nullptr, nullptr, softmax_lse_d, p_dropout, softmax_scale, window_size_left, window_size_right); // Set the pointers and strides. params.do_ptr = const_cast(dout->dptr()); params.do_row_stride = CHECK_JUST(VectorAt(dout->stride(), 1)); params.do_head_stride = CHECK_JUST(VectorAt(dout->stride(), 2)); params.dq_ptr = dq->mut_dptr(); params.dk_ptr = dk->mut_dptr(); params.dv_ptr = dv->mut_dptr(); params.dq_row_stride = CHECK_JUST(VectorAt(dq->stride(), 1)); params.dk_row_stride = CHECK_JUST(VectorAt(dk->stride(), 1)); params.dv_row_stride = CHECK_JUST(VectorAt(dv->stride(), 1)); params.dq_head_stride = CHECK_JUST(VectorAt(dq->stride(), 2)); params.dk_head_stride = CHECK_JUST(VectorAt(dk->stride(), 2)); params.dv_head_stride = CHECK_JUST(VectorAt(dv->stride(), 2)); if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = CHECK_JUST(VectorAt(dout->stride(), 0)); params.dq_batch_stride = CHECK_JUST(VectorAt(dq->stride(), 0)); params.dk_batch_stride = CHECK_JUST(VectorAt(dk->stride(), 0)); params.dv_batch_stride = CHECK_JUST(VectorAt(dv->stride(), 0)); } params.dq_accum_ptr = dq_accum_d; params.dk_accum_ptr = dk_accum_d; params.dv_accum_ptr = dv_accum_d; // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; params.deterministic = deterministic; } void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { run_mha_fwd_splitkv_dispatch(params, stream); } }); }); } void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { run_mha_bwd_(params, stream); }); }); } // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { continue; } if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } void set_params_splitkv(Flash_fwd_params& params, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, const float p_dropout, const int num_splits, cudaDeviceProp& dprops, void* tmp_ptr) { // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; params.num_splits = num_splits; if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops.multiProcessorCount, num_n_blocks, 128); } if (params.num_splits > 1) { size_t softmax_lse_accum_size = params.num_splits * batch_size * num_heads * max_seqlen_q * sizeof(float); params.softmax_lseaccum_ptr = tmp_ptr; params.oaccum_ptr = reinterpret_cast(tmp_ptr) + GetCudaAlignedSize(softmax_lse_accum_size); } CHECK_LE(params.num_splits, 128); } } void set_params_alibi(Flash_fwd_params& params, const Tensor* alibi_slopes_, int batch_size, int num_heads) { // TODO(ChenDe): Need Support Alibi params. // default to null CHECK(!alibi_slopes_) << "alibi_slopes should be null."; params.alibi_slopes_ptr = nullptr; } } // namespace } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ ================================================ FILE: oneflow/user/kernels/search_sorted_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/search_sorted_kernel_util.h" namespace oneflow { template class CpuSearchSortedKernel final : public user_op::OpKernel { public: CpuSearchSortedKernel() = default; ~CpuSearchSortedKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex("sorted_sequence", 0); const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const bool& right = ctx->Attr("right"); const T* values_ptr = values->dptr(); const T* sequence_ptr = sorted_sequence->dptr(); K* out_ptr = out->mut_dptr(); const int32_t instance_num = values->shape_view().elem_cnt(); bool is_values_scalar = values->shape_view().NumAxes() == 0; bool is_sequence_1d = (sorted_sequence->shape_view().NumAxes() == 1); K values_shape_last = is_values_scalar ? 1 : values->shape_view().At(values->shape_view().NumAxes() - 1); K sequence_shape_last = sorted_sequence->shape_view().At(sorted_sequence->shape_view().NumAxes() - 1); FOR_RANGE(int32_t, i, 0, instance_num) { K start_bd = is_sequence_1d ? 0 : i / values_shape_last * sequence_shape_last; K end_bd = start_bd + sequence_shape_last; K pos = !right ? cus_lower_bound(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd : cus_upper_bound(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd; out_ptr[i] = pos; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_SEARCH_SORTED_KERNEL(in_dtype, out_dtype) \ REGISTER_USER_KERNEL("searchsorted") \ .SetCreateFn< \ CpuSearchSortedKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("sorted_sequence", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("values", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_SEARCH_SORTED_KERNEL, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class CpuSearchSortedScalarKernel final : public user_op::OpKernel { public: CpuSearchSortedScalarKernel() = default; ~CpuSearchSortedScalarKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex("sorted_sequence", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const bool& right = ctx->Attr("right"); const T& values = static_cast(ctx->Attr("values")); const T* sequence_ptr = sorted_sequence->dptr(); K* out_ptr = out->mut_dptr(); K sequence_shape_last = sorted_sequence->shape_view().At(0); K pos = !right ? cus_lower_bound(0, sequence_shape_last, values, sequence_ptr) : cus_upper_bound(0, sequence_shape_last, values, sequence_ptr); out_ptr[0] = pos; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_SEARCH_SORTED_SCALAR_KERNEL(in_dtype, out_dtype) \ REGISTER_USER_KERNEL("searchsorted_scalar") \ .SetCreateFn< \ CpuSearchSortedScalarKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("sorted_sequence", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_SEARCH_SORTED_SCALAR_KERNEL, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/search_sorted_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/user/kernels/search_sorted_kernel_util.h" namespace oneflow { template __global__ void DoSearchSortedLogical(int32_t instance_num, bool is_sequence_1d, K values_shape_last, K sequence_shape_last, bool right, const T* values_ptr, const T* sequence_ptr, K* out_ptr) { CUDA_1D_KERNEL_LOOP(i, instance_num) { K start_bd = is_sequence_1d ? 0 : i / values_shape_last * sequence_shape_last; K end_bd = start_bd + sequence_shape_last; K pos = !right ? cus_lower_bound(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd : cus_upper_bound(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd; out_ptr[i] = pos; } } template __global__ void DoSearchSortedScalarLogical(K sequence_shape_last, bool right, const T values, const T* sequence_ptr, K* out_ptr) { CUDA_1D_KERNEL_LOOP(i, 1) { K pos = !right ? cus_lower_bound(0, sequence_shape_last, values, sequence_ptr) : cus_upper_bound(0, sequence_shape_last, values, sequence_ptr); out_ptr[0] = pos; } } template class GpuSearchSortedKernel final : public user_op::OpKernel { public: GpuSearchSortedKernel() = default; ~GpuSearchSortedKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex("sorted_sequence", 0); const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const bool& right = ctx->Attr("right"); const T* values_ptr = values->dptr(); const T* sequence_ptr = sorted_sequence->dptr(); K* out_ptr = out->mut_dptr(); const int32_t instance_num = values->shape_view().elem_cnt(); bool is_values_scalar = values->shape_view().NumAxes() == 0; bool is_sequence_1d = (sorted_sequence->shape_view().NumAxes() == 1); K values_shape_last = is_values_scalar ? 1 : values->shape_view().At(values->shape_view().NumAxes() - 1); K sequence_shape_last = sorted_sequence->shape_view().At(sorted_sequence->shape_view().NumAxes() - 1); RUN_CUDA_KERNEL((DoSearchSortedLogical), ctx->stream(), instance_num, instance_num, is_sequence_1d, values_shape_last, sequence_shape_last, right, values_ptr, sequence_ptr, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GPU_SEARCH_SORTED_KERNEL(in_dtype, out_dtype) \ REGISTER_USER_KERNEL("searchsorted") \ .SetCreateFn< \ GpuSearchSortedKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("sorted_sequence", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("values", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GPU_SEARCH_SORTED_KERNEL, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) template class GpuSearchSortedScalarKernel final : public user_op::OpKernel { public: GpuSearchSortedScalarKernel() = default; ~GpuSearchSortedScalarKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex("sorted_sequence", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const bool& right = ctx->Attr("right"); const T& values = static_cast(ctx->Attr("values")); const T* sequence_ptr = sorted_sequence->dptr(); K* out_ptr = out->mut_dptr(); K sequence_shape_last = sorted_sequence->shape_view().At(0); RUN_CUDA_KERNEL((DoSearchSortedScalarLogical), ctx->stream(), 1, sequence_shape_last, right, values, sequence_ptr, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GPU_SEARCH_SORTED_SCALAR_KERNEL(in_dtype, out_dtype) \ REGISTER_USER_KERNEL("searchsorted_scalar") \ .SetCreateFn< \ GpuSearchSortedScalarKernel>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("sorted_sequence", 0) == OF_PP_PAIR_SECOND(in_dtype)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GPU_SEARCH_SORTED_SCALAR_KERNEL, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/search_sorted_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" template OF_DEVICE_FUNC K cus_lower_bound(K start, K end, const T val, const T* bd) { while (start < end) { const K mid = start + ((end - start) >> 1); const T mid_val = bd[mid]; if (!(mid_val >= val)) { start = mid + 1; } else { end = mid; } } return start; } template OF_DEVICE_FUNC K cus_upper_bound(K start, K end, const T val, const T* bd) { while (start < end) { const K mid = start + ((end - start) >> 1); const T mid_val = bd[mid]; if (!(mid_val > val)) { start = mid + 1; } else { end = mid; } } return start; } ================================================ FILE: oneflow/user/kernels/sigmoid_cross_entropy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sigmoid_cross_entropy_kernel.h" namespace oneflow { namespace { template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyGradFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction, const LabelT* label, const PredT* loss_diff) { FOR_RANGE(int64_t, i, 0, n) { prediction_diff[i] = Opt()(prediction[i], label[i], loss_diff[i]); } } }; template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction, const LabelT* label) { FOR_RANGE(int64_t, i, 0, n) { loss[i] = Opt()(prediction[i], label[i]); } } }; } // namespace REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, float) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, double) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, float) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sigmoid_cross_entropy_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/sigmoid_cross_entropy_kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyGradFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction, const LabelT* label, const PredT* loss_diff) { OF_CUDA_CHECK(cuda::elementwise::Ternary(Opt(), n, prediction_diff, prediction, label, loss_diff, stream->As()->cuda_stream())); } }; template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction, const LabelT* label) { OF_CUDA_CHECK(cuda::elementwise::Binary(Opt(), n, loss, prediction, label, stream->As()->cuda_stream())); } }; } // namespace REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, float) REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, double) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, int32_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, int8_t) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, float) REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sigmoid_cross_entropy_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_ #define ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/math_unary_elementwise_func.h" namespace oneflow { template struct SigmoidCrossEntropyFunctor { OF_DEVICE_FUNC PredT operator()(const PredT prediction, const LabelT label) const { return -1.f * prediction * (label - (prediction >= 0)) + LogFunctor::Forward( 1 + ExpFunctor::Forward(prediction - 2 * prediction * (prediction >= 0))); } }; template struct SigmoidCrossEntropyGradFunctor { OF_DEVICE_FUNC PredT operator()(const PredT prediction, const LabelT label, const PredT loss_diff) const { return loss_diff * (1.f / (1.f + ExpFunctor::Forward(-prediction)) - label); } }; namespace { template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyGradFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction, const LabelT* label, const PredT* loss_diff); }; template class Opt, typename PredT, typename LabelT> struct ElemwiseSigmoidCrossEntropyFunctor final { void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction, const LabelT* label); }; } // namespace template class Opt, typename PredT, typename LabelT> class SigmoidCrossEntropyKernel final : public user_op::OpKernel { public: SigmoidCrossEntropyKernel() = default; ~SigmoidCrossEntropyKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex("loss", 0); const auto n = prediction->shape_view().elem_cnt(); ElemwiseSigmoidCrossEntropyFunctor()( ctx->stream(), n, loss->mut_dptr(), prediction->dptr(), label->dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(device_type, dtype, ltype) \ REGISTER_USER_KERNEL("sigmoid_cross_entropy") \ .SetCreateFn< \ SigmoidCrossEntropyKernel>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("label", 0) == GetDataType::value) \ && (user_op::HobDataType("loss", 0) == GetDataType::value)); template class Opt, typename PredT, typename LabelT> class SigmoidCrossEntropyGradKernel final : public user_op::OpKernel { public: SigmoidCrossEntropyGradKernel() = default; ~SigmoidCrossEntropyGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* loss_diff = ctx->Tensor4ArgNameAndIndex("loss_diff", 0); const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t n = prediction->shape_view().elem_cnt(); ElemwiseSigmoidCrossEntropyGradFunctor()( ctx->stream(), n, prediction_diff->mut_dptr(), prediction->dptr(), label->dptr(), loss_diff->dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(device_type, dtype, ltype) \ REGISTER_USER_KERNEL("sigmoid_cross_entropy_grad") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("label", 0) == GetDataType::value) \ && (user_op::HobDataType("prediction_diff", 0) == GetDataType::value)); } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/skip_layer_norm_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/cuda/atomic.cuh" #include #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/layer_norm.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace { template struct SkipLoad { using LoadType = DST; SkipLoad(const SRC* src, const SRC* bias, const SRC* skip, float alpha, int64_t row_size) : src(src), bias(bias), skip(skip), alpha(alpha), row_size(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::layer_norm::Pack src_pack; cuda::layer_norm::Pack bias_pack; cuda::layer_norm::Pack skip_pack; const int64_t offset = (row * row_size + col) / N; const int64_t bias_offset = col / N; src_pack.storage = *(reinterpret_cast*>(src) + offset); if (bias) { bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { bias_pack.elem[i] = static_cast(0.f); } } if (skip) { skip_pack.storage = *(reinterpret_cast*>(skip) + offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { skip_pack.elem[i] = static_cast(0.f); } } #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(src_pack.elem[i] + bias_pack.elem[i] + skip_pack.elem[i] * static_cast(alpha)); } } const SRC* src; const SRC* bias; const SRC* skip; double alpha; int64_t row_size; }; template struct AffineStore { AffineStore(DST* y, int64_t row_size, const DST* gamma, const DST* beta) : y(y), row_size(row_size), gamma(gamma), beta(beta) {} template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::layer_norm::Pack y_pack; cuda::layer_norm::Pack gamma_pack; cuda::layer_norm::Pack beta_pack; const int64_t offset = (row * row_size + col) / N; const int64_t gamma_offset = col / N; if (do_scale) { gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast(1.f); } } if (do_center) { beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { beta_pack.elem[i] = static_cast(0.f); } } #pragma unroll for (int i = 0; i < N; ++i) { DST normalized_i = static_cast(src[i]); if (do_scale || do_center) { y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; } else { y_pack.elem[i] = normalized_i; } } *(reinterpret_cast*>(y) + offset) = y_pack.storage; } DST* y; int64_t row_size; const DST* gamma; const DST* beta; }; template void LaunchSkipLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, const T* bias_ptr, const T* skip_ptr, const double alpha, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance) { constexpr int32_t block_size = 128; unsigned int nb_element = norm_size * num_instances; unsigned int grid_size = (nb_element + block_size - 1) / block_size; using ComputeType = typename cuda::layer_norm::DefaultComputeType::type; SkipLoad load(x_ptr, bias_ptr, skip_ptr, alpha, norm_size); AffineStore store(y_ptr, norm_size, gamma_ptr, beta_ptr); cuda::layer_norm::DispatchLayerNorm( stream->As()->cuda_stream(), load, store, num_instances, norm_size, epsilon, mean->mut_dptr(), inv_variance->mut_dptr()); } template void DispatchSkipLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size, const double epsilon, const T* x_ptr, const T* gamma_ptr, const T* beta_ptr, const T* bias_ptr, const T* skip_ptr, const double alpha, T* y_ptr, user_op::Tensor* mean, user_op::Tensor* inv_variance) { #define LAUNCH_GPU_KERNEL(has_gamma, has_beta) \ LaunchSkipLayerNormForwardGpu( \ stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, bias_ptr, skip_ptr, \ alpha, y_ptr, mean, inv_variance); if (gamma_ptr != nullptr && beta_ptr != nullptr) { LAUNCH_GPU_KERNEL(true, true); } else if (gamma_ptr != nullptr && beta_ptr == nullptr) { LAUNCH_GPU_KERNEL(true, false); } else if (gamma_ptr == nullptr && beta_ptr != nullptr) { LAUNCH_GPU_KERNEL(false, true); } else { LAUNCH_GPU_KERNEL(false, false); } #undef LAUNCH_GPU_KERNEL } template class SkipLayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SkipLayerNormGpuKernel() = default; ~SkipLayerNormGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { // obtain x and check its shape const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const ShapeView& x_shape = x->shape_view(); CHECK_GE(x_shape.NumAxes(), 2) << "number of axes of \'x\' should be greater than or equal to 2, yet get " << x_shape.NumAxes(); // obtain gamma and check its shape const T* gamma_ptr = nullptr; ShapeView gamma_shape; if (ctx->has_input("gamma", 0)) { const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); gamma_shape = gamma->shape_view(); gamma_ptr = gamma->dptr(); CHECK_EQ(gamma_shape.NumAxes(), 1) << "number of axes of \'gamma\' should be equal to 1, yet get " << gamma_shape.NumAxes(); CHECK_EQ(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'gamma\'(" << gamma_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // obtain beta and check its shape const T* beta_ptr = nullptr; ShapeView beta_shape; if (ctx->has_input("beta", 0)) { const user_op::Tensor* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); beta_shape = beta->shape_view(); beta_ptr = beta->dptr(); CHECK_EQ(beta_shape.NumAxes(), 1) << "number of axes of \'beta\' should be equal to 1, yet get " << beta_shape.NumAxes(); CHECK_EQ(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'beta\'(" << beta_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // obtain bias and check its shape const T* bias_ptr = nullptr; ShapeView bias_shape; if (ctx->has_input("bias", 0)) { const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); bias_shape = bias->shape_view(); bias_ptr = bias->dptr(); CHECK_EQ(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // obtain residual and check its shape const T* skip_ptr = nullptr; ShapeView skip_shape; if (ctx->has_input("skip", 0)) { const user_op::Tensor* skip = ctx->Tensor4ArgNameAndIndex("skip", 0); skip_shape = skip->shape_view(); skip_ptr = skip->dptr(); CHECK_EQ(skip_shape, x_shape); } // obtain epsilon and check its value const double epsilon = ctx->Attr("epsilon"); const double alpha = ctx->Attr("alpha"); // obtain output tensors user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0); const ShapeView& y_shape = y->shape_view(); const ShapeView& mean_shape = mean->shape_view(); const ShapeView& inv_variance_shape = inv_variance->shape_view(); // calculate number of instances and norm size const int64_t num_instances = mean->shape_view().elem_cnt(); const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; // dispatch kernel DispatchSkipLayerNormForwardGpu(ctx->stream(), num_instances, norm_size, epsilon, x->dptr(), gamma_ptr, beta_ptr, bias_ptr, skip_ptr, alpha, y->mut_dptr(), mean, inv_variance); } }; } // namespace #define REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("skip_layer_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(float) REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(double) REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(nv_bfloat16) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/skip_rms_norm_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/cuda/atomic.cuh" #include #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/matmul.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/rms_norm.cuh" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { namespace cuda { namespace rms_norm { template struct SkipLoad { using LoadType = DST; SkipLoad(const SRC* src, const SRC* bias, const SRC* skip, const float alpha, int64_t row_size) : src(src), bias(bias), skip(skip), alpha(alpha), row_size(row_size) {} template __device__ void load(DST* dst, int64_t row, int64_t col) const { layer_norm::Pack src_pack; layer_norm::Pack bias_pack; layer_norm::Pack skip_pack; const int64_t offset = (row * row_size + col) / N; const int64_t bias_offset = col / N; src_pack.storage = *(reinterpret_cast*>(src) + offset); if (bias) { bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { bias_pack.elem[i] = static_cast(0.f); } } if (skip) { skip_pack.storage = *(reinterpret_cast*>(skip) + offset); } else { #pragma unroll for (int i = 0; i < N; ++i) { skip_pack.elem[i] = static_cast(0.f); } } #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = static_cast(src_pack.elem[i] + bias_pack.elem[i] + skip_pack.elem[i] * static_cast(alpha)); } } const SRC* src; const SRC* bias; const SRC* skip; float alpha; int64_t row_size; }; template struct AffineStore { AffineStore(DST* dst, const DST* weight, int32_t row_size) : dst(dst), weight(weight), row_size(row_size) {} template __device__ void store(const SRC* src, int32_t row, int32_t col) { layer_norm::Pack dst_pack; layer_norm::Pack weight_pack; const int32_t offset = (row * row_size + col) / N; const int32_t weight_offset = col / N; if (affine) { weight_pack.storage = *(reinterpret_cast*>(weight) + weight_offset); } #pragma unroll for (int i = 0; i < N; ++i) { if (affine) { dst_pack.elem[i] = static_cast(src[i]) * weight_pack.elem[i]; } else { dst_pack.elem[i] = static_cast(src[i]); } } *(reinterpret_cast*>(dst) + offset) = dst_pack.storage; } DST* dst; const DST* weight; int32_t row_size; }; template void DispatchSkipRmsNormForwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const double eps, const double alpha, const T* x_dptr, const T* w_dptr, const T* skip_dptr, const T* bias_dptr, T* y_dptr, ComputeType* inv_rms) { constexpr int32_t block_size = 128; unsigned int nb_element = nrow * ncol; unsigned int grid_size = (nb_element + block_size - 1) / block_size; SkipLoad load(x_dptr, bias_dptr, skip_dptr, alpha, ncol); AffineStore store(y_dptr, w_dptr, ncol); OF_CUDA_CHECK((LaunchRmsNorm( stream->As()->cuda_stream(), load, store, nrow, ncol, eps, inv_rms))); } template void SkipRmsNormForward(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const double eps, const double alpha, const T* x_dptr, const T* w_dptr, const T* skip_dptr, const T* bias_dptr, T* y_dptr, ComputeType* inv_rms) { if (w_dptr) { DispatchSkipRmsNormForwardAffine( stream, nrow, ncol, eps, alpha, x_dptr, w_dptr, skip_dptr, bias_dptr, y_dptr, inv_rms); } else { DispatchSkipRmsNormForwardAffine( stream, nrow, ncol, eps, alpha, x_dptr, w_dptr, skip_dptr, bias_dptr, y_dptr, inv_rms); } } } // namespace rms_norm template class SkipRmsNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SkipRmsNormGpuKernel() = default; ~SkipRmsNormGpuKernel() = default; private: using user_op::OpKernel::Compute; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { // obtain x and check its shape const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const ShapeView& x_shape = x->shape_view(); CHECK_GE(x_shape.NumAxes(), 2) << "number of axes of \'x\' should be greater than or equal to 2, yet get " << x_shape.NumAxes(); // obtain weight and check its shape const T* weight_ptr = nullptr; ShapeView weight_shape; if (ctx->has_input("weight", 0)) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); weight_shape = weight->shape_view(); weight_ptr = weight->dptr(); CHECK_EQ(weight_shape.NumAxes(), 1) << "number of axes of \'weight\' should be equal to 1, yet get " << weight_shape.NumAxes(); CHECK_EQ(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'weight\'(" << weight_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // obtain bias and check its shape const T* bias_ptr = nullptr; ShapeView bias_shape; if (ctx->has_input("bias", 0)) { const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); bias_shape = bias->shape_view(); bias_ptr = bias->dptr(); CHECK_EQ(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // obtain skip and check its shape const T* skip_ptr = nullptr; ShapeView skip_shape; if (ctx->has_input("skip", 0)) { const user_op::Tensor* skip = ctx->Tensor4ArgNameAndIndex("skip", 0); skip_shape = skip->shape_view(); skip_ptr = skip->dptr(); CHECK_EQ(skip_shape, x_shape); } // obtain epsilon and check its value const double epsilon = ctx->Attr("epsilon"); const double alpha = ctx->Attr("alpha"); // obtain output tensors user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex("inv_rms", 0); const ShapeView& y_shape = y->shape_view(); const ShapeView& inv_rms_shape = inv_rms->shape_view(); // calculate number of instances and norm size const int64_t nrow = inv_rms->shape_view().elem_cnt(); const int64_t ncol = x->shape_view().elem_cnt() / nrow; // dispatch kernel using ComputeType = typename layer_norm::DefaultComputeType::type; rms_norm::SkipRmsNormForward(ctx->stream(), nrow, ncol, epsilon, alpha, x->dptr(), weight_ptr, skip_ptr, bias_ptr, y->mut_dptr(), inv_rms->mut_dptr()); } }; #define REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("skip_rms_norm") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(float) REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(double) REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(half) #if CUDA_VERSION >= 11000 REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(nv_bfloat16) #endif } // namespace cuda } // namespace oneflow ================================================ FILE: oneflow/user/kernels/slice_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { const int SPLIT_AXIS_FOR_NON_SPLIT = -1; // [start, end) int64_t GetSizeInSlice(const int64_t start, const int64_t end, const int64_t step) { if (end <= start) { return 0; } return (end - start - 1) / step + 1; } class SliceContext final { public: struct SplitInfo { // These fields shows how the logical tensor is split. // The logical tensor is split on the axis `split_axis` // The physical tensor on current device is in the range [lower, upper) // The length of the logical tensor on `split_axis` is `logical_length` // Example: // Variable shape = (8, 7, 6, 5), sbp = S(0), on 4 devices, then on the first card: // split_axis = 0 // lower = 0 // upper = 2 // logical_length = 8 const int64_t split_axis; const int64_t lower; const int64_t upper; const int64_t logical_length; }; SliceContext() : axis_bitset_(0) {} Maybe PushSplitInfo(int64_t split_axis, int64_t lower, int64_t upper, int64_t logical_length) { if (split_axis != SPLIT_AXIS_FOR_NON_SPLIT) { // split_axis can only be push once CHECK_OR_RETURN(!IsAxisPushed(split_axis)) << "split_axis " << split_axis << " has been pushed to SliceContext"; CHECK_GE_OR_RETURN(split_axis, 0) << "split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT"; axis_bitset_ |= ((uint32_t)1 << split_axis); // NOLINT } split_info_vec_.emplace_back(SplitInfo{split_axis, lower, upper, logical_length}); return Maybe::Ok(); } const std::vector& GetSplitInfo() const { return split_info_vec_; } bool IsAxisPushed(int64_t split_axis) const { if (split_axis == SPLIT_AXIS_FOR_NON_SPLIT) { return false; } CHECK_GE(split_axis, 0) << "split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT"; return (axis_bitset_ & ((uint32_t)1 << split_axis)) != 0; // NOLINT } private: std::vector split_info_vec_; uint32_t axis_bitset_; }; void ConstructSliceParamsLarge(const SliceContext& ctx, const std::vector& start_vec, const std::vector& stop_vec, const std::vector& step_vec, const ShapeView& shape, SliceParams* slice_param) { const int64_t ndim = shape.NumAxes(); CHECK_LE(ndim, kSliceMaxDims); CHECK_EQ(start_vec.size(), ndim); CHECK_EQ(stop_vec.size(), ndim); CHECK_EQ(step_vec.size(), ndim); slice_param->ndim = ndim; FOR_RANGE(int, i, 0, slice_param->ndim) { const int64_t dim_size = shape.At(i); const int64_t start_in_full_large = start_vec.at(i); const int64_t stop_in_full_large = stop_vec.at(i); const int64_t step = step_vec.at(i); CHECK_GT(step, 0); int64_t start_in_splitted_large = start_in_full_large; int64_t stop_in_splitted_large = stop_in_full_large; // large tensor has split sbp attribute for (const auto& split_info : ctx.GetSplitInfo()) { if (split_info.split_axis == i) { if (start_in_splitted_large < split_info.lower) { start_in_splitted_large = split_info.lower + (step - (split_info.lower - start_in_splitted_large) % step) % step; } start_in_splitted_large = std::min(std::max(start_in_splitted_large, split_info.lower), split_info.upper); stop_in_splitted_large = std::min(std::max(stop_in_splitted_large, split_info.lower), split_info.upper); start_in_splitted_large -= split_info.lower; stop_in_splitted_large -= split_info.lower; } } const int64_t slice_size = GetSizeInSlice(start_in_splitted_large, stop_in_splitted_large, step); slice_param->dims[i] = dim_size; slice_param->start[i] = start_in_splitted_large; slice_param->step[i] = step; slice_param->size[i] = slice_size; } } void ConstructSliceParamsSmall(const SliceContext& ctx, const std::vector& start_vec, const std::vector& stop_vec, const std::vector& step_vec, const ShapeView& shape, SliceParams* slice_param) { const int64_t ndim = shape.NumAxes(); CHECK_LE(ndim, kSliceMaxDims); CHECK_EQ(start_vec.size(), ndim); CHECK_EQ(stop_vec.size(), ndim); CHECK_EQ(step_vec.size(), ndim); slice_param->ndim = ndim; FOR_RANGE(int, i, 0, slice_param->ndim) { const int64_t start_in_full_large = start_vec.at(i); const int64_t step = step_vec.at(i); CHECK_GT(step, 0); // small tensor has broadcast/partialsum sbp attribute const int64_t dim_size = shape.At(i); int64_t start_in_full_small = 0; int64_t stop_in_full_small = dim_size; for (const auto& split_info : ctx.GetSplitInfo()) { if (split_info.split_axis == i) { start_in_full_small = GetSizeInSlice(start_in_full_large, split_info.lower, step); stop_in_full_small = GetSizeInSlice(start_in_full_large, split_info.upper, step); start_in_full_small = std::min(std::max(start_in_full_small, 0), dim_size); stop_in_full_small = std::min(std::max(stop_in_full_small, 0), dim_size); } } const int64_t slice_size = stop_in_full_small - start_in_full_small; slice_param->dims[i] = dim_size; slice_param->start[i] = start_in_full_small; slice_param->step[i] = 1; slice_param->size[i] = slice_size; } } SliceParams ConstructSliceParams(user_op::KernelComputeContext* ctx, const user_op::Tensor* entire, const user_op::Tensor* sliced) { const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); const int64_t ndim = entire->shape_view().NumAxes(); CHECK_LE(ndim, kSliceMaxDims); if (entire->shape_view().NumAxes() == 1) { CHECK_LE(sliced->shape_view().NumAxes(), 1); } else { CHECK_EQ(sliced->shape_view().NumAxes(), ndim); } CHECK_EQ(start_vec.size(), ndim); CHECK_EQ(stop_vec.size(), ndim); CHECK_EQ(step_vec.size(), ndim); SliceParams params; if (entire->shape_view().NumAxes() == 1 && sliced->shape_view().NumAxes() == 0) { params.ndim = ndim; params.dims[0] = entire->shape_view().At(0); params.start[0] = RegulateSliceStart(start_vec.at(0), entire->shape_view().At(0)); params.step[0] = step_vec.at(0); params.size[0] = 1; return params; } params.ndim = ndim; FOR_RANGE(int, i, 0, params.ndim) { const int64_t dim_size = entire->shape_view().At(i); const int64_t slice_size = sliced->shape_view().At(i); const int64_t step = step_vec.at(i); CHECK_NE(step, 0); const int64_t start = RegulateSliceStart(start_vec.at(i), dim_size); const int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size); if (step > 0) { CHECK_LT(start + step * (slice_size - 1), stop); } else { CHECK_GT(start + step * (slice_size - 1), stop); } params.dims[i] = dim_size; params.start[i] = start; params.step[i] = step; params.size[i] = slice_size; } return params; } } // namespace template void WriteSlice(user_op::KernelComputeContext* ctx, const user_op::Tensor* src, user_op::Tensor* dst, const SliceContext& slice_ctx, const bool from_large_to_small) { const user_op::Tensor* large = from_large_to_small ? src : dst; const user_op::Tensor* small = from_large_to_small ? dst : src; // Check physical tensor's shape for (const auto& split_info : slice_ctx.GetSplitInfo()) { if (split_info.split_axis != SPLIT_AXIS_FOR_NON_SPLIT) { CHECK_EQ(large->shape_view().At(split_info.split_axis), split_info.upper - split_info.lower) << "split_info shape mismatch physical tensor shape"; } } const std::vector start_attr = ctx->Attr>("start"); const std::vector stop_attr = ctx->Attr>("stop"); const std::vector step_attr = ctx->Attr>("step"); const int64_t ndim = start_attr.size(); std::vector positive_start_vec(ndim); std::vector positive_stop_vec(ndim); // regulate axis number std::vector logical_dims(ndim); { for (int i = 0; i < ndim; i++) { if (!slice_ctx.IsAxisPushed(i)) { // axis is not split, logical shape is same as physical shape logical_dims[i] = large->shape_view().At(i); } } for (const auto& split_info : slice_ctx.GetSplitInfo()) { if (split_info.split_axis != SPLIT_AXIS_FOR_NON_SPLIT) { logical_dims[split_info.split_axis] = split_info.logical_length; } } } for (int i = 0; i < ndim; i++) { positive_start_vec[i] = RegulateSliceStart(start_attr[i], logical_dims[i]); positive_stop_vec[i] = RegulateSliceStop(stop_attr[i], logical_dims[i]); } SliceParams large_slice_param; std::copy(large->stride().begin(), large->stride().end(), large_slice_param.stride); SliceParams small_slice_param; std::copy(small->stride().begin(), small->stride().end(), small_slice_param.stride); ConstructSliceParamsLarge(slice_ctx, positive_start_vec, positive_stop_vec, step_attr, large->shape_view(), &large_slice_param); ConstructSliceParamsSmall(slice_ctx, positive_start_vec, positive_stop_vec, step_attr, small->shape_view(), &small_slice_param); CHECK_EQ(large_slice_param.elem_cnt(), small_slice_param.elem_cnt()); if (large_slice_param.ndim == 0 && small_slice_param.ndim == 0) { // Copy data directly for scalar tensor AutoMemcpy(ctx->stream(), dst->mut_dptr(), src->dptr(), sizeof(T), src->mem_case(), dst->mem_case()); return; } if (from_large_to_small) { if (small_slice_param.elem_cnt() == small->shape_view().elem_cnt()) { SliceKernelUtil::Forward(ctx->stream(), large_slice_param, src->dptr(), dst->mut_dptr()); } else { AutoMemset(ctx->stream(), dst->mut_dptr(), 0, dst->shape_view().elem_cnt() * GetSizeOfDataType(dst->data_type()), dst->mem_case()); SliceKernelUtil::Forward(ctx->stream(), large_slice_param, small_slice_param, src->dptr(), dst->mut_dptr()); } } else { SliceKernelUtil::Forward(ctx->stream(), small_slice_param, large_slice_param, src->dptr(), dst->mut_dptr()); } } template class SliceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SliceKernel() = default; ~SliceKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { SliceContext slice_ctx; if (ctx->parallel_ctx().parallel_num() == 1) { // split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split' CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0)); } else { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); NdSbp in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0); { const NdSbp& y_nd_sbp = ctx->NdSbp4ArgNameAndIndex("y", 0); // If x and y both split in the same axis(must be full slice), // we can consider the physical tensor is broadcast in this axis. FOR_RANGE(int32_t, i, 0, parallel_hierarchy.NumAxes()) { const SbpParallel& x_sbp = in_nd_sbp.sbp_parallel(i); const SbpParallel& y_sbp = y_nd_sbp.sbp_parallel(i); if (x_sbp.has_split_parallel() && y_sbp.has_split_parallel()) { CHECK_EQ(x_sbp.split_parallel().axis(), y_sbp.split_parallel().axis()); in_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel(); in_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); } } } const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("x", 0)->shape(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const TensorSliceView& slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id); for (int i = 0; i < logical_shape.NumAxes(); ++i) { const Range& range = slice_view.At(i); if (range.begin() != 0 || range.end() != logical_shape.At(i)) { CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i))); } } } return std::make_shared>(slice_ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); if (y_tensor->shape_view().elem_cnt() == 0) { return; } const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); const SliceContext& slice_ctx = dynamic_cast*>(cache)->Get(); WriteSlice(ctx, x_tensor, y_tensor, slice_ctx, /*from_large_to_small=*/true); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SliceUpdateKernel final : public user_op::OpKernel { public: SliceUpdateKernel() = default; ~SliceUpdateKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { SliceContext slice_ctx; if (ctx->parallel_ctx().parallel_num() == 1) { // split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split' CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0)); } else { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); NdSbp ref_nd_sbp = ctx->NdSbp4ArgNameAndIndex("ref", 0); { const NdSbp& value_nd_sbp = ctx->NdSbp4ArgNameAndIndex("value", 0); // If ref and value both split in the same axis(full slice), // we can consider the physical tensor is broadcast in this axis. for (int i = 0; i < parallel_hierarchy.NumAxes(); ++i) { const SbpParallel& ref_sbp = ref_nd_sbp.sbp_parallel(i); const SbpParallel& value_sbp = value_nd_sbp.sbp_parallel(i); if (ref_sbp.has_split_parallel() && value_sbp.has_split_parallel()) { CHECK_EQ(ref_sbp.split_parallel().axis(), value_sbp.split_parallel().axis()); ref_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel(); ref_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); } } } const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("ref", 0)->shape(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const TensorSliceView& slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, ref_nd_sbp, logical_shape, parallel_id); for (int i = 0; i < logical_shape.NumAxes(); ++i) { const Range& range = slice_view.At(i); if (range.begin() != 0 || range.end() != logical_shape.At(i)) { CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i))); } } } return std::make_shared>(slice_ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* ref_tensor = ctx->Tensor4ArgNameAndIndex("ref", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); if (y_tensor->shape_view().elem_cnt() == 0) { return; } // When eager executing, y_tensor shared the same memory with ref_tensor if (ref_tensor->dptr() != y_tensor->dptr()) { // lazy run AutoMemcpy(ctx->stream(), y_tensor->mut_dptr(), ref_tensor->dptr(), y_tensor->shape_view().elem_cnt() * sizeof(T), ref_tensor->mem_case(), y_tensor->mem_case()); } const SliceContext& slice_ctx = dynamic_cast*>(cache)->Get(); WriteSlice(ctx, value_tensor, y_tensor, slice_ctx, /*from_large_to_small=*/false); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; template class SliceGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SliceGradKernel() = default; ~SliceGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); size_t dx_byte_size = dx_tensor->shape_view().elem_cnt() * sizeof(T); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_byte_size); if (dy_tensor->shape_view().elem_cnt() == 0) { return; } SliceParams params = ConstructSliceParams(ctx, dx_tensor, dy_tensor); SliceKernelUtil::Backward(ctx->stream(), params, dy_tensor->dptr(), dx_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SLICE_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("slice").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("slice_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("slice_update") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("ref", 0) == GetDataType::value)); #define REGISTER_SLICE_KERNEL_WITH_DEVICE(device) \ REGISTER_SLICE_KERNEL(device, bool) \ REGISTER_SLICE_KERNEL(device, float16) \ REGISTER_SLICE_KERNEL(device, float) \ REGISTER_SLICE_KERNEL(device, double) \ REGISTER_SLICE_KERNEL(device, int32_t) \ REGISTER_SLICE_KERNEL(device, int64_t) \ REGISTER_SLICE_KERNEL(device, int8_t) \ REGISTER_SLICE_KERNEL(device, uint8_t) REGISTER_SLICE_KERNEL(DeviceType::kCPU, std::complex) REGISTER_SLICE_KERNEL(DeviceType::kCPU, std::complex) #ifdef WITH_CUDA REGISTER_SLICE_KERNEL(DeviceType::kCUDA, cuComplex) REGISTER_SLICE_KERNEL(DeviceType::kCUDA, cuDoubleComplex) #endif REGISTER_SLICE_KERNEL_WITH_DEVICE(DeviceType::kCPU) REGISTER_SLICE_KERNEL(DeviceType::kCPU, bfloat16) #ifdef WITH_CUDA REGISTER_SLICE_KERNEL_WITH_DEVICE(DeviceType::kCUDA) #if CUDA_VERSION >= 11000 REGISTER_SLICE_KERNEL(DeviceType::kCUDA, nv_bfloat16) #endif #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/slice_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { SliceParams FoldContiguousFullSliceDimensions(const SliceParams& params) { SliceParams fold_slice_params; bool full_slice_on_prev_axis = false; FOR_RANGE(int, i, 0, params.ndim) { bool full_slice_on_cur_axis = params.IsFullSlice(i); if (full_slice_on_cur_axis && full_slice_on_prev_axis) { int cur_dim = fold_slice_params.ndim - 1; fold_slice_params.dims[cur_dim] *= params.dims[i]; fold_slice_params.size[cur_dim] *= params.size[i]; } else { int cur_dim = fold_slice_params.ndim; fold_slice_params.dims[cur_dim] = params.dims[i]; fold_slice_params.start[cur_dim] = params.start[i]; fold_slice_params.step[cur_dim] = params.step[i]; fold_slice_params.size[cur_dim] = params.size[i]; fold_slice_params.ndim += 1; } full_slice_on_prev_axis = full_slice_on_cur_axis; } return fold_slice_params; } template struct SliceKernelUtil { static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) { SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params); SwitchDoForward(SwitchCase(fold_slice_params.ndim), stream, fold_slice_params, entire, sliced); } static void Forward(ep::Stream* stream, const SliceParams& entire_params, const SliceParams& sliced_params, const T* entire, T* sliced) { SwitchDoForward(SwitchCase(entire_params.ndim), stream, entire_params, sliced_params, entire, sliced); } static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) { SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params); SwitchDoBackward(SwitchCase(fold_slice_params.ndim), stream, fold_slice_params, sliced, entire); } private: template static void DoForward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) { CHECK_EQ(params.ndim, NDIM); int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); MultiThreadLoop(elem_cnt, [&](int64_t i) { int64_t offset = SliceOffsetToEntireOffset(i, params, entire_idx_cvtr, sliced_idx_cvtr); sliced[i] = entire[offset]; }); } template static void SteppedMultiThreadLoop(size_t elem_cnt, size_t step, const DoEachT& DoEach) { if (elem_cnt == 0) { return; } CHECK_GT(step, 0); CHECK_EQ(elem_cnt % step, 0); MultiThreadLoop(elem_cnt / step, [&](size_t i) { DoEach(i * step); }); } template static void DoForward(ep::Stream* stream, const SliceParams& entire_params, const SliceParams& sliced_params, const T* entire, T* sliced) { CHECK_EQ(entire_params.ndim, NDIM); CHECK_EQ(sliced_params.ndim, NDIM); int64_t elem_cnt = entire_params.elem_cnt(); SliceIndexHelper entire_splitted_large_idx_cvtr = NdIndexStrideOffsetHelper(entire_params.stride); SliceIndexHelper sliced_splitted_large_idx_cvtr(entire_params.size); SliceIndexHelper entire_full_small_idx_cvtr = NdIndexStrideOffsetHelper(sliced_params.stride); SliceIndexHelper sliced_full_small_idx_cvtr(sliced_params.size); int cnt = 1; int entire_target_stride = 1; int sliced_target_stride = 1; // Calculate the length of continuous part for (int i = NDIM - 1; i >= 0; i--) { if (entire_params.stride[i] != entire_target_stride || sliced_params.stride[i] != sliced_target_stride) { break; } entire_target_stride *= entire_params.size[i]; sliced_target_stride *= sliced_params.size[i]; if (sliced_params.step[i] == 1 && entire_params.step[i] == 1) { cnt *= sliced_params.size[i]; } if (!entire_params.IsFullSlice(i) || !sliced_params.IsFullSlice(i)) { break; } } SteppedMultiThreadLoop(elem_cnt, cnt, [&](int64_t i) { const int64_t entire_offset = SliceOffsetToEntireOffset( i, entire_params, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr); const int64_t sliced_offset = SliceOffsetToEntireOffset( i, sliced_params, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr); std::copy(entire + entire_offset, entire + entire_offset + cnt, sliced + sliced_offset); }); } template static void DoBackward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) { CHECK_EQ(params.ndim, NDIM); int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); MultiThreadLoop(elem_cnt, [&](int64_t i) { int64_t offset = SliceOffsetToEntireOffset(i, params, entire_idx_cvtr, sliced_idx_cvtr); entire[offset] = sliced[i]; }); } #define MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY(func_name, N) \ SliceKernelUtil::func_name #define DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(func_name) \ DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY, \ MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(DoForward); DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(DoBackward); #undef DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD #undef MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY }; INSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(DeviceType::kCPU) INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, bfloat16) INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, std::complex) INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, std::complex) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/slice_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #if CUDA_VERSION >= 11000 #include #endif // CUDA_VERSION >= 11000 namespace oneflow { namespace { template __global__ void SliceForwardGpu(const int n, SliceParams params, SliceIndexHelper entire_idx_cvtr, SliceIndexHelper sliced_idx_cvtr, const T* entire, T* sliced) { CUDA_1D_KERNEL_LOOP(i, n) { int64_t offset = SliceOffsetToEntireOffset(i, params, entire_idx_cvtr, sliced_idx_cvtr); sliced[i] = entire[offset]; } } template __global__ void SliceForwardGpu(const int n, SliceParams entire_params, SliceParams sliced_params, SliceIndexHelper entire_splitted_large_idx_cvtr, SliceIndexHelper sliced_splitted_large_idx_cvtr, SliceIndexHelper entire_full_small_idx_cvtr, SliceIndexHelper sliced_full_small_idx_cvtr, const T* entire, T* sliced) { CUDA_1D_KERNEL_LOOP(i, n) { int64_t entire_offset = SliceOffsetToEntireOffset( i, entire_params, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr); int64_t sliced_offset = SliceOffsetToEntireOffset( i, sliced_params, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr); sliced[sliced_offset] = entire[entire_offset]; } } template __global__ void SliceBackwardGpu(const int n, SliceParams params, SliceIndexHelper entire_idx_cvtr, SliceIndexHelper sliced_idx_cvtr, T* entire, const T* sliced) { CUDA_1D_KERNEL_LOOP(i, n) { int64_t offset = SliceOffsetToEntireOffset(i, params, entire_idx_cvtr, sliced_idx_cvtr); entire[offset] = sliced[i]; } } template void LaunchSliceForward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) { CHECK_EQ(params.ndim, NDIM); int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); if (elem_cnt == 0) { return; } SliceForwardGpu<<As()->cuda_stream()>>>( elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); } template void LaunchSliceForward(ep::Stream* stream, const SliceParams& entire_params, const SliceParams& sliced_params, const T* entire, T* sliced) { CHECK_EQ(entire_params.ndim, NDIM); CHECK_EQ(sliced_params.ndim, NDIM); int64_t elem_cnt = entire_params.elem_cnt(); if (elem_cnt == 0) { return; } SliceIndexHelper entire_splitted_large_idx_cvtr = NdIndexStrideOffsetHelper(entire_params.stride); SliceIndexHelper sliced_splitted_large_idx_cvtr(entire_params.size); SliceIndexHelper entire_full_small_idx_cvtr = NdIndexStrideOffsetHelper(sliced_params.stride); SliceIndexHelper sliced_full_small_idx_cvtr(sliced_params.size); SliceForwardGpu<<As()->cuda_stream()>>>( elem_cnt, entire_params, sliced_params, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr, entire, sliced); } template void LaunchSliceBackward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) { CHECK_EQ(params.ndim, NDIM); int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); if (elem_cnt == 0) { return; } SliceBackwardGpu<<As()->cuda_stream()>>>( elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); } template struct SliceSwitchUtil final { #define MAKE_SLICE_SWITCH_ENTRY(func_name, N) func_name #define DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(func_name) \ DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_SLICE_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)) DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(LaunchSliceForward) DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(LaunchSliceBackward) #undef DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD #undef MAKE_SLICE_SWITCH_ENTRY }; template size_t GetPackSize(const SliceParams& params, const T* entire, const T* sliced) { CHECK_GT(params.ndim, 0); const int64_t last_dim = params.ndim - 1; const int64_t mask = (params.dims[last_dim] * sizeof(T)) | (params.start[last_dim] * sizeof(T)) | (params.size[last_dim] * sizeof(T)) | static_cast(reinterpret_cast(entire)) | static_cast(reinterpret_cast(sliced)); if ((mask & 0xF) == 0) { return 16; } else if ((mask & 0x7) == 0) { return 8; } else if ((mask & 0x3) == 0) { return 4; } else if ((mask & 0x1) == 0) { return 2; } else { return 1; } } template void GetPackedParams(const SliceParams& params, const T* entire, const T* sliced, size_t* pack_size, SliceParams* packed_params) { CHECK_GT(params.ndim, 0); const int64_t last_dim = params.ndim - 1; if (params.step[last_dim] == 1) { *pack_size = GetPackSize(params, entire, sliced); CHECK_GE(*pack_size, sizeof(T)); const int64_t elem_per_pack = *pack_size / sizeof(T); *packed_params = params; packed_params->dims[last_dim] /= elem_per_pack; packed_params->start[last_dim] /= elem_per_pack; packed_params->size[last_dim] /= elem_per_pack; } else { *pack_size = sizeof(T); *packed_params = params; } } } // namespace template struct SliceKernelUtil { static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) { SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params); size_t pack_size; SliceParams packed_params{}; GetPackedParams(fold_slice_params, entire, sliced, &pack_size, &packed_params); if (pack_size == 1) { SliceSwitchUtil::SwitchLaunchSliceForward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(entire), reinterpret_cast(sliced)); } else if (pack_size == 2) { SliceSwitchUtil::SwitchLaunchSliceForward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(entire), reinterpret_cast(sliced)); } else if (pack_size == 4) { SliceSwitchUtil::SwitchLaunchSliceForward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(entire), reinterpret_cast(sliced)); } else if (pack_size == 8) { SliceSwitchUtil::SwitchLaunchSliceForward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(entire), reinterpret_cast(sliced)); } else if (pack_size == 16) { SliceSwitchUtil::SwitchLaunchSliceForward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(entire), reinterpret_cast(sliced)); } else { UNIMPLEMENTED(); } } static void Forward(ep::Stream* stream, const SliceParams& entire_params, const SliceParams& sliced_params, const T* entire, T* sliced) { SliceSwitchUtil::SwitchLaunchSliceForward(SwitchCase(entire_params.ndim), stream, entire_params, sliced_params, entire, sliced); } static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) { SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params); size_t pack_size; SliceParams packed_params{}; GetPackedParams(fold_slice_params, entire, sliced, &pack_size, &packed_params); if (pack_size == 1) { SliceSwitchUtil::SwitchLaunchSliceBackward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(sliced), reinterpret_cast(entire)); } else if (pack_size == 2) { SliceSwitchUtil::SwitchLaunchSliceBackward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(sliced), reinterpret_cast(entire)); } else if (pack_size == 4) { SliceSwitchUtil::SwitchLaunchSliceBackward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(sliced), reinterpret_cast(entire)); } else if (pack_size == 8) { SliceSwitchUtil::SwitchLaunchSliceBackward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(sliced), reinterpret_cast(entire)); } else if (pack_size == 16) { SliceSwitchUtil::SwitchLaunchSliceBackward( SwitchCase(packed_params.ndim), stream, packed_params, reinterpret_cast(sliced), reinterpret_cast(entire)); } else { UNIMPLEMENTED(); } } }; INSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(DeviceType::kCUDA) INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, cuComplex) INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, cuDoubleComplex) #if CUDA_VERSION >= 11000 INSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, nv_bfloat16) #endif } // namespace oneflow ================================================ FILE: oneflow/user/kernels/slice_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SLICE_UTIL_H_ #define ONEFLOW_USER_KERNELS_SLICE_UTIL_H_ #include #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { inline int64_t RegulateSliceStart(int64_t start, int64_t size) { // slice start must be in range [-size, size) // after changing to positive order it should be in range [0, size) start = std::min(std::max(start, -size), size - 1); return (start < 0) ? (start + size) : start; } inline int64_t RegulateSliceStop(int64_t stop, int64_t size) { // slice stop must be in range [-size-1, size] // after changing to positive order it should be in range [-1, size] stop = std::min(std::max(stop, -size - 1), size); return (stop < 0) ? (stop + size) : stop; } constexpr size_t kSliceMaxDims = 8; struct SliceParams { int64_t ndim = 0; int64_t dims[kSliceMaxDims]{0}; int64_t stride[kSliceMaxDims]{0}; int64_t start[kSliceMaxDims]{0}; int64_t step[kSliceMaxDims]{0}; int64_t size[kSliceMaxDims]{0}; int64_t elem_cnt() const { if (ndim == 0) { return 0; } int64_t elem_cnt = 1; FOR_RANGE(int, i, 0, ndim) { elem_cnt *= size[i]; } return elem_cnt; } bool IsFullSlice(int dim) const { CHECK_GE(dim, 0); CHECK_LT(dim, ndim); if (step[dim] != 1) { return false; } if (start[dim] != 0) { return false; } if (size[dim] != dims[dim]) { return false; } return true; } std::string ToString() const { std::stringstream ss("SliceParams:"); for (int i = 0; i < ndim; ++i) { ss << "\n\tdim: " << i << ", start: " << start[i] << ", step: " << step[i] << ", stride: " << stride[i] << ", size: " << size[i] << ", dims: " << dims[i]; } return ss.str(); } }; SliceParams FoldContiguousFullSliceDimensions(const SliceParams& params); template using SliceIndexHelper = NdIndexOffsetHelper; template OF_DEVICE_FUNC int64_t SliceOffsetToEntireOffset(int64_t offset, const SliceParams& params, const SliceIndexHelper& entire_idx_cvtr, const SliceIndexHelper& sliced_idx_cvtr) { int64_t nd_index[NDIM] = {0}; sliced_idx_cvtr.OffsetToNdIndex(offset, nd_index); #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int64_t i = 0; i < NDIM; ++i) { nd_index[i] = params.start[i] + params.step[i] * nd_index[i]; assert(nd_index[i] >= 0); assert(nd_index[i] < params.dims[i]); } return entire_idx_cvtr.NdIndexToOffset(nd_index); } template struct SliceKernelUtil { static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced); static void Forward(ep::Stream* stream, const SliceParams& entire_params, const SliceParams& sliced_params, const T* entire, T* sliced); static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire); }; #define INSTANTIATE_SLICE_KERNEL_UTIL(device, dtype) template struct SliceKernelUtil; #define INSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(device) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, bool) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, float16) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, float) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, double) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, int32_t) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, int64_t) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, int8_t) \ INSTANTIATE_SLICE_KERNEL_UTIL(device, uint8_t) } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SLICE_UTIL_H_ ================================================ FILE: oneflow/user/kernels/smooth_l1_loss_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/loss_kernel_util.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template void ComputeSmoothL1Out(int64_t elem_cnt, const T* input, const T* target, T* out, const float beta) { FOR_RANGE(int64_t, i, 0, elem_cnt) { const T abs_diff = std::abs(input[i] - target[i]); if (abs_diff < beta) { out[i] = 0.5 * abs_diff * abs_diff / beta; } else { out[i] = abs_diff - 0.5 * beta; } } } template void ComputeSmoothL1GradOut(int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx, const float beta) { FOR_RANGE(int64_t, i, 0, elem_cnt) { const T diff = input[i] - target[i]; const T abs_diff = std::abs(diff); if (abs_diff < beta) { dx[i] = diff / beta; } else { dx[i] = (diff > GetZeroVal()) - (diff < GetZeroVal()); } const T dy_val = dy[i]; dx[i] = dx[i] * dy_val; } } template class SmoothL1LossKernel : public SimpleLossKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, T* out) const { const float beta = ctx->Attr("beta"); ComputeSmoothL1Out(elem_cnt, input, target, out, beta); } }; template class SmoothL1LossGradKernel : public SimpleLossGradKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx) const { const float beta = ctx->Attr("beta"); ComputeSmoothL1GradOut(elem_cnt, input, target, dy, dx, beta); } }; } // namespace REGISTER_SIMPLE_LOSS_KERNEL_CPU("smooth_l1_loss", SmoothL1LossKernel) REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU("smooth_l1_loss_grad", SmoothL1LossGradKernel) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/smooth_l1_loss_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/loss_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { using namespace loss; template struct SmoothL1Functor { float beta_; float inv_beta_; T half_of_one_; SmoothL1Functor(float beta) : beta_(beta), inv_beta_(static_cast(1.0 / beta)), half_of_one_(static_cast(0.5)) {} __device__ __forceinline__ T operator()(T input_val, T target_val) const { const T abs_diff = abs(input_val - target_val); if (abs_diff < beta_) { return half_of_one_ * abs_diff * abs_diff * inv_beta_; } else { return abs_diff - half_of_one_ * beta_; } } }; template<> struct SmoothL1Functor { half beta_; half inv_beta_; half zero_; half half_of_one_; SmoothL1Functor(float beta) : beta_(__float2half(beta)), inv_beta_(__float2half(static_cast(1.0 / beta))), zero_(__float2half(0.f)), half_of_one_(__float2half(0.5f)) {} __device__ __forceinline__ half operator()(half input_val, half target_val) const { const half diff = input_val - target_val; const half abs_diff = diff < zero_ ? __hneg(diff) : diff; if (abs_diff < beta_) { return half_of_one_ * abs_diff * abs_diff * inv_beta_; } else { return abs_diff - half_of_one_ * beta_; } } }; template struct SmoothL1GradFunctor { float beta_; float inv_beta_; T zero_; SmoothL1GradFunctor(float beta) : beta_(beta), inv_beta_(static_cast(1.0 / beta)), zero_(GetZeroVal()) {} __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val) const { const T diff = input_val - target_val; const T abs_diff = abs(diff); T dx_val; if (abs_diff < beta_) { dx_val = diff * inv_beta_; } else { dx_val = (diff > zero_) - (diff < zero_); } return dx_val * dy_val; } }; template<> struct SmoothL1GradFunctor { half beta_; half inv_beta_; half zero_; half one_; SmoothL1GradFunctor(float beta) : beta_(__float2half(beta)), inv_beta_(__float2half(static_cast(1.0 / beta))), zero_(__float2half(0.f)), one_(__float2half(1.f)) {} __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const { const half diff = input_val - target_val; const half abs_diff = diff < zero_ ? __hneg(diff) : diff; half dx_val; if (abs_diff < beta_) { dx_val = diff * inv_beta_; } else { dx_val = (diff > zero_) - (diff < zero_); } return dx_val * dy_val; } }; template class SmoothL1LossKernel : public SimpleLossKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, T* out) const { const float beta = ctx->Attr("beta"); OF_CUDA_CHECK((cuda::elementwise::Binary(SmoothL1Functor(beta), elem_cnt, out, input, target, ctx->stream()->As()->cuda_stream()))); } }; template class SmoothL1LossGradKernel : public SimpleLossGradKernel> { public: void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx) const { const float beta = ctx->Attr("beta"); OF_CUDA_CHECK( (cuda::elementwise::Ternary(SmoothL1GradFunctor(beta), elem_cnt, dx, input, target, dy, ctx->stream()->As()->cuda_stream()))); } }; } // namespace REGISTER_SIMPLE_LOSS_KERNEL_CUDA("smooth_l1_loss", SmoothL1LossKernel) REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA("smooth_l1_loss_grad", SmoothL1LossGradKernel) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/softmax_cross_entropy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/softmax_cross_entropy_kernel.h" #include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { namespace user_op { template struct CrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const T* x, const T* labels, T* y) { FOR_RANGE(int64_t, i, 0, num_instances) { T tmp = 0; FOR_RANGE(int64_t, j, 0, num_classes) { T label = labels[i * num_classes + j]; T prob = x[i * num_classes + j]; // tmp -= label * SafeLog(prob); tmp -= label * logf((prob > 1e-20) ? prob : 1e-20); } y[i] = tmp; } } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const T* prob, const T* labels, const T* dy, T* dx) { FOR_RANGE(int64_t, i, 0, elem_cnt) { const int32_t row_id = i / num_classes; dx[i] = dy[row_id] * (prob[i] - labels[i]); } } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/softmax_cross_entropy_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/softmax_cross_entropy_kernel.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace user_op { namespace { constexpr int64_t kCrossEntropyGpuBlockSize = 128; template __global__ void ComputeEntropyGpu(const int64_t num_instances, const int64_t num_classes, const T* x, const T* labels, T* y) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; const int tid = threadIdx.x; for (int row = blockIdx.x; row < num_instances; row += gridDim.x) { const int row_offset = row * num_classes; const T* in_row = x + row_offset; const T* label_row = labels + row_offset; T result = 0; for (int col = tid; col < num_classes; col += kCrossEntropyGpuBlockSize) { T label = label_row[col]; T prob = in_row[col]; result += -label * SafeLog(prob); } __syncthreads(); T row_reduce_result = BlockReduce(temp_storage).Reduce(result, cub::Sum()); if (0 == tid) { y[row] = row_reduce_result; } } } __global__ void ComputeEntropyGpuHalf(const int64_t num_instances, const int64_t num_classes, const half* x, const half* labels, half* y) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; const int tid = threadIdx.x; for (int row = blockIdx.x; row < num_instances; row += gridDim.x) { const int row_offset = row * num_classes; const half* in_row = x + row_offset; const half* label_row = labels + row_offset; float result = 0; for (int col = tid; col < num_classes; col += kCrossEntropyGpuBlockSize) { float label = __half2float(label_row[col]); float prob = __half2float(in_row[col]); result += -label * SafeLog(prob); } __syncthreads(); float row_reduce_result = BlockReduce(temp_storage).Reduce(result, cub::Sum()); if (0 == tid) { y[row] = __float2half(row_reduce_result); } } } template __global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t num_classes, const T* prob, const T* labels, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int32_t row_id = i / num_classes; dx[i] = dy[row_id] * (prob[i] - labels[i]); } } __global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int64_t num_classes, const half* prob, const half* labels, const half* dy, half* dx) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const int32_t row_id = i / num_classes; dx[i] = __hmul(dy[row_id], __hsub(prob[i], labels[i])); } #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } } // namespace int GetCrossEntropyNumBlocks(const int num_instances) { return std::min(static_cast(num_instances), kCudaMaxBlocksNum); } int GetCrossEntropyBlockSize() { return kCrossEntropyGpuBlockSize; } template struct CrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const T* x, const T* labels, T* y) { OF_CUDA_CHECK(cudaMemsetAsync(y, 0, sizeof(T) * num_instances, stream->As()->cuda_stream())); ComputeEntropyGpu<<As()->cuda_stream()>>>(num_instances, num_classes, x, labels, y); } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const T* prob, const T* labels, const T* dy, T* dx) { ComputeDiffWithSoftmaxGpu<<As()->cuda_stream()>>>( elem_cnt, num_classes, prob, labels, dy, dx); } }; template<> struct CrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const float16* x, const float16* labels, float16* y) { OF_CUDA_CHECK(cudaMemsetAsync(y, 0, sizeof(float16) * num_instances, stream->As()->cuda_stream())); ComputeEntropyGpuHalf<<As()->cuda_stream()>>>( num_instances, num_classes, reinterpret_cast(x), reinterpret_cast(labels), reinterpret_cast(y)); } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const float16* prob, const float16* labels, const float16* dy, float16* dx) { ComputeDiffWithSoftmaxGpuHalf<<As()->cuda_stream()>>>( elem_cnt, num_classes, reinterpret_cast(prob), reinterpret_cast(labels), reinterpret_cast(dy), reinterpret_cast(dx)); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/softmax_cross_entropy_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/softmax.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewSoftmaxPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("prediction", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto SoftmaxPrimitiveExists() { return hob::make_custom("SoftmaxPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewSoftmaxPrimitive(&ctx).operator bool(); }); } } // namespace template struct CrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const T* x, const T* labels, T* y); static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const T* prob, const T* labels, const T* dy, T* dx); }; template class SoftmaxCrossEntropyKernel final : public user_op::OpKernel { public: SoftmaxCrossEntropyKernel() = default; ~SoftmaxCrossEntropyKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto num_axes = label->shape_view().NumAxes(); const int64_t num_instances = label->shape_view().Count(0, num_axes - 1); const int64_t num_classes = label->shape_view().At(num_axes - 1); std::unique_ptr primitive = NewSoftmaxPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), num_instances, num_classes, prediction->dptr(), prob->mut_dptr()); CrossEntropyKernelUtil::ComputeEntropy(ctx->stream(), num_instances, num_classes, prob->dptr(), label->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair) \ REGISTER_USER_KERNEL("softmax_cross_entropy") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && SoftmaxPrimitiveExists()); template class SoftmaxCrossEntropyGradKernel final : public user_op::OpKernel { public: SoftmaxCrossEntropyGradKernel() = default; ~SoftmaxCrossEntropyGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t num_instances = dy->shape_view().elem_cnt(); CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances; CrossEntropyKernelUtil::ComputeDiffWithSoftmax( ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, prob->dptr(), label->dptr(), dy->dptr(), prediction_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL(device_type_v, dtype_pair) \ REGISTER_USER_KERNEL("softmax_cross_entropy_grad") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && (user_op::HobDataType("prediction_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("prediction_diff", 0, "prob", 0, true)); \ return Maybe::Ok(); \ }); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/softmax_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/softmax.h" #include "oneflow/core/ep/include/primitive/softmax_backward.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template std::unique_ptr NewSoftmaxPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("in", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto SoftmaxPrimitiveExists() { return hob::make_custom("SoftmaxPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewSoftmaxPrimitive(&ctx).operator bool(); }); } template std::unique_ptr NewSoftmaxBackwardPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto SoftmaxBackwardPrimitiveExists() { return hob::make_custom("SoftmaxBackwardPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewSoftmaxBackwardPrimitive(&ctx).operator bool(); }); } } // namespace class SoftmaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SoftmaxKernel() = default; ~SoftmaxKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); const int64_t cols = in_shape.At(in_shape.NumAxes() - 1); const int64_t rows = in_shape.Count(0, in_shape.NumAxes() - 1); std::unique_ptr primitive = NewSoftmaxPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), rows, cols, in->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("softmax").SetCreateFn().SetIsMatchedHob( SoftmaxPrimitiveExists() == true); class SoftmaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SoftmaxGradKernel() = default; ~SoftmaxGradKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t num_classes = y->shape_view().At(y->shape_view().NumAxes() - 1); const int64_t num_instances = y->shape_view().elem_cnt() / num_classes; std::unique_ptr primitive = NewSoftmaxBackwardPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), num_instances, num_classes, y->dptr(), dy->dptr(), dx->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("softmax_grad") .SetCreateFn() .SetIsMatchedHob(SoftmaxBackwardPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sort_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { template class CpuSortKernel final : public user_op::OpKernel { public: CpuSortKernel() = default; ~CpuSortKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); Memcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->shape_view().elem_cnt() * sizeof(T)); const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int32_t instance_num = in->shape_view().elem_cnt() / instance_size; const std::string& direction = ctx->Attr("direction"); const bool is_ascending = direction == "ASCENDING"; const bool is_descending = direction == "DESCENDING"; FOR_RANGE(int32_t, i, 0, instance_num) { T* out_ptr_i = out->mut_dptr() + i * instance_size; if (is_ascending) { std::sort(out_ptr_i, out_ptr_i + instance_size, std::less()); } else if (is_descending) { std::sort(out_ptr_i, out_ptr_i + instance_size, std::greater()); } else { UNIMPLEMENTED(); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_SORT_KERNEL(dtype) \ REGISTER_USER_KERNEL("sort").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CPU_SORT_KERNEL(float) REGISTER_CPU_SORT_KERNEL(double) REGISTER_CPU_SORT_KERNEL(int32_t) REGISTER_CPU_SORT_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sort_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { template class GpuSortKernel final : public user_op::OpKernel { public: GpuSortKernel() = default; ~GpuSortKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); Memcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->shape_view().elem_cnt() * sizeof(T)); const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int32_t instance_num = in->shape_view().elem_cnt() / instance_size; const std::string& direction = ctx->Attr("direction"); if (direction == "ASCENDING") { SortKeysAscending(in->dptr(), instance_num, instance_size, tmp_buffer->mut_dptr(), tmp_buffer->shape_view().elem_cnt(), out->mut_dptr(), ctx->stream()->As()->cuda_stream()); } else if (direction == "DESCENDING") { SortKeysDescending(in->dptr(), instance_num, instance_size, tmp_buffer->mut_dptr(), tmp_buffer->shape_view().elem_cnt(), out->mut_dptr(), ctx->stream()->As()->cuda_stream()); } else { UNIMPLEMENTED(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_SORT_KERNEL(dtype) \ REGISTER_USER_KERNEL("sort") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const Shape& in_shape = ctx->InputShape("in", 0); \ const int32_t instance_size = in_shape.dim_vec().back(); \ const int32_t instance_num = in_shape.elem_cnt() / instance_size; \ const std::string& direction = ctx->Attr("direction"); \ if (direction == "ASCENDING") { \ return InferTempStorageForSortKeysAscending(instance_num, instance_size); \ } else if (direction == "DESCENDING") { \ return InferTempStorageForSortKeysDescending(instance_num, instance_size); \ } else { \ UNIMPLEMENTED(); \ return 0; \ } \ }); REGISTER_CUDA_SORT_KERNEL(float) REGISTER_CUDA_SORT_KERNEL(double) REGISTER_CUDA_SORT_KERNEL(int32_t) REGISTER_CUDA_SORT_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_cross_entropy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { namespace user_op { namespace { class SparseCrossEntropyOpKernelCache final : public user_op::OpKernelCache { public: SparseCrossEntropyOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~SparseCrossEntropyOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; } // namespace template class SparseCrossEntropyKernel final : public user_op::OpKernel { public: SparseCrossEntropyKernel() = default; ~SparseCrossEntropyKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); SparseCrossEntropyKernelUtil::ComputeEntropy( ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr(), label->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SparseCrossEntropyMsKernel final : public user_op::OpKernel { public: SparseCrossEntropyMsKernel() = default; ~SparseCrossEntropyMsKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prediction", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); const TensorDesc* prediction_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("prediction", 0); const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { return nullptr; } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } Memset(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type())); SparseCrossEntropyKernelUtil::ComputeEntropy( ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr(), label->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SPARSE_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, dtype_pair, \ ltype_pair) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyKernel), ("sparse_cross_entropy"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyKernel), ("sparse_cross_entropy"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyMsKernel), ("sparse_cross_entropy_ms"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyMsKernel), ("sparse_cross_entropy_ms"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif template class SparseCrossEntropyGradKernel final : public user_op::OpKernel { public: SparseCrossEntropyGradKernel() = default; ~SparseCrossEntropyGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); size_t prediction_diff_bytes_size = prediction_diff->shape_view().elem_cnt() * GetSizeOfDataType(prediction_diff->data_type()); Memset(ctx->stream(), prediction_diff->mut_dptr(), 0, prediction_diff_bytes_size); SparseCrossEntropyKernelUtil::ComputeDiff( ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr(), label->dptr(), dy->dptr(), prediction_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SparseCrossEntropyMsGradKernel final : public user_op::OpKernel { public: SparseCrossEntropyMsGradKernel() = default; ~SparseCrossEntropyMsGradKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prediction", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); const TensorDesc* prediction_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("prediction", 0); const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { return nullptr; } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } size_t prediction_diff_bytes_size = prediction_diff->shape_view().elem_cnt() * GetSizeOfDataType(prediction_diff->data_type()); Memset(ctx->stream(), prediction_diff->mut_dptr(), 0, prediction_diff_bytes_size); SparseCrossEntropyKernelUtil::ComputeDiff( ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr(), label->dptr(), dy->dptr(), prediction_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL(kernel_class, kernel_name, device_type_v, \ dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("prediction_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL, (SparseCrossEntropyGradKernel), ("sparse_cross_entropy_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL, (SparseCrossEntropyGradKernel), ("sparse_cross_entropy_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL, (SparseCrossEntropyMsGradKernel), ("sparse_cross_entropy_ms_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL, (SparseCrossEntropyMsGradKernel), ("sparse_cross_entropy_ms_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_cross_entropy_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h" #include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { namespace user_op { template struct SparseCrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, T* y) { FOR_RANGE(int64_t, i, 0, num_instances) { CHECK_GE(labels[i], 0); CHECK_LT(labels[i], depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { y[i] = -SafeLog(x[i * num_classes + label]); } } } static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, const T* dy, T* dx) { FOR_RANGE(int64_t, i, 0, num_instances) { CHECK_GE(labels[i], 0); CHECK_LT(labels[i], depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]); } } } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { FOR_RANGE(int64_t, i, 0, elem_cnt) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; CHECK_GE(labels[row_id], 0); CHECK_LT(labels[row_id], depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = dy[row_id] * (prob[i] - 1); } else { dx[i] = dy[row_id] * prob[i]; } } } }; #define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair) \ template struct SparseCrossEntropyKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_cross_entropy_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { template __global__ void ComputeEntropyGpu(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, T* y) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) { assert(labels[i] >= 0); assert(labels[i] < depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { y[i] = -SafeLog(x[i * num_classes + label]); } } } template __global__ void ComputeEntropyGpuHalf(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const half* x, const K* labels, half* y) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) { assert(labels[i] >= 0); assert(labels[i] < depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { y[i] = __float2half(-SafeLog(__half2float(x[i * num_classes + label]))); } } #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } template __global__ void ComputeDiffGpu(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) { assert(labels[i] >= 0); assert(labels[i] < depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]); } } } template __global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const half* x, const K* labels, const half* dy, half* dx) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) { assert(labels[i] >= 0); assert(labels[i] < depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { dx[i * num_classes + label] = __hneg(__hdiv(__float2half(dy[i]), MaxWithLogThreshold(x[i * num_classes + label]))); } } #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } template __global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) { const IndexType row_id = i / num_classes; const IndexType col_id = i - row_id * num_classes; assert(labels[row_id] >= 0); assert(labels[row_id] < depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = dy[row_id] * (prob[i] - 1); } else { dx[i] = dy[row_id] * prob[i]; } } } template __global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const half* prob, const K* labels, const half* dy, half* dx) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) { // NOTE(chengcheng): int division ('/') of i will reduce performance of int64_t. const IndexType row_id = i / num_classes; const IndexType col_id = i - row_id * num_classes; assert(labels[row_id] >= 0); assert(labels[row_id] < depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = __hmul(dy[row_id], __hsub(prob[i], __float2half(1.0))); } else { dx[i] = __hmul(dy[row_id], prob[i]); } } #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } template __global__ void ComputeDiffWithSoftmaxGpuHalf2(const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const half* prob, const K* labels, const half* dy, half* dx) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) const int64_t h2_num_classes = num_classes / 2; const int64_t h2_elem_cnt = elem_cnt / 2; const auto* prob_h2 = reinterpret_cast(prob); auto* dx_h2 = reinterpret_cast(dx); CUDA_1D_KERNEL_LOOP_T(IndexType, i, h2_elem_cnt) { const IndexType row_id = i / h2_num_classes; const IndexType h2_col_id = i - row_id * h2_num_classes; assert(labels[row_id] >= 0); assert(labels[row_id] < depth); K label = labels[row_id] - lower_bound; const half2 prob_h2_i = prob_h2[i]; const half dy_row = dy[row_id]; half2 dx_h2_i; dx_h2_i.x = __hmul(dy_row, __hsub(prob_h2_i.x, static_cast(label == 2 * h2_col_id))); dx_h2_i.y = __hmul(dy_row, __hsub(prob_h2_i.y, static_cast(label == 2 * h2_col_id + 1))); dx_h2[i] = dx_h2_i; } #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } } // namespace template struct SparseCrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, T* y) { ComputeEntropyGpu<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, x, labels, y); } static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, const T* dy, T* dx) { ComputeDiffGpu<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, x, labels, dy, dx); } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { if (elem_cnt < GetMaxVal() / 2) { ComputeDiffWithSoftmaxGpu <<As()->cuda_stream()>>>(elem_cnt, num_classes, depth, lower_bound, prob, labels, dy, dx); } else { ComputeDiffWithSoftmaxGpu <<As()->cuda_stream()>>>(elem_cnt, num_classes, depth, lower_bound, prob, labels, dy, dx); } } }; template struct SparseCrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const float16* x, const K* labels, float16* y) { ComputeEntropyGpuHalf<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, reinterpret_cast(x), labels, reinterpret_cast(y)); } static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const float16* x, const K* labels, const float16* dy, float16* dx) { ComputeDiffGpuHalf<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, reinterpret_cast(x), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const float16* prob, const K* labels, const float16* dy, float16* dx) { if (num_classes % 2 == 0) { if (elem_cnt < GetMaxVal() / 2) { ComputeDiffWithSoftmaxGpuHalf2 <<As()->cuda_stream()>>>( elem_cnt, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } else { ComputeDiffWithSoftmaxGpuHalf2 <<As()->cuda_stream()>>>( elem_cnt, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } } else { if (elem_cnt < GetMaxVal() / 2) { ComputeDiffWithSoftmaxGpuHalf <<As()->cuda_stream()>>>( elem_cnt, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } else { ComputeDiffWithSoftmaxGpuHalf <<As()->cuda_stream()>>>( elem_cnt, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } } } }; #define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair) \ template struct SparseCrossEntropyKernelUtil< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_cross_entropy_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { namespace user_op { template struct SparseCrossEntropyKernelUtil { static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, T* y); static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* x, const K* labels, const T* dy, T* dx); static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx); }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h" #include "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/ep/include/primitive/log_softmax.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewLogSoftmaxPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("prediction", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } auto LogSoftmaxPrimitiveExists() { return hob::make_custom("LogSoftmaxPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewLogSoftmaxPrimitive(&ctx).operator bool(); }); } class SparseSoftmaxCrossEntropyOpKernelCache final : public user_op::OpKernelCache { public: SparseSoftmaxCrossEntropyOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~SparseSoftmaxCrossEntropyOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; } // namespace template class SparseSoftmaxCrossEntropyKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SparseSoftmaxCrossEntropyKernel() = default; ~SparseSoftmaxCrossEntropyKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); std::unique_ptr primitive = NewLogSoftmaxPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), num_instances, num_classes, prediction->dptr(), prob->mut_dptr()); const K* labels = label->dptr(); const T* prob_ptr = prob->dptr(); T* out_ptr = out->mut_dptr(); FOR_RANGE(int64_t, i, 0, num_instances) { CHECK_GE(labels[i], 0); CHECK_LT(labels[i], depth); K _label = labels[i] - lower_bound; if (_label >= 0 && _label < num_classes) { out_ptr[i] = -prob_ptr[i * num_classes + _label]; } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SparseSoftmaxCrossEntropyMsKernel final : public user_op::OpKernel { public: SparseSoftmaxCrossEntropyMsKernel() = default; ~SparseSoftmaxCrossEntropyMsKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { LOG(FATAL) << "SparseSoftmaxCrossEntropyMsKernel should be split to ops"; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, \ dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ && LogSoftmaxPrimitiveExists()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL, (SparseSoftmaxCrossEntropyKernel), ("sparse_softmax_cross_entropy"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL, (SparseSoftmaxCrossEntropyMsKernel), ("sparse_softmax_cross_entropy_ms"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL, (SparseSoftmaxCrossEntropyMsKernel), ("sparse_softmax_cross_entropy_ms"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif template class SparseSoftmaxCrossEntropyGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SparseSoftmaxCrossEntropyGradKernel() = default; ~SparseSoftmaxCrossEntropyGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances; const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); SparseSoftmaxCrossEntropyKernelUtil::ComputeDiff( ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, depth, lower_bound, prob->dptr(), label->dptr(), dy->dptr(), prediction_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class SparseSoftmaxCrossEntropyMsGradKernel final : public user_op::OpKernel { public: SparseSoftmaxCrossEntropyMsGradKernel() = default; ~SparseSoftmaxCrossEntropyMsGradKernel() = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prob", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); const TensorDesc* prob_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("prob", 0); const int64_t class_axis = prob_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prob_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { return nullptr; } } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex("prediction_diff", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; if (cache != nullptr) { auto* kernel_cache = dynamic_cast(cache); CHECK_NOTNULL(kernel_cache); CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); lower_bound = kernel_cache->lower(); } SparseCrossEntropyKernelUtil::ComputeDiffWithSoftmax( ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, depth, lower_bound, prob->dptr(), label->dptr(), dy->dptr(), prediction_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL(kernel_class, kernel_name, \ device_type_v, dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL(kernel_name) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("prediction_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("prediction_diff", 0, "prob", 0, true)); \ return Maybe::Ok(); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, (SparseSoftmaxCrossEntropyGradKernel), ("sparse_softmax_cross_entropy_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, (SparseSoftmaxCrossEntropyGradKernel), ("sparse_softmax_cross_entropy_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, (SparseSoftmaxCrossEntropyMsGradKernel), ("sparse_softmax_cross_entropy_ms_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL, (SparseSoftmaxCrossEntropyMsGradKernel), ("sparse_softmax_cross_entropy_ms_grad"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { template void ComputeProb(ep::Stream* stream, const int64_t row, const int64_t col, const T* in, T* prob) { using ComputeType = typename cuda::softmax::DefaultComputeType::type; cuda::softmax::DirectLoad load(in, col); cuda::softmax::DirectStore store(prob, col); OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax( stream->As()->cuda_stream(), load, store, row, col))); } template<> void ComputeProb(ep::Stream* stream, const int64_t row, const int64_t col, const float16* in, float16* prob) { cuda::softmax::DirectLoad load(reinterpret_cast(in), col); cuda::softmax::DirectStore store(reinterpret_cast(prob), col); OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax( stream->As()->cuda_stream(), load, store, row, col))); } template __global__ void ComputeSparseSoftmaxCrossEntropyResultGpu(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const K* labels, const T* prob, T* out) { CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) { assert(labels[i] >= 0); assert(labels[i] < depth); K label = labels[i] - lower_bound; if (label >= 0 && label < num_classes) { out[i] = -prob[i * num_classes + label]; } } } template inline typename std::enable_if::value, void>::type ComputeSparseSoftmaxCrossEntropyResult(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const K* labels, const T* prob, T* out) { ComputeSparseSoftmaxCrossEntropyResultGpu <<As()->cuda_stream()>>>(num_instances, num_classes, depth, lower_bound, labels, prob, out); } template inline typename std::enable_if::value, void>::type ComputeSparseSoftmaxCrossEntropyResult(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const K* labels, const T* prob, T* out) { #if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) ComputeSparseSoftmaxCrossEntropyResultGpu <<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, labels, reinterpret_cast(prob), reinterpret_cast(out)); #else printf("use half need nvcc arch >= 530"); assert(false); #endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } } // namespace template class SparseSoftmaxCrossEntropyKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SparseSoftmaxCrossEntropyKernel() = default; ~SparseSoftmaxCrossEntropyKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_instances = label->shape_view().elem_cnt(); CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0); const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances; const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); ComputeProb(ctx->stream(), num_instances, num_classes, prediction->dptr(), prob->mut_dptr()); ComputeSparseSoftmaxCrossEntropyResult(ctx->stream(), num_instances, num_classes, depth, lower_bound, label->dptr(), prob->dptr(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(dtype_pair, ltype_pair) \ REGISTER_USER_KERNEL("sparse_softmax_cross_entropy") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h" namespace oneflow { namespace user_op { template struct SparseSoftmaxCrossEntropyKernelUtil { static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { FOR_RANGE(int64_t, i, 0, num_instances) { const int32_t row_id = i / num_classes; const int32_t col_id = i - row_id * num_classes; CHECK_GE(labels[row_id], 0); CHECK_LT(labels[row_id], depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = dy[row_id] * (std::exp(prob[i]) - 1); } else { dx[i] = dy[row_id] * std::exp(prob[i]); } } } }; #define INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair) \ template struct SparseSoftmaxCrossEntropyKernelUtil< \ DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h" #include "oneflow/core/cuda/softmax.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { template __inline__ __device__ T Exp(T x); template<> __inline__ __device__ float Exp(float x) { #ifdef OF_SOFTMAX_USE_FAST_MATH return __expf(x); #else return exp(x); #endif } template<> __inline__ __device__ double Exp(double x) { return exp(x); } template<> __inline__ __device__ half Exp(half x) { #ifdef OF_SOFTMAX_USE_FAST_MATH return __float2half(__expf(__half2float(x))); #else return __float2half(exp(__half2float(x))); #endif } template __global__ void ComputeDiffGpu(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_instances) { const IndexType row_id = i / num_classes; const IndexType col_id = i - row_id * num_classes; assert(labels[row_id] >= 0); assert(labels[row_id] < depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = dy[row_id] * (Exp(prob[i]) - 1); } else { dx[i] = dy[row_id] * Exp(prob[i]); } } } template __global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const half* prob, const K* labels, const half* dy, half* dx) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_instances) { const IndexType row_id = i / num_classes; const IndexType col_id = i - row_id * num_classes; assert(labels[row_id] >= 0); assert(labels[row_id] < depth); K label = labels[row_id] - lower_bound; if (label == col_id) { dx[i] = __hmul(dy[row_id], __hsub(Exp(prob[i]), __float2half(1.0))); } else { dx[i] = __hmul(dy[row_id], Exp(prob[i])); } } } } // namespace template struct SparseSoftmaxCrossEntropyKernelUtil { static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx) { if (num_instances < GetMaxVal() / 2) { ComputeDiffGpu<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, prob, labels, dy, dx); } else { // NOTE(chengcheng): int division ('/') of i will reduce performance of int64_t. ComputeDiffGpu<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, prob, labels, dy, dx); } } }; template struct SparseSoftmaxCrossEntropyKernelUtil { static void ComputeDiff(ep::Stream* stream, const int64_t num_instances, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const float16* prob, const K* labels, const float16* dy, float16* dx) { if (num_instances < GetMaxVal() / 2) { ComputeDiffGpuHalf<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } else { ComputeDiffGpuHalf<<As()->cuda_stream()>>>( num_instances, num_classes, depth, lower_bound, reinterpret_cast(prob), labels, reinterpret_cast(dy), reinterpret_cast(dx)); } } }; #define INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair) \ template struct SparseSoftmaxCrossEntropyKernelUtil< \ DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_ #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { namespace user_op { template struct SparseSoftmaxCrossEntropyKernelUtil { static void ComputeDiff(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes, const int64_t depth, const int64_t lower_bound, const T* prob, const K* labels, const T* dy, T* dx); }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/split_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" namespace oneflow { namespace { template std::unique_ptr NewCopyNdPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type(), 2); } class SplitLikeKernel final : public user_op::OpKernel { public: SplitLikeKernel() = default; ~SplitLikeKernel() override = default; private: void InferShape(user_op::KernelInferContext* ctx) const override { const auto axis = ctx->Attr("axis"); const ShapeView& in_shape_view = ctx->ShapeView4ArgNameAndIndex("in", 0); int64_t total_dim_size = 0; const int64_t like_num_axes = ctx->ShapeView4ArgNameAndIndex("like", 0).NumAxes(); const int64_t in_num_axes = in_shape_view.NumAxes(); CHECK_LE(like_num_axes, in_num_axes); CHECK_LT(axis, like_num_axes); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { const ShapeView& like_shape_view = ctx->ShapeView4ArgNameAndIndex("like", i); CHECK_EQ(like_shape_view.NumAxes(), like_num_axes); FOR_RANGE(int64_t, j, 0, like_num_axes) { if (j == axis) { total_dim_size += like_shape_view.At(j); } else { CHECK_EQ(like_shape_view.At(j), in_shape_view.At(j)); } } if (ctx->TensorDesc4ArgNameAndIndex("out", i)->is_dynamic()) { auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex("out", i); DimVector out_i_dim_vec; like_shape_view.ToDimVector(&out_i_dim_vec); FOR_RANGE(int64_t, j, like_num_axes, in_num_axes) { out_i_dim_vec.emplace_back(in_shape_view.At(j)); } mut_shape_view.set_shape(Shape(out_i_dim_vec)); } } CHECK_EQ(total_dim_size, in_shape_view.At(axis)); } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); const auto axis = ctx->Attr("axis"); const int64_t in_cols = in_tensor->shape_view().Count(axis); const int64_t rows = in_tensor->shape_view().elem_cnt() / in_cols; CHECK_GT(rows, 0); auto primitive = NewCopyNdPrimitive(ctx); CHECK(primitive); int64_t in_col_offset = 0; for (const auto& out_arg_pair : ctx->outputs()) { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second); const int64_t out_cols = out_tensor->shape_view().Count(axis); CHECK_EQ(out_tensor->shape_view().elem_cnt(), rows * out_cols); if (out_cols > 0) { DimVector dst_shape = {rows, out_cols}; DimVector dst_pos_vec = {0, 0}; DimVector src_shape = {rows, in_cols}; DimVector src_pos_vec = {0, in_col_offset}; DimVector extent_vec = {rows, out_cols}; primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } in_col_offset += out_cols; } CHECK_EQ(in_col_offset, in_cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto CopyNdPrimitiveExists() { return hob::make_custom("CopyNdPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewCopyNdPrimitive(&ctx).operator bool(); }); } } // namespace REGISTER_USER_KERNEL("split_like") .SetCreateFn() .SetIsMatchedHob(CopyNdPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sqrt_square_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { int64_t getThreadNumBlocks(int64_t n) { int64_t num_blocks = 1; #ifdef WITH_CUDA num_blocks = BlocksNum4ThreadsNum(n); #endif return num_blocks; } template class SqrtSquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SqrtSquareSumKernel() = default; ~SqrtSquareSumKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); SqrtSquareSumKernelUtil::SqrtSquareSum(ctx->stream(), x->shape_view().elem_cnt(), x->dptr(), y->mut_dptr(), tmp->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SQUARE_SUM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("sqrt_square_sum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& x_shape = ctx->InputTensorDesc("x", 0).shape(); \ const int32_t num_blocks = getThreadNumBlocks(x_shape.Count(0)); \ int64_t tmp_buffer_size = num_blocks; \ return tmp_buffer_size * sizeof(OF_PP_PAIR_FIRST(dtype)); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" namespace oneflow { template struct SqrtSquareSumKernelUtil { static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) { T sum = 0; FOR_RANGE(int64_t, i, 0, n) { sum += x[i] * x[i]; } *y = std::sqrt(sum); } }; #define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU(type_cpp, type_proto) \ template struct SqrtSquareSumKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sqrt_square_sum_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template __global__ void SqrtSquareSumForOneThreadBlock(int64_t n, const T* x, T* y) { T t_sum = 0; CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { *y = sqrt(b_sum); } } template __global__ void SqrtSumForMultiThreadBlock(int64_t n, const T* x, T* y) { T t_sum = 0; CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i]; } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { *y = sqrt(b_sum); } } template __global__ void SquareSumForMultiThreadBlock(int64_t n, const T* x, T* tmp) { T t_sum = 0; CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { tmp[blockIdx.x] = b_sum; } } } // namespace template struct SqrtSquareSumKernelUtil { static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) { const int32_t num_blocks = BlocksNum4ThreadsNum(n); CHECK_GE(num_blocks, 0); if (num_blocks == 1) { SqrtSquareSumForOneThreadBlock <<<1, kCudaThreadsNumPerBlock, 0, stream->As()->cuda_stream()>>>(n, x, y); } else { Memset(stream, y, 0, sizeof(T)); SquareSumForMultiThreadBlock <<As()->cuda_stream()>>>( n, x, tmp); SqrtSumForMultiThreadBlock <<<1, kCudaThreadsNumPerBlock, 0, stream->As()->cuda_stream()>>>( num_blocks, tmp, y); } } }; #define INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA(type_cpp, type_proto) \ template struct SqrtSquareSumKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/sqrt_square_sum_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct SqrtSquareSumKernelUtil { static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/square_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/square_sum_kernel_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { template class SquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: SquareSumKernel() = default; ~SquareSumKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); SquareSumKernelUtil::SquareSum(ctx->stream(), x->shape_view().elem_cnt(), x->dptr(), y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SQUARE_SUM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("square_sum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ) template class MultiSquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MultiSquareSumKernel() = default; ~MultiSquareSumKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { std::vector> params; params.resize(ctx->input_size("x")); for (int64_t i = 0; i < params.size(); ++i) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", i); params[i].count = x->shape_view().elem_cnt(); params[i].ptr = x->dptr(); } user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); SquareSumKernelUtil::MultiSquareSum(ctx->stream(), params, y->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_MULTI_SQUARE_SUM_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("multi_square_sum") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MULTI_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/square_sum_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/square_sum_kernel_util.h" namespace oneflow { template struct SquareSumKernelUtil { static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y) { T sum = 0; FOR_RANGE(int64_t, i, 0, n) { sum += x[i] * x[i]; } *y = sum; } static void MultiSquareSum(ep::Stream* stream, const std::vector>& params, T* y) { T sum = 0; FOR_RANGE(int64_t, i, 0, params.size()) { const auto& p = params[i]; FOR_RANGE(int64_t, j, 0, p.count) { sum += p.ptr[j] * p.ptr[j]; } } *y = sum; } }; #define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU(type_cpp, type_proto) \ template struct SquareSumKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/square_sum_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/square_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template __global__ void SquareSumGpu(int64_t n, const T* x, T* y) { T t_sum = 0; CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { if (ONE_BLOCK) { *y = b_sum; } else { cuda::atomic::Add(y, b_sum); } } } constexpr int64_t kMultiSquareSumMaxSize = 64; template struct MultiSquareSumParams { SquareSumParam params[kMultiSquareSumMaxSize]; int32_t size; }; template __global__ void MultiSquareSumGpu(const MultiSquareSumParams params, T* y) { T t_sum = 0; for (int i = 0; i < params.size; ++i) { const SquareSumParam param = params.params[i]; CUDA_1D_KERNEL_LOOP(j, param.count) { t_sum += param.ptr[j] * param.ptr[j]; } } typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T b_sum = BlockReduce(temp_storage).Sum(t_sum); if (threadIdx.x == 0) { cuda::atomic::Add(y, b_sum); } } } // namespace template struct SquareSumKernelUtil { static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y) { const int32_t num_blocks = BlocksNum4ThreadsNum(n); CHECK_GE(num_blocks, 0); if (num_blocks == 0) { Memset(stream, y, 0, sizeof(T)); } else if (num_blocks == 1) { SquareSumGpu <<<1, kCudaThreadsNumPerBlock, 0, stream->As()->cuda_stream()>>>(n, x, y); } else { Memset(stream, y, 0, sizeof(T)); SquareSumGpu <<As()->cuda_stream()>>>( n, x, y); } } static void MultiSquareSum(ep::Stream* stream, const std::vector>& params, T* y) { Memset(stream, y, 0, sizeof(T)); for (int64_t start = 0; start < params.size(); start += kMultiSquareSumMaxSize) { MultiSquareSumParams gpu_params{}; int64_t max_count = 0; gpu_params.size = std::min(start + kMultiSquareSumMaxSize, params.size()) - start; for (int64_t i = 0; i < gpu_params.size; ++i) { gpu_params.params[i] = params[start + i]; max_count = std::max(max_count, gpu_params.params[i].count); } MultiSquareSumGpu<<As()->cuda_stream()>>>(gpu_params, y); } } }; #define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA(type_cpp, type_proto) \ template struct SquareSumKernelUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ); #undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/square_sum_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct SquareSumParam { const T* ptr; int64_t count; }; template struct SquareSumKernelUtil { static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y); static void MultiSquareSum(ep::Stream* stream, const std::vector>& params, T* y); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/ssp_variable_proxy_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template class SspVariableProxyKernel final : public user_op::OpKernel { public: SspVariableProxyKernel() = default; ~SspVariableProxyKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* var = ctx->Tensor4ArgNameAndIndex("var", 0); const user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex("ref", 0); CHECK_EQ(var->dptr(), ref->dptr()); user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); const ShapeView& in_shape = ref->shape_view(); CHECK_EQ(value->shape_view(), in_shape); const DataType in_data_type = ref->data_type(); CHECK_EQ(value->data_type(), in_data_type); Memcpy(ctx->stream(), value->mut_dptr(), ref->dptr(), in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_SSP_VARIABLE_PROXY_KERNEL(device) \ REGISTER_USER_KERNEL("ssp_variable_proxy") \ .SetCreateFn>() \ .SetIsMatchedHob(user_op::HobDeviceType() == device) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("ref", 0, "var", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_SSP_VARIABLE_PROXY_KERNEL(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_SSP_VARIABLE_PROXY_KERNEL(DeviceType::kCUDA) #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/stack_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/copy_nd.h" namespace oneflow { namespace { template std::unique_ptr NewCopyNdPrimitive(Context* ctx) { return ep::primitive::NewPrimitive(ctx->device_type(), 2); } class StackKernel final : public user_op::OpKernel { public: StackKernel() = default; ~StackKernel() = default; private: void InferShape(user_op::KernelInferContext* ctx) const override { const ShapeView& first_input_shape_view = ctx->ShapeView4ArgNameAndIndex("in", 0); const int64_t axis = ctx->Attr("axis"); const int64_t in_num_axes = first_input_shape_view.NumAxes(); DimVector out_dim_vec(in_num_axes + 1); for (int i = 0; i < in_num_axes + 1; i++) { if (i == axis) { continue; } else { out_dim_vec.at(i) = first_input_shape_view.At(i); } } for (const auto& in_arg_pair : ctx->inputs()) { const ShapeView& input_shape_view = ctx->ShapeView4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second); CHECK_EQ(input_shape_view.NumAxes(), first_input_shape_view.NumAxes()); FOR_RANGE(int64_t, i, 0, in_num_axes + 1) { if (i == axis) { out_dim_vec.at(axis) += 1; } else if (i < axis) { CHECK_EQ(input_shape_view.At(i), out_dim_vec.at(i)) << " Stack expects each tensor to be equal size" ", but got " << first_input_shape_view.ToString() << " at first input and " << input_shape_view.ToString(); } else { CHECK_EQ(input_shape_view.At(i - 1), out_dim_vec.at(i)) << " Stack expects each tensor to be equal size" ", but got " << first_input_shape_view.ToString() << " at first input and " << input_shape_view.ToString(); } } } ctx->MutShapeView4ArgNameAndIndex("out", 0).set_shape(Shape(out_dim_vec)); } void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); if (out_tensor->shape_view().elem_cnt() == 0) { return; } const int64_t axis = ctx->Attr("axis"); const int64_t out_cols = out_tensor->shape_view().Count(axis); const int64_t rows = out_tensor->shape_view().Count(0, axis); CHECK_GT(rows, 0) << "The multiplicative from axis 0 to axis " << axis - 1 << " should be greater than 0. "; auto primitive = NewCopyNdPrimitive(ctx); CHECK(primitive) << "Error in Stack kernel NewCopyNdPrimitive. "; int64_t out_col_offset = 0; for (const auto& in_arg_pair : ctx->inputs()) { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second); if (in_tensor->shape_view().elem_cnt() == 0) { continue; } const int64_t in_cols = in_tensor->shape_view().Count(axis); CHECK_EQ(in_tensor->shape_view().elem_cnt(), rows * in_cols) << "The element count of input tensor is not equal to `rows * in_cols`. "; if (in_cols > 0) { DimVector dst_shape = {rows, out_cols}; DimVector dst_pos_vec = {0, out_col_offset}; DimVector src_shape = {rows, in_cols}; DimVector src_pos_vec = {0, 0}; DimVector extent_vec = {rows, in_cols}; primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } out_col_offset += in_cols; } CHECK_EQ(out_col_offset, out_cols) << "The out column offset is not equal to out columns. "; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto CopyNdPrimitiveExists() { return hob::make_custom("CopyNdPrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewCopyNdPrimitive(&ctx).operator bool(); }); } } // namespace REGISTER_USER_KERNEL("stack").SetCreateFn().SetIsMatchedHob(CopyNdPrimitiveExists() == true); class StackGradKernel final : public user_op::OpKernel { public: StackGradKernel() = default; ~StackGradKernel() override = default; private: void InferShape(user_op::KernelInferContext* ctx) const override { const auto axis = ctx->Attr("axis"); const ShapeView& in_shape_view = ctx->ShapeView4ArgNameAndIndex("in", 0); int64_t total_dim_size = 0; const int64_t like_num_axes = ctx->ShapeView4ArgNameAndIndex("like", 0).NumAxes(); const int64_t in_num_axes = in_shape_view.NumAxes(); CHECK_LE(like_num_axes, in_num_axes) << "The num axes of `like` tensor should be less equal to num axes of `in` tensor. "; CHECK_LE(axis, like_num_axes) << "The axis should be less than or equal to num axes of `like` tensor. "; FOR_RANGE(size_t, i, 0, ctx->outputs().size()) { const ShapeView& like_shape_view = ctx->ShapeView4ArgNameAndIndex("like", i); CHECK_EQ(like_shape_view.NumAxes(), like_num_axes) << "The num axes of `like` tensor at index " << i << " should be equal to first `like` tensor. "; FOR_RANGE(int64_t, j, 0, like_num_axes + 1) { if (j == axis) { total_dim_size += like_shape_view.Count(j); } else if (j < axis) { CHECK_EQ(in_shape_view.At(j), like_shape_view.At(j)) << " Stack Grad expects the shape of input tensor is equal to like tensor's. " ", but got " << in_shape_view.ToString() << " at input and " << like_shape_view.ToString() << "at like "; } else { CHECK_EQ(in_shape_view.At(j), like_shape_view.At(j - 1)) << " Stack Grad expects the shape of input tensor is equal to like tensor's. " ", but got " << in_shape_view.ToString() << " at input and " << like_shape_view.ToString() << "at like "; } } if (ctx->TensorDesc4ArgNameAndIndex("out", i)->is_dynamic()) { auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex("out", i); DimVector out_i_dim_vec; like_shape_view.ToDimVector(&out_i_dim_vec); mut_shape_view.set_shape(Shape(out_i_dim_vec)); } } CHECK_EQ(total_dim_size, in_shape_view.Count(axis)) << "The sum of dim size of each `like` tensor should be equal to `in` tensor count from " "axis " << axis; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); const int64_t axis = ctx->Attr("axis"); const int64_t in_cols = in_tensor->shape_view().Count(axis); const int64_t rows = in_tensor->shape_view().Count(0, axis); CHECK_GT(rows, 0) << "The multiplicative from axis 0 to axis " << axis - 1 << " should be greater than 0. "; auto primitive = NewCopyNdPrimitive(ctx); CHECK(primitive) << "Error in Stack Grad kernel NewCopyNdPrimitive. "; int64_t in_col_offset = 0; for (const auto& out_arg_pair : ctx->outputs()) { user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second); const int64_t out_cols = out_tensor->shape_view().Count(axis); CHECK_EQ(out_tensor->shape_view().elem_cnt(), rows * out_cols) << "The element count of output tensor is not equal to `rows * out_cols`. "; if (out_cols > 0) { DimVector dst_shape = {rows, out_cols}; DimVector dst_pos_vec = {0, 0}; DimVector src_shape = {rows, in_cols}; DimVector src_pos_vec = {0, in_col_offset}; DimVector extent_vec = {rows, out_cols}; primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(), dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(), src_pos_vec.data(), extent_vec.data()); } in_col_offset += out_cols; } CHECK_EQ(in_col_offset, in_cols) << "The in column offset is not equal to in columns."; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("stack_grad") .SetCreateFn() .SetIsMatchedHob(CopyNdPrimitiveExists() == true); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/stateful_opkernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/compute_complexity_fn_context.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/profiler/event_recorder.h" #include "oneflow/core/eager/call_context.h" namespace oneflow { namespace one { class GlobalTensorInferResult; using ArgVec = std::vector>; using EagerBlobObjectListRawPtr = const std::vector>*; using GlobalTensorInferResultRawPtr = const GlobalTensorInferResult*; class ZeroCopyBaseContextHelper { public: ZeroCopyBaseContextHelper(const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {} #define RETURN_IF_FOUND(inputs, outputs, post_action) \ int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \ arg_name, index); \ if (i >= 0) { return (inputs).at(i) post_action; } \ i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \ index); \ if (i >= 0) { return (outputs).at(i) post_action; } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); return nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); return nullptr; } user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); } return nullptr; } const GlobalTensorMeta* GlobalTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { const auto& global_tensor_infer_result = call_ctx->global_tensor_infer_result(); RETURN_IF_FOUND(global_tensor_infer_result->input_tensor_metas(), global_tensor_infer_result->output_tensor_metas(), .shared_from_symbol().get()); return nullptr; } Optional> parallel_desc(eager::CallContext* call_ctx) const { const auto& global_tensor_infer_result = call_ctx->global_tensor_infer_result(); if (!global_tensor_infer_result) { return Optional>(); } if (!global_tensor_infer_result->input_tensor_metas().empty()) { return global_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc(); } else if (!global_tensor_infer_result->output_tensor_metas().empty()) { return global_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc(); } else { UNIMPLEMENTED(); return Optional>(); } } const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const auto& parallel_desc = this->parallel_desc(call_ctx); if (parallel_desc.has_value()) { const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc); return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol)); } else { static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx()); return single_device_parallel_ctx; } } const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); } const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); } private: static int32_t TryGetTensorTupleIndex(const std::unordered_map>& arg_name2bn_index2tensor_tuple_index, const std::string& arg_name, const int32_t arg_index) { auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name); if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); } return -1; } static ParallelContext MakeSingleDeviceParallelCtx() { ParallelContext single_device_parallel_ctx; single_device_parallel_ctx.set_parallel_id(0); single_device_parallel_ctx.set_parallel_num(1); return single_device_parallel_ctx; } std::shared_ptr input_arg_tuple_; std::shared_ptr output_arg_tuple_; }; class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper { public: UserKernelBaseContextHelper(DeviceType device_type, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {} ~UserKernelBaseContextHelper() = default; DeviceType device_type() const { return device_type_; } const JobDesc& job_desc() const { UNIMPLEMENTED(); return *(const JobDesc*)nullptr; } private: const DeviceType device_type_; }; class UserOpInferContextHelper final { public: UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : user_op_conf_(user_op_conf), zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {} ~UserOpInferContextHelper() = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { UNIMPLEMENTED(); return nullptr; } const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return *TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return *TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return zero_copy_base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } const Shape& OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } void SetOutputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, const Shape& shape) const { SetShape4ArgNameAndIndex(call_ctx, arg_name, index, shape); } const Shape& Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).shape(); } void SetShape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, const Shape& shape) const { return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->set_shape(shape); } const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Stride4ArgNameAndIndex(call_ctx, arg_name, index); } const Stride& OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Stride4ArgNameAndIndex(call_ctx, arg_name, index); } void SetOutputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, const Stride& stride) const { return SetStride4ArgNameAndIndex(call_ctx, arg_name, index, stride); } const Stride& Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).stride(); } void SetStride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, const Stride& stride) const { return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->set_stride(stride); } DataType InputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } DataType OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } void SetOutputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, DataType data_type) const { return SetDtype4ArgNameAndIndex(call_ctx, arg_name, index, data_type); } DataType Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).data_type(); } void SetDtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, DataType data_type) const { return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index) ->set_data_type(data_type); } bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); } bool OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); } void SetOutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, bool is_dynamic) const { return SetIsDynamic4ArgNameAndIndex(call_ctx, arg_name, index, is_dynamic); } bool IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).is_dynamic(); } void SetIsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index, bool is_dynamic) const { return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index) ->set_is_dynamic(is_dynamic); } const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); } const JobDesc* job_desc() const { UNIMPLEMENTED(); return nullptr; } const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx); } const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const { return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx)); } const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index); CHECK_EQ(nd_sbp.sbp_parallel_size(), 1); return nd_sbp.sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex( call_ctx, arg_name, index)) ->nd_sbp(); } int64_t parallel_num(eager::CallContext* call_ctx) const { return parallel_ctx(call_ctx).parallel_num(); } const std::string& input(const std::string& arg_name, int32_t index) const { return user_op_conf().input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const { return user_op_conf().output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const { return user_op_conf().has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const { return user_op_conf().has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const { return user_op_conf().input_size(arg_name); } int32_t output_size(const std::string& arg_name) const { return user_op_conf().output_size(arg_name); } const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); } const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const std::shared_ptr& Attr4Name(eager::CallContext* call_ctx, const std::string& attr_name) const { return call_ctx->composed_attrs().Attr4Name(attr_name); } private: const user_op::TensorDesc& NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { const user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } return *tensor_desc; } user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { user_op::TensorDesc* tensor_desc = MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } return tensor_desc; } const user_op::UserOpConfWrapper* user_op_conf_; ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_; }; class UserOpInferContext : public user_op::InferContext { public: UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx) : helper_(helper), call_ctx_(call_ctx) {} ~UserOpInferContext() override = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { return helper_->InputTensorDesc(call_ctx_, arg_name, index); } const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, int32_t index) const override { return helper_->OutputTensorDesc(call_ctx_, arg_name, index); } user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { return helper_->MutOutputTensorDesc(call_ctx_, arg_name, index); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return helper_->InputShape(call_ctx_, arg_name, index); } const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return helper_->OutputShape(call_ctx_, arg_name, index); } void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override { return helper_->SetOutputShape(call_ctx_, arg_name, index, shape); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index); } void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Shape& shape) override { return helper_->SetShape4ArgNameAndIndex(call_ctx_, arg_name, index, shape); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return helper_->InputStride(call_ctx_, arg_name, index); } const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return helper_->InputStride(call_ctx_, arg_name, index); } void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override { return helper_->SetOutputStride(call_ctx_, arg_name, index, stride); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index); } void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index, const Stride& stride) override { return helper_->SetStride4ArgNameAndIndex(call_ctx_, arg_name, index, stride); } DataType InputDType(const std::string& arg_name, int32_t index) const override { return helper_->InputDType(call_ctx_, arg_name, index); } DataType OutputDType(const std::string& arg_name, int32_t index) const override { return helper_->OutputDType(call_ctx_, arg_name, index); } void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override { return helper_->SetOutputDType(call_ctx_, arg_name, index, data_type); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index); } void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index, DataType data_type) override { return helper_->SetDtype4ArgNameAndIndex(call_ctx_, arg_name, index, data_type); } MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override { return MemoryFormat4ArgNameAndIndex(arg_name, index); } void SetOutputMemoryFormat(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format); } MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format(); } void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index, MemoryFormat memory_format) override { MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return helper_->InputIsDynamic(call_ctx_, arg_name, index); } bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return helper_->OutputIsDynamic(call_ctx_, arg_name, index); } void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override { return helper_->SetOutputIsDynamic(call_ctx_, arg_name, index, is_dynamic); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index); } void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index, bool is_dynamic) override { return helper_->SetIsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index, is_dynamic); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } const JobDesc* job_desc() const override { return helper_->job_desc(); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index); } int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); } const std::string& input(const std::string& arg_name, int32_t index) const override { return helper_->input(arg_name, index); } const std::string& output(const std::string& arg_name, int32_t index) const override { return helper_->output(arg_name, index); } bool has_input(const std::string& arg_name, int32_t index) const override { return helper_->has_input(arg_name, index); } bool has_output(const std::string& arg_name, int32_t index) const override { return helper_->has_output(arg_name, index); } int32_t input_size(const std::string& arg_name) const override { return helper_->input_size(arg_name); } int32_t output_size(const std::string& arg_name) const override { return helper_->output_size(arg_name); } const std::string& op_name() const override { return helper_->op_name(); } const std::string& op_type_name() const override { return helper_->op_type_name(); } const std::string& op_loc() const override { return helper_->op_loc(); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return helper_->Attr4Name(call_ctx_, attr_name); } const UserOpInferContextHelper* helper_; eager::CallContext* call_ctx_; }; class UserKernelComputeContextHelper final { public: UserKernelComputeContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : user_op_conf_(user_op_conf), base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} ~UserKernelComputeContextHelper() = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index); } DeviceType device_type() const { return base_ctx_helper_.device_type(); } const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const std::shared_ptr& Attr4Name(eager::CallContext* call_ctx, const std::string& attr_name) const { return call_ctx->composed_attrs().Attr4Name(attr_name); } private: const user_op::UserOpConfWrapper* user_op_conf_; UserKernelBaseContextHelper base_ctx_helper_; }; class UserKernelComputeContext final : public user_op::KernelComputeContext { public: UserKernelComputeContext(const UserKernelComputeContextHelper* helper, eager::CallContext* call_ctx, ep::Stream* stream) : helper_(helper), call_ctx_(call_ctx), stream_(stream) {} ~UserKernelComputeContext() = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index); } ep::Stream* stream() override { CHECK_NOTNULL(stream_); return stream_; } DeviceType device_type() const override { return helper_->device_type(); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } private: const user_op::UserOpConfWrapper& user_op_conf() const override { return helper_->user_op_conf(); } const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return helper_->Attr4Name(call_ctx_, attr_name); } const UserKernelComputeContextHelper* helper_; eager::CallContext* call_ctx_; ep::Stream* stream_; }; class UserKernelRegContextHelper final { public: UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : user_op_conf_(user_op_conf), base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} ~UserKernelRegContextHelper() = default; DeviceType device_type() const { return base_ctx_helper_.device_type(); } const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const std::shared_ptr& Attr4Name(eager::CallContext* call_ctx, const std::string& attr_name) const { return call_ctx->composed_attrs().Attr4Name(attr_name); } private: const user_op::UserOpConfWrapper* user_op_conf_; UserKernelBaseContextHelper base_ctx_helper_; }; class UserKernelRegContext final : public user_op::KernelRegContext { public: UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx) : helper_(helper), call_ctx_(call_ctx) {} ~UserKernelRegContext() = default; DeviceType device_type() const override { return helper_->device_type(); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } const user_op::UserOpConfWrapper& user_op_conf() const override { return helper_->user_op_conf(); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return helper_->Attr4Name(call_ctx_, attr_name); } const UserKernelRegContextHelper* helper_; eager::CallContext* call_ctx_; }; class UserKernelInitAndCacheContextHelper final { public: UserKernelInitAndCacheContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) : user_op_conf_(user_op_conf), base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} ~UserKernelInitAndCacheContextHelper() = default; DeviceType device_type() const { return base_ctx_helper_.device_type(); } const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index); } const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index); CHECK_EQ(nd_sbp.sbp_parallel_size(), 1); return nd_sbp.sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return *CHECK_NOTNULL( base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index)) ->nd_sbp(); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const { return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx)); } const std::shared_ptr& Attr4Name(eager::CallContext* call_ctx, const std::string& attr_name) const { return call_ctx->composed_attrs().Attr4Name(attr_name); } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } private: const user_op::UserOpConfWrapper* user_op_conf_; UserKernelBaseContextHelper base_ctx_helper_; }; class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, public user_op::KernelCacheContext { public: UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper, eager::CallContext* call_ctx, ep::Stream* stream) : helper_(helper), call_ctx_(call_ctx), stream_(stream) {} ~UserKernelInitAndCacheContext() override = default; ep::Stream* stream() override { CHECK_NOTNULL(stream_); return stream_; } DeviceType device_type() const override { return helper_->device_type(); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); } private: const std::shared_ptr& Attr4Name( const std::string& attr_name) const override { return helper_->Attr4Name(call_ctx_, attr_name); } const user_op::UserOpConfWrapper& user_op_conf() const override { return helper_->user_op_conf(); } const UserKernelInitAndCacheContextHelper* helper_; eager::CallContext* call_ctx_; ep::Stream* stream_; }; namespace { Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr& op_conf, const ArgVec& indexed_input_pairs, const ArgVec& indexed_output_pairs, OpArgsVector* input_tuple_indexes4const_ibns, OpArgsVector* input_tuple_indexes4mut_ibns, OpArgsVector* output_tuple_indexes4mut_obns, OpArgsVector* output_tuple_indexes4mut2_obns, small_vector* output_tuple_indexes2is_mut2_type) { const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); ArgModifierSignature arg_modifier_signature; for (const auto& pair : indexed_input_pairs) { const std::string ibn = GenRepeatedBn(pair.first, pair.second); arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert( {ibn, user_op::InputArgModifier()}); } for (const auto& pair : indexed_output_pairs) { const std::string obn = GenRepeatedBn(pair.first, pair.second); arg_modifier_signature.mutable_obn2output_blob_modifier()->insert( {obn, user_op::OutputArgModifier()}); } user_op::UserOpConfWrapper op_conf_wrapper(op_conf); if (op_reg_val->input_arg_modify_fn) { user_op::GetInputArgModifier GetInputArgModifierFn = [&arg_modifier_signature](const std::string& in_arg_name, int32_t in_arg_index) -> user_op::InputArgModifier* { const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index); auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier(); return &map->at(ibn); }; JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper)); } if (op_reg_val->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = [&arg_modifier_signature](const std::string& in_arg_name, int32_t in_arg_index) -> user_op::OutputArgModifier* { const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index); auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier(); return &map->at(obn); }; JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper)); } for (int i = 0; i < indexed_input_pairs.size(); i++) { const auto& pair = indexed_input_pairs.at(i); const std::string ibn = GenRepeatedBn(pair.first, pair.second); if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) { input_tuple_indexes4mut_ibns->emplace_back(i); } else { input_tuple_indexes4const_ibns->emplace_back(i); } } for (int i = 0; i < indexed_output_pairs.size(); i++) { const auto& pair = indexed_output_pairs.at(i); const std::string obn = GenRepeatedBn(pair.first, pair.second); if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) { output_tuple_indexes4mut_obns->emplace_back(i); output_tuple_indexes2is_mut2_type->emplace_back(false); } else { output_tuple_indexes4mut2_obns->emplace_back(i); output_tuple_indexes2is_mut2_type->emplace_back(true); } } return Maybe::Ok(); } } // namespace /* static */ Maybe StatefulOpKernel::New( const std::shared_ptr& op_conf, const Symbol& stream, const AttrMap& base_attrs, const std::shared_ptr& parallel_desc, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) { auto opkernel = std::shared_ptr(new StatefulOpKernel()); opkernel->base_attrs_ = base_attrs; opkernel->op_conf_ = op_conf; opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf)); opkernel->stream_ = stream; opkernel->input_arg_tuple_ = input_arg_tuple; opkernel->output_arg_tuple_ = output_arg_tuple; const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag())); const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get(); opkernel->op_infer_ctx_helper_.reset( new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple)); opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper( device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_, opkernel->output_arg_tuple_)); opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper( device_type, user_op_conf, input_arg_tuple, output_arg_tuple)); opkernel->reg_ctx_helper_.reset( new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple)); const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); if (op_reg_val->logical_tensor_desc_infer_fn) { opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn; } else { return Error::UnimplementedError(); } opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn; JUST(InitTensorTupleIndexes4Bns( op_conf, input_arg_tuple->indexed_arg_name_and_index(), output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_, &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_, &opkernel->output_tuple_indexes4mut2_obns_, &opkernel->output_tuple_indexes2is_mut2_type_)); return opkernel; } StatefulOpKernel::~StatefulOpKernel() = default; size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx, const user_op::OpKernel* user_opkernel) const { UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx); const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel); return InferTmpSizeFn(&op_infer_ctx); } Maybe StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel, bool* need_temp_storage) { DataType primary_dtype = kInvalidDataType; const auto& inputs = call_ctx->inputs(); const auto& outputs = call_ctx->outputs(); if (likely(!inputs.empty())) { primary_dtype = inputs[0]->data_type(); } else if (likely(!outputs.empty())) { primary_dtype = outputs[0]->data_type(); } else { // do nothing } UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx); for (const auto& pair : dtype2cached_kernels_[primary_dtype]) { if (likely(pair.first->is_matched_hob->get(reg_ctx))) { *need_temp_storage = pair.first->need_temp_storage; *user_opkernel = pair.second.get(); return Maybe::Ok(); } } OF_PROFILER_RANGE_GUARD("fallback"); const auto& op_type_name = user_op_conf_->op_type_name(); const auto* kernel_reg_val = JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx)); CHECK_NOTNULL(kernel_reg_val); auto* kernel = kernel_reg_val->create_fn(); dtype2cached_kernels_[primary_dtype].push_back( {kernel_reg_val, std::shared_ptr(kernel)}); infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn); *need_temp_storage = kernel_reg_val->need_temp_storage; *user_opkernel = kernel; return Maybe::Ok(); } void StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx, ep::Stream* stream, const user_op::OpKernel* op_kernel, user_op::OpKernelState** state, user_op::OpKernelCache** cache) { UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx, stream); if (state != nullptr) { auto it = op_kernel_state_map_.find(op_kernel); if (it != op_kernel_state_map_.end()) { *state = it->second.get(); } else { auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx); op_kernel_state_map_.emplace(op_kernel, created_state); *state = created_state.get(); } } { auto& cache_in_map = op_kernel_cache_map_[op_kernel]; op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx, user_op::OpKernelCache::kAllMayChanged, &cache_in_map); *cache = cache_in_map.get(); } } const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn( const user_op::OpKernel* op_kernel) const { return *infer_tmp_size_fn_map_.at(op_kernel); } user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const { return tensor_desc_infer_fn_; } user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; } void StatefulOpKernel::Compute(eager::CallContext* call_ctx, ep::Stream* stream, const user_op::OpKernel* user_opkernel, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const { UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, stream); auto* compute_ctx = &compute_context; OF_PROFILER_RANGE_GUARD("Compute"); auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder( op_type_name(), #if defined(WITH_CUDA) [compute_ctx]() -> int64_t { const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t { const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) { const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second); return mem_size + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type()); }; return std::accumulate(args.begin(), args.end(), static_cast(0), Func); }; return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs()); }, #endif [call_ctx]() -> std::pair { std::stringstream ss; std::size_t hash = 0; for (size_t i = 0; i < call_ctx->inputs().size(); i++) { const auto& shape = call_ctx->inputs().at(i)->shape(); ss << shape; if (i != call_ctx->inputs().size() - 1) { ss << ", "; } AddHash(&hash, shape); } return {ss.str(), hash}; }, [call_ctx]() -> std::pair { const std::string attr_str = call_ctx->composed_attrs().ToString(); return {attr_str, std::hash{}(attr_str)}; })); user_opkernel->Compute(compute_ctx, state, cache); CHECK_JUST(compute_ctx->stream()->GetAsyncError()); } } // namespace one } // namespace oneflow ================================================ FILE: oneflow/user/kernels/stateful_opkernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ #define ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/framework/arg_tuple.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/common/op_args_vector.h" namespace oneflow { class AttrMap; namespace vm { struct OpCallInstructionUtil; } namespace eager { class CallContext; } namespace one { using ArgVec = std::vector>; class UserKernelRegContextHelper; class UserOpInferContextHelper; class UserKernelInitAndCacheContextHelper; class UserKernelComputeContextHelper; class StatefulOpKernel final { public: OF_DISALLOW_COPY_AND_MOVE(StatefulOpKernel); static Maybe New(const std::shared_ptr& op_conf, const Symbol& stream, const AttrMap& base_attrs, const std::shared_ptr& parallel_desc, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple); ~StatefulOpKernel(); const Symbol& stream() const { return stream_; } const std::shared_ptr& mem_case() const { return stream_->device()->mem_case(); } const std::string& op_type_name() const { return op_conf_->user_conf().op_type_name(); } const OpArgsVector& input_tuple_indexes4const_ibns() const { return input_tuple_indexes4const_ibns_; } const OpArgsVector& input_tuple_indexes4mut_ibns() const { return input_tuple_indexes4mut_ibns_; } const OpArgsVector& output_tuple_indexes4mut_obns() const { return output_tuple_indexes4mut_obns_; } const OpArgsVector& output_tuple_indexes4mut2_obns() const { return output_tuple_indexes4mut2_obns_; } bool output_is_mut2_type(int64_t index) const { return output_tuple_indexes2is_mut2_type_.at(index); } const AttrMap& base_attrs() const { return base_attrs_; } size_t InferTmpSize(eager::CallContext* call_ctx, const user_op::OpKernel* user_opkernel) const; Maybe ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel, bool* need_temp_storage); const OperatorConf& op_conf() const { return *op_conf_; } const ArgTuple* input_arg_tuple() const { return input_arg_tuple_.get(); } const ArgTuple* output_arg_tuple() const { return output_arg_tuple_.get(); } private: friend struct vm::OpCallInstructionUtil; StatefulOpKernel() = default; void Compute(eager::CallContext* call_ctx, ep::Stream* stream, const user_op::OpKernel* user_opkernel, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const; user_op::TensorDescInferFn TensorDescInferFn() const; user_op::DataTypeInferFn DataTypeInferFn() const; void TryInitOpKernelStateAndCache(eager::CallContext* call_ctx, ep::Stream* stream, const user_op::OpKernel* op_kernel, user_op::OpKernelState** state, user_op::OpKernelCache** cache); user_op::OpKernelState* mut_opkernel_state(const user_op::OpKernel* opkernel) { return op_kernel_state_map_.at(opkernel).get(); } const user_op::InferTmpSizeFn& GetInferTmpSizeFn(const user_op::OpKernel* op_kernel) const; std::shared_ptr op_conf_; AttrMap base_attrs_; std::unique_ptr user_op_conf_; Symbol stream_; std::unique_ptr reg_ctx_helper_; std::unique_ptr op_infer_ctx_helper_; std::unique_ptr init_and_cache_ctx_helper_; std::unique_ptr compute_ctx_helper_; std::shared_ptr input_arg_tuple_; std::shared_ptr output_arg_tuple_; user_op::TensorDescInferFn tensor_desc_infer_fn_; user_op::DataTypeInferFn data_type_infer_fn_; // NOTE: every device has its own stateful local opkernel instance, // so only group kernels by dtype std::array>>, DataType_ARRAYSIZE> dtype2cached_kernels_; HashMap> op_kernel_state_map_; HashMap> op_kernel_cache_map_; HashMap infer_tmp_size_fn_map_; OpArgsVector input_tuple_indexes4const_ibns_; OpArgsVector input_tuple_indexes4mut_ibns_; OpArgsVector output_tuple_indexes4mut_obns_; OpArgsVector output_tuple_indexes4mut2_obns_; OpArgsVector output_tuple_indexes2is_mut2_type_; }; } // namespace one } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ ================================================ FILE: oneflow/user/kernels/summary_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/summary/events_writer.h" #include "oneflow/user/summary/env_time.h" #include "oneflow/user/summary/histogram.h" #include "oneflow/user/summary/event_writer_helper.h" #include #include namespace oneflow { namespace summary { template class SummaryWriteScalar final : public user_op::OpKernel { public: SummaryWriteScalar() = default; ~SummaryWriteScalar() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex("step", 0); const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex("tag", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("in", 0); T* tvalue = const_cast(value->dptr()); CHECK_NOTNULL(tvalue); int64_t* istep = const_cast(step->dptr()); CHECK_NOTNULL(istep); int8_t* ctag = const_cast(tag->dptr()); CHECK_NOTNULL(ctag); std::string tag_str(reinterpret_cast(ctag), tag->shape_view().elem_cnt()); EventWriterHelper::WriteScalarToFile( istep[0], static_cast(tvalue[0]), tag_str); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_SCALAR_USER_KERNEL(dtype) \ REGISTER_USER_KERNEL("summary_write_scalar") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_SCALAR_USER_KERNEL(double) REGISTER_SCALAR_USER_KERNEL(float) REGISTER_SCALAR_USER_KERNEL(int64_t) REGISTER_SCALAR_USER_KERNEL(int32_t) class CreateSummaryWriter final : public user_op::OpKernel { public: CreateSummaryWriter() = default; ~CreateSummaryWriter() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const std::string& logdir = ctx->Attr("logdir"); CHECK_JUST(Singleton::Get()->Init(logdir)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("create_summary_writer") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)); class FlushSummaryWriter final : public user_op::OpKernel { public: FlushSummaryWriter() = default; ~FlushSummaryWriter() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { Singleton::Get()->Flush(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("flush_summary_writer") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)); template class SummaryWriteHistogram final : public user_op::OpKernel { public: SummaryWriteHistogram() = default; ~SummaryWriteHistogram() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex("step", 0); const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex("tag", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("in", 0); int64_t* istep = const_cast(step->dptr()); CHECK_NOTNULL(istep); int8_t* ctag = const_cast(tag->dptr()); CHECK_NOTNULL(ctag); std::string tag_str(reinterpret_cast(ctag), tag->shape_view().elem_cnt()); EventWriterHelper::WriteHistogramToFile(static_cast(istep[0]), *value, tag_str); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_HISTOGRAM_USER_KERNEL(dtype) \ REGISTER_USER_KERNEL("summary_write_histogram") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)); REGISTER_HISTOGRAM_USER_KERNEL(double) REGISTER_HISTOGRAM_USER_KERNEL(float) REGISTER_HISTOGRAM_USER_KERNEL(int64_t) REGISTER_HISTOGRAM_USER_KERNEL(int32_t) REGISTER_HISTOGRAM_USER_KERNEL(int8_t) REGISTER_HISTOGRAM_USER_KERNEL(uint8_t) template class SummaryWritePb final : public user_op::OpKernel { public: SummaryWritePb() = default; ~SummaryWritePb() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex("step", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("in", 0); int64_t* istep = const_cast(step->dptr()); CHECK_NOTNULL(istep); int8_t* cvalue = const_cast(value->dptr()); CHECK_NOTNULL(cvalue); std::string value_str(reinterpret_cast(cvalue), value->shape_view().elem_cnt()); EventWriterHelper::WritePbToFile(istep[0], value_str); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("summary_write_pb") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == GetDataType::value)); template class SummaryWriteImage final : public user_op::OpKernel { public: SummaryWriteImage() = default; ~SummaryWriteImage() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex("step", 0); const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex("tag", 0); const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex("in", 0); int64_t* istep = const_cast(step->dptr()); CHECK_NOTNULL(istep); char* ctag = const_cast(tag->dptr()); CHECK_NOTNULL(ctag); std::string tag_str(ctag, tag->shape_view().elem_cnt()); EventWriterHelper::WriteImageToFile(static_cast(istep[0]), *value, tag_str); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("summary_write_image") .SetCreateFn>() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == GetDataType::value)); } // namespace summary } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tensor_buffer_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/thread/thread_manager.h" namespace oneflow { namespace { class TensorBufferToTensorKernel final : public user_op::OpKernel { public: TensorBufferToTensorKernel() = default; ~TensorBufferToTensorKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); CHECK_EQ(in->data_type(), DataType::kTensorBuffer); const ShapeView& out_shape = out->shape_view(); const auto& instance_shape = ctx->Attr("instance_shape"); CHECK_EQ(out_shape.NumAxes(), in_shape.NumAxes() + instance_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, in_shape.NumAxes()) { CHECK_EQ(out_shape.At(i), in_shape.At(i)); } FOR_RANGE(int64_t, i, 0, instance_shape.NumAxes()) { CHECK_EQ(out_shape.At(i + in_shape.NumAxes()), instance_shape.At(i)); } const auto data_type = ctx->Attr("dtype"); CHECK_EQ(out->data_type(), data_type); const int64_t instance_size = instance_shape.elem_cnt() * GetSizeOfDataType(data_type); const auto* in_ptr = in->dptr(); auto* out_ptr = out->mut_dptr(); MultiThreadLoop(in_shape.elem_cnt(), [&](size_t i) { const TensorBuffer* tensor_buffer = in_ptr + i; CHECK_EQ(tensor_buffer->nbytes(), instance_size); CHECK_EQ(tensor_buffer->data_type(), data_type); CHECK(tensor_buffer->shape_view() == instance_shape); Memcpy(ctx->stream(), out_ptr + i * instance_size, tensor_buffer->data(), instance_size); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("tensor_buffer_to_tensor") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer)); class TensorToTensorBufferKernel final : public user_op::OpKernel { public: TensorToTensorBufferKernel() = default; ~TensorToTensorBufferKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); const ShapeView& out_shape = out->shape_view(); const auto instance_dims = ctx->Attr("instance_dims"); CHECK_LT(instance_dims, in_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, in_shape.NumAxes() - instance_dims) { CHECK_EQ(out_shape.At(i), in_shape.At(i)); } DimVector instance_dim_vec; FOR_RANGE(int64_t, i, in_shape.NumAxes() - instance_dims, in_shape.NumAxes()) { instance_dim_vec.emplace_back(in_shape.At(i)); } const Shape instance_shape(instance_dim_vec); const auto data_type = in->data_type(); CHECK(IsTriviallyCopyableDataType(data_type)); const int64_t instance_size = instance_shape.elem_cnt() * GetSizeOfDataType(data_type); const auto* in_ptr = in->dptr(); auto* out_ptr = out->mut_dptr(); MultiThreadLoop(in_shape.Count(0, in_shape.NumAxes() - instance_dims), [&](size_t i) { TensorBuffer* tensor_buffer = out_ptr + i; tensor_buffer->Resize(instance_shape, data_type); CHECK_EQ(tensor_buffer->nbytes(), instance_size); Memcpy(ctx->stream(), tensor_buffer->mut_data(), in_ptr + i * instance_size, instance_size); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("tensor_to_tensor_buffer") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("out", 0) == DataType::kTensorBuffer)); template class GenTensorBuffer final : public user_op::OpKernel { public: GenTensorBuffer() = default; ~GenTensorBuffer() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t num_tensor_buffers = ctx->Attr("shape").elem_cnt(); const std::vector& shape_list = ctx->Attr>("shape_list"); const std::vector& value_list = ctx->Attr>("value_list"); CHECK_EQ(num_tensor_buffers, shape_list.size()); CHECK_EQ(num_tensor_buffers, value_list.size()); MultiThreadLoop(num_tensor_buffers, [&](size_t i) { TensorBuffer* tensor_buffer = out->mut_dptr() + i; const Shape& shape = shape_list.at(i); tensor_buffer->Resize(shape, GetDataType::value); T* begin = reinterpret_cast(tensor_buffer->mut_data()); std::fill(begin, begin + shape.elem_cnt(), static_cast(value_list.at(i))); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_GEN_TENSOR_BUFFER_KERNEL(dtype) \ REGISTER_USER_KERNEL("gen_tensor_buffer") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("data_type") == GetDataType::value)); REGISTER_GEN_TENSOR_BUFFER_KERNEL(int32_t) REGISTER_GEN_TENSOR_BUFFER_KERNEL(int64_t) REGISTER_GEN_TENSOR_BUFFER_KERNEL(float) REGISTER_GEN_TENSOR_BUFFER_KERNEL(double) #undef REGISTER_GEN_TENSOR_BUFFER_KERNEL class TensorBufferToListOfTensors final : public user_op::OpKernel { public: TensorBufferToListOfTensors() = default; ~TensorBufferToListOfTensors() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_GT(in->shape_view().elem_cnt(), 0); CHECK_EQ(in->data_type(), DataType::kTensorBuffer); const DataType out_dtype = ctx->Attr("out_dtype"); CHECK(IsTriviallyCopyableDataType(out_dtype)); const bool dynamic_out = ctx->Attr("dynamic_out"); const auto* in_ptr = in->dptr(); MultiThreadLoop(in->shape_view().elem_cnt(), [&](size_t i) { const TensorBuffer* tensor_buffer = in_ptr + i; user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex("out", i); CHECK_EQ(out_dtype, tensor_buffer->data_type()); if (dynamic_out) { CHECK_LE(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt()); out_i->mut_shape_view().set_shape(tensor_buffer->shape_view()); } else { CHECK_EQ(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt()); } Memcpy(ctx->stream(), out_i->mut_dptr(), tensor_buffer->data(), tensor_buffer->nbytes()); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("tensor_buffer_to_list_of_tensors") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer)); class TensorBufferToListOfTensorsV2 final : public user_op::OpKernel { public: TensorBufferToListOfTensorsV2() = default; ~TensorBufferToListOfTensorsV2() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_GT(in->shape_view().elem_cnt(), 0); CHECK_EQ(in->data_type(), DataType::kTensorBuffer); const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); const bool dynamic_out = ctx->Attr("dynamic_out"); const auto* in_ptr = in->dptr(); MultiThreadLoop(in->shape_view().elem_cnt(), [&](size_t i) { CHECK(IsTriviallyCopyableDataType(out_dtypes[i])); const TensorBuffer* tensor_buffer = in_ptr + i; user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex("out", i); CHECK_EQ(out_dtypes[i], tensor_buffer->data_type()); if (dynamic_out) { CHECK_LE(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt()); out_i->mut_shape_view().set_shape(tensor_buffer->shape_view()); } else { CHECK_EQ(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt()); } Memcpy(ctx->stream(), out_i->mut_dptr(), tensor_buffer->data(), tensor_buffer->nbytes()); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; REGISTER_USER_KERNEL("tensor_buffer_to_list_of_tensors_v2") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) && (user_op::HobDataType("in", 0) == DataType::kTensorBuffer)); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tensor_constant_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/tensor_fill.h" namespace oneflow { namespace user_op { namespace { template std::unique_ptr NewTensorFillPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return ep::primitive::NewPrimitive(ctx->device_type(), data_type); } class TensorConstantKernel final : public OpKernel { public: TensorConstantKernel() = default; ~TensorConstantKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK(value_tensor->shape_view().NumAxes() <= 1 && value_tensor->shape_view().elem_cnt() == 1) << "Only scalar tensor as filled value is supported!"; const int64_t elem_cnt = out_tensor->shape_view().elem_cnt(); CHECK_GE(elem_cnt, 0); if (elem_cnt == 0) { return; } std::unique_ptr tensor_fill = NewTensorFillPrimitive(ctx); CHECK(tensor_fill); tensor_fill->Launch(ctx->stream(), value_tensor->raw_dptr(), out_tensor->mut_dptr(), elem_cnt); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto TensorFillPrimitiveExists() { return hob::make_custom("TensorFillPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewTensorFillPrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("tensor_constant") .SetCreateFn() .SetIsMatchedHob(TensorFillPrimitiveExists() == true); } // namespace } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tf_pool_cpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/core/common/eigen_util.h" namespace oneflow { namespace { struct PoolOpKernelCache final : public user_op::OpKernelCache { Params3D params_3d; explicit PoolOpKernelCache(const Params3D& params_3d) : params_3d(params_3d) {} const Params3D& GetParams3D() const { return params_3d; } }; std::shared_ptr InitPoolOpKernelCache(user_op::KernelCacheContext* ctx, const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); const std::vector& pool_size = ctx->Attr>("pool_size"); const std::vector& strides = ctx->Attr>("strides"); const bool ceil_mode = ctx->Attr("ceil_mode"); Params3D params_3d = Params3D(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); std::shared_ptr state(new PoolOpKernelCache(params_3d)); return state; } template struct PoolCpuKernelUtil { public: typedef std::function ForwardInitialize; typedef std::function CFirstProcess; typedef std::function& in_mat, EigenMatrixMap& out_mat)> CLastProcess; typedef std::function CFirstFinalize; typedef std::function& out_mat)> CLastFinalize; typedef std::function CFirstProcessGrad; typedef std::function& out_arr, ConstEigenArrayMap& in_arr, ConstEigenArrayMap& out_diff_arr, EigenArrayMap& in_diff_arr)> CLastProcessGrad; static void CFirstForward(const Params3D& params_3d, const user_op::Tensor* in_blob, user_op::Tensor* out_blob, const ForwardInitialize& initialize, const CFirstProcess& process, const CFirstFinalize& finalize) { const Shape& in = params_3d.GetXShape5D(); const Shape& out = params_3d.GetYShape5D(); const std::vector& pool_size = params_3d.pool_size_3d(); const std::vector& strides = params_3d.strides_3d(); const std::vector& padding_before = params_3d.padding_before_3d(); const T* input = in_blob->dptr(); T* output = out_blob->mut_dptr(); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, pd, 0, out.At(2)) { int64_t dstart = pd * strides.at(0) - padding_before.at(0); int64_t dend = std::min(dstart + pool_size.at(0), in.At(2)); dstart = std::max(dstart, static_cast(0)); FOR_RANGE(int64_t, ph, 0, out.At(3)) { int64_t hstart = ph * strides.at(1) - padding_before.at(1); int64_t hend = std::min(hstart + pool_size.at(1), in.At(3)); hstart = std::max(hstart, static_cast(0)); FOR_RANGE(int64_t, pw, 0, out.At(4)) { int64_t wstart = pw * strides.at(2) - padding_before.at(2); int64_t wend = std::min(wstart + pool_size.at(2), in.At(4)); wstart = std::max(wstart, static_cast(0)); const int64_t pool_index = pd * out.Count(3) + ph * out.At(4) + pw; T res = initialize(); FOR_RANGE(int64_t, d, dstart, dend) { FOR_RANGE(int64_t, h, hstart, hend) { FOR_RANGE(int64_t, w, wstart, wend) { const int64_t input_index = d * in.Count(3) + h * in.At(4) + w; process(input[input_index], res); } } } finalize((dend - dstart) * (hend - hstart) * (wend - wstart), res); output[pool_index] = res; } } } input += in.Count(2); output += out.Count(2); } } } static void CFirstBackward(const Params3D& params_3d, const user_op::Tensor* out_diff_blob, const user_op::Tensor* out_blob, const user_op::Tensor* in_blob, user_op::Tensor* in_diff_blob, const CFirstProcessGrad& process) { const Shape& in = params_3d.GetXShape5D(); const Shape& out = params_3d.GetYShape5D(); const std::vector& pool_size = params_3d.pool_size_3d(); const std::vector& strides = params_3d.strides_3d(); const std::vector& padding_before = params_3d.padding_before_3d(); const T* output_diff = out_diff_blob->dptr(); const T* output = out_blob->dptr(); const T* input = in_blob->dptr(); std::memset(in_diff_blob->mut_dptr(), T(0), in.elem_cnt() * sizeof(T)); T* input_diff = in_diff_blob->mut_dptr(); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, c, 0, in.At(1)) { FOR_RANGE(int64_t, pd, 0, out.At(2)) { int64_t dstart = pd * strides.at(0) - padding_before.at(0); int64_t dend = std::min(dstart + pool_size.at(0), in.At(2)); dstart = std::max(dstart, static_cast(0)); FOR_RANGE(int64_t, ph, 0, out.At(3)) { int64_t hstart = ph * strides.at(1) - padding_before.at(1); int64_t hend = std::min(hstart + pool_size.at(1), in.At(3)); hstart = std::max(hstart, static_cast(0)); FOR_RANGE(int64_t, pw, 0, out.At(4)) { int64_t wstart = pw * strides.at(2) - padding_before.at(2); int64_t wend = std::min(wstart + pool_size.at(2), in.At(4)); wstart = std::max(wstart, static_cast(0)); const int64_t size = (dend - dstart) * (hend - hstart) * (wend - wstart); const int64_t pool_index = pd * out.Count(3) + ph * out.At(4) + pw; FOR_RANGE(int64_t, d, dstart, dend) { FOR_RANGE(int64_t, h, hstart, hend) { FOR_RANGE(int64_t, w, wstart, wend) { const int64_t index = d * in.Count(3) + h * in.At(4) + w; process(input[index], output[pool_index], output_diff[pool_index], size, input_diff[index]); } } } } } } // offset input += in.Count(2); input_diff += in.Count(2); output += out.Count(2); output_diff += out.Count(2); } } } static void CLastForward(const Params3D& params_3d, const user_op::Tensor* in_blob, user_op::Tensor* out_blob, const ForwardInitialize& forward_initialize, const CLastProcess& process, const CLastFinalize& finalize) { const Shape& in = params_3d.GetXShape5D(); const Shape& out = params_3d.GetYShape5D(); const std::vector& pool_size = params_3d.pool_size_3d(); const std::vector& strides = params_3d.strides_3d(); const std::vector& padding_before = params_3d.padding_before_3d(); ConstEigenMatrixMap in_mat(in_blob->dptr(), in.At(1), in.elem_cnt() / in.At(1)); EigenMatrixMap out_mat(out_blob->mut_dptr(), out.At(1), out.elem_cnt() / out.At(1)); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, pd, 0, out.At(2)) { int64_t dstart = pd * strides.at(0) - padding_before.at(0); int64_t dend = std::min(dstart + pool_size.at(0), in.At(2)); dstart = std::max(dstart, static_cast(0)); FOR_RANGE(int64_t, ph, 0, out.At(3)) { int64_t hstart = ph * strides.at(1) - padding_before.at(1); int64_t hend = std::min(hstart + pool_size.at(1), in.At(3)); hstart = std::max(hstart, static_cast(0)); FOR_RANGE(int64_t, pw, 0, out.At(4)) { int64_t wstart = pw * strides.at(2) - padding_before.at(2); int64_t wend = std::min(wstart + pool_size.at(2), in.At(4)); wstart = std::max(wstart, static_cast(0)); const int out_col = ((n * out.At(2) + pd) * out.At(3) + ph) * out.At(4) + pw; out_mat.col(out_col).setConstant(forward_initialize()); FOR_RANGE(int64_t, d, dstart, dend) { FOR_RANGE(int64_t, h, hstart, hend) { FOR_RANGE(int64_t, w, wstart, wend) { const int in_col = ((n * in.At(2) + d) * in.At(3) + h) * in.At(4) + w; process(in_col, out_col, in_mat, out_mat); } } } finalize((hend - hstart) * (wend - wstart) * (dend - dstart), out_col, out_mat); } } } } } static void CLastBackward(const Params3D& params_3d, const user_op::Tensor* out_diff_blob, const user_op::Tensor* out_blob, const user_op::Tensor* in_blob, user_op::Tensor* in_diff_blob, const CLastProcessGrad& process) { const Shape& in = params_3d.GetXShape5D(); const Shape& out = params_3d.GetYShape5D(); const std::vector& pool_size = params_3d.pool_size_3d(); const std::vector& strides = params_3d.strides_3d(); const std::vector& padding_before = params_3d.padding_before_3d(); // caffe2 implementation: need check ConstEigenArrayMap out_mat(out_blob->dptr(), out.At(1), out.elem_cnt() / out.At(1)); ConstEigenArrayMap in_mat(in_blob->dptr(), in.At(1), in.elem_cnt() / in.At(1)); ConstEigenArrayMap out_diff_mat(out_diff_blob->dptr(), out.At(1), out.elem_cnt() / out.At(1)); std::memset(in_diff_blob->mut_dptr(), T(0), in.elem_cnt() * sizeof(T)); EigenArrayMap in_diff_mat(in_diff_blob->mut_dptr(), in.At(1), in.elem_cnt() / in.At(1)); FOR_RANGE(int64_t, n, 0, in.At(0)) { FOR_RANGE(int64_t, pd, 0, out.At(2)) { int64_t dstart = pd * strides.at(0) - padding_before.at(0); int64_t dend = std::min(dstart + pool_size.at(0), in.At(2)); dstart = std::max(dstart, static_cast(0)); FOR_RANGE(int64_t, ph, 0, out.At(3)) { int64_t hstart = ph * strides.at(1) - padding_before.at(1); int64_t hend = std::min(hstart + pool_size.at(1), in.At(3)); hstart = std::max(hstart, static_cast(0)); FOR_RANGE(int64_t, pw, 0, out.At(4)) { int64_t wstart = pw * strides.at(2) - padding_before.at(2); int64_t wend = std::min(wstart + pool_size.at(2), in.At(4)); wstart = std::max(wstart, static_cast(0)); const int64_t pool_index = ((n * out.At(2) + pd) * out.At(3) + ph) * out.At(4) + pw; const int64_t size = (dend - dstart) * (hend - hstart) * (wend - wstart); FOR_RANGE(int64_t, d, dstart, dend) { FOR_RANGE(int64_t, h, hstart, hend) { FOR_RANGE(int64_t, w, wstart, wend) { const int64_t input_index = ((n * in.At(2) + d) * in.At(3) + h) * in.At(4) + w; process(pool_index, input_index, size, out_mat, in_mat, out_diff_mat, in_diff_mat); } } } } } } } } static void AvgFWCompute(user_op::KernelComputeContext* ctx, const PoolOpKernelCache* pool_state) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); CHECK_NOTNULL(pool_state); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { CFirstForward( pool_state->GetParams3D(), x, y, GetZeroVal, [](const T& lhs, T& rhs) { rhs += lhs; }, [](const int64_t size, T& out) { out /= size; }); } else if (data_format == "channels_last") { CLastForward( pool_state->GetParams3D(), x, y, GetZeroVal, [](const int64_t in_col, const int64_t out_col, ConstEigenMatrixMap& in_mat, EigenMatrixMap& out_mat) { out_mat.col(out_col) += in_mat.col(in_col); }, [](const int64_t size, const int64_t col, EigenMatrixMap& out_mat) { out_mat.col(col) /= size; }); } else { UNIMPLEMENTED(); } } static void AvgBWCompute(user_op::KernelComputeContext* ctx, const PoolOpKernelCache* pool_state) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); CHECK_NOTNULL(pool_state); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { CFirstBackward(pool_state->GetParams3D(), dy, y, x, dx, [](const T& in, const T& out, const T& out_diff, const int64_t size, T& in_diff) { in_diff += (out_diff / static_cast(size)); }); } else if (data_format == "channels_last") { CLastBackward(pool_state->GetParams3D(), dy, y, x, dx, [](const int64_t out_col, const int64_t in_col, const int64_t size, ConstEigenArrayMap& out_arr, ConstEigenArrayMap& in_arr, ConstEigenArrayMap& out_diff_arr, EigenArrayMap& in_diff_arr) { in_diff_arr.col(in_col) += out_diff_arr.col(out_col) / static_cast(size); }); } else { UNIMPLEMENTED(); } } static void MaxFWCompute(user_op::KernelComputeContext* ctx, const PoolOpKernelCache* pool_state) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); CHECK_NOTNULL(pool_state); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { CFirstForward( pool_state->GetParams3D(), x, y, GetMinVal, [](const T& lhs, T& rhs) { if (lhs > rhs) { rhs = lhs; } }, [](const int64_t size, T& out) {}); } else if (data_format == "channels_last") { CLastForward( pool_state->GetParams3D(), x, y, GetMinVal, [](const int64_t in_col, const int64_t out_col, ConstEigenMatrixMap& in_mat, EigenMatrixMap& out_mat) { out_mat.col(out_col) = out_mat.col(out_col).cwiseMax(in_mat.col(in_col)); }, [](const int64_t size, const int64_t col, EigenMatrixMap& out_mat) {}); } else { UNIMPLEMENTED(); } } static void MaxBWCompute(user_op::KernelComputeContext* ctx, const PoolOpKernelCache* pool_state) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); CHECK_NOTNULL(pool_state); const std::string& data_format = ctx->Attr("data_format"); if (data_format == "channels_first") { CFirstBackward( pool_state->GetParams3D(), dy, y, x, dx, [](const T& in, const T& out, const T& out_diff, const int64_t size, T& in_diff) { if (in == out) { in_diff += out_diff; } }); } else if (data_format == "channels_last") { CLastBackward( pool_state->GetParams3D(), dy, y, x, dx, [](const int64_t out_col, const int64_t in_col, const int64_t size, ConstEigenArrayMap& out_arr, ConstEigenArrayMap& in_arr, ConstEigenArrayMap& out_diff_arr, EigenArrayMap& in_diff_arr) { in_diff_arr.col(in_col) += out_diff_arr.col(out_col) * (in_arr.col(in_col).cwiseEqual(out_arr.col(out_col)).template cast()); }); } else { UNIMPLEMENTED(); } } }; } // namespace template class AvgPool1DCpuKernel final : public user_op::OpKernel { public: AvgPool1DCpuKernel() = default; ~AvgPool1DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; template class AvgPool1DGradCpuKernel final : public user_op::OpKernel { public: AvgPool1DGradCpuKernel() = default; ~AvgPool1DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; template class AvgPool2DCpuKernel final : public user_op::OpKernel { public: AvgPool2DCpuKernel() = default; ~AvgPool2DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; template class AvgPool2DGradCpuKernel final : public user_op::OpKernel { public: AvgPool2DGradCpuKernel() = default; ~AvgPool2DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; template class AvgPool3DCpuKernel final : public user_op::OpKernel { public: AvgPool3DCpuKernel() = default; ~AvgPool3DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; template class AvgPool3DGradCpuKernel final : public user_op::OpKernel { public: AvgPool3DGradCpuKernel() = default; ~AvgPool3DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool1DCpuKernel final : public user_op::OpKernel { public: MaxPool1DCpuKernel() = default; ~MaxPool1DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool1DGradCpuKernel final : public user_op::OpKernel { public: MaxPool1DGradCpuKernel() = default; ~MaxPool1DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 1); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool2DCpuKernel final : public user_op::OpKernel { public: MaxPool2DCpuKernel() = default; ~MaxPool2DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool2DGradCpuKernel final : public user_op::OpKernel { public: MaxPool2DGradCpuKernel() = default; ~MaxPool2DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 2); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool3DCpuKernel final : public user_op::OpKernel { public: MaxPool3DCpuKernel() = default; ~MaxPool3DCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; template class MaxPool3DGradCpuKernel final : public user_op::OpKernel { public: MaxPool3DGradCpuKernel() = default; ~MaxPool3DGradCpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return InitPoolOpKernelCache(ctx, 3); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; #define REGISTER_POOL_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("tf_avg_pool_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_avg_pool_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_avg_pool_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_avg_pool_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_avg_pool_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_avg_pool_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("tf_max_pool_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_POOL_CPU_KERNEL(float) REGISTER_POOL_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tf_pool_gpu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { class CudnnPoolDesc final { public: OF_DISALLOW_COPY_AND_MOVE(CudnnPoolDesc); CudnnPoolDesc(cudnnPoolingMode_t pooling_mode, int dims, const int* window, const int* padding, const int* stride) { OF_CUDNN_CHECK(cudnnCreatePoolingDescriptor(&val_)); OF_CUDNN_CHECK(cudnnSetPoolingNdDescriptor(val_, pooling_mode, CUDNN_NOT_PROPAGATE_NAN, dims, window, padding, stride)); } ~CudnnPoolDesc() { OF_CUDNN_CHECK(cudnnDestroyPoolingDescriptor(val_)); } const cudnnPoolingDescriptor_t& Get() const { return val_; } private: cudnnPoolingDescriptor_t val_; }; class GPUPoolOpKernelCache final : public user_op::OpKernelCache { public: GPUPoolOpKernelCache(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, const ShapeView& y_shape, const std::string& data_format, const DataType& dtype, const Params3D& params_3d) : pooling_type_(pooling_type) { Reset(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); } ~GPUPoolOpKernelCache() = default; void Reset(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, const ShapeView& y_shape, const std::string& data_format, const DataType& dtype, const Params3D& params_3d) { FixedVector pool_size(dim); FixedVector padding(dim); FixedVector strides(dim); FOR_RANGE(int, i, 0, dim) { int32_t index_in_3d = i + 3 - dim; pool_size[i] = params_3d.pool_size_3d().at(index_in_3d); padding[i] = params_3d.padding_before_3d().at(index_in_3d); strides[i] = params_3d.strides_3d().at(index_in_3d); } x_desc_.reset(new CudnnTensorDesc(dtype, x_shape, data_format)); y_desc_.reset(new CudnnTensorDesc(dtype, y_shape, data_format)); cudnnPoolingMode_t pooling_mode; if (pooling_type == "AVG") { pooling_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; } else if (pooling_type == "MAX") { pooling_mode = CUDNN_POOLING_MAX; } else { UNIMPLEMENTED(); } pooling_desc_.reset( new CudnnPoolDesc(pooling_mode, dim, pool_size.data(), padding.data(), strides.data())); } static std::shared_ptr FromKernelComputeContext( const int32_t& dim, const std::string& pooling_type, user_op::KernelCacheContext* ctx) { if (pooling_type != "MAX" && pooling_type != "AVG") { UNIMPLEMENTED(); } const ShapeView& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); const std::vector& pool_size = ctx->Attr>("pool_size"); const std::vector& strides = ctx->Attr>("strides"); const bool ceil_mode = ctx->Attr("ceil_mode"); const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); const ShapeView& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); const DataType dtype = ctx->TensorDesc4ArgNameAndIndex("x", 0)->data_type(); return std::make_shared(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); } const cudnnTensorDescriptor_t& cudnn_x_tensor_desc() const { return x_desc_->Get(); } const cudnnTensorDescriptor_t& cudnn_y_tensor_desc() const { return y_desc_->Get(); } const cudnnPoolingDescriptor_t& cudnn_pooling_desc() const { return pooling_desc_->Get(); } private: std::unique_ptr x_desc_; std::unique_ptr y_desc_; std::unique_ptr pooling_desc_; std::string pooling_type_; }; struct PoolGpuKernelUtil { static void FWCompute(user_op::KernelComputeContext* ctx, const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); CHECK(gpu_pool_op_kernel_cache != nullptr); OF_CUDNN_CHECK(cudnnPoolingForward( ctx->stream()->As()->cudnn_handle(), gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(x->data_type()), gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(x->data_type()), gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->mut_dptr())); } static void BWCompute(user_op::KernelComputeContext* ctx, const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); CHECK(gpu_pool_op_kernel_cache != nullptr); OF_CUDNN_CHECK(cudnnPoolingBackward( ctx->stream()->As()->cudnn_handle(), gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(y->data_type()), gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->dptr(), gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), dy->dptr(), gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(y->data_type()), gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), dx->mut_dptr())); } }; } // namespace class AvgPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool1DGpuKernel() = default; ~AvgPool1DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(1, "AVG", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class AvgPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool1DGradGpuKernel() = default; ~AvgPool1DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(1, "AVG", ctx); } void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; class AvgPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool2DGpuKernel() = default; ~AvgPool2DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(2, "AVG", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class AvgPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool2DGradGpuKernel() = default; ~AvgPool2DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(2, "AVG", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; class AvgPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool3DGpuKernel() = default; ~AvgPool3DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(3, "AVG", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class AvgPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: AvgPool3DGradGpuKernel() = default; ~AvgPool3DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(3, "AVG", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool1DGpuKernel() = default; ~MaxPool1DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(1, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool1DGradGpuKernel() = default; ~MaxPool1DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(1, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool2DGpuKernel() = default; ~MaxPool2DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(2, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool2DGradGpuKernel() = default; ~MaxPool2DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(2, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool3DGpuKernel() = default; ~MaxPool3DGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(3, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; class MaxPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: MaxPool3DGradGpuKernel() = default; ~MaxPool3DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return GPUPoolOpKernelCache::FromKernelComputeContext(3, "MAX", ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; REGISTER_USER_KERNEL("tf_avg_pool_1d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_avg_pool_1d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_avg_pool_2d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_avg_pool_2d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_avg_pool_3d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_avg_pool_3d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_1d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_1d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_2d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_2d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_3d") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); REGISTER_USER_KERNEL("tf_max_pool_3d_grad") .SetCreateFn() .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)); } // namespace oneflow #endif ================================================ FILE: oneflow/user/kernels/tf_prelu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template class TfCpuPReluKernel final : public user_op::OpKernel { public: TfCpuPReluKernel() = default; ~TfCpuPReluKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); user_op::Tensor* broadcasted_alpha = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); T* broadcasted_alpha_ptr = broadcasted_alpha->mut_dptr(); const int32_t elem_cnt = x->shape_view().elem_cnt(); const Shape& left_extended_shape = CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes()); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(x->shape_view(), broadcasted_alpha_ptr), XpuVarNdarray(left_extended_shape, alpha->dptr())); FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * broadcasted_alpha_ptr[i]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TF_CPU_PRELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("tf_prelu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("x", 0); \ return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ }); REGISTER_TF_CPU_PRELU_KERNEL(float) REGISTER_TF_CPU_PRELU_KERNEL(double) template class TfCpuPReluGradKernel final : public user_op::OpKernel { public: TfCpuPReluGradKernel() = default; ~TfCpuPReluGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex("alpha_diff", 0); const T* x_ptr = x->dptr(); const T* dy_ptr = dy->dptr(); T* dx_ptr = dx->mut_dptr(); const int32_t elem_cnt = x->shape_view().elem_cnt(); T* broadcasted_alpha_ptr = tmp_buffer->mut_dptr(); T* broadcasted_alpha_diff = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(elem_cnt * sizeof(T))); T* reduce_sum_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + 2 * GetCudaAlignedSize(elem_cnt * sizeof(T))); const Shape& left_extended_shape = CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes()); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(x->shape_view(), broadcasted_alpha_ptr), XpuVarNdarray(left_extended_shape, alpha->dptr())); FOR_RANGE(int32_t, i, 0, elem_cnt) { dx_ptr[i] = x_ptr[i] > 0 ? dy_ptr[i] : dy_ptr[i] * broadcasted_alpha_ptr[i]; broadcasted_alpha_diff[i] = x_ptr[i] > 0 ? 0 : dy_ptr[i] * x_ptr[i]; } NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(left_extended_shape, alpha_diff->mut_dptr()), XpuVarNdarray(x->shape_view(), broadcasted_alpha_diff), XpuVarNdarray(x->shape_view(), reduce_sum_tmp_buf)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TF_CPU_PRELU_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("tf_prelu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("x", 0); \ return 3 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ }); REGISTER_TF_CPU_PRELU_GRAD_KERNEL(float) REGISTER_TF_CPU_PRELU_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tf_prelu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void BroadcastPReluForwardGpu(const int32_t elem_cnt, const int32_t alpha_size, const int32_t inner_size, const T* x, const T* alpha, T* y) { T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; const T alpha_i = alpha[(i / inner_size) % alpha_size]; y[i] = x_i > zero_val ? x_i : x_i * alpha_i; } } template __global__ void BroadcastPReluBackwardGpu(const int32_t elem_cnt, const int32_t alpha_size, const int32_t inner_size, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff) { T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; const T dy_i = dy[i]; const T alpha_i = alpha[(i / inner_size) % alpha_size]; T dx_i = zero_val; T alpha_diff_i = zero_val; if (x_i > zero_val) { dx_i = dy_i; alpha_diff_i = zero_val; } else { dx_i = dy_i * alpha_i; alpha_diff_i = dy_i * x_i; } dx[i] = dx_i; alpha_diff[i] = alpha_diff_i; } } template __global__ void ElemwisePReluForwardGpu(const int32_t elem_cnt, const T* x, const T* alpha, T* y) { T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; const T alpha_i = alpha[i]; y[i] = x_i > zero_val ? x_i : x_i * alpha_i; } } template __global__ void ElemwisePReluBackwardGpu(const int32_t elem_cnt, const T* x, const T* alpha, const T* dy, T* dx, T* alpha_diff) { T zero_val = static_cast(0.0); CUDA_1D_KERNEL_LOOP(i, elem_cnt) { const T x_i = x[i]; const T dy_i = dy[i]; const T alpha_i = alpha[i]; T dx_i = zero_val; T alpha_diff_i = zero_val; if (x_i > zero_val) { dx_i = dy_i; alpha_diff_i = zero_val; } else { dx_i = dy_i * alpha_i; alpha_diff_i = dy_i * x_i; } dx[i] = dx_i; alpha_diff[i] = alpha_diff_i; } } bool IsAlphaShapeContiguous(const ShapeView& alpha_shape, const ShapeView& x_shape) { if (alpha_shape.elem_cnt() == 1) { return true; } int64_t begin_idx = -1; for (int64_t i = 0; i < alpha_shape.NumAxes(); ++i) { if (alpha_shape.At(i) != 1) { begin_idx = i; break; } } CHECK_NE(begin_idx, -1); int64_t end_idx = -1; for (int64_t i = alpha_shape.NumAxes(); i > 0; --i) { if (alpha_shape.At(i - 1) != 1) { end_idx = i; break; } } CHECK_NE(end_idx, -1); if (alpha_shape.elem_cnt() == x_shape.Count(begin_idx + 1, end_idx + 1)) { return true; } else { return false; } } int32_t GetOuterSize(const ShapeView& alpha_shape, const ShapeView& x_shape) { int32_t outer_size = x_shape.At(0); for (int32_t i = 0; i < alpha_shape.NumAxes(); ++i) { if (alpha_shape.At(i) == 1) { outer_size *= x_shape.At(i + 1); } else { break; } } return outer_size; } } // namespace template class TfGpuPReluKernel final : public user_op::OpKernel { public: TfGpuPReluKernel() = default; ~TfGpuPReluKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); if (IsAlphaShapeContiguous(alpha->shape_view(), x->shape_view())) { const int32_t outer_size = GetOuterSize(alpha->shape_view(), x->shape_view()); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int32_t inner_size = elem_cnt / outer_size / alpha_size; BroadcastPReluForwardGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x->dptr(), alpha->dptr(), y->mut_dptr()); } else { user_op::Tensor* broadcasted_alpha = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const Shape& left_extended_shape = CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes()); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(x->shape_view(), broadcasted_alpha->mut_dptr()), XpuVarNdarray(left_extended_shape, alpha->dptr())); ElemwisePReluForwardGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, x->dptr(), broadcasted_alpha->dptr(), y->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TF_CUDA_PRELU_KERNEL(dtype) \ REGISTER_USER_KERNEL("tf_prelu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("x", 0); \ const Shape& alpha_shape = ctx->InputShape("alpha", 0); \ const int64_t tmp_buffer_size = \ IsAlphaShapeContiguous(alpha_shape, in_shape) \ ? 0 \ : GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ return tmp_buffer_size; \ }); REGISTER_TF_CUDA_PRELU_KERNEL(half) REGISTER_TF_CUDA_PRELU_KERNEL(float) REGISTER_TF_CUDA_PRELU_KERNEL(double) template class TfGpuPReluGradKernel final : public user_op::OpKernel { public: TfGpuPReluGradKernel() = default; ~TfGpuPReluGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex("alpha_diff", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int32_t elem_cnt = x->shape_view().elem_cnt(); T* broadcasted_alpha_diff = tmp_buffer->mut_dptr(); T* reduce_sum_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + GetCudaAlignedSize(elem_cnt * sizeof(T))); const Shape& left_extended_shape = CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes()); if (IsAlphaShapeContiguous(alpha->shape_view(), x->shape_view())) { const int32_t outer_size = GetOuterSize(alpha->shape_view(), x->shape_view()); const int32_t alpha_size = alpha->shape_view().elem_cnt(); const int32_t inner_size = elem_cnt / outer_size / alpha_size; BroadcastPReluBackwardGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, alpha_size, inner_size, x->dptr(), alpha->dptr(), dy->dptr(), dx->mut_dptr(), broadcasted_alpha_diff); } else { T* broadcasted_alpha = reinterpret_cast(tmp_buffer->mut_dptr() + 2 * GetCudaAlignedSize(elem_cnt * sizeof(T))); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(x->shape_view(), broadcasted_alpha), XpuVarNdarray(left_extended_shape, alpha->dptr())); ElemwisePReluBackwardGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, x->dptr(), broadcasted_alpha, dy->dptr(), dx->mut_dptr(), broadcasted_alpha_diff); } NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(left_extended_shape, alpha_diff->mut_dptr()), XpuVarNdarray(x->shape_view(), broadcasted_alpha_diff), XpuVarNdarray(x->shape_view(), reduce_sum_tmp_buf)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TF_CUDA_PRELU_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("tf_prelu_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("x", 0); \ const Shape& alpha_shape = ctx->InputShape("alpha", 0); \ const int64_t tmp_buffer_size = \ IsAlphaShapeContiguous(alpha_shape, in_shape) \ ? 2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)) \ : 3 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \ return tmp_buffer_size; \ }); REGISTER_TF_CUDA_PRELU_GRAD_KERNEL(half) REGISTER_TF_CUDA_PRELU_GRAD_KERNEL(float) REGISTER_TF_CUDA_PRELU_GRAD_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/throw_error_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { class ThrowErrorKernel final : public user_op::OpKernel { public: ThrowErrorKernel() = default; ~ThrowErrorKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { THROW(RuntimeError) << "throw error kernel"; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("throw_error").SetCreateFn(); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/to_contiguous_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/to_contiguous_kernel.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { template struct ToContiguousUtil : ToContiguousUtilBase { using ToContiguousUtilBase::ToContiguousUtilBase; static constexpr size_t dsize = sizeof(T); void operator()() { if (contiguous_dim == -1) { // 0-dim tensor std::memcpy(out_dptr, in_dptr, block_size * dsize); } else { // if input tensor's strides equals to output's, than just copy one memory-contiguous tensor bool is_same = true; for (int64_t i = contiguous_dim; i != -1; --i) { if (out_stride[i] != in_stride[i]) { is_same = false; break; } } if (is_same) { std::memcpy(out_dptr + out_offset * dsize, in_dptr + in_offset * dsize, element_count * dsize); } else { const int64_t ndim = contiguous_dim + 1; int64_t coordinates[ndim]; for (int64_t i = 0; i < element_count; i += block_size) { memset(coordinates, 0, sizeof(int64_t) * ndim); out_offset = i; in_offset = 0; // compute coords(output offset to coords) int64_t remaining = out_offset; for (int i = 0; i < ndim; ++i) { const int64_t idx = remaining / out_stride[i]; coordinates[i] = idx; remaining = remaining - idx * out_stride[i]; } // compute input offset for (int64_t dim = 0; dim < ndim; ++dim) { in_offset += in_stride[dim] * coordinates[dim]; } // copy block_size data to output std::memcpy(out_dptr + out_offset * dsize, in_dptr + in_offset * dsize, block_size * dsize); } } } } }; namespace { template class ToContiguousKernel final : public user_op::OpKernel { public: ToContiguousKernel() = default; ~ToContiguousKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape_view(); CHECK_EQ(out->shape_view(), in_shape); const DataType in_data_type = in->data_type(); CHECK_EQ(out->data_type(), in_data_type); std::vector in_stride(in->stride().begin(), in->stride().end()); const char* in_dptr = static_cast(in->raw_dptr()); char* out_dptr = static_cast(out->mut_raw_dptr()); ToContiguousUtil(ctx->stream(), in_shape, in_stride, in_dptr, out_dptr)(); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TO_CONTIGUOUS_KERNEL(device_type, cpp_type, data_type) \ REGISTER_USER_KERNEL("to_contiguous") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ && (user_op::HobDataType("in", 0) == data_type)); #define REGISTER_TO_CONTIGUOUS_CPU_KERNEL(cpp_type, data_type) \ REGISTER_TO_CONTIGUOUS_KERNEL(DeviceType::kCPU, cpp_type, data_type) #define REGISTER_TO_CONTIGUOUS_CUDA_KERNEL(cpp_type, data_type) \ REGISTER_TO_CONTIGUOUS_KERNEL(DeviceType::kCUDA, cpp_type, data_type) #define REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CPU_TYPES \ OF_PP_FOR_EACH_TUPLE(REGISTER_TO_CONTIGUOUS_CPU_KERNEL, TO_CONTIGUOUS_CPU_TYPES) #define REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CUDA_TYPES \ OF_PP_FOR_EACH_TUPLE(REGISTER_TO_CONTIGUOUS_CUDA_KERNEL, \ TO_CONTIGUOUS_COMMON_TYPES TO_CONTIGUOUS_CUDA_SPECIAL_TYPE) REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CPU_TYPES #ifdef WITH_CUDA REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CUDA_TYPES #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/to_contiguous_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/device_type.pb.h" #include "oneflow/user/kernels/to_contiguous_kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { namespace { constexpr int32_t kThreadWorkSize = 4; constexpr int32_t kNumThreads = 32 * 4; constexpr int32_t get_min_threads_num() { return kNumThreads; } constexpr int32_t get_block_work_size() { return kThreadWorkSize * kNumThreads; } constexpr int32_t get_num_blocks(int64_t elem_cnt) { return (elem_cnt + get_block_work_size() - 1) / get_block_work_size(); } struct StrideParam { int32_t stride[SHAPE_MAX_AXIS_SIZE]; StrideParam(const int64_t* stride_vec, const size_t ndim) { for (size_t i = 0; i < ndim; ++i) { stride[i] = stride_vec[i]; } } }; template __device__ __forceinline__ IndexType compute_index(IndexType out_offset, const StrideParam& out_params, const StrideParam& in_params) { IndexType in_offset = 0; IndexType remaining = out_offset; #pragma unroll for (size_t i = 0; i < ndim; ++i) { const IndexType idx = static_cast(remaining / out_params.stride[i]); remaining -= idx * out_params.stride[i]; in_offset += idx * in_params.stride[i]; } return in_offset; } template __global__ void ToContiguousForwardGpuParallel(IndexType count, const StrideParam in_stride, const StrideParam out_stride, const T* in_dptr, T* out_dptr, const int32_t num_block_threads, const int32_t thread_work_size, const int32_t block_work_size) { IndexType remaining = count - block_work_size * blockIdx.x; IndexType idx = blockIdx.x; IndexType thread_idx = threadIdx.x; #pragma unroll for (int32_t i = 0; i < thread_work_size; i++) { if (thread_idx >= remaining) { return; } IndexType out_idx = thread_idx + block_work_size * idx; IndexType in_idx = compute_index(out_idx, out_stride, in_stride); out_dptr[out_idx] = in_dptr[in_idx]; thread_idx += num_block_threads; } } template void LaunchToContiguousKernel(ep::Stream* stream, IndexType count, const size_t ndim, IndexType block_size, const std::vector& in_stride, const DimVector& out_stride, const char* in_dptr, char* out_dptr) { const int32_t num_blocks = get_num_blocks(count); constexpr int32_t num_threads = get_min_threads_num(); constexpr int32_t block_work_size = get_block_work_size(); StrideParam param_in_stride(in_stride.data(), ndim), param_out_stride(out_stride.data(), ndim); switch (ndim) { #define TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(dim) \ case dim: \ ToContiguousForwardGpuParallel \ <<As()->cuda_stream()>>>( \ count, param_in_stride, param_out_stride, reinterpret_cast(in_dptr), \ reinterpret_cast(out_dptr), num_threads, kThreadWorkSize, block_work_size); \ break; TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(1) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(2) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(3) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(4) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(5) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(6) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(7) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(8) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(9) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(10) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(11) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(12) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(13) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(14) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(15) TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(16) default: break; #undef TO_CONTIGUOUS_FORWARD_GPU_PARALLEL } } } // namespace template struct ToContiguousUtil : ToContiguousUtilBase { using ToContiguousUtilBase::ToContiguousUtilBase; static constexpr size_t dsize = sizeof(T); void operator()() { int constant_memory_size = 0; const size_t ndims = contiguous_dim + 1; if (ndims == 0) { // 0-dim tensor OF_CUDA_CHECK(cudaMemcpyAsync(out_dptr, in_dptr, block_size * dsize, cudaMemcpyDeviceToDevice, stream->As()->cuda_stream())); } else { bool is_same = true; for (int64_t i = contiguous_dim; i != -1; --i) { if (out_stride[i] != in_stride[i]) { is_same = false; break; } } if (is_same) { // if input tensor's strides equals to output's, than just copy one memory-contiguous tensor OF_CUDA_CHECK(cudaMemcpyAsync(out_dptr, in_dptr, element_count * dsize, cudaMemcpyDeviceToDevice, stream->As()->cuda_stream())); } else { if (element_count < GetMaxVal()) { LaunchToContiguousKernel(stream, element_count, ndims, block_size, in_stride, out_stride, in_dptr, out_dptr); } else { LaunchToContiguousKernel(stream, element_count, ndims, block_size, in_stride, out_stride, in_dptr, out_dptr); } } } } }; #define INSTANTIATE_TO_CONTIGUOUS_UTILS_FOR_CUDA(cpp_type, data_type) \ template struct ToContiguousUtil; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_TO_CONTIGUOUS_UTILS_FOR_CUDA, TO_CONTIGUOUS_COMMON_TYPES TO_CONTIGUOUS_CUDA_SPECIAL_TYPE) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/to_contiguous_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_ #define ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { class ToContiguousUtilParam { protected: ToContiguousUtilParam(ep::Stream* stream, const ShapeView& in_shape, const std::vector& in_stride, const char* in_dptr, char* out_dptr) : stream(stream), in_shape(in_shape), in_stride(in_stride), in_dptr(in_dptr), out_dptr(out_dptr) {} ep::Stream* stream; const ShapeView& in_shape; const std::vector& in_stride; const char* in_dptr; char* out_dptr; }; class ToContiguousUtilBase : public ToContiguousUtilParam { public: ToContiguousUtilBase(ep::Stream* stream, const ShapeView& in_shape, const std::vector& in_stride, const char* in_dptr, char* out_dptr) : ToContiguousUtilParam(stream, in_shape, in_stride, in_dptr, out_dptr), block_size(1), contiguous_dim(in_shape.NumAxes() - 1), out_stride(in_shape.NumAxes()), in_offset(0), out_offset(0), element_count(1) { for (int64_t i = contiguous_dim; i != -1; --i) { out_stride[i] = element_count; element_count *= in_shape.At(i); } for (int64_t i = contiguous_dim; i != -1; --i) { if (block_size == in_stride[i]) { block_size *= in_shape.At(i); } else { break; } } } int64_t block_size = 1; int64_t contiguous_dim = 0; DimVector out_stride; int64_t in_offset = 0; int64_t out_offset = 0; int64_t element_count = 1; }; template struct ToContiguousUtil : ToContiguousUtilBase { using ToContiguousUtilBase::ToContiguousUtilBase; void operator()(); }; } // namespace oneflow #define TO_CONTIGUOUS_COMMON_TYPES \ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) \ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) \ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define TO_CONTIGUOUS_CPU_TYPES \ TO_CONTIGUOUS_COMMON_TYPES COMPLEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ( \ float16, DataType::kFloat16) OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) #ifdef WITH_CUDA #if CUDA_VERSION >= 11000 #define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE \ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) \ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) \ OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) \ OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128) #else #define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE \ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) \ OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) \ OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128) #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA #endif // ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_ ================================================ FILE: oneflow/user/kernels/top_k_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { namespace { template void ComputeTopOne(const T* in_ptr, const Range& range, int64_t instance_size, int64_t* out_ptr) { FOR_RANGE(int64_t, i, range.begin(), range.end()) { const T* in_ptr_i = in_ptr + i * instance_size; out_ptr[i] = std::distance(in_ptr_i, std::max_element(in_ptr_i, in_ptr_i + instance_size)); } } template void ComputeTopK(const T* in_ptr, int64_t* indices_ptr, const Range& range, int64_t instance_size, int64_t k, bool sorted, int64_t* out_ptr) { FOR_RANGE(int64_t, i, range.begin(), range.end()) { const int64_t offset = i * instance_size; const T* in_ptr_i = in_ptr + offset; int64_t* indices_ptr_i = indices_ptr + offset; std::iota(indices_ptr_i, indices_ptr_i + instance_size, 0); auto comp = [&](const int64_t lhs, const int64_t rhs) { const T l = in_ptr_i[lhs]; const T r = in_ptr_i[rhs]; if (l == r) { return lhs < rhs; } else { return l > r; } }; std::nth_element(indices_ptr_i, indices_ptr_i + k, indices_ptr_i + instance_size, comp); if (sorted) { std::sort(indices_ptr_i, indices_ptr_i + k, comp); } std::copy(indices_ptr_i, indices_ptr_i + k, out_ptr + i * k); } } template void CpuTopK(ep::Stream* /*stream*/, const T* in_ptr, int64_t* indices_ptr, int64_t instance_num, int64_t instance_size, int64_t k, bool sorted, int64_t* out_ptr) { const int64_t num_thread = std::min(instance_num, static_cast(Singleton::Get()->thread_num())); const BalancedSplitter bs(instance_num, num_thread); BlockingCounter bc(num_thread); FOR_RANGE(int64_t, thread_id, 0, num_thread) { const Range range = bs.At(thread_id); Singleton::Get()->AddWork([=, &bc]() { if (k == 1) { ComputeTopOne(in_ptr, range, instance_size, out_ptr); } else { ComputeTopK(in_ptr, indices_ptr, range, instance_size, k, sorted, out_ptr); } bc.Decrease(); }); } bc.WaitForeverUntilCntEqualZero(); } } // namespace template class TopKCpuKernel final : public user_op::OpKernel { public: TopKCpuKernel() = default; ~TopKCpuKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); if (in->shape_view().elem_cnt() == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int64_t instance_num = in->shape_view().elem_cnt() / instance_size; const int64_t k = std::min(static_cast(ctx->Attr("k")), instance_size); int64_t* indices_ptr = tmp_buffer ? tmp_buffer->mut_dptr() : nullptr; CpuTopK(ctx->stream(), in->dptr(), indices_ptr, instance_num, instance_size, k, ctx->Attr("sorted"), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_TOP_K_KERNEL(dtype) \ REGISTER_USER_KERNEL("top_k") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ return ctx->Attr("k") > 1 ? in_shape.elem_cnt() * sizeof(int64_t) : 0; \ }); REGISTER_CPU_TOP_K_KERNEL(float) REGISTER_CPU_TOP_K_KERNEL(double) REGISTER_CPU_TOP_K_KERNEL(int8_t) REGISTER_CPU_TOP_K_KERNEL(uint8_t) REGISTER_CPU_TOP_K_KERNEL(int32_t) REGISTER_CPU_TOP_K_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/top_k_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template T PowOf2Floor(T val, int64_t max_power) { CHECK_GT(val, GetZeroVal()); T max_floor = static_cast(std::pow(2, max_power)); val = std::min(val, max_floor); T ret = GetOneVal(); while (true) { ret *= 2; if (ret >= val) { return ret == val ? ret : ret / 2; } } } template T PowOf2Ceil(T val, int64_t max_power) { CHECK_GT(val, GetZeroVal()); T max_ceil = static_cast(std::pow(2, max_power)); val = std::min(val, max_ceil); T ret = GetOneVal(); while (true) { ret *= 2; if (ret >= val) { return ret; } } } template __device__ void BitonicSwap(T* data, const int64_t i, const int64_t j, const bool dir, const Compare& comp) { if (comp(data[i], data[j]) == dir) { T tmp = data[i]; data[i] = data[j]; data[j] = tmp; } } // https://en.wikipedia.org/wiki/Bitonic_sorter template __device__ void BitonicSort(T* data, const int64_t elem_cnt, const Compare& comp) { // The element count of instance should be pow-of-2 assert(elem_cnt > 0 && !(elem_cnt & (elem_cnt - 1))); // Generate a bitonic sequence from input for (int64_t size = 2; size <= elem_cnt / 2; size *= 2) { // Merge 2 bitonic sequences of length 'size' into a bitonic sequence of length '2 * size' for (int64_t stride = size / 2; stride > 0; stride /= 2) { for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { // Change dir at intervals of 'size / 2' swaps const bool dir = swap_id & (size / 2); // Locate the pair {pos, pos + stride} which is going te be swaped if needed const int pos = 2 * swap_id - (swap_id & (stride - 1)); BitonicSwap(data, pos, pos + stride, dir, comp); __syncthreads(); } } } // Sort the bitonic sequence for (int64_t stride = elem_cnt / 2; stride > 0; stride /= 2) { for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { // Locate the pair {pos, pos + stride} which is going te be swaped if needed const int pos = 2 * swap_id - (swap_id & (stride - 1)); BitonicSwap(data, pos, pos + stride, false, comp); __syncthreads(); } } } template class Entry final { public: __device__ __forceinline__ Entry(int64_t index, T value) : index_(index), value_(value) {} __device__ __forceinline__ int64_t GetIndex() const { return index_; } __device__ __forceinline__ T GetValue() const { return value_; } __device__ __forceinline__ void SetIndex(int64_t index) { index_ = index; } __device__ __forceinline__ void SetValue(T value) { value_ = value; } __device__ __forceinline__ bool operator<(const Entry& entry) const { return (value_ < entry.GetValue()) || (value_ == entry.GetValue() && index_ > entry.GetIndex()); } __device__ __forceinline__ bool operator>(const Entry& entry) const { return (value_ > entry.GetValue()) || (value_ == entry.GetValue() && index_ < entry.GetIndex()); } private: int64_t index_; T value_; }; template class MinHeap final { public: __device__ __forceinline__ MinHeap(Entry* data, const int64_t heap_size, const int64_t init_index, const T init_value) : data_(data), heap_size_(heap_size) { for (int64_t i = 0; i < heap_size; ++i) { data_[i].SetIndex(init_index); data_[i].SetValue(init_value); } } __device__ __forceinline__ Entry& Top() { return data_[0]; } __device__ __forceinline__ void Swap(const int64_t i, const int64_t j) { auto tmp = data_[j]; data_[j] = data_[i]; data_[i] = tmp; } __device__ __forceinline__ void MinHeapify(int64_t index) { while (true) { const int64_t left = 2 * index + 1; const int64_t right = 2 * index + 2; int64_t min = index; if (left < heap_size_ && data_[left] < data_[min]) { min = left; } if (right < heap_size_ && data_[right] < data_[min]) { min = right; } if (min == index) { return; } Swap(min, index); index = min; } } private: Entry* data_; int64_t heap_size_; }; template __global__ void HeapTopKKernel(const T* in_ptr, const int64_t instance_num, const int64_t instance_size, const int64_t k, const int64_t heap_size, const int64_t init_index, const T init_value, int64_t* out_ptr) { extern __shared__ char smem[]; auto* shared_entries = reinterpret_cast*>(smem); // Divide elements to be sorted into disjoint sets (# of sets == # of heaps). // Each thread in the thread block manipulates one heap to select top heap_size entries from // corresponding set const T* input = in_ptr + blockIdx.x * instance_size; auto heap = MinHeap(shared_entries + threadIdx.x * heap_size, heap_size, init_index, init_value); for (int64_t i = threadIdx.x; i < instance_size; i += blockDim.x) { auto entry = Entry(i, input[i]); if (entry > heap.Top()) { heap.Top() = entry; heap.MinHeapify(0); } } __syncthreads(); // Merge all heaps into a unified, sorted array BitonicSort(shared_entries, blockDim.x * heap_size, [](const Entry& x, const Entry& y) { return x > y; }); // Write top_k elements in sorted array to output for (int64_t i = threadIdx.x; i < k; i += blockDim.x) { (out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex(); } } template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); TmpBufferManager(int64_t capacity, void* ptr, const ShapeView& in_shape) : capacity_{capacity}, sorted_in_elem_cnt_{in_shape.elem_cnt()}, indices_elem_cnt_{sorted_in_elem_cnt_}, sorted_indices_elem_cnt_{sorted_in_elem_cnt_} { const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); const int64_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int64_t)); const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; sorted_in_ptr_ = reinterpret_cast(ptr); indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_in_ptr_) + sorted_in_aligned_bytes); sorted_indices_ptr_ = reinterpret_cast(reinterpret_cast(indices_ptr_) + indices_aligned_bytes); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(sorted_indices_ptr_) + sorted_indices_aligned_bytes); temp_storage_bytes_ = capacity_ - sorted_in_aligned_bytes - indices_aligned_bytes - sorted_indices_aligned_bytes; CHECK_GE(temp_storage_bytes_, 0); } ~TmpBufferManager() = default; T* SortedInPtr() const { return sorted_in_ptr_; } int64_t* IndicesPtr() const { return indices_ptr_; } int64_t* SortedIndicesPtr() const { return sorted_indices_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } int64_t TempStorageBytes() const { return temp_storage_bytes_; } private: int64_t capacity_; T* sorted_in_ptr_; int64_t* indices_ptr_; int64_t* sorted_indices_ptr_; void* temp_storage_ptr_; int64_t sorted_in_elem_cnt_; int64_t indices_elem_cnt_; int64_t sorted_indices_elem_cnt_; int64_t temp_storage_bytes_; }; __global__ void InitializeIndices(int64_t elem_cnt, int64_t* indices_ptr, int64_t instance_size) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; }; } } // namespace template class GpuTopKKernel final : public user_op::OpKernel { public: GpuTopKKernel() = default; ~GpuTopKKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); if (in->shape_view().elem_cnt() == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1); const int64_t instance_num = elem_cnt / instance_size; const int64_t k = std::min(static_cast(ctx->Attr("k")), instance_size); if (k > 30 || instance_num == 1) { user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buf_manager(static_cast(tmp_buffer->shape_view().elem_cnt()), tmp_buffer->mut_dptr(), in->shape_view()); InitializeIndices<<stream()->As()->cuda_stream()>>>( elem_cnt, buf_manager.IndicesPtr(), instance_size); SortPairsDescending(in->dptr(), buf_manager.IndicesPtr(), instance_num, instance_size, buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedInPtr(), buf_manager.SortedIndicesPtr(), ctx->stream()->As()->cuda_stream()); OF_CUDA_CHECK(cudaMemcpy2DAsync( out->mut_dptr(), k * sizeof(int64_t), buf_manager.SortedIndicesPtr(), instance_size * sizeof(int64_t), k * sizeof(int64_t), instance_num, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); } else { // Use as many heaps as possible (# of heaps == # of threads used in thread block). // Limitation 1: size of shared memory // We also need heap_size * num_heap to be pow-of-2 which is necessary for bitonic sort const int64_t heap_size = PowOf2Ceil(k, 16); int32_t num_heap = PowOf2Floor(kCudaMaxSharedMemoryByteSize / (heap_size * sizeof(Entry)), 16); // Limitation 2: # of threads in thread block num_heap = std::min(num_heap, kCudaThreadsNumPerBlock); HeapTopKKernel<<), ctx->stream()->As()->cuda_stream()>>>( in->dptr(), instance_num, instance_size, k, heap_size, GetMaxVal(), GetMinVal(), out->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_TOP_K_KERNEL(dtype) \ REGISTER_USER_KERNEL("top_k") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ const int64_t elem_cnt = in_shape.elem_cnt(); \ const int64_t instance_size = in_shape.dim_vec().back(); \ const int64_t instance_num = elem_cnt / instance_size; \ \ /* Sorted In*/ \ const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype)); \ /* Indices */ \ const int64_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int64_t)); \ /* Sorted Indices */ \ const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; \ /* CUB Temp Storage */ \ int64_t temp_storage_bytes = \ InferTempStorageForSortPairsDescending(instance_num, instance_size); \ \ return sorted_in_aligned_bytes + indices_aligned_bytes + sorted_indices_aligned_bytes \ + temp_storage_bytes; \ }); REGISTER_CUDA_TOP_K_KERNEL(float) REGISTER_CUDA_TOP_K_KERNEL(double) REGISTER_CUDA_TOP_K_KERNEL(uint8_t) REGISTER_CUDA_TOP_K_KERNEL(int8_t) REGISTER_CUDA_TOP_K_KERNEL(int32_t) REGISTER_CUDA_TOP_K_KERNEL(int64_t) REGISTER_CUDA_TOP_K_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/transpose_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/common/primitive/permute.h" namespace oneflow { namespace user_op { namespace { bool IsIdentity(const ShapeView& in_shape, const std::vector& perm) { constexpr int kMaxNumDims = 12; CHECK_LE(in_shape.NumAxes(), kMaxNumDims); CHECK_EQ(in_shape.NumAxes(), perm.size()); size_t simplified_num_dims{}; int64_t simplified_src_dims[kMaxNumDims]{}; int simplified_permutation[kMaxNumDims]{}; ep::primitive::permute::SimplifyPermutation( in_shape.NumAxes(), in_shape.ptr(), perm.data(), &simplified_num_dims, simplified_src_dims, simplified_permutation); for (int i = 0; i < simplified_num_dims; ++i) { if (simplified_permutation[i] != i) { return false; } } return true; } } // namespace template std::unique_ptr NewPermutePrimitive(Context* ctx) { const int64_t num_dims = ctx->TensorDesc4ArgNameAndIndex("output", 0)->shape().NumAxes(); return ep::primitive::NewPrimitive(ctx->device_type(), num_dims); } class TransposeKernel final : public OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(TransposeKernel); TransposeKernel() = default; ~TransposeKernel() override = default; private: void Compute(KernelComputeContext* ctx) const override { auto primitive = NewPermutePrimitive(ctx); CHECK(primitive); const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0); Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0); const auto& perm = ctx->Attr>("perm"); const ShapeView& in_shape = tensor_in->shape_view(); DataType dtype = tensor_out->data_type(); size_t num_dims = tensor_in->shape_view().NumAxes(); const int64_t* src_dims = in_shape.ptr(); int64_t elem_cnt = tensor_out->shape_view().elem_cnt(); if (elem_cnt != 0) { if (IsIdentity(in_shape, perm)) { // if permute vector is 0,1,...,n, do data copy directly AutoMemcpy(ctx->stream(), tensor_out->mut_dptr(), tensor_in->dptr(), elem_cnt * GetSizeOfDataType(dtype), tensor_out->mem_case(), tensor_in->mem_case()); } else { primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(), tensor_out->mut_dptr()); } } else { // For 0-d Tensor return; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; auto PermutePrimitiveExists() { return hob::make_custom("PermutePrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewPermutePrimitive(&ctx).operator bool(); }); } REGISTER_USER_KERNEL("transpose") .SetCreateFn() .SetIsMatchedHob(PermutePrimitiveExists() == true) .SetInplaceProposalFn([](const user_op::InferContext& ctx, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { const ShapeView input_shape(ctx.InputShape("input", 0)); const auto& perm = ctx.Attr>("perm"); if (IsIdentity(input_shape, perm)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("output", 0, "input", 0, false)); } return Maybe::Ok(); }); } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tril_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template class CpuTrilKernel final : public user_op::OpKernel { public: CpuTrilKernel() = default; ~CpuTrilKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); const auto shape = x->shape_view(); const auto diagonal = ctx->Attr("diagonal"); const int64_t num_rows = shape.At(shape.NumAxes() - 2); const int64_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); T* y_dptr = y->mut_dptr(); const T* x_dptr = x->dptr(); const T fill = ctx->Attr("is_floating_fill_value") ? static_cast(ctx->Attr("floating_fill_value")) : static_cast(ctx->Attr("integer_fill_value")); int64_t matrix_size = num_rows * num_cols; for (int64_t k = 0; k < shape.elem_cnt(); ++k) { int64_t offset_in_matrix = k % matrix_size; int64_t i = offset_in_matrix / num_cols; int64_t j = offset_in_matrix - num_cols * i; y_dptr[k] = j > i + diagonal ? fill : x_dptr[k]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_TRIL_KERNEL(dtype) \ REGISTER_USER_KERNEL("tril").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CPU_TRIL_KERNEL(float) REGISTER_CPU_TRIL_KERNEL(double) REGISTER_CPU_TRIL_KERNEL(bool) REGISTER_CPU_TRIL_KERNEL(uint8_t) REGISTER_CPU_TRIL_KERNEL(int8_t) REGISTER_CPU_TRIL_KERNEL(int32_t) REGISTER_CPU_TRIL_KERNEL(int64_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tril_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void TrilGpu(const int64_t elem_cnt, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T* x, const T fill, T* y) { const int64_t matrix_size = num_rows * num_cols; CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) { const int64_t offset_in_matrix = k % matrix_size; const int64_t i = offset_in_matrix / num_cols; const int64_t j = offset_in_matrix - num_cols * i; y[k] = j > i + diagonal ? fill : x[k]; } } template __global__ void TrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T* x, const T fill, T* y) { const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) { const int64_t idx = i * num_cols + col; y[idx] = col > row + diagonal ? fill : x[idx]; } } } template<> __global__ void TrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const half* x, const half fill, half* y) { const int64_t h2_num_cols = num_cols / 2; const auto* x_h2 = reinterpret_cast(x); auto* y_h2 = reinterpret_cast(y); const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) { const int64_t idx = i * h2_num_cols + col; const half2 x_val = x_h2[idx]; half2 y_val; y_val.x = (2 * col) > row + diagonal ? fill : x_val.x; y_val.y = (2 * col + 1) > row + diagonal ? fill : x_val.y; y_h2[idx] = y_val; } } } template __global__ void FusedScaleTrilGpu(const int64_t elem_cnt, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T scale, const T* x, const T fill, T* y) { const int64_t matrix_size = num_rows * num_cols; CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) { const int64_t offset_in_matrix = k % matrix_size; const int64_t i = offset_in_matrix / num_cols; const int64_t j = offset_in_matrix - num_cols * i; y[k] = j > i + diagonal ? fill : (scale * x[k]); } } template __global__ void FusedScaleTrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T scale, const T* x, const T fill, T* y) { const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) { const int64_t idx = i * num_cols + col; y[idx] = col > row + diagonal ? fill : (scale * x[idx]); } } } template<> __global__ void FusedScaleTrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const half scale, const half* x, const half fill, half* y) { const int64_t h2_num_cols = num_cols / 2; const auto* x_h2 = reinterpret_cast(x); auto* y_h2 = reinterpret_cast(y); const half2 h2_scale = __half2half2(scale); const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) { const int64_t idx = i * h2_num_cols + col; const half2 scaled_x = __hmul2(h2_scale, x_h2[idx]); half2 y_val; y_val.x = (2 * col) > row + diagonal ? fill : scaled_x.x; y_val.y = (2 * col + 1) > row + diagonal ? fill : scaled_x.y; y_h2[idx] = y_val; } } } template T GetAttrVal(bool is_floating_val, double floating_value, int64_t integer_value) { return is_floating_val ? static_cast(floating_value) : static_cast(integer_value); } template<> half GetAttrVal(bool is_floating_val, double floating_value, int64_t integer_value) { return is_floating_val ? __float2half(floating_value) : __float2half(integer_value); } } // namespace template class GpuTrilKernel final : public user_op::OpKernel { public: GpuTrilKernel() = default; ~GpuTrilKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); const auto shape = x->shape_view(); const auto diagonal = ctx->Attr("diagonal"); const int64_t num_rows = shape.At(shape.NumAxes() - 2); const int64_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t elem_cnt = shape.elem_cnt(); const T fill = GetAttrVal(ctx->Attr("is_floating_fill_value"), ctx->Attr("floating_fill_value"), ctx->Attr("integer_fill_value")); if (num_cols % (kCudaWarpSize * 2) == 0) { const int64_t total_rows = elem_cnt / num_cols; TrilWarpProcessRowGpu<<stream()->As()->cuda_stream()>>>( total_rows, num_rows, num_cols, diagonal, x->dptr(), fill, y->mut_dptr()); } else { TrilGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, num_rows, num_cols, diagonal, x->dptr(), fill, y->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_TRIL_KERNEL(dtype) \ REGISTER_USER_KERNEL("tril") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_CUDA_TRIL_KERNEL(float) REGISTER_CUDA_TRIL_KERNEL(double) REGISTER_CUDA_TRIL_KERNEL(bool) REGISTER_CUDA_TRIL_KERNEL(uint8_t) REGISTER_CUDA_TRIL_KERNEL(int8_t) REGISTER_CUDA_TRIL_KERNEL(int32_t) REGISTER_CUDA_TRIL_KERNEL(int64_t) REGISTER_CUDA_TRIL_KERNEL(half) template class GpuFusedScaleTrilKernel final : public user_op::OpKernel { public: GpuFusedScaleTrilKernel() = default; ~GpuFusedScaleTrilKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); const auto shape = x->shape_view(); const auto diagonal = ctx->Attr("diagonal"); const int32_t num_rows = shape.At(shape.NumAxes() - 2); const int32_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t elem_cnt = shape.elem_cnt(); const T fill = GetAttrVal(ctx->Attr("is_floating_fill_value"), ctx->Attr("floating_fill_value"), ctx->Attr("integer_fill_value")); const T scale = GetAttrVal(ctx->Attr("is_floating_scale_value"), ctx->Attr("floating_scale_value"), ctx->Attr("integer_scale_value")); if (num_cols % (kCudaWarpSize * 2) == 0) { const int64_t total_rows = elem_cnt / num_cols; FusedScaleTrilWarpProcessRowGpu<<stream()->As()->cuda_stream()>>>( total_rows, num_rows, num_cols, diagonal, scale, x->dptr(), fill, y->mut_dptr()); } else { FusedScaleTrilGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, num_rows, num_cols, diagonal, scale, x->dptr(), fill, y->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(dtype) \ REGISTER_USER_KERNEL("fused_scale_tril") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(float) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(double) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(bool) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(uint8_t) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int8_t) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int32_t) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int64_t) REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(half) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/triu_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template class CpuTriuKernel final : public user_op::OpKernel { public: CpuTriuKernel() = default; ~CpuTriuKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); const auto shape = x->shape_view(); const auto diagonal = ctx->Attr("diagonal"); const int64_t num_rows = shape.At(shape.NumAxes() - 2); const int64_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); T* y_dptr = y->mut_dptr(); const T* x_dptr = x->dptr(); int64_t matrix_size = num_rows * num_cols; for (int64_t k = 0; k < shape.elem_cnt(); ++k) { int64_t offset_in_matrix = k % matrix_size; int64_t i = offset_in_matrix / num_cols; int64_t j = offset_in_matrix - num_cols * i; y_dptr[k] = j < i + diagonal ? static_cast(0) : x_dptr[k]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CPU_TRIU_KERNEL(dtype) \ REGISTER_USER_KERNEL("triu").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("out", 0) == GetDataType::value)); REGISTER_CPU_TRIU_KERNEL(float16) REGISTER_CPU_TRIU_KERNEL(float) REGISTER_CPU_TRIU_KERNEL(double) REGISTER_CPU_TRIU_KERNEL(uint8_t) REGISTER_CPU_TRIU_KERNEL(int8_t) REGISTER_CPU_TRIU_KERNEL(int32_t) REGISTER_CPU_TRIU_KERNEL(int64_t) REGISTER_CPU_TRIU_KERNEL(bool) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/triu_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/util/cuda_half_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void TriuGpu(const int64_t elem_cnt, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T* x, T* y) { const int64_t matrix_size = num_rows * num_cols; CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) { const int64_t offset_in_matrix = k % matrix_size; const int64_t i = offset_in_matrix / num_cols; const int64_t j = offset_in_matrix - num_cols * i; y[k] = j < i + diagonal ? static_cast(0) : x[k]; } } template __global__ void TriuWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const T* x, T* y) { const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) { const int64_t idx = i * num_cols + col; y[idx] = col < row + diagonal ? static_cast(0) : x[idx]; } } } template<> __global__ void TriuWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows, const int64_t num_cols, const int64_t diagonal, const half* x, half* y) { const int64_t h2_num_cols = num_cols / 2; const auto* x_h2 = reinterpret_cast(x); auto* y_h2 = reinterpret_cast(y); const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize; const int64_t lan_id = threadIdx.x % kCudaWarpSize; const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize; for (int64_t i = warp_id; i < total_rows; i += num_warp) { const int64_t row = i % num_rows; for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) { const int64_t idx = i * h2_num_cols + col; const half2 x_val = x_h2[idx]; half2 y_val; y_val.x = (2 * col) < row + diagonal ? static_cast(0) : x_val.x; y_val.y = (2 * col + 1) < row + diagonal ? static_cast(0) : x_val.y; y_h2[idx] = y_val; } } } } // namespace template class GpuTriuKernel final : public user_op::OpKernel { public: GpuTriuKernel() = default; ~GpuTriuKernel() override = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); const auto shape = x->shape_view(); const auto diagonal = ctx->Attr("diagonal"); const int64_t num_rows = shape.At(shape.NumAxes() - 2); const int64_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t elem_cnt = shape.elem_cnt(); if (elem_cnt == 0) { return; } if (num_cols % (kCudaWarpSize * 2) == 0) { const int64_t total_rows = elem_cnt / num_cols; TriuWarpProcessRowGpu<<stream()->As()->cuda_stream()>>>( total_rows, num_rows, num_cols, diagonal, x->dptr(), y->mut_dptr()); } else { TriuGpu<<stream()->As()->cuda_stream()>>>( elem_cnt, num_rows, num_cols, diagonal, x->dptr(), y->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_TRIU_KERNEL(dtype) \ REGISTER_USER_KERNEL("triu") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("out", 0) == GetDataType::value)) \ .SetInplaceProposalFn([](const user_op::InferContext&, \ user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \ return Maybe::Ok(); \ }); REGISTER_CUDA_TRIU_KERNEL(half) REGISTER_CUDA_TRIU_KERNEL(float) REGISTER_CUDA_TRIU_KERNEL(double) REGISTER_CUDA_TRIU_KERNEL(uint8_t) REGISTER_CUDA_TRIU_KERNEL(int8_t) REGISTER_CUDA_TRIU_KERNEL(int32_t) REGISTER_CUDA_TRIU_KERNEL(int64_t) REGISTER_CUDA_TRIU_KERNEL(bool) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/tuple_identity_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" namespace oneflow { namespace { template class TupleIdentityKernel final : public user_op::OpKernel { public: TupleIdentityKernel() = default; ~TupleIdentityKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const int64_t in_size = ctx->input_size("in"); CHECK_EQ(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex("out", i); const DataType data_type = in_i->data_type(); CHECK_EQ(out_i->data_type(), data_type); const ShapeView& shape = in_i->shape_view(); CHECK_EQ(out_i->shape_view(), shape); Memcpy(ctx->stream(), out_i->mut_dptr(), in_i->dptr(), shape.elem_cnt() * GetSizeOfDataType(data_type)); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_TUPLE_IDENTITY_KERNEL(device) \ REGISTER_USER_KERNEL("tuple_identity") \ .SetCreateFn>() \ .SetIsMatchedHob(user_op::HobDeviceType() == device); REGISTER_TUPLE_IDENTITY_KERNEL(DeviceType::kCPU) #ifdef WITH_CUDA REGISTER_TUPLE_IDENTITY_KERNEL(DeviceType::kCUDA) #endif } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/two_stage_reduce_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/user/kernels/two_stage_reduce_kernel_util.h" #include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { namespace user_op { template class BinaryFunc, DeviceType device_type, typename T> class ReduceDeviceStageKernel final : public OpKernel { public: ReduceDeviceStageKernel() = default; ~ReduceDeviceStageKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex("count", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* reduce_tmp_buf = tmp_buffer->mut_dptr(); int32_t* mask_tmp_buf = tmp_buffer->mut_dptr(); const size_t tmp_bytes = GetCudaAlignedSize(in->shape_view().elem_cnt() * std::max(sizeof(T), sizeof(int32_t))); int32_t* reduce_sum_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_bytes); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(out->shape_view(), out->mut_dptr()), XpuVarNdarray(in->shape_view(), in->dptr()), XpuVarNdarray(in->shape_view(), reduce_tmp_buf)); auto bcast_eq = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kEqual, in->data_type(), DataType::kBool, in->shape_view().NumAxes()); CHECK(bcast_eq); bcast_eq->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), out->shape_view().NumAxes(), out->shape_view().ptr(), out->dptr(), mask->mut_dptr()); auto cast = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kInt8, DataType::kInt32); CHECK(cast); cast->Launch(ctx->stream(), mask->dptr(), mask_tmp_buf, mask->shape_view().elem_cnt()); NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(count->shape_view(), count->mut_dptr()), XpuVarNdarray(mask->shape_view(), mask_tmp_buf), XpuVarNdarray(mask->shape_view(), reduce_sum_tmp_buf)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenDeviceStageInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const size_t tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * std::max(sizeof(T), sizeof(int32_t))); const size_t reduce_sum_tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(int32_t)); return tmp_bytes + reduce_sum_tmp_bytes; }; } #define REGISTER_REDUCE_DEVICE_STAGE_KERNEL(op_name, binary_func, device, dtype_pair) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn(GenDeviceStageInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_KERNEL, ("reduce_max_device_stage"), (BinaryFuncMax), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_KERNEL, ("reduce_min_device_stage"), (BinaryFuncMin), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) template class ReduceDeviceStageGradKernel final : public OpKernel { public: ReduceDeviceStageGradKernel() = default; ~ReduceDeviceStageGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex("out_diff", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex("count", 0); user_op::Tensor* in_diff = ctx->Tensor4ArgNameAndIndex("in_diff", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); T* tmp_buf_ptr = tmp_buffer->mut_dptr(); const size_t tmp_bytes = GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(T)); T* broadcasted_tmp_buf_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_bytes); TwoStageReduceKernelUtil::Divide( ctx->stream(), out_diff->shape_view().elem_cnt(), out_diff->dptr(), count->dptr(), tmp_buf_ptr); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(in_diff->shape_view(), broadcasted_tmp_buf_ptr), XpuVarNdarray(out_diff->shape_view(), tmp_buf_ptr)); TwoStageReduceKernelUtil::Mask( ctx->stream(), in_diff->shape_view().elem_cnt(), broadcasted_tmp_buf_ptr, mask->dptr(), in_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenDeviceStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t tmp_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); return tmp_bytes + broadcasted_tmp_bytes; }; } #define REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL(op_name, device, dtype_pair) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn(GenDeviceStageGradInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL, ("reduce_max_device_stage_grad"), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL, ("reduce_min_device_stage_grad"), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) template class BinaryFunc, DeviceType device_type, typename T> class ReduceGlobalStageKernel final : public OpKernel { public: ReduceGlobalStageKernel() = default; ~ReduceGlobalStageKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& axis = ctx->Attr>("axis"); const Shape& reduced_shape = CreateReducedShape(in->shape_view(), {axis.begin(), axis.end()}); NdarrayReduce::Reduce( ctx->stream(), XpuVarNdarray(reduced_shape, out->mut_dptr()), XpuVarNdarray(in->shape_view(), in->dptr()), XpuVarNdarray(in->shape_view(), tmp_buffer->mut_dptr())); auto bcast_eq = ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::BinaryOp::kEqual, in->data_type(), DataType::kBool, in->shape_view().NumAxes()); CHECK(bcast_eq); bcast_eq->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(), reduced_shape.NumAxes(), reduced_shape.dim_vec().data(), out->dptr(), mask->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_REDUCE_GLOBAL_STAGE_KERNEL(op_name, binary_func, device, dtype_pair) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ return in_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair)); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_KERNEL, ("reduce_max_global_stage"), (BinaryFuncMax), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_KERNEL, ("reduce_min_global_stage"), (BinaryFuncMin), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) template class ReduceGlobalStageGradKernel final : public OpKernel { public: ReduceGlobalStageGradKernel() = default; ~ReduceGlobalStageGradKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex("out_diff", 0); const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); const user_op::Tensor* device_count = ctx->Tensor4ArgNameAndIndex("device_count", 0); user_op::Tensor* in_diff = ctx->Tensor4ArgNameAndIndex("in_diff", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); int32_t* device_count_with_mask = tmp_buffer->mut_dptr(); const size_t device_count_with_mask_bytes = GetCudaAlignedSize(device_count->shape_view().elem_cnt() * sizeof(int32_t)); int32_t* global_count = reinterpret_cast(tmp_buffer->mut_dptr() + device_count_with_mask_bytes); const size_t global_count_bytes = GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(int32_t)); int32_t* reduce_sum_tmp_buf = reinterpret_cast( tmp_buffer->mut_dptr() + device_count_with_mask_bytes + global_count_bytes); const size_t reduce_sum_tmp_bytes = GetCudaAlignedSize(device_count->shape_view().elem_cnt() * sizeof(int32_t)); T* divided_buf_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + device_count_with_mask_bytes + global_count_bytes + reduce_sum_tmp_bytes); const size_t divided_buf_bytes = GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(T)); T* broadcasted_divided_buf_ptr = reinterpret_cast(tmp_buffer->mut_dptr() + device_count_with_mask_bytes + global_count_bytes + reduce_sum_tmp_bytes + divided_buf_bytes); TwoStageReduceKernelUtil::Mask( ctx->stream(), device_count->shape_view().elem_cnt(), device_count->dptr(), mask->dptr(), device_count_with_mask); const auto& axis = ctx->Attr>("axis"); const Shape& reduced_shape = CreateReducedShape(device_count->shape_view(), {axis.begin(), axis.end()}); NdarrayUtil::ReduceSum( ctx->stream(), XpuVarNdarray(reduced_shape, global_count), XpuVarNdarray(device_count->shape_view(), device_count_with_mask), XpuVarNdarray(device_count->shape_view(), reduce_sum_tmp_buf)); TwoStageReduceKernelUtil::Divide( ctx->stream(), out_diff->shape_view().elem_cnt(), out_diff->dptr(), global_count, divided_buf_ptr); NdarrayUtil::BroadcastTo( ctx->stream(), XpuVarNdarray(in_diff->shape_view(), broadcasted_divided_buf_ptr), XpuVarNdarray(out_diff->shape_view(), divided_buf_ptr)); TwoStageReduceKernelUtil::Scale( ctx->stream(), in_diff->shape_view().elem_cnt(), broadcasted_divided_buf_ptr, device_count_with_mask, in_diff->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& device_count_shape = ctx->InputShape("device_count", 0); const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t device_count_with_mask_bytes = GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t global_count_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(int32_t)); const size_t reduce_sum_tmp_bytes = GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t divided_buf_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); const size_t broadcasted_divided_buf_bytes = GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); const size_t total_bytes = device_count_with_mask_bytes + global_count_bytes + reduce_sum_tmp_bytes + divided_buf_bytes + broadcasted_divided_buf_bytes; return total_bytes; }; } #define REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL(op_name, device, dtype_pair) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("in_diff", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn(GenGlobalStageGradInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL, ("reduce_max_global_stage_grad"), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL, ("reduce_min_global_stage_grad"), DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/two_stage_reduce_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/two_stage_reduce_kernel_util.h" #include "oneflow/core/common/data_type_seq.h" namespace oneflow { template struct TwoStageReduceKernelUtil { static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y) { FOR_RANGE(int64_t, i, 0, n) { y[i] = x[i] / count[i]; } } static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y) { FOR_RANGE(int64_t, i, 0, n) { y[i] = static_cast(mask[i]) * x[i]; } } static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y) { FOR_RANGE(int64_t, i, 0, n) { y[i] = x[i] * static_cast(scale[i]); } } }; #define INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU(data_type_pair, index_type_pair) \ template struct TwoStageReduceKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ); #undef INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/two_stage_reduce_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/two_stage_reduce_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template __global__ void DivideGpu(const int64_t n, const T* x, const K* count, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] / count[i]; } } template __global__ void MaskGpu(const int64_t n, const T* x, const K* mask, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = static_cast(mask[i]) * x[i]; } } template __global__ void ScaleGpu(const int64_t n, const T* x, const K* scale, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * static_cast(scale[i]); } } } // namespace template struct TwoStageReduceKernelUtil { static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y) { DivideGpu<<As()->cuda_stream()>>>(n, x, count, y); } static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y) { MaskGpu<<As()->cuda_stream()>>>(n, x, mask, y); } static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y) { ScaleGpu<<As()->cuda_stream()>>>(n, x, scale, y); } }; #define INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair) \ template struct TwoStageReduceKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ); #undef INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/two_stage_reduce_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_ #define ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_ #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct TwoStageReduceKernelUtil { static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y); static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y); static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_ ================================================ FILE: oneflow/user/kernels/unfold_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/kernels/unfold_kernel_util.h" namespace oneflow { namespace user_op { namespace { // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template class UnfoldOpKernelState : public OpKernelState { public: using ParamType = UnfoldParams; UnfoldOpKernelState(const ShapeView& input_shape, const std::vector& kernel_size, const std::vector& padding, const std::vector& stride, const std::vector& dilation) : params_(input_shape.At(0), input_shape.At(ParamType::kInputChannelDim), input_shape.ptr() + SDIM, kernel_size.data(), padding.data(), stride.data(), dilation.data()) {} const ParamType& params() const { return params_; } private: ParamType params_; }; template std::shared_ptr> CreateUnfoldOpKernelState( const ShapeView& input_shape, const std::vector& kernel_size, const std::vector& padding, const std::vector& stride, const std::vector& dilation) { std::shared_ptr> state( new UnfoldOpKernelState(input_shape, kernel_size, padding, stride, dilation)); return state; } template class UnfoldKernel final : public OpKernel { public: UnfoldKernel() = default; ~UnfoldKernel() = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* input = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* output = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector kernel_size = ctx->Attr>("kernel_size"); const std::vector padding = ctx->Attr>("padding"); const std::vector stride = ctx->Attr>("strides"); const std::vector dilation = ctx->Attr>("dilation_rate"); const auto& state_ptr = CreateUnfoldOpKernelState( input->shape_view(), kernel_size, padding, stride, dilation); const UnfoldParams params = state_ptr->params(); UnfoldKernelUtil::Forward( ctx->stream(), ¶ms, input->dptr(), output->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; } // namespace // Currently support 4-D tensor and NCHW format #define REGISTER_UNFOLD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("unfold") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == device) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_UNFOLD_KERNEL(DeviceType::kCPU, float) REGISTER_UNFOLD_KERNEL(DeviceType::kCPU, double) #ifdef WITH_CUDA REGISTER_UNFOLD_KERNEL(DeviceType::kCUDA, float) REGISTER_UNFOLD_KERNEL(DeviceType::kCUDA, double) #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unfold_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unfold_kernel_util.h" namespace oneflow { namespace user_op { // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template struct UnfoldKernelUtil { using ParamType = UnfoldParams; static void Forward(ep::Stream* stream, const UnfoldParams* raw_params, const T* input_ptr, T* output_ptr) { for (INDEX_T out_offset = 0; out_offset < raw_params->out_elem_cnt; ++out_offset) { using ParamType = UnfoldParams; INDEX_T in_index[ParamType::kInputNDim] = {0}; INDEX_T out_index[ParamType::kOutputNDim] = {0}; raw_params->out_index_helper.OffsetToNdIndex(out_offset, out_index); if (!UnfoldIndexTransform(*raw_params, out_index, in_index)) { INDEX_T in_offset = raw_params->in_index_helper.NdIndexToOffset(in_index); output_ptr[out_offset] = input_ptr[in_offset]; } else { output_ptr[out_offset] = static_cast(kUnfoldPaddingValue); } } } }; INSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unfold_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/user/kernels/unfold_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace user_op { namespace { constexpr int kBlockSize = cuda::elementwise::kBlockSize; int GetNumBlocks(int64_t elem_cnt) { int num_blocks = 0; OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); return num_blocks; } // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template __global__ void CudaUnfoldForward(UnfoldParams params, const T* in, T* out) { CUDA_1D_KERNEL_LOOP_T(INDEX_T, out_offset, params.out_elem_cnt) { using ParamType = UnfoldParams; INDEX_T in_index[ParamType::kInputNDim] = {0}; INDEX_T out_index[ParamType::kOutputNDim] = {0}; params.out_index_helper.OffsetToNdIndex(out_offset, out_index); if (!UnfoldIndexTransform(params, out_index, in_index)) { INDEX_T in_offset = params.in_index_helper.NdIndexToOffset(in_index); out[out_offset] = in[in_offset]; } else { out[out_offset] = static_cast(kUnfoldPaddingValue); } } } } // namespace template struct UnfoldKernelUtil { using ParamType = UnfoldParams; static void Forward(ep::Stream* stream, const UnfoldParams* params, const T* input_ptr, T* output_ptr) { CudaUnfoldForward <<out_elem_cnt), kBlockSize, 0, stream->As()->cuda_stream()>>>(*params, input_ptr, output_ptr); } }; INSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA) } // namespace user_op } // namespace oneflow #endif // WITH_CUDA ================================================ FILE: oneflow/user/kernels/unfold_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_ #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/core/ndarray/xpu_util.h" namespace oneflow { namespace user_op { constexpr int kUnfoldPaddingValue = 0; // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template struct UnfoldParams { static constexpr int kInputNDim = NDIM + 2; static constexpr int kOutputNDim = NDIM * 2 + 2; static constexpr int kInputChannelDim = (2 - SDIM) * NDIM + 1; static constexpr int kOutputChannelDim = (2 - SDIM) * NDIM * 2 + 1; static_assert(kInputChannelDim < kInputNDim, ""); static_assert(kOutputChannelDim < kOutputNDim, ""); UnfoldParams(const int64_t batch_size, const int64_t channels, const int64_t* spatial_dims, const int32_t* kernel_size, const int32_t* padding, const int32_t* stride, const int32_t* dilation); INDEX_T in_elem_cnt; INDEX_T out_elem_cnt; INDEX_T dims[NDIM]; int padding[NDIM]; int stride[NDIM]; int dilation[NDIM]; NdIndexOffsetHelper in_index_helper; NdIndexOffsetHelper out_index_helper; }; // NDIM range: (1, 2, 3) // SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first template UnfoldParams::UnfoldParams(const int64_t batch_size, const int64_t channels, const int64_t* spatial_dims, const int32_t* kernel_size, const int32_t* padding, const int32_t* stride, const int32_t* dilation) : in_index_helper(0), out_index_helper(0) { INDEX_T input_dims[kInputNDim] = {0}; INDEX_T output_dims[kOutputNDim] = {0}; in_elem_cnt = batch_size * channels; out_elem_cnt = batch_size * channels; input_dims[0] = batch_size; output_dims[0] = batch_size; input_dims[kInputChannelDim] = channels; output_dims[kOutputChannelDim] = channels; for (int d = 0; d < NDIM; ++d) { this->in_elem_cnt *= spatial_dims[d]; this->dims[d] = spatial_dims[d]; this->padding[d] = padding[d]; this->stride[d] = stride[d]; this->dilation[d] = dilation[d]; input_dims[SDIM + d] = spatial_dims[d]; output_dims[SDIM + d] = kernel_size[d]; output_dims[SDIM + NDIM + d] = (spatial_dims[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1; out_elem_cnt *= output_dims[SDIM + d] * output_dims[SDIM + NDIM + d]; } in_index_helper = NdIndexOffsetHelper(input_dims); out_index_helper = NdIndexOffsetHelper(output_dims); } // index_a format: (N, C, di, hi, wi, db, hb, wb) or (N, di, hi, wi, db, hb, wb, C) // index_b format: (N, C, D, H, W) or (N, D, H, W, C) // return: true indicates out-of-bound, otherwise in-bound template OF_DEVICE_FUNC bool UnfoldIndexTransform(const UnfoldParams& params, const INDEX_T* index_a, INDEX_T* index_b) { // batch dim index transform index_b[0] = index_a[0]; // channel dim index transform using ParamType = UnfoldParams; index_b[ParamType::kInputChannelDim] = index_a[ParamType::kOutputChannelDim]; // spatial dim index transform #ifdef __CUDA_ARCH__ #pragma unroll #endif // D,H,W spatial dim index transform for (int64_t d = 0; d < NDIM; ++d) { INDEX_T idx = index_a[SDIM + NDIM + d] * params.stride[d] + index_a[SDIM + d] * params.dilation[d] - params.padding[d]; if (idx < 0 || idx >= params.dims[d]) return true; index_b[SDIM + d] = idx; } return false; } template struct UnfoldKernelUtil { static void Forward(ep::Stream* stream, const UnfoldParams* params, const T* input_ptr, T* output_ptr); }; #define SPATIAL_NDIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3) #define SPATIAL_DIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) #define INSTANTIATE_UNFOLD_KERNEL_UTIL(device, dtype, itype, ndim, sdim) \ template struct UnfoldKernelUtil; #define INSTANTIATE_UNFOLD_KERNEL_UTIL_WITH_TYPE_PAIR(device, dtype_pair, itype_pair, ndim, sdim) \ INSTANTIATE_UNFOLD_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(dtype_pair), \ OF_PP_PAIR_FIRST(itype_pair), ndim, sdim) #define INSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(device) \ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNFOLD_KERNEL_UTIL_WITH_TYPE_PAIR, (device), \ FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, SPATIAL_NDIM_SEQ, \ SPATIAL_DIM_SEQ) } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/unfold_tensor_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/user/kernels/unfold_tensor_kernel_utils.h" namespace oneflow { template class UnfoldTensorKernel final : public user_op::OpKernel { public: UnfoldTensorKernel() = default; ~UnfoldTensorKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("y", 0); const ShapeView& in_shape = in->shape_view(); std::vector out_shape; out_shape.resize(out->shape_view().NumAxes()); for (int i = 0; i < out->shape_view().NumAxes(); ++i) { out_shape[i] = out->shape_view().At(i); } const int32_t in_dims = in_shape.NumAxes(); const int32_t out_dims = out_shape.size(); const int32_t dimension = ctx->Attr("dimension"); const int32_t step = ctx->Attr("step"); std::vector in_stride(in_dims, 1); for (int32_t i = in_dims - 2; i >= 0; --i) { in_stride[i] = in_shape.At(i + 1) * in_stride.at(i + 1); } std::vector out_stride(in_dims + 1); out_stride[in_dims] = in_dims == 0 ? 1 : in_stride[dimension]; for (int d = 0; d < in_dims; ++d) { if (d == dimension) { out_stride[d] = step * in_stride[d]; } else { out_stride[d] = in_stride[d]; } } const T* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); const int32_t out_size = out->shape_view().elem_cnt(); for (int32_t i = 0; i < out_size; ++i) { int offset = Offset(i, out_stride.data(), out_shape.data(), out_dims - 1); out_ptr[i] = in_ptr[offset]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNFOLD_TENSOR_KERNEL(dtype) \ REGISTER_USER_KERNEL("unfold_tensor") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); REGISTER_UNFOLD_TENSOR_KERNEL(float) REGISTER_UNFOLD_TENSOR_KERNEL(double) REGISTER_UNFOLD_TENSOR_KERNEL(int64_t) REGISTER_UNFOLD_TENSOR_KERNEL(int32_t) template class UnfoldTensorGradKernel final : public user_op::OpKernel { public: UnfoldTensorGradKernel() = default; ~UnfoldTensorGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dout = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* din = ctx->Tensor4ArgNameAndIndex("dx", 0); const ShapeView& in_shape = in->shape_view(); const int32_t in_dims = in_shape.NumAxes(); std::vector din_stride(in_dims, 1); for (int32_t i = in_dims - 2; i >= 0; --i) { din_stride[i] = in_shape.At(i + 1) * din_stride.at(i + 1); } std::vector dout_shape; dout_shape.resize(dout->shape_view().NumAxes()); for (int i = 0; i < dout->shape_view().NumAxes(); ++i) { dout_shape[i] = dout->shape_view().At(i); } const int32_t dout_dims = dout_shape.size(); const int32_t dimension = ctx->Attr("dimension"); const int32_t step = ctx->Attr("step"); std::vector dout_stride(in_dims + 1); dout_stride[in_dims] = in_dims == 0 ? 1 : din_stride[dimension]; for (int d = 0; d < in_dims; ++d) { if (d == dimension) { dout_stride[d] = step * din_stride[d]; } else { dout_stride[d] = din_stride[d]; } } const T* dout_ptr = dout->dptr(); T* din_ptr = din->mut_dptr(); std::fill(din_ptr, din_ptr + din->shape_view().elem_cnt(), static_cast(0)); const int32_t dout_size = dout->shape_view().elem_cnt(); for (int32_t i = 0; i < dout_size; ++i) { int offset = Offset(i, dout_stride.data(), dout_shape.data(), dout_dims - 1); din_ptr[offset] += dout_ptr[i]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("unfold_tensor_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x", 0) == GetDataType::value)); REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(float) REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(double) REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int64_t) REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int32_t) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unfold_tensor_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/unfold_tensor_kernel_utils.h" namespace oneflow { namespace { const int32_t NDIMS = 16; struct STRIDES { int32_t val[NDIMS]; }; template __global__ void UnfoldTensorCudaKernel(const T* in_ptr, const STRIDES out_stride, const STRIDES out_shape, const int32_t out_dims, const int32_t elements, T* out_ptr) { int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (gid < elements) { int32_t offset = Offset(gid, out_stride.val, out_shape.val, out_dims - 1); out_ptr[gid] = in_ptr[offset]; gid += step; } } template __global__ void UnfoldTensorGradCudaKernel(const T* dout_ptr, const STRIDES dout_stride, const STRIDES dout_shape, const int32_t dout_dims, const int32_t elements, T* din_ptr) { int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (gid < elements) { int32_t offset = Offset(gid, dout_stride.val, dout_shape.val, dout_dims - 1); cuda::atomic::Add(&din_ptr[offset], dout_ptr[gid]); gid += step; } } template __global__ void InitPtr(const int32_t elements, T* ptr) { int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x; int32_t step = gridDim.x * blockDim.x; while (gid < elements) { ptr[gid] = static_cast(0); gid += step; } } template struct GpuUnfoldTensorFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, const STRIDES out_stride, const STRIDES out_shape, const int32_t out_dims, const int32_t elements, T* out_ptr) { RUN_CUDA_KERNEL((UnfoldTensorCudaKernel), stream, elements, in_ptr, out_stride, out_shape, out_dims, elements, out_ptr); } }; template struct GpuUnfoldTensorGradFunctor final { void operator()(ep::Stream* stream, const T* dout_ptr, const STRIDES dout_stride, const STRIDES dout_shape, const int32_t dout_dims, const int32_t dout_elements, const int32_t din_elements, T* din_ptr) { RUN_CUDA_KERNEL((InitPtr), stream, din_elements, din_elements, din_ptr); RUN_CUDA_KERNEL((UnfoldTensorGradCudaKernel), stream, dout_elements, dout_ptr, dout_stride, dout_shape, dout_dims, dout_elements, din_ptr); } }; } // namespace template class GpuUnfoldTensorKernel final : public user_op::OpKernel { public: GpuUnfoldTensorKernel() = default; ~GpuUnfoldTensorKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("y", 0); const ShapeView& in_shape = in->shape_view(); std::vector out_shape; out_shape.resize(out->shape_view().NumAxes()); for (int i = 0; i < out->shape_view().NumAxes(); ++i) { out_shape[i] = out->shape_view().At(i); } const int32_t in_dims = in_shape.NumAxes(); const int32_t out_dims = out_shape.size(); const int32_t dimension = ctx->Attr("dimension"); const int32_t step = ctx->Attr("step"); std::vector in_stride(in_dims, 1); for (int32_t i = in_dims - 2; i >= 0; --i) { in_stride[i] = in_shape.At(i + 1) * in_stride.at(i + 1); } std::vector out_stride(in_dims + 1); out_stride[in_dims] = in_dims == 0 ? 1 : in_stride[dimension]; for (int d = 0; d < in_dims; ++d) { if (d == dimension) { out_stride[d] = step * in_stride[d]; } else { out_stride[d] = in_stride[d]; } } const T* in_ptr = in->dptr(); T* out_ptr = out->mut_dptr(); const int32_t out_size = out->shape_view().elem_cnt(); STRIDES out_stride_cuda; for (int i = 0; i < out_dims; ++i) { out_stride_cuda.val[i] = out_stride[i]; } STRIDES out_shape_cuda; for (int i = 0; i < out_dims; ++i) { out_shape_cuda.val[i] = out_shape[i]; } GpuUnfoldTensorFunctor()(ctx->stream(), in_ptr, out_stride_cuda, out_shape_cuda, out_dims, out_size, out_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNFOLD_TENSOR_KERNEL(dtype) \ REGISTER_USER_KERNEL("unfold_tensor") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)) REGISTER_UNFOLD_TENSOR_KERNEL(float); REGISTER_UNFOLD_TENSOR_KERNEL(double); REGISTER_UNFOLD_TENSOR_KERNEL(int32_t); REGISTER_UNFOLD_TENSOR_KERNEL(int64_t); template class GpuUnfoldTensorGradKernel final : public user_op::OpKernel { public: GpuUnfoldTensorGradKernel() = default; ~GpuUnfoldTensorGradKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dout = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* din = ctx->Tensor4ArgNameAndIndex("dx", 0); const ShapeView& in_shape = in->shape_view(); const int32_t in_dims = in_shape.NumAxes(); std::vector din_stride(in_dims, 1); for (int32_t i = in_dims - 2; i >= 0; --i) { din_stride[i] = in_shape.At(i + 1) * din_stride.at(i + 1); } std::vector dout_shape; dout_shape.resize(dout->shape_view().NumAxes()); for (int i = 0; i < dout->shape_view().NumAxes(); ++i) { dout_shape[i] = dout->shape_view().At(i); } const int32_t dout_dims = dout_shape.size(); const int32_t dimension = ctx->Attr("dimension"); const int32_t step = ctx->Attr("step"); std::vector dout_stride(in_dims + 1); dout_stride[in_dims] = in_dims == 0 ? 1 : din_stride[dimension]; for (int d = 0; d < in_dims; ++d) { if (d == dimension) { dout_stride[d] = step * din_stride[d]; } else { dout_stride[d] = din_stride[d]; } } STRIDES dout_stride_cuda; for (int i = 0; i < dout_dims; ++i) { dout_stride_cuda.val[i] = dout_stride[i]; } STRIDES dout_shape_cuda; for (int i = 0; i < dout_dims; ++i) { dout_shape_cuda.val[i] = dout_shape[i]; } const T* dout_ptr = dout->dptr(); T* din_ptr = din->mut_dptr(); const int32_t dout_size = dout->shape_view().elem_cnt(); const int32_t din_size = din->shape_view().elem_cnt(); GpuUnfoldTensorGradFunctor()(ctx->stream(), dout_ptr, dout_stride_cuda, dout_shape_cuda, dout_dims, dout_size, din_size, din_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("unfold_tensor_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x", 0) == GetDataType::value)) REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(float); REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(double); REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int32_t); REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int64_t); } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unfold_tensor_kernel_utils.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_ #define ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_ #include "oneflow/core/framework/framework.h" namespace oneflow { OF_DEVICE_FUNC int32_t Offset(int32_t in_offset, const int32_t* out_stride, const int32_t* out_shape, const int32_t n) { int32_t remaining = 0; int32_t out_offset = 0; #ifdef __CUDA_ARCH__ #pragma unroll #endif for (int32_t dim = n; dim >= 0; --dim) { remaining = in_offset % out_shape[dim]; out_offset += remaining * out_stride[dim]; in_offset = in_offset / out_shape[dim]; } return out_offset; } } // namespace oneflow #endif // ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_ ================================================ FILE: oneflow/user/kernels/unique_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unique_kernel_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { template class UniqueKernel final : public user_op::OpKernel { public: UniqueKernel() = default; ~UniqueKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* idx = ctx->Tensor4ArgNameAndIndex("idx", 0); user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex("num_unique", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const bool& sorted = ctx->Attr("sorted"); void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr; int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0; UniqueKernelUtil::Unique( ctx->stream(), x->shape_view().elem_cnt(), x->dptr(), num_unique->mut_dptr(), y->mut_dptr(), idx->mut_dptr(), tmp_ptr, tmp_size, sorted); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); int64_t workspace_size_in_bytes = 0; UniqueKernelUtil::GetUniqueWorkspaceSizeInBytes( nullptr, x.shape().elem_cnt(), &workspace_size_in_bytes); return workspace_size_in_bytes; }; } #define REGISTER_UNIQUE_KERNEL(device_type_v, data_type_pair, indices_type_pair) \ REGISTER_USER_KERNEL("unique") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("idx", 0) == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNIQUE_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unique_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unique_kernel_util.h" namespace oneflow { template struct UniqueKernelUtil { static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted) { UniqueKernelUtil::UniqueWithCounts( stream, n, in, num_unique, unique_out, idx_out, nullptr, workspace, workspace_size_in_bytes, sorted); } static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, IDX* count, void* workspace, int64_t workspace_size_in_bytes, bool sorted) { std::vector sorted_idx(n); std::iota(sorted_idx.begin(), sorted_idx.end(), 0); if (sorted) { std::sort(sorted_idx.begin(), sorted_idx.end(), [&in](size_t a, size_t b) { return in[a] < in[b]; }); } HashMap map; for (int64_t i : sorted_idx) { KEY in_i = in[i]; auto it = map.find(in_i); if (it == map.end()) { IDX idx = map.size(); if (count != nullptr) { count[idx] = 1; } idx_out[i] = idx; unique_out[idx] = in_i; map[in_i] = idx; } else { IDX idx = it->second; if (count != nullptr) { count[idx] += 1; } idx_out[i] = idx; } } *num_unique = map.size(); } static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) { *workspace_size_in_bytes = 1; } static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) { *workspace_size_in_bytes = 1; } }; #define INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU(key_type_pair, idx_type_pair) \ template struct UniqueKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unique_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unique_kernel_util.h" #include "oneflow/core/cuda/unique.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { constexpr cuda::unique::Flag kUniqueFlag = cuda::unique::kOutputInverseIndices; constexpr cuda::unique::Flag kUniqueWithCountsFlag = cuda::unique::kOutputInverseIndices | cuda::unique::kOutputCounts; } // namespace template struct UniqueKernelUtil { static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted); static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, IDX* count, void* workspace, int64_t workspace_size_in_bytes, bool sorted); static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes); static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes); }; template void UniqueKernelUtil::Unique( ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted /* not used, always return sorted output in CUDA,it`s the same as torch.unique*/) { OF_CUDA_CHECK((cuda::unique::Launch(kUniqueFlag, n, in, unique_out, num_unique, idx_out, nullptr, workspace, workspace_size_in_bytes, stream->As()->cuda_stream()))); } template void UniqueKernelUtil::UniqueWithCounts( ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, IDX* count, void* workspace, int64_t workspace_size_in_bytes, bool sorted /* not used, always return sorted output in CUDA,it`s the same as torch.unique*/) { OF_CUDA_CHECK((cuda::unique::Launch( kUniqueWithCountsFlag, n, in, unique_out, num_unique, idx_out, count, workspace, workspace_size_in_bytes, stream->As()->cuda_stream()))); } template void UniqueKernelUtil::GetUniqueWorkspaceSizeInBytes( ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) { size_t ws = 0; OF_CUDA_CHECK((cuda::unique::GetWorkspaceSize(kUniqueFlag, n, &ws))); *workspace_size_in_bytes = static_cast(ws); } template void UniqueKernelUtil::GetUniqueWithCountsWorkspaceSizeInBytes( ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) { size_t ws = 0; OF_CUDA_CHECK((cuda::unique::GetWorkspaceSize(kUniqueWithCountsFlag, n, &ws))); *workspace_size_in_bytes = static_cast(ws); } #define INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA(key_type_pair, idx_type_pair) \ template struct UniqueKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); #undef INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unique_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct UniqueKernelUtil { static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted); static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out, IDX* count, void* workspace, int64_t workspace_size_in_bytes, bool sorted); static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes); static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes); }; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/unique_with_counts_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unique_kernel_util.h" #include "oneflow/core/framework/framework.h" namespace oneflow { namespace { template class UniqueWithCountsKernel final : public user_op::OpKernel { public: UniqueWithCountsKernel() = default; ~UniqueWithCountsKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* idx = ctx->Tensor4ArgNameAndIndex("idx", 0); user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex("count", 0); user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex("num_unique", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const bool& sorted = ctx->Attr("sorted"); void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr; int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0; UniqueKernelUtil::UniqueWithCounts( ctx->stream(), x->shape_view().elem_cnt(), x->dptr(), num_unique->mut_dptr(), y->mut_dptr(), idx->mut_dptr(), count->mut_dptr(), tmp_ptr, tmp_size, sorted); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template user_op::InferTmpSizeFn GenInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("x", 0); int64_t workspace_size_in_bytes; UniqueKernelUtil::GetUniqueWithCountsWorkspaceSizeInBytes( nullptr, x.shape().elem_cnt(), &workspace_size_in_bytes); return workspace_size_in_bytes; }; } #define REGISTER_UNIQUE_WITH_COUNTS_KERNEL(device_type_v, data_type_pair, indices_type_pair) \ REGISTER_USER_KERNEL("unique_with_counts") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device_type_v) \ && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ && (user_op::HobDataType("idx", 0) == OF_PP_PAIR_SECOND(indices_type_pair))) \ .SetInferTmpSizeFn(GenInferTmpSizeFn()); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNIQUE_WITH_COUNTS_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unpack_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { namespace { template class UnpackKernel final : public user_op::OpKernel { public: UnpackKernel() = default; ~UnpackKernel() override = default; std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { return std::make_shared>>( std::make_pair(0, 0)); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_GT(in->shape_view().NumAxes(), 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); CHECK_EQ(in->shape_view().NumAxes(), out->shape_view().NumAxes()); const auto unpack_num = ctx->Attr("unpack_num"); CHECK_EQ(out->shape_view().At(0) * unpack_num, in->shape_view().At(0)); for (int64_t i = 1; i < in->shape_view().NumAxes(); ++i) { CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i)); } const int64_t copy_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()); auto* state_wrapper = dynamic_cast>*>(state); CHECK_NOTNULL(state_wrapper); const size_t index = state_wrapper->Get().first; CHECK_EQ(state_wrapper->Get().second, unpack_num); Memcpy(ctx->stream(), out->mut_dptr(), in->dptr() + index * copy_size, copy_size); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UNPACK_KERNEL(device) \ REGISTER_USER_KERNEL("unpack").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device)); OF_PP_FOR_EACH_TUPLE(REGISTER_UNPACK_KERNEL, DEVICE_TYPE_SEQ) #if defined(WITH_MLU) REGISTER_UNPACK_KERNEL(DeviceType::kMLU) #endif #undef REGISTER_UNPACK_KERNEL } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unsorted_batch_segment_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/batch_gather_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace user_op { namespace { Shape GetFlatShape(const ShapeView& shape, const int64_t axis) { CHECK_GT(shape.NumAxes(), 0); CHECK_GE(axis, 0); CHECK_LT(axis, shape.NumAxes()); return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); } } // namespace template class UnsortedBatchSegmentSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: UnsortedBatchSegmentSumKernel() = default; ~UnsortedBatchSegmentSumKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex("data", 0); const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex("segment_ids", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t axis = segment_ids->shape_view().NumAxes() - 1; const Shape& flat_data_shape = GetFlatShape(data->shape_view(), axis); Memset(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * sizeof(T)); BatchGatherKernelUtilImpl::Backward( ctx->stream(), data->dptr(), segment_ids->dptr(), flat_data_shape, out->shape_view().At(axis), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_UNSORTED_BATCH_SEGMENT_SUM_KERNEL(device, out_dtype, segment_ids_dtype) \ REGISTER_USER_KERNEL("unsorted_batch_segment_sum") \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("segment_ids", 0) == OF_PP_PAIR_SECOND(segment_ids_dtype)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_dtype))); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_BATCH_SEGMENT_SUM_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unsorted_segment_sum_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/ep/include/primitive/cast.h" #ifdef WITH_CUDA #include #endif namespace oneflow { namespace user_op { namespace { void CheckNdSbp(const Shape& hierarchy, int64_t sum_axis, const NdSbp& segment_ids_nd_sbp, const NdSbp& data_nd_sbp, const NdSbp& out_nd_sbp) { CHECK_EQ(hierarchy.NumAxes(), segment_ids_nd_sbp.sbp_parallel_size()); CHECK_EQ(hierarchy.NumAxes(), data_nd_sbp.sbp_parallel_size()); CHECK_EQ(hierarchy.NumAxes(), out_nd_sbp.sbp_parallel_size()); if (hierarchy.elem_cnt() == 1) { return; } FOR_RANGE(int64_t, i, 0, hierarchy.NumAxes()) { const auto& out_sbp = out_nd_sbp.sbp_parallel(i); if (out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == sum_axis) { CHECK(segment_ids_nd_sbp.sbp_parallel(i).has_broadcast_parallel()); CHECK(data_nd_sbp.sbp_parallel(i).has_broadcast_parallel()); } } } class UnsortedSegmentSumOpKernelCache final : public user_op::OpKernelCache { public: UnsortedSegmentSumOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} ~UnsortedSegmentSumOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } private: const int64_t lower_; const int64_t upper_; }; std::shared_ptr CreateUnsortedSegmentSumOpKernelCache( user_op::KernelCacheContext* ctx) { if (ctx->parallel_ctx().parallel_num() > 1) { const auto axis = ctx->Attr("axis"); const NdSbp& out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); CheckNdSbp(hierarchy, axis, ctx->NdSbp4ArgNameAndIndex("segment_ids", 0), ctx->NdSbp4ArgNameAndIndex("data", 0), out_nd_sbp); const TensorDesc* out_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("out", 0); TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, out_nd_sbp, out_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); return std::make_shared(view.At(axis).begin(), view.At(axis).end()); } else { return nullptr; } } } // namespace template class UnsortedSegmentSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: UnsortedSegmentSumKernel() = default; ~UnsortedSegmentSumKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateUnsortedSegmentSumOpKernelCache(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex("data", 0); const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex("segment_ids", 0); int64_t axis = ctx->Attr("axis"); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t outer_dim_size = out->shape_view().Count(0, axis); int64_t num_segments = out->shape_view().At(axis); int64_t inner_dim_size = out->shape_view().Count(axis + 1); int64_t num_segment_ids = segment_ids->shape_view().elem_cnt(); Memset(ctx->stream(), out->mut_dptr(), 0, out->shape_view().elem_cnt() * sizeof(T)); int64_t offset = 0; if (cache != nullptr) { auto* sum_cache = dynamic_cast(cache); CHECK_NOTNULL(sum_cache); CHECK_EQ(out->shape_view().At(axis), sum_cache->upper() - sum_cache->lower()); offset = sum_cache->lower(); } if (num_segment_ids != 0) { UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( ctx->stream(), segment_ids->dptr(), data->dptr(), num_segment_ids, num_segments, outer_dim_size, inner_dim_size, offset, out->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device, out_type, segment_ids_type, kernel_type) \ REGISTER_USER_KERNEL(kernel_type) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == device) \ && (user_op::HobDataType("segment_ids", 0) == OF_PP_PAIR_SECOND(segment_ids_type)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_type))); #define REGISTER_UNSORTED_SEGMENT_SUM_KERNEL_CASE(device_type, out_type, segment_ids_type) \ REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device_type, out_type, segment_ids_type, \ ("unsorted_segment_sum")) #define REGISTER_UNSORTED_SEGMENT_SUM_LIKE_KERNEL_CASE(device_type, out_type, segment_ids_type) \ REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device_type, out_type, segment_ids_type, \ ("unsorted_segment_sum_like")) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_KERNEL_CASE, DEVICE_TYPE_SEQ, UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_LIKE_KERNEL_CASE, DEVICE_TYPE_SEQ, UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #ifdef WITH_CUDA template class UnsortedSegmentSumHalfKernel final : public user_op::OpKernel { public: UnsortedSegmentSumHalfKernel() = default; ~UnsortedSegmentSumHalfKernel() override = default; std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { return CreateUnsortedSegmentSumOpKernelCache(ctx); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, const user_op::OpKernelCache* cache) const override { const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex("data", 0); const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex("segment_ids", 0); int64_t axis = ctx->Attr("axis"); user_op::Tensor* tmp_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t outer_dim_size = out->shape_view().Count(0, axis); int64_t num_segments = out->shape_view().At(axis); int64_t inner_dim_size = out->shape_view().Count(axis + 1); int64_t num_segment_ids = segment_ids->shape_view().elem_cnt(); Memset(ctx->stream(), tmp_buf->mut_dptr(), 0, out->shape_view().elem_cnt() * sizeof(float)); int64_t offset = 0; if (cache != nullptr) { auto* sum_cache = dynamic_cast(cache); CHECK_NOTNULL(sum_cache); CHECK_EQ(out->shape_view().At(axis), sum_cache->upper() - sum_cache->lower()); offset = sum_cache->lower(); } UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( ctx->stream(), segment_ids->dptr(), data->dptr(), num_segment_ids, num_segments, outer_dim_size, inner_dim_size, offset, tmp_buf->mut_dptr()); auto f2h = ep::primitive::NewPrimitive( ctx->device_type(), DataType::kFloat, out->data_type()); CHECK(f2h); f2h->Launch(ctx->stream(), tmp_buf->dptr(), out->mut_dptr(), out->shape_view().elem_cnt()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; #define REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type, kernel_type) \ REGISTER_USER_KERNEL(kernel_type) \ .SetCreateFn>() \ .SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("segment_ids", 0) == OF_PP_PAIR_SECOND(segment_ids_type)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_type))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& out_shape = ctx->OutputShape("out", 0); \ return GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)); \ }); #define REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE(out_type, segment_ids_type) \ REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type, \ ("unsorted_segment_sum")) \ REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type, \ ("unsorted_segment_sum_like")) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE, FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) #if CUDA_VERSION >= 11000 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE, OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16), INDEX_DATA_TYPE_SEQ) #endif #undef REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unsorted_segment_sum_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" namespace oneflow { template struct UnsortedSegmentSumKernelUtil final { static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, T* out); }; template void UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, T* out) { FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) { FOR_RANGE(int64_t, i, 0, num_segment_ids) { CHECK_GE(segment_ids[i], 0); const int64_t idx = segment_ids[i] - segment_id_offset; T* to = out + outer_idx * num_segments * inner_dim_size + idx * inner_dim_size; if (idx >= 0 && idx < num_segments) { const T* from = data + outer_idx * num_segment_ids * inner_dim_size + i * inner_dim_size; std::transform(from, from + inner_dim_size, to, to, std::plus()); } } } } #define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU(in_type_pair, index_type_pair) \ template struct UnsortedSegmentSumKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU, UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ); #undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unsorted_segment_sum_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include namespace oneflow { namespace { template __device__ __forceinline__ bool IsZero(T v) { return v == 0; } template<> __device__ __forceinline__ bool IsZero(half v) { return v == static_cast(0); } #if CUDA_VERSION >= 11000 template<> __device__ __forceinline__ bool IsZero(nv_bfloat16 v) { return v == __float2bfloat16(0); } #endif template<> __device__ __forceinline__ bool IsZero(half2 v) { return v.x == static_cast(0) && v.y == static_cast(0); } template __global__ void UnsortedSegmentSumGpu(const IDX data_elem_cnt, const NdIndexOffsetHelper in_helper, const NdIndexOffsetHelper out_helper, const U* data, const K* segment_ids, const IDX num_segments, const IDX segment_id_offset, T* out) { CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) { const U val = data[i]; if (!IsZero(val)) { IDX outer_idx, segment_id_idx, inner_idx; in_helper.OffsetToNdIndex(i, outer_idx, segment_id_idx, inner_idx); const K origin_idx = segment_ids[segment_id_idx]; assert(origin_idx >= 0); const IDX idx = origin_idx - segment_id_offset; if (idx >= 0 && idx < num_segments) { const int64_t out_offset = out_helper.NdIndexToOffset(outer_idx, idx, inner_idx); if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast(val)); } } } } } template __global__ void UnsortedSegmentColSumGpu(const IDX data_elem_cnt, const NdIndexOffsetHelper in_helper, const NdIndexOffsetHelper out_helper, const U* data, const K* segment_ids, const IDX num_segments, const IDX segment_id_offset, T* out) { CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) { const U val = data[i]; if (!IsZero(val)) { IDX outer_idx, segment_id_idx; in_helper.OffsetToNdIndex(i, outer_idx, segment_id_idx); const K origin_idx = segment_ids[segment_id_idx]; assert(origin_idx >= 0); const IDX idx = origin_idx - segment_id_offset; if (idx >= 0 && idx < num_segments) { const int64_t out_offset = out_helper.NdIndexToOffset(outer_idx, idx); if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast(val)); } } } } } template __global__ void UnsortedSegmentRowSumGpu(const IDX data_elem_cnt, const NdIndexOffsetHelper in_helper, const NdIndexOffsetHelper out_helper, const U* data, const K* segment_ids, const IDX num_segments, const IDX segment_id_offset, T* out) { CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) { const U val = data[i]; if (!IsZero(val)) { IDX segment_id_idx, inner_idx; in_helper.OffsetToNdIndex(i, segment_id_idx, inner_idx); const K origin_idx = segment_ids[segment_id_idx]; assert(origin_idx >= 0); const IDX idx = origin_idx - segment_id_offset; if (idx >= 0 && idx < num_segments) { const int64_t out_offset = out_helper.NdIndexToOffset(idx, inner_idx); if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast(val)); } } } } } template void UnsortedSegmentSumUtil(ep::Stream* stream, const K* segment_ids, const U* data, IDX num_segment_ids, IDX num_segments, IDX outer_dim_size, IDX inner_dim_size, IDX segment_id_offset, T* out) { const IDX data_elem_cnt = num_segment_ids * outer_dim_size * inner_dim_size; if (inner_dim_size == 1) { NdIndexOffsetHelper in_helper(outer_dim_size, num_segment_ids); NdIndexOffsetHelper out_helper(outer_dim_size, num_segments); UnsortedSegmentColSumGpu <<As()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper, data, segment_ids, num_segments, segment_id_offset, out); } else if (outer_dim_size == 1) { NdIndexOffsetHelper in_helper(num_segment_ids, inner_dim_size); NdIndexOffsetHelper out_helper(num_segments, inner_dim_size); UnsortedSegmentRowSumGpu <<As()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper, data, segment_ids, num_segments, segment_id_offset, out); } else { NdIndexOffsetHelper in_helper(outer_dim_size, num_segment_ids, inner_dim_size); NdIndexOffsetHelper out_helper(outer_dim_size, num_segments, inner_dim_size); UnsortedSegmentSumGpu <<As()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper, data, segment_ids, num_segments, segment_id_offset, out); } } template void DispatchDataType(ep::Stream* stream, const K* segment_ids, const U* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, T* out) { auto* cuda_stream = stream->As(); if (std::is_same::value && std::is_same::value && cuda_stream->device_properties().major >= 6 && reinterpret_cast(data) % sizeof(half2) == 0 && reinterpret_cast(out) % sizeof(half2) == 0 && inner_dim_size % 2 == 0) { UnsortedSegmentSumUtil( stream, segment_ids, reinterpret_cast(data), num_segment_ids, num_segments, outer_dim_size, inner_dim_size / 2, segment_id_offset, reinterpret_cast(out)); } else { UnsortedSegmentSumUtil(stream, segment_ids, data, num_segment_ids, num_segments, outer_dim_size, inner_dim_size, segment_id_offset, out); } } } // namespace template struct UnsortedSegmentSumKernelUtil final { static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const U* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, T* out) { const int64_t data_elem_cnt = num_segment_ids * outer_dim_size * inner_dim_size; const int64_t out_elem_cnt = outer_dim_size * num_segments * inner_dim_size; if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal() / 2) { DispatchDataType(stream, segment_ids, data, num_segment_ids, num_segments, outer_dim_size, inner_dim_size, segment_id_offset, out); } else { DispatchDataType(stream, segment_ids, data, num_segment_ids, num_segments, outer_dim_size, inner_dim_size, segment_id_offset, out); } } }; template struct UnsortedSegmentSumKernelUtil final { static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const float16* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, float* out) { UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( stream, segment_ids, reinterpret_cast(data), num_segment_ids, num_segments, outer_dim_size, inner_dim_size, segment_id_offset, out); } }; #define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA(in_type_pair, index_type_pair) \ template struct UnsortedSegmentSumKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA, UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ); #undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA #define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA(in_type_pair, index_type_pair, \ out_type_pair) \ template struct UnsortedSegmentSumKernelUtil; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA, OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat), UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ, FLOAT16_DATA_TYPE_SEQ); #if CUDA_VERSION >= 11000 OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA, OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat), UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ, OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)); #endif #undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA template struct UnsortedSegmentSumKernelUtil; template struct UnsortedSegmentSumKernelUtil; template struct UnsortedSegmentSumKernelUtil; } // namespace oneflow ================================================ FILE: oneflow/user/kernels/unsorted_segment_sum_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_ #define ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { template struct UnsortedSegmentSumKernelUtil final { static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const U* data, int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset, T* out); }; #define UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ \ FLOATING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ \ INDEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) } // namespace oneflow #endif // ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/upsample_bicubic_2d_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { template class UpsampleBicubic2dCPUKernel final : public user_op::OpKernel { public: UpsampleBicubic2dCPUKernel() = default; ~UpsampleBicubic2dCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const T* in_ptr = x_tensor->dptr(); T* out_ptr = y_tensor->mut_dptr(); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = x_tensor->shape_view().At(0); const int channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(out_ptr, in_ptr, sizeof(T) * nbatch * channels * in_height * in_width); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); for (int64_t output_y = 0; output_y < out_height; output_y++) { for (int64_t output_x = 0; output_x < out_width; output_x++) { const T* in = in_ptr; T* out = out_ptr; const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true); int64_t input_x = std::floor(real_x); const T t_x = real_x - input_x; const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true); int64_t input_y = std::floor(real_y); const T t_y = real_y - input_y; for (int64_t c = 0; c < channels * nbatch; c++) { T coefficients[4]; // Interpolate 4 times in the x direction for (int64_t i = 0; i < 4; i++) { coefficients[i] = cubic_interp1d(upsample_get_value_bounded(in, in_width, in_height, input_x - 1, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 0, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 1, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 2, input_y - 1 + i), t_x); } // Interpolate in the y direction using x interpolations out[output_y * out_width + output_x] = cubic_interp1d( coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y); // Move to next channel in += in_width * in_height; out += out_width * out_height; } } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel { public: UpsampleBicubic2dGradCPUKernel() = default; ~UpsampleBicubic2dGradCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); T* in_ptr = dx_tensor->mut_dptr(); const T* out_ptr = dy_tensor->dptr(); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = dx_tensor->shape_view().At(0); int channels = dx_tensor->shape_view().At(1); channels = channels * nbatch; const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(in_ptr, out_ptr, sizeof(T) * channels * in_height * in_width); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); for (int64_t output_y = 0; output_y < out_height; output_y++) { for (int64_t output_x = 0; output_x < out_width; output_x++) { T* in = in_ptr; const T* out = out_ptr; T real_x = GetAreaPixel(scale_width, output_x, align_corners, true); int64_t input_x = std::floor(real_x); T t_x = real_x - input_x; T real_y = GetAreaPixel(scale_height, output_y, align_corners, true); int64_t input_y = std::floor(real_y); T t_y = real_y - input_y; T x_coeffs[4]; T y_coeffs[4]; get_cubic_upsample_coefficients(x_coeffs, t_x); get_cubic_upsample_coefficients(y_coeffs, t_y); for (int64_t c = 0; c < channels; c++) { T out_value = out[output_y * out_width + output_x]; for (int64_t i = 0; i < 4; i++) { for (int64_t j = 0; j < 4; j++) { upsample_increment_value_bounded(in, in_width, in_height, input_x - 1 + i, input_y - 1 + j, out_value * y_coeffs[j] * x_coeffs[i]); } } in += in_width * in_height; out += out_width * out_height; } } } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_bicubic_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_bicubic_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(float) REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_bicubic_2d_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/upsample_kernel.h" #include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { namespace { template __device__ void upsample_increment_value_bounded_cuda(T* data, int64_t width, int64_t height, int64_t element, int64_t x, int64_t y, T value) { int64_t access_x = max(min(x, width - 1), static_cast(0)); int64_t access_y = max(min(y, height - 1), static_cast(0)); cuda::atomic::FastAdd(data, access_y * width + access_x, element, value); } template __global__ void UpsampleBicubic2dForward(const int64_t elem_cnt, const T* in_dptr, const int64_t nbatch, const int64_t channels, const int64_t in_height, const int64_t in_width, const int64_t out_height, const int64_t out_width, const float scale_height, const float scale_width, bool align_corners, T* out_dptr) { CUDA_1D_KERNEL_LOOP(idx, elem_cnt) { const int output_x = idx % out_width; const int output_y = idx / out_width; const T* in = in_dptr; T* out = out_dptr; const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true); int64_t input_x = floor(1.0 * real_x); const T t_x = real_x - input_x; const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true); int64_t input_y = floor(1.0 * real_y); const T t_y = real_y - input_y; for (int64_t c = 0; c < channels * nbatch; c++) { T coefficients[4]; // Interpolate 4 times in the x direction for (int64_t i = 0; i < 4; i++) { coefficients[i] = cubic_interp1d( upsample_get_value_bounded(in, in_width, in_height, input_x - 1, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 0, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 1, input_y - 1 + i), upsample_get_value_bounded(in, in_width, in_height, input_x + 2, input_y - 1 + i), t_x); } // Interpolate in the y direction using x interpolations out[output_y * out_width + output_x] = cubic_interp1d( coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y); // Move to next channel in += in_width * in_height; out += out_width * out_height; } } } template __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dptr, const int64_t nbatch, const int64_t channels, const int64_t in_height, const int64_t in_width, const int64_t out_height, const int64_t out_width, const float scale_height, const float scale_width, bool align_corners, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(idx, elem_cnt) { const int output_x = idx % out_width; const int output_y = idx / out_width; T* in = dx_dptr; const T* out = dy_dptr; T real_x = GetAreaPixel(scale_width, output_x, align_corners, true); int64_t input_x = floor(1.0 * real_x); T t_x = real_x - input_x; T real_y = GetAreaPixel(scale_height, output_y, align_corners, true); int64_t input_y = floor(1.0 * real_y); T t_y = real_y - input_y; T x_coeffs[4]; T y_coeffs[4]; get_cubic_upsample_coefficients(x_coeffs, t_x); get_cubic_upsample_coefficients(y_coeffs, t_y); for (int64_t c = 0; c < channels * nbatch; c++) { T out_value = out[output_y * out_width + output_x]; for (int64_t i = 0; i < 4; i++) { for (int64_t j = 0; j < 4; j++) { upsample_increment_value_bounded_cuda(in, in_width, in_height, elem_cnt, input_x - 1 + i, input_y - 1 + j, out_value * y_coeffs[j] * x_coeffs[i]); } } in += in_width * in_height; out += out_width * out_height; } } } } // namespace template class UpsampleBicubic2dGPUKernel final : public user_op::OpKernel { public: UpsampleBicubic2dGPUKernel() = default; ~UpsampleBicubic2dGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const T* in_ptr = x_tensor->dptr(); T* out_ptr = y_tensor->mut_dptr(); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = x_tensor->shape_view().At(0); const int channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const int64_t elem_cnt = out_height * out_width; if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleBicubic2dForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), nbatch, channels, in_height, in_width, out_height, out_width, scale_height, scale_width, align_corners, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleBicubic2dGradGPUKernel final : public user_op::OpKernel { public: UpsampleBicubic2dGradGPUKernel() = default; ~UpsampleBicubic2dGradGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = dx_tensor->shape_view().At(0); const int channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const int64_t elem_cnt = out_height * out_width; if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleBicubic2dBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), nbatch, channels, in_height, in_width, out_height, out_width, scale_height, scale_width, align_corners, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_bicubic_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_bicubic_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(float) REGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template static void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, const T scale_h, const T scale_w, const bool align_corners, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, ¶ms); const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const T top_left = in_dptr[top_offset + params.left_w_index]; const T top_right = in_dptr[top_offset + params.right_w_index]; const T bottom_left = in_dptr[bottom_offset + params.left_w_index]; const T bottom_right = in_dptr[bottom_offset + params.right_w_index]; out_dptr[index] = (1 - params.h_lerp) * ((1 - params.w_lerp) * top_left + params.w_lerp * top_right) + params.h_lerp * ((1 - params.w_lerp) * bottom_left + params.w_lerp * bottom_right); } } template static void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, const T scale_h, const T scale_w, const bool align_corners, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, ¶ms); const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const T dy = dy_dptr[index]; const T dbottom = params.h_lerp * dy; T* dx_dptr_bottom_offset = dx_dptr + bottom_offset; *(dx_dptr_bottom_offset + params.left_w_index) += static_cast((1 - params.w_lerp) * dbottom); *(dx_dptr_bottom_offset + params.right_w_index) += static_cast(params.w_lerp * dbottom); const T dtop = dy - dbottom; T* dx_dptr_top_offset = dx_dptr + top_offset; *(dx_dptr_top_offset + params.left_w_index) += static_cast((1 - params.w_lerp) * dtop); *(dx_dptr_top_offset + params.right_w_index) += static_cast(params.w_lerp * dtop); } } } // namespace template class UpsampleBilinear2DCPUKernel final : public user_op::OpKernel { public: UpsampleBilinear2DCPUKernel() = default; ~UpsampleBilinear2DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3)); const int64_t nbatch = x_tensor->shape_view().At(0); const int64_t channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); UpsampleBilinear2DForward(elem_cnt, x_tensor->dptr(), in_helper, out_helper, in_height, in_width, scale_height, scale_width, align_corners, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleBilinear2DGradCPUKernel final : public user_op::OpKernel { public: UpsampleBilinear2DGradCPUKernel() = default; ~UpsampleBilinear2DGradCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3)); const int64_t nbatch = dx_tensor->shape_view().At(0); const int64_t channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); UpsampleBilinearBackward(elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, in_height, in_width, scale_height, scale_width, align_corners, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_bilinear_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_bilinear_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(float) REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_bilinear_2d_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { __device__ __forceinline__ void GetBilinearParamHalf(const bool align_corners, const int64_t h, const int64_t w, const int64_t in_height, const int64_t in_width, const double scale_h, const double scale_w, BilinearParam* params) { half h1r; if (align_corners) { h1r = static_cast(scale_h * static_cast(h)); } else { h1r = h1r = static_cast((static_cast(h) + 0.5f) * scale_h - 0.5f); h1r = h1r < static_cast(0.0) ? static_cast(0.0) : h1r; } const int64_t h1 = int(h1r); const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; half w1r; if (align_corners) { w1r = static_cast(scale_w * static_cast(w)); } else { w1r = static_cast((static_cast(w) + 0.5f) * scale_w - 0.5f); w1r = w1r < static_cast(0.0) ? static_cast(0.0) : w1r; } const int64_t w1 = int(w1r); const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; params->top_h_index = h1; params->bottom_h_index = h1 + h1p; params->h_lerp = h1r - static_cast(h1 * 1.0); params->left_w_index = w1; params->right_w_index = w1 + w1p; params->w_lerp = w1r - static_cast(w1 * 1.0); } template __global__ void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, const T scale_h, const T scale_w, const bool align_corners, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, ¶ms); const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const T top_left = in_dptr[top_offset + params.left_w_index]; const T top_right = in_dptr[top_offset + params.right_w_index]; const T bottom_left = in_dptr[bottom_offset + params.left_w_index]; const T bottom_right = in_dptr[bottom_offset + params.right_w_index]; out_dptr[index] = (1 - params.h_lerp) * ((1 - params.w_lerp) * top_left + params.w_lerp * top_right) + params.h_lerp * ((1 - params.w_lerp) * bottom_left + params.w_lerp * bottom_right); } } template<> __global__ void UpsampleBilinear2DForward(const int64_t elem_cnt, const half* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, const half scale_h, const half scale_w, const bool align_corners, half* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParamHalf(align_corners, h, w, in_height, in_width, scale_h, scale_w, ¶ms); const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const half top_left = in_dptr[top_offset + params.left_w_index]; const half top_right = in_dptr[top_offset + params.right_w_index]; const half bottom_left = in_dptr[bottom_offset + params.left_w_index]; const half bottom_right = in_dptr[bottom_offset + params.right_w_index]; out_dptr[index] = (static_cast(1.0) - params.h_lerp) * ((static_cast(1.0) - params.w_lerp) * top_left + params.w_lerp * top_right) + params.h_lerp * ((static_cast(1.0) - params.w_lerp) * bottom_left + params.w_lerp * bottom_right); } } template __global__ void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, const T scale_h, const T scale_w, const bool align_corners, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, ¶ms); const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const T dy = dy_dptr[index]; const T dbottom = params.h_lerp * dy; T* dx_dptr_bottom_offset = dx_dptr + bottom_offset; cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.left_w_index, elem_cnt, static_cast((1 - params.w_lerp) * dbottom)); cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.right_w_index, elem_cnt, static_cast(params.w_lerp * dbottom)); const T dtop = dy - dbottom; T* dx_dptr_top_offset = dx_dptr + top_offset; cuda::atomic::FastAdd(dx_dptr_top_offset, params.left_w_index, elem_cnt, static_cast((1 - params.w_lerp) * dtop)); cuda::atomic::FastAdd(dx_dptr_top_offset, params.right_w_index, elem_cnt, static_cast(params.w_lerp * dtop)); } } template<> __global__ void UpsampleBilinearBackward(const int64_t elem_cnt, const half* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, const half scale_h, const half scale_w, const bool align_corners, half* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); BilinearParam params; GetBilinearParamHalf(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, ¶ms); const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0); const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0); const half dy = dy_dptr[index]; const half dbottom = params.h_lerp * dy; half* dx_dptr_bottom_offset = dx_dptr + bottom_offset; cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.left_w_index, elem_cnt, static_cast((static_cast(1.0) - params.w_lerp) * dbottom)); cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.right_w_index, elem_cnt, static_cast(params.w_lerp * dbottom)); const half dtop = dy - dbottom; half* dx_dptr_top_offset = dx_dptr + top_offset; cuda::atomic::FastAdd(dx_dptr_top_offset, params.left_w_index, elem_cnt, static_cast((static_cast(1.0) - params.w_lerp) * dtop)); cuda::atomic::FastAdd(dx_dptr_top_offset, params.right_w_index, elem_cnt, static_cast(params.w_lerp * dtop)); } } } // namespace template class UpsampleBilinear2DGPUKernel final : public user_op::OpKernel { public: UpsampleBilinear2DGPUKernel() = default; ~UpsampleBilinear2DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3)); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleBilinear2DForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), in_helper, out_helper, in_height, in_width, scale_height, scale_width, align_corners, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleBilinear2DGradGPUKernel final : public user_op::OpKernel { public: UpsampleBilinear2DGradGPUKernel() = default; ~UpsampleBilinear2DGradGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3)); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleBilinearBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, in_height, in_width, scale_height, scale_width, align_corners, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_bilinear_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_bilinear_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(half) REGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(float) REGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_kernel.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/nd_index_offset_helper.h" #include OF_DEVICE_FUNC double GetLinearInputIndex(const int64_t out_dim_idx, const double scale, bool align_corners) { if (align_corners) { return static_cast(scale * out_dim_idx); } else { double src_idx = scale * (out_dim_idx + 0.5) - 0.5; return static_cast(src_idx < 0 ? 0 : src_idx); } } OF_DEVICE_FUNC static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const double scale, const int64_t in_dim_size) { int64_t index = static_cast(floorf(out_dim_idx * scale)); index = index > in_dim_size - 1 ? in_dim_size - 1 : index; return index; } OF_DEVICE_FUNC double GetAreaPixelScale(const int64_t input_size, const int64_t output_size, bool align_corners, const double scale) { if (align_corners) { if (output_size > 1) { return static_cast(input_size - 1) / (output_size - 1); } else { return 0; } } else { return (scale > 0. ? 1.0 / scale : static_cast(input_size) / output_size); } } OF_DEVICE_FUNC double GetAreaPixel(const double scale, const int64_t dst_index, bool align_corners, bool cubic = false) { if (align_corners) { return scale * dst_index; } else { double src_idx = scale * (dst_index + 0.5) - 0.5; return (!cubic && src_idx < 0) ? static_cast(0) : src_idx; } } template struct BilinearParam { int64_t top_h_index; int64_t bottom_h_index; int64_t left_w_index; int64_t right_w_index; T w_lerp; T h_lerp; }; template OF_DEVICE_FUNC void GetBilinearParam(const bool align_corners, const int64_t h, const int64_t w, const int64_t in_height, const int64_t in_width, const double scale_h, const double scale_w, BilinearParam* params) { T h1r; if (align_corners) { h1r = scale_h * static_cast(h); } else { h1r = (static_cast(h) + 0.5f) * scale_h - 0.5f; h1r = h1r < 0 ? 0 : h1r; } const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; T w1r; if (align_corners) { w1r = scale_w * static_cast(w); } else { w1r = (static_cast(w) + 0.5f) * scale_w - 0.5f; w1r = w1r < 0 ? 0 : w1r; } const int64_t w1 = w1r; const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; params->top_h_index = h1; params->bottom_h_index = h1 + h1p; params->h_lerp = h1r - h1; params->left_w_index = w1; params->right_w_index = w1 + w1p; params->w_lerp = w1r - w1; } template OF_DEVICE_FUNC void upsample_increment_value_bounded(T* data, int64_t width, int64_t height, int64_t x, int64_t y, T value) { int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); data[access_y * width + access_x] += value; } template OF_DEVICE_FUNC T upsample_get_value_bounded(const T* data, const int64_t width, const int64_t height, const int64_t x, const int64_t y) { int64_t access_x = x; access_x = access_x > width - 1 ? width - 1 : access_x; access_x = access_x < 0 ? 0 : access_x; int64_t access_y = y; access_y = access_y > height - 1 ? height - 1 : access_y; access_y = access_y < 0 ? 0 : access_y; return data[access_y * width + access_x]; } // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm template OF_DEVICE_FUNC T cubic_convolution1(const T x, const T A) { return ((A + static_cast(2.0)) * x - (A + static_cast(3.0))) * x * x + static_cast(1.0); } template OF_DEVICE_FUNC T cubic_convolution2(const T x, const T A) { return ((A * x - static_cast(5.0) * A) * x + static_cast(8.0) * A) * x - static_cast(4.0) * A; } template OF_DEVICE_FUNC void get_cubic_upsample_coefficients(T coeffs[4], const T t) { T A = -0.75; T x1 = t; coeffs[0] = cubic_convolution2(x1 + 1.0, A); coeffs[1] = cubic_convolution1(x1, A); // opposite coefficients T x2 = 1.0 - t; coeffs[2] = cubic_convolution1(x2, A); coeffs[3] = cubic_convolution2(x2 + 1.0, A); } template OF_DEVICE_FUNC T cubic_interp1d(const T x0, const T x1, const T x2, const T x3, const T t) { T coeffs[4]; get_cubic_upsample_coefficients(coeffs, t); return x0 * coeffs[0] * 1.0 + x1 * coeffs[1] * 1.0 + x2 * coeffs[2] * 1.0 + x3 * coeffs[3] * 1.0; } ================================================ FILE: oneflow/user/kernels/upsample_linear_1d_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template static void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int in_height, const double scale_factor, bool align_corners, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const double h1lambda = h1r - h1; const double h0lambda = static_cast(1.) - h1lambda; out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)] + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)]; } } template static void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int in_height, const double scale_factor, bool align_corners, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const double h1lambda = h1r - h1; const double h0lambda = static_cast(1.) - h1lambda; *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1)) += h0lambda * dy_dptr[index]; *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p)) += h1lambda * dy_dptr[index]; } } } // namespace template class UpsampleLinear1DCPUKernel final : public user_op::OpKernel { public: UpsampleLinear1DCPUKernel() = default; ~UpsampleLinear1DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2)); const int64_t nbatch = x_tensor->shape_view().At(0); const int64_t channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); UpsampleLinear1DForward(elem_cnt, x_tensor->dptr(), in_helper, out_helper, in_height, scale_height, align_corners, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleLinearGrad1DCPUKernel final : public user_op::OpKernel { public: UpsampleLinearGrad1DCPUKernel() = default; ~UpsampleLinearGrad1DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); NdIndexOffsetHelper dy_helper(dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2)); NdIndexOffsetHelper dx_helper(dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2)); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t nbatch = dx_tensor->shape_view().At(0); const int64_t channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); UpsampleLinear1DBackward(elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, in_height, scale_height, align_corners, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_linear_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_linear_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(float) REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_linear_1d_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template __global__ void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int in_height, const double scale_factor, bool align_corners, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const double h1lambda = h1r - h1; const double h0lambda = static_cast(1.) - h1lambda; out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)] + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)]; } } template __global__ void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int in_height, const double scale_factor, bool align_corners, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const double h1lambda = h1r - h1; const double h0lambda = static_cast(1.) - h1lambda; cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, h1), elem_cnt, static_cast(h0lambda * dy_dptr[index])); cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, h1 + h1p), elem_cnt, static_cast(h1lambda * dy_dptr[index])); } } } // namespace template class UpsampleLinear1DGPUKernel final : public user_op::OpKernel { public: UpsampleLinear1DGPUKernel() = default; ~UpsampleLinear1DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2)); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); RUN_CUDA_KERNEL((UpsampleLinear1DForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), in_helper, out_helper, in_height, scale_height, align_corners, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleLinearGrad1DGPUKernel final : public user_op::OpKernel { public: UpsampleLinearGrad1DGPUKernel() = default; ~UpsampleLinearGrad1DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); NdIndexOffsetHelper dy_helper(dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2)); NdIndexOffsetHelper dx_helper(dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2)); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); } else { const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); RUN_CUDA_KERNEL((UpsampleLinear1DBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, in_height, scale_height, align_corners, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_linear_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_linear_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(float) REGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_nearest_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template static void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const double scale_factor, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)]; } } template static void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_height, const double scale_factor, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height); *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h)) += dy_dptr[index]; } } template static void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, const double scale_h, const double scale_w, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)]; } } template static void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, const double scale_h, const double scale_w, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height); const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width); *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h, dx_w)) += dy_dptr[index]; } } template static void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const float scale_d, const float scale_h, const float scale_w, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, d, h, w; out_helper.OffsetToNdIndex(index, n, c, d, h, w); const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width); const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)]; } } template static void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const float scale_d, const float scale_h, const float scale_w, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, d, h, w; dy_helper.OffsetToNdIndex(index, n, c, d, h, w); const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width); const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth); *(dx_dptr + dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w)) += dy_dptr[index]; } } } // namespace template class UpsampleNearest1DCPUKernel final : public user_op::OpKernel { public: UpsampleNearest1DCPUKernel() = default; ~UpsampleNearest1DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); const int64_t nbatch = x_tensor->shape_view().At(0); const int64_t channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); } else { NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2)); UpsampleNearest1DForward(elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), 1.f / height_scale, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearestGrad1DCPUKernel final : public user_op::OpKernel { public: UpsampleNearestGrad1DCPUKernel() = default; ~UpsampleNearestGrad1DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t nbatch = dx_tensor->shape_view().At(0); const int64_t channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); } else { NdIndexOffsetHelper dy_helper(dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2)); NdIndexOffsetHelper dx_helper(dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2)); UpsampleNearest1DBackward(elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), 1.f / height_scale, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(float) REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(double) template class UpsampleNearest2DCPUKernel final : public user_op::OpKernel { public: UpsampleNearest2DCPUKernel() = default; ~UpsampleNearest2DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t nbatch = x_tensor->shape_view().At(0); const int64_t channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); } else { NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3)); UpsampleNearest2DForward(elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearest2DGradCPUKernel final : public user_op::OpKernel { public: UpsampleNearest2DGradCPUKernel() = default; ~UpsampleNearest2DGradCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t nbatch = dx_tensor->shape_view().At(0); const int64_t channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); } else { NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3)); UpsampleNearest2DBackward(elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(float) REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(double) template class UpsampleNearest3DCPUKernel final : public user_op::OpKernel { public: UpsampleNearest3DCPUKernel() = default; ~UpsampleNearest3DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t in_depth = x_blob->shape_view().At(2); const int64_t in_height = x_blob->shape_view().At(3); const int64_t in_width = x_blob->shape_view().At(4); const int64_t out_depth = y_blob->shape_view().At(2); const int64_t out_height = y_blob->shape_view().At(3); const int64_t out_width = y_blob->shape_view().At(4); const int64_t elem_cnt = y_blob->shape_view().elem_cnt(); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } NdIndexOffsetHelper in_helper( x_blob->shape_view().At(0), x_blob->shape_view().At(1), x_blob->shape_view().At(2), x_blob->shape_view().At(3), x_blob->shape_view().At(4)); NdIndexOffsetHelper out_helper( y_blob->shape_view().At(0), y_blob->shape_view().At(1), y_blob->shape_view().At(2), y_blob->shape_view().At(3), y_blob->shape_view().At(4)); UpsampleNearest3DForward(elem_cnt, x_blob->dptr(), in_helper, out_helper, x_blob->shape_view().At(2), x_blob->shape_view().At(3), x_blob->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale, 1.f / width_scale, y_blob->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearestGrad3DCPUKernel final : public user_op::OpKernel { public: UpsampleNearestGrad3DCPUKernel() = default; ~UpsampleNearestGrad3DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); if (dx_blob == nullptr) { return; } Memset(ctx->stream(), dx_blob->mut_dptr(), 0, dx_blob->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t in_depth = dx_blob->shape_view().At(2); const int64_t in_height = dx_blob->shape_view().At(3); const int64_t in_width = dx_blob->shape_view().At(4); const int64_t out_depth = dy_blob->shape_view().At(2); const int64_t out_height = dy_blob->shape_view().At(3); const int64_t out_width = dy_blob->shape_view().At(4); const int64_t elem_cnt = dy_blob->shape_view().elem_cnt(); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } NdIndexOffsetHelper dy_helper( dy_blob->shape_view().At(0), dy_blob->shape_view().At(1), dy_blob->shape_view().At(2), dy_blob->shape_view().At(3), dy_blob->shape_view().At(4)); NdIndexOffsetHelper dx_helper( dx_blob->shape_view().At(0), dx_blob->shape_view().At(1), dx_blob->shape_view().At(2), dx_blob->shape_view().At(3), dx_blob->shape_view().At(4)); UpsampleNearest3DBackward(elem_cnt, dy_blob->dptr(), dy_helper, dx_helper, dx_blob->shape_view().At(2), dx_blob->shape_view().At(3), dx_blob->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale, 1.f / width_scale, dx_blob->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(float) REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_nearest_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/kernel/kernel_util.cuh" #include "oneflow/user/kernels/upsample_kernel.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { template __global__ void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const double scale_factor, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)]; } } template __global__ void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_height, const double scale_factor, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height); cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, dx_h), elem_cnt, static_cast(dy_dptr[index])); } } template __global__ void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, const double scale_h, const double scale_w, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)]; } } template struct alignas(2 * sizeof(T)) Pack2X { T x; T y; }; template __global__ void UpsampleNearest2D2XForward(const int32_t in_elem_cnt, const T* in_dptr, const int32_t in_height, const int32_t in_width, T* out_dptr) { const int32_t in_hw_size = in_width * in_height; CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) { const T in_value = in_dptr[index]; const int32_t nc_idx = index / in_hw_size; const int32_t hw_off = index - nc_idx * in_hw_size; const int32_t h = hw_off / in_width; const int32_t w = hw_off - h * in_width; Pack2X out_value{in_value, in_value}; Pack2X* out_pack_dptr = reinterpret_cast*>(out_dptr); out_pack_dptr[nc_idx * in_hw_size * 2 + h * 2 * in_width + w] = out_value; out_pack_dptr[nc_idx * in_hw_size * 2 + (h * 2 + 1) * in_width + w] = out_value; } } template __global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, const double scale_h, const double scale_w, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height); const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width); cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, dx_h, dx_w), elem_cnt, static_cast(dy_dptr[index])); } } template __global__ void UpsampleNearest2D2XBackward(const int32_t in_elem_cnt, const T* dy_dptr, const int32_t dx_height, const int32_t dx_width, T* dx_dptr) { const int32_t dx_hw_size = dx_height * dx_width; CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) { T dx_value = 0.0; const int32_t nc_idx = index / dx_hw_size; const int32_t dx_hw_off = index - nc_idx * dx_hw_size; const int32_t dx_h = dx_hw_off / dx_width; const int32_t dx_w = dx_hw_off - dx_h * dx_width; const Pack2X* dy_pack_dptr = reinterpret_cast*>(dy_dptr); const Pack2X dy_pack_value1 = dy_pack_dptr[nc_idx * dx_hw_size * 2 + dx_h * 2 * dx_width + dx_w]; const Pack2X dy_pack_value2 = dy_pack_dptr[nc_idx * dx_hw_size * 2 + (dx_h * 2 + 1) * dx_width + dx_w]; dx_value += dy_pack_value1.x; dx_value += dy_pack_value1.y; dx_value += dy_pack_value2.x; dx_value += dy_pack_value2.y; dx_dptr[index] = dx_value; } } template __global__ void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const float scale_d, const float scale_h, const float scale_w, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, d, h, w; out_helper.OffsetToNdIndex(index, n, c, d, h, w); const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width); const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth); out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)]; } } template __global__ void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const float scale_d, const float scale_h, const float scale_w, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, d, h, w; dy_helper.OffsetToNdIndex(index, n, c, d, h, w); const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height); const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width); const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth); cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w), elem_cnt, static_cast(dy_dptr[index])); } } } // namespace template class UpsampleNearest1DGPUKernel final : public user_op::OpKernel { public: UpsampleNearest1DGPUKernel() = default; ~UpsampleNearest1DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type())); } else { NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2)); RUN_CUDA_KERNEL((UpsampleNearest1DForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), 1.f / height_scale, y_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel { public: UpsampleNearestGrad1DGPUKernel() = default; ~UpsampleNearestGrad1DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); } else { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); NdIndexOffsetHelper dy_helper(dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2)); NdIndexOffsetHelper dx_helper(dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2)); RUN_CUDA_KERNEL((UpsampleNearest1DBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), 1.f / height_scale, dx_tensor->mut_dptr()); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_1d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_1d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(float) REGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(double) template class UpsampleNearest2DGPUKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: UpsampleNearest2DGPUKernel() = default; ~UpsampleNearest2DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t out_elem_cnt = y_tensor->shape_view().elem_cnt(); const int64_t in_elem_cnt = x_tensor->shape_view().elem_cnt(); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type())); } else { const int64_t n = x_tensor->shape_view().At(0); const int64_t c = x_tensor->shape_view().At(1); if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) { RUN_CUDA_KERNEL(UpsampleNearest2D2XForward, ctx->stream(), in_elem_cnt, in_elem_cnt, x_tensor->dptr(), in_height, in_width, y_tensor->mut_dptr()); } else { NdIndexOffsetHelper in_helper(n, c, in_height, in_width); NdIndexOffsetHelper out_helper(n, c, out_height, out_width); RUN_CUDA_KERNEL((UpsampleNearest2DForward), ctx->stream(), out_elem_cnt, out_elem_cnt, x_tensor->dptr(), in_helper, out_helper, in_height, in_width, 1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel { public: UpsampleNearest2DGradGPUKernel() = default; ~UpsampleNearest2DGradGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t in_elem_cnt = dx_tensor->shape_view().elem_cnt(); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); if (!output_size.empty()) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); } else { if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) { RUN_CUDA_KERNEL(UpsampleNearest2D2XBackward, ctx->stream(), in_elem_cnt, in_elem_cnt, dy_tensor->dptr(), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->mut_dptr()); } else { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3)); RUN_CUDA_KERNEL((UpsampleNearest2DBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr()); } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_2d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_2d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(float) REGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(half) REGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(double) template class UpsampleNearest3DGPUKernel final : public user_op::OpKernel { public: UpsampleNearest3DGPUKernel() = default; ~UpsampleNearest3DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t in_depth = x_tensor->shape_view().At(2); const int64_t in_height = x_tensor->shape_view().At(3); const int64_t in_width = x_tensor->shape_view().At(4); const int64_t out_depth = y_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(3); const int64_t out_width = y_tensor->shape_view().At(4); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3), y_tensor->shape_view().At(4)); RUN_CUDA_KERNEL((UpsampleNearest3DForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleNearestGrad3DGPUKernel final : public user_op::OpKernel { public: UpsampleNearestGrad3DGPUKernel() = default; ~UpsampleNearestGrad3DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); const int64_t in_depth = dx_tensor->shape_view().At(2); const int64_t in_height = dx_tensor->shape_view().At(3); const int64_t in_width = dx_tensor->shape_view().At(4); const int64_t out_depth = dy_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(3); const int64_t out_width = dy_tensor->shape_view().At(4); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4)); RUN_CUDA_KERNEL((UpsampleNearest3DBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_nearest_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_nearest_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(float) REGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template static void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const T rdepth, const T rheight, const T rwidth, const bool align_corners, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, d, h, w; out_helper.OffsetToNdIndex(index, n, c, d, h, w); const T t1r = GetAreaPixel(rdepth, d, align_corners); const int64_t t1 = t1r; const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0; const T t1lambda = t1r - t1; const T t0lambda = static_cast(1.) - t1lambda; const T h1r = GetAreaPixel(rheight, h, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const T h1lambda = h1r - h1; const T h0lambda = static_cast(1.) - h1lambda; const T w1r = GetAreaPixel(rwidth, w, align_corners); const int64_t w1 = w1r; const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; const T w1lambda = w1r - w1; const T w0lambda = static_cast(1.) - w1lambda; const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)]; out_dptr[index] = t0lambda * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + h1lambda * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p])) + t1lambda * (h0lambda * (w0lambda * pos1[t1p * in_height * in_width] + w1lambda * pos1[t1p * in_height * in_width + w1p]) + h1lambda * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width] + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p])); } } template static void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const T rdepth, const T rheight, const T rwidth, const bool align_corners, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, d, h, w; dy_helper.OffsetToNdIndex(index, n, c, d, h, w); const T t1r = GetAreaPixel(rdepth, d, align_corners); const int64_t t1 = t1r; const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0; const T t1lambda = t1r - t1; const T t0lambda = static_cast(1.) - t1lambda; const T h1r = GetAreaPixel(rheight, h, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const T h1lambda = h1r - h1; const T h0lambda = static_cast(1.) - h1lambda; const T w1r = GetAreaPixel(rwidth, w, align_corners); const int64_t w1 = w1r; const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; const T w1lambda = w1r - w1; const T w0lambda = static_cast(1.) - w1lambda; T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)]; const T* pos2 = &dy_dptr[index]; pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0]; pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0]; pos1[h1p * in_width] += t0lambda * h1lambda * w0lambda * pos2[0]; pos1[h1p * in_width + w1p] += t0lambda * h1lambda * w1lambda * pos2[0]; pos1[t1p * in_height * in_width] += t1lambda * h0lambda * w0lambda * pos2[0]; pos1[t1p * in_height * in_width + w1p] += t1lambda * h0lambda * w1lambda * pos2[0]; pos1[t1p * in_height * in_width + h1p * in_width] += t1lambda * h1lambda * w0lambda * pos2[0]; pos1[t1p * in_height * in_width + h1p * in_width + w1p] += t1lambda * h1lambda * w1lambda * pos2[0]; } } } // namespace template class UpsampleTrilinear3DCPUKernel final : public user_op::OpKernel { public: UpsampleTrilinear3DCPUKernel() = default; ~UpsampleTrilinear3DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3), y_tensor->shape_view().At(4)); const int64_t in_depth = x_tensor->shape_view().At(2); const int64_t in_height = x_tensor->shape_view().At(3); const int64_t in_width = x_tensor->shape_view().At(4); const int64_t out_depth = y_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(3); const int64_t out_width = y_tensor->shape_view().At(4); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); UpsampleTrilinear3DForward(elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4), scale_depth, scale_height, scale_width, align_corners, y_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleTrilinearGrad3DCPUKernel final : public user_op::OpKernel { public: UpsampleTrilinearGrad3DCPUKernel() = default; ~UpsampleTrilinearGrad3DCPUKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4)); const int64_t in_depth = dx_tensor->shape_view().At(2); const int64_t in_height = dx_tensor->shape_view().At(3); const int64_t in_width = dx_tensor->shape_view().At(4); const int64_t out_depth = dy_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(3); const int64_t out_width = dy_tensor->shape_view().At(4); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); UpsampleTrilinear3DBackward(elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4), scale_depth, scale_height, scale_width, align_corners, dx_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_trilinear_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_trilinear_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(float) REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/upsample_trilinear_3d_kernel.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/upsample_kernel.h" namespace oneflow { namespace { template __global__ void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const T rdepth, const T rheight, const T rwidth, const bool align_corners, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, d, h, w; out_helper.OffsetToNdIndex(index, n, c, d, h, w); const T t1r = GetAreaPixel(rdepth, d, align_corners); const int64_t t1 = t1r; const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0; const T t1lambda = t1r - t1; const T t0lambda = static_cast(1.) - t1lambda; const T h1r = GetAreaPixel(rheight, h, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const T h1lambda = h1r - h1; const T h0lambda = static_cast(1.) - h1lambda; const T w1r = GetAreaPixel(rwidth, w, align_corners); const int64_t w1 = w1r; const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; const T w1lambda = w1r - w1; const T w0lambda = static_cast(1.) - w1lambda; const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)]; out_dptr[index] = t0lambda * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + h1lambda * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p])) + t1lambda * (h0lambda * (w0lambda * pos1[t1p * in_height * in_width] + w1lambda * pos1[t1p * in_height * in_width + w1p]) + h1lambda * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width] + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p])); } } template __global__ void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t in_depth, const int64_t in_height, const int64_t in_width, const T rdepth, const T rheight, const T rwidth, const bool align_corners, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, d, h, w; dy_helper.OffsetToNdIndex(index, n, c, d, h, w); const T t1r = GetAreaPixel(rdepth, d, align_corners); const int64_t t1 = t1r; const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0; const T t1lambda = t1r - t1; const T t0lambda = static_cast(1.) - t1lambda; const T h1r = GetAreaPixel(rheight, h, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; const T h1lambda = h1r - h1; const T h0lambda = static_cast(1.) - h1lambda; const T w1r = GetAreaPixel(rwidth, w, align_corners); const int64_t w1 = w1r; const int64_t w1p = (w1 < in_width - 1) ? 1 : 0; const T w1lambda = w1r - w1; const T w0lambda = static_cast(1.) - w1lambda; T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)]; const T* pos2 = &dy_dptr[index]; cuda::atomic::FastAdd(pos1, 0, elem_cnt, t0lambda * h0lambda * w0lambda * pos2[0]); cuda::atomic::FastAdd(pos1, w1p, elem_cnt, t0lambda * h0lambda * w1lambda * pos2[0]); cuda::atomic::FastAdd(pos1, h1p * in_width, elem_cnt, t0lambda * h1lambda * w0lambda * pos2[0]); cuda::atomic::FastAdd(pos1, h1p * in_width + w1p, elem_cnt, t0lambda * h1lambda * w1lambda * pos2[0]); cuda::atomic::FastAdd(pos1, t1p * in_height * in_width, elem_cnt, t1lambda * h0lambda * w0lambda * pos2[0]); cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + w1p, elem_cnt, t1lambda * h0lambda * w1lambda * pos2[0]); cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + h1p * in_width, elem_cnt, t1lambda * h1lambda * w0lambda * pos2[0]); cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + h1p * in_width + w1p, elem_cnt, t1lambda * h1lambda * w1lambda * pos2[0]); } } } // namespace template class UpsampleTrilinear3DGPUKernel final : public user_op::OpKernel { public: UpsampleTrilinear3DGPUKernel() = default; ~UpsampleTrilinear3DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper in_helper( x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4)); NdIndexOffsetHelper out_helper( y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2), y_tensor->shape_view().At(3), y_tensor->shape_view().At(4)); const int64_t in_depth = x_tensor->shape_view().At(2); const int64_t in_height = x_tensor->shape_view().At(3); const int64_t in_width = x_tensor->shape_view().At(4); const int64_t out_depth = y_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(3); const int64_t out_width = y_tensor->shape_view().At(4); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleTrilinear3DForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->dptr(), in_helper, out_helper, x_tensor->shape_view().At(2), x_tensor->shape_view().At(3), x_tensor->shape_view().At(4), scale_depth, scale_height, scale_width, align_corners, y_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template class UpsampleTrilinearGrad3DGPUKernel final : public user_op::OpKernel { public: UpsampleTrilinearGrad3DGPUKernel() = default; ~UpsampleTrilinearGrad3DGPUKernel() = default; private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape_view().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); NdIndexOffsetHelper dy_helper( dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4)); NdIndexOffsetHelper dx_helper( dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4)); const int64_t in_depth = dx_tensor->shape_view().At(2); const int64_t in_height = dx_tensor->shape_view().At(3); const int64_t in_width = dx_tensor->shape_view().At(4); const int64_t out_depth = dy_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(3); const int64_t out_width = dy_tensor->shape_view().At(4); const std::vector output_size = ctx->Attr>("output_size"); double depth_scale = ctx->Attr("depth_scale"); double height_scale = ctx->Attr("height_scale"); double width_scale = ctx->Attr("width_scale"); if (!output_size.empty()) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); RUN_CUDA_KERNEL((UpsampleTrilinear3DBackward), ctx->stream(), elem_cnt, elem_cnt, dy_tensor->dptr(), dy_helper, dx_helper, dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4), scale_depth, scale_height, scale_width, align_corners, dx_tensor->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(dtype) \ REGISTER_USER_KERNEL("upsample_trilinear_3d") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); \ REGISTER_USER_KERNEL("upsample_trilinear_3d_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); REGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(float) REGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(double) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/util_ops_kernels.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/elementwise_primitive_kernel.h" namespace oneflow { namespace user_op { #define UTIL_OPS_SEQ \ OF_PP_MAKE_TUPLE_SEQ("isinf", ep::primitive::UnaryOp::kIsInf) \ OF_PP_MAKE_TUPLE_SEQ("isnan", ep::primitive::UnaryOp::kIsNan) \ OF_PP_MAKE_TUPLE_SEQ("isfinite", ep::primitive::UnaryOp::kIsFinite) #define RISTER_UTIL_OPS(op_name, op_kind) \ REGISTER_USER_KERNEL(op_name) \ .SetCreateFn([]() { \ return user_op::NewOpKernel( \ "out", "in", [](user_op::KernelComputeContext* ctx) { \ const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); \ const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); \ return ep::primitive::NewPrimitive( \ ctx->device_type(), op_kind, src->data_type(), dst->data_type()); \ }); \ }) \ .SetIsMatchedHob(UnaryPrimitiveExists(op_kind, "out", "in")); OF_PP_FOR_EACH_TUPLE(RISTER_UTIL_OPS, UTIL_OPS_SEQ) #undef RISTER_UTIL_OPS #undef UTIL_OPS_SEQ } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/variance_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_reduce.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/user/kernels/variance_kernel_util.h" namespace oneflow { namespace user_op { template class VarKernel final : public user_op::OpKernel { public: VarKernel() = default; ~VarKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); const bool unbiased = ctx->Attr("unbiased"); const T* in_ptr = input->dptr(); T* out_ptr = output->mut_dptr(); const std::vector axis = ctx->Attr>("dim"); const int64_t input_dim_element = input->shape_view().elem_cnt(); int64_t axis_dim_element = 1; for (int64_t i = 0; i < axis.size(); ++i) { axis_dim_element *= input->shape_view().At(axis[i]); } // when computing the variance with all the elements, the implementation of cuda kernel may use // tmp buffer for computation. ComputeType* tmp_buffer_ptr = (input_dim_element > 0 && (axis.size() == input->shape_view().NumAxes() || input_dim_element == axis_dim_element) && DeviceType::kCUDA == device_type) ? ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr() : nullptr; VarParamHelper param_helper(input->shape_view(), axis, unbiased); VarFunctor()(ctx->stream(), in_ptr, out_ptr, tmp_buffer_ptr, param_helper.param); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_VAR_CPU_KERNEL(dtype, compute_type) \ REGISTER_USER_KERNEL("var") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobAttr("dtype") == GetDataType::value)); REGISTER_VAR_CPU_KERNEL(float, double) REGISTER_VAR_CPU_KERNEL(double, double) REGISTER_VAR_CPU_KERNEL(float16, double) REGISTER_VAR_CPU_KERNEL(bfloat16, double) #undef REGISTER_VAR_CPU_KERNEL #ifdef WITH_CUDA template size_t InferTmpBufferSize(user_op::InferContext* ctx) { const TensorDesc& input = ctx->InputTensorDesc("input", 0); const Shape& input_shape = input.shape(); const std::vector axis = ctx->Attr>("dim"); const int64_t input_dim_element = input.shape().elem_cnt(); int64_t axis_dim_element = 1; for (int64_t i = 0; i < axis.size(); ++i) { axis_dim_element *= input.shape().At(axis[i]); } if (input_dim_element > 0 && (axis.size() == input_shape.NumAxes() || input_dim_element == axis_dim_element)) { return GetCudaAlignedSize( std::min(static_cast(std::ceil(std::sqrt(input.shape().elem_cnt()))), kCudaMaxBlocksNum) * sizeof(ComputeType) * 3); } return 0; } #define REGISTER_VAR_CUDA_KERNEL(dtype, compute_type) \ REGISTER_USER_KERNEL("var") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobAttr("dtype") == GetDataType::value)) \ .SetInferTmpSizeFn(InferTmpBufferSize); REGISTER_VAR_CUDA_KERNEL(float, double) REGISTER_VAR_CUDA_KERNEL(double, double) REGISTER_VAR_CUDA_KERNEL(half, double) #if CUDA_VERSION >= 11000 REGISTER_VAR_CUDA_KERNEL(nv_bfloat16, double) #endif // CUDA_VERSION >= 11000 #undef REGISTER_VAR_CUDA_KERNEL #endif // WITH_CUDA } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/variance_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/variance_kernel_util.h" namespace oneflow { namespace user_op { template struct VarFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr, const VarParam var_param) { // if var_param.parallel_num is 0, do nothing, return 0-size tensor if (IsNanOut(var_param)) { for (size_t i = 0; i < var_param.parallel_num; i++) { out_ptr[i] = std::numeric_limits::quiet_NaN(); } } else { for (size_t i = 0; i < var_param.parallel_num; i++) { const size_t input_offset = LinearIndex2Offset( i, var_param.dim_size_in_caxis, var_param.stride_in_caxis, var_param.caxis_size); ComputeVarUsingWelford(&in_ptr[input_offset], &out_ptr[i], var_param); } } } }; template struct VarFunctor; template struct VarFunctor; template struct VarFunctor; template struct VarFunctor; } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/variance_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/variance_kernel_util.h" #include "oneflow/core/cuda/layer_norm.cuh" namespace oneflow { namespace user_op { namespace { template __inline__ __device__ T Nan(); template<> __inline__ __device__ float Nan() { return CUDART_NAN_F; } template<> __inline__ __device__ double Nan() { return CUDART_NAN; } template<> __inline__ __device__ half Nan() { return half(CUDART_NAN_F); } #if CUDA_VERSION >= 11000 template<> __inline__ __device__ nv_bfloat16 Nan() { return nv_bfloat16(CUDART_NAN_F); } #endif } // namespace template __global__ void ComputeVarUsingWelfordWrapper(const T* in_ptr, T* out_ptr, const VarParam var_param, bool is_nan) { if (is_nan) { CUDA_1D_KERNEL_LOOP(i, var_param.parallel_num) { out_ptr[i] = Nan(); } } else { CUDA_1D_KERNEL_LOOP(i, var_param.parallel_num) { const size_t input_offset = LinearIndex2Offset( i, var_param.dim_size_in_caxis, var_param.stride_in_caxis, var_param.caxis_size); ComputeVarUsingWelford(&in_ptr[input_offset], &out_ptr[i], var_param); } } } namespace { template inline __device__ void WelfordReduce(const T* in_ptr, ComputeType* mean, ComputeType* m2, ComputeType* count, const size_t total_elem_cnt, const size_t start, const size_t step) { ComputeType old_mean = 0.0; for (size_t i = start; i < total_elem_cnt; i += step) { ++(*count); old_mean = *mean; *mean += (static_cast(in_ptr[i]) - *mean) / *count; *m2 += (static_cast(in_ptr[i]) - *mean) * (static_cast(in_ptr[i]) - old_mean); } } template inline __device__ void WelfordCombine(const T* b_mean, const T* b_m2, const T* b_count, T* mean, T* m2, T* count, const size_t total_elem_cnt, const size_t start, const size_t step) { for (size_t i = start; i < total_elem_cnt; i += step) { cuda::layer_norm::WelfordCombine(b_mean[i], b_m2[i], b_count[i], mean, m2, count); } } __device__ int32_t done_block_count = 0; } // namespace template __global__ void ComputeVarScalarOut(const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr, const VarParam var_param, bool is_nan) { if (is_nan) { if (blockIdx.x == 0 && threadIdx.x == 0) { *out_ptr = Nan(); } return; } const size_t elems_per_block = var_param.elem_cnt / gridDim.x; const size_t elems_per_thread = elems_per_block / blockDim.x; // tail element number in block size_t tail_elems = elems_per_block % blockDim.x; ComputeType thread_mean = 0.0; ComputeType thread_m2 = 0.0; ComputeType thread_count = 0.0; // every thread deal it's elems if (elems_per_thread > 0) { const size_t block_offset = blockIdx.x * elems_per_block; WelfordReduce(&in_ptr[block_offset], &thread_mean, &thread_m2, &thread_count, elems_per_block - tail_elems, threadIdx.x, blockDim.x); } // thread 0 of last block handles tail element between blocks if (blockIdx.x == gridDim.x - 1 && threadIdx.x == 0) { tail_elems += var_param.elem_cnt % gridDim.x; } // thread 0 deal tail elems if (tail_elems != 0 && threadIdx.x == 0) { const size_t tail_offset = blockIdx.x * elems_per_block + blockDim.x * elems_per_thread; WelfordReduce(&in_ptr[tail_offset], &thread_mean, &thread_m2, &thread_count, tail_elems, /*tail start=*/0, /*step=*/1); } ComputeType block_mean = 0; ComputeType block_m2 = 0; ComputeType block_count = 0; cuda::layer_norm::WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &block_mean, &block_m2, &block_count); if (gridDim.x == 1) { if (threadIdx.x == 0) { *out_ptr = cuda::layer_norm::Div(block_m2, (var_param.unbiased ? block_count - 1 : block_count)); } return; } ComputeType* tmp_mean_ptr = tmp_buffer_ptr; ComputeType* tmp_m2_ptr = &tmp_mean_ptr[gridDim.x]; ComputeType* tmp_count_ptr = &tmp_m2_ptr[gridDim.x]; if (threadIdx.x == 0) { tmp_mean_ptr[blockIdx.x] = block_mean; tmp_m2_ptr[blockIdx.x] = block_m2; tmp_count_ptr[blockIdx.x] = block_count; } __shared__ bool is_last_block; if (threadIdx.x == 0) { is_last_block = atomicAdd(&done_block_count, 1) == gridDim.x - 1; } __syncthreads(); if (is_last_block) { ComputeType last_block_thread_mean = 0; ComputeType last_block_thread_m2 = 0; ComputeType last_block_thread_count = 0; const size_t welforddatas_per_thread = gridDim.x / blockDim.x; const size_t tail_welforddatas = gridDim.x % blockDim.x; if (welforddatas_per_thread > 0) { WelfordCombine(tmp_mean_ptr, tmp_m2_ptr, tmp_count_ptr, &last_block_thread_mean, &last_block_thread_m2, &last_block_thread_count, gridDim.x - tail_welforddatas, threadIdx.x, blockDim.x); } // thread 0 deal tail welford data if (tail_welforddatas != 0 && threadIdx.x == 0) { const size_t last_block_tail_offset = blockDim.x * welforddatas_per_thread; WelfordCombine(&tmp_mean_ptr[last_block_tail_offset], &tmp_m2_ptr[last_block_tail_offset], &tmp_count_ptr[last_block_tail_offset], &last_block_thread_mean, &last_block_thread_m2, &last_block_thread_count, tail_welforddatas, /*tail start=*/0, /*step=*/1); } ComputeType final_mean = 0; ComputeType final_m2 = 0; ComputeType final_count = 0; cuda::layer_norm::WelfordBlockAllReduce( last_block_thread_mean, last_block_thread_m2, last_block_thread_count, &final_mean, &final_m2, &final_count); if (threadIdx.x == 0) { *out_ptr = cuda::layer_norm::Div(final_m2, (var_param.unbiased ? final_count - 1 : final_count)); done_block_count = 0; } } } template struct VarFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr, const VarParam var_param) { int grid_dim = 0; int block_dim = 0; SetGridDimAndBlockDim(var_param.elem_cnt, &grid_dim, &block_dim); if (var_param.parallel_num == 1) { ComputeVarScalarOut <<As()->cuda_stream()>>>( in_ptr, out_ptr, tmp_buffer_ptr, var_param, IsNanOut(var_param)); } else { // if var_param.parallel_num is 0, do nothing, return 0-size tensor if (var_param.parallel_num == 0) { return; } RUN_CUDA_KERNEL((ComputeVarUsingWelfordWrapper), stream, var_param.parallel_num, in_ptr, out_ptr, var_param, IsNanOut(var_param)); } } }; template struct VarFunctor; template struct VarFunctor; template struct VarFunctor; #if CUDA_VERSION >= 11000 template struct VarFunctor; #endif } // namespace user_op } // namespace oneflow ================================================ FILE: oneflow/user/kernels/variance_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_ #include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { namespace user_op { OF_DEVICE_FUNC size_t LinearIndex2Offset(const size_t linear_index, const int32_t* dim_size_in_axis_ptr, const int32_t* stride_vec_ptr, const int32_t size) { // low dim at begin size_t offset = 0; size_t num_dim = 0; for (int j = 0; j < size; j++) { num_dim = (j == 0 ? linear_index : (num_dim / dim_size_in_axis_ptr[j - 1])); offset += num_dim % dim_size_in_axis_ptr[j] * stride_vec_ptr[j]; } return offset; } namespace { constexpr size_t MaxDims = 8; } // namespace struct VarParam { VarParam() : unbiased(true), parallel_num(1), elem_cnt(1), axis_size(1), caxis_size(1) {} bool unbiased; size_t parallel_num; size_t elem_cnt; int32_t axis_size; int32_t caxis_size; int32_t stride_in_axis[MaxDims]; int32_t dim_size_in_axis[MaxDims]; int32_t stride_in_caxis[MaxDims]; int32_t dim_size_in_caxis[MaxDims]; }; class VarParamHelper final { public: VarParamHelper() = delete; explicit VarParamHelper(const ShapeView& input_shape, const std::vector axis, const bool unbiased) : axis_(axis), input_shape_(input_shape) { param.unbiased = unbiased; ComputeStrideVec(axis_, param.stride_in_axis); caxis_ = GetCAxis(); ComputeStrideVec(caxis_, param.stride_in_caxis); GetDimSizeInAxis(axis_, param.dim_size_in_axis); GetDimSizeInAxis(caxis_, param.dim_size_in_caxis); ComputeElemCntAndParallelNum(); param.axis_size = axis_.size(); param.caxis_size = caxis_.size(); } VarParam param; private: void ComputeElemCntAndParallelNum() { for (int i = 0; i < axis_.size(); i++) { param.elem_cnt *= input_shape_.At(axis_[i]); } for (int i = 0; i < caxis_.size(); i++) { param.parallel_num *= input_shape_.At(caxis_[i]); } } void ComputeStrideVec(const std::vector axis, int32_t* stride_vec) { // low dim at begin const int axis_size = axis.size(); for (int i = 0; i < axis_size; i++) { int stride = 1; if (axis.at(i) + 1 == input_shape_.NumAxes()) { stride_vec[axis_size - 1 - i] = 1; } else { for (int j = axis.at(i) + 1; j < input_shape_.NumAxes(); j++) { stride *= input_shape_.At(j); } stride_vec[axis_size - 1 - i] = stride; } } } std::vector GetCAxis() { std::vector caxis; caxis.resize(input_shape_.NumAxes()); std::iota(caxis.begin(), caxis.end(), 0); for (int i = 0; i < axis_.size(); i++) { caxis.erase(caxis.begin() + axis_.at(i) - i); } return caxis; } void GetDimSizeInAxis(const std::vector axis, int32_t* dim_size_in_axis) { // low dim at begin const int axis_size = axis.size(); for (int i = 0; i < axis_size; i++) { dim_size_in_axis[axis_size - 1 - i] = input_shape_.At(axis.at(i)); } } const std::vector axis_; const ShapeView input_shape_; std::vector caxis_; }; template OF_DEVICE_FUNC void ComputeVarUsingWelford(const T* in_ptr, T* out_ptr, const VarParam& var_param) { size_t count = 0; // torch use double even for float data, so here float will result in accuracy error. ComputeType mean = 0.0; ComputeType old_mean = 0.0; ComputeType m2 = 0.0; for (size_t i = 0; i < var_param.elem_cnt; i++) { const size_t offset = LinearIndex2Offset(i, var_param.dim_size_in_axis, var_param.stride_in_axis, var_param.axis_size); count++; old_mean = mean; mean += (static_cast(in_ptr[offset]) - mean) / count; m2 += (static_cast(in_ptr[offset]) - mean) * (static_cast(in_ptr[offset]) - old_mean); } *out_ptr = m2 / (var_param.unbiased ? count - 1 : count); } namespace { OF_DEVICE_FUNC bool IsNanOut(const VarParam var_param) { return (var_param.elem_cnt == 0) || (var_param.elem_cnt == 1 && var_param.unbiased == true); } #ifdef WITH_CUDA void SetGridDimAndBlockDim(const size_t total_elem_cnt, int* grid_dim, int* block_dim) { // when total_elem_cnt > 2 * kCudaThreadsNumPerBlock, use two cuda kernel if (total_elem_cnt > (kCudaThreadsNumPerBlock << 1)) { *grid_dim = std::min(static_cast(std::ceil(std::sqrt(total_elem_cnt))), kCudaMaxBlocksNum); *block_dim = kCudaThreadsNumPerBlock; } else { *grid_dim = 1; int32_t aligned_block_dim = (total_elem_cnt >= kCudaThreadsNumPerBlock) ? kCudaThreadsNumPerBlock // avoid get block_dim = 0 when total_elem_cnt is 0 : std::max((total_elem_cnt + kCudaWarpSize - 1) / kCudaWarpSize, 1) * kCudaWarpSize; *block_dim = std::min(aligned_block_dim, kCudaThreadsNumPerBlock); } } #endif // WITH_CUDA } // namespace template struct VarFunctor final { void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr, const VarParam var_param); }; } // namespace user_op } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/vector_matrix_product_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/matmul.h" namespace oneflow { namespace { ep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N; } template std::unique_ptr NewMemcpyPrimitive(Context* ctx) { return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::MemcpyKind::kDtoD); } std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, bool transpose_a, bool transpose_b) { const auto trans_a = GetBlasTransposeType(transpose_a); const auto trans_b = GetBlasTransposeType(transpose_b); return ep::primitive::NewPrimitive(device_type, data_type, trans_a, trans_b); } template std::unique_ptr NewVectorMatrixProductPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, false, false); } template std::unique_ptr NewVectorMatrixProductGradAPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, false, true); } template std::unique_ptr NewVectorMatrixProductGradBPrimitive(Context* ctx) { const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->data_type(); return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); } auto VectorMatrixProductPrimitiveExists() { return hob::make_custom("NewVectorMatrixProductPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewVectorMatrixProductPrimitive(&ctx).operator bool(); }); } auto VectorMatrixProductGradAPrimitiveExists() { return hob::make_custom("NewVectorMatrixProductGradAPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewVectorMatrixProductGradAPrimitive(&ctx).operator bool(); }); } auto VectorMatrixProductGradBPrimitiveExists() { return hob::make_custom("NewVectorMatrixProductGradBPrimitiveExists", [](const user_op::KernelRegContext& ctx) { return NewVectorMatrixProductGradBPrimitive(&ctx).operator bool(); }); } class VectorMatrixProductKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: VectorMatrixProductKernel() = default; ~VectorMatrixProductKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) */ const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); CHECK_EQ(a->shape_view().NumAxes(), 1) << "A Numdims should be equal to 1. "; const DataType data_type = a->data_type(); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); CHECK_EQ(b->shape_view().NumAxes(), 2) << "B Numdims should be equal to 2. "; CHECK_EQ(b->data_type(), data_type) << "Matrix A Datatype should be equal to Vector B"; user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(out->shape_view().NumAxes(), 1) << "Out Numdims should be equal to 1. "; CHECK_EQ(out->data_type(), data_type) << "Out Datatype should be equal to input's. "; size_t m = 1; size_t k = a->shape_view().At(0); size_t n = b->shape_view().At(1); const double alpha = 1.0; double beta = 0.0; auto matmul = NewVectorMatrixProductPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; REGISTER_USER_KERNEL("vector_matrix_product") .SetCreateFn() .SetIsMatchedHob(VectorMatrixProductPrimitiveExists()); class VectorMatrixProductGradAKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: VectorMatrixProductGradAKernel() = default; ~VectorMatrixProductGradAKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k) */ const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); size_t m = 1; size_t k = dy->shape_view().At(0); size_t n = b->shape_view().At(0); const double alpha = 1.0; double beta = 0.0; auto matmul = NewVectorMatrixProductGradAPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), b->dptr(), beta, dx->mut_dptr()); } }; REGISTER_USER_KERNEL("vector_matrix_product_grad_a") .SetCreateFn() .SetIsMatchedHob(VectorMatrixProductGradAPrimitiveExists()); class VectorMatrixProductGradBKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: VectorMatrixProductGradBKernel() = default; ~VectorMatrixProductGradBKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) GradB = a_transpose (k, 1) matmul dy (1, n) */ const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); size_t m = a->shape_view().At(0); size_t k = 1; size_t n = dy->shape_view().At(0); const double alpha = 1.0; double beta = 0.0; auto matmul = NewVectorMatrixProductGradBPrimitive(ctx); CHECK(matmul); matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), dy->dptr(), beta, dx->mut_dptr()); } }; REGISTER_USER_KERNEL("vector_matrix_product_grad_b") .SetCreateFn() .SetIsMatchedHob(VectorMatrixProductGradBPrimitiveExists()); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/where_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/where.h" namespace oneflow { namespace { template auto NewPrimitive(Context* ctx) -> std::unique_ptr { const user_op::TensorDesc* cond_desc = ctx->TensorDesc4ArgNameAndIndex("condition", 0); const user_op::TensorDesc* out_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0); return ep::primitive::NewPrimitive( ctx->device_type(), cond_desc->data_type(), out_desc->data_type(), out_desc->shape().NumAxes()); } auto PrimitiveExists() { return hob::make_custom("PrimitiveExists", [](const user_op::KernelRegContext& ctx) -> bool { return NewPrimitive(&ctx).operator bool(); }); } class WhereKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: WhereKernel() = default; ~WhereKernel() = default; private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* cond = ctx->Tensor4ArgNameAndIndex("condition", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); if (out->shape_view().elem_cnt() == 0) { return; } auto primitive = NewPrimitive(ctx); CHECK(primitive); primitive->Launch(ctx->stream(), cond->shape_view().size(), cond->shape_view().ptr(), cond->dptr(), x->shape_view().size(), x->shape_view().ptr(), x->dptr(), y->shape_view().size(), y->shape_view().ptr(), y->dptr(), out->mut_dptr()); } }; REGISTER_USER_KERNEL("where").SetCreateFn().SetIsMatchedHob(PrimitiveExists() == true); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/kernels/where_kernel_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/where_kernel_util.h" namespace oneflow { template struct WhereKernelUtil { static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T* rhs, T* out) { FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast(cond[i]) ? lhs[i] : rhs[i]; } } static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T* rhs, T* out) { FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast(cond[i]) ? x_scalar : rhs[i]; } } static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T y_scalar, T* out) { FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast(cond[i]) ? lhs[i] : y_scalar; } } static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T y_scalar, T* out) { FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast(cond[i]) ? x_scalar : y_scalar; } } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_WHERE_FUNCTOR, (DeviceType::kCPU), ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/where_kernel_util.cu ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/where_kernel_util.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { template struct WhereFunctor { OF_DEVICE_FUNC T operator()(CondT cond, T lhs, T rhs) const { return static_cast(cond) ? lhs : rhs; } }; template struct WhereScalarXFunctor { OF_DEVICE_FUNC explicit WhereScalarXFunctor(T scalar) : x_scalar(scalar) {} OF_DEVICE_FUNC T operator()(CondT cond, T rhs) const { return static_cast(cond) ? x_scalar : rhs; } const T x_scalar; }; template struct WhereScalarYFunctor { OF_DEVICE_FUNC explicit WhereScalarYFunctor(T scalar) : y_scalar(scalar) {} OF_DEVICE_FUNC T operator()(CondT cond, T lhs) const { return static_cast(cond) ? lhs : y_scalar; } const T y_scalar; }; template struct WhereScalarXYFunctor { OF_DEVICE_FUNC explicit WhereScalarXYFunctor(T x_scalar, T y_scalar) : x_scalar(x_scalar), y_scalar(y_scalar) {} OF_DEVICE_FUNC T operator()(CondT cond) const { return static_cast(cond) ? x_scalar : y_scalar; } const T x_scalar; const T y_scalar; }; } // namespace template struct WhereKernelUtil { static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T* rhs, T* out) { cuda::elementwise::Ternary(WhereFunctor(), elem_cnt, out, cond, lhs, rhs, stream->As()->cuda_stream()); } static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T* rhs, T* out) { cuda::elementwise::Binary(WhereScalarXFunctor(x_scalar), elem_cnt, out, cond, rhs, stream->As()->cuda_stream()); } static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T y_scalar, T* out) { cuda::elementwise::Binary(WhereScalarYFunctor(y_scalar), elem_cnt, out, cond, lhs, stream->As()->cuda_stream()); } static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T y_scalar, T* out) { cuda::elementwise::Unary(WhereScalarXYFunctor(x_scalar, y_scalar), elem_cnt, out, cond, stream->As()->cuda_stream()); } }; OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_WHERE_FUNCTOR, (DeviceType::kCUDA), ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) } // namespace oneflow ================================================ FILE: oneflow/user/kernels/where_kernel_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_ #define ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_ #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/ep/include/stream.h" namespace oneflow { template struct WhereKernelUtil { static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T* rhs, T* out); static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T* rhs, T* out); static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs, const T y_scalar, T* out); static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T x_scalar, const T y_scalar, T* out); }; #define INSTANTIATE_WHERE_FUNCTOR(device_type_v, dtype_pair, ctype_pair) \ template struct WhereKernelUtil; } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_ ================================================ FILE: oneflow/user/kernels/zero_like_kernel.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/memset.h" namespace oneflow { namespace { class ZeroLikeKernel final : public user_op::OpKernel { public: ZeroLikeKernel() = default; ~ZeroLikeKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = out->shape_view().elem_cnt(); if (elem_cnt > 0) { std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->stream()->device_type()); CHECK(primitive) << "Can not create Memset primitive for device type " << ctx->stream()->device_type(); primitive->Launch(ctx->stream(), out->mut_dptr(), 0, elem_cnt * GetSizeOfDataType(out->data_type())); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; REGISTER_USER_KERNEL("zero_like").SetCreateFn(); } // namespace } // namespace oneflow ================================================ FILE: oneflow/user/ops/acc_ctrl_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe AccCtrlTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape({1})); return Maybe::Ok(); } /*static*/ Maybe AccCtrlTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AccCtrlTickOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe AccCtrlTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), // NOLINT(maybe-need-error-msg) parallel_hierarchy.NumAxes()); // NOLINT(maybe-need-error-msg) NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); in_distribution->clear_sbp_parallel(); out_distribution->clear_sbp_parallel(); // in use hint in_distribution->CopyFrom(in_dis_hint); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { // out dim1 = broadcast out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); } return Maybe::Ok(); } /* static */ Maybe AccCtrlTickOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe AccCtrlTickOp::InferOutputBlobTimeShape( user_op::InferOutputBlobTimeShapeFnContext* ctx) { const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); DimVector time_shape_dim_vec = in_time_shape.dim_vec(); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(!time_shape_dim_vec.empty()); // NOLINT(maybe-need-error-msg) if (time_shape_dim_vec.back() == max_acc_num) { time_shape_dim_vec.pop_back(); } else if (time_shape_dim_vec.back() % max_acc_num == 0) { time_shape_dim_vec.back() /= max_acc_num; } else { const int64_t elem_cnt = in_time_shape.elem_cnt(); CHECK_EQ_OR_RETURN(elem_cnt % max_acc_num, 0); time_shape_dim_vec.resize(1); time_shape_dim_vec.back() = elem_cnt / max_acc_num; } *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/acc_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe AccOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe AccOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return AccOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe AccOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe AccOp::InferOutputBlobTimeShape( user_op::InferOutputBlobTimeShapeFnContext* ctx) { const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); DimVector time_shape_dim_vec = in_time_shape.dim_vec(); CHECK_OR_RETURN(!time_shape_dim_vec.empty()); if (time_shape_dim_vec.back() == max_acc_num) { time_shape_dim_vec.pop_back(); } else if (time_shape_dim_vec.back() % max_acc_num == 0) { time_shape_dim_vec.back() /= max_acc_num; } else { const int64_t elem_cnt = in_time_shape.elem_cnt(); time_shape_dim_vec.resize(1); time_shape_dim_vec.back() = elem_cnt / max_acc_num; } *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/adaptive_max_pool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferFWTensorDesc(user_op::InferContext* ctx) { std::vector output_size = ctx->Attr>("output_size"); const std::string& data_format = ctx->Attr("data_format"); const Shape& x_shape = ctx->InputShape("x", 0); DimVector out_shape(x_shape.NumAxes()); out_shape[0] = x_shape.dim_vec()[0]; out_shape[1] = x_shape.dim_vec()[1]; if (data_format == "channels_first") { out_shape[1] = x_shape.dim_vec()[1]; for (int i = 2; i < out_shape.size(); ++i) { out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0]; } } else { out_shape[3] = x_shape.dim_vec()[3]; for (int i = 1; i < out_shape.size() - 1; ++i) { out_shape[i] = output_size.size() > i - 1 ? output_size[i - 1] : output_size[0]; } } ctx->SetOutputShape("y", 0, Shape(out_shape)); ctx->SetOutputShape("index", 0, Shape(out_shape)); return Maybe::Ok(); } Maybe InferBWTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("x", 0)); return Maybe::Ok(); } Maybe FwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); // only for nchw FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("index", 0), i) .Build(); } return Maybe::Ok(); } Maybe BwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Split(user_op::OpArg("index", 0), i) .Build(); } return Maybe::Ok(); } Maybe InferFWDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("index", 0, DataType::kInt64); return Maybe::Ok(); } Maybe InferBWDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define DEF_ADAPTIVE_MAX_POOL_OP(op_class_name_prefix) \ /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferFWTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ return FwGetSbpFn(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ return InferFWDataType(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferBWTensorDesc(ctx); \ } \ \ /*static*/ \ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ \ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return BwGetSbpFn(ctx); \ } \ \ /* static */ \ Maybe op_class_name_prefix##GradOp::InferDataType(user_op::InferContext* ctx) { \ return InferBWDataType(ctx); \ } DEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool1D); DEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool2D); DEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool3D); #undef DEF_ADAPTIVE_MAX_POOL_OP } // namespace oneflow ================================================ FILE: oneflow/user/ops/adaptive_pool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferFWTensorDesc(user_op::InferContext* ctx) { std::vector output_size = ctx->Attr>("output_size"); const std::string& data_format = ctx->Attr("data_format"); const Shape& x_shape = ctx->InputShape("x", 0); DimVector out_shape(x_shape.NumAxes()); out_shape[0] = x_shape.dim_vec()[0]; if (data_format == "channels_first") { out_shape[1] = x_shape.dim_vec()[1]; for (int i = 2; i < out_shape.size(); ++i) { out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0]; } } else { out_shape[3] = x_shape.dim_vec()[3]; for (int i = 1; i < out_shape.size() - 1; ++i) { out_shape[i] = output_size.size() > i - 1 ? output_size[i - 1] : output_size[0]; } } ctx->SetOutputShape("y", 0, Shape(out_shape)); return Maybe::Ok(); } Maybe InferBWTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("x", 0)); return Maybe::Ok(); } Maybe FwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); // only for nchw FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe BwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } Maybe InferFWDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe InferBWDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define DEF_ADAPTIVE_AVG_POOL_OP(op_class_name_prefix) \ /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferFWTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ return FwGetSbpFn(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ return InferFWDataType(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferBWTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return BwGetSbpFn(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ user_op::InferContext* ctx) { \ return InferBWDataType(ctx); \ } DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool1D) DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool2D) DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool3D) #undef DEF_ADAPTIVE_AVG_POOL_OP } // namespace oneflow ================================================ FILE: oneflow/user/ops/add_n_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe AddNOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); CHECK_NOTNULL_OR_RETURN(out); // NOLINT(maybe-need-error-msg) for (const auto& pair : ctx->inputs()) { const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) { CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()) << Error::RuntimeError() << "inconsistent tensor size, expected all tensor to have the same shapes, " << "but got " << in_0.shape().ToString() << " and " << cur_in.shape().ToString(); } } out->set_shape(in_0.shape()); out->set_is_dynamic(in_0.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe AddNOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AddNOp::GetSbp(user_op::SbpContext* ctx) { int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); for (int64_t i = 0; i < num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); return Maybe::Ok(); } /* static */ Maybe AddNOp::InferDataType(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); CHECK_NOTNULL_OR_RETURN(out); // NOLINT(maybe-need-error-msg) for (const auto& pair : ctx->inputs()) { const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()) << Error::RuntimeError() << ctx->op_name() << " expected all tenser to have same type, but found " << DataType_Name(in_0.data_type()) << " and " << DataType_Name(cur_in.data_type()); } out->set_data_type(in_0.data_type()); return Maybe::Ok(); } /*static*/ Maybe AddNOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("in") >= 2) << Error::RuntimeError() << "The number of input tensors should be greater than or equal to 2"; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/affine_grid_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; const auto& size = conf.attr("size"); if (size.NumAxes() != 4 && size.NumAxes() != 5) { err << "dimension of size can't be:" << size.NumAxes(); pass_checked = false; } for (int i = 0; i < size.NumAxes(); i++) { if (size.At(i) <= 0) { err << "element of size can't be:" << size.At(i); } } if (pass_checked) { return Maybe::Ok(); } else { return oneflow::Error::CheckFailedError() << err.str(); } } } // namespace /* static */ Maybe AffineGridOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); user_op::TensorDesc* grid = ctx->MutOutputTensorDesc("grid", 0); const Shape& size = ctx->Attr("size"); // Only support 2D or 3D affine grid with NCHW layout // For 2D grid: theta = { N, 2, 3 }, // size = { N, C, H, W } // grid = { N, H, W, 2 } // For 3D grid: theta = { N, 3, 4 }, // size = { N, C, D, H, W } // grid = { N, D, H, W, 3 } bool is_2d_grid = true; if (theta.shape().At(1) == 2) { CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) << "Theta and size MUST have same batch dimension"; is_2d_grid = true; } else if (theta.shape().At(1) == 3) { CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; CHECK_EQ_OR_RETURN(size.NumAxes(), 5) << "Dimension of size MUST be 4, when 3d affine grid"; CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) << "Theta and size MUST have same batch dimension"; is_2d_grid = false; } else { CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; } grid->set_is_dynamic(theta.is_dynamic()); Shape grid_shape; if (is_2d_grid) { grid_shape = {size.At(0), size.At(2), size.At(3), 2}; } else { grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; } grid->set_shape(grid_shape); return Maybe::Ok(); } /*static*/ Maybe AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); user_op::TensorDesc* grid = ctx->MutOutputTensorDesc("grid", 0); const Shape& size = ctx->Attr("size"); // Only support 2D or 3D affine grid with NCHW layout // For 2D grid: theta = { N, 2, 3 }, // size = { N, C, H, W } // grid = { N, H, W, 2 } // For 3D grid: theta = { N, 3, 4 }, // size = { N, C, D, H, W } // grid = { N, D, H, W, 3 } const Shape& theta_shape = theta.shape(); bool is_2d_grid = true; if (theta_shape.At(1) == 2) { CHECK_EQ_OR_RETURN(theta_shape.At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; is_2d_grid = true; } else if (theta_shape.At(1) == 3) { CHECK_EQ_OR_RETURN(theta_shape.At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; CHECK_EQ_OR_RETURN(size.NumAxes(), 5) << "Dimension of size MUST be 4, when 3d affine grid"; is_2d_grid = false; } else { CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; } int64_t N = size.At(0); const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("theta", 0); Shape logical_shape = theta_shape; logical_shape.Set(0, size.At(0)); const auto& physical_shape = JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())); N = physical_shape->At(0); } CHECK_EQ_OR_RETURN(theta_shape.At(0), N) << "The dimension 0 size of theta shape should be " << N << ", but got " << theta_shape.At(0); grid->set_is_dynamic(theta.is_dynamic()); Shape grid_shape; if (is_2d_grid) { grid_shape = {N, size.At(2), size.At(3), 2}; } else { grid_shape = {N, size.At(2), size.At(3), size.At(4), 3}; } grid->set_shape(grid_shape); return Maybe::Ok(); } /* static */ Maybe AffineGridOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("theta", 0), 0) .Split(user_op::OpArg("grid", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe AffineGridOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_(def, conf); } /* static */ Maybe AffineGridOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("grid", 0, ctx->InputDType("theta", 0)); return Maybe::Ok(); } /* static */ Maybe AffineGridGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dgrid = ctx->InputTensorDesc("dgrid", 0); const Shape& size = ctx->Attr("size"); if (size.NumAxes() == 4) { ctx->MutOutputTensorDesc("dtheta", 0)->set_shape(Shape({dgrid.shape().At(0), 2, 3})); } else if (size.NumAxes() == 5) { ctx->MutOutputTensorDesc("dtheta", 0)->set_shape(Shape({dgrid.shape().At(0), 3, 4})); } else { CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; } return Maybe::Ok(); } /*static*/ Maybe AffineGridGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AffineGridGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dgrid", 0), 0) .Split(user_op::OpArg("dtheta", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe AffineGridGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_(def, conf); } /* static */ Maybe AffineGridGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dtheta", 0, ctx->InputDType("dgrid", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/amp_white_identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe AmpWhiteIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_shape(in.shape()); out->set_is_dynamic(in.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe AmpWhiteIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AmpWhiteIdentityOp::GetSbp(user_op::SbpContext* ctx) { const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe AmpWhiteIdentityOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } /* static */ Maybe AmpBlackIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_shape(in.shape()); out->set_is_dynamic(in.is_dynamic()); return Maybe::Ok(); } /* static */ Maybe AmpBlackIdentityOp::GetSbp(user_op::SbpContext* ctx) { const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe AmpBlackIdentityOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/arange_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { DataType dtype = ctx->Attr("dtype"); int64_t range_elem_cnt = 0; if (IsIntegralDataType(dtype)) { int64_t integer_delta = ctx->Attr("integer_delta"); CHECK_NE_OR_RETURN(integer_delta, static_cast(0)) << "RuntimeError: step must be nonzero. "; int64_t integer_start = ctx->Attr("integer_start"); int64_t integer_limit = ctx->Attr("integer_limit"); // CHECK when limit > start, delta > 0; limit < start, delta < 0; CHECK_GE_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast(0)) << "RuntimeError: upper bound and larger bound inconsistent with step sign"; range_elem_cnt = std::ceil(static_cast(integer_limit - integer_start) / integer_delta); } else { double float_delta = ctx->Attr("float_delta"); CHECK_NE_OR_RETURN(float_delta, static_cast(0.0)) << "RuntimeError: step must be nonzero. "; double float_start = ctx->Attr("float_start"); double float_limit = ctx->Attr("float_limit"); // CHECK when limit > start, delta > 0; limit < start, delta < 0; // CHECK_GE For 0-Dim Tensor CHECK_GE_OR_RETURN((float_limit - float_start) / float_delta, static_cast(0.0)) << "RuntimeError: upper bound and larger bound inconsistent with step sign"; range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); } ctx->SetOutputShape("out", 0, Shape({range_elem_cnt})); return Maybe::Ok(); } /*static*/ Maybe ArangeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { DataType dtype = ctx->Attr("dtype"); int64_t range_elem_cnt = 0; if (IsIntegralDataType(dtype)) { int64_t integer_delta = ctx->Attr("integer_delta"); if (integer_delta == static_cast(0)) { return Error::RuntimeError() << " step must be nonzero. "; } int64_t integer_start = ctx->Attr("integer_start"); int64_t integer_limit = ctx->Attr("integer_limit"); // CHECK when limit > start, delta > 0; limit < start, delta < 0; if ((integer_limit - integer_start) / integer_delta < static_cast(0)) { return Error::RuntimeError() << " upper bound and larger bound inconsistent with step sign"; } range_elem_cnt = std::ceil(static_cast(integer_limit - integer_start) / integer_delta); } else { double float_delta = ctx->Attr("float_delta"); if (float_delta == static_cast(0.0)) { return Error::RuntimeError() << " step must be nonzero. "; } double float_start = ctx->Attr("float_start"); double float_limit = ctx->Attr("float_limit"); // CHECK when limit > start, delta > 0; limit < start, delta < 0; // CHECK_GE For 0-Dim Tensor if ((float_limit - float_start) / float_delta < static_cast(0.0)) { return Error::RuntimeError() << " upper bound and larger bound inconsistent with step sign"; } range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); } const Shape& logical_shape = Shape({range_elem_cnt}); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe ArangeOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe ArangeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe ArangeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/arg_sort_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ArgSortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ArgSortOp::GetSbp(user_op::SbpContext* ctx) { // The current implementation can only do arg_sort in the last dimension and should use // Broadcast (by default) instead of Split for that dimension const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /* static */ Maybe ArgSortOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { const std::string& direction = conf.attr("direction"); CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING") << Error::RuntimeError() << "expected the input direction parameter value is \"ASCENDING\" or \"DESCENDING\", " << "but found the value is " << direction; return Maybe::Ok(); } /* static */ Maybe ArgSortOp::InferDataType(user_op::InferContext* ctx) { if (ctx->parallel_desc().device_type() == DeviceType::kNPU) { ctx->SetOutputDType("out", 0, DataType::kInt64); } else { ctx->SetOutputDType("out", 0, DataType::kInt32); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/arg_where_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); output_desc->set_shape(Shape({input_shape.elem_cnt(), input_shape.NumAxes()})); output_desc->set_is_dynamic(true); user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc("output_size", 0); output_size_desc->set_shape(Shape({1})); return Maybe::Ok(); } } // namespace /* static */ Maybe ArgwhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc(ctx); } /*static*/ Maybe ArgwhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ArgwhereOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe ArgwhereOp::InferDataType(user_op::InferContext* ctx) { const DataType dtype = ctx->Attr("dtype"); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); output_desc->set_data_type(dtype); user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc("output_size", 0); output_size_desc->set_data_type(dtype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/argmax_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { auto dim_vec = ctx->InputShape("in", 0).dim_vec(); dim_vec.pop_back(); ctx->SetOutputShape("out", 0, Shape(std::move(dim_vec))); return Maybe::Ok(); } /*static*/ Maybe ArgmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ArgmaxOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /* static */ Maybe ArgmaxOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/as_strided_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ auto AsStridedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const auto& size = ctx->Attr>("size"); const auto& stride = ctx->Attr>("stride"); CHECK_EQ_OR_RETURN(size.size(), stride.size()) << "mismatch in length of strides and shape"; DimVector out_vec; out_vec.insert(out_vec.end(), size.cbegin(), size.cend()); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); output_desc->set_shape(Shape(out_vec)); return Maybe::Ok(); } /*static*/ auto AsStridedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return AsStridedOp::InferLogicalTensorDesc(ctx); } /*static*/ auto AsStridedOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); } /*static*/ auto AsStridedOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /* static */ auto AsStridedGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& input_shape = ctx->InputShape("input", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(input_shape); return Maybe::Ok(); } /*static*/ auto AsStridedGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return AsStridedGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto AsStridedGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); } /*static*/ auto AsStridedGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("dx", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/assign_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(!ref_desc.is_dynamic()); CHECK_OR_RETURN(ref_desc.shape() == value_desc.shape()); if (ctx->has_input("condition", 0)) { const user_op::TensorDesc& condition = ctx->InputTensorDesc("condition", 0); CHECK_OR_RETURN(condition.shape().NumAxes() == 1); CHECK_OR_RETURN(condition.shape().At(0) == 1); } return Maybe::Ok(); } Maybe GetSbpSignatures(user_op::SbpContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0); FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) { if (ctx->user_op_conf().has_input("condition", 0)) { ctx->NewBuilder() .Split(user_op::OpArg("ref", 0), axis) .Split(user_op::OpArg("value", 0), axis) .Broadcast(user_op::OpArg("condition", 0)) .Build(); } else { ctx->NewBuilder().Split(ctx->inputs(), axis).Build(); } } return Maybe::Ok(); } Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); CHECK_OR_RETURN(ref_modifier != nullptr); ref_modifier->set_is_mutable(true); user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); CHECK_OR_RETURN(value_modifier != nullptr); value_modifier->set_requires_grad(false); if (conf.has_input("condition", 0)) { user_op::InputArgModifier* condition_modifier = GetInputArgModifierFn("condition", 0); CHECK_OR_RETURN(condition_modifier != nullptr); condition_modifier->set_requires_grad(false); } return Maybe::Ok(); } Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()) << Error::RuntimeError() << DataType_Name(ref_desc.data_type()) << " vs." << DataType_Name(value_desc.data_type()); if (ctx->has_input("condition", 0)) { const user_op::TensorDesc& condition = ctx->InputTensorDesc("condition", 0); CHECK_OR_RETURN(IsIndexDataType(condition.data_type())); } return Maybe::Ok(); } } // namespace #define DEF_ASSIGN_OP(op_class_name) \ /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ return GetSbpSignatures(ctx); \ } \ \ /* static */ Maybe op_class_name::ModifyInputArg( \ const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ return InputArgModifierFn(GetInputArgModifierFn, conf); \ } \ \ /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ return InferDataType_(ctx); \ } DEF_ASSIGN_OP(AssignUserOp) DEF_ASSIGN_OP(AssignIfOp) DEF_ASSIGN_OP(AssignIfNotOp) #undef DEF_ASSIGN_OP } // namespace oneflow ================================================ FILE: oneflow/user/ops/avg_pool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/avg_pool_kernel_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { typedef std::function(user_op::InferContext* ctx)> TensorDescInferFn; TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->Shape4ArgNameAndIndex("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& stride = ctx->Attr>("stride"); const bool ceil_mode = ctx->Attr("ceil_mode"); const bool count_include_pad = ctx->Attr("count_include_pad"); const int32_t& divisor_override = ctx->Attr("divisor_override"); CHECK_EQ_OR_RETURN(kernel_size.size(), dim) << Error::RuntimeError() << "kernel size.size() should equal to dim."; for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0) << Error::RuntimeError() << "kernel size should great than 0, but got: " << pool_dim; } CHECK_EQ_OR_RETURN(stride.size(), dim) << Error::RuntimeError() << "stride.size() should equal to dim."; for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0) << Error::RuntimeError() << "stride size should great than 0, but got: " << stride_dim; } for (int32_t i = 0; i < padding.size(); i++) { CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i]) << "pad should be smaller than half of kernel size"; } const AvgPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); y_desc->set_shape(params_3d.GetYShape()); return Maybe::Ok(); }; } Maybe AvgPoolForwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes() - 2)) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe AvgPoolBackwardGetSbpFn(user_op::SbpContext* ctx) { FOR_RANGE(int64_t, i, 0, 2) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } // Logically computation cost of pool op is the product of output data amount and pool kernal data // amount. After adding sbp, we just divide it by parallel number if output data is splitted because // splitting input and using partial sum for output is not a valid sbp for this op for now. Maybe GetComputationCost(user_op::ComputeComplexityFnContext* ctx, const std::string& blob_name) { const std::vector pool_size = ctx->Attr>("kernel_size"); double logical_computation_cost = std::accumulate( pool_size.begin(), pool_size.end(), ctx->Shape4ArgNameAndIndex(blob_name, 0).elem_cnt(), std::multiplies()); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex(blob_name, 0); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) { if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(dim_sbp); } } return logical_computation_cost; } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { *ctx->MutOutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } Maybe FwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define IMPLEMENT_AVGPOOL_FUNCS(name, ndim) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ return AvgPoolForwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return AvgPoolMakeForwardTensorDescInferFn(ndim)(ctx); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ return FwInferDataType(ctx); \ } \ /*static*/ Maybe name##Op::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx, "y"); \ } IMPLEMENT_AVGPOOL_FUNCS(AvgPool1D, 1) IMPLEMENT_AVGPOOL_FUNCS(AvgPool2D, 2) IMPLEMENT_AVGPOOL_FUNCS(AvgPool3D, 3) #undef IMPLEMENT_AVGPOOL_FUNCS #define IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(name) \ /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return AvgPoolBackwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return BackwardTensorDescInferFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return BwInferDataType(ctx); \ } \ /*static*/ Maybe name##GradOp::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx, "dy"); \ } IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool1D) IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool2D) IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool3D) #undef IMPLEMENT_AVGPOOL_BACKWARD_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_gather_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe BatchGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0) << Error::RuntimeError() << "The dimension of the input tensor should be greater than zero, " << "but got " << in.shape().NumAxes(); const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0) << Error::RuntimeError() << "The dimension of the indices tensor should be greater than zero, " << "but got " << indices.shape().NumAxes(); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()) << Error::RuntimeError() << "The dimension of the input tensor should be greater than or equal to the dimension of " "the indices tensor, " << "but found that the dimension of the input tensor is " << in.shape().dim_vec().size() << ", and the dimension of the indices tensor is " << indices.shape().dim_vec().size(); FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) { if (in.is_dynamic() && indices.is_dynamic() == false) { CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)) << Error::RuntimeError() << "The size of indices tensor should be greater than or equal to the " "size of input tensor " << " at dimension " << i << " when the input tensor is dynamic and the indices tensor is not dynamic"; } else if (in.is_dynamic() == false && indices.is_dynamic()) { LOG(FATAL) << "The indices tensor is not allowed to be dynamic when the input tensor is not dynamic"; } else { CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)) << Error::RuntimeError() << "The size of indices tensor must match the size of input tensor" << " at dimension " << i << " when two tensors are both dynamic or neither"; } } DimVector dim_vec(in.shape().dim_vec()); dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back(); out->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe BatchGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchGatherOp::GetSbp(user_op::SbpContext* ctx) { const int64_t indices_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); if (indices_num_axes > 1) { FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { ctx->NewBuilder() .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } } ctx->NewBuilder() .Broadcast(user_op::OpArg("indices", 0)) .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe BatchGatherOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); // NOLINT(maybe-need-error-msg) indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe BatchGatherOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); CHECK_OR_RETURN(IsIndexDataType(indices.data_type())) << Error::TypeError() << "The dtype of the indices tensor must be int32 or int64"; const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_norm_backward_elemt_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { std::function(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } } // namespace /* static */ Maybe BatchNormBackwardElemtOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const Shape& x_shape = x.shape(); const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, x_shape); JUST(SetOutTensorDesc("grad_in")); return Maybe::Ok(); } /*static*/ Maybe BatchNormBackwardElemtOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchNormBackwardElemtOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe BatchNormBackwardElemtOp::InferDataType(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const auto data_type = x.data_type(); const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type); JUST(SetOutDataType("grad_in")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_norm_backward_reduce_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { std::function(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } } // namespace /* static */ Maybe BatchNormBackwardReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const Shape& x_shape = x.shape(); const auto axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0) << "channel axis should be larger than 0"; CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()) << "channel axis should be less than " << x_shape.NumAxes(); const Shape param_shape({x_shape.At(axis)}); const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape); JUST(SetOutTensorDesc("sum_dy")); JUST(SetOutTensorDesc("sum_dy_xmu")); JUST(SetOutTensorDesc("grad_weight")); JUST(SetOutTensorDesc("grad_bias")); return Maybe::Ok(); } /*static*/ Maybe BatchNormBackwardReduceOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchNormBackwardReduceOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe BatchNormBackwardReduceOp::InferDataType(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const auto data_type = x.data_type(); const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type); JUST(SetOutDataType("sum_dy")); JUST(SetOutDataType("sum_dy_xmu")); JUST(SetOutDataType("grad_weight")); JUST(SetOutDataType("grad_bias")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_norm_elemt_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { std::function(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } } // namespace /* static */ Maybe BatchNormElemtOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const Shape& x_shape = x.shape(); const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, x_shape); JUST(SetOutTensorDesc("output")); return Maybe::Ok(); } /*static*/ Maybe BatchNormElemtOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchNormElemtOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe BatchNormElemtOp::InferDataType(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const auto data_type = x.data_type(); const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type); JUST(SetOutDataType("output")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_norm_gather_stats_with_counts_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { std::function(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } } // namespace /* static */ Maybe BatchNormGatherStatsWithCountsOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const auto& mean = ctx->InputTensorDesc("mean", 0); const Shape& mean_shape = mean.shape(); const Shape param_shape({mean_shape.At(1)}); const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape); JUST(SetOutTensorDesc("global_mean")); JUST(SetOutTensorDesc("global_invstd")); return Maybe::Ok(); } /*static*/ Maybe BatchNormGatherStatsWithCountsOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchNormGatherStatsWithCountsOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe BatchNormGatherStatsWithCountsOp::InferDataType( user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const auto data_type = x.data_type(); const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type); JUST(SetOutDataType("global_mean")); JUST(SetOutDataType("global_invstd")); return Maybe::Ok(); } /* static */ Maybe BatchNormGatherStatsWithCountsOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { if (conf.has_input("running_mean", 0)) { CHECK_OR_RETURN(conf.has_input("running_var", 0)) << "running_mean and running_var should be provided as inputs in the same time."; user_op::InputArgModifier* running_mean_modifier = GetInputArgModifierFn("running_mean", 0); CHECK_OR_RETURN(running_mean_modifier != nullptr) << "input arg modifier of running_mean is null."; running_mean_modifier->set_is_mutable(true); running_mean_modifier->set_requires_grad(false); user_op::InputArgModifier* running_var_modifier = GetInputArgModifierFn("running_var", 0); CHECK_OR_RETURN(running_var_modifier != nullptr) << "input arg modifier of running_var is null."; running_var_modifier->set_is_mutable(true); running_var_modifier->set_requires_grad(false); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/batch_norm_stats_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { std::function(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr) << "output tensordesc of " << bn << " is null."; tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } } // namespace /* static */ Maybe BatchNormStatsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const Shape& x_shape = x.shape(); const auto axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0) << "channel axis should be larger than 0"; CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()) << "channel axis should be less than " << x_shape.NumAxes(); const Shape param_shape({x_shape.At(axis)}); const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape); JUST(SetOutTensorDesc("mean")); JUST(SetOutTensorDesc("invstd")); return Maybe::Ok(); } /*static*/ Maybe BatchNormStatsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchNormStatsOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe BatchNormStatsOp::InferDataType(user_op::InferContext* ctx) { const auto& x = ctx->InputTensorDesc("input", 0); const auto data_type = x.data_type(); const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type); JUST(SetOutDataType("mean")); JUST(SetOutDataType("invstd")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/bernoulli_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe BernoulliOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); out_tensor->set_shape(in_tensor.shape()); return Maybe::Ok(); } /*static*/ Maybe BernoulliOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BernoulliOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /* static */ Maybe BernoulliOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_data_type(ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/bias_add_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe BiasAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); const auto bias_add_axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1) << Error::RuntimeError() << "Bias tensor has to be a one-dimensional vector"; CHECK_GE_OR_RETURN(bias_add_axis, 0) << Error::RuntimeError() << "The size of the axis must greater than or equal to 0, " << "but got " << bias_add_axis; CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()) << Error::IndexError() << "Dimension out of range (expected to be in range of [0" << ", " << a_tensor_desc.shape().NumAxes() - 1 << "]," << " but got " << bias_add_axis << ")"; CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)) << Error::RuntimeError() << "The size of tensor " << a_tensor_desc.shape().ToString() << " must match the size of tensor " << b_tensor_desc.shape().ToString() << " at dimension " << bias_add_axis; ctx->SetOutputShape("out", 0, ctx->InputShape("a", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("a", 0)); return Maybe::Ok(); } /*static*/ Maybe BiasAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BiasAddOp::GetSbp(user_op::SbpContext* ctx) { const auto axis = ctx->Attr("axis"); for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); ++i) { if (i == axis) { continue; } ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i) .Broadcast(user_op::OpArg("b", 0)) .Split(ctx->outputs(), i) .Build(); } ctx->NewBuilder() .Split(user_op::OpArg("b", 0), 0) .Split(user_op::OpArg("a", 0), axis) .Split(ctx->outputs(), axis) .Build(); return Maybe::Ok(); } /* static */ Maybe BiasAddOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("a", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/binary_cross_entropy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDescFn_(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(input_desc.is_dynamic()); out_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(input_desc.data_type()) << ", but got " << DataType_Name(target_desc.data_type()); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(input_desc.data_type()) << ", but got " << DataType_Name(weight_desc.data_type()); } ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(input_desc.is_dynamic()); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(input_desc.data_type()) << ", but got " << DataType_Name(target_desc.data_type()); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(input_desc.data_type()) << ", but got " << DataType_Name(weight_desc.data_type()); } ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe BinaryCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDescFn_(ctx); } /*static*/ Maybe BinaryCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { return GenLossForwardDefaultGetSbpFn()(ctx); } /* static */ Maybe BinaryCrossEntropyOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); target_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe BinaryCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { return InferDataType_(ctx); } /* static */ Maybe BinaryCrossEntropyGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe BinaryCrossEntropyGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { return GenLossBackwardDefaultGetSbpFn()(ctx); } /* static */ Maybe BinaryCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { return InferGradDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } if (ctx->Attr("has_pos_weight")) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.is_dynamic(), input_desc.is_dynamic()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(input_desc.is_dynamic()); out_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(target_desc.data_type()) << ", but got " << DataType_Name(weight_desc.data_type()); } if (ctx->Attr("has_pos_weight")) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(target_desc.data_type()) << ", but got " << DataType_Name(pos_weight_desc.data_type()); } ctx->SetOutputDType("out", 0, ctx->InputDType("target", 0)); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()) << "Dy shape should be equal to Target shape. "; if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } if (ctx->Attr("has_pos_weight")) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.is_dynamic(), input_desc.is_dynamic()); } user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(input_desc.is_dynamic()); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(weight_desc.data_type()) << ", but got " << DataType_Name(target_desc.data_type()); } if (ctx->Attr("has_pos_weight")) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(target_desc.data_type()) << ", but got " << DataType_Name(pos_weight_desc.data_type()); } ctx->SetOutputDType("dx", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDescFn(ctx); } /*static*/ Maybe BinaryCrossEntropyWithLogitsOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsOp::GetSbp(user_op::SbpContext* ctx) { return GenLossForwardDefaultGetSbpFn( [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { if (ctx->user_op_conf().has_input("pos_weight", 0)) { builder.Broadcast(user_op::OpArg("pos_weight", 0)); } })(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); target_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferDataType(user_op::InferContext* ctx) { return InferDataType_(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe BinaryCrossEntropyWithLogitsGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::GetSbp(user_op::SbpContext* ctx) { return GenLossBackwardDefaultGetSbpFn( [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { if (ctx->user_op_conf().has_input("pos_weight", 0)) { builder.Broadcast(user_op::OpArg("pos_weight", 0)); } })(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferDataType( user_op::InferContext* ctx) { return InferGradDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); out_desc->set_shape(Shape({})); return Maybe::Ok(); } Maybe InferFwDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); ctx->SetOutputDType("out", 0, ctx->InputDType("target", 0)); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(false); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); ctx->SetOutputDType("dx", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDescFn(ctx); } /*static*/ Maybe BinaryCrossEntropyWithLogitsReduceMeanOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanOp::GetSbp( user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr) << "target_modifier should not be nullptr. "; target_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanOp::InferDataType( user_op::InferContext* ctx) { return InferFwDataType(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanGradOp::GetSbp( user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Broadcast(user_op::OpArg("dy", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferDataType( user_op::InferContext* ctx) { return InferGradDataType(ctx); } /* static */ Maybe FusedBCEReduceMeanFwBwOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); out_desc->set_shape(Shape({})); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(false); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe FusedBCEReduceMeanFwBwOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedBCEReduceMeanFwBwOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .PartialSum(user_op::OpArg("out", 0)) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedBCEReduceMeanFwBwOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()], DType::priority_order[DType::Float16()->data_type()]); DataType out_dtype = ctx->Attr("out_dtype"); if (out_dtype == DataType::kInvalidDataType) { out_dtype = target_desc.data_type(); } ctx->SetOutputDType("out", 0, out_dtype); ctx->SetOutputDType("dx", 0, input_desc.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/bincount_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); const int64_t size = ctx->Attr("size"); output_desc->set_shape(Shape({size})); return Maybe::Ok(); } } // namespace /* static */ Maybe BinCountOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc(ctx); } /*static*/ Maybe BinCountOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BinCountOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe BinCountOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); if (ctx->has_input("weight", 0)) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weight", 0); output_desc->set_data_type(weight_desc.data_type()); } else { output_desc->set_data_type(input_desc.data_type()); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/broadcast_div_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dy", 0, ctx->InputShape("y", 0)); ctx->SetOutputIsDynamic("dy", 0, ctx->InputIsDynamic("y", 0)); return Maybe::Ok(); } /*static*/ Maybe BroadcastDivGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BroadcastDivGradOp::GetSbp(user_op::SbpContext* ctx) { const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { const int64_t axis_y = y_shape.NumAxes() - 1 - i; const int64_t axis_z = z_shape.NumAxes() - 1 - i; if (y_shape.At(axis_y) == z_shape.At(axis_z)) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), axis_y) .Split(user_op::OpArg("z", 0), axis_z) .Split(user_op::OpArg("dz", 0), axis_z) .Split(user_op::OpArg("dy", 0), axis_y) .Build(); } else { ctx->NewBuilder() .Broadcast(user_op::OpArg("y", 0)) .Split(user_op::OpArg("z", 0), axis_z) .Split(user_op::OpArg("dz", 0), axis_z) .PartialSum(user_op::OpArg("dy", 0)) .Build(); } } ctx->NewBuilder() .Broadcast(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Broadcast(user_op::OpArg("dz", 0)) .PartialSum(user_op::OpArg("dy", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("y", 0)) .Broadcast(user_op::OpArg("z", 0)) .PartialSum(user_op::OpArg("dz", 0)) .PartialSum(user_op::OpArg("dy", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe BroadcastDivGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dy", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/broadcast_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe GetSbpSignatures(user_op::SbpContext* ctx) { const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); int32_t x_num_axes = x_shape.NumAxes(); int32_t like_num_axes = like_shape.NumAxes(); const auto& reduced_axes = ctx->Attr>("broadcast_axes"); HashSet conf_axes; ReduceSbpUtil::GetRegularAxes(like_num_axes, reduced_axes, &conf_axes); auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, like_num_axes); int32_t num_reduced_axis = 0; FOR_RANGE(int64_t, i, 0, like_num_axes) { if (IsReducedAxis(i)) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("like", 0), i) .Split(user_op::OpArg("y", 0), i) .Build(); if (x_num_axes < like_num_axes) { num_reduced_axis += 1; } } else { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i - num_reduced_axis) .Split(user_op::OpArg("like", 0), i) .Split(user_op::OpArg("y", 0), i) .Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); } bool IsAxesLegal(const AxisVector& axis_vec, const Shape& like_shape, const Shape& in_shape) { Shape reduced_like_shape = CreateReducedShape(like_shape, axis_vec); if (like_shape.NumAxes() > in_shape.NumAxes()) { std::vector in_shape_vec; in_shape_vec.reserve(in_shape.NumAxes()); std::vector like_shape_vec; like_shape_vec.reserve(reduced_like_shape.NumAxes()); for (const int64_t& dim : in_shape.dim_vec()) { if (dim != 1) { in_shape_vec.emplace_back(dim); } } for (const int64_t& dim : reduced_like_shape.dim_vec()) { if (dim != 1) { like_shape_vec.emplace_back(dim); } } if (in_shape_vec.size() > like_shape_vec.size()) { return false; } else { return std::equal(in_shape_vec.begin(), in_shape_vec.end(), like_shape_vec.begin()); } } return reduced_like_shape.dim_vec() == in_shape.dim_vec(); } Maybe InferTensorDesc(user_op::InferContext* ctx) { const auto& broadcast_axes = ctx->Attr>("broadcast_axes"); CHECK_OR_RETURN(!broadcast_axes.empty()); const Shape& in_shape = ctx->InputShape("x", 0); const Shape& like_shape = ctx->InputShape("like", 0); const AxisVector axis_vec = {broadcast_axes.begin(), broadcast_axes.end()}; CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape)) << Error::RuntimeError() << "Invalid input parameter: like shape:" << like_shape.ToString() << ", in shape:" << in_shape.ToString() << ", axis_vec size:" << axis_vec.size(); ctx->SetOutputShape("y", 0, like_shape); ctx->SetOutputStride("y", 0, Stride(like_shape)); return Maybe::Ok(); } } // namespace /* static */ Maybe BroadcastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc(ctx); } /*static*/ Maybe BroadcastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BroadcastLikeOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures(ctx); } /* static */ Maybe BroadcastLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); CHECK_OR_RETURN(like_modifier != nullptr); like_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe BroadcastLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/buffer_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe IdentityBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IdentityBufferOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe IdentityBufferOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cast_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe CastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CastLikeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); for (int i = 0; i < in_shape.NumAxes(); ++i) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("dtype_like", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("dtype_like", 0)) .Broadcast(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("dtype_like", 0)) .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dtype_like", 0)) .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe CastLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); dtype_like_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe CastLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); output_tensor_desc->set_data_type(dtype_like_tensor_desc.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cast_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" namespace oneflow { namespace { Maybe> MakeCastStream(const Symbol& in_device, const bool pin_memory) { if (pin_memory) { CHECK_OR_RETURN(in_device->type() == "cpu") << "cast op only support pin_memory in cpu device but got " << in_device->type(); // TODO:(zhaoluyang) Parsing pin-memory-device from python auto pin_device = JUST(Device::New("cuda")); return Stream::New(pin_device, StreamType::kPinnedCompute); } return Stream::New(in_device, StreamType::kCompute); } } // namespace /* static */ Maybe CastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); output_tensor_desc->set_shape(input_tensor_desc.shape()); output_tensor_desc->set_stride( input_tensor_desc.stride()); // output's stride should consistent with input's output_tensor_desc->set_is_dynamic(input_tensor_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe CastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CastOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe CastOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); output_tensor_desc->set_data_type(ctx->Attr("dtype")); return Maybe::Ok(); } /* static */ Maybe> CastOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = in_device; const bool pin_memory = ctx->Attr("pin_memory"); return MakeCastStream(in_device, pin_memory); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cast_to_static_shape_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CastToStaticShapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); output_desc->set_shape(input_desc.shape()); output_desc->set_is_dynamic(false); return Maybe::Ok(); } /*static*/ Maybe CastToStaticShapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CastToStaticShapeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("output", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("input", 0)) .PartialSum(user_op::OpArg("output", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe CastToStaticShapeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cast_to_tick_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape({1})); return Maybe::Ok(); } /*static*/ Maybe CastToTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CastToTickOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe CastToTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); in_distribution->clear_sbp_parallel(); out_distribution->clear_sbp_parallel(); // in use hint in_distribution->CopyFrom(in_dis_hint); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { // out dim1 = broadcast out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); } return Maybe::Ok(); } /* static */ Maybe CastToTickOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/categorical_ordinal_encode_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CategoricalOrdinalEncodeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& table_shape = ctx->InputShape("table", 0); CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe CategoricalOrdinalEncodeOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1); const Shape& table_shape = ctx->InputShape("table", 0); CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe CategoricalOrdinalEncodeOp::GetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); return Maybe::Ok(); } /* static */ Maybe CategoricalOrdinalEncodeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); table->set_is_mutable(true); table->set_requires_grad(false); user_op::InputArgModifier* size = GetInputArgModifierFn("size", 0); size->set_is_mutable(true); size->set_requires_grad(false); user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); in->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe CategoricalOrdinalEncodeOp::CheckAttr( const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { CHECK_OR_RETURN(conf.attr("hash_precomputed")); return Maybe::Ok(); } /* static */ Maybe CategoricalOrdinalEncodeOp::InferDataType(user_op::InferContext* ctx) { DataType data_type = ctx->InputDType("in", 0); CHECK_OR_RETURN(IsIndexDataType(data_type)); CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("table", 0)) << ", but got " << DataType_Name(data_type); CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("size", 0)) << ", but got " << DataType_Name(data_type); ctx->SetOutputDType("out", 0, data_type); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/celu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe CeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CeluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe CeluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe CeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CeluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe CeluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("y", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/clip_by_value_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferClipTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } Maybe GetClipSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe InferClipGradTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } Maybe GetClipGradSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } Maybe InferClipTensorDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe InferClipGradDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define DEF_CLIP_BY_VALUE_OP(op_class_name_prefix) \ /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferClipTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ return GetClipSbpSignature(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ return InferClipTensorDataType(ctx); \ } \ /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferClipGradTensorDesc(ctx); \ } \ /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return GetClipGradSbpSignature(ctx); \ } \ /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ user_op::InferContext* ctx) { \ return InferClipGradDataType(ctx); \ } DEF_CLIP_BY_VALUE_OP(ClipByScalar) DEF_CLIP_BY_VALUE_OP(ClipByScalarMin) DEF_CLIP_BY_VALUE_OP(ClipByScalarMax) #undef DEF_CLIP_BY_VALUE_OP } // namespace oneflow ================================================ FILE: oneflow/user/ops/coco_reader_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe COCOReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int64_t batch_size = ctx->Attr("batch_size"); user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); image_desc->set_shape(Shape({batch_size})); user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); image_id_desc->set_shape(Shape({batch_size})); user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); image_size_desc->set_shape(Shape({batch_size, 2})); user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); bbox_desc->set_shape(Shape({batch_size})); user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); label_desc->set_shape(Shape({batch_size})); user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); segm_desc->set_shape(Shape({batch_size})); user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); segm_index_desc->set_shape(Shape({batch_size})); return Maybe::Ok(); } /* static */ Maybe COCOReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("image", 0); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("image_id", 0)); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("image_size", 0)); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("gt_bbox", 0)); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("gt_label", 0)); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("gt_segm", 0)); CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex("gt_segm_index", 0)); int64_t batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); int64_t device_batch_size = batch_size; if (parallel_num > 1) { int64_t split_num = 1; const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(device_batch_size % split_num, 0); device_batch_size /= split_num; } user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); image_desc->set_shape(Shape({device_batch_size})); user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); image_id_desc->set_shape(Shape({device_batch_size})); user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); image_size_desc->set_shape(Shape({device_batch_size, 2})); user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); bbox_desc->set_shape(Shape({device_batch_size})); user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); label_desc->set_shape(Shape({device_batch_size})); user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); segm_desc->set_shape(Shape({device_batch_size})); user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); segm_index_desc->set_shape(Shape({device_batch_size})); return Maybe::Ok(); } /* static */ Maybe COCOReaderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe COCOReaderOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); CHECK_OR_RETURN(image_modifier != nullptr); image_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn("image_id", 0); CHECK_OR_RETURN(image_id_modifier != nullptr); image_id_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn("image_size", 0); CHECK_OR_RETURN(image_size_modifier != nullptr); image_size_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn("gt_bbox", 0); CHECK_OR_RETURN(gt_bbox_modifier != nullptr); gt_bbox_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn("gt_label", 0); CHECK_OR_RETURN(gt_label_modifier != nullptr); gt_label_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn("gt_segm", 0); CHECK_OR_RETURN(gt_segm_modifier != nullptr); gt_segm_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* gt_segm_index_modifier = GetOutputArgModifierFn("gt_segm_index", 0); CHECK_OR_RETURN(gt_segm_index_modifier != nullptr); gt_segm_index_modifier->set_header_infered_before_compute(false); return Maybe::Ok(); } /* static */ Maybe COCOReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_split_parallel()->set_axis(0); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe COCOReaderOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); image_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); image_id_desc->set_data_type(DataType::kInt64); user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); image_size_desc->set_data_type(DataType::kInt32); user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); bbox_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); label_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); segm_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); segm_index_desc->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/combined_margin_loss_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CombinedMarginLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); user_op::TensorDesc* theta = ctx->MutOutputTensorDesc("theta", 0); CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); ctx->SetIsDynamic4ArgNameAndIndex("y", 0, ctx->InputIsDynamic("x", 0)); theta->set_is_dynamic(x.is_dynamic()); theta->set_shape(label.shape()); return Maybe::Ok(); } /*static*/ Maybe CombinedMarginLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CombinedMarginLossOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("y", 0), 0) .Split(user_op::OpArg("theta", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 1) .Broadcast(user_op::OpArg("label", 0)) .Split(user_op::OpArg("y", 0), 1) .PartialSum(user_op::OpArg("theta", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe CombinedMarginLossOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); label_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe CombinedMarginLossOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("theta", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe CombinedMarginLossGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); ctx->SetOutputShape("dx", 0, ctx->InputShape("dy", 0)); ctx->SetIsDynamic4ArgNameAndIndex("dx", 0, ctx->InputIsDynamic("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe CombinedMarginLossGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CombinedMarginLossGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("theta", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 1) .Broadcast(user_op::OpArg("label", 0)) .Broadcast(user_op::OpArg("theta", 0)) .Split(user_op::OpArg("dx", 0), 1) .Build(); return Maybe::Ok(); } /* static */ Maybe CombinedMarginLossGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/comm_net_device_infer_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/common/decorator.h" namespace oneflow { namespace { Maybe> RawGetTransportDevice(Symbol device) { return Stream::New(JUST(Device::New(device->type())), StreamType::kCcl); } } // namespace decltype(GetTransportDevice) GetTransportDevice = DECORATE(&RawGetTransportDevice, ThreadLocal); Maybe> DefaultGetOutputDeivce(user_op::DeviceAndStreamInferContext* ctx) { CHECK_GT_OR_RETURN(ctx->inputs().size(), 0); return ctx->InputTensorDevice4ArgNameAndIndex("in", 0); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/comm_net_device_infer_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_ #define ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/common/env_var/eager.h" #include "oneflow/core/job/lazy_mode.h" namespace oneflow { extern Maybe> (*GetTransportDevice)(Symbol); Maybe> DefaultGetOutputDeivce(user_op::DeviceAndStreamInferContext* ctx); template> (*GetOutputDeivce)(user_op::DeviceAndStreamInferContext*) = DefaultGetOutputDeivce> Maybe> DeviceAndStreamInferFn(user_op::DeviceAndStreamInferContext* ctx) { Symbol output_device = JUST(GetOutputDeivce(ctx)); for (const auto& pair : ctx->outputs()) { *ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = output_device; } if (EagerNcclUseComputeStream() && !LazyMode::is_enabled()) { return GetDefaultStreamByDevice(output_device); } return GetTransportDevice(output_device); } } // namespace oneflow #endif // ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_ ================================================ FILE: oneflow/user/ops/complex_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { static std::map complex_to_real_map{{DataType::kComplex32, DataType::kFloat16}, {DataType::kComplex64, DataType::kFloat}, {DataType::kComplex128, DataType::kDouble}}; static std::map real_to_complex_map{{DataType::kFloat16, DataType::kComplex32}, {DataType::kFloat, DataType::kComplex64}, {DataType::kDouble, DataType::kComplex128}}; /*static*/ Maybe RealOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe RealOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe RealOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RealOp::InferDataType(user_op::InferContext* ctx) { const std::pair& input_arg = ctx->inputs().at(0); const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); const std::pair& output_arg = ctx->outputs().at(0); ctx->SetOutputDType(output_arg.first, output_arg.second, complex_to_real_map[tensor_desc.data_type()]); return Maybe::Ok(); } /*static*/ Maybe RealGradOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe RealGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe RealGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RealGradOp::InferDataType(user_op::InferContext* ctx) { const std::pair& input_arg = ctx->inputs().at(0); const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); const std::pair& output_arg = ctx->outputs().at(0); ctx->SetOutputDType(output_arg.first, output_arg.second, real_to_complex_map[tensor_desc.data_type()]); return Maybe::Ok(); } /*static*/ Maybe ImagOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe ImagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe ImagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ImagOp::InferDataType(user_op::InferContext* ctx) { const std::pair& input_arg = ctx->inputs().at(0); const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); const std::pair& output_arg = ctx->outputs().at(0); ctx->SetOutputDType(output_arg.first, output_arg.second, complex_to_real_map[tensor_desc.data_type()]); return Maybe::Ok(); } /*static*/ Maybe ImagGradOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe ImagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe ImagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ImagGradOp::InferDataType(user_op::InferContext* ctx) { const std::pair& input_arg = ctx->inputs().at(0); const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); const std::pair& output_arg = ctx->outputs().at(0); ctx->SetOutputDType(output_arg.first, output_arg.second, real_to_complex_map[tensor_desc.data_type()]); return Maybe::Ok(); } /*static*/ Maybe ConjPhysicalOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe ConjPhysicalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe ConjPhysicalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ConjPhysicalOp::InferDataType(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/concat_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ConcatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); const int64_t axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, first_in_desc.shape().NumAxes()); DimVector out_dim_vec = first_in_desc.shape().dim_vec(); out_dim_vec.at(axis) = 0; int64_t first_axes = first_in_desc.shape().NumAxes(); int64_t first_elemcnt = first_in_desc.shape().elem_cnt(); int64_t dynamic_dim_size = 0; for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); if (first_elemcnt == 0 and first_axes == 1) { if (in_desc.shape().elem_cnt() != 0 or in_desc.shape().NumAxes() != 1) { out_dim_vec = in_desc.shape().dim_vec(); out_dim_vec.at(axis) = 0; first_axes = in_desc.shape().NumAxes(); first_elemcnt = in_desc.shape().elem_cnt(); } else { continue; } } else if (in_desc.shape().elem_cnt() != 0 or in_desc.shape().NumAxes() != 1) { CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), first_axes); } FOR_RANGE(int64_t, i, 0, in_desc.shape().NumAxes()) { if (in_desc.shape().elem_cnt() == 0 and in_desc.shape().NumAxes() == 1) { continue; } if (i == axis) { if (in_desc.is_dynamic()) { dynamic_dim_size += in_desc.shape().At(i); } else { out_dim_vec.at(axis) += in_desc.shape().At(i); } } else { CHECK_EQ_OR_RETURN(in_desc.shape().At(i), out_dim_vec.at(i)); } } } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const int64_t max_dim_size = ctx->Attr("max_dim_size"); CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size); if (dynamic_dim_size == 0) { out_desc->set_is_dynamic(false); } else { out_desc->set_is_dynamic(true); out_dim_vec.at(axis) = max_dim_size; } out_desc->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ConcatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ConcatOp::GetSbp(user_op::SbpContext* ctx) { const int64_t axis = ctx->Attr("axis"); const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) { if (i == axis) { continue; } ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe ConcatOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(in_desc.data_type()) << ", but got " << DataType_Name(first_in_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); return Maybe::Ok(); } /*static*/ Maybe ConcatOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("in") >= 2); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/constant_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } /*static*/ Maybe ConstantOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe ConstantOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe ConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe ConstantOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/conv_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { template Maybe InferTensorDesc4Conv(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes()) << "Conv" << NDims << "D op's input shape ndim should equal to " << NDims + 2 << " ,but got: " << in.shape().NumAxes(); auto data_format = ctx->Attr("data_format"); auto kernel_size = ctx->Attr>("kernel_size"); CHECK_EQ_OR_RETURN(NDims, kernel_size.size()); int32_t filters = ctx->Attr("filters"); size_t idx_offset = IdxOffset(data_format); { const auto& padding_before = ctx->Attr>("padding_before"); auto dilation_rate = ctx->Attr>("dilation_rate"); auto strides = ctx->Attr>("strides"); CHECK_EQ_OR_RETURN(NDims, dilation_rate.size()); CHECK_EQ_OR_RETURN(NDims, strides.size()); CHECK_EQ_OR_RETURN(NDims, padding_before.size()); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector out_shape(NDims + 2); out_shape.at(0) = in.shape().At(0); const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1; out_shape.at(c_dim) = filters; for (int32_t i = 0; i < NDims; ++i) { JUST(CalcConvOut(in.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i), strides.at(i), padding_before.at(i), &out_shape.at(idx_offset + i))); } out->set_is_dynamic(in.is_dynamic()); out->set_shape(Shape(out_shape)); } { int32_t groups = ctx->Attr("groups"); CHECK_GT_OR_RETURN(groups, 0); CHECK_LE_OR_RETURN(groups, filters); CHECK_EQ_OR_RETURN(filters % groups, 0); DimVector weight_shape(in.shape().dim_vec()); weight_shape.at(0) = filters; if (data_format == "channels_first") { CHECK_LE_OR_RETURN(groups, weight_shape.at(1)); CHECK_EQ_OR_RETURN(weight_shape.at(1) % groups, 0); weight_shape.at(1) = weight_shape.at(1) / groups; } else if (data_format == "channels_last") { CHECK_LE_OR_RETURN(groups, weight_shape.at(NDims + 1)); CHECK_EQ_OR_RETURN(weight_shape.at(NDims + 1) % groups, 0); weight_shape.at(NDims + 1) = weight_shape.at(NDims + 1) / groups; } else { UNIMPLEMENTED_THEN_RETURN(); } for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); } const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape)); } bool has_bias = ctx->has_input("bias", 0); if (has_bias) { const user_op::TensorDesc& bias = ctx->InputTensorDesc("bias", 0); CHECK_EQ_OR_RETURN(bias.shape(), Shape({filters})); } return Maybe::Ok(); } Maybe GetSbpSignatures4Conv(user_op::SbpContext* ctx) { bool has_bias = false; for (const auto& pair : ctx->inputs()) { if (pair.first == "bias") { CHECK_EQ_OR_RETURN(0, pair.second); has_bias = true; break; } } if (has_bias) { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Split(user_op::OpArg("in", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Broadcast(user_op::OpArg("bias", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); } else { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Split(user_op::OpArg("in", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); } return Maybe::Ok(); } /* Example for conv2d: ComputationCost = ((k*k + k*k-1)*c + c-1 + bias?1:0) * out_channel * out_width * out_height * batch_size = (2*k*k*c - 1 + bias?1:0) * out_channel * out_width * out_height * batch_size ≈ 2*k*k*c * out_channel * out_width * out_height * batch_size */ Maybe ConvComputationCost(user_op::ComputeComplexityFnContext* ctx) { const std::vector kernel_size = ctx->Attr>("kernel_size"); const std::string data_format = ctx->Attr("data_format"); const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex("in", 0); const size_t c_dim = data_format == "channels_first" ? 1 : in->shape().NumAxes() - 1; const int32_t c = in->shape().At(c_dim); const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("out", 0); double cost = std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies()); cost = cost * 2 * c; cost *= std::accumulate(out->shape().dim_vec().begin(), out->shape().dim_vec().end(), 1.0, std::multiplies()); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); const auto& nd_sbp_out = ctx->NdSbp4ArgNameAndIndex("out", 0); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_out.sbp_parallel_size(); dim_sbp++) { if (nd_sbp_out.sbp_parallel(dim_sbp).has_split_parallel()) { cost /= parallel_hierarchy->At(dim_sbp); } } return cost; } template Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; const auto& data_format = conf.attr("data_format"); if (!(data_format == "channels_first" || data_format == "channels_last")) { err << " data_format:" << data_format; is_checked = false; } if (NDims != 0) { const auto& padding_before = conf.attr>("padding_before"); if (padding_before.size() != NDims) { err << " padding_before: number of element is " << padding_before.size(); is_checked = false; } const auto& kernel_size = conf.attr>("kernel_size"); if (kernel_size.size() != NDims) { err << " kernel_size: number of element is " << kernel_size.size(); is_checked = false; } const auto& strides = conf.attr>("strides"); if (strides.size() != NDims) { err << " strides: number of element is " << strides.size(); is_checked = false; } const auto& dilation_rate = conf.attr>("dilation_rate"); if (dilation_rate.size() != NDims) { err << " dilation_rate: number of element is " << dilation_rate.size(); is_checked = false; } } if (is_checked) { return Maybe::Ok(); } else { return oneflow::Error::CheckFailedError() << err.str(); } } } // namespace /* static */ Maybe Conv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4Conv<1>(ctx); } /*static*/ Maybe Conv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Conv1DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4Conv(ctx); } /* static */ Maybe Conv1DOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { return ConvComputationCost(ctx); } /* static */ Maybe Conv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<1>(def, conf); } /* static */ Maybe Conv1DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe Conv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4Conv<2>(ctx); } /*static*/ Maybe Conv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Conv2DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4Conv(ctx); } /* static */ Maybe Conv2DOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { return ConvComputationCost(ctx); } /* static */ Maybe Conv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<2>(def, conf); } /* static */ Maybe Conv2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe Conv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4Conv<3>(ctx); } /*static*/ Maybe Conv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Conv3DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4Conv(ctx); } /* static */ Maybe Conv3DOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { return ConvComputationCost(ctx); } /* static */ Maybe Conv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<3>(def, conf); } /* static */ Maybe Conv3DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe ConvDataGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); CHECK_GE_OR_RETURN(num_spatial_dims, 1); CHECK_LE_OR_RETURN(num_spatial_dims, 3); CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2); if (ctx->has_input("_add_to_output", 0)) { const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); } ctx->SetOutputShape("dx", 0, ctx->InputShape("x_like", 0)); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("x_like", 0)); return Maybe::Ok(); } /*static*/ Maybe ConvDataGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ConvDataGradOp::GetSbp(user_op::SbpContext* ctx) { std::vector split_args; split_args.emplace_back("dy", 0); split_args.emplace_back("x_like", 0); split_args.emplace_back("dx", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { split_args.emplace_back("_add_to_output", 0); } ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build(); return Maybe::Ok(); } /* static */ Maybe ConvDataGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<0>(def, conf); } /* static */ Maybe ConvDataGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type()) << "InferDataType Failed. Expected " << DataType_Name(dy.data_type()) << ", but got " << DataType_Name(x_like.data_type()); if (ctx->has_input("_add_to_output", 0)) { const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()) << "InferDataType Failed. Expected " << DataType_Name(add_to_output.data_type()) << ", but got " << DataType_Name(x_like.data_type()); } ctx->SetOutputDType("dx", 0, ctx->InputDType("x_like", 0)); return Maybe::Ok(); } /* static */ Maybe ConvDataGradOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { const std::vector kernel_size = ctx->Attr>("kernel_size"); const user_op::TensorDesc* dx = ctx->TensorDesc4ArgNameAndIndex("dx", 0); const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const size_t c_dim = ctx->Attr("data_format") == "channels_first" ? 1 : dy->shape().NumAxes() - 1; double cost = std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies()) * std::accumulate(dx->shape().dim_vec().begin(), dx->shape().dim_vec().end(), 1.0, std::multiplies()) * 2.0 * dy->shape().At(c_dim); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dx", 0); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) { if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) { cost /= parallel_hierarchy->At(dim_sbp); } } return cost; } /* static */ Maybe ConvFilterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); const int32_t groups = ctx->Attr("groups"); const std::string& data_format = ctx->Attr("data_format"); const std::vector kernel_size = ctx->Attr>("kernel_size"); CHECK_GE_OR_RETURN(num_spatial_dims, 1); CHECK_LE_OR_RETURN(num_spatial_dims, 3); CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2); CHECK_GT_OR_RETURN(groups, 0); DimVector filter_diff_dim_vec; if (data_format == "channels_first") { CHECK_LE_OR_RETURN(groups, x.shape().At(1)); CHECK_LE_OR_RETURN(groups, dy.shape().At(1)); CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0); CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0); filter_diff_dim_vec.emplace_back(dy.shape().At(1)); filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups); filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); } else { CHECK_EQ_OR_RETURN("channels_last", data_format); CHECK_EQ_OR_RETURN(groups, 1); filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back()); filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); } user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc("filter_diff", 0); filter_diff->set_shape(Shape(filter_diff_dim_vec)); filter_diff->set_is_dynamic(false); return Maybe::Ok(); } /*static*/ Maybe ConvFilterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ConvFilterGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .PartialSum(user_op::OpArg("filter_diff", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ConvFilterGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<0>(def, conf); } /* static */ Maybe ConvFilterGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()) << "InferDataType Failed. Expected " << DataType_Name(dy.data_type()) << ", but got " << DataType_Name(x.data_type()); user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc("filter_diff", 0); filter_diff->set_data_type(x.data_type()); return Maybe::Ok(); } /* static */ Maybe ConvFilterGradOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { const std::vector kernel_size = ctx->Attr>("kernel_size"); const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); const size_t c_dim = ctx->Attr("data_format") == "channels_first" ? 1 : x->shape().NumAxes() - 1; double cost = std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies()) * std::accumulate(dy->shape().dim_vec().begin(), dy->shape().dim_vec().end(), 1.0, std::multiplies()) * 2.0 * x->shape().At(c_dim); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dy", 0); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) { if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) { cost /= parallel_hierarchy->At(dim_sbp); } } return cost; } /* static */ Maybe ConvBiasGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc("bias_diff", 0); int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); std::string data_format = ctx->Attr("data_format"); CHECK_GE_OR_RETURN(num_spatial_dims, 1); CHECK_LE_OR_RETURN(num_spatial_dims, 3); CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); if (data_format == "channels_first") { bias_diff->set_shape(Shape({dy.shape().At(1)})); } else if (data_format == "channels_last") { bias_diff->set_shape(Shape({dy.shape().At(dy.shape().NumAxes() - 1)})); } else { OF_UNIMPLEMENTED(); } return Maybe::Ok(); } /*static*/ Maybe ConvBiasGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ConvBiasGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .PartialSum(user_op::OpArg("bias_diff", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ConvBiasGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { std::string data_format = conf.attr("data_format"); if (data_format == "channels_first" || data_format == "channels_last") { return Maybe::Ok(); } return oneflow::Error::CheckFailedError() << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": data_format:" << data_format; } /* static */ Maybe ConvBiasGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc("bias_diff", 0); bias_diff->set_data_type(dy.data_type()); return Maybe::Ok(); } /* static */ Maybe ConvBiasGradOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const std::string data_format = ctx->Attr("data_format"); double cost = std::accumulate(dy->shape().dim_vec().begin(), dy->shape().dim_vec().end(), 1.0, std::multiplies()); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dy", 0); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) { if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) { cost /= parallel_hierarchy->At(dim_sbp); } } return cost; } } // namespace oneflow ================================================ FILE: oneflow/user/ops/convert_memory_format_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/convert_memory_format_op.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { static Shape ComputeShapeIdentity(const Shape& shape) { return shape; } Shape ComputeShapeContiguousToChannelsLast(const Shape& shape) { int ndim = shape.size(); if (ndim <= 2) { return ComputeShapeIdentity(shape); } Shape target_shape(ndim); target_shape[0] = shape[0]; target_shape[ndim - 1] = shape[1]; for (int i = 0; i < ndim - 2; ++i) { target_shape[i + 1] = shape[i + 2]; } return target_shape; } Shape ComputeShapeChannelsLastToContiguous(const Shape& shape) { int ndim = shape.size(); if (ndim <= 2) { return ComputeShapeIdentity(shape); } Shape target_shape(ndim); target_shape[0] = shape[0]; target_shape[1] = shape[ndim - 1]; for (int i = 0; i < ndim - 2; ++i) { target_shape[i + 2] = shape[i + 1]; } return target_shape; } static Maybe GetSbpIdentity(user_op::SbpContext* ctx, const Shape& shape) { for (int32_t i = 0; i < shape.size(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } static Maybe GetSbpContiguousToChannelsLast(user_op::SbpContext* ctx, const Shape& shape) { int ndim = shape.size(); if (ndim <= 2) { return GetSbpIdentity(ctx, shape); } ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); ctx->NewBuilder().Split(ctx->inputs(), 1).Split(ctx->outputs(), ndim - 1).Build(); for (int32_t i = 0; i < ndim - 2; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i + 2).Split(ctx->outputs(), i + 1).Build(); } return Maybe::Ok(); } static Maybe GetSbpChannelsLastToContiguous(user_op::SbpContext* ctx, const Shape& shape) { int ndim = shape.size(); if (ndim <= 2) { return GetSbpIdentity(ctx, shape); } ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); ctx->NewBuilder().Split(ctx->inputs(), ndim - 1).Split(ctx->outputs(), 1).Build(); for (int32_t i = 0; i < ndim - 2; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i + 1).Split(ctx->outputs(), i + 2).Build(); } return Maybe::Ok(); } using ComputeShapeFunc = std::function; using GetSbpFunc = std::function(user_op::SbpContext* ctx, const Shape& shape)>; static ComputeShapeFunc compute_shape_funcs[kMemoryFormatCount][kMemoryFormatCount] = { /*kContiguous->other*/ {ComputeShapeIdentity, ComputeShapeContiguousToChannelsLast}, /*kChannelsLast->other*/ {ComputeShapeChannelsLastToContiguous, ComputeShapeIdentity}, }; static GetSbpFunc get_sbp_funcs[kMemoryFormatCount][kMemoryFormatCount] = { /*kContiguous->other*/ {GetSbpIdentity, GetSbpContiguousToChannelsLast}, /*kChannelsLast->other*/ {GetSbpChannelsLastToContiguous, GetSbpIdentity}, }; Shape ComputeConvertMemoryFormatShape(const Shape& shape, MemoryFormat memory_format, MemoryFormat target_memory_format) { auto shape_func = compute_shape_funcs[memory_format][target_memory_format]; return shape_func(shape); } static Maybe GetConvertMemoryFormatSbp(user_op::SbpContext* ctx, const Shape& shape, MemoryFormat memory_format, MemoryFormat target_memory_format) { auto sbp_func = get_sbp_funcs[memory_format][target_memory_format]; return sbp_func(ctx, shape); } /*static*/ Maybe ConvertMemoryFormatOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const auto& memory_format = ctx->Attr("memory_format"); JUST(GetConvertMemoryFormatSbp(ctx, input_tensor.shape(), input_tensor.memory_format(), memory_format)); ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /*static*/ Maybe ConvertMemoryFormatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = in_tensor_desc.shape(); const auto& memory_format = ctx->Attr("memory_format"); out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic()); out_tensor_desc->set_shape( ComputeConvertMemoryFormatShape(in_shape, in_tensor_desc.memory_format(), memory_format)); out_tensor_desc->set_memory_format(memory_format); return Maybe::Ok(); } /*static*/ Maybe ConvertMemoryFormatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ConvertMemoryFormatOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/convert_memory_format_op.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" namespace oneflow { Shape ComputeShapeContiguousToChannelsLast(const Shape& shape); Shape ComputeShapeChannelsLastToContiguous(const Shape& shape); Shape ComputeConvertMemoryFormatShape(const Shape& shape, MemoryFormat memory_format, MemoryFormat target_memory_format); } // namespace oneflow ================================================ FILE: oneflow/user/ops/copy_hd_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { Maybe InferLogical(user_op::InferContext* ctx) { UNIMPLEMENTED_THEN_RETURN() << "copy hd should only exist in physical graph"; } Maybe InferPhysical(user_op::InferContext* ctx) { *ctx->MutOutputTensorDesc("out", 0) = ctx->InputTensorDesc("in", 0); return Maybe::Ok(); } Maybe FwGetSbpFn(user_op::SbpContext* ctx) { return Maybe::Ok(); } Maybe InferFWDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace Maybe CopyD2HOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferLogical(ctx); } Maybe CopyD2HOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferPhysical(ctx); } Maybe CopyD2HOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } Maybe CopyD2HOp::InferDataType(user_op::InferContext* ctx) { return InferFWDataType(ctx); } Maybe CopyH2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferLogical(ctx); } Maybe CopyH2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferPhysical(ctx); } Maybe CopyH2DOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } Maybe CopyH2DOp::InferDataType(user_op::InferContext* ctx) { return InferFWDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/copy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/common/env_var/stream.h" namespace oneflow { namespace { StreamType GetH2DStreamType() { if (ThreadLocalEnvBool()) { return StreamType::kHost2Device; } else { return StreamType::kCompute; } } Maybe> MakeCopyStream(const Symbol& in_device, const Symbol& out_device, const bool pin_memory) { if (in_device->type() != "cpu" && out_device->type() == "cpu") { return Stream::New(in_device, StreamType::kDevice2Host); } else if (in_device->type() == "cpu" && out_device->type() != "cpu") { return Stream::New(out_device, GetH2DStreamType()); } else if (in_device->type() == "cpu" && out_device->type() == "cpu" && pin_memory) { // TODO:(zhaoluyang) Parsing pin-memory-device from python auto pin_device = JUST(Device::New("cuda")); return Stream::New(pin_device, StreamType::kPinnedCompute); } else { CHECK_EQ_OR_RETURN(in_device->type(), out_device->type()); return Stream::New(out_device, StreamType::kCompute); } } } // namespace /* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputStride("out", 0, ctx->InputStride("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe CopyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CopyOp::GetSbp(user_op::SbpContext* ctx) { const auto& inputs = ctx->inputs(); CHECK_EQ_OR_RETURN(inputs.size(), 1); const auto& input = ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second); for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) { ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); } ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe CopyOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> CopyOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { Symbol out_device = ctx->Attr>("device"); *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); const bool pin_memory = ctx->Attr("pin_memory"); return MakeCopyStream(in_device, out_device, pin_memory); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/count_not_finite_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe CountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /* static */ Maybe CountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_data_type(DataType::kInt64); return Maybe::Ok(); } /* static */ Maybe MultiCountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe MultiCountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiCountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { min_num_axes = std::min(min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); } for (int64_t i = 0; i < min_num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /* static */ Maybe MultiCountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc("x", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_x_desc.data_type()) << ", but got " << DataType_Name(x_desc.data_type()); } user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_data_type(DataType::kInt64); return Maybe::Ok(); } /*static*/ Maybe MultiCountNotFiniteOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("x") >= 1); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/ctc_loss_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe CtcLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); const int64_t batch_size = log_probs.shape().At(1); const int64_t max_target_length = ctx->Attr("max_target_length"); if (targets.shape().NumAxes() == 2) { CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); } CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); ctx->SetOutputShape("loss", 0, Shape({batch_size})); ctx->SetOutputShape("alpha", 0, Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1})); return Maybe::Ok(); } /*static*/ Maybe CtcLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CtcLossOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 .Split(user_op::OpArg("targets", 0), 0) .Split(user_op::OpArg("input_lengths", 0), 0) .Split(user_op::OpArg("target_lengths", 0), 0) .Split(user_op::OpArg("loss", 0), 0) .Split(user_op::OpArg("alpha", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe CtcLossOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("loss", 0, ctx->InputDType("log_probs", 0)); ctx->SetOutputDType("alpha", 0, ctx->InputDType("log_probs", 0)); return Maybe::Ok(); } /* static */ Maybe CtcLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); const int64_t batch_size = log_probs.shape().At(1); const int64_t max_target_length = ctx->Attr("max_target_length"); if (targets.shape().NumAxes() == 2) { CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); } CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); ctx->SetOutputShape("grad", 0, log_probs.shape()); return Maybe::Ok(); } /*static*/ Maybe CtcLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CtcLossGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("grad_out", 0), 0) .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 .Split(user_op::OpArg("targets", 0), 0) .Split(user_op::OpArg("input_lengths", 0), 0) .Split(user_op::OpArg("target_lengths", 0), 0) .Split(user_op::OpArg("loss", 0), 0) .Split(user_op::OpArg("alpha", 0), 0) .Split(user_op::OpArg("grad", 0), 1) .Build(); return Maybe::Ok(); } /* static */ Maybe CtcLossGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("grad", 0, ctx->InputDType("log_probs", 0)); return Maybe::Ok(); } /* static */ Maybe CtcGreedyDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); const int64_t batch_size = log_probs.shape().At(1); CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); ctx->SetOutputShape("decoded", 0, Shape({batch_size, log_probs.shape().At(0)})); ctx->SetOutputShape("neg_sum_logits", 0, Shape({batch_size, 1})); return Maybe::Ok(); } /*static*/ Maybe CtcGreedyDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CtcGreedyDecoderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 .Split(user_op::OpArg("input_lengths", 0), 0) .Split(user_op::OpArg("decoded", 0), 0) .Split(user_op::OpArg("neg_sum_logits", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe CtcGreedyDecoderOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("decoded", 0, ctx->InputDType("input_lengths", 0)); ctx->SetOutputDType("neg_sum_logits", 0, ctx->InputDType("log_probs", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const int64_t bias_size = weight_desc.shape().At(1); Shape d_grad_shape({dy_desc.shape().At(0), weight_desc.shape().At(1)}); ctx->SetOutputShape("d_grad", 0, d_grad_shape); ctx->SetOutputShape("d_bias", 0, Shape({bias_size})); return Maybe::Ok(); } Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), dy_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(dy_desc.data_type()) << ", but got " << DataType_Name(weight_desc.data_type()); user_op::TensorDesc* d_grad_desc = ctx->MutOutputTensorDesc("d_grad", 0); user_op::TensorDesc* d_bias_desc = ctx->MutOutputTensorDesc("d_bias", 0); d_grad_desc->set_data_type(dy_desc.data_type()); d_bias_desc->set_data_type(dy_desc.data_type()); return Maybe::Ok(); } } // namespace /* static */ Maybe CublasBiasAddReluMatmulGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmulBackward(ctx); } /*static*/ Maybe CublasBiasAddReluMatmulGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CublasBiasAddReluMatmulGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("aux", 0), 0) .Split(user_op::OpArg("d_grad", 0), 0) .PartialSum(user_op::OpArg("d_bias", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe CublasBiasAddReluMatmulGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4MatmulBackward(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4MatmulBiasAddBackward(user_op::InferContext* ctx) { /* x (m, k) w (n, k) need transpose bias (n, ) y (m, n) w_grad = dy_transpose matmul x */ const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const int64_t bias_size = dy_desc.shape().At(1); Shape w_grad_shape({dy_desc.shape().At(1), x_desc.shape().At(1)}); ctx->SetOutputShape("w_grad", 0, w_grad_shape); ctx->SetOutputShape("b_grad", 0, Shape({bias_size})); return Maybe::Ok(); } Maybe InferDataType4MatmulBiasAddBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(x_desc.data_type(), dy_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(dy_desc.data_type()) << ", but got " << DataType_Name(x_desc.data_type()); user_op::TensorDesc* w_grad_desc = ctx->MutOutputTensorDesc("w_grad", 0); user_op::TensorDesc* b_grad_desc = ctx->MutOutputTensorDesc("b_grad", 0); w_grad_desc->set_data_type(dy_desc.data_type()); b_grad_desc->set_data_type(dy_desc.data_type()); return Maybe::Ok(); } } // namespace /* static */ Maybe CublasMatmulBiasAddGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4MatmulBiasAddBackward(ctx); } /*static*/ Maybe CublasMatmulBiasAddGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CublasMatmulBiasAddGradOp::GetSbp(user_op::SbpContext* ctx) { /* dy need transpose. assume dy(m, n), x(m, k), dbias=(n, 1) dw = dy_T matmul x */ ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 1) .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("w_grad", 0), 0) .Split(user_op::OpArg("b_grad", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .PartialSum(user_op::OpArg("w_grad", 0)) .PartialSum(user_op::OpArg("b_grad", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe CublasMatmulBiasAddGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4MatmulBiasAddBackward(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cublas_fused_mlp_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const int64_t weight_num = ctx->input_size("weights"); const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); ctx->SetOutputShape("d_weights", idx, weight_desc.shape()); ctx->SetOutputShape("d_biases", idx, Shape({weight_desc.shape().At(0)})); } ctx->SetOutputShape("d_x", 0, x_desc.shape()); return Maybe::Ok(); } Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { const int64_t weight_num = ctx->input_size("weights"); const int64_t dweight_num = ctx->output_size("d_weights"); CHECK_EQ(weight_num, dweight_num) << "The number of weights and d_weights should be equal. "; const int64_t dbias_size = ctx->output_size("d_biases"); CHECK_EQ(weight_num, dbias_size) << "The number of d_biases should be equal to weight_num. " "Because last layer's bias_grad is computed by ReduceSum. "; const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); for (int idx = weight_num - 1; idx >= 0; idx--) { ctx->SetOutputDType("d_weights", idx, dy_desc.data_type()); ctx->SetOutputDType("d_biases", idx, dy_desc.data_type()); } ctx->SetOutputDType("d_x", 0, dy_desc.data_type()); return Maybe::Ok(); } } // namespace /* static */ Maybe CublasFusedMLPGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmulBackward(ctx); } /*static*/ Maybe CublasFusedMLPGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CublasFusedMLPGradOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0); builder.Split(user_op::OpArg("dy", 0), 0); for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { builder.Broadcast(user_op::OpArg("weights", i)); } for (int i = 0; i < ctx->user_op_conf().input_size("cublas_aux"); ++i) { builder.Split(user_op::OpArg("cublas_aux", i), 0); } for (int i = 0; i < ctx->user_op_conf().input_size("hidden"); ++i) { builder.Split(user_op::OpArg("hidden", i), 0); } builder.Split(user_op::OpArg("d_x", 0), 0); if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { // FusedMLPGradKernel do allreduce for dbias and dweight, so here convert from PartialSum to // Broadcast. for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { builder.Broadcast(user_op::OpArg("d_biases", i)); } for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { builder.Broadcast(user_op::OpArg("d_weights", i)); } } else { for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { builder.PartialSum(user_op::OpArg("d_biases", i)); } for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { builder.PartialSum(user_op::OpArg("d_weights", i)); } } builder.Build(); return Maybe::Ok(); } /* static */ Maybe CublasFusedMLPGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4MatmulBackward(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cublas_fused_mlp_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { constexpr int32_t kAuxReluLdAlignRequirement = 128; long AlignReluAuxLd(long aux_ld) { /* ReLu bit-mask matrix leading dimension in elements. Must be divisible by 128 and be no less than the number of rows in the output matrix. */ long old_aux_ld = aux_ld; return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement) * kAuxReluLdAlignRequirement; } Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); int32_t weight_size = ctx->input_size("weights"); int32_t bias_size = ctx->input_size("biases"); CHECK_EQ_OR_RETURN(weight_size, bias_size); /* A: (m, k) B: (n, k) need transpose C: (m, n) */ int64_t m = 0, n = 0, k = 0, cublas_aux_ld = 0; m = x_desc.shape().At(0); k = x_desc.shape().At(1); for (int32_t idx = 0; idx < weight_size; idx++) { // skip first input weight. const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc("biases", idx); CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2); CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1); n = weight_desc.shape().At(0); CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n); CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k); cublas_aux_ld = n; // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype ctx->SetOutputShape("cublas_aux", idx, Shape({m, aux_size})); ctx->SetOutputShape("hidden", idx, Shape({m, n})); // Set for next layer. k = n; } ctx->SetOutputShape("out", 0, Shape({m, n})); return Maybe::Ok(); } Maybe InferDataType4Matmul(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(in_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); for (int32_t i = 0; i < ctx->output_size("hidden"); i++) { user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc("hidden", i); hidden_desc->set_data_type(first_in_desc.data_type()); } for (int32_t i = 0; i < ctx->output_size("cublas_aux"); i++) { user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc("cublas_aux", i); aux_desc->set_data_type(DataType::kInt32); } return Maybe::Ok(); } } // namespace /* static */ Maybe CublasFusedMLPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmul(ctx); } /*static*/ Maybe CublasFusedMLPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CublasFusedMLPOp::GetSbp(user_op::SbpContext* ctx) { // Currently Only support S0 B B B B ... S0 auto builder = ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0); for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { builder.Broadcast(user_op::OpArg("weights", i)); } for (int i = 0; i < ctx->user_op_conf().input_size("biases"); ++i) { builder.Broadcast(user_op::OpArg("biases", i)); } for (int i = 0; i < ctx->user_op_conf().output_size("cublas_aux"); ++i) { builder.Split(user_op::OpArg("cublas_aux", i), 0); } for (int i = 0; i < ctx->user_op_conf().output_size("hidden"); ++i) { builder.Split(user_op::OpArg("hidden", i), 0); } builder.Split(user_op::OpArg("out", 0), 0); builder.Build(); return Maybe::Ok(); } /* static */ Maybe CublasFusedMLPOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/cum_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe CumsumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } Maybe CumsumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } Maybe CumsumOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); auto dim = ctx->Attr("dim"); for (auto i = dim + 1; i < in_tensor_desc.shape().NumAxes(); i++) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe CumsumOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe CumProdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } Maybe CumProdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } Maybe CumProdOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); auto dim = ctx->Attr("dim"); for (auto i = dim + 1; i < in_tensor_desc.shape().NumAxes(); i++) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe CumProdOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe CumProdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("dy", 0)); return Maybe::Ok(); } Maybe CumProdGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } Maybe CumProdGradOp::GetSbp(user_op::SbpContext* ctx) { const auto& dy_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); auto dim = ctx->Attr("dim"); for (auto i = dim + 1; i < dy_tensor_desc.shape().NumAxes(); i++) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("output", 0), i) .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } Maybe CumProdGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/data_shuffle_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/embedding/embedding_manager.h" #include "oneflow/core/operator/operator.h" namespace oneflow { /* static */ Maybe UniqueKeyValuePairOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& keys_shape = ctx->InputShape("keys", 0); const int32_t num_tables = ctx->Attr("num_tables"); CHECK_GE_OR_RETURN(num_tables, 1) << "num_tables must greater than 1, but get " << num_tables; if (ctx->has_input("values", 0)) { const Shape& values_shape = ctx->InputShape("values", 0); CHECK_EQ_OR_RETURN(keys_shape, values_shape) << "keys_shape must equal to values_shape"; } else { if (num_tables > 1) { CHECK_EQ_OR_RETURN(keys_shape.NumAxes(), 2); CHECK_EQ_OR_RETURN(keys_shape.At(1), num_tables) << "keys cols must equal to num_tables"; } } ctx->SetOutputShape("num_unique", 0, Shape({1})); ctx->SetOutputShape("unique_keys", 0, Shape({keys_shape.elem_cnt()})); ctx->SetOutputShape("unique_values", 0, Shape({keys_shape.elem_cnt()})); ctx->SetOutputShape("inverse_indices", 0, keys_shape); return Maybe::Ok(); } /*static*/ Maybe UniqueKeyValuePairOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe UniqueKeyValuePairOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe UniqueKeyValuePairOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("num_unique", 0, DataType::kInt32); ctx->SetOutputDType("unique_keys", 0, ctx->InputDType("keys", 0)); ctx->SetOutputDType("inverse_indices", 0, DataType::kInt32); if (ctx->has_input("values", 0)) { ctx->SetOutputDType("unique_values", 0, ctx->InputDType("values", 0)); } else { ctx->SetOutputDType("unique_values", 0, DataType::kUInt8); } return Maybe::Ok(); } /* static */ Maybe IdShuffleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& ids_shape = ctx->InputShape("ids", 0); const int32_t num_tables = ctx->Attr("num_tables"); CHECK_GE_OR_RETURN(num_tables, 1) << "num_tables must greater than 1, but get " << num_tables; if (ctx->has_input("table_ids", 0)) { const Shape& table_ids_shape = ctx->InputShape("table_ids", 0); CHECK_EQ_OR_RETURN(ids_shape, table_ids_shape) << "ids_shape must equal to table_ids_shape"; } else { if (num_tables > 1) { CHECK_EQ_OR_RETURN(ids_shape.NumAxes(), 2); CHECK_EQ_OR_RETURN(ids_shape.At(1), num_tables) << "ids cols must equal to num_tables"; } } const int64_t num_ids = ids_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); ctx->SetOutputShape("num_unique_matrix", 0, Shape({parallel_num * parallel_num})); ctx->SetOutputShape("inverse_unique_partition_indices", 0, ids_shape); ctx->SetOutputShape("cur_rank_num_unique", 0, Shape({1})); ctx->SetOutputShape("cur_rank_unique_ids", 0, Shape({num_ids * parallel_num})); ctx->SetOutputShape("cur_rank_inverse_indices", 0, Shape({num_ids * parallel_num})); ctx->SetOutputShape("cur_rank_unique_table_ids", 0, Shape({num_ids * parallel_num})); return Maybe::Ok(); } /* static */ Maybe IdShuffleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IdShuffleOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Split(ctx->outputs(), 0) .Broadcast(user_op::OpArg("num_unique_matrix", 0)) .Broadcast(user_op::OpArg("cur_rank_num_unique", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe IdShuffleOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("num_unique_matrix", 0, DataType::kUInt32); ctx->SetOutputDType("inverse_unique_partition_indices", 0, DataType::kUInt32); ctx->SetOutputDType("cur_rank_num_unique", 0, DataType::kUInt32); ctx->SetOutputDType("cur_rank_unique_ids", 0, ctx->InputDType("ids", 0)); ctx->SetOutputDType("cur_rank_inverse_indices", 0, DataType::kUInt32); if (ctx->has_input("table_ids", 0)) { ctx->SetOutputDType("cur_rank_unique_table_ids", 0, ctx->InputDType("table_ids", 0)); } else { ctx->SetOutputDType("cur_rank_unique_table_ids", 0, DataType::kUInt8); } return Maybe::Ok(); } /* static */ Maybe EmbeddingShuffleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& cur_rank_embeddings_shape = ctx->InputShape("cur_rank_embeddings", 0); const Shape& num_unique_matrix_shape = ctx->InputShape("num_unique_matrix", 0); const Shape& cur_rank_inverse_indices_shape = ctx->InputShape("cur_rank_inverse_indices", 0); const Shape& inverse_unique_partition_indices_shape = ctx->InputShape("inverse_unique_partition_indices", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t num_ids = inverse_unique_partition_indices_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); if (embedding::UseDynamicMemoryAllocation()) { CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.elem_cnt(), 1) << "if use dynamic memory allocation, cur_rank_embeddings elem_cnt should be 1."; } else { CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.NumAxes(), 2) << "cur_rank_embeddings num_axes should be 2."; CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.At(0), parallel_num * num_ids) << " got " << cur_rank_embeddings_shape.At(0) << " and " << parallel_num * num_ids; CHECK_EQ_OR_RETURN(embedding_size, cur_rank_embeddings_shape.At(1)) << " got " << embedding_size << " and " << cur_rank_embeddings_shape.At(1); } CHECK_EQ_OR_RETURN(num_unique_matrix_shape.elem_cnt(), parallel_num * parallel_num); CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = inverse_unique_partition_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); ctx->SetOutputShape("embeddings", 0, Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EmbeddingShuffleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EmbeddingShuffleOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Split(ctx->inputs(), 0) .Broadcast(user_op::OpArg("num_unique_matrix", 0)) .Split(ctx->outputs(), 0); if (embedding::UseDynamicMemoryAllocation()) { builder.Broadcast(user_op::OpArg("cur_rank_embeddings", 0)).Build(); } else { builder.Split(user_op::OpArg("cur_rank_embeddings", 0), 0).Build(); } return Maybe::Ok(); } /* static */ Maybe EmbeddingShuffleOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(ctx->InputDType("num_unique_matrix", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("cur_rank_inverse_indices", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("inverse_unique_partition_indices", 0) == DataType::kUInt32); ctx->SetOutputDType("embeddings", 0, ctx->InputDType("cur_rank_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe EmbeddingGradientShuffleOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& embedding_grad_shape = ctx->InputShape("embedding_grad", 0); const Shape& num_unique_matrix_shape = ctx->InputShape("num_unique_matrix", 0); const Shape& cur_rank_inverse_indices_shape = ctx->InputShape("cur_rank_inverse_indices", 0); const Shape& inverse_unique_partition_indices_shape = ctx->InputShape("inverse_unique_partition_indices", 0); const int64_t num_ids = inverse_unique_partition_indices_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); CHECK_EQ_OR_RETURN(embedding_grad_shape.elem_cnt() % num_ids, 0); const int64_t embedding_size = embedding_grad_shape.elem_cnt() / num_ids; CHECK_EQ_OR_RETURN(num_unique_matrix_shape.elem_cnt(), parallel_num * parallel_num); CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = cur_rank_inverse_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); ctx->SetOutputShape("cur_rank_unique_embedding_grad", 0, Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EmbeddingGradientShuffleOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EmbeddingGradientShuffleOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Broadcast(user_op::OpArg("num_unique_matrix", 0)) .Split(ctx->outputs(), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe EmbeddingGradientShuffleOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(ctx->InputDType("num_unique_matrix", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("cur_rank_inverse_indices", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("inverse_unique_partition_indices", 0) == DataType::kUInt32); ctx->SetOutputDType("cur_rank_unique_embedding_grad", 0, ctx->InputDType("embedding_grad", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t num_ids = indices_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); if (embedding::UseDynamicMemoryAllocation()) { CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1) << "if use dynamic memory allocation, in elem_cnt should be 1."; } else { CHECK_EQ_OR_RETURN(in_shape.NumAxes(), 2) << "in num_axes should be 2."; CHECK_EQ_OR_RETURN(in_shape.At(0), parallel_num * num_ids) << " got " << in_shape.At(0) << " and " << parallel_num * num_ids; CHECK_EQ_OR_RETURN(embedding_size, in_shape.At(1)) << " got " << embedding_size << " and " << in_shape.At(1); } DimVector out_dim_vec = indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); ctx->SetOutputShape("out", 0, Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingGatherOp::GetSbp(user_op::SbpContext* ctx) { // Only used in parallel_num = 1. return Maybe::Ok(); } /* static */ Maybe OneEmbeddingGatherOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } REGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM_WITH_FUNC("id_shuffle", []() { if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION", false)) { return 2; } else { return 1; } }); } // namespace oneflow ================================================ FILE: oneflow/user/ops/deconv_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { template Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes()); const std::string& data_format = ctx->Attr("data_format"); const auto& kernel_size = ctx->Attr>("kernel_size"); CHECK_EQ_OR_RETURN(NDims, kernel_size.size()); const int32_t filters = ctx->Attr("filters"); size_t idx_offset = IdxOffset(data_format); int32_t groups = ctx->Attr("groups"); { const auto& dilation_rate = ctx->Attr>("dilation_rate"); const auto& output_padding = ctx->Attr>("output_padding"); const auto& strides = ctx->Attr>("strides"); const auto& padding_before = ctx->Attr>("padding_before"); CHECK_EQ_OR_RETURN(NDims, dilation_rate.size()); CHECK_EQ_OR_RETURN(NDims, strides.size()); CHECK_EQ_OR_RETURN(NDims, output_padding.size()); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector out_shape(NDims + 2); out_shape.at(0) = in.shape().At(0); const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1; out_shape.at(c_dim) = filters; for (int32_t i = 0; i < NDims; ++i) { int32_t effective_filter_size = (kernel_size.at(i) - 1) * dilation_rate.at(i) + 1; out_shape.at(idx_offset + i) = (in.shape().At(idx_offset + i) - 1) * strides.at(i) - 2 * padding_before.at(i) + output_padding.at(i) + effective_filter_size; } if (in.shape().At(0) != 0) { for (int i = 0; i < out_shape.size(); i++) { CHECK_GT_OR_RETURN(out_shape[i], 0) << "RuntimeError: Given input size per channel: (" << Shape(in.shape()) << "). Calculated output size per channel: (" << Shape(out_shape) << "). Output size is too small"; } } out->set_is_dynamic(in.is_dynamic()); out->set_shape(Shape(out_shape)); } { DimVector weight_shape(in.shape().dim_vec()); if (data_format == "channels_first") { weight_shape.at(0) = in.shape().At(1); weight_shape.at(1) = filters / groups; } else if (data_format == "channels_last") { weight_shape.at(0) = in.shape().At(NDims + 1); weight_shape.at(NDims + 1) = filters / groups; } else { UNIMPLEMENTED_THEN_RETURN(); } for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); } const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape)); } return Maybe::Ok(); } Maybe InferDataType_(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } Maybe GetSbpSignatures4DeConv(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); return Maybe::Ok(); } template Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; const std::string& data_format = conf.attr("data_format"); if (!(data_format == "channels_first" || data_format == "channels_last")) { err << " data_format:" << data_format; is_checked = false; } if (NDims != 0) { const auto& padding_before = conf.attr>("padding_before"); if (padding_before.size() != NDims) { err << " padding_before: number of element is " << padding_before.size(); is_checked = false; } const auto& kernel_size = conf.attr>("kernel_size"); if (kernel_size.size() != NDims) { err << " kernel_size: number of element is " << kernel_size.size(); is_checked = false; } const auto& strides = conf.attr>("strides"); if (strides.size() != NDims) { err << " strides: number of element is " << strides.size(); is_checked = false; } const auto& dilation_rate = conf.attr>("dilation_rate"); if (dilation_rate.size() != NDims) { err << " dilation_rate: number of element is " << dilation_rate.size(); is_checked = false; } } if (is_checked) { return Maybe::Ok(); } else { return oneflow::Error::CheckFailedError() << err.str(); } } } // namespace /* static */ Maybe Deconv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4DeConv<1>(ctx); } /*static*/ Maybe Deconv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Deconv1DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4DeConv(ctx); } /* static */ Maybe Deconv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<1>(def, conf); } /* static */ Maybe Deconv1DOp::InferDataType(user_op::InferContext* ctx) { return InferDataType_(ctx); } /* static */ Maybe Deconv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4DeConv<2>(ctx); } /*static*/ Maybe Deconv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Deconv2DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4DeConv(ctx); } /* static */ Maybe Deconv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<2>(def, conf); } /* static */ Maybe Deconv2DOp::InferDataType(user_op::InferContext* ctx) { return InferDataType_(ctx); } /* static */ Maybe Deconv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4DeConv<3>(ctx); } /*static*/ Maybe Deconv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe Deconv3DOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures4DeConv(ctx); } /* static */ Maybe Deconv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return CheckAttr_<3>(def, conf); } /* static */ Maybe Deconv3DOp::InferDataType(user_op::InferContext* ctx) { return InferDataType_(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/deform_conv_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DeformConv2dOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); const Shape& offset_shape = ctx->InputShape("offset", 0); const Shape& mask_shape = ctx->InputShape("mask", 0); const int32_t kW = weight_shape.at(3); const int32_t kH = weight_shape.at(2); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const int32_t deformable_group = ctx->Attr("offset_groups"); const bool use_mask = ctx->Attr("use_mask"); bool has_bias = ctx->has_input("bias", 0); if (has_bias) { const Shape& bias_shape = ctx->InputShape("bias", 0); std::cout << "bias_shape:" << bias_shape.ToString() << std::endl; CHECK_EQ_OR_RETURN(bias_shape.At(0), weight_shape.At(0)); } CHECK_OR_RETURN(dW > 0 && dH > 0) << Error::RuntimeError() << "The stride must be greater than 0,but got " << dW << " and " << dH; CHECK_OR_RETURN(kW > 0 && kH > 0) << Error::RuntimeError() << "The weight must be greater than 0,but got " << kW << " and " << kH; CHECK_OR_RETURN(padW >= 0 && padH >= 0) << Error::RuntimeError() << "The pad must be greater than or equal to 0,but got " << padW << " and " << padH; CHECK_OR_RETURN(dilationW > 0 && dilationH > 0) << Error::RuntimeError() << "The dilation must be greater than 0,but got " << dilationH << " and " << dilationW; CHECK_EQ_OR_RETURN(input_shape.NumAxes(), 4); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 4); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(offset_shape.NumAxes(), 4); // NOLINT(maybe-need-error-msg) if (use_mask) { CHECK_EQ_OR_RETURN(mask_shape.NumAxes(), 4); } // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(weight_shape.At(2), kH); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(weight_shape.At(3), kW); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(offset_shape.At(1), deformable_group * 2 * kW * kH) << Error::RuntimeError() << "offset.shape[1] is not valid: got: " << offset_shape.At(1) << " ,expected: " << deformable_group * 2 * kW * kH; if (use_mask) { CHECK_EQ_OR_RETURN(mask_shape.At(1), deformable_group * kW * kH) << Error::RuntimeError() << "mask.shape[1] is not valid: got: " << mask_shape.At(1) << " expected: " << deformable_group * kW * kH; } CHECK_EQ_OR_RETURN(offset_shape.At(0), input_shape.At(0)) << Error::RuntimeError() << "invalid batch size of offset:got: " << offset_shape.At(0) << " ,expected: " << input_shape.At(0); int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; CHECK_OR_RETURN(outputWidth > 0 && outputHeight > 0) << Error::RuntimeError() << "Calculated output size too small - out_h: " << outputHeight << " ,out_w: " << outputWidth; CHECK_OR_RETURN(offset_shape.At(2) == outputHeight && offset_shape.At(3) == outputWidth) << Error::RuntimeError() << "invalid offset output dims: got ( " << offset_shape.At(2) << ", " << offset_shape.At(3) << ")" << ",expected: " << "(" << outputHeight << ", " << outputWidth << ")"; if (use_mask) { CHECK_OR_RETURN(mask_shape.At(2) == outputHeight && mask_shape.At(3) == outputWidth) << Error::RuntimeError() << "invalid mask output dims: got ( " << mask_shape.At(2) << ", " << mask_shape.At(3) << ")" << ",expected: " << "(" << outputHeight << ", " << outputWidth << ")"; } ctx->SetOutputShape("output", 0, Shape({input_shape.At(0), weight_shape.At(0), outputHeight, outputWidth})); ctx->SetOutputIsDynamic("output", 0, ctx->InputIsDynamic("input", 0)); return Maybe::Ok(); } /* static */ Maybe DeformConv2dInputGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); const Shape& offset_shape = ctx->InputShape("offset", 0); const Shape& output_grad_shape = ctx->InputShape("output_grad", 0); const int32_t kW = weight_shape.at(3); const int32_t kH = weight_shape.at(2); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); const bool use_mask = ctx->Attr("use_mask"); CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 4); int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; CHECK_EQ_OR_RETURN(output_grad_shape.At(2), outputHeight); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(output_grad_shape.At(3), outputWidth); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(output_grad_shape.At(1), weight_shape.At(0)); // NOLINT(maybe-need-error-msg) ctx->SetOutputShape("input_grad", 0, ctx->InputShape("input", 0)); ctx->SetOutputShape("offset_grad", 0, ctx->InputShape("offset", 0)); ctx->SetOutputIsDynamic("input_grad", 0, ctx->InputIsDynamic("input", 0)); ctx->SetOutputIsDynamic("offset_grad", 0, false); if (use_mask) { ctx->SetOutputShape("mask_grad", 0, Shape({offset_shape.At(0), offset_shape.At(1) / 2, offset_shape.At(2), offset_shape.At(3)})); ctx->SetOutputIsDynamic("mask_grad", 0, false); } return Maybe::Ok(); } Maybe DeformConv2dParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const Shape& output_grad_shape = ctx->InputShape("output_grad", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); const int32_t kW = weight_shape.at(3); const int32_t kH = weight_shape.at(2); const int32_t dW = ctx->Attr("stride_w"); const int32_t dH = ctx->Attr("stride_h"); const int32_t padW = ctx->Attr("pad_w"); const int32_t padH = ctx->Attr("pad_h"); const int32_t dilationW = ctx->Attr("dilation_w"); const int32_t dilationH = ctx->Attr("dilation_h"); int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; CHECK_EQ_OR_RETURN(output_grad_shape.At(2), outputHeight); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(output_grad_shape.At(3), outputWidth); // NOLINT(maybe-need-error-msg) ctx->SetOutputShape("weight_grad", 0, ctx->InputShape("weight", 0)); ctx->SetOutputIsDynamic("weight_grad", 0, false); return Maybe::Ok(); } /* static */ Maybe DeformConv2dOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DeformConv2dInputGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DeformConv2dParamGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DeformConv2dOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /* static */ Maybe DeformConv2dInputGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("input_grad", 0, ctx->InputDType("input", 0)); ctx->SetOutputDType("offset_grad", 0, ctx->InputDType("offset", 0)); const bool use_mask = ctx->Attr("use_mask"); if (use_mask) { ctx->SetOutputDType("mask_grad", 0, ctx->InputDType("mask", 0)); } return Maybe::Ok(); } /* static */ Maybe DeformConv2dParamGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("weight_grad", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /* static */ Maybe DeformConv2dOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("offset", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("output", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe DeformConv2dInputGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("output_grad", 0), 0) .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("offset", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("input_grad", 0), 0) .Split(user_op::OpArg("offset_grad", 0), 0) .Split(user_op::OpArg("mask_grad", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe DeformConv2dParamGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("output_grad", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Split(user_op::OpArg("offset", 0), 0) .PartialSum(user_op::OpArg("weight_grad", 0)) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/depend_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DependOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe DependOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DependOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Broadcast(user_op::OpArg("depend_tensor", 0)) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("depend_tensor", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe DependOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/det_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe DetOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); auto x_shape = x_desc.shape(); ctx->SetOutputShape("y", 0, Shape(x_shape.begin(), x_shape.end() - 2)); return Maybe::Ok(); } /*static*/ Maybe DetOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe DetOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe DetOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/diag_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DiagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const int32_t diagonal = ctx->Attr("diagonal"); const ShapeView& in_shape = in.shape(); const int32_t in_dim = in_shape.NumAxes(); CHECK_GE_OR_RETURN(in_dim, 1); CHECK_LE_OR_RETURN(in_dim, 2); DimVector out_dim_vec = {0}; if (in_dim == 1) { int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal); out_dim_vec[0] = out_tensor_size; out_dim_vec.emplace_back(out_tensor_size); } else { if (diagonal >= 0) { out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal); } else { out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); } // For 0-size Tensor. CHECK_GE_OR_RETURN(out_dim_vec[0], 0); // NOLINT } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); out_desc->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe DiagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DiagOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe DiagOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe DiagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in.shape(); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(Shape(in_shape.dim_vec())); return Maybe::Ok(); } /*static*/ Maybe DiagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DiagGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe DiagGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/diagonal_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DiagonalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const int32_t offset = ctx->Attr("offset"); const ShapeView& in_shape = in.shape(); const int32_t in_dim = in_shape.NumAxes(); CHECK_GE_OR_RETURN(in_dim, 2); DimVector out_dim_vec = {}; FOR_RANGE(int32_t, index, 2, in_dim) { out_dim_vec.push_back(in_shape.At(index)); } int32_t last_dim = 0; if (offset >= 0) { last_dim = std::min(in_shape.At(0), in_shape.At(1) - offset); } else { last_dim = std::min(in_shape.At(0) + offset, in_shape.At(1)); } if (last_dim < 0) { last_dim = 0; } out_dim_vec.push_back(last_dim); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); out_desc->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe DiagonalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DiagonalOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe DiagonalOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe DiagonalGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in.shape(); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(Shape(in_shape.dim_vec())); return Maybe::Ok(); } /*static*/ Maybe DiagonalGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DiagonalGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe DiagonalGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/dim_gather_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DimGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); int64_t input_num_axes = in.shape().NumAxes(); // For 0-dim tensor CHECK_GE_OR_RETURN(input_num_axes, 0); // NOLINT CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); int64_t index_num_axes = index.shape().NumAxes(); const int32_t dim = ctx->Attr("dim"); // For 0-dim tensor CHECK_GE_OR_RETURN(dim, 0); CHECK_LE_OR_RETURN(dim, input_num_axes); // NOLINT if (input_num_axes > 0) { CHECK_GE_OR_RETURN(input_num_axes, index_num_axes); } // NOLINT CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); out->set_shape(index.shape()); return Maybe::Ok(); } /*static*/ Maybe DimGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DimGatherOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); int64_t index_num_axes = index_tensor.shape().NumAxes(); const int32_t dim = ctx->Attr("dim"); FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i != dim) { ctx->NewBuilder() .Split(user_op::OpArg("index", 0), i) .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("output", 0), i) .Build(); } else if (i == dim) { ctx->NewBuilder() .Broadcast(user_op::OpArg("input", 0)) .Split(user_op::OpArg("index", 0), i) .Split(user_op::OpArg("output", 0), i) .Build(); } } ctx->NewBuilder() .PartialSum(user_op::OpArg("input", 0)) .Broadcast(user_op::OpArg("index", 0)) .PartialSum(user_op::OpArg("output", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe DimGatherOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe DimGatherOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/dim_scatter_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/error.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/user/kernels/dim_scatter_kernel_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc* input = ctx->has_input("input", 0) ? &ctx->InputTensorDesc("input", 0) : nullptr; const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); const user_op::TensorDesc* like = ctx->has_input("like", 0) ? &ctx->InputTensorDesc("like", 0) : nullptr; const user_op::TensorDesc& src = ctx->InputTensorDesc("src", 0); int32_t dim = ctx->Attr("dim"); // check index.numaxes == src.num_axes == input/like.numaxes int64_t src_num_axes = src.shape().NumAxes(); // For 0-dim Tensor CHECK_GE_OR_RETURN(src_num_axes, 0); // NOLINT CHECK_LE_OR_RETURN(src_num_axes, user_op::kDimGatherMaxDimCount); int64_t index_num_axes = index.shape().NumAxes(); CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes); int64_t output_num_axes = 0; if (input) { output_num_axes = input->shape().NumAxes(); } else if (like) { output_num_axes = like->shape().NumAxes(); } else { OF_UNIMPLEMENTED() << "Input tensor and like tensor cannot be empty simultaneously."; } // For 0-dim Tensor if (output_num_axes != 0 && index_num_axes != 0) { CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes); // NOLINT } else if (output_num_axes != 0) { CHECK_LE_OR_RETURN(output_num_axes, 1); // NOLINT } else { CHECK_LE_OR_RETURN(index_num_axes, 1); // NOLINT } // check index.shape(i) <= input/like.shape(i) FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; if (input) { CHECK_LE_OR_RETURN(index.shape().At(i), input->shape().At(i)); } else { CHECK_LE_OR_RETURN(index.shape().At(i), like->shape().At(i)); } } // check index.shape(i) <= src.shape(i) FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; CHECK_LE_OR_RETURN(index.shape().At(i), src.shape().At(i)); } user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); out->set_shape(input ? input->shape() : like->shape()); return Maybe::Ok(); } Maybe InferScalarTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); int32_t dim = ctx->Attr("dim"); // check index.numaxes == src.num_axes == input/like.numaxes int64_t output_num_axes = input.shape().NumAxes(); int64_t index_num_axes = index.shape().NumAxes(); // For 0-dim tensor CHECK_GE_OR_RETURN(output_num_axes, index_num_axes); // NOLINT // check index.shape(i) <= input/like.shape(i) FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; CHECK_LE_OR_RETURN(index.shape().At(i), input.shape().At(i)); } user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); out->set_shape(input.shape()); return Maybe::Ok(); } Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } Maybe InputScalarArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } void _SetSbp(user_op::SbpContext* ctx, const char* like_or_input) { const int32_t dim = ctx->Attr("dim"); const Shape& index_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0).shape(); const Shape& src_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("src", 0).shape(); const Shape& input_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(like_or_input, 0).shape(); FOR_RANGE(int64_t, i, 0, index_tensor_shape.NumAxes()) { if (i == dim) { continue; } int64_t len = index_tensor_shape.At(i); if (len == src_tensor_shape.At(i) && len == input_tensor_shape.At(i)) { ctx->NewBuilder() .Split(user_op::OpArg("index", 0), i) .Split(user_op::OpArg("src", 0), i) .Split(user_op::OpArg(like_or_input, 0), i) .Split(user_op::OpArg("output", 0), i) .Build(); } } ctx->NewBuilder() .PartialSum(user_op::OpArg("src", 0)) .Broadcast(user_op::OpArg("index", 0)) .PartialSum(user_op::OpArg("output", 0)) .PartialSum(user_op::OpArg(like_or_input, 0)) .Build(); } Maybe SetSbpLike(user_op::SbpContext* ctx) { _SetSbp(ctx, "like"); return Maybe::Ok(); } Maybe SetSbpScatter(user_op::SbpContext* ctx) { _SetSbp(ctx, "input"); return Maybe::Ok(); } Maybe SetSbpScatterScalar(user_op::SbpContext* ctx) { const int32_t dim = ctx->Attr("dim"); const Shape& index_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0).shape(); const Shape& input_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); FOR_RANGE(int64_t, i, 0, index_tensor_shape.NumAxes()) { if (i == dim) { continue; } if (index_tensor_shape.At(i) == input_tensor_shape.At(i)) { ctx->NewBuilder() .Split(user_op::OpArg("index", 0), i) .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("output", 0), i) .Build(); } } return Maybe::Ok(); } Maybe InferDtype(user_op::InferContext* ctx) { const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); if (ctx->has_input("input", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("input", 0), ctx->InputDType("src", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("src", 0)) << ", but got " << DataType_Name(ctx->InputDType("input", 0)); } else { CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("like", 0)) << ", but got " << DataType_Name(ctx->InputDType("src", 0)); } ctx->SetOutputDType("output", 0, ctx->InputDType("src", 0)); return Maybe::Ok(); } Maybe InferScalarDtype(user_op::InferContext* ctx) { const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe DimScatterAddLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc(ctx); } /*static*/ Maybe DimScatterAddLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DimScatterAddLikeOp::GetSbp(user_op::SbpContext* ctx) { return SetSbpLike(ctx); } /* static */ Maybe DimScatterAddLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return InputArgModifierFn(GetInputArgModifierFn, conf); } /* static */ Maybe DimScatterAddLikeOp::InferDataType(user_op::InferContext* ctx) { return InferDtype(ctx); } #define DEF_SCATTER_OP(op_class_name) \ /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDesc(ctx); \ } \ \ /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ return SetSbpScatter(ctx); \ } \ \ /* static */ Maybe op_class_name::ModifyInputArg( \ const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ return InputArgModifierFn(GetInputArgModifierFn, conf); \ } \ \ /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ return InferDtype(ctx); \ } #define DEF_SCATTER_SCALAR_OP(optypename) \ /* static */ Maybe optypename::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferScalarTensorDesc(ctx); \ } \ \ /*static*/ Maybe optypename::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe optypename::GetSbp(user_op::SbpContext* ctx) { \ return SetSbpScatterScalar(ctx); \ } \ \ /* static */ Maybe optypename::ModifyInputArg( \ const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ return InputScalarArgModifierFn(GetInputArgModifierFn, conf); \ } \ \ /* static */ Maybe optypename::InferDataType(user_op::InferContext* ctx) { \ return InferScalarDtype(ctx); \ } DEF_SCATTER_OP(DimScatterAddOp); DEF_SCATTER_OP(DimScatterUpdateOp); DEF_SCATTER_OP(DimScatterMulOp); DEF_SCATTER_SCALAR_OP(DimScatterUpdateScalarOp); DEF_SCATTER_SCALAR_OP(DimScatterAddScalarOp); DEF_SCATTER_SCALAR_OP(DimScatterMulScalarOp); } // namespace oneflow ================================================ FILE: oneflow/user/ops/distributions/exponential_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe ExponentialOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("out_shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ExponentialOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("out_shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe ExponentialOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe ExponentialOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe ExponentialOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/distributions/multinomial_with_replacement_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe MultinomialWithReplacementOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { int32_t num_samples = ctx->Attr("num_samples"); const Shape& x_shape = ctx->InputShape("x", 0); if (x_shape.NumAxes() == 1) { ctx->SetOutputShape("out", 0, Shape({num_samples})); } else { ctx->SetOutputShape("out", 0, Shape({x_shape.At(0), num_samples})); } return Maybe::Ok(); } /*static*/ Maybe MultinomialWithReplacementOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultinomialWithReplacementOp::GetSbp(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); if (x_shape.NumAxes() == 2) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); } return Maybe::Ok(); } /* static */ Maybe MultinomialWithReplacementOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/distributions/normal_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->Attr("shape")); return Maybe::Ok(); } /*static*/ Maybe NormalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe NormalOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe NormalOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe NormalOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/distributions/uniform_int_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe UniformIntOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe UniformIntOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe UniformIntOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe UniformIntOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/distributions/uniform_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe UniformOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe UniformOp::GetSbp(user_op::SbpContext* ctx) { const Shape& logical_shape = ctx->Attr("shape"); int64_t num_axes = logical_shape.NumAxes(); for (int i = 0; i < num_axes; ++i) { ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /* static */ Maybe UniformOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe UniformOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } /* static */ Maybe UniformOp::DumpNdSbpSignatureForOpConfFn(const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf) { return user_op::SetSrcOpNdSbp(nd_sbp_sig, "out_0", op_conf); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/dot_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_OR_RETURN(x.shape() == y.shape()) << Error::RuntimeError() << "inconsistent tensor size, expected tensor to have the same number of elements, but got " << x.shape().elem_cnt() << " and " << y.shape().elem_cnt() << " elements respectively"; CHECK_OR_RETURN(x.shape().NumAxes() == 1) << Error::RuntimeError() << "1D tensors expected, but got " << x.shape().NumAxes() << "D tensors"; ctx->SetOutputShape("out", 0, Shape({})); return Maybe::Ok(); } /*static*/ Maybe DotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DotOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("y", 0), 0) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe DotOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_OR_RETURN(x.data_type() == y.data_type()) << Error::RuntimeError() << "expected both vectors to have same dtype, but found " << DataType_Name(x.data_type()) << " and " << DataType_Name(y.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/dropout_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("out", 0, in_shape); ctx->SetOutputShape("mask", 0, in_shape); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe DropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DropoutOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); } return Maybe::Ok(); } /* static */ Maybe DropoutOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { float rate = conf.attr("rate"); CHECK_GE_OR_RETURN(rate, 0.0); return Maybe::Ok(); } /* static */ Maybe DropoutOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("mask", 0, DataType::kBool); return Maybe::Ok(); } /* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); ctx->SetOutputShape("dx", 0, dy_shape); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("dy", 0)); CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); return Maybe::Ok(); } /*static*/ Maybe DropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DropoutGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe DropoutGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { float scale = conf.attr("scale"); CHECK_GT_OR_RETURN(scale, 1); return Maybe::Ok(); } /* static */ Maybe DropoutGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool) << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " << DataType_Name(ctx->InputDType("mask", 0)); return Maybe::Ok(); } /* static */ Maybe RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("like", 0)); return Maybe::Ok(); } /*static*/ Maybe RandomMaskLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RandomMaskLikeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("like", 0), axis) .Split(user_op::OpArg("out", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe RandomMaskLikeOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { float rate = conf.attr("rate"); CHECK_GE_OR_RETURN(rate, 0); CHECK_LT_OR_RETURN(rate, 1); return Maybe::Ok(); } /* static */ Maybe RandomMaskLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kBool); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { bool IsScalarTensor(const user_op::TensorDesc* desc) { return desc->shape().NumAxes() == 1 && desc->shape().At(0) == 1; } bool IsTensorWithType(const user_op::TensorDesc* desc, DataType data_type) { return desc->data_type() == data_type; } } // namespace /* static */ Maybe DynamicLossScaleScheduleOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("count_not_finite", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("loss_scale", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("good_step_counter", 0)))); return Maybe::Ok(); } /*static*/ Maybe DynamicLossScaleScheduleOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe DynamicLossScaleScheduleOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe DynamicLossScaleScheduleOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* loss_scale = GetInputArgModifierFn("loss_scale", 0); CHECK_OR_RETURN(loss_scale != nullptr); loss_scale->set_is_mutable(true); user_op::InputArgModifier* good_step_counter = GetInputArgModifierFn("good_step_counter", 0); CHECK_OR_RETURN(good_step_counter != nullptr); good_step_counter->set_is_mutable(true); return Maybe::Ok(); } /* static */ Maybe DynamicLossScaleScheduleOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN( IsTensorWithType(&(ctx->InputTensorDesc("count_not_finite", 0)), DataType::kInt64)); CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc("loss_scale", 0)), DataType::kFloat)); CHECK_OR_RETURN( IsTensorWithType(&(ctx->InputTensorDesc("good_step_counter", 0)), DataType::kInt64)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_b_to_s_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { // Can only be called in local TODO: move this comment to ods /* static */ Maybe EagerBToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); Symbol out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt)); DimVector dim_vec{shape.dim_vec()}; int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LT_OR_RETURN(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EagerBToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerBToSOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerBToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerBToSOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerBToSOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_ccl_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe EagerCclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EagerCclAllReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerCclAllReduceOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); return Maybe::Ok(); } /* static */ Maybe EagerCclAllReduceOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerCclAllReduceOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { size_t size = ctx->input_size("in"); const std::vector& shape_list = ctx->Attr>("shape_list"); CHECK_EQ_OR_RETURN(size, ctx->output_size("out")) << "the size of input tensor tuple should equal the size of output tensor tuple."; for (int i = 0; i < size; ++i) { ctx->SetOutputShape("out", i, JUST(VectorAt(shape_list, i))); } return Maybe::Ok(); } /*static*/ Maybe EagerCclBroadcastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerCclBroadcastOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).Broadcast(ctx->outputs()).Build(); ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); ctx->NewBuilder().Split(ctx->inputs(), 0).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe EagerCclBroadcastOp::InferDataType(user_op::InferContext* ctx) { size_t size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(size, ctx->output_size("out")) << "the size of input tensor tuple should equal the size of output tensor tuple."; for (int i = 0; i < size; ++i) { ctx->SetOutputDType("out", i, ctx->InputDType("in", i)); } return Maybe::Ok(); } /* static */ Maybe> EagerCclBroadcastOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclTouchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe EagerCclTouchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /* static */ Maybe EagerCclTouchOp::GetSbp(user_op::SbpContext* ctx) { // local only return Maybe::Ok(); } /* static */ Maybe EagerCclTouchOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /* static */ Maybe> EagerCclTouchOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EagerCclReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerCclReduceOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN() << "global tensor are not supported"; } /* static */ Maybe EagerCclReduceOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerCclReduceOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe EagerCclReduceScatterOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { const auto& input_shape = ctx->InputShape("in", 0); const auto& shape = ctx->Attr("output_shape"); Symbol parallel_desc = JUST(TxtStringToPlacement(ctx->Attr("parallel_conf"))); CHECK_EQ_OR_RETURN(input_shape.elem_cnt(), shape.elem_cnt() * parallel_desc->parallel_num()) << Error::RuntimeError() << "output tensor size must be equal to world_size times input tensor size"; CHECK_EQ_OR_RETURN(ctx->InputDType("in", 0), ctx->Attr("output_dtype")) << Error::RuntimeError() << "output tensor must have the same type as input tensor"; ctx->SetOutputShape("out", 0, ctx->Attr("output_shape")); ctx->SetOutputDType("out", 0, ctx->Attr("output_dtype")); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe EagerCclReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe EagerCclReduceScatterOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel()); } in_nd_sbp->clear_sbp_parallel(); out_nd_sbp->clear_sbp_parallel(); // P2S or B2S const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); in_nd_sbp->CopyFrom(in_dis_hint); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); } return Maybe::Ok(); } /* static */ Maybe EagerCclReduceScatterOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerCclReduceScatterOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EagerCclAllGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const auto& input_shape = ctx->InputShape("in", 0); const auto& shape = ctx->Attr("output_shape"); Symbol parallel_desc = JUST(TxtStringToPlacement(ctx->Attr("parallel_conf"))); CHECK_EQ_OR_RETURN(input_shape.elem_cnt() * parallel_desc->parallel_num(), shape.elem_cnt()) << Error::RuntimeError() << "output tensor size must be equal to world_size times input tensor size"; CHECK_EQ_OR_RETURN(ctx->InputDType("in", 0), ctx->Attr("output_dtype")) << Error::RuntimeError() << Error::RuntimeError() << "output tensor must have the same type as input tensor"; ctx->SetOutputShape("out", 0, ctx->Attr("output_shape")); ctx->SetOutputDType("out", 0, ctx->Attr("output_dtype")); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe EagerCclAllGatherOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe EagerCclAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { CHECK_OR_RETURN(sbp_hint.has_split_parallel()); CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); } in_nd_sbp->clear_sbp_parallel(); out_nd_sbp->clear_sbp_parallel(); // S(0)->B const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } return Maybe::Ok(); } /* static */ Maybe EagerCclAllGatherOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerCclAllGatherOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe EagerCclS2SOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe EagerCclS2SOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe EagerCclS2SOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { CHECK_OR_RETURN(sbp_hint.has_split_parallel()); CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); } in_nd_sbp->clear_sbp_parallel(); out_nd_sbp->clear_sbp_parallel(); // S(in)->S(out) const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); } return Maybe::Ok(); } /* static */ Maybe EagerCclS2SOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerCclS2SOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_p_to_b_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { // Can only be called in local /* static */ Maybe EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } /*static*/ Maybe EagerPToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerPToBOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_s_to_b op doesn't support global tensor!"; } /* static */ Maybe EagerPToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_s_to_b op doesn't support global tensor!"; } /* static */ Maybe EagerPToBOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerPToBOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_p_to_s_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EagerPToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); Symbol out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt)); DimVector dim_vec{shape.dim_vec()}; int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LT_OR_RETURN(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EagerPToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerPToSOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerPToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerPToSOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerPToSOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_s_to_b_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } /*static*/ Maybe EagerSToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerSToBOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_s_to_b op doesn't support global tensor!"; } /* static */ Maybe EagerSToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_s_to_b op doesn't support global tensor!"; } /* static */ Maybe EagerSToBOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerSToBOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_s_to_p_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EagerSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } /*static*/ Maybe EagerSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerSToPOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_b_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerSToPOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerSToPOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_s_to_s_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EagerNaiveSToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); Symbol out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt)); DimVector dim_vec{shape.dim_vec()}; int64_t out_parallel_num = out_parallel_desc->parallel_num(); if (out_parallel_num > 1) { CHECK_LE_OR_RETURN(out_split_axis, shape.NumAxes()); BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num); const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EagerNaiveSToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerNaiveSToSOp::GetSbp(user_op::SbpContext* ctx) { return Error::TypeError() << "eager_naive_s_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerNaiveSToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { return Error::TypeError() << "eager_naive_s_to_s op doesn't support global tensor!"; } /* static */ Maybe EagerNaiveSToSOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerNaiveSToSOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eager_symmetric_s_to_p_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EagerSymmetricSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EagerSymmetricSToPOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .PartialSum(user_op::OpArg("out", 0)) .Build(); } return Maybe::Ok(); } /* static */ Maybe EagerSymmetricSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { CHECK_OR_RETURN(sbp_hint.has_split_parallel()); CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); } in_nd_sbp->clear_sbp_parallel(); out_nd_sbp->clear_sbp_parallel(); const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); } return Maybe::Ok(); } /* static */ Maybe EagerSymmetricSToPOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> EagerSymmetricSToPOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/elementwise_maximum_minimum_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { using namespace user_op; Maybe GetSbpSignature_(SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { if (x_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; } if (x_shape.At(i) == y_shape.At(i)) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } else { UNIMPLEMENTED(); } } return Maybe::Ok(); } Maybe InferTensorDesc_(InferContext* ctx) { const TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()) << "Shape of tensor x and y should be same"; FOR_RANGE(int64_t, i, 0, tensor_x.shape().NumAxes()) { CHECK_EQ_OR_RETURN(tensor_x.shape().At(i), tensor_y.shape().At(i)); } TensorDesc* tensor_dx = ctx->MutOutputTensorDesc("dx", 0); TensorDesc* tensor_dy = ctx->MutOutputTensorDesc("dy", 0); if (tensor_dx) { tensor_dx->set_shape(tensor_x.shape()); } if (tensor_dy) { tensor_dy->set_shape(tensor_y.shape()); } return Maybe::Ok(); } Maybe InferDataType_(InferContext* ctx) { const TensorDesc& tensor_dz = ctx->InputTensorDesc("dz", 0); TensorDesc* tensor_dx = ctx->MutOutputTensorDesc("dx", 0); TensorDesc* tensor_dy = ctx->MutOutputTensorDesc("dy", 0); if (tensor_dx) { tensor_dx->set_data_type(tensor_dz.data_type()); } if (tensor_dy) { tensor_dy->set_data_type(tensor_dz.data_type()); } return Maybe::Ok(); } } // namespace #define DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix) \ /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ } #define DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix) \ /* static */ Maybe op_class_name_prefix##BackwardOp::InferLogicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferTensorDesc_(ctx); \ } \ \ /*static*/ Maybe op_class_name_prefix##BackwardOp::InferPhysicalTensorDesc( \ user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##BackwardOp::GetSbp(user_op::SbpContext* ctx) { \ return GetSbpSignature_(ctx); \ } \ \ /* static */ Maybe op_class_name_prefix##BackwardOp::InferDataType( \ user_op::InferContext* ctx) { \ return InferDataType_(ctx); \ } #define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name, op_class_name_prefix) \ DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix); \ DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix); REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_maximum", ElementwiseMaximum); REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_minimum", ElementwiseMinimum); } // namespace oneflow ================================================ FILE: oneflow/user/ops/elu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe EluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe EluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe EluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/embedding_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EmbeddingRenormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe EmbeddingRenormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe EmbeddingRenormOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe EmbeddingRenormOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe EmbeddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& weight_shape = ctx->InputShape("weight", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); DimVector out_dim_vec; out_dim_vec.insert(out_dim_vec.end(), indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend()); out_dim_vec.push_back(weight_shape.At(1)); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe EmbeddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe EmbeddingOp::GetSbp(user_op::SbpContext* ctx) { const int64_t indices_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); if (!scale_grad_by_freq) { FOR_RANGE(int64_t, i, 0, indices_num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("indices", 0), i) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("out", 0), i) .Build(); } } return Maybe::Ok(); } /*static*/ Maybe EmbeddingOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("weight", 0)); return Maybe::Ok(); } /* static */ Maybe EmbeddingOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); // NOLINT(maybe-need-error-msg) indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe EmbeddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& weight_shape = ctx->InputShape("weight", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(weight_shape); return Maybe::Ok(); } /*static*/ Maybe EmbeddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return EmbeddingGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe EmbeddingGradOp::GetSbp(user_op::SbpContext* ctx) { const bool scale_grad_by_freq = ctx->Attr("scale_grad_by_freq"); const int64_t indices_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); if (!scale_grad_by_freq) { for (int32_t i = 0; i < indices_num_axes; i++) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("indices", 0), i) .PartialSum(user_op::OpArg("dx", 0)) .Build(); } } return Maybe::Ok(); } /* static */ Maybe EmbeddingGradOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); // NOLINT(maybe-need-error-msg) indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe EmbeddingGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("weight", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("weight", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/empty_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" namespace oneflow { namespace { Maybe> MakeEmptyStream(const Symbol& out_device, const bool pin_memory) { if (pin_memory) { CHECK_OR_RETURN(out_device->type() == "cpu") << "empty op only support pin_memory in cpu device but got " << out_device->type(); // TODO:(zhaoluyang) Parsing pin-memory-device from python auto pin_device = JUST(Device::New("cuda")); return Stream::New(pin_device, StreamType::kPinnedCompute); } return Stream::New(out_device, StreamType::kCompute); } } // namespace /* static */ Maybe EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); ctx->SetOutputStride("out", 0, Stride(Shape(ctx->Attr("shape").dim_vec()))); return Maybe::Ok(); } /* static */ Maybe EmptyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); ctx->SetOutputStride("out", 0, Stride(physical_shape)); return Maybe::Ok(); } /* static */ Maybe EmptyOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe EmptyOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe EmptyOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } /* static */ Maybe> EmptyOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { Symbol out_device = JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; const bool pin_memory = ctx->Attr("pin_memory"); return MakeEmptyStream(out_device, pin_memory); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/erfinv_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ErfInvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); ctx->SetOutputShape("y", 0, x_shape); return Maybe::Ok(); } /*static*/ Maybe ErfInvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ErfInvOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe ErfInvOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/expand_dims_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { axis = axis < 0 ? axis + num_axes + 1 : axis; CHECK_GE(axis, 0); CHECK_LE(axis, num_axes); return axis; } } // namespace /* static */ Maybe ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const int32_t axis = TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); auto dim_vec = in_shape.dim_vec(); dim_vec.insert(dim_vec.begin() + axis, 1); ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ExpandDimsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ExpandDimsOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const int32_t axis = TransformNegativeAxisToPositive(ctx->Attr("axis"), in_tensor.shape().NumAxes()); auto dim_vec = in_tensor.shape().dim_vec(); FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), in_axis) .Split(user_op::OpArg("out", 0), in_axis < axis ? in_axis : in_axis + 1) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ExpandDimsOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/expand_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { namespace { Maybe InferExpandOutputStride(const Shape& input_shape, const Stride& input_stride, const Shape& expand_shape, Stride* output_stride) { CHECK_EQ_OR_RETURN(input_shape.size(), input_stride.size()); // NOLINT(maybe-need-error-msg) size_t lpad = expand_shape.size() - input_shape.size(); CHECK_GE_OR_RETURN(lpad, 0); // NOLINT(maybe-need-error-msg) output_stride->resize(expand_shape.size(), 0); for (int i = expand_shape.size() - 1; i >= 0; --i) { int64_t dim = i < lpad ? 1 : input_shape[i - lpad]; if (dim == expand_shape[i]) { if (i >= lpad) { output_stride->at(i) = input_stride[i - lpad]; } else if (i < expand_shape.size() - 1) { output_stride->at(i) = output_stride->at(i + 1) * expand_shape[i + 1]; } } else { CHECK_EQ_OR_RETURN(dim, 1); // NOLINT(maybe-need-error-msg) } } // NOTE: expand op only can output contiguous stride, // because lazy don't support to_contiguous op for now *output_stride = Stride(expand_shape); return Maybe::Ok(); } } // namespace /* static */ Maybe ExpandOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const Stride& input_stride = ctx->InputStride("in", 0); const Shape& expand_shape = ctx->Attr("expand_shape"); ctx->SetOutputShape("out", 0, expand_shape); Stride output_stride; JUST(InferExpandOutputStride(input_shape, input_stride, expand_shape, &output_stride)); ctx->SetOutputStride("out", 0, output_stride); return Maybe::Ok(); } /*static*/ Maybe ExpandOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const Stride& input_stride = ctx->InputStride("in", 0); const auto& global_expand_shape = ctx->Attr("expand_shape"); const auto& output_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const auto& device_mesh = *ctx->parallel_desc().hierarchy(); const auto& rank = ctx->parallel_ctx().parallel_id(); const auto local_view = GetTensorSliceView4ParallelId(device_mesh, output_sbp, global_expand_shape, rank); const auto& local_expand_shape = local_view.shape(); ctx->SetOutputShape("out", 0, local_expand_shape); Stride output_stride; JUST(InferExpandOutputStride(input_shape, input_stride, local_expand_shape, &output_stride)); ctx->SetOutputStride("out", 0, output_stride); return Maybe::Ok(); } /* static */ Maybe ExpandOp::GetSbp(user_op::SbpContext* ctx) { const auto& global_in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const auto& global_expand_shape = ctx->Attr("expand_shape"); size_t lpad = global_expand_shape.size() - global_in_shape.size(); CHECK_GE_OR_RETURN(lpad, 0); // NOLINT(maybe-need-error-msg) for (size_t i = 0; i < global_in_shape.size(); ++i) { if (global_in_shape[i] == global_expand_shape[i + lpad]) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i + lpad) .Build(); } } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ExpandOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/eye_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int64_t rows = ctx->Attr("rows"); int64_t cols = ctx->Attr("cols"); ctx->SetOutputShape("out", 0, Shape({rows, cols})); return Maybe::Ok(); } /*static*/ Maybe EyeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EyeOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe EyeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fake_quantization_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FakeQuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& scale_shape = ctx->InputShape("scale", 0); const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for // convolution weights. if (scale_shape.elem_cnt() > 1) { CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } ctx->SetOutputShape("out", 0, in_shape); return Maybe::Ok(); } /*static*/ Maybe FakeQuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FakeQuantizationOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const Shape& logical_scale_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); ctx->NewBuilder() .Broadcast(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); if (logical_scale_shape.elem_cnt() > 1) { // NOTE(Liang Depeng): only consider convolution weight per-channel quantization ctx->NewBuilder() .Split(user_op::OpArg("in", 0), 0) .Split(user_op::OpArg("scale", 0), 0) .Split(user_op::OpArg("zero_point", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); } else { // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise // ops ctx->NewBuilder() .Split(user_op::OpArg("in", 0), 0) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); } FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe FakeQuantizationOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); CHECK_OR_RETURN(scale != nullptr); scale->set_requires_grad(false); user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); CHECK_OR_RETURN(zero_point != nullptr); zero_point->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe FakeQuantizationOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { const int32_t quantization_bit = conf.attr("quantization_bit"); CHECK_GT_OR_RETURN(quantization_bit, 1); CHECK_LE_OR_RETURN(quantization_bit, 8); std::string quantization_scheme = conf.attr("quantization_scheme"); CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); std::string quantization_formula = conf.attr("quantization_formula"); CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); return Maybe::Ok(); } /* static */ Maybe FakeQuantizationOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fft_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FftC2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("input", 0); Stride out_stride = Stride(in_shape); // contiguous ctx->SetOutputShape("out", 0, in_shape); ctx->SetOutputStride("out", 0, out_stride); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); return Maybe::Ok(); } /*static*/ Maybe FftC2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FftC2COp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("input", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe FftC2COp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /* static */ Maybe FftR2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("input", 0); const auto& dims = ctx->Attr>("dims"); bool onesided = ctx->Attr("onesided"); Shape out_shape = in_shape; auto last_dim = dims.back(); if (onesided) { out_shape[last_dim] = out_shape[last_dim] / 2 + 1; } Stride out_stride = Stride(out_shape); ctx->SetOutputShape("out", 0, out_shape); ctx->SetOutputStride("out", 0, out_stride); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); return Maybe::Ok(); } /*static*/ Maybe FftR2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FftR2COp::GetSbp(user_op::SbpContext* ctx) { // TO-DO : Validate sbp ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe FftR2COp::InferDataType(user_op::InferContext* ctx) { const DataType& input_type = ctx->InputDType("input", 0); switch (input_type) { case (kFloat): ctx->SetOutputDType("out", 0, kComplex64); break; case (kDouble): ctx->SetOutputDType("out", 0, kComplex128); break; default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; } return Maybe::Ok(); } /* static */ Maybe FftC2ROp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("input", 0); const auto& dims = ctx->Attr>("dims"); int64_t last_dim_size = ctx->Attr("last_dim_size"); Shape out_shape = in_shape; out_shape[dims.back()] = last_dim_size; Stride out_stride = Stride(out_shape); ctx->SetOutputShape("out", 0, out_shape); ctx->SetOutputStride("out", 0, out_stride); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); return Maybe::Ok(); } /*static*/ Maybe FftC2ROp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FftC2ROp::GetSbp(user_op::SbpContext* ctx) { // TO-DO : Validate sbp ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe FftC2ROp::InferDataType(user_op::InferContext* ctx) { const DataType& input_type = ctx->InputDType("input", 0); switch (input_type) { case (kComplex64): ctx->SetOutputDType("out", 0, kFloat); break; case (kComplex128): ctx->SetOutputDType("out", 0, kDouble); break; default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fill_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("out", 0, in_shape); ctx->SetOutputStride("out", 0, ctx->InputStride("in", 0)); return Maybe::Ok(); } /*static*/ Maybe FillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FillOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe FillOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe FillTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("out", 0, in_shape); ctx->SetOutputStride("out", 0, ctx->InputStride("in", 0)); return Maybe::Ok(); } /*static*/ Maybe FillTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FillTensorOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Broadcast(user_op::OpArg("value", 0)) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("value", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe FillTensorOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/flip_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const int input_dims = x_desc.shape().NumAxes(); const std::vector dims = ctx->Attr>("dims"); CHECK_OR_RETURN(dims.size() <= input_dims) << "len of dims must less than len of input tensor"; for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(x_desc.shape()); return Maybe::Ok(); } /*static*/ auto FlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FlipOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FlipOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const std::vector dims = ctx->Attr>("dims"); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { bool flag = true; for (auto x : dims) { if (x == i) { flag = false; break; } } if (flag) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } } return Maybe::Ok(); } /*static*/ auto FlipOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/frac_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FracOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(x_desc.shape()); return Maybe::Ok(); } /*static*/ auto FracOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FracOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FracOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /*static*/ auto FracOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_attention_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe ParseDims(const Shape& shape, const std::string& layout, const Optional& batch_size, const Optional& seq_len, const Optional& num_heads, const Optional& head_size, int64_t* b, int64_t* m, int64_t* h, int64_t* k, bool* bm_packed) { if (shape.NumAxes() == 2) { if (layout == "(BM)(HK)" || layout == "(BM)(H2K)" || layout == "(BM)(H3K)") { *bm_packed = true; CHECK_OR_RETURN(batch_size); CHECK_OR_RETURN(seq_len); *b = JUST(batch_size); *m = JUST(seq_len); int64_t packed_n = 0; if (layout == "(BM)(HK)") { packed_n = 1; } else if (layout == "(BM)(H2K)") { packed_n = 2; } else if (layout == "(BM)(H3K)") { packed_n = 3; } else { UNIMPLEMENTED_THEN_RETURN(); } const int64_t hidden_size = shape.At(1); if (num_heads) { const int64_t expected_h = JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0); *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = JUST(head_size); const int64_t packed_k = packed_n * expected_k; CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0); *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED_THEN_RETURN(); } } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (shape.NumAxes() == 3) { if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)" || layout == "BM(H3K)" || layout == "MB(H3K)") { *bm_packed = false; int64_t packed_n = 0; if (layout == "BM(HK)") { *b = shape.At(0); *m = shape.At(1); packed_n = 1; } else if (layout == "MB(HK)") { *b = shape.At(1); *m = shape.At(0); packed_n = 1; } else if (layout == "BM(H2K)") { *b = shape.At(0); *m = shape.At(1); packed_n = 2; } else if (layout == "MB(H2K)") { *b = shape.At(1); *m = shape.At(0); packed_n = 2; } else if (layout == "BM(H3K)") { *b = shape.At(0); *m = shape.At(1); packed_n = 3; } else if (layout == "MB(H3K)") { *b = shape.At(1); *m = shape.At(0); packed_n = 3; } else { UNIMPLEMENTED_THEN_RETURN(); } const int64_t hidden_size = shape.At(2); if (num_heads) { const int64_t expected_h = JUST(num_heads); const int64_t packed_h = packed_n * expected_h; CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0); *h = expected_h; *k = hidden_size / packed_h; } else if (head_size) { const int64_t expected_k = JUST(head_size); const int64_t packed_k = packed_n * expected_k; CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0); *h = hidden_size / packed_k; *k = expected_k; } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (layout == "(BM)HK") { *bm_packed = true; CHECK_OR_RETURN(batch_size); CHECK_OR_RETURN(seq_len); *b = JUST(batch_size); *m = JUST(seq_len); *h = shape.At(1); *k = shape.At(2); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (shape.NumAxes() == 4) { *bm_packed = false; if (layout == "BMHK") { *b = shape.At(0); *m = shape.At(1); *h = shape.At(2); *k = shape.At(3); } else if (layout == "BHMK") { *b = shape.At(0); *m = shape.At(2); *h = shape.At(1); *k = shape.At(3); } else if (layout == "MBHK") { *b = shape.At(1); *m = shape.At(0); *h = shape.At(2); *k = shape.At(3); } else { UNIMPLEMENTED_THEN_RETURN(); } } else { UNIMPLEMENTED_THEN_RETURN(); }; if (batch_size) { const int64_t expected_b = JUST(batch_size); CHECK_EQ_OR_RETURN(*b, expected_b); } if (seq_len) { const int64_t expected_m = JUST(seq_len); CHECK_EQ_OR_RETURN(*m, expected_m); } if (num_heads) { const int64_t expected_h = JUST(num_heads); CHECK_EQ_OR_RETURN(*h, expected_h); } if (head_size) { const int64_t expected_k = JUST(head_size); CHECK_EQ_OR_RETURN(*k, expected_k); } return Maybe::Ok(); } Maybe ParseDims(const Shape& shape, const std::string& layout, const Optional& num_heads, const Optional& head_size, int64_t* b, int64_t* m, int64_t* h, int64_t* k) { bool bm_packed{}; return ParseDims(shape, layout, Optional(), Optional(), num_heads, head_size, b, m, h, k, &bm_packed); } Maybe LayoutToShape(int64_t b, int64_t m, int64_t h, int64_t k, const std::string& layout) { if (layout == "BM(HK)") { return Shape({b, m, h * k}); } else if (layout == "BM(H2K)") { return Shape({b, m, h * k * 2}); } else if (layout == "BM(H3K)") { return Shape({b, m, h * k * 3}); } else if (layout == "MB(HK)") { return Shape({m, b, h * k}); } else if (layout == "MB(H2K)") { return Shape({m, b, h * k * 2}); } else if (layout == "MB(H3K)") { return Shape({m, b, h * k * 3}); } else if (layout == "BMHK") { return Shape({b, m, h, k}); } else if (layout == "BHMK") { return Shape({b, h, m, k}); } else if (layout == "MBHK") { return Shape({m, b, h, k}); } else { UNIMPLEMENTED_THEN_RETURN(); } } Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t* b_split_axis, int64_t* h_split_axis) { if (layout == "BM(HK)" || layout == "BM(H2K)" || layout == "BM(H3K)") { *b_split_axis = 0; if (can_hk_split) { *h_split_axis = 2; } else { *h_split_axis = -1; } } else if (layout == "MB(HK)" || layout == "MB(H2K)" || layout == "MB(H3K)") { *b_split_axis = 1; if (can_hk_split) { *h_split_axis = 2; } else { *h_split_axis = -1; } } else if (layout == "BMHK") { *b_split_axis = 0; *h_split_axis = 2; } else if (layout == "BHMK") { *b_split_axis = 0; *h_split_axis = 1; } else if (layout == "MBHK") { *b_split_axis = 1; *h_split_axis = 2; } else if (layout == "(BM)HK") { *b_split_axis = -1; *h_split_axis = 1; } else if (layout == "(BM)(HK)" || layout == "(BM)(H2K)" || layout == "(BM)(H3K)") { *b_split_axis = -1; if (can_hk_split) { *h_split_axis = 1; } else { *h_split_axis = -1; } } else { UNIMPLEMENTED_THEN_RETURN(); } return Maybe::Ok(); }; } // namespace /*static*/ auto FusedMultiHeadAttentionInferenceOp::InferDataType(user_op::InferContext* ctx) -> Maybe { DataType query_type = ctx->InputDType("query", 0); DataType key_type = ctx->InputDType("key", 0); DataType value_type = ctx->InputDType("value", 0); CHECK_EQ_OR_RETURN(key_type, query_type); CHECK_EQ_OR_RETURN(value_type, query_type); if (ctx->has_input("attn_bias", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("attn_bias", 0), query_type); } if (ctx->has_input("query_seq_start", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("query_seq_start", 0), DataType::kInt32); } if (ctx->has_input("key_seq_start", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("key_seq_start", 0), DataType::kInt32); } if (ctx->has_input("key_seq_len", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("key_seq_len", 0), DataType::kInt32); } ctx->SetOutputDType("out", 0, query_type); return Maybe::Ok(); } /*static*/ auto FusedMultiHeadAttentionInferenceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { const int64_t query_head_size = ctx->Attr("query_head_size"); CHECK_GE_OR_RETURN(query_head_size, 1); Optional batch_size; if (ctx->has_input("query_seq_start", 0)) { CHECK_OR_RETURN(ctx->has_input("key_seq_start", 0)); const Shape& query_seq_start_shape = ctx->InputShape("query_seq_start", 0); CHECK_EQ_OR_RETURN(query_seq_start_shape.NumAxes(), 1); CHECK_GT_OR_RETURN(query_seq_start_shape.At(0), 1); CHECK_OR_RETURN(ctx->InputShape("key_seq_start", 0) == query_seq_start_shape); batch_size = query_seq_start_shape.At(0) - 1; if (ctx->has_input("key_seq_len", 0)) { const Shape& key_seq_len_shape = ctx->InputShape("key_seq_len", 0); CHECK_EQ_OR_RETURN(key_seq_len_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(key_seq_len_shape.At(0), query_seq_start_shape.At(0) - 1); } } else { CHECK_OR_RETURN(!ctx->has_input("key_seq_start", 0)); CHECK_OR_RETURN(!ctx->has_input("key_seq_len", 0)); } Optional query_max_seq_len; const int64_t attr_query_max_seq_len = ctx->Attr("query_max_seq_len"); if (attr_query_max_seq_len != 0) { query_max_seq_len = attr_query_max_seq_len; } Optional key_max_seq_len; const int64_t attr_key_max_seq_len = ctx->Attr("key_max_seq_len"); if (attr_key_max_seq_len != 0) { key_max_seq_len = attr_key_max_seq_len; } const Shape& query_shape = ctx->InputShape("query", 0); const std::string& query_layout = ctx->Attr("query_layout"); int64_t q_b = 0; int64_t q_m = 0; int64_t q_h = 0; int64_t q_k = 0; bool q_bm_packed = false; JUST(ParseDims(query_shape, query_layout, batch_size, query_max_seq_len, Optional(), query_head_size, &q_b, &q_m, &q_h, &q_k, &q_bm_packed)); if (q_bm_packed) { CHECK_OR_RETURN(ctx->has_input("query_seq_start", 0)); } const Shape& key_shape = ctx->InputShape("key", 0); const std::string& key_layout = ctx->Attr("key_layout"); int64_t k_b = 0; int64_t k_m = 0; int64_t k_h = 0; int64_t k_k = 0; bool k_bm_packed = false; JUST(ParseDims(key_shape, key_layout, q_b, key_max_seq_len, q_h, q_k, &k_b, &k_m, &k_h, &k_k, &k_bm_packed)); CHECK_EQ_OR_RETURN(k_b, q_b); CHECK_EQ_OR_RETURN(k_h, q_h); CHECK_EQ_OR_RETURN(k_bm_packed, q_bm_packed); const Shape& value_shape = ctx->InputShape("value", 0); const std::string& value_layout = ctx->Attr("value_layout"); int64_t v_b = 0; int64_t v_m = 0; int64_t v_h = 0; int64_t v_k = 0; bool v_bm_packed = false; JUST(ParseDims(value_shape, value_layout, q_b, k_m, q_h, Optional(), &v_b, &v_m, &v_h, &v_k, &v_bm_packed)); CHECK_EQ_OR_RETURN(v_b, q_b); CHECK_EQ_OR_RETURN(v_m, k_m); CHECK_EQ_OR_RETURN(v_bm_packed, k_bm_packed); if (ctx->has_input("attn_bias", 0)) { const Shape& attn_bias_shape = ctx->InputShape("attn_bias", 0); const int64_t num_attn_bias_axes = attn_bias_shape.NumAxes(); CHECK_GE_OR_RETURN(num_attn_bias_axes, 1); CHECK_LE_OR_RETURN(num_attn_bias_axes, 4); DimVector padded_attn_bias_shape; for (int i = 0; i < 4 - num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(1); } for (int i = 0; i < num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(attn_bias_shape.At(i)); } CHECK_OR_RETURN(padded_attn_bias_shape.at(0) == 1 || padded_attn_bias_shape.at(0) == q_b); CHECK_OR_RETURN(padded_attn_bias_shape.at(1) == 1 || padded_attn_bias_shape.at(1) == q_h); CHECK_OR_RETURN(padded_attn_bias_shape.at(2) == 1 || padded_attn_bias_shape.at(2) >= q_m); CHECK_OR_RETURN(padded_attn_bias_shape.at(3) >= k_m); } const std::string& output_layout = ctx->Attr("output_layout"); const bool o_bm_packed = output_layout == "(BM)(HK)"; CHECK_EQ(o_bm_packed, q_bm_packed); if (output_layout == "(BM)(HK)") { ctx->SetOutputShape("out", 0, Shape({query_shape.At(0), q_h * v_k})); } else if (output_layout == "BM(HK)") { ctx->SetOutputShape("out", 0, Shape({q_b, q_m, q_h * v_k})); } else if (output_layout == "MB(HK)") { ctx->SetOutputShape("out", 0, Shape({q_m, q_b, q_h * v_k})); } else { UNIMPLEMENTED_THEN_RETURN(); } return Maybe::Ok(); } /*static*/ auto FusedMultiHeadAttentionInferenceOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedMultiHeadAttentionInferenceOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedMultiHeadAttentionInferenceOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const int64_t query_head_size = ctx->user_op_conf().attr("query_head_size"); const std::string& query_layout = ctx->user_op_conf().attr("query_layout"); const std::string& key_layout = ctx->user_op_conf().attr("key_layout"); const std::string& value_layout = ctx->user_op_conf().attr("value_layout"); const std::string& output_layout = ctx->user_op_conf().attr("output_layout"); int64_t num_heads = 0; const user_op::TensorDesc& query = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0); if (query.shape().NumAxes() == 2) { if (query_layout == "(BM)(HK)") { CHECK_EQ_OR_RETURN(query.shape().At(1) % query_head_size, 0); num_heads = query.shape().At(1) / query_head_size; } else if (query_layout == "(BM)(H3K)") { CHECK_EQ_OR_RETURN(query.shape().At(1) % (query_head_size * 3), 0); num_heads = query.shape().At(1) / (query_head_size * 3); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (query.shape().NumAxes() == 3) { if (query_layout == "BM(HK)" || query_layout == "MB(HK)") { CHECK_EQ_OR_RETURN(query.shape().At(2) % query_head_size, 0); num_heads = query.shape().At(2) / query_head_size; } else if (query_layout == "BM(H3K)" || query_layout == "MB(H3K)") { CHECK_EQ_OR_RETURN(query.shape().At(2) % (query_head_size * 3), 0); num_heads = query.shape().At(2) / (query_head_size * 3); } else if (query_layout == "(BM)HK") { num_heads = query.shape().At(1); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (query.shape().NumAxes() == 4) { if (query_layout == "BMHK") { num_heads = query.shape().At(2); } else if (query_layout == "BHMK") { num_heads = query.shape().At(1); } else { UNIMPLEMENTED_THEN_RETURN(); } } else { UNIMPLEMENTED_THEN_RETURN(); } const bool can_hk_split = num_heads % ctx->parallel_num() == 0; int64_t q_b_split_axis = -1; int64_t q_h_split_axis = -1; JUST(ParseSplitAxis(query_layout, can_hk_split, &q_b_split_axis, &q_h_split_axis)); int64_t k_b_split_axis = -1; int64_t k_h_split_axis = -1; JUST(ParseSplitAxis(key_layout, can_hk_split, &k_b_split_axis, &k_h_split_axis)); int64_t v_b_split_axis = -1; int64_t v_h_split_axis = -1; JUST(ParseSplitAxis(value_layout, can_hk_split, &v_b_split_axis, &v_h_split_axis)); int64_t o_b_split_axis = -1; int64_t o_h_split_axis = -1; JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis)); std::vector attn_bias_arg; if (ctx->user_op_conf().has_input("attn_bias", 0)) { attn_bias_arg.emplace_back("attn_bias", 0); } std::vector var_len_args; if (ctx->user_op_conf().has_input("query_seq_start", 0)) { var_len_args.emplace_back("query_seq_start", 0); } if (ctx->user_op_conf().has_input("key_seq_start", 0)) { var_len_args.emplace_back("key_seq_start", 0); } if (ctx->user_op_conf().has_input("key_seq_len", 0)) { var_len_args.emplace_back("key_seq_len", 0); } if (q_b_split_axis >= 0 && k_b_split_axis >= 0 && v_b_split_axis >= 0 && o_b_split_axis >= 0 && var_len_args.empty()) { bool broadcast_attn_bias = false; if (ctx->user_op_conf().has_input("attn_bias", 0)) { const user_op::TensorDesc& attn_bias = ctx->LogicalTensorDesc4InputArgNameAndIndex("attn_bias", 0); if (attn_bias.shape().NumAxes() < 4 || attn_bias.shape().At(0) == 1) { broadcast_attn_bias = true; } } if (broadcast_attn_bias) { ctx->NewBuilder() .Split(user_op::OpArg("query", 0), q_b_split_axis) .Split(user_op::OpArg("key", 0), k_b_split_axis) .Split(user_op::OpArg("value", 0), v_b_split_axis) .Broadcast(attn_bias_arg) .Split(ctx->outputs(), o_b_split_axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("query", 0), q_b_split_axis) .Split(user_op::OpArg("key", 0), k_b_split_axis) .Split(user_op::OpArg("value", 0), v_b_split_axis) .Split(attn_bias_arg, 0) .Split(ctx->outputs(), o_b_split_axis) .Build(); } } if (q_h_split_axis >= 0 && k_h_split_axis >= 0 && v_h_split_axis >= 0 && o_h_split_axis >= 0) { bool broadcast_attn_bias = false; if (ctx->user_op_conf().has_input("attn_bias", 0)) { const user_op::TensorDesc& attn_bias = ctx->LogicalTensorDesc4InputArgNameAndIndex("attn_bias", 0); if (attn_bias.shape().NumAxes() == 4) { if (attn_bias.shape().At(1) == 1) { broadcast_attn_bias = true; } } else if (attn_bias.shape().NumAxes() == 3) { if (attn_bias.shape().At(0) == 1) { broadcast_attn_bias = true; } } else { broadcast_attn_bias = true; } } if (broadcast_attn_bias) { ctx->NewBuilder() .Split(user_op::OpArg("query", 0), q_h_split_axis) .Split(user_op::OpArg("key", 0), k_h_split_axis) .Split(user_op::OpArg("value", 0), v_h_split_axis) .Broadcast(attn_bias_arg) .Broadcast(var_len_args) .Split(ctx->outputs(), o_h_split_axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("query", 0), q_h_split_axis) .Split(user_op::OpArg("key", 0), k_h_split_axis) .Split(user_op::OpArg("value", 0), v_h_split_axis) .Split(attn_bias_arg, 1) .Broadcast(var_len_args) .Split(ctx->outputs(), o_h_split_axis) .Build(); } } return Maybe::Ok(); } /*static*/ auto FusedAttentionConcatPastKeyValueOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const DataType data_type = ctx->InputDType("key", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("value", 0), data_type); if (ctx->has_input("past_key", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("past_key", 0), data_type); } if (ctx->has_input("past_value", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("past_value", 0), data_type); } ctx->SetOutputDType("output_key", 0, data_type); ctx->SetOutputDType("output_value", 0, data_type); return Maybe::Ok(); } /*static*/ auto FusedAttentionConcatPastKeyValueOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { const int64_t key_head_size = ctx->Attr("key_head_size"); CHECK_GE_OR_RETURN(key_head_size, 1); const Shape& key_shape = ctx->InputShape("key", 0); const std::string& key_layout = ctx->Attr("key_layout"); int64_t k_b = 0; int64_t k_m = 0; int64_t k_h = 0; int64_t k_k = 0; JUST( ParseDims(key_shape, key_layout, Optional(), key_head_size, &k_b, &k_m, &k_h, &k_k)); const Shape& value_shape = ctx->InputShape("value", 0); const std::string& value_layout = ctx->Attr("value_layout"); int64_t v_b = 0; int64_t v_m = 0; int64_t v_h = 0; int64_t v_k = 0; JUST(ParseDims(value_shape, value_layout, k_h, k_k, &v_b, &v_m, &v_h, &v_k)); CHECK_EQ_OR_RETURN(v_b, k_b); CHECK_EQ_OR_RETURN(v_m, k_m); int64_t past_k_b = 0; int64_t past_k_m = 0; int64_t past_k_h = 0; int64_t past_k_k = 0; int64_t past_v_b = 0; int64_t past_v_m = 0; int64_t past_v_h = 0; int64_t past_v_k = 0; const std::string& past_key_layout = ctx->Attr("past_key_layout"); const std::string& past_value_layout = ctx->Attr("past_value_layout"); if (ctx->has_input("past_key", 0)) { CHECK_OR_RETURN(ctx->has_input("past_value", 0)); const Shape& past_key_shape = ctx->InputShape("past_key", 0); JUST(ParseDims(past_key_shape, past_key_layout, k_h, k_k, &past_k_b, &past_k_m, &past_k_h, &past_k_k)); CHECK_EQ_OR_RETURN(past_k_b, k_b); const Shape& past_value_shape = ctx->InputShape("past_value", 0); JUST(ParseDims(past_value_shape, past_value_layout, k_h, k_k, &past_v_b, &past_v_m, &past_v_h, &past_v_k)); CHECK_EQ_OR_RETURN(past_v_b, k_b); CHECK_EQ_OR_RETURN(past_v_m, past_k_m); } else { CHECK_OR_RETURN(!ctx->has_input("past_value", 0)); } ctx->SetOutputShape("output_key", 0, *JUST(LayoutToShape(k_b, past_k_m + k_m, k_h, k_k, past_key_layout))); ctx->SetOutputShape("output_value", 0, *JUST(LayoutToShape(v_b, past_v_m + v_m, v_h, v_k, past_value_layout))); return Maybe::Ok(); } /*static*/ auto FusedAttentionConcatPastKeyValueOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedAttentionConcatPastKeyValueOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedAttentionConcatPastKeyValueOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const int64_t key_head_size = ctx->user_op_conf().attr("key_head_size"); const std::string& past_key_layout = ctx->user_op_conf().attr("past_key_layout"); const std::string& past_value_layout = ctx->user_op_conf().attr("past_value_layout"); const std::string& key_layout = ctx->user_op_conf().attr("key_layout"); const std::string& value_layout = ctx->user_op_conf().attr("value_layout"); int64_t num_heads = 0; { int64_t b = 0; int64_t m = 0; int64_t k = 0; const user_op::TensorDesc& key = ctx->LogicalTensorDesc4InputArgNameAndIndex("key", 0); JUST(ParseDims(key.shape(), key_layout, Optional(), key_head_size, &b, &m, &num_heads, &k)); } const bool can_hk_split = num_heads % ctx->parallel_num() == 0; int64_t past_k_b_split_axis = -1; int64_t past_k_h_split_axis = -1; JUST(ParseSplitAxis(past_key_layout, can_hk_split, &past_k_b_split_axis, &past_k_h_split_axis)); int64_t past_v_b_split_axis = -1; int64_t past_v_h_split_axis = -1; JUST(ParseSplitAxis(past_value_layout, can_hk_split, &past_v_b_split_axis, &past_v_h_split_axis)); int64_t k_b_split_axis = -1; int64_t k_h_split_axis = -1; JUST(ParseSplitAxis(key_layout, can_hk_split, &k_b_split_axis, &k_h_split_axis)); int64_t v_b_split_axis = -1; int64_t v_h_split_axis = -1; JUST(ParseSplitAxis(value_layout, can_hk_split, &v_b_split_axis, &v_h_split_axis)); std::vector past_key_arg; if (ctx->user_op_conf().has_input("past_key", 0)) { past_key_arg.emplace_back("past_key", 0); } std::vector past_value_arg; if (ctx->user_op_conf().has_input("past_value", 0)) { past_value_arg.emplace_back("past_value", 0); } if (past_k_b_split_axis >= 0 && past_v_b_split_axis >= 0 && k_b_split_axis >= 0 && v_b_split_axis >= 0) { ctx->NewBuilder() .Split(past_key_arg, past_k_b_split_axis) .Split(past_value_arg, past_v_b_split_axis) .Split(user_op::OpArg("key", 0), k_b_split_axis) .Split(user_op::OpArg("value", 0), v_b_split_axis) .Split(user_op::OpArg("output_key", 0), past_k_b_split_axis) .Split(user_op::OpArg("output_value", 0), past_v_b_split_axis) .Build(); } if (past_k_h_split_axis >= 0 && past_v_h_split_axis >= 0 && k_h_split_axis >= 0 && v_h_split_axis >= 0) { ctx->NewBuilder() .Split(past_key_arg, past_k_h_split_axis) .Split(past_value_arg, past_v_h_split_axis) .Split(user_op::OpArg("key", 0), k_h_split_axis) .Split(user_op::OpArg("value", 0), v_h_split_axis) .Split(user_op::OpArg("output_key", 0), past_k_h_split_axis) .Split(user_op::OpArg("output_value", 0), past_v_h_split_axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe FusedApplyRotaryEmbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const std::string& x_layout = ctx->Attr("x_layout"); const std::string& output_layout = ctx->Attr("output_layout"); const std::string& mode = ctx->Attr("mode"); const int64_t rotary_size = ctx->Attr("rotary_size"); const int64_t k_size = ctx->Attr("k_size"); const int64_t tensor_index = ctx->Attr("tensor_index"); CHECK_OR_RETURN((tensor_index >= 0) && (tensor_index <= 2)) << "tensor_index should be in range [0, 2]."; CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) << "mode should be either \"interval\" or \"plane\"."; CHECK_OR_RETURN(output_layout != "BM(H2K)" && output_layout != "BM(H3K)" && output_layout != "MB(H2K)" && output_layout != "MB(H3K)") << "output_layout should not be \"BM(H2k)\", \"BM(H3K)\", \"MB(H2K)\", \"MB(H3K)\"."; int64_t b = 0, m = 0, h = 0, k = 0; JUST(ParseDims(x_desc.shape(), x_layout, Optional(), Optional(k_size), &b, &m, &h, &k)); CHECK_LE_OR_RETURN(rotary_size, k) << "rotary_size should be no more than K of input x."; int64_t rotary_emb_dim = 1; if (ctx->has_input("position_ids", 0)) { const user_op::TensorDesc& position_ids_desc = ctx->InputTensorDesc("position_ids", 0); CHECK_EQ_OR_RETURN(position_ids_desc.shape().NumAxes(), 3) << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(0), b) << "1st dim of position_ids should be equal to B."; CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(2), m) << "3rd dim of position_ids should be equal to M."; rotary_emb_dim = position_ids_desc.shape().At(1); CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2) << "2nd dim of position_ids should be 1 or 2."; } const int64_t actual_rotary_size = rotary_size / rotary_emb_dim; CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) << "rotary_size should be a multiple of 2 * rotary_encoding_dim."; bool has_cos = ctx->has_input("cos", 0); bool has_sin = ctx->has_input("sin", 0); // TODO: fused_apply_rotary_emb have same logic no matter name if (has_cos && has_sin) { const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); CHECK_EQ_OR_RETURN(cos_desc.shape().NumAxes(), 2) << "The number of dimensions of cos should be equal to 2."; CHECK_OR_RETURN(cos_desc.shape() == sin_desc.shape()) << "The dimensions of cos & sin should be the same."; CHECK_EQ_OR_RETURN(cos_desc.shape().At(1), actual_rotary_size) << "The 1st dimension of cos & sin should equal to rotary_size // " "rotary_embedding_dimension."; } else if (!has_cos && !has_sin) { // Do nothing } else { UNIMPLEMENTED_THEN_RETURN(); } if (!ctx->has_input("position_ids", 0)) { if (has_cos && has_sin) { const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); CHECK_GE_OR_RETURN(cos_desc.shape().At(0), m) << "M of cos should be no less than M of x if position_ids is not given."; // K of cos & sin is checked inside ParseDims } } Shape out_shape = *JUST(LayoutToShape(b, m, h, k, output_layout)); ctx->SetOutputShape("out", 0, out_shape); return Maybe::Ok(); } /*static*/ Maybe FusedApplyRotaryEmbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedApplyRotaryEmbOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); int num_heads = -1; const int64_t k_size = ctx->Attr("k_size"); const std::string& x_layout = ctx->Attr("x_layout"); const std::string& output_layout = ctx->Attr("output_layout"); if (x_desc.shape().NumAxes() == 2) { if (x_layout == "(BM)(HK)") { CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0); num_heads = x_desc.shape().At(1) / k_size; } else if (x_layout == "(BM)(H3K)") { CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0); num_heads = x_desc.shape().At(1) / (k_size * 3); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (x_desc.shape().NumAxes() == 3) { if (x_layout == "BM(HK)" || x_layout == "MB(HK)") { CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0); num_heads = x_desc.shape().At(2) / k_size; } else if (x_layout == "BM(H3K)" || x_layout == "MB(H3K)") { CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0); num_heads = x_desc.shape().At(2) / (k_size * 3); } else if (x_layout == "(BM)HK") { num_heads = x_desc.shape().At(1); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (x_desc.shape().NumAxes() == 4) { if (x_layout == "BMHK") { num_heads = x_desc.shape().At(2); } else if (x_layout == "BHMK") { num_heads = x_desc.shape().At(1); } else { UNIMPLEMENTED_THEN_RETURN(); } } else { UNIMPLEMENTED_THEN_RETURN(); } const bool can_hk_split = num_heads % ctx->parallel_num() == 0; int64_t x_b_split_axis = -1; int64_t x_h_split_axis = -1; JUST(ParseSplitAxis(x_layout, can_hk_split, &x_b_split_axis, &x_h_split_axis)); int64_t o_b_split_axis = -1; int64_t o_h_split_axis = -1; JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis)); if (x_b_split_axis >= 0 && o_b_split_axis >= 0) { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("x", 0), x_b_split_axis) .Split(user_op::OpArg("out", 0), o_b_split_axis); if (ctx->user_op_conf().has_input("cos", 0)) builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); if (ctx->user_op_conf().has_input("position_ids", 0)) builder = builder.Split(user_op::OpArg("position_ids", 0), 0); builder.Build(); } if (x_h_split_axis >= 0 && o_h_split_axis >= 0) { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("x", 0), x_h_split_axis) .Split(user_op::OpArg("out", 0), o_h_split_axis); if (ctx->user_op_conf().has_input("cos", 0)) builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); if (ctx->user_op_conf().has_input("position_ids", 0)) builder = builder.Broadcast(user_op::OpArg("position_ids", 0)); builder.Build(); } return Maybe::Ok(); } /* static */ Maybe FusedApplyRotaryEmbOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); bool has_sinuous = ctx->has_input("cos", 0); if (has_sinuous) { const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); CHECK_EQ_OR_RETURN(cos_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(cos_desc.data_type()); CHECK_EQ_OR_RETURN(sin_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(sin_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_bias_add_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FusedBiasAddGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); const auto bias_add_axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); ctx->SetOutputShape("out", 0, a_tensor_desc.shape()); ctx->SetOutputIsDynamic("out", 0, a_tensor_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedBiasAddGeluOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedBiasAddGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); ctx->SetOutputDType("out", 0, a_tensor_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const auto axis = ctx->Attr("axis"); for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); ++i) { if (i == axis) { continue; } ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i) .Broadcast(user_op::OpArg("b", 0)) .Split(ctx->outputs(), i) .Build(); } ctx->NewBuilder() .Split(user_op::OpArg("b", 0), 0) .Split(user_op::OpArg("a", 0), axis) .Split(ctx->outputs(), axis) .Build(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); const auto bias_add_axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); ctx->SetOutputShape("dx", 0, a_tensor_desc.shape()); ctx->SetOutputIsDynamic("dx", 0, a_tensor_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedBiasAddGeluGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedBiasAddGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); ctx->SetOutputDType("dx", 0, a_tensor_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const auto axis = ctx->Attr("axis"); for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); ++i) { if (i == axis) { continue; } ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i) .Split(user_op::OpArg("dy", 0), i) .Broadcast(user_op::OpArg("b", 0)) .Split(ctx->outputs(), i) .Build(); } ctx->NewBuilder() .Split(user_op::OpArg("b", 0), 0) .Split(user_op::OpArg("a", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(ctx->outputs(), axis) .Build(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); const auto& mask_tensor_desc = ctx->InputTensorDesc("mask", 0); const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); const auto bias_add_axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); ctx->SetOutputShape("out", 0, a_tensor_desc.shape()); ctx->SetOutputIsDynamic("out", 0, a_tensor_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedBiasAddMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); ctx->SetOutputDType("out", 0, a_tensor_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) -> Maybe { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const auto axis = ctx->Attr("axis"); std::vector split_args; split_args.emplace_back("a", 0); split_args.emplace_back("mask", 0); split_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { split_args.emplace_back("_add_to_output", 0); } for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); ++i) { if (i == axis) { continue; } ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg("b", 0)).Build(); } ctx->NewBuilder().Split(user_op::OpArg("b", 0), 0).Split(split_args, axis).Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_bias_add_scale_mask_softmax_dropout_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { bool CheckBroadcastable(const Shape& shape, const Shape& broadcast_shape) { int left_pad = broadcast_shape.size() - shape.size(); if (left_pad < 0) { return false; } for (int i = 0; i < shape.size(); ++i) { int j = i + left_pad; if (shape[i] != 1 && shape[i] != broadcast_shape[j]) { return false; } } return true; } bool CheckBroadcastAndSimplifyDims(const Shape& shape, const Shape& broadcast_shape, int& simplified_ndim, int64_t* simplified_dims) { int lpad = broadcast_shape.size() - shape.size(); if (lpad < 0) { return false; } simplified_ndim = 0; bool prev_broadcast = false; for (int i = 0; i < broadcast_shape.size(); ++i) { int64_t dim = (i < lpad) ? 1 : shape[i - lpad]; int64_t broadcast_dim = broadcast_shape[i]; if (dim != 1 && dim != broadcast_dim) { return false; } bool broadcast = (dim == 1 && broadcast_dim != 1); if (simplified_ndim > 0 && broadcast == prev_broadcast) { // fold to prev dim simplified_dims[simplified_ndim - 1] *= dim; } else { simplified_dims[simplified_ndim] = dim; simplified_ndim += 1; } prev_broadcast = broadcast; } return true; } // return lpad int GetBroadcastDims(const Shape& shape, const Shape& broadcast_shape, HashSet& broadcast_dims) { int lpad = broadcast_shape.size() - shape.size(); if (lpad < 0) { return lpad; } for (int i = 0; i < broadcast_shape.size(); ++i) { if (i < lpad) { broadcast_dims.insert(i); } else { int j = i - lpad; if (shape[j] == 1 && shape[j] != broadcast_shape[i]) { broadcast_dims.insert(i); } if (shape[j] != 1 && shape[j] != broadcast_shape[i]) { return -1; } } } return lpad; } } // namespace Maybe FusedBiasAddScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& bias_shape = ctx->InputShape("bias", 0); const Shape& mask_shape = ctx->InputShape("mask", 0); const Shape& dropout_mask_shape = ctx->InputShape("dropout_mask", 0); CHECK_GE_OR_RETURN(x_shape.size(), 2) << Error::RuntimeError() << "x has at least 2 dimensions"; CHECK_EQ_OR_RETURN(x_shape.back(), mask_shape.back()) << " Last dimension of x and mask should be equal, which is softmax dimension."; CHECK_EQ_OR_RETURN(dropout_mask_shape, x_shape) << Error::RuntimeError() << "dropout_mask shape " << dropout_mask_shape.ToString() << " should be equal to x shape " << x_shape.ToString(); int simplified_bias_ndim = 0; int simplified_mask_ndim = 0; DimVector simplified_bias_dims(x_shape.size()); DimVector simplified_mask_dims(x_shape.size()); CHECK_OR_RETURN(CheckBroadcastAndSimplifyDims(bias_shape, x_shape, simplified_bias_ndim, simplified_bias_dims.data())) << Error::RuntimeError() << "bias shape " << bias_shape.ToString() << " could not be broadcast to x shape " << x_shape.ToString(); CHECK_OR_RETURN(CheckBroadcastAndSimplifyDims(mask_shape, x_shape, simplified_mask_ndim, simplified_mask_dims.data())) << Error::RuntimeError() << "mask shape " << mask_shape.ToString() << " could not be broadcast to x shape " << x_shape.ToString(); CHECK_GT_OR_RETURN(simplified_bias_ndim, 0); // NOLINT(maybe-need-error-msg) CHECK_GT_OR_RETURN(simplified_mask_ndim, 0); // NOLINT(maybe-need-error-msg) // (1, ) -> (K, ) // (M, 1) -> (M, N) // (1, N) -> (M, N) // (M, 1, N) -> (M, K, N) if ((simplified_bias_ndim == 2 && simplified_bias_dims[0] != 1) || simplified_bias_ndim > 2) { return Error::RuntimeError() << "bias only support (1, N)->(M, N) broadcast, but got bias shape " << bias_shape.ToString() << " broadcast to x shape " << x_shape.ToString(); } if (simplified_mask_ndim > 3 || (simplified_mask_ndim == 3 && simplified_mask_dims[1] != 1)) { return Error::RuntimeError() << "mask support (M, 1)->(M, N) or (1, N)->(M, N) or (M, 1, " "N)->(M, K, N) broadcast, but got mask shape " << mask_shape.ToString() << " broadcast to x shape " << x_shape.ToString(); } ctx->SetOutputShape("y", 0, x_shape); ctx->SetOutputShape("softmax_y", 0, x_shape); ctx->SetOutputIsDynamic("y", 0, ctx->InputIsDynamic("x", 0)); ctx->SetOutputIsDynamic("softmax_y", 0, ctx->InputIsDynamic("x", 0)); return Maybe::Ok(); } Maybe FusedBiasAddScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } Maybe FusedBiasAddScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx) { const DataType x_dtype = ctx->InputDType("x", 0); const DataType bias_dtype = ctx->InputDType("bias", 0); const DataType mask_dtype = ctx->InputDType("mask", 0); const DataType dropout_mask_dtype = ctx->InputDType("dropout_mask", 0); CHECK_EQ_OR_RETURN(bias_dtype, x_dtype) << Error::RuntimeError() << "Expected bias data type " << DataType_Name(x_dtype) << ", but got " << DataType_Name(bias_dtype); CHECK_OR_RETURN(IsBoolDataType(mask_dtype) || IsIntegralDataType(mask_dtype)) << Error::RuntimeError() << "Expected mask data type to be bool or integer, but got " << DataType_Name(mask_dtype); CHECK_OR_RETURN(IsBoolDataType(dropout_mask_dtype)) << Error::RuntimeError() << "Expected dropout_mask data type to be bool, but got " << DataType_Name(dropout_mask_dtype); ctx->SetOutputDType("y", 0, x_dtype); ctx->SetOutputDType("softmax_y", 0, x_dtype); return Maybe::Ok(); } Maybe FusedBiasAddScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& bias_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("bias", 0).shape(); const Shape& mask_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0).shape(); const Shape& dropout_mask_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("dropout_mask", 0).shape(); CHECK_GE_OR_RETURN(x_shape.size(), 2) << Error::RuntimeError() << "x has at least 2 dimensions"; CHECK_EQ_OR_RETURN(dropout_mask_shape, x_shape) << Error::RuntimeError() << "dropout_mask_shape shape " << dropout_mask_shape.ToString() << " should be equal to x shape " << x_shape.ToString(); HashSet bias_broadcast_dims; HashSet mask_broadcast_dims; int bias_lpad = GetBroadcastDims(bias_shape, x_shape, bias_broadcast_dims); int mask_lpad = GetBroadcastDims(mask_shape, x_shape, mask_broadcast_dims); CHECK_GE_OR_RETURN(bias_lpad, 0) << Error::RuntimeError() << "bias shape " << bias_shape.ToString() << " could not be broadcast to x shape " << x_shape.ToString(); CHECK_GE_OR_RETURN(mask_lpad, 0) << Error::RuntimeError() << "mask shape " << mask_shape.ToString() << " could not be broadcast to x shape " << x_shape.ToString(); std::vector split_args = { {"x", 0}, {"dropout_mask", 0}, {"y", 0}, {"softmax_y", 0}, }; for (int i = 0; i < x_shape.size(); ++i) { bool bias_can_split = (bias_broadcast_dims.find(i) == bias_broadcast_dims.end()); bool mask_can_split = (mask_broadcast_dims.find(i) == mask_broadcast_dims.end()); if (bias_can_split && mask_can_split) { CHECK_GE_OR_RETURN(i, bias_lpad); // NOLINT(maybe-need-error-msg) CHECK_GE_OR_RETURN(i, mask_lpad); // NOLINT(maybe-need-error-msg) ctx->NewBuilder() .Split(split_args, i) .Split(user_op::OpArg("bias", 0), i - bias_lpad) .Split(user_op::OpArg("mask", 0), i - mask_lpad) .Build(); } else if (bias_can_split) { CHECK_GE_OR_RETURN(i, bias_lpad); // NOLINT(maybe-need-error-msg) ctx->NewBuilder() .Split(split_args, i) .Split(user_op::OpArg("bias", 0), i - bias_lpad) .Broadcast(user_op::OpArg("mask", 0)) .Build(); } else if (mask_can_split) { CHECK_GE_OR_RETURN(i, mask_lpad); // NOLINT(maybe-need-error-msg) ctx->NewBuilder() .Split(split_args, i) .Broadcast(user_op::OpArg("bias", 0)) .Split(user_op::OpArg("mask", 0), i - mask_lpad) .Build(); } else { ctx->NewBuilder() .Split(split_args, i) .Broadcast(user_op::OpArg("bias", 0)) .Broadcast(user_op::OpArg("mask", 0)) .Build(); } } return Maybe::Ok(); } Maybe FusedBiasAddScaleMaskSoftmaxDropoutOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); CHECK_OR_RETURN(mask_modifier != nullptr) << " cannot find mask input."; CHECK_OR_RETURN(dropout_mask_modifier != nullptr) << " cannot find dropout mask input."; mask_modifier->set_requires_grad(false); dropout_mask_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_cast_scale_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedCastScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().At(0), 1); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_is_dynamic(x.is_dynamic()); y->set_shape(x.shape()); return Maybe::Ok(); } Maybe FusedCastScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedCastScaleOp::InferLogicalTensorDesc(ctx); } Maybe FusedCastScaleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(scale_by_tensor.data_type()); return Maybe::Ok(); } Maybe FusedCastScaleOp::GetSbp(user_op::SbpContext* ctx) { const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); for (int i = 0; i < x.shape().NumAxes(); ++i) { ctx->NewBuilder() .Broadcast(user_op::OpArg("scale_by_tensor", 0)) .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("scale_by_tensor", 0)) .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("scale_by_tensor", 0)) .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_center_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedCenterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape()); user_op::TensorDesc* rho = ctx->MutOutputTensorDesc("rho2", 0); rho->set_is_dynamic(b1_x1.is_dynamic()); rho->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedCenterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedCenterOp::InferLogicalTensorDesc(ctx); } Maybe FusedCenterOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type()); user_op::TensorDesc* rho = ctx->MutOutputTensorDesc("rho2", 0); rho->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedCenterOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("rho2", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedCenterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); const user_op::TensorDesc& rho2_diff = ctx->InputTensorDesc("rho2_diff", 0); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), rho2_diff.shape()); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y2_diff->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedCenterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedCenterGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedCenterGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); const user_op::TensorDesc& rho2_diff = ctx->InputTensorDesc("rho2_diff", 0); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), rho2_diff.data_type()); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedCenterGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("rho2_diff", 0), i) .Split(user_op::OpArg("b1_x1_diff", 0), i) .Split(user_op::OpArg("b1_x2_diff", 0), i) .Split(user_op::OpArg("b1_y1_diff", 0), i) .Split(user_op::OpArg("b1_y2_diff", 0), i) .Split(user_op::OpArg("b2_x1_diff", 0), i) .Split(user_op::OpArg("b2_x2_diff", 0), i) .Split(user_op::OpArg("b2_y1_diff", 0), i) .Split(user_op::OpArg("b2_y2_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_clip_grad_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); CHECK_NOTNULL_OR_RETURN(arg_modifier) << "Arg Modifier should not be null. "; arg_modifier->set_is_mutable(true); return Maybe::Ok(); } Maybe InputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model_diff"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_diff", i)); } return Maybe::Ok(); } } // namespace /* static */ Maybe FusedClipGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("model_diff", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); for (int64_t i = 1; i < ctx->input_size("model_diff"); ++i) { const auto& cur_in = ctx->InputTensorDesc("model_diff", i); CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()) << Error::RuntimeError() << "inconsistent tensor size, expected all tensor to have the same shape, " << "but got " << in_0.shape().DebugStr() << " and " << cur_in.shape().DebugStr(); } out->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe FusedClipGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedClipGradOp::GetSbp(user_op::SbpContext* ctx) { const int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff", 0).shape().NumAxes(); for (int64_t i = 0; i < num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); return Maybe::Ok(); } /* static */ Maybe FusedClipGradOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return InputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe FusedClipGradOp::InferDataType(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("model_diff", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); const DataType data_type = in_0.data_type(); for (int64_t i = 1; i < ctx->input_size("model_diff"); ++i) { const auto& cur_in = ctx->InputTensorDesc("model_diff", i); CHECK_EQ_OR_RETURN(cur_in.data_type(), data_type) << Error::RuntimeError() << ctx->op_name() << " expected all tenser to have same type, but found " << DataType_Name(cur_in.data_type()) << " and " << DataType_Name(data_type); } out->set_data_type(data_type); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_codegeex_qkv_reshape.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedCodegeexQkvReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& query = ctx->InputTensorDesc("query", 0); const user_op::TensorDesc& key = ctx->InputTensorDesc("key", 0); const user_op::TensorDesc& value = ctx->InputTensorDesc("value", 0); const int32_t num_attention_heads = ctx->Attr("num_attention_heads"); CHECK_EQ_OR_RETURN(query.shape().size(), 3) << "query shape size should be equal 3"; CHECK_EQ_OR_RETURN(key.shape().size(), 3) << "key shape size should be equal 3"; CHECK_EQ_OR_RETURN(value.shape().size(), 3) << "value shape size should be equal 3"; CHECK_EQ_OR_RETURN(query.shape(), key.shape()) << "query, key, value should has same shape in codegeex attention block"; CHECK_EQ_OR_RETURN(query.shape(), value.shape()) << "query, key, value should has same shape in codegeex attention block"; CHECK_EQ_OR_RETURN(query.shape()[2] % num_attention_heads, 0) << "hidden_size must be divisible by num_attention_heads"; Shape new_shape(DimVector{query.shape()[0], query.shape()[1], num_attention_heads, query.shape()[2] / num_attention_heads}); user_op::TensorDesc* new_query = ctx->MutOutputTensorDesc("new_query", 0); new_query->set_is_dynamic(query.is_dynamic()); new_query->set_shape(new_shape); user_op::TensorDesc* new_key = ctx->MutOutputTensorDesc("new_key", 0); new_key->set_is_dynamic(key.is_dynamic()); new_key->set_shape(new_shape); user_op::TensorDesc* new_value = ctx->MutOutputTensorDesc("new_value", 0); new_value->set_is_dynamic(value.is_dynamic()); new_value->set_shape(new_shape); return Maybe::Ok(); } Maybe FusedCodegeexQkvReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedCodegeexQkvReshapeOp::InferLogicalTensorDesc(ctx); } Maybe FusedCodegeexQkvReshapeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& query = ctx->InputTensorDesc("query", 0); const user_op::TensorDesc& key = ctx->InputTensorDesc("key", 0); const user_op::TensorDesc& value = ctx->InputTensorDesc("value", 0); user_op::TensorDesc* new_query = ctx->MutOutputTensorDesc("new_query", 0); new_query->set_data_type(query.data_type()); user_op::TensorDesc* new_key = ctx->MutOutputTensorDesc("new_key", 0); new_key->set_data_type(key.data_type()); user_op::TensorDesc* new_value = ctx->MutOutputTensorDesc("new_value", 0); new_value->set_data_type(value.data_type()); return Maybe::Ok(); } Maybe FusedCodegeexQkvReshapeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& query = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0); FOR_RANGE(int64_t, i, 0, query.shape().NumAxes() - 1) { ctx->NewBuilder() .Split(user_op::OpArg("query", 0), i) .Split(user_op::OpArg("key", 0), i) .Split(user_op::OpArg("value", 0), i) .Split(user_op::OpArg("new_query", 0), i) .Split(user_op::OpArg("new_key", 0), i) .Split(user_op::OpArg("new_value", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_cross_feature_interaction_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FusedCrossFeatureInteractionOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); CHECK_EQ_OR_RETURN(x_shape.At(1), weight_shape.At(1)) << "Matmul K dims should be equal. "; ctx->SetOutputShape("matmul_result", 0, Shape({x_shape.At(0), weight_shape.At(0)})); const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_EQ_OR_RETURN(bias_shape.At(0), x0_shape.At(1)) << "Bias dim should be equal to X0 dim1. "; ctx->SetOutputShape("out", 0, x0_shape); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedCrossFeatureInteractionOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("x0", 0), 0) .Broadcast(user_op::OpArg("bias", 0)) .Split(user_op::OpArg("matmul_result", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("matmul_result", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV1GradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); ctx->SetOutputShape("dx0", 0, x0_shape); ctx->SetOutputShape("dw", 0, weight_shape); ctx->SetOutputShape("dx", 0, x0_shape); ctx->SetOutputShape("dbias", 0, Shape({x0_shape.At(1)})); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV1GradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedCrossFeatureInteractionV1GradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("x0", 0), 0) .Split(user_op::OpArg("matmul_result", 0), 0) .Split(user_op::OpArg("dx0", 0), 0) .PartialSum(user_op::OpArg("dw", 0)) .Split(user_op::OpArg("dx", 0), 0) .PartialSum(user_op::OpArg("dbias", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV1GradOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("dx0", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dw", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dbias", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV2GradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); ctx->SetOutputShape("dx0", 0, x0_shape); ctx->SetOutputShape("dw", 0, weight_shape); ctx->SetOutputShape("dx", 0, x0_shape); ctx->SetOutputShape("dbias", 0, Shape({x0_shape.At(1)})); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV2GradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedCrossFeatureInteractionV2GradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Broadcast(user_op::OpArg("weight", 0)) .Broadcast(user_op::OpArg("bias", 0)) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("x0", 0), 0) .Split(user_op::OpArg("matmul_result", 0), 0) .Split(user_op::OpArg("dx0", 0), 0) .PartialSum(user_op::OpArg("dw", 0)) .Split(user_op::OpArg("dx", 0), 0) .PartialSum(user_op::OpArg("dbias", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedCrossFeatureInteractionV2GradOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("dx0", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dw", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("dbias", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_dot_feature_interaction_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FusedDotFeatureInteractionOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const int64_t feature_input_size = ctx->input_size("features"); CHECK_GE_OR_RETURN(feature_input_size, 1); const Shape& first_feature_shape = ctx->InputShape("features", 0); CHECK_EQ_OR_RETURN(first_feature_shape.NumAxes(), 3); const int64_t batch_size = first_feature_shape.At(0); const int64_t vector_size = first_feature_shape.At(2); int64_t features_concated_dim = first_feature_shape.At(1); for (int64_t i = 1; i < feature_input_size; ++i) { const Shape& feature_shape = ctx->InputShape("features", i); CHECK_EQ_OR_RETURN(feature_shape.NumAxes(), 3); CHECK_EQ_OR_RETURN(feature_shape.At(0), batch_size); CHECK_EQ_OR_RETURN(feature_shape.At(2), vector_size); features_concated_dim += feature_shape.At(1); } const std::string& pooling = ctx->Attr("pooling"); if (pooling == "sum") { ctx->SetOutputShape("out", 0, Shape({batch_size, vector_size})); return Maybe::Ok(); } if (ctx->has_input("sparse_feature", 0)) { CHECK_OR_RETURN(pooling == "none") << "only none pooling support sparse feature."; CHECK_OR_RETURN(ctx->has_input("sparse_indices", 0)) << "if input sparse_feature exists, must have input sparse_indices."; const Shape& sparse_feature_shape = ctx->InputShape("sparse_feature", 0); const Shape& sparse_indices_shape = ctx->InputShape("sparse_indices", 0); CHECK_EQ_OR_RETURN(sparse_indices_shape.NumAxes(), 2) << "sparse_indices num_axes must be 2, but get " << sparse_indices_shape.NumAxes(); CHECK_EQ_OR_RETURN(sparse_indices_shape.At(0), batch_size) << "get " << sparse_indices_shape.At(0) << " and " << batch_size; CHECK_EQ_OR_RETURN(sparse_feature_shape.At(sparse_feature_shape.NumAxes() - 1), vector_size) << "get " << sparse_feature_shape.At(sparse_feature_shape.NumAxes() - 1) << " and " << vector_size; features_concated_dim += sparse_indices_shape.At(1); } const bool self_interaction = ctx->Attr("self_interaction"); const int32_t output_padding = ctx->Attr("output_padding"); const int64_t interaction_dim = self_interaction ? features_concated_dim * (features_concated_dim + 1) / 2 : features_concated_dim * (features_concated_dim - 1) / 2; int64_t out_dim = interaction_dim + output_padding; if (ctx->has_input("output_concat", 0)) { const Shape& output_concat_shape = ctx->InputShape("output_concat", 0); CHECK_EQ_OR_RETURN(output_concat_shape.NumAxes(), 2); CHECK_EQ_OR_RETURN(output_concat_shape.At(0), batch_size); out_dim += output_concat_shape.At(1); } ctx->SetOutputShape("out", 0, Shape({batch_size, out_dim})); return Maybe::Ok(); } /* static */ Maybe FusedDotFeatureInteractionOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedDotFeatureInteractionOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0); if (ctx->user_op_conf().has_input("num_valid_sparse_feature", 0)) { builder.Broadcast(user_op::OpArg("num_valid_sparse_feature", 0)); } builder.Build(); return Maybe::Ok(); } /* static */ Maybe FusedDotFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) { const int64_t feature_input_size = ctx->input_size("features"); CHECK_GE_OR_RETURN(feature_input_size, 1); DataType first_feature_dtype = ctx->InputDType("features", 0); for (int64_t i = 1; i < feature_input_size; ++i) { CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType("features", i)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("features", i)) << ", but got " << DataType_Name(first_feature_dtype); } if (ctx->has_input("output_concat", 0)) { CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType("output_concat", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("output_concat", 0)) << ", but got " << DataType_Name(first_feature_dtype); } if (ctx->has_input("sparse_feature", 0)) { CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType("sparse_feature", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("sparse_feature", 0)) << ", but got " << DataType_Name(first_feature_dtype); } ctx->SetOutputDType("out", 0, first_feature_dtype); return Maybe::Ok(); } /* static */ Maybe FusedDotFeatureInteractionGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const int64_t batch_size = dy_shape.At(0); CHECK_EQ_OR_RETURN(ctx->output_size("features_grad"), ctx->input_size("features")) << "features_grad and features must have same size"; for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { ctx->SetOutputShape("features_grad", i, ctx->InputShape("features", i)); } if (ctx->has_output("output_concat_grad", 0)) { const int32_t output_concat_grad_dim = ctx->Attr("output_concat_grad_dim"); ctx->SetOutputShape("output_concat_grad", 0, Shape({batch_size, output_concat_grad_dim})); } if (ctx->has_output("sparse_feature_grad", 0)) { ctx->SetOutputShape("sparse_feature_grad", 0, ctx->InputShape("sparse_feature", 0)); } return Maybe::Ok(); } /* static */ Maybe FusedDotFeatureInteractionGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedDotFeatureInteractionGradOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0); if (ctx->user_op_conf().has_input("num_valid_sparse_feature", 0)) { builder.Broadcast(user_op::OpArg("num_valid_sparse_feature", 0)); } builder.Build(); return Maybe::Ok(); } /* static */ Maybe FusedDotFeatureInteractionGradOp::InferDataType( user_op::InferContext* ctx) { DataType dy_dtype = ctx->InputDType("dy", 0); for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { ctx->SetOutputDType("features_grad", i, dy_dtype); } if (ctx->has_output("output_concat_grad", 0)) { ctx->SetOutputDType("output_concat_grad", 0, dy_dtype); } if (ctx->has_output("sparse_feature_grad", 0)) { ctx->SetOutputDType("sparse_feature_grad", 0, dy_dtype); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_boundding_boxes_coord_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetBounddingBoxesCoordOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x1 = ctx->InputTensorDesc("x1", 0); Shape x1_shape = x1.shape(); user_op::TensorDesc* b1_x1 = ctx->MutOutputTensorDesc("b1_x1", 0); b1_x1->set_is_dynamic(x1.is_dynamic()); b1_x1->set_shape(x1_shape); user_op::TensorDesc* b1_x2 = ctx->MutOutputTensorDesc("b1_x2", 0); b1_x2->set_is_dynamic(x1.is_dynamic()); b1_x2->set_shape(x1_shape); user_op::TensorDesc* b1_y1 = ctx->MutOutputTensorDesc("b1_y1", 0); b1_y1->set_is_dynamic(x1.is_dynamic()); b1_y1->set_shape(x1_shape); user_op::TensorDesc* b1_y2 = ctx->MutOutputTensorDesc("b1_y2", 0); b1_y2->set_is_dynamic(x1.is_dynamic()); b1_y2->set_shape(x1_shape); user_op::TensorDesc* b2_x1 = ctx->MutOutputTensorDesc("b2_x1", 0); b2_x1->set_is_dynamic(x1.is_dynamic()); b2_x1->set_shape(x1_shape); user_op::TensorDesc* b2_x2 = ctx->MutOutputTensorDesc("b2_x2", 0); b2_x2->set_is_dynamic(x1.is_dynamic()); b2_x2->set_shape(x1_shape); user_op::TensorDesc* b2_y1 = ctx->MutOutputTensorDesc("b2_y1", 0); b2_y1->set_is_dynamic(x1.is_dynamic()); b2_y1->set_shape(x1_shape); user_op::TensorDesc* b2_y2 = ctx->MutOutputTensorDesc("b2_y2", 0); b2_y2->set_is_dynamic(x1.is_dynamic()); b2_y2->set_shape(x1_shape); return Maybe::Ok(); } Maybe FusedGetBounddingBoxesCoordOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetBounddingBoxesCoordOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetBounddingBoxesCoordOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x1 = ctx->InputTensorDesc("x1", 0); user_op::TensorDesc* b1_x1 = ctx->MutOutputTensorDesc("b1_x1", 0); b1_x1->set_data_type(x1.data_type()); user_op::TensorDesc* b1_x2 = ctx->MutOutputTensorDesc("b1_x2", 0); b1_x2->set_data_type(x1.data_type()); user_op::TensorDesc* b1_y1 = ctx->MutOutputTensorDesc("b1_y1", 0); b1_y1->set_data_type(x1.data_type()); user_op::TensorDesc* b1_y2 = ctx->MutOutputTensorDesc("b1_y2", 0); b1_y2->set_data_type(x1.data_type()); user_op::TensorDesc* b2_x1 = ctx->MutOutputTensorDesc("b2_x1", 0); b2_x1->set_data_type(x1.data_type()); user_op::TensorDesc* b2_x2 = ctx->MutOutputTensorDesc("b2_x2", 0); b2_x2->set_data_type(x1.data_type()); user_op::TensorDesc* b2_y1 = ctx->MutOutputTensorDesc("b2_y1", 0); b2_y1->set_data_type(x1.data_type()); user_op::TensorDesc* b2_y2 = ctx->MutOutputTensorDesc("b2_y2", 0); b2_y2->set_data_type(x1.data_type()); return Maybe::Ok(); } Maybe FusedGetBounddingBoxesCoordOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); FOR_RANGE(int64_t, i, 0, x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x1", 0), i) .Split(user_op::OpArg("y1", 0), i) .Split(user_op::OpArg("w1", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("x2", 0), i) .Split(user_op::OpArg("y2", 0), i) .Split(user_op::OpArg("w2", 0), i) .Split(user_op::OpArg("h2", 0), i) .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetBounddingBoxesCoordGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1_diff = ctx->InputTensorDesc("b1_x1_diff", 0); user_op::TensorDesc* x1_diff = ctx->MutOutputTensorDesc("x1_diff", 0); x1_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); x1_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* y1_diff = ctx->MutOutputTensorDesc("y1_diff", 0); y1_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); y1_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc("w1_diff", 0); w1_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); w1_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc("h1_diff", 0); h1_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); h1_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* x2_diff = ctx->MutOutputTensorDesc("x2_diff", 0); x2_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); x2_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* y2_diff = ctx->MutOutputTensorDesc("y2_diff", 0); y2_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); y2_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc("w2_diff", 0); w2_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); w2_diff->set_shape(b1_x1_diff.shape()); user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc("h2_diff", 0); h2_diff->set_is_dynamic(b1_x1_diff.is_dynamic()); h2_diff->set_shape(b1_x1_diff.shape()); return Maybe::Ok(); } Maybe FusedGetBounddingBoxesCoordGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetBounddingBoxesCoordGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetBounddingBoxesCoordGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1_diff = ctx->InputTensorDesc("b1_x1_diff", 0); user_op::TensorDesc* x1_diff = ctx->MutOutputTensorDesc("x1_diff", 0); x1_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* y1_diff = ctx->MutOutputTensorDesc("y1_diff", 0); y1_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc("w1_diff", 0); w1_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc("h1_diff", 0); h1_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* x2_diff = ctx->MutOutputTensorDesc("x2_diff", 0); x2_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* y2_diff = ctx->MutOutputTensorDesc("y2_diff", 0); y2_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc("w2_diff", 0); w2_diff->set_data_type(b1_x1_diff.data_type()); user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc("h2_diff", 0); h2_diff->set_data_type(b1_x1_diff.data_type()); return Maybe::Ok(); } Maybe FusedGetBounddingBoxesCoordGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1_diff = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1_diff", 0); FOR_RANGE(int64_t, i, 0, b1_x1_diff.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1_diff", 0), i) .Split(user_op::OpArg("b1_x2_diff", 0), i) .Split(user_op::OpArg("b1_y1_diff", 0), i) .Split(user_op::OpArg("b1_y2_diff", 0), i) .Split(user_op::OpArg("b2_x1_diff", 0), i) .Split(user_op::OpArg("b2_x2_diff", 0), i) .Split(user_op::OpArg("b2_y1_diff", 0), i) .Split(user_op::OpArg("b2_y2_diff", 0), i) .Split(user_op::OpArg("x1_diff", 0), i) .Split(user_op::OpArg("y1_diff", 0), i) .Split(user_op::OpArg("w1_diff", 0), i) .Split(user_op::OpArg("h1_diff", 0), i) .Split(user_op::OpArg("x2_diff", 0), i) .Split(user_op::OpArg("y2_diff", 0), i) .Split(user_op::OpArg("w2_diff", 0), i) .Split(user_op::OpArg("h2_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_ciou_diagonal_angle_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetCiouDiagonalAngleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); const user_op::TensorDesc& h1 = ctx->InputTensorDesc("h1", 0); const user_op::TensorDesc& w2 = ctx->InputTensorDesc("w2", 0); const user_op::TensorDesc& h2 = ctx->InputTensorDesc("h2", 0); CHECK_EQ_OR_RETURN(w1.shape(), h1.shape()); CHECK_EQ_OR_RETURN(w1.shape(), w2.shape()); CHECK_EQ_OR_RETURN(w1.shape(), h2.shape()); user_op::TensorDesc* v = ctx->MutOutputTensorDesc("v", 0); v->set_is_dynamic(w1.is_dynamic()); v->set_shape(w1.shape()); return Maybe::Ok(); } Maybe FusedGetCiouDiagonalAngleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetCiouDiagonalAngleOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetCiouDiagonalAngleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); const user_op::TensorDesc& h1 = ctx->InputTensorDesc("h1", 0); const user_op::TensorDesc& w2 = ctx->InputTensorDesc("w2", 0); const user_op::TensorDesc& h2 = ctx->InputTensorDesc("h2", 0); CHECK_EQ_OR_RETURN(w1.data_type(), h1.data_type()); CHECK_EQ_OR_RETURN(w1.data_type(), w2.data_type()); CHECK_EQ_OR_RETURN(w1.data_type(), h2.data_type()); user_op::TensorDesc* v = ctx->MutOutputTensorDesc("v", 0); v->set_data_type(w1.data_type()); return Maybe::Ok(); } Maybe FusedGetCiouDiagonalAngleOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("w1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("w1", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("w2", 0), i) .Split(user_op::OpArg("h2", 0), i) .Split(user_op::OpArg("v", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetCiouDiagonalAngleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); const user_op::TensorDesc& h1 = ctx->InputTensorDesc("h1", 0); const user_op::TensorDesc& w2 = ctx->InputTensorDesc("w2", 0); const user_op::TensorDesc& h2 = ctx->InputTensorDesc("h2", 0); const user_op::TensorDesc& v_diff = ctx->InputTensorDesc("v_diff", 0); CHECK_EQ_OR_RETURN(w1.shape(), h1.shape()); CHECK_EQ_OR_RETURN(w1.shape(), w2.shape()); CHECK_EQ_OR_RETURN(w1.shape(), h2.shape()); CHECK_EQ_OR_RETURN(w1.shape(), v_diff.shape()); user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc("w1_diff", 0); w1_diff->set_is_dynamic(w1.is_dynamic()); w1_diff->set_shape(w1.shape()); user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc("h1_diff", 0); h1_diff->set_is_dynamic(w1.is_dynamic()); h1_diff->set_shape(w1.shape()); user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc("w2_diff", 0); w2_diff->set_is_dynamic(w1.is_dynamic()); w2_diff->set_shape(w1.shape()); user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc("h2_diff", 0); h2_diff->set_is_dynamic(w1.is_dynamic()); h2_diff->set_shape(w1.shape()); return Maybe::Ok(); } Maybe FusedGetCiouDiagonalAngleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetCiouDiagonalAngleGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetCiouDiagonalAngleGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); const user_op::TensorDesc& h1 = ctx->InputTensorDesc("h1", 0); const user_op::TensorDesc& w2 = ctx->InputTensorDesc("w2", 0); const user_op::TensorDesc& h2 = ctx->InputTensorDesc("h2", 0); const user_op::TensorDesc& v_diff = ctx->InputTensorDesc("v_diff", 0); CHECK_EQ_OR_RETURN(w1.data_type(), h1.data_type()); CHECK_EQ_OR_RETURN(w1.data_type(), w2.data_type()); CHECK_EQ_OR_RETURN(w1.data_type(), h2.data_type()); CHECK_EQ_OR_RETURN(w1.data_type(), v_diff.data_type()); user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc("w1_diff", 0); w1_diff->set_is_dynamic(w1.is_dynamic()); w1_diff->set_data_type(w1.data_type()); user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc("h1_diff", 0); h1_diff->set_is_dynamic(w1.is_dynamic()); h1_diff->set_data_type(w1.data_type()); user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc("w2_diff", 0); w2_diff->set_is_dynamic(w1.is_dynamic()); w2_diff->set_data_type(w1.data_type()); user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc("h2_diff", 0); h2_diff->set_is_dynamic(w1.is_dynamic()); h2_diff->set_data_type(w1.data_type()); return Maybe::Ok(); } Maybe FusedGetCiouDiagonalAngleGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& w1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("w1", 0); FOR_RANGE(int64_t, i, 0, w1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("w1", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("w2", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("v_diff", 0), i) .Split(user_op::OpArg("w1_diff", 0), i) .Split(user_op::OpArg("h1_diff", 0), i) .Split(user_op::OpArg("w2_diff", 0), i) .Split(user_op::OpArg("h2_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_ciou_result_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetCiouResultOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_is_dynamic(v.is_dynamic()); y->set_shape(v.shape()); user_op::TensorDesc* ahpha = ctx->MutOutputTensorDesc("alpha", 0); ahpha->set_is_dynamic(v.is_dynamic()); ahpha->set_shape(v.shape()); return Maybe::Ok(); } Maybe FusedGetCiouResultOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetCiouResultOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetCiouResultOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(v.data_type()); user_op::TensorDesc* alpha = ctx->MutOutputTensorDesc("alpha", 0); alpha->set_data_type(v.data_type()); return Maybe::Ok(); } Maybe FusedGetCiouResultOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& v = ctx->LogicalTensorDesc4InputArgNameAndIndex("v", 0); FOR_RANGE(int64_t, i, 0, v.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("v", 0), i) .Split(user_op::OpArg("iou", 0), i) .Split(user_op::OpArg("rho2", 0), i) .Split(user_op::OpArg("c2", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("alpha", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetCiouResultGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dv = ctx->MutOutputTensorDesc("dv", 0); dv->set_is_dynamic(dy.is_dynamic()); dv->set_shape(dy.shape()); user_op::TensorDesc* diou = ctx->MutOutputTensorDesc("diou", 0); diou->set_is_dynamic(dy.is_dynamic()); diou->set_shape(dy.shape()); user_op::TensorDesc* drho2 = ctx->MutOutputTensorDesc("drho2", 0); drho2->set_is_dynamic(dy.is_dynamic()); drho2->set_shape(dy.shape()); user_op::TensorDesc* dc2 = ctx->MutOutputTensorDesc("dc2", 0); dc2->set_is_dynamic(dy.is_dynamic()); dc2->set_shape(dy.shape()); return Maybe::Ok(); } Maybe FusedGetCiouResultGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetCiouResultGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetCiouResultGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dv = ctx->MutOutputTensorDesc("dv", 0); dv->set_data_type(dy.data_type()); user_op::TensorDesc* diou = ctx->MutOutputTensorDesc("diou", 0); diou->set_data_type(dy.data_type()); user_op::TensorDesc* drho2 = ctx->MutOutputTensorDesc("drho2", 0); drho2->set_data_type(dy.data_type()); user_op::TensorDesc* dc2 = ctx->MutOutputTensorDesc("dc2", 0); dc2->set_data_type(dy.data_type()); return Maybe::Ok(); } Maybe FusedGetCiouResultGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); FOR_RANGE(int64_t, i, 0, dy.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("alpha", 0), i) .Split(user_op::OpArg("rho2", 0), i) .Split(user_op::OpArg("c2", 0), i) .Split(user_op::OpArg("dv", 0), i) .Split(user_op::OpArg("diou", 0), i) .Split(user_op::OpArg("drho2", 0), i) .Split(user_op::OpArg("dc2", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_convex_diagonal_squared_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetConvexDiagonalSquaredOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); user_op::TensorDesc* c2 = ctx->MutOutputTensorDesc("c2", 0); c2->set_is_dynamic(b1_x1.is_dynamic()); c2->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedGetConvexDiagonalSquaredOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetConvexDiagonalSquaredOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetConvexDiagonalSquaredOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); user_op::TensorDesc* c2 = ctx->MutOutputTensorDesc("c2", 0); c2->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedGetConvexDiagonalSquaredOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("c2", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetConvexDiagonalSquaredGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y2_diff->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedGetConvexDiagonalSquaredGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return FusedGetConvexDiagonalSquaredGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetConvexDiagonalSquaredGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedGetConvexDiagonalSquaredGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("c2_diff", 0), i) .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("b1_x1_diff", 0), i) .Split(user_op::OpArg("b1_x2_diff", 0), i) .Split(user_op::OpArg("b2_x1_diff", 0), i) .Split(user_op::OpArg("b2_x2_diff", 0), i) .Split(user_op::OpArg("b1_y1_diff", 0), i) .Split(user_op::OpArg("b1_y2_diff", 0), i) .Split(user_op::OpArg("b2_y1_diff", 0), i) .Split(user_op::OpArg("b2_y2_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_intersection_area_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetIntersectionAreaOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape()); user_op::TensorDesc* inter = ctx->MutOutputTensorDesc("inter", 0); inter->set_is_dynamic(b1_x1.is_dynamic()); inter->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedGetIntersectionAreaOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetIntersectionAreaOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetIntersectionAreaOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type()); user_op::TensorDesc* inter = ctx->MutOutputTensorDesc("inter", 0); inter->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedGetIntersectionAreaOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("inter", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetIntersectionAreaGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); const user_op::TensorDesc& inter_diff = ctx->InputTensorDesc("inter_diff", 0); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape()); CHECK_EQ_OR_RETURN(b1_x1.shape(), inter_diff.shape()); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_x2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b1_y2_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y1_diff->set_shape(b1_x1.shape()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic()); b2_y2_diff->set_shape(b1_x1.shape()); return Maybe::Ok(); } Maybe FusedGetIntersectionAreaGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetIntersectionAreaGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetIntersectionAreaGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc("b1_x1", 0); const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc("b1_x2", 0); const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc("b1_y1", 0); const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc("b1_y2", 0); const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc("b2_x1", 0); const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc("b2_x2", 0); const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc("b2_y1", 0); const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc("b2_y2", 0); const user_op::TensorDesc& inter_diff = ctx->InputTensorDesc("inter_diff", 0); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type()); CHECK_EQ_OR_RETURN(b1_x1.data_type(), inter_diff.data_type()); user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc("b1_x1_diff", 0); b1_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc("b1_x2_diff", 0); b1_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc("b2_x1_diff", 0); b2_x1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc("b2_x2_diff", 0); b2_x2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc("b1_y1_diff", 0); b1_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc("b1_y2_diff", 0); b1_y2_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc("b2_y1_diff", 0); b2_y1_diff->set_data_type(b1_x1.data_type()); user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc("b2_y2_diff", 0); b2_y2_diff->set_data_type(b1_x1.data_type()); return Maybe::Ok(); } Maybe FusedGetIntersectionAreaGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("b1_x1", 0); FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("b1_x1", 0), i) .Split(user_op::OpArg("b1_x2", 0), i) .Split(user_op::OpArg("b1_y1", 0), i) .Split(user_op::OpArg("b1_y2", 0), i) .Split(user_op::OpArg("b2_x1", 0), i) .Split(user_op::OpArg("b2_x2", 0), i) .Split(user_op::OpArg("b2_y1", 0), i) .Split(user_op::OpArg("b2_y2", 0), i) .Split(user_op::OpArg("inter_diff", 0), i) .Split(user_op::OpArg("b1_x1_diff", 0), i) .Split(user_op::OpArg("b1_x2_diff", 0), i) .Split(user_op::OpArg("b1_y1_diff", 0), i) .Split(user_op::OpArg("b1_y2_diff", 0), i) .Split(user_op::OpArg("b2_x1_diff", 0), i) .Split(user_op::OpArg("b2_x2_diff", 0), i) .Split(user_op::OpArg("b2_y1_diff", 0), i) .Split(user_op::OpArg("b2_y2_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_get_iou_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe FusedGetIouOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); user_op::TensorDesc* iou = ctx->MutOutputTensorDesc("iou", 0); iou->set_is_dynamic(w1.is_dynamic()); iou->set_shape(w1.shape()); return Maybe::Ok(); } Maybe FusedGetIouOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetIouOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetIouOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& w1 = ctx->InputTensorDesc("w1", 0); user_op::TensorDesc* iou = ctx->MutOutputTensorDesc("iou", 0); iou->set_data_type(w1.data_type()); return Maybe::Ok(); } Maybe FusedGetIouOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& w1 = ctx->LogicalTensorDesc4InputArgNameAndIndex("w1", 0); FOR_RANGE(int64_t, i, 0, w1.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("w1", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("w2", 0), i) .Split(user_op::OpArg("h2", 0), i) .Split(user_op::OpArg("inter", 0), i) .Split(user_op::OpArg("iou", 0), i) .Build(); } return Maybe::Ok(); } Maybe FusedGetIouGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& diou = ctx->InputTensorDesc("diou", 0); user_op::TensorDesc* dw1 = ctx->MutOutputTensorDesc("dw1", 0); dw1->set_is_dynamic(diou.is_dynamic()); dw1->set_shape(diou.shape()); user_op::TensorDesc* dh1 = ctx->MutOutputTensorDesc("dh1", 0); dh1->set_is_dynamic(diou.is_dynamic()); dh1->set_shape(diou.shape()); user_op::TensorDesc* dinter = ctx->MutOutputTensorDesc("dinter", 0); dinter->set_is_dynamic(diou.is_dynamic()); dinter->set_shape(diou.shape()); return Maybe::Ok(); } Maybe FusedGetIouGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return FusedGetIouGradOp::InferLogicalTensorDesc(ctx); } Maybe FusedGetIouGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& diou = ctx->InputTensorDesc("diou", 0); user_op::TensorDesc* dw1 = ctx->MutOutputTensorDesc("dw1", 0); dw1->set_data_type(diou.data_type()); user_op::TensorDesc* dh1 = ctx->MutOutputTensorDesc("dh1", 0); dh1->set_data_type(diou.data_type()); user_op::TensorDesc* dinter = ctx->MutOutputTensorDesc("dinter", 0); dinter->set_data_type(diou.data_type()); return Maybe::Ok(); } Maybe FusedGetIouGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); FOR_RANGE(int64_t, i, 0, dy.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("diou", 0), i) .Split(user_op::OpArg("w1", 0), i) .Split(user_op::OpArg("h1", 0), i) .Split(user_op::OpArg("w2", 0), i) .Split(user_op::OpArg("h2", 0), i) .Split(user_op::OpArg("inter", 0), i) .Split(user_op::OpArg("dw1", 0), i) .Split(user_op::OpArg("dh1", 0), i) .Split(user_op::OpArg("dinter", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_glu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ auto FusedGluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { // check whether the user provide weight tensor v bool is_split_mode = false; if (ctx->user_op_conf().has_input("v", 0)) { is_split_mode = true; } bool has_b = ctx->user_op_conf().has_input("b", 0); bool has_c = ctx->user_op_conf().has_input("c", 0); // check whether the user provide bais tensors CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c))) << "expected existance of c, when provide tensors w, v and b"; bool has_bias = false; if (has_b && (is_split_mode && has_c)) { has_bias = true; } else if (has_b && (!is_split_mode)) { has_bias = true; } else { has_bias = false; } // data parallelism for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes() - 1; ++i) { if (is_split_mode && has_bias) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("w", 0)) .Broadcast(user_op::OpArg("b", 0)) .Broadcast(user_op::OpArg("v", 0)) .Broadcast(user_op::OpArg("c", 0)) .Split(ctx->outputs(), i) .Build(); } else if (is_split_mode && !has_bias) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("w", 0)) .Broadcast(user_op::OpArg("v", 0)) .Split(ctx->outputs(), i) .Build(); } else if (!is_split_mode && has_bias) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("w", 0)) .Broadcast(user_op::OpArg("b", 0)) .Split(ctx->outputs(), i) .Build(); } else if (!is_split_mode && !has_bias) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("w", 0)) .Split(ctx->outputs(), i) .Build(); } } // model parallelism if (is_split_mode && has_bias) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("w", 0), 0) .Split(user_op::OpArg("b", 0), 0) .Split(user_op::OpArg("v", 0), 0) .Split(user_op::OpArg("c", 0), 0) .Split(ctx->outputs(), ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes() - 1) .Build(); } else if (is_split_mode && !has_bias) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("w", 0), 0) .Split(user_op::OpArg("v", 0), 0) .Split(ctx->outputs(), ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes() - 1) .Build(); } return Maybe::Ok(); } /* static */ auto FusedGluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { // obtain input shape const Shape& x_shape = ctx->InputShape("x", 0); const Shape& w_shape = ctx->InputShape("w", 0); // check whether the user provide weight tensor v bool is_split_mode = false; if (ctx->has_input("v", 0)) { is_split_mode = true; } bool has_b = ctx->has_input("b", 0); bool has_c = ctx->has_input("c", 0); // check whether the user provide bais tensors CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c))) << "expected existance of c, when provide tensors w, v and b"; bool has_bias = false; if (has_b && (is_split_mode && has_c)) { has_bias = true; } else if (has_b && (!is_split_mode)) { has_bias = true; } else { has_bias = false; } // check dimensions of x, w and b CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1) << "number of axes of \'x\' should have be greater than 1, yet get " << x_shape.NumAxes(); CHECK_EQ_OR_RETURN(w_shape.NumAxes(), 2) << "number of axes of \'w\' should have be equal to 2, yet get " << w_shape.NumAxes(); if (has_bias) { const Shape& b_shape = ctx->InputShape("b", 0); CHECK_EQ_OR_RETURN(b_shape.NumAxes(), 1) << "number of axes of \'b\' should have be equal to 1, yet get " << b_shape.NumAxes(); } // check input shapes of w and b size_t x_num_axes = x_shape.NumAxes(); CHECK_EQ_OR_RETURN(w_shape.At(1), x_shape.At(x_num_axes - 1)) << "dimension 1 of \'w\'(" << w_shape.At(1) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_num_axes - 1) << ")"; if (has_bias) { const Shape& b_shape = ctx->InputShape("b", 0); CHECK_EQ_OR_RETURN(b_shape.At(0), w_shape.At(0)) << "dimension 0 of \'b\'(" << b_shape.At(0) << ") is not consistant with dimension 0 of \'w\'(" << w_shape.At(0) << ")"; } if (!is_split_mode) { CHECK_EQ_OR_RETURN(w_shape.At(1) % 2, 0) << "dimension 1 of \'w\' is not divisible by 2"; } // check both dimensions and input shapes of v and c (optional) if (is_split_mode) { const Shape& v_shape = ctx->InputShape("v", 0); CHECK_EQ_OR_RETURN(v_shape.NumAxes(), 2) << "number of axes of \'v\' should have be equal to 2, yet get " << v_shape.NumAxes(); CHECK_OR_RETURN(v_shape == w_shape) << "the shape of \'v\' is not consistant with \'w\'"; if (has_bias) { const Shape& b_shape = ctx->InputShape("b", 0); const Shape& c_shape = ctx->InputShape("c", 0); CHECK_EQ_OR_RETURN(c_shape.NumAxes(), 1) << "number of axes of \'c\' should have be equal to 1, yet get " << c_shape.NumAxes(); CHECK_OR_RETURN(c_shape == b_shape) << "the shape of \'c\' is not consistant with \'b\'"; } } // set shape of the output tensor y Shape y_shape = x_shape; // borrow from input shape size_t y_num_axes = x_num_axes; if (is_split_mode) { y_shape.Set(y_num_axes - 1, w_shape.At(0)); } else { y_shape.Set(y_num_axes - 1, w_shape.At(0) / 2); } user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc("y", 0); y_tensor->set_shape(y_shape); // set shape of the output tensors of both matmul_wx and matmul_vx Shape matmul_wx_shape = x_shape; // borrow from input shape matmul_wx_shape.Set(x_num_axes - 1, w_shape.At(0)); user_op::TensorDesc* matmul_wx_tensor = ctx->MutOutputTensorDesc("matmul_wx", 0); matmul_wx_tensor->set_shape(matmul_wx_shape); if (is_split_mode) { user_op::TensorDesc* matmul_vx_tensor = ctx->MutOutputTensorDesc("matmul_vx", 0); matmul_vx_tensor->set_shape(y_shape); } return Maybe::Ok(); } /* static */ auto FusedGluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /* static */ auto FusedGluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { // obtain input data types DataType x_dtype = ctx->InputDType("x", 0); // check whether the user provide weight tensor v bool is_split_mode = false; if (ctx->has_input("v", 0)) { is_split_mode = true; } bool has_b = ctx->has_input("b", 0); bool has_c = ctx->has_input("c", 0); // check whether the user provide bais tensors CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c))) << "expected existance of c, when provide tensors w, v and b"; bool has_bias = false; if (has_b && (is_split_mode && has_c)) { has_bias = true; } else if (has_b && (!is_split_mode)) { has_bias = true; } else { has_bias = false; } // check types of x, w and b CHECK_EQ_OR_RETURN(ctx->InputDType("w", 0), x_dtype) << "data type of \'w\' is not consitant with \'x\'"; if (has_bias) { CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), x_dtype) << "data type of \'b\' is not consitant with \'x\'"; } // check types of v and c (optional) if (is_split_mode) { CHECK_EQ_OR_RETURN(ctx->InputDType("v", 0), x_dtype) << "data type of \'v\' is not consitant with \'x\'"; if (has_bias) { CHECK_EQ_OR_RETURN(ctx->InputDType("c", 0), x_dtype) << "data type of \'c\' is not consitant with \'x\'"; } } // set output data type ctx->SetOutputDType("y", 0, x_dtype); ctx->SetOutputDType("matmul_wx", 0, x_dtype); if (is_split_mode) { ctx->SetOutputDType("matmul_vx", 0, x_dtype); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_glu_without_linear_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ auto FusedGluWithoutLinearGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { // check existance of optional args bool is_split_mode = false; if (ctx->user_op_conf().has_input("matmul_vx", 0)) { is_split_mode = true; } for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes() - 1; ++i) { if (is_split_mode) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("matmul_wx", 0), i) .Split(user_op::OpArg("matmul_vx", 0), i) .Split(ctx->outputs(), i) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("matmul_wx", 0), i) .Split(ctx->outputs(), i) .Build(); } } return Maybe::Ok(); } /* static */ auto FusedGluWithoutLinearGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { // obtain input shape const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& matmul_wx_shape = ctx->InputShape("matmul_wx", 0); // check existance of optional args bool is_split_mode = false; if (ctx->has_input("matmul_vx", 0)) { is_split_mode = true; } // obtain dimensions of dy and matmul_wx size_t dy_num_axes = dy_shape.NumAxes(); size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes(); // check dimensions of dy and matmul_wx CHECK_GT_OR_RETURN(dy_num_axes, 1) << "number of axes of \'dy\' should have be greater than 1, yet get " << dy_num_axes; CHECK_GT_OR_RETURN(matmul_wx_num_axes, 1) << "number of axes of \'matmul_wx\' should have be greater than 1, yet get " << matmul_wx_num_axes; CHECK_EQ_OR_RETURN(dy_num_axes, matmul_wx_num_axes) << "number of axes of \'dy\'(" << dy_num_axes << ") is not consistant with the one of \'matmul_wx\'(" << matmul_wx_num_axes << ")"; // check input shapes of dy and matmul_wx for (uint64_t i = 0; i < dy_num_axes - 1; i++) { size_t dy_size = dy_shape.At(i); size_t matmul_wx_size = matmul_wx_shape.At(i); CHECK_EQ_OR_RETURN(dy_size, matmul_wx_size) << "dimension " << i << "of \'dy\'(" << dy_size << ") and \'matmul_wx\'(" << matmul_wx_size << ") is not consistent"; } if (is_split_mode) { CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "the last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ")"; } else { CHECK_EQ_OR_RETURN(2 * dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1)) << "two times of the last dimension of \'dy\'(" << 2 * dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_wx\'(" << matmul_wx_shape.At(matmul_wx_num_axes - 1) << ")"; } // check both dimensions and input shapes of matmul_vx (optional) if (is_split_mode) { // obtain input shape const Shape& matmul_vx_shape = ctx->InputShape("matmul_vx", 0); // check dimensions of matmul_vx size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes(); CHECK_GT_OR_RETURN(matmul_vx_num_axes, 1) << "number of axes of \'matmul_vx\' should have be greater than 1, yet get " << matmul_vx_num_axes; CHECK_EQ_OR_RETURN(matmul_vx_num_axes, dy_num_axes) << "number of axes of \'dy\'(" << dy_num_axes << ") is not consistant with the one of \'matmul_vx\'(" << matmul_vx_num_axes << ")"; // check input shapes of dy and matmul_vx for (uint64_t i = 0; i < dy_num_axes - 1; i++) { size_t dy_size = dy_shape.At(i); size_t matmul_vx_size = matmul_vx_shape.At(i); CHECK_EQ_OR_RETURN(dy_size, matmul_vx_size) << "dimension " << i << "of \'dy\'(" << dy_size << ") and \'matmul_vx\'(" << matmul_vx_size << ") is not consistent"; } CHECK_EQ_OR_RETURN(matmul_vx_shape.At(matmul_vx_num_axes - 1), dy_shape.At(dy_num_axes - 1)) << "the last dimension of \'dy\'(" << dy_shape.At(dy_num_axes - 1) << ") is not consistant with the last dimension of \'matmul_vx\'(" << matmul_vx_shape.At(matmul_vx_num_axes - 1) << ")"; } // set shape of the output tensor d_matmul_wx Shape d_matmul_wx_shape = matmul_wx_shape; // borrow from input shape user_op::TensorDesc* d_matmul_wx_tensor = ctx->MutOutputTensorDesc("d_matmul_wx", 0); d_matmul_wx_tensor->set_shape(d_matmul_wx_shape); // set shape of the output tensor d_matmul_vx (optional) if (is_split_mode) { const Shape& matmul_vx_shape = ctx->InputShape("matmul_vx", 0); Shape d_matmul_vx_shape = matmul_vx_shape; // borrow from input shape user_op::TensorDesc* d_matmul_vx_tensor = ctx->MutOutputTensorDesc("d_matmul_vx", 0); d_matmul_vx_tensor->set_shape(d_matmul_vx_shape); } return Maybe::Ok(); } /* static */ auto FusedGluWithoutLinearGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /* static */ auto FusedGluWithoutLinearGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { // obtain input data types DataType dy_dtype = ctx->InputDType("dy", 0); // check types of matmul_wx CHECK_EQ_OR_RETURN(ctx->InputDType("matmul_wx", 0), dy_dtype) << "data type of \'matmul_wx\' is not consitant with \'dy\'"; bool is_split_mode = ctx->has_input("matmul_vx", 0); // check types of matmul_vx (optional) if (is_split_mode) { CHECK_EQ_OR_RETURN(ctx->InputDType("matmul_vx", 0), dy_dtype) << "data type of \'matmul_vx\' is not consitant with \'dy\'"; } // set output data type ctx->SetOutputDType("d_matmul_wx", 0, dy_dtype); if (is_split_mode) { ctx->SetOutputDType("d_matmul_vx", 0, dy_dtype); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_gru_cell_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FusedGruCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& hx_shape = ctx->InputShape("hx", 0); ctx->SetOutputShape("hy", 0, hx_shape); ctx->SetOutputShape("workspace", 0, Shape({hx_shape.At(0), hx_shape.At(1) * 5})); return Maybe::Ok(); } /*static*/ Maybe FusedGruCellOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedGruCellOp::GetSbp(user_op::SbpContext* ctx) { // input_gates shape: [batch_size, hidden_size * 3] // hidden_gates shape: [batch_size, hidden_size * 3] // hx shape: [batch_size, hidden_size] // input_bias shape: [hidden_size * 3] // hidden_bias shape: [hidden_size * 3] // hy shape: [batch_size, hidden_size] // workspace shape: [batch_size, hidden_size * 5] std::vector broadcast_args; if (ctx->user_op_conf().has_input("input_bias", 0)) { broadcast_args.emplace_back("input_bias", 0); } if (ctx->user_op_conf().has_input("hidden_bias", 0)) { broadcast_args.emplace_back("hidden_bias", 0); } std::vector split_args; split_args.emplace_back("input_gates", 0); split_args.emplace_back("hidden_gates", 0); split_args.emplace_back("hx", 0); split_args.emplace_back("hy", 0); split_args.emplace_back("workspace", 0); ctx->NewBuilder().Split(split_args, 0).Broadcast(broadcast_args).Build(); return Maybe::Ok(); } /* static */ Maybe FusedGruCellOp::InferDataType(user_op::InferContext* ctx) { DataType in_types = ctx->InputDType("hx", 0); ctx->SetOutputDType("hy", 0, in_types); ctx->SetOutputDType("workspace", 0, in_types); return Maybe::Ok(); } /* static */ Maybe FusedGruCellGradOp ::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& grad_hy_shape = ctx->InputShape("grad_hy", 0); DimVector dim_vec({grad_hy_shape.At(0), grad_hy_shape.At(1) * 3}); ctx->SetOutputShape("grad_input_gates", 0, Shape(dim_vec)); ctx->SetOutputShape("grad_hidden_gates", 0, Shape(dim_vec)); if (ctx->has_output("grad_hx", 0)) { ctx->SetOutputShape("grad_hx", 0, grad_hy_shape); } if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { ctx->SetOutputShape("grad_input_bias", 0, Shape({grad_hy_shape.At(1) * 3})); ctx->SetOutputShape("grad_hidden_bias", 0, Shape({grad_hy_shape.At(1) * 3})); } return Maybe::Ok(); } /*static*/ Maybe FusedGruCellGradOp ::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedGruCellGradOp ::GetSbp(user_op::SbpContext* ctx) { // grad_hy shape: [batch_size, hidden_size] // workspace shape: [batch_size, hidden_size * 5] // grad_input_gates shape: [batch_size, hidden_size * 3] // grad_hidden_gates shape: [batch_size, hidden_size * 3] // grad_hx shape: [batch_size, hidden_size] // grad_input_bias shape: [hidden_size * 3] // grad_hidden_bias shape: [hidden_size * 3] std::vector partial_sum_args; if (ctx->user_op_conf().has_output("grad_input_bias", 0)) { partial_sum_args.emplace_back("grad_input_bias", 0); } if (ctx->user_op_conf().has_output("grad_hidden_bias", 0)) { partial_sum_args.emplace_back("grad_hidden_bias", 0); } std::vector split_args; split_args.emplace_back("grad_hy", 0); split_args.emplace_back("workspace", 0); split_args.emplace_back("grad_input_gates", 0); split_args.emplace_back("grad_hidden_gates", 0); if (ctx->user_op_conf().has_output("grad_hx", 0)) { split_args.emplace_back("grad_hx", 0); } ctx->NewBuilder().Split(split_args, 0).PartialSum(partial_sum_args).Build(); return Maybe::Ok(); } /* static */ Maybe FusedGruCellGradOp ::InferDataType(user_op::InferContext* ctx) { DataType in_types = ctx->InputDType("grad_hy", 0); ctx->SetOutputDType("grad_input_gates", 0, in_types); ctx->SetOutputDType("grad_hidden_gates", 0, in_types); if (ctx->has_output("grad_hx", 0)) { ctx->SetOutputDType("grad_hx", 0, in_types); } if (ctx->has_output("grad_input_bias", 0)) { ctx->SetOutputDType("grad_input_bias", 0, in_types); } if (ctx->has_output("grad_hidden_bias", 0)) { ctx->SetOutputDType("grad_hidden_bias", 0, in_types); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_linear_with_groupwise_quantized_weight_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4FusedMatmulBias(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2); const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1); const user_op::TensorDesc& w_desc = ctx->InputTensorDesc("w", 0); CHECK_EQ_OR_RETURN(w_desc.shape().NumAxes(), 2); const int64_t n = w_desc.shape().At(0); const int32_t num_bits = ctx->Attr("num_bits"); if (num_bits == 8) { CHECK_EQ_OR_RETURN(w_desc.shape().At(1), k); } else if (num_bits == 4) { CHECK_EQ_OR_RETURN(w_desc.shape().At(1) * 2, k); } else { UNIMPLEMENTED_THEN_RETURN(); } const int64_t group_dim = ctx->Attr("group_dim"); CHECK_OR_RETURN(group_dim == 0 || group_dim == 1); const int64_t group_dim_size = group_dim == 0 ? n : k; const int64_t group_size = ctx->Attr("group_size"); CHECK_GT_OR_RETURN(group_size, 1); CHECK_LE_OR_RETURN(group_size, group_dim_size); CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0); const int64_t num_groups = group_dim_size / group_size; const user_op::TensorDesc& w_scale_desc = ctx->InputTensorDesc("w_scale", 0); CHECK_EQ_OR_RETURN(w_scale_desc.shape().NumAxes(), 2); if (group_dim == 0) { CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(0), num_groups); CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(1), k); } else if (group_dim == 1) { CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(0), n); CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(1), num_groups); } else { UNIMPLEMENTED_THEN_RETURN(); } Shape out_shape = x_desc.shape(); out_shape[x_desc.shape().NumAxes() - 1] = n; if (ctx->has_input("b", 0)) { const user_op::TensorDesc& b_desc = ctx->InputTensorDesc("b", 0); CHECK_EQ_OR_RETURN(b_desc.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(b_desc.shape().At(0), n); } if (ctx->has_input("w_zero", 0)) { const user_op::TensorDesc& w_zero_desc = ctx->InputTensorDesc("w_zero", 0); CHECK_OR_RETURN(w_zero_desc.shape() == w_scale_desc.shape()); } ctx->SetOutputShape("out", 0, out_shape); return Maybe::Ok(); } Maybe InferDataType4MatmulBias(user_op::InferContext* ctx) { const DataType data_type = ctx->InputDType("x", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("w_scale", 0), data_type); if (ctx->has_input("w_zero", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("w_zero", 0), data_type); } if (ctx->has_input("b", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), data_type); } if (ctx->Attr("symmetric")) { CHECK_OR_RETURN(ctx->InputDType("w", 0) == DataType::kUInt8 || ctx->InputDType("w", 0) == DataType::kInt8); } else { CHECK_EQ_OR_RETURN(ctx->InputDType("w", 0), DataType::kUInt8); } ctx->SetOutputDType("out", 0, data_type); return Maybe::Ok(); } } // namespace /* static */ Maybe FusedLinearWithGroupwiseQuantizedWeightOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmulBias(ctx); } /*static*/ Maybe FusedLinearWithGroupwiseQuantizedWeightOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedLinearWithGroupwiseQuantizedWeightOp::GetSbp( user_op::SbpContext* ctx) { // (b, m, k) * (n, k) const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const int64_t x_num_axes = x_shape.NumAxes(); const int64_t out_num_axes = x_num_axes; const int32_t k_x_axis = x_num_axes - 1; std::vector bias_args; if (ctx->user_op_conf().has_input("b", 0)) { bias_args.emplace_back("b", 0); } std::vector scale_args; scale_args.emplace_back("w_scale", 0); if (ctx->user_op_conf().has_input("w_zero", 0)) { scale_args.emplace_back("w_zero", 0); } for (int i = 0; i < x_shape.NumAxes() - 1; i++) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("w", 0)) .Broadcast(scale_args) .Broadcast(bias_args) .Split(user_op::OpArg("out", 0), i) .Build(); } const int64_t group_dim = ctx->user_op_conf().attr("group_dim"); const int64_t group_size = ctx->user_op_conf().attr("group_size"); CHECK_OR_RETURN(group_dim == 0 || group_dim == 1); const auto& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const auto& w_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("w", 0); CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2); CHECK_EQ_OR_RETURN(w_desc.shape().NumAxes(), 2); const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1); const int64_t n = w_desc.shape().At(0); const int64_t group_dim_size = group_dim == 0 ? n : k; CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0); const int64_t num_groups = group_dim_size / group_size; // B x S(n_axis) -> S(n_axis) if (group_dim == 1 || num_groups % ctx->parallel_num() == 0) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("w", 0), 0) .Split(scale_args, 0) .Split(bias_args, 0) .Split(user_op::OpArg("out", 0), out_num_axes - 1) .Build(); } // S(x_k_axis) x S(w_k_axis) -> P if (group_dim == 0 || num_groups % ctx->parallel_num() == 0) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), k_x_axis) .Split(user_op::OpArg("w", 0), 1) .Split(scale_args, 1) .PartialSum(bias_args) .PartialSum(user_op::OpArg("out", 0)) .Build(); } // P x B -> P ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("w", 0)) .Broadcast(scale_args) .PartialSum(bias_args) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedLinearWithGroupwiseQuantizedWeightOp::InferDataType( user_op::InferContext* ctx) { return InferDataType4MatmulBias(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_lstm_cell_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FusedLstmCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& cx_shape = ctx->InputShape("cx", 0); ctx->SetOutputShape("hy", 0, cx_shape); ctx->SetOutputShape("cy", 0, cx_shape); ctx->SetOutputShape("workspace", 0, ctx->InputShape("input_gates", 0)); return Maybe::Ok(); } /*static*/ Maybe FusedLstmCellOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedLstmCellOp::GetSbp(user_op::SbpContext* ctx) { // input_gates shape: [batch_size, hidden_size * 4] // hidden_gates shape: [batch_size, hidden_size * 4] // cx shape: [batch_size, hidden_size] // input_bias shape: [hidden_size * 4] // hidden_bias shape: [hidden_size * 4] // hy shape: [batch_size, hidden_size] // cy shape: [batch_size, hidden_size] // workspace shape: [batch_size, hidden_size * 4] std::vector broadcast_args; if (ctx->user_op_conf().has_input("input_bias", 0)) { broadcast_args.emplace_back("input_bias", 0); } if (ctx->user_op_conf().has_input("hidden_bias", 0)) { broadcast_args.emplace_back("hidden_bias", 0); } std::vector split_args; split_args.emplace_back("input_gates", 0); split_args.emplace_back("hidden_gates", 0); split_args.emplace_back("cx", 0); split_args.emplace_back("hy", 0); split_args.emplace_back("cy", 0); split_args.emplace_back("workspace", 0); ctx->NewBuilder().Split(split_args, 0).Broadcast(broadcast_args).Build(); return Maybe::Ok(); } /* static */ Maybe FusedLstmCellOp::InferDataType(user_op::InferContext* ctx) { DataType in_types = ctx->InputDType("cx", 0); ctx->SetOutputDType("hy", 0, in_types); ctx->SetOutputDType("cy", 0, in_types); ctx->SetOutputDType("workspace", 0, in_types); return Maybe::Ok(); } /* static */ Maybe FusedLstmCellGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("grad_gates", 0, ctx->InputShape("workspace", 0)); if (ctx->has_output("grad_cx", 0)) { ctx->SetOutputShape("grad_cx", 0, ctx->InputShape("cx", 0)); } if (ctx->has_output("grad_bias", 0)) { ctx->SetOutputShape("grad_bias", 0, Shape({ctx->InputShape("workspace", 0).At(1)})); } return Maybe::Ok(); } /*static*/ Maybe FusedLstmCellGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedLstmCellGradOp::GetSbp(user_op::SbpContext* ctx) { // grad_hy shape: [batch_size, hidden_size] // grad_cy shape: [batch_size, hidden_size] // cx shape: [batch_size, hidden_size] // cy shape: [batch_size, hidden_size] // workspace shape: [batch_size, hidden_size * 4] // grad_gates shape: [batch_size, hidden_size * 4] // grad_cx shape: [batch_size, hidden_size] // grad_bias shape: [hidden_size * 4] std::vector partial_sum_args; if (ctx->user_op_conf().has_output("grad_bias", 0)) { partial_sum_args.emplace_back("grad_bias", 0); } std::vector split_args; split_args.emplace_back("grad_hy", 0); split_args.emplace_back("grad_cy", 0); split_args.emplace_back("cx", 0); split_args.emplace_back("cy", 0); split_args.emplace_back("workspace", 0); split_args.emplace_back("grad_gates", 0); if (ctx->user_op_conf().has_output("grad_cx", 0)) { split_args.emplace_back("grad_cx", 0); } ctx->NewBuilder().Split(split_args, 0).PartialSum(partial_sum_args).Build(); return Maybe::Ok(); } /* static */ Maybe FusedLstmCellGradOp::InferDataType(user_op::InferContext* ctx) { DataType in_types = ctx->InputDType("grad_hy", 0); ctx->SetOutputDType("grad_gates", 0, in_types); if (ctx->has_output("grad_cx", 0)) { ctx->SetOutputDType("grad_cx", 0, in_types); } if (ctx->has_output("grad_bias", 0)) { ctx->SetOutputDType("grad_bias", 0, in_types); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { constexpr int32_t kAuxReluLdAlignRequirement = 128; long AlignReluAuxLd(long aux_ld) { /* ReLu bit-mask matrix leading dimension in elements. Must be divisible by 128 and be no less than the number of rows in the output matrix. */ long old_aux_ld = aux_ld; return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement) * kAuxReluLdAlignRequirement; } Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); int32_t weight_size = ctx->input_size("weights"); int32_t bias_size = ctx->input_size("biases"); CHECK_EQ_OR_RETURN(weight_size, bias_size) << "Weight num should be equal to bias num. "; /* A: (m, k) B: (n, k) need transpose C: (m, n) */ int64_t m = 0, n = 0, k = 0, cublas_aux_ld = 0; m = x_desc.shape().At(0); k = x_desc.shape().At(1); for (int32_t idx = 0; idx < weight_size; idx++) { // skip first input weight. const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc("biases", idx); CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2) << "Weight's ndim should be equal to 2. "; CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1) << "Bias's ndim should be equal to 1. "; n = weight_desc.shape().At(0); CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n) << "Bias shape should be equal to N. Assume (M, K) matmul (N, K, transpose_b=True) " "bias_add (N, ). "; CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k) << "Weight shape should be equal to K. Assume (M, K) matmul (N, K, transpose_b=True) " "bias_add (N, ). "; cublas_aux_ld = n; // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype ctx->SetOutputShape("cublas_aux", idx, Shape({m, aux_size})); ctx->SetOutputShape("hidden", idx, Shape({m, n})); // Set for next layer. k = n; } ctx->SetOutputShape("out", 0, Shape({m, n})); return Maybe::Ok(); } Maybe InferDataType4Matmul(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(in_desc.data_type()) << ", but got " << DataType_Name(first_in_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); for (int32_t i = 0; i < ctx->output_size("hidden"); i++) { user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc("hidden", i); hidden_desc->set_data_type(first_in_desc.data_type()); } for (int32_t i = 0; i < ctx->output_size("cublas_aux"); i++) { user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc("cublas_aux", i); aux_desc->set_data_type(DataType::kInt32); } return Maybe::Ok(); } } // namespace /* static */ Maybe FusedMatmulBiasAddReluDropoutOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmul(ctx); } /*static*/ Maybe FusedMatmulBiasAddReluDropoutOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedMatmulBiasAddReluDropoutOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0); for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { builder.Broadcast(user_op::OpArg("weights", i)); } for (int i = 0; i < ctx->user_op_conf().input_size("biases"); ++i) { builder.Broadcast(user_op::OpArg("biases", i)); } for (int i = 0; i < ctx->user_op_conf().output_size("cublas_aux"); ++i) { builder.Split(user_op::OpArg("cublas_aux", i), 0); } for (int i = 0; i < ctx->user_op_conf().output_size("hidden"); ++i) { builder.Split(user_op::OpArg("hidden", i), 0); } builder.Split(user_op::OpArg("out", 0), 0); builder.Build(); return Maybe::Ok(); } /* static */ Maybe FusedMatmulBiasAddReluDropoutOp::InferDataType( user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_matmul_bias_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4FusedMatmulBias(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); /* x: (m_i, ... m_1, k) weight: (n, k) need transpose bias: (n) */ CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2); const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1); const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc("bias", 0); CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2); CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1); const int64_t n = weight_desc.shape().At(0); CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n); CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k); Shape out_shape = x_desc.shape(); out_shape[x_desc.shape().NumAxes() - 1] = n; ctx->SetOutputShape("out", 0, out_shape); if (ctx->has_input("_add_to_output", 0)) { const user_op::TensorDesc& _add_to_output_desc = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(_add_to_output_desc.shape(), out_shape); } return Maybe::Ok(); } Maybe InferDataType4MatmulBias(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(in_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); if (ctx->has_input("_add_to_output", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("_add_to_output", 0), out_desc->data_type()) << "InferDataType Failed. _add_to_output Expected " << DataType_Name(out_desc->data_type()) << ", but got " << DataType_Name(ctx->InputDType("_add_to_output", 0)); } return Maybe::Ok(); } } // namespace /* static */ Maybe FusedMatmulBiasOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4FusedMatmulBias(ctx); } /*static*/ Maybe FusedMatmulBiasOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedMatmulBiasOp::GetSbp(user_op::SbpContext* ctx) { // (b, m, k) * (n, k) const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const int64_t x_num_axes = x_shape.NumAxes(); const int64_t out_num_axes = x_num_axes; const int32_t k_x_axis = x_num_axes - 1; std::vector out_and_add_to_output_args; out_and_add_to_output_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { out_and_add_to_output_args.emplace_back("_add_to_output", 0); } for (int i = 0; i < x_shape.NumAxes() - 1; i++) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("weight", 0)) .Broadcast(user_op::OpArg("bias", 0)) .Split(out_and_add_to_output_args, i) .Build(); } // B x S(n_axis) -> S(n_axis) ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("weight", 0), 0) .Split(user_op::OpArg("bias", 0), 0) .Split(out_and_add_to_output_args, out_num_axes - 1) .Build(); // S(x_k_axis) x S(w_k_axis) -> P ctx->NewBuilder() .Split(user_op::OpArg("x", 0), k_x_axis) .Split(user_op::OpArg("weight", 0), 1) .PartialSum(user_op::OpArg("bias", 0)) .PartialSum(out_and_add_to_output_args) .Build(); // P x B -> P ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("weight", 0)) .PartialSum(user_op::OpArg("bias", 0)) .PartialSum(out_and_add_to_output_args) .Build(); // B x P -> P ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("weight", 0)) .PartialSum(user_op::OpArg("bias", 0)) .PartialSum(out_and_add_to_output_args) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedMatmulBiasOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4MatmulBias(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_relu_dropout_grad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4FusedReluDropoutGrad(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("dy", 0)); return Maybe::Ok(); } Maybe InferDataType4FusedReluDropoutGrad(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe FusedReluDropoutGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4FusedReluDropoutGrad(ctx); } /*static*/ Maybe FusedReluDropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedReluDropoutGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe FusedReluDropoutGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4FusedReluDropoutGrad(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_scale_mask_bias_softmax_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { /*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferDataType(user_op::InferContext* ctx) -> Maybe { DataType query_type = ctx->InputDType("x", 0); DataType mask_bias_type = ctx->InputDType("mask", 0); CHECK_EQ_OR_RETURN(mask_bias_type, query_type); if (ctx->has_input("bias", 0)) { DataType bias_type = ctx->InputDType("bias", 0); CHECK_EQ_OR_RETURN(bias_type, query_type); } ctx->SetOutputDType("out", 0, query_type); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const float scale = ctx->Attr("scale"); CHECK_LE_OR_RETURN(scale, 1.); const Shape& x_shape = ctx->InputShape("x", 0); const Shape& mask_shape = ctx->InputShape("mask", 0); CHECK_OR_RETURN(x_shape[-1] == mask_shape[-1] && x_shape[0] == mask_shape[0]); if (ctx->has_input("bias", 0)) { const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_OR_RETURN(mask_shape[-1] == bias_shape[-1]); CHECK_OR_RETURN(mask_shape[0] == bias_shape[0] || bias_shape[0] == 1); for (int i = 1; i < x_shape.NumAxes() - 1; i++) { CHECK_OR_RETURN((mask_shape[i] == 1 || bias_shape[i] == 1) && mask_shape[i] * bias_shape[i] == x_shape[i]); } } else { auto axes = x_shape.NumAxes(); bool reach1 = false; for (int i = 0; i < axes - 1; i++) { CHECK_OR_RETURN((mask_shape[i] == x_shape[i] && !reach1) || (1 == mask_shape[i])); reach1 = (1 == mask_shape[i]); } } ctx->SetOutputShape("out", 0, x_shape); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskBiasSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { if (ctx->Attr("inplace") == false) ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Broadcast(user_op::OpArg("bias", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); else ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("mask", 0), 0) .Broadcast(user_op::OpArg("bias", 0)) .Build(); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { DataType y_type = ctx->InputDType("y", 0); DataType dy_type = ctx->InputDType("dy", 0); CHECK_EQ_OR_RETURN(y_type, dy_type); ctx->SetOutputDType("dx", 0, y_type); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_EQ_OR_RETURN(y_shape, dy_shape); ctx->SetOutputShape("dx", 0, y_shape); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), 0) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); const auto x_shape = x_desc.shape(); const auto mask_shape = mask_desc.shape(); CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; ctx->SetOutputShape("y", 0, x_desc.shape()); ctx->SetOutputIsDynamic("y", 0, x_desc.is_dynamic()); ctx->SetOutputShape("softmax_y", 0, x_desc.shape()); ctx->SetOutputIsDynamic("softmax_y", 0, x_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " << DataType_Name(mask_desc.data_type()); ctx->SetOutputDType("y", 0, x_desc.data_type()); ctx->SetOutputDType("softmax_y", 0, x_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) -> Maybe { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); CHECK_OR_RETURN(mask_modifier != nullptr) << " cannot find mask input."; CHECK_OR_RETURN(dropout_mask_modifier != nullptr) << " cannot find dropout mask input."; mask_modifier->set_requires_grad(false); dropout_mask_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2) << " x num axes at least 2."; const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0); CHECK_EQ_OR_RETURN(x_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes()) << " x num axes must equal with mask."; FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { // NOTE(chengcheng): mask support broadcast, when dim value = 1, sbp = broadcast if (mask_tensor.shape().At(axis) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .Broadcast(user_op::OpArg("mask", 0)) .Split(user_op::OpArg("dropout_mask", 0), axis) .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("softmax_y", 0), axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("dropout_mask", 0), axis) .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("softmax_y", 0), axis) .Build(); } } return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()) << " dy and y shape must equal."; CHECK_EQ_OR_RETURN(dy_desc.shape().At(dy_desc.shape().NumAxes() - 1), mask_desc.shape().At(mask_desc.shape().NumAxes() - 1)) << " last dim of y and mask is not equal."; user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(dy_desc.shape()); dx_desc->set_is_dynamic(dy_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(dy_desc.data_type(), softmax_y_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(softmax_y_desc.data_type()) << ", but got " << DataType_Name(dy_desc.data_type()); CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " << DataType_Name(mask_desc.data_type()); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_data_type(dy_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2) << " dy num axes at least 2."; const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0); CHECK_EQ_OR_RETURN(dy_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes()) << " dy num axes must equal with mask."; FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { if (mask_tensor.shape().At(axis) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("softmax_y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Broadcast(user_op::OpArg("mask", 0)) .Split(user_op::OpArg("dropout_mask", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("softmax_y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("dropout_mask", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_scale_mask_softmax_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); const auto x_shape = x_desc.shape(); const auto mask_shape = mask_desc.shape(); CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; ctx->SetOutputShape("y", 0, x_desc.shape()); ctx->SetOutputIsDynamic("y", 0, x_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskSoftmaxOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; ctx->SetOutputDType("y", 0, x_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) -> Maybe { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); CHECK_OR_RETURN(mask_modifier != nullptr) << " cannot find mask input."; mask_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2) << " x num axes at least 2."; const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0); CHECK_EQ_OR_RETURN(x_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes()) << " x num axes must equal with mask."; FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { if (mask_tensor.shape().At(axis) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .Broadcast(user_op::OpArg("mask", 0)) .Split(user_op::OpArg("y", 0), axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("y", 0), axis) .Build(); } } return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()) << " dy and y shape must equal."; CHECK_EQ_OR_RETURN(y_desc.shape().At(y_desc.shape().NumAxes() - 1), mask_desc.shape().At(mask_desc.shape().NumAxes() - 1)) << " last dim of y and mask is not equal."; user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(dy_desc.shape()); dx_desc->set_is_dynamic(dy_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedScaleMaskSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(dy_desc.data_type(), y_desc.data_type()) << " dy and y dtype must equal"; CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_data_type(dy_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2) << " dy num axes at least 2."; const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0); CHECK_EQ_OR_RETURN(dy_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes()) << " dy num axes must equal with mask."; FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { if (mask_tensor.shape().At(axis) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Broadcast(user_op::OpArg("mask", 0)) .Split(user_op::OpArg("dx", 0), axis) .Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); ctx->SetOutputShape("y", 0, x_desc.shape()); ctx->SetOutputIsDynamic("y", 0, x_desc.is_dynamic()); ctx->SetOutputShape("softmax_y", 0, x_desc.shape()); ctx->SetOutputIsDynamic("softmax_y", 0, x_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); ctx->SetOutputDType("y", 0, x_desc.data_type()); ctx->SetOutputDType("softmax_y", 0, x_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) -> Maybe { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("softmax_y", 0), axis) .Build(); } return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); dx_desc->set_shape(dy_desc.shape()); dx_desc->set_is_dynamic(dy_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); dx_desc->set_data_type(dy_desc.data_type()); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { ctx->NewBuilder() .Split(user_op::OpArg("softmax_y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("mask", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx) -> Maybe { DataType dtype = ctx->InputDType("hidden_states", 0); ctx->SetOutputDType("query_mul_key", 0, dtype); ctx->SetOutputDType("value", 0, dtype); return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { CHECK_OR_RETURN(!(ctx->InputIsDynamic("hidden_states", 0))); int64_t head_size = ctx->Attr("head_size"); const Shape& hidden_states_shape = ctx->InputShape("hidden_states", 0); // hidden_states_shape (seq_len, batch_size, hidden_size) // layout is (seq_len, batch_size, num_heads, 3, head_size) // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304, // 192, 1) CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3); int64_t seq_len = hidden_states_shape.At(0); int64_t batch_size = hidden_states_shape.At(1); int64_t hidden_size = hidden_states_shape.At(2); CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); int64_t num_heads = hidden_size / (head_size * 3); ctx->SetOutputShape("query_mul_key", 0, Shape({batch_size, num_heads, seq_len, seq_len})); ctx->SetOutputShape("value", 0, Shape({batch_size, num_heads, seq_len, head_size})); return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder() .Split(user_op::OpArg("hidden_states", 0), 1) .Split(user_op::OpArg("query_mul_key", 0), 0) .Split(user_op::OpArg("value", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("hidden_states", 0), 2) .Split(user_op::OpArg("query_mul_key", 0), 1) .Split(user_op::OpArg("value", 0), 1) .Build(); return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferDataType( user_op::InferContext* ctx) -> Maybe { DataType dtype = ctx->InputDType("query_mul_key_grad", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype) << "InferDataType Failed. Expected " << DataType_Name(dtype) << ", but got " << DataType_Name(ctx->InputDType("value_grad", 0)); ctx->SetOutputDType("hidden_states_grad", 0, dtype); return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { CHECK_OR_RETURN(!(ctx->InputIsDynamic("query_mul_key_grad", 0))); CHECK_OR_RETURN(!(ctx->InputIsDynamic("value_grad", 0))); const Shape& h_shape = ctx->InputShape("hidden_states", 0); const Shape& qmk_grad_shape = ctx->InputShape("query_mul_key_grad", 0); const Shape& v_grad_shape = ctx->InputShape("value_grad", 0); CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3); CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4); CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4); // hidden_states shape (s, b, H) int64_t seq_len = h_shape.At(0); int64_t batch_size = h_shape.At(1); int64_t hidden_size = h_shape.At(2); // value grad shape (b, n, s, h) int64_t num_heads = v_grad_shape.At(1); int64_t head_size = v_grad_shape.At(3); CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size); CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len); CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size); // qmk grad shape (b, n, sq, sk) CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size); CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads); CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); ctx->SetOutputShape("hidden_states_grad", 0, h_shape); return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder() .Split(user_op::OpArg("query_mul_key_grad", 0), 0) .Split(user_op::OpArg("value_grad", 0), 0) .Split(user_op::OpArg("hidden_states", 0), 1) .Split(user_op::OpArg("hidden_states_grad", 0), 1) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("query_mul_key_grad", 0), 1) .Split(user_op::OpArg("value_grad", 0), 1) .Split(user_op::OpArg("hidden_states", 0), 2) .Split(user_op::OpArg("hidden_states_grad", 0), 2) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/fused_weighted_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe FusedWeightedSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); for (int64_t i = 1; i < ctx->input_size("in"); ++i) { const auto& cur_in = ctx->InputTensorDesc("in", i); CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()) << Error::RuntimeError() << "inconsistent tensor size, expected all tensor to have the same shape, " << "but got " << in_0.shape().DebugStr() << " and " << cur_in.shape().DebugStr(); } out->set_shape(in_0.shape()); return Maybe::Ok(); } /*static*/ Maybe FusedWeightedSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FusedWeightedSumOp::GetSbp(user_op::SbpContext* ctx) { const int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); for (int64_t i = 0; i < num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); return Maybe::Ok(); } /* static */ Maybe FusedWeightedSumOp::InferDataType(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); auto* out = ctx->MutOutputTensorDesc("out", 0); const DataType data_type = in_0.data_type(); for (int64_t i = 1; i < ctx->input_size("in"); ++i) { const auto& cur_in = ctx->InputTensorDesc("in", i); CHECK_EQ_OR_RETURN(cur_in.data_type(), data_type) << Error::RuntimeError() << ctx->op_name() << " expected all tenser to have same type, but found " << DataType_Name(cur_in.data_type()) << " and " << DataType_Name(data_type); } out->set_data_type(data_type); return Maybe::Ok(); } /*static*/ Maybe FusedWeightedSumOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("in") >= 2) << Error::RuntimeError() << "The number of input tensors should be greater than or equal to 2"; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/gather_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto GatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); const int64_t axis = ctx->Attr("axis"); const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); // For 0-dim Tensor CHECK_GE_OR_RETURN(indices.shape().NumAxes(), 0); // NOLINT user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector dim_vec; dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + axis); dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(), indices.shape().dim_vec().cend()); dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1, in.shape().dim_vec().end()); out->set_shape(Shape(dim_vec)); out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic()); return Maybe::Ok(); } /*static*/ auto GatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return GatherOp::InferLogicalTensorDesc(ctx); } /*static*/ auto GatherOp::ModifyInputArg(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) -> Maybe { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { const int64_t in_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); const int64_t indices_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); const int64_t gather_axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(gather_axis, 0); CHECK_LT_OR_RETURN(gather_axis, in_num_axes); FOR_RANGE(int64_t, i, 0, indices_num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("indices", 0), i) .Broadcast(user_op::OpArg("in", 0)) .Split(user_op::OpArg("out", 0), gather_axis + i) .Build(); } FOR_RANGE(int64_t, i, 0, in_num_axes) { if (i == gather_axis) { ctx->NewBuilder() .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("in", 0), i) .PartialSum(user_op::OpArg("out", 0)) .Build(); } else { ctx->NewBuilder() .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1) .Build(); } } return Maybe::Ok(); } /*static*/ auto GatherOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/gelu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferGeluTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } Maybe InferGeluDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } Maybe GetGeluSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } } // namespace /*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluTensorDesc(ctx); } /*static*/ auto GeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluTensorDesc(ctx); } /*static*/ auto GeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { return InferGeluDataType(ctx); } /*static*/ auto GeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return GetGeluSbp(ctx); } /*static*/ auto FastGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluTensorDesc(ctx); } /*static*/ auto FastGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluTensorDesc(ctx); } /*static*/ auto FastGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { return InferGeluDataType(ctx); } /*static*/ auto FastGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return GetGeluSbp(ctx); } namespace { Maybe InferGeluGradTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << "InferTensorDesc failed (" << ctx->op_name() << "). Expected x shape " << x_shape.ToString() << " to be equal to dy shape " << dy_shape.ToString(); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } Maybe InferGeluGradDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe GetGeluGradSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } } // namespace /*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluGradTensorDesc(ctx); } /*static*/ auto GeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluGradTensorDesc(ctx); } /*static*/ auto GeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { return InferGeluGradDataType(ctx); } /*static*/ auto GeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return GetGeluGradSbp(ctx); } /*static*/ auto FastGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluGradTensorDesc(ctx); } /*static*/ auto FastGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferGeluGradTensorDesc(ctx); } /*static*/ auto FastGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { return InferGeluGradDataType(ctx); } /*static*/ auto FastGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { return GetGeluGradSbp(ctx); } /*static*/ Maybe FusedFastGeluMulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& m_shape = ctx->InputShape("multiplier", 0); CHECK_OR_RETURN(ctx->InputShape("multiplier", 0) == in_shape) << "Expected multiplier shape " << in_shape.ToString() << ", but got " << m_shape.ToString(); ctx->SetOutputShape("out", 0, in_shape); return Maybe::Ok(); } /*static*/ Maybe FusedFastGeluMulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe FusedFastGeluMulOp::InferDataType(user_op::InferContext* ctx) { const DataType in_dtype = ctx->InputDType("in", 0); const DataType m_dtype = ctx->InputDType("multiplier", 0); CHECK_EQ_OR_RETURN(m_dtype, in_dtype) << "Expected multiplier data type " << DataType_Name(in_dtype) << ", but got " << DataType_Name(m_dtype); ctx->SetOutputDType("out", 0, in_dtype); return Maybe::Ok(); } /*static*/ Maybe FusedFastGeluMulOp::GetSbp(user_op::SbpContext* ctx) { return GetGeluSbp(ctx); } /*static*/ Maybe FusedFastGeluMulGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); const Shape& m_shape = ctx->InputShape("multiplier", 0); CHECK_EQ_OR_RETURN(out_diff_shape, in_shape); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(m_shape, in_shape); // NOLINT(maybe-need-error-msg) ctx->SetOutputShape("in_diff", 0, in_shape); ctx->SetOutputShape("multiplier_diff", 0, m_shape); return Maybe::Ok(); } /*static*/ Maybe FusedFastGeluMulGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe FusedFastGeluMulGradOp::InferDataType(user_op::InferContext* ctx) { const DataType in_dtype = ctx->InputDType("in", 0); const DataType out_diff_dtype = ctx->InputDType("out_diff", 0); const DataType m_dtype = ctx->InputDType("multiplier", 0); CHECK_EQ_OR_RETURN(out_diff_dtype, in_dtype); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(m_dtype, in_dtype); // NOLINT(maybe-need-error-msg) ctx->SetOutputDType("in_diff", 0, in_dtype); ctx->SetOutputDType("multiplier_diff", 0, m_dtype); return Maybe::Ok(); } /*static*/ Maybe FusedFastGeluMulGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("multiplier", 0)) .PartialSum(user_op::OpArg("out_diff", 0)) .PartialSum(user_op::OpArg("in_diff", 0)) .PartialSum(user_op::OpArg("multiplier_diff", 0)) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { ctx->SetOutputShape("y", 0, Shape({ctx->InputShape("x", 0).At(0)})); return Maybe::Ok(); } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) -> Maybe { return GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc(ctx); } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).Broadcast(user_op::OpArg("y", 0)).Build(); const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Broadcast(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("y", 0, DataType::kInt32); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/gpt_data_loader_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ auto MegatronGptMmapDataLoaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { int64_t batch_size = ctx->Attr("batch_size"); int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(Shape({batch_size, sample_len})); return Maybe::Ok(); } /*static*/ auto MegatronGptMmapDataLoaderOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->MutOutputTensorDesc("out", 0)->set_data_type(ctx->Attr("dtype")); return Maybe::Ok(); } /*static*/ auto MegatronGptMmapDataLoaderOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /*static*/ auto MegatronGptMmapDataLoaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) -> Maybe { SbpParallel default_sbp; default_sbp.mutable_split_parallel()->set_axis(0); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /*static*/ auto MegatronGptMmapDataLoaderOp::ModifyInputArg( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) -> Maybe { if (!conf.has_input("iteration", 0)) { return Maybe::Ok(); } user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); CHECK_OR_RETURN(input_modifier != nullptr); input_modifier->set_is_mutable(true); input_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/greater_inplace_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { bool CheckBroadCastAble(const Shape& shape, const Shape& broadcast_shape) { int left_pad = broadcast_shape.size() - shape.size(); if (left_pad < 0) { return false; } for (int i = 0; i < shape.size(); ++i) { int j = i + left_pad; if (shape[i] != 1 && shape[i] != broadcast_shape[j]) { return false; } } return true; } } // namespace /*static*/ Maybe BroadCastInplaceGreaterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); const auto& y_desc = ctx->InputTensorDesc("y", 0); auto x_shape = x_desc.shape(); auto y_shape = y_desc.shape(); bool broadcast_status = CheckBroadCastAble(y_shape, x_shape); CHECK_OR_RETURN(broadcast_status); ctx->SetOutputShape("out", 0, x_shape); return Maybe::Ok(); } /*static*/ Maybe BroadCastInplaceGreaterOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return BroadCastInplaceGreaterOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe BroadCastInplaceGreaterOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe BroadCastInplaceGreaterOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ScalarLogicalInplaceGreaterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ScalarLogicalInplaceGreaterOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return ScalarLogicalInplaceGreaterOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarLogicalInplaceGreaterOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe ScalarLogicalInplaceGreaterOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/grid_sample_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; const auto& interpolation_mode = conf.attr("interpolation_mode"); if (!(interpolation_mode == "bilinear" || interpolation_mode == "nearest" || interpolation_mode == "bicubic")) { err << " interpolation_mode:" << interpolation_mode; pass_checked = false; } const auto& padding_mode = conf.attr("padding_mode"); if (!(padding_mode == "zeros" || padding_mode == "border" || padding_mode == "reflection")) { err << " padding_mode:" << padding_mode; pass_checked = false; } if (pass_checked) { return Maybe::Ok(); } else { return oneflow::Error::CheckFailedError() << err.str(); } } /*static*/ auto GridSampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); user_op::TensorDesc& output = *(ctx->MutOutputTensorDesc("output", 0)); // Only support 4D or 5D input with NCHW layout // For 4D grid: input = { N, C, H_in, W_in }, // grid = { N, H_out, W_out, 2 } // output = { N, C, H_out, W_out } // For 5D grid: input = { N, C, D_in, H_in, W_in }, // grid = { N, D_out, H_out, W_out, 3 } // output = { N, C, D_out, H_out, W_out } const Shape& input_shape = input.shape(); const Shape& grid_shape = grid.shape(); bool is_4d_input = true; if (input_shape.NumAxes() == 4) { CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; is_4d_input = true; } else if (input_shape.NumAxes() == 5) { CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; if (ctx->Attr("interpolation_mode") == "bicubic") { oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; } is_4d_input = false; } else { CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; } output.set_is_dynamic(grid.is_dynamic()); if (is_4d_input) { output.set_shape( Shape({input_shape.At(0), input_shape.At(1), grid_shape.At(1), grid_shape.At(2)})); } else { output.set_shape(Shape({input_shape.At(0), input_shape.At(1), grid_shape.At(1), grid_shape.At(2), grid_shape.At(3)})); } return Maybe::Ok(); } /*static*/ auto GridSampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return GridSampleOp::InferLogicalTensorDesc(ctx); } /*static*/ auto GridSampleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("grid", 0), 0) .Split(user_op::OpArg("output", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 1) .Broadcast(user_op::OpArg("grid", 0)) .Split(user_op::OpArg("output", 0), 1) .Build(); return Maybe::Ok(); } /*static*/ auto GridSampleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } Maybe GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return GridSampleOp::CheckAttr(def, conf); } /*static*/ auto GridSampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { ctx->MutOutputTensorDesc("dinput", 0)->set_shape(ctx->InputTensorDesc("input", 0).shape()); ctx->MutOutputTensorDesc("dgrid", 0)->set_shape(ctx->InputTensorDesc("grid", 0).shape()); return Maybe::Ok(); } /*static*/ auto GridSampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return GridSampleGradOp::InferLogicalTensorDesc(ctx); } /*static*/ auto GridSampleGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { ctx->NewBuilder() .Split(user_op::OpArg("doutput", 0), 0) .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("grid", 0), 0) .Split(user_op::OpArg("dinput", 0), 0) .Split(user_op::OpArg("dgrid", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("doutput", 0), 1) .Split(user_op::OpArg("input", 0), 1) .Broadcast(user_op::OpArg("grid", 0)) .Split(user_op::OpArg("dinput", 0), 1) .PartialSum(user_op::OpArg("dgrid", 0)) .Build(); return Maybe::Ok(); } /*static*/ auto GridSampleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { ctx->SetOutputDType("dinput", 0, ctx->InputDType("input", 0)); ctx->SetOutputDType("dgrid", 0, ctx->InputDType("grid", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/group_norm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_GROUP_NORM_USE_FP16_DIRECTLY, false); namespace { oneflow::DataType InferGnParamDataType(const DataType x_data_type) { if (EnvBool()) { return x_data_type; } return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16) ? DataType::kFloat : x_data_type; } } // namespace /* static */ Maybe GroupNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); const bool affine = ctx->Attr("affine"); const int32_t num_groups = ctx->Attr("num_groups"); const int64_t batch_size = x.shape().At(0); const std::string& data_format = ctx->Attr("data_format"); CHECK_GT_OR_RETURN(x.shape().NumAxes(), 2); int64_t channel_size = 0; if (data_format == "channels_first") { channel_size = x.shape().At(1); } else if (data_format == "channels_last") { channel_size = x.shape().At(x.shape().NumAxes() - 1); } else { UNIMPLEMENTED_THEN_RETURN(); } y->set_shape(x.shape()); y->set_is_dynamic(x.is_dynamic()); if (affine) { const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); CHECK_EQ_OR_RETURN(gamma.shape().At(0), channel_size); const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); CHECK_EQ_OR_RETURN(beta.shape().At(0), channel_size); } CHECK_EQ_OR_RETURN(channel_size % num_groups, 0) << "Channels should be divisble by num_groups. "; mean->set_shape(Shape({batch_size, num_groups})); *inv_variance = *mean; return Maybe::Ok(); } /*static*/ Maybe GroupNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe GroupNormOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Split(ctx->outputs(), 0) .Broadcast(user_op::OpArg("gamma", 0)) .Broadcast(user_op::OpArg("beta", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe GroupNormOp::InferDataType(user_op::InferContext* ctx) { const bool affine = ctx->Attr("affine"); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(x.data_type()); if (affine) { const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(gamma.data_type()); const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(beta.data_type()); } user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); mean->set_data_type(InferGnParamDataType(x.data_type())); inv_variance->set_data_type(mean->data_type()); return Maybe::Ok(); } // GroupNorm Grad /* static */ Maybe GroupNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); const int32_t num_groups = ctx->Attr("num_groups"); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); const Shape& gn_param_shape = Shape({x.shape().At(0), num_groups}); CHECK_EQ_OR_RETURN(mean.shape(), gn_param_shape); CHECK_EQ_OR_RETURN(inv_variance.shape(), gn_param_shape); dx->set_shape(dy.shape()); dx->set_is_dynamic(dy.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe GroupNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe GroupNormGradOp::GetSbp(user_op::SbpContext* ctx) { std::vector broadcast_args; if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back(user_op::OpArg("gamma", 0)); } ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("mean", 0), 0) .Split(user_op::OpArg("inv_variance", 0), 0) .Split(ctx->outputs(), 0) .Broadcast(broadcast_args) .Build(); return Maybe::Ok(); } /* static */ Maybe GroupNormGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(dy.data_type()); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); const DataType& gn_param_data_type = InferGnParamDataType(x.data_type()); CHECK_EQ_OR_RETURN(mean.data_type(), gn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(gn_param_data_type) << ", but got " << DataType_Name(mean.data_type()); CHECK_EQ_OR_RETURN(inv_variance.data_type(), gn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(gn_param_data_type) << ", but got " << DataType_Name(inv_variance.data_type()); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); dx->set_data_type(dy.data_type()); return Maybe::Ok(); } // GroupNorm Param Grad /* static */ Maybe GroupNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* dgamma = ctx->MutOutputTensorDesc("dgamma", 0); user_op::TensorDesc* dbeta = ctx->MutOutputTensorDesc("dbeta", 0); const int64_t channel_size = x.shape().At(1); dgamma->set_shape(Shape{channel_size}); dbeta->set_shape(Shape{channel_size}); return Maybe::Ok(); } /*static*/ Maybe GroupNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe GroupNormParamGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("mean", 0), 0) .Split(user_op::OpArg("inv_variance", 0), 0) .PartialSum(ctx->outputs()) .Build(); return Maybe::Ok(); } /* static */ Maybe GroupNormParamGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dgamma = ctx->MutOutputTensorDesc("dgamma", 0); user_op::TensorDesc* dbeta = ctx->MutOutputTensorDesc("dbeta", 0); dgamma->set_data_type(dy.data_type()); dbeta->set_data_type(dy.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/grouped_matmul_bias_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe GroupedMatmulBiasOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int64_t input_size = ctx->input_size("xs"); CHECK_EQ_OR_RETURN(ctx->input_size("weights"), input_size); const bool has_biases = ctx->has_input("biases", 0); if (has_biases) { CHECK_EQ_OR_RETURN(ctx->input_size("biases"), input_size); } CHECK_EQ_OR_RETURN(ctx->output_size("ys"), input_size); const DataType data_type = ctx->InputTensorDesc("xs", 0).data_type(); for (int64_t i = 0; i < input_size; ++i) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("xs", i); CHECK_EQ_OR_RETURN(x_desc.data_type(), data_type); CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2); const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1); const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", i); CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2); CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k); const int64_t n = weight_desc.shape().At(0); if (has_biases) { const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc("biases", i); CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n); } user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("ys", i); y_desc->set_data_type(data_type); DimVector out_dim_vec = x_desc.shape().dim_vec(); out_dim_vec.back() = n; y_desc->set_shape(Shape(out_dim_vec)); } return Maybe::Ok(); } /*static*/ Maybe GroupedMatmulBiasOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe GroupedMatmulBiasOp::GetSbp(user_op::SbpContext* ctx) { { // s0 x b auto builder = ctx->NewBuilder(); for (int64_t i = 0; i < ctx->user_op_conf().input_size("xs"); ++i) { builder.Split(user_op::OpArg("xs", i), 0); } for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { builder.Broadcast(user_op::OpArg("weights", i)); } for (int i = 0; i < ctx->user_op_conf().input_size("biases"); ++i) { builder.Broadcast(user_op::OpArg("biases", i)); } for (int i = 0; i < ctx->user_op_conf().output_size("ys"); ++i) { builder.Split(user_op::OpArg("ys", i), 0); } builder.Build(); } { // b x s0 auto builder = ctx->NewBuilder(); for (int64_t i = 0; i < ctx->user_op_conf().input_size("xs"); ++i) { builder.Broadcast(user_op::OpArg("xs", i)); } for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { builder.Split(user_op::OpArg("weights", i), 0); } for (int i = 0; i < ctx->user_op_conf().input_size("biases"); ++i) { builder.Split(user_op::OpArg("biases", i), 0); } for (int i = 0; i < ctx->user_op_conf().output_size("ys"); ++i) { builder.Split(user_op::OpArg("ys", i), ctx->LogicalTensorDesc4InputArgNameAndIndex("xs", i).shape().NumAxes() - 1); } builder.Build(); } return Maybe::Ok(); } /* static */ Maybe GroupedMatmulBiasOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("xs", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(in_desc.data_type()); } for (int32_t i = 0; i < ctx->output_size("ys"); i++) { user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("ys", i); y_desc->set_data_type(first_in_desc.data_type()); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/groupwise_dequantize_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe GroupwiseDequantizeOp::GetSbp(user_op::SbpContext* ctx) { const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& scale_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); std::vector scale_zero_args; scale_zero_args.emplace_back(user_op::OpArg("scale", 0)); if (ctx->user_op_conf().has_input("zero", 0)) { scale_zero_args.emplace_back(user_op::OpArg("zero", 0)); } for (int32_t i = 0; i < in_shape.NumAxes(); ++i) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(scale_zero_args, i) .Split(user_op::OpArg("out", 0), i) .Build(); } const int64_t group_dim = ctx->Attr("group_dim"); if (scale_shape.At(group_dim) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), group_dim) .Broadcast(scale_zero_args) .Split(user_op::OpArg("out", 0), group_dim) .Build(); } return Maybe::Ok(); } /*static*/ Maybe GroupwiseDequantizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& scale_shape = ctx->InputShape("scale", 0); const int32_t num_bits = ctx->Attr("num_bits"); const int64_t group_dim = ctx->Attr("group_dim"); const int64_t group_size = ctx->Attr("group_size"); CHECK_OR_RETURN(num_bits == 4 || num_bits == 8); CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1); CHECK_OR_RETURN(group_dim >= 0 && group_dim < in_shape.NumAxes()); Shape out_shape = in_shape; out_shape.Set(out_shape.NumAxes() - 1, out_shape.At(out_shape.NumAxes() - 1) * (8 / num_bits)); const int64_t group_dim_size = out_shape.At(group_dim); CHECK_GE_OR_RETURN(group_size, 0); CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0); const int64_t num_groups = group_dim_size / group_size; CHECK_EQ_OR_RETURN(scale_shape.NumAxes(), in_shape.NumAxes()); if (ctx->has_input("zero", 0)) { CHECK_EQ_OR_RETURN(ctx->InputShape("zero", 0).NumAxes(), in_shape.NumAxes()); } for (int64_t i = 0; i < out_shape.NumAxes(); ++i) { if (i == group_dim) { CHECK_EQ_OR_RETURN(scale_shape.At(i), num_groups); if (ctx->has_input("zero", 0)) { CHECK_EQ_OR_RETURN(ctx->InputShape("zero", 0).At(i), num_groups); } } else { CHECK_EQ_OR_RETURN(scale_shape.At(i), out_shape.At(i)); if (ctx->has_input("zero", 0)) { CHECK_EQ_OR_RETURN(ctx->InputShape("zero", 0).At(i), out_shape.At(i)); } } } ctx->SetOutputShape("out", 0, out_shape); return Maybe::Ok(); } /*static*/ Maybe GroupwiseDequantizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe GroupwiseDequantizeOp::InferDataType(user_op::InferContext* ctx) { const DataType data_type = ctx->InputDType("scale", 0); if (ctx->has_input("zero", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("zero", 0), data_type); } if (ctx->Attr("symmetric")) { CHECK_OR_RETURN(ctx->InputDType("in", 0) == DataType::kUInt8 || ctx->InputDType("in", 0) == DataType::kInt8); } else { CHECK_EQ_OR_RETURN(ctx->InputDType("in", 0), DataType::kUInt8); } ctx->SetOutputDType("out", 0, data_type); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/hardshrink_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe HardShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe HardShrinkOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardShrinkOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe HardShrinkOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe HardShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape) << "The shape of y_grad and y must be same."; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /* static */ Maybe HardShrinkGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardShrinkGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe HardShrinkGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("y", 0)) << ", but got " << DataType_Name(ctx->InputDType("dy", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/hardsigmoid_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe HardsigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardsigmoidOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe HardsigmoidOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe HardsigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardsigmoidGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe HardsigmoidGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/hardswish_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe HardswishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardswishOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe HardswishOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe HardswishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardswishGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe HardswishGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/hardtanh_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); double min_val = ctx->Attr("min_val"); double max_val = ctx->Attr("max_val"); CHECK_LE_OR_RETURN(min_val, max_val); return Maybe::Ok(); } /*static*/ Maybe HardtanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardtanhOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe HardtanhOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape); ctx->SetOutputShape("dx", 0, dy_shape); double min_val = ctx->Attr("min_val"); double max_val = ctx->Attr("max_val"); CHECK_LE_OR_RETURN(min_val, max_val); return Maybe::Ok(); } /*static*/ Maybe HardtanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HardtanhGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe HardtanhGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("y", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/hierarchical_parallel_cast_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe HierarchicalParallelCastOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HierarchicalParallelCastOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe HierarchicalParallelCastOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); const auto& conf = ctx->user_op_conf().attr>("nd_sbp"); CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes()); for (const std::string& sbp_str : conf) { SbpParallel sbp_parallel; CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); *in_distribution->add_sbp_parallel() = sbp_parallel; *out_distribution->add_sbp_parallel() = sbp_parallel; } return Maybe::Ok(); } /* static */ Maybe HierarchicalParallelCastOp::GetNdSbpSignatureList( user_op::GetNdSbpSignatureListContext* ctx) { const auto& conf = ctx->Attr>("nd_sbp"); NdSbpSignature nd_sbp_signature; for (const std::string& sbp_str : conf) { SbpParallel sbp_parallel; CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn("in", 0)].add_sbp_parallel() = sbp_parallel; *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn("out", 0)].add_sbp_parallel() = sbp_parallel; } ctx->AddNdSbpSignature(nd_sbp_signature); return Maybe::Ok(); } /* static */ Maybe HierarchicalParallelCastOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe HierarchicalParallelCastLikeOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe HierarchicalParallelCastLikeOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe HierarchicalParallelCastLikeOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); const NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); *in_distribution = hint_distribution; *out_distribution = hint_distribution; *like_distribution = hint_distribution; return Maybe::Ok(); } /* static */ Maybe HierarchicalParallelCastLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe IdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IdentityOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe IdentityOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_batch_align_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { template bool PowerOfTwo(T x) { static_assert(std::is_integral::value, "T must be integral"); return x != 0 && (x & (x - 1)) == 0; } } // namespace /* static */ Maybe ImageBatchAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1); const Shape& shape_attr = ctx->Attr("shape"); const bool dynamic_out = ctx->Attr("dynamic_out"); DimVector dim_vec(shape_attr.NumAxes() + 1); dim_vec.at(0) = in_desc.shape().elem_cnt(); FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(Shape(dim_vec)); out_desc->set_is_dynamic(dynamic_out); return Maybe::Ok(); } /*static*/ Maybe ImageBatchAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageBatchAlignOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe ImageBatchAlignOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); CHECK_OR_RETURN(out_modifier != nullptr); out_modifier->set_header_infered_before_compute(false); return Maybe::Ok(); } /* static */ Maybe ImageBatchAlignOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool check_failed = false; std::stringstream err; err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); const Shape& shape = conf.attr("shape"); if (shape.NumAxes() != 3) { err << ", shape: " << shape.ToString() << " (image shape must has 3 axes)"; check_failed = true; } DataType data_type = conf.attr("data_type"); if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; check_failed = true; } int32_t alignment = conf.attr("alignment"); if (alignment < 0) { err << ", alignment: " << alignment << " (alignment must be greater than or equal to 0)"; check_failed = true; } else if (alignment != 0 && !PowerOfTwo(alignment)) { err << ", alignment: " << alignment << " (alignment must be power of 2 when it's not equal to 0)"; check_failed = true; } if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } return Maybe::Ok(); } /* static */ Maybe ImageBatchAlignOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(ctx->Attr("data_type")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_decode_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ImageDecodeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(in_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe ImageDecodeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageDecodeOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe ImageDecodeOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool check_failed = false; std::stringstream err; err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); const std::string& color_space = conf.attr("color_space"); if (color_space != "BGR" && color_space != "RGB" && color_space != "GRAY") { err << ", color_space: " << color_space << " (color_space can only be one of BGR, RGB and GRAY)"; check_failed = true; } DataType data_type = conf.attr("data_type"); if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; check_failed = true; } if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } return Maybe::Ok(); } /* static */ Maybe ImageDecodeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_object_preprocess_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } } // namespace /* static */ Maybe ImageFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); const int N = in_desc.shape().elem_cnt(); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ImageFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageFlipOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ImageFlipOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(in_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe ObjectBboxFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); const int N = bbox_desc.shape().elem_cnt(); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); ctx->SetOutputShape("out", 0, ctx->InputShape("bbox", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("bbox", 0)); return Maybe::Ok(); } /*static*/ Maybe ObjectBboxFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ObjectBboxFlipOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ObjectBboxFlipOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(bbox_desc.data_type()); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(image_size_desc.data_type()); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt8) << ", but got " << DataType_Name(flip_code_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("bbox", 0)); return Maybe::Ok(); } /* static */ Maybe ObjectBboxScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); const int N = bbox_desc.shape().elem_cnt(); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); ctx->SetOutputShape("out", 0, ctx->InputShape("bbox", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("bbox", 0)); return Maybe::Ok(); } /*static*/ Maybe ObjectBboxScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ObjectBboxScaleOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ObjectBboxScaleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(bbox_desc.data_type()); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(scale_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("bbox", 0)); return Maybe::Ok(); } /* static */ Maybe ObjectSegmentationPolygonFlipOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); const int N = poly_desc.shape().elem_cnt(); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); ctx->SetOutputShape("out", 0, ctx->InputShape("poly", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("poly", 0)); return Maybe::Ok(); } /*static*/ Maybe ObjectSegmentationPolygonFlipOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ObjectSegmentationPolygonFlipOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ObjectSegmentationPolygonFlipOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(poly_desc.data_type()); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(image_size_desc.data_type()); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt8) << ", but got " << DataType_Name(flip_code_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("poly", 0)); return Maybe::Ok(); } /* static */ Maybe ObjectSegmentationPolygonScaleOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); const int N = poly_desc.shape().elem_cnt(); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); ctx->SetOutputShape("out", 0, ctx->InputShape("poly", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("poly", 0)); return Maybe::Ok(); } /*static*/ Maybe ObjectSegmentationPolygonScaleOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ObjectSegmentationPolygonScaleOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ObjectSegmentationPolygonScaleOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(poly_desc.data_type()); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(scale_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("poly", 0)); return Maybe::Ok(); } /* static */ Maybe ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ImageNormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageNormalizeOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ImageNormalizeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(in_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); const int N = poly_desc.shape().elem_cnt(); const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); ctx->SetOutputShape("out", 0, ctx->InputShape("poly", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("poly", 0)); return Maybe::Ok(); } /*static*/ Maybe ObjectSegmentationPolygonToMaskOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ObjectSegmentationPolygonToMaskOp::GetSbp(user_op::SbpContext* ctx) { return ImageObjectGetSbp(ctx); } /* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(poly_desc.data_type()); const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(poly_desc.data_type()); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(image_size_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("poly", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_preprocess_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/user/image/image_util.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); bool has_mirror = ctx->has_input("mirror", 0); if (has_mirror) { const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); } user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t N = in_tensor.shape().At(0); int64_t H = ctx->Attr("crop_h"); int64_t W = ctx->Attr("crop_w"); std::string color_space = ctx->Attr("color_space"); int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; CHECK_OR_RETURN(H != 0 && W != 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1); std::string output_layout = ctx->Attr("output_layout"); if (output_layout == "NCHW") { out_tensor->set_shape(Shape({N, C, H, W})); } else if (output_layout == "NHWC") { out_tensor->set_shape(Shape({N, H, W, C})); } else { return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; } return Maybe::Ok(); } /*static*/ Maybe CropMirrorNormalizeFromTensorbufferOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(in_tensor.data_type()); bool has_mirror = ctx->has_input("mirror", 0); if (has_mirror) { const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt8) << ", but got " << DataType_Name(mirror_tensor.data_type()); } user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); DataType output_dtype = ctx->Attr("output_dtype"); CHECK_EQ_OR_RETURN(output_dtype, DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(output_dtype); // only support float now; for float16 in future out_tensor->set_data_type(output_dtype); return Maybe::Ok(); } /* static */ Maybe CropMirrorNormalizeFromUint8Op::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); bool has_mirror = ctx->has_input("mirror", 0); if (has_mirror) { const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); } user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t N = in_tensor.shape().At(0); int64_t H = ctx->Attr("crop_h"); int64_t W = ctx->Attr("crop_w"); std::string color_space = ctx->Attr("color_space"); int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4); // {N, H, W, C} CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C); if (H == 0 || W == 0) { H = in_tensor.shape().At(1); W = in_tensor.shape().At(2); } else { H = std::min(H, in_tensor.shape().At(1)); W = std::min(W, in_tensor.shape().At(2)); } std::string output_layout = ctx->Attr("output_layout"); if (output_layout == "NCHW") { out_tensor->set_shape(Shape({N, C, H, W})); } else if (output_layout == "NHWC") { out_tensor->set_shape(Shape({N, H, W, C})); } else { return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; } return Maybe::Ok(); } /*static*/ Maybe CropMirrorNormalizeFromUint8Op::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CropMirrorNormalizeFromUint8Op::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe CropMirrorNormalizeFromUint8Op::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8) << "InferDataType Failed. Expected " << DataType_Name(DataType::kUInt8) << ", but got " << DataType_Name(in_tensor.data_type()); bool has_mirror = ctx->has_input("mirror", 0); if (has_mirror) { const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt8) << ", but got " << DataType_Name(mirror_tensor.data_type()); } user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); DataType output_dtype = ctx->Attr("output_dtype"); CHECK_EQ_OR_RETURN(output_dtype, DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(output_dtype); // only support float now; for float16 in future out_tensor->set_data_type(output_dtype); return Maybe::Ok(); } /* static */ Maybe CoinFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t batch_size = ctx->Attr("batch_size"); out_tensor->set_shape(Shape({batch_size})); return Maybe::Ok(); } /* static */ Maybe CoinFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); int64_t batch_size = ctx->Attr("batch_size"); const Shape logical_shape = Shape({batch_size}); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe CoinFlipOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("out", 0), 0).Build(); return Maybe::Ok(); } /* static */ Maybe CoinFlipOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const Shape& hierarchy = ctx->parallel_hierarchy(); NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex("out", 0); // the input may be produced by tick which should be broadcast parallel dist std::vector inputs_dist; for (const auto& arg_pair : ctx->inputs()) { inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second)); } const auto& dist_conf = ctx->user_op_conf().attr>("nd_sbp"); if (dist_conf.size() == 0) { FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); for (auto* input_dist : inputs_dist) { input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } } } else { CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes()); for (const std::string& sbp_str : dist_conf) { SbpParallel sbp_parallel; CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); CHECK_OR_RETURN( (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0) || sbp_parallel.has_broadcast_parallel()); *output_dist->add_sbp_parallel() = sbp_parallel; for (auto* input_dist : inputs_dist) { input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } } } return Maybe::Ok(); } /* static */ Maybe CoinFlipOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_data_type(DataType::kInt8); return Maybe::Ok(); } /* static */ Maybe ImageRandomCropOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_shape(in_tensor.shape()); out_tensor->set_is_dynamic(in_tensor.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe ImageRandomCropOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageRandomCropOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /* static */ Maybe ImageRandomCropOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe ImageRandomCropOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); ctx->SetOutputDType("out", 0, in_tensor.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_resize_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/image/image_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ImageResizeToFixedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0); int64_t batch_size = in_tensor.shape().elem_cnt(); int64_t target_width = ctx->Attr("target_width"); int64_t target_height = ctx->Attr("target_height"); int64_t channels = ctx->Attr("channels"); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_shape(Shape({batch_size, target_height, target_width, channels})); out_tensor->set_is_dynamic(in_tensor.is_dynamic()); user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc("scale", 0); scale_tensor->set_shape(Shape({batch_size, 2})); scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe ImageResizeToFixedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageResizeToFixedOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe ImageResizeToFixedOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool check_failed = false; std::ostringstream err; err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); int64_t target_width = conf.attr("target_width"); int64_t target_height = conf.attr("target_height"); if (target_width <= 0 || target_height <= 0) { err << ", target_width: " << target_width << ", target_height: " << target_height; check_failed = true; } int64_t channels = conf.attr("channels"); if (channels != 1 && channels != 3) { err << ", channels: " << channels << " (channels can only be 1 or 3)"; check_failed = true; } DataType data_type = conf.attr("data_type"); if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; check_failed = true; } const std::string& interp_type = conf.attr("interpolation_type"); if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } return Maybe::Ok(); } /* static */ Maybe ImageResizeToFixedOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_data_type(ctx->Attr("data_type")); user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc("scale", 0); scale_tensor->set_data_type(DataType::kFloat); return Maybe::Ok(); } /* static */ Maybe ImageResizeKeepAspectRatioOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(in_desc.shape()); user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); size_desc->set_shape(in_desc.shape()); user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); scale_desc->set_shape(in_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe ImageResizeKeepAspectRatioOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageResizeKeepAspectRatioOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe ImageResizeKeepAspectRatioOp::CheckAttr( const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool check_failed = false; std::ostringstream err; err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); const int32_t target_size = conf.attr("target_size"); const int32_t max_size = conf.attr("max_size"); if (target_size <= 0) { err << ", target_size: " << target_size << " (target_size must be greater than 0)"; check_failed = true; } if (max_size < target_size && max_size > 0) { err << ", max_size: " << max_size << " (max_size must be greater than target_size or equal to 0)"; check_failed = true; } const std::string& interp_type = conf.attr("interpolation_type"); if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } return Maybe::Ok(); } /* static */ Maybe ImageResizeKeepAspectRatioOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); size_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); scale_desc->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/image_target_resize_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ImageTargetResizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(in_desc.shape()); user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); size_desc->set_shape(Shape({in_desc.shape().elem_cnt(), 2})); user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); scale_desc->set_shape(Shape({in_desc.shape().elem_cnt(), 2})); return Maybe::Ok(); } /*static*/ Maybe ImageTargetResizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ImageTargetResizeOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe ImageTargetResizeOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { bool check_failed = false; std::stringstream err; err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); const int32_t target_size = conf.attr("target_size"); const int32_t max_size = conf.attr("max_size"); if (target_size <= 0) { err << ", target_size: " << target_size << " (target_size must be greater than 0)"; check_failed = true; } if (max_size < target_size) { err << ", max_size: " << max_size << " (max_size must be greater than 0)"; check_failed = true; } if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } return Maybe::Ok(); } /* static */ Maybe ImageTargetResizeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(DataType::kTensorBuffer); user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); size_desc->set_data_type(DataType::kInt32); user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); scale_desc->set_data_type(DataType::kFloat); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/in_top_k_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe InTopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); // NOLINT(maybe-need-error-msg) const bool is_dynamic = targets.is_dynamic(); CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic()); // NOLINT(maybe-need-error-msg) out->set_is_dynamic(is_dynamic); out->set_shape(targets.shape()); return Maybe::Ok(); } /*static*/ Maybe InTopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe InTopKOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe InTopKOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); CHECK_OR_RETURN(IsIndexDataType(targets.data_type())) << " targets data type must be index type"; const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(predictions.data_type()); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(kBool); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/index_add_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe IndexAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); ctx->SetOutputShape("output", 0, input_shape); ctx->SetOutputStride("output", 0, ctx->InputStride("input", 0)); return Maybe::Ok(); } /*static*/ Maybe IndexAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IndexAddOp::GetSbp(user_op::SbpContext* ctx) { // TODO(yangzhimin): support more valid sbp signature. return Maybe::Ok(); } /* static */ Maybe IndexAddOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/indexed_slices_reduce_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe IndexedSlicesReduceSumOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); const user_op::TensorDesc& x_values = ctx->InputTensorDesc("x_values", 0); CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes()); FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) { CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i)); } const int64_t n = x_indices.shape().elem_cnt(); const int64_t m = x_values.shape().elem_cnt() / n; user_op::TensorDesc* y_indices = ctx->MutOutputTensorDesc("y_indices", 0); user_op::TensorDesc* y_values = ctx->MutOutputTensorDesc("y_values", 0); *y_indices = x_indices; y_indices->set_shape(Shape({n})); *y_values = x_values; y_values->set_shape(Shape({n, m})); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe IndexedSlicesReduceSumOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IndexedSlicesReduceSumOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe IndexedSlicesReduceSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_data_type(DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/inv_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe InvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe InvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe InvOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe InvOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/kl_div_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe KlInferTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(input_desc.is_dynamic()); out_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe KlInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(target_desc.data_type()) << ", but got " << DataType_Name(input_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(input_desc.is_dynamic()); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(target_desc.data_type()) << ", but got " << DataType_Name(input_desc.data_type()); ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace /* static */ Maybe KlDivLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return KlInferTensorDescFn(ctx); } /*static*/ Maybe KlDivLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe KlDivLossOp::GetSbp(user_op::SbpContext* ctx) { const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe KlDivLossOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); target_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe KlDivLossOp::InferDataType(user_op::InferContext* ctx) { return KlInferDataType(ctx); } /* static */ Maybe KlDivLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe KlDivLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe KlDivLossGradOp::GetSbp(user_op::SbpContext* ctx) { const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("target", 0), i) .Split(user_op::OpArg("dx", 0), i) .Split(user_op::OpArg("dy", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe KlDivLossGradOp::InferDataType(user_op::InferContext* ctx) { return InferGradDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/l1_l2_regularize_gradient_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()); ctx->SetOutputShape("out", 0, ctx->InputShape("model", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("model", 0)); return Maybe::Ok(); } Maybe GetSbpSignatures(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); } return Maybe::Ok(); } } // namespace /* static */ Maybe L1L2RegularizeGradientOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc(ctx); } /*static*/ Maybe L1L2RegularizeGradientOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe L1L2RegularizeGradientOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpSignatures(ctx); } /* static */ Maybe L1L2RegularizeGradientOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()) << "InferDataType Failed. Expected " << DataType_Name(model.data_type()) << ", but got " << DataType_Name(model_diff.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("model", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/l2_normalize_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); CHECK_GT_OR_RETURN(epsilon, 0); ctx->SetOutputShape("y", 0, x_shape); Shape square_x_sum_shape = x_shape; square_x_sum_shape.Set(axis, 1); ctx->SetOutputShape("square_x_sum", 0, square_x_sum_shape); return Maybe::Ok(); } /*static*/ Maybe L2NormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe L2NormalizeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const int32_t axis = ctx->Attr("axis"); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { if (i != axis) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("square_x_sum", 0), i) .Build(); } } return Maybe::Ok(); } /* static */ Maybe L2NormalizeOp::InferDataType(user_op::InferContext* ctx) { DataType x_dtype = ctx->InputDType("x", 0); DataType square_x_sum_dtype = x_dtype; if (x_dtype == DataType::kFloat16 || x_dtype == DataType::kBFloat16) { square_x_sum_dtype = DataType::kFloat; } ctx->SetOutputDType("square_x_sum", 0, square_x_sum_dtype); ctx->SetOutputDType("y", 0, x_dtype); return Maybe::Ok(); } /* static */ Maybe L2NormalizeGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& y_shape = ctx->InputShape("y", 0); const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_EQ_OR_RETURN(dy_shape, y_shape); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes()); CHECK_GT_OR_RETURN(epsilon, 0); FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) { if (i == axis) { CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1); } else { CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i)); } } ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe L2NormalizeGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe L2NormalizeGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); const int32_t axis = ctx->Attr("axis"); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { if (i != axis) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("square_x_sum", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } } return Maybe::Ok(); } /* static */ Maybe L2NormalizeGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("y", 0)); CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("square_x_sum", 0)) << ", but got " << DataType_Name(ctx->InputDType("y", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/layer_norm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { DEFINE_ENV_BOOL(ONEFLOW_LAYER_NORM_PARAM_KEEP_DIM, false); namespace { int64_t ShiftNegativeAxisIfNeed(const Shape& shape, int64_t axis) { const int64_t shifted = axis < 0 ? axis + shape.NumAxes() : axis; CHECK_GE(shifted, 0); CHECK_LT(shifted, shape.NumAxes()); return shifted; } Shape InferBnParamShape(const Shape& x_shape, const int64_t begin_norm_axis) { DimVector bn_param_shape_dim_vec; bn_param_shape_dim_vec.insert(bn_param_shape_dim_vec.end(), x_shape.dim_vec().cbegin(), x_shape.dim_vec().cbegin() + begin_norm_axis); if (EnvBool()) { while (bn_param_shape_dim_vec.size() < x_shape.dim_vec().size()) { bn_param_shape_dim_vec.push_back(1); } } const Shape bn_param_shape(bn_param_shape_dim_vec); return bn_param_shape; } oneflow::DataType InferBnParamDataType(const DataType x_data_type) { return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16) ? DataType::kFloat : x_data_type; } } // namespace /* static */ Maybe LayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); const bool center = ctx->Attr("center"); const bool scale = ctx->Attr("scale"); const int64_t begin_params_axis = ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_params_axis")); y->set_shape(x.shape()); y->set_is_dynamic(x.is_dynamic()); DimVector param_shape_dim_vec; param_shape_dim_vec.insert(param_shape_dim_vec.end(), x.shape().dim_vec().cbegin() + begin_params_axis, x.shape().dim_vec().cend()); const Shape param_shape(param_shape_dim_vec); if (center) { const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); CHECK_EQ_OR_RETURN(beta.shape(), param_shape); } if (scale) { const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); } const int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_norm_axis")); if (begin_norm_axis != begin_params_axis) { return Error::RuntimeError() << "begin_norm_axis must equal to begin_params_axis, but got " << begin_norm_axis << " vs " << begin_params_axis; } mean->set_shape(InferBnParamShape(x.shape(), begin_norm_axis)); *inv_variance = *mean; return Maybe::Ok(); } /*static*/ Maybe LayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LayerNormOp::GetSbp(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_norm_axis")); int64_t begin_params_axis = ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_params_axis")); for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) { ctx->NewBuilder() .Split(ctx->inputs(), i) .Split(ctx->outputs(), i) .Broadcast(user_op::OpArg("gamma", 0)) .Broadcast(user_op::OpArg("beta", 0)) .Build(); } return Maybe::Ok(); } /* static */ Maybe LayerNormOp::InferDataType(user_op::InferContext* ctx) { const bool center = ctx->Attr("center"); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(x.data_type()); if (center) { const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(beta.data_type()); } const bool scale = ctx->Attr("scale"); if (scale) { const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(gamma.data_type()); } user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); mean->set_data_type(InferBnParamDataType(x.data_type())); inv_variance->set_data_type(mean->data_type()); return Maybe::Ok(); } /* static */ Maybe LayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); CHECK_GT_OR_RETURN(begin_norm_axis, 0); const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); dx->set_shape(dy.shape()); dx->set_is_dynamic(dy.is_dynamic()); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); } return Maybe::Ok(); } /*static*/ Maybe LayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LayerNormGradOp::GetSbp(user_op::SbpContext* ctx) { std::vector broadcast_args; if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back(user_op::OpArg("gamma", 0)); } int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); for (int i = 0; i < begin_norm_axis; ++i) { ctx->NewBuilder() .Split(ctx->inputs(), i) .Split(ctx->outputs(), i) .Broadcast(broadcast_args) .Build(); } return Maybe::Ok(); } /* static */ Maybe LayerNormGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(dy.data_type()); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); DataType bn_param_data_type = InferBnParamDataType(x.data_type()); CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " << DataType_Name(mean.data_type()); CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " << DataType_Name(inv_variance.data_type()); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); dx->set_data_type(dy.data_type()); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()) << "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got " << DataType_Name(add_to_output.data_type()); } return Maybe::Ok(); } /* static */ Maybe LayerNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { // TODO: tsai: replace lambda with user op if auto has_tensor = [ctx](const std::string& bn) -> bool { bool ret = false; for (const auto& t : ctx->inputs()) { if (bn == t.first) { return true; } } for (const auto& t : ctx->outputs()) { if (bn == t.first) { return true; } } return ret; }; const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); const bool has_beta_diff = has_tensor("beta_diff"); const bool has_gamma_diff = has_tensor("gamma_diff"); CHECK_GE_OR_RETURN(begin_params_axis, 1); CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); DimVector param_shape_dim_vec; param_shape_dim_vec.insert(param_shape_dim_vec.end(), dy.shape().dim_vec().cbegin() + begin_params_axis, dy.shape().dim_vec().cend()); const Shape param_shape(param_shape_dim_vec); if (has_beta_diff) { user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); beta_diff->set_shape(param_shape); } if (has_gamma_diff) { user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); gamma_diff->set_shape(param_shape); } return Maybe::Ok(); } /*static*/ Maybe LayerNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LayerNormParamGradOp::GetSbp(user_op::SbpContext* ctx) { int64_t begin_params_axis = ctx->Attr("begin_params_axis"); for (int i = 0; i < begin_params_axis; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } /* static */ Maybe LayerNormParamGradOp::InferDataType(user_op::InferContext* ctx) { auto has_tensor = [ctx](const std::string& bn) -> bool { bool ret = false; for (auto& t : ctx->inputs()) { if (bn == t.first) { return true; } } for (auto& t : ctx->outputs()) { if (bn == t.first) { return true; } } return ret; }; const bool has_beta_diff = has_tensor("beta_diff"); const bool has_gamma_diff = has_tensor("gamma_diff"); const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); if (has_beta_diff) { user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); beta_diff->set_data_type(dy.data_type()); } if (has_gamma_diff) { user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); gamma_diff->set_data_type(dy.data_type()); } return Maybe::Ok(); } /* static */ Maybe FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << "dy and x shapes should be equal."; const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); CHECK_GT_OR_RETURN(begin_norm_axis, 0) << "begin_norm_axis must be greater than 0."; const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << "mean shape must match bn_param_shape."; CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape) << "inv_variance shape must match bn_param_shape."; dx->set_shape(dy.shape()); dx->set_is_dynamic(dy.is_dynamic()); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()) << "add_to_output shape must match dx shape."; } auto has_tensor = [ctx](const std::string& bn) -> bool { bool ret = false; for (const auto& t : ctx->inputs()) { if (bn == t.first) { return true; } } for (const auto& t : ctx->outputs()) { if (bn == t.first) { return true; } } return ret; }; const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); const bool has_beta_diff = has_tensor("beta_diff"); const bool has_gamma_diff = has_tensor("gamma_diff"); CHECK_GE_OR_RETURN(begin_params_axis, 1) << "begin_params_axis must be greater than or equal to 1."; CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()) << "begin_params_axis must be less than the number of axes in dy shape."; DimVector param_shape_dim_vec; param_shape_dim_vec.insert(param_shape_dim_vec.end(), dy.shape().dim_vec().cbegin() + begin_params_axis, dy.shape().dim_vec().cend()); const Shape param_shape(param_shape_dim_vec); if (has_beta_diff) { user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); beta_diff->set_shape(param_shape); } if (has_gamma_diff) { user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); gamma_diff->set_shape(param_shape); } return Maybe::Ok(); } /*static*/ Maybe FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) { std::vector broadcast_args; if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back("gamma", 0); } int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); int64_t begin_params_axis = ctx->Attr("begin_params_axis"); CHECK_EQ(begin_norm_axis, begin_params_axis) << "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis << " and " << begin_params_axis; for (int i = 0; i < begin_norm_axis; ++i) { ctx->NewBuilder() .Split(ctx->inputs(), i) .Split(user_op::OpArg("dx", 0), i) .PartialSum(user_op::OpArg("gamma_diff", 0)) .PartialSum(user_op::OpArg("beta_diff", 0)) .Broadcast(broadcast_args) .Build(); } return Maybe::Ok(); } /* static */ Maybe FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()) << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " << DataType_Name(dy.data_type()); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); DataType bn_param_data_type = InferBnParamDataType(x.data_type()); CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " << DataType_Name(mean.data_type()); CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type) << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " << DataType_Name(inv_variance.data_type()); user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); dx->set_data_type(dy.data_type()); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()) << "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got " << DataType_Name(add_to_output.data_type()); } auto has_tensor = [ctx](const std::string& bn) -> bool { bool ret = false; for (auto& t : ctx->inputs()) { if (bn == t.first) { return true; } } for (auto& t : ctx->outputs()) { if (bn == t.first) { return true; } } return ret; }; const bool has_beta_diff = has_tensor("beta_diff"); const bool has_gamma_diff = has_tensor("gamma_diff"); if (has_beta_diff) { user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); beta_diff->set_data_type(dy.data_type()); } if (has_gamma_diff) { user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); gamma_diff->set_data_type(dy.data_type()); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/leaky_relu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe LeakyReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LeakyReluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe LeakyReluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe LeakyReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LeakyReluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe LeakyReluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/lerp_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe LerpOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(start.shape(), end.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor end" << end.shape(); if (weight.shape().elem_cnt() != 1) { CHECK_EQ_OR_RETURN(start.shape(), weight.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor weight" << weight.shape(); } user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(start.is_dynamic()); out->set_shape(start.shape()); return Maybe::Ok(); } Maybe LerpOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return LerpOp::InferLogicalTensorDesc(ctx); } Maybe LerpOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(start.data_type(), end.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `end` but got dtype " << end.data_type(); CHECK_EQ_OR_RETURN(start.data_type(), weight.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `weight` but got dtype " << weight.data_type(); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(start.data_type()); return Maybe::Ok(); } Maybe LerpOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex("start", 0); FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("start", 0), i) .Split(user_op::OpArg("end", 0), i) .Split(user_op::OpArg("weight", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } Maybe LerpGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& out_diff = ctx->InputTensorDesc("out_diff", 0); CHECK_EQ_OR_RETURN(start.shape(), end.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor end" << end.shape(); CHECK_EQ_OR_RETURN(start.shape(), weight.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor weight" << weight.shape(); CHECK_EQ_OR_RETURN(start.shape(), out_diff.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor out_diff" << out_diff.shape(); user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc("start_diff", 0); user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc("end_diff", 0); user_op::TensorDesc* weight_diff = ctx->MutOutputTensorDesc("weight_diff", 0); start_diff->set_is_dynamic(start.is_dynamic()); start_diff->set_shape(start.shape()); end_diff->set_is_dynamic(end.is_dynamic()); end_diff->set_shape(end.shape()); weight_diff->set_is_dynamic(weight.is_dynamic()); weight_diff->set_shape(weight.shape()); return Maybe::Ok(); } Maybe LerpGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return LerpGradOp::InferLogicalTensorDesc(ctx); } Maybe LerpGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& out_diff = ctx->InputTensorDesc("out_diff", 0); CHECK_EQ_OR_RETURN(start.data_type(), end.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `end` but got dtype " << end.data_type(); CHECK_EQ_OR_RETURN(start.data_type(), weight.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `weight` but got dtype " << weight.data_type(); CHECK_EQ_OR_RETURN(start.data_type(), out_diff.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `out_diff` but got dtype " << out_diff.data_type(); user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc("start_diff", 0); user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc("end_diff", 0); user_op::TensorDesc* weight_diff = ctx->MutOutputTensorDesc("weight_diff", 0); start_diff->set_data_type(start.data_type()); end_diff->set_data_type(end.data_type()); weight_diff->set_data_type(weight.data_type()); return Maybe::Ok(); } Maybe LerpGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex("start", 0); FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("start", 0), i) .Split(user_op::OpArg("end", 0), i) .Split(user_op::OpArg("weight", 0), i) .Split(user_op::OpArg("out_diff", 0), i) .Split(user_op::OpArg("start_diff", 0), i) .Split(user_op::OpArg("end_diff", 0), i) .Split(user_op::OpArg("weight_diff", 0), i) .Build(); } return Maybe::Ok(); } Maybe ScalarLerpOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); CHECK_EQ_OR_RETURN(start.shape(), end.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor end" << end.shape(); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(start.is_dynamic()); out->set_shape(start.shape()); return Maybe::Ok(); } Maybe ScalarLerpOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ScalarLerpOp::InferLogicalTensorDesc(ctx); } Maybe ScalarLerpOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); CHECK_EQ_OR_RETURN(start.data_type(), end.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `end` but got dtype " << end.data_type(); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(start.data_type()); return Maybe::Ok(); } Maybe ScalarLerpOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex("start", 0); FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("start", 0), i) .Split(user_op::OpArg("end", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } Maybe ScalarLerpGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& out_diff = ctx->InputTensorDesc("out_diff", 0); CHECK_EQ_OR_RETURN(start.shape(), end.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor end" << end.shape(); CHECK_EQ_OR_RETURN(start.shape(), out_diff.shape()) << "The size of tensor start" << start.shape() << "must match the size of tensor out_diff" << out_diff.shape(); user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc("start_diff", 0); user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc("end_diff", 0); start_diff->set_is_dynamic(start.is_dynamic()); start_diff->set_shape(start.shape()); end_diff->set_is_dynamic(start.is_dynamic()); end_diff->set_shape(start.shape()); return Maybe::Ok(); } Maybe ScalarLerpGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ScalarLerpGradOp::InferLogicalTensorDesc(ctx); } Maybe ScalarLerpGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& start = ctx->InputTensorDesc("start", 0); const user_op::TensorDesc& end = ctx->InputTensorDesc("end", 0); const user_op::TensorDesc& out_diff = ctx->InputTensorDesc("out_diff", 0); CHECK_EQ_OR_RETURN(start.data_type(), end.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `end` but got dtype " << end.data_type(); CHECK_EQ_OR_RETURN(start.data_type(), out_diff.data_type()) << Error::RuntimeError() << "expected dtype " << start.data_type() << " for `out_diff` but got dtype " << out_diff.data_type(); user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc("start_diff", 0); user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc("end_diff", 0); start_diff->set_data_type(start.data_type()); end_diff->set_data_type(start.data_type()); return Maybe::Ok(); } Maybe ScalarLerpGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex("start", 0); FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("start", 0), i) .Split(user_op::OpArg("end", 0), i) .Split(user_op::OpArg("out_diff", 0), i) .Split(user_op::OpArg("start_diff", 0), i) .Split(user_op::OpArg("end_diff", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/linalg_cross_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe LinalgCrossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("input", 0)); return Maybe::Ok(); } /*static*/ Maybe LinalgCrossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe LinalgCrossOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& input = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); const int64_t num_axes = input.shape().NumAxes(); const int64_t dim = ctx->Attr("dim"); FOR_RANGE(int64_t, i, 0, num_axes) { if (i == dim) continue; ctx->NewBuilder() .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("other", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe LinalgCrossOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/log_softmax_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("prob", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe LogSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LogSoftmaxOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), axis) .Split(user_op::OpArg("prob", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe LogSoftmaxOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("prob", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("prob", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe LogSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LogSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { ctx->NewBuilder() .Split(user_op::OpArg("prob", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe LogSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("prob", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("prob", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/logical_not_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferDataTypeLogicalNot(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, DataType::kBool); return Maybe::Ok(); } } // namespace /* static */ Maybe LogicalNotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe LogicalNotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LogicalNotOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /* static */ Maybe LogicalNotOp::InferDataType(user_op::InferContext* ctx) { return InferDataTypeLogicalNot(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/loss_op_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/common/just.h" namespace oneflow { user_op::GetSbpFn GenLossForwardDefaultGetSbpFn( const std::function& f) { return [=](user_op::SbpContext* ctx) -> Maybe { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("out", 0), 0); if (ctx->user_op_conf().has_input("weight", 0)) { builder.Split(user_op::OpArg("weight", 0), 0); } f(builder, ctx); builder.Build(); return Maybe::Ok(); }; } user_op::GetSbpFn GenLossBackwardDefaultGetSbpFn( const std::function& f) { return [=](user_op::SbpContext* ctx) -> Maybe { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Split(user_op::OpArg("dy", 0), 0); if (ctx->user_op_conf().has_input("weight", 0)) { builder.Split(user_op::OpArg("weight", 0), 0); } f(builder, ctx); builder.Build(); return Maybe::Ok(); }; } } // namespace oneflow ================================================ FILE: oneflow/user/ops/loss_op_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_ #define ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_ #include #include "oneflow/core/framework/framework.h" namespace oneflow { user_op::GetSbpFn GenLossForwardDefaultGetSbpFn( const std::function& f = [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {}); user_op::GetSbpFn GenLossBackwardDefaultGetSbpFn( const std::function& f = [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {}); } // namespace oneflow #endif // ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_ ================================================ FILE: oneflow/user/ops/lu_composition_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe LUDecompositionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); auto x_shape = x_desc.shape(); ctx->SetOutputShape("pivot", 0, Shape(x_shape.begin(), x_shape.end() - 1)); ctx->SetOutputShape("LU", 0, x_shape); return Maybe::Ok(); } /*static*/ Maybe LUDecompositionOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe LUDecompositionOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("LU", 0), i) .Split(user_op::OpArg("pivot", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe LUDecompositionOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("LU", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("pivot", 0, DataType::kInt32); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/masked_fill_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferMaskedFillTensorDesc(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); ctx->SetOutputShape("out", 0, mask_shape); return Maybe::Ok(); } Maybe InferMaskedFillDataType(user_op::InferContext* ctx) { DataType mask_dtype = ctx->InputDType("mask", 0); CHECK_OR_RETURN(IsIntegralDataType(mask_dtype) || IsBoolDataType(mask_dtype)) << " mask type must be integral or bool"; ctx->SetOutputDType("out", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe GetMaskedFillSbpSignatures(user_op::SbpContext* ctx) { const Shape& mask_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("mask", 0).shape(); const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); FOR_RANGE(int64_t, i, 0, mask_shape.NumAxes()) { if (mask_shape.At(i) == 1 && x_shape.At(i) == 1) { continue; } if (mask_shape.At(i) == x_shape.At(i)) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } else if (mask_shape.At(i) == 1) { ctx->NewBuilder() .Broadcast(user_op::OpArg("mask", 0)) .Split(user_op::OpArg("x", 0), i) .Split(ctx->outputs(), i) .Build(); } else if (x_shape.At(i) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("mask", 0), i) .Broadcast(user_op::OpArg("x", 0)) .Split(ctx->outputs(), i) .Build(); } else { UNIMPLEMENTED(); } } return Maybe::Ok(); } Maybe GetMaskedFillInputArgModify(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* mask_arg_modifier = GetInputArgModifierFn("mask", 0); mask_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace /* static */ Maybe MaskedFillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferMaskedFillTensorDesc(ctx); } /*static*/ Maybe MaskedFillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MaskedFillOp::GetSbp(user_op::SbpContext* ctx) { return GetMaskedFillSbpSignatures(ctx); } /* static */ Maybe MaskedFillOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return GetMaskedFillInputArgModify(GetInputArgModifierFn, conf); } /* static */ Maybe MaskedFillOp::InferDataType(user_op::InferContext* ctx) { return InferMaskedFillDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/math_binary_broadcast_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/user/ops/math_binary_broadcast_seq.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { bool IsScalarTensor(const user_op::TensorDesc* tensor) { return tensor->shape().NumAxes() == 1 && tensor->shape().At(0) == 1; } bool IsZeroDimTensor(const user_op::TensorDesc* tensor) { return tensor->shape().NumAxes() == 0; } Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); user_op::TensorDesc* tensor_z = ctx->MutOutputTensorDesc("z", 0); size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()); if (IsZeroDimTensor(&tensor_x)) { ctx->SetOutputShape("z", 0, ctx->InputShape("y", 0)); ctx->SetOutputIsDynamic("z", 0, ctx->InputIsDynamic("y", 0)); } else if (IsZeroDimTensor(&tensor_y)) { ctx->SetOutputShape("z", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("z", 0, ctx->InputIsDynamic("x", 0)); } else if (IsScalarTensor(&tensor_x)) { ctx->SetOutputShape("z", 0, ctx->InputShape("y", 0)); ctx->SetOutputIsDynamic("z", 0, ctx->InputIsDynamic("y", 0)); } else if (IsScalarTensor(&tensor_y)) { ctx->SetOutputShape("z", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("z", 0, ctx->InputIsDynamic("x", 0)); } else { const auto& x_shape = CreateLeftExtendedShape(ShapeView(tensor_x.shape()), output_num_axes); const auto& y_shape = CreateLeftExtendedShape(ShapeView(tensor_y.shape()), output_num_axes); ctx->SetOutputShape("z", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("z", 0, ctx->InputIsDynamic("x", 0)); Shape out_shape(x_shape); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { if (x_shape.At(i) != 1 && y_shape.At(i) != 1 && x_shape.At(i) != y_shape.At(i)) { return Error::RuntimeError() << "The size of tensor a (" << x_shape.At(i) << ") must match the size of tensor b (" << y_shape.At(i) << ") at non-singleton dimension " << i; } out_shape.Set(i, (x_shape.At(i) == 0 || y_shape.At(i) == 0) ? 0 : std::max(x_shape.At(i), y_shape.At(i))); } tensor_z->set_shape(out_shape); } tensor_z->set_is_dynamic(tensor_x.is_dynamic() || tensor_y.is_dynamic()); return Maybe::Ok(); } Maybe InferTensorDescBinaryBroadcastLogical(user_op::InferContext* ctx) { return InferTensorDescBinaryBroadcastNormal(ctx); } Maybe InferDataTypeBinaryBroadcastNormal(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type()) << "InferDataType Failed. Expected " << DataType_Name(tensor_x.data_type()) << ", but got " << DataType_Name(tensor_y.data_type()); ctx->SetOutputDType("z", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe InferDataTypeBinaryBroadcastLogical(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type()) << "InferDataType Failed. Expected " << DataType_Name(tensor_x.data_type()) << ", but got " << DataType_Name(tensor_y.data_type()); ctx->SetOutputDType("z", 0, DataType::kBool); return Maybe::Ok(); } template class binary_func> void GenPartialSbpSign(user_op::SbpContext* ctx) {} template<> void GenPartialSbpSign(user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); } template<> void GenPartialSbpSign(user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); } template<> void GenPartialSbpSign(user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); } template<> void GenPartialSbpSign(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); } template<> void GenPartialSbpSign(user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("z", 0)) .Build(); } template class binary_func> Maybe GetBinaryBroadcastSbpSignature(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); if (x_shape.NumAxes() < y_shape.NumAxes()) { FOR_RANGE(int64_t, i, 0, y_shape.NumAxes() - x_shape.NumAxes()) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("z", 0), i) .Build(); } FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i) .Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i) .Split(ctx->outputs(), y_shape.NumAxes() - 1 - i) .Build(); } } else if (x_shape.NumAxes() > y_shape.NumAxes()) { FOR_RANGE(int64_t, i, 0, x_shape.NumAxes() - y_shape.NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("y", 0)) .Split(user_op::OpArg("z", 0), i) .Build(); } FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i) .Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i) .Split(ctx->outputs(), x_shape.NumAxes() - 1 - i) .Build(); } } else { FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { if (x_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; } if (x_shape.At(i) == y_shape.At(i)) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } else if (x_shape.At(i) == 1) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("y", 0), i) .Split(ctx->outputs(), i) .Build(); } else if (y_shape.At(i) == 1) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("y", 0)) .Split(ctx->outputs(), i) .Build(); } else { UNIMPLEMENTED(); } } } GenPartialSbpSign(ctx); return Maybe::Ok(); } } // namespace #define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix) \ /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDescBinaryBroadcastNormal(ctx); \ } \ /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ return GetBinaryBroadcastSbpSignature(ctx); \ } \ /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ return InferDataTypeBinaryBroadcastNormal(ctx); \ } #define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix) \ /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDescBinaryBroadcastLogical(ctx); \ } \ /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ return GetBinaryBroadcastSbpSignature(ctx); \ } \ /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ return InferDataTypeBinaryBroadcastLogical(ctx); \ } OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ_ODS) OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP, MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS) } // namespace oneflow ================================================ FILE: oneflow/user/ops/math_binary_broadcast_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_ #define ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_ #include "oneflow/core/common/util.h" namespace oneflow { #define MATH_BINARY_BROADCAST_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("broadcast_add", Add) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_sub", Sub) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_mul", Mul) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_div", Div) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_minimum", Min) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_maximum", Max) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_bitwise_and", BitwiseAnd) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_bitwise_or", BitwiseOr) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_bitwise_xor", BitwiseXor) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_floor_mod", FloorMod) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_fmod", FMod) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_pow", Pow) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_zeta", Zeta) #define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("broadcast_equal", EQ) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_not_equal", NE) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_greater", GT) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_greater_equal", GE) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_less", LT) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_less_equal", LE) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_and", AND) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_or", OR) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_xor", XOR) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_isclose_eq_nan", IEN) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_isclose_neq_nan", INN) #define MATH_BINARY_BROADCAST_FUNC_SEQ_ODS \ OF_PP_MAKE_TUPLE_SEQ(BroadcastAddOp, Add) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastSubOp, Sub) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastMulOp, Mul) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastDivOp, Div) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastMinimumOp, Min) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastMaximumOp, Max) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseAndOp, BitwiseAnd) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseOrOp, BitwiseOr) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseXorOp, BitwiseXor) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastFloorModOp, FloorMod) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastFmodOp, FMod) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastPowOp, Pow) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastZetaOp, Zeta) #define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS \ OF_PP_MAKE_TUPLE_SEQ(BroadcastEqualOp, EQ) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastNotEqualOp, NE) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterOp, GT) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterEqualOp, GE) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastLessOp, LT) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastLessEqualOp, LE) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalAndOp, AND) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalOrOp, OR) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalXorOp, XOR) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastIsCloseEqualNanOp, IEN) \ OF_PP_MAKE_TUPLE_SEQ(BroadcastIsCloseNotEqualNanOp, INN) } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_ ================================================ FILE: oneflow/user/ops/math_binary_elementwise_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { #define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ } \ /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ } \ /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ } #define REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD(math_binary_elementwise_type, func_prefix) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op); \ \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##XGradOp); \ \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##YGradOp); OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS) } // namespace oneflow ================================================ FILE: oneflow/user/ops/math_binary_elementwise_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_ #define ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_ #include "oneflow/core/common/util.h" namespace oneflow { #define MATH_BINARY_ELEMENTWISE_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("pow", Pow) \ OF_PP_MAKE_TUPLE_SEQ("atan2", Atan2) \ OF_PP_MAKE_TUPLE_SEQ("floordiv", FloorDiv) \ OF_PP_MAKE_TUPLE_SEQ("truncdiv", TruncDiv) \ OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) #define MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS \ OF_PP_MAKE_TUPLE_SEQ("pow", Pow) \ OF_PP_MAKE_TUPLE_SEQ("atan2", Atan2) \ OF_PP_MAKE_TUPLE_SEQ("floordiv", Floordiv) \ OF_PP_MAKE_TUPLE_SEQ("truncdiv", Truncdiv) \ OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_ ================================================ FILE: oneflow/user/ops/math_unary_elementwise_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { #define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ } \ /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ } \ /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ } #define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_X(math_unary_elementwise_type, \ func_prefix) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp) OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_X, MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ) #define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_Y(math_unary_elementwise_type, \ func_prefix) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp) OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_Y, MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ) #define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_FILL(math_unary_elementwise_type, \ func_prefix) \ MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op) OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_FILL, MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ) // Negative's grad function = negative(dy), so here register negative op separately. MATH_ELEMENTWISE_DEFAULT_SET_FUNC(NegativeOp) MATH_ELEMENTWISE_DEFAULT_SET_FUNC(BitwiseNotOp) MATH_ELEMENTWISE_DEFAULT_SET_FUNC(TrigammaOp) } // namespace oneflow ================================================ FILE: oneflow/user/ops/math_unary_elementwise_seq.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_ #define ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_ #include "oneflow/core/common/util.h" namespace oneflow { #define MATH_UNARY_ELEMENTWISE_FUNC_SEQ \ OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \ OF_PP_MAKE_TUPLE_SEQ("acos", Acos) \ OF_PP_MAKE_TUPLE_SEQ("acosh", Acosh) \ OF_PP_MAKE_TUPLE_SEQ("asin", Asin) \ OF_PP_MAKE_TUPLE_SEQ("asinh", Asinh) \ OF_PP_MAKE_TUPLE_SEQ("atan", Atan) \ OF_PP_MAKE_TUPLE_SEQ("atanh", Atanh) \ OF_PP_MAKE_TUPLE_SEQ("ceil", Ceil) \ OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \ OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \ OF_PP_MAKE_TUPLE_SEQ("digamma", Digamma) \ OF_PP_MAKE_TUPLE_SEQ("trigamma", Trigamma) \ OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \ OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \ OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ OF_PP_MAKE_TUPLE_SEQ("exp2", Exp2) \ OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ OF_PP_MAKE_TUPLE_SEQ("log", Log) \ OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log10", Log10) \ OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("negative", Negative) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ReciprocalNoNan) \ OF_PP_MAKE_TUPLE_SEQ("rint", Rint) \ OF_PP_MAKE_TUPLE_SEQ("round", Round) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt", Rsqrt) \ OF_PP_MAKE_TUPLE_SEQ("sigmoid_v2", Sigmoid) \ OF_PP_MAKE_TUPLE_SEQ("sign", Sign) \ OF_PP_MAKE_TUPLE_SEQ("sin", Sin) \ OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \ OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \ OF_PP_MAKE_TUPLE_SEQ("square", Square) \ OF_PP_MAKE_TUPLE_SEQ("tan", Tan) \ OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero) #define MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ \ OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \ OF_PP_MAKE_TUPLE_SEQ("acos", Acos) \ OF_PP_MAKE_TUPLE_SEQ("acosh", Acosh) \ OF_PP_MAKE_TUPLE_SEQ("asin", Asin) \ OF_PP_MAKE_TUPLE_SEQ("asinh", Asinh) \ OF_PP_MAKE_TUPLE_SEQ("atan", Atan) \ OF_PP_MAKE_TUPLE_SEQ("atanh", Atanh) \ OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \ OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \ OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \ OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \ OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ OF_PP_MAKE_TUPLE_SEQ("exp2", Exp2) \ OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ OF_PP_MAKE_TUPLE_SEQ("log", Log) \ OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ OF_PP_MAKE_TUPLE_SEQ("digamma", Digamma) \ OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log10", Log10) \ OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ReciprocalNoNan) \ OF_PP_MAKE_TUPLE_SEQ("rsqrt", Rsqrt) \ OF_PP_MAKE_TUPLE_SEQ("sin", Sin) \ OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \ OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \ OF_PP_MAKE_TUPLE_SEQ("square", Square) \ OF_PP_MAKE_TUPLE_SEQ("tan", Tan) #define MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ OF_PP_MAKE_TUPLE_SEQ("sigmoid", Sigmoid) #define MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ \ OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero) \ OF_PP_MAKE_TUPLE_SEQ("sign", Sign) \ OF_PP_MAKE_TUPLE_SEQ("rint", Rint) \ OF_PP_MAKE_TUPLE_SEQ("round", Round) \ OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ OF_PP_MAKE_TUPLE_SEQ("ceil", Ceil) } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_ ================================================ FILE: oneflow/user/ops/matmul_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { bool transpose_a = ctx->Attr("transpose_a"); bool transpose_b = ctx->Attr("transpose_b"); const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); CHECK_GE_OR_RETURN(a.shape().NumAxes(), 2); size_t num_axes = a.shape().NumAxes(); if (num_axes > 2) { for (int i = 0; i < num_axes - 2; ++i) { CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); } } user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); Shape output = ctx->InputShape("a", 0); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("a", 0)); int64_t m, n, k; // tensor a (no trans): m*k, tensor b (no trans): k*n if (!transpose_a) { m = a.shape().At(num_axes - 2); k = a.shape().At(num_axes - 1); } else { m = a.shape().At(num_axes - 1); k = a.shape().At(num_axes - 2); } if (!transpose_b) { CHECK_EQ_OR_RETURN(k, b.shape().At(num_axes - 2)); n = b.shape().At(num_axes - 1); } else { CHECK_EQ_OR_RETURN(k, b.shape().At(num_axes - 1)); n = b.shape().At(num_axes - 2); } output.Set(num_axes - 2, m); output.Set(num_axes - 1, n); out->set_shape(output); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); } return Maybe::Ok(); } Maybe InferDataType4Matmul(user_op::InferContext* ctx) { DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "InferDataType Failed. Expected " << DataType_Name(dtype) << ", but got " << DataType_Name(ctx->InputDType("b", 0)); if (ctx->has_input("_add_to_output", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("_add_to_output", 0), dtype) << "InferDataType Failed. Expected " << DataType_Name(dtype) << ", but got " << DataType_Name(ctx->InputDType("_add_to_output", 0)); } ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } // Theoretically computation cost of matrix multiplication is the products of the number of matrix // and first dimension of matrix a, second dimension of matrix a, second dimension of matrix // b. If there is any splitting sbp parallel, the computation cost will be divided by number of // machines. If we use S(1) at matrix a and S(0) at matrix b, then it will be P at output matrix. // This is why we don't use SbpParallel at output matrix. Maybe GetComputationCost(user_op::ComputeComplexityFnContext* ctx) { bool transpose_b = ctx->Attr("transpose_b"); const Shape& shape_b = ctx->Shape4ArgNameAndIndex("b", 0); int64_t n = 0; if (!transpose_b) { n = shape_b.At(shape_b.NumAxes() - 1); } else { n = shape_b.At(shape_b.NumAxes() - 2); } double logical_computation_cost = 2 * ctx->Shape4ArgNameAndIndex("a", 0).elem_cnt() * n; const auto& nd_sbp_a = ctx->NdSbp4ArgNameAndIndex("a", 0); const auto& nd_sbp_b = ctx->NdSbp4ArgNameAndIndex("b", 0); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); for (int32_t sbp_dim = 0; sbp_dim < nd_sbp_a.sbp_parallel_size(); sbp_dim++) { if (nd_sbp_a.sbp_parallel(sbp_dim).has_split_parallel() || nd_sbp_b.sbp_parallel(sbp_dim).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(sbp_dim); } } return logical_computation_cost; } } // namespace /* static */ Maybe MatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4Matmul(ctx); } /*static*/ Maybe MatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MatmulOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) { return GetComputationCost(ctx); } /* static */ Maybe MatmulOp::GetSbp(user_op::SbpContext* ctx) { // (m, k_a) * (k_b, n) where k_a == k_b int32_t m_axis = -1; int32_t k_a_axis = -1; int32_t k_b_axis = -1; int32_t n_axis = -1; if (ctx->Attr("transpose_a")) { m_axis = 1; k_a_axis = 0; } else { m_axis = 0; k_a_axis = 1; } if (ctx->Attr("transpose_b")) { k_b_axis = 1; n_axis = 0; } else { k_b_axis = 0; n_axis = 1; } std::vector out_and_add_to_output_args; out_and_add_to_output_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { out_and_add_to_output_args.emplace_back("_add_to_output", 0); } ctx->NewBuilder() .Split(user_op::OpArg("a", 0), m_axis) .Broadcast(user_op::OpArg("b", 0)) .Split(out_and_add_to_output_args, 0) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), n_axis) .Split(out_and_add_to_output_args, 1) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("a", 0), k_a_axis) .Split(user_op::OpArg("b", 0), k_b_axis) .PartialSum(out_and_add_to_output_args) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); return Maybe::Ok(); } /* static */ Maybe MatmulOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } // BatchMatmul /* static */ Maybe BatchMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4Matmul(ctx); } /*static*/ Maybe BatchMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BatchMatmulOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0); std::vector out_and_add_to_output_args; out_and_add_to_output_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { out_and_add_to_output_args.emplace_back("_add_to_output", 0); } int32_t num_axes = a_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_axes - 2) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build(); } int32_t m_axis = -1; int32_t k_a_axis = -1; int32_t k_b_axis = -1; int32_t n_axis = -1; if (ctx->Attr("transpose_a")) { m_axis = num_axes - 1; k_a_axis = num_axes - 2; } else { m_axis = num_axes - 2; k_a_axis = num_axes - 1; } if (ctx->Attr("transpose_b")) { k_b_axis = num_axes - 1; n_axis = num_axes - 2; } else { k_b_axis = num_axes - 2; n_axis = num_axes - 1; } ctx->NewBuilder() .Split(user_op::OpArg("a", 0), m_axis) .Broadcast(user_op::OpArg("b", 0)) .Split(out_and_add_to_output_args, num_axes - 2) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), n_axis) .Split(out_and_add_to_output_args, num_axes - 1) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("a", 0), k_a_axis) .Split(user_op::OpArg("b", 0), k_b_axis) .PartialSum(out_and_add_to_output_args) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); return Maybe::Ok(); } /*static*/ Maybe BatchMatmulOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { return GetComputationCost(ctx); } /* static */ Maybe BatchMatmulOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } // BroadcastMatmul /* static */ Maybe BroadcastMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { bool transpose_a = ctx->Attr("transpose_a"); bool transpose_b = ctx->Attr("transpose_b"); const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); const int64_t num_a_dims = a.shape().NumAxes(); const int64_t num_b_dims = b.shape().NumAxes(); const size_t num_max_batch_dims = std::max(num_a_dims, num_b_dims) - 2; auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) { const int64_t num_batch_dims = num_dims - 2; const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims; return [num_padding_dims, shape_dim](size_t index) { return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims); }; }; auto GetABatchDim = MakeGetBatchDim(num_a_dims, a.shape()); auto GetBBatchDim = MakeGetBatchDim(num_b_dims, b.shape()); DimVector out_dim_vec(std::max(num_a_dims, num_b_dims)); FOR_RANGE(int64_t, i, 0, out_dim_vec.size() - 2) { // Set broadcast shape // m k k n // For example: A(16, 1, 4, 8) B(1, 8, 8, 6) // We First set the previous batch dims to broadcasted shape: C(16, 8) // Then we emplace back m, n -> C(16, 8, 4, 6) const int64_t a_batch_dim = GetABatchDim(i); const int64_t b_batch_dim = GetBBatchDim(i); CHECK(((a_batch_dim != 1 && b_batch_dim == 1) || (a_batch_dim == 1 && b_batch_dim != 1) || (a_batch_dim == b_batch_dim))) << "Batch Dims could not broadcast, please check. "; out_dim_vec[i] = std::max(a_batch_dim, b_batch_dim); } int64_t m = 0; int64_t n = 0; int64_t k = 0; // tensor a (no trans): batch_dims*m*k, tensor b (no trans): batch_dims*k*n if (!transpose_a) { m = a.shape().At(num_a_dims - 2); k = a.shape().At(num_a_dims - 1); } else { m = a.shape().At(num_a_dims - 1); k = a.shape().At(num_a_dims - 2); } if (!transpose_b) { CHECK_EQ_OR_RETURN(k, b.shape().At(num_b_dims - 2)) << "K dim should be equal to b.shape().At(num_b_dims - 2). "; n = b.shape().At(num_b_dims - 1); } else { CHECK_EQ_OR_RETURN(k, b.shape().At(num_b_dims - 1)) << "K dim should be equal to b.shape().At(num_b_dims - 1). "; n = b.shape().At(num_b_dims - 2); } out_dim_vec.at(num_max_batch_dims) = m; out_dim_vec.at(num_max_batch_dims + 1) = n; out->set_shape(Shape(out_dim_vec)); if (ctx->has_input("_add_to_output", 0)) { const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); } return Maybe::Ok(); } /*static*/ Maybe BroadcastMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe BroadcastMatmulOp::GetSbp(user_op::SbpContext* ctx) { // (b, m, k) * (k, n) when transpose_b is false // (b, m, k) * (n, k) when transpose_b is true bool transpose_a = ctx->Attr("transpose_a"); bool transpose_b = ctx->Attr("transpose_b"); const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); const auto& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("b", 0).shape(); const int64_t a_num_axes = a_shape.NumAxes(); const int64_t b_num_axes = b_shape.NumAxes(); int32_t m_a_axis = -1; int32_t k_a_axis = -1; int32_t k_b_axis = -1; int32_t n_axis = -1; if (transpose_a) { m_a_axis = a_num_axes - 1; k_a_axis = a_num_axes - 2; } else { m_a_axis = a_num_axes - 2; k_a_axis = a_num_axes - 1; } if (transpose_b) { k_b_axis = b_num_axes - 1; n_axis = b_num_axes - 2; } else { k_b_axis = b_num_axes - 2; n_axis = b_num_axes - 1; } std::vector out_and_add_to_output_args; out_and_add_to_output_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { out_and_add_to_output_args.emplace_back("_add_to_output", 0); } const int64_t a_batch_dims = a_num_axes - 2; const int64_t b_batch_dims = b_num_axes - 2; const int64_t max_num_axes = std::max(a_num_axes, b_num_axes); const size_t num_max_batch_dims = max_num_axes - 2; auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) { const int64_t num_batch_dims = num_dims - 2; const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims; return [num_padding_dims, shape_dim](size_t index) { return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims); }; }; auto GetABatchDim = MakeGetBatchDim(a_num_axes, a_shape); auto GetBBatchDim = MakeGetBatchDim(b_num_axes, b_shape); for (int i = 0; i < num_max_batch_dims; i++) { const int64_t a_batch_dim = GetABatchDim(i); const int64_t b_batch_dim = GetBBatchDim(i); if (a_batch_dim == b_batch_dim && a_batch_dim != 1) { // S(b axis) x S(b axis) -> S(b axis) ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i - (num_max_batch_dims - a_batch_dims)) .Split(user_op::OpArg("b", 0), i - (num_max_batch_dims - b_batch_dims)) .Split(out_and_add_to_output_args, i) .Build(); } else if (a_batch_dim == 1 && b_batch_dim != 1) { // B x S(b axis) -> S(b axis) ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), i - (num_max_batch_dims - b_batch_dims)) .Split(out_and_add_to_output_args, i) .Build(); } else if (b_batch_dim == 1 && a_batch_dim != 1) { // S(b axis) x B -> S(b axis) ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i - (num_max_batch_dims - a_batch_dims)) .Broadcast(user_op::OpArg("b", 0)) .Split(out_and_add_to_output_args, i) .Build(); } } // S(m axis) x B -> S(m axis) ctx->NewBuilder() .Split(user_op::OpArg("a", 0), m_a_axis) .Broadcast(user_op::OpArg("b", 0)) .Split(out_and_add_to_output_args, max_num_axes - 2) .Build(); // B x S(n_axis) -> S(n_axis) ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), n_axis) .Split(out_and_add_to_output_args, max_num_axes - 1) .Build(); // S(a_k_axis) x S(b_k_axis) -> P ctx->NewBuilder() .Split(user_op::OpArg("a", 0), k_a_axis) .Split(user_op::OpArg("b", 0), k_b_axis) .PartialSum(out_and_add_to_output_args) .Build(); // P x B -> P ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); // B x P -> P ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(out_and_add_to_output_args) .Build(); return Maybe::Ok(); } /* static */ Maybe BroadcastMatmulOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } /*static*/ Maybe BroadcastMatmulOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { return GetComputationCost(ctx); } /* static */ Maybe BroadcastMatmulGradBOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); } out->set_shape( Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)})); if (ctx->has_input("_add_to_output", 0)) { const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); } return Maybe::Ok(); } /*static*/ Maybe BroadcastMatmulGradBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe BroadcastMatmulGradBOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { const Shape& shape_a = ctx->Shape4ArgNameAndIndex("a", 0); int64_t n = shape_a.At(shape_a.NumAxes() - 2); double logical_computation_cost = 2 * ctx->Shape4ArgNameAndIndex("b", 0).elem_cnt() * n; const auto& nd_sbp_a = ctx->NdSbp4ArgNameAndIndex("a", 0); const auto& nd_sbp_b = ctx->NdSbp4ArgNameAndIndex("b", 0); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); for (int32_t sbp_dim = 0; sbp_dim < nd_sbp_a.sbp_parallel_size(); sbp_dim++) { if (nd_sbp_a.sbp_parallel(sbp_dim).has_split_parallel() || nd_sbp_b.sbp_parallel(sbp_dim).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(sbp_dim); } } return logical_computation_cost; } /* static */ Maybe BroadcastMatmulGradBOp::GetSbp(user_op::SbpContext* ctx) { const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); int64_t last_axis = a_shape.NumAxes() - 1; std::vector out_and_add_to_output_args; out_and_add_to_output_args.emplace_back("out", 0); if (ctx->user_op_conf().has_input("_add_to_output", 0)) { out_and_add_to_output_args.emplace_back("_add_to_output", 0); } // S(b or m axis) x S(b or m axis) -> P for (int64_t i = 0; i < last_axis; ++i) { ctx->NewBuilder() .Split(user_op::OpArg("a", 0), i) .Split(user_op::OpArg("b", 0), i) .PartialSum(out_and_add_to_output_args) .Build(); } // (b, m, k) * (b, m, n) -> (k, n) [transpose a] // S(k) x B -> S(0) or B x S(n) -> S(1) // (b, m, n) * (b, m, k) -> (n, k) [transpose a] // S(n) x B -> S(0) or B x S(k) -> S(1) ctx->NewBuilder() .Split(user_op::OpArg("a", 0), last_axis) .Broadcast(user_op::OpArg("b", 0)) .Split(out_and_add_to_output_args, 0) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), last_axis) .Split(out_and_add_to_output_args, 1) .Build(); return Maybe::Ok(); } /* static */ Maybe BroadcastMatmulGradBOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Matmul(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/matrix_vector_product_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4MatrixVectorProduct(user_op::InferContext* ctx) { const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t m = a.shape().At(0); int64_t k = a.shape().At(1); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; ctx->SetOutputShape("out", 0, Shape({m})); return Maybe::Ok(); } Maybe InferDataType4MatrixVectorProduct(user_op::InferContext* ctx) { DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "InferDataType Failed. Expected " << DataType_Name(dtype) << ", but got " << DataType_Name(ctx->InputDType("b", 0)); ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } Maybe InferTensorDesc4MatrixVectorProductGradA(user_op::InferContext* ctx) { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose */ const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t m = dy.shape().At(0); int64_t n = b.shape().At(0); ctx->SetOutputShape("dx", 0, Shape({m, n})); return Maybe::Ok(); } Maybe InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) GradB = dy_transpose (1, m) matmul A(m, k) */ const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t n = a.shape().At(1); ctx->SetOutputShape("dx", 0, Shape({n})); return Maybe::Ok(); } Maybe InferDataType4Grad(user_op::InferContext* ctx) { DataType dtype = ctx->InputDType("dy", 0); ctx->SetOutputDType("dx", 0, dtype); return Maybe::Ok(); } } // namespace /* static */ Maybe MatrixVectorProductOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4MatrixVectorProduct(ctx); } /*static*/ Maybe MatrixVectorProductOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MatrixVectorProductOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("a", 0), 0) .Broadcast(user_op::OpArg("b", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("a", 0), 1) .Split(user_op::OpArg("b", 0), 0) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe MatrixVectorProductOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4MatrixVectorProduct(ctx); } /* static */ Maybe MatrixVectorProductGradAOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4MatrixVectorProductGradA(ctx); } /*static*/ Maybe MatrixVectorProductGradAOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MatrixVectorProductGradAOp::GetSbp(user_op::SbpContext* ctx) { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose */ ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Broadcast(user_op::OpArg("b", 0)) .Split(user_op::OpArg("dx", 0), 0) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe MatrixVectorProductGradAOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Grad(ctx); } /* static */ Maybe MatrixVectorProductGradBOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4MatrixVectorProductGradB(ctx); } /*static*/ Maybe MatrixVectorProductGradBOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MatrixVectorProductGradBOp::GetSbp(user_op::SbpContext* ctx) { /* A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m) dy = (m, ) GradB = dy_transpose (1, m) matmul A(m, k) */ ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .Split(user_op::OpArg("a", 0), 1) .Split(user_op::OpArg("dx", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("a", 0), 0) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe MatrixVectorProductGradBOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Grad(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/max_pool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/max_pool_kernel_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { typedef std::function(user_op::InferContext* ctx)> TensorDescInferFn; TensorDescInferFn MaxPoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->InputShape("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& stride = ctx->Attr>("stride"); const std::vector& dilation = ctx->Attr>("dilation"); const bool return_indices = ctx->Attr("return_indices"); const bool ceil_mode = ctx->Attr("ceil_mode"); CHECK_EQ_OR_RETURN(kernel_size.size(), dim); for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); } CHECK_EQ_OR_RETURN(stride.size(), dim); for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); } for (int32_t i = 0; i < padding.size(); i++) { CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i]) << "pad should be smaller than half of kernel size"; } const MaxPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); y_desc->set_shape(params_3d.GetYShape()); user_op::TensorDesc* indice_desc = ctx->MutOutputTensorDesc("indice", 0); *indice_desc = *ctx->MutOutputTensorDesc("y", 0); indice_desc->set_shape(y_desc->shape()); indice_desc->set_data_type(kInt64); return Maybe::Ok(); }; } Maybe MaxPoolForwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes() - 2)) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("indice", 0), i) .Build(); } return Maybe::Ok(); } Maybe MaxPoolBackwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("indice", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } // Logically computation cost of pool op is the product of output data amount and pool kernal data // amount. After adding sbp, we just divide it by parallel number if output data is splitted because // splitting input and using partial sum for output is not a valid sbp for this op for now. Maybe GetComputationCost(user_op::ComputeComplexityFnContext* ctx, const std::string& blob_name) { const std::vector& pool_size = ctx->Attr>("kernel_size"); double logical_computation_cost = std::accumulate( pool_size.begin(), pool_size.end(), ctx->Shape4ArgNameAndIndex(blob_name, 0).elem_cnt(), std::multiplies()); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex(blob_name, 0); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) { if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(dim_sbp); } } return logical_computation_cost; } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { *ctx->MutOutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } Maybe FwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define IMPLEMENT_MAXPOOL_FUNCS(name, dim) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ return MaxPoolForwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return MaxPoolMakeForwardTensorDescInferFn(dim)(ctx); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ return FwInferDataType(ctx); \ } \ /*static*/ Maybe name##Op::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx, "y"); \ } IMPLEMENT_MAXPOOL_FUNCS(MaxPool1D, 1) IMPLEMENT_MAXPOOL_FUNCS(MaxPool2D, 2) IMPLEMENT_MAXPOOL_FUNCS(MaxPool3D, 3) #undef IMPLEMENT_MAXPOOL_FUNCS #define IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(name) \ /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return MaxPoolBackwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return BackwardTensorDescInferFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return BwInferDataType(ctx); \ } \ /*static*/ Maybe name##GradOp::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx, "dy"); \ } IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool1D) IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool2D) IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool3D) #undef IMPLEMENT_MAXPOOL_BACKWARD_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/max_unpool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape.h" #include "oneflow/user/kernels/max_unpool_kernel_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { typedef std::function(user_op::InferContext* ctx)> TensorDescInferFn; TensorDescInferFn MaxUnpoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->InputShape("x", 0); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& stride = ctx->Attr>("stride"); Shape output_shape = Shape(); if (ctx->Attr("has_output_size")) { output_shape = ctx->Attr("output_size"); } else { const MaxUnpoolParams3D params_3d(dim, x_shape, padding, kernel_size, stride); output_shape = params_3d.GetYShape(); } user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); y_desc->set_shape(output_shape); return Maybe::Ok(); }; } Maybe MaxUnpoolForwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2L, tensor.shape().NumAxes() - 2)) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("y", 0), i) .Build(); } return Maybe::Ok(); } Maybe MaxUnpoolBackwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, std::min(2L, tensor.shape().NumAxes())) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { *ctx->MutOutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } Maybe FwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace #define IMPLEMENT_MAXUNPOOL_FUNCS(name, dim) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ return MaxUnpoolForwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return MaxUnpoolMakeForwardTensorDescInferFn(dim)(ctx); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ return FwInferDataType(ctx); \ } IMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool1D, 1) IMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool2D, 2) IMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool3D, 3) #undef IMPLEMENT_MAXUNPOOL_FUNCS #define IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(name) \ /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return MaxUnpoolBackwardGetSbpFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return BackwardTensorDescInferFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return BwInferDataType(ctx); \ } IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool1D) IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool2D) IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool3D) #undef IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/median_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe MedianOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); int64_t num_axes = in_tensor.shape().NumAxes(); if (num_axes == 0) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } /*static*/ Maybe MedianOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& ones_shape = {1}; ctx->SetOutputShape("output", 0, ones_shape.RemoveOnes({0})); return Maybe::Ok(); } /*static*/ Maybe MedianOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MedianOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/median_with_indices_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe MedianWithIndicesOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); int64_t num_axes = in_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_axes - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } if (num_axes == 0) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } /*static*/ Maybe MedianWithIndicesOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const Shape& reduce_shape = CreateReducedShape(input_shape, {-1}); ctx->SetOutputShape("values", 0, reduce_shape.RemoveOnes({-1})); ctx->SetOutputShape("indices", 0, reduce_shape.RemoveOnes({-1})); return Maybe::Ok(); } /*static*/ Maybe MedianWithIndicesOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MedianWithIndicesOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("values", 0, ctx->InputDType("input", 0)); ctx->SetOutputDType("indices", 0, DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/min_max_observer_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe MinMaxObserverOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); if (ctx->Attr("quantization_formula") == "google") { if (ctx->Attr("per_layer_quantization") == true) { ctx->SetOutputShape("scale", 0, Shape({1})); ctx->SetOutputShape("zero_point", 0, Shape({1})); } else { // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 ctx->SetOutputShape("scale", 0, Shape({in_shape.At(0)})); ctx->SetOutputShape("zero_point", 0, Shape({in_shape.At(0)})); } } else { // quantization_formula == "cambricon" ctx->SetOutputShape("scale", 0, Shape({1})); ctx->SetOutputShape("zero_point", 0, Shape({1})); } return Maybe::Ok(); } /*static*/ Maybe MinMaxObserverOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the // global scale and zero_point return Maybe::Ok(); } /* static */ Maybe MinMaxObserverOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe MinMaxObserverOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& op_conf) { int32_t quantization_bit = op_conf.attr("quantization_bit"); CHECK_GT_OR_RETURN(quantization_bit, 1); CHECK_LE_OR_RETURN(quantization_bit, 8); std::string quantization_scheme = op_conf.attr("quantization_scheme"); CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); std::string quantization_formula = op_conf.attr("quantization_formula"); CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); return Maybe::Ok(); } /* static */ Maybe MinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("scale", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("zero_point", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/mish_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe MishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MishOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe MishOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe MishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MishGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe MishGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/mode_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ModeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); int64_t num_axes = in_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_axes - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } if (num_axes == 0) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } /*static*/ Maybe ModeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const Shape& reduce_shape = CreateReducedShape(input_shape, {-1}); ctx->SetOutputShape("values", 0, reduce_shape.RemoveOnes({-1})); ctx->SetOutputShape("indices", 0, reduce_shape.RemoveOnes({-1})); return Maybe::Ok(); } /*static*/ Maybe ModeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ModeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("values", 0, ctx->InputDType("input", 0)); ctx->SetOutputDType("indices", 0, DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/model_update_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe CheckShapeLike(const user_op::TensorDesc* tensor_desc, const user_op::TensorDesc* like) { CHECK_EQ_OR_RETURN(tensor_desc->shape(), like->shape()); return Maybe::Ok(); } Maybe CheckDataTypeLike(const user_op::TensorDesc* tensor_desc, const user_op::TensorDesc* like) { CHECK_EQ_OR_RETURN(tensor_desc->data_type(), like->data_type()) << "InferDataType Failed. Expected " << DataType_Name(tensor_desc->data_type()) << ", but got " << DataType_Name(like->data_type()); return Maybe::Ok(); } Maybe CheckScalarShape(const user_op::TensorDesc* tensor_desc) { CHECK_OR_RETURN(tensor_desc->shape().NumAxes() == 0 || (tensor_desc->shape().NumAxes() == 1 && tensor_desc->shape().At(0) == 1)) << tensor_desc->shape().DebugStr(); return Maybe::Ok(); } Maybe CheckScalarDataType(const user_op::TensorDesc* tensor_desc, const DataType data_type) { CHECK_EQ_OR_RETURN(tensor_desc->data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(tensor_desc->data_type()) << ", but got " << DataType_Name(data_type); return Maybe::Ok(); } Maybe CheckLearningRateShape(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc("learning_rate", 0); JUST(CheckScalarShape(&learning_rate)); } return Maybe::Ok(); } Maybe CheckLearningRateDataType(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc("learning_rate", 0); JUST(CheckScalarDataType(&learning_rate, DataType::kFloat)); } return Maybe::Ok(); } Maybe CheckIndexedSlicesModelDiffDesc(const user_op::TensorDesc* model, const user_op::TensorDesc* model_diff_indices, const user_op::TensorDesc* model_diff_values) { const int64_t num_indices_axes = model_diff_indices->shape().NumAxes(); const int64_t num_values_axes = model_diff_values->shape().NumAxes(); CHECK_GE_OR_RETURN(num_values_axes, num_indices_axes); FOR_RANGE(int64_t, i, 0, num_indices_axes) { CHECK_EQ_OR_RETURN(model_diff_values->shape().At(i), model_diff_indices->shape().At(i)); } const int64_t num_model_axes = model->shape().NumAxes(); CHECK_EQ_OR_RETURN(num_model_axes, num_values_axes - num_indices_axes + 1); FOR_RANGE(int64_t, i, 1, num_model_axes) { CHECK_EQ_OR_RETURN(model->shape().At(i), model_diff_values->shape().At(num_indices_axes + i - 1)); } return Maybe::Ok(); } Maybe CheckIndexedSlicesModelDiffDataType(const user_op::TensorDesc* model, const user_op::TensorDesc* model_diff_indices, const user_op::TensorDesc* model_diff_values) { CHECK_OR_RETURN(IsIndexDataType(model_diff_indices->data_type())); CHECK_EQ_OR_RETURN(model->data_type(), model_diff_values->data_type()) << "InferDataType Failed. Expected " << DataType_Name(model->data_type()) << ", but got " << DataType_Name(model_diff_values->data_type()); return Maybe::Ok(); } Maybe InferSGDUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); if (shape.NumAxes() > 0 && model_diff.shape().NumAxes() > 0) { CHECK_EQ_OR_RETURN(model_diff.shape(), shape); } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("model_copy", 0)) { CHECK_EQ_OR_RETURN(ctx->InputTensorDesc("model_copy", 0).shape(), shape) << "Model copy shape should be equal to Model shape. "; } if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferSGDUpdateDataType(user_op::InferContext* ctx) { JUST(CheckLearningRateDataType(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); JUST(CheckScalarDataType(&scale_by_tensor, model.data_type())); } return Maybe::Ok(); } Maybe InferIndexedSlicesSGDUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferIndexedSlicesSGDUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe InferMomentumUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckShapeLike(&momentum, &model)); JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferMomentumUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckDataTypeLike(&momentum, &model)); JUST(CheckLearningRateDataType(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, model.data_type())); } return Maybe::Ok(); } Maybe InferIndexedSlicesMomentumUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values)); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckShapeLike(&momentum, &model)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferIndexedSlicesMomentumUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values)); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckDataTypeLike(&momentum, &model)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe InferAdamUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", 0); JUST(CheckShapeLike(&m, &model)); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckShapeLike(&v, &model)); JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("model_copy", 0)) { CHECK_EQ_OR_RETURN(ctx->InputTensorDesc("model_copy", 0).shape(), shape) << "Model copy shape should be equal to Model shape. "; } if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferAdamUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", 0); JUST(CheckDataTypeLike(&m, &model)); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckDataTypeLike(&v, &model)); JUST(CheckLearningRateDataType(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, model.data_type())); } return Maybe::Ok(); } Maybe InferAdagradUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape); const user_op::TensorDesc& sum = ctx->InputTensorDesc("sum", 0); JUST(CheckShapeLike(&sum, &model)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferAdagradUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& sum = ctx->InputTensorDesc("sum", 0); JUST(CheckDataTypeLike(&sum, &model)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe InferIndexedSlicesAdamUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferIndexedSlicesAdamUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc("model_diff_indices", 0); const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc("model_diff_values", 0); JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe InferLambUpdateTensorDesc(user_op::InferContext* ctx) { const float beta1 = ctx->Attr("beta1"); const float beta2 = ctx->Attr("beta2"); CHECK_GE_OR_RETURN(beta1, 0); CHECK_LT_OR_RETURN(beta1, 1); CHECK_GE_OR_RETURN(beta2, 0); CHECK_LT_OR_RETURN(beta2, 1); const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", 0); JUST(CheckShapeLike(&m, &model)); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckShapeLike(&v, &model)); JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferLambUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", 0); JUST(CheckDataTypeLike(&m, &model)); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckDataTypeLike(&v, &model)); JUST(CheckLearningRateDataType(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, model.data_type())); } return Maybe::Ok(); } Maybe InferFtrlUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape) << "Model Diff shape is not consistent with Weight shape. "; const user_op::TensorDesc& accumulate = ctx->InputTensorDesc("accumulate", 0); const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); JUST(CheckShapeLike(&accumulate, &model)); JUST(CheckShapeLike(&z, &model)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferFtrlUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& accumulate = ctx->InputTensorDesc("accumulate", 0); const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); JUST(CheckDataTypeLike(&accumulate, &model)); JUST(CheckDataTypeLike(&z, &model)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe InferAdadeltaUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); const user_op::TensorDesc& square_avgs = ctx->InputTensorDesc("square_avgs", 0); const user_op::TensorDesc& acc_deltas = ctx->InputTensorDesc("acc_deltas", 0); JUST(CheckShapeLike(&model_diff, &model)); JUST(CheckShapeLike(&square_avgs, &model)); JUST(CheckShapeLike(&acc_deltas, &model)); JUST(CheckLearningRateShape(ctx)); return Maybe::Ok(); } Maybe InferAdadeltaUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& square_avgs = ctx->InputTensorDesc("square_avgs", 0); const user_op::TensorDesc& acc_deltas = ctx->InputTensorDesc("acc_deltas", 0); JUST(CheckDataTypeLike(&square_avgs, &model)); JUST(CheckDataTypeLike(&acc_deltas, &model)); JUST(CheckLearningRateDataType(ctx)); return Maybe::Ok(); } Maybe SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); CHECK_NOTNULL_OR_RETURN(arg_modifier); arg_modifier->set_is_mutable(true); return Maybe::Ok(); } Maybe AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); if (conf.has_input("max_v", 0)) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "max_v", 0)); } if (conf.has_input("model_copy", 0)) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_copy", 0)); } return Maybe::Ok(); } Maybe AdagradInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "sum", 0)); return Maybe::Ok(); } Maybe LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); return Maybe::Ok(); } Maybe SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); if (conf.has_input("model_copy", 0)) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_copy", 0)); } return Maybe::Ok(); } Maybe IndexedSlicesSgdInputArgModifyFn( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); return Maybe::Ok(); } Maybe MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); return Maybe::Ok(); } Maybe IndexedSlicesMomentumInputArgModifyFn( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); return Maybe::Ok(); } Maybe RmsPropUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_square", 0)); if (conf.attr("centered")) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_gradient", 0)); } return Maybe::Ok(); } Maybe LarsUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); return Maybe::Ok(); } Maybe FtrlInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "accumulate", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "z", 0)); return Maybe::Ok(); } Maybe AdadeltaInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "square_avgs", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "acc_deltas", 0)); return Maybe::Ok(); } Maybe InferRmsPropUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape); const user_op::TensorDesc& mean_square = ctx->InputTensorDesc("mean_square", 0); JUST(CheckShapeLike(&mean_square, &model)); JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } if (ctx->Attr("centered")) { CHECK_OR_RETURN(ctx->has_input("mean_gradient", 0)); const user_op::TensorDesc& mean_gradient = ctx->InputTensorDesc("mean_gradient", 0); JUST(CheckShapeLike(&mean_gradient, &model)); } else { CHECK_OR_RETURN(!ctx->has_input("mean_gradient", 0)); } return Maybe::Ok(); } Maybe InferRmsPropUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& mean_square = ctx->InputTensorDesc("mean_square", 0); JUST(CheckDataTypeLike(&mean_square, &model)); JUST(CheckLearningRateDataType(ctx)); const DataType data_type = model.data_type(); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, data_type)); } if (ctx->Attr("centered")) { CHECK_OR_RETURN(ctx->has_input("mean_gradient", 0)); const user_op::TensorDesc& mean_gradient = ctx->InputTensorDesc("mean_gradient", 0); JUST(CheckDataTypeLike(&mean_gradient, &model)); } return Maybe::Ok(); } Maybe InferLarsUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const Shape& shape = model.shape(); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), shape); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckShapeLike(&momentum, &model)); JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& momentum = ctx->InputTensorDesc("momentum", 0); JUST(CheckDataTypeLike(&momentum, &model)); JUST(CheckLearningRateDataType(ctx)); const DataType data_type = model.data_type(); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, data_type)); } return Maybe::Ok(); } } // namespace /* static */ Maybe SgdUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferSGDUpdateTensorDesc(ctx); } /*static*/ Maybe SgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { auto builder = ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis); if (ctx->user_op_conf().has_input("model_copy", 0)) { builder.Split(user_op::OpArg("model_copy", 0), axis); } builder.Build(); } return Maybe::Ok(); } /* static */ Maybe SgdUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return SgdInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe SgdUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferSGDUpdateDataType(ctx); } /* static */ Maybe IndexedSlicesSgdUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferIndexedSlicesSGDUpdateTensorDesc(ctx); } /*static*/ Maybe IndexedSlicesSgdUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IndexedSlicesSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); ctx->NewBuilder() .Broadcast(user_op::OpArg("learning_rate", 0)) .Broadcast(user_op::OpArg("model_diff_indices", 0)) .Broadcast(user_op::OpArg("model_diff_values", 0)) .Split(user_op::OpArg("model", 0), 0) .Build(); FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(user_op::OpArg("learning_rate", 0)) .Broadcast(user_op::OpArg("model_diff_indices", 0)) .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) .Split(user_op::OpArg("model", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe IndexedSlicesSgdUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return IndexedSlicesSgdInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe IndexedSlicesSgdUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferIndexedSlicesSGDUpdateDataType(ctx); } /* static */ Maybe MomentumUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferMomentumUpdateTensorDesc(ctx); } /*static*/ Maybe MomentumUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("momentum", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe MomentumUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return MomentumInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferMomentumUpdateDataType(ctx); } /* static */ Maybe IndexedSlicesMomentumUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferIndexedSlicesMomentumUpdateTensorDesc(ctx); } /*static*/ Maybe IndexedSlicesMomentumUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IndexedSlicesMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); ctx->NewBuilder() .Broadcast(user_op::OpArg("learning_rate", 0)) .Broadcast(user_op::OpArg("model_diff_indices", 0)) .Broadcast(user_op::OpArg("model_diff_values", 0)) .Split(user_op::OpArg("model", 0), 0) .Split(user_op::OpArg("momentum", 0), 0) .Build(); FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(user_op::OpArg("learning_rate", 0)) .Broadcast(user_op::OpArg("model_diff_indices", 0)) .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) .Split(user_op::OpArg("model", 0), i) .Split(user_op::OpArg("momentum", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe IndexedSlicesMomentumUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return IndexedSlicesMomentumInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe IndexedSlicesMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferIndexedSlicesMomentumUpdateDataType(ctx); } /* static */ Maybe AdamUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferAdamUpdateTensorDesc(ctx); } /*static*/ Maybe AdamUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { std::vector split_args; split_args.emplace_back("model", 0); split_args.emplace_back("model_diff", 0); split_args.emplace_back("m", 0); split_args.emplace_back("v", 0); if (ctx->user_op_conf().has_input("max_v", 0)) { split_args.emplace_back("max_v", 0); } auto builder = ctx->NewBuilder().Broadcast(ctx->inputs()).Split(split_args, axis); if (ctx->user_op_conf().has_input("model_copy", 0)) { builder.Split(user_op::OpArg("model_copy", 0), axis); } builder.Build(); } return Maybe::Ok(); } /* static */ Maybe AdamUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdamInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe AdamUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferAdamUpdateDataType(ctx); } /* static */ Maybe AdagradUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferAdagradUpdateTensorDesc(ctx); } /*static*/ Maybe AdagradUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("sum", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe AdagradUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdagradInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe AdagradUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferAdagradUpdateDataType(ctx); } /* static */ Maybe IndexedSlicesAdamUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferIndexedSlicesAdamUpdateTensorDesc(ctx); } /*static*/ Maybe IndexedSlicesAdamUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IndexedSlicesAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); const user_op::TensorDesc& model_diff_indices = ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); std::vector broadcast_args; broadcast_args.emplace_back("learning_rate", 0); broadcast_args.emplace_back("model_diff_indices", 0); std::vector split_args; split_args.emplace_back("model", 0); split_args.emplace_back("m", 0); split_args.emplace_back("v", 0); if (ctx->user_op_conf().has_input("max_v", 0)) { split_args.emplace_back("max_v", 0); } ctx->NewBuilder() .Broadcast(broadcast_args) .Broadcast(user_op::OpArg("model_diff_values", 0)) .Split(split_args, 0) .Build(); FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(broadcast_args) .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) .Split(split_args, i) .Build(); } return Maybe::Ok(); } /* static */ Maybe IndexedSlicesAdamUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdamInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe IndexedSlicesAdamUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferIndexedSlicesAdamUpdateDataType(ctx); } /* static */ Maybe LambUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferLambUpdateTensorDesc(ctx); } /*static*/ Maybe LambUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LambUpdateOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe LambUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return LambInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe LambUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferLambUpdateDataType(ctx); } /* static */ Maybe AdamBiasCorrectionFactorOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("train_step", 0)); return Maybe::Ok(); } /*static*/ Maybe AdamBiasCorrectionFactorOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AdamBiasCorrectionFactorOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe AdamBiasCorrectionFactorOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kFloat); return Maybe::Ok(); } /* static */ Maybe RmspropUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferRmsPropUpdateTensorDesc(ctx); } /*static*/ Maybe RmspropUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RmspropUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); bool centered = ctx->Attr("centered"); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { if (centered) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("mean_square", 0), axis) .Split(user_op::OpArg("mean_gradient", 0), axis) .Build(); } else { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("mean_square", 0), axis) .Build(); } } return Maybe::Ok(); } /* static */ Maybe RmspropUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return RmsPropUpdateInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe RmspropUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferRmsPropUpdateDataType(ctx); } /* static */ Maybe LarsUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferLarsUpdateTensorDesc(ctx); } /*static*/ Maybe LarsUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe LarsUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("momentum", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe LarsUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return LarsUpdateInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe LarsUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferLarsUpdateDataType(ctx); } /* static */ Maybe FtrlUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return FtrlInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe FtrlUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferFtrlUpdateTensorDesc(ctx); } /*static*/ Maybe FtrlUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe FtrlUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("accumulate", 0), axis) .Split(user_op::OpArg("z", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe FtrlUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferFtrlUpdateDataType(ctx); } /* static */ Maybe AdadeltaUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdadeltaInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe AdadeltaUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferAdadeltaUpdateTensorDesc(ctx); } /*static*/ Maybe AdadeltaUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe AdadeltaUpdateOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() .Broadcast(ctx->inputs()) .Split(user_op::OpArg("model", 0), axis) .Split(user_op::OpArg("model_diff", 0), axis) .Split(user_op::OpArg("square_avgs", 0), axis) .Split(user_op::OpArg("acc_deltas", 0), axis) .Build(); } return Maybe::Ok(); } /* static */ Maybe AdadeltaUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferAdadeltaUpdateDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/moving_average_min_max_observer_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe MovingAverageMinMaxObserverOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& moving_max_shape = ctx->InputShape("moving_max", 0); const Shape& moving_min_shape = ctx->InputShape("moving_min", 0); const Shape& current_train_step = ctx->InputShape("current_train_step", 0); // NOTE(Liang Depeng): for now only support per-layer quantization // TODO(Liang Depeng): depthwise convolution support per-channel quantization CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1); CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1); CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); ctx->SetOutputShape("scale", 0, Shape({1})); ctx->SetOutputShape("zero_point", 0, Shape({1})); return Maybe::Ok(); } /*static*/ Maybe MovingAverageMinMaxObserverOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MovingAverageMinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the // global scale and zero_point return Maybe::Ok(); } /* static */ Maybe MovingAverageMinMaxObserverOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); user_op::InputArgModifier* current_train_step = GetInputArgModifierFn("current_train_step", 0); CHECK_OR_RETURN(current_train_step != nullptr); current_train_step->set_requires_grad(false); user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); CHECK_OR_RETURN(moving_max != nullptr); moving_max->set_requires_grad(false); moving_max->set_is_mutable(true); user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); CHECK_OR_RETURN(moving_min != nullptr); moving_min->set_requires_grad(false); moving_min->set_is_mutable(true); return Maybe::Ok(); } /* static */ Maybe MovingAverageMinMaxObserverOp::CheckAttr( const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& op_conf) { int32_t quantization_bit = op_conf.attr("quantization_bit"); CHECK_GT_OR_RETURN(quantization_bit, 1); CHECK_LE_OR_RETURN(quantization_bit, 8); std::string quantization_scheme = op_conf.attr("quantization_scheme"); CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); int64_t stop_update_after_iters = op_conf.attr("stop_update_after_iters"); CHECK_GT_OR_RETURN(stop_update_after_iters, 0); std::string quantization_formula = op_conf.attr("quantization_formula"); CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); return Maybe::Ok(); } /* static */ Maybe MovingAverageMinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("scale", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("zero_point", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/multi_reduce_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/framework/nd_sbp.h" namespace oneflow { namespace { Maybe InferMultiReduceOpShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; ctx->SetOutputShape("y", 0, Shape({})); return Maybe::Ok(); } Maybe InferMultiReduceOpDataType(user_op::InferContext* ctx) { DataType x_0_dtype = ctx->InputDType("x", 0); for (size_t i = 1; i < ctx->input_size("x"); ++i) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", i), x_0_dtype) << ctx->op_name() << ": the " << i << " th input has the different data type with others"; } ctx->SetOutputDType("y", 0, x_0_dtype); return Maybe::Ok(); } Maybe GetMultiReduceOpSbp(user_op::SbpContext* ctx) { const auto& x_0 = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); int64_t min_num_axes = x_0.shape().NumAxes(); for (size_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { const auto& x_i = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i); min_num_axes = std::min(min_num_axes, x_i.shape().NumAxes()); } for (int64_t i = 0; i < min_num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } Maybe InferLocalMultiReduceOpLogicalShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; const NdSbp& any_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0); for (int32_t i = 1; i < ctx->input_size("x"); ++i) { const NdSbp& input_i_sbp = ctx->NdSbp4ArgNameAndIndex("x", i); CHECK_OR_RETURN(input_i_sbp == any_nd_sbp) << ctx->op_name() << ": the " << i << " th arg has the different sbp with others, " << NdSbpToString(input_i_sbp) << " vs. " << NdSbpToString(any_nd_sbp); } auto rank_mesh = ctx->parallel_desc().hierarchy(); CHECK_EQ_OR_RETURN(rank_mesh->NumAxes(), any_nd_sbp.sbp_parallel_size()) << ctx->op_name() << ": ndim of ranks of " << *JUST(PlacementToString(ctx->parallel_desc())) << " is mismatched with the size of sbp " << NdSbpToString(any_nd_sbp); int64_t split_num = 1; for (int64_t i = 0; i < rank_mesh->NumAxes(); ++i) { if (any_nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= rank_mesh->At(i); } } ctx->SetOutputShape("y", 0, Shape({split_num})); return Maybe::Ok(); } Maybe InferLocalMultiReduceOpPhysicalShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; ctx->SetOutputShape("y", 0, Shape({1})); return Maybe::Ok(); } Maybe GetLocalMultiReduceOpSbp(user_op::SbpContext* ctx) { const auto& x_0 = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); int64_t min_num_axes = x_0.shape().NumAxes(); for (size_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { const auto& x_i = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i); min_num_axes = std::min(min_num_axes, x_i.shape().NumAxes()); } for (int64_t i = 0; i < min_num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("y", 0), 0).Build(); } return Maybe::Ok(); } } // namespace #define DEFINE_MULTI_REDUCE_OP_METHODS(op) \ Maybe op##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferMultiReduceOpShape(ctx); \ } \ Maybe op##Op::InferDataType(user_op::InferContext* ctx) { \ return InferMultiReduceOpDataType(ctx); \ } \ Maybe op##Op::GetSbp(user_op::SbpContext* ctx) { return GetMultiReduceOpSbp(ctx); } DEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceSumPowAbs) DEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceMaxAbs) DEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceMinAbs) #undef DEFINE_MULTI_REDUCE_OP_METHODS #define DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(op) \ Maybe op##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferLocalMultiReduceOpLogicalShape(ctx); \ } \ Maybe op##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLocalMultiReduceOpPhysicalShape(ctx); \ } \ Maybe op##Op::InferDataType(user_op::InferContext* ctx) { \ return InferMultiReduceOpDataType(ctx); \ } \ Maybe op##Op::GetSbp(user_op::SbpContext* ctx) { return GetLocalMultiReduceOpSbp(ctx); } DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(LocalMultiReduceMaxAbs) DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(LocalMultiReduceMinAbs) #undef DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS } // namespace oneflow ================================================ FILE: oneflow/user/ops/multi_tensor_model_update_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe CheckShapeLike(const user_op::TensorDesc* tensor_desc, const user_op::TensorDesc* like) { CHECK_EQ_OR_RETURN(tensor_desc->shape(), like->shape()) << "Tensordesc shape should be equal to Like shape. "; return Maybe::Ok(); } Maybe CheckDataTypeLike(const user_op::TensorDesc* tensor_desc, const user_op::TensorDesc* like) { CHECK_EQ_OR_RETURN(tensor_desc->data_type(), like->data_type()) << "Tensordesc DataType should be equal to Like DataType. "; return Maybe::Ok(); } Maybe CheckScalarShape(const user_op::TensorDesc* tensor_desc) { CHECK_OR_RETURN(tensor_desc->shape().NumAxes() == 0 || (tensor_desc->shape().NumAxes() == 1 && tensor_desc->shape().At(0) == 1)) << tensor_desc->shape().DebugStr(); return Maybe::Ok(); } Maybe CheckScalarDataType(const user_op::TensorDesc* tensor_desc, const DataType data_type) { CHECK_EQ_OR_RETURN(tensor_desc->data_type(), data_type) << "TensorDesc DataType should be equal to Scalar DataType. "; return Maybe::Ok(); } Maybe CheckLearningRateShape(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc("learning_rate", 0); JUST(CheckScalarShape(&learning_rate)); } return Maybe::Ok(); } Maybe CheckLearningRateDataType(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc("learning_rate", 0); JUST(CheckScalarDataType(&learning_rate, DataType::kFloat)); } return Maybe::Ok(); } Maybe SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); CHECK_NOTNULL_OR_RETURN(arg_modifier) << "Arg Modifier should not be null. "; arg_modifier->set_is_mutable(true); return Maybe::Ok(); } Maybe InferSGDUpdateTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model Diff shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferSGDUpdateDataType(user_op::InferContext* ctx) { JUST(CheckLearningRateDataType(ctx)); const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc("model", 0); const int64_t input_size = ctx->input_size("model"); for (int64_t i = 0; i < input_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); CHECK_EQ(model.data_type(), first_model_desc.data_type()) << "Model DataType should be equal. "; } if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type())); } return Maybe::Ok(); } Maybe SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); } return Maybe::Ok(); } Maybe InferMomentumUpdateTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc("momentum_buf", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model Diff shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(momentum_buf.shape(), model.shape()) << "Momentum buf shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferMomentumUpdateDataType(user_op::InferContext* ctx) { JUST(CheckLearningRateDataType(ctx)); const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc("model", 0); const int64_t input_size = ctx->input_size("model"); for (int64_t i = 0; i < input_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc("momentum_buf", i); CHECK_EQ(model.data_type(), first_model_desc.data_type()) << "Model DataType should be equal. "; CHECK_EQ(momentum_buf.data_type(), first_model_desc.data_type()) << "Momentum buf DataType should be equal. "; } if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type())); } return Maybe::Ok(); } Maybe MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum_buf", i)); } return Maybe::Ok(); } Maybe InferAdamUpdateTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", i); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model Diff shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(m.shape(), model.shape()) << "m shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(v.shape(), model.shape()) << "v shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe InferAdamUpdateDataType(user_op::InferContext* ctx) { // todo JUST(CheckLearningRateDataType(ctx)); const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc("model", 0); const int64_t input_size = ctx->input_size("model"); for (int64_t i = 0; i < input_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", i); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", i); CHECK_EQ(model.data_type(), first_model_desc.data_type()) << "Model DataType should be equal. "; CHECK_EQ(m.data_type(), first_model_desc.data_type()) << "m DataType should be equal. "; CHECK_EQ(v.data_type(), first_model_desc.data_type()) << "v DataType should be equal. "; } if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type())); } return Maybe::Ok(); } Maybe AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", i)); } return Maybe::Ok(); } Maybe InferSGDUpdateWithCastTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_copy = ctx->InputTensorDesc("model_copy", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model diff shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape()) << "Model copy shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe SgdWithCastInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_copy", i)); } return Maybe::Ok(); } Maybe InferMomentumUpdateWithCastTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_copy = ctx->InputTensorDesc("model_copy", i); const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc("momentum_buf", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model diff shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(momentum_buf.shape(), model.shape()) << "Momentum buf shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape()) << "Model copy shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe MomentumWithCastInputArgModifyFn( const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum_buf", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_copy", i)); } return Maybe::Ok(); } Maybe InferAdamUpdateWithCastTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", i); const user_op::TensorDesc& model_copy = ctx->InputTensorDesc("model_copy", i); const user_op::TensorDesc& m = ctx->InputTensorDesc("m", i); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", i); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()) << "Model diff shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape()) << "Model copy shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(m.shape(), model.shape()) << "m shape should be equal to Model shape. "; CHECK_EQ_OR_RETURN(v.shape(), model.shape()) << "v shape should be equal to Model shape. "; } JUST(CheckLearningRateShape(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); } return Maybe::Ok(); } Maybe AdamWithCastInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_copy", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", i)); } return Maybe::Ok(); } Maybe InferYoloV5WeightUpdateTensorDesc(user_op::InferContext* ctx) { const int64_t weight_size = ctx->input_size("model"); for (int i = 0; i < weight_size; i++) { const user_op::TensorDesc& model_i = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_update_i = ctx->InputTensorDesc("model_update", i); CHECK_EQ_OR_RETURN(model_update_i.shape(), model_i.shape()) << "All Model shape should be equal to model_update shape."; } return Maybe::Ok(); } Maybe InferYoloV5WeightUpdateDataType(user_op::InferContext* ctx) { JUST(CheckLearningRateDataType(ctx)); const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc("model", 0); const int64_t input_size = ctx->input_size("model"); for (int64_t i = 0; i < input_size; i++) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", i); const user_op::TensorDesc& model_update_i = ctx->InputTensorDesc("model_update", i); CHECK_EQ(model.data_type(), first_model_desc.data_type()) << "Model DataType should be equal. "; CHECK_EQ(model_update_i.data_type(), first_model_desc.data_type()) << "Model DataType should be equal to model_update DataType."; } return Maybe::Ok(); } Maybe YoloV5WeightInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { for (int64_t i = 0; i < conf.input_size("model"); i++) { JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", i)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model_update", i)); } return Maybe::Ok(); } } // namespace /* static */ Maybe MultiTensorSgdUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferSGDUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorSgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorSgdUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return SgdInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorSgdUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferSGDUpdateDataType(ctx); } /* static */ Maybe MultiTensorMomentumUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferMomentumUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorMomentumUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorMomentumUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return MomentumInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferMomentumUpdateDataType(ctx); } /* static */ Maybe MultiTensorAdamUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferAdamUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorAdamUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorAdamUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdamInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorAdamUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferAdamUpdateDataType(ctx); } /* static */ Maybe MultiTensorSgdUpdateWithCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferSGDUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorSgdUpdateWithCastOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorSgdUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorSgdUpdateWithCastOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return SgdWithCastInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorSgdUpdateWithCastOp::InferDataType(user_op::InferContext* ctx) { return InferSGDUpdateDataType(ctx); } /* static */ Maybe MultiTensorMomentumUpdateWithCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferMomentumUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorMomentumUpdateWithCastOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorMomentumUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorMomentumUpdateWithCastOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return MomentumWithCastInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorMomentumUpdateWithCastOp::InferDataType( user_op::InferContext* ctx) { return InferMomentumUpdateDataType(ctx); } /* static */ Maybe MultiTensorAdamUpdateWithCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferAdamUpdateWithCastTensorDesc(ctx); } /*static*/ Maybe MultiTensorAdamUpdateWithCastOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorAdamUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorAdamUpdateWithCastOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return AdamWithCastInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorAdamUpdateWithCastOp::InferDataType( user_op::InferContext* ctx) { return InferAdamUpdateDataType(ctx); } /* static */ Maybe MultiTensorYoloV5WeightUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferYoloV5WeightUpdateTensorDesc(ctx); } /*static*/ Maybe MultiTensorYoloV5WeightUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MultiTensorYoloV5WeightUpdateOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MultiTensorYoloV5WeightUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return YoloV5WeightInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe MultiTensorYoloV5WeightUpdateOp::InferDataType( user_op::InferContext* ctx) { return InferYoloV5WeightUpdateDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/mutable_cast_once_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe MutableCastOnceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); output_tensor_desc->set_shape(input_tensor_desc.shape()); output_tensor_desc->set_is_dynamic(input_tensor_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe MutableCastOnceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe MutableCastOnceOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe MutableCastOnceOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); output_tensor_desc->set_data_type(ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/narrow_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe NarrowOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); const int64_t& dim = ctx->Attr("dim"); const int64_t& start = ctx->Attr("start"); int64_t length = ctx->Attr("length"); CHECK_GE_OR_RETURN(dim, 0); CHECK_GE_OR_RETURN(start, 0); CHECK_GE_OR_RETURN(length, 0); // length should be input size if split the full slice dimension if (start == 0 && length > in.shape().At(dim)) { length = in.shape().At(dim); } user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector dim_vec; dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + dim); dim_vec.insert(dim_vec.end(), length); dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1, in.shape().dim_vec().end()); out->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe NarrowOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NarrowOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const int64_t& dim = ctx->Attr("dim"); const int64_t& length = ctx->Attr("length"); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { if (i != dim) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } else { if (length == in_tensor.shape().At(i)) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } } } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe NarrowOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } /* static */ Maybe NarrowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& like_shape = ctx->InputShape("like", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); const int64_t ndim = dy_shape.NumAxes(); CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); ctx->SetOutputShape("dx", 0, like_shape); return Maybe::Ok(); } /*static*/ Maybe NarrowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NarrowGradOp::GetSbp(user_op::SbpContext* ctx) { const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); const int64_t ndim = like_shape.NumAxes(); const int64_t& dim = ctx->Attr("dim"); const int64_t& length = ctx->Attr("length"); FOR_RANGE(int64_t, i, 0, ndim) { if (i != dim) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } else { if (length == like_shape.At(i)) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe NarrowGradOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); CHECK_NOTNULL_OR_RETURN(dy_modifier); dy_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe NarrowGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/user/ops/nccl_logical_util.h" namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // (*, P) -> (*, B) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0)); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_partial_sum_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogical_2DSameDim0AllReduceOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // (P, *) -> (B, *) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel()); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1) == output_nd_sbp->sbp_parallel(1)); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogical_2DSameDim1AllReduceOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // (*, S(0)) -> (*, B) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0)); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel()); CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel(1).split_parallel().axis(), 0); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogical_2DSameDim0AllGatherOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::GetSbp( user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // (*, S(>=1)) -> (*, B) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0)); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel()); CHECK_GE_OR_RETURN(input_nd_sbp->sbp_parallel(1).split_parallel().axis(), 1); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // (*, S) -> (*, S) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0)); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_split_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2); return Maybe::Ok(); } /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogical_2DSameDim0All2allOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nccl_logical_fusion_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/common/container_util.h" namespace oneflow { /* static */ Maybe _ncclLogicalFusionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int32_t nccl_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(nccl_size, ctx->output_size("out")); // NOLINT for (int32_t i = 0; i < nccl_size; ++i) { ctx->SetOutputShape("out", i, ctx->InputShape("in", i)); ctx->SetOutputIsDynamic("out", i, ctx->InputIsDynamic("in", i)); } return Maybe::Ok(); } /* static */ Maybe _ncclLogicalFusionOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalFusionOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { int32_t nccl_size = ctx->inputs().size(); CHECK_EQ_OR_RETURN(nccl_size, ctx->outputs().size()); // NOLINT const std::vector& src_nd_sbp_str_list = ctx->user_op_conf().attr>("src_nd_sbp_str_list"); const std::vector& dst_nd_sbp_str_list = ctx->user_op_conf().attr>("dst_nd_sbp_str_list"); CHECK_EQ_OR_RETURN(nccl_size, src_nd_sbp_str_list.size()); // NOLINT CHECK_EQ_OR_RETURN(nccl_size, dst_nd_sbp_str_list.size()); // NOLINT for (int32_t i = 0; i < nccl_size; ++i) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", i); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", i); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); CHECK_OR_RETURN(ParseNdSbpFromLongString(JUST(VectorAt(src_nd_sbp_str_list, i)), input_nd_sbp)) << Error::RuntimeError() << " Cannot parse str: " << JUST(VectorAt(src_nd_sbp_str_list, i)) << " to input nd_sbp attr of op : " << ctx->user_op_conf().op_name(); CHECK_OR_RETURN(ParseNdSbpFromLongString(JUST(VectorAt(dst_nd_sbp_str_list, i)), output_nd_sbp)) << Error::RuntimeError() << " Cannot parse str: " << JUST(VectorAt(dst_nd_sbp_str_list, i)) << " to output nd_sbp attr of op : " << ctx->user_op_conf().op_name(); } return Maybe::Ok(); } /* static */ Maybe _ncclLogicalFusionOp::InferDataType(user_op::InferContext* ctx) { int32_t nccl_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(nccl_size, ctx->output_size("out")); // NOLINT for (int32_t i = 0; i < nccl_size; ++i) { ctx->SetOutputDType("out", i, ctx->InputDType("in", i)); } return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalFusionOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nccl_logical_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/nccl_logical_util.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" namespace oneflow { /* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllReduceOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalAllReduceOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // P->B CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllReduceOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalAllReduceOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalReduceScatterOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // P->S(0) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel()); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalReduceScatterOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalReduceScatterOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllGatherOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // S(0)->B CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel()); CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllGatherOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalAllGatherOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // S(>=1)->B CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel()); CHECK_GE_OR_RETURN(input_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalAllGatherNoncontinuousOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::GetSbp( user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferNdSbp( user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // P->S(0) CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1) << "input_nd_sbp should be 1d."; CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1) << "output_nd_sbp should be 1d."; CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel()) << "input_nd_sbp should be partial_sum_parallel."; CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel()) << "output_nd_sbp should be split parallel."; CHECK_GE_OR_RETURN(output_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1) << "output_nd_sbp split axis should greater equal 1."; CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1) << "parallel_hierarchy should be 1d."; return Maybe::Ok(); } /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalReduceScatterNoncontinuousOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalS2sOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); // S->S CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1); CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1); CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel()); CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel()); CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalS2sOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalS2sOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } /* static */ Maybe _ncclLogicalSendRecvOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalSendRecvOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe _ncclLogicalSendRecvOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); input_nd_sbp->clear_sbp_parallel(); output_nd_sbp->clear_sbp_parallel(); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", input_nd_sbp)); JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", output_nd_sbp)); return Maybe::Ok(); } /* static */ Maybe _ncclLogicalSendRecvOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe> _ncclLogicalSendRecvOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nccl_logical_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/nccl_logical_util.h" namespace oneflow { std::string GetCommKeyFromNcclType(const std::string& op_type_name) { if (op_type_name == "_nccl_logical_2D_same_dim0_all_reduce" || op_type_name == "_nccl_logical_2D_same_dim0_all_gather" || op_type_name == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous" || op_type_name == "_nccl_logical_2D_same_dim0_all2all") { return "SameDim0"; } if (op_type_name == "_nccl_logical_2D_same_dim1_all_reduce") { return "SameDim1"; } return ""; } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nccl_logical_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_ #define ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/sbp_parallel.h" namespace oneflow { template struct AttrFromContext { const AttrT& operator()(ContextT*, const std::string&); }; template struct AttrFromContext { const AttrT& operator()(user_op::InferNdSbpFnContext* ctx, const std::string& attr_name) { return ctx->user_op_conf().template attr(attr_name); } }; template struct AttrFromContext { const AttrT& operator()(user_op::KernelInitContext* ctx, const std::string& attr_name) { return ctx->Attr(attr_name); } }; template struct AttrFromContext { const AttrT& operator()(user_op::InferContext* ctx, const std::string& attr_name) { return ctx->Attr(attr_name); } }; template struct OpTypeNameFromContext { const std::string& operator()(ContextT*); }; template<> struct OpTypeNameFromContext { const std::string& operator()(user_op::InferNdSbpFnContext* ctx) { return ctx->user_op_conf().op_type_name(); } }; template<> struct OpTypeNameFromContext { const std::string& operator()(user_op::KernelInitContext* ctx) { return ctx->op_type_name(); } }; template<> struct OpTypeNameFromContext { const std::string& operator()(user_op::InferContext* ctx) { return ctx->op_type_name(); } }; template Maybe GetNcclLogicalNdSbpFromAttr(ContextT* ctx, const std::string& attr_name, NdSbp* nd_sbp) { const auto& sbp_str_list = AttrFromContext>()(ctx, attr_name); if (!ParseNdSbpFromStringList(sbp_str_list, nd_sbp)) { std::ostringstream err; err << "invalid " << attr_name << ": ["; for (size_t i = 0; i < sbp_str_list.size(); ++i) { const auto& sbp_str = sbp_str_list[i]; if (i == 0) { err << sbp_str; } else { err << ", " << sbp_str; } } err << "] for " << OpTypeNameFromContext()(ctx); return Error::RuntimeError() << err.str(); } return Maybe::Ok(); } std::string GetCommKeyFromNcclType(const std::string& op_type_name); } // namespace oneflow #endif // ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_ ================================================ FILE: oneflow/user/ops/nd_index_slice_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe CheckScatterNdShape(const Shape& params_shape, const Shape& indices_shape, const Shape& updates_shape) { int64_t batch_ndims = indices_shape.NumAxes() - 1; int64_t index_ndims = indices_shape.At(batch_ndims); CHECK_LE_OR_RETURN(batch_ndims, updates_shape.NumAxes()); CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, batch_ndims) { CHECK_EQ_OR_RETURN(updates_shape.At(i), indices_shape.At(i)); } int64_t slice_ndims = params_shape.NumAxes() - index_ndims; CHECK_EQ_OR_RETURN(slice_ndims, updates_shape.NumAxes() - batch_ndims); FOR_RANGE(int64_t, i, 0, slice_ndims) { CHECK_EQ_OR_RETURN(updates_shape.At(i + batch_ndims), params_shape.At(i + index_ndims)); } return Maybe::Ok(); } Maybe InferScatterNdTensorDesc(user_op::InferContext* ctx) { const Shape& indices_shape = ctx->InputShape("indices", 0); const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& params_shape = ctx->Attr("shape"); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); ctx->SetOutputShape("out", 0, params_shape); return Maybe::Ok(); } Maybe InferScatterNdDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("updates", 0)); return Maybe::Ok(); } Maybe InferScatterNdLikeTensorDesc(user_op::InferContext* ctx) { const Shape& indices_shape = ctx->InputShape("indices", 0); const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& like_shape = ctx->InputShape("like", 0); JUST(CheckScatterNdShape(like_shape, indices_shape, updates_shape)); ctx->SetOutputShape("out", 0, like_shape); return Maybe::Ok(); } Maybe InferScatterNdLikeDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("updates", 0)); return Maybe::Ok(); } Maybe InferTensorScatterNdOptTensorDesc(user_op::InferContext* ctx) { const Shape& params_shape = ctx->InputShape("params", 0); const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); ctx->SetOutputShape("out", 0, params_shape); ctx->SetOutputStride("out", 0, ctx->InputStride("params", 0)); return Maybe::Ok(); } Maybe InferTensorScatterNdOptDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("params", 0)); return Maybe::Ok(); } Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { const user_op::TensorDesc& params_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); const user_op::TensorDesc& indices_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); int64_t indices_num_axes = indices_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { ctx->NewBuilder() .Broadcast(user_op::OpArg("params", 0)) .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("updates", 0), i) .Broadcast(user_op::OpArg("out", 0)) .Build(); } int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("params", 0), i) .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("updates", 0), i - index_ndims + indices_num_axes - 1) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("params", 0)) .Broadcast(user_op::OpArg("indices", 0)) .PartialSum(user_op::OpArg("updates", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } } // namespace /* static */ Maybe GatherNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& params_shape = ctx->InputShape("params", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1); CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1); FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { out_shape_vec.emplace_back(params_shape.At(i)); } const Shape& out_shape = Shape(out_shape_vec); bool is_out_of_bounds = params_shape.Count(0) == 0 && out_shape.Count(0) != 0; CHECK_OR_RETURN(!is_out_of_bounds) << Error::IndexError() << "The index is out of bounds for dimension with size 0"; ctx->SetOutputShape("out", 0, out_shape); return Maybe::Ok(); } /*static*/ Maybe GatherNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe GatherNdOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& params_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); const user_op::TensorDesc& indices_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); int64_t indices_num_axes = indices_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { ctx->NewBuilder() .Broadcast(user_op::OpArg("params", 0)) .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("params", 0), i) .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("out", 0), i - index_ndims + indices_num_axes - 1) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("params", 0)) .Broadcast(user_op::OpArg("indices", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe GatherNdOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe GatherNdOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("params", 0)); return Maybe::Ok(); } /* static */ Maybe ScatterNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferScatterNdTensorDesc(ctx); } /*static*/ Maybe ScatterNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ScatterNdOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& indices_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); int64_t indices_num_axes = indices_desc.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { ctx->NewBuilder() .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("updates", 0), i) .Broadcast(user_op::OpArg("out", 0)) .Build(); } const Shape& out_shape = ctx->Attr("shape"); int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1); int64_t slice_ndims = out_shape.NumAxes() - index_ndims; FOR_RANGE(int64_t, i, 0, slice_ndims) { ctx->NewBuilder() .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) .Split(user_op::OpArg("out", 0), i + index_ndims) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("updates", 0)) .Broadcast(user_op::OpArg("indices", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ScatterNdOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe ScatterNdOp::InferDataType(user_op::InferContext* ctx) { return InferScatterNdDataType(ctx); } /* static */ Maybe ScatterNdLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferScatterNdLikeTensorDesc(ctx); } /*static*/ Maybe ScatterNdLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ScatterNdLikeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& indices_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); int64_t indices_num_axes = indices_tensor.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { ctx->NewBuilder() .Broadcast(user_op::OpArg("like", 0)) .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("updates", 0), i) .Broadcast(user_op::OpArg("out", 0)) .Build(); } const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); int64_t slice_ndims = out_shape.NumAxes() - index_ndims; FOR_RANGE(int64_t, i, 0, slice_ndims) { ctx->NewBuilder() .Split(user_op::OpArg("like", 0), i + index_ndims) .Broadcast(user_op::OpArg("indices", 0)) .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) .Split(user_op::OpArg("out", 0), i + index_ndims) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("updates", 0)) .Broadcast(user_op::OpArg("indices", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe ScatterNdLikeOp::InferDataType(user_op::InferContext* ctx) { return InferScatterNdLikeDataType(ctx); } /* static */ Maybe TensorScatterNdUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorScatterNdOptTensorDesc(ctx); } /*static*/ Maybe TensorScatterNdUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe TensorScatterNdUpdateOp::GetSbp(user_op::SbpContext* ctx) { return GetTensorScatterNdOptSbpSignatures(ctx); } /* static */ Maybe TensorScatterNdUpdateOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe TensorScatterNdUpdateOp::InferDataType(user_op::InferContext* ctx) { return InferTensorScatterNdOptDataType(ctx); } /* static */ Maybe TensorScatterNdAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorScatterNdOptTensorDesc(ctx); } /*static*/ Maybe TensorScatterNdAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe TensorScatterNdAddOp::GetSbp(user_op::SbpContext* ctx) { return GetTensorScatterNdOptSbpSignatures(ctx); } /* static */ Maybe TensorScatterNdAddOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe TensorScatterNdAddOp::InferDataType(user_op::InferContext* ctx) { return InferTensorScatterNdOptDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nll_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe NLLOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) << ctx->op_name() << ": expected target being integer type"; DataType input_dtype = ctx->InputDType("input", 0); if (ctx->has_input("weight", 0)) { DataType weight_dtype = ctx->InputDType("weight", 0); CHECK_EQ_OR_RETURN(weight_dtype, input_dtype) << ctx->op_name() << ": expected weight dtype " << input_dtype << ", but got " << weight_dtype; } ctx->SetOutputDType("output", 0, input_dtype); ctx->SetOutputDType("out_weight", 0, input_dtype); return Maybe::Ok(); } /* static */ Maybe NLLOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const bool is_dynamic = input_desc.is_dynamic(); CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic) << ctx->op_name() << ": expected the same dynamic with input and target"; const int64_t K = input_desc.shape().NumAxes(); CHECK_GE_OR_RETURN(K, 2) << ctx->op_name() << ": expected 2 or more dimensions for input"; CHECK_EQ_OR_RETURN(target_desc.shape().NumAxes(), K - 1) << ctx->op_name() << ": expected 1 less diemensions than input for target"; const int64_t N = target_desc.shape().elem_cnt(); const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1); CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C) << ctx->op_name() << ": expected input size " << input_desc.shape().ToString() << " to match target size " << target_desc.shape().ToString(); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), is_dynamic) << ctx->op_name() << ": expected the same dynamic with input and weight"; CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C) << ctx->op_name() << ": expected weight size " << C << ", got " << weight_desc.shape().ToString(); } user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); output_desc->set_is_dynamic(is_dynamic); output_desc->set_shape(Shape({N})); user_op::TensorDesc* out_weight_desc = ctx->MutOutputTensorDesc("out_weight", 0); out_weight_desc->set_is_dynamic(is_dynamic); out_weight_desc->set_shape(Shape({N})); return Maybe::Ok(); } /* static */ Maybe NLLOp::GetSbp(user_op::SbpContext* ctx) { // split batch dim auto builder1 = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("output", 0), 0) .Split(user_op::OpArg("out_weight", 0), 0); if (ctx->user_op_conf().has_input("weight", 0)) { builder1.Broadcast(user_op::OpArg("weight", 0)); } builder1.Build(); // split class dim const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); auto builder2 = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) .Broadcast(user_op::OpArg("target", 0)) .PartialSum(user_op::OpArg("output", 0)) .PartialSum(user_op::OpArg("out_weight", 0)); if (ctx->user_op_conf().has_input("weight", 0)) { builder2.Split(user_op::OpArg("weight", 0), 0); } builder2.Build(); return Maybe::Ok(); } /* static */ Maybe NLLOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); target_modifier->set_requires_grad(false); if (conf.has_input("weight", 0)) { auto* weight_modifier = GetInputArgModifierFn("weight", 0); if (weight_modifier) { weight_modifier->set_requires_grad(false); } } return Maybe::Ok(); } /* static */ Maybe NLLGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) << ctx->op_name() << ": expected target being integer type"; DataType input_dtype = ctx->InputDType("input", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("out_grad", 0), input_dtype) << ctx->op_name() << ": expected out_grad dtype " << input_dtype << ", got " << ctx->InputDType("out_grad", 0); if (ctx->has_input("weight", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("weight", 0), input_dtype) << ctx->op_name() << ": expected weight dtype " << input_dtype << ", got " << ctx->InputDType("weight", 0); } ctx->SetOutputDType("in_grad", 0, input_dtype); return Maybe::Ok(); } /* static */ Maybe NLLGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& out_grad_desc = ctx->InputTensorDesc("out_grad", 0); bool is_dynamic = input_desc.is_dynamic(); CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic) << ctx->op_name() << ": expected target dynamic " << is_dynamic; CHECK_EQ_OR_RETURN(out_grad_desc.is_dynamic(), is_dynamic) << ctx->op_name() << ": expected out_grad dynamic " << is_dynamic; const int64_t N = target_desc.shape().elem_cnt(); CHECK_EQ_OR_RETURN(out_grad_desc.shape().elem_cnt(), N) << ctx->op_name() << ": expected out_grad size " << N << ", got " << out_grad_desc.shape().ToString(); const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1); CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C) << ctx->op_name() << ": expected input size " << N << ", got " << input_desc.shape().ToString(); if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C) << ctx->op_name() << ": expected weight size " << C << ", got " << weight_desc.shape().ToString(); } user_op::TensorDesc* in_grad_desc = ctx->MutOutputTensorDesc("in_grad", 0); in_grad_desc->set_is_dynamic(is_dynamic); in_grad_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } /* static */ Maybe NLLGradOp::GetSbp(user_op::SbpContext* ctx) { // split batch dim auto builder1 = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("out_grad", 0), 0) .Split(user_op::OpArg("in_grad", 0), 0); if (ctx->user_op_conf().has_input("weight", 0)) { builder1.Broadcast(user_op::OpArg("weight", 0)); } builder1.Build(); // split class dim const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); auto builder2 = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) .Broadcast(user_op::OpArg("target", 0)) .Broadcast(user_op::OpArg("out_grad", 0)) .Split(user_op::OpArg("in_grad", 0), shape.NumAxes() - 1); if (ctx->user_op_conf().has_input("weight", 0)) { builder2.Split(user_op::OpArg("weight", 0), 0); } builder2.Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nms_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferNmsTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape({ctx->InputShape("in", 0).At(0)})); return Maybe::Ok(); } Maybe InferNmsDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kInt8); return Maybe::Ok(); } } // namespace /* static */ Maybe NmsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferNmsTensorDesc(ctx); } /*static*/ Maybe NmsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NmsOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe NmsOp::InferDataType(user_op::InferContext* ctx) { return InferNmsDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nn_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/nn_util.h" namespace oneflow { Maybe CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after) { CHECK_GT_OR_RETURN(stride, 0); CHECK_GE_OR_RETURN(dilation_rate, 1); int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; if (padding_type == "valid") { if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; } if (padding_before) { *padding_before = 0; } if (padding_after) { *padding_after = 0; } } else if (padding_type == "same") { int64_t tmp_output_size = (input_size + stride - 1) / stride; if (output_size) { *output_size = tmp_output_size; } const int32_t padding_needed = std::max( 0, static_cast((tmp_output_size - 1) * stride + effective_filter_size - input_size)); // For odd values of total padding, add more padding at the 'right' // side of the given dimension. if (padding_before) { *padding_before = padding_needed / 2; } if (padding_after) { *padding_after = padding_needed - padding_needed / 2; } } else { UNIMPLEMENTED(); } if (output_size) { CHECK_GE_OR_RETURN((*output_size), 0); } return Maybe::Ok(); } Maybe CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, int32_t* padding_small, int32_t* padding_large) { CHECK_GT_OR_RETURN(stride, 0); CHECK_GE_OR_RETURN(dilation_rate, 1); int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; int64_t tmp_output_size = (input_size + stride - 1) / stride; const int32_t padding_needed = std::max( 0, static_cast((tmp_output_size - 1) * stride + effective_filter_size - input_size)); if (padding_small) { *padding_small = padding_needed / 2; } if (padding_large) { *padding_large = padding_needed - padding_needed / 2; } return Maybe::Ok(); } Maybe CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, int32_t padding_before, int64_t* output_size) { CHECK_GT_OR_RETURN(stride, 0); CHECK_GE_OR_RETURN(dilation_rate, 1); int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; if (output_size) { *output_size = (input_size + 2 * padding_before - effective_filter_size + stride) / stride; CHECK_GE_OR_RETURN((*output_size), 0); } return Maybe::Ok(); } const size_t IdxOffset(const std::string& data_format) { if (data_format == "channels_first") { return 2; } else if (data_format == "channels_last") { return 1; } else { UNIMPLEMENTED(); } } const int32_t ChannelIdx(const std::string& data_format, int32_t num_axes) { if (data_format == "channels_first") { return 1; } else if (data_format == "channels_last") { return num_axes - 1; } else { UNIMPLEMENTED(); } } } // namespace oneflow ================================================ FILE: oneflow/user/ops/nn_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_NN_UTIL_H_ #define ONEFLOW_USER_OPS_NN_UTIL_H_ #include "oneflow/core/framework/framework.h" namespace oneflow { Maybe CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, const std::string& padding_type, int64_t* output_size, int32_t* padding_before, int32_t* padding_after); Maybe CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, int32_t* padding_small, int32_t* padding_large); Maybe CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride, int32_t padding_before, int64_t* output_size); const size_t IdxOffset(const std::string& data_format); const int32_t ChannelIdx(const std::string& data_format, int32_t num_axes); } // namespace oneflow #endif // ONEFLOW_USER_OPS_NN_UTIL_H_ ================================================ FILE: oneflow/user/ops/noncontiguous_binary_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe NonContiguousBinaryOp::GetSbp(user_op::SbpContext* ctx) { // only support broadcast ctx->NewBuilder() .Broadcast(user_op::OpArg("lhs", 0)) .Broadcast(user_op::OpArg("rhs", 0)) .Broadcast(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe NonContiguousBinaryOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& lhs = ctx->InputShape("lhs", 0); const Shape& rhs = ctx->InputShape("rhs", 0); CHECK_EQ(lhs.NumAxes(), rhs.NumAxes()); for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i)); ctx->SetOutputShape("y", 0, lhs); const bool inplace = ctx->Attr("inplace"); if (inplace) { ctx->SetOutputStride("y", 0, ctx->InputStride("lhs", 0)); } else { // set contiguous for y if not inplace ctx->SetOutputStride("y", 0, Stride(lhs)); } return Maybe::Ok(); } /*static*/ Maybe NonContiguousBinaryOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe NonContiguousBinaryOp::InferDataType(user_op::InferContext* ctx) { auto lhs = ctx->InputDType("lhs", 0); auto rhs = ctx->InputDType("rhs", 0); ctx->SetOutputDType("y", 0, GetSizeOfDataType(lhs) >= GetSizeOfDataType(rhs) ? lhs : rhs); return Maybe::Ok(); } /*static*/ Maybe NonContiguousBinaryOpGrad::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("lhs", 0)) .Broadcast(user_op::OpArg("rhs", 0)) .Broadcast(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("dlhs", 0)) .Broadcast(user_op::OpArg("drhs", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe NonContiguousBinaryOpGrad::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& lhs = ctx->InputShape("lhs", 0); const Shape& rhs = ctx->InputShape("rhs", 0); CHECK_EQ(lhs.NumAxes(), rhs.NumAxes()); for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i)); ctx->SetOutputShape("dlhs", 0, lhs); ctx->SetOutputStride("dlhs", 0, ctx->InputStride("lhs", 0)); ctx->SetOutputShape("drhs", 0, rhs); ctx->SetOutputStride("drhs", 0, ctx->InputStride("rhs", 0)); return Maybe::Ok(); } /*static*/ Maybe NonContiguousBinaryOpGrad::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe NonContiguousBinaryOpGrad::InferDataType(user_op::InferContext* ctx) { auto lhs = ctx->InputDType("lhs", 0); auto rhs = ctx->InputDType("rhs", 0); ctx->SetOutputDType("dlhs", 0, lhs); ctx->SetOutputDType("drhs", 0, rhs); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/normalization_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cudnn_util.h" #endif namespace oneflow { namespace { std::function(const std::string&)> MakeCheckParamTensorDescFn( user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_input(bn, 0)) { const auto& tensor_desc = ctx->InputTensorDesc(bn, 0); CHECK_EQ_OR_RETURN(tensor_desc.shape(), shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeCheckParamDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_input(bn, 0)) { const auto& tensor_desc = ctx->InputTensorDesc(bn, 0); CHECK_EQ_OR_RETURN(tensor_desc.data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(tensor_desc.data_type()) << ", but got " << DataType_Name(data_type); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetParamTensorDescFn(user_op::InferContext* ctx, const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr); tensor_desc->set_shape(shape); } return Maybe::Ok(); }; } std::function(const std::string&)> MakeSetParamDataTypeFn(user_op::InferContext* ctx, DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr); tensor_desc->set_data_type(data_type); } return Maybe::Ok(); }; } Maybe FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { bool training = true; if (conf.op_type_name() == "normalization" || conf.op_type_name() == "normalization_add_relu") { training = conf.attr("training"); } if (conf.has_input("moving_mean", 0)) { CHECK_OR_RETURN(conf.has_input("moving_variance", 0)); user_op::InputArgModifier* moving_mean_modifier = GetInputArgModifierFn("moving_mean", 0); CHECK_OR_RETURN(moving_mean_modifier != nullptr); moving_mean_modifier->set_is_mutable(training); moving_mean_modifier->set_requires_grad(false); user_op::InputArgModifier* moving_variance_modifier = GetInputArgModifierFn("moving_variance", 0); CHECK_OR_RETURN(moving_variance_modifier != nullptr); moving_variance_modifier->set_is_mutable(training); moving_variance_modifier->set_requires_grad(false); } else { CHECK_OR_RETURN(training) << "Must have moving mean and moving variance for normalization in inference mode."; } return Maybe::Ok(); } Maybe FwGetSbpFn(user_op::SbpContext* ctx) { std::vector split_args; split_args.emplace_back("x", 0); split_args.emplace_back("y", 0); if (ctx->user_op_conf().has_input("addend", 0)) { split_args.emplace_back("addend", 0); } if (ctx->user_op_conf().has_input("_add_to_output", 0)) { split_args.emplace_back("_add_to_output", 0); } std::vector broadcast_args; broadcast_args.emplace_back("moving_mean", 0); broadcast_args.emplace_back("moving_variance", 0); broadcast_args.emplace_back("gamma", 0); broadcast_args.emplace_back("beta", 0); if (ctx->user_op_conf().has_output("mean", 0)) { broadcast_args.emplace_back("mean", 0); } if (ctx->user_op_conf().has_output("inv_variance", 0)) { broadcast_args.emplace_back("inv_variance", 0); } if (ctx->user_op_conf().has_output("reserve_space", 0)) { broadcast_args.emplace_back("reserve_space", 0); } ctx->NewBuilder().Broadcast(broadcast_args).Split(split_args, 0).Build(); return Maybe::Ok(); } user_op::TensorDescInferFn MakeFwTensorDescInferFn( const std::function(user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space)>& reserve_space_infer_fn) { return [reserve_space_infer_fn](user_op::InferContext* ctx) -> Maybe { #ifdef WITH_CUDA // assume cudnn is enabled CHECK_GE_OR_RETURN(ctx->Attr("epsilon"), CUDNN_BN_MIN_EPSILON); #endif const auto& x = ctx->InputTensorDesc("x", 0); const auto data_type = x.data_type(); const Shape& x_shape = x.shape(); if (ctx->has_input("addend", 0)) { const auto& addend = ctx->InputTensorDesc("addend", 0); CHECK_EQ_OR_RETURN(addend.data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(addend.data_type()) << ", but got " << DataType_Name(data_type); CHECK_EQ_OR_RETURN(addend.shape(), x_shape); } if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(add_to_output.data_type()) << ", but got " << DataType_Name(data_type); CHECK_EQ_OR_RETURN(add_to_output.shape(), x_shape); } *ctx->MutOutputTensorDesc("y", 0) = x; const auto axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); const Shape param_shape({x_shape.At(axis)}); const auto CheckParamTensorDesc = MakeCheckParamTensorDescFn(ctx, param_shape); const auto SetParamTensorDesc = MakeSetParamTensorDescFn(ctx, param_shape); if (ctx->has_input("moving_mean", 0)) { CHECK_OR_RETURN(ctx->has_input("moving_variance", 0)); JUST(CheckParamTensorDesc("moving_mean")); JUST(CheckParamTensorDesc("moving_variance")); } JUST(CheckParamTensorDesc("beta")); JUST(CheckParamTensorDesc("gamma")); JUST(SetParamTensorDesc("mean")); JUST(SetParamTensorDesc("inv_variance")); if (ctx->has_output("reserve_space", 0)) { CHECK_OR_RETURN(reserve_space_infer_fn); reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc("reserve_space", 0)); } return Maybe::Ok(); }; } user_op::DataTypeInferFn MakeFwDataTypeInferFn( const std::function(user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space)>& reserve_space_infer_fn) { return [reserve_space_infer_fn](user_op::InferContext* ctx) -> Maybe { const auto& x = ctx->InputTensorDesc("x", 0); const auto data_type = x.data_type(); if (ctx->has_input("addend", 0)) { const auto& addend = ctx->InputTensorDesc("addend", 0); CHECK_EQ_OR_RETURN(addend.data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(data_type) << ", but got " << DataType_Name(addend.data_type()); } if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type) << "InferDataType Failed. Expected " << DataType_Name(data_type) << ", but got " << DataType_Name(add_to_output.data_type()); } *ctx->MutOutputTensorDesc("y", 0) = x; const DataType param_data_type = (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16) ? DataType::kFloat : data_type; const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type); const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type); if (ctx->has_input("moving_mean", 0)) { CHECK_OR_RETURN(ctx->has_input("moving_variance", 0)); JUST(CheckParamDataType("moving_mean")); JUST(CheckParamDataType("moving_variance")); } CHECK_OR_RETURN(ctx->has_input("gamma", 0)); JUST(CheckParamDataType("beta")); JUST(CheckParamDataType("gamma")); JUST(SetParamDataType("mean")); JUST(SetParamDataType("inv_variance")); if (ctx->has_output("reserve_space", 0)) { CHECK_OR_RETURN(reserve_space_infer_fn); reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc("reserve_space", 0)); } return Maybe::Ok(); }; } user_op::TensorDescInferFn MakeFwTensorDescInferFn() { return MakeFwTensorDescInferFn( std::function(user_op::InferContext * ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space)>()); } user_op::DataTypeInferFn MakeFwDataTypeInferFn() { return MakeFwDataTypeInferFn( std::function(user_op::InferContext * ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space)>()); } } // namespace /* static */ Maybe NormalizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return MakeFwTensorDescInferFn()(ctx); } /*static*/ Maybe NormalizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NormalizationOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } /* static */ Maybe NormalizationOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return FwInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe NormalizationOp::InferDataType(user_op::InferContext* ctx) { return MakeFwDataTypeInferFn()(ctx); } /* static */ Maybe NormalizationAddReluOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { const auto& x_desc = ctx->InputTensorDesc("x", 0); size_t reserve_space_bits = x_desc.shape().elem_cnt(); int64_t parallel_num = ctx->parallel_num(); if (parallel_num != 1) { // There no need to call NdSbp4ArgNameAndIndex when parallel_num = 1 in local. const NdSbp& x_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); int64_t split_num = 1; for (int32_t i = 0; i < x_nd_sbp.sbp_parallel_size(); ++i) { if (x_nd_sbp.sbp_parallel(i).has_split_parallel()) { CHECK_EQ_OR_RETURN(x_nd_sbp.sbp_parallel(i).split_parallel().axis(), 0) << "blob x in NormalizationAddReluOp only support B or S(0)"; split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0); reserve_space_bits = reserve_space_bits / split_num; } reserve_space->set_shape(Shape({static_cast(RoundUp(reserve_space_bits, 32) / 32)})); return Maybe::Ok(); })(ctx); } /* static */ Maybe NormalizationAddReluOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { const auto& x_desc = ctx->InputTensorDesc("x", 0); reserve_space->set_shape( Shape({static_cast(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)})); return Maybe::Ok(); })(ctx); } /* static */ Maybe NormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } /* static */ Maybe NormalizationAddReluOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return FwInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) { return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { reserve_space->set_data_type(DataType::kInt32); return Maybe::Ok(); })(ctx); } #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) namespace { void InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int64_t n, int64_t c, int64_t h, int64_t w, size_t* reserve_space_size) { cudnnHandle_t cudnn_handle = Singleton::Get()->Get(); CudnnTensorDesc xy_desc(CUDNN_TENSOR_NHWC, data_type, n, c, h, w); CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0); OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, activation_desc.Get(), xy_desc.Get(), reserve_space_size)); Singleton::Get()->Put(cudnn_handle); } } // namespace /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { const Shape& x_shape = x->shape(); const auto axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); int64_t n = x_shape.At(0); { const auto& x_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); int64_t split_num = 1; for (int32_t i = 0; i < x_nd_sbp.sbp_parallel_size(); ++i) { if (x_nd_sbp.sbp_parallel(i).has_split_parallel()) { CHECK_EQ_OR_RETURN(x_nd_sbp.sbp_parallel(i).split_parallel().axis(), 0) << "blob x in CudnnFusedNormalizationAddReluOp only support B or S(0)"; split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(n % split_num, 0); n = n / split_num; } int64_t h = x_shape.Count(1, axis); int64_t w = 1; int64_t c = x_shape.At(axis); cudnnBatchNormOps_t ops; if (ctx->has_input("addend", 0)) { ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; } else { ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; } size_t reserve_space_size; InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); reserve_space_size = std::max(reserve_space_size, GetOneVal()); reserve_space->set_shape(Shape({static_cast(reserve_space_size)})); return Maybe::Ok(); })(ctx); } /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { const Shape& x_shape = x->shape(); const auto axis = ctx->Attr("axis"); CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); int64_t n = x_shape.At(0); int64_t h = x_shape.Count(1, axis); int64_t w = 1; int64_t c = x_shape.At(axis); cudnnBatchNormOps_t ops; if (ctx->has_input("addend", 0)) { ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; } else { ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; } size_t reserve_space_size; InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); reserve_space_size = std::max(reserve_space_size, GetOneVal()); reserve_space->set_shape(Shape({static_cast(reserve_space_size)})); return Maybe::Ok(); })(ctx); } /* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } /* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return FwInputArgModifyFn(GetInputArgModifierFn, conf); } /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( user_op::InferContext* ctx) { return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, user_op::TensorDesc* reserve_space) -> Maybe { reserve_space->set_data_type(DataType::kChar); return Maybe::Ok(); })(ctx); } #else /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } #endif // WITH_CUDA namespace { Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { #ifdef WITH_CUDA // assume cudnn is enabled CHECK_GE_OR_RETURN(ctx->Attr("epsilon"), CUDNN_BN_MIN_EPSILON); #endif const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const Shape& x_shape = x.shape(); const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(dy.shape(), x_shape); if (ctx->has_input("y", 0)) { const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(y.shape(), x_shape); } *ctx->MutOutputTensorDesc("dx", 0) = x; if (ctx->has_output("addend_diff", 0)) { *ctx->MutOutputTensorDesc("addend_diff", 0) = x; } const Shape param_shape({x_shape.At(ctx->Attr("axis"))}); const auto CheckParamTensorDesc = MakeCheckParamTensorDescFn(ctx, param_shape); const auto SetParamTensorDesc = MakeSetParamTensorDescFn(ctx, param_shape); JUST(CheckParamTensorDesc("mean")); JUST(CheckParamTensorDesc("inv_variance")); JUST(CheckParamTensorDesc("gamma")); JUST(CheckParamTensorDesc("beta")); JUST(SetParamTensorDesc("gamma_diff")); JUST(SetParamTensorDesc("beta_diff")); return Maybe::Ok(); } Maybe BwDataTypeInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const DataType x_type = x.data_type(); const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(dy.data_type(), x_type) << "InferDataType Failed. Expected " << DataType_Name(x_type) << ", but got " << DataType_Name(dy.data_type()); if (ctx->has_input("y", 0)) { const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(y.data_type(), x_type) << "InferDataType Failed. Expected " << DataType_Name(x_type) << ", but got " << DataType_Name(y.data_type()); } *ctx->MutOutputTensorDesc("dx", 0) = x; if (ctx->has_output("addend_diff", 0)) { *ctx->MutOutputTensorDesc("addend_diff", 0) = x; } const DataType param_data_type = (x_type == DataType::kFloat16 || x_type == DataType::kBFloat16) ? DataType::kFloat : x_type; const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type); const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type); JUST(CheckParamDataType("mean")); JUST(CheckParamDataType("inv_variance")); JUST(CheckParamDataType("gamma")); JUST(CheckParamDataType("beta")); JUST(SetParamDataType("gamma_diff")); JUST(SetParamDataType("beta_diff")); return Maybe::Ok(); } Maybe BwGetSbpFn(user_op::SbpContext* ctx) { std::vector broadcast_args; broadcast_args.emplace_back("mean", 0); broadcast_args.emplace_back("inv_variance", 0); broadcast_args.emplace_back("gamma", 0); if (ctx->user_op_conf().has_input("beta", 0)) { broadcast_args.emplace_back("beta", 0); } if (ctx->user_op_conf().has_input("reserve_space", 0)) { broadcast_args.emplace_back("reserve_space", 0); } std::vector partial_sum_args; partial_sum_args.emplace_back("gamma_diff", 0); partial_sum_args.emplace_back("beta_diff", 0); std::vector split_args; split_args.emplace_back("x", 0); split_args.emplace_back("dy", 0); split_args.emplace_back("dx", 0); if (ctx->user_op_conf().has_input("y", 0)) { split_args.emplace_back("y", 0); } if (ctx->user_op_conf().has_output("addend_diff", 0)) { split_args.emplace_back("addend_diff", 0); } ctx->NewBuilder() .Broadcast(broadcast_args) .PartialSum(partial_sum_args) .Split(split_args, 0) .Build(); return Maybe::Ok(); } } // namespace /* static */ Maybe NormalizationGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return BwTensorDescInferFn(ctx); } /*static*/ Maybe NormalizationGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NormalizationGradOp::GetSbp(user_op::SbpContext* ctx) { return BwGetSbpFn(ctx); } /* static */ Maybe NormalizationGradOp::InferDataType(user_op::InferContext* ctx) { return BwDataTypeInferFn(ctx); } /* static */ Maybe NormalizationAddReluGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return BwTensorDescInferFn(ctx); } /*static*/ Maybe NormalizationAddReluGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { return BwGetSbpFn(ctx); } /* static */ Maybe NormalizationAddReluGradOp::InferDataType(user_op::InferContext* ctx) { return BwDataTypeInferFn(ctx); } #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return BwTensorDescInferFn(ctx); } /*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { return BwGetSbpFn(ctx); } /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( user_op::InferContext* ctx) { return BwDataTypeInferFn(ctx); } #else /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } /* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; } #endif } // namespace oneflow ================================================ FILE: oneflow/user/ops/nvtx_range_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { #ifdef WITH_CUDA /* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } #else /* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } /* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { return Error::UnimplementedError() << "require CUDA to use NVTX"; } #endif // WITH_CUDA } // namespace oneflow ================================================ FILE: oneflow/user/ops/ofrecord_decoder_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe OfrecordRawDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); Shape conf_shape = ctx->Attr("shape"); DimVector dim_vec(1 + conf_shape.NumAxes()); dim_vec[0] = in_tensor.shape().At(0); for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); } out_tensor->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe OfrecordRawDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OfrecordRawDecoderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); return Maybe::Ok(); } /* static */ Maybe OfrecordRawDecoderOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe OfrecordRawDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); out_tensor->set_data_type(ctx->Attr("data_type")); return Maybe::Ok(); } /* static */ Maybe OfrecordBytesDecoderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(in.is_dynamic()); out->set_shape(in.shape()); return Maybe::Ok(); } /*static*/ Maybe OfrecordBytesDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OfrecordBytesDecoderOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /* static */ Maybe OfrecordBytesDecoderOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe OfrecordBytesDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); out->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); out_tensor->set_shape(in_tensor.shape()); return Maybe::Ok(); } /*static*/ Maybe OfrecordImageDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OfrecordImageDecoderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); out_tensor->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderRandomCropOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); out_tensor->set_shape(in_tensor.shape()); return Maybe::Ok(); } /*static*/ Maybe OfrecordImageDecoderRandomCropOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OfrecordImageDecoderRandomCropOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderRandomCropOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe OfrecordImageDecoderRandomCropOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); out_tensor->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/ofrecord_image_classification_reader_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe OfrecordImageClassificationReaderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc("image", 0); user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc("label", 0); int32_t batch_size = ctx->Attr("batch_size"); image_tensor->set_shape(Shape({batch_size})); label_tensor->set_shape(Shape({batch_size})); return Maybe::Ok(); } /* static */ Maybe OfrecordImageClassificationReaderOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc("image", 0); user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc("label", 0); int32_t local_batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { int64_t split_num = 1; const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("image", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(local_batch_size % split_num, 0); local_batch_size /= split_num; } image_tensor->set_shape(Shape({local_batch_size})); label_tensor->set_shape(Shape({local_batch_size})); return Maybe::Ok(); } /* static */ Maybe OfrecordImageClassificationReaderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe OfrecordImageClassificationReaderOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); CHECK_OR_RETURN(image_modifier != nullptr); image_modifier->set_header_infered_before_compute(false); user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); CHECK_OR_RETURN(label_modifier != nullptr); label_modifier->set_header_infered_before_compute(false); return Maybe::Ok(); } /* static */ Maybe OfrecordImageClassificationReaderOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("image", 0, DataType::kTensorBuffer); ctx->SetOutputDType("label", 0, DataType::kTensorBuffer); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/ofrecord_reader_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe OFRecordReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_shape(Shape({ctx->Attr("batch_size")})); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int32_t batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { int64_t split_num = 1; const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(batch_size % split_num, 0); batch_size /= split_num; } out_tensor->set_shape(Shape({batch_size})); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::GetNdSbpSignatureList( user_op::GetNdSbpSignatureListContext* ctx) { NdSbpSignature nd_sbp_signature; SbpParallel split_sbp_parallel; split_sbp_parallel.mutable_split_parallel()->set_axis(0); for (int32_t dim_sbp = 0; dim_sbp < ctx->parallel_hierarchy().NumAxes(); dim_sbp++) { *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn("out", 0)].add_sbp_parallel() = split_sbp_parallel; } ctx->AddNdSbpSignature(nd_sbp_signature); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::GetComputeComplexity( user_op::ComputeComplexityFnContext* ctx) { // Don't support broadcast. return double(ctx->Shape4ArgNameAndIndex("out", 0).elem_cnt() * GetSizeOfDataType(DataType::kOFRecord)) / ctx->parallel_desc().hierarchy()->elem_cnt(); } /* static */ Maybe OFRecordReaderOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); CHECK_OR_RETURN(out_modifier != nullptr); // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch // size data with output shape (batch_size,) // out_modifier->set_header_infered_before_compute(false); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_split_parallel()->set_axis(0); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe OFRecordReaderOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kOFRecord); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/one_embedding_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/embedding/embedding_manager.h" namespace oneflow { /* static */ Maybe OneEmbeddingFusedLookupOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& ids_shape = ctx->InputShape("ids", 0); if (ctx->has_input("table_ids", 0)) { const Shape& table_ids_shape = ctx->InputShape("table_ids", 0); CHECK_EQ_OR_RETURN(ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; } DimVector out_dim_vec = ids_shape.dim_vec(); const int64_t embedding_size = ctx->Attr("embedding_size"); out_dim_vec.push_back(embedding_size); ctx->SetOutputShape("embeddings", 0, Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingFusedLookupOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingFusedLookupOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Broadcast(user_op::OpArg("shadow", 0)) .Split(user_op::OpArg("ids", 0), 0) .Split(user_op::OpArg("embeddings", 0), 0); if (ctx->user_op_conf().has_input("table_ids", 0)) { builder.Split(user_op::OpArg("table_ids", 0), 0); } builder.Build(); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedLookupOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* shadow = GetInputArgModifierFn("shadow", 0); CHECK_OR_RETURN(shadow != nullptr) << "shadow is nullptr"; shadow->set_requires_grad(false); user_op::InputArgModifier* ids = GetInputArgModifierFn("ids", 0); CHECK_OR_RETURN(ids != nullptr); ids->set_requires_grad(false); if (conf.has_input("table_ids", 0)) { user_op::InputArgModifier* table_ids = GetInputArgModifierFn("table_ids", 0); CHECK_OR_RETURN(table_ids != nullptr) << "table_ids is nullptr"; table_ids->set_requires_grad(false); } return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedLookupOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("embeddings", 0, ctx->InputDType("shadow", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedLookupGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingFusedLookupGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingFusedLookupGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("ids", 0), 0) .Split(user_op::OpArg("embedding_grad", 0), 0) .Build(); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedLookupGradOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /* static */ Maybe EmbeddingPrefetchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& num_unique_ids_shape = ctx->InputShape("num_unique_ids", 0); const Shape& unique_ids_shape = ctx->InputShape("unique_ids", 0); const Shape& table_ids_shape = ctx->InputShape("table_ids", 0); CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); ctx->SetOutputShape("context", 0, num_unique_ids_shape); return Maybe::Ok(); } /*static*/ Maybe EmbeddingPrefetchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EmbeddingPrefetchOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("num_unique_ids", 0)) .Split(user_op::OpArg("unique_ids", 0), 0) .Split(user_op::OpArg("table_ids", 0), 0) .Broadcast(user_op::OpArg("context", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe EmbeddingPrefetchOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("context", 0, ctx->InputDType("num_unique_ids", 0)); return Maybe::Ok(); } /* static */ Maybe EmbeddingLookupOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& num_unique_ids_shape = ctx->InputShape("num_unique_ids", 0); const Shape& unique_ids_shape = ctx->InputShape("unique_ids", 0); const Shape& table_ids_shape = ctx->InputShape("table_ids", 0); CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0); CHECK_NE_OR_RETURN(line_size, 0); CHECK_GE_OR_RETURN(line_size, embedding_size); const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (ctx->has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { ctx->SetOutputShape("embeddings", 0, Shape({1})); } else { DimVector embeddings_dim_vec = unique_ids_shape.dim_vec(); embeddings_dim_vec.push_back(embedding_size); ctx->SetOutputShape("embeddings", 0, Shape(embeddings_dim_vec)); } } if (use_dynamic_memory_allocation) { ctx->SetOutputShape("unique_values", 0, Shape({1})); } else { DimVector unique_values_dim_vec = unique_ids_shape.dim_vec(); unique_values_dim_vec.push_back(line_size); ctx->SetOutputShape("unique_values", 0, Shape(unique_values_dim_vec)); } return Maybe::Ok(); } /*static*/ Maybe EmbeddingLookupOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EmbeddingLookupOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Broadcast(user_op::OpArg("num_unique_ids", 0)) .Split(user_op::OpArg("unique_ids", 0), 0) .Split(user_op::OpArg("table_ids", 0), 0); if (ctx->user_op_conf().has_input("context", 0)) { builder.Broadcast(user_op::OpArg("context", 0)); } const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (use_dynamic_memory_allocation) { builder.Broadcast(user_op::OpArg("unique_values", 0)); } else { builder.Split(user_op::OpArg("unique_values", 0), 0); } if (ctx->user_op_conf().has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { builder.Broadcast(user_op::OpArg("embeddings", 0)); } else { builder.Split(user_op::OpArg("embeddings", 0), 0); } } builder.Build(); return Maybe::Ok(); } /* static */ Maybe EmbeddingLookupOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("unique_values", 0, ctx->Attr("dtype")); if (ctx->has_output("embeddings", 0)) { ctx->SetOutputDType("embeddings", 0, ctx->Attr("embeddings_dtype")); } return Maybe::Ok(); } /* static */ Maybe EmbeddingPutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe EmbeddingPutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe EmbeddingPutOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Broadcast(user_op::OpArg("num_unique_ids", 0)) .Split(user_op::OpArg("unique_ids", 0), 0); if (embedding::UseDynamicMemoryAllocation()) { builder.Broadcast(user_op::OpArg("unique_embeddings", 0)).Build(); } else { builder.Split(user_op::OpArg("unique_embeddings", 0), 0).Build(); } return Maybe::Ok(); } /* static */ Maybe EmbeddingPutOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } Maybe CheckDataShape(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { CHECK_EQ_OR_RETURN(ctx->InputShape("learning_rate", 0), Shape({1})); } if (ctx->has_input("down_scale_by_tensor", 0)) { CHECK_EQ_OR_RETURN(ctx->InputShape("down_scale_by_tensor", 0), Shape({1})); } CHECK_EQ_OR_RETURN(ctx->InputShape("num_unique_ids", 0), Shape({1})); const Shape& embedding_grad_shape = ctx->InputShape("embedding_grad", 0); CHECK_EQ_OR_RETURN(embedding_grad_shape.NumAxes(), 2); const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); if (embedding::UseDynamicMemoryAllocation()) { CHECK_EQ_OR_RETURN(unique_embeddings_shape.elem_cnt(), 1) << "if use dynamic memory allocation, unique_embeddings elem_cnt should be 1."; } else { CHECK_EQ_OR_RETURN(unique_embeddings_shape.NumAxes(), 2) << "unique_embeddings num_axes should be 2."; CHECK_EQ_OR_RETURN(unique_embeddings_shape.At(0), embedding_grad_shape.At(0)) << "got " << unique_embeddings_shape.At(0) << " and " << embedding_grad_shape.At(0); } return Maybe::Ok(); } Maybe CheckDataType(user_op::InferContext* ctx) { if (ctx->has_input("learning_rate", 0)) { const DataType learning_rate_dtype = ctx->InputDType("learning_rate", 0); CHECK_EQ_OR_RETURN(learning_rate_dtype, DataType::kFloat) << "InferDataType Failed. Expected " << DataType_Name(DataType::kFloat) << ", but got " << DataType_Name(learning_rate_dtype); } if (ctx->has_input("down_scale_by_tensor", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("down_scale_by_tensor", 0), ctx->InputDType("unique_embeddings", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("unique_embeddings", 0)) << ", but got " << DataType_Name(ctx->InputDType("down_scale_by_tensor", 0)); } return Maybe::Ok(); } Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Broadcast(ctx->inputs()) .Broadcast(user_op::OpArg("num_unique_ids", 0)) .Split(user_op::OpArg("embedding_grad", 0), 0); if (embedding::UseDynamicMemoryAllocation()) { builder.Broadcast(user_op::OpArg("unique_embeddings", 0)) .Broadcast(user_op::OpArg("updated_unique_embeddings", 0)) .Build(); } else { builder.Split(user_op::OpArg("unique_embeddings", 0), 0) .Split(user_op::OpArg("updated_unique_embeddings", 0), 0) .Build(); } return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedSgdUpdatePutOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingFusedSgdUpdatePutOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingFusedSgdUpdatePutOp::GetSbp(user_op::SbpContext* ctx) { auto builder = ctx->NewBuilder() .Broadcast(user_op::OpArg("learning_rate", 0)) .Broadcast(user_op::OpArg("num_unique_ids", 0)) .Split(user_op::OpArg("unique_ids", 0), 0) .Split(user_op::OpArg("embedding_grad", 0), 0); if (embedding::UseDynamicMemoryAllocation()) { builder.Broadcast(user_op::OpArg("unique_embeddings", 0)).Build(); } else { builder.Split(user_op::OpArg("unique_embeddings", 0), 0).Build(); } return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFusedSgdUpdatePutOp::InferDataType( user_op::InferContext* ctx) { return Maybe::Ok(); } /* static */ Maybe OneEmbeddingSgdUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size) << "when use SGD optimizer, line_size should equals to embedding_size, but get line_size: " << line_size << " embedding_size: " << embedding_size << ", please set size_factor of store_options to 1."; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingSgdUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingSgdUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingMomentumUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "when using Momentum optimizer, line_size should equals to embedding_size * 2, but get " "line_size: " << line_size << " embedding_size: " << embedding_size << ", please set size_factor of store_options to 2."; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingMomentumUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingAdamUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "when using Adam optimizer, line_size should equals to embedding_size * 3, but get " "line_size: " << line_size << " embedding_size: " << embedding_size << ", please set size_factor of store_options to 3."; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingAdamUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingAdamUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingSmartDecaySparseAdamUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; const int64_t value_dtype_size = GetSizeOfDataType(ctx->InputDType("unique_embeddings", 0)); const int64_t step_dtype_size = sizeof(int64_t); const int64_t model_and_states_bytes = embedding_size * 3 * value_dtype_size; const int64_t align_to_step_size_bytes = (model_and_states_bytes + step_dtype_size - 1) / step_dtype_size * step_dtype_size; const int64_t smart_decay_sparse_adam_line_size = (align_to_step_size_bytes + step_dtype_size) / value_dtype_size; CHECK_EQ_OR_RETURN(line_size, smart_decay_sparse_adam_line_size) << "when using SmartDecayAdam optimizer with embedding_size " << embedding_size << ", storage_dim should equals to " << smart_decay_sparse_adam_line_size << ", but got " "storage_dim: " << line_size << ", please set storage_dim of store_options to " << smart_decay_sparse_adam_line_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingSmartDecaySparseAdamUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingSmartDecaySparseAdamUpdateOp::GetSbp( user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingSmartDecaySparseAdamUpdateOp::InferDataType( user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingAdagradUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "when using Adagrad optimizer, line_size should equals to embedding_size * 2, but get " "line_size: " << line_size << " embedding_size: " << embedding_size << ", please set size_factor of store_options to 2."; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingAdagradUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingAdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingAdagradUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFtrlUpdateOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { JUST(CheckDataShape(ctx)); const int64_t embedding_size = ctx->Attr("embedding_size"); const int64_t line_size = ctx->Attr("line_size"); CHECK_NE_OR_RETURN(embedding_size, 0) << "should set attr embedding_size"; CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "when using Ftrl optimizer, line_size should equals to embedding_size * 3, but get " "line_size: " << line_size << " embedding_size: " << embedding_size << ", please set size_factor of store_options to 3."; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); ctx->SetOutputShape("updated_unique_embeddings", 0, unique_embeddings_shape); return Maybe::Ok(); } /*static*/ Maybe OneEmbeddingFtrlUpdateOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneEmbeddingFtrlUpdateOp::GetSbp(user_op::SbpContext* ctx) { JUST(GetEmbeddingUpdateSbp(ctx)); return Maybe::Ok(); } /* static */ Maybe OneEmbeddingFtrlUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); ctx->SetOutputDType("updated_unique_embeddings", 0, ctx->InputDType("unique_embeddings", 0)); return Maybe::Ok(); } /*static*/ Maybe IdShuffleCopyOutOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(ctx->inputs(), 0) .Split(ctx->outputs(), 0) .Broadcast(user_op::OpArg("num_unique_matrix", 0)) .Broadcast(user_op::OpArg("out_num_unique_matrix", 0)) .Broadcast(user_op::OpArg("cur_rank_num_unique", 0)) .Broadcast(user_op::OpArg("out_cur_rank_num_unique", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe IdShuffleCopyOutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out_num_unique_matrix", 0, ctx->InputShape("num_unique_matrix", 0)); ctx->SetOutputShape("out_inverse_unique_partition_indices", 0, ctx->InputShape("inverse_unique_partition_indices", 0)); ctx->SetOutputShape("out_cur_rank_num_unique", 0, ctx->InputShape("cur_rank_num_unique", 0)); ctx->SetOutputShape("out_cur_rank_unique_ids", 0, ctx->InputShape("cur_rank_unique_ids", 0)); ctx->SetOutputShape("out_cur_rank_unique_table_ids", 0, ctx->InputShape("cur_rank_unique_table_ids", 0)); ctx->SetOutputShape("out_cur_rank_inverse_indices", 0, ctx->InputShape("cur_rank_inverse_indices", 0)); return Maybe::Ok(); } /*static*/ Maybe IdShuffleCopyOutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe IdShuffleCopyOutOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out_num_unique_matrix", 0, ctx->InputDType("num_unique_matrix", 0)); ctx->SetOutputDType("out_inverse_unique_partition_indices", 0, ctx->InputDType("inverse_unique_partition_indices", 0)); ctx->SetOutputDType("out_cur_rank_num_unique", 0, ctx->InputDType("cur_rank_num_unique", 0)); ctx->SetOutputDType("out_cur_rank_unique_ids", 0, ctx->InputDType("cur_rank_unique_ids", 0)); ctx->SetOutputDType("out_cur_rank_unique_table_ids", 0, ctx->InputDType("cur_rank_unique_table_ids", 0)); ctx->SetOutputDType("out_cur_rank_inverse_indices", 0, ctx->InputDType("cur_rank_inverse_indices", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/one_hot_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe OneHotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int64_t depth = ctx->Attr("depth"); CHECK_GT_OR_RETURN(depth, 0); const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); // For 0-dim Tensor CHECK_GE_OR_RETURN(indices_desc.shape().NumAxes(), 0) << "indices dim must be great or equal than 0"; user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(indices_desc.is_dynamic()); DimVector dim_vec = indices_desc.shape().dim_vec(); dim_vec.emplace_back(depth); out_desc->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe OneHotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe OneHotOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& indices_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("indices", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe OneHotOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); return Maybe::Ok(); } /* static */ Maybe OneHotOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); DataType dtype = ctx->Attr("dtype"); out_desc->set_data_type(dtype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/ones_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe OnesLikeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("like", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("like", 0)); ctx->SetOutputStride("out", 0, ctx->InputStride("like", 0)); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return OnesLikeOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe OnesLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("like", 0)); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); *like_distribution = in_sbp; *out_distribution = in_sbp; for (auto& sbp : *out_distribution->mutable_sbp_parallel()) { if (sbp.has_partial_sum_parallel()) { sbp.Clear(); *sbp.mutable_broadcast_parallel() = BroadcastParallel(); } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/p2p_comm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SendOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } /*static*/ Maybe SendOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { // Do nothing. return Maybe::Ok(); } /*static*/ Maybe SendOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return SendOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe SendOp::InferDataType(user_op::InferContext* ctx) { // Do nothing. return Maybe::Ok(); } /*static*/ Maybe> SendOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn(ctx); } namespace { Maybe> GetRecvOutputDeivce(user_op::DeviceAndStreamInferContext* ctx) { const std::string& device_type = ctx->Attr("device_type"); const int device_id = ctx->Attr("device_id"); return Device::New(device_type, device_id); } } // namespace /*static*/ Maybe RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } /*static*/ Maybe RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->Attr("shape")); return Maybe::Ok(); } /*static*/ Maybe RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return RecvOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe RecvOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } /*static*/ Maybe> RecvOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { return DeviceAndStreamInferFn<&GetRecvOutputDeivce>(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/pack_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe PackOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe PackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); const int32_t pack_num = ctx->Attr("pack_num"); CHECK_GT_OR_RETURN(pack_num, 0); Shape out_shape = in_desc.shape(); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(in_desc.is_dynamic()); if (out_shape.NumAxes() > 0) { out_shape.Set(0, out_shape.At(0) * pack_num); out_desc->set_shape(out_shape); } else { // NOTE(chengcheng): for Scalar input pack CHECK_EQ_OR_RETURN(out_shape.elem_cnt(), 1); out_desc->set_shape(Shape({pack_num})); } return Maybe::Ok(); } /*static*/ Maybe PackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return PackOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe PackOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe PackOp::InferOutputBlobTimeShape( user_op::InferOutputBlobTimeShapeFnContext* ctx) { const int32_t pack_num = ctx->user_op_conf().attr("pack_num"); DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); CHECK_OR_RETURN(!time_shape_dim_vec.empty()); CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num); time_shape_dim_vec.pop_back(); if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); } *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/pad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe PadOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { if (padding_before[i] == 0 && padding_after[i] == 0) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } } return Maybe::Ok(); } /*static*/ Maybe PadOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes()); CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes()); DimVector y_dim_vec(x_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; } ctx->SetOutputShape("y", 0, Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return PadOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe PadOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/parallel_cast_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ParallelCastOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ParallelCastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ParallelCastOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ParallelCastOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ParallelCastOp::InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx) { auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel(); const std::string& ibn = GenRepeatedBn("in", 0); const std::string& obn = GenRepeatedBn("out", 0); const auto& sbp_parallel_str = ctx->Attr("sbp_parallel"); if (sbp_parallel_str.empty()) { const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0); (*bn2sbp)[ibn] = sbp_parallel; (*bn2sbp)[obn] = sbp_parallel; } else { SbpParallel sbp_parallel; CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel)) << "invalid sbp_parallel: " << sbp_parallel_str; if (sbp_parallel.has_split_parallel()) { int64_t split_axis = sbp_parallel.split_parallel().axis(); const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); int64_t num_axes = in_desc.shape().NumAxes(); CHECK_GE_OR_RETURN(split_axis, 0); CHECK_LT_OR_RETURN(split_axis, num_axes); } (*bn2sbp)[ibn] = sbp_parallel; (*bn2sbp)[obn] = sbp_parallel; } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/partial_fc_sample_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe DistributedPartialFcSampleOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("weight", 0), 0) .Broadcast(user_op::OpArg("label", 0)) .Broadcast(user_op::OpArg("mapped_label", 0)) .Split(user_op::OpArg("sampled_label", 0), 0) .Split(user_op::OpArg("sampled_weight", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const int64_t num_sample = ctx->Attr("num_sample"); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc("mapped_label", 0); user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc("sampled_weight", 0); user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc("sampled_label", 0); mapped_label->set_shape(label.shape()); mapped_label->set_is_dynamic(label.is_dynamic()); Shape sampled_weight_shape = weight.shape(); sampled_weight_shape.Set(0, num_sample); sampled_weight->set_shape(sampled_weight_shape); sampled_weight->set_is_dynamic(weight.is_dynamic()); Shape sampled_label_shape = label.shape(); sampled_label_shape.Set(0, num_sample); sampled_label->set_shape(sampled_label_shape); sampled_label->set_is_dynamic(label.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { const int64_t num_sample = ctx->Attr("num_sample"); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0); const int64_t num_sample_per_rank = num_sample / parallel_num; const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc("mapped_label", 0); user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc("sampled_weight", 0); user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc("sampled_label", 0); mapped_label->set_shape(label.shape()); mapped_label->set_is_dynamic(label.is_dynamic()); Shape sampled_weight_shape = weight.shape(); sampled_weight_shape.Set(0, num_sample_per_rank); sampled_weight->set_shape(sampled_weight_shape); sampled_weight->set_is_dynamic(weight.is_dynamic()); Shape sampled_label_shape = label.shape(); sampled_label_shape.Set(0, num_sample_per_rank); sampled_label->set_shape(sampled_label_shape); sampled_label->set_is_dynamic(label.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("mapped_label", 0, ctx->InputDType("label", 0)); ctx->SetOutputDType("sampled_weight", 0, ctx->InputDType("weight", 0)); ctx->SetOutputDType("sampled_label", 0, ctx->InputDType("label", 0)); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); CHECK_NOTNULL_OR_RETURN(label_modifier); label_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("sampled_weight_diff", 0), 0) .Split(user_op::OpArg("sampled_label", 0), 0) .Broadcast(user_op::OpArg("boxing_disabled_sampled_weight_diff", 0)) .Broadcast(user_op::OpArg("boxing_disabled_sampled_label", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { user_op::TensorDesc* boxing_disabled_sampled_weight_diff = ctx->MutOutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); Shape boxing_disabled_sampled_weight_diff_shape = ctx->InputShape("sampled_weight_diff", 0); CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff_shape.At(0) % ctx->parallel_num(), 0); boxing_disabled_sampled_weight_diff_shape.Set( 0, boxing_disabled_sampled_weight_diff_shape.At(0) / ctx->parallel_num()); boxing_disabled_sampled_weight_diff->set_shape(boxing_disabled_sampled_weight_diff_shape); boxing_disabled_sampled_weight_diff->set_is_dynamic( ctx->InputIsDynamic("sampled_weight_diff", 0)); user_op::TensorDesc* boxing_disabled_sampled_label = ctx->MutOutputTensorDesc("boxing_disabled_sampled_label", 0); Shape boxing_disabled_sampled_label_shape = ctx->InputShape("sampled_label", 0); ; CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label_shape.At(0) % ctx->parallel_num(), 0); boxing_disabled_sampled_label_shape.Set( 0, boxing_disabled_sampled_label_shape.At(0) / ctx->parallel_num()); boxing_disabled_sampled_label->set_shape(boxing_disabled_sampled_label_shape); boxing_disabled_sampled_label->set_is_dynamic(ctx->InputIsDynamic("sampled_label", 0)); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { ctx->SetOutputShape("boxing_disabled_sampled_weight_diff", 0, ctx->InputShape("sampled_weight_diff", 0)); ctx->SetOutputIsDynamic("boxing_disabled_sampled_weight_diff", 0, ctx->InputIsDynamic("sampled_weight_diff", 0)); ctx->SetOutputShape("boxing_disabled_sampled_label", 0, ctx->InputShape("sampled_label", 0)); ctx->SetOutputIsDynamic("boxing_disabled_sampled_label", 0, ctx->InputIsDynamic("sampled_label", 0)); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferDataType( user_op::InferContext* ctx) { ctx->SetOutputDType("boxing_disabled_sampled_weight_diff", 0, ctx->InputDType("sampled_weight_diff", 0)); ctx->SetOutputDType("boxing_disabled_sampled_label", 0, ctx->InputDType("sampled_label", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/pinned_identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe PinnedIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe PinnedIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe PinnedIdentityOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe PinnedIdentityOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/prelu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe PreluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); if (alpha_tensor.shape().At(0) != 1) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 1) .Split(user_op::OpArg("alpha", 0), 0) .Split(user_op::OpArg("y", 0), 1) .Build(); } FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { if (i == 1) continue; ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("alpha", 0)) .Split(user_op::OpArg("y", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); ctx->SetOutputShape("y", 0, x_shape); return Maybe::Ok(); } /*static*/ Maybe PreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe PreluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe PreluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Broadcast(user_op::OpArg("alpha", 0)) .Split(user_op::OpArg("dx", 0), 0) .PartialSum(user_op::OpArg("alpha_diff", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("alpha", 0)) .PartialSum(user_op::OpArg("dx", 0)) .PartialSum(user_op::OpArg("alpha_diff", 0)) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 1) .Split(user_op::OpArg("x", 0), 1) .Split(user_op::OpArg("alpha", 0), 0) .Split(user_op::OpArg("dx", 0), 1) .Split(user_op::OpArg("alpha_diff", 0), 0) .Build(); FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("alpha", 0), 0) .Split(user_op::OpArg("dx", 0), i) .Split(user_op::OpArg("alpha_diff", 0), 0) .Build(); } return Maybe::Ok(); } /*static*/ Maybe PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); CHECK_EQ_OR_RETURN(dy_shape, x_shape); ctx->SetOutputShape("dx", 0, x_shape); ctx->SetOutputShape("alpha_diff", 0, alpha_shape); return Maybe::Ok(); } /*static*/ Maybe PreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe PreluGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("alpha_diff", 0, ctx->InputDType("alpha", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/quantization_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe QuantizationOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const Shape& logical_scale_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); ctx->NewBuilder() .Broadcast(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); if (logical_scale_shape.elem_cnt() > 1) { // NOTE(Liang Depeng): only consider convolution weight per-channel quantization ctx->NewBuilder() .Split(user_op::OpArg("in", 0), 0) .Split(user_op::OpArg("scale", 0), 0) .Split(user_op::OpArg("zero_point", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); } else { // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise // ops ctx->NewBuilder() .Split(user_op::OpArg("in", 0), 0) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Split(user_op::OpArg("out", 0), 0) .Build(); } FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Broadcast(user_op::OpArg("scale", 0)) .Broadcast(user_op::OpArg("zero_point", 0)) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& scale_shape = ctx->InputShape("scale", 0); const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for // convolution weights. if (scale_shape.elem_cnt() > 1) { CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } ctx->SetOutputShape("out", 0, in_shape); return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe QuantizationOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); CHECK_OR_RETURN(scale != nullptr); scale->set_requires_grad(false); user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); CHECK_OR_RETURN(zero_point != nullptr); zero_point->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { const int32_t quantization_bit = op_conf.attr("quantization_bit"); CHECK_GT_OR_RETURN(quantization_bit, 1); CHECK_LE_OR_RETURN(quantization_bit, 8); std::string quantization_scheme = op_conf.attr("quantization_scheme"); CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); std::string quantization_formula = op_conf.attr("quantization_formula"); CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/quick_gelu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe QuickGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe QuickGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe QuickGeluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe QuickGeluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe QuickGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << "InferTensorDesc failed (" << ctx->op_name() << "). Expected x shape " << x_shape.ToString() << " to be equal to dy shape " << dy_shape.ToString(); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe QuickGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe QuickGeluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe QuickGeluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/randperm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /*static*/ Maybe RandpermOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /*static*/ Maybe RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int32_t n = ctx->Attr("n"); CHECK_GE_OR_RETURN(n, 0) << Error::RuntimeError() << "Trying to create tensor with negative dimension " << n << ":" << " [" << n << "]"; ctx->SetOutputShape("out", 0, Shape({n})); return Maybe::Ok(); } /*static*/ Maybe RandpermOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); int32_t n = ctx->Attr("n"); const Shape& logical_shape = Shape({n}); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /*static*/ Maybe RandpermOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kInt32); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/raw_reader_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe RawReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& instance_shape = ctx->Attr("shape"); const int32_t batch_size = ctx->Attr("batch_size"); DimVector dim_vec; dim_vec.push_back(batch_size); for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) { dim_vec.push_back(instance_shape.At(i)); } user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); out_tensor->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /* static */ Maybe RawReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int32_t batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { int64_t split_num = 1; const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); } } CHECK_EQ_OR_RETURN(batch_size % split_num, 0) << "batch_size must be a multiple of shard num"; batch_size /= split_num; } const Shape& instance_shape = ctx->Attr("shape"); DimVector dim_vec; dim_vec.push_back(batch_size); for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) { dim_vec.push_back(instance_shape.At(i)); } out_tensor->set_shape(Shape({dim_vec})); return Maybe::Ok(); } /* static */ Maybe RawReaderOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } /* static */ Maybe RawReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_split_parallel()->set_axis(0); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe RawReaderOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("data_type")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reduce_like_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ReduceSumLikeOp::GetSbp(user_op::SbpContext* ctx) { int32_t num_axes = 0; HashSet conf_axes; const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); num_axes = in_tensor.shape().NumAxes(); const auto& reduced_axes = ctx->Attr>("axis"); ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes); const auto& like_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); const bool keep_dims = (num_axes == like_num_axes); auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); int64_t num_reduced_axes = 0; FOR_RANGE(int64_t, i, 0, num_axes) { if (in_tensor.shape().at(i) == 1) { num_reduced_axes += 1; } else if (IsReducedAxis(i)) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .PartialSum(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); num_reduced_axes += 1; } else { const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes; ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("like", 0), out_split_axis) .Split(user_op::OpArg("y", 0), out_split_axis) .Build(); } } ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe ReduceSumLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); const auto& axis = ctx->Attr>("axis"); if (axis.empty()) { CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape()) << Error::RuntimeError() << "The shape of the x tensor must be consistent to the shape of the like tensor" << " when the input axis list is empty"; } user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc("y", 0); y_tensor->set_shape(like_tensor.shape()); y_tensor->set_is_dynamic(like_tensor.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe ReduceSumLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReduceSumLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReduceSumLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); CHECK_OR_RETURN(like_arg_modifier != nullptr); // NOLINT(maybe-need-error-msg) like_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reduce_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input_tensor", 0); const auto& reduce_axes = ctx->Attr>("axis"); Shape output_shape; // For 0-dim Tensor if (reduce_axes.empty()) { output_shape = input_shape; } else { const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()}; const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec); const bool keepdims = ctx->Attr("keepdims"); if (keepdims) { output_shape = reduce_shape; } else { output_shape = reduce_shape.RemoveOnes(reduce_axes_vec); } } ctx->SetOutputShape("output_tensor", 0, output_shape); ctx->SetOutputStride("output_tensor", 0, Stride(output_shape)); return Maybe::Ok(); } Maybe InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output_tensor", 0, ctx->InputDType("input_tensor", 0)); return Maybe::Ok(); } Maybe InferLogicalDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output_tensor", 0, DataType::kBool); return Maybe::Ok(); } template class binary_func> void GeneratePartialSbp(user_op::SbpContext* ctx, int64_t axis) { // TODO(lixinqi) } template<> void GeneratePartialSbp(user_op::SbpContext* ctx, int64_t axis) { ctx->NewBuilder().Split(ctx->inputs(), axis).PartialSum(ctx->outputs()).Build(); ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); } template class binary_func> Maybe GetSbpFn(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input_tensor", 0); int64_t num_axes = in_tensor.shape().NumAxes(); bool keep_dims = ctx->Attr("keepdims"); const auto& reduce_axes = ctx->Attr>("axis"); HashSet conf_axes; ReduceSbpUtil::GetRegularAxes(num_axes, reduce_axes, &conf_axes); auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); int32_t num_reduced_axes = 0; FOR_RANGE(int64_t, i, 0, num_axes) { if (IsReducedAxis(i)) { GeneratePartialSbp(ctx, i); num_reduced_axes += 1; } else { ctx->NewBuilder() .Split(ctx->inputs(), i) .Split(ctx->outputs(), keep_dims ? i : i - num_reduced_axes) .Build(); } } if (num_axes == 0) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } #define IMPLEMENT_REDUCE_OP_FUNCS(name, binary_func, infer_dtype_func) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ return GetSbpFn(ctx); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDescFn(ctx); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ return infer_dtype_func(ctx); \ } IMPLEMENT_REDUCE_OP_FUNCS(ReduceAny, BinaryFuncAny, InferLogicalDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceAll, BinaryFuncAll, InferLogicalDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceMin, BinaryFuncMin, oneflow::InferDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceMax, BinaryFuncMax, oneflow::InferDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceSum, BinaryFuncSum, oneflow::InferDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceProd, BinaryFuncProd, oneflow::InferDataType) IMPLEMENT_REDUCE_OP_FUNCS(ReduceNanSum, BinaryFuncNanSum, oneflow::InferDataType) #undef IMPLEMENT_REDUCE_OP_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/reflection_pad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { template Maybe GetOpSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const int64_t input_dims = x_tensor.shape().NumAxes(); const int64_t split_dims = input_dims - (ndim - 2); FOR_RANGE(int64_t, i, 0, split_dims) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } template Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); const int64_t grad_dims = dy_tensor.shape().NumAxes(); const int64_t split_dims = grad_dims - (ndim - 2); FOR_RANGE(int64_t, i, 0, split_dims) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } } // namespace /*static*/ Maybe ReflectionPad1DOp::GetSbp(user_op::SbpContext* ctx) { return GetOpSbpSignature<3>(ctx); } /*static*/ Maybe ReflectionPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; DimVector y_dim_vec(x_shape.NumAxes()); const int64_t w_x = x_shape.At(w_idx); y_dim_vec[n_idx] = x_shape.At(n_idx); y_dim_vec[c_idx] = x_shape.At(c_idx); y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; ctx->SetOutputShape("y", 0, Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReflectionPad1DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad1DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad1DOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); CHECK_NOTNULL_OR_RETURN(x_modifier); // NOLINT x_modifier->set_requires_grad(true); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad1DGradOp::GetSbp(user_op::SbpContext* ctx) { return GetOpGradSbpSignature<3>(ctx); } /*static*/ Maybe ReflectionPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; DimVector dx_dim_vec(dy_shape.NumAxes()); int64_t w_dy = dy_shape.At(w_idx); dx_dim_vec[n_idx] = dy_shape.At(0); dx_dim_vec[c_idx] = dy_shape.At(1); dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; ctx->SetOutputShape("dx", 0, Shape(dx_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReflectionPad1DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad1DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::GetSbp(user_op::SbpContext* ctx) { return GetOpSbpSignature<4>(ctx); } /*static*/ Maybe ReflectionPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; DimVector y_dim_vec(x_shape.NumAxes()); const int64_t h_x = x_shape.At(h_idx); const int64_t w_x = x_shape.At(w_idx); y_dim_vec[n_idx] = x_shape.At(n_idx); y_dim_vec[c_idx] = x_shape.At(c_idx); y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; ctx->SetOutputShape("y", 0, Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReflectionPad2DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); CHECK_NOTNULL_OR_RETURN(x_modifier); // NOLINT x_modifier->set_requires_grad(true); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { return GetOpGradSbpSignature<4>(ctx); } /*static*/ Maybe ReflectionPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; DimVector dx_dim_vec(dy_shape.NumAxes()); int64_t h_dy = dy_shape.At(h_idx); int64_t w_dy = dy_shape.At(w_idx); dx_dim_vec[n_idx] = dy_shape.At(0); dx_dim_vec[c_idx] = dy_shape.At(1); dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; ctx->SetOutputShape("dx", 0, Shape(dx_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReflectionPad2DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad2DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/relu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ReluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "Tensors y and dy must have the same shape"; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe ReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReluGradOp::InferDataType(user_op::InferContext* ctx) { DataType data_type = ctx->InputDType("y", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type) << "InferDataType Failed. Expected " << DataType_Name(data_type) << ", but got " << DataType_Name(ctx->InputDType("dy", 0)); ctx->SetOutputDType("dx", 0, data_type); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/repeat_interleave_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe Repeat_InterLeaveOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("cumsum", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("cumsum", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe Repeat_InterLeaveOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int64_t repeat_num = ctx->Attr("repeat_num"); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_shape(Shape({repeat_num})); return Maybe::Ok(); } /*static*/ Maybe Repeat_InterLeaveOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe Repeat_InterLeaveOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/repeat_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" namespace oneflow { /*static*/ Maybe RepeatOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RepeatOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferOutputBlobTimeShape( user_op::InferOutputBlobTimeShapeFnContext* ctx) { DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec()); dim_vec.emplace_back(ctx->user_op_conf().attr("repeat_num")); *ctx->mut_output_blob_time_shape() = Shape(dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/replication_pad_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { template Maybe GetOpSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const int64_t input_dims = x_tensor.shape().NumAxes(); const int64_t first_two_dims = input_dims - (ndim - 2); FOR_RANGE(int64_t, i, 0, first_two_dims) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } template Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); const int64_t grad_dims = dy_tensor.shape().NumAxes(); CHECK_EQ_OR_RETURN(grad_dims, ndim); // NOLINT const int64_t first_two_dims = grad_dims - (ndim - 2); FOR_RANGE(int64_t, i, 0, first_two_dims) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); return Maybe::Ok(); } } // namespace /*static*/ Maybe ReplicationPad1DOp::GetSbp(user_op::SbpContext* ctx) { return GetOpSbpSignature<3>(ctx); } /*static*/ Maybe ReplicationPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const auto& padding = ctx->Attr>("padding"); const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; DimVector y_dim_vec(x_shape.NumAxes()); const int64_t w_x = x_shape.At(w_idx); y_dim_vec[n_idx] = x_shape.At(n_idx); y_dim_vec[c_idx] = x_shape.At(c_idx); y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; ctx->SetOutputShape("y", 0, Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReplicationPad1DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad1DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad1DOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad1DGradOp::GetSbp(user_op::SbpContext* ctx) { return GetOpGradSbpSignature<3>(ctx); } /*static*/ Maybe ReplicationPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& padding = ctx->Attr>("padding"); CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes() - 1); // NOLINT const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t w_idx = 2; DimVector dx_dim_vec(dy_shape.NumAxes()); int64_t w_dy = dy_shape.At(w_idx); dx_dim_vec[n_idx] = dy_shape.At(0); dx_dim_vec[c_idx] = dy_shape.At(1); dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; ctx->SetOutputShape("dx", 0, Shape(dx_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReplicationPad1DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad1DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::GetSbp(user_op::SbpContext* ctx) { return GetOpSbpSignature<4>(ctx); } /*static*/ Maybe ReplicationPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const auto& padding = ctx->Attr>("padding"); CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); // NOLINT const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; DimVector y_dim_vec(x_shape.NumAxes()); const int64_t h_x = x_shape.At(h_idx); const int64_t w_x = x_shape.At(w_idx); y_dim_vec[n_idx] = x_shape.At(n_idx); y_dim_vec[c_idx] = x_shape.At(c_idx); y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; ctx->SetOutputShape("y", 0, Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReplicationPad2DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); CHECK_NOTNULL_OR_RETURN(x_modifier); // NOLINT x_modifier->set_requires_grad(true); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { return GetOpGradSbpSignature<4>(ctx); } /*static*/ Maybe ReplicationPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& padding = ctx->Attr>("padding"); CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); // NOLINT const int64_t n_idx = 0; const int64_t c_idx = 1; const int64_t h_idx = 2; const int64_t w_idx = 3; DimVector dx_dim_vec(dy_shape.NumAxes()); int64_t h_dy = dy_shape.At(h_idx); int64_t w_dy = dy_shape.At(w_idx); dx_dim_vec[n_idx] = dy_shape.At(0); dx_dim_vec[c_idx] = dy_shape.At(1); dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; ctx->SetOutputShape("dx", 0, Shape(dx_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ReplicationPad2DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad2DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reshape_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/reshape_user_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ReshapeLikeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); ctx->NewBuilder() .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("in", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(in_shape, like_shape, {{"in", 0}}, {{"like", 0}, {"out", 0}}, ctx->hierarchy_value(), &builder); } /*static*/ Maybe ReshapeLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& like_shape = ctx->InputShape("like", 0); CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt()) << Error::RuntimeError() << "The element number of the in tensor must be equal to the element number of the " "like tensor, " << "but got " << in_shape.elem_cnt() << " and " << like_shape.elem_cnt(); ctx->SetOutputShape("out", 0, like_shape); return Maybe::Ok(); } /*static*/ Maybe ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReshapeLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe ReshapeLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); CHECK_NOTNULL_OR_RETURN(like_modifier); // NOLINT(maybe-need-error-msg) like_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reshape_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/reshape_user_op_util.h" namespace oneflow { /*static*/ Maybe ReshapeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& shape = ctx->Attr("shape"); const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder); } /*static*/ Maybe ReshapeOp::EnumerateNdSbpSignatures( user_op::GetNdSbpSignatureListContext* ctx) { const Shape& in_shape = ctx->BlobShape4InputArgNameAndIndex("in", 0); const Shape& shape_attr = ctx->Attr("shape"); std::shared_ptr out_shape_ptr = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape_attr)); std::vector* nd_sbp_sig_list = ctx->MutNdSbpSignatureList(); JUST(ReshapeUserOpUtil::EnumerateNdSbpSignatures({{"in", 0}}, in_shape, {{"out", 0}}, *out_shape_ptr, ctx->parallel_hierarchy(), nd_sbp_sig_list)); // Go down from the tail to the head, since we might drop the tail. for (int32_t sbp_id = nd_sbp_sig_list->size() - 1; sbp_id >= 0; sbp_id--) { auto& nd_sbp_sig = (*nd_sbp_sig_list)[sbp_id]; const auto& out_nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find("out_0"); CHECK_OR_RETURN(out_nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end()) << "can't get sbp for out_0"; Shape out_logical_shape = *out_shape_ptr; // filter by output only be needed here // filter by input will be done in Operator::FilterNdSbpSignatureListByLogicalShape if (JUST(FilterNdSbpByLogicalShape(out_nd_sbp_it->second, out_logical_shape, ctx->parallel_hierarchy()))) { // Remove the Nd SBP candidate std::swap(nd_sbp_sig, nd_sbp_sig_list->back()); nd_sbp_sig_list->pop_back(); } } DeduplicateNdSbpSignatureList(nd_sbp_sig_list, {"in_0", "out_0"}); return Maybe::Ok(); } /*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { Shape shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = in_tensor_desc.shape(); CHECK_OR_RETURN(in_tensor_desc.is_dynamic() == false); // NOLINT(maybe-need-error-msg) out_tensor_desc->set_data_type(in_tensor_desc.data_type()); if (in_shape.NumAxes() == 0 || shape.NumAxes() == 0) { // NOTE(chengcheng): input/output Scalar // do nothing } else { CHECK_GE_OR_RETURN(shape.NumAxes(), 1); // NOLINT(maybe-need-error-msg) CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1); // NOLINT(maybe-need-error-msg) int need_infer_axis = -1; size_t count = 1; for (int i = 0; i < shape.NumAxes(); ++i) { if (shape.At(i) == -1) { CHECK_EQ_OR_RETURN(need_infer_axis, -1) << Error::RuntimeError() << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered"; need_infer_axis = i; } else { count *= shape.At(i); } } if (need_infer_axis != -1) { shape.Set(need_infer_axis, in_shape.elem_cnt() / count); } } out_tensor_desc->set_shape(shape); out_tensor_desc->set_stride(Stride(shape)); // For 0-size tensor, we don't need to check whether the input and output tensors have the same // element size. if (in_shape.elem_cnt() > 0) { CHECK_EQ_OR_RETURN(shape.elem_cnt(), in_shape.elem_cnt()) << Error::RuntimeError() << "Reshape infer ERROR! in op_name: " << ctx->op_name() << " input shape is : " << in_shape.ToString() << " , output shape is : " << shape.ToString() << " , and reshape shape conf is : " << ctx->Attr("shape").ToString() << " op_loc: " << ctx->op_loc(); } return Maybe::Ok(); } /*static*/ Maybe ReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = in_tensor_desc.shape(); out_tensor_desc->set_stride(Stride(in_tensor_desc.shape())); out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic()); if (in_shape.NumAxes() == 0 || logical_shape.NumAxes() == 0) { // NOTE(chengcheng): input/output Scalar // do nothing } else { CHECK_GE_OR_RETURN(logical_shape.NumAxes(), 1); // NOLINT(maybe-need-error-msg) CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1); // NOLINT(maybe-need-error-msg) const auto& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); const Shape in_logical_shape = *JUST(GetLogicalShape(in_shape, in_nd_sbp, ctx->parallel_desc())); int need_infer_axis = -1; size_t count = 1; for (int i = 0; i < logical_shape.NumAxes(); ++i) { if (logical_shape.At(i) == -1) { CHECK_EQ_OR_RETURN(need_infer_axis, -1) << Error::RuntimeError() << "Shape " << logical_shape.ToString() << " has more than 1 axis that needs to be infered"; need_infer_axis = i; } else { count *= logical_shape.At(i); } } if (need_infer_axis != -1) { logical_shape.Set(need_infer_axis, in_logical_shape.elem_cnt() / count); } } const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); out_tensor_desc->set_shape( *JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()))); out_tensor_desc->set_stride(Stride(out_tensor_desc->shape())); CHECK_EQ_OR_RETURN(out_tensor_desc->shape().elem_cnt(), in_shape.elem_cnt()) << Error::RuntimeError() << " Reshape infer ERROR! in op_name: " << ctx->op_name() << " input shape is : " << in_shape.ToString() << " , output shape is : " << out_tensor_desc->shape().ToString() << " , output logical shape is " << logical_shape.ToString() << " , and reshape shape conf is : " << ctx->Attr("shape").ToString() << " op_loc: " << ctx->op_loc(); return Maybe::Ok(); } /*static*/ Maybe ReshapeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reshape_user_op_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/reshape_user_op_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/container_util.h" namespace oneflow { Maybe ReshapeUserOpUtil::GetLogicalOutBlobShape(const Shape& in_shape, const Shape& reshape) { if (unlikely(in_shape.elem_cnt() == 0)) { FOR_RANGE(int, axis, 0, reshape.NumAxes()) { int64_t dim = reshape.At(axis); if (dim == -1) { return Error::RuntimeError() << "Cannot reshape tensor of 0 elements into shape " << reshape.DebugStr() << " because the unspecified dimension size -1 can be any value and is ambiguous"; } else if (dim < 0) { return Error::RuntimeError() << "Invalid shape dimension " << dim << ", the shape dimension can not to be less than 0"; } } return std::make_shared(reshape); } size_t total_elem_dim_exclude_minus_1 = 1; bool has_minus_1 = false; bool minus_1_axis = -1; DimVector dim_vec; FOR_RANGE(int, axis, 0, reshape.NumAxes()) { int64_t dim = reshape.At(axis); dim_vec.emplace_back(dim); if (dim == -1) { CHECK_OR_RETURN(has_minus_1 == false) << Error::RuntimeError() << "There are multiple '-1' in the shape list, only one '-1' can be inferred"; has_minus_1 = true; minus_1_axis = axis; } else if (dim > 0) { CHECK_LE_OR_RETURN(dim, in_shape.elem_cnt()) << Error::RuntimeError() << "Invalid axis: " << axis << ", dim: " << dim; total_elem_dim_exclude_minus_1 *= dim; CHECK_LE_OR_RETURN(total_elem_dim_exclude_minus_1, in_shape.elem_cnt()) << Error::RuntimeError() << "Element number in reshape_conf must be less than or equal to input blob, " << "but got " << total_elem_dim_exclude_minus_1 << " and " << in_shape.elem_cnt(); } else { OF_UNIMPLEMENTED() << "only positive number or -1 supported"; } } CHECK_EQ_OR_RETURN(in_shape.elem_cnt() % total_elem_dim_exclude_minus_1, 0) << Error::RuntimeError() << "Element number in input blob must be an integer multiple of reshape_conf, " << "but got " << in_shape.elem_cnt() << " and " << total_elem_dim_exclude_minus_1; if (has_minus_1) { dim_vec[minus_1_axis] = in_shape.elem_cnt() / total_elem_dim_exclude_minus_1; } else { CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), total_elem_dim_exclude_minus_1) << "Element number in input blob must be equal to reshape_conf, " << "but got " << in_shape.elem_cnt() << " and " << total_elem_dim_exclude_minus_1; } return std::make_shared(dim_vec); } Maybe ReshapeUserOpUtil::Squeeze(const Shape& origin, Shape* shape, HashMap* squeezed_axis2origin_axis) { DimVector dim_vec; FOR_RANGE(int, axis, 0, origin.NumAxes()) { int64_t dim = origin.At(axis); CHECK_GE_OR_RETURN(dim, 0) << Error::RuntimeError() << "Trying to suqeeze tensor with negative dimension " << dim << " : " << origin.DebugStr(); if (dim == 1) { continue; } CHECK_OR_RETURN(squeezed_axis2origin_axis->emplace(dim_vec.size(), axis).second) << "emplace error"; // NOLINT(maybe-need-error-msg) dim_vec.emplace_back(dim); } *shape = Shape(dim_vec); return Maybe::Ok(); } Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis) { CHECK_GE_OR_RETURN(in_shape.NumAxes(), 0) << Error::RuntimeError() << "The dimension of input tensor must be greater than or equal to zero, " << "but got " << in_shape.NumAxes(); // support 0D tensor CHECK_GE_OR_RETURN(out_shape.NumAxes(), 0) << Error::RuntimeError() << "The dimension of output tensor must be greater than or equal to zero, " << "but got " << out_shape.NumAxes(); // support 0D tensor CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt()) << Error::RuntimeError() << "The element number of input tensor must be equal to output tensor, " << "but got " << in_shape.elem_cnt() << " and " << out_shape.elem_cnt(); // Initialization // shape_count is the product of the axis length in [start_axis, end) int64_t in_shape_count = 1; int64_t out_shape_count = 1; int64_t in_axis = in_shape.NumAxes(); int64_t out_axis = out_shape.NumAxes(); // Move forward functions auto Move2NextAxis = [](const Shape& shape, int64_t* axis, int64_t* shape_count) { (*axis)--; if (*axis >= 0) { *shape_count *= shape.At(*axis); } }; auto MoveInAxis = [&] { Move2NextAxis(in_shape, &in_axis, &in_shape_count); }; auto MoveOutAxis = [&] { Move2NextAxis(out_shape, &out_axis, &out_shape_count); }; // Move the first step MoveInAxis(); MoveOutAxis(); // At the last step, both in_axis == out_axis == 0 // Then they would move to -1 simultaneously. while (in_axis >= 0) { if (in_shape_count == out_shape_count) { // Record split axises if (in_shape.At(in_axis) == out_shape.At(out_axis) || (in_shape.At(in_axis) % hierarchy_value == 0 && out_shape.At(out_axis) % hierarchy_value == 0)) { (*group_start_in_axis2out_axis)[in_axis] = out_axis; } // Move forward MoveInAxis(); MoveOutAxis(); } else if (in_shape_count < out_shape_count) { MoveInAxis(); } else { // in_shape_count > out_shape_count MoveOutAxis(); } } return Maybe::Ok(); } Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( const Shape& in_shape, const Shape& out_shape, const std::vector& in_args, const std::vector& out_args, const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder) { if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) { return Maybe::Ok(); } // 0D/0Size tensor only support b2b HashMap squeezed_group_start_in_axis2out_axis; HashMap in_squeezed_axis2original_axis; HashMap out_squeezed_axis2original_axis; { Shape squeezed_in_shape; Shape squeezed_out_shape; JUST(ReshapeUserOpUtil::Squeeze(in_shape, &squeezed_in_shape, &in_squeezed_axis2original_axis)); JUST(ReshapeUserOpUtil::Squeeze(out_shape, &squeezed_out_shape, &out_squeezed_axis2original_axis)); JUST(ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(squeezed_in_shape, squeezed_out_shape, hierarchy_value, &squeezed_group_start_in_axis2out_axis)); } for (const auto& pair : squeezed_group_start_in_axis2out_axis) { int64_t start_in_axis = in_squeezed_axis2original_axis.at(pair.first); int64_t start_out_axis = out_squeezed_axis2original_axis.at(pair.second); builder->Split(in_args, start_in_axis).Split(out_args, start_out_axis).Build(); } builder->PartialSum(in_args).PartialSum(out_args).Build(); return Maybe::Ok(); } namespace { void FowardRankMesh(size_t depth, size_t max_depth, std::deque& rank_axes_queue, std::vector>& rank_axes_subset) { if (depth == max_depth) { // skip empty subset if (rank_axes_queue.empty()) { return; } rank_axes_subset.emplace_back(); auto& rank_axes = rank_axes_subset.back(); for (int rank_axis : rank_axes_queue) { rank_axes.push_back(rank_axis); } } else { // forward by skip current depth axis FowardRankMesh(depth + 1, max_depth, rank_axes_queue, rank_axes_subset); // fowward by keep current depth axis rank_axes_queue.push_back(depth); FowardRankMesh(depth + 1, max_depth, rank_axes_queue, rank_axes_subset); rank_axes_queue.pop_back(); } } void GenRankMeshSubset(size_t mesh_depth, std::vector>& rank_axes_subset) { std::deque rank_axes_queue; FowardRankMesh(0, mesh_depth, rank_axes_queue, rank_axes_subset); } } // namespace Maybe ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxis( const Shape& in_shape, const std::vector& origin_in_axes, const Shape& out_shape, const std::vector& origin_out_axes, const Shape& rank_mesh, std::vector>>* nd_split_groups) { CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt()); CHECK_EQ_OR_RETURN(in_shape.size(), origin_in_axes.size()); CHECK_EQ_OR_RETURN(out_shape.size(), origin_out_axes.size()); // generate all subset of rank_mesh (keep order) // for example rank_mesh=(2, 3, 5), subset include: // (2, 3, 5) // (2, 3) // (2, 5) // (2,) // (3, 5) // (3,) // (5,) std::vector> rank_axes_subset; GenRankMeshSubset(rank_mesh.size(), rank_axes_subset); // traverse all subset to detect contiguous nd-split signatures // for example (6,) reshape to (2, 3) with rank_mesh=(2, 3) // nd-split signatures include: // S(0) -> S(0) with rank_axis=0 (1d) // S(0) -> S(1) with rank_axis=1 (1d) // [S(0), S(0)] -> [S(0), S(1)] with rank_mesh=(2,3) (2d) for (const std::vector& rank_axes : rank_axes_subset) { int rank_axis_idx = 0; int in_axis = in_shape.size() - 1; int out_axis = out_shape.size() - 1; int64_t in_dim_size = in_shape[in_axis]; int64_t out_dim_size = out_shape[out_axis]; // rank_axis -> {in_axis, out_axis} std::map> rank_in2out_axis; // go down from tail to head axis, since the dimensions // in the in_shape and the out_shape passed in // are reverse order while (in_axis >= 0 && out_axis >= 0 && rank_axis_idx < rank_axes.size()) { // dim_size == 1 then move to next axis to find contiguous split axis if (in_dim_size == 1) { in_axis--; in_dim_size = in_shape[in_axis]; continue; } if (out_dim_size == 1) { out_axis--; out_dim_size = out_shape[out_axis]; continue; } int rank_axis = rank_axes[rank_axis_idx]; int64_t rank_num = rank_mesh[rank_axis]; // dim_size is indivisible by rank_num indicate split can't continue if (in_dim_size % rank_num != 0 || out_dim_size % rank_num != 0) { break; } // divide dim_size by rank_num both at in_axis and out_axis till dim_size == 1 in_dim_size /= rank_num; out_dim_size /= rank_num; int origin_in_axis = origin_in_axes[in_axis]; int origin_out_axis = origin_out_axes[out_axis]; // mark rank_axis that can be splited by in_axis and out_axis both rank_in2out_axis.emplace(rank_axis, std::make_pair(origin_in_axis, origin_out_axis)); rank_axis_idx++; } // ensure all rank axes are marked splitable with some axis (in and out) if (rank_in2out_axis.size() == rank_axes.size()) { nd_split_groups->emplace_back(std::move(rank_in2out_axis)); } } return Maybe::Ok(); } Maybe ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxisGroups( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_sig_groups) { int in_axis = in_shape.size(); int out_axis = out_shape.size(); int64_t in_count = 1; int64_t out_count = 1; auto MoveAxis = [](const Shape& shape, int& axis, int64_t& count) { axis--; if (axis >= 0 && axis < shape.size()) { count *= shape[axis]; } }; auto MoveInAxis = [&]() { MoveAxis(in_shape, in_axis, in_count); }; auto MoveOutAxis = [&]() { MoveAxis(out_shape, out_axis, out_count); }; MoveInAxis(); MoveOutAxis(); DimVector group_in_dim_vec; DimVector group_out_dim_vec; std::vector group_in_axes; std::vector group_out_axes; group_in_axes.reserve(rank_mesh.size()); group_out_axes.reserve(rank_mesh.size()); // group reshape dimensions // for example: // (4, 5, 2, 3) reshape to (2, 2, 5, 6) will be divided to 3 groups: // ( 4,| 5, | 2, 3) // (2, 2,| 5, | 6) // group1: (2, 3) -> (6) // group2: (5,) -> (5) // group3: (4,) -> (2, 2) while (in_axis >= 0 && out_axis >= 0) { // move in_axis when in_count < out_count // move out_axis when out_count < in_count // move both when in_count == out_count if (in_count < out_count) { // skip dim_size == 1 if (in_shape[in_axis] != 1) { group_in_dim_vec.push_back(in_shape[in_axis]); group_in_axes.push_back(in_axis); } MoveInAxis(); } else if (in_count > out_count) { if (out_shape[out_axis] != 1) { group_out_dim_vec.push_back(out_shape[out_axis]); group_out_axes.push_back(out_axis); } MoveOutAxis(); } else { // in_count == out_count if (in_shape[in_axis] == out_shape[out_axis]) { // group2: (5, 5) in the example will reach this branch for (int rank_axis = 0; rank_axis < rank_mesh.size(); ++rank_axis) { int64_t rank_num = rank_mesh[rank_axis]; if (in_shape[in_axis] % rank_num == 0) { std::map> rank_in2out_split_axis{ {rank_axis, std::make_pair(in_axis, out_axis)}}; nd_sbp_in2out_sig_groups->emplace_back(std::move(rank_in2out_split_axis)); } } } else { // the reshape group (group1 and group3 in the example) finish group_in_dim_vec.push_back(in_shape[in_axis]); group_in_axes.push_back(in_axis); group_out_dim_vec.push_back(out_shape[out_axis]); group_out_axes.push_back(out_axis); // enumerate all nd-split signatures for one group JUST(EnumerateNdSplitIn2OutAxis(Shape(group_in_dim_vec), group_in_axes, Shape(group_out_dim_vec), group_out_axes, rank_mesh, nd_sbp_in2out_sig_groups)); group_in_dim_vec.clear(); group_out_dim_vec.clear(); group_in_axes.clear(); group_out_axes.clear(); } MoveInAxis(); MoveOutAxis(); } } return Maybe::Ok(); } Maybe ReshapeUserOpUtil::DfsCombineNdSbpSignatureGroups( const std::vector>>& nd_sbp_sig_groups, size_t rank_num_axes, std::vector>>* nd_sbp_sig_list) { std::map> nd_sbp_sig_group; std::set>> nd_sbp_sig_set; JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_sig_groups, rank_num_axes, nd_sbp_sig_group, nd_sbp_sig_set)); std::copy(nd_sbp_sig_set.begin(), nd_sbp_sig_set.end(), back_inserter(*nd_sbp_sig_list)); return Maybe::Ok(); } Maybe ReshapeUserOpUtil::DfsCombineNdSbpSignatureGroups( const std::vector>>& nd_sbp_sig_groups, size_t rank_num_axes, const std::map>& nd_sbp_sig_group, std::set>>& nd_sbp_sig_set) { if (nd_sbp_sig_group.size() == rank_num_axes) { std::vector> nd_sbp_sig; for (int i = 0; i < rank_num_axes; ++i) { nd_sbp_sig.emplace_back(JUST(MapAt(nd_sbp_sig_group, i))); } nd_sbp_sig_set.emplace(nd_sbp_sig); } else { for (const auto& nd_sbp_sig_group_to_combine : nd_sbp_sig_groups) { std::map> new_nd_sbp_sig_group = nd_sbp_sig_group; bool combine_failed = false; for (const auto& rank_in2out_pair : nd_sbp_sig_group_to_combine) { int rank_axis = rank_in2out_pair.first; if (nd_sbp_sig_group.find(rank_axis) != nd_sbp_sig_group.end()) { combine_failed = true; break; } CHECK_OR_RETURN(new_nd_sbp_sig_group.emplace(rank_axis, rank_in2out_pair.second).second); } if (!combine_failed) { JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_sig_groups, rank_num_axes, new_nd_sbp_sig_group, nd_sbp_sig_set)); } } } return Maybe::Ok(); } Maybe ReshapeUserOpUtil::EnumerateNdSbpIn2OutSignatures( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_signatures) { CHECK_GT_OR_RETURN(in_shape.size(), 0) << Error::RuntimeError() << "The dimension of input tensor must be greater than zero, " << "but got " << in_shape.size(); CHECK_GT_OR_RETURN(out_shape.size(), 0) << Error::RuntimeError() << "The dimension of output tensor must be greater than zero, " << "but got " << out_shape.size(); CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt()) << Error::RuntimeError() << "The element number of input tensor must be equal to output tensor, " << "but got " << in_shape.elem_cnt() << " and " << out_shape.elem_cnt(); // groups of nd of rank_axis -> (in_axis, out_axis) std::vector>> nd_sbp_signature_groups; JUST(EnumerateNdSplitIn2OutAxisGroups(in_shape, out_shape, rank_mesh, &nd_sbp_signature_groups)); std::map> nd_sbp_in2out_group; for (int rank_axis = 0; rank_axis < rank_mesh.size(); ++rank_axis) { // -1 indicate broadcaste, -2 indicate partial sum nd_sbp_in2out_group.emplace(rank_axis, std::make_pair(-1, -1)); nd_sbp_signature_groups.emplace_back(nd_sbp_in2out_group); nd_sbp_in2out_group.clear(); nd_sbp_in2out_group.emplace(rank_axis, std::make_pair(-2, -2)); nd_sbp_signature_groups.emplace_back(nd_sbp_in2out_group); nd_sbp_in2out_group.clear(); } JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_signature_groups, rank_mesh.size(), nd_sbp_in2out_signatures)); return Maybe::Ok(); } Maybe ReshapeUserOpUtil::FilterNdSbpIn2OutSignatures( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_signatures) { // filter the Nd SBP candidates // Go down from the tail to the head, since we might drop the tail. for (int i = nd_sbp_in2out_signatures->size() - 1; i >= 0; --i) { auto& nd_sbp_sig = (*nd_sbp_in2out_signatures)[i]; CHECK_EQ_OR_RETURN(nd_sbp_sig.size(), rank_mesh.size()); bool match_failed = false; DimVector in_dim_vec = in_shape.dim_vec(); DimVector out_dim_vec = out_shape.dim_vec(); for (int rank_axis = 0; rank_axis < nd_sbp_sig.size(); ++rank_axis) { int64_t rank_num = rank_mesh[rank_axis]; int in_sig = nd_sbp_sig[rank_axis].first; int out_sig = nd_sbp_sig[rank_axis].second; if (in_sig >= 0) { if (in_dim_vec[in_sig] % rank_num == 0) { in_dim_vec[in_sig] /= rank_num; } else { match_failed = true; break; } } if (out_sig >= 0) { if (out_dim_vec[out_sig] % rank_num == 0) { out_dim_vec[out_sig] /= rank_num; } else { match_failed = true; break; } } } if (match_failed) { // swap the invalid Nd SBP with the tail and drop it std::swap(nd_sbp_sig, nd_sbp_in2out_signatures->back()); nd_sbp_in2out_signatures->pop_back(); } } return Maybe::Ok(); } Maybe ReshapeUserOpUtil::EnumerateNdSbpSignatures( const std::vector& in_args, const Shape& in_shape, const std::vector& out_args, const Shape& out_shape, const Shape& rank_mesh, std::vector* nd_sbp_sig_list) { CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt()); if (in_shape.elem_cnt() == 0) { return Maybe::Ok(); } if (in_shape.size() == 0 || out_shape.size() == 0) { return Maybe::Ok(); } std::vector>> nd_sbp_in2out_sig_list; JUST(EnumerateNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh, &nd_sbp_in2out_sig_list)); for (const auto& nd_sbp_in2out_axis : nd_sbp_in2out_sig_list) { nd_sbp_sig_list->emplace_back(); auto& nd_sbp_sig = nd_sbp_sig_list->back(); for (const auto& in2out_axis : nd_sbp_in2out_axis) { for (const auto& in_arg : in_args) { const auto& ibn = in_arg.name() + "_" + std::to_string(in_arg.index()); auto& in_nd_sbp = (*nd_sbp_sig.mutable_bn_in_op2nd_sbp())[ibn]; auto* in_sbp = in_nd_sbp.add_sbp_parallel(); if (in2out_axis.first == -1) { in_sbp->mutable_broadcast_parallel(); } else if (in2out_axis.first == -2) { in_sbp->mutable_partial_sum_parallel(); } else { in_sbp->mutable_split_parallel()->set_axis(in2out_axis.first); } } for (const auto& out_arg : out_args) { const auto& obn = out_arg.name() + "_" + std::to_string(out_arg.index()); auto& out_nd_sbp = (*nd_sbp_sig.mutable_bn_in_op2nd_sbp())[obn]; auto* out_sbp = out_nd_sbp.add_sbp_parallel(); if (in2out_axis.second == -1) { out_sbp->mutable_broadcast_parallel(); } else if (in2out_axis.second == -2) { out_sbp->mutable_partial_sum_parallel(); } else { out_sbp->mutable_split_parallel()->set_axis(in2out_axis.second); } } } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/reshape_user_op_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL #define ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL #include "oneflow/core/framework/sbp_context.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/framework.h" namespace oneflow { struct ReshapeUserOpUtil { static Maybe GetLogicalOutBlobShape(const Shape& in_shape, const Shape& reshape); static Maybe Squeeze(const Shape& origin, Shape* shape, HashMap* squeezed_axis2origin_axis); static Maybe GetGroupStartInAxis2OutAxis(const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis); static Maybe GetReshapeUserOpSbpSignatures(const Shape& in_shape, const Shape& out_shape, const std::vector& in_args, const std::vector& out_args, const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder); static Maybe DfsCombineNdSbpSignatureGroups( const std::vector>>& nd_sbp_sig_groups, size_t rank_num_axes, std::vector>>* nd_sbp_sig_list); static Maybe DfsCombineNdSbpSignatureGroups( const std::vector>>& nd_sbp_sig_groups, size_t rank_num_axes, const std::map>& nd_sbp_sig_group, std::set>>& nd_sbp_sig_set); static Maybe EnumerateNdSplitIn2OutAxis( const Shape& in_shape, const std::vector& origin_in_axes, const Shape& out_shape, const std::vector& origin_out_axes, const Shape& rank_mesh, std::vector>>* nd_split_groups); static Maybe EnumerateNdSplitIn2OutAxisGroups( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_sig_groups); static Maybe EnumerateNdSbpIn2OutSignatures( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_signatures); static Maybe FilterNdSbpIn2OutSignatures( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, std::vector>>* nd_sbp_in2out_signatures); static Maybe EnumerateNdSbpSignatures(const std::vector& in_args, const Shape& in_shape, const std::vector& out_args, const Shape& out_shape, const Shape& rank_mesh, std::vector* nd_sbp_sig_list); }; } // namespace oneflow #endif // ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL ================================================ FILE: oneflow/user/ops/reshape_user_op_util_test.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/ops/reshape_user_op_util.h" #include namespace oneflow { namespace test { namespace { std::string NdSbpIn2OutSignaturesToString( const std::vector>>& nd_split_in2out_axis_list) { std::ostringstream ss; ss << "{"; int i = 0; for (const auto& nd_split_axis : nd_split_in2out_axis_list) { if (i > 0) { ss << ", "; } ss << "{"; int j = 0; for (const auto& split_in2out_axis : nd_split_axis) { if (j > 0) { ss << ", "; } ss << "{" << split_in2out_axis.first << ", " << split_in2out_axis.second << "}"; j++; } ss << "}"; i++; } ss << "}"; return ss.str(); } std::string NdSbpSignatureGroupsToString( const std::vector>>& nd_sbp_signature_groups) { std::ostringstream ss; ss << "{"; int i = 0; for (const auto& nd_sbp_sig_group : nd_sbp_signature_groups) { if (i > 0) { ss << ", "; } ss << "{"; int j = 0; for (const auto& nd_sbp_sig_pair : nd_sbp_sig_group) { if (j > 0) { ss << ", "; } ss << nd_sbp_sig_pair.first << ": {" << nd_sbp_sig_pair.second.first << ", " << nd_sbp_sig_pair.second.second << "}"; j++; } ss << "}"; i++; } ss << "}"; return ss.str(); } void TestEnumerateNdSbpIn2OutSignatures( const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh, const std::vector>>& expected_nd_sbp_in2out_sig_groups, const std::vector>>& expected_nd_sbp_in2out_sig_list) { std::vector>> actual_nd_sbp_in2out_sig_groups; CHECK_JUST(ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxisGroups(in_shape, out_shape, rank_mesh, &actual_nd_sbp_in2out_sig_groups)); std::sort(actual_nd_sbp_in2out_sig_groups.begin(), actual_nd_sbp_in2out_sig_groups.end()); ASSERT_EQ(expected_nd_sbp_in2out_sig_groups.size(), actual_nd_sbp_in2out_sig_groups.size()); for (size_t i = 0; i < actual_nd_sbp_in2out_sig_groups.size(); ++i) { const auto& exp_nd_sbp_sig_group = expected_nd_sbp_in2out_sig_groups[i]; const auto& act_nd_sbp_sig_group = actual_nd_sbp_in2out_sig_groups[i]; ASSERT_EQ(exp_nd_sbp_sig_group.size(), act_nd_sbp_sig_group.size()); for (const auto& act_pair : act_nd_sbp_sig_group) { auto exp_it = exp_nd_sbp_sig_group.find(act_pair.first); ASSERT_TRUE(exp_it != exp_nd_sbp_sig_group.end()); ASSERT_EQ(exp_it->second.first, act_pair.second.first); ASSERT_EQ(exp_it->second.second, act_pair.second.second); } } std::vector>> actual_nd_sbp_in2out_sig_list; CHECK_JUST(ReshapeUserOpUtil::EnumerateNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh, &actual_nd_sbp_in2out_sig_list)); CHECK_JUST(ReshapeUserOpUtil::FilterNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh, &actual_nd_sbp_in2out_sig_list)); std::sort(actual_nd_sbp_in2out_sig_list.begin(), actual_nd_sbp_in2out_sig_list.end()); ASSERT_EQ(expected_nd_sbp_in2out_sig_list.size(), actual_nd_sbp_in2out_sig_list.size()); for (size_t i = 0; i < actual_nd_sbp_in2out_sig_list.size(); ++i) { const auto& exp_nd_sbp_sig = expected_nd_sbp_in2out_sig_list[i]; const auto& act_nd_sbp_sig = actual_nd_sbp_in2out_sig_list[i]; ASSERT_EQ(exp_nd_sbp_sig.size(), act_nd_sbp_sig.size()); for (size_t j = 0; j < act_nd_sbp_sig.size(); ++j) { ASSERT_EQ(exp_nd_sbp_sig[j].first, act_nd_sbp_sig[j].first); ASSERT_EQ(exp_nd_sbp_sig[j].second, act_nd_sbp_sig[j].second); } } } } // namespace using std::pair; TEST(ReshapeUserOpUtil, EnumerateNdSbpIn2OutSignatures) { // clang-format off // 2D-split TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {4}, /*out_shape*/ {2, 2}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {12}, /*out_shape*/ {2, 2, 3}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 4}, /*out_shape*/ {8}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{1, 0}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {1, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 1, 4}, /*out_shape*/ {8}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{2, 0}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {2, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {8, 2}, /*out_shape*/ {2, 4, 2}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{1, 2}}}, {{1, pair{0, 0}}}, {{1, pair{1, 2}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-2, -2}, {1, 2}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{-1, -1}, {1, 2}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {0, 1}}, {{0, 0}, {1, 2}}, {{1, 2}, {-2, -2}}, {{1, 2}, {-1, -1}}, {{1, 2}, {0, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {8, 1, 2}, /*out_shape*/ {2, 1, 4, 2}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 2}}}, {{0, pair{2, 3}}}, {{1, pair{0, 0}}}, {{1, pair{2, 3}}},}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-2, -2}, {2, 3}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{-1, -1}, {2, 3}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {0, 2}}, {{0, 0}, {2, 3}}, {{2, 3}, {-2, -2}}, {{2, 3}, {-1, -1}}, {{2, 3}, {0, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {3, 2, 3, 5}, /*out_shape*/ {3, 30}, /*rank_mesh*/ {2, 3}, {{{0, pair{1, 1}}}, {{0, pair{1, 1}}, {1, pair{2, 1}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{1, 1}, {-2, -2}}, {{1, 1}, {-1, -1}}, {{1, 1}, {0, 0}}, {{1, 1}, {2, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 4}, /*out_shape*/ {4, 2}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{1, 0}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {1, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {4, 2}, /*out_shape*/ {2, 4}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{1, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {4, 3}, /*out_shape*/ {3, 4}, /*rank_mesh*/ {2, 3}, {}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 6}, /*out_shape*/ {4, 3}, /*rank_mesh*/ {2, 3}, {{{0, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 2, 5, 4}, /*out_shape*/ {4, 5, 2, 2}, /*rank_mesh*/ {2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{1, 0}}}, {{0, pair{3, 2}}}, {{0, pair{3, 2}}, {1, pair{3, 3}}}, {{1, pair{0, 0}}}, {{1, pair{3, 2}}}}, {{{-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}}, {{-2, -2}, {3, 2}}, {{-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}}, {{-1, -1}, {3, 2}}, {{0, 0}, {-2, -2}}, {{0, 0}, {-1, -1}}, {{0, 0}, {1, 0}}, {{0, 0}, {3, 2}}, {{3, 2}, {-2, -2}}, {{3, 2}, {-1, -1}}, {{3, 2}, {0, 0}}, {{3, 2}, {3, 3}}}); // 3D-split TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {24}, /*out_shape*/ {2, 4, 3}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 1}}}, {{0, pair{0, 0}}, {2, pair{0, 1}}}, {{1, pair{0, 0}}}, {{1, pair{0, 0}}, {2, pair{0, 1}}}, {{2, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {0, 1}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {0, 1}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {0, 1}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {0, 1}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {3, 24}, /*out_shape*/ {3, 2, 2, 6}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{1, 1}}}, {{0, pair{1, 1}}, {1, pair{1, 2}}}, {{0, pair{1, 1}}, {1, pair{1, 2}}, {2, pair{1, 3}}}, {{0, pair{1, 1}}, {2, pair{1, 2}}}, {{1, pair{1, 1}}}, {{1, pair{1, 1}}, {2, pair{1, 2}}}, {{2, pair{1, 1}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {1, 1}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {1, 1}}, {{-2, -2}, {1, 1}, {-2, -2}}, {{-2, -2}, {1, 1}, {-1, -1}}, {{-2, -2}, {1, 1}, {1, 2}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {1, 1}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {1, 1}}, {{-1, -1}, {1, 1}, {-2, -2}}, {{-1, -1}, {1, 1}, {-1, -1}}, {{-1, -1}, {1, 1}, {1, 2}}, {{1, 1}, {-2, -2}, {-2, -2}}, {{1, 1}, {-2, -2}, {-1, -1}}, {{1, 1}, {-2, -2}, {1, 2}}, {{1, 1}, {-1, -1}, {-2, -2}}, {{1, 1}, {-1, -1}, {-1, -1}}, {{1, 1}, {-1, -1}, {1, 2}}, {{1, 1}, {1, 2}, {-2, -2}}, {{1, 1}, {1, 2}, {-1, -1}}, {{1, 1}, {1, 2}, {1, 3}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {4, 77, 3}, /*out_shape*/ {2, 2, 77, 3}, /*rank_mesh*/ {2, 2, 3}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{1, pair{0, 0}}}, {{2, pair{2, 3}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {2, 3}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {2, 3}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {2, 3}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {2, 3}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {2, 3}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {2, 3}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {2, 3}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {2, 3}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {2, 3}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 3, 2, 5}, /*out_shape*/ {12, 5}, /*rank_mesh*/ {2, 3, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{1, 0}}}, {{0, pair{0, 0}}, {1, pair{1, 0}}, {2, pair{2, 0}}}, {{2, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {1, 0}, {-2, -2}}, {{0, 0}, {1, 0}, {-1, -1}}, {{0, 0}, {1, 0}, {2, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 1, 3, 2, 5}, /*out_shape*/ {12, 1, 5}, /*rank_mesh*/ {2, 3, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{2, 0}}}, {{0, pair{0, 0}}, {1, pair{2, 0}}, {2, pair{3, 0}}}, {{2, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {2, 0}, {-2, -2}}, {{0, 0}, {2, 0}, {-1, -1}}, {{0, 0}, {2, 0}, {3, 0}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {8, 4}, /*out_shape*/ {2, 2, 8}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 2}}}, {{0, pair{0, 0}}, {2, pair{0, 1}}}, {{1, pair{0, 0}}}, {{1, pair{0, 0}}, {2, pair{0, 1}}}, {{2, pair{0, 0}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {0, 1}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {0, 1}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {0, 1}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {0, 1}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {0, 2}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {8, 2, 2}, /*out_shape*/ {2, 2, 4, 2}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 2}}}, {{0, pair{0, 0}}, {2, pair{0, 1}}}, {{0, pair{2, 3}}}, {{1, pair{0, 0}}}, {{1, pair{0, 0}}, {2, pair{0, 1}}}, {{1, pair{2, 3}}}, {{2, pair{0, 0}}}, {{2, pair{2, 3}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {2, 3}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {2, 3}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {0, 1}}, {{-2, -2}, {0, 0}, {2, 3}}, {{-2, -2}, {2, 3}, {-2, -2}}, {{-2, -2}, {2, 3}, {-1, -1}}, {{-2, -2}, {2, 3}, {0, 0}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {2, 3}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {2, 3}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {0, 1}}, {{-1, -1}, {0, 0}, {2, 3}}, {{-1, -1}, {2, 3}, {-2, -2}}, {{-1, -1}, {2, 3}, {-1, -1}}, {{-1, -1}, {2, 3}, {0, 0}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {0, 1}}, {{0, 0}, {-2, -2}, {2, 3}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {0, 1}}, {{0, 0}, {-1, -1}, {2, 3}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {0, 2}}, {{0, 0}, {0, 1}, {2, 3}}, {{0, 0}, {2, 3}, {-2, -2}}, {{0, 0}, {2, 3}, {-1, -1}}, {{0, 0}, {2, 3}, {0, 1}}, {{2, 3}, {-2, -2}, {-2, -2}}, {{2, 3}, {-2, -2}, {-1, -1}}, {{2, 3}, {-2, -2}, {0, 0}}, {{2, 3}, {-1, -1}, {-2, -2}}, {{2, 3}, {-1, -1}, {-1, -1}}, {{2, 3}, {-1, -1}, {0, 0}}, {{2, 3}, {0, 0}, {-2, -2}}, {{2, 3}, {0, 0}, {-1, -1}}, {{2, 3}, {0, 0}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {8, 2, 1, 2}, /*out_shape*/ {2, 2, 1, 4, 2}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 3}}}, {{0, pair{0, 0}}, {2, pair{0, 1}}}, {{0, pair{3, 4}}}, {{1, pair{0, 0}}}, {{1, pair{0, 0}}, {2, pair{0, 1}}}, {{1, pair{3, 4}}}, {{2, pair{0, 0}}}, {{2, pair{3, 4}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {3, 4}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {3, 4}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {0, 1}}, {{-2, -2}, {0, 0}, {3, 4}}, {{-2, -2}, {3, 4}, {-2, -2}}, {{-2, -2}, {3, 4}, {-1, -1}}, {{-2, -2}, {3, 4}, {0, 0}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {3, 4}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {3, 4}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {0, 1}}, {{-1, -1}, {0, 0}, {3, 4}}, {{-1, -1}, {3, 4}, {-2, -2}}, {{-1, -1}, {3, 4}, {-1, -1}}, {{-1, -1}, {3, 4}, {0, 0}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {0, 1}}, {{0, 0}, {-2, -2}, {3, 4}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {0, 1}}, {{0, 0}, {-1, -1}, {3, 4}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {0, 3}}, {{0, 0}, {0, 1}, {3, 4}}, {{0, 0}, {3, 4}, {-2, -2}}, {{0, 0}, {3, 4}, {-1, -1}}, {{0, 0}, {3, 4}, {0, 1}}, {{3, 4}, {-2, -2}, {-2, -2}}, {{3, 4}, {-2, -2}, {-1, -1}}, {{3, 4}, {-2, -2}, {0, 0}}, {{3, 4}, {-1, -1}, {-2, -2}}, {{3, 4}, {-1, -1}, {-1, -1}}, {{3, 4}, {-1, -1}, {0, 0}}, {{3, 4}, {0, 0}, {-2, -2}}, {{3, 4}, {0, 0}, {-1, -1}}, {{3, 4}, {0, 0}, {0, 1}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {6, 4}, /*out_shape*/ {2, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{1, 2}}}, {{0, pair{1, 2}}, {2, pair{1, 3}}}, {{2, pair{0, 0}}}, {{2, pair{1, 2}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {1, 2}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {1, 2}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {1, 2}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {1, 2}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {1, 2}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {1, 2}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {1, 2}}, {{1, 2}, {-2, -2}, {-2, -2}}, {{1, 2}, {-2, -2}, {-1, -1}}, {{1, 2}, {-2, -2}, {0, 0}}, {{1, 2}, {-2, -2}, {1, 3}}, {{1, 2}, {-1, -1}, {-2, -2}}, {{1, 2}, {-1, -1}, {-1, -1}}, {{1, 2}, {-1, -1}, {0, 0}}, {{1, 2}, {-1, -1}, {1, 3}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {6, 4, 1}, /*out_shape*/ {2, 1, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 2}}}, {{0, pair{1, 3}}}, {{0, pair{1, 3}}, {2, pair{1, 4}}}, {{2, pair{0, 0}}}, {{2, pair{1, 3}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {1, 3}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {1, 3}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {1, 3}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {1, 3}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {1, 3}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {1, 3}}, {{0, 0}, {0, 2}, {-2, -2}}, {{0, 0}, {0, 2}, {-1, -1}}, {{0, 0}, {0, 2}, {1, 3}}, {{1, 3}, {-2, -2}, {-2, -2}}, {{1, 3}, {-2, -2}, {-1, -1}}, {{1, 3}, {-2, -2}, {0, 0}}, {{1, 3}, {-2, -2}, {1, 4}}, {{1, 3}, {-1, -1}, {-2, -2}}, {{1, 3}, {-1, -1}, {-1, -1}}, {{1, 3}, {-1, -1}, {0, 0}}, {{1, 3}, {-1, -1}, {1, 4}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {6, 3, 4}, /*out_shape*/ {2, 3, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{2, 3}}}, {{0, pair{2, 3}}, {2, pair{2, 4}}}, {{1, pair{1, 2}}}, {{2, pair{0, 0}}}, {{2, pair{2, 3}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {2, 3}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {2, 3}}, {{-2, -2}, {1, 2}, {-2, -2}}, {{-2, -2}, {1, 2}, {-1, -1}}, {{-2, -2}, {1, 2}, {0, 0}}, {{-2, -2}, {1, 2}, {2, 3}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {2, 3}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {2, 3}}, {{-1, -1}, {1, 2}, {-2, -2}}, {{-1, -1}, {1, 2}, {-1, -1}}, {{-1, -1}, {1, 2}, {0, 0}}, {{-1, -1}, {1, 2}, {2, 3}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {2, 3}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {2, 3}}, {{0, 0}, {0, 1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}}, {{0, 0}, {0, 1}, {2, 3}}, {{0, 0}, {1, 2}, {-2, -2}}, {{0, 0}, {1, 2}, {-1, -1}}, {{0, 0}, {1, 2}, {2, 3}}, {{2, 3}, {-2, -2}, {-2, -2}}, {{2, 3}, {-2, -2}, {-1, -1}}, {{2, 3}, {-2, -2}, {0, 0}}, {{2, 3}, {-2, -2}, {2, 4}}, {{2, 3}, {-1, -1}, {-2, -2}}, {{2, 3}, {-1, -1}, {-1, -1}}, {{2, 3}, {-1, -1}, {0, 0}}, {{2, 3}, {-1, -1}, {2, 4}}, {{2, 3}, {1, 2}, {-2, -2}}, {{2, 3}, {1, 2}, {-1, -1}}, {{2, 3}, {1, 2}, {0, 0}}, {{2, 3}, {1, 2}, {2, 4}}}); TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {2, 8}, /*out_shape*/ {2, 2, 2, 2}, /*rank_mesh*/ {2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{1, 1}}}, {{0, pair{1, 1}}, {1, pair{1, 2}}}, {{0, pair{1, 1}}, {1, pair{1, 2}}, {2, pair{1, 3}}}, {{0, pair{1, 1}}, {2, pair{1, 2}}}, {{1, pair{0, 0}}}, {{1, pair{1, 1}}}, {{1, pair{1, 1}}, {2, pair{1, 2}}}, {{2, pair{0, 0}}}, {{2, pair{1, 1}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {1, 1}}, {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {1, 1}}, {{-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {0, 0}, {1, 1}}, {{-2, -2}, {1, 1}, {-2, -2}}, {{-2, -2}, {1, 1}, {-1, -1}}, {{-2, -2}, {1, 1}, {0, 0}}, {{-2, -2}, {1, 1}, {1, 2}}, {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {1, 1}}, {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {1, 1}}, {{-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {0, 0}, {1, 1}}, {{-1, -1}, {1, 1}, {-2, -2}}, {{-1, -1}, {1, 1}, {-1, -1}}, {{-1, -1}, {1, 1}, {0, 0}}, {{-1, -1}, {1, 1}, {1, 2}}, {{0, 0}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {1, 1}}, {{0, 0}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {1, 1}}, {{0, 0}, {1, 1}, {-2, -2}}, {{0, 0}, {1, 1}, {-1, -1}}, {{0, 0}, {1, 1}, {1, 2}}, {{1, 1}, {-2, -2}, {-2, -2}}, {{1, 1}, {-2, -2}, {-1, -1}}, {{1, 1}, {-2, -2}, {0, 0}}, {{1, 1}, {-2, -2}, {1, 2}}, {{1, 1}, {-1, -1}, {-2, -2}}, {{1, 1}, {-1, -1}, {-1, -1}}, {{1, 1}, {-1, -1}, {0, 0}}, {{1, 1}, {-1, -1}, {1, 2}}, {{1, 1}, {0, 0}, {-2, -2}}, {{1, 1}, {0, 0}, {-1, -1}}, {{1, 1}, {0, 0}, {1, 2}}, {{1, 1}, {1, 2}, {-2, -2}}, {{1, 1}, {1, 2}, {-1, -1}}, {{1, 1}, {1, 2}, {0, 0}}, {{1, 1}, {1, 2}, {1, 3}}}); // 4D-split TestEnumerateNdSbpIn2OutSignatures( /*in_shape*/ {4, 77, 8}, /*out_shape*/ {2, 2, 77, 2, 4}, /*rank_mesh*/ {2, 2, 2, 2}, {{{0, pair{0, 0}}}, {{0, pair{0, 0}}, {1, pair{0, 1}}}, {{0, pair{0, 0}}, {2, pair{0, 1}}}, {{0, pair{0, 0}}, {3, pair{0, 1}}}, {{0, pair{2, 3}}}, {{0, pair{2, 3}}, {1, pair{2, 4}}}, {{0, pair{2, 3}}, {1, pair{2, 4}}, {2, pair{2, 4}}}, {{0, pair{2, 3}}, {1, pair{2, 4}}, {3, pair{2, 4}}}, {{0, pair{2, 3}}, {2, pair{2, 4}}}, {{0, pair{2, 3}}, {2, pair{2, 4}}, {3, pair{2, 4}}}, {{0, pair{2, 3}}, {3, pair{2, 4}}}, {{1, pair{0, 0}}}, {{1, pair{0, 0}}, {2, pair{0, 1}}}, {{1, pair{0, 0}}, {3, pair{0, 1}}}, {{1, pair{2, 3}}}, {{1, pair{2, 3}}, {2, pair{2, 4}}}, {{1, pair{2, 3}}, {2, pair{2, 4}}, {3, pair{2, 4}}}, {{1, pair{2, 3}}, {3, pair{2, 4}}}, {{2, pair{0, 0}}}, {{2, pair{0, 0}}, {3, pair{0, 1}}}, {{2, pair{2, 3}}}, {{2, pair{2, 3}}, {3, pair{2, 4}}}, {{3, pair{0, 0}}}, {{3, pair{2, 3}}}}, {{{-2, -2}, {-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {-2, -2}, {0, 0}}, {{-2, -2}, {-2, -2}, {-2, -2}, {2, 3}}, {{-2, -2}, {-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-2, -2}, {-1, -1}, {0, 0}}, {{-2, -2}, {-2, -2}, {-1, -1}, {2, 3}}, {{-2, -2}, {-2, -2}, {0, 0}, {-2, -2}}, {{-2, -2}, {-2, -2}, {0, 0}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}, {0, 1}}, {{-2, -2}, {-2, -2}, {0, 0}, {2, 3}}, {{-2, -2}, {-2, -2}, {2, 3}, {-2, -2}}, {{-2, -2}, {-2, -2}, {2, 3}, {-1, -1}}, {{-2, -2}, {-2, -2}, {2, 3}, {0, 0}}, {{-2, -2}, {-2, -2}, {2, 3}, {2, 4}}, {{-2, -2}, {-1, -1}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-1, -1}, {-2, -2}, {0, 0}}, {{-2, -2}, {-1, -1}, {-2, -2}, {2, 3}}, {{-2, -2}, {-1, -1}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {-1, -1}, {0, 0}}, {{-2, -2}, {-1, -1}, {-1, -1}, {2, 3}}, {{-2, -2}, {-1, -1}, {0, 0}, {-2, -2}}, {{-2, -2}, {-1, -1}, {0, 0}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}, {0, 1}}, {{-2, -2}, {-1, -1}, {0, 0}, {2, 3}}, {{-2, -2}, {-1, -1}, {2, 3}, {-2, -2}}, {{-2, -2}, {-1, -1}, {2, 3}, {-1, -1}}, {{-2, -2}, {-1, -1}, {2, 3}, {0, 0}}, {{-2, -2}, {-1, -1}, {2, 3}, {2, 4}}, {{-2, -2}, {0, 0}, {-2, -2}, {-2, -2}}, {{-2, -2}, {0, 0}, {-2, -2}, {-1, -1}}, {{-2, -2}, {0, 0}, {-2, -2}, {0, 1}}, {{-2, -2}, {0, 0}, {-2, -2}, {2, 3}}, {{-2, -2}, {0, 0}, {-1, -1}, {-2, -2}}, {{-2, -2}, {0, 0}, {-1, -1}, {-1, -1}}, {{-2, -2}, {0, 0}, {-1, -1}, {0, 1}}, {{-2, -2}, {0, 0}, {-1, -1}, {2, 3}}, {{-2, -2}, {0, 0}, {0, 1}, {-2, -2}}, {{-2, -2}, {0, 0}, {0, 1}, {-1, -1}}, {{-2, -2}, {0, 0}, {0, 1}, {2, 3}}, {{-2, -2}, {0, 0}, {2, 3}, {-2, -2}}, {{-2, -2}, {0, 0}, {2, 3}, {-1, -1}}, {{-2, -2}, {0, 0}, {2, 3}, {0, 1}}, {{-2, -2}, {0, 0}, {2, 3}, {2, 4}}, {{-2, -2}, {2, 3}, {-2, -2}, {-2, -2}}, {{-2, -2}, {2, 3}, {-2, -2}, {-1, -1}}, {{-2, -2}, {2, 3}, {-2, -2}, {0, 0}}, {{-2, -2}, {2, 3}, {-2, -2}, {2, 4}}, {{-2, -2}, {2, 3}, {-1, -1}, {-2, -2}}, {{-2, -2}, {2, 3}, {-1, -1}, {-1, -1}}, {{-2, -2}, {2, 3}, {-1, -1}, {0, 0}}, {{-2, -2}, {2, 3}, {-1, -1}, {2, 4}}, {{-2, -2}, {2, 3}, {0, 0}, {-2, -2}}, {{-2, -2}, {2, 3}, {0, 0}, {-1, -1}}, {{-2, -2}, {2, 3}, {0, 0}, {0, 1}}, {{-2, -2}, {2, 3}, {0, 0}, {2, 4}}, {{-2, -2}, {2, 3}, {2, 4}, {-2, -2}}, {{-2, -2}, {2, 3}, {2, 4}, {-1, -1}}, {{-2, -2}, {2, 3}, {2, 4}, {0, 0}}, {{-2, -2}, {2, 3}, {2, 4}, {2, 4}}, {{-1, -1}, {-2, -2}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {-2, -2}, {0, 0}}, {{-1, -1}, {-2, -2}, {-2, -2}, {2, 3}}, {{-1, -1}, {-2, -2}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-2, -2}, {-1, -1}, {0, 0}}, {{-1, -1}, {-2, -2}, {-1, -1}, {2, 3}}, {{-1, -1}, {-2, -2}, {0, 0}, {-2, -2}}, {{-1, -1}, {-2, -2}, {0, 0}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}, {0, 1}}, {{-1, -1}, {-2, -2}, {0, 0}, {2, 3}}, {{-1, -1}, {-2, -2}, {2, 3}, {-2, -2}}, {{-1, -1}, {-2, -2}, {2, 3}, {-1, -1}}, {{-1, -1}, {-2, -2}, {2, 3}, {0, 0}}, {{-1, -1}, {-2, -2}, {2, 3}, {2, 4}}, {{-1, -1}, {-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-1, -1}, {-2, -2}, {0, 0}}, {{-1, -1}, {-1, -1}, {-2, -2}, {2, 3}}, {{-1, -1}, {-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {-1, -1}, {0, 0}}, {{-1, -1}, {-1, -1}, {-1, -1}, {2, 3}}, {{-1, -1}, {-1, -1}, {0, 0}, {-2, -2}}, {{-1, -1}, {-1, -1}, {0, 0}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}, {0, 1}}, {{-1, -1}, {-1, -1}, {0, 0}, {2, 3}}, {{-1, -1}, {-1, -1}, {2, 3}, {-2, -2}}, {{-1, -1}, {-1, -1}, {2, 3}, {-1, -1}}, {{-1, -1}, {-1, -1}, {2, 3}, {0, 0}}, {{-1, -1}, {-1, -1}, {2, 3}, {2, 4}}, {{-1, -1}, {0, 0}, {-2, -2}, {-2, -2}}, {{-1, -1}, {0, 0}, {-2, -2}, {-1, -1}}, {{-1, -1}, {0, 0}, {-2, -2}, {0, 1}}, {{-1, -1}, {0, 0}, {-2, -2}, {2, 3}}, {{-1, -1}, {0, 0}, {-1, -1}, {-2, -2}}, {{-1, -1}, {0, 0}, {-1, -1}, {-1, -1}}, {{-1, -1}, {0, 0}, {-1, -1}, {0, 1}}, {{-1, -1}, {0, 0}, {-1, -1}, {2, 3}}, {{-1, -1}, {0, 0}, {0, 1}, {-2, -2}}, {{-1, -1}, {0, 0}, {0, 1}, {-1, -1}}, {{-1, -1}, {0, 0}, {0, 1}, {2, 3}}, {{-1, -1}, {0, 0}, {2, 3}, {-2, -2}}, {{-1, -1}, {0, 0}, {2, 3}, {-1, -1}}, {{-1, -1}, {0, 0}, {2, 3}, {0, 1}}, {{-1, -1}, {0, 0}, {2, 3}, {2, 4}}, {{-1, -1}, {2, 3}, {-2, -2}, {-2, -2}}, {{-1, -1}, {2, 3}, {-2, -2}, {-1, -1}}, {{-1, -1}, {2, 3}, {-2, -2}, {0, 0}}, {{-1, -1}, {2, 3}, {-2, -2}, {2, 4}}, {{-1, -1}, {2, 3}, {-1, -1}, {-2, -2}}, {{-1, -1}, {2, 3}, {-1, -1}, {-1, -1}}, {{-1, -1}, {2, 3}, {-1, -1}, {0, 0}}, {{-1, -1}, {2, 3}, {-1, -1}, {2, 4}}, {{-1, -1}, {2, 3}, {0, 0}, {-2, -2}}, {{-1, -1}, {2, 3}, {0, 0}, {-1, -1}}, {{-1, -1}, {2, 3}, {0, 0}, {0, 1}}, {{-1, -1}, {2, 3}, {0, 0}, {2, 4}}, {{-1, -1}, {2, 3}, {2, 4}, {-2, -2}}, {{-1, -1}, {2, 3}, {2, 4}, {-1, -1}}, {{-1, -1}, {2, 3}, {2, 4}, {0, 0}}, {{-1, -1}, {2, 3}, {2, 4}, {2, 4}}, {{0, 0}, {-2, -2}, {-2, -2}, {-2, -2}}, {{0, 0}, {-2, -2}, {-2, -2}, {-1, -1}}, {{0, 0}, {-2, -2}, {-2, -2}, {0, 1}}, {{0, 0}, {-2, -2}, {-2, -2}, {2, 3}}, {{0, 0}, {-2, -2}, {-1, -1}, {-2, -2}}, {{0, 0}, {-2, -2}, {-1, -1}, {-1, -1}}, {{0, 0}, {-2, -2}, {-1, -1}, {0, 1}}, {{0, 0}, {-2, -2}, {-1, -1}, {2, 3}}, {{0, 0}, {-2, -2}, {0, 1}, {-2, -2}}, {{0, 0}, {-2, -2}, {0, 1}, {-1, -1}}, {{0, 0}, {-2, -2}, {0, 1}, {2, 3}}, {{0, 0}, {-2, -2}, {2, 3}, {-2, -2}}, {{0, 0}, {-2, -2}, {2, 3}, {-1, -1}}, {{0, 0}, {-2, -2}, {2, 3}, {0, 1}}, {{0, 0}, {-2, -2}, {2, 3}, {2, 4}}, {{0, 0}, {-1, -1}, {-2, -2}, {-2, -2}}, {{0, 0}, {-1, -1}, {-2, -2}, {-1, -1}}, {{0, 0}, {-1, -1}, {-2, -2}, {0, 1}}, {{0, 0}, {-1, -1}, {-2, -2}, {2, 3}}, {{0, 0}, {-1, -1}, {-1, -1}, {-2, -2}}, {{0, 0}, {-1, -1}, {-1, -1}, {-1, -1}}, {{0, 0}, {-1, -1}, {-1, -1}, {0, 1}}, {{0, 0}, {-1, -1}, {-1, -1}, {2, 3}}, {{0, 0}, {-1, -1}, {0, 1}, {-2, -2}}, {{0, 0}, {-1, -1}, {0, 1}, {-1, -1}}, {{0, 0}, {-1, -1}, {0, 1}, {2, 3}}, {{0, 0}, {-1, -1}, {2, 3}, {-2, -2}}, {{0, 0}, {-1, -1}, {2, 3}, {-1, -1}}, {{0, 0}, {-1, -1}, {2, 3}, {0, 1}}, {{0, 0}, {-1, -1}, {2, 3}, {2, 4}}, {{0, 0}, {0, 1}, {-2, -2}, {-2, -2}}, {{0, 0}, {0, 1}, {-2, -2}, {-1, -1}}, {{0, 0}, {0, 1}, {-2, -2}, {2, 3}}, {{0, 0}, {0, 1}, {-1, -1}, {-2, -2}}, {{0, 0}, {0, 1}, {-1, -1}, {-1, -1}}, {{0, 0}, {0, 1}, {-1, -1}, {2, 3}}, {{0, 0}, {0, 1}, {2, 3}, {-2, -2}}, {{0, 0}, {0, 1}, {2, 3}, {-1, -1}}, {{0, 0}, {0, 1}, {2, 3}, {2, 4}}, {{0, 0}, {2, 3}, {-2, -2}, {-2, -2}}, {{0, 0}, {2, 3}, {-2, -2}, {-1, -1}}, {{0, 0}, {2, 3}, {-2, -2}, {0, 1}}, {{0, 0}, {2, 3}, {-2, -2}, {2, 4}}, {{0, 0}, {2, 3}, {-1, -1}, {-2, -2}}, {{0, 0}, {2, 3}, {-1, -1}, {-1, -1}}, {{0, 0}, {2, 3}, {-1, -1}, {0, 1}}, {{0, 0}, {2, 3}, {-1, -1}, {2, 4}}, {{0, 0}, {2, 3}, {0, 1}, {-2, -2}}, {{0, 0}, {2, 3}, {0, 1}, {-1, -1}}, {{0, 0}, {2, 3}, {0, 1}, {2, 4}}, {{0, 0}, {2, 3}, {2, 4}, {-2, -2}}, {{0, 0}, {2, 3}, {2, 4}, {-1, -1}}, {{0, 0}, {2, 3}, {2, 4}, {0, 1}}, {{0, 0}, {2, 3}, {2, 4}, {2, 4}}, {{2, 3}, {-2, -2}, {-2, -2}, {-2, -2}}, {{2, 3}, {-2, -2}, {-2, -2}, {-1, -1}}, {{2, 3}, {-2, -2}, {-2, -2}, {0, 0}}, {{2, 3}, {-2, -2}, {-2, -2}, {2, 4}}, {{2, 3}, {-2, -2}, {-1, -1}, {-2, -2}}, {{2, 3}, {-2, -2}, {-1, -1}, {-1, -1}}, {{2, 3}, {-2, -2}, {-1, -1}, {0, 0}}, {{2, 3}, {-2, -2}, {-1, -1}, {2, 4}}, {{2, 3}, {-2, -2}, {0, 0}, {-2, -2}}, {{2, 3}, {-2, -2}, {0, 0}, {-1, -1}}, {{2, 3}, {-2, -2}, {0, 0}, {0, 1}}, {{2, 3}, {-2, -2}, {0, 0}, {2, 4}}, {{2, 3}, {-2, -2}, {2, 4}, {-2, -2}}, {{2, 3}, {-2, -2}, {2, 4}, {-1, -1}}, {{2, 3}, {-2, -2}, {2, 4}, {0, 0}}, {{2, 3}, {-2, -2}, {2, 4}, {2, 4}}, {{2, 3}, {-1, -1}, {-2, -2}, {-2, -2}}, {{2, 3}, {-1, -1}, {-2, -2}, {-1, -1}}, {{2, 3}, {-1, -1}, {-2, -2}, {0, 0}}, {{2, 3}, {-1, -1}, {-2, -2}, {2, 4}}, {{2, 3}, {-1, -1}, {-1, -1}, {-2, -2}}, {{2, 3}, {-1, -1}, {-1, -1}, {-1, -1}}, {{2, 3}, {-1, -1}, {-1, -1}, {0, 0}}, {{2, 3}, {-1, -1}, {-1, -1}, {2, 4}}, {{2, 3}, {-1, -1}, {0, 0}, {-2, -2}}, {{2, 3}, {-1, -1}, {0, 0}, {-1, -1}}, {{2, 3}, {-1, -1}, {0, 0}, {0, 1}}, {{2, 3}, {-1, -1}, {0, 0}, {2, 4}}, {{2, 3}, {-1, -1}, {2, 4}, {-2, -2}}, {{2, 3}, {-1, -1}, {2, 4}, {-1, -1}}, {{2, 3}, {-1, -1}, {2, 4}, {0, 0}}, {{2, 3}, {-1, -1}, {2, 4}, {2, 4}}, {{2, 3}, {0, 0}, {-2, -2}, {-2, -2}}, {{2, 3}, {0, 0}, {-2, -2}, {-1, -1}}, {{2, 3}, {0, 0}, {-2, -2}, {0, 1}}, {{2, 3}, {0, 0}, {-2, -2}, {2, 4}}, {{2, 3}, {0, 0}, {-1, -1}, {-2, -2}}, {{2, 3}, {0, 0}, {-1, -1}, {-1, -1}}, {{2, 3}, {0, 0}, {-1, -1}, {0, 1}}, {{2, 3}, {0, 0}, {-1, -1}, {2, 4}}, {{2, 3}, {0, 0}, {0, 1}, {-2, -2}}, {{2, 3}, {0, 0}, {0, 1}, {-1, -1}}, {{2, 3}, {0, 0}, {0, 1}, {2, 4}}, {{2, 3}, {0, 0}, {2, 4}, {-2, -2}}, {{2, 3}, {0, 0}, {2, 4}, {-1, -1}}, {{2, 3}, {0, 0}, {2, 4}, {0, 1}}, {{2, 3}, {0, 0}, {2, 4}, {2, 4}}, {{2, 3}, {2, 4}, {-2, -2}, {-2, -2}}, {{2, 3}, {2, 4}, {-2, -2}, {-1, -1}}, {{2, 3}, {2, 4}, {-2, -2}, {0, 0}}, {{2, 3}, {2, 4}, {-2, -2}, {2, 4}}, {{2, 3}, {2, 4}, {-1, -1}, {-2, -2}}, {{2, 3}, {2, 4}, {-1, -1}, {-1, -1}}, {{2, 3}, {2, 4}, {-1, -1}, {0, 0}}, {{2, 3}, {2, 4}, {-1, -1}, {2, 4}}, {{2, 3}, {2, 4}, {0, 0}, {-2, -2}}, {{2, 3}, {2, 4}, {0, 0}, {-1, -1}}, {{2, 3}, {2, 4}, {0, 0}, {0, 1}}, {{2, 3}, {2, 4}, {0, 0}, {2, 4}}, {{2, 3}, {2, 4}, {2, 4}, {-2, -2}}, {{2, 3}, {2, 4}, {2, 4}, {-1, -1}}, {{2, 3}, {2, 4}, {2, 4}, {0, 0}}}); // clang-format on } } // namespace test } // namespace oneflow ================================================ FILE: oneflow/user/ops/rms_norm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe RmsNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& normalized_shape = ctx->Attr("normalized_shape"); if (ctx->has_input("weight", 0)) { const Shape& w_shape = ctx->InputShape("weight", 0); CHECK_EQ_OR_RETURN(w_shape, normalized_shape) << "expected weight shape " << normalized_shape.ToString() << ", got " << w_shape.ToString(); } CHECK_LE_OR_RETURN(normalized_shape.size(), x_shape.size()) << "invalid normalized shape " << normalized_shape.ToString() << " with input shape " << x_shape.ToString(); size_t batch_ndim = x_shape.size() - normalized_shape.size(); DimVector batch_dims(batch_ndim); for (int i = 0; i < x_shape.size(); ++i) { if (i < batch_ndim) { batch_dims[i] = x_shape[i]; } else { CHECK_EQ_OR_RETURN(normalized_shape[i - batch_ndim], x_shape[i]) << "invalid normalized shape " << normalized_shape.ToString() << " with input shape " << x_shape.ToString(); } } ctx->SetOutputShape("y", 0, x_shape); ctx->SetOutputShape("inv_rms", 0, Shape{batch_dims}); return Maybe::Ok(); } /*static*/ Maybe RmsNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RmsNormOp::InferDataType(user_op::InferContext* ctx) { DataType x_dtype = ctx->InputDType("x", 0); if (ctx->has_input("weight", 0)) { DataType w_dtype = ctx->InputDType("weight", 0); CHECK_EQ_OR_RETURN(w_dtype, x_dtype) << "RmsNormOp " << ctx->op_name() << " has different input dtype " << DataType_Name(x_dtype) << " and param dtype " << DataType_Name(w_dtype); } ctx->SetOutputDType("y", 0, x_dtype); DataType rms_dtype = x_dtype; if (x_dtype == DataType::kFloat16 || x_dtype == DataType::kBFloat16) { rms_dtype = DataType::kFloat; } ctx->SetOutputDType("inv_rms", 0, rms_dtype); return Maybe::Ok(); } /* static */ Maybe RmsNormOp::GetSbp(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& normalized_shape = ctx->Attr("normalized_shape"); size_t batch_ndim = x_shape.size() - normalized_shape.size(); for (int i = 0; i < batch_ndim; ++i) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Broadcast(user_op::OpArg("weight", 0)) .Split(ctx->outputs(), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe RmsNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->InputShape("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputShape("x", 0), shape); // NOLINT(maybe-need-error-msg) // No need to check weight and inv_rms legality which should be guaranteed by forward op ctx->SetOutputShape("dx", 0, shape); return Maybe::Ok(); } /*static*/ Maybe RmsNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RmsNormGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe RmsNormGradOp::GetSbp(user_op::SbpContext* ctx) { std::vector split_args = {user_op::OpArg("dy", 0), user_op::OpArg("x", 0), user_op::OpArg("inv_rms", 0)}; std::vector broadcast_args; if (ctx->user_op_conf().has_input("weight", 0)) { broadcast_args.emplace_back("weight", 0); } const Shape& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("inv_rms", 0).shape(); for (int i = 0; i < b_shape.size(); ++i) { ctx->NewBuilder() .Split(split_args, i) .Broadcast(broadcast_args) .Split(ctx->outputs(), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe RmsNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->InputShape("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputShape("x", 0), shape); // NOLINT(maybe-need-error-msg) const Shape& b_shape = ctx->InputShape("inv_rms", 0); CHECK_LE_OR_RETURN(b_shape.size(), shape.size()) << "invalid inv_rms shape " << b_shape.ToString() << " with dy shape " << shape.ToString(); size_t n_ndim = shape.size() - b_shape.size(); DimVector n_shape_vec(n_ndim); for (int i = 0; i < shape.size(); ++i) { if (i < b_shape.size()) { CHECK_EQ_OR_RETURN(b_shape[i], shape[i]) << "invalid inv_rms shape " << b_shape.ToString() << " with dy shape " << shape.ToString(); } else { n_shape_vec[i - b_shape.size()] = shape[i]; } } ctx->SetOutputShape("weight_grad", 0, Shape{n_shape_vec}); return Maybe::Ok(); } /*static*/ Maybe RmsNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RmsNormParamGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("weight_grad", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /* static */ Maybe RmsNormParamGradOp::GetSbp(user_op::SbpContext* ctx) { const Shape& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("inv_rms", 0).shape(); for (int i = 0; i < b_shape.size(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(ctx->outputs()).Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/roc_auc_score_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe RocAucScoreOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& pred_shape = ctx->InputTensorDesc("pred", 0).shape(); const Shape& label_shape = ctx->InputTensorDesc("label", 0).shape(); CHECK_EQ_OR_RETURN(pred_shape.elem_cnt(), label_shape.elem_cnt()) << "pred and label MUST have same element count."; out_desc->set_is_dynamic(false); out_desc->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe RocAucScoreOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RocAucScoreOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /* static */ Maybe RocAucScoreOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kFloat); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsFloatingDataType(label.data_type()) || IsIntegralDataType(label.data_type())) << "Input `label` data type " << DataType_Name(label.data_type()) << " is not supported."; const user_op::TensorDesc& pred = ctx->InputTensorDesc("pred", 0); CHECK_OR_RETURN(pred.data_type() == DataType::kFloat) << "Input `pred` data type " << DataType_Name(pred.data_type()) << " is not supported."; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/roi_align_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe RoiAlignOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .Split(user_op::OpArg("rois", 0), 0) .Split(user_op::OpArg("y", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); const int32_t pooled_w = ctx->Attr("pooled_w"); // x: feature map (N, C, H, W) CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 4) << Error::RuntimeError() << "The dimension of x tensor must be equal to 4, " << "but got " << x_shape.NumAxes(); // rois: (R, 5) CHECK_EQ_OR_RETURN(rois_shape.NumAxes(), 2) << Error::RuntimeError() << "The dimension of rois tensor must be equal to 2, " << "but got " << rois_shape.NumAxes(); CHECK_EQ_OR_RETURN(rois_shape.At(1), 5) << Error::RuntimeError() << "The size of rois tensor must be equal to 5 at dimension 1, " << "but got " << rois_shape.At(1); // y: (R, C, pool_h, pool_w) ctx->SetOutputShape("y", 0, Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w})); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RoiAlignOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn("rois", 0); CHECK_OR_RETURN(roi_modifier != nullptr); // NOLINT(maybe-need-error-msg) roi_modifier->set_requires_grad(false); user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn("x", 0); CHECK_OR_RETURN(feat_modifier != nullptr); // NOLINT(maybe-need-error-msg) feat_modifier->set_requires_grad(true); return Maybe::Ok(); } /*static*/ Maybe RoiAlignGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Broadcast(user_op::OpArg("x_like", 0)) .Split(user_op::OpArg("rois", 0), 0) .Broadcast(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe RoiAlignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& x_like_shape = ctx->InputShape("x_like", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); const int32_t pooled_w = ctx->Attr("pooled_w"); // x: feature map (N, C, H, W) CHECK_EQ_OR_RETURN(x_like_shape.NumAxes(), 4) << Error::RuntimeError() << "The dimension of x_like tensor must be equal to 4, " << "but got " << x_like_shape.NumAxes(); // rois: (R, 5) CHECK_EQ_OR_RETURN(rois_shape.NumAxes(), 2) << Error::RuntimeError() << "The dimension of rois tensor must be equal to 2, " << "but got " << rois_shape.NumAxes(); CHECK_EQ_OR_RETURN(rois_shape.At(1), 5) << Error::RuntimeError() << "The size of rois tensor must be equal to 5 " << "at dimension 1, " << "but got " << rois_shape.At(1); // y: (R, C, pool_h, pool_w) const Shape& y_shape = Shape({rois_shape.At(0), x_like_shape.At(1), pooled_h, pooled_w}); CHECK_EQ_OR_RETURN(y_shape, dy_shape) << Error::RuntimeError() << "Tensors y and dy must have same shape"; ctx->SetOutputShape("dx", 0, x_like_shape); return Maybe::Ok(); } /*static*/ Maybe RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RoiAlignGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x_like", 0)) << Error::TypeError() << "The dy tensor and x_like tensor must have same type"; ctx->SetOutputDType("dx", 0, ctx->InputDType("x_like", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/roll_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe RollOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const std::vector& dims = ctx->Attr>("dims"); CHECK_GT_OR_RETURN(dims.size(), 0) << Error::RuntimeError() << "The input list of dims doesn't allow to be empty"; // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with // dims == None if (dims[0] != -1) { FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { if (std::find(dims.begin(), dims.end(), i) == dims.end()) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } } } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("out", 0, in_shape); return Maybe::Ok(); } /*static*/ Maybe RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RollOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/rrelu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe RReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("output", 0, in_shape); ctx->SetOutputShape("noise_data", 0, in_shape); return Maybe::Ok(); } /*static*/ Maybe RReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe RReluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); } return Maybe::Ok(); } /* static */ Maybe RReluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("noise_data", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/same_padding_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SamePaddingOp::GetSbp(user_op::SbpContext* ctx) { const int32_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); const std::string& data_format = ctx->Attr("data_format"); ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); const int32_t channel_idx = ChannelIdx(data_format, num_axes); ctx->NewBuilder() .Split(user_op::OpArg("x", 0), channel_idx) .Split(user_op::OpArg("y", 0), channel_idx) .Build(); ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); return Maybe::Ok(); } /*static*/ Maybe SamePaddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(x_desc.shape()); y_desc->set_is_dynamic(x_desc.is_dynamic()); const std::string& data_format = ctx->Attr("data_format"); const auto& kernel_size = ctx->Attr>("kernel_size"); const auto& strides = ctx->Attr>("strides"); const auto& dilation_rate = ctx->Attr>("dilation_rate"); const size_t idx_offset = IdxOffset(data_format); const int32_t num_spatial_dims = x_desc.shape().NumAxes() - 2; CHECK_EQ_OR_RETURN(num_spatial_dims, kernel_size.size()) << Error::RuntimeError() << "The dimension of x tensor must be equal to the size of kernel_size array plus 2, " << "but got " << num_spatial_dims << " and " << kernel_size.size(); CHECK_EQ_OR_RETURN(num_spatial_dims, strides.size()) << Error::RuntimeError() << "The dimension of x tensor must be equal to the size of strides array plus 2, " << "but got " << num_spatial_dims << " and " << strides.size(); CHECK_EQ_OR_RETURN(num_spatial_dims, dilation_rate.size()) << Error::RuntimeError() << "The dimension of x tensor must be equal to the size of dilation_rate array plus 2, " << "but got " << num_spatial_dims << " and " << dilation_rate.size(); DimVector y_dim_vec(x_desc.shape().dim_vec()); for (int32_t i = 0; i < num_spatial_dims; ++i) { int32_t padding_small = 0; int32_t padding_large = 0; JUST(CalcSamePadding(x_desc.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i), strides.at(i), &padding_small, &padding_large)); y_dim_vec[idx_offset + i] = x_desc.shape().At(idx_offset + i) + padding_small + padding_large; } y_desc->set_shape(Shape(y_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe SamePaddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SamePaddingOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::GetSbp(user_op::SbpContext* ctx) { const int32_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); const std::string& data_format = ctx->Attr("data_format"); ctx->NewBuilder() .Split(user_op::OpArg("x_like", 0), 0) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); const int32_t channel_idx = ChannelIdx(data_format, num_axes); ctx->NewBuilder() .Split(user_op::OpArg("x_like", 0), channel_idx) .Split(user_op::OpArg("dy", 0), channel_idx) .Split(user_op::OpArg("dx", 0), channel_idx) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("x_like", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("x_like", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("x_like", 0)) .Broadcast(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x_like", 0)); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("x_like", 0)); return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SamePaddingGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x_like", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/scalar_bitwise_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { #define IMPLEMENT_SCALAR_BITWISE_OP_FUNCS(name) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ } \ return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); \ ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); \ return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); \ return Maybe::Ok(); \ } IMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseAnd); IMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseOr); IMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseXor); } // namespace oneflow ================================================ FILE: oneflow/user/ops/scalar_by_tensor_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe TensorDescInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& scalar = ctx->InputTensorDesc("scalar", 0); CHECK_EQ_OR_RETURN(scalar.shape().elem_cnt(), 1) << Error::RuntimeError() << "The input scalar tensor is not a scalar"; user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(x.shape()); y->set_is_dynamic(x.is_dynamic()); return Maybe::Ok(); } Maybe DataTypeInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& scalar = ctx->InputTensorDesc("scalar", 0); CHECK_EQ_OR_RETURN(x.data_type(), scalar.data_type()) << Error::TypeError() << "Tensors x and scalar have different type"; user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(x.data_type()); return Maybe::Ok(); } Maybe GetBasicSbpSignature(user_op::SbpContext* ctx) { const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Broadcast(user_op::OpArg("scalar", 0)) .Build(); } return Maybe::Ok(); } using GetSbpFn = std::function(user_op::SbpContext*)>; GetSbpFn MakeGetSbpFn(GetSbpFn extra) { return [extra](user_op::SbpContext* ctx) -> Maybe { JUST(extra(ctx)); JUST(GetBasicSbpSignature(ctx)); return Maybe::Ok(); }; } } // namespace /*static*/ Maybe ScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) { return MakeGetSbpFn([](user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); })(ctx); } /*static*/ Maybe ScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return TensorDescInferFn(ctx); } /*static*/ Maybe ScalarAddByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) { return DataTypeInferFn(ctx); } /*static*/ Maybe HostScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) { return MakeGetSbpFn([](user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); })(ctx); } /*static*/ Maybe HostScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return TensorDescInferFn(ctx); } /*static*/ Maybe HostScalarAddByTensorOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe HostScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) { return DataTypeInferFn(ctx); } REGISTER_OP_HOST_MEMORY_INPUT("host_scalar_add_by_tensor", "scalar", 0); /*static*/ Maybe ScalarSubByTensorOp::GetSbp(user_op::SbpContext* ctx) { return MakeGetSbpFn([](user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); })(ctx); } /*static*/ Maybe ScalarSubByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return TensorDescInferFn(ctx); } /*static*/ Maybe ScalarSubByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarSubByTensorOp::InferDataType(user_op::InferContext* ctx) { return DataTypeInferFn(ctx); } /*static*/ Maybe ScalarMulByTensorOp::GetSbp(user_op::SbpContext* ctx) { return MakeGetSbpFn([](user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); })(ctx); } /*static*/ Maybe ScalarMulByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return TensorDescInferFn(ctx); } /*static*/ Maybe ScalarMulByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarMulByTensorOp::InferDataType(user_op::InferContext* ctx) { return DataTypeInferFn(ctx); } /*static*/ Maybe ScalarDivByTensorOp::GetSbp(user_op::SbpContext* ctx) { return MakeGetSbpFn([](user_op::SbpContext* ctx) { ctx->NewBuilder() .PartialSum(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("scalar", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); })(ctx); } /*static*/ Maybe ScalarDivByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return TensorDescInferFn(ctx); } /*static*/ Maybe ScalarDivByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarDivByTensorOp::InferDataType(user_op::InferContext* ctx) { return DataTypeInferFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/scalar_logical_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { #define IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(name) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ } \ return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); \ ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); \ return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ ctx->SetOutputDType("out", 0, DataType::kBool); \ return Maybe::Ok(); \ } IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalEqual); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalNotEqual); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreater); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreaterEqual); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLess); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLessEqual); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalAnd); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalOr); IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalXor); } // namespace oneflow ================================================ FILE: oneflow/user/ops/scalar_math_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe GetSbp4ScalarMath(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } } // namespace #define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); \ ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); \ return Maybe::Ok(); \ } \ /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); \ return Maybe::Ok(); \ } IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarAdd, GetSbp4ScalarMath) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFloordiv, GetSbp4ScalarMath) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarTruncdiv, GetSbp4ScalarMath) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFmod, GetSbp4ScalarMath) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarMul, GetSbp4ScalarMul) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarDiv, GetSbp4ScalarMul) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarPow, GetSbp4ScalarMath) IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) #undef IMPLEMENT_SCALAR_MATH_OP_FUNCS /*static*/ Maybe ScalarPowGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarPowGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ScalarReversePowGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& q_shape = ctx->InputShape("query", 0); const Shape& k_shape = ctx->InputShape("key", 0); const Shape& v_shape = ctx->InputShape("value", 0); auto batch_size = q_shape.At(0); auto seqlen_q = q_shape.At(1); auto num_heads = q_shape.At(2); auto head_size_og = q_shape.At(3); auto seqlen_k = k_shape.At(1); auto num_heads_k = k_shape.At(2); // check input tensor shape. CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << "query has different batch size from key."; CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << "query has different batch size from value."; CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << "key has different seqlen from value."; CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << "key has different num_heads from value."; CHECK_EQ_OR_RETURN(head_size_og, k_shape.At(3)) << "query has different head_size from key"; CHECK_EQ_OR_RETURN(head_size_og, v_shape.At(3)) << "query has different head_size from value"; // batch size must be positive. CHECK_GT_OR_RETURN(batch_size, 0) << "batch size must be positive"; // only support head dimensions at most 256. CHECK_LE_OR_RETURN(head_size_og, 256) << "only support head dimensions at most 256"; // number of heads in key/value must devide number of heads in query. CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0) << "number of heads in key/value must devide number of heads in query."; ctx->SetOutputShape("out", 0, Shape({batch_size, seqlen_q, num_heads, head_size_og})); // save for backward ctx->SetOutputShape("softmax_lse", 0, Shape({batch_size, num_heads, seqlen_q})); // save seed and offset for backward. ctx->SetOutputShape("rng_state", 0, Shape({2})); return Maybe::Ok(); } Maybe ScaledDotProductFlashAttentionOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(ctx); } Maybe ScaledDotProductFlashAttentionOp::GetSbp(user_op::SbpContext* ctx) { auto parallel_num = ctx->parallel_num(); const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0).shape(); const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("key", 0).shape(); auto num_heads = q_shape.At(2); auto num_heads_k = k_shape.At(2); bool can_spilt_num_heads = num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num)); if (can_spilt_num_heads) { // prior to split on num_heads. ctx->NewBuilder() .Split(user_op::OpArg("query", 0), 2) .Split(user_op::OpArg("key", 0), 2) .Split(user_op::OpArg("value", 0), 2) .Split(user_op::OpArg("out", 0), 2) .Split(user_op::OpArg("softmax", 0), 1) .Broadcast(user_op::OpArg("rng_state", 0)) .Build(); } else { // otherwise split on batch_size. ctx->NewBuilder() .Split(user_op::OpArg("query", 0), 0) .Split(user_op::OpArg("key", 0), 0) .Split(user_op::OpArg("value", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Split(user_op::OpArg("softmax", 0), 0) .Broadcast(user_op::OpArg("rng_state", 0)) .Build(); } return Maybe::Ok(); } Maybe ScaledDotProductFlashAttentionOp::InferDataType(user_op::InferContext* ctx) { auto q_datatype = ctx->InputDType("query", 0); auto k_datatype = ctx->InputDType("key", 0); auto v_datatype = ctx->InputDType("value", 0); CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << "query has different data type from key."; CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << "query has different data type from value."; ctx->SetOutputDType("out", 0, q_datatype); ctx->SetOutputDType("softmax_lse", 0, DataType::kFloat); ctx->SetOutputDType("rng_state", 0, DataType::kUInt64); return Maybe::Ok(); } Maybe ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dout_shape = ctx->InputShape("grad_out", 0); const Shape& q_shape = ctx->InputShape("query", 0); const Shape& k_shape = ctx->InputShape("key", 0); const Shape& v_shape = ctx->InputShape("value", 0); const Shape& out_shape = ctx->InputShape("out", 0); const Shape& softmax_lse_shape = ctx->InputShape("softmax_lse", 0); auto batch_size = q_shape.At(0); auto seqlen_q = q_shape.At(1); auto num_heads = q_shape.At(2); auto head_size = q_shape.At(3); auto seqlen_k = k_shape.At(1); auto num_heads_k = k_shape.At(2); auto head_size_og = dout_shape.At(3); // check input tensor shape. CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << "query has different batch size from key."; CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << "query has different batch size from value."; CHECK_EQ_OR_RETURN(batch_size, dout_shape.At(0)) << "query has different batch size from grad_out."; CHECK_EQ_OR_RETURN(batch_size, out_shape.At(0)) << "query has different batch size from out."; CHECK_EQ_OR_RETURN(batch_size, softmax_lse_shape.At(0)) << "query has different batch size from softmax_lse."; CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << "key has different seqlen from value."; CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << "key has different num_heads from value."; // dout should be padded in functional layer if needed. CHECK_EQ_OR_RETURN(head_size_og, head_size) << "grad_out has different head_size from query"; CHECK_EQ_OR_RETURN(head_size, k_shape.At(3)) << "query has different head_size from key"; CHECK_EQ_OR_RETURN(head_size, v_shape.At(3)) << "query has different head_size from value"; // batch size must be positive. CHECK_GT_OR_RETURN(batch_size, 0) << "batch size must be positive"; // only support head dimensions at most 256. CHECK_LE_OR_RETURN(head_size_og, 256) << "only support head dimensions at most 256"; CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0) << "number of heads in key/value must devide number of heads in query."; // grad_k/v should be expanded if needed(when num_heads != num_heads_k && num_heads % num_heads_k // == 0). ctx->SetOutputShape("grad_q", 0, Shape({batch_size, seqlen_q, num_heads, head_size})); ctx->SetOutputShape("grad_k", 0, Shape({batch_size, seqlen_k, num_heads, head_size})); ctx->SetOutputShape("grad_v", 0, Shape({batch_size, seqlen_k, num_heads, head_size})); return Maybe::Ok(); } Maybe ScaledDotProductFlashAttentionGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc(ctx); } Maybe ScaledDotProductFlashAttentionGradOp::GetSbp(user_op::SbpContext* ctx) { auto parallel_num = ctx->parallel_num(); const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0).shape(); const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("key", 0).shape(); auto num_heads = q_shape.At(2); auto num_heads_k = k_shape.At(2); bool can_spilt_num_heads = num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num)); if (can_spilt_num_heads) { // prior to split on num_heads. ctx->NewBuilder() .Split(user_op::OpArg("grad_out", 0), 2) .Split(user_op::OpArg("query", 0), 2) .Split(user_op::OpArg("key", 0), 2) .Split(user_op::OpArg("value", 0), 2) .Split(user_op::OpArg("out", 0), 2) .Split(user_op::OpArg("softmax", 0), 1) .Broadcast(user_op::OpArg("rng_state", 0)) .Split(user_op::OpArg("grad_q", 0), 2) .Split(user_op::OpArg("grad_k", 0), 2) .Split(user_op::OpArg("grad_v", 0), 2) .Build(); } else { // otherwise split on batch_size. ctx->NewBuilder() .Split(user_op::OpArg("grad_out", 0), 0) .Split(user_op::OpArg("query", 0), 0) .Split(user_op::OpArg("key", 0), 0) .Split(user_op::OpArg("value", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Split(user_op::OpArg("softmax", 0), 0) .Broadcast(user_op::OpArg("rng_state", 0)) .Split(user_op::OpArg("grad_q", 0), 0) .Split(user_op::OpArg("grad_k", 0), 0) .Split(user_op::OpArg("grad_v", 0), 0) .Build(); } return Maybe::Ok(); } Maybe ScaledDotProductFlashAttentionGradOp::InferDataType(user_op::InferContext* ctx) { auto dout_datatype = ctx->InputDType("grad_out", 0); auto q_datatype = ctx->InputDType("query", 0); auto k_datatype = ctx->InputDType("key", 0); auto v_datatype = ctx->InputDType("value", 0); auto out_datatype = ctx->InputDType("out", 0); CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << "query has different data type from key."; CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << "query has different data type from value."; CHECK_EQ_OR_RETURN(q_datatype, dout_datatype) << "query has different data type from grad_out."; CHECK_EQ_OR_RETURN(q_datatype, out_datatype) << "query has different data type from out."; ctx->SetOutputDType("grad_q", 0, q_datatype); ctx->SetOutputDType("grad_k", 0, q_datatype); ctx->SetOutputDType("grad_v", 0, q_datatype); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/search_sorted_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe SearchSortedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("values", 0)); return Maybe::Ok(); } /*static*/ Maybe SearchSortedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SearchSortedOp::GetSbp(user_op::SbpContext* ctx) { // The current implementation can only do arg_sort in the last dimension and should use // Broadcast (by default) instead of Split for that dimension const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("sorted_sequence", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /* static */ Maybe SearchSortedOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return Maybe::Ok(); } /* static */ Maybe SearchSortedOp::InferDataType(user_op::InferContext* ctx) { const bool& out_int32 = ctx->Attr("out_int32"); if (out_int32) { ctx->SetOutputDType("out", 0, DataType::kInt32); } else { ctx->SetOutputDType("out", 0, DataType::kInt64); } return Maybe::Ok(); } /* static */ Maybe SearchSortedScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape({})); return Maybe::Ok(); } /*static*/ Maybe SearchSortedScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SearchSortedScalarOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe SearchSortedScalarOp::CheckAttr(const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { return Maybe::Ok(); } /* static */ Maybe SearchSortedScalarOp::InferDataType(user_op::InferContext* ctx) { const bool& out_int32 = ctx->Attr("out_int32"); if (out_int32) { ctx->SetOutputDType("out", 0, DataType::kInt32); } else { ctx->SetOutputDType("out", 0, DataType::kInt64); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/selu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SeluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SeluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SeluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "Tensors dy and x must be the same shape"; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe SeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SeluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/sigmoid_cross_entropy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SigmoidCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { const auto num_out_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_out_axes) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), i) .Split(user_op::OpArg("label", 0), i) .Split(user_op::OpArg("loss", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); user_op::TensorDesc* loss_desc = ctx->MutOutputTensorDesc("loss", 0); loss_desc->set_shape(prediction_desc.shape()); loss_desc->set_is_dynamic(prediction_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SigmoidCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("loss", 0, ctx->InputDType("prediction", 0)); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { const auto num_dy_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("loss_diff", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_dy_axes) { ctx->NewBuilder() .Split(user_op::OpArg("loss_diff", 0), i) .Split(user_op::OpArg("label", 0), i) .Split(user_op::OpArg("prediction", 0), i) .Split(user_op::OpArg("prediction_diff", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc("loss_diff", 0); CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of loss_diff " << loss_diff_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); user_op::TensorDesc* prediction_diff = ctx->MutOutputTensorDesc("prediction_diff", 0); prediction_diff->set_shape(prediction_desc.shape()); prediction_diff->set_is_dynamic(prediction_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SigmoidCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("prediction_diff", 0, ctx->InputDType("prediction", 0)); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyGradOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/silu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SiluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SiluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SiluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe SiluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SiluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/skip_layer_norm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { oneflow::DataType InferParamDataType(const DataType x_data_type) { return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16) ? DataType::kFloat : x_data_type; } } // namespace /* static */ auto SkipLayerNormOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes() - 1; ++i) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("skip", 0), i) .Broadcast(user_op::OpArg("bias", 0)) .Broadcast(user_op::OpArg("gamma", 0)) .Broadcast(user_op::OpArg("beta", 0)) .Split(ctx->outputs(), i) .Build(); } return Maybe::Ok(); } /* static */ auto SkipLayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { // check shape of x const Shape& x_shape = ctx->InputShape("x", 0); CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2) << "number of axes of \'x\' should have be greater than or equal to 2, yet get " << x_shape.NumAxes(); // check shape of gamma, beta and bias if (ctx->has_input("gamma", 0)) { const Shape& gamma_shape = ctx->InputShape("gamma", 0); CHECK_EQ_OR_RETURN(gamma_shape.NumAxes(), 1) << "number of axes of \'gamma\' should be equal to 1, yet get " << gamma_shape.NumAxes(); CHECK_EQ_OR_RETURN(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'gamma\'(" << gamma_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (ctx->has_input("beta", 0)) { const Shape& beta_shape = ctx->InputShape("beta", 0); CHECK_EQ_OR_RETURN(beta_shape.NumAxes(), 1) << "number of axes of \'beta\' should be equal to 1, yet get " << beta_shape.NumAxes(); CHECK_EQ_OR_RETURN(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'beta\'(" << beta_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (ctx->has_input("bias", 0)) { const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // check shape of skip if (ctx->has_input("skip", 0)) { const Shape& skip_shape = ctx->InputShape("skip", 0); CHECK_EQ_OR_RETURN(skip_shape, x_shape) << "shape of \'skip\' is not the same as \'x\'"; } // set output shape of y user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc("y", 0); y_tensor->set_shape(x_shape); // set output shape of mean and varience DimVector mean_dim_vec; mean_dim_vec.push_back(x_shape.Count(0, x_shape.NumAxes() - 1)); Shape mean_shape(mean_dim_vec); user_op::TensorDesc* mean_tensor = ctx->MutOutputTensorDesc("mean", 0); user_op::TensorDesc* varience_tensor = ctx->MutOutputTensorDesc("inv_variance", 0); mean_tensor->set_shape(mean_shape); varience_tensor->set_shape(mean_shape); return Maybe::Ok(); } /* static */ auto SkipLayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /* static */ auto SkipLayerNormOp::InferDataType(user_op::InferContext* ctx) -> Maybe { // obtain input data types DataType x_dtype = ctx->InputDType("x", 0); // check data type of gamma if (ctx->has_input("gamma", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("gamma", 0), x_dtype) << "data type of \'gamma\' is not consitant with \'x\'"; } // check data type of bias if (ctx->has_input("bias", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("bias", 0), x_dtype) << "data type of \'bias\' is not consitant with \'x\'"; } // check data types of beta if (ctx->has_input("beta", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("beta", 0), x_dtype) << "data type of \'beta\' is not consitant with \'x\'"; } // check data types of skip if (ctx->has_input("skip", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("skip", 0), x_dtype) << "data type of \'skip\' is not consitant with \'x\'"; } // set output data type ctx->SetOutputDType("y", 0, x_dtype); ctx->SetOutputDType("mean", 0, InferParamDataType(x_dtype)); ctx->SetOutputDType("inv_variance", 0, InferParamDataType(x_dtype)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/skip_rms_norm_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { oneflow::DataType InferParamDataType(const DataType x_data_type) { return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16) ? DataType::kFloat : x_data_type; } } // namespace /* static */ auto SkipRmsNormOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes() - 1; ++i) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("skip", 0), i) .Broadcast(user_op::OpArg("bias", 0)) .Broadcast(user_op::OpArg("weight", 0)) .Split(ctx->outputs(), i) .Build(); } return Maybe::Ok(); } /* static */ auto SkipRmsNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { // check shape of x const Shape& x_shape = ctx->InputShape("x", 0); CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2) << "number of axes of \'x\' should have be greater than or equal to 2, yet get " << x_shape.NumAxes(); // check shape of weight and bias if (ctx->has_input("weight", 0)) { const Shape& weight_shape = ctx->InputShape("weight", 0); CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 1) << "number of axes of \'weight\' should be equal to 1, yet get " << weight_shape.NumAxes(); CHECK_EQ_OR_RETURN(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'weight\'(" << weight_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } if (ctx->has_input("bias", 0)) { const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1) << "number of axes of \'bias\' should be equal to 1, yet get " << bias_shape.NumAxes(); CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1)) << "the size of \'bias\'(" << bias_shape.At(0) << ") is not consistant with the last dimension of \'x\'(" << x_shape.At(x_shape.NumAxes() - 1) << ")"; } // check shape of skip if (ctx->has_input("skip", 0)) { const Shape& skip_shape = ctx->InputShape("skip", 0); CHECK_EQ_OR_RETURN(skip_shape, x_shape) << "shape of \'skip\' is not the same as \'x\'"; } // set output shape of y user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc("y", 0); y_tensor->set_shape(x_shape); // set output shape of inv_rms DimVector inv_rms_dim_vec; inv_rms_dim_vec.push_back(x_shape.Count(0, x_shape.NumAxes() - 1)); Shape inv_rms_shape(inv_rms_dim_vec); user_op::TensorDesc* inv_rms_tensor = ctx->MutOutputTensorDesc("inv_rms", 0); inv_rms_tensor->set_shape(inv_rms_shape); return Maybe::Ok(); } /* static */ auto SkipRmsNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { return InferLogicalTensorDesc(ctx); } /* static */ auto SkipRmsNormOp::InferDataType(user_op::InferContext* ctx) -> Maybe { // obtain input data types DataType x_dtype = ctx->InputDType("x", 0); // check data type of bias if (ctx->has_input("bias", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("bias", 0), x_dtype) << "data type of \'bias\' is not consitant with \'x\'"; } // check data types of weight if (ctx->has_input("weight", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("weight", 0), x_dtype) << "data type of \'weight\' is not consitant with \'x\'"; } // check data types of skip if (ctx->has_input("skip", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("skip", 0), x_dtype) << "data type of \'skip\' is not consitant with \'x\'"; } // set output data type ctx->SetOutputDType("y", 0, x_dtype); ctx->SetOutputDType("inv_rms", 0, InferParamDataType(x_dtype)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/slice_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/operator/operator.h" namespace oneflow { namespace { bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { if (step != 1) { return false; } if (start != 0) { return false; } if (stop != size) { return false; } return true; } } // namespace /*static*/ Maybe SliceUpdateOp::GetSbp(user_op::SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0).shape(); const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); CHECK_EQ_OR_RETURN(start_vec.size(), ndim) << Error::RuntimeError() << "The size of start list must be equal to the dimension of ref tensor, " << "but got " << start_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(stop_vec.size(), ndim) << Error::RuntimeError() << "The size of stop list must be equal to the dimension of ref tensor, " << "but got " << stop_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(step_vec.size(), ndim) << Error::RuntimeError() << "The size of step list must be equal to the dimension of ref tensor, " << "but got " << step_vec.size() << " and " << ndim; FOR_RANGE(int64_t, axis, 0, ndim) { ctx->NewBuilder() .Split(user_op::OpArg("ref", 0), axis) .Broadcast(user_op::OpArg("value", 0)) .Split(user_op::OpArg("y", 0), axis) .Build(); // FullSlice support S+S->S if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], x_shape.At(axis))) { ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); } } ctx->NewBuilder() .PartialSum(user_op::OpArg("ref", 0)) .PartialSum(user_op::OpArg("value", 0)) .PartialSum(user_op::OpArg("y", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const Shape& value_shape = ctx->InputTensorDesc("value", 0).shape(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); CHECK_OR_RETURN(!ref_desc.is_dynamic()) << Error::RuntimeError() << "The ref tensor is not dynamic"; FOR_RANGE(size_t, i, 0, step_vec.size()) { const int64_t step = step_vec.at(i); const int64_t start = start_vec.at(i); const int64_t stop = stop_vec.at(i); CHECK_GT_OR_RETURN(step, 0) << Error::RuntimeError() << "The step list elements must be greater than 0, " << "but got " << step << " at index " << i; CHECK_GE_OR_RETURN(start, 0) << Error::RuntimeError() << "The start list elements must be greater than or equal to 0, " << "but got " << start << " at index " << i; CHECK_GE_OR_RETURN(stop, 0) << Error::RuntimeError() << "The stop list elements must be greater than or equal to 0, " << "but got " << stop << " at index " << i; CHECK_LE_OR_RETURN(start, stop) << Error::RuntimeError() << "The element in start list must be less than or equal to " "the element in stop list at index " << i << ", but got " << start << " and " << stop; CHECK_EQ_OR_RETURN((stop - start + step - 1) / step, value_shape.At(i)) << Error::RuntimeError() << "The size of slice tuple must be equal to the size of value tensor at dimension " << i << ", but got " << (stop - start + step - 1) / step << " and " << value_shape.At(i); } auto* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(ref_desc.shape()); y_desc->set_is_dynamic(ref_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); auto* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(ref_desc.shape()); y_desc->set_is_dynamic(ref_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe SliceUpdateOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()) << Error::TypeError() << "Tensors ref and value must have same type"; auto* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_data_type(ref_desc.data_type()); y_desc->set_stride(ref_desc.stride()); return Maybe::Ok(); } /*static*/ Maybe SliceOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const Shape& in_shape = input_desc.shape(); int32_t ndim = in_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); CHECK_EQ_OR_RETURN(start_vec.size(), ndim) << "start_vec's dim not equal to ref shape's dim: " << start_vec.size() << " vs " << ndim; CHECK_EQ_OR_RETURN(stop_vec.size(), ndim) << "stop_vec's dim not equal to ref shape's dim: " << start_vec.size() << " vs " << ndim; CHECK_EQ_OR_RETURN(step_vec.size(), ndim) << "step_vec's dim not equal to ref shape's dim: " << start_vec.size() << " vs " << ndim; FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) { if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], in_shape.At(axis))) { ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); } else { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), axis) .PartialSum(user_op::OpArg("y", 0)) .Build(); } } ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); DimVector dim_vec(ndim); FOR_RANGE(size_t, i, 0, dim_vec.size()) { const int64_t step = step_vec.at(i); const int64_t start = start_vec.at(i); const int64_t stop = stop_vec.at(i); CHECK_GT_OR_RETURN(step, 0) << Error::RuntimeError() << "The step list elements must be greater than 0, " << "but got " << step << " at index " << i; CHECK_GE_OR_RETURN(start, 0) << Error::RuntimeError() << "The start list elements must be greater than or equal to 0, " << "but got " << start << " at index " << i; CHECK_GE_OR_RETURN(stop, 0) << Error::RuntimeError() << "The stop list elements must be greater than or equal to 0, " << "but got " << stop << " at index " << i; CHECK_LE_OR_RETURN(start, stop) << Error::RuntimeError() << "The element in start list must be less than or equal to " "the element in stop list at index " << i << ", but got " << start << " and " << stop; const int64_t diff = stop - start - 1; dim_vec[i] = diff / step + 1; } ctx->SetOutputShape("y", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); DimVector dim_vec(ndim); // logical shape in slice attributes FOR_RANGE(size_t, i, 0, dim_vec.size()) { const int64_t step = step_vec[i]; const int64_t start = start_vec[i]; const int64_t stop = stop_vec[i]; CHECK_GT_OR_RETURN(step, 0) << "Slice step must be greater than 0"; CHECK_GE_OR_RETURN(start, 0) << "Slice start must be greater or equal to 0"; CHECK_GE_OR_RETURN(stop, 0) << "Slice stop must be greater or equal to 0"; CHECK_LE_OR_RETURN(start, stop) << "Slice start must be less or equal to stop"; const int64_t diff = stop - start - 1; dim_vec[i] = diff / step + 1; } // Get physical shape with TensorSliceView const NdSbp& y_nd_sbp = ctx->NdSbp4ArgNameAndIndex("y", 0); const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const Shape& logical_shape = Shape(dim_vec); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const TensorSliceView& slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, y_nd_sbp, logical_shape, parallel_id); ctx->SetOutputShape("y", 0, Shape(slice_view.shape())); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::GetSbp(user_op::SbpContext* ctx) { const Shape& like_shape = ctx->Attr("like_shape"); const int64_t ndim = like_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); CHECK_EQ_OR_RETURN(start_vec.size(), ndim) << Error::RuntimeError() << "The size of start list must be equal to the dimension of ref tensor, " << "but got " << start_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(stop_vec.size(), ndim) << Error::RuntimeError() << "The size of stop list must be equal to the dimension of ref tensor, " << "but got " << stop_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(step_vec.size(), ndim) << Error::RuntimeError() << "The size of step list must be equal to the dimension of ref tensor, " << "but got " << step_vec.size() << " and " << ndim; FOR_RANGE(int, i, 0, ndim) { if (IsFullSlice(start_vec[i], stop_vec[i], step_vec[i], like_shape.At(i))) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder().PartialSum(user_op::OpArg("dy", 0)).PartialSum(user_op::OpArg("dx", 0)).Build(); ctx->NewBuilder().Broadcast(user_op::OpArg("dy", 0)).Broadcast(user_op::OpArg("dx", 0)).Build(); return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& like_shape = ctx->Attr("like_shape"); const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); const int64_t ndim = dy_shape.NumAxes(); CHECK_EQ_OR_RETURN(start_vec.size(), ndim) << Error::RuntimeError() << "The size of start list must be equal to the dimension of ref tensor, " << "but got " << start_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(stop_vec.size(), ndim) << Error::RuntimeError() << "The size of stop list must be equal to the dimension of ref tensor, " << "but got " << stop_vec.size() << " and " << ndim; CHECK_EQ_OR_RETURN(step_vec.size(), ndim) << Error::RuntimeError() << "The size of step list must be equal to the dimension of ref tensor, " << "but got " << step_vec.size() << " and " << ndim; ctx->SetOutputShape("dx", 0, like_shape); return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("like_shape"); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(dy_desc.is_dynamic()); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dx", 0); dx_desc->set_shape( *JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()))); int dx_ndim = dx_desc->shape().NumAxes(); int dy_ndim = dy_desc.shape().NumAxes(); CHECK_EQ_OR_RETURN(dx_ndim, dy_ndim) << Error::RuntimeError() << "The output dimension (" << dx_ndim << ") should be equal to the input dimension (" << dy_ndim << ") for slice backward"; return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); dy_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/smooth_l1_loss_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SmoothL1LossOp::GetSbp(user_op::SbpContext* ctx) { const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()) << Error::RuntimeError() << "input and target are expected to have the same dynamic property, but found " << input_desc.is_dynamic() << " and " << target_desc.is_dynamic(); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << Error::RuntimeError() << "The size of input " << input_desc.shape() << " must match the size of target " << target_desc.shape(); CHECK_GE_OR_RETURN(ctx->Attr("beta"), 0) << Error::RuntimeError() << "beta must be greater than or equal to 0, but found it to be " << ctx->Attr("beta"); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(input_desc.is_dynamic()); out_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SmoothL1LossOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << Error::TypeError() << "input and target are expected to have the same dtype, but found " << DataType_Name(input_desc.data_type()) << " and " << DataType_Name(target_desc.data_type()); ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); // NOLINT(maybe-need-error-msg) target_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossGradOp::GetSbp(user_op::SbpContext* ctx) { const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("input", 0), i) .Split(user_op::OpArg("target", 0), i) .Split(user_op::OpArg("dx", 0), i) .Split(user_op::OpArg("dy", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()) << Error::RuntimeError() << "input and target are expected to have the same dynamic property, but found " << input_desc.is_dynamic() << " and " << target_desc.is_dynamic(); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << Error::RuntimeError() << "The size of input " << input_desc.shape() << " must match the size of target " << target_desc.shape(); CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of target " << target_desc.shape(); CHECK_GE_OR_RETURN(ctx->Attr("beta"), 0) << Error::RuntimeError() << "beta must be greater than or equal to 0, but found it to be " << ctx->Attr("beta"); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_is_dynamic(input_desc.is_dynamic()); dx_desc->set_shape(input_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe SmoothL1LossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SmoothL1LossGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << Error::TypeError() << "input and target are expected to have the same dtype, but found " << DataType_Name(input_desc.data_type()) << " and " << DataType_Name(target_desc.data_type()); ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/softmax_cross_entropy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SoftmaxCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { // ctx->LogicalTensorDesc4InputArgNameAndIndex("out", 0) is not initialized here const auto num_out_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes() - 1; FOR_RANGE(int64_t, i, 0, num_out_axes) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), i) .Split(user_op::OpArg("label", 0), i) .Split(user_op::OpArg("prob", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()) << Error::RuntimeError() << "prediction and label are expected to have the same dynamic property, but found " << prediction_desc.is_dynamic() << " and " << label_desc.is_dynamic(); CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2) << Error::RuntimeError() << "The dimension of prediction must be greater than or equal to 2, but found " << prediction_desc.shape().NumAxes(); CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; DimVector out_dim_vector; FOR_RANGE(int64_t, i, 0, num_out_axes) { out_dim_vector.emplace_back(prediction_desc.shape().At(i)); } ctx->SetOutputShape("prob", 0, ctx->InputShape("prediction", 0)); ctx->SetOutputIsDynamic("prob", 0, ctx->InputIsDynamic("prediction", 0)); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(prediction_desc.is_dynamic()); out_desc->set_shape(Shape(out_dim_vector)); return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftmaxCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type()) << Error::TypeError() << "label and prediction are expected to have the same dtype, but found " << DataType_Name(label_desc.data_type()) << " and " << DataType_Name(prediction_desc.data_type()); ctx->SetOutputDType("prob", 0, ctx->InputDType("prediction", 0)); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(prediction_desc.data_type()); return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { const auto num_dy_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_dy_axes) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("label", 0), i) .Split(user_op::OpArg("prob", 0), i) .Split(user_op::OpArg("prediction_diff", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()) << Error::RuntimeError() << "prob and label are expected to have the same dynamic property, but found " << prob_desc.is_dynamic() << " and " << label_desc.is_dynamic(); CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2) << Error::RuntimeError() << "The dimension of prob must be greater than or equal to 2, but found " << prob_desc.shape().NumAxes(); CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1) << Error::RuntimeError() << "The dimension of dy is expected to be less than that of prob by 1, but found " << dy_desc.shape().NumAxes() << " and " << prob_desc.shape().NumAxes() - 1; FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) { CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i)) << Error::RuntimeError() << "The size of dy (" << dy_desc.shape().At(i) << ") must match the size of label (" << label_desc.shape().At(i) << ") at dimension " << i; } CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prob " << prob_desc.shape(); ctx->SetOutputShape("prediction_diff", 0, ctx->InputShape("prob", 0)); ctx->SetOutputIsDynamic("prediction_diff", 0, ctx->InputIsDynamic("prob", 0)); return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftmaxCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type()) << Error::TypeError() << "label and prob are expected to have the same dtype, but found " << DataType_Name(label_desc.data_type()) << " and " << DataType_Name(prob_desc.data_type()); CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()) << Error::TypeError() << "dy and prob are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prob_desc.data_type()); ctx->SetOutputDType("prediction_diff", 0, ctx->InputDType("prob", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/softmax_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SoftmaxOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), axis) .Split(user_op::OpArg("out", 0), axis) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftmaxOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } // Logically computation cost of pool op is the product of output data amount and pool kernal data // amount. After adding sbp, we just divide it by parallel number if output data is splitted because // splitting input and using partial sum for output is not a valid sbp for this op for now. /*static*/ Maybe SoftmaxOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) { double logical_computation_cost = ctx->Shape4ArgNameAndIndex("in", 0).elem_cnt() * 10; const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); const auto& nd_sbp_in = ctx->NdSbp4ArgNameAndIndex("in", 0); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_in.sbp_parallel_size(); dim_sbp++) { if (nd_sbp_in.sbp_parallel(dim_sbp).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(dim_sbp); } } return logical_computation_cost; } /*static*/ Maybe SoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), axis) .Split(user_op::OpArg("dy", 0), axis) .Split(user_op::OpArg("dx", 0), axis) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe SoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)) << Error::TypeError() << "dy and y are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("y", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/softplus_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe SoftplusOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe SoftplusOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SoftplusOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe SoftplusOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe SoftplusGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /* static */ Maybe SoftplusGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SoftplusGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe SoftplusGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/softshrink_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe SoftShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe SoftShrinkOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SoftShrinkOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe SoftShrinkOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe SoftShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /* static */ Maybe SoftShrinkGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe SoftShrinkGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe SoftShrinkGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)) << Error::TypeError() << "dy and y are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("y", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("y", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/softsign_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SoftsignOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftsignOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SoftsignGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe SoftsignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftsignGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/sort_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SortOp::GetSbp(user_op::SbpContext* ctx) { // The current implementation can only do sort in the last dimension and should use Broadcast // (by default) instead of Split for that dimension const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SortOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /*static*/ Maybe SortOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { const std::string& direction = op_conf.attr("direction"); CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING") << Error::RuntimeError() << "The input direction parameter value is expected to be ASCENDING or DESCENDING, " << "but found it to be " << direction; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/sparse_cross_entropy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe CheckPredictionLabelDesc(const user_op::TensorDesc* prediction_desc, const user_op::TensorDesc* label_desc) { CHECK_EQ_OR_RETURN(prediction_desc->is_dynamic(), label_desc->is_dynamic()) << Error::RuntimeError() << "prediction and label are expected to have the same dynamic property, but found " << prediction_desc->is_dynamic() << " and " << label_desc->is_dynamic(); CHECK_GE_OR_RETURN(prediction_desc->shape().NumAxes(), 2) << Error::RuntimeError() << "The dimension of prediction must be greater than or equal to 2, but found " << prediction_desc->shape().NumAxes(); const int64_t num_out_axes = prediction_desc->shape().NumAxes() - 1; CHECK_EQ_OR_RETURN(label_desc->shape().NumAxes(), num_out_axes) << Error::RuntimeError() << "The dimension of label is expected to be less than that of prediction by 1, but found " << label_desc->shape().NumAxes() << " and " << num_out_axes; FOR_RANGE(int64_t, i, 0, num_out_axes) { CHECK_EQ_OR_RETURN(prediction_desc->shape().At(i), label_desc->shape().At(i)) << Error::RuntimeError() << "The size of prediction (" << prediction_desc->shape().At(i) << ") must match the size of label (" << label_desc->shape().At(i) << ") at dimension " << i; } return Maybe::Ok(); } Maybe InferTensorDescFn(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); JUST(CheckPredictionLabelDesc(&prediction_desc, &label_desc)); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(prediction_desc.is_dynamic()); out_desc->set_shape(label_desc.shape()); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); JUST(CheckPredictionLabelDesc(&prediction_desc, &label_desc)); CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); ctx->SetOutputShape("prediction_diff", 0, prediction_desc.shape()); ctx->SetOutputIsDynamic("prediction_diff", 0, prediction_desc.is_dynamic()); return Maybe::Ok(); } Maybe InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(prediction_desc.data_type()); return Maybe::Ok(); } Maybe InferDataTypeGrad(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); CHECK_EQ_OR_RETURN(dy_desc.data_type(), prediction_desc.data_type()) << Error::TypeError() << "dy and prediction are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prediction_desc.data_type()); ctx->SetOutputDType("prediction_diff", 0, prediction_desc.data_type()); return Maybe::Ok(); } } // namespace /*static*/ Maybe SparseCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDescFn(ctx); } /*static*/ Maybe SparseCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SparseCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { return oneflow::InferDataType(ctx); } /*static*/ Maybe SparseCrossEntropyOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); CHECK_OR_RETURN(label_modifier != nullptr); // NOLINT(maybe-need-error-msg) label_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyMsOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), prediction.shape().NumAxes() - 1) .Broadcast(user_op::OpArg("label", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyMsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDescFn(ctx); } /*static*/ Maybe SparseCrossEntropyMsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SparseCrossEntropyMsOp::InferDataType(user_op::InferContext* ctx) { return oneflow::InferDataType(ctx); } /*static*/ Maybe SparseCrossEntropyMsOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); CHECK_OR_RETURN(label_modifier != nullptr); // NOLINT(maybe-need-error-msg) label_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe SparseCrossEntropyGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SparseCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataTypeGrad(ctx); } /*static*/ Maybe SparseCrossEntropyMsGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), prediction.shape().NumAxes() - 1) .Broadcast(user_op::OpArg("label", 0)) .Broadcast(user_op::OpArg("dy", 0)) .Split(user_op::OpArg("prediction_diff", 0), prediction.shape().NumAxes() - 1) .Build(); return Maybe::Ok(); } /*static*/ Maybe SparseCrossEntropyMsGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferGradTensorDescFn(ctx); } /*static*/ Maybe SparseCrossEntropyMsGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SparseCrossEntropyMsGradOp::InferDataType(user_op::InferContext* ctx) { return InferDataTypeGrad(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()) << Error::RuntimeError() << "prediction and label are expected to have the same dynamic property, but found " << prediction_desc.is_dynamic() << " and " << label_desc.is_dynamic(); CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2) << Error::RuntimeError() << "The dimension of prediction must be greater than or equal to 2, but found " << prediction_desc.shape().NumAxes(); const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; CHECK_EQ_OR_RETURN(label_desc.shape().NumAxes(), num_out_axes) << Error::RuntimeError() << "The dimension of label is expected to be less than that of prediction by 1, but found " << label_desc.shape().NumAxes() << " and " << num_out_axes; FOR_RANGE(int64_t, i, 0, num_out_axes) { CHECK_EQ_OR_RETURN(prediction_desc.shape().At(i), label_desc.shape().At(i)) << Error::RuntimeError() << "The size of prediction (" << prediction_desc.shape().At(i) << ") must match the size of label (" << label_desc.shape().At(i) << ") at dimension " << i; } ctx->SetOutputIsDynamic("prob", 0, prediction_desc.is_dynamic()); // 'prob' is just for compute prediction's grad, prob's grad will be ignored ctx->SetOutputShape("prob", 0, prediction_desc.shape()); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(prediction_desc.is_dynamic()); out_desc->set_shape(label_desc.shape()); return Maybe::Ok(); } Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()) << Error::RuntimeError() << "prob and label are expected to have the same dynamic property, but found " << prob_desc.is_dynamic() << " and " << label_desc.is_dynamic(); CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2) << Error::RuntimeError() << "The dimension of prob must be greater than or equal to 2, but found " << prob_desc.shape().NumAxes(); const int64_t num_out_axes = prob_desc.shape().NumAxes() - 1; CHECK_EQ_OR_RETURN(label_desc.shape().NumAxes(), num_out_axes) << Error::RuntimeError() << "The dimension of label is expected to be less than that of prediction by 1, but found " << label_desc.shape().NumAxes() << " and " << num_out_axes; FOR_RANGE(int64_t, i, 0, num_out_axes) { CHECK_EQ_OR_RETURN(prob_desc.shape().At(i), label_desc.shape().At(i)) << Error::RuntimeError() << "The size of prob (" << prob_desc.shape().At(i) << ") must match the size of label (" << label_desc.shape().At(i) << ") at dimension " << i; } CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); ctx->SetOutputShape("prediction_diff", 0, prob_desc.shape()); ctx->SetOutputIsDynamic("prediction_diff", 0, prob_desc.is_dynamic()); return Maybe::Ok(); } Maybe InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); ctx->SetOutputDType("prob", 0, ctx->InputDType("prediction", 0)); ctx->SetOutputDType("out", 0, ctx->InputDType("prediction", 0)); return Maybe::Ok(); } Maybe InferDataTypeGrad(user_op::InferContext* ctx) { const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()) << Error::TypeError() << "dy and prob are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prob_desc.data_type()); ctx->SetOutputDType("prediction_diff", 0, prob_desc.data_type()); return Maybe::Ok(); } Maybe AddSignature(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("prob", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); return Maybe::Ok(); } Maybe AddMsSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("prob", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("out", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), prediction.shape().NumAxes() - 1) .Split(user_op::OpArg("prob", 0), prediction.shape().NumAxes() - 1) .Broadcast(user_op::OpArg("label", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } Maybe AddGradSignature(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("prob", 0), 0) .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); return Maybe::Ok(); } Maybe AddGradMsSignature(user_op::SbpContext* ctx) { const user_op::TensorDesc& prob = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); ctx->NewBuilder() .Split(user_op::OpArg("prob", 0), 0) .Split(user_op::OpArg("label", 0), 0) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("prob", 0), prob.shape().NumAxes() - 1) .Broadcast(user_op::OpArg("label", 0)) .Broadcast(user_op::OpArg("dy", 0)) .Split(user_op::OpArg("prediction_diff", 0), prob.shape().NumAxes() - 1) .Build(); return Maybe::Ok(); } template (*GetSbpSignature)(user_op::SbpContext*)> Maybe GetSbpFn(user_op::SbpContext* ctx) { JUST(GetSbpSignature(ctx)); return Maybe::Ok(); } } // namespace #define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(op_name, sbp_sig) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return sbp_sig(ctx); } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ return oneflow::InferDataType(ctx); \ } \ /*static*/ Maybe op_name##Op::ModifyInputArg( \ const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ CHECK_OR_RETURN(label_modifier != nullptr); /* NOLINT(maybe-need-error-msg) */ \ label_modifier->set_requires_grad(false); \ return Maybe::Ok(); \ } IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropy, AddSignature); IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropyMs, AddMsSignature); #undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS #define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(op_name, sbp_sig) \ /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return sbp_sig(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferGradTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return InferDataTypeGrad(ctx); \ } IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropy, AddGradSignature); IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropyMs, AddGradMsSignature); #undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/split_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SplitLikeOp::GetSbp(user_op::SbpContext* ctx) { const auto axis = ctx->Attr("axis"); const int64_t in_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); const int64_t like_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, like_num_axes) { if (i == axis) { continue; } ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } std::vector like_arg_vec; const size_t like_arg_size = ctx->outputs().size(); like_arg_vec.reserve(like_arg_size); FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Broadcast(like_arg_vec) .Split(ctx->outputs(), i) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .PartialSum(like_arg_vec) .Split(ctx->outputs(), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(like_arg_vec) .PartialSum(ctx->outputs()) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .Broadcast(like_arg_vec) .PartialSum(ctx->outputs()) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("in", 0)) .PartialSum(like_arg_vec) .Broadcast(ctx->outputs()) .Build(); return Maybe::Ok(); } /*static*/ Maybe SplitLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto axis = ctx->Attr("axis"); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); int64_t dynamic_dim_size = 0; int64_t static_dim_size = 0; const int64_t in_num_axes = ctx->InputTensorDesc("in", 0).shape().NumAxes(); const int64_t like_num_axes = ctx->InputTensorDesc("like", 0).shape().NumAxes(); CHECK_LE_OR_RETURN(like_num_axes, in_num_axes) << Error::RuntimeError() << "The dimension of like (" << like_num_axes << ") should be less than or equal to input (" << in_num_axes << ")"; CHECK_LT_OR_RETURN(axis, like_num_axes) << Error::RuntimeError() << "The axis (" << axis << ") should be less than the dimension of like (" << like_num_axes << ")"; FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc("like", i); user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes) << Error::RuntimeError() << "The dimension of like_i (" << like_i_desc.shape().NumAxes() << ") must match the dimension of the first like (" << like_num_axes << ")"; FOR_RANGE(int64_t, j, 0, like_num_axes) { if (j == axis) { if (like_i_desc.is_dynamic()) { dynamic_dim_size += like_i_desc.shape().At(j); } else { static_dim_size += like_i_desc.shape().At(j); } } else { CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j)) << Error::RuntimeError() << "The size of input (" << in_desc.shape().At(j) << ") must match the size of like_i (" << like_i_desc.shape().At(j) << ") at dimension " << j; } } DimVector out_i_dim_vec = like_i_desc.shape().dim_vec(); FOR_RANGE(int64_t, j, like_num_axes, in_num_axes) { out_i_dim_vec.emplace_back(in_desc.shape().At(j)); } out_i_desc->set_shape(Shape(out_i_dim_vec)); out_i_desc->set_is_dynamic(like_i_desc.is_dynamic()); } if (dynamic_dim_size == 0) { CHECK_EQ_OR_RETURN(static_dim_size, in_desc.shape().At(axis)) << Error::RuntimeError() << "In non-dynamic shape situation, the total size of like (" << static_dim_size << ") should be equal to the size of input (" << in_desc.shape().At(axis) << ") at dimension " << axis; } else { CHECK_LE_OR_RETURN(static_dim_size, in_desc.shape().At(axis)) << Error::RuntimeError() << "In dynamic shape situation, the total size of like (" << static_dim_size << ") should be less than or equal to the size of input (" << in_desc.shape().At(axis) << ") at dimension " << axis; } return Maybe::Ok(); } /*static*/ Maybe SplitLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SplitLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); out_i_desc->set_data_type(in_desc.data_type()); } return Maybe::Ok(); } /*static*/ Maybe SplitLikeOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); CHECK_NOTNULL_OR_RETURN(like_modifier); // NOLINT(maybe-need-error-msg) like_modifier->set_requires_grad(false); } return Maybe::Ok(); } /*static*/ Maybe SplitLikeOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("like") >= 1) << Error::RuntimeError() << "The number of like should be greater than or equal to 1"; CHECK_OR_RETURN(op_conf.output_size("out") >= 1) << Error::RuntimeError() << "The number of output should be greater than or equal to 1"; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/sqrt_square_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SqrtSquareSumOp::GetSbp(user_op::SbpContext* ctx) { const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_x_axes) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /*static*/ Maybe SqrtSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(Shape({})); return Maybe::Ok(); } /*static*/ Maybe SqrtSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SqrtSquareSumOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/square_relu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SquareReLUOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("y", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe SquareReLUOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SquareReLUOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe SquareReLUOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe SquareReLUGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape) << "InferTensorDesc failed (" << ctx->op_name() << "). Expected x shape " << x_shape.ToString() << " to be equal to dy shape " << dy_shape.ToString(); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /*static*/ Maybe SquareReLUGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SquareReLUGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe SquareReLUGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/square_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SquareSumOp::GetSbp(user_op::SbpContext* ctx) { const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, num_x_axes) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /*static*/ Maybe SquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe SquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SquareSumOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe MultiSquareSumOp::GetSbp(user_op::SbpContext* ctx) { int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { min_num_axes = std::min(min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); } for (int64_t i = 0; i < min_num_axes; ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); } return Maybe::Ok(); } /*static*/ Maybe MultiSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe MultiSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MultiSquareSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); for (int64_t i = 1; i < ctx->input_size("x"); ++i) { const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()) << Error::TypeError() << "All tensors are expected to have the same dtype, but found at least two dtypes, " << DataType_Name(x_i.data_type()) << " and " << DataType_Name(x_0.data_type()); } y->set_data_type(x_0.data_type()); return Maybe::Ok(); } /*static*/ Maybe MultiSquareSumOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("x") >= 1) << Error::RuntimeError() << "The number of x should be greater than or equal to 1"; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/squeeze_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe TransformNegativeAxesToPositive(const std::vector& axes_vec, const int32_t num_axes, AxisVector* fixed_axes_vec) { fixed_axes_vec->resize(axes_vec.size()); FOR_RANGE(size_t, i, 0, fixed_axes_vec->size()) { CHECK_GE_OR_RETURN(axes_vec[i], -num_axes); CHECK_LT_OR_RETURN(axes_vec[i], num_axes); fixed_axes_vec->at(i) = axes_vec[i] >= 0 ? axes_vec[i] : axes_vec[i] + num_axes; } return Maybe::Ok(); } Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector* dim_vec) { for (const auto& axis : axes) { CHECK_EQ_OR_RETURN(dim_vec->at(axis), 1); dim_vec->at(axis) = -1; } return Maybe::Ok(); } } // namespace /*static*/ Maybe SqueezeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_tensor.shape().NumAxes(), &fixed_axes_vec)); DimVector dim_vec = in_tensor.shape().dim_vec(); JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); int32_t out_axis = 0; FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { if (dim_vec.at(in_axis) != -1) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), in_axis) .Split(user_op::OpArg("out", 0), out_axis) .Build(); ++out_axis; } } return Maybe::Ok(); } /*static*/ Maybe SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), &fixed_axes_vec)); DimVector dim_vec = in_shape.dim_vec(); JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end()); ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe SqueezeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SqueezeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/ssp_variable_proxy_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe SspVariableProxyOp::GetSbp(user_op::SbpContext* ctx) { const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("var", 0); FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("var", 0), i) .Split(user_op::OpArg("ref", 0), i) .Split(user_op::OpArg("value", 0), i) .Build(); } return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& var_shape = ctx->InputShape("var", 0); ctx->SetOutputShape("ref", 0, var_shape); ctx->SetOutputShape("value", 0, var_shape); return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SspVariableProxyOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("ref", 0, ctx->InputDType("var", 0)); ctx->SetOutputDType("value", 0, ctx->InputDType("var", 0)); return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("ref", 0); CHECK_OR_RETURN(out_modifier != nullptr); out_modifier->set_is_mutable(true); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/stack_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe StackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); const int64_t axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0) << "The axis should be greater than or equal to 0."; const int64_t in_num_axes = first_in_desc.shape().NumAxes(); CHECK_LE_OR_RETURN(axis, in_num_axes) << "The axis should be less than or equal to input num axes."; DimVector out_dim_vec(in_num_axes + 1); for (int i = 0; i < in_num_axes + 1; i++) { if (i == axis) { continue; } else if (i < axis) { out_dim_vec.at(i) = first_in_desc.shape().At(i); } else { out_dim_vec.at(i) = first_in_desc.shape().At(i - 1); } } int64_t dynamic_dim_size = 0; for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), first_in_desc.shape().NumAxes()) << "The num axes of input should be equal to first input's num axes. "; FOR_RANGE(int64_t, i, 0, in_num_axes + 1) { if (i == axis) { if (in_desc.is_dynamic()) { dynamic_dim_size += 1; } else { out_dim_vec.at(axis) += 1; } } else if (i < axis) { CHECK_EQ_OR_RETURN(in_desc.shape().At(i), out_dim_vec.at(i)) << "The input shape at axis " << i << " is not equal to out shape at axis " << i; } else { CHECK_EQ_OR_RETURN(in_desc.shape().At(i - 1), out_dim_vec.at(i)) << "The input shape at axis " << i - 1 << " is not equal to out shape at axis " << i; } } } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const int64_t max_dim_size = ctx->Attr("max_dim_size"); CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size) << "The out shape at axis " << axis << " should be less equal to " << max_dim_size; if (dynamic_dim_size == 0) { out_desc->set_is_dynamic(false); } else { out_desc->set_is_dynamic(true); out_dim_vec.at(axis) = max_dim_size; } out_desc->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe StackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe StackOp::GetSbp(user_op::SbpContext* ctx) { const int64_t axis = ctx->Attr("axis"); const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) { /* Stack can be view as expand_dims + concat. For stack([(2, 4, 6), (2, 4, 6), axis=1]), it equals to [2, 4, 6]->[2, 1, 4, 6]. concat([2, 1, 4, 6], [2, 1, 4, 6], concat_dim=1) Concat split all the axis except the concat_dim. */ if (i >= axis) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i + 1).Build(); } else { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe StackOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) << ", but got " << DataType_Name(in_desc.data_type()); } user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_data_type(first_in_desc.data_type()); return Maybe::Ok(); } /*static*/ Maybe StackOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("in") >= 1) << "The size of input should be greater than or equal to 1. "; return Maybe::Ok(); } /*static*/ Maybe StackGradOp::GetSbp(user_op::SbpContext* ctx) { const auto axis = ctx->Attr("axis"); const int64_t like_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); std::vector like_arg_vec; const size_t like_arg_size = ctx->outputs().size(); like_arg_vec.reserve(like_arg_size); FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } FOR_RANGE(int64_t, i, 0, like_num_axes) { if (i >= axis) { ctx->NewBuilder() .Split(like_arg_vec, i) .Split(ctx->outputs(), i) .Split(user_op::OpArg("in", 0), i + 1) .Build(); } else { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(like_arg_vec) .PartialSum(ctx->outputs()) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .Broadcast(like_arg_vec) .PartialSum(ctx->outputs()) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("in", 0)) .PartialSum(like_arg_vec) .Broadcast(ctx->outputs()) .Build(); return Maybe::Ok(); } /*static*/ Maybe StackGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto axis = ctx->Attr("axis"); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); int64_t dynamic_dim_size = 0; int64_t static_dim_size = 0; const int64_t in_num_axes = ctx->InputTensorDesc("in", 0).shape().NumAxes(); const int64_t like_num_axes = ctx->InputTensorDesc("like", 0).shape().NumAxes(); CHECK_LE_OR_RETURN(like_num_axes, in_num_axes) << "The num axes of `like` tensor should be less equal to num axes of `in` tensor. "; CHECK_LE_OR_RETURN(axis, like_num_axes) << "The axis should be less equal than num axes of `like` tensor. "; FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc("like", i); user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes) << "The num axes of `like` tensor at index " << i << " should be equal to first `like` tensor. "; FOR_RANGE(int64_t, j, 0, like_num_axes + 1) { if (j == axis) { if (like_i_desc.is_dynamic()) { dynamic_dim_size += like_i_desc.shape().Count(j); } else { static_dim_size += like_i_desc.shape().Count(j); } } else if (j < axis) { CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j)) << " Stack Grad expects the shape of input tensor is equal to like tensor's. " ", but got " << in_desc.shape().ToString() << " at input and " << like_i_desc.shape().ToString() << "at like "; } else { CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j - 1)) << " Stack Grad expects the shape of input tensor is equal to like tensor's. " ", but got " << in_desc.shape().ToString() << " at input and " << like_i_desc.shape().ToString() << "at like "; } } DimVector out_i_dim_vec = like_i_desc.shape().dim_vec(); out_i_desc->set_shape(Shape(out_i_dim_vec)); out_i_desc->set_is_dynamic(like_i_desc.is_dynamic()); } if (dynamic_dim_size == 0) { CHECK_EQ_OR_RETURN(static_dim_size, in_desc.shape().Count(axis)) << "In non dynamic shape situation, the static dim size should be equal to input tensor " "size. "; } else { CHECK_LE_OR_RETURN(static_dim_size, in_desc.shape().Count(axis)) << "In dynamic shape situation, the static dim size should be less equal to input tensor " "size. "; } return Maybe::Ok(); } /*static*/ Maybe StackGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe StackGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); out_i_desc->set_data_type(in_desc.data_type()); } return Maybe::Ok(); } /*static*/ Maybe StackGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); } return Maybe::Ok(); } /*static*/ Maybe StackGradOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("like") >= 1) << "The count of like tensor should be greater than or equal to 1. "; CHECK_OR_RETURN(op_conf.output_size("out") >= 1) << "The count of out tensor should be greater than or equal to 1. "; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/stft_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { const Stride InferOutputStride(const Shape& in_shape, bool onesided = true, bool return_complex = false) { // TODO(yzm):support return_complex int last_dim_size = in_shape.At(2); if (onesided) { last_dim_size = last_dim_size / 2 + 1; } Stride out_stride(in_shape.NumAxes(), 0); if (in_shape.At(0) == 1) { out_stride = {2, 2 * last_dim_size, 1}; } else { out_stride = {last_dim_size * 2 * in_shape.At(1), 2, 2 * last_dim_size, 1}; } return out_stride; } const Shape InferOutputShape(const Shape& in_shape, bool onesided = true, bool return_complex = false) { // TODO(yzm):support return_complex Shape out_shape; int last_dim_size = in_shape.At(2); if (onesided) { last_dim_size = last_dim_size / 2 + 1; } if (in_shape.At(0) == 1) { out_shape = {last_dim_size, in_shape.At(1), 2}; } else { out_shape = {in_shape.At(0), last_dim_size, in_shape.At(1), 2}; } return out_shape; } /* static */ Maybe StftOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("input", 0); const bool onesided = ctx->Attr("onesided"); const Stride& out_stride = InferOutputStride(in_shape, onesided); const Shape& out_shape = InferOutputShape(in_shape, onesided); ctx->SetOutputStride("output", 0, out_stride); ctx->SetOutputShape("output", 0, out_shape); ctx->SetOutputIsDynamic("output", 0, ctx->InputIsDynamic("input", 0)); return Maybe::Ok(); } /*static*/ Maybe StftOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe StftOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe StftOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } /*static*/ Maybe StftOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) { // TODO: add ComputeComplexityFun return 0.0; } } // namespace oneflow ================================================ FILE: oneflow/user/ops/summary_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe CheckStepShape(const Shape* step) { CHECK_OR_RETURN(step->elem_cnt() == 1); return Maybe::Ok(); } Maybe CheckStepShapeInCtx(user_op::InferContext* ctx) { JUST(CheckStepShape(&ctx->InputShape("step", 0))); return Maybe::Ok(); } Maybe CheckInAndStepScalar(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); const Shape& step_shape = ctx->InputShape("step", 0); CHECK_OR_RETURN(in_shape.elem_cnt() == 1 && step_shape.elem_cnt() == 1); return Maybe::Ok(); } } // namespace /*static*/ Maybe CreateSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe CreateSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe CreateSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe CreateSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe FlushSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe FlushSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe FlushSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe FlushSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe SummaryWriteScalarOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe SummaryWriteScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return CheckInAndStepScalar(ctx); } /*static*/ Maybe SummaryWriteScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SummaryWriteScalarOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe SummaryWriteHistogramOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe SummaryWriteHistogramOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return CheckStepShapeInCtx(ctx); } /*static*/ Maybe SummaryWriteHistogramOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SummaryWriteHistogramOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe SummaryWritePbOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe SummaryWritePbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return CheckStepShapeInCtx(ctx); } /*static*/ Maybe SummaryWritePbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SummaryWritePbOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe SummaryWriteImageOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe SummaryWriteImageOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return CheckStepShapeInCtx(ctx); } /*static*/ Maybe SummaryWriteImageOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SummaryWriteImageOp::InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tanh_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TanhOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe TanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe TanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TanhOp::InferDataType(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); } /*static*/ Maybe TanhGradOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe TanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe TanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TanhGradOp::InferDataType(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tensor_buffer_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TensorBufferToTensorOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(in.is_dynamic()); const auto& instance_shape = ctx->Attr("instance_shape"); DimVector dim_vec; dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend()); dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), instance_shape.dim_vec().cend()); out->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe TensorBufferToTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TensorBufferToTensorOp::InferDataType(user_op::InferContext* ctx) { const auto data_type = ctx->Attr("dtype"); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsTriviallyCopyableDataType(data_type)); out->set_data_type(data_type); return Maybe::Ok(); } /*static*/ Maybe TensorToTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); const auto& instance_dims = ctx->Attr("instance_dims"); CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes()); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe TensorToTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in.shape(); const auto& instance_dims = ctx->Attr("instance_dims"); CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(in.is_dynamic()); DimVector out_dim_vec; out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), in_shape.dim_vec().cend() - instance_dims); out->set_shape(Shape(out_dim_vec)); return Maybe::Ok(); } /*static*/ Maybe TensorToTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TensorToTensorBufferOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(IsTriviallyCopyableDataType(in.data_type())); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } /*static*/ Maybe GenTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe GenTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); const Shape& shape = ctx->Attr("shape"); const int64_t num_tensor_buffers = shape.elem_cnt(); const std::vector& shape_list = ctx->Attr>("shape_list"); const std::vector& value_list = ctx->Attr>("value_list"); CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size()); CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size()); out->set_shape(shape); out->set_is_dynamic(ctx->Attr("dynamic_out")); return Maybe::Ok(); } /*static*/ Maybe GenTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe GenTensorBufferOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(DataType::kTensorBuffer); return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe TensorBufferToListOfTensorsOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); CHECK_OR_RETURN(!in.is_dynamic()); const Shape& out_shape = ctx->Attr("out_shape"); const bool dynamic_out = ctx->Attr("dynamic_out"); int64_t num_tensor_buffers = in.shape().elem_cnt(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); out_i->set_shape(out_shape); out_i->set_is_dynamic(dynamic_out); } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TensorBufferToListOfTensorsOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(in.data_type()); const DataType out_dtype = ctx->Attr("out_dtype"); CHECK_OR_RETURN(IsTriviallyCopyableDataType(out_dtype)); int64_t num_tensor_buffers = ctx->outputs().size(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); out_i->set_data_type(out_dtype); } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsOp::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { if (conf.attr("dynamic_out")) { FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); CHECK_OR_RETURN(out_i_modifier != nullptr); out_i_modifier->set_header_infered_before_compute(false); } } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsOp::CheckAttr( const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.output_size("out") >= 1); return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); CHECK_OR_RETURN(!in.is_dynamic()); const std::vector& out_shapes = ctx->Attr>("out_shapes"); const bool dynamic_out = ctx->Attr("dynamic_out"); int64_t num_tensor_buffers = in.shape().elem_cnt(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); out_i->set_shape(out_shapes[i]); out_i->set_is_dynamic(dynamic_out); } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer) << "InferDataType Failed. Expected " << DataType_Name(DataType::kTensorBuffer) << ", but got " << DataType_Name(in.data_type()); const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); int64_t num_tensor_buffers = ctx->outputs().size(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { CHECK_OR_RETURN(IsTriviallyCopyableDataType(out_dtypes[i])); user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); out_i->set_data_type(out_dtypes[i]); } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::ModifyOutputArg( const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { if (conf.attr("dynamic_out")) { FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); CHECK_OR_RETURN(out_i_modifier != nullptr); out_i_modifier->set_header_infered_before_compute(false); } } return Maybe::Ok(); } /*static*/ Maybe TensorBufferToListOfTensorsV2Op::CheckAttr( const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.output_size("out") >= 1); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tensor_constant_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/job/nd_sbp_util.h" namespace oneflow { /* static */ Maybe TensorConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } /*static*/ Maybe TensorConstantOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); const Shape& logical_shape = ctx->Attr("shape"); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); ctx->SetOutputShape("out", 0, physical_shape); return Maybe::Ok(); } /* static */ Maybe TensorConstantOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /* static */ Maybe TensorConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { SbpParallel default_sbp; default_sbp.mutable_broadcast_parallel(); return user_op::InferNdSbp4SrcOp(ctx, default_sbp); } /* static */ Maybe TensorConstantOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->Attr("dtype")); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tf_pool_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { // Logically computation cost of pool op is the product of output data amount and pool kernal data // amount. After adding sbp, we just divide it by parallel number if output data is splitted because // splitting input and using partial sum for output is not a valid sbp for this op for now. Maybe GetComputationCost(user_op::ComputeComplexityFnContext* ctx) { const std::vector pool_size = ctx->Attr>("pool_size"); double logical_computation_cost = std::accumulate(pool_size.begin(), pool_size.end(), ctx->Shape4ArgNameAndIndex("y", 0).elem_cnt(), std::multiplies()); const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy(); const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex("y", 0); for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) { if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) { logical_computation_cost /= parallel_hierarchy->At(dim_sbp); } } return logical_computation_cost; } typedef std::function(user_op::InferContext* ctx)> TensorDescInferFn; TensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->InputShape("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); const auto& padding_before = ctx->Attr>("padding_before"); const auto& padding_after = ctx->Attr>("padding_after"); const std::vector pool_size = ctx->Attr>("pool_size"); const std::vector strides = ctx->Attr>("strides"); const bool ceil_mode = ctx->Attr("ceil_mode"); CHECK_EQ_OR_RETURN(pool_size.size(), dim); for (int32_t pool_dim : pool_size) { CHECK_GT_OR_RETURN(pool_dim, 0); } CHECK_EQ_OR_RETURN(strides.size(), dim); for (int32_t stride_dim : strides) { CHECK_GT_OR_RETURN(stride_dim, 0); } const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); y_desc->set_shape(params_3d.GetYShape()); y_desc->set_is_dynamic(ctx->InputIsDynamic("x", 0)); return Maybe::Ok(); }; } Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("x", 0)); return Maybe::Ok(); } Maybe FwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe FwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } return Maybe::Ok(); } Maybe BwGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } } // namespace #define IMPLEMENT_TF_POOL_FUNCS(name, dim) \ /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return MakeFwTensorDescInferFn(dim)(ctx); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ return FwInferDataType(ctx); \ } \ /*static*/ Maybe name##Op::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx); \ } IMPLEMENT_TF_POOL_FUNCS(TfAvgPool1D, 1) IMPLEMENT_TF_POOL_FUNCS(TfAvgPool2D, 2) IMPLEMENT_TF_POOL_FUNCS(TfAvgPool3D, 3) IMPLEMENT_TF_POOL_FUNCS(TfMaxPool1D, 1) IMPLEMENT_TF_POOL_FUNCS(TfMaxPool2D, 2) IMPLEMENT_TF_POOL_FUNCS(TfMaxPool3D, 3) #undef IMPLEMENT_TF_POOL_FUNCS #define IMPLEMENT_TF_POOL_BACKWARD_FUNCS(name) \ /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return BwGetSbpFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return BwTensorDescInferFn(ctx); \ } \ /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return BwInferDataType(ctx); \ } \ /*static*/ Maybe name##GradOp::GetComputeComplexity( \ user_op::ComputeComplexityFnContext* ctx) { \ return GetComputationCost(ctx); \ } IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool1D) IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool2D) IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool3D) IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool1D) IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool2D) IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool3D) #undef IMPLEMENT_TF_POOL_BACKWARD_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/tf_prelu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TfPreluOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); ctx->NewBuilder() .Split(user_op::OpArg("x", 0), 0) .Broadcast(user_op::OpArg("alpha", 0)) .Split(user_op::OpArg("y", 0), 0) .Build(); FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("alpha", 0), i - 1) .Split(user_op::OpArg("y", 0), i) .Build(); } } return Maybe::Ok(); } /*static*/ Maybe TfPreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i)) || (alpha_shape.At(i - 1) == 1)); } y_desc->set_shape(x_desc.shape()); y_desc->set_is_dynamic(x_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe TfPreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TfPreluOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe TfPreluGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Broadcast(user_op::OpArg("alpha", 0)) .Split(user_op::OpArg("dx", 0), 0) .PartialSum(user_op::OpArg("alpha_diff", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("x", 0)) .Broadcast(user_op::OpArg("alpha", 0)) .PartialSum(user_op::OpArg("dx", 0)) .PartialSum(user_op::OpArg("alpha_diff", 0)) .Build(); FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("alpha", 0), i - 1) .Split(user_op::OpArg("dx", 0), i) .Split(user_op::OpArg("alpha_diff", 0), i - 1) .Build(); } } return Maybe::Ok(); } /*static*/ Maybe TfPreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i)) || (alpha_desc.shape().At(i - 1) == 1)); } CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape()); CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(x_desc.data_type()); dx_desc->set_shape(x_desc.shape()); dx_desc->set_is_dynamic(x_desc.is_dynamic()); ctx->SetOutputShape("alpha_diff", 0, alpha_desc.shape()); ctx->SetOutputIsDynamic("alpha_diff", 0, alpha_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe TfPreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TfPreluGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); ctx->SetOutputDType("alpha_diff", 0, ctx->InputDType("alpha", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/threshold_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe ThresholdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /* static */ Maybe ThresholdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ThresholdOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } return Maybe::Ok(); } /* static */ Maybe ThresholdOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } /* static */ Maybe ThresholdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(dy_shape == x_shape); ctx->SetOutputShape("dx", 0, dy_shape); return Maybe::Ok(); } /* static */ Maybe ThresholdGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe ThresholdGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } return Maybe::Ok(); } /* static */ Maybe ThresholdGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("dy", 0)) << ", but got " << DataType_Name(ctx->InputDType("x", 0)); ctx->SetOutputDType("dx", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/throw_error_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe ThrowErrorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } Maybe ThrowErrorOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } Maybe ThrowErrorOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/to_contiguous_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ToContiguousOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /*static*/ Maybe ToContiguousOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); ctx->SetOutputStride("out", 0, Stride(in_desc.shape())); return Maybe::Ok(); } /*static*/ Maybe ToContiguousOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ToContiguousOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/top_k_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TopKOp::GetSbp(user_op::SbpContext* ctx) { // The current implementation can only do top_k in the last dimension and should use Broadcast // (by default) instead of Split for that dimension const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } return Maybe::Ok(); } /*static*/ Maybe TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { Shape out_shape = ctx->InputShape("in", 0); out_shape.Set( out_shape.NumAxes() - 1, std::min(ctx->Attr("k"), static_cast(out_shape.dim_vec().back()))); ctx->SetOutputShape("out", 0, out_shape); return Maybe::Ok(); } /*static*/ Maybe TopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TopKOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kInt64); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/transpose_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { void CheckIsPerm(const std::vector& perm) { std::vector is_used(perm.size(), false); FOR_RANGE(size_t, i, 0, perm.size()) { CHECK_GE(perm[i], 0); CHECK_LE(perm[i], perm.size()); CHECK_EQ(is_used[perm[i]], false); is_used[perm[i]] = true; } } /*static*/ Maybe TransposeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); const auto& perm = ctx->Attr>("perm"); CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes()); FOR_RANGE(int32_t, i, 0, perm.size()) { int32_t axis = perm.at(i); if (axis < 0) { axis += perm.size(); } CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, perm.size()); ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /*static*/ Maybe TransposeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("output", 0); const Shape& in_shape = in_tensor_desc.shape(); Shape out_shape = in_tensor_desc.shape(); const auto& perm = ctx->Attr>("perm"); CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes()); CheckIsPerm(perm); // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); } out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic()); FOR_RANGE(size_t, i, 0, perm.size()) { out_shape.Set(i, in_shape.At(perm[i])); } out_tensor_desc->set_shape(out_shape); return Maybe::Ok(); } /*static*/ Maybe TransposeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TransposeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tril_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TrilOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } bool fill_zero = ctx->Attr("is_floating_fill_value") ? (ctx->Attr("floating_fill_value") == static_cast(0)) : (ctx->Attr("integer_fill_value") == static_cast(0)); if (fill_zero) { ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); } return Maybe::Ok(); } /*static*/ Maybe TrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); out->set_shape(in.shape()); out->set_is_dynamic(in.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe TrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TrilOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } /*static*/ Maybe FusedScaleTrilOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } bool fill_zero = ctx->Attr("is_floating_fill_value") ? (ctx->Attr("floating_fill_value") == static_cast(0)) : (ctx->Attr("integer_fill_value") == static_cast(0)); if (fill_zero) { ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); } return Maybe::Ok(); } /*static*/ Maybe FusedScaleTrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); out->set_shape(in.shape()); out->set_is_dynamic(in.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe FusedScaleTrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe FusedScaleTrilOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/triu_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TriuOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe TriuOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); out->set_shape(in.shape()); out->set_is_dynamic(in.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe TriuOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TriuOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_data_type(in.data_type()); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/trunc_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TruncOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /*static*/ Maybe TruncOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe TruncOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TruncOp::InferDataType(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); } /* static */ Maybe TruncGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::Unchanged(ctx); } /*static*/ Maybe TruncGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe TruncGradOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); } /* static */ Maybe TruncGradOp::InferDataType(user_op::InferContext* ctx) { return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/tuple_identity_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe TupleIdentityOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe TupleIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int64_t in_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { ctx->SetOutputShape("out", i, ctx->InputShape("in", i)); ctx->SetIsDynamic4ArgNameAndIndex("out", i, ctx->InputIsDynamic("in", i)); } return Maybe::Ok(); } /*static*/ Maybe TupleIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TupleIdentityOp::InferDataType(user_op::InferContext* ctx) { const int64_t in_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { ctx->SetOutputDType("out", i, ctx->InputDType("in", i)); } return Maybe::Ok(); } /*static*/ Maybe TupleIdentityOp::InferSbpSignature( user_op::InferSbpSignatureFnContext* ctx) { SbpSignature* signature = ctx->mutable_sbp_signature(); const SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf(); auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel(); const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel(); const int64_t in_size = ctx->user_op_conf().input_size("in"); CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { const SbpParallel* sbp_parallel = nullptr; const std::string ibn = GenRepeatedBn("in", i); const std::string& obn = GenRepeatedBn("out", i); const auto& conf_sbp_it = bn2conf_sbp.find(obn); if (conf_sbp_it == bn2conf_sbp.end()) { sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex("in", i); } else { sbp_parallel = &conf_sbp_it->second; } (*bn2sbp)[ibn] = *sbp_parallel; (*bn2sbp)[obn] = *sbp_parallel; } return Maybe::Ok(); } /*static*/ Maybe TupleIdentityOp::CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { CHECK_OR_RETURN(op_conf.input_size("in") >= 1); CHECK_OR_RETURN(op_conf.output_size("out") >= 1); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/two_stage_reduce_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferReduceDeviceStageDtypeFn(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("mask", 0, DataType::kBool); ctx->SetOutputDType("count", 0, DataType::kInt32); return Maybe::Ok(); } Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); const int64_t num_axes = input_shape.NumAxes(); Shape output_shape; if (axis.empty()) { output_shape = Shape::Ones(num_axes); } else { const ParallelDesc& parallel_desc = ctx->parallel_desc(); const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); DimVector dim_vec = input_shape.dim_vec(); if (parallel_desc.hierarchy()->NumAxes() == 1) { const auto& input_sbp = in_nd_sbp.sbp_parallel(0); for (auto i : axis) { const int64_t regular_axis = ShiftNegativeAxis(i, num_axes); dim_vec.at(regular_axis) = (input_sbp.has_split_parallel() && input_sbp.split_parallel().axis() == regular_axis) ? parallel_desc.parallel_num() : 1; } } else { CHECK_EQ_OR_RETURN(axis.size(), 1); const int64_t regular_axis = ShiftNegativeAxis(axis.at(0), num_axes); dim_vec.at(regular_axis) = 1; for (int64_t i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) { const auto& input_sbp = in_nd_sbp.sbp_parallel(i); if (input_sbp.has_split_parallel() && input_sbp.split_parallel().axis() == regular_axis) { dim_vec.at(regular_axis) *= parallel_desc.hierarchy()->At(i); } } } output_shape = Shape(dim_vec); } ctx->SetOutputShape("out", 0, output_shape); ctx->SetOutputShape("mask", 0, input_shape); ctx->SetOutputShape("count", 0, output_shape); return Maybe::Ok(); } Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); Shape output_shape; if (axis.empty()) { output_shape = Shape::Ones(input_shape.NumAxes()); } else { const AxisVector axis_vec = {axis.begin(), axis.end()}; const Shape& reduced_shape = CreateReducedShape(input_shape, axis_vec); output_shape = reduced_shape; } ctx->SetOutputShape("out", 0, output_shape); ctx->SetOutputShape("mask", 0, input_shape); ctx->SetOutputShape("count", 0, output_shape); ; return Maybe::Ok(); } Maybe InferReduceDeviceStageGradDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool) << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " << DataType_Name(ctx->InputDType("mask", 0)); CHECK_EQ_OR_RETURN(ctx->InputDType("count", 0), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(ctx->InputDType("count", 0)); ctx->SetOutputDType("in_diff", 0, ctx->InputDType("out_diff", 0)); return Maybe::Ok(); } Maybe InferReduceDeviceStageGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputShape("out_diff", 0), ctx->InputShape("count", 0)); ctx->SetOutputShape("in_diff", 0, ctx->InputShape("mask", 0)); return Maybe::Ok(); } Maybe InferReduceGlobalStageDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("device_count", 0), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(ctx->InputDType("device_count", 0)); ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); ctx->SetOutputDType("mask", 0, DataType::kBool); return Maybe::Ok(); } Maybe InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const Shape& device_count_shape = ctx->InputShape("device_count", 0); CHECK_EQ_OR_RETURN(input_shape, device_count_shape); const auto& axis = ctx->Attr>("axis"); bool keepdims = ctx->Attr("keepdims"); Shape output_shape; if (axis.empty()) { if (keepdims) { output_shape = Shape::Ones(input_shape.NumAxes()); } else { output_shape = Shape({1}); } } else { const AxisVector axis_vec = {axis.begin(), axis.end()}; const Shape& reduced_shape = CreateReducedShape(input_shape, axis_vec); if (keepdims) { output_shape = reduced_shape; } else { output_shape = reduced_shape.RemoveOnes(axis_vec); } } ctx->SetOutputShape("out", 0, output_shape); ctx->SetOutputShape("mask", 0, input_shape); return Maybe::Ok(); } Maybe InferReduceGlobalStageGradDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool) << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " << DataType_Name(ctx->InputDType("mask", 0)); CHECK_EQ_OR_RETURN(ctx->InputDType("device_count", 0), DataType::kInt32) << "InferDataType Failed. Expected " << DataType_Name(DataType::kInt32) << ", but got " << DataType_Name(ctx->InputDType("device_count", 0)); ctx->SetOutputDType("in_diff", 0, ctx->InputDType("out_diff", 0)); return Maybe::Ok(); } Maybe InferReduceGlobalStageGradTensorDescFn(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); const Shape& device_count_shape = ctx->InputShape("device_count", 0); CHECK_EQ_OR_RETURN(device_count_shape, mask_shape); ctx->SetOutputShape("in_diff", 0, mask_shape); return Maybe::Ok(); } Maybe GetReduceDeviceStageSbpFn(user_op::SbpContext* ctx) { int32_t num_axes = 0; HashSet conf_axes; { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); num_axes = in_tensor.shape().NumAxes(); const auto& reduced_axes = ctx->Attr>("axis"); conf_axes = {reduced_axes.begin(), reduced_axes.end()}; } auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); FOR_RANGE(int64_t, i, 0, num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("in", 0), i) .Split(user_op::OpArg("out", 0), i) .Split(user_op::OpArg("mask", 0), i) .Split(user_op::OpArg("count", 0), i) .Build(); } return Maybe::Ok(); } Maybe GetReduceDeviceStageGradSbpFn(user_op::SbpContext* ctx) { int32_t num_axes = 0; HashSet conf_axes; { const auto& output_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("out_diff", 0); num_axes = output_tensor.shape().NumAxes(); const auto& reduced_axes = ctx->Attr>("axis"); conf_axes = {reduced_axes.begin(), reduced_axes.end()}; } auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); FOR_RANGE(int64_t, i, 0, num_axes) { if (IsReducedAxis(i) || i == 0) { ctx->NewBuilder() .Split(user_op::OpArg("out_diff", 0), i) .Split(user_op::OpArg("count", 0), i) .Split(user_op::OpArg("mask", 0), i) .Split(user_op::OpArg("in_diff", 0), i) .Build(); } } return Maybe::Ok(); } } // namespace #define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(op_name) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ return GetReduceDeviceStageSbpFn(ctx); \ } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferReduceDeviceStageLogicalTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferReduceDeviceStagePhysicalTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ return InferReduceDeviceStageDtypeFn(ctx); \ } IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMinDeviceStage) IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMaxDeviceStage) #undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS #define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(op_name) \ /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ return GetReduceDeviceStageGradSbpFn(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferReduceDeviceStageGradTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return InferReduceDeviceStageGradDtypeFn(ctx); \ } IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMinDeviceStage) IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMaxDeviceStage) #undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS #define IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(op_name) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ ctx->NewBuilder() \ .Split(user_op::OpArg("in", 0), 0) \ .Split(user_op::OpArg("device_count", 0), 0) \ .Split(user_op::OpArg("out", 0), 0) \ .Split(user_op::OpArg("mask", 0), 0) \ .Build(); \ return Maybe::Ok(); \ } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferReduceGlobalStageTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ return InferReduceGlobalStageDtypeFn(ctx); \ } \ /*static*/ Maybe op_name##Op::ModifyInputArg( \ const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ user_op::InputArgModifier* device_count_modifier = GetInputArgModifierFn("device_count", 0); \ device_count_modifier->set_requires_grad(false); \ return Maybe::Ok(); \ } IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMinGlobalStage) IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMaxGlobalStage) #undef IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS #define IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(op_name) \ /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ ctx->NewBuilder() \ .Split(user_op::OpArg("out_diff", 0), 0) \ .Split(user_op::OpArg("mask", 0), 0) \ .Split(user_op::OpArg("device_count", 0), 0) \ .Split(user_op::OpArg("in_diff", 0), 0) \ .Build(); \ return Maybe::Ok(); \ } \ /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ return InferReduceGlobalStageGradTensorDescFn(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ return InferReduceGlobalStageGradDtypeFn(ctx); \ } IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMinGlobalStage) IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMaxGlobalStage) #undef IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS } // namespace oneflow ================================================ FILE: oneflow/user/ops/unfold_fold_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int32_t spatial_ndim = x_shape.NumAxes() - 2; std::string data_format = ctx->Attr("data_format"); std::vector padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& strides = ctx->Attr>("strides"); const std::vector& dilation_rate = ctx->Attr>("dilation_rate"); const int32_t idx_offset = IdxOffset(data_format); const size_t c_dim = data_format == "channels_first" ? 1 : spatial_ndim + 1; CHECK_EQ_OR_RETURN(spatial_ndim, 2); // only support 4-D tensor now. CHECK_EQ_OR_RETURN(padding.size(), spatial_ndim); for (int32_t pad : padding) { CHECK_GE_OR_RETURN(pad, 0); } CHECK_EQ_OR_RETURN(kernel_size.size(), spatial_ndim); for (int32_t kernel : kernel_size) { CHECK_GT_OR_RETURN(kernel, 0); } CHECK_EQ_OR_RETURN(strides.size(), spatial_ndim); for (int32_t stride : strides) { CHECK_GT_OR_RETURN(stride, 0); } CHECK_EQ_OR_RETURN(dilation_rate.size(), spatial_ndim); for (int32_t dilation : dilation_rate) { CHECK_GE_OR_RETURN(dilation, 1); } std::vector dhw_shape(spatial_ndim); for (int32_t i = 0; i < spatial_ndim; ++i) { dhw_shape[i] = (x_shape.At(idx_offset + i) + 2 * padding[i] - dilation_rate[i] * (kernel_size[i] - 1) - 1) / strides[i] + 1; } DimVector y_shape(3); y_shape.at(0) = x_shape.At(0); y_shape.at(1) = x_shape.At(c_dim) * std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); y_shape.at(2) = std::accumulate(dhw_shape.begin(), dhw_shape.end(), 1, std::multiplies()); ctx->SetOutputShape("y", 0, Shape(y_shape)); return Maybe::Ok(); } Maybe SetUnfoldDTypeFn(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe GetUnfoldSbpFn(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); ctx->NewBuilder().Split(user_op::OpArg("x", 0), 1).Split(user_op::OpArg("y", 0), 1).Build(); return Maybe::Ok(); } Maybe FoldTensorDescInferFn(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int32_t spatial_ndim = x_shape.NumAxes() - 1; // (n, c*K*K, h*w) std::string data_format = ctx->Attr("data_format"); std::vector output_size = ctx->Attr>("output_size"); std::vector padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); const std::vector& strides = ctx->Attr>("strides"); const std::vector& dilation_rate = ctx->Attr>("dilation_rate"); const size_t c_dim = data_format == "channels_first" ? 1 : spatial_ndim; const size_t length_dim = data_format == "channels_first" ? spatial_ndim : 1; const int32_t input_planes = x_shape.At(c_dim); const int32_t input_length = x_shape.At(length_dim); CHECK_EQ_OR_RETURN(spatial_ndim, 2); // only support 4-D tensor now. CHECK_EQ_OR_RETURN(output_size.size(), spatial_ndim); CHECK_EQ_OR_RETURN(padding.size(), spatial_ndim); for (int32_t pad : padding) { CHECK_GE_OR_RETURN(pad, 0); } CHECK_EQ_OR_RETURN(kernel_size.size(), spatial_ndim); for (int32_t kernel : kernel_size) { CHECK_GT_OR_RETURN(kernel, 0); } CHECK_EQ_OR_RETURN(strides.size(), spatial_ndim); for (int32_t stride : strides) { CHECK_GT_OR_RETURN(stride, 0); } CHECK_EQ_OR_RETURN(dilation_rate.size(), spatial_ndim); for (int32_t dilation : dilation_rate) { CHECK_GE_OR_RETURN(dilation, 1); } CHECK_EQ_OR_RETURN(input_planes % (kernel_size[0] * kernel_size[1]), 0); // C*K*K should be divided by K*K const int32_t output_height = (output_size[0] + 2 * padding[0] - dilation_rate[0] * (kernel_size[0] - 1) - 1) / strides[0] + 1; const int32_t output_width = (output_size[1] + 2 * padding[1] - dilation_rate[1] * (kernel_size[1] - 1) - 1) / strides[1] + 1; CHECK_EQ_OR_RETURN(output_height * output_width, input_length); // input_length == OH*OW DimVector y_shape(4); y_shape.at(0) = x_shape.At(0); y_shape.at(1) = input_planes / (kernel_size[0] * kernel_size[1]); y_shape.at(2) = output_size[0]; y_shape.at(3) = output_size[1]; ctx->SetOutputShape("y", 0, Shape(y_shape)); return Maybe::Ok(); } Maybe FoldDTypeFn(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } Maybe GetFoldSbpFn(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } } // namespace /*static*/ Maybe UnfoldOp::GetSbp(user_op::SbpContext* ctx) { return GetUnfoldSbpFn(ctx); } /*static*/ Maybe UnfoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UnfoldTensorDescInferFn(ctx); } /*static*/ Maybe UnfoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnfoldOp::InferDataType(user_op::InferContext* ctx) { return SetUnfoldDTypeFn(ctx); } /*static*/ Maybe FoldOp::GetSbp(user_op::SbpContext* ctx) { return GetFoldSbpFn(ctx); } /*static*/ Maybe FoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return FoldTensorDescInferFn(ctx); } /*static*/ Maybe FoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe FoldOp::InferDataType(user_op::InferContext* ctx) { return FoldDTypeFn(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unfold_tensor_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/unfold_tensor_kernel_utils.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UnfoldTensorOp::GetSbp(user_op::SbpContext* ctx) { const int32_t dimension = ctx->Attr("dimension"); const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { if (i != dimension) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); const int32_t dimension = ctx->Attr("dimension"); const int32_t size = ctx->Attr("size"); const int32_t step = ctx->Attr("step"); const Shape& in_shape = ctx->InputShape("x", 0); const int32_t in_dim = in_shape.NumAxes(); CHECK_GE_OR_RETURN(dimension, 0); // NOTE(lixiang): remove -1 for 0-dim tensor CHECK_LE_OR_RETURN(dimension, in_dim); const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension); CHECK_GT_OR_RETURN(size, 0); CHECK_LE_OR_RETURN(size, max_size); CHECK_GT_OR_RETURN(step, 0); DimVector out_shape(in_dim + 1); out_shape[in_dim] = size; FOR_RANGE(int32_t, d, 0, in_dim) { int32_t in_size_at_d = in.shape().At(d); if (d == dimension) { out_shape.at(d) = (in_size_at_d - size) / step + 1; } else { out_shape.at(d) = in_size_at_d; } } ctx->SetOutputShape("y", 0, Shape(out_shape)); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnfoldTensorOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorGradOp::GetSbp(user_op::SbpContext* ctx) { const int32_t dimension = ctx->Attr("dimension"); const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { if (i != dimension) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), i) .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("dx", 0), i) .Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); const Shape& in_shape = in.shape(); user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); dx_desc->set_shape(Shape(in_shape.dim_vec())); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnfoldTensorGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unique_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UniqueOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe UniqueOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(x.shape()); y->set_is_dynamic(x.is_dynamic()); user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); idx->set_shape(x.shape()); idx->set_is_dynamic(x.is_dynamic()); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe UniqueOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UniqueOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); auto out_idx = ctx->Attr("out_idx"); CHECK_OR_RETURN(IsIndexDataType(out_idx)); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(x.data_type()); user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); idx->set_data_type(out_idx); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_data_type(out_idx); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unique_with_counts_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UniqueWithCountsOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe UniqueWithCountsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_shape(x.shape()); y->set_is_dynamic(x.is_dynamic()); user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); idx->set_shape(x.shape()); idx->set_is_dynamic(x.is_dynamic()); user_op::TensorDesc* count = ctx->MutOutputTensorDesc("count", 0); count->set_shape(x.shape()); count->set_is_dynamic(x.is_dynamic()); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_shape(Shape({1})); return Maybe::Ok(); } /*static*/ Maybe UniqueWithCountsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UniqueWithCountsOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); auto out_idx = ctx->Attr("out_idx"); CHECK_OR_RETURN(IsIndexDataType(out_idx)); user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); y->set_data_type(x.data_type()); user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); idx->set_data_type(out_idx); user_op::TensorDesc* count = ctx->MutOutputTensorDesc("count", 0); count->set_data_type(out_idx); user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); num_unique->set_data_type(out_idx); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unpack_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UnpackOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("in", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe UnpackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in_desc.shape(); CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); const auto unpack_num = ctx->Attr("unpack_num"); CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); Shape out_shape = in_desc.shape(); out_shape.Set(0, in_shape.At(0) / unpack_num); out_desc->set_shape(out_shape); out_desc->set_is_dynamic(in_desc.is_dynamic()); return Maybe::Ok(); } /*static*/ Maybe UnpackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnpackOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); out_desc->set_data_type(in_desc.data_type()); return Maybe::Ok(); } /*static*/ Maybe UnpackOp::InferOutputBlobTimeShape( user_op::InferOutputBlobTimeShapeFnContext* ctx) { const int32_t unpack_num = ctx->user_op_conf().attr("unpack_num"); DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); time_shape_dim_vec.emplace_back(unpack_num); *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unsorted_batch_segment_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UnsortedBatchSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { const int64_t segment_ids_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); if (segment_ids_num_axes > 1) { FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) { ctx->NewBuilder() .Split(user_op::OpArg("segment_ids", 0), i) .Split(user_op::OpArg("data", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } } ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .PartialSum(user_op::OpArg("data", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe UnsortedBatchSegmentSumOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1); CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes()); CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); const int64_t num_segments = ctx->Attr("num_segments"); CHECK_GE_OR_RETURN(num_segments, 1); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); } DimVector dim_vec(data.shape().dim_vec()); dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments; out->set_shape(Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe UnsortedBatchSegmentSumOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnsortedBatchSegmentSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); out->set_data_type(data.data_type()); return Maybe::Ok(); } /*static*/ Maybe UnsortedBatchSegmentSumOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/unsorted_segment_sum_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe UnsortedSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { const int64_t data_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); const int64_t segment_ids_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); const int64_t axis = ctx->Attr("axis"); FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("segment_ids", 0), i) .Split(user_op::OpArg("data", 0), i + axis) .PartialSum(user_op::OpArg("out", 0)) .Build(); } FOR_RANGE(int64_t, i, 0, data_num_axes) { if (i >= axis && i < axis + segment_ids_num_axes) { continue; } const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; if (out_split_axis == axis) { continue; } ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .Split(user_op::OpArg("data", 0), i) .Split(user_op::OpArg("out", 0), out_split_axis) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .PartialSum(user_op::OpArg("data", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& data_shape = ctx->InputShape("data", 0); const int64_t axis = ctx->Attr("axis"); const int64_t num_segments = ctx->Attr("num_segments"); const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); DimVector dim_vec; dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(), data_shape.dim_vec().cbegin() + axis); dim_vec.emplace_back(num_segments); dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(), data_shape.dim_vec().end()); ctx->SetOutputShape("out", 0, Shape(dim_vec)); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnsortedSegmentSumOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); ctx->SetOutputDType("out", 0, ctx->InputDType("data", 0)); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::GetSbp(user_op::SbpContext* ctx) { const int64_t data_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); const int64_t segment_ids_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); const int64_t axis = ctx->Attr("axis"); FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { ctx->NewBuilder() .Split(user_op::OpArg("segment_ids", 0), i) .Split(user_op::OpArg("data", 0), i + axis) .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("segment_ids", 0), i) .Split(user_op::OpArg("data", 0), i + axis) .PartialSum(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); } FOR_RANGE(int64_t, i, 0, data_num_axes) { if (i >= axis && i < axis + segment_ids_num_axes) { continue; } const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; if (out_split_axis == axis) { continue; } ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .Split(user_op::OpArg("data", 0), i) .Split(user_op::OpArg("like", 0), out_split_axis) .Split(user_op::OpArg("out", 0), out_split_axis) .Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .PartialSum(user_op::OpArg("data", 0)) .Broadcast(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .PartialSum(user_op::OpArg("data", 0)) .PartialSum(user_op::OpArg("like", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("segment_ids", 0)) .Broadcast(user_op::OpArg("data", 0)) .Split(user_op::OpArg("like", 0), axis) .Split(user_op::OpArg("out", 0), axis) .Build(); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& data_shape = ctx->InputShape("data", 0); const Shape& like_shape = ctx->InputShape("like", 0); const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); const int64_t axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); CHECK_LE_OR_RETURN(axis, like_shape.NumAxes()); FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); } CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, like_shape.NumAxes()); FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); } ctx->SetOutputShape("out", 0, ctx->InputShape("like", 0)); ctx->SetIsDynamic4ArgNameAndIndex("out", 0, ctx->InputIsDynamic("like", 0)); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnsortedSegmentSumLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()) << "InferDataType Failed. Expected " << DataType_Name(like.data_type()) << ", but got " << DataType_Name(data.data_type()); CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); ctx->SetOutputDType("out", 0, ctx->InputDType("data", 0)); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::ModifyInputArg( const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/upsample_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace { using namespace oneflow; template typename std::enable_if<(N <= 3), Maybe>::type UpsamplingInferLogicalDesc( user_op::InferContext* ctx, const std::string& func_name) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); if (N == 1) { CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == (N + 2)) << func_name << " only supports NCH"; int64_t input_width = x_desc.shape().At(2); int64_t output_width = 0; const double scale_factor = ctx->Attr("scale_factor"); std::vector output_size = ctx->Attr>("output_size"); if (output_size.size()) { output_width = output_size[0]; } else { output_width = static_cast(scale_factor * input_width); } CHECK_OR_RETURN(input_width > 0 && output_width > 0) << func_name << ": Input and output sizes should be greater than 0, but got input (W: " << input_width << ") output (W: " << output_width << ")"; y_desc->set_shape(Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_width})); } else if (N == 2) { CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == (N + 2)) << func_name << " only supports NCHW"; const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); std::vector output_size = ctx->Attr>("output_size"); int64_t input_height = x_desc.shape().At(2); int64_t input_width = x_desc.shape().At(3); int64_t output_height = 0; int64_t output_width = 0; if (output_size.size()) { output_height = output_size[0]; output_width = output_size[1]; } else { output_height = static_cast(height_scale * input_height); output_width = static_cast(width_scale * input_width); } CHECK_OR_RETURN(input_height > 0 && input_width > 0 && output_height > 0 && output_width > 0) << func_name << ": Input and output sizes should be greater than 0, but got input (H: " << input_height << ", W: " << input_width << ") output (H: " << output_height << ", W: " << output_width << ")"; y_desc->set_shape( Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_height, output_width})); } else if (N == 3) { CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 5) << func_name << " only supports NCDHW"; const double depth_scale = ctx->Attr("depth_scale"); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); std::vector output_size = ctx->Attr>("output_size"); int64_t input_depth = x_desc.shape().At(2); int64_t input_height = x_desc.shape().At(3); int64_t input_width = x_desc.shape().At(4); int64_t output_depth = 0; int64_t output_height = 0; int64_t output_width = 0; if (output_size.size()) { output_depth = output_size[0]; output_height = output_size[1]; output_width = output_size[2]; } else { output_depth = static_cast(depth_scale * input_depth); output_height = static_cast(height_scale * input_height); output_width = static_cast(width_scale * input_width); } CHECK_OR_RETURN(input_depth > 0 && input_height > 0 && input_width > 0 && output_depth > 0 && output_height > 0 && output_width > 0) << func_name << ": Input and output sizes should be greater than 0, but got input (D: " << input_depth << ", H: " << input_height << ", W: " << input_width << ") output (D: " << output_depth << "H: " << output_height << ", W: " << output_width << ")"; y_desc->set_shape(Shape( {x_desc.shape().At(0), x_desc.shape().At(1), output_depth, output_height, output_width})); } return Maybe::Ok(); } } // namespace namespace oneflow { /*static*/ Maybe UpsampleLinear1DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<1>(ctx, "upsample_linear_1d"); } /*static*/ Maybe UpsampleLinear1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleLinear1DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<1>(ctx, "upsample_nearest_1d"); } /*static*/ Maybe UpsampleNearest1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest1DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<2>(ctx, "upsample_nearest_2d"); } /*static*/ Maybe UpsampleNearest2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<2>(ctx, "upsample_bilinear_2d"); } /*static*/ Maybe UpsampleBilinear2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBilinear2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<2>(ctx, "upsample_bicubic_2d"); } /*static*/ Maybe UpsampleBicubic2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBicubic2DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<3>(ctx, "upsample_nearest_3d"); } /*static*/ Maybe UpsampleNearest3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest3DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return UpsamplingInferLogicalDesc<3>(ctx, "upsample_trilinear_3d"); } /*static*/ Maybe UpsampleTrilinear3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleTrilinear3DOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("y", 0, ctx->InputDType("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_linear_1d_grad only supports NCH"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleLinear1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleLinear1DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_nearest_1d_grad only supports NCH"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest1DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_nearest_2d_grad only supports NCHW"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest2DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bilinear_2d_grad only supports NCHW"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBilinear2DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bicubic_2d_grad only supports NCHW"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBicubic2DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_nearest_3d_grad only supports NCDHW"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest3DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("x", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_trilinear_3d_grad only supports NCDHW"; ctx->SetOutputShape("dx", 0, ctx->InputShape("x", 0)); return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DGradOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleTrilinear3DGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/util_ops.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /* static */ Maybe IsNanOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe IsNanOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IsNanOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe IsNanOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kBool); return Maybe::Ok(); } /* static */ Maybe IsInfOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe IsInfOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IsInfOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe IsInfOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kBool); return Maybe::Ok(); } /* static */ Maybe IsFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); return Maybe::Ok(); } /*static*/ Maybe IsFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe IsFiniteOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } /* static */ Maybe IsFiniteOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, DataType::kBool); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/variance_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { Maybe VarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); const auto& reduce_axes = ctx->Attr>("dim"); CHECK_OR_RETURN(!reduce_axes.empty()); const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()}; const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec); const bool keepdim = ctx->Attr("keepdim"); Shape output_shape; if (keepdim) { output_shape = reduce_shape; } else { output_shape = reduce_shape.RemoveOnes(reduce_axes_vec); } ctx->SetOutputShape("output", 0, output_shape); return Maybe::Ok(); } Maybe VarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } Maybe VarOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("output", 0, ctx->InputDType("input", 0)); return Maybe::Ok(); } Maybe VarOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); const Shape& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); const int64_t ndim = input_shape.NumAxes(); const std::vector axis = ctx->Attr>("dim"); const bool keepdim = ctx->Attr("keepdim"); if (keepdim) { for (int i = 0; i < ndim; i++) { if (std::find(axis.begin(), axis.end(), i) == axis.end()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } } else { int offset = 0; for (int i = 0; i < ndim; i++) { if (std::find(axis.begin(), axis.end(), i) == axis.end()) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i - offset).Build(); } else { offset += 1; } } } return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/vector_matrix_product_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { Maybe InferTensorDesc4VectorMatrixProduct(user_op::InferContext* ctx) { const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t k = a.shape().At(0); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; int64_t n = b.shape().At(1); ctx->SetOutputShape("out", 0, Shape({n})); return Maybe::Ok(); } Maybe InferDataType4VectorMatrixProduct(user_op::InferContext* ctx) { DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "Matrix A datatype should be equal to Vector B. "; ctx->SetOutputDType("out", 0, dtype); return Maybe::Ok(); } Maybe InferTensorDesc4VectorMatrixProductGradA(user_op::InferContext* ctx) { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k) */ const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t k = b.shape().At(0); ctx->SetOutputShape("dx", 0, Shape({k})); return Maybe::Ok(); } Maybe InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) GradB = a (k, 1) matmul dy (1, n) */ const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t k = a.shape().At(0); int64_t n = dy.shape().At(0); ctx->SetOutputShape("dx", 0, Shape({k, n})); return Maybe::Ok(); } Maybe InferDataType4Grad(user_op::InferContext* ctx) { DataType dtype = ctx->InputDType("dy", 0); ctx->SetOutputDType("dx", 0, dtype); return Maybe::Ok(); } } // namespace /* static */ Maybe VectorMatrixProductOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return InferTensorDesc4VectorMatrixProduct(ctx); } /*static*/ Maybe VectorMatrixProductOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe VectorMatrixProductOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("b", 0), 1) .Split(user_op::OpArg("out", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("a", 0), 0) .Split(user_op::OpArg("b", 0), 0) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe VectorMatrixProductOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4VectorMatrixProduct(ctx); } /* static */ Maybe VectorMatrixProductGradAOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4VectorMatrixProductGradA(ctx); } /*static*/ Maybe VectorMatrixProductGradAOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe VectorMatrixProductGradAOp::GetSbp(user_op::SbpContext* ctx) { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k) */ ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .Split(user_op::OpArg("b", 0), 0) .Split(user_op::OpArg("dx", 0), 0) .Build(); ctx->NewBuilder() .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("b", 0), 1) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("dy", 0)) .Broadcast(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("b", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe VectorMatrixProductGradAOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Grad(ctx); } /* static */ Maybe VectorMatrixProductGradBOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { return InferTensorDesc4VectorMatrixProductGradB(ctx); } /*static*/ Maybe VectorMatrixProductGradBOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /* static */ Maybe VectorMatrixProductGradBOp::GetSbp(user_op::SbpContext* ctx) { /* A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n) A(k, ) -> (1, k) GradB = a_transpose (k, 1) matmul dy (1, n) */ ctx->NewBuilder() .Split(user_op::OpArg("a", 0), 0) .Broadcast(user_op::OpArg("dy", 0)) .Split(user_op::OpArg("dx", 0), 0) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .Split(user_op::OpArg("dy", 0), 0) .Split(user_op::OpArg("dx", 0), 1) .Build(); ctx->NewBuilder() .Broadcast(user_op::OpArg("a", 0)) .PartialSum(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); ctx->NewBuilder() .PartialSum(user_op::OpArg("a", 0)) .Broadcast(user_op::OpArg("dy", 0)) .PartialSum(user_op::OpArg("dx", 0)) .Build(); return Maybe::Ok(); } /* static */ Maybe VectorMatrixProductGradBOp::InferDataType(user_op::InferContext* ctx) { return InferDataType4Grad(ctx); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/where_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" #include "oneflow/core/framework/dtype.h" namespace oneflow { namespace { Maybe GetBroadcastShape(const Shape& cond_shape, const Shape& x_shape, const Shape& y_shape) { size_t ndim = std::max(x_shape.size(), y_shape.size()); ndim = std::max(ndim, cond_shape.size()); DimVector broadcast_dim_vec(ndim); for (size_t i = 0; i < ndim; ++i) { size_t cond_lpad = ndim - cond_shape.size(); size_t x_lpad = ndim - x_shape.size(); size_t y_lpad = ndim - y_shape.size(); int64_t cond_dim = (i < cond_lpad) ? 1 : cond_shape[i - cond_lpad]; int64_t x_dim = (i < x_lpad) ? 1 : x_shape[i - x_lpad]; int64_t y_dim = (i < y_lpad) ? 1 : y_shape[i - y_lpad]; int64_t max_dim = std::max(x_dim, y_dim); max_dim = std::max(max_dim, cond_dim); broadcast_dim_vec[i] = max_dim; if ((cond_dim != 1 && cond_dim != max_dim) || (x_dim != 1 && x_dim != max_dim) || (y_dim != 1 && y_dim != max_dim)) { return Error::RuntimeError() << "The tensor cond with size " << cond_shape.ToString() << ", x with size " << x_shape.ToString() << " and y with size " << y_shape.ToString() << " are not broadcastable."; } } return Shape(broadcast_dim_vec); } } // namespace /*static*/ Maybe WhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& cond_shape = ctx->InputShape("condition", 0); const Shape& x_shape = ctx->InputShape("x", 0); const Shape& y_shape = ctx->InputShape("y", 0); ctx->SetOutputShape("out", 0, *JUST(GetBroadcastShape(cond_shape, x_shape, y_shape))); return Maybe::Ok(); } /*static*/ Maybe WhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe WhereOp::InferDataType(user_op::InferContext* ctx) { DataType cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); DataType x_dtype = ctx->InputDType("x", 0); CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)) << "InferDataType Failed. Expected " << DataType_Name(ctx->InputDType("y", 0)) << ", but got " << DataType_Name(x_dtype); ctx->SetOutputDType("out", 0, x_dtype); return Maybe::Ok(); } /*static*/ Maybe WhereOp::GetSbp(user_op::SbpContext* ctx) { const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape(); const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); Shape broadcast_shape = *JUST(GetBroadcastShape(cond_shape, x_shape, y_shape)); const size_t ndim = broadcast_shape.size(); std::vector broadcast_args; std::vector split_args; std::vector split_dims; broadcast_args.reserve(3); split_args.reserve(3); split_dims.reserve(3); auto CheckArgCanSplit = [&](std::string&& arg_name, const int dim, const Shape& shape) { size_t ddiff = ndim - shape.size(); int dim_size = (dim >= ddiff) ? shape[dim - ddiff] : 1; if (dim_size == 1) { broadcast_args.emplace_back(std::forward(arg_name), 0); } else { split_args.emplace_back(std::forward(arg_name), 0); split_dims.push_back(dim - ddiff); } }; for (int i = 0; i < ndim; ++i) { if (broadcast_shape[i] == 1) { continue; } broadcast_args.clear(); split_args.clear(); split_dims.clear(); CheckArgCanSplit("x", i, x_shape); CheckArgCanSplit("y", i, y_shape); CheckArgCanSplit("condition", i, cond_shape); auto builder = ctx->NewBuilder(); builder.Broadcast(broadcast_args); for (int i = 0; i < split_args.size(); ++i) { builder.Split(split_args[i], split_dims[i]); } builder.Split(user_op::OpArg("out", 0), i); builder.Build(); } ctx->NewBuilder() .Broadcast(user_op::OpArg("condition", 0)) .PartialSum(user_op::OpArg("x", 0)) .PartialSum(user_op::OpArg("y", 0)) .PartialSum(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe WhereOp::ModifyInputArg(const GetInputArgModifier& fn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* cond_arg_modifier = fn("condition", 0); cond_arg_modifier->set_requires_grad(false); return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/ops/zero_like_op.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { /*static*/ Maybe ZeroLikeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { ctx->NewBuilder() .Split(user_op::OpArg("like", 0), i) .Split(user_op::OpArg("out", 0), i) .Build(); } ctx->NewBuilder() .PartialSum(user_op::OpArg("like", 0)) .Broadcast(user_op::OpArg("out", 0)) .Build(); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { ctx->SetOutputShape("out", 0, ctx->InputShape("like", 0)); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ZeroLikeOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("like", 0)); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); *like_distribution = in_sbp; *out_distribution = in_sbp; return Maybe::Ok(); } } // namespace oneflow ================================================ FILE: oneflow/user/summary/crc32c.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_CRC32C_H_ #define ONEFLOW_USER_SUMMARY_CRC32C_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace summary { static const uint32_t table[256] = { 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351}; inline uint32_t GetCrc32(const char* buf, size_t size) { const uint8_t* uchar_buf = reinterpret_cast(buf); uint32_t crc = 0 ^ 0xffffffffu; for (int i = 0; i < size; ++i) { crc = table[(crc & 0xff) ^ uchar_buf[i]] ^ (crc >> 8); } return crc ^ 0xffffffffu; } inline uint32_t MaskCrc32(uint32_t crc) { return ((crc >> 15) | (crc << 17)) + 0xa282ead8ul; } } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_CRC32C_H_ ================================================ FILE: oneflow/user/summary/env_time.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_ENV_TIME_H_ #define ONEFLOW_USER_SUMMARY_ENV_TIME_H_ #include "oneflow/core/common/util.h" namespace oneflow { namespace summary { static constexpr uint64_t kMicroTimeToNanoTime = 1000ULL; static constexpr uint64_t kSecondToNanoTime = 1000ULL * 1000ULL * 1000ULL; static constexpr uint64_t kMircoTimeToSecondTime = 1000ULL * 1000ULL; inline uint64_t CurrentNanoTime() { struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); return (static_cast(ts.tv_sec) * kSecondToNanoTime + static_cast(ts.tv_nsec)); } inline uint64_t CurrentMircoTime() { return CurrentNanoTime() / kMicroTimeToNanoTime; } inline uint64_t CurrentSecondTime() { return CurrentMircoTime() / kMircoTimeToSecondTime; } inline double GetWallTime() { return static_cast(CurrentNanoTime() / kMicroTimeToNanoTime) / 1.0e6; } } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_ENV_TIME_H_ ================================================ FILE: oneflow/user/summary/event_writer_helper.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/summary/event_writer_helper.h" #include "oneflow/user/summary/env_time.h" #include "oneflow/user/summary/events_writer.h" #include "oneflow/user/summary/histogram.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/summary/summary.pb.h" #include "oneflow/core/summary/event.pb.h" #include #include #include #include #define USER_LIBPNG_VER_STRING "1.6.24" namespace oneflow { namespace summary { const char* kScalarPluginName = "scalars"; const char* kHistogramPluginName = "histograms"; const char* kImagePluginName = "images"; void SetPluginData(SummaryMetadata* metadata, const char* name) { if (metadata->plugin_data().plugin_name().empty()) { metadata->mutable_plugin_data()->set_plugin_name(name); } } Maybe FillScalarInSummary(const float& value, const std::string& tag, Summary* s) { SummaryMetadata metadata; SetPluginData(&metadata, kScalarPluginName); Summary::Value* v = s->add_value(); v->set_tag(tag); *v->mutable_metadata() = metadata; v->set_simple_value(value); return Maybe::Ok(); } template Maybe FillHistogramInSummary(const user_op::Tensor& value, const std::string& tag, Summary* s) { SummaryMetadata metadata; SetPluginData(&metadata, kHistogramPluginName); Summary::Value* v = s->add_value(); v->set_tag(tag); *v->mutable_metadata() = metadata; summary::Histogram histo; for (int64_t i = 0; i < value.shape_view().elem_cnt(); i++) { double double_val = value.dptr()[i]; histo.AppendValue(double_val); } histo.AppendToProto(v->mutable_histo()); return Maybe::Ok(); } void WriteImageDataFn(png_structp png_ptr, png_bytep data, png_size_t length) { std::string* const s = reinterpret_cast(png_get_io_ptr(png_ptr)); s->append(reinterpret_cast(data), length); } bool WriteImageToBuffer(const uint8_t* image, int width, int height, int depth, std::string* png_string) { CHECK_NOTNULL(image); CHECK_NOTNULL(png_string); if (width == 0 || height == 0) return false; png_string->resize(0); png_infop info_ptr = nullptr; png_structp png_ptr = png_create_write_struct(USER_LIBPNG_VER_STRING, 0, 0, 0); if (png_ptr == nullptr) return false; if (setjmp(png_jmpbuf(png_ptr))) { png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr); return false; } info_ptr = png_create_info_struct(png_ptr); if (info_ptr == nullptr) { png_destroy_write_struct(&png_ptr, nullptr); return false; } int color_type = -1; switch (depth) { case 1: color_type = PNG_COLOR_TYPE_GRAY; break; case 2: color_type = PNG_COLOR_TYPE_GRAY_ALPHA; break; case 3: color_type = PNG_COLOR_TYPE_RGB; break; case 4: color_type = PNG_COLOR_TYPE_RGB_ALPHA; break; default: png_destroy_write_struct(&png_ptr, &info_ptr); return false; } const int bit_depth = 8; png_set_write_fn(png_ptr, png_string, WriteImageDataFn, nullptr); png_set_compression_level(png_ptr, Z_DEFAULT_COMPRESSION); png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL); png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type, PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, PNG_FILTER_TYPE_DEFAULT); png_write_info(png_ptr, info_ptr); png_byte* row = reinterpret_cast(const_cast(image)); int row_bytes = width * depth; for (; height--; row += row_bytes) png_write_row(png_ptr, row); png_write_end(png_ptr, nullptr); png_destroy_write_struct(&png_ptr, &info_ptr); return true; } Maybe FillImageInSummary(const user_op::Tensor& tensor, const std::string& tag, Summary* s) { SummaryMetadata metadata; SetPluginData(&metadata, kImagePluginName); if (!(tensor.shape_view().NumAxes() == 4 && (tensor.shape_view().At(3) == 1 || tensor.shape_view().At(3) == 3 || tensor.shape_view().At(3) == 4))) { UNIMPLEMENTED(); } if (!(tensor.shape_view().At(0) < (1LL << 31) && tensor.shape_view().At(1) < (1LL << 31) && tensor.shape_view().At(2) < (1LL << 31) && (tensor.shape_view().At(1) * tensor.shape_view().At(2)) < (1LL << 29))) { UNIMPLEMENTED(); } const int64_t batch_size = static_cast(tensor.shape_view().At(0)); const int64_t h = static_cast(tensor.shape_view().At(1)); const int64_t w = static_cast(tensor.shape_view().At(2)); const int64_t hw = h * w; const int64_t depth = static_cast(tensor.shape_view().At(3)); if (tensor.data_type() == DataType::kUInt8) { auto ith_image = [&tensor, hw, depth](int i) { auto images = tensor.dptr(); auto image_i = std::unique_ptr{new uint8_t[hw * depth]}; memcpy(image_i.get(), images + i * hw * depth, hw * depth); return image_i; }; for (int i = 0; i < batch_size; ++i) { Summary::Value* v = s->add_value(); *v->mutable_metadata() = metadata; if (batch_size == 1) { v->set_tag(tag); } else { v->set_tag(tag + std::to_string(i)); } Image* si = v->mutable_image(); si->set_height(h); si->set_width(w); si->set_colorspace(depth); auto image = ith_image(i); if (!WriteImageToBuffer(image.get(), w, h, depth, si->mutable_encoded_image_string())) UNIMPLEMENTED(); } } return Maybe::Ok(); } template struct EventWriterHelper { static void WritePbToFile(int64_t step, const std::string& value) { std::unique_ptr e{new Event}; Summary sum; TxtString2PbMessage(value, &sum); e->set_step(step); e->set_wall_time(GetWallTime()); *e->mutable_summary() = sum; Singleton::Get()->AppendQueue(std::move(e)); } static void WriteScalarToFile(int64_t step, float value, const std::string& tag) { std::unique_ptr e{new Event}; e->set_step(step); e->set_wall_time(GetWallTime()); CHECK_JUST(FillScalarInSummary(value, tag, e->mutable_summary())); Singleton::Get()->AppendQueue(std::move(e)); } static void WriteHistogramToFile(int64_t step, const user_op::Tensor& value, const std::string& tag) { std::unique_ptr e{new Event}; e->set_step(step); e->set_wall_time(GetWallTime()); CHECK_JUST(FillHistogramInSummary(value, tag, e->mutable_summary())); Singleton::Get()->AppendQueue(std::move(e)); } static void WriteImageToFile(int64_t step, const user_op::Tensor& tensor, const std::string& tag) { std::unique_ptr e{new Event}; e->set_step(step); e->set_wall_time(GetWallTime()); CHECK_JUST(FillImageInSummary(tensor, tag, e->mutable_summary())); Singleton::Get()->AppendQueue(std::move(e)); } }; #define INSTANTIATE_EVENT_WRITE_HELPER_CPU(dtype) \ template struct EventWriterHelper; INSTANTIATE_EVENT_WRITE_HELPER_CPU(float) INSTANTIATE_EVENT_WRITE_HELPER_CPU(double) INSTANTIATE_EVENT_WRITE_HELPER_CPU(int32_t) INSTANTIATE_EVENT_WRITE_HELPER_CPU(int64_t) INSTANTIATE_EVENT_WRITE_HELPER_CPU(uint8_t) INSTANTIATE_EVENT_WRITE_HELPER_CPU(int8_t) } // namespace summary } // namespace oneflow ================================================ FILE: oneflow/user/summary/event_writer_helper.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_ #define ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_ #include "oneflow/core/framework/framework.h" namespace oneflow { namespace summary { class EventsWriter; template struct EventWriterHelper { static void WritePbToFile(int64_t step, const std::string& value); static void WriteScalarToFile(int64_t step, float value, const std::string& tag); static void WriteHistogramToFile(int64_t step, const user_op::Tensor& value, const std::string& tag); static void WriteImageToFile(int64_t step, const user_op::Tensor& tensor, const std::string& tag); }; } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_ ================================================ FILE: oneflow/user/summary/events_writer.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/summary/events_writer.h" #include "oneflow/core/common/str_util.h" #include "oneflow/user/summary/env_time.h" namespace oneflow { namespace summary { EventsWriter::EventsWriter() : is_inited_(false) {} EventsWriter::~EventsWriter() { Close(); } Maybe EventsWriter::Init(const std::string& logdir) { file_system_ = std::make_unique(); log_dir_ = logdir + "/event"; file_system_->RecursivelyCreateDirIfNotExist(log_dir_); JUST(TryToInit()); is_inited_ = true; last_flush_time_ = CurrentMircoTime(); return Maybe::Ok(); } Maybe EventsWriter::TryToInit() { if (!filename_.empty()) { if (!file_system_->FileExists(filename_)) { LOG(WARNING) << "Event log file was lost, attempting create a new log file!"; } else { return Maybe::Ok(); } } int32_t current_time = CurrentSecondTime(); char fname[100] = {'\0'}; snprintf(fname, 100, "event.%d.log", current_time); filename_ = JoinPath(log_dir_, fname); file_system_->NewWritableFile(filename_, &writable_file_); CHECK_OR_RETURN(writable_file_ != nullptr); { Event event; event.set_wall_time(current_time); event.set_file_version(FILE_VERSION); WriteEvent(event); Flush(); } return Maybe::Ok(); } void EventsWriter::AppendQueue(std::unique_ptr event) { queue_mutex.lock(); event_queue_.emplace_back(std::move(event)); queue_mutex.unlock(); if (event_queue_.size() > MAX_QUEUE_NUM || CurrentMircoTime() - last_flush_time_ > FLUSH_TIME) { Flush(); } } void EventsWriter::Flush() { queue_mutex.lock(); for (const std::unique_ptr& e : event_queue_) { WriteEvent(*e); } event_queue_.clear(); queue_mutex.unlock(); FileFlush(); last_flush_time_ = CurrentMircoTime(); } void EventsWriter::WriteEvent(const Event& event) { std::string event_str; event.AppendToString(&event_str); if (!TryToInit().IsOk()) { LOG(ERROR) << "Write failed because file could not be opened."; return; } if (writable_file_ == nullptr) { LOG(WARNING) << "Log file is closed!"; return; } char head[kHeadSize]; char tail[kTailSize]; EncodeHead(head, event_str.size()); EncodeTail(tail, event_str.data(), event_str.size()); writable_file_->Append(head, sizeof(head)); writable_file_->Append(event_str.data(), event_str.size()); writable_file_->Append(tail, sizeof(tail)); FileFlush(); } void EventsWriter::FileFlush() { if (writable_file_ == nullptr) { return; } writable_file_->Flush(); } void EventsWriter::Close() { if (!is_inited_) { return; } queue_mutex.unlock(); Flush(); if (writable_file_ != nullptr) { writable_file_->Close(); writable_file_.reset(nullptr); } } } // namespace summary } // namespace oneflow ================================================ FILE: oneflow/user/summary/events_writer.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_ #define ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_ #include "oneflow/core/persistence/posix/posix_file_system.h" #include "oneflow/core/common/util.h" #include "oneflow/user/summary/crc32c.h" #include "oneflow/core/summary/event.pb.h" #include #include namespace oneflow { namespace summary { #define MAX_QUEUE_NUM 10 #define FLUSH_TIME 3 * 60 * 1000 * 1000 #define FILE_VERSION "brain.Event:3" const size_t kHeadSize = sizeof(uint64_t) + sizeof(uint32_t); const size_t kTailSize = sizeof(uint32_t); class EventsWriter { public: EventsWriter(); ~EventsWriter(); Maybe Init(const std::string& logdir); void WriteEvent(const Event& event); void Flush(); void Close(); void AppendQueue(std::unique_ptr event); void FileFlush(); private: Maybe TryToInit(); inline static void EncodeHead(char* head, size_t size); inline static void EncodeTail(char* tail, const char* data, size_t size); bool is_inited_; std::string log_dir_; std::string filename_; std::unique_ptr file_system_; std::unique_ptr writable_file_; uint64_t last_flush_time_; std::vector> event_queue_; std::mutex queue_mutex; OF_DISALLOW_COPY(EventsWriter); }; void EventsWriter::EncodeHead(char* head, size_t size) { memcpy(head, &size, sizeof(size)); uint32_t value = MaskCrc32(GetCrc32(head, sizeof(uint64_t))); memcpy(head + sizeof(uint64_t), &value, sizeof(value)); } void EventsWriter::EncodeTail(char* tail, const char* data, size_t size) { uint32_t value = MaskCrc32(GetCrc32(data, size)); memcpy(tail, &value, sizeof(value)); } } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_ ================================================ FILE: oneflow/user/summary/histogram.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/summary/histogram.h" #include "oneflow/core/common/maybe.h" #include #include namespace oneflow { namespace summary { static std::vector defalut_container = {-DBL_MAX, -451872326.521804, -410793024.1107308, -373448203.737028, -339498367.0336618, -308634879.1215107, -280577162.83773696, -255070148.03430632, -231881952.75846028, -210801775.23496386, -191637977.48633078, -174216343.1693916, -158378493.79035598, -143980448.9003236, -130891317.18211237, -118992106.52919304, -108174642.2992664, -98340583.90842399, -89400530.82583998, -81273209.8416727, -73884736.21970245, -67167942.01791131, -61061765.47082846, -55510695.882571325, -50464268.984155744, -45876608.16741431, -41706007.424922094, -37914552.20447463, -34467774.731340565, -31334340.664855056, -28485764.24077732, -25896149.309797563, -23541953.91799778, -21401776.28908889, -19456160.26280808, -17687418.420734618, -16079471.291576924, -14617701.174160838, -13288819.249237124, -12080744.772033747, -10982495.247303406, -9984086.58845764, -9076442.353143308, -8251311.230130279, -7501192.027391163, -6819265.479446511, -6199332.254042282, -5635756.594583892, -5123415.085985356, -4657650.078168505, -4234227.34378955, -3849297.5852632266, -3499361.4411483877, -3181237.6737712612, -2892034.2488829647, -2629122.0444390588, -2390110.949490053, -2172828.135900048, -1975298.30536368, -1795725.7321487998, -1632477.9383170905, -1484070.8530155367, -1349155.320923215, -1226504.8372029227, -1115004.3974572024, -1013640.3613247294, -921491.2375679357, -837719.3068799415, -761563.0062544922, -692330.005685902, -629390.9142599108, -572173.5584181007, -520157.7803800915, -472870.7094364468, -429882.4631240425, -390802.23920367495, -355274.76291243173, -322977.0571931197, -293615.50653919973, -266923.1877629088, -242657.4434208262, -220597.67583711472, -200543.34167010427, -182312.12879100387, -165738.2989009126, -150671.18081901144, -136973.80074455583, -124521.63704050529, -113201.48821864116, -102910.44383512832, -93554.94894102574, -85049.95358275066, -77318.13962068241, -70289.21783698401, -63899.28894271274, -58090.26267519339, -52809.32970472126, -48008.4815497466, -43644.07413613327, -39676.43103284842, -36069.48275713492, -32790.438870122656, -29809.489881929687, -27099.536256299714, -24635.942051181555, -22396.310955619592, -20360.2826869269, -18509.347897206273, -16826.679906551155, -15296.98173322832, -13906.347030207562, -12642.133663825056, -11492.848785295504, -10448.04435026864, -9498.222136607854, -8634.74739691623, -7849.770360832936, -7136.154873484486, -6487.413521349533, -5897.648655772302, -5361.49877797482, -4874.089798158927, -4430.990725599024, -4028.173386908203, -3661.9758062801843, -3329.0689148001675, -3026.42628618197, -2751.2966238017907, -2501.1787489107187, -2273.798862646108, -2067.089875132825, -1879.1726137571134, -1708.338739779194, -1553.0352179810852, -1411.8501981646227, -1283.500180149657, -1166.8183455905971, -1060.7439505369064, -964.3126823062785, -876.6478930057076, -796.9526300051887, -724.5023909138079, -658.6385371943708, -598.762306540337, -544.3293695821245, -494.844881438295, -449.85898312572266, -408.9627119324751, -371.78428357497734, -337.9857123408885, -307.2597384917168, -279.3270349924698, -253.93366817497255, -230.84878924997503, -209.86253568179546, -190.78412334708676, -173.44011213371522, -157.67282921246837, -143.33893564769852, -130.30812331608956, -118.46193028735415, -107.69266389759467, -97.90242172508606, -89.00220156826005, -80.91109233478186, -73.55553848616532, -66.86867135105938, -60.78970122823579, -55.26336475294163, -50.2394225026742, -45.67220227515836, -41.520183886507596, -37.745621715006905, -34.314201559097185, -31.19472869008835, -28.35884426371668, -25.78076751246971, -23.437061374972462, -21.306419431793145, -19.36947221072104, -17.608611100655487, -16.00782827332317, -14.552571157566518, -13.229610143242288, -12.026918312038443, -10.933562101853129, -9.93960191077557, -9.0360017370687, -8.214547033698818, -7.467770030635288, -6.788881846032079, -6.171710769120072, -5.6106461537455194, -5.100587412495926, -4.636897647723569, -4.215361497930517, -3.8321468163004693, -3.4837698330004265, -3.167063484545842, -2.8791486223144016, -2.6174078384676376, -2.379461671334216, -2.163146973940196, -1.9664972490365416, -1.7877247718514013, -1.6252043380467283, -1.4774584891333893, -1.3431440810303539, -1.221040073663958, -1.1100364306035981, -1.0091240278214528, -0.9173854798376843, -0.8339867998524402, -0.7581698180476728, -0.689245289134248, -0.62658662648568, -0.5696242058960727, -0.5178401871782479, -0.47076380652567984, -0.4279670968415271, -0.389060997128661, -0.35369181557150997, -0.32153801415591815, -0.2923072855962892, -0.26573389599662656, -0.2415762690878423, -0.21961479007985663, -0.199649809163506, -0.18149982651227817, -0.16499984228388923, -0.14999985662171747, -0.13636350601974315, -0.12396682365431194, -0.11269711241301085, -0.1024519203754644, -0.09313810943224035, -0.08467100857476395, -0.07697364415887631, -0.069976040144433, -0.06361458194948454, -0.05783143813589502, -0.05257403466899547, -0.04779457697181406, -0.04344961542892187, -0.03949965038992897, -0.0359087730817536, -0.032644339165230546, -0.029676671968391403, -0.02697879269853764, -0.02452617518048876, -0.022296522891353417, -0.020269566264866742, -0.018426878422606128, -0.01675170765691466, -0.01522882514264969, -0.013844386493317899, -0.012585805903016271, -0.01144164173001479, -0.010401492481831627, -0.009455902256210569, -0.008596274778373244, -0.007814795253066584, -0.007104359320969622, -0.006458508473608747, -0.005871371339644315, -0.005337610308767558, -0.004852373007970507, -0.004411248189064097, -0.004010225626421907, -0.0036456596603835515, -0.0033142360548941373, -0.003012941868085579, -0.0027390380618959806, -0.0024900346017236183, -0.0022636678197487437, -0.0020578798361352213, -0.0018707998510320192, -0.0017007271373018355, -0.001546115579365305, -0.0014055596176048226, -0.0012777814705498386, -0.0011616195186816714, -0.0010560177442560648, -0.000960016131141877, -0.0008727419374017063, -0.0007934017612742784, -0.0007212743284311622, -0.0006557039349374201, -0.0005960944863067456, -0.0005419040784606777, -0.0004926400713278887, -0.0004478546102980806, -0.0004071405548164369, -0.00037012777710585166, -0.000336479797368956, -0.00030589072488086905, -0.0002780824771644264, -0.00025280225196766033, -0.0002298202290615094, -0.00020892748096500852, -0.00018993407360455317, -0.00017266733964050286, -0.00015697030876409349, -0.00014270028069463043, -0.00012972752790420947, -0.00011793411627655406, -0.0001072128329786855, -9.7466211798805e-05, -8.860564708982272e-05, -8.05505882634752e-05, -7.322780751225018e-05, -6.657073410204561e-05, -6.051884918367783e-05, -5.5017135621525293e-05, -5.0015577837750266e-05, -4.546870712522751e-05, -4.1335188295661374e-05, -3.75774439051467e-05, -3.416131264104245e-05, -3.105573876458404e-05, -2.8232489785985488e-05, -2.566589980544135e-05, -2.3332636186764862e-05, -2.1211487442513508e-05, -1.9283170402285007e-05, -1.7530154911168186e-05, -1.5936504464698348e-05, -1.448773133154395e-05, -1.3170664846858136e-05, -1.197333167896194e-05, -1.088484698087449e-05, -9.895315437158626e-06, -8.995741306507842e-06, -8.177946642279855e-06, -7.43449694752714e-06, -6.758633588661036e-06, -6.144212353328214e-06, 0.0, 6.144212353328214e-06, 6.758633588661036e-06, 7.43449694752714e-06, 8.177946642279855e-06, 8.995741306507842e-06, 9.895315437158626e-06, 1.088484698087449e-05, 1.197333167896194e-05, 1.3170664846858136e-05, 1.448773133154395e-05, 1.5936504464698348e-05, 1.7530154911168186e-05, 1.9283170402285007e-05, 2.1211487442513508e-05, 2.3332636186764862e-05, 2.566589980544135e-05, 2.8232489785985488e-05, 3.105573876458404e-05, 3.416131264104245e-05, 3.75774439051467e-05, 4.1335188295661374e-05, 4.546870712522751e-05, 5.0015577837750266e-05, 5.5017135621525293e-05, 6.051884918367783e-05, 6.657073410204561e-05, 7.322780751225018e-05, 8.05505882634752e-05, 8.860564708982272e-05, 9.7466211798805e-05, 0.0001072128329786855, 0.00011793411627655406, 0.00012972752790420947, 0.00014270028069463043, 0.00015697030876409349, 0.00017266733964050286, 0.00018993407360455317, 0.00020892748096500852, 0.0002298202290615094, 0.00025280225196766033, 0.0002780824771644264, 0.00030589072488086905, 0.000336479797368956, 0.00037012777710585166, 0.0004071405548164369, 0.0004478546102980806, 0.0004926400713278887, 0.0005419040784606777, 0.0005960944863067456, 0.0006557039349374201, 0.0007212743284311622, 0.0007934017612742784, 0.0008727419374017063, 0.000960016131141877, 0.0010560177442560648, 0.0011616195186816714, 0.0012777814705498386, 0.0014055596176048226, 0.001546115579365305, 0.0017007271373018355, 0.0018707998510320192, 0.0020578798361352213, 0.0022636678197487437, 0.0024900346017236183, 0.0027390380618959806, 0.003012941868085579, 0.0033142360548941373, 0.0036456596603835515, 0.004010225626421907, 0.004411248189064097, 0.004852373007970507, 0.005337610308767558, 0.005871371339644315, 0.006458508473608747, 0.007104359320969622, 0.007814795253066584, 0.008596274778373244, 0.009455902256210569, 0.010401492481831627, 0.01144164173001479, 0.012585805903016271, 0.013844386493317899, 0.01522882514264969, 0.01675170765691466, 0.018426878422606128, 0.020269566264866742, 0.022296522891353417, 0.02452617518048876, 0.02697879269853764, 0.029676671968391403, 0.032644339165230546, 0.0359087730817536, 0.03949965038992897, 0.04344961542892187, 0.04779457697181406, 0.05257403466899547, 0.05783143813589502, 0.06361458194948454, 0.069976040144433, 0.07697364415887631, 0.08467100857476395, 0.09313810943224035, 0.1024519203754644, 0.11269711241301085, 0.12396682365431194, 0.13636350601974315, 0.14999985662171747, 0.16499984228388923, 0.18149982651227817, 0.199649809163506, 0.21961479007985663, 0.2415762690878423, 0.26573389599662656, 0.2923072855962892, 0.32153801415591815, 0.35369181557150997, 0.389060997128661, 0.4279670968415271, 0.47076380652567984, 0.5178401871782479, 0.5696242058960727, 0.62658662648568, 0.689245289134248, 0.7581698180476728, 0.8339867998524402, 0.9173854798376843, 1.0091240278214528, 1.1100364306035981, 1.221040073663958, 1.3431440810303539, 1.4774584891333893, 1.6252043380467283, 1.7877247718514013, 1.9664972490365416, 2.163146973940196, 2.379461671334216, 2.6174078384676376, 2.8791486223144016, 3.167063484545842, 3.4837698330004265, 3.8321468163004693, 4.215361497930517, 4.636897647723569, 5.100587412495926, 5.6106461537455194, 6.171710769120072, 6.788881846032079, 7.467770030635288, 8.214547033698818, 9.0360017370687, 9.93960191077557, 10.933562101853129, 12.026918312038443, 13.229610143242288, 14.552571157566518, 16.00782827332317, 17.608611100655487, 19.36947221072104, 21.306419431793145, 23.437061374972462, 25.78076751246971, 28.35884426371668, 31.19472869008835, 34.314201559097185, 37.745621715006905, 41.520183886507596, 45.67220227515836, 50.2394225026742, 55.26336475294163, 60.78970122823579, 66.86867135105938, 73.55553848616532, 80.91109233478186, 89.00220156826005, 97.90242172508606, 107.69266389759467, 118.46193028735415, 130.30812331608956, 143.33893564769852, 157.67282921246837, 173.44011213371522, 190.78412334708676, 209.86253568179546, 230.84878924997503, 253.93366817497255, 279.3270349924698, 307.2597384917168, 337.9857123408885, 371.78428357497734, 408.9627119324751, 449.85898312572266, 494.844881438295, 544.3293695821245, 598.762306540337, 658.6385371943708, 724.5023909138079, 796.9526300051887, 876.6478930057076, 964.3126823062785, 1060.7439505369064, 1166.8183455905971, 1283.500180149657, 1411.8501981646227, 1553.0352179810852, 1708.338739779194, 1879.1726137571134, 2067.089875132825, 2273.798862646108, 2501.1787489107187, 2751.2966238017907, 3026.42628618197, 3329.0689148001675, 3661.9758062801843, 4028.173386908203, 4430.990725599024, 4874.089798158927, 5361.49877797482, 5897.648655772302, 6487.413521349533, 7136.154873484486, 7849.770360832936, 8634.74739691623, 9498.222136607854, 10448.04435026864, 11492.848785295504, 12642.133663825056, 13906.347030207562, 15296.98173322832, 16826.679906551155, 18509.347897206273, 20360.2826869269, 22396.310955619592, 24635.942051181555, 27099.536256299714, 29809.489881929687, 32790.438870122656, 36069.48275713492, 39676.43103284842, 43644.07413613327, 48008.4815497466, 52809.32970472126, 58090.26267519339, 63899.28894271274, 70289.21783698401, 77318.13962068241, 85049.95358275066, 93554.94894102574, 102910.44383512832, 113201.48821864116, 124521.63704050529, 136973.80074455583, 150671.18081901144, 165738.2989009126, 182312.12879100387, 200543.34167010427, 220597.67583711472, 242657.4434208262, 266923.1877629088, 293615.50653919973, 322977.0571931197, 355274.76291243173, 390802.23920367495, 429882.4631240425, 472870.7094364468, 520157.7803800915, 572173.5584181007, 629390.9142599108, 692330.005685902, 761563.0062544922, 837719.3068799415, 921491.2375679357, 1013640.3613247294, 1115004.3974572024, 1226504.8372029227, 1349155.320923215, 1484070.8530155367, 1632477.9383170905, 1795725.7321487998, 1975298.30536368, 2172828.135900048, 2390110.949490053, 2629122.0444390588, 2892034.2488829647, 3181237.6737712612, 3499361.4411483877, 3849297.5852632266, 4234227.34378955, 4657650.078168505, 5123415.085985356, 5635756.594583892, 6199332.254042282, 6819265.479446511, 7501192.027391163, 8251311.230130279, 9076442.353143308, 9984086.58845764, 10982495.247303406, 12080744.772033747, 13288819.249237124, 14617701.174160838, 16079471.291576924, 17687418.420734618, 19456160.26280808, 21401776.28908889, 23541953.91799778, 25896149.309797563, 28485764.24077732, 31334340.664855056, 34467774.731340565, 37914552.20447463, 41706007.424922094, 45876608.16741431, 50464268.984155744, 55510695.882571325, 61061765.47082846, 67167942.01791131, 73884736.21970245, 81273209.8416727, 89400530.82583998, 98340583.90842399, 108174642.2992664, 118992106.52919304, 130891317.18211237, 143980448.9003236, 158378493.79035598, 174216343.1693916, 191637977.48633078, 210801775.23496386, 231881952.75846028, 255070148.03430632, 280577162.83773696, 308634879.1215107, 339498367.0336618, 373448203.737028, 410793024.1107308, 451872326.521804, DBL_MAX}; Histogram::Histogram() { max_constainers_ = defalut_container; containers_.resize(max_constainers_.size()); for (size_t idx = 0; idx < max_constainers_.size(); idx++) { containers_.at(idx) = 0; } value_sum_ = 0; sum_value_squares_ = 0; max_value_ = -DBL_MAX; value_count_ = 0; min_value_ = DBL_MAX; } void Histogram::AppendValue(double value) { value_sum_ += value; value_count_++; sum_value_squares_ += value * value; if (max_value_ < value) { max_value_ = value; } if (min_value_ > value) { min_value_ = value; } int idx = std::upper_bound(max_constainers_.begin(), max_constainers_.end(), value) - max_constainers_.begin(); CHECK_GT(containers_.size(), idx); containers_.at(idx) += 1.0; } void Histogram::AppendToProto(HistogramProto* hist_proto) { hist_proto->Clear(); hist_proto->set_num(value_count_); hist_proto->set_sum(value_sum_); hist_proto->set_min(min_value_); hist_proto->set_max(max_value_); hist_proto->set_sum_squares(sum_value_squares_); for (size_t idx = 0; idx < containers_.size();) { double num = containers_.at(idx); double last = max_constainers_.at(idx); idx++; if (num <= 0.0) { while (idx < containers_.size() && containers_.at(idx) <= 0.0) { last = max_constainers_.at(idx); num = containers_.at(idx); idx++; } } hist_proto->add_bucket_limit(last); hist_proto->add_bucket(num); } } } // namespace summary } // namespace oneflow ================================================ FILE: oneflow/user/summary/histogram.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_HISTOGRAM_H_ #define ONEFLOW_USER_SUMMARY_HISTOGRAM_H_ #include #include "oneflow/core/summary/summary.pb.h" namespace oneflow { namespace summary { class Histogram { public: Histogram(); ~Histogram() {} void AppendValue(double value); void AppendToProto(HistogramProto* proto); private: double value_count_; double value_sum_; double sum_value_squares_; double min_value_; double max_value_; std::vector max_constainers_; std::vector containers_; }; } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_HISTOGRAM_H_ ================================================ FILE: oneflow/user/summary/plan_to_physical_graph.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/summary/plan_to_physical_graph.h" #include "oneflow/core/summary/graph.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/job/plan_util.h" namespace oneflow { namespace summary { void PlanToPhysicalGraphFile(const Plan& plan) { GraphDef physical_graph; physical_graph.set_version(3); // "compute graph version number = 3" HashMap regst_desc_id2produce_op_name; HashMap task_id2op_name; HashSet ctrl_regst_desc_id_set; for (const TaskProto& task : plan.task()) { std::string op_name = ""; for (const ExecNodeProto& exec_node : task.exec_sequence().exec_node()) { if (op_name != "") { op_name += " && "; } op_name += (exec_node.kernel_conf().op_attribute().op_conf().name()); } if (op_name == "") { continue; } task_id2op_name.insert({task.task_id(), op_name}); for (const auto& pair : task.produced_regst_desc()) { const RegstDescProto& regst = pair.second; int64_t regst_desc_id = regst.regst_desc_id(); regst_desc_id2produce_op_name.insert({regst_desc_id, op_name}); if (regst.regst_desc_type().has_ctrl_regst_desc()) { ctrl_regst_desc_id_set.insert(regst_desc_id); } } } for (const TaskProto& task : plan.task()) { if (task_id2op_name.find(task.task_id()) == task_id2op_name.end()) { continue; } NodeDef* node = physical_graph.add_node(); node->set_name(task_id2op_name.at(task.task_id())); const OperatorConf& op_conf = task.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); DeviceType device_type = PlanUtil::GetStreamId(task).device_id().device_type(); node->set_device(*CHECK_JUST(DeviceTag4DeviceType(device_type))); if (op_conf.has_user_conf()) { const UserOpConf& user_op = op_conf.user_conf(); node->set_op(user_op.op_type_name()); node->mutable_attr()->insert(user_op.attr().begin(), user_op.attr().end()); } else { // maybe need get op / attr by every different op_type_case node->set_op("system_op"); } for (const auto& pair : task.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { if (regst_desc_id2produce_op_name.find(regst_desc_id) != regst_desc_id2produce_op_name.end()) { std::string input_name = regst_desc_id2produce_op_name.at(regst_desc_id); if (ctrl_regst_desc_id_set.find(regst_desc_id) != ctrl_regst_desc_id_set.end()) { input_name = "^" + input_name; // control edge } node->add_input(input_name); } } } } TeePersistentLogStream::Create("physical_graph")->Write(physical_graph); } } // namespace summary } // namespace oneflow ================================================ FILE: oneflow/user/summary/plan_to_physical_graph.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_ #define ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_ #include "oneflow/core/job/plan.pb.h" namespace oneflow { namespace summary { void PlanToPhysicalGraphFile(const Plan& plan); } } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_ ================================================ FILE: oneflow/user/summary/summary_converter.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_ #define ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_ #include "nlohmann/json.hpp" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { namespace summary { static void ConvertProtobufMsg2Json(nlohmann::json& json_value, const PbMessage& pb_msg); static void ConvertRepeatedField2Json(nlohmann::json& json_value, const PbMessage& pb_msg, const PbFd* pb_field, const google::protobuf::Reflection* pb_reflection) { if (NULL == pb_field || NULL == pb_reflection) { ConvertProtobufMsg2Json(json_value, pb_msg); } for (int i = 0; i < pb_reflection->FieldSize(pb_msg, pb_field); ++i) { nlohmann::json tmp_json_value; switch (pb_field->type()) { case PbFd::TYPE_MESSAGE: { const PbMessage& msg = pb_reflection->GetRepeatedMessage(pb_msg, pb_field, i); if (0 != msg.ByteSize()) { ConvertProtobufMsg2Json(tmp_json_value, msg); } } break; case PbFd::TYPE_INT32: tmp_json_value = pb_reflection->GetRepeatedInt32(pb_msg, pb_field, i); break; case PbFd::TYPE_UINT32: tmp_json_value = pb_reflection->GetRepeatedUInt32(pb_msg, pb_field, i); break; case PbFd::TYPE_INT64: { static char int64_str[25]; memset(int64_str, 0, sizeof(int64_str)); snprintf(int64_str, sizeof(int64_str), "%lld", (long long)pb_reflection->GetRepeatedInt64(pb_msg, pb_field, i)); tmp_json_value = int64_str; } break; case PbFd::TYPE_UINT64: { static char uint64str[25]; memset(uint64str, 0, sizeof(uint64str)); snprintf(uint64str, sizeof(uint64str), "%llu", (unsigned long long)pb_reflection->GetRepeatedUInt64(pb_msg, pb_field, i)); tmp_json_value = uint64str; } break; case PbFd::TYPE_STRING: case PbFd::TYPE_BYTES: tmp_json_value = pb_reflection->GetRepeatedString(pb_msg, pb_field, i); break; case PbFd::TYPE_BOOL: tmp_json_value = pb_reflection->GetRepeatedBool(pb_msg, pb_field, i); break; case PbFd::TYPE_ENUM: tmp_json_value = pb_reflection->GetRepeatedEnum(pb_msg, pb_field, i)->name(); break; case PbFd::TYPE_FLOAT: tmp_json_value = pb_reflection->GetRepeatedFloat(pb_msg, pb_field, i); break; case PbFd::TYPE_DOUBLE: tmp_json_value = pb_reflection->GetRepeatedDouble(pb_msg, pb_field, i); break; default: break; } json_value.emplace_back(tmp_json_value); } } static void ConvertProtobufMsg2Json(nlohmann::json& json_value, const PbMessage& pb_msg) { const google::protobuf::Descriptor* pb_descriptor = pb_msg.GetDescriptor(); const google::protobuf::Reflection* pb_reflection = pb_msg.GetReflection(); const int count = pb_descriptor->field_count(); for (int i = 0; i < count; ++i) { const PbFd* pb_field = pb_descriptor->field(i); if (pb_field->is_repeated()) { if (pb_reflection->FieldSize(pb_msg, pb_field) > 0) { ConvertRepeatedField2Json(json_value[pb_field->name()], pb_msg, pb_field, pb_reflection); } continue; } if (!pb_reflection->HasField(pb_msg, pb_field)) { continue; } switch (pb_field->type()) { case PbFd::TYPE_MESSAGE: { const PbMessage& msg = pb_reflection->GetMessage(pb_msg, pb_field); if (0 != msg.ByteSize()) { ConvertProtobufMsg2Json(json_value[pb_field->name()], msg); } } break; case PbFd::TYPE_INT32: json_value[pb_field->name()] = pb_reflection->GetInt32(pb_msg, pb_field); break; case PbFd::TYPE_UINT32: json_value[pb_field->name()] = pb_reflection->GetUInt32(pb_msg, pb_field); break; case PbFd::TYPE_INT64: { static char int64_str[25]; memset(int64_str, 0, sizeof(int64_str)); snprintf(int64_str, sizeof(int64_str), "%lld", (long long)pb_reflection->GetInt64(pb_msg, pb_field)); json_value[pb_field->name()] = int64_str; } break; case PbFd::TYPE_UINT64: { static char uint64_str[25]; memset(uint64_str, 0, sizeof(uint64_str)); snprintf(uint64_str, sizeof(uint64_str), "%llu", (unsigned long long)pb_reflection->GetUInt64(pb_msg, pb_field)); json_value[pb_field->name()] = uint64_str; } break; case PbFd::TYPE_STRING: case PbFd::TYPE_BYTES: { json_value[pb_field->name()] = pb_reflection->GetString(pb_msg, pb_field); } break; case PbFd::TYPE_BOOL: { json_value[pb_field->name()] = pb_reflection->GetBool(pb_msg, pb_field); } break; case PbFd::TYPE_ENUM: { json_value[pb_field->name()] = pb_reflection->GetEnum(pb_msg, pb_field)->name(); } break; case PbFd::TYPE_FLOAT: { json_value[pb_field->name()] = pb_reflection->GetFloat(pb_msg, pb_field); } break; case PbFd::TYPE_DOUBLE: { json_value[pb_field->name()] = pb_reflection->GetDouble(pb_msg, pb_field); } break; default: break; } } } } // namespace summary } // namespace oneflow #endif // ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_ ================================================ FILE: oneflow/user/utils/pool_util.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/utils/pool_util.h" #include "oneflow/core/operator/operator_util.h" namespace oneflow { Params3D::Params3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::string& padding, const std::vector& padding_before, const std::vector& padding_after, const std::vector& pool_size, const std::vector& strides, const bool ceil_mode) : dim_(dim), pool_size_3d_(Get3DVec(pool_size, dim)), strides_3d_(Get3DVec(strides, dim)), padding_before_3d_(Get3DVec(padding_before, dim)), padding_after_3d_(Get3DVec(padding_after, dim)), data_format_(data_format), padding_(padding), ceil_mode_(ceil_mode) { x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), GetInDim(x_shape, data_format, 2, dim)}; Get3DOutputSize(x_3d_, pool_size_3d_, strides_3d_, padding_, ceil_mode_, nullptr, &y_3d_, &padding_before_3d_, &padding_after_3d_); if (data_format == "channels_first") { channel_num_ = x_shape.At(1); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; channel_num_ = x_shape.At(x_shape.NumAxes() - 1); } batch_num_ = x_shape.At(0); } void Params3D::Reset(const ShapeView& x_shape) { x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_), GetInDim(x_shape, data_format_, 2, dim_)}; Get3DOutputSize(x_3d_, pool_size_3d_, strides_3d_, padding_, ceil_mode_, nullptr, &y_3d_, &padding_before_3d_, &padding_after_3d_); } Shape Params3D::GetYShape() const { DimVector y_dim_vec; if (dim_ == 1) { y_dim_vec = {y_3d_.at(2)}; } else if (dim_ == 2) { y_dim_vec = {y_3d_.at(1), y_3d_.at(2)}; } else if (dim_ == 3) { y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}; } else { UNIMPLEMENTED(); } if (data_format_ == "channels_first") { y_dim_vec.insert(y_dim_vec.begin(), channel_num_); } else { CHECK_EQ(data_format_, "channels_last") << "data_format must be 'channels_first' or 'channels_last'"; y_dim_vec.insert(y_dim_vec.end(), channel_num_); } y_dim_vec.insert(y_dim_vec.begin(), batch_num_); return Shape(y_dim_vec); } Shape Params3D::GetXShape5D() const { return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)}); } Shape Params3D::GetYShape5D() const { return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}); } } // namespace oneflow ================================================ FILE: oneflow/user/utils/pool_util.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_USER_UTILS_POOL_UTIL_H_ #define ONEFLOW_USER_UTILS_POOL_UTIL_H_ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { typedef small_vector FixedDimVector; typedef small_vector FixedVector; class Params3D { public: Params3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, const std::string& padding, const std::vector& padding_before, const std::vector& padding_after, const std::vector& pool_size, const std::vector& strides, const bool ceil_mode); ~Params3D() = default; void Reset(const ShapeView& x_shape); Shape GetYShape() const; Shape GetXShape5D() const; Shape GetYShape5D() const; const std::vector& pool_size_3d() const { return pool_size_3d_; } const std::vector& strides_3d() const { return strides_3d_; } const std::vector& padding_before_3d() const { return padding_before_3d_; } const std::vector& padding_after_3d() const { return padding_after_3d_; } private: int32_t dim_; FixedDimVector x_3d_; FixedDimVector y_3d_; std::vector pool_size_3d_; std::vector strides_3d_; std::vector padding_before_3d_; std::vector padding_after_3d_; std::string data_format_; std::string padding_; bool ceil_mode_; int64_t batch_num_; int64_t channel_num_; }; enum class Get3DVecType { kPad, kNonPad }; template std::vector Get3DVec(const std::vector& original_vec, int32_t NDims) { std::vector vec; FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - NDims); if (index < 0) { vec.emplace_back(static_cast(get_3d_vec_type)); // kPad -> 0, kNonPad -> 1 } else { vec.emplace_back(original_vec.at(index)); } } return vec; } } // namespace oneflow #endif // ONEFLOW_USER_UTILS_POOL_UTIL_H_ ================================================ FILE: python/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ /oneflow/include /oneflow/core /oneflow/compatible/single_client/core /oneflow/version.py lib.py *.ast.py unittest-log-* log output ================================================ FILE: python/oneflow/_C/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow._oneflow_internal._C import * import oneflow._C._nn as _nn import warnings def allclose(input, other, atol=1e-08, rtol=1e-05, equal_nan=False): return isclose(input, other, atol, rtol, equal_nan).all().item() def _log_api_usage_once(event): warnings.warn("_log_api_usage_once is not implemented in oneflow") ================================================ FILE: python/oneflow/_C/_nn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import builtins from oneflow.framework.tensor import Tensor from typing import overload, Tuple, Any _device = flow.device _bool = builtins.bool _dtype = flow.dtype @overload def _parse_to( device: _device, dtype: _dtype, non_blocking: _bool, copy: _bool, *, memory_format: Any, ) -> Tuple[_device, _dtype, _bool, Any]: ... @overload def _parse_to( dtype: _dtype, non_blocking: _bool, copy: _bool, *, memory_format: Any ) -> Tuple[_device, _dtype, _bool, Any]: ... @overload def _parse_to( tensor: Tensor, non_blocking: _bool, copy: _bool, *, memory_format: Any ) -> Tuple[_device, _dtype, _bool, Any]: ... def _parse_to(*args, **kwargs): # TODO: implement _parse_to natively result = flow.tensor([]).to(*args, **kwargs) return (result.device, result.dtype, False, None) ================================================ FILE: python/oneflow/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import sys import collections import warnings # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables if "CUDA_MODULE_LOADING" not in os.environ: os.environ["CUDA_MODULE_LOADING"] = "LAZY" import oneflow._oneflow_internal oneflow._oneflow_internal.RegisterSignalHandler() oneflow_python_base_dir = os.path.dirname(os.path.realpath(__file__)) oneflow._oneflow_internal.InitPythonPathsToBeKeptAndFilteredForDebugging( oneflow_python_base_dir ) oneflow._oneflow_internal.InitNumpyCAPI() oneflow._oneflow_internal.CheckAndClearRegistryFlag() Size = oneflow._oneflow_internal.Size device = oneflow._oneflow_internal.device placement = oneflow._oneflow_internal.placement locals()["dtype"] = oneflow._oneflow_internal.dtype locals()["bool"] = oneflow._oneflow_internal.bool locals()["float16"] = oneflow._oneflow_internal.float16 locals()["half"] = oneflow._oneflow_internal.float16 locals()["float32"] = oneflow._oneflow_internal.float32 locals()["float"] = oneflow._oneflow_internal.float locals()["double"] = oneflow._oneflow_internal.double locals()["float64"] = oneflow._oneflow_internal.float64 locals()["int8"] = oneflow._oneflow_internal.int8 locals()["int"] = oneflow._oneflow_internal.int32 locals()["int32"] = oneflow._oneflow_internal.int32 locals()["int64"] = oneflow._oneflow_internal.int64 locals()["long"] = oneflow._oneflow_internal.int64 locals()["uint8"] = oneflow._oneflow_internal.uint8 locals()["record"] = oneflow._oneflow_internal.record locals()["tensor_buffer"] = oneflow._oneflow_internal.tensor_buffer locals()["bfloat16"] = oneflow._oneflow_internal.bfloat16 locals()["char"] = oneflow._oneflow_internal.char locals()["short"] = oneflow._oneflow_internal.int16 locals()["int16"] = oneflow._oneflow_internal.int16 locals()["cfloat"] = oneflow._oneflow_internal.cfloat locals()["complex64"] = oneflow._oneflow_internal.complex64 locals()["cdouble"] = oneflow._oneflow_internal.cdouble locals()["complex128"] = oneflow._oneflow_internal.complex128 locals()["layout"] = oneflow._oneflow_internal.layout locals()["strided"] = oneflow._oneflow_internal.strided locals()["memory_format"] = oneflow._oneflow_internal.memory_format locals()["contiguous_format"] = oneflow._oneflow_internal.contiguous_format locals()["channels_last"] = oneflow._oneflow_internal.channels_last locals()["preserve_format"] = oneflow._oneflow_internal.preserve_format from oneflow.version import __version__ from oneflow.version import __git_commit__ _DEPRECATED = set() def oneflow_deprecate(*api_names, **kwargs): def Decorator(func_or_class): _DEPRECATED.add(func_or_class) return func_or_class return Decorator def is_deprecated(func_or_class): return ( isinstance(func_or_class, collections.abc.Hashable) and func_or_class in _DEPRECATED ) def use_deterministic_algorithms(mode, *, warn_only=False): # register a empty method warnings.warn("Oneflow temporarily does not support use_deterministic_algorithms.") from oneflow._C import abs from oneflow._C import exp from oneflow._C import exp2 from oneflow._C import acos from oneflow._C import acos as arccos from oneflow._C import acosh from oneflow._C import acosh as arccosh from oneflow._C import amin from oneflow._C import atanh from oneflow._C import atanh as arctanh from oneflow._C import batch_matmul as bmm from oneflow._C import baddbmm from oneflow._C import broadcast_like from oneflow._C import chunk from oneflow._C import digamma from oneflow._C import split from oneflow._C import sign from oneflow._C import sinh from oneflow._C import tan from oneflow._C import greater from oneflow._C import greater as gt from oneflow._C import greater_ as gt_ from oneflow._C import greater_equal from oneflow._C import greater_equal as ge from oneflow._C import log from oneflow._C import log2 from oneflow._C import log10 from oneflow._C import logical_and from oneflow._C import logical_or from oneflow._C import logical_xor from oneflow._C import logical_not from oneflow._C import logaddexp from oneflow._C import quantile from oneflow._C import gelu_with_approximate as gelu from oneflow._C import quick_gelu from oneflow._C import square_relu from oneflow._C import mish from oneflow._C import repeat from oneflow._C import repeat_interleave from oneflow._C import tile from oneflow._C import sigmoid from oneflow._C import tanh from oneflow._C import as_strided from oneflow._C import as_strided_ from oneflow._C import silu from oneflow._C import selu from oneflow._C import softshrink from oneflow._C import softsign from oneflow._C import cast from oneflow._C import diag from oneflow._C import log1p from oneflow._C import add from oneflow._C import addcdiv from oneflow._C import div, div_ from oneflow._C import addcmul from oneflow._C import floor, floor_ from oneflow._C import floor_divide from oneflow._C import frac, frac_ from oneflow._C import mul from oneflow._C import negative from oneflow._C import negative as neg from oneflow._C import reciprocal from oneflow._C import sub from oneflow._C import sin, sin_ from oneflow._C import asin from oneflow._C import asin as arcsin from oneflow._C import asinh from oneflow._C import asinh as arcsinh from oneflow._C import atan from oneflow._C import atan as arctan from oneflow._C import atan2 from oneflow._C import ceil, ceil_ from oneflow._C import clamp, clamp_, clamp_min, clamp_min_, clamp_max, clamp_max_ from oneflow._C import clip, clip_ from oneflow._C import cos from oneflow._C import cosh from oneflow._C import diagonal from oneflow._C import erf from oneflow._C import erfc from oneflow._C import expm1 from oneflow._C import fmod from oneflow._C import flatten from oneflow._C import topk from oneflow._C import in_top_k from oneflow._C import lgamma from oneflow._C import minimum from oneflow._C import maximum from oneflow._C import max from oneflow._C import min from oneflow._C import median from oneflow._C import mode from oneflow._C import pow from oneflow._C import reduce_prod as prod from oneflow._C import reduce_sum as sum from oneflow._C import reduce_mean as mean from oneflow._C import reduce_all as all from oneflow._C import reduce_any as any from oneflow._C import reduce_nansum as nansum from oneflow._C import logsumexp from oneflow._C import rsqrt from oneflow._C import sqrt from oneflow._C import square from oneflow._C import matmul from oneflow._C import mm from oneflow._C import matrix_vector_product as mv from oneflow._C import bernoulli from oneflow._C import round, round_ from oneflow._C import softplus from oneflow._C import threshold from oneflow._C import tril from oneflow._C import triu from oneflow._C import trunc from oneflow._C import pad from oneflow._C import transpose from oneflow._C import relu from oneflow._C import roc_auc_score from oneflow._C import softmax from oneflow._C import log_softmax from oneflow._C import argmax from oneflow._C import argmin from oneflow._C import std from oneflow._C import stft from oneflow._C import var from oneflow._C import stack, hstack, vstack, dstack, column_stack, row_stack from oneflow._C import atleast_1d, atleast_2d, atleast_3d from oneflow._C import squeeze from oneflow._C import narrow from oneflow._C import unsqueeze from oneflow._C import permute from oneflow._C import select from oneflow._C import unbind from oneflow._C import tensor_split from oneflow._C import hann_window from oneflow._C import hsplit from oneflow._C import vsplit from oneflow._C import concat from oneflow._C import concat as cat from oneflow._C import dim_gather as gather from oneflow._C import deform_conv2d from oneflow._C import gather_nd from oneflow._C import roi_align from oneflow._C import dot from oneflow._C import eye from oneflow._C import erfinv, erfinv_ from oneflow._C import cumsum from oneflow._C import contiguous from oneflow._C import cumprod from oneflow._C import swapaxes from oneflow._C import amax from oneflow._C import swapdims from oneflow._C import t from oneflow._C import masked_fill from oneflow._C import masked_fill_ from oneflow._C import equal from oneflow._C import broadcast_equal as eq from oneflow._C import not_equal from oneflow._C import not_equal as ne from oneflow._C import less as lt from oneflow._C import less_equal as le from oneflow._C import searchsorted from oneflow._C import flip from oneflow._C import index_select from oneflow._C import isnan from oneflow._C import isinf from oneflow._C import isfinite from oneflow._C import inv as inverse from oneflow._C import det from oneflow._C import iinfo, finfo from oneflow._C import multinomial from oneflow._C import linalg_cross as cross from oneflow._C import bincount from oneflow._C import isclose from oneflow._C import allclose from oneflow._C import lerp, lerp_ from oneflow._C import index_add, index_add_ from oneflow._C import sort from oneflow._C import clone from oneflow._C import bitwise_and, bitwise_or, bitwise_xor, bitwise_not from oneflow._C import real, imag, conj, conj_physical from oneflow._oneflow_internal import _set_num_threads as set_num_threads from . import sbp sbp.sbp.__call__ = lambda self: self import atexit import oneflow.framework.c_api_util import oneflow.framework.register_class_method_util as register_class_method_util register_class_method_util.RegisterMethod4Class() import oneflow.framework.env_util as env_util import oneflow.framework.scope_util as scope_util import oneflow.framework.session_context as session_ctx from oneflow.framework.tensor_str import set_printoptions _oneflow_global_unique_env = env_util.GetEnv() session_ctx.NewDefaultSession(_oneflow_global_unique_env) oneflow._oneflow_internal.RegisterGILForeignLockHelper() oneflow._oneflow_internal.autograd.graph.register_saved_tensors_hook_manager() oneflow._oneflow_internal.RegisterStackGetter() class ExitHook: def __init__(self): self.exit_code = None self.exception = None self._orig_exit = sys.exit self._orig_excepthook = sys.excepthook def exit(code=0): self.exit_code = code self._orig_exit(code) sys.exit = exit def exc_handler(exc_type, exc, *args): self.exception = exc self._orig_excepthook(exc_type, exc, *args) sys.excepthook = exc_handler def is_normal_exit(self): if self.exit_code is not None: return self.exit_code == 0 return self.exception is None hook = ExitHook() def atexit_hook(hook): _oneflow_global_unique_env.switch_to_shutting_down(hook.is_normal_exit()) oneflow.framework.session_context.TryCloseDefaultSession() atexit.register(atexit_hook, hook) del atexit_hook del hook del ExitHook del atexit del oneflow # default dtype from oneflow.framework.dtype import ( set_default_dtype, set_default_tensor_type, get_default_dtype, is_floating_point, ) import oneflow._C from oneflow._C import tensor, batch_gather from oneflow._C import from_numpy, from_dlpack from oneflow.autograd import ( enable_grad, set_grad_enabled, no_grad, inference_mode, is_grad_enabled, ) import oneflow.nn.image from oneflow.framework.check_point_v2 import load from oneflow.framework.check_point_v2 import save from oneflow.framework.check_point_v2 import frombuffer from oneflow.framework.dtype import convert_oneflow_dtype_to_numpy_dtype, dtypes from oneflow.framework.function_util import FunctionConfig from oneflow.framework.function_util import FunctionConfig as function_config from oneflow.framework.generator import create_generator as Generator from oneflow.framework.generator import ( default_generator, seed, manual_seed, initial_seed, get_rng_state, set_rng_state, ) # NOTE(chengcheng) oneflow.Model is unavailable now. # from oneflow.framework.model import Model import oneflow.utils.tensor import oneflow.utils.global_view import oneflow.utils.model_zoo from oneflow.framework.tensor import Tensor from oneflow.framework.tensor import is_nonzero from oneflow._oneflow_internal import to_dlpack from oneflow.framework.type_tensor import * from oneflow.framework.tensor import zero_ from oneflow.nn.modules.pooling import ( adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d, ) from oneflow.nn.modules.einsum import einsum_op as einsum from oneflow.nn.modules.is_tensor import is_tensor_op as is_tensor from oneflow.nn.modules.arange import arange_op as arange from oneflow.nn.modules.linspace import linspace_op as linspace from oneflow.nn.modules.logspace import logspace_op as logspace from oneflow.nn.modules.argsort import argsort_op as argsort from oneflow.nn.modules.argwhere import argwhere_op as argwhere from oneflow.nn.modules.constant import ones_op as ones from oneflow.nn.modules.constant import zeros_op as zeros from oneflow.nn.modules.constant import zeros_like_op as zeros_like from oneflow.nn.modules.constant import ones_like_op as ones_like from oneflow.nn.modules.constant import full_op as full from oneflow.nn.modules.constant import full_like_op as full_like from oneflow.nn.modules.constant import new_ones_op as new_ones from oneflow.nn.modules.constant import new_zeros_op as new_zeros from oneflow.nn.modules.constant import new_full_op as new_full from oneflow.nn.modules.empty import empty_op as empty from oneflow.nn.modules.empty import new_empty_op as new_empty from oneflow.nn.modules.empty import empty_like_op as empty_like from oneflow._C import empty_strided from oneflow.nn.modules.dataset import tensor_buffer_to_list_of_tensors from oneflow._C import movedim from oneflow.nn.modules.expand import expand_op as expand from oneflow.nn.modules.distributed_partial_fc_sample import ( distributed_partial_fc_sample_op as distributed_partial_fc_sample, ) from oneflow.nn.modules.roll import roll_op as roll from oneflow.nn.modules.masked_select import masked_select_op as masked_select from oneflow.nn.modules.math_ops import addmm_op as addmm from oneflow.nn.modules.nonzero import nonzero_op as nonzero from oneflow.nn.modules.nms import nms_op as nms from oneflow.nn.modules.numel import numel_op as numel from oneflow.nn.modules.meshgrid import meshgrid_op as meshgrid from oneflow.nn.modules.unique import unique_op as unique from oneflow._C import normal from oneflow._C import normal_ from oneflow._C import rand from oneflow._C import randn from oneflow._C import randn_like from oneflow._C import randint from oneflow._C import randint_like from oneflow._C import randperm from oneflow.nn.modules.reshape import reshape_op as reshape from oneflow.nn.modules.reshape import view_op as view from oneflow.nn.modules.slice import slice_op as slice from oneflow.nn.modules.slice import slice_update_op as slice_update from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer from oneflow.nn.modules.tensor_buffer import ( tensor_buffer_to_tensor_op as tensor_buffer_to_tensor, ) from oneflow.nn.modules.tensordot import tensordot from oneflow.nn.modules.norm import norm from oneflow.nn.modules.as_tensor import as_tensor from oneflow.nn.modules.tensor_buffer import tensor_to_tensor_buffer from oneflow.nn.modules.global_cast import local_to_global_op as local_to_global from oneflow.nn.modules.global_cast import global_to_global_op as global_to_global from oneflow.nn.modules.global_cast import to_global_op as to_global from oneflow.nn.modules.global_cast import to_local_op as to_local from oneflow.nn.modules.where import where_op as where from oneflow.nn.modules.scatter import * from oneflow.nn.modules.broadcast_ops import ( broadcast_tensors, broadcast_shapes, broadcast_to, ) from oneflow.ops.stateful_ops import StatefulOp as stateful_op # autocast from oneflow._oneflow_internal import ( is_autocast_enabled, set_autocast_enabled, get_autocast_gpu_dtype, get_autocast_cpu_dtype, set_autocast_gpu_dtype, set_autocast_cpu_dtype, is_autocast_cache_enabled, set_autocast_cache_enabled, clear_autocast_cache, ) from oneflow.amp.autocast_mode import * from oneflow.jit import * from . import ( autograd, distributed, distributions, linalg, optim, comm, boxing, backends, amp, hub, fx, fft, special, ) import oneflow.utils.data import oneflow.framework.docstr as docstr import oneflow.cuda import oneflow.multiprocessing import oneflow.asyncs import oneflow.one_embedding import oneflow.profiler import oneflow.mock_torch import oneflow.remat if oneflow._oneflow_internal.flags.with_mlir(): oneflow_internal_path = oneflow._oneflow_internal.__file__ oneflow._oneflow_internal.ir.load_jit_shared_lib(oneflow_internal_path) ================================================ FILE: python/oneflow/__main__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import os parser = argparse.ArgumentParser() parser.add_argument("--doctor", default=False, action="store_true", required=False) args = parser.parse_args() def main(): if args.doctor: import oneflow import oneflow.sysconfig print("path:", oneflow.__path__) print("version:", oneflow.__version__) print("git_commit:", oneflow.__git_commit__) print("cmake_build_type:", oneflow.sysconfig.cmake_build_type()) print("rdma:", oneflow.sysconfig.with_rdma()) print("mlir:", oneflow.sysconfig.with_mlir()) if __name__ == "__main__": main() ================================================ FILE: python/oneflow/_dynamo/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings # Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/__init__.py __all__ = [ "allow_in_graph", ] def allow_in_graph(fn): """ """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] assert callable(fn), "allow_in_graph expects a callable" warnings.warn( "The oneflow._dynamo.allow_in_graph interface is just to align the torch._dynamo.allow_in_graph interface and has no practical significance." ) return fn ================================================ FILE: python/oneflow/_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys import traceback import oneflow as flow class KeyErrorMessage(str): r"""str subclass that returns itself in repr""" def __repr__(self): return self class ExceptionWrapper(object): r"""Wraps an exception plus traceback to communicate across threads""" def __init__(self, exc_info=None, where="in background"): # It is important that we don't store exc_info, see # NOTE [ Python Traceback Reference Cycle Problem ] if exc_info is None: exc_info = sys.exc_info() self.exc_type = exc_info[0] self.exc_msg = "".join(traceback.format_exception(*exc_info)) self.where = where def reraise(self): r"""Reraises the wrapped exception in the current thread""" # Format a message such as: "Caught ValueError in DataLoader worker # process 2. Original Traceback:", followed by the traceback. msg = "Caught {} {}.\nOriginal {}".format( self.exc_type.__name__, self.where, self.exc_msg ) if self.exc_type == KeyError: # KeyError calls repr() on its argument (usually a dict key). This # makes stack traces unreadable. It will not be changed in Python # (https://bugs.python.org/issue2651), so we work around it. msg = KeyErrorMessage(msg) elif getattr(self.exc_type, "message", None): # Some exceptions have first argument as non-str but explicitly # have message field raise self.exc_type(message=msg) raise self.exc_type(msg) def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of same dense type. The api is referenced from https://github.com/pytorch/pytorch/blob/master/torch/_utils.py#L437 Since inputs are dense, the resulting tensor will be a concatenated 1D buffer. Element-wise operation on this buffer will be equivalent to operating individually. Args: tensors (Iterable[Tensor]): dense tensors to flatten. Returns: A contiguous 1D buffer containing input tensors. """ if len(tensors) == 1: return flow._C.flatten(tensors[0]) else: flatten_tensors = [] for tensor in tensors: flatten_tensors.append(flow.flatten(tensor)) return flow.cat(flatten_tensors, 0) def _unflatten_dense_tensors(flat, tensors): """View a flat buffer using the sizes of tensors. Assume that tensors are of same dense type, and that flat is given by _flatten_dense_tensors. The api is referenced from https://github.com/pytorch/pytorch/blob/master/torch/_utils.py#L474 Args: flat (Tensor): flattened dense tensors to unflatten. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to unflatten flat. Returns: Unflattened dense tensors with sizes same as tensors and values from flat. """ outputs = [] offset = 0 for tensor in tensors: numel = tensor.numel() if numel == 0: outputs.append(flow.zeros_like(tensor)) else: outputs.append(flow.narrow(flat, 0, offset, numel).view(tensor.size())) offset += numel return outputs ================================================ FILE: python/oneflow/amp/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .grad_scaler import GradScaler from .grad_scaler import StaticGradScaler from .autocast_mode import * ================================================ FILE: python/oneflow/amp/autocast_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools import warnings from typing import Any, Optional import oneflow as flow import oneflow._oneflow_internal.lazy_mode as lazy_mode __all__ = ["autocast_decorator", "autocast"] def autocast_decorator(autocast_instance, func): @functools.wraps(func) def decorate_autocast(*args, **kwargs): with autocast_instance: return func(*args, **kwargs) return decorate_autocast class autocast(object): r""" Note: The following doc was origined by pytorch, see https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py#L19-L179 Instances of :class:`autocast` serve as context managers or decorators that allow regions of your script to run in mixed precision. In these regions, ops run in an op-specific dtype chosen by autocast to improve performance while maintaining accuracy. When entering an autocast-enabled region, Tensors may be any type. You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. :class:`autocast` should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops. Example for CUDA Devices:: # Creates model and optimizer in default precision model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass (model + loss) with oneflow.autocast(device_type="cuda"): output = model(input) loss = loss_fn(output, target) # Exits the context manager before backward() loss.backward() optimizer.step() :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: class AutocastModel(nn.Module): ... @oneflow.autocast(device_type="cuda") def forward(self, input): ... Floating-point Tensors produced in an autocast-enabled region may be ``float16``. After returning to an autocast-disabled region, using them with floating-point Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) produced in the autocast region back to ``float32`` (or other dtype if desired). If a Tensor from the autocast region is already ``float32``, the cast is a no-op, and incurs no additional overhead. CUDA Example:: # Creates some tensors in default dtype (here assumed to be float32) a_float32 = oneflow.rand((8, 8), device="cuda") b_float32 = oneflow.rand((8, 8), device="cuda") c_float32 = oneflow.rand((8, 8), device="cuda") d_float32 = oneflow.rand((8, 8), device="cuda") with oneflow.autocast(device_type="cuda"): # oneflow.mm is on autocast's list of ops that should run in float16. # Inputs are float32, but the op runs in float16 and produces float16 output. # No manual casts are required. e_float16 = oneflow.mm(a_float32, b_float32) # Also handles mixed input types f_float16 = oneflow.mm(d_float32, e_float16) # After exiting autocast, calls f_float16.float() to use with d_float32 g_float32 = oneflow.mm(d_float32, f_float16.float()) CPU Training Example:: # Creates model and optimizer in default precision model = Net() optimizer = optim.SGD(model.parameters(), ...) for epoch in epochs: for input, target in data: optimizer.zero_grad() # Runs the forward pass with autocasting. with oneflow.autocast(device_type="cpu", dtype=oneflow.bfloat16): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() CPU Inference Example:: # Creates model in default precision model = Net().eval() with oneflow.autocast(device_type="cpu", dtype=oneflow.bfloat16): for input in data: # Runs the forward pass with autocasting. output = model(input) The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator must be invoked in that thread. Args: device_type(str, required): Whether to use 'cuda' or 'cpu' device enabled(bool, optional): Whether autocasting should be enabled in the region. Default: ``True`` dtype(oneflow_dtype, optional): Whether to use oneflow.float16 or oneflow.bfloat16. cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. Default: ``True`` """ def __init__( self, device_type: str, dtype: Optional[flow.dtype] = None, enabled: bool = True, cache_enabled: Optional[bool] = None, ): self.device = device_type if self.device == "cuda": self.fast_dtype = flow.get_autocast_gpu_dtype() elif self.device == "cpu": self.fast_dtype = flow.get_autocast_cpu_dtype() else: raise RuntimeError( "User specified autocast device_type must be 'cuda' or 'cpu'" ) self.cache_enabled = flow.is_autocast_cache_enabled() if dtype is not None: self.fast_dtype = dtype if cache_enabled is not None: self.cache_enabled = cache_enabled if self.device == "cpu": warnings.warn( "CPU autocast is not supported currently. Disabling autocast." ) enabled = False if lazy_mode.is_enabled(): warnings.warn( "Autocast is not supported for lazy mode. Disabling autocast." ) enabled = False self.enabled = enabled def __enter__(self): self.autocast_mode = flow._oneflow_internal.AutoCastMode( self.device, self.fast_dtype, self.enabled, self.cache_enabled ) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): del self.autocast_mode def __call__(self, func): return autocast_decorator(self, func) ================================================ FILE: python/oneflow/amp/grad_scaler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ class GradScaler(object): def __init__( self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ): self._init_scale = init_scale self._growth_factor = growth_factor self._backoff_factor = backoff_factor if self._backoff_factor != 1.0 / self._growth_factor: raise ValueError( "Only support 1.0/growth_factor as backoff_factor at the moment, " "got {}".format(backoff_factor) ) self._growth_interval = growth_interval def _generate_conf_for_graph(self, train_conf): train_conf.dynamic_loss_scale_policy.initial_loss_scale = self._init_scale train_conf.dynamic_loss_scale_policy.increment_period = self._growth_interval train_conf.dynamic_loss_scale_policy.multiplier = self._growth_factor class StaticGradScaler(object): def __init__(self, scale_factor): if scale_factor <= 0.0: raise ValueError("StaticGradScaler's scale_factor must > 0.0") self._scale_factor = scale_factor def _generate_conf_for_graph(self, train_conf): train_conf.loss_scale_factor = self._scale_factor ================================================ FILE: python/oneflow/ao/quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ class DeQuantStub: def __init__(self, *args, **kwargs): raise NotImplementedError( "The oneflow.ao.DeQuantStub interface is just to align the torch.ao.DeQuantStub interface and has no practical significance." ) class QuantStub: def __init__(self, *args, **kwargs): raise NotImplementedError( "The oneflow.ao.QuantStub interface is just to align the torch.ao.QuantStub interface and has no practical significance." ) ================================================ FILE: python/oneflow/asyncs/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .thread import Thread, thread ================================================ FILE: python/oneflow/asyncs/thread.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal Thread = oneflow._oneflow_internal.AsyncThread class thread: r"""Context-manager to pick worker thread. By default, all opkernels are excuted/launched in worker thread 0. Within this context, opkernels can be excuted/launched in the worker thread indicated by `thread_global_id`. This context manager is thread local; it will not affect ops in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) Args: worker_thread: a worker thread create with oneflow.asyncs.Thread. For example: .. code-block:: python >>> import oneflow as flow >>> with flow.asyncs.thread(flow.asyncs.Thread()): ... print(flow.ones(2, 2)) ... tensor([[1., 1.], [1., 1.]], dtype=oneflow.float32) """ def __init__(self, worker_thread: Thread): self.stream_set_ = oneflow._oneflow_internal.StreamSet(worker_thread) self.worker_thread_ = worker_thread def __enter__(self): self.guard_ = oneflow._oneflow_internal.StreamGuard(self.stream_set_) def __exit__(self, type, value, traceback): del self.guard_ ================================================ FILE: python/oneflow/autograd/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.autograd.autograd import backward, grad from oneflow.autograd.autograd_function import Function from oneflow.autograd.autograd_mode import ( set_grad_enabled, enable_grad, inference_mode, is_grad_enabled, no_grad, ) from oneflow.autograd.functional import vjp, jvp, jacobian, hessian, hvp, vhp from . import graph __all__ = [ "backward", "grad", "Function", "set_grad_enabled", "enable_grad", "inference_mode", "is_grad_enabled", "no_grad", "vjp", "jvp", "jacobian", "hessian", "hvp", "vhp", ] ================================================ FILE: python/oneflow/autograd/autograd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Sequence, Tuple, Union from oneflow._oneflow_internal import TensorTuple from oneflow._oneflow_internal.autograd import backward as backward_api from oneflow._oneflow_internal.autograd import grad as grad_api from oneflow.framework.tensor import Tensor from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple def grad( outputs: Union[Tensor, Sequence[Tensor]], inputs: Union[Tensor, Sequence[Tensor]], grad_outputs: Union[Tensor, Sequence[Tensor], None] = None, retain_graph: bool = False, create_graph: bool = False, allow_unused: bool = False, is_grads_batched: bool = False, ) -> Tuple[Tensor]: r""" Computes and returns the sum of gradients of outputs with respect to the inputs. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.autograd.grad.html. The graph is differentiated using the chain rule. ``grad_outputs`` should be a sequence of length matching ``outputs``, containing the "vector" in the Jacobian-vector product. (``None`` is an acceptable value for that tensor don't require gradient.) Args: outputs (Sequence[Tensor]): Tensors of which the derivative will be computed. inputs (Sequence[Tensor]): Inputs w.r.t. which the derivative will be returned(and not accumulated into ``.grad``). grad_outputs (Sequence[Tensor], optional): The "vector" in the Jacobian-vector product. Usually gradients w.r.t. each output. None values can be specified for scalar Tensors or ones that don't require grad. Defaults to None. retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be reset after backward is complete. Defaults to ``False``. Note that in nearly all cases setting this option to ``True`` is not needed and often can be worked around in a much more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. allow_unused (bool, optional): If ``False``, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. Defaults to ``False``. is_grads_batched (bool, optional): If True, the first dimension of each tensor in grad_outputs will be interpreted as the batch dimension. Instead of computing a single vector-Jacobian product, we compute a batch of vector-Jacobian products for each “vector” in the batch. This should lead to performance improvements when compared to manually looping and performing backward multiple times. Defaults to ``False``. Returns: Tuple(Tensor): A tuple of tensors containing the gradients for each ``inputs``. """ in_grads = grad_api( convert_to_tensor_tuple(outputs), convert_to_tensor_tuple(inputs), convert_to_tensor_tuple(grad_outputs), retain_graph, create_graph, allow_unused, is_grads_batched, ) return tuple([x for x in in_grads]) def backward( tensors: Union[Tensor, Sequence[Tensor]], grad_tensors: Union[Tensor, Sequence[Tensor], None], retain_graph: bool = False, create_graph: bool = False, ) -> None: r""" Computes the sum of gradients of given tensors with respect to graph leaves. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.autograd.backward.html. The graph is differentiated using the chain rule. If any of ``tensors`` are non-scalar (i.e. their data has more than one element) and require gradient, then the Jacobian-vector product would be computed, in this case the function additionally requires specifying ``grad_tensors``. It should be a sequence of matching length, that contains the "vector" in the Jacobian-vector product, usually the gradient of the differentiated function w.r.t. corresponding tensors. (``None`` is an acceptable value for all tensors that don't need gradient.) This function accumulates gradients in the leaves - you might need to zero ``.grad`` attributes or set them to ``None`` before calling it. Note: Using this method with ``create_graph=True`` will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using ``autograd.grad`` when creating the graph to avoid this. If you have to use this function, make sure to reset the ``.grad`` fields of your parameters to ``None`` after use to break the cycle and avoid the leak. Args: tensors (Tensor or Sequence[Tensor]): Tensors of which the derivative will be computed. grad_tensors (Tensor or Sequence[Tensor], optional): The "vector" in the Jacobian-vector product, usually gradients each element of corresponding tensors. (None values can be specified for scalar Tensors or ones that don't require grad.) retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be reset after backward is complete. Defaults to ``False``. Note that in nearly all cases setting this option to ``True`` is not needed and often can be worked around in a much more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. """ backward_api( convert_to_tensor_tuple(tensors), convert_to_tensor_tuple(grad_tensors), retain_graph, create_graph, ) ================================================ FILE: python/oneflow/autograd/autograd_function.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow._oneflow_internal import TensorTuple from oneflow._oneflow_internal.autograd import AutogradFunctionBase class Function(AutogradFunctionBase): r""" Function(self) Base class to create custom autograd.Function. To create a custom autograd.Function, subclass this class and implement the ``forward()`` and ``backward()`` static methods. Then, to use your custom op in the forward pass, call the class method ``apply()`` or ``__call__()``. Do not call ``forward()`` directly. For example: .. code-block:: python class Exp(Function): @staticmethod def forward(ctx, i): result = i.exp() ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): result, = ctx.saved_tensors return grad_output * result # Use it by calling the apply method or __call__ method output = Exp.apply(input) # output = Exp()(input) """ def __init__(self): super().__init__() def __call__(self, *inputs): r""" See :meth:`self.apply`. """ return self.apply(*inputs) @classmethod def apply(cls, *inputs): r""" Calculate output tensors and build backward graph. """ return AutogradFunctionBase.apply( cls.__name__, cls.forward, cls.backward, *inputs ) @staticmethod def forward(ctx, *inputs): r""" Override this function for custom forward calculation. """ raise NotImplementedError( "You must implement the forward function for custom autograd.Function." ) @staticmethod def backward(ctx, *out_grads): r""" Override this function for custom backward calculation. """ raise NotImplementedError( "You must implement the backward function for custom autograd.Function." ) ================================================ FILE: python/oneflow/autograd/autograd_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal from oneflow._oneflow_internal.autograd import AutoGradMode def is_grad_enabled(): r""" Returns True if grad mode is currently enabled. """ return oneflow._oneflow_internal.autograd.is_grad_enabled() class inference_mode: r""" Context-manager that enables or disables inference mode InferenceMode is a new context manager analogous to no_grad to be used when you arecertain your operations will have no interactions with autograd (e.g., model training). Code run under this mode gets better performance by disabling view tracking and version counter bumps. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) Args: mode (bool): Flag whether to enable or disable inference mode. (default: True) .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(2, 3, requires_grad=True) >>> with flow.inference_mode(): ... y = x * x >>> y.requires_grad False >>> @flow.inference_mode() ... def no_grad_func(x): ... return x * x >>> y = no_grad_func(x) >>> y.requires_grad False """ def __init__(self, mode=True): self.infer_mode = mode def __call__(self, func): def wrapper(*args, **kwargs): with AutoGradMode(not self.infer_mode): return func(*args, **kwargs) return wrapper def __enter__(self): self.grad_mode = AutoGradMode(not self.infer_mode) return self def __exit__(self, exc_type, exc_val, exc_tb): pass class enable_grad: r""" Context-manager that enabled gradient calculation. Enables gradient calculation, if it has been disabled via no_grad. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(2, 3, requires_grad=True) >>> with flow.no_grad(): ... with flow.enable_grad(): ... y = x * x >>> y.requires_grad True >>> @flow.enable_grad() ... def no_grad_func(x): ... return x * x >>> with flow.no_grad(): ... y = no_grad_func(x) >>> y.requires_grad True """ def __call__(self, func): def wrapper(*args, **kwargs): with AutoGradMode(True): return func(*args, **kwargs) return wrapper def __enter__(self): self.grad_mode = AutoGradMode(True) return self def __exit__(self, exc_type, exc_val, exc_tb): pass class no_grad: r""" Context-manager that disabled gradient calculation. Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.backward(). It will reduce memory consumption for computations that would otherwise have requires_grad=True. In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(2, 3, requires_grad=True) >>> with flow.no_grad(): ... y = x * x >>> y.requires_grad False >>> @flow.no_grad() ... def no_grad_func(x): ... return x * x >>> y = no_grad_func(x) >>> y.requires_grad False """ def __call__(self, func): def wrapper(*args, **kwargs): with AutoGradMode(False): return func(*args, **kwargs) return wrapper def __enter__(self): self.grad_mode = AutoGradMode(False) return self def __exit__(self, exc_type, exc_val, exc_tb): pass class set_grad_enabled: r""" Context-manager that enabled gradient calculation. Enables gradient calculation, if it has been disabled via no_grad. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) Args: mode (bool): Flag whether to enable or disable gradient calculation. (default: True) .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(2, 3, requires_grad=True) >>> with flow.set_grad_enabled(True): ... y = x * x >>> y.requires_grad True >>> @flow.set_grad_enabled(False) ... def no_grad_func(x): ... return x * x >>> y = no_grad_func(x) >>> y.requires_grad False """ def __init__(self, is_train=True): self.is_train = is_train self.prev_mode = is_grad_enabled() oneflow._oneflow_internal.autograd.set_grad_enabled(is_train) def __call__(self, func): # recover grad mode set in __init__ oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode) def wrapper(*args, **kwargs): with AutoGradMode(self.is_train): return func(*args, **kwargs) return wrapper def __enter__(self): # recover grad mode set in __init__ oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode) self.grad_mode = AutoGradMode(self.is_train) return self def __exit__(self, exc_type, exc_val, exc_tb): pass if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/autograd/functional.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This code is referenced from https://github.com/pytorch/pytorch/blob/master/torch/autograd/functional.py and consistent with oneflow. from typing import List, Tuple import oneflow as flow __all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] # Utility functions def _as_tuple_nocheck(x): if isinstance(x, tuple): return x elif isinstance(x, list): return tuple(x) else: return (x,) def _as_tuple(inp, arg_name=None, fn_name=None): # Ensures that inp is a tuple of Tensors # Returns whether or not the original inp was a tuple and the tupled version of the input if arg_name is None and fn_name is None: return _as_tuple_nocheck(inp) is_inp_tuple = True if not isinstance(inp, tuple): inp = (inp,) is_inp_tuple = False for i, el in enumerate(inp): if not isinstance(el, flow.Tensor): if is_inp_tuple: raise TypeError( f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" f" value at index {i} has type {type(el)}." ) else: raise TypeError( f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" f" given {arg_name} has type {type(el)}." ) return is_inp_tuple, inp def _tuple_postprocess(res, to_unpack): # Unpacks a potentially nested tuple of Tensors # to_unpack should be a single boolean or a tuple of two booleans. # It is used to: # - invert _as_tuple when res should match the inp given to _as_tuple # - optionally remove nesting of two tuples created by multiple calls to _as_tuple if isinstance(to_unpack, tuple): assert len(to_unpack) == 2 if not to_unpack[1]: res = tuple(el[0] for el in res) if not to_unpack[0]: res = res[0] else: if not to_unpack: res = res[0] return res def _grad_preprocess(inputs, create_graph, need_graph): # Preprocess the inputs to make sure they require gradient # inputs is a tuple of Tensors to preprocess # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs # need_graph specifies if we internally want gradients to flow back to the Tensors in res # Note that we *always* create a new Tensor object to be able to see the difference between # inputs given as arguments and the same Tensors automatically captured by the user function. res = [] for inp in inputs: if create_graph and inp.requires_grad: # Create at least a new Tensor object in a differentiable way # oneflow.torch has no is_sparse attribute. https://github.com/Oneflow-Inc/oneflow/issues/10401 res.append(inp.view_as(inp)) else: res.append(inp.detach().requires_grad_(need_graph)) return tuple(res) def _grad_postprocess(inputs, create_graph): # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not # request it. if isinstance(inputs[0], flow.Tensor): if not create_graph: return tuple(inp.detach() for inp in inputs) else: return inputs else: return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) def _validate_v(v, other, is_other_tuple): # This assumes that other is the correct shape, and v should match # Both are assumed to be tuples of Tensors if len(other) != len(v): if is_other_tuple: raise RuntimeError( f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." ) else: raise RuntimeError("The given v should contain a single Tensor.") for idx, (el_v, el_other) in enumerate(zip(v, other)): if el_v.size() != el_other.size(): prepend = "" if is_other_tuple: prepend = f"Entry {idx} in " raise RuntimeError( f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." ) def _check_requires_grad(inputs, input_type, strict): # Used to make all the necessary checks to raise nice errors in strict mode. if not strict: return if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: raise RuntimeError("Invalid input_type to _check_requires_grad") for i, inp in enumerate(inputs): if inp is None: # This can only be reached for grad_inputs. raise RuntimeError( f"The output of the user-provided function is independent of input {i}." " This is not allowed in strict mode." ) if not inp.requires_grad: if input_type == "hessian": raise RuntimeError( f"The hessian of the user-provided function with respect to input {i}" " is independent of the input. This is not allowed in strict mode." " You should ensure that your function is thrice differentiable and that" " the hessian depends on the inputs." ) elif input_type == "jacobian": raise RuntimeError( "While computing the hessian, found that the jacobian of the user-provided" f" function with respect to input {i} is independent of the input. This is not" " allowed in strict mode. You should ensure that your function is twice" " differentiable and that the jacobian depends on the inputs (this would be" " violated by a linear function for example)." ) elif input_type == "grad_inputs": raise RuntimeError( f"The gradient with respect to input {i} is independent of the inputs of the" " user-provided function. This is not allowed in strict mode." ) else: raise RuntimeError( f"Output {i} of the user-provided function does not require gradients." " The outputs must be computed in a differentiable manner from the input" " when running in strict mode." ) def _autograd_grad( outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, ): # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. # This has the extra constraint that inputs has to be a tuple assert isinstance(outputs, tuple) if grad_outputs is None: grad_outputs = (None,) * len(outputs) assert isinstance(grad_outputs, tuple) assert len(outputs) == len(grad_outputs) new_outputs: Tuple[flow.Tensor, ...] = tuple() new_grad_outputs: Tuple[flow.Tensor, ...] = tuple() for out, grad_out in zip(outputs, grad_outputs): if out is not None and out.requires_grad: new_outputs += (out,) new_grad_outputs += (grad_out,) if len(new_outputs) == 0: # No differentiable output, we don't need to call the autograd engine return (None,) * len(inputs) else: return flow.autograd.grad( new_outputs, inputs, new_grad_outputs, allow_unused=True, create_graph=create_graph, retain_graph=retain_graph, ) def _fill_in_zeros(grads, refs, strict, create_graph, stage): # Used to detect None in the grads and depending on the flags, either replace them # with Tensors full of 0s of the appropriate size based on the refs or raise an error. # strict and create graph allow us to detect when it is appropriate to raise an error # stage gives us information of which backward call we consider to give good error message if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") res: Tuple[flow.Tensor, ...] = tuple() for i, grads_i in enumerate(grads): if grads_i is None: if strict: if stage == "back": raise RuntimeError( "The output of the user-provided function is independent of " f"input {i}. This is not allowed in strict mode." ) elif stage == "back_trick": raise RuntimeError( f"The gradient with respect to the input is independent of entry {i}" " in the grad_outputs when using the double backward trick to compute" " forward mode gradients. This is not allowed in strict mode." ) elif stage == "double_back": raise RuntimeError( "The jacobian of the user-provided function is independent of " f"input {i}. This is not allowed in strict mode." ) else: raise RuntimeError( "The hessian of the user-provided function is independent of " f"entry {i} in the grad_jacobian. This is not allowed in strict " "mode as it prevents from using the double backward trick to " "replace forward mode AD." ) grads_i = flow.zeros_like(refs[i]) else: if strict and create_graph and not grads_i.requires_grad: if "double" not in stage: raise RuntimeError( "The jacobian of the user-provided function is independent of " f"input {i}. This is not allowed in strict mode when create_graph=True." ) else: raise RuntimeError( "The hessian of the user-provided function is independent of " f"input {i}. This is not allowed in strict mode when create_graph=True." ) res += (grads_i,) return res # Public API def vjp(func, inputs, v=None, create_graph=False, strict=False): r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html Args: func (function): a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. v (tuple of Tensors or Tensor): The vector for which the vector Jacobian product is computed. Must be the same size as the output of ``func``. This argument is optional when the output of ``func`` contains a single element and (if it is not provided) will be set as a Tensor containing a single ``1``. create_graph (bool, optional): If ``True``, both the output and result will be computed in a differentiable way. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the vjp for said inputs, which is the expected mathematical value. Defaults to ``False``. Returns: output (tuple): tuple with: func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` vjp (tuple of Tensors or Tensor): result of the dot product with the same shape as the inputs. Example: >>> def exp_reducer(x): ... return x.exp().sum(dim=1) >>> inputs = flow.rand(4, 4) >>> v = flow.ones(4) >>> vjp(exp_reducer, inputs, v) # doctest: +ELLIPSIS (tensor([5.7817, 7.2458, 5.7830, 6.7782]), tensor([[1.4458, 1.3962, 1.3042, 1.6354], [2.1288, 1.0652, 1.5483, 2.5035], [2.2046, 1.1292, 1.1432, 1.3059], [1.3225, 1.6652, 1.7753, 2.0152]])) >>> vjp(exp_reducer, inputs, v, create_graph=True) # doctest: +ELLIPSIS (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=), tensor([[1.4458, 1.3962, 1.3042, 1.6354], [2.1288, 1.0652, 1.5483, 2.5035], [2.2046, 1.1292, 1.1432, 1.3059], [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=)) >>> def adder(x, y): ... return 2 * x + 3 * y >>> inputs = (flow.rand(2), flow.rand(2)) >>> v = flow.ones(2) >>> vjp(adder, inputs, v) # doctest: +ELLIPSIS (tensor([2.4225, 2.3340]), (tensor([2., 2.]), tensor([3., 3.]))) """ with flow.enable_grad(): is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) outputs = func(*inputs) is_outputs_tuple, outputs = _as_tuple( outputs, "outputs of the user-provided function", "vjp" ) _check_requires_grad(outputs, "outputs", strict=strict) if v is not None: _, v = _as_tuple(v, "v", "vjp") v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) _validate_v(v, outputs, is_outputs_tuple) else: if len(outputs) != 1 or outputs[0].nelement() != 1: raise RuntimeError( "The vector v can only be None if the " "user-provided function returns " "a single Tensor with a single element." ) enable_grad = True if create_graph else flow.is_grad_enabled() with flow.set_grad_enabled(enable_grad): grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") # Cleanup objects and return them to the user outputs = _grad_postprocess(outputs, create_graph) vjp = _grad_postprocess(vjp, create_graph) return ( _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple), ) def jvp(func, inputs, v=None, create_graph=False, strict=False): r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html Args: func (function): a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. v (tuple of Tensors or Tensor): The vector for which the Jacobian vector product is computed. Must be the same size as the input of ``func``. This argument is optional when the input to ``func`` contains a single element and (if it is not provided) will be set as a Tensor containing a single ``1``. create_graph (bool, optional): If ``True``, both the output and result will be computed in a differentiable way. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the jvp for said inputs, which is the expected mathematical value. Defaults to ``False``. Returns: output (tuple): tuple with: func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` jvp (tuple of Tensors or Tensor): result of the dot product with the same shape as the output. Note: ``autograd.functional.jvp`` computes the jvp by using the backward of the backward (sometimes called the double backwards trick). This is not the most performant way of computing the jvp. Please consider using :func:`flow.func.jvp` instead. Example: >>> def exp_reducer(x): ... return x.exp().sum(dim=1) >>> inputs = flow.rand(4, 4) >>> v = flow.ones(4, 4) >>> jvp(exp_reducer, inputs, v) # doctest: +ELLIPSIS (tensor([6.3090, 4.6742, 7.9114, 8.2106]), tensor([6.3090, 4.6742, 7.9114, 8.2106])) >>> jvp(exp_reducer, inputs, v, create_graph=True) # doctest: +ELLIPSIS (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=), tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=)) >>> def adder(x, y): ... return 2 * x + 3 * y >>> inputs = (flow.rand(2), flow.rand(2)) >>> v = (flow.ones(2), flow.ones(2)) >>> jvp(adder, inputs, v) # doctest: +ELLIPSIS (tensor([2.2399, 2.5005]), tensor([5., 5.])) """ with flow.enable_grad(): is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) if v is not None: _, v = _as_tuple(v, "v", "jvp") v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: raise RuntimeError( "The vector v can only be None if the input to " "the user-provided function is a single Tensor " "with a single element." ) outputs = func(*inputs) is_outputs_tuple, outputs = _as_tuple( outputs, "outputs of the user-provided function", "jvp" ) _check_requires_grad(outputs, "outputs", strict=strict) # The backward is linear so the value of grad_outputs is not important as # it won't appear in the double backward graph. We only need to ensure that # it does not contain inf or nan. grad_outputs = tuple( flow.zeros_like(out, requires_grad=True) for out in outputs ) grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) if create_graph: with flow.enable_grad(): grad_res = _autograd_grad( grad_inputs, grad_outputs, v, create_graph=create_graph ) jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") else: grad_res = _autograd_grad( grad_inputs, grad_outputs, v, create_graph=create_graph ) jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") # Cleanup objects and return them to the user outputs = _grad_postprocess(outputs, create_graph) jvp = _grad_postprocess(jvp, create_graph) return ( _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple), ) def _construct_standard_basis_for(tensors, tensor_numels: Tuple[int, ...]): # This function: # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. # - Each chunk corresponds to one tensor. The chunk has the same dtype and # device as the tensor # # For example, with tensor_numels = [1, 2, 1], this function returns: # ( tensor([[1], tensor([[0, 0], tensor([[0], # [0], [1, 0], [0], # [0], [0, 1], [0], # [0]]) , [0, 0]]) , [1]]) ) # # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) # Precondition: tensors always has at least one element. # # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] # for context behind this function. All the pre-conditions are guarded for # in flow.autograd.functional.jacobian. assert len(tensors) == len(tensor_numels) assert len(tensors) > 0 total_numel = sum(tensor_numels) chunks = tuple( tensor.new_zeros(total_numel, tensor_numel) for tensor, tensor_numel in zip(tensors, tensor_numels) ) diag_start_idx = 0 for chunk, numel in zip(chunks, tensor_numels): # fill_ does not support NonContiguous.https://github.com/Oneflow-Inc/oneflow/issues/10394 # chunk.diagonal(diag_start_idx).fill_(1) for i in range(numel): chunk[diag_start_idx + i][i] = 1 diag_start_idx += numel return chunks def _jacfwd(func, inputs, strict=False, vectorize=False): if strict: raise RuntimeError( "flow.autograd.functional.jacobian: `strict=True` " 'and `strategy="forward-mode"` are not supported together (yet). ' "Please either set `strict=False` or " '`strategy="reverse-mode"`.' ) is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") output_info = [] if vectorize: # Computing Jacobian does not support vectorize. see issue 10397. https://github.com/Oneflow-Inc/oneflow/issues/10397 raise NotImplementedError("Computing Jacobian does not support vectorize. ") else: raise NotImplementedError( "Computing Jacobian using forward-AD or forward-over-reverse Hessian is" "only implemented for `vectorize=True`." ) def jacobian( func, inputs, create_graph=False, strict=False, vectorize=False, strategy="reverse-mode", ): r"""Compute the Jacobian of a given function. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html Args: func (function): a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. create_graph (bool, optional): If ``True``, the Jacobian will be computed in a differentiable manner. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the jacobian for said inputs, which is the expected mathematical value. Defaults to ``False``. vectorize (bool, optional): This feature is experimental. Please consider using :func:`flow.func.jacrev` or :func:`flow.func.jacfwd` instead if you are looking for something less experimental and more performant. When computing the jacobian, usually we invoke ``autograd.grad`` once per row of the jacobian. If this flag is ``True``, we perform only a single ``autograd.grad`` call with ``batched_grad=True`` which uses the vmap prototype feature. Though this should lead to performance improvements in many cases, because this feature is still experimental, there may be performance cliffs. See :func:`flow.autograd.grad`'s ``batched_grad`` parameter for more information. strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to determine whether the Jacobian will be computed with forward or reverse mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. Defaults to ``"reverse-mode"``. If ``func`` has more outputs than inputs, ``"forward-mode"`` tends to be more performant. Otherwise, prefer to use ``"reverse-mode"``. Returns: Jacobian (Tensor or nested tuple of Tensors): if there is a single input and output, this will be a single Tensor containing the Jacobian for the linearized inputs and output. If one of the two is a tuple, then the Jacobian will be a tuple of Tensors. If both of them are tuples, then the Jacobian will be a tuple of tuple of Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the ``i``\th output and ``j``\th input and will have as size the concatenation of the sizes of the corresponding output and the corresponding input and will have same dtype and device as the corresponding input. If strategy is ``forward-mode``, the dtype will be that of the output; otherwise, the input. Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def exp_reducer(x): ... return x.exp().sum(dim=1) >>> inputs = flow.rand(2, 2) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> jacobian(exp_reducer, inputs) tensor([[[1.4917, 2.4352], [0.0000, 0.0000]], [[0.0000, 0.0000], [2.4369, 2.3799]]]) >>> jacobian(exp_reducer, inputs, create_graph=True) tensor([[[1.4917, 2.4352], [0.0000, 0.0000]], [[0.0000, 0.0000], [2.4369, 2.3799]]], grad_fn=) >>> def exp_adder(x, y): ... return 2 * x.exp() + 3 * y >>> inputs = (flow.rand(2), flow.rand(2)) >>> jacobian(exp_adder, inputs) (tensor([[2.8052, 0.0000], [0.0000, 3.3963]]), tensor([[3., 0.], [0., 3.]])) """ assert strategy in ("forward-mode", "reverse-mode"), ( 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' 'function has more outputs than inputs, "forward-mode" tends to be more performant. ' 'Otherwise, prefer to use "reverse-mode".' ) if strategy == "forward-mode": if create_graph: raise NotImplementedError( "flow.autograd.functional.jacobian: `create_graph=True` " 'and `strategy="forward-mode"` are not supported together (yet). ' "Please either set `create_graph=False` or " '`strategy="reverse-mode"`.' ) return _jacfwd(func, inputs, strict, vectorize) with flow.enable_grad(): is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) outputs = func(*inputs) is_outputs_tuple, outputs = _as_tuple( outputs, "outputs of the user-provided function", "jacobian" ) _check_requires_grad(outputs, "outputs", strict=strict) if vectorize: if strict: raise RuntimeError( "flow.autograd.functional.jacobian: `strict=True` " "and `vectorized=True` are not supported together. " "Please either set `strict=False` or " "`vectorize=False`." ) # NOTE: [Computing jacobian with vmap and grad for multiple outputs] # # Let's consider f(x) = (x**2, x.sum()) and let x = flow.randn(3). # It turns out we can compute the jacobian of this function with a single # call to autograd.grad by using vmap over the correct grad_outputs. # # Firstly, one way to compute the jacobian is to stack x**2 and x.sum() # into a 4D vector. E.g., use g(x) = flow.stack([x**2, x.sum()]) # # To get the first row of the jacobian, we call # >>> autograd.grad(g(x), x, grad_outputs=flow.tensor([1, 0, 0, 0])) # To get the 2nd row of the jacobian, we call # >>> autograd.grad(g(x), x, grad_outputs=flow.tensor([0, 1, 0, 0])) # and so on. # # Using vmap, we can vectorize all 4 of these computations into one by # passing the standard basis for R^4 as the grad_output. # vmap(partial(autograd.grad, g(x), x))(flow.eye(4)). # # Now, how do we compute the jacobian *without stacking the output*? # We can just split the standard basis across the outputs. So to # compute the jacobian of f(x), we'd use # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) # The grad_outputs looks like the following: # ( flow.tensor([[1, 0, 0], # [0, 1, 0], # [0, 0, 1], # [0, 0, 0]]), # flow.tensor([[0], # [0], # [0], # [1]]) ) # # But we're not done yet! # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) # returns a Tensor of shape [4, 3]. We have to remember to split the # jacobian of shape [4, 3] into two: # - one of shape [3, 3] for the first output # - one of shape [ 3] for the second output # Step 1: Construct grad_outputs by splitting the standard basis output_numels = tuple(output.numel() for output in outputs) grad_outputs = _construct_standard_basis_for(outputs, output_numels) flat_outputs = tuple(output.reshape(-1) for output in outputs) # Step 2: Call vmap + autograd.grad def vjp(grad_output): vj = list( _autograd_grad( flat_outputs, inputs, grad_output, create_graph=create_graph, ) ) for el_idx, vj_el in enumerate(vj): if vj_el is not None: continue vj[el_idx] = flow.zeros_like(inputs[el_idx]).expand( (sum(output_numels),) + inputs[el_idx].shape ) return tuple(vj) jacobians_of_flat_output = vjp(grad_outputs) # Step 3: The returned jacobian is one big tensor per input. In this step, # we split each Tensor by output. jacobian_input_output = [] for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs): jacobian_input_i_output = [] for jac, output_j in zip( jac_input_i.split(output_numels, dim=0), outputs ): jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) jacobian_input_i_output.append(jacobian_input_i_output_j) jacobian_input_output.append(jacobian_input_i_output) # Step 4: Right now, `jacobian` is a List[List[Tensor]]. # The outer List corresponds to the number of inputs, # the inner List corresponds to the number of outputs. # We need to exchange the order of these and convert to tuples # before returning. jacobian_output_input = tuple(zip(*jacobian_input_output)) jacobian_output_input = _grad_postprocess( jacobian_output_input, create_graph ) return _tuple_postprocess( jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) ) jacobian: Tuple[flow.Tensor, ...] = tuple() for i, out in enumerate(outputs): # mypy complains that expression and variable have different types due to the empty list jac_i: Tuple[List[flow.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] for j in range(out.nelement()): vj = _autograd_grad( (out.reshape(-1)[j],), inputs, retain_graph=True, create_graph=create_graph, ) for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( zip(jac_i, vj, inputs) ): if vj_el is not None: if strict and create_graph and not vj_el.requires_grad: msg = ( "The jacobian of the user-provided function is " f"independent of input {i}. This is not allowed in " "strict mode when create_graph=True." ) raise RuntimeError(msg) jac_i_el.append(vj_el) else: if strict: msg = ( f"Output {i} of the user-provided function is " f"independent of input {el_idx}. This is not allowed in " "strict mode." ) raise RuntimeError(msg) jac_i_el.append(flow.zeros_like(inp_el)) jacobian += ( tuple( flow.stack(jac_i_el, dim=0).view( out.size() + inputs[el_idx].size() # type: ignore[operator] ) for (el_idx, jac_i_el) in enumerate(jac_i) ), ) jacobian = _grad_postprocess(jacobian, create_graph) return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) def hessian( func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode", ): r"""Compute the Hessian of a given scalar function. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html Args: func (function): a Python function that takes Tensor inputs and returns a Tensor with a single element. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. create_graph (bool, optional): If ``True``, the Hessian will be computed in a differentiable manner. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the hessian for said inputs, which is the expected mathematical value. Defaults to ``False``. vectorize (bool, optional): This feature is experimental. Please consider using :func:`flow.func.hessian` instead if you are looking for something less experimental and more performant. When computing the hessian, usually we invoke ``autograd.grad`` once per row of the hessian. If this flag is ``True``, we use the vmap prototype feature as the backend to vectorize calls to ``autograd.grad`` so we only invoke it once instead of once per row. This should lead to performance improvements in many use cases, however, due to this feature being incomplete, there may be performance cliffs. Please use `flow._C._debug_only_display_vmap_fallback_warnings(True)` to show any performance warnings and file us issues if warnings exist for your use case. Defaults to ``False``. outer_jacobian_strategy (str, optional): The Hessian is computed by computing the Jacobian of a Jacobian. The inner Jacobian is always computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` or ``"reverse-mode"`` determines whether the outer Jacobian will be computed with forward or reverse mode AD. Currently, computing the outer Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults to ``"reverse-mode"``. Returns: Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, this will be a single Tensor containing the Hessian for the input. If it is a tuple, then the Hessian will be a tuple of tuples where ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input and ``j``\th input with size the sum of the size of the ``i``\th input plus the size of the ``j``\th input. ``Hessian[i][j]`` will have the same dtype and device as the corresponding ``i``\th input. Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def pow_reducer(x): ... return x.pow(3).sum() >>> inputs = flow.rand(2, 2) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> hessian(pow_reducer, inputs) tensor([[[[5.2265, 0.0000], [0.0000, 0.0000]], [[0.0000, 4.8221], [0.0000, 0.0000]]], [[[0.0000, 0.0000], [1.9456, 0.0000]], [[0.0000, 0.0000], [0.0000, 3.2550]]]]) >>> hessian(pow_reducer, inputs, create_graph=True) tensor([[[[5.2265, 0.0000], [0.0000, 0.0000]], [[0.0000, 4.8221], [0.0000, 0.0000]]], [[[0.0000, 0.0000], [1.9456, 0.0000]], [[0.0000, 0.0000], [0.0000, 3.2550]]]], grad_fn=) >>> def pow_adder_reducer(x, y): ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() >>> inputs = (flow.rand(2), flow.rand(2)) >>> hessian(pow_adder_reducer, inputs) ((tensor([[4., 0.], [0., 4.]]), tensor([[0., 0.], [0., 0.]])), (tensor([[0., 0.], [0., 0.]]), tensor([[6., 0.], [0., 6.]]))) """ is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") assert outer_jacobian_strategy in ( "forward-mode", "reverse-mode", ), 'Expected strategy to be either "forward-mode" or "reverse-mode".' def ensure_single_output_function(*inp): out = func(*inp) is_out_tuple, t_out = _as_tuple( out, "outputs of the user-provided function", "hessian" ) _check_requires_grad(t_out, "outputs", strict=strict) if is_out_tuple or not isinstance(out, flow.Tensor): raise RuntimeError( "The function given to hessian should return a single Tensor" ) if out.nelement() != 1: raise RuntimeError( "The Tensor returned by the function given to hessian should contain a single element" ) return out.squeeze() def jac_func(*inp): if outer_jacobian_strategy == "forward-mode": # _grad_preprocess requires create_graph=True and input to require_grad # or else the input will be detached inp = tuple(t.requires_grad_(True) for t in inp) jac = jacobian(ensure_single_output_function, inp, create_graph=True) _check_requires_grad(jac, "jacobian", strict=strict) return jac res = jacobian( jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize, strategy=outer_jacobian_strategy, ) return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) def vhp(func, inputs, v=None, create_graph=False, strict=False): r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.vhp.html Args: func (function): a Python function that takes Tensor inputs and returns a Tensor with a single element. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. v (tuple of Tensors or Tensor): The vector for which the vector Hessian product is computed. Must be the same size as the input of ``func``. This argument is optional when ``func``'s input contains a single element and (if it is not provided) will be set as a Tensor containing a single ``1``. create_graph (bool, optional): If ``True``, both the output and result will be computed in a differentiable way. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the vhp for said inputs, which is the expected mathematical value. Defaults to ``False``. Returns: output (tuple): tuple with: func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` vhp (tuple of Tensors or Tensor): result of the dot product with the same shape as the inputs. Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def pow_reducer(x): ... return x.pow(3).sum() >>> inputs = flow.rand(2, 2) >>> v = flow.ones(2, 2) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> vhp(pow_reducer, inputs, v) (tensor(0.5591), tensor([[1.0689, 1.2431], [3.0989, 4.4456]])) >>> vhp(pow_reducer, inputs, v, create_graph=True) (tensor(0.5591, grad_fn=), tensor([[1.0689, 1.2431], [3.0989, 4.4456]], grad_fn=)) >>> def pow_adder_reducer(x, y): ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() >>> inputs = (flow.rand(2), flow.rand(2)) >>> v = (flow.zeros(2), flow.ones(2)) >>> vhp(pow_adder_reducer, inputs, v) (tensor(4.8053), (tensor([0., 0.]), tensor([6., 6.]))) """ with flow.enable_grad(): is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) if v is not None: _, v = _as_tuple(v, "v", "vhp") v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: raise RuntimeError( "The vector v can only be None if the input to the user-provided function " "is a single Tensor with a single element." ) outputs = func(*inputs) is_outputs_tuple, outputs = _as_tuple( outputs, "outputs of the user-provided function", "vhp" ) _check_requires_grad(outputs, "outputs", strict=strict) if is_outputs_tuple or not isinstance(outputs[0], flow.Tensor): raise RuntimeError( "The function given to vhp should return a single Tensor" ) if outputs[0].nelement() != 1: raise RuntimeError( "The Tensor returned by the function given to vhp should contain a single element" ) jac = _autograd_grad(outputs, inputs, create_graph=True) _check_requires_grad(jac, "jacobian", strict=strict) enable_grad = True if create_graph else flow.is_grad_enabled() with flow.set_grad_enabled(enable_grad): grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") outputs = _grad_postprocess(outputs, create_graph) vhp = _grad_postprocess(vhp, create_graph) return ( _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vhp, is_inputs_tuple), ) def hvp(func, inputs, v=None, create_graph=False, strict=False): r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.hvp.html Args: func (function): a Python function that takes Tensor inputs and returns a Tensor with a single element. inputs (tuple of Tensors or Tensor): inputs to the function ``func``. v (tuple of Tensors or Tensor): The vector for which the Hessian vector product is computed. Must be the same size as the input of ``func``. This argument is optional when ``func``'s input contains a single element and (if it is not provided) will be set as a Tensor containing a single ``1``. create_graph (bool, optional): If ``True``, both the output and result will be computed in a differentiable way. Note that when ``strict`` is ``False``, the result can not require gradients or be disconnected from the inputs. Defaults to ``False``. strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the hvp for said inputs, which is the expected mathematical value. Defaults to ``False``. Returns: output (tuple): tuple with: func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` hvp (tuple of Tensors or Tensor): result of the dot product with the same shape as the inputs. Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def pow_reducer(x): ... return x.pow(3).sum() >>> inputs = flow.rand(2, 2) >>> v = flow.ones(2, 2) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> hvp(pow_reducer, inputs, v) (tensor(0.1448), tensor([[2.0239, 1.6456], [2.4988, 1.4310]])) >>> hvp(pow_reducer, inputs, v, create_graph=True) (tensor(0.1448, grad_fn=), tensor([[2.0239, 1.6456], [2.4988, 1.4310]], grad_fn=)) >>> def pow_adder_reducer(x, y): ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() >>> inputs = (flow.rand(2), flow.rand(2)) >>> v = (flow.zeros(2), flow.ones(2)) >>> hvp(pow_adder_reducer, inputs, v) (tensor(2.3030), (tensor([0., 0.]), tensor([6., 6.]))) Note: This function is significantly slower than `vhp` due to backward mode AD constraints. If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you know that your function satisfies this condition, you should use vhp instead that is much faster with the current implementation. """ with flow.enable_grad(): is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) if v is not None: _, v = _as_tuple(v, "v", "hvp") v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: raise RuntimeError( "The vector v can only be None if the input to the user-provided function " "is a single Tensor with a single element." ) outputs = func(*inputs) is_outputs_tuple, outputs = _as_tuple( outputs, "outputs of the user-provided function", "hvp" ) _check_requires_grad(outputs, "outputs", strict=strict) if is_outputs_tuple or not isinstance(outputs[0], flow.Tensor): raise RuntimeError( "The function given to hvp should return a single Tensor" ) if outputs[0].nelement() != 1: raise RuntimeError( "The Tensor returned by the function given to hvp should contain a single element" ) jac = _autograd_grad(outputs, inputs, create_graph=True) _check_requires_grad(jac, "jacobian", strict=strict) grad_jac = tuple(flow.zeros_like(inp, requires_grad=True) for inp in inputs) double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) _check_requires_grad(jac, "hessian", strict=strict) enable_grad = True if create_graph else flow.is_grad_enabled() with flow.set_grad_enabled(enable_grad): grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) hvp = _fill_in_zeros( grad_res, inputs, strict, create_graph, "double_back_trick" ) outputs = _grad_postprocess(outputs, create_graph) hvp = _grad_postprocess(hvp, create_graph) return ( _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(hvp, is_inputs_tuple), ) ================================================ FILE: python/oneflow/autograd/graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This file is mostly copied from PyTorch import oneflow as flow from typing import Callable, Any class saved_tensors_hooks: """Context-manager that sets a pair of pack / unpack hooks for saved tensors. Use this context-manager to define how intermediary results of an operation should be packed before saving, and unpacked on retrieval. In that context, the ``pack_hook`` function will be called everytime an operation saves a tensor for backward (this includes intermediary results saved using :func:`~oneflow.autograd.function.save_for_backward` but also those recorded by a OneFlow-defined operation). The output of ``pack_hook`` is then stored in the computation graph instead of the original tensor. The ``unpack_hook`` is called when the saved tensor needs to be accessed, namely when executing :func:`oneflow.Tensor.backward()` or :func:`oneflow.autograd.grad()`. It takes as argument the *packed* object returned by ``pack_hook`` and should return a tensor which has the same content as the original tensor (passed as input to the corresponding ``pack_hook``). The hooks should have the following signatures: pack_hook(tensor: Tensor) -> Any unpack_hook(Any) -> Tensor where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms of value, size, dtype and device. Example:: >>> def pack_hook(x): ... print("Packing", x) ... return x >>> >>> def unpack_hook(x): ... print("Unpacking", x) ... return x >>> >>> a = flow.ones(5, requires_grad=True) >>> b = flow.ones(5, requires_grad=True) * 2 >>> with flow.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): ... y = a * b Packing tensor([1., 1., 1., 1., 1.]) Packing tensor([2., 2., 2., 2., 2.]) >>> y.sum().backward() Unpacking tensor([1., 1., 1., 1., 1.]) Unpacking tensor([2., 2., 2., 2., 2.]) .. warning :: Performing an inplace operation on the input to either hooks may lead to undefined behavior. .. warning :: Only one pair of hooks is allowed at a time. When recursively nesting this context-manager, only the inner-most pair of hooks will be applied. """ def __init__( self, pack_hook: Callable[["flow.Tensor"], Any], unpack_hook: Callable[[Any], "flow.Tensor"], ): self.pack_hook = pack_hook self.unpack_hook = unpack_hook def __enter__(self): flow._oneflow_internal.autograd.graph.append_new_hooks( self.pack_hook, self.unpack_hook ) def __exit__(self, *args: Any): flow._oneflow_internal.autograd.graph.pop_hooks() ================================================ FILE: python/oneflow/autograd/profiler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.profiler.profiler import profile def record_function(): raise NotImplementedError() class emit_nvtx: def __init__(self): raise NotImplementedError() ================================================ FILE: python/oneflow/autoprof/__init__.py ================================================ ================================================ FILE: python/oneflow/autoprof/__main__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import atexit import csv import unittest import os import sys import subprocess import tempfile import oneflow as flow import oneflow.test_utils.automated_test_util.profiler as auto_profiler from oneflow.autoprof.util import * csv_filename = os.getenv("ONEFLOW_PROFILE_CSV", "op_prof") if csv_filename[-4:] != ".csv": csv_filename += ".csv" f = open(csv_filename, "w") # all functions registered are called in last in, first out order if flow.support.env_var_util.parse_boolean_from_env( "ONEFLOW_PROFILE_PRINT_SUMMARY", True ): atexit.register(print_summary_from_csv, csv_filename) atexit.register(lambda f: f.close(), f) writer = csv.writer(f) ONLY_ONEFLOW = flow.support.env_var_util.parse_boolean_from_env( "ONEFLOW_PROFILE_ONLY_ONEFLOW", False ) ONLY_PYTORCH = flow.support.env_var_util.parse_boolean_from_env( "ONEFLOW_PROFILE_ONLY_PYTORCH", False ) assert not (ONLY_ONEFLOW and ONLY_PYTORCH) if not ONLY_ONEFLOW and not ONLY_PYTORCH: env = os.environ.copy() env.update({"ONEFLOW_PROFILE_ONLY_ONEFLOW": "1"}) temp_f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) env.update({"ONEFLOW_PROFILE_CSV": temp_f.name}) env.update({"ONEFLOW_PROFILE_PRINT_SUMMARY": "0"}) subprocess.run([sys.executable, "-m", "oneflow.autoprof", *sys.argv[1:]], env=env) temp_f.close() temp_f = open(temp_f.name, "r") rows = list(csv.reader(temp_f)) temp_f.close() os.remove(temp_f.name) env = os.environ.copy() env.update({"ONEFLOW_PROFILE_ONLY_PYTORCH": "1"}) temp_f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) env.update({"ONEFLOW_PROFILE_CSV": temp_f.name}) env.update({"ONEFLOW_PROFILE_PRINT_SUMMARY": "0"}) subprocess.run([sys.executable, "-m", "oneflow.autoprof", *sys.argv[1:]], env=env) temp_f.close() temp_f = open(temp_f.name, "r") rows.extend(list(csv.reader(temp_f))[1:]) temp_f.close() os.remove(temp_f.name) writer.writerows(rows) exit(0) writer.writerow( [ "OP", "Args", "Library", "Kernel Time (us, GPU)", "Kernel Bandwidth (GB/s, GPU)", "Kernel Time (us, 1 CPU)", "End-to-end Time (us, 1 CPU)", "Kernel Time (us, 32 CPUs)", "End-to-end Time (us, 32 CPUs)", "Description", ] ) auto_profiler.set_hardware_info_list([("cuda", None), ("cpu", 1), ("cpu", 32)]) if ONLY_ONEFLOW: auto_profiler.profiled_framework = ["oneflow"] if ONLY_PYTORCH: auto_profiler.profiled_framework = ["pytorch"] auto_profiler.set_profiler_hook(lambda profs: add_row(profs, writer, f)) # Align with https://github.com/python/cpython/blob/3.10/Lib/unittest/__main__.py __unittest = True from unittest.main import main loader = unittest.TestLoader() loader.testMethodPrefix = "profile_" main(module=None, testLoader=loader) ================================================ FILE: python/oneflow/autoprof/util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Iterable, Union, TypeVar from rich import box from rich.console import Console from rich.table import Table import csv import oneflow.test_utils.automated_test_util.profiler as auto_profiler T = TypeVar("T") def get_sole_value(x: Iterable[T]) -> T: s = set(x) assert len(s) == 1 return list(s)[0] def get_pytorch_cpu_kernel_time(prof) -> Union[str, float]: assert prof.num > 1 cpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages())) if len(cpu_kernel_items) == 0: return "-" kernel_cpu_time = ( sum(map(lambda x: x.self_cpu_time_total, cpu_kernel_items)) / prof.num ) return round(kernel_cpu_time, 1) def get_oneflow_cpu_kernel_time(prof) -> Union[str, float]: assert prof.num > 1 cpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages())) if len(cpu_kernel_items) == 0: return "-" kernel_cpu_time = sum(map(lambda x: x.cpu_time_total, cpu_kernel_items)) / prof.num return round(kernel_cpu_time, 1) def get_pytorch_gpu_kernel_time(prof) -> Union[str, float]: gpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages())) if len(gpu_kernel_items) == 0: return "-" kernel_gpu_time = ( sum(map(lambda x: x.self_cuda_time_total, gpu_kernel_items)) / prof.num ) return round(kernel_gpu_time, 1) def get_oneflow_gpu_kernel_time(prof) -> Union[str, float]: gpu_kernel_items = list( filter(lambda x: x.cuda_time_total is not None, prof.key_averages()) ) if len(gpu_kernel_items) == 0: return "-" kernel_gpu_time = sum(map(lambda x: x.cuda_time_total, gpu_kernel_items)) / prof.num return round(kernel_gpu_time, 1) def get_oneflow_gpu_kernel_bandwidth(prof) -> str: gpu_kernel_items = list( filter(lambda x: x.cuda_time_total is not None, prof.key_averages()) ) if len(gpu_kernel_items) == 0: return "-" if len(gpu_kernel_items) == 1: return gpu_kernel_items[0].bandwidth return ", ".join([f"{x.name}: {x.bandwidth}" for x in gpu_kernel_items]) def get_pytorch_cpu_end_to_end_time(prof) -> float: total = get_sole_value( filter(lambda x: x.key == auto_profiler.END_TO_END, prof.key_averages()) ) assert total.count == 1 return round(total.cpu_time / prof.num, 1) def get_oneflow_cpu_end_to_end_time(prof) -> float: total = list( filter(lambda x: x.name == auto_profiler.END_TO_END, prof.key_averages()) )[0] assert total.count == 1 return round(total.cpu_time / prof.num, 1) def add_row(profs, writer, f): non_none_profs = list(filter(lambda x: x is not None, profs)) op_name = get_sole_value([prof.op_name for prof in non_none_profs]) args_description = get_sole_value( [prof.args_description for prof in non_none_profs] ) additional_description = get_sole_value( [prof.additional_description for prof in non_none_profs] ) if "oneflow" in auto_profiler.profiled_framework: writer.writerow( [ op_name, args_description, "OneFlow", get_oneflow_gpu_kernel_time(profs[0]), get_oneflow_gpu_kernel_bandwidth(profs[0]), get_oneflow_cpu_kernel_time(profs[1]), get_oneflow_cpu_end_to_end_time(profs[1]), get_oneflow_cpu_kernel_time(profs[2]), get_oneflow_cpu_end_to_end_time(profs[2]), additional_description, ] ) if "pytorch" in auto_profiler.profiled_framework: writer.writerow( [ op_name, args_description, "PyTorch", get_pytorch_gpu_kernel_time(profs[3]), "-", get_pytorch_cpu_kernel_time(profs[4]), get_pytorch_cpu_end_to_end_time(profs[4]), get_pytorch_cpu_kernel_time(profs[5]), get_pytorch_cpu_end_to_end_time(profs[5]), additional_description, ] ) f.flush() def print_summary_from_csv(filename) -> None: print("----------------------------------------------------------------------") print( 'Summary ("KT" means "Kernel Time", "ET" means "End-to-end Time", in microseconds; "BW" means "Bandwidth" in GB/s):' ) with open(filename, "r") as f: table = Table( "OP", "Args", "Lib", "KT(GPU)", "BW(GPU)", "KT(1 CPU)", "ET(1 CPU)", "KT(32 CPU)", "ET(32 CPU)", box=box.SIMPLE, ) for row in list(csv.reader(f))[1:]: row[2] = {"PyTorch": "PT", "OneFlow": "OF"}[row[2]] table.add_row(*row[:-1]) Console().print(table) ================================================ FILE: python/oneflow/backends/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from . import cuda from . import cudnn from . import mps ================================================ FILE: python/oneflow/backends/cuda/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal class cuMatmulMode: def __getattr__(self, name): if name == "allow_tf32": return oneflow._oneflow_internal.ep.is_matmul_allow_tf32() elif name == "allow_fp16_reduced_precision_reduction": return ( oneflow._oneflow_internal.ep.is_matmul_allow_fp16_reduced_precision_reduction() ) raise AssertionError("Unknown attribute " + name) def __setattr__(self, name, value): if name == "allow_tf32": return oneflow._oneflow_internal.ep.set_matmul_allow_tf32(value) elif name == "allow_fp16_reduced_precision_reduction": return oneflow._oneflow_internal.ep.set_matmul_allow_fp16_reduced_precision_reduction( value ) raise AssertionError("Unknown attribute " + name) matmul = cuMatmulMode() ================================================ FILE: python/oneflow/backends/cudnn/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.config_util import ( api_reserved_device_mem_mbyte as set_reserved_mem_mbytes, ) from oneflow.framework.config_util import ( api_enable_cudnn_fused_normalization_add_relu as enable_fused_normalization_add_relu, ) from oneflow.framework.config_util import ( api_enable_cudnn_conv_heuristic_search_algo as enable_conv_heuristic_search_algo, ) benchmark = False ================================================ FILE: python/oneflow/backends/mps/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ __all__ = ["is_available"] def is_available() -> bool: return False ================================================ FILE: python/oneflow/boxing/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.config_util import api_enable_fusion as enable_fusion from . import nccl ================================================ FILE: python/oneflow/boxing/nccl/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.config_util import ( api_nccl_fusion_threshold_mb as set_fusion_threshold_mbytes, api_nccl_fusion_max_ops as set_fusion_max_ops_num, api_nccl_fusion_all_reduce as allow_fuse_all_reduce, api_nccl_fusion_reduce_scatter as allow_fuse_reduce_scatter, api_nccl_fusion_all_gather as allow_fuse_all_gather, api_nccl_fusion_reduce as allow_fuse_reduce, api_nccl_fusion_broadcast as allow_fuse_broadcast, api_nccl_enable_mixed_fusion as allow_fuse_mixed_ops, api_nccl_fusion_all_reduce_use_buffer as enable_use_buffer_to_fuse_all_reduce, ) from oneflow.framework.config_util import api_nccl_num_streams as set_stream_num from oneflow.framework.config_util import ( api_nccl_enable_all_to_all as enable_all_to_all, ) from oneflow.framework.config_util import ( api_nccl_use_compute_stream as enable_use_compute_stream, ) from oneflow.framework.config_util import ( api_disable_group_boxing_by_dst_parallel as disable_group_boxing_by_dst_parallel, ) ================================================ FILE: python/oneflow/comm/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.comm.comm_ops import all_reduce from oneflow.comm.comm_ops import all_gather from oneflow.comm.comm_ops import all_gather_into_tensor from oneflow.comm.comm_ops import reduce_scatter_tensor from oneflow.comm.comm_ops import broadcast from oneflow.comm.comm_ops import scatter from oneflow.comm.comm_ops import reduce from oneflow.comm.comm_ops import all_to_all from oneflow.comm.comm_ops import barrier from oneflow.comm.comm_ops import reduce_scatter from oneflow.comm.comm_ops import gather from oneflow._C import send, recv ================================================ FILE: python/oneflow/comm/comm_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import numpy as np def all_reduce(tensor): """ Reduces the tensor data across all machines in such a way that all get the final result. After the call ``tensor`` is going to be bitwise identical in all processes. Args: tensor (Tensor): the input tensor For example: .. code-block:: python >>> # We have 1 process groups, 2 ranks. >>> import oneflow as flow >>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() >>> # tensor on rank0 >>> tensor # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) >>> # tensor on rank1 >>> tensor # doctest: +ONLY_CHECK_RANK_1 tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) >>> flow.comm.all_reduce(tensor) >>> tensor.numpy() array([[3, 5], [7, 9]], dtype=int64) """ assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.device.index == flow.env.get_local_rank() assert tensor.is_local flow._C.local_all_reduce(tensor, inplace=True) def all_gather(tensor_list, tensor): """ Gathers tensors from the whole group in a list. Args: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. tensor (Tensor): Tensor to be broadcast from current process. For example: .. code-block:: python >>> # We have 1 process groups, 2 ranks. >>> import oneflow as flow >>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() >>> # input on rank0 >>> input # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) >>> # input on rank1 >>> input # doctest: +ONLY_CHECK_RANK_1 tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) >>> tensor_list = [flow.zeros(2, 2, dtype=flow.int64) for _ in range(2)] >>> flow.comm.all_gather(tensor_list, input) >>> # result on rank0 >>> tensor_list # doctest: +ONLY_CHECK_RANK_0 [tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64), tensor([[2, 3], [4, 5]], device='cuda:0', dtype=oneflow.int64)] >>> # result on rank1 >>> tensor_list # doctest: +ONLY_CHECK_RANK_1 [tensor([[1, 2], [3, 4]], device='cuda:1', dtype=oneflow.int64), tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64)] """ assert isinstance(tensor, flow._oneflow_internal.Tensor) assert isinstance(tensor_list, list) assert len(tensor_list) == flow.env.get_world_size() assert tensor.device.index == flow.env.get_local_rank() assert tensor.is_local tensor = tensor.expand(*([1] + list(tensor.shape))) device_type = tensor.device.type placement = flow.placement.all(device_type) tensor = ( tensor.to_global(placement=placement, sbp=flow.sbp.split(0)) .to_global(placement=placement, sbp=flow.sbp.broadcast) .to_local() ) assert len(tensor_list) == flow.env.get_world_size() # TODO(): getitem has bug on global tensor with size = [2, 1]. for i in range(tensor.shape[0]): tensor_list[i] = tensor[i] def all_gather_into_tensor(output_tensor, input_tensor): """ Gather tensors from all ranks and put them in a single output tensor. Args: output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of "concatenation", see ``oneflow.cat()``; (ii) a stack of all the input tensors along the primary dimension; for definition of "stack", see ``oneflow.stack()``. Examples below may better explain the supported output forms. input_tensor (Tensor): Tensor to be gathered from current rank. The input tensors in this API must have the same size across all ranks. For example: .. code-block:: python >>> # We have 1 process groups, 2 ranks. >>> # All tensors below are of flow.int64 dtype and on CUDA devices. >>> import oneflow as flow >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device="cuda") + flow.env.get_rank() * 6 >>> tensor_in # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2, 3], [4, 5, 6]], device='cuda:0', dtype=oneflow.int64) >>> # Output in concatenation form >>> tensor_out = flow.zeros(4, 3, dtype=flow.int64, device="cuda") >>> flow.comm.all_gather_into_tensor(tensor_out, tensor_in) >>> # result on rank0 >>> tensor_out # doctest: +ONLY_CHECK_RANK_0 tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out # doctest: +ONLY_CHECK_RANK_1 tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:1', dtype=oneflow.int64) >>> # Output in stack form >>> tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device="cuda") >>> flow.comm.all_gather_into_tensor(tensor_out2, tensor_in) >>> # result on rank0 >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_0 tensor([[[ 1, 2], [ 3, 4], [ 5, 6]], [[ 7, 8], [ 9, 10], [11, 12]]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_1 tensor([[[ 1, 2], [ 3, 4], [ 5, 6]], [[ 7, 8], [ 9, 10], [11, 12]]], device='cuda:1', dtype=oneflow.int64) """ assert output_tensor.is_local assert input_tensor.is_local flow._C.local_all_gather(output_tensor, input_tensor) def broadcast(tensor, src): """ Broadcasts the tensor to the whole group. ``tensor`` must have the same number of elements in all processes participating in the collective. Args: tensor (Tensor): Data to be sent if ``src`` is the rank of current process, and tensor to be used to save received data otherwise. src (int): Source rank. .. code-block:: python >>> # We have 1 process groups, 2 ranks. >>> import oneflow as flow >>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() >>> # input on rank0 >>> tensor # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) >>> # input on rank1 >>> tensor # doctest: +ONLY_CHECK_RANK_1 tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) >>> flow.comm.broadcast(tensor, 0) >>> # result on rank0 >>> tensor # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) """ assert isinstance(src, int) assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local flow._C.comm_broadcast(tensor, src_rank=src, inplace=True) def scatter(tensor, scatter_list=None, src=0): """ Scatters a list of tensors to all processes in a group. Each process will receive exactly one tensor and store its data in the ``tensor`` argument. Args: tensor (Tensor): Output tensor. scatter_list (list[Tensor]): List of tensors to scatter (default is None, must be specified on the source rank) src (int): Source rank (default is 0) """ assert isinstance(src, int) assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local out_shape = tensor.shape if flow.env.get_rank() == src: tensor.data = scatter_list[src] assert isinstance(scatter_list, list) assert len(scatter_list) == flow.env.get_world_size() for i in range(len(scatter_list)): if i == src: continue assert isinstance(scatter_list[i], flow._oneflow_internal.Tensor) assert scatter_list[i].is_local assert ( scatter_list[i].shape == out_shape ), f"invalid tensor size at index {i}: {out_shape} vs {scatter_list[i].shape}" flow.comm.send(scatter_list[i], i) # send/recv on the same rank is invalid if flow.env.get_rank() != src: flow.comm.recv(src, out=tensor) def reduce(tensor, dst): """ Reduces the tensor data across all machines. Only the process with rank ``dst`` is going to receive the final result. Args: tensor (Tensor): Input and output of the collective. The function operates in-place. dst (int): Destination rank """ assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local assert isinstance(dst, int) original_tensor = flow._C.identity(tensor) flow.comm.all_reduce(tensor) if flow.env.get_rank() != dst: tensor.data = original_tensor def all_to_all(output_tensor_list, input_tensor_list): """ Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. Args: output_tensor_list (list[Tensor]): List of tensors to be gathered one per rank. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. """ def _check_list(tensor_list): assert isinstance(tensor_list, list) assert len(tensor_list) == flow.env.get_world_size() shape = tensor_list[0].shape dtype = tensor_list[0].dtype device = tensor_list[0].device for tensor in tensor_list: assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local assert shape == tensor.shape assert dtype == tensor.dtype assert device == tensor.device _check_list(output_tensor_list) _check_list(input_tensor_list) assert input_tensor_list[0].shape == output_tensor_list[0].shape assert input_tensor_list[0].dtype == output_tensor_list[0].dtype assert input_tensor_list[0].device == output_tensor_list[0].device for i in range(flow.env.get_world_size()): flow.comm.scatter( output_tensor_list[i], input_tensor_list if i == flow.env.get_rank() else [], src=i, ) def barrier(): """ Synchronizes all processes. """ flow._oneflow_internal.eager.ClusterSync() def reduce_scatter(output, input_list): """ Reduces, then scatters a list of tensors to all processes in a group. Args: output (Tensor): Output tensor. input_list (list[Tensor]): List of tensors to reduce and scatter. """ assert isinstance(output, flow._oneflow_internal.Tensor) assert output.is_local assert isinstance(input_list, list) assert len(input_list) == flow.env.get_world_size() output_shape = output.shape device_type = output.device.type placement = flow.placement.all(device_type) reduced_tensor_list = [] for tensor in input_list: assert tensor.is_local assert tensor.shape == output_shape tensor = tensor.to_global( placement=placement, sbp=flow.sbp.partial_sum ).to_global(placement=placement, sbp=flow.sbp.broadcast) reduced_tensor_list.append(tensor.to_local()) output.data = reduced_tensor_list[flow.env.get_rank()] def reduce_scatter_tensor(output_tensor, input_tensor): """ Reduces, then scatters a tensor to all ranks. Args: output (Tensor): Output tensor. It should have the same size across all ranks. input (Tensor): Input tensor to be reduced and scattered. Its size should be output tensor size times the world size. The input tensor can have one of the following shapes: (i) a concatenation of the output tensors along the primary dimension, or (ii) a stack of the output tensors along the primary dimension. For definition of "concatenation", see ``oneflow.cat()``. For definition of "stack", see ``oneflow.stack()``. For example: .. code-block:: python >>> # We have 1 process groups, 2 ranks. >>> # All tensors below are of flow.int64 dtype and on CUDA devices. >>> import oneflow as flow >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=flow.int64, device="cuda") >>> tensor_in # doctest: +ONLY_CHECK_RANK_0 tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:0', dtype=oneflow.int64) >>> # Output in concatenation form >>> tensor_out = flow.zeros(2, 3, dtype=flow.int64, device="cuda") >>> flow.comm.reduce_scatter_tensor(tensor_out, tensor_in) >>> # result on rank0 >>> tensor_out # doctest: +ONLY_CHECK_RANK_0 tensor([[ 2, 4, 6], [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out # doctest: +ONLY_CHECK_RANK_1 tensor([[14, 16, 18], [20, 22, 24]], device='cuda:1', dtype=oneflow.int64) >>> # Output in stack form >>> tensor_in2 = tensor_in.reshape(2, 3, 2) >>> tensor_out2 = flow.zeros(2, 3, dtype=flow.int64, device="cuda") >>> flow.comm.reduce_scatter_tensor(tensor_out2, tensor_in2) >>> # result on rank0 >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_0 tensor([[ 2, 4, 6], [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_1 tensor([[14, 16, 18], [20, 22, 24]], device='cuda:1', dtype=oneflow.int64) """ assert output_tensor.is_local assert input_tensor.is_local flow._C.local_reduce_scatter(output_tensor, input_tensor) def gather(tensor, gather_list=None, dst=0): """ Gathers a list of tensors in a single process. Args: tensor (Tensor): Input tensor. gather_list (list[Tensor], optional): List of appropriately-sized tensors to use for gathered data (default is None, must be specified on the destination rank) dst (int, optional): Destination rank (default is 0) """ assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local shape = tensor.shape dtype = tensor.dtype tensor = tensor.expand(*([1] + list(shape))) device_type = tensor.device.type placement = flow.placement.all(device_type) tensor = tensor.to_global(placement=placement, sbp=flow.sbp.split(0)).to_global( placement=placement, sbp=flow.sbp.broadcast ) if gather_list is None: gather_list = [ flow.empty(shape, dtype=dtype) for _ in range(flow.env.get_world_size()) ] assert gather_list is not None assert isinstance(gather_list, list) assert len(gather_list) == flow.env.get_world_size() for i in range(tensor.shape[0]): gather_list[i] = tensor[i].to_local() ================================================ FILE: python/oneflow/cuda/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.cuda.type_tensor import * from oneflow.cuda._utils import _get_device_index from typing import Optional, Tuple, Union, Any default_generators = flow._oneflow_internal.default_generators() def is_available() -> bool: r"""Returns a bool indicating if CUDA is currently available.""" # This function never throws and returns 0 if driver is missing or can't # be initialized return device_count() > 0 def device_count() -> int: r"""Returns the number of GPUs available.""" return flow._oneflow_internal.CudaGetDeviceCount() def current_device() -> int: r"""Returns local rank as device index.""" return flow._oneflow_internal.GetCudaDeviceIndex() def get_device_properties(device: Union[flow.device, str, int] = None): r"""Gets the properties of a device. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_properties.html. Args: device(oneflow.device or str or int): device for which to return the properties of the device. Returns: the properties of the device. """ device = _get_device_index(device, optional=True) return flow._oneflow_internal._get_device_properties(device) def get_device_capability( device: Optional[Union[flow.device, str, int]] = None ) -> Tuple[int, int]: r"""Gets the cuda capability of a device. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_capability.html. Args: device (oneflow.device or int or str, optional): device for which to return the device capability. It uses the current device, given by :func:`~oneflow.cuda.current_device`, if :attr:`device` is ``None`` (default). Returns: tuple(int, int): the major and minor cuda capability of the device """ device_prop = get_device_properties(device) return device_prop.major, device_prop.minor def get_device_name(device: Optional[Union[flow.device, str, int]] = None) -> str: r"""Gets the name of a device. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_name.html. Args: device (oneflow.device or int or str, optional): device for which to return the name. It uses the current device, given by :func:`~oneflow.cuda.current_device`, if :attr:`device` is ``None`` (default). Returns: str: the name of the device """ return get_device_properties(device).name def manual_seed_all(seed) -> None: r"""Sets the seed for generating random numbers on all GPUs. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.manual_seed_all.html. It's safe to call this function if CUDA is not available; in that case, it is silently ignored. Args: seed (int): The desired seed. """ seed = int(seed) flow._oneflow_internal.ManualSeedAllCudaGenerator(seed) def manual_seed(seed: int) -> None: r"""Sets the seed for generating random numbers for the current GPU. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.manual_seed.html. It's safe to call this function if CUDA is not available; in that case, it is silently ignored. Args: seed (int): The desired seed. .. warning:: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use :func:`manual_seed_all`. """ seed = int(seed) idx = current_device() flow._oneflow_internal.manual_seed(seed, "cuda", idx) def set_device(device: Union[flow.device, str, int]) -> None: r"""Sets the current device. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cuda.set_device.html. Usage of this function is discouraged in favor of :attr:`device`. In most cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. Args: device (flow.device or int): selected device. This function is a no-op if this argument is negative. """ device_idx = _get_device_index(device) if device_idx < 0: return if flow.env.get_world_size() > 0: if device_idx == flow.env.get_local_rank(): return raise ValueError( "Setting cuda device to a device whose index does not equal to the local rank is not supported." ) flow._oneflow_internal.SetCudaDeviceIndex(device_idx) def synchronize(device: Union[flow.device, str, int, None] = None) -> None: r""" Waits for all kernels in all streams on a CUDA device to complete. Note: In the eager mode of oneflow, all operations will be converted into instructions executed in the virtual machine, so in order to comply with the semantics of synchronization, this function will call the `eager.Sync()` function before the device is synchronized, which may affect the operations executed in other devices. Args: device (flow.device or int, optional): device for which to synchronize. It uses the current device, given by :func:`~oneflow.cuda.current_device`, if :attr:`device` is ``None`` (default). """ device_idx = _get_device_index(device, optional=True) if device_idx >= 0: flow._oneflow_internal.eager.Sync() flow._oneflow_internal.CudaSynchronize(device_idx) def empty_cache() -> None: r""" Releases all unoccupied cached memory currently held by the caching allocators of all OneFlow streams so those can be re-allocated in OneFlow streams or other GPU application and visible in `nvidia-smi`. Note: :func:`~flow.cuda.empty_cache` may enable one stream to release memory and then freed memory can be used by another stream. It may also help reduce fragmentation of GPU memory in certain cases. """ return flow._oneflow_internal.EmptyCache() def mem_get_info(device: Any = None) -> Tuple[int, int]: r"""Returns the global free and total GPU memory for a given device using cudaMemGetInfo. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.cuda.mem_get_info.html Args: device (flow.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~flow.cuda.current_device`, if :attr:`device` is ``None`` (default). """ if device is None: device = current_device() device = _get_device_index(device) return flow._oneflow_internal.CudaMemGetInfo(device) from .random import * # noqa: F403 class Event: def __init__(self): raise NotImplementedError() ================================================ FILE: python/oneflow/cuda/_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from typing import Any, Optional def _get_current_device_index() -> int: r"""Checks if there are CUDA devices available and returns the device index of the current default CUDA device. Returns -1 in case there are no CUDA devices available. Arguments: ``None`` """ if flow.cuda.is_available(): return flow.cuda.current_device() return -1 def _get_device_index( device: Any, optional: bool = False, allow_cpu: bool = False ) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. If :attr:`device` is a flow.device object, returns the device index if it is a CUDA device. Note that for a CUDA device without a specified index, i.e., ``flow.device('cuda')``, this will return the current default CUDA device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, CPU devices will be accepted and ``-1`` will be returned in this case. If :attr:`device` is a Python integer, it is returned as is. If :attr:`device` is ``None``, this will return the current default CUDA device if :attr:`optional` is ``True``. """ device_idx: Optional[int] = None if isinstance(device, str): device = flow.device(device) if isinstance(device, flow.device): if allow_cpu: if device.type not in ["cuda", "cpu"]: raise ValueError( "Expected a cuda or cpu device, but got: {}".format(device) ) elif device.type != "cuda": raise ValueError("Expected a cuda device, but got: {}".format(device)) device_idx = -1 if device.type == "cpu" else device.index if isinstance(device, int): device_idx = device if device_idx is None: if optional: device_idx = _get_current_device_index() else: raise ValueError( "Expected a flow.device with a specified index " "or an integer, but got:{}".format(device) ) return device_idx ================================================ FILE: python/oneflow/cuda/amp/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .autocast_mode import autocast from oneflow.amp import GradScaler ================================================ FILE: python/oneflow/cuda/amp/autocast_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from typing import Any, Optional class autocast(flow.amp.autocast_mode.autocast): r""" See :class:`oneflow.autocast`. ``oneflow.cuda.amp.autocast(args...)`` is equivalent to ``oneflow.autocast("cuda", args...)`` """ def __init__( self, enabled: bool = True, dtype: Optional[flow.dtype] = None, cache_enabled: Optional[bool] = None, ): super().__init__( "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) def __enter__(self): return super().__enter__() def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] return super().__exit__(exc_type, exc_val, exc_tb) def __call__(self, func): return super().__call__(func) ================================================ FILE: python/oneflow/cuda/random.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow import Tensor from typing import cast, Iterable, List, Union from . import current_device, device_count def get_rng_state(device: Union[int, str, flow.device] = "cuda") -> Tensor: r"""Returns the random number generator state of the specified GPU as a ByteTensor. Args: device (flow.device or int, optional): The device to return the RNG state of. Default: ``'cuda'`` (i.e., ``flow.device('cuda')``, the current CUDA device). """ # TODO (add lazy initialization mechanism in OneFlow) # _lazy_init() if isinstance(device, str): device = flow.device(device) elif isinstance(device, int): device = flow.device("cuda", device) idx = device.index if idx is None: idx = current_device() default_generator = flow.cuda.default_generators[idx] return default_generator.get_state() def get_rng_state_all() -> List[Tensor]: r"""Returns a list of ByteTensor representing the random number states of all devices.""" results = [] for i in range(device_count()): results.append(get_rng_state(i)) return results def set_rng_state( new_state: Tensor, device: Union[int, str, flow.device] = "cuda" ) -> None: r"""Sets the random number generator state of the specified GPU. Args: new_state (flow.ByteTensor): The desired state device (flow.device or int, optional): The device to set the RNG state. Default: ``'cuda'`` (i.e., ``flow.device('cuda')``, the current CUDA device). """ new_state_copy = new_state.clone() if isinstance(device, str): device = flow.device(device) elif isinstance(device, int): device = flow.device("cuda", device) if device.type == "cpu": raise ValueError( "Cannot set RNG state for CPU device in flow.cuda.set_rng_state func!" ) idx = cast(flow.device, device).index if idx is None: idx = current_device() default_generator = flow.cuda.default_generators[idx] default_generator.set_state(new_state_copy) def set_rng_state_all(new_states: Iterable[Tensor]) -> None: r"""Sets the random number generator state of all devices. Args: new_states (Iterable of flow.ByteTensor): The desired state for each device""" for i, state in enumerate(new_states): set_rng_state(state, i) ================================================ FILE: python/oneflow/cuda/type_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import oneflow as flow from oneflow._C import cuda HalfTensor = cuda.HalfTensor FloatTensor = cuda.FloatTensor DoubleTensor = cuda.DoubleTensor BoolTensor = cuda.BoolTensor ByteTensor = cuda.ByteTensor CharTensor = cuda.CharTensor IntTensor = cuda.IntTensor LongTensor = cuda.LongTensor ComplexFloatTensor = cuda.ComplexFloatTensor ComplexDoubleTensor = cuda.ComplexDoubleTensor __all__ = [ "HalfTensor", "FloatTensor", "DoubleTensor", "BoolTensor", "ByteTensor", "CharTensor", "IntTensor", "LongTensor", "ComplexFloatTensor", "ComplexDoubleTensor", # TODO: Add support for BFloat16Tensor, ComplexHalfTensor ] ================================================ FILE: python/oneflow/data.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.experimental.load_mnist import load_mnist ================================================ FILE: python/oneflow/distributed/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # Just for alignment with pytorch, not really useful from .constants import default_pg_timeout from typing import List, Optional import oneflow as flow class ReduceOp: """Reduce operation enum. Mainly for PyTorch compatibility. Currently only support SUM. See also :func:`oneflow.comm.all_reduce()` """ SUM = "sum" def is_initialized() -> bool: """Always returns True. This function is only for PyTorch compatibility. Returns: True """ return True # PyTorch doesn't have torch.distributed.get_local_rank, # we add it for the consistency between flow.env and flow.distributed get_local_rank = flow.env.get_local_rank def get_rank(group=None) -> int: """Alias of `oneflow.env.get_rank()` for PyTorch compatibility. See also :func:`oneflow.env.get_rank()` """ assert group is None, "group is not supported yet" return flow.env.get_rank() def get_world_size(group=None) -> int: """Alias of `oneflow.env.get_world_size()` for PyTorch compatibility. See also :func:`oneflow.env.get_world_size()` """ assert group is None, "group is not supported yet" return flow.env.get_world_size() def send(tensor: flow.Tensor, dst: int, group=None, tag: int = 0) -> None: """Alias of `oneflow.comm.send()` for PyTorch compatibility. See also :func:`oneflow.comm.send()` """ assert group is None, "group is not supported yet" assert tag == 0, "tag is not supported yet" return flow.comm.send(tensor, dst) def recv(tensor: flow.Tensor, src: int, group=None, tag: int = 0) -> None: """Alias of `oneflow.comm.recv()` for PyTorch compatibility. See also :func:`oneflow.comm.recv()` """ assert group is None, "group is not supported yet" assert tag == 0, "tag is not supported yet" return flow.comm.recv(tensor, src) def broadcast( tensor: flow.Tensor, src: int, group=None, async_op: bool = False ) -> None: """Alias of `oneflow.comm.broadcast()` for PyTorch compatibility. See also :func:`oneflow.comm.broadcast()` """ assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.broadcast(tensor, src) def barrier(group=None, async_op=False, device_ids=None) -> None: """Alias of `oneflow.comm.barrier()` for PyTorch compatibility. See also :func:`oneflow.comm.barrier()` """ assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" assert device_ids is None, "device_ids is not supported yet" return flow.comm.barrier() def all_reduce( tensor: flow.Tensor, op: ReduceOp, group=None, async_op: bool = False ) -> None: """Alias of `oneflow.comm.all_reduce()` for PyTorch compatibility. See also :func:`oneflow.comm.all_reduce()` """ assert op == ReduceOp.SUM, "only ReduceOp.SUM is supported" assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.all_reduce(tensor) def all_gather( tensor_list: List[flow.Tensor], tensor: flow.Tensor, group=None, async_op: bool = False, ) -> None: """Alias of `oneflow.comm.all_gather()` for PyTorch compatibility. See also :func:`oneflow.comm.all_gather()` """ assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.all_gather(tensor_list, tensor) def reduce( tensor: flow.Tensor, dst: int, op: ReduceOp, group=None, async_op: bool = False ) -> None: """Alias of `oneflow.comm.reduce()` for PyTorch compatibility. See also :func:`oneflow.comm.reduce()` """ assert op == ReduceOp.SUM, "only ReduceOp.SUM is supported" assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.reduce(tensor, dst) def all_to_all( output_tensor_list: List[flow.Tensor], input_tensor_list: List[flow.Tensor], group=None, async_op: bool = False, ) -> None: """Alias of `oneflow.comm.all_to_all()` for PyTorch compatibility. See also :func:`oneflow.comm.all_to_all()` """ assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.all_to_all(output_tensor_list, input_tensor_list) def reduce_scatter( output: flow.Tensor, input_list: List[flow.Tensor], op: ReduceOp, group=None, async_op: bool = False, ) -> None: """Alias of `oneflow.comm.reduce_scatter()` for PyTorch compatibility. See also :func:`oneflow.comm.reduce_scatter()` """ assert op == ReduceOp.SUM, "only ReduceOp.SUM is supported" assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.reduce_scatter(output, input_list) def gather( tensor: flow.Tensor, gather_list: Optional[List[flow.Tensor]] = None, dst: int = 0, group=None, async_op: bool = False, ) -> None: """Alias of `oneflow.comm.gather()` for PyTorch compatibility. See also :func:`oneflow.comm.gather()` """ assert group is None, "group is not supported yet" assert async_op is False, "async_op is not supported yet" return flow.comm.gather(tensor, gather_list, dst) def is_available(): return True ================================================ FILE: python/oneflow/distributed/constants.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # Just for alignment with pytorch, not really useful from datetime import timedelta default_pg_timeout = timedelta(milliseconds=30 * 60 * 1000) ================================================ FILE: python/oneflow/distributed/launch.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ This file is mostly copied from PyTorch v1.8.1 torch/distributed/launch.py """ import os import signal import subprocess import sys import time from argparse import REMAINDER, ArgumentParser from typing import IO, Any, List, Optional stdout_filename = "stdout" stderr_filename = "stderr" def parse_args(): """ Helper function parsing the command line options @retval ArgumentParser """ parser = ArgumentParser( description="OneFlow distributed training launch helper utility that will spawn up multiple distributed processes" ) parser.add_argument( "--nnodes", type=int, default=1, help="The number of nodes to use for distributed training", ) parser.add_argument( "--node_rank", type=int, default=0, help="The rank of the node for multi-node distributed training", ) parser.add_argument( "--nproc_per_node", type=int, default=1, help="The number of processes to launch on each node, for GPU training, this is recommended to be set to the number of GPUs in your system so that each process can be bound to a single GPU.", ) parser.add_argument( "--master_addr", default="127.0.0.1", type=str, help="Master node (rank 0)'s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, the --master_addr can simply be 127.0.0.1", ) parser.add_argument( "--master_port", default=29500, type=int, help="Master node (rank 0)'s free port that needs to be used for communication during distributed training", ) parser.add_argument( "-m", "--module", default=False, action="store_true", help="Changes each process to interpret the launch script as a python module, executing with the same behavior as'python -m'.", ) parser.add_argument( "--no_python", default=False, action="store_true", help='Do not prepend the training script with "python" - just exec it directly. Useful when the script is not a Python script.', ) parser.add_argument( "--redirect_stdout_and_stderr", default=False, action="store_true", help=f"write the stdout and stderr to files\n '{stdout_filename}' and '{stderr_filename}' in logdir.", ) parser.add_argument( "--logdir", default="log", type=str, help=f"Relative path to write subprocess logs to. Passing in a relative\n path will create a directory if needed. Note that\n successive runs with the same path to write logs to will overwrite existing logs,\n so be sure to save logs as needed.", ) parser.add_argument( "training_script", type=str, help="The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script", ) parser.add_argument("training_script_args", nargs=REMAINDER) return parser.parse_args() def main(): args = parse_args() dist_world_size = args.nproc_per_node * args.nnodes current_env = os.environ.copy() current_env["MASTER_ADDR"] = args.master_addr current_env["MASTER_PORT"] = str(args.master_port) current_env["WORLD_SIZE"] = str(dist_world_size) if args.master_port is None or args.master_port >= 2 ** 16: raise ValueError( f"The port number of the master endpoint '{args.master_addr}:{args.master_port}' must be an integer " "between 0 and 65536." ) if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1: current_env["OMP_NUM_THREADS"] = str(1) print( "*****************************************\n" "Setting OMP_NUM_THREADS environment variable for each process " "to be {} in default, to avoid your system being overloaded, " "please further tune the variable for optimal performance in " "your application as needed. \n" "*****************************************".format( current_env["OMP_NUM_THREADS"] ) ) processes: List[Any] = [] if ( args.redirect_stdout_and_stderr and os.path.exists(args.logdir) and not os.path.isdir(args.logdir) ): raise ValueError("argument --logdir must be a path to a directory.") subprocess_file_handles = [] for local_rank in range(0, args.nproc_per_node): dist_rank = args.nproc_per_node * args.node_rank + local_rank current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) with_python = not args.no_python cmd = [] if with_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") elif args.module: raise ValueError( "Don't use both the '--no_python' flag and the '--module' flag at the same time." ) cmd.append(args.training_script) cmd.extend(args.training_script_args) stdout_handle: Optional[IO] stderr_handle: Optional[IO] log_directory_path = os.path.join( os.getcwd(), args.logdir, f"local_rank_{local_rank}" ) current_env["GLOG_log_dir"] = log_directory_path if args.redirect_stdout_and_stderr: os.makedirs(log_directory_path, exist_ok=True) node_rank = args.node_rank stdout_handle = open(os.path.join(log_directory_path, stdout_filename), "w") stderr_handle = open(os.path.join(log_directory_path, stderr_filename), "w") subprocess_file_handles.append((stdout_handle, stderr_handle)) stdout_name = stdout_handle.name stderr_name = stderr_handle.name print( f"Note: Stdout and stderr for node {node_rank} rank {local_rank} will\n be written to {stdout_name}, {stderr_name} respectively." ) sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None # set killing flag to make sure killing signal only executed once kill_flag = True def sigkill_handler(signum, frame): nonlocal kill_flag if not kill_flag: return for process in processes: print(f"Killing subprocess {process.pid}") kill_flag = False try: # Note: use os.kill or process.kill() may only kill current process # use killpg will kill(use signal) this process and all sub-processes # # Note: Worker processes launched by data loader will exit automatically # when its parent process exits because of `_prctl_pr_set_pdeathsig`. os.killpg(os.getpid(), signal.SIGTERM) except Exception: pass if last_return_code is not None: raise subprocess.CalledProcessError( returncode=last_return_code, cmd=cmd ) if signum in sig_names: print(f"Main process received {sig_names[signum]}, exiting") sys.exit(1) signal.signal(signal.SIGINT, sigkill_handler) signal.signal(signal.SIGTERM, sigkill_handler) stdout_handle = ( None if not subprocess_file_handles else subprocess_file_handles[local_rank][0] ) stderr_handle = ( None if not subprocess_file_handles else subprocess_file_handles[local_rank][1] ) process = subprocess.Popen( cmd, env=current_env, stdout=stdout_handle, stderr=stderr_handle ) processes.append(process) try: alive_processes = set(processes) while len(alive_processes): finished_processes = [] for process in alive_processes: if process.poll() is None: continue elif process.returncode != 0: last_return_code = process.returncode sigkill_handler(signal.SIGTERM, None) else: finished_processes.append(process) alive_processes = set(alive_processes) - set(finished_processes) time.sleep(1) finally: for (stdout_handle, stderr_handle) in subprocess_file_handles: stdout_handle.close() stderr_handle.close() if __name__ == "__main__": main() ================================================ FILE: python/oneflow/distributions/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""" The documentation is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/distributions/__init__.py The ``distributions`` package contains parameterizable probability distributions and sampling functions. This allows the construction of stochastic computation graphs and stochastic gradient estimators for optimization. This package generally follows the design of the `TensorFlow Distributions`_ package. .. _`TensorFlow Distributions`: https://arxiv.org/abs/1711.10604 It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through. These are the score function estimator/likelihood ratio estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly seen as the basis for policy gradient methods in reinforcement learning, and the pathwise derivative estimator is commonly seen in the reparameterization trick in variational autoencoders. Whilst the score function only requires the value of samples :math:`f(x)`, the pathwise derivative requires the derivative :math:`f'(x)`. The next sections discuss these two in a reinforcement learning example. For more details see `Gradient Estimation Using Stochastic Computation Graphs`_ . .. _`Gradient Estimation Using Stochastic Computation Graphs`: https://arxiv.org/abs/1506.05254 Score function ^^^^^^^^^^^^^^ When the probability density function is differentiable with respect to its parameters, we only need :meth:`~oneflow.distributions.Distribution.sample` and :meth:`~oneflow.distributions.Distribution.log_prob` to implement REINFORCE: .. math:: \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. In practice we would sample an action from the output of a network, apply this action in an environment, and then use ``log_prob`` to construct an equivalent loss function. Note that we use a negative because optimizers use gradient descent, whilst the rule above assumes gradient ascent. With a categorical policy, the code for implementing REINFORCE would be as follows:: probs = policy_network(state) # Note that this is equivalent to what used to be called multinomial m = Categorical(probs) action = m.sample() next_state, reward = env.step(action) loss = -m.log_prob(action) * reward loss.backward() Pathwise derivative ^^^^^^^^^^^^^^^^^^^ The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the :meth:`~oneflow.distributions.Distribution.rsample` method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable. The code for implementing the pathwise derivative would be as follows:: params = policy_network(state) m = Normal(*params) # Any distribution with .has_rsample == True could work based on the application action = m.rsample() next_state, reward = env.step(action) # Assuming that reward is differentiable loss = -reward loss.backward() """ from .distribution import Distribution from .categorical import Categorical ================================================ FILE: python/oneflow/distributions/categorical.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.distributions.distribution import Distribution from oneflow.distributions.utils import probs_to_logits, logits_to_probs # NOTE(Liang Depeng): modified from # https://github.com/pytorch/pytorch/blob/master/torch/distributions/categorical.py __all__ = ["Categorical"] class Categorical(Distribution): r""" Creates a categorical distribution parameterized by either :attr:`probs` or :attr:`logits` (but not both). .. note:: It is equivalent to the distribution that :func:`oneflow.multinomial` samples from. Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. If `probs` is 1-dimensional with length-`K`, each element is the relative probability of sampling the class at that index. If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. :attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. :attr:`logits` will return this normalized value. Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) See also: :func:`oneflow.multinomial` For example: .. code-block:: python >>> import oneflow as flow >>> gen = flow.manual_seed(0) >>> m = flow.distributions.categorical.Categorical(flow.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor(3, dtype=oneflow.int64) """ has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." ) assert validate_args is None if probs is not None: if probs.dim() < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs / probs.sum(-1, keepdim=True) else: if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") self.logits = logits # Normalize import math def logsumexp(t): if t.numel() != 0: maxes = flow.max(t, dim=-1, keepdim=True)[0] maxes.masked_fill_(flow.abs(maxes) == math.inf, 0) result = flow.sum(flow.exp(t - maxes), dim=-1, keepdim=True) return flow.log(result) + maxes else: return flow.log(flow.sum(t, dim=-1, keepdim=True)) self.probs = logits_to_probs(logits - logsumexp(logits)) self._param = self.probs if probs is not None else self.logits self._num_events = self._param.size()[-1] batch_shape = ( self._param.size()[:-1] if self._param.ndimension() > 1 else flow.Size() ) super(Categorical, self).__init__(batch_shape, validate_args=validate_args) def logits(self): return probs_to_logits(self.probs) def probs(self): return logits_to_probs(self.logits) def sample(self, sample_shape=flow.Size()): if not isinstance(sample_shape, flow.Size): sample_shape = flow.Size(sample_shape) probs_2d = self.probs.reshape(-1, self._num_events) samples_2d = flow.multinomial(probs_2d, sample_shape.numel(), True).T return samples_2d.reshape(self._extended_shape(sample_shape)) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/distributions/distribution.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import warnings from typing import Dict, Optional, Any # NOTE(Liang Depeng): Modified from # https://github.com/pytorch/pytorch/blob/master/torch/distributions/distribution.py __all__ = ["Distribution"] class Distribution(object): r""" Distribution is the abstract base class for probability distributions. """ has_rsample = False has_enumerate_support = False _validate_args = __debug__ @staticmethod def set_default_validate_args(value): """ Sets whether validation is enabled or disabled. The default behavior mimics Python's ``assert`` statement: validation is on by default, but is disabled if Python is run in optimized mode (via ``python -O``). Validation may be expensive, so you may want to disable it once a model is working. Args: value (bool): Whether to enable validation. """ if value not in [True, False]: raise ValueError Distribution._validate_args = value def __init__( self, batch_shape=oneflow.Size(), event_shape=oneflow.Size(), validate_args=None ): self._batch_shape = batch_shape self._event_shape = event_shape assert validate_args is None, "only support validate_args=None for now." super(Distribution, self).__init__() def expand(self, batch_shape, _instance=None): """ Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~oneflow.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (oneflow.Size): the desired expanded size. _instance: new instance provided by subclasses that need to override `.expand`. Returns: New distribution instance with batch dimensions expanded to `batch_size`. """ raise NotImplementedError @property def batch_shape(self): """ Returns the shape over which parameters are batched. """ return self._batch_shape @property def event_shape(self): """ Returns the shape of a single sample (without batching). """ return self._event_shape @property def mean(self): """ Returns the mean of the distribution. """ raise NotImplementedError @property def mode(self): """ Returns the mode of the distribution. """ raise NotImplementedError(f"{self.__class__} does not implement mode") @property def variance(self): """ Returns the variance of the distribution. """ raise NotImplementedError @property def stddev(self): """ Returns the standard deviation of the distribution. """ return self.variance.sqrt() def sample(self, sample_shape=oneflow.Size()): """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """ with oneflow.no_grad(): return self.rsample(sample_shape) def rsample(self, sample_shape=oneflow.Size()): """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. """ raise NotImplementedError def sample_n(self, n): """ Generates n samples or n batches of samples if the distribution parameters are batched. """ warnings.warn( "sample_n will be deprecated. Use .sample((n,)) instead", UserWarning ) return self.sample(oneflow.Size((n,))) def log_prob(self, value): """ Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError def cdf(self, value): """ Returns the cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError def icdf(self, value): """ Returns the inverse cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError def enumerate_support(self, expand=True): """ Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Returns: Tensor iterating over dimension 0. """ raise NotImplementedError def entropy(self): """ Returns entropy of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ raise NotImplementedError def perplexity(self): """ Returns perplexity of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ return oneflow.exp(self.entropy()) def _extended_shape(self, sample_shape=oneflow.Size()): """ Returns the size of the sample returned by the distribution, given a `sample_shape`. Note, that the batch and event shapes of a distribution instance are fixed at the time of construction. If this is empty, the returned shape is upcast to (1,). Args: sample_shape (oneflow.Size): the size of the sample to be drawn. """ if not isinstance(sample_shape, oneflow.Size): sample_shape = oneflow.Size(sample_shape) return sample_shape + self._batch_shape + self._event_shape def _validate_sample(self, value): """ Argument validation for distribution methods such as `log_prob`, `cdf` and `icdf`. The rightmost dimensions of a value to be scored via these methods must agree with the distribution's batch and event shapes. Args: value (Tensor): the tensor whose log probability is to be computed by the `log_prob` method. Raises ValueError: when the rightmost dimensions of `value` do not match the distribution's batch and event shapes. """ if not isinstance(value, oneflow.Tensor): raise ValueError("The value argument to log_prob must be a Tensor") event_dim_start = len(value.size()) - len(self._event_shape) if value.size()[event_dim_start:] != self._event_shape: raise ValueError( "The right-most size of value must match event_shape: {} vs {}.".format( value.size(), self._event_shape ) ) actual_shape = value.size() expected_shape = self._batch_shape + self._event_shape for i, j in zip(reversed(actual_shape), reversed(expected_shape)): if i != 1 and j != 1 and i != j: raise ValueError( "Value is not broadcastable with batch_shape+event_shape: {} vs {}.".format( actual_shape, expected_shape ) ) try: support = self.support except NotImplementedError: warnings.warn( f"{self.__class__} does not define `support` to enable " + "sample validation. Please initialize the distribution with " + "`validate_args=False` to turn off validation." ) return assert support is not None valid = support.check(value) if not valid.all(): raise ValueError( "Expected value argument " f"({type(value).__name__} of shape {tuple(value.shape)}) " f"to be within the support ({repr(support)}) " f"of the distribution {repr(self)}, " f"but found invalid values:\n{value}" ) def _get_checked_instance(self, cls, _instance=None): if _instance is None and type(self).__init__ != cls.__init__: raise NotImplementedError( "Subclass {} of {} that defines a custom __init__ method " "must also define a custom .expand() method.".format( self.__class__.__name__, cls.__name__ ) ) return self.__new__(type(self)) if _instance is None else _instance def __repr__(self): return self.__class__.__name__ if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/distributions/utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from functools import update_wrapper from numbers import Number import oneflow as flow import oneflow.nn.functional as F from typing import Dict, Any # NOTE(Liang Depeng): modified from # https://github.com/pytorch/pytorch/blob/master/torch/distributions/utils.py euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant def logits_to_probs(logits, is_binary=False): r""" Converts a tensor of logits into probabilities. Note that for the binary case, each value denotes log odds, whereas for the multi-dimensional case, the values along the last dimension denote the log probabilities (possibly unnormalized) of the events. """ if is_binary: return flow.sigmoid(logits) return F.softmax(logits, dim=-1) def clamp_probs(probs): eps = flow.finfo(probs.dtype).eps return probs.clamp(min=eps, max=1 - eps) def probs_to_logits(probs, is_binary=False): r""" Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. For the multi-dimensional case, the values along the last dimension denote the probabilities of occurrence of each of the events. """ ps_clamped = clamp_probs(probs) if is_binary: return flow.log(ps_clamped) - flow.log1p(-ps_clamped) return flow.log(ps_clamped) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/env.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.env_util import api_all_device_placement as all_device_placement import oneflow._oneflow_internal def get_local_rank(): """Returns the local rank of current machine. Local rank is not globally unique. It is only unique per process on a machine. Returns: The the local rank of process on current machine. """ return oneflow._oneflow_internal.GetLocalRank() def get_rank(): """Returns the rank of current process group. Rank is globally unique, range of which is [0, world_size). Returns: The rank of the process group. """ return oneflow._oneflow_internal.GetRank() def get_node_size(): """Returns the number of machines in the current process group. Returns: The the number of machines in the process group. """ return oneflow._oneflow_internal.GetNodeSize() def get_world_size(): """Returns the number of processes in the current process group. Returns: The world size of the process group. """ return oneflow._oneflow_internal.GetWorldSize() def init_rdma(): """ Init RDMA in the current envirment. If the current envirment support RDMA, turning on RDMA by calling oneflow.env.init_rdma() can speed up data transfer. Note: - Make sure to avoid using fork() after oneflow.env.init_rdma() is invoked. Otherwise, data corruption or segmentation fault may result! - Requires all devices to execute oneflow.env.init_rdma() simultaneously. Otherwise, deadlock may result! """ oneflow._oneflow_internal.InitRDMA() def rdma_is_initialized(): """Returns whether RDMA is initialized in the current envirment or not. Returns: Whether RDMA is initialized or not. """ return oneflow._oneflow_internal.RDMAIsInitialized() def destory_rdma(): """Destory RDMA in the current envirment. """ return oneflow._oneflow_internal.DestoryRDMA() ================================================ FILE: python/oneflow/experimental/load_mnist.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import hashlib import os import numpy as np import requests from tqdm import tqdm def get_sha256hash(file_path, Bytes=1024): sha256hash = hashlib.sha256() with open(file_path, "rb") as f: while True: data = f.read(Bytes) if data: sha256hash.update(data) else: break ret = sha256hash.hexdigest() return ret def download_mnist_file(out_path, url): resp = requests.get(url=url, stream=True) size = int(resp.headers["Content-Length"]) / 1024 print("File size: %.4f kb, downloading..." % size) with open(out_path, "wb") as f: for data in tqdm( iterable=resp.iter_content(1024), total=size, unit="k", desc=out_path ): f.write(data) print("Done!") def get_mnist_file(sha256, url, out_dir): path = os.path.join(out_dir, "mnist.npz") if not os.path.isfile(path): download_mnist_file(path, url) print("File mnist.npz already exist, path:", path) if not get_sha256hash(path) == sha256: cheksum_fail = "sha256 verification failed, remove {0} and try again".format( path ) raise Exception(cheksum_fail) return path def load_mnist( train_batch_size=100, test_batch_size=100, data_format="NCHW", url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz", hash_check="63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e", out_dir=".", ): """Load mnist dataset, return images and labels, if dataset doesn't exist, then download it to directory that out_dir specified Args: train_batch_size (int, optional): batch size for train. Defaults to 100. test_batch_size (int, optional): batch size for test or evaluate. Defaults to 100. data_format (str, optional): data format. Defaults to "NCHW". url (str, optional): url to get mnist.npz. Defaults to "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz". hash_check (str, optional): file hash value. Defaults to "63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e". out_dir (str, optional): dir to save downloaded file. Defaults to "./". Returns: (train_images, train_labels), (test_images, test_labels) """ path = get_mnist_file(hash_check, url, out_dir) with np.load(path, allow_pickle=True) as f: (x_train, y_train) = (f["x_train"], f["y_train"]) (x_test, y_test) = (f["x_test"], f["y_test"]) def normalize(x, y, batch_size): x = x.astype(np.float32) / 255.0 y = y.astype(np.int32) if data_format == "NCHW": images = x.reshape((-1, batch_size, 1, x.shape[1], x.shape[2])) else: images = x.reshape((-1, batch_size, x.shape[1], x.shape[2], 1)) labels = y.reshape((-1, batch_size)) return (images, labels) (train_images, train_labels) = normalize(x_train, y_train, train_batch_size) (test_images, test_labels) = normalize(x_test, y_test, test_batch_size) return ((train_images, train_labels), (test_images, test_labels)) ================================================ FILE: python/oneflow/fft/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.tensor import Tensor import oneflow as flow def fft(input, n=None, dim=-1, norm=None) -> Tensor: r""" Computes the one dimensional discrete Fourier transform of :attr:`input`. Note: The Fourier domain representation of any real signal satisfies the Hermitian property: `X[i] = conj(X[-i])`. This function always returns both the positive and negative frequency terms even though, for real inputs, the negative frequencies are redundant. :func:`oneflow.fft.rfft` returns the more compact one-sided representation where only the positive frequencies are returned. Args: input (Tensor): the input tensor n (int, optional): Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the FFT. dim (int, optional): The dimension along which to take the one dimensional FFT. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.fft`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) Calling the backward transform (:func:`oneflow.fft.ifft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ifft` the exact inverse. Default is ``"backward"`` (no normalization). Example: >>> t = oneflow.arange(4) >>> t tensor([0, 1, 2, 3]) >>> oneflow.fft.fft(t) tensor([ 6+0j, -2+2j, -2+0j, -2-2j], dtype=oneflow.complex64) >>> t = oneflow.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) >>> oneflow.fft.fft(t) tensor([12+16j, -8+0j, -4-4j, -8j], dtype=oneflow.complex128) """ if n is None: n = -1 return flow._C.fft(input, n, dim, norm) def ifft(input, n=None, dim=-1, norm=None) -> Tensor: r""" Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. Args: input (Tensor): the input tensor n (int, optional): Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the IFFT. dim (int, optional): The dimension along which to take the one dimensional IFFT. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ifft`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) Calling the forward transform (:func:`~oneflow.fft.fft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ifft` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). Example: >>> t = oneflow.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) >>> oneflow.fft.ifft(t) tensor([0j, (1+0j), (2+0j), (3+0j)], dtype=oneflow.complex128) """ if n is None: n = -1 return flow._C.ifft(input, n, dim, norm) def fft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the 2 dimensional discrete Fourier transform of :attr:`input`. Equivalent to :func:`~oneflow.fft.fftn` but FFTs only the last two dimensions by default. Note: The Fourier domain representation of any real signal satisfies the Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This function always returns all positive and negative frequency terms even though, for real inputs, half of these values are redundant. :func:`~oneflow.fft.rfft2` returns the more compact one-sided representation where only the positive frequencies of the last dimension are returned. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the FFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: last two dimensions. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.fft2`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.ifft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`~oneflow.fft.ifft2` the exact inverse. Default is ``"backward"`` (no normalization). """ return flow._C.fft2(input, s, dim, norm) def ifft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. Equivalent to :func:`oneflow.fft.ifftn` but IFFTs only the last two dimensions by default. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the IFFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: last two dimensions. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ifft2`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.fft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ifft2` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.ifft2(input, s, dim, norm) def fftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the N dimensional discrete Fourier transform of :attr:`input`. Note: The Fourier domain representation of any real signal satisfies the Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This function always returns all positive and negative frequency terms even though, for real inputs, half of these values are redundant. :func:`oneflow.fft.rfftn` returns the more compact one-sided representation where only the positive frequencies of the last dimension are returned. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the FFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.fftn`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.ifftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ifftn` the exact inverse. Default is ``"backward"`` (no normalization). """ return flow._C.fftn(input, s, dim, norm) def ifftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the IFFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ifftn`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.fftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ifftn` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.ifftn(input, s, dim, norm) def rfft(input, n=None, dim=-1, norm=None) -> Tensor: r""" Computes the one dimensional Fourier transform of real-valued :attr:`input`. The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so the output contains only the positive frequencies below the Nyquist frequency. To compute the full output, use :func:`oneflow.fft.fft` Args: input (Tensor): the real input tensor n (int, optional): Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the real FFT. dim (int, optional): The dimension along which to take the one dimensional real FFT. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.rfft`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) Calling the backward transform (:func:`oneflow.fft.irfft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfft` the exact inverse. Default is ``"backward"`` (no normalization). Example: >>> t = oneflow.arange(4) >>> t tensor([0, 1, 2, 3], dtype=oneflow.int64) >>> oneflow.fft.rfft(t) tensor([ (6+0j), (-2+2j), (-2+0j)], dtype=oneflow.complex64) Compare against the full output from :func:`oneflow.fft.fft`: >>> oneflow.fft.fft(t) tensor([ (6+0j), (-2+2j), (-2+0j), (-2-2j)], dtype=oneflow.complex64) Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, and therefore must always be real-valued. """ if n is None: n = -1 return flow._C.rfft(input, n, dim, norm) def irfft(input, n=None, dim=-1, norm=None) -> Tensor: r""" Computes the inverse of :func:`oneflow.fft.rfft`. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier domain, as produced by :func:`oneflow.fft.rfft`. By the Hermitian property, the output will be real-valued. Note: Some input frequencies must be real-valued to satisfy the Hermitian property. In these cases the imaginary component will be ignored. For example, any imaginary component in the zero-frequency term cannot be represented in a real output and so will always be ignored. Note: The correct interpretation of the Hermitian input depends on the length of the original data, as given by :attr:`n`. This is because each input shape could correspond to either an odd or even length signal. By default, the signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal length :attr:`n`. Args: input (Tensor): the input tensor representing a half-Hermitian signal n (int, optional): Output signal length. This determines the length of the output signal. If given, the input will either be zero-padded or trimmed to this length before computing the real IFFT. Defaults to even output: ``n=2*(input.size(dim) - 1)``. dim (int, optional): The dimension along which to take the one dimensional real IFFT. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.irfft`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) Calling the forward transform (:func:`oneflow.fft.rfft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfft` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ if n is None: n = -1 return flow._C.irfft(input, n, dim, norm) def rfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the 2-dimensional discrete Fourier transform of real :attr:`input`. Equivalent to :func:`oneflow.fft.rfftn` but FFTs only the last two dimensions by default. The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``, so the full :func:`oneflow.fft.fft2` output contains redundant information. :func:`oneflow.fft.rfft2` instead omits the negative frequencies in the last dimension. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the real FFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: last two dimensions. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.rfft2`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.irfft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfft2` the exact inverse. Default is ``"backward"`` (no normalization). """ return flow._C.rfft2(input, s, dim, norm) def irfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the inverse of :func:`oneflow.fft.rfft2`. Equivalent to :func:`oneflow.fft.irfftn` but IFFTs only the last two dimensions by default. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier domain, as produced by :func:`oneflow.fft.rfft2`. By the Hermitian property, the output will be real-valued. Note: Some input frequencies must be real-valued to satisfy the Hermitian property. In these cases the imaginary component will be ignored. For example, any imaginary component in the zero-frequency term cannot be represented in a real output and so will always be ignored. Note: The correct interpretation of the Hermitian input depends on the length of the original data, as given by :attr:`s`. This is because each input shape could correspond to either an odd or even length signal. By default, the signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal shape :attr:`s`. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the real FFT. If a length ``-1`` is specified, no padding is done in that dimension. Defaults to even output in the last dimension: ``s[-1] = 2*(input.size(dim[-1]) - 1)``. dim (Tuple[int], optional): Dimensions to be transformed. The last dimension must be the half-Hermitian compressed dimension. Default: last two dimensions. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.irfft2`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.rfft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfft2` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.irfft2(input, s, dim, norm) def rfftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the N-dimensional discrete Fourier transform of real :attr:`input`. The FFT of a real signal is Hermitian-symmetric, ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full :func:`oneflow.fft.fftn` output contains redundant information. :func:`oneflow.fft.rfftn` instead omits the negative frequencies in the last dimension. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the real FFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.rfftn`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.irfftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfftn` the exact inverse. Default is ``"backward"`` (no normalization). """ return flow._C.rfftn(input, s, dim, norm) def irfftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the inverse of :func:`oneflow.fft.rfftn`. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier domain, as produced by :func:`oneflow.fft.rfftn`. By the Hermitian property, the output will be real-valued. Note: Some input frequencies must be real-valued to satisfy the Hermitian property. In these cases the imaginary component will be ignored. For example, any imaginary component in the zero-frequency term cannot be represented in a real output and so will always be ignored. Note: The correct interpretation of the Hermitian input depends on the length of the original data, as given by :attr:`s`. This is because each input shape could correspond to either an odd or even length signal. By default, the signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal shape :attr:`s`. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the real FFT. If a length ``-1`` is specified, no padding is done in that dimension. Defaults to even output in the last dimension: ``s[-1] = 2*(input.size(dim[-1]) - 1)``. dim (Tuple[int], optional): Dimensions to be transformed. The last dimension must be the half-Hermitian compressed dimension. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.irfftn`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.rfftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.irfftn` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.irfftn(input, s, dim, norm) def hfft(input, n=None, dim=-1, norm=None) -> Tensor: r""" hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor Computes the one dimensional discrete Fourier transform of a Hermitian symmetric :attr:`input` signal. Note: :func:`oneflow.fft.hfft`/:func:`oneflow.fft.ihfft` are analogous to :func:`oneflow.fft.rfft`/:func:`oneflow.fft.irfft`. The real FFT expects a real signal in the time-domain and gives a Hermitian symmetry in the frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in the time-domain and real-valued in the frequency-domain. For this reason, special care needs to be taken with the length argument :attr:`n`, in the same way as with :func:`oneflow.fft.irfft`. Note: Because the signal is Hermitian in the time-domain, the result will be real in the frequency domain. Note that some input frequencies must be real-valued to satisfy the Hermitian property. In these cases the imaginary component will be ignored. For example, any imaginary component in ``input[0]`` would result in one or more complex frequency terms which cannot be represented in a real output and so will always be ignored. Note: The correct interpretation of the Hermitian input depends on the length of the original data, as given by :attr:`n`. This is because each input shape could correspond to either an odd or even length signal. By default, the signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal length :attr:`n`. Args: input (Tensor): the input tensor representing a half-Hermitian signal n (int, optional): Output signal length. This determines the length of the real output. If given, the input will either be zero-padded or trimmed to this length before computing the Hermitian FFT. Defaults to even output: ``n=2*(input.size(dim) - 1)``. dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.hfft`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) Calling the backward transform (:func:`oneflow.fft.ihfft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfft` the exact inverse. Default is ``"backward"`` (no normalization). Example: Taking a real-valued frequency signal and bringing it into the time domain gives Hermitian symmetric output: >>> t = oneflow.linspace(0, 1, 5) >>> t tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32) >>> T = oneflow.fft.ifft(t) >>> T tensor([ (0.5000-0.0000j), (-0.1250-0.1720j), (-0.1250-0.0406j), (-0.1250+0.0406j), (-0.1250+0.1720j)], dtype=oneflow.complex64) Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is redundant. We can thus compute the forward transform without considering negative frequencies: >>> oneflow.fft.hfft(T[:3], n=5) tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32) Like with :func:`oneflow.fft.irfft`, the output length must be given in order to recover an even length output: >>> oneflow.fft.hfft(T[:3]) tensor([0.1250, 0.2809, 0.6250, 0.9691], dtype=oneflow.float32) """ if n is None: n = -1 return flow._C.hfft(input, n, dim, norm) def ihfft(input, n=None, dim=-1, norm=None) -> Tensor: r""" Computes the inverse of :func:`oneflow.fft.hfft`. :attr:`input` must be a real-valued signal, interpreted in the Fourier domain. The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``. :func:`oneflow.fft.ihfft` represents this in the one-sided form where only the positive frequencies below the Nyquist frequency are included. To compute the full output, use :func:`oneflow.fft.ifft`. Args: input (Tensor): the real input tensor n (int, optional): Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the Hermitian IFFT. dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ihfft`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) Calling the forward transform (:func:`oneflow.fft.hfft`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfft` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). Example: >>> t = oneflow.arange(5) >>> t tensor([0, 1, 2, 3, 4], dtype=oneflow.int64) >>> oneflow.fft.ihfft(t) tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j)], dtype=oneflow.complex64) Compare against the full output from :func:`oneflow.fft.ifft`: >>> oneflow.fft.ifft(t) tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, -0.5000+0.6882j]) tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j), (-0.5000+0.1625j), (-0.5000+0.6882j)], dtype=oneflow.complex64) """ if n is None: n = -1 return flow._C.ihfft(input, n, dim, norm) def hfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric :attr:`input` signal. Equivalent to :func:`oneflow.fft.hfftn` but only transforms the last two dimensions by default. :attr:`input` is interpreted as a one-sided Hermitian signal in the time domain. By the Hermitian property, the Fourier transform will be real-valued. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the Hermitian FFT. If a length ``-1`` is specified, no padding is done in that dimension. Defaults to even output in the last dimension: ``s[-1] = 2*(input.size(dim[-1]) - 1)``. dim (Tuple[int], optional): Dimensions to be transformed. The last dimension must be the half-Hermitian compressed dimension. Default: last two dimensions. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.hfft2`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.ihfft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfft2` the exact inverse. Default is ``"backward"`` (no normalization). Example: Starting from a real frequency-space signal, we can generate a Hermitian-symmetric time-domain signal: >>> T = oneflow.rand(10, 9) >>> t = oneflow.fft.ihfft2(T) Without specifying the output length to :func:`oneflow.fft.hfftn`, the output will not round-trip properly because the input is odd-length in the last dimension: >>> oneflow.fft.hfft2(t).size() oneflow.Size([10, 10]) So, it is recommended to always pass the signal shape :attr:`s`. >>> roundtrip = oneflow.fft.hfft2(t, T.size()) >>> roundtrip.size() oneflow.Size([10, 9]) >>> oneflow.allclose(roundtrip, T) True """ return flow._C.hfft2(input, s, dim, norm) def ihfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: r""" Computes the 2-dimensional inverse discrete Fourier transform of real :attr:`input`. Equivalent to :func:`oneflow.fft.ihfftn` but transforms only the two last dimensions by default. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the Hermitian IFFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: last two dimensions. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ihfft2`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.hfft2`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfft2` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.ihfft2(input, s, dim, norm) def hfftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric :attr:`input` signal. :attr:`input` is interpreted as a one-sided Hermitian signal in the time domain. By the Hermitian property, the Fourier transform will be real-valued. Note: :func:`oneflow.fft.hfftn`/:func:`oneflow.fft.ihfftn` are analogous to :func:`oneflow.fft.rfftn`/:func:`oneflow.fft.irfftn`. The real FFT expects a real signal in the time-domain and gives Hermitian symmetry in the frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in the time-domain and real-valued in the frequency-domain. For this reason, special care needs to be taken with the shape argument :attr:`s`, in the same way as with :func:`oneflow.fft.irfftn`. Note: Some input frequencies must be real-valued to satisfy the Hermitian property. In these cases the imaginary component will be ignored. For example, any imaginary component in the zero-frequency term cannot be represented in a real output and so will always be ignored. Note: The correct interpretation of the Hermitian input depends on the length of the original data, as given by :attr:`s`. This is because each input shape could correspond to either an odd or even length signal. By default, the signal is assumed to be even length and odd signals will not round-trip properly. It is recommended to always pass the signal shape :attr:`s`. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the real FFT. If a length ``-1`` is specified, no padding is done in that dimension. Defaults to even output in the last dimension: ``s[-1] = 2*(input.size(dim[-1]) - 1)``. dim (Tuple[int], optional): Dimensions to be transformed. The last dimension must be the half-Hermitian compressed dimension. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the forward transform (:func:`oneflow.fft.hfftn`), these correspond to: * ``"forward"`` - normalize by ``1/n`` * ``"backward"`` - no normalization * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) Where ``n = prod(s)`` is the logical FFT size. Calling the backward transform (:func:`oneflow.fft.ihfftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfftn` the exact inverse. Default is ``"backward"`` (no normalization). """ return flow._C.hfftn(input, s, dim, norm) def ihfftn(input, s=None, dim=None, norm=None) -> Tensor: r""" Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`. :attr:`input` must be a real-valued signal, interpreted in the Fourier domain. The n-dimensional IFFT of a real signal is Hermitian-symmetric, ``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`oneflow.fft.ihfftn` represents this in the one-sided form where only the positive frequencies below the Nyquist frequency are included in the last signal dimension. To compute the full output, use :func:`oneflow.fft.ifftn`. Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension ``dim[i]`` will either be zero-padded or trimmed to the length ``s[i]`` before computing the Hermitian IFFT. If a length ``-1`` is specified, no padding is done in that dimension. Default: ``s = [input.size(d) for d in dim]`` dim (Tuple[int], optional): Dimensions to be transformed. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. norm (str, optional): Normalization mode. For the backward transform (:func:`oneflow.fft.ihfftn`), these correspond to: * ``"forward"`` - no normalization * ``"backward"`` - normalize by ``1/n`` * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) Where ``n = prod(s)`` is the logical IFFT size. Calling the forward transform (:func:`oneflow.fft.hfftn`) with the same normalization mode will apply an overall normalization of ``1/n`` between the two transforms. This is required to make :func:`oneflow.fft.ihfftn` the exact inverse. Default is ``"backward"`` (normalize by ``1/n``). """ return flow._C.ihfftn(input, s, dim, norm) ================================================ FILE: python/oneflow/framework/__init__.py ================================================ ================================================ FILE: python/oneflow/framework/args_tree.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Union, List, Tuple, Dict, Callable from collections import OrderedDict from oneflow.framework.tensor import Tensor def _is_raw_type(value, raw_type): # Special case for namedtuple return types # For example, max(x, dim=1) return oneflow.return_types.max(values=..., indices=...) if ( raw_type == tuple and isinstance(value, tuple) and type(value).__module__ == "oneflow.return_types" ): return True return type(value) is raw_type class NamedArg(object): r""" The class for wrapping over the input/output argument and associating each input/output argument with a prefix and name. The input/output argument can be viewed as a tree. NamedArg basically wraps over each tree node on this tree. The recursive structure of the input/output arguments are kept, for example: input = [1, {key: "value" }] will be constructed into: named_input = NamedArg([NamedArg(1), NamedArg({key: NamedArg("value")})]) """ def __init__( self, prefix="", name=None, global_index=0, tensor_type=Tensor ) -> None: self._name = name if name is not None else str(global_index) self._prefix = prefix self._global_index = global_index self._is_value_set = False self._value = None self._tensor_type = tensor_type def prefix(self): return self._prefix def name(self): return self._name def global_index(self): return self._global_index def value(self): assert self._is_value_set, "self._value is not set yet" return self._value def is_leaf(self): assert self._is_value_set, "self._value is not set yet" return not ( _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict) or _is_raw_type(self._value, tuple) or _is_raw_type(self._value, list) ) def set_value(self, value): assert not _is_raw_type(value, NamedArg), "cannot accept value of type NamedArg" self._value = value self._is_value_set = True def __repr__(self): repr_str = "" repr_str += "(name: " + self._name repr_str += ", idx: " + str(self._global_index) repr_str += ", type: " if _is_raw_type(self._value, tuple): repr_str += "TUPLE" elif _is_raw_type(self._value, list): repr_str += "LIST" elif _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict): repr_str += "DICT" elif isinstance(self._value, self._tensor_type): repr_str += "TENSOR" elif self._value is None: repr_str += "NONE" else: repr_str += "OPAQUE" if isinstance(self._value, self._tensor_type): repr_str += ( ", value: tensor(" + str(self._value.shape) + ", " + str(self._value.dtype) + ")" ) elif ( _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict) or _is_raw_type(self._value, list) or _is_raw_type(self._value, tuple) ): repr_str += ", value: " + repr(self._value) else: repr_str += ", value: " + repr(self._value) repr_str += ")" return repr_str class ArgsTree(object): def __init__( self, io_args: Union[Tuple, List, Dict], gen_name: bool = False, root_prefix: str = "", root_name: str = None, tensor_type=Tensor, ) -> None: self._io_args = io_args self._gen_name = gen_name self._root_prefix = root_prefix self._root_name = root_name self._named_io_args = None self._next_global_index = 0 self._tensor_type = tensor_type if self._gen_name: self._named_io_args = self._construct_named_io_args( self._io_args, self._root_prefix, self._root_name ) def gen_name(self): return self._gen_name def iter_nodes(self): r""" return a generator of the args tree nodes in the DFS manner. The node returned can be of type NamedArg or non-NamedArg depending on whether gen_name is set. If gen_name is set, the node will be NamedArg. """ if self._gen_name: args_to_iter = self._named_io_args else: args_to_iter = self._io_args # NOTE(lixiang): Generator expression and iterator are used. # This avoids generating the full list in memory and only processes the nodes that need to be processed, # reducing time and space consumption. stack = [iter([args_to_iter])] while len(stack) > 0: try: curr = next(stack[-1]) if _is_raw_type(curr, NamedArg): curr_value = curr.value() else: curr_value = curr if _is_raw_type(curr_value, list) or _is_raw_type(curr_value, tuple): children = curr_value elif _is_raw_type(curr_value, dict) or _is_raw_type( curr_value, OrderedDict ): children = curr_value.values() else: children = None if children: stack.append(iter(children)) yield curr except StopIteration: stack.pop() def iter_named_nodes(self): assert self._gen_name, "Only use this if gen_name is set!" for named_node in self.iter_nodes(): yield (named_node.prefix() + "_" + named_node.name(), named_node) def _construct_named_io_args(self, value, prefix: str, name: str) -> NamedArg: arg = NamedArg(prefix, name, self._next_global_index, self._tensor_type) self._next_global_index += 1 if _is_raw_type(value, list) or _is_raw_type(value, tuple): def construct_func(enum): (i, v) = enum next_prefix = prefix + ("." if prefix else "") + str(i) new_arg = self._construct_named_io_args(v, next_prefix, None) return new_arg arg.set_value(value.__class__(map(construct_func, enumerate(value)))) elif _is_raw_type(value, dict) or _is_raw_type(value, OrderedDict): def construct_func(enum): i, (key, v) = enum next_prefix = prefix + ("." if prefix else "") + str(i) new_arg = self._construct_named_io_args(v, next_prefix, key) return key, new_arg arg.set_value( value.__class__(map(construct_func, enumerate(value.items()))) ) else: arg.set_value(value) return arg def map_tuple_leaf(self, map_function: Callable): r""" When the type of io args is tuple or list, map the leaf of the arguments into map_function(leaf). """ assert map_function != None, "map function cannot be None" assert isinstance( self._io_args, (tuple, list) ), "only used when io args is a tuple or list of tensors" stack = [] # Cases handled: tuple(tensor, ...), such as input args. if len(self._io_args) > 0 and isinstance(self._io_args[0], self._tensor_type): for i in self._io_args: mapped_value = map_function(i) stack.append(mapped_value) if isinstance(self._io_args, tuple): return tuple(stack) elif isinstance(self._io_args, list): return stack # Cases handled: tuple(tuple(tensor, ...), ), such as the output args of return. elif ( len(self._io_args) > 0 and isinstance(self._io_args[0], (tuple, list)) and all(isinstance(arg, self._tensor_type) for arg in self._io_args[0]) ): for i in self._io_args[0]: mapped_value = map_function(i) stack.append(mapped_value) if isinstance(self._io_args[0], tuple): return (tuple(stack),) elif isinstance(self._io_args[0], list): return (stack,) # Other cases. # Do not loop optimize, and continue to execute the recursive code (`_execute_mapping`). else: return self._execute_mapping(self._io_args, map_function) def map_leaf(self, map_function: Callable): r""" Map the leaf of the arguments into map_function(leaf). """ assert map_function != None, "map function cannot be None" if self._gen_name: args_to_map = self._named_io_args else: args_to_map = self._io_args return self._execute_mapping(args_to_map, map_function) def _execute_mapping(self, value, map_function): if _is_raw_type(value, tuple) or _is_raw_type(value, list): mapped_value = value.__class__( map(lambda x: self._execute_mapping(x, map_function), value) ) elif _is_raw_type(value, dict) or _is_raw_type(value, OrderedDict): mapped_value = value.__class__( map( lambda x: (x[0], self._execute_mapping(x[1], map_function)), value.items(), ) ) elif _is_raw_type(value, NamedArg): if value.is_leaf(): # only map the leaf: TENSOR/NONE/OPAQUE mapped_value = map_function(value) else: mapped_value = self._execute_mapping(value.value(), map_function) else: mapped_value = map_function(value) return mapped_value def __repr__(self): if self._named_io_args: return self._named_io_args.__repr__() else: return str(self.__class__) ================================================ FILE: python/oneflow/framework/attr_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""" Get the nested attribute given the owning object and attribute chain. For example, if we want to get `resource.collective_boxing_conf.nccl_num_streams` we can call `get_nested_attribute(resource, ["collective_boxing_conf", "nccl_num_streams"]) """ def get_nested_attribute(owning_object, attrs_chain): if not isinstance(attrs_chain, list): if isinstance(attrs_chain, str): attrs_chain = [attrs_chain] else: assert False, ( "attrs_chain should be either a string or a list, but get " + str(type(attrs_chain)) ) last_attr = owning_object for att in attrs_chain: assert hasattr(last_attr, att), ( repr(last_attr) + " does not have attribute " + att + " !" ) last_attr = getattr(last_attr, att) return last_attr def SetProtoAttrValue(attr_value, py_value, default_attr_value): if default_attr_value.HasField("at_bool"): if py_value is None: py_value = True assert type(py_value) is bool attr_value.at_bool = py_value elif default_attr_value.HasField("at_int64"): assert type(py_value) is int attr_value.at_int64 = py_value elif default_attr_value.HasField("at_double"): assert type(py_value) is float attr_value.at_double = py_value elif default_attr_value.HasField("at_string"): assert type(py_value) is str attr_value.at_string = py_value else: raise ValueError( "config with type %s is invalid. supported types: [bool, int, float, str]" % type(py_value) ) ================================================ FILE: python/oneflow/framework/balanced_splitter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ def BalancedPartNums(total, part_size): base = int(total / part_size) remainder = total % part_size return [base + int(i < remainder) for i in range(part_size)] def BalancedRanges(total, part_size): balanced_part_nums = BalancedPartNums(total, part_size) ranges = [] start = 0 for part_num in balanced_part_nums: end = start + part_num ranges.append((start, end)) start = end return ranges ================================================ FILE: python/oneflow/framework/c_api_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from google.protobuf import text_format import oneflow import oneflow.core.common.data_type_pb2 as dtype_util import oneflow.core.common.error_pb2 as error_util import oneflow.core.job.env_pb2 as env_pb2 import oneflow.core.job.job_pb2 as job_pb import oneflow.core.job.job_conf_pb2 as job_conf_pb import oneflow.core.job.job_set_pb2 as job_set_pb import oneflow.core.job.placement_pb2 as placement_pb import oneflow.core.job.resource_pb2 as resource_util import oneflow.core.operator.op_attribute_pb2 as op_attribute_pb import oneflow.core.operator.op_conf_pb2 as op_conf_util import oneflow.core.record.record_pb2 as record_util import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util from oneflow.core.framework.config_def_pb2 import ConfigDef from oneflow.core.job.inter_user_job_info_pb2 import InterUserJobInfo def CurrentResource(): resource = oneflow._oneflow_internal.CurrentResource() return text_format.Parse(resource, resource_util.Resource()) def EnvResource(): resource = oneflow._oneflow_internal.EnvResource() return text_format.Parse(resource, resource_util.Resource()) def GetEnvContext(env_proto): assert type(env_proto) is env_pb2.EnvProto env_proto_str = text_format.MessageToString(env_proto) env_ctx = oneflow._oneflow_internal.EnvContext(env_proto_str) return env_ctx def JobBuildAndInferCtx_Open(job_name): job_name = str(job_name) oneflow._oneflow_internal.JobBuildAndInferCtx_Open(job_name) def CurJobBuildAndInferCtx_SetJobConf(job_config_proto): assert type(job_config_proto) is job_conf_pb.JobConfigProto, type(job_config_proto) job_config_proto_str = text_format.MessageToString(job_config_proto) oneflow._oneflow_internal.CurJobBuildAndInferCtx_SetJobConf(job_config_proto_str) def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str(text_format.MessageToString(upstream_signature)) op_attribute_str = oneflow._oneflow_internal.InferOpConf( serialized_op_conf, serialized_upstream_sig ) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) def IsInterfaceOpConf(op_conf): op_type_field = op_conf.WhichOneof("op_type") field_number = op_conf_util.OperatorConf.DESCRIPTOR.fields_by_name[ op_type_field ].number return oneflow._oneflow_internal.IsInterfaceOpTypeCase(field_number) def GetOpParallelSymbolId(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) return oneflow._oneflow_internal.GetOpParallelSymbolId(serialized_op_conf) def CheckAndCompleteUserOpConf(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) new_op_conf = oneflow._oneflow_internal.CheckAndCompleteUserOpConf( serialized_op_conf ) return text_format.Parse(new_op_conf, op_conf_util.OperatorConf()) def GetFunctionConfigDef(): func_config_def = oneflow._oneflow_internal.GetFunctionConfigDef() return text_format.Parse(func_config_def, ConfigDef()) def GetScopeConfigDef(): scope_config_def = oneflow._oneflow_internal.GetScopeConfigDef() return text_format.Parse(scope_config_def, ConfigDef()) def GetCurrentJob(): serialized_job = oneflow._oneflow_internal.GetSerializedCurrentJob() ret = job_pb.Job() ret.ParseFromString(serialized_job) return ret ================================================ FILE: python/oneflow/framework/check_point_v2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import contextlib import os import warnings from typing import ( Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, IO, BinaryIO, ) from pathlib import Path import pickle import json from collections import OrderedDict import io import numpy as np from google.protobuf import text_format import oneflow import oneflow as flow import oneflow._oneflow_internal import oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb import oneflow.framework.dtype as dtype_util import oneflow.framework.id_util as id_util from oneflow.framework.tensor import Tensor import oneflow.nn.graph.graph as graph_util from oneflow.framework.args_tree import ArgsTree import pickle from oneflow.nn.graph import GraphTensor SNAPSHOT_DONE_FILENAME = "snapshot_done" META_INFO_FILENAME = "meta" PICKLE_FILENAME = "pickled_data" DATA_FILENAME = "out" PROTOCOL_VERSION = 1 ONEFLOW_MAGIC_KEY = "__oneflow__" MAP_LOCATION = Optional[ Union[Callable[[Tensor, str], Tensor], flow.device, str, flow.placement] ] FILE_LIKE = Union[os.PathLike, BinaryIO, IO[bytes], Path] class _opener(object): def __init__(self, file_like): self.file_like = file_like def __enter__(self): return self.file_like def __exit__(self, *args): pass class _open_file(_opener): def __init__(self, path, mode): super(_open_file, self).__init__(open(path, mode)) def __exit__(self, *args): self.file_like.close() class _open_buffer_reader(_opener): def __init__(self, buffer): super(_open_buffer_reader, self).__init__(buffer) _check_seekable(buffer) class _open_buffer_writer(_opener): def __exit__(self, *args): self.file_like.flush() def _open_file_like(path_or_buffer, mode): if _is_path(path_or_buffer): return _open_file(path_or_buffer, mode) else: if "w" in mode: return _open_buffer_writer(path_or_buffer) elif "r" in mode: return _open_buffer_reader(path_or_buffer) else: raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") def _is_path(path_or_buffer): return isinstance(path_or_buffer, Path) def _check_seekable(f) -> bool: def raise_err_msg(patterns, e): for p in patterns: if p in str(e): msg = ( str(e) + ". You can only oneflow.load from a file that is seekable." + " Please pre-load the data into a buffer like io.BytesIO and" + " try to load from it instead." ) raise type(e)(msg) raise e try: f.seek(f.tell()) return True except (io.UnsupportedOperation, AttributeError) as e: raise_err_msg(["seek", "tell"], e) return False class FileBackendVariableBlob: def __init__( self, var_dir: str, dtype: Optional[oneflow.dtype] = None, shape: Optional[Sequence[int]] = None, ): data_path = os.path.join(var_dir, DATA_FILENAME) if not os.path.isfile(data_path): raise FileNotFoundError() self.var_dir_ = var_dir meta_info_path = os.path.join(self.var_dir_, META_INFO_FILENAME) if os.path.exists(meta_info_path): meta_info = variable_meta_info_pb.VariableMetaInfo() with open(meta_info_path) as f: text_format.Parse(f.read(), meta_info) self.has_meta_info_ = True else: self.has_meta_info_ = False if self.has_meta_info_: assert dtype is None and shape is None self.shape_ = tuple(meta_info.shape.dim) self.dtype_ = dtype_util.convert_proto_dtype_to_oneflow_dtype( meta_info.data_type ) elif shape is not None and dtype is not None: self.shape_ = shape self.dtype_ = dtype self.has_meta_info_ = True elif shape is not None or dtype is not None: raise RuntimeError("both or neither of shape and dtype should be None") else: pass if self.has_meta_info_: itemsize = np.dtype( dtype_util.convert_oneflow_dtype_to_numpy_dtype(self.dtype_) ).itemsize assert os.path.getsize(data_path) == np.prod(self.shape).item() * itemsize @property def file_path(self) -> str: return os.path.join(self.var_dir_, DATA_FILENAME) @property def shape(self) -> Tuple[int]: return self.shape_ @property def quant_info(self): raise NotImplementedError() @property def dtype(self) -> oneflow.dtype: return self.dtype_ def numpy(self) -> np.ndarray: if not self.has_meta_info_: raise RuntimeError("This variable does not have meta info") return np.fromfile( self.file_path, dtype=dtype_util.convert_oneflow_dtype_to_numpy_dtype(self.dtype), ).reshape(self.shape) def _save_tensor_to_disk(tensor: "oneflow.Tensor", dir_name: Union[str, Path]) -> None: os.makedirs(dir_name, exist_ok=True) meta_info = variable_meta_info_pb.VariableMetaInfo() meta_info.shape.dim[:] = tensor.shape meta_info.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( tensor.dtype ) data_path = os.path.join(dir_name, DATA_FILENAME) with open(data_path, "wb") as f: f.write(tensor.numpy().tobytes()) with open(os.path.join(dir_name, META_INFO_FILENAME), "w") as f: f.write(text_format.MessageToString(meta_info)) ValueContainer = Union[FileBackendVariableBlob, np.ndarray, "oneflow.Tensor"] def _default_restore_location(storage, location=None): return storage def _get_restore_location(map_location): if map_location is None: restore_location = _default_restore_location elif isinstance(map_location, (str, flow.device)): def restore_location(storage, location=None): return storage.to(device=map_location) elif isinstance(map_location, flow.placement): def restore_location(storage, location=None): return storage.to_global(placement=map_location) else: def restore_location(storage, location=None): result = map_location(storage, location) if result is None: result = _default_restore_location(storage, location) return result return restore_location def smart_to(obj: Any, dest: MAP_LOCATION) -> "oneflow.Tensor": if not isinstance(obj, flow.Tensor): return obj tensor = obj restore_location = _get_restore_location(dest) return restore_location(tensor, None) def module_to(obj: flow.nn.Module, dest: MAP_LOCATION) -> "oneflow.nn.Module": restore_location = _get_restore_location(dest) # for nn.Module object, we will use a tensor to get the device # to support dest with a Callable type device = restore_location(flow.tensor([0])).device obj.to(device) return obj def _map_location(obj: Any, map_location: MAP_LOCATION): if isinstance(obj, flow.nn.Module): return module_to(obj, map_location) else: res = ArgsTree(obj).map_leaf(lambda x: smart_to(x, map_location)) return res def _LoadSingleVariable( path: Optional[str], global_src_rank: Optional[int] = None, map_location: MAP_LOCATION = None, ) -> "flow.Tensor": if global_src_rank is not None: rank = flow.env.get_rank() if rank == global_src_rank: file_backed_blob = FileBackendVariableBlob(path) loaded = flow.tensor(file_backed_blob.numpy(), dtype=file_backed_blob.dtype) else: loaded = flow.tensor([]) loaded = loaded.to_global( flow.placement("cpu", [global_src_rank]), flow.sbp.broadcast ) else: loaded = flow.tensor(FileBackendVariableBlob(path).numpy()) return smart_to(loaded, map_location) def _broadcast_py_object(obj, src: int = 0): rank = flow.env.get_rank() if src == rank: obj_bytes = pickle.dumps(obj) return pickle.loads(flow._oneflow_internal.cpu_broadcast(obj_bytes, src)) else: return pickle.loads(flow._oneflow_internal.cpu_broadcast(None, src)) # NOTE(jianhao): # (de)serializing a container of global tensors requires the order # of those tensors are the same across all ranks. def tensor_getstate(self): # context_data is not None means setstate/getstate is called inside # flow.save or flow.load if context_data is not None: if context_data.global_rank is None: assert ( self.is_local ), "Please set global_dst_rank in `flow.save` to save global tensor" tensor = self else: assert not self.is_local # Boxing to cpu firstly to avoid extra gpu memory usage tensor = ( self.to_global( sbp=self.sbp, placement=flow.placement("cpu", self.placement.ranks) ) .to_global( sbp=flow.sbp.broadcast, placement=flow.placement("cpu", [context_data.global_rank]), ) .to_local() ) if context_data.save_as_external_data: if context_data.global_rank is None: rel_dir_name = id_util.UniqueStr("tensor_") else: rel_dir_name = f"global_tensor_{self.global_id()}" abs_dir_name = context_data.path / rel_dir_name if ( context_data.global_rank is None or context_data.global_rank == flow.env.get_rank() ): _save_tensor_to_disk(tensor, abs_dir_name) return {"path": rel_dir_name} else: return { "data": tensor.numpy(), "dtype": tensor.dtype, "device": "cpu", } else: if self.is_local: if self.is_cuda: device = "cuda" else: device = "cpu" return {"data": self.numpy(), "dtype": self.dtype, "device": device} else: return { "data": self.numpy(), "dtype": self.dtype, "placement": self.placement, "sbp": self.sbp, } def tensor_setstate(self, pickle_dict): if context_data is not None: if context_data.save_as_external_data: rel_dir_name = pickle_dict["path"] abs_dir_name = context_data.path / rel_dir_name tmp_tensor = _LoadSingleVariable( str(abs_dir_name), context_data.global_rank, context_data.map_location ) self.__init__(tmp_tensor) else: self.__init__(flow.tensor(pickle_dict["data"], dtype=pickle_dict["dtype"])) else: if "placement" in pickle_dict: return self.__init__( flow.tensor( pickle_dict["data"], dtype=pickle_dict["dtype"], placement=pickle_dict["placement"], sbp=pickle_dict["sbp"], ) ) else: return self.__init__( flow.tensor( pickle_dict["data"], dtype=pickle_dict["dtype"], device=pickle_dict["device"], ) ) def placement_getstate(self): return { "type": self.type, "ranks": self.ranks, } def placement_setstate(self, state): return self.__init__(state["type"], state["ranks"]) def RegisterMethods(): Tensor.__setstate__ = tensor_setstate Tensor.__getstate__ = tensor_getstate flow._oneflow_internal.placement.__getstate__ = placement_getstate flow._oneflow_internal.placement.__setstate__ = placement_setstate load_methods = [] def load_if(condition): def decorator(func): def condition_always_returning_extra_data(*args, **kwargs): res = condition(*args, **kwargs) if isinstance(res, tuple): assert len(res) == 2 assert isinstance(res[1], tuple) return res else: return res, () load_methods.append((condition_always_returning_extra_data, func)) return func return decorator def is_dir_and_no_pickle_file(path: FILE_LIKE, support_pytorch_format: bool): if _is_path(path) and path.is_dir(): pickle_path = path / PICKLE_FILENAME return not pickle_path.exists() return False @load_if(is_dir_and_no_pickle_file) def legacy_load( path: Path, global_src_rank: Optional[int], map_location: MAP_LOCATION, ) -> Dict[str, "flow.Tensor"]: assert os.path.isdir(path), "Directory {} doesn't exist!".format(path) rank = flow.env.get_rank() var_dict = {} if global_src_rank is None or rank == global_src_rank: all_files = os.listdir(path) assert SNAPSHOT_DONE_FILENAME in all_files all_files.remove(SNAPSHOT_DONE_FILENAME) if global_src_rank is not None: _broadcast_py_object(all_files, global_src_rank) else: all_files = _broadcast_py_object(None, global_src_rank) for f in all_files: var_dir = os.path.join(path, f) try: var_dict[f] = _LoadSingleVariable(var_dir, global_src_rank, map_location) except FileNotFoundError: warnings.warn( f"'{var_dir}' does not have valid tensor data. Please check it if it is unexpected.", stacklevel=2, ) return var_dict @contextlib.contextmanager def tensor_pickling_context( path: Path, global_rank: Optional[int], mp: MAP_LOCATION, save_as_external_data: bool, ): global context_data context_data = ContextData(path, global_rank, mp, save_as_external_data) try: yield finally: context_data = None def is_oneflow_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool: if _is_path(path) and not path.is_file(): return False try: with _open_file_like(path, "rb") as f: content = pickle.load(f) if ONEFLOW_MAGIC_KEY in content: return True, (content,) else: return False except: return False # `path` is not used in this function, because the file is already loaded # and deserialized in `is_oneflow_pickle_file`, and the content is passed # as `content`. @load_if(is_oneflow_pickle_file) def load_from_oneflow_single_file( path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, content: Any = None, ): rank = flow.env.get_rank() if global_src_rank is None or rank == global_src_rank: assert content["protocol_version"] == PROTOCOL_VERSION res = content["data"] else: res = None if global_src_rank is not None: res = flow.utils.global_view.to_global( res, placement=flow.placement("cpu", [global_src_rank]), sbp=flow.sbp.broadcast, warn_on_non_tensor_leaf=False, ) res = _map_location(res, map_location) return res def is_file_and_support_pytorch_format( path: FILE_LIKE, support_pytorch_format: bool ) -> bool: if not support_pytorch_format: return False if _is_path(path) and not path.is_file(): return False try: with flow.mock_torch.disable(): import torch content = torch.load(path, map_location="cpu") return True, (content,) except: if os.getenv("ONEFLOW_DEBUG_CHECKPOINT") == "1": import traceback traceback.print_exc() return False @load_if(is_file_and_support_pytorch_format) def load_from_pytorch_file( path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, torch_obj: Any = None ): if torch_obj is not None: with flow.mock_torch.disable(): import torch def torch_tensor_to_flow(x): if isinstance(x, torch.Tensor): return flow.utils.tensor.from_torch(x) else: return x flow_obj = ArgsTree(torch_obj).map_leaf(torch_tensor_to_flow) else: flow_obj = None if global_src_rank is not None: flow_obj = flow.utils.global_view.to_global( flow_obj, placement=flow.placement("cpu", [global_src_rank]), sbp=flow.sbp.broadcast, warn_on_non_tensor_leaf=False, ) flow_obj = _map_location(flow_obj, map_location) return flow_obj def is_dir_and_has_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool: if _is_path(path) and path.is_dir(): pickle_path = path / PICKLE_FILENAME return pickle_path.exists() return False @load_if(is_dir_and_has_pickle_file) def load_from_oneflow_pickle_dir( path: Path, global_src_rank: Optional[int], map_location: MAP_LOCATION, ): rank = flow.env.get_rank() pickle_path = path / PICKLE_FILENAME if global_src_rank is not None: if rank == global_src_rank: pickle_bytes = pickle_path.read_bytes() _broadcast_py_object(pickle_bytes, global_src_rank) else: pickle_bytes = _broadcast_py_object(None, global_src_rank) else: pickle_bytes = pickle_path.read_bytes() if map_location is not None: assert isinstance( map_location, (str, flow.device, flow.placement) ), "'map_location' only supports str, device or placement." with tensor_pickling_context(path, global_src_rank, map_location, True): res = pickle.loads(pickle_bytes) assert res["protocol_version"] == PROTOCOL_VERSION return res["data"] def load( path: Union[FILE_LIKE, str], global_src_rank: Optional[int] = None, map_location: MAP_LOCATION = None, *, support_pytorch_format: bool = True, ) -> Any: r"""Loads an object saved with oneflow.save() from a directory. Args: path: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name global_src_rank (int, optional): The source rank for loading global tensors. When specified, only the process whose rank == global_src_rank will really read the files in `path`, and tensors in the loaded object will be consistent with placement = `flow.placement('cuda', [global_src_rank])` map_location (str, flow.device or flow.placement, callable, optional): indicates the location where all tensors should be loaded. support_pytorch_format (bool, optional): whether to support loading the file saved by `torch.save`. Default: True Returns: The loaded object """ if isinstance(path, str): path = Path(path) rank = flow.env.get_rank() if global_src_rank is None or global_src_rank == rank: for i, (condition, load) in enumerate(load_methods): is_ok, extra_data = condition(path, support_pytorch_format) if is_ok: if global_src_rank is not None: _broadcast_py_object(i, global_src_rank) break else: if _is_path(path): err_msg = f'Cannot load file "{path}"' else: err_msg = "Cannot load the data" raise ValueError(err_msg) else: i = _broadcast_py_object(None, global_src_rank) load = load_methods[i][1] extra_data = () return load(path, global_src_rank, map_location, *extra_data) # type: ignore def save_one_embedding_info(state_dict: Any, path: Union[str, Path]) -> None: path: Path = Path(path) _embedding_info_dict = {"embedding": []} os.makedirs(path, exist_ok=True) _save_one_embedding_info_flag = False for module in state_dict.keys(): if not isinstance(state_dict[module], OrderedDict): continue for module_key in state_dict[module].keys(): _info_dict = {} if "OneEmbeddingKeyValueOptions" in module_key: if not _save_one_embedding_info_flag: _save_one_embedding_info_flag = True module_key_prefix = module_key.rstrip("OneEmbeddingKeyValueOptions") _embedding_info_dict["embedding"].append( { "snapshot": state_dict["module"][ module_key_prefix + "OneEmbeddingSnapshot" ], "kv_options": json.loads( state_dict["module"][ module_key_prefix + "OneEmbeddingKeyValueOptions" ] ), } ) if _save_one_embedding_info_flag: with open(os.path.join(path, "one_embedding_options.json"), "w") as f: f.write(json.dumps(_embedding_info_dict, indent=4)) def save( obj: Any, path_or_buffer: FILE_LIKE, global_dst_rank: Optional[int] = None, save_as_external_data: bool = False, ) -> None: r"""Save an object to a directory. Args: obj: The object to be saved path_or_buffer: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a file name global_dst_rank (int, optional): The destination rank for saving global tensors. When specified, whole tensors will be saved by the process whose rank == global_src_rank, while other processes will not do any disk I/O. save_as_external_data (bool): useful only if path_or_buffer is a string or os.PathLike object containing a file name """ if isinstance(path_or_buffer, str): path_or_buffer = Path(path_or_buffer) if isinstance(obj, graph_util.Graph): if not _is_path(path_or_buffer): raise ValueError( "path_or_buffer must be the type of {`str`, `pathlib.Path`} while obj is Graph" ) _save_graph(obj, path_or_buffer) return # this `path` is only used for `ContextData` and is set to empty when `path_or_buffer` is IO[bytes] or BinaryIO path: Path = Path(path_or_buffer if _is_path(path_or_buffer) else "") obj = {"protocol_version": PROTOCOL_VERSION, ONEFLOW_MAGIC_KEY: None, "data": obj} with tensor_pickling_context(path, global_dst_rank, None, save_as_external_data): pickled_bytes = pickle.dumps(obj) if _is_path(path_or_buffer) and save_as_external_data: path_or_buffer.mkdir(exist_ok=True) path_or_buffer = path_or_buffer / PICKLE_FILENAME def write_file(): with _open_file_like(path_or_buffer, "wb") as f: f.write(pickled_bytes) if global_dst_rank is not None: assert isinstance( global_dst_rank, int ), f"global_dst_rank expected type int, but got {type(global_dst_rank)}." assert ( global_dst_rank >= 0 and global_dst_rank < flow.env.get_world_size() ), f"out of range (expected to be in range of [0, {flow.env.get_world_size()}), but got {global_dst_rank})." if flow.env.get_rank() == global_dst_rank: write_file() else: # global_dst_rank is None write_file() def _save_graph(obj: graph_util.Graph, path: Union[str, Path]): path: Path = Path(path) graph: graph_util.Graph = obj if not graph._is_compiled: raise RuntimeError("graph must be compiled first.") path.mkdir(exist_ok=True) serialized_job = graph._forward_job_proto.SerializeToString() oneflow._oneflow_internal.nn.graph.SaveJobToIR(serialized_job, str(path)) for x in graph._state(): _save_tensor_to_disk( x.to(Tensor), path / f"{x.to(GraphTensor).name_prefix}{x.to(GraphTensor).name}", ) save_one_embedding_info(obj.state_dict(), path) def frombuffer( buffer: object, dtype: oneflow.dtype, count: int = -1, offset: int = 0, requires_grad: bool = False, ): return oneflow.tensor( np.frombuffer( buffer, dtype_util.convert_oneflow_dtype_to_numpy_dtype(dtype), count, offset, ), dtype=dtype, requires_grad=requires_grad, ) class ContextData: def __init__( self, path: Path, global_rank: Optional[int], map_location: Optional[Union[str, flow.device, flow.placement]], save_as_external_data: bool, ): self.path = path self.global_rank = global_rank self.map_location = map_location self.save_as_external_data = save_as_external_data context_data = None ================================================ FILE: python/oneflow/framework/config_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import sys import traceback from typing import Callable, List, Union import oneflow._oneflow_internal import oneflow.core.job.resource_pb2 as resource_util import oneflow.framework.session_context as session_ctx import oneflow.framework.attr_util as attr_util def _set_resource_attr(attrs_chain: Union[List[str], str], attr_value, type_): r""" set the attribute of config_proto.resource to attr_value. the attribute is specified as a string or a list of string. for example, if we want to do this: `config_proto.resource.machine_num = 1` we can call `_set_resource_attr("machine_num", 1)` if we want to do: `config_proto.resource.collective_boxing_conf.nccl_num_streams = 1` we can call `_set_resource_attr(["collective_boxing_conf", "nccl_num_streams"], 1)` ` """ assert isinstance(attr_value, type_), ( "Attribute " + repr(attrs_chain) + " type unmatched! Expected: " + str(type_) + " but get: " + str(type(attr_value)) ) if isinstance(attrs_chain, str): attrs_chain = [attrs_chain] session = session_ctx.GetDefaultSession() # get the current resource config resource_config = ( session.config_proto.resource if session.status_ != session.Status.INITED else session.resource ) # update the current resource config setattr( attr_util.get_nested_attribute( resource_config, attrs_chain[0:-1] ), # the owning object of the attribute to be updated attrs_chain[-1], # the attribute needs to be updated attr_value, ) # update the resource config eagerly if the session is already initialized if session.status_ == session.Status.INITED: session.update_resource_eagerly(resource_config) def api_load_library(val: str) -> None: """Load necessary library for job now Args: val (str): path to shared object file """ assert type(val) is str oneflow._oneflow_internal.LoadLibrary(val) def api_numa_aware_cuda_malloc_host(val: bool = True) -> None: """Whether or not let numa know that cuda allocated host's memory. Args: val (bool, optional): True or False. Defaults to True. """ print( "'enable_numa_aware_cuda_malloc_host' has been deprecated, has no effect and will be removed in the future." ) def api_reserved_device_mem_mbyte(val: int) -> None: """Set up the memory size of reserved device Args: val (int): memory size, e.g. 1024(mb) """ attrs, type_ = api_attrs_and_type[api_reserved_device_mem_mbyte] _set_resource_attr(attrs, val, type_) def api_enable_cudnn_fused_normalization_add_relu(val: bool) -> None: """Whether enable cudnn_fused_normalization_add_relu. Args: val (bool): whether enable or not """ attrs, type_ = api_attrs_and_type[api_enable_cudnn_fused_normalization_add_relu] _set_resource_attr(attrs, val, type_) def api_enable_cudnn_conv_heuristic_search_algo(val: bool) -> None: """Whether enable cudnn conv operatioin to use heuristic search algorithm. Args: val (bool): whether enable or not, the default value is true. """ attrs, type_ = api_attrs_and_type[api_enable_cudnn_conv_heuristic_search_algo] _set_resource_attr(attrs, val, type_) def api_enable_fusion(val: bool = True) -> None: """Whether or not allow fusion the operators Args: val (bool, optional): True or False. Defaults to True. """ attrs, type_ = api_attrs_and_type[api_enable_fusion] _set_resource_attr(attrs, val, type_) def api_nccl_use_compute_stream(val: bool = False) -> None: """Whether or not nccl use compute stream to reuse nccl memory and speedup Args: val (bool, optional): True or False. Defaults to False. """ attrs, type_ = api_attrs_and_type[api_nccl_use_compute_stream] _set_resource_attr(attrs, val, type_) def api_disable_group_boxing_by_dst_parallel(val: bool = False) -> None: """Whether or not disable group boxing by dst parallel pass to reduce boxing memory life cycle. Args: val (bool, optional): True or False. Defaults to False. """ attrs, type_ = api_attrs_and_type[api_disable_group_boxing_by_dst_parallel] _set_resource_attr(attrs, val, type_) def api_nccl_num_streams(val: int) -> None: """Set up the number of nccl parallel streams while use boxing Args: val (int): number of streams """ attrs, type_ = api_attrs_and_type[api_nccl_num_streams] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_threshold_mb(val: int) -> None: """Set up threshold for oprators fusion Args: val (int): int number, e.g. 10(mb) """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_threshold_mb] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_all_reduce_use_buffer(val: bool) -> None: """Whether or not use buffer during nccl fusion progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_reduce_use_buffer] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_all_reduce(val: bool) -> None: """Whether or not use nccl fusion during all reduce progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_reduce] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_reduce_scatter(val: bool) -> None: """Whether or not use nccl fusion during reduce scatter progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_reduce_scatter] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_all_gather(val: bool) -> None: """Whether or not use nccl fusion during all gather progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_gather] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_reduce(val: bool) -> None: """Whether or not use nccl fusion during reduce progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_reduce] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_broadcast(val: bool) -> None: """Whether or not use nccl fusion during broadcast progress Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_broadcast] _set_resource_attr(attrs, val, type_) def api_nccl_fusion_max_ops(val: int) -> None: """Maximum number of ops for nccl fusion. Args: val (int): Maximum number of ops """ attrs, type_ = api_attrs_and_type[api_nccl_fusion_max_ops] _set_resource_attr(attrs, val, type_) def api_nccl_enable_all_to_all(val: bool) -> None: """Whether or not use nccl all2all during s2s boxing Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_enable_all_to_all] _set_resource_attr(attrs, val, type_) def api_nccl_enable_mixed_fusion(val: bool) -> None: """Whether or not use nccl mixed fusion Args: val (bool): True or False """ attrs, type_ = api_attrs_and_type[api_nccl_enable_mixed_fusion] _set_resource_attr(attrs, val, type_) api_attrs_and_type = { api_reserved_device_mem_mbyte: ("reserved_device_mem_mbyte", int), api_enable_cudnn_fused_normalization_add_relu: ( ["cudnn_conf", "enable_cudnn_fused_normalization_add_relu"], bool, ), api_enable_cudnn_conv_heuristic_search_algo: ( ["cudnn_conf", "cudnn_conv_heuristic_search_algo"], bool, ), api_enable_fusion: (["collective_boxing_conf", "enable_fusion"], bool), api_nccl_use_compute_stream: ("nccl_use_compute_stream", bool), api_disable_group_boxing_by_dst_parallel: ( "disable_group_boxing_by_dst_parallel", bool, ), api_nccl_num_streams: (["collective_boxing_conf", "nccl_num_streams"], int), api_nccl_fusion_threshold_mb: ( ["collective_boxing_conf", "nccl_fusion_threshold_mb"], int, ), api_nccl_fusion_all_reduce_use_buffer: ( ["collective_boxing_conf", "nccl_fusion_all_reduce_use_buffer"], bool, ), api_nccl_fusion_all_reduce: ( ["collective_boxing_conf", "nccl_fusion_all_reduce"], bool, ), api_nccl_fusion_reduce_scatter: ( ["collective_boxing_conf", "nccl_fusion_reduce_scatter"], bool, ), api_nccl_fusion_all_gather: ( ["collective_boxing_conf", "nccl_fusion_all_gather"], bool, ), api_nccl_fusion_reduce: (["collective_boxing_conf", "nccl_fusion_reduce"], bool), api_nccl_fusion_broadcast: ( ["collective_boxing_conf", "nccl_fusion_broadcast"], bool, ), api_nccl_fusion_max_ops: (["collective_boxing_conf", "nccl_fusion_max_ops"], int), api_nccl_enable_all_to_all: ( ["collective_boxing_conf", "nccl_enable_all_to_all"], bool, ), api_nccl_enable_mixed_fusion: ( ["collective_boxing_conf", "nccl_enable_mixed_fusion"], bool, ), } ================================================ FILE: python/oneflow/framework/distribute.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import traceback import warnings from contextlib import contextmanager import oneflow._oneflow_internal def split_sbp(dim=None, **kwargs) -> oneflow._oneflow_internal.sbp.sbp: """ Generate a split signature which indicates the tensor will be split along `dim`. Args: dim (int): The dimension in which the tensor is split. Returns: SbpParallel: Split scheme object, often required by `to_global` method of `Tensor` Example:: array = numpy.array([[1.0, 2.0], [3.0, 4.0]]) t1 = flow.tensor(array) ct2 = t1.to_global(sbp=flow.sbp.split(0), placement=("cuda", ranks=[0, 1, 2, 3])) """ if dim is None: for key, value in kwargs.items(): if key == "axis": if not isinstance(value, int): raise TypeError( "split_sbp(): parameter must be int, not {}.".format( type(value) ) ) warnings.warn( "This 'axis' parameter of oneflow.sbp.split() has been updated to 'dim' since OneFlow version 0.8." ) dim = value else: raise TypeError( "split_sbp() got an unexpected keyword argument '%s'." % key ) if dim is None: raise TypeError("split_sbp() missing 1 required argument: 'dim'.") else: for key, value in kwargs.items(): if key == "axis": raise TypeError( "split_sbp() received an invalid combination of arguments - duplicate argument `axis`" ) else: raise TypeError( "split_sbp() got an unexpected keyword argument '%s'." % key ) assert isinstance(dim, int) return oneflow._oneflow_internal.sbp.split(dim) ================================================ FILE: python/oneflow/framework/docstr/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .math_ops import * from .random import * from .conv import * from .as_tensor import * from .pooling import * from .activation import * from .dropout import * from .vision import * from .norm import * from .normalization import * from .loss import * from .onehot import * from .comparison import * from .cast import * from .constant import * from .array_ops import * from .tensor import * from .tensor_attributes import * from .comm import * from .ctc_decode import * from .trigonometric_ops import * from .tensor_ops import * from .meshgrid import * from .dataset import * from .bmm import * from .flatten import * from .chunk import * from .broadcast_like import * from .arange import * from .split import * from .clamp import * from .erfinv import * from .swapaxes import * from .amax import * from .unbind import * from .repeat import * from .repeat_interleave import * from .tile import * from .tensor_t import * from .topk import * from .nms import * from .nonzero import * from .reduce_ops import * from .masked_fill import * from .expand import * from .flip import * from .in_top_k import * from .index_select import * from .sort import * from .is_floating_point import * from .swapdims import * from .where import * from .einsum import * from .oneflow import * from .argsort import * from .module import * from .util_ops import * from .tensordot import * from .searchsorted import * from .amin import * from .deconv import * from .inv import * from .logical_ops import * from .bitwise_ops import * from .distance import * from .addcdiv import * from .hann_window import * from .convolution import * from .linalg import * from .index_add import * from .baddbmm import * from .lerp import * from .quantile import * from .depend import * from .special_ops import * ================================================ FILE: python/oneflow/framework/docstr/activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.prelu, """ prelu(x: Tensor, alpha: Tensor) -> Tensor Applies the element-wise function: .. math:: prelu(x) = max(0,x) + alpha * min(0,x) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor(np.asarray([[[[1, -2], [3, 4]]]]), dtype=flow.float32) >>> alpha = flow.nn.Parameter(flow.tensor([1], dtype=flow.float32).fill_(0.25)) >>> flow.nn.functional.prelu(x, alpha) tensor([[[[ 1.0000, -0.5000], [ 3.0000, 4.0000]]]], dtype=oneflow.float32, grad_fn=) See :class:`~oneflow.nn.PReLU` for more details. """, ) add_docstr( oneflow.relu, """ Applies the rectified linear unit function element-wise. See :class:`~oneflow.nn.ReLU` for more details. Args: inplace: If set to ``True``, will do this operation in-place. Default: ``False`` For examples: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> ndarr = np.asarray([1, -2, 3]) >>> input = flow.Tensor(ndarr) >>> output = flow.relu(input) >>> output tensor([1., 0., 3.], dtype=oneflow.float32) """, ) add_docstr( oneflow.gelu, r""" gelu(x: Tensor) -> Tensor Applies the Gaussian Error Linear Units function: .. math:: \\text{GELU}(x) = x * \Phi(x) where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. When the approximate argument is 'tanh', Gelu is estimated with: .. math:: \\text{GELU}(x) = 0.5 * x * (1 + \\text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) Args: input (oneflow.Tensor): Input Tensor approximate (string, optional): the gelu approximation algorithm to use: ``'none'`` | ``'tanh'``. Default: ``'none'`` Returns: oneflow.Tensor: A Tensor has same shape as the input. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.gelu(input) >>> out tensor([-0.1543, 0.0000, 0.3457], dtype=oneflow.float32) See :class:`~oneflow.nn.GELU` for more details. """, ) add_docstr( oneflow._C.quick_gelu, r""" quick_gelu(x: Tensor) -> Tensor Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs .. math:: \\text{QuickGELU}(x) = x * \\sigma(1.702x) = x * \\frac{1}{1 + \\exp(-1.702x)} Args: input (oneflow.Tensor): Input Tensor Returns: oneflow.Tensor: A Tensor has same shape as the input. See :class:`~oneflow.nn.QuickGELU` for more details. """, ) add_docstr( oneflow._C.square_relu, r""" square_relu(x: Tensor) -> Tensor Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 .. math:: \\text{ReLU}(x) = \\max(0, x) * \\max(0, x) Args: input (oneflow.Tensor): Input Tensor Returns: oneflow.Tensor: A Tensor has same shape as the input. See :class:`~oneflow.nn.SquareReLU` for more details. """, ) add_docstr( oneflow._C.softmax, r""" softmax(x: Tensor, dim: int) -> Tensor Softmax is defined as: .. math:: \text{Softmax}(x_{i}) = \frac{\\exp(x_i)}{\sum_j \exp(x_j)} See :class:`~oneflow.nn.Softmax` for more details. """, ) add_docstr( oneflow._C.log_softmax, r""" log_softmax(x: Tensor, dim: int) -> Tensor LogSoftmax is defined as: .. math:: \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) = x_i - \log({ \sum_j \exp(x_j)}) See :class:`~oneflow.nn.LogSoftmax` for more details. """, ) add_docstr( oneflow._C.gumbel_softmax, r""" gumbel_softmax(x: Tensor, dim: int, tau: float = 1.0, hard: bool = False) -> Tensor Solve the problem that the output values of argmax do not reflect the probability distribution of the model's output. Compensates for the fact that the argmax cannot participate in gradient back-propagation. Gumbel is defined as: .. math:: Gumbel_i = -log(-log(U_i)),\ U_i \sim U(0,1) Add Noise ~ Gumbel: .. math:: In = (In + Noise) / tau Calculate Softmax value: .. math:: gumbel\_softmax(In)=\frac{e^{In_i/tau}}{\sum_{j=1}^n{e^{In_j/tau}}},i=1,2,3...n Parameters: x (oneflow.Tensor): the input Tensor. dim (int, Tuple[int]): the dimension to softmax. tau (double): the input tensor of Softmax should obey the Gumbel(x, tau). hard (bool): if `hard=True`, the output tensor will be one-hot. """, ) add_docstr( oneflow.softplus, r""" softplus(x: Tensor, beta: double = 1, threshold: double = 20) -> Tensor Applies the element-wise function: .. math:: \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) For numerical stability the implementation reverts to the linear function when :math:`input \times \beta > threshold`. See :class:`~oneflow.nn.Softplus` for more details. """, ) add_docstr( oneflow.tanh, r""" tanh(x: Tensor) -> Tensor The equation is: .. math:: out = \frac{e^x-e^{-x}}{e^x+e^{-x}} See :class:`~oneflow.nn.Tanh` for more details. """, ) add_docstr( oneflow._C.logsigmoid, r""" logsigmoid(x: Tensor) -> Tensor Applies the element-wise function: .. math:: \text{logsigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.nn.functional.logsigmoid(input) >>> out tensor([-0.9741, -0.6931, -0.4741], dtype=oneflow.float32) See :class:`~oneflow.nn.LogSigmoid` for more details. """, ) add_docstr( oneflow._C.softsign, r""" softsign(x: Tensor) -> Tensor The formula is: .. math:: softsign(x) = \frac{x}{1 + |x|} For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.nn.functional.softsign(input) >>> out tensor([0.5000, 0.6667, 0.7500], dtype=oneflow.float32) See :class:`~oneflow.nn.Softsign` for more details. """, ) add_docstr( oneflow.silu, """ silu(x: Tensor) -> Tensor The formula is: .. math:: \text{silu}(x) = x * sigmoid(x) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.silu(input) >>> out tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32) See :class:`~oneflow.nn.SiLU` for more details. """, ) add_docstr( oneflow.mish, """ mish(x: Tensor) -> Tensor Applies the element-wise function: .. math:: \text{mish}(x) = x * \text{tanh}(\text{softplus}(x)) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.mish(input) >>> out tensor([0.8651, 1.9440, 2.9865], dtype=oneflow.float32) See :class:`~oneflow.nn.Mish` for more details. """, ) add_docstr( oneflow._C.hardsigmoid, """ hardsigmoid(x: Tensor)-> Tensor Applies the element-wise function .. math:: \text{Hardsigmoid}(x) = \begin{cases} 0 & \text{if~} x \le -3, \\ 1 & \text{if~} x \ge +3, \\ x / 6 + 1 / 2 & \text{otherwise} \end{cases} See :class:`~oneflow.nn.Hardsigmoid` for more details. """, ) add_docstr( oneflow._C.hardswish, """ hardswish(x: Tensor)-> Tensor Applies the hardswish function, element-wise, as described in the paper: `Searching for MobileNetV3`_. .. math:: \text{Hardswish}(x) = \begin{cases} 0 & \text{if~} x \le -3, \\ x & \text{if~} x \ge +3, \\ x \cdot (x + 3) /6 & \text{otherwise} \end{cases} See :class:`~oneflow.nn.Hardswish` for more details. .. _`Searching for MobileNetV3`: https://arxiv.org/abs/1905.02244 """, ) add_docstr( oneflow.sigmoid, r""" sigmoid(input) -> Tensor Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` See :class:`~oneflow.nn.Sigmoid` for more details. For examples: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([0.81733328, 0.43621480, 0.10351428]) >>> input = flow.tensor(x, dtype=flow.float32) >>> out = flow.nn.functional.sigmoid(input) >>> out tensor([0.6937, 0.6074, 0.5259], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.hardtanh, """ hardtanh(input, min_val=-1., max_val=1.) -> Tensor Applies the HardTanh function element-wise. See :class:`~oneflow.nn.Hardtanh` for more details. """, ) add_docstr( oneflow._C.leaky_relu, """ leaky_relu(x: Tensor, alpha :Float) -> Tensor Applies element-wise, :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative_slope} * \min(0, x)` See :class:`~oneflow.nn.LeakyReLU` for more details. """, ) add_docstr( oneflow._C.rrelu, """ rrelu(x: Tensor, lower: Float = 1.0 / 8, upper: Float = 1.0 / 3, training: bool = False, inplace: bool = False) -> Tensor Applies the randomized leaky rectified liner unit function, element-wise :math:`\text{RReLU}(x) = \max(0, x) + a * \min(0, x)` where :math:`a` is randomly sampled from uniform distribution :math:`\mathcal{U}(\text{lower}, \text{upper})`. See :class:`~oneflow.nn.RReLU` for more details. """, ) add_docstr( oneflow._C.rrelu_, """ rrelu(x: Tensor, lower: Float = 1.0 / 8, upper: Float = 1.0 / 3, training: bool = False) -> Tensor In-place version of :func:`rrelu`. """, ) add_docstr( oneflow._C.elu, """ elu(x: Tensor, alpha :Float) -> Tensor Applies element-wise, :math:`\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))`. See :class:`~oneflow.nn.ELU` for more details. For examples: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.nn.functional.elu(input, alpha=1.0) >>> out tensor([-0.3935, 0.0000, 0.5000], dtype=oneflow.float32) """, ) add_docstr( oneflow.selu, """ selu(x: Tensor) -> Tensor Applies element-wise function .. math:: \text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, with :math:`\alpha=1.6732632423543772848170429916717` and :math:`scale=1.0507009873554804934193349852946`. See :class:`~oneflow.nn.SELU` for more details. For examples: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.nn.functional.selu(input) >>> out tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.glu, """ glu(input: Tensor, dim: int) -> Tensor The equation is: .. math:: GLU(input) = GLU(a, b) = a \otimes sigmoid(b) .. note:: where input is split in half along dim to form a and b, ⊗ is the element-wise product between matrices. For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> x = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=flow.float32) >>> y = nn.functional.glu(x) >>> y tensor([[0.9526, 1.9640], [4.9954, 5.9980]], dtype=oneflow.float32) See :class:`~oneflow.nn.GLU` for more details. """, ) add_docstr( oneflow._C.celu, r""" celu(x: Tensor, alpha: Float=1.0, inplace: bool=False) -> Tensor Applies the element-wise function: .. math:: \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) See :class:`~oneflow.nn.CELU` for more details. For examples: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.tensor(x) >>> out = flow.nn.functional.celu(input, alpha=0.5) >>> out tensor([-0.3161, 0.0000, 0.5000], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.threshold, """ threshold(input: Tensor, threshold: float, value: float) -> Tensor Thresholds each element of the input Tensor. See :class:`~oneflow.nn.Threshold` for more details. """, ) add_docstr( oneflow._C.hardshrink, """ hardshrink(input: Tensor, lambd: float=0.5, inplace: bool=False) -> Tensor Applies the hard shrinkage function in an element-wise manner. See :class:`~oneflow.nn.Hardshrink` for more details. """, ) add_docstr( oneflow._C.softshrink, """ softshrink(input: Tensor, lambd: float=0.5, inplace: bool=False) -> Tensor Applies the soft shrinkage function in an element-wise manner. See :class:`~oneflow.nn.Softshrink` for more details. """, ) ================================================ FILE: python/oneflow/framework/docstr/addcdiv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.addcdiv, r""" addcdiv(input, tensor1, tensor2, *, value=1) -> Tensor This function is equivalent to PyTorch’s addcdiv function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.addcdiv.html. Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, multiply the result by the scalar :attr:`value` and add it to :attr:`input`. .. math:: \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be `broadcastable`. For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be a real number, otherwise an integer. Args: input (Tensor): the tensor to be added tensor1 (Tensor): the numerator tensor tensor2 (Tensor): the denominator tensor Keyword args: value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}` Example:: >>> import oneflow as flow >>> input = flow.tensor([ 0.3810, 1.2774, -0.2972, -0.3719]) >>> tensor1 = flow.tensor([0.8032, 0.2930, -0.8113, -0.2308]) >>> tensor2 = flow.tensor([[0.5], [1]]) >>> output = flow.addcdiv(input, tensor1, tensor2) >>> output.shape oneflow.Size([2, 4]) """, ) ================================================ FILE: python/oneflow/framework/docstr/amax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.amax, """ oneflow.amax(input, dim=None, keepdim=False) -> Tensor Returns the maximum along a dimension. This function is equivalent to PyTorch’s amax function. Args: input (oneflow.Tensor): the input Tensor. dim (int or List of int, optional): the dimension or the dimensions to reduce. Dim is None by default. keepdim (bool, optional): whether to retain the dimension. keepdim is False by default. Returns: oneflow.Tensor: Maximum of the input tensor For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) >>> flow.amax(x, 1) tensor([[2, 3], [6, 7]], dtype=oneflow.int64) >>> flow.amax(x, 0) tensor([[4, 5], [6, 7]], dtype=oneflow.int64) >>> flow.amax(x) tensor(7, dtype=oneflow.int64) >>> flow.amax(x, 0, True) tensor([[[4, 5], [6, 7]]], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/amin.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.amin, """ amin(input, dim, keepdim=False) -> Tensor Returns the minimum value of each slice of the `input` tensor in the given dimension(s) `dim`. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed (see :func:`oneflow.squeeze`), resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). This function is equivalent to PyTorch’s amin function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.amin.html. Parameters: input (oneflow.Tensor): the input Tensor. dim (int, Tuple[int]): the dimension or dimensions to reduce. keepdim (bool): whether the output tensor has `dim` retained or not. Example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) >>> flow.amin(x, 1) tensor([[0, 1], [4, 5]], dtype=oneflow.int64) >>> flow.amin(x, 0) tensor([[0, 1], [2, 3]], dtype=oneflow.int64) >>> flow.amin(x) tensor(0, dtype=oneflow.int64) >>> flow.amin(x, 0, True) tensor([[[0, 1], [2, 3]]], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/arange.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.arange, """ oneflow.arange(start: int = 0, end, step: int = 1, dtype: Optional[oneflow._oneflow_internal.dtype] = None, device: Optional[Union[oneflow._oneflow_internal.device, str]] = None, placement: Optional[oneflow._oneflow_internal.placement] = None, sbp: Optional[Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]]] = None, requires_grad: bool = False) Returns a 1-D tensor of size :math:`\\left\\lfloor \\frac{\\text{end} - \\text{start}}{\\text{step}} \\right\\rfloor + 1` with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is the gap between two values in the tensor. .. math:: \\text{out}_{i+1} = \\text{out}_i + \\text{step}. Args: start (int): the starting value for the set of points. Default: ``0``. end (int): the ending value for the set of points step (int): the gap between each pair of adjacent points. Default: ``1``. Keyword args: dtype(flow.dtype, optional): If `dtype` is not given, infer the `dtype` from the other input arguments. If any of start, end, or step are floating-point, the `dtype` is inferred to be the floating-point data type. Otherwise, the `dtype` is inferred to be `flow.int64`. device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.arange(0, 5) >>> y tensor([0, 1, 2, 3, 4], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/argsort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.argsort, """ argsort() -> Tensor This operator sorts the input Tensor at specified dim and returns the indices of the sorted Tensor. Args: input (oneflow.Tensor): the input Tensor. dim (int, optional): the dimension to be sorted. Defaults to the last dim (-1). descending (bool, optional): controls the sorting order (ascending or descending). Returns: oneflow.Tensor: The indices of the sorted Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([[10, 2, 9, 3, 7], ... [1, 9, 4, 3, 2]]).astype("float32") >>> input = flow.Tensor(x) >>> output = flow.argsort(input) >>> output tensor([[1, 3, 4, 2, 0], [0, 4, 3, 2, 1]], dtype=oneflow.int32) >>> output = flow.argsort(input, descending=True) >>> output tensor([[0, 2, 4, 3, 1], [1, 2, 3, 4, 0]], dtype=oneflow.int32) >>> output = flow.argsort(input, dim=0) >>> output tensor([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0]], dtype=oneflow.int32) """, ) ================================================ FILE: python/oneflow/framework/docstr/array_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.diagonal, r""" oneflow.diagonal(input, offset, dim1, dim2) -> Tensor Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 appended as a dimension at the end of the shape. Args: input (Tensor): the input tensor.Must be at least 2-dimensional. offset (Optional[int], 0): which diagonal to consider. Default: 0 (main diagonal) dim1 (Optional[int], 0): first dimension with respect to which to take diagonal. Default: 0 dim2 (Optional[int], 1): second dimension with respect to which to take diagonal. Default: 1 Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.randn(2, 3, 4) >>> output = flow.diagonal(input, offset=1, dim1=1, dim2=0) >>> output.shape oneflow.Size([4, 1]) """, ) add_docstr( oneflow.diag, r""" If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal. If input is a matrix (2-D tensor), then returns a 1-D tensor with diagonal elements of input. Args: input (Tensor): the input tensor. diagonal (Optional[int], 0): The diagonal to consider. If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0. Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array( ... [ ... [1.0, 2.0, 3.0], ... [4.0, 5.0, 6.0], ... [7.0, 8.0, 9.0], ... ] ... ) >>> input = flow.tensor(arr, dtype=flow.float32) >>> flow.diag(input) tensor([1., 5., 9.], dtype=oneflow.float32) """, ) add_docstr( oneflow.tril, r"""Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input along the specified diagonal, the other elements of the result tensor out are set to 0. .. note:: - if diagonal = 0, the diagonal of the returned tensor will be the main diagonal, - if diagonal > 0, the diagonal of the returned tensor will be above the main diagonal, - if diagonal < 0, the diagonal of the returned tensor will be below the main diagonal. Args: input (Tensor): the input tensor. diagonal (int, optional): the diagonal to specify. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.ones(shape=(3, 3)).astype(np.float32)) >>> flow.tril(x) tensor([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.triu, r"""Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. Args: input (Tensor): the input tensor. diagonal (int, optional): the diagonal to consider For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.ones(shape=(3, 3)).astype(np.float32)) >>> flow.triu(x) tensor([[1., 1., 1.], [0., 1., 1.], [0., 0., 1.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.argmax, r"""The op computes the index with the largest value of a Tensor at specified axis. Args: input (oneflow.Tensor): Input Tensor dim (int, optional): dimension to be calculated. Defaults to the last dim (-1) keepdim (bool optional): whether the output tensor has dim retained or not. Ignored if dim=None. Returns: oneflow.Tensor: A Tensor(dtype=int64) contains the index with the largest value of `input` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[1, 3, 8, 7, 2], ... [1, 9, 4, 3, 2]], dtype=flow.float32) >>> output = flow.argmax(input) >>> output tensor(6, dtype=oneflow.int64) >>> output = flow.argmax(input, dim=1) >>> output tensor([2, 1], dtype=oneflow.int64) """, ) add_docstr( oneflow.argmin, r"""The op computes the index with the largest value of a Tensor at specified axis. Args: input (oneflow.Tensor): Input Tensor dim (int, optional): dimension to be calculated. Defaults to the last dim (-1) keepdim (bool optional): whether the output tensor has dim retained or not. Ignored if dim=None. Returns: oneflow.Tensor: A Tensor(dtype=int64) contains the index with the largest value of `input` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[4, 3, 1, 0, 2], ... [5, 9, 7, 6, 8]], dtype=flow.float32) >>> output = flow.argmin(input) >>> output tensor(3, dtype=oneflow.int64) >>> output = flow.argmin(input, dim=1) >>> output tensor([3, 0], dtype=oneflow.int64) """, ) add_docstr( oneflow.batch_gather, r"""Gather the element in batch dims. Args: in (Tensor): the input tensor. indices (Tensor): the indices tensor, its dtype must be int32/64. For example: Example 1: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.array([[1, 2, 3], ... [4, 5, 6]])) >>> indices = flow.tensor(np.array([1, 0]).astype(np.int64)) >>> out = flow.batch_gather(x, indices) tensor([[4., 5., 6.], [1., 2., 3.]], dtype=oneflow.float32) Example 2: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], ... [[1, 2, 3], [4, 5, 6]]])) >>> indices = flow.tensor(np.array([[1, 0], ... [0, 1]]).astype(np.int64)) >>> out = flow.batch_gather(x, indices) tensor([[[4., 5., 6.], [1., 2., 3.]], [[1., 2., 3.], [4., 5., 6.]]], dtype=oneflow.float32) """, ) add_docstr( oneflow.transpose, r"""Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped. The resulting out tensor shares its underlying storage with the input tensor, so changing the content of one would change the content of the other. Args: input (oneflow.Tensor): the input tensor. dim0 (int): the first dimension to be transposed. dim1 (int): the second dimension to be transposed. Returns: Tensor: A transposed tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> out = flow.transpose(input, 0, 1).shape >>> out oneflow.Size([6, 2, 5, 3]) """, ) add_docstr( oneflow.atleast_1d, r""" oneflow.atleast_1d(*tensors) -> Tensor or List[Tensor] Returns a 1-dimensional view of each input tensor with zero dimensions. Input tensors with one or more dimensions are returned as-is. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_1d.html. Args: tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(1) >>> flow.atleast_1d(x).shape oneflow.Size([1]) >>> x = flow.tensor(0) >>> x.shape oneflow.Size([]) >>> flow.atleast_1d(x).shape oneflow.Size([1]) """, ) add_docstr( oneflow.atleast_2d, r""" oneflow.atleast_2d(*tensors) -> Tensor or List[Tensor] Returns a 2-dimensional view of each input tensor with zero dimensions. Input tensors with two or more dimensions are returned as-is. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_2d.html. Args: tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor(0) >>> x.shape oneflow.Size([]) >>> flow.atleast_2d(x).shape oneflow.Size([1, 1]) >>> x = flow.randn(3) >>> flow.atleast_2d(x).shape oneflow.Size([1, 3]) >>> x = flow.randn(3, 3) >>> flow.atleast_2d(x).shape oneflow.Size([3, 3]) """, ) add_docstr( oneflow.atleast_3d, r""" oneflow.atleast_3d(*tensors) -> Tensor or List[Tensor] Returns a 3-dimensional view of each input tensor with zero dimensions. Input tensors with three or more dimensions are returned as-is. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_3d.html. Args: tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor(0) >>> flow.atleast_3d(x).shape oneflow.Size([1, 1, 1]) >>> x = flow.randn(3) >>> flow.atleast_3d(x).shape oneflow.Size([1, 3, 1]) >>> x = flow.randn(3, 4) >>> flow.atleast_3d(x).shape oneflow.Size([3, 4, 1]) >>> x = flow.randn(3, 4, 5) >>> flow.atleast_3d(x).shape oneflow.Size([3, 4, 5]) """, ) add_docstr( oneflow.stack, r"""Concatenates a sequence of tensors along a new dimension. The returned tensor shares the same underlying data with input tensors. A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1]` can be used. Negative :attr:`dim` will correspond to :meth:`stack` applied at :attr:`dim` = ``dim + input.ndimension() + 1``. Args: inputs (List[oneflow.Tensor]): the list of input tensors. Each tensor should have the same shape. dim (int): the index at which to insert the concatenated dimension. Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.random.rand(1, 3, 5)) >>> x2 = flow.tensor(np.random.rand(1, 3, 5)) >>> y = flow.stack([x1, x2], dim = -1) >>> y.shape oneflow.Size([1, 3, 5, 2]) """, ) add_docstr( oneflow.hstack, r""" oneflow.hstack(tensors) -> Tensor Stack tensors in :attr:`tensors` horizontally (column wise). This is equivalent to concatenation tensors in :attr:`tensors` along the first axis for 1-D tensors, and along the second axis for all other tensors. When there are tensors with dimension less than 1, these tensors will be reshaped by ``oneflow.atleast_1d()`` to 1-dims tensors before stacking. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.hstack.html. Args: tensors: (List[oneflow.Tensor]): sequence of tensors to stack Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.randn(5, 2) >>> x2 = flow.randn(5, 3) >>> flow.hstack([x1, x2]).shape oneflow.Size([5, 5]) >>> x = flow.randn(5) >>> flow.hstack([x, x]).shape oneflow.Size([10]) """, ) add_docstr( oneflow.vstack, r""" oneflow.vstack(tensors) -> Tensor Stack tensors in :attr:`tensors` vertically (row wise). This is equivalent to concatenation tensors in :attr:`tensors` along the first axis. When there are tensors with dimension less than 2, these tensors will be reshaped by ``oneflow.atleast_2d()`` to 2-D tensors before stacking. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.vstack.html. Args: tensors: (List[oneflow.Tensor]): sequence of tensors to stack Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.randn(2, 5) >>> x2 = flow.randn(3, 5) >>> flow.vstack([x1, x2]).shape oneflow.Size([5, 5]) >>> x = flow.randn(5) >>> flow.vstack([x, x]).shape oneflow.Size([2, 5]) """, ) add_docstr( oneflow.dstack, r""" oneflow.dstack(tensors) -> Tensor Stack tensors in :attr:`tensors` depthwish (along third axis). This is equivalent to concatenation tensors in :attr:`tensors` along the third axis after 1-D and 2-D tensors have been reshaped by ``oneflow.atleast_3d()``. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.dstack.html. Args: tensors: (List[oneflow.Tensor]): sequence of tensors to stack Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.randn(2, 3, 4) >>> x2 = flow.randn(2, 3, 2) >>> flow.dstack([x1, x2]).shape oneflow.Size([2, 3, 6]) >>> x = flow.randn(6, 4) >>> flow.dstack([x, x]).shape oneflow.Size([6, 4, 2]) """, ) add_docstr( oneflow.column_stack, r""" oneflow.column_stack(tensors) -> Tensor Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. Equivalent to :code:`oneflow.hstack(tensors)`, tensors with dimensions less than 2 will be reshaped to :code:`(t.numel(), 1)` before being stacked horizontally. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.column_stack.html. Args: tensors: (List[oneflow.Tensor]): sequence of tensors to stack Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.randn(5) >>> x2 = flow.randn(5) >>> flow.column_stack([x1, x2]).shape oneflow.Size([5, 2]) >>> x1 = flow.randn(2, 5) >>> x2 = flow.randn(2, 2) >>> flow.column_stack([x1, x2]).shape oneflow.Size([2, 7]) """, ) add_docstr( oneflow.row_stack, r""" oneflow.row_stack(tensors) -> Tensor Alias of ``oneflow.vstack()``. Stack tensors in :attr:`tensors` vertically (row wise). This is equivalent to concatenation tensors in :attr:`tensors` along the first axis. When there are tensors with dimension less than 2, these tensors will be reshaped by ``oneflow.atleast_2d()`` to 2-D tensors before stacking. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.row_stack.html. Args: tensors: (List[oneflow.Tensor]): sequence of tensors to stack Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.randn(2, 5) >>> x2 = flow.randn(3, 5) >>> flow.vstack([x1, x2]).shape oneflow.Size([5, 5]) >>> x = flow.randn(5) >>> flow.vstack([x, x]).shape oneflow.Size([2, 5]) """, ) add_docstr( oneflow.squeeze, r"""This operator removes the specified dimention which size is 1 of the input Tensor. If the `dim` is not specified, this operator will remove all the dimention which size is 1 of the input Tensor. The amount of element in return value is the same as Tensor `input`. Args: input (oneflow.Tensor): the input Tensor. dim (int, optinal): Defaults to None, if given, the input will be squeezed only in this dimension. Returns: Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32)) >>> input.shape oneflow.Size([1, 1, 1, 3]) >>> out = flow.squeeze(input, dim=[1, 2]).shape >>> out oneflow.Size([1, 3]) """, ) add_docstr( oneflow.cat, r""" cat(tensors, dim=0) -> Tensor Concatenate two or more `Tensor` s at specified dim. Analogous to `numpy.concatenate `_ Args: inputs: a `list` of `Tensor` dim: a `int`. Returns: A `Tensor` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input1 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> input2 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> input3 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> out = flow.cat([input1, input2, input3], dim=1) # equal to using flow.concat() >>> out.shape oneflow.Size([2, 18, 5, 3]) """, ) add_docstr( oneflow.gather, """ oneflow.gather(input, dim, index, sparse_grad=False) -> Tensor Gathers values along an axis specified by `dim`. For a 3-D tensor the output is specified by:: out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 :attr:`input` and :attr:`index` must have the same number of dimensions. It is also required that ``index.size(d) <= input.size(d)`` for all dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. Note that ``input`` and ``index`` do not broadcast against each other. Args: input (Tensor): the source tensor dim (int): the axis along which to index index (LongTensor): the indices of elements to gather For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = np.random.randn(3, 4, 3, 5) >>> index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5)) >>> output = flow.gather(flow.Tensor(input), 1, flow.tensor(index, dtype=flow.int64)) >>> output.shape oneflow.Size([3, 4, 3, 5]) """, ) add_docstr( oneflow.gather_nd, r""" oneflow.gather_nd(input, index) -> Tensor This operator is a high-dimensional extension of `gather`, `index` is a K-dimensional tensor, which is regarded as a index of input Tensor `input`. Each element defines a slice of `input`: .. math:: output[i_{0},i_{1},...,i_{K-2}] = input[index(i_{0},i_{1},...,i_{K-2})] Args: input: The input Tensor. index: The slice indices. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([[1, 2,3], [4, 5,6],[7,8,9]]), dtype=flow.float) >>> index_1 = flow.tensor(np.array([[0], [2]]), dtype=flow.int) >>> out_1 = flow.gather_nd(input,index_1) >>> print(out_1.shape) oneflow.Size([2, 3]) >>> out_1 tensor([[1., 2., 3.], [7., 8., 9.]], dtype=oneflow.float32) >>> index_2 = flow.tensor(np.array([[0,2], [2,1]]), dtype=flow.int) >>> out_2 = flow.gather_nd(input,index_2) >>> out_2 tensor([3., 8.], dtype=oneflow.float32) """, ) add_docstr( oneflow.bincount, r"""oneflow.bincount(input, weights=None, minlength=0) → Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bincount.html. Count the frequency of each value in an array of non-negative ints. The number of bins (size 1) is one larger than the largest value in ``input`` unless ``input`` is empty, in which case the result is a tensor of size 0. If ``minlength`` is specified, the number of bins is at least ``minlength`` and if ``input`` is empty, then the result is tensor of size ``minlength`` filled with zeros. If ``n`` is the value at position ``i``, ``out[n] += weights[i]`` if ``weights`` is specified else ``out[n] += 1``. Args: input (oneflow.Tensor): 1-d int Tensor weights (oneflow.Tensor): optional, weight for each value in the input tensor. Should be of same size as input tensor. minlength (int): optional, minimum number of bins. Should be non-negative. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 4, 6]) >>> flow.bincount(x) tensor([0, 1, 1, 0, 1, 0, 1], dtype=oneflow.int64) >>> x = flow.tensor([1, 2, 1]) >>> weights = flow.tensor([0.1, 0.2, 0.15]) >>> flow.bincount(x, weights=weights) tensor([0.0000, 0.2500, 0.2000], dtype=oneflow.float32) >>> flow.bincount(x, weights=weights, minlength=4) tensor([0.0000, 0.2500, 0.2000, 0.0000], dtype=oneflow.float32) """, ) add_docstr( oneflow.clone, r"""oneflow.clone(input) → Tensor Returns a copy of input. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.clone.html .. note:: This function is differentiable, so gradients will flow back from the result of this operation to ``input``. To create a tensor without an autograd relationship to ``input`` see :meth:`detach`. Args: input (oneflow.Tensor): input Tensor to be cloned For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.Tensor([1, 2, 3]) >>> y = flow.clone(x) >>> y tensor([1., 2., 3.], dtype=oneflow.float32) """, ) add_docstr( oneflow.frac, r"""frac(input) → Tensor Computes the fractional portion of each element in :attr:`input`. .. math:: \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) Args: input: The input Tensor. Returns: Tensor: The fractional part of the argument. For example: >>> import oneflow as flow >>> flow.frac(flow.Tensor([1, 2.50, -3.21])) tensor([ 0.0000, 0.5000, -0.2100], dtype=oneflow.float32) """, ) add_docstr( oneflow.frac_, r""" In-place version of :func:`oneflow.frac`. """, ) ================================================ FILE: python/oneflow/framework/docstr/as_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.as_tensor, r""" as_tensor(data, dtype=None, device=None) -> Tensor Converts data into a tensor, sharing data and preserving autograd history if possible. If data is already a tensor with the requeseted dtype and device then data itself is returned, but if data is a tensor with a different dtype or device then it’s copied as if using data.to(dtype=dtype, device=device). If data is a NumPy array (an ndarray) with the same dtype and device then a tensor is constructed using oneflow.from_numpy. The interface is consistent with PyTorch. Args: data (array_like): Initial data for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types. dtype (oneflow.dtype, optional): the desired data type of returned tensor. Default: if ``None``, infers data type from data. device (oneflow.device, optional): the device of the constructed tensor. If ``None`` and data is a tensor then the device of data is used. If None and data is not a tensor then the result tensor is constructed on the CPU. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> a = np.array([1, 2, 3]) >>> t = flow.as_tensor(a, device=flow.device('cuda')) >>> t tensor([1, 2, 3], device='cuda:0', dtype=oneflow.int64) >>> t[0] = -1 >>> a array([1, 2, 3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/autograd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr from oneflow._oneflow_internal.autograd.Function import FunctionCtx add_docstr( FunctionCtx.saved_tensors, "Get saved tensors in ctx.", ) add_docstr( FunctionCtx.save_for_backward, "Saves given tensors for a future call to ``backward()``.", ) add_docstr( FunctionCtx.mark_non_differentiable, "Marks outputs as non-differentiable.", ) ================================================ FILE: python/oneflow/framework/docstr/baddbmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.baddbmm, r""" baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.baddbmm.html. Performs a batch matrix-matrix product of matrices in :attr:`batch1` and :attr:`batch2`. :attr:`input` is added to the final result. :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same number of matrices. If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a :math:`(b \times m \times p)` tensor, then :attr:`input` must be broadcastable with a :math:`(b \times n \times p)` tensor and :attr:`out` will be a :math:`(b \times n \times p)` tensor. .. math:: \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. Args: input (Tensor): the tensor to be added batch1 (Tensor): the first batch of matrices to be multiplied batch2 (Tensor): the second batch of matrices to be multiplied Keyword args: beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`\text{{batch1}} \mathbin{{@}} \text{{batch2}}` (:math:`\alpha`) For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.randn(10, 3, 5) >>> batch1 = flow.randn(10, 3, 4) >>> batch2 = flow.randn(10, 4, 5) >>> of_out = flow.baddbmm(input, batch1, batch2) >>> of_out.shape oneflow.Size([10, 3, 5]) """, ) ================================================ FILE: python/oneflow/framework/docstr/bitwise_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.bitwise_and, """ Computes the bitwise AND of input and other. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical AND. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_and.html Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute bitwise AND with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> flow.bitwise_and(x, 2) tensor([0, 2, 2], dtype=oneflow.int64) >>> y = flow.tensor([5, 6, 7]) >>> flow.bitwise_and(x, y) tensor([1, 2, 3], dtype=oneflow.int64) """, ) add_docstr( oneflow.bitwise_or, """ Computes the bitwise OR of input and other. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical OR. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_or.html Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute OR with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> flow.bitwise_or(x, 4) tensor([5, 6, 7], dtype=oneflow.int64) >>> y = flow.tensor([5, 6, 7]) >>> flow.bitwise_or(x, y) tensor([5, 6, 7], dtype=oneflow.int64) """, ) add_docstr( oneflow.bitwise_xor, """ Computes the bitwise XOR of input and other. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical XOR. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_xor.html Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute XOR with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> flow.bitwise_xor(x, 2) tensor([3, 0, 1], dtype=oneflow.int64) >>> y = flow.tensor([5, 6, 7]) >>> flow.bitwise_xor(x, y) tensor([4, 4, 4], dtype=oneflow.int64) """, ) add_docstr( oneflow.bitwise_not, """ Computes the bitwise NOT of input. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical NOT. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_not.html Args: input (oneflow.Tensor): The input Tensor Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> flow.bitwise_not(x) tensor([-2, -3, -4], dtype=oneflow.int64) >>> x = flow.tensor([0, 0, 1]).bool() >>> flow.bitwise_not(x) tensor([ True, True, False], dtype=oneflow.bool) """, ) ================================================ FILE: python/oneflow/framework/docstr/bmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.bmm, """ Performs a batch matrix-matrix product of matrices stored in input and mat2. `input` and `mat2` must be 3-D tensors each containing the same number of matrices. If input is a (b x n x m) tensor, mat2 is a (b x m x p) tensor, out will be a (b x n x p) tensor. Args: input(oneflow.Tensor): the first batch of matrices to be multiplied mat2(oneflow.Tensor): the second batch of matrices to be multiplied For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input1 = flow.randn(10, 3, 4) >>> input2 = flow.randn(10, 4, 5) >>> of_out = flow.bmm(input1, input2) >>> of_out.shape oneflow.Size([10, 3, 5]) """, ) ================================================ FILE: python/oneflow/framework/docstr/broadcast_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.broadcast_like, """ This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes. Args: x (Tensor): The input Tensor. like_tensor (Tensor): The like Tensor. broadcast_axes (Optional[Sequence], optional): The axes you want to broadcast. Defaults to None. Returns: [Tensor]: Broadcasted input Tensor. For example: .. code:: python >>> import oneflow as flow >>> x = flow.randn(3, 1, 1) >>> like_tensor = flow.randn(3, 4, 5) >>> broadcast_tensor = flow.broadcast_like(x, like_tensor, broadcast_axes=[1, 2]) >>> broadcast_tensor.shape oneflow.Size([3, 4, 5]) """, ) ================================================ FILE: python/oneflow/framework/docstr/cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.cast, """ The operation takes input tensor `x` and casts it to the output with `dtype` Args: x (oneflow.Tensor): A Tensor dtype (flow.dtype): Data type of the output tensor Returns: oneflow.Tensor: A Tensor with specific dtype. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.random.randn(2, 3, 4, 5).astype(np.float32) >>> input = flow.tensor(np_arr, dtype=flow.float32) >>> output = flow.cast(input, flow.int8) >>> np.array_equal(output.numpy(), np_arr.astype(np.int8)) True """, ) ================================================ FILE: python/oneflow/framework/docstr/chunk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.chunk, """Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be bigger if the tensor size along the given dimension dim is not divisible by chunks. Args: input (oneflow.Tensor): The tensor to split. chunks (int): Number of chunks to return. dim (int): Dimension along which to split the tensor. Returns: List of Tensors. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.random.randn(5, 3, 6, 9).astype(np.float32) >>> input = flow.tensor(arr) >>> output = [] >>> chunks = 3 >>> output = flow.chunk(input, chunks=chunks, dim=2) >>> out_shape = [] >>> for i in range(0, chunks): ... out_shape.append(output[i].numpy().shape) >>> out_shape [(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)] """, ) ================================================ FILE: python/oneflow/framework/docstr/clamp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.clamp, """ Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return a resulting tensor: .. math:: y_i = \\begin{cases} \\text{min} & \\text{if } x_i < \\text{min} \\\\ x_i & \\text{if } \\text{min} \\leq x_i \\leq \\text{max} \\\\ \\text{max} & \\text{if } x_i > \\text{max} \\end{cases} If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min` and :attr:`max` must be real numbers, otherwise they should be integers. Args: input (Tensor): the input tensor. min (Number): lower-bound of the range to be clamped to. Defaults to None. max (Number): upper-bound of the range to be clamped to. Defaults to None. out (Tensor, optional): the output tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([0.2, 0.6, -1.5, -0.3]) >>> input = flow.Tensor(arr) >>> output = flow.clamp(input, min=-0.5, max=0.5) >>> output tensor([ 0.2000, 0.5000, -0.5000, -0.3000], dtype=oneflow.float32) >>> arr = np.array([0.2, 0.6, -1.5, -0.3]) >>> input = flow.Tensor(arr) >>> output = flow.clamp(input, min=None, max=0.5) >>> output tensor([ 0.2000, 0.5000, -1.5000, -0.3000], dtype=oneflow.float32) >>> arr = np.array([0.2, 0.6, -1.5, -0.3]) >>> input = flow.Tensor(arr) >>> output = flow.clamp(input, min=-0.5, max=None) >>> output tensor([ 0.2000, 0.6000, -0.5000, -0.3000], dtype=oneflow.float32) """, ) add_docstr( oneflow.clamp_min, """ Clamp all elements in :attr:`input` which are less than :attr:`min` to :attr:`min` and return a resulting tensor: .. math:: y_i = \max(min, x_i) If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min` must be real numbers, otherwise they should be integers. Args: input (Tensor): the input tensor. min (Number): lower-bound of the range to be clamped to. out (Tensor, optional): the output tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_min(input, min=-0.5) >>> output tensor([ 0.2000, 0.6000, -0.5000, -0.3000], dtype=oneflow.float32) >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_min(input, min=-2) >>> output tensor([ 0.2000, 0.6000, -1.5000, -0.3000], dtype=oneflow.float32) >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_min(input, min=1) >>> output tensor([1., 1., 1., 1.], dtype=oneflow.float32) """, ) add_docstr( oneflow.clamp_max, """ Clamp all elements in :attr:`input` which are greater than :attr:`max` to :attr:`max` and return a resulting tensor: .. math:: y_i = \min(max, x_i) If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`max` must be real numbers, otherwise they should be integers. Args: input (Tensor): the input tensor. max (Number): upper-bound of the range to be clamped to. out (Tensor, optional): the output tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_max(input, max=-0.5) >>> output tensor([-0.5000, -0.5000, -1.5000, -0.5000], dtype=oneflow.float32) >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_max(input, max=-2) >>> output tensor([-2., -2., -2., -2.], dtype=oneflow.float32) >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3]) >>> output = flow.clamp_max(input, max=1) >>> output tensor([ 0.2000, 0.6000, -1.5000, -0.3000], dtype=oneflow.float32) """, ) add_docstr( oneflow.clip, """ Alias for :func:`oneflow.clamp`. """, ) ================================================ FILE: python/oneflow/framework/docstr/comm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.comm.send, """Sends a tensor synchronously. Args: tensor (Tensor): Tensor to send. dst (int): Destination rank. send_meta (Bool): Whether to send meta information (default is True) """, ) add_docstr( oneflow.comm.recv, """Receives a tensor synchronously. All(send_meta is False) or none of shape, dtype and device should have value. Args: src (int, optional): Source rank. Will receive from any process if unspecified. shape (optional): output tensor shape. dataType (optional): output tensor data type. device (optional): output tensor device. out (Tensor, optional): Tensor to fill with received data. Returns: if out is None, return received tensor. otherwise got data from out self without return. """, ) ================================================ FILE: python/oneflow/framework/docstr/comparison.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.greater, """Returns the truth value of :math:`input > other` element-wise. Args: input (oneflow.Tensor): A Tensor other (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> input2 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> out = flow.gt(input1, input2).shape >>> out oneflow.Size([2, 6, 5, 3]) """, ) add_docstr( oneflow.greater_equal, """Returns the truth value of :math:`input >= other` element-wise. Args: input (oneflow.Tensor): A Tensor other (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32) >>> out = flow.ge(input1, input2) >>> out tensor([ True, True, False], dtype=oneflow.bool) """, ) add_docstr( oneflow.eq, """eq(input, other) -> Tensor Computes element-wise equality. The second argument can be a number or a tensor whose shape is broadcastable with the first argument. Args: input (oneflow.Tensor): the tensor to compare other (oneflow.Tensor, float or int): the target to compare Returns: - A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32) >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32) >>> y = flow.eq(input, other) >>> y tensor([ True, True, True, False], dtype=oneflow.bool) """, ) add_docstr( oneflow.equal, """equal(input, other) -> bool `True` if two tensors have the same size and elements, `False` otherwise. Args: input (oneflow.Tensor): the tensor to compare other (oneflow.Tensor): the target to compare Returns: A boolean value For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32) >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32) >>> y = flow.equal(input, other) >>> y False >>> y = flow.equal(input, input) >>> y True """, ) add_docstr( oneflow.ne, """ne(input, other) -> Tensor Computes element-wise not equality. The second argument can be a number or a tensor whose shape is broadcastable with the first argument. Args: input (oneflow.Tensor): the tensor to compare other (oneflow.Tensor, float or int): the target to compare Returns: - A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32) >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32) >>> y = flow.ne(input, other) >>> y tensor([False, False, False, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.lt, """lt(input, other) -> Tensor Returns the truth value of :math:`input < other` element-wise. Args: input (oneflow.Tensor): A Tensor other (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 2, 4]).astype(np.float32), dtype=flow.float32) >>> out = flow.lt(input1, input2) >>> out tensor([False, False, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.le, """le(input, other) -> Tensor Returns the truth value of :math:`input <= other` element-wise. Args: input (oneflow.Tensor): A Tensor other (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32) >>> out = flow.le(input1, input2) >>> out tensor([ True, False, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.isclose, r"""isclose(input, other, atol=1e-08, rtol=1e-05, equal_nan=False) -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isclose.html Returns a new tensor with boolean elements representing if each element of :attr:`input` is "close" to the corresponding element of :attr:`other`. Closeness is defined as: .. math:: \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert Args: input (oneflow.Tensor): first tensor to compare other (oneflow.Tensor): second tensor to compare atol (float, optional): absolute tolerance. Default: 1e-08 rtol (float, optional): relative tolerance. Default: 1e-05 equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import oneflow as flow >>> flow.isclose(flow.tensor((1., 2, 3)), flow.tensor((1 + 1e-10, 3, 4))) tensor([ True, False, False], dtype=oneflow.bool) >>> flow.isclose(flow.tensor((float('inf'), 4)), flow.tensor((float('inf'), 6)), rtol=.5) tensor([True, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.allclose, r"""allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.allclose.html This function checks if :attr:`input` and :attr:`other` satisfy the condition: .. math:: \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to `numpy.allclose `_ Args: input (oneflow.Tensor): first tensor to compare other (oneflow.Tensor): second tensor to compare atol (float, optional): absolute tolerance. Default: 1e-08 rtol (float, optional): relative tolerance. Default: 1e-05 equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` Returns: oneflow.Tensor: A Tensor with bool type. For example: .. code-block:: python >>> import oneflow as flow >>> flow.allclose(flow.tensor([10000., 1e-07]), flow.tensor([10000.1, 1e-08])) False >>> flow.allclose(flow.tensor([10000., 1e-08]), flow.tensor([10000.1, 1e-09])) True >>> flow.allclose(flow.tensor([1.0, float('nan')]), flow.tensor([1.0, float('nan')])) False >>> flow.allclose(flow.tensor([1.0, float('nan')]), flow.tensor([1.0, float('nan')]), equal_nan=True) True """, ) ================================================ FILE: python/oneflow/framework/docstr/constant.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.ones_like, """ ones_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.ones_like.html. Returns a tensor filled with the scalar value 1, with the same size as input. flow.ones_like(input) is equivalent to flow.ones(input.shape, dtype=input.dtype) Args: input(Tensor): The size of input will determine size of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.random.rand(5), dtype=flow.float32) >>> y = flow.ones_like(x) >>> y tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32) """, ) add_docstr( oneflow.zeros_like, """ zeros_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.zeros_like.html. Returns a tensor filled with the scalar value 0, with the same size as input. flow.zeros_like(input) is equivalent to flow.zeros(input.shape, dtype=input.dtype) Args: input(Tensor): The size of input will determine size of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.random.rand(5), dtype=flow.float32) >>> y = flow.zeros_like(x) >>> y tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.new_ones, """ new_ones(x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.new_ones.html. Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same oneflow.dtype and oneflow.device as this tensor. Args: size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.Tensor(np.ones((1, 2, 3))) >>> y = x.new_ones((2, 2)) >>> y tensor([[1., 1.], [1., 1.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.empty, """ empty(*size, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, pin_memory=False) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.empty.html. Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument ``size``. Args: size (int... or oneflow.Size): Defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size. dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``. device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp or List[flow.sbp], optional): The desired sbp of returned global tensor. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.empty(4, 5) # construct local empty tensor >>> y.shape oneflow.Size([4, 5]) >>> y.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.empty(4, 5, placement=placement, sbp=flow.sbp.broadcast) # construct consistent empty tensor >>> y.is_global True """, ) add_docstr( oneflow.empty_like, """ empty_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.empty_like.html. Returns an uninitialized tensor with the same size as :attr:`input`. ``oneflow.empty_like(input)`` is equivalent to ``oneflow.empty(input.size(), dtype=input.dtype, device=input.device)``. Args: input(Tensor): The size of input will determine size of the output tensor. dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``. device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp or List[flow.sbp], optional): The desired sbp of returned global tensor. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(2, 3) >>> y = flow.empty_like(x) >>> y.shape oneflow.Size([2, 3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.conv1d, r""" conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 1D convolution over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv1d.html. See :class:`~oneflow.nn.Conv1d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iW)` weight: filters of shape :math:`(\text{out_channels} , \frac{\text{in_channels}}{\text{groups}} , iW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sW,)`. Default: 1 padding: implicit paddings on both sides of the input. Can be a single number or a tuple `(padW,)`. Default: 0 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dW,)`. Default: 1 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(33, 16, 30) >>> filters = flow.randn(20, 16, 5) >>> outputs = F.conv1d(inputs, filters) """, ) add_docstr( oneflow._C.conv2d, r""" conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 2D convolution over an input image composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv2d.html. See :class:`~oneflow.nn.Conv2d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iH , iW)` weight: filters of shape :math:`(\text{out_channels} , \frac{\text{in_channels}}{\text{groups}} , kH , kW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1 padding: implicit paddings on both sides of the input. Can be a single number or a tuple `(padH, padW)`. Default: 0 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dH, dW)`. Default: 1 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(8, 4, 3, 3) >>> filters = flow.randn(1, 4, 5, 5) >>> outputs = F.conv2d(inputs, filters, padding=1) """, ) add_docstr( oneflow._C.conv3d, r""" conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 3D convolution over an input image composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv3d.html. See :class:`~oneflow.nn.Conv3d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iD , iH , iW)` weight: filters of shape :math:`(\text{out_channels} , \frac{\text{in_channels}}{\text{groups}} , kD , kH , kW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sD, sH, sW)`. Default: 1 padding: implicit paddings on both sides of the input. Can be a single number or a tuple `(padD, padH, padW)`. Default: 0 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dD, dH, dW)`. Default: 1 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(20, 16, 50, 10, 20) >>> filters = flow.randn(33, 16, 3, 3, 3) >>> outputs = F.conv3d(inputs, filters) """, ) ================================================ FILE: python/oneflow/framework/docstr/convolution.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.nn.functional.fold, r""" fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1) The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.fold.html. Combines an array of sliding local blocks into a large containing tensor. .. warning:: Currently, only 3-D input tensors (batched image-like tensors) are supported, and only unbatched (3D) or batched (4D) image-like output tensors are supported. See :class:`oneflow.nn.Fold` for details. """, ) add_docstr( oneflow.nn.functional.unfold, r""" unfold(input, kernel_size, dilation=1, padding=0, stride=1) The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.unfold.html. Extracts sliding local blocks from a batched input tensor. .. warning:: Currently, only 4-D input tensors (batched image-like tensors) are supported. .. warning:: More than one element of the unfolded tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensor, please clone it first. See :class:`oneflow.nn.Unfold` for details. """, ) ================================================ FILE: python/oneflow/framework/docstr/ctc_decode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.ctc_greedy_decoder, """ ctc_greedy_decoder(log_probs: Tensor, input_lengths: Tensor, merge_repeated: bool=True) -> Tensor Performs greedy decoding on the logits given in input (best path). Args: log_probs(oneflow.Tensor): A Tensor of shape [input_length, batch_size, num_labels]. The logarithmized probabilities of the outputs (e.g. obtained with flow.nn.logsoftmax()). input_lengths(oneflow.Tensor): A Tensor of shape [batch_size]. It represent the lengths of the inputs. And the lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths. merge_repeated (bool, optional): If merge_repeated is True, merge repeated classes in output. This means that if consecutive logits' maximum indices are the same, only the first of these is emitted. Defaults to True. Returns: decoded(oneflow.Tensor): A Tensor of shape [batch_size, input_length], The decoded outputs. neg_sum_logits(oneflow.Tensor): A float matrix (batch_size x 1) containing, for the sequence found, the negative of the sum of the greatest logit at each timeframe. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> log_probs = flow.tensor( ... [ ... [[-1.54, -1.20, -1.95, -1.65, -1.81], [-1.84, -1.74, -1.58, -1.55, -1.12]], ... [[-1.68, -1.48, -1.89, -1.30, -2.07], [-1.13, -1.45, -1.24, -1.61, -1.66]], ... [[-1.56, -1.40, -2.83, -1.67, -1.48], [-1.20, -2.01, -2.05, -1.95, -1.24]], ... [[-2.09, -1.76, -1.36, -1.67, -1.45], [-1.85, -1.48, -1.34, -2.16, -1.55]], ... ] ... ) >>> input_lengths = flow.tensor([4, 4]) >>> decoded, neg_sum_logits = flow.nn.functional.ctc_greedy_decoder(log_probs, input_lengths) >>> decoded tensor([[1, 3, 1, 2], [0, 2, 0, 0]], dtype=oneflow.int64) >>> neg_sum_logits tensor([[5.2600], [4.7900]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr ================================================ FILE: python/oneflow/framework/docstr/deconv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.deconv1d, r""" conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 1D transposed convolution operator over an input signal composed of several input planes, sometimes also called “deconvolution”. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose1d.html See :class:`~oneflow.nn.ConvTranspose1d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iW)` weight: filters of shape :math:`(\text{in_channels} , \frac{\text{out_channels}}{\text{groups}} , kW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sW,)`. Default: 1 padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padW,)`. Default: 0 output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padW)`. Default: 0 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dW,)`. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(20, 16, 50) >>> weights = flow.randn(16, 33, 5) >>> outputs = F.conv_transpose1d(inputs, weights) """, ) add_docstr( oneflow._C.deconv2d, r""" conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 2D transposed convolution operator over an input image composed of several input planes, sometimes also called “deconvolution”. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose3d.html See :class:`~oneflow.nn.ConvTranspose2d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iH , iW)` weight: filters of shape :math:`(\text{in_channels} , \frac{\text{out_channels}}{\text{groups}} , kH , kW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1 padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padH, padW)`. Default: 0 output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padH, out_padW)`. Default: 0 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dH, dW)`. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(1, 4, 5, 5) >>> weights = flow.randn(4, 8, 3, 3) >>> outputs = F.conv_transpose2d(inputs, weights, padding=1) """, ) add_docstr( oneflow._C.deconv3d, r""" conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 3D transposed convolution operator over an input image composed of several input planes, sometimes also called “deconvolution”. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose3d.html See :class:`~oneflow.nn.ConvTranspose3d` for details and output shape. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iT , iH , iW)` weight: filters of shape :math:`(\text{in_channels} , \frac{\text{out_channels}}{\text{groups}} , kT , kH , kW)` bias: optional bias of shape :math:`(\text{out_channels})`. Default: None. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sD, sH, sW)`. Default: 1 padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padT, padH, padW)`. Default: 0 output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padT, out_padH, out_padW)`. Default: 0 groups: split input into groups, :math:`\text{in_channels}` should be divisible by the number of groups. Default: 1 dilation: the spacing between kernel elements. Can be a single number or a tuple `(dT, dH, dW)`. Default: 1 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> inputs = flow.randn(20, 16, 50, 10, 20) >>> weights = flow.randn(16, 33, 3, 3, 3) >>> outputs = F.conv_transpose3d(inputs, weights) """, ) ================================================ FILE: python/oneflow/framework/docstr/depend.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.depend, r""" Add control dependency to guarantee OP A is executed before OP B. Used to prevent OPs from being rearranged or eliminated during graph compilation. Args: input (Tensor): a tensor intended to input OP B depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP) Returns: Tensor: the identity of "input" tensor Examples: >>> import oneflow as flow >>> import oneflow.nn as nn >>> import oneflow.nn.functional as F >>> class Model(nn.Module): ... def __init__(self): ... super().__init__() ... self.OP_A = nn.Linear(128, 128) ... self.OP_B = nn.Linear(128, 128) ... ... def forward(self, x): ... x1 = self.OP_A(x) ... x = F.depend(x, x1) ... return self.OP_B(x) ... >>> model = Model() >>> class Graph(nn.Graph): ... def __init__(self) -> None: ... super().__init__() ... self.model = model ... ... def build(self, x): ... return self.model(x) ... >>> graph = Graph() >>> x = flow.randn([1, 128], dtype=flow.float32) >>> y = graph(x) """, ) ================================================ FILE: python/oneflow/framework/docstr/distance.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.cosine_similarity, r""" cosine_similarity(x1: Tensor, x2: Tensor, dim: int=1, eps: float=1e-8) -> Tensor Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is squeezed (see :func:`oneflow.squeeze`), resulting in the output tensor having 1 fewer dimension. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.cosine_similarity.html .. math :: \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)} Args: x1 (Tensor): First input. x2 (Tensor): Second input. dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 eps (float, optional): Small value to avoid division by zero. Default: 1e-8 For examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input1 = flow.randn(100, 128) >>> input2 = flow.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) """, ) add_docstr( oneflow._C.pairwise_distance, r""" pairwise_distance(x1: Tensor, x2: Tensor, dim: float=2.0, eps: float=1e-6, keepdim: bool=False) -> Tensor Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: .. math :: \left \| x \right \| _p = (\sum_{i=1}^n \left | x_i \right |^p )^{\frac{1}{p}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.PairwiseDistance.html. Args: x1 (Tensor): First input. x2 (Tensor): Second input. p (real): the norm degree. Default: 2 eps (float, optional): Small value to avoid division by zero. Default: 1e-6 keepdim (bool, optional): Determines whether or not to keep the vector dimension. Default: False For example: .. code-block:: python >>> import oneflow as flow >>> x1 = flow.arange(12).reshape(3, 4) >>> x2 = flow.arange(12).reshape(3, 4) >>> output = flow.nn.functional.pairwise_distance(x1, x2, p=2) >>> output tensor([2.0000e-06, 2.0000e-06, 2.0000e-06], dtype=oneflow.float32) >>> output.shape oneflow.Size([3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.dropout, """ dropout(x: Tensor, p: float = 0.5, training: bool = True, generator :Generator = None, *, addend: Tensor) -> Tensor During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout.html. Args: x(Tensor): A Tensor which will be applyed dropout. p(float): probability of an element to be zeroed. Default: 0.5 training(bool): If is True it will apply dropout. Default: True generator(Generator, optional): A pseudorandom number generator for sampling addend(Tensor, optional): A Tensor add in result after dropout, it can be used in model's residual connection structure. Default: None Shape: - Input: :math:`(*)`. Input can be of any shape - Output: :math:`(*)`. Output is of the same shape as input For example: Example 1: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.tensor(arr, dtype=flow.float32) >>> y = flow.nn.functional.dropout(x, p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.tensor(arr, dtype=flow.float32) >>> generator = flow.Generator() >>> y = flow.nn.functional.dropout(x, p=0.5, generator=generator) Example 2: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.tensor(arr, dtype=flow.float32) >>> addend = flow.ones((3, 4), dtype=flow.float32) >>> y = flow.nn.functional.dropout(x, p=0, addend=addend) >>> y #doctest: +ELLIPSIS tensor([[ 0.2203, 1.2264, 1.2458, 1.4163], [ 1.4299, 1.3626, 0.5108, 1.4141], [-0.4115, 2.2183, 0.4497, 1.6520]], dtype=oneflow.float32) See :class:`~oneflow.nn.Dropout` for details. """, ) add_docstr( oneflow._C.dropout1d, r""" dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout1d.html. Randomly zero out entire channels (a channel is a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 1D tensor :math:`\text{input}[i, j]`) of the input tensor). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. See :class:`~oneflow.nn.Dropout1d` for details. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` """, ) add_docstr( oneflow._C.dropout2d, r""" dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout2d.html. Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 2D tensor :math:`\text{input}[i, j]`) of the input tensor). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. See :class:`~oneflow.nn.Dropout2d` for details. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` """, ) add_docstr( oneflow._C.dropout3d, r""" dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout3d.html. Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 3D tensor :math:`\text{input}[i, j]`) of the input tensor). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. See :class:`~oneflow.nn.Dropout3d` for details. Args: p: probability of a channel to be zeroed. Default: 0.5 training: apply dropout if is ``True``. Default: ``True`` """, ) add_docstr( oneflow.nn.Dropout, """ During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. Each channel will be zeroed out independently on every forward call. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout.html. This has proven to be an effective technique for regularization and preventing the co-adaptation of neurons as described in the paper "Improving neural networks by preventing co-adaptation of feature detectors". Furthermore, the outputs are scaled by a factor of :math:`\\frac{1}{1-p}` during training. This means that during evaluation the module simply computes an identity function. Additionally, we can pass an extra Tensor `addend` which shape is consistent with input Tensor. The `addend` Tensor will be add in result after dropout, it is very useful in model's residual connection structure. Args: p: probability of an element to be zeroed. Default: 0.5 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` generator: A pseudorandom number generator for sampling Shape: - Input: :math:`(*)`. Input can be of any shape - Output: :math:`(*)`. Output is of the same shape as input For example: example 1: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Dropout(p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.Tensor(arr) >>> y = m(x) >>> y #doctest: +ELLIPSIS tensor([[-0.7797, 0.2264, 0.2458, 0.4163], [ 0.4299, 0.3626, -0.4892, 0.4141], [-1.4115, 1.2183, -0.5503, 0.6520]], dtype=oneflow.float32) example 2: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Dropout(p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.Tensor(arr) >>> addend = flow.ones((3, 4), dtype=flow.float32) >>> y = m(x, addend=addend) >>> y #doctest: +ELLIPSIS tensor([[ 0.2203, 1.2264, 1.2458, 1.4163], [ 1.4299, 1.3626, 0.5108, 1.4141], [-0.4115, 2.2183, 0.4497, 1.6520]], dtype=oneflow.float32) .. _Improving neural networks by preventing co-adaptation of feature detectors: https://arxiv.org/abs/1207.0580 """, ) add_docstr( oneflow.nn.Dropout1d, """ Randomly zero out entire channels (a channel is a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 1D tensor :math:`\text{input}[i, j]`). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout1d.html. Usually the input comes from :class:`nn.Conv1d` modules. As described in the paper `Efficient Object Localization Using Convolutional Networks`_ , if adjacent pixels within feature maps are strongly correlated (as is normally the case in early convolution layers) then i.i.d. dropout will not regularize the activations and will otherwise just result in an effective learning rate decrease. In this case, :func:`oneflow.nn.Dropout1d` will help promote independence between feature maps and should be used instead. Args: p (float, optional): probability of an element to be zero-ed. inplace (bool, optional): If set to ``True``, will do this operation in-place Shape: - Input: :math:`(N, C, L)` or :math:`(C, L)`. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Dropout1d(p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.Tensor(arr) >>> y = m(x) >>> y #doctest: +ELLIPSIS tensor([[-0.7797, 0.2264, 0.2458, 0.4163], [ 0.4299, 0.3626, -0.4892, 0.4141], [-1.4115, 1.2183, -0.5503, 0.6520]], dtype=oneflow.float32) .. _Efficient Object Localization Using Convolutional Networks: https://arxiv.org/abs/1411.4280 """, ) add_docstr( oneflow.nn.Dropout2d, """ Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 2D tensor :math:`\text{input}[i, j]`). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout2d.html. Usually the input comes from :class:`nn.Conv2d` modules. As described in the paper `Efficient Object Localization Using Convolutional Networks`_ , if adjacent pixels within feature maps are strongly correlated (as is normally the case in early convolution layers) then i.i.d. dropout will not regularize the activations and will otherwise just result in an effective learning rate decrease. In this case, :func:`oneflow.nn.Dropout2d` will help promote independence between feature maps and should be used instead. Args: p (float, optional): probability of an element to be zero-ed. inplace (bool, optional): If set to ``True``, will do this operation in-place Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Dropout2d(p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.Tensor(arr) >>> y = m(x) >>> y #doctest: +ELLIPSIS tensor([[-0.7797, 0.2264, 0.2458, 0.4163], [ 0.4299, 0.3626, -0.4892, 0.4141], [-1.4115, 1.2183, -0.5503, 0.6520]], dtype=oneflow.float32) .. _Efficient Object Localization Using Convolutional Networks: https://arxiv.org/abs/1411.4280 """, ) add_docstr( oneflow.nn.Dropout3d, """ Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the batched input is a 3D tensor :math:`\text{input}[i, j]`). Each channel will be zeroed out independently on every forward call with probability :attr:`p` using samples from a Bernoulli distribution. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout2d.html. Usually the input comes from :class:`nn.Conv3d` modules. As described in the paper `Efficient Object Localization Using Convolutional Networks`_ , if adjacent pixels within feature maps are strongly correlated (as is normally the case in early convolution layers) then i.i.d. dropout will not regularize the activations and will otherwise just result in an effective learning rate decrease. In this case, :func:`oneflow.nn.Dropout3d` will help promote independence between feature maps and should be used instead. Args: p (float, optional): probability of an element to be zeroed. inplace (bool, optional): If set to ``True``, will do this operation in-place Shape: - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Dropout3d(p=0) >>> arr = np.array( ... [ ... [-0.7797, 0.2264, 0.2458, 0.4163], ... [0.4299, 0.3626, -0.4892, 0.4141], ... [-1.4115, 1.2183, -0.5503, 0.6520], ... ] ... ) >>> x = flow.Tensor(arr) >>> y = m(x) >>> y #doctest: +ELLIPSIS tensor([[-0.7797, 0.2264, 0.2458, 0.4163], [ 0.4299, 0.3626, -0.4892, 0.4141], [-1.4115, 1.2183, -0.5503, 0.6520]], dtype=oneflow.float32) .. _Efficient Object Localization Using Convolutional Networks: https://arxiv.org/abs/1411.4280 """, ) ================================================ FILE: python/oneflow/framework/docstr/einsum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.einsum, """ einsum(equation, *operands) -> oneflow.Tensor Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation based on the Einstein summation convention. Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of this format are described below, but the general idea is to label every dimension of the input :attr:`operands` with some subscript and define which subscripts are part of the output. The output is then computed by summing the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the output. For example, matrix multiplication can be computed using einsum as `flow.einsum("ij,jk->ik", A, B)`. Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). Equation: The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation followed by the subscripts for the output. For instance, the following equation computes the transpose of a matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and at most once for the output. Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements batch matrix multiplication `'...ij,...jk'`. A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. .. note:: ``flow.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. .. note:: This function does not optimize the given expression, so a different formula for the same computation may run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) can optimize the formula for you. Args: equation (String): The subscripts for the Einstein summation. *operands (oneflow.Tensor): The tensors to compute the Einstein summation of. For example: .. code-block:: python >>> import oneflow as flow # trace >>> flow.einsum('ii', flow.arange(4*4).reshape(4,4).to(flow.float32)) tensor(30., dtype=oneflow.float32) # diagonal >>> flow.einsum('ii->i', flow.arange(4*4).reshape(4,4).to(flow.float32)) tensor([ 0., 5., 10., 15.], dtype=oneflow.float32) # outer product >>> x = flow.arange(5).to(flow.float32) >>> y = flow.arange(4).to(flow.float32) >>> flow.einsum('i,j->ij', x, y) tensor([[ 0., 0., 0., 0.], [ 0., 1., 2., 3.], [ 0., 2., 4., 6.], [ 0., 3., 6., 9.], [ 0., 4., 8., 12.]], dtype=oneflow.float32) # batch matrix multiplication >>> As = flow.arange(3*2*5).reshape(3,2,5).to(flow.float32) >>> Bs = flow.arange(3*5*4).reshape(3,5,4).to(flow.float32) >>> flow.einsum('bij,bjk->bik', As, Bs).shape oneflow.Size([3, 2, 4]) # batch permute >>> A = flow.randn(2, 3, 4, 5) >>> flow.einsum('...ij->...ji', A).shape oneflow.Size([2, 3, 5, 4]) # bilinear >>> A = flow.randn(3,5,4) >>> l = flow.randn(2,5) >>> r = flow.randn(2,4) >>> flow.einsum('bn,anm,bm->ba', l, A, r).shape oneflow.Size([2, 3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/erfinv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.erfinv, """Computes the inverse error function of :attr:`input`. The inverse error function is defined in the range :math:`(-1, 1)` as: .. math:: \mathrm{erfinv}(\mathrm{erf}(x)) = x Args: input (oneflow.Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input=flow.tensor(np.random.randn(3,3).astype(np.float32)) >>> of_out=flow.erfinv(input) >>> of_out.shape oneflow.Size([3, 3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/expand.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.expand, """ oneflow.expand(input, *sizes) -> Tensor, This operator expand the input tensor to a larger size. Passing -1 as the size for a dimension means not changing the size of that dimension. Tensor can be also expanded to a larger number of dimensions and the new ones will be appended at the front. For the new dimensions, the size cannot be set to -1. Args: input (oneflow.Tensor): the input Tensor. *sizes (oneflow.Size or int): The desired expanded size. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = np.array([[[[0, 1]], ... [[2, 3]], ... [[4, 5]]]]).astype(np.int32) >>> input = flow.Tensor(x) >>> input.shape oneflow.Size([1, 3, 1, 2]) >>> out = input.expand(1, 3, 2, 2) >>> out.shape oneflow.Size([1, 3, 2, 2]) """, ) ================================================ FILE: python/oneflow/framework/docstr/flatten.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.flatten, """Flattens a contiguous range of dims into a tensor. Args: start_dim: first dim to flatten (default = 0). end_dim: last dim to flatten (default = -1). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.randn(32, 1, 5, 5) >>> output = flow.flatten(input, start_dim=1) >>> output.shape oneflow.Size([32, 25]) """, ) ================================================ FILE: python/oneflow/framework/docstr/flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.flip, """ flip(input, dims) -> Tensor Reverse the order of a n-D tensor along given axis in dims. .. note:: `flow.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, `flow.flip` is expected to be slower than `np.flip`. Args: input (Tensor): the input tensor dims (a list or tuple): axis to flip on For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.arange(0, 8).reshape((2, 2, 2)).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> input.shape oneflow.Size([2, 2, 2]) >>> out = flow.flip(input, [0, 1]) >>> out tensor([[[6., 7.], [4., 5.]], [[2., 3.], [0., 1.]]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/hann_window.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.hann_window, r""" hann_window(window_length, periodic=True, *, device=None, placement=None, sbp=None, dtype=None, requires_grad=False) -> Tensor This function is equivalent to PyTorch’s hann_window function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.hann_window.html. Hann window function. .. math:: w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = \sin^2 \left( \frac{\pi n}{N - 1} \right), where :math:`N` is the full window size. The input :attr:`window_length` is a positive integer controlling the returned window size. :attr:`periodic` flag determines whether the returned window trims off the last duplicate value from the symmetric window. Therefore, if :attr:`periodic` is true, the :math:`N` in above formula is in fact :math:`\text{window_length} + 1`. Also, we always have ``oneflow.hann_window(L, periodic=True)`` equal to ``oneflow.hann_window(L + 1, periodic=False)[:-1])``. .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. Keyword args: dtype (oneflow.dtype, optional): the data type to perform the computation in. Default: if None, uses the global default dtype (see oneflow.get_default_dtype()) when both :attr:`start` and :attr:`end` are real, and corresponding complex dtype when either is complex. device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: A 1-D tensor of size :math:`(\text{{window_length}},)` containing the window """, ) ================================================ FILE: python/oneflow/framework/docstr/in_top_k.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.in_top_k, """ in_top_k(targets, predictions, k) -> Tensor Says whether the targets are in the top K predictions. Args: targets (Tensor): the target tensor of type int32 or int64. predictions (Tensor): the predictions tensor of type float32 . k (int): Number of top elements to look at for computing precision. Returns: oneflow.Tensor: A Tensor of type bool. Computed Precision at k as a bool Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> targets1 = flow.tensor(np.array([3, 1]), dtype=flow.int32) >>> predictions1 = flow.tensor(np.array([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0],]), dtype=flow.float32) >>> out1 = flow.in_top_k(targets1, predictions1, k=1) >>> out1 tensor([ True, False], dtype=oneflow.bool) >>> out2 = flow.in_top_k(targets1, predictions1, k=2) >>> out2 tensor([True, True], dtype=oneflow.bool) >>> targets2 = flow.tensor(np.array([3, 1]), dtype=flow.int32, device=flow.device('cuda')) >>> predictions2 = flow.tensor(np.array([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0],]), dtype=flow.float32, device=flow.device('cuda')) >>> out3 = flow.in_top_k(targets2, predictions2, k=1) >>> out3 tensor([ True, False], device='cuda:0', dtype=oneflow.bool) """, ) ================================================ FILE: python/oneflow/framework/docstr/index_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.Tensor.index_add_, r""" index_add_(dim, index, source, *, alpha=1) -> Tensor The interface is consistent with PyTorch. Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` tensor by adding to the indices in the order given in :attr:`index`. For example, if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\ th row of ``source`` is subtracted from the ``j``\ th row of :attr:`self`. The :attr:`dim`\ th dimension of ``source`` must have the same size as the length of :attr:`index` (which must be a vector), and all other dimensions must match :attr:`self`, or an error will be raised. For a 3-D tensor the output is given as:: self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 Args: dim (int): dimension along which to index index (Tensor): indices of ``source`` to select from, should have dtype either `oneflow.int64` or `oneflow.int32` source (Tensor): the tensor containing values to add Keyword args: alpha (Number): the scalar multiplier for ``source`` .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(5, 3) >>> t = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float) >>> index = flow.tensor([0, 4, 2]) >>> x.index_add_(0, index, t) tensor([[ 2., 3., 4.], [ 1., 1., 1.], [ 8., 9., 10.], [ 1., 1., 1.], [ 5., 6., 7.]], dtype=oneflow.float32) >>> x.index_add_(0, index, t, alpha=-1) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.index_add, r""" index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor See :meth:`oneflow.Tensor.index_add_` for function description. """, ) add_docstr( oneflow._C.index_add_, r""" index_add_(dim, index, source, *, alpha=1) -> Tensor Out-of-place version of :meth:`oneflow.Tensor.index_add_`. """, ) ================================================ FILE: python/oneflow/framework/docstr/index_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.index_select, """ input.index_select(dim, index) -> Tensor Select values along an axis specified by `dim`. :attr:`index` must be an Int32 Tensor with 1-D. :attr:`dim` must be in the range of input Dimensions. value of :attr:`index` must be in the range of the dim-th of input. Note that ``input`` and ``index`` do not broadcast against each other. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.index_select.html. Args: input (Tensor): the source tensor dim (int): the axis along which to index index (Tensor): the 1-D tensor containing the indices to index For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[1,2,3],[4,5,6]], dtype=flow.int32) >>> input tensor([[1, 2, 3], [4, 5, 6]], dtype=oneflow.int32) >>> index = flow.tensor([0,1], dtype=flow.int64) >>> output = flow.index_select(input, 1, index) >>> output tensor([[1, 2], [4, 5]], dtype=oneflow.int32) >>> output = input.index_select(1, index) >>> output tensor([[1, 2], [4, 5]], dtype=oneflow.int32) .. Feature Stage of Operator [index_select]. - Maintainer List [@QiangX-man, @hjchen2, @strint] - Current Stage [ ] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [Yes] - eager local [Yes] [@QiangX-man, @hjchen2] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] [@BBuf, @strint, @hjchen2] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Exception Handling - Exception Message and Hint must be provided [ ] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ] - Doc(Same standard as Alpha Stage)[ ] - Functionality and its' Test [ ] - eager global [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - graph gloal [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [ ] - Exception Message and Hint must be provided [ ] - Try you best to do Exception Recovery [ ] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """, ) ================================================ FILE: python/oneflow/framework/docstr/inv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.linalg.inv, """linalg.inv(A) -> Tensor Computes the inverse of a square matrix if it exists. Throws a `RuntimeError` if the matrix is not invertible. Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, for a matrix :math:`A \in \mathbb{K}^{n \times n}`, its **inverse matrix** :math:`A^{-1} \in \mathbb{K}^{n \times n}` (if it exists) is defined as .. math:: A^{-1}A = AA^{-1} = \mathrm{I}_n where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case, the inverse is unique. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. Args: A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of invertible matrices. Raises: RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible. Examples: .. code-block:: python >>> import oneflow as flow >>> A = flow.tensor([[ 1.3408, -0.7788, 1.0551, -0.5866], ... [ 0.8480, 0.8350, 0.9781, -0.1297], ... [-0.0881, -0.6142, -0.3833, 0.3232], ... [ 1.2841, 0.7517, -0.3849, 0.2515]]) >>> flow.linalg.inv(A) tensor([[ 0.3105, -0.0811, 0.1288, 0.5169], ... [-0.3457, 0.1716, -0.7133, 0.1987], ... [-0.0593, 1.1706, 0.8694, -0.6516], ... [-0.6427, 1.6923, 2.8049, -0.2541]], dtype=oneflow.float32) >>> A = flow.tensor([[[ 0.6144, 0.1027, -0.1353], ... [-1.4415, -0.6731, 0.3723], ... [ 0.4069, -0.8940, 1.4056]], ... [[-1.1891, -0.3897, -1.5015], ... [ 0.3028, 1.1040, 0.2600], ... [-1.6970, 0.4238, 0.9146]]]) >>> flow.linalg.inv(A) tensor([[[ 1.6830, 0.0644, 0.1449], ... [-5.9755, -2.5206, 0.0925], ... [-4.2879, -1.6219, 0.7283]], ... ... [[-0.2370, 0.0737, -0.4100], ... [ 0.1892, 0.9579, 0.0384], ... [-0.5274, -0.3070, 0.3148]]], dtype=oneflow.float32) .. _invertible: https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem .. Feature Stage of Operator [linalg.inv]. - Maintainer List [@simonJJJ] - Current Stage [pre Alpha] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [Yes] - eager local [Yes] [@simonJJJ] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] [@simonJJJ] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Exception Handling - Exception Message and Hint must be provided [Yes] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ] - Doc(Same standard as Alpha Stage)[Yes] - Functionality and its' Test [ ] - eager global [Yes] [@simonJJJ] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph gloal [Yes] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [Yes] - Exception Message and Hint must be provided [Yes] - Try you best to do Exception Recovery [Yes] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """, ) ================================================ FILE: python/oneflow/framework/docstr/is_floating_point.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.is_floating_point, r"""Returns True if the data type of input is a floating point data type i.e., one of `oneflow.float64` , `oneflow.float32` , `oneflow.float16`, and `oneflow.bfloat16`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([1, 2, 3, 4, 5], dtype=flow.int) >>> output = flow.is_floating_point(input) >>> output False """, ) ================================================ FILE: python/oneflow/framework/docstr/lerp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.lerp, """ lerp(start, end, weight) -> Tensor The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.lerp.html. Does a linear interpolation of two tensors `start` and `end` based on a scalar or tensor `weight` and returns the result. The shapes of start` and `end` must be broadcastable. If `weight` is a tensor, then the shapes of `weight`, `start`, and `end` must be broadcastable. .. math:: out_{i} = start_{i} + weight_{i} * (end_{i} - start_{i}) Args: start (oneflow.Tensor): the tensor with the starting points. end (oneflow.Tensor): the tensor with the ending points. weight (float or oneflow.Tensor): the weight for the interpolation formula. For example: .. code-block:: python >>> import oneflow as flow >>> start = flow.arange(1., 5.) >>> end = flow.empty(4).fill_(10) >>> flow.lerp(start, end, 0.5) tensor([5.5000, 6.0000, 6.5000, 7.0000], dtype=oneflow.float32) >>> flow.lerp(start, end, flow.full_like(start, 0.5)) tensor([5.5000, 6.0000, 6.5000, 7.0000], dtype=oneflow.float32) """, ) add_docstr( oneflow.lerp_, """ In-place version of :func:`oneflow.lerp` """, ) ================================================ FILE: python/oneflow/framework/docstr/linalg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.linalg.cross, """linalg.cross(input, other, dim=-1) -> Tensor Computes the cross product of two 3-dimensional vectors. Supports input of float and double dtypes. Also supports batches of vectors, for which it computes the product along the dimension dim. In this case, the output has the same batch dimensions as the inputs broadcast to a common shape. The documentation is referenced from: https://pytorch.org/docs/1.11/generated/torch.linalg.cross.html Args: input (Tensor): the first input tensor. other (Tensor): the second input tensor. dim (int, optional): the dimension along which to take the cross-product. Default: `-1` Raises: RuntimeError: If after broadcasting ``input.size(dim) != 3`` or ``other.size(dim) != 3``. Examples: .. code-block:: python >>> import oneflow as flow >>> a = flow.tensor([[ -0.3956, 1.1455, 1.6895], ... [ -0.5849, 1.3672, 0.3599], ... [ -1.1626, 0.7180, -0.0521], ... [ -0.1339, 0.9902, -2.0225]]) >>> b = flow.tensor([[ -0.0257, -1.4725, -1.2251], ... [ -1.1479, -0.7005, -1.9757], ... [ -1.3904, 0.3726, -1.1836], ... [ -0.9688, -0.7153, 0.2159]]) >>> flow.linalg.cross(a, b) tensor([[ 1.0844, -0.5281, 0.6120], [-2.4491, -1.5687, 1.9791], [-0.8304, -1.3036, 0.5651], [-1.2329, 1.9883, 1.0551]], dtype=oneflow.float32) """, ) add_docstr( oneflow.cross, """cross(input, other, dim=None) -> Tensor Returns the cross product of vectors in dimension `dim` of `input` and `other`. Supports input of float and double dtypes. Also supports batches of vectors, for which it computes the product along the dimension `dim`. In this case, the output has the same batch dimensions as the inputs. If `dim` is not given, it defaults to the first dimension found with the size 3. Note that this might be unexpected. The documentation is referenced from: https://pytorch.org/docs/1.11/generated/torch.cross.html .. warning:: This function may change in a future PyTorch release to match the default behaviour in :func:`oneflow.linalg.cross`. We recommend using :func:`oneflow.linalg.cross`. Args: input (Tensor): the first input tensor. other (Tensor): the second input tensor. dim (int, optional): the dimension to take the cross-product in. Default: `None` Examples: .. code-block:: python >>> import oneflow as flow >>> a = flow.tensor([[ -0.3956, 1.1455, 1.6895], ... [ -0.5849, 1.3672, 0.3599], ... [ -1.1626, 0.7180, -0.0521], ... [ -0.1339, 0.9902, -2.0225]]) >>> b = flow.tensor([[ -0.0257, -1.4725, -1.2251], ... [ -1.1479, -0.7005, -1.9757], ... [ -1.3904, 0.3726, -1.1836], ... [ -0.9688, -0.7153, 0.2159]]) >>> flow.cross(a, b) tensor([[ 1.0844, -0.5281, 0.6120], [-2.4491, -1.5687, 1.9791], [-0.8304, -1.3036, 0.5651], [-1.2329, 1.9883, 1.0551]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/logaddexp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.logaddexp, """ logaddexp(input, other, *, out=None) -> Tensor The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.logaddexp.html. Logarithm of the sum of exponentiations of the inputs. Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful in statistics where the calculated probabilities of events may be so small as to exceed the range of normal floating point numbers. In such cases the logarithm of the calculated probability is stored. This function allows adding probabilities stored in such a fashion. Args: input (oneflow.Tensor): the input Tensor. other (oneflow.Tensor): the second input Tensor. out (oneflow.Tensor, optional): the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> flow.logaddexp(flow.tensor([-1.0]), flow.tensor([-1.0, -2, -3])) tensor([-0.3069, -0.6867, -0.8731], dtype=oneflow.float32) >>> flow.logaddexp(flow.tensor([-100.0, -200, -300]), flow.tensor([-1.0, -2, -3])) tensor([-1., -2., -3.], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/logical_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.logical_and, """ Computes the element-wise logical AND of the given input tensors. Zeros are treated as False and nonzeros are treated as True. Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute AND with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 1, 0]).astype(np.float32), dtype=flow.float32) >>> out = flow.logical_and(input1, input2) >>> out tensor([ True, False, False], dtype=oneflow.bool) """, ) add_docstr( oneflow.logical_or, """ Computes the element-wise logical OR of the given input tensors. Zeros are treated as False and nonzeros are treated as True. Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute OR with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 0, 0]).astype(np.float32), dtype=flow.float32) >>> out = flow.logical_or(input1, input2) >>> out tensor([ True, False, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.logical_xor, """ Computes the element-wise logical XOR of the given input tensors. Zeros are treated as False and nonzeros are treated as True. Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute XOR with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32) >>> input2 = flow.tensor(np.array([1, 0, 0]).astype(np.float32), dtype=flow.float32) >>> out = flow.logical_xor(input1, input2) >>> out tensor([False, False, True], dtype=oneflow.bool) """, ) add_docstr( oneflow.logical_not, r""" Computes the element-wise logical NOT of the given input tensors. Zeros are treated as False and nonzeros are treated as True. Args: input (oneflow.Tensor): The input Tensor other (oneflow.Tensor): The Tensor to compute NOT with Returns: oneflow.Tensor: The output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([1, 0, -1], dtype=flow.float32) >>> out = flow.logical_not(input) >>> out tensor([False, True, False], dtype=oneflow.bool) """, ) ================================================ FILE: python/oneflow/framework/docstr/loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.triplet_margin_loss, r""" Creates a criterion that measures the triplet loss given an input tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. This is used for measuring a relative similarity between samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative examples` respectively). The shapes of all input tensors should be :math:`(N, D)`. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.triplet_margin_loss.html. The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with triplet losses `__ by V. Balntas, E. Riba et al. The loss function for each sample in the mini-batch is: .. math:: L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} where .. math:: d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p Args: margin (float, optional): Default: :math:`1`. p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`. swap (bool, optional): The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with triplet losses` by V. Balntas, E. Riba et al. Default: ``False``. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, D)` where :math:`D` is the vector dimension. - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar otherwise. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2) >>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]]) >>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]]) >>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative)) >>> output tensor(6.2971, dtype=oneflow.float32) """, ) add_docstr( oneflow._C.cross_entropy, r""" cross_entropy(input, target, weight=None, ignore_index=-100, reduction="mean", label_smoothing=0.0) See :class:`~oneflow.nn.CrossEntropyLoss` for details. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.cross_entropy.html. Args: input (Tensor) : :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` in the case of K-dimensional loss. `input` is expected to contain unnormalized scores (often referred to as logits). target (Tensor) : If containing class indices, shape :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. If containing class probabilities, same shape as the input. weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size `C` ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. Note that :attr:`ignore_index` is only applicable when the target contains class indices. Default: -100 reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` label_smoothing (float, optinoal): A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in `Rethinking the Inception Architecture for Computer Vision `_. Default: :math:`0.0`. For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.randn(3, 5, requires_grad=True) >>> target = flow.ones(3, dtype=flow.int64) >>> loss = F.cross_entropy(input, target) >>> loss.backward() """, ) add_docstr( oneflow._C.l1_loss, r""" l1_loss(input, target, reduction="mean") -> Tensor This operator computes the L1 loss between each element in input and target. see :class:`~oneflow.nn.L1Loss` for details. Args: input (Tensor): The input Tensor. target (Tensor): The target Tensor. reduction (string, optional): The reduce type, it can be one of "none", "mean", "sum". Defaults to "mean". Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.randn(3, 4, requires_grad=True) >>> target = flow.rand(3, 4, requires_grad=False) >>> loss = F.l1_loss(input, target) >>> loss.backward() """, ) add_docstr( oneflow._C.mse_loss, r""" mse_loss(input, target, reduction="mean") -> Tensor This operator computes the mean squared error (squared L2 norm) loss between each element in input and target. see :class:`~oneflow.nn.MSELoss` for details. Args: input (Tensor): The input Tensor. target (Tensor): The target Tensor. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.randn(3, 4, requires_grad=True) >>> target = flow.rand(3, 4, requires_grad=False) >>> loss = F.mse_loss(input, target) >>> loss.backward() """, ) add_docstr( oneflow._C.smooth_l1_loss, """ smooth_l1_loss(input: Tensor, target: Tensor, size_average: bool=True, reduce: bool=True, reduction: str='mean', beta: float=1.0) -> Tensor Function that uses a squared term if the absolute element-wise error falls below beta and an L1 term otherwise. See :class:`~oneflow.nn.SmoothL1Loss` for details. """, ) add_docstr( oneflow._C.binary_cross_entropy_loss, r""" binary_cross_entropy(input, target, weight=None, reduction='mean') The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.binary_cross_entropy.html. Function that measures the Binary Cross Entropy between the target and input probabilities. See :class:`~oneflow.nn.BCELoss` for details. Args: input: Tensor of arbitrary shape as probabilities. target: Tensor of the same shape as input with values between 0 and 1. weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.randn(3, 2, requires_grad=True) >>> target = flow.rand(3, 2, requires_grad=False) >>> loss = F.binary_cross_entropy(flow.sigmoid(input), target) >>> loss.backward() """, ) add_docstr( oneflow._C.binary_cross_entropy_with_logits_loss, r""" binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean', pos_weight=None) The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.binary_cross_entropy_with_logits.html. Function that measures Binary Cross Entropy between target and input logits. See :class:`~oneflow.nn.BCEWithLogitsLoss` for details. Args: input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). target: Tensor of the same shape as input with values between 0 and 1 weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` pos_weight (Tensor, optional): a weight of positive examples. Must be a vector with length equal to the number of classes. Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.randn(3, requires_grad=True) >>> target = flow.randn(3) >>> target[target >= 0] = 1 >>> target[target < 0] = 0 >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() """, ) add_docstr( oneflow._C.kl_div_loss, r""" kl_div_loss(input, target, reduction="mean", log_target=False) `The Kullback-Leibler divergence loss measure `_ See :class:`~oneflow.nn.KLDivLoss` for details. Args: reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. ``'none'``: no reduction will be applied. ``'batchmean'``: the sum of the output will be divided by batchsize. ``'sum'``: the output will be summed. ``'mean'``: the output will be divided by the number of elements in the output. Default: ``'mean'`` log_target (bool, optional): Specifies whether `target` is passed in the log space. Default: ``False`` .. note:: :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``. Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input - Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`, the same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor([-0.9021705, 0.08798598, 1.04686249], dtype=flow.float32) >>> target = flow.tensor([1.22386942, -0.89729659, 0.01615712], dtype=flow.float32) >>> out = flow.nn.functional.kl_div(input, target, reduction="none", log_target=False) >>> out tensor([ 1.3514, 0.0000, -0.0836], dtype=oneflow.float32) >>> out = flow.nn.functional.kl_div(input, target, reduction="mean", log_target=False) >>> out tensor(0.4226, dtype=oneflow.float32) >>> out = flow.nn.functional.kl_div(input, target, reduction="sum", log_target=True) >>> out tensor(5.7801, dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/masked_fill.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.masked_fill, """ Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is True. The shape of :attr:`mask` must be broadcastable with the shape of the underlying tensor. Args: mask (BoolTensor): the boolean mask value (float): the value to fill in with For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> in_arr = np.array( ... [[[-0.13169311, 0.97277078, 1.23305363, 1.56752789], ... [-1.51954275, 1.87629473, -0.53301206, 0.53006478], ... [-1.38244183, -2.63448052, 1.30845795, -0.67144869]], ... [[ 0.41502161, 0.14452418, 0.38968 , -1.76905653], ... [ 0.34675095, -0.7050969 , -0.7647731 , -0.73233418], ... [-1.90089858, 0.01262963, 0.74693893, 0.57132389]]] ... ) >>> fill_value = 8.7654321 # random value e.g. -1e9 3.1415 >>> input = flow.tensor(in_arr, dtype=flow.float32) >>> mask = flow.tensor((in_arr > 0).astype(np.int8), dtype=flow.int) >>> output = flow.masked_fill(input, mask, fill_value) # tensor([[[-0.1317, 8.7654, 8.7654, 8.7654], # [-1.5195, 8.7654, -0.533 , 8.7654], # [-1.3824, -2.6345, 8.7654, -0.6714]], # [[ 8.7654, 8.7654, 8.7654, -1.7691], # [ 8.7654, -0.7051, -0.7648, -0.7323], # [-1.9009, 8.7654, 8.7654, 8.7654]]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/math_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.abs, r"""Return the absolute value of each element in input tensor:math:`y = |x|` element-wise. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([-1, 2, -3, 4]).astype(np.float32)) >>> flow.abs(x) tensor([1., 2., 3., 4.], dtype=oneflow.float32) """, ) add_docstr( oneflow.add, r""" oneflow.add(input, other, *, alpha=1) -> Tensor Adds `other`, scaled by `alpha`, to `input`. Scalar and broadcast promotation are supported. .. math:: out = input + alpha \times other Args: input (Union[int, float, oneflow.Tensor]): the input tensor. other (Union[int, float, oneflow.Tensor]): the tensor or number to add to input. Keyword args: alpha (Number, optional): the multiplier for `other`. Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow # element-wise add >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) # scalar add >>> x = 5 >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) # broadcast add >>> x = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.add(x, y).numpy() >>> out.shape (2, 3) # use alpha >>> x = flow.zeros(2, 3) >>> y = flow.ones(2, 3) >>> out = flow.add(x, y, alpha=10) >>> out tensor([[10., 10., 10.], [10., 10., 10.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.floor, """ Returns a new tensor with the arcsine of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\lfloor \\text{input}_{i} \\rfloor Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([-0.5, 1.5, 0, 0.8]), dtype=flow.float32) >>> output = flow.floor(input) >>> output.shape oneflow.Size([4]) >>> output.numpy() array([-1., 1., 0., 0.], dtype=float32) >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, 2.5]]), dtype=flow.float32) >>> output1 = input1.floor() >>> output1.shape oneflow.Size([2, 2]) >>> output1.numpy() array([[ 0., 1.], [-1., 2.]], dtype=float32) """, ) add_docstr( oneflow.floor_, r""" In-place version of :func:`oneflow.floor` """, ) add_docstr( oneflow.div, r""" div(x, y, *, rounding_mode=None) Computes the division of input by other for each element, scalar and broadcast promotation are supported. The formula is: .. math:: out = \frac{input}{other} Args: input (Union[int, float, oneflow.Tensor]): input. other (Union[int, float, oneflow.Tensor]): other. Keyword Arguments: rounding_mode (str, optional): It can be set as ``"floor"`` (roudning the results down) or ``"trunc"`` (rounding the results towards zero). None for default (no rounding). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow # element-wise divide >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.div(input,other).numpy() >>> out.shape (2, 3) # scalar divide >>> input = 5 >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.div(input,other).numpy() >>> out.shape (2, 3) # broadcast divide >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.div(input,other).numpy() >>> out.shape (2, 3) # rounding_mode >>> x = flow.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) >>> flow.div(x, 0.5) tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274], dtype=oneflow.float32) >>> flow.div(x, 0.5, rounding_mode="floor") tensor([ 0., 2., -1., -1., 0.], dtype=oneflow.float32) >>> flow.div(x, 0.5, rounding_mode="trunc") tensor([0., 2., -0., -0., 0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.mul, r"""Computes the multiplication of input by other for each element, scalar and broadcast promotation are supported. The formula is: .. math:: \text{out}_i = \text{input}_i \times \text{other}_i For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow # element-wise multiply >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.mul(input,other).numpy() >>> out.shape (2, 3) # scalar mutiply >>> input = 5 >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.mul(input,other).numpy() >>> out.shape (2, 3) # broadcast mutiply >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.mul(input,other).numpy() >>> out.shape (2, 3) """, ) add_docstr( oneflow.reciprocal, r"""Computes the safe reciprocal of x. If x is zero, the reciprocal will be also set to zero. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor(np.array([[1, 2, 3], [4, 5, 6]]), dtype=flow.float32) >>> out = flow.reciprocal(x) >>> out.numpy() array([[1. , 0.5 , 0.33333334], [0.25 , 0.2 , 0.16666667]], dtype=float32) """, ) add_docstr( oneflow.sub, r"""Computes the subtraction of input by other for each element, scalar and broadcast promotation are supported. The formula is: .. math:: out = input - other For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow # element-wise subtract >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.sub(input,other).numpy() >>> out.shape (2, 3) # scalar subtract >>> input = 5 >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.sub(input,other).numpy() >>> out.shape (2, 3) # broadcast subtract >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.sub(input,other).numpy() >>> out.shape (2, 3) """, ) add_docstr( oneflow.asin, r""" Returns a new tensor with the arcsine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sin^{-1}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([-0.5, 0.8, 1.0, -0.8]), dtype=flow.float32) >>> output = flow.asin(input) >>> output.shape oneflow.Size([4]) >>> output tensor([-0.5236, 0.9273, 1.5708, -0.9273], dtype=oneflow.float32) >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, -1.0]]), dtype=flow.float32) >>> output1 = input1.asin() >>> output1.shape oneflow.Size([2, 2]) >>> output1 tensor([[ 0.9273, 1.5708], [-0.6435, -1.5708]], dtype=oneflow.float32) """, ) add_docstr( oneflow.asinh, r""" Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([2, 3, 4]), dtype=flow.float32) >>> output = flow.asinh(input) >>> output.shape oneflow.Size([3]) >>> output tensor([1.4436, 1.8184, 2.0947], dtype=oneflow.float32) >>> input1 = flow.tensor(np.array([[-1, 0, -0.4], [5, 7, 0.8]]), dtype=flow.float32) >>> output1 = input1.asinh() >>> output1.shape oneflow.Size([2, 3]) >>> output1 tensor([[-0.8814, 0.0000, -0.3900], [ 2.3124, 2.6441, 0.7327]], dtype=oneflow.float32) """, ) add_docstr( oneflow.atan, r""" Returns a new tensor with the arctangent of the elements of :attr:`input`. .. math:: \text{out}_{i} = \tan^{-1}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([0.5, 0.6, 0.7]), dtype=flow.float32) >>> output = flow.atan(input) >>> output.shape oneflow.Size([3]) """, ) add_docstr( oneflow.ceil, r"""Returns a new tensor with the ceil of the elements of :attr:`input`, the smallest integer greater than or equal to each element. The equation is: .. math:: \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil = \left\lfloor \text{input}_{i} \right\rfloor + 1 Args: input (oneflow.Tensor): A Tensor. Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([0.1, -2, 3.4]).astype(np.float32)) >>> y = flow.ceil(x) >>> y.shape oneflow.Size([3]) >>> y tensor([ 1., -2., 4.], dtype=oneflow.float32) >>> x = flow.tensor(np.array([[2.5, 4.6, 0.6],[7.8, 8.3, 9.2]]).astype(np.float32)) >>> y = x.ceil() >>> y.shape oneflow.Size([2, 3]) >>> y tensor([[ 3., 5., 1.], [ 8., 9., 10.]], dtype=oneflow.float32) >>> x = flow.tensor(np.array([[[2.2, 4.4, 6.5],[7.1, 8.2, 9.3]],[[10.6,11.2,12.2],[13.5,14.8,15.9]]]).astype(np.float32)) >>> y = flow.ceil(x) >>> y.shape oneflow.Size([2, 2, 3]) >>> y tensor([[[ 3., 5., 7.], [ 8., 9., 10.]], [[11., 12., 13.], [14., 15., 16.]]], dtype=oneflow.float32) """, ) add_docstr(oneflow.ceil_, r"""In-place version of :func:`oneflow.ceil`""") add_docstr( oneflow.negative, r"""This operator computes the negative value of Tensor. Args: input (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor( ... np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 ... ) >>> out = flow.negative(input) >>> out tensor([-1.0000, 1.0000, -2.3000], dtype=oneflow.float32) """, ) add_docstr( oneflow.log1p, r"""Returns a new tensor with the natural logarithm of (1 + input). .. math:: \text{out}_{i}=\log_e(1+\text{input}_{i}) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([1.3, 1.5, 2.7]), dtype=flow.float32) >>> out = flow.log1p(x) >>> out tensor([0.8329, 0.9163, 1.3083], dtype=oneflow.float32) """, ) add_docstr( oneflow.exp, r""" This operator computes the exponential of Tensor. The equation is: .. math:: out = e^x Args: x (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> y = flow.exp(x) >>> y tensor([ 2.7183, 7.3891, 20.0855], dtype=oneflow.float32) """, ) add_docstr( oneflow.exp2, r""" This operator computes the base two exponential of Tensor. The equation is: .. math:: out = 2^x Args: x (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> y = flow.exp2(x) >>> y tensor([2., 4., 8.], dtype=oneflow.float32) """, ) add_docstr( oneflow.acos, r""" Returns a new tensor with the inverse cosine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \arccos(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([0.5, 0.6, 0.7]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.acos(input) >>> output tensor([1.0472, 0.9273, 0.7954], dtype=oneflow.float32) """, ) add_docstr( oneflow.acosh, r""" Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([2, 3, 4]).astype(np.float32)) >>> out1 = flow.acosh(x1) >>> out1 tensor([1.3170, 1.7627, 2.0634], dtype=oneflow.float32) >>> x2 = flow.tensor(np.array([1.5, 2.6, 3.7]).astype(np.float32),device=flow.device('cuda')) >>> out2 = flow.acosh(x2) >>> out2 tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32) """, ) add_docstr( oneflow.atanh, r"""Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. .. math:: \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.array([0.5, 0.6, 0.7]).astype(np.float32) >>> input = flow.tensor(np_arr, dtype=flow.float32) >>> output = flow.atanh(input) >>> output tensor([0.5493, 0.6931, 0.8673], dtype=oneflow.float32) """, ) add_docstr( oneflow.sign, r"""Computes the sign of Tensor. .. math:: \text{out}_{i} = \text{sgn}(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([-2, 0, 2]).astype(np.float32)) >>> out1 = flow.sign(x1) >>> out1.numpy() array([-1., 0., 1.], dtype=float32) >>> x2 = flow.tensor(np.array([-3.2, -4.5, 5.8]).astype(np.float32),device=flow.device('cuda')) >>> out2 = flow.sign(x2) >>> out2.numpy() array([-1., -1., 1.], dtype=float32) """, ) add_docstr( oneflow.sin, r"""Returns a new tensor with the sine of the elements of :attr:`input`. sin(x: Tensor) -> Tensor .. math:: \text{y}_{i} = \sin(\text{x}_{i}) Args: x (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([-0.5461, 0.1347, -2.7266, -0.2746]).astype(np.float32)) >>> y1 = flow.sin(x1) >>> y1 tensor([-0.5194, 0.1343, -0.4032, -0.2712], dtype=oneflow.float32) >>> x2 = flow.tensor(np.array([-1.4, 2.6, 3.7]).astype(np.float32), device=flow.device('cuda')) >>> y2 = flow.sin(x2) >>> y2 tensor([-0.9854, 0.5155, -0.5298], device='cuda:0', dtype=oneflow.float32) """, ) add_docstr( oneflow.sin_, r""" In-place version of :func:`oneflow.sin` """, ) add_docstr( oneflow.sinh, r"""Returns a new tensor with the hyperbolic sine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sinh(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x1 = flow.tensor(np.array([1, 2, 3]), dtype=flow.float32) >>> x2 = flow.tensor(np.array([1.53123589,0.54242598,0.15117185]), dtype=flow.float32) >>> x3 = flow.tensor(np.array([1,0,-1]), dtype=flow.float32) >>> flow.sinh(x1).numpy() array([ 1.1752012, 3.6268604, 10.017875 ], dtype=float32) >>> flow.sinh(x2).numpy() array([2.20381 , 0.5694193, 0.1517483], dtype=float32) >>> flow.sinh(x3).numpy() array([ 1.1752012, 0. , -1.1752012], dtype=float32) """, ) add_docstr( oneflow.tan, r"""Returns the tan value of the elements of :attr:`input`. .. math:: \text{out}_{i} = \tan(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32) >>> input = flow.tensor(np_arr, dtype=flow.float32) >>> output = flow.tan(input) >>> output tensor([-1., 0., 1.], dtype=oneflow.float32) """, ) add_docstr( oneflow.cos, r""" Returns a new tensor with the cosine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \cos(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([1.4309, 1.2706, -0.8562, 0.9796]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.cos(input).numpy() """, ) add_docstr( oneflow.cosh, r""" Returns a new tensor with the hyperbolic cosine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \cosh(\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> arr = np.array([ 0.1632, 1.1835, -0.6979, -0.7325]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.cosh(input).numpy() >>> output array([1.0133467, 1.7859949, 1.2535787, 1.2804903], dtype=float32) """, ) add_docstr( oneflow.erf, r"""Computes the error function of each element. The error function is defined as follows: .. math:: \operatorname{erf}(x)=\frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^{2}} d t Args: x (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32) >>> out = flow.erf(x) >>> out.shape oneflow.Size([3]) >>> out.numpy() array([ 0. , -0.8427008, 1. ], dtype=float32) >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8]]), dtype=flow.float32) >>> out = flow.erf(x) >>> out.shape oneflow.Size([2, 3]) >>> out.numpy() array([[ 0. , -0.8427008 , 1. ], [ 1. , 1. , 0.74210095]], dtype=float32) >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8], [2, 3, 4]]), dtype=flow.float32) >>> out = x.erf() >>> out.shape oneflow.Size([3, 3]) >>> out.numpy() array([[ 0. , -0.8427008 , 1. ], [ 1. , 1. , 0.74210095], [ 0.9953223 , 0.9999779 , 1. ]], dtype=float32) """, ) add_docstr( oneflow.erfc, r"""Computes the complementary error function of each element of input. The complementary error function is defined as follows: .. math:: \operatorname{erfc}(x)=1-\frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^{2}} d t Args: x (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32) >>> out = flow.erfc(x) >>> out tensor([1.0000e+00, 1.8427e+00, 2.8026e-45], dtype=oneflow.float32) >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8]]), dtype=flow.float32) >>> out = flow.erfc(x) >>> out tensor([[1.0000e+00, 1.8427e+00, 2.8026e-45], [1.5375e-12, 4.1838e-23, 2.5790e-01]], dtype=oneflow.float32) """, ) add_docstr( oneflow.expm1, r"""Returns a new tensor with the exponential of the elements minus 1 of :attr:`input`. The equation is: .. math:: y_{i} = e^{x_{i}} - 1 Args: input (oneflow.Tensor): A Tensor. Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32)) >>> y = flow.expm1(x) >>> y.shape oneflow.Size([3]) >>> y tensor([ 1.7183, 6.3891, 19.0855], dtype=oneflow.float32) >>> x = flow.tensor(np.array([[[2, 4, 6],[7, 8, 9]],[[10,11,12],[13,14,15]]]).astype(np.float32)) >>> y = flow.expm1(x) >>> print(y.shape) oneflow.Size([2, 2, 3]) >>> print(y.numpy()) [[[6.3890562e+00 5.3598152e+01 4.0242880e+02] [1.0956332e+03 2.9799580e+03 8.1020840e+03]] [[2.2025465e+04 5.9873141e+04 1.6275380e+05] [4.4241238e+05 1.2026032e+06 3.2690165e+06]]] """, ) add_docstr( oneflow.fmod, r""" fmod(input, other, *, out=None) -> Tensor Computes the element-wise remainder of division. The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend :attr:`input`. Supports broadcasting to a common shape, integer and float inputs. Args: input (Tensor): the dividend other (Tensor or Scalar): the divisor Keyword args: out (Tensor, optional): the output tensor. Example:: >>> import oneflow as flow >>> flow.fmod(flow.tensor([-3., -2, -1, 1, 2, 3], dtype=flow.float32), 2.) tensor([-1., -0., -1., 1., 0., 1.], dtype=oneflow.float32) >>> flow.fmod(flow.tensor([1, 2, 3, 4, 5.], dtype=flow.float32), 1.5) tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000], dtype=oneflow.float32) >>> flow.fmod(flow.tensor([1, 2, 3, 4., -5]), flow.tensor([4, 2, 1, 3., 1])) tensor([1., 0., 0., 1., -0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.log, r""" Returns a new tensor with the natural logarithm of the elements of :attr:`input`. .. math:: y_{i} = \log_{e} (x_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.random.randn(2, 3, 4, 5) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.log(input) """, ) add_docstr( oneflow.log2, """ oneflow.log2(input) -> Tensor Returns a new tensor with the natural logarithm to the base 2 of the elements of :attr:`input`. .. math:: y_{i} = \\log2_{e} (x_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.random.randn(2, 3, 4, 5) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.log2(input) """, ) add_docstr( oneflow.log10, """ oneflow.log10(input) -> Tensor Returns a new tensor with the natural logarithm to the base 10 of the elements of :attr:`input`. .. math:: y_{i} = \\log10_{e} (x_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(3, 3) * 10 >>> output = flow.log10(x) >>> output tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.minimum, r"""Computes the element-wise minimum of x and y. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor((1, 2, -1), dtype=flow.float32) >>> y = flow.tensor((3, 0, 4), dtype=flow.float32) >>> flow.minimum(x, y) tensor([ 1., 0., -1.], dtype=oneflow.float32) >>> x = flow.tensor((1,), dtype=flow.float32) >>> y = flow.tensor((3, 0, 4), dtype=flow.float32) >>> flow.minimum(x, y) tensor([1., 0., 1.], dtype=oneflow.float32) """, ) add_docstr( oneflow.maximum, r"""Computes the element-wise maximum of x and y. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor((1, 2, -1), dtype=flow.float32) >>> y = flow.tensor((3, 0, 4), dtype=flow.float32) >>> flow.maximum(x, y) tensor([3., 2., 4.], dtype=oneflow.float32) >>> x = flow.tensor((1,), dtype=flow.float32) >>> y = flow.tensor((3, 0, 4), dtype=flow.float32) >>> flow.maximum(x, y) tensor([3., 1., 4.], dtype=oneflow.float32) """, ) add_docstr( oneflow.median, r""" median(input) -> Tensor Returns the median of the values in input. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.median.html#torch.median .. note:: The median is not unique for :attr:`input` tensors with an even number of elements. In this case the lower of the two medians is returned. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor((1, 2, -1), dtype=flow.float32) >>> flow.median(x) tensor(1., dtype=oneflow.float32) .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) :noindex: Returns a tuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. If :attr:`keepdim` is ``True``, the output tensors are of the same size as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`flow.squeeze`), resulting in the outputs tensor having 1 fewer dimension than :attr:`input`. .. note:: The median is not unique for :attr:`input` tensors with an even number of elements in the dimension :attr:`dim`. In this case the lower of the two medians is returned. Args: input (Tensor): the input tensor. dim (int): the dimension to reduce. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], ... [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], ... [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], ... [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) >>> result=flow.median(a, 1) >>> result.values tensor([-0.3982, 0.2270, 0.2488, 0.4742], dtype=oneflow.float32) >>> result.indices tensor([1, 4, 4, 3], dtype=oneflow.int64) .. Feature Stage of Operator [index_select]. - Maintainer List [@simonJJJ] - Current Stage [pre Alpha] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [Yes] - eager local [Yes] [@simonJJJ] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] [@simonJJJ] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Exception Handling - Exception Message and Hint must be provided [Yes] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ] - Doc(Same standard as Alpha Stage)[Yes] - Functionality and its' Test [ ] - eager global [Yes] [@simonJJJ] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph gloal [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [ ] - Exception Message and Hint must be provided [ ] - Try you best to do Exception Recovery [ ] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """, ) add_docstr( oneflow.mode, r""" oneflow.mode(input, dim=-1, keepdim=False) Returns a namedtuple (values, indices) where values is the mode value of each row of the input tensor in the given dimension dim, i.e. a value which appears most often in that row, and indices is the index location of each mode value found. By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. If :attr:`keepdim` is ``True``, the output tensors are of the same size as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`flow.squeeze`), resulting in the outputs tensor having 1 fewer dimension than :attr:`input`. Args: input (Tensor): the input tensor. dim (int): the dimension to reduce. Default: `-1` keepdim (bool): whether the output tensor has dim retained or not. Default: `False` Returns: Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): the result tuple of two output tensors (values, indices) For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([6, 2, 5, 3, 3, 5, 4, 3]) >>> result = flow.mode(x) >>> result.values tensor(3, dtype=oneflow.int64) >>> result.indices tensor(7, dtype=oneflow.int64) >>> x = flow.Tensor([[2, 1, 2, 3], [2, 4, 3, 3]]) >>> result = flow.mode(x, dim=0) >>> result.values tensor([2., 1., 2., 3.], dtype=oneflow.float32) >>> result.indices tensor([1, 0, 0, 1], dtype=oneflow.int64) """, ) add_docstr( oneflow.pow, r"""Takes the power of each element in input with exponent and returns a tensor with the result. Exponent can be either a single float number, a single int number, or a tensor with the same shape as input. When exponent is a scalar value, the operation applied is: .. math:: \text{out}_i = x_i ^ \text{exponent} When exponent is a tensor, the operation applied is: .. math:: \text{out}_i = x_i ^ {\text{exponent}_i} Args: input (Tensor): the input tensor. exponent (int, float, Tensor): the exponent. Returns: Tensor: The result of variance on the specified axis of input Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), dtype=flow.float32) >>> out = flow.pow(x, 2) >>> out tensor([ 1., 4., 9., 16., 25., 36.], dtype=oneflow.float32) >>> x = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0]), dtype=flow.float32) >>> y = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0]), dtype=flow.float32) >>> out = flow.pow(x, y) >>> out tensor([ 1., 4., 27., 256.], dtype=oneflow.float32) """, ) add_docstr( oneflow.rsqrt, r"""Returns a new tensor with the reciprocal of the square-root of each of the elements of :attr:`input`. .. math:: \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> a = flow.tensor(np.array([1.0, 2.0, 3.0]), dtype=flow.float32) >>> out = flow.rsqrt(a).numpy() >>> out array([1. , 0.70710677, 0.57735026], dtype=float32) """, ) add_docstr( oneflow.sqrt, r"""Returns a new tensor with the square-root of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sqrt{\text{input}_{i}} Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.sqrt(input).numpy() >>> output array([1. , 1.4142135, 1.7320508], dtype=float32) """, ) add_docstr( oneflow.square, r"""Returns a new tensor with the square of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sqrt{\text{input}_{i}} Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.square(input).numpy() >>> output array([1., 4., 9.], dtype=float32) """, ) add_docstr( oneflow.matmul, r""" matmul(input, other) -> Tensor This operator applies matrix multiplication to two Tensor. Args: a (oneflow.Tensor): A Tensor b (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input1 = flow.tensor(np.random.randn(2, 6), dtype=flow.float32) >>> input2 = flow.tensor(np.random.randn(6, 5), dtype=flow.float32) >>> of_out = flow.matmul(input1, input2) >>> of_out.shape oneflow.Size([2, 5]) """, ) add_docstr( oneflow.mv, r""" mv(input, vec) -> Tensor Performs a matrix-vector product of the matrix :attr:`input` and the vector :attr:`vec`. If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of size `m`, :attr:`out` will be a 1-D tensor of size `n`. .. note:: This function does not broadcast. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.mv.html. Args: input (oneflow.Tensor): matrix to be matrix multiplied vec (oneflow.Tensor): vector to be matrix multiplied Returns: oneflow.Tensor: the output Tensor For example: .. code-block:: python >>> import oneflow as flow >>> mat = flow.randn(2, 3) >>> vec = flow.randn(3) >>> out = flow.mv(mat, vec) >>> out.shape oneflow.Size([2]) """, ) add_docstr( oneflow.mm, r""" mm(input, mat2) -> Tensor Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. .. note:: This function does not broadcast. For broadcasting matrix products, see :func:`oneflow.matmul`. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.mm.html. Args: input (oneflow.Tensor): the first matrix to be matrix multiplied mat2 (oneflow.Tensor): the second matrix to be matrix multiplied Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> mat1 = flow.randn(2, 3) >>> mat2 = flow.randn(3, 3) >>> of_out = flow.mm(mat1, mat2) >>> of_out.shape oneflow.Size([2, 3]) """, ) add_docstr( oneflow.round, r"""This operator rounds the value of Blob to the nearest integer. .. note:: This function implements the "round half to even" to break ties when a number is equidistant from two integers (e.g. `round(2.5)` is 2). Args: input (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([1.49999, 1.500001, 2.7]).astype(np.float32)) >>> out1 = flow.round(x1) >>> out1.numpy() array([1., 2., 3.], dtype=float32) >>> x2 = flow.tensor(np.array([2.499999, 7.5000001, 5.3, 6.8]).astype(np.float32)) >>> out2 = flow.round(x2) >>> out2.numpy() array([2., 8., 5., 7.], dtype=float32) """, ) add_docstr(oneflow.round_, r"""In-place version of :func:`oneflow.round`.""") add_docstr( oneflow.std, r""" Returns the standard-deviation of each row of the :attr:`input` tensor in the dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, reduce over all of them. If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 (or len(dim)) fewer dimension(s). If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel's correction will be used. Args: input (Tensor): the input tensor. dim (int or tuple of ints): the dimension or dimensions to reduce. unbiased (bool): whether to use the unbiased estimation or not keepdim (bool): whether the output tensor has `dim` retained or not. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr) >>> output = flow.std(input, dim=0).numpy() >>> output array(1.) """, ) add_docstr( oneflow.var, r"""Returns the variance of each row of the `input` tensor in the given dimension `dim`. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, dim is squeezed (see `flow.squeeze()`), resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (Tensor): the input tensor. dim (int or tuple of ints): the dimension or dimensions to reduce. Defaults to None. unbiased (bool, optional): whether to use Bessel’s correction (:math:`\delta N = 1`). Defaults to True. keepdim (bool, optional): whether the output tensor has dim retained or not. Defaults to False. Returns: Tensor: The result of variance on the specified axis of input Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.random.randn(2, 3, 4, 5)) >>> output = flow.var(input, 1, True) """, ) add_docstr( oneflow.dot, r"""This operator computes the dot product of tensor input and other. The equation is: $$ \\sum_{i=1}^{n}(x[i] * y[i]) $$ Args: input (Tensor): first tensor in the dot product. other (Tensor): second tensor in the dot product. Shape: - input: Input must be 1D. - other: Other must be 1D. For example: .. code-block:: python >>> import oneflow as flow >>> flow.dot(flow.Tensor([2, 3]), flow.Tensor([2, 1])) tensor(7., dtype=oneflow.float32) """, ) add_docstr( oneflow.select, r""" Slices the self tensor along the selected dimension at the given index. This function returns a view of the original tensor with the given dimension removed. Args: input (Tensor): the input tensor. dim (int): the dimension to slice. select (int): the index to select with. Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(3, 4, 5) >>> out = flow.select(input, 0, 1) >>> out.size() oneflow.Size([4, 5]) >>> out = flow.select(input, 1, 1) >>> out.size() oneflow.Size([3, 5]) """, ) add_docstr( oneflow.movedim, r""" Moves the dimension(s) of input at the position(s) in source to the position(s) in destination. Other dimensions of input that are not explicitly moved remain in their original order and appear at the positions not specified in destination. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.movedim.html. Args: input (Tensor): the input tensor. source (int or a list): Original positions of the dims to move. These must be unique. destination (int or a list): Destination positions for each of the original dims. These must also be unique. Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32) >>> output = flow.movedim(input, 1, 0) >>> output.shape oneflow.Size([3, 2, 4, 5]) >>> output = flow.movedim(input, (1, 2), (0, 1)) >>> output.shape oneflow.Size([3, 4, 2, 5]) """, ) add_docstr( oneflow.as_strided, r""" as_strided(input, size, stride, storage_offset=None) -> Tensor Create a view of an existing oneflow.Tensor input with specified size, stride and storage_offset. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.as_strided.html. Args: input (Tensor): the input tensor. size (tuple or ints): the shape of the output tensor. stride (tuple or ints): the stride of the output tensor. storage_offset (int): the offset in the underlying storage of the output tensor Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(2,3,5) >>> output = flow.as_strided(input, (2,3,3), (1,2,3), 1) >>> output.size() oneflow.Size([2, 3, 3]) """, ) add_docstr( oneflow.addcmul, r""" oneflow.addcmul(input, tensor1, tensor2, *, value=1) -> Tensor Performs the element-wise multiplication of tensor1 by tensor2, multiply the result by the scalar value and add it to input. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.addcmul.html .. math:: \text{out}_i = \text{input}_i + value \times\ \text{tensor1}_i \times\ \text{tensor2}_i Args: input (Tensor): the tensor to be added. tensor1 (Tensor): the tensor to be multiplied. tensor2 (Tensor): the tensor to be multiplied. Keyword args: value (Number, optional): multiplier for :math:`tensor1 * tensor2`. Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(2, 3, 4) >>> tensor1 = flow.rand(2, 3, 4) >>> tensor2 = flow.rand(2, 3, 4) >>> out = flow.addcmul(input, tensor1, tensor2, value=2) >>> out.size() oneflow.Size([2, 3, 4]) """, ) add_docstr( oneflow.eye, """oneflow.eye(n, m, *, device=None, requires_grad=False, placement=None, sbp) -> Tensor This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere. Args: n (int): the number of rows. m (int, optional): the number of colums with default being n. Defaults to None. Keyword args: device(Union[flow.device, str], optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. placement(oneflow._oneflow_internal.placement, optional): The placement attribute allows you to specify which physical device the tensor is stored on. sbp(Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]], optional): When creating a global tensor, specify the SBP of the tensor. Returns: oneflow.Tensor: The result tensor with ones on the diagonal and zeros elsewhere. For example: .. code-block:: python >>> import oneflow as flow >>> out = flow.eye(3, 3) >>> out tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=oneflow.float32) >>> out = flow.eye(3, 3, device="cuda") >>> out tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], device='cuda:0', dtype=oneflow.float32) """, ) add_docstr( oneflow.tensor_split, r""" Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections . The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.tensor_split.html. Args: input (Tensor): the input tensor. indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n). sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n). If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU. dim (int): dimension along which to split the tensor. Returns: oneflow.TensorTuple: the output TensorTuple. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(3,4,5) >>> output = flow.tensor_split(input,(2,3),2) >>> output[0].size() oneflow.Size([3, 4, 2]) >>> output[1].size() oneflow.Size([3, 4, 1]) >>> output[2].size() oneflow.Size([3, 4, 2]) """, ) add_docstr( oneflow.hsplit, r""" hsplit(input, indices_or_sections) -> List of Tensors The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.hsplit.html. Splits `input`, a tensor with one or more dimensions, into multiple tensors horizontally according to `indices_or_sections`. Each split is a view of `input`. If `input` is one dimensional this is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is zero), and if `input` has two or more dimensions it’s equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if `indices_or_sections` is an integer it must evenly divide the split dimension or a runtime error will be thrown. Args: input (Tensor): the input tensor. indices_or_sections (int or a list): See argument in :func:`oneflow.tensor_split()`. Returns: oneflow.TensorTuple: the output TensorTuple. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(3,4,5,6) >>> output = flow.hsplit(input,(1,3)) >>> output[0].size() oneflow.Size([3, 1, 5, 6]) >>> output[1].size() oneflow.Size([3, 2, 5, 6]) >>> output[2].size() oneflow.Size([3, 1, 5, 6]) """, ) add_docstr( oneflow.vsplit, r""" Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections. Each split is a view of input. This is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0), except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.vsplit.html. Args: input (Tensor): the input tensor. indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n). sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n). If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU. Returns: oneflow.TensorTuple: the output TensorTuple. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.rand(4, 4, 5, 6) >>> output = flow.vsplit(input, (1, 3)) >>> output[0].size() oneflow.Size([1, 4, 5, 6]) >>> output[1].size() oneflow.Size([2, 4, 5, 6]) >>> output[2].size() oneflow.Size([1, 4, 5, 6]) """, ) add_docstr( oneflow.cumsum, r"""oneflow.cumsum(input, dim) -> Tensor This operator computes the cumulative sum of input elements in the given dimension. The equation is: $$ y_{i}=x_{0}+x_{1}+...+x_{i} $$ Args: input (Tensor): the input ND tensor. dim (int): the dimension to do cumsum, valid range is [-N, N-1), N is tensor's dimensions Returns: oneflow.Tensor: The result tensor with cumsum result. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.ones(3, 3) >>> dim = 1 >>> flow.cumsum(input, dim) tensor([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.cumprod, """oneflow.cumprod(input, dim) -> Tensor This operator computes the cumulative product of input elements in the given dimension. The equation is: $$ y_{i}=x_{0}*x_{1}*...*x_{i} $$ Args: input (Tensor): the input tensor. dim (int): the dimension to do cumsum whose valid range is [-N, N-1), and the N is tensor's dimensions Returns: oneflow.Tensor: The result tensor with cumprod result. For example: .. code-block:: python >>> import oneflow as flow >>> input=flow.tensor([1, 2, 3]) >>> flow.cumprod(input, dim=0) tensor([1, 2, 6], dtype=oneflow.int64) """, ) add_docstr( oneflow.trunc, r"""trunc(input) -> Tensor The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.trunc.html Returns a new tensor with the truncated integer values of the elements of :attr:`input`. Args: input(Tensor): the input tensor. Example:: >>> import oneflow as flow >>> a = flow.tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) >>> flow.trunc(a) tensor([3., 0., -0., -0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.digamma, r"""digamma(input) -> Tensor .. math:: \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} Args: input (Tensor): the tensor to compute the digamma function on .. note:: This function is similar to SciPy's `scipy.special.digamma`. Example:: >>> import oneflow as flow >>> a = flow.tensor([1, 0.5]) >>> flow.digamma(a) tensor([-0.5772, -1.9635], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/meshgrid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.meshgrid, """ Take :math:`N` tensors, each of which can be either scalar or 1-dimensional vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.meshgrid.html#torch.meshgrid Args: tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be treated as tensors of size :math:`(1,)` automatically. indexing ((string, optional): the indexing mode, either "xy" or "ij", defaults to "ij". If "ij" is selected, the dimensions are in the same order as the cardinality of the inputs. If "xy" is selected, the first dimension corresponds to the cardinality of the second input and the second dimension corresponds to the cardinality of the first input. Returns: seq (sequence of Tensors): If the input has :math:`k` tensors of size :math:`(N_1,), (N_2,), \\ldots , (N_k,)`, then the output would also have :math:`k` tensors, where all tensors are of size :math:`(N_1, N_2, \\ldots , N_k)`. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input1 = flow.tensor(np.array([2, 2, 3]), dtype=flow.float32) >>> input2 = flow.tensor(np.array([4, 5, 6]), dtype=flow.float32) >>> of_x, of_y = flow.meshgrid(input1, input2) >>> of_x tensor([[2., 2., 2.], [2., 2., 2.], [3., 3., 3.]], dtype=oneflow.float32) >>> of_y tensor([[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/module.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.nn.Module.to_consistent, """ This interface is no longer available, please use :func:`oneflow.nn.Module.to_global` instead. """, ) add_docstr( oneflow.nn.Module.to_global, """ Convert the parameters and buffers to global. It performs the same :func:`oneflow.Tensor.to_global` conversion to each parameter and buffer in this module. Note: This method modifies the module in-place. Both placement and sbp are required if the parameters and buffers of this module are local, otherwise at least one of placement and sbp is required. Args: placement (flow.placement, optional): the desired placement of the parameters and buffers in this module. Default: None sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of the parameters and buffers in this module. Default: None For example: .. code-block:: python >>> import oneflow as flow >>> m = flow.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3) >>> m.to_global(placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.split(0)]) >>> m.weight.is_global True >>> m.bias.is_global True """, ) ================================================ FILE: python/oneflow/framework/docstr/nms.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.nms, """ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). NMS iteratively removes lower scoring boxes which have an IoU greater than iou_threshold with another (higher scoring) box. Args: boxes (Tensor[N, 4]): boxes to perform NMS on. They are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold Returns: Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores """, ) ================================================ FILE: python/oneflow/framework/docstr/nonzero.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.nonzero, """nonzero(input, *, out=None, as_tuple=False) -> Tensor or tuple of Tensors .. note:: When :attr:`as_tuple` is ``False`` (default): returns a 2-D tensor where each row is the index for a nonzero value. When :attr:`as_tuple` is ``True``: returns a tuple of 1-D index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor contains nonzero indices for a certain dimension. See below for more details on the two behaviors. **When** :attr:`as_tuple` **is** ``False`` **(default)**: Returns a tensor containing the indices of all non-zero elements of :attr:`input`. Each row in the result contains the indices of a non-zero element in :attr:`input`. The result is sorted lexicographically, with the last index changing the fastest (C-style). If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor :attr:`out` is of size :math:`(z \\times n)`, where :math:`z` is the total number of non-zero elements in the :attr:`input` tensor. **When** :attr:`as_tuple` **is** ``True``: Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, each containing the indices (in that dimension) of all non-zero elements of :attr:`input` . If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` tensors of size :math:`z`, where :math:`z` is the total number of non-zero elements in the :attr:`input` tensor. As a special case, when :attr:`input` has zero dimensions and a nonzero scalar value, it is treated as a one-dimensional tensor with one element. Args: input(Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor containing indices Returns: Tensor or tuple of Tensors: If :attr:`as_tuple` is ``False``, the output tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for each dimension, containing the indices of each nonzero element along that dimension. Example:: >>> import oneflow as flow >>> flow.nonzero(flow.tensor([1, 1, 1, 0, 1])) tensor([[0], [1], [2], [4]], dtype=oneflow.int64) >>> flow.nonzero(flow.tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]])) tensor([[0, 0], [1, 1], [2, 2], [3, 3]], dtype=oneflow.int64) >>> flow.nonzero(flow.tensor([1, 1, 1, 0, 1]), as_tuple=True) (tensor([0, 1, 2, 4], dtype=oneflow.int64),) >>> flow.nonzero(flow.tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) (tensor([0, 1, 2, 3], dtype=oneflow.int64), tensor([0, 1, 2, 3], dtype=oneflow.int64)) >>> flow.nonzero(flow.tensor(5), as_tuple=True) (tensor([0], dtype=oneflow.int64),) """, ) ================================================ FILE: python/oneflow/framework/docstr/norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.linalg.vector_norm, """linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor Computes a vector norm. Supports input of float, double dtypes. This function does not necessarily treat multidimensonal attr:`input` as a batch of vectors, instead: - If :attr:`dim`\\ `= None`, :attr:`input` will be flattened before the norm is computed. - If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions. This behavior is for consistency with :func:`flow.linalg.norm`. :attr:`ord` defines the vector norm that is computed. The following norms are supported: ====================== ======================================================== :attr:`ord` vector norm ====================== ======================================================== `2` (default) `2`-norm (see below) `inf` `max(abs(x))` `-inf` `min(abs(x))` `0` `sum(x != 0)` other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` ====================== ======================================================== where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. Args: input (Tensor): tensor, flattened by default, but this behavior can be controlled using :attr:`dim`. ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` dim (int, Tuple[int], optional): dimensions over which to compute the norm. See above for the behavior when :attr:`dim`\\ `= None`. Default: `None` keepdim (bool, optional): If set to `True`, the reduced dimensions are retained in the result as dimensions with size one. Default: `False` Returns: A real-valued tensor. Examples: .. code-block:: python >>> import oneflow as flow >>> from oneflow import linalg as LA >>> import numpy as np >>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4) >>> a tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32) >>> b = a.reshape(3, 3) >>> b tensor([[-4., -3., -2.], [-1., 0., 1.], [ 2., 3., 4.]], dtype=oneflow.float32) >>> LA.vector_norm(a, ord=3.5) tensor(5.4345, dtype=oneflow.float32) >>> LA.vector_norm(b, ord=3.5) tensor(5.4345, dtype=oneflow.float32) """, ) add_docstr( oneflow.linalg.matrix_norm, """linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor Computes a matrix norm. Support input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices: the norm will be computed over the dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will be treated as batch dimensions. The output will have the same batch dimensions. :attr:`ord` defines the matrix norm that is computed. The following norms are supported: ====================== ======================================================== :attr:`ord` matrix norm ====================== ======================================================== `'fro'` (default) Frobenius norm `'nuc'` -- not supported yet -- `inf` `max(sum(abs(x), dim=1))` `-inf` `min(sum(abs(x), dim=1))` `1` `max(sum(abs(x), dim=0))` `-1` `min(sum(abs(x), dim=0))` `2` -- not supported yet -- `-2` -- not supported yet -- ====================== ======================================================== where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. Args: input (Tensor): tensor with two or more dimensions. By default its shape is interpreted as `(*, m, n)` where `*` is zero or more batch dimensions, but this behavior can be controlled using :attr:`dim`. ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` keepdim (bool, optional): If set to `True`, the reduced dimensions are retained in the result as dimensions with size one. Default: `False` Returns: A real-valued tensor. Examples: .. code-block:: python >>> import oneflow as flow >>> from oneflow import linalg as LA >>> import numpy as np >>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape(3,3) >>> a tensor([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], dtype=oneflow.float32) >>> LA.matrix_norm(a) tensor(14.2829, dtype=oneflow.float32) >>> LA.matrix_norm(a, ord=-1) tensor(9., dtype=oneflow.float32) >>> b = a.expand(2, -1, -1) >>> b tensor([[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]], dtype=oneflow.float32) >>> LA.matrix_norm(b, dim=(0, 2)) tensor([ 3.1623, 10.0000, 17.2627], dtype=oneflow.float32) """, ) add_docstr( oneflow.linalg.norm, """linalg.norm(input, ord=None, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor Returns the matrix norm or vector norm of a given tensor. This function can calculate one of eight different types of matrix norms, or one of an infinite number of vector norms, depending on both the number of reduction dimensions and the value of the `ord` parameter. Args: input (Tensor): The input tensor. If dim is None, input must be 1-D or 2-D, unless :attr:`ord` is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D will be returned. Its data type must be either a floating point or complex type. For complex inputs, the norm is calculated on of the absolute values of each element. If the input is complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'None'` The following norms can be calculated: ============== ============================ ================================= :attr:`ord` norm for matrices norm for vectors ============== ============================ ================================= None Frobenius norm `2`-norm `'fro'` Frobenius norm -- not supported -- `'nuc'` -- not supported yet -- -- not supported -- `inf` `max(sum(abs(x), dim=1))` `max(abs(x))` `-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` `0` -- not supported -- `sum(x != 0)` `1` `max(sum(abs(x), dim=0))` as below `-1` `min(sum(abs(x), dim=0))` as below `2` -- not supported yet -- as below `-2` -- not supported yet -- as below other -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` ============== ============================ ================================= where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int, vector norm will be calculated over the specified dimension. If :attr:`dim` is a 2-tuple of ints, matrix norm will be calculated over the specified dimensions. If :attr:`dim` is None, matrix norm will be calculated when the input tensor has two dimensions, and vector norm will be calculated when the input tensor has one dimension. Default: ``None`` keepdim (bool, optional): If set to True, the reduced dimensions are retained in the result as dimensions with size one. Default: ``False`` out (Tensor, optional): The output tensor. For example: .. code-block:: python >>> import oneflow as flow >>> from oneflow import linalg as LA >>> import numpy as np >>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4) >>> a tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32) >>> b = a.reshape(3, 3) >>> b tensor([[-4., -3., -2.], [-1., 0., 1.], [ 2., 3., 4.]], dtype=oneflow.float32) >>> LA.norm(a) tensor(7.7460, dtype=oneflow.float32) >>> LA.norm(b) tensor(7.7460, dtype=oneflow.float32) >>> LA.norm(b, 'fro') tensor(7.7460, dtype=oneflow.float32) >>> LA.norm(a, float('inf')) tensor(4., dtype=oneflow.float32) >>> LA.norm(b, float('inf')) tensor(9., dtype=oneflow.float32) >>> LA.norm(a, -float('inf')) tensor(0., dtype=oneflow.float32) >>> LA.norm(b, -float('inf')) tensor(2., dtype=oneflow.float32) >>> LA.norm(a, 1) tensor(20., dtype=oneflow.float32) >>> LA.norm(b, 1) tensor(7., dtype=oneflow.float32) >>> LA.norm(a, -1) tensor(0., dtype=oneflow.float32) >>> LA.norm(b, -1) tensor(6., dtype=oneflow.float32) >>> LA.norm(a, 2) tensor(7.7460, dtype=oneflow.float32) >>> LA.norm(a, -2) tensor(0., dtype=oneflow.float32) >>> LA.norm(a, 3) tensor(5.8480, dtype=oneflow.float32) >>> LA.norm(a, -3) tensor(0., dtype=oneflow.float32) >>> c = flow.tensor([[1., 2., 3.], ... [-1, 1, 4]]) >>> LA.norm(c, dim=0) tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32) >>> LA.norm(c, dim=1, keepdim = True) tensor([[3.7417], [4.2426]], dtype=oneflow.float32) >>> LA.norm(c, ord=1, dim=1) tensor([6., 6.], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.normalize, """nn.functional.normalize(input: Tensor, p: float=2.0, dim: int=0, epsilon: float=1e-12) -> Tensor Performs :math:`L_p` normalization of inputs over specified dimension For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as: .. math:: v = \\frac{v}{\max(\\lVert v \\rVert_p, \\epsilon)}. With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. But note that the gradient calculation of the input tensor has different results on different frameworks when `input.shape[dim] = 1`. Args: input (oneflow.Tensor): input tensor of any shape p (float): the exponent value in the norm formulation. Default: 2 dim (int): the dimension to reduce. Default: 1 eps (float): small value to avoid division by zero. Default: 1e-12 For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([[1, 2], [3, 4]], dtype=flow.float32) >>> out = flow.nn.functional.normalize(x, 2, 0) >>> out tensor([[0.3162, 0.4472], [0.9487, 0.8944]], dtype=oneflow.float32) >>> out = flow.nn.functional.normalize(x, 2, 1) >>> out tensor([[0.4472, 0.8944], [0.6000, 0.8000]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/normalization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.nn.functional.layer_norm, """nn.functional.layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05) -> Tensor Applies Layer Normalization for last certain number of dimensions. See :class:`~oneflow.nn.LayerNorm` for details. """, ) ================================================ FILE: python/oneflow/framework/docstr/oneflow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.set_num_threads, """ Sets the number of threads used for intraop parallelism on CPU. .. WARNING:: To ensure that the correct number of threads is used, set_num_threads must be called before running eager, eager globe or ddp. """, ) add_docstr( oneflow.get_default_dtype, """oneflow.get_default_dtype() -> oneflow._oneflow_internal.dtype Returns the default floating point dtype. Returns: oneflow.dtype: The default floating point dtype. For example: .. code-block:: python >>> import oneflow as flow >>> flow.set_default_dtype(flow.float32) >>> flow.get_default_dtype() oneflow.float32 >>> flow.set_default_dtype(flow.float64) >>> flow.get_default_dtype() oneflow.float64 >>> flow.set_default_tensor_type(flow.FloatTensor) >>> flow.get_default_dtype() oneflow.float32 """, ) add_docstr( oneflow.set_default_dtype, """oneflow.set_default_dtype() -> None Sets the default floating point type for those source operators which create Tensor. The default floating point type is ``oneflow.float32``. Args: dtype (oneflow.dtype): The floating point dtype. For example: .. code-block:: python >>> import oneflow >>> oneflow.set_default_dtype(oneflow.float64) >>> x = oneflow.randn(2, 3) >>> x.dtype oneflow.float64 >>> oneflow.set_default_dtype(oneflow.float32) >>> x = oneflow.randn(2, 3) >>> x.dtype oneflow.float32 """, ) ================================================ FILE: python/oneflow/framework/docstr/onehot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.one_hot, r""" one_hot(input, num_classes=-1, on_value=1, off_value=0) This operator generates a onehot Tensor from input Tensor. If input Tensor's rank is `N`, the corresponding onehot Tensor's rank is `N+1`. Args: input (Tensor): The input Tensor. num_classes (int): The length of onehot Tensor. on_value (Union[int, float], optional): The fill value when `x[i] == i`. Defaults to 1. off_value (Union[int, float], optional): The fill value when `x[i] != i`. Defaults to 0. Note: The data type of input tensor should be `int32` or `int64`. Returns: oneflow.Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input=flow.tensor(np.array([0, 3, 1, 2]).astype(np.int64), dtype=flow.int64) >>> out = flow.nn.functional.one_hot(input, num_classes=5) >>> out tensor([[1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/pooling.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.adaptive_avg_pool1d, """ adaptive_avg_pool1d(input, output_size) -> Tensor Applies a 1D adaptive average pooling over an input signal composed of several input planes. See :class:`~oneflow.nn.AdaptiveAvgPool1d` for details and output shape. Args: input: the input tensor output_size: the target output size (single integer) For examples: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([[[ 0.0558, -0.6875, -1.6544, -0.6226, 0.1018, 0.0502, -1.2538, 0.1491]]]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> flow.nn.functional.adaptive_avg_pool1d(input, output_size=[4]) tensor([[[-0.3158, -1.1385, 0.0760, -0.5524]]], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.adaptive_avg_pool2d, """ adaptive_avg_pool2d(input, output_size) -> Tensor Applies a 2D adaptive average pooling over an input signal composed of several input planes. See :class:`~oneflow.nn.AdaptiveAvgPool2d` for details and output shape. Args: input: the input tensor output_size: the target output size (single integer or double-integer tuple) For examples: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> arr = np.array([[[[ 0.1004, 0.0488, -1.0515, 0.9466],[ 0.4538, 0.2361, 1.3437, 0.398 ],[ 0.0558, -0.6875, -1.6544, -0.6226],[ 0.1018, 0.0502, -1.2538, 0.1491]]]]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> outputs = flow.nn.functional.adaptive_avg_pool2d(input, (2, 2)) """, ) add_docstr( oneflow._C.adaptive_avg_pool3d, """ adaptive_avg_pool3d(input, output_size) -> Tensor Applies a 3D adaptive average pooling over an input signal composed of several input planes. See :class:`~oneflow.nn.AdaptiveAvgPool3d` for details and output shape. Args: input: the input tensor output_size: the target output size (single integer or triple-integer tuple) For examples: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 1, 4, 4, 4), dtype=flow.float32) >>> output = flow.nn.functional.adaptive_avg_pool3d(input, (2, 2, 2)) """, ) add_docstr( oneflow._C.avg_pool1d, """ avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor Applies a 1D average pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool1d.html See :class:`~oneflow.nn.AvgPool1d` for details and output shape. Args: input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iW)` kernel_size: the size of the window. Can be a single number or a tuple `(kW,)` stride: the stride of the window. Can be a single number or a tuple `(sW,)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padW,)`. Default: 0 ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape. Default: ``False`` count_include_pad: when True, will include the zero-padding in the averaging calculation. Default: ``True`` Examples:: >>> # pool of square window of size=3, stride=2 >>> import oneflow >>> input = oneflow.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=oneflow.float32) >>> oneflow.nn.functional.avg_pool1d(input, kernel_size=3, stride=2) tensor([[[2., 4., 6.]]], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.avg_pool2d, """ avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0) -> Tensor Applies 2D average-pooling operation in :math:`kH \\times kW` regions by step size :math:`sH \\times sW` steps. The number of output features is equal to the number of input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool2d.html. See :class:`~oneflow.nn.AvgPool2d` for details and output shape. Args: input: input tensor :math:`(\\text{minibatch} , \\text{in_channels} , iH , iW)` kernel_size: size of the pooling region. Can be a single number or a tuple `(kH, kW)` stride: stride of the pooling operation. Can be a single number or a tuple `(sH, sW)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padH, padW)`. Default: 0 ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape. Default: ``False`` count_include_pad: when True, will include the zero-padding in the averaging calculation. Default: ``True`` divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: 0 """, ) add_docstr( oneflow._C.avg_pool3d, """ avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0) -> Tensor Applies 3D average-pooling operation in :math:`kT \\times kH \\times kW` regions by step size :math:`sT \\times sH \\times sW` steps. The number of output features is equal to :math:`\\lfloor\\frac{\\text{input planes}}{sT}\\rfloor`. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool3d.html See :class:`~oneflow.nn.AvgPool3d` for details and output shape. Args: input: input tensor :math:`(\\text{minibatch} , \\text{in_channels} , iT \\times iH , iW)` kernel_size: size of the pooling region. Can be a single number or a tuple `(kT, kH, kW)` stride: stride of the pooling operation. Can be a single number or a tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padT, padH, padW)`, Default: 0 ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape count_include_pad: when True, will include the zero-padding in the averaging calculation divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: 0 """, ) add_docstr( oneflow._C.max_unpool1d, """ max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor Computes a partial inverse of ``MaxPool1d``. See :class:`MaxUnpool1d` for details. """, ) add_docstr( oneflow._C.max_unpool2d, """ max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor Computes a partial inverse of ``MaxPool2d``. See :class:`MaxUnpool2d` for details. """, ) add_docstr( oneflow._C.max_unpool3d, """ max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor Computes a partial inverse of ``MaxPool3d``. See :class:`MaxUnpool3d` for details. """, ) ================================================ FILE: python/oneflow/framework/docstr/quantile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.quantile, """ quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.quantile.html. Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with indices ``i`` and ``j`` in the sorted order, result is computed according to the given :attr:`interpolation` method as follows: - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. - ``lower``: ``a``. - ``higher``: ``b``. - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). - ``midpoint``: ``(a + b) / 2``. If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. .. note:: By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. Args: input (oneflow.Tensor): the input Tensor. q (float or oneflow.Tensor): a scalar or 1D tensor of values in the range [0, 1]. dim (int, optional): the dimension to reduce. Default is None. keepdim (bool, optional): whether the output tensor has dim retained or not. Default is False interpolation (str, optional): interpolation method to use when the desired quantile lies between two data points. Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. Default is ``linear``. out (oneflow.Tensor, optional): the output Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.arange(8.) >>> q = flow.tensor([0.25, 0.5, 0.75]) >>> flow.quantile(a, q, dim=0, keepdim=True) tensor([[1.7500], [3.5000], [5.2500]], dtype=oneflow.float32) >>> a = flow.arange(4.) >>> flow.quantile(a, 0.6, interpolation="linear") tensor(1.8000, dtype=oneflow.float32) >>> flow.quantile(a, 0.6, interpolation="lower") tensor(1., dtype=oneflow.float32) >>> flow.quantile(a, 0.6, interpolation="higher") tensor(2., dtype=oneflow.float32) >>> flow.quantile(a, 0.6, interpolation="midpoint") tensor(1.5000, dtype=oneflow.float32) >>> flow.quantile(a, 0.6, interpolation="nearest") tensor(2., dtype=oneflow.float32) >>> flow.quantile(a, 0.4, interpolation="nearest") tensor(1., dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/random.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.bernoulli, """ bernoulli(input, p, *, generator=None, out=None) This operator returns a Tensor with binaray random numbers (0 / 1) from a Bernoulli distribution. Args: input (Tensor): the input tensor of probability values for the Bernoulli distribution p (float, optional): the probability for the Bernoulli distribution. If specified, Bernoulli distribution will use p for sampling, not input generator (Generator, optional): a pseudorandom number generator for sampling out (Tensor, optional): the output tensor. Shape: - Input: :math:`(*)`. Input can be of any shape - Output: :math:`(*)`. Output is of the same shape as input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> arr = np.array( ... [ ... [1.0, 1.0, 1.0], ... [1.0, 1.0, 1.0], ... [1.0, 1.0, 1.0], ... ] ... ) >>> x = flow.tensor(arr, dtype=flow.float32) >>> y = flow.bernoulli(x) >>> y tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32) >>> y = flow.bernoulli(x, 1) >>> y tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32) >>> y = flow.bernoulli(x, p=0) >>> y tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.randn, """ randn(*size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution). The shape of the tensor is defined by the variable argument ``size``. Args: size (int... or oneflow.Size): Defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size. dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``. generator (flow.Generator, optional): a pseudorandom number generator for sampling device (flow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(3,3) # construct local tensor >>> x.shape oneflow.Size([3, 3]) >>> x.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> sbp = flow.sbp.broadcast >>> x = flow.randn(3,3,placement=placement,sbp=sbp) # construct global tensor >>> x.is_global True """, ) add_docstr( oneflow._C.randn_like, """ randn_like(input, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor with the same size as `input` that is filled with random numbers from a normal distribution with mean 0 and variance 1. flow.randn_like(input) is equivalent to flow.randn(input.size(), dtype=input.dtype, device=input.device). Args: input (oneflow.Tensor): the size of ``input`` will determine size of the output tensor. dtype (flow.dtype, optional): The desired data type of returned tensor. defaults to the dtype of `input`. generator (flow.Generator, optional): a pseudorandom number generator for sampling device (flow.device, optional): The desired device of returned local tensor. If None, defaults to the device of `input`. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement, If None, will construct local tensor. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(3,3) # construct local tensor >>> y = flow.randn_like(x) >>> y.shape oneflow.Size([3, 3]) >>> y.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> sbp = flow.sbp.broadcast >>> z = flow.randn_like(y, placement=placement, sbp=sbp) # construct global tensor >>> z.is_global True """, ) add_docstr( oneflow._C.rand, """ rand(*size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1) The shape of the tensor is defined by the variable argument ``size``. Args: size (int... or oneflow.Size): Defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size. dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``. generator (flow.Generator, optional): a pseudorandom number generator for sampling device (flow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.rand(3,3) # construct local tensor >>> x.shape oneflow.Size([3, 3]) >>> x.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> sbp = flow.sbp.broadcast >>> x = flow.rand(3, 3, placement=placement, sbp=sbp) # construct global tensor >>> x.is_global True """, ) add_docstr( oneflow._C.normal, r""" The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.normal.html normal(mean, std, *, generator=None, out=None) -> Tensor Returns a tensor of random numbers drawn from separate normal distributions whose mean and standard deviation are given. The :attr:`mean` is a tensor with the mean of each output element's normal distribution The :attr:`std` is a tensor with the standard deviation of each output element's normal distribution The shapes of :attr:`mean` and :attr:`std` don't need to match, but the total number of elements in each tensor need to be the same. .. note:: Infers the output shape from input arrays :attr:`mean` and :attr:`std`. The output shape will have a dimensionality equal to the max of :attr:`mean` and :attr:`std`. Dimensions with size 1 in either :attr:`mean` or :attr:`std` are expanded to match the other. Args: mean (Tensor): the tensor of per-element means std (Tensor): the tensor of per-element standard deviations Keyword args: generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided. out (Tensor, optional): Output tensor, will be resized and filled with the result. If not provided, a new tensor is created. Example: .. code-block:: python >>> import oneflow as flow >>> generator = flow.Generator() >>> generator.manual_seed(0) #doctest: +ELLIPSIS >>> z = flow.normal(mean=flow.arange(1., 11.), std=flow.arange(1, 0, -0.1), generator=generator) >>> z[:5] tensor([3.2122, 3.0468, 3.6192, 4.3387, 5.6261], dtype=oneflow.float32) normal(mean=0.0, std, `*`, generator=None, out=None) -> Tensor. Similar to the function above, but the means are shared among all drawn elements. Args: mean (float, optional) : the mean for all distributions std (Tensor) : the tensor of per-element standard deviations Keyword args: generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided. out (Tensor, optional): Output tensor, will be resized and filled with the result. If not provided, a new tensor is created. Example: .. code-block:: python >>> import oneflow as flow >>> flow.normal(mean=0.5, std=flow.arange(1., 6.)).shape oneflow.Size([5]) normal(mean, std=1.0, `*`, generator=None, out=None) -> Tensor Similar to the function above, but the standard deviations are shared among all drawn elements. Args: mean (Tensor): the tensor of per-element means std (float, optional): the standard deviation Keyword args: generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided. out (Tensor): The output tensor Returns: Tensor: The output tensor, with random normal values. Example: .. code-block:: python >>> import oneflow as flow >>> flow.normal(mean=flow.arange(1., 6.)).shape oneflow.Size([5]) normal(mean, std, size, `*`, out=None, placement=None, sbp=None, generator=None, dtype=None, device=None, requires_grad=False) -> Tensor Returns a tensor of random numbers drawn from separate normal distributions whose mean and standard deviation are given. Args: mean (float): the mean for all distributions std (float): the standard deviation for all distributions size (int...): a sequence of integers defining the shape of the output tensor. Keyword args: out (Tensor, optional): the output tensor. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. generator(:class:`oneflow.Generator`, optional): a pseudorandom number generator for sampling dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor. Default: `oneflow.float32`. device: the desired device of returned tensor. Default: cpu. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Example: .. code-block:: python >>> import oneflow as flow >>> generator = flow.Generator() >>> generator.manual_seed(0) #doctest: +ELLIPSIS >>> y = flow.normal(0, 1, 5, generator=generator) >>> y tensor([2.2122, 1.1631, 0.7740, 0.4838, 1.0434], dtype=oneflow.float32) """, ) add_docstr( oneflow._C.randint, """ randint(low=0, high, size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). The shape of the tensor is defined by the variable argument ``size``. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.randint.html. Args: low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. high (int): One above the highest integer to be drawn from the distribution. size (tuple or oneflow.Size): Defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size. Keyword args: dtype (oneflow.dtype, optional): The desired data type of returned tensor. Default: ``flow.int64``. generator (oneflow.Generator, optional) – a pseudorandom number generator for sampling device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (oneflow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (oneflow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> generator = flow.Generator() >>> generator.manual_seed(0) #doctest: +ELLIPSIS >>> y = flow.randint(0, 5, (3,3), generator=generator) # construct local tensor >>> y tensor([[2, 2, 3], [4, 3, 4], [2, 4, 2]], dtype=oneflow.int64) >>> y.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.randint(0, 5, (3,3), generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """, ) add_docstr( oneflow._C.randint_like, """ randint_like(input, low=0, high, size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.randint_like.html. Args: input (oneflow.Tensor): the size of ``input`` will determine size of the output tensor. low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. high (int): One above the highest integer to be drawn from the distribution. Keyword args: dtype (oneflow.dtype, optional): The desired data type of returned tensor. Default: ``flow.int64``. generator (oneflow.Generator, optional) – a pseudorandom number generator for sampling device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (oneflow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp (oneflow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> generator = flow.Generator() >>> generator.manual_seed(0) #doctest: +ELLIPSIS >>> x = flow.randn(2, 2, generator=generator) >>> y = flow.randint_like(x, 0, 5, generator=generator) # construct local tensor >>> y tensor([[3, 4], [2, 4]], dtype=oneflow.int64) >>> y.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.randint_like(x, 0, 5, generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """, ) add_docstr( oneflow._C.randperm, r""" randperm(n, *, generator=None, dtype=torch.int64, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a random permutation of integers from ``0`` to ``n - 1``. Args: n (int): the upper bound (exclusive) Keyword args: generator(:class:`oneflow.Generator`, optional): a pseudorandom number generator for sampling dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor. Default: ``oneflow.int64``. device: the desired device of returned tensor. Default: cpu. placement:(:class:`flow.placement`, optional): The desired device of returned global tensor. If None, will construct local tensor. sbp: (:class:`flow.sbp`, optional): The desired sbp of returned global tensor. It must be equal with the numbers of placement. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Example: .. code-block:: python >>> import oneflow as flow >>> generator = flow.Generator() >>> generator.manual_seed(0) #doctest: +ELLIPSIS >>> y = flow.randperm(5, generator=generator) # construct local tensor >>> y tensor([2, 4, 3, 0, 1], dtype=oneflow.int64) >>> y.is_global False >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.randperm(5, generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """, ) add_docstr( oneflow.multinomial, """ multinomial(input, num_samples, replacement=False, generator=None) -> LongTensor Returns a tensor where each row contains :attr:`num_samples` indices sampled from the multinomial probability distribution located in the corresponding row of tensor :attr:`input`. .. note:: The rows of :attr:`input` do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum. Indices are ordered from left to right according to when each was sampled (first samples are placed in first column). If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape :math:`(m x num\_samples)`. If replacement is ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. .. note:: When drawn without replacement, :attr:`num_samples` must be lower than number of non-zero elements in :attr:`input` (or the min number of non-zero elements in each row of :attr:`input` if it is a matrix). Args: input (Tensor): the input tensor containing probabilities num_samples (int): number of samples to draw replacement (bool, optional): whether to draw with replacement or not For example: .. code-block:: python >>> import oneflow as flow >>> gen = flow.manual_seed(0) >>> weights = flow.tensor([0, 10, 3, 0], dtype=flow.float) # create a tensor of weights >>> flow.multinomial(weights, 2) tensor([1, 2], dtype=oneflow.int64) >>> flow.multinomial(weights, 4, replacement=True) tensor([1, 2, 1, 1], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/reduce_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.max, """ oneflow.max(input, dim=None, keepdim=False) Computes the maximum value of all elements in the input tensor. Args: input (oneflow.Tensor): the Input Tensor dim (int, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` Returns: Tensor or Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): If :attr:`dim` is `None`, returns the maximum value of all elements in the `input` tensor. Otherwise, returns a tuple of Tensor (values, indices), where the `values` are the maximum value of all elements in the `input` tensor, the `indices` are the indices of the elements in the original input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[4, 1, 5], [2, 6, 3]]) >>> flow.max(input) tensor(6., dtype=oneflow.float32) >>> result = flow.max(input, dim=1) >>> result.values tensor([5., 6.], dtype=oneflow.float32) >>> result.indices tensor([2, 1], dtype=oneflow.int64) """, ) add_docstr( oneflow.min, """ oneflow.min(input, dim=None, keepdim=False) Computes the minimum value of all elements in the input tensor. Args: input (oneflow.Tensor): the Input Tensor dim (int, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` Returns: Tensor or Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): If :attr:`dim` is `None`, returns the minimum value of all elements in the `input` tensor. Otherwise, returns a tuple of Tensor (values, indices), where the `values` are the minimum value of all elements in the `input` tensor, the `indices` are the indices of the elements in the original input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[4, 1, 5], [2, 6, 3]]) >>> flow.min(input) tensor(1., dtype=oneflow.float32) >>> result = flow.min(input, dim=1) >>> result.values tensor([1., 2.], dtype=oneflow.float32) >>> result.indices tensor([1, 0], dtype=oneflow.int64) """, ) add_docstr( oneflow.sum, """ oneflow.sum(input, dim=None, keepdim=False) -> Tensor Computes the sum of row of elements in a tensor in the given dimension. If the dimension is None, sum of all elements will be caculated. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (oneflow.Tensor): the Input Tensor dim (int or tuple of ints, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) >>> flow.sum(input) tensor(21., dtype=oneflow.float32) >>> flow.sum(input, dim=0) tensor([5., 7., 9.], dtype=oneflow.float32) >>> flow.sum(input, dim=1) tensor([ 6., 15.], dtype=oneflow.float32) """, ) add_docstr( oneflow.mean, """ oneflow.mean(input, dim=None, keepdim=False) -> Tensor Computes the mean of row of elements in a tensor in the given dimension. If the dimension is None, mean of all elements will be caculated. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (oneflow.Tensor): the Input Tensor dim (int or tuple of ints, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) >>> flow.mean(input) tensor(3.5000, dtype=oneflow.float32) >>> flow.mean(input, dim=0) tensor([2.5000, 3.5000, 4.5000], dtype=oneflow.float32) >>> flow.mean(input, dim=1) tensor([2., 5.], dtype=oneflow.float32) """, ) add_docstr( oneflow.prod, """ oneflow.prod(input, dim=None, keepdim=False) -> Tensor Computes the product of row of elements in a tensor in the given dimension. If the dimension is None, product of all elements will be caculated. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (oneflow.Tensor): the Input Tensor dim (int or tuple of ints, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) >>> flow.prod(input) tensor(720., dtype=oneflow.float32) >>> flow.prod(input, dim=0) tensor([ 4., 10., 18.], dtype=oneflow.float32) >>> flow.prod(input, dim=1) tensor([ 6., 120.], dtype=oneflow.float32) """, ) add_docstr( oneflow.all, """ oneflow.all(input, dim=None, keepdim=False) -> Tensor For each row of `input` in the given dimension `dim`, returns True if all element in the row evaluate to True and False otherwise. If the dimension is None, compute if all elements in the input tensor to true. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (oneflow.Tensor): the Input Tensor dim (int, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) < 4 >>> input tensor([[ True, True, True], [False, False, False]], dtype=oneflow.bool) >>> flow.all(input) tensor(False, dtype=oneflow.bool) >>> flow.all(input, 1) tensor([ True, False], dtype=oneflow.bool) >>> flow.all(input, 1, True) tensor([[ True], [False]], dtype=oneflow.bool) """, ) add_docstr( oneflow.any, """ oneflow.any(input, dim=None, keepdim=False) -> Tensor For each row of `input` in the given dimension `dim`, returns True if any element in the row evaluate to True and False otherwise. If the dimension is None, compute if any elements in the input tensor to true. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: input (oneflow.Tensor): the Input Tensor dim (int, optional): the dimension to reduce. Default: `None` keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) < 4 >>> input tensor([[ True, True, True], [False, False, False]], dtype=oneflow.bool) >>> flow.any(input) tensor(True, dtype=oneflow.bool) >>> flow.any(input, 0) tensor([True, True, True], dtype=oneflow.bool) >>> flow.any(input, 0, True) tensor([[True, True, True]], dtype=oneflow.bool) """, ) add_docstr( oneflow.nansum, r"""oneflow.nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor Returns the sum of each row of the ``input`` tensor in the given dimension ``dim``, treating Not a Numbers (NaNs) as zero. If ``dim`` is a list of dimensions, reduce over all of them. If ``keepdim`` is ``True``, the output tensor is of the same size as ``input`` except in the dimension(s) ``dim`` where it is of size 1. Otherwise, ``dim`` is squeezed (see :class:`oneflow.squeeze()`), resulting in the output tensor having 1 (or ``len(dim)``) fewer dimension(s). The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nansum.html. Args: input (oneflow.Tensor): the Input Tensor dim (int, optional): the dimension to reduce. Default: ``None`` keepdim (bool, optional): whether the output tensor has ``dim`` retained or not. Default: `False` dtype (oneflow.dtype, optional): the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: ``None``. Example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1., 2., float("nan")]) >>> flow.nansum(x) tensor(3., dtype=oneflow.float32) >>> x = flow.tensor([[1., float("nan")], [float("nan"), 2]]) >>> flow.nansum(x, dim=1) tensor([1., 2.], dtype=oneflow.float32) >>> x = flow.tensor([float("nan") for i in range(3)]) >>> flow.nansum(x) tensor(0., dtype=oneflow.float32) """, ) add_docstr( oneflow.logsumexp, r""" oneflow.logsumexp(input, dim, keepdim=False) -> Tensor Returns the log of summed exponentials of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. The computation is numerically stabilized. For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is .. math:: \text{logsumexp}(x)_{{i}} = \log \sum_j \exp(x_{{ij}}) The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.12/generated/torch.logsumexp.html. Args: input (oneflow.Tensor): the Input Tensor dim (int or tuple of ints): the dimension or dimensions to reduce. keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False` For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) >>> flow.logsumexp(input, 0) tensor([4.0486, 5.0486, 6.0486], dtype=oneflow.float32) >>> flow.logsumexp(input, 1) tensor([3.4076, 6.4076], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/repeat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.repeat, """ repeat(input, sizes) -> Tensor This operator repeat the input tensor to a larger size along the specified dimensions. Args: input (oneflow.Tensor): the input Tensor. sizes (flow.Shape or List): The number of times to repeat this tensor along each dimension. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> out = input.repeat(1, 1, 2, 2) >>> out.shape oneflow.Size([5, 3, 12, 18]) >>> out = input.repeat(2, 1, 1, 2, 2) >>> out.shape oneflow.Size([2, 5, 3, 12, 18]) """, ) ================================================ FILE: python/oneflow/framework/docstr/repeat_interleave.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.repeat_interleave, """ repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor Repeat elements of a tensor. .. warning:: This is different from :meth:`oneflow.Tensor.repeat` but similar to ``numpy.repeat``. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.repeat_interleave.html Args: input (oneflow.Tensor): the input Tensor. repeats (Tensor or int): The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. dim (int, optional): The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array. Keyword args: output_size (int, optional): Total output size for the given axis ( e.g. sum of repeats). If given, it will avoid stream syncronization needed to calculate output shape of the tensor. Returns: oneflow.Tensor: Repeated tensor which has the same shape as input, except along the given axis. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> y = flow.tensor([[1, 2], [3, 4]]) >>> flow.repeat_interleave(y, 2) tensor([1, 1, 2, 2, 3, 3, 4, 4], dtype=oneflow.int64) >>> flow.repeat_interleave(y, 3, dim=1) tensor([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]], dtype=oneflow.int64) >>> flow.repeat_interleave(y, flow.tensor([1, 2]), dim=0) tensor([[1, 2], [3, 4], [3, 4]], dtype=oneflow.int64) >>> flow.repeat_interleave(y, flow.tensor([1, 2]), dim=0, output_size=3) tensor([[1, 2], [3, 4], [3, 4]], dtype=oneflow.int64) .. Feature Stage of Operator [repeat_interleave]. - Maintainer List [@BBuf] - Current Stage [ ] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [Yes] - eager local [Yes] [@QiangX-man, @hjchen2] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] [@BBuf, @strint, @hjchen2] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Exception Handling - Exception Message and Hint must be provided [Yes] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[Yes] - Doc(Same standard as Alpha Stage)[ ] - Functionality and its' Test [ ] - eager global [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - graph gloal [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [ ] - Exception Message and Hint must be provided [ ] - Try you best to do Exception Recovery [ ] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """, ) ================================================ FILE: python/oneflow/framework/docstr/roc_auc_score.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.roc_auc_score, """ oneflow.roc_auc_score(label, pred) -> Tensor Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. Note: Currently this implementation can only be used on CPU. Args: label (Tensor[N, 1]): True lable of the samples pred (Tensor[N, 1]): Predicted probability value to be true Returns: Tensor[1, ]: float32 tensor of auc score For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> label = flow.Tensor([0, 0, 1, 1]) >>> pred = flow.Tensor([0.1, 0.4, 0.35, 0.8]) >>> score = flow.roc_auc_score(label, pred) >>> score tensor([0.7500], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/searchsorted.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.searchsorted, """ searchsorted() -> oneflow.Tensor Find the indices from the innermost dimension of sorted_sequence such that, if the corresponding values in values were inserted before the indices, the order of the corresponding innermost dimension within sorted_sequence would be preserved. Return a new tensor with the same size as values. If right is False (default), then the left boundary of sorted_sequence is closed. More formally, the returned index satisfies the following rules: ================= ========= ========================================================================== sorted_sequence right returned index satisfies ================= ========= ========================================================================== 1-D False sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i] 1-D True sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i] N-D False sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i] N-D True sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] sorted_sequence[m][n]...[l][i] ================= ========= ========================================================================== The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.searchsorted.html Args: sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). out_int32 (bool optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. Default value is False, i.e. default output data type is torch.int64. right (bool optional): if False, return the first suitable location that is found. If True, return the last such index. If no suitable index found, return 0 for non-numerical value (eg. nan, inf) or the size of innermost dimension within sorted_sequence (one pass the last index of the innermost dimension). In other words, if False, gets the lower bound index for each value in values on the corresponding innermost dimension of the sorted_sequence. If True, gets the upper bound index instead. Default value is False. For example: .. code-block:: python >>> import oneflow as flow >>> sorted_sequence = flow.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) >>> sorted_sequence tensor([[ 1, 3, 5, 7, 9], [ 2, 4, 6, 8, 10]], dtype=oneflow.int64) >>> values = flow.tensor([[3, 6, 9], [3, 6, 9]]) >>> values tensor([[3, 6, 9], [3, 6, 9]], dtype=oneflow.int64) >>> flow.searchsorted(sorted_sequence, values) tensor([[1, 3, 4], [1, 2, 4]], dtype=oneflow.int64) >>> flow.searchsorted(sorted_sequence, values, right=True) tensor([[2, 3, 5], [1, 3, 4]], dtype=oneflow.int64) >>> sorted_sequence_1d = flow.tensor([1, 3, 5, 7, 9]) >>> sorted_sequence_1d tensor([1, 3, 5, 7, 9], dtype=oneflow.int64) >>> flow.searchsorted(sorted_sequence_1d, values) tensor([[1, 3, 4], [1, 3, 4]], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/sort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.sort, """Sorts the elements of the input tensor along a given dimension in ascending order by value. Args: input (oneflow.Tensor): the input Tensor. dim (int, optional): the dimension to be sorted. Defaults to the last dim (-1). descending (bool, optional): controls the sorting order (ascending or descending). Returns: Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int32)): A tuple of (values, indices), where where the values are the sorted values and the indices are the indices of the elements in the original input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = np.array([[1, 3, 8, 7, 2], [1, 9, 4, 3, 2]], dtype=np.float32) >>> input = flow.Tensor(x) >>> result = flow.sort(input) >>> result.values tensor([[1., 2., 3., 7., 8.], [1., 2., 3., 4., 9.]], dtype=oneflow.float32) >>> result.indices tensor([[0, 4, 1, 3, 2], [0, 4, 3, 2, 1]], dtype=oneflow.int32) >>> result = flow.sort(input, descending=True) >>> result.values tensor([[8., 7., 3., 2., 1.], [9., 4., 3., 2., 1.]], dtype=oneflow.float32) >>> result.indices tensor([[2, 3, 1, 4, 0], [1, 2, 3, 4, 0]], dtype=oneflow.int32) >>> result = flow.sort(input, dim=0) >>> result.values tensor([[1., 3., 4., 3., 2.], [1., 9., 8., 7., 2.]], dtype=oneflow.float32) >>> result.indices tensor([[0, 0, 1, 1, 0], [1, 1, 0, 0, 1]], dtype=oneflow.int32) """, ) ================================================ FILE: python/oneflow/framework/docstr/special_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.special.digamma, """ Alias for :func:`oneflow.digamma`. """, ) add_docstr( oneflow.special.erf, """ Alias for :func:`oneflow.erf`. """, ) add_docstr( oneflow.special.erfc, """ Alias for :func:`oneflow.erfc`. """, ) add_docstr( oneflow.special.erfinv, """ Alias for :func:`oneflow.erfinv`. """, ) add_docstr( oneflow.special.exp2, """ Alias for :func:`oneflow.exp2`. """, ) add_docstr( oneflow.special.expm1, """ Alias for :func:`oneflow.expm1`. """, ) add_docstr( oneflow.special.log1p, """ Alias for :func:`oneflow.log1p`. """, ) add_docstr( oneflow.special.log_softmax, """ Alias for :func:`oneflow.nn.functional.log_softmax`. """, ) add_docstr( oneflow.special.logsumexp, """ Alias for :func:`oneflow.logsumexp`. """, ) add_docstr( oneflow.special.round, """ Alias for :func:`oneflow.round`. """, ) add_docstr( oneflow.special.softmax, """ Alias for :func:`oneflow.softmax`. """, ) add_docstr( oneflow.special.psi, """ Alias for :func:`oneflow.special.digamma`. """, ) add_docstr( oneflow.special.zeta, r""" zeta(input, other) -> Tensor Computes the Hurwitz zeta function, elementwise. .. math:: \zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x} Args: input (Tensor): the input tensor corresponding to `x`. other (Tensor): the input tensor corresponding to `q`. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([2., 4.]) >>> flow.special.zeta(x, 1) tensor([1.6449, 1.0823], dtype=oneflow.float32) >>> flow.special.zeta(x, flow.tensor([1., 2.])) tensor([1.6449, 0.0823], dtype=oneflow.float32) >>> flow.special.zeta(2,flow.tensor([1., 2.])) tensor([1.6449, 0.6449], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/split.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.split, """Splits the tensor into chunks. If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension `dim` is not divisible by split_size. If `split_size_or_sections` is a list, then x will be split into `len(split_size_or_sections)` chunks with sizes in `dim` according to `split_size_or_sections`. Args: x: tensor to split. split_size_or_sections: size of a single chunk or list of sizes for each chunk. dim: dimension along which to split the tensor. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.arange(10).view(5, 2) >>> flow.split(a, 2) (tensor([[0, 1], [2, 3]], dtype=oneflow.int64), tensor([[4, 5], [6, 7]], dtype=oneflow.int64), tensor([[8, 9]], dtype=oneflow.int64)) >>> flow.split(a, [1, 4]) (tensor([[0, 1]], dtype=oneflow.int64), tensor([[2, 3], [4, 5], [6, 7], [8, 9]], dtype=oneflow.int64)) """, ) ================================================ FILE: python/oneflow/framework/docstr/swapaxes.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.swapaxes, """swapaxes(input, axis0, axis1) -> Tensor This function is equivalent to NumPy’s swapaxes function. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) >>> x.shape oneflow.Size([2, 2, 2]) >>> flow.swapaxes(x, 0, 1).shape oneflow.Size([2, 2, 2]) >>> flow.swapaxes(x, 0, 2).shape oneflow.Size([2, 2, 2]) """, ) ================================================ FILE: python/oneflow/framework/docstr/swapdims.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.swapdims, """ swapdims(input, dim0, dim1) -> Tensor This function is equivalent to torch’s swapdims function. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) >>> x tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=oneflow.int64) >>> flow.swapdims(x, 0, 1) tensor([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dtype=oneflow.int64) >>> flow.swapdims(x, 0, 2) tensor([[[0, 4], [2, 6]], [[1, 5], [3, 7]]], dtype=oneflow.int64) """, ) ================================================ FILE: python/oneflow/framework/docstr/tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.tensor, r""" Constructs a tensor with data, return a global tensor if placement and sbp are in kwargs, otherwise return a local tensor. Arguments: data: Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar or tensor. Keyword Arguments: dtype (oneflow.dtype, optional) – the desired data type of returned tensor. Default: if None, infers data type from data. device (oneflow.device, optional): the desired device of returned tensor. If placement and sbp is None, uses the current cpu for the default tensor type. placement (oneflow.placement, optional): the desired placement of returned tensor. sbp (oneflow.sbp or tuple of oneflow.sbp, optional): the desired sbp of returned tensor. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False pin_memory(bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. Note: The Keyword Argument device is mutually exclusive with placement and sbp. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1,2,3]) >>> x tensor([1, 2, 3], dtype=oneflow.int64) """, ) add_docstr( oneflow.from_numpy, r""" Creates a ``Tensor`` from a ``numpy.ndarray``. The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. It currently accepts ndarray with dtypes of numpy.float64, numpy.float32, numpy.float16, numpy.int64, numpy.int32, numpy.int8, numpy.uint8. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.arange(6).reshape(2, 3) >>> t = flow.from_numpy(np_arr) >>> t tensor([[0, 1, 2], [3, 4, 5]], dtype=oneflow.int64) >>> np_arr[0, 0] = -1 >>> t tensor([[-1, 1, 2], [ 3, 4, 5]], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.device, r""" Is the :class:`oneflow.device` where this Tensor is, which is invalid for global tensor. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.device.html. """, ) add_docstr( oneflow.Tensor.placement, r""" Is the :class:`oneflow.placement` where this Tensor is, which is invalid for local tensor. """, ) add_docstr( oneflow.Tensor.sbp, r""" Is the ``oneflow.sbp`` representing that how the data of the global tensor is distributed, which is invalid for local tensor. """, ) add_docstr( oneflow.Tensor.is_global, r""" Return whether this Tensor is a global tensor. """, ) add_docstr( oneflow.Tensor.is_lazy, r""" Return whether this Tensor is a lazy tensor. """, ) add_docstr( oneflow.Tensor.atan2, r""" See :func:`oneflow.atan2` """, ) add_docstr( oneflow.Tensor.expand, """ Tensor.expand() -> Tensor See :func:`oneflow.expand` """, ) add_docstr( oneflow.Tensor.expand_as, """ expand_as(other) -> Tensor Expand this tensor to the same size as :attr:`other`. ``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``. Please see :meth:`~Tensor.expand` for more information about ``expand``. Args: other (:class:`oneflow.Tensor`): The result tensor has the same size as :attr:`other`. """, ) add_docstr( oneflow.Tensor.flatten, """ See :func:`oneflow.flatten` """, ) add_docstr( oneflow.Tensor.floor, """ See :func:`oneflow.floor` """, ) add_docstr( oneflow.Tensor.floor_, """ See :func:`oneflow.floor_` """, ) add_docstr( oneflow.Tensor.flip, """ See :func:`oneflow.flip` """, ) add_docstr( oneflow.Tensor.in_top_k, """ Tensor.in_top_k(targets, predictions, k) -> Tensor See :func:`oneflow.in_top_k` """, ) add_docstr( oneflow.Tensor.index_select, """ Tensor.index_select(dim, index) -> Tensor See :func:`oneflow.index_select` """, ) add_docstr( oneflow.Tensor.numel, """ See :func:`oneflow.numel` """, ) add_docstr( oneflow.Tensor.offload, """ Transfer tensor data from GPU memory back to host (CPU) memory. If the tensor is already in host (CPU) memory, the operation does nothing and gives a warning. Note that this operation only changes the storage of the tensor, and the tensor id will not change. Note: Both global tensor and local tensor of oneflow are applicable to this operation. Use with :func:`oneflow.Tensor.load` and :func:`oneflow.Tensor.is_offloaded`. The behavior of load() is the opposite of offload(), is_offloaded() returns a boolean indicating whether the tensor has been moved to CPU memory. In addition, support for offloading elements of :func:`oneflow.nn.Module.parameters` is provided. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> # local tensor >>> x = flow.tensor(np.random.randn(1024, 1024, 100), dtype=flow.float32, device=flow.device("cuda"), ) >>> before_id = id(x) >>> x.offload() # Move the Tensor from the GPU to the CPU >>> after_id = id(x) >>> after_id == before_id True >>> x.is_offloaded() True >>> x.load() # Move the Tensor from the cpu to the gpu >>> x.is_offloaded() False .. code-block:: python >>> import oneflow as flow >>> # global tensor >>> # Run on 2 ranks respectively >>> placement = flow.placement("cuda", ranks=[0, 1]) >>> sbp = flow.sbp.broadcast >>> x = flow.randn(1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp) # doctest: +SKIP >>> before_id = id(x) # doctest: +SKIP >>> x.offload() # doctest: +SKIP >>> after_id = id(x) # doctest: +SKIP >>> print(after_id == before_id) # doctest: +SKIP >>> print(x.is_offloaded()) # doctest: +SKIP >>> x.load() # doctest: +SKIP >>> print(x.is_offloaded()) # doctest: +SKIP """, ) add_docstr( oneflow.Tensor.load, """ Load tensor data stored on the host (CPU) back to GPU memory. If the tensor is already in GPU memory, the operation does nothing and gives a warning. """, ) add_docstr( oneflow.Tensor.is_offloaded, """ Tensor.is_offloaded() -> bool Determine whether the tensor has been moved to CPU memory and the CUDA device memory has been released. """, ) add_docstr( oneflow.Tensor.new_empty, """ Tensor.new_empty(*size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` filled with uninitialized data. By default, the returned Tensor has the same :attr:`flow.dtype` and :attr:`flow.device` as this tensor. Args: size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(()) >>> y = x.new_empty((2, 2)) >>> y.shape oneflow.Size([2, 2]) """, ) add_docstr( oneflow.Tensor.new_ones, """ Tensor.new_ones() -> Tensor See :func:`oneflow.new_ones` """, ) add_docstr( oneflow.Tensor.new_zeros, """ Tensor.new_zeros(size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a Tensor of size size filled with 0. By default, the returned Tensor has the same oneflow.dtype, oneflow.device or oneflow.placement and oneflow.sbp as this tensor. Args: size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.Tensor(np.ones((1, 2, 3))) >>> y = x.new_zeros((2, 2)) >>> y tensor([[0., 0.], [0., 0.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.new_full, """ Tensor.new_full(size, fill_value, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a Tensor of size size filled with fill_value. By default, the returned Tensor has the same oneflow.dtype, oneflow.device or oneflow.placement and oneflow.sbp as this tensor. Args: fill_value (scalar): the number to fill the output tensor with. size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> tensor = flow.ones((2,), dtype=flow.float64) >>> tensor.new_full((3, 4), 3.141592) tensor([[3.1416, 3.1416, 3.1416, 3.1416], [3.1416, 3.1416, 3.1416, 3.1416], [3.1416, 3.1416, 3.1416, 3.1416]], dtype=oneflow.float64) """, ) add_docstr( oneflow.Tensor.storage_offset, """ Tensor.storage_offset() -> Tensor Returns self tensor’s offset in the underlying storage in terms of number of storage elements (not bytes). Example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3, 4, 5]) >>> x.storage_offset() 0 """, ) add_docstr( oneflow.Tensor.local_to_global, """ Tensor.local_to_global(placement=None, sbp=None, *, check_meta=True, copy=False) -> Tensor Creates a global tensor from a local tensor. Note: This tensor must be local tensor. Both placement and sbp are required. The returned global tensor takes this tensor as its local component in the current rank. There is no data communication usually, but when sbp is ``oneflow.sbp.broadcast``, the data on rank 0 will be broadcast to other ranks. .. warning:: When the sbp is ``oneflow.sbp.broadcast``, the data on the non-0 rank will be modified. If you want to keep the input local tensor unchanged, please set the arg copy to True. Args: placement (flow.placement, optional): the desired placement of returned global tensor. Default: None sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None Keyword Args: check_meta (bool, optional): indicates whether to check meta information when createing global tensor from local tensor. Only can be set to False when the shape and dtype of the input local tensor on each rank are the same. If set to False, the execution of local_to_global can be accelerated. Default: True copy (bool, optional): When copy is set, the returned global tensor takes the replication of this tensor as its local component in the current rank. Default: False .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> input = flow.tensor([0., 1.], dtype=flow.float32) # doctest: +SKIP >>> output = input.local_to_global(placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)], check_meta=False) # doctest: +SKIP >>> print(output.size()) # doctest: +SKIP >>> print(output) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 oneflow.Size([4]) tensor([0., 1., 0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) .. code-block:: python >>> # results on rank 1 oneflow.Size([4]) tensor([0., 1., 0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.global_to_global, """ Tensor.global_to_global(placement=None, sbp=None, *, grad_sbp=None, check_meta=False, copy=False) -> Tensor Performs Tensor placement and/or sbp conversion. Note: This tensor must be global tensor. At least one of placement and sbp is required. If placement and sbp are all the same as this tensor's own placement and sbp, then returns this tensor own. Args: placement (flow.placement, optional): the desired placement of returned global tensor. Default: None sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None Keyword Args: grad_sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): manually specify the sbp of this tensor's grad tensor in the backward pass. If None, the grad tensor sbp will be infered automatically. Default: None check_meta (bool, optional): indicates whether to check meta information. If set to True, check the consistency of the input meta information (placement and sbp) on each rank. Default: False copy (bool, optional): When copy is set, a new Tensor is created even when the Tensor already matches the desired conversion. Default: False .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> input = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.broadcast]) # doctest: +SKIP >>> output = input.global_to_global(placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP >>> print(output.size()) # doctest: +SKIP >>> print(output) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 oneflow.Size([2]) tensor([0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) .. code-block:: python >>> # results on rank 1 oneflow.Size([2]) tensor([0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.to_global, """ Tensor.to_global(placement=None, sbp=None, **kwargs) -> Tensor Creates a global tensor if this tensor is a local tensor, otherwise performs Tensor placement and/or sbp conversion. Note: This tensor can be local tensor or global tensor. - For local tensor Both placement and sbp are required. The returned global tensor takes this tensor as its local component in the current rank. There is no data communication usually, but when sbp is ``oneflow.sbp.broadcast``, the data on rank 0 will be broadcast to other ranks. - For global tensor At least one of placement and sbp is required. If placement and sbp are all the same as this tensor's own placement and sbp, then returns this tensor own. .. warning:: When the input tensor is a local tensor and sbp is ``oneflow.sbp.broadcast``, the data on the non-0 rank will be modified. If you want to keep the input local tensor unchanged, please set the arg copy to True. Args: placement (flow.placement, optional): the desired placement of returned global tensor. Default: None sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None Keyword Args: grad_sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): manually specify the sbp of this tensor's grad tensor in the backward pass. If None, the grad tensor sbp will be infered automatically. It is only used if this tensor is a global tensor. Default: None check_meta (bool, optional): indicates whether to check meta information. If set to True, check the input meta information on each rank. Default: True if this tensor is a local tensor, False if this tensor is a global tensor copy (bool, optional): When copy is set, copy occurres in this operation. For local tensor, the returned global tensor takes the replication of this tensor as its local component in the current rank. For global tensor, a new Tensor is created even when the Tensor already matches the desired conversion. Default: False For local tensor: .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> input = flow.tensor([0., 1.], dtype=flow.float32) # doctest: +SKIP >>> output = input.to_global(placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)], check_meta=False) # doctest: +SKIP >>> print(output.size()) # doctest: +SKIP >>> print(output) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 oneflow.Size([4]) tensor([0., 1., 0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) .. code-block:: python >>> # results on rank 1 oneflow.Size([4]) tensor([0., 1., 0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) For global tensor: .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> input = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.broadcast]) # doctest: +SKIP >>> output = input.to_global(placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP >>> print(output.size()) # doctest: +SKIP >>> print(output) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 oneflow.Size([2]) tensor([0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) .. code-block:: python >>> # results on rank 1 oneflow.Size([2]) tensor([0., 1.], placement=oneflow.placement(type="cpu", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.to_consistent, """ This interface is no longer available, please use :func:`oneflow.Tensor.to_global` instead. """, ) add_docstr( oneflow.Tensor.to_local, """ Tensor.to_local(**kwargs) -> Tensor Returns the local component of this global tensor in the current rank. Keyword Args: copy (bool, optional): When copy is set, a new replicated tensor of the local component of this global tensor in the current rank is returned. Default: False Note: This tensor should be a global tensor, and it returns a empty tensor if there is no local component in the current rank. No copy occurred in this operation if copy is not set. For example: .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> x = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP >>> y = x.to_local() # doctest: +SKIP >>> print(y.size()) # doctest: +SKIP >>> print(y) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 oneflow.Size([1]) tensor([0.], dtype=oneflow.float32) .. code-block:: python >>> # results on rank 1 oneflow.Size([1]) tensor([1.], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.transpose, """ See :func:`oneflow.transpose` """, ) add_docstr( oneflow.Tensor.logical_not, """ logical_not() -> Tensor See :func:`oneflow.logical_not` """, ) add_docstr( oneflow.Tensor.lerp, """ See :func:`oneflow.lerp` """, ) add_docstr( oneflow.Tensor.lerp_, """ See :func:`oneflow.lerp_` """, ) add_docstr( oneflow.Tensor.quantile, """ See :func:`oneflow.quantile` """, ) add_docstr( oneflow.Tensor.sqrt, """ See :func:`oneflow.sqrt` """, ) add_docstr( oneflow.Tensor.square, """ See :func:`oneflow.square` """, ) add_docstr( oneflow.Tensor.std, """ See :func:`oneflow.std` """, ) add_docstr( oneflow.Tensor.var, """ See :func:`oneflow.var` """, ) add_docstr( oneflow.Tensor.squeeze, """ Tensor.squeeze(dim=None) -> Tensor See :func:`oneflow.squeeze` """, ) add_docstr( oneflow.Tensor.squeeze_, """ Tensor.squeeze_(dim=None) -> Tensor In-place version of :func:`oneflow.Tensor.squeeze` """, ) add_docstr( oneflow.Tensor.unfold, """ Returns a view of the original tensor which contains all slices of `size` size from `self` tensor in the dimension `dimension`. Step between two slices is given by `step`. If sizedim is the size of dimension `dimension` for `self`, the size of dimension dimension in the returned tensor will be (sizedim - size) / step + 1. An additional dimension of size `size` is appended in the returned tensor. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.unfold.html. Args: dimension (int): dimension in which unfolding happens size (int): the size of each slice that is unfolded step (int): the step between each slice For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.arange(1, 8) >>> x tensor([1, 2, 3, 4, 5, 6, 7], dtype=oneflow.int64) >>> x.unfold(0, 2, 1) tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype=oneflow.int64) >>> x.unfold(0, 2, 2) tensor([[1, 2], [3, 4], [5, 6]], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.matmul, """ See :func:`oneflow.matmul` """, ) add_docstr( oneflow.Tensor.mv, """ See :func:`oneflow.mv` """, ) add_docstr( oneflow.Tensor.mm, """ See :func:`oneflow.mm` """, ) add_docstr( oneflow.Tensor.narrow, """ See :func:`oneflow.narrow` """, ) add_docstr( oneflow.Tensor.unsqueeze, """ Tensor.unsqueeze(dim) -> Tensor See :func:`oneflow.unsqueeze` """, ) add_docstr( oneflow.Tensor.unsqueeze_, """ Tensor.unsqueeze_(dim) -> Tensor In-place version of :func:`oneflow.Tensor.unsqueeze` """, ) add_docstr( oneflow.Tensor.as_strided, """ Tensor.as_strided(size, stride, storage_offset=None) -> Tensor See :func:`oneflow.as_strided` """, ) add_docstr( oneflow.Tensor.as_strided_, """ Tensor.as_strided_(size, stride, storage_offset=None) -> Tensor In-place version of :func:`oneflow.Tensor.as_strided` """, ) add_docstr( oneflow.Tensor.permute, """ See :func:`oneflow.permute` """, ) add_docstr( oneflow.Tensor.abs, """ See :func:`oneflow.abs` """, ) add_docstr( oneflow.Tensor.acos, """ See :func:`oneflow.acos` """, ) add_docstr( oneflow.Tensor.arccos, """ See :func:`oneflow.arccos` """, ) add_docstr( oneflow.Tensor.acosh, """ See :func:`oneflow.acosh` """, ) add_docstr( oneflow.Tensor.arccosh, """ See :func:`oneflow.arccosh` """, ) add_docstr( oneflow.Tensor.arctanh, """ See :func:`oneflow.arctanh` """, ) add_docstr( oneflow.Tensor.argmax, """ See :func:`oneflow.argmax` """, ) add_docstr( oneflow.Tensor.argmin, """ See :func:`oneflow.argmin` """, ) add_docstr( oneflow.Tensor.argsort, """ See :func:`oneflow.argsort` """, ) add_docstr( oneflow.Tensor.argwhere, """ See :func:`oneflow.argwhere` """, ) add_docstr( oneflow.Tensor.atanh, """ See :func:`oneflow.atanh` """, ) add_docstr( oneflow.Tensor.backward, """ Computes the gradient of current tensor `w.r.t.` graph leaves. The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying gradient. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t. self. This function accumulates gradients in the leaves - you might need to zero .grad attributes or set them to None before calling it. See Default gradient layouts for details on the memory layout of accumulated gradients. Note: If you run any forward ops, create gradient, and/or call backward in a user-specified CUDA stream context, see Stream semantics of backward passes. Note: When inputs are provided and a given input is not a leaf, the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). It is an implementation detail on which the user should not rely. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.backward.html. Args: gradient (Tensor or None): Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless create_graph is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional. retain_graph (bool, optional): If False, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph. create_graph (bool, optional): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False. """, ) add_docstr( oneflow.Tensor.grad, r""" Return the gradient calculated by autograd functions. This property is None by default. """, ) add_docstr( oneflow.Tensor.grad_fn, r""" Return the function that created this tensor if it's ``requires_grad`` is True. """, ) add_docstr( oneflow.Tensor.inverse, """ See :func:`oneflow.linalg.inv` """, ) add_docstr( oneflow.Tensor.trunc, """ See :func:`oneflow.trunc` """, ) add_docstr( oneflow.Tensor.is_leaf, r""" All Tensors that have ``requires_grad`` which is ``False`` will be leaf Tensors by convention. For Tensor that have ``requires_grad`` which is ``True``, they will be leaf Tensors if they were created by source operations. Only leaf Tensors will have their ``grad`` populated during a call to ``backward()``. To get ``grad`` populated for non-leaf Tensors, you can use ``retain_grad()``. Compatible with PyTorch. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.rand(10, requires_grad=False) >>> a.is_leaf True >>> a = flow.rand(10, requires_grad=True) >>> a.is_leaf True >>> b = a.cuda() >>> b.is_leaf False >>> c = a + 2 >>> c.is_leaf False """, ) add_docstr( oneflow.Tensor.requires_grad, r""" Is ``True`` if gradient need to be computed for this Tensor, ``False`` otherwise. Compatible with PyTorch. """, ) add_docstr( oneflow.Tensor.requires_grad_, r"""oneflow.Tensor.requires_grad_(requires_grad=True) -> Tensor Sets this tensor’s requires_grad attribute in-place. Returns this tensor. Compatible with PyTorch. Args: requires_grad (bool): Change the requires_grad flag for this Tensor. Default is ``True``. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.rand(10, requires_grad=False) >>> a.requires_grad False >>> a = a.requires_grad_(requires_grad=True) >>> a.requires_grad True """, ) add_docstr( oneflow.Tensor.register_hook, r"""oneflow.Tensor.register_hook(hook) Registers a backward hook. The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature: .. code-block:: hook(grad) -> Tensor or None The hook should not modify its argument, but it can optionally return a new gradient which will be used in place of ``grad``. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.ones(5, requires_grad=True) >>> def hook(grad): ... return grad * 2 >>> x.register_hook(hook) >>> y = x * 2 >>> y.sum().backward() >>> x.grad tensor([4., 4., 4., 4., 4.], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.retain_grad, r""" Enables this Tensor to have their ``grad`` populated during ``backward()``. This is a no-op for leaf tensors. Compatible with PyTorch. """, ) add_docstr( oneflow.Tensor.bmm, """ See :func:`oneflow.bmm` """, ) add_docstr( oneflow.Tensor.chunk, """ See :func:`oneflow.chunk` """, ) add_docstr( oneflow.Tensor.split, """ See :func:`oneflow.split` """, ) add_docstr( oneflow.Tensor.unbind, """ See :func:`oneflow.unbind` """, ) add_docstr( oneflow.Tensor.swapaxes, """ See :func:`oneflow.swapaxes` """, ) add_docstr( oneflow.Tensor.amax, """ See :func:`oneflow.amax` """, ) add_docstr( oneflow.Tensor.swapdims, """ See :func:`oneflow.swapdims` """, ) add_docstr( oneflow.Tensor.cast, """ See :func:`oneflow.cast` """, ) add_docstr( oneflow.Tensor.diag, """ See :func:`oneflow.diag` """, ) add_docstr( oneflow.Tensor.addcdiv, """ See :func:`oneflow.addcdiv` """, ) add_docstr( oneflow.Tensor.addcdiv_, """ In-place version of :func:`oneflow.Tensor.addcdiv` """, ) add_docstr( oneflow.Tensor.dim, """ Tensor.dim() → int Returns the number of dimensions of self tensor. """, ) add_docstr( oneflow.Tensor.element_size, """ Tensor.element_size() → int Returns the size in bytes of an individual element. """, ) add_docstr( oneflow.Tensor.exp, """ See :func:`oneflow.exp` """, ) add_docstr( oneflow.Tensor.exp2, """ See :func:`oneflow.exp2` """, ) add_docstr( oneflow.Tensor.erf, """ Tensor.erf() -> Tensor See :func:`oneflow.erf` """, ) add_docstr( oneflow.Tensor.erfc, """ Tensor.erfc() -> Tensor See :func:`oneflow.erfc` """, ) add_docstr( oneflow.Tensor.erfinv, """ See :func:`oneflow.erfinv` """, ) add_docstr( oneflow.Tensor.erfinv_, """ Inplace version of :func:`oneflow.erfinv` """, ) add_docstr( oneflow.Tensor.eq, """ See :func:`oneflow.eq` """, ) add_docstr( oneflow.Tensor.equal, """ See :func:`oneflow.equal` """, ) add_docstr( oneflow.Tensor.lt, """ See :func:`oneflow.lt` """, ) add_docstr( oneflow.Tensor.le, """ See :func:`oneflow.le` """, ) add_docstr( oneflow.Tensor.ne, """ See :func:`oneflow.ne` """, ) add_docstr( oneflow.Tensor.neg, """ See :func:`oneflow.neg` """, ) add_docstr( oneflow.Tensor.norm, """ See :func:`oneflow.norm` """, ) add_docstr( oneflow.Tensor.fill_, """ Tensor.fill_(value) → Tensor Fills `self` tensor with the specified value. """, ) add_docstr( oneflow.Tensor.ge, """ See :func:`oneflow.ge` """, ) add_docstr( oneflow.Tensor.get_device, """ Tensor.get_device() -> Device ordinal (Integer) For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. For CPU tensors, an error is thrown. """, ) add_docstr( oneflow.Tensor.gt, """ See :func:`oneflow.gt` """, ) add_docstr( oneflow.Tensor.gt_, """Tensor.gt_(value) -> Tensor In-place version of :func:`oneflow.Tensor.gt`. """, ) add_docstr( oneflow.Tensor.log1p, """ See :func:`oneflow.log1p` """, ) add_docstr( oneflow.Tensor.mish, """ See :func:`oneflow.mish` """, ) add_docstr( oneflow.Tensor.mul, """Tensor.mul(value) -> Tensor See :func:`oneflow.mul` """, ) add_docstr( oneflow.Tensor.mul_, """Tensor.mul_(value) -> Tensor In-place version of :func:`oneflow.Tensor.mul`. """, ) add_docstr( oneflow.Tensor.div_, """Tensor.div_(value) -> Tensor In-place version of :func:`oneflow.Tensor.div`. """, ) add_docstr( oneflow.Tensor.sub_, """Tensor.sub_(value) -> Tensor In-place version of :func:`oneflow.Tensor.sub`. """, ) add_docstr( oneflow.Tensor.negative, """ See :func:`oneflow.negative` """, ) add_docstr( oneflow.Tensor.nelement, """ Tensor.nelement() → int Alias for numel() """, ) add_docstr( oneflow.Tensor.normal_, """ normal_(mean=0, std=1, *, generator=None) -> Tensor Fills :attr:`self` tensor with elements samples from the normal distribution parameterized by :attr:`mean` and :attr:`std`. """, ) add_docstr( oneflow.Tensor.numpy, """ Tensor.numpy() → numpy.ndarray Returns self tensor as a NumPy ndarray. This tensor and the returned ndarray share the same underlying storage. Changes to self tensor will be reflected in the ndarray and vice versa. """, ) add_docstr( oneflow.Tensor.pow, """ See :func:`oneflow.pow` """, ) add_docstr( oneflow.Tensor.relu, """ See :func:`oneflow.relu` """, ) add_docstr( oneflow.Tensor.roll, """ See :func:`oneflow.roll` """, ) add_docstr( oneflow.Tensor.round, """ See :func:`oneflow.round` """, ) add_docstr( oneflow.Tensor.round_, """ See :func:`oneflow.round_` """, ) add_docstr( oneflow.Tensor.reciprocal, """ See :func:`oneflow.reciprocal` """, ) add_docstr( oneflow.Tensor.add, """ See :func:`oneflow.add` """, ) add_docstr( oneflow.Tensor.addmm, """ See :func:`oneflow.addmm` """, ) add_docstr( oneflow.Tensor.add_, """ In-place version of :func:`oneflow.Tensor.add`. """, ) add_docstr( oneflow.Tensor.addcmul, """ See :func:`oneflow.addcmul` """, ) add_docstr( oneflow.Tensor.addcmul_, """ In-place version of :func:`oneflow.Tensor.addcmul`. """, ) add_docstr( oneflow.Tensor.asin, """ See :func:`oneflow.asin` """, ) add_docstr( oneflow.Tensor.asinh, """ See :func:`oneflow.asinh` """, ) add_docstr( oneflow.Tensor.arcsin, """ See :func:`oneflow.arcsin` """, ) add_docstr( oneflow.Tensor.arcsinh, """ See :func:`oneflow.arcsinh` """, ) add_docstr( oneflow.Tensor.sin, """ sin() -> Tensor See :func:`oneflow.sin` """, ) add_docstr( oneflow.Tensor.sin_, """ See :func:`oneflow.sin_` """, ) add_docstr( oneflow.Tensor.cos, """ See :func:`oneflow.cos` """, ) add_docstr( oneflow.Tensor.diagonal, """ See :func:`oneflow.diagonal` """, ) add_docstr( oneflow.Tensor.log, """ See :func:`oneflow.log` """, ) add_docstr( oneflow.Tensor.log2, """ See :func:`oneflow.log2` """, ) add_docstr( oneflow.Tensor.log10, """ See :func:`oneflow.log10` """, ) add_docstr( oneflow.Tensor.ndim, """ See :func:`oneflow.Tensor.dim` """, ) add_docstr( oneflow.Tensor.rsqrt, """ See :func:`oneflow.rsqrt` """, ) add_docstr( oneflow.Tensor.cosh, """ See :func:`oneflow.cosh` """, ) add_docstr( oneflow.Tensor.atan, """ See :func:`oneflow.atan` """, ) add_docstr( oneflow.Tensor.arctan, """ See :func:`oneflow.arctan` """, ) add_docstr( oneflow.Tensor.dot, """ See :func:`oneflow.dot` """, ) add_docstr( oneflow.Tensor.selu, """ See :func:`oneflow.selu` """, ) add_docstr( oneflow.Tensor.sigmoid, """ See :func:`oneflow.sigmoid` """, ) add_docstr( oneflow.Tensor.sign, """ See :func:`oneflow.sign` """, ) add_docstr( oneflow.Tensor.silu, """ See :func:`oneflow.silu` """, ) add_docstr( oneflow.Tensor.sinh, """ See :func:`oneflow.sinh` """, ) add_docstr( oneflow.Tensor.size, """ Returns the size of the self tensor. If dim is not specified, the returned value is a oneflow.Size, a subclass of tuple. If dim is specified, returns an int holding the size of that dimension. The interface is consistent with PyTorch. Args: idx (int, optional): The dimension for which to retrieve the size. """, ) add_docstr( oneflow.Tensor.softmax, """ See :func:`oneflow.softmax` """, ) add_docstr( oneflow.Tensor.softplus, """ See :func:`oneflow.softplus` """, ) add_docstr( oneflow.Tensor.softsign, """ See :func:`oneflow.softsign` """, ) add_docstr( oneflow.Tensor.tan, """ See :func:`oneflow.tan` """, ) add_docstr( oneflow.Tensor.tanh, """ See :func:`oneflow.tanh` """, ) add_docstr( oneflow.Tensor.tril, """ See :func:`oneflow.tril` """, ) add_docstr( oneflow.Tensor.triu, """ See :func:`oneflow.triu` """, ) add_docstr( oneflow.Tensor.uniform_, """ Tensor.uniform_(from=0, to=1) → Tensor Fills self tensor with numbers sampled from the continuous uniform distribution: .. math:: P(x)=1/(to-from) """, ) add_docstr( oneflow.Tensor.copy_, """ Copies the elements from src into self tensor and returns self. The src tensor must be broadcastable with the self tensor. It may be of a different data type or reside on a different device. The interface is consistent with PyTorch. Args: src (Tensor): the source tensor to copy from non_blocking (bool): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. """, ) add_docstr( oneflow.Tensor.to, """Performs Tensor dtype and/or device conversion. A flow.dtype and flow.device are inferred from the arguments of `input.to(*args, **kwargs)`. .. note:: If the ``input`` Tensor already has the correct :class:`flow.dtype` and :class:`flow.device`, then ``input`` is returned. Otherwise, the returned tensor is a copy of ``input`` with the desired. Args: input (oneflow.Tensor): An input tensor. *args (oneflow.Tensor or oneflow.device or oneflow.dtype): Positional arguments **kwargs (oneflow.device or oneflow.dtype) : Key-value arguments Returns: oneflow.Tensor: A Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> arr = np.random.randint(1, 9, size=(1, 2, 3, 4)) >>> input = flow.Tensor(arr) >>> output = input.to(dtype=flow.float32) >>> np.array_equal(arr.astype(np.float32), output.numpy()) True """, ) add_docstr( oneflow.Tensor.half, """ self.half() is equivalent to self.to(dtype=oneflow.float16). See :func:`oneflow.Tensor.to` """, ) add_docstr( oneflow.Tensor.gather, """ oneflow.Tensor.gather(dim, index) -> Tensor See :func:`oneflow.gather` """, ) add_docstr( oneflow.Tensor.clamp, """ See :func:`oneflow.clamp`. """, ) add_docstr( oneflow.Tensor.clamp_, """ Inplace version of :func:`oneflow.Tensor.clamp`. """, ) add_docstr( oneflow.Tensor.clip, """ Alias for :func:`oneflow.Tensor.clamp`. """, ) add_docstr( oneflow.Tensor.clip_, """ Alias for :func:`oneflow.Tensor.clamp_`. """, ) add_docstr( oneflow.Tensor.cpu, r"""Returns a copy of this object in CPU memory. If this object is already in CPU memory and on the correct device, then no copy is performed and the original object is returned. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([1, 2, 3, 4, 5], device=flow.device("cuda")) >>> output = input.cpu() >>> output.device device(type='cpu', index=0) """, ) add_docstr( oneflow.Tensor.cuda, r"""Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then no copy is performed and the original object is returned. Args: device (flow.device): The destination GPU device. Defaults to the current CUDA device. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([1, 2, 3, 4, 5]) >>> output = input.cuda() >>> output.device device(type='cuda', index=0) """, ) add_docstr( oneflow.Tensor.cumprod, """ See :func:`oneflow.cumprod` """, ) add_docstr( oneflow.Tensor.cumsum, """ See :func:`oneflow.cumsum` """, ) add_docstr( oneflow.Tensor.repeat, """ Tensor.repeat(*size) -> Tensor See :func:`oneflow.repeat` """, ) add_docstr( oneflow.Tensor.repeat_interleave, """ Tensor.repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor See :func:`oneflow.repeat_interleave` """, ) add_docstr( oneflow.Tensor.t, """ See :func:`oneflow.t` Tensor.t() → Tensor """, ) add_docstr( oneflow.Tensor.tile, """ Tensor.tile(*dims) -> Tensor See :func:`oneflow.tile` """, ) add_docstr( oneflow.Tensor.T, """ Is this Tensor with its dimensions reversed. If `n` is the number of dimensions in `x`, `x.T` is equivalent to `x.permute(n-1, n-2, ..., 0)`. """, ) add_docstr( oneflow.Tensor.fmod, """ Tensor.fmod(other) -> Tensor See :func:`oneflow.fmod` """, ) add_docstr( oneflow.Tensor.logical_and, """ logical_and() -> Tensor See :func:`oneflow.logical_and` """, ) add_docstr( oneflow.Tensor.logical_or, """ logical_or() -> Tensor See :func:`oneflow.logical_or` """, ) add_docstr( oneflow.Tensor.logical_xor, """ logical_xor() -> Tensor See :func:`oneflow.logical_xor` """, ) add_docstr( oneflow.Tensor.logsumexp, """ See :func:`oneflow.logsumexp` """, ) add_docstr( oneflow.Tensor.masked_fill, """ See :func:`oneflow.masked_fill` """, ) add_docstr( oneflow.Tensor.masked_fill_, """ In-place version of :meth:`oneflow.Tensor.masked_fill`. """, ) add_docstr( oneflow.Tensor.masked_select, """ See :func:`oneflow.masked_select` """, ) add_docstr( oneflow.Tensor.sub, """ See :func:`oneflow.sub` """, ) add_docstr( oneflow.Tensor.div, """ See :func:`oneflow.div` """, ) add_docstr( oneflow.Tensor.ceil, """ See :func:`oneflow.ceil` """, ) add_docstr( oneflow.Tensor.ceil_, """ See :func:`oneflow.ceil_` """, ) add_docstr( oneflow.Tensor.expm1, """ See :func:`oneflow.expm1` """, ) add_docstr( oneflow.Tensor.topk, """ See :func:`oneflow.topk` """, ) add_docstr( oneflow.Tensor.nms, """ See :func:`oneflow.nms` """, ) add_docstr( oneflow.Tensor.nonzero, """ nonzero(input, as_tuple=False) -> Tensor See :func:`oneflow.nonzero` """, ) add_docstr( oneflow.Tensor.max, """ input.max(dim, index) -> Tensor See :func:`oneflow.max` """, ) add_docstr( oneflow.Tensor.min, """ input.min(dim, index) -> Tensor See :func:`oneflow.min` """, ) add_docstr( oneflow.Tensor.maximum, """ See :func:`oneflow.maximum` """, ) add_docstr( oneflow.Tensor.median, """ See :func:`oneflow.median` """, ) add_docstr( oneflow.Tensor.minimum, """ See :func:`oneflow.minimum` """, ) add_docstr( oneflow.Tensor.mode, """ See :func:`oneflow.mode` """, ) add_docstr( oneflow.Tensor.sum, """ input.sum(dim=None, keepdim=False) -> Tensor See :func:`oneflow.sum` """, ) add_docstr( oneflow.Tensor.all, """ input.all(dim=None, keepdim=False) -> Tensor See :func:`oneflow.all` """, ) add_docstr( oneflow.Tensor.any, """ input.any(dim=None, keepdim=False) -> Tensor See :func:`oneflow.any` """, ) add_docstr( oneflow.Tensor.mean, """ input.mean(dim=None, keepdim=False) -> Tensor See :func:`oneflow.mean` """, ) add_docstr( oneflow.Tensor.prod, """ input.prod(dim=None, keepdim=False) -> Tensor See :func:`oneflow.prod` """, ) add_docstr( oneflow.Tensor.reshape, """ See :func:`oneflow.reshape` """, ) add_docstr( oneflow.Tensor.reshape_as, """ Tensor.reshape_as(other) -> Tensor Returns this tensor as the same shape as other. self.reshape_as(other) is equivalent to self.reshape(other.sizes()). This method returns a view if other.sizes() is compatible with the current shape. See :func:`oneflow.Tensor.view` on when it is possible to return a view. Please see reshape() for more information about reshape. See :func:`oneflow.reshape` Parameters other (oneflow.Tensor) – The result tensor has the same shape as other. """, ) add_docstr( oneflow.Tensor.view, """ Returns a new tensor with the same data as the :attr:`self` tensor but of a different :attr:`shape`. The returned tensor shares the same data and must have the same number of elements, but may have a different size. For a tensor to be viewed, the new view size must be compatible with its original size and stride, i.e., each new view dimension must either be a subspace of an original dimension, or only span across original dimensions :math:`d, d+1, \\dots, d+k` that satisfy the following contiguity-like condition that :math:`\\forall i = d, \\dots, d+k-1`, .. math:: \\text{stride}[i] = \\text{stride}[i+1] \\times \\text{size}[i+1] Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which returns a view if the shapes are compatible, and copies (equivalent to calling :meth:`contiguous`) otherwise. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.view.html. Args: input: A Tensor. *shape: flow.Size or int... Returns: A Tensor has the same type as `input`. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array( ... [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ... ).astype(np.float32) >>> input = flow.Tensor(x) >>> y = input.view(2, 2, 2, -1).numpy().shape >>> y (2, 2, 2, 2) """, ) add_docstr( oneflow.Tensor.view_as, """ Tensor.view_as(other) -> Tensor Expand this tensor to the same size as :attr:`other`. ``self.view_as(other)`` is equivalent to ``self.view(other.size())``. Please see :meth:`~Tensor.view` for more information about ``view``. Args: other (:class:`oneflow.Tensor`): The result tensor has the same size as :attr:`other`. """, ) add_docstr( oneflow.Tensor.sort, """ See :func:`oneflow.sort` """, ) add_docstr( oneflow.Tensor.type_as, r"""Returns this tensor cast to the type of the given tensor. This is a no-op if the tensor is already of the correct type. Args: input (Tensor): the input tensor. target (Tensor): the tensor which has the desired type. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32) >>> target = flow.tensor(np.random.randn(4, 5, 6), dtype = flow.int32) >>> input = input.type_as(target) >>> input.dtype oneflow.int32 """, ) add_docstr( oneflow.Tensor.bool, r"""``Tensor.bool()`` is equivalent to ``Tensor.to(oneflow.bool)``. See :class:`oneflow.Tensor.to()`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32) >>> input = input.bool() >>> input.dtype oneflow.bool """, ) add_docstr( oneflow.Tensor.int, r"""``Tensor.int()`` is equivalent to ``Tensor.to(flow.int32)``. See :class:`oneflow.Tensor.to()`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32) >>> input = input.int() >>> input.dtype oneflow.int32 """, ) add_docstr( oneflow.Tensor.long, r"""``Tensor.long()`` is equivalent to ``Tensor.to(flow.int64)``. See :class:`oneflow.Tensor.to()`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32) >>> input = input.long() >>> input.dtype oneflow.int64 """, ) add_docstr( oneflow.Tensor.float, r"""``Tensor.float()`` is equivalent to ``Tensor.to(flow.float32)``. See :class:`oneflow.Tensor.to()`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.int) >>> input = input.float() >>> input.dtype oneflow.float32 """, ) add_docstr( oneflow.Tensor.double, r"""``Tensor.double()`` is equivalent to ``Tensor.to(flow.float64)``. See :class:`oneflow.Tensor.to()`. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.int) >>> input = input.double() >>> input.dtype oneflow.float64 """, ) add_docstr( oneflow.Tensor.is_contiguous, r""" Tensor.is_contiguous() -> bool Returns True if `self` tensor is contiguous in memory. """, ) add_docstr( oneflow.Tensor.is_cuda, r""" Tensor.is_cuda() -> bool Is `True` if the Tensor is stored on the GPU, `False` otherwise. """, ) add_docstr( oneflow.Tensor.is_floating_point, """ See :func:`oneflow.is_floating_point` """, ) add_docstr( oneflow.Tensor.item, r"""Returns the value of this tensor as a standard Python number. This only works for tensors with one element. For other cases, see tolist(). This operation is not differentiable. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1.0]) >>> x.item() 1.0 """, ) add_docstr( oneflow.Tensor.tolist, r"""Returns the tensor as a (nested) list. For scalars, a standard Python number is returned, just like with `item()`. Tensors are automatically moved to the CPU first if necessary. This operation is not differentiable. Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[1,2,3], [4,5,6]]) >>> input.tolist() [[1, 2, 3], [4, 5, 6]] """, ) add_docstr( oneflow.Tensor.where, """ See :func:`oneflow.where` """, ) add_docstr( oneflow.Tensor.zero_, r""" Tensor.zero_() -> Tensor Fills `self` tensor with zeros. """, ) add_docstr( oneflow.Tensor.isnan, """ See :func:`oneflow.isnan` """, ) add_docstr( oneflow.Tensor.isinf, """ See :func:`oneflow.isinf` """, ) add_docstr( oneflow.Tensor.byte, """ self.byte() is equivalent to self.to(oneflow.uint8). See :func:`oneflow.Tensor.to` """, ) add_docstr( oneflow.Tensor.amin, """ See :func:`oneflow.amin` """, ) add_docstr( oneflow.Tensor.pin_memory, r""" Tensor.pin_memory() -> Tensor Copies the tensor to pinned memory, if it’s not already pinned. """, ) add_docstr( oneflow.Tensor.is_pinned, r""" Tensor.is_pinned() -> bool Returns true if this tensor resides in pinned memory. """, ) add_docstr( oneflow.Tensor.type, r""" type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor Returns the type if dtype is not provided, else casts this object to the specified type. If this is already of the correct type, no copy is performed and the original object is returned. Args: dtype (oneflow.dtype or oneflow.tensortype or string, optional): The desired type. non_blocking (bool): (**Not Implemented yet**) If True, and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect. For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.tensor([1, 2], dtype=flow.float32) >>> a.type() 'oneflow.FloatTensor' >>> a.type(flow.int8) # dtype input tensor([1, 2], dtype=oneflow.int8) >>> a.type(flow.cuda.DoubleTensor) # tensortype input tensor([1., 2.], device='cuda:0', dtype=oneflow.float64) >>> a.type("oneflow.HalfTensor") # string input tensor([1., 2.], dtype=oneflow.float16) """, ) add_docstr( oneflow.Tensor.scatter, """ See :func:`oneflow.scatter` """, ) add_docstr( oneflow.Tensor.scatter_, """ Inplace version of :func:`oneflow.Tensor.scatter` """, ) add_docstr( oneflow.Tensor.scatter_add, """ See :func:`oneflow.scatter_add` """, ) add_docstr( oneflow.Tensor.scatter_add_, """ Inplace version of :func:`oneflow.Tensor.scatter_add` """, ) add_docstr( oneflow.Tensor.cross, """ See :func:`oneflow.cross` """, ) add_docstr( oneflow.Tensor.nansum, """ See :func:`oneflow.nansum` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1., 2., float("nan")]) >>> x.nansum() tensor(3., dtype=oneflow.float32) >>> x = flow.tensor([[1., float("nan")], [float("nan"), 2]]) >>> x.nansum(dim=1, keepdim=True) tensor([[1.], [2.]], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.bincount, """ See :func:`oneflow.bincount` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.Tensor([0, 2, 3]).int() >>> x.bincount() tensor([1, 0, 1, 1], dtype=oneflow.int64) >>> weight = flow.Tensor([0.1, 0.2, 0.3]) >>> x.bincount(weight) tensor([0.1000, 0.0000, 0.2000, 0.3000], dtype=oneflow.float32) >>> x.bincount(weight, minlength=5) tensor([0.1000, 0.0000, 0.2000, 0.3000, 0.0000], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.bernoulli, """ See :func:`oneflow.bernoulli` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.Tensor([1, 1, 1]) >>> x.bernoulli() tensor([1., 1., 1.], dtype=oneflow.float32) >>> x.bernoulli(p=0.0) tensor([0., 0., 0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.bernoulli_, """ The inplace version of :func:`oneflow.Tensor.bernoulli_`. See :func:`oneflow.Tensor.bernoulli` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.Tensor([1, 1, 1]) >>> x.bernoulli_(p=0.0) tensor([0., 0., 0.], dtype=oneflow.float32) >>> x tensor([0., 0., 0.], dtype=oneflow.float32) """, ) add_docstr( oneflow.Tensor.broadcast_to, """ See :func:`oneflow.broadcast_to` """, ) add_docstr( oneflow.Tensor.unique, """ See :func:`oneflow.unique` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([3, 1, 2, 0 ,2]) >>> x.unique() tensor([0, 1, 2, 3], dtype=oneflow.int64) >>> x, indices = x.unique(return_inverse=True) >>> indices tensor([3, 1, 2, 0, 2], dtype=oneflow.int32) >>> x, counts = x.unique(return_counts=True) >>> counts tensor([1, 1, 1, 1], dtype=oneflow.int32) """, ) add_docstr( oneflow.Tensor.clone, """ See :func:`oneflow.clone` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> x.clone() tensor([1, 2, 3], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.bitwise_and, """ See :func:`oneflow.bitwise_and` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> x.bitwise_and(4) tensor([0, 0, 0], dtype=oneflow.int64) >>> y = flow.tensor([2, 1, 0]) >>> x.bitwise_and(y) tensor([0, 0, 0], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.bitwise_or, """ See :func:`oneflow.bitwise_or` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> x.bitwise_or(4) tensor([5, 6, 7], dtype=oneflow.int64) >>> y = flow.tensor([2, 1, 0]) >>> x.bitwise_or(y) tensor([3, 3, 3], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.bitwise_xor, """ See :func:`oneflow.bitwise_xor` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> x.bitwise_xor(4) tensor([5, 6, 7], dtype=oneflow.int64) >>> y = flow.tensor([2, 1, 0]) >>> x.bitwise_xor(y) tensor([3, 3, 3], dtype=oneflow.int64) """, ) add_docstr( oneflow.Tensor.new, """ Constructs a new tensor of the same data type and device (or placemant and sbp) as self tensor. Any valid argument combination to the tensor constructor is accepted by this method, including sizes, NumPy ndarray, Python Sequence, etc. See :func:`oneflow.Tensor` for more details. .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(3, 2) >>> x.new() tensor([], dtype=oneflow.float32) >>> x.new(1, 2).shape oneflow.Size([1, 2]) >>> x.new([1, 2]) tensor([1., 2.], dtype=oneflow.float32) >>> y = flow.randn(3, 3) >>> x.new(y).shape oneflow.Size([3, 3]) .. warning:: When y is global tensor, the invoking ``Tensor.new(y)`` will raise an error. Consider use ``Tensor.new(y.size())`` to create a tensor that has the same placement and sbp with Tensor and the same size with ``y``. """, ) add_docstr( oneflow.Tensor.baddbmm, """ See :func:`oneflow.baddbmm` For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(2, 3, 4) >>> batch1 = flow.randn(2, 3, 5) >>> batch2 = flow.randn(2, 5, 4) >>> x.baddbmm(batch1, batch2, alpha=2, beta=2) # doctest: +SKIP """, ) add_docstr( oneflow.Tensor.frac, r""" See :func:`oneflow.frac`. """, ) add_docstr( oneflow.Tensor.frac_, r""" In-place version of :func:`oneflow.Tensor.frac`. """, ) add_docstr( oneflow.Tensor.digamma, """ See :func:`oneflow.digamma` """, ) ================================================ FILE: python/oneflow/framework/docstr/tensor_attributes.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr, reset_docstr oneflow.device.__doc__ = r""" A :class:`oneflow.device` is an object representing the device on which a :class:`oneflow.Tensor` is or will be allocated. The documentation is referenced from: https://pytorch.org/docs/1.10/tensor_attributes.html#torch.torch.device. The :class:`oneflow.device` contains a device type ('cpu' or 'cuda') and optional device ordinal for the device type. If the device ordinal is not present, this object will always represent the current device for the device type. A :class:`oneflow.device`’s device can be accessed via the Tensor.device property. A :class:`oneflow.device` can be constructed via a string or via a string and device ordinal Via a string: .. code-block:: python >>> import oneflow as flow >>> flow.device('cuda:0') device(type='cuda', index=0) >>> flow.device('cpu') device(type='cpu', index=0) >>> flow.device('cuda') # current cuda device device(type='cuda', index=0) Via a string and device ordinal: .. code-block:: python >>> import oneflow as flow >>> flow.device('cuda', 0) device(type='cuda', index=0) >>> flow.device('cpu', 0) device(type='cpu', index=0) Note: The :class:`oneflow.device` argument in functions can generally be substituted with a string. This allows for fast prototyping of code. .. code-block:: python >>> import oneflow as flow >>> # Example of a function that takes in a oneflow.device >>> cuda0 = flow.device('cuda:0') >>> x = flow.randn(2,3, device=cuda0) .. code-block:: python >>> # You can substitute the flow.device with a string >>> x = flow.randn(2,3, device='cuda:0') """ oneflow.placement.__doc__ = r""" A ``oneflow.placement`` is an object representing the device group on which a :class:`oneflow.Tensor` is or will be allocated. The ``oneflow.placement`` contains a device type ('cpu' or 'cuda') and corresponding device sequence. A :class:`oneflow.Tensor`'s placement can be accessed via the Tensor.placement property. A oneflow.placement can be constructed in several ways: .. code-block:: python >>> import oneflow as flow >>> p = flow.placement(type="cuda", ranks=[0, 1, 2, 3]) >>> p oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]) >>> p = flow.placement(type="cuda", ranks=[[0, 1], [2, 3]]) >>> p oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]) """ reset_docstr( oneflow.placement.all, r""" oneflow.placement.all(device_type) -> oneflow.placement Returns a placement that contains all available devices. Args: device_type (str): cuda or cpu For examples: .. code-block:: python # Runs on 4 ranks import oneflow as flow p = flow.placement.all("cuda") # oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]) p = flow.placement.all("cpu") # oneflow.placement(type="cpu", ranks=[0, 1, 2, 3]) """, ) oneflow.sbp.sbp.__doc__ = r""" A ``oneflow.sbp`` is an object representing that how the data of the global tensor is distributed across the ranks of the ``Tensor`` placement. ``oneflow.sbp`` includes three types: - oneflow.sbp.split(dim) Indicates that the global tensor is evenly divided according to the dimension `dim` and distributed on each rank. - oneflow.sbp.broadcast() Indicates that the global tensor is replicated on each rank. - oneflow.sbp.partial_sum() Indicates that the value of the global tensor is element-wise sum of the local tensors distributed in each rank. A :class:`oneflow.Tensor`'s sbp can be accessed via the Tensor.sbp property. A ``oneflow.sbp`` can be constructed in several ways: .. code-block:: python >>> import oneflow as flow >>> s = flow.sbp.split(0) >>> s oneflow.sbp.split(dim=0) >>> b = flow.sbp.broadcast() >>> b oneflow.sbp.broadcast >>> p = flow.sbp.partial_sum() >>> p oneflow.sbp.partial_sum """ ================================================ FILE: python/oneflow/framework/docstr/tensor_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.narrow, r""" narrow(x, dim: int, start: int, length: int) -> Tensor Returns a new tensor that is a narrowed version of `input` tensor. The dimension `dim` is input from `start` to `start + length`. Args: input: the tensor to narrow. dim: the dimension along which to narrow. start: the starting dimension. length: the distance to the ending dimension. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> flow.narrow(input, 0, 0, 2) tensor([[1, 2, 3], [4, 5, 6]], dtype=oneflow.int64) >>> flow.narrow(input, 1, 1, 2) tensor([[2, 3], [5, 6], [8, 9]], dtype=oneflow.int64) """, ) add_docstr( oneflow.unsqueeze, r""" unsqueeze(input, dim) -> Tensor Returns a new tensor with a dimension of size one inserted at the specified position. The returned tensor shares the same underlying data with this tensor. A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1)` can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` applied at :attr:`dim` = ``dim + input.ndimension() + 1``. Args: input (Tensor): the input tensor. dim (int): the index at which to insert the singleton dimension For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.randn(2, 3, 4) >>> y = x.unsqueeze(2) >>> y.shape oneflow.Size([2, 3, 1, 4]) """, ) add_docstr( oneflow.permute, r""" permute(input, *dims) -> Tensor Returns a view of the original tensor with its dimensions permuted. Args: dims (tuple of ints): The desired ordering of dimensions For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32) >>> output = flow.permute(input, (1, 0, 2, 3)).shape >>> output oneflow.Size([6, 2, 5, 3]) """, ) ================================================ FILE: python/oneflow/framework/docstr/tensor_t.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.t, """ oneflow.t(input) → Tensor. Expects `input` to be <= 2-D tensor and transposes dimensions 0 and 1. 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this is equivalent to `transpose(input, 0, 1)`. Args: input (oneflow.Tensor): An input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.random.randn(), dtype=flow.float32) >>> flow.t(x).shape oneflow.Size([]) >>> x = flow.tensor(np.random.randn(3), dtype=flow.float32) >>> flow.t(x).shape oneflow.Size([3]) >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> flow.t(x).shape oneflow.Size([3, 2]) """, ) ================================================ FILE: python/oneflow/framework/docstr/tensordot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.tensordot, r""" tensordot(a, b, dims=Union[int, Tensor, Tuple[List[int], List[int]], List[List[int]]], out=None) -> Tensor Compute tensor dot along given dimensions. Given two tensors a and b, and dims which represent two lists containing dim indices, `tensordot` traverses the two lists and calculate the tensor dot along every dim pair. Args: a (oneflow.Tensor): The input tensor to compute tensordot b (oneflow.Tensor): The input tensor to compute tensordot dims (int or list or tuple or oneflow.Tensor): The dims to calculate tensordot. If it's an integer or oneflow.Tensor with only one element, the last dims of tensor `a` and the first dims of tensor `b` will be calculated. If it's a list or tuple or oneflow.Tensor with more than one element, it must contain two array-like object, which represent the dims of tensor a and tensor b to be calculated. out (oneflow.Tensor): The tensor to save result (NOT IMPLEMENTED YET) Returns: oneflow.Tensor: The result tensor For example: .. code-block:: python >>> import oneflow as flow >>> a = flow.randn(3, 4, 5) >>> b = flow.randn(4, 5, 6) >>> flow.tensordot(a, b, dims=2).shape oneflow.Size([3, 6]) >>> b = flow.randn(5, 6, 7) >>> flow.tensordot(a, b, dims=1).shape oneflow.Size([3, 4, 6, 7]) >>> b = flow.randn(3, 4, 7) >>> flow.tensordot(a, b, dims=[[0, 1], [0, 1]]).shape oneflow.Size([5, 7]) Note: Three common use cases are: - dims = 0 : tensor product :math:`a \otimes b` - dims = 1 : tensor dot product :math:`a \cdot b` - dims = 2 : (default) tensor double contraction :math:`a : b` The part of documentation is referenced from https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html. Note: The operation is equivalent to the series of operations: - Permute the dimensions of the tensor A that require tensordot to the end - Permute the dimensions of the tensor B that require tensordot to the start - Reshape the permuted tensor A into a 2-dimensional tensor, where the size of the 0th dimension is the product of the dimensions that do not require dot product, and the size of the 1st dimension is the product of the dimensions that require dot product - Reshape the permuted tensor B into a 2-dimensional tensor, where the size of the 0th dimension is the product of the dimensions that require dot product, and the size of the 1st dimension is the product of the dimensions that do not require dot product - Calculate the matrix multiplication of reshaped tensor A and reshaped tensor B - Reshape the result of matrix multiplication, the target shape is the concatenation of the dimensions that do not require tensordot of tensor A and B This series of operations can be equivalently represented by the following code: .. code-block:: python >>> import oneflow as flow >>> a = flow.randn(2, 4, 3) >>> b = flow.randn(3, 4, 2) >>> dims = [[0, 2], [2, 0]] >>> permuted_a = a.permute(1, 0, 2) # 0, 2 are the dimensions requiring tensordot and are placed in the end in permuting >>> permuted_b = b.permute(2, 0, 1) # 2, 0 are the dimensions requiring tensordot and are placed at the beginning in permuting >>> reshaped_a = permuted_a.reshape(4, 2 * 3) # 4 is the dimensions of a that do not require tensordot >>> reshaped_b = permuted_b.reshape(2 * 3, 4) # 4 is the dimensions of a that do not require tensordot >>> matmul_result = flow.matmul(reshaped_a, reshaped_b) >>> result = matmul_result.reshape(4, 4) # 4, 4 are the concatentation of dimensions that do not require tensordot of a and b >>> flow.all(result == flow.tensordot(a, b, dims)) tensor(True, dtype=oneflow.bool) .. Feature Stage of Operator [tensordot]. - Maintainer List [@marigoold] - Current Stage [ ] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [ ] (out parameter is not implemented yet) - eager local [Yes] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] (when the type of param `dims` is oneflow.Tensor, the tensor.item() will make graph fail) - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Exception Handling - Exception Message and Hint must be provided [Yes] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ] - Doc(Same standard as Alpha Stage)[ ] - Functionality and its' Test [ ] - eager global [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - graph gloal [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [ ] - Exception Message and Hint must be provided [ ] - Try you best to do Exception Recovery [ ] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """, ) ================================================ FILE: python/oneflow/framework/docstr/tile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.tile, """ tile(input, dims) -> Tensor Constructs a tensor by repeating the elements of ``input``. The ``dims`` argument specifies the number of repetitions in each dimension. If ``dims`` specifies fewer dimensions than ``input`` has, then ones are prepended to ``dims`` until all dimensions are specified. For example, if ``input`` has shape (8, 6, 4, 2) and ``dims`` is (2, 2), then ``dims`` is treated as (1, 1, 2, 2). Analogously, if ``input`` has fewer dimensions than ``dims`` specifies, then ``input`` is treated as if it were unsqueezed at dimension zero until it has as many dimensions as ``dims`` specifies. For example, if ``input`` has shape (4, 2) and ``dims`` is (3, 3, 2, 2), then ``input`` is treated as if it had the shape (1, 1, 4, 2). .. note:: This function is similar to NumPy’s tile function. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.tile.html. Args: input (oneflow.Tensor): the tensor whose elements to repeat. dims (tuple): the number of repetitions per dimension. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> out = input.tile(2,1,2,1) >>> out.shape oneflow.Size([10, 3, 12, 9]) >>> x = np.random.randn(5, 2, 1) >>> input = flow.Tensor(x) >>> out = input.tile(3,4) >>> out.shape oneflow.Size([5, 6, 4]) """, ) ================================================ FILE: python/oneflow/framework/docstr/topk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.topk, """Finds the values and indices of the k largest entries at specified axis. Args: input (oneflow.Tensor): Input Tensor k (int): the k in “top-k” dim (int, optional): the dimension to sort along. Defaults to the last dim (-1) largest (bool, optional): controls whether to return largest or smallest elements sorted (bool, optional): controls whether to return the elements in sorted order (Only Support True Now!) Returns: Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int32)): A tuple of (values, indices), where the indices are the indices of the elements in the original input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = np.array([[1, 3, 8, 7, 2], [1, 9, 4, 3, 2]], dtype=np.float32) >>> result = flow.topk(flow.Tensor(x), k=3, dim=1) >>> result.values tensor([[8., 7., 3.], [9., 4., 3.]], dtype=oneflow.float32) >>> result.indices tensor([[2, 3, 1], [1, 2, 3]], dtype=oneflow.int64) >>> result.values.shape oneflow.Size([2, 3]) >>> result.indices.shape oneflow.Size([2, 3]) >>> result = flow.topk(flow.Tensor(x), k=2, dim=1, largest=False) >>> result.values tensor([[1., 2.], [1., 2.]], dtype=oneflow.float32) >>> result.indices tensor([[0, 4], [0, 4]], dtype=oneflow.int64) >>> result.values.shape oneflow.Size([2, 2]) >>> result.indices.shape oneflow.Size([2, 2]) """, ) ================================================ FILE: python/oneflow/framework/docstr/trigonometric_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.atan2, """Element-wise arctangent of input{i}/other{i} with consideration of the quadrant. Returns a new tensor with the signed angles in radians between vector (other{i},input{i}) and vector (1, 0). The shapes of input and other must be broadcastable. Args: input (Tensor): the first input tensor. other (Tensor): the second input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.Tensor(np.array([1,2,3])) >>> y1 = flow.Tensor(np.array([3,2,1])) >>> x2 = flow.Tensor(np.array([1.53123589,0.54242598,0.15117185])) >>> y2 = flow.Tensor(np.array([-0.21906378,0.09467151,-0.75562878])) >>> x3 = flow.Tensor(np.array([1,0,-1])) >>> y3 = flow.Tensor(np.array([0,1,0])) >>> flow.atan2(x1,y1).numpy() array([0.32175055, 0.7853982 , 1.2490457 ], dtype=float32) >>> flow.atan2(x2,y2).numpy() array([1.7128955, 1.3980033, 2.9441385], dtype=float32) >>> flow.atan2(x3,y3).numpy() array([ 1.5707964, 0. , -1.5707964], dtype=float32) """, ) ================================================ FILE: python/oneflow/framework/docstr/unbind.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.unbind, """ Removes a tensor dimension. Returns a tuple of all slices along a given dimension, already without it. This function is equivalent to PyTorch's unbind function. Args: x(Tensor): the tensor to unbind dim(int): dimension to remove For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor(range(12)).reshape([3,4]) >>> flow.unbind(x) (tensor([0, 1, 2, 3], dtype=oneflow.int64), tensor([4, 5, 6, 7], dtype=oneflow.int64), tensor([ 8, 9, 10, 11], dtype=oneflow.int64)) >>> flow.unbind(x, 1) (tensor([0, 4, 8], dtype=oneflow.int64), tensor([1, 5, 9], dtype=oneflow.int64), tensor([ 2, 6, 10], dtype=oneflow.int64), tensor([ 3, 7, 11], dtype=oneflow.int64)) """, ) ================================================ FILE: python/oneflow/framework/docstr/util_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.isnan, """ isnan(input) -> Tensor This function is equivalent to PyTorch’s isnan function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isnan.html?highlight=isnan#torch.isnan Returns a new tensor with boolean elements representing if each element of input is NaN or not. Args: input(Tensor): the input tensor. Returns: A boolean tensor that is True where input is NaN and False elsewhere. Example:: >>> import oneflow as flow >>> flow.isnan(flow.tensor([1, float('nan'), 2])) tensor([False, True, False], dtype=oneflow.bool) """, ) add_docstr( oneflow.isinf, """ isinf(input) -> Tensor This function is equivalent to PyTorch’s isinf function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isinf.html?highlight=isinf#torch.isinf Tests if each element of input is infinite (positive or negative infinity) or not. Args: input(Tensor): the input tensor. Returns: A boolean tensor that is True where input is infinite and False elsewhere. Example:: >>> import oneflow as flow >>> flow.isinf(flow.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([False, True, False, True, False], dtype=oneflow.bool) """, ) add_docstr( oneflow.isfinite, """ isfinite(input) -> Tensor This function is equivalent to PyTorch’s isfinite function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isfinite.html?highlight=isfinite#torch.isfinite Returns a new tensor with boolean elements representing if each element is finite or not. Args: input(Tensor): the input tensor. Returns: A boolean tensor that is True where input is finite and False elsewhere. Example:: >>> import oneflow as flow >>> flow.isfinite(flow.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([ True, False, True, False, False], dtype=oneflow.bool) """, ) ================================================ FILE: python/oneflow/framework/docstr/utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal from doctest import DocTestParser, DebugRunner, DocTestRunner def _test_docstr(docstr, verbose=True, optionflags=0, raise_on_error=True): parser = DocTestParser() if raise_on_error: runner = DebugRunner(verbose=verbose, optionflags=optionflags) else: runner = DocTestRunner(verbose=verbose, optionflags=optionflags) test = parser.get_doctest(docstr, {}, __name__, __file__, 0) runner.run(test) def add_docstr(fun, docstr: str): return oneflow._oneflow_internal.add_doc(fun, docstr) def reset_docstr(o, docstr): if type(o) == type: assert hasattr(o, "__doc__"), str(o) + " does not have a docstring!" setattr(o, "__doc__", docstr) return o else: return oneflow._oneflow_internal.reset_doc(o, docstr) ================================================ FILE: python/oneflow/framework/docstr/vision.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow._C.pad, r""" Pads tensor. Padding size: The padding size by which to pad some dimensions of :attr:`input` are described starting from the last dimension and moving forward. :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions of ``input`` will be padded. For example, to pad only the last dimension of the input tensor, then :attr:`pad` has the form :math:`(\text{padding_left}, \text{padding_right})`; to pad the last 2 dimensions of the input tensor, then use :math:`(\text{padding_left}, \text{padding_right},` :math:`\text{padding_top}, \text{padding_bottom})`; to pad the last 3 dimensions, use :math:`(\text{padding_left}, \text{padding_right},` :math:`\text{padding_top}, \text{padding_bottom}` :math:`\text{padding_front}, \text{padding_back})`. Padding mode: See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the padding modes works. Constant padding is implemented for arbitrary dimensions. Replicate padding is implemented for padding the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor. Reflect padding is only implemented for padding the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor. Args: input (Tensor): N-dimensional tensor pad (tuple): m-elements tuple, where :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'constant'`` value: fill value for ``'constant'`` padding. Default: ``0`` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> pad = [2, 2, 1, 1] >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) >>> output = flow.nn.functional.pad(input, pad, mode = "replicate") >>> output.shape oneflow.Size([1, 2, 5, 7]) >>> output tensor([[[[ 0., 0., 0., 1., 2., 2., 2.], [ 0., 0., 0., 1., 2., 2., 2.], [ 3., 3., 3., 4., 5., 5., 5.], [ 6., 6., 6., 7., 8., 8., 8.], [ 6., 6., 6., 7., 8., 8., 8.]], [[ 9., 9., 9., 10., 11., 11., 11.], [ 9., 9., 9., 10., 11., 11., 11.], [12., 12., 12., 13., 14., 14., 14.], [15., 15., 15., 16., 17., 17., 17.], [15., 15., 15., 16., 17., 17., 17.]]]], dtype=oneflow.float32) See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the padding modes works. """, ) ================================================ FILE: python/oneflow/framework/docstr/where.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.docstr.utils import add_docstr add_docstr( oneflow.where, """Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`. If the element in condition is larger than 0, it will take the `x` element, else it will take the `y` element .. note:: If :attr:`x` is None and :attr:`y` is None, flow.where(condition) is identical to flow.nonzero(condition, as_tuple=True). The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be broadcastable. Args: condition (IntTensor): When 1 (nonzero), yield x, otherwise yield y x (Tensor or Scalar): value (if :attr:x is a scalar) or values selected at indices where :attr:`condition` is True y (Tensor or Scalar): value (if :attr:x is a scalar) or values selected at indices where :attr:`condition` is False Returns: Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor( ... np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), ... dtype=flow.float32, ... ) >>> y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32) >>> condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) >>> out = condition.where(x, y) >>> out #doctest: +ELLIPSIS tensor([[1.0000, 0.3139], ... [0.0478, 1.0000]], dtype=oneflow.float32) """, ) ================================================ FILE: python/oneflow/framework/dtype.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import oneflow import oneflow._oneflow_internal import oneflow.core.common.data_type_pb2 as data_type_pb2 from oneflow._oneflow_internal import ( set_default_dtype, get_default_dtype, ) _dtypes = [ oneflow.bool, oneflow.float, oneflow.float32, oneflow.double, oneflow.float64, oneflow.float16, oneflow.int8, oneflow.int16, oneflow.int32, oneflow.int64, oneflow.uint8, oneflow.record, oneflow.tensor_buffer, oneflow.bfloat16, oneflow.complex64, oneflow.cfloat, oneflow.complex128, oneflow.cdouble, ] def dtypes(): return _dtypes def convert_proto_dtype_to_oneflow_dtype(proto_dtype): return oneflow._oneflow_internal.deprecated.GetDTypeByDataType(proto_dtype) _ONEFLOW_DTYPE_TO_NUMPY_DTYPE = { # >> np_bool = np.array([1,2], dtype=bool).dtype # >> np_bool == bool # True oneflow.bool: bool, oneflow.float: np.float32, oneflow.float16: np.float16, oneflow.float32: np.float32, oneflow.float64: np.double, oneflow.double: np.double, oneflow.int8: np.int8, oneflow.char: np.int8, oneflow.int16: np.int16, oneflow.int32: np.int32, oneflow.int64: np.int64, oneflow.uint8: np.uint8, oneflow.complex64: np.complex64, oneflow.cfloat: np.complex64, oneflow.complex128: np.complex128, oneflow.cdouble: np.complex128, } def convert_oneflow_dtype_to_numpy_dtype(oneflow_dtype: oneflow.dtype): if oneflow_dtype not in _ONEFLOW_DTYPE_TO_NUMPY_DTYPE: raise NotImplementedError return _ONEFLOW_DTYPE_TO_NUMPY_DTYPE[oneflow_dtype] def convert_numpy_dtype_to_oneflow_dtype(numpy_dtype: np.dtype): for (k, v) in _ONEFLOW_DTYPE_TO_NUMPY_DTYPE.items(): if v == numpy_dtype: return k raise NotImplementedError del data_type_pb2 del np def set_default_tensor_type(tensor_type): """Sets the default floating point type for those source operators which create Tensor. The default floating point type is ``oneflow.FloatTensor``. Args: tensor_type (type or string): The floating point tensor type or its name. For example: .. code-block:: python >>> import oneflow >>> oneflow.set_default_tensor_type(oneflow.FloatTensor) >>> x = oneflow.ones(2, 3) >>> x.dtype oneflow.float32 >>> oneflow.set_default_tensor_type("oneflow.DoubleTensor") >>> x = oneflow.ones(2, 3) >>> x.dtype oneflow.float64 >>> oneflow.set_default_tensor_type(oneflow.FloatTensor) >>> x = oneflow.tensor([1.0, 2]) >>> x.dtype oneflow.float32 """ def _import_dotted_name(name): """ This function quotes from: https://github.com/pytorch/pytorch/blob/master/torch/_utils.py """ components = name.split(".") obj = __import__(components[0]) for component in components[1:]: obj = getattr(obj, component) return obj if isinstance(tensor_type, str): tensor_type = _import_dotted_name(tensor_type) oneflow._oneflow_internal.set_default_tensor_type(tensor_type) def is_floating_point(input): return input.is_floating_point() ================================================ FILE: python/oneflow/framework/env_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import socket import traceback from contextlib import closing import warnings import oneflow._oneflow_internal import oneflow.core.control.ctrl_bootstrap_pb2 as ctrl_bootstrap_pb import oneflow.core.job.env_pb2 as env_pb import oneflow.core.job.resource_pb2 as resource_util import oneflow.framework.c_api_util as c_api_util def api_all_device_placement(device_type: str) -> oneflow._oneflow_internal.placement: r""" oneflow.env.all_device_placement(device_type) -> oneflow.placement Returns a placement that contains all available devices. Note: It is recommended to use `oneflow.placement.all` instead of this function. Args: device_type (str): cuda or cpu For examples: .. code-block:: python # Runs on 4 ranks import oneflow as flow p = flow.env.all_device_placement("cuda") # oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]) p = flow.env.all_device_placement("cpu") # oneflow.placement(type="cpu", ranks=[0, 1, 2, 3]) """ return oneflow.placement.all(device_type) def check_non_localhost_proxy_and_print_warning(): for env_var_name in ["http_proxy", "HTTP_PROXY", "https_proxy", "HTTPS_PROXY"]: env_var_value = os.getenv(env_var_name) if ( env_var_value is not None and (not "://localhost" in env_var_value) and (not "://127.0.0.1" in env_var_value) and (not env_var_value.startswith("localhost")) and (not env_var_value.startswith("127.0.0.1")) ): print( f"Proxy through another machine ({env_var_value}) is incompatible with OneFlow. Please unset them by `unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY`" ) break def create_env(): """create environment Returns: Env: [description] """ global default_env_proto assert len(default_env_proto.machine) > 0 CompleteEnvProto(default_env_proto) if default_env_proto.ctrl_bootstrap_conf.world_size > 1: check_non_localhost_proxy_and_print_warning() return c_api_util.GetEnvContext(default_env_proto) def CompleteEnvProto(env_proto): _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto) if env_proto.HasField("ctrl_port") == False: if len(env_proto.machine) == 1: env_proto.ctrl_port = _FindFreePort() else: raise ValueError( "a ctrl_port is required if running multi-node, set it with 'oneflow.env.ctrl_port([YOUR PORT])'" ) def _MakeMachine(machines): if isinstance(machines, str): machines = [machines] rp_machine = env_pb.EnvProto().machine for m_data in machines: m = rp_machine.add() if isinstance(m_data, str): m.addr = m_data elif isinstance(m_data, dict): if "addr" in m_data: m.addr = m_data["addr"] if "ctrl_port_agent" in m_data: m.ctrl_port_agent = m_data["ctrl_port_agent"] if "data_port_agent" in m_data: m.data_port_agent = m_data["data_port_agent"] else: raise NotImplementedError id = 0 addrs_for_check = set() for m in rp_machine: m.id = id id += 1 assert m.addr not in addrs_for_check addrs_for_check.add(m.addr) return rp_machine def _MakeBootstrapConf(bootstrap_info: dict): global config_master_addr assert config_master_addr.HasField("host"), "must config master host first" assert config_master_addr.HasField("port"), "must config master port first" assert config_world_size != 0, "must config world size first" bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf() bootstrap_conf.master_addr.CopyFrom(config_master_addr) bootstrap_conf.world_size = config_world_size assert "rank" in bootstrap_info bootstrap_conf.rank = bootstrap_info["rank"] if "host" in bootstrap_info: bootstrap_conf.host = bootstrap_info["host"] global config_bootstrap_ctrl_port if config_bootstrap_ctrl_port != 0: bootstrap_conf.ctrl_port = config_bootstrap_ctrl_port global config_node_size if config_node_size != 0: bootstrap_conf.node_size = config_node_size return bootstrap_conf def _DefaultEnvProto(): env_proto = env_pb.EnvProto() machine = env_proto.machine.add() machine.id = 0 machine.addr = "127.0.0.1" return env_proto def _FindFreePort(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("localhost", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] def CheckAndWarnAbnormalEnvVars(): env_var_names = ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] env_var_without_value = [x for x in env_var_names if os.getenv(x) is None] env_var_with_value = [x for x in env_var_names if os.getenv(x) is not None] if len(env_var_with_value) != 0 and len(env_var_without_value) != 0: warnings.warn( f"Among four environment variables required for distributed training, only {', '.join('`{0}`'.format(x) for x in env_var_with_value)} are set, but {', '.join('`{0}`'.format(x) for x in env_var_without_value)} are not set." ) def _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto): def str2int(env_config): return int(env_config) bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf() master_addr = ctrl_bootstrap_pb.Address() master_addr.host = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr.port = str2int(os.getenv("MASTER_PORT", _FindFreePort())) bootstrap_conf.master_addr.CopyFrom(master_addr) bootstrap_conf.world_size = str2int(os.getenv("WORLD_SIZE", 1)) bootstrap_conf.rank = str2int(os.getenv("RANK", 0)) env_proto.ctrl_bootstrap_conf.CopyFrom(bootstrap_conf) cpp_logging_conf = env_pb.CppLoggingConf() if os.getenv("GLOG_log_dir"): cpp_logging_conf.log_dir = os.getenv("GLOG_log_dir") if os.getenv("GLOG_logtostderr"): cpp_logging_conf.logtostderr = str2int(os.getenv("GLOG_logtostderr")) if os.getenv("GLOG_logbuflevel"): cpp_logging_conf.logbuflevel = str2int(os.getenv("GLOG_logbuflevel")) if os.getenv("GLOG_minloglevel"): cpp_logging_conf.minloglevel = str2int(os.getenv("GLOG_minloglevel")) env_proto.cpp_logging_conf.CopyFrom(cpp_logging_conf) class EnvHolder(object): def __init__(self): CheckAndWarnAbnormalEnvVars() self._env_cxt = create_env() self._shutting_down = [False] def is_shutting_down(self): """ Whether the interpreter is currently shutting down. For use in finalizers, __del__ methods, and similar; it is advised to early bind this function rather than look it up when calling it, since at shutdown module globals may be cleared. Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402 This solution is obtained from cupy code: https://github.com/cupy/cupy/pull/2809 """ return self._shutting_down[0] def switch_to_shutting_down(self, is_normal_exit=True): self._shutting_down[0] = True self._env_cxt.SwitchToShuttingDownPhase(is_normal_exit) def GetEnv(): return EnvHolder() device_tag2default_parallel_conf = {} default_env_proto = _DefaultEnvProto() config_master_addr = ctrl_bootstrap_pb.Address() config_world_size = 0 config_bootstrap_ctrl_port = 0 config_node_size = 0 global_ctrl_bootstrap_confs = [] ================================================ FILE: python/oneflow/framework/function_desc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal import oneflow.core.job.job_conf_pb2 as job_conf_pb import oneflow.framework.hob as hob import oneflow.framework.session_context as session_ctx import oneflow.support.enable_if as enable_if class FunctionAttribute(object): def __init__(self): self.default_placement_scope = None self.default_distribute_strategy = None self.allow_cpu_return_op = True class FunctionDesc(object): def __init__(self, job_func=None, job_config_proto=None, function_attribute=None): if job_config_proto is None: job_config_proto = job_conf_pb.JobConfigProto() if function_attribute is None: function_attribute = FunctionAttribute() self.job_func = job_func self.job_config_proto = job_config_proto self.job_config_proto.predict_conf.SetInParent() self.function_attribute = function_attribute def IsTrainable(self): if self.job_config_proto.HasField("train_conf"): return True if self.job_config_proto.HasField("predict_conf"): return False raise NotImplementedError def HasAttr(self, attr_name): if attr_name == "flag_name2flag_value": return False name2default = session_ctx.GetDefaultSession().function_flag_name2default_val if attr_name in self.job_config_proto.flag_name2flag_value: return True return self.job_config_proto.HasField(attr_name) def __getattr__(self, attr_name): assert attr_name != "flag_name2flag_value" flag_name2flag_value = self.job_config_proto.flag_name2flag_value name2default = session_ctx.GetDefaultSession().function_flag_name2default_val if attr_name not in name2default: assert self.job_config_proto.HasField(attr_name) return getattr(self.job_config_proto, attr_name) attr_value = name2default[attr_name] if attr_name in flag_name2flag_value: attr_value = flag_name2flag_value[attr_name] if attr_value.HasField("at_bool"): return attr_value.at_bool elif attr_value.HasField("at_int64"): return attr_value.at_int64 elif attr_value.HasField("at_double"): return attr_value.at_double elif attr_value.HasField("at_string"): return attr_value.at_string else: raise NotImplementedError() ================================================ FILE: python/oneflow/framework/function_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import copy import functools import inspect import re import sys import traceback from typing import Any, Callable, Optional, Union import oneflow._oneflow_internal import oneflow.core.common.data_type_pb2 as data_type_pb import oneflow.framework.session_context as session_ctx import oneflow.support.enable_if as enable_if import oneflow.support.pb_util as pb_util from oneflow import oneflow_deprecate from oneflow.framework.function_desc import FunctionDesc class FunctionConfig(object): """OneFlow function's configurations. """ def __init__(self) -> None: self.function_desc = FunctionDesc() def __getattr__( self, attr_name: str ) -> Callable[[Optional[Union[bool, int, float, str]]], None]: name2default = session_ctx.GetDefaultSession().function_flag_name2default_val assert attr_name in name2default flag_name2flag_value = self.function_desc.job_config_proto.flag_name2flag_value default_val = name2default[attr_name] def FunctionConfigSetter( attr_value: Optional[Union[bool, int, float, str]] = None ) -> None: if default_val.HasField("at_bool"): if attr_value is None: attr_value = True assert type(attr_value) is bool flag_name2flag_value[attr_name].at_bool = attr_value elif default_val.HasField("at_int64"): assert type(attr_value) is int flag_name2flag_value[attr_name].at_int64 = attr_value elif default_val.HasField("at_double"): assert type(attr_value) is float flag_name2flag_value[attr_name].at_double = attr_value elif default_val.HasField("at_string"): assert type(attr_value) is str flag_name2flag_value[attr_name].at_string = attr_value else: raise NotImplementedError( "config_flag `%s' with type %s is not supported" % (attr_name, type(attr_value)) ) return FunctionConfigSetter def _CloneFunctionDesc(func_desc, job_func): new_func_desc = FunctionDesc(job_func=job_func) new_func_desc.job_config_proto.CopyFrom(func_desc.job_config_proto) new_func_desc.function_attribute = copy.deepcopy(func_desc.function_attribute) return new_func_desc def oneflow_function_config(*field_paths): def Decorator(func): global _class_property2return_obj_class for field_path in field_paths: fields = field_path.split(".") assert len(fields) > 0 cls = FunctionConfig for (index, field) in enumerate(fields): assert field != "function_desc" assert re.match("^[_\\w]+[_\\w\\d]*$", field) if (cls, field) not in _class_property2return_obj_class: class_name = ".".join(["function_config"] + fields[: index + 1]) def Init(self, function_desc): self.function_desc = function_desc config_class = type(class_name, (object,), dict(__init__=Init)) setattr(cls, field, _MakeInnerJobConfigClassProperty(config_class)) _class_property2return_obj_class[cls, field] = config_class cls = _class_property2return_obj_class[cls, field] cls.__call__ = _MakeLeafJobConfigCall(func) return func return Decorator _class_property2return_obj_class = {} def _MakeInnerJobConfigClassProperty(return_obj_class): return property(lambda self: return_obj_class(self.function_desc)) def _MakeLeafJobConfigCall(method): return lambda self, *argv, **kwarg: method(self.function_desc, *argv, **kwarg) @oneflow_function_config("default_data_type") def set_default_data_type(func_desc, value): """Set default data type for job Args: func_desc ([type]): job function value ([type]): data type. e.g. flow.float """ func_desc.job_config_proto.default_data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( value ) @oneflow_function_config("default_initializer_conf") def set_default_initializer_conf(func_desc, value): """Set default initial configuration for job Args: func_desc ([type]): [description] value ([type]): [description] """ assert type(value) is dict pb_util.PythonDict2PbMessage( value, func_desc.job_config_proto.default_initializer_conf ) @oneflow_function_config("exp_run_conf") def set_exp_run_conf(value): """Set experimental configuration for job Args: value ([type]): [description] """ assert type(func_desc, value) is dict pb_util.PythonDict2PbMessage(value, func_desc.job_config_proto.exp_run_conf) @oneflow_function_config("static_mem_alloc_policy_white_list.has") def static_mem_alloc_policy_white_list_has_policy(func_desc, policy): """Get items from white list related to static memory allocation policy Args: func_desc ([type]): [description] policy ([type]): [description] Returns: [type]: [description] """ return getattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy) @oneflow_function_config("static_mem_alloc_policy_white_list.add") def static_mem_alloc_policy_white_list_add_policy(func_desc, policy): """Add item to white list related to static memory allocation policy Args: func_desc ([type]): [description] policy ([type]): [description] """ setattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy, True) @oneflow_function_config("static_mem_alloc_policy_white_list.remove") def static_mem_alloc_policy_white_list_remove_policy(func_desc, policy): """Remove item of white list related to static memory allocation policy Args: func_desc ([type]): [description] policy ([type]): [description] """ setattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy, False) @oneflow_function_config("static_mem_alloc_policy_white_list.policy_mem_size_first") def policy_mem_size_first(func_desc): """A static memory allocation policy called: mem_size_first Args: func_desc ([type]): [description] Returns: [type]: [description] """ return "use_mem_size_first_algo" @oneflow_function_config( "static_mem_alloc_policy_white_list.policy_mutual_exclusion_first" ) def policy_mutual_exclusion_first(func_desc): """A static memory allocation policy called: mutual_exclusion_first Args: func_desc ([type]): [description] Returns: [type]: [description] """ return "use_mutual_exclusion_first_algo" @oneflow_function_config("static_mem_alloc_policy_white_list.policy_time_line") def policy_time_line(func_desc): """A static memory allocation policy called: time_line Args: func_desc ([type]): [description] Returns: [type]: [description] """ return "use_time_line_algo" @oneflow_function_config("static_mem_alloc_algo_white_list.show") def show_static_mem_alloc_algo_white_list(func_desc): """Show configuration of static memory allocation policy, including: "use_mem_size_first_algo", "use_mutual_exclusion_first_algo", "use_time_line_algo" Args: func_desc ([type]): [description] Returns: [type]: [description] """ return [ "use_mem_size_first_algo", "use_mutual_exclusion_first_algo", "use_time_line_algo", ] @oneflow_function_config("enable_cudnn") def set_enable_cudnn(func_desc, value=True): """Whether use cudnn to accelerate job or not. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.enable_cudnn = value @oneflow_function_config("cudnn_buf_limit_mbyte") def set_cudnn_buf_limit_mbyte(func_desc, value): """Set cudnn buffer limit, e.g. 1024mb Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_buf_limit_mbyte = value @oneflow_function_config("cudnn_conv_force_fwd_algo") def set_cudnn_conv_force_fwd_algo(func_desc, value): """Set value to cudnn conv_force_forward algorithm Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_force_fwd_algo = value @oneflow_function_config("cudnn_conv_force_bwd_data_algo") def set_cudnn_conv_force_bwd_data_algo(func_desc, value): """Set value to cudnn conv_force_backward_data algorithm Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_force_bwd_data_algo = value @oneflow_function_config("cudnn_conv_force_bwd_filter_algo") def set_cudnn_conv_force_bwd_filter_algo(func_desc, value): """Set value to cudnn conv_force_backward_filter algorithm Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_force_bwd_filter_algo = value @oneflow_function_config("cudnn_conv_heuristic_search_algo") def set_cudnn_conv_heuristic_search_algo(func_desc, value): """Set value to cudnn conv_heuristic_search algorithm Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_heuristic_search_algo = value @oneflow_function_config("enable_cudnn_fused_normalization_add_relu") def set_enable_cudnn_fused_normalization_add_relu(func_desc, value): """Whether enable cudnn_fused_normalization_add_relu. Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.enable_cudnn_fused_normalization_add_relu = value @oneflow_function_config("enable_fuse_add_to_output") def set_enable_fuse_add_to_output(func_desc, value): """Whether enable fuse_add_to_output. If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance. Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.enable_fuse_add_to_output = value @oneflow_function_config("enable_fuse_cast_scale") def set_enable_fuse_cast_scale(func_desc, value=True): """Whether enable fuse_cast_scale. If enabled, try to fuse cast and scalar_mul_by_tensor to improve performance. Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.enable_fuse_cast_scale = value @oneflow_function_config("cudnn_conv_use_deterministic_algo_only") def set_cudnn_conv_use_deterministic_algo_only(func_desc, value): """Set value to cudnn conv_use_deterministic_only algorithm Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_use_deterministic_algo_only = value @oneflow_function_config("enable_reuse_mem") def set_enable_reused_mem(func_desc, value=True): """Whether enable reuse memory or not Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.enable_reuse_mem = value @oneflow_function_config("enable_inplace") def set_enable_inplace(func_desc, value=True): """Whether enable inplace or not Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.enable_inplace = value @oneflow_function_config("enable_inplace_in_reduce_struct") def set_enable_inplace_in_reduce_struct(func_desc, value=True): print( "'enable_inplace_in_reduce_struct' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("enable_nccl") def set_enable_nccl(func_desc, value=True): print( "'enable_nccl' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("use_nccl_inter_node_communication") def set_use_nccl_inter_node_communication(func_desc, value=True): print( "'use_nccl_inter_node_communication' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("use_boxing_v2") def set_use_boxing_v2(func_desc, value=True): print( "'use_boxing_v2' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("do_parallel_cast_before_widening_type_cast") def set_do_parallel_cast_before_widening_type_cast(func_desc, value=True): func_desc.job_config_proto.do_parallel_cast_before_widening_type_cast = value @oneflow_function_config("enable_all_reduce_group") def set_enable_all_reduce_group(func_desc, value=True): print( "'enable_all_reduce_group' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("all_reduce_group_num") def set_all_reduce_group_num(func_desc, value): print( "'all_reduce_group_num' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("all_reduce_lazy_ratio") def set_all_reduce_lazy_ratio(func_desc, value): print( "'all_reduce_lazy_ratio' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("all_reduce_group_min_mbyte") def set_all_reduce_group_min_mbyte(func_desc, value): print( "'all_reduce_group_min_mbyte' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("all_reduce_group_size_warmup") def set_all_reduce_group_size_warmup(func_desc, value): print( "'all_reduce_group_size_warmup' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("all_reduce_fp16") def set_all_reduce_fp16(func_desc, value=True): print( "'all_reduce_fp16' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config( "optimizer_placement_optimization_mode", "train.optimizer_placement_optimization_mode", ) def set_optimizer_placement_optimization_mode(func_desc, mode): """Enable optimizer_placement_optimization with mode 'mode' Args: func_desc ([type]): [description] mode (str): [description]. """ assert mode in ["non_distributed", "distributed_split"] func_desc.job_config_proto.optimizer_placement_optimization_mode = mode @oneflow_function_config( "optimizer_placement_optimization_threshold", "train.optimizer_placement_optimization_threshold", ) def set_optimizer_placement_optimization_threshold(func_desc, value): func_desc.job_config_proto.optimizer_placement_optimization_threshold = value @oneflow_function_config("enable_non_distributed_optimizer") def set_enable_non_distributed_optimizer(func_desc, value=True): """Whether enable non_distributed optimizer or not Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ if value: set_optimizer_placement_optimization_mode(func_desc, "non_distributed") @oneflow_function_config("disable_all_reduce_sequence") def set_disable_all_reduce_sequence(func_desc, value=True): print( "'disable_all_reduce_sequence' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config("prune_parallel_cast_ops") def set_prune_parallel_cast_ops(func_desc, value=True): """Whether prune parallel cast operations or not. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.prune_parallel_cast_ops = value @oneflow_function_config("prune_cast_to_static_shape_ops") def set_prune_cast_to_static_shape_ops(func_desc, value=True): """Whether or not set prune_cast to static shape opretions Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.prune_cast_to_static_shape_ops = value @oneflow_function_config("prune_amp_white_identity_ops") def set_prune_amp_white_identity_ops(func_desc, value=True): """Whether prune amp_white_identity operations or not. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.prune_amp_white_identity_ops = value @oneflow_function_config("prune_depend_ops") def set_prune_depend_ops(func_desc, value=True): """Whether prune depend operations or not. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.prune_depend_ops = value @oneflow_function_config("non_distributed_optimizer_group_size_mbyte") def set_non_distributed_optimizer_group_size_mbyte(func_desc, value): print( "'non_distributed_optimizer_group_size_mbyte' has been deprecated, has no effect and will be removed in the future." ) @oneflow_function_config( "enable_true_half_config_when_conv", "cudnn_conv_enable_true_half" ) def set_cudnn_conv_enable_true_half(func_desc, value=True): """Whether use true_half mode or not during convolution calculation process while using cudnn. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.cudnn_conv_enable_pseudo_half = not value @oneflow_function_config( "cudnn_conv_enable_pseudo_half", "enable_cudnn_conv_pseudo_half" ) def set_cudnn_conv_enable_pseudo_half(func_desc, value): """Whether enable pseudo_half mode or not during convolution calculation process while using cudnn Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.cudnn_conv_enable_pseudo_half = value @oneflow_function_config("enable_float_compute_for_half_gemm") def set_enable_float_compute_for_half_gemm(func_desc, value=True): """Whether enable float_compute or not , if True, means that the type of intermedia value is float when compute half gemm. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ print( "WARNING: enable_float_compute_for_half_gemm has been deprecated, because we always use float compute for half gemm. Please remove it.\n " ) print(traceback.format_stack()[-3]) @oneflow_function_config("enable_quantization_aware_training") @oneflow_function_config("enable_qat") def set_enable_quantization_aware_training(func_desc, value=True): """If true, then job will use quantization aware training Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.enable_quantization_aware_training = value @oneflow_function_config("qat.per_channel_weight_quantization") def set_qat_per_channel(func_desc, value=True): func_desc.job_config_proto.qat_config.per_channel_weight_quantization = value @oneflow_function_config("qat.symmetric") def set_qat_symmetric(func_desc, value=True): func_desc.job_config_proto.qat_config.symmetric = value @oneflow_function_config("qat.moving_min_max_momentum") def set_qat_moving_min_max_momentum(func_desc, value: float): func_desc.job_config_proto.qat_config.moving_min_max_momentum = value @oneflow_function_config("qat.moving_min_max_stop_update_after_iters") def set_qat_moving_min_max_momentum(func_desc, value: float): func_desc.job_config_proto.qat_config.moving_min_max_stop_update_after_iters = value @oneflow_function_config("qat.target_backend") def set_qat_symmetric(func_desc, value: str): func_desc.job_config_proto.qat_config.target_backend = value @oneflow_function_config("enable_auto_mixed_precision") def set_enable_auto_mixed_precision(func_desc, value=True): """If true, then job will use mixed precision mode, it means use both float16 and float32 during model training. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ func_desc.job_config_proto.enable_auto_mixed_precision = value @oneflow_function_config("enable_keep_header_only") def set_enable_keep_header_only(func_desc, value=True): """deprecated api. Args: func_desc ([type]): [description] value (bool, optional): [description]. Defaults to True. """ print("Sorry! enable_keep_header_only is deprecated and it doesn't work.\n") @oneflow_function_config("concurrency_width") def set_concurrency_width(func_desc, value): """Set up concurrency width Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.concurrency_width = value @oneflow_function_config("train.model_update_conf") def set_model_update_conf(func_desc, value): """Set up optimizer and update method of learning rate for job Args: func_desc ([type]): [description] value ([type]): [description] """ print( "WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\n " ) print(traceback.format_stack()[-3]) assert type(value) is dict func_desc.job_config_proto.train_conf.model_update_conf.SetInParent() pb_util.PythonDict2PbMessage( value, func_desc.job_config_proto.train_conf.model_update_conf ) @oneflow_function_config("indexed_slices_optimizer_conf") def set_indexed_slices_optimizer_conf(func_desc, value): """Set indexed slices configuration of optimizer Args: func_desc ([type]): [description] value ([type]): [description] """ assert type(value) is dict func_desc.job_config_proto.indexed_slices_optimizer_conf.SetInParent() pb_util.PythonDict2PbMessage( value, func_desc.job_config_proto.indexed_slices_optimizer_conf ) @oneflow_function_config("enable_fuse_model_update_ops") def set_enable_fuse_model_update_ops(func_desc, value=True): """Whether enable fuse_model_update_ops. If enabled, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance. Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.enable_fuse_model_update_ops = value @oneflow_function_config("enable_gradients_stats_aggregation") def set_enable_gradients_stats_aggregation(func_desc, value=True): """Whether enable gradients_stats_aggregation. If enabled, gradients stats ops (norm, finite, ...) will be aggregated. Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.job_config_proto.enable_gradients_stats_aggregation = value @oneflow_function_config("train.loss_scale_factor") def set_loss_scale_factor(func_desc, value): """Set scale factor for loss Args: func_desc ([type]): [description] value ([type]): [description] """ print( "WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\n " ) print(traceback.format_stack()[-3]) func_desc.job_config_proto.train_conf.loss_scale_factor = value @oneflow_function_config("train.primary_lr") def set_primary_lr(func_desc, value): """Set the primary leaning rate for job Args: func_desc ([type]): [description] value ([type]): [description] """ print( "WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\n " ) print(traceback.format_stack()[-3]) func_desc.job_config_proto.train_conf.primary_lr = value @oneflow_function_config("train.secondary_lr") def set_secondary_lr(func_desc, value): """Set the secondary leaning rate for job Args: func_desc ([type]): [description] value ([type]): [description] """ print( "WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\n " ) print(traceback.format_stack()[-3]) func_desc.job_config_proto.train_conf.secondary_lr = value @oneflow_function_config("train.num_gradient_accumulation_steps") def set_num_gradient_accumulation_steps(func_desc, value): func_desc.job_config_proto.num_gradient_accumulation_steps = value @oneflow_function_config("default_logical_view") def set_default_distribute_strategy(func_desc, value): """Set up default distribute strategy for job Args: func_desc ([type]): [description] value ([type]): [description] """ assert isinstance(value, distribute_ctx.DistributeStrategy) func_desc.function_attribute.default_distribute_strategy = value @oneflow_function_config("allow_cpu_return_op") def allow_cpu_return_op(func_desc, value): """Whether allow operaions returned from cpu or not Args: func_desc ([type]): [description] value ([type]): [description] """ func_desc.function_attribute.allow_cpu_return_op = value @oneflow_function_config("default_distribute_strategy") @oneflow_deprecate() def deprecated_set_default_distribute_strategy(*args, **kwargs): print( "WARNING:", "function_config.default_distribute_strategy", "has been deprecated. Please use {} instead.".format( "function_config.default_logical_view" ), ) print(traceback.format_stack()[-3], file=sys.stderr) set_default_distribute_strategy(*args, **kwargs) ================================================ FILE: python/oneflow/framework/generator.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import oneflow._oneflow_internal def create_generator(device=None): if device is None: device = "auto" return oneflow._oneflow_internal.create_generator(device) def seed() -> int: r""" Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.seed.html. """ seed = default_generator.seed() oneflow._oneflow_internal.manual_seed(seed) return seed def manual_seed(seed): r""" Sets the seed for generating random numbers. Returns a `oneflow.Generator` object. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.manual_seed.html. Args: seed (int): The desired seed. Value must be within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ seed = int(seed) return oneflow._oneflow_internal.manual_seed(seed) def initial_seed() -> int: r""" Returns the initial seed for generating random numbers as a Python `long`. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/random.html. """ return default_generator.initial_seed() def _getstate(self): return {"device": str(self.device), "state": self.get_state()} def _setstate(self, state_dict): self.__init__(state_dict["device"]) self.set_state(state_dict["state"]) def get_rng_state(): r""" Sets the random number generator state. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.get_rng_state.html. .. note: This function only works for CPU. For CUDA, please use oneflow.manual_seed(seed), which works for both CPU and CUDA. Args: new_state (oneflow.ByteTensor): The desired state """ return oneflow.default_generator.get_state() def set_rng_state(state): """ Returns the random number generator state as a `oneflow.ByteTensor`. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.set_rng_state.html. """ return oneflow.default_generator.set_state(state) default_generator = oneflow._oneflow_internal.default_generator("cpu") oneflow._oneflow_internal.Generator.__getstate__ = _getstate oneflow._oneflow_internal.Generator.__setstate__ = _setstate ================================================ FILE: python/oneflow/framework/graph_build_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from contextlib import contextmanager import os from google.protobuf import text_format import oneflow import oneflow._oneflow_internal import oneflow.core.job.scope_pb2 as scope_pb2_util import oneflow.core.job.job_conf_pb2 as job_conf_pb import oneflow.framework.attr_util as attr_util import oneflow.framework.c_api_util as c_api_util import oneflow.framework.scope_util as scope_util import oneflow.framework.session_context as session_context from oneflow.framework.tensor import Tensor from oneflow.nn.graph.proxy import GraphBlockType import oneflow._oneflow_internal._C as _C lazy_mode = oneflow._oneflow_internal.lazy_mode @contextmanager def graph_build_context(config_proto, session): prev_scope = oneflow._oneflow_internal.GetCurrentScope() assert type(config_proto) is job_conf_pb.JobConfigProto, type(config_proto) config_proto_str = text_format.MessageToString(config_proto) new_scope = oneflow._oneflow_internal.MakeInitialScope( config_proto_str, oneflow.placement("cpu", [0]), False, # is_local ) graph_scope = _make_new_graph_scope(new_scope, config_proto.job_name) oneflow._oneflow_internal.eager.Sync() with lazy_mode.guard(True): with JobBuildAndInferCtx(config_proto): with BlockScopeContext(prev_scope, graph_scope): yield class JobBuildAndInferCtx(object): def __init__(self, config_proto): self._job_conf = config_proto def __enter__(self): c_api_util.JobBuildAndInferCtx_Open(self._job_conf.job_name) c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf) def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: oneflow._oneflow_internal.JobBuildAndInferCtx_Close() return True else: oneflow._oneflow_internal.JobBuildAndInferCtx_Close() return False class BlockScopeContext(object): def __init__(self, prev_scope, new_scope): assert prev_scope is not None assert new_scope is not None self._prev_scope = prev_scope self._new_scope = new_scope def __enter__(self): assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope oneflow._oneflow_internal.GlobalScopeStackPush(self._new_scope) def __exit__(self, exc_type, exc_val, exc_tb): assert oneflow._oneflow_internal.GetCurrentScope() is self._new_scope oneflow._oneflow_internal.GlobalScopeStackPop() assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope if exc_type is None: return True else: return False class DebugScopeContext(object): def __init__( self, s_level, v_level=0, mode=False, max_py_stack_depth=2, only_user_py_stack=True, ): self._prev_v = oneflow._oneflow_internal.GetFLAGS_v() self._prev_logtostderr = oneflow._oneflow_internal.GetFLAGS_alsologtostderr() self._prev_mode = oneflow._oneflow_internal.GetGraphDebugMode() self._prev_max_py_stack_depth = ( oneflow._oneflow_internal.GetGraphDebugMaxPyStackDepth() ) self._prev_only_user_py_stack = ( oneflow._oneflow_internal.GetGraphDebugOnlyUserPyStack() ) self._v = max(v_level, self._prev_v) self._mode = mode self._s = s_level self._max_py_stack_depth = max( max_py_stack_depth, self._prev_max_py_stack_depth ) self._only_user_py_stack = only_user_py_stack def __enter__(self): oneflow._oneflow_internal.SetFLAGS_v(self._v) oneflow._oneflow_internal.SetGraphDebugMode(self._mode) if self._s == 0 and self._v >= 1: oneflow._oneflow_internal.SetFLAGS_alsologtostderr(True) oneflow._oneflow_internal.SetGraphDebugMaxPyStackDepth(self._max_py_stack_depth) oneflow._oneflow_internal.SetGraphDebugOnlyUserPyStack(self._only_user_py_stack) def __exit__(self, exc_type, exc_val, exc_tb): if self._s == 0 and self._v >= 1: oneflow._oneflow_internal.SetFLAGS_alsologtostderr(self._prev_logtostderr) oneflow._oneflow_internal.SetFLAGS_v(self._prev_v) oneflow._oneflow_internal.SetGraphDebugMode(self._prev_mode) oneflow._oneflow_internal.SetGraphDebugMaxPyStackDepth( self._prev_max_py_stack_depth ) oneflow._oneflow_internal.SetGraphDebugOnlyUserPyStack( self._prev_only_user_py_stack ) def _make_new_scope(prev_scope, scope_proto_str_setter): new_scope = None def build_scope(builder): nonlocal new_scope new_scope = builder.BuildScopeByProtoStrSetter( prev_scope, scope_proto_str_setter ) assert new_scope is not None oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope) oneflow._oneflow_internal.eager.Sync() return new_scope def _make_new_graph_scope(prev_scope, graph_name): assert prev_scope is not None attr_dict = dict() name2default = session_context.GetDefaultSession().scope_attr_name2default_val def scope_proto_str_setter(serialized_scope_proto: str): scope_proto = text_format.Parse( serialized_scope_proto, scope_pb2_util.ScopeProto() ) scope_proto.module_name = graph_name return str(text_format.MessageToString(scope_proto)) return _make_new_scope(prev_scope, scope_proto_str_setter) def make_new_blockgraph_scope(prev_scope, graph_block): assert prev_scope is not None assert graph_block is not None attr_dict = dict() if graph_block.stage_id is not None: attr_dict["pipeline_stage_id_hint"] = graph_block.stage_id if graph_block.type == GraphBlockType.MODULE: if graph_block.activation_checkpointing is not None: attr_dict["checkpointing"] = graph_block.activation_checkpointing name2default = session_context.GetDefaultSession().scope_attr_name2default_val def scope_proto_str_setter(serialized_scope_proto: str): scope_proto = text_format.Parse( serialized_scope_proto, scope_pb2_util.ScopeProto() ) # set attr for attr_name, py_value in attr_dict.items(): assert attr_name in name2default attr_util.SetProtoAttrValue( scope_proto.attr_name2attr_value[attr_name], py_value, name2default[attr_name], ) # append name prefix scope_proto.ClearField("scope_op_name_prefixes") scope_proto.scope_op_name_prefixes.append( graph_block.name_prefix + graph_block.name ) # set module name if graph_block.type == GraphBlockType.MODULE: scope_proto.module_name = graph_block.name_prefix + graph_block.name return str(text_format.MessageToString(scope_proto)) return _make_new_scope(prev_scope, scope_proto_str_setter) def make_new_name_scope(prev_scope, name): assert prev_scope is not None def scope_proto_str_setter(serialized_scope_proto: str): scope_proto = text_format.Parse( serialized_scope_proto, scope_pb2_util.ScopeProto() ) # append name prefix scope_proto.ClearField("scope_op_name_prefixes") scope_proto.scope_op_name_prefixes.append(name) scope_proto.module_name = name return str(text_format.MessageToString(scope_proto)) return _make_new_scope(prev_scope, scope_proto_str_setter) def scope_to_proto(scope): return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto()) def build_graph_input_arg(op_name, arg): assert isinstance(arg, Tensor) input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf() input_conf.in_0 = "in_0" # Set the default value, otherwise the parsing fails input_conf.out_0 = "out_0" input_conf_str = text_format.MessageToString(input_conf) input_op = oneflow._oneflow_internal.one.FeedInputOpExpr( op_name, input_conf_str, ["in_0"], ["out_0"] ) lazy_arg = _C.dispatch_feed_input(input_op, arg) return lazy_arg def build_graph_state(op_name, state_tensor, state_config): var_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf() var_conf.in_0 = "in_0" # Set the default value, otherwise the parsing fails var_conf.out_0 = "out_0" var_conf_str = text_format.MessageToString(var_conf) var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( op_name, var_conf_str, ["in_0"], ["out_0"] ) l2 = 0.0 if state_config is not None: l2 = state_config.l2 elif state_tensor.requires_grad: l2 = 0.0 assert isinstance(state_tensor, Tensor) lazy_tensor = _C.dispatch_feed_variable(var_op, state_tensor, l2=l2) return lazy_tensor def build_graph_output(op_name, out): assert isinstance(out, Tensor) output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf() output_conf.in_0 = "in_0" # Set the default value, otherwise the parsing fails output_conf.out_0 = "out_0" output_conf_str = text_format.MessageToString(output_conf) output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( op_name, output_conf_str, ["in_0"], ["out_0"] ) fake_eager_out = _C.dispatch_fetch_output(output_op, out) return fake_eager_out ================================================ FILE: python/oneflow/framework/hob.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.support.high_order_bool import bool_functor """Example: @bool_functor("Current mode is %s" % rt_mode.GLOBAL_MODE) def in_global_mode(ctx): return rt_mode.CurrentMode() == rt_mode.GLOBAL_MODE """ ================================================ FILE: python/oneflow/framework/id_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal def UniqueStr(prefix): return oneflow._oneflow_internal.UniqueStr(prefix) ================================================ FILE: python/oneflow/framework/infer_compiler/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ try: import torch except ImportError: print("You should install torch also when use `oneflow.framework.infer_compiler`.") from .transform.custom_transform import register from .utils.patch_for_compiler import * from .with_fx_graph import fx_node_tranform from .with_fx_interpreter import OneFlowInterpreter from .with_oneflow_compile import compile_from_torch from .with_oneflow_backend import oneflow_backend ================================================ FILE: python/oneflow/framework/infer_compiler/import_tools/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ Tools for importing modules and packages""" from .importer import LazyMocker, import_module_from_path ================================================ FILE: python/oneflow/framework/infer_compiler/import_tools/format_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect from types import FunctionType from typing import Union class MockEntityNameFormatter: def __init__(self, prefix: str = "mock_", suffix: str = "_oflow"): self.prefix = prefix self.suffix = suffix def _format_pkg_name(self, pkg_name: str) -> str: if pkg_name.startswith(self.prefix) and pkg_name.endswith(self.suffix): return pkg_name return self.prefix + pkg_name + self.suffix def _reverse_pkg_name(self, pkg_name: str) -> str: assert pkg_name.startswith(self.prefix) and pkg_name.endswith( self.suffix ), f"Package name must start with {self.prefix} and end with {self.suffix}, but got {pkg_name}" return pkg_name[len(self.prefix) : -len(self.suffix)] def _format_full_class_name(self, obj: Union[str, type, FunctionType]): if isinstance(obj, type): obj = f"{obj.__module__}.{obj.__qualname__}" elif isinstance(obj, FunctionType): module = inspect.getmodule(obj) obj = f"{module.__name__}.{obj.__qualname__}" assert isinstance(obj, str), f"obj must be str, but got {type(obj)}" if "." in obj: pkg_name, cls_name = obj.split(".", 1) return f"{self._format_pkg_name(pkg_name)}.{cls_name}" else: return self._format_pkg_name(obj) def format(self, entity: Union[str, type, FunctionType]) -> str: return self._format_full_class_name(entity) def unformat(self, mock_entity_name: str) -> str: if "." in mock_entity_name: pkg_name, cls_name = mock_entity_name.split(".", 1) return f"{self._reverse_pkg_name(pkg_name)}.{cls_name}" else: # mock_entity_name is a pkg_name return self._reverse_pkg_name(mock_entity_name) ================================================ FILE: python/oneflow/framework/infer_compiler/import_tools/importer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import importlib import os import sys from pathlib import Path from types import FunctionType, ModuleType from typing import Optional, Union from oneflow.mock_torch import DynamicMockModule from .format_utils import MockEntityNameFormatter if sys.version_info < (3, 8): try: from importlib_metadata import requires except ImportError: import subprocess subprocess.check_call("pip install importlib_metadata", shell=True) subprocess.check_call("pip install packaging", shell=True) else: from importlib.metadata import requires __all__ = ["import_module_from_path", "LazyMocker", "is_need_mock"] def is_need_mock(cls) -> bool: assert isinstance(cls, (type, str)) main_pkg = cls.__module__.split(".")[0] try: pkgs = requires(main_pkg) except Exception as e: return True if pkgs: for pkg in pkgs: pkg = pkg.split(" ")[0] if pkg == "torch": return True return False return True def import_module_from_path(module_path: Union[str, Path]) -> ModuleType: if isinstance(module_path, Path): module_path = str(module_path) module_name = os.path.basename(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] if os.path.isfile(module_path): module_spec = importlib.util.spec_from_file_location(module_name, module_path) module_dir = os.path.split(module_path)[0] else: module_spec = importlib.util.spec_from_file_location( module_name, os.path.join(module_path, "__init__.py") ) module_dir = module_path module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) return module class LazyMocker: def __init__(self, prefix: str, suffix: str, tmp_dir: Optional[Union[str, Path]]): self.prefix = prefix self.suffix = suffix self.tmp_dir = tmp_dir self.mocked_packages = set() self.cleanup_list = [] def mock_package(self, package: str): pass def cleanup(self): pass def get_mock_entity_name(self, entity: Union[str, type, FunctionType]): formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix) full_obj_name = formatter.format(entity) return full_obj_name def mock_entity(self, entity: Union[str, type, FunctionType]): """Mock the entity and return the mocked entity Example: >>> mocker = LazyMocker(prefix="mock_", suffix="_of", tmp_dir="tmp") >>> mocker.mock_entity("models.DemoModel") >>> cls_obj = models.DemoModel >>> mocker.mock_entity(cls_obj) """ return self.load_entity_with_mock(entity) def add_mocked_package(self, package: str): if package in self.mocked_packages: return self.mocked_packages.add(package) package = sys.modules.get(package, None) # TODO remove code below # fix the mock error in https://github.com/siliconflow/oneflow/blob/main/python/oneflow/mock_torch/mock_importer.py#L105-L118 if package and getattr(package, "__file__", None) is not None: pkg_path = Path(package.__file__).parents[1] if pkg_path not in sys.path: sys.path.append(str(pkg_path)) def load_entity_with_mock(self, entity: Union[str, type, FunctionType]): formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix) full_obj_name = formatter.format(entity) attrs = full_obj_name.split(".") # add package path to sys.path to avoid mock error self.add_mocked_package(attrs[0]) mock_pkg = DynamicMockModule.from_package(attrs[0], verbose=False) for name in attrs[1:]: mock_pkg = getattr(mock_pkg, name) return mock_pkg ================================================ FILE: python/oneflow/framework/infer_compiler/transform/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """Module to convert PyTorch code to OneFlow.""" from .builtin_transform import ( ProxySubmodule, default_converter, get_attr, map_args, proxy_class, torch2oflow, ) from .custom_transform import register from .manager import transform_mgr ================================================ FILE: python/oneflow/framework/infer_compiler/transform/builtin_transform.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """Convert torch object to oneflow object.""" import importlib import os import types from collections import OrderedDict from collections.abc import Iterable from functools import partial, singledispatch from typing import Any, Union import oneflow as flow import torch from oneflow.framework.infer_compiler.import_tools.importer import is_need_mock from oneflow.framework.infer_compiler.utils.log_utils import logger from oneflow.framework.infer_compiler.utils.patch_for_diffusers import diffusers_checker from .manager import transform_mgr __all__ = [ "proxy_class", "ProxySubmodule", "map_args", "get_attr", "torch2oflow", "default_converter", ] def singledispatch_proxy(func): dispatcher = singledispatch(func) _warning_set = set() def wrapper(first_param, *args, **kwargs): nonlocal _warning_set before = first_param.__class__.__name__ result = dispatcher(first_param, *args, **kwargs) after = result.__class__.__name__ description = f"{before} transformed to {after}" if before not in after and description not in _warning_set: _warning_set.add(description) logger.info(f"instance_name: {description}") return result wrapper.register = dispatcher.register wrapper.dispatch = dispatcher.dispatch return wrapper def proxy_class(cls: type): if cls.__module__.startswith("torch."): mod_name = cls.__module__.replace("torch.", "oneflow.") mod = importlib.import_module(mod_name) return getattr(mod, cls.__name__) full_qualified_name = cls.__module__ + "." + cls.__qualname__ result = transform_mgr.transform_cls(full_qualified_name) return result class ProxySubmodule: def __init__(self, submod): self._oflow_proxy_submod = submod self._oflow_proxy_parameters = {} self._oflow_proxy_children = {} def __getitem__(self, index): # __getitem__ if isinstance(self._oflow_proxy_submod, Iterable): submod = self._oflow_proxy_submod[index] return torch2oflow(submod) raise RuntimeError(f"can't getitem for: {type(self._oflow_proxy_submod)}") def __repr__(self) -> str: return " oflow_proxy: " + self._oflow_proxy_submod.__repr__() def __getattribute__(self, attribute): if attribute.startswith("_oflow_proxy"): return object.__getattribute__(self, attribute) elif attribute in ["forward", "_conv_forward"]: replacement = proxy_class(type(self._oflow_proxy_submod)) return lambda *args, **kwargs: getattr(replacement, attribute)( self, *args, **kwargs ) elif ( diffusers_checker.is_attention_instance(self._oflow_proxy_submod) and attribute == "get_attention_scores" ): replacement = proxy_class(type(self._oflow_proxy_submod)) return lambda *args, **kwargs: getattr(replacement, attribute)( self, *args, **kwargs ) elif ( isinstance(self._oflow_proxy_submod, torch.nn.Linear) and attribute == "use_fused_matmul_bias" ): return ( self.bias is not None and os.getenv("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR") == "1" ) elif ( isinstance(self._oflow_proxy_submod, torch.nn.Dropout) and attribute == "generator" ): return flow.Generator() elif ( isinstance(self._oflow_proxy_submod, (torch.nn.Conv2d, torch.nn.Conv3d)) and attribute == "channel_pos" ): return "channels_first" else: a = getattr(self._oflow_proxy_submod, attribute) if isinstance(a, (torch.nn.parameter.Parameter, torch.Tensor)): # TODO(oneflow): assert a.requires_grad == False if attribute not in self._oflow_proxy_parameters: a = torch2oflow(a) self._oflow_proxy_parameters[attribute] = a else: a = self._oflow_proxy_parameters[attribute] elif isinstance( a, (torch.nn.Module, torch.nn.ModuleList, torch.nn.Sequential) ): if attribute not in self._oflow_proxy_children: a = torch2oflow(a) self._oflow_proxy_children[attribute] = a else: a = self._oflow_proxy_children[attribute] return a def __call__(self, *args: Any, **kwargs: Any) -> Any: replacement = proxy_class(type(self._oflow_proxy_submod)) if replacement is not None: return replacement.__call__(self, *args, **kwargs) else: raise RuntimeError( "can't find oneflow module for: " + str(type(self._oflow_proxy_submod)) ) @singledispatch_proxy def torch2oflow(mod, *args, **kwargs): return default_converter(mod, *args, **kwargs) def default_converter(obj, verbose=False, *, proxy_cls=None): if not is_need_mock(type(obj)): return obj try: new_obj_cls = proxy_class(type(obj)) if proxy_cls is None else proxy_cls def init(self): for k, _ in obj.__dict__.items(): attr = getattr(obj, k) self.__dict__[k] = torch2oflow(attr) of_obj_cls = type(str(new_obj_cls), (new_obj_cls,), {"__init__": init}) of_obj = of_obj_cls() if verbose: logger.info(f"convert {type(obj)} to {type(of_obj)}") return of_obj except Exception as e: logger.warning(f"Unsupported type: {type(obj)} {e}") # raise NotImplementedError(f"Unsupported type: {obj}") return obj @torch2oflow.register def _(mod: torch.nn.Module, verbose=False): proxy_md = ProxySubmodule(mod) new_md_cls = proxy_class(type(mod)) def init(self): nonlocal proxy_md flow.nn.Module.__init__(self) self._parameters = OrderedDict() self._buffers = OrderedDict() self._modules = OrderedDict() for n, p in list(proxy_md.named_parameters("", False)): self._parameters[n] = torch2oflow(p) for n, b in list(proxy_md.named_buffers("", False)): self._buffers[n] = flow.utils.tensor.from_torch(b.data) for n, m in proxy_md._modules.items(): self._modules[n] = torch2oflow(m) for k, _ in proxy_md.__dict__.items(): if k not in self.__dict__: attr = getattr(proxy_md, k) try: self.__dict__[k] = torch2oflow(attr) except Exception as e: logger.error(f"convert {type(attr)} failed: {e}") raise NotImplementedError(f"Unsupported type: {type(attr)}") def proxy_getattr(self, attr): nonlocal proxy_md try: return super().__getattribute__(attr) except: if attr in self._modules: return self._modules[attr] if attr in self._parameters: return self._parameters[attr] elif attr in self._buffers: return self._buffers[attr] else: return getattr(proxy_md, attr) of_mod_cls = type( str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr} ) of_mod = of_mod_cls() if of_mod.training: of_mod.training = False if verbose: logger.info( f""" Warning: {type(of_mod)} is in training mode and is turned into eval mode which is good for infrence optimation. """ ) if verbose: logger.info(f"convert {type(mod)} to {type(of_mod)}") return of_mod @torch2oflow.register def _(mod: torch.nn.BatchNorm1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_axis = 1 return of_mod @torch2oflow.register def _(mod: torch.nn.BatchNorm2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): of_mod.channel_axis = 3 else: of_mod.channel_axis = 1 return of_mod @torch2oflow.register def _(mod: torch.nn.BatchNorm3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_axis = 1 return of_mod @torch2oflow.register def _(mod: torch.nn.MaxPool1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.MaxPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): of_mod.channel_pos = "channels_last" else: of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.MaxPool3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.AvgPool1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.AvgPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): of_mod.channel_pos = "channels_last" else: of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.AvgPool3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_pos = "channels_first" return of_mod @torch2oflow.register def _(mod: torch.nn.AdaptiveAvgPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): of_mod.channel_pos = "channels_last" else: of_mod.channel_pos = "channels_first" return of_mod try: from torchvision.ops import Conv2dNormActivation @torch2oflow.register def _(mod: Conv2dNormActivation, verbose=False): return flow.nn.Sequential(*[torch2oflow(layer) for layer in mod]) except ImportError: logger.warning("Failed to import torchvision") @torch2oflow.register def _(mod: torch.nn.ModuleList, verbose=False): of_mod_list = flow.nn.ModuleList() for original_submod in mod: submod = torch2oflow(original_submod, verbose) of_mod_list.append(submod) return of_mod_list @torch2oflow.register def _(mod: torch.nn.Sequential, verbose=False): of_mod_list = [] for original_submod in mod: submod = torch2oflow(original_submod, verbose) of_mod_list.append(submod) of_mod_seq = proxy_class(type(mod))(*of_mod_list) return of_mod_seq @torch2oflow.register def _(mod: torch.nn.parameter.Parameter, verbose=False) -> flow.nn.Parameter: data = flow.utils.tensor.from_torch(mod.data) if mod.data.dtype == torch.int8: mod.requires_grad_(False) return flow.nn.Parameter(data.to(flow.int8), requires_grad=False) return flow.nn.Parameter(data, requires_grad=mod.requires_grad) @torch2oflow.register def _(mod: torch.Tensor, verbose=False) -> flow.Tensor: return flow.utils.tensor.from_torch(mod) _dtype_map = { "torch.float16": flow.float16, "torch.float32": flow.float32, "torch.double": flow.double, "torch.int8": flow.int8, "torch.int32": flow.int32, "torch.int64": flow.int64, "torch.uint8": flow.uint8, } @torch2oflow.register def _(mod: torch.dtype, verbose=False) -> flow.dtype: return _dtype_map[str(mod)] @torch2oflow.register def _(mod: list, verbose=False) -> list: return [torch2oflow(m, verbose) for m in mod] @torch2oflow.register def _(mod: tuple, verbose=False) -> tuple: return tuple(torch2oflow(m, verbose) for m in mod) @torch2oflow.register def _(mod: OrderedDict, verbose=False) -> OrderedDict: if "OrderedDict" not in f"{mod}": return default_converter(mod, verbose) else: return default_converter(mod, verbose, proxy_cls=OrderedDict) @torch2oflow.register def _(mod: set, verbose=False) -> set: return set(torch2oflow(m, verbose) for m in mod) @torch2oflow.register(int) @torch2oflow.register(float) @torch2oflow.register(str) @torch2oflow.register(bool) def _(mod, verbose=False) -> Union[int, float, str, bool]: return mod @torch2oflow.register def _(mod: None, verbose=False): return mod @torch2oflow.register def _(mod: types.BuiltinFunctionType, verbose=False): if hasattr(mod, "__module__"): mod_name = None if mod.__module__.startswith("torch._C._nn"): # The equivalence of mod inside torch._C._nn may be # defined in flow.nn.functional if getattr(flow.nn.functional, mod.__name__): mod_name = "oneflow.nn.functional" else: mod_name = mod.__module__.replace( "torch._C._nn", "oneflow._oneflow_internal._C" ) elif mod.__module__.startswith("torch"): try: if getattr(torch.nn.functional, mod.__name__) == mod: mod_name = "oneflow.nn.functional" except: mod_name = mod.__module__.replace("torch", "oneflow") if mod_name is not None: m = importlib.import_module(mod_name) return getattr(m, mod.__name__) return default_converter(mod, verbose) @torch2oflow.register def _(mod: torch.device, verbose=False): index = mod.index if mod.index is not None else 0 return flow.device(mod.type, index) @torch2oflow.register def _(mod: dict, verbose=False) -> dict: return {torch2oflow(k): torch2oflow(v, verbose) for k, v in mod.items()} @torch2oflow.register def _(func: types.FunctionType, verbose=False): return transform_mgr.transform_func(func) @torch2oflow.register def _(mod: partial, verbose=False): # https://docs.python.org/3/library/functools.html?highlight=partial#functools.partial func = torch2oflow(mod.func) args = torch2oflow(mod.args) keywords = torch2oflow(mod.keywords) return partial(func, *args, **keywords) ############################################## Code For Onefx ############################################## def map_args(args, kwargs): args = [torch2oflow(a) for a in args] kwargs = dict((k, torch2oflow(v)) for (k, v) in kwargs.items()) return (args, kwargs) def get_attr(gm, node, torch2flow={}): attr = getattr(gm, node.target) if attr in torch2flow: return torch2flow[attr] of_attr = torch2oflow(attr) torch2flow[attr] = of_attr return of_attr ================================================ FILE: python/oneflow/framework/infer_compiler/transform/custom_transform.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """A module for registering custom torch2oflow functions and classes.""" import inspect from pathlib import Path from typing import Callable, Dict, List, Optional, Union from oneflow.framework.infer_compiler.utils.log_utils import logger from .builtin_transform import torch2oflow from .manager import transform_mgr __all__ = ["register"] def register_torch2oflow_class(cls: type, replacement: type, verbose=True): try: key = transform_mgr.get_transformed_entity_name(cls) transform_mgr.update_class_proxies({key: replacement}, verbose=verbose) except Exception as e: logger.warning(f"Cannot register {cls} {replacement}. {e}") def register_torch2oflow_func(func, first_param_type=None, verbose=False): if first_param_type is None: params = inspect.signature(func).parameters first_param_type = params[list(params.keys())[0]].annotation if first_param_type == inspect._empty: logger.warning(f"Cannot register {func} {first_param_type}.") try: torch2oflow.register(first_param_type)(func) logger.debug(f"Register {func} {first_param_type}") if verbose: logger.info(f"Register {func} {first_param_type}") except Exception as e: logger.warning(f"Cannot register {func} {first_param_type}. {e}") def ensure_list(obj): if isinstance(obj, list): return obj return [obj] def register( *, package_names: Optional[List[Union[Path, str]]] = None, torch2oflow_class_map: Optional[Dict[type, type]] = None, torch2oflow_funcs: Optional[List[Callable]] = None, ): if package_names: package_names = ensure_list(package_names) transform_mgr.load_class_proxies_from_packages(package_names) if torch2oflow_class_map: for torch_cls, of_cls in torch2oflow_class_map.items(): register_torch2oflow_class(torch_cls, of_cls) if torch2oflow_funcs: torch2oflow_funcs = ensure_list(torch2oflow_funcs) for func in torch2oflow_funcs: register_torch2oflow_func(func) ================================================ FILE: python/oneflow/framework/infer_compiler/transform/manager.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import os import types from pathlib import Path from typing import Dict, List, Union from oneflow.framework.infer_compiler.import_tools.importer import LazyMocker from oneflow.framework.infer_compiler.utils.log_utils import logger __all__ = ["transform_mgr"] class TransformManager: """TransformManager __init__ args: `debug_mode`: Whether to print debug info. `tmp_dir`: The temp dir to store mock files. """ def __init__(self, debug_mode=False, tmp_dir="./output"): self.debug_mode = debug_mode self._torch_to_oflow_cls_map = {} self._setup_logger() self.mocker = LazyMocker(prefix="", suffix="", tmp_dir=None) def _setup_logger(self): name = "ONEDIFF" level = logging.DEBUG if self.debug_mode else logging.ERROR logger.configure_logging(name=name, file_name=None, level=level, log_dir=None) self.logger = logger def get_mocked_packages(self): return self.mocker.mocked_packages def load_class_proxies_from_packages(self, package_names: List[Union[Path, str]]): self.logger.debug(f"Loading modules: {package_names}") for package_name in package_names: self.mocker.mock_package(package_name) self.logger.info(f"Loaded Mock Torch Package: {package_name} successfully") def update_class_proxies(self, class_proxy_dict: Dict[str, type], verbose=True): """Update `_torch_to_oflow_cls_map` with `class_proxy_dict`. example: `class_proxy_dict = {"mock_torch.nn.Conv2d": flow.nn.Conv2d}` """ self._torch_to_oflow_cls_map.update(class_proxy_dict) debug_message = f"Updated class proxies: {len(class_proxy_dict)}" debug_message += f"\n{class_proxy_dict}\n" self.logger.debug(debug_message) def _transform_entity(self, entity): result = self.mocker.mock_entity(entity) if result is None: RuntimeError(f"Failed to transform entity: {entity}") return result def get_transformed_entity_name(self, entity): return self.mocker.get_mock_entity_name(entity) def transform_cls(self, full_cls_name: str): """Transform a class name to a mock class .""" mock_full_cls_name = self.get_transformed_entity_name(full_cls_name) if mock_full_cls_name in self._torch_to_oflow_cls_map: use_value = self._torch_to_oflow_cls_map[mock_full_cls_name] return use_value mock_cls = self._transform_entity(mock_full_cls_name) self._torch_to_oflow_cls_map[mock_full_cls_name] = mock_cls return mock_cls def transform_func(self, func: types.FunctionType): # TODO: support transform function cache return self._transform_entity(func) def transform_package(self, package_name): return self._transform_entity(package_name) debug_mode = os.getenv("ONEDIFF_DEBUG", "0") == "1" transform_mgr = TransformManager(debug_mode=debug_mode, tmp_dir=None) try: import pydantic if pydantic.VERSION < "2.5.2": logger.warning( f"Pydantic version {pydantic.VERSION} is too low, please upgrade to 2.5.2 or higher." ) from oneflow.mock_torch.mock_utils import MockEnableDisableMixin MockEnableDisableMixin.hazard_list.append( "huggingface_hub.inference._text_generation" ) except ImportError: pass ================================================ FILE: python/oneflow/framework/infer_compiler/utils/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled ================================================ FILE: python/oneflow/framework/infer_compiler/utils/args_tree_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import torch from oneflow.framework.args_tree import ArgsTree def input_output_processor(func): def process_input(*args, **kwargs): def input_fn(value): if isinstance(value, torch.Tensor): # TODO: https://github.com/siliconflow/sd-team/issues/109 return flow.utils.tensor.from_torch(value.contiguous()) else: return value args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) out = args_tree.map_leaf(input_fn) mapped_args = out[0] mapped_kwargs = out[1] return mapped_args, mapped_kwargs def process_output(output): def output_fn(value): if isinstance(value, flow.Tensor): return flow.utils.tensor.to_torch(value) else: return value out_tree = ArgsTree((output, None), False) out = out_tree.map_leaf(output_fn) return out[0] def wrapper(cls, *args, **kwargs): mapped_args, mapped_kwargs = process_input(*args, **kwargs) output = func(cls, *mapped_args, **mapped_kwargs) return process_output(output) return wrapper ================================================ FILE: python/oneflow/framework/infer_compiler/utils/cost_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect import time from functools import wraps import oneflow as flow from .log_utils import logger class cost_cnt: def __init__(self, debug=False, message="\t"): self.debug = debug self.message = message def __enter__(self): if not self.debug: return flow._oneflow_internal.eager.Sync() before_used = flow._oneflow_internal.GetCUDAMemoryUsed() before_host_used = flow._oneflow_internal.GetCPUMemoryUsed() logger.debug(f"====> {self.message} try to run...") logger.debug(f"{self.message} cuda mem before {before_used} MB") logger.debug(f"{self.message} host mem before {before_host_used} MB") self.before_used = before_used self.before_host_used = before_host_used self.start_time = time.time() def __exit__(self, exc_type, exc_val, exc_tb): if not self.debug: return flow._oneflow_internal.eager.Sync() end_time = time.time() after_used = flow._oneflow_internal.GetCUDAMemoryUsed() after_host_used = flow._oneflow_internal.GetCPUMemoryUsed() logger.debug(f"{self.message} run time {end_time - self.start_time} seconds") logger.debug(f"{self.message} cuda mem after {after_used} MB") logger.debug(f"{self.message} cuda mem diff {after_used - self.before_used} MB") logger.debug(f"{self.message} host mem after {after_host_used} MB") logger.debug( f"{self.message} host mem diff {after_host_used - self.before_host_used} MB" ) logger.debug(f"<==== {self.message} finish run.") def __call__(self, func): @wraps(func) def clocked(*args, **kwargs): if not self.debug: return func(*args, **kwargs) module = inspect.getmodule(func) logger.debug( f"==> function {module.__name__}.{func.__name__} try to run..." ) flow._oneflow_internal.eager.Sync() before_used = flow._oneflow_internal.GetCUDAMemoryUsed() logger.debug(f"{func.__name__} cuda mem before {before_used} MB") before_host_used = flow._oneflow_internal.GetCPUMemoryUsed() logger.debug(f"{func.__name__} host mem before {before_host_used} MB") start_time = time.time() out = func(*args, **kwargs) flow._oneflow_internal.eager.Sync() end_time = time.time() logger.debug(f"{func.__name__} run time {end_time - start_time} seconds") after_used = flow._oneflow_internal.GetCUDAMemoryUsed() logger.debug(f"{func.__name__} cuda mem after {after_used} MB") logger.debug(f"{func.__name__} cuda mem diff {after_used - before_used} MB") after_host_used = flow._oneflow_internal.GetCPUMemoryUsed() logger.debug(f"{func.__name__} host mem after {after_host_used} MB") logger.debug( f"{func.__name__} host mem diff {after_host_used - before_host_used} MB" ) logger.debug(f"<== function {func.__name__} finish run.") logger.debug("") return out return clocked ================================================ FILE: python/oneflow/framework/infer_compiler/utils/log_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import os import time from pathlib import Path class ColorFormatter(logging.Formatter): COLORS = { "DEBUG": "\033[34m", # Blue "INFO": "\033[92m", # green "WARNING": "\033[93m", # Yellow "ERROR": "\033[91m", # Red "CRITICAL": "\033[91m", # Red } def format(self, record): log_message = super().format(record) color = self.COLORS.get(record.levelname, "\033[0m") # Default to Reset color return f"{color}{log_message}\033[0m" class ConfigurableLogger: def __init__(self) -> None: self.logger = logging.getLogger(__name__) def __getattr__(self, name): return getattr(self.logger, name) def configure_logging(self, name, level, log_dir=None, file_name=None): logger = logging.getLogger(name) if logger.hasHandlers(): logger.warning("Logging handlers already exist for %s", name) return logger.setLevel(level) # Create a console formatter and add it to a console handler console_formatter = ColorFormatter( fmt="%(levelname)s [%(asctime)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) console_handler = logging.StreamHandler() console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) # Create a file formatter and add it to a file handler if log_dir is provided if log_dir: log_dir = Path(log_dir) os.makedirs(log_dir, exist_ok=True) file_prefix = "{}_".format( time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) ) if file_name: log_file_name = file_prefix + file_name else: log_file_name = file_prefix + name + ".log" log_file = log_dir / log_file_name file_formatter = logging.Formatter( fmt="%(levelname)s [%(asctime)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) self.logger = logger logger = ConfigurableLogger() ================================================ FILE: python/oneflow/framework/infer_compiler/utils/oneflow_exec_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow _ONEFLOW_EXEC_MODE = False class oneflow_exec_mode(object): def __init__(self, enabled=None): if enabled is not None: self.enabled = enabled else: self.enabled = True def __enter__(self): global _ONEFLOW_EXEC_MODE self.prev_mode = _ONEFLOW_EXEC_MODE _ONEFLOW_EXEC_MODE = self.enabled self.prev_grad_mode = flow.is_grad_enabled() _ = flow.set_grad_enabled(False) def __exit__(self, exc_type, exc_val, exc_tb): global _ONEFLOW_EXEC_MODE _ONEFLOW_EXEC_MODE = self.prev_mode _ = flow.set_grad_enabled(self.prev_grad_mode) def oneflow_exec_mode_enabled(): global _ONEFLOW_EXEC_MODE return _ONEFLOW_EXEC_MODE ================================================ FILE: python/oneflow/framework/infer_compiler/utils/param_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Any, Dict, List import oneflow as flow import torch def parse_device(args: List[Any], kwargs: Dict[str, Any]): if "device" in kwargs: return kwargs["device"] for x in args: if isinstance(x, (flow.device, torch.device)): return x if x in ["cpu", "cuda"]: return x return None def check_device(current_device, target_device) -> bool: def _convert(device): assert isinstance(device, (str, torch.device, flow.device)) if isinstance(device, torch.device): index = device.index if device.index is not None else 0 return flow.device(device.type, index) if isinstance(device, str): return flow.device(device) return device return _convert(current_device) == _convert(target_device) ================================================ FILE: python/oneflow/framework/infer_compiler/utils/patch_for_compiler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import oneflow as flow import oneflow.nn.functional as F class FakeCuda: @staticmethod def current_device(): return "cuda:0" @staticmethod def mem_get_info(dev): return 1024 * 1024 * 1024, 1024 * 1024 * 1024 @staticmethod def _scaled_dot_product_attention_math( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ): d_k = query.size(-1) if is_causal: assert attn_mask is None, "Cannot use both attn_mask and is_causal=True" L, S = query.size(-2), key.size(-2) attn_mask = flow.ones((L, S), dtype=flow.bool).tril() if attn_mask is not None: if attn_mask.dtype == flow.bool: new_attn_mask = flow.empty( attn_mask.shape, dtype=query.dtype, device=query.device ) mask = flow.logical_not(attn_mask) new_attn_mask.masked_fill_(mask, float("-inf")) attn_mask = new_attn_mask scores = flow.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if attn_mask is not None: scores.add_(attn_mask) p_attn = F.softmax(scores, dim=-1) if dropout_p > 0.0: generator = flow.Generator() p_attn = flow.nn.functional.dropout( p_attn, p=dropout_p, generator=generator ) return flow.matmul(p_attn, value) @staticmethod def scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ): """Scaled Dot-Product Attention Args: query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported. A boolean mask where a value of True indicates that the element *should* take part in attention. A float mask of the same type as query, key, value that is added to the attention score. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal are set. Returns: output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}` - :math:`L: \text{Target sequence length}` - :math:`E: \text{Embedding dimension of the query and key}` - :math:`Ev: \text{Embedding dimension of the value}` """ if attn_mask is not None or dropout_p > 0.0: return FakeCuda._scaled_dot_product_attention_math( query, key, value, attn_mask, dropout_p, is_causal ) batch_size, num_heads, target_seq_len, head_dim = query.shape out = flow._C.fused_multi_head_attention_inference_v2( query=query, query_layout="BHMK", query_head_size=head_dim, key=key, key_layout="BHMK", value=value, value_layout="BHMK", output_layout="BM(HK)", causal=is_causal, ) # (N, L, H x Ev) -> (N, H, L, Ev) value_embed_dim = value.shape[-1] out = out.view(batch_size, target_seq_len, num_heads, value_embed_dim).permute( 0, 2, 1, 3 ) return out flow.cuda.current_device = FakeCuda.current_device flow.cuda.mem_get_info = FakeCuda.mem_get_info flow.nn.functional.scaled_dot_product_attention = FakeCuda.scaled_dot_product_attention F.scaled_dot_product_attention = FakeCuda.scaled_dot_product_attention ================================================ FILE: python/oneflow/framework/infer_compiler/utils/patch_for_diffusers.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # TODO: remove this file to diffusers/src/infer_compiler_registry/register_diffusers from abc import ABC, abstractmethod from .log_utils import logger try: import diffusers from diffusers.models.attention_processor import Attention except ImportError: diffusers = None logger.warning("diffusers not found, some features will be disabled.") _IS_DIFFUSERS_AVAILABLE = diffusers is not None class InstanceChecker(ABC): @abstractmethod def is_attention_instance(self, instance): pass class DiffusersChecker(InstanceChecker): def is_attention_instance(self, instance): if not _IS_DIFFUSERS_AVAILABLE: return False return isinstance(instance, Attention) diffusers_checker = DiffusersChecker() ================================================ FILE: python/oneflow/framework/infer_compiler/with_fx_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import oneflow as flow import torch import torch.fx as fx from torch.fx.node import map_aggregate from .transform import get_attr, torch2oflow def fx_node_tranform(gm): of_gm = to_of_transform(gm) enable_graph = os.getenv("ONEDIFF_INFER_COMPILER_USE_GRAPH", "True").lower() in ( "true", "1", "t", ) if not enable_graph: oneflow_fn = of_gm.forward else: # Align this with env setting in `with_oneflow_compile`. # Otherwise, infererence using PyTorch with OneFlow backend on # multiple input shapes may crash os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1") os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") os.environ.setdefault("ONEFLOW_MLIR_CSE", "1") os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1") os.environ.setdefault("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") os.environ.setdefault("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") os.environ.setdefault("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", "1") os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL", "1") os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "0") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", "1") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", "1") os.environ.setdefault( "ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1" ) os.environ.setdefault( "ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1" ) os.environ.setdefault("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", "1") os.environ.setdefault("ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", "1") os.environ.setdefault("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", "1") os.environ.setdefault("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", "1") os.environ.setdefault("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "1") os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1") class OfGraph(flow.nn.Graph): def __init__(self): super().__init__() self.fx_md = of_gm self.config.enable_cudnn_conv_heuristic_search_algo(False) self.config.allow_fuse_add_to_output(True) def build(self, *args, **kwargs): return self.fx_md(*args, **kwargs) of_g = OfGraph() oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) return oneflow_fn def to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: name2node = {} name2obj = {} torch2flow = {} of_g = flow.fx.Graph() modules = dict(gm.named_modules()) for node in gm.graph.nodes: if node.op == "placeholder": of_node = of_g.create_node("placeholder", node.target) name2node[node.name] = of_node elif node.op == "output": of_node = of_g.output(node_replace_args(node.args, name2node)[0]) name2node[node.name] = of_node elif node.op == "call_function": of_node = of_g.create_node( "call_function", torch2oflow(node.target), args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node), ) name2node[node.name] = of_node elif node.op == "call_method": of_node = of_g.create_node( "call_method", node.target, args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node), ) name2node[node.name] = of_node elif node.op == "call_module": torch_md = modules[node.target] name2obj[node.target] = torch2oflow(torch_md) of_node = of_g.create_node( "call_module", node.target, args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node), ) name2node[node.name] = of_node elif node.op == "get_attr": of_node = of_g.create_node("get_attr", node.target) name2node[node.name] = of_node name2obj[node.target] = get_attr(gm, node, torch2flow) else: raise ValueError(f"not valid node type{node.foramt_node()}") of_gm = flow.fx.GraphModule(name2obj, of_g) of_gm.training = False of_gm.graph.lint() of_gm.recompile() return of_gm def replace_node(node, name2node): if isinstance(node, torch.fx.Node): return name2node[node.name] else: return torch2oflow(node) def node_replace_args(args, name2node): return map_aggregate(args, lambda node: replace_node(node, name2node)) ================================================ FILE: python/oneflow/framework/infer_compiler/with_fx_interpreter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Any, Dict, Tuple import torch from oneflow.framework.infer_compiler.transform.builtin_transform import torch2oflow from .transform import ProxySubmodule, map_args class OneFlowInterpreter(torch.fx.Interpreter): from torch.fx.node import Argument, Target def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: args, kwargs = map_args(args, kwargs) target = torch2oflow(target) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: args, kwargs = map_args(args, kwargs) return super().call_method(target, args, kwargs) def call_module( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: submod = self.fetch_attr(target) submod = ProxySubmodule(submod) return submod(*args, **kwargs) ================================================ FILE: python/oneflow/framework/infer_compiler/with_oneflow_backend.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import torch import oneflow as flow from oneflow.framework.args_tree import ArgsTree from .with_fx_graph import fx_node_tranform from .with_fx_interpreter import OneFlowInterpreter def oneflow_backend(gm, example_inputs, *args, **kwargs): with_interp = os.getenv( "ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False" ).lower() in ("true", "1", "t",) if not with_interp: transformed_fn = fx_node_tranform(gm) def wrapped_forward(*args, **kwargs): def input_fn(value): if isinstance(value, torch.Tensor): return flow.utils.tensor.from_torch(value.contiguous()) else: return value args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) out = args_tree.map_leaf(input_fn) args = out[0] if with_interp: output = OneFlowInterpreter(gm, garbage_collect_values=False).run( *args, **kwargs ) else: output = transformed_fn(*args, **kwargs) if isinstance(output, tuple): return tuple(flow.utils.tensor.to_torch(i) for i in output) return flow.utils.tensor.to_torch(output) return wrapped_forward ================================================ FILE: python/oneflow/framework/infer_compiler/with_oneflow_compile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import types from functools import wraps from itertools import chain from typing import Any import oneflow as flow import torch from oneflow.utils.tensor import to_torch from .transform.builtin_transform import torch2oflow from .transform.manager import transform_mgr from .utils.args_tree_util import input_output_processor from .utils.cost_util import cost_cnt from .utils.log_utils import logger from .utils.oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled from .utils.param_utils import check_device, parse_device class DualModule(torch.nn.Module): def __init__(self, torch_module, oneflow_module): torch.nn.Module.__init__(self) self._torch_module = torch_module self._oneflow_module = oneflow_module @property def oneflow_module(self): if self._oneflow_module is not None: return self._oneflow_module logger.debug(f"Convert {type(self._torch_module)} ...") self._oneflow_module = torch2oflow(self._torch_module) logger.debug(f"Convert {id(self._torch_module)} done!") return self._oneflow_module @oneflow_module.deleter def oneflow_module(self): if self._oneflow_module: del self._oneflow_module setattr(self, "_oneflow_module", None) def to(self, *args, **kwargs): if oneflow_exec_mode_enabled(): self._oneflow_module.to(*args, **kwargs) else: if self._oneflow_module is not None: args = [torch2oflow(v) for v in args] kwargs = {k: torch2oflow(v) for k, v in kwargs.items()} self._oneflow_module.to(*args, **kwargs) self._torch_module_to_with_check(*args, **kwargs) else: self._torch_module.to(*args, **kwargs) def _torch_module_to_with_check(self, *args, **kwargs): def _align_tensor(torch_module, oneflow_module): oneflow_tensor_list = set( [x for x, _ in oneflow_module.named_parameters()] + [x for x, _ in oneflow_module.named_buffers()] ) for name, tensor in chain.from_iterable( [torch_module.named_parameters(), torch_module.named_buffers(),] ): if name not in oneflow_tensor_list: tensor.data = tensor.to(*args, **kwargs) else: oneflow_tensor = oneflow_module.get_parameter(name) if oneflow_tensor is None: tensor.data = tensor.to(*args, **kwargs) elif tensor.data_ptr() != oneflow_tensor.data_ptr(): tensor.data = to_torch(oneflow_tensor.data) oneflow_module_list = set([x for x, _ in self._oneflow_module.named_modules()]) for name, module in self._torch_module.named_modules(): if name not in oneflow_module_list: module.to(*args, **kwargs) else: _align_tensor(module, self._oneflow_module.get_submodule(name)) def __getattr__(self, name): if name == "_torch_module": return self._modules[name] if name == "_oneflow_module": return super().__getattribute__(name) torch_attr = getattr(self._torch_module, name) oneflow_attr = ( None if self._oneflow_module is None else getattr(self._oneflow_module, name) ) if isinstance(torch_attr, torch.nn.ModuleList): oneflow_attr = ( [None] * len(torch_attr) if oneflow_attr is None else oneflow_attr ) return DualModuleList(torch_attr, oneflow_attr) elif isinstance(torch_attr, torch.nn.Module): return get_mixed_dual_module(torch_attr.__class__)(torch_attr, oneflow_attr) else: return oneflow_attr if oneflow_exec_mode_enabled() else torch_attr def __setattr__(self, name: str, value: Any) -> None: if name in ["_torch_module", "_oneflow_module"]: super().__setattr__(name, value) else: # TODO: aviod memory up when set attr try: setattr(self._torch_module, name, value) value = torch2oflow(value) if isinstance(value, flow.Tensor): obj = getattr(self._oneflow_module, name) obj.copy_(value) else: setattr(self._oneflow_module, name, value) except: super().__setattr__(name, value) class DualModuleList(torch.nn.ModuleList): def __init__(self, torch_modules, oneflow_modules): super().__init__() assert len(torch_modules) == len(oneflow_modules) self._torch_modules = torch_modules self._oneflow_modules = oneflow_modules dual_modules = [] for torch_module, oneflow_module in zip( self._torch_modules, self._oneflow_modules ): dual_modules.append( get_mixed_dual_module(torch_module.__class__)( torch_module, oneflow_module ) ) # clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules self._modules.clear() self += dual_modules def __setitem__(self, idx: int, module: DualModule): idx = self._get_abs_string_index(idx) setattr(self._torch_modules, str(idx), module._torch_module) setattr(self._oneflow_modules, str(idx), module._oneflow_module) return setattr(self, str(idx), module) def __setattr__(self, name, value): if name in ("_torch_modules", "_oneflow_modules"): return object.__setattr__(self, name, value) try: if isinstance(value, DualModule): setattr(self._torch_modules, name, value._torch_module) setattr(self._oneflow_modules, name, value._oneflow_module) else: setattr(self._torch_modules, name, value) value = torch2oflow(value) setattr(self._oneflow_modules, name, value) except: super().__setattr__(name, value) def get_mixed_dual_module(module_cls): class MixedDualModule(DualModule, module_cls): def __init__(self, torch_module, oneflow_module): DualModule.__init__(self, torch_module, oneflow_module) return MixedDualModule def graph_file_management(func): @wraps(func) def wrapper(self: "DeployableModule", *args, **kwargs): graph_file = self._deployable_module_options.get("graph_file", None) # Load graph file if graph_file is not None: try: if not os.path.exists(graph_file): logger.warning( f"Graph file {graph_file} not exists!, will generate graph." ) else: graph_device = self._deployable_module_options.get( "graph_file_device", None ) self.load_graph(graph_file, torch2oflow(graph_device)) logger.info(f"Load graph file: {graph_file}") graph_file = None self._deployable_module_options["graph_file"] = None except Exception as e: logger.error(f"Load graph file: {graph_file} failed! {e}") ret = func(self, *args, **kwargs) # Save graph file if graph_file is not None: try: if graph_file is not None: os.makedirs(os.path.dirname(graph_file), exist_ok=True) self.save_graph(graph_file) logger.info(f"Save graph file: {graph_file} done!") except Exception as e: logger.error(f"Save graph file: {graph_file} failed! {e}") finally: self._deployable_module_options["graph_file"] = None return ret return wrapper def handle_deployable_exception(func): @wraps(func) def wrapper(self, *args, **kwargs): if transform_mgr.debug_mode: return func(self, *args, **kwargs) else: try: return func(self, *args, **kwargs) except Exception as e: logger.error(f"Exception in {func.__name__}: {e}") logger.warning("Recompile oneflow module ...") del self._deployable_module_model.oneflow_module self._deployable_module_dpl_graph = None return func(self, *args, **kwargs) return wrapper class DeployableModule(torch.nn.Module): def __init__( self, torch_module, oneflow_module, use_graph=True, options={}, graph_path=None, graph_device=None, ): torch.nn.Module.__init__(self) self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)( torch_module, oneflow_module ) self._deployable_module_use_graph = use_graph self._deployable_module_options = options self._deployable_module_dpl_graph = None self._is_raw_deployable_module = True @classmethod def from_existing(cls, existing_module, use_graph=None, options=None): torch_module = existing_module._deployable_module_model._torch_module oneflow_module = existing_module._deployable_module_model._oneflow_module instance = cls(torch_module, oneflow_module, use_graph, options) instance._deployable_module_dpl_graph = ( existing_module._deployable_module_dpl_graph if use_graph else None ) return instance def get_graph(self): if self._deployable_module_dpl_graph is not None: return self._deployable_module_dpl_graph if "size" in self._deployable_module_options: size = self._deployable_module_options["size"] else: size = 9 if "dynamic" in self._deployable_module_options: dynamic = self._deployable_module_options["dynamic"] else: dynamic = True self._deployable_module_dpl_graph = get_oneflow_graph( self._deployable_module_model.oneflow_module, size, dynamic ) if "debug" in self._deployable_module_options: self._deployable_module_dpl_graph.debug( self._deployable_module_options["debug"] ) return self._deployable_module_dpl_graph @input_output_processor @handle_deployable_exception @graph_file_management def apply_model(self, *args, **kwargs): if self._deployable_module_use_graph: dpl_graph = self.get_graph() with oneflow_exec_mode(): output = dpl_graph(*args, **kwargs) else: with oneflow_exec_mode(): output = self._deployable_module_model.oneflow_module.apply_model( *args, **kwargs ) return output @input_output_processor @handle_deployable_exception @graph_file_management def __call__(self, *args, **kwargs): if self._deployable_module_use_graph: dpl_graph = self.get_graph() with oneflow_exec_mode(): output = dpl_graph(*args, **kwargs) else: with oneflow_exec_mode(): output = self._deployable_module_model.oneflow_module(*args, **kwargs) return output def to(self, *args, **kwargs): if self._deployable_module_dpl_graph is None: self._deployable_module_model.to(*args, **kwargs) return self # assert the target device is same as graph device target_device = parse_device(args, kwargs) if ( target_device is not None and len(self._deployable_module_dpl_graph._blocks) > 0 ): current_device = next(self._deployable_module_dpl_graph._state()).device if not check_device(current_device, target_device): raise RuntimeError( f"After graph built, the device of graph can't be modified, current device: {current_device}, target device: {target_device}" ) self._deployable_module_model.to(*args, **kwargs) return self # TODO(): Just for transformers VAE decoder @input_output_processor @handle_deployable_exception @graph_file_management def decode(self, *args, **kwargs): if self._deployable_module_use_graph: def _build(graph, *args, **kwargs): return graph.model.decode(*args, **kwargs) dpl_graph = self.get_graph() dpl_graph.build = types.MethodType(_build, dpl_graph) with oneflow_exec_mode(): output = dpl_graph(*args, **kwargs) else: with oneflow_exec_mode(): output = self._deployable_module_model.oneflow_module.decode( *args, **kwargs ) return output def __getattr__(self, name): if name in self._modules: return self._modules[name] return getattr(self._deployable_module_model, name) def load_graph(self, file_path, device=None, run_warmup=True): self.get_graph().warmup_with_load(file_path, device, run_warmup) def warmup_with_load(self, file_path, device=None, run_warmup=True): self.get_graph().warmup_with_load(file_path, device, run_warmup) def save_graph(self, file_path): self.get_graph().save_graph(file_path) class OneflowGraph(flow.nn.Graph): @flow.nn.Graph.with_dynamic_input_shape() def __init__(self, model): super().__init__(enable_get_runtime_state_dict=True) self.model = model self.config.enable_cudnn_conv_heuristic_search_algo(False) self.config.allow_fuse_add_to_output(True) os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") os.environ.setdefault("ONEFLOW_MLIR_CSE", "1") os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1") os.environ.setdefault("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") os.environ.setdefault("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") os.environ.setdefault("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", "1") os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL", "1") # TODO(lml): enable ONEFLOW_MLIR_PREFER_NHWC when related bug fix. os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "0") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", "1") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", "1") os.environ.setdefault( "ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1" ) os.environ.setdefault( "ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1" ) os.environ.setdefault("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", "1") os.environ.setdefault("ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", "1") os.environ.setdefault("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", "1") os.environ.setdefault("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", "1") os.environ.setdefault("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "1") os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1") def build(self, *args, **kwargs): return self.model(*args, **kwargs) @cost_cnt(transform_mgr.debug_mode) def warmup_with_load(self, file_path, device=None, run_warmup=True): state_dict = flow.load(file_path) if device is not None: state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device) self.load_runtime_state_dict(state_dict, warmup_with_run=run_warmup) @cost_cnt(transform_mgr.debug_mode) def save_graph(self, file_path): state_dict = self.runtime_state_dict() flow.save(state_dict, file_path) def get_oneflow_graph(model, size=9, dynamic=True): g = OneflowGraph(model) g._dynamic_input_graph_cache.set_cache_size(size) g._dynamic_input_graph_cache.enable_shared(dynamic) return g def state_dict_hook(module, state_dict, prefix, local_metadata): pytorch_key_prefix = "_deployable_module_model._torch_module." new_state_dict = type(state_dict)() for k, v in state_dict.items(): # _deployable_module_model._torch_module.out.2.weight => out.2.weight if k.startswith(pytorch_key_prefix): new_k = k[len(pytorch_key_prefix) :] new_state_dict[new_k] = v else: new_state_dict[k] = v return new_state_dict # Return a DeployableModule that using module_cls as it's parent class. def get_mixed_deployable_module(module_cls): class MixedDeployableModule(DeployableModule, module_cls): def __init__( self, torch_module, oneflow_module, use_graph=True, options={}, graph_path=None, graph_device=None, ): DeployableModule.__init__( self, torch_module, oneflow_module, use_graph, options, graph_path, graph_device, ) self._is_raw_deployable_module = False @classmethod def from_existing(cls, existing_module, use_graph=None, options=None): torch_module = existing_module._deployable_module_model._torch_module oneflow_module = existing_module._deployable_module_model._oneflow_module instance = cls(torch_module, oneflow_module, use_graph, options) instance._deployable_module_dpl_graph = ( existing_module._deployable_module_dpl_graph if use_graph else None ) return instance return MixedDeployableModule def compile_from_torch( torch_module: torch.nn.Module, *, use_graph=True, options={}, ): """ Converts torch module to oneflow module. Note: Map from torch to oneflow should be registered by `infer_compiler.register(torch2oflow_class_map={TorchModule: OneflowModule})` before `compile_from_torch` be called. Args: torch_module (torch.nn.Module): Torch module to be compiled. use_graph (bool, optional): If `True`, graph of compiled module can be saved and loaded to speedup the compile process. Defaults to `True`. options (dict, optional): size (int, optional): graph cache size. Defaults to `9`. dynamic (bool, optional): If `True`, graph of compiled module can be shared with other modules. Defaults to `True`. debug (int, optional): debug level. Defaults to `-1`. Returns: DeployableModule: Compiled oneflow module. """ def wrap_module(module): if isinstance(module, DeployableModule): assert not module._is_raw_deployable_module return module.__class__.from_existing(module, use_graph, options) else: return get_mixed_deployable_module(module.__class__)( module, None, use_graph, options ) model = wrap_module(torch_module) assert isinstance(model, DeployableModule) assert isinstance(model, torch_module.__class__) model._register_state_dict_hook(state_dict_hook) return model ================================================ FILE: python/oneflow/framework/job_set_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional, TypeVar from oneflow.core.job.job_set_pb2 import JobSet _VT = TypeVar("_VT") _default_job_set = JobSet() ================================================ FILE: python/oneflow/framework/model.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ __all__ = [ "DataModule", "NumpyDataModule", "TrainingConfig", "ValidationConfig", "CheckpointConfig", "Callback", "Model", ] import inspect from abc import ABC from typing import Any, List, Optional, Tuple, Union import numpy as np import oneflow as flow import oneflow._oneflow_internal import oneflow.framework.dtype as dtype_util from oneflow.framework.function_util import FunctionConfig as ExecutionConfig from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module from oneflow.optim.optimizer import Optimizer as OOPOptimizer class DataModule(Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, step_idx: int = 0, optimizer_idx: int = 0): pass def infer_oneflow_data_placeholder( self, batch: Tuple[Any] = None, optimizer_idx: int = 0 ): return None class NumpyDataModule(DataModule): def __init__(self, *args, **kwargs): super().__init__() def forward(self, step_idx: int = 0, optimizer_idx: int = 0): pass def __call__(self, *args): ret = self.forward(*args) return ret def infer_oneflow_data_placeholder( self, batch: Tuple[np.ndarray, ...] = None, optimizer_idx: int = 0 ): assert isinstance(batch, tuple), "model.NumpyDataModule must return a tuple." data_placeholder_list = [] for item in batch: assert isinstance( item, np.ndarray ), "model.NumpyDataModule must return a tuple of numpy." of_dtype = dtype_util.convert_numpy_dtype_to_oneflow_dtype(item.dtype) # numpy_placeholder = oneflow_typing.Numpy.Placeholder( # shape=item.shape, dtype=of_dtype # ) data_placeholder_list.append(numpy_placeholder) return data_placeholder_list class TrainingConfig: def __init__(self): super().__init__() self.exe_cfg = ExecutionConfig() self.data = None self.error_msg = "" def config_execution(self, exe_cfg: ExecutionConfig = None): self.exe_cfg = exe_cfg def config_data(self, data: DataModule = None): self.data = data def check_valid(self): is_valid = True self.error_msg = "" if not isinstance(self.exe_cfg, ExecutionConfig): self.error_msg += "model.TrainingConfig exe_cfg is not ExecutionConfig;" is_valid = False if self.data is None: self.error_msg += "model.TrainingConfig data is None;" is_valid = False if not isinstance(self.data, DataModule): self.error_msg += "model.TrainingConfig data is not DataModule;" is_valid = False return is_valid class ValidationConfig: def __init__(self): super().__init__() self.exe_cfg = ExecutionConfig() self.data = None self.step_interval = 10 self.error_msg = "" def config_execution(self, exe_cfg: ExecutionConfig = None): self.exe_cfg = exe_cfg def config_data(self, data: DataModule = None): self.data = data def config_step_interval(self, step_interval: int = 1): self.step_interval = step_interval def check_valid(self): is_valid = True self.error_msg = "" if self.data is None: self.error_msg += "model.ValidationConfig data is None;" is_valid = False if not isinstance(self.data, DataModule): self.error_msg += "model.ValidationConfig data is not DataModule;" is_valid = False if self.step_interval <= 0 or not isinstance(self.step_interval, int): self.error_msg += ( "model.ValidationConfig step_interval is <= 0 or is not int;" ) is_valid = False return is_valid class CheckpointConfig(object): def __init__(self): self.need_load = False self.load_dirpath = None self.need_save = False self.save_dirpath = None self.save_step_interval = 1 self.error_msg = "" def config_load(self, dirpath: str = None): self.need_load = True assert dirpath is not None, "dirpath should not be None" self.load_dirpath = dirpath def config_save(self, dirpath: str = None, step_interval: int = 1): self.need_save = True self.save_dirpath = dirpath assert dirpath is not None, "dirpath should not be None" self.save_step_interval = step_interval assert step_interval > 0, "step_interval should not <= 0" assert isinstance(step_interval, int), "step_interval should be int" def check_valid(self): is_valid = True self.error_msg = "" return is_valid class Callback(ABC): """ Abstract base class used to build new callbacks. """ def on_training_step_end( self, outputs: Optional[Union[Tensor, Tuple[Tensor, ...]]], step_idx: int = 0, optimizer_idx: int = 0, ): pass def on_validation_step_end( self, outputs: Optional[Union[Tensor, Tuple[Tensor, ...]]], step_idx: int = 0, ): pass class Model(ABC, Module): """A high level API for model training and validation. """ def __init__(self, *args, **kwargs): super().__init__() self._is_deprecated_function_style = ( kwargs["is_deprecated_function_style"] if "is_deprecated_function_style" in kwargs else False ) def forward(self, *args, **kwargs): """Same as `nn.Module.forward()`, here is to define the operations you want to use for prediction. """ raise NotImplementedError def training_step(self, *args, **kwargs): """Operates on a single batch of data from the training set and return loss. """ raise NotImplementedError() def validation_step(self, *args, **kwargs): """Operates on a single batch of data from the validation set. """ raise NotImplementedError() def configure_optimizers(self): """Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. """ raise NotImplementedError() def fit( self, training_config: Optional[TrainingConfig] = None, validation_config: Optional[ValidationConfig] = None, checkpoint_config: Optional[CheckpointConfig] = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, max_steps: int = 100, ): """ Runs the full training and validation routine. """ self._max_steps = max_steps self._sub_models = self._get_and_check_sub_models( training_config, validation_config, checkpoint_config, callbacks ) if len(self._sub_models) == 0: return if self._checkpoint_model.is_valid: self._checkpoint_model.load() for step_idx in range(0, self._max_steps): for sub_model in self._sub_models: try: sub_model.step(step_idx) except Exception as e: print( "Model step_idx {} {} failed.".format(step_idx, sub_model.name) ) raise e def method_overrided(self, method_name: str = None) -> bool: return getattr(self.__class__, method_name) != getattr(Model, method_name) def _get_and_check_sub_models( self, training_config: Optional[TrainingConfig] = None, validation_config: Optional[ValidationConfig] = None, checkpoint_config: Optional[CheckpointConfig] = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): sub_models = [] self._train_model = ( TrainModel(training_config, self, callbacks) if self._is_deprecated_function_style else TrainModelOOPStyle(training_config, self, callbacks) ) if self._train_model.is_valid: sub_models.append(self._train_model) elif training_config is not None: print( self._train_model.error_msg, "{}'s fit() will not do training.".format(self.__class__.__name__), ) self._val_model = ( ValidateModel(validation_config, self, callbacks) if self._is_deprecated_function_style else ValidateModelOOPStyle(validation_config, self, callbacks) ) if self._val_model.is_valid: sub_models.append(self._val_model) elif validation_config is not None: print( self._val_model.error_msg, "{}'s fit() will not do validation.".format(self.__class__.__name__), ) if len(sub_models) == 0: print( "{}'s fit() will do nothing because there has no valid configuration.".format( self.__class__.__name__ ) ) return sub_models self._checkpoint_model = ( CheckpointModel(checkpoint_config, self, callbacks) if self._is_deprecated_function_style else CheckpointModelOOPStyle(checkpoint_config, self, callbacks) ) if self._checkpoint_model.is_valid: sub_models.append(self._checkpoint_model) elif checkpoint_config is not None: print( self._checkpoint_model.error_msg, "{}'s fit() will not do checkpoint.".format(self.__class__.__name__), ) return sub_models class SubModel(ABC): def __init__(self, name, cfg, model, callbacks): self._cfg = cfg assert isinstance(model, Model) self._model = model self._cbs = callbacks self.name = name self.is_valid = True self.error_msg = ( self._model.__class__.__name__ + " " + self.name + " error message: " ) if not self._get_and_check_cfg(): self.is_valid = False if not self._get_and_check_cbs(): self.is_valid = False def step(self, step_idx: int = 0): raise NotImplementedError def _get_and_check_cfg(self): if self._cfg is None: self.error_msg += "config is None;" return False if not self._cfg.check_valid(): self.error_msg += self._cfg.error_msg return False else: return True def _get_and_check_cbs(self): if self._cbs is None: self._cbs = [] return True if isinstance(self._cbs, Callback): self._cbs = [self._cbs] return True if isinstance(self._cbs, list): for cb in self._cbs: assert isinstance( cb, Callback ), "model callbacks' type must be model.Callback or List[model.Callback]." return True assert ( False ), "model callbacks' type must be model.Callback or List[model.Callback]." def _method_callback(self, method_name: str = None, *args, **kwargs): for cb in self._cbs: method = getattr(cb, method_name) method(*args, **kwargs) class TrainModel(SubModel): def __init__( self, cfg: TrainingConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("training", cfg, model, callbacks) if not self._get_and_check_step(): self.is_valid = False if not self._get_and_check_opts(): self.is_valid = False if self.is_valid and (not self._get_and_check_jobs()): self.is_valid = False def step(self, step_idx: int = 0): assert self.is_valid, self.error_msg for optimizer_idx in range(0, len(self._opts)): outputs = None if self._is_numpy_input: batch = None if step_idx == 0: batch = self._first_numpy_batch[optimizer_idx] else: batch = self._cfg.data(step_idx, optimizer_idx) outputs = self._jobs[optimizer_idx](*batch).get() else: outputs = self._jobs[optimizer_idx]().get() self._method_callback( "on_training_step_end", outputs=outputs, step_idx=step_idx, optimizer_idx=optimizer_idx, ) def _get_and_check_step(self): if not self._model.method_overrided("training_step"): self.error_msg += "model.training_step() is empty;" return False else: return True def _get_and_check_opts(self): self._opts = [] if not self._model.method_overrided("configure_optimizers"): self.error_msg += "model.configure_optimizers() is empty;" return False opt_conf = self._model.configure_optimizers() if isinstance(opt_conf, Optimizer): self._opts = [opt_conf] elif isinstance(opt_conf, (list, tuple)): for opt in opt_conf: assert isinstance( opt, Optimizer ), "model.configure_optimizers() must return Optimizer or List[Optimizer, ...] or Tuple[Optimizer, ...]" self._opts = opt_conf else: assert ( False ), "model.configure_optimizers() must return Optimizer or List[Optimizer, ...] or Tuple[Optimizer, ...]" return True def _get_and_check_jobs(self): self._is_numpy_input = ( True if isinstance(self._cfg.data, NumpyDataModule) else False ) self._jobs = [] if self._is_numpy_input: self._first_numpy_batch = [] for optimizer_idx in range(0, len(self._opts)): batch = self._cfg.data(0, optimizer_idx) self._first_numpy_batch.insert(optimizer_idx, batch) self._jobs.insert( optimizer_idx, self._construct_numpy_job(batch, optimizer_idx) ) else: for optimizer_idx in range(0, len(self._opts)): self._jobs.insert(optimizer_idx, self._construct_job(optimizer_idx)) return True def _construct_job(self, optimizer_idx: int = 0): def job(): batch = self._cfg.data(0, optimizer_idx) outputs = self._model.training_step( batch=batch, optimizer_idx=optimizer_idx ) loss = None if isinstance(outputs, tuple) and len(outputs) > 0: loss = outputs[0] else: loss = outputs self._opts[optimizer_idx].minimize(loss) return outputs job.__name__ = ( self._model.__class__.__name__ + "_Model_train_job_" + str(optimizer_idx) ) deco # = api_oneflow_function(type="train", function_config=self._cfg.exe_cfg) return deco(job) def _construct_numpy_job(self, batch, optimizer_idx): def job(*input_batch): outputs = self._model.training_step( batch=input_batch, optimizer_idx=optimizer_idx ) loss = None if isinstance(outputs, tuple) and len(outputs) > 0: loss = outputs[0] else: loss = outputs self._opts[optimizer_idx].minimize(loss) return outputs _infer_job_signature(self._cfg.data, batch, optimizer_idx, job) job.__name__ = ( self._model.__class__.__name__ + "_Model_train_numpy_job_" + str(optimizer_idx) ) deco # = api_oneflow_function(type="train", function_config=self._cfg.exe_cfg) return deco(job) class ValidateModel(SubModel): def __init__( self, cfg: ValidationConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("validation", cfg, model, callbacks) if not self._get_and_check_step(): self.is_valid = False if self.is_valid and (not self._get_and_check_job()): self.is_valid = False def step(self, step_idx: int = 0): assert self.is_valid if (step_idx + 1) % self._cfg.step_interval == 0: outputs = None if self._is_numpy_input: batch = None if step_idx == 0: batch = self._first_numpy_batch else: batch = self._cfg.data(step_idx, 0) outputs = self._job(*batch).get() else: outputs = self._job().get() self._method_callback( "on_validation_step_end", step_idx=step_idx, outputs=outputs ) def _get_and_check_step(self): if not self._model.method_overrided("validation_step"): self.error_msg += "model.validation_step() is empty;" return False else: return True def _get_and_check_job(self): self._is_numpy_input = ( True if isinstance(self._cfg.data, NumpyDataModule) else False ) self._job = None if not self._is_numpy_input: self._job = self._construct_job() else: batch = self._cfg.data(0, 0) self._first_numpy_batch = batch self._job = self._construct_numpy_job(batch) return True def _construct_job(self): def job(): batch = self._cfg.data(0, 0) return self._model.validation_step(batch) job.__name__ = self._model.__class__.__name__ + "_Model_eval_job" deco # = api_oneflow_function(type="predict", function_config=self._cfg.exe_cfg) return deco(job) def _construct_numpy_job(self, batch: Tuple[np.ndarray, ...] = None): def job(*input_batch): return self._model.validation_step(batch=input_batch) _infer_job_signature(self._cfg.data, batch, 0, job) job.__name__ = self._model.__class__.__name__ + "_Model_eval_numpy_job" deco # = api_oneflow_function(type="predict", function_config=self._cfg.exe_cfg) return deco(job) class CheckpointModel(SubModel): def __init__( self, cfg: CheckpointConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("checkpointing", cfg, model, callbacks) def load(self): assert self.is_valid if self._cfg.need_load: self._load_checkpoint(self._cfg.load_dirpath) def step(self, step_idx: int = 0): assert self.is_valid if self._cfg.need_save: if (step_idx + 1) % self._cfg.save_step_interval == 0: self._save_checkpoint( dirpath=self._cfg.save_dirpath + "-" + str(step_idx) ) def _load_checkpoint(self, dirpath: str): """Load model states from a checkpoint. """ stat_dict = flow.load(path=dirpath) self._model.load_state_dict(stat_dict) def _save_checkpoint(self, dirpath: str): """Save model states as a checkpoint. """ stat_dict = self._model.state_dict() flow.save(stat_dict, dirpath) class TrainModelOOPStyle(SubModel): def __init__( self, cfg: TrainingConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("training", cfg, model, callbacks) if not self._get_and_check_step(): self.is_valid = False if not self._get_and_check_opts(): self.is_valid = False def step(self, step_idx: int = 0): assert self.is_valid, self.error_msg for optimizer_idx in range(0, len(self._opts)): batch = self._cfg.data(step_idx, optimizer_idx) outputs = self._model.training_step( batch=batch, optimizer_idx=optimizer_idx ) loss = None if isinstance(outputs, tuple) and len(outputs) > 0: loss = outputs[0] else: loss = outputs loss.backward() opt = self._opts[optimizer_idx] opt.step() opt.zero_grad() self._method_callback( "on_training_step_end", outputs=outputs, step_idx=step_idx, optimizer_idx=optimizer_idx, ) def _get_and_check_step(self): if not self._model.method_overrided("training_step"): self.error_msg += "model.training_step() is empty;" return False else: return True def _get_and_check_opts(self): self._opts = [] if not self._model.method_overrided("configure_optimizers"): self.error_msg += "model.configure_optimizers() is empty;" return False opt_conf = self._model.configure_optimizers() if isinstance(opt_conf, OOPOptimizer): self._opts = [opt_conf] elif isinstance(opt_conf, (list, tuple)): for opt in opt_conf: assert isinstance( opt, OOPOptimizer ), "model.configure_optimizers() must return Optimizer or List[Optimizer, ...] or Tuple[Optimizer, ...]" self._opts = opt_conf else: assert ( False ), "model.configure_optimizers() must return Optimizer or List[Optimizer, ...] or Tuple[Optimizer, ...]" return True class ValidateModelOOPStyle(SubModel): def __init__( self, cfg: ValidationConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("validation", cfg, model, callbacks) if not self._get_and_check_step(): self.is_valid = False def step(self, step_idx: int = 0): assert self.is_valid if (step_idx + 1) % self._cfg.step_interval == 0: outputs = None with oneflow._oneflow_internal.autograd.no_grad(): inputs = self._cfg.data(step_idx, 0) model_previous_mode = self._model.training self._model.train() outputs = self._model.validation_step(inputs) self._model.train(model_previous_mode) self._method_callback( "on_validation_step_end", step_idx=step_idx, outputs=outputs ) def _get_and_check_step(self): if not self._model.method_overrided("validation_step"): self.error_msg += "model.validation_step() is empty;" return False else: return True class CheckpointModelOOPStyle(SubModel): def __init__( self, cfg: CheckpointConfig = None, model: Model = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, ): super().__init__("checkpointing", cfg, model, callbacks) def load(self): assert self.is_valid if self._cfg.need_load: self._load_checkpoint(self._cfg.load_dirpath) def step(self, step_idx: int = 0): assert self.is_valid if self._cfg.need_save: if (step_idx + 1) % self._cfg.save_step_interval == 0: self._save_checkpoint( dirpath=self._cfg.save_dirpath + "-" + str(step_idx) ) def _load_checkpoint(self, dirpath: str): """Load model states from a checkpoint. """ stat_dict = flow.load(path=dirpath) self._model.load_state_dict(stat_dict) def _save_checkpoint(self, dirpath: str): """Save model states as a checkpoint. """ stat_dict = self._model.state_dict() flow.save(stat_dict, dirpath) def _infer_job_signature(data_module, batch, optimizer_idx, job): para_list = [] placeholder_list = data_module.infer_oneflow_data_placeholder(batch, optimizer_idx) for (i, placeholder) in enumerate(placeholder_list): para_name = ( data_module.__class__.__name__ + "_opt_" + str(optimizer_idx) + "_para_" + str(i) ) para_list.append( inspect.Parameter( name=para_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=placeholder, ) ) origin_sig = inspect.signature(job) new_sig = origin_sig.replace(parameters=para_list) job.__oneflow_function_signature__ = new_sig ================================================ FILE: python/oneflow/framework/multi_client_session.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import enum import inspect from google.protobuf import text_format import oneflow._oneflow_internal import oneflow.core.job.job_set_pb2 as job_set_util import oneflow.framework.c_api_util as c_api_util import oneflow.framework.env_util as env_util import oneflow.core.job.resource_pb2 as resource_pb class MultiClientSession(object): class Status(enum.Enum): CREATED = 1 INITED = 2 CLOSED = 3 def __init__(self, env, sess_id): self._id = sess_id self._env = env assert self._env is not None # New a MultiClientSessionContext self._session_ctx = oneflow._oneflow_internal.SessionContext(self._env._env_cxt) self.config_proto_ = self._make_config_proto() self.function_flag_name2default_val_ = {} self._update_function_flag_name2defaultVal() self.scope_attr_name2default_val_ = {} self._update_scope_attr_name2defaultVal() self.status_ = self.Status.CREATED def __del__(self): if self._env.is_shutting_down(): # After python shutting down, it's not safe to call oneflow return self._TryClose() def TryInit(self): self._check_status(self.Status.CREATED, self.Status.INITED) if self.status_ == self.Status.CREATED: config_proto_str = text_format.MessageToString(self.config_proto) self._session_ctx.try_init(config_proto_str) self.status_ = self.Status.INITED def _TryClose(self): if self.status_ != self.Status.CLOSED: oneflow._oneflow_internal.ClearSessionId(self.id) self.status_ = self.Status.CLOSED @property def status(self): return self.status_ @property def id(self): return self._id @property def config_proto(self): return self.config_proto_ @property def resource(self): self._check_status(self.Status.INITED) return c_api_util.CurrentResource() @property def function_flag_name2default_val(self): return self.function_flag_name2default_val_ @property def scope_attr_name2default_val(self): return self.scope_attr_name2default_val_ @property def is_running(self): return self.status_ == self.Status.INITED def _check_status(self, *status): check_success = False for stat in status: if self.status_ == stat: check_success = True break if check_success is False: caller_func_name = inspect.stack()[1].function allowed_status = " or ".join([str(stat) for stat in status]) raise ValueError( "The calling to {} is only allowed when status is {}, but current status is {}".format( caller_func_name, allowed_status, self.status_ ) ) def _make_config_proto(self): config_proto = job_set_util.ConfigProto() config_proto.resource.SetInParent() config_proto.session_id = self.id return config_proto def _update_function_flag_name2defaultVal(self): items = c_api_util.GetFunctionConfigDef().attr_name2attr_def.items() self.function_flag_name2default_val_ = {k: v.default_val for (k, v) in items} def _update_scope_attr_name2defaultVal(self): items = c_api_util.GetScopeConfigDef().attr_name2attr_def.items() self.scope_attr_name2default_val_ = {k: v.default_val for (k, v) in items} def update_resource_eagerly(self, resource_config): self._check_status(self.Status.INITED) config_proto_str = text_format.MessageToString(resource_config) self._session_ctx.update_resource(config_proto_str) ================================================ FILE: python/oneflow/framework/register_class_method_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal import oneflow.framework.check_point_v2 as check_point_v2 import oneflow.framework.tensor as tensor_util def RegisterMethod4Class(): tensor_util.RegisterMethods() check_point_v2.RegisterMethods() ================================================ FILE: python/oneflow/framework/scope_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import traceback from contextlib import contextmanager from google.protobuf import text_format import oneflow._oneflow_internal import oneflow.core.job.scope_pb2 as scope_pb2_util import oneflow.framework.attr_util as attr_util import oneflow.framework.session_context as session_ctx from oneflow import oneflow_deprecate def api_scope_config(**kwargs): name2default = session_ctx.GetDefaultSession().scope_attr_name2default_val def SetScopeProtoStr(serialized_scope_proto: str): scope_proto = text_format.Parse( serialized_scope_proto, scope_pb2_util.ScopeProto() ) for (attr_name, py_value) in kwargs.items(): assert attr_name in name2default attr_util.SetProtoAttrValue( scope_proto.attr_name2attr_value[attr_name], py_value, name2default[attr_name], ) return str(text_format.MessageToString(scope_proto)) sess = session_ctx.GetDefaultSession() scope = MakeScope( lambda old_scope, builder: builder.BuildScopeByProtoStrSetter( old_scope, SetScopeProtoStr ) ) return ScopeContext(scope) def current_scope(): """ Return current scope """ return oneflow._oneflow_internal.GetCurrentScope() from oneflow import oneflow_deprecate def MakeScope(build_func): scope = None old_scope = oneflow._oneflow_internal.GetCurrentScope() assert old_scope is not None def BuildScope(builder): nonlocal scope scope = build_func(old_scope, builder) assert scope is not None oneflow._oneflow_internal.deprecated.PhysicalRun(BuildScope) return scope @contextmanager def ScopeContext(scope): old_scope = oneflow._oneflow_internal.GetCurrentScope() oneflow._oneflow_internal.GlobalScopeStackPush(scope) try: yield finally: assert oneflow._oneflow_internal.GetCurrentScope() is scope oneflow._oneflow_internal.GlobalScopeStackPop() assert oneflow._oneflow_internal.GetCurrentScope() is old_scope ================================================ FILE: python/oneflow/framework/session_context.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools import oneflow import oneflow._oneflow_internal from oneflow.framework.multi_client_session import MultiClientSession class SessionStatus: OPEN = "OPEN" RUNNING = "RUNNING" CLOSED = "CLOSED" def GetDefaultSession(): global _sess_id2sess default_sess_id = oneflow._oneflow_internal.GetDefaultSessionId() assert default_sess_id in _sess_id2sess return _sess_id2sess[default_sess_id] def NewDefaultSession(env): session_id = oneflow._oneflow_internal.NewSessionId() assert oneflow._oneflow_internal.RegsterSessionId(session_id) new_default_sess = MultiClientSession(env, session_id) global _sess_id2sess assert new_default_sess.id not in _sess_id2sess _sess_id2sess[new_default_sess.id] = new_default_sess def TryCloseDefaultSession(): global _sess_id2sess default_sess_id = oneflow._oneflow_internal.GetDefaultSessionId() assert default_sess_id in _sess_id2sess if default_sess_id in _sess_id2sess: del _sess_id2sess[default_sess_id] # Try clear to avoid using this outdated session. oneflow._oneflow_internal.ClearSessionId(default_sess_id) def try_init_default_session(func): @functools.wraps(func) def Func(*args, **kwargs): GetDefaultSession().TryInit() return func(*args, **kwargs) return Func _sess_id2sess = {} ================================================ FILE: python/oneflow/framework/sysconfig.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import imp import importlib.util import os from typing import List import oneflow import oneflow._oneflow_internal def get_include() -> str: return os.path.join(os.path.dirname(oneflow.__file__), "include") def get_lib() -> str: return os.path.dirname(oneflow.__file__) def get_compile_flags() -> List[str]: flags = [] flags.append("-I{}".format(get_include())) flags.append("-DHALF_ENABLE_CPP11_USER_LITERALS=0") if oneflow._oneflow_internal.flags.with_cuda(): flags.append("-DWITH_CUDA") if oneflow._oneflow_internal.flags.use_cxx11_abi(): flags.append("-D_GLIBCXX_USE_CXX11_ABI=1") else: flags.append("-D_GLIBCXX_USE_CXX11_ABI=0") return flags def get_liboneflow_link_flags() -> List[str]: oneflow_python_module_path = get_lib() # path in a pip release oneflow_python_libs_path = f"{oneflow_python_module_path}.libs" # path in a cmake build dir if not os.path.exists(oneflow_python_libs_path): from oneflow.version import __cmake_project_binary_dir__ oneflow_python_libs_path = __cmake_project_binary_dir__ return [ f"-L{oneflow_python_libs_path}", f"-l:oneflow", f"-l:of_protoobj", ] def get_link_flags() -> List[str]: flags = [] flags.append("-L{}".format(get_lib())) (file, oneflow_internal_lib_path, _) = imp.find_module( "_oneflow_internal", [get_lib()] ) if file: file.close() flags.append("-l:{}".format(os.path.basename(oneflow_internal_lib_path))) return flags def with_cuda() -> bool: return oneflow._oneflow_internal.flags.with_cuda() def get_cuda_version() -> int: return oneflow._oneflow_internal.flags.cuda_version() def has_rpc_backend_grpc() -> bool: return oneflow._oneflow_internal.flags.has_rpc_backend_grpc() def has_rpc_backend_local() -> bool: return oneflow._oneflow_internal.flags.has_rpc_backend_local() def cmake_build_type() -> str: return oneflow._oneflow_internal.flags.cmake_build_type() def with_rdma() -> bool: return oneflow._oneflow_internal.flags.with_rdma() ================================================ FILE: python/oneflow/framework/tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from numbers import Number import oneflow as flow import oneflow.framework.tensor_str as tensor_str import oneflow._oneflow_internal.lazy_mode as lazy_mode import numpy as np from typing import Union Tensor = flow._oneflow_internal.Tensor TensorTuple = flow._oneflow_internal.TensorTuple def _ndim(self): return len(self.shape) def _backward(self, gradient=None, retain_graph=False, create_graph=False): if lazy_mode.is_enabled(): assert ( self.is_lazy ), "nn.Graph only accept lazy tensor to call backward() in lazy mode." assert ( not retain_graph ), "nn.Graph donot accept 'retain_graph' argument in backward() at the moment." assert ( not create_graph ), "nn.Graph donot accept 'create_graph' argument in backward() at the moment." flow._oneflow_internal.nn.graph.AddTensorAsGraphLoss(self) flow.autograd.backward(self, gradient, retain_graph, create_graph) def _str(self): return self.__repr__() def _repr(self): return tensor_str._gen_tensor_str(self) def _meta_repr(self): return tensor_str._gen_tensor_meta_str(self) def _eq(self, other): if self is None and other is None: return True elif self is None or other is None: return False else: return flow._C.broadcast_equal(self, other) def _cuda(self, device: Union[int, str, flow.device] = None): if device is None: device = "cuda" elif isinstance(device, int): device = "cuda:" + str(device) return self.to(device=device) def _norm(self, p=None, dim=None, keepdim=False, dtype=None): if type(p) == str or dim != None: return flow._C.norm(self, p, dim, keepdim, dtype=dtype) return flow._C.norm(self, p, dim, keepdim, dtype=dtype, for_norm=True) def is_nonzero(input): r""" is_nonzero(input) -> (bool) Returns True if the :attr:`input` is a single element tensor which is not equal to zero after type conversions. i.e. not equal to ``flow.tensor([0.])`` or ``flow.tensor([0])``. Throws a ``RuntimeError`` if ``input.shape.numel() != 1`` For Example: .. code-block:: python >>> import oneflow as flow >>> flow.is_nonzero(flow.tensor([0.])) False >>> flow.is_nonzero(flow.tensor([1.5])) True >>> flow.is_nonzero(flow.tensor([3])) True """ shape = input.shape if shape.numel() == 0: raise RuntimeError("bool value of Tensor with no values is ambiguous") if shape.numel() > 1: raise RuntimeError("bool value of Tensor with more than one value is ambiguous") value = input.numpy().item() return bool(value) def _add(self, other, *, alpha=1): return flow._C.add(self, other, alpha=alpha) def _addmm(self, mat1, mat2, alpha=1, beta=1): return flow.addmm(self, mat1, mat2, alpha, beta) def _add_inplace(self, other, *, alpha=1): return flow._C.add(self, other, alpha=alpha, inplace=True) def _iadd(self, other): return self.add_(other) def _sub_inplace(self, other): return flow._C.sub(self, other, inplace=True) def _expand(self, *size): return flow.expand(self, *size) def _expand_as(input, other): return flow.expand(input, *other.size()) def _argwhere(self): return flow.argwhere(self) def _index(self): assert self.numel() == 1 and self.dtype in ( flow.uint8, flow.int8, flow.int32, flow.int64, flow.bool, ), "Only integer tensors of a single element can be converted to an index" return self.numpy().item() def _scalar_float(self): assert ( self.numel() == 1 ), "only one element tensors can be converted to Python scalars" return self.numpy().astype(np.float64).item() def _scalar_int(self): assert ( self.numel() == 1 ), "only one element tensors can be converted to Python scalars" return self.numpy().astype(np.int64).item() def _new_empty( self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): return flow.new_empty(self, size, dtype, device, placement, sbp, requires_grad) def _new_ones( self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): return flow.new_ones(self, size, dtype, device, placement, sbp, requires_grad) def _new_zeros( self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): return flow.new_zeros(self, size, dtype, device, placement, sbp, requires_grad) def _squeeze_inplace(self, dim=None): return flow._C.squeeze_(self, dim=dim) def _unsqueeze_inplace(self, dim=None): return flow._C.unsqueeze_(self, dim=dim) def _new_full( self, size, fill_value, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): return flow.new_full( self, size, fill_value, dtype, device, placement, sbp, requires_grad ) def _argsort(self, dim=-1, descending=None): return flow.argsort(self, dim=dim, descending=descending) def _uniform(self, a=0, b=1): return flow.nn.init.uniform_(self, a, b) def _exponential(self, lambd=1.0, generator=None): return flow._C.exponential_(self, lambd, generator) def _trunc_normal_( self, mean=0.0, std=1.0, a=-2.0, b=2.0, ): return flow.nn.init.trunc_normal_(self, mean=mean, std=std, a=a, b=b) def _kaiming_uniform( self, a=0, mode="fan_in", nonlinearity="leaky_relu", *, data_format="NCHW" ): return flow.nn.init.kaiming_uniform_( self, a=a, mode=mode, nonlinearity=nonlinearity, data_format=data_format ) def _kaiming_normal( self, a=0, mode="fan_in", nonlinearity="leaky_relu", *, data_format="NCHW" ): return flow.nn.init.kaiming_normal_( self, a=a, mode=mode, nonlinearity=nonlinearity, data_format=data_format ) def _xavier_normal(self, gain=1.0): return flow.nn.init.xavier_normal_(self, gain=gain, data_format=data_format) def _xavier_uniform(self, gain=1.0): return flow.nn.init.xavier_uniform_(self, gain=gain, data_format=data_format) def _orthogonal(self, gain=1.0): if self.ndimension() < 2: raise ValueError("Only tensors with 2 or more dimensions are supported") rows = self.shape[0] cols = np.prod(self.shape[1:]) flattened = np.random.normal(0.0, 1.0, size=(rows, cols)) if rows < cols: flattened = flattened.T # TODO q, r = np.linalg.qr(flattened) d = np.diag(r, 0) d = np.sign(d) q *= d if rows < cols: q = q.T self = gain * flow.tensor(q.reshape(self.shape)) return self def _normal(self, mean=0, std=1): return flow.nn.init.normal_(self, mean=mean, std=std) def _copy_from_numpy_to_eager_local_tensor(eager_local_tensor, np_arr): assert np_arr.dtype == flow.convert_oneflow_dtype_to_numpy_dtype( eager_local_tensor.dtype ) assert np_arr.shape == tuple(eager_local_tensor.shape) eager_local_tensor._copy_from_numpy(np_arr) def _copy(self, other: Union[Tensor, np.ndarray]): if isinstance(other, np.ndarray): other = flow.from_numpy(other) elif not isinstance(other, Tensor): other = flow.tensor(other) other = other.to(self.dtype) if self.is_global: assert other.is_global, "Only global tensor can be assigned to global tensor." if not (self.sbp == other.sbp and self.placement == other.placement): other_cpu_placement = flow.placement("cpu", other.placement.ranks) other = other.to_global(placement=other_cpu_placement) self_cpu_placement = flow.placement("cpu", self.placement.ranks) other = other.to_global(placement=self_cpu_placement, sbp=self.sbp) flow._C.assign_local_tensor(self.to_local(), other.to_local()) else: assert other.is_local, "Only local tensor can be assigned to local tensor." other = flow._C.broadcast_like(other, self) if not self.is_contiguous(): # NOTE: slice_update support non-contiguous input tensor with flow.no_grad(): self[...] = other else: flow._C.assign_local_tensor(self, other) def _format(self, format_spec): if self.dim() == 0: return self.numpy().tolist().__format__(format_spec) return object.__format__(self, format_spec) def _to(self, *args, **kwargs): new_args = list() # If device is single int, replace it with flow.device("cuda:{device}") if len(args) > 0 and isinstance(args[0], int): new_args.append(flow.device(f"cuda:{args[0]}")) for i in range(1, len(args)): new_args.append(args[i]) else: new_args = args if ("device" in kwargs) and isinstance(kwargs["device"], int): kwargs["device"] = flow.device(f"cuda:{kwargs['device']}") return flow._C.to(self, *new_args, **kwargs) def _tolist(self): if self.numel() == 1 and self.ndim == 0: return self.item() return self.numpy().tolist() def _repeat(self, *sizes): if len(sizes) == 1: new_sizes = sizes[0] if isinstance(new_sizes, int): new_sizes = (new_sizes,) else: new_sizes = sizes return flow._C.repeat(self, new_sizes) def _tile(self, *dims): if len(dims) == 1: new_dims = dims[0] if isinstance(new_dims, int): new_dims = (new_dims,) else: new_dims = dims return flow._C.tile(self, new_dims) def _T(self): return flow._C.T(self) def _nms(boxes, scores, iou_threshold: float): return flow.nms(boxes, scores, iou_threshold) def _nonzero(self, as_tuple=False): return flow.nonzero(self, as_tuple) def _prod(self, dim=[], keepdim=False): return flow.prod(self, dim, keepdim) def _masked_select(self, mask): return flow.masked_select(self, mask) def _sort(self, dim: int = -1, descending: bool = False): return flow.sort(self, dim, descending) def _where(self, x=None, y=None): return flow.where(self, x, y) def _numpy(self, dtype=None): assert ( not self.is_lazy ), "tensor.numpy() is not allowed to be called in nn.Graph.build(*args) or be called by lazy tensor." if self.is_global: if self.placement.type == "meta": raise TypeError("can't convert meta device type global tensor to numpy.") else: if self.device.type == "meta": raise TypeError("can't convert meta device type local tensor to numpy.") if self.dtype == flow.tensor_buffer: shapes, dtypes = self._tensor_buffer_shapes_and_dtypes tensors = flow.tensor_buffer_to_list_of_tensors(self, shapes, dtypes) return [t.numpy() for t in tensors] # TODO: support bfloat16 to numpy in C++ if self.dtype == flow.bfloat16: self = self.to(flow.float32) if self.is_global: self_cpu_placement = flow.placement("cpu", self.placement.ranks) self = ( self.to_global(placement=self_cpu_placement) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast) .to_local() ) assert self.is_local if self.device != flow.device("cpu"): self = self.cpu() result = self.to_numpy() if dtype is None: return result return result.astype(dtype) def zero_(self): self.zero_() return self def _is_consistent(self): raise RuntimeError(".is_consistent has been removed, please use .is_global instead") def _to_consistent(self, *args, **kwargs): raise RuntimeError(".to_consistent has been removed, please use .to_global instead") def _new_tensor( self, data, dtype=None, device=None, requires_grad=False, placement=None, sbp=None ): if dtype is None: dtype = self.dtype if self.is_local: assert ( placement is None and sbp is None ), "self is local tensor, placement and sbp are expected to be None." if device is None: device = self.device return flow.tensor( data, dtype=dtype, device=device, requires_grad=requires_grad ) else: assert device is None, "self is global tensor, device is expected to be None." if placement is None: placement = self.placement if sbp is None: sbp = self.sbp return flow.tensor( data, dtype=dtype, placement=placement, sbp=sbp, requires_grad=requires_grad ) def _cumsum(self, dim, dtype=None): return flow._C.cumsum(self, dim, dtype=dtype) def _cumprod(self, dim, dtype=None): return flow._C.cumprod(self, dim, dtype=dtype) def _cross(self, other, dim=None): return flow._C.cross(self, other, dim) def _scatter(self, dim, index, src, *, reduce=None): return flow._C.scatter(self, dim, index, src, reduce=reduce, inplace=False) def _scatter_inplace(self, dim, index, src, *, reduce=None): return flow._C.scatter(self, dim, index, src, reduce=reduce, inplace=True) def _scatter_add_inplace(self, dim, index, src): return flow._C.scatter_add(self, dim, index, src, inplace=True) def _contains(self, element): r"""Check if `element` is present in tensor Args: element (Tensor or scalar): element to be checked for presence in current tensor" """ if isinstance(element, (flow.Tensor, Number)): # type hint doesn't understand the __contains__ result array return (element == self).any().item() # type: ignore[union-attr] raise RuntimeError( "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." % type(element) ) def _allclose(self, other, atol=1e-08, rtol=1e-05, equal_nan=False): return flow._C.allclose(self, other, atol, rtol, equal_nan) def _index_add(self, dim, index, source, alpha=1): return flow._C.index_add(self, dim, index, source, alpha) def _index_add_inplace(self, dim, index, source, alpha=1): return flow._C.index_add_(self, dim, index, source, alpha) def _as_strided(self, size, stride, storage_offset=0): return flow._C.as_strided(self, size, stride, storage_offset) def _as_strided_inplace(self, size, stride, storage_offset=0): return flow._C.as_strided_(self, size, stride, storage_offset) def _logaddexp(self, other): return flow._C.logaddexp(self, other) def _real(self): return flow._C.real(self) def _imag(self): return flow._C.imag(self) def _conj(self): return flow._C.conj(self) def _conj_physical(self): return flow._C.conj_physical(self) def _storage(self): return self @property def _layout(self): return flow.strided def RegisterMethods(): Tensor.ndim = property(_ndim) Tensor.numpy = _numpy Tensor.add = _add Tensor.add_ = _add_inplace Tensor.sub_ = _sub_inplace Tensor.backward = _backward Tensor.__str__ = _str Tensor.__repr__ = _repr Tensor.__contains__ = _contains Tensor.__bool__ = is_nonzero Tensor.__iadd__ = _iadd Tensor.addmm = _addmm Tensor.__format__ = _format Tensor.__index__ = _index Tensor.__float__ = _scalar_float Tensor.__int__ = _scalar_int Tensor.__array__ = _numpy Tensor.uniform_ = _uniform Tensor.exponential_ = _exponential Tensor.trunc_normal_ = _trunc_normal_ Tensor.kaiming_uniform_ = _kaiming_uniform Tensor.kaiming_normal_ = _kaiming_normal Tensor.xavier_normal_ = _xavier_normal Tensor.xavier_uniform_ = _xavier_uniform Tensor.orthogonal_ = _orthogonal Tensor.normal_ = _normal Tensor.copy_ = _copy Tensor._meta_repr = _meta_repr Tensor.argsort = _argsort Tensor.argwhere = _argwhere Tensor.expand = _expand Tensor.expand_as = _expand_as Tensor.new_empty = _new_empty Tensor.new_ones = _new_ones Tensor.new_zeros = _new_zeros Tensor.new_full = _new_full Tensor.squeeze_ = _squeeze_inplace Tensor.unsqueeze_ = _unsqueeze_inplace Tensor.where = _where Tensor.norm = _norm Tensor.repeat = _repeat Tensor.tile = _tile Tensor.to = _to Tensor.T = property(_T) Tensor.masked_select = _masked_select Tensor.eq = _eq Tensor.sort = _sort Tensor.tolist = _tolist Tensor.nms = _nms Tensor.nonzero = _nonzero Tensor.prod = _prod Tensor.is_consistent = _is_consistent Tensor.to_consistent = _to_consistent Tensor.new_tensor = _new_tensor Tensor.cumsum = _cumsum Tensor.cumprod = _cumprod Tensor.cross = _cross Tensor.scatter = _scatter Tensor.scatter_ = _scatter_inplace Tensor.scatter_add_ = _scatter_add_inplace Tensor.allclose = _allclose Tensor.index_add = _index_add Tensor.index_add_ = _index_add_inplace Tensor.as_strided = _as_strided Tensor.as_strided_ = _as_strided_inplace Tensor.logaddexp = _logaddexp Tensor.real = _real Tensor.imag = _imag Tensor.conj = _conj Tensor.conj_physical = _conj_physical Tensor.layout = _layout Tensor.storage = _storage def register_tensor_op(op_name): def set_tensor_op(method): setattr(Tensor, op_name, method) return method return set_tensor_op ================================================ FILE: python/oneflow/framework/tensor_str.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ This file is mostly referenced from PyTorch v1.8.1 torch/_tensor_str.py """ import math import numpy as np from typing import Optional import oneflow as flow from oneflow.framework.tensor_str_util import _autoset_linewidth from oneflow.framework.tensor_str_util import _try_convert_to_local_tensor class __PrinterOptions(object): precision: int = 4 threshold: float = 1000 edgeitems: int = 3 userset_linewidth: int = None sci_mode: Optional[bool] = None autoset_linewidth: bool = True @property def linewidth(self): return ( _autoset_linewidth() if self.autoset_linewidth else self.userset_linewidth ) @linewidth.setter def linewidth(self, value): self.userset_linewidth = value PRINT_OPTS = __PrinterOptions() def set_printoptions( precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None, ): r"""Set options for printing. Items shamelessly taken from NumPy Args: precision: Number of digits of precision for floating point output (default = 4). threshold: Total number of array elements which trigger summarization rather than full `repr` (default = 1000). edgeitems: Number of array items in summary at beginning and end of each dimension (default = 3). linewidth: The number of characters per line for the purpose of inserting line breaks (default = terminal_columns). profile: Sane defaults for pretty printing. Can override with any of the above options. (any one of `default`, `short`, `full`) sci_mode: Enable (True) or disable (False) scientific notation. If None (default) is specified, the value is defined by `oneflow._tensor_str._Formatter`. This value is automatically chosen by the framework. .. note:: linewidth equals to terminal columns, manual setting will invalidate the default automatic setting. """ if profile is not None: if profile == "default": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 elif profile == "short": PRINT_OPTS.precision = 2 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 2 PRINT_OPTS.linewidth = 80 elif profile == "full": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = math.inf PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 if precision is not None: PRINT_OPTS.precision = precision if threshold is not None: PRINT_OPTS.threshold = threshold if edgeitems is not None: PRINT_OPTS.edgeitems = edgeitems if linewidth is not None: PRINT_OPTS.linewidth = linewidth PRINT_OPTS.sci_mode = sci_mode if profile is not None or linewidth is not None: PRINT_OPTS.autoset_linewidth = False class _Formatter(object): def __init__(self, tensor): self.floating_dtype = tensor.dtype.is_floating_point self.int_mode = True self.sci_mode = False self.max_width = 1 self.random_sample_num = 50 tensor = _try_convert_to_local_tensor(tensor) with flow.no_grad(): tensor_view = tensor.reshape(-1) if not self.floating_dtype: for value in tensor_view: value_str = "{}".format(value) self.max_width = max(self.max_width, len(value_str)) else: nonzero_finite_vals = flow.masked_select(tensor_view, tensor_view.ne(0)) if nonzero_finite_vals.numel() == 0: # no valid number, do nothing return nonzero_finite_abs = nonzero_finite_vals.abs() nonzero_finite_min = nonzero_finite_abs.min().numpy().astype(np.float64) nonzero_finite_max = nonzero_finite_abs.max().numpy().astype(np.float64) for value in nonzero_finite_abs.numpy(): if value != np.ceil(value): self.int_mode = False break if self.int_mode: # Check if scientific representation should be used. if ( nonzero_finite_max / nonzero_finite_min > 1000.0 or nonzero_finite_max > 1.0e8 ): self.sci_mode = True for value in nonzero_finite_vals: value_str = ( ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value) ) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: value_str = ("{:.0f}").format(value) self.max_width = max(self.max_width, len(value_str) + 1) else: if ( nonzero_finite_max / nonzero_finite_min > 1000.0 or nonzero_finite_max > 1.0e8 or nonzero_finite_min < 1.0e-4 ): self.sci_mode = True for value in nonzero_finite_vals: value_str = ( ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value) ) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: value_str = ( ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value) ) self.max_width = max(self.max_width, len(value_str)) if PRINT_OPTS.sci_mode is not None: self.sci_mode = PRINT_OPTS.sci_mode def width(self): return self.max_width def format(self, value): if self.floating_dtype: if self.sci_mode: ret = ( ("{{:{}.{}e}}") .format(self.max_width, PRINT_OPTS.precision) .format(value) ) elif self.int_mode: ret = "{:.0f}".format(value) if not (math.isinf(value) or math.isnan(value)): ret += "." else: ret = ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value) else: ret = "{}".format(value) return (self.max_width - len(ret)) * " " + ret def _scalar_str(self, formatter1): return formatter1.format(_try_convert_to_local_tensor(self).tolist()) def _vector_str(self, indent, summarize, formatter1): # length includes spaces and comma between elements element_length = formatter1.width() + 2 elements_per_line = max( 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) ) def _val_formatter(val, formatter1=formatter1): return formatter1.format(val) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: left_values = _try_convert_to_local_tensor( self[: PRINT_OPTS.edgeitems] ).tolist() right_values = _try_convert_to_local_tensor( self[-PRINT_OPTS.edgeitems :] ).tolist() data = ( [_val_formatter(val) for val in left_values] + [" ..."] + [_val_formatter(val) for val in right_values] ) else: values = _try_convert_to_local_tensor(self).tolist() data = [_val_formatter(val) for val in values] data_lines = [ data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) ] lines = [", ".join(line) for line in data_lines] return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" def _tensor_str_with_formatter(self, indent, summarize, formatter1): dim = self.dim() if dim == 0: return _scalar_str(self, formatter1) if dim == 1: return _vector_str(self, indent, summarize, formatter1) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: slices = ( [ _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1,) for i in range(0, PRINT_OPTS.edgeitems) ] + ["..."] + [ _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1,) for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0]) ] ) else: slices = [ _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1) for i in range(0, self.size(0)) ] tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) return "[" + tensor_str + "]" def _tensor_str(self, indent): summarize = self.numel() > PRINT_OPTS.threshold if self.dtype is flow.float16: self = self.float() with flow.no_grad(): formatter = _Formatter(get_summarized_data(self) if summarize else self) return _tensor_str_with_formatter(self, indent, summarize, formatter) def _add_suffixes(tensor_str, suffixes, indent): tensor_strs = [tensor_str] last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 for suffix in suffixes: suffix_len = len(suffix) if last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: tensor_strs.append(",\n" + " " * indent + suffix) last_line_len = indent + suffix_len else: tensor_strs.append(", " + suffix) last_line_len += suffix_len + 2 tensor_strs.append(")") return "".join(tensor_strs) def get_summarized_data(self): dim = self.dim() if dim == 0: return self if dim == 1: if self.size(0) > 2 * PRINT_OPTS.edgeitems: return flow.cat( (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) ) else: return self if self.size(0) > 2 * PRINT_OPTS.edgeitems: start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] end = [ self[i] for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0]) ] return flow.stack([get_summarized_data(x) for x in (start + end)]) else: return flow.stack([get_summarized_data(x) for x in self]) def _format_tensor_on_cpu(tensor): if tensor.is_global: device = tensor.placement.type else: device = tensor.device.type return device != "cpu" and device != "cuda" def _gen_tensor_str_template(tensor, is_meta): is_meta = is_meta or tensor.is_lazy prefix = "tensor(" indent = len(prefix) suffixes = [] meta_device_flag = False # tensor is local or global if tensor.is_global: if tensor.placement.type == "meta": meta_device_flag = True suffixes.append(f"placement={str(tensor.placement)}") suffixes.append(f"sbp={str(tensor.sbp)}") elif tensor.device.type != "cpu": if tensor.device.type == "meta": meta_device_flag = True suffixes.append("device='" + str(tensor.device) + "'") if tensor.is_lazy: suffixes.append("is_lazy='True'") # tensor is empty, meta or normal if tensor.numel() == 0: # Explicitly print the shape if it is not (0,), to match NumPy behavior if tensor.dim() != 1: suffixes.append("size=" + str(tuple(tensor.shape))) tensor_str = "[]" elif is_meta or meta_device_flag: tensor_str = "..." suffixes.append("size=" + str(tuple(tensor.shape))) else: if _format_tensor_on_cpu(tensor): tensor_str = _tensor_str(tensor.detach().to("cpu"), indent) else: tensor_str = _tensor_str(tensor, indent) suffixes.append("dtype=" + str(tensor.dtype)) if tensor.grad_fn is not None: name = tensor.grad_fn.name() suffixes.append("grad_fn=<{}>".format(name)) elif tensor.requires_grad: suffixes.append("requires_grad=True") return _add_suffixes(prefix + tensor_str, suffixes, indent) def _gen_tensor_str(tensor): return _gen_tensor_str_template(tensor, False) def _gen_tensor_meta_str(tensor): # meta return _gen_tensor_str_template(tensor, True) ================================================ FILE: python/oneflow/framework/tensor_str_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import oneflow as flow from typing import Optional, Tuple def _autoset_linewidth(): # os.terminal_size(columns, lines), # columns represents width of the terminal window in characters # and lines represents height of the terminal window in characters. try: linewidth = os.get_terminal_size()[0] except OSError: linewidth = 80 return linewidth def _try_convert_to_local_tensor(tensor): if tensor.is_global: tensor = tensor.to_global( placement=flow.placement.all(tensor.placement.type), sbp=flow.sbp.broadcast, ).to_local() return tensor ================================================ FILE: python/oneflow/framework/tensor_tuple_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections from typing import Optional, Sequence, Union from oneflow._oneflow_internal import Tensor, TensorTuple def convert_to_tensor_tuple(args: Optional[Union[Tensor, Sequence[Tensor]]]): if args is None: return TensorTuple() elif isinstance(args, collections.abc.Sequence): return TensorTuple(args) else: tensor_tuple = TensorTuple() tensor_tuple.append(args) return tensor_tuple ================================================ FILE: python/oneflow/framework/type_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import oneflow as flow from oneflow._C import ( HalfTensor, FloatTensor, DoubleTensor, BoolTensor, ByteTensor, CharTensor, IntTensor, LongTensor, ComplexFloatTensor, ComplexDoubleTensor, ) __all__ = [ "HalfTensor", "FloatTensor", "DoubleTensor", "BoolTensor", "ByteTensor", "CharTensor", "IntTensor", "LongTensor", "ComplexFloatTensor", "ComplexDoubleTensor", # TODO: Add support for BFloat16Tensor, ComplexHalfTensor ] ================================================ FILE: python/oneflow/framework/unittest.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import atexit import imp import os import socket import subprocess import sys import unittest import uuid import doctest from contextlib import closing from tempfile import NamedTemporaryFile from typing import Any, Callable, Dict import google.protobuf.text_format as pbtxt import oneflow import oneflow.env import oneflow.sysconfig from oneflow.core.job.env_pb2 import EnvProto def register_test_cases( scope: Dict[str, Any], directory: str, filter_by_num_nodes: Callable[[bool], int], base_class: unittest.TestCase = unittest.TestCase, ) -> None: def FilterTestPyFile(f): return ( os.path.isfile(os.path.join(directory, f)) and f.endswith(".py") and f.startswith("test") ) def FilterMethodName(module, name): method = getattr(module, name) return ( name.startswith("test") and callable(method) and filter_by_num_nodes(_GetNumOfNodes(method)) ) onlytest_files = [f for f in os.listdir(directory) if FilterTestPyFile(f)] for f in onlytest_files: class_name = f[0:-3] module = imp.load_source(class_name, os.path.join(directory, f)) test_func_names = [ name for name in dir(module) if FilterMethodName(module, name) ] method_dict = {k: getattr(module, k) for k in test_func_names} scope[class_name] = type(class_name, (test_case_mixin, base_class), method_dict) def num_nodes_required(num_nodes: int) -> Callable[[Callable], Callable]: def Decorator(f): f.__oneflow_test_case_num_nodes_required__ = num_nodes return f return Decorator def _GetNumOfNodes(func): if hasattr(func, "__oneflow_test_case_num_nodes_required__") == False: return 1 return getattr(func, "__oneflow_test_case_num_nodes_required__") def eager_execution_enabled(): return os.getenv("ONEFLOW_TEST_ENABLE_EAGER") == "1" def typing_check_enabled(): return os.getenv("ONEFLOW_TEST_ENABLE_TYPING_CHECK") == "1" def node_list(): node_list_str = os.getenv("ONEFLOW_TEST_NODE_LIST") assert node_list_str return node_list_str.split(",") def has_node_list(): if os.getenv("ONEFLOW_TEST_NODE_LIST"): return True else: return False def node_size(): node_num_from_env = os.getenv("ONEFLOW_TEST_NODE_NUM", None) if node_num_from_env: return int(node_num_from_env) elif has_node_list(): node_list_from_env = node_list() return len(node_list_from_env) else: return 1 def has_world_size(): return True def world_size(): return oneflow.env.get_world_size() def device_num(): device_num_str = os.getenv("ONEFLOW_TEST_DEVICE_NUM") if device_num_str: return int(device_num_str) else: return 1 def enable_init_by_host_list(): return os.getenv("ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST") == "1" def enable_multi_process(): return os.getenv("ONEFLOW_TEST_MULTI_PROCESS") == "1" def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("localhost", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] _unittest_worker_initilized = False def worker_agent_port(): port_txt = os.getenv("ONEFLOW_TEST_WORKER_AGENT_PORT") if port_txt: return int(port_txt) else: return None def worker_agent_authkey(): key = os.getenv("ONEFLOW_TEST_WORKER_AGENT_AUTHKEY") assert key return key def use_worker_agent(): return worker_agent_port() is not None def cast(conn=None, cmd=None, msg=None): cmd = "cast/" + cmd print("[unittest]", f"[{cmd}]", msg) conn.send(cmd.encode()) conn.send(msg.encode()) def call(conn=None, cmd=None, msg=None): cmd = "call/" + cmd print("[unittest]", f"[{cmd}]", msg) conn.send(cmd.encode()) msg_ = "" if msg is not None: msg_ = msg conn.send(msg_.encode()) return conn.recv().decode() TestCase = unittest.TestCase def skip_unless(n, d): if (n > 1 or d > 1) and oneflow.sysconfig.has_rpc_backend_grpc() == False: return unittest.skip( "requires multi node rpc backend when node_size > 1 and device_num > 1" ) if node_size() == n and device_num() == d: return lambda func: func else: return unittest.skip( "only runs when node_size is {} and device_num is {}".format(n, d) ) def skip_unless_1n1d(): return skip_unless(1, 1) def skip_unless_1n2d(): return skip_unless(1, 2) def skip_unless_1n4d(): return skip_unless(1, 4) def skip_unless_2n1d(): return skip_unless(2, 1) def skip_unless_2n2d(): return skip_unless(2, 2) def skip_unless_2n4d(): return skip_unless(2, 4) class CondSkipChecker(doctest.OutputChecker): def __init__(self, check_flags): self._check_flags = check_flags def check_output(self, want, got, optionflags): # default check_output without flag if optionflags == 0: return super(CondSkipChecker, self).check_output(want, got, optionflags) target_rank_list = [bool(flag & optionflags) for flag in self._check_flags] # wrong flag will be handled before here, so any(target_rank_list) is True # not target rank if target_rank_list.index(True) != oneflow.env.get_rank(): return True elif target_rank_list.index(True) == oneflow.env.get_rank(): return super(CondSkipChecker, self).check_output(want, got, optionflags) def check_multi_rank_docstr(module): # supply customized flag ONLY_CHECK_RANK_{x} for docstr check_flags = [ doctest.register_optionflag(f"ONLY_CHECK_RANK_{i}") for i in range(oneflow.env.get_world_size()) ] finder = doctest.DocTestFinder() runner = doctest.DebugRunner(CondSkipChecker(check_flags)) for test in finder.find(module, module.__name__): runner.run(test) ================================================ FILE: python/oneflow/fx/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ try: from onefx import * except: class Proxy: def __init__(self): raise NotImplementedError( "oneflow.fx.Proxy is only for compatibility with PyTorch and is not actually implemented." ) ================================================ FILE: python/oneflow/hub.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This file was copyed from https://github.com/pytorch/pytorch/blob/master/torch/hub.py and consistent with oneflow. import errno import hashlib import json import os import re import shutil import sys import tempfile import oneflow as flow import warnings import zipfile from pathlib import Path from typing import Dict, Optional, Any from urllib.error import HTTPError from urllib.request import urlopen, Request from urllib.parse import urlparse # noqa: F401 try: from tqdm.auto import ( tqdm, ) # automatically select proper tqdm submodule if available except ImportError: try: from tqdm import tqdm except ImportError: # fake tqdm if it's not installed class tqdm(object): # type: ignore[no-redef] def __init__( self, total=None, disable=False, unit=None, unit_scale=None, unit_divisor=None, ): self.total = total self.disable = disable self.n = 0 # ignore unit, unit_scale, unit_divisor; they're just for real tqdm def update(self, n): if self.disable: return self.n += n if self.total is None: sys.stderr.write("\r{0:.1f} bytes".format(self.n)) else: sys.stderr.write( "\r{0:.1f}%".format(100 * self.n / float(self.total)) ) sys.stderr.flush() def close(self): self.disable = True def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.disable: return sys.stderr.write("\n") __all__ = [ "download_url_to_file", "get_dir", "help", "list", "load", "load_state_dict_from_url", "set_dir", ] # matches bfd8deac from resnet18-bfd8deac.pth HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") _TRUSTED_REPO_OWNERS = "oneflow" ENV_GITHUB_TOKEN = "GITHUB_TOKEN" ENV_ONEFLOW_HOME = "ONEFLOW_HOME" ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" DEFAULT_CACHE_DIR = "~/.cache" VAR_DEPENDENCY = "dependencies" MODULE_HUBCONF = "hubconf.py" READ_DATA_CHUNK = 8192 _hub_dir = None # Copied from tools/shared/module_loader to be included in oneflow package def _import_module(name, path): import importlib.util from importlib.abc import Loader spec = importlib.util.spec_from_file_location(name, path) assert spec is not None module = importlib.util.module_from_spec(spec) assert isinstance(spec.loader, Loader) spec.loader.exec_module(module) return module def _remove_if_exists(path): if os.path.exists(path): if os.path.isfile(path): os.remove(path) else: shutil.rmtree(path) def _git_archive_link(repo_owner, repo_name, ref): # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" def _load_attr_from_module(module, func_name): # Check if callable is defined in the module if func_name not in dir(module): return None return getattr(module, func_name) def _get_oneflow_home(): oneflow_home = os.path.expanduser( os.getenv( ENV_ONEFLOW_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "oneflow"), ) ) return oneflow_home _get_torch_home = _get_oneflow_home def _parse_repo_info(github): if ":" in github: repo_info, ref = github.split(":") else: repo_info, ref = github, None repo_owner, repo_name = repo_info.split("/") if ref is None: # The ref wasn't specified by the user, so we need to figure out the # default branch: main or master. Our assumption is that if main exists # then it's the default branch, otherwise it's master. try: with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): ref = "main" except HTTPError as e: if e.code == 404: ref = "master" else: raise return repo_owner, repo_name, ref def _read_url(url): with urlopen(url) as r: return r.read().decode(r.headers.get_content_charset("utf-8")) def _validate_not_a_forked_repo(repo_owner, repo_name, ref): # Use urlopen to avoid depending on local git. headers = {"Accept": "application/vnd.github.v3+json"} token = os.environ.get(ENV_GITHUB_TOKEN) if token is not None: headers["Authorization"] = f"token {token}" for url_prefix in ( f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", ): page = 0 while True: page += 1 url = f"{url_prefix}?per_page=100&page={page}" response = json.loads(_read_url(Request(url, headers=headers))) # Empty response means no more data to process if not response: break for br in response: if br["name"] == ref or br["commit"]["sha"].startswith(ref): return raise ValueError( f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " "If it's a commit from a forked repo, please call hub.load() with forked repo directly." ) def _get_cache_or_reload( github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False ): # Setup hub_dir to save downloaded files hub_dir = get_dir() if not os.path.exists(hub_dir): os.makedirs(hub_dir) # Parse github repo information repo_owner, repo_name, ref = _parse_repo_info(github) # Github allows branch name with slash '/', # this causes confusion with path on both Linux and Windows. # Backslash is not allowed in Github branch name so no need to # to worry about it. normalized_br = ref.replace("/", "_") # Github renames folder repo-v1.x.x to repo-1.x.x # We don't know the repo name before downloading the zip file # and inspect name from it. # To check if cached repo exists, we need to normalize folder names. owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) repo_dir = os.path.join(hub_dir, owner_name_branch) # Check that the repo is in the trusted list _check_repo_is_trusted( repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn, ) use_cache = (not force_reload) and os.path.exists(repo_dir) if use_cache: if verbose: sys.stderr.write("Using cache found in {}\n".format(repo_dir)) else: # Validate the tag/branch is from the original repo instead of a forked repo if not skip_validation: _validate_not_a_forked_repo(repo_owner, repo_name, ref) cached_file = os.path.join(hub_dir, normalized_br + ".zip") _remove_if_exists(cached_file) try: url = _git_archive_link(repo_owner, repo_name, ref) sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) download_url_to_file(url, cached_file, progress=False) except HTTPError as err: if err.code == 300: # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags # See https://git-scm.com/book/en/v2/Git-Internals-Git-References # Here, we do the same as git: we throw a warning, and assume the user wanted the branch warnings.warn( f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " "OneFlowhub will now assume that it's a branch. " "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " "refs/tags/tag_name as the ref. That might require using skip_validation=True." ) disambiguated_branch_ref = f"refs/heads/{ref}" url = _git_archive_link( repo_owner, repo_name, ref=disambiguated_branch_ref ) download_url_to_file(url, cached_file, progress=False) else: raise with zipfile.ZipFile(cached_file) as cached_zipfile: extraced_repo_name = cached_zipfile.infolist()[0].filename extracted_repo = os.path.join(hub_dir, extraced_repo_name) _remove_if_exists(extracted_repo) # Unzip the code and rename the base folder cached_zipfile.extractall(hub_dir) _remove_if_exists(cached_file) _remove_if_exists(repo_dir) shutil.move(extracted_repo, repo_dir) # rename the repo return repo_dir def _check_repo_is_trusted( repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load" ): hub_dir = get_dir() filepath = os.path.join(hub_dir, "trusted_list") if not os.path.exists(filepath): Path(filepath).touch() with open(filepath, "r") as file: trusted_repos = tuple(line.strip() for line in file) # To minimize friction of introducing the new trust_repo mechanism, we consider that # if a repo was already downloaded by oneflowhub, then it is already trusted (even if it's not in the allowlist) trusted_repos_legacy = next(os.walk(hub_dir))[1] owner_name = "_".join([repo_owner, repo_name]) is_trusted = ( owner_name in trusted_repos or owner_name_branch in trusted_repos_legacy or repo_owner in _TRUSTED_REPO_OWNERS ) # TODO: Remove `None` option in 1.14 and change the default to "check" if trust_repo is None: if not is_trusted: warnings.warn( "You are about to download and run code from an untrusted repository. In a future release, this won't " "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" ) return if (trust_repo is False) or (trust_repo == "check" and not is_trusted): response = input( f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" ) if response.lower() in ("y", "yes"): if is_trusted: print("The repository is already trusted.") elif response.lower() in ("n", "no", ""): raise Exception("Untrusted repository.") else: raise ValueError(f"Unrecognized response {response}.") # At this point we're sure that the user trusts the repo (or wants to trust it) if not is_trusted: with open(filepath, "a") as file: file.write(owner_name + "\n") def _check_module_exists(name): import importlib.util return importlib.util.find_spec(name) is not None def _check_dependencies(m): dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) if dependencies is not None: missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] if len(missing_deps): raise RuntimeError( "Missing dependencies: {}".format(", ".join(missing_deps)) ) def _load_entry_from_hubconf(m, model): if not isinstance(model, str): raise ValueError("Invalid input: model should be a string of function name") # Note that if a missing dependency is imported at top level of hubconf, it will # throw before this function. It's a chicken and egg situation where we have to # load hubconf to know what're the dependencies, but to import hubconf it requires # a missing package. This is fine, Python will throw proper error message for users. _check_dependencies(m) func = _load_attr_from_module(m, model) if func is None or not callable(func): raise RuntimeError("Cannot find callable {} in hubconf".format(model)) return func def get_dir(): """ Get the OneFlow Hub cache directory used for storing downloaded models & weights. If :func:`~oneflow.hub.set_dir` is not called, default path is ``$ONEFLOW_HOME/hub`` where environment variable ``$ONEFLOW_HOME`` defaults to ``$XDG_CACHE_HOME/oneflow``. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux filesystem layout, with a default value ``~/.cache`` if the environment variable is not set. """ # Issue warning to move data if old env is set if os.getenv("ONEFLOW_HUB"): warnings.warn("ONEFLOW_HUB is deprecated, please use env ONEFLOW_HOME instead") if _hub_dir is not None: return _hub_dir return os.path.join(_get_oneflow_home(), "hub") def set_dir(d): """ Optionally set the OneFlow Hub directory used to save downloaded models & weights. Args: d (str): path to a local folder to save downloaded models & weights. """ global _hub_dir _hub_dir = os.path.expanduser(d) def list(github, force_reload=False, skip_validation=False, trust_repo=None): """ List all callable entrypoints available in the repo specified by ``github``. Args: github (str): a string with format "repo_owner/repo_name[:ref]" with an optional ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. Example: ' Oneflow-Inc/vision:0.2.0' force_reload (bool, optional): whether to discard the existing cache and force a fresh download. Default is ``False``. skip_validation (bool, optional): if ``False``, oneflowhub will check that the branch or commit specified by the ``github`` argument properly belongs to the repo owner. This will make requests to the GitHub API; you can specify a non-default GitHub token by setting the ``GITHUB_TOKEN`` environment variable. Default is ``False``. trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. - If ``False``, a prompt will ask the user whether the repo should be trusted. - If ``True``, the repo will be added to the trusted list and loaded without requiring explicit confirmation. - If ``"check"``, the repo will be checked against the list of trusted repos in the cache. If it is not present in that list, the behaviour will fall back onto the ``trust_repo=False`` option. - If ``None``, this will raise a warning, inviting the user to set ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This is only present for backward compatibility and will be removed in v1.14. Default is ``None`` and will eventually change to ``"check"`` in v1.14. Returns: list: The available callables entrypoint For example: >>> entrypoints = oneflow.hub.list('Oneflow-Inc/vision', force_reload=True) """ repo_dir = _get_cache_or_reload( github, force_reload, trust_repo, "list", verbose=True, skip_validation=skip_validation, ) sys.path.insert(0, repo_dir) hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) hub_module = _import_module(MODULE_HUBCONF, hubconf_path) sys.path.remove(repo_dir) # We take functions starts with '_' as internal helper functions entrypoints = [ f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith("_") ] return entrypoints def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): """ Show the docstring of entrypoint ``model``. Args: github (str): a string with format with an optional ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. Example: 'Oneflow-Inc/vision:0.2.0' model (str): a string of entrypoint name defined in repo's ``hubconf.py`` force_reload (bool, optional): whether to discard the existing cache and force a fresh download. Default is ``False``. skip_validation (bool, optional): if ``False``, oneflowhub will check that the ref specified by the ``github`` argument properly belongs to the repo owner. This will make requests to the GitHub API; you can specify a non-default GitHub token by setting the ``GITHUB_TOKEN`` environment variable. Default is ``False``. trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. - If ``False``, a prompt will ask the user whether the repo should be trusted. - If ``True``, the repo will be added to the trusted list and loaded without requiring explicit confirmation. - If ``"check"``, the repo will be checked against the list of trusted repos in the cache. If it is not present in that list, the behaviour will fall back onto the ``trust_repo=False`` option. - If ``None``: this will raise a warning, inviting the user to set ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This is only present for backward compatibility and will be removed in v1.14. Default is ``None`` and will eventually change to ``"check"`` in v1.14. For example: >>> print(oneflow.hub.help('Oneflow-Inc/vision', 'resnet18', force_reload=True)) """ repo_dir = _get_cache_or_reload( github, force_reload, trust_repo, "help", verbose=True, skip_validation=skip_validation, ) sys.path.insert(0, repo_dir) hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) hub_module = _import_module(MODULE_HUBCONF, hubconf_path) sys.path.remove(repo_dir) entry = _load_entry_from_hubconf(hub_module, model) return entry.__doc__ def load( repo_or_dir, model, *args, source="github", trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs, ): """ Load a model from a github repo or a local directory. Note: Loading a model is the typical use case, but this can also be used to for loading other objects such as tokenizers, loss functions, etc. If ``source`` is 'github', ``repo_or_dir`` is expected to be of the form ``repo_owner/repo_name[:ref]`` with an optional ref (a tag or a branch). If ``source`` is 'local', ``repo_or_dir`` is expected to be a path to a local directory. Args: repo_or_dir (str): If ``source`` is 'github', this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with an optional ref (tag or branch), for example 'Oneflow-Inc/vision:0.2.0'. If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. If ``source`` is 'local' then it should be a path to a local directory. model (str): the name of a callable (entrypoint) defined in the repo/dir's ``hubconf.py``. *args (optional): the corresponding args for callable ``model``. source (str, optional): 'github' or 'local'. Specifies how ``repo_or_dir`` is to be interpreted. Default is 'github'. trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. - If ``False``, a prompt will ask the user whether the repo should be trusted. - If ``True``, the repo will be added to the trusted list and loaded without requiring explicit confirmation. - If ``"check"``, the repo will be checked against the list of trusted repos in the cache. If it is not present in that list, the behaviour will fall back onto the ``trust_repo=False`` option. - If ``None``: this will raise a warning, inviting the user to set ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This is only present for backward compatibility and will be removed in v1.14. Default is ``None`` and will eventually change to ``"check"`` in v1.14. force_reload (bool, optional): whether to force a fresh download of the github repo unconditionally. Does not have any effect if ``source = 'local'``. Default is ``False``. verbose (bool, optional): If ``False``, mute messages about hitting local caches. Note that the message about first download cannot be muted. Does not have any effect if ``source = 'local'``. Default is ``True``. skip_validation (bool, optional): if ``False``, oneflowhub will check that the branch or commit specified by the ``github`` argument properly belongs to the repo owner. This will make requests to the GitHub API; you can specify a non-default GitHub token by setting the ``GITHUB_TOKEN`` environment variable. Default is ``False``. **kwargs (optional): the corresponding kwargs for callable ``model``. Returns: The output of the ``model`` callable when called with the given ``*args`` and ``**kwargs``. For example: >>> # from a github repo >>> repo = 'Oneflow-Inc/vision' >>> model = oneflow.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') >>> # from a local directory >>> path = '/some/local/path/oneflow/vision' >>> # xdoctest: +SKIP >>> model = oneflow.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT') """ source = source.lower() if source not in ("github", "local"): raise ValueError( f'Unknown source: "{source}". Allowed values: "github" | "local".' ) if source == "github": repo_or_dir = _get_cache_or_reload( repo_or_dir, force_reload, trust_repo, "load", verbose=verbose, skip_validation=skip_validation, ) model = _load_local(repo_or_dir, model, *args, **kwargs) return model def _load_local(hubconf_dir, model, *args, **kwargs): """ Load a model from a local directory with a ``hubconf.py``. Args: hubconf_dir (str): path to a local directory that contains a ``hubconf.py``. model (str): name of an entrypoint defined in the directory's ``hubconf.py``. *args (optional): the corresponding args for callable ``model``. **kwargs (optional): the corresponding kwargs for callable ``model``. Returns: a single model with corresponding pretrained weights. For example: >>> # xdoctest: +SKIP("stub local path") >>> path = '/some/local/path/oneflow/vision' >>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') """ sys.path.insert(0, hubconf_dir) hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) hub_module = _import_module(MODULE_HUBCONF, hubconf_path) entry = _load_entry_from_hubconf(hub_module, model) model = entry(*args, **kwargs) sys.path.remove(hubconf_dir) return model def download_url_to_file(url, dst, hash_prefix=None, progress=True): """Download object at the given URL to a local path. Args: url (str): URL of the object to download dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True For example: >>> # xdoctest: +REQUIRES(POSIX) >>> oneflow.hub.download_url_to_file('https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip', '/tmp/temporary_file') """ file_size = None req = Request(url, headers={"User-Agent": "oneflow.hub"}) u = urlopen(req) meta = u.info() if hasattr(meta, "getheaders"): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint # being overridden by a broken download. dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: if hash_prefix is not None: sha256 = hashlib.sha256() with tqdm( total=file_size, disable=not progress, unit="B", unit_scale=True, unit_divisor=1024, ) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) if hash_prefix is not None: sha256.update(buffer) pbar.update(len(buffer)) f.close() if hash_prefix is not None: digest = sha256.hexdigest() if digest[: len(hash_prefix)] != hash_prefix: raise RuntimeError( 'invalid hash value (expected "{}", got "{}")'.format( hash_prefix, digest ) ) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) # Hub used to support automatically extracts from zipfile manually compressed by users. # We should remove this support since zipfile is now default zipfile format for oneflow.save(). def _is_legacy_zip_format(filename): if zipfile.is_zipfile(filename): return True else: return False def _legacy_zip_load(filename, model_dir, map_location): # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. # We deliberately don't handle tarfile here since our legacy serialization format was in tar. # E.g. resnet18-5c106cde.pth which is widely used. with zipfile.ZipFile(filename) as f: members = f.infolist() f.extractall(model_dir) extraced_name = members[0].filename extracted_file = os.path.join(model_dir, extraced_name) return flow.load(extracted_file, map_location=map_location) def load_state_dict_from_url( url: str, model_dir: Optional[str] = None, map_location=None, progress: bool = True, check_hash: bool = False, file_name: Optional[str] = None, ) -> Dict[str, Any]: """Loads the OneFlow serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed. If the object is already present in `model_dir`, it's deserialized and returned. The default value of ``model_dir`` is ``/checkpoints`` where ``hub_dir`` is the directory returned by :func:`~oneflow.hub.get_dir`. Args: url (str): URL of the object to download model_dir (str, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see oneflow.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-.ext`` where ```` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. For example: >>> state_dict = oneflow.hub.load_state_dict_from_url('https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip') """ # Issue warning to move data if old env is set if os.getenv("ONEFLOW_MODEL_ZOO"): warnings.warn( "ONEFLOW_MODEL_ZOO is deprecated, please use env ONEFLOW_HOME instead" ) if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) if _is_legacy_zip_format(cached_file): return _legacy_zip_load(cached_file, model_dir, map_location) return flow.load(cached_file, map_location=map_location) ================================================ FILE: python/oneflow/ir/__main__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import oneflow parser = argparse.ArgumentParser() parser.add_argument("--gen_ods", default=False, action="store_true", required=False) args = parser.parse_args() if __name__ == "__main__": oneflow._oneflow_internal.ir.gen_ods() ================================================ FILE: python/oneflow/ir/ast_gen_transformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import ast class ASTTransformer(ast.NodeTransformer): def visit_arg(self, node: ast.arg): node.ast = oneflow._oneflow_internal.ir.arg_(node.arg) return node def visit_arguments(self, node: ast.arguments): for arg in node.args: self.visit(arg) list = [arg.ast for arg in node.args] node.ast = oneflow._oneflow_internal.ir.arguments_(list) return node def visit_FunctionDef(self, node: ast.FunctionDef): for arg in node.body: self.visit(arg) body = [arg.ast for arg in node.body] self.visit(node.args) node.ast = oneflow._oneflow_internal.ir.FunctionDef_( "get_lr", node.args.ast, body ) return node def visit_Return(self, node: ast.Return): self.visit(node.value) node.ast = oneflow._oneflow_internal.ir.Return_(node.value.ast) return node def visit_Assign(self, node: ast.Assign): self.visit(node.value) for arg in node.targets: self.visit(arg) targets = [arg.ast for arg in node.targets] node.ast = oneflow._oneflow_internal.ir.Assign_(targets, node.value.ast) return node def visit_If(self, node: ast.If): self.visit(node.test) for arg in node.body: self.visit(arg) if node.orelse: for arg in node.orelse: self.visit(arg) test = node.test.ast body = [arg.ast for arg in node.body] orelse = [arg.ast for arg in node.orelse] node.ast = oneflow._oneflow_internal.ir.If_(test, body, orelse) return node def visit_Raise(self, node: ast.Raise): print(ast.dump(node)) raise "not suport yet now" def visit_Assert(self, node: ast.Assert): print(ast.dump(node)) raise "not suport yet now" def visit_Expr(self, node: ast.Expr): print(ast.dump(node)) raise "not suport yet now" def visit_BoolOp(self, node: ast.BoolOp): print(ast.dump(node)) raise "not suport yet now" def visit_BinOp(self, node: ast.BinOp): self.visit(node.left) self.visit(node.right) left = node.left.ast right = node.right.ast def get_op(op: ast.operator): list = [ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow] res = 1 for elem in list: if isinstance(op, elem): return res res += 1 op = get_op(node.op) node.ast = oneflow._oneflow_internal.ir.BinOp_(left, op, right) return node def visit_Lambda(self, node: ast.Lambda): print(ast.dump(node)) raise "not suport yet now" def visit_Compare(self, node: ast.Compare): self.visit(node.left) for arg in node.comparators: self.visit(arg) left = node.left.ast comparators = [arg.ast for arg in node.comparators] def get_op(op: ast.operator): list = [ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE] res = 1 for elem in list: if isinstance(op, elem): return res res += 1 ops = [get_op(arg) for arg in node.ops] node.ast = oneflow._oneflow_internal.ir.Compare_(left, ops, comparators) return node def visit_Call(self, node: ast.Call): self.visit(node.func) for arg in node.args: self.visit(arg) func = node.func.ast args = [arg.ast for arg in node.args] node.ast = oneflow._oneflow_internal.ir.Call_(func, args) return node def visit_Constant(self, node: ast.Constant): node.ast = oneflow._oneflow_internal.ir.Constant_(node.value) return node def visit_Num(self, node: ast.Num): node.ast = oneflow._oneflow_internal.ir.Num_(node.value) return node def visit_Attribute(self, node: ast.Attribute): self.visit(node.value) value = node.value.ast node.ast = oneflow._oneflow_internal.ir.Attribute_(value, node.attr) return node def visit_Name(self, node: ast.Name): node.ast = oneflow._oneflow_internal.ir.Name_(node.id) return node ================================================ FILE: python/oneflow/ir/bisect_transformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ast from bisect import bisect class BisectTransformer(ast.NodeTransformer): def visit_FunctionDef(self, node: ast.FunctionDef): self.body_index = 0 self.body = node.body for stmt in node.body: self.visit(stmt) self.body_index += 1 return node def visit_Call(self, node: ast.Call): if isinstance(node.func, ast.Attribute): func: ast.Attribute = node.func if func.value.id == "bisect": bisect_x_list = ["bisect_right", "bisect_left"] if func.attr in bisect_x_list: op = ast.LtE if func.attr == "bisect_right": op = ast.Lt if not isinstance(node.args[0], ast.List): raise "only support bisect.bisect_right(list, x)" ls = node.args[0].elts cmp = node.args[1] index = 0 for i in ls[::-1]: test = ast.Compare(cmp, [op()], [i]) assign = ast.Assign( [ast.Name("tmp")], ast.Constant(len(ls) - index - 1, None) ) if "orelse" in locals(): orelse = ast.If(test, [assign], [orelse]) else: orelse = ast.If(test, [assign], []) index += 1 self.body.insert(self.body_index, orelse) return ast.Name("tmp") return node ================================================ FILE: python/oneflow/ir/lr_jit.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ast import textwrap import inspect import oneflow import unittest import oneflow.unittest from ast_gen_transformer import ASTTransformer from math_params_transformer import MathParamsTransformer from self_params_transformer import SelfParamsTransformer from bisect_transformer import BisectTransformer def lr_jit_register(lr_obj, is_dump=False): _id = lr_obj.__class__.__name__ # load source txt _src = textwrap.dedent(inspect.getsource(lr_obj.get_lr)) _ast = ast.parse(_src).body[0] # transform param self transformer = SelfParamsTransformer(lr_obj) transformer.visit(_ast) # transform for bisect lib transformer = BisectTransformer() transformer.visit(_ast) # transform for math lib transformer = MathParamsTransformer() transformer.visit(_ast) # feed transformed as to C++ transformer = ASTTransformer() transformer.visit(_ast) oneflow._oneflow_internal.ir.compile_and_register_lr_jit(_id, _ast.ast, is_dump) return _id def _test_current_lr_jit(test_case): from oneflow.nn.optimizer.constant_lr import ConstantLR from oneflow.nn.optimizer.cosine_annealing_lr import CosineAnnealingLR from oneflow.nn.optimizer.cosine_decay_lr import CosineDecayLR from oneflow.nn.optimizer.exponential_lr import ExponentialLR from oneflow.nn.optimizer.lambda_lr import LambdaLR from oneflow.nn.optimizer.linear_lr import LinearLR from oneflow.nn.optimizer.multistep_lr import MultiStepLR from oneflow.nn.optimizer.polynomial_lr import PolynomialLR from oneflow.nn.optimizer.sequential_lr import SequentialLR from oneflow.nn.optimizer.step_lr import StepLR from oneflow.nn.optimizer.warmup_lr import WarmupLR from oneflow.optim import SGD from oneflow.nn import Parameter import numpy as np param = Parameter(oneflow.ones(3, 4)) optimizer = SGD([param], lr=0.001) lr_jit = oneflow._oneflow_internal.ir.create_global_lr_jit() lr_obj_list = [ # WarmupLR(optimizer), StepLR(optimizer, 5), # SequentialLR(optimizer), PolynomialLR(optimizer, 5), MultiStepLR(optimizer, [10, 20, 30]), LinearLR(optimizer), # LambdaLR(optimizer, [lambda step: 0.95 * step]), ExponentialLR(optimizer, 1.1), CosineDecayLR(optimizer, 10), CosineAnnealingLR(optimizer, 50), ConstantLR(optimizer), ] for lr_obj in lr_obj_list: id_ = lr_jit_register(lr_obj, False) ls = [[0.005, 5], [0.01, 10], [0.02, 21]] for elem in ls: base_lr = elem[0] step = elem[1] lr = lr_obj.get_lr(base_lr, step) lr_jit = oneflow._oneflow_internal.ir.get_lr(id_, base_lr, step) test_case.assertTrue(np.isclose(lr, lr_jit)) @oneflow.unittest.skip_unless_1n1d() class TestCurrentLRJIT(oneflow.unittest.MLIRTestCase): def test_current_lr_jit(test_case): _test_current_lr_jit(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/ir/math_params_transformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ast class MathParamsTransformer(ast.NodeTransformer): def visit_Attribute(self, node): import math list = ["pi"] if node.value.id == "math": if node.attr in list: _name = node.attr _attr = getattr(math, _name) return ast.Constant(_attr, None) return node ================================================ FILE: python/oneflow/ir/self_params_transformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ast class SelfParamsTransformer(ast.NodeTransformer): def __init__(self, lr_obj): super().__init__() self.lr_obj = lr_obj def visit_Attribute(self, node): if node.value.id == "self": _name = node.attr _attr = getattr(self.lr_obj, _name) if isinstance(_attr, list): ls = [ast.Constant(elem, None) for elem in _attr] return ast.List(ls) return ast.Constant(_attr, None) return node def visit_arguments(self, node: ast.arguments): for index, item in enumerate(node.args): if item.arg == "self": node.args.pop(index) return node ================================================ FILE: python/oneflow/jit/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from typing import Any, Dict, List, Set, Tuple, Union, Callable def script( obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None, ): warnings.warn( "The oneflow.jit.script interface is just to align the torch.jit.script interface and has no practical significance." ) return obj def ignore(drop=False, **kwargs): warnings.warn( "The oneflow.jit.ignore interface is just to align the torch.jit.ignore interface and has no practical significance." ) def decorator(fn): return fn return decorator def unused(fn): warnings.warn( "The oneflow.jit.unused interface is just to align the torch.jit.unused interface and has no practical significance." ) return fn def _script_if_tracing(fn): warnings.warn( "The oneflow.jit._script_if_tracing interface is just to align the torch.jit._script_if_tracing interface and has no practical significance." ) return fn def _overload_method(fn): warnings.warn( "The oneflow.jit._overload_method interface is just to align the torch.jit._overload_method interface and has no practical significance." ) return fn def is_scripting(): return False def is_tracing(): return False class _Final: """Mixin to prohibit subclassing""" __slots__ = ("__weakref__",) def __init_subclass__(self, *args, **kwds): if "_root" not in kwds: raise TypeError("Cannot subclass special typing classes") class _SpecialForm(_Final, _root=True): __slots__ = ("_name", "__doc__", "_getitem") def __init__(self, getitem): self._getitem = getitem self._name = getitem.__name__ self.__doc__ = getitem.__doc__ def __getattr__(self, item): if item in {"__name__", "__qualname__"}: return self._name raise AttributeError(item) def __mro_entries__(self, bases): raise TypeError(f"Cannot subclass {self!r}") def __repr__(self): return "typing." + self._name def __reduce__(self): return self._name def __call__(self, *args, **kwds): raise TypeError(f"Cannot instantiate {self!r}") def __or__(self, other): return Union[self, other] def __ror__(self, other): return Union[other, self] def __instancecheck__(self, obj): raise TypeError(f"{self} cannot be used with isinstance()") def __subclasscheck__(self, cls): raise TypeError(f"{self} cannot be used with issubclass()") def __getitem__(self, parameters): return self._getitem(self, parameters) @_SpecialForm def Final(*args, **kwargs): warnings.warn( "The oneflow.jit.Final interface is just to align the torch.jit.Final interface and has no practical significance." ) def interface(fn): warnings.warn( "The oneflow.jit.interface interface is just to align the torch.jit.interface interface and has no practical significance." ) return fn ================================================ FILE: python/oneflow/jit/annotations.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Tuple, List BroadcastingList2 = Tuple ================================================ FILE: python/oneflow/library.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings warnings.warn( "The oneflow.library interface is just to align the torch.library interface and has no practical significance." ) ================================================ FILE: python/oneflow/linalg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def norm(self, ord=None, dim=None, keepdim=False, dtype=None): return flow._C.norm(self, ord, dim, keepdim, dtype=dtype) def vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): return flow._C.vector_norm(self, ord, dim, keepdim, dtype=dtype) def matrix_norm(self, ord="fro", dim=(-2, -1), keepdim=False, dtype=None): return flow._C.matrix_norm(self, ord, dim, keepdim, dtype=dtype) def inv(self): return flow._C.inv(self) def diagonal(self, input, offset=0, dim1=-2, dim2=-1): """ Alias for :func:`oneflow.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`. """ return flow._C.diagonal(self, input, offset=offset, dim1=dim1, dim2=dim2) def cross(input, other, dim=-1): return flow._C.linalg_cross(input, other, dim=dim) def det(A): """ Computes the determinant of a square matrix. Supports input of float, double dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.linalg.det.html Args: A (Tensor): tensor of shape (\*, n, n) where \* is zero or more batch dimensions. Returns: oneflow.Tensor: the output Tensor. .. warning:: Currently, only CUDA11 and above versions are supported. """ return flow._C.det(A) def solve(): raise NotImplementedError() ================================================ FILE: python/oneflow/mock_torch/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .mock_importer import ModuleWrapper, enable, disable from .mock_modules import DummyModule from .dyn_mock_mod import DynamicMockModule ================================================ FILE: python/oneflow/mock_torch/__main__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse from pathlib import Path import os import sys if sys.version_info < (3, 8): try: from importlib_metadata import requires except ImportError: import subprocess subprocess.check_call("pip install importlib_metadata", shell=True) subprocess.check_call("pip install packaging", shell=True) parser = argparse.ArgumentParser() parser.add_argument( "mock", choices=["enable", "disable"], help="enable/disable mocking 'import torch', default is enable", nargs="?", default="enable", ) parser.add_argument("--lazy", action="store_true") parser.add_argument("--verbose", action="store_true") args = parser.parse_args() torch_env = Path(__file__).parent def main(): def is_torch_env(s): if s.endswith("oneflow/mock_torch"): return True return False if args.mock == "enable": print( f"export ONEFLOW_MOCK_TORCH_LAZY={args.lazy}; export ONEFLOW_MOCK_TORCH_VERBOSE={args.verbose}; export PYTHONPATH={str(torch_env)}:$PYTHONPATH" ) elif args.mock == "disable" and "PYTHONPATH" in os.environ: paths = os.environ["PYTHONPATH"].rstrip(":").split(":") paths = [p for p in paths if not is_torch_env(p)] if len(paths) == 0: print( "unset PYTHONPATH; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE" ) return path = ":".join(paths) print( f"export PYTHONPATH={path}; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE" ) if __name__ == "__main__": main() ================================================ FILE: python/oneflow/mock_torch/dyn_mock_mod.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from inspect import ismodule import importlib from contextlib import contextmanager from types import ModuleType from typing import Dict, List from .mock_importer import enable class DynamicMockModule(ModuleType): def __init__( self, pkg_name: str, obj_entity: ModuleType, main_pkg_enable: callable, ): self._pkg_name = pkg_name self._obj_entity = obj_entity # ModuleType or _LazyModule self._main_pkg_enable = main_pkg_enable self._intercept_dict = {} def __repr__(self) -> str: return f"" def hijack(self, module_name: str, obj: object): self._intercept_dict[module_name] = obj @classmethod def from_package( cls, main_pkg: str, *, lazy: bool = True, verbose: bool = False, extra_dict: Dict[str, str] = None, required_dependencies: List[str] = [], ): assert isinstance(main_pkg, str) @contextmanager def main_pkg_enable(): with enable( lazy=lazy, verbose=verbose, extra_dict=extra_dict, main_pkg=main_pkg, mock_version=True, required_dependencies=required_dependencies, ): yield with main_pkg_enable(): obj_entity = importlib.import_module(main_pkg) return cls(main_pkg, obj_entity, main_pkg_enable) def _get_module(self, _name: str): # Fix Lazy import # https://github.com/huggingface/diffusers/blob/main/src/diffusers/__init__.py#L728-L734 module_name = f"{self._obj_entity.__name__}.{_name}" try: return importlib.import_module(module_name) except Exception as e: raise RuntimeError( f"Failed to import {module_name} because of the following error (look up to see its" f" traceback):\n{e}" ) from e def __getattr__(self, name: str): fullname = f"{self._obj_entity.__name__}.{name}" if fullname in self._intercept_dict: return self._intercept_dict[fullname] with self._main_pkg_enable(): obj_entity = getattr(self._obj_entity, name, None) if obj_entity is None: obj_entity = self._get_module(name) if ismodule(obj_entity): return DynamicMockModule(self._pkg_name, obj_entity, self._main_pkg_enable) return obj_entity def __all__(self): with self._main_pkg_enable(): return dir(self._obj_entity) ================================================ FILE: python/oneflow/mock_torch/mock_importer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import builtins from functools import partial import types from inspect import ismodule, currentframe from types import ModuleType from typing import Any, Dict, Optional from importlib.abc import MetaPathFinder, Loader from importlib.machinery import ModuleSpec from importlib.util import find_spec, module_from_spec import sys from typing import List from zipimport import zipimporter import oneflow.support.env_var_util as env_var_util from .mock_modules import MockModuleDict, DummyModule from .mock_utils import MockEnableDisableMixin error_msg = """ is not implemented, please submit an issue at 'https://github.com/Oneflow-Inc/oneflow/issues' including the log information of the error, the minimum reproduction code, and the system information.""" # patch hasattr so that # 1. torch.not_exist returns DummyModule object, but # 2. hasattr(torch, "not_exist") still returns False _builtin_hasattr = builtins.hasattr if not isinstance(_builtin_hasattr, types.BuiltinFunctionType): raise Exception("hasattr already patched by someone else!") def hasattr(obj, name): return _builtin_hasattr(obj, name) builtins.hasattr = hasattr def probably_called_from_hasattr(): frame = currentframe().f_back.f_back return frame.f_code is hasattr.__code__ # module wrapper with checks for existence of methods class ModuleWrapper(ModuleType): # TODO add selcted methods def __init__(self, module): self.module = module def __setattr__(self, name, value): super().__setattr__(name, value) if name != "module": setattr(self.module, name, value) def __getattr__(self, name: str) -> Any: if not hasattr(self.module, name): if name == "__path__": return None if name == "__all__": return [attr for attr in dir(self.module) if not attr.startswith("_")] new_name = self.module.__name__ + "." + name if _importer.lazy and not probably_called_from_hasattr(): if _importer.verbose: print( f'"{new_name}" is not found in oneflow, use dummy object as fallback.' ) return DummyModule(new_name, verbose=_importer.verbose) else: if _importer.lazy and _importer.verbose: print(f"hasattr({self.module.__name__}, {name}) returns False") raise AttributeError(new_name + error_msg) attr = getattr(self.module, name) if ismodule(attr): return ModuleWrapper(attr) else: return attr class OneflowImporter(MockEnableDisableMixin, MetaPathFinder, Loader): def __init__(self): # module_from_spec will try to call the loader's create_module, resulting in infinite recursion self.in_create_module = False self.enable = False # both __init__.py of oneflow and torch can't be executed multiple times, so we use a cache self.enable_mod_cache = {} self.disable_mod_cache = {} # Record modules loaded during mocking for deletion self.delete_list = [] def find_spec(self, fullname, path, target=None): if module_dict_global.in_forward_dict( fullname ): # don't touch modules other than torch or extra libs module # for first import of real torch, we use default meta path finders, not our own if not self.enable and self.disable_mod_cache.get(fullname) is None: return None return ModuleSpec(fullname, self) self.delete_list.append(fullname) return None def find_module(self, fullname, path=None): spec = self.find_spec(fullname, path) return spec def create_module(self, spec): if self.in_create_module: return None self.in_create_module = True if self.enable: if module_dict_global.in_forward_dict(spec.name): oneflow_mod_fullname = module_dict_global.forward_name(spec.name) if ( sys.modules.get(oneflow_mod_fullname) is None and self.enable_mod_cache.get(spec.name) is None ): # get actual oneflow module try: real_spec = find_spec(oneflow_mod_fullname) except ModuleNotFoundError: real_spec = None if real_spec is None: self.in_create_module = False if self.lazy: if self.verbose: print( f"{oneflow_mod_fullname} is not found in oneflow, use dummy object as fallback." ) return DummyModule(oneflow_mod_fullname, verbose=self.verbose) else: raise ModuleNotFoundError(oneflow_mod_fullname + error_msg) real_mod = module_from_spec(real_spec) loader = real_spec.loader if isinstance(loader, zipimporter): # TODO: verify can mock torch as oneflow in zipimporter pass else: loader.exec_module(real_mod) else: real_mod = sys.modules.get(oneflow_mod_fullname) if real_mod is None: real_mod = self.enable_mod_cache[spec.name] self.in_create_module = False return real_mod else: torch_full_name = spec.name real_mod = self.disable_mod_cache[torch_full_name] self.in_create_module = False return real_mod def exec_module(self, module): module_name = module.__name__ if module_dict_global.in_inverse_dict(module_name): fullname = module_dict_global.inverse_name(module_name) if self.enable: if not isinstance(module, DummyModule): module = ModuleWrapper(module) sys.modules[fullname] = module globals()[fullname] = module def _enable( self, globals=None, lazy=False, verbose=False, *, main_pkg: str = None, mock_version: bool = None, required_dependencies: List[str] = [], from_cli: bool = False, ): if verbose: print("enable mock torch", globals["__name__"]) if self.enable: # already enabled of_importer_module_name = self.globals["__name__"] input_module_name = globals["__name__"] if of_importer_module_name != input_module_name: print( f"Warning: {of_importer_module_name} is already enabled, but {input_module_name} is trying to enable it again. skip." ) return # record config for re-enabling self._mock_enable_config = {k: v for k, v in locals().items() if k != "self"} # insert importer to the first place of meta_path sys.meta_path.insert(0, self) self.lazy = lazy self.verbose = verbose self.from_cli = from_cli self.globals = globals self.mock_enable( globals=globals, module_dict=module_dict_global, main_pkg=main_pkg, mock_version=mock_version, required_dependencies=required_dependencies, from_cli=from_cli, verbose=verbose, ) self.enable = True def _disable(self, globals, *, verbose=False): if verbose: print( "disable mock torch in", globals["__name__"], "\tself.enable: ", self.enable, ) if not self.enable: # already disabled return of_importer_module_name = self.globals["__name__"] input_module_name = globals["__name__"] if of_importer_module_name != input_module_name: raise RuntimeError( f"Error: {of_importer_module_name} is enabled, but {input_module_name} is trying to disable it. must disable it in the same module." ) self.mock_disable( globals=globals, module_dict=module_dict_global, delete_list=self.delete_list, from_cli=self.from_cli, ) sys.meta_path.remove(self) self.enable = False self.delete_list = [] self.globals = None _importer = OneflowImporter() class BaseMockConfig: def __init__( self, lazy: Optional[bool] = None, verbose: Optional[bool] = None, extra_dict: Optional[Dict[str, str]] = None, *, main_pkg: Optional[str] = None, mock_version: Optional[str] = None, required_dependencies: List[str] = [], _from_cli: bool = False, ): global module_dict_global module_dict_global = MockModuleDict(extra_dict) module_dict_global.add("torch", "oneflow") required_dependencies.extend( [k for k in extra_dict or {} if k not in required_dependencies] ) if "torch" not in required_dependencies: required_dependencies.append("torch") parse_bool_env = partial( env_var_util.parse_boolean_from_env, defalut_value=False ) forcedly_disabled_by_env_var = parse_bool_env("ONEFLOW_DISABLE_MOCK_TORCH") lazy = lazy if lazy is not None else parse_bool_env("ONEFLOW_MOCK_TORCH_LAZY") verbose = ( verbose if verbose is not None else parse_bool_env("ONEFLOW_MOCK_TORCH_VERBOSE") ) self.lazy = lazy self.verbose = verbose self.forcedly_disabled_by_env_var = forcedly_disabled_by_env_var self.required_dependencies = required_dependencies self.parse_bool_env = parse_bool_env self._from_cli = _from_cli self.main_pkg = main_pkg self.mock_version = mock_version class enable(BaseMockConfig): """https://docs.oneflow.org/master/cookies/oneflow_torch.html""" def __init__( self, lazy: Optional[bool] = None, verbose: Optional[bool] = None, extra_dict: Optional[Dict[str, str]] = None, *, main_pkg: Optional[str] = None, mock_version: Optional[str] = None, required_dependencies: List[str] = [], _from_cli: bool = False, ): super().__init__( lazy=lazy, verbose=verbose, extra_dict=extra_dict, main_pkg=main_pkg, mock_version=mock_version, required_dependencies=required_dependencies, _from_cli=_from_cli, ) if self.forcedly_disabled_by_env_var: # super().__init__ will set this return self.globals = currentframe().f_back.f_globals self.skip_processing = False if getattr(_importer, "globals", None) is not None: import_name = _importer.globals["__name__"] if import_name == self.globals["__name__"]: self.skip_processing = True return self._importer_enable = _importer.enable if self._importer_enable: self._mock_enable_config = _importer._mock_enable_config _importer._disable(_importer.globals, verbose=self.verbose) _importer._enable( self.globals, lazy=self.lazy, verbose=self.verbose, main_pkg=main_pkg, mock_version=mock_version, required_dependencies=required_dependencies, from_cli=_from_cli, ) def __enter__(self): pass def __exit__(self, exception_type, exception_value, traceback): if self.forcedly_disabled_by_env_var or self.skip_processing: return _importer._disable(_importer.globals, verbose=self.verbose) if self._importer_enable: _importer._enable( # When re-enabling mock torch, from_cli shoule always be False **self._mock_enable_config, ) class disable: def __init__(self): self._importer_enable = _importer.enable if not self._importer_enable: return self.globals = currentframe().f_back.f_globals self.lazy = _importer.lazy self.verbose = _importer.verbose self._mock_enable_config = _importer._mock_enable_config _importer._disable(_importer.globals, verbose=self.verbose) def __enter__(self): pass def __exit__(self, exception_type, exception_value, traceback): if self._importer_enable: _importer._enable( # When re-enabling mock torch, from_cli shoule always be False **self._mock_enable_config, ) def is_enabled(): return _importer.enable ================================================ FILE: python/oneflow/mock_torch/mock_modules.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from types import ModuleType __all__ = ["MockModuleDict", "DummyModule"] class MockModuleDict: def __init__(self, mapping=None): if mapping is not None and not isinstance(mapping, dict): raise ValueError("Extra mock library must be a dict.") self.forward = {} self.inverse = {} if mapping is not None: for key, value in mapping.items(): self.add(key, value) def add(self, key, value): """mock key thorugh value.""" if key in self.forward or value in self.inverse: raise ValueError("Key or value already exists.") self.forward[key] = value self.inverse[value] = key def remove(self, key=None, value=None): if key is not None: value = self.forward.pop(key) self.inverse.pop(value) elif value is not None: key = self.inverse.pop(value) self.forward.pop(key) else: raise ValueError("Must provide a key or value to remove.") def in_forward_dict(self, s): return s.split(".")[0] in self.forward.keys() def in_inverse_dict(self, s): return s.split(".")[0] in self.inverse.keys() def inverse_name(self, s: str): # s: spec.name return self.inverse[s.split(".")[0]] + s[len(s.split(".")[0]) :] def forward_name(self, s: str): return self.forward[s.split(".")[0]] + s[len(s.split(".")[0]) :] class DummyModule(ModuleType): def __init__(self, name, verbose=False): super().__init__(name) self._verbose = verbose def __getattr__(self, name): if self._verbose: print( f'"{self.__name__}" is a dummy object, and its attr "{name}" is accessed.' ) if name == "__path__": return None if name == "__all__": return [] if name == "__file__": return None if name == "__mro_entries__": return lambda x: () return DummyModule(self.__name__ + "." + name, self._verbose) def __getitem__(self, name): new_name = f"{self.__name__}[{name}]" if isinstance(name, int): if self._verbose: print( f'"{self.__name__}" is a dummy object, and `{new_name}` is called. Raising an IndexError to simulate an empty list.' ) raise IndexError if self._verbose: print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') return DummyModule(new_name, self._verbose) def __call__(self, *args, **kwargs): new_name = f'{self.__name__}({", ".join(map(repr, args))}, {", ".join(["{}={}".format(k, repr(v)) for k, v in kwargs.items()])})' if self._verbose: print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') return DummyModule(new_name, self._verbose) def __bool__(self): if self._verbose: print( f'"{self.__name__}" is a dummy object, and its bool value is accessed.' ) return False def __enter__(self): raise RuntimeError( f'"{self.__name__}" is a dummy object, and does not support "with" statement.' ) def __exit__(self, exception_type, exception_value, traceback): raise RuntimeError( f'"{self.__name__}" is a dummy object, and does not support "with" statement.' ) def __subclasscheck__(self, subclass): return False def __instancecheck__(self, instance): return False ================================================ FILE: python/oneflow/mock_torch/mock_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys import sysconfig import pkgutil from collections import deque from importlib import import_module if sys.version_info <= (3, 8): try: from importlib_metadata import requires except ImportError: import subprocess subprocess.check_call("pip install importlib_metadata", shell=True) subprocess.check_call("pip install packaging", shell=True) else: from importlib.metadata import requires from packaging.requirements import Requirement from pathlib import Path from functools import lru_cache from typing import List, Optional from types import ModuleType __all__ = ["MockEnableDisableMixin"] class PackageDependencyMixin: """Get all dependencies of a package filtered by a list of dependencies. Example: >>> import diffusers # version 0.24.0 >>> op = PackageDependencyMixin() >>> result = op.has_dependencies("diffusers", ["torch"]) >>> print(result) ['huggingface_hub', 'diffusers'] """ pkg_cache = {} # {pkg: [deps]} @staticmethod def find_matching_dependencies( main_pkg: str, dependencies: List[str], max_visits=1000 ) -> List[str]: @lru_cache() def python_stdlib_packages(): # current python stdlib path stdlib_path = sysconfig.get_paths()["stdlib"] # use pkgutil to list all modules in the standard library python_modules = [ name for _, name, _ in pkgutil.iter_modules([stdlib_path]) ] # combine built-in module names and Python modules all_modules = list(sys.builtin_module_names) + python_modules return all_modules def format_package_name(pkg: str): return Requirement(pkg).name.replace("-", "_") @lru_cache() def get_requirements(pkg: str): python_modules = python_stdlib_packages() if pkg in python_modules: return [] try: direct_dependencies = requires(pkg) if len(direct_dependencies) == 0: return [] result = set() for pkg in direct_dependencies: pkg = format_package_name(pkg) if pkg == main_pkg: continue if pkg not in python_modules: result.add(pkg) return list(result) except: return [] def is_leaf_package(pkg) -> bool: if pkg in dependencies: return True return len(get_requirements(pkg)) == 0 main_pkg = format_package_name(main_pkg) # build graph graph = {} # {dep: [pkg1, pkg2, ...]} queue = deque([main_pkg]) visited = set() stops = set() while queue: pkg = queue.popleft() if is_leaf_package(pkg): stops.add(pkg) continue if pkg in visited: continue visited.add(pkg) if len(visited) > max_visits: print( f"\033[1;33mWARNING: max_visits {max_visits} reached, stop searching.\033[0m" ) break for req in get_requirements(pkg): graph.setdefault(req, set()).add(pkg) queue.append(req) # init cache and queue cache = {} visited.clear() queue = deque(stops) for pkg in stops: cache[pkg] = True if pkg in dependencies else False # bfs_from_stops while queue: pkg = queue.popleft() if pkg in visited: continue visited.add(pkg) for dep in graph.get(pkg, set()): is_ok = cache.get(dep, False) if cache[pkg] or is_ok: is_ok = True cache[dep] = is_ok queue.append(dep) return [pkg for pkg, is_ok in cache.items() if is_ok] @staticmethod def varify_input(main_pkg, dependencies, callback, verbose=False): try: requires(main_pkg) except: if verbose: print( f"WARNING: main_pkg {main_pkg} has no meta information, please check if it is a valid package." ) print("will set it as its own dependency to avoid error.") PackageDependencyMixin.pkg_cache[main_pkg] = [main_pkg] + dependencies if not isinstance(main_pkg, str): raise ValueError("main_pkg must be a string.") if not isinstance(dependencies, list): raise ValueError("dependencies must be a list.") if not all([isinstance(dep, str) for dep in dependencies]): raise ValueError("dependencies must be a list of strings.") if callback is not None and not callable(callback): raise ValueError("callback must be a callable.") @classmethod def has_dependencies( self, main_pkg: str, dependencies: List[str], callback: callable = None, *, verbose=False, ) -> List[str]: """Check if a package has any dependencies in a list of dependencies.""" PackageDependencyMixin.varify_input(main_pkg, dependencies, callback, verbose) deps = PackageDependencyMixin.pkg_cache.get(main_pkg, None) if deps is None: deps = PackageDependencyMixin.find_matching_dependencies( main_pkg, dependencies ) PackageDependencyMixin.pkg_cache.update({main_pkg: deps}) if verbose: print("PackageDependencyMixin : main_pkg=", main_pkg, ", deps=", deps) if callback: return callback(deps) else: return deps class VersionMixin: version_cache = {} def mock_version(self, module_a: ModuleType, module_b: ModuleType): """Mock the version of module_a with the version of module_b.""" if isinstance(module_a, str): module_a = import_module(module_a) if isinstance(module_b, str): module_b = import_module(module_b) attr_name = "__version__" orig_attr = getattr(module_a, attr_name, None) setattr(module_a, attr_name, getattr(module_b, attr_name, None)) VersionMixin.version_cache.update({module_a: (attr_name, orig_attr)}) def restore_version(self): for module, (attr_name, orig_attr) in self.version_cache.items(): setattr(module, attr_name, orig_attr) VersionMixin.version_cache.clear() class MockEnableDisableMixin(PackageDependencyMixin, VersionMixin): """Mock torch package using OneFlow.""" # list of hazardous modules that may cause issues, handle with care hazard_list = [ "_distutils_hack", "importlib", "regex", "tokenizers", "safetensors._safetensors_rust", ] def is_safe_module(self, module_key): k = module_key hazard_list = MockEnableDisableMixin.hazard_list name = k if "." not in k else k[: k.find(".")] if name in hazard_list or k in hazard_list: return False return True def mock_enable( self, globals, # parent_globals module_dict, # MockModuleDict object *, main_pkg: Optional[str] = None, mock_version: Optional[str] = None, required_dependencies: List[str], from_cli=False, verbose=False, **kwargs, ): """Mock torch package using OneFlow. Args: `globals`: The globals() of the parent module. `module_dict`: MockModuleDict object. `main_pkg`: The main package to mock. `required_dependencies`: The dependencies to mock for the `main_pkg`. """ if mock_version: mock_map = module_dict.forward for pkg, mock_pkg in mock_map.items(): self.mock_version(pkg, mock_pkg) if not hasattr(self, "enable_mod_cache"): self.enable_mod_cache = {} if not hasattr(self, "disable_mod_cache"): self.disable_mod_cache = {} if not hasattr(self, "mock_safety_packages"): self.mock_safety_packages = set() if main_pkg: # Analyze the dependencies of the main package cur_sys_modules = sys.modules.copy() existing_deps = self.has_dependencies( main_pkg, dependencies=required_dependencies, callback=lambda x: [dep for dep in x if dep in cur_sys_modules], verbose=verbose, ) if verbose: print( "Existing dependencies of ", "main_pkg: ", main_pkg, "existing_deps: ", existing_deps, ) self.mock_safety_packages.update(existing_deps) # disable non-safe modules loaded before mocking def can_disable_mod_cache(k): # module_key if not self.is_safe_module(k): return False if module_dict.in_forward_dict(k): return True for dep_pkg in self.mock_safety_packages: if k.startswith(dep_pkg + ".") or k == dep_pkg: return True return False for k, v in sys.modules.copy().items(): exclude_torch_from_cli = not (from_cli and k == "torch") if not exclude_torch_from_cli: # torch is imported from CLI continue if can_disable_mod_cache(k): aliases = [alias for alias, value in globals.items() if value is v] self.disable_mod_cache.update({k: (v, aliases)}) del sys.modules[k] for alias in aliases: del globals[alias] # restore modules loaded during mocking for k, (v, aliases) in self.enable_mod_cache.items(): sys.modules.update({k: v}) for alias in aliases: globals.update({alias: v}) def mock_disable(self, globals, module_dict, delete_list, from_cli=False): """Disable the mocked packages.""" if not hasattr(self, "enable_mod_cache") or not hasattr( self, "disable_mod_cache" ): RuntimeError("Please call mock_enable() first.") # disable modules loaded during mocking def can_enable_mod_cache(k): # module_key if not self.is_safe_module(k): return False if module_dict.in_forward_dict(k): return True return k in delete_list for k, v in sys.modules.copy().items(): if can_enable_mod_cache(k): aliases = [alias for alias, value in globals.items() if value is v] self.enable_mod_cache.update({k: (v, aliases)}) del sys.modules[k] for alias in aliases: del globals[alias] # restore modules loaded during before mocking for k, (v, aliases) in self.disable_mod_cache.items(): sys.modules.update({k: v}) for alias in aliases: globals.update({alias: v}) if from_cli: torch_env = Path(__file__).parent if str(torch_env) in sys.path: sys.path.remove(str(torch_env)) self.restore_version() ================================================ FILE: python/oneflow/mock_torch/torch/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys import oneflow from oneflow.mock_torch import ModuleWrapper, enable def __getattr__(name: str): return ModuleWrapper(oneflow).__getattr__(name) enable(_from_cli=True) ================================================ FILE: python/oneflow/model.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.model import Callback, CheckpointConfig, DataModule from oneflow.framework.model import Model as Model from oneflow.framework.model import NumpyDataModule, TrainingConfig, ValidationConfig ================================================ FILE: python/oneflow/multiprocessing/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ oneflow.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. It registers custom reducers, that use shared memory to provide shared views on the same data in different processes. Once the tensor/storage is moved to shared_memory (see :func:`~oneflow.Tensor.share_memory_`), it will be possible to send it to other processes without making any copies. The API is 100% compatible with the original module - it's enough to change ``import multiprocessing`` to ``import oneflow.multiprocessing`` to have all the tensors sent through the queues or shared via other mechanisms, moved to shared memory. Because of the similarity of APIs we do not document most of this package contents, and we recommend referring to very good docs of the original module. """ import oneflow as flow import sys from .reductions import init_reductions import multiprocessing __all__ = [ "set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies", "unlink_all_shared_memory", ] from multiprocessing import * # noqa: F403 __all__ += multiprocessing.__all__ # type: ignore[attr-defined] # This call adds a Linux specific prctl(2) wrapper function to this module. # See https://github.com/pytorch/pytorch/pull/14391 for more information. flow._oneflow_internal._multiprocessing_init() """Add helper function to spawn N processes and wait for completion of any of them. This depends `mp.get_context` which was added in Python 3.4.""" from .spawn import ( spawn, SpawnContext, start_processes, ProcessContext, ProcessRaisedException, ProcessExitedException, ) if sys.platform == "darwin" or sys.platform == "win32": _sharing_strategy = "file_system" _all_sharing_strategies = {"file_system"} else: _sharing_strategy = "file_descriptor" _all_sharing_strategies = {"file_descriptor", "file_system"} def set_sharing_strategy(new_strategy): """Sets the strategy for sharing CPU tensors. Args: new_strategy (str): Name of the selected strategy. Should be one of the values returned by :func:`get_all_sharing_strategies()`. """ global _sharing_strategy assert new_strategy in _all_sharing_strategies _sharing_strategy = new_strategy def get_sharing_strategy(): """Returns the current strategy for sharing CPU tensors.""" return _sharing_strategy def get_all_sharing_strategies(): """Returns a set of sharing strategies supported on a current system.""" return _all_sharing_strategies def unlink_all_shared_memory(): flow._oneflow_internal.multiprocessing.unlink_all_shared_memory() init_reductions() ================================================ FILE: python/oneflow/multiprocessing/_atfork.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys __all__ = ["register_after_fork"] if sys.platform == "win32" or sys.version_info < (3, 7): import multiprocessing.util as _util def _register(func): def wrapper(arg): func() _util.register_after_fork(_register, wrapper) else: import os def _register(func): os.register_at_fork(after_in_child=func) def register_after_fork(func): """Register a callable to be executed in the child process after a fork. Note: In python < 3.7 this will only work with processes created using the ``multiprocessing`` module. In python >= 3.7 it also works with ``os.fork()``. Args: func (function): Function taking no arguments to be called in the child after fork """ _register(func) ================================================ FILE: python/oneflow/multiprocessing/pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import multiprocessing.pool import multiprocessing.util as util from .queue import SimpleQueue def clean_worker(*args, **kwargs): import gc multiprocessing.pool.worker(*args, **kwargs) # Regular multiprocessing workers don't fully clean up after themselves, # so we have to explicitly trigger garbage collection to make sure that all # destructors are called... gc.collect() class Pool(multiprocessing.pool.Pool): """Pool implementation which uses our version of SimpleQueue. This lets us pass tensors in shared memory across processes instead of serializing the underlying data.""" def _setup_queues(self): self._inqueue = SimpleQueue() self._outqueue = SimpleQueue() self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv def _repopulate_pool(self): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for i in range(self._processes - len(self._pool)): # changed worker -> clean_worker args = ( self._inqueue, self._outqueue, self._initializer, self._initargs, self._maxtasksperchild, ) if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) w = self.Process(target=clean_worker, args=args) self._pool.append(w) w.name = w.name.replace("Process", "PoolWorker") w.daemon = True w.start() util.debug("added worker") ================================================ FILE: python/oneflow/multiprocessing/queue.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import io import multiprocessing.queues from multiprocessing.reduction import ForkingPickler import pickle class ConnectionWrapper(object): """Proxy class for _multiprocessing.Connection which uses ForkingPickler to serialize objects""" def __init__(self, conn): self.conn = conn def send(self, obj): buf = io.BytesIO() ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) self.send_bytes(buf.getvalue()) def recv(self): buf = self.recv_bytes() return pickle.loads(buf) def __getattr__(self, name): if "conn" in self.__dict__: return getattr(self.conn, name) raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, "conn") ) class Queue(multiprocessing.queues.Queue): def __init__(self, *args, **kwargs): super(Queue, self).__init__(*args, **kwargs) self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) self._send = self._writer.send self._recv = self._reader.recv class SimpleQueue(multiprocessing.queues.SimpleQueue): def _make_methods(self): if not isinstance(self._reader, ConnectionWrapper): self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) super(SimpleQueue, self)._make_methods() # type: ignore[misc] ================================================ FILE: python/oneflow/multiprocessing/reductions.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from multiprocessing.reduction import ForkingPickler import numpy as np import oneflow as flow from oneflow.nn.parameter import Parameter from oneflow.framework.tensor import Tensor from oneflow.multiprocessing import shared_memory try: # Early load resource_sharer to prevent a partially initialized instance # from being inherited in a forked child process. The reduce_storage method # requires this module indirectly through DupFd(). The built-in mp.Queue # class pickles arguments in a background thread which may overlap with the # fork. import multiprocessing.resource_sharer except ImportError: pass def rebuild_empty_tensor(shape, dtype, requires_grad): t = flow.tensor([], dtype=dtype) t.requires_grad = requires_grad return t.reshape(*shape) def rebuild_shm_tensor(shm, shape, dtype, requires_grad): def delete_shm(): try: # For unknown reasons delete_shm called in dataloader may fail # with "StopIteration". # An example is when dataloader is wrapped in a generator like # `log_every`. shm.close() shm.unlink() except: pass arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) t = flow.from_numpy(arr) t._register_storage_delete_hook(delete_shm) t.requires_grad = requires_grad return t def rebuild_empty_parameter(shape, dtype, requires_grad): t = flow.tensor([], dtype=dtype) t = t.reshape(*shape) return Parameter(t, requires_grad=requires_grad) def rebuild_shm_parameter(shm, shape, dtype, requires_grad): def delete_shm(): shm.close() shm.unlink() arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) t = flow.from_numpy(arr) t._register_storage_delete_hook(delete_shm) return Parameter(t, requires_grad=requires_grad) def reduce_tensor(tensor): tensor_data = tensor.numpy() requires_grad = tensor.requires_grad if tensor_data.nbytes == 0: return (rebuild_empty_tensor, (tensor.shape, tensor.dtype, requires_grad)) else: shm = shared_memory.SharedMemory(create=True, size=tensor_data.nbytes) shm_numpy = np.ndarray( tensor_data.shape, dtype=tensor_data.dtype, buffer=shm.buf ) shm_numpy[:] = tensor_data[:] return ( rebuild_shm_tensor, (shm, tensor_data.shape, tensor_data.dtype, requires_grad), ) def reduce_parameter(tensor): tensor_data = tensor.numpy() requires_grad = tensor.requires_grad if tensor_data.nbytes == 0: return (rebuild_empty_parameter, (tensor, shape, tensor.dtype, requires_grad)) else: shm = shared_memory.SharedMemory(create=True, size=tensor_data.nbytes) shm_numpy = np.ndarray( tensor_data.shape, dtype=tensor_data.dtype, buffer=shm.buf ) shm_numpy[:] = tensor_data[:] return ( rebuild_shm_parameter, (shm, tensor_data.shape, tensor_data.dtype, requires_grad), ) def init_reductions(): ForkingPickler.register(Tensor, reduce_tensor) ForkingPickler.register(flow._oneflow_internal.Tensor, reduce_tensor) ForkingPickler.register(Parameter, reduce_parameter) ForkingPickler.register(flow._oneflow_internal.nn.Parameter, reduce_parameter) ================================================ FILE: python/oneflow/multiprocessing/shared_memory/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow __all__ = ["SharedMemory"] class SharedMemory: def __init__(self, name=None, create=False, size=0): if not size >= 0: raise ValueError("'size' must be a non-negative integer") if create: if size == 0: raise ValueError("'size' must be a positive number different from zero") self.shm_ = flow._oneflow_internal.multiprocessing.SharedMemory( name=name if name is not None else "", create=create, size=size ) def __del__(self): try: if hasattr(self, "shm_"): self.close() except OSError: pass def __reduce__(self): return ( self.__class__, (self.name, False, self.size,), ) def __repr__(self): return f"{self.__class__.__name__}({self.name!r}, size={self.size})" @property def buf(self): "A memoryview of contents of the shared memory block." return self.shm_.buf @property def name(self): "Unique name that identifies the shared memory block." return self.shm_.name @property def size(self): "Size in bytes." return self.shm_.size def close(self): """Closes access to the shared memory from this instance but does not destroy the shared memory block.""" return self.shm_.close() def unlink(self): """Requests that the underlying shared memory block be destroyed. In order to ensure proper cleanup of resources, unlink should be called once (and only once) across all processes which have access to the shared memory block.""" return self.shm_.unlink() ================================================ FILE: python/oneflow/multiprocessing/spawn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional import multiprocessing import multiprocessing.connection import signal import sys import warnings from oneflow.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] class ProcessException(Exception): __slots__ = ["error_index", "error_pid"] def __init__(self, msg: str, error_index: int, pid: int): super().__init__(msg) self.error_index = error_index self.pid = pid class ProcessRaisedException(ProcessException): """ Exception is thrown when the process failed due to exception raised by the code. """ def __init__( self, msg: str, error_index: int, error_pid: int, ): super().__init__(msg, error_index, error_pid) class ProcessExitedException(ProcessException): """ Exception is thrown when the process failed due to signal or exited with a specific code. """ __slots__ = ["exit_code"] def __init__( self, msg: str, error_index: int, error_pid: int, exit_code: int, signal_name: Optional[str] = None, ): super().__init__(msg, error_index, error_pid) self.exit_code = exit_code self.signal_name = signal_name def _wrap(fn, i, args, error_queue): # prctl(2) is a Linux specific system call. # On other systems the following function call has no effect. # This is set to ensure that non-daemonic child processes can # terminate if their parent terminates before they do. _prctl_pr_set_pdeathsig(signal.SIGINT) try: fn(i, *args) except KeyboardInterrupt: pass # SIGINT; Killed by parent, do nothing except Exception: # Propagate exception to parent process, keeping original traceback import traceback error_queue.put(traceback.format_exc()) sys.exit(1) class ProcessContext: def __init__(self, processes, error_queues): self.error_queues = error_queues self.processes = processes self.sentinels = { process.sentinel: index for index, process in enumerate(processes) } def pids(self): return [int(process.pid) for process in self.processes] def join(self, timeout=None): r""" Tries to join one or more processes in this spawn context. If one of them exited with a non-zero exit status, this function kills the remaining processes and raises an exception with the cause of the first process exiting. Returns ``True`` if all processes have been joined successfully, ``False`` if there are more processes that need to be joined. Args: timeout (float): Wait this long before giving up on waiting. """ # Ensure this function can be called even when we're done. if len(self.sentinels) == 0: return True # Wait for any process to fail or all of them to succeed. ready = multiprocessing.connection.wait(self.sentinels.keys(), timeout=timeout,) error_index = None for sentinel in ready: index = self.sentinels.pop(sentinel) process = self.processes[index] process.join() if process.exitcode != 0: error_index = index break # Return if there was no error. if error_index is None: # Return whether or not all processes have been joined. return len(self.sentinels) == 0 # Assume failure. Terminate processes that are still alive. for process in self.processes: if process.is_alive(): process.terminate() process.join() # There won't be an error on the queue if the process crashed. failed_process = self.processes[error_index] if self.error_queues[error_index].empty(): exitcode = self.processes[error_index].exitcode if exitcode < 0: name = signal.Signals(-exitcode).name raise ProcessExitedException( "process %d terminated with signal %s" % (error_index, name), error_index=error_index, error_pid=failed_process.pid, exit_code=exitcode, signal_name=name, ) else: raise ProcessExitedException( "process %d terminated with exit code %d" % (error_index, exitcode), error_index=error_index, error_pid=failed_process.pid, exit_code=exitcode, ) original_trace = self.error_queues[error_index].get() msg = "\n\n-- Process %d terminated with the following error:\n" % error_index msg += original_trace raise ProcessRaisedException(msg, error_index, failed_process.pid) class SpawnContext(ProcessContext): def __init__(self, processes, error_queues): warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") super(SpawnContext, self).__init__(processes, error_queues) pass # Note: [start_processes] # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' # works better than 'spawn'. Every helper function we created for mp.spawn is indeed # general enough, and backends like XLA can reuse them in Colab notebooks as well. # Currently we only add this API first, we can consider adding it to documentation as # needed in the future. def start_processes( fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" ): mp = multiprocessing.get_context(start_method) error_queues = [] processes = [] for i in range(nprocs): error_queue = mp.SimpleQueue() process = mp.Process( target=_wrap, args=(fn, i, args, error_queue), daemon=daemon, ) process.start() error_queues.append(error_queue) processes.append(process) context = ProcessContext(processes, error_queues) if not join: return context # Loop on join until it returns True or raises an exception. while not context.join(): pass def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. If one of the processes exits with a non-zero exit status, the remaining processes are killed and an exception is raised with the cause of termination. In the case an exception was caught in the child process, it is forwarded and its traceback is included in the exception raised in the parent process. Args: fn (function): Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args)``, where ``i`` is the process index and ``args`` is the passed through tuple of arguments. args (tuple): Arguments passed to ``fn``. nprocs (int): Number of processes to spawn. join (bool): Perform a blocking join on all processes. daemon (bool): The spawned processes' daemon flag. If set to True, daemonic processes will be created. start_method (string): (deprecated) this method will always use ``spawn`` as the start method. To use a different start method use ``start_processes()``. Returns: None if ``join`` is ``True``, :class:`~ProcessContext` if ``join`` is ``False`` """ if start_method != "spawn": msg = ( "This method only supports start_method=spawn (got: %s).\n" "To use a different start_method use:\n\t\t" " oneflow.multiprocessing.start_processes(...)" % start_method ) warnings.warn(msg) return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") ================================================ FILE: python/oneflow/nn/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .modules import * from oneflow.nn.graph import Graph from oneflow.nn.modules.activation import ( ELU, CELU, GELU, QuickGELU, SquareReLU, GLU, Hardsigmoid, Hardshrink, Hardswish, Hardtanh, LeakyReLU, RReLU, LogSigmoid, LogSoftmax, Mish, PReLU, ReLU, ReLU6, Sigmoid, Softmax, Softshrink, Softplus, Tanh, SELU, SiLU, Softsign, Threshold, ) from oneflow.nn.modules.all_reduce import AllReduce from oneflow.nn.modules.batchnorm import ( BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, ) from oneflow.nn.modules.batchnorm_fused import ( FusedBatchNorm1d, FusedBatchNorm2d, FusedBatchNorm3d, ) from oneflow.nn.modules.fused_mlp import FusedMLP from oneflow.nn.modules.container import ( ModuleDict, ModuleList, ParameterDict, ParameterList, Sequential, ) from oneflow.nn.modules.conv import ( Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, ) from oneflow.nn.modules.distance import CosineSimilarity, PairwiseDistance from oneflow.nn.modules.min_max_observer import MinMaxObserver from oneflow.nn.modules.moving_average_min_max_observer import ( MovingAverageMinMaxObserver, ) from oneflow.nn.modules.fake_quantization import FakeQuantization from oneflow.nn.modules.quantization import Quantization from oneflow.nn.modules.distributed_partial_fc_sample import ( DistributedPariticalFCSample, ) from oneflow.nn.modules.dataset import ( COCOReader, CoinFlip, CropMirrorNormalize, OFRecordImageDecoder, OFRecordImageDecoderRandomCrop, OFRecordImageGpuDecoderRandomCropResize, OFRecordRawDecoder, OFRecordRawDecoder as OfrecordRawDecoder, OFRecordReader, OFRecordReader as OfrecordReader, OFRecordBytesDecoder, GPTIndexedBinDataReader, RawReader, ) from oneflow.nn.modules.dropout import Dropout, Dropout1d, Dropout2d, Dropout3d from oneflow.nn.modules.flatten import Flatten from oneflow.nn.modules.instancenorm import ( InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, ) from oneflow.nn.modules.linear import Identity, Linear from oneflow.nn.modules.loss import ( BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, CTCLoss, KLDivLoss, L1Loss, MarginRankingLoss, MSELoss, NLLLoss, SmoothL1Loss, CombinedMarginLoss, TripletMarginLoss, ) from oneflow.nn.modules.normalization import GroupNorm, LayerNorm, RMSLayerNorm, RMSNorm from oneflow.nn.modules.padding import ( ConstantPad1d, ConstantPad2d, ConstantPad3d, ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, ReplicationPad2d, ZeroPad2d, ) from oneflow.nn.modules.pixelshuffle import PixelShufflev2 as PixelShuffle from oneflow.nn.modules.pooling import ( AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, ) from oneflow.nn.modules.sparse import Embedding from oneflow.nn.modules.upsampling import ( Upsample, UpsamplingBilinear2d, UpsamplingNearest2d, ) from oneflow.nn.modules.fold import Fold, Unfold from oneflow.nn.parameter import Parameter from oneflow.nn import utils from . import functional from . import parallel from oneflow.nn.modules.rnn import ( RNNCellBase, RNNCell, LSTMCell, GRUCell, RNNBase, RNN, LSTM, GRU, ) from oneflow.nn.qat.conv import QatConv1d, QatConv2d, QatConv3d class DataParallel(Module): def __init__(self): raise NotImplementedError() ================================================ FILE: python/oneflow/nn/common_types.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Tuple, TypeVar, Union T = TypeVar("T") _scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] _scalar_or_tuple_1_t = Union[T, Tuple[T]] _scalar_or_tuple_2_t = Union[T, Tuple[T, T]] _scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] _scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] _scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] _scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] _size_any_t = _scalar_or_tuple_any_t[int] _size_1_t = _scalar_or_tuple_1_t[int] _size_2_t = _scalar_or_tuple_2_t[int] _size_3_t = _scalar_or_tuple_3_t[int] _size_4_t = _scalar_or_tuple_4_t[int] _size_5_t = _scalar_or_tuple_5_t[int] _size_6_t = _scalar_or_tuple_6_t[int] _ratio_2_t = _scalar_or_tuple_2_t[float] _ratio_3_t = _scalar_or_tuple_3_t[float] _ratio_any_t = _scalar_or_tuple_any_t[float] ================================================ FILE: python/oneflow/nn/functional/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.modules.interpolate import interpolate, interpolate_like from oneflow.nn.modules.affine_grid import affine_grid from oneflow.nn.modules.grid_sample import grid_sample from oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy from oneflow._C import conv1d from oneflow._C import conv2d from oneflow._C import conv3d from oneflow._C import deconv1d as conv_transpose1d from oneflow._C import deconv2d as conv_transpose2d from oneflow._C import deconv3d as conv_transpose3d from oneflow._C import avg_pool1d from oneflow._C import avg_pool2d from oneflow._C import avg_pool3d from .maxpool import max_pool1d from .maxpool import max_pool2d from .maxpool import max_pool3d from .maxpool import adaptive_max_pool1d from .maxpool import adaptive_max_pool2d from .maxpool import adaptive_max_pool3d from oneflow._C import adaptive_avg_pool1d from oneflow._C import adaptive_avg_pool2d from oneflow._C import adaptive_avg_pool3d from oneflow._C import max_unpool1d from oneflow._C import max_unpool2d from oneflow._C import max_unpool3d from oneflow._C import cosine_similarity, pairwise_distance from oneflow._C import relu from oneflow._C import square_relu from oneflow._C import hardtanh from oneflow._C import hardsigmoid from oneflow._C import hardshrink from oneflow._C import hardswish from oneflow._C import leaky_relu from oneflow._C import rrelu, rrelu_ from oneflow._C import elu from oneflow._C import celu from oneflow._C import selu from oneflow._C import sigmoid from oneflow._C import softshrink from oneflow._C import prelu from oneflow._C import gelu_with_approximate as gelu from oneflow._C import quick_gelu from oneflow._C import glu from oneflow._C import logsigmoid from oneflow._C import log_softmax from oneflow._C import softsign from .softmax import softmax from oneflow._C import softplus from oneflow._C import tanh from oneflow._C import threshold from oneflow._C import silu from oneflow._C import mish from oneflow.nn.modules.normalization import layer_norm, group_norm from oneflow._C import dropout, dropout1d, dropout2d, dropout3d from oneflow._C import smooth_l1_loss from .pad import pad from .batch_norm import batch_norm from oneflow._C import triplet_margin_loss from oneflow._C import ctc_greedy_decoder from .ctc_loss import ctc_loss from oneflow._C import one_hot from oneflow._C import normalize from oneflow._C import mse_loss from oneflow._C import l1_loss from oneflow._C import cross_entropy from oneflow._C import binary_cross_entropy_loss as binary_cross_entropy from oneflow._C import ( binary_cross_entropy_with_logits_loss as binary_cross_entropy_with_logits, ) from oneflow.nn.modules.sparse import embedding from oneflow.nn.modules.linear import linear from oneflow.nn.modules.activation import relu6 from oneflow.nn.modules.upsampling import Upsample as upsample from oneflow._C import unfold from oneflow._C import fold from .deform_conv import deform_conv2d from oneflow._C import kl_div_loss as kl_div from oneflow._C import gumbel_softmax from .depend import depend ================================================ FILE: python/oneflow/nn/functional/batch_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os from typing import List, Optional from oneflow.framework.tensor import Tensor import oneflow as flow def batch_norm( input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor: r"""Applies Batch Normalization for each channel across a batch of data. See :class:`~oneflow.nn.BatchNorm1d`, :class:`~oneflow.nn.BatchNorm2d`, :class:`~oneflow.nn.BatchNorm3d` for details. """ if input.ndim == 4 and os.getenv("ONEFLOW_ENABLE_NHWC") == "1": axis = 3 else: axis = 1 return flow._C.normalization( input, running_mean, running_var, weight, bias, axis, eps, momentum, training, ) ================================================ FILE: python/oneflow/nn/functional/ctc_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.tensor import Tensor import oneflow as flow def ctc_loss( log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank=0, reduction="mean", zero_infinity=False, ) -> Tensor: r""" The Connectionist Temporal Classification loss. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.functional.ctc_loss.html See :class:`~oneflow.nn.CTCLoss` for details. Args: log_probs: The logarithmized probabilities of the outputs. targets: Targets cannot be blank. In the second form, the targets are assumed to be concatenated. input_lengths: Lengths of the inputs. target_lengths: Lengths of the targets. blank: Black label, default 0. reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'`` . Default ``'Mean'``. zero_infinity: Whether to zero infinite losses and the associated gradients. Default ``False``. Example: >>> import oneflow as flow >>> import oneflow.nn as nn >>> import oneflow.nn.functional as F >>> log_probs = flow.tensor( ... [ ... [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]], ... [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]], ... [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]], ... [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]], ... [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]], ... ], ... dtype=flow.float32, ... requires_grad=True, ... ) >>> targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32, device="cuda") >>> input_lengths = flow.tensor([5, 5], dtype=flow.int32) >>> target_lengths = flow.tensor([3, 3], dtype=flow.int32) >>> out = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> out tensor(1.1376, dtype=oneflow.float32, grad_fn=) """ max_target_length = 0 if targets.ndim == 1: max_target_length = target_lengths.max().item() elif targets.ndim == 2: max_target_length = targets.shape[1] return flow._C.ctc_loss( log_probs, targets, input_lengths, target_lengths, max_target_length, blank, zero_infinity, reduction, ) ================================================ FILE: python/oneflow/nn/functional/deform_conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional, Tuple, Union import oneflow as flow from oneflow.framework.tensor import Tensor def deform_conv2d( input: Tensor, offset: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), mask: Optional[Tensor] = None, ) -> Tensor: r""" Performs Deformable Convolution v2, described in `Deformable ConvNets v2: More Deformable, Better Results `__ if :attr:`mask` is not ``None`` and Performs Deformable Convolution, described in `Deformable Convolutional Networks `__ if :attr:`mask` is ``None``. Args: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights, split into groups of size (in_channels // groups) bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None stride (int or Tuple[int, int]): distance between convolution centers. Default: 1 padding (int or Tuple[int, int]): height/width of padding of zeroes around each image. Default: 0 dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]): masks to be applied for each position in the convolution kernel. Default: None Returns: Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> input = flow.rand(4, 3, 10, 10) >>> kh, kw = 3, 3 >>> weight = flow.rand(5, 3, kh, kw) >>> # offset and mask should have the same spatial size as the output >>> # of the convolution. In this case, for an input of 10, stride of 1 >>> # and kernel size of 3, without padding, the output size is 8 >>> offset = flow.rand(4, 2 * kh * kw, 8, 8) >>> mask = flow.rand(4, kh * kw, 8, 8) >>> out = F.deform_conv2d(input, offset, weight, mask=mask) >>> out.size() oneflow.Size([4, 5, 8, 8]) """ use_mask = mask is not None if mask is None: mask = flow.zeros((input.shape[0], 0), dtype=input.dtype).to(input.device) stride_h = stride[0] stride_w = stride[1] pad_h = padding[0] pad_w = padding[1] dil_h = dilation[0] dil_w = dilation[1] weights_h, weights_w = weight.shape[-2:] # TODO(yzm): Support rectangle convolution if weights_h != weights_w: raise NotImplementedError("Rectangle convolution is not supported currently.") if use_mask and len(mask.shape) != 4: raise RuntimeError("The dimension of mask tensor weight must be 4") if len(input.shape) != 4: raise RuntimeError("The dimension of input tensor weight must be 4") if len(weight.shape) != 4: raise RuntimeError("The dimension of weight tensor weight must be 4") if len(offset.shape) != 4: raise RuntimeError("The dimension of offset tensor weight must be 4") _, n_in_channels, _, _ = input.shape n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) n_weight_grps = n_in_channels // weight.shape[1] if n_offset_grps == 0: raise RuntimeError( "The shape of the offset tensor at dimension 1 is not valid. It should " "be a multiple of 2 * weight.size[2] * weight.size[3].\n" f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}" ) return flow._C.deform_conv2d( input, weight, offset, mask, bias, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, n_weight_grps, n_offset_grps, use_mask, ) ================================================ FILE: python/oneflow/nn/functional/depend.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.tensor import Tensor import oneflow as flow from typing import Union, List def depend(input: Tensor, depend: Union[Tensor, List[Tensor]]) -> Tensor: r""" Add control dependency to guarantee OP A is executed before OP B. Used to prevent OPs from being rearranged or eliminated during graph compilation. Args: input (Tensor): a tensor intended to input OP B depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP) Returns: Tensor: the identity of "input" tensor Examples: >>> import oneflow as flow >>> import oneflow.nn as nn >>> import oneflow.nn.functional as F >>> class Model(nn.Module): ... def __init__(self): ... super().__init__() ... self.OP_A = nn.Linear(128, 128) ... self.OP_B = nn.Linear(128, 128) ... ... def forward(self, x): ... x1 = self.OP_A(x) ... x = F.depend(x, x1) ... return self.OP_B(x) ... >>> model = Model() >>> class Graph(nn.Graph): ... def __init__(self) -> None: ... super().__init__() ... self.model = model ... ... def build(self, x): ... return self.model(x) ... >>> graph = Graph() >>> x = flow.randn([1, 128], dtype=flow.float32) >>> y = graph(x) """ # avoid performance loss in eager mode if not input.is_lazy: return input # avoid self-loop if isinstance(depend, Tensor) and input is depend: raise RuntimeError('"input" and "depend" can NOT be the same tensor.') if isinstance(depend, List): for idx, t_depend in enumerate(depend): if input is t_depend: raise RuntimeError( '"input" and "depend[%d]" are the same tensor, which is not allowed.' % idx ) return flow._C.depend(input, depend) ================================================ FILE: python/oneflow/nn/functional/maxpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow # oneflow._C.max_poolXd returns a TensorTuple, to align torch, # here we return different result according to the param `return_indices`. def max_pool1d( x, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, data_format="channels_first", ): r""" max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False,ceil_mode=False, data_format="channels_first") Applies a 1D max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool1d.html. .. note:: The order of :attr:`ceil_mode` and :attr:`return_indices` is different from what seen in :class:`~oneflow.nn.MaxPool1d`, and will change in a future release. See :class:`~oneflow.nn.MaxPool1d` for details. Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in_channels} , iW)`, minibatch dim optional. kernel_size: the size of the window. Can be a single number or a tuple `(kW,)` stride: the stride of the window. Can be a single number or a tuple `(sW,)`. Default: :attr:`kernel_size` padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. dilation: The stride between elements within a sliding window, must be > 0. return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`oneflow.nn.functional.max_unpool1d` later. ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. """ _max_pool_out = oneflow._C.max_pool1d( x, kernel_size, stride, padding, dilation, return_indices, ceil_mode, data_format, ) if return_indices: return _max_pool_out else: return _max_pool_out[0] def max_pool2d( x, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, data_format="channels_first", ): r""" max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False,data_format="channels_first") Applies a 2D max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool2d.html. .. note:: The order of :attr:`ceil_mode` and :attr:`return_indices` is different from what seen in :class:`~oneflow.nn.MaxPool2d`, and will change in a future release. See :class:`~oneflow.nn.MaxPool2d` for details. Args: input: input tensor :math:`(\text{minibatch} , \text{in_channels} , iH , iW)`, minibatch dim optional. kernel_size: size of the pooling region. Can be a single number or a tuple `(kH, kW)` stride: stride of the pooling operation. Can be a single number or a tuple `(sH, sW)`. Default: :attr:`kernel_size` padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. dilation: The stride between elements within a sliding window, must be > 0. return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`oneflow.nn.functional.max_unpool2d` later. ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. """ _max_pool_out = oneflow._C.max_pool2d( x, kernel_size, stride, padding, dilation, return_indices, ceil_mode, data_format, ) if return_indices: return _max_pool_out else: return _max_pool_out[0] def max_pool3d( x, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, data_format="channels_first", ): r""" max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, data_format="channels_first") Applies a 3D max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool3d.html. .. note:: The order of :attr:`ceil_mode` and :attr:`return_indices` is different from what seen in :class:`~oneflow.nn.MaxPool3d`, and will change in a future release. See :class:`~oneflow.nn.MaxPool3d` for details. Args: input: input tensor :math:`(\text{minibatch} , \text{in_channels} , iD, iH , iW)`, minibatch dim optional. kernel_size: size of the pooling region. Can be a single number or a tuple `(kT, kH, kW)` stride: stride of the pooling operation. Can be a single number or a tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. dilation: The stride between elements within a sliding window, must be > 0. return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`~oneflow.nn.functional.max_unpool3d` later. ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. """ _max_pool_out = oneflow._C.max_pool3d( x, kernel_size, stride, padding, dilation, return_indices, ceil_mode, data_format, ) if return_indices: return _max_pool_out else: return _max_pool_out[0] def adaptive_max_pool1d(input, output_size, return_indices: bool = False): r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool1d.html See :class:`~oneflow.nn.AdaptiveMaxPool1d` for details and output shape. Args: output_size: the target output size (single integer) return_indices: whether to return pooling indices. Default: ``False`` """ _out = oneflow._C.adaptive_max_pool1d(input, output_size) if return_indices: return _out else: return _out[0] def adaptive_max_pool2d( input, output_size, return_indices: bool = False, data_format="channels_first" ): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool2d.html See :class:`~oneflow.nn.AdaptiveMaxPool2d` for details and output shape. Args: output_size: the target output size (single integer or double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ _out = oneflow._C.adaptive_max_pool2d(input, output_size, data_format=data_format) if return_indices: return _out else: return _out[0] def adaptive_max_pool3d(input, output_size, return_indices: bool = False): r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool3d.html See :class:`~oneflow.nn.AdaptiveMaxPool3d` for details and output shape. Args: output_size: the target output size (single integer or triple-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ _out = oneflow._C.adaptive_max_pool3d(input, output_size) if return_indices: return _out else: return _out[0] ================================================ FILE: python/oneflow/nn/functional/pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List from oneflow.framework.tensor import Tensor import oneflow as flow def pad( input: Tensor, pad: List[int], mode: str = "constant", value: float = 0.0 ) -> Tensor: r"""Pads tensor. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.pad.html. Padding size: The padding size by which to pad some dimensions of :attr:`input` are described starting from the last dimension and moving forward. :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions of ``input`` will be padded. For example, to pad only the last dimension of the input tensor, then :attr:`pad` has the form :math:`(\text{padding_left}, \text{padding_right})`; to pad the last 2 dimensions of the input tensor, then use :math:`(\text{padding_left}, \text{padding_right},` :math:`\text{padding_top}, \text{padding_bottom})`; to pad the last 3 dimensions, use :math:`(\text{padding_left}, \text{padding_right},` :math:`\text{padding_top}, \text{padding_bottom}` :math:`\text{padding_front}, \text{padding_back})`. Padding mode: See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the padding modes works. Constant padding is implemented for arbitrary dimensions. Replicate and reflection padding is implemented for padding the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor. Note: When using the CUDA backend, this operation may induce nondeterministic behaviour in its backward pass that is not easily switched off. Args: input (Tensor): N-dimensional tensor pad (tuple): m-elements tuple, where :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'constant'`` value: fill value for ``'constant'`` padding. Default: ``0`` Examples:: >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> t4d = flow.empty(3, 3, 4, 2) >>> p1d = (1, 1) >>> out = F.pad(t4d, p1d) >>> out.size() oneflow.Size([3, 3, 4, 4]) """ assert len(pad) % 2 == 0, "Padding length must be divisible by 2" assert len(pad) // 2 <= input.dim(), "Padding length too large" if mode == "constant": return flow._C.pad(input, pad, mode="constant", value=value) else: assert ( value == 0.0 ), 'Padding mode "{}"" doesn\'t take in value argument'.format(mode) if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3): if mode == "reflect": return flow._C.pad(input, pad, mode="reflect") elif mode == "replicate": return flow._C.pad(input, pad, mode="replicate") elif mode == "circular": raise NotImplementedError( "1D circular padding are not supported for now" ) else: raise NotImplementedError elif len(pad) == 4 and (input.dim() == 3 or input.dim() == 4): if mode == "reflect": return flow._C.pad(input, pad, mode="reflect") elif mode == "replicate": return flow._C.pad(input, pad, mode="replicate") elif mode == "circular": raise NotImplementedError( "2D circular padding are not supported for now" ) else: raise NotImplementedError elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5): if mode == "reflect": raise NotImplementedError( "3D reflect padding are not supported for now" ) elif mode == "replicate": raise NotImplementedError( "3D replicate padding are not supported for now" ) elif mode == "circular": raise NotImplementedError( "3D circular padding are not supported for now" ) else: raise NotImplementedError else: raise NotImplementedError( "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now" ) ================================================ FILE: python/oneflow/nn/functional/softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os from typing import List, Optional from oneflow.framework.tensor import Tensor import oneflow as flow # ref https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py def softmax(input: Tensor, dim: Optional[int] = None, dtype=None) -> Tensor: r"""Applies a softmax function. Softmax is defined as: :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` It is applied to all slices along dim, and will re-scale them so that the elements lie in the range `[0, 1]` and sum to 1. See :class:`~oneflow.nn.Softmax` for more details. Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. .. note:: This function doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use log_softmax instead (it's faster and has better numerical properties). """ if dtype is None: ret = flow._C.softmax(input, dim) else: ret = flow._C.softmax(input.to(dtype), dim) return ret ================================================ FILE: python/oneflow/nn/graph/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .graph import Graph from .proxy import Proxy from .graph_block import GraphModule from .graph_block import GraphTensor ================================================ FILE: python/oneflow/nn/graph/cache.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import weakref from collections import deque, OrderedDict from typing import Dict, Union from oneflow.framework.args_tree import ArgsTree from oneflow.framework.tensor import Tensor import oneflow class LRUCache(object): _cnt: int = 0 def __init__(self, cache_size, keep_the_1st=True): assert cache_size >= 2 self.cache_size = cache_size self.hash_map = dict() self.keep_the_1st = keep_the_1st self.queue = deque() def is_empty(self): return len(self.hash_map) == 0 def is_full(self): return len(self.hash_map) >= self.cache_size def pop(self): if len(self.queue) == 0: return None pop_key = self.queue.pop() value = self.hash_map.pop(pop_key) del value return pop_key def set(self, key, value): new_key = None old_key = None if key in self.hash_map: return new_key, old_key if self.is_full(): old_key = self.pop() assert old_key is not None, f"Cache size is {self.cache_size}, at least 2." assert not self.is_full() if not (self.keep_the_1st and self.is_empty()): self.queue.appendleft(key) value._oneflow_graph_cache_order = LRUCache._cnt LRUCache._cnt += 1 self.hash_map[key] = value new_key = key return new_key, old_key def get(self, key): if key in self.hash_map: if key in self.queue: self.queue.remove(key) self.queue.appendleft(key) return self.hash_map[key] return None def items(self): for (key, value) in self.hash_map.items(): yield (key, value) class AvoidRecursiveCacheCall(object): def __init__(self, graph) -> None: self._g = graph self._prev_flag = self._g._run_with_cache def __enter__(self): self._g._run_with_cache = False def __exit__(self, exc_type, exc_val, exc_tb): self._g._run_with_cache = self._prev_flag class GraphCache(object): def __init__(self, base_graph, cache_size=10, enable_graph_shared=True): assert base_graph is not None and isinstance(base_graph, weakref.ProxyTypes) self._base_graph = base_graph self._cache_size = cache_size self._cache = None self._enable_shared = enable_graph_shared def set_cache_size(self, cache_size): self._cache_size = cache_size def enable_shared(self, enabled=True): self._enable_shared = enabled def __call__(self, *args, **kwargs): graph = self.get_graph(*args, **kwargs) with AvoidRecursiveCacheCall(graph): return graph(*args, **kwargs) def _compile(self, *args, **kwargs): graph = self.get_graph(*args, **kwargs) with AvoidRecursiveCacheCall(graph): return graph._compile(*args, **kwargs) def runtime_state_dict( self, destination=None, with_eager=False, ) -> Dict[str, Dict[str, Union[Dict[str, Tensor], str]]]: if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() for (key, graph) in self._cache.items(): with AvoidRecursiveCacheCall(graph): state_dict = graph.runtime_state_dict(with_eager=with_eager) state_dict["cache_order"] = graph._oneflow_graph_cache_order state_dict["cache_key"] = key destination[state_dict["graph_name"]] = state_dict return destination @staticmethod def runtime_state_dict_to( state_dict: Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ], device: str, ) -> Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ]: destination = OrderedDict() destination._metadata = OrderedDict() for (key, sub_state_dict) in state_dict.items(): dest_sub_state_dict = oneflow.nn.Graph.runtime_state_dict_to( sub_state_dict, device ) dest_sub_state_dict["cache_order"] = sub_state_dict["cache_order"] dest_sub_state_dict["cache_key"] = sub_state_dict["cache_key"] destination[key] = dest_sub_state_dict return destination def _init_and_get_a_graph_in_cache(self, cache_key): self._base_graph._print( 0, 0, self._base_graph._shallow_repr() + f" is creating a graph cache with key {cache_key}.", ) cur_is_base = False if self._cache.is_empty(): # Has no graph yet cur_is_base = True graph = self._base_graph else: # Create new graph from base graph = self._base_graph.__class__( *self._base_graph._cached_init_args, **self._base_graph._cached_init_kwargs, ) graph._run_with_cache = False graph._dynamic_input_graph_cache = None graph._cached_init_args = None graph._cached_init_kwargs = None if self._enable_shared is True: if cur_is_base: graph.enable_shared() else: graph.share_from(self._base_graph) new_key, old_key = self._cache.set(cache_key, graph) if old_key is not None: self._base_graph._print( 0, 0, self._base_graph._shallow_repr() + f" cache is full(cache size {self._cache_size}), has deleted an old graph cache with key {old_key}.", ) assert new_key is not None return graph def load_runtime_state_dict( self, state_dict: Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], *, warmup_with_run: bool = False, ) -> None: graph_dict = dict() for _, sub_state_dict in state_dict.items(): cache_order = sub_state_dict["cache_order"] graph_dict[cache_order] = sub_state_dict if self._cache is None: self._cache = LRUCache(self._cache_size) for _, sub_state_dict in sorted(graph_dict.items()): cache_key = sub_state_dict["cache_key"] graph = self._cache.get(cache_key) assert graph is None graph = self._init_and_get_a_graph_in_cache(cache_key) with AvoidRecursiveCacheCall(graph): graph.load_runtime_state_dict( sub_state_dict, warmup_with_run=warmup_with_run ) def gen_key(self, *args, **kwargs): flattened_shapes = [] args_tree = ArgsTree((args, kwargs), False) for arg in args_tree.iter_nodes(): if isinstance(arg, Tensor): flattened_shapes.append(arg.shape) return tuple(flattened_shapes) def get_graph(self, *args, **kwargs): if self._cache is None: self._cache = LRUCache(self._cache_size) cache_key = hash(self.gen_key(*args, **kwargs)) graph = self._cache.get(cache_key) # Create graph if graph is None: self._base_graph._print( 0, 0, self._base_graph._shallow_repr() + " got a new input shape, is compiling a new graph.", ) graph = self._init_and_get_a_graph_in_cache(cache_key) return graph ================================================ FILE: python/oneflow/nn/graph/graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import warnings import os import sys import time import inspect import weakref from collections import OrderedDict from functools import partial, wraps from typing import Dict, Optional, Union, List, Callable from google.protobuf import text_format from copy import deepcopy import oneflow import oneflow._oneflow_internal import oneflow.core.job.job_pb2 as job_pb import oneflow.framework.c_api_util as c_api_util import oneflow.framework.graph_build_util as graph_build_util import oneflow.framework.session_context as session_ctx from oneflow.amp import GradScaler, StaticGradScaler from oneflow.env import get_rank from oneflow.framework.multi_client_session import MultiClientSession from oneflow.framework.tensor import Tensor, TensorTuple from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple from oneflow.nn.graph.proxy import ( Proxy, GraphBlockType, get_proxy_cls, GraphModule, GraphTensor, ) from oneflow.nn.graph.graph_config import GraphConfig from oneflow.nn.graph.optimizer import OptDict, VariableConfig from oneflow.nn.graph.util import ( add_indent, operators_repr, GraphIR, seq_to_func_return, sys_exc_error_msg, _rsd_sub_destination_to, _job_to, _plan_to, ) from oneflow.framework.args_tree import ArgsTree from oneflow.nn.modules.module import Module from oneflow.nn.optimizer.lr_scheduler import LRScheduler from oneflow.optim.optimizer import Optimizer class Graph(object): r"""Base class for training or evaluating a neural network in static graph mode. To use static graph mode for model training or evaluation in OneFlow, you should: 1. Define your customized graph as a subclass of ``nn.Graph``. 2. Add ``super().__init__()`` in your subclass's ``__init__()``. 3. Add modules to your graph as regular attributes. 4. Define computation logical in ``build()`` method. 5. Instantiate your graph then call it. For example: .. code-block:: python >>> import oneflow as flow >>> class LinearGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... # Add a module to the graph. ... self.linear = flow.nn.Linear(3, 8, False) ... def build(self, x): ... # Use the module to build the computation logic of the graph. ... return self.linear(x) # Instantiate the graph >>> linear_graph = LinearGraph() >>> x = flow.randn(4, 3) # First call on graph will run graph's build() method to # trace a computatioin graph. Then the computation graph will be # optimized and executed for the first time. >>> linear_graph(x).shape oneflow.Size([4, 8]) # Later call on graph will execute the computation graph directly. >>> linear_graph(x).shape oneflow.Size([4, 8]) Note: nn.Graph cannot be nested at the moment. """ _child_init_cnt = dict() def __init__( self, *, enable_get_runtime_state_dict: bool = False, debug_v_level: int = -1, debug_ranks: Optional[Union[int, List[int]]] = None, debug_max_py_stack_depth: int = 2, debug_only_user_py_stack=True, debug_op_repr_with_py_stack=False, ): """ Initializes internal Graph states. It MUST be called in ``__init__`` method of subclass. For example: .. code-block:: python >>> import oneflow as flow >>> class CustomGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() # MUST be called ... # Then define the graph attributes ... def build(self): ... pass """ self._generate_name() self.config = GraphConfig() self._blocks = OrderedDict() self._opts = [] self._verbose = False self._grad_scaler = None self._variables_conf = OrderedDict() self._additional_variable_tobe_loaded = OrderedDict() self._is_compiled = False self._is_user_mode = False # Default is local view self._is_global_view = False # Optimize the overhead of graph input/output process self._is_simple_tuple_input = False self._is_simple_tuple_output = False self._outputs_buffer_size = 2 self._cur_index_of_ouputs_buffer = 0 # For graph level op rewrite self._unique_global_op_dict = dict() self._unique_identity_op_dict = dict() # Graph compilation related. # forward graph job proto self._forward_job_proto = None # forward, backward and optimized graph job proto self._full_job_proto = None # completed graph job proto self._compiled_job_proto = None self._job_id = None self._args_repr = [] self._outs_repr = [] self._oneflow_internal_graph_ir__ = None enalbe_lazy_separate_compile = os.environ.get( "ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE" ) if enalbe_lazy_separate_compile != None and enalbe_lazy_separate_compile == "1": os.environ["ONEFLOW_LAZY_COMPILE_MODE"] = "rank_per_process" # Separate compile mode only works with nccl use compute stream and logical chain. os.environ["ENABLE_LOGICAL_CHAIN"] = "1" oneflow.boxing.nccl.enable_use_compute_stream(True) self._session = session_ctx.GetDefaultSession() assert type(self._session) is MultiClientSession self._session.TryInit() self._c_nn_graph = None self.env_enable_mlir_inference_opt = None # For build graph from another graph with different input shape. self._enable_shared_from_this = False self._build_with_shared_graph = False # For load graph from runtime states. self.enable_save_runtime_state_dict(enable_get_runtime_state_dict) self._is_from_runtime_state_dict = False # For run graph with dynamic shape cache self._run_with_cache = False # For debug self._debug = False self._debug_min_s_level = 2 self._debug_max_v_level = 0 self._debug_max_py_stack_depth = 2 self._debug_op_repr_with_py_stack = False self._debug_only_user_py_stack = True self.debug( debug_v_level, ranks=debug_ranks, max_py_stack_depth=debug_max_py_stack_depth, only_user_py_stack=debug_only_user_py_stack, op_repr_with_py_stack=debug_op_repr_with_py_stack, ) def build(self, *args, **kwargs): r"""The ``build()`` method must be overridden to define neural network computaion logic. The ``build()`` method of nn.Graph is very similar to the ``forward()`` method of nn.Module. It is used to describe the computatioin logical of a neural network. When a graph object being called for the first time, the ``build()`` method will be called implicitly to build the computatioin graph. Make sure to call modules's ``train()`` or ``eval()`` method before the first call of your graph to make the module executing the right training or evaluation logic if needed. For example: .. code-block:: python >>> import oneflow as flow >>> linear = flow.nn.Linear(3, 8, False) >>> class MyGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = linear ... def build(self, x): ... return self.model(x) >>> linear_graph = MyGraph() >>> x = flow.randn(4, 3) >>> linear.eval() # make linear module executing in evaluation mode Linear(in_features=3, out_features=8, bias=False) >>> y = linear_graph(x) # The build() method is called implicitly Note: ``build()`` method's inputs and outputs support list/tuple/dict, but the item in them must be one of these types: * ``Tensor`` * ``None`` """ raise NotImplementedError( "nn.Graph.build() method must be overridden when subclassing the nn.Graph." ) def __call__(self, *args, **kwargs): r"""Call nn.Graph subclass instance to run your customized graph. Call your customized graph after the instantiation: For example: .. code-block:: python g = CustomGraph() out_tensors = g(input_tensors) The inputs of ``__call__`` method must match the inputs of ``build()`` method. And the ``__call__`` method will return outputs matching the outputs of ``build()`` method. Note: The first call takes longer than later calls, because nn.Graph will do the computaion graph generation and optimization at the first call. Donot override this function. """ # For cache cache graphs with dynamic input shape. if self._run_with_cache: return self._dynamic_input_graph_cache(*args, **kwargs) if not self._is_compiled: self._compile(*args, **kwargs) return self.__run(*args, **kwargs) def add_optimizer( self, optim: Optimizer, *, lr_sch: LRScheduler = None, is_sparse: bool = False, ): r"""Add an optimizer, an learning rate scheduler to the graph. To do training with nn.Graph, you should do 2 more things: 1. Add at least one optimizer(learning rate schedulers are optional) with ``add_optimizer()`` method. 2. Call loss tensor's ``backward()`` method in ``build()`` method. Note that the computaion graph will automatically execute these methods: * optimizer's ``clip_grad()`` if a optimizer is set to do grad cliping. * optimizer's ``step()``. * optimizer's ``zero_grad()``. * learn rate scheduler's ``step()``. Also note that only scalar tensor are allowed to call ``backward()`` in ``nn.Graph.build()`` for the moment. So you may call methods such as ``Tensor.mean()`` to make the loss tensor a scalar tensor. Note: If you want to output the learning rate information for each step, set the ``verbose`` parameter of the ``lr_scheduler`` to ``True``, and you will see the result at rank 0. This feature is the same as eager mode. For example: .. code-block:: python >>> import oneflow as flow >>> loss_fn = flow.nn.MSELoss(reduction="sum") >>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1)) >>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6) >>> class LinearTrainGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = model ... self.loss_fn = loss_fn ... # Add an optimizer ... self.add_optimizer(optimizer) ... def build(self, x, y): ... y_pred = self.model(x) ... loss = self.loss_fn(y_pred, y) ... # Call loss tensor's backward(), loss tensor must be a scalar tensor ... loss.backward() ... return loss >>> linear_graph = LinearTrainGraph() >>> x = flow.randn(10, 3) >>> y = flow.randn(10) >>> model.train() # make model executing in training mode Sequential( (0): Linear(in_features=3, out_features=1, bias=True) (1): Flatten(start_dim=0, end_dim=1) ) >>> for t in range(3): ... loss = linear_graph(x, y) Args: optim (oneflow.optim.Optimizer): The optimizer. lr_sch : The learning rate scheduler, see oneflow.optim.lr_scheduler. is_sparse: When set to be True, treat optim as a sparse optimizer. Default is False. """ opt_dict = dict() assert optim is not None, "optimizer cannot be None" assert isinstance( optim, Optimizer ), "optimizer must be an instance of Optimizer" opt_dict["optim"] = optim opt_dict["is_sparse"] = bool(is_sparse) if lr_sch is not None: assert isinstance(lr_sch, LRScheduler) assert ( lr_sch.optimizer is optim ), "lr_scheduler's optimizer must be the same optimizer in add_optimizer." opt_dict["lr_sch"] = lr_sch self._verbose = opt_dict["lr_sch"].verbose rank = get_rank() if rank != 0: self._verbose = False oneflow._oneflow_internal.SetGraphLRVerbose(self._verbose) self._opts.append(opt_dict) # Set the training config if there is an optimizer add in graph. if len(self._opts) == 1: self.config._train(True) def set_grad_scaler(self, grad_scaler: GradScaler = None): r"""Set the GradScaler for gradient and loss scaling.""" assert isinstance(grad_scaler, (GradScaler, StaticGradScaler)) self._grad_scaler = grad_scaler def state_dict( self, destination=None ) -> Dict[str, Union[Dict[str, Tensor], Tensor]]: r"""Returns a dictionary containing a whole state of the graph. States of modules/optimizers/lr schedulers in a graph are included. Keys of modules' state dict are corresponding to their name in the graph. Values of modules' state dict are corresponding to their nn.Module's state dict. Other keys and tensors are states of optimizers/lr schedulers/etc. Returns: dict: a dictionary containing the whole state of the graph. """ # Sync to make sure states has been updated. oneflow._oneflow_internal.eager.Sync() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() # Get states from sub module block for name, block in self._blocks.items(): assert block.to(GraphModule).type == GraphBlockType.MODULE sub_destination = OrderedDict() sub_destination._metadata = OrderedDict() module = block.to(Module) if module is not None: module.state_dict( sub_destination, "", keep_vars=False, ) destination[name] = sub_destination # Get additional states. # Additional variables are states in Optimizer/LRScheduler and free eager tensors of nn.Graph. if self._is_compiled: # Get from _c_nn_graph. additional_var_names = self._c_nn_graph.additional_var_names additional_var_tensors = self._c_nn_graph.additional_var_tensors assert len(additional_var_names) == len(additional_var_tensors) for i in range(len(additional_var_names)): additional_tensor = additional_var_tensors[i] if not self._is_global_view and additional_tensor.is_global: additional_tensor = additional_tensor.to_local() destination[additional_var_names[i]] = additional_tensor else: # Get from loaded dict. for name, item in self._additional_variable_tobe_loaded.items(): destination[name] = item return destination def load_state_dict( self, state_dict: Dict[str, Union[Dict[str, Tensor], Tensor]], strict: bool = True, ): r"""Copies module's states and other graph states from :attr:`state_dict` into this graph. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`nn.Graph.state_dict` function. Args: state_dict (dict): a dict containing module's states and other graph states. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this graph's :meth:`nn.Graph.state_dict` function. Default: ``True``. Note: nn.Graph's state dict can only be loaded before the first call of a graph. """ assert ( not self._is_compiled ), "nn.Graph's state dict can only be loaded before the first call of a graph." # Additional variables are states in Optimizer or LRScheduler of nn.Graph. for name, item in state_dict.items(): if name in self._blocks: # 1 load parameter/buffer to Modules self._blocks[name].to(Module).load_state_dict(item, strict) else: # 2 store other state to CNNGraph, CNNGraph load them after job pass assert isinstance(item, Tensor) self._additional_variable_tobe_loaded[name] = item @property def name(self): r"""Name auto-generated for this graph.""" return self._name @property def is_compiled(self): r"""Whether this graph is compiled or not """ return self._is_compiled @property def training(self): r"""In traninig mode if the graph has an optimizer.""" return self.config.training def debug( self, v_level: int = -1, *, ranks: Optional[Union[int, List[int]]] = None, max_py_stack_depth: int = 2, only_user_py_stack=True, op_repr_with_py_stack=False, ) -> None: r"""Open or close debug mode of the graph. If in debug mode, logs of computation graph building infos or warnings will be printed. Otherwise, only errors will be printed. Each nn.Module inside a nn.Graph also has a debug() method to enable debug mode. Use ``v_level`` to choose verbose debug info level, default level is 0, max level is 3. ``v_level`` -1 will disable the debug mode of the graph (i.e. no info will be printed). ``v_level`` 0 will print warning and graph building stages. ``v_level`` 1 will additionally print graph build info of each nn.Module. ``v_level`` 2 will additionally print graph build info of each operation. ``v_level`` 3 will additionally print more detailed info of each operation. Use ``ranks`` to choose which rank to print the debug information. Use ``max_py_stack_depth`` to specify the max Python stack depth for the debug information. Use ``only_user_py_stack`` to only print the operators' locations which are from users' code or models. Use ``op_repr_with_py_stack`` to print operators' locations when printing nn.Graph's repr. For example: .. code-block:: python g = CustomGraph() g.debug() # Open debug mode out_tensors = g(input_tensors) # Will print log for debug at the first call Args: v_level (int): choose verbose debug info level, default v_level is 0, max v_level is 3. v_level can be set to -1 to close the debug mode. ranks (int or list(int)): choose ranks to print the debug information. Default rank ``0``. You can choose any valid rank. Ranks equals ``-1`` means debug on all ranks. max_py_stack_depth(int): the maximum depth for the Python stack debug information. Default: ``2``. only_user_py_stack(bool): only to print the operators' locations from users' code. Default: ``True``. op_repr_with_py_stack(bool): print operators' locations when printing nn.Graph's repr. Default: ``False``. """ assert isinstance(v_level, int) assert v_level >= -1, "The min verbose debug info level is -1." assert v_level <= 3, "The max verbose debug info level is 3." assert max_py_stack_depth >= 0, "The min max stack depth is 0." assert isinstance(max_py_stack_depth, int) assert isinstance(only_user_py_stack, bool) assert isinstance(op_repr_with_py_stack, bool) if ranks is None: rank_list = [0] elif isinstance(ranks, int): rank_list = [ranks] elif isinstance(ranks, list): rank_list = ranks else: raise ValueError("ranks must be int or List[int].") my_rank = get_rank() if -1 in rank_list or my_rank in rank_list: self._debug = v_level >= 0 if self._debug: self._debug_min_s_level = 0 self._debug_max_v_level = max(0, v_level) for name, block in self._blocks.items(): assert block.to(GraphModule).type == GraphBlockType.MODULE block.to(GraphModule).debug( v_level, ranks=ranks, max_py_stack_depth=max_py_stack_depth, only_user_py_stack=only_user_py_stack, op_repr_with_py_stack=op_repr_with_py_stack, ) self._debug_max_py_stack_depth = max_py_stack_depth self._debug_op_repr_with_py_stack = op_repr_with_py_stack self._debug_only_user_py_stack = only_user_py_stack def __repr__(self): r"""For printing the graph structure. The graph structure can be printed after graph instantiation. After the first call of graph, inputs and outputs will be added to the graph structure. For example: .. code-block:: python g = CustomGraph() print(g) out_tensors = g(input_tensors) print(g) # Inputs and Outputs infos are added """ child_lines = [] child_lines.append(add_indent(repr(self.config), 2)) if len(self._args_repr) > 0: for in_str in self._args_repr: input_str = add_indent(in_str, 2) child_lines.append(input_str) if len(self._blocks) > 0: for n, m in self._blocks.items(): mod_str = repr(m) mod_str = add_indent(mod_str, 2) child_lines.append(mod_str) for op_str in self._ops_repr(): child_lines.append(add_indent(op_str, 2)) if len(self._outs_repr) > 0: for out_str in self._outs_repr: output_str = add_indent(out_str, 2) child_lines.append(output_str) main_str = self._shallow_repr() + ": (" if len(child_lines) > 0: main_str += "\n " + "\n ".join(child_lines) + "\n" main_str += ")" return main_str def _shallow_repr(self): shallow_repr = "(GRAPH:" + self._name + ":" + self.__class__.__name__ + ")" return shallow_repr def _ops_repr(self): r"""Generate operators' string representation of this graph """ if self._compiled_graph_proto is not None: module_conf = self._compiled_graph_proto.module_name2module_conf[self.name] if self._oneflow_internal_graph_ir__ is None: self._oneflow_internal_graph_ir__ = GraphIR(self._compiled_graph_proto) return operators_repr( module_conf.ops, self._oneflow_internal_graph_ir__, self._debug_op_repr_with_py_stack, ) return [] def __print(self, s_level=2, v_level=0, msg=None): r"""Do print according to info level.""" assert isinstance(s_level, int) assert isinstance(v_level, int) assert isinstance(msg, str) or isinstance(msg, Callable) if s_level >= self._debug_min_s_level: if (s_level > 0) or (s_level == 0 and v_level <= self._debug_max_v_level): if isinstance(msg, str): print(msg, flush=True) elif isinstance(msg, Callable): print(msg(), flush=True) def _print(self, s_level=2, v_level=0, msg=None): self.__print(s_level, v_level, msg) @property def _config_proto(self): return self.config.proto @property def _optimization_conf_proto(self): return self._session.resource @property def _graph_proto(self): if not self._is_compiled: self.__print( 2, 0, f"[ERROR]{self._shallow_repr()} has not been compiled, so it's graph proto is None." " You can call the graph to trigger it's compilation.", ) return self._forward_job_proto @property def _full_graph_proto(self): if self._full_job_proto is None: self.__print( 2, 0, f"[ERROR]{self._shallow_repr()} has not been compiled, so it's full graph proto is None." " You can call the graph to trigger it's compilation.", ) return self._full_job_proto @_full_graph_proto.setter def _full_graph_proto(self, full_job_proto): assert ( not self._is_compiled ), "nn.Graph's full graph proto can only be set before the first compilation." self._full_job_proto = full_job_proto self._c_nn_graph.job = full_job_proto.SerializeToString() @property def _compiled_graph_proto(self): if not self._is_compiled and self._compiled_job_proto is None: self.__print( 2, 0, f"[ERROR]{self._shallow_repr()} has not been compiled, so it's compiled graph proto is None." " You can call the graph to trigger it's compilation.", ) return self._compiled_job_proto def _generate_name(self): child_name = self.__class__.__name__ if Graph._child_init_cnt.get(child_name) is None: Graph._child_init_cnt[child_name] = 0 self._name = child_name + "_" + str(Graph._child_init_cnt[child_name]) Graph._child_init_cnt[child_name] += 1 def _state(self): for _, b in self._blocks.items(): pa_gen = b.parameters(recurse=True) for pa in pa_gen: yield pa bu_gen = b.buffers(recurse=True) for bu in bu_gen: yield bu def __ensure_state_tensors_contiguous(self): for state_block in self._state(): state_tensor = state_block.to(Tensor) if not state_tensor.is_contiguous(): state_tensor.contiguous_() def _filter_states(self): state_tensor_set = set() state_tensors = [] state_op_names = [] for state_block in self._state(): state_tensor = state_block.to(Tensor) # If any state tensor is global tensor, graph is in global view. if state_tensor.is_global: self._is_global_view = True if state_tensor in state_tensor_set: continue op_name = ( state_block.to(GraphTensor).name_prefix + state_block.to(GraphTensor).name ) state_tensor_set.add(state_tensor) state_tensors.append(state_tensor) state_op_names.append(op_name) if state_block.to(GraphTensor).type == GraphBlockType.PARAMETER: self._variables_conf[state_tensor] = VariableConfig(op_name) self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors) self._eager_state_op_names = deepcopy(state_op_names) return state_op_names def _generate_config_proto(self): self.config.proto.job_name = self._name self._outputs_buffer_size = self.config._outputs_buffer_size if self._grad_scaler is not None: self._grad_scaler._generate_conf_for_graph(self.config.proto.train_conf) for opt in self._opts: opt_dict = OptDict(opt) self.config._generate_optimizer_and_variable_configs( opt_dict, self._variables_conf ) def _create_states_builder(self): state2lazy_builder = dict() for state_block in self._state(): state_tensor = state_block.to(Tensor) op_name = ( state_block.to(GraphTensor).name_prefix + state_block.to(GraphTensor).name ) if state_tensor in state2lazy_builder: # Differe tensor block shares the same tensor, so they need to share the same # builder. state_block.set_lazy_origin_builder(state2lazy_builder[state_tensor]) else: if state_block.to(GraphTensor).type == GraphBlockType.PARAMETER: assert state_tensor in self._variables_conf state_config = self._variables_conf[state_tensor] op_name = state_config.name else: state_config = None # Init a new lazy tensor builder state_block.lazy_origin_builder().name = op_name state_block.lazy_origin_builder().method = partial( graph_build_util.build_graph_state, op_name, state_tensor, state_config, ) state2lazy_builder[state_tensor] = state_block.lazy_origin_builder() def _mark_variable_gradients(self): variable = [] gradients = [] for state_block in self._state(): if ( state_block.to(GraphTensor).type == GraphBlockType.PARAMETER and state_block.to(Tensor).grad is not None and state_block.to(Tensor).grad.is_lazy ): variable.append(state_block.to(Tensor)) gradients.append(state_block.to(Tensor).grad) oneflow._oneflow_internal.nn.graph.MarkVariableGradients(variable, gradients) @staticmethod def trace(func): """Trace a function to do static graph and run with nn.Graph. After decorating a function with ``trace``, the function is turned into a naive `nn.Graph`. Note: This is just a quick way to run a simple function with nn.Graph. If you want to do training or model save/load, customize a nn.Graph class instead, donot use ``trace``. For example: .. code-block:: python >>> import oneflow as flow >>> @flow.nn.Graph.trace ... def test_func(x): ... return x * 2 >>> input = flow.tensor((1, 2), dtype=flow.float32) >>> out = test_func(input) >>> out tensor([2., 4.], dtype=oneflow.float32) .. Feature Stage of Feature [trace]. - Maintainer List [@strint] - Current Stage [Pre-alpha, note that this is an experimental feature and maybe removed without notice.] """ assert inspect.isfunction( func ), f"nn.Graph.trace only support function currently, so {func} must be a function." graph_cls_name = func.__name__ + "_graph" def init(self): super(graph_cls_name, self).__init__() def build(self, *args, **kwargs): return func(*args, **kwargs) graph_cls_name = type( graph_cls_name, (Graph,), {"__init__": init, "build": build,}, ) a_graph = graph_cls_name() return a_graph def _compile(self, *args, **kwargs): if self._run_with_cache: return self._dynamic_input_graph_cache._compile(*args, **kwargs) if not self._is_compiled: if not self._build_with_shared_graph: return self._compile_new(*args, **kwargs) else: return self._compile_from_shared(*args, **kwargs) else: warnings.warn( f"{self._shallow_repr()} has been compiled, no need to compile again." ) return def _compile_new(self, *args, **kwargs): if ( len(args) != 0 and isinstance(args, (tuple, list)) and len(kwargs) == 0 and all(isinstance(arg, Tensor) for arg in args) ): self._is_simple_tuple_input = True self.__ensure_input_tensors_contiguous(*args, **kwargs) _, eager_outputs = self.build_graph(*args, **kwargs) if isinstance(eager_outputs, (tuple, list)) and all( isinstance(arg, Tensor) for arg in eager_outputs ): self._is_simple_tuple_output = True self.finish_compile_and_init_runtime() return eager_outputs def enable_shared(self, mode: bool = True): if mode: assert ( not self._is_compiled ), " enable_shared must be set before graph compile." # If enable shared, graph compile will generate more data for sharing. self._enable_shared_from_this = True else: self._enable_shared_from_this = False def share_from(self, shared_graph: "Graph") -> None: assert isinstance( shared_graph, Graph ), "shared_graph must be an instance of nn.Graph." assert ( shared_graph._enable_shared_from_this ), "shared_graph must have been enabled to be shared." assert shared_graph._is_compiled, "shared_graph must have been compiled." self._shared_graph = shared_graph self._enable_shared_from_this = False self._build_with_shared_graph = True def _compile_from_shared(self, *args, **kwargs): self.__print( 0, 0, self._shallow_repr() + " start building a shared graph and plan." ) build_graph_start = time.perf_counter() self.__ensure_input_tensors_contiguous(*args, **kwargs) self.__ensure_state_tensors_contiguous() # Filter to get unique states in graph state_op_names = self._filter_states() # Generate new config. if self._shared_graph._is_from_runtime_state_dict: # To avoid same graph name with the loaded graphs. self._name = ( self._name + "_of_shared_from_loaded_" + self._shared_graph.name ) self._generate_config_proto() # Deal with parameter and buffer self._create_states_builder() # Build current forward graph to generate some new attributes of this graph. with graph_build_util.graph_build_context(self.config.proto, self._session): self._job_id = ( oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobId() ) # Deal with inputs (input_op_names, lazy_args, lazy_kwargs, args_repr, _,) = self.__build_io( "input", graph_build_util.build_graph_input_arg, *args, **kwargs ) # Deal with module in self.build(*args) self._is_user_mode = True outputs = self.build(*lazy_args, **lazy_kwargs) self._is_user_mode = False # Always pack output to remain type of outputs outputs = (outputs,) ( output_op_names, build_eager_outputs, _, # empty kwargs return outs_repr, out2name, ) = self.__build_io("output", graph_build_util.build_graph_output, *outputs) # Save forward graph job proto self._forward_job_proto = c_api_util.GetCurrentJob() # Create op name vectors from shared graph and this graph. assert len(self._forward_job_proto.net.op) == len( self._shared_graph._forward_job_proto.net.op ) # This graph and the shared graph's original graph have same operators and operator order. # We use this to find the corresponding operator in shared graph. shared_op_names_from_ordered_original_graph = [] for op_idx in range(len(self._forward_job_proto.net.op)): shared_op_names_from_ordered_original_graph.append( self._shared_graph._forward_job_proto.net.op[op_idx].name ) # Copy the completed graph from the shared graphwo and reuse it. self._compiled_job_proto = deepcopy(self._shared_graph._compiled_graph_proto) self._compiled_job_proto.job_conf.job_name = self._name # Create a c nn graph to run with lazy runtime. self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph( self._name, self._compiled_job_proto.SerializeToString(), self._job_id, self._session._session_ctx, ) # Build graph with new inputs from a compiled job of a shared graph. inputs_tensor_tuple = convert_to_tensor_tuple( self.__flatten_io("input", *args, **kwargs) ) input_op_names = self._shared_graph._input_op_names self._c_nn_graph.build_with_new_input_from_shared_graph( input_op_names, inputs_tensor_tuple, shared_op_names_from_ordered_original_graph, self._forward_job_proto.SerializeToString(), ) # Get new compiled job proto compiled_job_str = self._c_nn_graph.get_current_job_str() self._compiled_job_proto = job_pb.Job() self._compiled_job_proto.ParseFromString(compiled_job_str) # Build output tensor buffer with new shape from the new compiled job proto. self.__rebuild_outputs( self._shared_graph._out2name, self._compiled_job_proto, self._shared_graph._build_eager_outputs, ) # Register output/variable/buffer to _c_nn_graph output_op_names = self._shared_graph._output_op_names self._c_nn_graph.register_output_op_names_and_tensors( output_op_names, self._outputs_tensor_tuple ) self._state_tensor_tuple = self._shared_graph._state_tensor_tuple self._c_nn_graph.register_variable_op_names_and_tensors( self._shared_graph._state_op_names, self._state_tensor_tuple ) self.__prepare_for_share_or_runtime_save( input_op_names, inputs_tensor_tuple, output_op_names, build_eager_outputs, out2name, *args, **kwargs, ) # Init runtime. # TODO(strint): align states needs to care about free eager tensor. self._c_nn_graph.align_states_after_logical_graph_compile() self._c_nn_graph.compile_plan_for_runtime() self._c_nn_graph.init_runtime() self._is_compiled = True build_graph_end = time.perf_counter() self.__print( 0, 0, self._shallow_repr() + " building a shared graph and plan Done! Cost time: " + str(round(build_graph_end - build_graph_start, 2)) + "s." + "\n", ) return (seq_to_func_return(self._eager_outputs_buffer[0], True),) def enable_save_runtime_state_dict(self, mode: bool = True): if mode: assert ( not self._is_compiled ), " enable_save_runtime_state_dict must be set before graph compile." # If enable save runtime states, graph compile will generate more data for save. self._enable_save_runtime_state_dict = True else: self._enable_save_runtime_state_dict = False def runtime_state_dict( self, destination=None, with_eager=False ) -> Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ]: if self._run_with_cache: return self._dynamic_input_graph_cache.runtime_state_dict( with_eager=with_eager ) assert ( self._enable_save_runtime_state_dict ), "nn.Graph's runtime state dict can only be got when enable_save_runtime_state_dict is set with True." assert ( self._is_compiled ), "nn.Graph's runtime state dict can only be got after the first call of a graph." # Sync to make sure states has been updated. oneflow._oneflow_internal.eager.Sync() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination["oneflow_version"] = oneflow.__version__ destination["graph_name"] = self.name destination["job_id"] = self._job_id def _fill_sub_destination(dest_dict, name_list, tensor_tuple): assert len(tensor_tuple) == len(name_list) for name_idx in range(len(name_list)): tensor_item = tensor_tuple[name_idx] device_str = ":".join( (tensor_item.device.type, str(tensor_item.device.index)) ) dest_dict[name_list[name_idx]] = (tensor_item, device_str) # This is original outputs is needed to build output buffer. tuple_idx = -1 def gen_index_in_tuple(item): nonlocal tuple_idx if isinstance(item, Tensor): tuple_idx += 1 return "_OFTPI" + str(tuple_idx) # oneflow tuple index else: return item inputs_sub_destination = OrderedDict() _fill_sub_destination( inputs_sub_destination, self._input_op_names, self._inputs_tensor_tuple ) _eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite( gen_index_in_tuple, *self.inputs_original[0], **self.inputs_original[1], ) destination["inputs"] = inputs_sub_destination destination["inputs_original"] = (_eager_inputs_args, _eager_inputs_kwargs) tuple_idx = -1 _eager_outputs, _ = self.__map_io_lite(gen_index_in_tuple, *self._eager_outputs) destination["outputs_original"] = _eager_outputs assert len(self._outputs_tensor_tuple) == tuple_idx + 1 outputs_sub_destination = OrderedDict() _fill_sub_destination( outputs_sub_destination, self._output_op_names, self._outputs_tensor_tuple ) destination["outputs"] = outputs_sub_destination destination["oneflow_with_eager_tensor"] = with_eager if not self._build_with_shared_graph: _state_tensor_tuple4save = [] if with_eager: _state_tensor_tuple4save = self._state_tensor_tuple else: assert len(self._state_tensor_tuple) == len(self._state_op_names) for state_idx in range(len(self._state_tensor_tuple)): if self._state_op_names[state_idx] in self._eager_state_op_names: # This state tensor is from eager module. Just save a dummy tensor here. _state_tensor_tuple4save.append( oneflow.Tensor().to( self._state_tensor_tuple[state_idx].device ) ) else: _state_tensor_tuple4save.append( self._state_tensor_tuple[state_idx] ) states_sub_destination = OrderedDict() _fill_sub_destination( states_sub_destination, self._state_op_names, _state_tensor_tuple4save ) destination["states"] = states_sub_destination destination["exe_plan"] = self._c_nn_graph.plan if self._enable_shared_from_this: destination["forward_graph"] = self._forward_job_proto destination["compile_graph"] = self._compiled_job_proto destination["id_state"] = oneflow._oneflow_internal.get_id_state() return destination def load_runtime_state_dict( self, state_dict: Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ], *, warmup_with_run: bool = True, ) -> None: if self._run_with_cache: return self._dynamic_input_graph_cache.load_runtime_state_dict( state_dict, warmup_with_run=warmup_with_run ) build_graph_start = time.perf_counter() # init id state oneflow._oneflow_internal.set_id_state(state_dict["id_state"]) self._is_from_runtime_state_dict = True self._name = state_dict["graph_name"] if "oneflow_version" not in state_dict: state_dict["oneflow_version"] = "none" if state_dict["oneflow_version"] != oneflow.__version__: warnings.warn( f"nn.Graph {self._name} WARNING: current oneflow version ({oneflow.__version__}) is loading " f"runtime_state_dict from a different version ({state_dict['oneflow_version']}), " "there may has compatibility problems." ) # Generate new config. self._generate_config_proto() self.__print(0, 0, self._shallow_repr() + " start loading a graph and plan.") self._job_id = state_dict["job_id"] # Create a c nn graph to run with lazy runtime. self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph( self._name, state_dict["exe_plan"], self._job_id, self._session._session_ctx, True, # Init from plan ) def _load_list_from_state_dict(state_dict): name_list = [] tensor_list = [] for name, item in state_dict.items(): name_list.append(name) tensor_of_item, device_of_item = item tensor_list.append(tensor_of_item.to(device_of_item)) return (name_list, convert_to_tensor_tuple(tensor_list)) self._input_op_names, self._inputs_tensor_tuple = _load_list_from_state_dict( state_dict["inputs"] ) self._output_op_names, self._outputs_tensor_tuple = _load_list_from_state_dict( state_dict["outputs"] ) _eager_inputs_args_index, _eager_inputs_kwargs_index = state_dict[ "inputs_original" ] _eager_outputs_index = state_dict["outputs_original"] def get_tensor_in_tuple(tensor_tuple, map_item): if isinstance(map_item, str) and map_item.startswith("_OFTPI"): of_idx = int(map_item[6:]) return tensor_tuple[of_idx] else: return map_item _eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite( lambda map_item: get_tensor_in_tuple(self._inputs_tensor_tuple, map_item), *_eager_inputs_args_index, **_eager_inputs_kwargs_index, ) _eager_outputs, _ = self.__map_io_lite( lambda map_item: get_tensor_in_tuple(self._outputs_tensor_tuple, map_item), *_eager_outputs_index, ) self._eager_outputs = _eager_outputs # The base graph need extra info to create new shared graph if self._enable_shared_from_this: self._forward_job_proto = state_dict["forward_graph"] self._compiled_job_proto = state_dict["compile_graph"] self._build_eager_outputs = self._eager_outputs self._out2name = dict() for output_idx in range(len(self._output_op_names)): self._out2name[ self._outputs_tensor_tuple[output_idx] ] = self._output_op_names[output_idx] # Load state tensor of modules if "oneflow_with_eager_tensor" in state_dict: with_eager = state_dict["oneflow_with_eager_tensor"] else: with_eager = True if self._build_with_shared_graph: self._state_op_names = self._shared_graph._state_op_names self._state_tensor_tuple = self._shared_graph._state_tensor_tuple else: self._state_op_names, self._state_tensor_tuple = _load_list_from_state_dict( state_dict["states"] ) if type(self) != Graph: # Graph init with eager module, try to share mem with eager module states_from_eager = dict() for state_block in self._state(): state_tensor = state_block.to(Tensor) state_op_name = ( state_block.to(GraphTensor).name_prefix + state_block.to(GraphTensor).name ) states_from_eager[state_op_name] = state_tensor for s_idx, s_name in enumerate(self._state_op_names): if s_name in states_from_eager: state_tensor_from_eager = states_from_eager[s_name] assert ( state_tensor_from_eager.device == self._state_tensor_tuple[s_idx].device ) if with_eager: assert oneflow.allclose( state_tensor_from_eager, self._state_tensor_tuple[s_idx] ) self._state_tensor_tuple[s_idx] = state_tensor_from_eager if not with_eager: for s_idx, s_name in enumerate(self._state_op_names): if (oneflow.numel(self._state_tensor_tuple[s_idx]) == 0) and ( s_name not in states_from_eager ): warnings.warn( f"Current graph is missing parameter {s_name}, but load_runtime_state_dict needs it. This may cause error later." ) self.__build_outputs_buffer() self._c_nn_graph.register_input_op_names_and_tensors( self._input_op_names, self._inputs_tensor_tuple ) self._c_nn_graph.register_output_op_names_and_tensors( self._output_op_names, self._outputs_tensor_tuple ) self._c_nn_graph.register_variable_op_names_and_tensors( self._state_op_names, self._state_tensor_tuple ) self._c_nn_graph.align_states_after_logical_graph_compile() self._c_nn_graph.init_runtime() self._is_compiled = True if warmup_with_run: self.__run( *_eager_inputs_args, **_eager_inputs_kwargs ) # pre-run to warm up oneflow._oneflow_internal.eager.Sync() build_graph_end = time.perf_counter() self.__print( 0, 0, self._shallow_repr() + " load a graph and plan Done! Cost time: " + str(round(build_graph_end - build_graph_start, 2)) + "s." + "\n", ) @staticmethod def runtime_state_dict_to( state_dict: Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ], device: str, ) -> Union[ Dict[str, Union[Dict[str, Tensor], str]], Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], ]: if "job_id" not in state_dict: from oneflow.nn.graph.cache import GraphCache return GraphCache.runtime_state_dict_to(state_dict, device) dest_device = oneflow.device(device) assert dest_device.type == "cuda", "device must be cuda." destination = OrderedDict() destination._metadata = OrderedDict() destination["oneflow_version"] = state_dict["oneflow_version"] destination["graph_name"] = state_dict["graph_name"] destination["job_id"] = state_dict["job_id"] destination["inputs"] = _rsd_sub_destination_to(state_dict["inputs"], device) destination["inputs_original"] = state_dict["inputs_original"] destination["outputs"] = _rsd_sub_destination_to(state_dict["outputs"], device) destination["outputs_original"] = state_dict["outputs_original"] destination["oneflow_with_eager_tensor"] = state_dict[ "oneflow_with_eager_tensor" ] if "states" in state_dict: destination["states"] = _rsd_sub_destination_to( state_dict["states"], device ) destination["exe_plan"] = _plan_to(state_dict["exe_plan"], dest_device) if "forward_graph" in state_dict: forward_graph = deepcopy(state_dict["forward_graph"]) _job_to(forward_graph, dest_device) destination["forward_graph"] = forward_graph if "compile_graph" in state_dict: compile_graph = deepcopy(state_dict["compile_graph"]) _job_to(compile_graph, dest_device) destination["compile_graph"] = compile_graph destination["id_state"] = state_dict["id_state"] return destination def build_graph(self, *args, **kwargs): # Build graph try: self.__print(0, 0, self._shallow_repr() + " start building graph.") assert not self._is_compiled, ( "nn.Graph " + self._name + " has already been compiled." ) build_graph_start = time.perf_counter() with graph_build_util.DebugScopeContext( self._debug_min_s_level, self._debug_max_v_level, self._debug, self._debug_max_py_stack_depth, self._debug_only_user_py_stack, ): outputs = self.__build_graph(*args, **kwargs) build_graph_end = time.perf_counter() self.__print( 0, 0, self._shallow_repr() + " building graph Done! Cost time: " + str(round(build_graph_end - build_graph_start, 2)) + "s." + "\n", ) return outputs except: self.__print( 2, 0, "[ERROR]" + self._shallow_repr() + " building graph got error." ) raise def finish_compile_and_init_runtime(self): additional_var_names = list() additional_var_tensors = list() for name, tensor in self._additional_variable_tobe_loaded.items(): additional_var_names.append(name) additional_var_tensors.append(tensor) if len(additional_var_names) > 0: self._c_nn_graph.register_additional_variable_names_and_tensors( additional_var_names, convert_to_tensor_tuple(additional_var_tensors) ) # Sync to make sure states has been loaded. oneflow._oneflow_internal.eager.Sync() # Complie graph to execution plan and init Runtime try: self.__print( 0, 0, self._shallow_repr() + " start building plan.", ) compile_and_init_start = time.perf_counter() with graph_build_util.DebugScopeContext( self._debug_min_s_level, self._debug_max_v_level, self._debug, self._debug_max_py_stack_depth, self._debug_only_user_py_stack, ): self._c_nn_graph.align_states_after_logical_graph_compile() self._c_nn_graph.complete_graph_for_runtime() # Get compiled job compiled_job_str = self._c_nn_graph.get_current_job_str() self._compiled_job_proto = job_pb.Job() self._compiled_job_proto.ParseFromString(compiled_job_str) self.__print( 0, 1, lambda: f"{self.name} with operators:\n" + self.__repr__() ) self._c_nn_graph.compile_plan_for_runtime() self._c_nn_graph.init_runtime() compile_and_init_end = time.perf_counter() self.__print( 0, 0, self._shallow_repr() + " building plan Done! Cost time: " + str(round(compile_and_init_end - compile_and_init_start, 2)) + "s." + "\n", ) except Exception as e: print(e, file=sys.stderr) self.__print( 2, 0, "[ERROR]" + self._shallow_repr() + " building plan got error." ) raise self._is_compiled = True # After compile, _additional_variable_tobe_loaded is useless. self._additional_variable_tobe_loaded.clear() def __build_graph(self, *args, **kwargs): self.__ensure_state_tensors_contiguous() # Filter to get unique states in graph state_op_names = self._filter_states() self._generate_config_proto() # Deal with parameter and buffer self.__print( 0, 1, self._shallow_repr() + " start building graph builders of parameters and buffers.", ) self._create_states_builder() self.__print( 0, 1, self._shallow_repr() + " end building graph builders of parameters and buffers.", ) with graph_build_util.graph_build_context(self.config.proto, self._session): # Deal with inputs self.__print(0, 1, self._shallow_repr() + " start building graph inputs.") ( input_op_names, lazy_args, lazy_kwargs, self._args_repr, _, ) = self.__build_io( "input", graph_build_util.build_graph_input_arg, *args, **kwargs ) self.__print(0, 1, self._shallow_repr() + " end building graph inputs.") # Deal with module in self.build(*args) self.__print(0, 1, self._shallow_repr() + " start building graph modules.") self._is_user_mode = True outputs = self.build(*lazy_args, **lazy_kwargs) self._is_user_mode = False self.__print(0, 1, self._shallow_repr() + " end building graph modules.") # Deal with outputs self.__print(0, 1, self._shallow_repr() + " start building graph outputs.") # Always pack output to remain type of outputs outputs = (outputs,) ( output_op_names, build_eager_outputs, _, # empty kwargs return self._outs_repr, out2name, ) = self.__build_io("output", graph_build_util.build_graph_output, *outputs) self.__print(0, 1, self._shallow_repr() + " end building graph outputs.") # Save forward graph job proto self._forward_job_proto = c_api_util.GetCurrentJob() if self.training: self._mark_variable_gradients() self.__print( 0, 1, self._shallow_repr() + " start building graph with compile passes.", ) self.env_enable_mlir_inference_opt = os.getenv( "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION" ) enable_mlir_inference_opt = ( False if self.env_enable_mlir_inference_opt is None else bool(self.env_enable_mlir_inference_opt) ) modules_has_training = False for item in self._blocks.values(): if item.to(Module).training: modules_has_training = True break if ( modules_has_training or self.training or self._is_global_view ) and enable_mlir_inference_opt: log_for_mlir_inference_opt = lambda extra_info: logging.warning( f"environment variable ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION will be ignored {extra_info}." ) if self.training: log_for_mlir_inference_opt("in training mode") if modules_has_training and not self.training: log_for_mlir_inference_opt( "when not all modules in graph are in eval mode" ) if self._is_global_view: log_for_mlir_inference_opt("in global mode") enable_mlir_inference_opt = False del os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] oneflow._oneflow_internal.FillVariableTensorMgr( state_op_names, self._state_tensor_tuple ) # Optimize the graph with compile passes. oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete() # Save full graph job proto after job Complete for find real output blob shape and build it. self._full_job_proto = c_api_util.GetCurrentJob() self._job_id = ( oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobId() ) self.__print( 0, 1, self._shallow_repr() + " end building graph with compile passes." ) # Re-build outputs accoring to full graph and outputs buffer config. self.__print( 0, 1, self._shallow_repr() + " start re-building graph outputs for optimizatioin.", ) self.__rebuild_outputs(out2name, self._full_job_proto, build_eager_outputs) self.__print( 0, 1, self._shallow_repr() + " end re-building graph outputs for optimizatioin.", ) # Create a c nn graph to run with lazy runtime. self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph( self._name, self._full_job_proto.SerializeToString(), self._job_id, self._session._session_ctx, ) # Register input/output/variable/buffer to _c_nn_graph inputs_tensor_tuple = convert_to_tensor_tuple( self.__flatten_io("input", *args, **kwargs) ) self._c_nn_graph.register_input_op_names_and_tensors( input_op_names, inputs_tensor_tuple ) self._c_nn_graph.register_output_op_names_and_tensors( output_op_names, self._outputs_tensor_tuple ) ( self._state_op_names, state_tensors, ) = oneflow._oneflow_internal.DumpVariableTensorMgr() self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors) self._c_nn_graph.register_variable_op_names_and_tensors( self._state_op_names, self._state_tensor_tuple ) self.__prepare_for_share_or_runtime_save( input_op_names, inputs_tensor_tuple, output_op_names, build_eager_outputs, out2name, *args, **kwargs, ) # Clear useless dict used in graph build. self._unique_global_op_dict.clear() self._unique_identity_op_dict.clear() # Always pack outputs to remain type of outputs return ( self._full_job_proto, seq_to_func_return(self._eager_outputs_buffer[0], True), ) def __prepare_for_share_or_runtime_save( self, input_op_names, inputs_tensor_tuple, output_op_names, build_eager_outputs, out2name, *args, **kwargs, ): if self._enable_save_runtime_state_dict or self._enable_shared_from_this: self._input_op_names = input_op_names self._output_op_names = output_op_names if self._enable_shared_from_this: self._build_eager_outputs = build_eager_outputs self._out2name = out2name if self._enable_save_runtime_state_dict: self._inputs_tensor_tuple = inputs_tensor_tuple self.inputs_original = (args, kwargs) def __rebuild_outputs( self, out2name=None, compiled_graph_proto=None, build_eager_outputs=None ): # NOTE(chengcheng): # Lazy build output eager tensors. # # After JobBuildAndInferCtxt.Complete, the output tensor shape # could be changed by JobPass, such as GradientAccumulationRewritePass. def build_real_output(fake_eager_out): lbn = out2name[fake_eager_out] + "/out" assert lbn in compiled_graph_proto.helper.lbn2logical_blob_desc blob_conf = compiled_graph_proto.helper.lbn2logical_blob_desc[lbn] shape = tuple(blob_conf.shape.dim) dtype = fake_eager_out.dtype with oneflow._oneflow_internal.lazy_mode.guard(False): if fake_eager_out.is_global: eager_out = oneflow.empty( shape, dtype=dtype, placement=fake_eager_out.placement, sbp=fake_eager_out.sbp, ) else: eager_out = oneflow.empty( shape, dtype=dtype, device=fake_eager_out.device ) return eager_out self._eager_outputs, _ = self.__map_io( "output", build_real_output, *build_eager_outputs ) self.__build_outputs_buffer() def __build_outputs_buffer(self): def convert_to_synced_tensor_tuple(*args): tensor_tuple = convert_to_tensor_tuple(*args) # tensors acting as buffer should be synced once upon created. oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers( tensor_tuple, self._c_nn_graph ) return tensor_tuple self._outputs_tensor_tuple = convert_to_synced_tensor_tuple( self.__flatten_io("output", *self._eager_outputs) ) self._eager_outputs_buffer = [ self._eager_outputs, ] self._outputs_tensor_tuple_buffer = [ self._outputs_tensor_tuple, ] # Make outputs buffer for i in range(self._outputs_buffer_size - 1): outputs_buffer_item, _ = self.__empty_like_io( "output", *self._eager_outputs ) self._eager_outputs_buffer.append(outputs_buffer_item) outputs_tensor_tuple_buffer_item = convert_to_synced_tensor_tuple( self.__flatten_io("output", *outputs_buffer_item) ) self._outputs_tensor_tuple_buffer.append(outputs_tensor_tuple_buffer_item) self.__check_outputs_buffer() def __check_outputs_buffer(self): has_len = len(self._outputs_tensor_tuple_buffer) assert ( has_len == self._outputs_buffer_size ), f"nn.Graph's outputs buffer size {has_len} donot match the set value {self._outputs_buffer_size}." # Check there is not duplicated outputs buffer tensor. out_id_dic = dict() def check_id_and_add(t, name): if t is not None: tid = id(t) assert ( tid not in out_id_dic ), f"nn.Graph's outputs buffer add buffer tensor tid {tid} has conflict, new item name {name}, old item name {out_id_dic[tid]}." out_id_dic[tid] = name for b_idx, buffer in enumerate(self._outputs_tensor_tuple_buffer): for i_idx, item in enumerate(buffer): check_id_and_add( item, "graph_ouputs_buffer_" + str(b_idx) + "_" + str(i_idx) ) def __run(self, *args, **kwargs): try: flattened_eager_args = self.__ensure_input_tensors_contiguous_and_flatten( *args, **kwargs ) if oneflow.support.env_var_util.parse_boolean_from_env( "ONEFLOW_RUN_GRAPH_BY_VM", False ): eager_outputs = oneflow._oneflow_internal.nn.graph.RunLazyNNGraphByVM( convert_to_tensor_tuple(flattened_eager_args), self._c_nn_graph, ) if len(eager_outputs) == 1: return eager_outputs[0] else: return eager_outputs else: outputs_tensor_tuple = self._outputs_tensor_tuple_buffer[ self._cur_index_of_ouputs_buffer ] eager_outputs = self._eager_outputs_buffer[ self._cur_index_of_ouputs_buffer ] # oneflow._oneflow_internal.eager.Sync() NOTE(chengcheng): Need Sync? oneflow._oneflow_internal.nn.graph.RunLazyNNGraph( convert_to_tensor_tuple(flattened_eager_args), outputs_tensor_tuple, self._c_nn_graph, ) # Update outputs buffer reading index self._cur_index_of_ouputs_buffer += 1 if self._cur_index_of_ouputs_buffer >= self._outputs_buffer_size: self._cur_index_of_ouputs_buffer = 0 # Copy outputs from buffer eager_outputs, _ = self.__copy_io("output", *eager_outputs) # Make sure that last used devices of tensors in `outputs_tensor_tuple` are # "critical_section". # NNGraph's execution flow will be broken if `last_used_device` of `outputs_tensor_tuple` # are not "critical_section". oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers( outputs_tensor_tuple, self._c_nn_graph ) except: self.__print( 2, 0, "[ERROR]" + self._shallow_repr() + " run got error: " + sys_exc_error_msg(), ) raise # Always pack outputs to remain type of outputs return seq_to_func_return(eager_outputs, True) def __build_io(self, io_type, build_func, *args, **kwargs): assert io_type in ("input", "output") op_names = [] args_repr = [] tensor2op_name = {} def build_tensor_or_any(tensor, name, repr_str): if isinstance(tensor, Tensor): build_arg = build_func(name, tensor) op_names.append(name) tensor2op_name[build_arg] = name else: build_arg = tensor args_repr.append(repr_str) self.__print(0, 1, repr_str) return build_arg args_tree = ArgsTree( (args, kwargs), True, "_" + self.name + "_" + io_type, None ) def leaf_arg_fn(arg): name = arg.prefix() + "_" + arg.name() if isinstance(arg.value(), Tensor): arg_repr = self.__io_item_check_and_gen_repr( arg.value(), Tensor, io_type, name ) build_arg = build_tensor_or_any(arg.value(), name, arg_repr) return build_arg else: # Opaque arg_repr = self.__io_item_check_and_gen_repr( arg.value(), None, io_type, name ) build_arg = build_tensor_or_any(arg.value(), name, arg_repr) out = args_tree.map_leaf(leaf_arg_fn) build_args = out[0] build_kwargs = out[1] return op_names, build_args, build_kwargs, args_repr, tensor2op_name def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name): assert io_type in ("input", "output") if expect_type is None: repr_str = ( "[WARNING](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")" ) self.__print(1, 0, repr_str) return repr_str elif expect_type is not None and isinstance(item, expect_type): if isinstance(item, Tensor): repr_str = ( "(" + io_type.upper() + ":" + name + ":" + item._meta_repr() + ")" ) else: repr_str = ( "[WARNING](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")" ) return repr_str else: repr_str = ( "[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")" ) self.__print(2, 0, repr_str) raise NotImplementedError( "nn.Graph.build()'s input/output item only support types: Tensor/None." ) def __map_io(self, io_type, func, *args, **kwargs): assert io_type in ("input", "output") def mapping_tensor_or_any(tensor): if isinstance(tensor, Tensor): mapped_arg = func(tensor) else: mapped_arg = tensor return mapped_arg def leaf_arg_fn(arg): arg_value = arg.value() return mapping_tensor_or_any(arg_value) # NOTE(lixiang): Reduce the overhead of traversal and parsing of io args. if self._is_simple_tuple_output or self._is_simple_tuple_input: args_tree = ArgsTree(args, False) out = args_tree.map_tuple_leaf(mapping_tensor_or_any) return out, kwargs args_tree = ArgsTree( (args, kwargs), True, "_" + self.name + "_" + io_type, None ) out = args_tree.map_leaf(leaf_arg_fn) mapped_args = out[0] mapped_kwargs = out[1] return mapped_args, mapped_kwargs def __map_io_lite(self, func, *args, **kwargs): args_tree = ArgsTree((args, kwargs), False) out = args_tree.map_leaf(func) mapped_args = out[0] mapped_kwargs = out[1] return mapped_args, mapped_kwargs def __flatten_io(self, io_type, *args, **kwargs): flattened_args = [] args_tree = ArgsTree((args, kwargs), False) for arg in args_tree.iter_nodes(): if isinstance(arg, Tensor): flattened_args.append(arg) else: continue return flattened_args def __io_item_check(self, item, expect_type, io_type, name): if expect_type is None and item is None: return elif expect_type is not None and isinstance(item, expect_type): return else: assert io_type in ("input", "output") repr_str = ( "[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")" ) self.__print(2, 0, repr_str) raise NotImplementedError( "nn.Graph.build()'s input/output item only support types: Tensor/None." ) def __empty_like_io(self, io_type, *args, **kwargs): def func(t): shape = t.shape dtype = t.dtype with oneflow._oneflow_internal.lazy_mode.guard(False): if t.is_global: eager_out = oneflow.empty( shape, dtype=dtype, placement=t.placement, sbp=t.sbp, ) else: eager_out = oneflow.empty(shape, dtype=dtype, device=t.device) return eager_out return self.__map_io(io_type, func, *args, **kwargs) def __copy_io(self, io_type, *args, **kwargs): def func(tensor): with oneflow._oneflow_internal.lazy_mode.guard(False): build_arg = tensor.to(copy=True) return build_arg return self.__map_io(io_type, func, *args, **kwargs) def _add_module(self, name: str, module: Module = None) -> None: r"""Adds module to the graph as a block so that the module will be called in nn.Graph.build. Args: name (str): name of the child block. The child block can be accessed from this graph using the given name. module (Module): child module to be added to the graph. Just assign nn.Module in nn.Graph, _add_module will be called to add the module as a ProxyModule: For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> class LinearGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... # add a nn.Module as a block to graph. ... self.linear = flow.nn.Linear(3, 8, False) ... def build(self, x): ... # call the nn.Module block. ... return self.linear(x) The block can be accessed as an attribute using the given name. g = LinearGraph() g(flow.Tensor(np.random.randn(8, 3))) print(g.linear) (MODULE:linear:Linear(in_features=3, out_features=8, bias=False)): ( (INPUT:_linear_input.0.0_2:tensor(..., is_lazy='True', size=(8, 3), dtype=oneflow.float32)) (PARAMETER:linear.weight:tensor(..., size=(8, 3), dtype=oneflow.float32, grad_fn=)): () (OUTPUT:_linear_output.0.0_2:tensor(..., is_lazy='True', size=(8, 8), dtype=oneflow.float32, grad_fn=)) (GraphModule:linear()): ( (OPERATOR: linear.weight() -> (out:sbp=(B), size=(8, 3), dtype=(oneflow.float32)), placement=(oneflow.placement(type="cpu", ranks=[0]))) (OPERATOR: linear-matmul-0(_LinearGraph_0_input.0.0_2/out:(sbp=(B), size=(8, 3), dtype=(oneflow.float32)), linear.weight/out:(sbp=(B), size=(8, 3), dtype=(oneflow.float32))) -> (linear-matmul-0/out_0:(sbp=(B), size=(8, 8), dtype=(oneflow.float32))), placement=(oneflow.placement(type="cpu", ranks=[0]))) ) ) """ if "_name" not in self.__dict__: raise AttributeError( "Base class nn.Graph has not been initialized, " "please call super().__init__() in subclass of nn.Graph " "before assigning any attribute." ) if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format(type(module))) elif not isinstance(name, str): raise TypeError("module name should be a string. Got {}".format(type(name))) elif hasattr(self, name) and name not in self._blocks: raise KeyError("attribute '{}' already exists".format(name)) elif "." in name: raise KeyError('module name can\'t contain ".", got: {}'.format(name)) elif name == "": raise KeyError('module name can\'t be empty string ""') self._blocks[name] = get_proxy_cls(module)( module, "", name, weakref.proxy(self) ) def __setattr__(self, name: str, value=None): if isinstance(value, Module): self._add_module(name, value) elif isinstance(value, Optimizer): raise AttributeError( "'{}' nn.Graph is not allowed to set Optimizer attribute named '{}'. " "Please use add_optimizer(...) instead.".format( type(self).__name__, name ) ) elif isinstance(value, Tensor): raise AttributeError( "'{}' nn.Graph is not allowed to set Tensor attribute named '{}'. " "Please use nn.Module to hold the tensor, then add the nn.Module to nn.Graph.".format( type(self).__name__, name ) ) else: object.__setattr__(self, name, value) def __getattr__(self, name: str): if "_blocks" in self.__dict__: if name in self._blocks: return self._blocks[name] if name in self.__dict__: return self.__dict__[name] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, name) ) def __del__(self): # Ensure vm has finished running this graph. if self._session._env.is_shutting_down(): # After python shutting down, it's not safe to call oneflow._oneflow_internal.eager. # But shutting down will do sync in SwitchToShuttingDownPhase. # So it's safe to skip sync here. return oneflow._oneflow_internal.eager.Sync() current_env_enable_mlir_inference_opt = os.getenv( "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION" ) if (self.env_enable_mlir_inference_opt is not None) and ( current_env_enable_mlir_inference_opt is None ): os.environ[ "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION" ] = self.env_enable_mlir_inference_opt oneflow._oneflow_internal.ResetVariableTensorMgr() def __ensure_input_tensors_contiguous(self, *args, **kwargs): args_tree = ArgsTree((args, kwargs), False) def func(value): if isinstance(value, Tensor) and not value.is_contiguous(): value.contiguous_() return value # NOTE(lixiang): Reduce the overhead of traversal and parsing of input args. if self._is_simple_tuple_input: args_tree.map_tuple_leaf(func) return args_tree.map_leaf(func) def __ensure_input_tensors_contiguous_and_flatten(self, *args, **kwargs): flattened_args = [] def func(value): if isinstance(value, Tensor) and not value.is_contiguous(): value.contiguous_() return value # NOTE(lixiang): Reduce the overhead of traversal and parsing of input args. if self._is_simple_tuple_input: args_tree = ArgsTree(args, False) # contiguous args_tree.map_tuple_leaf(func) # flatten for arg in args_tree.iter_nodes(): if isinstance(arg, Tensor): flattened_args.append(arg) else: continue return flattened_args args_tree = ArgsTree((args, kwargs), False) # contiguous args_tree.map_leaf(func) # flatten for arg in args_tree.iter_nodes(): if isinstance(arg, Tensor): flattened_args.append(arg) else: continue return flattened_args @staticmethod def with_dynamic_input_shape(*, size: int = 10, enable_shared: bool = True): def deco_with_config(graph_init_func): @wraps(graph_init_func) def deco_func(self, *args, **kwargs): graph_init_func(self, *args, **kwargs) self._run_with_cache = True import oneflow.nn.graph.cache as cache self._dynamic_input_graph_cache = cache.GraphCache( weakref.proxy(self), cache_size=size, enable_graph_shared=enable_shared, ) self._cached_init_args = args self._cached_init_kwargs = kwargs return deco_func return deco_with_config if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/graph/graph_block.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import weakref from collections import OrderedDict from typing import Iterator, Optional, Set, Union, List import oneflow._oneflow_internal from oneflow.env import get_rank from oneflow.framework import graph_build_util from oneflow.nn.graph.util import ( add_indent, operators_repr, GraphIR, ) class GraphBlockType: NONE = "NONE" MODULE = "MODULE" PARAMETER = "PARAMETER" BUFFER = "BUFFER" # Module or Tensor are both treated as Block. class GraphBlock(object): def __init__( self, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, belonged_proxy: weakref.ProxyTypes = None, block_graph_type: GraphBlockType = GraphBlockType.NONE, ): self._name = name self._name_prefix = prefix self._type = block_graph_type self._scope = None self._prev_scope = None assert belonged_graph is None or isinstance(belonged_graph, weakref.ProxyTypes) self._belonged_graph = belonged_graph assert belonged_proxy is None or isinstance(belonged_proxy, weakref.ProxyTypes) self._belonged_proxy = belonged_proxy @property def name(self): return self._name @property def name_prefix(self): return self._name_prefix @property def type(self): return self._type @property def prev_scope(self): if self._prev_scope is None: self._prev_scope = oneflow._oneflow_internal.GetCurrentScope() return self._prev_scope @property def scope(self): if self._scope is None: self._scope = graph_build_util.make_new_blockgraph_scope( self.prev_scope, self ) return self._scope def scope_context(self): return graph_build_util.BlockScopeContext(self.prev_scope, self.scope) class GraphModule(GraphBlock): r"""GraphModule is the graph representation of a nn.Module in a nn.Graph. When an nn.Module is added into an nn.Graph, it is wrapped into a ProxyModule. The ProxyModule has a GraphModule inside it. You can get and set the GraphModule to enable graph optimization on the nn.Module. """ def __init__( self, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, belonged_proxy: weakref.ProxyTypes = None, ): super().__init__( prefix, name, belonged_graph, belonged_proxy, GraphBlockType.MODULE ) self._is_null = True self._stage_id = None self._stage_placement = None self._activation_checkpointing = None self._debug = False self._debug_min_s_level = 2 self._debug_max_v_level = 0 self._debug_max_py_stack_depth = 2 self._debug_only_user_py_stack = True self._debug_op_repr_with_py_stack = False self._is_executing_forward = False self._args_repr = [] self._outs_repr = [] def set_stage(self, stage_id: int = None, placement=None): r"""Set stage id and placement of nn.Module in pipeline parallelism. Args: stage_id (int): stage id of this module. placement (flow.placement): the placement of all tensor in this module. Note: There will be automatically do tensor.to_global(placement) for all input tensor of this module. So there is no need to write to_global() in the module forward when using Pipeline Parallelism which is not recommended. For example: .. code-block:: python # module0 and module1 are two nn.Module in a nn.Graph. # When a nn.Module is added into a nn.Graph, it is wrapped into a ProxyModule. # We can set Stage ID and Placement by using ProxyModule.to(GraphModule).set_stage() # The Stage ID is numbered starting from 0 and increasing by 1. # The Placement is all tensors placement of this module. import oneflow as flow from oneflow.nn.graph import GraphModule P_0 = flow.placement(type = "cuda", ranks = [0, 1]) P_1 = flow.placement(type = "cuda", ranks = [2, 3]) self.module0.to(GraphModule).set_stage(stage_id = 0, placement = P0) self.module1.to(GraphModule).set_stage(stage_id = 1, placement = P1) """ self._is_null = False self._stage_id = stage_id self._stage_placement = placement # NOTE(lixiang): For the normal display of docstr, the API Doc of the get and set methods are written together in the stage_id function. @property def stage_id(self): r"""Set/Get stage id of nn.Module/GraphModule in pipeline parallelism. When calling stage_id(value: int = None), set different module's stage id to hint the graph preparing right num of buffers in pipeline. (Not Recommended, for easy and efficient pipeline parallelism experience, please use set_stage(stage_id, placement)) """ return self._stage_id @stage_id.setter def stage_id(self, value: int = None): r"""Set stage id of Module in pipeline parallelism. Set different module's stage id to hint the graph preparing right num of buffers in pipeline. """ print( "Warning: `stage_id = i` is deprecated, please use \n", " set_stage(i, placement) for easy and efficient Pipeline parallel experience.", ) self._is_null = False self._stage_id = value @property def stage_placement(self): return self._stage_placement # NOTE(lixiang): For the normal display of docstr, the API Doc of the get and set methods are written together in the activation_checkpointing function. @property def activation_checkpointing(self): r"""Set/Get whether do activation checkpointing in this nn.Module. For example: .. code-block:: python import oneflow as flow from oneflow.nn.graph import GraphModule class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 5, False) self.linear2 = flow.nn.Linear(5, 8, False) self.linear1.to(GraphModule).activation_checkpointing = True self.linear2.to(GraphModule).activation_checkpointing = True def build(self, x): y_pred = self.linear1(x) y_pred = self.linear2(y_pred) return y_pred graph = Graph() """ return self._activation_checkpointing @activation_checkpointing.setter def activation_checkpointing(self, mode: bool = False): r"""Set whether do activation checkpointing in this Module. """ self._is_null = False self._activation_checkpointing = mode def _config_repr(self): main_str = ( "(" + self.__class__.__name__ + "(" + ( ("stage_id=" + str(self.stage_id) + ", ") if self.stage_id is not None else "" ) + ( ( "activation_checkpointing=" + str(self.activation_checkpointing) + ", " ) if self.activation_checkpointing is not None else "" ) + "))" ) return main_str def debug( self, v_level: int = 0, *, ranks: Optional[Union[int, List[int]]] = None, max_py_stack_depth: int = 2, only_user_py_stack=True, op_repr_with_py_stack=False, ) -> None: assert isinstance(v_level, int) assert isinstance(max_py_stack_depth, int) assert isinstance(only_user_py_stack, bool) assert isinstance(op_repr_with_py_stack, bool) if ranks is None: rank_list = [0] elif isinstance(ranks, int): rank_list = [ranks] elif isinstance(ranks, list): rank_list = ranks else: raise ValueError("ranks must be int or List[int].") my_rank = get_rank() if -1 in rank_list or my_rank in rank_list: self._debug = v_level >= 0 if self._debug: self._debug_min_s_level = 0 self._debug_max_v_level = max(0, v_level) self._debug_max_py_stack_depth = max_py_stack_depth self._debug_only_user_py_stack = only_user_py_stack self._debug_op_repr_with_py_stack = op_repr_with_py_stack if self._type == GraphBlockType.MODULE: def _set_child(d): for (_, n) in d.items(): n.to(GraphModule).debug( v_level, ranks=ranks, max_py_stack_depth=max_py_stack_depth, only_user_py_stack=only_user_py_stack, op_repr_with_py_stack=op_repr_with_py_stack, ) assert self._belonged_proxy is not None and isinstance( self._belonged_proxy, weakref.ProxyTypes ) _set_child(self._belonged_proxy._modules) def _ops_repr(self): r"""Generate operators' string representation of this GraphModule """ assert self._belonged_graph, ( "ProxyModule: " + self._name_prefix + self.name + "'s belonged graph is not set." ) if self._belonged_graph._compiled_graph_proto is not None: module_conf = self._belonged_graph._compiled_graph_proto.module_name2module_conf[ self.name_prefix + self.name ] if self._belonged_graph._oneflow_internal_graph_ir__ is None: self._belonged_graph._oneflow_internal_graph_ir__ = GraphIR( self._belonged_graph._compiled_graph_proto ) return operators_repr( module_conf.ops, self._belonged_graph._oneflow_internal_graph_ir__, self._debug_op_repr_with_py_stack, ) return [] def _shallow_repr(self): main_str = ( "(" + self.__class__.__name__ + ":" + self._name_prefix + self._name + "(" + ( ("stage_id=" + str(self.stage_id) + ", ") if self.stage_id is not None else "" ) + ( ( "activation_checkpointing=" + str(self.activation_checkpointing) + ", " ) if self.activation_checkpointing is not None else "" ) + "))" ) return main_str def _repr_lines(self): child_lines = [] for op_str in self._ops_repr(): child_lines.append(add_indent(op_str, 2)) return child_lines def __repr__(self): lines = None child_lines = self._repr_lines() if len(child_lines) > 0: lines = child_lines main_str = self._shallow_repr() + ": (" if lines is not None: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str class GraphTensor(GraphBlock): r"""GraphTensor is the graph representation of a Tensor in a nn.Graph. """ def __init__( self, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, belonged_proxy: weakref.ProxyTypes = None, tensor_graph_type: GraphBlockType = GraphBlockType.NONE, ): super().__init__( prefix, name, belonged_graph, belonged_proxy, tensor_graph_type ) self._stage_id = None self._stage_placement = None def set_stage(self, stage_id: int = None, placement=None): self._stage_id = stage_id self._stage_placement = placement @property def stage_id(self): return self._stage_id @stage_id.setter def stage_id(self, value: int = None): print( "Warning: `stage_id = i` is deprecated, please use \n", " set_stage(i, placement) for easy and efficient Pipeline parallel experience.", ) self._stage_id = value @property def stage_placement(self): return self._stage_placement ================================================ FILE: python/oneflow/nn/graph/graph_config.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os from collections import OrderedDict import oneflow.boxing.nccl as nccl_config from oneflow.nn.graph.optimizer import OptDict import oneflow.core.job.job_conf_pb2 as job_conf_pb import oneflow as flow class GraphConfig(object): r"""For configuration of nn.Graph. """ def __init__(self): super().__init__() self._outputs_buffer_size = 2 self.proto = job_conf_pb.JobConfigProto() self._train(False) def _train(self, mode: bool = True): if mode: self.proto.train_conf.SetInParent() else: self.proto.predict_conf.SetInParent() @property def training(self): if self.proto.HasField("train_conf"): return True if self.proto.HasField("predict_conf"): return False raise NotImplementedError def enable_amp(self, mode: bool = True, *, dtype: flow.dtype = flow.float16): r"""If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_amp(True) # Use mixed precision mode. def build(self, x): return self.linear(x) graph = Graph() Args: mode (bool, optional): The default value is True. """ assert type(mode) is bool assert dtype in (flow.float16, flow.bfloat16) self.proto.enable_auto_mixed_precision = mode self.proto.mixed_precision_data_type = flow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( dtype ) def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): raise RuntimeError( "`set_zero_redundancy_optimizer_mode` has been changed to `enable_zero`, please use `enable_zero(True)` to activate ZeRO optimization." ) def enable_zero( self, mode: bool = True, *, stage: int = 2, shard_min_size: int = 1024, shard_restore_level: int = 1, ): r"""Enable ZeRO redundancy optimizer. This optimization will reduce optimizer states memory consumption as described by ZeRO https://arxiv.org/abs/1910.02054 . The default zero stage is 2. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_zero() def build(self, x): return self.linear(x) graph = Graph() Args: mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. stage (int): optimization stage, range from 1 to 3. shard_min_size (int): min size (element count) of a shard of an optimizer state. shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this parameter is at pre-alpha stage. """ if not mode: self.proto.optimizer_placement_optimization_mode = "none" return assert stage >= 1 and stage <= 3, "ZeRO stage must range from 1 to 3." assert ( shard_min_size > 0 ), "ZeRO min size of a sharded optimizer state must > 0." assert stage >= 1 and stage <= 3, "ZeRO stage must range from 1 to 3." if stage >= 1: self.proto.optimizer_placement_optimization_mode = "distributed_split" self.proto.optimizer_placement_optimization_threshold = shard_min_size self.proto.optimizer_placement_optimization_shard_restore_level = ( shard_restore_level ) if stage >= 2: nccl_config.enable_use_compute_stream(True) if stage >= 3: nccl_config.disable_group_boxing_by_dst_parallel(True) def allow_fuse_model_update_ops(self, mode: bool = True): r"""If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.allow_fuse_model_update_ops(True) def build(self, x): return self.linear(x) graph = Graph() Args: mode (bool, optional): The default value is True. """ self.proto.enable_fuse_model_update_ops = mode def allow_fuse_add_to_output(self, mode: bool = True): r"""If set to true, try to fuse a binary element-wise add operator to one of the predecessors to improve performance. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.bn1 = flow.nn.BatchNorm1d(100) self.config.allow_fuse_add_to_output(True) def build(self, x): bn = self.bn1(x) out = bn + x return out graph = Graph() Args: mode (bool, optional): The default value is True. """ self.proto.enable_fuse_add_to_output = mode def allow_fuse_cast_scale(self, mode: bool = True): r"""If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance. For example: .. code-block:: python import oneflow as flow def model(x): return flow.mul(1,flow.cast(x,flow.int8)) class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.m=model self.config.allow_fuse_cast_scale(True) def build(self, x): return self.m(x) graph = Graph() Args: mode (bool, optional): The default value is True. """ self.proto.enable_fuse_cast_scale = mode def set_gradient_accumulation_steps(self, value): r"""Set num of steps to accumulate gradient. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) # Let graph do gradient accumulation, such as pipelining parallelism depends on gradient accumulation. self.config.set_gradient_accumulation_steps(4) def build(self, x): return self.linear(x) graph = Graph() Args: value (int): num of steps. """ self.proto.num_gradient_accumulation_steps = value if value > 1: # NOTE(chengcheng): when use gradient accumulation, optimizer nccl allreduce can NOT # overlap with backward, so nccl use compute stream is optimization without negative # effects. nccl_config.enable_use_compute_stream(True) def set_outputs_buffer_size(self, value: int = 2): r"""Set the outputs buffer size of ``nn.Graph``. When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time. The default outputs buffer size is 2. # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code. # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap. # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time. Args: value (int): graph outputs buffer size. """ assert isinstance(value, int) assert value >= 1 self._outputs_buffer_size = value def enable_cudnn_conv_heuristic_search_algo(self, mode: bool = True): r""" Whether enable cudnn conv operation to use heuristic search algorithm. Note: It is recommended to use `flow.backends.cudnn.enable_conv_heuristic_search_algo(False)` instead of this function. For example: .. code-block:: python import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.m = flow.nn.Conv2d(16, 32, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) # Do not enable the cudnn conv operation to use the heuristic search algorithm. self.config.enable_cudnn_conv_heuristic_search_algo(False) def build(self, x): return self.m(x) graph = Graph() Args: mode (bool, optional): The default value is True. """ self.proto.cudnn_conv_heuristic_search_algo = mode def enable_straighten_algorithm(self, mode: str = "MemoryFirst"): r""" Whether enable the straighten algorithm. straighten_algorithm_tag 1: Disable Disable the straighten algorithm in the task graph. Would use the original topography order for executing task nodes. straighten_algorithm_tag 2: SpeedFirst Under the second configuration, the straighten algorithm would try to speed up the training as much as possible. If using nccl compute stream, setting the tag to 2 might not speed up the training. If not using nccl compute stream, setting the tag to 2 might speed up data parallelism by 0.6% and model parallelism by 6%. Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism. straighten_algorithm_tag 3: MemoryFirst Under the third configuration, the straighten algorithm would try to compress memory as much as possible. It might save up to 13% of the memory for some models. And might save nothing for some models. straighten_algorithm_tag 4: OverlapCpuGpu Under the forth configuration, the straighten algorithm would try to run the cpu nodes and gpu nodes alternately. Such procedure would reduce the gaps of the execution on gpus. It might speed up the training by 2%. If no cpu nodes exist, the straighten_algorithm_tag would be switch to 3 automatically. straighten_algorithm_tag 5: DelayShortGpu Under the fifth configuration, the straighten algorithm would try to delay the cpu nodes. Such procedure would reduce the gaps of the execution on gpus. It might speed up the validation (or training). If no cpu nodes exist, the straighten_algorithm_tag would be switch to 3 automatically. """ assert ( mode == "Disable" or mode == "SpeedFirst" or mode == "MemoryFirst" or mode == "OverlapCpuGpu" or mode == "DelayShortGpu" ), "please choose one type among {Disable, SpeedFirst, MemoryFirst, OverlapCpuGpu, DelayShortGpu}" if mode == "Disable": self.proto.straighten_algorithm_tag_in_task_graph = 1 elif mode == "SpeedFirst": self.proto.straighten_algorithm_tag_in_task_graph = 2 elif mode == "MemoryFirst": self.proto.straighten_algorithm_tag_in_task_graph = 3 elif mode == "OverlapCpuGpu": self.proto.straighten_algorithm_tag_in_task_graph = 4 else: self.proto.straighten_algorithm_tag_in_task_graph = 5 def enable_compress_memory(self, mode: bool = True): """If true, then the graph will try its best to find the minimum memory allocation strategy. This process might take several minutes for a small graph and half an hour for a large one. The compressed memory would be closed to the lower bound of the peak memory. It benefits a lot if you need to train a lot of batches. Args: mode (bool, optional): [description]. Default is True. """ self.proto.enable_compress_memory = mode def enable_choose_best_memory_allocation(self, mode: bool = True): """If true, then the graph will go through all the memory allocation algorithms. Including large memory first algorithm, long lifetime first algorithm, first in first allocates algorithm, large memory volume first algorithm with the compact insertion on and off. The the graph will choose the one with the least memory. If false, the graph will directly choose the large memory first algorithm with compact insertion off. Since the large memory first algorithm is the best one among those algorithms during most of our test cases. And turning compact insertion off will save half of the time of this algorithm. """ if mode: self.proto.memory_allocation_algorithm_conf.use_mem_size_first_algo = True self.proto.memory_allocation_algorithm_conf.use_lifetime_first_algo = True self.proto.memory_allocation_algorithm_conf.use_time_line_algo = True self.proto.memory_allocation_algorithm_conf.use_mem_volume_first_algo = True self.proto.memory_compact_insert_conf.use_compact_insert = True self.proto.memory_compact_insert_conf.use_non_compact_insert = True def enable_auto_parallel(self, mode: bool = True): """If true, then graph will use the auto parallel algorithm to select a parallelism strategy. Args: mode (bool, optional): [description]. Default is True. """ self.proto.enable_auto_parallel = mode def enable_auto_parallel_ignore_user_sbp_config(self, mode: bool = True): """If true, it will ignore all user configurations of SBP. Args: mode (bool, optional): [description]. Default is True. """ self.proto.enable_auto_parallel_ignore_user_sbp_config = mode def set_auto_parallel_computation_cost_ratio(self, ratio): """ Set coefficient of computation cost in auto-parallel algorithm. """ self.proto.auto_parallel_computation_cost_ratio = ratio def set_auto_parallel_wait_time(self, cost): """ Set wait time for auto-parallel algorithm. wait time: An auto-parallel parameter. Describe the mutable extra time it will take when communication between devices occurs. It will be added to the copy cost and may get reduced when cover by computation cost. """ self.proto.auto_parallel_wait_time = cost def enable_auto_parallel_trunk_algo(self, mode: bool = True): """ Find the trunk of the SBP graph, then reduce the wait time for tributaries. """ self.proto.enable_auto_parallel_trunk_algo = mode def enable_auto_parallel_sbp_collector(self, mode: bool = True): """ Use \"sbp collector\" to create \"sbp proxy\" for nodes with multiple downstream operators. """ self.proto.enable_auto_parallel_sbp_collector = mode def enable_auto_memory(self, mode: str = "AdaptiveMemory"): r""" Whether we use a parallelism strategy with less memory Auto memory strategy 1: Disable Disable auto memory in auto parallel. Ignore the memory and try our best to speed up the training. Auto memory strategy 2: SlightMemoryDown Try to decrease the memory while maintaining the throughput. Auto memory strategy 3: ModerateMemoryDown Decrease the memory, throughput might or might not be affected. Similar to data parallelism + ZeRO. Auto memory strategy 4: HeavyMemoryDown Try our best to decrease the memory, ignoring the throughput. Auto memory strategy 5: AdaptiveMemory Use normal auto parallelism without consideration of memory while we have enough memory. Gradually decrease the memory to avoid out of memory while we have inadequate memory. Always try to find the highest throughput under the current limitation of memory. """ assert ( mode == "Disable" or mode == "SlightMemoryDown" or mode == "ModerateMemoryDown" or mode == "HeavyMemoryDown" or mode == "AdaptiveMemory" ) if mode == "Disable": self.proto.enable_auto_memory = 1 elif mode == "SlightMemoryDown": self.proto.enable_auto_memory = 2 elif mode == "ModerateMemoryDown": self.proto.enable_auto_memory = 3 elif mode == "HeavyMemoryDown": self.proto.enable_auto_memory = 4 else: self.proto.enable_auto_memory = 5 def enable_multi_tensor_update(self, mode: bool = True): """ Enable Multi Tensor Update Pass, it will merge small optimizer kernels to reduce kernel launch overhead. """ self.proto.enable_multi_tensor_update = mode def enable_fused_model_update_cast(self, mode: bool = True): """ This option only works in AMP Mode, it will fuse optimizer update and model weights cast to half precision operation. """ self.proto.enable_fused_model_update_cast = mode def _generate_optimizer_and_variable_configs( self, opt_dict: OptDict = None, variables_conf: OrderedDict = None, ): opt_dict.generate_optimizer_and_variable_configs(self.proto, variables_conf) def __repr__(self): main_str = ( "(" + "CONFIG" + ":config:" + self.__class__.__name__ + "(" + ("training=" + str(self.training) + ", ") + "))" ) return main_str ================================================ FILE: python/oneflow/nn/graph/optimizer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.optim.optimizer import Optimizer from oneflow.nn.optimizer.lr_scheduler import LRScheduler class OptDict(object): def __init__(self, opt_dict): if not isinstance(opt_dict, dict): raise ValueError("opt_dict is not a dict") if "optim" in opt_dict: if isinstance(opt_dict["optim"], Optimizer): self._optimizer = opt_dict["optim"] else: raise ValueError('opt_dict["optim"] is not an instance of Optimizer.') else: raise ValueError("Key 'optim' doesn't exist in opt_dict.") if "is_sparse" in opt_dict and opt_dict["is_sparse"] is True: self._is_sparse = True else: self._is_sparse = False self._lr_scheduler = None if "lr_sch" in opt_dict: if not isinstance(opt_dict["lr_sch"], LRScheduler): raise ValueError( 'opt_dict["lr_sch"] is not an instance of LRScheduler.' ) if opt_dict["lr_sch"].optimizer is not self._optimizer: raise ValueError("lr_scheduler doesn't match optimizer.") self._lr_scheduler = opt_dict["lr_sch"] def generate_optimizer_and_variable_configs(self, job_conf, vars_conf): train_conf = job_conf.train_conf if self._optimizer is None: return # Check first self._optimizer._check_variables_in_graph(vars_conf) self._optimizer._check_variables_optimizer_bound(vars_conf) opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf) if self._is_sparse: self._optimizer._generate_indexed_slices_optimizer_conf(job_conf, vars_conf) if self._lr_scheduler is None: return for opt_conf in opt_confs: self._lr_scheduler._generate_conf_for_graph(opt_conf.learning_rate_decay) class VariableConfig(object): def __init__(self, name: str): assert name != "" self._name = name self._l2 = 0.0 self._bound_opt = None @property def name(self): return self._name @property def l2(self): return self._l2 @l2.setter def l2(self, l2: float = 0.0): self._l2 = l2 @property def bound_optimizer(self): return self._bound_opt @bound_optimizer.setter def bound_optimizer(self, opt): self._bound_opt = opt def __repr__(self): return "(variable name: " + self._name + "):(l2: " + str(self._l2) + ".)" ================================================ FILE: python/oneflow/nn/graph/proxy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Iterator, Optional, Set, Union, List import weakref import types import oneflow._C import oneflow._oneflow_internal from oneflow.framework import graph_build_util from oneflow.framework.tensor import Tensor, TensorTuple from oneflow.nn.modules.module import Module from oneflow.nn.modules.container import * from oneflow.nn.utils.container import * from oneflow.nn.parameter import Parameter from oneflow.nn.graph.graph_block import ( GraphBlockType, GraphBlock, GraphModule, GraphTensor, ) from oneflow.nn.graph.util import ( add_indent, seq_to_func_return, ) from oneflow.framework.args_tree import ArgsTree def get_proxy_cls(item): if isinstance(item, Sequential): return ProxySequential elif isinstance(item, ModuleList): return ProxyModuleList elif isinstance(item, ModuleDict): return ProxyModuleDict elif isinstance(item, ParameterList): return ProxyParameterList elif isinstance(item, ParameterDict): return ProxyParameterDict elif isinstance(item, Module): return ProxyModule elif isinstance(item, Tensor): return ProxyTensor else: raise NotImplementedError() class Proxy(object): def __init__(self): """ An ecution proxy of nn.Module or Tensor. A proxy contains the original data(nn.Module or Tensor) and a graph representation of the original data. """ # The original data self._oneflow_internal_origin__ = None # The graph representation of the original data self._oneflow_internal_graphblock__ = None def to(self, *args, **kwargs): """ """ if len(args) == 1 and issubclass(args[0], GraphBlock): return self._oneflow_internal_graphblock__ elif len(args) == 1 and (args[0] is Module or args[0] is Tensor): return self._oneflow_internal_origin__ else: self._oneflow_internal_origin__.to(*args, **kwargs) class ProxyModule(Proxy): def __init__( self, origin: Module = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): assert not isinstance(origin, Proxy) super().__init__() self._oneflow_internal_graphblock__ = GraphModule( prefix, name, belonged_graph, weakref.proxy(self) ) self._modules = OrderedDict() self._parameters = OrderedDict() self._buffers = OrderedDict() self._oneflow_internal_graphblock__set_origin(origin) def _oneflow_internal_graphblock__set_origin(self, origin): self._oneflow_internal_origin__ = origin if origin is None: return assert isinstance(origin, Module) for (n, m) in origin._modules.items(): self.__setattr__( n, get_proxy_cls(m)( m, self.to(GraphModule)._name_prefix + self.to(GraphModule)._name + ".", n, self.to(GraphModule)._belonged_graph, ), ) for (n, p) in list(origin.named_parameters("", False)): self.__setattr__( n, get_proxy_cls(p)( p, self.to(GraphTensor)._name_prefix + self.to(GraphTensor)._name + ".", n, ), ) for (n, b) in list(origin.named_buffers("", False)): self.__setattr__( n, get_proxy_cls(b)( b, self.to(GraphTensor)._name_prefix + self.to(GraphTensor)._name + ".", n, ), ) def __call__(self, *args, **kwargs): assert self.to(GraphModule)._type == GraphBlockType.MODULE self.__print(0, 1, self._shallow_repr()) args_tree = ArgsTree( (args, kwargs), True, "_" + self.to(GraphModule).name_prefix + self.to(GraphModule).name + "_input", None, ) for (name, arg) in args_tree.iter_named_nodes(): if arg.is_leaf(): arg_value = arg.value() meta_repr_str = ( arg_value._meta_repr() if isinstance(arg_value, Tensor) else str(type(arg_value)) ) in_str = "(INPUT:" + name + ":" + meta_repr_str + ")" if not isinstance(arg_value, Tensor): in_str = "[WARNING]" + in_str self.to(GraphModule)._args_repr.append(in_str) self.__print(0, 1, in_str) def _print_state(d): for (_, n) in d.items(): self.__print(0, 1, n._shallow_repr()) _print_state(self._parameters) _print_state(self._buffers) # NOTE: The original nn.Module's __call__ method is ignored, which means # that hooks of nn.Modules are ignored. It is not recommended # to use hooks of nn.Module in nn.Graph for the moment. with graph_build_util.DebugScopeContext( self.to(GraphModule)._debug_min_s_level, self.to(GraphModule)._debug_max_v_level, self.to(GraphModule)._debug, self.to(GraphModule)._debug_max_py_stack_depth, self.to(GraphModule)._debug_only_user_py_stack, ): result = self.__block_forward(*args, **kwargs) outputs = () if not (type(result) is tuple or type(result) is list): outputs = (result,) else: outputs = result args_tree = ArgsTree( (outputs, {}), True, "_" + self.to(GraphModule).name_prefix + self.to(GraphModule).name + "_output", None, ) for (name, arg) in args_tree.iter_named_nodes(): if arg.is_leaf(): arg_value = arg.value() meta_repr_str = ( arg_value._meta_repr() if isinstance(arg_value, Tensor) else str(type(arg_value)) ) out_str = "(OUTPUT:" + name + ":" + meta_repr_str + ")" if not isinstance(arg_value, Tensor): out_str = "[WARNING]" + out_str self.to(GraphModule)._outs_repr.append(out_str) self.__print(0, 1, out_str) return result @property def __class__(self): if self.to(GraphModule)._belonged_graph._is_user_mode == True: return self.to(Module).__class__ else: return type(self) def __block_forward(self, *args, **kwargs): self.to(GraphModule)._is_executing_forward = True args, kwargs = self.__pre_forward_map(*args, **kwargs) with self.to(GraphModule).scope_context(): # "Instance method __func__ is the function object", "when an instance method object is called, # the underlying function __func__ is called, inserting the class instance __self__ in front of # the argument list." # Reference: https://docs.python.org/3/reference/datamodel.html unbound_forward_of_module_instance = self.to(Module).forward.__func__ result = unbound_forward_of_module_instance(self, *args, **kwargs) self.to(GraphModule)._is_executing_forward = False return result def __pre_forward_map(self, *args, **kwargs): # Insert identity op when doing activation checkpointing or pipeline execution. # Identity op outside activation checkpointing scope will be the endpoint of an activation checkpointing segment. # Identity op as the first op of a pipeline stage will make backward op depends on the identity op within the stage, # otherwise the backward op may depends the op in former stage which will make graph creates unnessary buffers. if self.to(GraphModule)._stage_placement is not None: def insert_to_global(t): assert isinstance(t, Tensor) return self.__get_or_create_global( t, self.to(GraphModule)._stage_placement ) args, kwargs = self.__map_io( "input", insert_to_global, "insert_to_global", *args, **kwargs ) if self.to(GraphModule).activation_checkpointing or ( self.to(GraphModule).stage_id is not None and self.to(GraphModule).stage_id >= 0 ): def insert_identity(t): assert isinstance(t, Tensor) return self.__get_or_create_identity(t) args, kwargs = self.__map_io( "input", insert_identity, "insert_identity", *args, **kwargs ) return args, kwargs def __get_or_create_global(self, input_tensor: Tensor = None, placement=None): assert input_tensor is not None assert placement is not None key = str(id(input_tensor)) + str(placement) # input_tensor + placement -> unique_global_tensor if key not in self.to(GraphModule)._belonged_graph._unique_global_op_dict: # store input tensor to avoid tensor id recycle self.to(GraphModule)._belonged_graph._unique_global_op_dict[key] = ( input_tensor.to_global(placement=placement), input_tensor, ) return self.to(GraphModule)._belonged_graph._unique_global_op_dict[key][0] def __get_or_create_identity(self, input_tensor: Tensor = None): assert input_tensor is not None key = input_tensor # input_tensor(with placement) -> unique_identity_tensor # When placement is different, the input tensor(output tensor of __get_or_create_global) is different, so the # key can use only input tensor. if key not in self.to(GraphModule)._belonged_graph._unique_identity_op_dict: # Reuse current module name for indentity op ident_name_scope = graph_build_util.make_new_name_scope( self.to(GraphModule).prev_scope, self.to(GraphModule).name_prefix + self.to(GraphModule).name, ) with graph_build_util.BlockScopeContext( self.to(GraphModule).prev_scope, ident_name_scope ): # store input tensor to avoid tensor id recycle self.to(GraphModule)._belonged_graph._unique_identity_op_dict[ key ] = oneflow._C.identity(input_tensor) return self.to(GraphModule)._belonged_graph._unique_identity_op_dict[key] def add_module(self, name: str, module: Optional[Module]) -> None: if isinstance(module, Module): self.__setattr__( name, get_block_cls(module)( module, self.to(GraphModule)._name_prefix + self.to(GraphModule)._name + ".", name, self.to(GraphModule)._belonged_graph, ), ) elif isinstance(module, Proxy): self.__setattr__(name, module) def register_parameter(self, name: str, param: Optional[Parameter]) -> None: self.__setattr__( name, get_proxy_cls(param)( param, self.to(GraphModule)._name_prefix + self.to(GraphModule)._name + ".", name, ), ) def modules(self, memo: Optional[Set["Proxy"]] = None) -> Iterator["Proxy"]: assert self.to(GraphModule)._type == GraphBlockType.MODULE if memo is None: memo = set() if self not in memo: memo.add(self) yield self for (name, module) in self._modules.items(): if module is None: continue for m in module.modules(memo): yield m def __map_io(self, io_type, func, func_desc, *args, **kwargs): assert isinstance(func_desc, str) assert io_type in ("input", "output") mapped_args = [] def map_tensor(item): assert isinstance(item, Tensor) return func(item) args_tree = ArgsTree( (args, kwargs), True, "_" + self.to(GraphModule).name_prefix + self.to(GraphModule).name + "_" + io_type, None, ) def leaf_node_fn(leaf_node): arg = leaf_node.value() name = leaf_node.prefix() + "_" + leaf_node.name() is_tensor, repr_str = self.__io_tensor_check_and_gen(arg, io_type, name) if is_tensor: self.__print( 0, 1, f"{repr_str} is a Tensor, {func_desc} transformation has been done.", ) return map_tensor(arg) else: self.__print( 0, 0, f"{repr_str} is not a Tensor, {func_desc} transformation will be ignored.", ) return arg out = args_tree.map_leaf(leaf_node_fn) mapped_args = out[0] mapped_kwargs = out[1] return mapped_args, mapped_kwargs def __io_tensor_check_and_gen(self, item, io_type, name): assert io_type in ("input", "output") if isinstance(item, Tensor): repr_str = ( "(" + io_type.upper() + ":" + name + ":" + item._meta_repr() + ")" ) return True, repr_str else: repr_str = ( "[WARNING](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")" ) return False, repr_str def __members(self, get_members_fn, recurse=True) -> Iterator["Proxy"]: assert self.to(GraphModule)._type == GraphBlockType.MODULE memo = set() modules = self.modules() if recurse else [self] for module in modules: members = get_members_fn(module) for (k, v) in members: if v is None or v in memo: continue memo.add(v) yield v def parameters(self, recurse: bool = True) -> Iterator["Proxy"]: assert self.to(GraphModule)._type == GraphBlockType.MODULE gen = self.__members(lambda module: module._parameters.items(), recurse=recurse) for elem in gen: yield elem def buffers(self, recurse: bool = True) -> Iterator["Proxy"]: assert self.to(GraphModule)._type == GraphBlockType.MODULE gen = self.__members(lambda module: module._buffers.items(), recurse=recurse) for elem in gen: yield elem def __setattr__(self, name: str, value=None) -> None: if value is None or not isinstance(value, Proxy): self.__dict__[name] = value else: dicts_or_sets = ( self.__dict__, self._modules, self._parameters, self._buffers, ) for d in dicts_or_sets: if name in d: raise AttributeError( "'{}' object has duplicated attribute named '{}'".format( self.to(GraphModule)._name, name ) ) if value.to(GraphModule).type == GraphBlockType.MODULE: self._modules[name] = value elif value.to(GraphTensor).type == GraphBlockType.PARAMETER: self._parameters[name] = value elif value.to(GraphTensor).type == GraphBlockType.BUFFER: self._buffers[name] = value else: raise AttributeError( "'{}' object are not allowed to set attribute named '{}'".format( type(self).__name__, name ) ) def __getattr__(self, name: str): if name in self.__dict__: return self.__dict__[name] # support get module if "_modules" in self.__dict__: modules = self.__dict__["_modules"] if name in modules: return modules[name] # support get parameter p_state = self._get_from_states(name, "_parameters") if p_state is not None: return p_state # support get buffer b_state = self._get_from_states(name, "_buffers") if b_state is not None: return b_state # support none parameter or buffer if name in self.to(Module)._parameters: p_none = self.to(Module)._parameters[name] assert p_none is None return None if name in self.to(Module)._buffers: b_none = self.to(Module)._buffers[name] assert b_none is None return None if hasattr(self.to(Module), name): # support getting normal attr from the nn.Module attr = getattr(self.to(Module), name) if isinstance(attr, types.MethodType): # If the attr is MethodType, rebind the method to self attr = types.MethodType(attr.__func__, self) return attr raise AttributeError( "'{}' '{}' object '{}' in nn.Graph has no attribute '{}'".format( self.to(GraphModule)._type, type(self).__name__, self.to(GraphModule)._name_prefix + self.to(GraphModule).name, name, ) ) def _get_from_states(self, name, states_name): if states_name not in self.__dict__: return None _states = self.__dict__[states_name] if name not in _states: return None _s_block = _states[name] if graph_build_util.lazy_mode.is_enabled(): _s_block.try_build() return _s_block.lazy_origin elif (not graph_build_util.lazy_mode.is_enabled()) and self.to( GraphModule )._is_executing_forward: # eager and inside nn.Graph.build() return _s_block.to(Tensor) else: # outside nn.Graph.build() # eager and inside nn.Graph.build() return _s_block def __repr__(self): lines = None child_lines = [] if len(self.to(GraphModule)._args_repr) > 0: for in_str in self.to(GraphModule)._args_repr: input_str = add_indent(in_str, 2) child_lines.append(input_str) def _append_child(d): for (_, n) in d.items(): n_str = repr(n) n_str = add_indent(n_str, 2) child_lines.append(n_str) _append_child(self._parameters) _append_child(self._buffers) _append_child(self._modules) if len(self.to(GraphModule)._outs_repr) > 0: for out_str in self.to(GraphModule)._outs_repr: output_str = add_indent(out_str, 2) child_lines.append(output_str) child_lines.append(add_indent(repr(self.to(GraphModule)), 2)) if len(child_lines) > 0: lines = child_lines main_str = self._shallow_repr() + ": (" if lines is not None: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str def _shallow_repr(self): shallow_repr = ( "(" + self.to(GraphModule)._type + ":" + self.to(GraphModule)._name_prefix + self.to(GraphModule)._name + ":" + self._oneflow_internal_origin__._shallow_repr() + ")" ) return shallow_repr def __print(self, s_level=2, v_level=0, msg: str = ""): r"""Do print according to info level. """ assert isinstance(s_level, int) assert isinstance(v_level, int) assert isinstance(msg, str) if s_level >= self.to(GraphModule)._debug_min_s_level: if (s_level > 0) or ( s_level == 0 and v_level <= self.to(GraphModule)._debug_max_v_level ): print(msg, flush=True) class LazyBuilder(object): def __init__(self, name: str = None, method=None): self.name = name self.method = method self.result = None self.finished = False def try_build(self, block=None): if not self.finished: assert self.name is not None assert self.method is not None assert self.result is None with block.to(GraphTensor).scope_context(): self.result = self.method() self.finished = True class ProxyTensor(Proxy): def __init__( self, origin: Union[Parameter, Tensor] = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): assert not isinstance(origin, Proxy) if isinstance(origin, Parameter): self._oneflow_internal_graphblock__ = GraphTensor( prefix, name, belonged_graph, weakref.proxy(self), GraphBlockType.PARAMETER, ) elif isinstance(origin, Tensor): self._oneflow_internal_graphblock__ = GraphTensor( prefix, name, belonged_graph, weakref.proxy(self), GraphBlockType.BUFFER ) else: raise NotImplementedError() self._lazy_origin_builder = LazyBuilder() self.build_finished = False self._oneflow_internal_graphblock__set_origin(origin) def _oneflow_internal_graphblock__set_origin(self, origin): self._oneflow_internal_origin__ = origin @property def lazy_origin(self): assert ( self.to(GraphTensor)._type == GraphBlockType.PARAMETER or self.to(GraphTensor)._type == GraphBlockType.BUFFER ), "Only Parameter or Buffer Proxy has lazy_origin" return self._lazy_origin_builder.result def lazy_origin_builder(self): assert ( self.to(GraphTensor)._type == GraphBlockType.PARAMETER or self.to(GraphTensor)._type == GraphBlockType.BUFFER ), "Only Parameter or Buffer Proxy has lazy_origin_builder" return self._lazy_origin_builder def set_lazy_origin_builder(self, builder=None): assert ( self.to(GraphTensor)._type == GraphBlockType.PARAMETER or self.to(GraphTensor)._type == GraphBlockType.BUFFER ), "Only Parameter or Buffer Proxy has lazy_origin_builder" self._lazy_origin_builder = builder def try_build(self): if not self.build_finished: self._lazy_origin_builder.try_build(self) self.build_finished = True def __repr__(self): lines = None main_str = self._shallow_repr() + ": (" if lines is not None: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str def _shallow_repr(self): shallow_repr = ( "(" + self.to(GraphTensor)._type + ":" + self.to(GraphTensor)._name_prefix + self.to(GraphTensor)._name + ":" + self._oneflow_internal_origin__._meta_repr() + ")" ) return shallow_repr class ProxySequential(get_seq(ProxyModule)): def __init__( self, origin: Sequential = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): super().__init__() self.to(GraphModule)._name_prefix = prefix self.to(GraphModule)._name = name self.to(GraphModule)._belonged_graph = belonged_graph self.to(GraphModule)._belonged_block = weakref.proxy(self) self._oneflow_internal_graphblock__set_origin(origin) class ProxyModuleList(get_list(ProxyModule)): def __init__( self, origin: ModuleList = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): if isinstance(origin, ModuleList): super().__init__() self.to(GraphModule)._name_prefix = prefix self.to(GraphModule)._name = name self.to(GraphModule)._belonged_graph = belonged_graph self._oneflow_internal_graphblock__set_origin(origin) # ModuleList is a container without forward() method, elif isinstance(origin, list): super().__init__(origin) first = origin[0] new_name = "_idx" new_list = [] for item in origin: new_name += "-" + item.to(GraphModule).name new_list.append(item.to(Module)) new_module_list = ModuleList(new_list) self.to(GraphModule)._name_prefix = ( first.to(GraphModule).name_prefix + first.to(GraphModule).name ) self.to(GraphModule)._name = new_name self.to(GraphModule)._belonged_graph = first.to(GraphModule)._belonged_graph self._oneflow_internal_origin__ = new_module_list class ProxyModuleDict(get_dict(ProxyModule)): def __init__( self, origin: ModuleDict = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): super().__init__() self.to(GraphModule)._name_prefix = prefix self.to(GraphModule)._name = name self.to(GraphModule)._belonged_graph = belonged_graph self.to(GraphModule)._belonged_block = weakref.proxy(self) self._oneflow_internal_graphblock__set_origin(origin) class ProxyParameterList(get_para_list(ProxyModule)): def __init__( self, origin: ParameterList = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): super().__init__() self.to(GraphModule)._name_prefix = prefix self.to(GraphModule)._name = name self.to(GraphModule)._belonged_graph = belonged_graph self.to(GraphModule)._belonged_block = weakref.proxy(self) self._oneflow_internal_graphblock__set_origin(origin) self.to(GraphModule)._is_executing_forward = True def __getitem__(self, idx): assert isinstance(idx, int) idx = self._get_abs_string_index(idx) key = str(idx) p_state = self._get_from_states(key, "_parameters") if p_state is not None: return p_state else: raise AttributeError("ParameterList dosen't contain ", key) class ProxyParameterDict(get_para_dict(ProxyModule)): def __init__( self, origin: ParameterDict = None, prefix: str = "", name: str = "", belonged_graph: weakref.ProxyTypes = None, ): super().__init__() self.to(GraphModule)._name_prefix = prefix self.to(GraphModule)._name = name self.to(GraphModule)._belonged_graph = belonged_graph self.to(GraphModule)._belonged_block = weakref.proxy(self) self._oneflow_internal_graphblock__set_origin(origin) self.to(GraphModule)._is_executing_forward = True def __getitem__(self, key: str): p_state = self._get_from_states(key, "_parameters") if p_state is not None: return p_state else: raise AttributeError("ParameterDict dosen't contain key ", key) ================================================ FILE: python/oneflow/nn/graph/util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys from string import Template from typing import Callable, Dict, Union, List, Tuple, Optional from collections import OrderedDict from google.protobuf import text_format from google.protobuf.message import Message import oneflow import oneflow.core.job.job_pb2 as job_pb import oneflow.core.job.plan_pb2 as plan_pb import oneflow.core.common.device_type_pb2 as device_type import oneflow.core.operator.op_conf_pb2 as op_conf_util from oneflow.framework.tensor import Tensor def _nd_sbp2repr(nd_sbp): dim_len = len(nd_sbp.sbp_parallel) nd_sbp_str = "sbp=(" for i in range(dim_len): if i > 0: nd_sbp_str += ", " sbp = nd_sbp.sbp_parallel[i] if sbp.HasField("broadcast_parallel"): nd_sbp_str += "B" elif sbp.HasField("partial_sum_parallel"): nd_sbp_str += "P" elif sbp.HasField("split_parallel"): nd_sbp_str += "S(" + str(sbp.split_parallel.axis) + ")" nd_sbp_str += ")" return nd_sbp_str def _blob_desc_repr(blob_desc): desc_str = "size=(" for i in range(len(blob_desc.shape.dim)): if i > 0: desc_str += ", " desc_str += str(blob_desc.shape.dim[i]) desc_str += "), " desc_str += "dtype=(" desc_str += str(oneflow.dtype.get(int(blob_desc.data_type))) desc_str += ")" return desc_str def _get_args_repr(ordered_bn, bn2lbn, bn2nd_sbp, lbn2blob_desc): arg_repr_list = [] for bn in ordered_bn: lbns = list(bn2lbn[bn].s) # sbp repr sub_bns_sbp = [] for bn_idx in range(len(lbns)): sub_bn = bn + "_" + str(bn_idx) nd_sbp = bn2nd_sbp[sub_bn] sub_bns_sbp.append(_nd_sbp2repr(nd_sbp)) # TODO: placement repr # shape repr and dtype sub_bns_desc = [] for bn_idx in range(len(lbns)): sub_bns_desc.append(_blob_desc_repr(lbn2blob_desc[lbns[bn_idx]])) # sub arg repr sub_arg_repr_list = [] for bn_idx in range(len(lbns)): sub_arg_repr_list.append( lbns[bn_idx] + ":(" + sub_bns_sbp[bn_idx] + ", " + sub_bns_desc[bn_idx] + ")" ) if len(lbns) > 1: # arg of multiple tensors arg_repr_list.append("[" + (", ").join(sub_arg_repr_list) + "]") else: assert len(lbns) == 1 arg_repr_list.append(sub_arg_repr_list[0]) return arg_repr_list def _get_user_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc): user_op_conf = op_conf.user_conf input_sig_str = ", ".join( _get_args_repr( user_op_conf.input_order, user_op_conf.input, bn2nd_sbp, lbn2blob_desc ) ) output_sig_str = ", ".join( _get_args_repr( user_op_conf.output_order, user_op_conf.output, bn2nd_sbp, lbn2blob_desc ) ) return input_sig_str, output_sig_str def _get_var_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc): input_sig_str = "" var_op_conf = op_conf.variable_conf output_lbn = op_conf.name + "/" + var_op_conf.out output_sig_str = var_op_conf.out nd_sbp = bn2nd_sbp[var_op_conf.out] output_sig_str += ( ":" + _nd_sbp2repr(nd_sbp) + ", " + _blob_desc_repr(lbn2blob_desc[output_lbn]) ) return input_sig_str, output_sig_str def _get_iden_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc): iden_op_conf = op_conf.identity_conf input_lbn = getattr(iden_op_conf, "in") input_sig_str = ( input_lbn + ":" + _nd_sbp2repr(bn2nd_sbp["in"]) + ", " + _blob_desc_repr(lbn2blob_desc[input_lbn]) ) output_lbn = op_conf.name + "/" + iden_op_conf.out output_sig_str = iden_op_conf.out nd_sbp = bn2nd_sbp[iden_op_conf.out] output_sig_str += ( ":" + _nd_sbp2repr(nd_sbp) + ", " + _blob_desc_repr(lbn2blob_desc[output_lbn]) ) return input_sig_str, output_sig_str def _get_input_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc): op_input_conf = op_conf.input_conf output_lbn = op_conf.name + "/" + op_input_conf.out nd_sbp = bn2nd_sbp[op_input_conf.out] output_sig_str = ( output_lbn + ":" + _nd_sbp2repr(nd_sbp) + ", " + _blob_desc_repr(lbn2blob_desc[output_lbn]) ) return "", output_sig_str def _get_output_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc): op_output_conf = op_conf.output_conf input_lbn = getattr(op_output_conf, "in") output_lbn = op_conf.name + "/" + op_output_conf.out input_sig_str = ( input_lbn + ":" + _nd_sbp2repr(bn2nd_sbp["in"]) + ", " + _blob_desc_repr(lbn2blob_desc[output_lbn]) ) nd_sbp = bn2nd_sbp[op_output_conf.out] output_sig_str = ( output_lbn + ":" + _nd_sbp2repr(nd_sbp) + ", " + _blob_desc_repr(lbn2blob_desc[output_lbn]) ) return input_sig_str, output_sig_str class GraphIR(object): def __init__(self, g_proto: job_pb.Job): assert g_proto is not None and isinstance(g_proto, job_pb.Job) self._graph_proto = g_proto self._op2conf = None self._op2placement = None def get_op_conf(self, op_name: str) -> Optional[op_conf_util.OperatorConf]: if self._op2conf is None: self._op2conf = dict() for op_conf in self._graph_proto.net.op: self._op2conf[op_conf.name] = op_conf if op_name not in self._op2conf: return None return self._op2conf[op_name] def get_op_placement(self, op_name: str) -> Optional[oneflow.placement]: if self._op2placement is None: self._op2placement = dict() for group in self._graph_proto.placement.placement_group: parallel_conf = group.parallel_conf for this_op_name in group.op_set.op_name: self._op2placement[this_op_name] = oneflow.placement( proto_str=text_format.MessageToString(parallel_conf) ) if op_name not in self._op2placement: return None return self._op2placement[op_name] def _op_signature( op: op_conf_util.OperatorConf, graph_proto: job_pb.Job, graph_ir: GraphIR, show_op_loc: bool, ) -> Tuple[bool, str]: bn2nd_sbp = graph_proto.job_parallel_view_conf.op_name2nd_sbp_signature_conf[ op.name ].bn_in_op2nd_sbp lbn2blob_desc = graph_proto.helper.lbn2logical_blob_desc signature_template = Template( op.name + "($input) -> ($output)" + ", placement=(" + str(graph_ir.get_op_placement(op.name)) + ")" ) input_sig_str = "..." output_sig_str = "..." # Only deal with UserOpConf and VariableOpConf for now. if op.HasField("user_conf"): input_sig_str, output_sig_str = _get_user_op_io_repr( op, bn2nd_sbp, lbn2blob_desc ) elif op.HasField("variable_conf"): input_sig_str, output_sig_str = _get_var_op_io_repr( op, bn2nd_sbp, lbn2blob_desc ) elif op.HasField("identity_conf"): input_sig_str, output_sig_str = _get_iden_op_io_repr( op, bn2nd_sbp, lbn2blob_desc ) elif op.HasField("input_conf"): input_sig_str, output_sig_str = _get_input_op_io_repr( op, bn2nd_sbp, lbn2blob_desc ) elif op.HasField("output_conf"): input_sig_str, output_sig_str = _get_output_op_io_repr( op, bn2nd_sbp, lbn2blob_desc ) elif op.name.startswith("System-"): return False, "" op_str = "(OPERATOR: " op_str += signature_template.substitute(input=input_sig_str, output=output_sig_str) if show_op_loc and op.loc: op_str += ", location=(" + op.loc + ")" op_str += ")" return True, op_str def operators_repr(ops: Message, graph_ir: GraphIR, show_op_loc: bool,) -> List[str]: r"""Generate operators' string representation of this module """ graph_proto = graph_ir._graph_proto ops_strs = [] for op in ops: op_conf = graph_ir.get_op_conf(op) if op_conf is None: continue assert isinstance(op_conf, op_conf_util.OperatorConf) got_repr, op_str = _op_signature(op_conf, graph_proto, graph_ir, show_op_loc) if got_repr: ops_strs.append(op_str) return ops_strs def add_indent(in_s, num_spaces): s = in_s.split("\n") if len(s) == 1: return in_s first = s.pop(0) s = [num_spaces * " " + line for line in s] s = "\n".join(s) s = first + "\n" + s return s def sys_exc_error_msg(): msg = "" exc_info = sys.exc_info() if len(exc_info) > 0: msg += str(exc_info[0]) if len(exc_info) > 1: msg += " " + str(exc_info[1]) return msg def seq_to_func_return(seq, need_unpack=False): if need_unpack: return seq[0] return seq def _rsd_sub_destination_to(origin_dict, dest_device_str): dest_dict = OrderedDict() for k, v in origin_dict.items(): tensor_item, device_str = v dest_dict[k] = ( tensor_item.to(device=oneflow.device(dest_device_str), copy=True), dest_device_str, ) return dest_dict def _parallel_conf_to(parallel_conf, dest_device): if parallel_conf.device_tag == "cuda": assert len(parallel_conf.device_name) == 1 parallel_conf.device_name[0] = "@0:" + str(dest_device.index) def _mem_case_to(mem_case, dest_device): if mem_case.device_type == device_type.DeviceType.kCUDA: mem_case.device_id = dest_device.index if ( mem_case.HasField("pinned_device_type") and mem_case.pinned_device_type == device_type.DeviceType.kCUDA ): mem_case.pinned_device_id = dest_device.index def _job_to(job, dest_device): for pg in job.placement.placement_group: _parallel_conf_to(pg.parallel_conf, dest_device) for bpg in job.placement.blob_placement_group: _parallel_conf_to(bpg.parallel_conf, dest_device) def _modify_bits(original_num, k, j, new_num): if k > j: return original_num mask = ((1 << (j - k + 1)) - 1) << k cleared_num = original_num & ~mask modified_num = cleared_num | ((new_num & ((1 << (j - k + 1)) - 1)) << k) return modified_num def _get_bits(original_num, k, j): mask = ((1 << (j - k + 1)) - 1) << k cleared_num = (original_num & mask) >> k return cleared_num def _task_id_to(task_id, dest_device): if _get_bits(task_id, 43, 48) == 2: new_id = _modify_bits(task_id, 36, 43, dest_device.index) return new_id else: return task_id def _thrd_id_to(thrd_id, dest_device): if _get_bits(thrd_id, 22, 27) == 2: new_id = _modify_bits(thrd_id, 15, 22, dest_device.index) return new_id else: return thrd_id def _plan_to(plan_str, dest_device): plan = plan_pb.Plan() plan.ParseFromString(plan_str) for task in plan.task: task.task_id = _task_id_to(task.task_id, dest_device) task.thrd_id = _thrd_id_to(task.thrd_id, dest_device) for node in task.exec_sequence.exec_node: _parallel_conf_to( node.kernel_conf.op_attribute.parallel_conf_signature.op_parallel_conf, dest_device, ) for name, regst in task.produced_regst_desc.items(): regst.producer_task_id = _task_id_to(regst.producer_task_id, dest_device) for c_task_id_idx in range(len(regst.consumer_task_id)): regst.consumer_task_id[c_task_id_idx] = _task_id_to( regst.consumer_task_id[c_task_id_idx], dest_device ) _mem_case_to(regst.mem_case, dest_device) for mem_block in plan.block_chunk_list.mem_block: _mem_case_to(mem_block.mem_case, dest_device) mem_block.thrd_id_hint = _thrd_id_to(mem_block.thrd_id_hint, dest_device) for chunk in plan.block_chunk_list.chunk: _mem_case_to(chunk.mem_case, dest_device) new_ctrl_regst_desc_id2producer_task_id = {} for ( regst_desc_id, producer_task_id, ) in plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id.items(): new_ctrl_regst_desc_id2producer_task_id[regst_desc_id] = _task_id_to( producer_task_id, dest_device ) for ( regst_desc_id, producer_task_id, ) in new_ctrl_regst_desc_id2producer_task_id.items(): plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id[ regst_desc_id ] = producer_task_id for job_id, op_attr_tab in plan.job_id2op_attribute_ref_table.items(): for _, op_attr in op_attr_tab.op_name2op_attribute.items(): _parallel_conf_to( op_attr.parallel_conf_signature.op_parallel_conf, dest_device ) return plan.SerializeToString() ================================================ FILE: python/oneflow/nn/image.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.modules.dataset import ImageBatchAlign as batch_align from oneflow.nn.modules.dataset import ImageDecode as decode from oneflow.nn.modules.dataset import ImageFlip as flip from oneflow.nn.modules.dataset import ImageNormalize as normalize from oneflow.nn.modules.dataset import ImageResize as Resize ================================================ FILE: python/oneflow/nn/init.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import math import warnings import numpy as np import oneflow as flow from oneflow.ops.util.initializer_util import ( calc_gain as calculate_gain, calc_fan, get_data_format, ) from oneflow.framework.tensor import Tensor import oneflow.framework.dtype as dtype_util def uniform_(tensor, a=0.0, b=1.0): r""" Fills the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `oneflow.Tensor` a: the lower bound of the uniform distribution b: the upper bound of the uniform distribution Examples: >>> w = flow.empty(3, 5) >>> nn.init.uniform_(w) """ assert a <= b, "b must be greater than or equal to a,but got {%d} vs {%d}" % (b, a) with flow.no_grad(): return flow._C.uniform_(tensor, a, b) def normal_(tensor, mean=0.0, std=1.0): r""" Fills the input Tensor with values drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `oneflow.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution Examples: >>> w = flow.empty(3, 5) >>> nn.init.normal_(w) """ with flow.no_grad(): if tensor.is_local: return flow.normal(mean=mean, std=std, size=tensor.shape, out=tensor) else: return flow.normal( mean=mean, std=std, size=tensor.shape, out=tensor, placement=tensor.placement, sbp=tensor.sbp, ) def xavier_uniform_(tensor, gain=1.0, *, data_format="NCHW"): r""" Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-a, a)` where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Also known as Glorot initialization. Args: tensor: an n-dimensional `oneflow.Tensor` gain: an optional scaling factor Examples: >>> w = flow.empty(3, 5) >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) """ fan = calc_fan(tensor.shape, "fan_sum", get_data_format(data_format)) std = gain * math.sqrt(2.0 / fan) bound = math.sqrt(3.0) * std return uniform_(tensor, -bound, bound) def xavier_normal_(tensor, gain=1.0, *, data_format="NCHW"): r""" Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Also known as Glorot initialization. Args: tensor: an n-dimensional `oneflow.Tensor` gain: an optional scaling factor Examples: >>> w = flow.empty(3, 5) >>> nn.init.xavier_normal_(w) """ if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": data_format = "NHWC" fan = calc_fan(tensor.shape, "fan_sum", get_data_format(data_format)) std = gain * math.sqrt(2.0 / fan) return normal_(tensor, 0.0, std) def orthogonal_(tensor, gain=1.0): r""" Fills the input `Tensor` with a (semi) orthogonal matrix, as described in `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks` - Saxe, A. et al. (2013). The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` gain: optional scaling factor Examples: >>> w = flow.empty(3, 5) >>> nn.init.orthogonal_(w) """ with flow.no_grad(): return tensor.orthogonal_(gain) def kaiming_uniform_( tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", *, data_format="NCHW" ): r""" Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where .. math:: \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Also known as He initialization. Args: tensor: an n-dimensional `oneflow.Tensor` a: the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = flow.empty(3, 5) >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') """ if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": data_format = "NHWC" fan = calc_fan(tensor.shape, mode, get_data_format(data_format)) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std return uniform_(tensor, -bound, bound) def kaiming_normal_( tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", *, data_format="NCHW" ): r""" Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Also known as He initialization. Args: tensor: an n-dimensional `oneflow.Tensor` a: the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = flow.empty(3, 5) >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') """ if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": data_format = "NHWC" assert mode in ["fan_in", "fan_out"] fan = calc_fan(tensor.shape, mode, get_data_format(data_format)) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) return normal_(tensor, 0.0, std) # The trunc_normal_ implemention is referenced from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L22 def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) with flow.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def constant_(tensor, val): r""" Fills the input Tensor with the value :math:`\text{val}`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `oneflow.Tensor` val: the value to fill the tensor with Examples: >>> w = flow.empty(3, 5) >>> nn.init.constant_(w, 0.3) """ with flow.no_grad(): tensor[...] = val return tensor def ones_(tensor): r""" Fills the input Tensor with the scalar value `1`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `oneflow.Tensor` Examples: >>> w = flow.empty(3, 5) >>> nn.init.ones_(w) """ with flow.no_grad(): return constant_(tensor, 1) def zeros_(tensor): r""" Fills the input Tensor with the scalar value `0`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: an n-dimensional `oneflow.Tensor` Examples: >>> w = flow.empty(3, 5) >>> nn.init.zeros_(w) """ with flow.no_grad(): return constant_(tensor, 0) def eye_(tensor): r""" Fills the 2-dimensional input `Tensor` with the identity matrix. Preserves the identity of the inputs in `Linear` layers, where as many inputs are preserved as possible. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html. Args: tensor: a 2-dimensional `oneflow.Tensor` Examples: >>> w = flow.empty(3, 5) >>> nn.init.eye_(w) """ if tensor.ndimension() != 2: raise ValueError("Only tensors with 2 dimensions are supported") with flow.no_grad(): return flow._C.eye_(tensor) def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.ndimension() if dimensions < 2: raise ValueError( "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" ) num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.ndimension() > 2: for s in tensor.size()[2:]: receptive_field_size *= s fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return (fan_in, fan_out) ================================================ FILE: python/oneflow/nn/modules/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .module import Module ================================================ FILE: python/oneflow/nn/modules/_functions.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow class BackwardHookFunction(flow.autograd.Function): @staticmethod def forward(ctx, *args): ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) return args @staticmethod def backward(ctx, *args): return args ================================================ FILE: python/oneflow/nn/modules/activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from typing import Optional import oneflow as flow from oneflow.nn.modules.module import Module class PReLU(Module): """Applies the element-wise function: .. math:: PReLU(x) = \\max(0,x) + a * \\min(0,x) Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for each input channel. .. note:: weight decay should not be used when learning :math:`a` for good performance. .. note:: Channel dim is the 2nd dim of input. When input has dims < 2, then there is no channel dim and the number of channels = 1. Args: num_parameters (int): number of :math:`a` to learn. Although it takes an int as input, there is only two values are legitimate: 1, or the number of channels at input. Default: 1 init (float): the initial value of :math:`a`. Default: 0.25 Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input Attr: - weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.PReLU() >>> input = flow.tensor(np.asarray([[[[1, -2], [3, 4]]]]), dtype=flow.float32) >>> print(m(input).numpy()) [[[[ 1. -0.5] [ 3. 4. ]]]] """ def __init__( self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None ) -> None: super().__init__() self.num_parameters = num_parameters self.weight = flow.nn.Parameter( flow.empty(num_parameters, dtype=dtype, device=device).fill_(init) ) def forward(self, x): return flow._C.prelu(x, self.weight) def extra_repr(self) -> str: return "num_parameters={}".format(self.num_parameters) class ReLU(Module): """Applies the rectified linear unit function element-wise: :math:`\\text{ReLU}(x) = (x)^+ = \\max(0, x)` Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> relu = flow.nn.ReLU() >>> ndarr = np.asarray([1, -2, 3]) >>> x = flow.Tensor(ndarr) >>> relu(x) tensor([1., 0., 3.], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace def forward(self, x): return flow._C.relu(x, self.inplace) def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "" return inplace_str class ReLU6(Module): """Applies the element-wise function: .. math:: \\text{Relu6}(x) = \\begin{cases} 6 & \\text{ if } x > 6 \\\\ 0 & \\text{ if } x < 0 \\\\ x & \\text{ otherwise } \\\\ \\end{cases} Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> relu6 = flow.nn.ReLU6() >>> out = relu6(input) >>> out tensor([0.0000, 0.0000, 0.5000], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: warnings.warn("ReLU6 module do not support inplace now") return flow._C.hardtanh(x, min_val=0.0, max_val=6.0) def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "" return inplace_str def relu6(input, inplace=False): r"""relu6(input, inplace=False) -> Tensor Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. See :class:`~oneflow.nn.ReLU6` for more details. """ if inplace: warnings.warn("nn.functional.relu6 do not support inplace now") return flow._C.hardtanh(input, min_val=0.0, max_val=6.0) class Tanh(Module): """This operator computes the hyperbolic tangent value of Tensor. The equation is: .. math:: out = \\frac{e^x-e^{-x}}{e^x+e^{-x}} Args: input (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-1, 0, 1]).astype(np.float32) >>> input = flow.Tensor(x) >>> tanh = flow.nn.Tanh() >>> out = tanh(input) >>> out tensor([-0.7616, 0.0000, 0.7616], dtype=oneflow.float32) """ def __init__(self): super().__init__() def forward(self, input): return flow._C.tanh(input) class ELU(Module): """Applies the element-wise function :math:`\\text{ELU}(x) = \\begin{cases}x & \\text{ if } x \\gt 0 \\\\\\alpha*(exp(x)-1) & \\text{ if } x \\le 0 \\\\\\end{cases}` Args: alpha: the :math:`\\alpha` value for the ELU formulation. Default: 1.0 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> elu = flow.nn.ELU() >>> out = elu(input) >>> out tensor([-0.3935, 0.0000, 0.5000], dtype=oneflow.float32) """ def __init__(self, alpha: float = 1.0, inplace: bool = False): super().__init__() self.alpha = alpha self.inplace = inplace def forward(self, x): if self.inplace: warnings.warn("ELU module do not support inplace now") return flow._C.elu(x, alpha=self.alpha) def extra_repr(self): param_str = f"alpha={self.alpha}" param_str += ", inplace=True" if self.inplace else "" return param_str class CELU(Module): """Applies the element-wise function: .. math:: \\text{CELU}(x, \\alpha) = \\begin{cases} x & \\text{ if } x \\ge 0 \\\\ \\alpha*(exp(\\frac{x}{\\alpha})-1) & \\text{ otherwise } \\\\ \\end{cases} Args: alpha: the :math:`\\alpha` value for the CELU formulation. Default: 1.0 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> celu = flow.nn.CELU(alpha=0.5) >>> out = celu(input) >>> out tensor([-0.3161, 0.0000, 0.5000], dtype=oneflow.float32) """ def __init__(self, alpha: float = 1.0, inplace: bool = False): super().__init__() self.alpha = alpha self.inplace = inplace def forward(self, x): return flow._C.celu(x, alpha=self.alpha, inplace=self.inplace) def extra_repr(self): param_str = f"alpha={self.alpha}" param_str += ", inplace=True" if self.inplace else "" return param_str class GELU(Module): """ GELU(approximate='none') -> Tensor The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.GELU.html. Applies the Gaussian Error Linear Units function: .. math:: \\text{GELU}(x) = x * \Phi(x) where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. When the approximate argument is 'tanh', Gelu is estimated with: .. math:: \\text{GELU}(x) = 0.5 * x * (1 + \\text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) Args: input (oneflow.Tensor): Input Tensor approximate (string, optional): the gelu approximation algorithm to use: ``'none'`` | ``'tanh'``. Default: ``'none'`` Returns: oneflow.Tensor: A Tensor has same shape as the input. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> gelu = flow.nn.GELU() >>> out = gelu(input) >>> out tensor([-0.1543, 0.0000, 0.3457], dtype=oneflow.float32) """ def __init__(self, approximate: str = "none"): super().__init__() self.approximate = approximate def forward(self, input): if self.approximate == "none" or self.approximate == "tanh": return flow._C.gelu_with_approximate(input, self.approximate) else: raise NotImplementedError class QuickGELU(Module): """ QuickGELU() -> Tensor Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs .. math:: \\text{QuickGELU}(x) = x * \\sigma(1.702x) = x * \\frac{1}{1 + \\exp(-1.702x)} Args: input (oneflow.Tensor): Input Tensor Returns: oneflow.Tensor: A Tensor has same shape as the input. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor([-0.5, 0, 0.5]) >>> gelu = flow.nn.QuickGELU() >>> out = gelu(input) >>> out tensor([-0.1496, 0.0000, 0.3504], dtype=oneflow.float32) """ def __init__(self): super().__init__() def forward(self, x): return flow._C.quick_gelu(x) class SquareReLU(Module): """ SquareReLU() -> Tensor Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 .. math:: :math:`\\text{SquareReLU}(x) = \\max(0, x) * \\max(0, x)` Args: input (oneflow.Tensor): Input Tensor Returns: oneflow.Tensor: A Tensor has same shape as the input. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> square_relu = flow.nn.SquareReLU() >>> out = square_relu(input) >>> out tensor([0.0000, 0.0000, 0.2500], dtype=oneflow.float32) """ def __init__(self): super().__init__() def forward(self, x): return flow._C.square_relu(x) class Sigmoid(Module): """Applies the element-wise function: .. math:: \\text{Sigmoid}(x) = \\sigma(x) = \\frac{1}{1 + \\exp(-x)} Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = flow.Tensor(np.array([0.81733328, 0.43621480, 0.10351428])) >>> m = flow.nn.Sigmoid() >>> out = m(x) >>> out tensor([0.6937, 0.6074, 0.5259], dtype=oneflow.float32) """ def __init__(self): super().__init__() def forward(self, x): return flow._C.sigmoid(x) class Hardsigmoid(Module): """Applies the element-wise function: .. math:: \\text{Hardsigmoid}(x) = \\begin{cases} 0 & \\text{ if } x \\le -3 \\\\ 1 & \\text{ if } x \\ge +3 \\\\ \\frac{x}{6} + \\frac{1}{2} & \\text{ otherwise } \\\\ \\end{cases} Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> hardsigmoid = flow.nn.Hardsigmoid() >>> out = hardsigmoid(input) >>> out tensor([0.4167, 0.5000, 0.5833], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: return flow._C.hardsigmoid(x, True) return flow._C.hardsigmoid(x, False) def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "" return inplace_str class Hardshrink(Module): r""" The Hardshrink activation. The formula is: .. math:: \text{Hardshrink}(x) = \begin{cases} x, & \text{ if } x > \lambda \\ x, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases} Args: lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-1.1, 0, 0.2, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> hardshrink = flow.nn.Hardshrink(lambd=0.5) >>> out = hardshrink(input) >>> out tensor([-1.1000, 0.0000, 0.0000, 0.0000], dtype=oneflow.float32) """ def __init__(self, lambd: float = 0.5, inplace: bool = False): super().__init__() self.inplace = inplace self.lambd = lambd def forward(self, x): return flow._C.hardshrink(x, lambd=self.lambd, inplace=self.inplace) def extra_repr(self) -> str: param_str = f"lambd={self.lambd}" param_str += ", inplace=True" if self.inplace else "" return param_str class Softmax(Module): """Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range [0,1] and sum to 1. Softmax is defined as: .. math:: \\text{Softmax}(x_{i}) = \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)} When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``. Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input Returns: a Tensor of the same dimension and shape as the input with values in the range [0, 1] Args: dim (int): A dimension along which Softmax will be computed (so every slice along dim will sum to 1). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Softmax(dim = 2) >>> x = flow.Tensor( ... np.array( ... [[[-0.46716809, 0.40112534, 0.61984003], ... [-1.31244969, -0.42528763, 1.47953856]]] ... ) ... ) >>> out = m(x) >>> out tensor([[[0.1575, 0.3754, 0.4671], [0.0507, 0.1230, 0.8263]]], dtype=oneflow.float32) """ def __init__(self, dim: Optional[int] = None): super(Softmax, self).__init__() self.dim = dim def forward(self, x): return flow._C.softmax(x, self.dim) def extra_repr(self): return f"dim={self.dim}" class LogSoftmax(Module): r"""Applies the LogSoftmax function to an n-dimensional input Tensor. The LogSoftmax formulation can be simplified as: .. math:: \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) = x_i - \log({ \sum_j \exp(x_j)}) Args: dim (int): A dimension along which LogSoftmax will be computed. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.LogSoftmax(dim=1) >>> x = flow.Tensor( ... np.array( ... [[ 0.4296, -1.1957, 2.5463], ... [ 1.2552, -1.5747, 0.6923]] ... ) ... ) >>> out = m(x) >>> out tensor([[-2.2513, -3.8766, -0.1346], [-0.4877, -3.3176, -1.0506]], dtype=oneflow.float32) """ def __init__(self, dim: Optional[int] = None): super(LogSoftmax, self).__init__() self.dim = dim def forward(self, x): return flow._C.log_softmax(x, self.dim) def extra_repr(self): return f"dim={self.dim}" class LogSigmoid(Module): """Applies the element-wise function: .. math:: \\text{LogSigmoid}(x) = \\log\\left(\\frac{ 1 }{ 1 + \\exp(-x)}\\right) Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> logsigmoid = flow.nn.LogSigmoid() >>> out = logsigmoid(input) >>> out tensor([-0.9741, -0.6931, -0.4741], dtype=oneflow.float32) """ def __init__(self): super().__init__() def forward(self, x): return flow._C.logsigmoid(x) class Softplus(Module): """Applies the element-wise function: .. math:: \\text{Softplus}(x) = \\frac{1}{\\beta} * \\log(1 + \\exp(\\beta * x)) SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. For numerical stability the implementation reverts to the linear function when :math:`input \\times \\beta > threshold`. Args: beta: the :math:`\\beta` value for the Softplus formulation. Default: 1 threshold: values above this revert to a linear function. Default: 20 Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> softplus = flow.nn.Softplus() >>> out = softplus(input) >>> out tensor([0.4741, 0.6931, 0.9741], dtype=oneflow.float32) """ def __init__(self, beta: int = 1, threshold: int = 20): super().__init__() self.beta = beta self.threshold = threshold def forward(self, x): return flow._C.softplus(x, beta=self.beta, threshold=self.threshold) def extra_repr(self): return f"beta={self.beta}, threshold={self.threshold}" class Hardswish(Module): """Applies the hardswish function, element-wise, as described in the paper `Searching for MobileNetV3 `__. .. math:: \\text{Hardswish}(x) = \\begin{cases} 0 & \\text{ if } x \\le -3 \\\\ x & \\text{ if } x \\ge +3 \\\\ x*(x+3)/6 & \\text{ otherwise } \\\\ \\end{cases} Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> hardswish = flow.nn.Hardswish() >>> out = hardswish(input) >>> out tensor([-0.2083, 0.0000, 0.2917], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace def forward(self, x): if self.inplace: warnings.warn("Hardswish module do not support inplace now") return flow._C.hardswish(x) def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "" return inplace_str class Hardtanh(Module): """ Applies the HardTanh function element-wise HardTanh is defined as: .. math:: \\text{HardTanh}(x) = \\begin{cases} 1 & \\text{ if } x > 1 \\\\ -1 & \\text{ if } x < -1 \\\\ x & \\text{ otherwise } \\\\ \\end{cases} The range of the linear region :math:`[-1, 1]` can be adjusted using :attr:`min_val` and :attr:`max_val`. Args: min_val: minimum value of the linear region range. Default: -1 max_val: maximum value of the linear region range. Default: 1 inplace: can optionally do the operation in-place. Default: ``False`` Keyword arguments :attr:`min_value` and :attr:`max_value` have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Hardtanh() >>> arr = np.array([0.2, 0.3, 3.0, 4.0]) >>> x = flow.Tensor(arr) >>> out = m(x) >>> out tensor([0.2000, 0.3000, 1.0000, 1.0000], dtype=oneflow.float32) """ def __init__( self, min_val: float = -1, max_val: float = 1, inplace: bool = False, min_value: Optional[float] = None, max_value: Optional[float] = None, ): super().__init__() if min_value is not None: warnings.warn( "keyword argument min_value is deprecated and rename to min_val" ) min_val = min_value if max_value is not None: warnings.warn( "keyword argument max_value is deprecated and rename to max_val" ) max_val = max_value self.min_val = min_val self.max_val = max_val self.inplace = inplace def forward(self, x): if self.inplace: warnings.warn("Hardtanh module do not support inplace now") return flow._C.hardtanh(x, min_val=self.min_val, max_val=self.max_val) def extra_repr(self): param_str = f"min_val={self.min_val}, max_val={self.max_val}" param_str += ", inplace=True" if self.inplace else "" return param_str class LeakyReLU(Module): """Applies the element-wise function: .. math:: \\text{LeakyRELU}(x) = \\begin{cases} x, & \\text{ if } x \\geq 0 \\\\ \\text{negative_slope} \\times x, & \\text{ otherwise } \\end{cases} Args: negative_slope: Controls the angle of the negative slope. Default: 1e-2 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.LeakyReLU(0.1) >>> arr = np.array([0.2, 0.3, 3.0, 4.0]) >>> x = flow.Tensor(arr) >>> out = m(x) >>> out tensor([0.2000, 0.3000, 3.0000, 4.0000], dtype=oneflow.float32) """ def __init__(self, negative_slope: float = 0.01, inplace: bool = False): super().__init__() self.negative_slope = negative_slope self.inplace = inplace def forward(self, x): return flow._C.leaky_relu(x, alpha=self.negative_slope, inplace=self.inplace) def extra_repr(self): param_str = f"negative_slope={self.negative_slope}" param_str += ", inplace=True" if self.inplace else "" return param_str class RReLU(Module): """Applies the randomized leaky rectified liner unit function, element-wise: .. math:: \\text{RReLU}(x) = \\begin{cases} x, & \\text{ if } x \\geq 0 \\\\ a \\times x, & \\text{ otherwise } \\end{cases} where :math:`a` is randomly sampled from uniform distribution :math:`\mathcal{U}(\text{lower}, \text{upper})`. .. note:: See `Empirical Evaluation of Rectified Activations in Convolution Network: `_ Args: lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> m = flow.nn.RReLU(0.1, 0.3) >>> arr = np.array([0.2, -0.3, -3.0, 4.0, 0.5, -2.2]) >>> x = flow.Tensor(arr) >>> out = m(x) >>> print(out) # doctest: +SKIP tensor([ 0.2000, -0.0824, -0.5418, 4.0000, 0.5000, -0.4213], dtype=oneflow.float32) # doctest: +SKIP """ def __init__( self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False ): super().__init__() self.lower = lower self.upper = upper self.inplace = inplace def forward(self, x): return flow._C.rrelu(x, self.lower, self.upper, self.training, self.inplace) def extra_repr(self): param_str = f"lower={self.lower}" param_str += f"upper={self.upper}" param_str += ", inplace=True" if self.inplace else "" return param_str class Mish(Module): """Applies the element-wise function: .. math:: \\text{Mish}(x) = x * \\text{Tanh}(\\text{Softplus}(x)) .. note:: See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.Tensor(x) >>> mish = flow.nn.Mish() >>> out = mish(input) >>> out tensor([0.8651, 1.9440, 2.9865], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): self.inplace = inplace super().__init__() def forward(self, x): return flow._C.mish(x) class SiLU(Module): r"""SiLU(Swish) activation: .. math:: \text{SiLU}(x) = x * sigmoid(x) .. note:: See `Gaussian Error Linear Units (GELUs) `_ where the SiLU (Sigmoid Linear Unit) was originally coined, and see `Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning `_ and `Swish: a Self-Gated Activation Function `_ where the SiLU was experimented with later. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.Tensor(x) >>> silu = flow.nn.SiLU() >>> out = silu(input) >>> out tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): self.inplace = inplace super().__init__() def forward(self, x): return flow._C.silu(x) class SELU(Module): r"""Applies the element-wise function: The formula is: .. math:: \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) with :math:`\alpha = 1.6732632423543772848170429916717` and :math:`\text{scale} = 1.0507009873554804934193349852946`. .. warning:: When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` in order to get `Self-Normalizing Neural Networks`_. See :func:`torch.nn.init.calculate_gain` for more information. More details can be found in the paper `Self-Normalizing Neural Networks `_. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.Tensor(x) >>> selu = flow.nn.SELU() >>> out = selu(input) >>> out tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): self.inplace = inplace super().__init__() def forward(self, x): return flow._C.selu(x) class Softshrink(Module): r""" The Softshrink activation. The formula is: .. math:: \text{Softshrink}(x) = \begin{cases} x - \lambd, & \text{ if } x > \lambda \\ x + \lambd, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Softshrink.html. Args: lambd: the :math:`\lambda` value for the Softshrink formulation. Default: 0.5 inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([-1, 0, 0.2, 0.5]).astype(np.float32) >>> input = flow.Tensor(x) >>> softshrink = flow.nn.Softshrink(lambd=0.5) >>> out = softshrink(input) >>> out tensor([-0.5000, 0.0000, 0.0000, 0.0000], dtype=oneflow.float32) """ def __init__(self, lambd: float = 0.5, inplace: bool = False): self.inplace = inplace self.lambd = lambd super().__init__() def forward(self, x): return flow._C.softshrink(x, alpha=self.lambd, inplace=self.inplace) def extra_repr(self) -> str: param_str = f"lambd={self.lambd}" param_str += ", inplace=True" if self.inplace else "" return param_str class Softsign(Module): r"""The SoftSign activation. The formula is: .. math:: SoftSign(x) = \frac{x}{1 + |x|} Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([1, 2, 3]).astype(np.float32) >>> input = flow.Tensor(x) >>> softsign = flow.nn.Softsign() >>> out = softsign(input) >>> out tensor([0.5000, 0.6667, 0.7500], dtype=oneflow.float32) """ def __init__(self, inplace: bool = False): self.inplace = inplace super().__init__() def forward(self, x): return flow._C.softsign(x) class GLU(Module): r"""The GLU activation. Args: input (Tensor, float): input tensor. dim (int, optional): dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` The formula is: .. math:: GLU(input) = GLU(a, b) = a \otimes sigmoid(b) .. note:: where input is split in half along dim to form a and b, ⊗ is the element-wise product between matrices. For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> m = nn.GLU() >>> x = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=flow.float32) >>> y = m(x) >>> y tensor([[0.9526, 1.9640], [4.9954, 5.9980]], dtype=oneflow.float32) """ def __init__(self, dim: Optional[int] = -1): super().__init__() self.dim = dim def forward(self, input): return flow._C.glu(input, self.dim) class Threshold(Module): r"""The Threshold Activation. Return ``x`` if ``x`` is greater than ``threshold``, else return ``value``. The interface is consistent with PyTorch. The documentation is referenced from https://pytorch.org/docs/1.10/generated/torch.nn.Threshold.html. The formula is: .. math:: \text{Threshold}(x) = \begin{cases} x, & \text{ if } x > \text{ threshold } \\ \text{value }, & \text{ otherwise } \end{cases} Args: threshold (float): The ``threshold`` value for the Threshold formulation value (float): The ``value`` value for the Threshold formulation Shapes: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input Returns: Oneflow.Tensor: The result tensor For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = np.array([-1, 0, 0.5, 1]).astype(np.float32) >>> input = flow.Tensor(x) >>> th = flow.nn.Threshold(threshold=0.5, value=0.2) >>> out = th(input) >>> out tensor([0.2000, 0.2000, 0.2000, 1.0000], dtype=oneflow.float32) """ def __init__(self, threshold: float, value: float): super().__init__() self.threshold = threshold self.value = value def forward(self, input): return flow._C.threshold(input, threshold=self.threshold, value=self.value) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/affine_grid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List import oneflow as flow def affine_grid(theta, size: List[int], align_corners: bool = False): """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.affine_grid.html. Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. .. note:: This function is often used in conjunction with :func:`grid_sample` to build `Spatial Transformer Networks`_ . Args: theta (Tensor): input batch of affine matrices with shape (:math:`N, 2, 3`) for 2D or (:math:`N, 3, 4`) for 3D size (oneflow.Size): the target output image size. (:math:`N, C, H, W` for 2D or :math:`N, C, D, H, W` for 3D) Example: oneflow.Size((32, 3, 24, 24)) align_corners (bool): if ``True``, consider ``-1`` and ``1`` to refer to the centers of the corner pixels rather than the image corners. Refer to :func:`grid_sample` for a more complete description. A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` with the same setting for this option. Default: ``False`` Returns: output (Tensor): output Tensor of size (:math:`N, H, W, 2`) .. _`Spatial Transformer Networks`: https://arxiv.org/abs/1506.02025 Examples:: >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(1., 7).reshape((1, 2, 3)), dtype=flow.float32) >>> output = flow.nn.functional.affine_grid(input, flow.Size([1, 1, 2, 2]), align_corners=True) >>> output tensor([[[[ 0., -3.], [ 2., 5.]], [[ 4., 7.], [ 6., 15.]]]], dtype=oneflow.float32) """ y = flow._C.affine_grid(theta, size=size, align_corners=align_corners) return y if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/all_reduce.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.module import Module from typing import Sequence class AllReduce(Module): def __init__(self, parallel_conf_str: str): super().__init__() self._op = ( flow.stateful_op("eager_ccl_all_reduce").Input("in").Output("out").Build() ) self.parallel_conf = parallel_conf_str def forward(self, x): assert x.device.type == "cuda" assert x.device.index == flow.env.get_local_rank() return flow._C.dispatch_eager_ccl_all_reduce(self._op, parallel_conf) ================================================ FILE: python/oneflow/nn/modules/arange.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List, Union import oneflow as flow def arange_op( start: Union[int, flow.Tensor] = None, end: Union[int, flow.Tensor] = None, step: int = 1, dtype: flow.dtype = None, device: Union[str, flow.device] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, requires_grad: bool = False, ): if start is None: start = 0 elif flow.is_tensor(start): # support start as a Scalar Tensor assert len(start.shape) == 0, "start must be a Scalar" start = start.item() if end is None: end = start start = 0 elif flow.is_tensor(end): # support end as a Scalar Tensor assert len(end.shape) == 0, "end must be a Scalar" end = end.item() if placement is None: if isinstance(device, str): device = flow.device(device) res = flow._C.arange(start, end, step, dtype=dtype, device=device) else: assert isinstance( placement, flow._oneflow_internal.placement ), "placement should be oneflow._oneflow_internal.placement type." assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp assert len(sbp) == len(placement.ranks.shape) res = flow._C.global_arange( start, end, step, dtype=dtype, placement=placement, sbp=sbp ) res.requires_grad = requires_grad return res if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/argsort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module from oneflow.ops.transpose_util import ( get_inversed_perm, get_perm_when_transpose_axis_to_last_dim, ) def argsort_op(input, dim: int = -1, descending: bool = False): num_dims = len(input.shape) dim = dim if dim >= 0 else dim + num_dims direction = "DESCENDING" if descending else "ASCENDING" assert 0 <= dim < num_dims, "dim out of range" if dim == num_dims - 1: return flow._C.arg_sort(input, direction) else: perm = get_perm_when_transpose_axis_to_last_dim(num_dims, dim) x = flow._C.transpose(input, perm=perm) x = flow._C.arg_sort(x, direction) return flow._C.transpose(x, perm=get_inversed_perm(perm)) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/argwhere.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional import numpy as np import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module def argwhere_op(input, dtype: Optional[flow.dtype] = flow.int32): """This operator finds the indices of input Tensor `input` elements that are non-zero. It returns a list in which each element is a coordinate that points to a non-zero element in the condition. Args: input (oneflow.Tensor): the input Tensor. dtype (Optional[flow.dtype], optional): The data type of output. Defaults to None. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array([[0, 1, 0], ... [2, 0, 2]]).astype(np.float32) >>> input = flow.Tensor(x) >>> output = flow.argwhere(input) >>> output tensor([[0, 1], [1, 0], [1, 2]], dtype=oneflow.int32) """ if input.is_lazy: raise ValueError("A lazy tensor can not be applied to argwhere.") (res, size) = flow._C.argwhere(input, dtype=dtype) slice_tup_list = [(0, size.numpy().item(), 1)] return flow.slice(res, slice_tup_list=slice_tup_list) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/as_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import oneflow as flow def as_tensor(data, dtype=None, device=None): if flow.is_tensor(data): if dtype is None: dtype = data.dtype if device is None: device = data.device if data.dtype is dtype and data.device is device: return data else: data = data.to(dtype=dtype, device=device) elif isinstance(data, (np.ndarray)): if dtype is None: if (device is None) or (device.type == "cpu"): data = flow.from_numpy(data) else: data = flow.tensor(data, device=device) else: data_infer_flow_type = flow.framework.dtype.convert_numpy_dtype_to_oneflow_dtype( data.dtype ) if data_infer_flow_type is dtype: if (device is None) or (device.type == "cpu"): data = flow.from_numpy(data) else: data = flow.tensor(data, dtype=dtype, device=device) else: if (device is None) or (device.type == "cpu"): data = flow.tensor(data, dtype=dtype) else: data = flow.tensor(data, dtype=dtype, device=device) else: # not shared memory in this case data = flow.tensor(data) if device is not None: data = data.to(device) if dtype is not None: data = data.to(dtype) return data ================================================ FILE: python/oneflow/nn/modules/batchnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Union import os import oneflow as flow from oneflow.nn.modules.module import Module from oneflow.autograd import Function class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" def __init__( self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ) -> None: super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = flow.nn.Parameter(flow.Tensor(num_features)) self.bias = flow.nn.Parameter(flow.Tensor(num_features)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) if self.track_running_stats: self.register_buffer("running_mean", flow.zeros(num_features)) self.register_buffer("running_var", flow.ones(num_features)) self.register_buffer("num_batches_tracked", flow.tensor(0, dtype=flow.long)) else: self.register_buffer("running_mean", None) self.register_buffer("running_var", None) self.register_buffer("num_batches_tracked", None) self.reset_parameters() def reset_running_stats(self) -> None: if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self) -> None: self.reset_running_stats() if self.affine: flow.nn.init.ones_(self.weight) flow.nn.init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): if self.track_running_stats: num_batches_tracked_key = prefix + "num_batches_tracked" if not num_batches_tracked_key in state_dict: if self.running_mean.is_global: sbp = self.running_mean.sbp placement = self.running_mean.placement state_dict[num_batches_tracked_key] = flow.tensor( 0, dtype=flow.long ).to_global(sbp=sbp, placement=placement) else: state_dict[num_batches_tracked_key] = flow.tensor( 0, dtype=flow.long ) super(_NormBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def extra_repr(self): return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( **self.__dict__ ) class _BatchNorm(_NormBase): def __init__( self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) self.channel_axis = 1 def forward(self, x): self._check_input_dim(x) exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked.add_(1) if self.momentum is None: exponential_average_factor = 1.0 / float(self.num_batches_tracked) if self.training: is_training = True else: is_training = (self.running_mean is None) and (self.running_var is None) # NOTE(lixiang): If it is training mode, pass running_mean and running_var directly to the functor layer. return flow._C.normalization( x, self.running_mean, self.running_var, self.weight, self.bias, axis=self.channel_axis, epsilon=self.eps, momentum=exponential_average_factor, is_training=is_training, ) class BatchNorm1d(_BatchNorm): """Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift `__ . .. math:: y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(20, 100)) >>> m = flow.nn.BatchNorm1d(100) >>> y = m(x) """ def _check_input_dim(self, input): if input.ndim != 2 and input.ndim != 3: raise ValueError( "expected 2D or 3D input (got {}D input)".format(input.ndim) ) class BatchNorm2d(_BatchNorm): """Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift `__ . .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(4, 2, 8, 3)) >>> m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1) >>> y = m(x) """ def __init__( self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_axis = 3 def to_memory_format(self, memory_format) -> None: if memory_format is flow.channels_last: self.channel_axis = 3 elif memory_format is flow.contiguous_format: self.channel_axis = 1 def _check_input_dim(self, input): if input.ndim != 4: raise ValueError("expected 4D input (got {}D input)".format(input.ndim)) class BatchNorm3d(_BatchNorm): r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift `__ . .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, D, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(3, 2, 5, 8, 4)) >>> m = flow.nn.BatchNorm3d(num_features=2, eps=1e-5, momentum=0.1) >>> y = m(x) >>> y.size() oneflow.Size([3, 2, 5, 8, 4]) """ def _check_input_dim(self, input): if input.ndim != 5: raise ValueError("expected 5D input (got {}D input)".format(input.ndim)) global_eps = 0.1 global_momentum = 0.1 global_world_size = 1 global_axis = 1 class SyncBatchNormFunction(flow.autograd.Function): @staticmethod def forward(self, input, weight, bias, running_mean, running_var): assert input.is_local, "SyncBatchNorm does not support global tensor as input." if not input.is_contiguous(): input = input.contiguous() if weight is not None: weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and global_world_size < 2: raise ValueError( "Expected more than 1 value per channel when training, got input size {}".format( size ) ) num_channels = input.shape[global_axis] if input.numel() > 0: # calculate mean/invstd for input. mean, invstd = flow._C.batch_norm_stats(input, global_axis, global_eps) count = flow.full( (1,), input.numel() // input.size(global_axis), dtype=mean.dtype, device=mean.device, ) # C, C, 1 -> (2C + 1) combined = flow.cat([mean, invstd, count], dim=0) else: # for empty input, set stats and the count to zero. The stats with # zero count will be filtered out later when computing global mean # & invstd, but they still needs to participate the all_gather # collective communication to unblock other peer processes. combined = flow.zeros( 2 * num_channels + 1, dtype=input.dtype, device=input.device ) # Use allgather instead of allreduce because count could be different across # ranks, simple all reduce op can not give correct results. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on # all gathered mean, invstd and count. # world_size * (2C + 1) combined_size = combined.numel() combined_flat = flow.empty( global_world_size, combined_size, dtype=combined.dtype, device=combined.device, ) flow.comm.all_gather_into_tensor(combined_flat, combined) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = flow.split(combined_flat, num_channels, dim=1) # remove stats from empty inputs mask = count_all.squeeze(-1) >= 1 count_all = count_all[mask] mean_all = mean_all[mask] invstd_all = invstd_all[mask] # calculate global mean & invstd mean, invstd = flow._C.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, global_momentum, global_eps, count_all.view(-1), ) self.save_for_backward(input, weight, mean, invstd, count_all.to(flow.int32)) # apply element-wise normalization if input.numel() > 0: return flow._C.batch_norm_elemt( input, weight, bias, mean, invstd, global_axis, global_eps ) else: return flow.zeros(*(input.shape), dtype=input.dtype, device=input.device) @staticmethod def backward(self, grad_output): if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None channel_axis = 1 if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": if saved_input.dim() == 3: channel_axis = 2 elif saved_input.dim() == 4: channel_axis = 3 elif saved_input.dim() == 5: channel_axis = 4 # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = flow._C.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, channel_axis ) # synchronizing stats used to calculate input gradient. num_channels = sum_dy.shape[0] combined = flow.cat([sum_dy, sum_dy_xmu], dim=0) flow.comm.all_reduce(combined) sum_dy, sum_dy_xmu = flow.split(combined, num_channels) # backward pass for gradient calculation grad_input = flow._C.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor, channel_axis, ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. return grad_input, grad_weight, grad_bias, None, None class SyncBatchNorm(_BatchNorm): r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift `__ . .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over all mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch Normalization or Spatio-temporal Batch Normalization. Currently :class:`SyncBatchNorm` only supports :class:`~oneflow.nn.DistributedDataParallel` (DDP) with single GPU per process. Use :meth:`oneflow.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping Network with DDP. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, +)` eps: a value added to the denominator for numerical stability. Default: ``1e-5`` momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, +)` - Output: :math:`(N, C, +)` (same shape as input) .. note:: Synchronization of batchnorm statistics occurs only while training, i.e. synchronization is disabled when ``model.eval()`` is set or if ``self.training`` is otherwise ``False``. Examples:: >>> import oneflow as flow >>> bn = flow.nn.BatchNorm2d(100) >>> sync_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(bn).cuda() >>> input = flow.randn(20, 100, 35, 45, device="cuda") >>> output = sync_bn(input) """ def __init__( self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ) -> None: super().__init__(num_features, eps, momentum, affine, track_running_stats) def _check_input_dim(self, input): if input.dim() < 2: raise ValueError( "expected at least 2D input (got {}D input)".format(input.dim()) ) if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": if input.dim() == 3: self.channel_axis = 2 elif input.dim() == 4: self.channel_axis = 3 elif input.dim() == 5: self.channel_axis = 4 def _check_non_zero_input_channels(self, input): if input.size(1) == 0: raise ValueError( "SyncBatchNorm number of input channels should be non-zero" ) def forward(self, input): # currently only GPU input is supported if not input.is_cuda: raise ValueError("SyncBatchNorm expected input tensor to be on GPU") self._check_input_dim(input) self._check_non_zero_input_channels(input) if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: assert self.num_batches_tracked is not None self.num_batches_tracked.add_(1) if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum r""" Decide whether the mini-batch stats should be used for normalization rather than the buffers. Mini-batch stats are used in training mode, and in eval mode when buffers are None. """ if self.training: bn_training = True else: bn_training = (self.running_mean is None) and (self.running_var is None) # Don't sync batchnorm stats in inference mode (model.eval()). need_sync = bn_training and self.training if need_sync: need_sync = flow.env.get_world_size() > 1 # # fallback to framework BN when synchronization is not necessary if not need_sync: return flow._C.normalization( input, self.running_mean, self.running_var, self.weight, self.bias, axis=self.channel_axis, epsilon=self.eps, momentum=exponential_average_factor, is_training=bn_training, ) else: assert bn_training global global_eps global global_momentum global global_world_size global global_axis global_eps = self.eps global_momentum = exponential_average_factor global_world_size = flow.env.get_world_size() global_axis = self.channel_axis assert ( self.track_running_stats ), "`track_running_stats` should be True when using SyncBatchNorm." return SyncBatchNormFunction.apply( input, self.weight, self.bias, self.running_mean, self.running_var, ) @classmethod def convert_sync_batchnorm(cls, module): r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to :class:`oneflow.nn.SyncBatchNorm` layers. Args: module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers Returns: The original :attr:`module` with the converted :class:`oneflow.nn.SyncBatchNorm` layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, a new :class:`oneflow.nn.SyncBatchNorm` layer object will be returned instead. Example:: >>> import oneflow as flow >>> module = flow.nn.Sequential( flow.nn.Linear(20, 100), flow.nn.BatchNorm1d(100)).cuda() >>> sync_bn_module = flow.nn.SyncBatchNorm.convert_sync_batchnorm(module) """ module_output = module if isinstance(module, flow.nn.modules.batchnorm._BatchNorm): module_output = flow.nn.SyncBatchNorm( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, ) if module.affine: with flow.no_grad(): module_output.weight = module.weight module_output.bias = module.bias module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): module_output.add_module(name, cls.convert_sync_batchnorm(child)) del module return module_output if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/batchnorm_fused.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Union import os import oneflow as flow from oneflow.nn.modules.module import Module class _FusedNormBase(Module): """Common base of _FusedBatchNorm""" def __init__( self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ) -> None: super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = flow.nn.Parameter(flow.Tensor(num_features)) self.bias = flow.nn.Parameter(flow.Tensor(num_features)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) if self.track_running_stats: self.register_buffer("running_mean", flow.Tensor(num_features)) self.register_buffer("running_var", flow.Tensor(num_features)) else: self.register_parameter("running_mean", None) self.register_parameter("running_var", None) self.reset_parameters() def reset_running_stats(self) -> None: if self.track_running_stats: self.running_mean.fill_(0) self.running_var.fill_(1) def reset_parameters(self) -> None: self.reset_running_stats() if self.affine: flow.nn.init.ones_(self.weight) flow.nn.init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError def extra_repr(self): return "num_features={num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( **self.__dict__ ) class _FusedBatchNorm(_FusedNormBase): def __init__( self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) self.channel_axis = 1 def forward(self, x, addend=None): self._check_input_dim(x) if self.training: is_training = True else: is_training = (self.running_mean is None) and (self.running_var is None) return flow._C.normalization_add_relu( x, addend if addend is not None else None, self.running_mean if not self.training or self.track_running_stats else None, self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, axis=self.channel_axis, epsilon=self.eps, momentum=self.momentum, is_training=is_training, ) class FusedBatchNorm1d(_FusedBatchNorm): """Applies Fused Batch Normalization over a 2D or 3D input, the formula is: .. math:: out = ReLU(BatchNorm(input) + addend) The formula of Batch Normalization is: .. math:: y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(20, 100)).to("cuda") # FusedBatchNorm support in GPU currently. >>> m = flow.nn.FusedBatchNorm1d(num_features=100, eps=1e-5, momentum=0.1).to("cuda") >>> y = m(x, addend=None) """ def _check_input_dim(self, input): if input.ndim != 2 and input.ndim != 3: raise ValueError( "expected 2D or 3D input (got {}D input)".format(input.ndim) ) class FusedBatchNorm2d(_FusedBatchNorm): """Applies Fused Batch Normalization over a 4D input, the formula is: .. math:: out = ReLU(BatchNorm(input) + addend) The formula of Batch Normalization is: .. math:: y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(4, 2, 8, 3)).to("cuda") # FusedBatchNorm support in GPU currently. >>> m = flow.nn.FusedBatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to("cuda") >>> y = m(x, addend=None) """ def __init__( self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_axis = 3 def _check_input_dim(self, input): if input.ndim != 4: raise ValueError("expected 4D input (got {}D input)".format(input.ndim)) class FusedBatchNorm3d(_FusedBatchNorm): r"""Applies Fused Batch Normalization over a 5D input, the formula is: .. math:: out = ReLU(BatchNorm(input) + addend) The formula of Batch Normalization is: .. math:: y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, D, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.random.randn(3, 2, 5, 8, 4)).to("cuda") # FusedBatchNorm support in GPU currently. >>> m = flow.nn.FusedBatchNorm3d(num_features=2, eps=1e-5, momentum=0.1).to("cuda") >>> y = m(x, addend=None) """ def _check_input_dim(self, input): if input.ndim != 5: raise ValueError("expected 5D input (got {}D input)".format(input.ndim)) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/broadcast_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.utils import _single, _handle_size_arg def broadcast_shapes(*shapes): r"""broadcast_shapes(*shapes) -> Size The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.broadcast_shapes.html. Similar to :func:`oneflow.broadcast_tensors` but for shapes. This is equivalent to ``flow.broadcast_tensors(*map(flow.empty, shapes))[0].shape`` but avoids the need create to intermediate tensors. This is useful for broadcasting tensors of common batch shape but different rightmost shape, e.g. to broadcast mean vectors with covariance matrices. Args: \*shapes (flow.Size): Shapes of tensors. Returns: A shape compatible with all input shapes. Raises: RuntimeError: If shapes are incompatible. Example:: >>> import oneflow as flow >>> flow.broadcast_shapes((2,), (3, 1), (1, 1, 1)) oneflow.Size([1, 3, 2]) """ shapes = _single(shapes) return flow._C.broadcast_shapes(shapes) def broadcast_tensors(*tensors): r"""broadcast_tensors(*tensors) -> List of Tensors The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.broadcast_tensors.html. Broadcasts the given tensors according to ``broadcasting-semantics``. Args: *tensors: any number of tensors of the same type .. warning:: More than one element of a broadcasted tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensors, please clone them first. Example:: >>> import oneflow as flow >>> x = flow.arange(3).view(1, 3) >>> y = flow.arange(2).view(2, 1) >>> a, b = flow.broadcast_tensors(x, y) >>> a.size() oneflow.Size([2, 3]) >>> a tensor([[0, 1, 2], [0, 1, 2]], dtype=oneflow.int64) """ tensors = _single(tensors) return flow._C.broadcast_tensors(tensors) def broadcast_to(input, shape): r"""broadcast_to(input, shape) -> Tensors The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.broadcast_to.html. Broadcasts ``input`` to the shape ``shape``. Equivalent to calling ``input.expand(shape)``. See :func:`oneflow.expand` for details. Args: input (oneflow.Tensor): the input tensor. shape (list, tuple, or oneflow.Size): the new shape. Example:: >>> import oneflow as flow >>> x = flow.tensor([1, 2, 3]) >>> flow.broadcast_to(x, (3, 3)) tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype=oneflow.int64) """ shape = _handle_size_arg(shape) shape = _single(shape) return flow._C.broadcast_to(input, shape) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/constant.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List, Optional, Union import numpy as np import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.common_types import _size_any_t from oneflow.nn.modules.utils import _single, _handle_size_arg class _ConstantBase: def __init__( self, size: Union[_size_any_t, flow.Size], value: Union[float, int, complex], dtype: Optional[flow.dtype], device: Union[flow.device, int, str] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, requires_grad: bool = False, ) -> None: assert size is not None, "shape must not be None!" assert isinstance( size, (int, tuple, list, flow.Size) ), "shape should be int or tuple int!" self.device = device if isinstance(self.device, int): self.device = flow.device("cuda", self.device) if isinstance(self.device, str): self.device = flow.device(self.device) self.requires_grad = requires_grad size = _single(size) if dtype is None: dtype = flow.get_default_dtype() if placement is None: if device is None: self.device = flow.device("cpu") else: assert device is None self.placement = placement self.sbp = sbp if placement is not None: assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(self.sbp, flow.sbp.sbp): self.sbp = (self.sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp assert len(self.sbp) == len(placement.ranks.shape) else: assert sbp is None, "sbp: %s" % sbp self.shape = size self.value = value self.dtype = dtype def forward(self): if self.placement is not None: if isinstance(self.value, flow.Tensor): assert ( self.value.ndim <= 1 and self.value.numel() == 1 ), "Only tensor with single element or scalar tensor are supported as value!" res = flow._C.global_tensor_constant( self.shape, self.value, dtype=self.dtype, placement=self.placement, sbp=self.sbp, ) else: res = flow._C.global_constant( self.shape, self.value, dtype=self.dtype, placement=self.placement, sbp=self.sbp, ) else: if isinstance(self.value, flow.Tensor): assert ( self.value.ndim <= 1 and self.value.numel() == 1 ), "Only tensor with single element or scalar tensor are supported as value!" res = flow._C.tensor_constant( self.shape, self.value, dtype=self.dtype, device=self.device ) else: res = flow._C.constant( self.shape, self.value, dtype=self.dtype, device=self.device ) res.requires_grad = self.requires_grad return res def _handle_meta_args( input, size: Union[_size_any_t, List[int], flow.Size, None] = None, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): if isinstance(device, str): device = flow.device(device) if size is None: new_size = input.shape else: new_size = _handle_size_arg(size) if dtype is None: new_dtype = input.dtype else: new_dtype = dtype new_device = device new_placement = placement new_sbp = sbp new_requires_grad = requires_grad if new_device is not None: assert ( new_placement is None ), "argument 'placement' must be None when argument 'device' exist" assert ( new_sbp is None ), "argument 'sbp' must be None when argument 'device' exist" elif new_device is None and new_placement is None and new_sbp is None: new_device = input.device if input.is_local else None new_placement = input.placement if input.is_global else None new_sbp = input.sbp if input.is_global else None else: if new_placement is None and new_sbp is not None: assert ( input.is_global ), "argument 'placement' must not be None when argument 'sbp' exist and Tensor is local" new_placement = input.placement elif new_placement is not None and new_sbp is None: assert ( input.is_global ), "argument 'sbp' must not be None when argument 'placement' exist and Tensor is local" new_sbp = input.sbp assert isinstance( new_size, (int, tuple, list, flow.Size) ), f"argument 'size' must be tuple of ints, not %s" % (type(new_size)) assert isinstance( new_dtype, flow.dtype ), f"argument 'dtype' must be flow.dtype, not %s" % (type(new_dtype)) if new_placement is not None: assert isinstance( new_placement, flow.placement ), f"argument 'placement' must be flow.placement, not %s" % ( type(new_placement) ) assert isinstance( new_sbp, (flow.sbp.sbp, tuple) ), f"argument 'sbp' must be flow.sbp.sbp, not %s" % (type(new_sbp)) else: assert isinstance( new_device, (str, flow.device) ), f"argument 'device' must be flow.device, not %s" % (type(new_device)) assert isinstance( new_requires_grad, bool ), f"argument 'requires_grad' must be bool, not %s" % (type(new_requires_grad)) return new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad class Ones(_ConstantBase): def __init__( self, size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): super().__init__(size, 1, dtype, device, placement, sbp, requires_grad) def ones_op( *size: Union[_size_any_t, flow.Size, List[int]], dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): """ Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument `size`. Args: size (an integer or tuple of integer values): defining the shape of the output tensor. Can be \\ a variable number of arguments or a collection like a list or tuple. dtype (flow.dtype, optional): the desired data type of returned tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.ones(5) >>> y tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32) >>> y = flow.ones(2,3) # construct local tensor >>> y tensor([[1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32) >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.ones(4, 5, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """ size = _handle_size_arg(size) return Ones(size, dtype, device, placement, sbp, requires_grad).forward() def ones_like_op( input, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad) return Ones( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad ).forward() class Zeros(_ConstantBase): def __init__( self, size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): super().__init__(size, 0, dtype, device, placement, sbp, requires_grad) def zeros_op( *size: Union[_size_any_t, flow.Size, List[int]], dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): """ Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument `size`. Args: size(an integer or tuple of integer values) - defining the shape of the output tensor. Can be \\ a variable number of arguments or a collection like a list or tuple. dtype (flow.dtype, optional): the desired data type of returned tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.zeros(5) >>> y tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32) >>> y = flow.zeros(2,3) >>> y tensor([[0., 0., 0.], [0., 0., 0.]], dtype=oneflow.float32) """ size = _handle_size_arg(size) return Zeros(size, dtype, device, placement, sbp, requires_grad).forward() def zeros_like_op( input, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad) return Zeros( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad ).forward() class Full(_ConstantBase): def __init__( self, size, value, dtype, device=None, placement=None, sbp=None, requires_grad=False, ): super().__init__(size, value, dtype, device, placement, sbp, requires_grad) def full_op( size: Union[_size_any_t, flow.Size], fill_value: Union[float, int, complex], dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): """ Creates a tensor of size `size` filled with fill_value. The tensor’s dtype is inferred from `value`. Args: size(int...): a list, tuple, or oneflow.Size of integers defining the shape of the output tensor. fill_value(Scalar): the value to fill the output tensor with. dtype (oneflow.dtype, optional): the desired data type of returned tensor. device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.full((5,),5) >>> y tensor([5, 5, 5, 5, 5], dtype=oneflow.int64) >>> y = flow.full((2,3),5.0) # construct local tensor >>> y tensor([[5., 5., 5.], [5., 5., 5.]], dtype=oneflow.float32) >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.full((2,3), 5.0, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """ size = _handle_size_arg(size) if not isinstance(fill_value, (int, float, complex, flow.Tensor)): # handle numpy scalar dtype assert isinstance( fill_value.dtype, (np.dtype) ), "fill_value must be python scalar or numpy scalar." fill_value = fill_value.item() if dtype is None: dtype = flow.tensor(fill_value).dtype return Full( size, fill_value, dtype, device, placement, sbp, requires_grad ).forward() def full_like_op( input, fill_value, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None ] = None, requires_grad: bool = False, ): """ full_like(input, fill_value, \*, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. ``oneflow.full_like(input, fill_value)`` is equivalent to ``oneflow.full(input.size(), fill_value, dtype=input.dtype, device=input.device)``. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.full_like.html. Args: input(oneflow.Tensor) fill_value(Scalar): the value to fill the output tensor with. dtype (oneflow.dtype, optional): the desired data type of returned tensor. device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(2, 3) >>> y = flow.full_like(x, 2.0) >>> y tensor([[2., 2., 2.], [2., 2., 2.]], dtype=oneflow.float32) >>> y = flow.full_like(x, 2, dtype=flow.int32) >>> y tensor([[2, 2, 2], [2, 2, 2]], dtype=oneflow.int32) >>> placement = flow.placement("cpu", ranks=[0]) >>> y = flow.full_like(x, 5.0, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor >>> y.is_global True """ ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad) return Full( new_size, fill_value, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ).forward() def new_ones_op( x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False ): ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad) if new_placement is not None: res = flow._C.global_constant( new_size, 1.0, dtype=new_dtype, placement=placement, sbp=sbp ) else: res = flow._C.constant(new_size, 1.0, dtype=new_dtype, device=new_device) res.requires_grad = new_requires_grad return res def new_zeros_op( x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False ): ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad) if new_placement is not None: res = flow._C.global_constant( new_size, 0.0, dtype=new_dtype, placement=new_placement, sbp=new_sbp ) else: res = flow._C.constant(new_size, 0.0, dtype=new_dtype, device=new_device) res.requires_grad = new_requires_grad return res def new_full_op( x, size, fill_value, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, ): size = _handle_size_arg(size) ( new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad, ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad) if flow.is_tensor(fill_value): assert ( len(fill_value.size()) == 0 ), "new_full(): argument 'fill_value' must be Number, not Tensor" fill_value = fill_value.item() if new_placement is not None: res = flow._C.global_constant( new_size, fill_value, dtype=new_dtype, placement=new_placement, sbp=new_sbp ) else: res = flow._C.constant(new_size, fill_value, dtype=new_dtype, device=new_device) res.requires_grad = new_requires_grad return res if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/container.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.utils.container import * from oneflow.nn.modules.module import Module class Sequential(get_seq(Module)): """A sequential container. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Sequential.html?#torch.nn.Sequential. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. To make it easier to understand, here is a small example: .. code-block:: python >>> import oneflow.nn as nn >>> from collections import OrderedDict >>> nn.Sequential(nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU()) #doctest: +ELLIPSIS Sequential( (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (1): ReLU() (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (3): ReLU() ) >>> nn.Sequential(OrderedDict([ ... ('conv1', nn.Conv2d(1,20,5)), ... ('relu1', nn.ReLU()), ... ('conv2', nn.Conv2d(20,64,5)), ... ('relu2', nn.ReLU()) ... ])) #doctest: +ELLIPSIS Sequential( (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (relu2): ReLU() ) """ pass class ModuleList(get_list(Module)): """Holds submodules in a list. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleList.html?#torch.nn.ModuleList. :class:`~oneflow.nn.ModuleList` can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all :class:`~oneflow.nn.Module` methods. Args: modules (iterable, optional): an iterable of modules to add .. code-block:: python >>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) ... def forward(self, x): ... # ModuleList can act as an iterable, or be indexed using ints ... for i, l in enumerate(self.linears): ... x = self.linears[i // 2](x) + l(x) ... return x >>> model = MyModule() >>> model.linears ModuleList( (0): Linear(in_features=10, out_features=10, bias=True) (1): Linear(in_features=10, out_features=10, bias=True) (2): Linear(in_features=10, out_features=10, bias=True) (3): Linear(in_features=10, out_features=10, bias=True) (4): Linear(in_features=10, out_features=10, bias=True) (5): Linear(in_features=10, out_features=10, bias=True) (6): Linear(in_features=10, out_features=10, bias=True) (7): Linear(in_features=10, out_features=10, bias=True) (8): Linear(in_features=10, out_features=10, bias=True) (9): Linear(in_features=10, out_features=10, bias=True) ) """ pass class ModuleDict(get_dict(Module)): """Holds submodules in a dictionary. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleDict.html?#torch.nn.ModuleDict. :class:`~oneflow.nn.ModuleDict` can be indexed like a regular Python dictionary, but modules it contains are properly registered, and will be visible by all :class:`~oneflow.nn.Module` methods. :class:`~oneflow.nn.ModuleDict` is an **ordered** dictionary that respects * the order of insertion, and * in :meth:`~oneflow.nn.ModuleDict.update`, the order of the merged ``OrderedDict``, ``dict`` (started from Python 3.6) or another :class:`~oneflow.nn.ModuleDict` (the argument to :meth:`~oneflow.nn.ModuleDict.update`). Note that :meth:`~oneflow.nn.ModuleDict.update` with other unordered mapping types (e.g., Python's plain ``dict`` before Python version 3.6) does not preserve the order of the merged mapping. Args: modules (iterable, optional): a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module) .. code-block:: python >>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.choices = nn.ModuleDict({ ... 'conv': nn.Conv2d(10, 10, 3), ... 'pool': nn.MaxPool2d(3) ... }) ... self.activations = nn.ModuleDict([ ... ['lrelu', nn.LeakyReLU()], ... ['prelu', nn.PReLU()] ... ]) ... def forward(self, x, choice, act): ... x = self.choices[choice](x) ... x = self.activations[act](x) ... return x >>> model = MyModule() >>> model.choices ModuleDict( (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (pool): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=(0, 0), dilation=(1, 1)) ) """ pass class ParameterList(get_para_list(Module)): """Holds parameters in a list. :class:`~oneflow.nn.ParameterList` can be indexed like a regular Python list, but parameters it contains are properly registered, and will be visible by all :class:`~oneflow.nn.Module` methods. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ParameterList.html?#torch.nn.ParameterList. Args: parameters (iterable, optional): an iterable of :class:`~oneflow.nn.Parameter` to add .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.params = nn.ParameterList([nn.Parameter(flow.randn(10, 10)) for i in range(10)]) ... ... def forward(self, x): ... # ParameterList can act as an iterable, or be indexed using ints ... for i, p in enumerate(self.params): ... x = self.params[i // 2].mm(x) + p.mm(x) ... return x >>> model = MyModule() >>> model.params ParameterList( (0): Parameter containing: [ of size 10x10] (1): Parameter containing: [ of size 10x10] (2): Parameter containing: [ of size 10x10] (3): Parameter containing: [ of size 10x10] (4): Parameter containing: [ of size 10x10] (5): Parameter containing: [ of size 10x10] (6): Parameter containing: [ of size 10x10] (7): Parameter containing: [ of size 10x10] (8): Parameter containing: [ of size 10x10] (9): Parameter containing: [ of size 10x10] ) """ pass class ParameterDict(get_para_dict(Module)): """ Holds parameters in a dictionary. ParameterDict can be indexed like a regular Python dictionary, but parameters it contains are properly registered, and will be visible by all Module methods. :class:`~oneflow.nn.ParameterDict` is an **ordered** dictionary that respects * the order of insertion, and * in :meth:`~oneflow.nn.ParameterDict.update`, the order of the merged ``OrderedDict`` or another :class:`~oneflow.nn.ParameterDict` (the argument to :meth:`~oneflow.nn.ParameterDict.update`). Note that :meth:`~oneflow.nn.ParameterDict.update` with other unordered mapping types (e.g., Python's plain ``dict``) does not preserve the order of the merged mapping. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ParameterDict.html?#torch.nn.ParameterDict. Args: parameters (iterable, optional): a mapping (dictionary) of (string : :class:`~oneflow.nn.Parameter`) or an iterable of key-value pairs of type (string, :class:`~oneflow.nn.Parameter`) .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.params = nn.ParameterDict({ ... 'left': nn.Parameter(flow.randn(5, 10)), ... 'right': nn.Parameter(flow.randn(5, 10)) ... }) ... ... def forward(self, x, choice): ... x = self.params[choice].mm(x) ... return x >>> model = MyModule() >>> model.params ParameterDict( (left): Parameter containing: [ of size 5x10] (right): Parameter containing: [ of size 5x10] ) """ pass if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import os import oneflow as flow from oneflow.nn import init from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t from oneflow.nn.modules.module import Module from oneflow.nn.modules.utils import _pair, _single, _triple from typing import Union def slice(x, begin, size): ndim = len(x.shape) if not isinstance(begin, (list, tuple)) or len(begin) != ndim: raise ValueError( "begin must be a list/tuple with the same length as input tensor's number of dimensions" ) if not all((isinstance(b, int) or b is None for b in begin)): raise ValueError("element of begin must be a int or None") if not isinstance(size, (list, tuple)) or len(size) != ndim: raise ValueError( "size must be a list/tuple with the same length as input tensor's number of dimensions." ) if not all((isinstance(s, int) or s is None for s in size)): raise ValueError("element of size must be a int or None") slice_tup_list = [] for (b, s, dim_size) in zip(begin, size, x.shape): (start, stop, step) = (None, None, 1) if b is not None: if b < -dim_size or b >= dim_size: raise ValueError("element of begin is out of range") start = b if s is not None: if s == -1: stop = dim_size else: if s <= 0 or s > dim_size: raise ValueError("element of size is invalid") if b + s < dim_size: stop = b + s slice_tup_list.append((start, stop, step)) return flow.slice(x, slice_tup_list) class ConvUtil(object): @classmethod def split(cls, x, axis, split_num): split_len = x.shape[axis] // split_num result_list = [] slice_begin = [0] * len(x.shape) slice_size = [-1] * len(x.shape) slice_size[axis] = split_len for i in range(split_num): slice_begin[axis] = i * split_len result = slice(x, slice_begin, slice_size) result_list.append(result) return result_list def get_padding(padding, kernel_size, dilation, stride): valid_padding_strings = {"same", "valid"} if isinstance(padding, str): if padding not in valid_padding_strings: raise ValueError( "Invalid padding string {!r}, should be one of {}".format( padding, valid_padding_strings ) ) if padding == "same" and any(s != 1 for s in list(stride)): raise ValueError("padding='same' is not supported for strided convolutions") out_padding = [0] * len(kernel_size) if padding == "same": for d, k, i in zip(dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)): total_padding = d * (k - 1) left_pad = total_padding // 2 out_padding[i] = left_pad return out_padding class Conv1d(Module): """Applies a 1D convolution over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv1d.html. In the simplest case, the output value of the layer with input size :math:`(N, C_{\\text{in}}, L)` and output :math:`(N, C_{\\text{out}}, L_{\\text{out}})` can be precisely described as: .. math:: \\text{out}(N_i, C_{\\text{out}_j}) = \\text{bias}(C_{\\text{out}_j}) + \\sum_{k = 0}^{C_{in} - 1} \\text{weight}(C_{\\text{out}_j}, k) \\star \\text{input}(N_i, k) where :math:`\\star` is the valid `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`L` is a length of signal sequence. * :attr:`stride` controls the stride for the cross-correlation, a single number or a one-element tuple. * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. Note: ``padding='valid'`` is the same as no padding. ``padding='same'`` pads the input so the output has the shape as the input. However, this mode doesn't support any stride values other than 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where .. math:: L_{out} = \\left\\lfloor\\frac{L_{in} + 2 \\times \\text{padding} - \\text{dilation} \\times (\\text{kernel\\_size} - 1) - 1}{\\text{stride}} + 1\\right\\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\\text{out\\_channels}, \\frac{\\text{in\\_channels}}{\\text{groups}}, \\text{kernel\\_size})`. The values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(20, 16, 50) >>> input = flow.Tensor(arr) >>> m = nn.Conv1d(16, 33, 3, stride=2) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ): super().__init__() assert padding_mode == "zeros" self.padding_mode = padding_mode self.kernel_size = _single(kernel_size) self.stride = _single(stride) self.dilation = _single(dilation) self.padding = ( get_padding(padding, self.kernel_size, self.dilation, self.stride) if isinstance(padding, str) else _single(padding) ) self.groups = groups self.channel_pos = "channels_first" assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.empty( out_channels, in_channels // groups, *self.kernel_size, dtype=dtype, device=device ) ) self.out_channel_groups = out_channels // groups self.bias = None if bias: self.bias = flow.nn.Parameter( flow.empty(out_channels, dtype=dtype, device=device) ) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def _conv_forward(self, x, weight, bias): return flow._C.conv1d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, channel_pos=self.channel_pos, ) def forward(self, x): return self._conv_forward(x, self.weight, self.bias) def extra_repr(self): s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): s += ", dilation={dilation}" if self.groups != 1: s += ", groups={groups}" if self.bias is None: s += ", bias=False" if self.padding_mode != "zeros": s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) class Conv2d(Module): """Applies a 2D convolution over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv2d.html. In the simplest case, the output value of the layer with input size :math:`(N, C_{\\text{in}}, H, W)` and output :math:`(N, C_{\\text{out}}, H_{\\text{out}}, W_{\\text{out}})` can be precisely described as: .. math:: \\text{out}(N_i, C_{\\text{out}_j}) = \\text{bias}(C_{\\text{out}_j}) + \\sum_{k = 0}^{C_{\\text{in}} - 1} \\text{weight}(C_{\\text{out}_j}, k) \\star \\text{input}(N_i, k) where :math:`\\star` is the valid 2D `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. * :attr:`stride` controls the stride for the cross-correlation, a single number or a tuple. * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. * :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, * At groups=1, all inputs are convolved to all outputs. * At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated. * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters (of size :math:`\\frac{\\text{out_channels}}{\\text{in_channels}}`)., The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Note: When `groups == in_channels` and `out_channels == K * in_channels`, where `K` is a positive integer, this operation is also known as a "depthwise convolution". In other words, for an input of size :math:`(N, C_{in}, L_{in})`, a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments :math:`(C_\\text{in}=C_\\text{in}, C_\\text{out}=C_\\text{in} \\times \\text{K}, ..., \\text{groups}=C_\\text{in})`. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[0] - \\text{dilation}[0] \\times (\\text{kernel_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor .. math:: W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[1] - \\text{dilation}[1] \\times (\\text{kernel_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor Attr: - weight (Tensor): the learnable weights of the module of shape :math:`(\\text{out_channels}, \\frac{\\text{in_channels}}{\\text{groups}},` :math:`\\text{kernel_size[0]}, \\text{kernel_size[1]})`. The values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` - bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(20, 16, 50, 100) >>> input = flow.Tensor(arr) >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ): super().__init__() assert padding_mode == "zeros" self.padding_mode = padding_mode self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.dilation = _pair(dilation) self.padding = ( get_padding(padding, self.kernel_size, self.dilation, self.stride) if isinstance(padding, str) else _pair(padding) ) self.groups = groups self.transposed = False if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_pos = "channels_last" self.transposed = True else: self.channel_pos = "channels_first" assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels if self.channel_pos == "channels_first": self.weight = flow.nn.Parameter( flow.empty( out_channels, in_channels // groups, *self.kernel_size, device=device, dtype=dtype ) ) else: self.weight = flow.nn.Parameter( flow.empty( out_channels, *self.kernel_size, in_channels // groups, device=device, dtype=dtype ) ) self.out_channel_groups = out_channels // groups self.bias = None if bias: self.bias = flow.nn.Parameter( flow.empty(out_channels, device=device, dtype=dtype) ) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def to_memory_format(self, memory_format) -> None: if self.channel_pos == "channels_first" and memory_format is flow.channels_last: self.channel_pos = "channels_last" with flow.no_grad(): self.weight.data = self.weight.to(memory_format=flow.channels_last) elif ( self.channel_pos == "channels_last" and memory_format is flow.contiguous_format ): self.channel_pos = "channels_first" with flow.no_grad(): self.weight.data = self.weight.to(memory_format=flow.contiguous_format) def _conv_forward(self, x, weight, bias): return flow._C.conv2d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, channel_pos=self.channel_pos, ) def forward(self, x): return self._conv_forward(x, self.weight, self.bias) def extra_repr(self): s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): s += ", dilation={dilation}" if self.groups != 1: s += ", groups={groups}" if self.bias is None: s += ", bias=False" if self.padding_mode != "zeros": s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) class Conv3d(Module): r"""Applies a 3D convolution over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv3d.html. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: out(N_i, C_{out_j}) = bias(C_{out_j}) + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) where :math:`\star` is the valid 3D `cross-correlation`_ operator * :attr:`stride` controls the stride for the cross-correlation. * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the height dimension and the third `int` for the width dimension Note: ``padding='valid'`` is the same as no padding. ``padding='same'`` pads the input so the output has the shape as the input. However, this mode doesn't support any stride values other than 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(1, 2, 5, 5, 5) >>> input = flow.Tensor(arr) >>> m = nn.Conv3d(2, 4, kernel_size=3, stride=1) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: Union[str, _size_3_t] = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, ): super().__init__() assert padding_mode == "zeros" self.padding_mode = padding_mode self.kernel_size = _triple(kernel_size) self.stride = _triple(stride) self.dilation = _triple(dilation) self.padding = ( get_padding(padding, self.kernel_size, self.dilation, self.stride) if isinstance(padding, str) else _triple(padding) ) self.groups = groups self.channel_pos = "channels_first" assert in_channels % groups == 0, "in_channels must be divisible by groups" assert out_channels % groups == 0, "out_channels must be divisible by groups" self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.empty( out_channels, in_channels // groups, *self.kernel_size, device=device, dtype=dtype ) ) self.out_channel_groups = out_channels // groups self.bias = None if bias: self.bias = flow.nn.Parameter( flow.empty(out_channels, device=device, dtype=dtype) ) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def to_memory_format(self, memory_format) -> None: if self.channel_pos == "channels_first" and memory_format is flow.channels_last: self.channel_pos = "channels_last" with flow.no_grad(): self.weight.data = self.weight.to(memory_format=flow.channels_last) elif ( self.channel_pos == "channels_last" and memory_format is flow.contiguous_format ): self.channel_pos = "channels_first" with flow.no_grad(): self.weight.data = self.weight.to(memory_format=flow.contiguous_format) def _conv_forward(self, x, weight, bias): return flow._C.conv3d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, channel_pos=self.channel_pos, ) def forward(self, x): return self._conv_forward(x, self.weight, self.bias) def extra_repr(self): s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): s += ", dilation={dilation}" if self.groups != 1: s += ", groups={groups}" if self.bias is None: s += ", bias=False" if self.padding_mode != "zeros": s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) class ConvTranspose1d(Module): r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv1d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation). This module supports TensorFloat32. * :attr:`stride` controls the stride for the cross-correlation. * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d` are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when ``stride > 1``, :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output shape. :attr:`output_padding` is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that :attr:`output_padding` is only used to find output shape, but does not actually add zero-padding to output. Note: In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. Please see the notes on randomness for background. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where .. math:: L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} \times (\text{kernel_size} - 1) + \text{output_padding} + 1 Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\\text{in\_channels}, \frac{\\text{out\\_channels}}{\text{groups}},` :math:`\\text{kernel\\_size})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \\text{kernel\\_size}}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \\text{kernel\\_size}}` .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t = 0, output_padding: _size_1_t = 0, groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, padding_mode: str = "zeros", ) -> None: super().__init__() assert ( padding_mode == "zeros" ), "Only `zeros` padding mode is supported for ConvTranspose1d" self.kernel_size = _single(kernel_size) self.stride = _single(stride) self.padding = _single(padding) self.dilation = _single(dilation) self.output_padding = _single(output_padding) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) self.filters = out_channels self.bias = None self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, x): return flow._C.deconv1d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, "channels_first", ) class ConvTranspose2d(Module): """ Applies a 2D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation). Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of each dimension in the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of each dimension in the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = (H_{in} - 1) \\times \\text{stride}[0] - 2 \\times \\text{padding}[0] + \\text{dilation}[0] \\times (\\text{kernel_size}[0] - 1) + \\text{output_padding}[0] + 1 .. math:: W_{out} = (W_{in} - 1) \\times \\text{stride}[1] - 2 \\times \\text{padding}[1] + \\text{dilation}[1] \\times (\\text{kernel_size}[1] - 1) + \\text{output_padding}[1] + 1 Attributes: ConvTranspose2d.weight (Tensor): the learnable weights of the module of shape :math:`(\\text{in_channels}, \\frac{\\text{out_channels}}{\\text{groups}},` :math:`\\text{kernel_size[0]}, \\text{kernel_size[1]})`. The values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{out} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` ConvTranspose2d.bias (Tensor): the learnable bias of the module of shape (out_channels) If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{out} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` Examples:: >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> m = m.to("cuda") >>> input = flow.Tensor(np.random.randn(20, 16, 50, 100), device=flow.device("cuda")) >>> output = m(input) >>> output.size() oneflow.Size([20, 33, 93, 100]) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t = 0, output_padding: _size_2_t = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = "zeros", ) -> None: super().__init__() assert padding_mode == "zeros" self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.output_padding = _pair(output_padding) self.dilation = _pair(dilation) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) self.in_channel_groups = in_channels // groups self.filters = out_channels self.bias = None self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, x): res = flow._C.deconv2d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, "channels_first", ) return res class ConvTranspose3d(Module): r""" Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes. This module can be seen as the gradient of Conv3d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation). This module supports TensorFloat32. * :attr:`stride` controls the stride for the cross-correlation. * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the height dimension and the third `int` for the width dimension Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d` are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when ``stride > 1``, :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output shape. :attr:`output_padding` is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that :attr:`output_padding` is only used to find output shape, but does not actually add zero-padding to output. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of each dimension in the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of each dimension in the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where .. math:: D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1 .. math:: H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1 .. math:: W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] \times (\text{kernel_size}[2] - 1) + \text{output_padding}[2] + 1 Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},` :math:`\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels) If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}` Examples:: >>> import oneflow as flow >>> import oneflow.nn as nn >>> # With square kernels and equal stride >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) >>> input = flow.randn(20, 16, 10, 50, 100) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: _size_3_t = 0, output_padding: _size_3_t = 0, groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, padding_mode: str = "zeros", ) -> None: super().__init__() assert padding_mode == "zeros", "Only `zeros` padding mode is supported" self.kernel_size = _triple(kernel_size) self.stride = _triple(stride) self.padding = _triple(padding) self.dilation = _triple(dilation) self.output_padding = _triple(output_padding) self.groups = groups self.in_channels = in_channels self.out_channels = out_channels assert in_channels % groups == 0 assert out_channels % groups == 0 self.weight = flow.nn.Parameter( flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) self.filters = out_channels self.bias = None self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, x): return flow._C.deconv3d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, "channels_first", ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import random import sys import traceback from google.protobuf import text_format from typing import List, Optional, Sequence, Tuple, Union import oneflow as flow import oneflow._oneflow_internal._C as _C from oneflow.framework.tensor import Tensor from oneflow.framework.scope_util import current_scope from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t, _size_any_t from oneflow.nn.modules.module import Module from oneflow.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple import oneflow.framework.id_util as id_util def local_gen_random_seed(seed=None): if seed is None: seed = -1 has_seed = False else: has_seed = True return (seed, has_seed) class OFRecordReader(Module): def __init__( self, ofrecord_dir: str, batch_size: int = 1, data_part_num: int = 1, part_name_prefix: str = "part-", part_name_suffix_length: int = -1, random_shuffle: bool = False, shuffle_buffer_size: int = 1024, shuffle_after_epoch: bool = False, random_seed: int = -1, device: Union[flow.device, str] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, name: Optional[str] = None, ): super().__init__() if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") self.ofrecord_dir = ofrecord_dir self.batch_size = batch_size self.data_part_num = data_part_num self.part_name_prefix = part_name_prefix self.part_name_suffix_length = part_name_suffix_length self.random_shuffle = random_shuffle self.shuffle_buffer_size = shuffle_buffer_size self.shuffle_after_epoch = shuffle_after_epoch self.placement = placement if placement is None: self.device = device or flow.device("cpu") else: assert device is None if placement is not None: assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp assert len(sbp) == len(placement.ranks.shape) else: assert sbp is None, "sbp: %s" % sbp self.sbp = sbp (self.seed, self.has_seed) = local_gen_random_seed(random_seed) self._op = flow.stateful_op("OFRecordReader").Output("out").Build() def forward(self): if self.placement is not None: res = _C.dispatch_ofrecord_reader( self._op, data_dir=self.ofrecord_dir, data_part_num=self.data_part_num, part_name_prefix=self.part_name_prefix, part_name_suffix_length=self.part_name_suffix_length, batch_size=self.batch_size, shuffle_buffer_size=self.shuffle_buffer_size, random_shuffle=self.random_shuffle, shuffle_after_epoch=self.shuffle_after_epoch, seed=self.seed, sbp=self.sbp, placement=self.placement, ) else: res = _C.dispatch_ofrecord_reader( self._op, data_dir=self.ofrecord_dir, data_part_num=self.data_part_num, part_name_prefix=self.part_name_prefix, part_name_suffix_length=self.part_name_suffix_length, batch_size=self.batch_size, shuffle_buffer_size=self.shuffle_buffer_size, random_shuffle=self.random_shuffle, shuffle_after_epoch=self.shuffle_after_epoch, seed=self.seed, device=self.device, ) return res class OFRecordRawDecoder(Module): def __init__( self, blob_name: str, shape: Sequence[int], dtype: flow.dtype, dim1_varying_length: bool = False, truncate: bool = False, auto_zero_padding: bool = False, name: Optional[str] = None, ): super().__init__() if auto_zero_padding: print( "WARNING: auto_zero_padding has been deprecated, Please use truncate instead.\n" ) if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") self.blob_name = blob_name self.shape = shape self.dtype = dtype self.dim1_varying_length = dim1_varying_length self.truncate = truncate self.auto_zero_padding = auto_zero_padding self._op = ( flow.stateful_op("ofrecord_raw_decoder").Input("in").Output("out").Build() ) def forward(self, input): res = _C.dispatch_ofrecord_raw_decoder( self._op, input, name=self.blob_name, shape=self.shape, data_type=self.dtype, dim1_varying_length=self.dim1_varying_length, truncate=self.truncate or self.auto_zero_padding, ) return res class CoinFlip(Module): r""" CoinFlip(batch_size=1, random_seed=None, probability=0.5, device=None, placement=None, sbp=None) Generates random boolean values following a bernoulli distribution. The probability of generating a value 1 (true) is determined by the ``probability`` argument. The shape of the generated data can be either specified explicitly with a ``shape`` argument, or chosen to match the shape of the input, if provided. If none are present, a single value per sample is generated. The documentation is referenced from: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.CoinFlip. Args: batch_size (int, optional): Maximum batch size of the pipeline. Negative values for this parameter are invalid - the default value may only be used with serialized pipeline (the value stored in serialized pipeline is used instead). In most cases, the actual batch size of the pipeline will be equal to the maximum one. Default: 1 random_seed (int, optional): Random seed. Default: None probability (float, optional): Probability of value 1. Default: 0.5 device (oneflow.device, optional): Desired device of returned tensor. Default: if None, uses the current device for the default tensor type. placement (oneflow.placement, optional): Desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): Desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. """ def __init__( self, batch_size: int = 1, random_seed: Optional[int] = None, probability: float = 0.5, device: Union[flow.device, str] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() self.batch_size = batch_size self.probability = probability self.placement = placement if placement is None: self.device = device or flow.device("cpu") assert self.device == "cpu" or self.device == flow.device( "cpu" ), "coin flip only supports cpu currently." else: assert device is None if placement is not None: assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp assert len(sbp) == len(placement.ranks.shape) assert ( self.placement.type == "cpu" ), "coin flip only supports cpu currently." else: assert sbp is None, "sbp: %s" % sbp self.sbp = sbp (self.seed, self.has_seed) = local_gen_random_seed(random_seed) self._op = flow.stateful_op("coin_flip").Output("out").Build() def forward(self): if self.placement is not None: res = _C.dispatch_coin_flip( self._op, batch_size=self.batch_size, probability=self.probability, has_seed=self.has_seed, seed=self.seed, placement=self.placement, sbp=self.sbp, ) else: res = _C.dispatch_coin_flip( self._op, batch_size=self.batch_size, probability=self.probability, has_seed=self.has_seed, seed=self.seed, device=self.device, ) return res class CropMirrorNormalize(Module): r""" CropMirrorNormalize(color_space="BGR", output_layout="NCHW", crop_h=0, crop_w=0, crop_pos_y=0.5, crop_pos_x=0.5, mean= [0.0], std= [1.0], output_dtype=oneflow.float) Performs fused cropping, normalization, format conversion (NHWC to NCHW) if desired, and type casting. Normalization takes the input images and produces the output by using the following formula: .. math:: output = (input - mean) / std .. note:: If no cropping arguments are specified, only mirroring and normalization will occur. This operator allows sequence inputs and supports volumetric data. The documentation is referenced from: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.CropMirrorNormalize. Args: color_space (str, optional): The color space of the input image. Default: "BGR" output_layout (str, optional): Tensor data layout for the output. Default: "NCHW" crop_h (int, optional): Cropping the window height (in pixels). Default: 0 crop_w (int, optional): Cropping window width (in pixels). Default: 0 crop_pos_y (float, optional): Normalized (0.0 - 1.0) vertical position of the start of the cropping window (typically, the upper left corner). Default: 0.5 crop_pos_x (float, optional): Normalized (0.0 - 1.0) horizontal position of the cropping window (upper left corner). Default: 0.5 mean (float or list of float, optional): Mean pixel values for image normalization. Default: [0.0], std (float or list of float, optional): Standard deviation values for image normalization. Default: [1.0] output_dtype (oneflow.dtype, optional): Output data type. Default: ``oneflow.float`` """ def __init__( self, color_space: str = "BGR", output_layout: str = "NCHW", crop_h: int = 0, crop_w: int = 0, crop_pos_y: float = 0.5, crop_pos_x: float = 0.5, mean: Sequence[float] = [0.0], std: Sequence[float] = [1.0], output_dtype: flow.dtype = flow.float, ): super().__init__() if output_layout != "NCHW": print( "WARNING: output_layout has been deprecated. Please use Environment Variable ONEFLOW_ENABLE_NHWC, and make it equals 1." ) if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": output_layout = "NHWC" else: output_layout = "NCHW" self.color_space = color_space self.output_layout = output_layout self.mean = mean self.std = std self.crop_h = crop_h self.crop_w = crop_w self.crop_pos_y = crop_pos_y self.crop_pos_x = crop_pos_x self.output_dtype = output_dtype self._op_uint8_with_mirror = ( flow.stateful_op("crop_mirror_normalize_from_uint8") .Input("in") .Input("mirror") .Output("out") .Build() ) self._op_uint8_no_mirror = ( flow.stateful_op("crop_mirror_normalize_from_uint8") .Input("in") .Output("out") .Build() ) self._op_buffer_with_mirror = ( flow.stateful_op("crop_mirror_normalize_from_tensorbuffer") .Input("in") .Input("mirror") .Output("out") .Build() ) self._op_buffer_no_mirror = ( flow.stateful_op("crop_mirror_normalize_from_tensorbuffer") .Input("in") .Output("out") .Build() ) def forward(self, input, mirror=None): if input.dtype is flow.uint8: if mirror is not None: res = _C.dispatch_crop_mirror_normalize_from_uint8( self._op_uint8_with_mirror, (input, mirror), color_space=self.color_space, output_layout=self.output_layout, mean=self.mean, std=self.std, crop_h=self.crop_h, crop_w=self.crop_w, crop_pos_x=self.crop_pos_x, crop_pos_y=self.crop_pos_y, output_dtype=self.output_dtype, ) else: res = _C.dispatch_crop_mirror_normalize_from_uint8( self._op_uint8_no_mirror, (input,), color_space=self.color_space, output_layout=self.output_layout, mean=self.mean, std=self.std, crop_h=self.crop_h, crop_w=self.crop_w, crop_pos_x=self.crop_pos_x, crop_pos_y=self.crop_pos_y, output_dtype=self.output_dtype, ) elif input.dtype is flow.tensor_buffer: if mirror is not None: res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer( self._op_buffer_with_mirror, (input, mirror), color_space=self.color_space, output_layout=self.output_layout, mean=self.mean, std=self.std, crop_h=self.crop_h, crop_w=self.crop_w, crop_pos_x=self.crop_pos_x, crop_pos_y=self.crop_pos_y, output_dtype=self.output_dtype, ) else: res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer( self._op_buffer_no_mirror, (input,), color_space=self.color_space, output_layout=self.output_layout, mean=self.mean, std=self.std, crop_h=self.crop_h, crop_w=self.crop_w, crop_pos_x=self.crop_pos_x, crop_pos_y=self.crop_pos_y, output_dtype=self.output_dtype, ) else: print( "ERROR! oneflow.nn.CropMirrorNormalize module NOT support input dtype = ", input.dtype, ) raise NotImplementedError return res class OFRecordImageDecoderRandomCrop(Module): def __init__( self, blob_name: str, color_space: str = "BGR", num_attempts: int = 10, random_seed: Optional[int] = None, random_area: Sequence[float] = [0.08, 1.0], random_aspect_ratio: Sequence[float] = [0.75, 1.333333], ): super().__init__() self.blob_name = blob_name self.color_space = color_space self.num_attempts = num_attempts self.random_area = random_area self.random_aspect_ratio = random_aspect_ratio (self.seed, self.has_seed) = local_gen_random_seed(random_seed) self._op = ( flow.stateful_op("ofrecord_image_decoder_random_crop") .Input("in") .Output("out") .Build() ) def forward(self, input): res = _C.dispatch_ofrecord_image_decoder_random_crop( self._op, input, name=self.blob_name, color_space=self.color_space, num_attempts=self.num_attempts, random_area=self.random_area, random_aspect_ratio=self.random_aspect_ratio, has_seed=self.has_seed, seed=self.seed, ) return res class OFRecordImageDecoder(Module): def __init__(self, blob_name: str, color_space: str = "BGR"): super().__init__() self._op = ( flow.stateful_op("ofrecord_image_decoder").Input("in").Output("out").Build() ) self.blob_name = blob_name self.color_space = color_space def forward(self, input): res = _C.dispatch_ofrecord_image_decoder( self._op, input, name=self.blob_name, color_space=self.color_space ) return res class OFRecordImageGpuDecoderRandomCropResize(Module): def __init__( self, target_width: int, target_height: int, num_attempts: Optional[int] = 10, seed: Optional[int] = 0, random_area: Optional[Sequence[float]] = [0.08, 1.0], random_aspect_ratio: Optional[Sequence[float]] = [0.75, 1.333333], num_workers: Optional[int] = 3, warmup_size: Optional[int] = 6400, max_num_pixels: Optional[int] = 67108864, ): super().__init__() self.target_width = target_width self.target_height = target_height self.num_attempts = num_attempts self.seed = seed assert len(random_area) == 2 self.random_area = random_area assert len(random_aspect_ratio) == 2 self.random_aspect_ratio = random_aspect_ratio self.num_workers = num_workers self.warmup_size = warmup_size self.max_num_pixels = max_num_pixels gpu_decoder_conf = ( flow.core.operator.op_conf_pb2.ImageDecoderRandomCropResizeOpConf() ) # parse failed when excu clang format if use `gpu_decoder_conf.in = "error_input_need_to_be_replaced"` setattr(gpu_decoder_conf, "in", "error_input_need_to_be_replaced") gpu_decoder_conf.out = "out" gpu_decoder_conf.target_width = ( -1 ) # Set the default value, otherwise the parsing fails gpu_decoder_conf.target_height = -1 gpu_decoder_conf_str = text_format.MessageToString(gpu_decoder_conf) self._op = flow._oneflow_internal.one.ImageDecoderRandomCropResizeOpExpr( id_util.UniqueStr("ImageGpuDecoder"), gpu_decoder_conf_str, ["in"], ["out"] ) def forward(self, input): if not input.is_lazy: print( "ERROR! oneflow.nn.OFRecordImageGpuDecoderRandomCropResize module ", "NOT support run as eager module, please use it in nn.Graph.", ) raise NotImplementedError res = _C.dispatch_image_decoder_random_crop_resize( self._op, input, target_width=self.target_width, target_height=self.target_height, num_attempts=self.num_attempts, seed=self.seed, random_area_min=self.random_area[0], random_area_max=self.random_area[1], random_aspect_ratio_min=self.random_aspect_ratio[0], random_aspect_ratio_max=self.random_aspect_ratio[1], num_workers=self.num_workers, warmup_size=self.warmup_size, max_num_pixels=self.max_num_pixels, ) if not res.is_cuda: print( "WARNING! oneflow.nn.OFRecordImageGpuDecoderRandomCropResize ONLY support ", "CUDA runtime version >= 10.2, so now it degenerates into CPU decode version.", ) return res class TensorBufferToListOfTensors(Module): def __init__( self, out_shapes, out_dtypes, out_num: int = 1, dynamic_out: bool = False ): super().__init__() self._op = ( flow.stateful_op("tensor_buffer_to_list_of_tensors_v2") .Input("in") .Output("out", out_num) .Build() ) self.out_shapes = out_shapes self.out_dtypes = out_dtypes self.dynamic_out = dynamic_out def forward(self, input): return _C.dispatch_tensor_buffer_to_list_of_tensors_v2( self._op, input, out_shapes=self.out_shapes, out_dtypes=self.out_dtypes, dynamic_out=self.dynamic_out, ) def tensor_buffer_to_list_of_tensors(tensor, out_shapes, out_dtypes): return TensorBufferToListOfTensors( [list(out_shape) for out_shape in out_shapes], out_dtypes, len(out_shapes) )(tensor) class ImageResize(Module): def __init__( self, target_size: Union[int, Sequence[int]] = None, min_size: Optional[int] = None, max_size: Optional[int] = None, keep_aspect_ratio: bool = False, resize_side: str = "shorter", channels: int = 3, dtype: Optional[flow.dtype] = None, interpolation_type: str = "auto", name: Optional[str] = None, color_space: Optional[str] = None, interp_type: Optional[str] = None, resize_shorter: int = 0, resize_x: int = 0, resize_y: int = 0, ): super().__init__() if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") deprecated_param_used = False if color_space is not None: print( "WARNING: color_space has been deprecated. Please use channels instead." ) print(traceback.format_stack()[-2]) deprecated_param_used = True assert isinstance(color_space, str) if color_space.upper() == "RGB" or color_space.upper() == "BGR": channels = 3 elif color_space.upper() == "GRAY": channels = 1 else: raise ValueError("invalid color_space") self.channels = channels if interp_type is not None: print( "WARNING: interp_type has been deprecated. Please use interpolation_type instead." ) print(traceback.format_stack()[-2]) deprecated_param_used = True assert isinstance(interp_type, str) if interp_type == "Linear": interpolation_type = "bilinear" elif interp_type == "NN": interpolation_type = "nearest_neighbor" elif interp_type == "Cubic": interpolation_type = "bicubic" else: raise ValueError("invalid interp_type") self.interpolation_type = interpolation_type if resize_x > 0 and resize_y > 0: print( "WARNING: resize_x and resize_y has been deprecated. Please use target_size instead." ) print(traceback.format_stack()[-2]) deprecated_param_used = True target_size = (resize_x, resize_y) keep_aspect_ratio = False if resize_shorter > 0: print( "WARNING: resize_shorter has been deprecated. Please use target_size instead." ) print(traceback.format_stack()[-2]) deprecated_param_used = True target_size = resize_shorter keep_aspect_ratio = True resize_side = "shorter" self.keep_aspect_ratio = keep_aspect_ratio if self.keep_aspect_ratio: if not isinstance(target_size, int): raise ValueError( "target_size must be an int when keep_aspect_ratio is True" ) if min_size is None: min_size = 0 if max_size is None: max_size = 0 if resize_side == "shorter": resize_longer = False elif resize_side == "longer": resize_longer = True else: raise ValueError('resize_side must be "shorter" or "longer"') self.target_size = target_size self.min_size = min_size self.max_size = max_size self.resize_longer = resize_longer self._op = ( flow.stateful_op("image_resize_keep_aspect_ratio") .Input("in") .Output("out") .Output("size") .Output("scale") .Build() ) else: if ( not isinstance(target_size, (list, tuple)) or len(target_size) != 2 or (not all((isinstance(size, int) for size in target_size))) ): raise ValueError( "target_size must be a form like (width, height) when keep_aspect_ratio is False" ) if dtype is None: dtype = flow.uint8 self.dtype = dtype (self.target_w, self.target_h) = target_size self._op = ( flow.stateful_op("image_resize_to_fixed") .Input("in") .Output("out") .Output("scale") .Build() ) def forward(self, input): if self.keep_aspect_ratio: res = _C.dispatch_image_resize_keep_aspect_ratio( self._op, input, target_size=self.target_size, min_size=self.min_size, max_size=self.max_size, resize_longer=self.resize_longer, interpolation_type=self.interpolation_type, ) new_size = flow.tensor_buffer_to_tensor( res[1], dtype=flow.int32, instance_shape=(2,) ) scale = flow.tensor_buffer_to_tensor( res[2], dtype=flow.float32, instance_shape=(2,) ) else: res = _C.dispatch_image_resize_to_fixed( self._op, input, target_width=self.target_w, target_height=self.target_h, channels=self.channels, data_type=self.dtype, interpolation_type=self.interpolation_type, ) new_size = None scale = res[1] res_image = res[0] return (res_image, scale, new_size) def raw_decoder( input_record, blob_name: str, shape: Sequence[int], dtype: flow.dtype, dim1_varying_length: bool = False, truncate: bool = False, auto_zero_padding: bool = False, name: Optional[str] = None, ): if auto_zero_padding: print( "WARNING: auto_zero_padding has been deprecated, Please use truncate instead.\n " ) return OFRecordRawDecoder( blob_name, shape, dtype, dim1_varying_length, truncate or auto_zero_padding, name, ).forward(input_record) def get_ofrecord_handle( ofrecord_dir: str, batch_size: int = 1, data_part_num: int = 1, part_name_prefix: str = "part-", part_name_suffix_length: int = -1, random_shuffle: bool = False, shuffle_buffer_size: int = 1024, shuffle_after_epoch: bool = False, name: Optional[str] = None, ): return OFRecordReader( ofrecord_dir, batch_size, data_part_num, part_name_prefix, part_name_suffix_length, random_shuffle, shuffle_buffer_size, shuffle_after_epoch, name, )() class ImageFlip(Module): """This operator flips the images. The flip code corresponds to the different flip mode: 0 (0x00): Non Flip 1 (0x01): Horizontal Flip 2 (0x02): Vertical Flip 3 (0x03): Both Horizontal and Vertical Flip Args: images: The input images. flip_code: The flip code. Returns: The result image. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.array([ ... [[[1, 2, 3], [3, 2, 1]], ... [[2, 3, 4], [4, 3, 2]]], ... [[[3, 4, 5], [5, 4, 3]], ... [[4, 5, 6], [6, 5, 4]]]]) >>> image_tensors = flow.Tensor(arr, device=flow.device("cpu")) >>> image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3) >>> flip_code = flow.ones(arr.shape[0], dtype=flow.int8) >>> output = nn.image.flip()(image_tensor_buffer, flip_code).numpy() >>> output[0] array([[[3., 2., 1.], [1., 2., 3.]], [[4., 3., 2.], [2., 3., 4.]]], dtype=float32) >>> output[1] array([[[5., 4., 3.], [3., 4., 5.]], [[6., 5., 4.], [4., 5., 6.]]], dtype=float32) """ def __init__(self): super().__init__() def forward(self, images, flip_code): return flow._C.image_flip(images, flip_code=flip_code) class ImageDecode(Module): def __init__(self, dtype: flow.dtype = flow.uint8, color_space: str = "BGR"): super().__init__() self.color_space = color_space self.dtype = dtype self._op = flow.stateful_op("image_decode").Input("in").Output("out").Build() def forward(self, input): return _C.dispatch_image_decode( self._op, input, color_space=self.color_space, data_type=self.dtype ) class ImageNormalize(Module): def __init__(self, std: Sequence[float], mean: Sequence[float]): super().__init__() self.std = std self.mean = mean self._op = flow.stateful_op("image_normalize").Input("in").Output("out").Build() def forward(self, input): return _C.dispatch_image_normalize( self._op, input, mean=self.mean, std=self.std ) class COCOReader(Module): def __init__( self, annotation_file: str, image_dir: str, batch_size: int, shuffle: bool = True, random_seed: Optional[int] = None, group_by_aspect_ratio: bool = True, remove_images_without_annotations: bool = True, stride_partition: bool = True, device: Union[flow.device, str] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() _handle_shuffle_args(self, shuffle, random_seed) _handle_distributed_args(self, device, placement, sbp) self.annotation_file = annotation_file self.image_dir = image_dir self.batch_size = batch_size self.group_by_aspect_ratio = group_by_aspect_ratio self.remove_images_without_annotations = remove_images_without_annotations self.stride_partition = stride_partition self._op = ( flow.stateful_op("COCOReader") .Output("image") .Output("image_id") .Output("image_size") .Output("gt_bbox") .Output("gt_label") .Output("gt_segm") .Output("gt_segm_index") .Build() ) def forward(self): if self.placement is None: # local apply outputs = _C.dispatch_coco_reader( self._op, session_id=current_scope().session_id, annotation_file=self.annotation_file, image_dir=self.image_dir, batch_size=self.batch_size, shuffle_after_epoch=self.shuffle, random_seed=self.random_seed, group_by_ratio=self.group_by_aspect_ratio, remove_images_without_annotations=self.remove_images_without_annotations, stride_partition=self.stride_partition, device=self.device, ) else: # consistent apply outputs = _C.dispatch_coco_reader( self._op, session_id=current_scope().session_id, annotation_file=self.annotation_file, image_dir=self.image_dir, batch_size=self.batch_size, shuffle_after_epoch=self.shuffle, random_seed=self.random_seed, group_by_ratio=self.group_by_aspect_ratio, remove_images_without_annotations=self.remove_images_without_annotations, stride_partition=self.stride_partition, placement=self.placement, sbp=self.sbp, ) return outputs class ImageBatchAlign(Module): def __init__(self, shape: Sequence[int], dtype: flow.dtype, alignment: int): super().__init__() self._op = ( flow.stateful_op("image_batch_align").Input("in").Output("out").Build() ) self.shape = shape self.dtype = dtype self.alignment = alignment def forward(self, input): return _C.dispatch_image_batch_align( self._op, input, shape=self.shape, data_type=self.dtype, alignment=self.alignment, dynamic_out=False, ) class OFRecordBytesDecoder(Module): r"""This operator reads an tensor as bytes. The output might need further decoding process like cv2.imdecode() for images and decode("utf-8") for characters,depending on the downstream task. Args: blob_name: The name of the target feature in OFRecord. name: The name for this component in the graph. input: the Tensor which might be provided by an OFRecordReader. Returns: The result Tensor encoded with bytes. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> def example(): ... batch_size = 16 ... record_reader = flow.nn.OFRecordReader( ... "dataset/", ... batch_size=batch_size, ... part_name_suffix_length=5, ... ) ... val_record = record_reader() ... bytesdecoder_img = flow.nn.OFRecordBytesDecoder("encoded") ... image_bytes_batch = bytesdecoder_img(val_record) ... image_bytes = image_bytes_batch.numpy()[0] ... return image_bytes ... example() # doctest: +SKIP array([255 216 255 ... 79 255 217], dtype=uint8) """ def __init__(self, blob_name: str, name: Optional[str] = None): super().__init__() if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") self._op = ( flow.stateful_op("ofrecord_bytes_decoder").Input("in").Output("out").Build() ) self.blob_name = blob_name def forward(self, input): return _C.dispatch_ofrecord_bytes_decoder(self._op, input, name=self.blob_name) class GPTIndexedBinDataReader(Module): def __init__( self, data_file_prefix: str, seq_length: int, num_samples: int, batch_size: int, dtype: flow.dtype = flow.int64, shuffle: bool = True, random_seed: Optional[int] = None, split_sizes: Optional[Sequence[str]] = None, split_index: Optional[int] = None, device: Union[flow.device, str] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() _handle_shuffle_args(self, shuffle, random_seed) _handle_distributed_args(self, device, placement, sbp) self.data_file_prefix = data_file_prefix self.batch_size = batch_size self.num_samples = num_samples self.seq_length = seq_length self.dtype = dtype if split_index is None: split_index = 0 self.split_index = split_index if split_sizes is None: split_sizes = (1,) self.split_sizes = split_sizes if split_index >= len(split_sizes): raise ValueError( "split index {} is out of range, split_sizes {}".formart( split_index, split_sizes ) ) self.op_ = ( flow.stateful_op("megatron_gpt_mmap_data_loader").Output("out").Build() ) def forward(self): if self.placement is None: output = _C.dispatch_megatron_gpt_mmap_data_loader( self.op_, data_file_prefix=self.data_file_prefix, seq_length=self.seq_length, label_length=1, num_samples=self.num_samples, batch_size=self.batch_size, dtype=self.dtype, shuffle=self.shuffle, random_seed=self.random_seed, split_sizes=self.split_sizes, split_index=self.split_index, device=self.device, ) else: output = _C.dispatch_megatron_gpt_mmap_data_loader( self.op_, data_file_prefix=self.data_file_prefix, seq_length=self.seq_length, label_length=1, num_samples=self.num_samples, batch_size=self.batch_size, dtype=self.dtype, shuffle=self.shuffle, random_seed=self.random_seed, split_sizes=self.split_sizes, split_index=self.split_index, placement=self.placement, sbp=self.sbp, ) return output class RawReader(Module): def __init__( self, files: List[str], shape: Sequence[int], dtype: flow.dtype, batch_size: int, random_shuffle: bool = True, shuffle_block_size: int = 0, random_seed: Optional[int] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() _handle_shuffle_args(self, random_shuffle, random_seed) _handle_distributed_args(self, None, placement, sbp) self.files = files self.shape = shape self.dtype = dtype self.batch_size = batch_size self.shuffle_block_size = shuffle_block_size self.op = flow.stateful_op("raw_reader").Output("out").Build() def forward(self): if self.placement is None: output = _C.dispatch_raw_reader( self.op, files=self.files, shape=self.shape, data_type=self.dtype, batch_size=self.batch_size, random_shuffle=self.shuffle, shuffle_block_size=self.shuffle_block_size, random_seed=self.random_seed, device=self.device, ) else: output = _C.dispatch_raw_reader( self.op, files=self.files, shape=self.shape, data_type=self.dtype, batch_size=self.batch_size, random_shuffle=self.shuffle, shuffle_block_size=self.shuffle_block_size, random_seed=self.random_seed, placement=self.placement, sbp=self.sbp, ) return output def _handle_distributed_args(module, device, placement, sbp): module.placement = placement if placement is None: module.device = device or flow.device("cpu") else: if device is not None: raise ValueError( "The 'device' and 'placement' arguments can't be specified at the same time." ) module.device = None if isinstance(sbp, (tuple, list)): for sbp_item in sbp: if not isinstance(sbp_item, flow.sbp.sbp): raise ValueError(f"invalid sbp item: {sbp_item}") elif isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: raise ValueError(f"invalid 'sbp' argument: {sbp}") if len(sbp) != len(placement.ranks.shape): raise ValueError( "Number of SBP's dimensions of sbp and number of placement ranks'dimensions must equal." f" {len(sbp)} vs. {len(placement.ranks)}" ) module.sbp = sbp def _handle_shuffle_args(module, shuffle, random_seed): module.shuffle = shuffle if random_seed is None: if shuffle: module.random_seed = random.randrange(sys.maxsize) else: module.random_seed = -1 else: assert isinstance(random_seed, int) module.random_seed = random_seed if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/distance.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module from typing import Optional class CosineSimilarity(Module): r""" Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. .. math :: \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.CosineSimilarity.html#torch.nn.CosineSimilarity Args: dim (int, optional): Dimension where cosine similarity is computed. Default: 1 eps (float, optional): Small value to avoid division by zero. Default: 1e-8 Shape: - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`. - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, and broadcastable with x1 at other dimensions. - Output: :math:`(\ast_1, \ast_2)` For example: .. code-block:: python >>> import oneflow as flow >>> from oneflow import nn >>> input1 = flow.randn(100, 128) >>> input2 = flow.randn(100, 128) >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2) """ def __init__(self, dim: Optional[int] = 1, eps: Optional[float] = 1e-08,) -> None: super().__init__() self.dim = dim self.eps = eps def forward(self, x1: Tensor, x2: Tensor) -> Tensor: return flow._C.cosine_similarity(x1, x2, self.dim, self.eps) class PairwiseDistance(Module): r"""Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: .. math :: \left \| x \right \| _p = (\sum_{i=1}^n \left | x_i \right |^p )^{\frac{1}{p}} The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.PairwiseDistance.html. Args: p (real): the norm degree. Default: 2 eps (float, optional): Small value to avoid division by zero. Default: 1e-6 keepdim (bool, optional): Determines whether or not to keep the vector dimension. Default: False Shape: - Input1: :math:`(N, D)` or :math:`(D)`, where N = batch dimension and D = vector dimension - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the input1 - Output: :math:`(N)` or :math:`()` based on input dimension. If keepdim is True, then :math:`(N, 1)` or :math:`(1)` based on input dimension. For example: .. code-block:: python >>> import oneflow as flow >>> pdist = flow.nn.PairwiseDistance(p=2) >>> x1 = flow.arange(12).reshape(3, 4) >>> x2 = flow.arange(12).reshape(3, 4) >>> pdist(x1, x2) tensor([2.0000e-06, 2.0000e-06, 2.0000e-06], dtype=oneflow.float32) >>> pdist(x1, x2).shape oneflow.Size([3]) """ def __init__( self, p: Optional[float] = 2.0, eps: Optional[float] = 1e-06, keepdim: Optional[bool] = False, ) -> None: super().__init__() self.p = p self.eps = eps self.keepdim = keepdim def forward(self, x1: Tensor, x2: Tensor) -> Tensor: return flow._C.pairwise_distance( x1, x2, p=self.p, eps=self.eps, keepdim=self.keepdim ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/distributed_partial_fc_sample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import oneflow as flow import oneflow._oneflow_internal._C as _C from oneflow.nn.modules.module import Module class DistributedPariticalFCSample(Module): def __init__(self, num_sample): super().__init__() self.num_sample = num_sample self._op = ( flow.stateful_op("distributed_partial_fc_sample") .Input("weight") .Input("label") .Output("mapped_label") .Output("sampled_label") .Output("sampled_weight") .Build() ) def forward(self, weight, label): res = _C.dispatch_distributed_partial_fc_sample( self._op, weight=weight, label=label, num_sample=self.num_sample ) return res def distributed_partial_fc_sample_op(weight, label, num_sample): warnings.warn( "oneflow.distributed_partial_fc_sample is deprecated. Please use nn.DistributedPariticalFCSample module instead.", DeprecationWarning, ) return DistributedPariticalFCSample(num_sample)(weight, label) ================================================ FILE: python/oneflow/nn/modules/dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random import sys import oneflow as flow import oneflow.framework.id_util as id_util from oneflow.nn.modules.module import Module class _DropoutNd(Module): __constants__ = ["p", "inplace"] p: float inplace: bool def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super(_DropoutNd, self).__init__() if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, but got {}".format(p) ) self.p = p self.inplace = inplace def extra_repr(self) -> str: return "p={}, inplace={}".format(self.p, self.inplace) class Dropout(_DropoutNd): def __init__(self, p: float = 0.5, inplace: bool = False, generator=None): _DropoutNd.__init__(self, p, inplace) self.p = p self.generator = generator def forward(self, x, addend=None): return flow._C.dropout( x, self.p, self.training, self.inplace, self.generator, addend=addend if addend is not None else None, ) class Dropout1d(Dropout): def forward(self, x, addend=None): return flow._C.dropout1d(x, self.p, self.training) class Dropout2d(Dropout): def forward(self, x, addend=None): return flow._C.dropout2d(x, self.p, self.training) class Dropout3d(Dropout): def forward(self, x, addend=None): return flow._C.dropout3d(x, self.p, self.training) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/einsum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def einsum_op(equation, *operands): return flow._C.einsum(equation, operands) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/empty.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List, Optional, Union import oneflow as flow from oneflow.nn.common_types import _size_any_t from oneflow.nn.modules.utils import _handle_size_arg, _single def empty_op( *size, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str] = None, placement: flow.placement = None, sbp: Union[ flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp] ] = None, requires_grad: bool = False, pin_memory: bool = False, ): assert size is not None, "shape must not be None" shape = _single(_handle_size_arg(size)) if dtype is None: dtype = flow.get_default_dtype() if placement is not None: assert ( device is None ), "argument 'device' must be None when argument 'placement' exist" if placement is not None: assert ( sbp is not None ), "argument 'sbp' must not be None when argument 'placement' exist" assert isinstance( sbp, (flow.sbp.sbp, tuple, list) ), f"argument 'sbp' must be flow.sbp.sbp, not %s" % (type(sbp)) if isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), ( "Element in argument 'sbp' must be flow.sbp.sbp, not %s" % (type(elem)) ) assert len(sbp) == len(placement.ranks.shape) else: assert sbp is None, "argument 'sbp' must be None" if placement is not None: tensor = flow._C.global_empty(shape, dtype=dtype, placement=placement, sbp=sbp) tensor.requires_grad_(requires_grad) else: tensor = flow._C.empty( shape, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory, ) return tensor def empty_like_op( input, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, placement: flow.placement = None, sbp: flow._oneflow_internal.sbp.sbp = None, requires_grad: bool = False, ): new_size = _single(_handle_size_arg(input.size())) if placement is None and input.is_global and input.placement is not None: placement = input.placement if sbp is None and input.is_global and input.sbp is not None: sbp = input.sbp if dtype is None: dtype = input.dtype if placement is None and device is None: device = input.device return empty_op( new_size, dtype=dtype, device=device, placement=placement, sbp=sbp, requires_grad=requires_grad, ) def new_empty_op( x, size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False ): new_size = _single(_handle_size_arg(size)) new_dtype = dtype new_device = device new_placement = placement new_sbp = sbp if dtype is None: new_dtype = x.dtype if device is None: new_device = x.device if x.is_local else None if placement is None: new_placement = x.placement if x.is_global else None if sbp is None: new_sbp = x.sbp if x.is_global else None return empty_op( new_size, dtype=new_dtype, device=new_device, placement=new_placement, sbp=new_sbp, requires_grad=requires_grad, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/expand.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.utils import _single, _handle_size_arg def expand_op(input, *sizes): sizes = _handle_size_arg(sizes) sizes = _single(sizes) return flow._C.expand(input, sizes) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/fake_quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.module import Module class FakeQuantization(Module): """ Simulate the quantize and dequantize operations in training time. The output will be computed as: if quantization_scheme == "symmetric": .. math:: & quant\\_max = 2^{quantization\\_to\\_bit - 1} - 1 & quant\\_min = -quant\\_max & clamp(round(x / scale), quant\\_min, quant\\_max) * scale elif quantization_scheme == "affine": .. math:: & quant\\_max = 2^{quantization\\_to\\_bit} - 1 & quant\\_min = 0 & (clamp(round(x / scale + zero\\_point), quant\\_min, quant\\_max) - zero\\_point) * scale Args: input(oneflow.Tensor): the input value(s), in ``oneflow.float32``. scale(oneflow.Tensor): quantization scale. zero_point(oneflow.Tensor): quantization zero_point. quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". quantization_formula (str): Support "google" or "cambricon". Returns: oneflow.Tensor: Input tensor after quantize and dequantize operations. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32) >>> input_tensor = flow.tensor( ... weight, dtype=flow.float32 ... ) >>> quantization_bit = 8 >>> quantization_scheme = "symmetric" >>> quantization_formula = "google" >>> per_layer_quantization = True >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization) >>> fake_quantization = flow.nn.FakeQuantization(quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme) >>> scale, zero_point = min_max_observer( ... input_tensor, ... ) >>> output_tensor = fake_quantization( ... input_tensor, ... scale, ... zero_point, ... ) """ def __init__( self, quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", ) -> None: super().__init__() self.quantization_formula = quantization_formula self.quantization_bit = quantization_bit self.quantization_scheme = quantization_scheme def forward(self, input, scale, zero_point): return flow._C.fake_quantization( input, scale, zero_point, self.quantization_formula, self.quantization_bit, self.quantization_scheme, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/flatten.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module class Flatten(Module): """Flattens a contiguous range of dims into a tensor. For use with: nn.Sequential. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.Tensor(32, 1, 5, 5) >>> m = flow.nn.Flatten() >>> output = m(input) >>> output.shape oneflow.Size([32, 25]) """ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: super().__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input): return flow._C.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim) def extra_repr(self) -> str: return "start_dim={}, end_dim={}".format(self.start_dim, self.end_dim) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/fold.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.common_types import _size_2_t from oneflow.nn.modules.module import Module class Fold(Module): r""" Fold(output_size, kernel_size, dilation=1, padding=0, stride=1) The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Fold.html. Combines an array of sliding local blocks into a large containing tensor, it also called `col2img` Consider a batched :attr:`input` tensor containing sliding local blocks, e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel_size}), L)`, where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel_size})` is the number of values within a block (a block has :math:`\prod(\text{kernel_size})` spatial locations each containing a :math:`C`-channeled vector), and :math:`L` is the total number of blocks. (This is exactly the same specification as the output shape of :class:`~oneflow.nn.Unfold`.) This operation combines these local blocks into the large :attr:`output` tensor of shape :math:`(N, C, \text{output_size}[0], \text{output_size}[1], \dots)` by summing the overlapping values. Similar to :class:`~oneflow.nn.Unfold`, the arguments must satisfy .. math:: L = \prod_d \left\lfloor\frac{\text{output_size}[d] + 2 \times \text{padding}[d] % - \text{dilation}[d] \times (\text{kernel_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, where :math:`d` is over all spatial dimensions. * :attr:`output_size` describes the spatial shape of the large containing tensor of the sliding local blocks. It is useful to resolve the ambiguity when multiple input shapes map to same number of sliding blocks, e.g., with ``stride > 0``. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify how the sliding blocks are retrieved. * :attr:`stride` controls the stride for the sliding blocks. * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. Args: output_size (int or tuple): the shape of the spatial dimensions of the output (i.e., ``output.sizes()[2:]``) kernel_size (int or tuple): the size of the sliding blocks stride (int or tuple): the stride of the sliding blocks in the input spatial dimensions. Default: 1 padding (int or tuple, optional): implicit zero padding to be added on both sides of input. Default: 0 dilation (int or tuple, optional): a parameter that controls the stride of elements within the neighborhood. Default: 1 * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then their values will be replicated across all spatial dimensions. * For the case of two output spatial dimensions this operation is sometimes called ``col2im``. .. note:: :class:`~oneflow.nn.Fold` calculates each combined value in the resulting large tensor by summing all values from all containing blocks. :class:`~oneflow.nn.Unfold` extracts the values in the local blocks by copying from the large tensor. So, if the blocks overlap, they are not inverses of each other. In general, folding and unfolding operations are related as follows. Consider :class:`~oneflow.nn.Fold` and :class:`~oneflow.nn.Unfold` instances created with the same parameters: >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params) Then for any (supported) ``input`` tensor the following equality holds: :: fold(unfold(input)) == divisor * input where ``divisor`` is a tensor that depends only on the shape and dtype of the ``input``: >>> input_ones = oneflow.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones)) When the ``divisor`` tensor contains no zero elements, then ``fold`` and ``unfold`` operations are inverses of each other (up to constant divisor). .. warning:: Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. Shape: - Input: :math:`(N, C \times \prod(\text{kernel_size}), L)` or :math:`(C \times \prod(\text{kernel_size}), L)` - Output: :math:`(N, C, \text{output_size}[0], \text{output_size}[1], \dots)` or :math:`(C, \text{output_size}[0], \text{output_size}[1], \dots)` as described above For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x_tensor = flow.Tensor(np.random.randn(1, 9, 16)) >>> fold = flow.nn.Fold(output_size=(4, 4), kernel_size=3, padding=1) >>> out = fold(x_tensor) >>> out.shape oneflow.Size([1, 1, 4, 4]) """ def __init__( self, output_size: _size_2_t, kernel_size: _size_2_t, dilation: _size_2_t = 1, padding: _size_2_t = 0, stride: _size_2_t = 1, ) -> None: super(Fold, self).__init__() self.output_size = output_size self.kernel_size = kernel_size self.dilation = dilation self.padding = padding self.stride = stride def forward(self, input): return flow._C.fold( input, self.output_size, self.kernel_size, self.dilation, self.padding, self.stride, "channels_first", ) def extra_repr(self) -> str: return ( "output_size={output_size}, kernel_size={kernel_size}, " "dilation={dilation}, padding={padding}, stride={stride}".format( **self.__dict__ ) ) class Unfold(Module): r""" Unfold(kernel_size, dilation=1, padding=0, stride=1) The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Unfold.html. This op extracts elements in a local window from input tensor, it also called `img2col`. Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, where :math:`N` is the batch dimension, :math:`C` is the channel dimension, and :math:`*` represent arbitrary spatial dimensions. This operation flattens each sliding :attr:`kernel_size`-sized block within the spatial dimensions of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output` tensor of shape :math:`(N, C \times \prod(\text{kernel_size}), L)`, where :math:`C \times \prod(\text{kernel_size})` is the total number of values within each block (a block has :math:`\prod(\text{kernel_size})` spatial locations each containing a :math:`C`-channeled vector), and :math:`L` is the total number of such blocks: .. math:: L = \prod_d \left\lfloor\frac{\text{spatial_size}[d] + 2 \times \text{padding}[d] % - \text{dilation}[d] \times (\text{kernel_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, where :math:`\text{spatial_size}` is formed by the spatial dimensions of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial dimensions. Therefore, indexing :attr:`output` at the last dimension (column dimension) gives all values within a certain block. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify how the sliding blocks are retrieved. * :attr:`stride` controls the stride for the sliding blocks. * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. Args: kernel_size (int or tuple): the size of the sliding blocks stride (int or tuple, optional): the stride of the sliding blocks in the input spatial dimensions. Default: 1 padding (int or tuple, optional): implicit zero padding to be added on both sides of input. Default: 0 dilation (int or tuple, optional): a parameter that controls the stride of elements within the neighborhood. Default: 1 * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or :attr:`stride` is an int or a tuple of length 1, their values will be replicated across all spatial dimensions. * For the case of two input spatial dimensions this operation is sometimes called ``im2col``. .. note:: :class:`~oneflow.nn.Fold` calculates each combined value in the resulting large tensor by summing all values from all containing blocks. :class:`~oneflow.nn.Unfold` extracts the values in the local blocks by copying from the large tensor. So, if the blocks overlap, they are not inverses of each other. In general, folding and unfolding operations are related as follows. Consider :class:`~oneflow.nn.Fold` and :class:`~oneflow.nn.Unfold` instances created with the same parameters: >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params) Then for any (supported) ``input`` tensor the following equality holds: :: fold(unfold(input)) == divisor * input where ``divisor`` is a tensor that depends only on the shape and dtype of the ``input``: >>> input_ones = oneflow.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones)) When the ``divisor`` tensor contains no zero elements, then ``fold`` and ``unfold`` operations are inverses of each other (up to constant divisor). .. warning:: Currently, only 4-D input tensors (batched image-like tensors) are supported. Shape: - Input: :math:`(N, C, *)` - Output: :math:`(N, C \times \prod(\text{kernel_size}), L)` as described above For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x_tensor = flow.Tensor(np.random.randn(1, 1, 4, 4)) >>> unfold = flow.nn.Unfold(kernel_size=3, padding=1) >>> out = unfold(x_tensor) >>> out.shape oneflow.Size([1, 9, 16]) """ def __init__( self, kernel_size: _size_2_t, dilation: _size_2_t = 1, padding: _size_2_t = 0, stride: _size_2_t = 1, ) -> None: super(Unfold, self).__init__() self.kernel_size = kernel_size self.dilation = dilation self.padding = padding self.stride = stride def forward(self, input): return flow._C.unfold( input, self.kernel_size, self.dilation, self.padding, self.stride, "channels_first", ) def extra_repr(self) -> str: return ( "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," " stride={stride}".format(**self.__dict__) ) ================================================ FILE: python/oneflow/nn/modules/fused_mlp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.init import _calculate_fan_in_and_fan_out from oneflow.nn.modules.module import Module from typing import Tuple class FusedMLP(Module): """Applies a linear transformation with relu activation to the incoming data: :math:`y = ReLU(xA^T + b)` Args: in_features: size of each input sample hidden_features: A tuple of each Linear layer hidden size out_features: The final Linear layer hidden size hidden_dropout_rate: A tuple of each hidden layer's dropout rate out_dropout_rate: The final Linear layer's dropout rate Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = {in\\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = {out\\_features}`. Attr: - :attr:`skip_final_activation`: Whether to skip final hidden layer's activation. Default: False. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.FusedMLP(128, [256, 512], 1024).to("cuda") >>> input = flow.Tensor(np.random.randn(1, 128)).to("cuda") >>> output = m(input) >>> output.size() oneflow.Size([1, 1024]) """ def __init__( self, in_features: int, hidden_features: Tuple[int], out_features: int, hidden_dropout_rate: Tuple[float] = None, out_dropout_rate: float = 0.0, skip_final_activation=False, ) -> None: super().__init__() self.in_features = in_features self.hidden_features = hidden_features self.out_features = out_features # TODO(zzk): Add more activation support. self.skip_final_activation = skip_final_activation self.hidden_layer_num = len(hidden_features) self.dropout_rate_list = ( hidden_dropout_rate if hidden_dropout_rate else [0.0] * (self.hidden_layer_num) ) self.dropout_rate_list += [out_dropout_rate] self.add_parameters() self.reset_parameters() self.use_dropout = False for i in range(self.hidden_layer_num + 1): if self.dropout_rate_list[i] != 0.0: self.use_dropout = True break def add_parameters(self) -> None: """Register parameter in FusedMLP module. """ if self.hidden_layer_num != 0: # First layer. self.register_parameter( f"weight_{0}", flow.nn.Parameter( flow.Tensor(self.hidden_features[0], self.in_features) ), ) self.register_parameter( f"bias_{0}", flow.nn.Parameter(flow.Tensor(self.hidden_features[0])) ) # Middle Layer. for idx in range(1, self.hidden_layer_num): self.register_parameter( f"weight_{idx}", flow.nn.Parameter( flow.Tensor( self.hidden_features[idx], self.hidden_features[idx - 1], ) ), ) self.register_parameter( f"bias_{idx}", flow.nn.Parameter(flow.Tensor(self.hidden_features[idx])), ) # Final Layer. self.register_parameter( f"weight_{self.hidden_layer_num}", flow.nn.Parameter( flow.Tensor( self.out_features, self.hidden_features[self.hidden_layer_num - 1], ) ), ) self.register_parameter( f"bias_{self.hidden_layer_num}", flow.nn.Parameter(flow.Tensor(self.out_features)), ) else: # there is only 1 layer. self.register_parameter( f"weight_{0}", flow.nn.Parameter(flow.Tensor(self.out_features, self.in_features)), ) self.register_parameter( f"bias_{0}", flow.nn.Parameter(flow.Tensor(self.out_features)) ) def weight(self, i): """Returns the ith weight. """ return getattr(self, f"weight_{i}") def weights(self): """Returns the weight list in FusedMLP module. """ return [self.weight(i) for i in range(self.hidden_layer_num + 1)] def bias(self, i): """Return the ith bias. """ return getattr(self, f"bias_{i}") def biases(self): """Returns the bias list in FusedMLP module. """ return [self.bias(i) for i in range(self.hidden_layer_num + 1)] def reset_parameters(self) -> None: """Reset the parameters in FusedMLP module. """ for layer_idx in range(self.hidden_layer_num + 1): flow.nn.init.kaiming_uniform_(self.weight(layer_idx), a=math.sqrt(5)) (fan_in, _) = _calculate_fan_in_and_fan_out(self.weight(layer_idx)) bound = 1 / math.sqrt(fan_in) flow.nn.init.uniform_(self.bias(layer_idx), -bound, bound) def forward(self, x): if not self.training or not self.use_dropout: return flow._C.fused_mlp( x, self.weights(), self.biases(), self.skip_final_activation ) else: return flow._C.fused_matmul_bias_add_relu_dropout( x, self.weights(), self.biases(), self.skip_final_activation, self.dropout_rate_list, ) def extra_repr(self) -> str: return "in_features={}, hidden_features={}, out_features={}, skip_final_activation={}".format( self.in_features, self.hidden_features, self.out_features, self.skip_final_activation, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/global_cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op, Tensor from oneflow.nn.modules.module import Module def _check_sbp(sbp): if sbp is None: pass elif isinstance(sbp, (tuple, list)): if not all(isinstance(sbp_item, flow.sbp.sbp) for sbp_item in sbp): raise TypeError( "sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp" ) elif isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: raise TypeError(f"Invalid parameter sbp with type {type(sbp)}") return sbp def local_to_global_op(input, placement=None, sbp=None, *, check_meta=True, copy=False): # Convert None to a tensor with shape 0, in order to input it into flow._C.to_global. if input is None: input = flow.tensor(()) assert isinstance(input, Tensor) assert input.is_local, "input must be a local tensor" if placement is None or sbp is None: raise ValueError( "Converting a local tensor to global tensor must have placement and sbp parameters." ) assert isinstance( placement, flow.placement ), f"Invalid parameter placement with type {type(placement)}" sbp = _check_sbp(sbp) grad_sbp = tuple() return flow._C.to_global(input, placement, sbp, grad_sbp, check_meta, copy) def global_to_global_op( input, placement=None, sbp=None, *, grad_sbp=None, check_meta=False, copy=False ): assert isinstance(input, Tensor) assert input.is_global, "input must be a global tensor" sbp = _check_sbp(sbp) if placement is None: placement = input.placement if sbp is None: sbp = input.sbp assert isinstance( placement, flow.placement ), f"Invalid parameter placement with type {type(placement)}" grad_sbp = _check_sbp(grad_sbp) if grad_sbp is None: grad_sbp = tuple() return flow._C.to_global(input, placement, sbp, grad_sbp, check_meta, copy) def to_global_op(input, placement=None, sbp=None, **kwargs): assert isinstance(input, Tensor) if input.is_global: return global_to_global_op(input=input, placement=placement, sbp=sbp, **kwargs) else: if "grad_sbp" in kwargs: del kwargs["grad_sbp"] return local_to_global_op(input=input, placement=placement, sbp=sbp, **kwargs) def to_local_op(input, *, copy=False): assert input.is_global, "Expected global tensor for to_local but got local tensor!" return flow._C.to_local(input, copy) ================================================ FILE: python/oneflow/nn/modules/grid_sample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def grid_sample( input, grid, mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = False, ): """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.grid_sample.html. Given an :attr:`input` and a flow-field :attr:`grid`, computes the ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are supported. In the spatial (4-D) case, for :attr:`input` with shape :math:`(N, C, H_{in}, W_{in})` and :attr:`grid` with shape :math:`(N, H_{out}, W_{out}, 2)`, the output will have shape :math:`(N, C, H_{out}, W_{out})`. For each output location ``output[n, :, h, w]``, the size-2 vector ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, which are used to interpolate the output value ``output[n, :, h, w]``. In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the ``x``, ``y``, ``z`` pixel locations for interpolating ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or ``bilinear`` interpolation method to sample the input pixels. :attr:`grid` specifies the sampling pixel locations normalized by the :attr:`input` spatial dimensions. Therefore, it should have most values in the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the right-bottom pixel of :attr:`input`. If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding outputs are handled as defined by :attr:`padding_mode`. Options are * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, * ``padding_mode="border"``: use border values for out-of-bound grid locations, * ``padding_mode="reflection"``: use values at locations reflected by the border for out-of-bound grid locations. For location far away from the border, it will keep being reflected until becoming in bound, e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes ``x'' = -0.5``. Note: This function is often used in conjunction with :func:`affine_grid` to build `Spatial Transformer Networks`_ . Note: NaN values in :attr:`grid` would be interpreted as ``-1``. Args: input (Tensor): input of shape :math:`(N, C, H_{in}, W_{in})` (4-D case) or :math:`(N, C, D_{in}, H_{in}, W_{in})` (5-D case) grid (Tensor): flow-field of shape :math:`(N, H_{out}, W_{out}, 2)` (4-D case) or :math:`(N, D_{out}, H_{out}, W_{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` Note: ``mode='bicubic'`` supports only 4-D input. When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. padding_mode (str): padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` align_corners (bool): Geometrically, we consider the pixels of the input as squares rather than points. If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring to the center points of the input's corner pixels. If set to ``False``, they are instead considered as referring to the corner points of the input's corner pixels, making the sampling more resolution agnostic. This option parallels the ``align_corners`` option in :func:`interpolate`, and so whichever option is used here should also be used there to resize the input image before grid sampling. Default: ``False`` Returns: output (Tensor): output Tensor .. _`Spatial Transformer Networks`: https://arxiv.org/abs/1506.02025 .. note:: ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\\alpha=-0.75`. The constant :math:`\\alpha` might be different from packages to packages. For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. This algorithm may "overshoot" the range of values it's interpolating. For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. Clamp the results with :func: `flow.clamp` to ensure they are within the valid range. .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 Examples:: >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(1., 11).reshape((1, 1, 2, 5)), dtype=flow.float32) >>> np_grid = np.array( ... [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], ... [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]] ... ).reshape(1, 2, 5, 2) >>> grid = flow.tensor(np_grid, dtype=flow.float32) >>> output = flow.nn.functional.grid_sample(input, grid, mode='nearest', padding_mode='zeros', ... align_corners=True) >>> output tensor([[[[0., 8., 5., 7., 9.], [1., 8., 5., 8., 0.]]]], dtype=oneflow.float32) """ y = flow._C.grid_sample( input, grid, interpolation_mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return y if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/instancenorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.batchnorm import _NormBase class _InstanceNorm(_NormBase): def __init__( self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = False, track_running_stats: bool = False, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) def _forward(self, x): axis = 1 params_shape = [x.shape[axis]] weight = self.weight bias = self.bias nd_params_shape = [1] * len(x.shape) nd_params_shape[axis] = params_shape[0] mean = x.mean(2, keepdim=True) variance = x.var(2, unbiased=False, keepdim=True) normalized = (x - mean) / flow.sqrt(variance + self.eps) if self.weight is not None and params_shape[0] == self.weight.nelement(): weight = flow.reshape(self.weight, shape=nd_params_shape) if self.bias is not None and params_shape[0] == self.bias.nelement(): bias = flow.reshape(self.bias, shape=nd_params_shape) if self.weight is not None: normalized = normalized * weight if self.bias is not None: normalized = normalized + bias return normalized def forward(self, x): self._check_input_dim(x) reshape_to_1d = flow.reshape(x, [x.shape[0], x.shape[1], -1]) normalized_1d_out = self._forward(reshape_to_1d) reshape_back_to_nd = flow.reshape(normalized_1d_out, list(x.shape)) return reshape_back_to_nd class InstanceNorm1d(_InstanceNorm): """ Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `__. .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. If :attr:`track_running_stats` is set to ``True``, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. .. note:: :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but have some subtle differences. :class:`InstanceNorm1d` is applied on each channel of channeled data like multidimensional time series, but :class:`LayerNorm` is usually applied on entire sample and often in NLP tasks. Additionally, :class:`LayerNorm` applies elementwise affine transform, while :class:`InstanceNorm1d` usually don't apply affine transform. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm1d.html. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. Default: ``False``. track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` Shape: - Input: :math:`(N, C, L)` - Output: :math:`(N, C, L)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> # Without Learnable Parameters >>> m = flow.nn.InstanceNorm1d(100) >>> # With Learnable Parameters >>> m = flow.nn.InstanceNorm1d(100, affine=True) >>> x = flow.Tensor(np.random.randn(20, 100, 40)) >>> output = m(x) """ def _check_input_dim(self, input): if input.dim() == 2: raise ValueError( "InstanceNorm1d returns 0-filled tensor to 2D tensor.This is because InstanceNorm1d reshapes inputs to(1, N * C, ...) from (N, C,...) and this makesvariances 0." ) if input.dim() != 3: raise ValueError("expected 3D input (got {}D input)".format(input.dim())) class InstanceNorm2d(_InstanceNorm): """ Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `__. .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. If :attr:`track_running_stats` is set to ``True``, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. .. note:: :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but have some subtle differences. :class:`InstanceNorm2d` is applied on each channel of channeled data like RGB images, but :class:`LayerNorm` is usually applied on entire sample and often in NLP tasks. Additionally, :class:`LayerNorm` applies elementwise affine transform, while :class:`InstanceNorm2d` usually don't apply affine transform. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm2d.html. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. Default: ``False``. track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> # Without Learnable Parameters >>> m = flow.nn.InstanceNorm2d(100) >>> # With Learnable Parameters >>> m = flow.nn.InstanceNorm2d(100, affine=True) >>> x = flow.Tensor(np.random.randn(20, 100, 35, 45)) >>> output = m(x) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError("expected 4D input (got {}D input)".format(input.dim())) class InstanceNorm3d(_InstanceNorm): """ Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `__. .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors of size C (where C is the input size) if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. If :attr:`track_running_stats` is set to ``True``, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`, where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. .. note:: :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but have some subtle differences. :class:`InstanceNorm3d` is applied on each channel of channeled data like 3D models with RGB color, but :class:`LayerNorm` is usually applied on entire sample and often in NLP tasks. Additionally, :class:`LayerNorm` applies elementwise affine transform, while :class:`InstanceNorm3d` usually don't apply affine transform. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm3d.html. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, D, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. Default: ``False``. track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> # Without Learnable Parameters >>> m = flow.nn.InstanceNorm3d(100) >>> # With Learnable Parameters >>> m = flow.nn.InstanceNorm3d(100, affine=True) >>> x = flow.Tensor(np.random.randn(20, 100, 35, 45, 10)) >>> output = m(x) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError("expected 5D input (got {}D input)".format(input.dim())) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/interpolate.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import warnings from typing import Optional, Tuple, Union import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module class Interpolate: def __init__( self, size: Optional[Union[int, Tuple[int, ...]]] = None, scale_factor: Optional[Union[float, Tuple[float, ...]]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, ): self.size = size if isinstance(scale_factor, tuple): self.scale_factor = tuple((float(factor) for factor in scale_factor)) else: self.scale_factor = float(scale_factor) if scale_factor else None if mode in ("nearest", "area") and align_corners is not None: raise ValueError( "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" ) self.mode = mode self.recompute_scale_factor = recompute_scale_factor if align_corners == None: align_corners = False self.align_corners = align_corners self.height_scale = None self.width_scale = None if isinstance(self.scale_factor, float): self.height_scale = self.scale_factor self.width_scale = self.scale_factor elif isinstance(self.scale_factor, tuple): self.height_scale = self.scale_factor[0] self.width_scale = self.scale_factor[1] else: pass if self.mode not in ( "nearest", "bilinear", "linear", "area", "bicubic", "trilinear", ): raise ValueError( 'interpolation must be "nearest" or "bilinear" or "linear" or "area" or "bicubic" or "trilinear".' ) if self.mode == "nearest" and self.align_corners: raise ValueError('interpolation "nearest" does not support align_corners.') def forward(self, x): if len(x.shape) == 3 and self.mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") if len(x.shape) == 3 and self.mode == "trilinear": raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") if len(x.shape) == 4 and self.mode == "linear": raise NotImplementedError("Got 4D input, but linear mode needs 3D input") if len(x.shape) == 4 and self.mode == "trilinear": raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") if len(x.shape) == 5 and self.mode == "linear": raise NotImplementedError("Got 5D input, but linear mode needs 3D input") if len(x.shape) == 5 and self.mode == "bilinear": raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") dim = len(x.shape) - 2 if self.size is not None and self.scale_factor is not None: raise ValueError("only one of size or scale_factor should be defined") elif self.size is not None: assert self.scale_factor is None scale_factors = [] if isinstance(self.size, (list, tuple)): if len(self.size) != dim: raise ValueError( "size shape must match input shape. Input is {}D, size is {}".format( dim, len(self.size) ) ) output_size = self.size else: output_size = [self.size for _ in range(dim)] for i in range(dim): scale_factors.append(output_size[i] / x.shape[i + 2]) elif self.scale_factor is not None: assert self.size is None output_size = None if isinstance(self.scale_factor, (list, tuple)): if len(self.scale_factor) != dim: raise ValueError( "scale_factor shape must match input shape. Input is {}D, scale_factor is {}".format( dim, len(self.scale_factor) ) ) scale_factors = self.scale_factor else: scale_factors = [self.scale_factor for _ in range(dim)] else: raise ValueError("either size or scale_factor should be defined") if self.recompute_scale_factor and self.size is not None: raise ValueError( "recompute_scale_factor is not meaningful with an explicit size." ) if self.mode == "area" and output_size is None: self.recompute_scale_factor = True if self.recompute_scale_factor is True: assert scale_factors is not None output_size = [ int(math.floor(float(x.size(i + 2)) * scale_factors[i])) for i in range(dim) ] scale_factors = [] for i in range(dim): scale_factors.append(output_size[i] / x.shape[2 + i]) if len(x.shape) == 3 and self.mode == "nearest": return flow._C.upsample_nearest_1d( x, scale_factor=scale_factors[0], output_size=output_size, data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "nearest": return flow._C.upsample_nearest_2d( x, height_scale=scale_factors[0], width_scale=scale_factors[1], output_size=output_size, data_format="channels_first", ) if len(x.shape) == 5 and self.mode == "nearest": return flow._C.upsample_nearest_3d( x, depth_scale=scale_factors[0], height_scale=scale_factors[1], width_scale=scale_factors[2], output_size=output_size, data_format="channels_first", ) if len(x.shape) == 3 and self.mode == "area": assert output_size is not None return flow._C.adaptive_avg_pool1d(x, output_size) if len(x.shape) == 4 and self.mode == "area": assert output_size is not None return flow._C.adaptive_avg_pool2d(x, output_size) if len(x.shape) == 5 and self.mode == "area": assert output_size is not None return flow._C.adaptive_avg_pool3d(x, output_size) if len(x.shape) == 3 and self.mode == "linear": assert self.align_corners is not None return flow._C.upsample_linear_1d( x, scale_factor=scale_factors[0], align_corners=self.align_corners, output_size=output_size, data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "bilinear": assert self.align_corners is not None return flow._C.upsample_bilinear_2d( x, height_scale=scale_factors[0], width_scale=scale_factors[1], align_corners=self.align_corners, output_size=output_size, data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "bicubic": assert self.align_corners is not None return flow._C.upsample_bicubic_2d( x, height_scale=scale_factors[0], width_scale=scale_factors[1], align_corners=self.align_corners, output_size=output_size, data_format="channels_first", ) if len(x.shape) == 5 and self.mode == "trilinear": assert self.align_corners is not None return flow._C.upsample_trilinear_3d( x, depth_scale=scale_factors[0], height_scale=scale_factors[1], width_scale=scale_factors[2], align_corners=self.align_corners, output_size=output_size, data_format="channels_first", ) raise NotImplementedError( "Input Error: Only 3D, 4D and 5D input Tensors supported" " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area" " (got {})".format(len(x.shape), self.mode) ) def interpolate( input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, ): """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate. Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor` The algorithm used for interpolation is determined by :attr:`mode`. Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape. The input dimensions are interpreted in the form: `mini-batch x channels x [optional depth] x [optional height] x width`. The modes available for resizing are: `nearest`, `linear` (3D-only), `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area` Args: input (Tensor): the input tensor size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): output spatial size. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'`` | ``'area'``. Default: ``'nearest'`` align_corners (bool, optional): Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. If set to ``False``, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation *independent* of input size when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. Default: ``False`` recompute_scale_factor (bool, optional): recompute the scale_factor for use in the interpolation calculation. When `scale_factor` is passed as a parameter, it is used to compute the `output_size`. If `recompute_scale_factor` is ``False`` or not specified, the passed-in `scale_factor` will be used in the interpolation computation. Otherwise, a new `scale_factor` will be computed based on the output and input sizes for use in the interpolation computation (i.e. the computation will be identical to if the computed `output_size` were passed-in explicitly). Note that when `scale_factor` is floating-point, the recomputed scale_factor may differ from the one passed in due to rounding and precision issues. .. note:: With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot when displaying the image. .. warning:: With ``align_corners = True``, the linearly interpolating modes (`linear`, `bilinear`, and `trilinear`) don't proportionally align the output and input pixels, and thus the output values can depend on the input size. This was the default behavior for these modes up to version 0.3.1. Since then, the default behavior is ``align_corners = False``. See :class:`~torch.nn.Upsample` for concrete examples on how this affects the outputs. .. warning:: When scale_factor is specified, if recompute_scale_factor=True, scale_factor is used to compute the output_size which will then be used to infer new scales for the interpolation. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 4)), dtype=flow.float32) >>> output = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="linear") >>> output tensor([[[1.0000, 1.2500, 1.7500, 2.2500, 2.7500, 3.2500, 3.7500, 4.0000]]], dtype=oneflow.float32) """ return Interpolate( size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, ).forward(input) def interpolate_like( input, like, mode="nearest", align_corners=None, ): """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate. Down/up samples the input to the same shape as the `like` tensor. The algorithm used for interpolation is determined by :attr:`mode`. Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape. The input dimensions are interpreted in the form: `mini-batch x channels x [optional depth] x [optional height] x width`. The modes available for resizing are: `nearest`, `linear` (3D-only), `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area` Args: input (Tensor): the input tensor like (Tensor): the like tensor mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'`` | ``'area'``. Default: ``'nearest'`` align_corners (bool, optional): Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. If set to ``False``, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values. This only has an effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. Default: ``False`` .. note:: With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot when displaying the image. .. warning:: With ``align_corners = True``, the linearly interpolating modes (`linear`, `bilinear`, and `trilinear`) don't proportionally align the output and input pixels, and thus the output values can depend on the input size. This was the default behavior for these modes up to version 0.3.1. Since then, the default behavior is ``align_corners = False``. See :class:`~torch.nn.Upsample` for concrete examples on how this affects the outputs. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) >>> like = flow.randn(1, 1, 4, 4) >>> output = flow.nn.functional.interpolate_like(input, like, mode="nearest") >>> output tensor([[[[1., 1., 2., 2.], [1., 1., 2., 2.], [3., 3., 4., 4.], [3., 3., 4., 4.]]]], dtype=oneflow.float32) """ return Interpolate( size=like.shape[2:], mode=mode, align_corners=align_corners, ).forward(input) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/is_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def is_tensor_op(obj): r""" is_tensor(input) -> (bool) Note that this function is simply doing ``isinstance(obj, Tensor)``. Using that ``isinstance`` check is better for typechecking with mypy, and more explicit - so it's recommended to use that instead of ``is_tensor``. Args: obj (Object): Object to test For example: .. code-block:: python >>> import oneflow as flow >>> x=flow.tensor([1,2,3]) >>> flow.is_tensor(x) True """ return isinstance(obj, flow.Tensor) ================================================ FILE: python/oneflow/nn/modules/linear.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.init import _calculate_fan_in_and_fan_out from oneflow.nn.modules.module import Module import os class Identity(Module): """A placeholder identity operator that is argument-insensitive. Args: args: any argument (unused) kwargs: any keyword argument (unused) For example: .. code-block:: python import numpy as np import oneflow as flow m = flow.nn.Identity() input = flow.Tensor(np.random.rand(2, 3, 4, 5)) output = m(input) # output = input """ def __init__(self, *args, **kwargs): super().__init__() def forward(self, input: Tensor) -> Tensor: return input class Linear(Module): """Applies a linear transformation to the incoming data: :math:`y = xA^T + b` Args: - in_features: size of each input sample - out_features: size of each output sample - bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = {in\\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = {out\\_features}`. Attr: - :attr:`weight`: the learnable weights of the module of shape :math:`({out\\_features}, {in\\_features})`. The values are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`, where :math:`(k = 1 / {in\\_features})` - :attr:`bias`: the learnable bias of the module of shape :math:`({out\\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`(k = 1 / {in\\_features})` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.Linear(20, 30, False) >>> input = flow.Tensor(np.random.randn(128, 20)) >>> output = m(input) >>> output.size() oneflow.Size([128, 30]) """ def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.weight = flow.nn.Parameter( flow.Tensor(out_features, in_features).to(dtype=dtype, device=device) ) self.bias = ( flow.nn.Parameter(flow.Tensor(out_features).to(dtype=dtype, device=device)) if bias else None ) self.use_fused_matmul_bias = ( self.bias is not None and os.getenv("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR") == "1" ) self.reset_parameters() def reset_parameters(self) -> None: if os.getenv("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "0") == "1": return flow.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: (fan_in, _) = _calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) flow.nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): if self.use_fused_matmul_bias: return flow._C.fused_matmul_bias(x, self.weight, self.bias) else: res = flow._C.matmul(x, self.weight, transpose_a=False, transpose_b=True) if self.bias is not None: res += self.bias return res def extra_repr(self) -> str: return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features, self.bias is not None ) def linear(input, weight, bias=None): r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. Shape: - Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of additional dimensions - Weight: :math:`(out\_features, in\_features)` - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.random.randn(128, 20)) >>> weight = flow.tensor(np.random.randn(30, 20)) >>> output = flow.nn.functional.linear(input, weight) >>> output.size() oneflow.Size([128, 30]) """ res = flow._C.matmul(input, weight, transpose_a=False, transpose_b=True) if bias is not None: res += bias return res if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/linspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import List, Optional, Union import math import oneflow as flow def linspace_op( start: Union[float, flow.Tensor], end: Union[float, flow.Tensor], steps: Union[int, flow.Tensor], dtype: flow.dtype = flow.float32, device: Union[str, flow.device] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, requires_grad: bool = False, ): r""" Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: .. math:: (\text{start}, \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, \ldots, \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, \text{end}) Args: start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor Keyword arguments: dtype(flow.dtype, optional): If `dtype` is not given, the `dtype` is inferred to be `flow.float32`. device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. For example: .. code-block:: python >>> import oneflow as flow >>> y = flow.linspace(3, 10, steps=5) >>> y tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000], dtype=oneflow.float32) """ def is_scalar(tensor): return tensor.ndim == 0 and tensor.nelement() == 1 if isinstance(start, flow.Tensor): if not is_scalar(start): raise TypeError( "linspace(): argument 'start' (position 1) must be Number, not Tensor" ) start = start.item() if isinstance(end, flow.Tensor): if not is_scalar(end): raise TypeError( "linspace(): argument 'end' (position 2) must be Number, not Tensor" ) end = end.item() if isinstance(steps, flow.Tensor): if not is_scalar(steps): raise TypeError( "linspace(): argument 'steps' (position 3) must be Number, not Tensor" ) if flow.is_floating_point(steps): raise TypeError( "linspace(): argument 'steps' (position 3) must be int, not Tensor (with dtype: " + str(steps.dtype) + ")" ) steps = steps.item() if start == end: return flow.full((steps,), start * 1.0) step = 1.0 if steps == 0: end = start elif steps == 1: end = start + 1.0 else: step = (end - start) * 1.0 / (steps - 1) if math.isclose(((end - start) / (steps - 1)) * (steps - 1), (end - start)): end = end + step / 2.0 if placement is None: if isinstance(device, str): device = flow.device(device) res = flow._C.arange(start, end, step, dtype=dtype, device=device) else: assert isinstance( placement, flow._oneflow_internal.placement ), "placement should be oneflow._oneflow_internal.placement type." assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp assert len(sbp) == len(placement.ranks.shape) res = flow._C.global_arange( start, end, step, dtype=dtype, placement=placement, sbp=sbp ) res.requires_grad = requires_grad return res if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/logspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from cgitb import reset from typing import List, Optional, Union import math import oneflow as flow def logspace_op( start: float, end: float, steps: int, base: Optional[float] = 10.0, dtype: flow.dtype = None, device: Union[str, flow.device] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, requires_grad: bool = False, ): r""" logspace(start, end, steps, base=10.0, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor This function is equivalent to PyTorch’s logspace function. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.logspace.html. Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale with base :attr:`base`. That is, the values are: .. math:: (\text{base}^{\text{start}}, \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \ldots, \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \text{base}^{\text{end}}) Args: start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor base (float, optional): base of the logarithm function. Default: ``10.0``. Keyword arguments: dtype (oneflow.dtype, optional): the data type to perform the computation in. Default: if None, uses the global default dtype (see oneflow.get_default_dtype()) when both :attr:`start` and :attr:`end` are real, and corresponding complex dtype when either is complex. device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. Example:: >>> import oneflow as flow >>> flow.logspace(start=-10, end=10, steps=2) tensor([1.0000e-10, 1.0000e+10], dtype=oneflow.float32) >>> flow.logspace(start=0.1, end=1.0, steps=5) tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000], dtype=oneflow.float32) >>> flow.logspace(start=0.1, end=1.0, steps=1) tensor([1.2589], dtype=oneflow.float32) >>> flow.logspace(start=2, end=2, steps=1, base=2) tensor([4.], dtype=oneflow.float32) """ # TODO: Migrate to C++ indice = flow.linspace( start=start, end=end, steps=steps, dtype=dtype, device=device, placement=placement, sbp=sbp, ) res = flow.pow(base, indice) res.requires_grad = requires_grad return res ================================================ FILE: python/oneflow/nn/modules/loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module from oneflow.nn.modules.constant import _ConstantBase class _Loss(Module): def __init__(self, reduction: str = "mean") -> None: super(_Loss, self).__init__() assert reduction in ["none", "mean", "sum"] self.reduction = reduction class _WeightedLoss(_Loss): def __init__( self, weight: Optional[Tensor] = None, reduction: str = "mean" ) -> None: super(_WeightedLoss, self).__init__(reduction=reduction) self.register_buffer("weight", weight) class L1Loss(_Loss): """This operator computes the L1 Loss between each element in `input` and `target`. The equation is: if reduction = "none": .. math:: output = |Target - Input| if reduction = "mean": .. math:: output = \\frac{1}{n}\\sum_{i=1}^n|Target_i - Input_i| if reduction = "sum": .. math:: output = \\sum_{i=1}^n|Target_i - Input_i| Args: input (oneflow.Tensor): the input Tensor. target (oneflow.Tensor): The target Tensor. reduction (str): The reduce type, it can be one of "none", "mean", "sum". Defaults to "mean". Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor([[1, 1, 1], [2, 2, 2], [7, 7, 7]], dtype = flow.float32) >>> target = flow.tensor([[4, 4, 4], [4, 4, 4], [4, 4, 4]], dtype = flow.float32) >>> m = flow.nn.L1Loss(reduction="none") >>> out = m(input, target) >>> out tensor([[3., 3., 3.], [2., 2., 2.], [3., 3., 3.]], dtype=oneflow.float32) >>> m_mean = flow.nn.L1Loss(reduction="mean") >>> out = m_mean(input, target) >>> out tensor(2.6667, dtype=oneflow.float32) >>> m_mean = flow.nn.L1Loss(reduction="sum") >>> out = m_mean(input, target) >>> out tensor(24., dtype=oneflow.float32) """ def __init__(self, reduction: str = "mean") -> None: super(L1Loss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.l1_loss(input, target, self.reduction) class CrossEntropyLoss(_WeightedLoss): r""" The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.CrossEntropyLoss.html. This criterion combines :class:`~flow.nn.LogSoftmax` and :class:`~flow.nn.NLLLoss` in one single class. It is useful when training a classification problem with `C` classes. If provided, the optional argument `weight` should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. The `input` is expected to contain raw, unnormalized scores for each class. `input` has to be a Tensor of size either :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the `K`-dimensional case (described later). The target that this criterion expects should contain either: - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if `ignore_index` is specified, this loss also accepts this class index (this index may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss for this case can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\} where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as :math:`d_1, ..., d_k` for the `K`-dimensional case. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, & \text{if reduction} = \text{'mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{'sum'.} \end{cases} Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and :class:`~torch.nn.NLLLoss`. - Probabilities for each class; useful when labels beyond a single class per minibatch item are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss for this case can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as :math:`d_1, ..., d_k` for the `K`-dimensional case. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{'mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{'sum'.} \end{cases} Args: weight (oneflow.Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size `C` ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. When ``reduction`` is ``mean``, the loss is averaged over non-ignored targets. Note that ``ignore_index`` is only applicable when the target contains class indices. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Default: ``'mean'`` label_smoothing (float, optinoal): A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in `Rethinking the Inception Architecture for Computer Vision `_. Default: :math:`0.0`. Shape: - Input: Shape ::math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: If containing class indices, shape :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. - Output: If reduction is 'none', same shape as the target. Otherwise, scalar. where: .. math:: \begin{aligned} C ={} & \text{number of classes} \\ N ={} & \text{batch size} \\ \end{aligned} For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor( ... [[-0.1664078, -1.7256707, -0.14690138], ... [-0.21474946, 0.53737473, 0.99684894], ... [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32) >>> target = flow.tensor(np.array([0, 1, 2]), dtype=flow.int32) >>> out = flow.nn.CrossEntropyLoss(reduction="none")(input, target) >>> out tensor([0.8020, 1.1167, 0.3583], dtype=oneflow.float32) >>> out_sum = flow.nn.CrossEntropyLoss(reduction="sum")(input, target) >>> out_sum tensor(2.2769, dtype=oneflow.float32) >>> out_mean = flow.nn.CrossEntropyLoss(reduction="mean")(input, target) >>> out_mean tensor(0.7590, dtype=oneflow.float32) >>> out_ignore_0 = flow.nn.CrossEntropyLoss(reduction="none", ignore_index=0)(input, target) >>> out_ignore_0 tensor([0.0000, 1.1167, 0.3583], dtype=oneflow.float32) >>> out_label_smoothing = flow.nn.CrossEntropyLoss(reduction="none", label_smoothing=0.5)(input, target) >>> out_label_smoothing tensor([1.0586, 1.1654, 0.8864], dtype=oneflow.float32) >>> probs = flow.tensor([[ 0.99495536, 0.28255007, -0.2775054 ], ... [ 0.42397153, 0.01075112, 0.56527734], ... [ 0.72356546, -0.1304398 , 0.4068744 ]], dtype=flow.float32) >>> out = flow.nn.CrossEntropyLoss()(input, probs) >>> out tensor(1.3305, dtype=oneflow.float32) """ def __init__( self, weight: Optional[Tensor] = None, ignore_index: int = -100, reduction: str = "mean", label_smoothing: float = 0.0, ) -> None: super(CrossEntropyLoss, self).__init__(weight, reduction) self.ignore_index = ignore_index self.label_smoothing = label_smoothing if self.label_smoothing < 0.0 or self.label_smoothing > 1.0: raise ValueError( "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing ) def forward(self, input, target): return flow._C.cross_entropy( input, target, self.weight, self.ignore_index, self.reduction, self.label_smoothing, ) class BCELoss(_WeightedLoss): """This operator computes the binary cross entropy loss. The equation is: if reduction = "none": .. math:: out = -(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i)) if reduction = "mean": .. math:: out = -\\frac{1}{n}\\sum_{i=1}^n(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i)) if reduction = "sum": .. math:: out = -\\sum_{i=1}^n(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i)) Args: weight (oneflow.Tensor, optional): The manual rescaling weight to the loss. Default to None, whose corresponding weight value is 1. reduction (str, optional): The reduce type, it can be one of "none", "mean", "sum". Defaults to "mean". Attention: The input value must be in the range of (0, 1). Or the loss function may return `nan` value. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.Tensor(np.array([[1.2, 0.2, -0.3], [0.7, 0.6, -2]]).astype(np.float32)) >>> target = flow.Tensor(np.array([[0, 1, 0], [1, 0, 1]]).astype(np.float32)) >>> weight = flow.Tensor(np.array([[2, 2, 2], [2, 2, 2]]).astype(np.float32)) >>> activation = flow.nn.Sigmoid() >>> sigmoid_input = activation(input) >>> m = flow.nn.BCELoss(weight, reduction="none") >>> out = m(sigmoid_input, target) >>> out tensor([[2.9266, 1.1963, 1.1087], [0.8064, 2.0750, 4.2539]], dtype=oneflow.float32) >>> m_sum = flow.nn.BCELoss(weight, reduction="sum") >>> out = m_sum(sigmoid_input, target) >>> out tensor(12.3668, dtype=oneflow.float32) >>> m_mean = flow.nn.BCELoss(weight, reduction="mean") >>> out = m_mean(sigmoid_input, target) >>> out tensor(2.0611, dtype=oneflow.float32) >>> m_none = flow.nn.BCELoss() >>> out = m_none(sigmoid_input, target) >>> out tensor(1.0306, dtype=oneflow.float32) """ def __init__( self, weight: Optional[Tensor] = None, reduction: str = "mean" ) -> None: super(BCELoss, self).__init__(weight, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.binary_cross_entropy_loss( input, target, self.weight, self.reduction ) class NLLLoss(_WeightedLoss): """ The negative log likelihood loss. It is useful to train a classification problem with `C` classes. The `input` given through a forward call is expected to contain log-probabilities of each class. `input` has to be a Tensor of size either :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \\geq 1` for the `K`-dimensional case (described later). Obtaining log-probabilities in a neural network is easily achieved by adding a `LogSoftmax` layer in the last layer of your network. You may use `CrossEntropyLoss` instead, if you prefer not to add an extra layer. The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` where `C = number of classes`; The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad l_n = - w_{y_n} x_{n,y_n}, \\quad w_{c} = \\mathbb{1}, where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \\ell(x, y) = \\begin{cases} \\sum_{n=1}^N \\frac{1}{N} l_n, & \\text{if reduction} = \\text{`mean';}\\\\ \\sum_{n=1}^N l_n, & \\text{if reduction} = \\text{`sum'.} \\end{cases} Can also be used for higher dimension inputs, such as 2D images, by providing an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \\geq 1`, where :math:`K` is the number of dimensions, and a target of appropriate shape (see below). In the case of images, it computes NLL loss per-pixel. Args: reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Default: ``'mean'`` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor( ... [[-0.1664078, -1.7256707, -0.14690138], ... [-0.21474946, 0.53737473, 0.99684894], ... [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32) >>> target = flow.tensor(np.array([0, 1, 2]), dtype=flow.int32) >>> m = flow.nn.NLLLoss(reduction="none") >>> out = m(input, target) >>> out tensor([ 0.1664, -0.5374, -0.7645], dtype=oneflow.float32) >>> m = flow.nn.NLLLoss(reduction="sum") >>> out = m(input, target) >>> out tensor(-1.1355, dtype=oneflow.float32) >>> m = flow.nn.NLLLoss(reduction="mean") >>> out = m(input, target) >>> out tensor(-0.3785, dtype=oneflow.float32) """ def __init__( self, weight: Optional[Tensor] = None, ignore_index: int = -100, reduction: str = "mean", ) -> None: super(NLLLoss, self).__init__(weight, reduction) self.ignore_index = ignore_index def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.nll_loss( input, target, self.weight, self.ignore_index, self.reduction ) class KLDivLoss(_Loss): """ The Kullback-Leibler divergence loss measure `Kullback-Leibler divergence`_ is a useful distance measure for continuous distributions and is often useful when performing direct regression over the space of (discretely sampled) continuous output distributions. As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain *log-probabilities* and is not restricted to a 2D Tensor. The targets are interpreted as *probabilities* by default, but could be considered as *log-probabilities* with :attr:`log_target` set to ``True``. This criterion expects a `target` `Tensor` of the same size as the `input` `Tensor`. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: l(x,y) = L = \\{ l_1,\\dots,l_N \\}, \\quad l_n = y_n \\cdot \\left( \\log y_n - x_n \\right) where the index :math:`N` spans all dimensions of ``input`` and :math:`L` has the same shape as ``input``. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then: .. math:: \\ell(x, y) = \\begin{cases} \\operatorname{mean}(L), & \\text{if reduction} = \\text{`mean';} \\\\ \\operatorname{sum}(L), & \\text{if reduction} = \\text{`sum'.} \\end{cases} In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations **as well as** over dimensions. ``'batchmean'`` mode gives the correct KL divergence where losses are averaged over batch dimension only. ``'mean'`` mode's behavior will be changed to the same as ``'batchmean'`` in the next major release. .. _`kullback-leibler divergence`: https://en.wikipedia.org/wiki/Kullback-Leibler_divergence The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.KLDivLoss.html. Args: reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. ``'none'``: no reduction will be applied. ``'batchmean'``: the sum of the output will be divided by batchsize. ``'sum'``: the output will be summed. ``'mean'``: the output will be divided by the number of elements in the output. Default: ``'mean'`` log_target (bool, optional): Specifies whether `target` is passed in the log space. Default: ``False`` .. note:: :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``. Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input - Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`, the same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor([-0.9021705, 0.08798598, 1.04686249], dtype=flow.float32) >>> target = flow.tensor([1.22386942, -0.89729659, 0.01615712], dtype=flow.float32) >>> m = flow.nn.KLDivLoss(reduction="none", log_target=False) >>> out = m(input, target) >>> out tensor([ 1.3514, 0.0000, -0.0836], dtype=oneflow.float32) >>> m = flow.nn.KLDivLoss(reduction="mean", log_target=False) >>> out = m(input, target) >>> out tensor(0.4226, dtype=oneflow.float32) >>> m = flow.nn.KLDivLoss(reduction="sum", log_target=True) >>> out = m(input, target) >>> out tensor(5.7801, dtype=oneflow.float32) """ def __init__(self, reduction: str = "mean", log_target: bool = False) -> None: if reduction == "batchmean": super(KLDivLoss, self).__init__("sum") self.reduction = "batchmean" else: super(KLDivLoss, self).__init__(reduction) self.log_target = log_target def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.kl_div_loss(input, target, self.log_target, self.reduction) class MSELoss(_Loss): """ Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input :math:`x` and target :math:`y`. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad l_n = \\left( x_n - y_n \\right)^2, where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then: .. math:: \\ell(x, y) = \\begin{cases} \\operatorname{mean}(L), & \\text{if reduction} = \\text{`mean';}\\\\ \\operatorname{sum}(L), & \\text{if reduction} = \\text{`sum'.} \\end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total of :math:`n` elements each. The mean operation still operates over all the elements, and divides by :math:`n`. The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MSELoss.html. Args: reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor( ... [[-0.02557137, 0.03101675, 1.37493674], ... [0.25599439, -1.08372561, -0.21006816]], dtype=flow.float32) >>> target = flow.tensor( ... [[-1.53105064, -0.68137555, 0.5931354], ... [-0.49158347, 0.93673637, 0.1324141]], dtype=flow.float32) >>> m = flow.nn.MSELoss(reduction="none") >>> out = m(input, target) >>> out tensor([[2.2665, 0.5075, 0.6112], [0.5589, 4.0823, 0.1173]], dtype=oneflow.float32) >>> m = flow.nn.MSELoss(reduction="mean") >>> out = m(input, target) >>> out tensor(1.3573, dtype=oneflow.float32) >>> m = flow.nn.MSELoss(reduction="sum") >>> out = m(input, target) >>> out tensor(8.1436, dtype=oneflow.float32) """ def __init__(self, reduction: str = "mean") -> None: super(MSELoss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.mse_loss(input, target, self.reduction) class MarginRankingLoss(_Loss): """Creates a criterion that measures the loss given inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, and a label 1D mini-batch tensor :math:`y` (containing 1 or -1). If :math:`y = 1` then it assumed the first input should be ranked higher (have a larger value) than the second input, and vice-versa for :math:`y = -1`. The loss function for each sample in the mini-batch is: .. math:: \\text{loss}(x1, x2, y) = \\max(0, -y * (x1 - x2) + \\text{margin}) Args: margin (float, optional): Has a default value of :math:`0`. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Shape: - `x1` : :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample. - `x2` : :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample. - Target: :math:`(N)` - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=flow.float32) >>> x2 = flow.tensor(np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]]), dtype=flow.float32) >>> target = flow.tensor(np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]]), dtype=flow.float32) >>> m = flow.nn.MarginRankingLoss(margin =1.0, reduction="none") >>> out = m(x1, x2, target) >>> out tensor([[2., 1., 0.], [3., 0., 5.], [0., 0., 0.]], dtype=oneflow.float32) >>> m = flow.nn.MarginRankingLoss(margin = 0.3, reduction="sum") >>> out = m(x1, x2, target) >>> out tensor(8.2000, dtype=oneflow.float32) >>> m = flow.nn.MarginRankingLoss(margin = 10, reduction="mean") >>> out = m(x1, x2, target) >>> out tensor(8.3333, dtype=oneflow.float32) """ def __init__(self, margin: float = 0.0, reduction: str = "mean") -> None: super(MarginRankingLoss, self).__init__(reduction) self.margin = margin def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: return flow._C.margin_ranking_loss( input1, input2, target, self.margin, self.reduction ) class CTCLoss(_Loss): """The Connectionist Temporal Classification loss. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.CTCLoss.html. Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the probability of possible alignments of input to target, producing a loss value which is differentiable with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which limits the length of the target sequence such that it must be :math:`\\leq` the input length. Args: blank (int, optional): blank label. Default :math:`0`. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ``'mean'`` zero_infinity (bool, optional): Whether to zero infinite losses and the associated gradients. Default: ``False`` Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Shape: - Log_probs: Tensor of size :math:`(T, N, C)`, where :math:`T = \\text{input length}`, :math:`N = \\text{batch size}`, and :math:`C = \\text{number of classes (including blank)}`. - Targets: Tensor of size :math:`(N, S)` or :math:`(\\operatorname{sum}(\\text{target_lengths}))`, where :math:`N = \\text{batch size}` and :math:`S = \\text{max target length, if shape is } (N, S)`. It represent the target sequences. Each element in the target sequence is a class index. And the target index cannot be blank (default=0). In the :math:`(N, S)` form, targets are padded to the length of the longest sequence, and stacked. In the :math:`(\\operatorname{sum}(\\text{target_lengths}))` form, the targets are assumed to be un-padded and concatenated within 1 dimension. - Input_lengths: Tuple or tensor of size :math:`(N)`, where :math:`N = \\text{batch size}`. It represent the lengths of the inputs (must each be :math:`\\leq T`). And the lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths. - Target_lengths: Tuple or tensor of size :math:`(N)`, where :math:`N = \\text{batch size}`. It represent lengths of the targets. Lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths. If target shape is :math:`(N,S)`, target_lengths are effectively the stop index :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for each target in a batch. Lengths must each be :math:`\\leq S` If the targets are given as a 1d tensor that is the concatenation of individual targets, the target_lengths must add up to the total length of the tensor. Reference: A. Graves et al.: Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks: https://www.cs.toronto.edu/~graves/icml_2006.pdf For example: .. code-block:: python >>> import oneflow as flow >>> log_probs = flow.tensor( ... [ ... [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]], ... [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]], ... [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]], ... [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]], ... [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]], ... ], dtype=flow.float32) >>> targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32) >>> input_lengths = flow.tensor([5, 5], dtype=flow.int32) >>> target_lengths = flow.tensor([3, 3], dtype=flow.int32) >>> loss_mean = flow.nn.CTCLoss() >>> out = loss_mean(log_probs, targets, input_lengths, target_lengths) >>> out tensor(1.1376, dtype=oneflow.float32) >>> loss_sum = flow.nn.CTCLoss(blank=0, reduction="sum") >>> out = loss_sum(log_probs, targets, input_lengths, target_lengths) >>> out tensor(6.8257, dtype=oneflow.float32) """ def __init__( self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False ) -> None: super(CTCLoss, self).__init__(reduction) self.blank = blank self.zero_infinity = zero_infinity def forward( self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, ) -> Tensor: max_target_length = 0 if targets.ndim == 1: max_target_length = target_lengths.max().item() elif targets.ndim == 2: max_target_length = targets.shape[1] return flow._C.ctc_loss( log_probs, targets, input_lengths, target_lengths, max_target_length, self.blank, self.zero_infinity, self.reduction, ) class BCEWithLogitsLoss(_WeightedLoss): """This operator combines the `Sigmoid` and `BCELoss` together. For numerical stability, we apply some math tricks instead of using `Sigmoid` layer with `BCELoss`. The equation is: if :attr:`reduction` = ``"none"``: .. math:: out = -weight*[Pos\\_weight*y*log\\sigma({x}) + (1-y)*log(1-\\sigma(x))] if :attr:`reduction` = ``"mean"``: .. math:: out = -\\frac{weight}{n}\\sum_{i=1}^n[Pos\\_weight*y*log\\sigma({x}) + (1-y)*log(1-\\sigma(x))] if :attr:`reduction` = ``"sum"``: .. math:: out = -weight*\\sum_{i=1}^n[Pos\\_weight*y*log\\sigma({x}) + (1-y)*log(1-\\sigma(x))] Args: weight (Tensor, optional): The manual rescaling weight to the loss. Default: ``None`` size_average (bool, optional): Deprecated (see :attr:`reduction`). Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). Default: ``True`` reduction (str, optional): The reduce type, it can be one of ``"none"``, ``"mean"``, ``"sum"``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``"mean"`` pos_weight (Tensor, optional): The manual rescaling weight to the positive examples. Default: ``None`` Shape: - Input: :math:`(N,*)` where `*` means, any number of additional dimensions - Target: :math:`(N,*)`, same shape as the input - Output: scalar. If :attr:`reduction` is ``"none"``, then :math:`(N,*)`, same shape as input. For example: .. code-block:: python >>> import oneflow as flow >>> input = flow.tensor([[1.2, 0.2, -0.3], [0.7, 0.6, -2], [0.7, 0.6, -2]], dtype=flow.float32) >>> target = flow.tensor([[0, 1, 0], [1, 0, 1], [1, 0, 1]], dtype=flow.float32) >>> weight = flow.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=flow.float32) >>> pos_weight = flow.tensor([1.2, 1.3, 1.4], dtype=flow.float32) >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction="none") >>> out = m(input, target) >>> out tensor([[2.9266, 1.5552, 1.1087], [0.9676, 2.0750, 5.9554], [0.9676, 2.0750, 5.9554]], dtype=oneflow.float32) >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction="mean") >>> out = m(input, target) >>> out tensor(2.6207, dtype=oneflow.float32) >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction="sum") >>> out = m(input, target) >>> out tensor(23.5865, dtype=oneflow.float32) """ def __init__( self, weight: Optional[Tensor] = None, reduction: str = "mean", pos_weight: Optional[Tensor] = None, ) -> None: super(BCEWithLogitsLoss, self).__init__(weight, reduction) self.reduction = reduction self.pos_weight = pos_weight def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.binary_cross_entropy_with_logits_loss( input, target, self.weight, self.pos_weight, self.reduction ) class SmoothL1Loss(_Loss): """Creates a criterion that uses a squared term if the absolute element-wise error falls below beta and an L1 term otherwise. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.SmoothL1Loss.html. It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases prevents exploding gradients (e.g. see the paper `Fast R-CNN `__ by Ross Girshick).. For a batch of size :math:`N`, the unreduced loss can be described as: .. math:: \\ell(x, y) = L = \\{l_1, ..., l_N\\}^T with .. math:: l_n = \\begin{cases} 0.5 (x_n - y_n)^2 / beta, & \\text{if } |x_n - y_n| < beta \\\\ |x_n - y_n| - 0.5 * beta, & \\text{otherwise } \\end{cases} If `reduction` is not `none`, then: .. math:: \\ell(x, y) = \\begin{cases} \\operatorname{mean}(L), & \\text{if reduction} = \\text{`mean';}\\\\ \\operatorname{sum}(L), & \\text{if reduction} = \\text{`sum'.} \\end{cases} .. note:: Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta` portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`. The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`. .. note:: Smooth L1 loss is closely related to :class:`HuberLoss`, being equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is also known as delta for Huber). This leads to the following differences: * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss` converges to a constant 0 loss. * As beta -> :math:`+\\infty`, Smooth L1 loss converges to a constant 0 loss, while :class:`HuberLoss` converges to :class:`MSELoss`. * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. For :class:`HuberLoss`, the slope of the L1 segment is beta. Args: size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. The value must be non-negative. Default: 1.0 Shape: - Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions - Target: :math:`(N, *)`; same shape as the input - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`; same shape as the input For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([0.1, 0.4, 0.3, 0.5, 0.9]).astype(np.float32), dtype=flow.float32) >>> y = flow.tensor(np.array([0.3, 0.9, 2.5, 0.4, 0.3]).astype(np.float32), dtype=flow.float32) >>> m = flow.nn.SmoothL1Loss(reduction="none") >>> out = m(x, y) >>> out tensor([0.0200, 0.1250, 1.7000, 0.0050, 0.1800], dtype=oneflow.float32) >>> m = flow.nn.SmoothL1Loss(reduction="mean") >>> out = m(x, y) >>> out tensor(0.4060, dtype=oneflow.float32) >>> m = flow.nn.SmoothL1Loss(reduction="sum") >>> out = m(x, y) >>> out tensor(2.0300, dtype=oneflow.float32) """ def __init__(self, reduction: str = "mean", beta: float = 1.0) -> None: super(SmoothL1Loss, self).__init__(reduction) self.beta = beta def forward(self, input: Tensor, target: Tensor) -> Tensor: return flow._C.smooth_l1_loss(input, target, self.beta, self.reduction) class CombinedMarginLoss(Module): r"""The operation implements "margin_softmax" in InsightFace: https://github.com/deepinsight/insightface/blob/master/recognition/arcface_mxnet/train.py The implementation of margin_softmax in InsightFace is composed of multiple operators. We fuse them for speed up. Applies the function: .. math:: {\rm CombinedMarginLoss}(x_i, label) = \left\{\begin{matrix} \cos(m_1\cdot\arccos x_i+m_2) - m_3 & {\rm if} \ i == label \\ x_i & {\rm otherwise} \end{matrix}\right. Args: x (oneflow.Tensor): A Tensor label (oneflow.Tensor): label with integer data type m1 (float): loss m1 parameter m2 (float): loss m2 parameter m3 (float): loss m3 parameter .. note:: Here are some special cases: - when :math:`m_1=1, m_2\neq 0, m_3=0`, CombineMarginLoss has the same parameter as `ArcFace `__ . - when :math:`m_1=1, m_2=0, m_3\neq 0`, CombineMarginLoss has the same parameter as `CosFace (a.k.a AM-Softmax) `__ . - when :math:`m_1\gt 1, m_2=m_3=0`, CombineMarginLoss has the same parameter as `A-Softmax `__. Returns: oneflow.Tensor: A Tensor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> np_x = np.array([[-0.7027179, 0.0230609], [-0.02721931, -0.16056311], [-0.4565852, -0.64471215]]) >>> np_label = np.array([0, 1, 1]) >>> x = flow.tensor(np_x, dtype=flow.float32) >>> label = flow.tensor(np_label, dtype=flow.int32) >>> loss_func = flow.nn.CombinedMarginLoss(0.3, 0.5, 0.4) >>> out = loss_func(x, label) >>> out tensor([[-0.0423, 0.0231], [-0.0272, 0.1237], [-0.4566, -0.0204]], dtype=oneflow.float32) """ def __init__(self, m1: float = 1.0, m2: float = 0.0, m3: float = 0.0) -> None: super().__init__() self.m1 = m1 self.m2 = m2 self.m3 = m3 def forward(self, x: Tensor, label: Tensor) -> Tensor: return flow._C.combined_margin_loss( x, label, m1=self.m1, m2=self.m2, m3=self.m3 ) class TripletMarginLoss(Module): r"""Creates a criterion that measures the triplet loss given an input tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. This is used for measuring a relative similarity between samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative examples` respectively). The shapes of all input tensors should be :math:`(N, D)`. The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with triplet losses `__ by V. Balntas, E. Riba et al. The loss function for each sample in the mini-batch is: .. math:: L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} where .. math:: d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p Args: margin (float, optional): Default: :math:`1`. p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`. swap (bool, optional): The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with triplet losses` by V. Balntas, E. Riba et al. Default: ``False``. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, D)` where :math:`D` is the vector dimension. - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar otherwise. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2) >>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]]) >>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]]) >>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative)) >>> output tensor(6.2971, dtype=oneflow.float32) """ def __init__( self, margin: float = 1.0, p: float = 2.0, eps: float = 1e-6, swap: bool = False, size_average=None, reduce=None, reduction: str = "mean", ) -> None: super().__init__() self.margin = margin self.p = p self.eps = eps self.swap = swap self.reduction = reduction def forward(self, anchor, positive, negative): triplet_loss = flow._C.triplet_margin_loss( anchor, positive, negative, margin=self.margin, p=self.p, eps=self.eps, swap=self.swap, reduction=self.reduction, ) return triplet_loss if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/masked_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def masked_select_op(input, mask): """ Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor(In oneFlow BoolTensor is replaced by Int8Tensor). The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable. Args: input (Tensor): the input tensor. mask (Tensor): the tensor containing the binary mask to index with For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32) >>> mask = input.gt(0.05) >>> out = flow.masked_select(input, mask) >>> out tensor([0.3139, 0.3898], dtype=oneflow.float32) """ assert input.is_global == mask.is_global, ( f"input tensor is %s tensor, but mask is %s tensor" % ( "global" if input.is_global else "local", "global" if mask.is_global else "local", ) ) broadcast_shape = [] input_shape_len = len(input.shape) mask_shape_len = len(mask.shape) input_shape = [input.shape[i] for i in range(input_shape_len)] input_shape.reverse() mask_shape = [mask.shape[i] for i in range(mask_shape_len)] mask_shape.reverse() for i in range(max(input_shape_len, mask_shape_len)): if i < input_shape_len and i < mask_shape_len: broadcast_shape.append(max(input_shape[i], mask_shape[i])) elif i < input_shape_len: broadcast_shape.append(input_shape[i]) else: broadcast_shape.append(mask_shape[i]) broadcast_shape.reverse() broadcast_input = input.expand(broadcast_shape) broadcast_mask = mask.expand(broadcast_shape) indices = flow.argwhere(broadcast_mask) gather_res = flow._C.gather_nd(broadcast_input, indices) return gather_res.flatten() if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/math_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections from typing import Optional, Sequence, Union import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module from oneflow.nn.modules.utils import _check_axis from oneflow.ops.transpose_util import ( get_inversed_perm, get_perm_when_transpose_axis_to_last_dim, ) def asin_op(input): """ Returns a new tensor with the arcsine of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\sin^{-1}(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([-0.5, 0.8, 1.0, -0.8]), dtype=flow.float32) >>> output = flow.asin(input) >>> output.shape oneflow.Size([4]) >>> output tensor([-0.5236, 0.9273, 1.5708, -0.9273], dtype=oneflow.float32) >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, -1.0]]), dtype=flow.float32) >>> output1 = input1.asin() >>> output1.shape oneflow.Size([2, 2]) >>> output1 tensor([[ 0.9273, 1.5708], [-0.6435, -1.5708]], dtype=oneflow.float32) """ return flow._C.asin(input) def arcsin_op(input): """ Alias for :func:`oneflow.asin` """ return flow._C.asin(input) def asinh_op(input): """ Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\sinh^{-1}(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([2, 3, 4]), dtype=flow.float32) >>> output = flow.asinh(input) >>> output.shape oneflow.Size([3]) >>> output tensor([1.4436, 1.8184, 2.0947], dtype=oneflow.float32) >>> input1 = flow.tensor(np.array([[-1, 0, -0.4], [5, 7, 0.8]]), dtype=flow.float32) >>> output1 = input1.asinh() >>> output1.shape oneflow.Size([2, 3]) >>> output1 tensor([[-0.8814, 0.0000, -0.3900], [ 2.3124, 2.6441, 0.7327]], dtype=oneflow.float32) """ return flow._C.asinh(input) def arcsinh_op(input): """ Alias for :func:`oneflow.asinh` """ return flow._C.asinh(input) def asinh_op_tensor(input): """ See :func:`oneflow.asinh` """ return flow._C.asinh(input) def inplace_sin_op_tensor(input): """ In-place version of :func:`oneflow.sin` """ return flow._C.sin_(input) def atan_op(input): """ Returns a new tensor with the arctangent of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\tan^{-1}(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([0.5, 0.6, 0.7]), dtype=flow.float32) >>> output = flow.atan(input) >>> output.shape oneflow.Size([3]) """ return flow._C.atan(input) def arctan_op(input): """ Alias for :func:`oneflow.atan` """ return flow._C.atan(input) def fmod_op(input, other): """ fmod(input, other, *, out=None) -> Tensor Computes the element-wise remainder of division. The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend :attr:`input`. Supports broadcasting to a common shape, integer and float inputs. Args: input (Tensor): the dividend other (Tensor or Scalar): the divisor Keyword args: out (Tensor, optional): the output tensor. Example:: >>> import oneflow as flow >>> flow.fmod(flow.tensor([-3., -2, -1, 1, 2, 3]), 2.) tensor([-1., -0., -1., 1., 0., 1.], dtype=oneflow.float32) >>> flow.fmod(flow.tensor([1, 2, 3, 4, 5.]), 1.5) tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000], dtype=oneflow.float32) >>> flow.fmod(flow.tensor([1, 2, 3, 4., -5]), flow.tensor([4, 2, 1, 3., 1])) tensor([1., 0., 0., 1., -0.], dtype=oneflow.float32) """ return flow._C.fmod(input, other) def addmm(x, mat1, mat2, alpha=1, beta=1): if len(x.shape) > 2 or len(mat1.shape) > 2 or len(mat2.shape) > 2: raise ValueError("input matrixes shape can not be greater than 2") else: return flow.mul(x, beta) + flow.mul(flow._C.matmul(mat1, mat2), alpha) def addmm_op(input, mat1, mat2, alpha=1, beta=1): """addmm(beta=1, input, alpha=1, mat1, mat2, out=None) -> Tensor Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. The matrix :attr:`input` is added to the final result. If :attr:`mat1` is a :math:`(n \\times m)` tensor, :attr:`mat2` is a :math:`(m \\times p)` tensor, then :attr:`input` must be broadcastable with a :math:`(n \\times p)` tensor and :attr:`out` will be a :math:`(n \\times p)` tensor. :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. .. math:: \\text{out} = \\beta\\ \\text{input} + \\alpha\\ (\\text{mat1}_i \\mathbin{@} \\text{mat2}_i) For inputs of type `float` or `double`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. Args: beta (Number, optional): multiplier for :attr:`input` (:math:`\\beta`) input (Tensor): matrix to be added alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\\alpha`) mat1 (Tensor): the first matrix to be multiplied mat2 (Tensor): the second matrix to be multiplied out (Tensor, optional): the output tensor. For example: >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.array([[1,2,4],[5,11,9.1]])) >>> mat1 = flow.tensor(np.array([[7.3,1.9,7.3],[10.2,1,5.5]])) >>> mat2 = flow.tensor(np.array([[7.3,1.9,7.3],[10.2,1,5.5],[3.7,2.2,8.1]])) >>> output = flow.addmm(input, mat1, mat2) >>> output tensor([[100.6800, 33.8300, 126.8700], [110.0100, 43.4800, 133.6100]], dtype=oneflow.float64) >>> output.shape oneflow.Size([2, 3]) >>> input2 = flow.tensor(np.array([1.7])) >>> mat1 = flow.tensor(np.array([[1,2],[5,9.1],[7.7,1.4]])) >>> mat2 = flow.tensor(np.array([[1,2,3.7],[5,9.1,6.8]])) >>> output2 = flow.addmm(input2, mat1, mat2, alpha=1, beta=2) >>> output2 tensor([[14.4000, 23.6000, 20.7000], [53.9000, 96.2100, 83.7800], [18.1000, 31.5400, 41.4100]], dtype=oneflow.float64) >>> output2.shape oneflow.Size([3, 3]) """ return addmm(input, mat1, mat2, alpha, beta) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/meshgrid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def meshgrid_op(*tensors, indexing="ij"): if isinstance(tensors[0], (list, tuple)): return flow._C.meshgrid(tensors[0], indexing) return flow._C.meshgrid(tensors, indexing) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/min_max_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module class MinMaxObserver(Module): """ Compute the quantization parameters of the input tensor. First compute the max and min values of input tensor: .. math:: & max\\_value = max(input) & min\\_value = min(input) Then compute the scale and zero_point with the following equations: if quantization_scheme == "symmetric": .. math:: & denom = 2^{quantization\\_to\\_bit - 1} - 1 & scale = max(|max\\_value|,|min\\_value|) / denom & zero\\_point = 0 elif quantization_scheme == "affine": .. math:: & denom = 2^{quantization\\_to\\_bit} - 1 & scale = (max\\_value - min\\_value) / denom & zero\\_point = -min\\_value / scale If per_layer_quantization is False, then the shape of scale and zero_point will be (input.shape[0],). Args: input(oneflow.Tensor): the input value(s), in ``oneflow.float32``. quantization_formula (str): Support "google" or "cambricon". quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". per_layer_quantization (bool): True or False, means per-layer / per-channel quantization. Defaults to True. Returns: Tuple[oneflow.Tensor, oneflow.Tensor]: The scale and zero_point of input tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32) >>> input_tensor = flow.tensor( ... weight, dtype=flow.float32 ... ) >>> quantization_bit = 8 >>> quantization_scheme = "symmetric" >>> quantization_formula = "google" >>> per_layer_quantization = True >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization) >>> scale, zero_point = min_max_observer( ... input_tensor, ) """ def __init__( self, quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", per_layer_quantization: bool = True, ) -> None: super().__init__() self.quantization_formula = quantization_formula self.quantization_bit = quantization_bit self.quantization_scheme = quantization_scheme self.per_layer_quantization = per_layer_quantization def forward(self, input): return flow._C.min_max_observer( input, self.quantization_formula, self.quantization_bit, self.quantization_scheme, self.per_layer_quantization, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/module.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import itertools from collections import OrderedDict, namedtuple from typing import ( Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, TypeVar, Union, overload, ) import traceback import functools import weakref import warnings import numpy as np import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.parameter import Parameter from contextlib import contextmanager class _WrappedHook(object): def __init__(self, hook: Callable, module: Optional["Module"] = None): self.hook: Callable = hook functools.update_wrapper(self, hook) self.with_module: bool = False if module is not None: self.module: weakref.ReferenceType["Module"] = weakref.ref(module) self.with_module = True def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.with_module: module = self.module() if module is None: raise RuntimeError("You are trying to call the hook of a dead Module!") return self.hook(module, *args, **kwargs) return self.hook(*args, **kwargs) def __getstate__(self) -> Dict: result = {"hook": self.hook, "with_module": self.with_module} if self.with_module: result["module"] = self.module() return result def __setstate__(self, state: Dict): self.hook = state["hook"] self.with_module = state["with_module"] if self.with_module: if state["module"] is None: raise RuntimeError( "You are trying to revive the hook of a dead Module!" ) self.module = weakref.ref(state["module"]) class _IncompatibleKeys( namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) ): def __repr__(self): if not self.missing_keys and (not self.unexpected_keys): return "" return super(_IncompatibleKeys, self).__repr__() __str__ = __repr__ def _addindent(s_, numSpaces): s = s_.split("\n") if len(s) == 1: return s_ first = s.pop(0) s = [numSpaces * " " + line for line in s] s = "\n".join(s) s = first + "\n" + s return s T = TypeVar("T", bound="Module") class Module(object): r"""Base class for all neural network modules. This class is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Module.html. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import oneflow.nn as nn import oneflow.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool """ def __init__(self): """ Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. """ super().__setattr__("training", True) super().__setattr__("_parameters", OrderedDict()) super().__setattr__("_buffers", OrderedDict()) super().__setattr__("_non_persistent_buffers_set", set()) super().__setattr__("_backward_hooks", OrderedDict()) super().__setattr__("_is_full_backward_hook", None) super().__setattr__("_forward_hooks", OrderedDict()) super().__setattr__("_forward_pre_hooks", OrderedDict()) super().__setattr__("_state_dict_hooks", OrderedDict()) super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) super().__setattr__("_modules", OrderedDict()) super().__setattr__("_is_ddp_module", False) super().__setattr__("_oneflow_internal_module_tensor_applied_dict__", None) super().__setattr__("cpg", None) def __getstate__(self): if not self._is_ddp_module: if ( len(self._backward_hooks) > 0 or len(self._forward_hooks) > 0 or len(self._forward_pre_hooks) > 0 or len(self._state_dict_hooks) > 0 or len(self._load_state_dict_pre_hooks) > 0 ): warnings.warn("The module hooks will not be remained after serializing") state = self.__dict__.copy() del state["_backward_hooks"] del state["_forward_hooks"] del state["_forward_pre_hooks"] del state["_state_dict_hooks"] del state["_load_state_dict_pre_hooks"] del state["_is_full_backward_hook"] del state["_non_persistent_buffers_set"] return state def __setstate__(self, state): self.__dict__.update(state) self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._is_full_backward_hook = None self._non_persistent_buffers_set = set() if hasattr(self, "_is_ddp_module") and self._is_ddp_module: # flow.nn.parallel.DistributedDataParallel updates the module inplace flow.nn.parallel.DistributedDataParallel(self, broadcast_parameters=False) def forward(self, *args, **kwargs): raise NotImplementedError() def __call__(self, *args, **kwargs): if flow._oneflow_internal.lazy_mode.is_enabled(): warnings.warn( self._shallow_repr() + " is called in a nn.Graph, but not registered into a nn.Graph." ) full_backward_hooks, non_full_backward_hooks = [], [] if self._backward_hooks: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() for hook in list(self._forward_pre_hooks.values()): result = hook(self, args) if result is not None: if not isinstance(result, tuple): result = (result,) args = result bw_hook = None if full_backward_hooks: bw_hook = flow.utils.hooks.BackwardHook(self, full_backward_hooks, []) args = bw_hook.setup_input_hook(args) res = self.forward(*args, **kwargs) for hook in list(self._forward_hooks.values()): result = hook(self, args, res) if result is not None: res = result if bw_hook is not None: res = bw_hook.setup_output_hook(res) if non_full_backward_hooks: var = res while not isinstance(var, Tensor): if isinstance(var, dict): var = next((v for v in var.values() if isinstance(v, Tensor))) else: var = var[0] grad_fn = var.grad_fn if grad_fn is not None: self._maybe_warn_non_full_backward_hook(args, res, grad_fn) for hook in non_full_backward_hooks: wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) return res def add_module(self, name: str, module: Optional["Module"]) -> None: r""" add_module(name, module) Adds a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (string): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. """ if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format(type(module))) elif not isinstance(name, str): raise TypeError("module name should be a string. Got {}".format(type(name))) elif hasattr(self, name) and name not in self._modules: raise KeyError("attribute '{}' already exists".format(name)) elif "." in name: raise KeyError('module name can\'t contain ".", got: {}'.format(name)) elif name == "": raise KeyError('module name can\'t be empty string ""') self._modules[name] = module def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: r""" register_buffer(name, tensor, persistent=True) Adds a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (string): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> self.register_buffer('running_mean', oneflow.zeros(num_features)) # doctest: +SKIP """ if "_buffers" not in self.__dict__: raise AttributeError("cannot assign buffer before Module.__init__() call") elif not isinstance(name, str): raise TypeError("buffer name should be a string. Got {}".format(type(name))) elif "." in name: raise KeyError('buffer name can\'t contain "."') elif name == "": raise KeyError('buffer name can\'t be empty string ""') elif hasattr(self, name) and name not in self._buffers: raise KeyError("attribute '{}' already exists".format(name)) elif tensor is not None and (not isinstance(tensor, Tensor)): raise TypeError( "cannot assign '{}' object to buffer '{}' (Tensor or None required)".format( type(tensor), name ) ) else: self._buffers[name] = tensor if persistent: self._non_persistent_buffers_set.discard(name) else: self._non_persistent_buffers_set.add(name) def register_parameter(self, name: str, param: Optional[Parameter]) -> None: r""" register_parameter(name, param) Adds a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (string): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. """ if "_parameters" not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call" ) elif not isinstance(name, str): raise TypeError( "parameter name should be a string. Got {}".format(type(name)) ) elif "." in name: raise KeyError('parameter name can\'t contain "."') elif name == "": raise KeyError('parameter name can\'t be empty string ""') elif hasattr(self, name) and name not in self._parameters: raise KeyError("attribute '{}' already exists".format(name)) if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError( "cannot assign '{}' object to parameter '{}' (nn.Parameter or None required)".format( type(param), name ) ) else: self._parameters[name] = param def _register_state_dict_hook(self, hook): r"""These hooks will be called with arguments: `self`, `state_dict`, `prefix`, `local_metadata`, after the `state_dict` of `self` is set. Note that only parameters and buffers of `self` or its children are guaranteed to exist in `state_dict`. The hooks may modify `state_dict` inplace or return a new one. .. note: Do not use `module.state_dict()` in _register_state_dict_hook function """ handle = flow.utils.hooks.RemovableHandle(self._state_dict_hooks) self._state_dict_hooks[handle.id] = hook return handle def _register_load_state_dict_pre_hook( self, hook: Callable[..., None], with_module=False ): r"""These hooks will be called with arguments: `state_dict`, `prefix`, `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, `error_msgs`, before loading `state_dict` into `self`. These arguments are exactly the same as those of `_load_from_state_dict`. If ``with_module`` is ``True``, then the first argument to the hook is an instance of the module. Arguments: hook (Callable): Callable hook that will be invoked before loading the state dict. with_module (bool, optional): Whether or not to pass the module instance to the hook as the first parameter. """ handle = flow.utils.hooks.RemovableHandle(self._load_state_dict_pre_hooks) self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( hook, self if with_module else None ) return handle def register_state_dict_pre_hook(self, hook): r"""These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. """ handle = flow.utils.hooks.RemovableHandle(self._state_dict_pre_hooks) self._state_dict_pre_hooks[handle.id] = hook return handle def __getattr__(self, name: str) -> Union[Tensor, "Module"]: if "_parameters" in self.__dict__: _parameters = self.__dict__["_parameters"] if name in _parameters: return _parameters[name] if "_buffers" in self.__dict__: _buffers = self.__dict__["_buffers"] if name in _buffers: return _buffers[name] if "_modules" in self.__dict__: modules = self.__dict__["_modules"] if name in modules: return modules[name] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, name) ) def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: def remove_from(*dicts_or_sets): for d in dicts_or_sets: if name in d: if isinstance(d, dict): del d[name] else: d.discard(name) params = self.__dict__.get("_parameters") if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call" ) remove_from( self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set, ) self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError( "cannot assign '{}' as parameter '{}' (nn.Parameter or None expected)".format( type(value), name ) ) self.register_parameter(name, value) else: modules = self.__dict__.get("_modules") if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call" ) remove_from( self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set, ) modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError( "cannot assign '{}' as child module '{}' (nn.Module or None expected)".format( type(value), name ) ) modules[name] = value else: buffers = self.__dict__.get("_buffers") if buffers is not None and name in buffers: if value is not None and (not isinstance(value, Tensor)): raise TypeError( "cannot assign '{}' as buffer '{}' (Tensor or None expected)".format( type(value), name ) ) buffers[name] = value else: object.__setattr__(self, name, value) def __delattr__(self, name): if name in self._parameters: del self._parameters[name] elif name in self._buffers: del self._buffers[name] self._non_persistent_buffers_set.discard(name) elif name in self._modules: del self._modules[name] else: super().__delattr__(name) def _named_members(self, get_members_fn, prefix="", recurse=True): memo = set() modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] for (module_prefix, module) in modules: members = get_members_fn(module) for (k, v) in members: if v is None or v in memo: continue memo.add(v) name = module_prefix + ("." if module_prefix else "") + k yield (name, v) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: r""" parameters(recurse=True) -> Iterator[Parameter] Returns an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> for param in model.parameters(): # doctest: +SKIP ... print(type(param), param.size()) # doctest: +SKIP oneflow.Size([10]) """ for (name, param) in self.named_parameters(recurse=recurse): yield param def named_parameters( self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: r""" named_parameters(prefix="", recurse=True) -> Iterator[Tuple[str, Tensor]] Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: (string, Parameter): Tuple containing the name and parameter Example:: >>> for name, param in self.named_parameters(): # doctest: +SKIP ... if name in ['bias']: # doctest: +SKIP ... print(param.size()) # doctest: +SKIP """ gen = self._named_members( lambda module: module._parameters.items(), prefix=prefix, recurse=recurse ) for elem in gen: yield elem def buffers(self, recurse: bool = True) -> Iterator[Tensor]: r""" buffers(recurse=True) -> Iterator[Tensor] Returns an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: oneflow.Tensor: module buffer Example:: >>> for buf in model.buffers(): # doctest: +SKIP ... print(type(buf), buf.size()) # doctest: +SKIP oneflow.Size([10]) """ for (name, buf) in self.named_buffers(recurse=recurse): yield buf def named_buffers( self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: r""" named_buffers(prefix="", recurse=True) -> Iterator[Tuple[str, Tensor]] Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: (string, oneflow.Tensor): Tuple containing the name and buffer Example:: >>> for name, buf in self.named_buffers(): # doctest: +SKIP ... if name in ['running_var']: # doctest: +SKIP ... print(buf.size()) # doctest: +SKIP """ gen = self._named_members( lambda module: module._buffers.items(), prefix=prefix, recurse=recurse ) for elem in gen: yield elem def children(self) -> Iterator["Module"]: r""" children() -> Iterator["Module"] Returns an iterator over immediate children modules. Yields: Module: a child module Example:: >>> import oneflow.nn as nn >>> l1 = nn.Linear(2, 2) >>> l2 = nn.Linear(2, 2) >>> net = nn.Sequential(l1, l2) >>> for idx, m in enumerate(net.children()): ... print(idx, '->', m) 0 -> Linear(in_features=2, out_features=2, bias=True) 1 -> Linear(in_features=2, out_features=2, bias=True) """ for (name, module) in self.named_children(): yield module def named_children(self) -> Iterator[Tuple[str, "Module"]]: r""" named_children() -> Iterator[Tuple[str, "Module"]] Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple containing a name and child module Example:: >>> for name, module in model.named_children(): # doctest: +SKIP ... if name in ['conv4', 'conv5']: # doctest: +SKIP ... print(module) # doctest: +SKIP """ memo = set() for (name, module) in self._modules.items(): if module is not None and module not in memo: memo.add(module) yield (name, module) def modules(self) -> Iterator["Module"]: r""" modules() -> Iterator["Module"] Returns an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> import oneflow.nn as nn >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) """ for (name, module) in self.named_modules(): yield module def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""): r""" named_modules(memo=None, prefix="") Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module Yields: (string, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> import oneflow.nn as nn >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) """ if memo is None: memo = set() if self not in memo: memo.add(self) yield (prefix, self) for (name, module) in self._modules.items(): if module is None: continue submodule_prefix = prefix + ("." if prefix else "") + name for m in module.named_modules(memo, submodule_prefix): yield m def train(self: T, mode: bool = True) -> T: r""" train(mode=True) Sets the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm1d`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self def eval(self: T) -> T: r""" eval() Sets the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm1d`, etc. This is equivalent with :meth:`self.train(False) `. Returns: Module: self """ return self.train(False) def requires_grad_(self: T, requires_grad: bool = True) -> T: r"""Change if autograd should record operations on parameters in this module. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Module.html?highlight=requires_grad_#torch.nn.Module.requires_grad_. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self """ for p in self.parameters(): p.requires_grad_(requires_grad) return self def zero_grad(self, set_to_none: bool = False) -> None: r""" zero_grad(set_to_none=False) Sets gradients of all model parameters to zero. See similar function under :class:`oneflow.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`oneflow.optim.Optimizer.zero_grad` for details. """ if getattr(self, "_is_replica", False): warnings.warn( "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " "The parameters are copied (in a differentiable manner) from the original module. " "This means they are not leaf nodes in autograd and so don't accumulate gradients. " "If you need gradients in your forward method, consider using autograd.grad instead." ) for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() def _save_to_state_dict(self, destination, prefix, keep_vars): for (name, param) in self._parameters.items(): if param is not None: destination[prefix + name] = param for (name, buf) in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): for hook in self._load_state_dict_pre_hooks.values(): hook( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) persistent_buffers = { k: v for (k, v) in self._buffers.items() if k not in self._non_persistent_buffers_set } local_name_params = itertools.chain( self._parameters.items(), persistent_buffers.items() ) local_state = {k: v for (k, v) in local_name_params if v is not None} for (name, param) in local_state.items(): key = prefix + name if key in state_dict: input_param = state_dict[key] if tuple(input_param.shape) != tuple(param.shape): error_msgs.append( "size mismatch for {}: copying a param with shape {} from checkpoint, the shape in current model is {}.".format( key, input_param.shape, param.shape ) ) continue if ( isinstance(input_param, Tensor) and input_param.is_global != param.is_global ): if param.is_global: help_msg = "Maybe you need to convert the checkpoint param to global, or set global_src_rank=0 when using flow.load to load model's state_dict" else: help_msg = "Maybe you need to convert your model to global." error_msgs.append( 'local / global mismatch for "{}": param from checkpoint is {} tensor, but the param in current model is {} tensor. {}'.format( key, "global" if input_param.is_global else "local", "global" if param.is_global else "local", help_msg, ) ) continue try: with flow.no_grad(): param.copy_(input_param) except Exception as ex: error_msgs.append( 'While copying the parameter "{}", an exception occurred : \n\n{}.'.format( key, "".join( map( lambda line: "\t" + line, traceback.format_exc().splitlines(True), ) ), ) ) elif strict: missing_keys.append(key) if strict: for key in state_dict.keys(): if key.startswith(prefix): input_name = key[len(prefix) :] input_name = input_name.split(".", 1)[0] if ( input_name not in self._modules and input_name not in local_state ): unexpected_keys.append(key) def load_state_dict( self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True, ): r""" load_state_dict(state_dict, strict=True) Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~oneflow.nn.Module.state_dict` function. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~oneflow.nn.Module.state_dict` function. Default: ``True`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. """ missing_keys = [] unexpected_keys = [] error_msgs = [] metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for (name, child) in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(self) load = None if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, "Unexpected key(s) in state_dict: {}. ".format( ", ".join(('"{}"'.format(k) for k in unexpected_keys)) ), ) if len(missing_keys) > 0: error_msgs.insert( 0, "Missing key(s) in state_dict: {}. ".format( ", ".join(('"{}"'.format(k) for k in missing_keys)) ), ) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs) ) ) return _IncompatibleKeys(missing_keys, unexpected_keys) def state_dict( self, destination=None, prefix="", keep_vars=False ) -> Dict[str, Tensor]: r""" state_dict(destination=None, prefix="", keep_vars=False) -> Dict[str, Tensor] Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. Args: destination (dict, optional): Deprecated. This dict is returned with the module state saved in it. It should also have an attribute ``_metadata: dict`` to save metadata of the module state. If it's not provided, an ``OrderedDict`` is created and returned. Default: ``None`` prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in dict. Default: ``''`` keep_vars (bool, optional): by default the :class:`~oneflow.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching is not performed. Default: ``False`` Returns: dict: a dictionary containing a whole state of the module Example:: >>> import oneflow.nn as nn >>> l1 = nn.Linear(2, 2) >>> l2 = nn.Linear(2, 2) >>> net = nn.Sequential(l1, l2) >>> net.state_dict().keys() odict_keys(['0.weight', '0.bias', '1.weight', '1.bias']) """ if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() # TODO(hujiakui): add _version for nn.Module local_metadata = dict(version=1) if hasattr(destination, "_metadata"): destination._metadata[prefix[:-1]] = local_metadata self._save_to_state_dict(destination, prefix, keep_vars) for (name, module) in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination _grad_t = Union[Tuple[Tensor, ...], Tensor] def register_backward_hook( self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]] ): r"""Registers a backward hook on the module. This function is deprecated in favor of :meth:`~oneflow.nn.Module.register_full_backward_hook` and the behavior of this function will change in future versions. Returns: :class:`oneflow.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ if self._is_full_backward_hook is True: raise RuntimeError( "Cannot use both regular backward hooks and full backward hooks on a " "single Module. Please use only one of them." ) self._is_full_backward_hook = False handle = flow.utils.hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def register_full_backward_hook( self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]], ): r"""Registers a backward hook on the module. The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> TensorTuple or None The :attr:`grad_input` and :attr:`grad_output` are :class:`oneflow.TensorTuple` that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Returns: :class:`oneflow.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ if self._is_full_backward_hook is False: raise RuntimeError( "Cannot use both regular backward hooks and full backward hooks on a " "single Module. Please use only one of them." ) self._is_full_backward_hook = True handle = flow.utils.hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def _get_backward_hooks(self): r"""Returns the backward hooks for use in the call function. It returns two lists, one with the full backward hooks and one with the non-full backward hooks. """ full_backward_hooks: List[Callable] = [] if self._is_full_backward_hook is True: full_backward_hooks += self._backward_hooks.values() non_full_backward_hooks: List[Callable] = [] if self._is_full_backward_hook is False: non_full_backward_hooks += self._backward_hooks.values() return full_backward_hooks, non_full_backward_hooks def _maybe_warn_non_full_backward_hook(self, args, res, grad_fn): if not isinstance(res, Tensor): if not ( isinstance(res, tuple) and all([isinstance(r, Tensor) for r in result]) ): warnings.warn( "Using non-full backward hooks on a Module that does not return a " "single Tensor or a tuple of Tensors is deprecated and will be removed " "in future versions. This hook will be missing some of the grad_output. " "Please use register_full_backward_hook to get the documented behavior." ) return else: res = (res,) if not isinstance(args, Tensor): if not ( isinstance(args, tuple) and all([isinstance(i, Tensor) for i in args]) ): warnings.warn( "Using non-full backward hooks on a Module that does not take as input a " "single Tensor or a tuple of Tensors is deprecated and will be removed " "in future versions. This hook will be missing some of the grad_input. " "Please use register_full_backward_hook to get the documented behavior." ) return else: args = (args,) # At this point we are sure that inputs and result are tuple of Tensors out_grad_fn = {r.grad_fn for r in res if r.grad_fn is not None} if len(out_grad_fn) == 0 or ( len(out_grad_fn) == 1 and grad_fn not in out_grad_fn ): warnings.warn( "Using a non-full backward hook when outputs are nested in python data structure " "is deprecated and will be removed in future versions. This hook will be missing " "some grad_output." ) elif len(out_grad_fn) > 1: warnings.warn( "Using a non-full backward hook when outputs are generated by different autograd Nodes " "is deprecated and will be removed in future versions. This hook will be missing " "some grad_output. Please use register_full_backward_hook to get the documented behavior." ) else: # At this point the grad_output part of the hook will most likely be correct inputs_grad_fn = {i.grad_fn for i in args if i.grad_fn is not None} next_functions = {grad_fn.next_functions[0][0]} if inputs_grad_fn != next_functions: warnings.warn( "Using a non-full backward hook when the forward contains multiple autograd Nodes " "is deprecated and will be removed in future versions. This hook will be missing " "some grad_input. Please use register_full_backward_hook to get the documented " "behavior." ) def register_forward_pre_hook(self, hook: Callable[..., None]): r""" register_forward_pre_hook(hook) Registers a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. It should have the following signature:: hook(module, input) -> None or modified input The input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple). """ handle = flow.utils.hooks.RemovableHandle(self._forward_pre_hooks) self._forward_pre_hooks[handle.id] = hook return handle def register_forward_hook(self, hook: Callable[..., None]): r""" register_forward_hook(hook) Registers a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. It should have the following signature:: hook(module, input, output) -> None or modified output The input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. """ handle = flow.utils.hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle def _apply(self, fn): if not hasattr(self, "cpg"): self.cpg = None if self.cpg is not None: self.cpg = None warnings.warn( "deleted ContiguousParamsGroup since creating it before " "apply operations like to(), to_global() will cause error." ) # A dict to store tensors that has already been applied. # There is no need to apply multiple times on a same tensor. if self._oneflow_internal_module_tensor_applied_dict__ is None: self._oneflow_internal_module_tensor_applied_dict__ = dict() for module in self.children(): module._oneflow_internal_module_tensor_applied_dict__ = ( self._oneflow_internal_module_tensor_applied_dict__ ) module._apply(fn) module._oneflow_internal_module_tensor_applied_dict__ = None def can_use_assign_copy(tensor, tensor_applied): return tensor.is_local == tensor_applied.is_local for (key, param) in self._parameters.items(): if param is None: continue need_apply = False if param not in self._oneflow_internal_module_tensor_applied_dict__: need_apply = True assert isinstance(param, Parameter) assert param.is_leaf with flow.no_grad(): param_applied = fn(param) param_applied.requires_grad = param.requires_grad if param.grad is not None: assert param.grad.is_leaf with flow.no_grad(): grad_applied = fn(param.grad) grad_applied.requires_grad = param.grad.requires_grad param_applied.grad = grad_applied else: param_applied = self._oneflow_internal_module_tensor_applied_dict__[ param ] if can_use_assign_copy(param_applied, param): if need_apply: self._parameters[key].data = param_applied self._oneflow_internal_module_tensor_applied_dict__[ param ] = param_applied else: # The parameter's data has already been set when it can use assign copy. pass else: if need_apply: new_param = Parameter(param_applied, param.requires_grad) self._parameters[key] = new_param self._oneflow_internal_module_tensor_applied_dict__[ param ] = new_param else: self._parameters[ key ] = self._oneflow_internal_module_tensor_applied_dict__[param] for (key, buf) in self._buffers.items(): if buf is not None: if buf not in self._oneflow_internal_module_tensor_applied_dict__: buf_applied = fn(buf) self._buffers[key] = buf_applied self._oneflow_internal_module_tensor_applied_dict__[ buf ] = buf_applied else: self._buffers[ key ] = self._oneflow_internal_module_tensor_applied_dict__[buf] self._oneflow_internal_module_tensor_applied_dict__ = None return self def apply(self: T, fn: Callable[["Module"], None]) -> T: r""" apply(fn) Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model. Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> import oneflow as flow >>> import oneflow.nn as nn >>> @flow.no_grad() ... def init_weights(m): ... print(m) ... if type(m) == nn.Linear: ... m.weight.fill_(1.0) ... print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) tensor([[1., 1.], [1., 1.]], dtype=oneflow.float32, requires_grad=True) Linear(in_features=2, out_features=2, bias=True) tensor([[1., 1.], [1., 1.]], dtype=oneflow.float32, requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) """ if self.cpg is not None: self.cpg = None warnings.warn( "deleted ContiguousParamsGroup since creating it before " "apply operations like to(), to_global() will cause error." ) for module in self.children(): module.apply(fn) fn(self) return self def to_empty(self: T, *, device: Union[str, flow.device]) -> T: r"""Moves the parameters and buffers to the specified device without copying storage. Args: device (:class:`oneflow.device`): the desired device of the parameters and buffers in this module Returns: Module: self """ return self._apply(lambda t: flow.empty_like(t, device=device)) def _to_memory_format(self, memory_format): r"""Casts the parameters and buffers in this module to another memory format. The data_format attribute should also be modified. Note: This interface is unstable and may be removed in the future once the data_format attribute has been removed from the module. Args: memory_format (:class:`oneflow.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self """ for module in self.children(): module._to_memory_format(memory_format) self.to_memory_format(memory_format) return self def to_memory_format(self, memory_format) -> None: pass @overload def to( self: T, device: Optional[Union[int, str, flow.device]] = ..., dtype: Optional[flow.dtype] = ..., ) -> T: ... @overload def to(self: T, dtype: flow.dtype) -> T: ... @overload def to(self: T, tensor: Tensor) -> T: ... def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None) :noindex: .. function:: to(dtype) :noindex: .. function:: to(memory_format=None) :noindex: .. function:: to(tensor) :noindex: Its signature is similar to :meth:`oneflow.Tensor.to`, but only accepts floating point :attr:`dtype`\ s. In addition, this method will only cast the floating point parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`oneflow.device`): the desired device of the parameters and buffers in this module dtype (:class:`oneflow.dtype`): the desired floating point dtype of the parameters and buffers in this module memory_format (:class:`oneflow.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) tensor (oneflow.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module Returns: Module: self Examples:: >>> import oneflow as flow >>> import oneflow.nn as nn >>> linear = nn.Linear(2, 2) >>> linear.weight.device device(type='cpu', index=0) >>> linear.weight.dtype oneflow.float32 >>> linear.to(flow.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight.dtype oneflow.float64 >>> gpu1 = flow.device("cuda:1") >>> linear.to(gpu1, dtype=flow.half) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight.device device(type='cuda', index=1) >>> linear.weight.dtype oneflow.float16 >>> cpu = flow.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight.device device(type='cpu', index=0) """ device = None dtype = None memory_format = None if len(args) + len(kwargs) == 2: device = kwargs.pop("device", None) or args[0] dtype = kwargs.pop("dtype", None) or args[1] elif len(args) + len(kwargs) == 1: if len(args) == 1: arg = args[0] if isinstance(arg, Tensor): device = arg.device dtype = arg.dtype elif isinstance(arg, flow.dtype): dtype = arg device = None elif isinstance(arg, (flow.device, str, int)): dtype = None device = arg elif isinstance(arg, flow.memory_format): memory_format = arg else: raise ValueError(f"Unsupported parameters in module.to: {arg}") else: device = kwargs.pop("device", None) dtype = kwargs.pop("dtype", None) memory_format = kwargs.pop("memory_format", None) tensor = kwargs.pop("tensor", None) if tensor is not None: device = tensor.device dtype = tensor.dtype else: raise ValueError( f"Unsupported parameters in module.to: {args} and {kwargs}" ) if dtype is not None: if not dtype.is_floating_point: raise TypeError( "nn.Module.to only accepts floating point " "dtypes, but got desired dtype={}".format(dtype) ) if memory_format is not None: self._to_memory_format(memory_format) def convert(t): return t.to(device, dtype if t.is_floating_point() else None) return self._apply(convert) def to_consistent(self, *args, **kwargs): raise RuntimeError( ".to_consistent has been removed, please use .to_global instead" ) def to_global(self, placement=None, sbp=None): def convert(t): return t.to_global(placement=placement, sbp=sbp) return self._apply(convert) def to_local(self): def convert(t): return t.to_local() return self._apply(convert) def cpu(self: T) -> T: r""" cpu() Moves all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self """ return self._apply(lambda t: t.cpu()) def cuda(self: T, device: Optional[Union[int, flow.device]] = None) -> T: r""" cuda(device=None) Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self """ return self._apply(lambda t: t.cuda(device)) def float(self: T) -> T: r""" float() Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self """ return self._apply(lambda t: t.float() if t.is_floating_point() else t) def double(self: T) -> T: r""" double() Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self """ return self._apply(lambda t: t.double() if t.is_floating_point() else t) def half(self: T) -> T: r""" half() Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self """ return self._apply(lambda t: t.half() if t.is_floating_point() else t) def _get_name(self): return self.__class__.__name__ def get_submodule(self, target: str): r"""Get submodule accroding to the name of submodule. Args: target (str): The name of submodule to find. .. code-block:: python >>> from oneflow import nn >>> class Net3(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.linear = nn.Linear(3, 2) >>> >>> class Net2(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.net3 = Net3() >>> >>> class Net1(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.net2 = Net2() >>> >>> net = Net1() >>> print(net.get_submodule("net2.net3")) Net3( (linear): Linear(in_features=3, out_features=2, bias=True) ) >>> print(net.get_submodule("net2")) Net2( (net3): Net3( (linear): Linear(in_features=3, out_features=2, bias=True) ) ) Returns: oneflow.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the module can't reference the submodule accroding to ``target`` TypeError: If the result referenced by ``target`` is not an ``nn.Module`` """ if target == "": return self curr_module_name = [self._get_name()] submodule_names = target.split(".") mod = self for submodule_name in submodule_names: if not hasattr(mod, submodule_name): raise AttributeError( f"`{'.'.join(curr_module_name)}` doesn't have submodule `{submodule_name}`" ) mod = getattr(mod, submodule_name) curr_module_name.append(submodule_name) if not isinstance(mod, flow.nn.Module): raise TypeError( f"`{'.'.join(curr_module_name)}` isn't an oneflow.Module, but a {type(mod)}" ) return mod def get_parameter(self, target: str): r"""Return the parameter refenreced by ``target``. Args: target (str): The name of parameter to find. .. code-block:: python >>> from oneflow import nn >>> class Net3(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.linear = nn.Linear(3, 3) >>> >>> class Net2(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.net3 = Net3() >>> self.linear = nn.Linear(2, 2) >>> >>> class Net1(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.net2 = Net2() >>> self.linear = nn.Linear(1, 1) >>> >>> net = Net1() >>> print(net.get_parameter("linear.weight").shape) oneflow.Size([1, 1]) >>> print(net.get_parameter("net2.linear.weight").shape) oneflow.Size([2, 2]) Returns: oneflow.nn.Parameter: The parameter referenced by ``target`` Raises: AttributeError: If the module can't reference the parameter according to ``target`` TypeError: If the result refererenced by ``target`` is not an ``nn.Parameter`` """ sub_module_name, _, parameter_name = target.rpartition(".") sub_module = self.get_submodule(sub_module_name) if hasattr(sub_module, parameter_name): parameter = getattr(sub_module, parameter_name) else: raise AttributeError( f"`{sub_module_name}` doesn't have attribute `{parameter_name}`" ) if not isinstance(parameter, flow.Tensor): raise TypeError( f"`{target}` is not an oneflow.Tensor, but {type(parameter)}" ) return parameter def extra_repr(self) -> str: """Set the extra representation of the module To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. """ return "" def make_contiguous_params_group(self): r"""Get contiguous parameters group after creating the whole module. Rearrange the parameters of the model in the same dtype and device (or placement and sbp for global tensor) to form a single tensor for accelerating the element-wise operations of parameters' data or gradient. .. note:: This method should be used strictly after all parameters have finished doing apply operations, otherwise it will cause an error. Example:: >>> net = Network().to(device) >>> net.make_contiguous_params_group() """ self.cpg = flow.nn.utils.parameters_grouping.ContiguousParamsGroup( list(self.parameters()), group_on_current_buffer=False ) def __repr__(self): extra_lines = [] extra_repr = self.extra_repr() if extra_repr: extra_lines = extra_repr.split("\n") child_lines = [] for (key, module) in self._modules.items(): mod_str = repr(module) mod_str = _addindent(mod_str, 2) child_lines.append("(" + key + "): " + mod_str) lines = extra_lines + child_lines main_str = self._get_name() + "(" if lines: if len(extra_lines) == 1 and (not child_lines): main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str def _shallow_repr(self): extra_lines = [] extra_repr = self.extra_repr() if extra_repr: extra_lines = extra_repr.split("\n") lines = extra_lines main_str = self._get_name() + "(" if lines: if len(extra_lines) == 1: main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str ================================================ FILE: python/oneflow/nn/modules/moving_average_min_max_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module class MovingAverageMinMaxObserver(Module): """ Compute the quantization parameters based on the moving average of the input tensor's min and max values. First compute the moving\\_max and moving\\_min value of input tensor: if quantization_scheme == "symmetric": .. math:: & moving\\_max = moving\\_max * momentum + |max(input)| * (1 - momentum) & moving\\_min = moving\\_max elif quantization_scheme == "affine": .. math:: & moving\\_max = moving\\_max * momentum + max(input) * (1 - momentum) & moving\\_min = moving\\_min * momentum + min(input) * (1 - momentum) The moving average of min and max values are initialized as the first batch of input `Blob`'s min and max. Then compute the scale and zero_point with the following equations: if quantization_scheme == "symmetric": .. math:: & denom = 2^{quantization\\_to\\_bit - 1} - 1 & scale = moving\\_max / denom & zero\\_point = 0 elif quantization_scheme == "affine": .. math:: & denom = 2^{quantization\\_to\\_bit} - 1 & scale = (moving\\_max - moving\\_min) / denom & zero\\_point = -moving\\_min / scale Note: ``current_train_step`` can be directly assigned to an optimizer(eg.SGD) step. Args: input(oneflow.Tensor): the input value(s), in ``oneflow.float32``. current_train_step_tensor(oneflow.Tensor): record train step for quantionzation aware training. stop_update_after_iters(int): stop record train step for quantionzation aware training when train iter greater than stop_update_after_iters. quantization_formula (str): Support "google" or "cambricon". quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". momentum (float): Smoothing parameter for exponential moving average operation. Defaults to 0.95. Returns: Tuple[oneflow.Tensor, oneflow.Tensor]: The scale and zero_point of input tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32) >>> input_tensor = flow.tensor( ... weight, dtype=flow.float32 ... ) >>> current_train_step_tensor = flow.tensor( ... np.zeros((1,)).astype(np.float32), ... dtype=flow.int64, ... ) >>> momentum = 0.95 >>> quantization_bit = 8 >>> quantization_scheme = "symmetric" >>> quantization_formula = "google" >>> moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver(stop_update_after_iters=1, ... quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme, momentum=momentum, ... ) >>> (scale, zero_point) = moving_average_min_max_observer( ... input_tensor, ... current_train_step_tensor, ... ) """ def __init__( self, stop_update_after_iters: int = 1, quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", momentum: float = 0.95, ) -> None: super().__init__() self.quantization_formula = quantization_formula self.stop_update_after_iters = stop_update_after_iters self.quantization_bit = quantization_bit self.quantization_scheme = quantization_scheme self.momentum = momentum self.register_buffer("moving_max", flow.Tensor(1)) self.register_buffer("moving_min", flow.Tensor(1)) self.reset_running_stats() def reset_running_stats(self) -> None: self.moving_max.fill_(0) self.moving_min.fill_(0) def forward(self, input, current_train_step): return flow._C.moving_average_min_max_observer( input, current_train_step, self.moving_max, self.moving_min, self.training, self.stop_update_after_iters, self.quantization_formula, self.quantization_bit, self.quantization_scheme, self.momentum, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/nms.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module def nms_op(boxes, scores, iou_threshold: float): score_inds = flow.argsort(scores, dim=0, descending=True) boxes = flow._C.gather(boxes, score_inds, axis=0) keep = flow._C.nms(boxes, iou_threshold) index = flow.squeeze(flow.argwhere(keep), dim=[1]) return flow._C.gather(score_inds, index, axis=0) ================================================ FILE: python/oneflow/nn/modules/nonzero.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional import numpy as np import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module def nonzero_op(input, as_tuple=False): meta_device_flag = False if input.is_global: if input.placement.type == "meta": meta_device_flag = True else: if input.device.type == "meta": meta_device_flag = True if meta_device_flag: raise RuntimeError( "Could not run nonzero with arguments from the meta backend." ) if as_tuple: return flow._C.nonzero(input, as_tuple) else: return flow._C.nonzero(input, as_tuple)[0] if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def norm(input, p="fro", dim=None, keepdim=False, dtype=None): """ Returns the matrix norm or vector norm of a given tensor. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.norm.html. .. warning:: Use :func:`oneflow.linalg.norm`, instead, or :func:`oneflow.linalg.vector_norm` when computing vector norms and :func:`oneflow.linalg.matrix_norm` when computing matrix norms. Note, however, the signature for these functions is slightly different than the signature for oneflow.norm. Args: input (Tensor): The input tensor. Its data type must be either a floating point or complex type. For complex inputs, the norm is calculated using the absolute value of each element. If the input is complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` The following norms can be calculated: ====== ============== ========================== ord matrix norm vector norm ====== ============== ========================== 'fro' Frobenius norm -- 'nuc' nuclear norm -- Number -- sum(abs(x)**p)**(1./p) ====== ============== ========================== The vector norm can be calculated across any number of dimensions. The corresponding dimensions of :attr:`input` are flattened into one dimension, and the norm is calculated on the flattened dimension. Frobenius norm produces the same result as ``p=2`` in all cases except when :attr:`dim` is a list of three or more dims, in which case Frobenius norm throws an error. Nuclear norm can only be calculated across exactly two dimensions. dim (int, tuple of ints, list of ints, optional): Specifies which dimension or dimensions of :attr:`input` to calculate the norm across. If :attr:`dim` is ``None``, the norm will be calculated across all dimensions of :attr:`input`. If the norm type indicated by :attr:`p` does not support the specified number of dimensions, an error will occur. keepdim (bool, optional): whether the output tensors have :attr:`dim` retained or not. Ignored if :attr:`dim` = ``None`` and :attr:`out` = ``None``. Default: ``False`` dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` while performing the operation. Default: None. .. note:: Even though ``p='fro'`` supports any number of dimensions, the true mathematical definition of Frobenius norm only applies to tensors with exactly two dimensions. :func:`oneflow.linalg.norm` with ``ord='fro'`` aligns with the mathematical definition, since it can only be applied across exactly two dimensions. Example:: >>> import oneflow as flow >>> a = flow.arange(9, dtype= flow.float) - 4 >>> b = a.reshape((3, 3)) >>> flow.norm(a) tensor(7.7460, dtype=oneflow.float32) >>> flow.norm(b) tensor(7.7460, dtype=oneflow.float32) >>> flow.norm(a, float('inf')) tensor(4., dtype=oneflow.float32) >>> flow.norm(b, float('inf')) tensor(9., dtype=oneflow.float32) >>> c = flow.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= flow.float) >>> flow.norm(c, dim=0) tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32) >>> flow.norm(c, dim=1) tensor([3.7417, 4.2426], dtype=oneflow.float32) >>> flow.norm(c, p=1, dim=1) tensor([6., 6.], dtype=oneflow.float32) >>> d = flow.arange(8, dtype= flow.float).reshape(2,2,2) >>> flow.norm(d, dim=(1,2)) tensor([ 3.7417, 11.2250], dtype=oneflow.float32) >>> flow.norm(d[0, :, :]), flow.norm(d[1, :, :]) (tensor(3.7417, dtype=oneflow.float32), tensor(11.2250, dtype=oneflow.float32)) """ if type(p) == str or dim != None: return flow._C.norm(input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype) return flow._C.norm( input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype, for_norm=True ) ================================================ FILE: python/oneflow/nn/modules/normalization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from typing import Tuple, Union import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn import init from oneflow.nn.modules.module import Module _shape_t = Union[int, Tuple[int], flow._oneflow_internal.Size] def group_norm( input: Tensor, num_groups: int, weight: Tensor = None, bias: Tensor = None, eps: float = 1e-05, num_channels: int = None, ): r"""Apply Group Normalization for last certain number of dimensions. See :class:`~oneflow.nn.GroupNorm` for details. """ assert len(input.shape) >= 3, "The dimensions of input tensor must larger than 2" if num_channels is None: num_channels = input.shape[1] assert ( input.shape[1] == num_channels ), "The channels of input tensor must equal num_channels" affine = weight is not None and bias is not None if not input.is_cpu: return flow._C.group_norm(input, weight, bias, affine, num_groups, eps) else: origin_shape = input.shape reshape_to_1d = flow.reshape(input, shape=[origin_shape[0], num_groups, -1]) mean = flow.mean(reshape_to_1d, dim=2, keepdim=True) variance = flow.var(reshape_to_1d, dim=2, unbiased=False, keepdim=True) normalized = (reshape_to_1d - mean) / flow.sqrt(variance + eps) normalized = flow.reshape(normalized, shape=[origin_shape[0], num_channels, -1]) if weight is not None: normalized = normalized * weight.reshape(1, num_channels, 1) if bias is not None: normalized = normalized + bias.reshape(1, num_channels, 1) res = flow.reshape(normalized, shape=tuple(input.shape)) return res class GroupNorm(Module): """ Applies Group Normalization over a mini-batch of inputs as described in the paper `Group Normalization `__ .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The input channels are separated into :attr:`num_groups` groups, each containing ``num_channels / num_groups`` channels. The mean and standard-deviation are calculated separately over the each group. :math:`\\gamma` and :math:`\\beta` are learnable per-channel affine transform parameter vectors of size :attr:`num_channels` if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. This layer uses statistics computed from input data in both training and evaluation modes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.GroupNorm.html. Args: num_groups (int): number of groups to separate the channels into num_channels (int): number of channels expected in input eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: a boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. Shape: - Input: :math:`(N, C, *)` where :math:`C=\\text{num_channels}` - Output: :math:`(N, C, *)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.Tensor(np.random.randn(20, 6, 10, 10)) >>> # Separate 6 channels into 3 groups >>> m = flow.nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = flow.nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = flow.nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input) """ def __init__( self, num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True, device=None, dtype=None, ) -> None: super().__init__() assert num_groups > 0, "The num_groups must larger than zero" assert num_channels > 0, "The num_channels must larger than zero" self.num_groups = num_groups self.num_channels = num_channels self.eps = eps self.affine = affine factory_kwargs = {} if device: factory_kwargs["device"] = device if dtype: factory_kwargs["dtype"] = dtype if self.affine: self.weight = flow.nn.Parameter( flow.Tensor(num_channels).to(**factory_kwargs) ) self.bias = flow.nn.Parameter( flow.Tensor(num_channels).to(**factory_kwargs) ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.affine: flow.nn.init.ones_(self.weight) flow.nn.init.zeros_(self.bias) def forward(self, input: Tensor) -> Tensor: return group_norm( input, self.num_groups, self.weight, self.bias, self.eps, self.num_channels ) def extra_repr(self) -> str: return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format( **self.__dict__ ) def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): assert len(input.shape) > len( normalized_shape ), "Input tensor dim must greater than normalized dim!" begin_norm_axis = len(input.shape) - len(normalized_shape) begin_params_axis = len(input.shape) - len(normalized_shape) elementwise_affine = True if (weight is not None and bias is not None) else False for i in range(0, len(normalized_shape)): if input.shape[i + begin_params_axis] != normalized_shape[i]: raise RuntimeError( f"Given normalized_shape={normalized_shape}, expected input with shape [*, {str(normalized_shape)[1:-1]}], but got input of size {input.shape}" ) if input.is_cpu: reduce_axis = [] for dim in range(len(input.shape)): if dim >= begin_norm_axis: reduce_axis.append(dim) mean = input.mean(dim=reduce_axis, keepdim=True) variance = input.var(dim=reduce_axis, unbiased=False, keepdim=True) params_shape = input.shape[begin_params_axis:] if len(mean.shape) == 1: nd_params_shape = [1] * len(input.shape) nd_params_shape[begin_norm_axis] = params_shape[0] mean = flow.reshape(mean, shape=nd_params_shape) variance = flow.reshape(variance, nd_params_shape) if weight is not None and params_shape[0] == weight.nelement(): weight = flow.reshape(weight, shape=nd_params_shape) if bias is not None and params_shape[0] == bias.nelement(): bias = flow.reshape(bias, shape=nd_params_shape) elif len(mean.shape) == len(input.shape): pass else: raise ValueError( "shape of mean and variance should be 1D or has number of axes and x's" ) variance += eps normalized = (input - mean) * variance.rsqrt() if elementwise_affine: normalized = normalized * weight + bias return normalized else: if elementwise_affine: res = flow._C.layer_norm_affine( input, weight, bias, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, epsilon=eps, ) else: res = flow._C.layer_norm( input, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, epsilon=eps, ) return res class LayerNorm(Module): """Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization `__ .. math:: y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta The mean and standard-deviation are calculated separately over the last certain number dimensions which have to be of the shape specified by :attr:`normalized_shape`. :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. The standard-deviation is calculated via the biased estimator. .. note:: Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the :attr:`affine` option, Layer Normalization applies per-element scale and bias with :attr:`elementwise_affine`. This layer uses statistics computed from input data in both training and evaluation modes. Args: normalized_shape (int or list or oneflow.Size): input shape from an expected input of size .. math:: [* \\times \\text{normalized_shape}[0] \\times \\text{normalized_shape}[1] \\times \\ldots \\times \\text{normalized_shape}[-1]] If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine: a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. Shape: - Input: :math:`(N, *)` - Output: :math:`(N, *)` (same shape as input) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input_arr = np.array( ... [ ... [ ... [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]], ... [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]], ... ], ... [ ... [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]], ... [[1.07541728, 0.11008703], [0.26361224, -0.48663723]], ... ], ... ], ... dtype=np.float32, ... ) >>> x = flow.Tensor(input_arr) >>> m = flow.nn.LayerNorm(2) >>> y = m(x).numpy() >>> y array([[[[ 0.99997395, -0.99997395], [-0.999947 , 0.999947 ]], [[-0.99995965, 0.9999595 ], [ 0.99998784, -0.99998784]]], [[[-0.9998348 , 0.99983466], [ 0.9999914 , -0.9999914 ]], [[ 0.9999785 , -0.9999785 ], [ 0.9999646 , -0.9999646 ]]]], dtype=float32) """ __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-05, elementwise_affine: bool = True, ) -> None: super(LayerNorm, self).__init__() if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = flow.nn.Parameter(flow.Tensor(*self.normalized_shape)) self.bias = flow.nn.Parameter(flow.Tensor(*self.normalized_shape)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: init.ones_(self.weight) init.zeros_(self.bias) def forward(self, x): return layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) def extra_repr(self) -> str: return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format( **self.__dict__ ) class RMSLayerNorm(Module): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated w/o mean and there is no bias. Additionally we want to make sure that the accumulation for half-precision inputs is done in fp32. Args: hidden_size (int): number of features in the hidden state eps: a value added to the denominator for numerical stability. Default: 1e-6 Shape: - Input: :math:`(N, *)` - Output: :math:`(N, *)` (same shape as input) For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.randn(2, 4, 3) >>> m = flow.nn.RMSLayerNorm(3) >>> y = m(x) >>> y.size() oneflow.Size([2, 4, 3]) """ def __init__(self, hidden_size, eps=1e-6): warnings.warn( f"nn.RMSLayerNorm has been deprecated. Please use nn.RMSNorm instead." ) super().__init__() self.weight = flow.nn.Parameter(flow.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): return flow._C.rms_layer_norm(hidden_states, self.weight, self.variance_epsilon) class RMSNorm(Module): """Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper `Root Mean Square Layer Normalization `__ .. math:: y = \\frac{x}{\\mathrm{RMS}[x]} \\mathrm{weight},\\text{ where }\\mathrm{RMS}[x] = \\sqrt{\\frac{1}{n} \\sum_{i=1}^{n} x^{2}} There is no bias and no subtraction of mean with RMS Layer Normalization, and it only scales and doesn't shift. The root mean squre are calculated separately over the last certain number dimensions which have to be of the shape specified by :attr:`normalized_shape`. :math:`\\weight` is learnable affine transform parameters of :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. .. note:: Like Layer Normalization, Root Mean Square Layer Normalization applies per-element scale with :attr:`elementwise_affine`. This layer uses statistics computed from input data in both training and evaluation modes. Args: normalized_shape (int or list or oneflow.Size): input shape from an expected input of size .. math:: [* \\times \\text{normalized_shape}[0] \\times \\text{normalized_shape}[1] \\times \\ldots \\times \\text{normalized_shape}[-1]] If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine: a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``. Shape: - Input: :math:`(N, *)` - Output: :math:`(N, *)` (same shape as input) For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input_arr = np.array( ... [ ... [ ... [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]], ... [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]], ... ], ... [ ... [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]], ... [[1.07541728, 0.11008703], [0.26361224, -0.48663723]], ... ], ... ], ... dtype=np.float32, ... ) >>> x = flow.Tensor(input_arr, device="cuda") >>> m = flow.nn.RMSNorm(2).to(device="cuda") >>> y = m(x).numpy() >>> y array([[[[-0.21632987, -1.3975569 ], [-1.127044 , 0.8541454 ]], [[-1.2975204 , -0.5625112 ], [ 1.3711083 , 0.34648165]]], [[[-1.2388322 , -0.6820876 ], [ 0.16959298, -1.4040003 ]], [[ 1.4068495 , 0.14401469], [ 0.6735778 , -1.2434478 ]]]], dtype=float32) """ _constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-05, elementwise_affine: bool = True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = flow.nn.Parameter( flow.ones(*self.normalized_shape, **factory_kwargs) ) else: self.register_parameter("weight", None) def forward(self, x): return flow._C.rms_norm(x, self.weight, self.normalized_shape, self.eps) def extra_repr(self) -> str: return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format( **self.__dict__ ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/numel.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def numel_op(input): """ numel(input) -> int Returns the total number of elements in the :attr:`input` tensor. Args: input (oneflow.Tensor): Input Tensor .. code-block:: python >>> import oneflow as flow >>> a = flow.randn(1, 2, 3, 4, 5) >>> flow.numel(a) 120 >>> a = flow.zeros(4,4) >>> flow.numel(a) 16 """ return input.numel() if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/padding.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Union import oneflow as flow from oneflow.nn.common_types import _size_2_t, _size_4_t from oneflow.nn.modules.module import Module from oneflow.nn.modules.utils import _pair, _quadruple class ReplicationPad1d(Module): r""" ReplicationPad1d(padding) Pads the input tensor using replication of the input boundary. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ReplicationPad1d.html. For `N`-dimensional padding, use :func:`oneflow.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 2-`tuple`, uses (:math:`\text{padding_left}`, :math:`\text{padding_right}`) Shape: - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where :math:`W_{out} = W_{in} + \text{padding_left} + \text{padding_right}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> m = flow.nn.ReplicationPad1d((2, 2)) >>> input = flow.tensor(np.arange(18).reshape((2, 3, 3)).astype(np.float32)) >>> out = m(input) >>> out tensor([[[ 0., 0., 0., 1., 2., 2., 2.], [ 3., 3., 3., 4., 5., 5., 5.], [ 6., 6., 6., 7., 8., 8., 8.]], [[ 9., 9., 9., 10., 11., 11., 11.], [12., 12., 12., 13., 14., 14., 14.], [15., 15., 15., 16., 17., 17., 17.]]], dtype=oneflow.float32) """ def __init__(self, padding: _size_4_t): super().__init__() if isinstance(padding, tuple): assert len(padding) == 2, ValueError("Padding length must be 2") boundary = [*padding] elif isinstance(padding, int): boundary = _pair(padding) else: raise ValueError("padding must be in or list or tuple!") self.padding = boundary def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="replicate") def extra_repr(self) -> str: return "{}".format(self.padding) class ReplicationPad2d(Module): """ ReplicationPad2d(padding) Pads the input tensor using the replication of the input boundary. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ReplicationPad2d.html. Args: padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\\mathrm{padding_{left}}`, :math:`\\mathrm{padding_{right}}`, :math:`\\mathrm{padding_{top}}`, :math:`\\mathrm{padding_{bottom}}`) Shape: - Input: :math:`(N, C, H_{\\text{in}}, W_{\\text{in}})` or :math:`(C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{\\text{out}}, W_{\\text{out}})` or :math:`(C, H_{out}, W_{out})` where :math:`H_{out} = H_{in} + \\mathrm{padding_{top}} + \\mathrm{padding_{bottom}}` :math:`W_{out} = W_{in} + \\mathrm{padding_{left}} + \\mathrm{padding_{right}}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> m = flow.nn.ReplicationPad2d((2, 2, 1, 1)) >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) >>> input_int = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32)) >>> output = m(input) >>> output.shape oneflow.Size([1, 2, 5, 7]) >>> output tensor([[[[ 0., 0., 0., 1., 2., 2., 2.], [ 0., 0., 0., 1., 2., 2., 2.], [ 3., 3., 3., 4., 5., 5., 5.], [ 6., 6., 6., 7., 8., 8., 8.], [ 6., 6., 6., 7., 8., 8., 8.]], [[ 9., 9., 9., 10., 11., 11., 11.], [ 9., 9., 9., 10., 11., 11., 11.], [12., 12., 12., 13., 14., 14., 14.], [15., 15., 15., 16., 17., 17., 17.], [15., 15., 15., 16., 17., 17., 17.]]]], dtype=oneflow.float32) """ def __init__(self, padding: _size_4_t): super().__init__() if isinstance(padding, (tuple, list)): assert len(padding) == 4, ValueError("Length of padding must be 4") boundary = [*padding] elif isinstance(padding, int): boundary = _quadruple(padding) else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="replicate") def extra_repr(self) -> str: return "{}".format(self.padding) class ReflectionPad1d(Module): """ ReflectionPad1d(padding) This operator pads the input tensor using the reflection of the input boundary. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ReflectionPad1d.html. Args: padding (Union[int,tuple]): The size or bundary of padding, if is `int` uses the same padding in all dimension; if 4-dims `tuple`, uses :math:`(\\text{padding}_{\\text{left}}, \\text{padding}_{\\text{right}}, \\text{padding}_{\\text{top}}, \\text{padding}_{\\text{bottom}} )` Returns: Tensor: Returns a new tensor which is result of the reflection padding of the input tensor. Shape: - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where :math:`W_{out} = W_{in} + \\text{padding_left} + \\text{padding_right}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(18).reshape((2, 3, 3)).astype(np.float32)) >>> m = flow.nn.ReflectionPad1d((2, 2)) >>> out = m(input) >>> out tensor([[[ 2., 1., 0., 1., 2., 1., 0.], [ 5., 4., 3., 4., 5., 4., 3.], [ 8., 7., 6., 7., 8., 7., 6.]], [[11., 10., 9., 10., 11., 10., 9.], [14., 13., 12., 13., 14., 13., 12.], [17., 16., 15., 16., 17., 16., 15.]]], dtype=oneflow.float32) """ def __init__(self, padding: _size_2_t) -> None: super().__init__() if isinstance(padding, tuple): assert len(padding) == 2, ValueError("Padding length must be 2") boundary = [*padding] elif isinstance(padding, int): boundary = _pair(padding) else: raise ValueError("padding must be in or list or tuple!") self.padding = boundary def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="reflect") def extra_repr(self) -> str: return "{}".format(self.padding) class ReflectionPad2d(Module): """ ReflectionPad2d(padding) This operator pads the input tensor using the reflection of the input boundary. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ReflectionPad2d.html. Args: padding (Union[int,tuple]): The size or bundary of padding, if is `int` uses the same padding in all dimension; if 4-dims `tuple`, uses :math:`(\\text{padding}_{\\text{left}}, \\text{padding}_{\\text{right}}, \\text{padding}_{\\text{top}}, \\text{padding}_{\\text{bottom}} )` Returns: Tensor: Returns a new tensor which is result of the reflection padding of the input tensor. Shape: - Input: :math:`(N, C, H_{\\text{in}}, W_{\\text{in}})` or :math:`(C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{\\text{out}}, W_{\\text{out}})` or :math:`(C, H_{out}, W_{out})` where :math:`H_{\\text{out}} = H_{\\text{in}} + \\text{padding}_{\\text{top}} + \\text{padding}_{\\text{bottom}}` :math:`W_{\\text{out}} = W_{\\text{in}} + \\text{padding}_{\\text{left}} + \\text{padding}_{\\text{right}}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) >>> m = flow.nn.ReflectionPad2d((2, 2, 1, 1)) >>> out = m(input) >>> out tensor([[[[ 5., 4., 3., 4., 5., 4., 3.], [ 2., 1., 0., 1., 2., 1., 0.], [ 5., 4., 3., 4., 5., 4., 3.], [ 8., 7., 6., 7., 8., 7., 6.], [ 5., 4., 3., 4., 5., 4., 3.]], [[14., 13., 12., 13., 14., 13., 12.], [11., 10., 9., 10., 11., 10., 9.], [14., 13., 12., 13., 14., 13., 12.], [17., 16., 15., 16., 17., 16., 15.], [14., 13., 12., 13., 14., 13., 12.]]]], dtype=oneflow.float32) """ def __init__(self, padding: _size_4_t) -> None: super().__init__() if isinstance(padding, tuple): assert len(padding) == 4, ValueError("Padding length must be 4") boundary = [*padding] elif isinstance(padding, int): boundary = _quadruple(padding) else: raise ValueError("padding must be in or list or tuple!") self.padding = boundary def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="reflect") def extra_repr(self) -> str: return "{}".format(self.padding) class ConstantPad1d(Module): """ ConstantPad1d(padding) Pads the input tensor boundaries with a constant value. The interface is consistent with PyTorch, and referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad1d.html. For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, list, tuple): the size of the padding. If is `int`, uses the same padding in both boundaries. If a 2-`tuple`, uses (:math:`\\text{padding_left}`, :math:`\\text{padding_right}`) value (int, float): The constant value used for padding. Defaults to 0. Shape: - Input: :math:`(N, C, W_{in})` - Output: :math:`(N, C, W_{out})` where :math:`W_{out} = W_{in} + \\text{padding\\_left} + \\text{padding\\_right}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(8).reshape(2,2,2).astype(np.float32)) >>> m = flow.nn.ConstantPad1d(padding=[1, 2], value=9.9999) >>> output = m(input) >>> output tensor([[[9.9999, 0.0000, 1.0000, 9.9999, 9.9999], [9.9999, 2.0000, 3.0000, 9.9999, 9.9999]], [[9.9999, 4.0000, 5.0000, 9.9999, 9.9999], [9.9999, 6.0000, 7.0000, 9.9999, 9.9999]]], dtype=oneflow.float32) """ def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): super().__init__() if isinstance(padding, (tuple, list)): boundary = padding elif isinstance(padding, int): boundary = [padding] * 2 else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary self.value = value def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="constant", value=self.value) class ConstantPad2d(Module): """ ConstantPad2d(padding) This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad2d.html. Args: padding (int, tuple, list): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\\mathrm{padding_{left}}`, :math:`\\mathrm{padding_{right}}`, :math:`\\mathrm{padding_{top}}`, :math:`\\mathrm{padding_{bottom}}`) value (int, float): The constant value used for padding. Defaults to 0. Shape: - Input: :math:`(N, C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{out}, W_{out})` where :math:`H_{out} = H_{in} + \\mathrm{padding_{top}} + \\mathrm{padding_{bottom}}` :math:`W_{out} = W_{in} + \\mathrm{padding_{left}} + \\mathrm{padding_{right}}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> m = flow.nn.ConstantPad2d((2, 2, 1, 1), 1) >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) >>> output = m(input) >>> output.shape oneflow.Size([1, 2, 5, 7]) >>> output tensor([[[[ 1., 1., 1., 1., 1., 1., 1.], [ 1., 1., 0., 1., 2., 1., 1.], [ 1., 1., 3., 4., 5., 1., 1.], [ 1., 1., 6., 7., 8., 1., 1.], [ 1., 1., 1., 1., 1., 1., 1.]], [[ 1., 1., 1., 1., 1., 1., 1.], [ 1., 1., 9., 10., 11., 1., 1.], [ 1., 1., 12., 13., 14., 1., 1.], [ 1., 1., 15., 16., 17., 1., 1.], [ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32) """ def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): super().__init__() if isinstance(padding, (tuple, list)): boundary = padding elif isinstance(padding, int): boundary = [padding] * 4 else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary self.value = value def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="constant", value=self.value) class ConstantPad3d(Module): """ ConstantPad3d(padding) Pads the input tensor boundaries with a constant value. The interface is consistent with PyTorch, and referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad3d.html. For `N`-dimensional padding, use :func:`flow.nn.functional.pad()`. Args: padding (int, list, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 6-`tuple`, uses (:math:`\\text{padding_left}`, :math:`\\text{padding_right}`, :math:`\\text{padding_top}`, :math:`\\text{padding_bottom}`, :math:`\\text{padding_front}`, :math:`\\text{padding_back}`) value (int, float): The constant value used for padding. Defaults to 0. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where :math:`D_{out} = D_{in} + \\text{padding_front} + \\text{padding_back}` :math:`H_{out} = H_{in} + \\text{padding_top} + \\text{padding_bottom}` :math:`W_{out} = W_{in} + \\text{padding_left} + \\text{padding_right}` Examples:: >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.arange(8).reshape(1,1,2,2,2).astype(np.int32)) >>> m = flow.nn.ConstantPad3d(padding=1, value=9) >>> output = m(input) >>> output tensor([[[[[9, 9, 9, 9], [9, 9, 9, 9], [9, 9, 9, 9], [9, 9, 9, 9]], [[9, 9, 9, 9], [9, 0, 1, 9], [9, 2, 3, 9], [9, 9, 9, 9]], [[9, 9, 9, 9], [9, 4, 5, 9], [9, 6, 7, 9], [9, 9, 9, 9]], [[9, 9, 9, 9], [9, 9, 9, 9], [9, 9, 9, 9], [9, 9, 9, 9]]]]], dtype=oneflow.int32) """ def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): super().__init__() if isinstance(padding, (tuple, list)): boundary = padding elif isinstance(padding, int): boundary = [padding] * 6 else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary self.value = value def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="constant", value=self.value) class ZeroPad2d(Module): """ ZeroPad2d(padding) Pads the input tensor boundaries with zero. User can set the amount of padding by setting the parameter `paddings`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ZeroPad2d.html. Args: padding (Union[int, tuple]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\\mathrm{padding_{left}}`, :math:`\\mathrm{padding_{right}}`, :math:`\\mathrm{padding_{top}}`, :math:`\\mathrm{padding_{bottom}}`) Shape: - Input: :math:`(N, C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{out}, W_{out})` where :math:`H_{out} = H_{in} + \\mathrm{padding_{top}} + \\mathrm{padding_{bottom}}` :math:`W_{out} = W_{in} + \\mathrm{padding_{left}} + \\mathrm{padding_{right}}` For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> m1 = flow.nn.ZeroPad2d(2) >>> m2 = flow.nn.ZeroPad2d((1,2,2,0)) >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) >>> output = m1(input) >>> output.shape oneflow.Size([1, 2, 7, 7]) >>> output tensor([[[[ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 1., 2., 0., 0.], [ 0., 0., 3., 4., 5., 0., 0.], [ 0., 0., 6., 7., 8., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.]], [[ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 9., 10., 11., 0., 0.], [ 0., 0., 12., 13., 14., 0., 0.], [ 0., 0., 15., 16., 17., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0.]]]], dtype=oneflow.float32) >>> output = m2(input) >>> output tensor([[[[ 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0.], [ 0., 0., 1., 2., 0., 0.], [ 0., 3., 4., 5., 0., 0.], [ 0., 6., 7., 8., 0., 0.]], [[ 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0.], [ 0., 9., 10., 11., 0., 0.], [ 0., 12., 13., 14., 0., 0.], [ 0., 15., 16., 17., 0., 0.]]]], dtype=oneflow.float32) """ def __init__(self, padding: Union[int, tuple, list]): super().__init__() if isinstance(padding, (tuple, list)): boundary = padding elif isinstance(padding, int): boundary = [padding] * 4 else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary self.value = 0.0 def forward(self, x): return flow._C.pad(x, pad=self.padding, mode="constant", value=self.value) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/pixelshuffle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module class PixelShufflev2(Module): """ Part of the documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.PixelShuffle.html. Rearranges elements in a tensor of shape :math:`(*, C \\times r_h \\times r_w, H, W)` to a tensor of shape :math:`(*, C, H \\times r_h, W \\times r_w)`, where r_h and r_w are upscale factors. This is useful for implementing efficient sub-pixel convolution with a stride of :math:`1/r`. See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ by Shi et. al (2016) for more details. Args: upscale_factor (int, optional): factor to increase spatial resolution by, only use when factors of height and width spatial are the same. h_upscale_factor (int, optional): factor to increase height spatial resolution by, only one of h_upscale_factor and upscale_factor can be used. w_upscale_factor (int, optional): factor to increase width spatial resolution by, only one of w_upscale_factor and upscale_factor can be used. Shape: - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where if use upscale_factor: .. math:: C_{out} = C_{in} \\div \\text{h_upscale_factor}^2 H_{out} = H_{in} \\times \\text{upscale_factor} W_{out} = W_{in} \\times \\text{upscale_factor} if use h_upscale_factor and w_upscale_factor: .. math:: C_{out} = C_{in} \\div \\text{h_upscale_factor} \\div \\text{w_upscale_factor} H_{out} = H_{in} \\times \\text{h_upscale_factor} W_{out} = W_{in} \\times \\text{w_upscale_factor} For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> m = flow.nn.PixelShuffle(upscale_factor=2) >>> x = flow.Tensor(np.random.randn(3, 4, 5, 5)) >>> y = m(x) >>> y.shape oneflow.Size([3, 1, 10, 10]) >>> m = flow.nn.PixelShuffle(h_upscale_factor=3, w_upscale_factor=4) >>> x = flow.Tensor(np.random.randn(1, 24, 2, 2)) >>> y = m(x) >>> y.shape oneflow.Size([1, 2, 6, 8]) .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: https://arxiv.org/abs/1609.05158 """ def __init__( self, upscale_factor: Optional[int] = None, h_upscale_factor: Optional[int] = None, w_upscale_factor: Optional[int] = None, ) -> None: super().__init__() if upscale_factor is None: assert ( h_upscale_factor is not None and w_upscale_factor is not None ), "h_upscale_factor and w_upscale_factor should be None if use upscale_factor" else: assert ( h_upscale_factor is None and w_upscale_factor is None ), "upscale_factor should be None if use h_upscale_factor and w_upscale_factor" h_upscale_factor = upscale_factor w_upscale_factor = upscale_factor assert ( h_upscale_factor > 0 and w_upscale_factor > 0 ), "The scale factor of height and width must larger than zero" self.h_upscale_factor = h_upscale_factor self.w_upscale_factor = w_upscale_factor def forward(self, input: Tensor) -> Tensor: return flow._C.pixel_shuffle( input, self.h_upscale_factor, self.w_upscale_factor ) def extra_repr(self) -> str: return f"w_upscale_factor={self.w_upscale_factor}, h_upscale_factor={self.h_upscale_factor}" if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/pooling.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional, Union, List import os import oneflow as flow from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t from oneflow.nn.modules.module import Module from oneflow.nn.modules.utils import ( _generate_output_size, _getint, _pair, _single, _triple, ) class MaxPool1d(Module): r"""Applies a 1D max pooling over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool1d.html. In the simplest case, the output value of the layer with input size :math:`(N, C, L)` and output :math:`(N, C, L_{out})` can be precisely described as: .. math:: out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} input(N_i, C_j, stride \times k + m) If :attr:`padding` is non-zero, then the input is implicitly padded with minimum value on both sides for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the sliding window. This link has a nice visualization of the pooling parameters. Note: When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored. Args: kernel_size: The size of the sliding window, must be > 0. stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. dilation: The stride between elements within a sliding window, must be > 0. return_indices: If ``True``, will return the argmax along with the max values. ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window. Shape: - Input: :math:`(N, C, L_{in})` - Output: :math:`(N, C, L_{out})`, where .. math:: L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} \times (\text{kernel_size} - 1) - 1}{\text{stride}} + 1\right\rfloor For example: .. code-block:: python import oneflow as flow import numpy as np of_maxpool1d = flow.nn.MaxPool1d(kernel_size=3, padding=1, stride=1) x = flow.Tensor(np.random.randn(1, 4, 4)) y = of_maxpool1d(x) y.shape oneflow.Size([1, 4, 4]) """ def __init__( self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0, dilation: _size_1_t = 1, return_indices: bool = False, ceil_mode: bool = False, ): super().__init__() self.kernel_size = _single(kernel_size) self.stride = _single(stride) if stride is not None else self.kernel_size data_format = "NCL" # only support "NCL" for now ! self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last" self.dilation = _single(dilation) self.padding = _single(padding) self.return_indices = return_indices self.ceil_mode = ceil_mode def forward(self, x): y, indice = flow._C.max_pool1d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, return_indices=True, ceil_mode=self.ceil_mode, data_format=self.channel_pos, ) if self.return_indices: return y, indice else: return y def extra_repr(self) -> str: return "kernel_size={}, stride={}, padding={}".format( self.kernel_size, self.stride, self.padding ) def get_dhw_offset(channel_pos): if channel_pos == "channels_first": return 2 else: return 1 def get_ndim_pads_list(padding, dhw_offset, ndims): pads_list = [] for i in range(len(padding)): pad = padding[i] if isinstance(pad, int): pad = [pad, pad] elif isinstance(pad, (list, tuple)): assert len(pad) == 2 pad = [pad[0], pad[1]] else: raise ValueError("padding must be list tuple or int") if i in range(dhw_offset, dhw_offset + ndims): pads_list.append(pad) else: assert pad == [0, 0] return pads_list def calc_pool_padding(padding, dhw_offset, ndims): if isinstance(padding, str): padding = "SAME_LOWER" if padding.upper() == "SAME" else padding assert padding.upper() in ["VALID", "SAME_LOWER", "SAME_UPPER"] padding_type = padding.lower() ndim_pads_list = [[0, 0]] * ndims elif isinstance(padding, (list, tuple)): padding_type = "customized" ndim_pads_list = get_ndim_pads_list(padding, dhw_offset, ndims) else: raise ValueError("padding must be str or a list.") return (padding_type, ndim_pads_list) class MaxPool2d(Module): r"""Applies a 2D max pooling over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool2d.html. In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` can be precisely described as: .. math:: \begin{aligned} out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, \text{stride[1]} \times w + n) \end{aligned} If :attr:`padding` is non-zero, then the input is implicitly minimum value padded on both sides for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this link has a nice visualization of what :attr:`dilation` does. Note: When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Args: kernel_size: the size of the window to take a max over stride: the stride of the window. Default value is :attr:`kernel_size` padding: implicit minimum value padding to be added on both sides dilation: a parameter that controls the stride of elements in the window return_indices: if ``True``, will return the max indices along with the outputs. Useful for :class:`torch.nn.MaxUnpool2d` later ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape Shape: - Input: :math:`(N, C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{out}, W_{out})`, where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor For example: .. code-block:: python import oneflow as flow import numpy as np m = flow.nn.MaxPool2d(kernel_size=3, padding=1, stride=1) x = flow.Tensor(np.random.randn(1, 4, 4, 4)) y = m(x) y.shape oneflow.Size([1, 4, 4, 4]) """ def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, dilation: _size_2_t = 1, return_indices: bool = False, ceil_mode: bool = False, ): super().__init__() self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) self.padding = _pair(padding) self.dilation = _pair(dilation) self.return_indices = return_indices self.ceil_mode = ceil_mode if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_pos = "channels_last" else: self.channel_pos = "channels_first" def to_memory_format(self, memory_format) -> None: if memory_format is flow.channels_last: self.channel_pos = "channels_last" elif memory_format is flow.contiguous_format: self.channel_pos = "channels_first" def forward(self, x): if not self.return_indices: return flow._C.max_pool2d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, return_indices=self.return_indices, ceil_mode=self.ceil_mode, data_format=self.channel_pos, )[0] else: return flow._C.max_pool2d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, return_indices=self.return_indices, ceil_mode=self.ceil_mode, data_format=self.channel_pos, ) def extra_repr(self) -> str: return "kernel_size={}, stride={}, padding={}, dilation={}".format( self.kernel_size, self.stride, self.padding, self.dilation ) class MaxPool3d(Module): r"""Applies a 3D max pooling over an input signal composed of several input planes. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool3d.html. In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` can be precisely described as: .. math:: \begin{aligned} \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) \end{aligned} If :attr:`padding` is non-zero, then the input is implicitly minimum value on both sides for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this link has a nice visualization of what :attr:`dilation` does. Note: When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the height dimension and the third `int` for the width dimension Args: kernel_size: the size of the window to take a max over stride: the stride of the window. Default value is :attr:`kernel_size` padding: implicit minimum value padding to be added on all three sides dilation: a parameter that controls the stride of elements in the window return_indices: if ``True``, will return the max indices along with the outputs. Useful for :class:`torch.nn.MaxUnpool3d` later ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor For example: .. code-block:: python import oneflow as flow import numpy as np of_maxpool3d = flow.nn.MaxPool3d(kernel_size=3, padding=1, stride=1) x = flow.Tensor(np.random.randn(1, 4, 4, 4, 4)) y = of_maxpool3d(x) y.shape oneflow.Size([1, 4, 4, 4, 4]) """ def __init__( self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, dilation: _size_3_t = 1, return_indices: bool = False, ceil_mode: bool = False, ): super().__init__() self.kernel_size = _triple(kernel_size) self.stride = _triple(stride) if (stride is not None) else _triple(kernel_size) data_format = "NCDHW" self.channel_pos = ( "channels_last" if data_format == "NDHWC" else "channels_first" ) self.dilation = _triple(dilation) self.padding = _triple(padding) self.return_indices = return_indices self.ceil_mode = ceil_mode def forward(self, x): y, indice = flow._C.max_pool3d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, return_indices=True, ceil_mode=self.ceil_mode, data_format=self.channel_pos, ) if self.return_indices: return y, indice else: return y def extra_repr(self) -> str: return "kernel_size={}, stride={}, padding={}, dilation={}".format( self.kernel_size, self.stride, self.padding, self.dilation ) class AvgPool1d(Module): r"""Applies a 1D average pooling over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` :math:`k` can be precisely described as: .. math:: out(N_i, C_j, l) = \\frac{1}{k} \\sum_{m=0}^{k-1} input(N_i, C_j, stride[0] \\times h + m, stride*l + m) If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. The parameters kernel_size, stride, padding can each be an int or a one-element tuple. Note: When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored. Args: kernel_size: the size of the window. strides: the stride of the window. Default value is kernel_size. padding: implicit zero padding to be added on both sides. ceil_mode: when True, will use ceil instead of floor to compute the output shape. count_include_pad: when True, will include the zero-padding in the averaging calculation. For example: .. code-block:: python import oneflow as flow import numpy as np m = flow.nn.AvgPool1d(kernel_size=3, padding=1, stride=1) x = flow.tensor(np.random.randn(1, 4, 4)) y = m(x) y.shape oneflow.Size([1, 4, 4]) """ def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, ): super().__init__() self.kernel_size = _single(kernel_size) data_format = "NCHW" # only support "NCHW" for now ! self.channel_pos = ( "channels_first" if data_format == "NCHW" else "channels_last" ) self.stride = _single(stride) if (stride is not None) else _single(kernel_size) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad self.padding = _single(padding) def forward(self, x): return flow._C.avg_pool1d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, ceil_mode=self.ceil_mode, count_include_pad=self.count_include_pad, divisor_override=0, data_format=self.channel_pos, ) def extra_repr(self) -> str: return ( "kernel_size={kernel_size}, stride={stride}, padding={padding}" ", ceil_mode={ceil_mode}".format(**self.__dict__) ) class AvgPool2d(Module): r"""Performs the 2d-average pooling on the input. In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` :math:`(kH, kW)` can be precisely described as: .. math:: out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) Args: kernel_size (Union[int, Tuple[int, int]]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input Tensor. strides (Union[int, Tuple[int, int]]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input Tensor. padding (Tuple[int, int]): An int or list of ints that has length 1, 2. Implicit zero padding to be added on both sides. ceil_mode (bool, default to False): When True, will use ceil instead of floor to compute the output shape. For example: .. code-block:: python import oneflow as flow import numpy as np m = flow.nn.AvgPool2d(kernel_size=3, padding=1, stride=1) x = flow.tensor(np.random.randn(1, 4, 4, 4)) y = m(x) y.shape oneflow.Size([1, 4, 4, 4]) """ def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: int = 0, ): super().__init__() self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) self.ceil_mode = ceil_mode self.channel_pos = "channels_first" if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_pos = "channels_last" self.padding = _pair(padding) self.count_include_pad = count_include_pad self.divisor_override = int(divisor_override) def to_memory_format(self, memory_format) -> None: if memory_format is flow.channels_last: self.channel_pos = "channels_last" elif memory_format is flow.contiguous_format: self.channel_pos = "channels_first" def forward(self, x): return flow._C.avg_pool2d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, ceil_mode=self.ceil_mode, count_include_pad=self.count_include_pad, divisor_override=self.divisor_override, data_format=self.channel_pos, ) def extra_repr(self) -> str: return ( "kernel_size={kernel_size}, stride={stride}, padding={padding}" ", ceil_mode={ceil_mode}".format(**self.__dict__) ) class AvgPool3d(Module): r"""Applies a 3D average pooling over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, output :math:`(N, C, D_{out}, H_{out}, W_{out})` and `kernel_size` :math:`(kD, kH, kW)` can be precisely described as: .. math:: out(N_i, C_j, d, h, w) = \\frac{1}{kD * kH * kW } \\sum_{k=0}^{kD-1} \\sum_{m=0}^{kH-1} \\sum_{n=0}^{kW-1} input(N_i, C_j, stride[0] \\times d + k, stride[1] \\times h + m, stride[2] \\times w + n) If padding is non-zero, then the input is implicitly zero-padded on all three sides for padding number of points. Note: When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored. Args: kernel_size: the size of the window. strides: the stride of the window. Default value is kernel_size. padding: implicit zero padding to be added on all three sides. ceil_mode: when True, will use ceil instead of floor to compute the output shape. count_include_pad: when True, will include the zero-padding in the averaging calculation. divisor_override: if specified, it will be used as divisor, otherwise kernel_size will be used. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \\left\\lfloor\\frac{D_{in} + 2 \\times \\text{padding}[0] - \\text{kernel_size}[0]}{\\text{stride}[0]} + 1\\right\\rfloor .. math:: H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[1] - \\text{kernel_size}[1]}{\\text{stride}[1]} + 1\\right\\rfloor .. math:: W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[2] - \\text{kernel_size}[2]}{\\text{stride}[2]} + 1\\right\\rfloor For example: .. code-block:: python import oneflow as flow import numpy as np m = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1)) x = flow.tensor(np.random.randn(9, 7, 11, 32, 20)) y = m(x) y.shape oneflow.Size([9, 7, 10, 31, 19]) """ def __init__( self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: int = 0, ): super().__init__() self.kernel_size = _triple(kernel_size) data_format = "NCHW" # only support "NCHW" for now ! self.channel_pos = ( "channels_first" if data_format == "NCHW" else "channels_last" ) self.stride = _triple(stride) if (stride is not None) else _triple(kernel_size) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad self.divisor_override = int(divisor_override) self.padding = _triple(padding) def forward(self, x): return flow._C.avg_pool3d( x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, ceil_mode=self.ceil_mode, count_include_pad=self.count_include_pad, divisor_override=self.divisor_override, data_format=self.channel_pos, ) def extra_repr(self) -> str: return ( "kernel_size={kernel_size}, stride={stride}, padding={padding}" ", ceil_mode={ceil_mode}".format(**self.__dict__) ) class AdaptiveAvgPool1d(Module): """Applies a 1D adaptive average pooling over an input signal composed of several input planes. The output size is H, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size H For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> m = nn.AdaptiveAvgPool1d(5) >>> input = flow.Tensor(np.random.randn(1, 64, 8)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 5]) """ def __init__(self, output_size: _size_1_t) -> None: super().__init__() assert output_size is not None, "'output_size' cannot be NoneType" self.output_size = _single(output_size) def forward(self, x): assert ( len(x.shape) == 3 and len(self.output_size) == 1 ), "the length of 'output_size' does not match the input size, 1 expected" assert isinstance( self.output_size[0], int ), "numbers in 'output_size' should be integer" return flow._C.adaptive_avg_pool1d(x, output_size=self.output_size) def adaptive_avg_pool1d(input, output_size): """Applies a 1D adaptive average pooling over an input signal composed of several input planes. See :mod:`oneflow.nn.AdaptiveAvgPool1d` Args: input: input tensor output_size: the target output size (single integer) """ return AdaptiveAvgPool1d(output_size)(input) class AdaptiveAvgPool2d(Module): """Applies a 2D adaptive average pooling over an input signal composed of several input planes. The output is of size H x W, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H. H and W can be either a ``int``, or ``None`` which means the size will be the same as that of the input. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> m = nn.AdaptiveAvgPool2d((5,7)) >>> input = flow.Tensor(np.random.randn(1, 64, 8, 9)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 5, 7]) >>> m = nn.AdaptiveAvgPool2d(7) >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 7, 7]) >>> m = nn.AdaptiveAvgPool2d((None, 7)) >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 10, 7]) """ def __init__(self, output_size, data_format=None) -> None: super().__init__() assert output_size is not None, "'output_size' cannot be NoneType" self.output_size = _pair(output_size) if data_format: if not data_format in ["channels_first", "channels_last"]: raise ValueError( f"data_format must be one of ['channels_first', 'channels_last'], but got {data_format}" ) self.channel_pos = data_format elif os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_pos = "channels_last" else: self.channel_pos = "channels_first" def to_memory_format(self, memory_format) -> None: if memory_format is flow.channels_last: self.channel_pos = "channels_last" elif memory_format is flow.channels_first: self.channel_pos = "channels_first" def forward(self, x): assert ( len(x.shape) == 4 ), f"expected 4-dimensional tensor, but got {len(x.shape)}-dimensional tensor" new_output_size = _generate_output_size(x.shape, self.output_size) return flow._C.adaptive_avg_pool2d( x, output_size=new_output_size, data_format=self.channel_pos ) def adaptive_avg_pool2d(input, output_size, data_format=None): """Applies a 2D adaptive average pooling over an input signal composed of several input planes. See :mod:`oneflow.nn.AdaptiveAvgPool2d` Args: input: input tensor output_size: the target output size (single integer or double-integer tuple) """ return AdaptiveAvgPool2d(output_size, data_format)(input) class AdaptiveAvgPool3d(Module): """Applies a 3D adaptive average pooling over an input signal composed of several input planes. The output is of size D x H x W, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size of the form D x H x W. Can be a tuple (D, H, W) or a single number D for a cube D x D x D. D, H and W can be either a ``int``, or ``None`` which means the size will be the same as that of the input. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> m = nn.AdaptiveAvgPool3d((5,7,9)) >>> input = flow.Tensor(np.random.randn(1, 64, 8, 9, 10)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 5, 7, 9]) >>> m = nn.AdaptiveAvgPool3d(7) >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9, 8)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 7, 7, 7]) >>> m = nn.AdaptiveAvgPool3d((7, None, None)) >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9, 8)) >>> output = m(input) >>> output.size() oneflow.Size([1, 64, 7, 9, 8]) """ def __init__(self, output_size) -> None: super().__init__() assert output_size is not None, "'output_size' cannot be NoneType" self.output_size = _triple(output_size) def forward(self, x): assert ( len(x.shape) == 5 ), f"expected 5-dimensional tensor, but got {len(x.shape)}-dimensional tensor" new_output_size = _generate_output_size(x.shape, self.output_size) return flow._C.adaptive_avg_pool3d(x, output_size=new_output_size) def adaptive_avg_pool3d(input, output_size): """Applies a 3D adaptive average pooling over an input signal composed of several input planes. See :mod:`oneflow.nn.AdaptiveAvgPool3d` Args: input: input tensor output_size: the target output size (single integer or triple-integer tuple) """ return AdaptiveAvgPool3d(output_size)(input) class _AdaptiveMaxPoolNd(Module): def __init__(self, output_size, return_indices: bool = False) -> None: super(_AdaptiveMaxPoolNd, self).__init__() self.output_size = output_size self.return_indices = return_indices def extra_repr(self) -> str: return "output_size={}".format(self.output_size) class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool1d.html. The output size is :math:`L_{out}`, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size :math:`L_{out}`. return_indices: if ``True``, will return the indices along with the outputs. Default: ``False`` Shape: - Input: :math:`(N, C, L_{in})`. - Output: :math:`(N, C, L_{out})`, where :math:`L_{out}=\text{output_size}`. Examples: .. code-block:: python >>> import oneflow as flow >>> # target output size of 5 >>> m = flow.nn.AdaptiveMaxPool1d(5) >>> input = flow.randn(1, 64, 8) >>> output = m(input) >>> print(output.shape) oneflow.Size([1, 64, 5]) """ def forward(self, input): self.output_size = _single(self.output_size) assert ( len(input.shape) == 3 and len(self.output_size) == 1 ), "the length of 'output_size' does not match the input size, 1 expected" new_output_size = _generate_output_size(input.shape, self.output_size) return flow.nn.functional.adaptive_max_pool1d( input, self.output_size, self.return_indices ) class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool2d.html. The output is of size :math:`H_{out} \times W_{out}`, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`. Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}` should be a ``int``. return_indices: if ``True``, will return the indices along with the outputs. Default: ``False`` Shape: - Input: :math:`(N, C, H_{in}, W_{in})`. - Output: :math:`(N, C, H_{out}, W_{out})`, where :math:`(H_{out}, W_{out})=\text{output_size}`. Examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> # target output size of 5x7 >>> m = nn.AdaptiveMaxPool2d((5,7)) >>> input = flow.randn(1, 64, 8, 9) >>> output = m(input) >>> print(output.shape) oneflow.Size([1, 64, 5, 7]) >>> # target output size of 7x7 (square) >>> m = nn.AdaptiveMaxPool2d(7) >>> input = flow.randn(1, 64, 10, 9) >>> output = m(input) >>> print(output.shape) oneflow.Size([1, 64, 7, 7]) """ def __init__(self, output_size, return_indices=False, data_format=None) -> None: super().__init__(output_size, return_indices=return_indices) if data_format: if not data_format in ["channels_first", "channels_last"]: raise ValueError( f"data_format must be one of ['channels_first', 'channels_last'], but got {data_format}" ) self.channel_pos = data_format elif os.getenv("ONEFLOW_ENABLE_NHWC") == "1": self.channel_pos = "channels_last" else: self.channel_pos = "channels_first" def to_memory_format(self, memory_format) -> None: if memory_format is flow.channels_last: self.channel_pos = "channels_last" elif memory_format is flow.channels_first: self.channel_pos = "channels_first" def forward(self, input): self.output_size = _pair(self.output_size) assert ( len(input.shape) == 4 ), f"expected 4-dimensional tensor, but got {len(input.shape)}-dimensional tensor" new_output_size = _generate_output_size(input.shape, self.output_size) return flow.nn.functional.adaptive_max_pool2d( input, self.output_size, self.return_indices, self.channel_pos ) class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool3d.html. The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`. Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`. :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` should be a ``int``. return_indices: if ``True``, will return the indices along with the outputs. Default: ``False`` Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where :math:`(D_{out}, H_{out}, W_{out})=\text{output_size}`. Examples: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> # target output size of 5x7x9 >>> m = nn.AdaptiveMaxPool3d((5,7,9)) >>> input = flow.randn(1, 64, 8, 9, 10) >>> output = m(input) >>> print(output.shape) oneflow.Size([1, 64, 5, 7, 9]) >>> # target output size of 7x7x7 (cube) >>> m = nn.AdaptiveMaxPool3d(7) >>> input = flow.randn(1, 64, 10, 9, 8) >>> output = m(input) >>> print(output.shape) oneflow.Size([1, 64, 7, 7, 7]) """ def forward(self, input): self.output_size = _triple(self.output_size) assert ( len(input.shape) == 5 ), f"expected 5-dimensional tensor, but got {len(input.shape)}-dimensional tensor" new_output_size = _generate_output_size(input.shape, self.output_size) return flow.nn.functional.adaptive_max_pool3d( input, self.output_size, self.return_indices ) class MaxUnpool1d(Module): r"""Computes a partial inverse of :class:`MaxPool1d`. :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost. :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d` including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxUnpool1d.html. .. note:: :class:`MaxPool1d` can map several input sizes to the same output sizes. Hence, the inversion process can get ambiguous. To accommodate this, you can provide the needed output size as an additional argument :attr:`output_size` in the forward call. See the Inputs and Example below. Args: kernel_size (int or tuple): Size of the max pooling window. stride (int or tuple): Stride of the max pooling window. It is set to :attr:`kernel_size` by default. padding (int or tuple): Padding that was added to the input Inputs: - `input`: the input Tensor to invert - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool1d` - `output_size` (optional): the targeted output size Shape: - Input: :math:`(N, C, H_{in})`. - Output: :math:`(N, C, H_{out})`, where .. math:: H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0] or as given by :attr:`output_size` in the call operator For example: .. code-block:: python >>> import oneflow as flow >>> pool = flow.nn.MaxPool1d(2, stride=2, return_indices=True) >>> unpool = flow.nn.MaxUnpool1d(2, stride=2) >>> input = flow.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]]) >>> output, indices = pool(input) >>> unpool(output, indices) tensor([[[0., 2., 0., 4., 0., 6., 0., 8.]]], dtype=oneflow.float32) >>> # Example showcasing the use of output_size >>> input = flow.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]]) >>> output, indices = pool(input) >>> unpool(output, indices, output_size=input.size()) tensor([[[0., 2., 0., 4., 0., 6., 0., 8., 0.]]], dtype=oneflow.float32) >>> unpool(output, indices) tensor([[[0., 2., 0., 4., 0., 6., 0., 8.]]], dtype=oneflow.float32) .. note:: When `indices` contains elements out of the `output_size` range, an RuntimeError will be raised on the cpu and an indeterminate result will be calculated on the cuda. """ def __init__( self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: Optional[_size_1_t] = 0, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x, indices, output_size=None): return flow._C.max_unpool1d( x, indices, self.kernel_size, self.stride, self.padding, output_size ) class MaxUnpool2d(Module): r"""Computes a partial inverse of :class:`MaxPool2d`. :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost. :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d` including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxUnpool2d.html. .. note:: :class:`MaxPool2d` can map several input sizes to the same output sizes. Hence, the inversion process can get ambiguous. To accommodate this, you can provide the needed output size as an additional argument :attr:`output_size` in the forward call. See the Inputs and Example below. Args: kernel_size (int or tuple): Size of the max pooling window. stride (int or tuple): Stride of the max pooling window. It is set to :attr:`kernel_size` by default. padding (int or tuple): Padding that was added to the input Inputs: - `input`: the input Tensor to invert - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool2d` - `output_size` (optional): the targeted output size Shape: - Input: :math:`(N, C, H_{in}, W_{in})` . - Output: :math:`(N, C, H_{out}, W_{out})`, where .. math:: H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} .. math:: W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} or as given by :attr:`output_size` in the call operator For example: .. code-block:: python >>> import oneflow as flow >>> pool = flow.nn.MaxPool2d(2, stride=2, return_indices=True) >>> unpool = flow.nn.MaxUnpool2d(2, stride=2) >>> input = flow.tensor([[[[ 1., 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12], ... [13, 14, 15, 16]]]]) >>> output, indices = pool(input) >>> unpool(output, indices) # doctest: +SKIP tensor([[[[ 0., 0., 0., 0.], [ 0., 6., 0., 8.], [ 0., 0., 0., 0.], [ 0., 14., 0., 16.]]]], dtype=oneflow.float32) >>> # specify a different output size than input size >>> unpool(output, indices, output_size=flow.Size([1, 1, 5, 5])) # doctest: +SKIP tensor([[[[ 0., 0., 0., 0., 0.], [ 6., 0., 8., 0., 0.], [ 0., 0., 0., 14., 0.], [16., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.]]]], dtype=oneflow.float32) .. note:: When `indices` contains elements out of the `output_size` range, an RuntimeError will be raised on the cpu and an indeterminate result will be calculated on the cuda. """ def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: Optional[_size_2_t] = 0, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x, indices, output_size=None): return flow._C.max_unpool2d( x, indices, self.kernel_size, self.stride, self.padding, output_size ) class MaxUnpool3d(Module): r"""Computes a partial inverse of :class:`MaxPool3d`. :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost. :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d` including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool3d.html. .. note:: :class:`MaxPool3d` can map several input sizes to the same output sizes. Hence, the inversion process can get ambiguous. To accommodate this, you can provide the needed output size as an additional argument :attr:`output_size` in the forward call. See the Inputs section below. Args: kernel_size (int or tuple): Size of the max pooling window. stride (int or tuple): Stride of the max pooling window. It is set to :attr:`kernel_size` by default. padding (int or tuple): Padding that was added to the input Inputs: - `input`: the input Tensor to invert - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool3d` - `output_size` (optional): the targeted output size Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} .. math:: H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} .. math:: W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} or as given by :attr:`output_size` in the call operator For example: .. code-block:: python >>> import oneflow as flow >>> # pool of square window of size=3, stride=2 >>> pool = flow.nn.MaxPool3d(3, stride=2, return_indices=True) >>> unpool = flow.nn.MaxUnpool3d(3, stride=2) >>> output, indices = pool(flow.randn(20, 16, 51, 33, 15)) >>> unpooled_output = unpool(output, indices) >>> unpooled_output.size() oneflow.Size([20, 16, 51, 33, 15]) .. note:: When `indices` contains elements out of the `output_size` range, an RuntimeError will be raised on the cpu and an indeterminate result will be calculated on the cuda. """ def __init__( self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: Optional[_size_3_t] = 0, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x, indices, output_size=None): return flow._C.max_unpool3d( x, indices, self.kernel_size, self.stride, self.padding, output_size ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.module import Module class Quantization(Module): """ Simulate the quantize operation in inference time. The output will be computed as: if quantization_scheme == "symmetric": .. math:: & quant\\_max = 2^{quantization\\_to\\_bit - 1} - 1 & quant\\_min = -quant\\_max & clamp(round(x / scale), quant\\_min, quant\\_max) elif quantization_scheme == "affine": .. math:: & quant\\_max = 2^{quantization\\_to\\_bit} - 1 & quant\\_min = 0 & (clamp(round(x / scale + zero\\_point), quant\\_min, quant\\_max) - zero\\_point) Args: quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". quantization_formula (str): Support "google" or "cambricon". Returns: oneflow.Tensor: Input tensor after quantize operation. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32) >>> input_tensor = flow.tensor( ... weight, dtype=flow.float32 ... ) >>> quantization_bit = 8 >>> quantization_scheme = "symmetric" >>> quantization_formula = "google" >>> per_layer_quantization = True >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization) >>> quantization = flow.nn.Quantization(quantization_formula=quantization_formula, quantization_bit=quantization_bit, ... quantization_scheme=quantization_scheme) >>> scale, zero_point = min_max_observer( ... input_tensor, ... ) >>> output_tensor = quantization( ... input_tensor, ... scale, ... zero_point, ... ) """ def __init__( self, quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", ) -> None: super().__init__() self.quantization_formula = quantization_formula self.quantization_bit = quantization_bit self.quantization_scheme = quantization_scheme def forward(self, input, scale, zero_point): return flow._C.quantization( input, scale, zero_point, self.quantization_formula, self.quantization_bit, self.quantization_scheme, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/reshape.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Sequence import oneflow as flow from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module def _input_args_is_int(args): return all((isinstance(x, int) for x in args)) def _input_args_is_flow_size(args): return all((isinstance(x, flow.Size) for x in args)) and len(args) == 1 def reshape_op(input, shape: Sequence[int] = None): """This operator reshapes a Tensor. We can set one dimension in `shape` as `-1`, the operator will infer the complete shape. Args: x: A Tensor. shape: Shape of the output tensor. Returns: A Tensor has the same type as `x`. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.array( ... [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ... ).astype(np.float32) >>> input = flow.Tensor(x) >>> y = flow.reshape(input, shape=[2, 2, 2, -1]).shape >>> y oneflow.Size([2, 2, 2, 2]) """ return flow._C.reshape(input, shape) def view_op(input, *shape): if len(shape) == 1: new_shape = shape[0] if isinstance(new_shape, int): new_shape = (new_shape,) else: new_shape = shape return flow._C.view(input, new_shape) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/rnn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import warnings import numbers from typing import List, Tuple, Optional import oneflow as flow from oneflow import nn from oneflow.framework.tensor import Tensor from oneflow.nn.utils.rnn import PackedSequence # NOTE(Liang Depeng): The implementation of rnn modules are modified from # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) class RNNBase(nn.Module): def __init__( self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.mode = mode self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional self.proj_size = proj_size num_directions = 2 if bidirectional else 1 if ( not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or isinstance(dropout, bool) ): raise ValueError( "dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed" ) if dropout > 0 and num_layers == 1: warnings.warn( "dropout option adds dropout after all but last " "recurrent layer, so non-zero dropout expects " "num_layers greater than 1, but got dropout={} and " "num_layers={}".format(dropout, num_layers) ) if proj_size < 0: raise ValueError( "proj_size should be a positive integer or zero to disable projections" ) if proj_size >= hidden_size: raise ValueError("proj_size has to be smaller than hidden_size") if mode == "LSTM": gate_size = 4 * hidden_size elif mode == "GRU": gate_size = 3 * hidden_size elif mode == "RNN_TANH": gate_size = hidden_size elif mode == "RNN_RELU": gate_size = hidden_size else: raise ValueError("Unrecognized RNN mode: " + mode) self._flat_weights_names = [] self._all_weights = [] for layer in range(num_layers): for direction in range(num_directions): real_hidden_size = proj_size if proj_size > 0 else hidden_size layer_input_size = ( input_size if layer == 0 else real_hidden_size * num_directions ) w_ih = nn.Parameter( flow.empty((gate_size, layer_input_size), **factory_kwargs) ) w_hh = nn.Parameter( flow.empty((gate_size, real_hidden_size), **factory_kwargs) ) b_ih = nn.Parameter(flow.empty(gate_size, **factory_kwargs)) b_hh = nn.Parameter(flow.empty(gate_size, **factory_kwargs)) layer_params: Tuple[Tensor, ...] = () if self.proj_size == 0: if bias: layer_params = (w_ih, w_hh, b_ih, b_hh) else: layer_params = (w_ih, w_hh) else: w_hr = nn.Parameter( flow.empty((proj_size, hidden_size), **factory_kwargs) ) if bias: layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) else: layer_params = (w_ih, w_hh, w_hr) suffix = "_reverse" if direction == 1 else "" param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] if bias: param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] if self.proj_size > 0: param_names += ["weight_hr_l{}{}"] param_names = [x.format(layer, suffix) for x in param_names] for name, param in zip(param_names, layer_params): setattr(self, name, param) self._flat_weights_names.extend(param_names) self._all_weights.append(param_names) self._flat_weights = [ (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names ] self.reset_parameters() def __setattr__(self, attr, value): if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: # keep self._flat_weights up to date if you do self.weight = ... idx = self._flat_weights_names.index(attr) self._flat_weights[idx] = value super().__setattr__(attr, value) def to_global(self, placement=None, sbp=None): def convert(t): return t.to_global(placement=placement, sbp=sbp) self = self._apply(convert) self._flat_weights = [ (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names ] return self def reset_parameters(self) -> None: stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv) def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError( "input must have {} dimensions, got {}".format( expected_input_dim, input.dim() ) ) if self.input_size != input.size(-1): raise RuntimeError( "input.size(-1) must be equal to input_size. Expected {}, got {}".format( self.input_size, input.size(-1) ) ) def get_expected_hidden_size( self, input: Tensor, batch_sizes: Optional[Tensor] ) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 if self.proj_size > 0: expected_hidden_size = ( self.num_layers * num_directions, mini_batch, self.proj_size, ) else: expected_hidden_size = ( self.num_layers * num_directions, mini_batch, self.hidden_size, ) return expected_hidden_size def check_hidden_size( self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], msg: str = "Expected hidden size {}, got {}", ) -> None: if hx.size() != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) def check_forward_args( self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] ): self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size) def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): if permutation is None: return hx return apply_permutation(hx, permutation) def extra_repr(self) -> str: s = "{input_size}, {hidden_size}" if self.proj_size != 0: s += ", proj_size={proj_size}" if self.num_layers != 1: s += ", num_layers={num_layers}" if self.bias is not True: s += ", bias={bias}" if self.batch_first is not False: s += ", batch_first={batch_first}" if self.dropout != 0: s += ", dropout={dropout}" if self.bidirectional is not False: s += ", bidirectional={bidirectional}" return s.format(**self.__dict__) @property def all_weights(self) -> List[List[nn.Parameter]]: return [ [getattr(self, weight) for weight in weights] for weights in self._all_weights ] class RNN(RNNBase): r""" Applies a multi-layer Elman RNN with \tanhtanh or \text{ReLU}ReLU non-linearity to an input sequence. For each element in the input sequence, each layer computes the following function: function: .. math:: h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the previous layer at time `t-1` or the initial hidden state at time `0`. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.RNN.html. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two RNNs together to form a `stacked RNN`, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each RNN layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` Inputs: input, h_0 * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of the input sequence. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden state for each element in the batch. Defaults to zeros if not provided. where: .. math:: \begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ H_{in} ={} & \text{input_size} \\ H_{out} ={} & \text{hidden_size} \end{aligned} Outputs: output, h_n * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features `(h_t)` from the last layer of the RNN, for each `t`. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state for each element in the batch. Attributes: weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(hidden_size, num_directions * hidden_size)` weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape `(hidden_size, hidden_size)` bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape `(hidden_size)` bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape `(hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. note:: For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. Example of splitting the output layers when ``batch_first=False``: ``output.view((seq_len, batch, num_directions, hidden_size))``. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> rnn = flow.nn.RNN(10, 20, 2) >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32) >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32) >>> output, hn = rnn(input, h0) >>> output.size() oneflow.Size([5, 3, 20]) """ def __init__(self, *args, **kwargs): if "proj_size" in kwargs: raise ValueError( "proj_size argument is only supported for LSTM, not RNN or GRU" ) self.nonlinearity = kwargs.pop("nonlinearity", "tanh") if self.nonlinearity == "tanh": mode = "RNN_TANH" elif self.nonlinearity == "relu": mode = "RNN_RELU" else: raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity)) super().__init__(mode, *args, **kwargs) def forward(self, input, hx=None): # noqa: F811 orig_input = input if isinstance(orig_input, PackedSequence): input = orig_input.data batch_sizes = orig_input.batch_sizes sorted_indices = orig_input.sorted_indices unsorted_indices = orig_input.unsorted_indices max_batch_size = int(batch_sizes[0]) else: batch_sizes = None is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: input = input.unsqueeze(batch_dim) if hx is not None: if hx.dim() != 2: raise RuntimeError( f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" ) hx = hx.unsqueeze(1) else: if hx is not None and hx.dim() != 3: raise RuntimeError( f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" ) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 if input.is_global: hx = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: hx = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device, ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self._flat_weights = [ (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names ] assert hx is not None self.check_forward_args(input, hx, batch_sizes) assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" if batch_sizes is None: if self.mode == "RNN_TANH": result = flow._C.rnn_tanh( input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first, ) else: result = flow._C.rnn_relu( input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first, ) else: if self.mode == "RNN_TANH": result = flow._C.rnn_tanh( input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, ) else: result = flow._C.rnn_relu( input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, ) output = result[0] hidden = result[1] if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( output, batch_sizes, sorted_indices, unsorted_indices ) return output_packed, self.permute_hidden(hidden, unsorted_indices) if not is_batched: output = output.squeeze(batch_dim) hidden = hidden.squeeze(1) return output, self.permute_hidden(hidden, unsorted_indices) class LSTM(RNNBase): r""" Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and output gates, respectively. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). Second, the output hidden state of each layer will be multiplied by a learnable projection matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/rnn.html#LSTM. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two LSTMs together to form a `stacked LSTM`, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Inputs: input, (h_0, c_0) * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of the input sequence. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden state for each element in the batch. Defaults to zeros if (h_0, c_0) is not provided. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the initial cell state for each element in the batch. Defaults to zeros if (h_0, c_0) is not provided. where: .. math:: \begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ H_{in} ={} & \text{input\_size} \\ H_{cell} ={} & \text{hidden\_size} \\ H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ \end{aligned} Outputs: output, (h_n, c_n) * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features `(h_t)` from the last layer of the LSTM, for each `t`. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state for each element in the batch. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the final cell state for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` was specified, the shape will be `(4*hidden_size, proj_size)`. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was specified. .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. note:: For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. Example of splitting the output layers when ``batch_first=False``: ``output.view(seq_len, batch, num_directions, hidden_size)``. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> rnn = flow.nn.LSTM(10, 20, 2) >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32) >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32) >>> c0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output.size() oneflow.Size([5, 3, 20]) """ def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) def get_expected_cell_size( self, input: Tensor, batch_sizes: Optional[Tensor] ) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 expected_hidden_size = ( self.num_layers * num_directions, mini_batch, self.hidden_size, ) return expected_hidden_size def check_forward_args( self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor], ): self.check_input(input, batch_sizes) self.check_hidden_size( hidden[0], self.get_expected_hidden_size(input, batch_sizes), "Expected hidden[0] size {}, got {}", ) self.check_hidden_size( hidden[1], self.get_expected_cell_size(input, batch_sizes), "Expected hidden[1] size {}, got {}", ) def permute_hidden( self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor] ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx return ( apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation), ) def forward(self, input, hx=None): orig_input = input batch_sizes = None if isinstance(orig_input, PackedSequence): input = orig_input.data batch_sizes = orig_input.batch_sizes sorted_indices = orig_input.sorted_indices unsorted_indices = orig_input.unsorted_indices max_batch_size = int(batch_sizes[0]) else: batch_sizes = None is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: input = input.unsqueeze(batch_dim) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 real_hidden_size = ( self.proj_size if self.proj_size > 0 else self.hidden_size ) if input.is_global: h_zeros = flow.zeros( self.num_layers * num_directions, max_batch_size, real_hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) c_zeros = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: h_zeros = flow.zeros( self.num_layers * num_directions, max_batch_size, real_hidden_size, dtype=input.dtype, device=input.device, ) c_zeros = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device, ) hx = (h_zeros, c_zeros) else: if batch_sizes is None: # If not PackedSequence input. if is_batched: if hx[0].dim() != 3 or hx[1].dim() != 3: msg = ( "For batched 3-D input, hx and cx should " f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" ) raise RuntimeError(msg) else: if hx[0].dim() != 2 or hx[1].dim() != 2: msg = ( "For unbatched 2-D input, hx and cx should " f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" ) raise RuntimeError(msg) hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) self._flat_weights = [ (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names ] if batch_sizes is None: result = flow._C.lstm( input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first, ) else: result = flow._C.lstm( input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, ) output = result[0] hidden = result[1:] if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( output, batch_sizes, sorted_indices, unsorted_indices ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: output = output.squeeze(batch_dim) hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) return output, self.permute_hidden(hidden, unsorted_indices) class GRU(RNNBase): r""" Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \\tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/rnn.html#GRU. Args: num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two GRUs together to form a `stacked GRU`, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of the input sequence. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden state for each element in the batch. Defaults to zeros if not provided. where: .. math:: \begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ H_{in} ={} & \text{input\_size} \\ H_{out} ={} & \text{hidden\_size} \end{aligned} Outputs: output, h_n * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features `(h_t)` from the last layer of the GRU, for each `t`. If a * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer (b_ir|b_iz|b_in), of shape `(3*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. note:: For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. Example of splitting the output layers when ``batch_first=False``: ``output.view(seq_len, batch, num_directions, hidden_size)``. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> rnn = flow.nn.GRU(10, 20, 2) >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32) >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32) >>> output, hn = rnn(input, h0) >>> output.size() oneflow.Size([5, 3, 20]) """ def __init__(self, *args, **kwargs): if "proj_size" in kwargs: raise ValueError( "proj_size argument is only supported for LSTM, not RNN or GRU" ) super().__init__("GRU", *args, **kwargs) def forward(self, input, hx=None): orig_input = input if isinstance(orig_input, PackedSequence): input = orig_input.data batch_sizes = orig_input.batch_sizes sorted_indices = orig_input.sorted_indices unsorted_indices = orig_input.unsorted_indices max_batch_size = int(batch_sizes[0]) else: batch_sizes = None is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: input = input.unsqueeze(batch_dim) if hx is not None: if hx.dim() != 2: raise RuntimeError( f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" ) hx = hx.unsqueeze(1) else: if hx is not None and hx.dim() != 3: raise RuntimeError( f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" ) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 if input.is_global: hx = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: hx = flow.zeros( self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device, ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) self._flat_weights = [ (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names ] if batch_sizes is None: result = flow._C.gru( input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first, ) else: result = flow._C.gru( input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, ) output = result[0] hidden = result[1] if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( output, batch_sizes, sorted_indices, unsorted_indices ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: output = output.squeeze(batch_dim) hidden = hidden.squeeze(1) return output, self.permute_hidden(hidden, unsorted_indices) class RNNCellBase(nn.Module): def __init__( self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.weight_ih = nn.Parameter( flow.empty(num_chunks * hidden_size, input_size, **factory_kwargs) ) self.weight_hh = nn.Parameter( flow.empty(num_chunks * hidden_size, hidden_size, **factory_kwargs) ) if bias: self.bias_ih = nn.Parameter( flow.empty(num_chunks * hidden_size, **factory_kwargs) ) self.bias_hh = nn.Parameter( flow.empty(num_chunks * hidden_size, **factory_kwargs) ) else: self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) self.reset_parameters() def extra_repr(self) -> str: s = "{input_size}, {hidden_size}" if "bias" in self.__dict__ and self.bias is not True: s += ", bias={bias}" if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": s += ", nonlinearity={nonlinearity}" return s.format(**self.__dict__) def reset_parameters(self) -> None: stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv) class RNNCell(RNNCellBase): r""" An Elman RNN cell with tanh or ReLU non-linearity. .. math:: h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.RNNCell.html. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` Inputs: input, hidden - **input**: tensor containing input features - **hidden**: tensor containing the initial hidden state Defaults to zero if not provided. Outputs: h' - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state for each element in the batch Shape: - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where :math:`H_{in}` = `input_size`. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. Attributes: weight_ih: the learnable input-hidden weights, of shape `(hidden_size, input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(hidden_size, hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> rnn = nn.RNNCell(10, 20) >>> input = flow.randn(6, 3, 10) >>> hx = flow.randn(3, 20) >>> hx = rnn(input[0], hx) >>> hx.size() oneflow.Size([3, 20]) """ def __init__( self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super(RNNCell, self).__init__( input_size, hidden_size, bias, num_chunks=1, **factory_kwargs ) self.nonlinearity = nonlinearity def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in ( 1, 2, ), f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: if input.is_global(): hx = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: hx = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, device=input.device, ) else: hx = hx.unsqueeze(0) if not is_batched else hx if self.nonlinearity == "tanh": ret = flow._C.rnn_tanh_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) elif self.nonlinearity == "relu": ret = flow._C.rnn_relu_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) else: raise RuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity)) if not is_batched: ret = ret.squeeze(0) return ret class LSTMCell(RNNCellBase): r""" A long short-term memory (LSTM) cell. .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.LSTMCell.html. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` Inputs: input, (h_0, c_0) - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. Outputs: (h_1, c_1) - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state Attributes: weight_ih: the learnable input-hidden weights, of shape `(4*hidden_size, input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(4*hidden_size, hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) >>> input = flow.randn(2, 3, 10) # (time_steps, batch, input_size) >>> hx = flow.randn(3, 20) # (batch, hidden_size) >>> cx = flow.randn(3, 20) >>> hx, cx = rnn(input[0], (hx, cx)) >>> hx.size() oneflow.Size([3, 20]) """ def __init__( self, input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LSTMCell, self).__init__( input_size, hidden_size, bias, num_chunks=4, **factory_kwargs ) def forward( self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None ) -> Tuple[Tensor, Tensor]: assert input.dim() in ( 1, 2, ), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: if input.is_global(): zeros = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: zeros = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, device=input.device, ) hx = (zeros, zeros) else: hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx ret = flow._C.lstm_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) if not is_batched: ret = (ret[0].squeeze(0), ret[1].squeeze(0)) return ret class GRUCell(RNNCellBase): r""" A gated recurrent unit (GRU) cell .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \end{array} where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.GRUCell.html. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` Inputs: input, hidden - **input** : tensor containing input features - **hidden** : tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. Outputs: h' - **h'** : tensor containing the next hidden state for each element in the batch Shape: - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where :math:`H_{in}` = `input_size`. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. Attributes: weight_ih: the learnable input-hidden weights, of shape `(3*hidden_size, input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(3*hidden_size, hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn as nn >>> rnn = nn.GRUCell(10, 20) >>> input = flow.randn(6, 3, 10) >>> hx = flow.randn(3, 20) >>> hx = rnn(input[0], hx) >>> hx.size() oneflow.Size([3, 20]) """ def __init__( self, input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in ( 1, 2, ), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: if input.is_global(): hx = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, sbp=input.sbp, placement=input.placement, ) else: hx = flow.zeros( input.size(0), self.hidden_size, dtype=input.dtype, device=input.device, ) else: hx = hx.unsqueeze(0) if not is_batched else hx ret = flow._C.gru_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) if not is_batched: ret = ret.squeeze(0) return ret if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/roll.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op def roll_op(input, shifts, dims=None): """Roll the tensor along the given dimension(s). Elements that are shifted beyond the last position are re-introduced at the first position. If a dimension is not specified, the tensor will be flattened before rolling and then restored to the original shape. Args: input (oneflow.Tensor): the input Tensor. shifts (int or tuple of ints): The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value. dims (int or tuple of ints): Axis along which to roll. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x = np.array([[1, 2], ... [3, 4], ... [5, 6], ... [7, 8]]) >>> input = flow.Tensor(x) >>> input.shape oneflow.Size([4, 2]) >>> out = flow.roll(input, 1, 0) >>> out tensor([[7., 8.], [1., 2.], [3., 4.], [5., 6.]], dtype=oneflow.float32) >>> input.roll(-1, 1) tensor([[2., 1.], [4., 3.], [6., 5.], [8., 7.]], dtype=oneflow.float32) """ return flow._C.roll(input, shifts, dims) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/scatter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module __all__ = ["scatter", "scatter_add", "scatter_nd", "tensor_scatter_nd_update"] def scatter(input, dim, index, src, *, reduce=None): r"""This operator writes the elements specified by `index` along with the axis `dim` from the `src` into the `input`. Take a 3-D blob as example, the output is specified by: .. code-block:: python input[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 input[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 input[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 input, index and src (if it is a Tensor) should all have the same number of dimensions. It is also required that index.shape(d) <= src.shape(d) for all dimensions d, and that index.shape(d) <= input.shape(d) for all dimensions d != dim. Note that index and src do not broadcast. .. warning:: When indices are not unique, the behavior is non-deterministic (one of the values from src will be picked arbitrarily) and the gradient will be incorrect (it will be propagated to all locations in the source that correspond to the same index)! .. note:: The backward pass is implemented only for ``src.shape == index.shape``. Additionally accepts an optional ``reduce`` argument that allows specification of an optional reduction operation, which is applied to all values in the tensor ``src`` into ``input`` at the indicies specified in the ``index``. For each value in ``src``, the reduction operation is applied to an index in ``input`` which is specified by its index in ``src`` for ``dimension != dim`` and by the corresponding value in ``index`` for ``dimension = dim``. Given a 3-D tensor and reduction using the multiplication operation, input is updated as: .. code-block:: python input[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 input[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 input[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 Reducing with the addition operation is the same as using :func:`oneflow.scatter_add()`. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.scatter\_.html. Args: input (Tensor): The input blob. dim (int): The axis along which to index index (Tensor): The index blob of elements to scatter. src (Tensor or float): The source blob whose elements will be scatterd and updated to output. reduce (str, optional): Reduction operation to apply, can be either ``add`` or ``multiply``. Returns: Tensor: The scatterd Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.ones((3,5))*2 >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32) >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]])) >>> out = flow.scatter(input, 1, index, src) >>> out tensor([[ 0., 10., 20., 2., 2.], [50., 60., 2., 2., 70.], [ 2., 2., 2., 2., 2.]], dtype=oneflow.float32) """ return flow._C.scatter(input, dim, index, src, reduce=reduce) def scatter_add(input, dim, index, src): r"""This operator scatter the src with addition operation according to index along dim into the input. Take a 3-D blob as example, the output is specified by: .. code-block:: python input[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 input[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 input[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 Args: input (Tensor): The input blob. dim (int): The axis along which to index index (Tensor): The index blob of elements to scatter. src (Tensor): The source blob whose elements will be scatterd and added to output. Returns: Tensor: The scatterd Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> input = flow.ones((3,5))*2 >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32) >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]])) >>> out = flow.scatter_add(input, 1, index, src) >>> out tensor([[ 2., 12., 22., 2., 2.], [52., 62., 2., 2., 72.], [ 2., 2., 2., 2., 2.]], dtype=oneflow.float32) """ assert type(src) in [ flow.Tensor ], f"type of src must be oneflow.Tensor, but %s givien" % type(src) return flow._C.scatter_add(input, dim, index, src) def scatter_nd(index, update, shape): """This operator inserts the elements in `update` according to the `index` and create a new Tensor. Args: index: The indices of `update`. Its type should be `flow.int`. update: The update Tensor. shape (Sequence[int]): The constant tensor shape, the constant tensor elements are all zero. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> index = flow.tensor(np.array([[1], [6], [4]]), dtype=flow.int) >>> update = flow.tensor(np.array([10.2, 5.1, 12.7]), dtype=flow.float) >>> out = flow.scatter_nd(index, update, [8]) >>> out tensor([ 0.0000, 10.2000, 0.0000, 0.0000, 12.7000, 0.0000, 5.1000, 0.0000], dtype=oneflow.float32) """ return flow._C.scatternd(index, update, shape) def tensor_scatter_nd_update(tensor, indices, updates): r""" This operation creates a new tensor by applying sparse updates to the input tensor. This is similar to an index assignment. This operator is very similar to :meth:`scatter_nd`, except that the updates are scattered onto an existing tensor (as opposed to a zero-tensor). Args: tensor: The tensor will be scattered. indices: The indices of ``update``. Its type should be `flow.int`. update: The update Tensor. For example: .. code-block:: python >>> import oneflow as flow >>> tensor = flow.arange(8) >>> indices = flow.tensor([[1], [3], [5]]) >>> updates = flow.tensor([-1, -2, -3]) >>> flow.tensor_scatter_nd_update(tensor, indices, updates) tensor([ 0, -1, 2, -2, 4, -3, 6, 7], dtype=oneflow.int64) """ return flow._C.tensor_scatter_nd_update(tensor, indices, updates) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/slice.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Sequence, Tuple import oneflow as flow from oneflow.ops.array_ops import parse_slice_tuple_list def slice_op(input, slice_tup_list: Sequence[Tuple[int, int, int]]): """Extracts a slice from a tensor. The `slice_tup_list` assigns the slice indices in each dimension, the format is (start, stop, step). The operator will slice the tensor according to the `slice_tup_list`. Args: input: A `Tensor`. slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.Tensor(np.random.randn(3, 6, 9).astype(np.float32)) >>> tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] >>> y = flow.slice(input, slice_tup_list=tup_list) >>> y.shape oneflow.Size([3, 3, 2]) """ (start, stop, step) = parse_slice_tuple_list(slice_tup_list, input.shape) return flow._C.slice(input, start, stop, step) def slice_update_op(input, update, slice_tup_list: Sequence[Tuple[int, int, int]]): """Update a slice of tensor `x`. Like `x[start:stop:step] = update`. Args: x: A `Tensor`, whose slice will be updated. update: A `Tensor`, indicate the update content. slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step). For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.Tensor(np.array([1, 1, 1, 1, 1]).astype(np.float32)) >>> update = flow.Tensor(np.array([2, 3, 4]).astype(np.float32)) >>> flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]]) tensor([1., 2., 3., 4., 1.], dtype=oneflow.float32) """ (start, stop, step) = parse_slice_tuple_list(slice_tup_list, input.shape) if update.dtype != input.dtype: update = update.to(dtype=input.dtype) return flow._C.slice_update(input, update, start, stop, step, inplace=True) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/sparse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os from typing import List, Optional, Tuple import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.nn.modules.module import Module class Embedding(Module): """A simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings. Args: num_embeddings (int): size of the dictionary of embeddings embedding_dim (int): the size of each embedding vector padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". For a newly constructed Embedding, the embedding vector at :attr:`padding_idx` will default to all zeros, but can be updated to another value to be used as the padding vector. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have norm :attr:`max_norm` norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default :attr:`2`. scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default :attr:`False` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> indices = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int) >>> m = flow.nn.Embedding(10, 3) >>> y = m(indices) .. Feature Stage of Operator [Embedding]. - Maintainer List [@EsdeathYZH] - Current Stage [ ] - Alpha Stage Check List [ ] - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes] - Doc(API Doc must be provided and showed normally on the web page.)[Yes] - Functionality and its' Test [ ] - Functionality is highly compatiable with PyTorch 1.11. [Yes] - eager local [Yes] [@EsdeathYZH] - forward [Yes] - backward [Yes] - gpu [Yes] - cpu [Yes] - graph local [ ] [@BBuf, @strint, @hjchen2] - forward [Yes] - backward [ ] - gpu [Yes] - cpu [Yes] - Exception Handling - Exception Message and Hint must be provided [ ] - Beta Stage Check List [ ] - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ] - Doc(Same standard as Alpha Stage)[ ] - Functionality and its' Test [ ] - eager global [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - graph gloal [ ] - forward [ ] - backward [ ] - gpu [ ] - cpu [ ] - Performance and Scalability(Must be evaluated.)[ ] - CUDA kernel [ ] - CPU kernel [ ] - N nodes M devices [ ] - Exception Handling [ ] - Exception Message and Hint must be provided [ ] - Try you best to do Exception Recovery [ ] - Stable Stage Check List [ ] - API(Same standard as Beta Stage)[ ] - Doc(Same standard as Beta Stage)[ ] - Functionality and its' Test [ ] - fp16 and AMP [ ] - NHWC [ ] - Performance and Scalability(Must be evaluated.)[ ] - Exception Handling [ ] """ def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: assert ( padding_idx < self.num_embeddings ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: assert ( padding_idx >= -self.num_embeddings ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq assert sparse is False, "Not support sparse=True yet!" if _weight is None: self.weight = flow.nn.Parameter( flow.empty((num_embeddings, embedding_dim), **factory_kwargs) ) self.reset_parameters() else: assert list(_weight.shape) == [ num_embeddings, embedding_dim, ], "Shape of weight does not match num_embeddings and embedding_dim" self.weight = flow.nn.Parameter(_weight) self.sparse = sparse def reset_parameters(self) -> None: if os.getenv("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "0") == "1": return flow.nn.init.normal_(self.weight) self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with flow.no_grad(): self.weight[self.padding_idx] = 0 def extra_repr(self) -> str: s = "{num_embeddings}, {embedding_dim}" if self.padding_idx is not None: s += ", padding_idx={padding_idx}" if self.max_norm is not None: s += ", max_norm={max_norm}" if self.norm_type != 2: s += ", norm_type={norm_type}" if self.scale_grad_by_freq is not False: s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: s += ", sparse=True" return s.format(**self.__dict__) def forward(self, indices): if self.max_norm is not None: with flow.no_grad(): flow._C.embedding_renorm_( self.weight, indices, self.max_norm, self.norm_type ) if self.padding_idx is None and not self.scale_grad_by_freq: return flow._C.gather(self.weight, indices, axis=0) else: return flow._C.embedding( self.weight, indices, self.padding_idx, self.scale_grad_by_freq ) def embedding( input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, ): r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. The input to the module is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings. See :class:`oneflow.nn.Embedding` for more details. Args: input (oneflow.LongTensor): Tensor containing indices into the embedding matrix weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, and number of columns equal to the embedding size padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False For example: .. code-block:: python >>> import oneflow as flow >>> import oneflow.nn.functional as F >>> # a batch of 2 samples of 4 indices each >>> input = flow.tensor([[1,2,4,5],[4,3,2,9]]) >>> # an embedding matrix containing 10 tensors of size 3 >>> embedding_matrix = flow.rand(10, 3) >>> output = F.embedding(input, embedding_matrix) >>> output.shape oneflow.Size([2, 4, 3]) >>> # example with padding_idx >>> input = flow.tensor([[0,2,0,5]]) >>> output = F.embedding(input, embedding_matrix, padding_idx=0) >>> output.shape oneflow.Size([1, 4, 3]) """ assert sparse is False, "Not support sparse=True yet!" if padding_idx is not None: if padding_idx > 0: assert padding_idx < weight.size( 0 ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: assert padding_idx >= -weight.size( 0 ), "Padding_idx must be within num_embeddings" padding_idx = weight.size(0) + padding_idx if max_norm is not None: with flow.no_grad(): weight = flow._C.embedding_renorm_(weight, input, max_norm, norm_type) if padding_idx is None and not scale_grad_by_freq: return flow._C.gather(weight, input, axis=0) else: return flow._C.embedding(weight, input, padding_idx, scale_grad_by_freq) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/sparse_softmax_cross_entropy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def sparse_softmax_cross_entropy(labels, logits): """The interface is consistent with TensorFlow. The documentation is referenced from: https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits Computes sparse softmax cross entropy between `logits` and `labels`. Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. A common use case is to have logits of shape `[batch_size, num_classes]` and have labels of shape `[batch_size]`, but higher dimensions are supported, in which case the `dim`-th dimension is assumed to be of size `num_classes`. `logits` must have the dtype of `float16`, `float32`, or `float64`, and `labels` must have the dtype of `int32` or `int64`. Args: labels (Tensor): shape with [d_0, d_1, ..., d_{r-1}] (where `r` is rank of `labels` and output) and dtype `int32` or `int64`. Each entry in `labels` must be an index in [0, num_classes). logits (Tensor): Per-label activations (typically a linear output) of shape [d_0, d_1, ..., d_{r-1}, num_classes] and dtype `float16`, `float32`, or `float64`. These activation energies are interpreted as unnormalized log probabilities. Returns: output (Tensor): A `Tensor` of the same shape as `labels` and of the same type as `logits` with the softmax cross entropy loss. Examples:: >>> import numpy as np >>> import oneflow as flow >>> np_logits = np.array( ... [ ... [2.0, -5.0, 0.5, -0.1], ... [0.0, 0.0, 1.9, 1.4], ... [-100.0, 100.0, -100.0, -100.0], ... ] ... ) >>> np_labels = np.array([0, 3, 1]) >>> logits = flow.tensor(np_logits, dtype=flow.float32) >>> labels = flow.tensor(np_labels, dtype=flow.int32) >>> output = flow.nn.functional.sparse_softmax_cross_entropy( ... labels=labels, logits=logits ... ) >>> output tensor([ 2.9751e-01, 1.1448e+00, -1.4305e-06], dtype=oneflow.float32) """ return flow._C.sparse_softmax_cross_entropy(logits, labels) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/tensor_buffer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional, Sequence import oneflow as flow def tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[int]): """This operator converts the Tensor's type from TensorBuffer to original type. Some operator's output data type is `TensorBuffer`, you can use this operator to convert back to `Tensor`. Refer to `Concept Explanation `_ for more about TensorBuffer. Args: x (oneflow.Tensor): The input Tensor. dtype (flow.dtype): The data dtype. instance_shape (Sequence[int]): The shape of each TensorBuffer instance. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32) >>> x = flow.Tensor(x) >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2) >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float) >>> output.shape oneflow.Size([4, 16, 64, 64]) """ return flow._C.tensor_buffer_to_tensor( x, dtype=dtype, instance_shape=instance_shape ) def tensor_to_tensor_buffer(x, instance_dims: int): """This operator converts the Tensor's type to TensorBuffer. Refer to `Concept Explanation `_ for more about TensorBuffer. Args: x (oneflow.Tensor): The input Tensor. instance_dims (int): The dimensions of dynamic tensor instance. Returns: oneflow.Tensor: The result Tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32) >>> x = flow.Tensor(x) >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2) >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float) >>> output.shape oneflow.Size([4, 16, 64, 64]) """ return flow._C.tensor_to_tensor_buffer(x, instance_dims) def gen_tensor_buffer( shape: Sequence[int], shape_list: Sequence[Sequence[int]], value_list: Sequence[float], data_type: Optional[flow.dtype] = flow.float32, dynamic_out: Optional[bool] = False, ): return flow._C.gen_tensor_buffer( shape, shape_list, value_list, data_type, dynamic_out ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/tensordot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from typing import Union, List, Tuple import warnings def tensordot( a, b, dims: Union[oneflow.Tensor, int, List[List[int]], Tuple[List[int]]] = 2, out=None, ): if out is not None: raise NotImplementedError( "tensordot with `out` parameter which is not None is not yet implemented" ) if not isinstance(dims, (oneflow.Tensor, int, list, tuple)): raise TypeError( f"oneflow.tensordot expects dims to be one of oneflow.Tensor, int, Tuple[List[int], List[int]] or List[List[int]] containing two lists, but got {type(dims)}" ) if isinstance(dims, int): return oneflow._C.tensordot(a, b, dims) elif isinstance(dims, (list, tuple)): assert ( len(dims) == 2 ), f"The list/tuple of dims must contain two lists, got {len(dims)}" dim_a = list(dims[0]) dim_b = list(dims[1]) elif isinstance(dims, oneflow.Tensor): warnings.warn( "tensordot doesn't support nn.Graph when the type of `dims` is oneflow.Tensor, because it needs synchronization." ) if dims.numel() == 1: return oneflow._C.tensordot(a, b, dims.item()) assert ( dims.dim() == 2 ), f"The dims tensor must have two dimensions, got {dims.dim()}" assert ( len(dims) == 2 and dims.dim() == 2 ), f"The dims tensor must have two rows, got {len(dims)}" dim_a = dims[0].tolist() dim_b = dims[1].tolist() return oneflow._C.tensordot(a, b, dim_a, dim_b) ================================================ FILE: python/oneflow/nn/modules/trigonometric_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.nn.modules.module import Module from oneflow.framework.tensor import register_tensor_op def sign_op(input): """Computes the sign of Tensor. .. math:: \\text{out}_{i} = \\text{sgn}(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.Tensor(np.array([-2, 0, 2]).astype(np.float32)) >>> out1 = flow.sign(x1) >>> out1.numpy() array([-1., 0., 1.], dtype=float32) >>> x2 = flow.Tensor(np.array([-3.2, -4.5, 5.8]).astype(np.float32),device=flow.device('cuda')) >>> out2 = flow.sign(x2) >>> out2.numpy() array([-1., -1., 1.], dtype=float32) """ return flow._C.sign(input) def sinh_op(input): """Returns a new tensor with the hyperbolic sine of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\sinh(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> x1 = flow.Tensor(np.array([1, 2, 3])) >>> x2 = flow.Tensor(np.array([1.53123589,0.54242598,0.15117185])) >>> x3 = flow.Tensor(np.array([1,0,-1])) >>> flow.sinh(x1).numpy() array([ 1.1752012, 3.6268604, 10.017875 ], dtype=float32) >>> flow.sinh(x2).numpy() array([2.20381 , 0.5694193, 0.1517483], dtype=float32) >>> flow.sinh(x3).numpy() array([ 1.1752012, 0. , -1.1752012], dtype=float32) """ return flow._C.sinh(input) def tan_op(input): """Returns the tan value of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\tan(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> output = flow.tan(input) >>> output tensor([-1., 0., 1.], dtype=oneflow.float32) """ return flow._C.tan(input) def acosh_op(input): """Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. .. math:: \\text{out}_{i} = \\cosh^{-1}(\\text{input}_{i}) Args: input (Tensor): the input tensor. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.Tensor(np.array([2, 3, 4]).astype(np.float32)) >>> out1 = flow.acosh(x1) >>> out1 tensor([1.3170, 1.7627, 2.0634], dtype=oneflow.float32) >>> x2 = flow.Tensor(np.array([1.5, 2.6, 3.7]).astype(np.float32),device=flow.device('cuda')) >>> out2 = flow.acosh(x2) >>> out2 tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32) """ return flow._C.acosh(input) def arccosh_op(input): """ See :func:`oneflow.acosh` """ return flow._C.acosh(input) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/unique.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow def unique_op( input, sorted=True, return_inverse=False, return_counts=False, dtype=flow.int ): r""" Returns the unique elements of the input tensor. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.unique.html. Args: input (Tensor): The input tensor. sorted (bool): Whether to sort the unique elements in ascending order before returning as output. return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. return_counts (bool): Whether to also return the counts for each unique element. dtype (flow.dtype): Dtype of the returned indices and counts. Returns: oneflow.Tensor or List of oneflow.Tensor: - **output** (Tensor): the output list of unique scalar elements. - **inverse_indices** (Tensor): (optional) if return_inverse is True, there will be an additional returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor. - **counts** (Tensor): (optional) if return_counts is True, there will be an additional returned tensor (same shape as output or output.size(dim), if dim was specified) representing the number of occurrences for each unique value or tensor. For example: .. code-block:: python >>> import oneflow as flow >>> x = flow.tensor([3, 1, 2, 0 ,2]) >>> flow.unique(x) tensor([0, 1, 2, 3], dtype=oneflow.int64) >>> flow.unique(x, sorted=False) tensor([3, 1, 2, 0], dtype=oneflow.int64) >>> results, indices = flow.unique(x, return_inverse=True) >>> indices tensor([3, 1, 2, 0, 2], dtype=oneflow.int32) >>> results, counts = flow.unique(x, return_counts=True) >>> counts tensor([1, 1, 2, 1], dtype=oneflow.int32) >>> results, indices = flow.unique(x, return_inverse=True, dtype=flow.long) >>> indices tensor([3, 1, 2, 0, 2], dtype=oneflow.int64) """ if not return_inverse and not return_counts: return flow._C.unique(input, sorted, dtype=dtype) else: return flow._C.unique( input, sorted, return_inverse=return_inverse, return_counts=return_counts, dtype=dtype, ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/upsampling.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Optional, Tuple, Union import oneflow as flow from oneflow.nn.modules.module import Module class Upsample(Module): """ Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. The input data is assumed to be of the form `minibatch x channels x [optional depth] x [optional height] x width`. Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. The algorithms available for upsampling are nearest neighbor and linear, bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, respectively. One can either give a :attr:`scale_factor` or the target output :attr:`size` to calculate the output size. (You cannot give both, as it is ambiguous) The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/upsampling.html. Args: size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): output spatial sizes scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): multiplier for spatial size. Has to match input size if it is a tuple. mode (str, optional): the upsampling algorithm: one of ``'nearest'``, ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. Default: ``'nearest'`` align_corners (bool, optional): if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False`` Shape: - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \\left\\lfloor D_{in} \\times \\text{scale_factor} \\right\\rfloor .. math:: H_{out} = \\left\\lfloor H_{in} \\times \\text{scale_factor} \\right\\rfloor .. math:: W_{out} = \\left\\lfloor W_{in} \\times \\text{scale_factor} \\right\\rfloor .. warning:: With ``align_corners = True``, the linearly interpolating modes (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally align the output and input pixels, and thus the output values can depend on the input size. This was the default behavior for these modes up to version 0.3.1. Since then, the default behavior is ``align_corners = False``. See below for concrete examples on how this affects the outputs. .. note:: If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) >>> input = input.to("cuda") >>> m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") >>> output = m(input) >>> output #doctest: +ELLIPSIS tensor([[[[1., 1., 2., 2.], ... [3., 3., 4., 4.]]]], device='cuda:0', dtype=oneflow.float32) """ def __init__( self, size: Optional[Union[int, Tuple[int, ...]]] = None, scale_factor: Optional[Union[float, Tuple[float, ...]]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, ): super().__init__() self.size = size self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return flow.nn.functional.interpolate( x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, ) def extra_repr(self) -> str: if self.scale_factor is not None: info = "scale_factor=" + str(self.scale_factor) else: info = "size=" + str(self.size) info += ", mode=" + self.mode return info class UpsamplingNearest2d(Upsample): """Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` as it's constructor argument. When :attr:`size` is given, it is the output size of the image `(h, w)`. Args: size (int or Tuple[int, int], optional): output spatial sizes scale_factor (float or Tuple[float, float], optional): multiplier for spatial size. .. warning:: This class is deprecated in favor of :func:`~nn.functional.interpolate`. Shape: - Input: :math:`(N, C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{out}, W_{out})` where .. math:: H_{out} = \\left\\lfloor H_{in} \\times \\text{scale_factor} \\right\\rfloor .. math:: W_{out} = \\left\\lfloor W_{in} \\times \\text{scale_factor} \\right\\rfloor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) >>> input = input.to("cuda") >>> m = flow.nn.UpsamplingNearest2d(scale_factor=2.0) >>> output = m(input) >>> output #doctest: +ELLIPSIS tensor([[[[1., 1., 2., 2.], ... [3., 3., 4., 4.]]]], device='cuda:0', dtype=oneflow.float32) """ def __init__( self, size: Optional[Tuple[int, int]] = None, scale_factor: Optional[Tuple[float, float]] = None, ) -> None: super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode="nearest") class UpsamplingBilinear2d(Upsample): """Applies a 2D bilinear upsampling to an input signal composed of several input channels. To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` as it's constructor argument. When :attr:`size` is given, it is the output size of the image `(h, w)`. Args: size (int or Tuple[int, int], optional): output spatial sizes scale_factor (float or Tuple[float, float], optional): multiplier for spatial size. .. warning:: This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. Shape: - Input: :math:`(N, C, H_{in}, W_{in})` - Output: :math:`(N, C, H_{out}, W_{out})` where .. math:: H_{out} = \\left\\lfloor H_{in} \\times \\text{scale_factor} \\right\\rfloor .. math:: W_{out} = \\left\\lfloor W_{in} \\times \\text{scale_factor} \\right\\rfloor For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) >>> input = input.to("cuda") >>> m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) >>> output = m(input) >>> output #doctest: +ELLIPSIS tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], ... [3.0000, 3.3333, 3.6667, 4.0000]]]], device='cuda:0', dtype=oneflow.float32) """ def __init__( self, size: Optional[Tuple[int, int]] = None, scale_factor: Optional[Tuple[float, float]] = None, ) -> None: super(UpsamplingBilinear2d, self).__init__( size, scale_factor, mode="bilinear", align_corners=True ) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/modules/utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections.abc as container_abcs from itertools import repeat from typing import List import oneflow as flow def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return tuple(x) return tuple(repeat(x, n)) return parse def _getint(): def parse(x): if isinstance(x, container_abcs.Iterable): return int(x[0]) return int(x) return parse _getint = _getint() _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def _handle_size_arg(size): if len(size) == 0: return size assert len(size) > 0, "size of tensor doesn't exists" if isinstance(size[0], (list, tuple, flow.Size)): assert ( len(size) == 1 ), "shape should be specified by tuple of int size, not tuple of list" size = size[0] return size def _reverse_repeat_tuple(t, n): """Reverse the order of `t` and repeat each element for `n` times. This can be used to translate padding arg used by Conv and Pooling modules to the ones used by `F.pad`. """ return tuple((x for x in reversed(t) for _ in range(n))) def _list_with_default(out_size, defaults): if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): raise ValueError( "Input dimension should be at least {}".format(len(out_size) + 1) ) return [ v if v is not None else d for (v, d) in zip(out_size, defaults[-len(out_size) :]) ] def _check_axis(axis, shape): ndim = len(shape) if axis is None: axis = list(range(len(shape))) if isinstance(axis, int): axis = [axis] assert isinstance(axis, (list, tuple)), "Invalid axis {}".format(axis) axis = list(axis) for i in range(len(axis)): assert ( -ndim <= axis[i] <= ndim - 1 ), "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( -ndim, ndim - 1, axis[i] ) if axis[i] < 0: axis[i] = axis[i] + ndim return axis def _generate_output_size(input_size, output_size): new_output_size = [] assert len(input_size) - 2 == len( output_size ), f"the length of 'output_size' does not match the input size, {len(input_size) - 2} expected" for i in range(len(output_size)): if output_size[i] is None: new_output_size.append(input_size[i + 2]) else: assert isinstance( output_size[i], int ), "numbers in 'output_size' should be integer" new_output_size.append(output_size[i]) return tuple(new_output_size) ================================================ FILE: python/oneflow/nn/modules/where.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import register_tensor_op def where_op(condition, x=None, y=None): if x is None and y is None: return flow.nonzero(condition, as_tuple=True) return flow._C.where(condition, x, y) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/optimizer/__init__.py ================================================ ================================================ FILE: python/oneflow/nn/optimizer/adadelta.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections import math from typing import Callable, Dict, Iterator, List, Tuple, Union import oneflow as flow from oneflow.optim.optimizer import Optimizer, ParamGroup from oneflow.nn.parameter import Parameter class Adadelta(Optimizer): r"""Implements Adadelta Optimizer. The formula is: .. math:: & v_{t} = v_{t-1} * rho + g_{t}^2 * (1 - rho) & delta = \frac{\sqrt{u_{t-1} + \epsilon}}{\sqrt{v_{t} + \epsilon}} * g_{t} & u_{t} = u_{t-1} * rho + delta^2*(1 - rho) & x_{t} = x_{t-1} - lr * delta Args: params (Union[Iterator[Parameter], List[Dict]]): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): The learning rate. Defaults to 0.001. rho (float, optional): The decay factor of learning rate. Defaults to 0.0. eps (float, optional): A small constant terms added to the denominator to improve numerical stability. Defaults to 1e-10. weight_decay (float, optional): The weight decay. Defaults to 0. maximize (bool, optional): maximize the params based on the objective, instead of minimizing. Defaults False. contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) For example: Example 1: .. code-block:: python # Assume net is a custom model. adadelta = flow.optim.Adadelta(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adadelta.step() adadelta.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. adadelta = flow.optim.Adadelta( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adadelta.clip_grad() adadelta.step() adadelta.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 1.0, rho: float = 0.9, eps: float = 1e-6, weight_decay: float = 0, maximize: bool = False, contiguous_params: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert 1.0 >= rho >= 0.0, f"Invalid rho value: {rho}" assert ( not maximize ), f"In Graph Mode, weight decay has been added to Variable, it cause different result with Eager Mode when maximize = True" options = dict() options["lr"] = lr options["rho"] = rho options["eps"] = eps options["maximize"] = maximize options["weight_decay"] = weight_decay options["contiguous_params"] = contiguous_params super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() self.state[param]["square_avgs"] = flow.zeros_like(param) self.state[param]["acc_deltas"] = flow.zeros_like(param) self._op = ( flow.stateful_op("adadelta_update") .Input("model") .Input("model_diff") .Input("square_avgs") .Input("acc_deltas") .Build() ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: kwargs = { "learning_rate": param_group["lr"], "l2": param_group["weight_decay"], "rho": param_group["rho"], "epsilon": param_group["eps"], "maximize": param_group["maximize"], } if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue square_avgs_tensor = self.state[param]["square_avgs"] acc_deltas_tensor = self.state[param]["acc_deltas"] flow._C.dispatch_adadelta_update( self._op, (param, param.grad, square_avgs_tensor, acc_deltas_tensor), **kwargs, ) self.state["step"] = self.state["step"] + 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) l2 = param_group["weight_decay"] rho = param_group["rho"] epsilon = param_group["eps"] maximize = param_group["maximize"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.adadelta_conf.rho = rho optimizer_conf.adadelta_conf.epsilon = epsilon optimizer_conf.adadelta_conf.maximize = maximize self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: vars_conf[param].l2 = l2 if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs ================================================ FILE: python/oneflow/nn/optimizer/adagrad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections import math from typing import Callable, Dict, Iterator, List, Tuple, Union import oneflow as flow from oneflow.optim.optimizer import Optimizer, ParamGroup from oneflow.nn.parameter import Parameter class Adagrad(Optimizer): r"""Implements Adagrad Optimizer. The formula is: .. math:: & S_{t} = S_{t-1} + grad \odot grad & decay\_lr = \frac{learning\_rate}{(1 + (train\_step - 1) * lr\_decay)} & X_{t} = X_{t-1} - \frac{decay\_lr}{\sqrt{S_{t} + \epsilon}} \odot grad Args: params (Union[Iterator[Parameter], List[Dict]]): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): The learning rate. Defaults to 0.001. lr_decay (float, optional): The decay factor of learning rate. Defaults to 0.0. weight_decay (float, optional): The weight decay. Defaults to 0. initial_accumulator_value (float, optional): The initial value of S. Defaults to 0.0. eps (float, optional): A small constant terms added to the denominator to improve numerical stability. Defaults to 1e-10. contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) For example: Example 1: .. code-block:: python # Assume net is a custom model. adagrad = flow.optim.Adagrad(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adagrad.step() adagrad.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. adagrad = flow.optim.Adagrad( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adagrad.clip_grad() adagrad.step() adagrad.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, lr_decay: float = 0.0, weight_decay: float = 0, initial_accumulator_value: float = 0.0, eps: float = 1e-10, contiguous_params: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" assert ( initial_accumulator_value >= 0.0 ), f"Invalid initial_accumulator_value value: {initial_accumulator_value}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" options = dict() options["lr"] = lr options["initial_accumulator_value"] = initial_accumulator_value options["lr_decay"] = lr_decay options["weight_decay"] = weight_decay options["eps"] = eps options["contiguous_params"] = contiguous_params super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() self.state[param]["sum"] = flow.zeros_like(param).fill_( param_group["initial_accumulator_value"] ) self._op = ( flow.stateful_op("adagrad_update") .Input("model") .Input("model_diff") .Input("sum") .Build() ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: kwargs = { "learning_rate": param_group["lr"], "l2": param_group["weight_decay"], "epsilon": param_group["eps"], "lr_decay": param_group["lr_decay"], "train_step_val": self.state["step"] + 1, } if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue sum_tensor = self.state[param]["sum"] flow._C.dispatch_adagrad_update( self._op, (param, param.grad, sum_tensor), **kwargs ) self.state["step"] = self.state["step"] + 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) l2 = param_group["weight_decay"] initial_accumulator_value = param_group["initial_accumulator_value"] lr_decay = param_group["lr_decay"] epsilon = param_group["eps"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.adagrad_conf.initial_accumulator_value = ( initial_accumulator_value ) optimizer_conf.adagrad_conf.lr_decay = lr_decay optimizer_conf.adagrad_conf.epsilon = epsilon self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: vars_conf[param].l2 = l2 if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs ================================================ FILE: python/oneflow/nn/optimizer/adam.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import math from typing import Callable, Dict, Iterator, List, Tuple, Union import oneflow as flow from oneflow.optim.optimizer import Optimizer, ParamGroup from oneflow.nn.parameter import Parameter class Adam(Optimizer): """Implements Adam algorithm. It has been proposed in `Adam: A Method for Stochastic Optimization`_. The implementation of the L2 penalty follows changes proposed in `Decoupled Weight Decay Regularization`_. This algorithm can adjust the learning rate of each parameter dynamically according to the 1st-moment estimates and the 2nd-moment estimates of gradient. the equation of parameters updating is: .. math:: & V_t = \\beta_1*V_{t-1} + (1-\\beta_1)*grad & S_t = \\beta_2*S_{t-1} + (1-\\beta_2)*{grad} \\odot {grad} & \\hat{g} = learning\\_rate*\\frac{{V_t}}{\\sqrt{{S_t}}+\\epsilon} & param_{new} = param_{old} - \\hat{g} Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. (default: False) do_bias_correction (bool, optional): whether to do bias correction (default: True) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) fused (bool, optional): whether to divide all the parameters into several groups, then update each group of parameters with the fused kernel. (default: False) .. _Adam\\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 For example: Example 1: .. code-block:: python # Assume net is a custom model. adam = flow.optim.Adam(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adam.step() adam.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. adam = flow.optim.Adam( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adam.clip_grad() adam.step() adam.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, do_bias_correction: bool = True, contiguous_params: bool = False, fused: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert ( betas[0] >= 0.0 and betas[0] < 1.0 ), f"Invalid beta parameter at index 0: {betas[0]}" assert ( betas[1] >= 0.0 and betas[1] < 1.0 ), f"Invalid beta parameter at index 1: {betas[1]}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" options = dict() options["lr"] = lr options["eps"] = eps options["betas"] = betas options["weight_decay"] = weight_decay options["amsgrad"] = amsgrad options["bias_correction1"] = 1.0 options["bias_correction2"] = 1.0 options["do_bias_correction"] = do_bias_correction options["contiguous_params"] = contiguous_params options["fused"] = fused super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() if param_group["fused"] and param_group["amsgrad"]: warnings.warn("Fused Adam is not supported when amsgrad=True.") param_group["fused"] = False if param_group["fused"] and not param.is_cuda: warnings.warn("Fused Adam only support cuda parameters.") param_group["fused"] = False self._op_with_amsgrad = ( flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Input("max_v") .Build() ) self._op_without_amsgrad = ( flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Build() ) def _single_tensor_update(self, param_group): kwargs = { "learning_rate": param_group["lr"], "bias_correction1": param_group["bias_correction1"], "bias_correction2": param_group["bias_correction2"], "l2": param_group["weight_decay"], "beta1": param_group["betas"][0], "beta2": param_group["betas"][1], "epsilon": param_group["eps"], "do_bias_correction": param_group["do_bias_correction"], "amsgrad": param_group["amsgrad"], } if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue if "exp_avg" not in self.state[param]: self.state[param]["exp_avg"] = flow.zeros_like(param) if "exp_avg_sq" not in self.state[param]: self.state[param]["exp_avg_sq"] = flow.zeros_like(param) if param_group["amsgrad"]: if "max_exp_avg_sq" not in self.state[param]: self.state[param]["max_exp_avg_sq"] = flow.zeros_like(param) m_tensor = self.state[param]["exp_avg"] v_tensor = self.state[param]["exp_avg_sq"] if param_group["amsgrad"]: max_v_tensor = self.state[param]["max_exp_avg_sq"] flow._C.dispatch_adam_update( self._op_with_amsgrad, (param, param.grad, m_tensor, v_tensor, max_v_tensor), **kwargs, ) else: flow._C.dispatch_adam_update( self._op_without_amsgrad, (param, param.grad, m_tensor, v_tensor), **kwargs, ) def _fused_update(self, param_group): param_list = [] param_grad_list = [] m_tensor_list = [] v_tensor_list = [] for param in param_group.parameters: if param.grad is None: continue if "exp_avg" not in self.state[param]: self.state[param]["exp_avg"] = flow.zeros_like(param) if "exp_avg_sq" not in self.state[param]: self.state[param]["exp_avg_sq"] = flow.zeros_like(param) if param_group["amsgrad"]: if "max_exp_avg_sq" not in self.state[param]: self.state[param]["max_exp_avg_sq"] = flow.zeros_like(param) param_list.append(param) param_grad_list.append(param.grad) m_tensor_list.append(self.state[param]["exp_avg"]) v_tensor_list.append(self.state[param]["exp_avg_sq"]) flow._C.multi_tensor_adam_update( model=param_list, model_diff=param_grad_list, m=m_tensor_list, v=v_tensor_list, learning_rate_val=param_group["lr"], l2=param_group["weight_decay"], beta1=param_group["betas"][0], beta2=param_group["betas"][1], bias_correction1_val=param_group["bias_correction1"], bias_correction2_val=param_group["bias_correction2"], do_bias_correction=param_group["do_bias_correction"], scale=1.0, weight_decay=0.0, epsilon=param_group["eps"], ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: if param_group["do_bias_correction"]: param_group["bias_correction1"] = 1.0 - math.pow( param_group["betas"][0], self.state["step"] + 1 ) param_group["bias_correction2"] = 1.0 - math.pow( param_group["betas"][1], self.state["step"] + 1 ) if param_group["fused"]: self._fused_update(param_group) else: self._single_tensor_update(param_group) self.state["step"] += 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) l2 = param_group["weight_decay"] beta1 = param_group["betas"][0] beta2 = param_group["betas"][1] epsilon = param_group["eps"] do_bias_correction = param_group["do_bias_correction"] amsgrad = param_group["amsgrad"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.adam_conf.beta1 = beta1 optimizer_conf.adam_conf.beta2 = beta2 optimizer_conf.adam_conf.epsilon = epsilon optimizer_conf.adam_conf.do_bias_correction = do_bias_correction optimizer_conf.adam_conf.amsgrad = amsgrad self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: vars_conf[param].l2 = l2 if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs @property def support_sparse(self): return True ================================================ FILE: python/oneflow/nn/optimizer/adamw.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import math from typing import Callable, Dict, Iterator, List, Tuple, Union import oneflow as flow from oneflow.optim.optimizer import Optimizer, ParamGroup from oneflow.nn.parameter import Parameter class AdamW(Optimizer): """Implements AdamW algorithm. The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. The optimizer of the Adam-weight-decay algorithm. (More details please refer to `Adam-weight-decay `_). So we use Adam-weight-decay algorithm to solve this problem. the equation of parameters updating is: .. math:: & V_t = \\beta_1*V_{t-1} + (1-\\beta_1)*grad & S_t = \\beta_2*S_{t-1} + (1-\\beta_2)*{grad} \\odot {grad} & \\hat{g} = learning\\_rate*(\\frac{{V_t}}{\\sqrt{{S_t}}+\\epsilon}+\\lambda*param_{old}) & param_{new} = param_{old} - \\hat{g} Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (In the equation is λ, default: 0) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. (default: False) do_bias_correction (bool, optional): whether to do bias correction (default: True) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) fused (bool, optional): whether to divide all the parameters into several groups, then update each group of parameters with the fused kernel. (default: False) .. _Adam\\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 For example: Example 1: .. code-block:: python # Assume net is a custom model. adamw = flow.optim.AdamW(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adamw.step() adamw.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. adamw = flow.optim.AdamW( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() adamw.clip_grad() adamw.step() adamw.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, do_bias_correction: bool = True, contiguous_params: bool = False, fused: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert ( betas[0] >= 0.0 and betas[0] < 1.0 ), f"Invalid beta parameter at index 0: {betas[0]}" assert ( betas[1] >= 0.0 and betas[1] < 1.0 ), f"Invalid beta parameter at index 1: {betas[1]}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" options = dict() options["lr"] = lr options["eps"] = eps options["betas"] = betas options["weight_decay"] = weight_decay options["bias_correction1"] = 1.0 options["bias_correction2"] = 1.0 options["do_bias_correction"] = do_bias_correction options["amsgrad"] = amsgrad options["contiguous_params"] = contiguous_params options["fused"] = fused super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() if param_group["fused"] and param_group["amsgrad"]: warnings.warn("Fused Adamw is not supported when amsgrad=True.") param_group["fused"] = False if param_group["fused"] and not param.is_cuda: warnings.warn("Fused Adamw only support cuda parameters.") param_group["fused"] = False self._op_with_amsgrad = ( flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Input("max_v") .Build() ) self._op_without_amsgrad = ( flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Build() ) def _single_tensor_update(self, param_group): kwargs = { "learning_rate": param_group["lr"], "bias_correction1": param_group["bias_correction1"], "bias_correction2": param_group["bias_correction2"], "weight_decay": param_group["weight_decay"], "beta1": param_group["betas"][0], "beta2": param_group["betas"][1], "epsilon": param_group["eps"], "do_bias_correction": param_group["do_bias_correction"], "amsgrad": param_group["amsgrad"], } if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue if "exp_avg" not in self.state[param]: self.state[param]["exp_avg"] = flow.zeros_like(param) if "exp_avg_sq" not in self.state[param]: self.state[param]["exp_avg_sq"] = flow.zeros_like(param) if param_group["amsgrad"]: if "max_exp_avg_sq" not in self.state[param]: self.state[param]["max_exp_avg_sq"] = flow.zeros_like(param) m_tensor = self.state[param]["exp_avg"] v_tensor = self.state[param]["exp_avg_sq"] if param_group["amsgrad"]: max_v_tensor = self.state[param]["max_exp_avg_sq"] flow._C.dispatch_adam_update( self._op_with_amsgrad, (param, param.grad, m_tensor, v_tensor, max_v_tensor), **kwargs, ) else: flow._C.dispatch_adam_update( self._op_without_amsgrad, (param, param.grad, m_tensor, v_tensor), **kwargs, ) def _fused_update(self, param_group): param_list = [] param_grad_list = [] m_tensor_list = [] v_tensor_list = [] for param in param_group.parameters: if param.grad is None: continue if "exp_avg" not in self.state[param]: self.state[param]["exp_avg"] = flow.zeros_like(param) if "exp_avg_sq" not in self.state[param]: self.state[param]["exp_avg_sq"] = flow.zeros_like(param) if param_group["amsgrad"]: if "max_exp_avg_sq" not in self.state[param]: self.state[param]["max_exp_avg_sq"] = flow.zeros_like(param) param_list.append(param) param_grad_list.append(param.grad) m_tensor_list.append(self.state[param]["exp_avg"]) v_tensor_list.append(self.state[param]["exp_avg_sq"]) flow._C.multi_tensor_adam_update( model=param_list, model_diff=param_grad_list, m=m_tensor_list, v=v_tensor_list, learning_rate_val=param_group["lr"], l2=0.0, beta1=param_group["betas"][0], beta2=param_group["betas"][1], bias_correction1_val=param_group["bias_correction1"], bias_correction2_val=param_group["bias_correction2"], do_bias_correction=param_group["do_bias_correction"], scale=1.0, weight_decay=param_group["weight_decay"], epsilon=param_group["eps"], ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: if param_group["do_bias_correction"]: param_group["bias_correction1"] = 1.0 - math.pow( param_group["betas"][0], self.state["step"] + 1 ) param_group["bias_correction2"] = 1.0 - math.pow( param_group["betas"][1], self.state["step"] + 1 ) if param_group["fused"]: self._fused_update(param_group) else: self._single_tensor_update(param_group) self.state["step"] += 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) weight_decay = param_group["weight_decay"] beta1 = param_group["betas"][0] beta2 = param_group["betas"][1] epsilon = param_group["eps"] do_bias_correction = param_group["do_bias_correction"] amsgrad = param_group["amsgrad"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.adam_conf.beta1 = beta1 optimizer_conf.adam_conf.beta2 = beta2 optimizer_conf.adam_conf.epsilon = epsilon optimizer_conf.adam_conf.do_bias_correction = do_bias_correction optimizer_conf.adam_conf.amsgrad = amsgrad optimizer_conf.weight_decay_conf.weight_decay_rate = weight_decay self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs @property def support_sparse(self): """Whether AdamW Optimizer support sparse update. """ return True ================================================ FILE: python/oneflow/nn/optimizer/chained_scheduler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .lr_scheduler import LRScheduler class ChainedScheduler(LRScheduler): """Chains list of learning rate schedulers. It takes a list of chainable learning rate schedulers and performs consecutive step() functions belong to them by just one call. Args: schedulers (list): List of chained schedulers. Example: >>> # Assuming optimizer uses lr = 1. for all groups >>> # lr = 0.09 if step == 0 >>> # lr = 0.081 if step == 1 >>> # lr = 0.729 if step == 2 >>> # lr = 0.6561 if step == 3 >>> # lr = 0.59049 if step >= 4 >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) >>> for _ in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, schedulers): if not isinstance(schedulers, (list, tuple)) or any( not isinstance(s, LRScheduler) for s in schedulers ): raise ValueError("ChainedScheduler expects a list of schedulers") if len(schedulers) == 0: raise ValueError("length of list of schedulers must be greater than 0") opt = schedulers[0].optimizer for i in range(1, len(schedulers)): if schedulers[i].optimizer != opt: raise ValueError( "ChainedScheduler expects all schedulers to belong to the same optimizer, but " f"got schedulers at index {0} and {i} to be different" ) self.schedulers = list(schedulers) super().__init__(optimizer=opt) def step(self): self.last_step += 1 lrs = self.schedulers[0].base_lrs.copy() for scheduler in self.schedulers: for i, lr in enumerate(lrs): lrs[i] = scheduler.get_lr(lr, self.last_step) scheduler.last_step = self.last_step self.update_lrs(lrs) def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. The wrapped scheduler states will also be saved. """ state_dict = { key: value for key, value in self.__dict__.items() if key not in ("optimizer", "schedulers") } state_dict["schedulers"] = [None] * len(self.schedulers) for i, s in enumerate(self.schedulers): state_dict["schedulers"][i] = s.state_dict() return state_dict def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ scheduler_states = state_dict.pop("schedulers") self.__dict__.update(state_dict) # avoid side effect of calling load_state_dict twice state_dict["schedulers"] = scheduler_states for i, s in enumerate(scheduler_states): self.schedulers[i].load_state_dict(s) def _generate_conf_for_graph(self, lr_conf): raise NotImplementedError("ChainedScheduler is not supported in graph mode yet") ================================================ FILE: python/oneflow/nn/optimizer/constant_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class ConstantLR(LRScheduler): """Decays the learning rate of each parameter group by a small constant factor until the number of step reaches a pre-defined milestone: total_iters. Args: optimizer (Optimizer): Wrapped optimizer. factor (float): The number we multiply learning rate until the milestone. Default: 1./3. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. last_step (int): The last step. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each step. Default: ``False``. Example: >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.025 if step == 0 >>> # lr = 0.025 if step == 1 >>> # lr = 0.025 if step == 2 >>> # lr = 0.025 if step == 3 >>> # lr = 0.05 if step >= 4 >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) >>> for step in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__( self, optimizer: Optimizer, factor: float = 1.0 / 3, total_iters: int = 5, last_step: int = -1, verbose: bool = False, ): assert isinstance(optimizer, Optimizer) if factor > 1.0 or factor < 0: raise ValueError( "Constant multiplicative factor expected to be between 0 and 1." ) self.factor = factor self.total_iters = total_iters super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): if step < self.total_iters: return base_lr * self.factor return base_lr def _generate_conf_for_graph(self, lr_conf): lr_conf.constant_lr_conf.SetInParent() constant_lr_conf = lr_conf.constant_lr_conf constant_lr_conf.factor = self.factor constant_lr_conf.total_iters = self.total_iters ================================================ FILE: python/oneflow/nn/optimizer/cosine_annealing_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class CosineAnnealingLR(LRScheduler): r""" Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: .. math:: \begin{aligned} \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), & T_{cur} \neq (2k+1)T_{max}; \\ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), & T_{cur} = (2k+1)T_{max}. \end{aligned} When last_step=-1, sets initial lr as lr. Notice that because the schedule is defined recursively, the learning rate can be simultaneously modified outside this scheduler by other operators. If the learning rate is set solely by this scheduler, the learning rate at each step becomes: .. math:: \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html. Args: optimizer (Optimizer): Wrapped optimizer. T_max (int): Maximum number of iterations. eta_min (float): Minimum learning rate. Default: 0. last_step (int): The index of last epoch. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 """ def __init__( self, optimizer: Optimizer, T_max: int, eta_min: float = 0.0, last_step: int = -1, verbose: bool = False, ): self.T_max = T_max self.eta_min = eta_min super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): cos_decay = 0.5 * (1 + math.cos(math.pi * step / self.T_max)) return self.eta_min + (base_lr - self.eta_min) * cos_decay def _generate_conf_for_graph(self, lr_conf): lr_conf.cosine_annealing_conf.SetInParent() cosine_annealing_conf = lr_conf.cosine_annealing_conf cosine_annealing_conf.t_max = self.T_max cosine_annealing_conf.eta_min = self.eta_min ================================================ FILE: python/oneflow/nn/optimizer/cosine_annealing_warm_restarts.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class CosineAnnealingWarmRestarts(LRScheduler): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the number of steps since the last restart and :math:`T_{i}` is the number of steps between two warm restarts in SGDR: .. math:: \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Args: optimizer (Optimizer): Wrapped optimizer. T_0 (int): Number of iterations for the first restart. T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. eta_min (float, optional): Minimum learning rate. Default: 0. decay_rate (float, optional): Decay rate every restarts. restart_limit (int, optional): The limit of restarts. 0 indicate unlimited restarts. Default: 0. last_step (int, optional): The index of last step. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 """ def __init__( self, optimizer: Optimizer, T_0: int, T_mult: int = 1, eta_min: float = 0.0, decay_rate: float = 1.0, restart_limit: int = 0, last_step: int = -1, verbose: bool = False, ): assert isinstance(optimizer, Optimizer) if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") self.T_0 = T_0 self.T_mult = T_mult self.eta_min = eta_min self.decay_rate = decay_rate self.restart_limit = restart_limit super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): if self.T_mult > 1: epoch = math.floor( math.log(1 - step / self.T_0 * (1 - self.T_mult), self.T_mult) ) epoch_steps = self.T_mult ** epoch * self.T_0 step_in_epoch = ( step - (1 - self.T_mult ** epoch) / (1 - self.T_mult) * self.T_0 ) else: epoch = step // self.T_0 epoch_steps = self.T_0 step_in_epoch = step - (epoch_steps * epoch) gamma = self.decay_rate ** epoch if self.restart_limit == 0 or ( self.restart_limit > 0 and epoch < self.restart_limit ): cos_decay = 0.5 * (1 + math.cos(math.pi * step_in_epoch / epoch_steps)) return self.eta_min + (base_lr * gamma - self.eta_min) * cos_decay return self.eta_min def _generate_conf_for_graph(self, lr_conf): lr_conf.cosine_annealing_warm_restarts_conf.SetInParent() cosa_warm_restarts_conf = lr_conf.cosine_annealing_warm_restarts_conf cosa_warm_restarts_conf.t_initial = self.T_0 cosa_warm_restarts_conf.t_mult = self.T_mult cosa_warm_restarts_conf.eta_min = self.eta_min cosa_warm_restarts_conf.decay_rate = self.decay_rate cosa_warm_restarts_conf.restart_limit = self.restart_limit ================================================ FILE: python/oneflow/nn/optimizer/cosine_decay_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class CosineDecayLR(LRScheduler): """This operator creates a Cosine decayed learning rate scheduler. Before the decay_steps are specified by user, the learning rate will be updated as: .. math:: & cos\\_decay = 0.5*(1+cos(\\pi*\\frac{current\\_step}{decay\\_steps})) & decay\\_factor = (1-\\alpha)*cos\\_decay+\\alpha & learning\\_rate = base\\_learning\\_rate*decay\\_factor After the decay_steps specified by user, the learning rate will be : .. math:: learning\\_rate = {base\\_learning\\_rate}*{\\alpha} It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. Args: optimizer(Optimizer): Wrapped optimizer. decay_steps (int): The decay steps in the scheduler. alpha (float, optional): The learning rate scale factor (:math:`\\alpha`). (default: 0.0) last_step (int, optional): The index of last step. (default: -1) verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``) For example: .. code-block:: python import oneflow as flow ... cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(optimizer, decay_steps=100, alpha=0.0) for epoch in range(num_epoch): train(...) cosine_decay_lr.step() .. _SGDR\\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 """ def __init__( self, optimizer: Optimizer, decay_steps: int, alpha: float = 0.0, last_step: int = -1, verbose: bool = False, ): assert ( decay_steps > 0 ), f"decay_steps must greater than zero, but got {decay_steps}" self.decay_steps = decay_steps self.alpha = alpha super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): if step < self.decay_steps: cos_decay = 0.5 * (1 + math.cos(math.pi * step / self.decay_steps)) decay_factor = (1 - self.alpha) * cos_decay + self.alpha else: decay_factor = self.alpha return base_lr * decay_factor def _generate_conf_for_graph(self, lr_conf): # CosineDecayLR is the same as CosineDecayConf in nn.Graph lr_conf.cosine_conf.SetInParent() cosine_decay_conf = lr_conf.cosine_conf cosine_decay_conf.decay_batches = self.decay_steps cosine_decay_conf.alpha = self.alpha ================================================ FILE: python/oneflow/nn/optimizer/exponential_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class ExponentialLR(LRScheduler): """ Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. last_step (int): The index of last step. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. """ def __init__( self, optimizer: Optimizer, gamma: float, last_step: int = -1, verbose: bool = False, ): assert isinstance(optimizer, Optimizer) if gamma <= 0.0: raise ValueError(f"'gamma' must be greater than zero, but got {gamma}") self.gamma = gamma super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): return base_lr * (self.gamma ** step) def _generate_conf_for_graph(self, lr_conf): lr_conf.step_conf.SetInParent() step_conf = lr_conf.step_conf step_conf.step_size = 1 step_conf.gamma = self.gamma ================================================ FILE: python/oneflow/nn/optimizer/lamb.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Callable, Dict, Iterator, List, Union, Tuple import math import oneflow as flow from oneflow.optim.optimizer import Optimizer from oneflow.nn.parameter import Parameter class LAMB(Optimizer): """Implements LAMB algorithm. LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. The equation of parameters updating is: .. math:: & V_t = \\beta_1*V_{t-1} + (1-\\beta_1)*grad & S_t = \\beta_2*S_{t-1} + (1-\\beta_2)*{grad} \\odot {grad} & \\hat{u} = \\frac{{V_t}}{\\sqrt{{S_t}}+\\epsilon} & \\hat{r} = learning\\_rate * \\frac{||param_{old}||_2}{||\\hat{u}||_2} & param_{new} = param_{old} - \\hat{r} * \\hat{u} Args: parameters (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) adam_w_mode (bool, optional): apply L2 regularization or weight decay True for decoupled weight decay (also known as AdamW) (default: True) do_bias_correction (bool, optional): whether to do bias correction (default: True) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. NOT SUPPORTED now! (default: False) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) .. _Large Batch Optimization for Deep Learning\\: Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 For example: Example 1: .. code-block:: python # Assume net is a custom model. lamb = flow.optim.LAMB(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() lamb.step() lamb.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. lamb = flow.optim.LAMB( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() lamb.clip_grad() lamb.step() lamb.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, adam_w_mode: bool = True, do_bias_correction: bool = True, amsgrad: bool = False, contiguous_params: bool = False, ): if amsgrad: # TODO: supported amsgrad in Lamb raise RuntimeError("LAMB does not support AMSGrad variant.") assert lr >= 0.0, f"Invalid learning rate: {lr}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert ( betas[0] >= 0.0 and betas[0] < 1.0 ), f"Invalid beta parameter at index 0: {betas[0]}" assert ( betas[1] >= 0.0 and betas[1] < 1.0 ), f"Invalid beta parameter at index 1: {betas[1]}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" options = dict() options["lr"] = lr options["eps"] = eps options["betas"] = betas options["weight_decay"] = weight_decay options["amsgrad"] = amsgrad options["adam_w_mode"] = adam_w_mode options["bias_correction1"] = 1.0 options["bias_correction2"] = 1.0 options["do_bias_correction"] = do_bias_correction options["contiguous_params"] = contiguous_params super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() self._op = ( flow.stateful_op("lamb_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Build() ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: if param_group["do_bias_correction"]: param_group["bias_correction1"] = 1.0 - math.pow( param_group["betas"][0], self.state["step"] + 1 ) param_group["bias_correction2"] = 1.0 - math.pow( param_group["betas"][1], self.state["step"] + 1 ) kwargs = { "learning_rate": param_group["lr"], "bias_correction1": param_group["bias_correction1"], "bias_correction2": param_group["bias_correction2"], "beta1": param_group["betas"][0], "beta2": param_group["betas"][1], "epsilon": param_group["eps"], "do_bias_correction": param_group["do_bias_correction"], } if param_group["adam_w_mode"]: kwargs["weight_decay"] = param_group["weight_decay"] kwargs["l2"] = 0.0 else: kwargs["l2"] = param_group["weight_decay"] kwargs["weight_decay"] = 0.0 if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue if "exp_avg" not in self.state[param]: self.state[param]["exp_avg"] = flow.zeros_like(param) if "exp_avg_sq" not in self.state[param]: self.state[param]["exp_avg_sq"] = flow.zeros_like(param) m_tensor = self.state[param]["exp_avg"] v_tensor = self.state[param]["exp_avg_sq"] flow._C.dispatch_lamb_update( self._op, (param, param.grad, m_tensor, v_tensor), **kwargs ) self.state["step"] += 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) adam_w_mode = param_group["adam_w_mode"] weight_decay = param_group["weight_decay"] beta1 = param_group["betas"][0] beta2 = param_group["betas"][1] do_bias_correction = param_group["do_bias_correction"] epsilon = param_group["eps"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.lamb_conf.beta1 = beta1 optimizer_conf.lamb_conf.beta2 = beta2 optimizer_conf.lamb_conf.epsilon = epsilon optimizer_conf.lamb_conf.do_bias_correction = do_bias_correction self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) if adam_w_mode: optimizer_conf.weight_decay_conf.weight_decay_rate = weight_decay else: optimizer_conf.weight_decay_conf.weight_decay_rate = 0.0 for param in param_group.parameters: if not adam_w_mode: # Set l2 penalty as weight decay if **NOT** using adam_w_mode vars_conf[param].l2 = weight_decay if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs ================================================ FILE: python/oneflow/nn/optimizer/lambda_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import types from .lr_scheduler import LRScheduler class LambdaLR(LRScheduler): """ Sets the learning rate of each parameter group to the initial lr times a given function. When last_step=-1, sets initial lr as lr. .. math:: learning\\_rate = base\\_learning\\_rate*lambda(last\\_step) Args: optimizer(Optimizer): Wrapped optimizer. lr_lambda(function or list): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_step (int, optional): The index of last step. (default: -1) verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``) For example: .. code-block:: python import oneflow as flow ... lambda1 = lambda step: step // 30 lambda2 = lambda step: 0.95 * step lambda_lr = flow.optim.lr_scheduler.LambdaLR(optimizer, [lambda1, lambda2]) for epoch in range(num_epoch): train(...) lambda_lr.step() """ def __init__(self, optimizer, lr_lambda, last_step=-1, verbose=False): if not isinstance(lr_lambda, (list, tuple)): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: assert len(lr_lambda) == len( optimizer.param_groups ), f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" self.lr_lambdas = list(lr_lambda) super().__init__(optimizer, last_step, verbose) def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. """ state_dict = { key: value for (key, value) in self.__dict__.items() if key not in ("optimizer", "lr_lambdas") } state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) for (idx, fn) in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict def load_state_dict(self, state_dict): """Loads the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ lr_lambdas = state_dict.pop("lr_lambdas") self.__dict__.update(state_dict) state_dict["lr_lambdas"] = lr_lambdas for (idx, fn) in enumerate(lr_lambdas): if fn is not None: self.lr_lambdas[idx].__dict__.update(fn) def step(self): """Performs a single learning rate schedule step. """ self.last_step += 1 lrs = [] for (lmbda, base_lr) in zip(self.lr_lambdas, self.base_lrs): lrs.append(base_lr * lmbda(self.last_step)) self.update_lrs(lrs) ================================================ FILE: python/oneflow/nn/optimizer/lbfgs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Callable, Dict, Iterator, List, Tuple, Union from functools import reduce from oneflow.optim.optimizer import Optimizer from oneflow.nn.parameter import Parameter import oneflow as flow # TODO implement quadrati_interpolate op def _quadratic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): if bounds is not None: xmin_bound, xmax_bound = bounds else: xmin_bound, xmax_bound = (x1, x2) if x1 < x2 else (x2, x1) if x1 == 0: t_new = -(g1 * (x2 ** 2)) / (2 * (f2 - f1 - g1 * x2)) else: a = -(f1 - f2 - g1 * (x1 - x2)) / ((x1 - x2) ** 2) t_new = x1 - g1 / (2 * a) return min(xmax_bound, max(xmin_bound, t_new)) def _strong_wolfe( eval_closure, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 ): d_norm = d.abs().max() g = g.clone() f_new, g_new = eval_closure(x, t, d) ls_func_evals = 1 gtd_new = g_new.dot(d) t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd done = False ls_iter = 0 while ls_iter < max_ls: if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new > f_prev): search_area = [t_prev, t] search_area_f = [f_prev, f_new] search_area_g = [g_prev, g_new.clone()] search_area_gtd = [gtd_prev, gtd_new] break if abs(gtd_new) <= -c2 * gtd: search_area = [t] search_area_f = [f_new] search_area_g = [g_new] done = True break if gtd_new >= 0: search_area = [t_prev, t] search_area_f = [f_prev, f_new] search_area_g = [g_prev, g_new.clone()] search_area_gtd = [gtd_prev, gtd_new] min_step = t + 0.01 * (t - t_prev) max_step = t * 10 tmp = t t = _quadratic_interpolate( t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) ) t_prev = tmp f_prev = f_new g_prev = g_new.clone() gtd_prev = gtd_new f_new, g_new = eval_closure(x, t, d) ls_func_evals += 1 gtd_new = g_new.dot(d) ls_iter += 1 if ls_iter == max_ls: search_area = [0, t] search_area_f = [f, f_new] search_area_g = [g, g_new] # zoom low_pos, high_pos = (0, 1) if search_area_f[0] <= search_area_f[-1] else (1, 0) while not done and ls_iter < max_ls: if abs(search_area[1] - search_area[0]) * d_norm < tolerance_change: break t = _quadratic_interpolate( search_area[0], search_area_f[0], search_area_gtd[0], search_area[1], search_area_f[1], search_area_gtd[1], ) f_new, g_new = eval_closure(x, t, d) ls_func_evals += 1 gtd_new = g_new.dot(d) ls_iter += 1 if f_new > (f + c1 * t * gtd) or f_new >= search_area_f[low_pos]: search_area[high_pos] = t search_area_f[high_pos] = f_new search_area_g[high_pos] = g_new.clone() search_area_gtd[high_pos] = gtd_new low_pos, high_pos = ( (0, 1) if search_area_f[0] <= search_area_f[1] else (1, 0) ) if abs(gtd_new) <= -c2 * gtd: done = True elif gtd_new * (search_area[high_pos] - search_area[low_pos]) >= 0: search_area[high_pos] = search_area[low_pos] search_area_f[high_pos] = search_area_f[low_pos] search_area_g[high_pos] = search_area_g[low_pos] search_area_gtd[high_pos] = search_area_gtd[low_pos] search_area[low_pos] = t search_area_f[low_pos] = f_new search_area_g[low_pos] = g_new.clone() search_area_gtd[low_pos] = gtd_new t = search_area[low_pos] f_new = search_area_f[low_pos] g_new = search_area_g[low_pos] return f_new, g_new, t, ls_func_evals class LBFGS(Optimizer): """Implements LBFGS algorithm It has been propose in `On the limited memory BFGS method for large scale optimization`_. The implementation of the two-loop recursion proposed in `Updating Quasi-Newton Matrices with Limited Storage`_. The implementation of the strong_wolfe line search proposed in `Numerical_Optimization_v2` This algorithm uses an estimated inverse Hessian matrix to steer its search through variable space and determine the optimal direction. The line search algorithm terminates with a step length that satisfies the strong Wolfe conditions. This optimizer only support one parameter group. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) max_iter (int,optional): max iteration per step (default: 20) max_eval (int,optional): max func evals per step (default: max_iter * 1.25) tolerance_grad (float, optional): termination tolerance on first order optimality (default 1e-7) tolerance_change (float, optional): termination tolerance on paramter changes (default: 1e-9) history_size (int,optional): paramter update history size (default: 100) line_search_fn (str,optional): line search function `strong_wolfe` or None (default: None) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) .. _On the limited memory BFGS method for large scale optimization: https://dl.acm.org/doi/10.5555/3112655.3112866 .. _Updating Quasi-Newton Matrices with Limited Storage: https://www.ams.org/journals/mcom/1980-35-151/S0025-5718-1980-0572855-7/S0025-5718-1980-0572855-7.pdf For example: .. code-block:: python # Assume net is a custom model. lbfgs = flow.optim.LBFGS(net.parameters()) for epoch in range (epochs): def closure(): lbfgs.zero_grad() # Read data, Compute the loss and so on. loss.backward() return loss lbfgs.step(closure) """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, max_iter: int = 20, max_eval: int = None, tolerance_grad: float = 1e-7, tolerance_change: float = 1e-9, history_size: int = 100, line_search_fn=None, contiguous_params: bool = False, ): if max_eval is None: max_eval = max_iter * 1.25 options = dict() options["lr"] = lr options["max_iter"] = max_iter options["max_eval"] = max_eval options["tolerance_grad"] = tolerance_grad options["tolerance_change"] = tolerance_change options["history_size"] = history_size options["line_search_fn"] = line_search_fn options["contiguous_params"] = contiguous_params super().__init__(params, options) assert ( len(self.param_groups) == 1 ), "LBFGS not support parameter groups (there can be only one)" param_group = self.param_groups[0] if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self._params = param_list self._numel_cache = None def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.new(p.numel()).zero_() else: view = p.grad.view(-1) views.append(view) return flow.cat(views, 0) def _numel(self): # get parameters total numel if self._numel_cache is None: self._numel_cache = reduce( lambda totnumel, p: totnumel + p.numel(), self._params, 0, ) return self._numel_cache def _update(self, step_size, direction): # update parameters offset = 0 for p in self._params: numel = p.numel() p.add_(direction[offset : offset + numel].view_as(p), alpha=step_size) offset += numel assert offset == self._numel() def _try_direction(self, closure, x, t, d): self._update(t, d) with flow.enable_grad(): loss = float(closure()) flag_grad = self._gather_flat_grad() for p, data in zip(self._params, x): p.copy_(data) return loss, flag_grad def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): assert closure != None, "closure must not be None" param_group = self.param_groups[0] lr = param_group["lr"] max_iter = param_group["max_iter"] max_eval = param_group["max_eval"] tolerance_grad = param_group["tolerance_grad"] tolerance_change = param_group["tolerance_change"] line_search_fn = param_group["line_search_fn"] history_size = param_group["history_size"] state = self.state[self._params[0]] state.setdefault("func_evals", 0) state.setdefault("n_iter", 0) with flow.enable_grad(): origin_loss = closure() loss = float(origin_loss) current_evals = 1 state["func_evals"] += 1 flat_grad = self._gather_flat_grad() if flat_grad.abs().max() <= tolerance_grad: return origin_loss # prev state d = state.get("d") t = state.get("t") old_diffs = state.get("old_diffs") old_step_size = state.get("old_step_size") ro = state.get("ro") H_diag = state.get("H_diag") prev_flat_grad = state.get("prev_flat_grad") prev_loss = state.get("prev_loss") n_iter = 0 while n_iter < max_iter: n_iter += 1 state["n_iter"] += 1 # compute direction if state["n_iter"] == 1: d = flat_grad.neg() old_diffs = [] old_step_size = [] ro = [] H_diag = 1 else: y = flat_grad.sub(prev_flat_grad) s = d.mul(t) ys = y.dot(s) # ys must be positive if ys > 1e-10: if len(old_diffs) == history_size: old_diffs.pop(0) old_step_size.pop(0) ro.pop(0) old_diffs.append(y) old_step_size.append(s) ro.append(1.0 / ys) H_diag = ys / y.dot(y) num_old = len(old_diffs) if "alpha" not in state: state["alpha"] = [None] * history_size alpha = state["alpha"] q = flat_grad.neg() for i in range(num_old - 1, -1, -1): alpha[i] = old_step_size[i].dot(q) * ro[i] q.add_(old_diffs[i], alpha=-alpha[i]) d = q.mul(H_diag) for i in range(num_old): beta_i = old_diffs[i].dot(d) * ro[i] d.add_(old_step_size[i], alpha=alpha[i] - beta_i) # compute step size if prev_flat_grad is None: prev_flat_grad = flat_grad.clone() else: prev_flat_grad.copy_(flat_grad) prev_loss = loss if state["n_iter"] == 1: t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr else: t = lr gtd = flat_grad.dot(d) if gtd > -tolerance_change: break ls_func_evals = 0 if line_search_fn is None: self._update(t, d) if n_iter != max_iter: with flow.enable_grad(): loss = float(closure()) flat_grad = self._gather_flat_grad() ls_func_evals = 1 else: assert ( line_search_fn == "strong_wolfe" ), "only strong_wolfe is expected" init_param = [p.clone() for p in self._params] def eval_func(x, t, d): return self._try_direction(closure, x, t, d) loss, flat_grad, t, ls_func_evals = _strong_wolfe( eval_func, init_param, t, d, loss, flat_grad, gtd ) self._update(t, d) current_evals += ls_func_evals state["func_evals"] += ls_func_evals if n_iter == max_iter: break if current_evals >= max_eval: break if flat_grad.abs().max() <= tolerance_grad: break if d.mul(t).abs().max() <= tolerance_change: break if abs(loss - prev_loss) < tolerance_change: break state["d"] = d state["t"] = t state["old_diffs"] = old_diffs state["old_step_size"] = old_step_size state["ro"] = ro state["prev_flat_grad"] = prev_flat_grad state["prev_loss"] = prev_loss state["H_diag"] = H_diag return origin_loss ================================================ FILE: python/oneflow/nn/optimizer/linear_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class LinearLR(LRScheduler): """Decays the learning rate of each parameter group by linearly changing small multiplicative factor until the number of step reaches a pre-defined milestone: total_iters. Args: optimizer (Optimizer): Wrapped optimizer. start_factor (float): The number we multiply learning rate in the first step. The multiplication factor changes towards end_factor in the following steps. Default: 1./3. end_factor (float): The number we multiply learning rate at the end of linear changing process. Default: 1.0. total_iters (int): The number of iterations that multiplicative factor reaches to 1. Default: 5. last_step (int): The index of the last step. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. Example: >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.025 if step == 0 >>> # lr = 0.03125 if step == 1 >>> # lr = 0.0375 if step == 2 >>> # lr = 0.04375 if step == 3 >>> # lr = 0.05 if step >= 4 >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) >>> for step in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__( self, optimizer: Optimizer, start_factor: float = 1.0 / 3, end_factor: float = 1.0, total_iters: int = 5, last_step: int = -1, verbose: bool = False, ): assert isinstance(optimizer, Optimizer) if start_factor > 1.0 or start_factor < 0: raise ValueError( "Starting multiplicative factor expected to be between 0 and 1." ) if end_factor > 1.0 or end_factor < 0: raise ValueError( "Ending multiplicative factor expected to be between 0 and 1." ) self.start_factor = start_factor self.end_factor = end_factor self.total_iters = total_iters super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): if step < self.total_iters: multiplier = self.start_factor + (self.end_factor - self.start_factor) * ( step / self.total_iters ) else: multiplier = self.end_factor return base_lr * multiplier def _generate_conf_for_graph(self, lr_conf): lr_conf.linear_lr_conf.SetInParent() linear_lr_conf = lr_conf.linear_lr_conf linear_lr_conf.start_factor = self.start_factor linear_lr_conf.end_factor = self.end_factor linear_lr_conf.total_iters = self.total_iters ================================================ FILE: python/oneflow/nn/optimizer/lr_scheduler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from ...optim.optimizer import Optimizer class LRScheduler(object): def __init__( self, optimizer: Optimizer, last_step: int = -1, verbose: bool = False ): if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer object") self.optimizer = optimizer self.last_step = last_step self.verbose = verbose self._init_base_lrs() self.step() def state_dict(self): """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return { key: value for (key, value) in self.__dict__.items() if key != "optimizer" } def load_state_dict(self, state_dict): """Load the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def get_lr(self, base_lr, step): """Compute learning rate using chainable form of the scheduler""" raise NotImplementedError def get_last_lr(self): """Return last computed learning rate by current scheduler.""" return self._last_lr def print_lr(self, group, lr): """Display the current learning rate.""" print( f"Last step {self.last_step} of {type(self)} adjusting learning rate " f"of param_groups[{group}] to {lr:.5f}" ) def step(self): self.last_step += 1 lrs = [self.get_lr(base_lr, self.last_step) for base_lr in self.base_lrs] self.update_lrs(lrs) def update_lrs(self, lrs): self._last_lr = [] for i, (group, lr) in enumerate(zip(self.optimizer.param_groups, lrs)): group["lr"] = lr self._last_lr.append(lr) if self.verbose: self.print_lr(i, lr) def _init_base_lrs(self): if self.last_step == -1: for group in self.optimizer.param_groups: if "initial_lr" not in group: group.setdefault("initial_lr", group["lr"]) else: for (i, group) in enumerate(self.optimizer.param_groups): if "initial_lr" not in group: raise KeyError( "param 'initial_lr' is not specified " f"in param_groups[{i}] when resuming an optimizer" ) self.base_lrs = [group["initial_lr"] for group in self.optimizer.param_groups] ================================================ FILE: python/oneflow/nn/optimizer/multiplicative_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class MultiplicativeLR(LRScheduler): """Multiply the learning rate of each parameter group by the factor given in the specified function. When last_epoch=-1, sets initial lr as lr. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiplicativeLR Args: optimizer (Optimizer): Wrapped optimizer. lr_lambda (function or list): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_step (int): The index of last step. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. For example: .. code-block:: python import oneflow as flow ... lmbda = lambda epoch: 0.95 step_lr = flow.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda) for epoch in range(num_epoch): train(...) step_lr.step() """ def __init__(self, optimizer, lr_lambda, last_step=-1, verbose=False): self.optimizer = optimizer if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: if len(lr_lambda) != len(optimizer.param_groups): raise ValueError( "Expected {} lr_lambdas, but got {}".format( len(optimizer.param_groups), len(lr_lambda) ) ) self.lr_lambdas = list(lr_lambda) super().__init__(optimizer, last_step, verbose) def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. """ state_dict = { key: value for key, value in self.__dict__.items() if key not in ("optimizer", "lr_lambdas") } state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ lr_lambdas = state_dict.pop("lr_lambdas") self.__dict__.update(state_dict) state_dict["lr_lambdas"] = lr_lambdas for idx, fn in enumerate(lr_lambdas): if fn is not None: self.lr_lambdas[idx].__dict__.update(fn) def step(self): """Performs a single learning rate schedule step. """ self.last_step += 1 if self.last_step > 0: lrs = [ group["lr"] * lmbda(self.last_step) for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) ] else: lrs = [group["lr"] for group in self.optimizer.param_groups] self.update_lrs(lrs) ================================================ FILE: python/oneflow/nn/optimizer/multistep_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import bisect from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class MultiStepLR(LRScheduler): """ Decays the learning rate of each parameter group by gamma once the number of step reaches one of the milestones. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler.When last_step=-1, sets initial lr as lr. Args: optimizer(Optimizer): Wrapped optimizer. milestones(list): List of step indices. Must be increasing gamma (float, optional): Multiplicative factor of learning rate decay. (default: 0.1) last_step (int, optional): The index of last step. (default: -1) verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``) For example: .. code-block:: python import oneflow as flow ... multistep_lr = flow.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) for epoch in range(num_epoch): train(...) multistep_lr.step() """ def __init__( self, optimizer: Optimizer, milestones: list, gamma: float = 0.1, last_step: int = -1, verbose: bool = False, ): for i in range(1, len(milestones)): assert ( milestones[i] > milestones[i - 1] ), f"values in `list` milestone must be increasing, but got {milestones}" assert gamma > 0.0, f"gamma must greater than zero, but got {gamma}" self.milestones = milestones self.gamma = gamma super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): sect = bisect.bisect_right(self.milestones, step) factor = self.gamma ** sect return base_lr * factor def _generate_conf_for_graph(self, lr_conf): lr_conf.multi_step_conf.SetInParent() multi_step_conf = lr_conf.multi_step_conf for milestone in self.milestones: multi_step_conf.milestones.append(milestone) multi_step_conf.gamma = self.gamma ================================================ FILE: python/oneflow/nn/optimizer/polynomial_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from .lr_scheduler import LRScheduler class PolynomialLR(LRScheduler): r""" This operator creates a polynomial decayed learning rate scheduler. The learning rate will be updated as follows: If cycle is `True`, the equation is: .. math:: \begin{aligned} & decay\_batch = decay\_batch*ceil(\frac{current\_batch}{decay\_batch}) \\ & learning\_rate = (base\_lr-end\_lr)*(1-\frac{current\_batch}{decay\_batch})^{power}+end\_lr \end{aligned} If cycle is `False`, the equation is: .. math:: \begin{aligned} & current\_batch = min(decay\_batch, current\_batch) \\ & learning\_rate = (base\_lr-end\_lr)*(1-\frac{current\_batch}{decay\_batch})^{power}+end\_lr \end{aligned} Args: optimizer (Optimizer): Wrapper optimizer. decay_batch (int): The decayed steps. end_learning_rate (float, optional): The final learning rate. Defaults to 0.0001. power (float, optional): The power of polynomial. Defaults to 1.0. cycle (bool, optional): If cycle is True, the scheduler will decay the learning rate every decay steps. Defaults to False. For example: .. code-block:: python import oneflow as flow ... polynomial_scheduler = flow.optim.lr_scheduler.PolynomialLR( optimizer, decay_batch=5, end_learning_rate=0.00001, power=2 ) for epoch in range(num_epoch): train(...) polynomial_scheduler.step() """ def __init__( self, optimizer, decay_batch: int, end_learning_rate: float = 0.0001, power: float = 1.0, cycle: bool = False, last_step: int = -1, verbose: bool = False, ): assert ( decay_batch > 0 ), f"decay_batch must greater than zero, but got {decay_batch}" self.max_decay_steps = decay_batch self.end_learning_rate = end_learning_rate self.power = power self.cycle = cycle super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): decay_batch = self.max_decay_steps cur_batch = step if self.cycle: if cur_batch == 0: cur_batch = 1 decay_batch = decay_batch * math.ceil(cur_batch / decay_batch) else: cur_batch = min(cur_batch, decay_batch) factor = (1 - cur_batch / decay_batch) ** (self.power) return (base_lr - self.end_learning_rate) * factor + self.end_learning_rate def _generate_conf_for_graph(self, lr_conf): lr_conf.polynomial_conf.SetInParent() polynomial_conf = lr_conf.polynomial_conf polynomial_conf.decay_batches = self.max_decay_steps polynomial_conf.end_learning_rate = self.end_learning_rate polynomial_conf.power = self.power polynomial_conf.cycle = self.cycle ================================================ FILE: python/oneflow/nn/optimizer/reduce_lr_on_plateau.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from math import inf from ...optim.optimizer import Optimizer class ReduceLROnPlateau(object): """Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced. Args: optimizer (Optimizer): Wrapped optimizer. mode (str): One of `min`, `max`. In `min` mode, lr will be reduced when the quantity monitored has stopped decreasing; in `max` mode it will be reduced when the quantity monitored has stopped increasing. Default: 'min'. factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. patience (int): Number of epochs with no improvement after which learning rate will be reduced. For example, if `patience = 2`, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn't improved then. Default: 10. threshold (float): Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold ) in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'. cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. For example: .. code-block:: python optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) scheduler = flow.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') for epoch in range(10): train(...) val_loss = validate(...) # Note that step should be called after validate() scheduler.step(val_loss) """ def __init__( self, optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose=False, ): if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer if isinstance(min_lr, list) or isinstance(min_lr, tuple): if len(min_lr) != len(optimizer.param_groups): raise ValueError( "expected {} min_lrs, got {}".format( len(optimizer.param_groups), len(min_lr) ) ) self.min_lrs = list(min_lr) else: self.min_lrs = [min_lr] * len(optimizer.param_groups) self.patience = patience self.verbose = verbose self.cooldown = cooldown self.cooldown_counter = 0 self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode self.best = None self.num_bad_steps = None self.mode_worse = None # the worse value for the chosen mode self.eps = eps self.last_step = 0 self._init_is_better( mode=mode, threshold=threshold, threshold_mode=threshold_mode ) self._reset() def step(self, metrics): """Performs a single learning rate schedule step. Arguments: metrics (float): a metrics quantity of Measuring the effect of model training. """ # convert `metrics` to float, in case it's a zero-dim Tensor current = float(metrics) self.last_step = self.last_step + 1 if self.is_better(current, self.best): self.best = current self.num_bad_steps = 0 else: self.num_bad_steps += 1 if self.in_cooldown: self.cooldown_counter -= 1 self.num_bad_steps = 0 # ignore any bad epochs in cooldown if self.num_bad_steps > self.patience: self._reduce_lr(self.last_step) self.cooldown_counter = self.cooldown self.num_bad_steps = 0 self._last_lr = [group["lr"] for group in self.optimizer.param_groups] @property def in_cooldown(self): """Whether the learning rate scheduler in cooldown phase. """ return self.cooldown_counter > 0 def is_better(self, a, best): """Whether the metric has improvement. """ if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon elif self.mode == "min" and self.threshold_mode == "abs": return a < best - self.threshold elif self.mode == "max" and self.threshold_mode == "rel": rel_epsilon = self.threshold + 1.0 return a > best * rel_epsilon else: # mode == 'max' and epsilon_mode == 'abs': return a > best + self.threshold def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return { key: value for key, value in self.__dict__.items() if key != "optimizer" } def load_state_dict(self, state_dict): """Loads the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) self._init_is_better( mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode ) def _reduce_lr(self, epoch): for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group["lr"] = new_lr if self.verbose: print( "Epoch {:5d}: reducing learning rate" " of group {} to {:.4e}.".format(epoch, i, new_lr) ) def _reset(self): """Resets num_bad_steps counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 self.num_bad_steps = 0 def _init_is_better(self, mode, threshold, threshold_mode): if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: raise ValueError("threshold mode " + threshold_mode + " is unknown!") if mode == "min": self.mode_worse = inf else: # mode == 'max': self.mode_worse = -inf self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode ================================================ FILE: python/oneflow/nn/optimizer/rmsprop.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections from typing import Callable, Dict, Iterator, List, Union import oneflow as flow from oneflow.optim.optimizer import Optimizer, ParamGroup from oneflow.nn.parameter import Parameter class RMSprop(Optimizer): """Implements RMSprop algorithm. oot Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method. The original slides proposed RMSProp: Slide 29 of http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf . The original equation is as follows: .. math:: r(w, t) = \\alpha r(w, t-1) + (1 - \\alpha)(\\nabla Q_{i}(w))^2 W = w - \\frac{\\eta} {\\\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w) The first equation calculates moving average of the squared gradient for each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`. In some cases, adding a momentum term :math: `\\beta` is beneficial. In our implementation, Nesterov momentum is used: .. math:: r(w, t) = \\alpha r(w, t-1) + (1 - \\alpha)(\\nabla Q_{i}(w))^2 v(w, t) = \\beta v(w, t-1) + \\frac{\\eta} {\\\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w) w = w - v(w, t) if centered is True: .. math:: r(w, t) = \\alpha r(w, t-1) + (1 - \\alpha)(\\nabla Q_{i}(w))^2 g(w, t) = \\alpha g(w, t-1) + (1 - \\alpha)\\nabla Q_{i}(w) v(w, t) = \\beta v(w, t-1) + \\frac{\\eta} {\\\\sqrt{r(w,t) - (g(w, t))^2 + \\epsilon}} \\nabla Q_{i}(w) w = w - v(w, t) where, :math:`\\alpha` is a hyperparameter and typical values are 0.99, 0.95 and so on. :math:`\\beta` is the momentum term. :math:`\\epsilon` is a smoothing term to avoid division by zero, usually set somewhere in range from 1e-4 to 1e-8. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) momentum (float, optional): momentum factor (default: 0, oneflow not support momenmtum > 0 now!) alpha (float, optional): smoothing constant (default: 0.99) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance weight_decay (float, optional): weight decay (L2 penalty) (default: 0) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) For example: Example 1: .. code-block:: python # Assume net is a custom model. rmsprop = flow.optim.RMSprop(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() rmsprop.step() rmsprop.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. rmsprop = flow.optim.RMSprop( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() rmsprop.clip_grad() rmsprop.step() rmsprop.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0.0, centered: bool = False, contiguous_params: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert alpha >= 0.0, f"Invalid alpha value: {alpha}" assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" assert momentum == 0.0, "Not support momentum greater than zeros now!" options = dict() options["lr"] = lr options["alpha"] = alpha options["eps"] = eps options["weight_decay"] = weight_decay options["centered"] = centered options["contiguous_params"] = contiguous_params super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() self._centered_rmsprop = ( flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") .Input("mean_gradient") .Build() ) self._rmsprop = ( flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") .Build() ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: kwargs = { "learning_rate": param_group["lr"], "epsilon": param_group["eps"], "decay_rate": param_group["alpha"], "l2": param_group["weight_decay"], } if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue if "square_avg" not in self.state[param]: self.state[param]["square_avg"] = flow.zeros_like(param) ms_tensor = self.state[param]["square_avg"] if param_group["centered"]: if "grad_avg" not in self.state[param]: self.state[param]["grad_avg"] = flow.zeros_like(param) mg_tensor = self.state[param]["grad_avg"] flow._C.dispatch_rmsprop_update( self._centered_rmsprop, (param, param.grad, ms_tensor, mg_tensor), centered=True, **kwargs, ) else: flow._C.dispatch_rmsprop_update( self._rmsprop, (param, param.grad, ms_tensor), **kwargs ) self.state["step"] = self.state["step"] + 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) decay_rate = param_group["alpha"] centered = param_group["centered"] weight_decay = param_group["weight_decay"] epslion = param_group["eps"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.rmsprop_conf.decay_rate = decay_rate optimizer_conf.rmsprop_conf.centered = centered optimizer_conf.rmsprop_conf.epsilon = epslion self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) # Set l2 penalty as weight decay for param in param_group.parameters: vars_conf[param].l2 = weight_decay if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs ================================================ FILE: python/oneflow/nn/optimizer/sequential_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import bisect from typing import Sequence, Union from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class SequentialLR(LRScheduler): """Receives the list of schedulers that is expected to be called sequentially during optimization process and milestone points that provides exact intervals to reflect which scheduler is supposed to be called at a given step. Args: optimizer (Optimizer): Wrapped optimizer. schedulers (list): List of chained schedulers. milestones (list): List of integers that reflects milestone points. interval_rescaling (bool or list): Each scheduler has a corresponding 'interval_rescaling'. If it is set to True, scheduler will start and end at the same values as it would if it were the only scheduler, otherwise all schedulers share the same step. Default is False for all schedulers. last_step (int): The index of last step. Default: -1. verbose (bool): Default: False. Print lr if is set to True. Example: >>> # Assuming optimizer uses lr = 1. for all groups >>> # lr = 0.1 if step == 0 >>> # lr = 0.1 if step == 1 >>> # lr = 0.9 if step == 2 >>> # lr = 0.81 if step == 3 >>> # lr = 0.729 if step == 4 >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) >>> for step in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__( self, optimizer: Optimizer, schedulers: Sequence[LRScheduler], milestones: Sequence[int], interval_rescaling: Union[Sequence[bool], bool] = False, last_step: int = -1, verbose: bool = False, ): assert isinstance(optimizer, Optimizer) assert isinstance(schedulers, (list, tuple)) assert isinstance(milestones, (list, tuple)) if len(schedulers) == 0: raise ValueError("Sequential Schedulers expects at least one scheduler") for i in range(len(schedulers)): if schedulers[i].optimizer != optimizer: raise ValueError( "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " f"got schedulers at index {i} to be different than the optimizer passed in." ) if len(milestones) != len(schedulers) - 1: raise ValueError( f"Sequential Schedulers expects number of schedulers provided to be one more " f"than the number of milestone points, but got number of schedulers {len(schedulers)} " f"and the number of milestones to be equal to {len(milestones)}" ) if isinstance(interval_rescaling, (list, tuple)): if len(interval_rescaling) != len(milestones): raise ValueError( "'interval_rescaling' expects a bool or a list of bool with length be equal to " f"the number of milestones, but got number of milestones {len(milestones)} " f"and the length of list of interval_rescaling {len(interval_rescaling)}" ) assert all([isinstance(r, bool) for r in interval_rescaling]) else: assert isinstance(interval_rescaling, bool) interval_rescaling = [interval_rescaling] * (len(milestones)) self.schedulers = list(schedulers) self.milestones = list(milestones) self.interval_rescaling = list(interval_rescaling) super().__init__(optimizer, last_step, verbose) def step(self): self.last_step += 1 cur_step = self.last_step s_i = bisect.bisect_right(self.milestones, cur_step) if s_i > 0 and self.interval_rescaling[s_i - 1]: cur_step = self.last_step - self.milestones[s_i - 1] scheduler = self.schedulers[s_i] scheduler.last_step = cur_step lrs = [scheduler.get_lr(base_lr, cur_step) for base_lr in self.base_lrs] self.update_lrs(lrs) def state_dict(self): # exclude optimizer and nested schedulers state_dict = { key: value for key, value in self.__dict__.items() if key not in ("optimizer", "schedulers") } state_dict["schedulers"] = [None] * len(self.schedulers) for i, s in enumerate(self.schedulers): state_dict["schedulers"][i] = s.state_dict() return state_dict def load_state_dict(self, state_dict): scheduler_states = state_dict.pop("schedulers") self.__dict__.update(state_dict) # avoid side effect of calling load_state_dict twice state_dict["schedulers"] = scheduler_states for i, s in enumerate(scheduler_states): self.schedulers[i].load_state_dict(s) def _generate_conf_for_graph(self, lr_conf): lr_conf.sequential_scheduler_conf.SetInParent() seq_lr_conf = lr_conf.sequential_scheduler_conf for scheduler in self.schedulers: scheduler._generate_conf_for_graph(seq_lr_conf.schedulers.add()) for m in self.milestones: seq_lr_conf.milestones.append(m) for r in self.interval_rescaling: seq_lr_conf.interval_rescaling.append(r) ================================================ FILE: python/oneflow/nn/optimizer/sgd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from typing import Callable, Dict, Iterator, List, Union import oneflow as flow from oneflow.nn.parameter import Parameter from ...optim.optimizer import Optimizer, ParamGroup class SGD(Optimizer): """Implements SGD algorithm. This algorithm takes a random sample's gradient as an approximate estimate of the overall gradient in small batch gradient descent. When the momentum = 0, the equation of parameters updating is: .. math:: param_{new} = param_{old} - learning\\_rate * grad With momentum, the equation of parameters updating is: .. math:: & V_t = \\beta * V_{t-1} - learning\\_rate * (g_t + param_{old} * weight\\_decay) & param_{new} = param_{old} + V_t Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) momentum (float, optional): Momentum factor (default: 0.0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) contiguous_params (bool, optional): whether to use contiguous ParamGroup which puts all parameters of the same type, device and group into the same tensor and update them together. (default: False) fused (bool, optional): whether to divide all the parameters into several groups, then update each group of parameters with the fused kernel. (default: False) For example: Example 1: .. code-block:: python # Assume net is a custom model. sgd = flow.optim.SGD(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() sgd.step() sgd.zero_grad() Example 2: .. code-block:: python # Assume net is a custom model. sgd = flow.optim.SGD( [ { "params": net.parameters(), "lr": learning_rate, "clip_grad_max_norm": 0.5, "clip_grad_norm_type": 2.0, } ], ) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() sgd.clip_grad() sgd.step() sgd.zero_grad() If you want to use clip_grad, you can refer this example. For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, maximize: bool = False, contiguous_params: bool = False, fused: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert momentum >= 0.0, f"Invalid momentum: {momentum}" assert weight_decay >= 0.0, f"Invalid weight_decay: {weight_decay}" if maximize: warnings.warn( "Only Momentum > 0.0, param `maximize` takes effect. ", FutureWarning, ) options = dict() options["lr"] = lr options["momentum"] = momentum options["dampening"] = dampening options["weight_decay"] = weight_decay options["nesterov"] = nesterov options["maximize"] = maximize options["contiguous_params"] = contiguous_params options["fused"] = fused super().__init__(params, options) for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() if param_group["fused"] and not param.is_cuda: warnings.warn("Fused SGD only support cuda parameters.") param_group["fused"] = False self._momentum_sgd = ( flow.stateful_op("momentum_update") .Input("model") .Input("model_diff") .Input("momentum") .Build() ) self._sgd = ( flow.stateful_op("sgd_update").Input("model").Input("model_diff").Build() ) def _single_tensor_update(self, param_group): lr = param_group["lr"] l2 = param_group["weight_decay"] if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: if param.grad is None: continue if param_group["momentum"] == 0.0: # TODO: Support param `maximize` in Naive SGD Optimizer. (zhengzekang) flow._C.dispatch_sgd_update( self._sgd, (param, param.grad), learning_rate=lr, l2=l2 ) else: if "momentum_buf" not in self.state[param]: self.state[param]["momentum_buf"] = flow.zeros_like(param) momentum_buf = self.state[param]["momentum_buf"] beta = param_group["momentum"] dampening = param_group["dampening"] nesterov = param_group["nesterov"] maximize = param_group["maximize"] flow._C.dispatch_momentum_update( self._momentum_sgd, (param, param.grad, momentum_buf), learning_rate=lr, l2=l2, beta=beta, dampening=dampening, nesterov=nesterov, maximize=maximize, ) def _fused_update(self, param_group): use_momentum = param_group["momentum"] != 0 param_list = [] param_grad_list = [] if use_momentum: momentum_buf_list = [] for param in param_group.parameters: if param.grad is None: continue param_list.append(param) param_grad_list.append(param.grad) if use_momentum: if "momentum_buf" not in self.state[param]: self.state[param]["momentum_buf"] = flow.zeros_like(param) momentum_buf_list.append(self.state[param]["momentum_buf"]) if not use_momentum: flow._C.multi_tensor_sgd_update( model=param_list, model_diff=param_grad_list, scale=1.0, weight_decay=param_group["weight_decay"], learning_rate_val=param_group["lr"], ) else: flow._C.multi_tensor_momentum_update( model=param_list, model_diff=param_grad_list, momentum_buf=momentum_buf_list, scale=1.0, weight_decay=param_group["weight_decay"], learning_rate_val=param_group["lr"], momentum=param_group["momentum"], dampening=param_group["dampening"], nesterov=param_group["nesterov"], maximize=param_group["maximize"], ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: with flow.enable_grad(): loss = closure() for param_group in self.param_groups: if param_group["fused"]: self._fused_update(param_group) else: self._single_tensor_update(param_group) self.state["step"] = self.state["step"] + 1 return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: assert ( param_group["contiguous_params"] != True ), "contiguous_params cannot be used in graph" optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) beta = param_group["momentum"] l2 = param_group["weight_decay"] dampening = param_group["dampening"] nesterov = param_group["nesterov"] maximize = param_group["maximize"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) if beta == 0: optimizer_conf.naive_conf.SetInParent() else: optimizer_conf.momentum_conf.beta = beta # Only Momentum Optimizer support these params. optimizer_conf.momentum_conf.dampening = dampening optimizer_conf.momentum_conf.nesterov = nesterov optimizer_conf.momentum_conf.maximize = maximize self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: vars_conf[param].l2 = l2 if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs @property def support_sparse(self): """Whether SGD Optimizer support sparse update. """ return True ================================================ FILE: python/oneflow/nn/optimizer/step_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler class StepLR(LRScheduler): """ Decays the learning rate of each parameter group by gamma every step_size steps. Notice that such decay can happen simultaneously with other changes to the learning rate fromoutside this scheduler. When last_step=-1, sets initial lr as lr. Args: optimizer(Optimizer): Wrapped optimizer. step_size (int): Period of learning rate decay. gamma (float, optional): Multiplicative factor of learning rate decay. (default: 0.1) last_step (int, optional): The index of last step. (default: -1) verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``) For example: .. code-block:: python import oneflow as flow ... step_lr = flow.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) for epoch in range(num_epoch): train(...) step_lr.step() """ def __init__( self, optimizer: Optimizer, step_size: int, gamma: float = 0.1, last_step: int = -1, verbose: bool = False, ): assert step_size > 0, f"step_size must greater than zero, but got {step_size}" assert gamma > 0.0, f"gamma must greater than zero, but got {gamma}" self.step_size = step_size self.gamma = gamma super().__init__(optimizer, last_step, verbose) def get_lr(self, base_lr, step): step_stage = math.floor(step / self.step_size) factor = self.gamma ** step_stage return base_lr * factor def _generate_conf_for_graph(self, lr_conf): lr_conf.step_conf.SetInParent() step_conf = lr_conf.step_conf step_conf.step_size = self.step_size step_conf.gamma = self.gamma ================================================ FILE: python/oneflow/nn/optimizer/swa_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""" Swa_utils Methods are consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging. """ import itertools import math from copy import deepcopy import warnings import oneflow as flow from oneflow.nn import Module from oneflow.nn.optimizer.lr_scheduler import LRScheduler __all__ = ["AveragedModel", "update_bn", "SWALR"] class AveragedModel(Module): r"""Implements averaged model for Stochastic Weight Averaging (SWA). The documentation is referenced from: https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging Stochastic Weight Averaging was proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018). AveragedModel class creates a copy of the provided module :attr:`model` on the device :attr:`device` and allows to compute running averages of the parameters of the :attr:`model`. Args: model (oneflow.nn.Module): model to use with SWA device (oneflow.device, optional): if provided, the averaged model will be stored on the :attr:`device` avg_fn (function, optional): the averaging function used to update parameters; the function must take in the current value of the :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: None) use_buffers (bool): if ``True``, it will compute running averages for both the parameters and the buffers of the model. (default: ``False``) For example: .. code-block:: python import oneflow as flow ... loader, optimizer, model, loss_fn = ... swa_model = flow.optim.swa_utils.AveragedModel(model) scheduler = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) swa_start = 160 swa_scheduler = SWALR(optimizer, swa_lr=0.05) for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() if i > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step() # Update bn statistics for the swa_model at the end flow.optim.swa_utils.update_bn(loader, swa_model) You can also use custom averaging functions with `avg_fn` parameter. If no averaging function is provided, the default is to compute equally-weighted average of the weights. For example: .. code-block:: python import oneflow as flow ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: ( 0.1 * averaged_model_parameter + 0.9 * model_parameter) swa_model = flow.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True) .. note:: When using SWA with models containing Batch Normalization you may need to update the activation statistics for Batch Normalization. This can be done either by using the :meth:`oneflow.optim.swa_utils.update_bn` or by setting :attr:`use_buffers` to `True`. The first approach updates the statistics in a post-training step by passing data through the model. The second does it during the parameter update phase by averaging all buffers. Empirical evidence has shown that updating the statistics in normalization layers increases accuracy, but you may wish to empirically test which approach yields the best results in your problem. .. note:: :attr:`avg_fn` is not saved in the :meth:`state_dict` of the model. .. note:: When :meth:`update_parameters` is called for the first time (i.e. :attr:`n_averaged` is `0`) the parameters of `model` are copied to the parameters of :class:`AveragedModel`. For every subsequent call of :meth:`update_parameters` the function `avg_fn` is used to update the parameters. .. _Averaging Weights Leads to Wider Optima and Better Generalization: https://arxiv.org/abs/1803.05407 .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average: https://arxiv.org/abs/1806.05594 .. _SWALP: Stochastic Weight Averaging in Low-Precision Training: https://arxiv.org/abs/1904.11943 .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That Generalizes Well: https://arxiv.org/abs/2001.02312 """ def __init__(self, model, device=None, avg_fn=None, use_buffers=False): super(AveragedModel, self).__init__() self.module = deepcopy(model) if device is not None: self.module = self.module.to(device) self.register_buffer( "n_averaged", flow.tensor(0, dtype=flow.long, device=device) ) if avg_fn is None: def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return averaged_model_parameter + ( model_parameter - averaged_model_parameter ) / (num_averaged + 1) self.avg_fn = avg_fn self.use_buffers = use_buffers def forward(self, *args, **kwargs): return self.module(*args, **kwargs) def update_parameters(self, model): self_param = ( itertools.chain(self.module.parameters(), self.module.buffers()) if self.use_buffers else self.parameters() ) model_param = ( itertools.chain(model.parameters(), model.buffers()) if self.use_buffers else model.parameters() ) for p_swa, p_model in zip(self_param, model_param): device = p_swa.device p_model_ = p_model.detach().to(device) if self.n_averaged == 0: p_swa.detach().copy_(p_model_) else: p_swa.detach().copy_( self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)) ) if not self.use_buffers: # If not apply running averages to the buffers, # keep the buffers in sync with the source model. for b_swa, b_model in zip(self.module.buffers(), model.buffers()): b_swa.detach().copy_(b_model.detach().to(device)) self.n_averaged += 1 def update_bn(loader, model, device=None): r"""Updates BatchNorm running_mean, running_var buffers in the model. The documentation is referenced from: https://pytorch.org/docs/stable/optim.html#taking-care-of-batch-normalization It performs one pass over data in `loader` to estimate the activation statistics for BatchNorm layers in the model. Args: loader (oneflow.utils.data.DataLoader): dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data. model (oneflow.nn.Module): model for which we seek to update BatchNorm statistics. device (oneflow.device, optional): If set, data will be transferred to :attr:`device` before being passed into :attr:`model`. For example: .. code-block:: python import oneflow as flow loader, model = ... flow.optim.swa_utils.update_bn(loader, model) .. note:: The `update_bn` utility assumes that each data batch in :attr:`loader` is either a tensor or a list or tuple of tensors; in the latter case it is assumed that :meth:`model.forward()` should be called on the first element of the list or tuple corresponding to the data batch. """ with flow.no_grad(): momenta = {} for module in model.modules(): if isinstance(module, flow.nn.modules.batchnorm._BatchNorm): module.running_mean = flow.zeros_like(module.running_mean) module.running_var = flow.ones_like(module.running_var) momenta[module] = module.momentum if not momenta: return was_training = model.training model.train() for module in momenta.keys(): module.momentum = None module.num_batches_tracked *= 0 for input in loader: if isinstance(input, (list, tuple)): input = input[0] if device is not None: input = input.to(device) model(input) for bn_module in momenta.keys(): bn_module.momentum = momenta[bn_module] model.train(was_training) class SWALR(LRScheduler): r"""Anneals the learning rate in each parameter group to a fixed value. The documentation is referenced from: https://pytorch.org/docs/stable/optim.html#swa-learning-rate-schedules This learning rate scheduler is meant to be used with Stochastic Weight Averaging (SWA) method (see `oneflow.optim.swa_utils.AveragedModel`). Args: optimizer (oneflow.optim.Optimizer): wrapped optimizer swa_lrs (float or list): the learning rate value for all param groups together or separately for each group. annealing_epochs (int): number of epochs in the annealing phase (default: 10) annealing_strategy (str): "cos" or "linear"; specifies the annealing strategy: "cos" for cosine annealing, "linear" for linear annealing (default: "cos") last_epoch (int): the index of the last epoch (default: -1) The :class:`SWALR` scheduler can be used together with other schedulers to switch to a constant learning rate late in the training as in the example below. For example: .. code-block:: python import oneflow as flow loader, optimizer, model = ... lr_lambda = lambda epoch: 0.9 scheduler = flow.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lr_lambda) swa_scheduler = flow.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) swa_start = 160 for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() if i > swa_start: swa_scheduler.step() else: scheduler.step() .. _Averaging Weights Leads to Wider Optima and Better Generalization: https://arxiv.org/abs/1803.05407 """ def __init__( self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy="cos", last_epoch=-1 ): swa_lrs = self._format_param(optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups): group["swa_lr"] = swa_lr if anneal_strategy not in ["cos", "linear"]: raise ValueError( "anneal_strategy must by one of 'cos' or 'linear', " f"instead got {anneal_strategy}" ) elif anneal_strategy == "cos": self.anneal_func = self._cosine_anneal elif anneal_strategy == "linear": self.anneal_func = self._linear_anneal if not isinstance(anneal_epochs, int) or anneal_epochs < 0: raise ValueError( f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" ) self.anneal_epochs = anneal_epochs self.param_group_index = 0 super(SWALR, self).__init__(optimizer, last_epoch) @staticmethod def _format_param(optimizer, swa_lrs): if isinstance(swa_lrs, (list, tuple)): if len(swa_lrs) != len(optimizer.param_groups): raise ValueError( "swa_lr must have the same length as " f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, " f"optimizer.param_groups has {len(optimizer.param_groups)}" ) return swa_lrs else: return [swa_lrs] * len(optimizer.param_groups) @staticmethod def _linear_anneal(t): return t @staticmethod def _cosine_anneal(t): return (1 - math.cos(math.pi * t)) / 2 @staticmethod def _get_initial_lr(lr, swa_lr, alpha): if alpha == 1: return swa_lr return (lr - alpha * swa_lr) / (1 - alpha) def get_lr(self, base_lr, step): if self.anneal_epochs == 0: step = max(1, step) prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) prev_alpha = self.anneal_func(prev_t) group = self.optimizer.param_groups[self.param_group_index] prev_lr = self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) self.param_group_index += 1 if self.param_group_index == len(self.optimizer.param_groups): self.param_group_index = 0 t = max(0, min(1, step / max(1, self.anneal_epochs))) alpha = self.anneal_func(t) return group["swa_lr"] * alpha + prev_lr * (1 - alpha) ================================================ FILE: python/oneflow/nn/optimizer/warmup_lr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np from typing import Union from ...optim.optimizer import Optimizer from .lr_scheduler import LRScheduler from .sequential_lr import SequentialLR from .constant_lr import ConstantLR from .linear_lr import LinearLR class WarmupLR(SequentialLR): r"""Increasing the learning rate with a small warmup factor until the number of epoch reaches the warmup_iters. You can assign an optimizer or a learning rate scheduler. Notice that the warmup can happen simultaneously with learning rate scheduler. Args: scheduler_or_optimizer ([type]): Wrapped learning rate scheduler or optimizer warmup_factor (float, optional): The warmup factor. Defaults to 1.0/3. warmup_iters (int, optional): The number of warmup steps. Defaults to 5. warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant". In linear mode, the multiplication factor starts with warmup_factor in the first epoch and then inreases linearly to reach 1. Defaults to "linear". last_step (int, optional): The index of the last step. Defaults to -1. verbose (bool, optional): If True, it prints a message to stdout for each update step. Defaults to False. Raises: ValueError: The warmup method should be one of the "constant" and "linear" For example: Example 1: .. code:: python # lr = 0.0005 if epoch == 0 # lr = 0.0005 if epoch == 1 # lr = 0.0005 if epoch == 2 # lr = 0.0005 if epoch == 3 # lr = 0.0005 if epoch == 4 # lr = 0.001 if epoch >= 5 of_sgd = flow.optim.SGD(parameters, lr=0.001) constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR( of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method="constant" ) ... Example 2: .. code:: python # lr = 0.0005 if epoch == 0 # lr = 0.0006 if epoch == 1 # lr = 0.0007 if epoch == 2 # lr = 0.0008 if epoch == 3 # lr = 0.0009 if epoch == 4 # lr = 0.001 if epoch >= 5 of_sgd = flow.optim.SGD(parameters, lr=0.001) constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR( of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) ... Example 2: .. code:: python # lr = 0.0005 if epoch == 0 # lr = 0.00075 if epoch == 1 # Above is WarmUpLR, then we start CosineDecayLR # lr = 0.000689 if epoch == 2 # lr = 0.000410 if epoch == 3 # .... of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.1 decay_steps = 5 cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR( of_sgd, decay_steps=decay_steps, alpha=alpha ) linear_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR( cosine_decay_lr, warmup_factor=0.5, warmup_iters=2, warmup_method="linear" ) ... """ def __init__( self, scheduler_or_optimizer: Union[LRScheduler, Optimizer], warmup_factor: float = 1.0 / 3, warmup_iters: int = 5, warmup_method: str = "linear", warmup_prefix: bool = False, last_step=-1, verbose=False, ): if not isinstance(scheduler_or_optimizer, (LRScheduler, Optimizer)): raise ValueError( "'scheduler_or_optimizer' must be a LRScheduler or an Optimizer, but got " f"{type(scheduler_or_optimizer)}" ) if warmup_method not in ("linear", "constant"): raise ValueError( f"'warmup_method' must be 'linear' or 'constant', but got {warmup_method}" ) if isinstance(scheduler_or_optimizer, LRScheduler): opt = scheduler_or_optimizer.optimizer scheduler = scheduler_or_optimizer else: opt = scheduler_or_optimizer scheduler = None if scheduler is None and warmup_iters == 0: raise ValueError( "When 'scheduler_or_optimizer' is an optimizer warmup_iters can't be equal to 0" ) self.warmup_factor = warmup_factor self.warmup_iters = warmup_iters self.warmup_method = warmup_method self.warmup_prefix = warmup_prefix # manually init optimizer, last_step, base_lrs first self.optimizer = opt self.last_step = last_step self.verbose = verbose self._init_base_lrs() warmup = self._init_warmup_scheduler(scheduler) self._init_seq_scheduler(scheduler, warmup) def _init_warmup_scheduler(self, scheduler): warmup = None if self.warmup_iters <= 0: return if self.warmup_method == "linear": if scheduler and self.warmup_prefix is False: base_lr = self.base_lrs[0] if not np.isclose(self.base_lrs, base_lr).all(): raise ValueError( "The param_groups in optimizer have different warmup configs, please use different optimizers." ) end_lr = scheduler.get_lr(base_lr, self.warmup_iters) end_factor = end_lr / base_lr else: end_factor = 1.0 warmup = LinearLR( self.optimizer, start_factor=self.warmup_factor, end_factor=end_factor, total_iters=self.warmup_iters, last_step=self.last_step, verbose=self.verbose, ) else: # "constant" warmup = ConstantLR( self.optimizer, factor=self.warmup_factor, total_iters=self.warmup_iters, last_step=self.last_step, verbose=self.verbose, ) return warmup def _init_seq_scheduler(self, scheduler, warmup): if warmup and scheduler: schedulers = [warmup, scheduler] milestones = [self.warmup_iters] interval_rescaling = [self.warmup_prefix] elif warmup: schedulers = [warmup] milestones = [] interval_rescaling = [] elif scheduler: schedulers = [scheduler] milestones = [] interval_rescaling = [] else: raise ValueError("No scheduler can work") super().__init__( self.optimizer, schedulers=schedulers, milestones=milestones, interval_rescaling=interval_rescaling, last_step=self.last_step, verbose=self.verbose, ) ================================================ FILE: python/oneflow/nn/parallel/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .distributed import DistributedDataParallel __all__ = ["DistributedDataParallel"] ================================================ FILE: python/oneflow/nn/parallel/distributed.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from collections import OrderedDict import oneflow as flow from oneflow.support.env_var_util import parse_boolean_from_env from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple from oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup from oneflow.framework.args_tree import ArgsTree def allreduce_fn(module, param, use_bucket): ddp_state_for_reversed_params = module._ddp_state_for_reversed_params def allreduce_with_bucket(grad): buckets = module._buckets bucket_tensors = module._bucket_tensors ddp_state_for_reversed_params[param][0] = True for index, bucket in enumerate(buckets): deleted = all(ddp_state_for_reversed_params[x][1] for x in bucket) if deleted: continue assert not any(ddp_state_for_reversed_params[x][1] for x in bucket) all_params_in_bucket_ready = all( ddp_state_for_reversed_params[x][0] for x in bucket ) if all_params_in_bucket_ready: for x in bucket: ddp_state_for_reversed_params[x][1] = True # NOTE(jianhao)(higher-order-grad): # local allreduce doesn't have gradient function, higher-order grad may be unsupported flow._C.local_all_reduce(bucket_tensors[index], inplace=True) else: break def allreduce_without_bucket(grad): ddp_state_for_reversed_params[param][0] = True for cur_param, (ready, deleted) in ddp_state_for_reversed_params.items(): if deleted: continue if ready: ddp_state_for_reversed_params[cur_param][1] = True # NOTE(jianhao)(higher-order-grad): local allreduce doesn't have gradient function, higher-order grad may be unsupported if cur_param is param: flow._C.local_all_reduce(grad, True) else: flow._C.local_all_reduce(cur_param.grad, True) else: break return allreduce_with_bucket if use_bucket else allreduce_without_bucket def DistributedDataParallel( module: "flow.nn.Module", *, broadcast_buffers: bool = True, broadcast_parameters: bool = True, bucket_size: int = 10, use_bucket: bool = True, ): assert all(x.dtype == flow.float32 for x in module.parameters()) if use_bucket and parse_boolean_from_env("ONEFLOW_DISABLE_VIEW", False): warnings.warn( "because the environment variable 'ONEFLOW_DISABLE_VIEW' is set to true, so the view mechanism is disabled, and we will set use_bucket=False" ) use_bucket = False world_size = flow.env.get_world_size() if broadcast_parameters: with flow.no_grad(): for x in module.parameters(): requires_grad = x.requires_grad flow._C.comm_broadcast(x, inplace=True) # TODO: fix the bug that x's requires_grad is discarded # after flow._C.comm_broadcast x.requires_grad_(requires_grad) if use_bucket: all_grad_size = sum([x.numel() for x in module.parameters()]) if all_grad_size > 0: device = list(module.parameters())[0].device assert all(x.device == device for x in module.parameters()) reversed_param_list = list( reversed( list([param for param in module.parameters() if param.requires_grad]) ) ) module._bucket_index = { x: i // bucket_size for i, x in enumerate(reversed_param_list) } module._buckets = [ reversed_param_list[i : i + bucket_size] for i in range(0, len(reversed_param_list), bucket_size) ] module._params_group = ContiguousParamsGroup(module._buckets) module._bucket_tensors = module._params_group.grouped_parameters_grad ddp_state_for_reversed_params = OrderedDict( reversed([(x, [False, False]) for x in module.parameters() if x.requires_grad]) ) module._ddp_state_for_reversed_params = ddp_state_for_reversed_params # The gradient shoule be averaged by all the nodes, so besides allreduce, # a division by world_size is required. # Use x * (1 / world_size) instead of x / world_size for two reasons: # 1. multiplication is faster than division # 2. An inplace operation is needed here (for allreduce grouping) # But we do not have inplace division in oneflow. mul_factor = 1 / world_size def inplace_mul_and_return_none(x): x.mul_(mul_factor) return None for param in module.parameters(): if param.requires_grad: param._register_post_grad_accumulation_hook(inplace_mul_and_return_none) param._register_post_grad_accumulation_hook( allreduce_fn(module, param, use_bucket) ) def post_forward_hook(module, input, output): ddp_state_for_reversed_params = module._ddp_state_for_reversed_params for state in ddp_state_for_reversed_params.values(): state[0], state[1] = False, False output = ArgsTree(output).map_leaf( lambda x: flow._C.select_top_n( convert_to_tensor_tuple([x, *ddp_state_for_reversed_params.keys()]), n=1, )[0] ) buffers = list(module.buffers()) if len(buffers) > 0: flow._C.stream_touch(buffers) return output module.register_forward_hook(post_forward_hook) if broadcast_buffers: def pre_forward_hook(module, input): with flow.no_grad(): buffers = list(module.buffers()) flow._C.comm_broadcast(buffers, inplace=True) module.register_forward_pre_hook(pre_forward_hook) module._is_ddp_module = True return module ================================================ FILE: python/oneflow/nn/parameter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow Parameter = flow._oneflow_internal.nn.Parameter ================================================ FILE: python/oneflow/nn/qat/__init__.py ================================================ ================================================ FILE: python/oneflow/nn/qat/conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow import nn as nn from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t from typing import Union def get_conv_fake_quantized( input, input_observer, current_train_step, weight, weight_observer, fake_quantizer ): in_scale, in_zero_point = input_observer(input, current_train_step) input_fake_quanted = fake_quantizer(input, in_scale, in_zero_point) w_scale, w_zero_point = weight_observer(weight) weight_fake_quanted = fake_quantizer(weight, w_scale, w_zero_point) return input_fake_quanted, weight_fake_quanted def init_conv_fake_quants( self, quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", weight_quant_per_layer: bool = True, input_quant_momentum: float = 0.95, ): self.input_min_max_observer = nn.MovingAverageMinMaxObserver( stop_update_after_iters=1, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, momentum=input_quant_momentum, ) self.register_buffer("current_train_step", flow.zeros(1, dtype=flow.int64,)) self.weight_min_max_observer = nn.MinMaxObserver( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, per_layer_quantization=weight_quant_per_layer, ) self.fake_quantizer = nn.FakeQuantization( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, ) class QatConv1d(nn.Conv1d): r"""A Conv1d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input, used for quantization aware training. The parameters of QatConv1d are the same as :class:`~oneflow.nn.Conv1d` with some extra parameters for fake quantization, see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` quantization_formula (str): Support "google" or "cambricon". quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True. input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95. Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where .. math:: L_{out} = \\left\\lfloor\\frac{L_{in} + 2 \\times \\text{padding} - \\text{dilation} \\times (\\text{kernel\\_size} - 1) - 1}{\\text{stride}} + 1\\right\\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\\text{out\\_channels}, \\frac{\\text{in\\_channels}}{\\text{groups}}, \\text{kernel\\_size})`. The values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(20, 16, 50) >>> input = flow.Tensor(arr) >>> m = nn.QatConv1d(16, 33, 3, stride=2, quantization_formula="google", quantization_bit=8, quantization_scheme="symmetric") >>> output = m(input) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", weight_quant_per_layer: bool = True, input_quant_momentum: float = 0.95, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.channel_pos = "channels_first" init_conv_fake_quants( self, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ) def forward(self, x): fake_quan_input, fake_quan_weight = get_conv_fake_quantized( x, self.input_min_max_observer, self.current_train_step, self.weight, self.weight_min_max_observer, self.fake_quantizer, ) return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias) class QatConv2d(nn.Conv2d): r"""A Conv2d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input, used for quantization aware training. The parameters of QatConv2d are the same as :class:`~oneflow.nn.Conv2d` with some extra parameters for fake quantization, see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` quantization_formula (str): Support "google" or "cambricon". quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True. input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95. Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[0] - \\text{dilation}[0] \\times (\\text{kernel_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor .. math:: W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[1] - \\text{dilation}[1] \\times (\\text{kernel_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor Attr: - weight (Tensor): the learnable weights of the module of shape :math:`(\\text{out_channels}, \\frac{\\text{in_channels}}{\\text{groups}},` :math:`\\text{kernel_size[0]}, \\text{kernel_size[1]})`. The values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` - bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel_size}[i]}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(20, 16, 50, 100) >>> input = flow.Tensor(arr) >>> m = nn.QatConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1), quantization_formula="google", quantization_bit=8, quantization_scheme="symmetric") >>> output = m(input) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", weight_quant_per_layer: bool = True, input_quant_momentum: float = 0.95, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.channel_pos = "channels_first" init_conv_fake_quants( self, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ) def forward(self, x): fake_quan_input, fake_quan_weight = get_conv_fake_quantized( x, self.input_min_max_observer, self.current_train_step, self.weight, self.weight_min_max_observer, self.fake_quantizer, ) return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias) class QatConv3d(nn.Conv3d): r"""A Conv3d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input, used for quantization aware training. The parameters of QatConv3d are the same as :class:`~oneflow.nn.Conv3d` with some extra parameters for fake quantization, see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'`` quantization_formula (str): Support "google" or "cambricon". quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8. quantization_scheme (str): "symmetric" or "affine", quantize to signed / unsigned integer. Defaults to "symmetric". weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True. input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95. Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow >>> import oneflow.nn as nn >>> arr = np.random.randn(1, 2, 5, 5, 5) >>> input = flow.Tensor(arr) >>> m = nn.QatConv3d(2, 4, kernel_size=3, stride=1, quantization_formula="google", quantization_bit=8, quantization_scheme="symmetric") >>> output = m(input) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: Union[str, _size_3_t] = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", quantization_formula: str = "google", quantization_bit: int = 8, quantization_scheme: str = "symmetric", weight_quant_per_layer: bool = True, input_quant_momentum: float = 0.95, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.channel_pos = "channels_first" init_conv_fake_quants( self, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ) def forward(self, x): fake_quan_input, fake_quan_weight = get_conv_fake_quantized( x, self.input_min_max_observer, self.current_train_step, self.weight, self.weight_min_max_observer, self.fake_quantizer, ) return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/utils/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.utils.clip_grad import clip_grad_norm_, clip_grad_value_ from oneflow.nn.utils.weight_norm import weight_norm from oneflow.nn.utils.weight_norm import remove_weight_norm from oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup from oneflow.nn.utils.convert_parameters import ( parameters_to_vector, vector_to_parameters, ) from oneflow.nn.utils.skip_init import skip_init ================================================ FILE: python/oneflow/nn/utils/clip_grad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from typing import Union, Iterable import numpy as np import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.framework.tensor import register_tensor_op from oneflow.nn.modules.module import Module _tensor_or_tensors = Union[Tensor, Iterable[Tensor]] def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, fused: bool = False, error_if_nonfinite: bool = False, ) -> Tensor: r"""Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:``parameters`` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Parameters after cliping gradient norm Total norm of the parameters (viewed as a single vector). For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> x1 = flow.tensor(np.array([[2, 3, 4], [1.5, 2.6, 3.7]]).astype(np.float32), requires_grad=True) >>> m1 = flow.nn.ReLU() >>> out1 = m1(x1) >>> out1 = out1.sum() >>> out1.backward() >>> norm1 = flow.nn.utils.clip_grad_norm_(x1, 0.6, 1.0) >>> norm1 tensor(6., dtype=oneflow.float32) >>> x1.grad tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.1000]], dtype=oneflow.float32) >>> x2 = flow.tensor(np.array([[-2, -3, -4], [2.5, 0, 3.2]]).astype(np.float32), requires_grad=True) >>> out2 = flow.atan(x2) >>> out2 = out2.sum() >>> out2.backward() >>> norm2 = flow.nn.utils.clip_grad_norm_(x2, 0.5) >>> norm2 tensor(1.0394, dtype=oneflow.float32) >>> x2.grad tensor([[0.0962, 0.0481, 0.0283], [0.0663, 0.4810, 0.0428]], dtype=oneflow.float32) """ if isinstance(parameters, (Tensor, flow._oneflow_internal.Tensor)): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) if len(parameters) == 0: return flow.tensor(0.0) if parameters[0].is_global: assert all( [p.is_global for p in parameters] ), "All parameters must be global tensor." sbp_broadcast = [flow.sbp.broadcast for _ in parameters[0].sbp] param0_placement = parameters[0].placement if norm_type == float("inf"): norms = [ p.grad.detach() .to_global(sbp=sbp_broadcast) .abs() .max() .to_global(placement=param0_placement) for p in parameters ] total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms)) elif norm_type == float("-inf"): norms = [ p.grad.detach() .to_global(sbp=sbp_broadcast) .abs() .min() .to_global(placement=param0_placement) for p in parameters ] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) else: total_norm = flow.linalg.vector_norm( flow.stack( [ flow.linalg.vector_norm( p.grad.detach().to_global(sbp=sbp_broadcast), norm_type ).to_global(placement=param0_placement) for p in parameters ] ), norm_type, ) if error_if_nonfinite and flow.logical_or( total_norm.isnan(), total_norm.isinf() ): raise RuntimeError( f"The total norm of order {norm_type} for gradients from " "`parameters` is non-finite, so it cannot be clipped. To disable " "this error and scale the gradients by the non-finite norm anyway, " "set `error_if_nonfinite=False`" ) clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = clip_coef.clamp(max=1.0) for p in parameters: p.grad.detach().mul_(clip_coef_clamped.to_global(placement=p.placement)) elif fused and not error_if_nonfinite and all([p.grad.is_cuda for p in parameters]): param_grad_list = [] for param in parameters: param_grad_list.append(param.grad) total_norm = flow._C.fused_clip_grad(param_grad_list, max_norm, norm_type,) else: device = parameters[0].grad.device if norm_type == float("inf"): norms = [p.grad.detach().abs().max().to(device) for p in parameters] total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms)) elif norm_type == float("-inf"): norms = [p.grad.detach().abs().min().to(device) for p in parameters] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) else: total_norm = flow.linalg.vector_norm( flow.stack( [ flow.linalg.vector_norm(p.grad.detach(), norm_type).to(device) for p in parameters ] ), norm_type, ) if error_if_nonfinite and flow.logical_or( total_norm.isnan(), total_norm.isinf() ): raise RuntimeError( f"The total norm of order {norm_type} for gradients from " "`parameters` is non-finite, so it cannot be clipped. To disable " "this error and scale the gradients by the non-finite norm anyway, " "set `error_if_nonfinite=False`" ) clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = clip_coef.clamp(max=1.0) for p in parameters: p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device)) return total_norm def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None: r"""Clips gradient of an iterable of parameters at specified value. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` """ if isinstance(parameters, flow.Tensor): parameters = [parameters] clip_value = float(clip_value) for p in filter(lambda p: p.grad is not None, parameters): # TODO: Switch to inplace clamp function p.grad[:] = p.grad.clamp(min=-clip_value, max=clip_value) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/utils/container.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections.abc import warnings import operator from collections import OrderedDict, abc as container_abcs from itertools import islice from typing import ( Any, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, overload, Generic, ) import oneflow as flow from oneflow.nn.modules.module import Module T = TypeVar("T") def get_seq(T): class SequentialContainer(T): @overload def __init__(self, *args: T) -> None: ... @overload def __init__(self, arg: "OrderedDict[str, T]") -> None: ... def __init__(self, *args: Any): super(SequentialContainer, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for (key, module) in args[0].items(): self.add_module(key, module) else: for (idx, module) in enumerate(args): self.add_module(str(idx), module) def _get_item_by_idx(self, iterator, idx): """Get the idx-th item of the iterator""" size = len(self) idx = operator.index(idx) if not -size <= idx < size: raise IndexError("index {} is out of range".format(idx)) idx %= size return next(islice(iterator, idx, None)) def __getitem__(self: T, idx) -> T: if isinstance(idx, slice): return self.__class__(OrderedDict(list(self._modules.items())[idx])) else: return self._get_item_by_idx(self._modules.values(), idx) def __setitem__(self, idx: int, module: T) -> None: key = self._get_item_by_idx(self._modules.keys(), idx) return setattr(self, key, module) def __delitem__(self, idx: Union[slice, int]) -> None: if isinstance(idx, slice): for key in list(self._modules.keys())[idx]: delattr(self, key) else: key = self._get_item_by_idx(self._modules.keys(), idx) delattr(self, key) def __len__(self) -> int: return len(self._modules) def __dir__(self): keys = super(SequentialContainer, self).__dir__() keys = [key for key in keys if not key.isdigit()] return keys def __iter__(self) -> Iterator[T]: return iter(self._modules.values()) def forward(self, input): for module in self: input = module(input) return input return SequentialContainer def get_list(T): class ListContainer(T): def __init__(self, modules: Optional[Iterable[T]] = None) -> None: super(ListContainer, self).__init__() if modules is not None: self += modules def _get_abs_string_index(self, idx): """Get the absolute index for the list of modules""" idx = operator.index(idx) if not -len(self) <= idx < len(self): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) return str(idx) def __getitem__(self, idx: int) -> T: if isinstance(idx, slice): return self.__class__(list(self._modules.values())[idx]) else: return self._modules[self._get_abs_string_index(idx)] def __setitem__(self, idx: int, module: T) -> None: idx = self._get_abs_string_index(idx) return setattr(self, str(idx), module) def __delitem__(self, idx: Union[int, slice]) -> None: if isinstance(idx, slice): for k in range(len(self._modules))[idx]: delattr(self, str(k)) else: delattr(self, self._get_abs_string_index(idx)) str_indices = [str(i) for i in range(len(self._modules))] self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) def __len__(self) -> int: return len(self._modules) def __iter__(self) -> Iterator[T]: return iter(self._modules.values()) def __iadd__(self: T, modules: Iterable[T]) -> T: return self.extend(modules) def __dir__(self): keys = super(ListContainer, self).__dir__() keys = [key for key in keys if not key.isdigit()] return keys def insert(self, index: int, module: T) -> None: """Insert a given module before a given index in the list. Arguments: index (int): index to insert. module (nn.Module): module to insert """ for i in range(len(self._modules), index, -1): self._modules[str(i)] = self._modules[str(i - 1)] self._modules[str(index)] = module def append(self: T, module: T) -> T: """Appends a given module to the end of the list. Arguments: module (nn.Module): module to append """ self.add_module(str(len(self)), module) return self def extend(self: T, modules: Iterable[T]) -> T: """Appends modules from a Python iterable to the end of the list. Arguments: modules (iterable): iterable of modules to append """ if not isinstance(modules, collections.abc.Iterable): raise TypeError( "ModuleList.extend should be called with an iterable, but got " + type(modules).__name__ ) offset = len(self) for (i, module) in enumerate(modules): self.add_module(str(offset + i), module) return self def forward(self): raise NotImplementedError() return ListContainer def get_dict(T): class DictContainer(T): def __init__(self, modules: Optional[Mapping[str, T]] = None) -> None: super(DictContainer, self).__init__() if modules is not None: self.update(modules) def __getitem__(self, key: str) -> T: return self._modules[key] def __setitem__(self, key: str, module: T) -> None: self.add_module(key, module) def __delitem__(self, key: str) -> None: del self._modules[key] def __len__(self) -> int: return len(self._modules) def __iter__(self) -> Iterator[str]: return iter(self._modules) def __contains__(self, key: str) -> bool: return key in self._modules def clear(self) -> None: """Remove all items from the ModuleDict. """ self._modules.clear() def pop(self, key: str) -> T: """Remove key from the ModuleDict and return its module. Arguments: key (string): key to pop from the ModuleDict """ v = self[key] del self[key] return v def keys(self) -> Iterable[str]: """Return an iterable of the ModuleDict keys. """ return self._modules.keys() def items(self) -> Iterable[Tuple[str, T]]: """Return an iterable of the ModuleDict key/value pairs. """ return self._modules.items() def values(self) -> Iterable[T]: """Return an iterable of the ModuleDict values. """ return self._modules.values() def update(self, modules: Mapping[str, T]) -> None: if not isinstance(modules, collections.abc.Iterable): raise TypeError( "ModuleDict.update should be called with an iterable of key/value pairs, but got " + type(modules).__name__ ) if isinstance(modules, (OrderedDict, T, collections.abc.Mapping)): for (key, module) in modules.items(): self[key] = module else: for (j, m) in enumerate(modules): if not isinstance(m, collections.abc.Iterable): raise TypeError( "ModuleDict update sequence element #" + str(j) + " should be Iterable; is" + type(m).__name__ ) if not len(m) == 2: raise ValueError( "ModuleDict update sequence element #" + str(j) + " has length " + str(len(m)) + "; 2 is required" ) self[m[0]] = m[1] return DictContainer def get_para_list(T): class ParameterListContainer(T): def __init__(self, parameters=None) -> None: super(ParameterListContainer, self).__init__() self._initialized = True if parameters is not None: self += parameters def __setstate__(self, state): state["_initialized"] = False super(ParameterListContainer, self).__setstate__(state) self._initialized = True def _get_abs_string_index(self, idx): """Get the absolute index for the list of modules""" idx = operator.index(idx) if not -len(self) <= idx < len(self): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) return str(idx) @overload def __getitem__(self, idx: int): ... @overload def __getitem__(self: T, idx: slice) -> T: ... def __getitem__(self, idx): if isinstance(idx, slice): return self.__class__(list(self._parameters.values())[idx]) else: idx = self._get_abs_string_index(idx) return self._parameters[str(idx)] def __setitem__(self, idx: int, param) -> None: idx = self._get_abs_string_index(idx) return self.register_parameter(str(idx), param) def __len__(self) -> int: return len(self._parameters) def __iter__(self): return iter(self._parameters.values()) def __iadd__(self, parameters): return self.extend(parameters) def __dir__(self): keys = super(ParameterListContainer, self).__dir__() keys = [key for key in keys if not key.isdigit()] return keys def append(self: T, parameter) -> T: """Appends a given parameter at the end of the list. Arguments: parameter (nn.Parameter): parameter to append """ self.register_parameter(str(len(self)), parameter) return self def extend(self: T, parameters) -> T: """Appends parameters from a Python iterable to the end of the list. Arguments: parameters (iterable): iterable of parameters to append """ if not isinstance(parameters, collections.abc.Iterable): raise TypeError( "ParameterList.extend should be called with an iterable, but got " + type(parameters).__name__ ) offset = len(self) for (i, param) in enumerate(parameters): self.register_parameter(str(offset + i), param) return self def extra_repr(self) -> str: child_lines = [] for (k, p) in self._parameters.items(): size_str = "x".join((str(size) for size in p.size())) device_str = "" if not p.is_cuda else " (GPU {})".format(p.get_device()) parastr = "Parameter containing: [{} of size {}{}]".format( type(p), size_str, device_str ) child_lines.append(" (" + str(k) + "): " + parastr) tmpstr = "\n".join(child_lines) return tmpstr def __call__(self, input): raise RuntimeError("ParameterList should not be called.") def _replicate_for_data_parallel(self): warnings.warn( "nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one." ) return super(ParameterListContainer, self)._replicate_for_data_parallel() return ParameterListContainer def get_para_dict(T): class ParameterDictContainer(T): def __init__(self, parameters=None) -> None: super(ParameterDictContainer, self).__init__() self._initialized = True if parameters is not None: self.update(parameters) def __setstate__(self, state): state["_initialized"] = False super(ParameterDictContainer, self).__setstate__(state) self._initialized = True def __getitem__(self, key: str): return self._parameters[key] def __setitem__(self, key: str, parameter) -> None: self.register_parameter(key, parameter) def __delitem__(self, key: str) -> None: del self._parameters[key] def __len__(self) -> int: return len(self._parameters) def __iter__(self) -> Iterator[str]: return iter(self._parameters.keys()) def __contains__(self, key: str) -> bool: return key in self._parameters def clear(self) -> None: """Remove all items from the ParameterDict. """ self._parameters.clear() def pop(self, key: str): r"""Remove key from the ParameterDict and return its parameter. Args: key (string): key to pop from the ParameterDict """ v = self[key] del self[key] return v def keys(self) -> Iterable[str]: r"""Return an iterable of the ParameterDict keys. """ return self._parameters.keys() def items(self): r"""Return an iterable of the ParameterDict key/value pairs. """ return self._parameters.items() def values(self): r"""Return an iterable of the ParameterDict values. """ return self._parameters.values() def update(self, parameters) -> None: r"""Update the :class:`~flow.nn.ParameterDict` with the key-value pairs from a mapping or an iterable, overwriting existing keys. .. note:: If :attr:`parameters` is an ``OrderedDict``, a :class:`~flow.nn.ParameterDict`, or an iterable of key-value pairs, the order of new elements in it is preserved. Args: parameters (iterable): a mapping (dictionary) from string to :class:`~flow.nn.Parameter`, or an iterable of key-value pairs of type (string, :class:`~flow.nn.Parameter`) """ if not isinstance(parameters, container_abcs.Iterable): raise TypeError( "ParametersDict.update should be called with an " "iterable of key/value pairs, but got " + type(parameters).__name__ ) if isinstance(parameters, (OrderedDict, ParameterDictContainer)): for key, parameter in parameters.items(): self[key] = parameter elif isinstance(parameters, container_abcs.Mapping): for key, parameter in sorted(parameters.items()): self[key] = parameter else: for j, p in enumerate(parameters): if not isinstance(p, container_abcs.Iterable): raise TypeError( "ParameterDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(p).__name__ ) if not len(p) == 2: raise ValueError( "ParameterDict update sequence element " "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" ) # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment self[p[0]] = p[1] # type: ignore[assignment] def extra_repr(self) -> str: child_lines = [] for k, p in self._parameters.items(): size_str = "x".join(str(size) for size in p.size()) device_str = "" if not p.is_cuda else " (GPU {})".format(p.get_device()) parastr = "Parameter containing: [{} of size {}{}]".format( type(p), size_str, device_str ) child_lines.append(" (" + k + "): " + parastr) tmpstr = "\n".join(child_lines) return tmpstr def __call__(self, input): raise RuntimeError("ParameterDict should not be called.") def _replicate_for_data_parallel(self): warnings.warn( "nn.ParameterDict is being used with DataParallel but this is not " "supported. This dict will appear empty for the models replicated " "on each GPU except the original one." ) return super(ParameterDictContainer, self)._replicate_for_data_parallel() return ParameterDictContainer ================================================ FILE: python/oneflow/nn/utils/convert_parameters.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from typing import Iterable, Optional from oneflow.framework.tensor import Tensor def parameters_to_vector(parameters: Iterable[Tensor]) -> Tensor: r"""Convert parameters to one vector The method is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector. Args: parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. Returns: The parameters represented by a single vector """ # Flag for the device where the parameter is located param_device = None vec = [] for param in parameters: # Ensure the parameters are located in the same device param_device = _check_param_device(param, param_device) vec.append(param.view(-1)) return flow.cat(vec) def vector_to_parameters(vec: Tensor, parameters: Iterable[Tensor]) -> None: r"""Convert one vector to the parameters The method is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.utils.vector_to_parameters. Args: vec (Tensor): a single vector represents the parameters of a model. parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. """ # Ensure vec of type Tensor if not isinstance(vec, Tensor): raise TypeError("expected flow.Tensor, but got: {}".format(flow.typename(vec))) # Flag for the device where the parameter is located param_device = None # Pointer for slicing the vector for each parameter pointer = 0 for param in parameters: # Ensure the parameters are located in the same device param_device = _check_param_device(param, param_device) # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the parameter param.data = vec[pointer : pointer + num_param].view_as(param).data # Increment the pointer pointer += num_param def _check_param_device(param: Tensor, old_param_device: Optional[int]) -> int: r"""This helper function is to check if the parameters are located in the same device. Currently, the conversion between model parameters and single vector form is not supported for multiple allocations, e.g. parameters in different GPUs, or mixture of CPU/GPU. The method is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.html#utilities. Args: param ([Tensor]): a Tensor of a parameter of a model old_param_device (int): the device where the first parameter of a model is allocated. Returns: old_param_device (int): report device for the first time """ # Meet the first parameter if old_param_device is None: old_param_device = param.get_device() if param.is_cuda else -1 else: warn = False if param.is_cuda: # Check if in same GPU warn = param.get_device() != old_param_device else: # Check if in CPU warn = old_param_device != -1 if warn: raise TypeError( "Found two parameters on different devices, " "this is currently not supported." ) return old_param_device ================================================ FILE: python/oneflow/nn/utils/parameters_grouping.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections import warnings from typing import Union, List import oneflow as flow from oneflow.framework.tensor import Tensor _tensor_or_tensors = Union[Tensor, List[Tensor], List[List[Tensor]]] def numel_in_bucket(tensor: Tensor): assert flow.is_floating_point(tensor), "params grouping only support float tensor." def align(x: int, unit_size: int): return (x + (unit_size - 1)) // unit_size * unit_size # tensor memory should be align to 512 bytes for cuda operations, # align size depends on floating type return align( tensor.numel(), flow._oneflow_internal.max_alignment_size() // (flow.finfo(tensor.dtype).bits // 8), ) class ContiguousParamsGroup(object): """Arange tensors into contiguous buffer according to their group. Args: params_group_list (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will be made into buffers. group_on_current_buffer (bool, optional): whether to group tensors on allocated buffers. (default: True) Note: The ContiguousParamsGroup is created by 2D List of Tensors, which indicates the Tensors in the same 1D List should be grouped into the same Tensor buffer, otherwise try to make them into a 2D List. If group_on_current_buffer is set True but there is not any buffer created before, ContiguousParamsGroup will allocate default buffers for all parameters. """ def __init__( self, params_group_list: _tensor_or_tensors, group_on_current_buffer: bool = True, ): self.params_group_list = params_group_list.copy() self._make_valid_params_group_list() self._remove_no_grad_tensors() self._check_tensor_position_consistency() self.group_on_current_buffer = group_on_current_buffer self.grouped_tensors = [] self.grouped_grads = [] if not self.group_on_current_buffer: self._parameters_grouping_on_new_buffer() else: self._check_current_buffer() self._parameters_grouping_on_current_buffer() def _make_valid_params_group_list(self): """making params_group_list 2D List of Tensors """ if isinstance(self.params_group_list, Tensor): warnings.warn("Single tensor is best not do grouping.") self.params_group_list = [[self.params_group_list]] elif all([isinstance(p, Tensor) for p in self.params_group_list]): self.params_group_list = [self.params_group_list] elif all( [ all([isinstance(p, Tensor) for p in params]) for params in self.params_group_list ] ): pass else: raise ValueError("The shape of params_group_list is illegal!") def _remove_no_grad_tensors(self): self.params_group_list = [ [p for p in params if p.requires_grad] for params in self.params_group_list ] def _check_tensor_position_consistency(self): if all( [all([p.is_global for p in params]) for params in self.params_group_list] ): self.is_global = True elif all( [all([p.is_local for p in params]) for params in self.params_group_list] ): self.is_global = False else: raise ValueError( "Parameters must be all local tensors or all global tensors for params grouping." ) def _check_current_buffer(self): """If all tensors are not held by any buffer, try to create buffer. """ for params in self.params_group_list: for p in params: if p._ref_tensor is not None: return warnings.warn("create defualt buffer for all parameters as one group.") self._physical_preparation = ContiguousParamsGroup( self.params_group_list, group_on_current_buffer=False, ) def _make_buffer_params_mapping(self): buffer_params_mapping = collections.defaultdict(list) for params in self.params_group_list: for p in params: if p._ref_tensor is not None: assert ( p._ref_index < p._ref_tensor.numel() ), "invalid ref tensor index." buffer_params_mapping[p._ref_tensor].append((p._ref_index, p)) for buffer, params_list in buffer_params_mapping.items(): buffer_params_mapping[buffer] = sorted(params_list, key=lambda x: x[0]) return buffer_params_mapping def _parameters_grouping_on_new_buffer(self): # Use the group in params_group_list to create default buffer. # A buffer that is too large will affect the parallelism of different parameters. params_buffer_size = {} physical_params_buffer = {} params_buffer_index = {} for idx, params in enumerate(self.params_group_list): for p in params: if self.is_global: tensor_key = (p.dtype, p.placement, p.sbp, idx) else: tensor_key = (p.dtype, p.device, idx) params_buffer_size[tensor_key] = params_buffer_size.get( tensor_key, 0 ) + numel_in_bucket(p) for tensor_key, buffer_size in params_buffer_size.items(): dtype = tensor_key[0] if self.is_global: placement = tensor_key[1] sbp = tensor_key[2] physical_param_buf = flow.zeros( buffer_size, dtype=dtype, placement=placement, sbp=sbp ) physical_param_buf.grad = flow.zeros( buffer_size, dtype=dtype, placement=placement, sbp=sbp ) else: device = tensor_key[1] physical_param_buf = flow.zeros(buffer_size, dtype=dtype, device=device) physical_param_buf.grad = flow.zeros( buffer_size, dtype=dtype, device=device ) self.grouped_tensors.append(physical_param_buf) self.grouped_grads.append(physical_param_buf.grad) physical_params_buffer[tensor_key] = physical_param_buf params_buffer_index[tensor_key] = 0 for idx, params in enumerate(self.params_group_list): for p in params: if self.is_global: tensor_key = (p.dtype, p.placement, p.sbp, idx) else: tensor_key = (p.dtype, p.device, idx) param_buf = physical_params_buffer[tensor_key] index = params_buffer_index[tensor_key] size = p.numel() shape = p.data.shape assert index + numel_in_bucket(p) <= param_buf.numel() param_buf[index : index + size] = p.data.detach().clone().view(-1) p.data = param_buf[index : index + size].view(shape) p.grad = param_buf.grad[index : index + size].view(shape) p._ref_tensor = param_buf p._ref_index = index index += numel_in_bucket(p) params_buffer_index[tensor_key] = index def _parameters_grouping_on_current_buffer(self): buffer_params_mapping = self._make_buffer_params_mapping() if buffer_params_mapping is None or len(buffer_params_mapping) == 0: warnings.warn( "Since nn.Module didn't use make_contiguous_params_group() to create " "a contiguous module, the remapping won't make any difference for parameters. " ) params_group = [] for params in self.params_group_list: group = set() for p in params: group.add(p) params_group.append(group) # handling the parameters already on allocated buffers # try best to make the adjacent tensors on device into same logical buffer for param_buf, params in buffer_params_mapping.items(): logical_buffer_start, logical_buffer_size = 0, 0 pre_group_index = -1 params_cnt = len(params) for p_index, (_, p) in enumerate(params): current_group_index = -1 for group_index, group in enumerate(params_group): if p in group: current_group_index = group_index break if current_group_index == -1: continue params_group[current_group_index].remove(p) def _make_logical_buf(): nonlocal logical_buffer_start, logical_buffer_size nonlocal pre_group_index, current_group_index pre_group_index = current_group_index if logical_buffer_size == 0: return logical_param_buf = param_buf[ logical_buffer_start : logical_buffer_start + logical_buffer_size ].view(logical_buffer_size) logical_param_grad_buf = param_buf.grad[ logical_buffer_start : logical_buffer_start + logical_buffer_size ].view(logical_buffer_size) logical_param_buf.grad = logical_param_grad_buf self.grouped_tensors.append(logical_param_buf) self.grouped_grads.append(logical_param_grad_buf) logical_buffer_start += logical_buffer_size logical_buffer_size = 0 if current_group_index != pre_group_index: _make_logical_buf() logical_buffer_size += numel_in_bucket(p) if p_index == params_cnt - 1: _make_logical_buf() # handling params not on any buffer # however, we don't make new tensors into contiguous buffer this time for group in params_group: for p in group: self.grouped_tensors.append(p) self.grouped_grads.append(p.grad) @property def grouped_parameters(self): """the grouped contiguous parameters """ return self.grouped_tensors @property def grouped_parameters_grad(self): """the grouped contiguous parameters' gradient """ return self.grouped_grads ================================================ FILE: python/oneflow/nn/utils/prune.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""" Prune Methods are consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/stable/nn.html#module-torch.nn.utils. """ import numbers from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Tuple import numpy as np import oneflow as flow class BasePruningMethod(ABC): r"""Abstract base class for creation of new pruning techniques. Provides a skeleton for customization requiring the overriding of methods such as :meth:`compute_mask` and :meth:`apply`. """ _tensor_name: str def __init__(self): pass def __call__(self, module, inputs): r"""Multiplies the mask (stored in ``module[name + '_mask']``) into the original tensor (stored in ``module[name + '_orig']``) and stores the result into ``module[name]`` by using :meth:`apply_mask`. Args: module (nn.Module): module containing the tensor to prune inputs: not used. """ setattr(module, self._tensor_name, self.apply_mask(module)) @abstractmethod def compute_mask(self, t, default_mask): r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the ``default_mask`` according to the specific pruning method recipe. Args: t (flow.Tensor): tensor representing the importance scores of the parameter to prune. default_mask (flow.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t`` """ pass def apply_mask(self, module): r"""Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor. Args: module (nn.Module): module containing the tensor to prune Returns: pruned_tensor (flow.Tensor): pruned version of the input tensor """ # to carry out the multiplication, the mask needs to have been computed, # so the pruning method must know what tensor it's operating on assert self._tensor_name is not None, "Module {} has to be pruned".format( module ) # this gets set in apply() mask = getattr(module, self._tensor_name + "_mask") orig = getattr(module, self._tensor_name + "_orig") pruned_tensor = mask.to(dtype=orig.dtype) * orig return pruned_tensor @classmethod def apply(cls, module, name, *args, importance_scores=None, **kwargs): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. args: arguments passed on to a subclass of :class:`BasePruningMethod` importance_scores (flow.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the parameter will be used in its place. kwargs: keyword arguments passed on to a subclass of a :class:`BasePruningMethod` """ def _get_composite_method(cls, module, name, *args, **kwargs): # Check if a pruning method has already been applied to # `module[name]`. If so, store that in `old_method`. old_method = None found = 0 # there should technically be only 1 hook with hook.name == name # assert this using `found` hooks_to_remove = [] for k, hook in module._forward_pre_hooks.items(): # if it exists, take existing thing, remove hook, then # go through normal thing if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: old_method = hook hooks_to_remove.append(k) found += 1 assert ( found <= 1 ), "Avoid adding multiple pruning hooks to the\ same tensor {} of module {}. Use a PruningContainer.".format( name, module ) for k in hooks_to_remove: del module._forward_pre_hooks[k] # Apply the new pruning method, either from scratch or on top of # the previous one. method = cls(*args, **kwargs) # new pruning # Have the pruning method remember what tensor it's been applied to method._tensor_name = name # combine `methods` with `old_method`, if `old_method` exists if old_method is not None: # meaning that there was a hook # if the hook is already a pruning container, just add the # new pruning method to the container if isinstance(old_method, PruningContainer): old_method.add_pruning_method(method) method = old_method # rename old_method --> method # if the hook is simply a single pruning method, create a # container, add the old pruning method and the new one elif isinstance(old_method, BasePruningMethod): container = PruningContainer(old_method) # Have the pruning method remember the name of its tensor # setattr(container, '_tensor_name', name) container.add_pruning_method(method) method = container # rename container --> method return method method = _get_composite_method(cls, module, name, *args, **kwargs) # at this point we have no forward_pre_hooks but we could have an # active reparametrization of the tensor if another pruning method # had been applied (in which case `method` would be a PruningContainer # and not a simple pruning method). # Pruning is to be applied to the module's tensor named `name`, # starting from the state it is found in prior to this iteration of # pruning. The pruning mask is calculated based on importances scores. orig = getattr(module, name) if importance_scores is not None: assert ( importance_scores.shape == orig.shape ), "importance_scores should have the same shape as parameter \ {} of {}".format( name, module ) else: importance_scores = orig # If this is the first time pruning is applied, take care of moving # the original tensor to a new parameter called name + '_orig' and # and deleting the original parameter if not isinstance(method, PruningContainer): # copy `module[name]` to `module[name + '_orig']` module.register_parameter(name + "_orig", orig) # temporarily delete `module[name]` del module._parameters[name] default_mask = flow.ones_like(orig) # temp # If this is not the first time pruning is applied, all of the above # has been done before in a previous pruning iteration, so we're good # to go else: default_mask = getattr(module, name + "_mask").detach().clone() # Use try/except because if anything goes wrong with the mask # computation etc., you'd want to roll back. try: # get the final mask, computed according to the specific method mask = method.compute_mask(importance_scores, default_mask=default_mask) # reparametrize by saving mask to `module[name + '_mask']`... module.register_buffer(name + "_mask", mask) # ... and the new pruned tensor to `module[name]` setattr(module, name, method.apply_mask(module)) # associate the pruning method to the module via a hook to # compute the function before every forward() (compile by run) module.register_forward_pre_hook(method) except Exception as e: if not isinstance(method, PruningContainer): orig = getattr(module, name + "_orig") module.register_parameter(name, orig) del module._parameters[name + "_orig"] raise e return method def prune(self, t, default_mask=None, importance_scores=None): r"""Computes and returns a pruned version of input tensor ``t`` according to the pruning rule specified in :meth:`compute_mask`. Args: t (flow.Tensor): tensor to prune (of same dimensions as ``default_mask``). importance_scores (flow.Tensor): tensor of importance scores (of same shape as ``t``) used to compute mask for pruning ``t``. The values in this tensor indicate the importance of the corresponding elements in the ``t`` that is being pruned. If unspecified or None, the tensor ``t`` will be used in its place. default_mask (flow.Tensor, optional): mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones. Returns: pruned version of tensor ``t``. """ if importance_scores is not None: assert ( importance_scores.shape == t.shape ), "importance_scores should have the same shape as tensor t" else: importance_scores = t default_mask = default_mask if default_mask is not None else flow.ones_like(t) return t * self.compute_mask(importance_scores, default_mask=default_mask) def remove(self, module): r"""Removes the pruning reparameterization from a module. The pruned parameter named ``name`` remains permanently pruned, and the parameter named ``name+'_orig'`` is removed from the parameter list. Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. Note: Pruning itself is NOT undone or reversed! """ # before removing pruning from a tensor, it has to have been applied assert ( self._tensor_name is not None ), "Module {} has to be pruned\ before pruning can be removed".format( module ) # this gets set in apply() # to update module[name] to latest trained weights weight = self.apply_mask(module) # masked weights # delete and reset if hasattr(module, self._tensor_name): delattr(module, self._tensor_name) orig = module._parameters[self._tensor_name + "_orig"] orig.data = weight.data del module._parameters[self._tensor_name + "_orig"] del module._buffers[self._tensor_name + "_mask"] setattr(module, self._tensor_name, orig) class PruningContainer(BasePruningMethod): """Container holding a sequence of pruning methods for iterative pruning. Keeps track of the order in which pruning methods are applied and handles combining successive pruning calls. Accepts as argument an instance of a BasePruningMethod or an iterable of them. """ def __init__(self, *args): self._pruning_methods: Tuple["BasePruningMethod", ...] = tuple() if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) elif len(args) == 1: # only 1 item in a tuple self._tensor_name = args[0]._tensor_name self.add_pruning_method(args[0]) else: # manual construction from list or other iterable (or no args) for method in args: self.add_pruning_method(method) def add_pruning_method(self, method): r"""Adds a child pruning ``method`` to the container. Args: method (subclass of BasePruningMethod): child pruning method to be added to the container. """ # check that we're adding a pruning method to the container if not isinstance(method, BasePruningMethod) and method is not None: raise TypeError( "{} is not a BasePruningMethod subclass".format(type(method)) ) elif method is not None and self._tensor_name != method._tensor_name: raise ValueError( "Can only add pruning methods acting on " "the parameter named '{}' to PruningContainer {}.".format( self._tensor_name, self ) + " Found '{}'".format(method._tensor_name) ) # if all checks passed, add to _pruning_methods tuple self._pruning_methods += (method,) # type: ignore[operator] def __len__(self): return len(self._pruning_methods) def __iter__(self): return iter(self._pruning_methods) def __getitem__(self, idx): return self._pruning_methods[idx] def compute_mask(self, t, default_mask): r"""Applies the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. The new partial mask should be computed on the entries or channels that were not zeroed out by the ``default_mask``. Which portions of the tensor ``t`` the new mask will be calculated from depends on the ``PRUNING_TYPE`` (handled by the type handler): * for 'unstructured', the mask will be computed from the raveled list of nonmasked entries; * for 'structured', the mask will be computed from the nonmasked channels in the tensor; * for 'global', the mask will be computed across all entries. Args: t (flow.Tensor): tensor representing the parameter to prune (of same dimensions as ``default_mask``). default_mask (flow.Tensor): mask from previous pruning iteration. Returns: mask (flow.Tensor): new mask that combines the effects of the ``default_mask`` and the new mask from the current pruning ``method`` (of same dimensions as ``default_mask`` and ``t``). """ def _combine_masks(method, t, mask): r""" Args: method (a BasePruningMethod subclass): pruning method currently being applied. t (flow.Tensor): tensor representing the parameter to prune (of same dimensions as mask). mask (flow.Tensor): mask from previous pruning iteration Returns: new_mask (flow.Tensor): new mask that combines the effects of the old mask and the new mask from the current pruning method (of same dimensions as mask and t). """ new_mask = mask # start off from existing mask new_mask = new_mask.to(dtype=t.dtype) # compute a slice of t onto which the new pruning method will operate if method.PRUNING_TYPE == "unstructured": # prune entries of t where the mask is 1 slc = mask == 1 # for struct pruning, exclude channels that have already been # entirely pruned elif method.PRUNING_TYPE == "structured": if not hasattr(method, "dim"): raise AttributeError( "Pruning methods of PRUNING_TYPE " '"structured" need to have the attribute `dim` defined.' ) # find the channels to keep by removing the ones that have been # zeroed out already (i.e. where sum(entries) == 0) n_dims = t.dim() # "is this a 2D tensor? 3D? ..." dim = method.dim # convert negative indexing if dim < 0: dim = n_dims + dim # if dim is still negative after subtracting it from n_dims if dim < 0: raise IndexError( "Index is out of bounds for tensor with dimensions {}".format( n_dims ) ) # find channels along dim = dim that aren't already tots 0ed out keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 # create slice to identify what to prune slc = [slice(None)] * n_dims slc[dim] = keep_channel elif method.PRUNING_TYPE == "global": n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." slc = [slice(None)] * n_dims else: raise ValueError( "Unrecognized PRUNING_TYPE {}".format(method.PRUNING_TYPE) ) # compute the new mask on the unpruned slice of the tensor t partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) return new_mask method = self._pruning_methods[-1] mask = _combine_masks(method, t, default_mask) return mask class Identity(BasePruningMethod): r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones. """ PRUNING_TYPE = "unstructured" def compute_mask(self, t, default_mask): mask = default_mask return mask @classmethod def apply(cls, module, name): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. """ return super(Identity, cls).apply(module, name) class RandomUnstructured(BasePruningMethod): r"""Prune (currently unpruned) units in a tensor at random. Args: name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ PRUNING_TYPE = "unstructured" def __init__(self, amount): # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount def compute_mask(self, t, default_mask): # Check that the amount of units to prune is not > than the number of # parameters in t tensor_size = t.nelement() # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) mask = default_mask.clone() if nparams_toprune != 0: # k=0 not supported by flow.kthvalue # prob = flow.rand_like(t) prob = flow.rand(t.size(), dtype=t.dtype, device=t.device) topk = flow.topk(prob.view(-1), k=nparams_toprune) mask.view(-1)[topk.indices] = 0 return mask @classmethod def apply(cls, module, name, amount): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ return super(RandomUnstructured, cls).apply(module, name, amount=amount) class L1Unstructured(BasePruningMethod): r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. Args: amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ PRUNING_TYPE = "unstructured" def __init__(self, amount): # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount def compute_mask(self, t, default_mask): # Check that the amount of units to prune is not > than the number of # parameters in t tensor_size = t.nelement() # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) mask = default_mask.clone() if nparams_toprune != 0: # k=0 not supported by flow.kthvalue # largest=True --> top k; largest=False --> bottom k # Prune the smallest k topk = flow.topk(flow.abs(t).view(-1), k=nparams_toprune, largest=False) # topk will have .indices and .values mask.view(-1)[topk.indices] = 0 return mask @classmethod def apply(cls, module, name, amount, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. importance_scores (flow.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. """ return super(L1Unstructured, cls).apply( module, name, amount=amount, importance_scores=importance_scores ) class RandomStructured(BasePruningMethod): r"""Prune entire (currently unpruned) channels in a tensor at random. Args: amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ PRUNING_TYPE = "structured" def __init__(self, amount, dim=-1): # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount self.dim = dim def compute_mask(self, t, default_mask): r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the ``default_mask`` by randomly zeroing out channels along the specified dim of the tensor. Args: t (flow.Tensor): tensor representing the parameter to prune default_mask (flow.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t`` Raises: IndexError: if ``self.dim >= len(t.shape)`` """ # Check that tensor has structure (i.e. more than 1 dimension) such # that the concept of "channels" makes sense _validate_structured_pruning(t) # Check that self.dim is a valid dim to index t, else raise IndexError _validate_pruning_dim(t, self.dim) # Check that the amount of channels to prune is not > than the number of # channels in t along the dim to prune tensor_size = t.shape[self.dim] # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) # Compute binary mask by initializing it to all 0s and then filling in # 1s wherever topk.indices indicates, along self.dim. # mask has the same shape as tensor t def make_mask(t, dim, nchannels, nchannels_toprune): # generate a random number in [0, 1] to associate to each channel prob = flow.rand(nchannels) # generate mask for each channel by 0ing out the channels that # got assigned the k = nchannels_toprune lowest values in prob # threshold = flow.kthvalue(prob, k=nchannels_toprune).values # --------------------------------------------------------------- # Oneflow does not support kthvalue, but because the operation of kthvalue is # relatively simple, it is implemented directly in python y, i = flow.sort(prob) threshold = y[nchannels_toprune - 1] # --------------------------------------------------------------- channel_mask = prob > threshold mask = flow.zeros_like(t) slc = [slice(None)] * len(t.shape) slc[dim] = channel_mask mask[slc] = 1 return mask if nparams_toprune == 0: # k=0 not supported by flow.kthvalue mask = default_mask else: # apply the new structured mask on top of prior (potentially # unstructured) mask mask = make_mask(t, self.dim, tensor_size, nparams_toprune) mask *= default_mask.to(dtype=mask.dtype) return mask @classmethod def apply(cls, module, name, amount, dim=-1): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ return super(RandomStructured, cls).apply(module, name, amount=amount, dim=dim) class LnStructured(BasePruningMethod): r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. Args: amount (int or float): quantity of channels to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`flow.norm`. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ PRUNING_TYPE = "structured" def __init__(self, amount, n, dim=-1): # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount self.n = n self.dim = dim def compute_mask(self, t, default_mask): r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a mask to apply on top of the ``default_mask`` by zeroing out the channels along the specified dim with the lowest L\ ``n``-norm. Args: t (flow.Tensor): tensor representing the parameter to prune default_mask (flow.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t`` Raises: IndexError: if ``self.dim >= len(t.shape)`` """ # Check that tensor has structure (i.e. more than 1 dimension) such # that the concept of "channels" makes sense _validate_structured_pruning(t) # Check that self.dim is a valid dim to index t, else raise IndexError _validate_pruning_dim(t, self.dim) # Check that the amount of channels to prune is not > than the number of # channels in t along the dim to prune tensor_size = t.shape[self.dim] # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) nparams_tokeep = tensor_size - nparams_toprune # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) # Structured pruning prunes entire channels so we need to know the # L_n norm along each channel to then find the topk based on this # metric norm = _compute_norm(t, self.n, self.dim) # largest=True --> top k; largest=False --> bottom k # Keep the largest k channels along dim=self.dim topk = flow.topk(norm, k=nparams_tokeep, largest=True) # topk will have .indices and .values # Compute binary mask by initializing it to all 0s and then filling in # 1s wherever topk.indices indicates, along self.dim. # mask has the same shape as tensor t def make_mask(t, dim, indices): # init mask to 0 mask = flow.zeros_like(t) # e.g.: slc = [None, None, None], if len(t.shape) = 3 slc = [slice(None)] * len(t.shape) # replace a None at position=dim with indices # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] slc[dim] = indices # use slc to slice mask and replace all its entries with 1s # e.g.: mask[:, :, [0, 2, 3]] = 1 mask[slc] = 1 return mask if nparams_toprune == 0: # k=0 not supported by flow.kthvalue mask = default_mask else: mask = make_mask(t, self.dim, topk.indices) mask *= default_mask.to(dtype=mask.dtype) return mask @classmethod def apply(cls, module, name, amount, n, dim, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`flow.norm`. dim (int): index of the dim along which we define channels to prune. importance_scores (flow.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. """ return super(LnStructured, cls).apply( module, name, amount=amount, n=n, dim=dim, importance_scores=importance_scores, ) class CustomFromMask(BasePruningMethod): PRUNING_TYPE = "global" def __init__(self, mask): self.mask = mask def compute_mask(self, t, default_mask): assert default_mask.shape == self.mask.shape mask = default_mask * self.mask.to(dtype=default_mask.dtype) return mask @classmethod def apply(cls, module, name, mask): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. """ return super(CustomFromMask, cls).apply(module, name, mask=mask) def identity(module, name): r"""Applies pruning reparametrization to the tensor corresponding to the parameter called ``name`` in ``module`` without actually pruning any units. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Note: The mask is a tensor of ones. Args: module (nn.Module): module containing the tensor to prune. name (str): parameter name within ``module`` on which pruning will act. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.]) """ Identity.apply(module, name) return module def random_unstructured(module, name, amount): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) units selected at random. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) >>> flow.sum(m.weight_mask == 0) tensor(1) """ RandomUnstructured.apply(module, name, amount) return module def l1_unstructured(module, name, amount, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified `amount` of (currently unpruned) units with the lowest L1-norm. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. importance_scores (flow.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """ L1Unstructured.apply( module, name, amount=amount, importance_scores=importance_scores ) return module def random_structured(module, name, amount, dim): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` selected at random. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int): index of the dim along which we define channels to prune. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.random_structured( ... nn.Linear(5, 3), 'weight', amount=3, dim=1 ... ) >>> columns_pruned = int(sum(flow.sum(m.weight, dim=0) == 0)) >>> print(columns_pruned) 3 """ RandomStructured.apply(module, name, amount, dim) return module def ln_structured(module, name, amount, n, dim, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` with the lowest L\ ``n``-norm. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`flow.norm`. dim (int): index of the dim along which we define channels to prune. importance_scores (flow.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.ln_structured( ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ... ) """ LnStructured.apply( module, name, amount, n, dim, importance_scores=importance_scores ) return module def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. Modifies modules in place by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: parameters (Iterable of (module, name) tuples): parameters of the model to prune in a global fashion, i.e. by aggregating all weights prior to deciding which ones to prune. module must be of type :class:`nn.Module`, and name must be a string. pruning_method (function): a valid pruning function from this module, or a custom one implemented by the user that satisfies the implementation guidelines and has ``PRUNING_TYPE='unstructured'``. importance_scores (dict): a dictionary mapping (module, name) tuples to the corresponding parameter's importance scores tensor. The tensor should be the same shape as the parameter, and is used for computing mask for pruning. If unspecified or None, the parameter will be used in place of its importance scores. kwargs: other keyword arguments such as: amount (int or float): quantity of parameters to prune across the specified parameters. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. Raises: TypeError: if ``PRUNING_TYPE != 'unstructured'`` Note: Since global structured pruning doesn't make much sense unless the norm is normalized by the size of the parameter, we now limit the scope of global pruning to unstructured methods. Examples: >>> # xdoctest: +SKIP >>> net = nn.Sequential(OrderedDict([ ... ('first', nn.Linear(10, 4)), ... ('second', nn.Linear(4, 1)), ... ])) >>> parameters_to_prune = ( ... (net.first, 'weight'), ... (net.second, 'weight'), ... ) >>> prune.global_unstructured( ... parameters_to_prune, ... pruning_method=prune.L1Unstructured, ... amount=10, ... ) >>> print(sum(flow.nn.utils.parameters_to_vector(net.buffers()) == 0)) tensor(10, dtype=flow.uint8) """ # ensure parameters is a list or generator of tuples if not isinstance(parameters, Iterable): raise TypeError("global_unstructured(): parameters is not an Iterable") importance_scores = importance_scores if importance_scores is not None else {} if not isinstance(importance_scores, dict): raise TypeError("global_unstructured(): importance_scores must be of type dict") # flatten importance scores to consider them all at once in global pruning relevant_importance_scores = flow.nn.utils.parameters_to_vector( [ importance_scores.get((module, name), getattr(module, name)) for (module, name) in parameters ] ) # similarly, flatten the masks (if they exist), or use a flattened vector # of 1s of the same dimensions as t default_mask = flow.nn.utils.parameters_to_vector( [ getattr(module, name + "_mask", flow.ones_like(getattr(module, name))) for (module, name) in parameters ] ) # use the canonical pruning methods to compute the new mask, even if the # parameter is now a flattened out version of `parameters` container = PruningContainer() container._tensor_name = "temp" # to make it match that of `method` method = pruning_method(**kwargs) method._tensor_name = "temp" # to make it match that of `container` if method.PRUNING_TYPE != "unstructured": raise TypeError( 'Only "unstructured" PRUNING_TYPE supported for ' "the `pruning_method`. Found method {} of type {}".format( pruning_method, method.PRUNING_TYPE ) ) container.add_pruning_method(method) # use the `compute_mask` method from `PruningContainer` to combine the # mask computed by the new method with the pre-existing mask final_mask = container.compute_mask(relevant_importance_scores, default_mask) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 for module, name in parameters: param = getattr(module, name) # The length of the parameter num_param = param.numel() # Slice the mask, reshape it param_mask = final_mask[pointer : pointer + num_param].view_as(param) # Assign the correct pre-computed mask to each parameter and add it # to the forward_pre_hooks like any other pruning method custom_from_mask(module, name, mask=param_mask) # Increment the pointer to continue slicing the final_mask pointer += num_param def custom_from_mask(module, name, mask): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. mask (Tensor): binary mask to be applied to the parameter. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name='bias', mask=flow.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.]) """ CustomFromMask.apply(module, name, mask) return module def remove(module, name): r"""Removes the pruning reparameterization from a module and the pruning method from the forward hook. The pruned parameter named ``name`` remains permanently pruned, and the parameter named ``name+'_orig'`` is removed from the parameter list. Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. Note: Pruning itself is NOT undone or reversed! Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. Examples: >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) >>> m = remove(m, name='weight') """ for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: hook.remove(module) del module._forward_pre_hooks[k] return module raise ValueError( "Parameter '{}' of module {} has to be pruned " "before pruning can be removed".format(name, module) ) def is_pruned(module): r"""Check whether ``module`` is pruned by looking for ``forward_pre_hooks`` in its modules that inherit from the :class:`BasePruningMethod`. Args: module (nn.Module): object that is either pruned or unpruned Returns: binary answer to whether ``module`` is pruned. Examples: >>> m = nn.Linear(5, 7) >>> # xdoctest: +SKIP >>> print(prune.is_pruned(m)) False >>> prune.random_unstructured(m, name='weight', amount=0.2) >>> print(prune.is_pruned(m)) True """ for _, submodule in module.named_modules(): for _, hook in submodule._forward_pre_hooks.items(): if isinstance(hook, BasePruningMethod): return True return False def _validate_pruning_amount_init(amount): r"""Validation helper to check the range of amount at init. Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. Raises: ValueError: if amount is a float not in [0, 1], or if it's a negative integer. TypeError: if amount is neither a float nor an integer. Note: This does not take into account the number of parameters in the tensor to be pruned, which is known only at prune. """ if not isinstance(amount, numbers.Real): raise TypeError( "Invalid type for amount: {}. Must be int or float." "".format(amount) ) if (isinstance(amount, numbers.Integral) and amount < 0) or ( not isinstance(amount, numbers.Integral) # so it's a float and (float(amount) > 1.0 or float(amount) < 0.0) ): raise ValueError( "amount={} should either be a float in the " "range [0, 1] or a non-negative integer" "".format(amount) ) def _validate_pruning_amount(amount, tensor_size): r"""Validation helper to check that the amount of parameters to prune is meaningful wrt to the size of the data (`tensor_size`). Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. tensor_size (int): absolute number of parameters in the tensor to prune. """ # TODO: consider removing this check and allowing users to specify # a number of units to prune that is greater than the number of units # left to prune. In this case, the tensor will just be fully pruned. if isinstance(amount, numbers.Integral) and amount > tensor_size: raise ValueError( "amount={} should be smaller than the number of " "parameters to prune={}".format(amount, tensor_size) ) def _validate_structured_pruning(t): r"""Validation helper to check that the tensor to be pruned is multi- dimensional, such that the concept of "channels" is well-defined. Args: t (flow.Tensor): tensor representing the parameter to prune Raises: ValueError: if the tensor `t` is not at least 2D. """ shape = t.shape if len(shape) <= 1: raise ValueError( "Structured pruning can only be applied to " "multidimensional tensors. Found tensor of shape " "{} with {} dims".format(shape, len(shape)) ) def _compute_nparams_toprune(amount, tensor_size): r"""Since amount can be expressed either in absolute value or as a percentage of the number of units/channels in a tensor, this utility function converts the percentage to absolute value to standardize the handling of pruning. Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. tensor_size (int): absolute number of parameters in the tensor to prune. Returns: int: the number of units to prune in the tensor """ # incorrect type already checked in _validate_pruning_amount_init if isinstance(amount, numbers.Integral): return amount else: return round(amount * tensor_size) def _validate_pruning_dim(t, dim): r""" Args: t (flow.Tensor): tensor representing the parameter to prune dim (int): index of the dim along which we define channels to prune """ if dim >= t.dim(): raise IndexError("Invalid index {} for tensor of size {}".format(dim, t.shape)) def _compute_norm(t, n, dim): r"""Compute the L_n-norm across all entries in tensor `t` along all dimension except for the one identified by dim. Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), then norm will have Size [4], and each entry will represent the `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. Args: t (flow.Tensor): tensor representing the parameter to prune n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument p in flow.norm dim (int): dim identifying the channels to prune Returns: norm (flow.Tensor): L_n norm computed across all dimensions except for `dim`. By construction, `norm.shape = t.shape[-1]`. """ # dims = all axes, except for the one identified by `dim` dims = list(range(t.dim())) # convert negative indexing if dim < 0: dim = dims[dim] dims.remove(dim) # norm = flow.norm(t, p=n, dim=dims) # torch.norm in pytorch can support the norm of multi-dimensional arrays, # but the norm of the oneflow version only supports the norm of two-dimensional # arrays. So we need to reshape tensor into two-dimensional tensor. The dim of 1 # represent the dims to compute norm. a = t.clone() fullDims = list(range(a.dim())) retainedDims = list(set(fullDims).difference(set(dims))) permute_order = retainedDims + dims reshape_size = 1 for item in retainedDims: reshape_size *= a.shape[item] a = a.permute(permute_order) a = a.reshape(reshape_size, -1) norm = flow.norm(a, p=n, dim=1) return norm ================================================ FILE: python/oneflow/nn/utils/rnn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from audioop import reverse from collections import namedtuple from typing import List, Tuple, Union, Iterable, Optional import warnings import oneflow as flow from oneflow.framework.tensor import Tensor # The implementation of rnn util is modified from: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/rnn.py def bind(optional, fn): if optional is None: return None return fn(optional) def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: if permutation is None: return None return flow.scatter( flow.zeros_like(permutation), 0, permutation, flow.arange( 0, permutation.numel(), device=permutation.device, dtype=flow.int32 ), ) class PackedSequence(object): """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.PackedSequence.html. Holds the data and list of :attr:`batch_sizes` of a packed sequence. All RNN modules accept packed sequences as inputs. Note: Instances of this class should never be created manually. They are meant to be instantiated by functions like :func:`pack_padded_sequence`. Batch sizes represent the number elements at each sequence step in the batch, not the varying sequence lengths passed to :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` the :class:`PackedSequence` would contain data ``axbc`` with ``batch_sizes=[2,1,1]``. Attributes: data (Tensor): Tensor containing packed sequence batch_sizes (Tensor): Tensor of integers holding information about the batch size at each sequence step sorted_indices (Tensor, optional): Tensor of integers holding how this :class:`PackedSequence` is constructed from sequences. unsorted_indices (Tensor, optional): Tensor of integers holding how this to recover the original sequences with correct order. .. note:: :attr:`data` can be on arbitrary device and of arbitrary dtype. :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``oneflow.int64`` tensors on the same device as :attr:`data`. However, :attr:`batch_sizes` should always be a CPU ``oneflow.int64`` tensor. This invariant is maintained throughout :class:`PackedSequence` class, and all functions that construct a `:class:PackedSequence` in PyTorch (i.e., they only pass in tensors conforming to this constraint). """ def __init__( self, data: Tensor, batch_sizes: Optional[Tensor] = None, sorted_indices: Optional[Tensor] = None, unsorted_indices: Optional[Tensor] = None, ): self.sorted_indices = sorted_indices if unsorted_indices is None: self.unsorted_indices = invert_permutation(sorted_indices) self.sorted_indices = sorted_indices if batch_sizes is not None: if batch_sizes.device.type != "cpu": raise ValueError( "batch_sizes should always be on CPU. " "Instances of PackedSequence should never be created manually. " "They should be instantiated by functions like pack_sequence " "and pack_padded_sequences in nn.rnn_utils " ) self.data = data self.batch_sizes = batch_sizes else: assert isinstance(data, (list, tuple)) and len(data) == 2 self.data = data[0] self.batch_sizes = data[1] def pin_memory(self): return PackedSequence( self.data.pin_memory(), self.batch_sizes, bind(self.sorted_indices, lambda t: t.pin_memory()), bind(self.unsorted_indices, lambda t: t.pin_memory()), ) def cuda(self, *args, **kwargs): ex = flow.tensor((), dtype=self.data.dtype, device=self.data.device).to( *args, **kwargs ) if ex.is_cuda: return self.to(*args, **kwargs) return self.to(*args, device="cuda", **kwargs) def cpu(self, *args, **kwargs): ex = flow.tensor((), dtype=self.data.dtype, device=self.data.device).to( *args, **kwargs ) if ex.device.type == "cpu": return self.to(*args, **kwargs) return self.to(*args, device="cpu", **kwargs) def double(self): return self.to(dtype=flow.double) def float(self): return self.to(dtype=flow.float) def half(self): return self.to(dtype=flow.half) def long(self): return self.to(dtype=flow.long) def int(self): return self.to(dtype=flow.int) def short(self): return self.to(dtype=flow.short) def char(self): return self.to(dtype=flow.int8) def byte(self): return self.to(dtype=flow.uint8) def to(self, *args, **kwargs): """Performs dtype and/or device conversion on `self.data`. It has similar signature as :meth:`oneflow.Tensor.to`, except optional arguments like `non_blocking` and `copy` should be passed as kwargs, not args, or they will not apply to the index tensors. .. note:: If the ``self.data`` Tensor already has the correct :class:`oneflow.dtype` and :class:`oneflow.device`, then ``self`` is returned. Otherwise, returns a copy with the desired configuration. """ data = self.data.to(*args, **kwargs) if data is self.data: return self else: kwargs = { k: v for k, v in filter( lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items() ) } sorted_indices = bind( self.sorted_indices, lambda t: t.to(data.device, **kwargs) ) unsorted_indices = bind( self.unsorted_indices, lambda t: t.to(data.device, **kwargs) ) return PackedSequence( data, self.batch_sizes, sorted_indices, unsorted_indices ) @property def is_cuda(self): r"""Returns true if `self.data` stored on a gpu""" return self.data.is_cuda def is_pinned(self): r"""Returns true if `self.data` stored on in pinned memory""" return self.data.is_pinned() def pack_padded_sequence( input: Tensor, lengths: Tensor, batch_first: bool = False, enforce_sorted: bool = True, ) -> PackedSequence: """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pack_padded_sequence.html. Packs a Tensor containing padded sequences of variable length. :attr:`input` can be of size ``T x B x *`` where `T` is the length of the longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and ``*`` is any number of dimensions (including 0). If ``batch_first`` is ``True``, ``B x T x *`` :attr:`input` is expected. For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is ``True``, the sequences should be sorted by length in a decreasing order, i.e. ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest one. `enforce_sorted = True` is only necessary for ONNX export. Note: This function accepts any input that has at least two dimensions. You can apply it to pack the labels, and use the output of the RNN with them to compute the loss directly. A Tensor can be retrieved from a :class:`PackedSequence` object by accessing its ``.data`` attribute. Args: input (Tensor): padded batch of variable length sequences. lengths (Tensor or list(int)): list of sequence lengths of each batch element (must be on the CPU if provided as a tensor). batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` format. enforce_sorted (bool, optional): if ``True``, the input is expected to contain sequences sorted by length in a decreasing order. If ``False``, the input will get sorted unconditionally. Default: ``True``. Returns: a :class:`PackedSequence` object """ lengths = flow.as_tensor(lengths, dtype=flow.int64) assert ( enforce_sorted == True ), "Only support enforce_sorted == True for now. Plesase Sort the input by length in a decreasing order." if enforce_sorted: sorted_indices = None else: lengths, sorted_indices = flow.sort(lengths, descending=True) sorted_indices = sorted_indices.to(input.device) batch_dim = 0 if batch_first else 1 input = input.index_select(batch_dim, sorted_indices) data, batch_sizes = flow._C.pack_padded_sequence(input, lengths, batch_first) return PackedSequence(data, batch_sizes, sorted_indices, None) def pad_packed_sequence( sequence: PackedSequence, batch_first: bool = False, padding_value: float = 0.0, total_length: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pad_packed_sequence.html. Pads a packed batch of variable length sequences. It is an inverse operation to :func:`pack_padded_sequence`. The returned Tensor's data will be of size ``T x B x *``, where `T` is the length of the longest sequence and `B` is the batch size. If ``batch_first`` is True, the data will be transposed into ``B x T x *`` format. .. note:: :attr:`total_length` is useful to implement the ``pack sequence -> recurrent network -> unpack sequence`` pattern in a :class:`~oneflow.nn.Module` wrapped in :class:`~oneflow.nn.DataParallel`. Args: sequence (PackedSequence): batch to pad batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` format. padding_value (float, optional): values for padded elements. total_length (int, optional): if not ``None``, the output will be padded to have length :attr:`total_length`. This method will throw :class:`ValueError` if :attr:`total_length` is less than the max sequence length in :attr:`sequence`. Returns: Tuple of Tensor containing the padded sequence, and a Tensor containing the list of lengths of each sequence in the batch. Batch elements will be re-ordered as they were ordered originally when the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. For example: .. code-block:: python >>> from oneflow.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> import oneflow as flow >>> seq = flow.tensor([[4,5,6], [1,2,0], [3,0,0]]) >>> lens = [3, 2, 1] >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=True) >>> packed.data tensor([4, 1, 3, 5, 2, 6], dtype=oneflow.int64) >>> packed.batch_sizes tensor([3, 2, 1], dtype=oneflow.int64) >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) >>> seq_unpacked tensor([[4, 5, 6], [1, 2, 0], [3, 0, 0]], dtype=oneflow.int64) >>> lens_unpacked tensor([3., 2., 1.], dtype=oneflow.float32) """ max_seq_length = sequence.batch_sizes.shape[0] if total_length is not None: if total_length < max_seq_length: raise ValueError( "Expected total_length to be at least the length " "of the longest sequence in input, but got " "total_length={} and max sequence length being {}".format( total_length, max_seq_length ) ) else: total_length = max_seq_length batch_sizes_t = sequence.batch_sizes.contiguous() assert ( len(batch_sizes_t.shape) == 1 and batch_sizes_t.device.type == "cpu" and batch_sizes_t.dtype == flow.int64 ), f"'sequence.batch_sizes' should be a 1D CPU int64 tensor, but got {len(batch_sizes_t.shape)} D {batch_sizes_t.device.type} {batch_sizes_t.dtype} tensor" batch_sizes = batch_sizes_t.numpy() max_batch_size = int(batch_sizes[0]) max_real_seq_length = batch_sizes_t.shape[0] max_seq_length = max_real_seq_length if total_length > 0: assert ( total_length >= max_seq_length ), f"Expected total_length to be at least the length of the longest sequence in input, but got total_length={total_length} and max sequence length being {max_seq_length}" max_seq_length = total_length output_size = [] # == [max_seq_length, max_batch_size, *sequence.data.size()[1:]] output_size.append(max_seq_length) output_size.append(max_batch_size) output_size = output_size + list(sequence.data.shape[1:]) padded_output = flow.full( output_size, padding_value, dtype=sequence.data.dtype, device=sequence.data.device, requires_grad=sequence.data.requires_grad, ) # `padded_output` is leaf tensor which needs to be transformed into non-leaf tensor # when it requires grad by calling the `clone` method before the following # in-place operation to avoid runtime check error . if padded_output.requires_grad == True: padded_output = padded_output.clone() # This will be modified at every iteration, but we reserve memory for it now. tmp_view_size = output_size # == [-1, -1, *sequence.data.size()[1:]] lengths = flow.empty(max_batch_size) data_offset = 0 prev_batch_size = max_batch_size prev_i = 0 lengths_idx = max_batch_size - 1 for i in range(max_real_seq_length + 1): batch_size = batch_sizes[i] if i != max_real_seq_length else 0 if batch_size != prev_batch_size: l = prev_batch_size * (i - prev_i) tmp_view_size[0] = i - prev_i tmp_view_size[1] = prev_batch_size padded_output[prev_i:i, 0:prev_batch_size] = sequence.data[ data_offset : data_offset + l ].view(tmp_view_size) data_offset += l prev_i = i dec = prev_batch_size - batch_size if dec > 0: for j in range(dec): lengths[lengths_idx] = i lengths_idx = lengths_idx - 1 prev_batch_size = batch_size if batch_first: permute_dims = [1, 0] for i in range(2, padded_output.ndim): permute_dims.append(i) padded_output = padded_output.permute(permute_dims) unsorted_indices = sequence.unsorted_indices if unsorted_indices is not None: batch_dim = 0 if batch_first else 1 return ( padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices], ) return padded_output, lengths def pad_sequence( sequences: Union[Tensor, List[Tensor]], batch_first: bool = False, padding_value: float = 0.0, ) -> Tensor: """The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pad_sequence.html. Pad a list of variable length Tensors with ``padding_value`` ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them to equal length. For example, if the input is list of sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` otherwise. `B` is batch size. It is equal to the number of elements in ``sequences``. `T` is length of the longest sequence. `L` is length of the sequence. `*` is any number of trailing dimensions, including none. Note: This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` where `T` is the length of the longest sequence. This function assumes trailing dimensions and type of all the Tensors in sequences are same. Args: sequences (list[Tensor]): list of variable length sequences. batch_first (bool, optional): output will be in ``B x T x *`` if True, or in ``T x B x *`` otherwise. Default: False. padding_value (float, optional): value for padded elements. Default: 0. Returns: Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. Tensor of size ``B x T x *`` otherwise For example: .. code-block:: python >>> from oneflow.nn.utils.rnn import pad_sequence >>> import oneflow as flow >>> a = flow.ones(25, 300) >>> b = flow.ones(22, 300) >>> c = flow.ones(15, 300) >>> out = pad_sequence([a, b, c]) >>> out.size() oneflow.Size([25, 3, 300]) """ if isinstance(sequences, Tensor): sequences = sequences.unbind(0) # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] sequences_size = len(sequences) max_size = sequences[0].shape trailing_dims = max_size[1:] lens = [seq.shape[0] for seq in sequences] lens.sort(reverse=True) max_len = lens[0] out_dims = [sequences_size, max_len] if batch_first else [max_len, sequences_size] out_dims = out_dims + list(trailing_dims) out = flow.full( out_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device, requires_grad=sequences[0].requires_grad, ) for i in range(sequences_size): currseq = sequences[i] length_i = currseq.shape[0] # use index notation to prevent duplicate references to the tensor if batch_first: out[i, 0:length_i] = currseq else: out[0:length_i, i] = currseq return out def unpad_sequence( padded_sequences: Tensor, lengths: Tensor, batch_first: bool = False, ) -> List[Tensor]: """ Unpad padded Tensor into a list of variable length Tensors ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. Args: padded_sequences (Tensor): padded sequences. lengths (Tensor): length of original (unpadded) sequences. batch_first (bool, optional): whether batch dimension first or not. Default: False. Returns: a list of :class:`Tensor` objects For example: .. code-block:: python >>> from oneflow.nn.utils.rnn import pad_sequence, unpad_sequence >>> import oneflow as flow >>> import numpy as np >>> a = flow.ones(25, 300) >>> b = flow.ones(22, 300) >>> c = flow.ones(15, 300) >>> sequences = [a, b, c] >>> padded_sequences = pad_sequence(sequences) >>> lengths = flow.as_tensor([v.size(0) for v in sequences]) >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) >>> np.allclose(sequences[0].numpy(), unpadded_sequences[0].numpy()) True >>> np.allclose(sequences[1].numpy(), unpadded_sequences[1].numpy()) True >>> np.allclose(sequences[2].numpy(), unpadded_sequences[2].numpy()) True """ unpadded_sequences = [] if not batch_first: padded_sequences = padded_sequences.permute((1, 0, 2)) max_length = padded_sequences.shape[1] idx = flow.arange(max_length) for seq, length in zip(padded_sequences, lengths): mask = idx < length unpacked_seq = seq[mask] unpadded_sequences.append(unpacked_seq) return unpadded_sequences def pack_sequence( sequences: List[Tensor], enforce_sorted: bool = True ) -> PackedSequence: """Packs a list of variable length Tensors Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is the length of a sequence and `*` is any number of trailing dimensions, including zero. For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` is ``True``, the sequences should be sorted in the order of decreasing length. ``enforce_sorted = True`` is only necessary for ONNX export. Args: sequences (list[Tensor]): A list of sequences of decreasing length. enforce_sorted (bool, optional): if ``True``, checks that the input contains sequences sorted by length in a decreasing order. If ``False``, this condition is not checked. Default: ``True``. Returns: a :class:`PackedSequence` object For example: .. code-block:: python >>> from oneflow.nn.utils.rnn import pack_sequence >>> import oneflow as flow >>> a = flow.tensor([1,2,3]) >>> b = flow.tensor([4,5]) >>> c = flow.tensor([6]) >>> packed = pack_sequence([a, b, c]) >>> packed.data tensor([1, 4, 6, 2, 5, 3], dtype=oneflow.int64) >>> packed.batch_sizes tensor([3, 2, 1], dtype=oneflow.int64) """ lengths = flow.as_tensor([v.size(0) for v in sequences]) return pack_padded_sequence( pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted ) def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]: """Unpacks PackedSequence into a list of variable length Tensors ``packed_sequences`` should be a PackedSequence object. Args: packed_sequences (PackedSequence): A PackedSequence object. Returns: a list of :class:`Tensor` objects For example: .. code-block:: python >>> from oneflow.nn.utils.rnn import pack_sequence, unpack_sequence >>> import oneflow as flow >>> a = flow.tensor([1,2,3]) >>> b = flow.tensor([4,5]) >>> c = flow.tensor([6]) >>> sequences = [a, b, c] >>> packed_sequences = pack_sequence(sequences) >>> packed_sequences.data tensor([1, 4, 6, 2, 5, 3], dtype=oneflow.int64) >>> packed_sequences.batch_sizes tensor([3, 2, 1], dtype=oneflow.int64) >>> unpacked_sequences = unpack_sequence(packed_sequences) >>> unpacked_sequences [tensor([1, 2, 3], dtype=oneflow.int64), tensor([4, 5], dtype=oneflow.int64), tensor([6], dtype=oneflow.int64)] """ padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) return unpacked_sequences if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/nn/utils/skip_init.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect from oneflow.nn.modules.module import Module def skip_init(module_cls, *args, **kwargs): if not issubclass(module_cls, Module): raise RuntimeError("Expected a Module; got {}".format(module_cls)) if "device" not in inspect.signature(module_cls).parameters: raise RuntimeError("Module must support a 'device' arg to skip initialization") final_device = kwargs.pop("device", "cpu") kwargs["device"] = "meta" return module_cls(*args, **kwargs).to_empty(device=final_device) ================================================ FILE: python/oneflow/nn/utils/weight_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.framework.tensor import Tensor from typing import Any, TypeVar from oneflow.nn.modules.module import Module def _norm_except_dim_0(v: Tensor): output_size = [1] * v.dim() output_size[0] = v.size(0) return flow.linalg.norm(v.view(v.size(0), -1), ord=2, dim=1).view(*output_size) def _norm_except_dim(v: Tensor, dim: int): assert -v.dim() <= dim <= v.dim() - 1, "dim out of range" if dim == -1: return flow.linalg.norm(v, ord="fro") elif dim == 0: return _norm_except_dim_0(v) elif dim == v.dim() - 1: output_size = [1] * v.dim() output_size[v.dim() - 1] = v.size(v.dim() - 1) return flow.linalg.norm(v.view(-1, v.size(v.dim() - 1)), ord=2, dim=0).view( *output_size ) else: return flow.transpose(_norm_except_dim_0(flow.transpose(v, 0, dim)), 0, dim) class WeightNorm(object): name: str dim: int def __init__(self, name: str, dim: int) -> None: if dim is None: dim = -1 self.name = name self.dim = dim def compute_weight(self, module: Module) -> Any: g = getattr(module, self.name + "_g") v = getattr(module, self.name + "_v") return v * (g / _norm_except_dim(v, self.dim)) @staticmethod def apply(module, name: str, dim: int) -> "WeightNorm": for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError( "Cannot register two weight_norm hooks on " "the same parameter {}".format(name) ) if dim is None: dim = -1 fn = WeightNorm(name, dim) weight = getattr(module, name) del module._parameters[name] # add g and v as new parameters and express w as g/||v|| * v module.register_parameter( name + "_g", flow.nn.Parameter(_norm_except_dim(weight, dim)) ) module.register_parameter(name + "_v", flow.nn.Parameter(weight)) setattr(module, name, fn.compute_weight(module)) # recompute weight before every forward() module.register_forward_pre_hook(fn) return fn def remove(self, module: Module) -> None: weight = self.compute_weight(module) delattr(module, self.name) del module._parameters[self.name + "_g"] del module._parameters[self.name + "_v"] setattr(module, self.name, flow.nn.Parameter(weight)) def __call__(self, module: Module, inputs: Any) -> None: setattr(module, self.name, self.compute_weight(module)) T_module = TypeVar("T_module", bound=Module) def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module: r"""Applies weight normalization to a parameter in the given module. .. math:: \mathbf{w}=g \frac{\mathbf{v}}{\|\mathbf{v}\|} Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every :meth:`~Module.forward` call. By default, with ``dim=0``, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use ``dim=None``. See https://arxiv.org/abs/1602.07868 This document description is refereced to the Pytorch document: https://pytorch.org/docs/1.10/generated/torch.nn.utils.weight_norm.html. Args: module (Module): containing module name (str, optional): name of weight parameter dim (int, optional): dimension over which to compute the norm Returns: The original module with the weight norm hook For example: .. code-block:: python >>> import oneflow as flow >>> m = flow.nn.utils.weight_norm(flow.nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() oneflow.Size([40, 1]) >>> m.weight_v.size() oneflow.Size([40, 20]) """ WeightNorm.apply(module, name, dim) return module def remove_weight_norm(module: T_module, name: str = "weight") -> T_module: r"""Removes the weight normalization reparameterization from a module. Args: module (Module): containing module name (str, optional): name of weight parameter For example: .. code-block:: python >>> import oneflow as flow >>> m = flow.nn.utils.weight_norm(flow.nn.Linear(20, 40)) >>> flow.nn.utils.remove_weight_norm(m) Linear(in_features=20, out_features=40, bias=True) """ for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, WeightNorm) and hook.name == name: hook.remove(module) del module._forward_pre_hooks[k] return module raise ValueError("weight_norm of '{}' not found in {}".format(name, module)) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=True) ================================================ FILE: python/oneflow/one_embedding.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Callable, Dict, Iterator, List, Union import oneflow as flow from oneflow.nn.modules.module import Module from oneflow.optim.optimizer import Optimizer from oneflow.nn.parameter import Parameter import json import datetime from oneflow._oneflow_internal import OneEmbeddingHandler from oneflow._oneflow_internal import PersistentTableReader from oneflow._oneflow_internal import PersistentTableWriter import numpy as np import traceback from oneflow import nn import oneflow.framework.graph_build_util as graph_build_util def _check_initializer(initializer): assert isinstance(initializer, dict) assert initializer.__contains__("type") initializer_type = initializer["type"] assert initializer_type in ["uniform", "normal", "constant", "trunc_normal"] if initializer_type == "uniform": assert initializer.__contains__("low") assert initializer.__contains__("high") elif initializer_type == "normal": assert initializer.__contains__("mean") assert initializer.__contains__("std") elif initializer_type == "constant": assert initializer.__contains__("value") elif initializer_type == "trunc_normal": assert initializer.__contains__("mean") assert initializer.__contains__("std") assert initializer.__contains__("a") assert initializer.__contains__("b") else: raise NotImplementedError("unsupported initializer_type") def _check_cache(cache): assert isinstance(cache, dict) assert cache.__contains__("policy") assert cache["policy"] in ["lru", "full"] cache_memory_budget_mb = 0 if cache.__contains__("cache_memory_budget_mb"): cache_memory_budget_mb = cache["cache_memory_budget_mb"] capacity = 0 if cache.__contains__("capacity"): capacity = cache["capacity"] assert cache_memory_budget_mb > 0 or capacity > 0 assert cache.__contains__("value_memory_kind") assert cache["value_memory_kind"] in ["device", "host"] def _init( name, embedding_dims, dtype, key_type, tables, store_options, default_initializer ): default_initializer = default_initializer or { "type": "normal", "mean": 0, "std": 0.05, } key_value_store_options = {} embedding_tables = {} key_value_store_options["name"] = name if isinstance(embedding_dims, (list, tuple)): column_dims = embedding_dims embedding_dim = sum(embedding_dims) else: assert embedding_dims > 0 column_dims = [embedding_dims] embedding_dim = embedding_dims parallel_num = flow.env.get_world_size() key_type_size = np.dtype( flow.convert_oneflow_dtype_to_numpy_dtype(key_type) ).itemsize assert key_type_size > 0 key_value_store_options["key_type_size"] = key_type_size value_type_size = np.dtype( flow.convert_oneflow_dtype_to_numpy_dtype(dtype) ).itemsize assert value_type_size > 0 key_value_store_options["value_type_size"] = value_type_size key_value_store_options["value_type"] = str(dtype) scale_factor = store_options["size_factor"] storage_dim = store_options["storage_dim"] if storage_dim != -1: key_value_store_options["storage_dim"] = storage_dim else: key_value_store_options["storage_dim"] = scale_factor * embedding_dim # kv store assert store_options.__contains__("kv_store") kv_store = store_options["kv_store"] assert isinstance(kv_store, dict) if kv_store.__contains__("caches"): caches = kv_store["caches"] assert isinstance(caches, (dict, list, tuple)) if isinstance(caches, dict): _check_cache(caches) caches = [caches] else: assert len(caches) <= 2 for i in range(len(caches)): assert isinstance(caches[i], dict) _check_cache(caches[i]) for i in range(len(caches)): if caches[i].__contains__("capacity"): caches[i]["capacity"] = caches[i]["capacity"] // parallel_num assert kv_store.__contains__("persistent_table") persistent_table = kv_store["persistent_table"] assert isinstance(persistent_table, dict) assert persistent_table.__contains__("path") persistent_table_path = persistent_table["path"] assert isinstance(persistent_table_path, (str, list, tuple)) if isinstance(persistent_table_path, (list, tuple)): assert len(persistent_table_path) == parallel_num if persistent_table.__contains__("physical_block_size"): assert persistent_table["physical_block_size"] in [512, 4096] else: persistent_table["physical_block_size"] = 4096 if persistent_table.__contains__("capacity_hint"): assert persistent_table["capacity_hint"] >= 0 persistent_table["capacity_hint"] = ( persistent_table["capacity_hint"] // parallel_num ) key_value_store_options["kv_store"] = kv_store # initializer if tables is not None: assert isinstance(tables, (list, tuple)) for i in range(len(tables)): table = tables[i] if table.__contains__("columns"): assert not table.__contains__("initializer") columns = table["columns"] assert len(columns) == len(column_dims) for column in columns: assert isinstance(column, dict) assert column.__contains__("initializer") _check_initializer(column["initializer"]) else: assert isinstance(table, dict) assert table.__contains__("initializer") _check_initializer(table["initializer"]) columns = [] for j in range(len(column_dims)): columns.append(make_column_options(table["initializer"])) table["columns"] = columns del table["initializer"] embedding_tables["tables"] = tables else: assert default_initializer is not None _check_initializer(default_initializer) columns = [] for j in range(len(column_dims)): columns.append(make_column_options(default_initializer)) embedding_tables["tables"] = [{"columns": columns}] embedding_tables["column_dims"] = column_dims key_value_store_options["parallel_num"] = parallel_num return embedding_dim, embedding_tables, key_value_store_options class Embedding(Module): def __init__( self, name, embedding_dim, dtype, key_type, tables, store_options, default_initializer=None, padding_idx=None, seed=0, ): super().__init__() self.dtype = dtype self.key_type = key_type parallel_num = flow.env.get_world_size() self.embedding_dim, embedding_tables, key_value_store_options = _init( name, embedding_dim, dtype, key_type, tables, store_options, default_initializer, ) self.storage_dim = key_value_store_options["storage_dim"] self.embedding_name = key_value_store_options["name"] self.seed = seed self.is_full_cache = ( len(key_value_store_options["kv_store"]["caches"]) > 0 and key_value_store_options["kv_store"]["caches"][0]["policy"] == "full" ) self.key_value_store_options = json.dumps(key_value_store_options) self.embedding_tables = json.dumps(embedding_tables) self.num_tables = len(embedding_tables["tables"]) self.local_rank = flow.env.get_local_rank() self.rank_id = flow.env.get_rank() self.world_size = flow.env.get_world_size() self.handler = OneEmbeddingHandler( self.key_value_store_options, self.local_rank, self.rank_id, self.world_size ) self.shadow = flow.nn.Parameter(flow.Tensor(1)) self.padding_idx = padding_idx self.embedding = None def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) snapshot_timestamp_tensor = flow.tensor( datetime.datetime.now().timestamp(), dtype=flow.float64, device="cuda" ) # Broadcast timestamp tensor from master rank. flow.comm.broadcast(snapshot_timestamp_tensor, src=0) snapshot_timestamp = float(snapshot_timestamp_tensor.numpy()) snapshot_timestamp_datetime = datetime.datetime.fromtimestamp( snapshot_timestamp ) snapshot_timestamp_str = snapshot_timestamp_datetime.strftime( "%Y-%m-%d-%H-%M-%S-%f" ) self.handler.SaveSnapshot(snapshot_timestamp_str) destination[prefix + "OneEmbeddingSnapshot"] = snapshot_timestamp_str destination[ prefix + "OneEmbeddingKeyValueOptions" ] = self.key_value_store_options def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): key = prefix + "OneEmbeddingSnapshot" if key in state_dict: saved_snapshot_name = state_dict[key] try: self.handler.LoadSnapshot(saved_snapshot_name) except Exception as ex: error_msgs.append( 'While Loading OneEmbedding Snapshot named "{}" failed, please check whether the Snapshot exist'.format( saved_snapshot_name ) ) def save_snapshot(self, snapshot_name): """save snapshot Args: snapshot_name (str): the snapshot_name, snapshot will be saved in the snapshots dir under your_configed_persistent_path For example: .. code-block:: python >>> import oneflow as flow >>> # use embedding create by flow.one_embedding.MultiTableEmbedding >>> embedding.save_snapshot("my_snapshot1") >>> # a snapshot named "my_snapshot1" have been saved in the "snapshots" dir under your_configed_persistent_path >>> # which can be reload by flow.one_embedding.load_snapshot """ self.handler.SaveSnapshot(snapshot_name) def load_snapshot(self, snapshot_name): """load snapshot Args: snapshot_name (str): the snapshot_name, snapshot will be load from your_configed_persistent_path For example: .. code-block:: python >>> import oneflow as flow >>> # use embedding create by flow.one_embedding.MultiTableEmbedding >>> embedding.load_snapshot("my_snapshot1") >>> # load a snapshot named "my_snapshot1" from your_configed_persistent_path """ self.handler.LoadSnapshot(snapshot_name) def forward(self, ids, table_ids=None): """Embedding lookup operation Args: ids (flow.tensor): the feature ids table_ids (flow.tensor, optional): the table_id of each id, must be same shape as ids. There is no need to pass table_ids, if has config only one table or the ids has shape (batch_size, num_tables), and each column's id belongs to the column_id th table, otherwise, you should pass the tensor_ids. Returns: flow.tensor: the result of embedding lookup """ assert self.key_type == ids.dtype, "ids data_type must equals key_type" embedding = flow._C.one_embedding_fused_lookup( self.shadow, ids, table_ids, self.dtype, self.embedding_name, self.storage_dim, self.embedding_dim, self.is_full_cache, self.num_tables, self.embedding_tables, self.padding_idx, self.seed, ) if embedding.requires_grad and not graph_build_util.lazy_mode.is_enabled(): if self.embedding is not None: raise ValueError( "You are training without set embedding optimizer, Please add flow.one_embedding.Optimizer after optimizer." ) self.embedding = embedding self.embedding.retain_grad() self.ids = ids self.table_ids = table_ids return embedding def shuffle_and_lookup(self, state_initializer): embedding_grad = self.embedding.grad if self.world_size > 1: ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle( self.ids, self.table_ids, self.num_tables, self.embedding_name ) unique_values = flow._C.one_embedding_lookup( cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, self.dtype, self.dtype, self.storage_dim, self.embedding_dim, self.embedding_name, self.embedding_tables, state_initializer, seed=self.seed, ) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, self.embedding_name, ) else: ( cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, inverse_indices, ) = flow._C.one_embedding_unique_key_value_pair( self.ids, self.table_ids, self.num_tables, self.embedding_name ) unique_values = flow._C.one_embedding_lookup( cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, self.dtype, self.dtype, self.storage_dim, self.embedding_dim, self.embedding_name, self.embedding_tables, state_initializer, seed=self.seed, ) cur_rank_unique_embedding_grad = flow._C.unsorted_segment_sum( embedding_grad, inverse_indices, axis=0, num_segments=unique_values.shape[0], ) self.embedding = None return ( cur_rank_num_unique, cur_rank_unique_ids, unique_values, cur_rank_unique_embedding_grad, ) def sgd_update(self, param_group, step): lr = param_group["lr"] l2 = param_group["weight_decay"] momentum = param_group["momentum"] ( cur_rank_num_unique, cur_rank_unique_ids, unique_values, cur_rank_unique_embedding_grad, ) = self.shuffle_and_lookup("") updated_values = flow._C.one_embedding_sgd_update( cur_rank_num_unique, unique_values, cur_rank_unique_embedding_grad, learning_rate_val=lr, scale=1.0, weight_decay=l2, momentum=momentum, line_size=self.storage_dim, embedding_size=self.embedding_dim, embedding_name=self.embedding_name, ) flow._C.one_embedding_embedding_put( cur_rank_num_unique, cur_rank_unique_ids, updated_values, self.embedding_name, self.storage_dim, ) def adam_update(self, param_group, step): line_size = self.storage_dim embedding_size = self.embedding_dim lr = param_group["lr"] # not adjust, because it has been set in optimizer's step bias_correction1 = param_group["bias_correction1"] bias_correction2 = param_group["bias_correction2"] l2 = param_group["weight_decay"] beta1 = param_group["betas"][0] beta2 = param_group["betas"][1] epsilon = param_group["eps"] do_bias_correction = param_group["do_bias_correction"] amsgrad = param_group["amsgrad"] assert amsgrad == False, "one_embedding's adam not support amsgrad" state_initializer = [make_constant_initializer(0), make_constant_initializer(0)] ( cur_rank_num_unique, cur_rank_unique_ids, unique_values, cur_rank_unique_embedding_grad, ) = self.shuffle_and_lookup(json.dumps(state_initializer)) updated_values = flow._C.one_embedding_adam_update( cur_rank_num_unique, unique_values, cur_rank_unique_embedding_grad, learning_rate_val=lr, scale=1.0, weight_decay=l2, beta1=beta1, beta2=beta2, bias_correction1_val=bias_correction1, bias_correction2_val=bias_correction2, epsilon=epsilon, do_bias_correction=do_bias_correction, line_size=line_size, embedding_size=embedding_size, embedding_name=self.embedding_name, ) flow._C.one_embedding_embedding_put( cur_rank_num_unique, cur_rank_unique_ids, updated_values, self.embedding_name, line_size, ) def adagrad_update(self, param_group, step): lr = param_group["lr"] l2 = param_group["weight_decay"] epsilon = param_group["eps"] lr_decay = param_group["lr_decay"] initial_accumulator_value = param_group["initial_accumulator_value"] state_initializer = [make_constant_initializer(initial_accumulator_value)] ( cur_rank_num_unique, cur_rank_unique_ids, unique_values, cur_rank_unique_embedding_grad, ) = self.shuffle_and_lookup(json.dumps(state_initializer)) updated_values = flow._C.one_embedding_adagrad_update( cur_rank_num_unique, unique_values, cur_rank_unique_embedding_grad, train_step_val=step + 1, learning_rate_val=lr, scale=1.0, weight_decay=l2, lr_decay=lr_decay, epsilon=epsilon, line_size=self.storage_dim, embedding_size=self.embedding_dim, embedding_name=self.embedding_name, ) flow._C.one_embedding_embedding_put( cur_rank_num_unique, cur_rank_unique_ids, updated_values, self.embedding_name, self.storage_dim, ) def ftrl_update(self, param_group, step): lr = param_group["lr"] l2 = param_group["weight_decay"] lr_power = param_group["lr_power"] lambda1 = param_group["lambda1"] lambda2 = param_group["lambda2"] beta = param_group["beta"] initial_accumulator_value = param_group["initial_accumulator_value"] state_initializer = [ make_constant_initializer(initial_accumulator_value), make_constant_initializer(initial_accumulator_value), ] ( cur_rank_num_unique, cur_rank_unique_ids, unique_values, cur_rank_unique_embedding_grad, ) = self.shuffle_and_lookup(json.dumps(state_initializer)) updated_values = flow._C.one_embedding_ftrl_update( cur_rank_num_unique, unique_values, cur_rank_unique_embedding_grad, learning_rate_val=lr, scale=1.0, weight_decay=l2, lr_power=lr_power, lambda1=lambda1, lambda2=lambda2, beta=beta, line_size=self.storage_dim, embedding_size=self.embedding_dim, embedding_name=self.embedding_name, ) flow._C.one_embedding_embedding_put( cur_rank_num_unique, cur_rank_unique_ids, updated_values, self.embedding_name, self.storage_dim, ) def make_device_mem_store_options( persistent_path, capacity, size_factor=1, storage_dim=-1, physical_block_size=4096 ): """make GPU only store_options param of MultiTableEmbedding Args: persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding. capacity (int): total capacity of Embedding size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1. storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1. physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096. Returns: dict: GPU only store_options param of MultiTableEmbedding See also :func:`oneflow.one_embedding.make_cached_ssd_store_options` """ assert isinstance(persistent_path, (str, list, tuple)) assert capacity > 0 options = { "kv_store": { "caches": [ { "policy": "full", "capacity": int(capacity), "value_memory_kind": "device", } ], "persistent_table": { "path": persistent_path, "physical_block_size": physical_block_size, "capacity_hint": int(capacity), }, }, "size_factor": size_factor, "storage_dim": storage_dim, } return options def make_cached_ssd_store_options( cache_budget_mb, persistent_path, capacity=None, size_factor=1, storage_dim=-1, physical_block_size=4096, host_cache_budget_mb=0, ): """make SSD use GPU and host as cache store_options param of MultiTableEmbedding. If cache_budget_mb > 0 and host_cache_budget_mb > 0, use GPU and host memory as multi-level cache. Args: cache_budget_mb (int): the MB budget of per GPU as cache. persistent_path (str, list): persistent storage path of Embedding, must use fast SSD because of frequently random disk access during training. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding. capacity (int): total capacity of Embedding size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1. storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1. physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096. host_cache_budget_mb (int): the MB budget of host memory as cache per rank. Defaults to 0. Returns: dict: SSD use GPU and host as cache store_options param of MultiTableEmbedding For example: .. code-block:: python >>> import oneflow as flow >>> store_options = flow.one_embedding.make_cached_ssd_store_options( >>> cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size, >>> ) >>> # pass the store_options to the "store_options" param of flow.one_embedding.MultiTableEmbedding >>> # ... """ assert isinstance(persistent_path, (str, list, tuple)) assert cache_budget_mb > 0 or host_cache_budget_mb > 0 if capacity is not None: assert capacity > 0 else: capacity = 0 cache_list = [] if cache_budget_mb > 0: cache_list.append( { "policy": "lru", "cache_memory_budget_mb": cache_budget_mb, "value_memory_kind": "device", } ) if host_cache_budget_mb > 0: cache_list.append( { "policy": "lru", "cache_memory_budget_mb": host_cache_budget_mb, "value_memory_kind": "host", } ) options = { "kv_store": { "caches": cache_list, "persistent_table": { "path": persistent_path, "physical_block_size": physical_block_size, "capacity_hint": int(capacity), }, }, "size_factor": size_factor, "storage_dim": storage_dim, } return options def make_cached_host_mem_store_options( cache_budget_mb, persistent_path, capacity, size_factor=1, storage_dim=-1, physical_block_size=4096, ): """make host use GPU as cache store_options param of MultiTableEmbedding Args: cache_budget_mb (int): the MB budget of per GPU as cache. persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding. capacity (int): total capacity of Embedding size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1. storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1. physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096. Returns: dict: host use GPU as cache store_options param of MultiTableEmbedding See also :func:`oneflow.one_embedding.make_cached_ssd_store_options` """ assert isinstance(persistent_path, (str, list, tuple)) assert cache_budget_mb > 0 assert capacity > 0 options = { "kv_store": { "caches": [ { "policy": "lru", "cache_memory_budget_mb": cache_budget_mb, "value_memory_kind": "device", }, { "policy": "full", "capacity": int(capacity), "value_memory_kind": "host", }, ], "persistent_table": { "path": persistent_path, "physical_block_size": physical_block_size, "capacity_hint": int(capacity), }, }, "size_factor": size_factor, "storage_dim": storage_dim, } return options def make_uniform_initializer(low=0.0, high=1.0): """make uniform initializer param of make_table_options Args: low (float): A python scalar. Lower bound of the range of random values to generate. high (float): A python scalar. Upper bound of the range of random values to generate. Returns: dict: initializer param of make_table_options For example: .. code-block:: python >>> import oneflow as flow >>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) >>> # pass the initializer to flow.one_embedding.make_table_options >>> # ... """ return {"type": "uniform", "low": low, "high": high} def make_normal_initializer(mean=0.0, std=1.0): """make normal initializer param of make_table_options Args: mean (float): A python scalar. Mean of the random values to generate. std (float): A python scalar. Standard deviation of the random values to generate. Returns: dict: initializer param of make_table_options For example: .. code-block:: python >>> import oneflow as flow >>> initializer = flow.one_embedding.make_normal_initializer(mean=0, std=0.01) >>> # pass the initializer to flow.one_embedding.make_table_options >>> # ... """ return {"type": "normal", "mean": mean, "std": std} def make_constant_initializer(value): """make constant initializer param of make_table_options Args: constant (float): A python scalar. value to generate. Returns: dict: initializer param of make_table_options For example: .. code-block:: python >>> import oneflow as flow >>> initializer = flow.one_embedding.make_constant_initializer(value=0) >>> # pass the initializer to flow.one_embedding.make_table_options >>> # ... """ return {"type": "constant", "value": value} def make_trunc_normal_initializer(mean=0.0, std=1.0, a=-2.0, b=2.0): """make truncated normal initializer param of make_table_options Args: mean (float): A python scalar. Mean of the random values to generate. std (float): A python scalar. Standard deviation of the random values to generate. a (float): A python scalar. The minimum cutoff value. b (float): A python scalar. The maximum cutoff value. Returns: dict: initializer param of make_table_options For example: .. code-block:: python >>> import oneflow as flow >>> initializer = flow.one_embedding.make_trunc_normal_initializer(mean=0, std=0.01, a=-0.02, b=0.02) >>> # pass the initializer to flow.one_embedding.make_table_options >>> # ... """ return {"type": "trunc_normal", "mean": mean, "std": std, "a": a, "b": b} def make_table_options(param): """make table param of Embedding tables Args: param (dict or list): param can be initializer or list of column_option. initializer can be made by make_uniform_initializer or make_normal_initializer or make_constant_initializer, column options can be made by make_column_options Returns: dict: table param of Embedding tables For example: .. code-block:: python >>> import oneflow as flow >>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) >>> table1 = flow.one_embedding.make_table_options(initializer) >>> table2 = flow.one_embedding.make_table_options(initializer) >>> tables = [table1, table2] >>> # pass the tables to the "tables" param of flow.one_embedding.MultiTableEmbedding or flow.one_embedding.MultiTableMultiColumnEmbedding >>> # ... """ if isinstance(param, dict): table = {"initializer": param} elif isinstance(param, (list, tuple)): table = {"columns": param} else: raise ValueError("param must be initializer or columns") return table def make_column_options(initializer): return {"initializer": initializer} def make_table(param): """alias of `oneflow.one_embedding.make_table_options` See also :func:`oneflow.one_embedding.make_table_options` """ return make_table_options(param) class MultiTableEmbedding(Embedding): r"""MultiTableEmbedding represent multi Embedding tables with same embedding_dim, dtype, and key_type. Args: name (str): The name of Embedding embedding_dim (int): the size of each embedding vector dtype (flow.dtype): the data type of embeddings key_type (flow.dtype): the data type of feature ids tables (list): list of table param which can be made by flow.one_embedding.make_table_options store_options (dict): store option of Embedding default_initializer (dict, optional): if tables param is None, use default_initializer to initialize table. Defaults to None. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, the embedding vector at :attr:`padding_idx` will default to all zeros. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> import oneflow.nn as nn >>> # a simple example with 3 table >>> table_size_array = [39884407, 39043, 17289] >>> vocab_size = sum(table_size_array) >>> num_tables = len(table_size_array) >>> embedding_size = 128 >>> scales = np.sqrt(1 / np.array(table_size_array)) >>> tables = [ >>> flow.one_embedding.make_table_options( >>> flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) >>> ) >>> for scale in scales >>> ] >>> store_options = flow.one_embedding.make_cached_ssd_store_options( >>> cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size, >>> ) >>> embedding = flow.one_embedding.MultiTableEmbedding( >>> name="my_embedding", >>> embedding_dim=embedding_size, >>> dtype=flow.float, >>> key_type=flow.int64, >>> tables=tables, >>> store_options=store_options, >>> ) >>> embedding.to("cuda") >>> mlp = flow.nn.FusedMLP( >>> in_features=embedding_size * num_tables, >>> hidden_features=[512, 256, 128], >>> out_features=1, >>> skip_final_activation=True, >>> ) >>> mlp.to("cuda") >>> >>> class TrainGraph(flow.nn.Graph): >>> def __init__(self,): >>> super().__init__() >>> self.embedding_lookup = embedding >>> self.mlp = mlp >>> self.add_optimizer( >>> flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0) >>> ) >>> self.add_optimizer( >>> flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0) >>> ) >>> def build(self, ids): >>> embedding = self.embedding_lookup(ids) >>> loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size))) >>> loss = loss.sum() >>> loss.backward() >>> return loss >>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64) >>> ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") >>> graph = TrainGraph() >>> loss = graph(ids_tensor) >>> print(loss) """ def __init__( self, name, embedding_dim, dtype, key_type, tables, store_options, default_initializer=None, padding_idx=None, seed=0, ): assert isinstance(embedding_dim, int) super().__init__( name, embedding_dim, dtype, key_type, tables, store_options, default_initializer, padding_idx, seed, ) class MultiTableMultiColumnEmbedding(Embedding): r"""MultiTableMultiColumnEmbedding represent multi Embedding tables with multi embedding_dim, same dtype, and key_type. Args: name (str): The name of Embedding embedding_dim (list): list of the size of each embedding vector dtype (flow.dtype): the data type of embeddings key_type (flow.dtype): the data type of feature ids tables (list): list of table param which can be made by flow.one_embedding.make_table_options store_options (dict): store option of Embedding default_initializer (dict, optional): if tables param is None, use default_initializer to initialize table. Defaults to None. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, the embedding vector at :attr:`padding_idx` will default to all zeros. For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np >>> import oneflow.nn as nn >>> # a simple example with 3 table, every table has two column, the first column embedding_size is 10 and the second is 1. >>> # every table's first column initialize with uniform(-1/sqrt(table_size), 1/sqrt(table_size)), second column initialize with normal(0, 1/sqrt(table_size)) >>> table_size_array = [39884407, 39043, 17289] >>> vocab_size = sum(table_size_array) >>> num_tables = len(table_size_array) >>> embedding_size_list = [10, 1] >>> scales = np.sqrt(1 / np.array(table_size_array)) >>> tables = [ >>> flow.one_embedding.make_table_options( >>> [flow.one_embedding.make_column_options( >>> flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)), >>> flow.one_embedding.make_column_options( >>> flow.one_embedding.make_normal_initializer(mean=0, std=scale))] >>> ) >>> for scale in scales >>> ] >>> store_options = flow.one_embedding.make_cached_ssd_store_options( >>> cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size, >>> ) >>> embedding = flow.one_embedding.MultiTableMultiColumnEmbedding( >>> name="my_embedding", >>> embedding_dim=embedding_size_list, >>> dtype=flow.float, >>> key_type=flow.int64, >>> tables=tables, >>> store_options=store_options, >>> ) >>> embedding.to("cuda") >>> mlp = flow.nn.FusedMLP( >>> in_features=sum(embedding_size_list) * num_tables, >>> hidden_features=[512, 256, 128], >>> out_features=1, >>> skip_final_activation=True, >>> ) >>> mlp.to("cuda") >>> >>> class TrainGraph(flow.nn.Graph): >>> def __init__(self,): >>> super().__init__() >>> self.embedding_lookup = embedding >>> self.mlp = mlp >>> self.add_optimizer( >>> flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0) >>> ) >>> self.add_optimizer( >>> flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0) >>> ) >>> def build(self, ids): >>> embedding = self.embedding_lookup(ids) >>> loss = self.mlp(flow.reshape(embedding, (-1, num_tables * sum(embedding_size_list)))) >>> loss = loss.sum() >>> loss.backward() >>> return loss >>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64) >>> ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") >>> graph = TrainGraph() >>> loss = graph(ids_tensor) >>> print(loss) """ def __init__( self, name, embedding_dim, dtype, key_type, tables, store_options, default_initializer=None, padding_idx=None, seed=0, ): if isinstance(embedding_dim, (list, tuple)): for dim in embedding_dim: assert isinstance(dim, int) else: assert isinstance(embedding_dim, int) super().__init__( name, embedding_dim, dtype, key_type, tables, store_options, default_initializer, padding_idx, seed, ) class Ftrl(Optimizer): r"""FTRL Optimizer. The formula is: .. math:: \begin{align} accumlator_{i+1} = accumlator_{i} + grad * grad \\ sigma = (accumulator_{i+1}^{lr\_power} - accumulator_{i}^{lr\_power}) / learning\_rate \\ z_{i+1} = z_{i} + grad - sigma * param_{i} \\ \text{} param_{i+1} = \begin{cases} 0 & \text{ if } |z_{i+1}| < \lambda_1 \\ -(\frac{\beta+accumlator_{i+1}^{lr\_power}}{learning\_rate} + \lambda_2)*(z_{i+1} - sign(z_{i+1})*\lambda_1) & \text{ otherwise } \\ \end{cases} \end{align} Example 1: .. code-block:: python # Assume net is a custom model. ftrl = flow.one_embedding.FTRL(net.parameters(), lr=1e-3) for epoch in range(epochs): # Read data, Compute the loss and so on. # ... loss.backward() ftrl.step() ftrl.zero_grad() Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate. Defaults to 1e-3. weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0.0. lr_power (float, optional): learning rate decrease factor. Defaults to -0.5. initial_accumulator_value (float, optional): The initial value of accumlator. Defaults to 0.1. lambda1 (float, optional): L1 regularization strength. Defaults to 0.0. lambda2 (float, optional): L2 regularization strength. Defaults to 0.0. beta (float, optional): The value of beta. Defaults to 0.0. """ def __init__( self, params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, weight_decay: float = 0.0, lr_power: float = -0.5, initial_accumulator_value: float = 0.1, lambda1: float = 0.0, lambda2: float = 0.0, beta: float = 0.0, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" options = dict() options["lr"] = lr options["weight_decay"] = weight_decay options["lr_power"] = lr_power options["initial_accumulator_value"] = initial_accumulator_value options["lambda1"] = lambda1 options["lambda2"] = lambda2 options["beta"] = beta super().__init__(params, options) for param_group in self.param_groups: for param in param_group.parameters: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() self.state[param]["accumulator_value"] = flow.zeros_like(param).fill_( param_group["initial_accumulator_value"] ) self._op = ( flow.stateful_op("ftrl_update") .Input("model") .Input("model_diff") .Input("accumulate") .Input("z") .Build() ) def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ with flow.no_grad(): loss = None if closure is not None: loss = closure() for param_group in self.param_groups: kwargs = { "learning_rate": param_group["lr"], "l2": param_group["weight_decay"], "lr_power": param_group["lr_power"], "lambda1": param_group["lambda1"], "lambda2": param_group["lambda2"], "beta": param_group["beta"], } for param in param_group.parameters: if param.grad is None: continue if "z" not in self.state[param]: self.state[param]["z"] = flow.zeros_like(param) accumulate_tensor = self.state[param]["accumulator_value"] z_tensor = self.state[param]["z"] flow._C.dispatch_ftrl_update( self._op, (param, param.grad, accumulate_tensor, z_tensor), **kwargs, ) return loss def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: optimizer_conf = train_conf.optimizer_conf.add() lr = ( param_group["initial_lr"] if "initial_lr" in param_group else param_group["lr"] ) l2 = param_group["weight_decay"] initial_accumulator_value = param_group["initial_accumulator_value"] lr_power = param_group["lr_power"] lambda1 = param_group["lambda1"] lambda2 = param_group["lambda2"] beta = param_group["beta"] optimizer_conf.base_learning_rate = lr self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf) optimizer_conf.ftrl_conf.initial_accumulator_value = ( initial_accumulator_value ) optimizer_conf.ftrl_conf.lr_power = lr_power optimizer_conf.ftrl_conf.lambda1 = lambda1 optimizer_conf.ftrl_conf.lambda2 = lambda2 optimizer_conf.ftrl_conf.beta = beta self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) for param in param_group.parameters: vars_conf[param].l2 = l2 if param.requires_grad: optimizer_conf.variable_op_names.append(vars_conf[param].name) new_opt_confs.append(optimizer_conf) return new_opt_confs @property def support_sparse(self): return False def make_persistent_table_reader( paths, snapshot_name, key_type, value_type, storage_dim, physical_block_size=4096, ): r"""Creates a reader for reading persistent table. Args: paths (list): paths of tables to read snapshot_name (str): name of the snapshot to read key_type (flow.dtype): the data type of key value_type (flow.dtype): the data type of value storage_dim (int): number of elements in each value physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096 """ return PersistentTableReader( paths, snapshot_name, key_type, value_type, storage_dim, 4 * 1024, physical_block_size, ) def make_persistent_table_writer( paths, snapshot_name, key_type, value_type, storage_dim, physical_block_size=4096, ): r"""Creates a writer for writing persistent table. Args: paths (list): paths of tables to write snapshot_name (str): name of the snapshot to write key_type (flow.dtype): the data type of key value_type (flow.dtype): the data type of value storage_dim (int): number of elements in each value physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096 """ return PersistentTableWriter( paths, snapshot_name, key_type, value_type, storage_dim, 4 * 1024, physical_block_size, ) class SmartDecayAdam(flow.nn.optimizer.adam.Adam): """Implements SmartDecayAdam algorithm. The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. For Sparse Embedding Table in OneEmbedding, implement the SmartDecayAdam algorithm. For other models, it is same as Adam. """ def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = super()._generate_conf_for_graph(train_conf, vars_conf) for opt_conf in new_opt_confs: opt_conf.adam_conf.smart_decay = True class Optimizer(Optimizer): def __init__( self, optimizer: Optimizer, embeddings: List[Embedding], ): self.optimizer = optimizer self.embeddings = embeddings self.param_groups = optimizer.param_groups # self._default_options = optimizer._default_options # self._state = optimizer._state self.defaults = optimizer.defaults self.state = optimizer.state self.embedding_param_group_dict = {} for embedding in self.embeddings: for group in self.param_groups: param_set = set() for param in group.parameters: param_set.add(param) if embedding.shadow in param_set: self.embedding_param_group_dict[embedding.embedding_name] = group if not embedding.embedding_name in self.embedding_param_group_dict: raise ValueError("embedding must in optimizers param_group") def step(self, closure: Callable = None): step = self.optimizer.state["step"] for embedding in self.embeddings: param_group = self.embedding_param_group_dict[embedding.embedding_name] if type(self.optimizer) is flow.nn.optimizer.sgd.SGD: embedding.sgd_update(param_group, step) elif type(self.optimizer) is flow.nn.optimizer.adam.Adam: embedding.adam_update(param_group, step) elif type(self.optimizer) is flow.nn.optimizer.adagrad.Adagrad: embedding.adagrad_update(param_group, step) elif type(self.optimizer) is flow.one_embedding.Ftrl: embedding.ftrl_update(param_group, step) else: raise NotImplementedError("only support sgd, adam, adagrad and ftrl") self.optimizer.step() def _generate_conf_for_graph(self, train_conf, vars_conf): return self.optimizer._generate_conf_for_graph(train_conf, vars_conf) ================================================ FILE: python/oneflow/onnx/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings def symbolic_opset11(): warnings.warn( "The oneflow.onnx.symbolic_opset11 interface is just to align the torch.onnx.symbolic_opset11 interface and has no practical significance." ) def register_custom_op_symbolic(*args, **kwargs): warnings.warn( "The oneflow.onnx.register_custom_op_symbolic interface is just to align the torch.onnx.register_custom_op_symbolic interface and has no practical significance." ) ================================================ FILE: python/oneflow/onnx/symbolic_helper.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings def parse_args(*args, **kwargs): warnings.warn( "The oneflow.onnx.parse_args interface is just to align the torch.onnx.parse_args interface and has no practical significance." ) def func(fn): return fn return func ================================================ FILE: python/oneflow/ops/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ def load_library(path): raise ImportError("load_library is not implemented") ================================================ FILE: python/oneflow/ops/array_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ def parse_slice_tuple_list(slice_tup_list, shape): ndim = len(shape) if not isinstance(slice_tup_list, (list, tuple)) or len(slice_tup_list) > ndim: raise ValueError( "slice_tup_list must be a list or tuple with length less than or equal " "to number of dimensions of input tensor" ) if len(slice_tup_list) < ndim: supple_ndim = ndim - len(slice_tup_list) slice_tup_list += type(slice_tup_list)([(None, None, None)] * supple_ndim) start_list, stop_list, step_list = [], [], [] for (slice_tup, dim) in zip(slice_tup_list, shape): if not isinstance(slice_tup, (tuple, list)) or len(slice_tup) != 3: raise ValueError( "element of slice_tup_list must be a list or tuple with form (start, stop, step)" ) if not all((isinstance(elem, int) or elem is None for elem in slice_tup)): raise ValueError("element of slice tuple must int or None") (start, stop, step) = slice_tup if step is None: step = 1 if step == 0: raise ValueError("slice step can't be 0") if start is None: start = 0 if step > 0 else dim if stop is None: stop = dim if step > 0 else -dim - 1 # start range is [-dim, dim-1] start = max(min(start, dim - 1), -dim) # stop range is [-dim-1, dim] stop = max(min(stop, dim), -dim - 1) reg_start = start if start >= 0 else start + dim reg_stop = stop if stop >= 0 else stop + dim if step > 0 and reg_stop < reg_start: stop = start if step < 0 and reg_start < reg_stop: stop = start start_list.append(start) stop_list.append(stop) step_list.append(step) return start_list, stop_list, step_list ================================================ FILE: python/oneflow/ops/stateful_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import oneflow._oneflow_internal import oneflow.framework.id_util as id_util class StatefulOp(object): def __init__(self, op_type_name, op_name=None): if op_name is None: op_name = id_util.UniqueStr(op_type_name) self._builder = oneflow._oneflow_internal.one.OpBuilder(op_type_name, op_name) self._op = None self._op_type_name = op_type_name @property def op(self): """access the builtin op Returns: the builtin op """ if self._op is None: self._op = self._builder.build() return self._op def Input(self, input_name, num=1): """Set input blob of op Args: input_name (str): input name of blob num (int, optional) : Defaults to 1. Returns: self """ assert isinstance(num, int) and num >= 1 self._builder.input(input_name, num) return self def Output(self, output_name, num=1): """Set output blob of op Args: output_name (str): name of output blob num (int, optional): Defaults to 1. Returns: self """ assert isinstance(num, int) and num >= 1 self._builder.output(output_name, num) return self def Build(self): """Explicitly complete the construction of the builtin op Returns: the completed builtin op """ if self._op is None: self._op = self._builder.build() return self._op ================================================ FILE: python/oneflow/ops/transpose_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Sequence def is_perm(perm: Sequence[int]) -> bool: return list(range(len(perm))) == sorted(list(perm)) def get_perm_when_transpose_axis_to_last_dim(num_axes: int, axis: int) -> tuple: axis = axis if axis >= 0 else axis + num_axes assert 0 <= axis < num_axes, "axis out of range" perm = [dim if dim < axis else dim + 1 for dim in range(num_axes - 1)] perm.append(axis) return tuple(perm) def get_inversed_perm(perm: Sequence[int]) -> tuple: assert is_perm(perm) inversed_perm = [-1] * len(perm) for i in range(len(perm)): inversed_perm[perm[i]] = i return tuple(inversed_perm) ================================================ FILE: python/oneflow/ops/util/__init__.py ================================================ ================================================ FILE: python/oneflow/ops/util/initializer_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools import math from typing import Optional, Sequence, Union import numpy as np import oneflow as flow import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util import oneflow.core.operator.op_conf_pb2 as op_conf_util import oneflow.framework.dtype as dtype_util def get_random_distribution(distribution): if distribution.lower() == "truncated_normal": return initializer_conf_util.kTruncatedNormal elif distribution.lower() == "random_normal": return initializer_conf_util.kRandomNormal elif distribution.lower() == "random_uniform": return initializer_conf_util.kRandomUniform else: raise ValueError("Invalid random_distribution") def get_data_format(data_format): assert isinstance(data_format, str), "data_format must be a string" if data_format.startswith("NC"): return "channels_first" elif data_format.startswith("N") and data_format.endswith("C"): return "channels_last" else: assert data_format == "", ValueError( 'data_format must be "N...C" or "NC..." or ""' ) return "" def calc_fan(shape, mode, data_format): assert ( len(shape) >= 2 ), "Fan in and fan out can out be computed for tensor with fewer 2 dimensions" if len(shape) == 2: fan_in = shape[1] fan_out = shape[0] else: fan_in = 1.0 for dim in shape[1:]: fan_in *= dim fan_out = shape[0] if data_format == "channels_first": for dim in shape[2:]: fan_out *= dim elif data_format == "channels_last": for dim in shape[1:-1]: fan_out *= dim else: raise NotImplementedError( "Only support 'channels_first' and 'channels_last' data format" ) if mode == "fan_sum": return float(fan_in) + float(fan_out) elif mode == "fan_in": return float(fan_in) elif mode == "fan_out": return float(fan_out) else: raise NotImplementedError("Only support 'fan_in', 'fan_out' and 'fan_sum' mode") def calc_gain(nonlinearity, param=None): linear_fns = [ "linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", ] if nonlinearity in linear_fns or nonlinearity == "sigmoid": return 1 elif nonlinearity == "tanh": return 5.0 / 3 elif nonlinearity == "relu": return math.sqrt(2.0) elif nonlinearity == "leaky_relu": if param is None: negative_slope = 0.01 elif ( not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float) ): negative_slope = param else: raise ValueError("negative_slope {} not a valid number".format(param)) return math.sqrt(2.0 / (1 + negative_slope ** 2)) elif nonlinearity == "selu": return 3.0 / 4 else: raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) ================================================ FILE: python/oneflow/optim/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.optimizer.adam import Adam from oneflow.nn.optimizer.adamw import AdamW from oneflow.optim.optimizer import Optimizer from oneflow.nn.optimizer.rmsprop import RMSprop from oneflow.nn.optimizer.sgd import SGD from oneflow.nn.optimizer.adagrad import Adagrad from oneflow.nn.optimizer.lamb import LAMB from oneflow.nn.optimizer.adadelta import Adadelta from oneflow.nn.optimizer.lbfgs import LBFGS from . import lr_scheduler from . import swa_utils ================================================ FILE: python/oneflow/optim/lr_scheduler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.optimizer.lr_scheduler import LRScheduler as _LRScheduler from oneflow.nn.optimizer.cosine_decay_lr import CosineDecayLR from oneflow.nn.optimizer.cosine_annealing_lr import CosineAnnealingLR from oneflow.nn.optimizer.lambda_lr import LambdaLR from oneflow.nn.optimizer.step_lr import StepLR from oneflow.nn.optimizer.multistep_lr import MultiStepLR from oneflow.nn.optimizer.exponential_lr import ExponentialLR from oneflow.nn.optimizer.reduce_lr_on_plateau import ReduceLROnPlateau from oneflow.nn.optimizer.polynomial_lr import PolynomialLR from oneflow.nn.optimizer.constant_lr import ConstantLR from oneflow.nn.optimizer.linear_lr import LinearLR from oneflow.nn.optimizer.warmup_lr import WarmupLR from oneflow.nn.optimizer.warmup_lr import WarmupLR as WarmUpLR from oneflow.nn.optimizer.cosine_annealing_warm_restarts import ( CosineAnnealingWarmRestarts, ) from oneflow.nn.optimizer.chained_scheduler import ChainedScheduler from oneflow.nn.optimizer.sequential_lr import SequentialLR from oneflow.nn.optimizer.multiplicative_lr import MultiplicativeLR ================================================ FILE: python/oneflow/optim/optimizer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections import warnings from copy import deepcopy from itertools import chain from typing import Any, Callable, Dict, Union from oneflow.framework.tensor import Tensor from oneflow.nn.graph.proxy import ProxyTensor from oneflow.nn.parameter import Parameter from oneflow.nn.utils.clip_grad import clip_grad_norm_ from oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup import oneflow as flow from collections import defaultdict, abc as container_abcs class ParamGroup(dict): def __init__( self, parameters: Dict[str, Any], default_options: Dict, ): # ParamGroup must be constructed by Dict["params": parameters: List[Parameter, Tensor or ProxyTensor], "...": ...] assert isinstance(parameters, dict) and "params" in parameters assert not isinstance(parameters["params"], (Parameter, Tensor)) self._parameters = list() for p in parameters["params"]: if isinstance(p, (Parameter, Tensor)): self._parameters.append(p) elif isinstance(p, ProxyTensor): # Add parameter from nn.Graph self._parameters.append(p.to(Tensor)) else: raise ValueError( "parameters in ParamGroup must be Tensor or ProxyTensor." ) self._options = deepcopy(default_options) # rewrite options in default_options for key in self._options: if key in parameters: self._options[key] = parameters[key] # add excess keys in dict for key in parameters: if key not in self._options and key != "params": self._options[key] = parameters[key] self._enable_clip_grad = False if "clip_grad_max_norm" in parameters and "clip_grad_norm_type" in parameters: self._enable_clip_grad = True self._options["clip_grad_max_norm"] = parameters["clip_grad_max_norm"] self._options["clip_grad_norm_type"] = parameters["clip_grad_norm_type"] self._make_options_valid() self.contiguous_params = self._options.get("contiguous_params", False) if self.contiguous_params: self.params_group = ContiguousParamsGroup([parameters["params"]]) super().__init__(**self._options, params=self._parameters) super().setdefault("contiguous_params", False) super().setdefault("_enable_clip_grad", self._enable_clip_grad) def _make_options_valid(self): """handle the conflict between optimizer options """ if self._options.get("contiguous_params", False) and self._options.get( "fused", False ): self._options["fused"] = False warnings.warn( "do not set contiguous_params and fused at the same time, " "now only contiguous_params is set." ) @property def parameters(self): return self._parameters @property def contiguous_parameters(self): """return contiguous_parameters for fast updating """ return self.params_group.grouped_parameters class _SourceOpOnlyResourceDependenceMode: def __init__(self): self.guard_ = None def __enter__(self): self.guard = ( flow._oneflow_internal.eager.SourceOpOnlyResourceDependenceModeGuard() ) def __exit__(self, *args, **kwargs): del self.guard def _decorate_step(step): def decorated_step(*args, **kwargs): with _SourceOpOnlyResourceDependenceMode(): return step(*args, **kwargs) return decorated_step class _RequiredParameter(object): """Singleton class representing a required parameter for an Optimizer.""" def __repr__(self): return "" required = _RequiredParameter() class Optimizer(object): def __init__(self, parameters, options): self.param_groups = list() self.state = defaultdict(dict) self.defaults = options self.state["step"] = 0 self._parse_input_parameters(parameters) all_remat = all( p.is_local and p.device.rematable for pg in self.param_groups for p in pg.parameters ) all_not_remat = all( not p.is_local or not p.device.rematable for pg in self.param_groups for p in pg.parameters ) if not all_remat and not all_not_remat: raise ValueError( "Parameters should be all on rematable device or all on non-rematable device." ) if all_not_remat: # _decorate_step makes mutable update interleaved with backward # computation, producing wrong results in DTR if the original # weight is used to recompute other tensors. # Besides, it makes parameters remain in memory by unknown reasons # even after parameters and optimizer are not hold by python # interpreter. self.step = _decorate_step(self.step) self._state_not_saved = [ "params_group", "_parameters", ] def add_param_group(self, param_group) -> None: r""" Add a param group to the :class:`Optimizer` s `param_groups`. This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses. Args: param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options. Example: >>> import oneflow >>> import oneflow.optim as optim >>> w1 = oneflow.ones(3, 3) >>> w1.requires_grad = True >>> w2 = oneflow.ones(3, 3) >>> w2.requires_grad = True >>> o = optim.SGD([w1]) >>> o.param_groups[0] {'lr': 0.001, 'momentum': 0.0, 'dampening': 0.0, 'weight_decay': 0.0, 'nesterov': False, 'maximize': False, 'params': [tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32, requires_grad=True)]} >>> o.add_param_group({'params': w2}) >>> o.param_groups[1] {'lr': 0.001, 'momentum': 0.0, 'dampening': 0.0, 'weight_decay': 0.0, 'nesterov': False, 'maximize': False, 'params': [tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=oneflow.float32, requires_grad=True)]} """ assert isinstance(param_group, dict), "param group must be a dict" params = param_group["params"] if isinstance(params, flow.Tensor): param_group["params"] = [params] elif isinstance(params, set): raise TypeError( "optimizer parameters need to be organized in ordered collections, but " "the ordering of tensors in sets will change between runs. Please use a list instead." ) else: param_group["params"] = list(params) for param in param_group["params"]: if not isinstance(param, flow.Tensor): raise TypeError( "optimizer can only optimize Tensors, " "but one of the params is " + type(param) ) if not param.is_leaf: raise ValueError("can't optimize a non-leaf Tensor") for name, default in self.defaults.items(): if default is required and name not in param_group: raise ValueError( "parameter group didn't specify a value of required optimization parameter " + name ) else: param_group.setdefault(name, default) params = param_group["params"] if len(params) != len(set(params)): warnings.warn( "optimizer contains a parameter group with duplicate parameters; " "in future, this will cause an error; ", stacklevel=3, ) param_set = set() for group in self.param_groups: param_set.update(set(group.parameters)) if not param_set.isdisjoint(set(param_group["params"])): raise ValueError("some parameters appear in more than one parameter group") self.param_groups.append(ParamGroup(param_group, self.defaults)) for param in param_group["params"]: assert param.is_leaf, "parameters must be leaf tensor" self.state[param] = dict() def load_state_dict(self, state_dict) -> None: r""" Load the state of the optimizer which is created by `state_dict` function. It almost copied from: https://pytorch.org/docs/1.10/_modules/torch/optim/optimizer.html#Optimizer.load_state_dict. """ # Validate the state_dict groups = self.param_groups saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): raise ValueError( "loaded state dict has a different number of parameter groups" ) for param, saved_param in zip(groups, saved_groups): # the contiguous_params property is remained in state_dict, # so contiguous_params of state_dict and current optimizer should match. if "contiguous_params" in param and param[ "contiguous_params" ] != saved_param.get("contiguous_params", False): raise ValueError( "loaded contiguous_params state doesn't match the optimizer" ) if param["contiguous_params"]: param_list = param.contiguous_parameters else: param_list = param.parameters if len(param_list) != len(saved_param["params"]): raise ValueError( "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" ) # Update the state id_map = { old_id: p for old_id, p in zip( chain.from_iterable((g["params"] for g in saved_groups)), chain.from_iterable( ( g.parameters if not g["contiguous_params"] else g.contiguous_parameters for g in groups ) ), ) } def cast(param, value): r"""Make a deep copy of value, casting all tensors to device or placement of param.""" if isinstance(value, Tensor): if value.is_local: value = value.to(param.device) else: cpu_value_placement = flow.placement("cpu", value.placement.ranks) cpu_param_placement = flow.placement("cpu", param.placement.ranks) value = ( value.to_global(placement=cpu_value_placement) .to_global(placement=cpu_param_placement, sbp=param.sbp) .to_global(placement=param.placement) ) return value elif isinstance(value, dict): return {k: cast(param, v) for k, v in value.items()} elif isinstance(value, collections.Iterable): return type(value)(cast(param, v) for v in value) else: return value # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state = dict() for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] state[param] = cast(param, v) else: state[k] = v self.state = state # Update parameter groups, setting their 'params' value def update_group(group, new_group): new_group.pop("params") g = deepcopy(new_group) group.update(g) group._enable_clip_grad = g["_enable_clip_grad"] group._options = g return group param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.param_groups = param_groups def state_dict(self): r""" Returns the state of the optimizer as a :py:class:`dict`. It contains two entries: * state - a dict holding current optimization state. Its content differs between optimizer classes. * param_group - a dict containing all parameter groups. It almost copied from: https://pytorch.org/docs/1.10/_modules/torch/optim/optimizer.html#Optimizer.state_dict. """ # Save order indices instead of Tensors param_mappings = {} start_index = 0 def pack_group(group): if group["contiguous_params"]: param_list = group.contiguous_parameters else: param_list = group.parameters nonlocal start_index packed = {k: v for k, v in group.items() if k not in self._state_not_saved} param_mappings.update( { id(p): i for i, p in enumerate(param_list, start_index) if id(p) not in param_mappings } ) packed["params"] = [param_mappings[id(p)] for p in param_list] start_index += len(packed["params"]) return packed param_groups = [pack_group(g) for g in self.param_groups] # Remap state to use order indices as keys packed_state = { (param_mappings[id(k)] if isinstance(k, Tensor) else k): v for k, v in self.state.items() } return { "state": packed_state, "param_groups": param_groups, } def step(self, closure: Union[Callable, None] = None) -> Union[Tensor, None]: """Performs a single optimization step (parameter update). Args: closure (Union[Callable, None], optional): A closure that reevaluates the model and returns the loss. Optional for most optimizers. Returns: Union[Tensor, None]: The loss. """ raise NotImplementedError() def clip_grad(self, error_if_nonfinite: bool = False): r"""Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. You can set the max_norm and norm_type. For more details, you can refer to the documentation of each optimizer(like Adam, SGD and so on). You can also refer the code in :func:`oneflow.nn.utils.clip_grad_norm_` Args: error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:``parameters`` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) """ for param_group in self.param_groups: if param_group._enable_clip_grad: clip_grad_norm_( param_group.parameters, param_group["clip_grad_max_norm"], param_group["clip_grad_norm_type"], error_if_nonfinite, param_group.get("fused", False), ) else: warnings.warn( "To enable clip_grad, passing the `clip_grad_max_norm` and `clip_grad_norm_type` parameters when instantializing the Optimizer." ) def zero_grad(self, set_to_none: bool = False): """Sets the gradients of all optimized :class:`oneflow.Tensor` s to zero. Args: set_to_none (bool): instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, grads are guaranteed to be None for params that did not receive a gradient. 3. Optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether). """ for param_group in self.param_groups: if param_group["contiguous_params"]: param_list = param_group.contiguous_parameters else: param_list = param_group.parameters for param in param_list: param._zero_grad_(set_to_none) def _parse_input_parameters(self, parameters): """ Supports such parameters: 1. Iterator: flow.optim.SGD(module.parameters(), lr=0.1) 2. List[Dict]: flow.optim.SGD([{"params": module1.parameters()}, {"params": module2.parameters()}]) 3. List[Parameter or Tensor]: flow.optim.SGD([module.weight, module.bias]) """ if isinstance(parameters, collections.abc.Iterator): # Iterator self.param_groups.append( ParamGroup({"params": list(parameters)}, self.defaults) ) elif isinstance(parameters, collections.abc.Iterable): # List[Dict] if isinstance(parameters[0], dict): for param in parameters: assert isinstance(param, dict) self.param_groups.append(ParamGroup(param, self.defaults)) # List[Parameter or Tensor] else: self.param_groups.append( ParamGroup({"params": parameters}, self.defaults) ) else: raise TypeError( f"params argument given to the optimizer should be an iterable of Tensors or dicts, but got {type(parameters)}" ) def _generate_grad_clip_conf_for_optim_conf(self, param_group, optimizer_conf): if not param_group._enable_clip_grad: return assert "clip_grad_max_norm" in param_group assert "clip_grad_norm_type" in param_group max_norm = float(param_group["clip_grad_max_norm"]) norm_type = float(param_group["clip_grad_norm_type"]) clip_grad_norm = optimizer_conf.clip_conf.clip_by_global_norm clip_grad_norm.max_norm = max_norm clip_grad_norm.norm_type = norm_type def _generate_lr_scale_for_optim_conf(self, param_group, optimizer_conf): if "lr_scale" not in param_group: return lr_scale = float(param_group["lr_scale"]) optimizer_conf.lr_scale = lr_scale @property def support_sparse(self): """Whether the Optimizer support sparse update. """ return False def _check_variables_in_graph(self, vars_conf): for param_group in self.param_groups: for param in param_group.parameters: if not param.requires_grad: continue if param not in vars_conf: raise ValueError( f"Parameter <{param}> is not in the corresponding nn.Graph/nn.Module." " Please make sure you call the module's to(..)/to_global(...) method first," " then add the module's parameters into an optimizer." ) def _check_variables_optimizer_bound(self, vars_conf): for param_group in self.param_groups: for param in param_group.parameters: if not param.requires_grad: continue if vars_conf[param].bound_optimizer is None: vars_conf[param].bound_optimizer = self elif vars_conf[param].bound_optimizer is not self: raise ValueError( f"<{vars_conf[param].name}> is already bound to another optimizer." ) def _generate_indexed_slices_optimizer_conf(self, job_conf, vars_conf): if not self.support_sparse: raise ValueError(f"{self.__class__} does not support sparse updating.") for param_group in self.param_groups: for param in param_group.parameters: if not param.requires_grad: continue sparse_opt_conf = job_conf.indexed_slices_optimizer_conf sparse_variable_op_names = sparse_opt_conf.include_op_names sparse_variable_op_names.op_name.append(vars_conf[param].name) ================================================ FILE: python/oneflow/optim/swa_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.nn.optimizer.swa_utils import SWALR, update_bn, AveragedModel ================================================ FILE: python/oneflow/profiler/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal from oneflow.profiler.profiler import ( profile, record_function, ProfilerActivity, ProfilerAction, tensorboard_trace_handler, ) __all__ = [ "range_push", "range_pop", "profiler_start", "profiler_stop", "profile", "record_function", "ProfilerActivity", "kineto_available", "tensorboard_trace_handler", "ProfilerAction", ] def range_push(range_name): oneflow._oneflow_internal.profiler.RangePush(range_name) def range_pop(): oneflow._oneflow_internal.profiler.RangePop() def profiler_start(): oneflow._oneflow_internal.profiler.ProfilerStart() def profiler_stop(): oneflow._oneflow_internal.profiler.ProfilerStop() def kineto_available(): return True ================================================ FILE: python/oneflow/profiler/events.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import json import copy from enum import Enum from typing import Tuple, List, Dict from collections import OrderedDict from rich import box from rich.console import Console from rich.table import Table from oneflow.profiler.util import format_time class EventType(Enum): Custom = 0 Kernel = 1 class CustomEventType(Enum): Default = 0 CudaKernel = 1 CudaRuntime = 2 class EventBase: MAX_NAME_LENGTH = 55 def __init__(self, name: str, time_total: float, event_type: EventType) -> None: self._name: str = name self._time_total: float = time_total self.count: int = 1 self.event_type: EventType = event_type def update(self, event) -> None: assert self.event_type == event.event_type self.cpu_time_total += event.cpu_time_total self.count += event.count @property def name(self): if len(self._name) > self.MAX_NAME_LENGTH: return self._name[: self.MAX_NAME_LENGTH - 3] + "..." return self._name @property def cpu_time_total(self): return self._time_total @cpu_time_total.setter def cpu_time_total(self, new_time): self._time_total = new_time @property def cpu_time(self): return self._time_total / self.count @property def cuda_time_total(self): return None @cuda_time_total.setter def cuda_time_total(self, new_time): pass @property def cuda_time(self): if self.cuda_time_total is None: return None return self.cuda_time_total / self.count def has_cuda_time(self) -> bool: return self.cuda_time_total is not None def __eq__(self, __o: object) -> bool: return ( self.name == __o.name and self.count == __o.count and self.cpu_time_total == __o.cpu_time_total and self.cuda_time_total == __o.cuda_time_total ) class CustomEvent(EventBase): def __init__( self, name: str, time_total: float, custom_event_type: CustomEventType ) -> None: super().__init__(name, time_total, EventType.Custom) self.custom_event_type = custom_event_type @classmethod def from_dict(cls, d: dict): return cls(d.get("name"), d.get("time"), CustomEventType(d.get("custom_type"))) @property def key(self): return self.name, self.custom_event_type @property def cuda_time_total(self): if self.custom_event_type == CustomEventType.CudaKernel: return self._time_total return None def to_dict(self): device_prefix = "cuda" if self.has_cuda_time() else "cpu" time_attrs = [f"{device_prefix}_{suffix}" for suffix in ["time", "time_total"]] result = { "name": self.name, "count": self.count, } for time_attr in time_attrs: result[time_attr] = format_time(getattr(self, time_attr)) return result def __eq__(self, __o: object) -> bool: return ( super().__eq__(__o) and isinstance(__o, type(self)) and self.custom_event_type == __o.custom_event_type ) class KernelEvent(EventBase): def __init__( self, name: str, time_total: float, memory_size: int, description: Dict[str, str], ) -> None: super().__init__(name, time_total, EventType.Kernel) self.children: List[CustomEvent] = [] self.memory_size = memory_size self.description = description self._cuda_time_total = 0.0 self._enable_show_input_shapes = True self._enable_show_attributes = True def add_child(self, event: CustomEvent): self.children.append(event) if event.has_cuda_time(): self._cuda_time_total += event.cuda_time @classmethod def from_dict(cls, d: dict): kernel_event = cls( d.get("name"), d.get("time"), d.get("memory_size"), d.get("description", {}) ) if "children" in d.keys(): children_list = d.get("children") if len(children_list) > 0: for child_dict in children_list: kernel_event.add_child(CustomEvent.from_dict(child_dict)) return kernel_event @property def key(self): def get_extra_keys(): extra_keys = [] if self.input_shapes != "" and self._enable_show_input_shapes: extra_keys.append(self.description.get("input_shapes")[1]) if self.attributes != "" and self._enable_show_attributes: extra_keys.append(self.description.get("attrs")[1]) return tuple(extra_keys) if len(self.children) == 0: return (self.name,) + get_extra_keys() return ( self.name, *get_extra_keys(), ",".join([x.name for x in self.children]), ) @property def cuda_time_total(self): if self._cuda_time_total > 0.0: return self._cuda_time_total return None @cuda_time_total.setter def cuda_time_total(self, new_time): self._cuda_time_total = new_time @property def input_shapes(self): if "input_shapes" in self.description: return self.description["input_shapes"][0] return "" @property def attributes(self): if "attrs" in self.description: return self.description["attrs"][0] return "" @property def bandwidth(self): if len(self.children) > 0 and self.has_cuda_time(): if self.memory_size != -1: return f"{self.memory_size / (1024.0 * 1024.0 * 1024.0) / (self.cuda_time / (1000 * 1000)):.3f}GB/s" return "" def to_dict(self): result = { "name": self.name, "cpu_time_total": format_time(self.cpu_time_total), "cpu_time": format_time(self.cpu_time), "count": self.count, "input_shapes": self.input_shapes, "attributes": self.attributes, } if self.has_cuda_time(): result.update( { "cuda_time_total": format_time(self.cuda_time_total), "cuda_time": format_time(self.cuda_time), } ) return result def update(self, event): assert id(self) != id(event) assert isinstance(event, type(self)) assert len(self.children) == len(event.children) assert self.has_cuda_time() == event.has_cuda_time() assert self.key == event.key super().update(event) if self.has_cuda_time(): self.cuda_time_total += event.cuda_time_total for i in range(len(self.children)): self.children[i].update(event.children[i]) def make_children_average(self): stats: Dict[Tuple[str, ...], CustomEvent] = OrderedDict() for event in self.children: if event.key in stats: stats[event.key].update(event) else: stats[event.key] = copy.deepcopy(event) self.children = list(stats.values()) self.children.sort(key=lambda x: x.name) def __eq__(self, __o: object) -> bool: return ( super().__eq__(__o) and isinstance(__o, type(self)) and self.children == __o.children and self.memory_size == __o.memory_size and self.input_shapes == __o.input_shapes and self.attributes == __o.attributes ) class Events(list): def __init__(self, events: str = "") -> None: list.__init__([]) if events != "": self.__init_events(events) def __init_events(self, events: str): events_json = json.loads(events) classes = [CustomEvent, KernelEvent] for event_json in events_json: self.append(classes[event_json.get("type")].from_dict(event_json)) def __str__(self): return self.table() def key_averages(self, group_by_input_shape=False, group_by_attributes=False): stats: Dict[Tuple[str, ...], EventBase] = OrderedDict() def deal_event(e): if isinstance(e, KernelEvent): e._enable_show_input_shapes = group_by_input_shape e._enable_show_attributes = group_by_attributes key = e.key if key in stats: stats[key].update(e) else: stats[key] = copy.deepcopy(e) for event in self: if isinstance(event, KernelEvent) and len(event.children) != 0: event.make_children_average() for event_child in event.children: deal_event(event_child) event.children = [] deal_event(event) results = Events() results.extend(stats.values()) return results def table(self): has_input_shapes = any( [ x.input_shapes != "" and x._enable_show_input_shapes for x in self if isinstance(x, KernelEvent) ] ) has_attributes = any( [ x.attributes != "" and x._enable_show_attributes for x in self if isinstance(x, KernelEvent) ] ) has_bandwidth = any( [x.bandwidth != "" for x in self if isinstance(x, KernelEvent)] ) t = Table( "Name", "CPU time total", "CPU time", "GPU time total", "GPU time", "Number of calls", box=box.SIMPLE, ) field_keys = [ "name", "cpu_time_total", "cpu_time", "cuda_time_total", "cuda_time", "count", ] if has_input_shapes: t.add_column("Input shapes") field_keys.append("input_shapes") if has_attributes: t.add_column("Attributes") field_keys.append("attributes") if has_bandwidth: t.add_column("Bandwidth") field_keys.append("bandwidth") def build_row(data: dict): return tuple(str(data.get(key, "")) for key in field_keys) for item in self: if isinstance(item, CustomEvent): t.add_row(*build_row(item.to_dict())) if isinstance(item, KernelEvent): t.add_row(*build_row(item.to_dict())) if len(item.children) > 0: for child in item.children: t.add_row(*build_row(child.to_dict())) console = Console() with console.capture() as capture: console.print(t) return capture.get() ================================================ FILE: python/oneflow/profiler/profiler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal from enum import Enum from typing import Optional, Iterable, Set from oneflow.profiler.events import Events class ProfilerActivity(Enum): CPU = 1 CUDA = 2 class ProfilerAction(Enum): """ Profiler actions that can be taken at the specified intervals """ NONE = 0 WARMUP = 1 RECORD = 2 RECORD_AND_SAVE = 3 def tensorboard_trace_handler(): raise NotImplementedError() def supported_activities() -> Set[ProfilerActivity]: activities = set([ProfilerActivity.CPU]) if oneflow.cuda.is_available(): activities.add(ProfilerActivity.CUDA) return activities class profile: def __init__( self, activities: Optional[Iterable[ProfilerActivity]] = None, record_shapes: bool = False, record_attrs: bool = False, record_bandwidth_for_cuda: bool = False, ) -> None: self.activities = set(activities) if activities else supported_activities() assert ( len(self.activities) > 0 ), "At least one ProfilerActivity must be specified." for item in self.activities: assert ( item in supported_activities() ), f"Unsupported ProfilerActivity {item}" self.record_shapes = record_shapes self.record_attrs = record_attrs if not (ProfilerActivity.CUDA in self.activities): assert ( record_bandwidth_for_cuda == False ), "record_bandwidth_for_cuda = True can only work with cuda." self.record_bandwidth_for_cuda = record_bandwidth_for_cuda self.profile_events: Optional[Events] = None def __enter__(self): oneflow._oneflow_internal.profiler.EnableProfiler( ProfilerActivity.CPU in self.activities, ProfilerActivity.CUDA in self.activities, self.record_shapes, self.record_attrs, self.record_bandwidth_for_cuda, ) return self def __exit__(self, exc_type, exc_val, exc_tb): self.profile_events = Events( oneflow._oneflow_internal.profiler.DisableProfilerAndReturnResult() ) def __check_finish(self): if self.profile_events is None: raise RuntimeError("Profiler didn't finish running") def key_averages(self, group_by_input_shape=False, group_by_attributes=False): self.__check_finish() return self.profile_events.key_averages( group_by_input_shape=group_by_input_shape, group_by_attributes=group_by_attributes, ) def events(self): self.__check_finish() return self.profile_events class record_function: def __init__(self, name: str) -> None: self.name = name self.__event_recorder_key = "" def __enter__(self): self.__event_recorder_key = oneflow._oneflow_internal.profiler.StartRecord( self.name ) return self def __exit__(self, exc_type, exc_val, exc_tb): oneflow._oneflow_internal.profiler.EndRecord(self.__event_recorder_key) ================================================ FILE: python/oneflow/profiler/util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ US_IN_MS = 1000.0 US_IN_SECOND = US_IN_MS * 1000.0 def format_time(time_us): if time_us >= US_IN_SECOND: return "{:.3f}s".format(time_us / US_IN_SECOND) if time_us >= US_IN_MS: return "{:.3f}ms".format(time_us / US_IN_MS) return "{:.3f}us".format(time_us) ================================================ FILE: python/oneflow/remat/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import oneflow as flow def parse_size(size): units = { "B": 1, "KB": 2 ** 10, "MB": 2 ** 20, "GB": 2 ** 30, "TB": 2 ** 40, "": 1, "KIB": 10 ** 3, "MIB": 10 ** 6, "GIB": 10 ** 9, "TIB": 10 ** 12, } m = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", str(size).strip()) assert m is not None number, unit = float(m.group(1)), m.group(2).upper() return int(number * units[unit]) def set_budget(budget: str): budget_in_bytes = parse_size(budget) flow._oneflow_internal.remat.set_budget_in_bytes(budget_in_bytes) def get_budget(): budget_in_bytes = flow._oneflow_internal.remat.budget_in_bytes() return budget_in_bytes set_small_pieces_optimization = ( flow._oneflow_internal.remat.set_small_pieces_optimization ) is_small_pieces_optimization_enabled = ( flow._oneflow_internal.remat.is_small_pieces_optimization_enabled ) ================================================ FILE: python/oneflow/sbp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.distribute import split_sbp as split import oneflow._oneflow_internal sbp = oneflow._oneflow_internal.sbp.sbp broadcast = oneflow._oneflow_internal.sbp.broadcast() partial_sum = oneflow._oneflow_internal.sbp.partial_sum() ================================================ FILE: python/oneflow/special/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .special_ops import erf from .special_ops import erfc from .special_ops import erfinv from .special_ops import exp2 from .special_ops import expm1 from .special_ops import log1p from .special_ops import log_softmax from .special_ops import logsumexp from .special_ops import round from .special_ops import softmax from .special_ops import digamma from .special_ops import psi from .special_ops import zeta ================================================ FILE: python/oneflow/special/special_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.tensor import Tensor # avoid redefine error when add_doc def erf(x: Tensor): return oneflow._C.erf(x) def erfc(x: Tensor): return oneflow._C.erfc(x) def erfinv(x: Tensor): return oneflow._C.erfinv(x) def exp2(x: Tensor): return oneflow._C.exp2(x) def expm1(x: Tensor): return oneflow._C.expm1(x) def log1p(x: Tensor): return oneflow._C.log1p(x) def log_softmax(x: Tensor, dim: int): return oneflow._C.log_softmax(x, dim) def logsumexp(x: Tensor, dim: int, keepdim=False): return oneflow._C.logsumexp(x, dim, keepdim) def round(x: Tensor): return oneflow._C.round(x) def softmax(x: Tensor, dim: int): return oneflow._C.softmax(x, dim) def digamma(x: Tensor): return oneflow._C.digamma(x) def psi(x: Tensor): return oneflow._C.digamma(x) def zeta(input, other): return oneflow._C.zeta(input, other) ================================================ FILE: python/oneflow/support/__init__.py ================================================ ================================================ FILE: python/oneflow/support/async_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import threading def Await(counter, func): assert counter > 0 cond_var = threading.Condition() counter_box = [counter] result_list = [] def Yield(result=None): result_list.append(result) cond_var.acquire() assert counter_box[0] > 0 counter_box[0] -= 1 cond_var.notify() cond_var.release() func(Yield) cond_var.acquire() while counter_box[0] > 0: cond_var.wait() cond_var.release() return result_list ================================================ FILE: python/oneflow/support/box.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ class Box(object): def __init__(self, *arg): assert len(arg) <= 1 self.has_value_ = len(arg) > 0 self.value_ = None if self.has_value_: self.value_ = arg[0] @property def value(self): assert self.has_value_ return self.value_ @property def value_setter(self): return lambda val: self.set_value(val) def set_value(self, val): self.value_ = val self.has_value_ = True def has_value(self): return self.has_value_ ================================================ FILE: python/oneflow/support/enable_if.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect import oneflow.support.traceinfo as traceinfo def condition(hob_expr): def Decorator(func): func.__oneflow_condition_hob__ = hob_expr return func return Decorator def get_condition_hob(func): assert hasattr(func, "__oneflow_condition_hob__") return func.__oneflow_condition_hob__ def set_condition_hob(func, hob): func.__oneflow_condition_hob__ = hob def unique(arg_funcs, context=None, default=None): assert isinstance(arg_funcs, (list, tuple)) conditional_functions = [] for arg_func in arg_funcs: if isinstance(arg_func, tuple): (func, hob_expr) = arg_func elif inspect.isfunction(arg_func): func = arg_func assert hasattr(func, "__oneflow_condition_hob__") hob_expr = func.__oneflow_condition_hob__ else: raise NotImplementedError debug_str = func.__name__ if hasattr(func, "__debug_str__"): debug_str = func.__debug_str__ conditional_functions.append((hob_expr, func, debug_str)) if default is None: def default(get_failed_info, *args, **kwargs): raise NotImplementedError(get_failed_info()) matched_func = GetMatchedFunction(default, conditional_functions, context=context) if matched_func is not None: return matched_func return MakeDefaultFunction(default, conditional_functions, context=context) def GetMatchedFunction(default, conditional_functions, context=None): select_triple = (None, None, None) for triple in conditional_functions: if not triple[0](context): continue if select_triple[1] is not None: return _MultiMatchedErrorFunction( default, [select_triple, triple], context=context ) select_triple = triple return select_triple[1] def MakeDefaultFunction(default, conditional_functions, context=None): def get_failed_info(customized_prompt=None): failed_info = "no avaliable function found.\n" for (bf, func, location) in conditional_functions: prompt = location if customized_prompt is None else customized_prompt failed_info += "\n%s: \x1b[1;31mFAILED\x1b[0m\n\t%s\n" % ( prompt, bf.debug_str(context), ) return failed_info return lambda *args, **kwargs: default(get_failed_info, *args, **kwargs) def _MultiMatchedErrorFunction(default, matched_functions, context=None): def get_failed_info(customized_prompt=None): failed_info = "at least two conditional functions matched.\n" for (bf, func, location) in matched_functions: prompt = location if customized_prompt is None else customized_prompt failed_info += "\n%s: \x1b[1;31mPASSED\x1b[0m\n\t%s\n" % ( prompt, bf.debug_str(context), ) return failed_info return lambda *args, **kwargs: default(get_failed_info, *args, **kwargs) ================================================ FILE: python/oneflow/support/env_var_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os def string_to_bool(env_str): if env_str.lower() in ("1", "true", "yes", "on", "y"): return True return False def parse_boolean_from_env(env_var, defalut_value): # This function aligns with ParseBooleanFromEnv() in oneflow/core/common/util.cpp assert isinstance(env_var, str), "env variable must be string, but got: " + type( env_var ) assert isinstance( defalut_value, bool ), "env variable defalut value must be boolean, but got: " + type(defalut_value) if os.getenv(env_var) is None: return defalut_value else: return string_to_bool(os.getenv(env_var)) ================================================ FILE: python/oneflow/support/func_inspect_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect import sys if sys.version_info > (2, 7) and sys.version_info < (3, 0): def GetArgNameAndDefaultTuple(func): """ returns a dictionary of arg_name:default_values for the input function """ (args, varargs, keywords, defaults) = inspect.getargspec(func) defaults = list(defaults) if defaults is not None else [] while len(defaults) < len(args): defaults.insert(0, None) return tuple(zip(args, defaults)) elif sys.version_info >= (3, 0): def GetArgNameAndDefaultTuple(func): signature = inspect.signature(func) return tuple( [ (k, v.default if v.default is not inspect.Parameter.empty else None) for (k, v) in signature.parameters.items() ] ) else: raise NotImplementedError def GetArgDefaults(func): return tuple(map(lambda x: x[1], GetArgNameAndDefaultTuple(func))) ================================================ FILE: python/oneflow/support/high_order_bool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import oneflow._oneflow_internal def bool_functor(verbose_debug_str): def Decorator(match_function): return HighOrderBool(verbose_debug_str, match_function) return Decorator def hob_context_attr(attr_name): def Decorator(attr_getter): return HobContextAttr(attr_name, attr_getter) return Decorator class BoolFunctor(object): def debug_str(self, ctx, display_result=True): if hasattr(self, "__debug_str__"): if display_result: return '"%s"[%s]' % (self.__debug_str__, self(ctx)) else: return '"%s"' % self.__debug_str__ return self.verbose_debug_str(ctx, display_result=display_result) def verbose_debug_str(self, ctx, display_result=True): raise NotImplementedError def __call__(self, ctx): raise NotImplementedError def __and__(self, rhs): return _AndBoolFunctor(self, rhs) def __or__(self, rhs): return _OrBoolFunctor(self, rhs) def __invert__(self): return _NotBoolFunctor(self) class HighOrderBool(BoolFunctor): def __init__(self, verbose_debug_str, function): self.verbose_debug_str_ = verbose_debug_str self.function_ = function def verbose_debug_str(self, ctx, display_result=True): if display_result: return '"%s"[%s]' % (self.verbose_debug_str_, self.function_(ctx)) else: return '"%s"' % self.verbose_debug_str_ def __call__(self, ctx): return self.function_(ctx) always_true = HighOrderBool("Always true", lambda: True) always_false = HighOrderBool("Always false", lambda: False) class _AndBoolFunctor(BoolFunctor): def __init__(self, lhs, rhs): assert isinstance(lhs, BoolFunctor) assert isinstance(rhs, BoolFunctor) self.lhs_ = lhs self.rhs_ = rhs def verbose_debug_str(self, ctx, display_result=True): left_display = self.lhs_.debug_str(ctx, display_result) display_result = display_result and self.lhs_(ctx) right_display = self.rhs_.debug_str(ctx, display_result) return "(%s and %s)" % (left_display, right_display) def __call__(self, ctx): return self.lhs_(ctx) and self.rhs_(ctx) class _OrBoolFunctor(BoolFunctor): def __init__(self, lhs, rhs): assert isinstance(lhs, BoolFunctor) assert isinstance(rhs, BoolFunctor) self.lhs_ = lhs self.rhs_ = rhs def verbose_debug_str(self, ctx, display_result=True): left_display = self.lhs_.debug_str(ctx, display_result) display_result = display_result and (not self.lhs_(ctx)) right_display = self.rhs_.debug_str(ctx, display_result) return "(%s or %s)" % (left_display, right_display) def __call__(self, ctx): return self.lhs_(ctx) or self.rhs_(ctx) class _NotBoolFunctor(BoolFunctor): def __init__(self, x): assert isinstance(x, BoolFunctor) self.x_ = x def verbose_debug_str(self, ctx, display_result=True): return "(not %s)" % self.x_.debug_str(ctx, display_result) def __call__(self, ctx): return not self.x_(ctx) class HobContextGetter(object): def __init__(self, attr_name, attr_getter): self.attr_name_ = attr_name self.attr_getter_ = attr_getter @property def attr_name(self): return self.attr_name_ @property def attr_getter(self): return self.attr_getter_ def __eq__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, "==", lambda a, b: a == b) def __ne__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, "!=", lambda a, b: a != b) def __gt__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, ">", lambda a, b: a > b) def __ge__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, ">=", lambda a, b: a >= b) def __lt__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, "<", lambda a, b: a < b) def __le__(self, other): if not isinstance(other, HobContextGetter): other = HobContextConstant(other) return self._MakeHob(other, "<=", lambda a, b: a <= b) def _MakeHob(self, other, cmp_str, cmp_func): @bool_functor("%s %s %s" % (self.attr_name, cmp_str, other.attr_name)) def HobHob(context): return cmp_func(self.attr_getter(context), other.attr_getter(context)) return HobHob class HobContextConstant(HobContextGetter): def __init__(self, value): HobContextGetter.__init__(self, str(value), lambda ctx: value) class HobContextAttr(HobContextGetter): def __init__(self, attr_name, attr_getter): HobContextGetter.__init__(self, attr_name, attr_getter) def __getattr__(self, attr_name): @hob_context_attr("%s.%s" % (self.attr_name, attr_name)) def HobCtxAttr(ctx): obj = self.attr_getter(ctx) return getattr(obj, attr_name) return HobCtxAttr def HasField(self, attr_name): @bool_functor('%s.HasField("%s")' % (self.attr_name, attr_name)) def BoolFunctor(ctx): obj = self.attr_getter(ctx) if hasattr(obj, "HasField"): return obj.HasField(attr_name) else: return hasattr(obj, attr_name) return BoolFunctor ================================================ FILE: python/oneflow/support/lazy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ class Lazy(object): def __init__(self, get_value): self.value_ = None self.has_value_ = False self.get_value_ = get_value @property def value(self): if not self.has_value_: self.value_ = self.get_value_() self.has_value_ = True return self.value_ ================================================ FILE: python/oneflow/support/pb_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ def PythonDict2PbMessage(value, msg): def extend_dict(values, msg): for (k, v) in values.items(): if type(v) is dict: extend_dict(v, getattr(msg, k)) elif type(v) is list or type(v) is tuple: extend_list_or_tuple(v, getattr(msg, k)) else: setattr(msg, k, v) else: msg.SetInParent() def extend_list_or_tuple(values, msg): if len(values) == 0: return if type(values[0]) is dict: for v in values: cmd = msg.add() extend_dict(v, cmd) else: msg.extend(values) extend_dict(value, msg) return msg def MergePbMessage(dst, src): assert type(dst) is type(src) for field in dst.DESCRIPTOR.fields: field_name = field.name if field.containing_oneof is not None: if dst.WhichOneof(field.containing_oneof.name) is not None: continue src_field_name = src.WhichOneof(field.containing_oneof.name) if src_field_name is None: continue if field_name != src_field_name: continue else: if dst.HasField(field_name): continue if not src.HasField(field_name): continue _MergePbMessageField(dst, src, field) def _MergePbMessageField(dst, src, field): if field.message_type is None: setattr(dst, field.name, getattr(src, field.name)) else: MergePbMessage(getattr(dst, field.name), getattr(src, field.name)) ================================================ FILE: python/oneflow/support/scope_stack.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from contextlib import contextmanager class ScopeStack(object): def __init__(self, init=[]): if not isinstance(init, list): init = [init] assert isinstance(init, list) self.stack_ = init def Current(self): assert len(self.stack_) > 0 return self.stack_[0] @contextmanager def NewScope(self, scope): self.stack_.insert(0, scope) yield self.stack_.pop(0) ================================================ FILE: python/oneflow/support/traceinfo.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import traceback def GetFrameLocationStr(depth=-1): assert depth < 0 frame = traceback.extract_stack()[depth - 1] return "%s:%d" % (frame[0], frame[1]) def GetStackInfoExcludeOneflowPythonFile(): import oneflow dirname = os.path.dirname(oneflow.__file__) stack_info = traceback.extract_stack() filtered_stack_info = filter( lambda x: x[0].startswith(dirname) == False, stack_info ) return list(filtered_stack_info) ================================================ FILE: python/oneflow/sysconfig.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow from oneflow.framework.sysconfig import ( cmake_build_type, get_compile_flags, get_include, get_lib, get_link_flags, get_liboneflow_link_flags, has_rpc_backend_grpc, has_rpc_backend_local, with_cuda, get_cuda_version, with_rdma, ) from oneflow._oneflow_internal.flags import ( with_mlir, with_mlir_cuda_codegen, ) ================================================ FILE: python/oneflow/test/README.md ================================================ ## Ops Version : Alpha | Op Name | Doc Test | Compatiable/Completeness Test | Exception | Performance Test | | ------------------------- | ------------- | ----------------------------- | --------- | ---------------- | | oneflow.autograd.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727) | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.autograd.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.autograd.no_grad | | [no_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L62) | | | | oneflow.autograd.enable_grad | | [enable_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L50) | | | | oneflow.autograd.set_grad_enabled | | [set_grad_enabled](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L74) | | | | oneflow.autograd.inference_mode | | [inference_mode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L27) | | | | oneflow.Tensor.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.requires_grad | [oneflow.Tensor.requires_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L800) | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.is_leaf | [oneflow.Tensor.is_leaf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L767) | | | | | oneflow.Tensor.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727) | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.detach | | [tensor_detach](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_2.py#L91) | | | | oneflow.Tensor.register_hook | [oneflow.Tensor.register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L833) | [tensor_register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L446) | | | | oneflow.Tensor.retain_grad | [oneflow.Tensor.retain_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L866) | [retain_grad_for_leaf_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L178) | | | | oneflow.autograd.Function.forward | | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27) | | | | oneflow.autograd.Function.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727) | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.autograd.Function.apply | | [module_apply](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L161) | | | | oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.mark_non_differentiable | | | | | | oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.save_for_backward | | | | | | oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.saved_tensors | | | | | | oneflow.cuda.is_available | | | | | | oneflow.cuda.device_count | | | | | | oneflow.cuda.current_device | | | | | | oneflow.cuda.set_device | | | | | | oneflow.cuda.synchronize | | | | | | oneflow.cuda.manual_seed_all | | | | | | oneflow.cuda.manual_seed | | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72) | | | | oneflow.cuda.HalfTensor | | | | | | oneflow.cuda.FloatTensor | | | | | | oneflow.cuda.DoubleTensor | | | | | | oneflow.cuda.BoolTensor | | | | | | oneflow.cuda.ByteTensor | | | | | | oneflow.cuda.CharTensor | | | | | | oneflow.cuda.IntTensor | | | | | | oneflow.cuda.LongTensor | | | | | | oneflow.cuda.empty_cache | | | | | | oneflow.nn.functional.conv1d | [oneflow._C.conv1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L20) | [conv1d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L128) | | | | oneflow.nn.functional.conv2d | [oneflow._C.conv2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L57) | [conv2d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L134) | | done | | oneflow.nn.functional.conv3d | [oneflow._C.conv3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L95) | [conv3d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L140) | | | | oneflow.nn.functional.conv_transpose1d | | | | | | oneflow.nn.functional.conv_transpose2d | | | | done | | oneflow.nn.functional.conv_transpose3d | | | | | | oneflow.nn.functional.fold | [oneflow.nn.functional.fold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/convolution.py#L20) | [fold_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_fold.py#L25) | | done | | oneflow.nn.functional.unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563) | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30) | | | | oneflow.nn.functional.avg_pool1d | [oneflow._C.avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L99) | | | | | oneflow.nn.functional.avg_pool2d | [oneflow._C.avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L129) | | | done | | oneflow.nn.functional.avg_pool3d | [oneflow._C.avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L151) | | | | | oneflow.nn.functional.max_pool1d | | | | | | oneflow.nn.functional.max_pool2d | | | | done | | oneflow.nn.functional.max_pool3d | | | | | | oneflow.nn.functional.adaptive_avg_pool1d | [oneflow._C.adaptive_avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L20) | | | done | | oneflow.nn.functional.adaptive_avg_pool2d | [oneflow._C.adaptive_avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L48) | | | done | | oneflow.nn.functional.adaptive_avg_pool3d | [oneflow._C.adaptive_avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L74) | | | done | | oneflow.nn.functional.threshold | [oneflow._C.threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L496) | [softplus_threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L533) | | done | | oneflow.nn.functional.relu | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50) | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33) | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29) | done | | oneflow.nn.functional.hardtanh | [oneflow._C.hardtanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L363) | [hardtanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L618) | | done | | oneflow.nn.functional.hardswish | [oneflow._C.hardswish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L316) | [hardswish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L590) | | done | | oneflow.nn.functional.relu6 | | [relu6_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L73) | | done | | oneflow.nn.functional.elu | [oneflow._C.elu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L385) | [elu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L165) | | done | | oneflow.nn.functional.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409) | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754) | | done | | oneflow.nn.functional.celu | [oneflow._C.celu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L468) | [celu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L203) | [celu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L47) | done | | oneflow.nn.functional.leaky_relu | [oneflow._C.leaky_relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L373) | | | done | | oneflow.nn.functional.prelu | [oneflow._C.prelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L20) | [prelu_4dim_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_prelu.py#L32) | [prelu_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L38) | | | oneflow.nn.functional.glu | [oneflow._C.glu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L436) | [glu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_glu.py#L37) | [glu_scalar_tensor_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L57) | done | | oneflow.nn.functional.gelu | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74) | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253) | | done | | oneflow.nn.functional.logsigmoid | [oneflow._C.logsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L177) | [logsigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L484) | | done | | oneflow.nn.functional.hardshrink | [oneflow._C.hardshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L507) | [hardshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L857) | | done | | oneflow.nn.functional.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207) | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782) | | done | | oneflow.nn.functional.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146) | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43) | | done | | oneflow.nn.functional.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118) | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436) | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109) | done | | oneflow.nn.functional.softshrink | [oneflow._C.softshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L518) | [softshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L895) | | done | | oneflow.nn.functional.log_softmax | [oneflow._C.log_softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L132) | | | done | | oneflow.nn.functional.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163) | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106) | | done | | oneflow.nn.functional.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338) | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281) | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87) | done | | oneflow.nn.functional.hardsigmoid | [oneflow._C.hardsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L298) | [hardsigmoid_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L336) | | done | | oneflow.nn.functional.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237) | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726) | | done | | oneflow.nn.functional.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267) | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698) | | done | | oneflow.nn.functional.layer_norm | [oneflow.nn.functional.layer_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/normalization.py#L20) | [t5_layer_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_t5_layernorm.py#L55) | | | | oneflow.nn.functional.normalize | [oneflow._C.normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L268) | [functional_normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_normalize.py#L54) | | | | oneflow.nn.functional.linear | | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27) | | | | oneflow.nn.functional.dropout | [oneflow._C.dropout](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L20) | [dropout_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L44) | | done | | oneflow.nn.functional.dropout1d | [oneflow._C.dropout1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L102) | [dropout1d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L309) | | | | oneflow.nn.functional.dropout2d | [oneflow._C.dropout2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L124) | [dropout2d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L316) | | | | oneflow.nn.functional.dropout3d | [oneflow._C.dropout3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L146) | [dropout3d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L323) | | | | oneflow.nn.functional.embedding | | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174) | | | | oneflow.nn.functional.one_hot | [oneflow._C.one_hot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/onehot.py#L20) | [one_hot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_hot.py#L27) | | | | oneflow.nn.functional.cosine_similarity | [oneflow._C.cosine_similarity](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/distance.py#L20) | | [cosine_similarity_not_floating_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L24) | done | | oneflow.nn.functional.pairwise_distance | [oneflow._C.pairwise_distance](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/distance.py#L54) | [pairwise_distance_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_pairwise_distance.py#L27) | | | | oneflow.nn.functional.sparse_softmax_cross_entropy | | [eager_global_sparse_softmax_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py#L131) | [sparse_softmax_cross_entropy_prediction_numaxes_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_softmax_cross_entropy_op.py#L23) | | | oneflow.nn.functional.cross_entropy | [oneflow._C.cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L82) | [eager_global_sparse_softmax_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py#L131) | [sparse_cross_entropy_prediction_numaxes_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py#L23) | | | oneflow.nn.functional.l1_loss | [oneflow._C.l1_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L130) | [l1_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L277) | [smooth_l1_loss_shape_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_smooth_l1_loss_op.py#L23) | | | oneflow.nn.functional.mse_loss | [oneflow._C.mse_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L156) | [mse_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L328) | | | | oneflow.nn.functional.smooth_l1_loss | [oneflow._C.smooth_l1_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L186) | [smooth_l1_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L308) | [smooth_l1_loss_shape_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_smooth_l1_loss_op.py#L23) | | | oneflow.nn.functional.triplet_margin_loss | [oneflow._C.triplet_margin_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L20) | | [triplet_margin_loss_reduce_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L255) | | | oneflow.nn.functional.binary_cross_entropy | | [nn_functional_binary_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L245) | | | | oneflow.nn.functional.binary_cross_entropy_with_logits | | [nn_functional_binary_cross_entropy_with_logits](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L269) | | | | oneflow.nn.functional.pad | [oneflow._C.pad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/vision.py#L20) | [pad_1d_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_pad.py#L25) | [pad_size_attribute_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L89) | | | oneflow.nn.functional.interpolate | | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27) | | | | oneflow.nn.functional.upsample | | [upsample_bilinear_align_corners](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L338) | | | | oneflow.nn.functional.grid_sample | | [grid_sample_4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_grid_sample.py#L31) | | done | | oneflow.nn.functional.affine_grid | | [affine_grid_2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_affine_grid.py#L31) | | done | | oneflow.nn.functional.ctc_greedy_decoder | [oneflow._C.ctc_greedy_decoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/ctc_decode.py#L20) | [ctc_greedy_decoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ctc_greedy_decoder.py#L111) | | | | oneflow.Tensor.new_empty | [oneflow.Tensor.new_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L201) | [new_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_empty.py#L52) | | | | oneflow.Tensor.new_ones | [oneflow.Tensor.new_ones](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L229) | [flow_new_ones_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L107) | | | | oneflow.Tensor.new_zeros | [oneflow.Tensor.new_zeros](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L238) | [new_zeros](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L129) | | | | oneflow.Tensor.new_tensor | | [new_tensor_local_mode_with_default_args](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_new_tensor.py#L25) | | | | oneflow.Tensor.is_cuda | [oneflow.Tensor.is_cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2071) | | | | | oneflow.Tensor.is_global | [oneflow.Tensor.is_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L109) | | | | | oneflow.Tensor.device | [oneflow.Tensor.device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L85) | [non_default_device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_randperm.py#L133) | [device_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_device.py#L25) | | | oneflow.Tensor.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.ndim | [oneflow.Tensor.ndim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1315) | [abs_with_ndim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_abs.py#L34) | | | | oneflow.Tensor.abs | [oneflow.abs](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L20) | [abs_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_abs.py#L27) | | done | | oneflow.Tensor.acos | [oneflow.acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L509) | [acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L122) | | | | oneflow.Tensor.acosh | [oneflow.acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L535) | [acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L138) | | | | oneflow.Tensor.add | [oneflow.add](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L41) | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57) | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27) | done | | oneflow.Tensor.add_ | [oneflow.Tensor.add_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1222) | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57) | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27) | | | oneflow.Tensor.addcdiv | [oneflow.Tensor.addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L939) | [addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L25) | | done | | oneflow.Tensor.addcdiv_ | [oneflow.Tensor.addcdiv_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L946) | [tensor_addcdiv_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L49) | | | | oneflow.Tensor.addcmul | [oneflow.addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1558) | [addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addcmul.py#L37) | | done | | oneflow.Tensor.addcmul_ | [oneflow.Tensor.addcmul_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1236) | [tensor_addcmul_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcmul.py#L50) | | | | oneflow.Tensor.addmm | [oneflow.Tensor.addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1215) | [addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addmm.py#L60) | | done | | oneflow.Tensor.all | [oneflow.Tensor.all](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1822) | [flow_var_all_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_var.py#L27) | | | | oneflow.Tensor.amin | [oneflow.Tensor.amin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2167) | [amin_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amin.py#L34) | | done | | oneflow.Tensor.amax | [oneflow.Tensor.amax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L911) | [amax_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amax.py#L35) | | done | | oneflow.Tensor.any | [oneflow.Tensor.any](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1831) | [any_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L52) | | | | oneflow.Tensor.arccos | [oneflow.Tensor.arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L664) | [arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L114) | | | | oneflow.Tensor.arccosh | [oneflow.Tensor.arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L678) | [arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L130) | | | | oneflow.Tensor.arcsin | [oneflow.Tensor.arcsin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1257) | [flow_arcsin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L230) | | | | oneflow.Tensor.arcsinh | [oneflow.Tensor.arcsinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1264) | [flow_arcsinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L247) | | | | oneflow.Tensor.arctan | [oneflow.Tensor.arctan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1343) | [flow_arctan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L274) | | | | oneflow.Tensor.arctanh | [oneflow.Tensor.arctanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L685) | [flow_arctanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L296) | | | | oneflow.Tensor.argmax | [oneflow.Tensor.argmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L692) | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29) | [argmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L22) | done | | oneflow.Tensor.argmin | [oneflow.Tensor.argmin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L699) | [argmin_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmin.py#L29) | | | | oneflow.Tensor.argsort | [oneflow.Tensor.argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L706) | [argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_argsort.py#L37) | | done | | oneflow.Tensor.argwhere | [oneflow.Tensor.argwhere](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L713) | [argwhere_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argwhere.py#L50) | | | | oneflow.Tensor.asin | [oneflow.asin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L285) | [flow_asin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L223) | | | | oneflow.Tensor.asinh | [oneflow.asinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L318) | [flow_asinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L240) | | | | oneflow.Tensor.atan | [oneflow.atan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L353) | [flow_atan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L267) | | | | oneflow.Tensor.atan2 | [oneflow.Tensor.atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L123) | [atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L155) | | | | oneflow.Tensor.atanh | [oneflow.atanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L564) | [flow_atanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L289) | | | | oneflow.Tensor.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727) | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.bmm | [oneflow.Tensor.bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L876) | [bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bmm.py#L93) | [bmm_exception_dim_not_right](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_bmm.py#L25) | | | oneflow.Tensor.byte | [oneflow.Tensor.byte](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2159) | [byte](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1234) | | | | oneflow.Tensor.cast | [oneflow.cast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/cast.py#L20) | [cast_float2int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cast.py#L28) | [add_broad_cast_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L37) | | | oneflow.Tensor.ceil | [oneflow.ceil](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L378) | [ceil_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ceil.py#L29) | | | | oneflow.Tensor.chunk | [oneflow.Tensor.chunk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L883) | [flow_chunk_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_chunk.py#L46) | [chunk_0_dim_input_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_chunk.py#L25) | | | oneflow.Tensor.clamp | [oneflow.clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L20) | [clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L96) | | | | oneflow.Tensor.clamp_ | [oneflow.Tensor.clamp_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1548) | [clamp_scalar_min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L47) | | | | oneflow.Tensor.clip | [oneflow.clip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L152) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | | | | oneflow.Tensor.clip_ | [oneflow.Tensor.clip_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1562) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | | | | oneflow.Tensor.clone | | [clone_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_clone.py#L24) | | | | oneflow.Tensor.contiguous | | [tensor_scatter_nd_update_with_non_contiguous_input](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_scatter_nd_update.py#L40) | | | | oneflow.Tensor.copy_ | [oneflow.Tensor.copy_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1468) | [copy_broadcast_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_copy.py#L30) | | | | oneflow.Tensor.cos | [oneflow.cos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L712) | [global_cos_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L65) | | | | oneflow.Tensor.cosh | [oneflow.cosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L736) | | | | | oneflow.Tensor.cpu | [oneflow.Tensor.cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1569) | [from_torch_cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_from_torch.py#L26) | | | | oneflow.Tensor.cuda | [oneflow.Tensor.cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1587) | [cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L110) | | | | oneflow.Tensor.cumprod | [oneflow.cumprod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1788) | [cumprod_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumprod.py#L25) | | done | | oneflow.Tensor.cumsum | [oneflow.cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1755) | [cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumsum.py#L37) | | done | | oneflow.Tensor.data | | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32) | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278) | | | oneflow.Tensor.dot | [oneflow.dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1438) | [dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L903) | [dot_shape_error_msg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_dot.py#L24) | done | | oneflow.Tensor.detach | | [tensor_detach](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_2.py#L91) | | | | oneflow.Tensor.placement | [oneflow.Tensor.placement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L95) | [eager_boxing_with_same_placement_p_to_s1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing.py#L3093) | [multi_input_with_diff_placement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_multi_input_with_diff_device_or_placement.py#L42) | | | oneflow.Tensor.sbp | [oneflow.Tensor.sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L102) | [eager_global_cast_with_same_placement_and_sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing.py#L3205) | [get_sbp_with_invalid_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L24) | | | oneflow.Tensor.diag | [oneflow.Tensor.diag](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L932) | [diag_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diag.py#L26) | | done | | oneflow.Tensor.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294) | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24) | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204) | done | | oneflow.Tensor.dim | [oneflow.Tensor.dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L953) | [cosine_similartiy_module_with_nonequal_dim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cosine_similarity.py#L53) | [glu_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L67) | | | oneflow.Tensor.div | [oneflow.div](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L143) | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26) | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81) | done | | oneflow.Tensor.div_ | [oneflow.Tensor.div_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1116) | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26) | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81) | | | oneflow.Tensor.double | [oneflow.Tensor.double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2041) | [double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L211) | | | | oneflow.Tensor.dtype | | [out_grad_with_different_dtype](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L113) | [sparse_cross_entropy_label_dtype_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py#L53) | | | oneflow.Tensor.element_size | [oneflow.Tensor.element_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L962) | | | | | oneflow.Tensor.eq | [oneflow.Tensor.eq](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1011) | [eq_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_eq.py#L25) | | done | | oneflow.Tensor.erf | [oneflow.erf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L763) | [flow_erf_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erf.py#L33) | | done | | oneflow.Tensor.erfc | [oneflow.erfc](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L810) | [erfc_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_erfc.py#L25) | | done | | oneflow.Tensor.erfinv | [oneflow.Tensor.erfinv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L997) | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30) | | done | | oneflow.Tensor.erfinv_ | [oneflow.Tensor.erfinv_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1004) | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30) | | | | oneflow.Tensor.exp | [oneflow.exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L476) | [exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L72) | | | | oneflow.Tensor.expand | [oneflow.Tensor.expand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L130) | [expand_new_dims_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_expand_op.py#L28) | [expand_dim_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L78) | | | oneflow.Tensor.expand_as | [oneflow.Tensor.expand_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L139) | | | | | oneflow.Tensor.expm1 | [oneflow.expm1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L845) | [expm1_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_expm1.py#L29) | | done | | oneflow.Tensor.fill_ | [oneflow.Tensor.fill_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1053) | [masked_fill_with_0dim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L35) | | done | | oneflow.Tensor.flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155) | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30) | | done | | oneflow.Tensor.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169) | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70) | | done | | oneflow.Tensor.float | [oneflow.Tensor.float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2020) | [logical_xor_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L37) | | | | oneflow.Tensor.floor | [oneflow.floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L100) | [floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_floor.py#L35) | | done | | oneflow.Tensor.floor_ | [oneflow.floor_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L135) | [flow_floor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_floor.py#L57) | | | | oneflow.Tensor.floor_divide | | | | | | oneflow.Tensor.fmod | [oneflow.fmod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L890) | [flow_fmod_element_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1021) | | done | | oneflow.Tensor.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531) | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85) | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120) | done | | oneflow.Tensor.ge | [oneflow.Tensor.ge](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1062) | | | | | oneflow.Tensor.get_device | [oneflow.Tensor.get_device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1069) | | | | | oneflow.Tensor.grad_fn | [oneflow.Tensor.grad_fn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L760) | [parameter_grad_fn_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_parameter.py#L29) | | | | oneflow.Tensor.gt | [oneflow.Tensor.gt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1080) | | | done | | oneflow.Tensor.half | [oneflow.Tensor.half](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1520) | [module_to_half](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to_half.py#L25) | | | | oneflow.Tensor.in_top_k | [oneflow.Tensor.in_top_k](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L176) | [in_top_k_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_in_top_k.py#L82) | [in_top_k_num_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L389) | | | oneflow.Tensor.index_select | [oneflow.Tensor.index_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L185) | [index_select_by_random](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_index_select.py#L30) | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330) | | | oneflow.Tensor.int | [oneflow.Tensor.int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1978) | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27) | [tensordot_too_large_int_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L35) | | | oneflow.Tensor.is_contiguous | [oneflow.Tensor.is_contiguous](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2062) | | | | | oneflow.Tensor.is_floating_point | [oneflow.is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/is_floating_point.py#L20) | [is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L176) | | | | oneflow.Tensor.is_lazy | [oneflow.Tensor.is_lazy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L116) | | | | | oneflow.Tensor.is_leaf | [oneflow.Tensor.is_leaf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L767) | | | | | oneflow.Tensor.isinf | [oneflow.Tensor.isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2152) | [isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L33) | | | | oneflow.Tensor.isnan | [oneflow.Tensor.isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2145) | [isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L24) | | | | oneflow.Tensor.item | [oneflow.Tensor.item](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2087) | [tensordot_single_item_tensor_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensordot.py#L105) | | | | oneflow.Tensor.le | [oneflow.Tensor.le](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1025) | | | | | oneflow.Tensor.log | [oneflow.log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L923) | [log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L56) | | | | oneflow.Tensor.log1p | [oneflow.log1p](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L455) | [log1p_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_log1p.py#L31) | | | | oneflow.Tensor.log2 | [oneflow.log2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L948) | [log2_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L971) | | | | oneflow.Tensor.logical_and | [oneflow.Tensor.logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1677) | [logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_and.py#L58) | | | | oneflow.Tensor.logical_or | [oneflow.Tensor.logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1687) | [logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_or.py#L58) | | | | oneflow.Tensor.logical_not | [oneflow.Tensor.logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L520) | [logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_not.py#L43) | | | | oneflow.Tensor.logical_xor | [oneflow.Tensor.logical_xor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1698) | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27) | | | | oneflow.Tensor.long | [oneflow.Tensor.long](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1999) | [long](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L145) | | | | oneflow.Tensor.lt | [oneflow.Tensor.lt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1018) | | | | | oneflow.Tensor.masked_fill | [oneflow.Tensor.masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1708) | [masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L58) | | | | oneflow.Tensor.masked_select | [oneflow.Tensor.masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1715) | [masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_masked_select.py#L87) | | | | oneflow.Tensor.matmul | [oneflow.matmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1249) | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173) | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220) | | | oneflow.Tensor.mm | [oneflow.mm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1311) | [flow_mm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L69) | [mm_not_2dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mm.py#L24) | | | oneflow.Tensor.mv | [oneflow.mv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1278) | [flow_mv_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L78) | [mv_not_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mv.py#L23) | done | | oneflow.Tensor.max | [oneflow.Tensor.max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1774) | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83) | | | | oneflow.Tensor.maximum | [oneflow.maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L997) | [broadcast_maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L32) | | | | oneflow.Tensor.median | [oneflow.median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1019) | [median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_median.py#L48) | [median_exception_dim_out_of_range](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_median.py#L25) | | | oneflow.Tensor.mean | [oneflow.Tensor.mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1840) | [mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_mean.py#L70) | [normalization_moving_mean_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L317) | | | oneflow.Tensor.min | [oneflow.Tensor.min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1783) | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83) | | | | oneflow.Tensor.minimum | [oneflow.minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L975) | [broadcast_minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L50) | | | | oneflow.Tensor.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267) | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698) | | done | | oneflow.Tensor.mul | [oneflow.mul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L186) | [einsum_eltwise_mul_then_reduce_sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py#L40) | | | | oneflow.Tensor.mul_ | [oneflow.Tensor.mul_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1108) | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173) | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220) | | | oneflow.Tensor.narrow | [oneflow.Tensor.narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L629) | [narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_narrow.py#L35) | [narrow_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L178) | | | oneflow.Tensor.ndimension | | | | | | oneflow.Tensor.ne | [oneflow.Tensor.ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1032) | [ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ne.py#L31) | | | | oneflow.Tensor.neg | [oneflow.Tensor.neg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1039) | [flow_split_sizes_neg_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_split.py#L63) | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25) | | | oneflow.Tensor.negative | [oneflow.negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L428) | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29) | [repeat_interleave_negative_tensor_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L58) | | | oneflow.Tensor.nelement | [oneflow.Tensor.nelement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1137) | [tensor_nelement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L552) | | | | oneflow.Tensor.nonzero | [oneflow.nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/nonzero.py#L20) | [nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L51) | | | | oneflow.Tensor.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160) | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50) | | | | oneflow.Tensor.normal_ | [oneflow.Tensor.normal_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1154) | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113) | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278) | | | oneflow.Tensor.numel | [oneflow.Tensor.numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L194) | [tensor_numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L558) | | | | oneflow.Tensor.numpy | [oneflow.Tensor.numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1163) | [dropout_numpy_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L29) | [numpy_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_pad.py#L32) | | | oneflow.Tensor.permute | [oneflow.Tensor.permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L643) | [einsum_batch_permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_batch_permute.py#L42) | | | | oneflow.Tensor.pow | [oneflow.pow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1132) | [pow_with_scalar](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L96) | | | | oneflow.Tensor.prod | [oneflow.Tensor.prod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1849) | [prod_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L59) | | | | oneflow.Tensor.reciprocal | [oneflow.reciprocal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L226) | [flow_reciprocal_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_reciprocal.py#L32) | | | | oneflow.Tensor.register_hook | [oneflow.Tensor.register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L833) | [tensor_register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L446) | | | | oneflow.Tensor.relu | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50) | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33) | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29) | done | | oneflow.Tensor.repeat | [oneflow.Tensor.repeat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1622) | [flow_tensor_repeat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat.py#L27) | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25) | | | oneflow.Tensor.repeat_interleave | [oneflow.Tensor.repeat_interleave](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1631) | [flow_int_repeat_interleave_dim_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat_interleave.py#L29) | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25) | | | oneflow.Tensor.requires_grad | [oneflow.Tensor.requires_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L800) | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.requires_grad_ | [oneflow.Tensor.requires_grad_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L809) | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.Tensor.reshape | [oneflow.Tensor.reshape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1858) | [reshape_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_reshape.py#L27) | [reshape_like_size_match_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reshape_like_op.py#L24) | done | | oneflow.Tensor.reshape_as | [oneflow.Tensor.reshape_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1865) | [reshape_as_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1181) | | | | oneflow.Tensor.retain_grad | [oneflow.Tensor.retain_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L866) | [retain_grad_for_leaf_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L178) | | | | oneflow.Tensor.roll | [oneflow.Tensor.roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1187) | [roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roll.py#L27) | [roll_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L112) | | | oneflow.Tensor.round | [oneflow.round](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1346) | [flow_round_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_round.py#L30) | | | | oneflow.Tensor.rsqrt | [oneflow.rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1173) | [rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L80) | | | | oneflow.Tensor.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409) | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754) | | done | | oneflow.Tensor.shape | | [randn_tuple_shape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randn.py#L62) | [layernorm_exception_input_shape_not_match](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_layernorm.py#L25) | | | oneflow.Tensor.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338) | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281) | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87) | done | | oneflow.Tensor.sign | [oneflow.sign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L589) | [sign_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sign.py#L25) | | | | oneflow.Tensor.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237) | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726) | | done | | oneflow.Tensor.sin | [oneflow.sin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L618) | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59) | | | | oneflow.Tensor.sin_ | [oneflow.sin_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L648) | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59) | | | | oneflow.Tensor.sinh | [oneflow.sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L656) | [sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L23) | | | | oneflow.Tensor.size | [oneflow.Tensor.size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1392) | [unsqueeze_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_unsqueeze.py#L62) | [local_to_global_with_invalid_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L75) | | | oneflow.Tensor.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118) | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436) | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109) | done | | oneflow.Tensor.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146) | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43) | | done | | oneflow.Tensor.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207) | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782) | | done | | oneflow.Tensor.sort | [oneflow.Tensor.sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1947) | [sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sort.py#L69) | | | | oneflow.Tensor.split | [oneflow.Tensor.split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L890) | [eager_boxing_2d_special_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L146) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.Tensor.sqrt | [oneflow.sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1198) | [sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L64) | | | | oneflow.Tensor.square | [oneflow.square](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1224) | [inv_random_square_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L39) | [inv_exception_not_square_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L34) | | | oneflow.Tensor.squeeze | [oneflow.Tensor.squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L556) | [squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_squeeze.py#L94) | [squeeze_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L106) | | | oneflow.Tensor.std | [oneflow.std](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1371) | [std_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_std.py#L26) | | | | oneflow.Tensor.storage_offset | [oneflow.Tensor.storage_offset](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L268) | | | | | oneflow.Tensor.stride | | [flow_as_strided_with_stride](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_as_stride.py#L49) | | | | oneflow.Tensor.sum | [oneflow.Tensor.sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1813) | [sum_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sum.py#L29) | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24) | | | oneflow.Tensor.swapaxes | [oneflow.Tensor.swapaxes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L904) | [swapaxes_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapaxes.py#L31) | | | | oneflow.Tensor.swapdims | [oneflow.Tensor.swapdims](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L918) | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32) | | | | oneflow.Tensor.sub | [oneflow.sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L246) | [global_sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L50) | | | | oneflow.Tensor.sub_ | [oneflow.Tensor.sub_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1123) | [global_sub_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L56) | | | | oneflow.Tensor.tan | [oneflow.tan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L687) | [flow_tan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L257) | | | | oneflow.Tensor.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163) | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106) | | done | | oneflow.Tensor.tile | [oneflow.tile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tile.py#L20) | [flow_tile_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tile.py#L27) | [tile_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L431) | | | oneflow.Tensor.to | [oneflow.Tensor.to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1485) | [dummy_module_to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L58) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.Tensor.local_to_global | [oneflow.Tensor.local_to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L286) | [local_to_global_2d_sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L85) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.Tensor.global_to_global | [oneflow.Tensor.global_to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L334) | [cuda_global_to_global_cpu_s2b](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L210) | [global_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L51) | | | oneflow.Tensor.to_global | [oneflow.Tensor.to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L381) | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.Tensor.to_local | [oneflow.Tensor.to_local](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L473) | | [call_to_local_for_local_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L65) | | | oneflow.Tensor.to_consistent | [oneflow.Tensor.to_consistent](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L466) | | | | | oneflow.Tensor.tolist | [oneflow.Tensor.tolist](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2108) | [tolist](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L257) | | | | oneflow.Tensor.topk | [oneflow.Tensor.topk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1751) | [flow_topk_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L306) | | | | oneflow.Tensor.transpose | [oneflow.Tensor.transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L513) | [einsum_matrix_transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_matrix_transpose.py#L35) | | | | oneflow.Tensor.tril | [oneflow.Tensor.tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1441) | [fused_scale_tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_fused_scale_tril.py#L78) | | | | oneflow.Tensor.triu | [oneflow.Tensor.triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1448) | [triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_triu.py#L47) | | | | oneflow.Tensor.type_as | [oneflow.Tensor.type_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1954) | [type_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L165) | | | | oneflow.Tensor.type | [oneflow.Tensor.type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2192) | [type_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_type_tensor.py#L74) | [cosine_similarity_not_floating_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L24) | | | oneflow.Tensor.t | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640) | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140) | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439) | | | oneflow.Tensor.T | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640) | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140) | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439) | | | oneflow.Tensor.unbind | [oneflow.Tensor.unbind](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L897) | [unbind_flow_with_random_data1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unbind.py#L32) | [unbind_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L248) | | | oneflow.Tensor.unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563) | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30) | | | | oneflow.Tensor.uniform_ | [oneflow.Tensor.uniform_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1455) | | | | | oneflow.Tensor.unsqueeze | [oneflow.Tensor.unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L636) | [unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L68) | | | | oneflow.Tensor.var | [oneflow.var](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1407) | [module_to_with_var_reuse](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L93) | | | | oneflow.Tensor.view | [oneflow.Tensor.view](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1881) | [view](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_view.py#L79) | [view_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_view.py#L25) | | | oneflow.Tensor.view_as | [oneflow.Tensor.view_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1931) | | | | | oneflow.Tensor.where | [oneflow.Tensor.where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2129) | [where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_where.py#L196) | | | | oneflow.Tensor.zero_ | [oneflow.Tensor.zero_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2136) | [nonzero_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L64) | | | | oneflow.Tensor.nms | [oneflow.Tensor.nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1758) | [nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_nms.py#L50) | | | | oneflow.Tensor.pin_memory | [oneflow.Tensor.pin_memory](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2174) | [tensor_pin_memory](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_pin_memory.py#L33) | | | | oneflow.Tensor.is_pinned | [oneflow.Tensor.is_pinned](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2183) | [tensor_is_pinned](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_pin_memory.py#L76) | | | | oneflow.nn.Parameter | | [ddp_with_partial_requires_grad_parameter](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ddp.py#L225) | [direction_parameter_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_arg_sort_op.py#L23) | | | oneflow.nn.Module | [oneflow.nn.Module.to_consistent](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/module.py#L20) | [dummy_module](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L45) | | | | oneflow.nn.Sequential | | | | | | oneflow.nn.ModuleList | | | | | | oneflow.nn.ModuleDict | | [moduledict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L353) | | | | oneflow.nn.ParameterList | | | | | | oneflow.nn.ParameterDict | | | | | | oneflow.nn.Module.add_module | | | | | | oneflow.nn.Module.apply | | [module_apply](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L161) | | | | oneflow.nn.Module.buffers | | | | | | oneflow.nn.Module.children | | | | | | oneflow.nn.Module.cpu | [oneflow.Tensor.cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1569) | [from_torch_cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_from_torch.py#L26) | | | | oneflow.nn.Module.cuda | [oneflow.Tensor.cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1587) | [cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L110) | | | | oneflow.nn.Module.double | [oneflow.Tensor.double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2041) | [double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L211) | | | | oneflow.nn.Module.train | | [train_eval](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L121) | | | | oneflow.nn.Module.eval | | [dropout_eval_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L33) | [normalization_eval_need_moving_statistic_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L347) | | | oneflow.nn.Module.extra_repr | | | | | | oneflow.nn.Module.float | [oneflow.Tensor.float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2020) | [logical_xor_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L37) | | | | oneflow.nn.Module.forward | | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27) | | | | oneflow.nn.Module.load_state_dict | | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63) | | | | oneflow.nn.Module.modules | | | | | | oneflow.nn.Module.named_buffers | | | | | | oneflow.nn.Module.named_children | | | | | | oneflow.nn.Module.named_modules | | | | | | oneflow.nn.Module.named_parameters | | | | | | oneflow.nn.Module.parameters | | | | | | oneflow.nn.Module.register_buffer | | | | | | oneflow.nn.Module.register_forward_hook | | | | | | oneflow.nn.Module.register_forward_pre_hook | | | | | | oneflow.nn.Module.register_parameter | | | | | | oneflow.nn.Module.requires_grad_ | [oneflow.Tensor.requires_grad_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L809) | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170) | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24) | | | oneflow.nn.Module.state_dict | | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63) | | | | oneflow.nn.Module.to | [oneflow.Tensor.to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1485) | [dummy_module_to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L58) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.nn.Module.zero_grad | | | | | | oneflow.nn.Conv1d | [oneflow._C.conv1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L20) | [conv1d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L128) | | | | oneflow.nn.Conv2d | [oneflow._C.conv2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L57) | [conv2d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L134) | | | | oneflow.nn.Conv3d | [oneflow._C.conv3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L95) | [conv3d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L140) | | | | oneflow.nn.ConvTranspose1d | | | | | | oneflow.nn.ConvTranspose2d | | | | | | oneflow.nn.ConvTranspose3d | | | | | | oneflow.nn.Unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563) | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30) | | | | oneflow.nn.Fold | [oneflow.nn.functional.fold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/convolution.py#L20) | [fold_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_fold.py#L25) | | | | oneflow.nn.MaxPool1d | | [maxpool1d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L28) | | | | oneflow.nn.MaxPool2d | | [maxpool2d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L51) | | | | oneflow.nn.MaxPool3d | | [maxpool3d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L75) | | | | oneflow.nn.AdaptiveAvgPool1d | | | | | | oneflow.nn.AdaptiveAvgPool2d | | | | | | oneflow.nn.AdaptiveAvgPool3d | | | | | | oneflow.nn.AvgPool1d | | [adaptive_avgpool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L39) | | | | oneflow.nn.AvgPool2d | | [adaptive_avgpool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L53) | | | | oneflow.nn.AvgPool3d | | [adaptive_avgpool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L72) | | | | oneflow.nn.ConstantPad1d | | [constantpad1d_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant_pad.py#L32) | | | | oneflow.nn.ConstantPad2d | | [ConstantPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_zeropad2d.py#L96) | | | | oneflow.nn.ConstantPad3d | | [constantpad3d_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant_pad.py#L64) | | | | oneflow.nn.ReflectionPad1d | | | | | | oneflow.nn.ReflectionPad2d | | | | | | oneflow.nn.ReplicationPad1d | | | | | | oneflow.nn.ReplicationPad2d | | [ReplicationPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_replication_pad.py#L104) | | | | oneflow.nn.ZeroPad2d | | [global_ZeroPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeropad2d.py#L37) | | | | oneflow.nn.ELU | [oneflow._C.elu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L385) | [elu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L165) | | | | oneflow.nn.Hardshrink | [oneflow._C.hardshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L507) | [hardshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L857) | | | | oneflow.nn.Hardsigmoid | [oneflow._C.hardsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L298) | [hardsigmoid_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L336) | | | | oneflow.nn.Hardswish | [oneflow._C.hardswish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L316) | [hardswish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L590) | | | | oneflow.nn.Hardtanh | [oneflow._C.hardtanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L363) | [hardtanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L618) | | | | oneflow.nn.LeakyReLU | | [leakyrelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L656) | | | | oneflow.nn.LogSigmoid | [oneflow._C.logsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L177) | [logsigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L484) | | | | oneflow.nn.PReLU | [oneflow._C.prelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L20) | [prelu_4dim_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_prelu.py#L32) | [prelu_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L38) | | | oneflow.nn.ReLU | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50) | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33) | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29) | | | oneflow.nn.ReLU6 | | [relu6_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L73) | | | | oneflow.nn.SELU | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409) | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754) | | | | oneflow.nn.CELU | [oneflow._C.celu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L468) | [celu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L203) | [celu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L47) | | | oneflow.nn.GELU | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74) | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253) | | | | oneflow.nn.SiLU | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237) | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726) | | | | oneflow.nn.Sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338) | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281) | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87) | | | oneflow.nn.Mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267) | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698) | | | | oneflow.nn.Softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146) | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43) | | | | oneflow.nn.Softshrink | [oneflow._C.softshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L518) | [softshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L895) | | | | oneflow.nn.Softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207) | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782) | | | | oneflow.nn.Tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163) | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106) | | | | oneflow.nn.Threshold | [oneflow._C.threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L496) | [softplus_threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L533) | | | | oneflow.nn.GLU | [oneflow._C.glu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L436) | [glu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_glu.py#L37) | [glu_scalar_tensor_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L57) | | | oneflow.nn.Softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118) | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436) | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109) | | | oneflow.nn.LogSoftmax | | [logsoftmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L460) | | | | oneflow.nn.BatchNorm1d | | [batchnorm1d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L34) | | | | oneflow.nn.BatchNorm2d | | [batchnorm2d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L52) | | | | oneflow.nn.BatchNorm3d | | [batchnorm3d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L70) | | | | oneflow.nn.FusedBatchNorm1d | | | | | | oneflow.nn.FusedBatchNorm2d | | | | | | oneflow.nn.FusedBatchNorm3d | | | | | | oneflow.nn.GroupNorm | | [groupnorm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_groupnorm.py#L332) | | | | oneflow.nn.InstanceNorm1d | | [instancenorm1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L29) | | | | oneflow.nn.InstanceNorm2d | | [instancenorm2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L71) | | | | oneflow.nn.InstanceNorm3d | | [instancenorm3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L141) | | | | oneflow.nn.LayerNorm | | [t5_layernorm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_t5_layernorm.py#L83) | [layernorm_exception_input_shape_not_match](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_layernorm.py#L25) | | | oneflow.nn.RMSLayerNorm | | | | | | oneflow.nn.RNN | | [rnn_relu_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L206) | | | | oneflow.nn.LSTM | | [lstm_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L200) | | | | oneflow.nn.GRU | | [gru_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L218) | | | | oneflow.nn.RNNCell | | | | | | oneflow.nn.LSTMCell | | | | | | oneflow.nn.GRUCell | | | | | | oneflow.nn.Identity | | [identity](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_linear.py#L113) | | | | oneflow.nn.Linear | | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27) | | | | oneflow.nn.Dropout | [oneflow._C.dropout](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L20) | [dropout_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L44) | | | | oneflow.nn.Dropout1d | [oneflow._C.dropout1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L102) | [dropout1d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L309) | | | | oneflow.nn.Dropout2d | [oneflow._C.dropout2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L124) | [dropout2d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L316) | | | | oneflow.nn.Dropout3d | [oneflow._C.dropout3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L146) | [dropout3d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L323) | | | | oneflow.nn.Embedding | | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174) | | | | oneflow.nn.CosineSimilarity | | | | | | oneflow.nn.PairwiseDistance | | | | | | oneflow.nn.BCELoss | | | | | | oneflow.nn.BCEWithLogitsLoss | | | | | | oneflow.nn.CTCLoss | | | [ctcloss_reduction_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L62) | | | oneflow.nn.CombinedMarginLoss | | | | | | oneflow.nn.CrossEntropyLoss | | | | | | oneflow.nn.KLDivLoss | | | | | | oneflow.nn.L1Loss | | | | | | oneflow.nn.MSELoss | | | | | | oneflow.nn.MarginRankingLoss | | | | | | oneflow.nn.NLLLoss | | | | | | oneflow.nn.SmoothL1Loss | | | | | | oneflow.nn.TripletMarginLoss | | | | | | oneflow.nn.PixelShuffle | | | | | | oneflow.nn.Upsample | | [upsample_bilinear_align_corners](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L338) | | | | oneflow.nn.UpsamplingBilinear2d | | [UpsamplingBilinear2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L97) | | | | oneflow.nn.UpsamplingNearest2d | | [UpsamplingNearest2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L74) | | | | oneflow.nn.parallel.DistributedDataParallel | | | | | | oneflow.nn.COCOReader | | | | | | oneflow.nn.CoinFlip | | | | | | oneflow.nn.CropMirrorNormalize | | | | | | oneflow.nn.OFRecordBytesDecoder | | [OFRecordBytesDecoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dataset.py#L351) | | | | oneflow.nn.OFRecordImageDecoder | | | | | | oneflow.nn.OFRecordImageDecoderRandomCrop | | | | | | oneflow.nn.OFRecordRawDecoder | | | | | | oneflow.nn.OFRecordReader | | | | | | oneflow.nn.MinMaxObserver | | | | | | oneflow.nn.MovingAverageMinMaxObserver | | | | | | oneflow.nn.FakeQuantization | | | | | | oneflow.nn.QatConv1d | | | | | | oneflow.nn.QatConv2d | | | | | | oneflow.nn.QatConv3d | | | | | | oneflow.nn.utils.clip_grad_norm_ | | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50) | | | | oneflow.nn.utils.clip_grad_value_ | | [clip_grad_value_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L79) | | | | oneflow.nn.utils.weight_norm | | [weight_norm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_weight_norm.py#L150) | | | | oneflow.nn.utils.remove_weight_norm | | | | | | oneflow.nn.utils.rnn.PackedSequence | | | | | | oneflow.nn.utils.rnn.pack_padded_sequence | | | | | | oneflow.nn.utils.rnn.pad_packed_sequence | | | | | | oneflow.nn.utils.rnn.pad_sequence | | | | | | oneflow.nn.utils.rnn.pack_sequence | | | | | | oneflow.nn.Flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155) | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30) | | | | oneflow.nn.FakeQuantization | | | | | | oneflow.nn.MinMaxObserver | | | | | | oneflow.nn.MovingAverageMinMaxObserver | | | | | | oneflow.nn.Quantization | | | | | | oneflow.BoolTensor | | | | | | oneflow.ByteTensor | | | | | | oneflow.CharTensor | | | | | | oneflow.DoubleTensor | | | | | | oneflow.FloatTensor | | | | | | oneflow.HalfTensor | | | | | | oneflow.IntTensor | | | | | | oneflow.LongTensor | | | | | | oneflow.is_tensor | | [ellipsis_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_indexing2.py#L900) | [rol_align_rois_tensor_dimension_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_roi_align_op.py#L34) | | | oneflow.is_floating_point | [oneflow.is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/is_floating_point.py#L20) | [is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L176) | | | | oneflow.is_nonzero | | | | | | oneflow.numel | [oneflow.Tensor.numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L194) | [tensor_numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L558) | | | | oneflow.set_printoptions | | | | | | oneflow.tensor | [oneflow.tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L20) | [type_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_type_tensor.py#L74) | [call_to_local_for_local_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L65) | | | oneflow.as_tensor | [oneflow.as_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/as_tensor.py#L20) | [reshape_as_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1181) | | | | oneflow.as_strided | [oneflow.as_strided](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1529) | [flow_as_strided_with_stride](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_as_stride.py#L49) | | | | oneflow.from_numpy | [oneflow.from_numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L55) | [copy_to_and_from_numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L73) | | | | oneflow.zeros | | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27) | | | | oneflow.zeros_like | [oneflow.zeros_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L53) | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27) | | | | oneflow.ones | | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27) | | | | oneflow.ones_like | [oneflow.ones_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L20) | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27) | | | | oneflow.randint_like | [oneflow._C.randint_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L242) | [consistent_randint_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randint_like.py#L27) | | | | oneflow.masked_fill | [oneflow.Tensor.masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1708) | [masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L58) | | | | oneflow.new_ones | [oneflow.Tensor.new_ones](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L229) | [flow_new_ones_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L107) | | | | oneflow.arange | [oneflow.arange](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/arange.py#L20) | [arange](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_arange.py#L63) | | done | | oneflow.linspace | | [global_linspace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_linspace.py#L26) | | | | oneflow.eye | [oneflow.eye](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1597) | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27) | | done | | oneflow.empty | [oneflow.empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L119) | [slice_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_slice.py#L51) | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24) | | | oneflow.empty_like | [oneflow.empty_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L160) | | | | | oneflow.full | | [global_full](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_full.py#L27) | | | | oneflow.full_like | | [full_like_with_random_data_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L154) | | | | oneflow.tensor_scatter_nd_update | | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128) | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156) | | | oneflow.logspace | | [logspace_int_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logspace.py#L26) | | | | oneflow.argwhere | [oneflow.Tensor.argwhere](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L713) | [argwhere_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argwhere.py#L50) | | | | oneflow.atleast_1d | [oneflow.atleast_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L272) | [atleast_1d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L28) | | | | oneflow.atleast_2d | [oneflow.atleast_2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L306) | [atleast_2d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L43) | | | | oneflow.atleast_3d | [oneflow.atleast_3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L344) | [atleast_3d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L59) | | | | oneflow.cat | [oneflow.cat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L613) | [cat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_concat.py#L138) | | | | oneflow.column_stack | [oneflow.column_stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L513) | [column_stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L48) | | | | oneflow.concat | | [concat_with_input_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_concat.py#L164) | [concat_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L37) | | | oneflow.chunk | [oneflow.Tensor.chunk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L883) | [flow_chunk_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_chunk.py#L46) | [chunk_0_dim_input_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_chunk.py#L25) | | | oneflow.dstack | [oneflow.dstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L481) | [dstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L115) | | | | oneflow.expand | [oneflow.Tensor.expand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L130) | [expand_new_dims_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_expand_op.py#L28) | [expand_dim_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L78) | | | oneflow.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531) | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85) | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120) | done | | oneflow.gather_nd | [oneflow.gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L685) | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85) | | | | oneflow.batch_gather | [oneflow.batch_gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L199) | [batch_gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batch_gather.py#L74) | | | | oneflow.hsplit | [oneflow.hsplit](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1674) | [flow_hsplit_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_hsplit.py#L27) | | | | oneflow.hstack | [oneflow.hstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L413) | [hstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L80) | | | | oneflow.vsplit | [oneflow.vsplit](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1714) | [flow_vsplit_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_vsplit.py#L27) | | | | oneflow.vstack | [oneflow.vstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L447) | [vstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L98) | | | | oneflow.index_select | [oneflow.Tensor.index_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L185) | [index_select_by_random](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_index_select.py#L30) | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330) | | | oneflow.masked_select | [oneflow.Tensor.masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1715) | [masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_masked_select.py#L87) | | | | oneflow.movedim | [oneflow.movedim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1496) | [flow_movedim_with_vector](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_movedim.py#L27) | | | | oneflow.narrow | [oneflow.Tensor.narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L629) | [narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_narrow.py#L35) | [narrow_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L178) | | | oneflow.nonzero | [oneflow.nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/nonzero.py#L20) | [nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L51) | | | | oneflow.permute | [oneflow.Tensor.permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L643) | [einsum_batch_permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_batch_permute.py#L42) | | | | oneflow.repeat | [oneflow.Tensor.repeat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1622) | [flow_tensor_repeat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat.py#L27) | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25) | | | oneflow.reshape | [oneflow.Tensor.reshape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1858) | [reshape_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_reshape.py#L27) | [reshape_like_size_match_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reshape_like_op.py#L24) | done | | oneflow.row_stack | [oneflow.row_stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L547) | [row_stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L64) | | | | oneflow.select | [oneflow.select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1467) | [flow_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_select.py#L28) | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330) | | | oneflow.scatter | | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128) | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156) | | | oneflow.scatter_add | | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57) | | | | oneflow.scatter_nd | | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128) | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156) | | | oneflow.slice | | [slice_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_slice.py#L38) | [slice_update_start_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L23) | | | oneflow.slice_update | | [slice_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_slice_update.py#L120) | [slice_update_start_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L23) | | | oneflow.split | [oneflow.Tensor.split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L890) | [eager_boxing_2d_special_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L146) | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39) | | | oneflow.squeeze | [oneflow.Tensor.squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L556) | [squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_squeeze.py#L94) | [squeeze_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L106) | | | oneflow.stack | [oneflow.stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L382) | [stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L28) | [stack_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L62) | | | oneflow.swapaxes | [oneflow.Tensor.swapaxes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L904) | [swapaxes_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapaxes.py#L31) | | | | oneflow.swapdims | [oneflow.Tensor.swapdims](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L918) | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32) | | | | oneflow.t | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640) | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140) | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439) | | | oneflow.tile | [oneflow.tile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tile.py#L20) | [flow_tile_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tile.py#L27) | [tile_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L431) | | | oneflow.transpose | [oneflow.Tensor.transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L513) | [einsum_matrix_transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_matrix_transpose.py#L35) | | | | oneflow.unbind | [oneflow.Tensor.unbind](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L897) | [unbind_flow_with_random_data1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unbind.py#L32) | [unbind_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L248) | | | oneflow.unsqueeze | [oneflow.Tensor.unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L636) | [unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L68) | | | | oneflow.where | [oneflow.Tensor.where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2129) | [where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_where.py#L196) | | | | oneflow.tensor_split | [oneflow.tensor_split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1634) | [flow_tensor_split_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_split.py#L27) | | | | oneflow.seed | | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72) | | | | oneflow.manual_seed | | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72) | | | | oneflow.initial_seed | | | | | | oneflow.get_rng_state | | [get_rng_state](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L137) | | | | oneflow.set_rng_state | | [set_rng_state](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L148) | | | | oneflow.bernoulli | [oneflow.bernoulli](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L20) | [bernoulli](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bernoulli.py#L56) | | | | oneflow.normal | [oneflow._C.normal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L154) | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113) | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278) | | | oneflow.rand | [oneflow._C.rand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L112) | [0d_rand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_rand.py#L45) | | | | oneflow.randint | [oneflow._C.randint](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L191) | [global_randint](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randint.py#L27) | | | | oneflow.randn | [oneflow._C.randn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L71) | [randn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_randn.py#L103) | | | | oneflow.randperm | [oneflow._C.randperm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L291) | [global_randperm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randperm.py#L26) | [randperm_n_value_err_mes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_randperm_op.py#L24) | | | oneflow.save | | [save_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L222) | | | | oneflow.load | | [resnet18_load_weight_compatibile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_resnet_load_torch_weight_compatibile.py#L30) | | | | oneflow.set_num_threads | [oneflow.set_num_threads](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/oneflow.py#L20) | | | | | oneflow.no_grad | | [no_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L62) | | | | oneflow.set_grad_enabled | | [set_grad_enabled](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L74) | | | | oneflow.enable_grad | | [enable_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L50) | | | | oneflow.is_grad_enabled | | | | | | oneflow.inference_mode | | [inference_mode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L27) | | | | oneflow.abs | [oneflow.abs](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L20) | [abs_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_abs.py#L27) | | done | | oneflow.acos | [oneflow.acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L509) | [acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L122) | | | | oneflow.acosh | [oneflow.acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L535) | [acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L138) | | | | oneflow.arccos | [oneflow.Tensor.arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L664) | [arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L114) | | | | oneflow.arccosh | [oneflow.Tensor.arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L678) | [arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L130) | | | | oneflow.add | [oneflow.add](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L41) | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57) | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27) | done | | oneflow.addcdiv | [oneflow.Tensor.addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L939) | [addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L25) | | done | | oneflow.addcmul | [oneflow.addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1558) | [addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addcmul.py#L37) | | done | | oneflow.asin | [oneflow.asin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L285) | [flow_asin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L223) | | | | oneflow.asinh | [oneflow.asinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L318) | [flow_asinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L240) | | | | oneflow.arcsin | [oneflow.Tensor.arcsin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1257) | [flow_arcsin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L230) | | | | oneflow.arcsinh | [oneflow.Tensor.arcsinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1264) | [flow_arcsinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L247) | | | | oneflow.atan | [oneflow.atan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L353) | [flow_atan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L267) | | | | oneflow.atanh | [oneflow.atanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L564) | [flow_atanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L289) | | | | oneflow.arctan | [oneflow.Tensor.arctan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1343) | [flow_arctan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L274) | | | | oneflow.arctanh | [oneflow.Tensor.arctanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L685) | [flow_arctanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L296) | | | | oneflow.atan2 | [oneflow.Tensor.atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L123) | [atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L155) | | | | oneflow.ceil | [oneflow.ceil](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L378) | [ceil_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ceil.py#L29) | | | | oneflow.clamp | [oneflow.clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L20) | [clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L96) | | | | oneflow.clamp_min | [oneflow.clamp_min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L70) | [clamp_min_none_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L119) | | | | oneflow.clamp_max | [oneflow.clamp_max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L111) | [clamp_max_none_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L126) | | | | oneflow.clip | [oneflow.clip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L152) | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213) | | | | oneflow.cos | [oneflow.cos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L712) | [global_cos_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L65) | | | | oneflow.cosh | [oneflow.cosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L736) | | | | | oneflow.div | [oneflow.div](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L143) | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26) | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81) | done | | oneflow.erf | [oneflow.erf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L763) | [flow_erf_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erf.py#L33) | | done | | oneflow.erfc | [oneflow.erfc](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L810) | [erfc_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_erfc.py#L25) | | done | | oneflow.erfinv | [oneflow.Tensor.erfinv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L997) | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30) | | done | | oneflow.exp | [oneflow.exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L476) | [exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L72) | | | | oneflow.expm1 | [oneflow.expm1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L845) | [expm1_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_expm1.py#L29) | | done | | oneflow.floor | [oneflow.floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L100) | [floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_floor.py#L35) | | done | | oneflow.floor_ | [oneflow.floor_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L135) | [flow_floor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_floor.py#L57) | | | | oneflow.fmod | [oneflow.fmod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L890) | [flow_fmod_element_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1021) | | done | | oneflow.gelu | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74) | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253) | | done | | oneflow.log | [oneflow.log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L923) | [log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L56) | | | | oneflow.log1p | [oneflow.log1p](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L455) | [log1p_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_log1p.py#L31) | | | | oneflow.log2 | [oneflow.log2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L948) | [log2_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L971) | | | | oneflow.logical_and | [oneflow.Tensor.logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1677) | [logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_and.py#L58) | | | | oneflow.logical_not | [oneflow.Tensor.logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L520) | [logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_not.py#L43) | | | | oneflow.logical_or | [oneflow.Tensor.logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1687) | [logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_or.py#L58) | | | | oneflow.logical_xor | [oneflow.Tensor.logical_xor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1698) | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27) | | | | oneflow.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267) | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698) | | done | | oneflow.mul | [oneflow.mul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L186) | [einsum_eltwise_mul_then_reduce_sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py#L40) | | | | oneflow.neg | [oneflow.Tensor.neg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1039) | [flow_split_sizes_neg_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_split.py#L63) | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25) | | | oneflow.negative | [oneflow.negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L428) | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29) | [repeat_interleave_negative_tensor_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L58) | | | oneflow.pow | [oneflow.pow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1132) | [pow_with_scalar](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L96) | | | | oneflow.reciprocal | [oneflow.reciprocal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L226) | [flow_reciprocal_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_reciprocal.py#L32) | | | | oneflow.round | [oneflow.round](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1346) | [flow_round_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_round.py#L30) | | | | oneflow.rsqrt | [oneflow.rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1173) | [rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L80) | | | | oneflow.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409) | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754) | | done | | oneflow.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118) | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436) | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109) | done | | oneflow.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146) | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43) | | done | | oneflow.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207) | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782) | | done | | oneflow.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237) | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726) | | done | | oneflow.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338) | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281) | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87) | done | | oneflow.sign | [oneflow.sign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L589) | [sign_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sign.py#L25) | | | | oneflow.sin | [oneflow.sin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L618) | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59) | | | | oneflow.sinh | [oneflow.sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L656) | [sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L23) | | | | oneflow.sin_ | [oneflow.sin_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L648) | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59) | | | | oneflow.sqrt | [oneflow.sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1198) | [sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L64) | | | | oneflow.square | [oneflow.square](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1224) | [inv_random_square_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L39) | [inv_exception_not_square_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L34) | | | oneflow.sub | [oneflow.sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L246) | [global_sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L50) | | | | oneflow.tan | [oneflow.tan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L687) | [flow_tan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L257) | | | | oneflow.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163) | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106) | | done | | oneflow.floor_divide | | | | | | oneflow.argmax | [oneflow.Tensor.argmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L692) | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29) | [argmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L22) | done | | oneflow.argmin | [oneflow.Tensor.argmin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L699) | [argmin_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmin.py#L29) | | | | oneflow.amax | [oneflow.Tensor.amax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L911) | [amax_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amax.py#L35) | | done | | oneflow.amin | [oneflow.Tensor.amin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2167) | [amin_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amin.py#L34) | | done | | oneflow.any | [oneflow.Tensor.any](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1831) | [any_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L52) | | | | oneflow.max | [oneflow.Tensor.max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1774) | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83) | | | | oneflow.min | [oneflow.Tensor.min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1783) | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83) | | | | oneflow.mean | [oneflow.Tensor.mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1840) | [mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_mean.py#L70) | [normalization_moving_mean_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L317) | | | oneflow.median | [oneflow.median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1019) | [median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_median.py#L48) | [median_exception_dim_out_of_range](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_median.py#L25) | | | oneflow.prod | [oneflow.Tensor.prod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1849) | [prod_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L59) | | | | oneflow.std | [oneflow.std](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1371) | [std_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_std.py#L26) | | | | oneflow.sum | [oneflow.Tensor.sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1813) | [sum_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sum.py#L29) | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24) | | | oneflow.var | [oneflow.var](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1407) | [module_to_with_var_reuse](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L93) | | | | oneflow.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160) | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50) | | | | oneflow.all | [oneflow.Tensor.all](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1822) | [flow_var_all_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_var.py#L27) | | | | oneflow.argsort | [oneflow.Tensor.argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L706) | [argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_argsort.py#L37) | | done | | oneflow.eq | [oneflow.Tensor.eq](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1011) | [eq_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_eq.py#L25) | | done | | oneflow.equal | | [softmax_module_with_batch_size_equal_1024](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L464) | [concat_dim_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L44) | | | oneflow.gt | [oneflow.Tensor.gt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1080) | | | done | | oneflow.isinf | [oneflow.Tensor.isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2152) | [isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L33) | | | | oneflow.isnan | [oneflow.Tensor.isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2145) | [isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L24) | | | | oneflow.le | [oneflow.Tensor.le](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1025) | | | | | oneflow.lt | [oneflow.Tensor.lt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1018) | | | | | oneflow.ne | [oneflow.Tensor.ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1032) | [ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ne.py#L31) | | | | oneflow.sort | [oneflow.Tensor.sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1947) | [sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sort.py#L69) | | | | oneflow.topk | [oneflow.Tensor.topk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1751) | [flow_topk_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L306) | | | | oneflow.ge | [oneflow.Tensor.ge](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1062) | | | | | oneflow.greater | [oneflow.greater](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comparison.py#L21) | [greater_normal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_greater.py#L29) | | | | oneflow.greater_equal | [oneflow.greater_equal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comparison.py#L49) | [greater_equal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_greater_equal.py#L25) | | | | oneflow.maximum | [oneflow.maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L997) | [broadcast_maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L32) | | | | oneflow.minimum | [oneflow.minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L975) | [broadcast_minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L50) | | | | oneflow.not_equal | | | | | | oneflow.hann_window | [oneflow.hann_window](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/hann_window.py#L20) | [global_hann_window](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_hann_window.py#L26) | [hann_window_dtype_not_support](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_hann_window.py#L25) | done | | oneflow.adaptive_avg_pool1d | [oneflow._C.adaptive_avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L20) | | | done | | oneflow.adaptive_avg_pool2d | [oneflow._C.adaptive_avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L48) | | | done | | oneflow.adaptive_avg_pool3d | [oneflow._C.adaptive_avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L74) | | | done | | oneflow.broadcast_like | [oneflow.broadcast_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/broadcast_like.py#L20) | [broadcast_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_broadcast_like.py#L161) | [broadcast_like_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L28) | | | oneflow.cast | [oneflow.cast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/cast.py#L20) | [cast_float2int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cast.py#L28) | [add_broad_cast_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L37) | | | oneflow.cumprod | [oneflow.cumprod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1788) | [cumprod_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumprod.py#L25) | | done | | oneflow.cumsum | [oneflow.cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1755) | [cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumsum.py#L37) | | done | | oneflow.diag | [oneflow.Tensor.diag](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L932) | [diag_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diag.py#L26) | | done | | oneflow.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294) | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24) | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204) | done | | oneflow.einsum | [oneflow.einsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/einsum.py#L20) | [einsum_alphaflod_usecase11](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_alphaflod_usecase11.py#L38) | | | | oneflow.flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155) | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30) | | done | | oneflow.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169) | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70) | | done | | oneflow.in_top_k | [oneflow.Tensor.in_top_k](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L176) | [in_top_k_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_in_top_k.py#L82) | [in_top_k_num_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L389) | | | oneflow.meshgrid | [oneflow.meshgrid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/meshgrid.py#L20) | [meshgrid_forawd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_meshgrid.py#L29) | [meshgrid_tensors_scalar_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L276) | | | oneflow.nms | [oneflow.Tensor.nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1758) | [nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_nms.py#L50) | | | | oneflow.roc_auc_score | [oneflow.roc_auc_score](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/roc_auc_score.py#L20) | [roc_auc_score](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roc_auc_score.py#L52) | | | | oneflow.roll | [oneflow.Tensor.roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1187) | [roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roll.py#L27) | [roll_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L112) | | | oneflow.searchsorted | [oneflow.searchsorted](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/searchsorted.py#L20) | | | | | oneflow.tensordot | [oneflow.tensordot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensordot.py#L20) | [tensordot_intdim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensordot.py#L28) | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25) | | | oneflow.tril | [oneflow.Tensor.tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1441) | [fused_scale_tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_fused_scale_tril.py#L78) | | | | oneflow.repeat_interleave | [oneflow.Tensor.repeat_interleave](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1631) | [flow_int_repeat_interleave_dim_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat_interleave.py#L29) | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25) | | | oneflow.triu | [oneflow.Tensor.triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1448) | [triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_triu.py#L47) | | | | oneflow.addmm | [oneflow.Tensor.addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1215) | [addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addmm.py#L60) | | done | | oneflow.bmm | [oneflow.Tensor.bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L876) | [bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bmm.py#L93) | [bmm_exception_dim_not_right](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_bmm.py#L25) | | | oneflow.dot | [oneflow.dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1438) | [dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L903) | [dot_shape_error_msg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_dot.py#L24) | done | | oneflow.matmul | [oneflow.matmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1249) | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173) | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220) | | | oneflow.mm | [oneflow.mm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1311) | [flow_mm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L69) | [mm_not_2dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mm.py#L24) | | | oneflow.mv | [oneflow.mv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1278) | [flow_mv_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L78) | [mv_not_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mv.py#L23) | done | | oneflow.env.all_device_placement | | | | | | oneflow.env.get_world_size | | | | | | oneflow.env.get_rank | | | | | | oneflow.env.get_local_rank | | | | | | oneflow.env.get_node_size | | | | | | oneflow.env.init_rdma | | | | | | oneflow.env.rdma_is_initialized | | | | | | oneflow.comm.all_reduce | | [all_reduce_1n2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L31) | | | | oneflow.comm.all_gather | | [all_gather_1n2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L48) | | | | oneflow.comm.all_to_all | | [all_to_all_1n4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L148) | | | | oneflow.comm.broadcast | | [cosine_similartiy_broadcast_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cosine_similarity.py#L45) | [cosine_similarity_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L34) | | | oneflow.comm.barrier | | | | | | oneflow.comm.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531) | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85) | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120) | done | | oneflow.comm.reduce | | [min_reduce_random_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_min.py#L28) | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24) | | | oneflow.comm.reduce_scatter | | [reduce_scatter_1n4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L167) | | | | oneflow.comm.recv | [oneflow.comm.recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comm.py#L32) | [send_recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm.py#L28) | | | | oneflow.comm.scatter | | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128) | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156) | | | oneflow.comm.send | [oneflow.comm.send](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comm.py#L20) | [send_recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm.py#L28) | | | | oneflow.linalg.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160) | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50) | | | | oneflow.linalg.vector_norm | [oneflow.linalg.vector_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L21) | [vector_norm_only_zero_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_norm.py#L318) | | | | oneflow.linalg.matrix_norm | [oneflow.linalg.matrix_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L88) | | | | | oneflow.linalg.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294) | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24) | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204) | done | | oneflow.linalg.inv | [oneflow.linalg.inv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/inv.py#L21) | [inv_3by3_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L27) | [inv_exception_dim_short](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L25) | done | | oneflow.optim.Optimizer.add_param_group | | [sgd_add_param_group](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_add_param_group.py#L44) | [sgd_add_param_group_not_unique](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_optim_add_param_group.py#L23) | | | oneflow.optim.Optimizer.load_state_dict | | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63) | | | | oneflow.optim.Optimizer.state_dict | | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63) | | | | oneflow.optim.Optimizer.step | | [arange_step_prarm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_arange.py#L35) | [slice_update_step_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L49) | | | oneflow.optim.Optimizer.zero_grad | | | | | | oneflow.optim.Adagrad | | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174) | | | | oneflow.optim.Adam | | [multi_tensor_adam_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_multi_tensor_adam_update.py#L157) | | | | oneflow.optim.AdamW | | [adamw](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adamw.py#L244) | | | | oneflow.optim.LAMB | | [lamb](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_lamb.py#L157) | | | | oneflow.optim.RMSprop | | [rmsprop](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_rmsprop.py#L228) | | | | oneflow.optim.SGD | | [one_embedding_sgd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_sgd.py#L190) | [sgd_add_param_group_not_unique](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_optim_add_param_group.py#L23) | | | oneflow.optim.lr_scheduler.CosineAnnealingLR | | | | | | oneflow.optim.lr_scheduler.CosineDecayLR | | | | | | oneflow.optim.lr_scheduler.ExponentialLR | | | | | | oneflow.optim.lr_scheduler.LambdaLR | | | | | | oneflow.optim.lr_scheduler.MultiStepLR | | | | | | oneflow.optim.lr_scheduler.PolynomialLR | | | | | | oneflow.optim.lr_scheduler.ReduceLROnPlateau | | | | | | oneflow.optim.lr_scheduler.StepLR | | | | | | oneflow.optim.lr_scheduler.ConstantLR | | | | | | oneflow.optim.lr_scheduler.LinearLR | | | | | | oneflow.optim.lr_scheduler.ChainedScheduler | | | | | | oneflow.optim.lr_scheduler.SequentialLR | | | | | | oneflow.optim.lr_scheduler.CosineAnnealingWarmRestarts | | | | | | oneflow.one_embedding.make_table_options | | | | | | oneflow.one_embedding.make_table | | | | | | oneflow.one_embedding.make_uniform_initializer | | | | | | oneflow.one_embedding.make_normal_initializer | | | | | | oneflow.one_embedding.make_device_mem_store_options | | | | | | oneflow.one_embedding.make_cached_ssd_store_options | | | | | | oneflow.one_embedding.make_cached_host_mem_store_options | | | | | | oneflow.one_embedding.MultiTableEmbedding | | | | | | oneflow.one_embedding.MultiTableEmbedding.forward | | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27) | | | | oneflow.one_embedding.MultiTableEmbedding.save_snapshot | | | | | | oneflow.one_embedding.MultiTableEmbedding.load_snapshot | | | | | | oneflow.one_embedding.MultiTableMultiColumnEmbedding | | | | | | oneflow.one_embedding.MultiTableMultiColumnEmbedding.forward | | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27) | | | | oneflow.one_embedding.MultiTableMultiColumnEmbedding.save_snapshot | | | | | | oneflow.one_embedding.MultiTableMultiColumnEmbedding.load_snapshot | | | | | | oneflow.one_embedding.make_persistent_table_reader | | | | | | oneflow.one_embedding.make_persistent_table_writer | | | | | | oneflow.one_embedding.Ftrl | | [ftrl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_ftrl.py#L191) | | | | oneflow.nn.init.calculate_gain | | | | | | oneflow.nn.init.uniform_ | [oneflow.Tensor.uniform_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1455) | | | | | oneflow.nn.init.normal_ | [oneflow.Tensor.normal_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1154) | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113) | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278) | | | oneflow.nn.init.constant_ | | [constant_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_constant.py#L99) | | | | oneflow.nn.init.ones_ | | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27) | | | | oneflow.nn.init.zeros_ | | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27) | | | | oneflow.nn.init.xavier_uniform_ | | | | | | oneflow.nn.init.xavier_normal_ | | | | | | oneflow.nn.init.kaiming_uniform_ | | | | | | oneflow.nn.init.kaiming_normal_ | | | | | | oneflow.nn.init.trunc_normal_ | | | | | | oneflow.nn.init.orthogonal_ | | | | | | oneflow.nn.image.Resize | | [image_resize_to_fixed_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_resize.py#L192) | | | | oneflow.nn.image.batch_align | | [image_batch_align](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_batch_align.py#L52) | | | | oneflow.nn.image.decode | | [read_decode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_rec_ops.py#L78) | | | | oneflow.nn.image.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169) | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70) | | done | | oneflow.nn.image.normalize | [oneflow._C.normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L268) | [functional_normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_normalize.py#L54) | | | | oneflow.utils.data.random_split | | | | | ## Test Data Summary - OneFlow Total API Number: 771 - Doc Test Ratio: 63.81% (492 / 771) - Compatiable/Completeness Test Ratio: 73.80% (569 / 771) - Exception Test Ratio: 19.71% (152 / 771) - Performance Test Ratio: 15.56% (120 / 771) ================================================ FILE: python/oneflow/test/dataloader/data_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import oneflow as flow import flowvision as vision import flowvision.transforms as transforms def load_data_cifar10( batch_size, data_dir="./data-test/cifar10", download=True, transform=None, source_url=None, num_workers=0, ): cifar10_train = vision.datasets.CIFAR10( root=data_dir, train=True, download=download, transform=transform, source_url=source_url, ) cifar10_test = vision.datasets.CIFAR10( root=data_dir, train=False, download=download, transform=transform, source_url=source_url, ) train_iter = flow.utils.data.DataLoader( cifar10_train, batch_size=batch_size, shuffle=True, num_workers=num_workers ) test_iter = flow.utils.data.DataLoader( cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers ) return train_iter, test_iter def load_data_mnist( batch_size, resize=None, root="./data/mnist", download=True, source_url=None ): """Download the MNIST dataset and then load into memory.""" root = os.path.expanduser(root) transformer = [] if resize: transformer += [transforms.Resize(resize)] transformer += [transforms.ToTensor()] transformer = transforms.Compose(transformer) mnist_train = vision.datasets.MNIST( root=root, train=True, transform=transformer, download=download, source_url=source_url, ) mnist_test = vision.datasets.MNIST( root=root, train=False, transform=transformer, download=download, source_url=source_url, ) train_iter = flow.utils.data.DataLoader( mnist_train, batch_size, shuffle=True, num_workers=2 ) test_iter = flow.utils.data.DataLoader( mnist_test, batch_size, shuffle=False, num_workers=2 ) return train_iter, test_iter def get_fashion_mnist_dataset( resize=None, root="./data-test/fashion-mnist", download=True, source_url=None, ): root = os.path.expanduser(root) trans = [] if resize: trans.append(transforms.Resize(resize)) trans.append(transforms.ToTensor()) transform = transforms.Compose(trans) mnist_train = vision.datasets.FashionMNIST( root=root, train=True, transform=transform, download=download, source_url=source_url, ) mnist_test = vision.datasets.FashionMNIST( root=root, train=False, transform=transform, download=download, source_url=source_url, ) return mnist_train, mnist_test # reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch def load_data_fashion_mnist( batch_size, resize=None, root="./data-test/fashion-mnist", download=True, source_url=None, num_workers=0, ): """Download the Fashion-MNIST dataset and then load into memory.""" root = os.path.expanduser(root) trans = [] if resize: trans.append(transforms.Resize(resize)) trans.append(transforms.ToTensor()) transform = transforms.Compose(trans) mnist_train = vision.datasets.FashionMNIST( root=root, train=True, transform=transform, download=download, source_url=source_url, ) mnist_test = vision.datasets.FashionMNIST( root=root, train=False, transform=transform, download=download, source_url=source_url, ) train_iter = flow.utils.data.DataLoader( mnist_train, batch_size, shuffle=True, num_workers=num_workers ) test_iter = flow.utils.data.DataLoader( mnist_test, batch_size, shuffle=False, num_workers=num_workers ) return train_iter, test_iter ================================================ FILE: python/oneflow/test/dataloader/test_cifar_dataset_multiprocess.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow.unittest import oneflow as flow import oneflow.nn as nn import oneflow.optim as optim from data_utils import load_data_cifar10 import flowvision as vision import flowvision.transforms as transforms classes = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(flow._C.relu(self.conv1(x))) x = self.pool(flow._C.relu(self.conv2(x))) x = flow.flatten(x, 1) # flatten all dimensions except batch x = flow._C.relu(self.fc1(x)) x = flow._C.relu(self.fc2(x)) x = self.fc3(x) return x def _test(test_case): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") net = Net() net.to(device) optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9) criterion = nn.CrossEntropyLoss() criterion.to(device) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),] ) train_epoch = 1 batch_size = 4 num_workers = 4 data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10" ) train_iter, test_iter = load_data_cifar10( batch_size=batch_size, data_dir=data_dir, download=True, transform=transform, source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz", num_workers=num_workers, ) final_loss = 0 for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(train_iter, 1): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs = inputs.to(dtype=flow.float32, device=device) labels = labels.to(dtype=flow.int64, device=device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 200 == 0: # print every 200 mini-batches final_loss = running_loss / 200 print("epoch: %d step: %5d loss: %.3f " % (epoch, i, final_loss)) running_loss = 0.0 break print("final loss : ", final_loss) @flow.unittest.skip_unless_1n1d() class TestCifarDataset(flow.unittest.TestCase): def test_cifar_dataset(test_case): _test(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_cifar_dataset_singleprocess.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import flowvision as vision import flowvision.transforms as transforms import oneflow.unittest import oneflow as flow import oneflow.nn as nn import oneflow.optim as optim from data_utils import load_data_cifar10 classes = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(flow._C.relu(self.conv1(x))) x = self.pool(flow._C.relu(self.conv2(x))) x = flow.flatten(x, 1) # flatten all dimensions except batch x = flow._C.relu(self.fc1(x)) x = flow._C.relu(self.fc2(x)) x = self.fc3(x) return x def _test(test_case): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") net = Net() net.to(device) optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9) criterion = nn.CrossEntropyLoss() criterion.to(device) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),] ) train_epoch = 1 batch_size = 4 num_workers = 0 data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10" ) train_iter, test_iter = load_data_cifar10( batch_size=batch_size, data_dir=data_dir, download=True, transform=transform, source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz", num_workers=num_workers, ) final_loss = 0 for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(train_iter, 1): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs = inputs.to(dtype=flow.float32, device=device) labels = labels.to(dtype=flow.int64, device=device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 200 == 0: # print every 200 mini-batches final_loss = running_loss / 200 print("epoch: %d step: %5d loss: %.3f " % (epoch, i, final_loss)) running_loss = 0.0 break print("final loss : ", final_loss) # test_case.assertLess(final_loss, 1.50) @flow.unittest.skip_unless_1n1d() class TestCifarDataset(flow.unittest.TestCase): def test_cifar_dataset(test_case): _test(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_fashion_mnist_dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import time import oneflow.unittest import oneflow as flow import oneflow.nn as nn from data_utils import load_data_fashion_mnist def get_fashion_mnist_labels(labels): """Get text labels for Fashion-MNIST.""" text_labels = [ "t-shirt", "trouser", "pullover", "dress", "coat", "sandal", "shirt", "sneaker", "bag", "ankle boot", ] return [text_labels[int(i)] for i in labels] class FlattenLayer(nn.Module): def __init__(self): super(FlattenLayer, self).__init__() def forward(self, x): # x shape: (batch, *, *, ...) res = x.reshape(x.shape[0], -1) return res def evaluate_accuracy(data_iter, net, device=None): if device is None and isinstance(net, nn.Module): # using net device if not specified device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 net.eval() with flow.no_grad(): for X, y in data_iter: X = X.to(device=device) y = y.to(device=device) acc_sum += ( net(X.to(device)).argmax(dim=1).numpy() == y.to(device).numpy() ).sum() n += y.shape[0] net.train() return acc_sum / n def _test(test_case): num_inputs, num_outputs, num_hiddens = 784, 10, 256 net = nn.Sequential( FlattenLayer(), nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.Linear(num_hiddens, num_outputs), ) if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") net.to(device) batch_size = 256 num_epochs = 1 data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "fashion-mnist" ) source_url = "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/" train_iter, test_iter = load_data_fashion_mnist( batch_size, resize=None, root=data_dir, download=True, source_url=source_url ) loss = nn.CrossEntropyLoss() loss.to(device) optimizer = flow.optim.SGD(net.parameters(), lr=0.1) final_accuracy = 0 for epoch in range(num_epochs): train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 start = time.time() for X, y in train_iter: X = X.to(device=device) y = y.to(device=device) y_hat = net(X) l = loss(y_hat, y).sum() optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.numpy() train_acc_sum += (y_hat.argmax(dim=1).numpy() == y.numpy()).sum() n += y.shape[0] if n > 200: break test_acc = evaluate_accuracy(test_iter, net) final_accuracy = train_acc_sum / n print( "epoch %d, loss %.4f, train acc %.3f, test acc %.3f, cost >>>>>>> %s(s)" % ( epoch + 1, train_l_sum / n, final_accuracy, test_acc, str(time.time() - start), ) ) final_accuracy = train_acc_sum / n # test_case.assertLess(0.60, final_accuracy) @flow.unittest.skip_unless_1n1d() class TestFashionMnistDataset(flow.unittest.TestCase): def test_fashion_mnist_dataset(test_case): _test(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_lenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import time import unittest import oneflow as flow import oneflow.nn as nn import oneflow.unittest from data_utils import load_data_fashion_mnist # reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 6, kernel_size=5), # in_channels, out_channels, kernel_size nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # kernel_size, stride nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.fc = nn.Sequential( nn.Linear(16 * 4 * 4, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10), ) def forward(self, img): feature = self.conv(img) feature = feature.flatten(start_dim=1) output = self.fc(feature) return output def evaluate_accuracy(data_iter, net, device=None): if device is None and isinstance(net, nn.Module): device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 net.eval() with flow.no_grad(): for X, y in data_iter: X = X.to(device=device) y = y.to(device=device) acc_sum += (net(X).argmax(dim=1).numpy() == y.numpy()).sum() n += y.shape[0] net.train() return acc_sum / n def _test_train_and_eval(test_case): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") net = LeNet() lr, num_epochs = 0.02, 1 optimizer = flow.optim.SGD(net.parameters(), lr=lr, momentum=0.9) net.to(device) batch_size = 256 data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "fashion-mnist-lenet" ) source_url = "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/" train_iter, test_iter = load_data_fashion_mnist( batch_size=batch_size, resize=None, root=data_dir, download=True, source_url=source_url, num_workers=0, ) loss = nn.CrossEntropyLoss() loss.to(device) final_accuracy = 0 for epoch in range(num_epochs): train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() for X, y in train_iter: X = X.to(device=device) y = y.to(device=device) # forward y_hat = net(X) l = loss(y_hat, y).sum() # backward l.backward() optimizer.step() optimizer.zero_grad() train_l_sum += l.numpy() train_acc_sum += (y_hat.argmax(dim=1).numpy() == y.numpy()).sum() n += y.shape[0] batch_count += 1 if batch_count == 20: break test_acc = evaluate_accuracy(test_iter, net) final_accuracy = train_acc_sum / n print( "epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec" % ( epoch + 1, train_l_sum / batch_count, final_accuracy, test_acc, time.time() - start, ) ) # test_case.assertLess(0.4, final_accuracy) @flow.unittest.skip_unless_1n1d() class TestLenet(flow.unittest.TestCase): def test_lenet(test_case): _test_train_and_eval(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_mnist_dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import flowvision as vision import flowvision.transforms as transforms import oneflow.unittest import oneflow as flow import oneflow.nn as nn from data_utils import load_data_mnist data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "mnist-dataset" ) train_iter, test_iter = load_data_mnist( batch_size=128, download=True, root=data_dir, source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/", ) def evaluate_accuracy(data_iter, net, device=None): n_correct, n_samples = 0.0, 0 net.to(device) net.eval() with flow.no_grad(): for images, labels in data_iter: images = images.reshape(-1, 28 * 28) images = images.to(device=device) labels = labels.to(device=device) n_correct += (net(images).argmax(dim=1).numpy() == labels.numpy()).sum() n_samples += images.shape[0] net.train() return n_correct / n_samples class Net(nn.Module): def __init__( self, input_size=784, hidden_size1=128, hidden_size2=64, num_classes=10 ): super(Net, self).__init__() self.l1 = nn.Linear(input_size, hidden_size1) self.relu1 = nn.ReLU() self.l2 = nn.Linear(hidden_size1, hidden_size2) self.relu2 = nn.ReLU() self.l3 = nn.Linear(hidden_size2, num_classes) def forward(self, x): out = self.l1(x) out = self.relu1(out) out = self.l2(out) out = self.relu2(out) out = self.l3(out) return out def _test_train_and_eval(test_case): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") model = Net() model.to(device) loss = nn.CrossEntropyLoss().to(device) optimizer = flow.optim.SGD(model.parameters(), lr=0.10) num_epochs = 1 for epoch in range(num_epochs): train_loss, n_correct, n_samples = 0.0, 0.0, 0 for images, labels in train_iter: images = images.reshape(-1, 28 * 28) images = images.to(device=device) labels = labels.to(device=device) features = model(images) l = loss(features, labels).sum() optimizer.zero_grad() l.backward() optimizer.step() train_loss += l.numpy() n_correct += (features.argmax(dim=1).numpy() == labels.numpy()).sum() n_samples += images.shape[0] if n_samples > 2000: break test_acc = evaluate_accuracy(test_iter, model, device) train_acc = n_correct / n_samples print( "epoch %d, train loss %.4f, train acc %.3f, test acc %.3f" % (epoch + 1, train_loss / n_samples, train_acc, test_acc) ) # test_case.assertLess(0.8, test_acc) @flow.unittest.skip_unless_1n1d() class TestMnistDataset(flow.unittest.TestCase): def test_mnist_dataset(test_case): _test_train_and_eval(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_numpy_dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest class ScpDataset(flow.utils.data.Dataset): def __init__(self, chunksize=200, dim=81, length=2000): self.chunksize = chunksize self.dim = dim self.length = length def __getitem__(self, index): np.random.seed(index) return np.random.randn(self.chunksize, self.dim) def __len__(self): return self.length @flow.unittest.skip_unless_1n1d() class TestNumpyDataset(flow.unittest.TestCase): def test_numpy_dataset(test_case): dataset = ScpDataset() dataloader = flow.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) for X in dataloader: test_case.assertEqual(X.shape, flow.Size([16, 200, 81])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_tensor_dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.nn as nn import oneflow.unittest import oneflow.optim as optim class LinearNet(nn.Module): def __init__(self, n_feature): super(LinearNet, self).__init__() self.linear = nn.Linear(n_feature, 1) def forward(self, x): y = self.linear(x) return y @unittest.skip("optimizer has a bug with 0-dim tensor") class TestTensorDataset(flow.unittest.TestCase): def test_tensor_dataset(test_case): num_inputs = 2 num_examples = 1000 true_w = [2, -3.4] true_b = 4.2 net = LinearNet(num_inputs) flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01) flow.nn.init.constant_(net.linear.bias, val=0) loss = nn.MSELoss() optimizer = optim.SGD(net.parameters(), lr=0.03) features = flow.tensor( np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float ) labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b labels += flow.tensor( np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float ) batch_size = 10 dataset = flow.utils.data.TensorDataset(features, labels) data_iter = flow.utils.data.DataLoader( dataset, batch_size, shuffle=True, num_workers=0 ) num_epochs = 10 for epoch in range(1, num_epochs + 1): for (X, y) in data_iter: output = net(X) l = loss(output, y).sum() optimizer.zero_grad() l.backward() optimizer.step() if epoch == num_epochs: test_case.assertLess(l.numpy(), 0.00025) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/dataloader/test_transforms.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import flowvision as vision import flowvision.transforms as transforms import oneflow as flow import oneflow.nn as nn import oneflow.optim as optim import oneflow.unittest from data_utils import load_data_cifar10 class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(flow._C.relu(self.conv1(x))) x = self.pool(flow._C.relu(self.conv2(x))) x = flow.flatten(x, 1) # flatten all dimensions except batch x = flow._C.relu(self.fc1(x)) x = flow._C.relu(self.fc2(x)) x = self.fc3(x) return x def _test(test_case): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): device = flow.device("cpu") else: device = flow.device("cuda") net = Net() net.to(device) optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9) criterion = nn.CrossEntropyLoss() criterion.to(device) transform = transforms.Compose( [ transforms.Pad(10), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.CenterCrop(32), transforms.Resize([32, 32]), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) train_epoch = 1 batch_size = 4 data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10" ) train_iter, test_iter = load_data_cifar10( batch_size=batch_size, data_dir=data_dir, download=True, transform=transform, source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz", num_workers=0, ) final_loss = 0 for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(train_iter, 1): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs = inputs.to(dtype=flow.float32, device=device) labels = labels.to(dtype=flow.int64, device=device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.numpy() # print every 2000 mini-batches if i % 2000 == 0: final_loss = running_loss / 2000 print("epoch: %d step: %5d loss: %.3f " % (epoch, i, final_loss)) running_loss = 0.0 print("final loss : ", final_loss) # test_case.assertLess(final_loss, 1.79) @flow.unittest.skip_unless_1n1d() class TestCifarDataset(flow.unittest.TestCase): def test_cifar_dataset(test_case): _test(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * class TestActivationError(flow.unittest.TestCase): def test_relu_inplace_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) x.relu_() test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_prelu_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) m = flow.nn.PReLU(5) y = m(x) test_case.assertTrue( "num_parameters in prelu must be 1 or 4" in str(context.exception) ) def test_celu_inplace_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) m = flow.nn.CELU(alpha=1.0, inplace=True) y = m(x) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_glu_scalar_tensor_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor(1.0) m = flow.nn.GLU() y = m(x) test_case.assertTrue( "glu does not support scalars because halving size must be even" in str(context.exception) ) def test_glu_dim_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2, 4) m = flow.nn.GLU(dim=3) y = m(x) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2, 1], but got 3)" in str(context.exception) ) def test_glu_dim_even_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2, 3) m = flow.nn.GLU() y = m(x) test_case.assertTrue( "Halving dimension must be even, but dimension 1 is size 3" in str(context.exception) ) def test_hard_sigmoid_inplace_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2) x.requires_grad = True m = flow.nn.Hardsigmoid(inplace=True) y = m(x) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_hard_shrink_inplace_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2) x.requires_grad = True m = flow.nn.Hardshrink(inplace=True) y = m(x) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_softmax_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2, 4) m = flow.nn.Softmax(dim=2) y = m(x) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2, 1], but got 2)" in str(context.exception) ) def test_soft_shrink_inplace_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2) x.requires_grad = True m = flow.nn.Softshrink(inplace=True) y = m(x) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_soft_shrink_alpha_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.randn(2) x.requires_grad = True m = flow.nn.Softshrink(-0.1) y = m(x) test_case.assertTrue( "alpha must be greater or equal to 0, but found to be -0.1." in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_add_n_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestAddN(flow.unittest.TestCase): def test_add_n_shape_error_msg(test_case): a = flow.tensor([1, 2]) b = flow.tensor([3, 4]) c = flow.tensor([[2, 2], [2, 2]]) with test_case.assertRaises(RuntimeError) as context: flow.add(a, b, c) test_case.assertTrue( "inconsistent tensor size, expected all tensor to have the same number of elements, but got" in str(context.exception) ) def test_add_n_dtype_error_msg(test_case): a = flow.tensor([1, 2], dtype=flow.int64) b = flow.tensor([3, 4], dtype=flow.int64) c = flow.tensor([2, 2], dtype=flow.float64) with test_case.assertRaises(RuntimeError) as context: flow.add(a, b, c) test_case.assertTrue( "expected all tenser to have same type, but found" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_arg_sort_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestArgSort(flow.unittest.TestCase): def test_direction_parameter_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor([5, 10, 7, 8, 9, 1]) flow._C.arg_sort(x, direction="NONE") test_case.assertTrue( "expected the input direction parameter value is" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_array_functor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow.unittest import oneflow as flow class TestArrayError(flow.unittest.TestCase): def test_argmax_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) y = flow.argmax(x, dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_broadcast_like_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 0), dtype=flow.float32, requires_grad=True) like = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True) y = flow.broadcast_like(x, like) test_case.assertTrue( "The expanded size of the tensor" in str(context.exception) ) def test_broadcast_like_numaxes_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True) like = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) y = flow._C.broadcast_like(x, like) print(str(context.exception)) test_case.assertTrue("The number of sizes provided" in str(context.exception)) def test_concat_index_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) y = flow.concat([x1, x2], dim=3) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_concat_dim_equal_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True) y = flow.concat([x1, x2]) test_case.assertTrue( "Tensors must have same number of dimensions" in str(context.exception) ) def test_concat_match_size_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 3), dtype=flow.float32, requires_grad=True) y = flow.concat([x1, x2]) test_case.assertTrue( "Sizes of tensors must match except in dimension" in str(context.exception) ) def test_stack_index_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) y = flow.concat([x1, x2], dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_stack_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) y = flow.stack([x1, x2]) test_case.assertTrue( "stack expects each tensor to be equal size" in str(context.exception) ) def test_expand_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2), dtype=flow.float32, requires_grad=True) y = flow.expand(x1, x2.shape) test_case.assertTrue( "be greater or equal to the number of dimensions in the tensor" in str(context.exception) ) def test_expand_g_shape_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 4), dtype=flow.float32, requires_grad=True) y = flow.expand(x1, x2.shape) test_case.assertTrue( "The expanded size of the tensor" in str(context.exception) ) def test_expand_l_shape_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 0), dtype=flow.float32, requires_grad=True) y = flow.expand(x1, x2.shape) test_case.assertTrue( "The expanded size of the tensor" in str(context.exception) ) def test_squeeze_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) y = flow.squeeze(x, dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_roll_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) y = flow.roll(x, [0, 1], [0]) test_case.assertTrue( "shifts and dimensions must align" in str(context.exception) ) def test_gather_index_type_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2), dtype=flow.float32) y = flow.gather(x1, 1, x2) test_case.assertTrue( "gather(): Expected dtype int32 or int64 for index" in str(context.exception) ) def test_gather_dim_value_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2), dtype=flow.int64) y = flow.gather(x1, 2, x2) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_gather_dim_equal_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 2, 2), dtype=flow.int64) y = flow.gather(x1, 1, x2) test_case.assertTrue( "Index tensor must have the same number of dimensions as input tensor" in str(context.exception) ) def test_gather_size_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((4, 2), dtype=flow.int64) y = flow.gather(x1, 1, x2) test_case.assertTrue( "Size does not match at dimension" in str(context.exception) ) def test_tensor_scatter_nd_update_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.arange(8, dtype=flow.float32, requires_grad=True) indices = flow.tensor([[1], [3], [5]]) updates = flow.tensor([-1, -2, -3], dtype=flow.float64, requires_grad=True) y = flow.tensor_scatter_nd_update(x, indices, updates) test_case.assertTrue( "The dtype of tensor and updates must be same." in str(context.exception) ) def test_view_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 3, 4), dtype=flow.float32, requires_grad=True).permute( 1, 0, 2 ) x2 = flow.ones((4, 6), dtype=flow.float32, requires_grad=True) y = flow.view(x1, x2.shape) test_case.assertTrue( "view size is not compatible with input tensor's size" in str(context.exception) ) def test_narrow_dim_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True) y = flow.narrow(x, 3, 0, 2) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_narrow_0_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor(1, dtype=flow.float32, requires_grad=True) y = flow.narrow(x, 0, 0, 0) test_case.assertTrue( "narrow() cannot be applied to a 0-dim tensor." in str(context.exception) ) def test_narrow_start_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True) y = flow.narrow(x, 0, 4, 0) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_narrow_length_exceed_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True) y = flow.narrow(x, 0, 2, 2) test_case.assertTrue("exceeds dimension size" in str(context.exception)) def test_diagonal_index_error1(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.diagonal(x, 1, 3, 2) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_diagonal_index_error2(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.diagonal(x, 1, 2, 3) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_diagonal_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.diagonal(x, 1, 2, 2) test_case.assertTrue( "diagonal dimensions cannot be identical" in str(context.exception) ) def test_split_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.split(x, split_size_or_sections=0, dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_split_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.split(x, split_size_or_sections=-1) test_case.assertTrue( "split expects split_size be non-negative, but got split_size" in str(context.exception) ) def test_splitwithsize_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((5, 2), dtype=flow.float32, requires_grad=True) y = flow.split(x, [1, 3]) test_case.assertTrue( "split_with_sizes expects split_sizes to sum exactly to " in str(context.exception) ) def test_unbind_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.unbind(x, dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_chunk_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.chunk(x, chunks=2, dim=4) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_chunk_tensor_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor(1, dtype=flow.float32, requires_grad=True) y = flow.chunk(x, chunks=2, dim=4) test_case.assertTrue( "chunk expects at least a 1-dimensional tensor" in str(context.exception) ) def test_chunk_value_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.chunk(x, chunks=-1, dim=4) test_case.assertTrue( "chunk expects `chunks` to be greater than 0, got" in str(context.exception) ) def test_meshgrid_tensors_scalar_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.tensor([], dtype=flow.float32, requires_grad=True) x2 = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.meshgrid(x1, x2) test_case.assertTrue( "Expected scalar or 1D tensor in the tensor list" in str(context.exception) ) def test_meshgrid_tensors_size_runtime_error(test_case): with test_case.assertRaises(Exception) as context: y = flow.meshgrid([]) test_case.assertTrue( "meshgrid expects a non-empty TensorList" in str(context.exception) ) def test_meshgrid_tensors_dtype_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2), dtype=flow.float16, requires_grad=True) y = flow.meshgrid(x1, x2) test_case.assertTrue( "meshgrid expects all tensors to have the same dtype" in str(context.exception) ) def test_meshgrid_tensors_placement_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.tensor( [0.0, 1.0], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.broadcast], ) x2 = flow.tensor( [0.0, 1.0], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.broadcast], ).to_local() y = flow.meshgrid(x1, x2) test_case.assertTrue( "meshgrid expects all tensors are global tensor" in str(context.exception) ) def test_meshgrid_indexing_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2), dtype=flow.float32, requires_grad=True) y = flow.meshgrid(x1, x2, indexing="ab") test_case.assertTrue( "meshgrid: indexing must be one of" in str(context.exception) ) def test_index_select_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor( [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True ) index = flow.tensor([0, 1], dtype=flow.float32) y = flow.index_select(x, 1, index) test_case.assertTrue( "Expected dtype int32 or int64 for index" in str(context.exception) ) def test_index_select_index_num_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor( [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True ) index = flow.tensor([[0]], dtype=flow.int32) y = flow.index_select(x, 1, index) test_case.assertTrue( "Index is supposed to be a vector" in str(context.exception) ) def test_index_select_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor( [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True ) index = flow.tensor([0], dtype=flow.int32) y = flow.index_select(x, 4, index) test_case.assertTrue("Dimension out of range" in str(context.exception)) def test_to_device_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor( [0.0, 1.0], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.split(0)], ) x.to("cpp") test_case.assertTrue( "Only string device without device id" in str(context.exception) ) def test_to_other_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([0.0, 1.0], dtype=flow.float32) other = flow.tensor( [0.0, 1.0], dtype=flow.float32, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.split(0)], ) x.to(other) test_case.assertTrue( "tensor.to(other) can only be called when tensor and other are local tensors" in str(context.exception) ) def test_in_top_k_num_equal_runtime_error(test_case): with test_case.assertRaises(Exception) as context: target = flow.tensor([[3, 1]], dtype=flow.int32) prediction = flow.tensor( [[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], dtype=flow.float32 ) out = flow.in_top_k(target, prediction, k=1) test_case.assertTrue( "The num of targets must equal the num of predictions" in str(context.exception) ) def test_in_top_k_targets_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: target = flow.tensor([[3, 1], [1, 3]], dtype=flow.int32) prediction = flow.tensor( [[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], dtype=flow.float32 ) out = flow.in_top_k(target, prediction, k=1) test_case.assertTrue( "The dimension of targets must be 1" in str(context.exception) ) def test_in_top_k_pre_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: target = flow.tensor([3, 1], dtype=flow.int32) prediction = flow.tensor( [[[0.0, 1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0, 0.0]]], dtype=flow.float32 ) out = flow.in_top_k(target, prediction, k=1) test_case.assertTrue( "The dimension of predictions must be 2" in str(context.exception) ) def test_repeat_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1], [1]], dtype=flow.int32) y = x.repeat(1) test_case.assertTrue( "Number of dimensions of repeat dims can not be" in str(context.exception) ) def test_tile_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1], [1]], dtype=flow.int32) y = x.tile(-1) test_case.assertTrue( "Trying to create tensor with negative dimension" in str(context.exception) ) def test_t_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[[1]]], dtype=flow.int32) y = x.t() test_case.assertTrue( "t() expects a tensor with <= 2 dimensions" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_autograd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import unittest import oneflow as flow import oneflow.unittest class TestAutograd(flow.unittest.TestCase): def test_non_requires_grad_tensor_backward(test_case): x = flow.ones(4, 4) with test_case.assertRaises(Exception) as context: x.backward() test_case.assertIsNotNone( re.search( r"\nRuntimeError: element \d of tensors does not require grad and does not have a grad_fn", str(context.exception), ) ) def test_allow_unused(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones(4, 4).requires_grad_() y = flow.ones(4, 4).requires_grad_() z = x * x dx, dy = flow.autograd.grad(z, [x, y], flow.ones_like(z)) test_case.assertTrue( "allow_unused=True if this is the desired behavior" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_batch_gather_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from numpy import array, dtype import oneflow as flow import oneflow.unittest class TestBatchGather(flow.unittest.TestCase): def test_input_tensor_dimesion_error_msg(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor(1) indice = flow.tensor([1]) flow.batch_gather(x, indice) test_case.assertTrue( "The dimension of the input tensor should be greater than zero, but got" in str(context.exception) ) def test_indices_dimesion_error_msg(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor([1]) indice = flow.tensor(1) flow.batch_gather(x, indice) test_case.assertTrue( "The dimension of the indices tensor should be greater than zero, but got" in str(context.exception) ) def test_legal_dimension_error_msg(test_case): with test_case.assertRaises(RuntimeError) as context: x = np.random.randn(1) x_tensor = flow.tensor(x) indice = flow.tensor([[1, 1], [1, 1], [1, 1]]) flow.batch_gather(x_tensor, indice) test_case.assertTrue( "The dimension of the input tensor should be greater than or equal to the dimension of the indices tensor" in str(context.exception) ) def test_indice_type_error_msg(test_case): with test_case.assertRaises(TypeError) as context: x = np.random.randn(2) x_tensor = flow.tensor(x) indice = flow.tensor([1, 1], dtype=flow.float64) flow.batch_gather(x_tensor, indice) test_case.assertTrue( "The dtype of the indices tensor must be int32 or int64" in str(context.exception) ) def test_tensor_shape_size_error_msg(test_case): with test_case.assertRaises(RuntimeError) as context: x = np.random.randn(4, 5) x_tensor = flow.tensor(x) indice = flow.tensor([[1, 2], [1, 2], [1, 2]]) out = flow.batch_gather(x_tensor, indice) test_case.assertTrue( "The size of indices tensor must match the size of input tensor" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_bias_add_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestBiasAdd(flow.unittest.TestCase): def test_b_tensor_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor([[1, 1], [2, 2]]) y = flow.tensor([[2, 2], [1, 1]]) out = flow._C.bias_add(y, x, axis=0) test_case.assertTrue( "Bias tensor has to be a one-dimensional vector" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_binary_functor_exception.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest class TestBinaryFunctorError(flow.unittest.TestCase): def test_add_inplace_runtime_error(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) x.add_(y) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) def test_add_broad_cast_runtime_error(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.ones((2, 3)) y = flow.ones((2, 4)) x.add_(y) test_case.assertTrue( "Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((3, 3)) y = flow.ones((2, 3, 3)) x.add_(y) test_case.assertTrue( "Can not expand origin shape (2,3,3) to (3,3)" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) x.mul_(y) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((2, 3)) y = flow.ones((2, 4)) x.mul_(y) test_case.assertTrue( "Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((3, 3)) y = flow.ones((2, 3, 3)) x.mul_(y) test_case.assertTrue( "Can not expand origin shape (2,3,3) to (3,3)" in str(context.exception) ) def test_div_inplace_runtime_error(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True) x.div_(y) test_case.assertTrue( "a leaf Tensor that requires grad is being used in an in-place operation" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((2, 3)) y = flow.ones((2, 4)) x.div_(y) test_case.assertTrue( "Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation" in str(context.exception) ) with test_case.assertRaises(RuntimeError) as context: x = flow.ones((3, 3)) y = flow.ones((2, 3, 3)) x.div_(y) test_case.assertTrue( "Can not expand origin shape (2,3,3) to (3,3)" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_bmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestBmm(flow.unittest.TestCase): def test_bmm_exception_dim_not_right(test_case): x = flow.tensor((2, 2)) with test_case.assertRaises(RuntimeError) as ctx: y = flow.bmm(x, x) test_case.assertTrue( "Expected 3-dimensional tensor, but got 1-dimensional tensor for argument #1" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_broadcast_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest binary_ops = [ flow.add, flow.sub, flow.mul, flow.div, flow.min, flow.minimum, flow.max, flow.maximum, flow.fmod, flow.pow, flow.eq, flow.ne, flow.gt, flow.ge, flow.lt, flow.le, flow.logical_and, flow.logical_or, flow.logical_xor, ] @flow.unittest.skip_unless_1n1d() class TestBroadcastOps(flow.unittest.TestCase): def test_broadcast_binary_ops(test_case): x = flow.Tensor(8, 10) y = flow.Tensor(8) for op in binary_ops: with test_case.assertRaises(RuntimeError) as ctx: op(x, y) test_case.assertTrue( "The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 1" in str(ctx.exception) ) def test_broadcast_shapes(test_case): with test_case.assertRaises(RuntimeError) as ctx: y = flow.broadcast_shapes((2,), (3, 3), (1, 1, 1)) test_case.assertTrue( "input and other can't be broadcasted to a single shape." in str(ctx.exception) ) with test_case.assertRaises(RuntimeError) as ctx: y = flow.broadcast_shapes() test_case.assertTrue("shapes should not be empty." in str(ctx.exception)) def test_broadcast_tensors(test_case): with test_case.assertRaises(RuntimeError) as ctx: y, z = flow.broadcast_tensors(flow.ones(2, 3), flow.ones(4, 3)) test_case.assertTrue( "input and other can't be broadcasted to a single shape." in str(ctx.exception) ) with test_case.assertRaises(RuntimeError) as ctx: y = flow.broadcast_tensors() test_case.assertTrue("tensors should not be empty." in str(ctx.exception)) def test_broadcast_to(test_case): # see flow.expand, because broadcast_to is an alias of flow.expand pass if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_chunk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_chunk_0_dim_input_exception(test_case): # torch exception and messge: # # RuntimeError: chunk expects at least a 1-dimensional tensor. # x = flow.tensor(3.14) with test_case.assertRaises(RuntimeError) as ctx: y = flow.chunk(x, chunks=1, dim=0) test_case.assertTrue( "chunk expects at least a 1-dimensional tensor" in str(ctx.exception) ) def test_chunk_0_chunks_param_exception(test_case): # torch exception and messge: # # RuntimeError: chunk expects `chunks` to be greater than 0, got: 0 # x = flow.tensor([[1, 2, 3], [4, 5, 6]]) with test_case.assertRaises(RuntimeError) as ctx: y = flow.chunk(x, chunks=0, dim=0) test_case.assertTrue( "chunk expects `chunks` to be greater than 0, got: " in str(ctx.exception) ) def test_chunk_dim_param_exception(test_case): # torch exception and messge: # # IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3) # x = flow.tensor([[1, 2, 3], [4, 5, 6]]) with test_case.assertRaises(IndexError) as ctx: y = flow.chunk(x, chunks=2, dim=-3) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2, 1], but got -3)" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_cosine_similarity.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestCosineSimilarity(flow.unittest.TestCase): def test_cosine_similarity_not_floating_type(test_case): x = flow.randn(2, 5).to(flow.int32) y = flow.randn(2, 5).to(flow.int32) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.cosine_similarity(x, y, dim=1) test_case.assertTrue( "expected common dtype to be floating point, yet common dtype is oneflow.int32" in str(ctx.exception) ) def test_cosine_similarity_broadcast(test_case): x = flow.randn(2, 5) y = flow.randn(2, 4) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.cosine_similarity(x, y, dim=1) test_case.assertTrue( "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_deform_conv2d_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestDeformConv(flow.unittest.TestCase): def test_deform_conv2d_invalid_input_sizes(test_case): input = flow.randn(2, 5, 1) weight = flow.randn(2, 5, 1, 1) offset = flow.randn(2, 5, 1, 1) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight) test_case.assertTrue( "The dimension of input tensor weight must be " in str(ctx.exception) ) def test_deform_conv2d_invalid_offset_sizes(test_case): input = flow.randn(2, 5, 1, 1) weight = flow.randn(2, 5, 1, 1) offset = flow.randn(2, 5, 1) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight) test_case.assertTrue( "The dimension of offset tensor weight must be " in str(ctx.exception) ) def test_deform_conv2d_invalid_weight_sizes(test_case): input = flow.randn(2, 5, 1, 1) weight = flow.randn(2, 5, 5) offset = flow.randn(2, 3, 1, 1) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight) test_case.assertTrue( "The dimension of weight tensor weight must be " in str(ctx.exception) ) def test_deform_conv2d_invalid_mask_sizes(test_case): input = flow.randn(2, 5, 1, 1) weight = flow.randn(2, 4, 1, 1) offset = flow.randn(2, 3, 1, 1) mask = flow.randn(2, 3, 1) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight, mask=mask) test_case.assertTrue( "The dimension of mask tensor weight must be" in str(ctx.exception) ) def test_deform_conv2d_invalid_dilation_parm(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d( input, offset, weight, dilation=(-1, 0) ) test_case.assertTrue("The dilation must be greater than" in str(ctx.exception)) def test_deform_conv2d_invalid_pad_parm(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d( input, offset, weight, padding=(-1, 0) ) test_case.assertTrue("The pad must be greater than" in str(ctx.exception)) def test_deform_conv2d_invalid_stride_parm(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d( input, offset, weight, stride=(-1, 0) ) test_case.assertTrue("The stride must be greater than" in str(ctx.exception)) def test_deform_conv2d_invalid_offset_shape(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 9, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight) test_case.assertTrue( "The shape of the offset tensor at dimension 1 is not valid" in str(ctx.exception) ) def test_deform_conv2d_invalid_batch_size(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(3, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight) test_case.assertTrue("invalid batch size of offset" in str(ctx.exception)) def test_deform_conv2d_invalid_mask_shape(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) mask = flow.randn(4, 1, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d(input, offset, weight, mask=mask) test_case.assertTrue("mask.shape[1] is not valid" in str(ctx.exception)) def test_deform_conv2d_invalid_output_size(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d( input, offset, weight, dilation=(10, 10) ) test_case.assertTrue("Calculated output size too small" in str(ctx.exception)) def test_deform_conv2d_invalid_offset_output_dims(test_case): input = flow.randn(4, 3, 10, 10) weight = flow.randn(5, 3, 3, 3) offset = flow.randn(4, 18, 8, 8) with test_case.assertRaises(RuntimeError) as ctx: out = flow.nn.functional.deform_conv2d( input, offset, weight, dilation=(2, 2) ) test_case.assertTrue("invalid offset output dims" in str(ctx.exception)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_device.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import unittest import oneflow as flow import oneflow.unittest import oneflow.nn.functional as F @flow.unittest.skip_unless_1n1d() class TestDevice(flow.unittest.TestCase): def test_device_type(test_case): with test_case.assertRaises(RuntimeError) as exp: flow.device("xpu") test_case.assertTrue( re.match( "Expected one of (.*) device type at start of device string: xpu", str(exp.exception), ) is not None ) def test_device_index(test_case): # TODO(hjchen2): throw runtime error if cuda reports error # with test_case.assertRaises(RuntimeError) as exp: # device = flow.device("cuda:1000") # flow.Tensor(2, 3).to(device=device) # test_case.assertTrue("CUDA error: invalid device ordinal" in str(exp.exception)) pass if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_dot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestDot(flow.unittest.TestCase): def test_dot_shape_error_msg(test_case): with test_case.assertRaises(RuntimeError) as exp: a = flow.tensor([2, 3]) b = flow.tensor([2, 3, 4]) flow.dot(a, b) test_case.assertTrue("inconsistent tensor size" in str(exp.exception)) def test_dot_dims_error_msg(test_case): with test_case.assertRaises(RuntimeError) as exp: a = flow.tensor([[2, 3], [3, 4]]) flow.dot(a, a) test_case.assertTrue("1D tensors expected" in str(exp.exception)) def test_dot_dtype_error_msg(test_case): with test_case.assertRaises(RuntimeError) as exp: a = flow.tensor([2, 3], dtype=flow.int64) b = flow.tensor([2, 3], dtype=flow.float32) flow.dot(a, b) test_case.assertTrue( "expected both vectors to have same dtype" in str(exp.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_error_reported_in_thread.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import subprocess import sys import tempfile import os import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() def test_error_reported_in_thread(): for env_name in ["ONEFLOW_DEBUG", "ONEFLOW_PYTHON_STACK_GETTER"]: env = os.environ.copy() env[env_name] = "1" # Run a new process to capture the error output p = subprocess.run( [sys.executable, "throw_error.py"], capture_output=True, cwd=os.path.dirname(os.path.realpath(__file__)), env=env, ) assert p.returncode != 0 error_msg = p.stderr.decode("utf-8") print(error_msg) assert ( """File "throw_error.py", line 19, in g flow._C.throw_error(x) File "throw_error.py", line 23, in f g(x) File "throw_error.py", line 26, in f(x)""" in error_msg ) @flow.unittest.skip_unless_1n1d() def test_python_stack_getter_disabled(): # Run a new process to capture the error output p = subprocess.run( [sys.executable, "throw_error.py"], capture_output=True, cwd=os.path.dirname(os.path.realpath(__file__)), ) assert p.returncode != 0 error_msg = p.stderr.decode("utf-8") assert "No Python stack available." in error_msg assert "ONEFLOW_DEBUG" in error_msg assert "ONEFLOW_PYTHON_STACK_GETTER" in error_msg ================================================ FILE: python/oneflow/test/exceptions/test_gird_sample_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow.unittest import oneflow.nn import oneflow as flow from oneflow.test_utils.test_util import GenArgList import numpy as np from collections import OrderedDict arg_dict = OrderedDict() arg_dict["N"] = [3, 4, 5] arg_dict["C"] = [4, 5, 6] arg_dict["D_in"] = [8, 11, 13] arg_dict["H_in"] = [5, 6, 7] arg_dict["W_in"] = [7, 8, 9] arg_dict["D_out"] = [13, 15, 17] arg_dict["H_out"] = [9, 10, 11] arg_dict["W_out"] = [11, 12, 13] def _test_dimention_error_msg_impl(test_case, N, C, H_in, H_out): inputval = oneflow.ones(N, C, H_in,) grid = oneflow.ones(N, H_out, 1) with test_case.assertRaises(RuntimeError) as ctx: flow.nn.functional.grid_sample( inputval, grid, mode="bilinear", padding_mode="zeros" ) test_case.assertTrue("MUST be 4D or 5D input" in str(ctx.exception)) def _test_4d_gird_shape_error_msg_impl(test_case, N, C, H_in, W_in, H_out, W_out): inputval = oneflow.ones(N, C, H_in, W_in) grid = oneflow.ones(N, H_out, W_out, 1) with test_case.assertRaises(RuntimeError) as ctx: flow.nn.functional.grid_sample( inputval, grid, mode="bilinear", padding_mode="zeros" ) test_case.assertTrue("Grid shape MUST (N, H_out, W_out, 2)" in str(ctx.exception)) def _test_4d_grid_input_not_same_shape_error_msg_impl( test_case, N, C, H_in, W_in, H_out, W_out ): inputval = oneflow.ones(N, C, H_in, W_in) grid = oneflow.ones(N, H_out, W_out) with test_case.assertRaises(RuntimeError) as ctx: flow.nn.functional.grid_sample( inputval, grid, mode="bilinear", padding_mode="zeros" ) test_case.assertTrue( "Grid and input MUST have same dimention" in str(ctx.exception) ) def _test_5d_gird_shape_error_msg_impl( test_case, N, C, D_in, H_in, W_in, D_out, H_out, W_out ): inputval = oneflow.ones(N, C, D_in, H_in, W_in) grid = oneflow.ones(N, D_out, H_out, W_out, 2) with test_case.assertRaises(RuntimeError) as ctx: flow.nn.functional.grid_sample( inputval, grid, mode="bilinear", padding_mode="zeros" ) test_case.assertTrue("Grid shape MUST (N, H_out, W_out, 3)" in str(ctx.exception)) def _test_5d_grid_input_not_same_shape_error_msg_impl( test_case, N, C, D_in, H_in, W_in, D_out, H_out, W_out ): inputval = oneflow.ones(N, C, D_in, H_in, W_in) grid = oneflow.ones(N, D_out, H_out, W_out) with test_case.assertRaises(RuntimeError) as ctx: flow.nn.functional.grid_sample( inputval, grid, mode="bilinear", padding_mode="zeros" ) test_case.assertTrue( "Grid and input MUST have same dimention" in str(ctx.exception) ) class TestGridSample(flow.unittest.TestCase): def test_dimention_error_msg(test_case): for arg in GenArgList(arg_dict): _test_dimention_error_msg_impl(test_case, arg[0], arg[1], arg[3], arg[6]) def test_4d_gird_shape_error_msg(test_case): for arg in GenArgList(arg_dict): _test_4d_gird_shape_error_msg_impl( test_case, arg[0], arg[1], arg[3], arg[4], arg[6], arg[7] ) def test_4d_grid_input_not_same_shape_error_msg(test_case): for arg in GenArgList(arg_dict): _test_4d_grid_input_not_same_shape_error_msg_impl( test_case, arg[0], arg[1], arg[3], arg[4], arg[6], arg[7] ) def test_5d_gird_shape_error_msg(test_case): for arg in GenArgList(arg_dict): _test_5d_gird_shape_error_msg_impl(test_case, *arg[0:]) def test_5d_grid_input_not_same_shape_error_msg(test_case): for arg in GenArgList(arg_dict): _test_5d_grid_input_not_same_shape_error_msg_impl(test_case, *arg[0:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_broadcast_sbp_1n2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n2d() class TestLocalToGlobalBranchError(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_global_branch_error_with_local_to_global(test_case): try: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "2" data = flow.rand(2, dtype=flow.float32) placement = flow.placement(type="cpu", ranks=[0, 1]) sbp = flow.sbp.broadcast if flow.env.get_rank() == 0: global_data = data.to_global(placement=placement, sbp=sbp) else: time.sleep(2) except Exception as e: err_msg = "Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor" assert err_msg in str(e) finally: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "300" if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_broadcast_sbp_1n4d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n4d() class TestLocalToGlobalBranchError(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_global_branch_error_with_local_to_global(test_case): try: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "2" data = flow.rand(2, dtype=flow.float32) placement = flow.placement(type="cpu", ranks=[0, 1]) sbp = flow.sbp.broadcast if flow.env.get_rank() == 0: global_data = data.to_global(placement=placement, sbp=sbp) else: time.sleep(2) except Exception as e: err_msg = "Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor" assert err_msg in str(e) finally: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "300" if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_split_sbp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n2d() class TestLocalToGlobalBranchError(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_global_branch_error_with_local_to_global(test_case): try: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "2" data = flow.rand(2, dtype=flow.float32) placement = flow.placement.all("cuda") sbp = flow.sbp.split(0) if flow.env.get_rank() == 0: global_data = data.to_global(placement=placement, sbp=sbp) else: time.sleep(2) except Exception as e: err_msg = "Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor" assert err_msg in str(e) finally: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "300" if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_global_branch_error_with_global_mean.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n2d() class TestGlobalMeanBranchError(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_global_branch_error_global_data_mean(test_case): try: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "2" data = flow.rand(2, dtype=flow.float32) placement = flow.placement.all("cuda") sbp = flow.sbp.split(0) global_data = data.to_global(placement=placement, sbp=sbp) if flow.env.get_rank() == 0: print(data.mean()) print(global_data.mean()) else: time.sleep(2) except Exception as e: err_msg = "Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor" assert err_msg in str(e) finally: os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "300" if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_hann_window.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestHannWindow(flow.unittest.TestCase): def test_hann_window_dtype_not_support(test_case): window_length = 8 dtype = flow.int64 with test_case.assertRaises(RuntimeError) as ctx: x = flow.hann_window(window_length, dtype=dtype) test_case.assertTrue( "hann_window expects floating point dtypes, got: " in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_in_top_k.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow.unittest import oneflow as flow import numpy as np class TestInTopK(flow.unittest.TestCase): def test_in_top_k_error_msg(test_case): arr = np.array([1, 1]) targets = flow.Tensor(arr) targets = flow.cast(targets, flow.float) arr = np.array([[0.8, 0.6, 0.3], [0.1, 0.6, 0.4]]) predictions = flow.Tensor(arr) with test_case.assertRaises(RuntimeError) as ctx: flow.in_top_k(targets, predictions, 1) test_case.assertTrue( "targets data type must be index type" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_inv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestInv(flow.unittest.TestCase): def test_inv_exception_dim_short(test_case): x = flow.tensor((2, 2)) with test_case.assertRaises(RuntimeError) as ctx: y = flow.linalg.inv(x) test_case.assertTrue( "linalg.inv: The input tensor must be at least 2 dimensions." in str(ctx.exception) ) def test_inv_exception_not_square_matrix(test_case): x = flow.randn(2, 3, 2) with test_case.assertRaises(RuntimeError) as ctx: y = flow.linalg.inv(x) test_case.assertTrue( "RuntimeError: linalg.inv: A must be batches of square matrices, but they are 3 by 2 matrices" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_layernorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestLayerNormModule(flow.unittest.TestCase): def test_layernorm_exception_input_shape_not_match(test_case): x = flow.randn(2, 3) m = flow.nn.LayerNorm(2) with test_case.assertRaises(RuntimeError) as ctx: y = m(x) test_case.assertTrue( "Given normalized_shape=(2,), expected input with shape [*, 2,], but got input of size oneflow.Size([2, 3])" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_linalg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest class TestLinalgCross(flow.unittest.TestCase): def test_cross_has_no_3_error(test_case): a = flow.randn(4, 2) b = flow.randn(4, 2) with test_case.assertRaises(RuntimeError) as ctx: flow.cross(a, b) test_case.assertTrue( "RuntimeError: no dimension of size 3 in input." in str(ctx.exception) ) def test_linalg_cross_has_no_3_error(test_case): a = flow.randn(4, 2) b = flow.randn(4, 2) with test_case.assertRaises(RuntimeError) as ctx: flow.linalg.cross(a, b) test_case.assertTrue( "RuntimeError: the size of the specified dimension(which is -1) is not 3." in str(ctx.exception) ) def test_linalg_cross_broadcast_error(test_case): a = flow.randn(4) b = flow.randn(4, 2) with test_case.assertRaises(RuntimeError) as ctx: flow.linalg.cross(a, b) test_case.assertTrue( "RuntimeError: input and other can't be broadcasted to a single shape. [input's shape: (1,4), other's shape: (4,2)]." in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_local_global_convert_error.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestModule(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_get_sbp_with_invalid_axis(test_case): with test_case.assertRaises(RuntimeError) as ctx: sbp = flow.sbp.split(-1) test_case.assertTrue( "Split axis must not be negative, but got -1!" in str(ctx.exception) ) with test_case.assertRaises(RuntimeError) as ctx: sbp = flow.sbp.split(7) test_case.assertTrue( "Expected split axis to be less than the supported maximum axis (6), but got 7!" in str(ctx.exception) ) @flow.unittest.skip_unless_1n1d() def test_local_to_global_with_invalid_split_axis(test_case): x = flow.tensor([1, 2, 3, 4]) with test_case.assertRaises(RuntimeError) as ctx: y = x.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.split(1)) test_case.assertTrue( "Split axis out of range (expected to be in range of [0, 1), but got 1!" in str(ctx.exception) ) @flow.unittest.skip_unless_1n1d() def test_global_to_global_with_invalid_split_axis(test_case): x = flow.tensor( [1, 2, 3, 4], placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ) with test_case.assertRaises(RuntimeError) as ctx: y = x.to_global(sbp=flow.sbp.split(1)) test_case.assertTrue( "Split axis out of range (expected to be in range of [0, 1), but got 1!" in str(ctx.exception) ) @flow.unittest.skip_unless_1n1d() def test_call_to_local_for_local_tensor(test_case): x = flow.tensor([1, 2, 3, 4]) with test_case.assertRaises(RuntimeError) as ctx: y = x.to_local() test_case.assertTrue( "Expected global tensor for to_local but got local tensor!" in str(ctx.exception) ) @flow.unittest.skip_unless_1n2d() def test_local_to_global_with_invalid_size(test_case): if flow.env.get_rank() == 0: x = flow.Tensor(2, 4) # size(2, 4) else: x = flow.Tensor(4, 4) # size(4, 4) with test_case.assertRaises(RuntimeError) as ctx: y = x.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.split(0)) test_case.assertTrue( "Sizes of tensors in dimension 0 must be same or match balanced split distribution. " "See https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/balanced_splitter.h " "for details of balanced split" in str(ctx.exception) ) with test_case.assertRaises(RuntimeError) as ctx: y = x.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.split(1)) test_case.assertTrue( "Sizes of tensors must match except in dimension 1. Expected size 2 but got size 4 for tensor on rank 1!" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_median.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMedian(flow.unittest.TestCase): def test_median_exception_dim_out_of_range(test_case): x = flow.tensor((2, 2)) with test_case.assertRaises(IndexError) as ctx: y = flow.median(x, 1) test_case.assertTrue( "Dimension out of range (expected to be in range of [-1, 0], but got 1)" in str(ctx.exception) ) def test_median_exception_reduce_0dim(test_case): x = flow.randn(2, 0, 2) with test_case.assertRaises(IndexError) as ctx: y = flow.median(x, 1) test_case.assertTrue( "IndexError: Expected reduction dim 1 to have non-zero size." in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_mm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import oneflow.nn.functional as F @flow.unittest.skip_unless_1n1d() class TestMm(flow.unittest.TestCase): def test_mm_not_2dim(test_case): with test_case.assertRaises(Exception) as exp: mat1 = flow.randn(2, 3, 3) mat2 = flow.randn(3, 3) out = flow.mm(mat1, mat2) test_case.assertTrue("self must be a matrix" in str(exp.exception)) with test_case.assertRaises(Exception) as exp: mat1 = flow.randn(2, 3) mat2 = flow.randn(3, 3, 2) out = flow.mm(mat1, mat2) test_case.assertTrue("mat2 must be a matrix" in str(exp.exception)) def test_mm_dim_not_match(test_case): with test_case.assertRaises(Exception) as exp: mat1 = flow.randn(2, 3) mat2 = flow.randn(4, 3) out = flow.mm(mat1, mat2) test_case.assertTrue( "mat1 and mat2 shapes cannot be multiplied (2x3 and 4x3)" in str(exp.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMode(flow.unittest.TestCase): def test_mode_exception_dim_out_of_range(test_case): x = flow.tensor((2, 2)) with test_case.assertRaises(IndexError) as ctx: y = flow.mode(x, 1) test_case.assertTrue( "Dimension out of range (expected to be in range of [-1, 0], but got 1)" in str(ctx.exception) ) def test_mode_exception_reduce_0dim(test_case): x = flow.randn(2, 0, 2) with test_case.assertRaises(IndexError) as ctx: y = flow.mode(x, 1) test_case.assertTrue( "IndexError: Expected reduction dim 1 to have non-zero size." in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_multi_input_with_diff_device_or_placement.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestModule(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_multi_input_with_diff_device(test_case): # torch exception and messge: # # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! # x = flow.tensor([1, 2, 3, 4]) y = flow.tensor([2, 4, 6, 8], device="cuda") with test_case.assertRaises(RuntimeError) as ctx: z = flow.add(x, y) test_case.assertTrue( "Expected all tensors to be on the same device, but found at least two devices" in str(ctx.exception) ) @flow.unittest.skip_unless_1n2d() def test_multi_input_with_diff_placement(test_case): x = flow.tensor( [1, 2, 3, 4], placement=flow.placement("cuda", [0]), sbp=flow.sbp.broadcast ) y = flow.tensor( [2, 4, 6, 8], placement=flow.placement("cuda", [1]), sbp=flow.sbp.broadcast ) with test_case.assertRaises(RuntimeError) as ctx: z = flow.add(x, y) test_case.assertTrue( "Expected all tensors to be on the same placement, but found at least two placements" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_mv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMv(flow.unittest.TestCase): def test_mv_not_matrix(test_case): with test_case.assertRaises(Exception) as exp: mat = flow.randn(2, 3, 3) vec = flow.randn(3) out = flow.mv(mat, vec) test_case.assertTrue( "vector + matrix @ vector expected, got 1, 3, 1" in str(exp.exception) ) def test_mv_not_vector(test_case): with test_case.assertRaises(Exception) as exp: mat = flow.randn(2, 3) vec = flow.randn(3, 1) out = flow.mv(mat, vec) test_case.assertTrue( "vector + matrix @ vector expected, got 1, 2, 2" in str(exp.exception) ) def test_mv_size_mismatch(test_case): with test_case.assertRaises(Exception) as exp: mat = flow.randn(2, 3) vec = flow.randn(4) out = flow.mv(mat, vec) test_case.assertTrue("size mismatch" in str(exp.exception)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_nn_functor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * class TestBiasAddError(flow.unittest.TestCase): def test_bias_add_dimension_match_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) bias = flow.ones((5,), dtype=flow.float32) out = flow._C.bias_add(x, bias, axis=1) test_case.assertTrue( "The size of tensor x (4,4) must match the size of tensor b (5,) at dimension 1" in str(ctx.exception) ) def test_bias_add_index_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) bias = flow.ones((5,), dtype=flow.float32) out = flow._C.bias_add(x, bias, axis=3) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2,1], but got 3)" in str(ctx.exception) ) class TestCrossEntropyError(flow.unittest.TestCase): def test_cross_entropy_reduction_type_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) target = flow.ones((4, 4), dtype=flow.float32) out = flow._C.cross_entropy(x, target, None, 0, "just_test") test_case.assertTrue( "Reduction should be none, sum or mean." in str(ctx.exception) ) class TestCTCLossError(flow.unittest.TestCase): def test_ctcloss_reduction_type_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((5, 2, 3), dtype=flow.float32) targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32) input_lengths = flow.tensor([5, 5], dtype=flow.int32) target_lengths = flow.tensor([3, 3], dtype=flow.int32) max_target_length = 0 if targets.ndim == 1: max_target_length = target_lengths.max().item() elif targets.ndim == 2: max_target_length = targets.shape[1] loss = flow._C.ctc_loss( x, targets, input_lengths, target_lengths, max_target_length, blank=0, zero_infinity=False, reduction="just_test", ) test_case.assertTrue( "Reduction should be none, sum or mean." in str(ctx.exception) ) class TestPadError(flow.unittest.TestCase): def test_pad_size_attribute_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 1), dtype=flow.float32) out = flow._C.pad(x, (1, 1, 1, 1, 1)) test_case.assertTrue( "Pad size should less than or equal to input axes * 2." in str(ctx.exception) ) def test_pad_size_mod2_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 1), dtype=flow.float32) out = flow._C.pad(x, (1, 1, 1,)) test_case.assertTrue( "Length of pad must be even but instead it equals 3" in str(ctx.exception) ) def test_reflect_pad_size_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 1, 2, 2), dtype=flow.float32) out = flow._C.pad(x, (4, 4, 4, 4), mode="reflect") test_case.assertTrue( "Padding size should be less than the corresponding input dimension, but got:" in str(ctx.exception) ) def test_pad_mode_error(test_case): with test_case.assertRaises(NotImplementedError) as ctx: x = flow.ones((1, 1, 2, 2), dtype=flow.float32) out = flow._C.pad(x, (4, 4, 4, 4), mode="test") test_case.assertTrue( "Pad mode is test, but only constant, reflect and replicate are valid." in str(ctx.exception) ) class TestFusedMLPError(flow.unittest.TestCase): def test_fuse_mlp_weight_size_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) bias = flow.ones((4,), dtype=flow.float32) out = flow._C.fused_mlp(x, [], [bias], False) test_case.assertTrue( "The number of weights should be greater equal than 1" in str(ctx.exception) ) def test_fuse_mlp_weight_bias_size_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) w1 = flow.ones((4, 4), dtype=flow.float32) w2 = flow.ones((4, 4), dtype=flow.float32) bias1 = flow.ones((4,), dtype=flow.float32) out = flow._C.fused_mlp(x, [w1, w2], [bias1], False) test_case.assertTrue( "The number of weights should be equal to biases" in str(ctx.exception) ) def test_fuse_mlp_weight_numaxes_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) w1 = flow.ones((4,), dtype=flow.float32) bias1 = flow.ones((4,), dtype=flow.float32) out = flow._C.fused_mlp(x, [w1,], [bias1,], False) test_case.assertTrue("Weight's dim size should == 2" in str(ctx.exception)) def test_fuse_mlp_bias_numaxes_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) w1 = flow.ones((4, 4), dtype=flow.float32) bias1 = flow.ones((4, 4), dtype=flow.float32) out = flow._C.fused_mlp(x, [w1,], [bias1,], False) test_case.assertTrue("Bias's dim size should == 1" in str(ctx.exception)) def test_fuse_mlp_bias_first_dim_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) w1 = flow.ones((6, 4), dtype=flow.float32) bias1 = flow.ones((5), dtype=flow.float32) out = flow._C.fused_mlp(x, [w1,], [bias1,], False) test_case.assertTrue( "Bias's dim is not equal to weight's first dim." in str(ctx.exception) ) def test_fuse_mlp_weight_second_dim_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((2, 4), dtype=flow.float32) w1 = flow.ones((3, 6), dtype=flow.float32) bias1 = flow.ones((3), dtype=flow.float32) out = flow._C.fused_mlp(x, [w1,], [bias1,], False) test_case.assertTrue( "weight's second dim should be equal to input's second dim." in str(ctx.exception) ) class TestL2NormalizeError(flow.unittest.TestCase): def test_l2normalize_axis_error1(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((3, 3), dtype=flow.float32) out = flow._C.normalize(x, dim=3, use_l2_norm_kernel=True) test_case.assertTrue("Axis should < 2 but axis is 3 now." in str(ctx.exception)) def test_l2normalize_axis_error2(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((3, 3), dtype=flow.float32) out = flow._C.normalize(x, dim=-3, use_l2_norm_kernel=True) test_case.assertTrue( "Axis should >=0 but axis is -1 now." in str(ctx.exception) ) class TestLossBaseFunctorError(flow.unittest.TestCase): def test_loss_base_reduction_type_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) target = flow.ones((4, 4), dtype=flow.float32) out = flow._C.mse_loss(x, target, "just_test") test_case.assertTrue( "Reduction should be none, sum or mean." in str(ctx.exception) ) class TestMatmulError(flow.unittest.TestCase): def test_matmul_dimension_error1(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((), dtype=flow.float32) w = flow.ones((4, 4), dtype=flow.float32) out = flow._C.matmul(x, w, False, False, 1.0) test_case.assertTrue("Tensor a's dim should >= 1" in str(ctx.exception)) def test_matmul_dimension_error2(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((4, 4), dtype=flow.float32) w = flow.ones((), dtype=flow.float32) out = flow._C.matmul(x, w, False, False, 1.0) test_case.assertTrue("Tensor b's dim should >= 1" in str(ctx.exception)) class TestPixelShuffleError(flow.unittest.TestCase): def test_pixel_shuffle_4D_input_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 8, 4, 4, 1), dtype=flow.float32) out = flow._C.pixel_shuffle(x, 2, 2) test_case.assertTrue("Only Accept 4D Tensor" in str(ctx.exception)) def test_pixel_shuffle_channel_divisble_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 8, 4, 4), dtype=flow.float32) out = flow._C.pixel_shuffle(x, 2, 3) test_case.assertTrue( "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)" in str(ctx.exception) ) class TestTripletMarginLossError(flow.unittest.TestCase): def test_triplet_margin_loss_reduce_type_error(test_case): with test_case.assertRaises(Exception) as ctx: anchor = flow.ones((3, 3), dtype=flow.float32) positive = flow.ones((3, 3), dtype=flow.float32) negative = flow.ones((3, 3), dtype=flow.float32) triplet_loss = flow._C.triplet_margin_loss( anchor, positive, negative, margin=0.001, p=2, eps=1e-5, swap=False, reduction="just_test", ) test_case.assertTrue( "Reduction should be none, sum or mean." in str(ctx.exception) ) class TestNormalError(flow.unittest.TestCase): def test_normal_data_type_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow._C.normal(mean=0.0, std=1.0, size=(3, 3), dtype=flow.int32) test_case.assertTrue( "Only support float and double in normal()." in str(ctx.exception) ) def test_normal_out_tensor_data_type_error(test_case): with test_case.assertRaises(RuntimeError) as ctx: out = flow.zeros((3, 3), dtype=flow.float64) x = flow._C.normal( mean=0.0, std=1.0, size=(3, 3), dtype=flow.float32, out=out ) test_case.assertTrue( "data type oneflow.float32 does not match data type of out parameter oneflow.float64" in str(ctx.exception) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_normal_out_tensor_device_type_error(test_case): with test_case.assertRaises(RuntimeError) as ctx: out = flow.zeros((3, 3), dtype=flow.float32, device="cuda") x = flow._C.normal( mean=0.0, std=1.0, size=(3, 3), dtype=flow.float32, out=out, device="cpu", ) test_case.assertTrue( "does not match device type of out parameter" in str(ctx.exception) ) class TestNormalizationError(flow.unittest.TestCase): def test_normalization_moving_mean_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 4, 2, 2), dtype=flow.float32) moving_mean = flow.ones((4,), dtype=flow.float32) weight = flow.ones((4,), dtype=flow.float32) bias = flow.ones((4,), dtype=flow.float32) out = flow._C.normalization( x, moving_mean, None, weight, bias, 1, 1e-5, 0.9, False ) test_case.assertTrue( "Both moving_mean and moving_variance should be None or Tensor." in str(ctx.exception) ) def test_normalization_x_input_axes_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1,), dtype=flow.float32) weight = flow.ones((4,), dtype=flow.float32) bias = flow.ones((4,), dtype=flow.float32) out = flow._C.normalization( x, None, None, weight, bias, 1, 1e-5, 0.9, False ) test_case.assertTrue( "NumAxes of x should be greater or equal than 2." in str(ctx.exception) ) def test_normalization_eval_need_moving_statistic_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((1, 2,), dtype=flow.float32) weight = flow.ones((2,), dtype=flow.float32) bias = flow.ones((2,), dtype=flow.float32) out = flow._C.normalization( x, None, None, weight, bias, 1, 1e-5, 0.9, False ) test_case.assertTrue( "Must have moving_mean and moving_variance in eval mode." in str(ctx.exception) ) class TestOnehotError(flow.unittest.TestCase): def test_onehot_error(test_case): with test_case.assertRaises(Exception) as ctx: x = flow.ones((3, 3), dtype=flow.float32) out = flow._C.one_hot(x, 3, 0.9, 0) test_case.assertTrue( "one_hot is only applicable to index tensor." in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_optim_add_param_group.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestSgdAddParamGroup(flow.unittest.TestCase): def test_sgd_add_param_group_not_unique(test_case): with test_case.assertRaises(Exception) as exp: w1 = flow.ones(3, 3) w1.requires_grad = True w2 = flow.ones(3, 3) w2.requires_grad = True o = flow.optim.SGD([w1]) o.add_param_group({"params": w2}) o.add_param_group({"params": w2}) print(str(exp.exception)) test_case.assertTrue( "some parameters appear in more than one parameter group" in str(exp.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import oneflow.nn.functional as F import torch @flow.unittest.skip_unless_1n1d() class TestPad(flow.unittest.TestCase): def test_torch_type(test_case): with test_case.assertRaises(TypeError) as exp: F.pad(torch.randn(2, 2)) test_case.assertTrue( "pad() missing 1 required positional argument: 'pad'" in str(exp.exception) ) def test_numpy_type(test_case): import numpy as np with test_case.assertRaises(TypeError) as exp: F.pad(np.random.randn(2, 2)) test_case.assertTrue( "pad() missing 1 required positional argument: 'pad'" in str(exp.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_placement.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestPlacement(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_inconsistent_placement(test_case): x = flow.randn(2, 3) if flow.env.get_rank() == 0: placement = flow.placement("cpu", [0, 1]) else: placement = flow.placement("cpu", [0]) sbp = flow.sbp.split(1) with test_case.assertRaises(RuntimeError) as ctx: x_global = x.to_global(placement=placement, sbp=sbp) test_case.assertTrue("Inconsistent parallel description" in str(ctx.exception)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_randperm_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestRandpermOp(flow.unittest.TestCase): def test_randperm_n_value_err_mes(test_case): with test_case.assertRaises(RuntimeError) as ctx: a = flow.randperm(-1) test_case.assertTrue( "Trying to create tensor with negative dimension" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_reduce_like_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestReduceSumLikeOps(flow.unittest.TestCase): def test_reduce_sum_like_empty_axis_case_err(test_case): a = flow.tensor([1, 1]) b = flow.tensor([1, 1, 1]) with test_case.assertRaises(RuntimeError) as ctx: flow._C.reduce_sum_like(a, b, []) test_case.assertTrue( "The shape of the x tensor must be consistent to the shape of the like tensor" in str(ctx.exception) ) def test_reduce_sum_like_type_err(test_case): a = flow.tensor([1, 1], dtype=flow.int64) b = flow.tensor([1, 1], dtype=flow.float64) with test_case.assertRaises(TypeError) as ctx: flow._C.reduce_sum_like(a, b, [1]) test_case.assertTrue( "Tensors x and like must have the same type" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_reduce_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestReduceOps(flow.unittest.TestCase): def test_exception_dim_out_of_int_range(test_case): x = flow.randn(2, 3, 4) with test_case.assertRaises(IndexError) as exp: flow.sum(x, 3) test_case.assertTrue("Dimension out of range" in str(exp.exception)) def test_exception_dim_out_of_list_range(test_case): x = flow.randn(2, 3, 4) with test_case.assertRaises(IndexError) as exp: flow.sum(x, [-4]) test_case.assertTrue("Dimension out of range" in str(exp.exception)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_repeat_interleave.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import oneflow.nn.functional as F import torch @flow.unittest.skip_unless_1n1d() class TestRepeatInterleave(flow.unittest.TestCase): def test_repeat_interleave_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) y = flow.repeat_interleave(x, 3, dim=4) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2, 1], but got 4)" in str(context.exception) ) def test_repeat_interleave_tensor_shape_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) r = flow.tensor([[1, 2], [3, 4]]) y = flow.repeat_interleave(x, r, dim=1) test_case.assertTrue( "repeat_interleave only accept 1D vector as repeat" in str(context.exception) ) def test_repeat_interleave_dtype_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) r = flow.tensor([1.0, 2.0]) y = flow.repeat_interleave(x, r, dim=1) test_case.assertTrue("repeats has to be Long tensor" in str(context.exception)) def test_repeat_interleave_negative_tensor_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) r = flow.tensor([1, -2]) y = flow.repeat_interleave(x, r, dim=1) test_case.assertTrue("repeats can not be negative" in str(context.exception)) def test_repeat_interleave_negative_tensor_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) r = flow.tensor([1, 2]) y = flow.repeat_interleave(x, r, dim=2) test_case.assertTrue( "Dimension out of range (expected to be in range of [-2, 1], but got 2)" in str(context.exception) ) def test_repeat_interleave_dim_not_match_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2], [3, 4]]) r = flow.tensor([1]) y = flow.repeat_interleave(x, r, dim=1) test_case.assertTrue( "repeats must have the same size as input along dim" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_reshape.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_reshape_exception_invalid_dim(test_case): # torch exception and messge: # # RuntimeError: Invalid shape dimension -2 # x = flow.tensor((2, 2)) with test_case.assertRaises(RuntimeError) as ctx: y = x.reshape((-2, 4)) test_case.assertTrue("Invalid shape dimension -2" in str(ctx.exception)) def test_reshape_exception_invalid_size(test_case): # torch exception and messge: # # RuntimeError: shape '[2, 3, 5]' is invalid for input of size 24 # x = flow.arange(24).reshape(2, 3, 4) with test_case.assertRaises(RuntimeError) as ctx: y = x.reshape((2, 3, 5)) test_case.assertTrue("is invalid for input of size 24" in str(ctx.exception)) def test_reshape_exception_only_one_dim_infered(test_case): # torch exception and messge: # # RuntimeError: only one dimension can be inferred # x = flow.tensor((2, 2)) with test_case.assertRaises(RuntimeError) as ctx: y = x.reshape((-1, -1)) test_case.assertTrue("only one dimension can be inferred" in str(ctx.exception)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_reshape_like_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestReshapeLikeOp(flow.unittest.TestCase): def test_reshape_like_size_match_err(test_case): a = flow.tensor([1, 1]) b = flow.tensor([[1, 1, 1], [1, 1, 1]]) with test_case.assertRaises(RuntimeError) as ctx: flow._C.reshape_like(a, b) test_case.assertTrue( "The element number of the in tensor must be equal to the element number of the like tensor" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_roi_align_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest class TestRoiAlignOp(flow.unittest.TestCase): def test_rol_align_x_tensor_dimension_err(test_case): x = flow.randn(2, 3, 64) rois = flow.randn(2, 3, 64, 64) with test_case.assertRaises(RuntimeError) as ctx: flow.roi_align(x, rois, 2.0, 14, 14, 2, True) test_case.assertTrue( "The dimension of x tensor must be equal to 4, but got" in str(ctx.exception) ) def test_rol_align_rois_tensor_dimension_err(test_case): x = flow.randn(2, 3, 64, 5) rois = flow.randn(2, 3, 64, 64) with test_case.assertRaises(RuntimeError) as ctx: flow.roi_align(x, rois, 2.0, 14, 14, 2, True) test_case.assertTrue( "The dimension of rois tensor must be equal to 2, but got" in str(ctx.exception) ) def test_rol_align_rois_tensor_size_err(test_case): x = flow.randn(2, 3, 64, 5) rois = flow.randn(2, 3) with test_case.assertRaises(RuntimeError) as ctx: flow.roi_align(x, rois, 2.0, 14, 14, 2, True) test_case.assertTrue( "The size of rois tensor must be equal to 5 at dimension 1, but got" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_save_load.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import tempfile import oneflow as flow import oneflow.unittest import torch @flow.unittest.skip_unless_1n1d() class TestSaveLoad(flow.unittest.TestCase): def test_support_pytorch_with_global_src_rank(test_case): conv_torch = torch.nn.Conv2d(3, 3, 3) conv_flow = flow.nn.Conv2d(3, 3, 3) with tempfile.NamedTemporaryFile() as f: torch.save(conv_torch.state_dict(), f.name) with test_case.assertRaises(ValueError) as ctx: conv_flow.load_state_dict( flow.load(f.name, support_pytorch_format=False) ) test_case.assertTrue("Cannot load file" in str(ctx.exception)) def test_load_invalid_file(test_case): f = tempfile.NamedTemporaryFile() f.write(b"invalid file") f.flush() with test_case.assertRaises(ValueError) as ctx: flow.load(f.name) test_case.assertTrue("Cannot load file" in str(ctx.exception)) f.close() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_saved_tensor_hooks.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestSavedTensorHooks(flow.unittest.TestCase): def test_unpack_returns_non_tensor(test_case): x = flow.ones(1, 2, 3).to("cuda").requires_grad_() y = flow.zeros(1, 2, 3).to("cuda").requires_grad_() def pack(x): return x def unpack(x): return 0 with flow.autograd.graph.saved_tensors_hooks(pack, unpack): z = x * y with test_case.assertRaises(Exception) as exp: z.sum().backward() test_case.assertTrue( "unpack_hook should return a Tensor, but got `` instead" in str(exp.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_slice_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import numpy as np class TestSlice(flow.unittest.TestCase): def test_slice_update_start_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([[1], [2]]) value = flow.tensor([[1], [2]]) start = [-1] stop = [1] step = [1] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "The start list elements must be greater than or equal to 0, but got" in str(context.exception) ) def test_slice_update_stop_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([[1], [2]]) value = flow.tensor([[1], [2]]) start = [1] stop = [-1] step = [1] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "The stop list elements must be greater than or equal to 0" in str(context.exception) ) def test_slice_update_step_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([[1], [2]]) value = flow.tensor([[1], [2]]) start = [1] stop = [1] step = [0] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "The step list elements must be greater than 0, but got" in str(context.exception) ) def test_slice_update_start_and_stop_compare_value_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([[1], [2]]) value = flow.tensor([[1], [2]]) start = [2] stop = [1] step = [1] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "The element in start list must be less than or equal to the element in stop list at index" in str(context.exception) ) def test_slice_update_turple_size_match_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([1, 2]) value = flow.tensor([1, 2]) start = [1, 2, 3] stop = [1, 2, 3] step = [1, 2, 3] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "The size of slice tuple must be equal to the size of value tensor at dimension" in str(context.exception) ) def test_slice_update_type_err(test_case): with test_case.assertRaises(TypeError) as context: ref = flow.tensor([1], dtype=flow.int64) value = flow.tensor([0.545], dtype=flow.float32) start = [1] stop = [2] step = [1] flow._C.slice_update(ref, value, start, stop, step) test_case.assertTrue( "Tensors ref and value must have same type" in str(context.exception) ) def test_slice_start_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([1]) start = [-1] stop = [1] step = [1] flow._C.slice(ref, start, stop, step) test_case.assertTrue( "The start list elements must be greater than or equal to 0, but got " in str(context.exception) ) def test_slice_stop_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([1]) start = [1] stop = [-1] step = [1] flow._C.slice(ref, start, stop, step) test_case.assertTrue( "The stop list elements must be greater than or equal to 0, but got " in str(context.exception) ) def test_slice_step_list_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([1]) start = [1] stop = [1] step = [-1] flow._C.slice(ref, start, stop, step) test_case.assertTrue( "The step list elements must be greater than 0, but got " in str(context.exception) ) def test_slice_start_and_stop_compare_value_err(test_case): with test_case.assertRaises(RuntimeError) as context: ref = flow.tensor([1]) start = [2] stop = [1] step = [1] flow._C.slice(ref, start, stop, step) test_case.assertTrue( "The element in start list must be less than or equal to the element in stop list at index " in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_smooth_l1_loss_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestSmoothL1LossError(flow.unittest.TestCase): def test_smooth_l1_loss_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: input = flow.randn(10) target = flow.randn(11) reduction = "mean" beta = 1.0 flow._C.smooth_l1_loss(input, target, beta, reduction) test_case.assertTrue("must match the size of target" in str(context.exception)) def test_smooth_l1_loss_beta_err(test_case): with test_case.assertRaises(RuntimeError) as context: input = flow.randn(10) target = flow.randn(10) reduction = "mean" beta = -1.0 flow._C.smooth_l1_loss(input, target, beta, reduction) test_case.assertTrue( "beta must be greater than or equal to 0" in str(context.exception) ) def test_smooth_l1_loss_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: input = flow.randn(10, dtype=flow.float32) target = flow.randn(10, dtype=flow.float64) reduction = "mean" beta = 1.0 flow._C.smooth_l1_loss(input, target, beta, reduction) test_case.assertTrue( "input and target are expected to have the same dtype" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_softmax_cross_entropy_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestSoftmaxCrossEntropyError(flow.unittest.TestCase): def test_softmax_cross_entropy_prediction_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10) label = flow.randn(1, 10) flow._C.softmax_cross_entropy(prediction, label) test_case.assertTrue( "The dimension of prediction must be greater than or equal to 2, but found" in str(context.exception) ) def test_softmax_cross_entropy_prediction_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(1, 10) label = flow.randn(1, 11) flow._C.softmax_cross_entropy(prediction, label) test_case.assertTrue( "must match the size of prediction" in str(context.exception) ) def test_softmax_cross_entropy_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: prediction = flow.randn(1, 10, dtype=flow.float32) label = flow.randn(1, 10, dtype=flow.float64) flow._C.softmax_cross_entropy(prediction, label) test_case.assertTrue( "label and prediction are expected to have the same dtype, but found" in str(context.exception) ) def test_softmax_cross_entropy_grad_prob_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: dy = flow.randn(10, 5) label = flow.randn(10, 10, 5) prob = flow.randn(10) flow._C.softmax_cross_entropy_grad(dy, label, prob) test_case.assertTrue( "The dimension of prob must be greater than or equal to 2, but found " in str(context.exception) ) def test_softmax_cross_entropy_grad_dy_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: dy = flow.randn(10, 10, 5) label = flow.randn(10, 10, 5) prob = flow.randn(10, 10, 5) flow._C.softmax_cross_entropy_grad(dy, label, prob) test_case.assertTrue( "The dimension of dy is expected to be less than that of prob by 1, but found" in str(context.exception) ) def test_softmax_cross_entropy_grad_dy_i_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: dy = flow.randn(10, 8) label = flow.randn(10, 10, 5) prob = flow.randn(10, 10, 5) flow._C.softmax_cross_entropy_grad(dy, label, prob) test_case.assertTrue("must match the size of label" in str(context.exception)) def test_softmax_cross_entropy_grad_prob_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: dy = flow.randn(10, 10) label = flow.randn(10, 10, 5) prob = flow.randn(10, 10, 6) flow._C.softmax_cross_entropy_grad(dy, label, prob) test_case.assertTrue("must match the size of prob" in str(context.exception)) def test_softmax_cross_entropy_grad_label_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: dy = flow.randn(10, 10, dtype=flow.float64) label = flow.randn(10, 10, 5, dtype=flow.float32) prob = flow.randn(10, 10, 5, dtype=flow.float64) flow._C.softmax_cross_entropy_grad(dy, label, prob) test_case.assertTrue( "label and prob are expected to have the same dtype, but found" in str(context.exception) ) def test_softmax_cross_entropy_grad_dy_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: dy = flow.randn(10, 10, dtype=flow.float32) label = flow.randn(10, 10, 5, dtype=flow.float64) prob = flow.randn(10, 10, 5, dtype=flow.float64) flow._C.softmax_cross_entropy_grad(dy, label, prob) print(str(context.exception)) test_case.assertTrue( "dy and prob are expected to have the same dtype, but found" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestSparseCrossEntropyError(flow.unittest.TestCase): def test_sparse_cross_entropy_prediction_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10) label = flow.randint(0, 10, (10, 10), dtype=flow.int64) depth = 10 flow._C.sparse_cross_entropy(prediction, label, depth) test_case.assertTrue( "The dimension of prediction must be greater than or equal to 2, but found" in str(context.exception) ) def test_sparse_cross_entropy_label_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randint(0, 10, (10, 10, 5), dtype=flow.int64) depth = 10 flow._C.sparse_cross_entropy(prediction, label, depth) test_case.assertTrue( "The dimension of label is expected to be less than that of prediction by 1" in str(context.exception) ) def test_sparse_cross_entropy_prediction_i_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randint(0, 10, (10, 5), dtype=flow.int64) depth = 10 flow._C.sparse_cross_entropy(prediction, label, depth) test_case.assertTrue(" must match the size of label" in str(context.exception)) def test_sparse_cross_entropy_label_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randn((10, 10), dtype=flow.float32) depth = 10 flow._C.sparse_cross_entropy(prediction, label, depth) test_case.assertTrue( "The dtype of label must be integer, but found" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_sparse_softmax_cross_entropy_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestSparseSoftmaxCrossEntropyError(flow.unittest.TestCase): def test_sparse_softmax_cross_entropy_prediction_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10) label = flow.randint(0, 10, (10, 10), dtype=flow.int64) flow._C.sparse_softmax_cross_entropy(prediction, label) test_case.assertTrue( "The dimension of prediction must be greater than or equal to 2, but found" in str(context.exception) ) def test_sparse_softmax_cross_entropy_label_numaxes_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randint(0, 10, (10, 10, 5), dtype=flow.int64) flow._C.sparse_softmax_cross_entropy(prediction, label) test_case.assertTrue( "The dimension of label is expected to be less than that of prediction by 1" in str(context.exception) ) def test_sparse_softmax_cross_entropy_prediction_i_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randint(0, 10, (10, 9), dtype=flow.int64) flow._C.sparse_softmax_cross_entropy(prediction, label) test_case.assertTrue("must match the size of label" in str(context.exception)) def test_sparse_softmax_cross_entropy_label_dtype_err(test_case): with test_case.assertRaises(TypeError) as context: prediction = flow.randn(10, 10, 5) label = flow.randn(10, 10, dtype=flow.float32) flow._C.sparse_softmax_cross_entropy(prediction, label) test_case.assertTrue( "The dtype of label must be integer, but found " in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_split_like_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestSplitLikeError(flow.unittest.TestCase): def test_split_like_like_axes_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.randn(4, 4) like = (flow.randn(2, 4, 4), flow.randn(2, 4, 4)) axis = 0 flow._C.split_like(x, like, axis) test_case.assertTrue( ") should be less than or equal to input (" in str(context.exception) ) def test_split_like_split_axes_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.randn(4, 4) like = (flow.randn(2, 4), flow.randn(2, 4)) axis = 3 flow._C.split_like(x, like, axis) test_case.assertTrue( "should be less than the dimension of like" in str(context.exception) ) def test_split_like_like_i_axes_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.randn(4, 4) like = (flow.randn(2, 4), flow.randn(2)) axis = 0 flow._C.split_like(x, like, axis) test_case.assertTrue( "must match the dimension of the first like" in str(context.exception) ) def test_split_like_x_i_shape_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.randn(4, 4) like = (flow.randn(2, 4), flow.randn(2, 3)) axis = 0 flow._C.split_like(x, like, axis) test_case.assertTrue("must match the size of like_i" in str(context.exception)) def test_split_like_non_dynamic_static_dim_err(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.randn(4, 4) like = (flow.randn(2, 4), flow.randn(3, 4)) axis = 0 flow._C.split_like(x, like, axis) test_case.assertTrue( "shape situation, the total size of like" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_stft_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import numpy as np @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_stft_illegal_input_dim(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(2, 2, 3) with test_case.assertRaises(RuntimeError) as ctx: x_flow = flow.tensor(np_tensor) flow.stft( x_flow, n_fft=4, center=True, onesided=True, return_complex=False, normalized=False, ) test_case.assertTrue("Expected a 1D or 2D tensor,but got" in str(ctx.exception)) def test_stft_illegal_nfft(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3) win_tensor = np.arange(1, 5, dtype=float) with test_case.assertRaises(RuntimeError) as ctx: x_flow = flow.tensor(np_tensor) flow_win = flow.tensor(win_tensor) flow.stft( x_flow, n_fft=-1, window=flow_win, center=True, onesided=True, return_complex=False, normalized=False, ) test_case.assertTrue("Expected 0 < n_fft" in str(ctx.exception)) def test_stft_illegal_hop_length(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3) with test_case.assertRaises(RuntimeError) as ctx: x_flow = flow.tensor(np_tensor) flow.stft( x_flow, n_fft=4, hop_length=-1, center=True, onesided=True, return_complex=False, normalized=False, ) test_case.assertTrue("Expected hop_length > 0, but got" in str(ctx.exception)) def test_stft_illegal_win_length(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3) with test_case.assertRaises(RuntimeError) as ctx: x_flow = flow.tensor(np_tensor) flow.stft( x_flow, n_fft=4, win_length=-1, center=True, onesided=True, return_complex=False, normalized=False, ) test_case.assertTrue( "Expected 0 < win_length <=n_fft ,but got" in str(ctx.exception) ) def test_stft_illegal_window(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(2, 6) win_tensor = np.arange(1, 10, dtype=float) with test_case.assertRaises(RuntimeError) as ctx: x_flow = flow.tensor(np_tensor) flow_win = flow.tensor(win_tensor) flow.stft( x_flow, n_fft=4, window=flow_win, center=True, onesided=True, return_complex=False, normalized=False, ) test_case.assertTrue( "Expected a 1D window tensor of size equal to win_length=" in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_tensor_index.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow.unittest import oneflow as flow class TestTensorIndexError(flow.unittest.TestCase): def test_PrepareSliceIndices_indices_amount_index_error(test_case): with test_case.assertRaises(IndexError) as context: x = flow.arange(16).reshape(4, 4) x[0, 0, 0] = 0 test_case.assertTrue( "Too many indices for tensor of dimension" in str(context.exception) ) def test_PrepareSliceIndices_slice_step_runtime_error(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor([0, 1, 2, 3], dtype=flow.int32) s = slice(0, 2, -1) y = x[s] test_case.assertTrue("Step must be greater than zero" in str(context.exception)) def test_ApplySelectIndexing_input_dim_runtime_error(test_case): with test_case.assertRaises(RuntimeError) as context: x = flow.tensor(5, dtype=flow.int32) y = x[0] test_case.assertTrue( "select() cannot be applied to a 0-dim tensor." in str(context.exception) ) def test_ApplySelectIndexing_index_error(test_case): with test_case.assertRaises(IndexError) as context: x = flow.ones(2, 3, dtype=flow.int32) y = x[3] test_case.assertTrue( "Index out of range (expected to be in range of" in str(context.exception) ) def test_ApplyAdvancedIndexing_index_error(test_case): with test_case.assertRaises(IndexError) as context: x = flow.ones(2, 2, dtype=flow.int32) index = ( flow.tensor(1, dtype=flow.int32), flow.tensor(1, dtype=flow.int32), flow.tensor(1, dtype=flow.int32), ) y = x[index] test_case.assertTrue( "Too many indices for tensor of dimension" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_tensordot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import oneflow.nn.functional as F import torch @flow.unittest.skip_unless_1n1d() class TestTensordotError(flow.unittest.TestCase): def test_tensordot_neg_dims_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=-1) test_case.assertTrue( "tensordot expects dims >= 0, but got dims=-1" in str(context.exception) ) @unittest.skip("PyTorch doesn't have corresponding error message") def test_tensordot_too_large_int_dims_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=100) test_case.assertTrue( "tensordot expects dims <= a.ndim which is 3, but got 100" in str(context.exception) ) def test_tensordot_out_of_range_dims_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=[[3], [2]]) test_case.assertTrue( "Dimension out of range (expected to be in range of [-3, 2], but got 3)" in str(context.exception) ) def test_tensordot_unmatch_dims_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=[[1], [2]]) test_case.assertTrue( "contracted dimensions need to match, but first has size 2 in dim 1 and second has size 3 in dim 2" in str(context.exception) ) def test_tensordot_recurring_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=[[1, 1], [1, 1]]) test_case.assertTrue( "dim 1 appears multiple times in the list of dims" in str(context.exception) ) def test_tensordot_dims_different_length_runtime_error(test_case): with test_case.assertRaises(Exception) as context: a = flow.randn(1, 2, 3) b = flow.randn(1, 2, 3) flow.tensordot(a, b, dims=[[1], [1, 2]]) test_case.assertTrue( "both dimension lists should have same length" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_to_global_error.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n2d() class TestToGlobalError(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_tensor_to_consistent(self): with self.assertRaises(Exception) as context: data = flow.rand(2, dtype=flow.float32) placement = flow.placement.all("cuda") sbp = flow.sbp.split(0) global_data = data.to_consistent(placement=placement, sbp=sbp) self.assertTrue( ".to_consistent has been removed, please use .to_global instead" in str(context.exception) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_tensor_is_global(self): with self.assertRaises(Exception) as context: data = flow.rand(2, dtype=flow.float32) print(data.is_consistent()) self.assertTrue( ".is_consistent has been removed, please use .is_global instead" in str(context.exception) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_module_to_consistent(self): with self.assertRaises(Exception) as context: m = flow.nn.Conv2d(1, 1, 1) placement = flow.placement.all("cuda") sbp = flow.sbp.split(0) m.to_consistent(placement=placement, sbp=sbp) self.assertTrue( ".to_consistent has been removed, please use .to_global instead" in str(context.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/test_view.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_view_exception(test_case): # torch exception and messge: # # RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. # a = flow.arange(9).reshape(3, 3) b = a.permute(1, 0) with test_case.assertRaises(RuntimeError) as ctx: print(b.view(9)) test_case.assertTrue( "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." in str(ctx.exception) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/exceptions/throw_error.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This file is intended to be run in # python/oneflow/test/exceptions/test_error_reported_in_thread.py import oneflow as flow def g(x): flow._C.throw_error(x) def f(x): x = x.relu() g(x) x = flow.ones(3, 3, 4) f(x) ================================================ FILE: python/oneflow/test/expensive/README.md ================================================ # Expensive tests - Tests requires a lot of time, memory to run. - Every test should have exclusive access to GPU when running ================================================ FILE: python/oneflow/test/expensive/_internally_replaced_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import importlib.machinery def _download_file_from_remote_location(fpath: str, url: str) -> None: pass def _is_remote_location_available() -> bool: return False try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url def _get_extension_path(lib_name): lib_dir = os.path.dirname(__file__) if os.name == "nt": # Register the main torchvision library location on the default DLL path import ctypes import sys kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) with_load_library_flags = hasattr(kernel32, "AddDllDirectory") prev_error_mode = kernel32.SetErrorMode(0x0001) if with_load_library_flags: kernel32.AddDllDirectory.restype = ctypes.c_void_p if sys.version_info >= (3, 8): os.add_dll_directory(lib_dir) elif with_load_library_flags: res = kernel32.AddDllDirectory(lib_dir) if res is None: err = ctypes.WinError(ctypes.get_last_error()) err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' raise err kernel32.SetErrorMode(prev_error_mode) loader_details = ( importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES, ) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec(lib_name) if ext_specs is None: raise ImportError return ext_specs.origin ================================================ FILE: python/oneflow/test/expensive/_test_remat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ This file (_test_remat.py) is intended to be run inside test_remat.py with correct environment variables like ONEFLOW_VM_MULTI_THREAD=0 """ from contextlib import contextmanager import os import unittest import functools import numpy as np import oneflow as flow from oneflow import nn import flowvision import oneflow.unittest def evict(tensor): flow._oneflow_internal.remat.evict(tensor) def is_in_memory(tensor): return flow._oneflow_internal.remat.is_in_memory(tensor) placeholder_size = 0 def allocated_memory(device, include_test_placeholder=False): if device == "cuda" and not flow.sysconfig.with_cuda(): return 0 return flow._oneflow_internal.remat.allocated_memory(device) - ( 0 if include_test_placeholder else placeholder_size ) def display(device): return flow._oneflow_internal.remat.display(device) def only_fbip(): if os.getenv("ONEFLOW_REMAT_COPY_ON_WRITE") is None: return lambda f: f else: return unittest.skip("") def only_copy_on_write(): if os.getenv("ONEFLOW_REMAT_COPY_ON_WRITE") is not None: return lambda f: f else: return unittest.skip("") def loss_test(): if os.getenv("ONEFLOW_REMAT_RUN_LOSS_TEST") is not None: return lambda f: f else: return unittest.skip( "Environment variable 'ONEFLOW_REMAT_RUN_LOSS_TEST' need to be set to run this test." ) @contextmanager def generate_placeholder(size_mb, device): global placeholder_size placeholder_size = size_mb * 1024 * 1024 x = flow.zeros(int(placeholder_size), dtype=flow.int8, device=device) flow._oneflow_internal.remat.disable_eviction(x) try: yield finally: del x placeholder_size = 0 def memory_budget(budget_mb, device): if device == "cuda" and not oneflow.sysconfig.with_cuda(): return unittest.skip("Skip CUDA tests on CPU build") def deco(f): @functools.wraps(f) def new_f(*args, **kwargs): total_budget = flow.remat.get_budget() / 1024 / 1024 assert total_budget >= budget_mb, "Not enough memory budget" remat_device = device + "+remat" with generate_placeholder(total_budget - budget_mb, remat_device): return f(*args, remat_device, **kwargs) return new_f return deco class TestRemat(flow.unittest.TestCase): @classmethod def setUpClass(cls): flow.remat.set_budget("500MB") flow.remat.set_small_pieces_optimization(False) def setUp(self): super().setUp() assert ( os.getenv("ONEFLOW_VM_MULTI_THREAD") is not None ), "Please set ONEFLOW_VM_MULTI_THREAD to False, 0 or OFF" # check the memory is empty at the beginning of every test case if allocated_memory("cpu") > 0: print("allocated_memory(cpu):", allocated_memory("cpu")) display("cpu") if allocated_memory("cuda") > 0: print("allocated_memory(cuda):", allocated_memory("cuda")) display("cuda") self.assertEqual(allocated_memory("cpu"), 0) self.assertEqual(allocated_memory("cuda"), 0) flow._oneflow_internal.remat.clear_stats() def tearDown(self): super().tearDown() # check the memory is empty at the end of every test case self.assertEqual(allocated_memory("cpu"), 0) self.assertEqual(allocated_memory("cuda"), 0) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(12, "cpu") def test_remat_work_on_fbip_1(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB x2 = x1 * -2 # 8MB x3 = x2 - 2 # 12MB x2.relu_() # 12MB self.assertTrue(is_in_memory(x1)) self.assertTrue(is_in_memory(x2)) self.assertTrue(is_in_memory(x3)) evict(x3) self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * -4)) evict(x2) self.assertTrue(np.array_equal(x2.numpy(), np.zeros(x2.shape))) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(12, "cpu") def test_remat_work_on_fbip_2(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB x2 = x1[0] x3 = x2 + 2 evict(x3) self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 3)) evict(x2) evict(x3) self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 3)) evict(x2) self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape))) @flow.unittest.skip_unless_1n1d() @unittest.skip("mutation other than inplace is not supported yet") @only_fbip() @memory_budget(12, "cpu") def test_remat_work_on_fbip_3(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB x2 = x1 * -2 # 8MB x1.zero_() evict(x2) print(x2.numpy()) self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape) * -2)) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(12, "cuda") def test_remat_work_on_fbip_4(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB x2 = x1 + 1 x2 += x1 x3 = x2.relu() x4 = x3 + 1 evict(x3) evict(x2) evict(x1) evict(x3) self.assertTrue(np.array_equal(x4.numpy(), np.ones(x4.shape) * 4)) @flow.unittest.skip_unless_1n1d() @memory_budget(12, "cpu") def test_remat_work_on_simple_case_1(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB self.assertTrue(is_in_memory(x1)) self.assertEqual(allocated_memory(device), 4 * 1024 * 1024) x2 = x1 + 2 self.assertEqual(allocated_memory(device), 8 * 1024 * 1024) # eager eviction del x1 self.assertEqual(allocated_memory(device), 4 * 1024 * 1024) self.assertTrue(is_in_memory(x2)) x3 = x2 + 2 self.assertTrue(is_in_memory(x2)) x4 = x3 + 2 self.assertTrue(is_in_memory(x2)) x5 = x4 + 2 self.assertFalse(is_in_memory(x2)) self.assertTrue(is_in_memory(x3)) self.assertTrue(is_in_memory(x4)) x6 = x5 + 2 self.assertFalse(is_in_memory(x2)) # the eviction of x2 increases the cost of x3, so x4 is evicted self.assertTrue(is_in_memory(x3)) self.assertFalse(is_in_memory(x4)) self.assertTrue(np.array_equal(x6.numpy(), np.ones(x6.shape) * 11)) self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 5)) @flow.unittest.skip_unless_1n1d() @memory_budget(12, "cpu") def test_remat_work_on_simple_case_2(self, device): x1 = flow.ones(1024 * 1024, device=device) # 4MB self.assertTrue(is_in_memory(x1)) self.assertEqual(allocated_memory(device), 4 * 1024 * 1024) x2 = x1 + 2 # eager eviction del x1 self.assertTrue(is_in_memory(x2)) x3 = x2 + 2 self.assertTrue(is_in_memory(x2)) x4 = x3 + 2 self.assertTrue(is_in_memory(x2)) x5 = x4 + 2 self.assertFalse(is_in_memory(x2)) self.assertTrue(is_in_memory(x3)) self.assertTrue(is_in_memory(x4)) x6 = x5 + 2 self.assertFalse(is_in_memory(x2)) # the eviction of x2 increases the cost of x3, so x4 is evicted self.assertTrue(is_in_memory(x3)) self.assertFalse(is_in_memory(x4)) self.assertTrue(np.array_equal(x6.numpy(), np.ones(x6.shape) * 11)) self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 5)) @flow.unittest.skip_unless_1n1d() @memory_budget(12, "cpu") def test_remat_full_and_init_constant(self, device): x1 = flow.eye(1024, 1024, device=device) self.assertTrue(is_in_memory(x1)) self.assertEqual(allocated_memory(device), 4 * 1024 * 1024) x2 = flow.full(x1.shape, 3.0, device=device) flow.nn.init.constant_(x1, x2) # type: ignore[arg-type] del x2 self.assertEqual(allocated_memory(device), 4 * 1024 * 1024) evict(x1) self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape) * 3)) @flow.unittest.skip_unless_1n1d() @memory_budget(12, "cpu") def test_remat_lifecycle_of_view_tensor(self, device): x1 = flow.eye(2, 3, device=device) self.assertTrue(is_in_memory(x1)) x2 = flow.ones(3, device=device) x3 = flow.expand(x2, (2, 3)) x1[:] = x3 del x3 del x2 evict(x1) self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape))) @flow.unittest.skip_unless_1n1d() @memory_budget(16, "cpu") def test_remat_init_constant_and_scalar(self, device): x0 = flow.ones(1024, 1024).to(device) x1 = x0 + 0 x2 = x1 + 1 flow.nn.init.constant_(x1, 5.0) # type: ignore[arg-type] evict(x1) self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape) * 5)) evict(x1) evict(x2) self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape) * 2)) @flow.unittest.skip_unless_1n1d() @memory_budget(80, "cpu") def test_copy(self, device): x1 = flow.ones(1) x2 = x1.to(device) self.assertTrue(x2.device.rematable) x3 = x2.to(flow.int64) self.assertTrue(x3.device.rematable) x4 = x2 + 1 self.assertTrue(x4.device.rematable) @flow.unittest.skip_unless_1n1d() @memory_budget(80, "cuda") def test_simple_network(self, device): model = nn.Sequential( nn.Conv2d(3, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=False), nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=False), nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=False), nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=False), ).to(device) for p in model.parameters(): p.grad = flow.zeros_like(p).to(device) optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0) x = flow.ones(4, 3, 224, 224).to(device) mem = allocated_memory(device) for _ in range(10): mem2 = allocated_memory(device) self.assertEqual(mem, mem2) loss = model(x).sum() loss.backward() del loss optimizer.step() optimizer.zero_grad() def _test_resnet18(self, optimizer_fn, ddp, expected_loss): flow.manual_seed(flow.env.get_rank()) device = "cpu+remat" model = flowvision.models.resnet18().to(device) if ddp: model = flow.nn.parallel.DistributedDataParallel(model, use_bucket=False) criterion = nn.CrossEntropyLoss().to(device) for x in model.parameters(): x.grad = flow.zeros_like(x).to(device) # optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0) optimizer = optimizer_fn(model.parameters()) x = flow.rand(10, 3, 224, 224).to(device) target = ( flow.randint(low=0, high=1000, size=(x.shape[0],)).to(device).to(flow.int32) ) # NOTE: there is a bug in current implementation about random ops: # x1 = flow.rand(5) # x2 = x1 + 1 # del x1 <--- we cannot block the eviction of x1 here because it is controlled by the user # evict(x2) # recompute(x2) <-- recomputing x2 triggers the recomputation of x1 and causes inconsistentness flow._oneflow_internal.remat.disable_eviction(x) flow._oneflow_internal.remat.disable_eviction(target) ITER_NUM = 5 for i in range(ITER_NUM): print("start allocated_memory(cpu):", allocated_memory("cpu")) print( "recomputation num: ", flow._oneflow_internal.remat.recomputation_num() ) output = model(x) loss = criterion(output, target) del output print(loss.numpy().item()) if i == 4 and expected_loss is not None: self.assertTrue(loss.numpy().item() in expected_loss) loss.backward() del loss optimizer.step() optimizer.zero_grad() print("end allocated_memory(cpu):", allocated_memory("cpu")) print( "recomputation num: ", flow._oneflow_internal.remat.recomputation_num() ) # check there is more than 10 recomputations each iteration # so the correctness check makes sense. self.assertGreater( flow._oneflow_internal.remat.recomputation_num(), ITER_NUM * 10 ) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(220, "cpu") @loss_test() def test_resnet18_naive_sgd(self, _): # NOTE: this loss is only correct in my environment on 21 self._test_resnet18( lambda params: flow.optim.SGD(params, lr=0.1, momentum=0), False, [0.6304041147232056], ) @flow.unittest.skip_unless_1n2d() @only_fbip() @memory_budget(220, "cpu") @loss_test() def test_resnet18_naive_sgd_ddp_1n2d(self, _): # 2 devices, 2 losses # NOTE: these losses are only correct in my environment on 21 self._test_resnet18( lambda params: flow.optim.SGD(params, lr=0.1, momentum=0), True, [1.8890058994293213, 1.8992782831192017], ) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(270, "cpu") @loss_test() def test_resnet18_momentum_sgd(self, _): # NOTE: this loss is only correct in my environment on 21 self._test_resnet18( lambda params: flow.optim.SGD(params, lr=0.1, momentum=0.9), False, None ) @flow.unittest.skip_unless_1n1d() @only_fbip() @memory_budget(310, "cpu") @loss_test() def test_resnet18_adam(self, _): # NOTE: this loss is only correct in my environment on 21 self._test_resnet18(lambda params: flow.optim.Adam(params, lr=0.1), False, None) @flow.unittest.skip_unless_1n1d() @only_copy_on_write() @memory_budget(12, "cpu") def test_copy_on_write(self, _): x1 = flow.ones(1024 * 1024) # 4MB x2 = flow.ones(1024 * 1024) x3 = x2 + 1 x2 += x1 display("cpu") print(f"x1 in memory?: {is_in_memory(x1)}") print(f"x2 in memory?: {is_in_memory(x2)}") print(f"x3 in memory?: {is_in_memory(x3)}") print(f"recompute num: {flow._oneflow_internal.remat.recomputation_num()}") print( f"forced eviction num: {flow._oneflow_internal.remat.forced_eviction_num()}" ) print( f"eager eviction num: {flow._oneflow_internal.remat.eager_eviction_num()}" ) print("-------------") print(x3.numpy()) print(f"x1 in memory?: {is_in_memory(x1)}") print(f"x2 in memory?: {is_in_memory(x2)}") print(f"x3 in memory?: {is_in_memory(x3)}") print(f"recompute num: {flow._oneflow_internal.remat.recomputation_num()}") print( f"forced eviction num: {flow._oneflow_internal.remat.forced_eviction_num()}" ) print( f"eager eviction num: {flow._oneflow_internal.remat.eager_eviction_num()}" ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/pytorch_alexnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from typing import Any __all__ = ["AlexNet", "alexnet"] class AlexNet(nn.Module): def __init__(self, num_classes: int = 1000) -> None: super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper. The required minimum input size of the model is 63x63. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ model = AlexNet(**kwargs) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_convmixer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch.nn as nn __all__ = ["ConvMixer", "convmixer_768_32_relu"] class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000): return nn.Sequential( nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), nn.GELU(), nn.BatchNorm2d(dim), *[ nn.Sequential( Residual( nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.GELU(), nn.BatchNorm2d(dim), ) ), nn.Conv2d(dim, dim, kernel_size=1), nn.GELU(), nn.BatchNorm2d(dim), ) for i in range(depth) ], nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(dim, n_classes) ) def convmixer_768_32_relu(pretrained: bool = False, progress: bool = True, **kwargs): """ Constructs the ConvMixer model with 32 depth and 768 hidden size and ReLU activation layer. .. note:: ConvMixer model with 32 depth and 768 hidden size and ReLU activation layer from the `Patched Are All You Need? `_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` """ model = ConvMixer(768, 32, kernel_size=7, patch_size=7, n_classes=1000) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_convnext.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_, DropPath __all__ = ["ConvNeXt", "convnext_tiny"] class Block(nn.Module): r""" ConvNeXt Block. There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back We use (2) as we find it slightly faster in PyTorch Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): super().__init__() self.dwconv = nn.Conv2d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.norm = LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, 4 * dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class ConvNeXt(nn.Module): r""" ConvNeXt A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_path_rate (float): Stochastic depth rate. Default: 0. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.0, layer_scale_init_value=1e-6, head_init_scale=1.0, ): super().__init__() self.downsample_layers = ( nn.ModuleList() ) # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) self.stages = ( nn.ModuleList() ) # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.Sequential( *[ Block( dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value, ) for j in range(depths[i]) ] ) self.stages.append(stage) cur += depths[i] self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer self.head = nn.Linear(dims[-1], num_classes) self.apply(self._init_weights) self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward_features(self, x): for i in range(4): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm( x.mean([-2, -1]) ) # global average pooling, (N, C, H, W) -> (N, C) def forward(self, x): x = self.forward_features(x) x = self.head(x) return x class LayerNorm(nn.Module): r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def convnext_tiny(pretrained=False, in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_crossformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ __all__ = ["CrossFormer", "crossformer_tiny_patch4_group7_224"] class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DynamicPosBias(nn.Module): def __init__(self, dim, num_heads, residual): super().__init__() self.residual = residual self.num_heads = num_heads self.pos_dim = dim // 4 self.pos_proj = nn.Linear(2, self.pos_dim) self.pos1 = nn.Sequential( nn.LayerNorm(self.pos_dim), nn.ReLU(inplace=True), nn.Linear(self.pos_dim, self.pos_dim), ) self.pos2 = nn.Sequential( nn.LayerNorm(self.pos_dim), nn.ReLU(inplace=True), nn.Linear(self.pos_dim, self.pos_dim), ) self.pos3 = nn.Sequential( nn.LayerNorm(self.pos_dim), nn.ReLU(inplace=True), nn.Linear(self.pos_dim, self.num_heads), ) def forward(self, biases): if self.residual: pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads pos = pos + self.pos1(pos) pos = pos + self.pos2(pos) pos = self.pos3(pos) else: pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) return pos def flops(self, N): flops = N * 2 * self.pos_dim flops += N * self.pos_dim * self.pos_dim flops += N * self.pos_dim * self.pos_dim flops += N * self.pos_dim * self.num_heads return flops class Attention(nn.Module): r""" Multi-head self attention module with dynamic position bias. Args: dim (int): Number of input channels. group_size (tuple[int]): The height and width of the group. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, position_bias=True, ): super().__init__() self.dim = dim self.group_size = group_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.position_bias = position_bias if position_bias: self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) # generate mother-set position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0]) position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1]) biases = torch.stack( torch.meshgrid([position_bias_h, position_bias_w]) ) # 2, 2Wh-1, 2W2-1 biases = biases.flatten(1).transpose(0, 1).float() self.register_buffer("biases", biases) # get pair-wise relative position index for each token inside the group coords_h = torch.arange(self.group_size[0]) coords_w = torch.arange(self.group_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = ( coords_flatten[:, :, None] - coords_flatten[:, None, :] ) # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0 ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.group_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.group_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_groups*B, N, C) mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) if self.position_bias: pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads # select position bias relative_position_bias = pos[self.relative_position_index.view(-1)].view( self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1, ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 1 ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return ( f"dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}" ) def flops(self, N): # calculate flops for 1 group with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim if self.position_bias: flops += self.pos.flops(N) return flops class CrossFormerBlock(nn.Module): r""" CrossFormer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. group_size (int): Group size. lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.group_size = group_size self.lsda_flag = lsda_flag self.mlp_ratio = mlp_ratio self.num_patch_size = num_patch_size if min(self.input_resolution) <= self.group_size: # if group size is larger than input resolution, we don't partition groups self.lsda_flag = 0 self.group_size = min(self.input_resolution) self.norm1 = norm_layer(dim) self.attn = Attention( dim, group_size=to_2tuple(self.group_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, position_bias=True, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # group embeddings G = self.group_size if self.lsda_flag == 0: # 0 for SDA x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5) else: # 1 for LDA x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5) x = x.reshape(B * H * W // G ** 2, G ** 2, C) # multi-head self-attention x = self.attn(x, mask=self.attn_mask) # nW*B, G*G, C # ungroup embeddings x = x.reshape(B, H // G, W // G, G, G, C) if self.lsda_flag == 0: x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C) else: x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C) x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return ( f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}" ) def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # LSDA nW = H * W / self.group_size / self.group_size flops += nW * self.attn.flops(self.group_size * self.group_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1, ): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reductions = nn.ModuleList() self.patch_size = patch_size self.norm = norm_layer(dim) for i, ps in enumerate(patch_size): if i == len(patch_size) - 1: out_dim = 2 * dim // 2 ** i else: out_dim = 2 * dim // 2 ** (i + 1) stride = 2 padding = (ps - stride) // 2 self.reductions.append( nn.Conv2d(dim, out_dim, kernel_size=ps, stride=stride, padding=padding) ) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = self.norm(x) x = x.view(B, H, W, C).permute(0, 3, 1, 2) xs = [] for i in range(len(self.reductions)): tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2) xs.append(tmp_x) x = torch.cat(xs, dim=2) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim for i, ps in enumerate(self.patch_size): if i == len(self.patch_size) - 1: out_dim = 2 * self.dim // 2 ** i else: out_dim = 2 * self.dim // 2 ** (i + 1) flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim return flops class Stage(nn.Module): """ CrossFormer blocks for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. group_size (int): variable G in the paper, one group has GxG embeddings mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, input_resolution, depth, num_heads, group_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, patch_size_end=[4], num_patch_size=None, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList() for i in range(depth): lsda_flag = 0 if (i % 2 == 0) else 1 self.blocks.append( CrossFormerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, group_size=group_size, lsda_flag=lsda_flag, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, num_patch_size=num_patch_size, ) ) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=norm_layer, patch_size=patch_size_end, num_input_patch_size=num_patch_size, ) else: self.downsample = None def forward(self, x): for blk in self.blocks: # if self.use_checkpoint: # x = checkpoint.checkpoint(blk, x) # else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: [4]. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__( self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None ): super().__init__() img_size = to_2tuple(img_size) # patch_size = to_2tuple(patch_size) patches_resolution = [ img_size[0] // patch_size[0], img_size[0] // patch_size[0], ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.projs = nn.ModuleList() for i, ps in enumerate(patch_size): if i == len(patch_size) - 1: dim = embed_dim // 2 ** i else: dim = embed_dim // 2 ** (i + 1) stride = patch_size[0] padding = (ps - patch_size[0]) // 2 self.projs.append( nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding) ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert ( H == self.img_size[0] and W == self.img_size[1] ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." xs = [] for i in range(len(self.projs)): tx = self.projs[i](x).flatten(2).transpose(1, 2) xs.append(tx) # B Ph*Pw C x = torch.cat(xs, dim=2) if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = 0 for i, ps in enumerate(self.patch_size): if i == len(self.patch_size) - 1: dim = self.embed_dim // 2 ** i else: dim = self.embed_dim // 2 ** (i + 1) flops += ( Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i]) ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class CrossFormer(nn.Module): r""" CrossFormer A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` - Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each stage. num_heads (tuple(int)): Number of attention heads in different layers. group_size (int): Group size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__( self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], group_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs, ): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size] for i_layer in range(self.num_layers): patch_size_end = ( merge_size[i_layer] if i_layer < self.num_layers - 1 else None ) num_patch_size = num_patch_sizes[i_layer] layer = Stage( dim=int(embed_dim * 2 ** i_layer), input_resolution=( patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer), ), depth=depths[i_layer], num_heads=num_heads[i_layer], group_size=group_size[i_layer], mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, patch_size_end=patch_size_end, num_patch_size=num_patch_size, ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = ( nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): return {"absolute_pos_embed"} def no_weight_decay_keywords(self): return {"relative_position_bias_table"} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += ( self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) ) flops += self.num_features * self.num_classes return flops def _create_cross_former(arch, pretrained=False, progress=True, **model_kwargs): model = CrossFormer(**model_kwargs) return model def crossformer_tiny_patch4_group7_224(pretrained=False, progress=True, **kwargs): """ Constructs CrossFormer-T 224x224 model. .. note:: CrossFormer-T 224x224 model from `"CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention" `_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> crossformer_tiny_patch4_group7_224 = flowvision.models.crossformer_tiny_patch4_group7_224(pretrained=False, progress=True) """ model_kwargs = dict( img_size=224, patch_size=(4, 8, 16, 32), embed_dim=64, depths=(1, 1, 8, 6), num_heads=(2, 4, 8, 16), group_size=(7, 7, 7, 7), merge_size=((2, 4), (2, 4), (2, 4)), drop_path_rate=0.1, **kwargs, ) return _create_cross_former( "crossformer_tiny_patch4_group7_224", pretrained=pretrained, progress=progress, **model_kwargs, ) ================================================ FILE: python/oneflow/test/expensive/pytorch_densenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from collections import OrderedDict from typing import Any, List, Tuple __all__ = [ "DenseNet", "densenet121", ] class _DenseLayer(nn.Module): def __init__( self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False, ) -> None: super().__init__() self.norm1: nn.BatchNorm2d self.add_module("norm1", nn.BatchNorm2d(num_input_features)) self.relu1: nn.ReLU self.add_module("relu1", nn.ReLU(inplace=True)) self.conv1: nn.Conv2d self.add_module( "conv1", nn.Conv2d( num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False, ), ) self.norm2: nn.BatchNorm2d self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)) self.relu2: nn.ReLU self.add_module("relu2", nn.ReLU(inplace=True)) self.conv2: nn.Conv2d self.add_module( "conv2", nn.Conv2d( bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False, ), ) self.drop_rate = float(drop_rate) self.memory_efficient = memory_efficient def bn_function(self, inputs: List[Tensor]) -> Tensor: concated_features = torch.cat(inputs, 1) bottleneck_output = self.conv1( self.relu1(self.norm1(concated_features)) ) # noqa: T484 return bottleneck_output # todo: rewrite when torchscript supports any def any_requires_grad(self, input: List[Tensor]) -> bool: for tensor in input: if tensor.requires_grad: return True return False # torchscript does not yet support *args, so we overload method # allowing it to take either a List[Tensor] or single Tensor def forward(self, input: Tensor) -> Tensor: # noqa: F811 if isinstance(input, Tensor): prev_features = [input] else: prev_features = input if self.memory_efficient and self.any_requires_grad(prev_features): if torch.jit.is_scripting(): raise Exception("Memory Efficient not supported in JIT") bottleneck_output = self.call_checkpoint_bottleneck(prev_features) else: bottleneck_output = self.bn_function(prev_features) new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) if self.drop_rate > 0: new_features = F.dropout( new_features, p=self.drop_rate, training=self.training ) return new_features class _DenseBlock(nn.ModuleDict): _version = 2 def __init__( self, num_layers: int, num_input_features: int, bn_size: int, growth_rate: int, drop_rate: float, memory_efficient: bool = False, ) -> None: super().__init__() for i in range(num_layers): layer = _DenseLayer( num_input_features + i * growth_rate, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate, memory_efficient=memory_efficient, ) self.add_module("denselayer%d" % (i + 1), layer) def forward(self, init_features: Tensor) -> Tensor: features = [init_features] for name, layer in self.items(): new_features = layer(features) features.append(new_features) return torch.cat(features, 1) class _Transition(nn.Sequential): def __init__(self, num_input_features: int, num_output_features: int) -> None: super().__init__() self.add_module("norm", nn.BatchNorm2d(num_input_features)) self.add_module("relu", nn.ReLU(inplace=True)) self.add_module( "conv", nn.Conv2d( num_input_features, num_output_features, kernel_size=1, stride=1, bias=False, ), ) self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" `_. Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ def __init__( self, growth_rate: int = 32, block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000, memory_efficient: bool = False, ) -> None: super().__init__() # First convolution self.features = nn.Sequential( OrderedDict( [ ( "conv0", nn.Conv2d( 3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False, ), ), ("norm0", nn.BatchNorm2d(num_init_features)), ("relu0", nn.ReLU(inplace=True)), ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ] ) ) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock( num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, memory_efficient=memory_efficient, ) self.features.add_module("denseblock%d" % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition( num_input_features=num_features, num_output_features=num_features // 2, ) self.features.add_module("transition%d" % (i + 1), trans) num_features = num_features // 2 # Final batch norm self.features.add_module("norm5", nn.BatchNorm2d(num_features)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) def forward(self, x: Tensor) -> Tensor: features = self.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) out = self.classifier(out) return out def _densenet( growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, progress: bool, **kwargs: Any, ) -> DenseNet: model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) return model def densenet121(progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: weights (DenseNet121_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ return _densenet(32, (6, 12, 24, 16), 64, progress, **kwargs) ================================================ FILE: python/oneflow/test/expensive/pytorch_efficientnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch from torch import nn, Tensor from torchvision.ops import StochasticDepth import copy import math import warnings from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, List, Sequence, Tuple, Union __all__ = [ "EfficientNet", "efficientnet_b0", ] class SqueezeExcitation(torch.nn.Module): """ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. Args: input_channels (int): Number of channels in the input image squeeze_channels (int): Number of squeeze channels activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` """ def __init__( self, input_channels: int, squeeze_channels: int, activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, ) -> None: super().__init__() self.avgpool = torch.nn.AdaptiveAvgPool2d(1) self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) self.activation = activation() self.scale_activation = scale_activation() def _scale(self, input: Tensor) -> Tensor: scale = self.avgpool(input) scale = self.fc1(scale) scale = self.activation(scale) scale = self.fc2(scale) return self.scale_activation(scale) def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) return scale * input class ConvNormActivation(torch.nn.Sequential): def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, ) -> None: if padding is None: padding = (kernel_size - 1) // 2 * dilation if bias is None: bias = norm_layer is None layers = [ conv_layer( in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=bias, ) ] if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: params = {} if inplace is None else {"inplace": inplace} layers.append(activation_layer(**params)) super().__init__(*layers) self.out_channels = out_channels if self.__class__ == ConvNormActivation: warnings.warn( "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." ) class Conv2dNormActivation(ConvNormActivation): """ Configurable block used for Convolution2d-Normalization-Activation blocks. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block kernel_size: (int, optional): Size of the convolving kernel. Default: 3 stride (int, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, groups, norm_layer, activation_layer, dilation, inplace, bias, torch.nn.Conv2d, ) def _make_divisible(v, divisor=8, min_value=None, round_limit=0.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v @dataclass class _MBConvConfig: expand_ratio: float kernel: int stride: int input_channels: int out_channels: int num_layers: int block: Callable[..., nn.Module] @staticmethod def adjust_channels( channels: int, width_mult: float, min_value: Optional[int] = None ) -> int: return _make_divisible(channels * width_mult, 8, min_value) class MBConvConfig(_MBConvConfig): # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper def __init__( self, expand_ratio: float, kernel: int, stride: int, input_channels: int, out_channels: int, num_layers: int, width_mult: float = 1.0, depth_mult: float = 1.0, block: Optional[Callable[..., nn.Module]] = None, ) -> None: input_channels = self.adjust_channels(input_channels, width_mult) out_channels = self.adjust_channels(out_channels, width_mult) num_layers = self.adjust_depth(num_layers, depth_mult) if block is None: block = MBConv super().__init__( expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block, ) @staticmethod def adjust_depth(num_layers: int, depth_mult: float): return int(math.ceil(num_layers * depth_mult)) class FusedMBConvConfig(_MBConvConfig): # Stores information listed at Table 4 of the EfficientNetV2 paper def __init__( self, expand_ratio: float, kernel: int, stride: int, input_channels: int, out_channels: int, num_layers: int, block: Optional[Callable[..., nn.Module]] = None, ) -> None: if block is None: block = FusedMBConv super().__init__( expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block, ) class MBConv(nn.Module): def __init__( self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], se_layer: Callable[..., nn.Module] = SqueezeExcitation, ) -> None: super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError("illegal stride value") self.use_res_connect = ( cnf.stride == 1 and cnf.input_channels == cnf.out_channels ) layers: List[nn.Module] = [] activation_layer = nn.SiLU # expand expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: layers.append( Conv2dNormActivation( cnf.input_channels, expanded_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation_layer, ) ) # depthwise layers.append( Conv2dNormActivation( expanded_channels, expanded_channels, kernel_size=cnf.kernel, stride=cnf.stride, groups=expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer, ) ) # squeeze and excitation squeeze_channels = max(1, cnf.input_channels // 4) layers.append( se_layer( expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True), ) ) # project layers.append( Conv2dNormActivation( expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None, ) ) self.block = nn.Sequential(*layers) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.out_channels = cnf.out_channels def forward(self, input: Tensor) -> Tensor: result = self.block(input) if self.use_res_connect: result = self.stochastic_depth(result) result += input return result class FusedMBConv(nn.Module): def __init__( self, cnf: FusedMBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], ) -> None: super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError("illegal stride value") self.use_res_connect = ( cnf.stride == 1 and cnf.input_channels == cnf.out_channels ) layers: List[nn.Module] = [] activation_layer = nn.SiLU expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: # fused expand layers.append( Conv2dNormActivation( cnf.input_channels, expanded_channels, kernel_size=cnf.kernel, stride=cnf.stride, norm_layer=norm_layer, activation_layer=activation_layer, ) ) # project layers.append( Conv2dNormActivation( expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None, ) ) else: layers.append( Conv2dNormActivation( cnf.input_channels, cnf.out_channels, kernel_size=cnf.kernel, stride=cnf.stride, norm_layer=norm_layer, activation_layer=activation_layer, ) ) self.block = nn.Sequential(*layers) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.out_channels = cnf.out_channels def forward(self, input: Tensor) -> Tensor: result = self.block(input) if self.use_res_connect: result = self.stochastic_depth(result) result += input return result class EfficientNet(nn.Module): def __init__( self, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, stochastic_depth_prob: float = 0.2, num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, last_channel: Optional[int] = None, **kwargs: Any, ) -> None: """ EfficientNet V1 and V2 main class Args: inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure dropout (float): The droupout probability stochastic_depth_prob (float): The stochastic depth probability num_classes (int): Number of classes norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use last_channel (int): The number of channels on the penultimate layer """ super().__init__() if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") elif not ( isinstance(inverted_residual_setting, Sequence) and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting]) ): raise TypeError( "The inverted_residual_setting should be List[MBConvConfig]" ) if "block" in kwargs: warnings.warn( "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. " "Please pass this information on 'MBConvConfig.block' instead." ) if kwargs["block"] is not None: for s in inverted_residual_setting: if isinstance(s, MBConvConfig): s.block = kwargs["block"] if norm_layer is None: norm_layer = nn.BatchNorm2d layers: List[nn.Module] = [] # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU, ) ) # building inverted residual blocks total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting) stage_block_id = 0 for cnf in inverted_residual_setting: stage: List[nn.Module] = [] for _ in range(cnf.num_layers): # copy to avoid modifications. shallow copy is enough block_cnf = copy.copy(cnf) # overwrite info if not the first conv in the stage if stage: block_cnf.input_channels = block_cnf.out_channels block_cnf.stride = 1 # adjust stochastic depth probability based on the depth of the stage block sd_prob = ( stochastic_depth_prob * float(stage_block_id) / total_stage_blocks ) stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = ( last_channel if last_channel is not None else 4 * lastconv_input_channels ) layers.append( Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.SiLU, ) ) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(lastconv_output_channels, num_classes), ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): init_range = 1.0 / math.sqrt(m.out_features) nn.init.uniform_(m.weight, -init_range, init_range) nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> Tensor: x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _efficientnet( inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, last_channel: Optional[int], progress: bool, **kwargs: Any, ) -> EfficientNet: model = EfficientNet( inverted_residual_setting, dropout, last_channel=last_channel, **kwargs ) return model def _efficientnet_conf( arch: str, **kwargs: Any, ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] if arch.startswith("efficientnet_b"): bneck_conf = partial( MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"), ) inverted_residual_setting = [ bneck_conf(1, 3, 1, 32, 16, 1), bneck_conf(6, 3, 2, 16, 24, 2), bneck_conf(6, 5, 2, 24, 40, 2), bneck_conf(6, 3, 2, 40, 80, 3), bneck_conf(6, 5, 1, 80, 112, 3), bneck_conf(6, 5, 2, 112, 192, 4), bneck_conf(6, 3, 1, 192, 320, 1), ] last_channel = None elif arch.startswith("efficientnet_v2_s"): inverted_residual_setting = [ FusedMBConvConfig(1, 3, 1, 24, 24, 2), FusedMBConvConfig(4, 3, 2, 24, 48, 4), FusedMBConvConfig(4, 3, 2, 48, 64, 4), MBConvConfig(4, 3, 2, 64, 128, 6), MBConvConfig(6, 3, 1, 128, 160, 9), MBConvConfig(6, 3, 2, 160, 256, 15), ] last_channel = 1280 elif arch.startswith("efficientnet_v2_m"): inverted_residual_setting = [ FusedMBConvConfig(1, 3, 1, 24, 24, 3), FusedMBConvConfig(4, 3, 2, 24, 48, 5), FusedMBConvConfig(4, 3, 2, 48, 80, 5), MBConvConfig(4, 3, 2, 80, 160, 7), MBConvConfig(6, 3, 1, 160, 176, 14), MBConvConfig(6, 3, 2, 176, 304, 18), MBConvConfig(6, 3, 1, 304, 512, 5), ] last_channel = 1280 elif arch.startswith("efficientnet_v2_l"): inverted_residual_setting = [ FusedMBConvConfig(1, 3, 1, 32, 32, 4), FusedMBConvConfig(4, 3, 2, 32, 64, 7), FusedMBConvConfig(4, 3, 2, 64, 96, 7), MBConvConfig(4, 3, 2, 96, 192, 10), MBConvConfig(6, 3, 1, 192, 224, 19), MBConvConfig(6, 3, 2, 224, 384, 25), MBConvConfig(6, 3, 1, 384, 640, 7), ] last_channel = 1280 else: raise ValueError(f"Unsupported model type {arch}") return inverted_residual_setting, last_channel def efficientnet_b0(progress: bool = True, **kwargs: Any) -> EfficientNet: """ Constructs a EfficientNet B0 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: weights (EfficientNet_B0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ inverted_residual_setting, last_channel = _efficientnet_conf( "efficientnet_b0", width_mult=1.0, depth_mult=1.0 ) return _efficientnet( inverted_residual_setting, 0.2, last_channel, progress, **kwargs ) ================================================ FILE: python/oneflow/test/expensive/pytorch_ghostnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import math __all__ = ["ghost_net"] def _make_divisible(v, divisor, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v class SELayer(nn.Module): def __init__(self, channel, reduction=4): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) y = torch.clamp(y, 0, 1) return x * y def depthwise_conv(inp, oup, kernel_size=3, stride=1, relu=False): return nn.Sequential( nn.Conv2d( inp, oup, kernel_size, stride, kernel_size // 2, groups=inp, bias=False ), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) if relu else nn.Sequential(), ) class GhostModule(nn.Module): def __init__( self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True ): super(GhostModule, self).__init__() self.oup = oup init_channels = math.ceil(oup / ratio) new_channels = init_channels * (ratio - 1) self.primary_conv = nn.Sequential( nn.Conv2d( inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False ), nn.BatchNorm2d(init_channels), nn.ReLU(inplace=True) if relu else nn.Sequential(), ) self.cheap_operation = nn.Sequential( nn.Conv2d( init_channels, new_channels, dw_size, 1, dw_size // 2, groups=init_channels, bias=False, ), nn.BatchNorm2d(new_channels), nn.ReLU(inplace=True) if relu else nn.Sequential(), ) def forward(self, x): x1 = self.primary_conv(x) x2 = self.cheap_operation(x1) out = torch.cat([x1, x2], dim=1) return out[:, : self.oup, :, :] class GhostBottleneck(nn.Module): def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se): super(GhostBottleneck, self).__init__() assert stride in [1, 2] self.conv = nn.Sequential( # pw GhostModule(inp, hidden_dim, kernel_size=1, relu=True), # dw depthwise_conv(hidden_dim, hidden_dim, kernel_size, stride, relu=False) if stride == 2 else nn.Sequential(), # Squeeze-and-Excite SELayer(hidden_dim) if use_se else nn.Sequential(), # pw-linear GhostModule(hidden_dim, oup, kernel_size=1, relu=False), ) if stride == 1 and inp == oup: self.shortcut = nn.Sequential() else: self.shortcut = nn.Sequential( depthwise_conv(inp, inp, kernel_size, stride, relu=False), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): return self.conv(x) + self.shortcut(x) class GhostNet(nn.Module): def __init__(self, cfgs, num_classes=1000, width_mult=1.0): super(GhostNet, self).__init__() # setting of inverted residual blocks self.cfgs = cfgs # building first layer output_channel = _make_divisible(16 * width_mult, 4) layers = [ nn.Sequential( nn.Conv2d(3, output_channel, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channel), nn.ReLU(inplace=True), ) ] input_channel = output_channel # building inverted residual blocks block = GhostBottleneck for k, exp_size, c, use_se, s in self.cfgs: output_channel = _make_divisible(c * width_mult, 4) hidden_channel = _make_divisible(exp_size * width_mult, 4) layers.append( block(input_channel, hidden_channel, output_channel, k, s, use_se) ) input_channel = output_channel self.features = nn.Sequential(*layers) # building last several layers output_channel = _make_divisible(exp_size * width_mult, 4) self.squeeze = nn.Sequential( nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channel), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), ) input_channel = output_channel output_channel = 1280 self.classifier = nn.Sequential( nn.Linear(input_channel, output_channel, bias=False), nn.BatchNorm1d(output_channel), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(output_channel, num_classes), ) self._initialize_weights() def forward(self, x): x = self.features(x) x = self.squeeze(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def ghost_net(**kwargs): """ Constructs a GhostNet model """ cfgs = [ # k, t, c, SE, s [3, 16, 16, 0, 1], [3, 48, 24, 0, 2], [3, 72, 24, 0, 1], [5, 72, 40, 1, 2], [5, 120, 40, 1, 1], [3, 240, 80, 0, 2], [3, 200, 80, 0, 1], [3, 184, 80, 0, 1], [3, 184, 80, 0, 1], [3, 480, 112, 1, 1], [3, 672, 112, 1, 1], [5, 672, 160, 1, 2], [5, 960, 160, 0, 1], [5, 960, 160, 1, 1], [5, 960, 160, 0, 1], [5, 960, 160, 1, 1], ] return GhostNet(cfgs, **kwargs) ================================================ FILE: python/oneflow/test/expensive/pytorch_googlenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor import warnings from typing import Optional, Tuple, List, Callable, Any __all__ = ["GoogLeNet", "googlenet"] class GoogLeNet(nn.Module): __constants__ = ["aux_logits", "transform_input"] def __init__( self, num_classes: int = 1000, aux_logits: bool = True, transform_input: bool = False, init_weights: Optional[bool] = None, blocks: Optional[List[Callable[..., nn.Module]]] = None, dropout: float = 0.2, dropout_aux: float = 0.7, ) -> None: super().__init__() if blocks is None: blocks = [BasicConv2d, Inception, InceptionAux] if init_weights is None: warnings.warn( "The default weight initialization of GoogleNet will be changed in future releases of " "torchvision. If you wish to keep the old behavior (which leads to long initialization times" " due to scipy/scipy#11299), please set init_weights=True.", FutureWarning, ) init_weights = True if len(blocks) != 3: raise ValueError(f"blocks length should be 3 instead of {len(blocks)}") conv_block = blocks[0] inception_block = blocks[1] inception_aux_block = blocks[2] self.aux_logits = aux_logits self.transform_input = transform_input self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.conv2 = conv_block(64, 64, kernel_size=1) self.conv3 = conv_block(64, 192, kernel_size=3, padding=1) self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32) self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64) self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64) self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64) self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64) self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64) self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128) self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128) self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128) if aux_logits: self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux) self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux) else: self.aux1 = None # type: ignore[assignment] self.aux2 = None # type: ignore[assignment] self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(p=dropout) self.fc = nn.Linear(1024, num_classes) if init_weights: for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _transform_input(self, x: Tensor) -> Tensor: if self.transform_input: x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x = torch.cat((x_ch0, x_ch1, x_ch2), 1) return x def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: # N x 3 x 224 x 224 x = self.conv1(x) # N x 64 x 112 x 112 x = self.maxpool1(x) # N x 64 x 56 x 56 x = self.conv2(x) # N x 64 x 56 x 56 x = self.conv3(x) # N x 192 x 56 x 56 x = self.maxpool2(x) # N x 192 x 28 x 28 x = self.inception3a(x) # N x 256 x 28 x 28 x = self.inception3b(x) # N x 480 x 28 x 28 x = self.maxpool3(x) # N x 480 x 14 x 14 x = self.inception4a(x) # N x 512 x 14 x 14 aux1: Optional[Tensor] = None if self.aux1 is not None: if self.training: aux1 = self.aux1(x) x = self.inception4b(x) # N x 512 x 14 x 14 x = self.inception4c(x) # N x 512 x 14 x 14 x = self.inception4d(x) # N x 528 x 14 x 14 aux2: Optional[Tensor] = None if self.aux2 is not None: if self.training: aux2 = self.aux2(x) x = self.inception4e(x) # N x 832 x 14 x 14 x = self.maxpool4(x) # N x 832 x 7 x 7 x = self.inception5a(x) # N x 832 x 7 x 7 x = self.inception5b(x) # N x 1024 x 7 x 7 x = self.avgpool(x) # N x 1024 x 1 x 1 x = torch.flatten(x, 1) # N x 1024 x = self.dropout(x) x = self.fc(x) # N x 1000 (num_classes) return x, aux2, aux1 def forward(self, x: Tensor): x = self._transform_input(x) x, aux1, aux2 = self._forward(x) return x class Inception(nn.Module): def __init__( self, in_channels: int, ch1x1: int, ch3x3red: int, ch3x3: int, ch5x5red: int, ch5x5: int, pool_proj: int, conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) self.branch2 = nn.Sequential( conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1), ) self.branch3 = nn.Sequential( conv_block(in_channels, ch5x5red, kernel_size=1), # Here, kernel_size=3 instead of kernel_size=5 is a known bug. # Please see https://github.com/pytorch/vision/issues/906 for details. conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1), ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), conv_block(in_channels, pool_proj, kernel_size=1), ) def _forward(self, x: Tensor) -> List[Tensor]: branch1 = self.branch1(x) branch2 = self.branch2(x) branch3 = self.branch3(x) branch4 = self.branch4(x) outputs = [branch1, branch2, branch3, branch4] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionAux(nn.Module): def __init__( self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None, dropout: float = 0.7, ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.conv = conv_block(in_channels, 128, kernel_size=1) self.fc1 = nn.Linear(2048, 1024) self.fc2 = nn.Linear(1024, num_classes) self.dropout = nn.Dropout(p=dropout) def forward(self, x: Tensor) -> Tensor: # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 x = F.adaptive_avg_pool2d(x, (4, 4)) # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 x = self.conv(x) # N x 128 x 4 x 4 x = torch.flatten(x, 1) # N x 2048 x = F.relu(self.fc1(x), inplace=True) # N x 1024 x = self.dropout(x) # N x 1024 x = self.fc2(x) # N x 1000 (num_classes) return x class BasicConv2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) return F.relu(x, inplace=True) def googlenet(progress: bool = True, **kwargs: Any) -> GoogLeNet: r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. The required minimum input size of the model is 15x15. Args: weights (GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``weights=GoogLeNet_Weights.IMAGENET1K_V1``, else False. """ model = GoogLeNet(**kwargs) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_inception_v3.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn.functional as F from torch import nn, Tensor import warnings from typing import Callable, Any, Optional, Tuple, List __all__ = ["Inception3", "inception_v3"] class Inception3(nn.Module): def __init__( self, num_classes: int = 1000, aux_logits: bool = True, transform_input: bool = False, inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, init_weights: Optional[bool] = None, dropout: float = 0.5, ) -> None: super().__init__() if inception_blocks is None: inception_blocks = [ BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux, ] if init_weights is None: warnings.warn( "The default weight initialization of inception_v3 will be changed in future releases of " "torchvision. If you wish to keep the old behavior (which leads to long initialization times" " due to scipy/scipy#11299), please set init_weights=True.", FutureWarning, ) init_weights = True if len(inception_blocks) != 7: raise ValueError( f"lenght of inception_blocks should be 7 instead of {len(inception_blocks)}" ) conv_block = inception_blocks[0] inception_a = inception_blocks[1] inception_b = inception_blocks[2] inception_c = inception_blocks[3] inception_d = inception_blocks[4] inception_e = inception_blocks[5] inception_aux = inception_blocks[6] self.aux_logits = aux_logits self.transform_input = transform_input self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.Mixed_5b = inception_a(192, pool_features=32) self.Mixed_5c = inception_a(256, pool_features=64) self.Mixed_5d = inception_a(288, pool_features=64) self.Mixed_6a = inception_b(288) self.Mixed_6b = inception_c(768, channels_7x7=128) self.Mixed_6c = inception_c(768, channels_7x7=160) self.Mixed_6d = inception_c(768, channels_7x7=160) self.Mixed_6e = inception_c(768, channels_7x7=192) self.AuxLogits: Optional[nn.Module] = None if aux_logits: self.AuxLogits = inception_aux(768, num_classes) self.Mixed_7a = inception_d(768) self.Mixed_7b = inception_e(1280) self.Mixed_7c = inception_e(2048) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(p=dropout) self.fc = nn.Linear(2048, num_classes) if init_weights: for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore torch.nn.init.trunc_normal_( m.weight, mean=0.0, std=stddev, a=-2, b=2 ) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _transform_input(self, x: Tensor) -> Tensor: if self.transform_input: x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x = torch.cat((x_ch0, x_ch1, x_ch2), 1) return x def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: # N x 3 x 299 x 299 x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149 x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147 x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 x = self.maxpool1(x) # N x 64 x 73 x 73 x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 x = self.maxpool2(x) # N x 192 x 35 x 35 x = self.Mixed_5b(x) # N x 256 x 35 x 35 x = self.Mixed_5c(x) # N x 288 x 35 x 35 x = self.Mixed_5d(x) # N x 288 x 35 x 35 x = self.Mixed_6a(x) # N x 768 x 17 x 17 x = self.Mixed_6b(x) # N x 768 x 17 x 17 x = self.Mixed_6c(x) # N x 768 x 17 x 17 x = self.Mixed_6d(x) # N x 768 x 17 x 17 x = self.Mixed_6e(x) # N x 768 x 17 x 17 aux: Optional[Tensor] = None if self.AuxLogits is not None: if self.training: aux = self.AuxLogits(x) # N x 768 x 17 x 17 x = self.Mixed_7a(x) # N x 1280 x 8 x 8 x = self.Mixed_7b(x) # N x 2048 x 8 x 8 x = self.Mixed_7c(x) # N x 2048 x 8 x 8 # Adaptive average pooling x = self.avgpool(x) # N x 2048 x 1 x 1 x = self.dropout(x) # N x 2048 x 1 x 1 x = torch.flatten(x, 1) # N x 2048 x = self.fc(x) # N x 1000 (num_classes) return x, aux def forward(self, x: Tensor): x = self._transform_input(x) x, aux = self._forward(x) return x class InceptionA(nn.Module): def __init__( self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) def _forward(self, x: Tensor) -> List[Tensor]: branch1x1 = self.branch1x1(x) branch5x5 = self.branch5x5_1(x) branch5x5 = self.branch5x5_2(branch5x5) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionB(nn.Module): def __init__( self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) def _forward(self, x: Tensor) -> List[Tensor]: branch3x3 = self.branch3x3(x) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) outputs = [branch3x3, branch3x3dbl, branch_pool] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionC(nn.Module): def __init__( self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) c7 = channels_7x7 self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch_pool = conv_block(in_channels, 192, kernel_size=1) def _forward(self, x: Tensor) -> List[Tensor]: branch1x1 = self.branch1x1(x) branch7x7 = self.branch7x7_1(x) branch7x7 = self.branch7x7_2(branch7x7) branch7x7 = self.branch7x7_3(branch7x7) branch7x7dbl = self.branch7x7dbl_1(x) branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionD(nn.Module): def __init__( self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) def _forward(self, x: Tensor) -> List[Tensor]: branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_2(branch3x3) branch7x7x3 = self.branch7x7x3_1(x) branch7x7x3 = self.branch7x7x3_2(branch7x7x3) branch7x7x3 = self.branch7x7x3_3(branch7x7x3) branch7x7x3 = self.branch7x7x3_4(branch7x7x3) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) outputs = [branch3x3, branch7x7x3, branch_pool] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionE(nn.Module): def __init__( self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch_pool = conv_block(in_channels, 192, kernel_size=1) def _forward(self, x: Tensor) -> List[Tensor]: branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) branch3x3 = [ self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), ] branch3x3 = torch.cat(branch3x3, 1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = [ self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] branch3x3dbl = torch.cat(branch3x3dbl, 1) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return outputs def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionAux(nn.Module): def __init__( self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if conv_block is None: conv_block = BasicConv2d self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1.stddev = 0.01 # type: ignore[assignment] self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 # type: ignore[assignment] def forward(self, x: Tensor) -> Tensor: # N x 768 x 17 x 17 x = F.avg_pool2d(x, kernel_size=5, stride=3) # N x 768 x 5 x 5 x = self.conv0(x) # N x 128 x 5 x 5 x = self.conv1(x) # N x 768 x 1 x 1 # Adaptive average pooling x = F.adaptive_avg_pool2d(x, (1, 1)) # N x 768 x 1 x 1 x = torch.flatten(x, 1) # N x 768 x = self.fc(x) # N x 1000 return x class BasicConv2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) return F.relu(x, inplace=True) def inception_v3(progress: bool = True, **kwargs: Any) -> Inception3: r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. The required minimum input size of the model is 75x75. .. note:: **Important**: In contrast to the other models the inception_v3 expects tensors with a size of N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: weights (Inception_V3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``weights=Inception_V3_Weights.IMAGENET1K_V1``, else False. """ model = Inception3(**kwargs) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_levit.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import itertools from timm.models.vision_transformer import trunc_normal_ specification = { "LeViT_128S": { "C": "128_256_384", "D": 16, "N": "4_6_8", "X": "2_3_4", "drop_path": 0, "weights": "https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth", } } __all__ = ["LeViT_128S"] def LeViT_128S(num_classes=1000, distillation=False, pretrained=False, fuse=False): return model_factory( **specification["LeViT_128S"], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse ) FLOPS_COUNTER = 0 class Conv2d_BN(torch.nn.Sequential): def __init__( self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000, ): super().__init__() self.add_module( "c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False) ) bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) self.add_module("bn", bn) global FLOPS_COUNTER output_points = ( (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1 ) ** 2 FLOPS_COUNTER += a * b * output_points * (ks ** 2) // groups @torch.no_grad() def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = torch.nn.Conv2d( w.size(1) * self.c.groups, w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups, ) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class Linear_BN(torch.nn.Sequential): def __init__(self, a, b, bn_weight_init=1, resolution=-100000): super().__init__() self.add_module("c", torch.nn.Linear(a, b, bias=False)) bn = torch.nn.BatchNorm1d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) self.add_module("bn", bn) global FLOPS_COUNTER output_points = resolution ** 2 FLOPS_COUNTER += a * b * output_points @torch.no_grad() def fuse(self): l, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def forward(self, x): l, bn = self._modules.values() x = l(x) return bn(x.flatten(0, 1)).reshape_as(x) class BN_Linear(torch.nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() self.add_module("bn", torch.nn.BatchNorm1d(a)) l = torch.nn.Linear(a, b, bias=bias) trunc_normal_(l.weight, std=std) if bias: torch.nn.init.constant_(l.bias, 0) self.add_module("l", l) global FLOPS_COUNTER FLOPS_COUNTER += a * b @torch.no_grad() def fuse(self): bn, l = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 b = ( bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 ) w = l.weight * w[None, :] if l.bias is None: b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def b16(n, activation, resolution=224): return torch.nn.Sequential( Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution), activation(), Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), activation(), Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), activation(), Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8), ) class Residual(torch.nn.Module): def __init__(self, m, drop): super().__init__() self.m = m self.drop = drop def forward(self, x): if self.training and self.drop > 0: return ( x + self.m(x) * torch.rand(x.size(0), 1, 1, device=x.device) .ge_(self.drop) .div(1 - self.drop) .detach() ) else: return x + self.m(x) class Attention(torch.nn.Module): def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, activation=None, resolution=14 ): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio h = self.dh + nh_kd * 2 self.qkv = Linear_BN(dim, h, resolution=resolution) self.proj = torch.nn.Sequential( activation(), Linear_BN(self.dh, dim, bn_weight_init=0, resolution=resolution), ) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) attention_offsets = {} idxs = [] for p1 in points: for p2 in points: offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter( torch.zeros(num_heads, len(attention_offsets)) ) self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N)) global FLOPS_COUNTER # queries * keys FLOPS_COUNTER += num_heads * (resolution ** 4) * key_dim # softmax FLOPS_COUNTER += num_heads * (resolution ** 4) # attention * v FLOPS_COUNTER += num_heads * self.d * (resolution ** 4) @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and hasattr(self, "ab"): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): # x (B,N,C) B, N, C = x.shape qkv = self.qkv(x) q, k, v = qkv.view(B, N, self.num_heads, -1).split( [self.key_dim, self.key_dim, self.d], dim=3 ) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale + ( self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab ) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class Subsample(torch.nn.Module): def __init__(self, stride, resolution): super().__init__() self.stride = stride self.resolution = resolution def forward(self, x): B, N, C = x.shape x = x.view(B, self.resolution, self.resolution, C)[ :, :: self.stride, :: self.stride ].reshape(B, -1, C) return x class AttentionSubsample(torch.nn.Module): def __init__( self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, activation=None, stride=2, resolution=14, resolution_=7, ): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * self.num_heads self.attn_ratio = attn_ratio self.resolution_ = resolution_ self.resolution_2 = resolution_ ** 2 h = self.dh + nh_kd self.kv = Linear_BN(in_dim, h, resolution=resolution) self.q = torch.nn.Sequential( Subsample(stride, resolution), Linear_BN(in_dim, nh_kd, resolution=resolution_), ) self.proj = torch.nn.Sequential( activation(), Linear_BN(self.dh, out_dim, resolution=resolution_) ) self.stride = stride self.resolution = resolution points = list(itertools.product(range(resolution), range(resolution))) points_ = list(itertools.product(range(resolution_), range(resolution_))) N = len(points) N_ = len(points_) attention_offsets = {} idxs = [] for p1 in points_: for p2 in points: size = 1 offset = ( abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2), ) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter( torch.zeros(num_heads, len(attention_offsets)) ) self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N_, N)) global FLOPS_COUNTER # queries * keys FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2) * key_dim # softmax FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2) # attention * v FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2) * self.d @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and hasattr(self, "ab"): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): B, N, C = x.shape k, v = ( self.kv(x) .view(B, N, self.num_heads, -1) .split([self.key_dim, self.d], dim=3) ) k = k.permute(0, 2, 1, 3) # BHNC v = v.permute(0, 2, 1, 3) # BHNC q = ( self.q(x) .view(B, self.resolution_2, self.num_heads, self.key_dim) .permute(0, 2, 1, 3) ) attn = (q @ k.transpose(-2, -1)) * self.scale + ( self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab ) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x class LeViT(torch.nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=[192], key_dim=[64], depth=[12], num_heads=[3], attn_ratio=[2], mlp_ratio=[2], hybrid_backbone=None, down_ops=[], attention_activation=torch.nn.Hardswish, mlp_activation=torch.nn.Hardswish, distillation=True, drop_path=0, ): super().__init__() global FLOPS_COUNTER self.num_classes = num_classes self.num_features = embed_dim[-1] self.embed_dim = embed_dim self.distillation = distillation self.patch_embed = hybrid_backbone self.blocks = [] down_ops.append([""]) resolution = img_size // patch_size for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops) ): for _ in range(dpth): self.blocks.append( Residual( Attention( ed, kd, nh, attn_ratio=ar, activation=attention_activation, resolution=resolution, ), drop_path, ) ) if mr > 0: h = int(ed * mr) self.blocks.append( Residual( torch.nn.Sequential( Linear_BN(ed, h, resolution=resolution), mlp_activation(), Linear_BN( h, ed, bn_weight_init=0, resolution=resolution ), ), drop_path, ) ) if do[0] == "Subsample": # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) resolution_ = (resolution - 1) // do[5] + 1 self.blocks.append( AttentionSubsample( *embed_dim[i : i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], activation=attention_activation, stride=do[5], resolution=resolution, resolution_=resolution_ ) ) resolution = resolution_ if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( Residual( torch.nn.Sequential( Linear_BN(embed_dim[i + 1], h, resolution=resolution), mlp_activation(), Linear_BN( h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution, ), ), drop_path, ) ) self.blocks = torch.nn.Sequential(*self.blocks) # Classifier head self.head = ( BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() ) if distillation: self.head_dist = ( BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() ) self.FLOPS = FLOPS_COUNTER FLOPS_COUNTER = 0 def no_weight_decay(self): return {x for x in self.state_dict().keys() if "attention_biases" in x} def forward(self, x): x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) x = self.blocks(x) x = x.mean(1) if self.distillation: x = self.head(x), self.head_dist(x) if not self.training: x = (x[0] + x[1]) / 2 else: x = self.head(x) return x def model_factory( C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse ): embed_dim = [int(x) for x in C.split("_")] num_heads = [int(x) for x in N.split("_")] depth = [int(x) for x in X.split("_")] act = torch.nn.Hardswish model = LeViT( patch_size=16, embed_dim=embed_dim, num_heads=num_heads, key_dim=[D] * 3, depth=depth, attn_ratio=[2, 2, 2], mlp_ratio=[2, 2, 2], down_ops=[ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) ["Subsample", D, embed_dim[0] // D, 4, 2, 2], ["Subsample", D, embed_dim[1] // D, 4, 2, 2], ], attention_activation=act, mlp_activation=act, hybrid_backbone=b16(embed_dim[0], activation=act), num_classes=num_classes, drop_path=drop_path, distillation=distillation, ) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_mnasnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from torch import Tensor import warnings from typing import Any, Dict, List __all__ = [ "MNASNet", "mnasnet1_0", ] # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. _BN_MOMENTUM = 1 - 0.9997 class _InvertedResidual(nn.Module): def __init__( self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1, ) -> None: super().__init__() if stride not in [1, 2]: raise ValueError(f"stride should be 1 or 2 instead of {stride}") if kernel_size not in [3, 5]: raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}") mid_ch = in_ch * expansion_factor self.apply_residual = in_ch == out_ch and stride == 1 self.layers = nn.Sequential( # Pointwise nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Depthwise nn.Conv2d( mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False, ), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Linear pointwise. Note that there's no activation. nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch, momentum=bn_momentum), ) def forward(self, input: Tensor) -> Tensor: if self.apply_residual: return self.layers(input) + input else: return self.layers(input) def _stack( in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float, ) -> nn.Sequential: """Creates a stack of inverted residuals.""" if repeats < 1: raise ValueError(f"repeats should be >= 1, instead got {repeats}") # First one has no skip, because feature map size changes. first = _InvertedResidual( in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum ) remaining = [] for _ in range(1, repeats): remaining.append( _InvertedResidual( out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum ) ) return nn.Sequential(first, *remaining) def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: """Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.""" if not 0.0 < round_up_bias < 1.0: raise ValueError( f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}" ) new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) return new_val if new_val >= round_up_bias * val else new_val + divisor def _get_depths(alpha: float) -> List[int]: """Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down.""" depths = [32, 16, 24, 40, 80, 96, 192, 320] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] class MNASNet(torch.nn.Module): """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This implements the B1 variant of the model. >>> model = MNASNet(1.0, num_classes=1000) >>> x = torch.rand(1, 3, 224, 224) >>> y = model(x) >>> y.dim() 2 >>> y.nelement() 1000 """ # Version 2 adds depth scaling in the initial stages of the network. _version = 2 def __init__( self, alpha: float, num_classes: int = 1000, dropout: float = 0.2 ) -> None: super().__init__() if alpha <= 0.0: raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}") self.alpha = alpha self.num_classes = num_classes depths = _get_depths(alpha) layers = [ # First layer: regular conv. nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), # Depthwise separable, no skip. nn.Conv2d( depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False, ), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), # MNASNet blocks: stacks of inverted residuals. _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), # Final mapping to classifier input. nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) self.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes) ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.kaiming_uniform_( m.weight, mode="fan_out", nonlinearity="sigmoid" ) nn.init.zeros_(m.bias) def forward(self, x: Tensor) -> Tensor: x = self.layers(x) # Equivalent to global avgpool and removing H and W dimensions. x = x.mean([2, 3]) return self.classifier(x) def _mnasnet(alpha: float, progress: bool, **kwargs: Any) -> MNASNet: model = MNASNet(alpha, **kwargs) return model def mnasnet1_0(progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: weights (MNASNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ return _mnasnet(1.0, progress, **kwargs) ================================================ FILE: python/oneflow/test/expensive/pytorch_poolformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple import os import copy class PatchEmbed(nn.Module): """ Patch Embedding that is implemented by a layer of conv. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H/stride, W/stride] """ def __init__( self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=None, ): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, eps=1e-05): super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze( -1 ).unsqueeze(-1) return x class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) class Pooling(nn.Module): """ Implementation of pooling for PoolFormer --pool_size: pooling size """ def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size // 2, count_include_pad=False ) def forward(self, x): return self.pool(x) - x class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. --dim: embedding dim --pool_size: pooling size --mlp_ratio: mlp expansion ratio --act_layer: activation --norm_layer: normalization --drop: dropout rate --drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 --use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 """ def __init__( self, dim, pool_size=3, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=GroupNorm, drop=0.0, drop_path=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, ): super().__init__() self.norm1 = norm_layer(dim) self.token_mixer = Pooling(pool_size=pool_size) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True ) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)) ) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)) ) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def basic_blocks( dim, index, layers, pool_size=3, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=GroupNorm, drop_rate=0.0, drop_path_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, ): """ generate PoolFormer blocks for a stage return: PoolFormer blocks """ blocks = [] for block_idx in range(layers[index]): block_dpr = ( drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) ) blocks.append( PoolFormerBlock( dim, pool_size=pool_size, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) ) blocks = nn.Sequential(*blocks) return blocks class PoolFormer(nn.Module): """ PoolFormer, the main class of our model --layers: [x,x,x,x], number of blocks for the 4 stages --embed_dims, --mlp_ratios, --pool_size: the embedding dims, mlp ratios and pooling size for the 4 stages --downsamples: flags to apply downsampling or not --norm_layer, --act_layer: define the types of normalization and activation --num_classes: number of classes for the image classification --in_patch_size, --in_stride, --in_pad: specify the patch embedding for the input image --down_patch_size --down_stride --down_pad: specify the downsample (patch embed.) --fork_feat: whether output features of the 4 stages, for dense prediction --init_cfg, --pretrained: for mmdetection and mmsegmentation to load pretrained weights """ def __init__( self, layers, embed_dims=None, mlp_ratios=None, downsamples=None, pool_size=3, norm_layer=GroupNorm, act_layer=nn.GELU, num_classes=1000, in_patch_size=7, in_stride=4, in_pad=2, down_patch_size=3, down_stride=2, down_pad=1, drop_rate=0.0, drop_path_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, fork_feat=False, init_cfg=None, pretrained=None, **kwargs, ): super().__init__() if not fork_feat: self.num_classes = num_classes self.fork_feat = fork_feat self.patch_embed = PatchEmbed( patch_size=in_patch_size, stride=in_stride, padding=in_pad, in_chans=3, embed_dim=embed_dims[0], ) # set the main block in network network = [] for i in range(len(layers)): stage = basic_blocks( embed_dims[i], i, layers, pool_size=pool_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) network.append(stage) if i >= len(layers) - 1: break if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: # downsampling between two stages network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, padding=down_pad, in_chans=embed_dims[i], embed_dim=embed_dims[i + 1], ) ) self.network = nn.ModuleList(network) if self.fork_feat: # add a norm layer for each output self.out_indices = [0, 2, 4, 6] for i_emb, i_layer in enumerate(self.out_indices): if i_emb == 0 and os.environ.get("FORK_LAST3", None): # TODO: more elegant way """For RetinaNet, `start_level=1`. The first norm layer will not used. cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` """ layer = nn.Identity() else: layer = norm_layer(embed_dims[i_emb]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) else: # Classifier head self.norm = norm_layer(embed_dims[-1]) self.head = ( nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self.cls_init_weights) self.init_cfg = copy.deepcopy(init_cfg) # load pre-trained model if self.fork_feat and (self.init_cfg is not None or pretrained is not None): self.init_weights() # init for classification def cls_init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def get_classifier(self): return self.head def reset_classifier(self, num_classes): self.num_classes = num_classes self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def forward_embeddings(self, x): x = self.patch_embed(x) return x def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): x = block(x) if self.fork_feat and idx in self.out_indices: norm_layer = getattr(self, f"norm{idx}") x_out = norm_layer(x) outs.append(x_out) if self.fork_feat: # output the features of four stages for dense prediction return outs # output only the features of last layer for image classification return x def forward(self, x): # input embedding x = self.forward_embeddings(x) # through backbone x = self.forward_tokens(x) if self.fork_feat: # otuput features of four stages for dense prediction return x x = self.norm(x) cls_out = self.head(x.mean([-2, -1])) # for image classification return cls_out def poolformer_s12(pretrained=False, **kwargs): """ PoolFormer-S12 model, Params: 12M --layers: [x,x,x,x], numbers of layers for the four stages --embed_dims, --mlp_ratios: embedding dims and mlp ratios for the four stages --downsamples: flags to apply downsampling or not in four blocks """ layers = [2, 2, 6, 2] embed_dims = [64, 128, 320, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] model = PoolFormer( layers, embed_dims=embed_dims, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_pvt.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from functools import partial __all__ = ["pvt_tiny"] class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1, ): super().__init__() assert ( dim % num_heads == 0 ), f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) def forward(self, x, H, W): B, N, C = x.shape q = ( self.q(x) .reshape(B, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = ( self.kv(x_) .reshape(B, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) else: kv = ( self.kv(x) .reshape(B, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ # f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) H, W = H // self.patch_size[0], W // self.patch_size[1] return x, (H, W) class PyramidVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, ): super().__init__() self.num_classes = num_classes self.depths = depths self.num_stages = num_stages dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule cur = 0 for i in range(num_stages): patch_embed = PatchEmbed( img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), patch_size=patch_size if i == 0 else 2, in_chans=in_chans if i == 0 else embed_dims[i - 1], embed_dim=embed_dims[i], ) num_patches = ( patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1 ) pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) pos_drop = nn.Dropout(p=drop_rate) block = nn.ModuleList( [ Block( dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, sr_ratio=sr_ratios[i], ) for j in range(depths[i]) ] ) cur += depths[i] setattr(self, f"patch_embed{i + 1}", patch_embed) setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_drop{i + 1}", pos_drop) setattr(self, f"block{i + 1}", block) self.norm = norm_layer(embed_dims[3]) # cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) # classification head self.head = ( nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() ) # init weights for i in range(num_stages): pos_embed = getattr(self, f"pos_embed{i + 1}") trunc_normal_(pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): # return {'pos_embed', 'cls_token'} # has pos_embed may be better return {"cls_token"} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def _get_pos_embed(self, pos_embed, patch_embed, H, W): if H * W == self.patch_embed1.num_patches: return pos_embed else: return ( F.interpolate( pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute( 0, 3, 1, 2 ), size=(H, W), mode="bilinear", ) .reshape(1, -1, H * W) .permute(0, 2, 1) ) def forward_features(self, x): B = x.shape[0] for i in range(self.num_stages): patch_embed = getattr(self, f"patch_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}") pos_drop = getattr(self, f"pos_drop{i + 1}") block = getattr(self, f"block{i + 1}") x, (H, W) = patch_embed(x) if i == self.num_stages - 1: cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) else: pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) x = pos_drop(x + pos_embed) for blk in block: x = blk(x, H, W) if i != self.num_stages - 1: x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() x = self.norm(x) return x[:, 0] def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def pvt_tiny(pretrained=False, **kwargs): model = PyramidVisionTransformer( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs, ) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_res2net.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch.nn as nn import torch import torch.nn.functional as F import math __all__ = ["Res2Net", "res2net50"] class Bottle2neck(nn.Module): expansion = 4 def __init__( self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype="normal", ): """ Constructor Args: inplanes: input channel dimensionality planes: output channel dimensionality stride: conv stride. Replaces pooling layer. downsample: None when stride = 1 baseWidth: basic width of conv3x3 scale: number of scale. type: 'normal': normal set. 'stage': first block of a new stage. """ super(Bottle2neck, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(width * scale) if scale == 1: self.nums = 1 else: self.nums = scale - 1 if stype == "stage": self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) convs = [] bns = [] for i in range(self.nums): convs.append( nn.Conv2d( width, width, kernel_size=3, stride=stride, padding=1, bias=False ) ) bns.append(nn.BatchNorm2d(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.conv3 = nn.Conv2d( width * scale, planes * self.expansion, kernel_size=1, bias=False ) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stype = stype self.scale = scale self.width = width def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) spx = torch.split(out, self.width, 1) for i in range(self.nums): if i == 0 or self.stype == "stage": sp = spx[i] else: sp = sp + spx[i] sp = self.convs[i](sp) sp = self.relu(self.bns[i](sp)) if i == 0: out = sp else: out = torch.cat((out, sp), 1) if self.scale != 1 and self.stype == "normal": out = torch.cat((out, spx[self.nums]), 1) elif self.scale != 1 and self.stype == "stage": out = torch.cat((out, self.pool(spx[self.nums])), 1) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Res2Net(nn.Module): def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): self.inplanes = 64 super(Res2Net, self).__init__() self.baseWidth = baseWidth self.scale = scale self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample=downsample, stype="stage", baseWidth=self.baseWidth, scale=self.scale, ) ) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale) ) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def res2net50(pretrained=False, **kwargs): """Constructs a Res2Net-50 model. Res2Net-50 refers to the Res2Net-50_26w_4s. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_resmlp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from timm.models.layers import trunc_normal_, DropPath, to_2tuple __all__ = ["resmlp_12"] class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert ( H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", ) assert ( W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", ) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x class Affine(nn.Module): def __init__(self, dim): super().__init__() self.alpha = nn.Parameter(torch.ones(dim)) self.beta = nn.Parameter(torch.zeros(dim)) def forward(self, x): return self.alpha * x + self.beta class layers_scale_mlp_blocks(nn.Module): def __init__( self, dim, drop=0.0, drop_path=0.0, act_layer=nn.GELU, init_values=1e-4, num_patches=196, ): super().__init__() self.norm1 = Affine(dim) self.attn = nn.Linear(num_patches, num_patches) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = Affine(dim) self.mlp = Mlp( in_features=dim, hidden_features=int(4.0 * dim), act_layer=act_layer, drop=drop, ) self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) def forward(self, x): x = x + self.drop_path( self.gamma_1 * self.attn(self.norm1(x).transpose(1, 2)).transpose(1, 2) ) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class resmlp_models(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, drop_rate=0.0, Patch_layer=PatchEmbed, act_layer=nn.GELU, drop_path_rate=0.0, init_scale=1e-4, ): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim self.patch_embed = Patch_layer( img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches dpr = [drop_path_rate for i in range(depth)] self.blocks = nn.ModuleList( [ layers_scale_mlp_blocks( dim=embed_dim, drop=drop_rate, drop_path=dpr[i], act_layer=act_layer, init_values=init_scale, num_patches=num_patches, ) for i in range(depth) ] ) self.norm = Affine(embed_dim) self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")] self.head = ( nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) for i, blk in enumerate(self.blocks): x = blk(x) x = self.norm(x) x = x.mean(dim=1).reshape(B, 1, -1) return x[:, 0] def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def resmlp_12(pretrained=False, dist=False, **kwargs): model = resmlp_models( patch_size=16, embed_dim=384, depth=12, Patch_layer=PatchEmbed, init_scale=0.1, **kwargs, ) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_resnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch from torch import Tensor import torch.nn as nn from typing import Type, Any, Callable, Union, List, Optional __all__ = [ "ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", ] def conv3x3( in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 ) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) return model def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet( "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs ) def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet( "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs ) def resnext50_32x4d( pretrained: bool = False, progress: bool = True, **kwargs: Any ) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs["groups"] = 32 kwargs["width_per_group"] = 4 return _resnet( "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs ) def resnext101_32x8d( pretrained: bool = False, progress: bool = True, **kwargs: Any ) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs["groups"] = 32 kwargs["width_per_group"] = 8 return _resnet( "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs ) def wide_resnet50_2( pretrained: bool = False, progress: bool = True, **kwargs: Any ) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs["width_per_group"] = 64 * 2 return _resnet( "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs ) def wide_resnet101_2( pretrained: bool = False, progress: bool = True, **kwargs: Any ) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs["width_per_group"] = 64 * 2 return _resnet( "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs ) ================================================ FILE: python/oneflow/test/expensive/pytorch_rexnet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import math __all__ = [ "ReXNetV1", "rexnetv1_1_0", ] def silu(x, inplace=False): return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) class SiLU(nn.Module): def __init__(self, inplace=True): super(SiLU, self).__init__() self.inplace = inplace def forward(self, x): return silu(x, self.inplace) def ConvBNAct( out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1, active=True, relu6=False, ): out.append( nn.Conv2d( in_channels, channels, kernel, stride, pad, groups=num_group, bias=False ) ) out.append(nn.BatchNorm2d(channels)) if active: out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True)) def ConvBNSiLU(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1): out.append( nn.Conv2d( in_channels, channels, kernel, stride, pad, groups=num_group, bias=False ) ) out.append(nn.BatchNorm2d(channels)) out.append(SiLU(inplace=True)) class SE(nn.Module): def __init__(self, in_channels, channels, se_ratio=12): super(SE, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0), nn.BatchNorm2d(channels // se_ratio), nn.ReLU(inplace=True), nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.fc(y) return x * y class LinearBottleneck(nn.Module): def __init__( self, in_channels, channels, t, stride, use_se=True, se_ratio=12, **kwargs ): super(LinearBottleneck, self).__init__(**kwargs) self.use_shortcut = stride == 1 and in_channels <= channels self.in_channels = in_channels self.out_channels = channels out = [] if t != 1: dw_channels = in_channels * t ConvBNSiLU(out, in_channels=in_channels, channels=dw_channels) else: dw_channels = in_channels ConvBNAct( out, in_channels=dw_channels, channels=dw_channels, kernel=3, stride=stride, pad=1, num_group=dw_channels, active=False, ) if use_se: out.append(SE(dw_channels, dw_channels, se_ratio)) out.append(nn.ReLU6()) ConvBNAct( out, in_channels=dw_channels, channels=channels, active=False, relu6=True ) self.out = nn.Sequential(*out) def forward(self, x): out = self.out(x) if self.use_shortcut: out[:, 0 : self.in_channels] += x return out class ReXNetV1(nn.Module): def __init__( self, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000, use_se=True, se_ratio=12, dropout_ratio=0.2, bn_momentum=0.9, ): super(ReXNetV1, self).__init__() layers = [1, 2, 2, 3, 3, 5] strides = [1, 2, 2, 2, 1, 2] use_ses = [False, False, True, True, True, True] layers = [math.ceil(element * depth_mult) for element in layers] strides = sum( [ [element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides) ], [], ) if use_se: use_ses = sum( [[element] * layers[idx] for idx, element in enumerate(use_ses)], [] ) else: use_ses = [False] * sum(layers[:]) ts = [1] * layers[0] + [6] * sum(layers[1:]) self.depth = sum(layers[:]) * 3 stem_channel = 32 / width_mult if width_mult < 1.0 else 32 inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch features = [] in_channels_group = [] channels_group = [] # The following channel configuration is a simple instance to make each layer become an expand layer. for i in range(self.depth // 3): if i == 0: in_channels_group.append(int(round(stem_channel * width_mult))) channels_group.append(int(round(inplanes * width_mult))) else: in_channels_group.append(int(round(inplanes * width_mult))) inplanes += final_ch / (self.depth // 3 * 1.0) channels_group.append(int(round(inplanes * width_mult))) ConvBNSiLU( features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1, ) for block_idx, (in_c, c, t, s, se) in enumerate( zip(in_channels_group, channels_group, ts, strides, use_ses) ): features.append( LinearBottleneck( in_channels=in_c, channels=c, t=t, stride=s, use_se=se, se_ratio=se_ratio, ) ) pen_channels = int(1280 * width_mult) ConvBNSiLU(features, c, pen_channels) features.append(nn.AdaptiveAvgPool2d(1)) self.features = nn.Sequential(*features) self.output = nn.Sequential( nn.Dropout(dropout_ratio), nn.Conv2d(pen_channels, classes, 1, bias=True) ) def extract_features(self, x): return self.features[:-1](x) def forward(self, x): x = self.features(x) x = self.output(x).flatten(1) return x def _create_rexnetv1(arch, pretrained=False, progress=True, **model_kwargs): model = ReXNetV1(**model_kwargs) return model def rexnetv1_1_0(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet model with width multiplier of 1.0. .. note:: ReXNet model with width multiplier of 1.0 from the `Rethinking Channel Dimensions for Efficient Model Design `_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` """ model_kwargs = dict(width_mult=1.0, **kwargs) return _create_rexnetv1( "rexnetv1_1_0", pretrained=pretrained, progress=progress, **model_kwargs ) ================================================ FILE: python/oneflow/test/expensive/pytorch_rexnetv1_lite.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from math import ceil __all__ = [ "ReXNetV1_lite", "rexnet_lite_1_0", ] def _make_divisible(channel_size, divisor=None, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py """ if not divisor: return channel_size if min_value is None: min_value = divisor new_channel_size = max( min_value, int(channel_size + divisor / 2) // divisor * divisor ) # Make sure that round down does not go down by more than 10%. if new_channel_size < 0.9 * channel_size: new_channel_size += divisor return new_channel_size def _add_conv( out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1, active=True, relu6=True, bn_momentum=0.1, bn_eps=1e-5, ): out.append( nn.Conv2d( in_channels, channels, kernel, stride, pad, groups=num_group, bias=False ) ) out.append(nn.BatchNorm2d(channels, momentum=bn_momentum, eps=bn_eps)) if active: out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True)) class LinearBottleneck(nn.Module): def __init__( self, in_channels, channels, t, kernel_size=3, stride=1, bn_momentum=0.1, bn_eps=1e-5, **kwargs ): super(LinearBottleneck, self).__init__(**kwargs) self.conv_shortcut = None self.use_shortcut = stride == 1 and in_channels <= channels self.in_channels = in_channels self.out_channels = channels out = [] if t != 1: dw_channels = in_channels * t _add_conv( out, in_channels=in_channels, channels=dw_channels, bn_momentum=bn_momentum, bn_eps=bn_eps, ) else: dw_channels = in_channels _add_conv( out, in_channels=dw_channels, channels=dw_channels * 1, kernel=kernel_size, stride=stride, pad=(kernel_size // 2), num_group=dw_channels, bn_momentum=bn_momentum, bn_eps=bn_eps, ) _add_conv( out, in_channels=dw_channels, channels=channels, active=False, bn_momentum=bn_momentum, bn_eps=bn_eps, ) self.out = nn.Sequential(*out) def forward(self, x): out = self.out(x) if self.use_shortcut: out[:, 0 : self.in_channels] += x return out class ReXNetV1_lite(nn.Module): def __init__( self, fix_head_stem=False, divisible_value=8, input_ch=16, final_ch=164, multiplier=1.0, classes=1000, dropout_ratio=0.2, bn_momentum=0.1, bn_eps=1e-5, kernel_conf="333333", ): super(ReXNetV1_lite, self).__init__() layers = [1, 2, 2, 3, 3, 5] strides = [1, 2, 2, 2, 1, 2] kernel_sizes = [int(element) for element in kernel_conf] strides = sum( [ [element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides) ], [], ) ts = [1] * layers[0] + [6] * sum(layers[1:]) kernel_sizes = sum( [[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], [] ) self.num_convblocks = sum(layers[:]) features = [] inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32 first_channel = _make_divisible( int(round(first_channel * multiplier)), divisible_value ) in_channels_group = [] channels_group = [] _add_conv( features, 3, first_channel, kernel=3, stride=2, pad=1, bn_momentum=bn_momentum, bn_eps=bn_eps, ) for i in range(self.num_convblocks): inplanes_divisible = _make_divisible( int(round(inplanes * multiplier)), divisible_value ) if i == 0: in_channels_group.append(first_channel) channels_group.append(inplanes_divisible) else: in_channels_group.append(inplanes_divisible) inplanes += final_ch / (self.num_convblocks - 1 * 1.0) inplanes_divisible = _make_divisible( int(round(inplanes * multiplier)), divisible_value ) channels_group.append(inplanes_divisible) for block_idx, (in_c, c, t, k, s) in enumerate( zip(in_channels_group, channels_group, ts, kernel_sizes, strides) ): features.append( LinearBottleneck( in_channels=in_c, channels=c, t=t, kernel_size=k, stride=s, bn_momentum=bn_momentum, bn_eps=bn_eps, ) ) pen_channels = ( int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280 ) _add_conv(features, c, pen_channels, bn_momentum=bn_momentum, bn_eps=bn_eps) self.features = nn.Sequential(*features) self.avgpool = nn.AdaptiveAvgPool2d(1) self.output = nn.Sequential( nn.Conv2d(pen_channels, 1024, 1, bias=True), nn.BatchNorm2d(1024, momentum=bn_momentum, eps=bn_eps), nn.ReLU6(inplace=True), nn.Dropout(dropout_ratio), nn.Conv2d(1024, classes, 1, bias=True), ) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = self.output(x).flatten(1) return x def _create_rexnet_lite(arch, pretrained=False, progress=True, **model_kwargs): model = ReXNetV1_lite(**model_kwargs) return model def rexnet_lite_1_0(pretrained=False, progress=True, **kwargs): """ Constructs the ReXNet-lite model with width multiplier of 1.0. .. note:: ReXNet-lite model with width multiplier of 1.0 from the `Rethinking Channel Dimensions for Efficient Model Design `_ paper. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> rexnet_lite_1_0 = flowvision.models.rexnet_lite_1_0(pretrained=False, progress=True) """ model_kwargs = dict(multiplier=1.0, **kwargs) return _create_rexnet_lite( "rexnet_lite_1_0", pretrained=pretrained, progress=progress, **model_kwargs ) ================================================ FILE: python/oneflow/test/expensive/pytorch_senet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import print_function, division, absolute_import from collections import OrderedDict import math import torch.nn as nn __all__ = ["SENet", "senet154"] class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x class Bottleneck(nn.Module): """ Base class for bottlenecks that implements `forward()` method. """ def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = self.se_module(out) + residual out = self.relu(out) return out class SEBottleneck(Bottleneck): """ Bottleneck for SENet154. """ expansion = 4 def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEBottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes * 2) self.conv2 = nn.Conv2d( planes * 2, planes * 4, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False, ) self.bn2 = nn.BatchNorm2d(planes * 4) self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes * 4, reduction=reduction) self.downsample = downsample self.stride = stride class SEResNetBottleneck(Bottleneck): """ ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe implementation and uses `stride=stride` in `conv1` and not in `conv2` (the latter is used in the torchvision implementation of ResNet). """ expansion = 4 def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEResNetBottleneck, self).__init__() self.conv1 = nn.Conv2d( inplanes, planes, kernel_size=1, bias=False, stride=stride ) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, padding=1, groups=groups, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes * 4, reduction=reduction) self.downsample = downsample self.stride = stride class SEResNeXtBottleneck(Bottleneck): """ ResNeXt bottleneck type C with a Squeeze-and-Excitation module. """ expansion = 4 def __init__( self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4, ): super(SEResNeXtBottleneck, self).__init__() width = math.floor(planes * (base_width / 64)) * groups self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) self.bn1 = nn.BatchNorm2d(width) self.conv2 = nn.Conv2d( width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False, ) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes * 4, reduction=reduction) self.downsample = downsample self.stride = stride class SENet(nn.Module): def __init__( self, block, layers, groups, reduction, dropout_p=0.2, inplanes=128, input_3x3=True, downsample_kernel_size=3, downsample_padding=1, num_classes=1000, ): """ Parameters ---------- block (nn.Module): Bottleneck class. - For SENet154: SEBottleneck - For SE-ResNet models: SEResNetBottleneck - For SE-ResNeXt models: SEResNeXtBottleneck layers (list of ints): Number of residual blocks for 4 layers of the network (layer1...layer4). groups (int): Number of groups for the 3x3 convolution in each bottleneck block. - For SENet154: 64 - For SE-ResNet models: 1 - For SE-ResNeXt models: 32 reduction (int): Reduction ratio for Squeeze-and-Excitation modules. - For all models: 16 dropout_p (float or None): Drop probability for the Dropout layer. If `None` the Dropout layer is not used. - For SENet154: 0.2 - For SE-ResNet models: None - For SE-ResNeXt models: None inplanes (int): Number of input channels for layer1. - For SENet154: 128 - For SE-ResNet models: 64 - For SE-ResNeXt models: 64 input_3x3 (bool): If `True`, use three 3x3 convolutions instead of a single 7x7 convolution in layer0. - For SENet154: True - For SE-ResNet models: False - For SE-ResNeXt models: False downsample_kernel_size (int): Kernel size for downsampling convolutions in layer2, layer3 and layer4. - For SENet154: 3 - For SE-ResNet models: 1 - For SE-ResNeXt models: 1 downsample_padding (int): Padding for downsampling convolutions in layer2, layer3 and layer4. - For SENet154: 1 - For SE-ResNet models: 0 - For SE-ResNeXt models: 0 num_classes (int): Number of outputs in `last_linear` layer. - For all models: 1000 """ super(SENet, self).__init__() self.inplanes = inplanes if input_3x3: layer0_modules = [ ("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)), ("bn1", nn.BatchNorm2d(64)), ("relu1", nn.ReLU(inplace=True)), ("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), ("bn2", nn.BatchNorm2d(64)), ("relu2", nn.ReLU(inplace=True)), ("conv3", nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), ("bn3", nn.BatchNorm2d(inplanes)), ("relu3", nn.ReLU(inplace=True)), ] else: layer0_modules = [ ( "conv1", nn.Conv2d( 3, inplanes, kernel_size=7, stride=2, padding=3, bias=False ), ), ("bn1", nn.BatchNorm2d(inplanes)), ("relu1", nn.ReLU(inplace=True)), ] # To preserve compatibility with Caffe weights `ceil_mode=True` # is used instead of `padding=1`. layer0_modules.append(("pool", nn.MaxPool2d(3, stride=2, ceil_mode=True))) self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) self.layer1 = self._make_layer( block, planes=64, blocks=layers[0], groups=groups, reduction=reduction, downsample_kernel_size=1, downsample_padding=0, ) self.layer2 = self._make_layer( block, planes=128, blocks=layers[1], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding, ) self.layer3 = self._make_layer( block, planes=256, blocks=layers[2], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding, ) self.layer4 = self._make_layer( block, planes=512, blocks=layers[3], stride=2, groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding, ) self.avg_pool = nn.AvgPool2d(7, stride=1) self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.last_linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer( self, block, planes, blocks, groups, reduction, stride=1, downsample_kernel_size=1, downsample_padding=0, ): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, stride=stride, padding=downsample_padding, bias=False, ), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, groups, reduction, stride, downsample) ) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, groups, reduction)) return nn.Sequential(*layers) def features(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x def logits(self, x): x = self.avg_pool(x) if self.dropout is not None: x = self.dropout(x) x = x.view(x.size(0), -1) x = self.last_linear(x) return x def forward(self, x): x = self.features(x) x = self.logits(x) return x def senet154(num_classes=1000, pretrained="imagenet"): model = SENet( SEBottleneck, [3, 8, 12, 3], groups=64, reduction=16, dropout_p=0.2, num_classes=num_classes, ) return model ================================================ FILE: python/oneflow/test/expensive/pytorch_shufflenetv2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from torch import Tensor from typing import Callable, Any, List __all__ = [ "ShuffleNetV2", "shufflenet_v2_x2_0", ] def channel_shuffle(x: Tensor, groups: int) -> Tensor: batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class InvertedResidual(nn.Module): def __init__(self, inp: int, oup: int, stride: int) -> None: super().__init__() if not (1 <= stride <= 3): raise ValueError("illegal stride value") self.stride = stride branch_features = oup // 2 if (self.stride == 1) and (inp != branch_features << 1): raise ValueError( f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1." ) if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv( inp, inp, kernel_size=3, stride=self.stride, padding=1 ), nn.BatchNorm2d(inp), nn.Conv2d( inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) else: self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( nn.Conv2d( inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv( branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, ), nn.BatchNorm2d(branch_features), nn.Conv2d( branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) @staticmethod def depthwise_conv( i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False, ) -> nn.Conv2d: return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x: Tensor) -> Tensor: if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out class ShuffleNetV2(nn.Module): def __init__( self, stages_repeats: List[int], stages_out_channels: List[int], num_classes: int = 1000, inverted_residual: Callable[..., nn.Module] = InvertedResidual, ) -> None: super().__init__() if len(stages_repeats) != 3: raise ValueError("expected stages_repeats as list of 3 positive ints") if len(stages_out_channels) != 5: raise ValueError("expected stages_out_channels as list of 5 positive ints") self._stage_out_channels = stages_out_channels input_channels = 3 output_channels = self._stage_out_channels[0] self.conv1 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) input_channels = output_channels self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Static annotations for mypy self.stage2: nn.Sequential self.stage3: nn.Sequential self.stage4: nn.Sequential stage_names = [f"stage{i}" for i in [2, 3, 4]] for name, repeats, output_channels in zip( stage_names, stages_repeats, self._stage_out_channels[1:] ): seq = [inverted_residual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(inverted_residual(output_channels, output_channels, 1)) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels output_channels = self._stage_out_channels[-1] self.conv5 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) self.fc = nn.Linear(output_channels, num_classes) def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.maxpool(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.conv5(x) x = x.mean([2, 3]) # globalpool x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _shufflenetv2(progress: bool, *args: Any, **kwargs: Any,) -> ShuffleNetV2: model = ShuffleNetV2(*args, **kwargs) return model def shufflenet_v2_x2_0(progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ return _shufflenetv2(progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) ================================================ FILE: python/oneflow/test/expensive/pytorch_squeezenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.init as init from typing import Any __all__ = ["SqueezeNet", "squeezenet1_1"] class Fire(nn.Module): def __init__( self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int, ) -> None: super().__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze_activation = nn.ReLU(inplace=True) self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) self.expand1x1_activation = nn.ReLU(inplace=True) self.expand3x3 = nn.Conv2d( squeeze_planes, expand3x3_planes, kernel_size=3, padding=1 ) self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.squeeze_activation(self.squeeze(x)) return torch.cat( [ self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x)), ], 1, ) class SqueezeNet(nn.Module): def __init__( self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5 ) -> None: super().__init__() self.num_classes = num_classes if version == "1_0": self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(96, 16, 64, 64), Fire(128, 16, 64, 64), Fire(128, 32, 128, 128), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(256, 32, 128, 128), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) elif version == "1_1": self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(64, 16, 64, 64), Fire(128, 16, 64, 64), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(128, 32, 128, 128), Fire(256, 32, 128, 128), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), Fire(512, 64, 256, 256), ) else: # FIXME: Is this needed? SqueezeNet should only be called from the # FIXME: squeezenet1_x() functions # FIXME: This checking is not done for the other models raise ValueError( f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected" ) # Final convolution is initialized differently from the rest final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) self.classifier = nn.Sequential( nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), ) for m in self.modules(): if isinstance(m, nn.Conv2d): if m is final_conv: init.normal_(m.weight, mean=0.0, std=0.01) else: init.kaiming_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.classifier(x) return torch.flatten(x, 1) def _squeezenet(version: str, progress: bool, **kwargs: Any,) -> SqueezeNet: model = SqueezeNet(version, **kwargs) return model def squeezenet1_1(progress: bool = True, **kwargs: Any) -> SqueezeNet: r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters than SqueezeNet 1.0, without sacrificing accuracy. The required minimum input size of the model is 17x17. Args: weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ return _squeezenet("1_1", progress, **kwargs) ================================================ FILE: python/oneflow/test/expensive/pytorch_swin_transformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn from timm.models.layers import DropPath, to_2tuple, trunc_normal_ class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = ( x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) ) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view( B, H // window_size, W // window_size, window_size, window_size, -1 ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = ( coords_flatten[:, :, None] - coords_flatten[:, None, :] ) # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0 ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1, ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 1 ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert ( 0 <= self.shift_size < self.window_size ), "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill( attn_mask != 0, float(-100.0) ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll( x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) ) else: shifted_x = x # partition windows x_windows = window_partition( shifted_x, self.window_size ) # nW*B, window_size, window_size, C x_windows = x_windows.view( -1, self.window_size * self.window_size, C ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn( x_windows, mask=self.attn_mask ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll( shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) ) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return ( f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" ) def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=norm_layer ) else: self.downsample = None def forward(self, x): for blk in self.blocks: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__( self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1], ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert ( H == self.img_size[0] and W == self.img_size[1] ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = ( Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__( self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, **kwargs, ): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim) ) trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), input_resolution=( patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer), ), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = ( nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): return {"absolute_pos_embed"} def no_weight_decay_keywords(self): return {"relative_position_bias_table"} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += ( self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) ) flops += self.num_features * self.num_classes return flops def _create_swin_transformer(arch, pretrained=False, progress=True, **model_kwargs): model = SwinTransformer(**model_kwargs) return model def swin_tiny_patch4_window7_224(pretrained=False, progress=True, **kwargs): """ Constructs Swin-T 224x224 model trained on ImageNet-1k. .. note:: Swin-T 224x224 model from `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_. Args: pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True`` For example: .. code-block:: python >>> import flowvision >>> swin_tiny_patch4_window7_224 = flowvision.models.swin_tiny_patch4_window7_224(pretrained=False, progress=True) """ model_kwargs = dict( img_size=224, patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), drop_path_rate=0.2, **kwargs, ) return _create_swin_transformer( "swin_tiny_patch4_window7_224", pretrained=pretrained, progress=progress, **model_kwargs, ) ================================================ FILE: python/oneflow/test/expensive/pytorch_uniformer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import torch import torch.nn as nn from functools import partial from timm.models.layers import trunc_normal_, DropPath, to_2tuple layer_scale = False init_value = 1e-6 class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class CMlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class CBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) self.norm1 = nn.BatchNorm2d(dim) self.conv1 = nn.Conv2d(dim, dim, 1) self.conv2 = nn.Conv2d(dim, dim, 1) self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = nn.BatchNorm2d(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = CMlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) def forward(self, x): x = x + self.pos_embed(x) x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class SABlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) global layer_scale self.ls = layer_scale if self.ls: global init_value print(f"Use layer_scale: {layer_scale}, init_values: {init_value}") self.gamma_1 = nn.Parameter( init_value * torch.ones((dim)), requires_grad=True ) self.gamma_2 = nn.Parameter( init_value * torch.ones((dim)), requires_grad=True ) def forward(self, x): x = x + self.pos_embed(x) B, N, H, W = x.shape x = x.flatten(2).transpose(1, 2) if self.ls: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) x = x.transpose(1, 2).reshape(B, N, H, W) return x class head_embedding(nn.Module): def __init__(self, in_channels, out_channels): super(head_embedding, self).__init__() self.proj = nn.Sequential( nn.Conv2d( in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ), nn.BatchNorm2d(out_channels // 2), nn.GELU(), nn.Conv2d( out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ), nn.BatchNorm2d(out_channels), ) def forward(self, x): x = self.proj(x) return x class middle_embedding(nn.Module): def __init__(self, in_channels, out_channels): super(middle_embedding, self).__init__() self.proj = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ), nn.BatchNorm2d(out_channels), ) def forward(self, x): x = self.proj(x) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.norm = nn.LayerNorm(embed_dim) self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert ( H == self.img_size[0] and W == self.img_size[1] ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() return x class UniFormer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, depth=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512], head_dim=64, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=None, conv_stem=False, ): """ Args: depth (list): depth of each stage img_size (int, tuple): input image size in_chans (int): number of input channels num_classes (int): number of classes for classification head embed_dim (list): embedding dimension of each stage head_dim (int): head dimension mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate norm_layer (nn.Module): normalization layer conv_stem (bool): whether use overlapped patch stem """ super().__init__() self.num_classes = num_classes self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) if conv_stem: self.patch_embed1 = head_embedding( in_channels=in_chans, out_channels=embed_dim[0] ) self.patch_embed2 = middle_embedding( in_channels=embed_dim[0], out_channels=embed_dim[1] ) self.patch_embed3 = middle_embedding( in_channels=embed_dim[1], out_channels=embed_dim[2] ) self.patch_embed4 = middle_embedding( in_channels=embed_dim[2], out_channels=embed_dim[3] ) else: self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0], ) self.patch_embed2 = PatchEmbed( img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1], ) self.patch_embed3 = PatchEmbed( img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2], ) self.patch_embed4 = PatchEmbed( img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3], ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depth)) ] # stochastic depth decay rule num_heads = [dim // head_dim for dim in embed_dim] self.blocks1 = nn.ModuleList( [ CBlock( dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth[0]) ] ) self.blocks2 = nn.ModuleList( [ CBlock( dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i + depth[0]], norm_layer=norm_layer, ) for i in range(depth[1]) ] ) self.blocks3 = nn.ModuleList( [ SABlock( dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i + depth[0] + depth[1]], norm_layer=norm_layer, ) for i in range(depth[2]) ] ) self.blocks4 = nn.ModuleList( [ SABlock( dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i + depth[0] + depth[1] + depth[2]], norm_layer=norm_layer, ) for i in range(depth[3]) ] ) self.norm = nn.BatchNorm2d(embed_dim[-1]) # Representation layer if representation_size: self.num_features = representation_size self.pre_logits = nn.Sequential( OrderedDict( [ ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()), ] ) ) else: self.pre_logits = nn.Identity() # Classifier head self.head = ( nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): return {"pos_embed", "cls_token"} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def forward_features(self, x): x = self.patch_embed1(x) x = self.pos_drop(x) for blk in self.blocks1: x = blk(x) x = self.patch_embed2(x) for blk in self.blocks2: x = blk(x) x = self.patch_embed3(x) for blk in self.blocks3: x = blk(x) x = self.patch_embed4(x) for blk in self.blocks4: x = blk(x) x = self.norm(x) x = self.pre_logits(x) return x def forward(self, x): x = self.forward_features(x) x = x.flatten(2).mean(-1) x = self.head(x) return x def uniformer_small(pretrained=True, **kwargs): model = UniFormer( depth=[3, 4, 8, 3], embed_dim=[64, 128, 320, 512], head_dim=64, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) return model ================================================ FILE: python/oneflow/test/expensive/pytroch_mlp_mixer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import torch import torch.nn as nn from timm.models.layers import DropPath, lecun_normal_, to_2tuple from functools import partial from typing import Callable class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert ( H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", ) assert ( W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", ) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x def named_apply( fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module class GatedMlp(nn.Module): """ MLP as used in gMLP """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, gate_layer=None, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = ( hidden_features // 2 ) # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.gate(x) x = self.fc2(x) x = self.drop2(x) return x class MixerBlock(nn.Module): """ Residual Block w/ token mixing and channel MLPs Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ def __init__( self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0.0, drop_path=0.0, ): super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path( self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2) ) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x class Affine(nn.Module): def __init__(self, dim): super().__init__() self.alpha = nn.Parameter(torch.ones((1, 1, dim))) self.beta = nn.Parameter(torch.zeros((1, 1, dim))) def forward(self, x): return torch.addcmul(self.beta, self.alpha, x) class ResBlock(nn.Module): """ Residual MLP block w/ LayerScale and Affine 'norm' Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ def __init__( self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine, act_layer=nn.GELU, init_values=1e-4, drop=0.0, drop_path=0.0, ): super().__init__() channel_dim = int(dim * mlp_ratio) self.norm1 = norm_layer(dim) self.linear_tokens = nn.Linear(seq_len, seq_len) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop) self.ls1 = nn.Parameter(init_values * torch.ones(dim)) self.ls2 = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): x = x + self.drop_path( self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2) ) x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x))) return x class SpatialGatingUnit(nn.Module): """ Spatial Gating Unit Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm): super().__init__() gate_dim = dim // 2 self.norm = norm_layer(gate_dim) self.proj = nn.Linear(seq_len, seq_len) def init_weights(self): # special init for the projection gate, called as override by base model init nn.init.normal_(self.proj.weight, std=1e-6) nn.init.ones_(self.proj.bias) def forward(self, x): u, v = x.chunk(2, dim=-1) v = self.norm(v) v = self.proj(v.transpose(-1, -2)) return u * v.transpose(-1, -2) class SpatialGatingBlock(nn.Module): """ Residual Block w/ Spatial Gating Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ def __init__( self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0.0, drop_path=0.0, ): super().__init__() channel_dim = int(dim * mlp_ratio) self.norm = norm_layer(dim) sgu = partial(SpatialGatingUnit, seq_len=seq_len) self.mlp_channels = mlp_layer( dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): x = x + self.drop_path(self.mlp_channels(self.norm(x))) return x class MlpMixer(nn.Module): def __init__( self, num_classes=1000, img_size=224, in_chans=3, patch_size=16, num_blocks=8, embed_dim=512, mlp_ratio=(0.5, 4.0), block_layer=MixerBlock, mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_rate=0.0, drop_path_rate=0.0, nlhb=False, stem_norm=False, global_pool="avg", ): super().__init__() self.num_classes = num_classes self.global_pool = global_pool self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models self.grad_checkpointing = False self.stem = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None, ) # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential( *[ block_layer( embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate, ) for _ in range(num_blocks) ] ) self.norm = norm_layer(embed_dim) self.head = ( nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() ) self.init_weights(nlhb=nlhb) def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0.0 named_apply( partial(_init_weights, head_bias=head_bias), module=self ) # depth-first def group_matcher(self, coarse=False): return dict( stem=r"^stem", # stem and embed blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], ) def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ("", "avg") self.global_pool = global_pool self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) return x def forward(self, x): x = self.forward_features(x) if self.global_pool == "avg": x = x.mean(dim=1) x = self.head(x) return x def _init_weights(module: nn.Module, name: str, head_bias: float = 0.0, flax=False): """ Mixer weight initialization (trying to match Flax defaults) """ if isinstance(module, nn.Linear): if name.startswith("head"): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: if flax: # Flax defaults lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) else: # like MLP init in vit (my original init) nn.init.xavier_uniform_(module.weight) if module.bias is not None: if "mlp" in name: nn.init.normal_(module.bias, std=1e-6) else: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): # NOTE if a parent module contains init_weights method, it can override the init of the # child modules as this will be called in depth-first order. module.init_weights() def mixer_s32_224(pretrained=False, **kwargs): """ Mixer-S/32 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) model = MlpMixer(**model_args) return model ================================================ FILE: python/oneflow/test/expensive/resnet50_model.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Any, Callable, List, Optional, Type, Union import oneflow as flow import oneflow.nn as nn from oneflow import Tensor class FakeBN(nn.Module): """Common base of _InstanceNorm and _BatchNorm""" def __init__( self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ) -> None: super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = flow.nn.Parameter(flow.Tensor(num_features)) self.bias = flow.nn.Parameter(flow.Tensor(num_features)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) if self.track_running_stats: self.register_buffer("running_mean", flow.Tensor(num_features)) self.register_buffer("running_var", flow.Tensor(num_features)) else: self.register_parameter("running_mean", None) self.register_parameter("running_var", None) def forward(self, input): return flow._C.identity(input) def conv3x3( in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 ) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU() self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU() self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None or a 3-element tuple, got {}".format( replace_stride_with_dilation ) ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AvgPool2d((7, 7)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) return model def resnet50(**kwargs: Any) -> ResNet: """ResNet-5 `"Deep Residual Learning for Image Recognition" `_. """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs) ================================================ FILE: python/oneflow/test/expensive/test_compatibility.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.test_utils.oneflow_pytorch_compatibility import * import os @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestApiCompatibility(flow.unittest.TestCase): def test_alexnet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_alexnet.py", "alexnet", "cuda", 16, 224 ) def test_resnet50_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_resnet.py", "resnet50", "cuda", 16, 224 ) @unittest.skipIf( os.environ["ONEFLOW_CI"] == "1", "always get error: 'Check failed: cudnnConvolutionBackwardFilter'", ) def test_convmixer_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_convmixer.py", "convmixer_768_32_relu", "cuda", 4, 224 ) def test_densenet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_densenet.py", "densenet121", "cuda", 8, 224 ) def test_ghostnet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_ghostnet.py", "ghost_net", "cuda", 16, 224 ) def test_googlenet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_googlenet.py", "googlenet", "cuda", 8, 224 ) def test_inception_v3_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_inception_v3.py", "inception_v3", "cuda", 4, 299 ) def test_mnasnet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_mnasnet.py", "mnasnet1_0", "cuda", 16, 224 ) # def test_rexnet_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, "pytorch_rexnet.py", "rexnetv1_1_0", "cuda", 16, 224 # ) # TODO(): support non-contiguous inplace add # def test_rexnetv1_lite_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, "pytorch_rexnetv1_lite.py", "rexnet_lite_1_0", "cuda", 16, 224 # ) # def test_res2net_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, "pytorch_res2net.py", "res2net50", "cuda", 16, 224 # ) def test_shufflenetv2_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_shufflenetv2.py", "shufflenet_v2_x2_0", "cuda", 16, 224 ) def test_squeezenet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_squeezenet.py", "squeezenet1_1", "cuda", 16, 224 ) @unittest.skipIf( os.environ["ONEFLOW_CI"] == "1", "always get error: 'Check failed: cudnnConvolutionBackwardFilter'", ) def test_convnext_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_convnext.py", "convnext_tiny", "cuda", 8, 224 ) # def test_crossformer_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, # "pytorch_crossformer.py", # "crossformer_tiny_patch4_group7_224", # "cuda", # 8, # 224, # ) # def test_efficientnet_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, "pytorch_efficientnet.py", "efficientnet_b0", "cuda", 8, 224, # ) def test_levit_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_levit.py", "LeViT_128S", "cuda", 8, 224, ) # def test_mlp_mixer_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, "pytroch_mlp_mixer.py", "mixer_s32_224", "cuda", 8, 224, # ) def test_poolformer_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_poolformer.py", "poolformer_s12", "cuda", 8, 224, ) def test_pvt_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_pvt.py", "pvt_tiny", "cuda", 8, 224, ) def test_resmlp_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_resmlp.py", "resmlp_12", "cuda", 8, 224, ) def test_uniformer_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_uniformer.py", "uniformer_small", "cuda", 8, 224, ) # TODO(): support non-contiguous inplace add # def test_swin_transformer_compatibility(test_case): # do_test_train_loss_oneflow_pytorch( # test_case, # "pytorch_swin_transformer.py", # "swin_tiny_patch4_window7_224", # "cuda", # 8, # 224, # ) def test_senet_compatibility(test_case): do_test_train_loss_oneflow_pytorch( test_case, "pytorch_senet.py", "senet154", "cuda", 2, 224, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_conv3d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestConv3DModule(flow.unittest.TestCase): @autotest(n=3) def test_nn_functional_conv3d(test_case): flow.backends.cuda.matmul.allow_tf32 = True device = random_device() img = torch.ones((1, 3, 16, 16, 16), requires_grad=True).to(device) kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device) y = torch.nn.functional.conv3d(img, kernel) return y @autotest(n=10, rtol=1e-3, atol=1e-4) def test_conv3d_with_random_data(test_case): flow.backends.cuda.matmul.allow_tf32 = True channels = random(1, 6) m = torch.nn.Conv3d( in_channels=channels, out_channels=random(1, 6), kernel_size=random(1, 3), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=5, dim0=2, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, check_allclose=False, rtol=1e-3) def test_conv3d_group_with_random_data(test_case): flow.backends.cuda.matmul.allow_tf32 = True channels = 720 # lcm(1, 2, 3, 4, 5, 6) m = torch.nn.Conv3d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 7), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=5, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_convtranspose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.nn as nn import oneflow.unittest def _test_convtranspose1d_bias_false(test_case, device): np_arr = np.array([[[0.35356437, -0.95761778, 0.19567713]]]) weight = np.ones((1, 2, 3)) test_out_data = np.array( [ [ [0.35356438, -0.6040534, -0.40837622, -0.7619406, 0.19567713], [0.35356438, -0.6040534, -0.40837622, -0.7619406, 0.19567713], ] ] ) test_out_grad = np.array([[[6.0, 6.0, 6.0]]]) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(1, 2, 3, stride=1, bias=False) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-03, 1e-05)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) def _test_convtranspose1d_bias_true(test_case, device): np_arr = np.array([[[0.54925832, -0.64144184, 0.15213189]]]) weight = np.ones((1, 2, 3)) bias = np.array([0.16849578, 0.1509564]) test_out_data = np.array( [ [ [0.71775407, 0.07631224, 0.22844413, -0.32081416, 0.32062766], [0.7002147, 0.05877288, 0.21090476, -0.3383535, 0.3030883], ] ] ) test_out_grad = np.array([[[6.0, 6.0, 6.0]]]) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(1, 2, 3, stride=1, bias=True) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f.bias = nn.Parameter(flow.Tensor(bias)) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-02, 1e-05)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) def _test_convtranspose1d_group_bias_false(test_case, device): np_arr = np.array( [[[0.38072484, -0.01421228, -0.6512485], [-0.05744093, 2.47079971, 0.17573214]]] ) weight = np.ones((2, 1, 3)) test_out_data = np.array( [ [ [0.38072485, 0.36651257, -0.28473592, -0.66546077, -0.6512485], [-0.05744093, 2.4133587, 2.5890908, 2.6465318, 0.17573214], ] ] ) test_out_grad = np.array([[[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]]) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(2, 2, 3, stride=1, groups=2, bias=False) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) def _test_convtranspose1d_group_bias_true(test_case, device): np_arr = np.array( [ [ [-0.77808793, 0.99824008, 0.57340066], [1.46278707, -0.65234252, -1.13087643], ], [ [0.76053973, 0.62332447, -1.17157106], [0.60291466, -0.0472167, 0.89986403], ], ] ) weight = np.ones((2, 1, 3)) bias = np.array([0.32546719, 0.14995032]) test_out_data = np.array( [ [ [-0.45262071, 0.54561937, 1.11902, 1.897108, 0.89886785], [1.6127374, 0.96039486, -0.1704815, -1.6332686, -0.9809261], ], [ [1.0860069, 1.7093314, 0.5377604, -0.22277936, -0.8461038], [0.75286496, 0.70564824, 1.6055121, 1.0025976, 1.0498143], ], ] ) test_out_grad = np.array( [[[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]], [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]] ) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(2, 2, 3, stride=1, groups=2, bias=True) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f.bias = nn.Parameter(flow.Tensor(bias)) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) def _test_convtranspose1d_group_large_out_channel(test_case, device): np_arr = np.array( [ [ [2.00934643, 1.5782626, -1.59060988], [-1.70463546, 1.30170714, -1.04025804], ], [ [0.60327536, 1.26085986, -0.58499662], [-0.48145872, -1.64391469, -0.09332249], ], ] ) weight = np.ones((2, 3, 3)) test_out_data = np.array( [ [ [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099], [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099], [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099], [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258], [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258], [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258], ], [ [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664], [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664], [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664], [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249], [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249], [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249], ], ] ) test_out_grad = np.array( [[[9.0, 9.0, 9.0], [9.0, 9.0, 9.0]], [[9.0, 9.0, 9.0], [9.0, 9.0, 9.0]]] ) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(2, 6, 3, stride=1, groups=2, bias=False) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) def _test_convtranspose1d_group_large_in_channel(test_case, device): np_arr = np.array( [ [ [-0.3939792, -0.34989742, 0.15775536], [0.927185, 0.25040535, -1.22738067], [-0.2187831, -0.24346108, -0.07109655], [-1.55353756, -0.37241986, 0.59579139], ], [ [-0.01818884, -1.34408642, 1.31260516], [0.52124192, 0.52142919, 1.40499944], [0.7410308, 1.93069512, 0.25694943], [-0.30531658, 0.24990326, -0.9493729], ], ] ) weight = np.ones((4, 1, 3)) test_out_data = np.array( [ [ [0.5332058, 0.43371373, -0.6359115, -1.1691173, -1.0696253], [-1.7723207, -2.3882017, -1.8635068, -0.09118611, 0.52469486], ], [ [0.50305307, -0.31960416, 2.3980005, 1.8949474, 2.7176046], [0.43571424, 2.6163127, 1.9238893, 1.488175, -0.69242346], ], ] ) test_out_grad = np.array( [ [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0]], [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0]], ] ) input_flow = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m_f = nn.ConvTranspose1d(4, 2, 3, stride=1, groups=2, bias=False) m_f.weight.data = flow.tensor(weight, dtype=flow.float32) m_f = m_f.to(device) out_flow = m_f(input_flow) test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06)) out_flow = out_flow.sum() out_flow.backward() test_case.assertTrue( np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06) ) @flow.unittest.skip_unless_1n1d() class TestConvTranspose(flow.unittest.TestCase): def test_ConvTranspose1d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_convtranspose1d_bias_false, _test_convtranspose1d_bias_true, _test_convtranspose1d_group_bias_false, _test_convtranspose1d_group_bias_true, _test_convtranspose1d_group_large_out_channel, _test_convtranspose1d_group_large_in_channel, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, rtol=1e-2) def test_ConvTranspose1d_(test_case): channels = random(1, 6) m = torch.nn.ConvTranspose1d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5) def test_deconv1d_group_with_random_data(test_case): channels = 720 # lcm(1, 2, 3, 4, 5, 6) m = torch.nn.ConvTranspose1d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 7), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=3, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y @autotest(n=5, rtol=1e-2) def test_ConvTranspose3d_(test_case): channels = random(1, 2) m = torch.nn.ConvTranspose3d( in_channels=channels, out_channels=random(1, 2), kernel_size=random(1, 2), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=1, padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=5, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5) def test_deconv3d_group_with_random_data(test_case): channels = 120 # lcm(1, 2, 3, 4, 5) m = torch.nn.ConvTranspose3d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 6), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=5, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y @autotest(n=3, auto_backward=False) @unittest.skip("TODO: functional_conv_transpose might output incorrect result") def test_functional_conv_transpose1d(test_case): device = random_device() channels = random(1, 6) img = random_tensor(ndim=3, dim1=channels).to(device) kernel = random_tensor(ndim=3, dim0=channels).to(device) y = torch.nn.functional.conv_transpose1d(img, kernel) return y @autotest(n=3, auto_backward=False) @unittest.skip("TODO: functional_conv_transpose might output incorrect result") def test_functional_conv_transpose2d(test_case): device = random_device() channels = random(1, 6) img = random_tensor(ndim=4, dim1=channels).to(device) kernel = random_tensor(ndim=4, dim0=channels).to(device) y = torch.nn.functional.conv_transpose2d(img, kernel) return y @autotest(n=3, auto_backward=False) @unittest.skip("TODO: functional_conv_transpose might output incorrect result") def test_functional_conv_transpose3d(test_case): device = random_device() channels = random(1, 6) img = random_tensor(ndim=5, dim1=channels).to(device) kernel = random_tensor(ndim=5, dim0=channels).to(device) y = torch.nn.functional.conv_transpose3d(img, kernel) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_dynamic_allocation_gradient_shuffle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "1" import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow def round_half_away_from_zero(x): sign = np.sign(x) abs_val = np.abs(x) abs_val += 0.5 floor_val = np.floor(abs_val) out = floor_val * sign return out def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): batch_size = 512 num_tables = 26 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" ids = np.arange(batch_size * num_tables, dtype=np.int64) np.random.shuffle(ids) else: if fp16: np_tolerance = 1e-2 else: np_tolerance = 1e-3 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids embedding_grad = np.random.uniform( low=-1, high=1, size=(batch_size, num_tables, embedding_size) ).astype(np.float32) ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to( "cuda" ) embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, embedding_grad): ( num_unique_matrix, inverse_unique_partition_indices, _, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") if fp16: embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, "test", ) if fp16: cur_rank_unique_embedding_grad = flow.cast( cur_rank_unique_embedding_grad, flow.float32 ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_unique_ids, flow.int32), flow.cast(cur_rank_inverse_indices, flow.int32), flow.cast(inverse_unique_partition_indices, flow.int32), ) graph = TestGraph() ( cur_rank_unique_embedding_grad, cur_rank_unique_ids, cur_rank_inverse_indices, inverse_unique_partition_indices, ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros( cur_rank_unique_embedding_grad.shape, dtype=np.float32 ).reshape(-1, embedding_size) embedding_grad = embedding_grad.reshape(-1, embedding_size) if fp16: embedding_grad = embedding_grad.astype(np.float16) for k in range(np_num_unique): np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]]) # Quantize Embedding Gradient. if enable_quantized_comm: abs_max_factor = np.max(np.abs(np_data)) int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) quantize_factor = int8_factor / abs_max_factor np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) np_data = np_data.astype(np.float32) dequantize_factor = abs_max_factor / int8_factor np_data = np_data * dequantize_factor np_cur_rank_unique_embedding_grad[k, :] = np_data reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][ inverse_unique_partition_indices ] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) of_cur_rank_embedding_grad = cur_rank_unique_embedding_grad[ cur_rank_inverse_indices ][inverse_unique_partition_indices] of_cur_rank_embedding_grad = flow.reshape( of_cur_rank_embedding_grad, (-1, embedding_size) ) np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse] if fp16: np_cur_rank_embedding_grad = np_cur_rank_embedding_grad.astype(np.float32) test_case.assertTrue( np.allclose( of_cur_rank_embedding_grad.numpy().flatten(), np_cur_rank_embedding_grad.flatten(), atol=np_tolerance, rtol=np_tolerance, ) ) def _test_unique_key_value(test_case, has_table_id, num_tables): batch_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) if has_table_id: table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids table_ids_tensor = flow.tensor( table_ids.astype(np.int32), requires_grad=False ).to("cuda") else: table_ids_tensor = None ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids): ( num_unique, unique_ids, unique_table_ids, inverse_indices, ) = flow._C.one_embedding_unique_key_value_pair(ids, table_ids, num_tables) return ( flow.cast(num_unique, flow.int32), flow.cast(unique_ids, flow.int32), flow.cast(unique_table_ids, flow.int32), flow.cast(inverse_indices, flow.int32), ) graph = TestGraph() (num_unique, unique_ids, unique_table_ids, inverse_indices,) = graph( ids_tensor, table_ids_tensor ) np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size test_case.assertTrue(np.array_equal(np_num_unique, num_unique[0])) reversed_ids = unique_ids[inverse_indices] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) if has_table_id: reversed_table_ids = unique_table_ids[inverse_indices] test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class DataShuffleTestCase(flow.unittest.TestCase): def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] arg_dict["fp16"] = [True, False] arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_einsum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestEinsum(flow.unittest.TestCase): @autotest(n=5) def test_einsum_matrix_transpose(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) z = torch.einsum("ij->ji", x) return z @autotest(n=5) def test_einsum_eltwise_multiply(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) z = torch.einsum("ij,ij->ij", x, y) return z @autotest(n=5) def test_einsum_get_diagonal(test_case): device = random_device() dim = random(1, 6) x = random_tensor(ndim=2, dim0=dim, dim1=dim,).to(device) z = torch.einsum("ii->i", x) return z @autotest(n=5) def test_einsum_batch_permute(test_case): device = random_device() x = random_tensor( ndim=5, dim0=random(1, 6), dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6), dim4=random(1, 6), ).to(device) z = torch.einsum("...ij->...ji", x) return z @autotest(n=5) def test_einsum_reduce_sum(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) z = torch.einsum("ij->", x) return z @autotest(n=5) def test_einsum_matrix_column_sum(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) z = torch.einsum("ij->j", x) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_matrix_vector_multiply(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=1, dim0=dim1,).to(device) # NOTE(Liang Depeng): the same as 'ik,k->i' z = torch.einsum("ik,k", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_matmul(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,).to(device) # NOTE(Liang Depeng): the same as 'ik,kj->ij' z = torch.einsum("ik,kj", x, y) return z @autotest(n=5) def test_einsum_vector_inner_product(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=1, dim0=dim0,).to(device) y = random_tensor(ndim=1, dim0=dim0,).to(device) # NOTE(Liang Depeng): the same as 'i,i->' z = torch.einsum("i,i", x, y) return z @autotest(n=5) def test_einsum_eltwise_mul_then_reduce_sum(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) # NOTE(Liang Depeng): the same as 'ij,ij->' z = torch.einsum("ij,ij", x, y) return z @autotest(n=5) def test_einsum_vector_outer_product(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random(1, 6),).to(device) y = random_tensor(ndim=1, dim0=random(1, 6),).to(device) # NOTE(Liang Depeng): the same as 'i,j->ij' z = torch.einsum("i,j", x, y) return z @autotest(n=5, rtol=1e-2) def test_einsum_batch_matmul(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("ijk,ikl->ijl", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_tensor_contraction(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6), ).to(device) y = random_tensor( ndim=5, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6), dim4=dim1, ).to(device) z = torch.einsum("pqrs,tuqvr->pstuv", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_bilinear_transformation(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2,).to(device) w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,).to(device) z = torch.einsum("ik,jkl,il->ij", x, y, w) return z @autotest(n=20, auto_backward=False, check_graph=True) def test_einsum_0_size_tensor(test_case): device = random_device() x = random_tensor(ndim=3, dim0=random(1, 6), dim1=0, dim2=random(1, 6),).to( device ) z = torch.einsum("ijk", x) return z @unittest.skip("skip for now, becase it failed 20 times in past week") @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_tensor_contraction2(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=random(1, 6), ).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device) z = torch.einsum("b n h w, n d -> b d h w", x, y) return z @autotest(n=5) def test_einsum_eltwise_mul_sum_row(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) z = torch.einsum("n d, n d -> n", x, y) return z @unittest.skip("skip for now, becase it failed 20 times in past week") @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_matmul2(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) y = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) z = torch.einsum("i d, j d -> i j", x, y) return z @autotest(n=5, rtol=1e-3) def test_einsum_attention(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, ).to(device) z = torch.einsum("b h i d, b h j d -> b h i j", x, y) return z @autotest(n=5, rtol=1e-3) def test_einsum_batch_matmul2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6) ).to(device) z = torch.einsum("b h i j, b h j d -> b h i d", x, y) return z @unittest.skip("skip for now, becase it failed 28 times in past week") @autotest(n=5, rtol=1e-2) def test_einsum_batch_matrix_vector_multiply(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, ).to(device) z = torch.einsum("b i d, b i j d -> b i j", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_batch_matmul3(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1, ).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) z = torch.einsum("b x i d, b j d -> b x i j", x, y) return z @autotest(n=5, rtol=1e-2) def test_einsum_batch_matmul4(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1, ).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("b x i j, b j d -> b x i d", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_alphaflod_usecase1(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("hij, ijc->ihc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("rac,rab->rbc", x, y) return z @autotest(n=5, rtol=1e-2) def test_einsum_alphaflod_usecase3(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("ra,rab->rb", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_alphaflod_usecase4(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) z = torch.einsum("qhc,khc->qkh", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase5(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( device ) z = torch.einsum("nm, mrc->nrc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase6(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) z = torch.einsum("abc,adc->bdc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase7(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6), ).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) z = torch.einsum("dceb,cef->dbf", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase8(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( device ) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( device ) z = torch.einsum("acb,ade->dceb", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase9(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( device ) y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device) z = torch.einsum("qkc,ch->hqk", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_alphaflod_usecase10(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 6) ).to(device) z = torch.einsum("bhqk,bkhc->bqhc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_alphaflod_usecase11(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( device ) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( device ) z = torch.einsum("bqa,ahc->bqhc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_ellipsis_usecase1(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( device ) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( device ) z = torch.einsum("...lc, ...c -> ...l", x, y) return z @autotest(n=5, rtol=1e-2) def test_einsum_ellipsis_usecase2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1).to(device) z = torch.einsum("...lc, ...lc -> ...l", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_ellipsis_usecase3(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( device ) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( device ) z = torch.einsum("...id,...jd->...ij", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_ellipsis_usecase4(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 ).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6)).to(device) z = torch.einsum("...klm,kmn->...kln", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_ellipsis_usecase5(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) ).to(device) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( device ) z = torch.einsum("...ikl, ...jk -> ...ijl", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_ellipsis_usecase6(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( device ) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( device ) z = torch.einsum("...l,...l->...", x, y) return z @autotest(n=5) def test_einsum_ellipsis_usecase7(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6) ).to(device) z = torch.einsum("ijk,ijk...->ij...", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_other_usecase1(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device) y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2).to(device) w = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim2).to(device) z = torch.einsum("bxi,oij,byj->boxy", x, y, w) return z @autotest(n=5) def test_einsum_other_usecase2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) ).to(device) z = torch.einsum("ijac,ijkp->ijakcp", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_other_usecase3(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=random(1, 6), dim2=dim1, dim3=random(1, 6) ).to(device) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device) z = torch.einsum("cdij,cbi->cdbj", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_fastfold_usecase1(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) dim2 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 ).to(device) z = torch.einsum("bsid,bsjd->bijd", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_fastfold_usecase2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) ).to(device) y = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) ).to(device) z = torch.einsum("bsid,bsje->bijde", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_openfold_usecase1(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) ).to(device) y = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) ).to(device) z = torch.einsum("...bac,...dae->...bdce", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_openfold_usecase2(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 ).to(device) y = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 ).to(device) z = torch.einsum("...abc,...adc->...bdc", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-4) def test_einsum_openfold_usecase3(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1 ).to(device) y = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1 ).to(device) z = torch.einsum("...qhd,...khd->...hqk", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_openfold_usecase4(test_case): device = random_device() dim0 = random(1, 6) dim1 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6) ).to(device) y = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim1, dim3=dim0 ).to(device) z = torch.einsum("...vhf,...qhv->...qhf", x, y) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_einsum_openfold_usecase5(test_case): device = random_device() dim0 = random(1, 6) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=random(1, 6), dim3=dim0 ).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6)).to(device) z = torch.einsum("...ij,jk->ik", x, y) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_global_tensor_offload.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.nn as nn import oneflow.unittest # NOTE(Li Xiang): This variable controls the mem comparison method of the tensor offload test. # 1: Strictly test, compare mem changes according to tensor size. # 2: Loose test, compare mem changes before and after offload; # 3: Execute only offload, skip mem check. offload_tensor_test_mem_mode = 3 def _test_global_tensor_offload_d2h(test_case, input, tensor_mem): test_case.assertTrue(not input.is_offloaded()) flow.cuda.empty_cache() if input.placement == oneflow.placement(type="cuda", ranks=[0, 1]): flow._oneflow_internal.CudaSynchronize(0) flow._oneflow_internal.CudaSynchronize(1) elif input.placement == oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]): flow._oneflow_internal.CudaSynchronize(0) flow._oneflow_internal.CudaSynchronize(1) flow._oneflow_internal.CudaSynchronize(2) flow._oneflow_internal.CudaSynchronize(3) flow._oneflow_internal.eager.ClusterSync() before_used = flow._oneflow_internal.GetCUDAMemoryUsed() before_id = id(input) print("cuda", before_used) input.offload() test_case.assertTrue(input.is_offloaded()) test_case.assertEqual(input.placement.type, "cuda") after_used = flow._oneflow_internal.GetCUDAMemoryUsed() after_id = id(input) print("cuda to cpu", after_used) # Check global_tensor_mem cuda memory released if offload_tensor_test_mem_mode == 1: # NOTE(Li Xiang): In the case of 4 gpus, the memory usage of the tensor sometimes has a 2MB error. if input.placement == oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]): test_case.assertTrue( ((before_used - after_used) == tensor_mem) or ((before_used - after_used) == (tensor_mem - 2)) ) return test_case.assertTrue((before_used - after_used) == tensor_mem) elif offload_tensor_test_mem_mode == 2: test_case.assertTrue(before_used > after_used) elif offload_tensor_test_mem_mode == 3: print( "Device:", flow.env.get_rank(), ". cuda mem change value:", before_used - after_used, ) test_case.assertEqual(before_id, after_id) def _test_global_tensor_load_h2d(test_case, input, tensor_mem): test_case.assertTrue(input.is_offloaded()) if input.placement == oneflow.placement(type="cuda", ranks=[0, 1]): flow._oneflow_internal.CudaSynchronize(0) flow._oneflow_internal.CudaSynchronize(1) elif input.placement == oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]): flow._oneflow_internal.CudaSynchronize(0) flow._oneflow_internal.CudaSynchronize(1) flow._oneflow_internal.CudaSynchronize(2) flow._oneflow_internal.CudaSynchronize(3) flow._oneflow_internal.eager.ClusterSync() before_used = flow._oneflow_internal.GetCUDAMemoryUsed() before_id = id(input) input.load() test_case.assertTrue(not input.is_offloaded()) test_case.assertEqual(input.placement.type, "cuda") after_used = flow._oneflow_internal.GetCUDAMemoryUsed() after_id = id(input) print("cpu to cuda", after_used) # Check global_tensor_mem cuda memory allocated if offload_tensor_test_mem_mode == 1: # NOTE(Li Xiang): In the case of 4 gpus, the memory usage of the tensor sometimes has a 2MB error. if input.placement == oneflow.placement(type="cuda", ranks=[0, 1, 2, 3]): test_case.assertTrue( ((after_used - before_used) == tensor_mem) or ((after_used - before_used) == (tensor_mem - 2)) ) return test_case.assertTrue((after_used - before_used) == tensor_mem) elif offload_tensor_test_mem_mode == 2: test_case.assertTrue(after_used > before_used) elif offload_tensor_test_mem_mode == 3: print( "Device:", flow.env.get_rank(), ". cuda mem change value:", after_used - before_used, ) test_case.assertEqual(before_id, after_id) def _get_specific_global_tensor_mem(placement, sbp, tensor): size_tensor = tensor.clone().detach().to_local() cnt_size = size_tensor.element_size() * flow.numel(size_tensor) return cnt_size / 1024 / 1024 @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalTensorOffload(flow.unittest.TestCase): @globaltest @flow.unittest.skip_unless_1n2d() def test_global_tensor_offload_and_load_2d(test_case): for i in range(5): placement = flow.placement("cuda", ranks=[0, 1]) for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): input = flow.randn( 1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp ) data = input.numpy() tensor_mem = _get_specific_global_tensor_mem(placement, sbp, input) _test_global_tensor_offload_d2h(test_case, input, tensor_mem) _test_global_tensor_load_h2d(test_case, input, tensor_mem) test_case.assertTrue( np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001) ) @globaltest @flow.unittest.skip_unless_1n4d() def test_global_tensor_offload_and_load_4d(test_case): for i in range(5): placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): input = flow.randn( 1024, 1024, 10, dtype=flow.float32, placement=placement, sbp=sbp ) data = input.numpy() tensor_mem = _get_specific_global_tensor_mem(placement, sbp, input) _test_global_tensor_offload_d2h(test_case, input, tensor_mem) _test_global_tensor_load_h2d(test_case, input, tensor_mem) test_case.assertTrue( np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001) ) @globaltest @flow.unittest.skip_unless_1n2d() def test_global_tensor_offload_and_load_2d_cpu_mem(test_case): flow.cuda.empty_cache() for i in range(5): placement = flow.placement("cuda", ranks=[0, 1]) for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): input = flow.randn( 1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp ) before_used = flow._oneflow_internal.GetCPUMemoryUsed() before_id = id(input) input.offload() after_used = flow._oneflow_internal.GetCPUMemoryUsed() after_id = id(input) if offload_tensor_test_mem_mode == 2: test_case.assertTrue(after_used > before_used) elif offload_tensor_test_mem_mode == 3: print("cpu mem change value:", after_used - before_used) test_case.assertEqual(before_id, after_id) cur_used = flow._oneflow_internal.GetCPUMemoryUsed() before_id = id(input) input.load() after_used = flow._oneflow_internal.GetCPUMemoryUsed() after_id = id(input) if offload_tensor_test_mem_mode == 2: test_case.assertTrue(after_used < cur_used) elif offload_tensor_test_mem_mode == 3: print("cpu mem change value:", cur_used - after_used) test_case.assertEqual(before_id, after_id) @globaltest @flow.unittest.skip_unless_1n2d() def test_global_param_offload_and_load(test_case): def load_eager_model(model): for param in model.parameters(): if param.is_offloaded(): param.load() test_case.assertTrue(not param.is_offloaded()) def offload_eager_model(model): for param in model.parameters(): if not param.is_offloaded(): param.offload() test_case.assertTrue(param.is_offloaded()) class Model(nn.Module): def __init__(self): super().__init__() self.n_layer = 1 layer_list = list() for _ in range(self.n_layer): layer_list.append(nn.Linear(768, 4096)) self.layers = nn.Sequential(*layer_list) def forward(self, x): return self.layers(x) placement = flow.placement("cuda", ranks=[0, 1]) model0 = Model().cuda() model0.to_global(placement=placement, sbp=flow.sbp.broadcast) BZ = 128 dataset = [flow.rand((BZ, 768), dtype=flow.float32) for _ in range(128)] with flow.no_grad(): for idx, x in enumerate(dataset): print(f"iter {idx} begin") x = x.cuda() x = x.to_global(placement=placement, sbp=flow.sbp.broadcast) load_eager_model(model0) y0 = model0(x) offload_eager_model(model0) print(f"iter {idx} end") if idx == 1: break if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_graph_multi_graph_v2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import time import tempfile import multiprocessing import oneflow as flow import oneflow.unittest def _reset_session(): # Close session to avoid the buffer name duplicate error. oneflow.framework.session_context.TryCloseDefaultSession() time.sleep(5) flow.framework.session_context.NewDefaultSession(flow._oneflow_global_unique_env) def _with_new_session(fn): def new_fn(*args, **kwargs): # Avoid Singleton value duplication such as buffer names. # saved and loaded graph runtime share the same buffer names(job names). print( "function ", fn.__name__, " session reset to avoid Singleton value duplication ...", ) _reset_session() out = fn(*args, **kwargs) _reset_session() return out return new_fn def _test_linear_multi_graph_share(test_case, device, with_reshape): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) flow.nn.init.constant_(linear.weight, 2.3) class LinearReshapeModule(flow.nn.Module): def __init__(self, lin, with_r): super().__init__() self.linear = lin self.with_reshape = with_r def forward(self, x): y = self.linear(x) if with_reshape: assert len(y.shape) == 2 return flow.reshape(y, (y.shape[1], y.shape[0])) else: return y linear_reshape = LinearReshapeModule(linear, with_reshape) class LinearGraph(flow.nn.Graph): @flow.nn.Graph.with_dynamic_input_shape(size=4) def __init__(self, lin, with_r): super().__init__() self.my_linear = LinearReshapeModule(lin, with_r) def build(self, x): return self.my_linear(x) linear_g = LinearGraph(linear, with_reshape) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) x = flow.tensor(input_arr, device=device) of_lazy_out = linear_g(x) of_eager_out = linear_reshape(x) test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) input_arr1 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], ], dtype=np.float32, ) x1 = flow.tensor(input_arr1, device=device) of_lazy_out1 = linear_g(x1) of_eager_out1 = linear_reshape(x1) test_case.assertTrue(np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy())) input_arr2 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], ], dtype=np.float32, ) x2 = flow.tensor(input_arr2, device=device) of_lazy_out2 = linear_g(x2) of_eager_out2 = linear_reshape(x2) test_case.assertTrue(np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy())) of_lazy_out2 = linear_g(x2) of_eager_out2 = linear_reshape(x2) test_case.assertTrue(np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy())) def _get_state_dict_tensor_size(sd): from oneflow.framework.args_tree import ArgsTree def _get_tensor_mem(input): # if input.dim() == 0: # return 2 cnt_size = input.element_size() * flow.numel(input) return cnt_size args_tree = ArgsTree(sd, False) size = 0 for arg in args_tree.iter_nodes(): if isinstance(arg, flow.Tensor): size += _get_tensor_mem(arg) else: continue return size @_with_new_session def _test_linear_multi_graph_save(return_dict, device, with_reshape, with_eager): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) flow.nn.init.constant_(linear.weight, 2.3) class LinearReshapeModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = linear def forward(self, x): y = self.linear(x) if with_reshape: assert len(y.shape) == 2 return flow.reshape(y, (y.shape[1], y.shape[0])) else: return y linear_reshape = LinearReshapeModule() class LinearGraph(flow.nn.Graph): @flow.nn.Graph.with_dynamic_input_shape(size=3) def __init__(self): super().__init__(enable_get_runtime_state_dict=True) self.my_linear = linear_reshape def build(self, x): return self.my_linear(x) linear_g = LinearGraph() input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) x = flow.tensor(input_arr, device=device) of_lazy_out = linear_g(x) of_eager_out = linear_reshape(x) test_case0 = np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()) return_dict["save0"] = test_case0 input_arr1 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], ], dtype=np.float32, ) x1 = flow.tensor(input_arr1, device=device) of_lazy_out1 = linear_g(x1) of_eager_out1 = linear_reshape(x1) test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy()) return_dict["save1"] = test_case1 input_arr2 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], ], dtype=np.float32, ) x2 = flow.tensor(input_arr2, device=device) of_lazy_out2 = linear_g(x2) of_eager_out2 = linear_reshape(x2) test_case2 = np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy()) return_dict["save2"] = test_case2 input_arr3 = np.array([[-0.94630778, -0.83378579, -0.87060891],], dtype=np.float32,) x3 = flow.tensor(input_arr3, device=device) of_lazy_out3 = linear_g(x3) of_eager_out3 = linear_reshape(x3) test_case3 = np.array_equal(of_lazy_out3.numpy(), of_eager_out3.numpy()) return_dict["save3"] = test_case3 of_lazy_out1 = linear_g(x1) test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy()) return_dict["save4"] = test_case1 state_dict = linear_g.runtime_state_dict(with_eager=with_eager) print("====> saved graphs", state_dict.keys()) return state_dict @_with_new_session def _test_linear_multi_graph_load( return_dict, device, with_reshape, state_dict, with_new_input ): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) flow.nn.init.constant_(linear.weight, 2.3) class LinearReshapeModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = linear def forward(self, x): y = self.linear(x) if with_reshape: assert len(y.shape) == 2 return flow.reshape(y, (y.shape[1], y.shape[0])) else: return y linear_reshape = LinearReshapeModule() class LinearGraph(flow.nn.Graph): @flow.nn.Graph.with_dynamic_input_shape(size=20) def __init__(self): super().__init__(debug_v_level=0) self.my_linear = linear_reshape def build(self, x): return self.my_linear(x) linear_g = LinearGraph() print("====> load") linear_g.load_runtime_state_dict(state_dict) print("====> load finish") input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) x = flow.tensor(input_arr, device=device) of_lazy_out = linear_g(x) of_eager_out = linear_reshape(x) test_case0 = np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()) return_dict["load0"] = test_case0 input_arr1 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], ], dtype=np.float32, ) x1 = flow.tensor(input_arr1, device=device) of_lazy_out1 = linear_g(x1) of_eager_out1 = linear_reshape(x1) test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy()) return_dict["load1"] = test_case1 if with_new_input: # The following section is for testing the new input shape after completing the load. input_arr2 = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.08086036, -1.81075924, 1.20752494], ], dtype=np.float32, ) x2 = flow.tensor(input_arr2, device=device) of_lazy_out2 = linear_g(x2) of_eager_out2 = linear_reshape(x2) test_case2 = np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy()) return_dict["load2"] = test_case2 def _graph_save(return_dict, filename, with_eager): state_dict = _test_linear_multi_graph_save( return_dict, flow.device("cuda:0"), True, with_eager, ) print( f"state_dict(with_eager={with_eager}) tensors size ", _get_state_dict_tensor_size(state_dict), ) flow.save(state_dict, filename) print("====> save process done") def _graph_load(return_dict, filename): state_dict_loaded = flow.load(filename) # load with nn.Graph _test_linear_multi_graph_load( return_dict, flow.device("cuda"), True, state_dict_loaded, True ) print("====> load process done") def _graph_load_to_another_device(return_dict, filename): state_dict_loaded = flow.load(filename) new_state_dict = flow.nn.Graph.runtime_state_dict_to( state_dict_loaded, flow.device("cuda:1") ) # load with nn.Graph _test_linear_multi_graph_load( return_dict, flow.device("cuda:1"), True, new_state_dict, False ) print("====> load process done") def _test_linear_multi_graph_save_load_gpu(test_case, with_eager): # A graph runtime state dict with tempfile.NamedTemporaryFile() as f: # Save a graph manager = multiprocessing.Manager() return_dict = manager.dict() save_p = multiprocessing.get_context("spawn").Process( target=_graph_save, args=(return_dict, f.name, with_eager), ) save_p.start() save_p.join() # Resume a graph from a graph runtime state dict load_p = multiprocessing.get_context("spawn").Process( target=_graph_load, args=(return_dict, f.name) ) load_p.start() load_p.join() # test_case can't be passed into sub process, so we check with return_dict. # Reference: https://stackoverflow.com/questions/52225003/writing-to-multiple-files-using-multiprocessing-error-typeerror-cannot-seria for (key, check_value) in return_dict.items(): test_case.assertTrue(check_value, key + " failed.") def _test_load_to_another_device(test_case, with_eager): # A graph runtime state dict with tempfile.NamedTemporaryFile() as f: # Save a graph manager = multiprocessing.Manager() return_dict = manager.dict() save_p = multiprocessing.get_context("spawn").Process( target=_graph_save, args=(return_dict, f.name, with_eager), ) save_p.start() save_p.join() print(save_p) # Resume a graph from a graph runtime state dict load_p = multiprocessing.get_context("spawn").Process( target=_graph_load_to_another_device, args=(return_dict, f.name) ) load_p.start() load_p.join() print(load_p) # test_case can't be passed into sub process, so we check with return_dict. # Reference: https://stackoverflow.com/questions/52225003/writing-to-multiple-files-using-multiprocessing-error-typeerror-cannot-seria for (key, check_value) in return_dict.items(): test_case.assertTrue(check_value, key + " failed.") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLinearMultiGraph(oneflow.unittest.TestCase): def test_linear_multi_graph_share_gpu(test_case): _test_linear_multi_graph_share(test_case, flow.device("cuda"), False) def test_linear_reshape_multi_graph_share_gpu(test_case): _test_linear_multi_graph_share(test_case, flow.device("cuda"), True) def test_linear_multi_graph_save_load_gpu_with_share(test_case): _test_linear_multi_graph_save_load_gpu(test_case, True) def test_linear_multi_graph_save_load_gpu_with_share_without_eager(test_case): _test_linear_multi_graph_save_load_gpu(test_case, False) def test_load_to_another_device(test_case): _test_load_to_another_device(test_case, False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_id_shuffle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow def _test_id_shuffle(test_case, has_table_id, num_tables): batch_size = 512 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) if has_table_id: table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids table_ids_tensor = flow.tensor( table_ids.astype(np.int32), requires_grad=False ).to("cuda") else: table_ids_tensor = None ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids): ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") return ( flow.cast(num_unique_matrix, flow.int32), flow.cast(inverse_unique_partition_indices, flow.int32), flow.cast(cur_rank_num_unique, flow.int32), flow.cast(cur_rank_unique_ids, flow.int32), flow.cast(cur_rank_unique_table_ids, flow.int32), flow.cast(cur_rank_inverse_indices, flow.int32), ) graph = TestGraph() ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, cur_rank_inverse_indices, ) = graph(ids_tensor, table_ids_tensor) np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size test_case.assertTrue(np.array_equal(np_num_unique, num_unique_matrix[0])) test_case.assertTrue(np.array_equal(np_num_unique, cur_rank_num_unique[0])) reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][ inverse_unique_partition_indices ] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) if has_table_id: reversed_table_ids = cur_rank_unique_table_ids[cur_rank_inverse_indices][ inverse_unique_partition_indices ] test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids)) # when has_table_id=False, we can not test table ids because in this case same ids not lead to same table id def round_half_away_from_zero(x): sign = np.sign(x) abs_val = np.abs(x) abs_val += 0.5 floor_val = np.floor(abs_val) out = floor_val * sign return out def embedding_shuffle_quantize(np_data, np_dtype): # When use float16, ComputeType is set to as Float. np_reduce_data = np_data.astype(np.float32) abs_max_factor = np.max(np.abs(np_reduce_data), axis=2) abs_max_factor = np.expand_dims(abs_max_factor, axis=2) transport_quantize_factor = abs_max_factor.astype(np_dtype) int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0 int8_factor = int8_factor.astype(np.float32) quantize_factor = int8_factor / abs_max_factor # Covert to Compute Type. np_data.astype(np.float32) np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) # Covert to Compute Type. np_data = np_data.astype(np.float32) dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor np_data = np_data * dequantize_factor np_data = np_data.astype(np_dtype) return np_data def _test_embedding_shuffle(test_case, dtype, enable_quantize): batch_size = 512 num_tables = 26 embedding_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" else: os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids if dtype == flow.float16: np_dtype = np.float16 else: np_dtype = np.float32 data = np.random.rand(1000, embedding_size).astype(np_dtype) ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to( "cuda" ) data_tensor = flow.tensor(data, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, data): ( num_unique_matrix, inverse_unique_partition_indices, _, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") unique_embeddings = flow._C.gather(data, cur_rank_unique_ids, axis=0) embeddings = flow._C.one_embedding_embedding_shuffle( unique_embeddings, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, "test", ) return embeddings graph = TestGraph() embeddings = graph(ids_tensor, table_ids_tensor, data_tensor) np_embeddings = data[ids] # Quantized numpy embedding. if enable_quantized_comm: np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype) test_case.assertTrue( np.allclose(embeddings.numpy(), np_embeddings, atol=1e-4, rtol=1e-4) ) def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): batch_size = 512 num_tables = 26 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" ids = np.arange(batch_size * num_tables, dtype=np.int64) np.random.shuffle(ids) else: if fp16: np_tolerance = 1e-2 else: np_tolerance = 1e-4 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids embedding_grad = np.random.uniform( low=-1, high=1, size=(batch_size, num_tables, embedding_size) ).astype(np.float32) ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to( "cuda" ) embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, embedding_grad): ( num_unique_matrix, inverse_unique_partition_indices, _, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") if fp16: embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, "test", ) if fp16: cur_rank_unique_embedding_grad = flow.cast( cur_rank_unique_embedding_grad, flow.float32 ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_unique_ids, flow.int32), flow.cast(cur_rank_inverse_indices, flow.int32), flow.cast(inverse_unique_partition_indices, flow.int32), ) graph = TestGraph() ( cur_rank_unique_embedding_grad, cur_rank_unique_ids, cur_rank_inverse_indices, inverse_unique_partition_indices, ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros( cur_rank_unique_embedding_grad.shape, dtype=np.float32 ).reshape(-1, embedding_size) embedding_grad = embedding_grad.reshape(-1, embedding_size) if fp16: embedding_grad = embedding_grad.astype(np.float16) for k in range(np_num_unique): np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]]) # Quantize Embedding Gradient. if enable_quantized_comm: abs_max_factor = np.max(np.abs(np_data)) int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) quantize_factor = int8_factor / abs_max_factor np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) np_data = np_data.astype(np.float32) dequantize_factor = abs_max_factor / int8_factor np_data = np_data * dequantize_factor np_cur_rank_unique_embedding_grad[k, :] = np_data reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][ inverse_unique_partition_indices ] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) of_cur_rank_embedding_grad = cur_rank_unique_embedding_grad[ cur_rank_inverse_indices ][inverse_unique_partition_indices] of_cur_rank_embedding_grad = flow.reshape( of_cur_rank_embedding_grad, (-1, embedding_size) ) np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse] if fp16: np_cur_rank_embedding_grad = np_cur_rank_embedding_grad.astype(np.float32) test_case.assertTrue( np.allclose( of_cur_rank_embedding_grad.numpy().flatten(), np_cur_rank_embedding_grad.flatten(), atol=np_tolerance, rtol=np_tolerance, ) ) def _test_unique_key_value(test_case, has_table_id, num_tables): batch_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) if has_table_id: table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids table_ids_tensor = flow.tensor( table_ids.astype(np.int32), requires_grad=False ).to("cuda") else: table_ids_tensor = None ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids): ( num_unique, unique_ids, unique_table_ids, inverse_indices, ) = flow._C.one_embedding_unique_key_value_pair( ids, table_ids, num_tables, "test" ) return ( flow.cast(num_unique, flow.int32), flow.cast(unique_ids, flow.int32), flow.cast(unique_table_ids, flow.int32), flow.cast(inverse_indices, flow.int32), ) graph = TestGraph() (num_unique, unique_ids, unique_table_ids, inverse_indices,) = graph( ids_tensor, table_ids_tensor ) np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size test_case.assertTrue(np.array_equal(np_num_unique, num_unique[0])) reversed_ids = unique_ids[inverse_indices] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) if has_table_id: reversed_table_ids = unique_table_ids[inverse_indices] test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class DataShuffleTestCase(flow.unittest.TestCase): def test_id_shuffle(test_case): arg_dict = OrderedDict() arg_dict["has_table_id"] = [True, False] arg_dict["num_tables"] = [1, 26] for kwargs in GenArgDict(arg_dict): _test_id_shuffle(test_case, **kwargs) def test_embedding_shuffle(test_case): arg_dict = OrderedDict() arg_dict["dtype"] = [flow.float32, flow.float16] arg_dict["enable_quantize"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_embedding_shuffle(test_case, **kwargs) def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] arg_dict["fp16"] = [True, False] arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) def test_unique_key_value(test_case): arg_dict = OrderedDict() arg_dict["has_table_id"] = [True, False] arg_dict["num_tables"] = [13, 26, 1] for kwargs in GenArgDict(arg_dict): _test_unique_key_value(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_id_shuffle_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow parallel_num = 2 max_id = 1000 def get_tensors(batch_size, num_tables): placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) ids = np.random.randint(0, max_id, (batch_size, num_tables), dtype=np.int64) ids_tensor = flow.tensor(ids, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.split(0) ) table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids table_ids_tensor = flow.tensor( table_ids.astype(np.int32), requires_grad=False ).to_global(placement=placement, sbp=flow.sbp.split(0)) return ids_tensor, table_ids_tensor def _test_id_shuffle(test_case, has_table_id, num_tables): batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids): ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") return ( flow.cast(num_unique_matrix, flow.int32), flow.cast(inverse_unique_partition_indices, flow.int32), flow.cast(cur_rank_num_unique, flow.int32), flow.cast(cur_rank_unique_ids, flow.int32), flow.cast(cur_rank_unique_table_ids, flow.int32), flow.cast(cur_rank_inverse_indices, flow.int32), ) graph = TestGraph() for i in range(10): ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) if not has_table_id: table_ids_tensor = None graph(ids_tensor, table_ids_tensor) ( num_unique_matrix, inverse_unique_partition_indices, local_cur_rank_num_unique, cur_rank_unique_ids, cur_rank_unique_table_ids, cur_rank_inverse_indices, ) = graph(ids_tensor, table_ids_tensor) cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global( placement=placement, sbp=flow.sbp.split(0) ) cur_rank_num_unique_list = [] cur_rank_unique_ids_list = [] cur_rank_unique_table_ids_list = [] cur_rank_num_ids = batch_size * num_tables * parallel_num for i in range(parallel_num): num_unique_i = cur_rank_num_unique.numpy()[i] unique_ids_i = cur_rank_unique_ids.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] unique_table_ids_i = cur_rank_unique_table_ids.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] cur_rank_num_unique_list.append(num_unique_i) cur_rank_unique_ids_list.append(np.array(unique_ids_i[0:num_unique_i])) cur_rank_unique_table_ids_list.append( np.array(unique_table_ids_i[0:num_unique_i]) ) global_ids = ids_tensor.numpy() np_unique_ids, np_unique_index, np_inverse = np.unique( global_ids, return_index=True, return_inverse=True ) np_num_unique = np_unique_ids.size # test num unique test_case.assertTrue( np.array_equal(np_num_unique, np.array(cur_rank_num_unique_list).sum()) ) # test unique ids unique_ids = np.concatenate(cur_rank_unique_ids_list) unique_ids.sort() np_unique_ids.sort() test_case.assertTrue(np.array_equal(unique_ids, np_unique_ids)) if has_table_id: # test unique table ids unique_table_ids = np.concatenate(cur_rank_unique_table_ids_list) unique_table_ids.sort() global_table_ids = table_ids_tensor.numpy() np_unique_table_ids = global_table_ids.flatten()[np_unique_index] np_unique_table_ids.sort() test_case.assertTrue(np.array_equal(unique_table_ids, np_unique_table_ids)) def round_half_away_from_zero(x): sign = np.sign(x) abs_val = np.abs(x) abs_val += 0.5 floor_val = np.floor(abs_val) out = floor_val * sign return out def embedding_shuffle_quantize(np_data, np_dtype): # When use float16, ComputeType is set to as Float. np_reduce_data = np_data.astype(np.float32) abs_max_factor = np.max(np.abs(np_reduce_data), axis=2) abs_max_factor = np.expand_dims(abs_max_factor, axis=2) transport_quantize_factor = abs_max_factor.astype(np_dtype) int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0 int8_factor = int8_factor.astype(np.float32) quantize_factor = int8_factor / abs_max_factor # Covert to Compute Type. np_data.astype(np.float32) np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) # Covert to Compute Type. np_data = np_data.astype(np.float32) dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor np_data = np_data * dequantize_factor np_data = np_data.astype(np_dtype) return np_data def _test_embedding_shuffle(test_case, dtype, enable_quantize): batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 embedding_size = 128 enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" else: os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" if dtype == flow.float16: np_dtype = np.float16 else: np_dtype = np.float32 data = np.random.rand(max_id, embedding_size).astype(np_dtype) data_tensor = flow.tensor(data, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.broadcast() ) class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, data): ( num_unique_matrix, inverse_unique_partition_indices, _, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") unique_embeddings = flow._C.gather(data, cur_rank_unique_ids, axis=0) embeddings = flow._C.one_embedding_embedding_shuffle( unique_embeddings, flow._C.identity(num_unique_matrix), flow._C.identity(cur_rank_inverse_indices), flow._C.identity(inverse_unique_partition_indices), "test", ) return embeddings graph = TestGraph() for i in range(10): ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) graph(ids_tensor, table_ids_tensor, data_tensor) embeddings = graph(ids_tensor, table_ids_tensor, data_tensor) global_ids = ids_tensor.numpy() global_data = data_tensor.numpy() np_embeddings = global_data[global_ids] # Quantized numpy embedding. if enable_quantized_comm: np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype) test_case.assertTrue(np.array_equal(embeddings.numpy(), np_embeddings)) def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): np_tolerance = 0 batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" else: if fp16: np_tolerance = 1e-2 else: np_tolerance = 1e-4 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype( np.float32 ) embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.split(0) ) class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, embedding_grad): ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") if fp16: embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, "test", ) if fp16: cur_rank_unique_embedding_grad = flow.cast( cur_rank_unique_embedding_grad, flow.float32 ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_num_unique, flow.int32), cur_rank_unique_ids, ) graph = TestGraph() for i in range(10): ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) ( cur_rank_unique_embedding_grad, local_cur_rank_num_unique, cur_rank_unique_ids, ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global( placement=placement, sbp=flow.sbp.split(0) ) global_ids = ids_tensor.numpy() global_embedding_grad = embedding_grad_tensor.numpy() np_unique_ids = np.unique(global_ids) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size)) if fp16: global_embedding_grad = global_embedding_grad.astype(np.float16) for k in range(np_num_unique): unique_id = np_unique_ids[k] np_data = sum( global_embedding_grad.reshape(-1, embedding_size)[ np.where(global_ids.flatten() == unique_id)[0] ] ) # Quantize Embedding Gradient. if enable_quantized_comm: abs_max_factor = np.max(np.abs(np_data)) int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) quantize_factor = int8_factor / abs_max_factor np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) np_data = np_data.astype(np.float32) dequantize_factor = abs_max_factor / int8_factor np_data = np_data * dequantize_factor np_cur_rank_unique_embedding_grad[unique_id, :] = np_data if fp16: np_cur_rank_unique_embedding_grad = np_cur_rank_unique_embedding_grad.astype( np.float32 ) cur_rank_num_ids = batch_size * num_tables * parallel_num of_unique_embedding_grad = np.zeros((max_id, embedding_size)) for i in range(parallel_num): num_unique_i = cur_rank_num_unique.numpy()[i] unique_ids_i = cur_rank_unique_ids.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] unique_embedding_grad_i = cur_rank_unique_embedding_grad.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] for j in range(num_unique_i): unique_id = unique_ids_i[j] of_unique_embedding_grad[unique_id, :] = unique_embedding_grad_i[j, :] test_case.assertTrue( np.allclose( of_unique_embedding_grad, np_cur_rank_unique_embedding_grad, atol=np_tolerance, rtol=np_tolerance, ), ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class DataShuffleTestCase(flow.unittest.TestCase): def test_id_shuffle(test_case): arg_dict = OrderedDict() arg_dict["has_table_id"] = [True, False] arg_dict["num_tables"] = [1, 26] for kwargs in GenArgDict(arg_dict): _test_id_shuffle(test_case, **kwargs) def test_embedding_shuffle(test_case): arg_dict = OrderedDict() arg_dict["dtype"] = [flow.float32, flow.float16] arg_dict["enable_quantize"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_embedding_shuffle(test_case, **kwargs) def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] arg_dict["fp16"] = [True, False] arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_layernorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList input_arr = np.array( [ [ [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]], [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]], ], [ [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]], [[1.07541728, 0.11008703], [0.26361224, -0.48663723]], ], ], dtype=np.float32, ) def _test_layernorm(test_case, device): output = np.array( [ [ [[-0.0544118, -1.0509688], [-0.2696846, 0.4295622]], [[-1.2834904, -0.4838651], [2.0891891, 0.6236691]], ], [ [[-0.8555527, -0.3554582], [0.493019, -1.694826]], [[1.8035311, 0.4155158], [0.6362644, -0.4424936]], ], ], dtype=np.float32, ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) m = flow.nn.LayerNorm(x.size()[1:]).to(device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05)) def _test_layernorm_v2(test_case, device): output = np.array( [ [ [[0.3406544, -1.5249983], [-0.0623574, 1.2467014]], [[-1.2004623, -0.5688803], [1.4634399, 0.3059027]], ], [ [[-0.3180245, 0.3122248], [1.3815271, -1.3757277]], [[1.497291, -0.2341234], [0.0412391, -1.3044068]], ], ], dtype=np.float32, ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) m = flow.nn.LayerNorm([2, 2], eps=1e-05).to(device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05)) def _test_layernorm_v3(test_case, device): output = np.array( [ [ [[0.999974, -0.999974], [-0.999947, 0.999947]], [[-0.9999595, 0.9999595], [0.999988, -0.999988]], ], [ [[-0.9998344, 0.9998341], [0.9999914, -0.9999914]], [[0.9999787, -0.9999787], [0.9999645, -0.9999645]], ], ], dtype=np.float32, ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) m = flow.nn.LayerNorm(2, elementwise_affine=True).to(device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05)) def _test_layernorm_backward(test_case, device): output = np.array( [ [ [[-0.0544118, -1.0509688], [-0.2696846, 0.4295622]], [[-1.2834904, -0.4838651], [2.0891891, 0.6236691]], ], [ [[-0.8555527, -0.3554582], [0.493019, -1.694826]], [[1.8035311, 0.4155158], [0.6362644, -0.4424936]], ], ], dtype=np.float32, ) x = flow.tensor( input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = flow.nn.LayerNorm(x.size()[1:]).to(device=flow.device(device)) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLayerNorm(flow.unittest.TestCase): def test_layernorm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_layernorm, _test_layernorm_v2, _test_layernorm_v3, _test_layernorm_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3) def test_layernorm_with_random_data_warp(test_case): device = "cuda" channel = random(1, 32).to(int) height = random(1, 2).to(int) width = random(1, 1024).to(int) def get_random_norm_shape(): begin_axis = random(1, 3).to(int).value() return tuple((channel.value(), height.value(), width.value())[begin_axis:]) m = torch.nn.LayerNorm( normalized_shape=get_random_norm_shape(), elementwise_affine=random().to(bool), ).to(device) x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device) y = m(x) return y @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3) def test_layernorm_with_random_data_shared_mem(test_case): device = "cuda" channel = random(1, 32).to(int) height = random(1, 2).to(int) width = random(1024, 8192).to(int) def get_random_norm_shape(): begin_axis = random(1, 3).to(int).value() return tuple((channel.value(), height.value(), width.value())[begin_axis:]) m = torch.nn.LayerNorm( normalized_shape=get_random_norm_shape(), elementwise_affine=random().to(bool), ).to(device) x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device) y = m(x) return y @autotest(n=5, auto_backward=True, rtol=1e-3, atol=1e-3) def test_layernorm_with_random_data_uncached(test_case): device = "cuda" channel = random(1, 32).to(int) height = random(1, 2).to(int) width = random(8192, 32768).to(int) def get_random_norm_shape(): begin_axis = random(1, 3).to(int).value() return tuple((channel.value(), height.value(), width.value())[begin_axis:]) m = torch.nn.LayerNorm( normalized_shape=get_random_norm_shape(), elementwise_affine=random().to(bool), ).to(device) x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device) y = m(x) return y @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3) def test_layernorm_without_affine(test_case): device = random_device() channel = random(1, 32).to(int) height = random(1, 2).to(int) width = random(8192, 32768).to(int) def get_random_norm_shape(): begin_axis = random(1, 3).to(int).value() return tuple((channel.value(), height.value(), width.value())[begin_axis:]) m = torch.nn.LayerNorm(normalized_shape=get_random_norm_shape()).to(device) x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device) y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_oneembedding.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow import oneflow.nn as nn import tempfile import hashlib class OneEmbedding(nn.Module): def __init__( self, test_id, embedding_vec_size, persistent_path, table_size_array, size_factor, ): assert table_size_array is not None vocab_size = sum(table_size_array) scales = np.sqrt(1 / np.array(table_size_array)) tables = [ flow.one_embedding.make_table( flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) ) for scale in scales ] store_options = flow.one_embedding.make_device_mem_store_options( persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, ) super(OneEmbedding, self).__init__() self.one_embedding = flow.one_embedding.MultiTableEmbedding( f"oneembedding_{test_id}", embedding_dim=embedding_vec_size, dtype=flow.float, key_type=flow.int64, tables=tables, store_options=store_options, ) def forward(self, ids): return self.one_embedding.forward(ids) class TestModule(nn.Module): def __init__( self, test_id, embedding_vec_size, persistent_path, table_size_array, size_factor, ): super(TestModule, self).__init__() self.embedding = OneEmbedding( test_id, embedding_vec_size, persistent_path, table_size_array, size_factor ) self.mlp = nn.Linear(embedding_vec_size, 1) def forward(self, inputs) -> flow.Tensor: embedding = self.embedding(inputs) logits = self.mlp(embedding).mean(dim=1) return logits class TrainGraph(flow.nn.Graph): def __init__( self, module, loss, optimizer, amp=False, ): super(TrainGraph, self).__init__() self.module = module self.loss = loss self.add_optimizer(optimizer) if amp: self.config.enable_amp(True) def build(self, labels, features): logits = self.module(features.to("cuda")) loss = self.loss(logits, labels.to("cuda")) reduce_loss = flow.mean(loss) reduce_loss.backward() return reduce_loss.to("cpu") def _test_one_embedding( test_case, batch_size, table_size_array, embedding_size, test_opt ): test_str = str([batch_size, table_size_array, embedding_size, test_opt]) test_hash = hashlib.sha256(test_str.encode("utf-8")).hexdigest() def np_to_global(np): t = flow.from_numpy(np) return t.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.split(0)) with tempfile.TemporaryDirectory() as persistent_path: size_factor = 3 if test_opt == "Adam" else 1 module = TestModule( test_hash, embedding_size, persistent_path, table_size_array, size_factor ) module.to_global(flow.placement.all("cuda"), flow.sbp.broadcast) if test_opt == "Adam": opt = flow.optim.Adam(module.parameters(), lr=0.1) elif test_opt == "SGD": opt = flow.optim.SGD(module.parameters(), lr=0.1) else: assert False loss = flow.nn.BCEWithLogitsLoss(reduction="none").to("cuda") train_graph = TrainGraph(module, loss, opt) module.train() for step in range(1, 101): labels = np.random.randint(2, size=(batch_size, 1)).astype(np.float32) features = np.random.randint( sum(table_size_array), size=(batch_size, len(table_size_array)) ) labels = np_to_global(labels) features = np_to_global(features) loss = train_graph(labels, features) test_case.assertFalse(np.isnan(loss.numpy())) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class OneEmbeddingTestCase(flow.unittest.TestCase): def test_one_embedding(test_case): arg_dict = OrderedDict() arg_dict["batch_size"] = [32, 4096] arg_dict["table_size_array"] = [ [32, 65536, 100, 7], [32768, 10000, 17, 3, 686], ] arg_dict["embedding_size"] = [128, 17] arg_dict["test_opt"] = ["SGD", "Adam"] for kwargs in GenArgDict(arg_dict): _test_one_embedding(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_oneembedding_padding_idx.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow import oneflow.nn as nn import tempfile import hashlib import random class OneEmbedding(nn.Module): def __init__( self, test_id, embedding_vec_size, persistent_path, table_size_array, size_factor, padding_idx, ): assert table_size_array is not None vocab_size = sum(table_size_array) scales = np.sqrt(1 / np.array(table_size_array)) tables = [ flow.one_embedding.make_table( flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) ) for scale in scales ] store_options = flow.one_embedding.make_device_mem_store_options( persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, ) super(OneEmbedding, self).__init__() self.one_embedding = flow.one_embedding.MultiTableEmbedding( f"oneembedding_{test_id}", embedding_dim=embedding_vec_size, dtype=flow.float, key_type=flow.int64, tables=tables, store_options=store_options, padding_idx=padding_idx, ) def forward(self, ids): return self.one_embedding.forward(ids) class TestModule(nn.Module): def __init__( self, test_id, embedding_vec_size, persistent_path, table_size_array, size_factor, padding_idx, ): super(TestModule, self).__init__() self.embedding = OneEmbedding( test_id, embedding_vec_size, persistent_path, table_size_array, size_factor, padding_idx=padding_idx, ) def forward(self, inputs) -> flow.Tensor: embedding = self.embedding(inputs) return embedding class TrainGraph(flow.nn.Graph): def __init__( self, module, loss, optimizer, amp=False, ): super(TrainGraph, self).__init__() self.module = module self.loss = loss self.add_optimizer(optimizer) if amp: self.config.enable_amp(True) def build(self, labels, features): embedding = self.module(features.to("cuda")) reduce_loss = flow.mean(embedding) reduce_loss.backward() return embedding.to("cpu") def _test_one_embedding_padding_idx( test_case, batch_size, table_size_array, embedding_size, test_opt, padding_idx ): test_str = str([batch_size, table_size_array, embedding_size, test_opt]) test_hash = hashlib.sha256(test_str.encode("utf-8")).hexdigest() def np_to_global(np): t = flow.from_numpy(np) return t.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.split(0)) with tempfile.TemporaryDirectory() as persistent_path: size_factor = 3 if test_opt == "Adam" else 1 module = TestModule( test_hash, embedding_size, persistent_path, table_size_array, size_factor, padding_idx, ) module.to_global(flow.placement.all("cuda"), flow.sbp.broadcast) if test_opt == "Adam": opt = flow.optim.Adam(module.parameters(), lr=0.1) elif test_opt == "SGD": opt = flow.optim.SGD(module.parameters(), lr=0.1) else: assert False loss = flow.nn.BCEWithLogitsLoss(reduction="none").to("cuda") train_graph = TrainGraph(module, loss, opt) module.train() padding_num = random.randint(0, batch_size - 1) labels = np.random.randint(2, size=(batch_size, 1)).astype(np.float32) padding_feature = np.full( (len(table_size_array)), fill_value=padding_idx ).astype(np.int64) features = np.random.randint( sum(table_size_array), size=(batch_size, len(table_size_array)) ) padding_feature_idx = np.random.randint(batch_size, size=(padding_num,)) for i in range(padding_num): idx = int(padding_feature_idx[i]) features[idx] = padding_feature labels = np_to_global(labels) features = np_to_global(features) embedding_val = train_graph(labels, features) for i in range(padding_feature_idx.size): idx = int(padding_feature_idx[i]) test_case.assertTrue( np.array_equal( embedding_val[idx].numpy(), np.zeros((len(table_size_array), embedding_size), dtype=np.float32), ) ) # Infer again to check the embedding in padding_idx is not updated. embedding_val = train_graph(labels, features) for i in range(padding_feature_idx.size): idx = int(padding_feature_idx[i]) test_case.assertTrue( np.array_equal( embedding_val[idx].numpy(), np.zeros((len(table_size_array), embedding_size), dtype=np.float32), ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class OneEmbeddingWithPaddingIdxTestCase(flow.unittest.TestCase): def test_one_embedding_padding_idx(test_case): arg_dict = OrderedDict() arg_dict["batch_size"] = [32] arg_dict["table_size_array"] = [ [32, 64, 32, 32], ] arg_dict["embedding_size"] = [12] arg_dict["test_opt"] = ["SGD"] arg_dict["padding_idx"] = [2] os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "300" for kwargs in GenArgDict(arg_dict): _test_one_embedding_padding_idx(test_case, **kwargs) os.environ["ONEFLOW_TIMEOUT_SECONDS"] = "90" if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_permute.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from random import shuffle import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _test_permute_impl(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out1 = flow.permute(input, (1, 0, 2, 3)) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out1.numpy().flatten(), np_out.flatten())) of_out = of_out1.sum() of_out.backward() np_grad = np.ones((2, 6, 5, 3)) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 0.0001, 0.0001)) def _test_tensor_permute_impl(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out1 = input.permute(1, 0, 2, 3) of_out2 = input.permute(*(1, 0, 2, 3)) of_out3 = input.permute((1, 0, 2, 3)) of_out4 = input.permute([1, 0, 2, 3]) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out1.numpy().flatten(), np_out.flatten())) test_case.assertTrue(np.array_equal(of_out2.numpy().flatten(), np_out.flatten())) test_case.assertTrue(np.array_equal(of_out3.numpy().flatten(), np_out.flatten())) test_case.assertTrue(np.array_equal(of_out4.numpy().flatten(), np_out.flatten())) of_out = of_out1.sum() of_out.backward() np_grad = np.ones((2, 6, 5, 3)) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestPermute(flow.unittest.TestCase): def test_permute(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_permute_impl(test_case, *arg) _test_tensor_permute_impl(test_case, *arg) @autotest(n=10, check_graph=False) def test_torch_permute4d_with_random_data(test_case): device = random_device() ndim = 4 permute_list = [0, 1, 2, 3] shuffle(permute_list) x = random_tensor(ndim=ndim, dim0=random().to(int)).to(device) y = torch.permute(x, dims=permute_list) return y @unittest.skip("pytorch 1.9.0 exist not torch.permute api") @autotest(n=10) def test_torch_permute4d_with_random_0dim_data(test_case): device = random_device() permute_list = [0, 1, 2, 3] shuffle(permute_list) x = random_tensor(ndim=0).to(device) y = torch.permute(x, dims=permute_list) return y @autotest(n=10, check_graph=True) def test_permute5d_tensor_with_random_data(test_case): device = random_device() ndim = 5 permute_list = [0, 1, 2, 3, 4] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 16).to(int), dim1=random(1, 33).to(int), dim2=random(1, 64).to(int), dim3=random(45, 67).to(int), dim4=random(1, 64).to(int), ).to(device) y = x.permute(permute_list) return y @autotest(n=10, check_graph=True) def test_permute4d_tensor_with_random_data(test_case): device = random_device() ndim = 4 permute_list = [0, 1, 2, 3] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 7).to(int), dim1=random(1, 15).to(int), dim2=random(1, 9).to(int), dim3=random(1, 19).to(int), ).to(device) y = x.permute(permute_list) return y @autotest(n=10, check_graph=True) def test_permute4d_tensor_with_stride(test_case): device = random_device() ndim = 4 permute_list1 = [0, 1, 2, 3] shuffle(permute_list1) x = random_tensor( ndim=ndim, dim0=random(1, 7).to(int), dim1=random(1, 15).to(int), dim2=random(1, 9).to(int), dim3=random(1, 19).to(int), ).to(device) y = x.permute(permute_list1) permute_list2 = [0, 1, 2, 3] shuffle(permute_list2) z = y.permute(permute_list2) return z @autotest(n=5, check_graph=True) def test_permute3d_tensor_with_random_data(test_case): device = random_device() ndim = 3 permute_list = [0, 1, 2] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 18).to(int), dim1=random(1, 78).to(int), dim2=random(1, 99).to(int), ).to(device) y = x.permute(permute_list) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_permute4d_tensor_bool_with_random_data(test_case): device = random_device() ndim = 4 permute_list = [0, 1, 2, 3] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 7).to(int), dim1=random(1, 15).to(int), dim2=random(1, 9).to(int), dim3=random(1, 19).to(int), ).to(device=device, dtype=torch.bool) y = x.permute(permute_list) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_remat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import subprocess import sys import os import unittest import oneflow as flow import oneflow.unittest class TestRemat(flow.unittest.TestCase): def test_remat_in_single_threaded_vm(test_case): env = os.environ.copy() env["ONEFLOW_VM_MULTI_THREAD"] = "0" p = subprocess.run( [sys.executable, "_test_remat.py"], cwd=os.path.dirname(os.path.realpath(__file__)), env=env, ) test_case.assertEqual(p.returncode, 0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_resnet50_with_bn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from resnet50_model import resnet50 import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestResNet50(flow.unittest.TestCase): def test_resnet50_with_batchnorm(test_case): batch_size = 32 color_space = "RGB" height = 224 width = 224 output_layout = "NCHW" rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, shuffle_after_epoch=False, ) record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space=color_space ) record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) resize = flow.nn.image.Resize( resize_side="shorter", keep_aspect_ratio=True, target_size=256 ) crop_mirror_normal = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, crop_h=height, crop_w=width, crop_pos_y=0.5, crop_pos_x=0.5, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) res50_module = resnet50( replace_stride_with_dilation=[False, False, False], norm_layer=flow.nn.BatchNorm2d, ) res50_module.train() res50_module.load_state_dict( flow.load(flow.unittest.dataset_dir("imagenette/resnet50_models")) ) of_corss_entropy = flow.nn.CrossEntropyLoss() res50_module.to("cuda") of_corss_entropy.to("cuda") learning_rate = 0.001 mom = 0.9 of_sgd = flow.optim.SGD( res50_module.parameters(), lr=learning_rate, momentum=mom ) errors = 0.0 for b in range(100): val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) image = resize(image_raw_buffer)[0] image = crop_mirror_normal(image) image = image.to("cuda") label = label.to("cuda") logits = res50_module(image) loss = of_corss_entropy(logits, label) loss.backward() of_sgd.step() of_sgd.zero_grad() l = loss.numpy() test_case.assertTrue(l < 3.5) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_resnet50_without_bn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np from resnet50_model import FakeBN, resnet50 import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestResNet50(flow.unittest.TestCase): def test_resnet50_without_batchnorm(test_case): batch_size = 32 color_space = "RGB" height = 224 width = 224 output_layout = "NCHW" rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, shuffle_after_epoch=False, ) record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space=color_space ) record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) resize = flow.nn.image.Resize( resize_side="shorter", keep_aspect_ratio=True, target_size=256 ) crop_mirror_normal = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, crop_h=height, crop_w=width, crop_pos_y=0.5, crop_pos_x=0.5, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) res50_module = resnet50( replace_stride_with_dilation=[False, False, False], norm_layer=FakeBN ) res50_module.train() res50_module.load_state_dict( flow.load(flow.unittest.dataset_dir("resnet50_wo_bn_weights_for_ci")) ) of_corss_entropy = flow.nn.CrossEntropyLoss() res50_module.to("cuda") of_corss_entropy.to("cuda") learning_rate = 0.001 mom = 0.9 of_sgd = flow.optim.SGD( res50_module.parameters(), lr=learning_rate, momentum=mom ) gt_of_losses = [ 49.83235168457031, 36.34172821044922, 23.585250854492188, 15.628865242004395, 9.552209854125977, 8.11514663696289, 6.364114284515381, 6.442500114440918, 4.439807891845703, 4.024901866912842, 4.7038373947143555, 4.253284454345703, 4.5806169509887695, 4.158677577972412, 3.0066077709198, 4.611920356750488, 4.46696138381958, 2.9725658893585205, 3.2383458614349365, 3.605447292327881, 3.8676259517669678, 3.2477705478668213, 2.9191272258758545, 3.162745475769043, 3.0127673149108887, 2.615905284881592, 2.7866411209106445, 3.471228837966919, 2.9467897415161133, 3.3623316287994385, ] for b in range(len(gt_of_losses)): val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) image = resize(image_raw_buffer)[0] image = crop_mirror_normal(image) image = image.to("cuda") label = label.to("cuda") logits = res50_module(image) loss = of_corss_entropy(logits, label) loss.backward() of_sgd.step() of_sgd.zero_grad() l = loss.numpy() test_case.assertTrue( np.allclose(l.item(), gt_of_losses[b], rtol=1e-2, atol=1e-3) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_rnn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestRNNModules(flow.unittest.TestCase): @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3) def test_rnn(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) num_layers = random(1, 6).to(int) input_size = random(2, 6).to(int) hidden_size = random(2, 6).to(int) m = torch.nn.RNN( input_size, hidden_size, num_layers=num_layers, nonlinearity="tanh", bias=random().to(bool), batch_first=random().to(bool), dropout=0, bidirectional=random().to(bool), ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) out = m(input) return out[0] @autotest(n=5, check_graph=True, rtol=1e-2) def test_lstm(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) num_layers = random(1, 6).to(int) input_size = random(2, 6).to(int) hidden_size = random(2, 6).to(int) proj_size = random(2, 6).to(int) m = torch.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=random().to(bool), batch_first=random().to(bool), dropout=0, bidirectional=random().to(bool), proj_size=proj_size, ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) out = m(input) return out[0] @autotest(n=5, check_graph=True, rtol=1e-2) def test_gru(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) num_layers = random(1, 6).to(int) input_size = random(2, 6).to(int) hidden_size = random(2, 6).to(int) m = torch.nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=random().to(bool), batch_first=random().to(bool), dropout=0, bidirectional=random().to(bool), ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) out = m(input) return out[0] if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_rnn_cell.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestRNN(flow.unittest.TestCase): @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3) def test_rnn_tanh_cell(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) input_size = random(1, 6) * 2 hidden_size = random(1, 6) * 2 m = torch.nn.RNNCell( input_size=input_size, hidden_size=hidden_size, bias=random().to(bool), nonlinearity="tanh", ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx @autotest(n=5, check_graph=True) def test_rnn_relu_cell(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) input_size = random(1, 6) * 2 hidden_size = random(1, 6) * 2 m = torch.nn.RNNCell( input_size=input_size, hidden_size=hidden_size, bias=random().to(bool), nonlinearity="relu", ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx @unittest.skip("skip for now, becase it failed 4 times in past week") @autotest(n=5, check_graph=True, rtol=1e-2) def test_lstm_cell(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) input_size = random(1, 6) * 2 hidden_size = random(1, 6) * 2 has_bias = random().to(bool) cx_requires_grad = random().to(bool) m = torch.nn.LSTMCell( input_size=input_size, hidden_size=hidden_size, bias=has_bias, ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) hx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False ).to(device) cx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=cx_requires_grad ).to(device) for i in range(time_steps.to(int).value()): res = m(input[i], (hx, cx)) hx = res[0] cx = res[1] return res[0] @autotest(n=5, check_graph=True, rtol=1e-2) def test_gru_cell(test_case): device = random_device() batch_size = random(1, 6) time_steps = random(1, 6) input_size = random(1, 6) * 2 hidden_size = random(1, 6) * 2 has_bias = random().to(bool) m = torch.nn.GRUCell( input_size=input_size, hidden_size=hidden_size, bias=has_bias ).to(device) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to(device) hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_rnn_pack_sequence.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random import numpy as np from collections import OrderedDict import torch import torch.nn.utils.rnn as torch_rnn_utils import oneflow as flow import oneflow.nn.utils.rnn as flow_rnn_utils import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_rnn_pack_sequence(test_case, device): l = ["tanh", "relu"] input_size = random.randint(10, 1000) hidden_size = random.randint(10, 1000) num_layers = random.randint(1, 6) nonlinearity = l[0 if num_layers <= 3 else 1] grad_tol = 1e-4 if nonlinearity == "relu": grad_tol = 100 bias = random.randint(-10, 10) <= 0 batch_first = False dropout = 0 bidirectional = random.randint(-10, 10) <= 0 rnn_torch = torch.nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, nonlinearity=nonlinearity, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) rnn_flow = flow.nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, nonlinearity=nonlinearity, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) torch_state_dict = rnn_torch.state_dict() new_dict = {} for k, v in torch_state_dict.items(): new_dict[k] = v.detach().numpy() rnn_flow.load_state_dict(new_dict) rnn_flow = rnn_flow.to(device) rnn_torch = rnn_torch.to(device) max_seq_len = random.randint(10, 50) batch_size = random.randint(10, 50) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) sequences = [] for i in range(batch_size): sequences.append(flow.rand(lengths[i], input_size).to(device)) x_flow = flow_rnn_utils.pack_sequence(sequences) torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences] x_torch = torch_rnn_utils.pack_sequence(torch_inputs) out_torch, hid_torch = rnn_torch(x_torch) out_flow, hid_flow = rnn_flow(x_flow) z_torch = out_torch.data.sum() z_torch.backward() z_flow = out_flow.data.sum() z_flow.backward() test_case.assertTrue( np.allclose( out_torch.data.cpu().detach().numpy(), out_flow.data.cpu().detach().numpy(), atol=1e-5, ) ) test_case.assertTrue( np.allclose( hid_torch.cpu().detach().numpy(), hid_flow.cpu().detach().numpy(), atol=1e-5, ) ) all_weights = rnn_torch.all_weights torch_params = [] for ls in all_weights: for l in ls: torch_params.append(l) all_weights = rnn_flow.all_weights flow_params = [] for ls in all_weights: for l in ls: flow_params.append(l) for i in range(len(flow_params)): torch_np = torch_params[i].grad.cpu().numpy() flow_np = flow_params[i].grad.cpu().numpy() test_case.assertTrue(np.allclose(torch_np, flow_np, atol=grad_tol)) def _test_lstm_pack_sequence(test_case, device): input_size = random.randint(10, 1000) hidden_size = random.randint(12, 1000) num_layers = random.randint(1, 6) bias = random.randint(-10, 10) <= 0 batch_first = False dropout = 0 bidirectional = random.randint(-10, 10) <= 0 proj_size = random.randint(0, hidden_size - 1) lstm_torch = torch.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) lstm_flow = flow.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) torch_state_dict = lstm_torch.state_dict() new_dict = {} for k, v in torch_state_dict.items(): new_dict[k] = v.detach().numpy() lstm_flow.load_state_dict(new_dict) lstm_flow = lstm_flow.to(device) lstm_torch = lstm_torch.to(device) max_seq_len = random.randint(10, 50) batch_size = random.randint(10, 50) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) sequences = [] for i in range(batch_size): sequences.append(flow.rand(lengths[i], input_size).to(device)) x_flow = flow_rnn_utils.pack_sequence(sequences) torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences] x_torch = torch_rnn_utils.pack_sequence(torch_inputs) out_torch, hid_torch = lstm_torch(x_torch) out_flow, hid_flow = lstm_flow(x_flow) z_torch = out_torch.data.sum() z_torch.backward() z_flow = out_flow.data.sum() z_flow.backward() test_case.assertTrue( np.allclose( out_torch.data.cpu().detach().numpy(), out_flow.data.cpu().detach().numpy(), atol=1e-5, ) ) test_case.assertTrue( np.allclose( hid_torch[0].cpu().detach().numpy(), hid_flow[0].cpu().detach().numpy(), atol=1e-5, ) ) test_case.assertTrue( np.allclose( hid_torch[1].cpu().detach().numpy(), hid_flow[1].cpu().detach().numpy(), atol=1e-5, ) ) all_weights = lstm_torch.all_weights torch_params = [] for ls in all_weights: for l in ls: torch_params.append(l) all_weights = lstm_flow.all_weights flow_params = [] for ls in all_weights: for l in ls: flow_params.append(l) for i in range(len(flow_params)): torch_np = torch_params[i].grad.cpu().numpy() flow_np = flow_params[i].grad.cpu().numpy() test_case.assertTrue(np.allclose(torch_np, flow_np, atol=1e-4)) def _test_gru_pack_sequence(test_case, device): input_size = random.randint(10, 1000) hidden_size = random.randint(10, 1000) num_layers = random.randint(1, 6) grad_tol = 1e-4 bias = random.randint(-10, 10) <= 0 batch_first = False dropout = 0 bidirectional = random.randint(-10, 10) <= 0 gru_torch = torch.nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) gru_flow = flow.nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) torch_state_dict = gru_torch.state_dict() new_dict = {} for k, v in torch_state_dict.items(): new_dict[k] = v.detach().numpy() gru_flow.load_state_dict(new_dict) gru_flow = gru_flow.to(device) gru_torch = gru_torch.to(device) max_seq_len = random.randint(10, 50) batch_size = random.randint(10, 50) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) sequences = [] for i in range(batch_size): sequences.append(flow.rand(lengths[i], input_size).to(device)) x_flow = flow_rnn_utils.pack_sequence(sequences) torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences] x_torch = torch_rnn_utils.pack_sequence(torch_inputs) out_torch, hid_torch = gru_torch(x_torch) out_flow, hid_flow = gru_flow(x_flow) z_torch = out_torch.data.sum() z_torch.backward() z_flow = out_flow.data.sum() z_flow.backward() test_case.assertTrue( np.allclose( out_torch.data.cpu().detach().numpy(), out_flow.data.cpu().detach().numpy(), atol=1e-5, ) ) test_case.assertTrue( np.allclose( hid_torch.cpu().detach().numpy(), hid_flow.cpu().detach().numpy(), atol=1e-5, ) ) all_weights = gru_torch.all_weights torch_params = [] for ls in all_weights: for l in ls: torch_params.append(l) all_weights = gru_flow.all_weights flow_params = [] for ls in all_weights: for l in ls: flow_params.append(l) for i in range(len(flow_params)): torch_np = torch_params[i].grad.cpu().numpy() flow_np = flow_params[i].grad.cpu().numpy() test_case.assertTrue(np.allclose(torch_np, flow_np, atol=grad_tol)) @flow.unittest.skip_unless_1n1d() class TestRNNModules(flow.unittest.TestCase): def test_rnn(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_rnn_pack_sequence, _test_lstm_pack_sequence, _test_gru_pack_sequence, ] arg_dict["device"] = ["cuda", "cpu"] for i in range(5): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_rnn_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random import numpy as np from collections import OrderedDict import torch import torch.nn.utils.rnn as torch_rnn_utils import oneflow as flow import oneflow.nn.utils.rnn as flow_rnn_utils import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_rnn_utils_pack_padded_sequence(test_case, device): input_size = random.randint(10, 150) max_seq_len = random.randint(10, 300) batch_size = random.randint(10, 300) requires_grad = np.random.rand() > 0.5 padded_inputs = np.zeros((max_seq_len, batch_size, input_size)) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) for i in range(batch_size): padded_inputs[0 : lengths[i], i : i + 1, :] = i + 1 inputs = flow.from_numpy(padded_inputs).to(device) inputs.requires_grad = requires_grad flow_res = flow_rnn_utils.pack_padded_sequence(inputs, lengths) torch_inputs = torch.from_numpy(padded_inputs).to(device) torch_inputs.requires_grad = requires_grad torch_res = torch_rnn_utils.pack_padded_sequence(torch_inputs, lengths) test_case.assertTrue( np.allclose( torch_res.batch_sizes.cpu().detach().numpy(), flow_res.batch_sizes.cpu().detach().numpy(), atol=1e-8, ) ) test_case.assertTrue( np.allclose( torch_res.data.cpu().detach().numpy(), flow_res.data.cpu().detach().numpy(), atol=1e-8, ) ) torch_seq_unpacked, torch_lens_unpacked = torch_rnn_utils.pad_packed_sequence( torch_res, batch_first=False ) flow_seq_unpacked, flow_lens_unpacked = flow_rnn_utils.pad_packed_sequence( flow_res, batch_first=False ) if requires_grad: torch_seq_unpacked.sum().backward() flow_seq_unpacked.sum().backward() test_case.assertTrue( np.allclose( torch_seq_unpacked.cpu().detach().numpy(), flow_seq_unpacked.cpu().detach().numpy(), atol=1e-8, ) ) test_case.assertTrue( np.allclose( torch_lens_unpacked.cpu().detach().numpy(), flow_lens_unpacked.cpu().detach().numpy(), atol=1e-8, ) ) if requires_grad: test_case.assertTrue( np.allclose(inputs.grad.cpu().numpy(), torch_inputs.grad.cpu().numpy()) ) def _test_rnn_utils_pad_sequence(test_case, device): input_size = random.randint(10, 150) max_seq_len = random.randint(20, 300) batch_size = random.randint(20, 300) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) sequences = [] for i in range(batch_size): sequences.append(flow.rand(lengths[i], input_size).to(device)) flow_res = flow_rnn_utils.pad_sequence(sequences) torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences] torch_res = torch_rnn_utils.pad_sequence(torch_inputs) test_case.assertTrue( np.allclose( torch_res.cpu().detach().numpy(), flow_res.cpu().detach().numpy(), atol=1e-8, ) ) def _test_rnn_utils_pack_sequence(test_case, device): input_size = random.randint(10, 150) max_seq_len = random.randint(20, 300) batch_size = random.randint(20, 300) lengths = [] lengths.append(max_seq_len) for i in range(batch_size - 1): lengths.append(random.randint(1, max_seq_len)) lengths.sort(reverse=True) sequences = [] for i in range(batch_size): sequences.append(flow.rand(lengths[i], input_size).to(device)) flow_res = flow_rnn_utils.pack_sequence(sequences) torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences] torch_res = torch_rnn_utils.pack_sequence(torch_inputs) test_case.assertTrue( np.allclose( torch_res.batch_sizes.cpu().detach().numpy(), flow_res.batch_sizes.cpu().detach().numpy(), atol=1e-8, ) ) test_case.assertTrue( np.allclose( torch_res.data.cpu().detach().numpy(), flow_res.data.cpu().detach().numpy(), atol=1e-8, ) ) @flow.unittest.skip_unless_1n1d() class TestRNNUtils(flow.unittest.TestCase): def test_rnn_utils_pack_padded_sequence(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for i in range(10): for arg in GenArgList(arg_dict): _test_rnn_utils_pack_padded_sequence(test_case, *arg[0:]) def test_rnn_utils_pad_sequence(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for i in range(10): for arg in GenArgList(arg_dict): _test_rnn_utils_pad_sequence(test_case, *arg[0:]) def test_rnn_utils_pack_sequence(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for i in range(10): for arg in GenArgList(arg_dict): _test_rnn_utils_pack_sequence(test_case, *arg[0:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_sqrt_square_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLinalgVectorNorm2D(flow.unittest.TestCase): @autotest(n=2, auto_backward=False, check_graph=True, rtol=0.5, atol=0.5) def test_sqrt_sum_with_cpu_random_data(test_case): device = cpu_device() x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5, requires_grad=False).to( device ) y = torch.linalg.norm(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=2, auto_backward=False, check_graph=True) def test_sqrt_sum_with_cuda_random_data(test_case): device = gpu_device() x = random_tensor(ndim=4, dim1=10, dim2=10, dim3=10, requires_grad=False).to( device ) y = torch.linalg.norm(x) return y @autotest(n=2, auto_backward=False, check_graph=True, rtol=0.5, atol=0.5) def test_scalar_print_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5, requires_grad=False).to( device ) y = torch.linalg.norm(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_tensor_offload.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.nn as nn import oneflow.unittest # NOTE(Li Xiang): This variable controls the mem comparison method of the tensor offload test. # 1: Strictly test, compare mem changes according to tensor size. # 2: Loose test, compare mem changes before and after offload; # 3: Execute only offload, skip mem check. offload_tensor_test_mem_mode = 3 def _test_tensor_offload_d2h(test_case, input, tensor_mem): print("\n- test offload cuda mem use") test_case.assertTrue(not input.is_offloaded()) before_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - before ", before_used) before_id = id(input) input.offload() test_case.assertTrue(input.is_offloaded()) test_case.assertEqual(input.device, flow.device("cuda")) after_used = flow._oneflow_internal.GetCUDAMemoryUsed() after_id = id(input) print(" - after ", after_used) change_as_expected = (before_used - after_used) == tensor_mem # Check tensor_mem cuda memory released if offload_tensor_test_mem_mode == 1: test_case.assertTrue(change_as_expected) elif offload_tensor_test_mem_mode == 2: if tensor_mem != 0: test_case.assertTrue(before_used > after_used) print(" - tensor size ", tensor_mem) print(" - change ", after_used - before_used) print(" - change as expected ", change_as_expected) test_case.assertEqual(before_id, after_id) def _test_tensor_load_h2d(test_case, input, tensor_mem): print("\n- test load cuda mem use") test_case.assertTrue(input.is_offloaded()) before_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - before ", before_used) before_id = id(input) input.load() test_case.assertTrue(not input.is_offloaded()) test_case.assertEqual(input.device, flow.device("cuda")) after_used = flow._oneflow_internal.GetCUDAMemoryUsed() after_id = id(input) print(" - after ", after_used) # Check tensor_mem cuda memory allocated change_as_expected = (after_used - before_used) == tensor_mem if offload_tensor_test_mem_mode == 1: test_case.assertTrue(change_as_expected) elif offload_tensor_test_mem_mode == 2: if tensor_mem != 0: test_case.assertTrue(after_used > before_used) print(" - tensor size ", tensor_mem) print(" - change ", after_used - before_used) print(" - change as expected ", change_as_expected) test_case.assertEqual(before_id, after_id) def _get_tensor_mem(input): if input.dim() == 0: return 2 cnt_size = input.element_size() * flow.numel(input) return cnt_size / 1024 / 1024 @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTensorOffload(flow.unittest.TestCase): def test_tensor_offload_and_load_float32(test_case): flow.cuda.empty_cache() input = flow.tensor( np.random.randn(1024, 1024, 100), dtype=flow.float32, device=flow.device("cuda"), ) data = input.numpy() for i in range(3): input_tensor_mem = _get_tensor_mem(input) # test tensor offload _test_tensor_offload_d2h(test_case, input, input_tensor_mem) # data = input.numpy() will raise error here # test tensor load _test_tensor_load_h2d(test_case, input, input_tensor_mem) # test data after tensor load test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)) def test_tensor_offload_and_load_float16(test_case): flow.cuda.empty_cache() input = flow.tensor( np.random.randn(20, 1024, 1024), dtype=flow.float16, device=flow.device("cuda"), ) data = input.numpy() for i in range(3): input_tensor_mem = _get_tensor_mem(input) # test tensor offload _test_tensor_offload_d2h(test_case, input, input_tensor_mem) # data = input.numpy() will raise error here # test tensor load _test_tensor_load_h2d(test_case, input, input_tensor_mem) # test data after tensor load test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)) def test_tensor_offload_and_load_int64(test_case): flow.cuda.empty_cache() input = flow.tensor( np.random.randn(20, 1024, 1024), dtype=flow.int64, device=flow.device("cuda"), ) data = input.numpy() for i in range(3): input_tensor_mem = _get_tensor_mem(input) # test tensor offload _test_tensor_offload_d2h(test_case, input, input_tensor_mem) # data = input.numpy() will raise error here # test tensor load _test_tensor_load_h2d(test_case, input, input_tensor_mem) # test data after tensor load test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)) @unittest.skip("0 dim tensor is unstable in CI container mem tests.") def test_tensor_offload_and_load_0dim(test_case): flow.cuda.empty_cache() input = flow.tensor( np.random.randint(1, 10), dtype=flow.float16, device=flow.device("cuda"), ) data = input.numpy() for i in range(3): input_tensor_mem = _get_tensor_mem(input) # test tensor offload _test_tensor_offload_d2h(test_case, input, input_tensor_mem) # data = input.numpy() will raise error here # test tensor load _test_tensor_load_h2d(test_case, input, input_tensor_mem) # test data after tensor load test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)) def test_tensor_offload_and_load_0size(test_case): flow.cuda.empty_cache() input = flow.tensor( np.random.randn(0, 1024, 1024), dtype=flow.float16, device=flow.device("cuda"), ) data = input.numpy() for i in range(3): input_tensor_mem = 0 # test tensor offload _test_tensor_offload_d2h(test_case, input, input_tensor_mem) # data = input.numpy() will raise error here # test tensor load _test_tensor_load_h2d(test_case, input, input_tensor_mem) # test data after tensor load test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)) def test_tensor_offload_and_load_cpu_mem(test_case): input = flow.tensor( np.random.randn(1024, 1024, 100), dtype=flow.float32, device=flow.device("cuda"), ) before_used = flow._oneflow_internal.GetCPUMemoryUsed() before_id = id(input) input.offload() after_used = flow._oneflow_internal.GetCPUMemoryUsed() after_id = id(input) if offload_tensor_test_mem_mode == 2: test_case.assertTrue(after_used > before_used) elif offload_tensor_test_mem_mode == 3: print("cpu mem change value:", after_used - before_used) test_case.assertEqual(before_id, after_id) cur_used = flow._oneflow_internal.GetCPUMemoryUsed() before_id = id(input) input.load() after_used = flow._oneflow_internal.GetCPUMemoryUsed() after_id = id(input) if offload_tensor_test_mem_mode == 2: test_case.assertTrue(after_used < cur_used) elif offload_tensor_test_mem_mode == 3: print("cpu mem change value:", cur_used - after_used) test_case.assertEqual(before_id, after_id) def test_param_offload(test_case): def load_eager_model(model): for param in model.parameters(): print("\n- test param load cuda mem use") test_case.assertTrue(param.is_offloaded()) before_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - before ", before_used) param.load() after_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - after ", after_used) tensor_mem = _get_tensor_mem(param) change_as_expected = (after_used - before_used) == tensor_mem print(" - tensor size ", tensor_mem) print(" - change ", after_used - before_used) print(" - change as expected ", change_as_expected) test_case.assertTrue(not param.is_offloaded()) def offload_eager_model(model): for param in model.parameters(): print("\n- test param offload cuda mem use") test_case.assertTrue(not param.is_offloaded()) before_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - before ", before_used) param.offload() after_used = flow._oneflow_internal.GetCUDAMemoryUsed() print(" - after ", after_used) tensor_mem = _get_tensor_mem(param) change_as_expected = (before_used - after_used) == tensor_mem print(" - tensor size ", tensor_mem) print(" - change ", after_used - before_used) print(" - change as expected ", change_as_expected) test_case.assertTrue(param.is_offloaded()) class Model(nn.Module): def __init__(self): super().__init__() self.n_layer = 1 layer_list = list() for _ in range(self.n_layer): # Too small to seem mem change layer_list.append(nn.Linear(768, 4096)) # Big enough to seem mem change layer_list.append(nn.Linear(4096, 4096)) self.layers = nn.Sequential(*layer_list) def forward(self, x): return self.layers(x) model0 = Model().cuda() BZ = 128 dataset = [flow.rand((BZ, 768), dtype=flow.float32) for _ in range(128)] with flow.no_grad(): for idx, x in enumerate(dataset): print(f"iter {idx} begin") x = x.cuda() if idx != 0: # no need to load at first iter load_eager_model(model0) y0 = model0(x) offload_eager_model(model0) print(f"iter {idx} end") if idx == 1: break if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_tensor_str.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow import tensor import oneflow def _test_local_tensor_str(test_case, device): # int dtype x = flow.tensor([[1, 2, 3], [4, 5, -6]], device=flow.device(device)) tensor_str = str(x) test_case.assertTrue("3" in tensor_str) test_case.assertTrue("5" in tensor_str) test_case.assertTrue("-6" in tensor_str) test_case.assertTrue("2" in str(x[0][1])) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # empty x = flow.tensor([], device=flow.device(device)) tensor_str = str(x) test_case.assertTrue("[]" in tensor_str) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # scientific representation int_mode(val == np.ceil(val)) x = flow.tensor( [[1, 2, 3], [4, 5, 600000]], device=flow.device(device), dtype=flow.float64 ) tensor_str = str(x) test_case.assertTrue("6.0000e+05" in tensor_str) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # int_mode x = flow.tensor( [[1.0, 2.0, 3.0], [4.0, 5, 60]], device=flow.device(device), dtype=flow.float64 ) tensor_str = str(x) test_case.assertTrue("4." in tensor_str) test_case.assertTrue("60." in tensor_str) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # float dtype x = flow.tensor( [[1.3, 2.4, 3.5], [-4.6, 5, 60]], device=flow.device(device), dtype=flow.float64 ) tensor_str = str(x) test_case.assertTrue("3.5000" in tensor_str) test_case.assertTrue("-4.6000" in tensor_str) test_case.assertTrue("60.0000" in tensor_str) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # scientific representation float dtype x = flow.tensor( [[1.3, 2.4, 3.5], [-4.6, 5, 60000000]], device=flow.device(device), dtype=flow.float64, ) tensor_str = str(x) test_case.assertTrue("2.4000e+00" in tensor_str) test_case.assertTrue("3.5000e+00" in tensor_str) test_case.assertTrue("-4.6000e+00" in tensor_str) test_case.assertTrue("6.0000e+07" in tensor_str) test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy())) # summarized data float dtype x = flow.tensor( np.ones((100, 100, 100)), device=flow.device(device), dtype=flow.float64 ) tensor_str = str(x) test_case.assertTrue("1" in tensor_str) test_case.assertTrue("..." in tensor_str) def _test_global_tensor_str(test_case, device): placement = flow.placement(device, range(1)) # split global tensor x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # broadcast global tensor x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.broadcast]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # partial_sum global tensor x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.partial_sum]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # summarized global tensor x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) test_case.assertTrue("..." in tensor_str) # empty global tensor x = flow.ones((0, 10), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("[]" in tensor_str) def _test_global_tensor_str_2d(test_case, device): placement = flow.placement(device, range(2)) x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.broadcast]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # TODO: x[0][0].to("cuda") has bug # test_case.assertTrue("1." in str(x[0][0])) x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.partial_sum]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # TODO: this test has bug # test_case.assertTrue("..." in tensor_str) x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(1)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) # TODO: this test has bug # test_case.assertTrue("..." in tensor_str) x = flow.ones( (10, 10), placement=flow.placement(device, ranks=[0]), sbp=[flow.sbp.broadcast] ) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) x = flow.ones((2, 5), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) def _test_nd_sbp_tensor_str(test_case, device, sbp0, sbp1): placement = flow.placement(type=device, ranks=[[0, 1], [2, 3]]) sbp = [sbp0, sbp1] x = flow.ones((20, 20), placement=placement, sbp=sbp) tensor_str = str(x) test_case.assertTrue(str(sbp0) in tensor_str) test_case.assertTrue(str(sbp1) in tensor_str) class TestTensorStrModule(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() @unittest.skip("TODO: fengwei, this often fails") def test_local_tensor_str_1n1d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_local_tensor_str, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() @unittest.skip("TODO: fengwei, this often fails") def test_global_tensor_str_1n1d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_global_tensor_str, ] arg_dict["device"] = ["cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_tensor_str_1n2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_global_tensor_str_2d, ] arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() def test_nd_sbp_tensor_str(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_nd_sbp_tensor_str, ] arg_dict["device"] = ["cpu", "cuda"] sbp_arg_dict = OrderedDict() sbp_list = [ flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.partial_sum, ] sbp_arg_dict["sbp0"] = sbp_list sbp_arg_dict["sbp1"] = sbp_list for arg in GenArgList(arg_dict): for sbp in GenArgList(sbp_arg_dict): arg[0](test_case, *(arg[1:] + sbp[:])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/expensive/test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import itertools import os from collections import OrderedDict from collections.abc import Iterable import numpy as np import oneflow as flow import oneflow.unittest def GenCartesianProduct(sets): assert isinstance(sets, Iterable) for set in sets: assert isinstance(set, Iterable) if os.getenv("ONEFLOW_TEST_CPU_ONLY"): if "cuda" in set: set.remove("cuda") return itertools.product(*sets) def GenArgList(arg_dict): assert isinstance(arg_dict, OrderedDict) assert all([isinstance(x, list) for x in arg_dict.values()]) sets = [arg_set for (_, arg_set) in arg_dict.items()] return GenCartesianProduct(sets) def GenArgDict(arg_dict): return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)] class Args: def __init__(self, flow_args, tf_args=None): super().__init__() if tf_args is None: tf_args = flow_args self.flow_args = flow_args self.tf_args = tf_args def __str__(self): return "flow_args={} tf_args={}".format(self.flow_args, self.tf_args) def __repr__(self): return self.__str__() type_name_to_flow_type = { "float16": flow.float16, "float32": flow.float32, "double": flow.double, "int8": flow.int8, "int32": flow.int32, "int64": flow.int64, "uint8": flow.uint8, } type_name_to_np_type = { "float16": np.float16, "float32": np.float32, "double": np.float64, "int8": np.int8, "int32": np.int32, "int64": np.int64, "uint8": np.uint8, } def FlattenArray(input_array): output_array = list() for x in np.nditer(input_array): output_array.append(x.tolist()) return output_array def Array2Numpy(input_array, target_shape): return np.array(input_array).reshape(target_shape, order="C") def Index2Coordinate(idx, tensor_shape): coordinate = [] tmp = idx for i in range(len(tensor_shape) - 1, -1, -1): axis_size = tensor_shape[i] coor = tmp % axis_size coordinate.insert(0, int(coor)) tmp = (tmp - coor) / axis_size return coordinate def Coordinate2Index(coordinate, tensor_shape): if len(coordinate) != len(tensor_shape): raise "wrong coordinate or shape" idx = 0 for (i, coor) in enumerate(coordinate): size_at_axis = coor for j in range(i + 1, len(tensor_shape)): size_at_axis *= tensor_shape[j] idx += size_at_axis return idx def generate_graph(func): class Graph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, *args): return func(*args) return Graph() ================================================ FILE: python/oneflow/test/gen_ops_process.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import subprocess import glob import re def get_api(rst_dir): """ Extract operator names from rst files. `currentmodule` is not regarded as operators. `autoclass` and `automodule` are regarded as operators in the absence of `members`. """ op_files = glob.glob(rst_dir + "/*.rst") op_files.remove(rst_dir + "/graph.rst") op_files.remove(rst_dir + "/index.rst") api_list = [] api_str = "" for op_file in op_files: with open(op_file, "r") as f: line = f.readline() pre = "" while line: skip = False if ".. currentmodule::" in line: pre = line.strip().replace(".. currentmodule::", "") + "." elif ".. autofunction::" in line: if "oneflow" not in line: api_str += pre api_str += line.replace(".. autofunction::", "") elif ( ".. autosummary::" in line or ".. autoclass::" in line or ":toctree:" in line or ":nosignatures:" in line or ":template:" in line ): if ":nosignatures:" in line: line = f.readline() if ":template:" in line: line = f.readline() line = f.readline() while line and len(line.replace(" ", "")) > 1: if "oneflow" not in line: api_str += pre api_str += line line = f.readline() elif ".. automodule::" in line: pre_a = line.replace(".. automodule:: ", "") line = f.readline() skip = True if ":members:" in line and len(line) > 14: pre_a = pre_a.strip() + "." api_str += pre_a + line.replace(":members:", "") line = f.readline() while ( line and ":" not in line and len(line.replace(" ", "")) > 1 ): api_str += pre_a + line line = f.readline() if not skip: line = f.readline() api_list = api_str.strip().replace(" ", "").replace(",", "").split("\n") return api_list def get_profile_func(path): """ Iterate through files under `path` to find out all operator names, and update code links to file_func_map_list by file_func_map. """ files = os.listdir(path) commit_bytes = subprocess.check_output(["git", "rev-parse", "HEAD"]) commit_str = commit_bytes.decode("utf-8").replace("\n", "") result_profile_func_list = [] for file in files: if file != "log" and not os.path.isdir(file) and file.find("__pycache__") == -1: f = open(os.path.join(path, file)) last_line = "" iter_f = iter(f) line_num = 1 for line in iter_f: line = line.strip() match = re.fullmatch(r"^@profile\((.+)\)$", line) if match: tem_profile = match.group(1) tem_profile_name = tem_profile.split(".")[-1] result_profile_func_list.append(tem_profile_name) return result_profile_func_list def get_test_func(path): """ Iterate through files under `path` to find out all operator names, and update code links to file_func_map_list by file_func_map. """ files = os.listdir(path) commit_bytes = subprocess.check_output(["git", "rev-parse", "HEAD"]) commit_str = commit_bytes.decode("utf-8").replace("\n", "") result_func_list = [] for file in files: if file != "log" and not os.path.isdir(file) and file.find("__pycache__") == -1: f = open(os.path.join(path, file)) last_line = "" iter_f = iter(f) line_num = 1 for line in iter_f: line = line.strip() rem = re.match("def .*?(test_.*)\(test_case.*", line) if rem and "#" not in line: func_name = rem.group(1).replace("_test_", "").replace("test_", "") result_func_list.append(func_name) file_func_map[func_name] = ( f" [{func_name}](" + "https://github.com/Oneflow-Inc/oneflow/blob/" + commit_str + "/python/oneflow/test/" + path + "/" + file + f"#L{line_num}) " ) elif last_line.startswith("add_docstr"): result_func_list.append(line[0:-1]) file_func_map[line[0:-1]] = ( f" [{line[0:-1]}](" + "https://github.com/Oneflow-Inc/oneflow/blob/" + commit_str + "/python/oneflow/test/" + path + "/" + file + f"#L{line_num}) " ) last_line = line line_num += 1 return result_func_list def pure_match(x, y): """ Check whether x contains y. The purpose of identifying "." is to accurately match operator documents. For example, if we make pos = x.find(y) while y = clip_, either oneflow.Tensor.clip or oneflow.Tensor.clip_ is right. Besides, identifying "_" is important. For example, if we make pos = x.find(y) while y = squeeze, either test of squeeze or unsqueeze is right. """ x = x.lower() y = y.lower() pos = -1 if "." in x: x = x.split(".") for i in x: if i == y: pos = 1 break elif "_" in y: pos = x.find(y) else: x = x.split("_") for i in x: if i == y: pos = 1 break return pos != -1 def match_test_func(func, func_list): """ func: operator name func_list: names of all operators Check whether func_list contains func. If yes, return matching content, or else return "". """ match_res = "" for i in range(len(func_list)): if pure_match(func_list[i], func): match_res = func_list[i] break return match_res if __name__ == "__main__": api_list = get_api("../../../docs/source") dir_list = [ ["../../../python/oneflow/framework/docstr"], ["../../../python/oneflow/test/modules", "../../../python/oneflow/test/tensor"], ["../../../python/oneflow/test/exceptions"], ] num_cols = 4 test_func_list = list() test_profile_list = list() file_func_map = dict() file_func_map_list = [] for i in range(0, len(dir_list)): tmp_func_list = list() tmp_profile_list = list() file_func_map = dict() for path in dir_list[i]: tmp_func_list.extend(get_test_func(path)) tmp_profile_list.extend(get_profile_func(path)) test_func_list.append(tmp_func_list) test_profile_list.extend(tmp_profile_list) file_func_map_list.append(file_func_map) result_list = [] result_list.append(f"## Ops Version : Alpha") result_list.append(f"") result_list.append(f"") table_head = f"| Op Name | Doc Test | Compatiable/Completeness Test | Exception | Performance Test |" result_list.append(table_head) result_list.append( f"| ------------------------- | ------------- | ----------------------------- | --------- | ---------------- |" ) cnt0 = 0 # the number of doc_test cnt1 = 0 # the number of compatiable_completeness_test cnt2 = 0 # the number of exception_test cnt3 = 0 # the number of profile_test for name in api_list: table_line = f"| {name} |" name = name.split(".")[-1] for i in range(3): match_name = match_test_func(name, test_func_list[i]) if match_name != "": if i == 0: cnt0 += 1 elif i == 1: cnt1 += 1 else: cnt2 += 1 table_line += file_func_map_list[i][match_name] table_line += " |" if name in test_profile_list: table_line += " done " cnt3 += 1 table_line += " |" result_list.append(table_line) doc_test_ratio = cnt0 / len(api_list) compatiable_completeness_test_ratio = cnt1 / len(api_list) exception_test_ratio = cnt2 / len(api_list) performance_test_ratio = cnt3 / len(api_list) result_list.append(f"## Test Data Summary") result_list.append(f"- OneFlow Total API Number: {len(api_list)}") result_list.append( f"- Doc Test Ratio: {100*doc_test_ratio:.2f}% ({cnt0} / {len(api_list)})" ) result_list.append( f"- Compatiable/Completeness Test Ratio: {100*compatiable_completeness_test_ratio:.2f}% ({cnt1} / {len(api_list)})" ) result_list.append( f"- Exception Test Ratio: {100*exception_test_ratio:.2f}% ({cnt2} / {len(api_list)})" ) result_list.append( f"- Performance Test Ratio: {100*performance_test_ratio:.2f}% ({cnt3} / {len(api_list)})" ) f = open("./README.md", "w") for line in result_list: f.write(line + "\n") f.close() ================================================ FILE: python/oneflow/test/graph/alexnet_model.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import oneflow.nn as nn from typing import Any __all__ = ["AlexNet", "alexnet"] class AlexNet(nn.Module): def __init__(self, num_classes: int = 1000) -> None: super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) def forward(self, x: flow.Tensor) -> flow.Tensor: x = self.features(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.classifier(x) return x def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper. The required minimum input size of the model is 63x63. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ model = AlexNet(**kwargs) return model ================================================ FILE: python/oneflow/test/graph/ofrecord_data_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import os class OFRecordDataLoader(flow.nn.Module): def __init__( self, ofrecord_root: str = "./ofrecord", mode: str = "train", # "val" dataset_size: int = 9469, batch_size: int = 1, ): super().__init__() channel_last = False output_layout = "NHWC" if channel_last else "NCHW" self.train_record_reader = flow.nn.OFRecordReader( ofrecord_root, batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, random_shuffle=True if mode == "train" else False, shuffle_after_epoch=True if mode == "train" else False, ) self.record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) color_space = "RGB" height = 224 width = 224 self.record_image_decoder = ( flow.nn.OFRecordImageDecoderRandomCrop("encoded", color_space=color_space) if mode == "train" else flow.nn.OFRecordImageDecoder("encoded", color_space=color_space) ) self.resize = ( flow.nn.image.Resize(target_size=[height, width]) if mode == "train" else flow.nn.image.Resize( resize_side="shorter", keep_aspect_ratio=True, target_size=256 ) ) self.flip = flow.nn.CoinFlip(batch_size=batch_size) if mode == "train" else None rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] self.crop_mirror_norm = ( flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) if mode == "train" else flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, crop_h=height, crop_w=width, crop_pos_y=0.5, crop_pos_x=0.5, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) ) self.batch_size = batch_size self.dataset_size = dataset_size def __len__(self): return self.dataset_size // self.batch_size def forward(self): train_record = self.train_record_reader() label = self.record_label_decoder(train_record) image_raw_buffer = self.record_image_decoder(train_record) image = self.resize(image_raw_buffer)[0] rng = self.flip() if self.flip != None else None image = self.crop_mirror_norm(image, rng) return image, label ================================================ FILE: python/oneflow/test/graph/optimizer_test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np def clip_grad_norm_np(np_grad, max_norm, norm_type): max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == float("inf"): total_norm = np.max(np.abs(np_grad)) if norm_type == float("-inf"): total_norm = np.min(np.abs(np_grad)) elif norm_type == 0: total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0) else: total_norm = np_grad for i in range(np_grad.ndim, 0, -1): total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: np_grad = np_grad * clip_coef return total_norm, np_grad ================================================ FILE: python/oneflow/test/graph/test_alexnet_auto_parallel.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import time import unittest import argparse import numpy as np import oneflow as flow import oneflow.unittest from alexnet_model import alexnet import flowvision as vision import flowvision.transforms as transforms def load_data_fashion_mnist( batch_size, resize=None, root="./data-test/fashion-mnist", download=True, source_url=None, num_workers=0, ): """Download the Fashion-MNIST dataset and then load into memory.""" root = os.path.expanduser(root) trans = [] if resize: trans.append(transforms.Resize(resize)) trans.append(transforms.ToTensor()) transform = transforms.Compose(trans) mnist_train = vision.datasets.FashionMNIST( root=root, train=True, transform=transform, download=download, source_url=source_url, ) mnist_test = vision.datasets.FashionMNIST( root=root, train=False, transform=transform, download=download, source_url=source_url, ) train_iter = flow.utils.data.DataLoader( mnist_train, batch_size, shuffle=True, num_workers=num_workers ) test_iter = flow.utils.data.DataLoader( mnist_test, batch_size, shuffle=False, num_workers=num_workers ) return train_iter, test_iter def _parse_args(): parser = argparse.ArgumentParser("flags for train alexnet") parser.add_argument( "--load_checkpoint", type=str, default="", help="load checkpoint" ) parser.add_argument( "--ofrecord_path", type=str, default=flow.unittest.dataset_dir("imagenette/ofrecord"), help="dataset path", ) # training hyper-parameters parser.add_argument( "--learning_rate", type=float, default=0.02, help="learning rate" ) parser.add_argument("--mom", type=float, default=0.9, help="momentum") parser.add_argument("--epochs", type=int, default=1, help="training epochs") parser.add_argument("--batch_size", type=int, default=128, help="val batch size") return parser.parse_known_args() def _test_alexnet_graph(test_case, args, placement, sbp): data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "fashion-mnist-lenet" ) source_url = "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/" train_iter, test_iter = load_data_fashion_mnist( batch_size=args.batch_size, root=data_dir, download=True, source_url=source_url, num_workers=0, resize=(112, 112), ) # oneflow init start_t = time.time() alexnet_module = alexnet(num_classes=10) end_t = time.time() print("init time : {}".format(end_t - start_t)) alexnet_module.to_global(placement, sbp) of_cross_entropy = flow.nn.CrossEntropyLoss().to_global(placement, sbp) of_sgd = flow.optim.SGD( alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom ) class AlexNetGraph(flow.nn.Graph): def __init__(self): super().__init__() self.alexnet = alexnet_module self.cross_entropy = of_cross_entropy self.add_optimizer(of_sgd) self.config.enable_auto_parallel(True) self.config.enable_auto_parallel_ignore_user_sbp_config(True) self.config.enable_auto_parallel_trunk_algo(True) self.config.enable_auto_parallel_sbp_collector(True) def build(self, image, label): logits = self.alexnet(image) loss = self.cross_entropy(logits, label) loss.backward() return loss alexnet_graph = AlexNetGraph() class AlexNetEvalGraph(flow.nn.Graph): def __init__(self): super().__init__() self.alexnet = alexnet_module self.config.enable_auto_parallel(True) self.config.enable_auto_parallel_ignore_user_sbp_config(True) self.config.enable_auto_parallel_trunk_algo(True) self.config.enable_auto_parallel_sbp_collector(True) def build(self, image): with flow.no_grad(): logits = self.alexnet(image) predictions = logits.softmax() return predictions alexnet_eval_graph = AlexNetEvalGraph() of_losses = [] print_interval = 20 acc = 0.0 for epoch in range(args.epochs): alexnet_module.train() for i, (image, label) in enumerate(train_iter): # oneflow graph train if image.shape[0] != args.batch_size: # drop last batch break start_t = time.time() image = image.to_global(placement, sbp).expand(args.batch_size, 3, 112, 112) label = label.to_global(placement, sbp) loss = alexnet_graph(image, label) end_t = time.time() if i % print_interval == 0: l = loss.numpy() of_losses.append(l) if flow.env.get_rank() == 0: print( "epoch {} train iter {}/{} oneflow loss {}, train time : {}".format( epoch, i, len(train_iter), l, end_t - start_t ) ) # Stop after 20 iters to save time break if flow.env.get_rank() == 0: print("epoch %d train done, start validation" % epoch) alexnet_module.eval() correct_of = 0.0 total_of = 0.0 for image, label in test_iter: # oneflow graph eval if image.shape[0] != args.batch_size: # drop last batch break start_t = time.time() image = image.to_global(placement, sbp).expand(args.batch_size, 3, 112, 112) predictions = alexnet_eval_graph(image) of_predictions = predictions.numpy() clsidxs = np.argmax(of_predictions, axis=1) label_nd = label.numpy() for i in range(args.batch_size): total_of += 1 if clsidxs[i] == label_nd[i]: correct_of += 1 end_t = time.time() acc = correct_of / total_of if flow.env.get_rank() == 0: print("epoch %d, oneflow top1 val acc: %f" % (epoch, acc)) # test_case.assertTrue(acc > 0.50) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAlexnetAutoParallel(oneflow.unittest.TestCase): def test_alexnet_auto_parallel_1d_sbp(test_case): args, unknown_args = _parse_args() placement = flow.placement.all("cuda") sbp = [flow.sbp.broadcast,] * len(placement.ranks.shape) _test_alexnet_graph(test_case, args, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_alexnet_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import numpy as np import os import time import unittest import oneflow as flow import oneflow.unittest from alexnet_model import alexnet from ofrecord_data_utils import OFRecordDataLoader def _parse_args(): parser = argparse.ArgumentParser("flags for train alexnet") parser.add_argument( "--save_checkpoint_path", type=str, default="./checkpoints", help="save checkpoint root dir", ) parser.add_argument( "--load_checkpoint", type=str, default="", help="load checkpoint" ) parser.add_argument( "--ofrecord_path", type=str, default=flow.unittest.dataset_dir("imagenette/ofrecord"), help="dataset path", ) parser.add_argument( "--train_dataset_size", type=int, default=400, help="train_dataset size" ) parser.add_argument( "--val_dataset_size", type=int, default=40, help="val_dataset size" ) # training hyper-parameters parser.add_argument( "--learning_rate", type=float, default=0.001, help="learning rate" ) parser.add_argument("--mom", type=float, default=0.9, help="momentum") parser.add_argument("--epochs", type=int, default=1, help="training epochs") parser.add_argument( "--train_batch_size", type=int, default=4, help="train batch size" ) parser.add_argument("--val_batch_size", type=int, default=4, help="val batch size") parser.add_argument("--device", type=str, default="cuda", help="device") return parser.parse_known_args() def _test_alexnet_graph_repr(test_case, args): train_data_loader = OFRecordDataLoader( ofrecord_root=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_batch_size, ) alexnet_module = alexnet() alexnet_module.to(args.device) of_cross_entropy = flow.nn.CrossEntropyLoss() of_cross_entropy.to(args.device) of_sgd = flow.optim.SGD( alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom ) class AlexNetGraph(flow.nn.Graph): def __init__(self): super().__init__() self.alexnet = alexnet_module self.cross_entropy = of_cross_entropy self.add_optimizer(of_sgd) def build(self, image, label): logits = self.alexnet(image) loss = self.cross_entropy(logits, label) loss.backward() return loss alexnet_graph = AlexNetGraph() print("repr(alexnet_graph) before run: \n", repr(alexnet_graph)) # debug graph build alexnet_graph.debug(1, op_repr_with_py_stack=True, max_py_stack_depth=4) alexnet_module.train() image, label = train_data_loader() image = image.to(args.device) label = label.to(args.device) loss = alexnet_graph(image, label) print("repr(alexnet_graph) after run: \n", repr(alexnet_graph)) def _test_alexnet_graph(test_case, args): train_data_loader = OFRecordDataLoader( ofrecord_root=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_batch_size, ) val_data_loader = OFRecordDataLoader( ofrecord_root=args.ofrecord_path, mode="val", dataset_size=args.val_dataset_size, batch_size=args.val_batch_size, ) # oneflow init start_t = time.time() alexnet_module = alexnet() end_t = time.time() print("init time : {}".format(end_t - start_t)) alexnet_module.to(args.device) of_cross_entropy = flow.nn.CrossEntropyLoss() of_cross_entropy.to(args.device) of_sgd = flow.optim.SGD( alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom ) class AlexNetGraph(flow.nn.Graph): def __init__(self): super().__init__() self.train_data_loader = train_data_loader self.alexnet = alexnet_module self.cross_entropy = of_cross_entropy self.add_optimizer(of_sgd) def build(self): image, label = self.train_data_loader() image = image.to(args.device) label = label.to(args.device) logits = self.alexnet(image) loss = self.cross_entropy(logits, label) loss.backward() return loss alexnet_graph = AlexNetGraph() class AlexNetEvalGraph(flow.nn.Graph): def __init__(self): super().__init__() self.val_data_loader = val_data_loader self.alexnet = alexnet_module def build(self): with flow.no_grad(): image, label = self.val_data_loader() image = image.to(args.device) logits = self.alexnet(image) predictions = logits.softmax() return predictions, label alexnet_eval_graph = AlexNetEvalGraph() of_losses = [] all_samples = len(val_data_loader) * args.val_batch_size print_interval = 10 for epoch in range(args.epochs): alexnet_module.train() for b in range(len(train_data_loader)): # oneflow graph train start_t = time.time() loss = alexnet_graph() end_t = time.time() if b % print_interval == 0: l = loss.numpy() of_losses.append(l) print( "epoch {} train iter {} oneflow loss {}, train time : {}".format( epoch, b, l, end_t - start_t ) ) print("epoch %d train done, start validation" % epoch) alexnet_module.eval() correct_of = 0.0 for b in range(len(val_data_loader)): start_t = time.time() predictions, label = alexnet_eval_graph() of_predictions = predictions.numpy() clsidxs = np.argmax(of_predictions, axis=1) label_nd = label.numpy() for i in range(args.val_batch_size): if clsidxs[i] == label_nd[i]: correct_of += 1 end_t = time.time() print("epoch %d, oneflow top1 val acc: %f" % (epoch, correct_of / all_samples)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestAlexnetGraph(oneflow.unittest.TestCase): def test_alexnet_graph_repr(test_case): args, unknown_args = _parse_args() args.device = "cuda" _test_alexnet_graph_repr(test_case, args) @unittest.skip("skip for now, becase it failed 2 times in past week") def test_alexnet_graph_gpu(test_case): args, unknown_args = _parse_args() args.device = "cuda" _test_alexnet_graph(test_case, args) def test_alexnet_graph_cpu(test_case): args, unknown_args = _parse_args() args.device = "cpu" args.train_batch_size = 40 _test_alexnet_graph(test_case, args) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_comb1to2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow import nn import os import numpy as np import oneflow.unittest class _TestModuleDiffHierarchy(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: # (2, 2) -> 4 x = x.to_global( placement=flow.placement(type="cuda", ranks=np.array(range(4))), sbp=[sbp1], ) # 4 -> (2, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), sbp=[sbp2, sbp3], ) return x class _TestModuleDiffPlacement(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: # (2, 2) -> 3 # 4 is not divisible by 3 x = x.to_global( placement=flow.placement(type="cuda", ranks=np.array(range(3))), sbp=[sbp1], ) # 3 -> (2, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), sbp=[sbp2, sbp3], ) return x class _TestGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): x = self.model(x) return x @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestLazyAllSbpCombinationTesting(flow.unittest.TestCase): def test_lazy_boxing_2d_all_combination(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" x = flow.ones( 4, 12, 4, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), ) flow.boxing.nccl.enable_use_compute_stream(False) model_diff_hierarchy = _TestModuleDiffHierarchy() graph_diff_hierarchy = _TestGraph(model_diff_hierarchy) y = graph_diff_hierarchy(x) model_diff_placement = _TestModuleDiffPlacement() graph_diff_placement = _TestGraph(model_diff_placement) z = graph_diff_placement(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_comb2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow import nn import os import numpy as np import oneflow.unittest class _TestModule(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] y = x for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: # in this case, use intra group boxing if sbp1 == sbp3: continue for sbp4 in sbp_1ds: # (2, 2) -> (2, 2) x = x.to_global(sbp=[sbp1, sbp2]) x = x.to_global(sbp=[sbp3, sbp4]) return x class _TestGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): x = self.model(x) return x @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestLazyAllSbpCombinationTesting(flow.unittest.TestCase): def test_lazy_boxing_2d_all_combination(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" model = _TestModule() graph = _TestGraph(model) flow.boxing.nccl.enable_use_compute_stream(False) x = flow.ones( 4, 4, 4, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), ) y = graph(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_forward_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestForwardGraph(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_forward_graph(test_case): class SubModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter(flow.Tensor(6, 6)) self.relu = flow.nn.ReLU() def forward(self, x, y): x = oneflow._C.matmul(x, self.weight) x = self.relu(x) y = self.relu(y) return (x, y) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.layer = SubModule() self.register_buffer("dummy_buff", flow.Tensor(6, 8)) def forward(self, x, y): (x, y) = self.layer(x, y) x = oneflow._C.flatten(x, 1) x = oneflow._C.matmul(x, self.dummy_buff) return (x, y) class CustomGraph(flow.nn.Graph): def __init__(self, module): super().__init__() self.m = module def build(self, x, y): out = self.m(x, y) return out m = CustomModule() m.to("cuda") g = CustomGraph(m) x = flow.Tensor(6, 6) flow.nn.init.uniform_(x, a=-1.0, b=1.0) x = x.to("cuda") y = flow.Tensor(10, 10) flow.nn.init.uniform_(y, a=-1.0, b=1.0) y = y.to("cuda") print(repr(g)) (z, a) = g._compile(x, y) test_case.assertEqual(z.shape, (6, 8)) test_case.assertEqual(z.is_lazy, False) test_case.assertEqual(a.shape, (10, 10)) test_case.assertEqual(a.is_lazy, False) print("graph proto: ", g._graph_proto) def test_add_backward(test_case): linear = flow.nn.Linear(3, 8) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) class GraphAddBackward(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear self.add_optimizer(of_sgd) def build(self, x): out = self.linear(x) out = out.mean() out.backward() return out g_with_b = GraphAddBackward() x = flow.ones(8, 3) out = g_with_b(x) print("graph proto: ", g_with_b._graph_proto) print(repr(g_with_b)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_free_tensor_not_in_job.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest import oneflow.nn as nn def get_bn_graph(): model = nn.BatchNorm1d(6) model.eval() model.to_global(flow.placement.all("cpu"), flow.sbp.broadcast) class Testgraph(flow.nn.Graph): def __init__(self, model): super(Testgraph, self).__init__() self.module = model def build(self, x): return self.module(x) test_graph = Testgraph(model) return test_graph @flow.unittest.skip_unless_1n1d() class TestFreeTensorNotInJob(flow.unittest.TestCase): def test_free_tensor_not_in_job(test_case): x = flow.randn(1, 6, 2).to_global( placement=flow.placement.all("cpu"), sbp=flow.sbp.split(0) ) y = get_bn_graph()(x) test_case.assertEqual(y.size(), (1, 6, 2)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_fx_fuse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import oneflow.nn as nn import numpy as np import unittest from oneflow.test_utils.automated_test_util import * import numpy as np import copy from typing import Dict, Any, Tuple def _fuse_conv_bn_eval(conv, bn): """ Given a conv Module `A` and an batch_norm module `B`, returns a conv module `C` such that C(x) == B(A(x)) in inference mode. """ assert not (conv.training or bn.training), "Fusion only for eval!" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = _fuse_conv_bn_weights( fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, ) return fused_conv def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: conv_b = flow.zeros_like(bn_rm) if bn_w is None: bn_w = flow.ones_like(bn_rm) if bn_b is None: bn_b = flow.zeros_like(bn_rm) bn_var_rsqrt = flow.rsqrt(bn_rv + bn_eps) conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape( [-1] + [1] * (len(conv_w.shape) - 1) ) conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b return flow.nn.Parameter(conv_w), flow.nn.Parameter(conv_b) def _parent_name(target: str) -> Tuple[str, str]: """ Splits a qualname into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) """ *parent, name = target.rsplit(".", 1) return parent[0] if parent else "", name def _replace_node_module( node: flow.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module ): assert isinstance(node.target, str) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, new_module) def _fx_fuse(model: flow.nn.Module) -> flow.nn.Module: model = copy.deepcopy(model) # The first step of most FX passes is to symbolically trace our model to # obtain a `GraphModule`. This is a representation of our original model # that is functionally identical to our original model, except that we now # also have a graph representation of our forward pass. fx_model: flow.fx.GraphModule = flow.fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) # The primary representation for working with FX are the `Graph` and the # `Node`. Each `GraphModule` has a `Graph` associated with it - this # `Graph` is also what generates `GraphModule.code`. # The `Graph` itself is represented as a list of `Node` objects. Thus, to # iterate through all of the operations in our graph, we iterate over each # `Node` in our `Graph`. for node in fx_model.graph.nodes: # The FX IR contains several types of nodes, which generally represent # call sites to modules, functions, or methods. The type of node is # determined by `Node.op`. if ( node.op != "call_module" ): # If our current node isn't calling a Module then we can ignore it. continue # For call sites, `Node.target` represents the module/function/method # that's being called. Here, we check `Node.target` to see if it's a # batch norm module, and then check `Node.args[0].target` to see if the # input `Node` is a convolution. if ( type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d ): if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue conv = modules[node.args[0].target] bn = modules[node.target] fused_conv = _fuse_conv_bn_eval(conv, bn) _replace_node_module(node.args[0], modules, fused_conv) # As we've folded the batch nor into the conv, we need to replace all uses # of the batch norm with the conv. node.replace_all_uses_with(node.args[0]) # Now that all uses of the batch norm have been replaced, we can # safely remove the batch norm. fx_model.graph.erase_node(node) fx_model.graph.lint() # After we've modified our graph, we need to recompile our graph in order # to keep the generated code in sync. fx_model.recompile() return fx_model @flow.unittest.skip_unless_1n1d() class TestConvBnFuse(flow.unittest.TestCase): def test_fuse(test_case): class WrappedBatchNorm(nn.Module): def __init__(self): super().__init__() self.mod = nn.BatchNorm2d(1) def forward(self, x): return self.mod(x) class M(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.bn1 = nn.BatchNorm2d(1) self.conv2 = nn.Conv2d(1, 1, 1) self.nested = nn.Sequential(nn.BatchNorm2d(1), nn.Conv2d(1, 1, 1),) self.wrapped = WrappedBatchNorm() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.conv2(x) x = self.nested(x) x = self.wrapped(x) return x model = M() model.eval() fused_model = _fx_fuse(model) for i in range(10): inp = flow.randn(5, 1, 32, 32) test_case.assertTrue( np.allclose(fused_model(inp).numpy(), model(inp).numpy(), atol=1e-6) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_fx_replace_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from oneflow.fx import symbolic_trace, replace_pattern from oneflow.test_utils.automated_test_util import * import unittest class M(flow.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): val1 = flow.neg(w1) m1 = flow.cat([val1, w2]).sum() val2 = flow.neg(w1) m2 = flow.cat([val2, w2]).sum() return x + flow.max(m1) + flow.max(m2) @flow.unittest.skip_unless_1n1d() class TestReplaceOps(flow.unittest.TestCase): def test_pattern(test_case): traced = symbolic_trace(M()) def pattern(a1, a2): val1 = flow.neg(a1) return flow.cat([val1, a2]).sum() def replacement(w1, w2): return flow.stack([w1, w2]) replace_pattern(traced, pattern, replacement) test_case.assertTrue("cat" not in traced.code) test_case.assertTrue("neg" not in traced.code) test_case.assertTrue("stack" in traced.code) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_fx_symbolic_trace_module.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import oneflow.nn as nn import numpy as np import unittest from oneflow.test_utils.automated_test_util import * class AlexNet(nn.Module): def __init__(self, num_classes: int = 1000) -> None: super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) def forward(self, x: flow.Tensor) -> flow.Tensor: x = self.features(x) x = self.avgpool(x) x = flow.flatten(x, 1) x = self.classifier(x) return x @flow.unittest.skip_unless_1n1d() class TestAlexNet(flow.unittest.TestCase): def test_alexnet(test_case): m = AlexNet() m = m.eval() gm: flow.fx.GraphModule = flow.fx.symbolic_trace(m) for i in range(5): input = flow.randn(1, 3, 224, 224) test_case.assertTrue( np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_gbc1to2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import time import os def _test_general_basic_communication_1d_to_2d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return # input placement_x = flow.placement("cuda", ranks=[0, 1, 2]) placement_y = flow.placement("cuda", ranks=[[3, 0], [1, 2]]) local_np = np.arange(4 * 14).reshape(4, 14) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(False) class TestGeneralBasicCommunicationGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) return y graph = TestGeneralBasicCommunicationGraph() y = graph(x) out_np = y.numpy() in_np = x.numpy() test_case.assertTrue(np.array_equal(out_np, in_np)) def gen_nd_sbp_1d(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), ] return sbp_list def gen_nd_sbp_2d(): nd_sbp_list = [] for sbp0 in gen_nd_sbp_1d(): for sbp1 in gen_nd_sbp_1d(): nd_sbp_list.append([sbp0, sbp1]) return nd_sbp_list @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGeneralBasicCommunication(flow.unittest.TestCase): def test_general_basic_communication(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_nd_sbp_1d() arg_dict["dst_nd_sbp"] = gen_nd_sbp_2d() for arg in GenArgList(arg_dict): _test_general_basic_communication_1d_to_2d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_gbc2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import time import os def _test_general_basic_communication_same_placement(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return # skip src == dst if src_nd_sbp == dst_nd_sbp: return # in this case, use intra group boxing if src_nd_sbp[0] == dst_nd_sbp[0]: return # in this case, use inter group boxing if ( src_nd_sbp[1] == dst_nd_sbp[1] and src_nd_sbp[0] != src_nd_sbp[1] and dst_nd_sbp[0] != dst_nd_sbp[1] ): return # input placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) local_np = np.arange(4 * 5).reshape(4, 5) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(False) class TestGeneralBasicCommunicationGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y graph = TestGeneralBasicCommunicationGraph() y = graph(x) out_np = y.numpy() in_np = x.numpy() test_case.assertTrue(np.array_equal(out_np, in_np)) def gen_nd_sbp(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), ] nd_sbp_list = [] for sbp0 in sbp_list: for sbp1 in sbp_list: nd_sbp_list.append([sbp0, sbp1]) return nd_sbp_list @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGeneralBasicCommunication(flow.unittest.TestCase): def test_general_basic_communication(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_nd_sbp() arg_dict["dst_nd_sbp"] = gen_nd_sbp() for arg in GenArgList(arg_dict): _test_general_basic_communication_same_placement(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_gbc2to1d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import time import os def _test_general_basic_communication_2d_to_1d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() == dst_nd_sbp: return # input placement_x = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) placement_y = flow.placement("cuda", ranks=[0, 3, 4]) local_np = np.arange(13 * 5).reshape(13, 5) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(False) class TestGeneralBasicCommunicationGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) return y graph = TestGeneralBasicCommunicationGraph() y = graph(x) out_np = y.numpy() in_np = x.numpy() test_case.assertTrue(np.array_equal(out_np, in_np)) def gen_nd_sbp_1d(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), ] return sbp_list def gen_nd_sbp_2d(): nd_sbp_list = [] for sbp0 in gen_nd_sbp_1d(): for sbp1 in gen_nd_sbp_1d(): nd_sbp_list.append([sbp0, sbp1]) return nd_sbp_list @flow.unittest.skip_unless_2n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGeneralBasicCommunication(flow.unittest.TestCase): def test_general_basic_communication(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_nd_sbp_2d() arg_dict["dst_nd_sbp"] = gen_nd_sbp_1d() for arg in GenArgList(arg_dict): _test_general_basic_communication_2d_to_1d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_gbc2to2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import time import os def _test_general_basic_communication_2d_to_2d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return if dst_nd_sbp[0] == dst_nd_sbp[1] and src_nd_sbp[0] == src_nd_sbp[1]: return # input placement_x = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) placement_y = flow.placement("cuda", ranks=[[0, 3, 4], [2, 5, 6]]) local_np = np.arange(12 * 12).reshape(12, 12) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(False) class TestGeneralBasicCommunicationGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) return y graph = TestGeneralBasicCommunicationGraph() y = graph(x) out_np = y.numpy() in_np = x.numpy() test_case.assertTrue(np.array_equal(out_np, in_np)) def gen_nd_sbp(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), ] nd_sbp_list = [] for sbp0 in sbp_list: for sbp1 in sbp_list: nd_sbp_list.append([sbp0, sbp1]) return nd_sbp_list @flow.unittest.skip_unless_2n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGeneralBasicCommunication(flow.unittest.TestCase): def test_general_basic_communication(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_nd_sbp() arg_dict["dst_nd_sbp"] = gen_nd_sbp() for arg in GenArgList(arg_dict): _test_general_basic_communication_2d_to_2d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from threading import Thread import numpy as np import oneflow import oneflow as flow from oneflow.nn.graph import GraphModule, GraphTensor import oneflow.framework.graph_build_util as graph_build_util import oneflow.framework.scope_util as scope_util import oneflow.unittest class SubModule(flow.nn.Module): def __init__(self): super().__init__() self.conv1 = flow.nn.Conv2d(1, 1, 5) self.relu = flow.nn.ReLU() def forward(self, x): x = self.conv1(x) x = self.relu(x) return x class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.layer = SubModule() self.fc1 = flow.nn.Linear(36, 4) self.register_buffer("dummy_buff", flow.Tensor(1, 4)) def forward(self, x): x = self.layer(x) x = oneflow._C.flatten(x, 1) x = self.fc1(x) + self.dummy_buff return x @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraph(flow.unittest.TestCase): def test_add_nested_module(test_case): x = flow.Tensor(1, 1, 10, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) m = CustomModule() y = m(x) class CustomGraphNestedModule(flow.nn.Graph): def __init__(self): super().__init__() self.m = m def build(self, x): return self.m(x) g = CustomGraphNestedModule() test_case.assertTrue(isinstance(g.m, flow.nn.graph.Proxy)) test_case.assertEqual(g.m.to(GraphModule).type, "MODULE") test_case.assertEqual(g.m.to(GraphModule).name, "m") test_case.assertTrue(isinstance(g.m.dummy_buff, flow.nn.graph.Proxy)) test_case.assertEqual(g.m.dummy_buff.to(GraphTensor).type, "BUFFER") test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Proxy)) test_case.assertEqual(g.m.layer.conv1.to(GraphModule).name, "conv1") test_case.assertEqual(g.m.layer.conv1.to(GraphModule).name_prefix, "m.layer.") test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.nn.graph.Proxy)) test_case.assertEqual(g.m.layer.conv1.weight.to(GraphTensor).type, "PARAMETER") g.m.layer.conv1.to(GraphModule)._is_executing_forward = True test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor)) g.m.layer.conv1.to(GraphModule)._is_executing_forward = False test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5)) z = g.build(x) test_case.assertTrue(np.array_equal(y.numpy(), z.numpy())) def test_graph_name(test_case): class ACustomGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): return x class BCustomGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): return x class CBCustomGraph(BCustomGraph): def __init__(self): super().__init__() def create_graph(cnt): a = ACustomGraph() test_case.assertEqual(a.name, "ACustomGraph_" + str(cnt)) b = BCustomGraph() test_case.assertEqual(b.name, "BCustomGraph_" + str(cnt)) cb = CBCustomGraph() test_case.assertEqual(cb.name, "CBCustomGraph_" + str(cnt)) flow.nn.Graph._child_init_cnt.clear() for i in range(0, 3): create_graph(i) flow.nn.Graph._child_init_cnt.clear() for i in range(0, 3): create_graph(i) def test_graph_build_ctx(test_case): test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) with graph_build_util.lazy_mode.guard(True): test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True) with graph_build_util.lazy_mode.guard(False): test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True) test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) class CustomGraphGraphBuildCtx(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True) import oneflow.framework.session_context as session_ctx from oneflow.framework.multi_client_session import MultiClientSession session = session_ctx.GetDefaultSession() test_case.assertEqual(type(session), MultiClientSession) import oneflow.framework.scope_util as scope_util scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) test_case.assertEqual(session.id, scope_proto.session_id) test_case.assertEqual( oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(), self.name, ) return x g = CustomGraphGraphBuildCtx() test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) data = np.array([2.0, 1.0, 0.0, -1.0, -2.0]) x = flow.tensor(data, dtype=flow.float32) g._compile(x) print("graph proto", g._graph_proto) test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) def test_block_scope(test_case): class SubModule0(flow.nn.Module): def __init__(self): super().__init__() self.conv1 = flow.nn.Conv2d(1, 1, 5) def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) stage_int = scope_proto.attr_name2attr_value[ "pipeline_stage_id_hint" ].at_int64 test_case.assertEqual(stage_int, 0) out = self.conv1(x) weight = self.conv1.weight test_case.assertTrue(weight.is_lazy) return out class SubModule1(flow.nn.Module): def __init__(self): super().__init__() self.fc1 = flow.nn.Linear(36, 4, False) self.register_buffer("dummy_buff", flow.Tensor(1, 4)) def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) test_case.assertEqual( scope_proto.parent_scope_symbol_id, self.to(flow.nn.graph.GraphModule).prev_scope.symbol_id, ) ck_bool = scope_proto.attr_name2attr_value["checkpointing"] test_case.assertEqual(ck_bool.WhichOneof("value"), None) stage_int = scope_proto.attr_name2attr_value[ "pipeline_stage_id_hint" ].at_int64 test_case.assertEqual(stage_int, 1) name = ( self.to(flow.nn.graph.GraphModule).name_prefix + self.to(flow.nn.graph.GraphModule).name ) prefixes = [] for prefix in scope_proto.scope_op_name_prefixes: prefixes.append(prefix) name_in_scope = ".".join(prefixes) test_case.assertEqual(name, name_in_scope) b = self.dummy_buff dummy_buff_scope_proto = graph_build_util.scope_to_proto( self._buffers["dummy_buff"].to(flow.nn.graph.GraphTensor).scope ) test_case.assertEqual( dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id ) x = self.fc1(x) return x + b class CustomModule1(flow.nn.Module): def __init__(self): super().__init__() self.layer0 = SubModule0() self.layer1 = SubModule1() def forward(self, x, y): print("x0: ", x.shape) x = self.layer0(x) print("x1: ", x.shape) print("y0: ", y.shape) y = self.layer1(y) print("y1: ", y.shape) return (x, y) m = CustomModule1() class CustomGraphBlockScope(flow.nn.Graph): def __init__(self): super().__init__() self.m = m self.m.layer0.to(GraphModule).set_stage(stage_id=0) self.m.layer0.to(GraphModule).activation_checkpointing = True self.m.layer1.to(GraphModule).set_stage(stage_id=1) def build(self, x, y): return self.m(x, y) g = CustomGraphBlockScope() print(g) x = np.ones((1, 1, 10, 10)) x = flow.tensor(x, dtype=flow.float32) y = np.ones((16, 36)) y = flow.tensor(y, dtype=flow.float32) g._compile(x, y) def test_create_optimizer_in_graph(test_case): device = "cuda" linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) class OptCreatedInGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear # creat optimizer in nn.Graph and add parameter from ProxyModule self.add_optimizer( flow.optim.SGD(self.linear.parameters(), lr=0.001, momentum=0.9) ) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out g = OptCreatedInGraph() print(g) def test_graph_in_subthread(test_case): class TinyGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, input): return input + 1 def f(): tiny_graph = TinyGraph() input = flow.randn(1, 4) return tiny_graph(input) f() new_thread = Thread(target=f) new_thread.start() new_thread.join() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_activation_checkpoint.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import os import unittest import numpy as np import oneflow import oneflow as flow import oneflow.framework.graph_build_util as graph_build_util import oneflow.framework.scope_util as scope_util import oneflow.unittest from oneflow.nn.graph import GraphModule @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphActivationCheckpoint(flow.unittest.TestCase): def test_activation_checkpoint(test_case): loss_fn = flow.nn.MSELoss(reduction="sum") model = flow.nn.Sequential(flow.nn.Linear(3, 4), flow.nn.Linear(4, 4)) model1 = flow.nn.Sequential(flow.nn.Linear(4, 1), flow.nn.Flatten(0, 1)) class SubModule0(flow.nn.Module): def __init__(self): super().__init__() self.model = model def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.model(x) return out class SubModule1(flow.nn.Module): def __init__(self): super().__init__() self.model = model1 def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.model(x) return out optimizer = flow.optim.SGD(model.parameters(), lr=1e-6) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.model = SubModule0() self.model1 = SubModule1() self.loss_fn = loss_fn # Add an optimizer self.add_optimizer(optimizer) self.model.to(GraphModule).activation_checkpointing = True self.model1.to(GraphModule).activation_checkpointing = True def build(self, x, y): y_pred = self.model(x) y_pred = self.model1(y_pred) loss = self.loss_fn(y_pred, y) loss.backward() return loss linear_graph = LinearTrainGraph() x = flow.randn(10, 3) y = flow.randn(10) linear_graph._compile(x, y) graph_proto = linear_graph._full_graph_proto for op in graph_proto.net.op: # Check flatten gradient operator take checkpoiting as input if re.search("flatten.*grad", op.name, re.I) is not None: find_check_point = False for value in op.user_conf.input.values(): if ( re.search("Sys-Checkpointing-Fake-Fw-Op", str(value), re.I) is not None ): find_check_point = True print(value) test_case.assertTrue(find_check_point) # Check having insert identity op and first fake op of a segment has indentity grad as it's ctrl in op if ( re.search( "Sys-Checkpointing-Fake-Fw-Op_model.model.0-matmul*", op.name, re.I, ) is not None ): find_ctrl = False for name in op.ctrl_in_op_name: if re.search("identity", str(name), re.I) is not None: find_ctrl = True print(name) test_case.assertTrue(find_ctrl) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_arange.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestArangeGraph(oneflow.unittest.TestCase): def test_arange_graph(test_case): of_eager_out = flow.arange(start=0, end=100, step=3, device=flow.device("cuda")) class ArangeGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return flow.arange(start=0, end=100, step=3, device=flow.device("cuda")) arange_g = ArangeGraph() of_lazy_out = arange_g() test_case.assertTrue( np.allclose(of_eager_out.numpy(), of_lazy_out.numpy(), 1e-05, 1e-05) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_asymmetric_io.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGlobalAsymmetricGraph(oneflow.unittest.TestCase): def test_global_asymmetric_graph_gpu(test_case): Broadcast = [flow.sbp.broadcast] Placement_rank_0 = flow.placement("cuda", ranks=[0]) Placement_rank_1 = flow.placement("cuda", ranks=[1]) class MyGlobalAsymmetricModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 8, False) self.linear2 = flow.nn.Linear(8, 7, False) self.linear1.to_global(placement=Placement_rank_0, sbp=Broadcast) self.linear2.to_global(placement=Placement_rank_1, sbp=Broadcast) flow.nn.init.ones_(self.linear1.weight) flow.nn.init.constant_(self.linear2.weight, 2.3) def forward(self, x, y): out0 = x + y out1 = self.linear1(out0) out1 = out1.to_global(placement=Placement_rank_1, sbp=Broadcast) out2 = self.linear2(out1) return out2 class MyLocalModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 8, False) self.linear2 = flow.nn.Linear(8, 7, False) flow.nn.init.ones_(self.linear1.weight) flow.nn.init.constant_(self.linear2.weight, 2.3) def forward(self, x, y): # print("local_x in rank : ", flow.env.get_rank(), " is : ", x) # print("local_y in rank : ", flow.env.get_rank(), " is : ", y) out0 = x + y out1 = self.linear1(out0) out2 = self.linear2(out1) return out2 my_local_module = MyLocalModule() np_x = np.random.randn(5, 3) np_y = np.ones(3) local_x = flow.tensor(np_x, dtype=flow.float32) global_x = local_x.to_global( placement=flow.placement("cuda", ranks=[0, 1]), sbp=Broadcast ) local_x = global_x.to_local().to("cpu") local_y = flow.tensor(np_y, dtype=flow.float32) local_out = my_local_module(local_x, local_y) # print("eager_local_out: ", local_out) my_module = MyGlobalAsymmetricModule() x = local_x.to_global(placement=Placement_rank_0, sbp=Broadcast) y = local_y.to_global(placement=Placement_rank_0, sbp=Broadcast) class MyAsymmetricGraph(flow.nn.Graph): def __init__(self): super().__init__() self.my_net = my_module def build(self, x, y): return self.my_net(x, y) my_g = MyAsymmetricGraph() graph_out = my_g(x, y) test_case.assertTrue(graph_out.placement == Placement_rank_1) graph_local_out = graph_out.to_local() # NOTE(chengcheng): MUST call for each rank sync correct input copy graph_local_out_np = graph_local_out.numpy() # print("graph_local_out in rank ", flow.env.get_rank(), " is : ", graph_local_out) if flow.env.get_rank() == 0: test_case.assertTrue(graph_local_out.shape.numel() == 0) test_case.assertTrue(graph_local_out_np.size == np.array([]).size) elif flow.env.get_rank() == 1: test_case.assertTrue( np.allclose( graph_local_out.numpy(), local_out.numpy(), atol=1e-4, rtol=1e-4 ) ) else: test_case.assertTrue(False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_block.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import types import warnings import numpy as np import oneflow as flow import oneflow.nn as nn import oneflow.unittest import oneflow.framework.graph_build_util as graph_build_util import oneflow.framework.scope_util as scope_util @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphBlock(flow.unittest.TestCase): def test_module_has_custom_func(test_case): class CustomModuleHasFunc(flow.nn.Module): def __init__(self): super().__init__() self.data_mem = 10 def forward(self, x): return self._custom_func(x) def _custom_func(self, x): test_case.assertEqual(self.data_mem, 10) return x class CustomGraphHasFunc(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModuleHasFunc() def build(self, x): return self.m(x) g = CustomGraphHasFunc() x = np.ones((10, 10)) x = flow.tensor(x, dtype=flow.float32) out = g(x) test_case.assertTrue(np.array_equal(x.numpy(), out.numpy())) def test_module_has_special_attr(test_case): class CustomModuleHasSpecialAttr(flow.nn.Module): def __init__(self): super().__init__() self.config = 1 self.name = "test_name" def forward(self, x): test_case.assertEqual(self.config, 1) test_case.assertEqual(self.name, "test_name") test_case.assertEqual(self.to(nn.graph.GraphModule).name, "m") return x class CustomGraphHasSpecialAttr(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModuleHasSpecialAttr() def build(self, x): return self.m(x) g = CustomGraphHasSpecialAttr() x = np.ones((10, 10)) x = flow.tensor(x, dtype=flow.float32) out = g(x) test_case.assertTrue(np.array_equal(x.numpy(), out.numpy())) def test_block_with_parameter(test_case): device = "cuda" linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = linear def forward(self, x): return self._forward_impl(x) def _forward_impl(self, x): test_case.assertTrue(isinstance(self.linear, flow.nn.graph.Proxy)) return self.linear(x) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModule() self.add_optimizer(of_sgd) def build(self, x): out = self.m(x) out = out.sum() out.backward() test_case.assertTrue(self.m.linear.weight.is_lazy) return out linear_t_g = LinearTrainGraph() linear_t_g(x) def test_block_get_class_in_forward(test_case): device = "cuda" linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = linear def forward(self, x): return self._forward_impl(x) def _forward_impl(self, x): test_case.assertTrue(isinstance(self.linear, flow.nn.Module)) test_case.assertTrue(isinstance(self.linear, flow.nn.Linear)) return self.linear(x) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModule() test_case.assertTrue(isinstance(self.m.linear, flow.nn.graph.Proxy)) self.add_optimizer(of_sgd) def build(self, x): test_case.assertTrue(isinstance(self.m.linear, flow.nn.Module)) test_case.assertTrue(isinstance(self.m.linear, flow.nn.Linear)) out = self.m(x) out = out.sum() out.backward() test_case.assertTrue(self.m.linear.weight.is_lazy) return out linear_t_g = LinearTrainGraph() test_case.assertTrue(isinstance(linear_t_g.m.linear, flow.nn.graph.Proxy)) linear_t_g(x) def test_block_with_not_registered_module(test_case): device = "cuda" linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.dict = {"lin": linear} def forward(self, x): return self._forward_impl(x) def _forward_impl(self, x): test_case.assertTrue(isinstance(self.dict["lin"], flow.nn.Module)) return self.dict["lin"](x) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModule() def build(self, x): out = self.m(x) out = out.sum() return out linear_t_g = LinearTrainGraph() with warnings.catch_warnings(record=True) as w: # Here will print: # UserWarning: Linear(in_features=3, out_features=8, bias=True) is called in a nn.Graph, but not registered into a nn.Graph. linear_t_g(x) test_case.assertTrue(len(w) == 1) test_case.assertTrue(issubclass(w[-1].category, UserWarning)) test_case.assertTrue( "is called in a nn.Graph, but not registered into a nn.Graph" in str(w[-1].message) ) def test_block_with_seq_container(test_case): class SubModule0(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(10, 10, False) def forward(self, x): if graph_build_util.lazy_mode.is_enabled(): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.linear(x) return out list_of_m = [SubModule0() for i in range(3)] class SeqModule(flow.nn.Module): def __init__(self): super().__init__() self.linears = flow.nn.Sequential(*list_of_m) def forward(self, x): x = self.linears(x) return x class SeqGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linears = flow.nn.Sequential(*list_of_m) self.linears.to(nn.graph.GraphModule).activation_checkpointing = True def build(self, x): x = self.linears(x) return x seq_m = SeqModule() seq_g = SeqGraph() input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = seq_m(input) output_g = seq_g(input) # print(seq_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_block_with_list_container(test_case): class SubModule0(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(10, 10, False) def forward(self, x): if graph_build_util.lazy_mode.is_enabled(): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.linear(x) return out list_of_m = [SubModule0() for i in range(3)] class ModuleListModule(flow.nn.Module): def __init__(self): super().__init__() self.linears = flow.nn.ModuleList(list_of_m) def forward(self, x): for i, _ in enumerate(self.linears): x = self.linears[i](x) return x class ModuleListGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linears = flow.nn.ModuleList(list_of_m) # NOTE: ModuleList doesn't have config. # self.linears.to(GraphModule).activation_checkpointing = True for i, _ in enumerate(self.linears): self.linears[i].to( nn.graph.GraphModule ).activation_checkpointing = True def build(self, x): # ModuleList can act as an iterable, or be indexed using ints for i, _ in enumerate(self.linears): x = self.linears[i](x) return x module_list_m = ModuleListModule() module_list_g = ModuleListGraph() input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = module_list_m(input) output_g = module_list_g(input) # print(module_list_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_module_list_slice(test_case): class ModuleListSlice(nn.Module): def __init__(self,): super().__init__() linear1 = nn.Linear(5, 5, bias=False) linear2 = nn.Linear(5, 5, bias=False) linear3 = nn.Linear(5, 5, bias=False) self.modulelist = nn.ModuleList([linear1, linear2, linear3]) def forward(self, x): sliced_m = self.modulelist[1:] test_case.assertEqual(len(sliced_m), 2) y = sliced_m[1](x) return y class GraphModuleListSlice(nn.Graph): def __init__(self, m): super().__init__() self.m = m def build(self, x): return self.m(x) in_tensor = flow.randn(5, 5) m = ModuleListSlice() eager_out = m(in_tensor) g = GraphModuleListSlice(m) graph_out = g(in_tensor) test_case.assertTrue(np.array_equal(eager_out.numpy(), graph_out.numpy())) def test_block_with_dict_container(test_case): class SubModule0(flow.nn.Module): def __init__(self, out): super().__init__() self.linear = flow.nn.Linear(10, out, False) def forward(self, x): if graph_build_util.lazy_mode.is_enabled(): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.linear(x) return out dict_of_m = {"0": SubModule0(10), "1": SubModule0(6)} class ModuleDictModule(flow.nn.Module): def __init__(self): super().__init__() self.linears = flow.nn.ModuleDict(dict_of_m) def forward(self, x): x = self.linears["0"](x) x = self.linears["1"](x) return x class ModuleDictGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linears = flow.nn.ModuleDict(dict_of_m) # NOTE: ModuleDict doesn't have config. # self.linears.to(GraphModule).activation_checkpointing = True for k, _ in self.linears.items(): self.linears[k].to( nn.graph.GraphModule ).activation_checkpointing = True def build(self, x): # ModuleDict can act as an iterable, or get using key x = self.linears["0"](x) x = self.linears["1"](x) return x module_dict_m = ModuleDictModule() module_dict_g = ModuleDictGraph() input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = module_dict_m(input) output_g = module_dict_g(input) # print(module_dict_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_block_with_dict_container_nto1(test_case): class SubModule0(flow.nn.Module): def __init__(self, out): super().__init__() self.linear = flow.nn.Linear(10, out, False) def forward(self, x): if graph_build_util.lazy_mode.is_enabled(): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.linear(x) return out sub_m = SubModule0(10) dict_of_m = {"0": sub_m, "1": sub_m} class ModuleDictModule(flow.nn.Module): def __init__(self): super().__init__() self.linears = flow.nn.ModuleDict(dict_of_m) def forward(self, x): x = self.linears["0"](x) x = self.linears["1"](x) return x class ModuleDictGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linears = flow.nn.ModuleDict(dict_of_m) # NOTE: ModuleDict doesn't have config. # self.linears.to(GraphModule).activation_checkpointing = True for k, _ in self.linears.items(): self.linears[k].to( nn.graph.GraphModule ).activation_checkpointing = True def build(self, x): # ModuleDict can act as an iterable, or get using key x = self.linears["0"](x) x = self.linears["1"](x) return x module_dict_m = ModuleDictModule() module_dict_g = ModuleDictGraph() input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = module_dict_m(input) output_g = module_dict_g(input) print(module_dict_g) # print(module_dict_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_block_with_para_list_container(test_case): list_of_p = [flow.nn.Parameter(flow.randn(10, 10)) for i in range(2)] class ParaListModule(flow.nn.Module): def __init__(self): super().__init__() self.params = flow.nn.ParameterList(list_of_p) def forward(self, x): for i, _ in enumerate(self.params): x = flow._C.matmul(x, self.params[i]) return x class ParaListGraph(flow.nn.Graph): def __init__(self): super().__init__() self.params = flow.nn.ParameterList(list_of_p) def build(self, x): for i, _ in enumerate(self.params): x = flow._C.matmul(x, self.params[i]) return x para_list_m = ParaListModule() para_list_g = ParaListGraph() # print(para_list_g) input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = para_list_m(input) # print(output_m) output_g = para_list_g(input) # print(para_list_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_block_with_para_dict_container(test_case): dict_of_p = { "0": flow.nn.Parameter(flow.randn(10, 3)), "1": flow.nn.Parameter(flow.randn(10, 10)), } class ParaDictModule(flow.nn.Module): def __init__(self): super().__init__() self.params = flow.nn.ParameterDict(dict_of_p) def forward(self, x): x = flow._C.matmul(x, self.params["0"]) return x class ParaDictGraph(flow.nn.Graph): def __init__(self): super().__init__() self.params = flow.nn.ParameterDict(dict_of_p) def build(self, x): x = flow._C.matmul(x, self.params["0"]) return x para_dict_m = ParaDictModule() para_dict_g = ParaDictGraph() # print(para_dict_g) input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32) output_m = para_dict_m(input) # print(output_m) output_g = para_dict_g(input) # print(para_dict_g) test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy())) def test_mixin_module(test_case): class ModuleMixin(flow.nn.Module): def __init__(self): super().__init__() self._dtype = flow.float32 @property def dtype(self): return self._dtype class ConfigMixin: def hello_from_cfg(self): return "hello_from_cfg" @property def property_from_cfg(self): return 128 class MixedModule(ModuleMixin, ConfigMixin): def __init__(self): super().__init__() def forward(self, x): test_case.assertEqual(self.dtype, flow.float32) test_case.assertEqual(self.hello_from_cfg(), "hello_from_cfg") test_case.assertEqual(self.property_from_cfg, 128) return x mixedm = MixedModule() class GraphConfigMixin(object): @property def hello_from_graph(self): return "hello_from_gcfg" def mixin_get_name(self): return self.name class MixinGraph(flow.nn.Graph, GraphConfigMixin): def __init__(self): super().__init__() self.m = mixedm def build(self, x): test_case.assertEqual(self.hello_from_graph, "hello_from_gcfg") test_case.assertEqual(self.mixin_get_name(), self.name) return self.m(x) g = MixinGraph() x = np.ones((10, 10)) x = flow.tensor(x, dtype=flow.float32) out = g(x) test_case.assertTrue(np.array_equal(x.numpy(), out.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_buffer_limit.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import time import unittest import numpy as np import oneflow as flow import oneflow.unittest def _test_graph_buffer_limit(test_case): class StageLayerModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(10, 8, False) self.linear2 = flow.nn.Linear(8, 10, False) flow.nn.init.constant_(self.linear1.weight, 0.023) flow.nn.init.constant_(self.linear2.weight, 1.23) def forward(self, x): out0 = self.linear1(x) out0 = out0 + 1.0 out0 = out0 * 2.0 out1 = self.linear2(out0) return out1 P0 = flow.placement("cuda", ranks=[0]) P1 = flow.placement("cuda", ranks=[1]) PT = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.layer_0 = StageLayerModule() self.layer_1 = StageLayerModule() self.layer_0.to_global(P0, B) self.layer_1.to_global(P1, B) def forward(self, x): # stage 0 in0 = x.to_global(P0, B) out0 = self.layer_0(in0) # stage 1 in1 = out0.to_global(P1, B) out1 = self.layer_1(in1) return out1 pp_m = PipelineModule() pp_m.eval() class PipelineGraph(flow.nn.Graph): def __init__(self): super().__init__() self.pp_m = pp_m def build(self, x): return self.pp_m(x) pp_g = PipelineGraph() for i in range(500): x = flow.randn(16, 10) x = x.to_global(P0, B) out = pp_g(x) # print(out.to_local().mean()) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGraphPipelineBufferLimit(oneflow.unittest.TestCase): def test_graph_buffer_limit(test_case): _test_graph_buffer_limit(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_clip_grad_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow from oneflow.nn.graph import GraphModule import oneflow.unittest class MyModule1(flow.nn.Module): def __init__(self, param): super().__init__() self.param = flow.nn.Parameter(param) def forward(self, input): x = flow._C.matmul(input, self.param, transpose_b=True) return flow._C.gelu(x) class MyModule2(flow.nn.Module): def __init__(self, param): super().__init__() self.param = flow.nn.Parameter(param) def forward(self, input, target): x = flow._C.matmul(input, self.param) loss = flow._C.sparse_softmax_cross_entropy(x, target) return loss.mean() # return loss def _make_optimizer(params, norm_type, max_norm): return flow.optim.SGD( [ { "params": params, "lr": 1.0, "momentum": 0.0, "clip_grad_max_norm": max_norm, "clip_grad_norm_type": norm_type, }, ] ) class MyGraph(flow.nn.Graph): def __init__(self, module1, module2, optimizer=None, acc=1): super().__init__() self.m1 = module1 self.m2 = module2 if ( module1.param.is_global and module2.param.is_global and module1.param.placement != module2.param.placement ): self.m1.to(GraphModule).set_stage(0) self.m2.to(GraphModule).set_stage(1) if optimizer is not None: self.add_optimizer(optimizer) if acc > 1: self.config.set_gradient_accumulation_steps(acc) def build(self, input, target): x = self.m1(input) if x.is_global and target.is_global and x.placement != target.placement: x = x.to_global(placement=target.placement) loss = self.m2(x, target) loss.backward() return loss class TensorGenerator(object): def __init__( self, batch_size=8, feat1=10, feat2=8, device="cuda", parallel_mode=None ): input = flow.randn(batch_size, feat1).to(device) param1 = flow.randn(feat2, feat1).to(device) param2 = flow.randn(feat2, feat1).to(device) target = flow.randint(0, 10, (batch_size,)).to(device) ranks = np.array(range(flow.env.get_world_size())) placement = flow.placement(device, ranks) self.input = input.to_global(placement, sbp=flow.sbp.broadcast) self.param1 = param1.to_global(placement, sbp=flow.sbp.broadcast) self.param2 = param2.to_global(placement, sbp=flow.sbp.broadcast) self.target = target.to_global(placement, sbp=flow.sbp.broadcast) self.input_sbp = None self.target_sbp = None self.param1_sbp = None self.param2_sbp = None self.placement1 = None self.placement2 = None if parallel_mode is not None: assert isinstance(parallel_mode, str) or isinstance( parallel_mode, (list, tuple) ) if isinstance(parallel_mode, str): parallel_mode = [parallel_mode] assert all(p.upper() in ("DP", "MP", "PP") for p in parallel_mode) assert len(parallel_mode) > 0 and len(parallel_mode) <= 2 self.input_sbp = [] self.target_sbp = [] self.param1_sbp = [] self.param2_sbp = [] has_pp = False for p in parallel_mode: if p == "DP": self.input_sbp.append(flow.sbp.split(0)) self.target_sbp.append(flow.sbp.split(0)) self.param1_sbp.append(flow.sbp.broadcast()) self.param2_sbp.append(flow.sbp.broadcast()) elif p == "MP": self.input_sbp.append(flow.sbp.broadcast()) self.target_sbp.append(flow.sbp.broadcast()) self.param1_sbp.append(flow.sbp.split(0)) self.param2_sbp.append(flow.sbp.split(0)) elif p == "PP": ranks = ranks.reshape(2, -1) self.placement1 = flow.placement(device, ranks[0]) self.placement2 = flow.placement(device, ranks[1]) has_pp = True else: raise ValueError if len(parallel_mode) > 1 and not has_pp: ranks = ranks.reshape(2, -1) self.placement1 = flow.placement(device, ranks) self.placement2 = flow.placement(device, ranks) if len(self.input_sbp) == 0: self.input_sbp = None if len(self.target_sbp) == 0: self.target_sbp = None if len(self.param1_sbp) == 0: self.param1_sbp = None if len(self.param2_sbp) == 0: self.param2_sbp = None def local_input(self): return self.input.to_local() def local_target(self): return self.target.to_local() def local_param1(self): return self.param1.clone().to_local() def local_param2(self): return self.param2.clone().to_local() def global_input(self): if self.input_sbp is None and self.placement1 is None: return self.input return self.input.to_global(placement=self.placement1, sbp=self.input_sbp) def global_target(self): if self.target_sbp is None and self.placement2 is None: return self.target return self.target.to_global(placement=self.placement2, sbp=self.target_sbp) def global_param1(self): if self.param1_sbp is None and self.placement1 is None: return self.param1.clone() return self.param1.to_global(placement=self.placement1, sbp=self.param1_sbp) def global_param2(self): if self.param2_sbp is None and self.placement2 is None: return self.param2.clone() return self.param2.to_global(placement=self.placement2, sbp=self.param2_sbp) def _compare_with_eager( test_case, *, batch_size=8, acc=1, norm_type=2.0, max_norm=1.0, device="cuda", parallel_mode=None, rtol=1e-03, atol=1e-05, ): gen = TensorGenerator( batch_size=batch_size, device=device, parallel_mode=parallel_mode ) # eager m1 = MyModule1(gen.local_param1()) m2 = MyModule2(gen.local_param2()) opt = _make_optimizer([m1.param, m2.param], norm_type, max_norm) x = m1(gen.local_input()) loss = m2(x, gen.local_target()) opt.zero_grad() loss.backward() opt.clip_grad() opt.step() loss_a = loss.numpy() grad1_a = m1.param.numpy() grad2_a = m2.param.numpy() # graph graph_m1 = MyModule1(gen.global_param1()) graph_m2 = MyModule2(gen.global_param2()) opt = _make_optimizer([graph_m1.param, graph_m2.param], norm_type, max_norm) graph = MyGraph(graph_m1, graph_m2, opt, acc) graph_loss = graph(gen.global_input(), gen.global_target()) # debug # rank = flow.env.get_rank() # print("") # print(f"[rank{rank}] eager local loss: {loss}") # print( # f"[rank{rank}] graph_loss placement: {graph_loss.placement}, sbp: {graph_loss.sbp}" # ) # print(f"[rank{rank}] graph_loss: {graph_loss}") # local_loss = graph_loss.to_local() # print(f"[rank{rank}] local_loss.numel(): {local_loss.numel()}") # print(f"[rank{rank}] local_loss: {local_loss}") if acc > 1 and graph_loss.numel() == acc: graph_loss = graph_loss.mean() if parallel_mode is None: loss_b = graph_loss.numpy() grad1_b = graph.m1.to(flow.nn.Module).param.numpy() grad2_b = graph.m2.to(flow.nn.Module).param.numpy() else: ranks = np.array(range(flow.env.get_world_size())) placement = flow.placement(device, ranks) loss_b = graph_loss.to_global(placement, flow.sbp.broadcast).to_local().numpy() grad1_b = graph.m1.to(flow.nn.Module).param.to_global( placement, flow.sbp.broadcast ) grad1_b = grad1_b.to_local().numpy() grad2_b = graph.m2.to(flow.nn.Module).param.to_global( placement, flow.sbp.broadcast ) grad2_b = grad2_b.to_local().numpy() # compare test_case.assertTrue( np.allclose(loss_a, loss_b, rtol=rtol, atol=atol), f"{loss_a} vs. {loss_b}" ) test_case.assertTrue( np.allclose(grad1_a, grad1_b, rtol=rtol, atol=atol), f"\n{grad1_a}\nvs.\n{grad1_b}", ) test_case.assertTrue( np.allclose(grad2_a, grad2_b, rtol=rtol, atol=atol), f"\n{grad2_a}\nvs.\n{grad2_b}", ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGraphClipGradNorm(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_local(test_case): _compare_with_eager(test_case) @flow.unittest.skip_unless_1n1d() def test_acc(test_case): _compare_with_eager(test_case, batch_size=8, acc=8) @flow.unittest.skip_unless_1n2d() def test_dp(test_case): _compare_with_eager(test_case, parallel_mode="DP") @flow.unittest.skip_unless_1n2d() def test_mp(test_case): _compare_with_eager(test_case, parallel_mode="MP") @flow.unittest.skip_unless_1n2d() def test_pp(test_case): _compare_with_eager(test_case, parallel_mode="PP") @flow.unittest.skip_unless_1n2d() def test_pp_acc(test_case): _compare_with_eager(test_case, batch_size=8, acc=8, parallel_mode="PP") @flow.unittest.skip_unless_1n4d() def test_dp_mp(test_case): _compare_with_eager(test_case, parallel_mode=["DP", "MP"]) @flow.unittest.skip_unless_1n4d() def test_mp_pp(test_case): _compare_with_eager(test_case, parallel_mode=["MP", "PP"]) @flow.unittest.skip_unless_1n4d() def test_dp_pp(test_case): _compare_with_eager(test_case, parallel_mode=["DP", "PP"]) @flow.unittest.skip_unless_1n4d() def test_mp_pp_acc(test_case): _compare_with_eager(test_case, batch_size=8, acc=8, parallel_mode=["MP", "PP"]) @flow.unittest.skip_unless_1n4d() def test_dp_pp_acc(test_case): _compare_with_eager(test_case, batch_size=8, acc=4, parallel_mode=["DP", "PP"]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGraphClipGradNormInf(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_local(test_case): _compare_with_eager(test_case, norm_type=float("inf")) @flow.unittest.skip_unless_1n1d() def test_acc(test_case): _compare_with_eager( test_case, batch_size=8, acc=8, norm_type=-float("inf"), atol=1e-6 ) @flow.unittest.skip_unless_1n2d() def test_dp(test_case): _compare_with_eager( test_case, norm_type=float("inf"), max_norm=2.0, parallel_mode="DP", atol=1e-6, ) @flow.unittest.skip_unless_1n2d() def test_mp(test_case): _compare_with_eager( test_case, norm_type=-float("inf"), max_norm=3.0, parallel_mode="MP", atol=1e-6, ) @flow.unittest.skip_unless_1n2d() def test_pp(test_case): _compare_with_eager( test_case, norm_type=float("inf"), max_norm=4.0, parallel_mode="PP", atol=1e-6, ) @flow.unittest.skip_unless_1n2d() def test_pp_acc(test_case): _compare_with_eager( test_case, batch_size=8, acc=8, norm_type=-float("inf"), max_norm=5.0, parallel_mode="PP", atol=1e-6, ) @flow.unittest.skip_unless_1n4d() def test_dp_mp(test_case): _compare_with_eager( test_case, norm_type=float("inf"), max_norm=1.1, parallel_mode=["DP", "MP"], atol=1e-6, ) @flow.unittest.skip_unless_1n4d() def test_mp_pp(test_case): _compare_with_eager( test_case, norm_type=-float("inf"), max_norm=1.2, parallel_mode=["MP", "PP"], atol=1e-6, ) @flow.unittest.skip_unless_1n4d() def test_dp_pp(test_case): _compare_with_eager( test_case, norm_type=float("inf"), max_norm=1.3, parallel_mode=["DP", "PP"], atol=1e-6, ) @flow.unittest.skip_unless_1n4d() def test_mp_pp_acc(test_case): _compare_with_eager( test_case, batch_size=8, acc=8, norm_type=float("inf"), max_norm=2.1, parallel_mode=["MP", "PP"], atol=1e-6, ) @flow.unittest.skip_unless_1n4d() def test_dp_pp_acc(test_case): _compare_with_eager( test_case, batch_size=8, acc=4, norm_type=-float("inf"), max_norm=2.2, parallel_mode=["DP", "PP"], atol=1e-6, ) if __name__ == "__main__": # flow.manual_seed(0) unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_copy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestCopyGraph(oneflow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_copy_graph(test_case): linear = flow.nn.Linear(3, 8, False) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr) flow.nn.init.constant_(linear.weight, 2.3) of_eager_out = linear(x) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) class LinearGraph(flow.nn.Graph): def __init__(self): super().__init__() self.my_linear = linear.to(flow.device("cuda")) def build(self, x): x = x.to(flow.device("cuda")) return self.my_linear(x) linear_g = LinearGraph() of_lazy_out = linear_g(x) test_case.assertTrue( np.allclose(of_lazy_out.numpy(), of_eager_out.numpy(), 1e-05, 1e-05) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_debug.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import sys import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule rank = flow.env.get_rank() def _graph_debug(test_case, v_level=0, ranks=None, max_py_stack_depth=2): class DebugGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = flow.nn.Linear(3, 3) def build(self, x): return x d_g = DebugGraph() d_g.debug(v_level, ranks=ranks, max_py_stack_depth=max_py_stack_depth) if ranks is None: rank_list = [0] elif isinstance(ranks, int): rank_list = [ranks] elif isinstance(ranks, list): rank_list = ranks if ( -1 in rank_list or rank in rank_list ) and v_level >= 0: # v_level == -1 means debug mode is closed test_case.assertTrue(d_g._debug) test_case.assertTrue(d_g.m.to(GraphModule)._debug) print(f"ranks {ranks} rank {rank} debug is opened.") else: test_case.assertTrue(not d_g._debug) test_case.assertTrue(not d_g.m.to(GraphModule)._debug) print(f"ranks {ranks} rank {rank} debug is closed.") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestGraphDebug(oneflow.unittest.TestCase): def test_graph_debug_rank_null(test_case): _graph_debug(test_case) def test_graph_debug_rank_0(test_case): _graph_debug(test_case, ranks=0) def test_graph_debug_rank_1(test_case): _graph_debug(test_case, ranks=1) def test_graph_debug_rank_1_and_2(test_case): _graph_debug(test_case, ranks=[1, 2]) def test_graph_debug_rank_all(test_case): _graph_debug(test_case, ranks=-1) def test_graph_debug_mode_closed(test_case): _graph_debug(test_case, v_level=-1) def test_graph_debug_mode_opened(test_case): _graph_debug(test_case, v_level=0) def test_graph_debug_max_py_stack_depth_2(test_case): _graph_debug(test_case, max_py_stack_depth=2) def test_graph_debug_max_py_stack_depth_8(test_case): _graph_debug(test_case, max_py_stack_depth=8) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_depend.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import unittest # used to observe operator optimization and execution order manually # import os # os.environ["ONEFLOW_DEBUG_MODE"] = "1" # os.environ["GLOG_v"] = "3" # os.environ["ENABLE_LOGICAL_CHAIN"] = "true" import oneflow as flow import oneflow.nn as nn import oneflow.unittest # NOTE: nn.functional.depend() behaves differently in the two modes # in EAGER mode, the OP has no effect. That is, the first paramerter # and output are the same tensor (like "y=x" in python), while the # second paramerter will be ignore. def _build_graph_and_test(TestModel, in_data, test_case): model = TestModel() y_eager = model(in_data) class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() self.model = model def build(self, x): return self.model(x) graph = TestGraph() # used to observe operator optimization and execution order manually # graph.debug(3) y_lazy = graph(in_data) test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) @flow.unittest.skip_unless_1n1d() class TestDependGraph(oneflow.unittest.TestCase): def test_depend_graph_case0(test_case): class TestModel_0(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2" be executed before "self.linear(x)" in graph mode # base use case x1 = x * 2 x = nn.functional.depend(x, x1) x2 = self.linear(x) return x2 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_0, x, test_case) def test_depend_graph_case1(test_case): class TestModel_1(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode # test multiple continuous nn.functional.depend() in a logical chain x1 = x * 2 x2 = x + 2 x = nn.functional.depend(x, x1) x = nn.functional.depend(x, x2) x3 = self.linear(x) return x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_1, x, test_case) def test_depend_graph_case2(test_case): class TestModel_2(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode # some users may code like this x1 = x * 2 x2 = x + 2 x2 = nn.functional.depend(x2, x1) x = nn.functional.depend(x, x2) x3 = self.linear(x) return x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_2, x, test_case) def test_depend_graph_case3(test_case): class TestModel_3(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2", "x + 2" and "x -2" be executed before "self.linear(x)" in graph mode # a combination of above cases x1 = x * 2 x2 = x + 2 x3 = x - 2 x = nn.functional.depend(x, x1) x2 = nn.functional.depend(x2, x3) x = nn.functional.depend(x, x2) x3 = self.linear(x) return x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_3, x, test_case) def test_depend_graph_case4(test_case): class TestModel_4(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # the depend OP do nothing and it should be pruned from graph correctly x1 = x * 2 x2 = nn.functional.depend(x, x1) x3 = self.linear(x) return x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_4, x, test_case) def test_depend_graph_case5(test_case): class TestModel_5(nn.Module): def __init__(self): super().__init__() self.linear0 = nn.Linear(128, 128) self.linear1 = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2" be executed before "self.linear0(x)" and # "self.linear1(x)" in graph mode # to test the case that depend OP connect to more than one OPs x1 = x * 2 x = nn.functional.depend(x, x1) x2 = self.linear0(x) x3 = self.linear1(x) return x2 + x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_5, x, test_case) def test_depend_graph_case6(test_case): class TestModel_6(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x - 2" be executed before "self.linear(x)" in graph mode # to test the case that the OP connects to Depend OP also connects to other OPs x1 = x * 2 x2 = x1 - 2 x3 = nn.functional.depend(x2, x1) x4 = self.linear(x3) x5 = x2 + x4 return x5 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_6, x, test_case) def test_depend_graph_case7(test_case): class TestModel_7(nn.Module): def __init__(self): super().__init__() def forward(self, x): # to ensure "mp_values * 2" be executed before "max_pool1d" in graph mode # to test the case that OPs have mutiple outputs connect to depend OP x1 = x + 2 mp_values, mp_indices = nn.functional.max_pool1d( x, kernel_size=2, return_indices=True ) mp_values = nn.functional.depend(mp_values, x1) mp_values = mp_values * 2 return mp_values + mp_indices.to(flow.float32) x = flow.randn([1, 2, 3], dtype=flow.float32) _build_graph_and_test(TestModel_7, x, test_case) def test_depend_graph_case8(test_case): class TestModel_1(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(128, 128) def forward(self, x): # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode # to test the case that inputting mutiple depend tensors at a time x1 = x * 2 x2 = x + 2 x = nn.functional.depend(x, [x1, x2]) x3 = self.linear(x) return x3 x = flow.randn([1, 128], dtype=flow.float32) _build_graph_and_test(TestModel_1, x, test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_eye.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import random import oneflow as flow import oneflow.unittest from test_util import generate_graph @flow.unittest.skip_unless_1n1d() class TestEyeGraph(oneflow.unittest.TestCase): def test_eye_graph(test_case): n = random.randint(1, 10) m = random.randint(1, 10) eye_fn = lambda: flow.eye(n, m) y_eager = eye_fn() eye_graph = generate_graph(eye_fn) y_lazy = eye_graph() test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_free_eager_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphWithEagerTensorCaught(oneflow.unittest.TestCase): def test_eager_tensor_forward_graph(test_case): class MyModuleWithEagerTensorForward(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) def forward(self, x): y0 = self.linear(x) eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device) out = y0 + eager_t return out my_net_module = MyModuleWithEagerTensorForward() flow.nn.init.constant_(my_net_module.linear.weight, 2.3) x = np.random.randn(5, 3) x = flow.tensor(x, dtype=flow.float32) class GraphEagerTensorCaught(flow.nn.Graph): def __init__(self): super().__init__() self.my_net = my_net_module def build(self, x): return self.my_net(x) my_g = GraphEagerTensorCaught() graph_out = my_g(x) eager_out = my_net_module(x) test_case.assertTrue( np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4) ) @unittest.skip("skip for now, becase it failed 2 times in past week") def test_eager_tensor_to(test_case): class EagerTensorToModule(flow.nn.Module): def __init__(self): super().__init__() def forward(self): # test free eager tensor to t = flow.tensor([1.0], dtype=flow.float32).to("cuda") return t e_m = EagerTensorToModule() class EagerTensorToGraph(flow.nn.Graph): def __init__(self): super().__init__() self.e_m = e_m def build(self): return self.e_m() e_g = EagerTensorToGraph() graph_out = e_g() eager_out = e_m() test_case.assertTrue( np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4) ) def test_two_graph_caught_same_free_eager_tensor(test_case): np_x = np.random.randn(5, 3) np_y = np.random.randn(5, 3) x = flow.tensor(np_x, dtype=flow.float32) y = flow.tensor(np_y, dtype=flow.float32) class GraphAdd(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return x + y class GraphMul(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return x * y g_add = GraphAdd() g_mul = GraphMul() add_out = g_add() mul_out = g_mul() test_case.assertTrue( np.allclose(add_out.numpy(), np_x + np_y, atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4) ) def test_graph_return_free_eager_tensor(test_case): np_x = np.random.randn(5, 3) x = flow.tensor(np_x, dtype=flow.float32) class GraphReturnEager(flow.nn.Graph): def __init__(self): super().__init__() def build(self): # Return free eager tensor return x g_return_eager = GraphReturnEager() # Run first time ret_eager_out = g_return_eager() test_case.assertTrue( np.allclose(ret_eager_out.numpy(), np_x, atol=1e-4, rtol=1e-4) ) # Run second time ret_eager_out1 = g_return_eager() test_case.assertTrue( np.allclose(ret_eager_out1.numpy(), np_x, atol=1e-4, rtol=1e-4) ) def test_graph_return_inplace_free_eager_tensor(test_case): np_x = np.random.randn(5, 3) x = flow.tensor(np_x, dtype=flow.float32) class GraphInplaceReturnEager(flow.nn.Graph): def __init__(self): super().__init__() def build(self): # x is free eager tensor # mul_ is inplace scalar mul # Input and output of mul_ are both tensor x # After lazy interpretr, tensor x's name will be the ouput lbn of mul_ x.mul_(2) # Here will return the output of mul_ return x g_return_eager = GraphInplaceReturnEager() # Run first time ret_eager_out = g_return_eager() # x in ouput changed # So nn.Graph simulate inplace in nn.Graph.build(). test_case.assertTrue( np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) ) # x has not changed # So nn.Graph inplace will not change free eager tensor. test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) # Run second time ret_eager_out = g_return_eager() test_case.assertTrue( np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) ) test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class GlobalFreeEagerTensorGraphTestCase(oneflow.unittest.TestCase): def test_global_eager_tensor_to(test_case): rank = flow.env.get_rank() placement = flow.placement("cpu", ranks=[0, 1]) t_l = flow.tensor([1.0, 2.0], dtype=flow.float32) t = t_l.to_global(placement=placement, sbp=flow.sbp.broadcast) class GlobalEagerTensorToModule(flow.nn.Module): def __init__(self): super().__init__() def forward(self): # test free eager tensor to nonlocal t t = t.to("cuda") return t e_m = GlobalEagerTensorToModule() class GlobalEagerTensorToGraph(flow.nn.Graph): def __init__(self): super().__init__() self.e_m = e_m def build(self): return self.e_m() e_g = GlobalEagerTensorToGraph() graph_out = e_g().to_local() print("g ", graph_out.numpy()) test_case.assertTrue( np.allclose(graph_out.numpy(), t_l.numpy(), atol=1e-4, rtol=1e-4) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_grad_acc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest def _test_grad_acc_graph(test_case, device): def get_linear_sgd(): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 1.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.01, momentum=0.9) return linear, of_sgd x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], device=device, requires_grad=False, ) free_one = flow.tensor([1.0], device=device, requires_grad=False) eager_linear, eager_sgd = get_linear_sgd() eager_out_list = [] eager_weight_list = [] for i in range(12): index = (i % 4) * 2 input = x[index : (index + 2)] # NOTE(chengcheng): unpack x by slice # print("i = ", i, " input = ", input) of_out = eager_linear(input) of_out += free_one # Test free eager tensor one = flow.ones(of_out.shape, dtype=of_out.dtype, device=of_out.device) of_out += one of_out = flow.reshape(of_out, shape=[-1]) of_out = of_out.sum() loss = of_out * 0.25 # NOTE(chengcheng): scale loss by grad acc loss.backward() if (i + 1) % 4 == 0: eager_sgd.step() eager_sgd.zero_grad() eager_weight_list.append(eager_linear.weight.numpy()) # print("of_eager_weight in step: ", i, # " weight = ", eager_linear.weight.numpy()) # print("of_eager_out : ", of_out.numpy()) eager_out_list.append(of_out.numpy()) graph_linear, graph_sgd = get_linear_sgd() graph_out_list = [] graph_weight_list = [] class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = graph_linear self.add_optimizer(graph_sgd) self.config.set_gradient_accumulation_steps(4) def build(self, x): out = self.linear(x) out += free_one # Test free eager tensor one = flow.ones(out.shape, dtype=out.dtype, device=out.device) out += one out = flow.reshape(out, shape=[-1]) # print("out.shape: ", out.shape) loss = out.sum() loss.backward() return out, loss linear_t_g = LinearTrainGraph() for i in range(3): # NOTE(chengcheng): Graph call 1 step for 1 mini-batch(4 micro-batch) non_scalar_out, of_out = linear_t_g(x) # print("of_lazy_out : ", of_out.numpy()) graph_out_list.append(of_out.numpy()) graph_weight_list.append(graph_linear.weight.numpy()) # print("of_lazy_weight in step: ", i, # " weight = ", graph_linear.weight.numpy()) for i in range(3): test_case.assertTrue(np.allclose(eager_weight_list[i], graph_weight_list[i])) for j in range(4): test_case.assertTrue( eager_out_list[i * 4 + j].item() == graph_out_list[i][j] ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGradAccGraph(oneflow.unittest.TestCase): def test_grad_acc_graph_gpu(test_case): _test_grad_acc_graph(test_case, flow.device("cuda")) def test_grad_acc_graph_cpu(test_case): _test_grad_acc_graph(test_case, flow.device("cpu")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_image_gpu_decoder.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest class OFRecordDataLoader(flow.nn.Module): def __init__(self): super().__init__() batch_size = 4 image_size = 224 self.train_record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, random_shuffle=True, shuffle_after_epoch=True, # placement=flow.placement("cpu", ranks=[0]), # sbp=[flow.sbp.broadcast] ) self.record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) self.bytes_decoder = flow.nn.OFRecordBytesDecoder("encoded") self.image_gpu_decoder = flow.nn.OFRecordImageGpuDecoderRandomCropResize( target_width=image_size, target_height=image_size, num_workers=3 ) color_space = "RGB" output_layout = "NHWC" self.flip = flow.nn.CoinFlip( batch_size=batch_size, # placement=flow.placement("cpu", ranks=[0]), # sbp=[flow.sbp.broadcast] ) rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] self.crop_mirror_norm = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) def forward(self) -> (flow.Tensor, flow.Tensor): train_record = self.train_record_reader() label = self.record_label_decoder(train_record) encoded = self.bytes_decoder(train_record) image = self.image_gpu_decoder(encoded) rng = self.flip() if image.is_cuda: rng = rng.to("cuda") image = self.crop_mirror_norm(image, rng) return image, label @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestImageGpuDecoderGraph(oneflow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_image_gpu_decoder_graph(test_case): cc_reader = OFRecordDataLoader() class GraphReader(flow.nn.Graph): def __init__(self): super().__init__() self.my_reader = cc_reader def build(self): return self.my_reader() reader_g = GraphReader() image, label = reader_g() print(image.shape) print(label) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_inplace_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest def _test_graph_lazy_inplace(test_case, x, y): class LazyInplaceAdd(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x, y): x += y return x z = LazyInplaceAdd()(x, y) test_case.assertTrue(np.allclose(z.numpy(), (x + y).numpy(), 1e-05, 1e-05)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLocalInplace(oneflow.unittest.TestCase): def test_graph_inplace_gpu(test_case): x = flow.randn(10, 10, device=flow.device("cuda")) y = flow.ones(10, device=flow.device("cuda")) _test_graph_lazy_inplace(test_case, x, y) def test_graph_inplace_cpu(test_case): x = flow.randn(10, 10, device=flow.device("cpu")) y = flow.ones(10, device=flow.device("cpu")) _test_graph_lazy_inplace(test_case, x, y) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGlobalInplace(oneflow.unittest.TestCase): def test_graph_inplace_gpu(test_case): x = flow.randn( 10, 10, placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.split(1), ) y = flow.ones( 10, placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.broadcast ) _test_graph_lazy_inplace(test_case, x, y) def test_graph_inplace_cpu(test_case): x = flow.randn( 10, 10, placement=flow.placement("cpu", ranks=[0, 1]), sbp=flow.sbp.split(1) ) y = flow.ones( 10, placement=flow.placement("cpu", ranks=[0, 1]), sbp=flow.sbp.broadcast ) _test_graph_lazy_inplace(test_case, x, y) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_io_check.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings from collections import OrderedDict from dataclasses import dataclass, fields from typing import Any, Tuple from collections import OrderedDict import os import unittest import sys import numpy as np import oneflow as flow import oneflow.unittest from oneflow.framework.tensor import Tensor, TensorTuple from oneflow.framework.args_tree import ArgsTree from oneflow.nn.graph import GraphModule class BaseOutput(OrderedDict): def __post_init__(self): class_fields = fields(self) # Safety and consistency checks if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") first_field = getattr(self, class_fields[0].name) other_fields_are_none = all( getattr(self, field.name) is None for field in class_fields[1:] ) if other_fields_are_none and isinstance(first_field, dict): for key, value in first_field.items(): self[key] = value else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception( f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." ) def setdefault(self, *args, **kwargs): raise Exception( f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." ) def pop(self, *args, **kwargs): raise Exception( f"You cannot use ``pop`` on a {self.__class__.__name__} instance." ) def update(self, *args, **kwargs): raise Exception( f"You cannot use ``update`` on a {self.__class__.__name__} instance." ) def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} if ( self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample" ): warnings.warn( "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" " `'images'` instead.", DeprecationWarning, ) return inner_dict["images"] return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ return tuple(self[k] for k in self.keys()) @dataclass class CustomDataClass(BaseOutput): sample: flow.Tensor @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphIOCheck(flow.unittest.TestCase): def test_io_node(test_case): x = np.ones((2, 2)) x = flow.tensor(x, dtype=flow.float32) t2 = np.ones((2, 2)) t2 = flow.tensor(t2, dtype=flow.float32) t3 = np.ones((2, 2)) t3 = flow.tensor(t3, dtype=flow.float32) lt0 = list() lt0.append(t2) lt0.append(t3) t4 = np.ones((2, 2)) t4 = flow.tensor(t4, dtype=flow.float32) t4 = np.ones((2, 2)) t4 = flow.tensor(t4, dtype=flow.float32) def fn(*args, **kwargs): inp = (args, kwargs) print("origin: ", inp) args_tree = ArgsTree(inp, True, "Graph_0", None) for (name, arg) in args_tree.iter_named_nodes(): print(name, repr(arg)) def leaf_fn(arg): if isinstance(arg.value(), str): return "mapped_str" return arg.value() m_v = args_tree.map_leaf(leaf_fn) print("mapped:", m_v) return m_v[0], m_v[1] ret = fn(None, 1, "test_str", x, lt0, {"t": t4, "l": lt0}, kw=t4) print(ret) test_case.assertEqual(ret[0][2], "mapped_str") test_case.assertEqual(id(ret[1]["kw"]), id(t4)) def test_io_node_with_simple_tuple_or_list_input(self): x = np.ones((2, 2)) x = flow.tensor(x, dtype=flow.float32) t2 = np.ones((2, 2)) t2 = flow.tensor(t2, dtype=flow.float32) t3 = np.ones((2, 2)) t3 = flow.tensor(t3, dtype=flow.float32) t4 = np.ones((2, 2)) t4 = flow.tensor(t4, dtype=flow.float32) t5 = np.ones((2, 2)) t5 = flow.tensor(t4, dtype=flow.float32) t6 = np.ones((2, 2)) t6 = flow.tensor(t4, dtype=flow.float32) input_tuple = (x, t2, t3, t4) input_list = [t5, t6] def fn(args): print("origin: ", args) args_tree = ArgsTree(args, False) for arg in args_tree.iter_nodes(): print(repr(arg)) def leaf_fn(value): if isinstance(value, Tensor) and not value.is_contiguous(): value.contiguous_() return value m_v = args_tree.map_tuple_leaf(leaf_fn) print("mapped:", m_v) return m_v # input tuple ret = fn(input_tuple) print(ret) self.assertTrue(isinstance(ret, tuple)) self.assertEqual(id(ret[0]), id(x)) self.assertEqual(id(ret[1]), id(t2)) self.assertEqual(id(ret[2]), id(t3)) self.assertEqual(id(ret[3]), id(t4)) # input list ret = fn(input_list) print(ret) self.assertTrue(isinstance(ret, list)) self.assertEqual(id(ret[0]), id(t5)) self.assertEqual(id(ret[1]), id(t6)) def test_custom_class(test_case): x = np.ones((2, 2)) x = flow.tensor(x, dtype=flow.float32) ordered_d = CustomDataClass(sample=x) def fn(*args, **kwargs): inp = (args, kwargs) print("origin: ", inp) args_tree = ArgsTree(inp, True, "Graph_0", None) for (name, arg) in args_tree.iter_named_nodes(): print(name, repr(arg)) def leaf_fn(arg): if isinstance(arg.value(), dict): return "replaced" return arg.value() m_v = args_tree.map_leaf(leaf_fn) print("mapped:", m_v) return m_v[0], m_v[1] ret = fn(ordered_d) print(ret) def test_non_tensor_types_of_module(test_case): class CustomModuleIOCheck(flow.nn.Module): def __init__(self): super().__init__() def forward(self, t, lt, n, i, s, **kwargs): return t, lt, n, i, s, kwargs class CustomGraphIOCheck(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModuleIOCheck() self.m.to(GraphModule).activation_checkpointing = True def build(self, t, lt, n, **kwargs): rt, rlt, n, ri, rs, dic = self.m(t, lt, n, 1, "2", **kwargs) return t, lt, n, dic g = CustomGraphIOCheck() x = flow.tensor(np.random.randn(1,), dtype=flow.float32) t2 = flow.tensor(np.random.randn(1,), dtype=flow.float32) t3 = flow.tensor(np.random.randn(1,), dtype=flow.float32) lt0 = list() lt0.append(t2) lt0.append(t3) t7 = flow.tensor(np.random.randn(1,), dtype=flow.float32) dic2 = {"kw2": t7} lt0.append(dic2) t4 = flow.tensor(np.random.randn(1,), dtype=flow.float32) t5 = flow.tensor(np.random.randn(1,), dtype=flow.float32) t6 = flow.tensor(np.random.randn(1,), dtype=flow.float32) lt1 = list() lt1.append(t5) lt1.append(t6) ot, olt, on, odic = g(x, lt0, None, kw0=t4, kw1=lt1) # print(g) test_case.assertTrue(np.array_equal(x.numpy(), ot.numpy())) test_case.assertTrue(isinstance(olt, list)) test_case.assertTrue(isinstance(olt[0], Tensor)) test_case.assertTrue(np.array_equal(olt[0].numpy(), lt0[0].numpy())) test_case.assertTrue(isinstance(olt[1], Tensor)) test_case.assertTrue(np.array_equal(olt[1].numpy(), lt0[1].numpy())) test_case.assertTrue(isinstance(olt[2], dict)) test_case.assertTrue( np.array_equal(olt[2]["kw2"].numpy(), lt0[2]["kw2"].numpy()) ) test_case.assertTrue(on is None) test_case.assertTrue(isinstance(odic, dict)) test_case.assertTrue(np.array_equal(odic["kw0"].numpy(), t4.numpy())) test_case.assertTrue(np.array_equal(odic["kw1"][0].numpy(), t5.numpy())) test_case.assertTrue(np.array_equal(odic["kw1"][1].numpy(), t6.numpy())) def test_graph_return_size_0_tuple(test_case): def test_output(input, output_type): print(input) input = (input,) print(input) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() def forward(self, t): return t[0] class CustomGraphCheck1Ret(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModule() def build(self, t): rt = self.m(t) return rt model = CustomModule() graph = CustomGraphCheck1Ret() model_out = model(input) graph_out = graph(input) if output_type is None: test_case.assertTrue(model_out is output_type) test_case.assertTrue(graph_out is output_type) else: test_case.assertTrue(isinstance(model_out, output_type)) test_case.assertTrue(isinstance(graph_out, output_type)) x = np.ones((1, 10)) x = flow.tensor(x, dtype=flow.float32) # test size 1 tuple x_tuple = (x,) test_output(x_tuple, tuple) # test size 1 list x_list = [ x, ] test_output(x_list, list) # test tensor test_output(x, Tensor) def test_graph_return_dict_tuple(test_case): def test_output(input): print(input) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() def forward(self, t): return {"output": t} class CustomGraphCheck1Ret(flow.nn.Graph): def __init__(self): super().__init__() self.m = CustomModule() def build(self, t): rt = self.m(t) return rt model = CustomModule() graph = CustomGraphCheck1Ret() model_out = model(input) graph_out = graph(input) test_case.assertTrue(isinstance(model_out, dict)) test_case.assertTrue(isinstance(graph_out, dict)) test_case.assertEqual(len(model_out), 1) test_case.assertEqual(len(graph_out), 1) test_case.assertTrue("output" in model_out) test_case.assertTrue("output" in graph_out) test_case.assertTrue( np.array_equal(model_out["output"].numpy(), graph_out["output"].numpy()) ) x = np.ones((1, 10)) x = flow.tensor(x, dtype=flow.float32) # test tensor test_output(x) def test_graph_outputs_buffer(test_case): class CustomModuleIOCheck(flow.nn.Module): def __init__(self): super().__init__() def forward(self, t, tp, lt, n, i, s): return t, tp, lt, n, i, s class CustomGraphIOCheck1(flow.nn.Graph): def __init__(self): super().__init__() self.config.set_outputs_buffer_size(5) self.m = CustomModuleIOCheck() def build(self, t, tp, lt, n): rt, rtp, rlt, n, ri, rs = self.m(t, tp, lt, n, 1, "2") return t, tp, lt, n g = CustomGraphIOCheck1() x = np.ones((10, 10)) x = flow.tensor(x, dtype=flow.float32) y = np.ones((10, 10)) y = flow.tensor(y, dtype=flow.float32) # IO with TensorTuple cannot pass this test, # its tensor item's id is weird. # t0 = np.ones((10, 10)) # t0 = flow.tensor(t0, dtype=flow.float32) # t1 = np.ones((10, 10)) # t1 = flow.tensor(t1, dtype=flow.float32) # tp0 = TensorTuple() # tp0.append(t0) # tp0.append(t1) t2 = np.ones((10, 10)) t2 = flow.tensor(t2, dtype=flow.float32) t3 = np.ones((10, 10)) t3 = flow.tensor(t3, dtype=flow.float32) lt0 = list() lt0.append(t2) lt0.append(t3) # Check there is not duplicated tensor in outputs buffer and outputs. out_id_dic = dict() out_tensor_holder = dict() def check_id_and_add(t, name): if t is not None: tid = id(t) assert ( tid not in out_id_dic ), f"tid {tid}, now name {name}, inserted name {out_id_dic[tid]}" test_case.assertTrue(tid not in out_id_dic) out_id_dic[tid] = name # It seems that python id maybe re-used, hold it to avoid gc re-using it. # ref: https://stackoverflow.com/questions/52096582/how-unique-is-pythons-id out_tensor_holder[name] = t def call_and_check(idx): # ot, otp, olt, on = g(x, tp0, lt0, None) ot, otp, olt, on = g(x, y, lt0, None) if idx == 0: test_case.assertEqual(len(g._outputs_tensor_tuple_buffer), 5) for b_idx, buffer in enumerate(g._outputs_tensor_tuple_buffer): for i_idx, item in enumerate(buffer): check_id_and_add( item, "buffer_" + str(b_idx) + "_" + str(i_idx) ) test_case.assertTrue(np.array_equal(x.numpy(), ot.numpy())) check_id_and_add(ot, "ot_" + str(idx)) # test_case.assertTrue(isinstance(otp, TensorTuple)) # check_id_and_add(otp, "otp_" + str(idx)) # test_case.assertTrue(isinstance(otp[0], Tensor)) # check_id_and_add(otp[0], "otp0_" + str(idx)) # test_case.assertTrue(np.array_equal(otp[0].numpy(), tp0[0].numpy())) # test_case.assertTrue(isinstance(otp[1], Tensor)) # check_id_and_add(otp[1], "otp1_" + str(idx)) # test_case.assertTrue(np.array_equal(otp[1].numpy(), tp0[1].numpy())) test_case.assertTrue(isinstance(otp, Tensor)) check_id_and_add(otp, "otp_" + str(idx)) test_case.assertTrue(np.array_equal(y.numpy(), otp.numpy())) test_case.assertTrue(isinstance(olt, list)) check_id_and_add(olt, "olt_" + str(idx)) test_case.assertTrue(isinstance(olt[0], Tensor)) check_id_and_add(olt[0], "olt0_" + str(idx)) test_case.assertTrue(np.array_equal(olt[0].numpy(), lt0[0].numpy())) check_id_and_add(olt[1], "olt1_" + str(idx)) test_case.assertTrue(np.array_equal(olt[1].numpy(), lt0[1].numpy())) test_case.assertTrue(on is None) for i in range(15): call_and_check(i) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_linear.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest def _test_linear_graph(test_case, device): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr, device=device) flow.nn.init.constant_(linear.weight, 2.3) of_eager_out = linear(x) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) class LinearGraph(flow.nn.Graph): def __init__(self): super().__init__() self.my_linear = linear def build(self, x): return self.my_linear(x) linear_g = LinearGraph() linear_g.debug(0) of_lazy_out = linear_g(x) test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) def _test_linear_graph_func(test_case, device): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr, device=device) flow.nn.init.constant_(linear.weight, 2.3) of_eager_out = linear(x) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) @flow.nn.Graph.trace def linear_func(x): return linear(x) of_lazy_out = linear_func(x) test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLinearGraph(oneflow.unittest.TestCase): def test_linear_graph_gpu(test_case): _test_linear_graph(test_case, flow.device("cuda")) def test_linear_graph_cpu(test_case): _test_linear_graph(test_case, flow.device("cpu")) def test_linear_graph_func_gpu(test_case): _test_linear_graph_func(test_case, flow.device("cuda")) def test_linear_graph_func_cpu(test_case): _test_linear_graph_func(test_case, flow.device("cpu")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_linear_train.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest def _test_linear_train_graph(test_case, device): def train_with_module(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) def one_iter(): of_out = linear(x) of_out = of_out.sum() of_out.backward() of_sgd.step() of_sgd.zero_grad() return of_out.numpy(), linear.weight.numpy() check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list def train_with_graph(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear self.add_optimizer(of_sgd) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraph() def one_iter(): of_graph_out = linear_t_g(x) print(linear_t_g.linear) return ( of_graph_out.numpy(), linear_t_g.linear.weight.to(flow.Tensor).numpy(), ) check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list iter_num = 3 module_check_list = train_with_module(iter_num) graph_check_list = train_with_graph(iter_num) for i in range(iter_num): # check equal on loss test_case.assertTrue( np.array_equal(module_check_list[i][0], graph_check_list[i][0]) ) # check equal on weight test_case.assertTrue( np.array_equal(module_check_list[i][1], graph_check_list[i][1]) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLinearTrainGraph(oneflow.unittest.TestCase): def test_linear_train_graph_gpu(test_case): _test_linear_train_graph(test_case, flow.device("cuda")) def test_linear_train_graph_cpu(test_case): _test_linear_train_graph(test_case, flow.device("cpu")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from collections import OrderedDict from test_util import GenArgList shapes = {2: (128, 8), 3: (16, 8, 64), 4: (16, 8, 32, 32), 5: (16, 8, 16, 16, 16)} def compare_loss(device_type, dim, reduction, cls, data_generator): x, y = data_generator(dim, device_type) f = cls(reduction=reduction).to(device_type) z_eager = f(x, y) class CurrentGraph(flow.nn.Graph): def __init__(self) -> None: super().__init__() self.f = f def build(self, x, y): return self.f(x, y) f_g = CurrentGraph() z_lazy = f_g(x, y) assert np.allclose(z_eager.numpy(), z_lazy.numpy(), rtol=1.0e-5, atol=1.0e-5) def generate_necessity_default(dim: int, device: str): shape = shapes[dim] x_np = np.random.uniform(0, 1, shape) y_np = np.random.uniform(0, 1, shape) x = flow.tensor(x_np, dtype=flow.float32, device=device) y = flow.tensor(y_np, dtype=flow.float32, device=device) return x, y def generate_necessity_for_cross_entropy_or_nll_loss(dim: int, device: str): shape = shapes[dim] y_shape = (shape[0],) if dim == 2 else (shape[0], *shape[2:]) x_np = np.random.uniform(0, 1, shape) y_np = np.random.randint(0, shape[1], y_shape) x = flow.tensor(x_np, dtype=flow.float32, device=device) y = flow.tensor(y_np, dtype=flow.int32, device=device) return x, y @flow.unittest.skip_unless_1n1d() class TestKLDivLossGraph(oneflow.unittest.TestCase): def test_kl_div_loss_graph(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.KLDivLoss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) @flow.unittest.skip_unless_1n1d() class TestSmoothL1LossGraph(oneflow.unittest.TestCase): def test_smooth_l1_loss_graph(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.SmoothL1Loss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) @flow.unittest.skip_unless_1n1d() class TestBCELossOrWithLogitsGraph(flow.unittest.TestCase): def test_bce_loss_graph(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.BCELoss, flow.nn.BCEWithLogitsLoss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) @flow.unittest.skip_unless_1n1d() class TestCrossEntropyOrNllLossGraph(flow.unittest.TestCase): def test_cross_entropy_loss_or_nll_loss_graph(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.CrossEntropyLoss, flow.nn.NLLLoss] arg_dict["data_generator"] = [generate_necessity_for_cross_entropy_or_nll_loss] for arg in GenArgList(arg_dict): compare_loss(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_lr_scale.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict class _Block(flow.nn.Module): def __init__(self, feats, device=None, placement=None): super().__init__() ones = flow.ones(feats) if placement is not None: ones = ones.to_global(placement=placement, sbp=flow.sbp.broadcast()) elif device is not None: ones = ones.to(device) self.param = flow.nn.Parameter(ones) def forward(self, x): return x + self.param class _MyModule(flow.nn.Module): def __init__(self, feats, depth, device=None, placement=None): super().__init__() self.layers = flow.nn.ModuleList( [ _Block(feats=feats, device=device, placement=placement) for i in range(depth) ] ) def forward(self, x): for layer in self.layers: x = layer(x) return x class _MyGraph(flow.nn.Graph): def __init__(self, model, optimizer, lr_scheduler): super().__init__() self.m = model self.add_optimizer(optimizer, lr_sch=lr_scheduler) def build(self, input): out = self.m(input) out.sum().backward() return out def _lrs_param_groups(model, base_scale): param_groups = [] for i, layer in enumerate(model.layers): this_scale = base_scale ** (i + 1) param_group = {"params": layer.parameters(), "lr_scale": this_scale} param_groups.append(param_group) return param_groups def _rand_input(shape, device=None, placement=None, requires_grad=False): input = flow.tensor(np.random.rand(*shape).astype(np.float32)) if placement is not None: input = input.to_global(placement=placement, sbp=flow.sbp.split(0)) elif device is not None: input = input.to(device) if requires_grad: input.requires_grad_() return input def _test_lrs(test_case, **kwargs): verbose = kwargs.pop("verbose", False) if verbose: print(f"#### kwargs={kwargs}") batch_size = kwargs.pop("batch_size", 4) feats = kwargs.pop("feats", 768) depth = kwargs.pop("depth", 3) lr = kwargs.pop("lr", 1.0) base_scale = kwargs.pop("base_scale", 0.1) device_type = kwargs.pop("device_type", "cuda") placement = kwargs.pop("placement", None) graph_mode = kwargs.pop("graph_mode", True) model = _MyModule(feats=feats, depth=depth, device=device_type, placement=placement) param_groups = _lrs_param_groups(model, base_scale=base_scale) optimizer = flow.optim.SGD(param_groups, lr=lr) lr_scheduler = flow.optim.lr_scheduler.ConstantLR( optimizer, factor=1.0, total_iters=100 ) model_graph = _MyGraph(model, optimizer, lr_scheduler) input = _rand_input( (batch_size, feats), device=device_type, placement=placement, requires_grad=True ) t_params = [] if graph_mode: for i in range(depth): origin_p = model.layers[i].param.numpy() init_grad = float(batch_size * flow.env.get_world_size()) t_params.append(origin_p - float(init_grad) * lr * (base_scale ** (i + 1))) ret = model_graph(input) else: for i in range(depth): origin_p = model.layers[i].param.numpy() init_grad = float(batch_size * flow.env.get_world_size()) t_params.append(origin_p - float(init_grad) * lr) optimizer.zero_grad() ret = model(input) ret.sum().backward() optimizer.step() lr_scheduler.step() if verbose: print("#### input") print(input) # sync np_ret = ret.numpy() print("#### ret") print(np_ret) for i in range(depth): np_param = model.layers[i].param.numpy() print(f"#### layer{i} param") print(np_param) print("#### grad") print(input.grad) for i in range(depth): np_param = model.layers[i].param.numpy() t_param = t_params[i] test_case.assertTrue( np.allclose(np_param, t_param), f"\n{np_param}\n vs. \n{t_param}" ) @flow.unittest.skip_unless_1n1d() class LRScaleTest(flow.unittest.TestCase): def test_lr_scale(self): arg_dict = OrderedDict() arg_dict["batch_size"] = [2, 4] arg_dict["feats"] = [10, 13] arg_dict["depth"] = [3, 4] arg_dict["lr"] = [1.0, 0.1] arg_dict["base_scale"] = [0.1, 0.2] arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["is_global"] = [True, False] arg_dict["graph_mode"] = [True, False] for arg in GenArgDict(arg_dict): is_global = arg.pop("is_global", True) if is_global: device_type = arg.pop("device_type", "cuda") arg["placement"] = flow.placement.all(device_type) # arg["verbose"] = True _test_lrs(self, **arg) @flow.unittest.skip_unless_1n2d() class LRScaleParallelTest(flow.unittest.TestCase): def test_lr_scale_parallel(self): arg_dict = OrderedDict() arg_dict["batch_size"] = [2, 4] arg_dict["feats"] = [5, 10] arg_dict["depth"] = [3, 4] arg_dict["lr"] = [1.0, 0.1] arg_dict["base_scale"] = [0.1, 0.2] arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["graph_mode"] = [True, False] for arg in GenArgDict(arg_dict): device_type = arg.pop("device_type", "cuda") arg["placement"] = flow.placement.all(device_type) # arg["verbose"] = True _test_lrs(self, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_lr_scheduler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import glob import oneflow as flow import oneflow.unittest class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.param = flow.nn.Parameter(flow.ones(3, 4)) def forward(self, input): return self.param + input class MyGraph(flow.nn.Graph): def __init__(self, module, optimizer, lr_scheduler): super().__init__() self.m = module self.add_optimizer(optimizer, lr_sch=lr_scheduler) def build(self, input): out = self.m(input) out.mean().backward() return out def _rand_input(): return flow.Tensor(np.random.rand(3, 4).astype(np.float32)) def _get_graph_lrs_from_log(log_path): lines = [] with open(log_path, "rt") as f: for line in f: lines.append(line.strip()) lines = lines[1:] lrs = [] for i, line in enumerate(lines): step, lr = line.split(",") assert int(step) == i lrs.append(float(lr)) return lrs class _DebugMode(object): def __enter__(self): os.environ["ONEFLOW_DEBUG_MODE"] = "True" def __exit__(self, type, value, traceback): del os.environ["ONEFLOW_DEBUG_MODE"] def _compare_graph_lr_scheduler_with_eager(test_case, **kwargs): lr_scheduler_class = kwargs.pop("lr_scheduler", None) base_lr = kwargs.pop("base_lr", None) iters = kwargs.pop("iters", None) rtol = kwargs.pop("rtol", 1e-05) atol = kwargs.pop("atol", 1e-08) if "warmup_method" in kwargs: warmup_method = kwargs.pop("warmup_method", "linear") warmup_iters = kwargs.pop("warmup_iters", 5) warmup_factor = kwargs.pop("warmup_factor", 0.1) warmup_prefix = kwargs.pop("warmup_prefix", False) need_warmup = True else: need_warmup = False assert base_lr is not None and iters is not None module = MyModule() optimizer = flow.optim.SGD([module.param], lr=base_lr) lr_scheduler = ( lr_scheduler_class(optimizer, **kwargs) if lr_scheduler_class else None ) if need_warmup: lr_scheduler = flow.optim.lr_scheduler.WarmupLR( lr_scheduler or optimizer, warmup_factor=warmup_factor, warmup_iters=warmup_iters, warmup_method=warmup_method, warmup_prefix=warmup_prefix, ) graph = MyGraph(module, optimizer, lr_scheduler) with _DebugMode(): for _ in range(iters + 1): ret = graph(_rand_input()) ret.numpy() # sync for graph finishing pid = os.getpid() lr_log_file = glob.glob(f"log/*/{pid}-train_step2lr.csv")[0] lrs = _get_graph_lrs_from_log(lr_log_file) lrs = lrs[:iters] optimizer.zero_grad(set_to_none=True) eager_lrs = [lr_scheduler.get_last_lr()[0]] for _ in range(iters): ret = module(_rand_input()) ret.numpy() optimizer.step() lr_scheduler.step() eager_lrs.append(lr_scheduler.get_last_lr()[0]) eager_lrs = eager_lrs[:iters] test_case.assertTrue( np.allclose(lrs, eager_lrs, rtol=rtol, atol=atol), f"\ngraph_lrs: {lrs}\nvs.\neager_lrs: {eager_lrs}", ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphLRSchedulerWithEager(flow.unittest.TestCase): def test_constant_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=10, lr_scheduler=flow.optim.lr_scheduler.ConstantLR, factor=0.1, total_iters=10, ) def test_linear_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.LinearLR, start_factor=0.1, end_factor=1.0, total_iters=10, ) def test_linear_lr_end_factor(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.LinearLR, start_factor=0.1, end_factor=0.9, total_iters=10, ) def test_step_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=30, lr_scheduler=flow.optim.lr_scheduler.StepLR, step_size=10, gamma=0.1, ) def test_multi_step_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.MultiStepLR, milestones=[5, 15], gamma=0.2, ) def test_polynomial_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.PolynomialLR, decay_batch=20, end_learning_rate=1e-5, power=2.0, atol=1e-5, ) _compare_graph_lr_scheduler_with_eager( self, base_lr=0.01, iters=20, lr_scheduler=flow.optim.lr_scheduler.PolynomialLR, decay_batch=20, end_learning_rate=1e-4, power=1.0, cycle=True, ) def test_exponential_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=10, lr_scheduler=flow.optim.lr_scheduler.ExponentialLR, gamma=0.5, atol=1e-5, ) def test_cosine_decay_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR, decay_steps=10, alpha=1e-3, atol=1e-5, ) def test_cosine_annealing_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR, T_max=10, eta_min=1e-4, atol=1e-5, ) def test_linear_warmup_cosine_annealing_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR, T_max=20, eta_min=1e-5, warmup_method="linear", warmup_factor=0.1, warmup_iters=5, warmup_prefix=False, atol=1e-5, ) def test_linear_warmup_prefix_cosine_annealing_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR, T_max=20, eta_min=1e-5, warmup_method="linear", warmup_factor=0.1, warmup_iters=5, warmup_prefix=True, atol=1e-5, ) def test_linear_warmup_multistep_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.MultiStepLR, milestones=[10, 15], gamma=0.1, warmup_method="linear", warmup_factor=0.1, warmup_iters=5, ) def test_constant_warmup_cosine_decay_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR, decay_steps=20, alpha=1e-3, warmup_method="constant", warmup_factor=0.1, warmup_iters=5, atol=1e-5, ) def test_constant_warmup_prefix_cosine_decay_lr(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=20, lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR, decay_steps=20, alpha=1e-3, warmup_method="constant", warmup_factor=0.1, warmup_iters=5, warmup_prefix=True, atol=1e-5, ) def test_only_warmup(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=10, lr_scheduler=None, warmup_method="linear", warmup_factor=0.1, warmup_iters=5, ) def test_warmup_iters_equal_to_zero(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=10, lr_scheduler=flow.optim.lr_scheduler.StepLR, step_size=3, gamma=0.5, warmup_method="linear", warmup_iters=0, ) def test_cosine_annealing_warm_restarts(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=50, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10, T_mult=1, eta_min=0.01, atol=1e-5, ) def test_cosine_annealing_warm_restarts_mult_2(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=70, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10, T_mult=2, eta_min=0.01, atol=1e-5, ) def test_cosine_annealing_warm_restarts_limit(self): _compare_graph_lr_scheduler_with_eager( self, base_lr=0.1, iters=50, lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10, T_mult=2, eta_min=0.01, decay_rate=0.5, restart_limit=2, atol=1e-5, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_lr_with_warmup.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import unittest import os import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.parameter import Parameter def _test_linear_graph_train_with_lr_sch( test_case, iter_num, device, get_opt_and_lr_sch ): def train_with_module(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, -0.68758) flow.nn.init.constant_(linear.bias, 0.23) opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) def one_iter(): of_out = linear(x) of_out = of_out.sum() of_out.backward() opt.step() if lr_sch is not None: lr_sch.step() opt.zero_grad() return of_out.numpy(), linear.weight.numpy() check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list def train_with_graph(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, -0.68758) flow.nn.init.constant_(linear.bias, 0.23) opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear if lr_sch is None: self.add_optimizer(opt) else: self.add_optimizer(opt, lr_sch=lr_sch) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraph() def one_iter(): of_graph_out = linear_t_g(x) return ( of_graph_out.numpy(), linear_t_g.linear.weight.to(flow.Tensor).numpy(), ) check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list module_check_list = train_with_module(iter_num) graph_check_list = train_with_graph(iter_num) for i in range(iter_num): # check equal on loss test_case.assertTrue( np.allclose( module_check_list[i][0], graph_check_list[i][0], rtol=0.00001, atol=0.00001, ) ) # check equal on weight test_case.assertTrue( np.allclose( module_check_list[i][1], graph_check_list[i][1], rtol=0.00001, atol=0.00001, ) ) def _sgd_cosine_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.5 decay_steps = 10 cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR( of_sgd, decay_steps=decay_steps, alpha=alpha ) return of_sgd, cosine_decay_lr def _sgd_cosine_constant_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.5 decay_steps = 10 cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR( of_sgd, decay_steps=decay_steps, alpha=alpha ) constant_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR( cosine_decay_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="constant" ) return of_sgd, constant_warmup_cosine_lr def _sgd_constant_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.5 steps = 10 constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR( of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method="constant" ) return of_sgd, constant_warmup_lr def _sgd_cosine_linear_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.5 decay_steps = 10 cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR( of_sgd, decay_steps=decay_steps, alpha=alpha ) linear_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR( cosine_decay_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) return of_sgd, linear_warmup_cosine_lr def _sgd_linear_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) alpha = 0.5 steps = 10 linear_warmup_lr = flow.optim.lr_scheduler.WarmUpLR( of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) return of_sgd, linear_warmup_lr @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLinearGraphTrainWithCosineLrScheduler(flow.unittest.TestCase): def test_graph_cosine(test_case): _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cuda"), _sgd_cosine_fn ) _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cpu"), _sgd_cosine_fn ) def test_graph_cosine_constant(test_case): _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cuda"), _sgd_cosine_constant_fn ) _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cpu"), _sgd_cosine_constant_fn ) def test_graph_constant(test_case): _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cuda"), _sgd_constant_fn ) _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cpu"), _sgd_constant_fn ) def test_graph_cosine_linear(test_case): _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cuda"), _sgd_cosine_linear_fn ) _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cpu"), _sgd_cosine_linear_fn ) def test_graph_linear(test_case): _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cuda"), _sgd_linear_fn ) _test_linear_graph_train_with_lr_sch( test_case, 21, flow.device("cpu"), _sgd_linear_fn ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_lrs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import unittest import os import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.parameter import Parameter def _test_linear_graph_train_with_lr_sch( test_case, iter_num, device, get_opt_and_lr_sch ): def train_with_module(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, -0.68758) flow.nn.init.constant_(linear.bias, 0.23) opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) def one_iter(): of_out = linear(x) of_out = of_out.sum() of_out.backward() opt.step() if lr_sch is not None: lr_sch.step() opt.zero_grad() return of_out.numpy(), linear.weight.numpy() check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list def train_with_graph(iter_num=3): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, -0.68758) flow.nn.init.constant_(linear.bias, 0.23) opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear if lr_sch is None: self.add_optimizer(opt) else: self.add_optimizer(opt, lr_sch=lr_sch) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraph() def one_iter(): of_graph_out = linear_t_g(x) return ( of_graph_out.numpy(), linear_t_g.linear.weight.to(flow.Tensor).numpy(), ) check_list = [] for i in range(iter_num): check_list.append(one_iter()) return check_list module_check_list = train_with_module(iter_num) graph_check_list = train_with_graph(iter_num) for i in range(iter_num): # check equal on loss test_case.assertTrue( np.allclose( module_check_list[i][0], graph_check_list[i][0], rtol=0.00001, atol=0.00001, ) ) # check equal on weight test_case.assertTrue( np.allclose( module_check_list[i][1], graph_check_list[i][1], rtol=0.00001, atol=0.00001, ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphLRs(flow.unittest.TestCase): def test_step_lr(test_case): def _lr_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) step_lr = flow.optim.lr_scheduler.StepLR(of_sgd, step_size=7, gamma=0.1) return of_sgd, step_lr _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) def test_multistep_lr(test_case): def _lr_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) multistep_lr = flow.optim.lr_scheduler.MultiStepLR( of_sgd, milestones=[10, 15], gamma=0.1 ) return of_sgd, multistep_lr _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) @unittest.skip("skip for now, becase it failed 6 times in past week") def test_cosine_annealing_lr(test_case): def _lr_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) lr = flow.optim.lr_scheduler.CosineAnnealingLR( of_sgd, T_max=5, eta_min=0.0001 ) return of_sgd, lr _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) def test_polynomial_lr(test_case): def _lr_fn(parameters): of_sgd = flow.optim.SGD(parameters, lr=0.001) lr = flow.optim.lr_scheduler.PolynomialLR( of_sgd, decay_batch=10, end_learning_rate=0.00001, power=2, cycle=True ) return of_sgd, lr _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_masked_fill.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import random import oneflow as flow from oneflow import nn import oneflow.unittest from test_util import generate_graph @flow.unittest.skip_unless_1n1d() class TestMaskedFillGraph(flow.unittest.TestCase): def test_masked_fill_graph(test_case): k = random.randint(1, 10) model = nn.Sequential(nn.Linear(k, k)) optimizer = flow.optim.SGD(model.parameters(), lr=1e-3) loss_fn = nn.MSELoss() class MaskedFillGraph(flow.nn.Graph): def __init__(self,): super().__init__() self.model = model self.loss_fn = loss_fn self.add_optimizer(optimizer) def build(self, input, mask): output = self.model(input) output = flow.masked_fill(output, mask > 0.5, 0.5) loss = self.loss_fn(output, input) loss.backward() return loss input = flow.randn(k, k).requires_grad_() mask = flow.randn(k, k) model = MaskedFillGraph() return model(input, mask) def test_masked_fill_by_generate_graph(test_case): k = random.randint(1, 10) input = flow.randn(k, k) mask = flow.randn(k, k) masked_fill_fn = lambda: flow.masked_fill(input, mask > 0.5, 0.5) y_eager = masked_fill_fn() masked_fill_graph = generate_graph(masked_fill_fn) y_lazy = masked_fill_graph() test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_nccl_logical_fusion.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow import nn import os import numpy as np import oneflow.unittest @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGraphNcclLogicalFusion(flow.unittest.TestCase): def test_graph_nccl_fusion_1d(test_case): x_list = [] local_np = np.arange(4 * 8, dtype=float).reshape(4, 8) P1d = flow.placement("cuda", ranks=[0, 1, 2, 3]) B = flow.sbp.broadcast() S0 = flow.sbp.split(0) S1 = flow.sbp.split(1) P = flow.sbp.partial_sum() in_0 = ( flow.tensor(local_np / 4.0) .to(flow.device("cuda")) .to_global(sbp=P, placement=P1d) ) flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclFusion1DGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): # fuse group 0: x0 = x * 0.5 y0 = x0.to_global(sbp=B, placement=P1d) # P->B x1 = x * 1.0 y1 = x1.to_global(sbp=S0, placement=P1d) # P->S0 x2 = x * 2.0 y2 = x2.to_global(sbp=S1, placement=P1d) # P->S1 x3 = x * 3.0 y3 = x3.to_global(sbp=S1, placement=P1d) # P->S1 x4 = x * 4.0 y4 = x4.to_global(sbp=S0, placement=P1d) # P->S0 # fuse group 1: x5 = y1 * 5.0 y5 = x5.to_global(sbp=B, placement=P1d) # S0->B x6 = y2 * (6.0 / 2.0) y6 = x6.to_global(sbp=B, placement=P1d) # S1->B x7 = y3 * (9.0 / 3.0) y7 = x7.to_global(sbp=S0, placement=P1d) # S1->S0 x8 = y4 * (8.0 / 4.0) y8 = x8.to_global(sbp=S1, placement=P1d) # S0->S1 y = y0 + y1 + y2 + y3 + y4 + y5 + y6 + y7 + y8 return y, y0, y1, y2, y3, y4, y5, y6, y7, y8 graph = TestNcclFusion1DGraph() out, out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8 = graph(in_0) test_case.assertTrue(np.array_equal(out_0.numpy(), local_np * 0.5)) test_case.assertTrue(np.array_equal(out_1.numpy(), local_np * 1.0)) test_case.assertTrue(np.array_equal(out_2.numpy(), local_np * 2.0)) test_case.assertTrue(np.array_equal(out_3.numpy(), local_np * 3.0)) test_case.assertTrue(np.array_equal(out_4.numpy(), local_np * 4.0)) test_case.assertTrue(np.array_equal(out_5.numpy(), local_np * 5.0)) test_case.assertTrue(np.array_equal(out_6.numpy(), local_np * 6.0)) test_case.assertTrue(np.array_equal(out_7.numpy(), local_np * 9.0)) test_case.assertTrue(np.array_equal(out_8.numpy(), local_np * 8.0)) flow.boxing.nccl.enable_use_compute_stream(False) def test_graph_nccl_fusion_2d(test_case): x_list = [] local_np = np.arange(4 * 8, dtype=float).reshape(4, 8) P2d = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) B = flow.sbp.broadcast() S0 = flow.sbp.split(0) S1 = flow.sbp.split(1) P = flow.sbp.partial_sum() in_BP = ( flow.tensor(local_np / 2.0) .to(flow.device("cuda")) .to_global(sbp=(B, P), placement=P2d) ) in_PB = ( flow.tensor(local_np / 2.0) .to(flow.device("cuda")) .to_global(sbp=(P, B), placement=P2d) ) in_S0P = in_BP.to_global(sbp=(S0, P), placement=P2d) in_PS0 = in_PB.to_global(sbp=(P, S0), placement=P2d) flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclFusion2DGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x, xsd1): # fuse group 0: x0 = x * 0.5 y0 = x0.to_global(sbp=(S0, B), placement=P2d) # same dim0 P->B x1 = x * 1.0 y1 = x1.to_global(sbp=(S0, B), placement=P2d) # same dim0 P->B xss0 = x.to_global(sbp=(S0, S0), placement=P2d) xss1 = x.to_global(sbp=(S0, S1), placement=P2d) x2 = xss0 * 2.0 y2 = x2.to_global(sbp=(S0, B), placement=P2d) # same dim0 S0->B x3 = xss1 * 3.0 y3 = x3.to_global(sbp=(S0, B), placement=P2d) # same dim0 S1->B x4 = xss0 * 4.0 y4 = x4.to_global(sbp=(S0, S1), placement=P2d) # same dim0 S0->S1 x5 = xss1 * 5.0 y5 = x5.to_global(sbp=(S0, S0), placement=P2d) # same dim0 S1->S0 x6 = xsd1 * 6.0 y6 = x6.to_global(sbp=(B, S0), placement=P2d) # same dim1 P-> B x7 = xsd1 * 7.0 y7 = x7.to_global(sbp=(B, S0), placement=P2d) # same dim1 P-> B y = y0 + y1 + y2 + y3 + y4 + y5 + y6 + y7 return y, y0, y1, y2, y3, y4, y5, y6, y7 graph = TestNcclFusion2DGraph() out, out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7 = graph( in_S0P, in_PS0 ) test_case.assertTrue(np.array_equal(out_0.numpy(), local_np * 0.5)) test_case.assertTrue(np.array_equal(out_1.numpy(), local_np * 1.0)) test_case.assertTrue(np.array_equal(out_2.numpy(), local_np * 2.0)) test_case.assertTrue(np.array_equal(out_3.numpy(), local_np * 3.0)) test_case.assertTrue(np.array_equal(out_4.numpy(), local_np * 4.0)) test_case.assertTrue(np.array_equal(out_5.numpy(), local_np * 5.0)) test_case.assertTrue(np.array_equal(out_6.numpy(), local_np * 6.0)) test_case.assertTrue(np.array_equal(out_7.numpy(), local_np * 7.0)) flow.boxing.nccl.enable_use_compute_stream(False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_non_contiguous_tensors.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow.unittest import oneflow as flow import numpy as np class ModuleTest(flow.nn.Module): def __init__(self, contiguous: bool, device): super().__init__() if contiguous: self.weight = flow.nn.Parameter(flow.ones(4, 3, device=device)) else: self.weight = flow.nn.Parameter( flow.ones(3, 4, device=device).transpose(0, 1) ) def forward(self, input): res = flow.matmul(input, self.weight) return res def _test_graph_non_contiguous_tensors(test_case, device): bias = flow.tensor( [[1, 2, 3], [3, 4, 5], [7, 7, 7],], dtype=flow.float32, device=device ) free_eager_bias_contiguous = bias free_eager_bias_non_contiguous = bias.transpose(0, 1).contiguous().transpose(0, 1) test_case.assertTrue(free_eager_bias_contiguous.is_contiguous()) test_case.assertFalse(free_eager_bias_non_contiguous.is_contiguous()) class GraphTestContiguousTensors(flow.nn.Graph): def __init__(self): super().__init__() self.model = ModuleTest(True, device) def build(self, input): res = self.model(input) + free_eager_bias_contiguous return res class GraphTestNonContiguousTensors(flow.nn.Graph): def __init__(self): super().__init__() self.model = ModuleTest(False, device) def build(self, input): res = self.model(input) + free_eager_bias_non_contiguous return res graph_contiguous_tensors = GraphTestContiguousTensors() graph_non_contiguous_tensors = GraphTestNonContiguousTensors() test_case.assertTrue( graph_contiguous_tensors.model.weight.to(flow.Tensor).is_contiguous() ) test_case.assertFalse( graph_non_contiguous_tensors.model.weight.to(flow.Tensor).is_contiguous() ) inp = flow.tensor( [[1, 2, 3], [4, 5, 6], [3, 3, 3], [7, 8, 8]], dtype=flow.float32, device=device ) non_contiguous_input = inp.transpose(0, 1) test_case.assertFalse(non_contiguous_input.is_contiguous()) contiguous_input = non_contiguous_input.contiguous() test_case.assertTrue(contiguous_input.is_contiguous()) contiguous_graph_output = graph_contiguous_tensors(contiguous_input) non_contiguous_graph_output = graph_non_contiguous_tensors(non_contiguous_input) test_case.assertTrue( np.array_equal( contiguous_graph_output.numpy(), non_contiguous_graph_output.numpy() ) ) @flow.unittest.skip_unless_1n1d() class TestGraphNonContiguousTensor(oneflow.unittest.TestCase): def test_graph_non_contiguous_tensors_cpu(test_case): _test_graph_non_contiguous_tensors(test_case, flow.device("cpu")) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_graph_non_contiguous_tensors_gpu(test_case): _test_graph_non_contiguous_tensors(test_case, flow.device("cuda")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_normal_inplace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import numpy as np import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict _fn_param_local = { "normal": lambda data: flow.normal( size=data.shape, mean=0.0, std=1.0, out=data ), # NOTE(lixiang): source op that can be inplaced. } _fn_param_global = { "normal": lambda data, placement, sbp: flow.normal( size=data.shape, mean=0.0, std=1.0, out=data, placement=placement, sbp=sbp, ), } def _test_data_local(test_case, device, fn): data_1 = flow.zeros([16, 64, 128, 128]).to(device) data_2 = flow.zeros([16, 64, 128, 128]).to(device) class NormalGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): fn(data_1).to(device) return data_1 model = NormalGraph() lazy_x = model() fn(data_2) test_case.assertTrue(lazy_x.numpy().sum() != 0) test_case.assertTrue(data_2.numpy().sum() != 0) def _test_data_global(test_case, data_1, data_2, placement, sbp, fn): class GlobalNormalGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): flow.manual_seed(233) fn(data_1, placement, sbp) return data_1 model = GlobalNormalGraph() lazy_x = model() flow.manual_seed(233) fn(data_2, placement, sbp) test_case.assertTrue( np.array_equal(lazy_x.to_local().numpy(), data_2.to_local().numpy()) ) class TestNormalOpInplaceData(flow.unittest.TestCase): @oneflow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_normal_op_data_local_with_eager_and_lazy(test_case): for device in ["cuda", "cpu"]: for _, fn in _fn_param_local.items(): _test_data_local(test_case, device, fn=fn) @unittest.skipIf(True, "refactor eager random to align pytorch") @globaltest def test_normal_op_data_consistent_with_eager_and_lazy(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): data_1 = flow.empty([8, 64, 128, 128]).to_global( placement=placement, sbp=sbp ) data_2 = flow.empty([8, 64, 128, 128]).to_global( placement=placement, sbp=sbp ) for _, fn in _fn_param_global.items(): _test_data_global(test_case, data_1, data_2, placement, sbp, fn=fn) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_ofrecord_reader.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest class OFRecordDataLoader(flow.nn.Module): def __init__(self): super().__init__() batch_size = 4 self.train_record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, random_shuffle=True, shuffle_after_epoch=True, # placement=flow.placement("cpu", ranks=[0]), # sbp=[flow.sbp.broadcast] ) self.record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) color_space = "RGB" output_layout = "NHWC" self.record_image_decoder = flow.nn.OFRecordImageDecoderRandomCrop( "encoded", color_space=color_space ) self.resize = flow.nn.image.Resize(target_size=[224, 224]) self.flip = flow.nn.CoinFlip( batch_size=batch_size, # placement=flow.placement("cpu", ranks=[0]), # sbp=[flow.sbp.broadcast] ) rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] self.crop_mirror_norm = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) def forward(self) -> (flow.Tensor, flow.Tensor): train_record = self.train_record_reader() label = self.record_label_decoder(train_record) image_raw_buffer = self.record_image_decoder(train_record) image = self.resize(image_raw_buffer)[0] rng = self.flip() image = self.crop_mirror_norm(image, rng) return image, label @flow.unittest.skip_unless_1n1d() class TestOFRecordReaderGraph(oneflow.unittest.TestCase): def test_ofrecord_reader_graph(test_case): cc_reader = OFRecordDataLoader() class GraphReader(flow.nn.Graph): def __init__(self): super().__init__() self.my_reader = cc_reader def build(self): return self.my_reader() reader_g = GraphReader() image, label = reader_g() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_adadelta.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_adadelta( test_case, device, x_shape, learning_rate, train_iters, rho, eps, maximize, weight_decay, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adadelta0 = flow.optim.Adadelta( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, } ], rho=rho, eps=eps, maximize=maximize, ) class CustomAdadeltaGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adadelta0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adadelta_graph = CustomAdadeltaGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) adadelta_x = adadelta_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value square_avgs = np.zeros_like(x) acc_deltas = np.zeros_like(x) def np_train_one_iter(grad): grad = grad if not maximize else -grad grad = grad + weight_decay * x new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad std = np.sqrt(new_square_avgs + eps) delta = np.sqrt(acc_deltas + eps) / std * grad new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho) param = x - learning_rate * delta return (param, new_square_avgs, new_acc_deltas) for i in range(1, train_iters + 1): (x, square_avgs, acc_deltas) = np_train_one_iter(random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4)) def compare_with_numpy_adadelta_clip_grad( test_case, device, x_shape, learning_rate, train_iters, rho, eps, maximize, weight_decay, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adadelta0 = flow.optim.Adadelta( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], rho=rho, eps=eps, maximize=maximize, ) class CustomAdadeltaGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adadelta0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adadelta_graph = CustomAdadeltaGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adadelta_x = adadelta_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value square_avgs = np.zeros_like(x) acc_deltas = np.zeros_like(x) def np_train_one_iter(grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad if not maximize else -grad grad = grad + weight_decay * x new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad std = np.sqrt(new_square_avgs + eps) delta = np.sqrt(acc_deltas + eps) / std * grad new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho) param = x - learning_rate * delta return (param, new_square_avgs, new_acc_deltas) for i in range(1, train_iters + 1): (x, square_avgs, acc_deltas) = np_train_one_iter(random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4)) @flow.unittest.skip_unless_1n1d() class TestAdadelta(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 8 times in past week") def test_adadelta(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["rho"] = [0.9] arg_dict["eps"] = [1e-6] arg_dict["maximize"] = [False] arg_dict["weight_decay"] = [0.1] for arg in GenArgList(arg_dict): compare_with_numpy_adadelta(test_case, *arg) def test_adadelta_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["rho"] = [0.9] arg_dict["eps"] = [1e-6] arg_dict["maximize"] = [False] arg_dict["weight_decay"] = [0.1] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_adadelta_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_adagrad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_adagrad( test_case, device, x_shape, learning_rate, train_iters, lr_decay, weight_decay, initial_accumulator_value, eps, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adam0 = flow.optim.Adagrad( [ { "params": simp_module.parameters(), "lr": learning_rate, "eps": eps, "weight_decay": weight_decay, } ], lr_decay=lr_decay, initial_accumulator_value=initial_accumulator_value, ) class CustomAdagradGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adam0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adagrad_graph = CustomAdagradGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adagrad_x = adagrad_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value st = np.ones_like(x) * initial_accumulator_value def train_one_iter(iter, grad): grad = grad + weight_decay * x lr = learning_rate / (1 + (iter - 1) * lr_decay) s = st + grad * grad param = x - lr / (np.sqrt(s) + eps) * grad return (param, s) for i in range(1, train_iters + 1): (x, st) = train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001)) def compare_with_numpy_adagrad_clip_grad( test_case, device, x_shape, learning_rate, train_iters, lr_decay, weight_decay, initial_accumulator_value, eps, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adam0 = flow.optim.Adagrad( [ { "params": simp_module.parameters(), "lr": learning_rate, "eps": eps, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], lr_decay=lr_decay, initial_accumulator_value=initial_accumulator_value, ) class CustomAdagradGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adam0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adagrad_graph = CustomAdagradGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adagrad_x = adagrad_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value st = np.ones_like(x) * initial_accumulator_value def np_train_one_iter(iter, grad): norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x lr = learning_rate / (1 + (iter - 1) * lr_decay) s = st + grad * grad param = x - lr / (np.sqrt(s) + eps) * grad return (param, s) for i in range(1, train_iters + 1): (x, st) = np_train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001)) @flow.unittest.skip_unless_1n1d() class TestAdagrad(flow.unittest.TestCase): def test_adagrad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["lr_decay"] = [0.9, 0.75] arg_dict["weight_decay"] = [0.0, 0.1] arg_dict["initial_accumulator_value"] = [1.0, 2.1] arg_dict["eps"] = [1e-08, 1e-07] for arg in GenArgList(arg_dict): compare_with_numpy_adagrad(test_case, *arg) def test_adagrad_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["lr_decay"] = [0.9, 0.75] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["initial_accumulator_value"] = [1.0, 2.1] arg_dict["eps"] = [1e-8] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_adagrad_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_adam.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_adam( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adam0 = flow.optim.Adam( [ { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "do_bias_correction": do_bias_correction, "amsgrad": amsgrad, } ] ) class CustomAdamGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adam0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adam_graph = CustomAdamGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) adam_x = adam_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def np_train_one_iter(step, grad): grad = grad + weight_decay * x bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps param = x - ((learning_rate / bias_correction1) * v / denom) return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001)) def compare_with_numpy_adam_clip_grad( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adam0 = flow.optim.Adam( [ { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "do_bias_correction": do_bias_correction, "amsgrad": amsgrad, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ] ) class CustomAdamGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adam0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adam_graph = CustomAdamGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adam_x = adam_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def np_train_one_iter(step, grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps param = x - ((learning_rate / bias_correction1) * v / denom) return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3)) @flow.unittest.skip_unless_1n1d() class TestAdam(flow.unittest.TestCase): def test_adam(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.001, 0.0] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [True, False] arg_dict["amsgrad"] = [True, False] for arg in GenArgList(arg_dict): compare_with_numpy_adam(test_case, *arg) def test_adam_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [True, False] arg_dict["amsgrad"] = [True, False] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_adam_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_adamw.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_adamw( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adamw0 = flow.optim.AdamW( [ { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "weight_decay": weight_decay, "do_bias_correction": do_bias_correction, "amsgrad": amsgrad, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, ) class CustomAdamWGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adamw0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adamw_graph = CustomAdamWGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adamw_x = adamw_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def np_train_one_iter(step, grad): v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps lr = learning_rate / bias_correction1 / denom g = lr * v + learning_rate * weight_decay * x param = x - g return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) train_by_numpy() test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-4, atol=1e-4)) def compare_with_numpy_adamw_clip_grad( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() adamw0 = flow.optim.AdamW( [ { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, ) class CustomAdamWGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adamw0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] adamw_graph = CustomAdamWGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) adamw_x = adamw_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def np_train_one_iter(step, grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps lr = learning_rate / bias_correction1 / denom g = lr * v + learning_rate * weight_decay * x param = x - g return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1]) np_res_list.append(x) train_by_numpy() test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-4, atol=1e-4)) @flow.unittest.skip_unless_1n1d() class TestAdamW(flow.unittest.TestCase): def test_adamw(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [1e-3, 0.0] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [True, False] arg_dict["amsgrad"] = [True, False] for arg in GenArgList(arg_dict): compare_with_numpy_adamw(test_case, *arg) def test_adamw_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [True, False] arg_dict["amsgrad"] = [True, False] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_adamw_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_ftrl.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.one_embedding import Ftrl def compare_with_numpy_ftrl( test_case, device, x_shape, learning_rate, train_iters, weight_decay, lr_power, initial_accumulator_value, lambda1, lambda2, beta, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() ftrl = Ftrl( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, "lr_power": lr_power, "initial_accumulator_value": initial_accumulator_value, "lambda1": lambda1, "lambda2": lambda2, "beta": beta, } ] ) class CustomftrlGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(ftrl) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] ftrl_graph = CustomftrlGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) ftrl_x = ftrl_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value accum = np.zeros_like(x) accum.fill(initial_accumulator_value) z_arr = np.zeros_like(x) def np_train_one_iter(grad): grad = grad + weight_decay * x new_accum = accum + grad * grad sigma = ( np.power(new_accum, lr_power) - np.power(accum, lr_power) ) / learning_rate new_z_val = z_arr + grad - sigma * x update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / ( (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2 ) param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val) return (param, new_accum, new_z_val) for i in range(1, train_iters + 1): (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4)) def compare_with_numpy_ftrl_clip_grad( test_case, device, x_shape, learning_rate, train_iters, weight_decay, lr_power, initial_accumulator_value, lambda1, lambda2, beta, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() ftrl = Ftrl( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, "lr_power": lr_power, "initial_accumulator_value": initial_accumulator_value, "lambda1": lambda1, "lambda2": lambda2, "beta": beta, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ] ) class CustomftrlGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(ftrl) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] ftrl_graph = CustomftrlGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) ftrl_x = ftrl_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value accum = np.zeros_like(x) accum.fill(initial_accumulator_value) z_arr = np.zeros_like(x) def np_train_one_iter(grad): norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x new_accum = accum + grad * grad sigma = ( np.power(new_accum, lr_power) - np.power(accum, lr_power) ) / learning_rate new_z_val = z_arr + grad - sigma * x update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / ( (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2 ) param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val) return (param, new_accum, new_z_val) for i in range(1, train_iters + 1): (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4)) @flow.unittest.skip_unless_1n1d() class Testftrl(flow.unittest.TestCase): def test_ftrl(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["weight_decay"] = [0.9, 0.000] arg_dict["lr_power"] = [-0.5, 0.5] arg_dict["initial_accumulator_value"] = [0.1, 0.05] arg_dict["lambda1"] = [0.01] arg_dict["lambda2"] = [0.0, 0.01] arg_dict["beta"] = [1.0] for arg in GenArgList(arg_dict): compare_with_numpy_ftrl(test_case, *arg) def test_ftrl_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["weight_decay"] = [0.9, 0.000] arg_dict["lr_power"] = [-0.5, 0.5] arg_dict["initial_accumulator_value"] = [0.1, 0.05] arg_dict["lambda1"] = [0.01] arg_dict["lambda2"] = [0.0, 0.01] arg_dict["beta"] = [1.0] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_ftrl_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_lamb.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_lamb( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, adam_w_mode, clip_grad_max_norm, clip_grad_norm_type, ): np.random.seed(1000) random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.param = flow.nn.Parameter( flow.Tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.param * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() optim_kwargs = { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "adam_w_mode": adam_w_mode, "do_bias_correction": do_bias_correction, } if clip_grad_max_norm != -1: optim_kwargs["clip_grad_max_norm"] = clip_grad_max_norm optim_kwargs["clip_grad_norm_type"] = clip_grad_norm_type lamb_optim = flow.optim.LAMB([optim_kwargs]) class CustomLambGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(lamb_optim) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss lamb_graph = CustomLambGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) lamb_graph(mask_tensor) of_res = simp_module.param.numpy() def train_by_numpy(): x = init_value mt = np.zeros_like(x) vt = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] if adam_w_mode: l2 = 0 wd = weight_decay else: l2 = weight_decay wd = 0 def np_train_one_iter(step, grad): if clip_grad_max_norm != -1: _, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + l2 * x bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step + 1) bias_correction2 = 1.0 - np.power(beta2, step + 1) m = beta1 * mt + (1 - beta1) * grad v = beta2 * vt + (1 - beta2) * grad * grad denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps adam_diff = m / bias_correction1 / denom w_norm = np.linalg.norm(x, ord=2) g_norm = np.linalg.norm(adam_diff, ord=2) if w_norm > 0 and g_norm > 0: trust_ratio = w_norm / g_norm else: trust_ratio = 1.0 param = x - learning_rate * trust_ratio * (adam_diff + wd * x) return (param, m, v) for i in range(train_iters): (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i]) return x np_res = train_by_numpy() test_case.assertTrue( np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3) ) @flow.unittest.skip_unless_1n1d() class TestLamb(flow.unittest.TestCase): def test_lamb(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [0.1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.001, 0.1] arg_dict["eps"] = [1e-8, 1e-6] arg_dict["do_bias_correction"] = [True, False] arg_dict["adam_w_mode"] = [True, False] # NOTE(l1aoxingyu): max_norm = -1 means no clip grad # nn.Graph only support `clip_grad_max_norm == 1.0` and `clip_grad_norm_type == 2.0` arg_dict["clip_grad_max_norm"] = [-1, 1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_lamb(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_rmsprop.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow @flow.unittest.skip_unless_1n1d() def compare_with_numpy_rmsprop( test_case, device, x_shape, learning_rate, momentum, train_iters, alpha, eps, weight_decay, centered, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModel(flow.nn.Module): def __init__(self): super().__init__() self.param0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.param0 * mask simp_module = CustomModel() simp_module.to(flow.device(device)) simp_module.train() rmsprop0 = flow.optim.RMSprop( [ { "params": simp_module.parameters(), "lr": learning_rate, "alpha": alpha, "eps": eps, "weight_decay": weight_decay, "momentum": momentum, "centered": centered, } ] ) class CustomRMSpropGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(rmsprop0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] rmsprop_graph = CustomRMSpropGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) rmsprop_x = rmsprop_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.param0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value r = np.zeros_like(x) v = np.zeros_like(x) g = np.zeros_like(x) def np_train_one_iter(grad): # ref to: ../modules/test_optim_rmsprop.py -> train_by_numpy() # weight decay is equivalent to l2 penalty grad = grad + weight_decay * x r_ = alpha * r + (1 - alpha) * grad * grad if centered: g_ = alpha * g + (1 - alpha) * grad v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad else: g_ = g v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad param = x - v_ return (param, r_, g_, v_) for i in range(train_iters): (x, r, g, v) = np_train_one_iter(random_grad_seq[i]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3)) @flow.unittest.skip_unless_1n1d() def compare_with_numpy_rmsprop_clip_grad( test_case, device, x_shape, learning_rate, momentum, train_iters, alpha, eps, weight_decay, centered, clip_grad_max_norm, clip_grad_norm_type, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModel(flow.nn.Module): def __init__(self): super().__init__() self.param0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.param0 * mask simp_module = CustomModel() simp_module.to(flow.device(device)) simp_module.train() rmsprop0 = flow.optim.RMSprop( [ { "params": simp_module.parameters(), "lr": learning_rate, "alpha": alpha, "eps": eps, "weight_decay": weight_decay, "momentum": momentum, "centered": centered, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ] ) class CustomRMSpropGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(rmsprop0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] rmsprop_graph = CustomRMSpropGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) rmsprop_x = rmsprop_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.param0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value r = np.zeros_like(x) v = np.zeros_like(x) g = np.zeros_like(x) def np_train_one_iter(grad): norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) # weight decay is equivalent to l2 penalty grad = grad + weight_decay * x r_ = alpha * r + (1 - alpha) * grad * grad if centered: g_ = alpha * g + (1 - alpha) * grad v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad else: g_ = g v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad param = x - v_ return (param, r_, g_, v_) for i in range(train_iters): (x, r, g, v) = np_train_one_iter(random_grad_seq[i]) np_res_list.append(x) return x train_by_numpy() test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3)) @flow.unittest.skip_unless_1n1d() class TestRMSprop(flow.unittest.TestCase): def test_rmsprop(test_case): args_dict = OrderedDict() args_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): args_dict["device"] = ["cpu"] args_dict["x_shape"] = [(1,), (10,)] args_dict["learning_rate"] = [1] args_dict["momentum"] = [0.0] # not supported momentum > 0 args_dict["train_iters"] = [10] args_dict["alpha"] = [0.9] args_dict["eps"] = [1e-8, 1e-5] args_dict["weight_decay"] = [0.1, 0.9] args_dict["centered"] = [False, True] for args in GenArgList(args_dict): compare_with_numpy_rmsprop(test_case, *args) def test_rmsprop_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1] arg_dict["momentum"] = [0.0] arg_dict["train_iters"] = [10] arg_dict["alpha"] = [0.9, 0.99] arg_dict["eps"] = [1e-08, 1e-05] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["centered"] = [False, True] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] for arg in GenArgList(arg_dict): compare_with_numpy_rmsprop_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optim_sgd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy from test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow def compare_with_numpy_sgd( test_case, device, x_shape, learning_rate, train_iters, momentum, dampening, nesterov, maximize, weight_decay, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() sgd0 = flow.optim.SGD( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, } ], momentum=momentum, dampening=dampening, nesterov=nesterov, maximize=maximize, ) class CustomSGDGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(sgd0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] sgd_graph = CustomSGDGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) sgd_x = sgd_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) def np_train_one_iter(grad): grad = grad + weight_decay * x if momentum > 0.0: next_momentum = momentum * vt + (1 - dampening) * grad v = next_momentum if nesterov: grad += momentum * next_momentum else: grad = next_momentum alpha = -learning_rate if maximize: alpha = learning_rate next_model = x + alpha * grad param = next_model else: v = learning_rate * grad param = x - v return (param, v) for i in range(train_iters): (x, vt) = np_train_one_iter(random_grad_seq[i]) np_res_list.append(x) train_by_numpy() test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-3, atol=1e-3)) def compare_with_numpy_sgd_clip_grad( test_case, device, x_shape, learning_rate, momentum, dampening, nesterov, maximize, weight_decay, clip_grad_max_norm, clip_grad_norm_type, train_iters, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter( flow.tensor(init_value, device=flow.device(device)) ) def forward(self, mask): return self.para0 * mask simp_module = CustomModule() simp_module.to(device) simp_module.train() sgd0 = flow.optim.SGD( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], momentum=momentum, dampening=dampening, nesterov=nesterov, maximize=maximize, ) class CustomSGDGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(sgd0) def build(self, mask_tensor): loss = flow.sum(self.m(mask_tensor)) loss.backward() return loss of_res_list = [] sgd_graph = CustomSGDGraph() for i in range(train_iters): mask_tensor = flow.tensor( random_grad_seq[i], requires_grad=False, device=flow.device(device) ) sgd_x = sgd_graph(mask_tensor) of_res_list.append(copy.copy(simp_module.para0.numpy())) np_res_list = [] def train_by_numpy(): x = init_value vt = np.zeros_like(x) def np_train_one_iter(grad): norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x if momentum > 0.0: next_momentum = momentum * vt + (1 - dampening) * grad v = next_momentum if nesterov: grad += momentum * next_momentum else: grad = next_momentum alpha = -learning_rate if maximize: alpha = learning_rate next_model = x + alpha * grad param = next_model else: v = learning_rate * grad param = x - v return (param, v) for i in range(train_iters): (x, vt) = np_train_one_iter(random_grad_seq[i]) np_res_list.append(x) train_by_numpy() for np_res, of_res in zip(np_res_list, of_res_list): test_case.assertTrue(np.allclose(np_res, of_res, rtol=0.001, atol=0.001)) @flow.unittest.skip_unless_1n1d() class TestGraphSGD(flow.unittest.TestCase): def test_sgd(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["momentum"] = [0.9, 0.8] arg_dict["dampening"] = [0.0, 0.9] arg_dict["nesterov"] = [True, False] arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.001, 0.0] for arg in GenArgList(arg_dict): compare_with_numpy_sgd(test_case, *arg) def test_sgd_with_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 0.1] arg_dict["momentum"] = [0.0, 0.9] arg_dict["dampening"] = [0.0, 0.9] arg_dict["nesterov"] = [True, False] arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] arg_dict["train_iters"] = [10] for arg in GenArgList(arg_dict): compare_with_numpy_sgd_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_optimizer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestGraphOptimizer(flow.unittest.TestCase): def test_optimizer(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter(flow.Tensor(10, 4)) def forward(self, x): x = flow._C.matmul(x, self.para0) return x m = CustomModule() learning_rate = 0.1 momentum = 0.2 weight_decay = 0.7 sgd0 = flow.optim.SGD( [ { "params": [m.para0], "lr": learning_rate, "momentum": momentum, "weight_decay": weight_decay, } ] ) cosine_lr = flow.optim.lr_scheduler.CosineDecayLR( sgd0, decay_steps=100, alpha=0.1 ) class CustomGraph0(flow.nn.Graph): def __init__(self): super().__init__() self.m = m self.add_optimizer(sgd0) def build(self, x): out = self.m(x) out = out.mean() out.backward() return out g = CustomGraph0() x = flow.Tensor(4, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) z = g._compile(x) print("repr(g): \n", repr(g)) print("g.config.proto: \n", g.config.proto) print("graph proto: \n", g._graph_proto) def test_multi_optimizer_conf(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter(flow.Tensor(1, 4)) self.para1 = flow.nn.Parameter(flow.Tensor(1, 4)) self.para2 = flow.nn.Parameter(flow.Tensor(1, 4)) self.para2.requires_grad_(False) self.para3 = flow.nn.Parameter(flow.Tensor(1, 4)) self.para4 = flow.nn.Parameter(flow.Tensor(1, 4)) def forward(self, x): x = flow._C.matmul(self.para0, x) y = flow._C.matmul(self.para3, x) return x, y m = CustomModule() learning_rate = 0.1 momentum = 0.2 sgd0 = flow.optim.SGD( [ { "params": [m.para0, m.para1, m.para2], "lr": learning_rate, "momentum": momentum, "weight_decay": 0.3, } ] ) sgd1 = flow.optim.SGD( [ { "params": [m.para3], "lr": learning_rate, "momentum": momentum, "weight_decay": 0.4, }, { "params": [m.para4], "lr": learning_rate, "momentum": 0.9, "weight_decay": 0.5, }, ] ) cosine_lr0 = flow.optim.lr_scheduler.CosineDecayLR( sgd0, decay_steps=10, alpha=0.01 ) constant_warmup_cosine_lr0 = flow.optim.lr_scheduler.WarmUpLR( cosine_lr0, warmup_factor=0.5, warmup_iters=5, warmup_method="constant" ) cosine_lr1 = flow.optim.lr_scheduler.CosineDecayLR( sgd1, decay_steps=100, alpha=0.1 ) linear_warmup_cosine_lr1 = flow.optim.lr_scheduler.WarmUpLR( cosine_lr1, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) class CustomGraph0(flow.nn.Graph): def __init__(self): super().__init__() self.m = m self.add_optimizer(sgd0, lr_sch=constant_warmup_cosine_lr0) self.add_optimizer(sgd1, lr_sch=linear_warmup_cosine_lr1) def build(self, x): out0, out1 = self.m(x) out0.backward() out1.backward() return out0, out1 g = CustomGraph0() x = flow.Tensor(4, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) g._filter_states() g._generate_config_proto() print("repr(g): \n", repr(g)) print("g.config.proto: \n", g.config.proto) @unittest.skip("skip for now, becase it failed 2 times in past week") def test_optimizer_with_clip_grad(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.para0 = flow.nn.Parameter(flow.Tensor(10, 4)) def forward(self, x): x = flow._C.matmul(x, self.para0) return x m = CustomModule() learning_rate = 0.1 momentum = 0.2 scale = 0.3 weight_decay = 0.7 clip_grad_max_norm = 1.0 clip_grad_norm_type = 2.0 sgd0 = flow.optim.SGD( [ { "params": [m.para0], "lr": learning_rate, "momentum": momentum, "scale": scale, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ] ) class CustomGraph0(flow.nn.Graph): def __init__(self): super().__init__() self.m = m self.add_optimizer(sgd0) def build(self, x): out = self.m(x) out = out.sum() out.backward() return out g = CustomGraph0() x = flow.Tensor(4, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) z = g._compile(x) print("repr(g): \n", repr(g)) print("g.config.proto: \n", g.config.proto) print("graph proto: \n", g._graph_proto) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_pipeline.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import sys import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule rank = flow.env.get_rank() class OFRecordDataLoader(flow.nn.Module): def __init__( self, ofrecord_root: str = "./ofrecord", mode: str = "train", # "val" dataset_size: int = 9469, batch_size: int = 1, placement=None, sbp=None, ): super().__init__() channel_last = False output_layout = "NHWC" if channel_last else "NCHW" self.train_record_reader = flow.nn.OFRecordReader( ofrecord_root + "/" + mode, batch_size=batch_size, data_part_num=40, part_name_suffix_length=5, random_shuffle=False, shuffle_after_epoch=False, placement=placement, sbp=sbp, random_seed=0, ) self.record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) color_space = "RGB" height = 22 width = 22 self.record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space=color_space ) self.resize = flow.nn.image.Resize(target_size=[height, width]) self.batch_size = batch_size self.dataset_size = dataset_size def __len__(self): return self.dataset_size // self.batch_size def forward(self): train_record = self.train_record_reader() label = self.record_label_decoder(train_record) image_raw_buffer = self.record_image_decoder(train_record) image = self.resize(image_raw_buffer)[0] image = flow.flatten(image.to(flow.float32), start_dim=1) return image, label def _train_with_graph(iter_num=3): B = [flow.sbp.broadcast] P0 = flow.placement("cuda", ranks=[0]) P1 = flow.placement("cuda", ranks=[1]) P2 = flow.placement("cuda", ranks=[2]) P3 = flow.placement("cuda", ranks=[3]) train_data_loader = OFRecordDataLoader( ofrecord_root=flow.unittest.dataset_dir("ImageNet/ofrecord"), mode="validation", dataset_size=400, batch_size=4, placement=P0, sbp=B, ) def _get_ppm_and_opt(): class StageModule(flow.nn.Module): def __init__(self, *linear_args): super().__init__() self.linear = flow.nn.Linear(*linear_args) flow.nn.init.constant_(self.linear.weight, 0.00023) def forward(self, input): out = self.linear(input) return out class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() # Initlize module and move each module to the right placement of its pipeline stage. self.stage_0_m = StageModule(1452, 8, False).to_global( placement=P0, sbp=B ) self.stage_1_m = StageModule(8, 8, False).to_global(placement=P1, sbp=B) self.stage_2_m = StageModule(8, 8, False).to_global(placement=P2, sbp=B) self.stage_3_m = StageModule(8, 1, False).to_global(placement=P3, sbp=B) def forward(self, image): out = self.stage_0_m(image) # Move tensor between different pipeline stages. out = out.to_global(placement=P1, sbp=B) out = self.stage_1_m(out) out = out.to_global(placement=P2, sbp=B) out = self.stage_2_m(out) out = out.to_global(placement=P3, sbp=B) out = self.stage_3_m(out) return out pp_m = PipelineModule() sgd = flow.optim.SGD(pp_m.parameters(), lr=0.0001) return pp_m, sgd pp_m, sgd = _get_ppm_and_opt() class PipelineGraph(flow.nn.Graph): def __init__(self): super().__init__() self.train_data_loader = train_data_loader self.pp_m = pp_m # Set different module's stage id to hint the graph preparing right num of buffers in pipeline. self.pp_m.stage_0_m.to(GraphModule).set_stage(0) self.pp_m.stage_1_m.to(GraphModule).set_stage(1) self.pp_m.stage_2_m.to(GraphModule).set_stage(2) self.pp_m.stage_3_m.to(GraphModule).set_stage(3) self.pp_m.stage_0_m.to(GraphModule).activation_checkpointing = True self.pp_m.stage_1_m.to(GraphModule).activation_checkpointing = True self.pp_m.stage_2_m.to(GraphModule).activation_checkpointing = True self.pp_m.stage_3_m.to(GraphModule).activation_checkpointing = True self.mseloss = flow.nn.MSELoss("sum") self.add_optimizer(sgd) # Let graph to do gradient accumulatioin, pipline execution depends on gradient accumulatioin. self.config.set_gradient_accumulation_steps(4) def build(self): image, label = self.train_data_loader() # Dataloader's outputs are on host memory, so move it to device 0. image = image.to_global(placement=P0, sbp=B) pp_m.train() out = self.pp_m(image) # Dataloader's outputs are on host memory, so move it to device 3. label = label.to_global(placement=P3, sbp=B) loss = self.mseloss(out, label.to(dtype=flow.float32)) loss.backward() # Returning image and label is just for re-using data in eager test image = image.to_global(placement=P3, sbp=B) return loss, image, label pp_g = PipelineGraph() def one_iter(iter_idx): loss, image, label = pp_g() if rank == 3: # loss on other rank are 0-Size tensor loss = loss.to_local() loss_np = loss.numpy() print("loss numpy \n", loss) image = image.to_local().numpy() label = label.to_local().numpy() return loss, image, label check_list = [] data_list = [] for i in range(iter_num): out = one_iter(i) if rank == 3: check_list.append(out[0]) data_list.append((out[1], out[2])) return check_list, data_list def _train_with_module(iter_num=3, data=None): class DataModule(flow.nn.Module): def __init__(self, data): super().__init__() self.data_list = [] self.idx = 0 for pair in data: for i in range(4): s = i * 4 e = s + 4 micro_batch_image = pair[0][s:e] micro_batch_label = pair[1][s:e] self.data_list.append( ( flow.Tensor(micro_batch_image).to("cuda:3"), flow.Tensor(micro_batch_label).to("cuda:3"), ) ) def forward(self): image = self.data_list[self.idx][0] label = self.data_list[self.idx][1] self.idx += 1 return image, label class TrainModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(1452, 8, False) flow.nn.init.constant_(self.linear.weight, 0.00023) self.linear.to("cuda:3") self.linear1 = flow.nn.Linear(8, 8, False) flow.nn.init.constant_(self.linear1.weight, 0.00023) self.linear1.to("cuda:3") self.linear2 = flow.nn.Linear(8, 8, False) flow.nn.init.constant_(self.linear2.weight, 0.00023) self.linear2.to("cuda:3") self.linear3 = flow.nn.Linear(8, 1, False) flow.nn.init.constant_(self.linear3.weight, 0.00023) self.linear3.to("cuda:3") self.mseloss = flow.nn.MSELoss("sum") def forward(self, image, label): out = self.linear(image) out = self.linear1(out) out = self.linear2(out) out = self.linear3(out) loss = self.mseloss(out, label) return loss if rank == 3: data_m = DataModule(data) train_m = TrainModule() sgd = flow.optim.SGD(train_m.parameters(), lr=0.0001) def one_iter(iter_idx): if rank == 3: image, label = data_m() loss = train_m(image, label) loss_np = loss.numpy() print("eager loss numpy \n", loss_np) loss = loss * 0.25 loss.backward() if iter_idx % 4 == 3: print(f"iter index: {iter_idx}") # eager gradient accumulatioin sgd.step() sgd.zero_grad() return loss_np check_list = [] for i in range(iter_num): check_list.append(one_iter(i)) return check_list def _test_graph_pipeline(test_case): iter_num = 3 graph_check_list, data = _train_with_graph(iter_num) module_check_list = _train_with_module(iter_num * 4, data) if rank == 3: for i in range(iter_num * 4): # check equal on loss test_case.assertTrue( np.array_equal(module_check_list[i], graph_check_list[i // 4][i % 4]) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestGraphPipeline(oneflow.unittest.TestCase): def test_graph_pipeline(test_case): _test_graph_pipeline(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_pipeline_delay.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import time import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule def _test_graph_pipeline_delay_output(test_case): class StageLayerModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(10, 8, False) self.linear2 = flow.nn.Linear(8, 10) flow.nn.init.constant_(self.linear1.weight, 0.023) flow.nn.init.constant_(self.linear2.weight, 1.23) def forward(self, x): out0 = self.linear1(x) out0 = out0 + 1.0 out0 = out0 * 2.0 out1 = self.linear2(out0) return out1 P0 = flow.placement("cuda", ranks=[0]) P1 = flow.placement("cuda", ranks=[1]) B = flow.sbp.broadcast class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.layer_0 = StageLayerModule() self.layer_1 = StageLayerModule() self.layer_0.to_global(P0, B) self.layer_1.to_global(P1, B) def forward(self, x): # stage 0 in0 = x.to_global(P0, B) out0 = self.layer_0(in0) # stage 1 in1 = out0.to_global(P1, B) out1 = self.layer_1(in1) return out1 pp_m = PipelineModule() pp_m.train() of_sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001) class PipelineGraph(flow.nn.Graph): def __init__(self): super().__init__() self.pp_m = pp_m self.pp_m.layer_0.to(GraphModule).stage_id = 0 self.pp_m.layer_1.to(GraphModule).stage_id = 1 self.config.set_gradient_accumulation_steps(4) self.add_optimizer(of_sgd) def build(self, x, y): pp_out = self.pp_m(x) loss = pp_out.mean() loss.backward() y = x + y free_out = y.to_global(P1, B) return loss, free_out pp_g = PipelineGraph() rank = flow.env.get_rank() for i in range(3): x = flow.randn(16, 10) y = flow.randn(16, 10) x = x.to_global(P0, B) y = y.to_global(P0, B) if rank == 1: time.sleep(2) loss_pack_4, free_out = pp_g(x, y) if rank == 1: # NOTE(chengcheng): Before Oneflow-Inc/oneflow#6221 fix src/dst tick order with input/output, # this case use sleep in rank 1 will expose this BUG: # free_out is output only on rank 1, but NOT control in rank 1 src/dst tick, so if manual sleep # on rank 1, free out pull callback must exec before rank 1 src tick exec, so will meet BUG of # output_kernel buffer status empty. # After this PR fix, this test case ensure that src/dst tick and input/output cb exec order on # each rank is as expected. time.sleep(2) print( "rank: ", rank, "packed loss with 4 micro-batch = ", loss_pack_4.to_local(), ) print( "rank: ", rank, "packed image with 4 micro-batch = ", free_out.to_local(), ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGraphPipelineDelayOutput(oneflow.unittest.TestCase): def test_graph_pipeline_delay_output(test_case): _test_graph_pipeline_delay_output(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_random_seed.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest import inspect import types import oneflow as flow import oneflow.nn as nn import oneflow.unittest def _inspect_rand_op_and_args(rand_op, **kwargs): if inspect.isclass(rand_op) and issubclass(rand_op, nn.Module): init_method_signature = inspect.signature(rand_op.__init__) module_init_args = dict() for arg_name in list(init_method_signature.parameters.keys())[1:]: if arg_name in kwargs: module_init_args[arg_name] = kwargs.pop(arg_name) module_instance = rand_op(**module_init_args) return module_instance, kwargs if isinstance(rand_op, types.BuiltinFunctionType): return rand_op, kwargs if inspect.isfunction(rand_op): return rand_op, kwargs raise ValueError(f"invalid rand_op {rand_op}, type: {type(rand_op)}") # y1 = rand_op1(x) # y2 = rand_op2(x) # rand_op1 and rand_op2 should have different seed in graph, lead to different result def _test_rand_op_in_graph(test_case, rand_op, input=None, **kwargs): rand_op1, kwargs1 = _inspect_rand_op_and_args(rand_op, **kwargs) rand_op2, kwargs2 = _inspect_rand_op_and_args(rand_op, **kwargs) class TestGraphWithoutInput(nn.Graph): def __init__(self): super().__init__() self.rand_op1 = rand_op1 self.rand_op2 = rand_op2 def build(self): y1 = self.rand_op1(**kwargs1) y2 = self.rand_op2(**kwargs2) return y1, y2 class TestGraph(nn.Graph): def __init__(self): super().__init__() self.rand_op1 = rand_op1 self.rand_op2 = rand_op2 def build(self, x): x1 = x x2 = x.clone() y1 = self.rand_op1(x1, **kwargs1) y2 = self.rand_op2(x2, **kwargs2) return y1, y2 if input is None: graph = TestGraphWithoutInput() rand_result1, rand_result2 = graph() else: graph = TestGraph() rand_result1, rand_result2 = graph(input) if isinstance(rand_result1, (list, tuple)): rand_result1 = rand_result1[0] if isinstance(rand_result2, (list, tuple)): rand_result2 = rand_result2[0] test_case.assertFalse( np.allclose(rand_result1.numpy(), rand_result2.numpy()), f"\ninput:\n{input}\nrand_result1:\n{rand_result1}\nrand_result2:\n{rand_result2}", ) def _get_shape_and_device_from_args(pop_device=False, **kwargs): if "size" in kwargs: shape = kwargs["size"] elif "shape" in kwargs: shape = kwargs["shape"] elif "n" in kwargs: shape = (kwargs["n"],) else: raise ValueError(f"can't parse shape from kwargs {kwargs}") device = "cpu" if "device" in kwargs: device = kwargs["device"] return shape, device # Test FRB (Forward Recomputation Backpropagation) # y = rand_op(x) * w # dw = fake_rand_op(x) * dy # (y * w).backward() will result in dy == w # so dw == y demand rand_op(x) == fake_rand_op(x) # in checkpoint activation graph # fake_rand_op in backward should produce the same result with rand_op in forward def _test_rand_op_in_FRB(test_case, rand_op, input=None, **kwargs): rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs) class CheckpointActivationModule(nn.Module): def __init__(self, weight, is_src_rand=False): super().__init__() self.rand_op = rand_op self.is_src_rand = is_src_rand self.weight = weight self.param = nn.Parameter(flow.zeros(*weight.shape)) def forward(self, x): weight = self.param - self.weight if self.is_src_rand: y = self.rand_op(**kwargs) + x else: y = self.rand_op(x, **kwargs) if isinstance(y, (tuple, list)): y = y[0] return y * weight class TestGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model self.model.to(nn.graph.GraphModule).activation_checkpointing = True self.add_optimizer(flow.optim.SGD(self.model.parameters(), lr=1.0)) def build(self, x): y = self.model(x) (y * self.model.weight).sum().backward() return y if input is None: shape, device = _get_shape_and_device_from_args(**kwargs) x = flow.randn(*shape).to(device) weight = flow.randn(*shape).to(device) model = CheckpointActivationModule(weight, True).to(device) graph = TestGraph(model) else: x = input weight = flow.randn(*input.shape).to(input.device) model = CheckpointActivationModule(weight, False).to(input.device) graph = TestGraph(model) y = graph(x) test_case.assertTrue( np.allclose(y.numpy(), model.param.numpy()), f"\nx=\n{x.numpy()}\nweight=\n{weight.numpy()}\ny=\n{y.numpy()}\ndweight=\n{model.param.numpy()}", ) def _test_split_rand_op_in_graph(test_case, rand_op, input=None, **kwargs): rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs) class TestGraph(nn.Graph): def __init__(self): super().__init__() self.rand_op = rand_op def build(self, x): x = x.to_global(sbp=flow.sbp.split(0)) y = self.rand_op(x, **kwargs) return y class TestGraphWithoutInput(nn.Graph): def __init__(self, placement): super().__init__() self.rand_op = rand_op self.placement = placement def build(self): y = self.rand_op(placement=self.placement, sbp=flow.sbp.split(0), **kwargs) return y ranks = np.array(range(flow.env.get_world_size())) if input is None: device = kwargs.pop("device", None) placement = flow.placement(device, ranks) graph = TestGraphWithoutInput(placement) y_global = graph() else: x = flow.concat([input, input], dim=0) placement = flow.placement(input.device.type, ranks) # local to broadcast global x_global = x.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True) graph = TestGraph() y_global = graph(x_global) if isinstance(y_global, (list, tuple)): y_global = y_global[0] y_global = y_global.to_global(placement=placement, sbp=flow.sbp.broadcast()) half = y_global.shape[0] // 2 first_half = y_global[0:half] second_half = y_global[half:] test_case.assertFalse(np.allclose(first_half.numpy(), second_half.numpy())) def _test_broadcast_rand_op_in_graph(test_case, rand_op, input=None, **kwargs): rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs) class TestGraph(nn.Graph): def __init__(self): super().__init__() self.rand_op = rand_op def build(self, x): y = self.rand_op(x, **kwargs) return y class TestGraphWithoutInput(nn.Graph): def __init__(self, placement): super().__init__() self.rand_op = rand_op self.placement = placement def build(self): y = self.rand_op(placement=placement, sbp=flow.sbp.broadcast(), **kwargs) return y ranks = np.array(range(flow.env.get_world_size())) if input is None: device = kwargs.pop("device", None) placement = flow.placement(device, ranks) graph = TestGraphWithoutInput(placement) y_global = graph() else: placement = flow.placement(input.device.type, ranks) # local to broadcast global x = input x_global = x.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True) graph = TestGraph() y_global = graph(x_global) if isinstance(y_global, (list, tuple)): y_local = y_global[0].to_local() else: y_local = y_global.to_local() y_all_ranks = y_local.to_global(placement=placement, sbp=flow.sbp.split(0)) y_allgather = y_all_ranks.to_global(sbp=flow.sbp.broadcast()) half = y_allgather.shape[0] // 2 first_half = y_allgather[0:half] second_half = y_allgather[half:] test_case.assertTrue(np.allclose(first_half.numpy(), second_half.numpy())) @flow.unittest.skip_unless_1n1d() class TestRandOpInGraph(oneflow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op(self): for device in ("cpu", "cuda"): x = flow.randn(4, 16, device=device) _test_rand_op_in_graph(self, nn.Dropout, x, p=0.5) _test_rand_op_in_graph(self, flow._C.rrelu, x, training=True) _test_rand_op_in_graph(self, nn.init.uniform_, x) _test_rand_op_in_graph(self, flow._C.exponential_, x) x1 = flow.rand(4, 16, device=device) _test_rand_op_in_graph( self, flow.multinomial, x1, num_samples=16, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op(self): shape = (4, 16) for device in ("cpu", "cuda"): _test_rand_op_in_graph(self, flow.rand, size=shape, device=device) _test_rand_op_in_graph( self, flow.normal, mean=0.0, std=1.0, size=shape, device=device ) _test_rand_op_in_graph( self, flow.randint, low=0, high=10, size=shape, device=device ) _test_rand_op_in_graph(self, flow.randperm, n=32, device=device) def test_bernoulli(self): x1 = flow.randn(4, 16) _test_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(4, 16) _test_rand_op_in_graph(self, flow.bernoulli, x2) @unittest.skip("skip for now, becase it failed 4 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like(self): x = flow.randn(4, 16, 128, 128).to("cuda") _test_rand_op_in_graph( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.1, diagonal=2, tril_scale_value=-1000, ) @flow.unittest.skip_unless_1n1d() class TestRandOpInFRB(oneflow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op(self): for device in ("cpu", "cuda"): x = flow.randn(4, 16, device=device) _test_rand_op_in_FRB(self, nn.Dropout, x, p=0.5) _test_rand_op_in_FRB(self, flow._C.rrelu, x, training=True) _test_rand_op_in_FRB(self, nn.init.uniform_, x) _test_rand_op_in_FRB(self, flow._C.exponential_, x) x1 = flow.rand(4, 16, device=device) _test_rand_op_in_FRB( self, flow.multinomial, x1, num_samples=16, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op(self): shape = (4, 16) for device in ("cpu", "cuda"): _test_rand_op_in_FRB(self, flow.rand, size=shape, device=device) _test_rand_op_in_FRB( self, flow.normal, mean=0.0, std=1.0, size=shape, device=device ) _test_rand_op_in_FRB( self, flow.randint, low=0, high=10, size=shape, device=device ) _test_rand_op_in_FRB(self, flow.randperm, n=32, device=device) def test_bernoulli(self): x1 = flow.randn(4, 16) _test_rand_op_in_FRB(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(4, 16) _test_rand_op_in_FRB(self, flow.bernoulli, x2) @unittest.skip("skip for now, becase it failed 4 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like(self): x = flow.randn(4, 16, 128, 128).to("cuda") _test_rand_op_in_FRB( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.1, diagonal=0, tril_scale_value=-1000, ) @flow.unittest.skip_unless_1n2d() class TestGlobalRandInGraph(oneflow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op_with_split(self): x = flow.randn(2, 4, device="cuda") _test_split_rand_op_in_graph(self, nn.Dropout, x, p=0.5) _test_split_rand_op_in_graph(self, flow._C.rrelu, x, training=True) _test_split_rand_op_in_graph(self, nn.init.uniform_, x) _test_split_rand_op_in_graph(self, flow._C.exponential_, x) x1 = flow.rand(2, 8, device="cuda") _test_split_rand_op_in_graph( self, flow.multinomial, x1, num_samples=8, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op_with_broadcast(self): x = flow.randn(2, 4, device="cuda") _test_broadcast_rand_op_in_graph(self, nn.Dropout, x, p=0.5) _test_broadcast_rand_op_in_graph(self, flow._C.rrelu, x, training=True) _test_broadcast_rand_op_in_graph(self, nn.init.uniform_, x) _test_broadcast_rand_op_in_graph(self, flow._C.exponential_, x) x1 = flow.rand(2, 8, device="cuda") _test_broadcast_rand_op_in_graph( self, flow.multinomial, x1, num_samples=8, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op_with_split(self): shape = (4, 4) _test_split_rand_op_in_graph(self, flow.rand, size=shape, device="cuda") _test_split_rand_op_in_graph( self, flow.normal, mean=0.0, std=1.0, size=shape, device="cuda" ) _test_split_rand_op_in_graph( self, flow.randint, low=0, high=10, size=shape, device="cuda" ) _test_split_rand_op_in_graph(self, flow.randperm, n=32, device="cuda") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op_with_broadcast(self): shape = (4, 4) _test_broadcast_rand_op_in_graph(self, flow.rand, size=shape, device="cuda") _test_broadcast_rand_op_in_graph( self, flow.normal, mean=0.0, std=1.0, size=shape, device="cuda" ) _test_broadcast_rand_op_in_graph( self, flow.randint, low=0, high=10, size=shape, device="cuda" ) _test_broadcast_rand_op_in_graph(self, flow.randperm, n=32, device="cuda") def test_bernoulli_with_split(self): x1 = flow.randn(2, 8) _test_split_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(2, 8) _test_split_rand_op_in_graph(self, flow.bernoulli, x2) def test_bernoulli_with_broadcast(self): x1 = flow.randn(2, 8) _test_broadcast_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(2, 8) _test_broadcast_rand_op_in_graph(self, flow.bernoulli, x2) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like_with_split(self): x = flow.randn(2, 16, 64).to("cuda") _test_split_rand_op_in_graph( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.1, diagonal=0, tril_scale_value=-1000, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like_with_broadcast(self): x = flow.randn(2, 16, 64).to("cuda") _test_broadcast_rand_op_in_graph( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.2, diagonal=1, tril_scale_value=-100, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_relu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestReluGraph(oneflow.unittest.TestCase): def test_relu_graph(test_case): data = np.array([2.0, 1.0, 0.0, -1.0, -2.0]) x = flow.tensor(data, dtype=flow.float32) MyRelu = flow.nn.ReLU() y_eager = MyRelu(x) # print("eager out :", y_eager) class ReluGraph(flow.nn.Graph): def __init__(self): super().__init__() self.cc_relu = MyRelu def build(self, x): return self.cc_relu(x) relu_g = ReluGraph() y_lazy = relu_g(x) # print(f"type of lazy y: {type(y_lazy)}") # print(f"lazy y shape: {y_lazy.shape}, data: {y_lazy}") test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_reshape_acc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule def _test_graph_reshape_acc(test_case): class StageLayerModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(10, 8, False) self.linear2 = flow.nn.Linear(8, 10, False) flow.nn.init.constant_(self.linear1.weight, 0.023) flow.nn.init.constant_(self.linear2.weight, 1.23) def forward(self, x): out0 = self.linear1(x) out0 = flow.reshape(out0, (-1, 2, 4)) out0 = out0 + 1.0 out0 = out0 * 2.0 out0 = flow.reshape(out0, (-1, 8)) out1 = self.linear2(out0) return out1 P0 = flow.placement("cuda", ranks=[0]) P1 = flow.placement("cuda", ranks=[1]) B = flow.sbp.broadcast class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.layer_0 = StageLayerModule() self.layer_1 = StageLayerModule() self.layer_0.to_global(P0, B) self.layer_1.to_global(P1, B) def forward(self, x): # stage 0 x = flow.flatten(x, start_dim=1) in0 = x.to_global(P0, B) out0 = self.layer_0(in0) # stage 1 in1 = out0.to_global(P1, B) out1 = self.layer_1(in1) return out1 pp_m = PipelineModule() pp_m.train() sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001) class PipelineGraph(flow.nn.Graph): def __init__(self): super().__init__() self.pp_m = pp_m self.pp_m.layer_0.to(GraphModule).set_stage(0) self.pp_m.layer_1.to(GraphModule).set_stage(1) self.loss_fn = flow.nn.CrossEntropyLoss() self.config.set_gradient_accumulation_steps(2) self.add_optimizer(sgd) def build(self, x, y): out = self.pp_m(x) y = y.to_global(P1, B) loss = self.loss_fn(out, y) loss.backward() return loss pp_g = PipelineGraph() for i in range(20): x = flow.randn(6, 2, 5) y = flow.randint(0, 10, (6,)) x = x.to_global(P0, B) y = y.to_global(P1, B) out = pp_g(x, y) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGraphReshapeAcc(oneflow.unittest.TestCase): def test_graph_reshape_acc(test_case): _test_graph_reshape_acc(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_reuse_var.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGraphResueVar(flow.unittest.TestCase): def test_graph_reuse_var(test_case): rank = flow.env.get_rank() P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast class ReuseVarModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(2, 2) self.linear2 = flow.nn.Linear(2, 2) # Reuse parameter self.linear2.weight = self.linear1.weight def forward(self, x): # Allow user to call parameter outside it's module. self.linear1.weight x = self.linear1(x) x = self.linear2(x) return x reuse_var_m = ReuseVarModule() reuse_var_m.to_global(placement=P, sbp=B) of_sgd = flow.optim.SGD(reuse_var_m.parameters(), lr=0.001, momentum=0.9) class ReuseVarGraph(flow.nn.Graph): def __init__(self): super().__init__() self.reuse_var_m = reuse_var_m self.add_optimizer(of_sgd) def build(self, x): x = self.reuse_var_m(x) loss = x.sum() loss.backward() return loss x = flow.randint(0, 1, (2, 2), placement=P, sbp=B, dtype=flow.float32) reuse_var_g = ReuseVarGraph() loss = reuse_var_g(x) # check lazy tensor builder block = reuse_var_g.reuse_var_m test_case.assertEqual( block.linear1.weight.lazy_origin_builder().name, "reuse_var_m.linear1.weight", ) test_case.assertEqual( block.linear1.weight.lazy_origin_builder().name, block.linear2.weight.lazy_origin_builder().name, ) # check optimizer's variable list var_list = [ "reuse_var_m.linear1.weight", "reuse_var_m.linear1.bias", "reuse_var_m.linear2.bias", ] var_list_in_conf = reuse_var_g._graph_proto.job_conf.train_conf.optimizer_conf[ 0 ].variable_op_names test_case.assertEqual(len(var_list_in_conf), 3) for idx in range(3): test_case.assertEqual(var_list[idx], var_list_in_conf[idx]) if rank == 0: print(var_list_in_conf[idx]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_save_load.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import tempfile import oneflow as flow import oneflow.unittest def _test_linear_graph_save_load(test_case, device): def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) class LinearTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear self.add_optimizer(of_sgd) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraph() if call_cnt == 1: state_dict = flow.load(state_dict_file) linear_t_g.load_state_dict(state_dict) # Check state in module has been loaded. test_case.assertTrue( np.array_equal(state_dict["linear"]["weight"].numpy(), linear.weight) ) test_case.assertTrue( np.array_equal(state_dict["linear"]["bias"].numpy(), linear.bias) ) # Get state dict before compile is allowed. init_state_dict = linear_t_g.state_dict() of_graph_out = linear_t_g(x) iter0_state_dict = linear_t_g.state_dict() if call_cnt == 1: # Check additional variable state initialized in job has been loaded. cur_train_step = iter0_state_dict["System-Train-TrainStep"].numpy()[0] test_case.assertEqual(3, cur_train_step) test_case.assertTrue( cur_train_step == last_state_dict["System-Train-TrainStep"].numpy()[0] ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["weight"].numpy(), last_state_dict["linear"]["weight"].numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["bias"].numpy(), last_state_dict["linear"]["bias"].numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.weight-momentum"].numpy(), last_state_dict["linear.weight-momentum"].numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.bias-momentum"].numpy(), last_state_dict["linear.bias-momentum"].numpy(), ) ) of_graph_out = linear_t_g(x) of_graph_out.numpy() iter1_state_dict = linear_t_g.state_dict() if call_cnt == 0: flow.save(iter1_state_dict, state_dict_file) if call_cnt == 0: of_graph_out = linear_t_g(x) iter2_state_dict = linear_t_g.state_dict() of_graph_out.numpy() return iter2_state_dict with tempfile.NamedTemporaryFile(prefix="graph_save_load_local") as f: iter2_state_dict = train_with_graph(0, f.name) train_with_graph(1, f.name, iter2_state_dict) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLinearGraphSaveLoad(oneflow.unittest.TestCase): def test_linear_graph_save_load_gpu(test_case): _test_linear_graph_save_load(test_case, flow.device("cuda")) def _test_linear_graph_save_load_cpu(test_case): _test_linear_graph_save_load(test_case, flow.device("cpu")) def _test_linear_graph_save_load_global(test_case, device): P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S = flow.sbp.split(0) def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None): linear = flow.nn.Linear(3, 8) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) linear.to_global(placement=P, sbp=B) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=device, requires_grad=False, ) x = x.to_global(placement=P, sbp=S) class LinearTrainGraphGlobal(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear self.add_optimizer(of_sgd) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraphGlobal() if call_cnt == 1: state_dict = flow.load(state_dict_file, global_src_rank=0) linear_t_g.load_state_dict(state_dict) # Check state in module has been loaded. # Tensors in state dict are save to rank 0, so they need to be broadcast to rank 0 and 1 before check. test_case.assertTrue( np.array_equal( state_dict["linear"]["weight"] .to_global(placement=P, sbp=B) .to_local() .numpy(), linear.weight.to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( state_dict["linear"]["bias"] .to_global(placement=P, sbp=B) .to_local() .numpy(), linear.bias.to_local().numpy(), ) ) # Get state dict before compile is allowed. init_state_dict = linear_t_g.state_dict() of_graph_out = linear_t_g(x) iter0_state_dict = linear_t_g.state_dict() if call_cnt == 1: # Check additional variable state initialized in job has been loaded. # TrainStep's placement is only on rank 0, so it needs to be broadcast to rank 0 and 1 before check. cur_train_step = ( iter0_state_dict["System-Train-TrainStep"] .to_global(placement=P, sbp=B) .to_local() .numpy()[0] ) test_case.assertEqual(3, cur_train_step) test_case.assertTrue( cur_train_step == last_state_dict["System-Train-TrainStep"] .to_global(placement=P, sbp=B) .to_local() .numpy()[0] ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["weight"].to_local().numpy(), last_state_dict["linear"]["weight"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["bias"].to_local().numpy(), last_state_dict["linear"]["bias"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.weight-momentum"].to_local().numpy(), last_state_dict["linear.weight-momentum"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.bias-momentum"].to_local().numpy(), last_state_dict["linear.bias-momentum"].to_local().numpy(), ) ) of_graph_out = linear_t_g(x) of_graph_out.numpy() iter1_state_dict = linear_t_g.state_dict() if call_cnt == 0: flow.save(iter1_state_dict, state_dict_file, global_dst_rank=0) if call_cnt == 0: of_graph_out = linear_t_g(x) of_graph_out.numpy() iter2_state_dict = linear_t_g.state_dict() return iter2_state_dict with tempfile.NamedTemporaryFile(prefix="graph_save_load_global") as f: iter2_state_dict = train_with_graph(0, f.name) train_with_graph(1, f.name, iter2_state_dict) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestLinearGraphSaveLoadGlobal(oneflow.unittest.TestCase): def test_linear_graph_save_load_gpu(test_case): _test_linear_graph_save_load_global(test_case, flow.device("cuda")) def _test_linear_graph_save_load_cpu(test_case): _test_linear_graph_save_load_global(test_case, flow.device("cpu")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_save_load_global_b_s.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import tempfile import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule def _test_linear_graph_save_load_global_broadcast( test_case, model_tensor_placement, model_file_placement ): """Data parallelism on 2 ranks. """ B = flow.sbp.broadcast S0 = flow.sbp.split(0) def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None): linear = flow.nn.Linear(3, 8) linear = linear.to(flow.device(model_tensor_placement.type)) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) linear.to_global(placement=model_tensor_placement, sbp=B) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=flow.device(model_tensor_placement.type), requires_grad=False, ) x = x.to_global(placement=model_tensor_placement, sbp=S0) class LinearTrainGraphGlobal(flow.nn.Graph): def __init__(self): super().__init__() self.linear = linear self.add_optimizer(of_sgd) def build(self, x): out = self.linear(x) out = out.sum() out.backward() return out linear_t_g = LinearTrainGraphGlobal() cur_rank = flow.env.get_rank() if call_cnt == 1: if cur_rank in model_file_placement.ranks: local_state_dict = flow.load(state_dict_file) else: local_state_dict = None global_state_dict = flow.utils.global_view.to_global( local_state_dict, placement=model_file_placement, sbp=B ) linear_t_g.load_state_dict(global_state_dict) if cur_rank == 0: # Ignore None on rank 1 # Check state in module has been loaded. test_case.assertTrue( np.array_equal( global_state_dict["linear"]["weight"].to_local().numpy(), linear.weight.to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( global_state_dict["linear"]["bias"].to_local().numpy(), linear.bias.to_local().numpy(), ) ) # Get state dict before compile is allowed. init_state_dict = linear_t_g.state_dict() of_graph_out = linear_t_g(x) iter0_state_dict = linear_t_g.state_dict() # Load the model and check if call_cnt == 1: # Check additional variable state initialized in job has been loaded. # TrainStep's placement is only on rank 0, so it needs to be broadcast to all ranks before check. cur_train_step = ( iter0_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=B) .to_local() .numpy()[0] ) test_case.assertEqual(3, cur_train_step) test_case.assertTrue( cur_train_step == last_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=B) .to_local() .numpy()[0] ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["weight"].to_local().numpy(), last_state_dict["linear"]["weight"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear"]["bias"].to_local().numpy(), last_state_dict["linear"]["bias"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.weight-momentum"].to_local().numpy(), last_state_dict["linear.weight-momentum"].to_local().numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["linear.bias-momentum"].to_local().numpy(), last_state_dict["linear.bias-momentum"].to_local().numpy(), ) ) of_graph_out = linear_t_g(x) of_graph_out.numpy() iter1_state_dict = linear_t_g.state_dict() # Save the model if call_cnt == 0: # Transfer the state dict to model_file_placement model_file_state_dict = flow.utils.global_view.to_global( iter1_state_dict, placement=model_file_placement, sbp=B ) # Get the local component and save it on model_file_placement's rank(s) if cur_rank in model_file_placement.ranks: iter1_local_dict = flow.utils.global_view.to_local( model_file_state_dict ) flow.save(iter1_local_dict, state_dict_file) of_graph_out = linear_t_g(x) of_graph_out.numpy() iter2_state_dict = linear_t_g.state_dict() return iter2_state_dict rank_id = flow.env.get_rank() with tempfile.NamedTemporaryFile( prefix="graph_save_load_global_" + str(rank_id) ) as f: iter2_state_dict = train_with_graph(0, f.name) train_with_graph(1, f.name, iter2_state_dict) def _test_graph_save_load_global_split_2( test_case, model_tensor_placement, model_file_placement ): """Pipeline parallelism on 2 ranks. """ P0 = flow.placement(model_tensor_placement.type, ranks=[0]) P1 = flow.placement(model_tensor_placement.type, ranks=[1]) BROADCAST = flow.sbp.broadcast def get_sbp(state_dict, tensor): if tensor is state_dict["System-Train-TrainStep"]: return BROADCAST if tensor is state_dict["module_pipeline"]["m_stage1.linear.weight"]: return flow.sbp.split(1) if tensor is state_dict["module_pipeline"]["m_stage1.linear.bias"]: return BROADCAST return flow.sbp.split(0) class Stage0Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(16, 8) self.relu = flow.nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) class Stage1Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(8, 1) def forward(self, x): return self.linear(x) class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.m_stage0 = Stage0Module() self.m_stage1 = Stage1Module() self.m_stage0.to_global(placement=P0, sbp=BROADCAST) self.m_stage1.to_global(placement=P1, sbp=BROADCAST) def forward(self, x): out_stage0 = self.m_stage0(x) in_stage1 = out_stage0.to_global(placement=P1, sbp=BROADCAST) out_stage1 = self.m_stage1(in_stage1) return out_stage1 class PipelineGraph(flow.nn.Graph): def __init__(self, module_pipleine): super().__init__() self.module_pipeline = module_pipleine self.module_pipeline.m_stage0.to(GraphModule).set_stage(0, P0) self.module_pipeline.m_stage1.to(GraphModule).set_stage(1, P1) self.config.set_gradient_accumulation_steps(2) self.add_optimizer( flow.optim.SGD(self.module_pipeline.parameters(), lr=0.001) ) def build(self, x): out = self.module_pipeline(x) out = out.sum() out.backward() return out def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None): # A fixed input with shape [2, 16] x = flow.tensor( [ [ 0.4286, 0.7402, 0.4161, 0.6103, 0.7394, 1.1330, -0.2311, -0.1013, 0.8537, 0.9757, -0.9842, 0.3839, -0.5551, -0.8832, 0.7820, 0.7421, ], [ -0.1581, -1.0319, 1.8430, 0.3576, 0.7288, -0.6912, 0.9966, 1.0840, -1.1760, 1.5683, -0.2098, -1.6439, -2.7049, 0.1949, 1.6377, 0.0745, ], ], dtype=oneflow.float32, placement=P0, sbp=BROADCAST, ) module_pipleine = PipelineModule() graph_model = PipelineGraph(module_pipleine) cur_rank = flow.env.get_rank() if call_cnt == 1: if cur_rank in model_file_placement.ranks: local_state_dict = flow.load(state_dict_file) else: local_state_dict = None # test sbp_for_special_keys global_state_dict = flow.utils.global_view.to_global( local_state_dict, placement=model_file_placement, sbp=get_sbp, ) graph_model.load_state_dict(global_state_dict) if cur_rank == 0: test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), module_pipleine.m_stage0.linear.weight.to_local().numpy()[ :4 ], # The first half of shape (8, 16) ) ) test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), module_pipleine.m_stage0.linear.bias.to_local().numpy()[ :4 ], # The first half of shape (8,) ) ) if cur_rank == 1: test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), module_pipleine.m_stage1.linear.weight.to_local().numpy()[ :, 4: ], # The second half of shape (1, 8) ) ) test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), module_pipleine.m_stage1.linear.bias.to_local().numpy(), ) ) graph_model(x) iter0_state_dict = graph_model.state_dict() if call_cnt == 1: # TrainStep cur_train_step = ( iter0_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=BROADCAST) .to_local() .numpy()[0] ) test_case.assertEqual(3, cur_train_step) test_case.assertTrue( cur_train_step == last_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=BROADCAST) .to_local() .numpy()[0] ) # Weight & bias test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), ) ) graph_model(x) iter1_state_dict = graph_model.state_dict() if call_cnt == 0: model_file_state_dict = flow.utils.global_view.to_global( iter1_state_dict, placement=model_file_placement, sbp=get_sbp, ) if flow.env.get_rank() in model_file_placement.ranks: flow.save( flow.utils.global_view.to_local(model_file_state_dict), state_dict_file, ) graph_model(x) iter2_state_dict = graph_model.state_dict() return iter2_state_dict rank_id = flow.env.get_rank() with tempfile.NamedTemporaryFile( prefix="graph_save_load_global_" + str(rank_id) ) as f: iter2_state_dict = train_with_graph(0, f.name) train_with_graph(1, f.name, iter2_state_dict) def _test_graph_save_load_global_split_4( test_case, model_tensor_placement, model_file_placement ): """Pipeline parallelism on 4 ranks. """ P0 = flow.placement(model_tensor_placement.type, ranks=[0]) P1 = flow.placement(model_tensor_placement.type, ranks=[1]) P2 = flow.placement(model_tensor_placement.type, ranks=[2]) P3 = flow.placement(model_tensor_placement.type, ranks=[3]) BROADCAST = flow.sbp.broadcast def get_sbp(state_dict, tensor): if tensor is state_dict["System-Train-TrainStep"]: return BROADCAST if tensor is state_dict["module_pipeline"]["m_stage3.linear.weight"]: return flow.sbp.split(1) if tensor is state_dict["module_pipeline"]["m_stage3.linear.bias"]: return BROADCAST return flow.sbp.split(0) class Stage0Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(16, 8) self.relu = flow.nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) class Stage1Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(8, 4) self.relu = flow.nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) class Stage2Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(4, 2) self.relu = flow.nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) class Stage3Module(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(2, 1) def forward(self, x): return self.linear(x) class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.m_stage0 = Stage0Module() self.m_stage1 = Stage1Module() self.m_stage2 = Stage2Module() self.m_stage3 = Stage3Module() self.m_stage0.to_global(placement=P0, sbp=BROADCAST) self.m_stage1.to_global(placement=P1, sbp=BROADCAST) self.m_stage2.to_global(placement=P2, sbp=BROADCAST) self.m_stage3.to_global(placement=P3, sbp=BROADCAST) def forward(self, x): out_stage0 = self.m_stage0(x) in_stage1 = out_stage0.to_global(placement=P1, sbp=BROADCAST) out_stage1 = self.m_stage1(in_stage1) in_stage2 = out_stage1.to_global(placement=P2, sbp=BROADCAST) out_stage2 = self.m_stage2(in_stage2) in_stage3 = out_stage2.to_global(placement=P3, sbp=BROADCAST) out_stage3 = self.m_stage3(in_stage3) return out_stage3 class PipelineGraph(flow.nn.Graph): def __init__(self, module_pipleine): super().__init__() self.module_pipeline = module_pipleine self.module_pipeline.m_stage0.to(GraphModule).set_stage(0, P0) self.module_pipeline.m_stage1.to(GraphModule).set_stage(1, P1) self.module_pipeline.m_stage2.to(GraphModule).set_stage(2, P2) self.module_pipeline.m_stage3.to(GraphModule).set_stage(3, P3) self.config.set_gradient_accumulation_steps(2) self.add_optimizer( flow.optim.SGD(self.module_pipeline.parameters(), lr=0.001) ) def build(self, x): out = self.module_pipeline(x) out = out.sum() out.backward() return out def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None): # A fixed input with shape [2, 16] x = flow.tensor( [ [ 0.4286, 0.7402, 0.4161, 0.6103, 0.7394, 1.1330, -0.2311, -0.1013, 0.8537, 0.9757, -0.9842, 0.3839, -0.5551, -0.8832, 0.7820, 0.7421, ], [ -0.1581, -1.0319, 1.8430, 0.3576, 0.7288, -0.6912, 0.9966, 1.0840, -1.1760, 1.5683, -0.2098, -1.6439, -2.7049, 0.1949, 1.6377, 0.0745, ], ], dtype=flow.float32, placement=P0, sbp=BROADCAST, ) module_pipleine = PipelineModule() graph_model = PipelineGraph(module_pipleine) cur_rank = flow.env.get_rank() if call_cnt == 1: if cur_rank in model_file_placement.ranks: local_state_dict = flow.load(state_dict_file) else: local_state_dict = None # test sbp_for_special_keys global_state_dict = flow.utils.global_view.to_global( local_state_dict, placement=model_file_placement, sbp=get_sbp, ) graph_model.load_state_dict(global_state_dict) if cur_rank == 0: test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), module_pipleine.m_stage0.linear.weight.to_local().numpy()[ :4 ], # The first half of shape (8, 16) ) ) test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), module_pipleine.m_stage0.linear.bias.to_local().numpy()[ :4 ], # The first half of shape (8,) ) ) if cur_rank == 1: test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), module_pipleine.m_stage1.linear.weight.to_local().numpy()[ 2:, : ], # The second half of shape (4, 8) ) ) test_case.assertTrue( np.array_equal( global_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), module_pipleine.m_stage1.linear.bias.to_local().numpy()[ 2: ], # The second half if shape (4,) ) ) graph_model(x) iter0_state_dict = graph_model.state_dict() if call_cnt == 1: # TrainStep cur_train_step = ( iter0_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=BROADCAST) .to_local() .numpy()[0] ) test_case.assertEqual(3, cur_train_step) test_case.assertTrue( cur_train_step == last_state_dict["System-Train-TrainStep"] .to_global(placement=model_tensor_placement, sbp=BROADCAST) .to_local() .numpy()[0] ) # Weight & bias test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage0.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage0.linear.bias"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage1.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage1.linear.bias"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage2.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage2.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage2.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage2.linear.bias"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage3.linear.weight"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage3.linear.weight"] .to_local() .numpy(), ) ) test_case.assertTrue( np.array_equal( iter0_state_dict["module_pipeline"]["m_stage3.linear.bias"] .to_local() .numpy(), last_state_dict["module_pipeline"]["m_stage3.linear.bias"] .to_local() .numpy(), ) ) graph_model(x) iter1_state_dict = graph_model.state_dict() if call_cnt == 0: model_file_state_dict = flow.utils.global_view.to_global( iter1_state_dict, placement=model_file_placement, sbp=get_sbp, ) if flow.env.get_rank() in model_file_placement.ranks: flow.save( flow.utils.global_view.to_local(model_file_state_dict), state_dict_file, ) graph_model(x) iter2_state_dict = graph_model.state_dict() return iter2_state_dict rank_id = flow.env.get_rank() with tempfile.NamedTemporaryFile( prefix="graph_save_load_global_" + str(rank_id) ) as f: iter2_state_dict = train_with_graph(0, f.name) train_with_graph(1, f.name, iter2_state_dict) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestGraphSaveLoadGlobal2d(oneflow.unittest.TestCase): def test_linear_graph_save_load_gpu_1_broadcast(test_case): _test_linear_graph_save_load_global_broadcast( test_case, model_tensor_placement=flow.placement("cuda", ranks=[0, 1]), model_file_placement=flow.placement("cpu", ranks=[0]), ) def test_linear_graph_save_load_cpu_1_broadcast(test_case): _test_linear_graph_save_load_global_broadcast( test_case, model_tensor_placement=flow.placement("cpu", ranks=[0, 1]), model_file_placement=flow.placement("cpu", ranks=[0]), ) def test_graph_save_load_gpu_2_split(test_case): _test_graph_save_load_global_split_2( test_case, model_tensor_placement=flow.placement("cuda", ranks=[0, 1]), model_file_placement=flow.placement("cpu", ranks=[0, 1]), ) @unittest.skip("skip for now, becase it failed 2 times in past week") def test_graph_save_load_cpu_2_split(test_case): _test_graph_save_load_global_split_2( test_case, model_tensor_placement=flow.placement("cpu", ranks=[0, 1]), model_file_placement=flow.placement("cpu", ranks=[0, 1]), ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestGraphSaveLoadGlobal4d(oneflow.unittest.TestCase): def test_graph_save_load_gpu_2_split_2_none(test_case): _test_graph_save_load_global_split_4( test_case, model_tensor_placement=flow.placement("cuda", ranks=[0, 1, 2, 3]), model_file_placement=flow.placement("cpu", ranks=[0, 1]), ) @unittest.skip("skip for now, becase it failed 24 times in past week") def test_graph_save_load_cpu_2_split_2_none(test_case): _test_graph_save_load_global_split_4( test_case, model_tensor_placement=flow.placement("cpu", ranks=[0, 1, 2, 3]), model_file_placement=flow.placement("cpu", ranks=[0, 1]), ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_scalar.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest def _test_scalar_graph(test_case, device): x = flow.tensor(3.0, device=device) class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter(flow.tensor(5.0, device=device)) def forward(self, x): return x * self.weight + 1.0 my_module = MyModule() of_eager_out = my_module(x) class ScalarGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = my_module def build(self, x): return self.m(x) scalar_g = ScalarGraph() of_lazy_out = scalar_g(x) test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) def _test_scalar_train_graph(test_case, device): class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter(flow.tensor(5.0, device=device)) def forward(self, x): return x * self.weight + 1.0 my_module = MyModule() of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9) eager_out_list = [] for i in range(3): x = flow.tensor(i * 1.0, device=device, requires_grad=False) of_eager_out = my_module(x) of_eager_out.backward() of_sgd.step() of_sgd.zero_grad() eager_out_list.append(of_eager_out) lazy_module = MyModule() class ScalarTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = lazy_module of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9) # self.m = MyModule() # of_sgd = flow.optim.SGD(self.m.parameters(), lr=0.001, momentum=0.9) self.add_optimizer(of_sgd) def build(self, x): loss = self.m(x) loss.backward() return loss lazy_out_list = [] scalar_g = ScalarTrainGraph() for i in range(3): x = flow.tensor(i * 1.0, device=device) of_lazy_out = scalar_g(x) lazy_out_list.append(of_lazy_out) for i in range(3): test_case.assertTrue( np.array_equal(lazy_out_list[i].numpy(), eager_out_list[i].numpy()) ) def _test_scalar_global_train_graph(test_case, placement): sbp_b = flow.sbp.broadcast class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter(flow.tensor(5.0)) def forward(self, x): return x * self.weight + 1.0 my_module = MyModule() of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9) eager_out_list = [] for i in range(3): x = flow.tensor(i * 1.0, requires_grad=False) of_eager_out = my_module(x) of_eager_out.backward() of_sgd.step() of_sgd.zero_grad() eager_out_list.append(of_eager_out) lazy_module = MyModule() lazy_module.to_global(placement=placement, sbp=sbp_b) class ScalarTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = lazy_module of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9) self.add_optimizer(of_sgd) def build(self, x): loss = self.m(x) loss.backward() return loss lazy_out_list = [] scalar_g = ScalarTrainGraph() for i in range(3): x = flow.tensor(i * 1.0, requires_grad=False) x = x.to_global(placement=placement, sbp=sbp_b) of_lazy_out = scalar_g(x) lazy_out_list.append(of_lazy_out) for i in range(3): test_case.assertTrue( np.array_equal( lazy_out_list[i].to_local().numpy(), eager_out_list[i].numpy() ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestScalarGraph(oneflow.unittest.TestCase): def test_scalar_graph_gpu(test_case): _test_scalar_graph(test_case, flow.device("cuda")) def test_scalar_graph_cpu(test_case): _test_scalar_graph(test_case, flow.device("cpu")) def test_scalar_train_graph_gpu(test_case): _test_scalar_train_graph(test_case, flow.device("cuda")) def test_scalar_train_graph_cpu(test_case): _test_scalar_train_graph(test_case, flow.device("cpu")) def test_scalar_global_train_graph_gpu(test_case): _test_scalar_global_train_graph(test_case, flow.placement("cuda", ranks=[0])) def test_scalar_global_train_graph_cpu(test_case): _test_scalar_global_train_graph(test_case, flow.placement("cpu", ranks=[0])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_separate_compile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import contextlib import os import numpy as np import oneflow as flow from oneflow import nn import oneflow.unittest @contextlib.contextmanager def modified_environ(*remove, **update): """ From: https://stackoverflow.com/questions/2059482/temporarily-modify-the-current-processs-environment Temporarily updates the ``os.environ`` dictionary in-place. The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations. :param remove: Environment variables to remove. :param update: Dictionary of environment variables and values to add/update. """ env = os.environ update = update or {} remove = remove or [] # List of environment variables being updated or removed. stomped = (set(update.keys()) | set(remove)) & set(env.keys()) # Environment variables and values to restore on exit. update_after = {k: env[k] for k in stomped} # Environment variables and values to remove on exit. remove_after = frozenset(k for k in update if k not in env) try: env.update(update) [env.pop(k, None) for k in remove] yield finally: env.update(update_after) [env.pop(k) for k in remove_after] def run_testcase_with_sep_compile(test_case_cls): new_cls = type("SeparationCompile_" + test_case_cls.__name__, (test_case_cls,), {}) with modified_environ( ONEFLOW_LAZY_COMPILE_MODE="rank_per_process", ENABLE_LOGICAL_CHAIN="1" ): assert os.environ.get("ONEFLOW_LAZY_COMPILE_MODE") == "rank_per_process" assert os.environ.get("ENABLE_LOGICAL_CHAIN") == "1" flow.boxing.nccl.enable_use_compute_stream(True) unittest.TextTestRunner().run( unittest.TestLoader().loadTestsFromTestCase(new_cls) ) def _get_comb1to2d_test(): class _TestModuleDiffHierarchy(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: # (2, 2) -> 4 x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)) ), sbp=[sbp1], ) # 4 -> (2, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), sbp=[sbp2, sbp3], ) return x class _TestModuleDiffPlacement(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: # (2, 2) -> 3 # 4 is not divisible by 3 x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(3)) ), sbp=[sbp1], ) # 3 -> (2, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), sbp=[sbp2, sbp3], ) return x class _TestGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): x = self.model(x) return x @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestSepCompileLazyAllSbpCombinationTesting(flow.unittest.TestCase): def test_lazy_boxing_2d_all_combination_diff_hierarchy(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" x = flow.ones( 4, 12, 4, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), ) model_diff_hierarchy = _TestModuleDiffHierarchy() graph_diff_hierarchy = _TestGraph(model_diff_hierarchy) y = graph_diff_hierarchy(x) def test_lazy_boxing_2d_all_combination_diff_placement(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" x = flow.ones( 4, 12, 4, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), ) model_diff_placement = _TestModuleDiffPlacement() graph_diff_placement = _TestGraph(model_diff_placement) z = graph_diff_placement(x) test_case.assertTrue(np.allclose(x.numpy(), z.numpy(), 1e-05, 1e-05)) return TestSepCompileLazyAllSbpCombinationTesting @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestSeparationCompile(oneflow.unittest.TestCase): def test_test_alexnet_auto_parallel(test_case): from test_alexnet_auto_parallel import TestAlexnetAutoParallel run_testcase_with_sep_compile(TestAlexnetAutoParallel) def _test_comb1to2d(test_case): run_testcase_with_sep_compile(_get_comb1to2d_test()) def test_graph_zero(test_case): from test_graph_zero import TestLinearTrainGraph2DWithZeRO run_testcase_with_sep_compile(TestLinearTrainGraph2DWithZeRO) def test_graph_clip_grad_norm(test_case): from test_graph_clip_grad_norm import TestGraphClipGradNorm run_testcase_with_sep_compile(TestGraphClipGradNorm) def test_graph_pipeline_grad_acc_and_activatioin_checkpointing(test_case): from test_graph_pipeline import TestGraphPipeline run_testcase_with_sep_compile(TestGraphPipeline) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_session_env_destruct.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest linear = flow.nn.Linear(3, 8, False) input_arr = np.random.randn(8, 3).astype(np.float32) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr) flow.nn.init.constant_(linear.weight, 2.3) of_eager_out = linear(x) np_out = np.matmul(input_arr, np_weight) assert np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05) class LinearGraphDestruct(flow.nn.Graph): def __init__(self): super().__init__() self.my_linear = linear def build(self, x): return self.my_linear(x) linear_g_d = LinearGraphDestruct() @flow.unittest.skip_unless_1n1d() class TestLinearGraphDestruct(oneflow.unittest.TestCase): def test_linear_graph_destruct(test_case): of_lazy_out = linear_g_d(x) assert np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_session_env_destruct1.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest linear = flow.nn.Linear(3, 8, False) input_arr = np.random.randn(8, 3).astype(np.float32) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr) flow.nn.init.constant_(linear.weight, 2.3) of_eager_out = linear(x) np_out = np.matmul(input_arr, np_weight) assert np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05) class LinearGraphDestruct1(flow.nn.Graph): def __init__(self): super().__init__() self.my_linear = linear def build(self, x): return self.my_linear(x) # test graph destruction when graph is not compiled linear_g_d_not_compiled = LinearGraphDestruct1() print("test graph destruction when graph is not compiled") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_sparse_optimizer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest class MyModule(flow.nn.Module): def __init__(self, placement=None, sbp=None): super().__init__() w = flow.randn(10, 10, placement=placement, sbp=sbp) self.weight = flow.nn.Parameter(w) def forward(self, input): return flow._C.gather(self.weight, input, 0) class MyGraph(flow.nn.Graph): def __init__(self, module): super().__init__() self.m = module sgd = flow.optim.SGD(module.parameters(), lr=1e-3) self.add_optimizer(sgd, is_sparse=True) def build(self, input): result = self.m(input) result.mean().backward() def _rand_input(placement=None, sbp=None): generator = flow.Generator() generator.manual_seed(0) return flow.randint(0, 10, (8,), generator=generator, placement=placement, sbp=sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class GraphSparseOptimizerTest(oneflow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 6 times in past week") def test(test_case): PLC = flow.placement("cuda", ranks=[0]) SBP = flow.sbp.broadcast m = MyModule(PLC, SBP) graph = MyGraph(m) graph._compile(_rand_input(PLC, SBP)) sparse_optimizer_found = False for op in graph._full_graph_proto.net.op: # print("==>", op.name) if op.HasField("user_conf"): # print(" -->", op.user_conf.op_type_name) if op.user_conf.op_type_name == "indexed_slices_sgd_update": sparse_optimizer_found = True break test_case.assertTrue(sparse_optimizer_found) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_sparse_softmax_cross_entropy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest class CrossEntropyModule(flow.nn.Module): def __init__(self, pred): super().__init__() if pred.is_global: self.param = flow.nn.Parameter( flow.zeros( *pred.shape, dtype=pred.dtype, placement=pred.placement, sbp=pred.sbp, ) ) else: self.param = flow.nn.Parameter( flow.zeros(*pred.shape, dtype=pred.dtype, device=pred.device) ) def forward(self, pred, label): pred = pred + self.param loss = flow._C.sparse_softmax_cross_entropy(pred, label) return loss.mean() class CrossEntropyGraph(flow.nn.Graph): def __init__(self, module): super().__init__() self.m = module self.add_optimizer(flow.optim.SGD([module.param], lr=1.0, momentum=0.0)) def build(self, pred, label): loss = self.m(pred, label) loss.backward() return loss def _compare_with_nn_cross_entropy_loss( test_case, pred, label, pred_sbp=None, label_sbp=None ): if pred.is_global: assert label.is_global pred_ = pred.to_local().detach().clone() label_ = label.to_local() else: pred_ = pred.detach().clone() label_ = label pred_.requires_grad = True cross_entropy_loss = flow.nn.CrossEntropyLoss() loss = cross_entropy_loss(pred_, label_) loss.backward() if pred_sbp is not None: pred = pred.to_global(sbp=pred_sbp) if label_sbp is not None: label = label.to_global(sbp=label_sbp) cross_entropy_module = CrossEntropyModule(pred) cross_entropy_graph = CrossEntropyGraph(cross_entropy_module) graph_loss = cross_entropy_graph(pred, label) loss_a = loss.numpy() grad_a = pred_.grad.numpy() if graph_loss.is_local: loss_b = graph_loss.numpy() grad_b = -cross_entropy_module.param.numpy() else: graph_loss = graph_loss.to_global( sbp=[flow.sbp.broadcast()] * len(graph_loss.sbp) ) loss_b = graph_loss.to_local().numpy() pred_grad = cross_entropy_module.param.to_global( sbp=[flow.sbp.broadcast()] * len(cross_entropy_module.param.sbp) ) grad_b = -pred_grad.to_local().numpy() test_case.assertTrue(np.allclose(loss_a, loss_b), f"{loss_a} vs. {loss_b}") test_case.assertTrue(np.allclose(grad_a, grad_b), f"\n{grad_a}\nvs.\n{grad_b}") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestSparseSoftmaxCrossEntropyGraph(oneflow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_local(test_case): pred = flow.randn(8, 10).to("cuda") label = flow.randint(0, 10, (8,)).to("cuda") _compare_with_nn_cross_entropy_loss(test_case, pred, label) @flow.unittest.skip_unless_1n2d() def test_data_split(test_case): pred = flow.randn(8, 10) label = flow.randint(0, 10, (8,)) placement = flow.placement("cuda", list(range(flow.env.get_world_size()))) pred = pred.to_global(placement=placement, sbp=flow.sbp.broadcast()) label = label.to_global(placement=placement, sbp=flow.sbp.broadcast()) _compare_with_nn_cross_entropy_loss( test_case, pred, label, flow.sbp.split(0), flow.sbp.split(0) ) @flow.unittest.skip_unless_1n2d() def test_model_split(test_case): pred = flow.randn(8, 10) label = flow.randint(0, 10, (8,)) placement = flow.placement("cuda", list(range(flow.env.get_world_size()))) pred = pred.to_global(placement=placement, sbp=flow.sbp.broadcast()) label = label.to_global(placement=placement, sbp=flow.sbp.broadcast()) _compare_with_nn_cross_entropy_loss( test_case, pred, label, flow.sbp.split(1), flow.sbp.broadcast() ) @flow.unittest.skip_unless_1n4d() def test_2d_split(test_case): pred = flow.randn(8, 10) label = flow.randint(0, 10, (8,)) placement = flow.placement( "cuda", np.array(range(flow.env.get_world_size())).reshape(2, 2) ) pred = pred.to_global( placement=placement, sbp=[flow.sbp.broadcast(), flow.sbp.broadcast()] ) label = label.to_global( placement=placement, sbp=[flow.sbp.broadcast(), flow.sbp.broadcast()] ) _compare_with_nn_cross_entropy_loss( test_case, pred, label, [flow.sbp.split(0), flow.sbp.split(1)], [flow.sbp.split(0), flow.sbp.broadcast()], ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_tensor_clone.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensorCloneGraph(oneflow.unittest.TestCase): def test_tensor_clone_graph(test_case): class TensorCloneGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.clone() x += x return x, y x = flow.randn(3, 4) res = TensorCloneGraph()(x) test_case.assertTrue(len(res) == 2) test_case.assertTrue(np.allclose(res[0], res[1] * 2, 1e-05, 1e-05)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_tensor_detach.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensorDetachGraph(oneflow.unittest.TestCase): def test_tensor_detach_graph(test_case): class TensorDetachGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): x += x y = x.detach() return x, y x = flow.randn(3, 4) res = TensorDetachGraph()(x) test_case.assertTrue(len(res) == 2) test_case.assertTrue(np.allclose(res[0], res[1], 1e-05, 1e-05)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_with_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule import oneflow.utils.global_view as global_view from oneflow.utils.global_view import global_mode def _test_linear_train_graph_with_ddp(test_case): def train_with_graph(iter_num=1): PC = flow.placement("cpu", ranks=[0, 1]) P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) linear_dp = flow.nn.Linear(800, 400, bias=False) linear_dp = linear_dp.to_global(placement=P, sbp=B) flow.nn.init.constant_(linear_dp.weight, 2.068758) of_sgd = flow.optim.SGD( [{"params": linear_dp.parameters()}], lr=0.001, momentum=0.9, ) x = flow.ones((6, 800), placement=PC, sbp=S0) class LinearTrainGraphWithDDP(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp self.add_optimizer(of_sgd) def build(self, x): x = x.to_global(placement=P) out = self.linear_dp(x) loss = out.sum() loss.backward() return out class LinearEvalGraphWithDDP(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp def build(self, x): x = x.to_global(placement=P) out = self.linear_dp(x) return out linear_t_g = LinearTrainGraphWithDDP() # linear_t_g.debug(1) linear_e_g = LinearEvalGraphWithDDP() # linear_e_g.debug(1) result_check_list = [] def one_train_iter(iter_cnt=0): out = linear_t_g(x) result_check_list.append(out) # if iter_cnt == 0: # if flow.env.get_rank() == 0: # import traceback # try: # print(linear_t_g) # except: # print(traceback.format_exc()) def one_eval_iter(iter_cnt=0): out = linear_e_g(x) result_check_list.append(out) for i in range(iter_num): one_train_iter(i) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. one_eval_iter() return result_check_list def train_with_graph_ddp(iter_num=1): PC = flow.placement("cpu", ranks=[0, 1]) P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) linear_dp = flow.nn.Linear(800, 400, bias=False) linear_dp = linear_dp.to_global(placement=P, sbp=B) flow.nn.init.constant_(linear_dp.weight, 2.068758) of_sgd = flow.optim.SGD( [{"params": linear_dp.parameters()}], lr=0.001, momentum=0.9, ) with global_mode(True, placement=PC, sbp=S0): x = flow.ones((6, 800), placement=PC, sbp=S0) class LinearTrainGraphWithDDP(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp self.add_optimizer(of_sgd) def build(self, x): # This is ok # x = x.to("cuda") # This is ok # x = x.to_global(placement=P) # This is not ok # x = x.to(device) with global_mode(True, placement=P, sbp=B): # Test global tensor to device device = self.linear_dp.weight.device x = x.to(device) out = self.linear_dp(x) # Test randn source op sample = flow.randn(out.shape, device="cpu").to(device) out = out + sample * 100 # Test disable global_mode while passing placement and sbp with global_mode(False, placement=P, sbp=B): out = out - sample * 100 cur_global_mode = global_view.current_global_mode() test_case.assertFalse(cur_global_mode.is_enabled) loss = out.sum() loss.backward() return out class LinearEvalGraphWithDDP(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp def build(self, x): with global_mode(True, placement=P, sbp=B): device = self.linear_dp.weight.device x = x.to(device) out = self.linear_dp(x) # Test randn source op sample = flow.randn(out.shape, device="cpu").to(device) out = out + sample * 100 out = out - sample * 100 return out linear_t_g = LinearTrainGraphWithDDP() # linear_t_g.debug(1) linear_e_g = LinearEvalGraphWithDDP() # linear_e_g.debug(1) result_check_list = [] def one_train_iter(iter_cnt=0): out = linear_t_g(x) result_check_list.append(out) # if iter_cnt == 0: # if flow.env.get_rank() == 0: # import traceback # try: # print(linear_t_g) # except: # print(traceback.format_exc()) def one_eval_iter(iter_cnt=0): out = linear_e_g(x) result_check_list.append(out) for i in range(iter_num): one_train_iter(i) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. one_eval_iter() return result_check_list iter_num = 2 graph_check_list = train_with_graph(iter_num) graph_ddp_check_list = train_with_graph_ddp(iter_num) test_case.assertEqual(len(graph_check_list), iter_num + 1) test_case.assertEqual(len(graph_ddp_check_list), iter_num + 1) for i in range(iter_num + 1): test_case.assertTrue( np.allclose( graph_check_list[i].numpy(), graph_ddp_check_list[i].numpy(), rtol=1e-5, atol=1e-5, ), f"current index {i} \n base {graph_check_list[i].numpy()} \n ddp {graph_ddp_check_list[i].numpy()} \n diff {graph_ddp_check_list[i].numpy() - graph_check_list[i].numpy()}", ) def _test_global_mode(test_case): P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast class GlobalModeGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): with global_mode(True, placement=P, sbp=B): # Test global mode meta data cur_global_mode = global_view.current_global_mode() test_case.assertTrue(cur_global_mode.is_enabled) test_case.assertEqual(cur_global_mode.placement, P) test_case.assertEqual(cur_global_mode.sbp[0], B) # Test global mode source op randn_out = flow.randn((2, 2)) rand_out = flow.rand((2, 2)) randint_out = flow.randint(-100, 100, (2, 2)) randperm_out = flow.randperm(5) arange_out = flow.arange(10) empty_out = flow.empty((1, 2)) tensor_out = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int) hann_window_out = flow.hann_window(8, dtype=flow.float) test_case.assertTrue(not global_view.current_global_mode().is_enabled) return { "randn_out": randn_out, "rand_out": rand_out, "randint_out": randint_out, "randperm_out": randperm_out, "arange_out": arange_out, "empty_out": empty_out, "tensor_out": tensor_out, "hann_window_out": hann_window_out, } global_graph = GlobalModeGraph() out = global_graph() for k, v in out.items(): test_case.assertEqual(v.is_global, True, k) test_case.assertEqual(v.placement, P, k) test_case.assertEqual(v.sbp[0], B, k) def _test_global_mode_with_default_placement_and_sbp(test_case): # create a tensor with broadcast split and placement on rank 0 a = flow.randn( (1, 8), sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0]) ) # enter global mode with broadcast split and placement on 2 GPUs with global_mode( True, placement=flow.placement(type="cuda", ranks=[0, 1]), sbp=flow.sbp.broadcast, ): # check tensor placement and split test_case.assertTrue(a.placement == flow.placement("cuda", ranks=[0])) test_case.assertTrue(a.sbp == (flow.sbp.broadcast,)) # check tensor print print(a) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestLinearTrainGraphWithDDP(oneflow.unittest.TestCase): def test_linear_train_graph_with_ddp(test_case): _test_linear_train_graph_with_ddp(test_case) @unittest.skip("skip for now, becase it failed 4 times in past week") def test_global_mode(test_case): _test_global_mode(test_case) _test_global_mode_with_default_placement_and_sbp(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_graph_zero.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest from oneflow.nn.graph import GraphModule def _test_linear_train_graph_with_zero(test_case, zero_stage=1): def train_with_graph(iter_num=1): P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) linear_dp = flow.nn.Linear(800, 400, bias=False) linear_dp = linear_dp.to_global(placement=P, sbp=B) flow.nn.init.constant_(linear_dp.weight, 2.068758) linear_mp = flow.nn.Linear(400, 500, bias=False) linear_mp = linear_mp.to_global(placement=P, sbp=S0) flow.nn.init.constant_(linear_mp.weight, 2.068758) of_sgd = flow.optim.SGD( [{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], lr=0.001, momentum=0.9, ) grad_scaler = flow.amp.StaticGradScaler(200) x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=S0) class LinearTrainGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp self.linear_mp = linear_mp self.add_optimizer(of_sgd) self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) self.config.enable_zero( True, stage=zero_stage, shard_min_size=1, shard_restore_level=0, ) self.debug(2) def build(self, x): out = self.linear_dp(x) out = out.to_global(placement=P, sbp=B) out = self.linear_mp(out) loss = out.sum() loss.backward() return out class LinearEvalGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp self.linear_mp = linear_mp self.config.enable_amp(True) def build(self, x): out = self.linear_dp(x) out = out.to_global(placement=P, sbp=B) out = self.linear_mp(out) return out linear_t_g = LinearTrainGraphWithZeRO() linear_t_g.debug(1) linear_e_g = LinearEvalGraphWithZeRO() linear_e_g.debug(1) def one_train_iter(): out = linear_t_g(x) if flow.env.get_rank() == 0: import traceback try: print(linear_t_g) except: print(traceback.format_exc()) def one_eval_iter(): out = linear_e_g(x) for i in range(iter_num): one_train_iter() # After pass rewrite in training graph, parameters' sbp has been # changed from flow.sbp.broadcast to flow.sbp.split(0) test_case.assertEqual(linear_dp.weight.sbp[0], S0) test_case.assertEqual(linear_mp.weight.sbp[0], S0) # In evaluation graph, parameter's sbp are flow.sbp.split(0). # But their consumer will consume them as flow.sbp.broadcast. one_eval_iter() iter_num = 1 graph_check_list = train_with_graph(iter_num) def _test_linear_train_graph_2d_with_zero(test_case, zero_stage=1): def train_with_graph(iter_num=1): P = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) S1 = flow.sbp.split(1) def get_mixed_linear(): linear_dp_mp = flow.nn.Linear(800, 400, bias=False) linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) flow.nn.init.constant_(linear_dp_mp.weight, 1.068758) linear_mp_dp = flow.nn.Linear(800, 400, bias=False) linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) flow.nn.init.constant_(linear_mp_dp.weight, 1.068758) class MixedLinear(flow.nn.Module): def __init__(self): super().__init__() self.dp_mp = linear_dp_mp self.mp_dp = linear_mp_dp def forward(self, x): x = self.dp_mp(x) x = flow.relu(x) x = self.mp_dp(x) x = flow.relu(x) return x return MixedLinear() mixed_linear0 = get_mixed_linear() mixed_linear1 = get_mixed_linear() of_sgd = flow.optim.SGD( [ {"params": mixed_linear0.parameters()}, {"params": mixed_linear1.parameters()}, ], lr=0.001, momentum=0.9, ) grad_scaler = flow.amp.StaticGradScaler(200) x = flow.rand((2, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) class LinearTrainGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() self.mixed_linear0 = mixed_linear0 self.mixed_linear0.to(GraphModule).activation_checkpointing = True self.mixed_linear1 = mixed_linear1 self.mixed_linear1.to(GraphModule).activation_checkpointing = True self.add_optimizer(of_sgd) self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) self.config.enable_zero( True, stage=zero_stage, shard_min_size=1, shard_restore_level=1, ) def build(self, x): out = self.mixed_linear0(x) out = self.mixed_linear1(out) loss = out.mean() loss.backward() return loss class LinearEvalGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() self.mixed_linear0 = mixed_linear0 self.mixed_linear1 = mixed_linear1 self.config.enable_amp(True) def build(self, x): out = self.mixed_linear0(x) out = self.mixed_linear1(out) return out linear_t_g = LinearTrainGraph2DWithZeRO() linear_e_g = LinearEvalGraph2DWithZeRO() def one_train_iter(): out = linear_t_g(x) # if flow.env.get_rank() == 0: # print(linear_t_g) def one_eval_iter(): out = linear_e_g(x) for i in range(iter_num): one_train_iter() for state in linear_t_g._state(): test_case.assertEqual( state.to(flow.Tensor).sbp, (oneflow.sbp.split(dim=0), oneflow.sbp.split(dim=0)), ) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. one_eval_iter() iter_num = 1 graph_check_list = train_with_graph(iter_num) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestLinearTrainGraphWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_with_zero_1(test_case): _test_linear_train_graph_with_zero(test_case, 1) def test_linear_train_graph_with_zero_2(test_case): _test_linear_train_graph_with_zero(test_case, 2) def test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_2d_with_zero_3(test_case): _test_linear_train_graph_2d_with_zero(test_case, 3) def test_linear_train_graph_2d_with_zero_2(test_case): _test_linear_train_graph_2d_with_zero(test_case, 2) def test_linear_train_graph_2d_with_zero_1(test_case): _test_linear_train_graph_2d_with_zero(test_case, 1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_input_op_expr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np from google.protobuf import text_format import oneflow import oneflow as flow import oneflow._oneflow_internal import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest from oneflow.framework.multi_client_session import MultiClientSession @flow.unittest.skip_unless_1n1d() class TestFeedInputTensor(unittest.TestCase): def test_feed_input_tensor(test_case): x = flow.Tensor(1, 1, 10, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) session = session_ctx.GetDefaultSession() test_case.assertTrue(isinstance(session, MultiClientSession)) session.TryInit() with oneflow._oneflow_internal.lazy_mode.guard(True): oneflow._oneflow_internal.JobBuildAndInferCtx_Open( "cc_test_input_op_expr_job" ) job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto() job_conf.job_name = "cc_test_input_op_expr_job" job_conf.predict_conf.SetInParent() c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) op_name = "cc_Input_0" input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf() input_conf.in_0 = "EagerTensorInput" input_conf.out_0 = "out_0" input_conf_str = text_format.MessageToString(input_conf) input_op = oneflow._oneflow_internal.one.FeedInputOpExpr( op_name, input_conf_str, ["in_0"], ["out_0"] ) out_tensor = _C.dispatch_feed_input(input_op, x) test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(out_tensor.is_lazy) test_case.assertTrue(out_tensor.is_local) oneflow._oneflow_internal.JobBuildAndInferCtx_Close() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_long_add_n_pass.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import numpy as np import os import time import unittest import oneflow as flow import oneflow.unittest def _test_long_add_n_graph(test_case, device): input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) x0 = flow.tensor(input_arr, device=device) x1 = flow.tensor(input_arr, device=device) x2 = flow.tensor(input_arr, device=device) x3 = flow.tensor(input_arr, device=device) x4 = flow.tensor(input_arr, device=device) x5 = flow.tensor(input_arr, device=device) x6 = flow.tensor(input_arr, device=device) x7 = flow.tensor(input_arr, device=device) x8 = flow.tensor(input_arr, device=device) x9 = flow.tensor(input_arr, device=device) class AddNGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): # Deprecated `temp = x0 + x0` to avoid unstable test # enable this after fix https://github.com/Oneflow-Inc/oneflow/issues/9431 # temp = x0 + x0 temp = x0 temp = temp + x1 # test add_n1(add_n0(...), ...) temp = temp + temp # test add_n1(add_n0(...), add_n0(...)) temp = temp + x2 temp = temp + x3 temp = temp + x4 temp = temp + x5 temp = temp + x6 temp = temp + x7 other_add_n = x8 + x9 temp = temp + other_add_n # test add_n2(add_n0(), add_n1()) return temp add_n_g = AddNGraph() of_lazy_out = add_n_g() test_case.assertTrue(np.allclose(input_arr * 12, of_lazy_out.numpy(), 1e-05, 1e-05)) def _test_add_n_consume_multi_add_n_graph(test_case, device): input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) x0 = flow.tensor(input_arr, device=device) class AddNGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): temp = x0 + x0 temp = temp + temp return temp add_n_g = AddNGraph() of_lazy_out = add_n_g() test_case.assertTrue(np.allclose(input_arr * 4, of_lazy_out.numpy(), 1e-05, 1e-05)) @unittest.skip("fail on ci") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLongAddNGraph(oneflow.unittest.TestCase): def test_add_n(test_case): device = "cuda" _test_long_add_n_graph(test_case, device) def test_consume_multi_add_n(test_case): device = "cuda" _test_add_n_consume_multi_add_n_graph(test_case, device) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_modify_module_forward.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import numpy as np import os import time import unittest from types import MethodType import oneflow as flow import oneflow.unittest from oneflow import nn @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestModifyForwardOfModule(oneflow.unittest.TestCase): def test_modify_forward(test_case): def forward2(self, x): return x + 1 class Model1(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x class ForwardModifiedGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model self.model.eval() def build(self, x): return self.model(x) test_model = Model1() test_model.forward = MethodType(forward2, test_model) eval_graph_model1 = ForwardModifiedGraph(model=test_model) input_tensor = flow.tensor([0.0], requires_grad=True) eager_out = test_model(input_tensor) graph_out = eval_graph_model1(input_tensor) test_case.assertTrue(np.array_equal(graph_out.numpy(), eager_out.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_multi_client_session.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow import oneflow as flow import oneflow.framework.session_context as session_ctx import oneflow.unittest from oneflow.framework.multi_client_session import MultiClientSession @flow.unittest.skip_unless_1n1d() class TestMultiClientSession(unittest.TestCase): def test_case1(self): sess = session_ctx.GetDefaultSession() self.assertTrue(isinstance(sess, MultiClientSession)) sess.TryInit() self.assertEqual(sess.status, sess.Status.INITED) def test_case2(self): print("test_case2") sess = session_ctx.GetDefaultSession() self.assertTrue(isinstance(sess, MultiClientSession)) sess.TryInit() self.assertEqual(sess.status, sess.Status.INITED) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_multi_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestMultiGraph(oneflow.unittest.TestCase): def test_multi_graph(test_case): relu_data = np.array([2.0, 1.0, 0.0, -1.0, -2.0]) relu_in = flow.tensor(relu_data, dtype=flow.float32) MyRelu = flow.nn.ReLU() relu_out_eager = MyRelu(relu_in) class ReluGraph(flow.nn.Graph): def __init__(self): super().__init__() self.cc_relu = MyRelu def build(self, x): return self.cc_relu(x) relu_g = ReluGraph() relu_out_lazy = relu_g(relu_in) test_case.assertTrue( np.array_equal(relu_out_lazy.numpy(), relu_out_eager.numpy()) ) linear = flow.nn.Linear(3, 8, False) linear = linear.to(flow.device("cuda")) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) linear_in = flow.tensor(input_arr, device=flow.device("cuda")) flow.nn.init.constant_(linear.weight, 2.3) linear_out_eager = linear(linear_in) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue( np.allclose(linear_out_eager.numpy(), np_out, 1e-05, 1e-05) ) class LinearGraph(flow.nn.Graph): def __init__(self): super().__init__() self.my_linear = linear def build(self, x): return self.my_linear(x) linear_g = LinearGraph() linear_out_lazy = linear_g(linear_in) test_case.assertTrue( np.array_equal(linear_out_lazy.numpy(), linear_out_eager.numpy()) ) relu_out_lazy = relu_g(relu_in) linear_out_lazy = linear_g(linear_in) test_case.assertTrue( np.array_equal(relu_out_eager.numpy(), relu_out_lazy.numpy()) ) test_case.assertTrue( np.array_equal(linear_out_eager.numpy(), linear_out_lazy.numpy()) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_multi_tensor_adam_update_with_cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy import os from test_util import GenArgList import oneflow as flow def compare_with_numpy_adam( test_case, device, x_shape, tensor_num, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, ): random_weight_seq = [] init_value_seq = [] for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_weight_seq.append(random_grad_seq_per_iter) for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.add_parameters() def add_parameters(self) -> None: for idx in range(tensor_num): self.register_parameter( f"param_{idx}", flow.nn.Parameter( flow.tensor(init_value_seq[idx], device=flow.device(device)) ), ) def param(self, i): return getattr(self, f"param_{i}") def forward(self, mask_list): out = 0 for idx in range(tensor_num): out += flow._C.matmul(self.param(idx), mask_list[idx]) return out simp_module = CustomModule() simp_module.to(device) simp_module.train() adam0 = flow.optim.Adam( [ { "params": simp_module.parameters(), "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, }, ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, ) class CustomAdamGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(adam0) self.config.enable_amp(True) self.config.allow_fuse_model_update_ops(True) self.config.enable_multi_tensor_update(True) self.config.enable_fused_model_update_cast(True) def build(self, mask_tensor_list): loss = flow.sum(self.m(mask_tensor_list)) loss.backward() return loss of_res_list = [] adam_graph = CustomAdamGraph() for i in range(train_iters): mask_tensor_list = [] for idx in range(tensor_num): mask_tensor_list.append( flow.tensor( random_weight_seq[i][idx], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) ) adam_x = adam_graph(mask_tensor_list) of_res_list.append([]) for idx in range(tensor_num): of_res_list[i].append(copy.copy(simp_module.param(idx).numpy())) np_res_list = [] def train_by_numpy(): x = init_value_seq m = [] v = [] for idx in range(tensor_num): m.append(np.zeros_like(x[idx])) v.append(np.zeros_like(x[idx])) beta1 = betas[0] beta2 = betas[1] ones = np.ones(x_shape).astype(np.float32) def train_one_iter(step, weight): for i in range(tensor_num): transposed_weight = np.transpose(weight[i], (1, 0)) grad = np.matmul(ones, transposed_weight) grad = grad + weight_decay * x[i] bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) m[i] = beta1 * m[i] + (1 - beta1) * grad v[i] = beta2 * v[i] + (1 - beta2) * grad * grad denom = np.sqrt(v[i]) / np.sqrt(bias_correction2) + eps x[i] = x[i] - ((learning_rate / bias_correction1) * m[i] / denom) return (x, m, v) for i in range(1, train_iters + 1): x, m, v = train_one_iter(i, random_weight_seq[i - 1]) np_res_list.append(copy.copy(x)) train_by_numpy() for i in range(tensor_num): test_case.assertTrue( np.allclose(np_res_list[i], of_res_list[i], rtol=1e-3, atol=1e-3) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestMultiTensorAdam(flow.unittest.TestCase): def test_multi_tensor_adam(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] arg_dict["x_shape"] = [(4, 4)] arg_dict["tensor_num"] = [4, 6] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.0, 1e-3] arg_dict["eps"] = [1e-5] arg_dict["do_bias_correction"] = [True, False] arg_dict["amsgrad"] = [False] # Multi tensor update do not support amsgrad for arg in GenArgList(arg_dict): compare_with_numpy_adam(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_multi_tensor_sgd_update_with_cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import copy import os from test_util import GenArgList import oneflow as flow def compare_with_numpy_sgd( test_case, device, x_shape, tensor_num, learning_rate, train_iters, weight_decay ): random_weight_seq = [] init_value_seq = [] for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_weight_seq.append(random_grad_seq_per_iter) for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.add_parameters() def add_parameters(self) -> None: for idx in range(tensor_num): self.register_parameter( f"param_{idx}", flow.nn.Parameter( flow.tensor(init_value_seq[idx], device=flow.device(device)) ), ) def param(self, i): return getattr(self, f"param_{i}") def forward(self, mask_list): out = 0 for idx in range(tensor_num): out += flow._C.matmul(self.param(idx), mask_list[idx]) return out simp_module = CustomModule() simp_module.to(device) simp_module.train() sgd0 = flow.optim.SGD( [ { "params": simp_module.parameters(), "lr": learning_rate, "weight_decay": weight_decay, } ], ) class CustomSGDGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = simp_module self.add_optimizer(sgd0) self.config.enable_amp(True) self.config.allow_fuse_model_update_ops(True) self.config.enable_multi_tensor_update(True) self.config.enable_fused_model_update_cast(True) def build(self, mask_tensor_list): loss = flow.sum(self.m(mask_tensor_list)) loss.backward() return loss of_res_list = [] sgd_graph = CustomSGDGraph() for i in range(train_iters): mask_tensor_list = [] for idx in range(tensor_num): mask_tensor_list.append( flow.tensor( random_weight_seq[i][idx], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) ) sgd_x = sgd_graph(mask_tensor_list) of_res_list.append([]) for idx in range(tensor_num): of_res_list[i].append(copy.copy(simp_module.param(idx).numpy())) np_res_list = [] def train_by_numpy(): x = init_value_seq ones = np.ones(x_shape).astype(np.float32) def train_one_iter(weight): for i in range(tensor_num): transposed_weight = np.transpose(weight[i], (1, 0)) grad = np.matmul(ones, transposed_weight) grad = grad + weight_decay * x[i] x[i] = x[i] - learning_rate * grad return x for i in range(train_iters): x = train_one_iter(random_weight_seq[i]) np_res_list.append(copy.copy(x)) train_by_numpy() for i in range(tensor_num): test_case.assertTrue( np.allclose(np_res_list[i], of_res_list[i], rtol=1e-3, atol=1e-3) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestMultiTensorSGD(flow.unittest.TestCase): def test_multi_tensor_sgd(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] arg_dict["x_shape"] = [(4, 4)] arg_dict["tensor_num"] = [4, 6] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["weight_decay"] = [0.0, 1e-3] for arg in GenArgList(arg_dict): compare_with_numpy_sgd(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_nccl_logical_send_recv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import time import os def _test_nccl_logical_send_recv_2d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return # skip src == dst if src_nd_sbp == dst_nd_sbp: return # in this case, use intra group boxing if src_nd_sbp[0] == dst_nd_sbp[0]: return # in this case, use inter group boxing if ( src_nd_sbp[1] == dst_nd_sbp[1] and src_nd_sbp[0] != src_nd_sbp[1] and dst_nd_sbp[0] != dst_nd_sbp[1] ): return # input placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) local_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclLogicalSendRecv2DGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y graph = TestNcclLogicalSendRecv2DGraph() # graph.debug() y = graph(x) out_np = y.numpy() in_np = x.numpy() # if flow.env.get_rank() == 0: # print("src sbp ", src_nd_sbp, ", dst sbp ", dst_nd_sbp) # equal = np.array_equal(out_np, in_np) # if not equal: # print("in ", in_np) # print("out ", out_np) test_case.assertTrue(np.array_equal(out_np, in_np)) flow.boxing.nccl.enable_use_compute_stream(False) def gen_2d_sbp(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] nd_sbp_list = [] for sbp0 in sbp_list: for sbp1 in sbp_list: nd_sbp_list.append([sbp0, sbp1]) return nd_sbp_list @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestNcclLogicalSendRecv2D(flow.unittest.TestCase): def test_nccl_logical_send_recv_2d(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_2d_sbp() arg_dict["dst_nd_sbp"] = gen_2d_sbp() for arg in GenArgList(arg_dict): _test_nccl_logical_send_recv_2d(test_case, *arg) def _test_nccl_logical_send_recv_1d(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return # skip src == dst if src_nd_sbp == dst_nd_sbp: return # input placement = flow.placement("cuda", ranks=[0, 1]) local_np = np.arange(2 * 2 * 2).reshape(2, 2, 2) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # check graph boxing flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclLogicalSendRecv1DGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, x): y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y graph = TestNcclLogicalSendRecv1DGraph() # graph.debug(0) y = graph(x) out_np = y.numpy() in_np = x.numpy() # if flow.env.get_rank() == 0: # print("src sbp ", src_nd_sbp, ", dst sbp ", dst_nd_sbp) # print(graph) # equal = np.array_equal(out_np, in_np) # if not equal: # print("in ", in_np) # print("out ", out_np) # print("====================") test_case.assertTrue(np.array_equal(out_np, in_np)) def gen_1d_sbp(): sbp_list = [ flow.sbp.partial_sum(), flow.sbp.broadcast(), flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), ] nd_sbp_list = [] for sbp0 in sbp_list: nd_sbp_list.append( [sbp0,] ) return nd_sbp_list @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestNcclLogicalSendRecv1D(flow.unittest.TestCase): def test_nccl_logical_send_recv_1d(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1" arg_dict = OrderedDict() arg_dict["src_nd_sbp"] = gen_1d_sbp() arg_dict["dst_nd_sbp"] = gen_1d_sbp() for arg in GenArgList(arg_dict): _test_nccl_logical_send_recv_1d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_neq_device_process_num.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow import oneflow as flow import oneflow.unittest import oneflow.sysconfig from oneflow.nn.graph import GraphModule @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGraphNeqDeviceProcessNum(flow.unittest.TestCase): def test_graph_process_num_greater_than_device(test_case): # NOTE(chengcheng): this test case is ONLY for 1n8d in 4d env. if not (flow.env.get_node_size() == 1 and flow.env.get_world_size() == 8): return if not oneflow.sysconfig.has_rpc_backend_grpc(): return BATCH_SIZE = 64 BROADCAST = [flow.sbp.broadcast] P0 = flow.placement("cpu", ranks=[0, 1, 2, 3]) P1 = flow.placement("cpu", ranks=[4, 5, 6, 7]) class Stage0Module(flow.nn.Module): def __init__(self): super().__init__() self.flatten = flow.nn.Flatten() self.linear0 = flow.nn.Linear(28 * 28, 512) self.relu0 = flow.nn.ReLU() def forward(self, x): out = self.flatten(x) out = self.linear0(out) out = self.relu0(out) return out class Stage1Module(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(512, 512) self.relu1 = flow.nn.ReLU() self.linear2 = flow.nn.Linear(512, 10) self.relu2 = flow.nn.ReLU() def forward(self, x): out = self.linear1(x) out = self.relu1(out) out = self.linear2(out) out = self.relu2(out) return out class PipelineModule(flow.nn.Module): def __init__(self): super().__init__() self.m_stage0 = Stage0Module() self.m_stage1 = Stage1Module() self.m_stage0.to_global(placement=P0, sbp=BROADCAST) self.m_stage1.to_global(placement=P1, sbp=BROADCAST) def forward(self, x): out_stage0 = self.m_stage0(x) in_stage1 = out_stage0.to_global(placement=P1, sbp=flow.sbp.split(0)) out_stage1 = self.m_stage1(in_stage1) return out_stage1 module_pipeline = PipelineModule() sgd = flow.optim.SGD(module_pipeline.parameters(), lr=0.001) class PipelineGraph(flow.nn.Graph): def __init__(self): super().__init__() self.module_pipeline = module_pipeline self.module_pipeline.m_stage0.to(GraphModule).set_stage(0) self.module_pipeline.m_stage1.to(GraphModule).set_stage(1) self.loss_fn = flow.nn.CrossEntropyLoss(reduction="none") self.config.set_gradient_accumulation_steps(2) self.add_optimizer(sgd) def build(self, x, y): out = self.module_pipeline(x) loss = self.loss_fn(out, y).sum() loss = loss.to_global(placement=P1, sbp=BROADCAST) loss.backward() return loss graph_pipeline = PipelineGraph() graph_pipeline.debug(1) x = flow.randn(BATCH_SIZE, 1, 28, 28) x = x.to_global(P0, sbp=flow.sbp.split(0)) y = flow.randint(0, 10, (BATCH_SIZE, 1)) y = y.to_global(P1, sbp=flow.sbp.split(0)) for i in range(2): loss = graph_pipeline(x, y) print(">>>>>>>", flow.env.get_rank(), loss.to_local().numpy(), flush=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_oneflow_compiler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest import torch from oneflow.framework.infer_compiler import compile_from_torch, register from oneflow.framework.infer_compiler.with_oneflow_compile import ( DualModule, DualModuleList, ) class TorchModule(torch.nn.Module): def __init__(self): super().__init__() self.linears = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(10)]) def forward(self, x): for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x class FlowModule(flow.nn.Module): def __init__(self): super().__init__() self.linears = flow.nn.ModuleList([flow.nn.Linear(10, 10) for _ in range(10)]) def forward(self, x): for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestOneflowInferCompiler(flow.unittest.TestCase): def setUp(self): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" def test_compile_from_torch(test_case): register(torch2oflow_class_map={TorchModule: FlowModule}) m = TorchModule().to("cuda") x = torch.randn(2, 10).to("cuda") y_torch = m(x) m = compile_from_torch(m) y_flow = m(x) test_case.assertTrue( np.allclose(y_torch.detach().cpu(), y_flow.detach().cpu(), 1e-03, 1e-03) ) test_case.assertIsInstance(m.linears, DualModuleList) x = getattr(m.linears, "1") test_case.assertIsInstance(x, DualModule) x.bias = None setattr(m.linears, "2", x) test_case.assertIsNone(m.linears[2].bias) test_case.assertIsNone(m.linears._torch_modules[2].bias) test_case.assertIsNone(m.linears._oneflow_modules[2].bias) m.linears[3] = x test_case.assertIsNone(m.linears[3].bias) test_case.assertIsNone(m.linears._torch_modules[3].bias) test_case.assertIsNone(m.linears._oneflow_modules[3].bias) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_optimization_conf.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow.framework.session_context as session_ctx import oneflow as flow import oneflow.unittest import oneflow.framework.config_util as config_util import oneflow.framework.attr_util as attr_util import random @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphWithSysConf(flow.unittest.TestCase): def test_graph_config(test_case): flow.boxing.enable_fusion(True) flow.boxing.nccl.set_fusion_threshold_mbytes(800) flow.boxing.nccl.set_fusion_max_ops_num(10) flow.boxing.nccl.allow_fuse_all_reduce(True) flow.boxing.nccl.allow_fuse_reduce_scatter(True) flow.boxing.nccl.allow_fuse_all_gather(True) flow.boxing.nccl.allow_fuse_reduce(True) flow.boxing.nccl.allow_fuse_broadcast(True) flow.boxing.nccl.allow_fuse_mixed_ops(True) flow.boxing.nccl.enable_use_buffer_to_fuse_all_reduce(True) flow.boxing.nccl.set_stream_num(3) flow.boxing.nccl.enable_all_to_all(True) flow.boxing.nccl.enable_use_compute_stream(True) flow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) flow.backends.cudnn.set_reserved_mem_mbytes(1000) flow.backends.cudnn.enable_fused_normalization_add_relu(True) flow.backends.cudnn.enable_conv_heuristic_search_algo(False) flow.utils.load_library("") class CustomGraphSysConf(flow.nn.Graph): def __init__(self): super().__init__() # amp self.config.enable_amp(True) grad_scaler = flow.amp.GradScaler( init_scale=3000, growth_factor=2.0, backoff_factor=0.5, growth_interval=1000, ) self.set_grad_scaler(grad_scaler) self.config.allow_fuse_model_update_ops(True) self.config.allow_fuse_add_to_output(True) self.config.set_gradient_accumulation_steps(100) self.config.allow_fuse_cast_scale(True) self.config.enable_zero(True) self.config.enable_cudnn_conv_heuristic_search_algo(False) def build(self, x): return x g = CustomGraphSysConf() print("optimization conf: \n", g._optimization_conf_proto) test_case.assertTrue(g._optimization_conf_proto.nccl_use_compute_stream) g._generate_config_proto() print("graph conf: \n", g._config_proto) # Test the resource config update eagerly # Note: this tests all the apis in oneflow.framework.config_util automatically def test_resource_config_update_apis_eagerly_automatically(): attrs_and_values_to_check = [] num_api_tested = 0 for api in config_util.api_attrs_and_type.keys(): attrs, type_ = config_util.api_attrs_and_type[api] if type_ is int: attr_value = random.randint(0, 9999) attrs_and_values_to_check.append((attrs, attr_value)) elif type_ is bool: attr_value = random.choice([True, False]) attrs_and_values_to_check.append((attrs, attr_value)) else: raise TypeError("Unsupported type!") api(attr_value) num_api_tested += 1 # check all the attributes are set correctly for (attrs, expected_attr_value) in attrs_and_values_to_check: current_attr_value = attr_util.get_nested_attribute( g._optimization_conf_proto, attrs ) test_case.assertTrue( current_attr_value == expected_attr_value, str(attrs) + " : " + str(current_attr_value) + " vs " + str(current_attr_value), ) print("number of APIs tested: " + str(num_api_tested)) # save the resource config before running random resource api tests session = session_ctx.GetDefaultSession() prev_resource_config = session.resource for i in range(5): test_resource_config_update_apis_eagerly_automatically() print("optimization conf after session init: \n", g._optimization_conf_proto) # restore the resource config session.update_resource_eagerly(prev_resource_config) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_output_op_expr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np from google.protobuf import text_format import oneflow import oneflow as flow import oneflow._oneflow_internal import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest from oneflow.framework.multi_client_session import MultiClientSession @flow.unittest.skip_unless_1n1d() class TestFetchOutputTensor(unittest.TestCase): def test_fetch_output_tensor(test_case): x = flow.Tensor(1, 1, 10, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) session = session_ctx.GetDefaultSession() test_case.assertTrue(isinstance(session, MultiClientSession)) session.TryInit() with oneflow._oneflow_internal.lazy_mode.guard(True): oneflow._oneflow_internal.JobBuildAndInferCtx_Open( "cc_test_output_op_expr_job" ) job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto() job_conf.job_name = "cc_test_output_op_expr_job" job_conf.predict_conf.SetInParent() c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf() input_conf.in_0 = "EagerTensorInput" input_conf.out_0 = "out_0" input_conf_str = text_format.MessageToString(input_conf) input_op = oneflow._oneflow_internal.one.FeedInputOpExpr( "cc_Input_0", input_conf_str, ["in_0"], ["out_0"] ) output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf() output_conf.in_0 = "LazyTensorInput" output_conf.out_0 = "out_0" output_conf_str = text_format.MessageToString(output_conf) output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( "cc_Output_0", output_conf_str, ["in_0"], ["out_0"] ) lazy_tensor = _C.dispatch_feed_input(input_op, x) test_case.assertEqual(lazy_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(lazy_tensor.is_lazy) test_case.assertTrue(lazy_tensor.is_local) eager_tensor = _C.dispatch_fetch_output(output_op, lazy_tensor) test_case.assertEqual(eager_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(not eager_tensor.is_lazy) test_case.assertTrue(eager_tensor.is_local) oneflow._oneflow_internal.JobBuildAndInferCtx_Close() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_run_global_graph_by_vm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import oneflow as flow import oneflow.unittest import numpy as np from test_run_graph_by_vm import RunGraphByVmEnv, Graph from test_graph_ofrecord_reader import OFRecordDataLoader @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalInterpreter(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_data_parallel_run_by_vm(test_case): with RunGraphByVmEnv(): class DataParallelMul(flow.nn.Module): def __init__(self, placement) -> None: super().__init__() self.w = flow.randn( 5, 8, placement=placement, sbp=flow.sbp.broadcast ) def forward(self, x): return flow.matmul(x, self.w) placement = flow.placement("cuda", [0, 1]) m = DataParallelMul(placement).eval() g = Graph(m) input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.split(0)) graph_output = g(input) eager_output = m(input) test_case.assertTrue(graph_output.sbp == eager_output.sbp) test_case.assertTrue(graph_output.shape == eager_output.shape) test_case.assertTrue(graph_output.placement == eager_output.placement) test_case.assertTrue(np.allclose(graph_output, eager_output)) @flow.unittest.skip_unless_1n2d() def test_module_parallel_run_by_vm(test_case): with RunGraphByVmEnv(): class ModuleParallelMul(flow.nn.Module): def __init__(self, placement) -> None: super().__init__() self.w = flow.randn( 5, 8, placement=placement, sbp=flow.sbp.split(1) ) def forward(self, x): return flow.matmul(x, self.w) placement = flow.placement("cuda", [0, 1]) m = ModuleParallelMul(placement).eval() g = Graph(m) input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.broadcast) graph_output = g(input) eager_output = m(input) test_case.assertTrue(graph_output.sbp == eager_output.sbp) test_case.assertTrue(graph_output.shape == eager_output.shape) test_case.assertTrue(graph_output.placement == eager_output.placement) test_case.assertTrue(np.allclose(graph_output, eager_output)) @flow.unittest.skip_unless_1n2d() def test_boxing_data_parallel_run_by_vm(test_case): with RunGraphByVmEnv(): flow.boxing.nccl.enable_use_compute_stream(False) class BoxingModuleParallelMul(flow.nn.Module): def __init__(self, placement) -> None: super().__init__() self.w1 = flow.randn( 5, 8, placement=placement, sbp=flow.sbp.split(1) ) self.w2 = flow.randn( 8, 6, placement=placement, sbp=flow.sbp.split(1) ) def forward(self, x): x = flow.matmul(x, self.w1) x = flow.matmul(x, self.w2) return x placement = flow.placement("cuda", [0, 1]) m = BoxingModuleParallelMul(placement).eval() g = Graph(m) input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.broadcast) graph_output = g(input) eager_output = m(input) test_case.assertTrue(graph_output.sbp == eager_output.sbp) test_case.assertTrue(graph_output.shape == eager_output.shape) test_case.assertTrue(graph_output.placement == eager_output.placement) test_case.assertTrue(np.allclose(graph_output, eager_output)) @flow.unittest.skip_unless_1n1d() def test_empty_inputs(test_case): with RunGraphByVmEnv(): class GraphReader(flow.nn.Graph): def __init__(self): super().__init__() self.my_reader = OFRecordDataLoader() def build(self): return self.my_reader() reader_g = GraphReader() image, label = reader_g() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_run_graph_by_vm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import oneflow as flow import numpy as np class EnvVar(object): def __init__(self, env_list: dict): self.env_list = env_list def __enter__(self): os.environ.update(self.env_list) def __exit__(self, *args): for key in self.env_list.keys(): if key in os.environ.keys(): os.environ.pop(key) class RunGraphByVmEnv(EnvVar): def __init__(self): super().__init__( { "ONEFLOW_RUN_GRAPH_BY_VM": "1", "ONEFLOW_MLIR_ENABLE_ROUND_TRIP": "1", "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION": "1", } ) class Graph(flow.nn.Graph): def __init__(self, m): super().__init__() self.m = m def build(self, x): return self.m(x) class M(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.randn(4)) def forward(self, x): # these broadcast_sub and cast ops will be # eliminated by nn.Graph w1 = self.w - self.w - self.w x = x * w1.to(flow.float32) return x def test_run_graph_by_vm(capsys): with RunGraphByVmEnv(): m = M().eval() g = Graph(m) input = flow.randn(4) graph_output = g(input) eager_output = m(input) assert graph_output.shape == (4,) assert np.allclose(graph_output, eager_output) input = flow.randn(3, 4) graph_output = g(input) eager_output = m(input) assert graph_output.shape == (3, 4) assert np.allclose(graph_output, eager_output) # Test the optimization in graph works. # broadcast_sub and cast ops are pruned. print(g) assert "broadcast_sub" not in capsys.readouterr().out assert "cast" not in capsys.readouterr().out assert "broadcast_mul" not in capsys.readouterr().out ================================================ FILE: python/oneflow/test/graph/test_to_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest x = np.array( [ [ 0.21490018, 0.22043167, 0.1605895, 0.25424683, 0.12975895, 0.49967155, 0.04753795, 0.7518577, 0.38964537, 0.01955934, ], [ 0.16392729, 0.41410774, 0.05424517, 0.7668146, 0.08050849, 0.5763975, 0.42364502, 0.4950619, 0.9608427, 0.11889187, ], ] ) y = np.array( [ [ 0.9903706, 0.11213686, 0.29525927, 0.79380244, 0.70357895, 0.6950597, 0.52552456, 0.32304054, 0.6997739, 0.15671141, ], [ 0.76867193, 0.59983397, 0.07774717, 0.07815815, 0.30385414, 0.7366552, 0.4607681, 0.40554753, 0.8290172, 0.8405671, ], [ 0.8900324, 0.5274955, 0.80989295, 0.71331054, 0.8076364, 0.94833183, 0.04778554, 0.23992656, 0.57683426, 0.81757474, ], ] ) class MyModule1(flow.nn.Module): def __init__(self, weight): assert isinstance(weight, flow._oneflow_internal.Tensor) super().__init__() self.weight = flow.nn.Parameter(weight) self.activation = flow.nn.ReLU() def forward(self, x): # print(f"x shape: {x.shape}, placement: {x.placement}, sbp: {x.sbp}") # print( # f"weight shape: {self.weight.shape}, placement: {self.weight.placement}, sbp: {self.weight.sbp}" # ) y = flow._C.matmul(x, self.weight, transpose_b=True) # print(f"y shape: {y.shape}, placement: {y.placement}, sbp: {y.sbp}") if y.is_global: y = y.to_global(sbp=flow.sbp.broadcast) # print(f"post y shape: {y.shape}, placement: {y.placement}, sbp: {y.sbp}") return self.activation(y) class MyModule2(flow.nn.Module): def __init__(self, weight): assert isinstance(weight, flow._oneflow_internal.Tensor) super().__init__() self.weight = flow.nn.Parameter(weight) self.activation = flow.nn.ReLU() def forward(self, x): # print(f"weight shape: {self.weight.shape}, placement: {self.weight.placement}, sbp: {self.weight.sbp}") if self.weight.is_global: y = self.weight.to_global(grad_sbp=flow.sbp.broadcast) z = flow._C.matmul(y, x, transpose_b=True) out = self.activation(z).sum() if self.weight.is_global: out = out.to_global(sbp=flow.sbp.broadcast) return out class MyModule3(flow.nn.Module): def __init__(self, transpose_a=False, transpose_b=False): super().__init__() self.activation = flow.nn.ReLU() self.transpose_a = transpose_a self.transpose_b = transpose_b def forward(self, x, y): z = flow._C.matmul(x, y, self.transpose_a, self.transpose_b) if z.is_global: z = z.to_global(sbp=flow.sbp.broadcast) return self.activation(z) class GlobalToModule(flow.nn.Module): def __init__(self, device="cuda"): super().__init__() self.device = device def forward(self, x): return x.to(self.device) class FreeTensorModule(flow.nn.Module): def __init__(self, shape, placement, sbp): super().__init__() self.shape = shape self.placement = placement self.sbp = sbp def forward(self, x): y = flow.ones( self.shape, dtype=flow.float32, placement=self.placement, sbp=self.sbp ) return flow._C.matmul(x, y, transpose_b=True) class ToPlacementModule(flow.nn.Module): def __init__(self, placement): super().__init__() self.placement = placement def forward(self, x): return x.to_global(placement=self.placement) class MyGraph(flow.nn.Graph): def __init__(self, module, optimizer=None): super().__init__() self.module = module if optimizer is not None: self.add_optimizer(optimizer) def build(self, *arg): y = self.module(*arg) if self.config.training: y.backward() return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class ToGlobalGraphTestCase(oneflow.unittest.TestCase): # @unittest.skipIf(True, "") def test_fwd_P2B(test_case): """ compare eager fwd and lazy bwd """ rank = flow.env.get_rank() # pid = os.getpid() # print(f"[{pid}][{rank}] ToGlobalGraphTestCase.test_fwd_P2B") local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) z = flow._C.matmul( flow.cat([local_x, local_x], dim=1), flow.cat([local_y, local_y], dim=1), transpose_b=True, ) z = flow._C.relu(z) # print(f"z shape: {z.shape}, device: {z.device}") # print(z.numpy()) placement = flow.placement("cuda", ranks=[0, 1]) sbp = flow.sbp.split(1) c_x = local_x.to_global(placement=placement, sbp=sbp) c_y = local_y.to_global(placement=placement, sbp=sbp) # print(f"c_x shape: {c_x.shape}, placement: {c_x.placement}, sbp: {c_x.sbp}") # print(f"c_y shape: {c_y.shape}, placement: {c_y.placement}, sbp: {c_y.sbp}") m = MyModule1(c_y) g = MyGraph(m) g_z = g(c_x) # print(f"g_z shape: {g_z.shape}, placement: {g_z.placement}, sbp: {g_z.sbp}") # print(g_z.to_local().numpy()) test_case.assertTrue(np.allclose(z.numpy(), g_z.to_local().numpy())) # @unittest.skipIf(True, "") def test_bwd_P2B(test_case): """ compare eager bwd and lazy bwd """ rank = flow.env.get_rank() # pid = os.getpid() # print(f"[{pid}][{rank}] ToGlobalGraphTestCase.test_bwd_P2B") local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) z = flow._C.matmul( local_y, flow.cat([local_x, local_x], dim=0), transpose_b=True, ) z = flow._C.relu(z) z = z.sum() placement = flow.placement("cuda", ranks=[0, 1]) c_x = local_x.to_global(placement=placement, sbp=flow.sbp.split(0)) c_y = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast) m = MyModule2(c_y) optimizer = flow.optim.SGD(m.parameters(), lr=1.0) g = MyGraph(m, optimizer) g_z = g(c_x) # print(f"g_z shape: {g_z.shape}, placement: {g_z.placement}, sbp: {g_z.sbp}") test_case.assertTrue(g_z.is_global) test_case.assertTrue(g_z.sbp[0] == flow.sbp.broadcast) # S(1) -> B not supported yet # c_z = g_z.to_global(sbp=flow.sbp.broadcast) # print(f"c_z shape: {c_z.shape}, placement: {c_z.placement}, sbp: {c_z.sbp}") test_case.assertTrue(np.allclose(z.numpy(), g_z.to_local().numpy())) e_y = c_y.detach() # print(f"e_y shape: {e_y.shape}, placement: {e_y.placement}, sbp: {e_y.sbp}") e_m = MyModule2(e_y) e_z = e_m(c_x) # print(f"e_z shape: {e_z.shape}, placement: {e_z.placement}, sbp: {e_z.sbp}") e_z.backward() test_case.assertTrue( np.allclose(c_y.to_local().numpy(), e_y.to_local().numpy()) ) # @unittest.skipIf(True, "") def test_multi_graph(test_case): """ compare two lazy fwd """ rank = flow.env.get_rank() # pid = os.getpid() # print(f"[{pid}][{rank}] ToGlobalGraphTestCase.test_multi_graph") local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f"cuda:{rank}")) placement = flow.placement("cuda", ranks=[0, 1]) x1 = local_x.to_global(placement=placement, sbp=flow.sbp.broadcast) y1 = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast) # B * B -> B -> B m1 = MyModule3(transpose_b=True) g1 = MyGraph(m1) slice_obj = slice( int(rank * local_x.shape[0] / 2), int((rank + 1) * local_x.shape[0] / 2) ) x2 = local_x[slice_obj, :] x2 = x2.to_global(placement=placement, sbp=flow.sbp.split(0)) y2 = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast) # S(0) * B -> S(0) -> B m2 = MyModule3(transpose_b=True) g2 = MyGraph(m2) x3 = local_x[ :, int(rank * local_x.shape[1] / 2) : int((rank + 1) * local_x.shape[1] / 2) ] x3 = x3.to_global(placement=placement, sbp=flow.sbp.split(1)) y3 = local_y[ :, int(rank * local_y.shape[1] / 2) : int((rank + 1) * local_y.shape[1] / 2) ] y3 = y3.to_global(placement=placement, sbp=flow.sbp.split(1)) # S(1) * S(0) -> P -> B m3 = MyModule3(transpose_b=True) g3 = MyGraph(m3) z1 = g1(x1, y1) # print(f"z1 shape: {z1.shape}, placement: {z1.placement}, sbp: {z1.sbp}") # print(z1.to_local().numpy()) z2 = g2(x2, y2) # print(f"z2 shape: {z2.shape}, placement: {z2.placement}, sbp: {z2.sbp}") # print(z2.to_local().numpy()) z3 = g3(x3, y3) # print(f"z3 shape: {z3.shape}, placement: {z3.placement}, sbp: {z3.sbp}") # print(z3.to_local().numpy()) test_case.assertTrue(np.allclose(z1.to_local().numpy(), z2.to_local().numpy())) test_case.assertTrue(np.allclose(z1.to_local().numpy(), z3.to_local().numpy())) # @unittest.skipIf(True, "") def test_global_to(test_case): c_x = flow.ones( (4, 3), placement=flow.placement("cpu", ranks=[0, 1]), sbp=flow.sbp.split(0) ) global_to = GlobalToModule("cuda") g_global_to = MyGraph(global_to) e = global_to(c_x) test_case.assertTrue(e.is_cuda) test_case.assertTrue(e.is_global) test_case.assertTrue(e.sbp[0] == flow.sbp.split(0)) g = g_global_to(c_x) test_case.assertTrue(g.is_cuda) test_case.assertTrue(g.is_global) test_case.assertTrue(g.sbp[0] == flow.sbp.split(0)) test_case.assertTrue(np.allclose(e.to_local().numpy(), g.to_local().numpy())) # @unittest.skipIf(True, "") def test_free_tensor_to_global(test_case): local_x = flow.tensor(x, dtype=flow.float32, device="cpu") placement = flow.placement("cuda", ranks=[0, 1]) c_x = local_x.to_global(placement, flow.sbp.split(0)) m = FreeTensorModule((3, 10), placement, flow.sbp.broadcast) g = MyGraph(m) eager_out = m(c_x) test_case.assertTrue(eager_out.is_cuda) test_case.assertTrue(eager_out.is_global) test_case.assertTrue(eager_out.sbp[0] == flow.sbp.split(0)) graph_out = g(c_x) test_case.assertTrue(graph_out.is_cuda) test_case.assertTrue(graph_out.is_global) test_case.assertTrue(graph_out.sbp[0] == flow.sbp.split(0)) test_case.assertTrue( np.allclose(eager_out.to_local().numpy(), graph_out.to_local().numpy()) ) # @unittest.skipIf(True, "") def test_to_placement(test_case): rank = flow.env.get_rank() # pid = os.getpid() # print(f"[{pid}][{rank}] ToGlobalGraphTestCase.test_to_placement") if rank == 0: x = flow.ones((2, 3), dtype=flow.float32) elif rank == 1: x = flow.empty(tuple()) else: raise ValueError c_x = x.to_global( placement=flow.placement("cpu", ranks=[0]), sbp=flow.sbp.broadcast ) # print(f"c_x shape: {c_x.shape}, placement: {c_x.placement}, sbp: {c_x.sbp}") p1 = flow.placement("cpu", ranks=[0, 1]) m1 = ToPlacementModule(p1) g1 = MyGraph(m1) y1 = g1(c_x) # print(f"y1 shape: {y1.shape}, placement: {y1.placement}, sbp: {y1.sbp}") test_case.assertTrue(y1.placement == p1) test_case.assertTrue(y1.sbp[0] == flow.sbp.broadcast) test_case.assertTrue(y1.to_local().numpy().mean() == 1.0) p2 = flow.placement("cuda", ranks=[0, 1]) m2 = ToPlacementModule(p2) g2 = MyGraph(m2) y2 = g2(y1) # print(f"y2 shape: {y2.shape}, placement: {y2.placement}, sbp: {y2.sbp}") test_case.assertTrue(y2.placement == p2) test_case.assertTrue(y2.sbp[0] == flow.sbp.broadcast) test_case.assertTrue(y2.to_local().numpy().mean() == 1.0) # @unittest.skipIf(True, "") def test_to_dtype(test_case): x = flow.ones((2, 3), dtype=flow.int32, device="cpu") placement = flow.placement("cpu", ranks=[0, 1]) c_x = flow.ones( (2, 3), dtype=flow.int32, placement=placement, sbp=flow.sbp.broadcast ) class CastModule(flow.nn.Module): def __init__(self, dtype): super().__init__() self.dtype = dtype def forward(self, x): return x.to(dtype=self.dtype) m = CastModule(flow.float32) g = MyGraph(m) e_x = m(x) e_c_x = m(c_x) # NOTE(chengcheng): # There are two BUG in this test script: # 1. first call and second call input tensor meta is NOT same # 2. nn.Graph NOT support local input with multi-rank yet. # g_x = g(x) g_c_x = g(c_x) test_case.assertTrue(e_x.dtype == flow.float32) # test_case.assertTrue(g_x.dtype == flow.float32) test_case.assertTrue(e_c_x.dtype == flow.float32) test_case.assertTrue(g_c_x.dtype == flow.float32) class MyModule5(flow.nn.Module): def __init__(self, transpose_a=False, transpose_b=False, sbp=[]): super().__init__() self.transpose_a = transpose_a self.transpose_b = transpose_b self.sbp = sbp def forward(self, x, y): z = flow._C.matmul(x, y, self.transpose_a, self.transpose_b) assert z.is_global assert len(z.sbp) == len(self.sbp) return z.to_global(sbp=self.sbp) @unittest.skipIf(True, "") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class ToGlobal2DGraphTestCase(oneflow.unittest.TestCase): def test_matmul(test_case): placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) x = flow.ones( (4, 6), placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)] ) y = flow.ones( (4, 6), placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.split(1)] ) z = flow._C.matmul(x, y, transpose_b=True) print(f"z shape: {z.shape}, placement: {z.placement}, sbp: {z.sbp}") # m = MyModule5(transpose_b=True, sbp=[flow.sbp.split(0), flow.sbp.broadcast]) # z = m(x, y) # print(f"z shape: {z.shape}, placement: {z.placement}, sbp: {z.sbp}") @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestLazy1dTo2dGlobal(flow.unittest.TestCase): def test_lazy_1d_to_2d_sbp(test_case): P_1d = flow.placement( device_type="cuda", device_ids={0: range(4)}, hierarchy=(4,) ) P_2d = flow.placement( device_type="cuda", device_ids={0: range(4)}, hierarchy=(2, 2) ) B = flow.sbp.broadcast class Test1dTo2dModule(flow.nn.Module): def forward(self, x): return x.to_global(placement=P_2d, sbp=[B, B]) class Test1dTo2dGraph(flow.nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): return self.model(x) class Test2dTo1dModule(flow.nn.Module): def forward(self, x): return x.to_global(placement=P_1d, sbp=[B]) class Test2dTo1dGraph(flow.nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): return self.model(x) model_1d_to_2d = Test1dTo2dModule() graph_1d_to_2d = Test1dTo2dGraph(model_1d_to_2d) x = flow.zeros(4, 4, 4, 4, sbp=[B, B], placement=P_2d) x = x.to_global(placement=P_1d, sbp=[B]) test_case.assertTrue(x.sbp == (B,)) test_case.assertTrue(x.placement == P_1d) y = graph_1d_to_2d(x) test_case.assertTrue(y.sbp == (B, B)) test_case.assertTrue(y.placement == P_2d) model_2d_to_1d = Test2dTo1dModule() graph_2d_to_1d = Test2dTo1dGraph(model_2d_to_1d) z = graph_2d_to_1d(y) test_case.assertTrue(z.sbp == x.sbp) test_case.assertTrue(z.placement == x.placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_tvm_frontend_dependency_on_graph.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import unittest import numpy as np import oneflow as flow import oneflow.unittest from alexnet_model import alexnet class TvmFrontedGraph(flow.nn.Graph): def __init__(self, module): super().__init__() self.m = module def build(self, x): out = self.m(x) return out def parse_attr(attr): # Parse node_attr attrs = {} for a in attr: attr_str = str(attr[a]) if attr_str[0:7] == "at_list": attr_str_ = attr_str.split(" ")[0] if attr_str_ == "at_list_float": attrs[a] = tuple(attr[a].at_list_float.val) elif attr_str_ == "at_list_int32": attrs[a] = tuple(attr[a].at_list_int32.val) elif attr_str_ == "at_list_int64": attrs[a] = tuple(attr[a].at_list_int64.val) elif attr_str.split(":")[0] == "at_string": attrs[a] = attr[a].at_string elif attr_str.split(" ")[0] == "at_shape": attrs[a] = tuple(list(attr[a].at_shape.dim)) else: attr_str_ = attr_str.split(":")[0] if attr_str_ == "at_bool": attrs[a] = attr[a].at_bool elif attr_str_ == "at_double": attrs[a] = attr[a].at_double elif attr_str_ == "at_float": attrs[a] = attr[a].at_float elif attr_str_ == "at_int32": attrs[a] = attr[a].at_int32 elif attr_str_ == "at_int64": attrs[a] = attr[a].at_int64 return attrs def is_user_op(node): # Determine if the the node is the intermediate variables of graph return node.WhichOneof("op_type") == "user_conf" @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestConvertDependency(flow.unittest.TestCase): def test_get_params(test_case): class ConvModel(flow.nn.Module): def __init__(self): super(ConvModel, self).__init__() self.conv = flow.nn.Conv2d(3, 64, kernel_size=11, bias=False) def forward(self, x): x = self.conv(x) return x model = ConvModel().state_dict() for layer_name in model: layer_path = os.path.join(layer_name, "out") test_case.assertEqual(layer_path, "conv.weight/out") def test_infos_of_nodes(test_case): alexnet_module = alexnet() alexnet_graph = TvmFrontedGraph(alexnet_module) if not alexnet_graph._is_compiled: alexnet_graph._compile(flow.rand(1, 3, 224, 224)) graph_str = repr(alexnet_graph) if not alexnet_graph._is_compiled: alexnet_graph._compile(flow.rand(shape_input)) size_where = 2 if "cuda" in graph_str: size_where = 3 p_size = re.compile(r"size=\(.*?\)", re.S) p_type = re.compile(r"(dtype=.*?)[,|\)]", re.S) types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] num_nodes = {} for t in types: data = re.finditer(t + ":.*", graph_str) cnt = 0 for i in data: cnt += 1 attrs = i.group().split(":") size_strs = re.findall(p_size, attrs[size_where]) type_strs = re.findall(p_type, attrs[size_where]) test_case.assertEqual(size_strs != [], True) test_case.assertEqual(type_strs != [], True) size_attr = size_strs[0].replace("size=", "") type_attr = type_strs[0].replace("dtype=", "").replace(")", "") if size_attr[-2] == ",": size_attr = size_attr.replace(",", "") if type_attr[-1] == ",": type_attr = type_attr.replace(",", "") test_case.assertEqual(type_attr, "oneflow.float32") data_size = tuple(map(int, size_attr[1:-1].split(", "))) if cnt == 1 and t == "PARAMETER": test_case.assertEqual(data_size, (64, 3, 11, 11)) elif cnt == 15 and t == "PARAMETER": test_case.assertEqual(data_size, (1000, 4096)) num_nodes[t] = cnt test_case.assertEqual(num_nodes["INPUT"] != 0, True) test_case.assertEqual(num_nodes["BUFFER"], 0) test_case.assertEqual(num_nodes["PARAMETER"], 16) test_case.assertEqual(num_nodes["OUTPUT"] != 0, True) # get graph proto, if you don't _compile the graph, the _graph_proto will be None graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") shape_input = tuple( map( int, re.findall(p_size, graph_input[size_where])[0] .replace("size=", "")[1:-1] .split(", "), ) ) graph_proto = alexnet_graph._graph_proto nodes = {} for op in graph_proto.net.op: nodes[op.name] = op op_names = [] op_attrs = [] for node_name in nodes: node = nodes[node_name] if is_user_op(node): op_name = node.user_conf.op_type_name op_attr = parse_attr(node.user_conf.attr) op_names.append(op_name) op_attrs.append(op_attr) test_case.assertEqual(op_names[0], "conv2d") test_case.assertEqual(op_names[1], "bias_add") test_case.assertEqual(op_names[2], "relu") kernel_size = op_attrs[0].get("kernel_size", None) strides = op_attrs[0].get("strides", None) padding_before = op_attrs[0].get("padding_before", None) test_case.assertEqual(kernel_size, (11, 11)) test_case.assertEqual(strides, (4, 4)) test_case.assertEqual(padding_before, (2, 2)) node_input_list = [] node_output_list = [] for node_name in nodes: node = nodes[node_name] if is_user_op(node) and node.user_conf.op_type_name == "conv2d": for input_name in node.user_conf.input: node_input_paths = getattr(node.user_conf.input[input_name], "s") for i in node_input_paths: node_input = i.split("/")[0] print(node_input) node_input_list.append(node_input) for output_name in node.user_conf.output: node_output_paths = getattr(node.user_conf.output[output_name], "s") for node_output_path in node_output_paths: node_output_name = node_output_path.split("/")[0] print(node_output_name) node_output_list.append(node_output_name) test_case.assertEqual("_TvmFrontedGraph_1_input.0.0_2" in node_input_list, True) test_case.assertEqual("m.features.0.weight" in node_input_list, True) test_case.assertEqual("m.features.5-max_pool_2d-7" in node_input_list, True) test_case.assertEqual("m.features.0-conv2d-0" in node_output_list, True) test_case.assertEqual("m.features.6-conv2d-8" in node_output_list, True) def test_buffer_convert_dependence(test_case): class SubModule(flow.nn.Module): def __init__(self): super().__init__() self.fc1 = flow.nn.Linear(36, 4, False) self.register_buffer("dummy_buff", flow.Tensor(1, 4)) def forward(self, x): x = self.fc1(x) x += self.dummy_buff return x sub_module = SubModule() sub_graph = TvmFrontedGraph(sub_module) graph_str = repr(sub_graph) size_where = 2 if "cuda" in graph_str: size_where = 3 p_size = re.compile(r"size=\(.*?\)", re.S) p_type = re.compile(r"dtype=.*?,", re.S) num_nodes = {} data = re.finditer("BUFFER:.*", graph_str) for i in data: attrs = i.group().split(":") size_strs = re.findall(p_size, attrs[size_where]) size_attr = size_strs[0].replace("size=", "") data_size = tuple(map(int, size_attr[1:-1].split(", "))) test_case.assertEqual(data_size, (1, 4)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_user_op_expr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from google.protobuf import text_format import os import oneflow import oneflow as flow import oneflow._oneflow_internal import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest from oneflow.framework.multi_client_session import MultiClientSession def _get_c_tensor(t): if isinstance(t, oneflow._oneflow_internal.Tensor): return t else: raise NotImplementError def _test_user_op_graph(test_case, is_cuda): x0 = flow.tensor(np.random.rand(20, 30), dtype=flow.float32) weight0 = flow.tensor(np.random.rand(30, 50), dtype=flow.float32) x1 = flow.tensor(np.random.rand(50, 70), dtype=flow.float32) if is_cuda: x0 = x0.to(device=flow.device("cuda")) weight0 = weight0.to(device=flow.device("cuda")) x1 = x1.to(device=flow.device("cuda")) # NOTE(chengcheng): this tiny net is: # x0 * weight0 -> out0 # relu(out0) -> y0 # y0 * x1 -> out1 # relu(out1) -> y1 session = session_ctx.GetDefaultSession() test_case.assertTrue(isinstance(session, MultiClientSession)) session.TryInit() with oneflow._oneflow_internal.lazy_mode.guard(True): oneflow._oneflow_internal.JobBuildAndInferCtx_Open( "cc_test_user_op_expr_job_with_cuda" + str(is_cuda) ) job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto() job_conf.job_name = "cc_test_user_op_expr_job_with_cuda" + str(is_cuda) job_conf.predict_conf.SetInParent() c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) x0_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf() x0_conf.in_0 = "in_0" x0_conf.out_0 = "out_0" x0_conf_str = text_format.MessageToString(x0_conf) x0_op = oneflow._oneflow_internal.one.FeedInputOpExpr( "cc_Input_0", x0_conf_str, ["in_0"], ["out_0"] ) x1_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf() x1_conf.in_0 = "in_0" x1_conf.out_0 = "out_0" x1_conf_str = text_format.MessageToString(x1_conf) x1_op = oneflow._oneflow_internal.one.FeedInputOpExpr( "cc_Input_1", x1_conf_str, ["in_0"], ["out_0"] ) weight0_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf() weight0_conf.in_0 = "in_0" weight0_conf.out_0 = "out_0" weight0_conf_str = text_format.MessageToString(weight0_conf) weight0_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( "cc_Variable_0", weight0_conf_str, ["in_0"], ["out_0"] ) output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf() output_conf.in_0 = "in_0" output_conf.out_0 = "out_0" output_conf_str = text_format.MessageToString(output_conf) output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( "cc_Output_0", output_conf_str, ["in_0"], ["out_0"] ) x0_lazy_tensor = _C.dispatch_feed_input(x0_op, x0) x1_lazy_tensor = _C.dispatch_feed_input(x1_op, x1) weight0_lazy_tensor = _C.dispatch_feed_input(weight0_op, weight0) test_case.assertEqual(x0_lazy_tensor.shape, (20, 30)) test_case.assertTrue(x0_lazy_tensor.is_lazy) test_case.assertEqual(weight0_lazy_tensor.shape, (30, 50)) test_case.assertTrue(weight0_lazy_tensor.is_lazy) test_case.assertEqual(x1_lazy_tensor.shape, (50, 70)) test_case.assertTrue(x1_lazy_tensor.is_lazy) out0 = flow._C.matmul(x0_lazy_tensor, weight0_lazy_tensor) test_case.assertEqual(out0.shape, (20, 50)) test_case.assertTrue(out0.is_lazy) y0 = flow._C.relu(out0) test_case.assertEqual(y0.shape, (20, 50)) test_case.assertTrue(y0.is_lazy) out1 = flow._C.matmul(y0, x1_lazy_tensor) test_case.assertEqual(out1.shape, (20, 70)) test_case.assertTrue(out1.is_lazy) y1 = flow._C.relu(out1) test_case.assertEqual(y1.shape, (20, 70)) test_case.assertTrue(y1.is_lazy) eager_output = _C.dispatch_fetch_output(output_op, y1) test_case.assertEqual(eager_output.shape, (20, 70)) test_case.assertTrue(not eager_output.is_lazy) if is_cuda: test_case.assertTrue(x0_lazy_tensor.is_cuda) test_case.assertTrue(x1_lazy_tensor.is_cuda) test_case.assertTrue(weight0_lazy_tensor.is_cuda) test_case.assertTrue(out0.is_cuda) test_case.assertTrue(y0.is_cuda) test_case.assertTrue(out1.is_cuda) test_case.assertTrue(y1.is_cuda) oneflow._oneflow_internal.JobBuildAndInferCtx_Close() @flow.unittest.skip_unless_1n1d() class TestUserOpGraph(unittest.TestCase): def test_user_op_graph_cpu(test_case): _test_user_op_graph(test_case, False) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_user_op_graph_gpu(test_case): _test_user_op_graph(test_case, True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/graph/test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import itertools import os from collections import OrderedDict from collections.abc import Iterable import numpy as np import oneflow as flow import oneflow.unittest def GenCartesianProduct(sets): assert isinstance(sets, Iterable) for set in sets: assert isinstance(set, Iterable) if os.getenv("ONEFLOW_TEST_CPU_ONLY"): if "cuda" in set: set.remove("cuda") return itertools.product(*sets) def GenArgList(arg_dict): assert isinstance(arg_dict, OrderedDict) assert all([isinstance(x, list) for x in arg_dict.values()]) sets = [arg_set for (_, arg_set) in arg_dict.items()] return GenCartesianProduct(sets) def GenArgDict(arg_dict): return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)] class Args: def __init__(self, flow_args, tf_args=None): super().__init__() if tf_args is None: tf_args = flow_args self.flow_args = flow_args self.tf_args = tf_args def __str__(self): return "flow_args={} tf_args={}".format(self.flow_args, self.tf_args) def __repr__(self): return self.__str__() type_name_to_flow_type = { "float16": flow.float16, "float32": flow.float32, "double": flow.double, "int8": flow.int8, "int32": flow.int32, "int64": flow.int64, "uint8": flow.uint8, } type_name_to_np_type = { "float16": np.float16, "float32": np.float32, "double": np.float64, "int8": np.int8, "int32": np.int32, "int64": np.int64, "uint8": np.uint8, } def FlattenArray(input_array): output_array = list() for x in np.nditer(input_array): output_array.append(x.tolist()) return output_array def Array2Numpy(input_array, target_shape): return np.array(input_array).reshape(target_shape, order="C") def Index2Coordinate(idx, tensor_shape): coordinate = [] tmp = idx for i in range(len(tensor_shape) - 1, -1, -1): axis_size = tensor_shape[i] coor = tmp % axis_size coordinate.insert(0, int(coor)) tmp = (tmp - coor) / axis_size return coordinate def Coordinate2Index(coordinate, tensor_shape): if len(coordinate) != len(tensor_shape): raise "wrong coordinate or shape" idx = 0 for (i, coor) in enumerate(coordinate): size_at_axis = coor for j in range(i + 1, len(tensor_shape)): size_at_axis *= tensor_shape[j] idx += size_at_axis return idx def generate_graph(func): class Graph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, *args): return func(*args) return Graph() ================================================ FILE: python/oneflow/test/graph/test_variable_op_expr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np from google.protobuf import text_format import oneflow import oneflow as flow import oneflow._oneflow_internal import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest from oneflow.framework.multi_client_session import MultiClientSession @flow.unittest.skip_unless_1n1d() class TestFeedVariableTensor(unittest.TestCase): def test_feed_var_tensor(test_case): x = flow.Tensor(1, 1, 10, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) session = session_ctx.GetDefaultSession() test_case.assertTrue(isinstance(session, MultiClientSession)) session.TryInit() with oneflow._oneflow_internal.lazy_mode.guard(True): oneflow._oneflow_internal.JobBuildAndInferCtx_Open( "cc_test_variable_op_expr_job" ) job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto() job_conf.job_name = "cc_test_variable_op_expr_job" job_conf.predict_conf.SetInParent() c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) op_name = "cc_Variable_0" var_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf() var_conf.in_0 = "EagerTensorInput" var_conf.out_0 = "out_0" var_conf_str = text_format.MessageToString(var_conf) var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( op_name, var_conf_str, ["in_0"], ["out_0"] ) out_tensor = _C.dispatch_feed_variable(var_op, x, l2=0) test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(out_tensor.is_lazy) test_case.assertTrue(out_tensor.is_local) oneflow._oneflow_internal.JobBuildAndInferCtx_Close() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/mock_example.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch print(torch.__file__.find("mock_torch") != -1) def f(): return torch.__package__ ================================================ FILE: python/oneflow/test/misc/test_autograd_functional.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from packaging import version import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import torch from oneflow.test_utils.automated_test_util import random_tensor from oneflow.test_utils.automated_test_util import autotest def _func_tensor(x): return x.exp().sum(dim=1) def _func_scalar(x): return x.exp().sum() def _func_multi_tensor(x, y): return (x.exp() + y.pow(2)).sum(dim=1) def _func_multi_scalar(x, y): return (x.exp() + y.pow(2)).sum() def _func_scalar2tensor(x): return (x, x ** 2, x ** 3) @flow.unittest.skip_unless_1n1d() class TestAutogradFunctional(flow.unittest.TestCase): @autotest(n=1, check_graph=False) def test_vjp(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) v = random_tensor(ndim=1, dim0=5) result_tensor = torch.autograd.functional.vjp(_func_tensor, inputs, v) result_scalar = torch.autograd.functional.vjp(_func_scalar, inputs) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) result_tensors = torch.autograd.functional.vjp(_func_multi_tensor, inputs, v) result_scalars = torch.autograd.functional.vjp(_func_multi_scalar, inputs) @autotest(n=1, check_graph=False) def test_jvp(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) v = random_tensor(ndim=2, dim0=5, dim1=5) result_tensor = torch.autograd.functional.jvp(_func_tensor, inputs, v) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) v = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) result_tensors = torch.autograd.functional.jvp(_func_multi_tensor, inputs, v) inputs = random_tensor(1) result_scalar2tensor = torch.autograd.functional.jvp( _func_scalar2tensor, inputs ) @autotest(n=1, check_graph=False) def test_vhp(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) v = random_tensor(ndim=2, dim0=5, dim1=5) result_tensor = torch.autograd.functional.vhp(_func_scalar, inputs, v) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) v = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) result_tensors = torch.autograd.functional.vhp(_func_multi_scalar, inputs, v) @autotest(n=1, check_graph=False) def test_hvp(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) v = random_tensor(ndim=2, dim0=5, dim1=5) result_tensor = torch.autograd.functional.hvp(_func_scalar, inputs, v) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) v = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) result_tensors = torch.autograd.functional.hvp(_func_multi_scalar, inputs, v) # TODO: "'jacobian' and 'hessian' has no strategy parameter in PyTorch before '1.11.0'" @autotest(n=1, check_graph=False) def test_jacobian(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) if version.parse(torch.pytorch.__version__) < version.parse("1.11.0"): result_tensor = torch.autograd.functional.jacobian( _func_tensor, inputs, vectorize=False ) else: result_tensor = torch.autograd.functional.jacobian( _func_tensor, inputs, vectorize=False, strategy="reverse-mode" ) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) if version.parse(torch.pytorch.__version__) < version.parse("1.11.0"): result_tensors = torch.autograd.functional.jacobian( _func_multi_scalar, inputs, vectorize=False ) else: result_tensors = torch.autograd.functional.jacobian( _func_multi_scalar, inputs, vectorize=False, strategy="reverse-mode" ) @autotest(n=1, check_graph=False) def test_hessian(test_case): inputs = random_tensor(ndim=2, dim0=5, dim1=5) if version.parse(torch.pytorch.__version__) < version.parse("1.11.0"): result_tensor = torch.autograd.functional.hessian( _func_scalar, inputs, vectorize=False, ) else: result_tensor = torch.autograd.functional.hessian( _func_scalar, inputs, vectorize=False, outer_jacobian_strategy="reverse-mode", ) inputs = ( random_tensor(ndim=2, dim0=5, dim1=5), random_tensor(ndim=2, dim0=5, dim1=5), ) if version.parse(torch.pytorch.__version__) < version.parse("1.11.0"): result_tensors = torch.autograd.functional.hessian( _func_multi_scalar, inputs, vectorize=False, ) else: result_tensors = torch.autograd.functional.hessian( _func_multi_scalar, inputs, vectorize=False, outer_jacobian_strategy="reverse-mode", ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_distributed_env_vars.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest class TestDistributedEnvVars(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_default(test_case): test_case.assertFalse("MASTER_ADDR" in os.environ) test_case.assertFalse("MASTER_PORT" in os.environ) test_case.assertFalse("WORLD_SIZE" in os.environ) test_case.assertFalse("RANK" in os.environ) test_case.assertFalse("LOCAL_RANK" in os.environ) test_case.assertEqual(flow.distributed.get_world_size(), 1) test_case.assertEqual(flow.distributed.get_rank(), 0) test_case.assertEqual(flow.distributed.get_local_rank(), 0) @flow.unittest.skip_unless_1n2d() def test_1n2d(test_case): test_case.assertEqual(os.environ["MASTER_ADDR"], "127.0.0.1") test_case.assertEqual(os.environ["WORLD_SIZE"], "2") test_case.assertTrue(os.environ["RANK"] in ["0", "1"]) test_case.assertTrue(os.environ["LOCAL_RANK"] in ["0", "1"]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_empty_cache.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestEmptyCache(flow.unittest.TestCase): def test_cuda_to_cpu_empty_cache(test_case): if flow._oneflow_internal.flags.with_cuda(): x = flow.randn(512, 3, 512, 512).to("cuda") used_mem1 = flow._oneflow_internal.GetCUDAMemoryUsed() x = x.cpu() used_mem2 = flow._oneflow_internal.GetCUDAMemoryUsed() flow.cuda.empty_cache() used_mem3 = flow._oneflow_internal.GetCUDAMemoryUsed() test_case.assertTrue((used_mem3 < used_mem1) and (used_mem3 < used_mem2)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_env_cuda.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow from oneflow.test_utils.automated_test_util.generators import nothing, oneof, random from oneflow.test_utils.automated_test_util import torch import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestEnv(flow.unittest.TestCase): def test_get_device_count(test_case): test_case.assertEqual(flow.cuda.device_count(), 2) def test_current_device_idx(test_case): test_case.assertEqual(flow.cuda.current_device(), flow.env.get_rank()) def test_cuda_is_available(test_case): test_case.assertEqual(flow.cuda.is_available(), True) def test_cuda_synchronize(test_case): flow.cuda.synchronize() flow.cuda.synchronize("cuda") flow.cuda.synchronize("cuda:0") flow.cuda.synchronize("cuda:1") flow.cuda.synchronize(0) flow.cuda.synchronize(1) flow.cuda.synchronize(flow.device("cuda:0")) flow.cuda.synchronize(flow.device("cuda:1")) with test_case.assertRaisesRegex(ValueError, "Expected a cuda device, but"): flow.cuda.synchronize(flow.device("cpu")) with test_case.assertRaisesRegex(ValueError, "Expected a cuda device, but"): flow.cuda.synchronize("cpu") def test_cuda_get_device_name(test_case): return torch.cuda.get_device_name(oneof(0, nothing())) def test_cuda_get_device_capability(test_case): return torch.cuda.get_device_capability(oneof(0, nothing())) def test_cuda_mem_get_info(test_case): device_idx = random(0, flow.cuda.device_count()).to(int).value() return torch.cuda.mem_get_info(device_idx) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_manual_seed_api.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestManualSeedApi(flow.unittest.TestCase): def test_cuda_manual_seed_all(test_case): flow.cuda.manual_seed_all(20) x = flow.randn(2, 4, device="cuda:0") y = flow.randn(2, 4, device="cuda:1") test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) def test_cuda_manual_seed(test_case): flow.cuda.manual_seed(30) device = flow.device("cuda", flow.cuda.current_device()) x = flow.randn(2, 4, device=device) tensor_list = [flow.zeros((2, 4), dtype=flow.int32) for _ in range(2)] flow.comm.all_gather(tensor_list, x) test_case.assertTrue( np.allclose(tensor_list[0].numpy(), tensor_list[1].numpy()) ) def test_manual_seed(test_case): flow.manual_seed(40) x = flow.randn(2, 4, device="cuda:0") y = flow.randn(2, 4, device="cuda:1") test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) def test_set_get_rng_state(test_case): x = flow.ByteTensor(5000) flow.set_rng_state(x) y = flow.get_rng_state() test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_mock_diffusers.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest """ If some modules import torch internally, flow.mock_torch.disable() should be able to restore the original torch within these modules. """ class TestMock(flow.unittest.TestCase): def test_mock_diffusers(test_case): flow.mock_torch.enable(lazy=True) from diffusers import UNet2DConditionModel torch_module = UNet2DConditionModel.__dict__["forward"].__globals__["torch"] flow.mock_torch.disable() from diffusers import UNet2DConditionModel torch_module = UNet2DConditionModel.__dict__["forward"].__globals__["torch"] # check whether the torch module is the original torch test_case.assertFalse(isinstance(torch_module, flow.mock_torch.ModuleWrapper)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_mock_scope.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest import oneflow.mock_torch as mock """ enable & disable mode hold a dict[str, ModuleType] like sys.modules, the keys start with 'torch'. The two modes don't interfere with each other, sys.modules and global scope are replaced on switch. """ with mock.enable(): import torch import torch.nn import torch.version with mock.disable(): import torch import torch.nn import torch.version @flow.unittest.skip_unless_1n1d() class TestMock(flow.unittest.TestCase): def test_with(test_case): with mock.enable(): test_case.assertEqual(torch.__package__, "oneflow") test_case.assertEqual(torch.nn.__package__, "oneflow.nn") test_case.assertEqual(torch.version.__version__, flow.__version__) with mock.disable(): test_case.assertEqual(torch.__package__, "torch") test_case.assertEqual(torch.nn.__package__, "torch.nn") test_case.assertEqual(torch.version.__version__, torch.__version__) def test_simple(test_case): mock.enable() test_case.assertEqual(torch.__package__, "oneflow") test_case.assertEqual(torch.nn.__package__, "oneflow.nn") test_case.assertEqual(torch.version.__version__, flow.__version__) mock.disable() test_case.assertEqual(torch.__package__, "torch") test_case.assertEqual(torch.nn.__package__, "torch.nn") test_case.assertEqual(torch.version.__version__, torch.__version__) def test_import_from(test_case): mock.enable() from torch import nn from torch.version import __version__ test_case.assertEqual(nn.__package__, "oneflow.nn") test_case.assertEqual(__version__, flow.__version__) mock.disable() from torch import nn from torch.version import __version__ test_case.assertEqual(nn.__package__, "torch.nn") test_case.assertEqual(__version__, torch.__version__) def test_error(test_case): mock.enable() with test_case.assertRaises(ImportError) as context: from torch import noexist test_case.assertTrue( "cannot import name 'noexist' from 'oneflow'" in str(context.exception) ) with test_case.assertRaises(ModuleNotFoundError) as context: import torch.noexist test_case.assertTrue( "oneflow.noexist is not implemented" in str(context.exception) ) mock.disable() with test_case.assertRaises(ImportError) as context: from torch import noexist test_case.assertTrue( "cannot import name 'noexist' from 'torch'" in str(context.exception) ) with test_case.assertRaises(ModuleNotFoundError) as context: import torch.noexist test_case.assertTrue( "No module named 'torch.noexist'" in str(context.exception) ) def test_nested_with(test_case): with mock.enable(): test_case.assertEqual(torch.__package__, "oneflow") with mock.disable(): test_case.assertEqual(torch.__package__, "torch") test_case.assertEqual(torch.__package__, "oneflow") with mock.disable(): test_case.assertEqual(torch.__package__, "torch") with mock.enable(): test_case.assertEqual(torch.__package__, "oneflow") test_case.assertEqual(torch.__package__, "torch") def test_noop_disable(test_case): with mock.disable(): import torch test_case.assertEqual(torch.__package__, "torch") @unittest.skip("skip for now, becase it failed 2 times in past week") def test_3rd_party(test_case): with mock.enable(): from mock_example import f test_case.assertEqual(f(), "oneflow") def test_env_var(test_case): os.environ["ONEFLOW_DISABLE_MOCK_TORCH"] = "1" with mock.enable(): import torch test_case.assertEqual(torch.__package__, "torch") os.environ["ONEFLOW_DISABLE_MOCK_TORCH"] = "0" def test_dummy_obj_fallback(test_case): with mock.enable(lazy=True): from torch import not_exist test_case.assertEqual(not_exist.__name__, "oneflow.not_exist") x = not_exist.x test_case.assertEqual(x.__name__, "oneflow.not_exist.x") def test_mock_torchvision(test_case): with mock.enable(lazy=True): import torchvision model = torchvision.models.resnet18(pretrained=False) test_case.assertEqual(len(list(model.parameters())), 62) def test_mock_lazy_for_loop(test_case): with mock.enable(lazy=True): import torch # Test no infinite loop for _ in torch.not_exist: pass def test_mock_lazy_in_if(test_case): with mock.enable(lazy=True): import torch if torch.not_exist: test_case.assertTrue(False) def test_hazard_list(test_case): with mock.enable(): import sys import safetensors test_case.assertTrue("safetensors._safetensors_rust" in sys.modules) import safetensors def test_isinstance(test_case): with mock.enable(lazy=True): import torch test_case.assertFalse(isinstance(int, torch._six.string_class)) def test_with_statement(test_case): with mock.enable(lazy=True): with test_case.assertRaises(RuntimeError) as context: import torch.noexist with torch.noexist: pass test_case.assertTrue( '"oneflow.noexist" is a dummy object, and does not support "with" statement.' in str(context.exception) ) def test_setattr(test_case): with mock.enable(): import torch torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward test_case.assertEqual( torch.nn.Linear_forward_before_lora, torch.nn.Linear.forward ) def test_hasattr_and_getattr_in_lazy_mode(test_case): with mock.enable(lazy=True): test_case.assertFalse(hasattr(torch, "not_exist")) test_case.assertFalse(hasattr(torch.nn.functional, "not_exist")) test_case.assertTrue(isinstance(torch.not_exist, mock.DummyModule)) test_case.assertTrue( isinstance(torch.nn.functional.not_exist, mock.DummyModule) ) import torch.nn.functional as F test_case.assertFalse(hasattr(F, "scaled_dot_product_attention")) test_case.assertFalse( hasattr(torch.nn.functional, "scaled_dot_product_attention") ) def test_mock_extra_dict(test_case): with mock.enable(lazy=True, extra_dict={"torchvision": "flowvision"}): import torchvision test_case.assertEqual(torchvision.models.__package__, "flowvision.models") # MUST use pytest to run this test def test_verbose(capsys): with mock.enable(lazy=True, verbose=True): import torch.not_exist captured = capsys.readouterr() assert "oneflow.not_exist is not found in oneflow" in captured.out if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_np_dtype_converter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestNpDtypeConverter(flow.unittest.TestCase): def test_np_dtype_converter(test_case): for flow_dtype in flow.dtypes(): if flow_dtype in [flow.record, flow.tensor_buffer, flow.bfloat16]: continue np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(flow_dtype) test_case.assertEqual( flow.framework.dtype.convert_numpy_dtype_to_oneflow_dtype(np_dtype), flow_dtype, ) # Test whether dtype conversion works with arr.dtype np_arr = np.array([1, 2], dtype=np_dtype) test_case.assertEqual(np_arr.dtype, np_dtype) flow_tensor = flow.tensor([1, 2], dtype=flow_dtype) test_case.assertEqual(flow_tensor.dtype, flow_dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_placement.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest class TestPlacement(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_placement_all_cuda(test_case): placement = flow.placement.all("cuda") test_case.assertEqual(placement.type, "cuda") # assertEqual fails to compare lists test_case.assertTrue( list(placement.ranks) == list(range(flow.env.get_world_size())) ) @unittest.skip("skip for now, becase it failed 10 times in past week") def test_placement_all_cpu(test_case): placement = flow.placement.all("cpu") test_case.assertEqual(placement.type, "cpu") # assertEqual fails to compare lists test_case.assertTrue( list(placement.ranks) == list(range(flow.env.get_world_size())) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/misc/test_pybind11_caster.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestPybind11Caster(flow.unittest.TestCase): def test_optional(test_case): test_case.assertEqual( flow._oneflow_internal.test_api.increase_if_not_none(1), 2 ) test_case.assertEqual( flow._oneflow_internal.test_api.increase_if_not_none(None), None ) def test_maybe(test_case): test_case.assertEqual(flow._oneflow_internal.test_api.divide(6, 2), 3) def test_maybe_void(test_case): flow._oneflow_internal.test_api.throw_if_zero(1) def test_return_maybe_shared_ptr(test_case): a1 = flow._oneflow_internal.test_api.get_singleton_a() x1 = a1.get_x() a1.inc_x() a2 = flow._oneflow_internal.test_api.get_singleton_a() x2 = a2.get_x() test_case.assertEqual(id(a1), id(a2)) test_case.assertEqual(x1 + 1, x2) def test_pass_optional_shared_ptr(test_case): a1 = flow._oneflow_internal.test_api.get_singleton_a() x1 = a1.get_x() a1.inc_x() a2 = flow._oneflow_internal.test_api.increase_x_of_a_if_not_none(a1) x2 = a2.get_x() test_case.assertEqual(id(a1), id(a2)) test_case.assertEqual(x1 + 2, x2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/image_test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import random import cv2 import numpy as np import PIL import oneflow as flow global_coco_dict = dict() default_coco_anno_file = flow.unittest.dataset_dir( "mscoco_2017/annotations/instances_val2017.json" ) default_coco_image_dir = flow.unittest.dataset_dir("mscoco_2017/val2017") def get_coco(anno_file): global global_coco_dict if anno_file not in global_coco_dict: from pycocotools.coco import COCO global_coco_dict[anno_file] = COCO(anno_file) return global_coco_dict[anno_file] def random_sample_images_from_coco( anno_file=default_coco_anno_file, image_dir=default_coco_image_dir, batch_size=2 ): image_files = [] image_ids = [] batch_group_id = -1 coco = get_coco(anno_file) img_ids = coco.getImgIds() while len(image_files) < batch_size: rand_img_id = random.choice(img_ids) img_h = coco.imgs[rand_img_id]["height"] img_w = coco.imgs[rand_img_id]["width"] group_id = int(img_h / img_w) if batch_group_id == -1: batch_group_id = group_id if group_id != batch_group_id: continue image_files.append(os.path.join(image_dir, coco.imgs[rand_img_id]["file_name"])) image_ids.append(rand_img_id) assert len(image_files) == len(image_ids) return (image_files, image_ids) def read_images_by_cv(image_files, dtype, channels=3): np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype) images = [cv2.imread(image_file).astype(np_dtype) for image_file in image_files] assert all((isinstance(image, np.ndarray) for image in images)) assert all((image.ndim == 3 for image in images)) assert all((image.shape[2] == channels for image in images)) return images def read_images_by_pil(image_files, dtype, channels=3): image_objs = [PIL.Image.open(image_file) for image_file in image_files] images = [] np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype) for im in image_objs: bands = im.getbands() band = "".join(bands) if band == "RGB": images.append(np.asarray(im).astype(np_dtype)[:, :, ::-1]) elif band == "L": gs_image = np.asarray(im).astype(np_dtype) gs_image_shape = gs_image.shape assert len(gs_image_shape) == 2 gs_image = gs_image.reshape(gs_image_shape + (1,)) gs_image = np.broadcast_to(gs_image, shape=gs_image_shape + (3,)) images.append(gs_image) elif band == "BGR": images.append(np.asarray(im).astype(np_dtype)) else: raise NotImplementedError assert all((isinstance(image, np.ndarray) for image in images)) assert all((image.ndim == 3 for image in images)) assert all((image.shape[2] == channels for image in images)) return images def infer_images_static_shape(images, channels=3): image_shapes = [image.shape for image in images] assert all((image.ndim == 3 for image in images)) assert all((image.shape[2] == channels for image in images)) image_shapes = np.asarray(image_shapes) max_h = np.max(image_shapes[:, 0]).item() max_w = np.max(image_shapes[:, 1]).item() image_static_shape = (len(images), max_h, max_w, channels) group_ids = [] aspect_ratio_list = [] for image_shape in image_shapes: (h, w) = image_shape[0:2] if h < w: group_id = 0 aspect_ratio = h / w else: group_id = 1 aspect_ratio = w / h group_ids.append(group_id) aspect_ratio_list.append(aspect_ratio) assert all((group_id == group_ids[0] for group_id in group_ids)) return (image_static_shape, aspect_ratio_list) def compute_keep_aspect_ratio_resized_size( target_size, min_size, max_size, aspect_ratio, resize_side ): if resize_side == "shorter": min_res_size = target_size max_res_size = int(round(min_res_size / aspect_ratio)) if max_size is not None and max_res_size > max_size: max_res_size = max_size min_res_size = int(round(max_res_size * aspect_ratio)) elif resize_side == "longer": max_res_size = target_size min_res_size = int(round(max_res_size * aspect_ratio)) if min_size is not None and min_res_size < min_size: min_res_size = min_size max_res_size = int(round(min_res_size / aspect_ratio)) else: raise NotImplementedError return (min_res_size, max_res_size) def infer_keep_aspect_ratio_resized_images_static_shape( target_size, min_size, max_size, aspect_ratio_list, resize_side="shorter", channels=3, ): resized_size_list = [] for aspect_ratio in aspect_ratio_list: resized_size_list.append( compute_keep_aspect_ratio_resized_size( target_size, min_size, max_size, aspect_ratio, resize_side ) ) (res_min_size, res_max_size) = max( resized_size_list, key=lambda size: size[0] * size[1] ) return (res_min_size, res_max_size, channels) ================================================ FILE: python/oneflow/test/modules/optimizer_test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np def clip_grad_norm_np(np_grad, max_norm, norm_type): np_grad_is_list = True if isinstance(np_grad, np.ndarray): np_grad_is_list = False np_grad = [np_grad] max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == float("inf"): total_norm = np.max(np.abs(np_grad)) elif norm_type == float("-inf"): total_norm = np.min(np.abs(np_grad)) else: norms = np_grad total_norm = [] for i, norm in enumerate(norms): for j in range(np_grad[i].ndim, 0, -1): norm = np.linalg.norm(norm, norm_type, axis=j - 1) total_norm.append(norm) total_norm = np.linalg.norm(np.array(total_norm, dtype=np.float32), norm_type) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for grad in np_grad: grad *= clip_coef if not np_grad_is_list: np_grad = np_grad[0] return total_norm, np_grad ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_3/meta ================================================ shape { dim: 3 dim: 3 dim: 3 dim: 3 } data_type: kFloat ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_4/meta ================================================ shape { dim: 3 } data_type: kFloat ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_4/out ================================================ w)I ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_5/meta ================================================ shape { dim: 3 dim: 3 dim: 3 dim: 3 } data_type: kFloat ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_6/meta ================================================ shape { dim: 3 } data_type: kFloat ================================================ FILE: python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_6/out ================================================ w)I ================================================ FILE: python/oneflow/test/modules/sync_batchnorm_test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow ONEREC_URL = ( "https://oneflow-public.oss-cn-beijing.aliyuncs.com/sync_bn_test_datas.tar.gz" ) MD5 = "537ff00fb47be8be90df75f47a883b76" def md5(fname): import hashlib hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) result = hash_md5.hexdigest() print("md5", fname, result) return result def download_file(out_path: str, url): import requests from tqdm import tqdm resp = requests.get(url=url, stream=True) MB = 1024 ** 2 size = int(resp.headers["Content-Length"]) / MB print("File size: %.4f MB, downloading..." % size) with open(out_path, "wb") as f: for data in tqdm( iterable=resp.iter_content(MB), total=size, unit="m", desc=out_path ): f.write(data) print("Done!") def ensure_datas(): import os import pathlib data_dir = os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "sync_bn" ) file_path = pathlib.Path(data_dir) / ONEREC_URL.split("/")[-1] absolute_file_path = str(file_path.absolute()) if flow.env.get_rank() == 0: file_path.parent.mkdir(parents=True, exist_ok=True) if file_path.exists(): if MD5 != md5(absolute_file_path): file_path.unlink() download_file(absolute_file_path, ONEREC_URL) else: download_file(str(absolute_file_path), ONEREC_URL) assert MD5 == md5(absolute_file_path) import tarfile my_tar = tarfile.open(str(absolute_file_path)) my_tar.extractall(data_dir) # specify which folder to extract to my_tar.close() flow.comm.barrier() return os.path.join( os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "sync_bn", "sync_bn_test_datas", ) ================================================ FILE: python/oneflow/test/modules/test_0_dim_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _test_0_dim_tensor(test_case, device): scalar = 9.999 input_np = np.array(scalar) input = flow.tensor(input_np, device=device) test_case.assertEqual(input.numel(), 1) test_case.assertEqual(input.ndimension(), 0) x1 = flow.tensor(np.array(2), dtype=flow.float32, device=device) x2 = flow.tensor(np.array(3), dtype=flow.float32, device=device) y1 = x1 * x2 y2 = x1 + x2 test_case.assertEqual(y1.numpy(), 6.0) test_case.assertEqual(y2.numpy(), 5.0) def _test_scalar_mul(test_case, device): for dim in range(5): test_case.assertEqual( np.ones([2] * dim).sum(), flow.ones([2] * dim, device=device).sum().numpy() ) def _test_slice(test_case, device): x = flow.tensor(np.arange(10), device=device) for i in range(x.numel()): scalar_i = x[i] test_case.assertEqual(i, scalar_i.numpy()) test_case.assertEqual(scalar_i.numel(), 1) test_case.assertEqual(scalar_i.ndimension(), 0) def _test_slice_backward(test_case, device): np_grad = np.zeros(10) x = flow.tensor(np.arange(10).astype(np.float32), device=device, requires_grad=True) for i in range(x.numel()): y = x[i] z = y.sum() z.backward() np_grad[i] = 1 test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-04, 1e-04)) x2 = flow.tensor( np.arange(100).astype(np.float32), device=device, requires_grad=True ) y2 = x2[1:100] z2 = y2.sum() z2.backward() np_grad2 = np.ones(100) np_grad2[0] = 0 test_case.assertTrue(np.allclose(x2.grad.numpy(), np_grad2, 1e-04, 1e-04)) def _test_slice_scalar_graph(test_case, device): x = flow.tensor(3.0, device=device) class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter( flow.tensor([1.0, 2.0, 3.0, 4.0], device=device) ) def forward(self, x): return x * self.weight[3] my_module = MyModule() of_eager_out = my_module(x) class ScalarGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = my_module def build(self, x): return self.m(x) scalar_g = ScalarGraph() of_lazy_out = scalar_g(x) test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) def _test_slice_scalar_train_graph(test_case, device): class MyModule(flow.nn.Module): def __init__(self): super().__init__() self.weight = flow.nn.Parameter( flow.tensor([1.0, 2.0, 3.0, 4.0], device=device) ) def forward(self, x): return x * self.weight[3] + 1.0 my_module = MyModule() of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9) eager_out_list = [] for i in range(3): x = flow.tensor(i * 1.0, device=device, requires_grad=False) of_eager_out = my_module(x) of_eager_out.backward() of_sgd.step() of_sgd.zero_grad() eager_out_list.append(of_eager_out) lazy_module = MyModule() class ScalarTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = lazy_module of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9) self.add_optimizer(of_sgd) def build(self, x): loss = self.m(x) loss.backward() return loss lazy_out_list = [] scalar_g = ScalarTrainGraph() for i in range(3): x = flow.tensor(i * 1.0, device=device) of_lazy_out = scalar_g(x) lazy_out_list.append(of_lazy_out) for i in range(3): test_case.assertTrue( np.array_equal(lazy_out_list[i].numpy(), eager_out_list[i].numpy()) ) @flow.unittest.skip_unless_1n1d() class TestZeroDimensionTensor(flow.unittest.TestCase): def test_0_dim_tensor(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_0_dim_tensor, _test_scalar_mul, _test_slice, _test_slice_backward, _test_slice_scalar_graph, _test_slice_scalar_train_graph, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_TripletMarginLoss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow @flow.unittest.skip_unless_1n1d() class TestTripletMarginLoss(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @autotest(n=10) def test_triplet_marginloss_with_random_data(test_case): margin = random().to(float) p = random().to(float) swap = random_bool() reduction = oneof("none", "sum", "mean", nothing()) m = torch.nn.TripletMarginLoss( margin=margin, p=p, swap=swap, reduction=reduction ) m.train(random()) device = random_device() m.to(device) shape = random_tensor(ndim=2, dim0=random(1, 8)).pytorch.shape anchor = random_tensor(len(shape), *shape).to(device) pos = random_tensor(len(shape), *shape).to(device) neg = random_tensor(len(shape), *shape).to(device) y = m(anchor, pos, neg) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_abs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestAbsModule(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_abs_with_0_size_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.abs(x) return y @autotest(n=5, check_graph=True) def test_abs_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.abs(x) return y @profile(torch.abs) def profile_abs(test_case): torch.abs(torch.ones(1, 128, 28, 28)) torch.abs(torch.ones(16, 128, 28, 28)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as pytorch from oneflow.test_utils.automated_test_util import * from scipy import special from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestReLUModule(flow.unittest.TestCase): @autotest(n=5) def test_relu_module_with_random_data(test_case): m = torch.nn.ReLU() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_relu_module_with_0dim_data(test_case): m = torch.nn.ReLU() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_relu_module_with_0_size_data(test_case): m = torch.nn.ReLU() m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @profile(torch.nn.functional.relu) def profile_relu(test_case): torch.nn.functional.relu(torch.ones(1, 128, 28, 28)) torch.nn.functional.relu(torch.ones(1, 128, 28, 28), inplace=True) torch.nn.functional.relu(torch.ones(16, 128, 28, 28)) torch.nn.functional.relu(torch.ones(16, 128, 28, 28), inplace=True) @flow.unittest.skip_unless_1n1d() class TestReLU6Module(flow.unittest.TestCase): @autotest(n=5) def test_relu6_module_with_random_data(test_case): m = torch.nn.ReLU6() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_relu6_module_with_0dim_data(test_case): m = torch.nn.ReLU6() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_relu6_module_with_0_size_data(test_case): m = torch.nn.ReLU6() m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestTanh(flow.unittest.TestCase): @autotest(n=5) def test_tanh_module_with_random_data(test_case): m = torch.nn.Tanh() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_tanh_module_with_0dim_data(test_case): m = torch.nn.Tanh() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_tanh_module_with_0_size_data(test_case): m = torch.nn.Tanh() m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @autotest(n=5) def test_flow_tanh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.tanh(x) return y @autotest(n=5) def test_flow_tanh_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.tanh(x) return y @autotest(n=5, auto_backward=False) def test_flow_tanh_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 3).to(device) y = torch.tanh(x) return y @profile(torch.nn.functional.tanh) def profile_tanh(test_case): torch.nn.functional.tanh(torch.ones(1, 128, 28, 28)) torch.nn.functional.tanh(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestELUModule(flow.unittest.TestCase): @autotest(n=5) def test_elu_module_with_random_data(test_case): m = torch.nn.ELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_elu_module_with_0dim_data(test_case): m = torch.nn.ELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_elu_module_with_0_size_data(test_case): m = torch.nn.ELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @profile(torch.nn.functional.elu) def profile_elu(test_case): torch.nn.functional.elu(torch.ones(1, 128, 28, 28), 1.0) torch.nn.functional.elu(torch.ones(16, 128, 28, 28), 1.0) @flow.unittest.skip_unless_1n1d() class TestCELUModule(flow.unittest.TestCase): @autotest(n=5) def test_celu_module_with_random_data(test_case): m = torch.nn.CELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_celu_module_with_0dim_data(test_case): m = torch.nn.CELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_celu_module_with_0_size_data(test_case): m = torch.nn.CELU(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @autotest(n=10) def test_inplace_celu_module(test_case): m = torch.nn.CELU(alpha=random() | nothing(), inplace=True) device = random_device() m.to(device) x = random_tensor().to(device) y = x + 0.001 m(y) return y @profile(torch.nn.functional.celu) def profile_celu(test_case): torch.nn.functional.celu(torch.ones(1, 128, 28, 28)) torch.nn.functional.celu(torch.ones(1, 128, 28, 28), inplace=True) torch.nn.functional.celu(torch.ones(16, 128, 28, 28)) torch.nn.functional.celu(torch.ones(16, 128, 28, 28), inplace=True) @flow.unittest.skip_unless_1n1d() class TestGelu(flow.unittest.TestCase): @autotest(n=5) def test_gelu_module_with_random_data(test_case): m = torch.nn.GELU() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_gelu_module_with_0dim_data(test_case): m = torch.nn.GELU() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.gelu) def profile_gelu(test_case): torch.nn.functional.gelu(torch.ones(1, 128, 28, 28)) torch.nn.functional.gelu(torch.ones(16, 128, 28, 28)) @unittest.skipIf( float(pytorch.__version__[:4]) < 1.12, f"need pytorch version >= 1.12, got {pytorch.__version__}", ) @flow.unittest.skip_unless_1n1d() class TestFastGelu(flow.unittest.TestCase): @autotest(n=5) def test_fast_gelu(test_case): device = random_device() x = random_tensor().to(device) y = torch.nn.functional.gelu(x, approximate="tanh") return y @autotest(n=5, atol=1e-2, rtol=1e-2) def test_fast_gelu_fp16(test_case): x = random_tensor().to(device=gpu_device(), dtype=torch.float16) y = torch.nn.functional.gelu(x, approximate="tanh") return y @autotest(n=5) def test_fast_gelu_scalar(test_case): x = random_tensor(ndim=0).to(device=random_device()) y = torch.nn.functional.gelu(x, approximate="tanh") return y @profile(torch.nn.functional.gelu) def profile_fast_gelu(test_case): torch.nn.functional.gelu(torch.ones(1, 128, 28, 28), approximate="tanh") torch.nn.functional.gelu(torch.ones(16, 128, 28, 28), approximate="tanh") @flow.unittest.skip_unless_1n1d() class TestSigmoidModule(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, atol=1e-3, check_dtype=True) def test_sigmoid_flow_with_half_data(test_case): device = gpu_device() x = random_tensor().to(device=device, dtype=torch.float16) y = torch.sigmoid(x) return y @autotest(n=5) def test_sigmoid_module_with_random_data(test_case): m = torch.nn.Sigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_sigmoid_module_with_0dim_data(test_case): m = torch.nn.Sigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5) def test_sigmoid_flow_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.sigmoid(x) return y @autotest(n=5) def test_sigmoid_flow_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.sigmoid(x) return y @autotest(n=5) def test_sigmoid_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.sigmoid() return y @autotest(n=5) def test_sigmoid_tensor_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.sigmoid() return y @profile(torch.sigmoid) def profile_sigmoid(test_case): torch.sigmoid(torch.ones(1, 128, 28, 28)) torch.sigmoid(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestHardsigmoidModule(flow.unittest.TestCase): def test_hardsigmoid_inplace(test_case): def np_hardsigmoid(input): input_shape = input.shape input = input.flatten() elem_cnt = input.size _zero = np.zeros_like(input) for i in range(elem_cnt): if input[i] >= 3: _zero[i] = 1 elif input[i] <= -3: _zero[i] = 0 else: _zero[i] = input[i] / 6 + 0.5 np_hsigmoid_out = np.reshape(_zero, newshape=input_shape) return np.array(np_hsigmoid_out) def test_hardsigmoid_inplace_impl(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x_inplace = x + 1 np_out = np_hardsigmoid(x_inplace.numpy()) id_old = id(x_inplace) y_inplace = flow.nn.functional.hardsigmoid(x_inplace, inplace=True) test_case.assertEqual(id_old, id(y_inplace)) test_case.assertTrue(np.allclose(y_inplace.numpy(), np_out, 1e-5, 1e-5)) arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): test_hardsigmoid_inplace_impl(test_case, *arg) @autotest(n=5) def test_hardsigmoid_module_with_random_data(test_case): m = torch.nn.Hardsigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_hardsigmoid_module_with_0dim_data(test_case): m = torch.nn.Hardsigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5) def test_functional_hardsigmoid_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.nn.functional.hardsigmoid(x, random_bool()) return y @autotest(n=5) def test_functional_hardsigmoid_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.nn.functional.hardsigmoid(x, random_bool()) return y @profile(torch.nn.functional.hardsigmoid) def profile_hardsigmoid(test_case): torch.nn.functional.hardsigmoid(torch.ones(1, 128, 28, 28)) torch.nn.functional.hardsigmoid(torch.ones(1, 128, 28, 28), inplace=True) torch.nn.functional.hardsigmoid(torch.ones(16, 128, 28, 28)) torch.nn.functional.hardsigmoid(torch.ones(16, 128, 28, 28), inplace=True) def do_test_softmax(batch_size: int, log_softmax: bool = False): num_dims = random(low=1, high=5).to(int) m = torch.nn.Softmax(dim=random(low=0, high=num_dims).to(int) | nothing()) if log_softmax: m = torch.nn.LogSoftmax(dim=random(low=0, high=num_dims).to(int) | nothing()) m.train(random()) device = random_device() m.to(device) x = ( random_tensor(ndim=num_dims).to(device) if batch_size < 0 else random_tensor(ndim=num_dims, dim0=batch_size).to(device) ) y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestSoftmax(flow.unittest.TestCase): @autotest(n=5) def test_softmax_module_with_random_data(test_case): return do_test_softmax(batch_size=-1, log_softmax=False) @autotest(n=5) def test_softmax_module_with_batch_size_equal_1024(test_case): return do_test_softmax(batch_size=1024, log_softmax=False) @autotest(n=5, check_graph=True) def test_softmax_module_with_batch_size_equal_5120(test_case): return do_test_softmax(batch_size=5120, log_softmax=False) @autotest(n=2, check_graph=True) def test_softmax_module_with_batch_size_equal_10240(test_case): return do_test_softmax(batch_size=10240, log_softmax=False) @profile(torch.nn.functional.softmax) def profile_softmax(test_case): torch.nn.functional.softmax(torch.ones(1, 128, 28, 28)) torch.nn.functional.softmax(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestLogSoftmaxModule(flow.unittest.TestCase): @autotest(n=5) def test_logsoftmax_module_with_random_data(test_case): return do_test_softmax(batch_size=-1, log_softmax=True) @autotest(n=5) def test_softmax_module_with_batch_size_equal_1024(test_case): return do_test_softmax(batch_size=1024, log_softmax=True) @autotest(n=5, check_graph=True) def test_softmax_module_with_batch_size_equal_5120(test_case): return do_test_softmax(batch_size=5120, log_softmax=True) @autotest(n=2, check_graph=True) def test_softmax_module_with_batch_size_equal_10240(test_case): return do_test_softmax(batch_size=10240, log_softmax=True) @profile(torch.nn.functional.log_softmax) def profile_logsoftmax(test_case): torch.nn.functional.log_softmax(torch.ones(1, 128, 28, 28)) torch.nn.functional.log_softmax(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestLogSigmoidModule(flow.unittest.TestCase): @autotest(n=5) def test_logsigmoid_module_with_random_data(test_case): m = torch.nn.LogSigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_logsigmoid_module_with_0dim_data(test_case): m = torch.nn.LogSigmoid() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.logsigmoid) def profile_logsigmoid(test_case): torch.nn.functional.logsigmoid(torch.ones(1, 128, 28, 28)) torch.nn.functional.logsigmoid(torch.ones(16, 128, 28, 28)) def numpy_softplus(x, beta, threshold): return np.where( x * beta > threshold, x, 1.0 / beta * np.log(1.0 + np.exp(beta * x)) ) def _test_softplus(test_case, device): m = flow.nn.Softplus() arr = np.random.randn(2, 3, 4, 5) np_out = numpy_softplus(arr, 1.0, 20) x = flow.tensor(arr, device=flow.device(device)) of_out = m(x) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_softplus_beta(test_case, device): m = flow.nn.Softplus(beta=1.11) arr = np.random.randn(2, 3, 4, 5) np_out = numpy_softplus(arr, 1.11, 20) x = flow.tensor(arr, device=flow.device(device)) of_out = m(x) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_softplus_threshold(test_case, device): m = flow.nn.Softplus(beta=1.11, threshold=1.55) arr = np.random.randn(2, 3, 4, 5) np_out = np.where( arr * 1.11 > 1.55, arr, 1.0 / 1.11 * np.log(1.0 + np.exp(1.11 * arr)) ) np_out = numpy_softplus(arr, 1.11, 1.55) x = flow.tensor(arr, device=flow.device(device)) of_out = m(x) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_softplus_backward(test_case, device): m = flow.nn.Softplus() arr = np.array([1.0, 2.0, 21.0, 20.0, 4.0]) x = flow.tensor(arr, device=flow.device(device), requires_grad=True) of_out = m(x) of_out = of_out.sum() of_out.backward() np_grad = [0.7310585786300049, 0.8807970779778824, 1.0, 1.0, 0.9820137900379085] test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestSoftplusModule(flow.unittest.TestCase): def test_softplus(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_softplus, _test_softplus_beta, _test_softplus_threshold, _test_softplus_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skip("pytorch softplus backward has bug") @autotest(n=5) def test_softplus_module_with_random_data(test_case): m = torch.nn.Softplus(beta=random() | nothing(), threshold=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @profile(torch.nn.functional.softplus) def profile_softplus(test_case): torch.nn.functional.softplus(torch.ones(1, 128, 28, 28)) torch.nn.functional.softplus(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestHardswishModule(flow.unittest.TestCase): @autotest(n=5) def test_hardswish_module_with_random_data(test_case): m = torch.nn.Hardswish() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_hardswish_module_with_0dim_data(test_case): m = torch.nn.Hardswish() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.hardswish) def profile_hardswish(test_case): torch.nn.functional.hardswish(torch.ones(1, 128, 28, 28)) torch.nn.functional.hardswish(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestHardtanhModule(flow.unittest.TestCase): @autotest(n=5) def test_hardtanh_module_with_random_data(test_case): m = torch.nn.Hardtanh( min_val=random().to(float) | nothing(), max_val=random().to(float) | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4).to(device) y = m(x) return y @autotest(n=5) def test_hardtanh_module_with_0dim_data(test_case): m = torch.nn.Hardtanh( min_val=random().to(float) | nothing(), max_val=random().to(float) | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.hardtanh) def profile_hardtanh(test_case): torch.nn.functional.hardtanh( torch.ones(1, 128, 28, 28), min_val=-1.0, max_val=1.0 ) torch.nn.functional.hardtanh( torch.ones(16, 128, 28, 28), min_val=-1.0, max_val=1.0 ) @flow.unittest.skip_unless_1n1d() class TestLeakyReLUModule(flow.unittest.TestCase): @autotest(n=5) def test_leakyrelu_module_with_random_data(test_case): m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_leakyrelu_module_with_inplace_arg(test_case): m = torch.nn.LeakyReLU( negative_slope=random() | nothing(), inplace=random().to(bool) | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_leakyrelu_module_with_0dim_data(test_case): m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.leaky_relu) def profile_leaky_relu(test_case): torch.nn.functional.leaky_relu(torch.ones(1, 128, 28, 28), 0.1) torch.nn.functional.leaky_relu(torch.ones(1, 128, 28, 28), 0.1, inplace=True) torch.nn.functional.leaky_relu(torch.ones(16, 128, 28, 28), 0.1) torch.nn.functional.leaky_relu(torch.ones(16, 128, 28, 28), 0.1, inplace=True) @flow.unittest.skip_unless_1n1d() class TestMishModule(flow.unittest.TestCase): @autotest(n=5) def test_mish_module_with_random_data(test_case): m = torch.nn.Mish() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_mish_module_with_0dim_data(test_case): m = torch.nn.Mish() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.mish) def profile_mish(test_case): torch.nn.functional.mish(torch.ones(1, 128, 28, 28)) torch.nn.functional.mish(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestSiluModule(flow.unittest.TestCase): @autotest(n=5) def test_silu_module_with_random_data(test_case): m = torch.nn.SiLU() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_silu_module_with_0dim_data(test_case): m = torch.nn.SiLU() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.silu) def profile_silu(test_case): torch.nn.functional.silu(torch.ones(1, 128, 28, 28)) torch.nn.functional.silu(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestSeluModule(flow.unittest.TestCase): @autotest(n=5) def test_selu_module_with_random_data(test_case): m = torch.nn.SELU() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_selu_module_with_0dim_data(test_case): m = torch.nn.SELU() m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @profile(torch.nn.functional.selu) def profile_selu(test_case): torch.nn.functional.selu(torch.ones(1, 128, 28, 28)) torch.nn.functional.selu(torch.ones(16, 128, 28, 28)) @unittest.skip("still have error in ci test") class TestSoftsignModule(flow.unittest.TestCase): @autotest(n=5) def test_softsign_module_with_random_data(test_case): m = torch.nn.Softsign() m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y #'Ran 1 test in 0.000s',return a blank table @profile(torch.nn.functional.softsign) def profile_softsign(test_case): torch.nn.functional.softsign(torch.ones(1, 128, 28, 28)) torch.nn.functional.softsign(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestReluFunction(flow.unittest.TestCase): @autotest(n=5) def test_flow_relu_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) y = torch.relu(x) return y @autotest(n=5) def test_flow_relu_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.relu(x) return y @flow.unittest.skip_unless_1n1d() class TestRelu6Function(flow.unittest.TestCase): @autotest(n=5) def test_flow_nn_functional_relu6_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) y = torch.nn.functional.relu6(x) return y @autotest(n=5) def test_flow_nn_functional_relu6_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.nn.functional.relu6(x) return y @profile(torch.nn.functional.relu6) def profile_relu6(test_case): torch.nn.functional.relu6(torch.ones(1, 128, 28, 28)) torch.nn.functional.relu6(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestLogSigmoidFunction(flow.unittest.TestCase): @autotest(n=5) def test_flow_nn_functional_logsigmoid_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) y = torch.nn.functional.logsigmoid(x) return y @autotest(n=5) def test_flow_nn_functional_logsigmoid_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.nn.functional.logsigmoid(x) return y @flow.unittest.skip_unless_1n1d() class TestHardshrinkModule(flow.unittest.TestCase): @autotest(n=5) def test_hardshrink_module_with_random_data(test_case): m = torch.nn.Hardshrink(lambd=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_hardshrink_module_with_0dim_data(test_case): m = torch.nn.Hardshrink(lambd=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_hardshrink_module_with_0_size_data(test_case): m = torch.nn.Hardshrink(lambd=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @profile(torch.nn.functional.hardshrink) def profile_hardshrink(test_case): torch.nn.functional.hardshrink(torch.ones(1, 128, 28, 28)) torch.nn.functional.hardshrink(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestSoftshrinkModule(flow.unittest.TestCase): @autotest(n=5) def test_softshrink_module_with_random_data(test_case): m = torch.nn.Softshrink(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_softshrink_module_with_0dim_data(test_case): m = torch.nn.Softshrink(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_softshrink_module_with_0_size_data(test_case): m = torch.nn.Softshrink(alpha=random() | nothing()) m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @profile(torch.nn.functional.softshrink) def profile_softshrink(test_case): torch.nn.functional.softshrink(torch.ones(1, 128, 28, 28)) torch.nn.functional.softshrink(torch.ones(16, 128, 28, 28)) @flow.unittest.skip_unless_1n1d() class TestThresholdModule(flow.unittest.TestCase): @autotest(n=5) def test_threshold_module_with_random_data(test_case): m = torch.nn.Threshold( threshold=random() | nothing(), value=random() | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_threshold_module_with_0dim_data(test_case): m = torch.nn.Threshold( threshold=random() | nothing(), value=random() | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=0).to(device) y = m(x) return y @autotest(n=5, auto_backward=False) def test_threshold_module_with_0_size_data(test_case): m = torch.nn.Threshold( threshold=random() | nothing(), value=random() | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor(4, 2, 3, 0, 3).to(device) y = m(x) return y @profile(torch.nn.functional.threshold) def profile_threshold(test_case): torch.nn.functional.threshold( torch.ones(1, 128, 28, 28), threshold=0.1, value=20 ) torch.nn.functional.threshold( torch.ones(16, 128, 28, 28), threshold=0.1, value=20 ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_adaptive_max_pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.nn.common_types import _size_1_t from packaging import version import torch as torch_original from typing import Union, Tuple import numpy as np from oneflow.test_utils.automated_test_util import * NoneType = type(None) @flow.unittest.skip_unless_1n1d() class TestAdaptiveMaxPool(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @autotest(n=5) def test_adaptive_maxpool1d(test_case): m = torch.nn.AdaptiveMaxPool1d(output_size=random().to(_size_1_t)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_adaptive_maxpool2d_manually(test_case): def _test_adaptive_max_pool_nd(input_shape, output_shape, m1, m2): input_np = np.random.rand(2, 3, *input_shape) input_pt = torch_original.tensor( input_np, device="cuda", requires_grad=True ) input_of = flow.tensor(input_np, device="cuda", requires_grad=True) m_pt = m1(output_shape, True) m_of = m2(output_shape, True) output_pt = m_pt(input_pt) output_of = m_of(input_of) sum_pt = torch_original.sum(output_pt[0]) sum_of = flow.sum(output_of[0]) sum_pt.backward() sum_of.backward() test_case.assertTrue( np.array_equal( output_pt[0].detach().cpu().numpy(), output_of[0].detach().cpu().numpy(), ) ) test_case.assertTrue( np.array_equal( output_pt[1].detach().cpu().numpy(), output_of[1].detach().cpu().numpy(), ) ) test_case.assertTrue( np.array_equal(input_pt.grad.cpu().numpy(), input_of.grad.cpu().numpy()) ) _test_adaptive_max_pool_nd( (10, 11), (3, 4), torch_original.nn.AdaptiveMaxPool2d, flow.nn.AdaptiveMaxPool2d, ) _test_adaptive_max_pool_nd( (10, 11, 12), (3, 4, 5), torch_original.nn.AdaptiveMaxPool3d, flow.nn.AdaptiveMaxPool3d, ) @profile(torch.nn.functional.adaptive_max_pool1d) def profile_adaptive_max_pool1d(test_case): torch.nn.functional.adaptive_max_pool1d(torch.ones(1, 64, 8), 5) @profile(torch.nn.functional.adaptive_max_pool2d) def profile_adaptive_max_pool2d(test_case): torch.nn.functional.adaptive_max_pool2d(torch.ones(1, 64, 10, 9), 7) torch.nn.functional.adaptive_max_pool2d(torch.ones(1, 64, 8, 9), (5, 7)) @profile(torch.nn.functional.adaptive_max_pool3d) def profile_adaptive_max_pool3d(test_case): torch.nn.functional.adaptive_max_pool3d(torch.ones(1, 64, 8, 9, 10), (5, 7, 9)) torch.nn.functional.adaptive_max_pool3d(torch.ones(1, 64, 10, 9, 8), 7) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_adaptive_pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.nn.common_types import _size_1_t from packaging import version import torch as torch_original from typing import Union, Tuple from oneflow.test_utils.automated_test_util import * NoneType = type(None) # Not the same as those in PyTorch because 'output_size' cannot be NoneType (even in 'torch.nn.AdaptiveAvgPoolXd') _size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]] _size_3_opt_t_not_none = Union[ int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]] ] @flow.unittest.skip_unless_1n1d() class TestAdaptiveAvgPool(flow.unittest.TestCase): @autotest(n=5) def test_adaptive_avgpool1d(test_case): m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3).to(device) y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool1d) def profile_adaptive_avg_pool1d(test_case): torch.nn.functional.adaptive_avg_pool1d(torch.ones(1, 64, 8), 5) @autotest(n=5) def test_adaptive_avgpool2d(test_case): m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4).to(device) y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool2d) def profile_adaptive_avg_pool2d(test_case): torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 10, 9), 7) torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 8, 9), (5, 7)) @unittest.skipIf( version.parse(torch_original.__version__) < version.parse("1.10.0"), "GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'", ) @autotest(n=5) def test_adaptive_avgpool3d(test_case): m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=5).to(device) y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool3d) def profile_adaptive_avg_pool3d(test_case): torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 8, 9, 10), (5, 7, 9)) torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 10, 9, 8), 7) @flow.unittest.skip_unless_1n1d() class TestAdaptiveAvgPoolFunctional(flow.unittest.TestCase): @autotest(n=5) def test_adaptive_avgpool1d_functional(test_case): device = random_device() x = random_tensor(ndim=3).to(device) return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int)) @autotest(n=5) def test_adaptive_avgpool2d_functional(test_case): device = random_device() x = random_tensor(ndim=4).to(device) return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int)) @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.10.0"), "GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'", ) @autotest(n=5) def test_adaptive_avgpool3d_functional(test_case): device = random_device() x = random_tensor(ndim=5).to(device) return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_adaptive_pool_fp16.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.nn.common_types import _size_1_t from packaging import version import torch as torch_original from typing import Union, Tuple from oneflow.test_utils.automated_test_util import * NoneType = type(None) _size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]] _size_3_opt_t_not_none = Union[ int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]] ] @flow.unittest.skip_unless_1n1d() class Test_CpuFp16_AdaptiveAvgPool(flow.unittest.TestCase): @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool1d(test_case): m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t)) m.train(random()) device = "cpu" m.to(device) x = random_tensor(ndim=3).to(device) x = x.clone().half() y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool1d) def profile_adaptive_avg_pool1d(test_case): return torch.nn.functional.adaptive_avg_pool1d(torch.ones(1, 64, 8).half(), 5) @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool2d(test_case): m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4).to(device) x = x.half() y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool2d) def profile_adaptive_avg_pool2d(test_case): torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 10, 9).half(), 7) torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 8, 9).half(), (5, 7)) @unittest.skipIf( version.parse(torch_original.__version__) < version.parse("1.10.0"), "GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'", ) @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool3d(test_case): m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=5).to(device) x = x.half() y = m(x) return y @profile(torch.nn.functional.adaptive_avg_pool3d) def profile_adaptive_avg_pool3d(test_case): torch.nn.functional.adaptive_avg_pool3d( torch.ones(1, 64, 8, 9, 10).half(), (5, 7, 9) ) torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 10, 9, 8).half(), 7) @flow.unittest.skip_unless_1n1d() class Test_CpuFp16_AdaptiveAvgPoolFunctional(flow.unittest.TestCase): @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool1d_functional(test_case): device = random_device() x = random_tensor(ndim=3).to(device) x = x.half() return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int)) @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool2d_functional(test_case): device = random_device() x = random_tensor(ndim=4).to(device) x = x.half() return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int)) @autotest(n=5, rtol=0.01, atol=0.01) def test_adaptive_avgpool3d_functional(test_case): device = random_device() x = random_tensor(ndim=5).to(device) x = x.half() return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import torch as torch_original import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_add_forward(test_case, shape, device): x = flow.tensor(np.random.randn(*shape), device=flow.device(device)) y = flow.tensor(np.random.randn(*shape), device=flow.device(device)) of_out = flow.add(x, y) np_out = np.add(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = 5 y = flow.tensor(np.random.randn(*shape), device=flow.device(device)) of_out = flow.add(x, y) np_out = np.add(x, y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor(np.random.randn(*shape), device=flow.device(device)) y = 5 of_out = flow.add(x, y) np_out = np.add(x.numpy(), y) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor(np.random.randn(*shape), device=flow.device(device)) y = flow.tensor(np.array([5.0]), device=flow.device(device)) of_out = flow.add(x, y) np_out = np.add(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor(np.random.randn(1, 1), device=flow.device(device)) y = flow.tensor(np.random.randn(*shape), device=flow.device(device)) of_out = flow.add(x, y) np_out = np.add(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) def _test_add_backward(test_case, shape, device): x = 5 y = flow.tensor( np.random.randn(*shape), requires_grad=True, device=flow.device(device) ) of_out = flow.add(x, y).sum() of_out.backward() test_case.assertTrue( np.allclose(y.grad.numpy(), np.ones(shape=shape), 0.0001, 0.0001) ) def _test_inplace_add(test_case, shape, device): np_x = np.random.randn(*shape) of_x = flow.tensor( np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_x_inplace = of_x + 1 id_old = id(of_x_inplace) of_x_inplace.add_(5) test_case.assertEqual(id_old, id(of_x_inplace)) np_out = np_x + 1 + 5 test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05)) of_x_inplace = of_x_inplace.sum() of_x_inplace.backward() test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05)) of_x = flow.tensor( np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) of_x_inplace = of_x + 1 id_old = id(of_x_inplace) of_x_inplace.add_(of_y) test_case.assertEqual(id_old, id(of_x_inplace)) np_out = np_x + 1 + of_y.numpy() test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05)) of_x_inplace = of_x_inplace.sum() of_x_inplace.backward() test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05)) of_x = flow.tensor( np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) of_x_inplace = of_x + 1 id_old = id(of_x_inplace) of_x_inplace += of_y test_case.assertEqual(id_old, id(of_x_inplace)) np_out = np_x + 1 + of_y.numpy() test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05)) of_x_inplace = of_x_inplace.sum() of_x_inplace.backward() test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05)) of_x = flow.tensor( np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_y = flow.tensor( np.array([5.0]), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) of_x_inplace = of_x + 1 id_old = id(of_x_inplace) of_x_inplace.add_(of_y) test_case.assertEqual(id_old, id(of_x_inplace)) np_out = np_x + 6 test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05)) of_x_inplace = of_x_inplace.sum() of_x_inplace.backward() test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05)) of_x = flow.tensor( np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) np_y = np.random.randn(*shape[:-1], 1) of_y = flow.tensor( np_y, dtype=flow.float32, device=flow.device(device), requires_grad=False ) of_x_inplace = of_x + 1 id_old = id(of_x_inplace) of_x_inplace.add_(of_y) test_case.assertEqual(id_old, id(of_x_inplace)) np_out = np_x + 1 + np_y test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05)) of_x_inplace = of_x_inplace.sum() of_x_inplace.backward() test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05)) def _test_inplace_add_with_type_promotion(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), device=flow.device(device), dtype=flow.float16 ) y = flow.tensor( np.random.randn(*shape), device=flow.device(device), dtype=flow.float32 ) x += y test_case.assertTrue(x.dtype == flow.float16) def _test_inplace_add_0_size_tensor(test_case, shape, device): x = flow.randn(0, 256, device=device) y = flow.randn(1, 256, device=device) x += y test_case.assertEqual(x.size(), (0, 256)) @flow.unittest.skip_unless_1n1d() class TestAddModule(flow.unittest.TestCase): def test_add(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_add_forward, _test_add_backward, _test_inplace_add, _test_inplace_add_with_type_promotion, _test_inplace_add_0_size_tensor, ] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, include_complex=True) def test_0_size_add(test_case): device = random_device() x = random_tensor(2, 0, 3).to(device) y = random_tensor(2, 1, 3).to(device) out = x + y return out @autotest(n=6, auto_backward=False, include_complex=True) def test_0dim_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3, requires_grad=False).to(device) y = random_tensor(1, 10).to(device) x += y.mean() return x @autotest(n=10, include_complex=True) def test_0dim_two_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3).to(device).mean() y = random_tensor(2, 2, 3).to(device) x += y.mean() return x @autotest(n=6, include_complex=True) def test_add_with_alpha(test_case): device = random_device() x1 = random_tensor(2, 2, 3).to(device).mean() x2 = random_tensor(2, 2, 3).to(device).mean() x3 = random_tensor(2, 2, 3).to(device).mean() y = random_tensor(2, 2, 3).to(device) s = random().to(float) alpha = random().to(float) z1 = torch.add(x1, y, alpha=alpha) z2 = torch.add(x2, s, alpha=alpha) z3 = torch.add(s, x3, alpha=alpha) return z1, z2, z3 @autotest(auto_backward=False) def test_bool_add(test_case): device = random_device() x = random_tensor(2, 1, 3).to(device, torch.bool) y = random_tensor(2, 1, 3).to(device, torch.bool) out = x + y return out @autotest(auto_backward=False) def test_0shape_bool_add(test_case): device = random_device() x = random_tensor(2, 0, 3).to(device, torch.bool) y = random_tensor(2, 1, 3).to(device, torch.bool) out = x + y return out @autotest(n=3, auto_backward=False) def test_0dim_bool_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3, requires_grad=False).to(device, torch.bool) y = random_tensor(1, 10).to(device) x += y.mean().to(torch.bool) return x @autotest(auto_backward=False) def test_0dim_two_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3).to(device).mean().to(torch.bool) y = random_tensor(2, 2, 3).to(device) return x x += y.mean().to(torch.bool) @autotest(n=6, include_complex=True) def test_add_with_alpha_0dim(test_case): device = random_device() x1 = random_tensor(ndim=0).to(device).mean() x2 = random_tensor(ndim=0).to(device).mean() x3 = random_tensor(ndim=0).to(device).mean() y = random_tensor(ndim=0).to(device) s = random().to(float) alpha = random().to(float) z1 = torch.add(x1, y, alpha=alpha) z2 = torch.add(x2, s, alpha=alpha) z3 = torch.add(s, x3, alpha=alpha) return z1, z2, z3 @profile(torch.add) def profile_add(test_case): torch.add(torch.ones(100), 20) torch.add(torch.ones(100), torch.ones(100, 1), alpha=10) @autotest(n=6, include_complex=True) def test_non_contiguous_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) y = x + 1 y = y[:, 1:3] y += random_tensor(2, 2, 2).to(device) return y @autotest(n=10, include_complex=True) def test_scalar_add_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() x1 = random_tensor(2, 2, 3).to(x1_device).mean() x2 = random_tensor(2, 2, 3).to(x2_device) y = x1 + x2 return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_addcdiv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestAddcdiv(flow.unittest.TestCase): @autotest(n=5) def test_addcdiv(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) tensor1 = random_tensor(ndim, *shape).to(device) tensor2 = random_tensor(ndim, *shape).to(device) value = random(2, 4).to(int) output = torch.addcdiv(input, tensor1, tensor2, value=value) return output @autotest(n=5) def test_tensor_addcdiv(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) tensor1 = random_tensor(ndim, *shape).to(device) tensor2 = random_tensor(ndim, *shape).to(device) value = random(2, 4).to(int) output = input.addcdiv(tensor1, tensor2, value=value) return output @autotest(n=5) def test_tensor_addcdiv_inplace(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) input = input + 1.0 tensor1 = random_tensor(ndim, *shape).to(device) tensor2 = random_tensor(ndim, *shape).to(device) value = random(2, 4).to(int) input.addcdiv_(tensor1, tensor2, value=value) return input @profile(torch.addcdiv) def profile_addcdiv(test_case): t = torch.ones(1, 3) t1 = torch.ones(3, 1) t2 = torch.ones(1, 3) torch.addcdiv(t, t1, t2, value=0.1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_addcmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestAddcmul(flow.unittest.TestCase): @autotest(check_graph=True) def test_addcmul(test_case): device = random_device() ndim = random(low=2).to(int).value() shape = [random(low=2, high=4) for i in range(ndim)] input = random_tensor(len(shape), *shape).to(device) tensor1 = random_tensor(len(shape), *shape).to(device) tensor2 = random_tensor(len(shape), *shape).to(device) value = random(3, 6).to(int) output = torch.addcmul(input, tensor1, tensor2, value=value) return output @autotest(check_graph=True) def test_tensor_addcmul(test_case): device = random_device() ndim = random(low=2).to(int).value() shape = [random(low=2, high=4) for i in range(ndim)] input = random_tensor(len(shape), *shape).to(device) tensor1 = random_tensor(len(shape), *shape).to(device) tensor2 = random_tensor(len(shape), *shape).to(device) value = random(3, 6).to(int) output = input.addcmul(tensor1, tensor2, value=value) return output @autotest(check_graph=True) def test_tensor_addcmul_inplace(test_case): device = random_device() ndim = random(low=2).to(int).value() shape = [random(low=2, high=4) for i in range(ndim)] input = random_tensor(len(shape), *shape).to(device) input = input + 1.0 tensor1 = random_tensor(len(shape), *shape).to(device) tensor2 = random_tensor(len(shape), *shape).to(device) value = random(3, 6).to(int) input.addcmul_(tensor1, tensor2, value=value) return input @profile(torch.addcmul) def profile_addcmul(test_case): input = torch.ones(100, 12, 13) tensor1 = torch.ones(100, 12, 13) tensor2 = torch.ones(100, 12, 13) torch.addcmul(input, tensor1, tensor2, value=2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_addmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_addmm(test_case, shape, alpha, beta, device): mat1 = np.random.randn(*shape) mat2 = np.random.randn(*shape) input = np.random.randn(*shape) mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device)) mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device)) input_tensor = flow.tensor(input, dtype=flow.float32, device=flow.device(device)) of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta) np_out = np.add(beta * input, alpha * np.matmul(mat1, mat2)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_addmm_backward(test_case, shape, alpha, beta, device): mat1 = np.random.randn(*shape) mat2 = np.random.randn(*shape) input = np.random.randn(*shape) mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device)) mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device)) input_tensor = flow.tensor( input, dtype=flow.float32, requires_grad=True, device=flow.device(device) ) of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta).sum() of_out.backward() np_grad_out = np.ones_like(input) * beta test_case.assertTrue( np.allclose(input_tensor.grad.numpy(), np_grad_out, 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestAddmm(flow.unittest.TestCase): def test_addmm(test_case): arg_dict = OrderedDict() arg_dict["function_test"] = [_test_addmm, _test_addmm_backward] arg_dict["shape"] = [(3, 3)] arg_dict["alpha"] = [4, 1.2, -3.7] arg_dict["beta"] = [1.5, 4, -2] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, rtol=1e-2, atol=1e-3) def test_addmm_flow_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=3).to(device) mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device) mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device) y = torch.addmm( input, mat1, mat2, beta=random().to(float) | nothing(), alpha=random().to(float) | nothing(), ) return y @autotest(n=5, rtol=1e-2, atol=1e-3) def test_addmm_broadcast_flow_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=1, dim1=1).to(device) mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device) mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device) y = torch.addmm( input, mat1, mat2, beta=random().to(float) | nothing(), alpha=random().to(float) | nothing(), ) return y @profile(torch.addmm) def profile_addmm(test_case): input = torch.ones(2, 3) mat1 = torch.ones(2, 3) mat2 = torch.ones(3, 3) torch.addmm(input, mat1, mat2) torch.addmm(input, mat1, mat2, alpha=1, beta=2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_affine_grid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import randint from random import choice import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestAffineGrid(flow.unittest.TestCase): def test_affine_grid_2d(test_case): input = flow.tensor(np.arange(1.0, 7).reshape((1, 2, 3)), dtype=flow.float32) output = flow.nn.functional.affine_grid( input, flow.Size([1, 1, 2, 2]), align_corners=True ) groundtruth = np.array([[[[0.0, -3.0], [2.0, 5.0]], [[4.0, 7.0], [6.0, 15.0]]]]) test_case.assertTrue( np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) output = flow.nn.functional.affine_grid( input, flow.Size([1, 1, 2, 2]), align_corners=False ) groundtruth = np.array([[[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]]) test_case.assertTrue( np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) def test_affine_grid_3d(test_case): input = flow.tensor(np.arange(1.0, 13).reshape((1, 3, 4)), dtype=flow.float32) output = flow.nn.functional.affine_grid( input, flow.Size([1, 1, 2, 2, 2]), align_corners=True ) groundtruth = np.array( [ [ [ [[-2.0, -10.0, -18.0], [0.0, 0.0, 0.0]], [[2.0, 2.0, 2.0], [4.0, 12.0, 20.0]], ], [ [[4.0, 4.0, 4.0], [6.0, 14.0, 22.0]], [[8.0, 16.0, 24.0], [10.0, 26.0, 42.0]], ], ] ] ) test_case.assertTrue( np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) output = flow.nn.functional.affine_grid( input, flow.Size([1, 1, 2, 2, 2]), align_corners=False ) groundtruth = np.array( [ [ [ [[1.0, -1.0, -3.0], [2.0, 4.0, 6.0]], [[3.0, 5.0, 7.0], [4.0, 10.0, 16.0]], ], [ [[4.0, 6.0, 8.0], [5.0, 11.0, 17.0]], [[6.0, 12.0, 18.0], [7.0, 17.0, 27.0]], ], ] ] ) test_case.assertTrue( np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) @autotest(n=5, rtol=1e-03, atol=1e-04, check_allclose=False, check_graph=True) def test_flow_affine_grid_2d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) H = randint(1, 8) W = randint(1, 8) device = random_device() align_corners = choice([True, False]) theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) output = torch.nn.functional.affine_grid( theta, (N, C, H, W), align_corners=align_corners ).to(device) return output @autotest(rtol=1e-03, atol=1e-03, check_allclose=False, check_graph=True) def test_flow_affine_grid_3d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) D = randint(1, 8) H = randint(1, 8) W = randint(1, 8) device = random_device() align_corners = choice([True, False]) theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device) output = torch.nn.functional.affine_grid( theta, (N, C, D, H, W), align_corners=align_corners ).to(device) return output @profile(torch.nn.functional.affine_grid) def profile_affine_grid(test_case): input = torch.tensor(np.arange(1.0, 7).reshape((1, 2, 3)), dtype=torch.float32) torch.nn.functional.affine_grid( input, torch.Size([1, 1, 2, 2]), align_corners=True ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_allclose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest rtol = 1e-3 def _perturbate(x): shape = x.oneflow.shape device = x.device diff = ( random_tensor(len(shape), *shape, low=-1, high=1, requires_grad=False).to( device ) * rtol * 2 ) return x + diff @flow.unittest.skip_unless_1n1d() class TestAllClose(flow.unittest.TestCase): @autotest(n=10, auto_backward=False, check_graph=False) def test_allclose_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = torch.allclose(x1, x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_allclose_with_0dim_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = torch.allclose(x1, x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_tensor_allclose_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = x1.allclose(x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_allclose_broadcast(test_case): device = random_device() x1 = random_tensor(2, 2, 8, requires_grad=False).to(device) x2 = _perturbate(x1[:, :1]) y = torch.allclose(x1, x2, rtol=rtol) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_allreduce.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllReduce(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_all_reduce(test_case): arr_rank1 = np.array([1, 2]) arr_rank2 = np.array([3, 4]) if flow.env.get_rank() == 0: x = flow.Tensor(arr_rank1) elif flow.env.get_rank() == 1: x = flow.Tensor(arr_rank2) else: raise ValueError x = x.to("cuda") y = flow._C.local_all_reduce(x) test_case.assertTrue(np.allclose(y.numpy(), arr_rank1 + arr_rank2)) @flow.unittest.skip_unless_2n2d() def test_all_reduce_2nodes(test_case): np_arr = np.array([1, 2]) x = flow.Tensor(np_arr * (flow.env.get_rank() + 1)) x = x.to("cuda") y = flow._C.local_all_reduce(x) test_case.assertTrue(np.allclose(y.numpy(), np_arr * 10)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_amax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import numpy as np def __check(test_case, input, dim, keepdim, device): of_out = flow.amax(input, dim=dim, keepdim=keepdim) if type(dim) is tuple: if len(dim) == 0: dim = None np_out = np.amax(input.numpy(), axis=dim, keepdims=keepdim) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=0.0001, atol=1e-05,)) def _test_amax_with_negative_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 0).to(int).value() keepdim = random_bool().value() __check(test_case, input, dim, keepdim, device) def _test_amax_with_positive_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(0, 4).to(int).value() keepdim = random_bool().value() __check(test_case, input, dim, keepdim, device) def _test_amax_with_multiple_axes(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) axes = set() num_axes = random(1, 4).to(int).value() for _ in range(num_axes): axes.add(random(0, 4).to(int).value()) keepdim = random_bool().value() __check(test_case, input, tuple(axes), keepdim, device) def _test_amax_with_empty_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) keepdim = random_bool().value() __check(test_case, input, None, keepdim, device) def _test_amax_keepdim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 4).to(int).value() keepdim = True __check(test_case, input, dim, keepdim, device) def _test_amax_not_keepdim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 4).to(int).value() keepdim = False __check(test_case, input, dim, keepdim, device) @flow.unittest.skip_unless_1n1d() class TestAmax(flow.unittest.TestCase): def test_amax(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_amax_with_negative_dim, _test_amax_with_positive_dim, _test_amax_with_multiple_axes, _test_amax_with_empty_dim, _test_amax_keepdim, _test_amax_not_keepdim, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_amax_with_random_data_single_dim(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.amax(x, dim=random(0, ndim), keepdim=random().to(bool)) return y @autotest(n=5) def test_amax_with_random_data_empty_dim(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.amax(x, dim=None, keepdim=random().to(bool)) return y @autotest(n=5) def test_amax_with_random_data_multi_dims(test_case): device = random_device() ndim = random(2, 6).to(int) x = random_tensor(ndim=ndim).to(device) dim = set() for _ in range(random(1, ndim).to(int).value()): dim.add(random(0, ndim).to(int).value()) y = torch.amax(x, dim=tuple(dim), keepdim=random().to(bool)) return y @profile(torch.amax) def profile_amax(test_case): input1 = torch.ones(4, 4) input2 = torch.ones(100, 100) torch.amax(input1, 1) torch.amax(input1, 1, True) torch.amax(input2, 1) torch.amax(input2, 1, True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_amin.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import numpy as np def __check(test_case, input, dim, keepdim, device): of_out = flow.amin(input, dim=dim, keepdim=keepdim) if type(dim) is tuple: if len(dim) == 0: dim = None np_out = np.amin(input.numpy(), axis=dim, keepdims=keepdim) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=0.0001, atol=1e-05,)) def _test_amin_with_negative_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 0).to(int).value() keepdim = random_bool().value() __check(test_case, input, dim, keepdim, device) def _test_amin_with_positive_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(0, 4).to(int).value() keepdim = random_bool().value() __check(test_case, input, dim, keepdim, device) def _test_amin_with_multiple_axes(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) axes = set() num_axes = random(1, 4).to(int).value() for _ in range(num_axes): axes.add(random(0, 4).to(int).value()) keepdim = random_bool().value() __check(test_case, input, tuple(axes), keepdim, device) def _test_amin_with_empty_dim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) keepdim = random_bool().value() __check(test_case, input, None, keepdim, device) def _test_amin_keepdim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 4).to(int).value() keepdim = True __check(test_case, input, dim, keepdim, device) def _test_amin_not_keepdim(test_case, device): input = flow.tensor( np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device) ) dim = random(-4, 4).to(int).value() keepdim = False __check(test_case, input, dim, keepdim, device) @flow.unittest.skip_unless_1n1d() class TestAmin(flow.unittest.TestCase): def test_amin(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_amin_with_negative_dim, _test_amin_with_positive_dim, _test_amin_with_multiple_axes, _test_amin_with_empty_dim, _test_amin_keepdim, _test_amin_not_keepdim, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_amin_with_random_data_single_dim(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.amin(x, dim=random(0, ndim), keepdim=random().to(bool)) return y @autotest(n=5) def test_amin_with_random_data_empty_dim(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.amin(x, dim=None, keepdim=random().to(bool)) return y @autotest(n=5) def test_amin_with_random_data_multi_dims(test_case): device = random_device() ndim = random(2, 6).to(int) x = random_tensor(ndim=ndim).to(device) dim = set() for _ in range(random(1, ndim).to(int).value()): dim.add(random(0, ndim).to(int).value()) y = torch.amin(x, dim=tuple(dim), keepdim=random().to(bool)) return y @profile(torch.amin) def profile_amin(test_case): input1 = torch.ones(4, 4) input2 = torch.ones(100, 100) torch.amin(input1, 1) torch.amin(input1, 1, True) torch.amin(input2, 1) torch.amin(input2, 1, True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_arange.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_arange(test_case, device): np_out = np.arange(13, dtype=np.float32) of_out = flow.arange(13, device=device, dtype=flow.float32) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) np_out = np.arange(13, dtype=np.float16) of_out = flow.arange(13, device=device, dtype=flow.float16) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_arange_step_prarm(test_case, device): np_out = np.arange(0, 20, 2) of_out = flow.arange(0, 20, step=2, device=device) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_arange_more_params(test_case, device): np_out = np.arange(0, 100, 3) of_out = flow.arange(start=0, end=100, step=3, device=device) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_arange_backward(test_case, device): x = flow.arange(13, dtype=flow.float32, device=device) x.requires_grad = True y = x.sum() y.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(13), 1e-05, 1e-05)) x = flow.arange(13, dtype=flow.float16, device=device) x.requires_grad = True y = x.sum() y.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(13), 1e-05, 1e-05)) def _test_arange_input_tensor_type(test_case, device): x = flow.tensor([[1, 2], [3, 4]], dtype=flow.int64).to(device) y = flow.arange(start=flow.min(x), end=flow.max(x), device=device) test_case.assertTrue(np.allclose(y.numpy(), np.arange(1, 4))) x = flow.tensor([[1, 2], [3, 4]], dtype=flow.int64).to(device) y = flow.arange( start=flow.min(x), end=flow.max(x), device=device, dtype=flow.float16 ) test_case.assertTrue(np.allclose(y.numpy(), np.arange(1, 4))) @flow.unittest.skip_unless_1n1d() class TestArange(flow.unittest.TestCase): def test_arange(test_case): arg_dict = OrderedDict() arg_dict["function_test"] = [ _test_arange, _test_arange_step_prarm, _test_arange_more_params, _test_arange_backward, _test_arange_input_tensor_type, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_arange_with_random_data(test_case): start = random().to(int) end = start + random().to(int) step = random(1, end - start + 1).to(int) x = torch.arange(start=start, end=end, step=step) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_arange_with_float_delta(test_case): start = random().to(int) end = start + random().to(int) step = random(1, end - start + 1).to(float) x = torch.arange(start=start, end=end, step=step) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_arange_input_float_scalar_tensor(test_case): start = random().to(float) end = start + random().to(float) x = torch.arange(start=torch.tensor(start), end=torch.tensor(end)) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_arange_input_float16_scalar_tensor(test_case): start = random().to(float) end = start + random().to(float) start, end = torch.tensor(start).half(), torch.tensor(end).half() x = torch.arange(start=start, end=end) device = random_device() x.to(device) return x def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.arange(start=0, end=10, step=1, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) @profile(torch.arange) def profile_arange(test_case): torch.arange(5) torch.arange(100000) torch.arange(1, 4) torch.arange(1, 2.5, 0.5) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_argmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_argmax_axis_negative(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = -1 of_out = flow.argmax(input, dim=axis) np_out = np.argmax(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_tensor_argmax(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 0 of_out = input.argmax(dim=axis) np_out = np.argmax(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmax_axis_postive(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 1 of_out = flow.argmax(input, dim=axis) np_out = np.argmax(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmax_keepdims(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 0 of_out = input.argmax(axis, True) np_out = np.argmax(input.numpy(), axis=axis) np_out = np.expand_dims(np_out, axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmax_dim_equal_none(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = input.argmax() np_out = np.argmax(input.numpy().flatten(), axis=0) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) @flow.unittest.skip_unless_1n1d() class TestArgmax(flow.unittest.TestCase): def test_argmax(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_argmax_axis_negative, _test_tensor_argmax, _test_argmax_axis_postive, _test_argmax_keepdims, _test_argmax_dim_equal_none, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_argmax_with_random_data(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.argmax(x, dim=random(0, ndim).to(int), keepdim=random().to(bool)) return y @profile(torch.argmax) def profile_argmax(test_case): torch.argmax(torch.ones(100000)) torch.argmax(torch.ones(1000000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_argmin.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_argmin_axis_negative(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = -1 of_out = flow.argmin(input, dim=axis) np_out = np.argmin(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_tensor_argmin(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 0 of_out = input.argmin(dim=axis) np_out = np.argmin(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmin_axis_postive(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 1 of_out = flow.argmin(input, dim=axis) np_out = np.argmin(input.numpy(), axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmin_keepdims(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) axis = 0 of_out = input.argmin(axis, True) np_out = np.argmin(input.numpy(), axis=axis) np_out = np.expand_dims(np_out, axis=axis) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_argmin_dim_equal_none(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = input.argmin() np_out = np.argmin(input.numpy().flatten(), axis=0) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) @flow.unittest.skip_unless_1n1d() class TestArgmin(flow.unittest.TestCase): def test_argmin(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_argmin_axis_negative, _test_tensor_argmin, _test_argmin_axis_postive, _test_argmin_keepdims, _test_argmin_dim_equal_none, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_argmin_with_random_data(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.argmin(x, dim=random(0, ndim).to(int), keepdim=random().to(bool)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_argsort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_argsort(test_case, data_shape, axis, descending, data_type, device): input = flow.tensor( np.random.randn(*data_shape), dtype=type_name_to_flow_type[data_type], device=flow.device(device), ) np_input = -input.numpy() if descending else input.numpy() if axis is not None: of_out = flow.argsort(input, dim=axis, descending=descending) np_out = np.argsort(np_input, axis=axis) else: of_out = flow.argsort(input, descending=descending) np_out = np.argsort(np_input) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_tensor_argsort(test_case, data_shape, axis, descending, data_type, device): input = flow.tensor( np.random.randn(*data_shape), dtype=type_name_to_flow_type[data_type], device=flow.device(device), ) np_input = -input.numpy() if descending else input.numpy() if axis is not None: of_out = input.argsort(dim=axis, descending=descending) np_out = np.argsort(np_input, axis=axis) else: of_out = input.argsort(descending=descending) np_out = np.argsort(np_input) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) @flow.unittest.skip_unless_1n1d() class TestArgsort(flow.unittest.TestCase): def test_argsort(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_argsort, _test_tensor_argsort] arg_dict["data_shape"] = [(2, 6, 5, 4), (3, 4, 8)] arg_dict["axis"] = [-1, 0, 2, None] arg_dict["descending"] = [True, False] arg_dict["data_type"] = ["double", "float32", "int32"] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(auto_backward=False, check_graph=True) def test_argsort_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.argsort( x, dim=random(low=-4, high=4).to(int), descending=random_bool() ) return y @autotest(auto_backward=False, check_graph=True) def test_argsort_bool_with_random_data(test_case): x = random_tensor(ndim=4).to("cpu", torch.bool) y = torch.argsort( x, dim=random(low=-4, high=4).to(int), descending=random_bool() ) return y @profile(torch.argsort) def profile_argsort(test_case): torch.argsort(torch.ones(10, 10), dim=1) torch.argsort(torch.ones(1000, 1000), dim=1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_argwhere.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from packaging import version from oneflow.test_utils.automated_test_util import * import torch as torch_original def _test_argwhere(test_case, shape, device): np_input = np.random.randn(*shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) of_out = flow.argwhere(input) np_out = np.argwhere(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) @flow.unittest.skip_unless_1n1d() class TestArgwhere(flow.unittest.TestCase): def test_argwhere(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_argwhere] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skip("pytorch do not have argwhere fn/module yet!") @autotest(n=5, rtol=1e-5, atol=1e-5) def test_argwhere_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(2, 5).to(int)).to(device) y = torch.argwhere(x) return y has_pytorch_1_11 = version.parse(torch_original.__version__) >= version.parse( "1.11.0" ) @unittest.skipIf( not has_pytorch_1_11, "torch.argwhere only exists in PyTorch >= 1.11.0" ) @profile(torch.argwhere if has_pytorch_1_11 else None) def profile_argwhere(test_case): torch.argwhere(torch.ones(3, 3, 100, 100)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_as_strided.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestAsStrided(flow.unittest.TestCase): @autotest(n=10) def test_flow_AsStrided(test_case): device = random_device() ndim = np.random.randint(3, 6) dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) dim4 = np.random.randint(2, 4) if ndim == 3: x = random_tensor(3, dim0, dim1, dim2) elif ndim == 4: x = random_tensor(4, dim0, dim1, dim2, dim3) elif ndim == 5: x = random_tensor(5, dim0, dim1, dim2, dim3, dim4) x = x.to(device) storage_offset = random(0, 3).to(int) z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset) return z @autotest(n=5) def test_tensor_as_strided(test_case): device = random_device() ndim = np.random.randint(3, 6) dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) dim4 = np.random.randint(2, 4) if ndim == 3: x = random_tensor(3, dim0, dim1, dim2) elif ndim == 4: x = random_tensor(4, dim0, dim1, dim2, dim3) elif ndim == 5: x = random_tensor(5, dim0, dim1, dim2, dim3, dim4) x = x.to(device) storage_offset = random(0, 3).to(int) y = x.as_strided((2, 2, 3), (1, 1, 2), storage_offset) return y @autotest(n=10) def test_flow_as_strided_tensor_method(test_case): device = random_device() ndim = np.random.randint(3, 6) x = random_tensor(ndim, *[np.random.randint(2, 4) for _ in range(ndim)]) x = x.to(device) storage_offset = random(0, 3).to(int) z = x.as_strided((2, 2, 3), (1, 1, 2), storage_offset) return z @autotest(n=10) def test_flow_as_strided_with_stride(test_case): device = random_device() dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) x = random_tensor(4, dim0, dim1, dim2, dim3) x = x.to(device) storage_offset = random(0, 3).to(int) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.as_strided(y, (2, 2, 3), (1, 1, 2), storage_offset) return z @autotest(n=5, auto_backward=False) def test_flow_as_strided_bool(test_case): device = random_device() ndim = np.random.randint(3, 6) dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) dim4 = np.random.randint(2, 4) if ndim == 3: x = random_tensor(3, dim0, dim1, dim2) elif ndim == 4: x = random_tensor(4, dim0, dim1, dim2, dim3) elif ndim == 5: x = random_tensor(5, dim0, dim1, dim2, dim3, dim4) x = x.to(device) x = x.to(torch.bool) storage_offset = random(0, 3).to(int) z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset) return z @autotest(n=5, auto_backward=False) def test_flow_as_strided_int8(test_case): device = random_device() ndim = np.random.randint(3, 6) dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) dim4 = np.random.randint(2, 4) if ndim == 3: x = random_tensor(3, dim0, dim1, dim2) elif ndim == 4: x = random_tensor(4, dim0, dim1, dim2, dim3) elif ndim == 5: x = random_tensor(5, dim0, dim1, dim2, dim3, dim4) x = x.to(device) x = x.to(torch.int8) storage_offset = random(0, 3).to(int) z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset) return z @autotest(n=5, auto_backward=False) def test_flow_as_strided_uint8(test_case): device = random_device() ndim = np.random.randint(3, 6) dim0 = np.random.randint(2, 4) dim1 = np.random.randint(2, 4) dim2 = np.random.randint(2, 4) dim3 = np.random.randint(2, 4) dim4 = np.random.randint(2, 4) if ndim == 3: x = random_tensor(3, dim0, dim1, dim2) elif ndim == 4: x = random_tensor(4, dim0, dim1, dim2, dim3) elif ndim == 5: x = random_tensor(5, dim0, dim1, dim2, dim3, dim4) x = x.to(device) x = x.to(torch.uint8) storage_offset = random(0, 3).to(int) z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset) return z @profile(torch.as_strided) def profile_as_strided(test_case): input = torch.ones(10, 10, 128, 128) torch.as_strided(input, (10, 3, 128, 128), (1, 1, 1, 1)) torch.as_strided(input, (10, 3, 128, 128), (1, 1, 1, 1), 1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_as_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import random import unittest import numpy as np import oneflow as flow import oneflow.unittest numpy_dtype_to_oneflow_dtype_dict = { np.int32: flow.int32, np.int64: flow.int64, np.int8: flow.int8, np.uint8: flow.uint8, np.float64: flow.float64, np.float32: flow.float32, np.float16: flow.float16, } @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestAsTensor(flow.unittest.TestCase): def test_tensor_type(test_case): x = flow.randn(2, 3) y = flow.as_tensor(x) y[0] = 2.0 test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) test_case.assertTrue(np.array_equal(id(x), id(y))) x = flow.randn(2, 3) x = x.to("cuda") y = flow.as_tensor(x) y[0] = 2.0 test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) test_case.assertTrue(np.array_equal(id(x), id(y))) x = flow.randn(2, 3) y = flow.as_tensor(x, device=flow.device("cuda:0")) test_case.assertTrue(id(x) != id(y)) for dtype in [ flow.float64, flow.float16, flow.int64, flow.int32, flow.int8, flow.uint8, ]: x = flow.randn(2, 3) y = flow.as_tensor(x, dtype=dtype) test_case.assertTrue(id(x) != id(y)) def test_numpy_type(test_case): for device in [flow.device("cpu"), flow.device("cuda:0"), None]: for np_dtype in [ np.float64, np.float32, np.float16, np.int64, np.int32, np.int8, np.uint8, ]: for flow_dtype in [ flow.float64, flow.float16, flow.int64, flow.int32, flow.int8, flow.uint8, ]: np_arr = np.ones((2, 3), dtype=np_dtype) try: tensor = flow.as_tensor(np_arr, dtype=flow_dtype) if numpy_dtype_to_oneflow_dtype_dict[ np_arr.dtype ] == flow_dtype and device is not flow.device("cuda:0"): tensor[0][0] += 1.0 test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) else: test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) except Exception as e: # Ignore cast or kernel mismatch error in test example pass def test_other_type(test_case): for device in [flow.device("cpu"), flow.device("cuda:0"), None]: for np_dtype in [ np.float64, np.float32, np.float16, np.int64, np.int32, np.int8, np.uint8, ]: for flow_dtype in [ flow.float64, flow.float16, flow.int64, flow.int32, flow.int8, flow.uint8, ]: # tuple np_arr = (1.0, 2.0, 3.0) try: tensor = flow.as_tensor(np_arr, dtype=flow_dtype) test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) except Exception as e: # Ignore cast or kernel mismatch error in test example pass # tuple np_arr = [1.0, 2.0, 3.0] try: tensor = flow.as_tensor(np_arr, dtype=flow_dtype) test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) except Exception as e: # Ignore cast or kernel mismatch error in test example pass # scalar np_arr = 4.0 try: tensor = flow.as_tensor(np_arr, dtype=flow_dtype) test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) except Exception as e: # Ignore cast or kernel mismatch error in test example pass def test_numpy_dtype_bug(test_case): test_case.assertEqual(flow.as_tensor([1.0]).dtype, flow.float32) x = np.random.randn(10) y1 = flow.as_tensor(x, dtype=flow.int64) y2 = flow.as_tensor(x, dtype=flow.float64) test_case.assertEqual(y1.dtype, flow.int64) test_case.assertEqual(y2.dtype, flow.float64) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_asyncs_thread.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestLocalThread(flow.unittest.TestCase): def test_stream(test_case): with flow.asyncs.thread(flow.asyncs.Thread()): test_case.assertEqual(flow.ones(1)[0], 1) @flow.unittest.skip_unless_1n2d() class TestGlobalThread(flow.unittest.TestCase): def test_cpu_stream(test_case): threads = [flow.asyncs.Thread() for i in range(7)] iter_and_threads = [(i, threads[i % 7]) for i in range(30)] for i, thread in iter_and_threads: with flow.asyncs.thread(thread): placement = flow.placement("cpu", [0, 1]) tensor = flow.ones(2, placement=placement, sbp=flow.sbp.split(0)) test_case.assertEqual(tensor[0], 1) test_case.assertEqual(tensor[1], 1) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_stream(test_case): threads = [flow.asyncs.Thread() for i in range(7)] iter_and_threads = [(i, threads[i % 7]) for i in range(200)] tensors = [] dim = 0 for i, thread in iter_and_threads: dim += 1 with flow.asyncs.thread(thread): placement = flow.placement("cuda", [0, 1]) ones = flow.ones(2 * dim, placement=placement, sbp=flow.sbp.split(0)) tensors.append(ones.to_global(sbp=flow.sbp.broadcast) + i) for i, tensor in enumerate(tensors): test_case.assertEqual(tensor[0], 1 + i) test_case.assertEqual(tensor[int(tensor.shape[0] / 2)], 1 + i) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_atleast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestAtLeast(flow.unittest.TestCase): @autotest(n=5) def test_atleast_1d_with_list_random_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=2).to(device) out = torch.atleast_1d([x, y]) return out @autotest(n=5) def test_atleast_1d_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(low=0, high=3).to(int)).to(device) out = torch.atleast_1d(x) return out @autotest(n=5) def test_atleast_2d_with_list_random_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=1).to(device) z = random_tensor(ndim=3).to(device) out = torch.atleast_2d([x, y, z]) return out @autotest(n=5) def test_atleast_2d_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(low=0, high=4).to(int)).to(device) out = torch.atleast_2d(x) return out @autotest(n=5) def test_atleast_3d_with_list_random_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=1).to(device) z = random_tensor(ndim=2).to(device) p = random_tensor(ndim=4).to(device) out = torch.atleast_3d([x, y, z, p]) return out @autotest(n=5) def test_atleast_3d_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(low=0, high=5).to(int)).to(device) out = torch.atleast_3d(x) return out if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_auto_to_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import os from oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest from oneflow.test_utils.test_util import GenArgList def _test_auto_to_global(test_case, device): os.environ["ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT"] = "true" x = flow.ones( (2, 2), sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement(device, ranks=[[0], [1]]), ) y = flow.zeros( (2, 2), sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement(device, ranks=[[2], [3]]), ) z = x + y test_case.assertTrue(np.array_equal(x.numpy(), z.numpy())) test_case.assertEqual(y.placement, z.placement) os.environ["ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT"] = "false" @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAutoToGlobal(flow.unittest.TestCase): @globaltest @flow.unittest.skip_unless_1n4d() def test_auto_to_global(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] for arg in GenArgList(arg_dict): _test_auto_to_global(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_autograd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import torch as original_torch import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList def _test_autograd_backward(test_case, shape, device): np_input = np.random.rand(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input ** 2 of_out_sum = of_out.sum() of_out_sum.backward() test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_input * 2, 0.0001, 0.0001) ) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input ** 2 of_out_sum = of_out.sum() of_out_sum.backward(flow.ones_like(of_out_sum) * 3) test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_input * 6, 0.0001, 0.0001) ) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input ** 2 of_out_sum = of_out.sum() of_out_sum.backward(retain_graph=True) of_out_sum.backward(retain_graph=True) test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_input * 4, 0.0001, 0.0001) ) def _test_autograd_grad(test_case, shape, device): np_input = np.random.rand(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input ** 2 of_out_sum = of_out.sum() grad = flow.autograd.grad(of_out_sum, of_input)[0] test_case.assertTrue(of_input.grad is None) test_case.assertTrue(np.allclose(grad.numpy(), np_input * 2, 0.0001, 0.0001)) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input ** 2 of_out_sum = of_out.sum() grad = flow.autograd.grad(of_out_sum, of_input, flow.ones_like(of_out_sum) * 3)[0] test_case.assertTrue(np.allclose(grad.numpy(), np_input * 6, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestAutograd(flow.unittest.TestCase): def test_autograd_interface(test_case): arg_dict = OrderedDict() arg_dict["case"] = [_test_autograd_backward, _test_autograd_grad] arg_dict["shape"] = [(2, 3), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def test_accumulate_grad(test_case): device = random_device() ndim = random(1, 4).to(int) x = random_tensor(ndim=ndim, requires_grad=True).to(device) y = random_tensor(ndim=ndim, requires_grad=True).to(device) return x / (x + y) @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def test_0dim_accumulate_grad(test_case): device = random_device() ndim = 0 x = random_tensor(ndim=ndim, requires_grad=True).to(device) y = random_tensor(ndim=ndim, requires_grad=True).to(device) return x / (x + y) @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def test_scalar_leaf_tensor_backward(test_case): device = random_device() ndim = 0 x = random_tensor(ndim=ndim, requires_grad=True).to(device) return x @autotest(n=1, auto_backward=False, check_graph=False) def test_out_grad_with_different_dtype(test_case): x = random_tensor(ndim=2, requires_grad=True) y = x.sum() y.backward(torch.tensor(False)) return x.grad @autotest(n=10, auto_backward=False, check_graph=False) def test_grad_grad(test_case): device = random_device() ndim = random(1, 4).to(int) x = random_tensor(ndim=ndim, requires_grad=True).to(device) y = x * x * x x_grad = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=torch.ones_like(y), create_graph=True, retain_graph=True, )[0] x_grad_grad = torch.autograd.grad( outputs=x_grad, inputs=x, grad_outputs=torch.ones_like(x_grad) )[0] return x_grad_grad @autotest(n=10, auto_backward=False, rtol=1e-3, atol=1e-3, check_graph=False) def test_autograd_multiple_times(test_case): device = random_device() ndim = random(1, 4).to(int).value() dims = [random(0, 10).to(int) for _ in range(ndim)] x = random_tensor(ndim, *dims, requires_grad=True) x1 = x.to(device) y = random_tensor(ndim, *dims, requires_grad=True) y1 = y.to(device) z = x1 + y1 for _ in range(10): z.sum().backward() return (x.grad, y.grad) def test_autograd_set_acc_grad_and_backward(test_case): for _ in range(5): ndim = 2 dims = [random(1, 5).to(int).value() for _ in range(ndim)] x = torch.randn(*dims).requires_grad_() np_arr = np.random.rand(*dims) init_grad = torch.tensor(np_arr).to(x.dtype) x.pytorch.grad = init_grad.pytorch x.oneflow.grad = init_grad.oneflow x.sum().backward() test_case.assertTrue( np.allclose( x.grad.oneflow.numpy(), x.grad.pytorch.cpu().detach().numpy() ) ) @autotest(n=1, check_graph=False) def test_requires_grad_tensor_inplace_and_backward(test_case): random_shape = [random(1, 10).to(int) for _ in range(4)] x = random_tensor(4, *random_shape, requires_grad=False) y = random_tensor(4, *random_shape, requires_grad=True) x += y return x @autotest(n=1, check_graph=False) def test_retain_grad_for_leaf_tensor(test_case): random_shape = [random(1, 10).to(int) for _ in range(4)] x = random_tensor(4, *random_shape, requires_grad=True) y = x * 2 x.retain_grad() return y @autotest(n=1, auto_backward=False, check_graph=False) def test_run_backward_and_grad_for_same_tensor(test_case): random_shape = [random(1, 10).to(int) for _ in range(4)] x = random_tensor(4, *random_shape, requires_grad=True) y = x ** 2 y.sum().backward() test_case.assertTrue( np.allclose(x.grad.oneflow.numpy(), x.grad.pytorch.numpy()) ) y = x ** 2 x_grad = torch.autograd.grad(y.sum(), x)[0] test_case.assertTrue( np.allclose(x_grad.oneflow.numpy(), x_grad.pytorch.numpy()) ) test_case.assertTrue( np.allclose(x.grad.oneflow.numpy(), x_grad.oneflow.numpy()) ) @autotest(n=1, auto_backward=False, check_graph=False) def test_no_grad_domain_call_backward(test_case): random_shape = [random(1, 10).to(int).value() for _ in range(4)] with flow.no_grad(): x = flow.rand(*random_shape).requires_grad_() with flow.enable_grad(): y = x * 2 flow.autograd.backward(y, flow.ones_like(y)) test_case.assertTrue(np.array_equal(x.grad.numpy(), np.full(random_shape, 2.0))) @autotest(n=1, auto_backward=False, check_graph=False) def test_acc_grad_inplace_update(test_case): random_shape = [random(1, 5).to(int).value() for _ in range(4)] x = flow.rand(*random_shape).requires_grad_() y = flow.rand(*random_shape).requires_grad_() z = x / (x + y) z.sum().backward() id_x_grad = id(x.grad) id_y_grad = id(y.grad) z = x / (x + y) z.sum().backward() test_case.assertEqual(id_x_grad, id(x.grad)) test_case.assertEqual(id_y_grad, id(y.grad)) def test_autograd_grad_allow_unused(test_case): shape = [random(1, 10).to(int) for _ in range(4)] shape = [2, 4] device = random_device() x = random_tensor(len(shape), *shape, requires_grad=True).to(device) z = random_tensor(len(shape), *shape, requires_grad=True).to(device) y = x * x np_arr = np.random.rand(*y.oneflow.shape) init_grad = torch.tensor(np_arr).requires_grad_().to(device) dx_and_dz = torch.autograd.grad( y, [x, z], init_grad, retain_graph=True, create_graph=True, allow_unused=True, ) test_case.assertTrue( np.allclose( dx_and_dz[0].oneflow.detach().numpy(), dx_and_dz[0].pytorch.detach().cpu().numpy(), ) ) test_case.assertTrue( dx_and_dz[1].oneflow is None and dx_and_dz[1].pytorch is None ) np_arr = np.random.rand(*y.oneflow.shape) init_grad_grad = torch.tensor(np_arr).requires_grad_().to(device) ddx = torch.autograd.grad( dx_and_dz[0], x, init_grad_grad, retain_graph=True, create_graph=True, allow_unused=True, )[0] test_case.assertTrue( np.allclose( ddx.oneflow.detach().numpy(), ddx.pytorch.detach().cpu().numpy(), ) ) np_arr = np.random.rand(*y.oneflow.shape) init_grad_grad_grad = torch.tensor(np_arr).requires_grad_().to(device) dddx = torch.autograd.grad( ddx, x, init_grad_grad_grad, retain_graph=True, create_graph=True, allow_unused=True, )[0] test_case.assertTrue(dddx.oneflow is None and dddx.pytorch is None) def test_autograd_is_grads_batched(test_case): x = flow.randn(2, 2, requires_grad=True) out = x.clone() # Size([2, 2]) batched_grad = flow.arange(3).expand(2, 2, 3).transpose(0, 2) # Size([3, 2, 2]) (grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) test_case.assertTrue( np.array_equal( grad.cpu().detach().numpy(), flow.arange(3) .expand(2, 2, 3) .transpose(0, 2) .to(dtype=grad.dtype) .numpy(), ) ) # Detect shape mismatch grad_out = flow.ones(2, 2) with test_case.assertRaisesRegex( RuntimeError, "If `is_grads_batched=True`, we interpret the first" ): flow.autograd.grad( outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True, ) # TODO: ReduceSum backward not support broadcast grad with shape (3, ) to (3, 2, 2) # # Scalar outputs # out = x.sum() # Size([]) # batched_grad = flow.arange(3) # Size([3]) # (grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) # test_case.assertTrue( # np.array_equal( # grad.cpu().detach().numpy(), # flow.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype).numpy(), # ) # ) # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior. grad_out = flow.ones(2).unsqueeze(1) with test_case.assertRaisesRegex( RuntimeError, "If `is_grads_batched=True`, we interpret the first" ): flow.autograd.grad( outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True, ) def test_autograd_grad_none_list(test_case): x = flow.randn(10, 10, requires_grad=True) y = flow.randn(10, 10, requires_grad=True) merge = flow.cat([x, y], dim=0) s_x, s_y = flow.split(merge, 10, dim=0) s_x_sum = s_x.sum() s_y_sum = s_y.sum() (grad_x, grad_y) = flow.autograd.grad((s_x_sum, s_y_sum), (x, y), (None, None)) test_case.assertTrue( np.array_equal(grad_x.numpy(), np.ones(x.shape).astype(np.float32),) ) test_case.assertTrue( np.array_equal(grad_y.numpy(), np.ones(y.shape).astype(np.float32),) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_autograd_function.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow import autograd class TestAutogradFunction(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_simple_input(test_case): class MyReLU(autograd.Function): @staticmethod def forward(ctx, x): y = x.clamp(min=0.0, max=None) ctx.save_for_backward(x) return y @staticmethod def backward(ctx, y_grad): x_grad = y_grad.clone() (x,) = ctx.saved_tensors x_grad[x < 0] = 0 return x_grad np_arr = np.random.randn(4, 5) a = flow.tensor(np_arr).requires_grad_() # forward b = MyReLU.apply(a) test_case.assertTrue(np.allclose(b.numpy(), np_arr.clip(min=0.0))) # backward b.sum().backward() np_grad = np.ones((4, 5)) np_grad[np_arr < 0] = 0.0 test_case.assertTrue(np.allclose(a.grad.numpy(), np_grad)) @flow.unittest.skip_unless_1n1d() def test_multi_input(test_case): class MyMatMul(autograd.Function): @staticmethod def forward(ctx, x, y): z = x * y ctx.save_for_backward(x, y) return z @staticmethod def backward(ctx, z_grad): x, y = ctx.saved_tensors x_grad = y * z_grad y_grad = x * z_grad return x_grad, y_grad np_arr0 = np.random.randn(4, 5) np_arr1 = np.random.randn(4, 5) a = flow.tensor(np_arr0).requires_grad_() b = flow.tensor(np_arr1).requires_grad_() # forward c = MyMatMul().apply(a, b) test_case.assertTrue(np.allclose(c.numpy(), np_arr0 * np_arr1)) # backward c.sum().backward() test_case.assertTrue(np.allclose(a.grad.numpy(), np_arr1)) test_case.assertTrue(np.allclose(b.grad.numpy(), np_arr0)) @flow.unittest.skip_unless_1n1d() def test_non_differentiable_interface(test_case): class MyModule(autograd.Function): @staticmethod def forward(ctx, x, y): mul_res = x * y add_res = x + y ctx.save_for_backward(x, y) ctx.mark_non_differentiable(add_res) return mul_res, add_res @staticmethod def backward(ctx, mul_grad, add_grad=None): x, y = ctx.saved_tensors x_grad = y * mul_grad y_grad = x * mul_grad return x_grad, y_grad np_arr0 = np.random.randn(4, 5) np_arr1 = np.random.randn(4, 5) a = flow.tensor(np_arr0).requires_grad_() b = flow.tensor(np_arr1).requires_grad_() # forward c, d = MyModule().apply(a, b) test_case.assertTrue(np.allclose(c.numpy(), np_arr0 * np_arr1)) test_case.assertFalse(d.requires_grad) test_case.assertTrue(d.grad_fn is None) # backward c.sum().backward() test_case.assertTrue(np.allclose(a.grad.numpy(), np_arr1)) test_case.assertTrue(np.allclose(b.grad.numpy(), np_arr0)) @flow.unittest.skip_unless_1n1d() def test_partial_inputs_requires_grad(test_case): class MyModule(autograd.Function): @staticmethod def forward(ctx, x, y, z): return x + y + z @staticmethod def backward(ctx, out_grad): return None, out_grad, None x = flow.randn(4, 5) y = flow.randn(4, 5).requires_grad_() z = flow.randn(4, 5) # forward res = MyModule.apply(x, y, z) test_case.assertTrue( np.allclose(res.numpy(), x.numpy() + y.numpy() + z.numpy()) ) # backward res.sum().backward() test_case.assertIsNone(x.grad) test_case.assertTrue(np.allclose(y.grad.numpy(), np.ones((4, 5)))) test_case.assertIsNone(z.grad) @flow.unittest.skip_unless_1n1d() def test_dynamic_attr_for_ctx(test_case): class MyModule(autograd.Function): @staticmethod def forward(ctx, x): ctx.scale = 2.0 return x * ctx.scale @staticmethod def backward(ctx, out_grad): return out_grad * ctx.scale x = flow.randn(4, 5).requires_grad_() # forward res = MyModule.apply(x) test_case.assertTrue(np.allclose(res.numpy(), x.numpy() * 2.0)) # backward res.sum().backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones((4, 5)) * 2.0)) @flow.unittest.skip_unless_1n1d() def test_backward_error_message(test_case): class MyModule(autograd.Function): @staticmethod def forward(ctx, x, y, z): return x + y + z @staticmethod def backward(ctx, out_grad): return None, out_grad x = flow.randn(4, 5) y = flow.randn(4, 5).requires_grad_() z = flow.randn(4, 5) res = MyModule.apply(x, y, z) with test_case.assertRaises(Exception) as exp: res.sum().backward() test_case.assertIsNotNone( re.search( r"RuntimeError: function MyModule returned an incorrect number of gradients \(expected \d, got \d\)", str(exp.exception), ) ) @flow.unittest.skip_unless_1n1d() def test_graph_test_multi_input(test_case): class MyMul(autograd.Function): @staticmethod def forward(ctx, x, y): z = x * y ctx.save_for_backward(x, y) return z @staticmethod def backward(ctx, z_grad): x, y = ctx.saved_tensors x_grad = 2 * y * z_grad y_grad = 3 * x * z_grad return x_grad, y_grad class MyAdd(autograd.Function): @staticmethod def forward(ctx, x, y): return 2 * x + y @staticmethod def backward(ctx, z_grad): x_grad = z_grad y_grad = 2 * z_grad return x_grad, y_grad model = flow.nn.Linear(5, 4, bias=False) model.train() class MyGraph(flow.nn.Graph): def __init__(self): super().__init__() self.model = model optimizer = flow.optim.SGD(self.model.parameters()) self.add_optimizer(optimizer) def build(self, x, y): x.retain_grad() y.retain_grad() self.model.weight.retain_grad() z = MyMul().apply(x, y) z = MyAdd().apply(z, self.model.weight) z.sum().backward() return z, x.grad, y.grad, self.model.weight.grad np_arr0 = np.random.randn(4, 5).astype(np.float32) np_arr1 = np.random.randn(4, 5).astype(np.float32) np_arr2 = np.random.randn(4, 5).astype(np.float32) a = flow.tensor(np_arr0).requires_grad_() b = flow.tensor(np_arr1).requires_grad_() model.weight.copy_(np_arr2) c, a_grad, b_grad, w_grad = MyGraph()(a, b) test_case.assertTrue(np.allclose(c.numpy(), 2 * np_arr0 * np_arr1 + np_arr2)) test_case.assertTrue(np.allclose(a_grad.numpy(), 2 * np_arr1)) test_case.assertTrue(np.allclose(b_grad.numpy(), 3 * np_arr0)) test_case.assertTrue(np.allclose(w_grad.numpy(), 2 * np.ones_like(np_arr2))) @flow.unittest.skip_unless_1n1d() def test_autograd_function_memory(test_case): global_ctx = None class MyModule(autograd.Function): @staticmethod def forward(ctx, x): z = x.clone() ctx.save_for_backward(z) nonlocal global_ctx global_ctx = ctx return z @staticmethod def backward(ctx, out_grad): (x,) = ctx.saved_tensors return x x = flow.randn(5, 5).requires_grad_() res = MyModule.apply(x) test_case.assertTrue(global_ctx._is_data_valid()) res.sum().backward() # ensure that global_ctx is released test_case.assertFalse(global_ctx._is_data_valid()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_autograd_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestAutogradMode(oneflow.unittest.TestCase): def test_grad_mode(test_case): test_case.assertTrue(flow.is_grad_enabled()) def test_inference_mode(test_case): with flow.inference_mode(True): test_case.assertFalse(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.inference_mode(True) def func(): test_case.assertFalse(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) with flow.inference_mode(False): test_case.assertTrue(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.inference_mode(False) def func(): test_case.assertTrue(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) def test_enable_grad(test_case): with flow.enable_grad(): test_case.assertTrue(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.enable_grad() def func(): test_case.assertTrue(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) def test_no_grad(test_case): with flow.no_grad(): test_case.assertFalse(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.no_grad() def func(): test_case.assertFalse(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) def test_set_grad_enabled(test_case): def assert_grad_mode(mode): if mode: test_case.assertTrue(flow.is_grad_enabled()) else: test_case.assertFalse(flow.is_grad_enabled()) def get_decorater_func_with_mode(mode): @flow.set_grad_enabled(mode) def func(): assert_grad_mode(mode) return func def get_decorater_context_func_with_mode(dec_mode, ctx_mode): @flow.set_grad_enabled(dec_mode) def func(): assert_grad_mode(dec_mode) with flow.set_grad_enabled(ctx_mode): assert_grad_mode(ctx_mode) assert_grad_mode(dec_mode) return func flow.set_grad_enabled(False) assert_grad_mode(False) with flow.set_grad_enabled(True): assert_grad_mode(True) flow.set_grad_enabled(False) assert_grad_mode(False) func = get_decorater_func_with_mode(True) func() assert_grad_mode(False) flow.set_grad_enabled(True) assert_grad_mode(True) with flow.set_grad_enabled(False): assert_grad_mode(False) flow.set_grad_enabled(True) assert_grad_mode(True) func = get_decorater_func_with_mode(False) func() assert_grad_mode(True) get_decorater_context_func_with_mode(True, True)() get_decorater_context_func_with_mode(True, False)() get_decorater_context_func_with_mode(False, True)() get_decorater_context_func_with_mode(False, False)() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_avgpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util.generators import constant, random_bool import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestAvgPoolingModule(flow.unittest.TestCase): @autotest(n=5) def test_avgpool1d_with_random_data(test_case): m = torch.nn.AvgPool1d( kernel_size=random(4, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), ceil_mode=random(), count_include_pad=random(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim2=random(20, 22)).to(device) y = m(x) return y @autotest(n=5) def test_avgpool2d_with_random_data(test_case): m = torch.nn.AvgPool2d( kernel_size=random(4, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), ceil_mode=random(), count_include_pad=random(), divisor_override=random().to(int), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device) y = m(x) return y # TODO:(zhaoluyang) this test case has probability to fail in backward @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, rtol=0.001, atol=0.001, auto_backward=False) def test_avgpool2d_with_half_data(test_case): m = torch.nn.AvgPool2d( kernel_size=random(4, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), ceil_mode=random(), count_include_pad=random(), divisor_override=random().to(int), ) m.train(random()) device = gpu_device() m.to(device) x = ( random_tensor( ndim=4, dim2=random(20, 22), dim3=random(20, 22), requires_grad=False ) .to(device) .to(torch.float16) ) y = m(x) return y @autotest(n=5) def test_avgpool3d_with_random_data(test_case): m = torch.nn.AvgPool3d( kernel_size=random(4, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), ceil_mode=random(), count_include_pad=random(), divisor_override=random().to(int), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22) ).to(device) y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestAvgPoolingFunctional(flow.unittest.TestCase): @autotest(n=5) def test_avgpool1d_functional(test_case): device = random_device() x = random_tensor(ndim=3, dim2=random(20, 22)).to(device) y = torch.nn.functional.avg_pool1d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y @autotest(n=5) def test_avgpool2d_functional(test_case): device = random_device() x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device) y = torch.nn.functional.avg_pool2d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y @autotest(n=5) def test_avgpool3d_functional(test_case): device = random_device() x = random_tensor( ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22) ).to(device) y = torch.nn.functional.avg_pool3d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y @profile(torch.nn.functional.avg_pool2d) def profile_avgpool2d(test_case): torch.nn.functional.avg_pool2d( torch.ones(1, 128, 28, 28), kernel_size=3, padding=1 ) torch.nn.functional.avg_pool2d( torch.ones(1, 128, 28, 28), kernel_size=3, stride=2, padding=1 ) torch.nn.functional.avg_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, padding=1 ) torch.nn.functional.avg_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1 ) torch.nn.functional.avg_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1, ceil_mode=True, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_baddbmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestBaddBmmModule(flow.unittest.TestCase): @autotest(n=5, rtol=1e-4, atol=1e-3) def test_baddbmm_with_torch(test_case): device = random_device() input = random_tensor(ndim=3, dim0=2, dim1=4, dim2=4).to(device) batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device) batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device) y = torch.baddbmm(input, batch1, batch2, beta=2.0, alpha=1.2) return y @autotest(n=5, rtol=1e-4, atol=1e-3) def test_baddbmm_in_sd2_with_torch(test_case): device = random_device() input = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2, requires_grad=False).to( device ) batch1 = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2).to(device) batch2 = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2).to(device) y = torch.baddbmm(input, batch1, batch2, beta=0.0, alpha=1.2) return y @autotest(n=5, rtol=1e-4, atol=1e-3) def test_baddbmm_no_attr_with_torch(test_case): device = random_device() input = random_tensor(ndim=3, dim0=2, dim1=4, dim2=4).to(device) batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device) batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device) y = torch.baddbmm(input, batch1, batch2) return y @autotest(n=5, rtol=1e-4, atol=1e-3) def test_baddbmm_broadcast_with_torch(test_case): device = random_device() input = random_tensor(ndim=1, dim0=4).to(device) batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device) batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device) y = torch.baddbmm(input, batch1, batch2, beta=-1.98, alpha=1.34) return y @profile(torch.baddbmm) def profile_baddbmm(test_case): input = torch.ones(10, 100, 100) batch1 = torch.ones(10, 100, 100) batch2 = torch.ones(10, 100, 100) torch.bmm(input, batch1, batch2, beta=-1.98, alpha=1.34) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_batch_gather.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_batch_gather(test_case, shape, device): # for example: shape = (3, 2, 2) x = np.random.randn(*shape) x_tensor = flow.Tensor(x).to(device) x_tensor.requires_grad = True batchsize = x.shape[0] init_index = np.array( [np.random.randint(batchsize) for i in range(batchsize)] ).astype(np.int64) batch_gather_index = flow.tensor(init_index).to(device) batch_gather_out = flow.batch_gather(x_tensor, batch_gather_index) x_tensor_gather = flow.Tensor(x).to(device) x_tensor_gather.requires_grad = True reshaped_shape = [batchsize] # reshaped_shape = [3] for i in range(len(x.shape) - 1): reshaped_shape.append(1) # reshaped_shape = [3] -> [3, 1, 1] gather_index = np.reshape(init_index, reshaped_shape) gather_index = np.broadcast_to(gather_index, shape).astype( np.int64 ) # [3, 1, 1] -> [3, 2, 2] gather_index = flow.tensor(gather_index).to(device) gather_out = flow.gather(x_tensor_gather, 0, gather_index) total_out = batch_gather_out.sum() + gather_out.sum() total_out.backward() test_case.assertTrue( np.allclose(batch_gather_out.numpy(), gather_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() class TestBatchGather(flow.unittest.TestCase): def test_batch_gather(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_batch_gather] arg_dict["shape"] = [(3, 2, 2), (3, 2, 4, 2), (3, 3, 4, 2, 2), (4, 2)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_batchnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestBatchNormModule(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 11 times in past week") @autotest( auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False ) def test_batchnorm1d_module_with_random_data(test_case): device = random_device() channel = random(1, 4).to(int) m = torch.nn.BatchNorm1d( num_features=channel, track_running_stats=random().to(bool), affine=random().to(bool), ).to(device) m.train(random()) x = random_tensor( ndim=3, dim0=random(1, 4), dim1=channel, requires_grad=True ).to(device) y = m(x) return y @autotest( auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False ) def test_batchnorm2d_module_with_random_data(test_case): device = random_device() channel = random(1, 4).to(int) m = torch.nn.BatchNorm2d( num_features=channel, track_running_stats=random().to(bool), affine=random().to(bool), ).to(device) m.train(random()) x = random_tensor( ndim=4, dim0=random(1, 4), dim1=channel, requires_grad=True ).to(device) y = m(x) return y @autotest( auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False ) def test_batchnorm3d_module_with_random_data(test_case): device = random_device() channel = random(1, 4).to(int) m = torch.nn.BatchNorm3d( num_features=channel, track_running_stats=random().to(bool), affine=random().to(bool), ).to(device) m.train(random()) x = random_tensor(ndim=5, dim1=channel, requires_grad=True).to(device) y = m(x) return y @autotest(rtol=1e-3, atol=1e-3, check_grad_use_random_data=False) def test_functional_batchnorm_with_random_data(test_case): device = random_device() channel = random(1, 4).to(int) x = random_tensor(ndim=5, dim1=channel, requires_grad=True).to(device) running_mean = random_tensor(ndim=1, dim0=channel, requires_grad=False) running_var = random_tensor(ndim=1, dim0=channel, low=0.0, requires_grad=False) weight = random_tensor(ndim=1, dim0=channel) bias = random_tensor(ndim=1, dim0=channel) result = torch.nn.functional.batch_norm( input=x, running_mean=running_mean, running_var=running_var, weight=weight, bias=bias, training=random_bool(), ) return result @autotest(rtol=1e-3, atol=1e-3, auto_backward=False, check_graph=False) def test_batchnorm2d_module_with_half_random_data(test_case): device = random_device() channel = random(1, 4).to(int) m = torch.nn.BatchNorm2d( num_features=channel, track_running_stats=random().to(bool), affine=random().to(bool), ).to(device) m.train(random()) m.half() x = random_tensor( ndim=4, dim0=random(1, 4), dim1=channel, requires_grad=True ).to(device) x.half() y = m(x) return y @profile(torch.nn.functional.batch_norm) def profile_batchnorm(test_case): input = torch.ones(16, 128, 28, 28) running_mean = torch.randn(128) running_var = torch.randn(128) weight = torch.randn(128) bias = torch.randn(128) torch.nn.functional.batch_norm( input, running_mean, running_var, weight, bias, True ) torch.nn.functional.batch_norm( input, running_mean, running_var, weight, bias, False ) torch.nn.functional.batch_norm( input, running_mean, running_var, None, None, True ) torch.nn.functional.batch_norm( input, running_mean, running_var, None, None, False ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_batchnorm_add_relu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_bn_add_relu(test_case, device, batch, channel, height, width): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_x_tensor.requires_grad = True fused_addend = np.random.randn(batch, channel, height, width) fused_addend_tensor = flow.Tensor(fused_addend).to(device) fused_addend_tensor.requires_grad = True fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device) fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor, fused_addend_tensor) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_x_tensor.requires_grad = True origin_addend_tensor = flow.Tensor(fused_addend).to(device) origin_addend_tensor.requires_grad = True origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device) origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor origin_out = flow.nn.functional.relu(origin_out) total_out = fused_out + origin_out total_out_sum = total_out.sum() total_out_sum.backward() # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) # test input grad. test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_addend_tensor.grad.numpy(), origin_addend_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test weight and bias grad. test_case.assertTrue( np.allclose( fused_weight_tensor.grad.numpy(), origin_weight_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test running mean and running variance. test_case.assertTrue( np.allclose( fused_bn.running_mean.numpy(), origin_batch_norm.running_mean.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bn.running_var.numpy(), origin_batch_norm.running_var.numpy(), atol=1e-4, rtol=1e-4, ) ) def _test_bn_relu(test_case, device, batch, channel, height, width): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_x_tensor.requires_grad = True fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device) fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor, None) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_x_tensor.requires_grad = True origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device) origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) origin_out = flow.nn.functional.relu(origin_out) total_out = fused_out + origin_out total_out_sum = total_out.sum() total_out_sum.backward() # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) # test input grad. test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test weight and bias grad. test_case.assertTrue( np.allclose( fused_weight_tensor.grad.numpy(), origin_weight_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test running mean and running variance. test_case.assertTrue( np.allclose( fused_bn.running_mean.numpy(), origin_batch_norm.running_mean.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bn.running_var.numpy(), origin_batch_norm.running_var.numpy(), atol=1e-4, rtol=1e-4, ) ) def _test_bn_relu_track_running_states_false( test_case, device, batch, channel, height, width ): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_x_tensor.requires_grad = True fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel, track_running_stats=False).to(device) fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor, None) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_x_tensor.requires_grad = True origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel, track_running_stats=False).to( device ) origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) origin_out = flow.nn.functional.relu(origin_out) total_out = fused_out + origin_out total_out_sum = total_out.sum() total_out_sum.backward() # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) # test input grad. test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test weight and bias grad. test_case.assertTrue( np.allclose( fused_weight_tensor.grad.numpy(), origin_weight_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # When track running states is False, the running mean and running variance will be set as None. test_case.assertIsNone(fused_bn.running_mean) test_case.assertIsNone(origin_batch_norm.running_mean) test_case.assertIsNone(fused_bn.running_var) test_case.assertIsNone(origin_batch_norm.running_var) def _test_bn_add_relu_track_running_states_false( test_case, device, batch, channel, height, width ): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_x_tensor.requires_grad = True fused_addend = np.random.randn(batch, channel, height, width) fused_addend_tensor = flow.Tensor(fused_addend).to(device) fused_addend_tensor.requires_grad = True fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel, track_running_stats=False).to(device) fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor, fused_addend_tensor) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_x_tensor.requires_grad = True origin_addend_tensor = flow.Tensor(fused_addend).to(device) origin_addend_tensor.requires_grad = True origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel, track_running_stats=False).to( device ) origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor origin_out = flow.nn.functional.relu(origin_out) total_out = fused_out + origin_out total_out_sum = total_out.sum() total_out_sum.backward() # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) # test input grad. test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_addend_tensor.grad.numpy(), origin_addend_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # test weight and bias grad. test_case.assertTrue( np.allclose( fused_weight_tensor.grad.numpy(), origin_weight_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # When track running states is False, the running mean and running variance will be set as None. test_case.assertIsNone(fused_bn.running_mean) test_case.assertIsNone(origin_batch_norm.running_mean) test_case.assertIsNone(fused_bn.running_var) test_case.assertIsNone(origin_batch_norm.running_var) def _test_bn_add_relu_eval(test_case, device, batch, channel, height, width): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_addend = np.random.randn(batch, channel, height, width) fused_addend_tensor = flow.Tensor(fused_addend).to(device) fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device) fused_bn.eval() fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor, fused_addend_tensor) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_addend_tensor = flow.Tensor(fused_addend).to(device) origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device) origin_batch_norm.eval() origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor origin_out = flow.nn.functional.relu(origin_out) # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) def _test_bn_relu_eval(test_case, device, batch, channel, height, width): weight_numpy = np.random.randn(channel) bias_numpy = np.random.randn(channel) fused_x = np.random.randn(batch, channel, height, width) fused_x_tensor = flow.Tensor(fused_x).to(device) fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device) fused_bn.eval() fused_bn.weight = fused_weight_tensor fused_bn.bias = fused_bias_tensor fused_out = fused_bn(fused_x_tensor) origin_x_tensor = flow.Tensor(fused_x).to(device) origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device)) origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device)) origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device) origin_batch_norm.eval() origin_batch_norm.weight = origin_weight_tensor origin_batch_norm.bias = origin_bias_tensor origin_out = origin_batch_norm(origin_x_tensor) origin_out = flow.nn.functional.relu(origin_out) # test output. test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestBnAddRelu(flow.unittest.TestCase): def test_bn_add_relu2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_bn_add_relu, _test_bn_relu, _test_bn_relu_track_running_states_false, _test_bn_add_relu_track_running_states_false, _test_bn_add_relu_eval, _test_bn_relu_eval, ] arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["batch"] = [1, 2, 8] arg_dict["channels"] = [4, 6] arg_dict["height"] = [6, 8] arg_dict["width"] = [12, 8] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_bernoulli.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_bernoulli(test_case, shape, p, dtype): input_arr = np.ones(shape) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device("cpu")) if p is None: y = flow.bernoulli(x, dtype=dtype) else: y = flow.bernoulli(x, p=p, dtype=dtype) test_case.assertTrue(y.dtype == dtype) if p == 1 or p is None: test_case.assertTrue(np.allclose(y.numpy(), x.numpy())) elif p == 0: test_case.assertTrue(np.allclose(y.numpy(), np.zeros(shape))) def _test_bernoulli_with_generator(test_case, shape): generator = flow.Generator() generator.manual_seed(0) x = flow.tensor( np.random.rand(*shape), dtype=flow.float32, device=flow.device("cpu") ) y_1 = flow.bernoulli(x, generator=generator) generator.manual_seed(0) y_2 = flow.bernoulli(x, generator=generator) test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy())) @flow.unittest.skip_unless_1n1d() class TestBernoulli(flow.unittest.TestCase): def test_bernoulli(test_case): arg_dict = OrderedDict() arg_dict["test_functions"] = [_test_bernoulli] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["p"] = [None, 0, 1] arg_dict["dtype"] = [flow.float32, flow.int64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skip("bernoulli has bug") @autotest(auto_backward=False) def test_flow_bernoulli_with_random_data(test_case): input = random_tensor(ndim=1).to("cpu") return torch.bernoulli(input) """ @profile(torch.bernoulli) def profile_bernoulli(test_case): torch.bernoulli(torch.ones(3, 3)) torch.bernoulli(torch.zeros(3, 3)) """ if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_binary_math_ops_dtype.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from itertools import product import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def get_dtype_str(dtype): return str(dtype).split(".")[-1] dtype_list = [ torch.int8, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, ] @flow.unittest.skip_unless_1n1d() class TestBinaryMathOpsDtype(flow.unittest.TestCase): @autotest(n=2, auto_backward=False, check_graph=False) def test_binary_math_ops_dtype(test_case): device = random_device() for x1_dtype, x2_dtype in product(dtype_list, dtype_list): x1 = random_tensor(2, 2, 3, requires_grad=False).to(device).to(x1_dtype) x2 = random_tensor(2, 2, 3, requires_grad=False).to(device).to(x2_dtype) for op in ["+", "-", "*", "/"]: y = eval(f"x1 {op} x2") test_case.assertEqual( get_dtype_str(y.oneflow.dtype), get_dtype_str(y.pytorch.dtype) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_bincount.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestBinCount(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount(test_case): device = random_device() x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) result = torch.bincount(x) return result @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_weight(test_case): device = random_device() x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 100).to(device) return torch.bincount(x, weights=weight) @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_minlength(test_case): device = random_device() x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 100).to(device) minlength = random(1, 200).to(int) return torch.bincount(x, weights=weight, minlength=minlength) @autotest(n=5, auto_backward=False, check_graph=False) def test_bincount_0element(test_case): device = random_device() x = random_tensor(1, 0, low=0, high=65536, dtype=int).to(device) weight = random_tensor(1, 0).to(device) minlength = random(1, 200).to(int) return torch.bincount(x, weights=weight, minlength=minlength) @profile(torch.bincount) def profile_bincount(test_case): torch.bincount(torch.ones(4096).int()) torch.bincount(torch.ones(65536).int()) torch.bincount(torch.arange(4096).int()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_bitwise.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.automated_test_util import * def _test_bitwise_op(test_case, op): device = random_device() dims_kwargs = { "ndim": 4, "dim0": random(low=4, high=8).to(int), "dim1": random(low=4, high=8).to(int), "dim2": random(low=4, high=8).to(int), "dim3": random(low=4, high=8).to(int), } # TODO(WangYi): oneflow doesn't support conversion between uint8 and int8 # So, use "index" instead of "int" in `random_dtype` x_dtype = random_dtype(["index", "bool", "unsigned"]) y_dtype = random_dtype(["index", "bool", "unsigned"]) x = random_tensor(dtype=int, **dims_kwargs,).to(device).to(x_dtype) y = random_tensor(dtype=int, **dims_kwargs,).to(device).to(y_dtype) bool_tensor = random_tensor(low=-1, high=1, **dims_kwargs,).to(device) > 0 return op(op(x, y), bool_tensor) def _test_scalar_bitwise(test_case, op): device = random_device() dtype = random_dtype(["int", "bool", "unsigned"]) x = ( random_tensor( ndim=4, dim0=random(low=4, high=8).to(int), dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), dtype=int, ) .to(device) .to(dtype) ) scalar = random(low=-10, high=10).to(int) bool_scalar = random_bool() result = op(op(x, scalar), bool_scalar) return result # Bitwise ops only accept integral dtype, # so auto_backward isn't necessary @flow.unittest.skip_unless_1n1d() class TestBitwiseAndModule(flow.unittest.TestCase): @autotest(n=10, auto_backward=False) def test_bitwise_and(test_case): return _test_bitwise_op(test_case, torch.bitwise_and) @autotest(n=10, auto_backward=False) def test_scalar_bitwise_and(test_case): return _test_scalar_bitwise(test_case, torch.bitwise_and,) @flow.unittest.skip_unless_1n1d() class TestBitwiseOrModule(flow.unittest.TestCase): @autotest(n=10, auto_backward=False) def test_bitwise_or(test_case): return _test_bitwise_op(test_case, torch.bitwise_or) @autotest(n=10, auto_backward=False) def test_scalar_bitwise_or(test_case): return _test_scalar_bitwise(test_case, torch.bitwise_or,) @flow.unittest.skip_unless_1n1d() class TestBitwiseXorModule(flow.unittest.TestCase): @autotest(n=10, auto_backward=False) def test_bitwise_xor(test_case): return _test_bitwise_op(test_case, torch.bitwise_xor) @autotest(n=10, auto_backward=False) def test_scalar_bitwise_xor(test_case): return _test_scalar_bitwise(test_case, torch.bitwise_xor,) @flow.unittest.skip_unless_1n1d() class TestBitwiseNotModule(flow.unittest.TestCase): @autotest(n=10, auto_backward=False) def test_bitwise_not(test_case): device = random_device() # TODO(WangYi): oneflow doesn't support conversion between uint8 and int8 # So, use "index" instead of "int" in `random_dtype` dtype = random_dtype(["index", "bool", "unsigned"]) x = ( random_tensor( ndim=4, dim0=random(low=4, high=8).to(int), dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), dtype=int, high=10, ) .to(device) .to(dtype) ) return torch.bitwise_not(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_bmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_bmm(test_case, device): input1 = flow.tensor( np.random.randn(10, 3, 4), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(10, 4, 5), dtype=flow.float32, device=flow.device(device) ) of_out = flow.bmm(input1, input2) np_out = np.matmul(input1.numpy(), input2.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_bmm_backward(test_case, device): input1 = flow.tensor( [ [ [-0.0036776792258024216, 1.9946473836898804, -0.423959881067276], [1.0892143249511719, 0.04005361348390579, -0.27883127331733704], ], [ [-0.970306396484375, 0.017771577462553978, 0.019596196711063385], [0.27402883768081665, -0.8192587494850159, -0.3135920464992523], ], ], dtype=flow.float32, device=flow.device(device), requires_grad=True, ) input2 = flow.tensor( [ [ [1.118346929550171, -0.930071234703064], [1.1238232851028442, 1.373764157295227], [0.17178462445735931, -1.1010534763336182], ], [ [0.6694859862327576, 0.9250285029411316], [-1.0835869312286377, 0.4192655086517334], [1.2616937160491943, 0.33809131383895874], ], ], dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.bmm(input1, input2) of_out = of_out.sum() of_out.backward() np_grad = [ [ [0.18827569484710693, 2.4975874423980713, -0.9292688369750977], [0.18827569484710693, 2.4975874423980713, -0.9292688369750977], ], [ [1.5945144891738892, -0.6643214225769043, 1.5997850894927979], [1.5945144891738892, -0.6643214225769043, 1.5997850894927979], ], ] test_case.assertTrue( np.allclose(input1.grad.numpy(), np_grad, atol=1e-05, rtol=1e-05) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_bmm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_bmm, _test_bmm_backward] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True, rtol=1e-4, atol=1e-3) def test_bmm_with_torch(test_case): device = random_device() mat1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device) mat2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device) y = torch.bmm(mat1, mat2,) return y @profile(torch.bmm) def profile_bmm(test_case): mat1 = torch.ones(10, 100, 100) mat2 = torch.ones(10, 100, 100) torch.bmm(mat1, mat2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_broadcast_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_broadcast_like(test_case, device): input = flow.tensor( np.ones(shape=(3, 1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(3, 3, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(1, 2)) np_out = np.ones(shape=(3, 3, 3)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_one(test_case, device): input = flow.tensor( np.ones(shape=(1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 2, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor) np_out = np.ones(shape=(1, 2, 3)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_different_dim(test_case, device): input = flow.tensor( np.ones(shape=(3, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(2, 3, 4), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor) np_out = np.ones(shape=(2, 3, 4)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_different_dim_with_input_axisvec(test_case, device): input = flow.tensor( np.ones(shape=(1, 5, 6), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 5, 6, 1, 6), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(3, 4)) np_out = np.ones(shape=(1, 5, 6, 1, 6)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_3dim(test_case, device): input = flow.tensor( np.ones(shape=(1, 3, 2), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(3, 3, 2), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(0,)) np_out = np.ones(shape=(3, 3, 2)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_4dim(test_case, device): input = flow.tensor( np.ones(shape=(1, 3, 2, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(3, 3, 2, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(0, 3)) np_out = np.ones(shape=(3, 3, 2, 3)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_empty_axisvec(test_case, device): input = flow.tensor( np.ones(shape=(1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(2, 3, 4), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.broadcast_like(input, like_tensor) np_out = np.ones(shape=(2, 3, 4)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_broadcast_like_backward(test_case, device): input = flow.tensor( np.ones(shape=(3, 1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) like_tensor = flow.tensor( np.ones(shape=(3, 3, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(1, 2)) of_out = of_out.sum() of_out.backward() np_grad = [[[9.0]], [[9.0]], [[9.0]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestBroadCastLike(flow.unittest.TestCase): def test_broadcast_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_broadcast_like, _test_broadcast_like_one, _test_broadcast_like_different_dim, _test_broadcast_like_different_dim_with_input_axisvec, _test_broadcast_like_3dim, _test_broadcast_like_4dim, _test_broadcast_like_empty_axisvec, _test_broadcast_like_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_broadcast_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import torch as ori_torch import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * binary_ops = [ torch.add, torch.sub, torch.mul, torch.div, torch.min, torch.minimum, torch.max, torch.maximum, torch.fmod, torch.pow, torch.eq, torch.ne, torch.gt, torch.ge, torch.lt, torch.le, torch.logical_and, torch.logical_or, torch.logical_xor, ] @flow.unittest.skip_unless_1n1d() class TestBroadcastOps(flow.unittest.TestCase): @autotest(n=5, auto_backward=False) def test_broadcast_elementwise(test_case): op_idx = random(low=0, high=len(binary_ops)).to(int).value() op = binary_ops[op_idx] device = random_device() x = random_tensor(ndim=4, dim0=2, dim1=2, dim2=3, dim3=4).to(device) y = random_tensor(ndim=4, dim0=1, dim1=2, dim2=3, dim3=1).to(device) out = op(x, y) return out @autotest(n=5, auto_backward=False) def test_broadcast_matrix_row(test_case): op_idx = random(low=0, high=len(binary_ops)).to(int).value() op = binary_ops[op_idx] device = random_device() x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device) y = random_tensor(ndim=2, dim0=2, dim1=3).to(device) out = op(x, y) return out @autotest(n=5, auto_backward=False) def test_broadcast_matrix_col(test_case): op_idx = random(low=0, high=len(binary_ops)).to(int).value() op = binary_ops[op_idx] device = random_device() x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device) y = random_tensor(ndim=3, dim0=2, dim1=2, dim2=1).to(device) out = op(x, y) return out @autotest(n=5, auto_backward=False) def test_cpu_scalar_tensor_auto_cast(test_case): def check_output(test_case, output): of_res = output.oneflow torch_res = output.pytorch # NOTE: torch's device has no device index bug oneflow has. # e.g. torch gets "cpu" but oneflow gets "cpu:0" test_case.assertTrue(str(torch_res.device) in str(of_res.device)) test_case.assertTrue( np.allclose(of_res.numpy(), torch_res.detach().cpu().numpy()) ) op_idx = random(low=0, high=len(binary_ops)).to(int).value() op = binary_ops[op_idx] device = random_device() x = torch.tensor(1.0) y = random_tensor(ndim=2, dim0=2, dim1=2).to(device) out = op(x, y) check_output(test_case, out) out = op(y, x) check_output(test_case, out) @autotest(n=30, auto_backward=False) def test_broadcast_scalar(test_case): op_idx = random(low=0, high=len(binary_ops)).to(int).value() op = binary_ops[op_idx] device = random_device() x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device) out = op(x, 1) return out @profile(torch.add) def profile_broadcast_matrix_row(test_case): input0 = torch.ones(256, 1024) input1 = torch.ones(1024) torch.add(input0, input1) @profile(torch.add) def profile_broadcast_matrix_col(test_case): input0 = torch.ones(1024, 256) input1 = torch.ones(1024, 1) torch.add(input0, input1) @profile(torch.add) def profile_broadcast_elementwise(test_case): input0 = torch.ones(256, 1024) input1 = torch.ones(256, 1024) torch.add(input0, input1) @profile(torch.add) def profile_broadcast_scalar(test_case): input0 = torch.ones(256, 1024) torch.add(input0, 1) @profile(torch.add) def profile_broadcast_general(test_case): input0 = torch.ones(2, 64, 8, 16, 16, 4) input1 = torch.ones(64, 8, 1, 16, 1) torch.add(input0, input1) @flow.unittest.skip_unless_1n1d() class TestBroadcastOpsOther(flow.unittest.TestCase): def test_broadcast_shapes(test_case): shapes = (2,), (3, 1), (1, 1, 1) test_case.assertTrue( flow.broadcast_shapes(*shapes), ori_torch.broadcast_shapes(*shapes), ) @autotest(n=3) def test_broadcast_tensors(test_case): device = random_device() x = random_tensor(ndim=2, dim0=1, dim1=4).to(device=device) y = random_tensor(ndim=2, dim0=3, dim1=1).to(device=device) return torch.broadcast_tensors(x, y) def test_broadcast_to(test_case): # see flow.expand, because broadcast_to is an alias of flow.expand pass if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from random import shuffle import unittest from collections import OrderedDict import numpy as np import torch as torch_original import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _test_cast_float2int(test_case, device, shape): np_arr = np.random.randn(*shape).astype(np.float32) input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) output = flow.cast(input, flow.int8) np_out = np_arr.astype(np.int8) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_cast_int2float(test_case, device, shape): np_arr = np.random.randn(*shape).astype(np.int8) input = flow.tensor(np_arr, dtype=flow.int8, device=flow.device(device)) output = flow.cast(input, flow.float32) np_out = np_arr.astype(np.float32) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_cast_bool2int16(test_case, device, shape): np_arr = np.random.randn(*shape).astype(np.float32) input = flow.tensor(np_arr, dtype=flow.bool, device=flow.device(device)) output = flow.cast(input, flow.int16) np_out = np_arr.astype(bool).astype(np.int16) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_cast_with_non_contiguous_input(test_case, device, shape): np_arr = np.random.randn(*shape).astype(np.int8) permute_dims = np.arange(len(shape)).tolist() shuffle(permute_dims) input = flow.tensor(np_arr, dtype=flow.int8, device=flow.device(device)).permute( permute_dims ) output = flow.cast(input, flow.float32) np_out = np_arr.astype(np.float32).transpose(permute_dims) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) test_case.assertTrue(input.stride() == output.stride()) def _test_cast_backward(test_case, device, shape): np_arr = np.random.randn(*shape).astype(np.float32) x = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) y = flow.cast(x, flow.float64) z = y.sum() z.backward() np_out = np_arr.astype(np.float64) test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones(shape=shape))) def random_expand(x, ndim, expand_size): dim_size = [1,] * ndim random_index = random(0, ndim).to(int).value() dim_size[random_index] = expand_size return x.expand(*dim_size) @flow.unittest.skip_unless_1n1d() class TestCast(flow.unittest.TestCase): def test_cast(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_cast_float2int, _test_cast_int2float, _test_cast_bool2int16, _test_cast_backward, # _test_cast_with_non_contiguous_input, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_cast_with_0_size_data(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_cast_float2int, _test_cast_int2float, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3, 0, 5)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_cast_with_strided_input(test_case): device = random_device() x = random_tensor() x = x.to(dtype=torch.float32, device=device) perm_list = [0, 1, 2, 3] shuffle(perm_list) x = x.permute(perm_list) y = x.to(dtype=torch.float64, device=device) return y @autotest(n=5) def test_cast_with_expanded_input(test_case): device = random_device() random_expand_size = random(1, 6).to(int).value() x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1) x = x.to(dtype=torch.float32, device=device) perm_list = [0, 1, 2, 3, 4] shuffle(perm_list) x = x.permute(perm_list) y = random_expand(x, ndim=5, expand_size=random_expand_size) z = y.to(dtype=torch.float64, device=device) return z @autotest(n=5) def test_cast_with_expanded_input_2(test_case): device = random_device() x = random_tensor(ndim=1, dim0=5) a = x.to(dtype=torch.float32, device=device) b = a.expand((4, 5)) c = b.to(dtype=torch.double, device=device) return c @autotest(n=5) def test_cast_with_squeezed_input(test_case): device = random_device() x = random_tensor().to(device) y = torch.squeeze(x, random(1, 3).to(int)) z = y.to(dtype=torch.double, device=device) return z @autotest(n=5, auto_backward=False) def test_cast_with_sliced_input(test_case): device = random_device() x = random_tensor(ndim=1, dim0=20) y = random_tensor(ndim=1, dim0=7) x = x.to(dtype=torch.float32, device=device) y = y.to(device=device) rows = x * 10 cols = y a = rows.reshape(20, 1) + cols b = a[:, :1] c = b.to(torch.int) return c @autotest(n=5, auto_backward=False) # NOTE:if set auto_backward=True, both oneflow and pytorch will raise RuntimeError: # element 0 of tensors does not require grad and does not have a grad_fn def test_cast_with_scalar_input(test_case): device = random_device() x = torch.tensor(3.14, device=device) y = x.to(dtype=torch.float64, device=device) z = y.to(dtype=torch.int8, device=device) return z @autotest(n=5, auto_backward=True, include_complex=False, atol=1e-5, rtol=1e-5) def test_cast_with_complex_float2complex(test_case): device = random_device() x = random_tensor().to(dtype=torch.float32, device=device) y = x.to(torch.complex64) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ceil.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestCeilModule(flow.unittest.TestCase): @autotest(n=5) def test_ceil_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.ceil(input) return y @autotest(n=5) def test_ceil_flow_with_random_0d_data(test_case): device = random_device() input = random_tensor(ndim=0).to(device) y = torch.ceil(input) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_ceil_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.ceil(x) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_ceil_with_0shape_0d_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.ceil(x) return y @profile(torch.ceil) def profile_ceil(test_case): torch.ceil(torch.ones(4)) torch.ceil(torch.ones(100000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_check_meta_consistency.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import os import oneflow.unittest from oneflow.test_utils.test_util import GenArgList @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCastModule_1n2d(flow.unittest.TestCase): def test_check_meta_consistency(test_case): if os.getenv("RANK") == "0": x = flow.ones((16, 16), device=flow.device("cuda"), dtype=flow.int32) else: x = flow.zeros((1,), device=flow.device("cuda"), dtype=flow.float) placement = flow.placement("cuda", ranks=[0]) sbp = (flow.sbp.broadcast,) y = x.to_global(placement=placement, sbp=sbp) y.check_meta_consistency() y = y.to_global(sbp=flow.sbp.split(0)) y.check_meta_consistency() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_checkpointing.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import oneflow as flow import oneflow.unittest import numpy as np @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestCheckpointing(flow.unittest.TestCase): def test_checkpointing(test_case): relu_forward_num = 0 relu_backward_num = 0 class MyReLU(flow.autograd.Function): @staticmethod def forward(ctx, x): nonlocal relu_forward_num relu_forward_num += 1 y = flow.relu(x) ctx.save_for_backward(y) return y @staticmethod def backward(ctx, dy): nonlocal relu_backward_num relu_backward_num += 1 y = ctx.saved_tensors[0] return dy * (y > 0) class M(flow.nn.Module): def __init__(self): super().__init__() self.conv1 = flow.nn.Conv2d(3, 3, 3) self.conv2 = flow.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv1(x) if checkpointing: x = flow.utils.checkpoint.checkpoint(MyReLU.apply, x) else: x = MyReLU.apply(x) x = self.conv2(x) return x x1 = flow.randn(1, 3, 8, 16).requires_grad_() x2 = x1.detach().clone().requires_grad_() m = M() checkpointing = True y1 = m(x1) y1.sum().backward() checkpointing = False y2 = m(x2) y2.sum().backward() test_case.assertTrue(np.array_equal(y1, y2)) test_case.assertTrue(np.array_equal(x1.grad, x2.grad)) test_case.assertEqual(relu_forward_num, 3) test_case.assertEqual(relu_backward_num, 2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_chunk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from random import shuffle import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestChunk(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_flow_chunk_list_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor( ndim=4, dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device) y = torch.chunk(x, chunks=random(low=1, high=5).to(int), dim=dim) z = torch.cat(y, dim=dim) return z @autotest(n=10) def test_flow_chunk_list_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor( ndim=4, dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device) permute_list = [0, 1, 2, 3] shuffle(permute_list) y = x.permute(permute_list) z = torch.chunk(y, chunks=random(low=1, high=5).to(int), dim=dim) return torch.cat(z, dim=dim) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_chunk_list_with_stride(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor( ndim=4, dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.chunk(y, chunks=random(low=1, high=5).to(int), dim=dim) return torch.cat(z, dim=dim) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_chunk_list_bool_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor( ndim=4, dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device, torch.bool) y = torch.chunk(x, chunks=random(low=1, high=5).to(int), dim=dim) z = torch.cat(y, dim=dim) return z @autotest(n=5, check_graph=True) def test_flow_chunk_list_with_random_data_negative_dim(test_case): device = random_device() dim = random(1, 3).to(int) x = random_tensor( ndim=4, dim0=random(low=4, high=8).to(int), dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device) y = torch.chunk(x, chunks=4, dim=-1) z = torch.cat(y, dim=-1) return z @profile(torch.chunk) def profile_chunk(test_case): torch.chunk(torch.ones(16), 4) torch.chunk(torch.ones(100000), 5) torch.chunk(torch.ones(100, 100), 5, dim=1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_clamp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_clamp(test_case, shape, device, dtype): input = flow.tensor( np.random.randn(*shape), dtype=dtype, device=flow.device(device) ) of_out = flow.clamp(input, 0.1, 0.5) np_out = np.clip(input.numpy(), 0.1, 0.5) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_tensor_clamp(test_case, shape, device, dtype): input = flow.tensor( np.random.randn(*shape), dtype=dtype, device=flow.device(device) ) of_out = input.clamp(0.1, 0.5) np_out = np.clip(input.numpy(), 0.1, 0.5) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_scalar_min(test_case, shape, device, dtype): input = flow.tensor( np.random.randn(*shape), dtype=dtype, device=flow.device(device) ) of_out = flow.clamp(input, 0.1, None) np_out = np.clip(input.numpy(), 0.1, None) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_scalar_max(test_case, shape, device, dtype): input = flow.tensor( np.random.randn(*shape), dtype=dtype, device=flow.device(device) ) of_out = flow.clamp(input, None, 0.5) np_out = np.clip(input.numpy(), None, 0.5) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_integral(test_case, shape, device, dtype): input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device)).to( dtype ) of_out = flow.clamp(input, 1, 5) np_out = np.clip(input.numpy(), 1, 5) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _numpy_clamp_grad(arr, min, max): grad = np.zeros_like(arr) grad[arr.clip(min, max) == arr] += 1 return grad def _test_clamp_backward(test_case, shape, device, dtype): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.clamp(x, 0.1, 0.5).sum() y.backward() test_case.assertTrue( np.allclose( x.grad.numpy(), _numpy_clamp_grad(x.numpy(), 0.1, 0.5), 1e-05, 1e-05 ) ) @flow.unittest.skip_unless_1n1d() class TestClampModule(flow.unittest.TestCase): def test_clamp(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_clamp, _test_tensor_clamp, _test_clamp_scalar_min, _test_clamp_scalar_max, _test_clamp_backward, ] arg_dict["shape"] = [(2,), (2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [flow.float16, flow.float, flow.double] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) arg_dict["fun"] = [ _test_clamp_integral, ] arg_dict["dtype"] = [flow.int8, flow.int, flow.long] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_clamp_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clamp(input, min=random().to(float), max=random().to(float)) return y @autotest(n=5) def test_clamp_min_none_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clamp(input, min=random().to(float), max=random().to(float)) return y @autotest(n=5) def test_clamp_max_none_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clamp( input, min=random().to(float), max=random().to(float) | nothing() ) return y @profile(torch.clamp) def profile_clamp(test_case): torch.clamp(torch.ones(4), -1, 2) torch.clamp(torch.ones(100000), -1, 2) @autotest(n=5) def test_clip_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clip(input, min=random().to(float), max=random().to(float)) return y @autotest(n=5) def test_clip_min_none_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clip(input, min=random().to(float), max=random().to(float)) return y @autotest(n=5) def test_clip_max_none_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clip( input, min=random().to(float), max=random().to(float) | nothing() ) return y @profile(torch.clip) def profile_clip(test_case): torch.clip(torch.ones(4), -1, 2) torch.clip(torch.ones(100000), -1, 2) @autotest(n=5, auto_backward=False, check_graph=True) def test_clamp_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.clamp(x, min=random().to(float), max=random().to(float)) return y def _test_clamp_min(test_case, shape, device): input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.clamp_min(input, 0.1) np_out = np.clip(input.numpy(), 0.1, None) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_min_integral(test_case, shape, device): input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device)) of_out = flow.clamp_min(input, 1) np_out = np.clip(input.numpy(), 1, None) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_min_backward(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.clamp_min(x, 0.1).sum() y.backward() test_case.assertTrue( np.allclose( x.grad.numpy(), _numpy_clamp_grad(x.numpy(), 0.1, None), 1e-05, 1e-05 ) ) @flow.unittest.skip_unless_1n1d() class TestClampMinModule(flow.unittest.TestCase): def test_clamp_min(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_clamp_min, _test_clamp_min_integral, _test_clamp_min_backward, ] arg_dict["shape"] = [(2,), (2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_clamp_min_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clamp_min(input, min=random().to(float)) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_clamp_min_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.clamp_min(x, min=random().to(float)) return y def _test_clamp_max(test_case, shape, device): input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.clamp_max(input, 0.5) np_out = np.clip(input.numpy(), None, 0.5) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_max_integral(test_case, shape, device): input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device)) of_out = flow.clamp_max(input, 1) np_out = np.clip(input.numpy(), None, 1) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_clamp_max_backward(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.clamp_max(x, 0.5).sum() y.backward() test_case.assertTrue( np.allclose( x.grad.numpy(), _numpy_clamp_grad(x.numpy(), None, 0.5), 1e-05, 1e-05 ) ) @flow.unittest.skip_unless_1n1d() class TestClampMaxModule(flow.unittest.TestCase): def test_clamp_min(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_clamp_max, _test_clamp_max_integral, _test_clamp_max_backward, ] arg_dict["shape"] = [(2,), (2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_clamp_max_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.clamp_max(input, max=random().to(float)) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_clamp_max_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.clamp_max(x, max=random().to(float)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_clip_grad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList def _clip_grad_norm_np(input, max_norm, norm_type): np_out = np.maximum(0, input) np_grad = np.array(np_out > 0, dtype=np.float32) max_norm = float(max_norm) norm_type = float(norm_type) input = [input] if len(input) == 0: return 0, 0 if norm_type == float("inf"): total_norm = np.max(np.abs(np_grad)) if norm_type == float("-inf"): total_norm = np.min(np.abs(np_grad)) elif norm_type == 0: total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0) else: total_norm = np_grad for i in range(np_grad.ndim, 0, -1): total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: np_grad = np.dot(np_grad, clip_coef) return total_norm, np_grad def _test_clip_grad_norm_impl(test_case, shape, device, max_norm, norm_type, fused): np_input = np.random.rand(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = flow.nn.ReLU() of_out = m(of_input) of_out = of_out.sum() of_out.backward() of_total_norm = flow.nn.utils.clip_grad_norm_(of_input, max_norm, norm_type, fused) np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type) test_case.assertTrue( np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True) ) test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_grad, 1e-4, 1e-4, equal_nan=True) ) def _clip_grad_value_np(input, clip_value): np_out = np.maximum(0, input) np_grad = np.array(np_out > 0, dtype=np.float32) clip_value = float(clip_value) if len(input) == 0: return 0, 0 np_grad = np.clip(np_grad, -clip_value, clip_value) return np_grad def _test_clip_grad_value_impl(test_case, shape, device, clip_value): np_input = np.random.rand(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = flow.nn.ReLU() of_out = m(of_input) of_out = of_out.sum() of_out.backward() flow.nn.utils.clip_grad_value_(of_input, clip_value) of_grad = of_input.grad.numpy() np_grad = _clip_grad_value_np(np_input, clip_value) test_case.assertTrue(np.allclose(of_grad, np_grad, 1e-4, 1e-4, equal_nan=True)) class ReluGraph(flow.nn.Graph): def __init__(self, clip_value) -> None: super().__init__() self.clip_value = clip_value def build(self, x): flow.nn.utils.clip_grad_value_(x, self.clip_value) return x def _test_graph_clip_grad_value_impl(test_case, shape, device, clip_value): np_input = np.random.rand(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_eager_out = of_input flow.nn.utils.clip_grad_value_(of_eager_out, clip_value) relu_graph = ReluGraph(clip_value) of_graph_out = relu_graph(of_input) test_case.assertTrue( np.allclose( of_eager_out.numpy(), of_graph_out.numpy(), 1e-4, 1e-4, equal_nan=True ) ) def _test_clip_grad_norm_global_impl( test_case, shape, sbp, placement, max_norm, norm_type ): of_input = flow.rand( *shape, dtype=flow.float32, sbp=sbp, placement=placement, requires_grad=True ) np_input = of_input.to_global(sbp=flow.sbp.broadcast).to_local().numpy() m = flow.nn.ReLU() of_out = m(of_input) of_out = of_out.sum() of_out.backward() of_total_norm = flow.nn.utils.clip_grad_norm_( of_input, max_norm, norm_type ).to_local() np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type) test_case.assertTrue( np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True) ) test_case.assertTrue( np.allclose( of_input.grad.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), np_grad, 1e-4, 1e-4, equal_nan=True, ) ) @flow.unittest.skip_unless_1n1d() class TestClipGrad(flow.unittest.TestCase): def test_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["max_norm"] = [0, 0.5, 1.0] arg_dict["norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["fused"] = [False, True] for arg in GenArgList(arg_dict): _test_clip_grad_norm_impl(test_case, *arg) def test_clip_value(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["clip_value"] = [0, 0.5, 1.0] for arg in GenArgList(arg_dict): _test_clip_grad_value_impl(test_case, *arg) _test_graph_clip_grad_value_impl(test_case, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestClipGradGlobal(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_clip_grad_global(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 4), (2, 4, 3), (2, 4, 5, 6)] arg_dict["sbp"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)] arg_dict["placement"] = [ flow.placement.all("cpu"), flow.placement.all("cuda"), ] arg_dict["max_norm"] = [0, 0.5, 1.0] arg_dict["norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] for arg in GenArgList(arg_dict): _test_clip_grad_norm_global_impl(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_clone.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestClone(flow.unittest.TestCase): @autotest(n=3) def test_clone_with_random_data(test_case): x = random_tensor() y = torch.clone(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_coco_reader.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest class COCODataLoader(flow.nn.Module): def __init__( self, anno_file=flow.unittest.dataset_dir( "mscoco_2017/annotations/instances_val2017.json" ), image_dir=flow.unittest.dataset_dir("mscoco_2017/val2017"), batch_size=2, device=None, placement=None, sbp=None, ): super().__init__() self.coco_reader = flow.nn.COCOReader( annotation_file=anno_file, image_dir=image_dir, batch_size=batch_size, shuffle=True, random_seed=12345, stride_partition=True, device=device, placement=placement, sbp=sbp, ) self.image_decoder = flow.nn.image.decode(dtype=flow.float32) self.resize = flow.nn.image.Resize(target_size=[224, 224], dtype=flow.float32) def forward(self): outputs = self.coco_reader() # decode images image = self.image_decoder(outputs[0]) fixed_image = self.resize(image)[0] image_id = outputs[1] image_size = outputs[2] return fixed_image, image_id, image_size class DataLoaderGraph(flow.nn.Graph): def __init__(self, loader): super().__init__() self.loader_ = loader def build(self): return self.loader_() @flow.unittest.skip_unless_1n2d() class COCODataLoaderDistributedTestCase(oneflow.unittest.TestCase): def test_case1(test_case): rank = flow.env.get_rank() # pid = os.getpid() # print(f"[{pid}][{rank}] COCODataLoaderDistributedTestCase.test_case1") eager_coco_loader = COCODataLoader( batch_size=2, device=flow.device("cpu", rank) ) global_coco_loader = COCODataLoader( batch_size=4, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)], ) coco_loader_graph = DataLoaderGraph(global_coco_loader) # coco_loader_graph.debug() iteration = 1 for i in range(iteration): image, image_id, image_size = eager_coco_loader() # print(f"image: {image.numpy().mean()} ") # print(f"image_id: {image_id.numpy()}") # print(f"image_size: {image_size.numpy()}") g_image, g_image_id, g_image_size = coco_loader_graph() # print(f"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}") test_case.assertTrue(np.allclose(image.numpy(), g_image.to_local().numpy())) test_case.assertTrue( np.allclose(image_id.numpy(), g_image_id.to_local().numpy()) ) test_case.assertTrue( np.allclose(image_size.numpy(), g_image_size.to_local().numpy()) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_coin_flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_coin_flip_impl(test_case, batch_size, random_seed, probability, device): m = flow.nn.CoinFlip(batch_size, random_seed, probability, device) x = m() test_case.assertEqual(x.shape[0], batch_size) device = flow.device(device) test_case.assertEqual(x.device, device) class TestCoinFlipModule(flow.unittest.TestCase): def test_coin_flip(test_case): arg_dict = OrderedDict() arg_dict["batch_size"] = [1, 2, 50] arg_dict["random_seed"] = [None, 1, -1] arg_dict["probability"] = [0.0, 0.5, 1.0] # TODO: CoinFlip support cuda kernel # arg_dict["device"] = ["cpu", "cuda"] arg_dict["device"] = ["cpu"] for arg in GenArgDict(arg_dict): _test_coin_flip_impl(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_comb2to2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow import nn import os import numpy as np import oneflow.unittest flow.boxing.nccl.enable_use_compute_stream(False) class _TestModuleDiffHierarchy(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: for sbp4 in sbp_1ds: # (3, 2) -> (2, 3) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(6)).reshape(2, 3) ), sbp=[sbp1, sbp2], ) # (2, 3) -> (3, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(6)).reshape(3, 2) ), sbp=[sbp3, sbp4], ) return x class _TestModuleDiffPlacement(nn.Module): def forward(self, x): sbp_1ds = [ flow.sbp.broadcast, flow.sbp.partial_sum, flow.sbp.split(0), flow.sbp.split(1), ] for sbp1 in sbp_1ds: for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: for sbp4 in sbp_1ds: # (3, 2) -> (2, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), sbp=[sbp1, sbp2], ) # (2, 2) -> (3, 2) x = x.to_global( placement=flow.placement( type="cuda", ranks=np.array(range(6)).reshape(3, 2) ), sbp=[sbp3, sbp4], ) return x class _TestGraph(nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): x = self.model(x) return x @flow.unittest.skip_unless_2n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestLazyAllSbpCombinationTesting(flow.unittest.TestCase): def test_lazy_boxing_2d_all_combination(test_case): os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" x = flow.ones( 12, 12, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(6)).reshape(3, 2) ), ) model_diff_hierarchy = _TestModuleDiffHierarchy() graph_diff_hierarchy = _TestGraph(model_diff_hierarchy) y = graph_diff_hierarchy(x) model_diff_placement = _TestModuleDiffPlacement() graph_diff_placement = _TestGraph(model_diff_placement) z = graph_diff_placement(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_combined_margin_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _scatter_add_numpy(src, dim, index, outshape): output = np.zeros(outshape) for srcidx in range(0, src.size): outcoord = np.unravel_index(srcidx, src.shape) outcoord = [*outcoord] outcoord[dim] = index[np.unravel_index(srcidx, index.shape)] output_offset = np.ravel_multi_index(outcoord, outshape) output[np.unravel_index(output_offset, outshape)] += src[ np.unravel_index(srcidx, src.shape) ] return output def _np_one_hot(indices, depth): return np.eye(depth)[indices.reshape(-1)] def _np_gather_with_batch_dims(params, indices, axis): batch_dims = 1 result = [] for p, i in zip(params, indices): r = np.take_along_axis(p, i, axis - batch_dims) result.append(r) return np.stack(result) def _np_gather_with_batch_dims_grad(params, indices, axis, output): batch_dims = 1 result = [] for p, i, o in zip(params, indices, output): r = _scatter_add_numpy(np.ones_like(o), axis - batch_dims, i, p.shape) result.append(r) return np.stack(result) def _np_combined_margin_loss(np_input, np_label, m1, m2, m3): class_num = np_input.shape[1] if m1 != 1.0 or m2 != 0.0 or m3 != 0.0: if m1 == 1.0 and m2 == 0.0: gt_one_hot = _np_one_hot(np_label, class_num) * m3 np_input = np_input - gt_one_hot else: np_label_expand = np.reshape(np_label, (np_label.shape[0], 1)) zy = _np_gather_with_batch_dims(np_input, np_label_expand, 0) cos_t = zy * 1 t = np.arccos(cos_t) if m1 != 1.0: t = t * m1 if m2 > 0.0: t = t + m2 body = np.cos(t) if m3 > 0.0: body = body - m3 new_zy = body diff = new_zy - zy gt_one_hot = _np_one_hot(np_label, class_num) body = gt_one_hot * diff np_input = np_input + body return np_input def _np_combined_margin_loss_grad(np_input, np_label, m1, m2, m3): class_num = np_input.shape[1] if m1 != 1.0 or m2 != 0.0 or m3 != 0.0: if m1 == 1.0 and m2 == 0.0: result = np.ones(np_input.shape) else: np_label_expand = np.reshape(np_label, (np_label.shape[0], 1)) zy = _np_gather_with_batch_dims(np_input, np_label_expand, 0) dzy = _np_gather_with_batch_dims_grad(np_input, np_label_expand, 0, zy) cos_t = zy * 1 t = np.arccos(cos_t) dt = -1 / np.sqrt((1 - cos_t * cos_t)) * dzy if m1 != 1.0: t = t * m1 dt = dt * m1 if m2 > 0.0: t = t + m2 body = np.cos(t) dbody = -np.sin(t) * dt if m3 > 0.0: body = body - m3 new_zy = body diff = new_zy - zy ddiff = dbody - dzy gt_one_hot = _np_one_hot(np_label, class_num) body = gt_one_hot * diff dbody = gt_one_hot * ddiff np_input = np_input + body result = np.ones(np_input.shape) + dbody else: result = np.ones(np_input.shape) return result def _test_combined_margin_loss( test_case, device_type, input_shape, label_shape, data_type, m1, m2, m3 ): assert device_type in ["cpu", "cuda"] np_x = np.random.uniform(low=-1, high=1, size=input_shape).astype(np.float32) np_labels = np.random.randint(0, input_shape[1], size=(*label_shape,)).astype( np.int32 ) x = flow.tensor(np_x, device=device_type, dtype=data_type, requires_grad=True) labels = flow.tensor(np_labels, device=device_type, dtype=flow.int32) loss_func = flow.nn.CombinedMarginLoss(m1, m2, m3).to(flow.device(device_type)) output = loss_func(x, labels) output.sum().backward() output_ref = _np_combined_margin_loss(np_x, np_labels, m1, m2, m3) test_case.assertTrue(np.allclose(output.numpy(), output_ref, rtol=1e-5, atol=1e-5)) input_grad_ref = _np_combined_margin_loss_grad(np_x, np_labels, m1, m2, m3) test_case.assertTrue( np.allclose(x.grad.numpy(), input_grad_ref, rtol=1e-4, atol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestCombinedMarginLoss(flow.unittest.TestCase): def test_combined_margin_loss(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_combined_margin_loss] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["input_shape"] = [(64, 1000)] arg_dict["label_shape"] = [(64,)] arg_dict["data_type"] = [flow.float32] arg_dict["m1"] = [0.3, 1.0] arg_dict["m2"] = [0.5, 0.0] arg_dict["m3"] = [0.4, 0.0] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_comm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from threading import Thread import numpy as np import os import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestComm(flow.unittest.TestCase): def _test_send_recv(test_case, x0, src, dst): rank = flow.env.get_rank() if rank == src: x1 = x0 flow.comm.send(x1, dst) x2 = x0 flow.comm.send(x2, dst) elif rank == dst: x1 = flow.comm.recv(src) test_case.assertTrue(np.array_equal(x1.numpy(), x0.numpy())) test_case.assertEqual(x1.device, x0.device) x2 = flow.zeros_like(x0) flow.comm.recv(src, out=x2) test_case.assertTrue(np.array_equal(x2.numpy(), x0.numpy())) test_case.assertEqual(x2.device, x0.device) else: # do nothing pass @flow.unittest.skip_unless_1n2d() def test_send_recv_2_devices(test_case): x0 = flow.tensor([[1, 2]]) test_case._test_send_recv(x0, 0, 1) x0 = x0.to("cuda") test_case._test_send_recv(x0, 1, 0) @flow.unittest.skip_unless_1n4d() def test_send_recv_4_devices(test_case): x0 = flow.tensor([[1, 2]]) test_case._test_send_recv(x0, 3, 1) x0 = x0.to("cuda") test_case._test_send_recv(x0, 0, 3) def _test_send_recv_without_sending_meta(test_case, x0, src, dst): rank = flow.env.get_rank() if rank == src: x1 = x0 flow.comm.send(x1, dst, send_meta=False) x2 = x0 flow.comm.send(x2, dst, send_meta=False) elif rank == dst: x1 = flow.comm.recv(src, shape=x0.shape, dtype=x0.dtype, device=x0.device) test_case.assertTrue(np.array_equal(x1.numpy(), x0.numpy())) x2 = flow.zeros_like(x0) flow.comm.recv( src, shape=x0.shape, dtype=x0.dtype, device=x0.device, out=x2 ) test_case.assertTrue(np.array_equal(x2.numpy(), x0.numpy())) else: # do nothing pass @flow.unittest.skip_unless_1n2d() def test_send_recv_without_sending_meta_2_devices(test_case): x0 = flow.tensor([[1, 2]]) test_case._test_send_recv_without_sending_meta(x0, 1, 0) x0 = x0.to("cuda") test_case._test_send_recv_without_sending_meta(x0, 0, 1) @flow.unittest.skip_unless_1n4d() def test_send_recv_without_sending_meta_4_devices(test_case): x0 = flow.tensor([[1, 2]]) test_case._test_send_recv_without_sending_meta(x0, 2, 3) x0 = x0.to("cuda") test_case._test_send_recv_without_sending_meta(x0, 3, 1) @flow.unittest.skip_unless_1n2d() def test_comm_in_thread(test_case): def threaded_function(): rank = flow.env.get_rank() rev = flow.framework.check_point_v2._broadcast_py_object(rank, 0) test_case.assertEqual(rev, 0) x = flow.tensor([rank, rank + 1]).to_global( placement=flow.placement.all("cpu"), sbp=flow.sbp.split(0) ) test_case.assertTrue(np.array_equal(x.numpy(), np.array([0, 1, 1, 2]))) x = flow.tensor([rank, rank + 1]) flow.comm.all_reduce(x) test_case.assertTrue(np.array_equal(x.numpy(), np.array([1, 3]))) thread = Thread(target=threaded_function) thread.start() thread.join() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_comm_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import unittest import os import oneflow as flow import oneflow.unittest import torch import torch.distributed as dist @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllReduce(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_all_reduce_1n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) tensor = flow.tensor(np_arr, device="cuda") flow.comm.all_reduce(tensor) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr * 2)) @flow.unittest.skip_unless_2n2d() def test_all_reduce_2n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) tensor = flow.tensor(np_arr, device="cuda") flow.comm.all_reduce(tensor) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr * 4)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllGather(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_all_gather_into_tensor_1n2d(test_case): device = "cuda" tensor_in = ( flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device=device) + flow.env.get_rank() * 6 ) tensor_out = flow.zeros(4, 3, dtype=flow.int64, device=device) flow.comm.all_gather_into_tensor(tensor_out, tensor_in) test_case.assertTrue( np.allclose( tensor_out.numpy(), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), ) ) tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device=device) flow.comm.all_gather_into_tensor(tensor_out2, tensor_in) test_case.assertTrue( np.allclose( tensor_out2.numpy(), np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]), ) ) @flow.unittest.skip_unless_1n2d() def test_all_gather_1n2d(test_case): if flow.env.get_rank() == 0: np_arr = np.array([[2, 3], [4, 5]]) elif flow.env.get_rank() == 1: np_arr = np.array([[1, 2], [3, 4]]) input = flow.tensor(np_arr, device="cuda", dtype=flow.int32) tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(2)] flow.comm.all_gather(tensor_list, input) test_case.assertTrue( np.allclose(tensor_list[0].numpy(), np.array([[2, 3], [4, 5]])) ) test_case.assertTrue( np.allclose(tensor_list[1].numpy(), np.array([[1, 2], [3, 4]])) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestBroadCast(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_broadcast_1n2d(test_case): if flow.env.get_rank() == 0: np_arr = np.array([[1, 2], [3, 4]]) elif flow.env.get_rank() == 1: np_arr = np.array([[4, 5], [6, 7]]) tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) flow.comm.broadcast(tensor, 1) test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]]))) tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) flow.comm.broadcast(tensor, 0) test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[1, 2], [3, 4]]))) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestScatter(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_scatter_1n4d(test_case): output = flow.tensor([[1, 2], [3, 4]], device="cuda") if flow.env.get_rank() == 1: tensor_list = [ flow.tensor([[5, 6], [7, 8]], device="cuda") + i for i in range(4) ] flow.comm.scatter(output, tensor_list, src=1) test_case.assertTrue( np.allclose(output.numpy(), np.array([[6, 7], [8, 9]])) ) else: flow.comm.scatter(output, src=1) test_case.assertTrue( np.allclose( output.numpy(), np.array([[5, 6], [7, 8]]) + flow.env.get_rank() ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGather(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_gather_1n4d(test_case): np_arr = np.array([[1, 2], [3, 4]]) if flow.env.get_rank() == 1: input = flow.tensor( np_arr + flow.env.get_rank(), device="cuda", dtype=flow.int32 ) tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(4)] flow.comm.gather(input, gather_list=tensor_list, dst=1) for i in range(4): test_case.assertTrue( np.allclose(tensor_list[i].numpy(), np.array([[1, 2], [3, 4]]) + i) ) else: input = flow.tensor( np_arr + flow.env.get_rank(), device="cuda", dtype=flow.int32 ) flow.comm.gather(input, dst=1) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestReduce(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_reduce_1n2d(test_case): if flow.env.get_rank() == 0: np_arr = np.array([[1, 2], [3, 4]]) elif flow.env.get_rank() == 1: np_arr = np.array([[4, 5], [6, 7]]) tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) flow.comm.reduce(tensor, 0) if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose(tensor.numpy(), np.array([[5, 7], [9, 11]])) ) else: test_case.assertTrue( np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]])) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllToAll(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_all_to_all_1n4d(test_case): input_list = [ flow.tensor([0, 1], device="cuda") + i * 2 + flow.env.get_rank() * 8 for i in range(4) ] output_list = [flow.tensor([0, 1], device="cuda") for _ in range(4)] flow.comm.all_to_all(output_list, input_list) for i in range(len(output_list)): test_case.assertTrue( np.allclose( output_list[i].numpy(), input_list[i].numpy() + (i - flow.env.get_rank()) * 6, ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestReduceScatter(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_reduce_scatter_1n4d(test_case): output = flow.tensor([[0, 0], [0, 0]], device="cuda") tensor_list = [ flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_rank() + i for i in range(4) ] flow.comm.reduce_scatter(output, tensor_list) test_case.assertTrue( np.allclose(output.numpy(), tensor_list[0].numpy() * 4 + 6) ) @flow.unittest.skip_unless_1n2d() def test_reduce_scatter_tensor_1n2d(test_case): tensor_in = flow.tensor( [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=flow.int64, device="cuda", ) tensor_out = flow.zeros(2, 3, dtype=flow.int64, device="cuda") flow.comm.reduce_scatter_tensor(tensor_out, tensor_in) if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose(tensor_out.numpy(), np.array([[2, 4, 6], [8, 10, 12]]),) ) else: test_case.assertTrue( np.allclose(tensor_out.numpy(), np.array([[14, 16, 18], [20, 22, 24]]),) ) tensor_in2 = tensor_in.reshape(2, 3, 2) tensor_out2 = flow.zeros(2, 3, dtype=flow.int64, device="cuda") flow.comm.reduce_scatter_tensor(tensor_out2, tensor_in2) if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose(tensor_out2.numpy(), np.array([[2, 4, 6], [8, 10, 12]]),) ) else: test_case.assertTrue( np.allclose( tensor_out2.numpy(), np.array([[14, 16, 18], [20, 22, 24]]), ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestDocs(flow.unittest.TestCase): def test_docs(test_case): oneflow.framework.unittest.check_multi_rank_docstr(oneflow.comm.comm_ops) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_concat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_concat_origin(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.cat([input1, input2], dim=0) np_out = np.concatenate((input1.numpy(), input2.numpy()), axis=0) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_concat_with_empty_input(test_case, device): input1 = flow.Tensor().to(flow.device(device)) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out1 = flow.cat([input1, input2], dim=0) of_out2 = flow.cat([input2, input1], dim=0) of_out3 = flow.cat([input1, input2, input1, input1], dim=0) torch_input1 = torch.Tensor().to(torch.device(device)) torch_input2 = torch.tensor( np.random.randn(2, 6, 5, 3), dtype=torch.float32, device=torch.device(device) ) torch_out1 = torch.cat((torch_input1, torch_input2), 0) torch_out2 = torch.cat((torch_input2, torch_input1), 0) torch_out3 = torch.cat((torch_input1, torch_input2, torch_input1, torch_input1), 0) test_case.assertTrue( np.array_equal(of_out1.numpy(), torch_out1.detach().cpu().numpy()) ) test_case.assertTrue( np.array_equal(of_out2.numpy(), torch_out2.detach().cpu().numpy()) ) test_case.assertTrue( np.array_equal(of_out3.numpy(), torch_out3.detach().cpu().numpy()) ) test_case.assertTrue( np.array_equal(of_out1.numpy(), torch_out2.detach().cpu().numpy()) ) test_case.assertTrue( np.array_equal(of_out1.numpy(), torch_out3.detach().cpu().numpy()) ) def _test_concat_with_axis_one(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.cat([input1, input2], dim=1) np_out = np.concatenate((input1.numpy(), input2.numpy()), axis=1) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_concat_with_three_tensor(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input3 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.cat([input1, input2, input3], dim=1) np_out = np.concatenate((input1.numpy(), input2.numpy(), input3.numpy()), axis=1) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_concat_with_three_tensor_backward(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) input3 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.cat([input1, input2, input3], dim=1) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001) ) test_case.assertTrue( np.allclose(input2.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001) ) test_case.assertTrue( np.allclose(input3.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001) ) def _test_concat_grad_and_no_grad(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) of_out = flow.cat([input1, input2], dim=1) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001) ) def _test_concat_single_input_type(test_case, device): torch_list = [torch.Tensor([1, 1, 9, 1])] torch_list = [t.to(dtype=torch.int64, device=device) for t in torch_list] flow_list = [flow.Tensor([1, 1, 9, 1])] flow_list = [t.to(dtype=flow.int64, device=device) for t in flow_list] flow_cat_list = flow.cat(flow_list) test_case.assertTrue(flow_cat_list.dtype is oneflow.int64) def _test_concat_grad_fn_name(test_case, device): x1 = flow.randn(2, 3, requires_grad=True, device=device) x2 = flow.randn(2, 3, requires_grad=True, device=device) cat = flow.cat([x1, x2], dim=1) grad_fn_name = cat.grad_fn.name() test_case.assertEqual(grad_fn_name, "catBackward") test_case.assertEqual(cat.grad_fn.next_functions[0][0].name(), "accumulategrad") next_fn = cat.grad_fn.next_functions[0] test_case.assertTrue( np.allclose(next_fn[0].variable.numpy(), x1.numpy(), 0.0001, 0.0001) ) next_fn = cat.grad_fn.next_functions[1] test_case.assertTrue( np.allclose(next_fn[0].variable.numpy(), x2.numpy(), 0.0001, 0.0001) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_concat(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_concat_origin, _test_concat_with_empty_input, _test_concat_with_axis_one, _test_concat_with_three_tensor, _test_concat_with_three_tensor_backward, _test_concat_grad_and_no_grad, _test_concat_single_input_type, _test_concat_grad_fn_name, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_cat_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device) return torch.cat((x, x, x), random(0, 2).to(int)) @autotest(n=5, check_graph=True, check_dtype=True) def test_cat_with_diff_dtypes(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device).float() y = x.int() z = x.double() return torch.cat((x, y, z), random(0, 2).to(int)) @autotest(n=1, check_graph=True, check_dtype=True) def test_cat_with_diff_dtype_corner_case(test_case): device = random_device() input_list = list() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device) y = x.int() for i in range(128): input_list.append(x) for j in range(128, 257): input_list.append(y) return torch.cat(tuple(input_list), random(0, 2).to(int)) @autotest(n=5, auto_backward=False, check_graph=True) def test_concat_with_input_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 2, 4).to(device) y = random_tensor(4, 2, 3, random(0, 3), 4).to(device) z = torch.cat((x, y), dim=2) return z @autotest(n=5, auto_backward=False, check_graph=True) def test_concat_with_output_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 0, 2, 4).to(device) y = random_tensor(4, 2, 0, 2, 4).to(device) dim = random(0, 4).to(int).value() z = torch.cat((x, y), dim=dim) return z @autotest(n=5, auto_backward=False, check_graph=True) def test_cat_bool_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device, torch.bool) return torch.cat((x, x, x), random(0, 2).to(int)) @autotest(n=5, check_graph=True) def test_cat_only_one_tensor(test_case): device = random_device() x = random_tensor(4, 2, 3, random(0, 3)).to(device) return torch.cat((x,), 0) @profile(torch.cat) def profile_cat(test_case): input = torch.ones(100, 100) torch.cat((input, input), dim=0) torch.cat((input, input), dim=1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_constant.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _test_different_dtype(test_case, device, shape): y1 = flow.ones(shape, dtype=flow.int32, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int32), y1.numpy())) y2 = flow.ones(shape, dtype=flow.uint8, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.uint8), y2.numpy())) y3 = flow.ones(shape, dtype=flow.float64, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.float64), y3.numpy())) y4 = flow.ones(shape, dtype=flow.short, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.short), y4.numpy())) y5 = flow.ones(shape, dtype=flow.int16, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int16), y5.numpy())) y6 = flow.ones(shape, dtype=flow.char, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int8), y6.numpy())) y7 = flow.ones(shape, dtype=flow.int8, device=flow.device(device)) test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int8), y7.numpy())) @flow.unittest.skip_unless_1n1d() class TestConstantModule(flow.unittest.TestCase): @autotest(n=10, auto_backward=False, check_graph=True) def test_flow_zeros_list_with_random_data(test_case): device = random_device() y1 = torch.zeros(random().to(int)).to(device) y2 = torch.zeros(random().to(int), random().to(int)).to(device) y3 = torch.zeros(random().to(int), random().to(int), random().to(int)).to( device ) y4 = torch.zeros( random().to(int), random().to(int), random().to(int), random().to(int) ).to(device) return y1, y2, y3, y4 @profile(torch.zeros) def profile_zeros(test_case): torch.zeros(2, 3) torch.zeros(32, 3, 128, 128) torch.zeros(1000, 1000) @autotest(n=10, auto_backward=False, check_graph=True) def test_flow_ones_list_with_random_data(test_case): device = random_device() y1 = torch.ones(random().to(int)).to(device) y2 = torch.ones(random().to(int), random().to(int)).to(device) y3 = torch.ones(random().to(int), random().to(int), random().to(int)).to(device) y4 = torch.ones( random().to(int), random().to(int), random().to(int), random().to(int) ).to(device) return y1, y2, y3, y4 @profile(torch.ones) def profile_ones(test_case): torch.ones(2, 3) torch.ones(32, 3, 128, 128) torch.ones(1000, 1000) @autotest(auto_backward=False, check_graph=True) def test_flow_zeros_like_list_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.zeros_like(x) return y @profile(torch.zeros_like) def profile_zeros_like(test_case): input1 = torch.ones(32, 3, 128, 128) input2 = torch.ones(1000, 1000) input3 = torch.ones(2, 3) torch.zeros_like(input1) torch.zeros_like(input2) torch.zeros_like(input3) @autotest(auto_backward=True, check_graph=True) def test_flow_zeros_like_list_with_random_data_and_requires_grad(test_case): device = random_device() x = random_tensor().to(device) y = torch.zeros_like(x, requires_grad=True) return y @autotest(auto_backward=False, check_graph=True) def test_flow_zeros_like_list_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.zeros_like(x) return y @autotest(auto_backward=False, check_graph=True) def test_flow_ones_like_list_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.ones_like(x) return y @profile(torch.ones_like) def profile_ones_like(test_case): input1 = torch.ones(32, 3, 128, 128) input2 = torch.ones(1000, 1000) input3 = torch.ones(2, 3) torch.ones_like(input1) torch.ones_like(input2) torch.ones_like(input3) @autotest(auto_backward=True, check_graph=True) def test_flow_ones_like_list_with_random_data_and_requires_grad(test_case): device = random_device() x = random_tensor().to(device) y = torch.ones_like(x, requires_grad=True) return y @autotest(auto_backward=False, check_graph=True) def test_flow_ones_like_list_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.ones_like(x) return y @autotest(auto_backward=True, check_graph=True) def test_flow_new_ones_list_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.new_ones( (random().to(int), random().to(int), random().to(int)), device=device.value(), requires_grad=constant(True), ) return y @profile(torch.Tensor.new_ones) def profile_new_ones(test_case): x = torch.Tensor(np.ones((1, 2, 3))) x.new_ones((2, 3)) x.new_ones((32, 3, 128, 128)) x.new_ones((1000, 1000, 1000, 1000)) @unittest.skip("skip for now, becase it failed 10 times in past week") @autotest(auto_backward=True, check_graph=True) def test_flow_new_ones_list_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.new_ones( (random().to(int), random().to(int), random().to(int)), device=device.value(), requires_grad=constant(True), ) return y @autotest(n=5) def test_new_zeros(test_case): device = random_device() x = random_tensor().to(device) y = x.new_zeros( (random().to(int), random().to(int), random().to(int)), device=device.value(), requires_grad=constant(True), ) return y @profile(torch.Tensor.new_zeros) def profile_new_zeros(test_case): x = torch.Tensor(np.ones((1, 2, 3))) x.new_zeros((2, 3)) x.new_zeros((32, 3, 128, 128)) x.new_zeros((1000, 1000, 1000, 1000)) @autotest(n=5) def test_new_full(test_case): device = random_device() x = random_tensor().to(device) y = x.new_full( (random().to(int), random().to(int), random().to(int)), random().to(float).value(), device=device.value(), requires_grad=constant(True), ) return y @autotest(n=5, auto_backward=False) def test_new_full_with_scalar(test_case): device = random_device() x = random_tensor().to(device) y = x.new_full([], random().to(int)) return y @autotest(n=5, auto_backward=False) def test_full_with_scalar(test_case): device = random_device() y = torch.full([], random().to(int), device=device) return y @autotest(n=10, auto_backward=True) def test_full_with_random_data_int(test_case): device = random_device() shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape y = torch.full(shape, 2.0, requires_grad=True) return y @autotest(n=5) def test_full_with_random_data_numpy_scalar(test_case): device = random_device() shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape y = torch.full(shape, np.array([2.0])[0], device=device, requires_grad=True) return y @autotest(n=5) def test_full_with_scalar_tensor(test_case): device = random_device() shape = random_tensor(low=0, high=6, requires_grad=False).pytorch.shape y = torch.full( shape, torch.tensor(2.0, requires_grad=random().to(bool)), device=device, requires_grad=True, ) return y @profile(torch.full) def profile_full_with_scalar_tensor(test_case): torch.full((2, 3), torch.tensor(3.141592)) torch.full((64, 3, 128, 128), torch.tensor(3.141592)) torch.full((1000, 1000), torch.tensor(3.141592)) @profile(torch.full) def profile_full(test_case): torch.full((2, 3), 3.141592) torch.full((64, 3, 128, 128), 3.141592) torch.full((1000, 1000), 3.141592) @autotest(n=10, auto_backward=True) def test_full_with_random_data_float(test_case): device = random_device() shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape y = torch.full(shape, 2.0, requires_grad=True) return y @autotest(n=10, auto_backward=True) def test_full_like_with_random_data_float(test_case): device = random_device() x = random_tensor(low=1, high=6, requires_grad=False).to(device) y = torch.full_like(x, 2.0, requires_grad=True) return y @profile(torch.full_like) def profile_full_like(test_case): torch.full_like(torch.ones(2, 3), 3.141592) torch.full_like(torch.ones(64, 3, 128, 128), 3.141592) torch.full_like(torch.ones(1000, 1000), 3.141592) def test_cast(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_different_dtype, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_constant_pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from random import choice from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_2_t, _size_4_t, _size_6_t import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestConstantPad1d(flow.unittest.TestCase): @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True) def test_constantpad1d_with_random_data(test_case): m = torch.nn.ConstantPad1d( padding=random(1, 6).to(_size_2_t), value=random().to(float) ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=random(1, 6), dim2=random(1, 6)).to(device) y = m(x) return y @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False) def test_constantpad1d_with_random_int_data(test_case): dtype = choice([int, bool]) value = random(0, 2).to(bool) if dtype is bool else random().to(int) m = torch.nn.ConstantPad1d(padding=random(1, 6).to(_size_2_t), value=value) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=random(1, 6), dim2=random(1, 6), dtype=int).to( device ) if dtype is bool: x = x.bool() y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestConstantPad2d(flow.unittest.TestCase): @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True) def test_constantpad2d_with_random_data(test_case): m = torch.nn.ConstantPad2d( padding=random(1, 6).to(_size_4_t), value=random().to(float) ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=4, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6) ).to(device) y = m(x) return y @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False) def test_constantpad2d_with_random_int_data(test_case): dtype = choice([int, bool]) value = random(0, 2).to(bool) if dtype is bool else random().to(int) m = torch.nn.ConstantPad2d(padding=random(1, 6).to(_size_4_t), value=value,) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=4, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6) ).to(device) if dtype is bool: x = x.bool() y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestConstantPad3d(flow.unittest.TestCase): @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True) def test_constantpad3d_with_random_data(test_case): m = torch.nn.ConstantPad3d( padding=random(1, 6).to(_size_6_t), value=random().to(float) ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=5, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6), dim4=random(1, 6), ).to(device) y = m(x) return y @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False) def test_constantpad3d_with_random_int_data(test_case): dtype = choice([bool, int]) value = random(0, 2).to(bool) if dtype is bool else random().to(int) m = torch.nn.ConstantPad3d(padding=random(1, 6).to(_size_6_t), value=value,) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=5, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6), dim4=random(1, 6), ).to(device) if dtype is bool: x = x.bool() y = m(x) return y @flow.unittest.skip_unless_1n1d() class TestFunctionalConstantPad2d(flow.unittest.TestCase): @autotest(n=10, rtol=0.001, atol=0.001, check_graph=True, include_complex=True) def test_functional_constantpad2d(test_case): device = random_device() padding = random(-1, 6).to(_size_4_t) value = random().to(float) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=random(2, 6), dim3=random(2, 6), ).to(device) y = torch.nn.functional.pad(x, pad=padding, mode="constant", value=value) return y @autotest(n=10, rtol=0.001, atol=0.001, check_graph=True, auto_backward=False) def test_functional_constantpad2d_int_data(test_case): dtype = choice([bool, int]) device = random_device() padding = random(-1, 6).to(_size_4_t) value = random(0, 2).to(bool) if dtype is bool else random().to(int) x = random_tensor( ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=random(2, 6), dim3=random(2, 6), ).to(device) if dtype is bool: x = x.bool() y = torch.nn.functional.pad(x, pad=padding, mode="constant", value=value) return y @profile(torch.nn.functional.pad) def profile_pad(test_case): tensor = torch.ones(32, 3, 128, 128) pad = (1, 1) torch.nn.functional.pad(tensor, pad) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_contiguous.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from random import shuffle import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow.unittest import oneflow as flow @flow.unittest.skip_unless_1n1d() class TestContiguous(flow.unittest.TestCase): @autotest(n=5) def test_transpose_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) z = y.contiguous() return z @autotest(n=5, auto_backward=False) def test_transpose_with_bool_data(test_case): device = random_device() x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.bool) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) z = y.contiguous() return z @autotest(n=5, auto_backward=False) def test_transpose_with_int_data(test_case): device = random_device() x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.int) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) z = y.contiguous() return z @autotest(n=5, auto_backward=False) def test_contiguous_with_half_data(test_case): device = random_device() x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.float16) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) z = y.contiguous() return z @autotest(n=10, check_graph=True) def test_permute2d_tensor_with_random_data(test_case): device = random_device() ndim = 2 permute_list = [0, 1] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 32).to(int), dim1=random(1, 59).to(int), ).to(device) y = x.permute(permute_list) z = y.contiguous() return z @autotest(n=10, check_graph=True) def test_permute3d_tensor_with_random_data(test_case): device = random_device() ndim = 3 permute_list = [0, 1, 2] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 7).to(int), dim1=random(1, 15).to(int), dim2=random(1, 9).to(int), ).to(device) y = x.permute(permute_list) z = y.contiguous() return z @autotest(n=10, check_graph=True) def test_permute4d_tensor_with_random_data(test_case): device = random_device() ndim = 4 permute_list = [0, 1, 2, 3] shuffle(permute_list) x = random_tensor( ndim=ndim, dim0=random(1, 7).to(int), dim1=random(1, 15).to(int), dim2=random(1, 9).to(int), dim3=random(1, 19).to(int), ).to(device) y = x.permute(permute_list) z = y.contiguous() return z @profile(torch.Tensor.contiguous) def profile_contiguous(test_case): x = torch.ones(32, 3, 128, 128) x.contiguous() def _test_inplace_contiguous(test_case, device): arr = np.random.randn(4, 5, 6, 7).astype(np.float32) input = flow.tensor(arr, device=device) x = input.permute(0, 3, 2, 1) # x is non-contiguous tensor test_case.assertTrue(x.is_contiguous() == False) # y1 is normal version of tensor contiguous y1 = x.contiguous() # y2 is inplace version of tensor contiguous y2 = x.contiguous_() test_case.assertTrue(np.array_equal(y1.cpu().numpy(), y2.cpu().numpy())) test_case.assertTrue(id(x) != id(y1)) test_case.assertTrue(id(x) == id(y2)) test_case.assertTrue(x.is_contiguous() == True) test_case.assertTrue(y1.is_contiguous() == True) test_case.assertTrue(y2.is_contiguous() == True) @flow.unittest.skip_unless_1n1d() class TestInplaceContiguous(flow.unittest.TestCase): def test_inplace_contiguous(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_inplace_contiguous, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_conv1d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.nn as nn import oneflow.unittest import torch as torch_original from packaging import version def _test_conv1d_bias_false(test_case, device): np_arr = np.array([[[1.28795946, -0.2921792, 0.20338029, 0.78604293, -1.89607573]]]) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [[0.10197904, 0.3372305, -0.25743008]], [[0.27720425, -0.52435774, -0.38381988]], [[0.56016803, -0.10063095, -0.10760903]], ] ) m = nn.Conv1d(1, 3, 3, stride=1, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [-0.01954307, -0.16356121, 0.77392507], [0.43217283, -0.48933625, 0.37196174], [0.72899038, -0.2687211, 0.23886177], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [[[0.93935132, 0.65159315, -0.09726584, -1.03661716, -0.74885899]]] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_bias_true(test_case, device): np_arr = np.array( [ [ [0.90499806, -1.11683071, 0.71605605, -0.56754625, 0.61944169], [-0.31317389, -0.26271924, 0.95579433, 0.52468461, 1.48926127], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [0.01997352, 0.23834395, 0.00526353], [-0.04861857, -0.22751901, -0.06725175], ], [ [0.13344523, -0.35202524, 0.15168799], [-0.25714493, -0.17459838, 0.28768948], ], [ [0.10671382, -0.28205597, -0.39752254], [0.36393702, 0.07843742, -0.33898622], ], [ [0.20485674, 0.04222689, -0.1898618], [0.22519711, -0.15910202, -0.35057363], ], ] ) bias = np.array([0.01012857, 0.38912651, -0.01600273, -0.3883304]) m = nn.Conv1d(2, 4, 3, stride=1, bias=True) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) np_out = np.array( [ [ [-0.22349545, -0.08447243, -0.37358052], [1.4130373, -0.04644597, 0.86949122], [-0.34765026, -0.31004351, -0.14158708], [-0.74985039, -0.87430149, -0.77354753], ] ] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [ [ [0.4649893, 0.11147892, -0.3189539, -0.78394318, -0.43043283], [0.28337064, -0.19941133, -0.66853344, -0.95190406, -0.46912211], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_dilation(test_case, device): np_arr = np.array( [[[-0.43016902, 1.74619496, -0.57338119, 0.25563857, 0.12575546]]] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [[-0.35057205, -0.31304273, 0.46250814]], [[-0.40786612, 0.36518192, 0.46280444]], [[-0.00921835, -0.38710043, 0.47566161]], ] ) m = nn.Conv1d(1, 3, 3, stride=1, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [-0.66102189, -0.31443936, 0.17914855], [0.54776692, -0.8032915, 0.38541752], [-0.94472277, 0.32745653, -0.03385513], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [[[-0.76765651, -1.10261774, 0.29835641, 1.06601286, 1.40097415]]] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_stride(test_case, device): np_arr = np.array( [[[-1.01312506, -0.40687919, 1.5985316, 0.53594196, -1.89935565]]] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [[0.5751484, 0.26589182, -0.026546]], [[-0.10313249, -0.20797005, -0.48268208]], [[-0.22216944, -0.14962578, 0.57433963]], ] ) m = nn.Conv1d(1, 3, 3, stride=2, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [-0.73331773, 1.11231577], [-0.58247775, 0.64046454], [1.20406508, -1.5262109], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [[[0.24984647, -0.09170401, 0.31495798, -0.09170401, 0.06511152]]] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_group_bias_true(test_case, device): np_arr = np.array( [ [ [1.48566079, 0.54937589, 0.62353903, -0.94114172, -0.60260266], [0.61150503, -0.50289607, 1.41735041, -1.85877609, -1.04875529], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [[0.25576305, 0.40814576, -0.05900212]], [[-0.24829513, 0.42756805, -0.01354307]], [[0.44658303, 0.46889144, 0.41060263]], [[0.30083328, -0.5221613, 0.12215579]], ] ) bias = np.array([-0.03368823, -0.4212504, -0.42130581, -0.17434336]) m = nn.Conv1d(2, 4, 3, groups=2, stride=1, bias=True) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) np_out = np.array( [ [ [0.53372419, 0.41684598, -0.22277816], [-0.56368178, -0.27830642, -0.97031319], [0.19794616, -0.74452549, -1.09052706], [0.44534814, -1.29277706, 1.09451222], ] ] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [ [ [0.00746793, 0.84318173, 0.77063656, 0.76316863, -0.07254519], [0.74741632, 0.69414645, 1.22690487, 0.47948855, 0.53275841], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_group_large_out_bias_true(test_case, device): np_arr = np.array( [ [ [2.17964911, 0.91623521, 1.24746692, 0.73605931, -0.23738743], [-0.70412433, 0.10727754, 1.0207864, -0.09711888, -1.10814202], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [[-0.207307473, 0.12856324, 0.371991515]], [[-0.416422307, 3.26921181e-05, -0.385845661]], [[-0.182592362, 0.143281639, 0.419321984]], [[-0.27117458, 0.0421470925, 0.377335936]], [[0.546190619, -0.211819887, -0.29785803]], [[0.334832489, 0.255918801, -0.0556600206]], ] ) bias = np.array( [-0.56865668, 0.17631066, -0.43992457, -0.24307285, -0.53672957, -0.52927947] ) m = nn.Conv1d(2, 6, 3, groups=2, stride=1, bias=True) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) np_out = np.array( [ [ [-0.43867296, -0.32441288, -0.82094181], [-1.21264362, -0.48919463, -0.25154343], [-0.18354186, -0.11983716, -0.66178048], [0.33756858, -0.26578707, -0.9421193], [-1.2480886, -0.66543078, 0.37145507], [-0.79440582, -0.22671542, -0.15066233], ] ] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [ [ [-0.8063221, -0.53444451, -0.12897667, 0.6773454, 0.40546784], [0.6098485, 0.69609451, 0.71991241, 0.1100639, 0.02381789], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_group_large_in_bias_true(test_case, device): np_arr = np.array( [ [ [0.7382921, 0.3227571, -0.73204273, -0.01697334, 1.72585976], [0.52866709, 0.28417364, 1.12931311, 1.73048413, -0.60748184], [0.43222603, 0.7882517, -0.62105948, 0.10097823, 0.81639361], [0.36671457, 0.24468753, -0.5824874, -0.74464536, -0.38901371], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [-0.29574063, -0.31176069, 0.17234495], [0.06092392, 0.30691007, -0.36685407], ], [ [0.26149744, 0.07149458, 0.3209756], [0.18960869, -0.37148297, -0.13602243], ], ] ) bias = np.array([-0.35048512, -0.0093792]) m = nn.Conv1d(4, 2, 3, groups=2, stride=1, bias=True) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) np_out = np.array( [[[-1.09048378, -0.49156523, 0.99150705], [0.01852397, 0.54882324, 0.31657016]]] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [ [ [-0.29574063, -0.60750133, -0.43515638, -0.13941574, 0.17234495], [0.06092392, 0.36783397, 0.0009799, -0.059944, -0.36685407], [0.26149744, 0.33299202, 0.65396762, 0.39247018, 0.3209756], [0.18960869, -0.18187428, -0.31789672, -0.50750542, -0.13602243], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_conv1d_compilcate(test_case, device): np_arr = np.array( [ [ [-1.00674784, 0.51784992, 0.39896572, 0.11018554, 0.91136694], [1.95886874, 0.89779067, 0.4748213, 0.33313531, -0.49350029], [-0.19280219, 0.04023677, 1.66438103, -0.83563608, 0.15925731], [1.49166429, 1.45189261, -1.86512125, 0.34329697, 0.20413807], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [-0.36045218, 0.37349278, 0.04565236], [0.0242328, -0.09459515, -0.30684742], ], [ [-0.30345008, -0.1196513, -0.26765293], [0.09876197, 0.03346226, 0.2748405], ], [ [-0.37798449, 0.00242459, -0.34125558], [-0.05174343, -0.10443231, 0.09526101], ], [ [0.34196907, -0.32667893, 0.40264183], [0.38025281, 0.26807079, -0.09074812], ], ] ) bias = np.array([-0.03499984, -0.21616256, 0.13312563, -0.24104381]) m = nn.Conv1d(4, 4, 3, groups=2, stride=2, padding=2, dilation=2, bias=True) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) np_out = np.array( [ [ [-0.72379637, 0.67248386, 0.21977007], [-0.00643994, -0.1286152, -0.41589433], [-0.76877236, 0.29273134, -0.42040929], [1.0612179, -0.73787093, -0.37839717], ] ] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.array( [ [ [-0.41006082, 0.0, -0.63206136, 0.0, 0.03184089], [0.06186188, 0.0, 0.02985496, 0.0, -0.09313981], [-0.36026976, 0.0, -0.2988835, 0.0, -0.26286808], [0.49214786, 0.0, 0.49666074, 0.0, 0.16815135], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) @flow.unittest.skip_unless_1n1d() class TestConv1d(flow.unittest.TestCase): def test_conv1d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv1d_bias_true, _test_conv1d_bias_false, _test_conv1d_dilation, _test_conv1d_stride, _test_conv1d_group_bias_true, _test_conv1d_group_large_out_bias_true, _test_conv1d_group_large_in_bias_true, _test_conv1d_compilcate, ] arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skip("skip for now, becase it failed 8 times in past week") @autotest(n=3) def test_nn_functional_conv1d(test_case): device = random_device() img = torch.ones((1, 3, 224), requires_grad=True).to(device) kernel = torch.ones((3, 1, 3), requires_grad=True).to(device) y = torch.nn.functional.conv1d(img, kernel, groups=3) return y @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.13.0"), "conv module don't support unbatched input in PyTorch before '1.13.0'", ) @autotest(n=3) def test_nn_functional_conv1d_2dinput(test_case): device = random_device() img = torch.ones((3, 224), requires_grad=True).to(device) kernel = torch.ones((3, 1, 3), requires_grad=True).to(device) y = torch.nn.functional.conv1d(img, kernel, groups=3) return y @profile(torch.nn.functional.conv1d) def profile_conv1d(test_case): inputs = torch.ones(40, 16, 30) weight_16c = torch.ones(20, 16, 5) weight_16c_4g = torch.ones(20, 4, 5) weight_3k_16c = torch.ones(20, 16, 3) weight_1k_16c = torch.ones(20, 16, 1) torch.nn.functional.conv1d(inputs, weight_16c) torch.nn.functional.conv1d(inputs, weight_16c, bias=torch.ones(20)) torch.nn.functional.conv1d(inputs, weight_16c, bias=torch.ones(20), padding=2) torch.nn.functional.conv1d( inputs, weight_16c, bias=torch.ones(20), padding=2, stride=2 ) torch.nn.functional.conv1d(inputs, weight_16c_4g, groups=4) torch.nn.functional.conv1d(inputs, weight_16c_4g, bias=torch.ones(20), groups=4) torch.nn.functional.conv1d( inputs, weight_16c_4g, bias=torch.ones(20), groups=4, stride=4 ) torch.nn.functional.conv1d( inputs, weight_16c_4g, bias=torch.ones(20), groups=4, padding=2 ) torch.nn.functional.conv1d(inputs, weight_3k_16c) torch.nn.functional.conv1d(inputs, weight_3k_16c, bias=torch.ones(20)) torch.nn.functional.conv1d( inputs, weight_3k_16c, bias=torch.ones(20), padding=1 ) torch.nn.functional.conv1d( inputs, weight_3k_16c, bias=torch.ones(20), padding=1, stride=2 ) torch.nn.functional.conv1d(inputs, weight_1k_16c) torch.nn.functional.conv1d(inputs, weight_1k_16c, bias=torch.ones(20)) torch.nn.functional.conv1d(inputs, weight_1k_16c, bias=torch.ones(20), stride=2) @autotest(n=5, atol=1e-3) def test_conv1d_with_random_data(test_case): channels = random(1, 6) m = torch.nn.Conv1d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, check_allclose=False) def test_conv1d_group_with_random_data(test_case): channels = 720 # lcm(1, 2, 3, 4, 5, 6) m = torch.nn.Conv1d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 7), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=3, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_conv2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch as torch_original from packaging import version test_conv2d_weight = np.array( [ [ [ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878], [0.35005471110343933, 0.5360521078109741, 1.5194443464279175], [1.9040879011154175, -1.5734431743621826, -0.14007866382598877], ] ], [ [ [0.29670074582099915, 1.3111951351165771, 0.5035904049873352], [-1.1894450187683105, -0.5502137541770935, -1.591875672340393], [-1.1081947088241577, 0.07872020453214645, -0.9185634255409241], ] ], [ [ [-0.7457143664360046, -1.2080862522125244, 1.8140212297439575], [-1.5227429866790771, -2.515244960784912, -1.3549325466156006], [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952], ] ], ] ) test_conv2d_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, ], [ 1.5580710172653198, -0.5459445714950562, -2.3556296825408936, 0.5414402484893799, 2.678506374359131, ], [ 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, 0.37723132967948914, ], [ 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, ], [ 1.830764889717102, -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, ], ] ] ] ) test_conv2d_data_grad = np.array( [ [ [ [ 0.4095913469791412, 0.2847584038972855, 2.803684800863266, 2.3940934538841248, 2.5189263969659805, ], [ -1.9525419473648071, -4.606781497597694, -3.51521897315979, -1.562677025794983, 1.0915625244379044, ], [ -2.1141327619552612, -6.987950943410397, -5.84306687861681, -3.7289341166615486, 1.1448840647935867, ], [ -2.5237241089344025, -7.272709347307682, -8.646751679480076, -6.123027570545673, -1.3740423321723938, ], [ -0.1615908145904541, -2.381169445812702, -2.32784790545702, -2.1662570908665657, 0.0533215403556824, ], ] ] ] ) test_conv2d_weight_grad = np.array( [ [ [ [0.6277393400669098, -2.7888944894075394, -0.2910575419664383], [-3.095237225294113, -4.835702538490295, -1.8706469237804413], [-1.0139376372098923, -6.076017692685127, -5.780256435275078], ] ], [ [ [0.6277393400669098, -2.7888944894075394, -0.2910575419664383], [-3.095237225294113, -4.835702538490295, -1.8706469237804413], [-1.0139376372098923, -6.076017692685127, -5.780256435275078], ] ], [ [ [0.6277393400669098, -2.7888944894075394, -0.2910575419664383], [-3.095237225294113, -4.835702538490295, -1.8706469237804413], [-1.0139376372098923, -6.076017692685127, -5.780256435275078], ] ], ] ) test_conv2d_output = np.array( [ [ [ [0.9699610471725464, -0.20758534967899323, 2.3857712745666504], [0.3666309118270874, 4.690882682800293, -8.203354835510254], [2.6072847843170166, -1.9033538103103638, 2.331153154373169], ], [ [2.519343852996826, 2.3757898807525635, -1.6613528728485107], [0.5777544379234314, -3.5739502906799316, 5.349126815795898], [0.729295015335083, 1.5791023969650269, 3.7627718448638916], ], [ [-0.27685487270355225, 6.446267127990723, -2.762883424758911], [-8.25644588470459, 9.616064071655273, 8.005367279052734], [-0.6944921016693115, 3.866114854812622, 4.788446426391602], ], ] ] ) test_conv2d_with_bias_weight = np.array( [ [ [ [1.8271433115005493, -1.0446699857711792, 1.0062190294265747], [0.5174201130867004, -0.806931734085083, 1.3769007921218872], [0.205885112285614, 0.9943519234657288, -0.23580588400363922], ] ], [ [ [0.29881811141967773, -1.9982075691223145, 0.3511354625225067], [-0.7644741535186768, 1.2594351768493652, -0.9629734754562378], [0.5080506205558777, 0.7561734318733215, 1.6839302778244019], ] ], [ [ [1.2573646306991577, 0.13123232126235962, 1.6403018236160278], [-1.2138012647628784, 2.399970531463623, -0.38509097695350647], [-0.9878040552139282, 0.9585888385772705, -1.4976465702056885], ] ], ] ) test_conv2d_with_bias_bias = np.array( [0.6605162620544434, -0.18903568387031555, -0.27302607893943787] ) test_conv2d_with_bias_data = np.array( [ [ [ [ -0.47827261686325073, -1.1739492416381836, -0.7921845316886902, 0.9321041703224182, -3.1557741165161133, ], [ 2.1935296058654785, -0.5385921001434326, -0.8611332774162292, -1.881519079208374, -0.7205708026885986, ], [ -0.35601571202278137, -0.15963983535766602, 1.797447681427002, 0.19594945013523102, -1.7376397848129272, ], [ 0.047347065061330795, 0.14580930769443512, 0.32604914903640747, 0.4578782916069031, -0.8942581415176392, ], [ 0.49383941292762756, -0.9043426513671875, -1.2140793800354004, 2.1564064025878906, 1.0938222408294678, ], ] ] ] ) test_conv2d_with_bias_output = np.array( [ [ [ [-0.05607491731643677, -0.185230553150177, -3.8808679580688477], [6.861937046051025, -2.3341472148895264, -0.5597308874130249], [1.8299254179000854, -2.770848274230957, 2.1958212852478027], ], [ [2.9348952770233154, 4.117504119873047, -6.278541088104248], [0.2638452351093292, 3.998856782913208, 2.612290620803833], [-1.9891828298568726, -1.6476304531097412, 3.39066219329834], ], [ [-8.44466781616211, 0.5747121572494507, -8.501373291015625], [-0.036642804741859436, -0.23458999395370483, -2.370849370956421], [2.8372013568878174, -2.987276077270508, 1.8382092714309692], ], ] ] ) test_conv2d_group_weight = np.array( [ [ [ [-0.7248556613922119, 1.1119636297225952, -0.47827261686325073], [-1.1739492416381836, -0.7921845316886902, 0.9321041703224182], [-3.1557741165161133, 2.1935296058654785, -0.5385921001434326], ] ], [ [ [-0.8611332774162292, -1.881519079208374, -0.7205708026885986], [-0.35601571202278137, -0.15963983535766602, 1.797447681427002], [0.19594945013523102, -1.7376397848129272, 0.047347065061330795], ] ], ] ) test_conv2d_group_data_grad = np.array( [ [ [ [ -0.7248556613922119, 0.3871079683303833, -0.0911646485328674, 0.6336910128593445, -0.4782726168632507, ], [ -1.8988049030303955, -1.5790258049964905, -1.125194251537323, 0.7736106514930725, 0.4538315534591675, ], [ -5.054579019546509, -2.5412703156471252, -2.6260308623313904, 2.4285481572151184, -0.0847605466842651, ], [ -4.329723358154297, -2.9283782839775085, -2.534866213798523, 1.794857144355774, 0.3935120701789856, ], [ -3.1557741165161133, -0.9622445106506348, -1.5008366107940674, 1.654937505722046, -0.5385921001434326, ], ], [ [ -0.8611332774162292, -2.7426523566246033, -3.463223159313202, -2.6020898818969727, -0.7205708026885986, ], [ -1.2171489894390106, -3.2583079040050507, -2.1814310252666473, -0.9642820358276367, 1.0768768787384033, ], [ -1.0211995393037796, -4.799998238682747, -3.6757742948830128, -2.654574755579233, 1.1242239437997341, ], [ -0.1600662618875504, -2.0573458820581436, -0.2125511355698109, -0.0524848736822605, 1.8447947464883327, ], [ 0.195949450135231, -1.5416903346776962, -1.4943432696163654, -1.6902927197515965, 0.0473470650613308, ], ], ] ] ) test_conv2d_group_weight_grad = np.array( [ [ [ [0.6277393400669098, -2.7888944894075394, -0.2910575419664383], [-3.095237225294113, -4.835702538490295, -1.8706469237804413], [-1.0139376372098923, -6.076017692685127, -5.780256435275078], ] ], [ [ [3.30740749835968, -0.7220746576786041, -3.660933956503868], [0.5273916646838188, -2.631059892475605, -7.6207195818424225], [-3.5466641262173653, -8.214546449482441, -11.031560003757477], ] ], ] ) test_conv2d_group_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, ], [ 1.5580710172653198, -0.5459445714950562, -2.3556296825408936, 0.5414402484893799, 2.678506374359131, ], [ 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, 0.37723132967948914, ], [ 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, ], [ 1.830764889717102, -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, ], ], [ [ 0.8586049675941467, -0.2279418259859085, 0.2013147622346878, 0.35005471110343933, 0.5360521078109741, ], [ 1.5194443464279175, 1.9040879011154175, -1.5734431743621826, -0.14007866382598877, 0.29670074582099915, ], [ 1.3111951351165771, 0.5035904049873352, -1.1894450187683105, -0.5502137541770935, -1.591875672340393, ], [ -1.1081947088241577, 0.07872020453214645, -0.9185634255409241, -0.7457143664360046, -1.2080862522125244, ], [ 1.8140212297439575, -1.5227429866790771, -2.515244960784912, -1.3549325466156006, -0.9574840068817139, ], ], ] ] ) test_conv2d_group_output = np.array( [ [ [ [-8.836943626403809, 3.2316627502441406, 6.994439601898193], [-0.8386597037315369, -9.857108116149902, 13.68197250366211], [-13.020713806152344, 7.310227870941162, -3.3760271072387695], ], [ [-4.803101539611816, 1.026240587234497, 0.5452112555503845], [-6.839838027954102, 2.0195930004119873, 0.11328654736280441], [0.393694669008255, 4.987061023712158, 3.297354221343994], ], ] ] ) test_conv2d_padding_weight = np.array( [ [ [ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878], [0.35005471110343933, 0.5360521078109741, 1.5194443464279175], [1.9040879011154175, -1.5734431743621826, -0.14007866382598877], ] ] ] ) test_conv2d_padding_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, ], [ 1.5580710172653198, -0.5459445714950562, -2.3556296825408936, 0.5414402484893799, 2.678506374359131, ], [ 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, 0.37723132967948914, ], [ 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, ], [ 1.830764889717102, -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, ], ] ] ] ) test_conv2d_padding_data_grad = np.array( [ [ [ [ 3.237529069185257, 3.237529069185257, 3.237529069185257, 3.237529069185257, 3.237529069185257, ], [ 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, ], [ 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, ], [ 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, 3.428095132112503, ], [ 2.596117228269577, 2.596117228269577, 2.596117228269577, 2.596117228269577, 2.596117228269577, ], ] ] ] ) test_conv2d_padding_weight_grad = np.array( [ [ [ [1.7594299167394638, 1.7594299167394638, 1.7594299167394638], [-0.6019042432308197, -0.6019042432308197, -0.6019042432308197], [-1.532561555504799, -1.532561555504799, -1.532561555504799], ] ] ] ) test_conv2d_padding_output = np.array( [ [ [ [ 1.5489805936813354, -1.0164761543273926, 5.277345657348633, 3.153532028198242, -7.301508903503418, -3.7565059661865234, 4.690962314605713, ], [ 2.425799608230591, -2.0592665672302246, 0.9699610471725464, -0.20758534967899323, 2.3857712745666504, 1.1719579696655273, 0.6523551940917969, ], [ 2.1625545024871826, -1.3517316579818726, 0.3666309118270874, 4.690882682800293, -8.203354835510254, 3.0248217582702637, 1.2624683380126953, ], [ 0.6193475723266602, -2.0285415649414062, 2.6072847843170166, -1.9033538103103638, 2.331153154373169, -3.998155355453491, -1.0176407098770142, ], [ 2.8643176555633545, -0.7396122217178345, -0.2253415733575821, -2.846742630004883, -4.961236476898193, -0.1308247298002243, -0.7344070672988892, ], ] ] ] ) test_conv2d_stride_weight = np.array( [ [ [ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878], [0.35005471110343933, 0.5360521078109741, 1.5194443464279175], [1.9040879011154175, -1.5734431743621826, -0.14007866382598877], ] ] ] ) test_conv2d_stride_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, ], [ 1.5580710172653198, -0.5459445714950562, -2.3556296825408936, 0.5414402484893799, 2.678506374359131, ], [ 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, 0.37723132967948914, ], [ 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, ], [ 1.830764889717102, -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, ], ] ] ] ) test_conv2d_stride_data_grad = np.array( [ [ [ [ 0.5360521078109741, 1.5194443464279175, 0.3500547111034393, 0.5360521078109741, 1.5194443464279175, ], [ -1.8013850003480911, 0.061236098408699, 2.762692868709564, -1.8013850003480911, 0.061236098408699, ], [ 0.5360521078109741, 1.5194443464279175, 0.3500547111034393, 0.5360521078109741, 1.5194443464279175, ], [ -1.8013850003480911, 0.061236098408699, 2.762692868709564, -1.8013850003480911, 0.061236098408699, ], [ 0.5360521078109741, 1.5194443464279175, 0.3500547111034393, 0.5360521078109741, 1.5194443464279175, ], ] ] ] ) test_conv2d_stride_weight_grad = np.array( [ [ [ [-5.1135923862457275, 3.5859558284282684, 2.089697480201721], [-0.3276629596948624, 1.7587070614099503, -2.5950092673301697], [-5.1135923862457275, 3.5859558284282684, 2.089697480201721], ] ] ] ) test_conv2d_stride_output = np.array( [ [ [ [-1.0164761543273926, -7.301508903503418], [-1.3517316579818726, -8.203354835510254], [-0.7396122217178345, -4.961236476898193], ] ] ] ) test_conv2d_kernel_weight = np.array( [ [ [ [ -0.9574840068817139, -0.7248556613922119, 1.1119636297225952, -0.47827261686325073, -1.1739492416381836, ], [ -0.7921845316886902, 0.9321041703224182, -3.1557741165161133, 2.1935296058654785, -0.5385921001434326, ], [ -0.8611332774162292, -1.881519079208374, -0.7205708026885986, -0.35601571202278137, -0.15963983535766602, ], ] ] ] ) test_conv2d_kernel_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, 1.5580710172653198, -0.5459445714950562, ], [ -2.3556296825408936, 0.5414402484893799, 2.678506374359131, 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, ], [ 0.37723132967948914, 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, 1.830764889717102, ], [ -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, 0.8586049675941467, -0.2279418259859085, 0.2013147622346878, ], [ 0.35005471110343933, 0.5360521078109741, 1.5194443464279175, 1.9040879011154175, -1.5734431743621826, -0.14007866382598877, 0.29670074582099915, ], [ 1.3111951351165771, 0.5035904049873352, -1.1894450187683105, -0.5502137541770935, -1.591875672340393, -1.1081947088241577, 0.07872020453214645, ], [ -0.9185634255409241, -0.7457143664360046, -1.2080862522125244, 1.8140212297439575, -1.5227429866790771, -2.515244960784912, -1.3549325466156006, ], ] ] ] ) test_conv2d_kernel_data_grad = np.array( [ [ [ [ -0.9574840068817139, -1.6823396682739258, -0.5703760385513306, -0.0911646485328674, -0.5402582287788391, -1.6522218585014343, -1.1739492416381836, ], [ -1.749668538570404, -1.5424200296401978, -3.586230516433716, -0.121304988861084, -2.0410948395729065, 0.0027156472206116, -1.7125413417816162, ], [ -2.6108018159866333, -4.285072386264801, -7.049453675746918, -3.079410582780838, -3.2773211896419525, -0.5129399001598358, -1.8721811771392822, ], [ -2.6108018159866333, -4.285072386264801, -7.049453675746918, -3.079410582780838, -3.2773211896419525, -0.5129399001598358, -1.8721811771392822, ], [ -2.6108018159866333, -4.285072386264801, -7.049453675746918, -3.079410582780838, -3.2773211896419525, -0.5129399001598358, -1.8721811771392822, ], [ -1.6533178091049194, -2.6027327179908752, -6.479077637195587, -2.9882459342479706, -2.7370629608631134, 1.1392819583415985, -0.6982319355010986, ], [ -0.8611332774162292, -2.7426523566246033, -3.463223159313202, -2.958105593919754, -1.236226350069046, -0.5156555473804474, -0.159639835357666, ], ] ] ] ) test_conv2d_kernel_weight_grad = np.array( [ [ [ [ 2.974529668688774, 4.548736393451691, 1.1672898679971695, -1.499158263206482, 0.1862268149852753, ], [ 1.6534235626459122, 2.3762744814157486, -1.448018729686737, -5.2917241007089615, -2.278435029089451, ], [ -2.083257421851158, -2.23808591067791, -5.749193429946899, -7.540486767888069, -6.306201495230198, ], ] ] ] ) test_conv2d_kernel_output = np.array( [ [ [ [-3.5647754669189453, -4.234736919403076, 1.4046944379806519], [-0.6964312791824341, 16.42838478088379, -9.649789810180664], [4.312150478363037, -6.283960819244385, -4.8443922996521], [-2.772286891937256, -4.483709812164307, 12.315184593200684], [7.39893913269043, 1.305102825164795, -2.049992561340332], ] ] ] ) test_conv2d_dilation_weight = np.array( [ [ [ [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952], [-0.47827261686325073, -1.1739492416381836, -0.7921845316886902], [0.9321041703224182, -3.1557741165161133, 2.1935296058654785], ] ] ] ) test_conv2d_dilation_data = np.array( [ [ [ [ 1.1630785465240479, 0.4838046133518219, 0.299563467502594, 0.15302546322345734, -1.168814778327942, 1.5580710172653198, -0.5459445714950562, ], [ -2.3556296825408936, 0.5414402484893799, 2.678506374359131, 1.2546343803405762, -0.5487740635871887, -0.6810643672943115, -0.13531559705734253, ], [ 0.37723132967948914, 0.41016456484794617, 0.5712682008743286, -2.757962703704834, 1.0762799978256226, -0.6141325235366821, 1.830764889717102, ], [ -1.1468064785003662, 0.053837940096855164, -2.5074806213378906, -0.5916498899459839, 0.8586049675941467, -0.2279418259859085, 0.2013147622346878, ], [ 0.35005471110343933, 0.5360521078109741, 1.5194443464279175, 1.9040879011154175, -1.5734431743621826, -0.14007866382598877, 0.29670074582099915, ], [ 1.3111951351165771, 0.5035904049873352, -1.1894450187683105, -0.5502137541770935, -1.591875672340393, -1.1081947088241577, 0.07872020453214645, ], [ -0.9185634255409241, -0.7457143664360046, -1.2080862522125244, 1.8140212297439575, -1.5227429866790771, -2.515244960784912, -1.3549325466156006, ], ] ] ] ) test_conv2d_dilation_data_grad = np.array( [ [ [ [ -0.9574840068817139, 0.0, 0.0, -0.7248556613922119, 0.0, 0.0, 1.1119636297225952, ], [ -0.9574840068817139, 0.0, 0.0, -0.7248556613922119, 0.0, 0.0, 1.1119636297225952, ], [ -1.4357566237449646, 0.0, 0.0, -1.8988049030303955, 0.0, 0.0, 0.319779098033905, ], [ -0.4782726168632507, 0.0, 0.0, -1.1739492416381836, 0.0, 0.0, -0.7921845316886902, ], [ 0.4538315534591675, 0.0, 0.0, -4.329723358154297, 0.0, 0.0, 1.4013450741767883, ], [ 0.9321041703224182, 0.0, 0.0, -3.1557741165161133, 0.0, 0.0, 2.1935296058654785, ], [ 0.9321041703224182, 0.0, 0.0, -3.1557741165161133, 0.0, 0.0, 2.1935296058654785, ], ] ] ] ) test_conv2d_dilation_weight_grad = np.array( [ [ [ [-0.8153198063373566, -1.3503028601408005, 1.1495047211647034], [-0.4195204377174377, -1.4455246925354004, 2.328780397772789], [0.7426864206790924, 3.1678953766822815, -0.979511596262455], ] ] ] ) test_conv2d_dilation_output = np.array( [[[[-5.2563982009887695], [5.410353183746338], [-8.517012596130371]]]] ) def _test_conv2d( test_case, conv, data, weight, output, bias=None, device="cuda", ): to_device = flow.device(device) x = flow.tensor(data, dtype=flow.float32, device=to_device) conv.weight = flow.nn.Parameter(flow.Tensor(weight)) if bias is not None: conv.bias = flow.nn.Parameter(flow.Tensor(bias)) conv.to(to_device) of_out = conv(x) test_case.assertTrue(np.allclose(of_out.numpy(), output, rtol=1e-4, atol=1e-8)) def _test_conv2d_backward( test_case, conv, data, weight, data_grad, weight_grad, bias=None, device="cuda", data_rtol=1e-4, data_atol=1e-8, weight_rtol=1e-4, weight_atol=1e-8, ): to_device = flow.device(device) x = flow.tensor(data, dtype=flow.float32, device=to_device, requires_grad=True) conv.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) if bias is not None: conv.bias = flow.nn.Parameter(flow.Tensor(bias)) conv.to(to_device) of_out = conv(x) of_out.sum().backward() test_case.assertTrue( np.allclose(x.grad.numpy(), data_grad, rtol=data_rtol, atol=data_atol) ) test_case.assertTrue( np.allclose( conv.weight.grad.numpy(), weight_grad, rtol=weight_rtol, atol=weight_atol ) ) def _test_conv2d_large_in_channel(test_case, device): np_arr = np.array( [ [ [ [ 0.6206631238581714, -1.1225329393404626, 0.8407155480700242, -0.6845162855236345, ], [ -0.5186484633906412, 0.10420735184519186, -0.1711568947473012, 0.5168640476046483, ], [ -0.12429464919764661, 0.050277779246134253, -1.0144501797426606, -2.184600444658526, ], [ 0.28918126931309923, -0.822872663244595, 0.44019150436683663, -1.0247720130825562, ], ], [ [ 0.7786504412818226, -0.7501839068078657, -0.8187283189941765, -1.1116653569170698, ], [ 0.18085524152316743, -1.3461349607476678, 1.142505437476448, -0.000649619704040145, ], [ 0.03160672782674317, -0.006318157449953413, 1.2218487782604377, 0.15903027907930234, ], [ 1.5857011815642381, 0.6656477116332891, -0.04036621813223574, -0.3427168687988546, ], ], [ [ -1.1774346070102524, 1.6195241269303395, -0.36185552303441965, -1.1382193113192487, ], [ 0.08061907334568702, 1.5025447613238763, -1.1591348706634745, 1.6449050139676873, ], [ 1.1539915649822392, -2.414624939646017, 0.3056063774849572, 1.1920089257083162, ], [ 0.7623012858982319, -0.01685314742940813, -1.096666898224702, -0.4406476137098582, ], ], [ [ 0.9383797282214235, -1.1075876842796508, -0.4420913825139058, -1.0736097610655628, ], [ -0.3101376466546291, 1.6578227745160954, -0.6225454278031398, 0.6831188620748697, ], [ 0.00743800968372913, -0.8089158949698473, 2.08084287836801, 0.721204366332351, ], [ 0.5694701823297723, 0.031519314469744895, -0.5041680957766629, -0.4738588233094669, ], ], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [ [0.06456436216831207, -0.10852358490228653, -0.21638715267181396], [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096], [0.05026858672499657, 0.10818571597337723, 0.02056501805782318], ], [ [0.205095112323761, 0.1488947868347168, -0.2344113141298294], [0.1684819906949997, -0.21986986696720123, 0.1082606166601181], [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222], ], ], [ [ [-0.09441672265529633, -0.03644559532403946, -0.22235223650932312], [-0.1771145612001419, 0.08043312281370163, 0.06938580423593521], [0.054393064230680466, -0.05483492836356163, 0.23438701033592224], ], [ [0.22666795551776886, 0.0874653309583664, 0.07092718034982681], [0.08883464336395264, -0.052362944930791855, -0.1720171570777893], [0.10441060364246368, 0.011952142231166363, -0.0894528403878212], ], ], ] ) m = flow.nn.Conv2d(4, 2, 3, groups=2, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) m = m.to(device) output = m(input) np_out = [ [ [ [0.7666134238243103, -0.3961866497993469], [-0.656266987323761, -1.1613956689834595], ], [ [0.3077264130115509, -0.42817503213882446], [-0.5761325359344482, 0.1300736665725708], ], ] ] test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3)) output = output.sum() output.backward() np_grad = [ [ [ [ 0.06456436216831207, -0.04395922273397446, -0.3249107301235199, -0.21638715267181396, ], [ -0.16334669291973114, -0.12419328093528748, 0.017341122031211853, -0.021812304854393005, ], [ -0.17764246463775635, 0.07822024822235107, 0.47100257873535156, 0.21513986587524414, ], [ 0.05026858672499657, 0.1584542989730835, 0.128750741481781, 0.02056501805782318, ], ], [ [ 0.205095112323761, 0.3539898991584778, -0.08551652729511261, -0.2344113141298294, ], [ 0.3735771179199219, 0.30260205268859863, -0.19712577760219574, -0.1261506974697113, ], [ 0.015584588050842285, -0.03308109939098358, 0.07913993299007416, 0.12780562043190002, ], [ -0.1528974026441574, 0.018306776881217957, 0.1907491832971573, 0.01954500749707222, ], ], [ [ -0.09441672265529633, -0.13086232542991638, -0.258797824382782, -0.22235223650932312, ], [ -0.27153128385543823, -0.22754377126693726, -0.10897888988256454, -0.1529664397239685, ], [ -0.12272149324417114, -0.09712330251932144, 0.32937100529670715, 0.30377280712127686, ], [ 0.054393064230680466, -0.00044186413288116455, 0.1795520782470703, 0.23438701033592224, ], ], [ [ 0.22666795551776886, 0.31413328647613525, 0.1583925187587738, 0.07092718034982681, ], [ 0.3155025839805603, 0.35060498118400574, -0.06598758697509766, -0.1010899767279625, ], [ 0.19324524700641632, 0.1528344452381134, -0.301880806684494, -0.2614699900150299, ], [ 0.10441060364246368, 0.11636274307966232, -0.07750070095062256, -0.0894528403878212, ], ], ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3)) def _test_conv2d_large_out_channel(test_case, device): np_arr = np.array( [ [ [ [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567], [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593], [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564], [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150], [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510], ], [ [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217], [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165], [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956], [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374], [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907], ], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [ [-0.19489679, -0.32377058, 0.21736273], [0.04095296, -0.21552679, -0.14626531], [-0.19359522, -0.00742865, -0.19832158], ] ], [ [ [0.29926914, 0.00931164, 0.26197660], [0.27611443, -0.15439281, -0.19027126], [-0.28909120, 0.30367029, -0.05168664], ] ], [ [ [-0.03155736, 0.17610769, 0.22111714], [0.22790670, -0.32897446, -0.03260243], [-0.10274851, -0.06903386, -0.19438276], ] ], [ [ [-0.24573688, -0.06723209, -0.21363299], [-0.02136187, -0.24994437, -0.18691199], [0.12189507, 0.29469389, 0.03398871], ] ], ] ) m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) m = m.to(device) output = m(input) np_out = np.array( [ [ [ [-0.21170563, 0.03652292, 0.25926736], [-0.19168918, 0.49044561, 0.25099146], [-1.02489340, 0.25361472, -0.51828313], ], [ [0.23977707, -0.56090075, -0.19285655], [-0.17167747, 0.24558367, -0.30935860], [-0.33303234, 1.52472734, -0.49013454], ], [ [-0.17137986, 1.21333742, 0.18988736], [0.31785482, -0.12121570, -0.18676008], [-0.10680684, -0.30298883, 0.41809759], ], [ [-0.87821335, -0.51665992, -0.44061098], [0.74804580, 0.53107250, 0.50418228], [-0.00512899, -0.36455840, -0.23643512], ], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3)) output = output.sum() output.backward() np_grad = np.array( [ [ [ [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933], [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275], [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548], [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481], [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823], ], [ [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415], [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026], [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430], [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844], [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404], ], ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3)) @flow.unittest.skip_unless_1n1d() class TestConv2d(flow.unittest.TestCase): def test_conv2d_default_init(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 3), bias=True).to(flow.device(device)) test_case.assertTrue( not np.allclose( conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10 ) ) test_case.assertTrue( not np.allclose( conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10 ) ) conv = flow.nn.Conv2d( 1, 1, (3, 3), bias=True, device=device, dtype=flow.float32 ) test_case.assertTrue( not np.allclose( conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10 ) ) test_case.assertTrue( not np.allclose( conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10 ) ) conv = flow.nn.Conv2d( 1, 1, (3, 3), bias=True, device=device, dtype=flow.float16 ) test_case.assertTrue( not np.allclose( conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10 ) ) test_case.assertTrue( not np.allclose( conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10 ) ) @unittest.skip("skip for now, becase it failed 8 times in past week") @autotest(n=3) def test_nn_functional_conv2d(test_case): device = random_device() img = torch.ones((1, 3, 224, 224), requires_grad=True).to(device) kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device) y = torch.nn.functional.conv2d(input=img, weight=kernel, groups=3) return y @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.13.0"), "conv module don't support unbatched input in PyTorch before '1.13.0'", ) @autotest(n=3) def test_nn_functional_conv2d_3dinput(test_case): device = random_device() img = torch.ones((3, 224, 224), requires_grad=True).to(device) kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device) y = torch.nn.functional.conv2d(input=img, weight=kernel, groups=3) return y def test_conv2d(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device)) _test_conv2d( test_case, conv, test_conv2d_data, test_conv2d_weight, test_conv2d_output, device=device, ) def test_conv2d_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] os.environ["ONEFLOW_ENABLE_NHWC"] = "0" for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device)) _test_conv2d_backward( test_case, conv, test_conv2d_data, test_conv2d_weight, test_conv2d_data_grad, test_conv2d_weight_grad, device=device, ) # bias grad not yet supported def test_conv2d_with_bias(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True).to(flow.device(device)) _test_conv2d( test_case, conv, test_conv2d_with_bias_data, test_conv2d_with_bias_weight, test_conv2d_with_bias_output, bias=test_conv2d_with_bias_bias, device=device, ) def test_conv2d_group(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to( flow.device(device) ) _test_conv2d( test_case, conv, test_conv2d_group_data, test_conv2d_group_weight, test_conv2d_group_output, device=device, ) def test_conv2d_group_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to( flow.device(device) ) _test_conv2d_backward( test_case, conv, test_conv2d_group_data, test_conv2d_group_weight, test_conv2d_group_data_grad, test_conv2d_group_weight_grad, device=device, ) def test_conv2d_padding(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to( flow.device(device) ) _test_conv2d( test_case, conv, test_conv2d_padding_data, test_conv2d_padding_weight, test_conv2d_padding_output, device=device, ) def test_conv2d_padding_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to( flow.device(device) ) _test_conv2d_backward( test_case, conv, test_conv2d_padding_data, test_conv2d_padding_weight, test_conv2d_padding_data_grad, test_conv2d_padding_weight_grad, device=device, weight_atol=1e-3, ) def test_conv2d_stride(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d( 1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False ).to(flow.device(device)) _test_conv2d( test_case, conv, test_conv2d_stride_data, test_conv2d_stride_weight, test_conv2d_stride_output, device=device, ) def test_conv2d_stride_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d( 1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False ).to(flow.device(device)) _test_conv2d_backward( test_case, conv, test_conv2d_stride_data, test_conv2d_stride_weight, test_conv2d_stride_data_grad, test_conv2d_stride_weight_grad, device=device, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_conv2d_kernel(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device)) conv.to(flow.device("cuda")) _test_conv2d( test_case, conv, test_conv2d_kernel_data, test_conv2d_kernel_weight, test_conv2d_kernel_output, device=device, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_conv2d_kernel_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device)) conv.to(flow.device("cuda")) _test_conv2d_backward( test_case, conv, test_conv2d_kernel_data, test_conv2d_kernel_weight, test_conv2d_kernel_data_grad, test_conv2d_kernel_weight_grad, device=device, ) def test_conv2d_dilation(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to( flow.device(device) ) _test_conv2d( test_case, conv, test_conv2d_dilation_data, test_conv2d_dilation_weight, test_conv2d_dilation_output, device=device, ) def test_conv2d_dilation_backward(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): device = arg[0] conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to( flow.device(device) ) _test_conv2d_backward( test_case, conv, test_conv2d_dilation_data, test_conv2d_dilation_weight, test_conv2d_dilation_data_grad, test_conv2d_dilation_weight_grad, device=device, ) def test_large_in_channel_group_conv(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv2d_large_in_channel, ] arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_large_out_channel_group_conv(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv2d_large_out_channel, ] arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, rtol=1e-2, atol=1e-2) def test_conv2d_with_random_data(test_case): channels = random(1, 6) m = torch.nn.Conv2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random(1, 4) | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 3) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=channels).to(device) y = m(x) return y @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.13.0"), "conv module don't support unbatched input in PyTorch before '1.13.0'", ) @autotest(n=5, rtol=1e-3, atol=1e-3) def test_conv2d_auto_squeeze_with_random_data(test_case): channels = random(1, 6) m = torch.nn.Conv2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), bias=random_bool(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim0=channels).to(device) y = m(x) return y @autotest(n=5, check_graph=False) def test_conv2d_0size_with_random_data(test_case): channels = random(1, 6) m = torch.nn.Conv2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim0=0, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, check_allclose=False) def test_conv2d_group_with_random_data(test_case): channels = 720 # lcm(1, 2, 3, 4, 5, 6) m = torch.nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 7), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=4, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y @unittest.skip("skip for now, becase it failed 6 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_conv2d_NHWC_with_random_data(test_case): in_channels = np.random.randint(6, 33) out_channels = np.random.randint(32, 66) kernel_size = np.random.randint(1, 5) stride = np.random.randint(1, 2) padding = np.random.randint(1, 3) dilation = np.random.randint(1, 3) spatial = np.random.randint(6, 64) np_x = np.random.randn(4, in_channels, spatial, spatial).astype(np.float32) np_weight = np.random.randn( out_channels, in_channels, kernel_size, kernel_size ).astype(np.float32) np_bias = np.random.randn(out_channels).astype(np.float32) flow_nchw_input = flow.tensor( np_x, device="cuda", dtype=flow.float32, requires_grad=True ) flow_nchw_weights = flow.nn.Parameter( flow.tensor( np_weight, device="cuda", dtype=flow.float32, requires_grad=True ) ) flow_nchw_bias = flow.nn.Parameter( flow.tensor(np_bias, device="cuda", dtype=flow.float32, requires_grad=True) ) flow_nchw_conv = flow.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ).to("cuda") flow_nchw_conv.weight = flow_nchw_weights flow_nchw_conv.bias = flow_nchw_bias flow_nchw_out = flow_nchw_conv(flow_nchw_input) os.environ["ONEFLOW_ENABLE_NHWC"] = "1" flow_nhwc_input = flow.tensor( np_x, device="cuda", dtype=flow.float32, requires_grad=True ) flow_nhwc_permuted_input = flow.permute(flow_nhwc_input, (0, 2, 3, 1)) flow_nhwc_weights = flow.tensor( np_weight, device="cuda", dtype=flow.float32, requires_grad=True ) flow_nhwc_permuted_weights = flow.nn.Parameter( flow.permute(flow_nhwc_weights, (0, 2, 3, 1)) ) flow_nhwc_bias = flow.nn.Parameter( flow.tensor(np_bias, device="cuda", dtype=flow.float32, requires_grad=True) ) flow_nhwc_conv = flow.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ).to("cuda") flow_nhwc_conv.weight = flow_nhwc_permuted_weights flow_nhwc_conv.bias = flow_nhwc_bias flow_nhwc_out = flow_nhwc_conv(flow_nhwc_permuted_input) flow_nhwc_permuted_out = flow.permute(flow_nhwc_out, (0, 3, 1, 2)) test_case.assertTrue( np.allclose( flow_nchw_out.numpy(), flow_nhwc_permuted_out.numpy(), rtol=1e-4, atol=1e-4, ) ) total_out = flow_nchw_out + flow_nhwc_permuted_out total_out = total_out.sum() total_out.backward() test_case.assertTrue( np.allclose( flow_nchw_weights.grad.numpy(), np.transpose(flow_nhwc_permuted_weights.grad.numpy(), (0, 3, 1, 2)), rtol=1e-3, atol=1e-4, ) ) test_case.assertTrue( np.allclose( flow_nchw_input.grad.numpy(), flow_nhwc_input.grad.numpy(), rtol=1e-4, atol=1e-4, ) ) os.environ["ONEFLOW_ENABLE_NHWC"] = "0" @profile(torch.nn.functional.conv2d) def profile_conv2d(test_case): input = torch.ones(8, 128, 28, 28) weight_128c = torch.ones(128, 128, 3, 3) weight_128c_2g = torch.ones(128, 64, 3, 3) weight_1x1_128c = torch.ones(128, 128, 1, 1) weight_5x5_128c = torch.ones(128, 128, 5, 5) bias = torch.ones(128) torch.nn.functional.conv2d(input, weight_128c, padding=1) torch.nn.functional.conv2d(input, weight_128c_2g, groups=2, padding=1) torch.nn.functional.conv2d(input, weight_128c, padding=1, stride=2) torch.nn.functional.conv2d(input, weight_128c, bias=bias, padding=1) torch.nn.functional.conv2d(input, weight_128c, bias=bias, padding=1, stride=2) torch.nn.functional.conv2d(input, weight_1x1_128c) torch.nn.functional.conv2d(input, weight_1x1_128c, stride=2) torch.nn.functional.conv2d(input, weight_1x1_128c, bias=bias) torch.nn.functional.conv2d(input, weight_1x1_128c, bias=bias, stride=2) torch.nn.functional.conv2d(input, weight_5x5_128c, padding=2) torch.nn.functional.conv2d(input, weight_5x5_128c, padding=2, stride=2) torch.nn.functional.conv2d(input, weight_5x5_128c, bias=bias, padding=2) torch.nn.functional.conv2d( input, weight_5x5_128c, bias=bias, padding=2, stride=2 ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_copy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as ori_torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class Test_Copy_module(flow.unittest.TestCase): def test_copy_broadcast_tensor(test_case): torch_base_grid = ori_torch.zeros(1, 2, 2, 3) flow_base_grid = flow.zeros(1, 2, 2, 3) torch_x_grid = ori_torch.ones(2) flow_x_grid = flow.ones(2) torch_base_grid[..., 0].copy_(torch_x_grid) flow_base_grid[..., 0].copy_(flow_x_grid) test_case.assertTrue( np.allclose(torch_base_grid.numpy(), flow_base_grid.numpy()) ) def test_non_contiguous_sliced_tensor_copy(test_case): torch_tensor = torch.arange(24, dtype=torch.float32).reshape(1, 2, 3, 4) flow_tensor = flow.arange(24, dtype=flow.float32).reshape(1, 2, 3, 4) torch_copy = torch.tensor([3.1415]) flow_copy = flow.tensor([3.1415]) torch_tensor[:, 1:2, 1:2, ::2].copy_(torch_copy) flow_tensor[:, 1:2, 1:2, ::2].copy_(flow_copy) test_case.assertTrue(np.allclose(flow_tensor.numpy(), torch_tensor.numpy())) def test_non_contiguous_permuted_tensor_copy(test_case): torch_tensor = torch.arange(24, dtype=torch.float32).reshape(1, 2, 3, 4) flow_tensor = flow.arange(24, dtype=flow.float32).reshape(1, 2, 3, 4) torch_copy = torch.tensor([3.1415]) flow_copy = flow.tensor([3.1415]) torch_tensor.permute(0, 2, 1, 3).copy_(torch_copy) flow_tensor.permute(0, 2, 1, 3).copy_(flow_copy) test_case.assertTrue(np.allclose(flow_tensor.numpy(), torch_tensor.numpy())) def test_copy_fp16(test_case): x = flow.tensor([1, 2], dtype=flow.float16) a = np.array([0, 9], dtype=np.float16) x.copy_(a) test_case.assertTrue(np.array_equal(x.numpy(), a)) def test_tensor_inplace_copy_with_diff_dtype(test_case): x = flow.randn(4, 12).to(flow.int) y = flow.randn(4, 12) y.copy_(x) test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_tensor_inplace_copy_with_diff_dtype_and_device(test_case): x = flow.randn(4, 12).to(flow.int) y = flow.randn(4, 12).to("cuda") y.copy_(x) test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) @unittest.skip("skip for now, becase it failed 6 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_global_tensor_inplace_copy_with_diff_dtype_and_device(test_case): x = ( flow.randn(4, 12) .to(flow.int) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast) ) y = flow.randn(4, 12).to_global( placement=flow.placement.all("cuda"), sbp=flow.sbp.broadcast ) y.copy_(x) test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_cosine_similarity.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestCosineSimilarity(flow.unittest.TestCase): @autotest(n=3) def test_cosine_similartiy_module_with_random_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6).to(device) cos.train(random()) output = cos(a, b) return output @autotest(n=3) def test_cosine_similartiy_functional_with_random_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) output = torch.nn.functional.cosine_similarity(a, b, dim=1, eps=1e-6) return output @unittest.skip("skip for now, becase it failed 4 times in past week") @autotest(n=3) def test_cosine_similartiy_broadcast_with_random_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=2, dim0=1, dim1=128).to(device) output = torch.nn.functional.cosine_similarity(a, b, dim=1, eps=1e-6) return output @autotest(n=3) def test_cosine_similartiy_module_with_nonequal_dim_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=3, dim0=10, dim1=10, dim2=128).to(device) cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6).to(device) cos.train(random()) output = cos(a, b) return output @unittest.skip( reason="https://github.com/Oneflow-Inc/oneflow/issues/8881#issuecomment-1229682453" ) @profile(torch.nn.functional.cosine_similarity) def profile_cosine_similarity(test_case): input1 = torch.ones(100, 128) input2 = torch.ones(100, 128) torch.nn.functional.cosine_similarity(input1, input2) torch.nn.functional.cosine_similarity(input1, input2, dim=0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ctc_greedy_decoder.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest ninf = -float("inf") def log_softmax(logits, axis=0): max_value = np.max(logits, axis, keepdims=True) exp = np.exp(logits - max_value) exp_sum = np.sum(exp, axis, keepdims=True) dist = exp / exp_sum return np.log(dist) def np_ctc_greedy_decoder(log_probs, input_lengths, merge_repeated=True): blank_label = log_probs.shape[2] - 1 decodes = np.zeros( (log_probs.shape[1], log_probs.shape[0]), dtype=input_lengths.dtype ) neg_sum_logits = np.zeros((input_lengths.size, 1), dtype=log_probs.dtype) for b in range(input_lengths.size): input_length = input_lengths[b] prev_indices = -1 t_dec = 0 for t in range(input_length): max_indice = np.argmax(log_probs[t, b, :]) neg_sum_logits[b, 0] -= log_probs[t, b, max_indice] if max_indice != blank_label and ( not (merge_repeated and max_indice == prev_indices) ): decodes[b, t_dec] = max_indice t_dec += 1 prev_indices = max_indice return (decodes, neg_sum_logits) def compare_with_np( device_type, data_type, max_input_length, batch_size, num_classes, merge_repeated, ): assert data_type in ["float32", "double"] assert device_type in ["cpu", "cuda"] assert merge_repeated in [False, True] log_probs = np.random.random( size=(max_input_length, batch_size, num_classes) ).astype(np.float32) log_probs = log_softmax(log_probs, axis=2) input_lengths = np.random.randint( max_input_length / 2, high=max_input_length, size=(batch_size,), dtype=np.int64 ) (np_decoded, np_neg_sum_logits) = np_ctc_greedy_decoder( log_probs, input_lengths, merge_repeated ) log_probs = flow.tensor( log_probs, dtype=flow.float32, requires_grad=False, device=flow.device(device_type), ) input_lengths = flow.tensor( input_lengths, dtype=flow.int64, requires_grad=False, device=flow.device(device_type), ) (of_decoded, of_neg_sum_logits) = flow.nn.functional.ctc_greedy_decoder( log_probs, input_lengths, merge_repeated ) np.allclose(of_decoded.numpy(), np_decoded, atol=1e-05) np.allclose(of_neg_sum_logits.numpy(), np_neg_sum_logits, atol=1e-05) def gen_arg_list(): arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["data_type"] = ["float32"] arg_dict["max_input_length"] = [20] arg_dict["batch_size"] = [4] arg_dict["num_classes"] = [5] arg_dict["merge_repeated"] = [False, True] return GenArgList(arg_dict) @flow.unittest.skip_unless_1n1d() class TestCTCGreedyDecoder1n1d(flow.unittest.TestCase): def test_ctc_greedy_decoder(test_case): for arg in gen_arg_list(): compare_with_np(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ctc_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestCTCLoss1n1d(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") # This test case can always success out of ci container, but will get error in ci container for unknown reason: error: # 'oneflow.ctc_loss' op attribute 'blank' failed to satisfy constraint: 32-bit signed integer attribute # loc("-":0:0): error: Failed to run round-trip passes @autotest(n=5, check_graph=False) def test_ctc_loss_with_diff_device_input(test_case): log_probs = torch.tensor( [ [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]], [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]], [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]], [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]], [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]], ], dtype=torch.float32, requires_grad=True, ) targets = torch.tensor([[1, 2, 2], [1, 2, 2]], dtype=torch.int32, device="cuda") input_lengths = torch.tensor([5, 5], dtype=torch.int32) target_lengths = torch.tensor([3, 3], dtype=torch.int32) loss_mean = torch.nn.CTCLoss(reduction=oneof("mean", "none", "sum", nothing())) out = loss_mean(log_probs, targets, input_lengths, target_lengths) return out @unittest.skip("skip for now, becase it failed 10 times in past week") @autotest(n=5, check_graph=False) def test_ctc_loss_functional(test_case): device_random = random_device() log_probs = random_tensor(ndim=3, dim0=5, dim1=2, dim2=3).to(device_random) targets = random_tensor(ndim=2, dim0=2, dim1=3, low=1, high=3, dtype=int).to( device_random ) input_lengths = torch.tensor([5, 5], dtype=torch.int32) target_lengths = torch.tensor([3, 3], dtype=torch.int32) out = torch.nn.functional.ctc_loss( log_probs, targets, input_lengths, target_lengths, reduction=oneof("mean", "none", "sum", nothing()), ) return out if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_cublas_fused_mlp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _matmul_bias_relu(x, weight, bias, skip_activation): out = flow._C.bias_add(flow._C.matmul(x, weight, transpose_b=True), bias, axis=1) if not skip_activation: out = flow._C.relu(out) return out def _test_fused_matmul_bias_add_relu( test_case, batchsize, in_feature, hidden_size_list, out_feature, skip_final_activation, dtype, device, ): x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_weight_list = [] naive_weight_list = [] fused_bias_list = [] naive_bias_list = [] hidden_num = len(hidden_size_list) if hidden_num != 0: np_first_weight = np.random.uniform( low=-1, high=1, size=(hidden_size_list[0], in_feature) ) np_first_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[0]) fused_weight_list.append( flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True) ) for idx in range(1, hidden_num): np_weight = np.random.uniform( low=-1, high=1, size=(hidden_size_list[idx], hidden_size_list[idx - 1]) ) np_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[idx]) fused_weight_list.append( flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True) ) np_final_weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature)) if hidden_num != 0: np_final_weight = np.random.uniform( low=-1, high=1, size=(out_feature, hidden_size_list[-1]) ) np_final_bias = np.random.uniform(low=-1, high=1, size=(out_feature)) fused_weight_list.append( flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True) ) fused_out = flow._C.fused_mlp( fused_x, fused_weight_list, fused_bias_list, skip_final_activation=skip_final_activation, ) naive_out = _matmul_bias_relu( naive_x, naive_weight_list[0], naive_bias_list[0], False if hidden_num != 0 else skip_final_activation, ) for idx in range(1, hidden_num + 1): if idx == hidden_num: naive_out = _matmul_bias_relu( naive_out, naive_weight_list[idx], naive_bias_list[idx], skip_final_activation, ) else: naive_out = _matmul_bias_relu( naive_out, naive_weight_list[idx], naive_bias_list[idx], False ) total_out = fused_out.sum() + naive_out.sum() total_out.backward() # Test output equality test_case.assertTrue( np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4) ) # Test weight grad equality for idx in range(hidden_num + 1): test_case.assertTrue( np.allclose( fused_weight_list[idx].grad.numpy(), naive_weight_list[idx].grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_list[idx].grad.numpy(), naive_bias_list[idx].grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # Test dx equality test_case.assertTrue( np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestFusedMatmulBiasAddRelu(flow.unittest.TestCase): def test_fused_matmul_op(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_matmul_bias_add_relu] args_dict["batchsize"] = [1, 2, 4] args_dict["in_feature"] = [96, 128] args_dict["hidden_size_list"] = [[256, 512], [256], [96, 144], []] args_dict["out_feature"] = [512, 1024, 288, 1] args_dict["skip_final_activation"] = [True, False] args_dict["dtype"] = [flow.float32, flow.float64] args_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_cum_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest import torch as ori_torch from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestCumOp(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_cumsum(test_case): device = random_device() x = random_tensor().to(device) dim = random(0, x.ndim.pytorch).to(int) z = torch.cumsum(x, dim) return z @autotest(n=5, check_graph=True) def test_cumprod(test_case): device = random_device() x = random_tensor().to(device) dim = random(0, x.ndim.pytorch).to(int) y = torch.cumprod(x, dim) return y def test_cumop_with_dtype(test_case): x = flow.tensor([2, 3, 4]) cumsum_res = flow.cumsum(x, dim=0, dtype=flow.float) cumprod_res = flow.cumprod(x, dim=0, dtype=flow.float) test_case.assertEqual(cumsum_res.dtype, flow.float) test_case.assertEqual(cumprod_res.dtype, flow.float) @autotest(n=5, check_graph=True) def test_cumsum(test_case): device = random_device() x = random_tensor().to(device) dim = random(0, x.ndim.pytorch).to(int) y = x.cumsum(dim) return y @autotest(n=5, check_graph=True) def test_cumprod_with_user_dy(test_case): device = random_device() x = random_tensor().to(device) dim = random(0, x.ndim.pytorch).to(int) y = torch.cumprod(x, dim) z = y * 2 return z def test_cumprod_with_zero(test_case): np_arr = np.ones((5, 5)) np_arr_grad = np_arr np_arr[2][3] = 0 np_arr[4][3] = 0 of_tensor = flow.tensor(np_arr, dtype=flow.float, requires_grad=True) of_res = of_tensor.cumprod(dim=0) of_res.backward(flow.tensor(np_arr_grad, dtype=flow.float)) torch_tensor = ori_torch.tensor( np_arr, dtype=ori_torch.float, requires_grad=True ) torch_res = torch_tensor.cumprod(dim=0) torch_res.backward(ori_torch.tensor(np_arr_grad, dtype=ori_torch.float)) test_case.assertTrue( np.allclose( of_tensor.grad.numpy(), torch_tensor.grad.numpy(), rtol=0.0001, atol=1e-05, ) ) def test_cumsum_graph_backward(test_case): class CustomizedModule(flow.nn.Module): def __init__(self): super().__init__() self.layer = flow.nn.Linear(5, 5) def forward(self, input): layer_out = self.layer(input) loss = flow.cumsum(layer_out, -1) loss = loss.sum() loss.backward() return loss class TestCumsum(flow.nn.Graph): def __init__(self) -> None: super().__init__() self.my_module = CustomizedModule() self.add_optimizer( flow.optim.SGD(self.my_module.parameters(), lr=0.1, momentum=0.0) ) def build(self, ids): loss = self.my_module(ids) return loss ids = np.random.randint(0, 10, (5, 5), dtype=np.int64) ids_tensor = flow.tensor(ids, dtype=flow.float, requires_grad=False) graph = TestCumsum() loss = graph(ids_tensor) @profile(torch.cumsum) def profile_cumsum(test_case): input = torch.ones(100, 1280) torch.cumsum(input, dim=0) torch.cumsum(input, dim=1) @profile(torch.cumprod) def profile_cumprod(test_case): input = torch.ones(100, 1280) torch.cumprod(input, dim=0) torch.cumprod(input, dim=1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import os import unittest import cv2 import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestOFRecordModule(flow.unittest.TestCase): def test_record(test_case): batch_size = 1 color_space = "RGB" height = 224 width = 224 output_layout = "NCHW" rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, shuffle_after_epoch=False, ) record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space=color_space ) record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) resize = flow.nn.image.Resize( resize_side="shorter", keep_aspect_ratio=True, target_size=256 ) crop_mirror_normal = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, crop_h=height, crop_w=width, crop_pos_y=0.5, crop_pos_x=0.5, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) image_raw_buffer_nd = image_raw_buffer.numpy() gt_np = cv2.imread( flow.unittest.dataset_dir("imagenette/ofrecord/gt_tensor_buffer_image.png") ) test_case.assertTrue(np.array_equal(image_raw_buffer_nd[0], gt_np)) image = resize(image_raw_buffer)[0] resized_image_raw_buffer_nd = image.numpy() gt_np = cv2.imread( flow.unittest.dataset_dir( "imagenette/ofrecord/gt_tensor_buffer_resized_image.png" ) ) test_case.assertTrue(np.array_equal(resized_image_raw_buffer_nd[0], gt_np)) image = crop_mirror_normal(image) image_np = image.numpy() image_np = np.squeeze(image_np) image_np = np.transpose(image_np, (1, 2, 0)) image_np = image_np * rgb_std + rgb_mean image_np = cv2.cvtColor(np.float32(image_np), cv2.COLOR_RGB2BGR) image_np = image_np.astype(np.uint8) gt_np = cv2.imread( flow.unittest.dataset_dir("imagenette/ofrecord/gt_val_image.png") ) test_case.assertEqual(label.numpy(), 5) test_case.assertTrue(np.array_equal(image_np, gt_np)) @flow.unittest.skip_unless_1n1d() class TestGlobalOFRecordModule(flow.unittest.TestCase): def test_global_record(test_case): batch_size = 1 color_space = "RGB" height = 224 width = 224 output_layout = "NCHW" rgb_mean = [123.68, 116.779, 103.939] rgb_std = [58.393, 57.12, 57.375] record_reader = flow.nn.OfrecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, data_part_num=1, part_name_suffix_length=5, shuffle_after_epoch=False, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.split(0)], ) record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space=color_space ) record_label_decoder = flow.nn.OfrecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) resize = flow.nn.image.Resize( resize_side="shorter", keep_aspect_ratio=True, target_size=256 ) flip = flow.nn.CoinFlip( batch_size=batch_size, placement=flow.placement("cpu", ranks=[0]), sbp=[flow.sbp.split(0)], ) crop_mirror_normal = flow.nn.CropMirrorNormalize( color_space=color_space, output_layout=output_layout, crop_h=height, crop_w=width, crop_pos_y=0.5, crop_pos_x=0.5, mean=rgb_mean, std=rgb_std, output_dtype=flow.float, ) rng = flip() val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) image_raw_buffer_nd = image_raw_buffer.to_local().numpy() gt_np = cv2.imread( flow.unittest.dataset_dir("imagenette/ofrecord/gt_tensor_buffer_image.png") ) test_case.assertTrue(np.array_equal(image_raw_buffer_nd[0], gt_np)) image = resize(image_raw_buffer)[0] resized_image_raw_buffer_nd = image.to_local().numpy() gt_np = cv2.imread( flow.unittest.dataset_dir( "imagenette/ofrecord/gt_tensor_buffer_resized_image.png" ) ) test_case.assertTrue(np.array_equal(resized_image_raw_buffer_nd[0], gt_np)) image = crop_mirror_normal(image) image_np = image.to_local().numpy() image_np = np.squeeze(image_np) image_np = np.transpose(image_np, (1, 2, 0)) image_np = image_np * rgb_std + rgb_mean image_np = cv2.cvtColor(np.float32(image_np), cv2.COLOR_RGB2BGR) image_np = image_np.astype(np.uint8) gt_np = cv2.imread( flow.unittest.dataset_dir("imagenette/ofrecord/gt_val_image.png") ) test_case.assertEqual(label.to_local().numpy(), 5) test_case.assertTrue(np.array_equal(image_np, gt_np)) coco_dict = dict() def _coco(anno_file): global coco_dict if anno_file not in coco_dict: from pycocotools.coco import COCO coco_dict[anno_file] = COCO(anno_file) return coco_dict[anno_file] def _get_coco_image_samples(anno_file, image_dir, image_ids): coco = _coco(anno_file) category_id_to_contiguous_id_map = _get_category_id_to_contiguous_id_map(coco) (image, image_size) = _read_images_with_cv(coco, image_dir, image_ids) bbox = _read_bbox(coco, image_ids) label = _read_label(coco, image_ids, category_id_to_contiguous_id_map) img_segm_poly_list = _read_segm_poly(coco, image_ids) (poly, poly_index) = _segm_poly_list_to_tensor(img_segm_poly_list) samples = [] for (im, ims, b, l, p, pi) in zip(image, image_size, bbox, label, poly, poly_index): samples.append( dict(image=im, image_size=ims, bbox=b, label=l, poly=p, poly_index=pi) ) return samples def _get_category_id_to_contiguous_id_map(coco): return {v: i + 1 for (i, v) in enumerate(coco.getCatIds())} def _read_images_with_cv(coco, image_dir, image_ids): image_files = [ os.path.join(image_dir, coco.imgs[img_id]["file_name"]) for img_id in image_ids ] image_size = [ (coco.imgs[img_id]["height"], coco.imgs[img_id]["width"]) for img_id in image_ids ] return ( [cv2.imread(image_file).astype(np.single) for image_file in image_files], image_size, ) def _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w): (x, y, w, h) = bbox (x1, y1) = (x, y) x2 = x1 + max(w - 1, 0) y2 = y1 + max(h - 1, 0) x1 = min(max(x1, 0), image_w - 1) y1 = min(max(y1, 0), image_h - 1) x2 = min(max(x2, 0), image_w - 1) y2 = min(max(y2, 0), image_h - 1) if x1 >= x2 or y1 >= y2: return None return [x1, y1, x2, y2] def _read_bbox(coco, image_ids): img_bbox_list = [] for img_id in image_ids: anno_ids = coco.getAnnIds(imgIds=[img_id]) assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id) image_h = coco.imgs[img_id]["height"] image_w = coco.imgs[img_id]["width"] bbox_list = [] for anno_id in anno_ids: anno = coco.anns[anno_id] if anno["iscrowd"] != 0: continue bbox = anno["bbox"] assert isinstance(bbox, list) bbox_ = _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w) if bbox_ is not None: bbox_list.append(bbox_) bbox_array = np.array(bbox_list, dtype=np.single) img_bbox_list.append(bbox_array) return img_bbox_list def _read_label(coco, image_ids, category_id_to_contiguous_id_map): img_label_list = [] for img_id in image_ids: anno_ids = coco.getAnnIds(imgIds=[img_id]) assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id) label_list = [] for anno_id in anno_ids: anno = coco.anns[anno_id] if anno["iscrowd"] != 0: continue cate_id = anno["category_id"] isinstance(cate_id, int) label_list.append(category_id_to_contiguous_id_map[cate_id]) label_array = np.array(label_list, dtype=np.int32) img_label_list.append(label_array) return img_label_list def _read_segm_poly(coco, image_ids): img_segm_poly_list = [] for img_id in image_ids: anno_ids = coco.getAnnIds(imgIds=[img_id]) assert len(anno_ids) > 0, "img {} has no anno".format(img_id) segm_poly_list = [] for anno_id in anno_ids: anno = coco.anns[anno_id] if anno["iscrowd"] != 0: continue segm = anno["segmentation"] assert isinstance(segm, list) assert len(segm) > 0, str(len(segm)) assert all([len(poly) > 0 for poly in segm]), str( [len(poly) for poly in segm] ) segm_poly_list.append(segm) img_segm_poly_list.append(segm_poly_list) return img_segm_poly_list def _segm_poly_list_to_tensor(img_segm_poly_list): poly_array_list = [] poly_index_array_list = [] for (img_idx, segm_poly_list) in enumerate(img_segm_poly_list): img_poly_elem_list = [] img_poly_index_list = [] for (obj_idx, poly_list) in enumerate(segm_poly_list): for (poly_idx, poly) in enumerate(poly_list): img_poly_elem_list.extend(poly) for (pt_idx, pt) in enumerate(poly): if pt_idx % 2 == 0: img_poly_index_list.append([pt_idx / 2, poly_idx, obj_idx]) img_poly_array = np.array(img_poly_elem_list, dtype=np.single).reshape(-1, 2) assert img_poly_array.size > 0, segm_poly_list poly_array_list.append(img_poly_array) img_poly_index_array = np.array(img_poly_index_list, dtype=np.int32) assert img_poly_index_array.size > 0, segm_poly_list poly_index_array_list.append(img_poly_index_array) return (poly_array_list, poly_index_array_list) @flow.unittest.skip_unless_1n1d() class TestCocoReader(flow.unittest.TestCase): def test_coco_reader(test_case): anno_file = flow.unittest.dataset_dir( "mscoco_2017/annotations/instances_val2017.json" ) image_dir = flow.unittest.dataset_dir("mscoco_2017/val2017") num_iterations = 10 coco_reader = flow.nn.COCOReader( annotation_file=anno_file, image_dir=image_dir, batch_size=2, shuffle=True, stride_partition=True, ) image_decoder = flow.nn.image.decode(dtype=flow.float) for i in range(num_iterations): ( image, image_id, image_size, gt_bbox, gt_label, gt_segm, gt_segm_index, ) = coco_reader() decoded_image = image_decoder(image) image_list = decoded_image.numpy() image_id = image_id.numpy() image_size = image_size.numpy() bbox_list = gt_bbox.numpy() label_list = gt_label.numpy() segm_list = gt_segm.numpy() segm_index_list = gt_segm_index.numpy() samples = _get_coco_image_samples(anno_file, image_dir, image_id) for (i, sample) in enumerate(samples): test_case.assertTrue(np.array_equal(image_list[i], sample["image"])) test_case.assertTrue( np.array_equal(image_size[i], sample["image_size"]) ) test_case.assertTrue(np.allclose(bbox_list[i], sample["bbox"])) cur_label = label_list[i] if len(cur_label.shape) == 0: cur_label = np.array([cur_label]) test_case.assertTrue(np.array_equal(cur_label, sample["label"])) test_case.assertTrue(np.allclose(segm_list[i], sample["poly"])) test_case.assertTrue( np.array_equal(segm_index_list[i], sample["poly_index"]) ) @flow.unittest.skip_unless_1n1d() class TestOFRecordBytesDecoder(flow.unittest.TestCase): def test_OFRecordBytesDecoder(test_case): batch_size = 16 record_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenette/ofrecord"), batch_size=batch_size, part_name_suffix_length=5, ) val_record = record_reader() bytesdecoder_img = flow.nn.OFRecordBytesDecoder("encoded") image_raw_buffer = bytesdecoder_img(val_record) image_raw_buffer_nd = image_raw_buffer.numpy()[0] gt_np = cv2.imread( flow.unittest.dataset_dir("imagenette/ofrecord/gt_tensor_buffer_image.png") ) img = cv2.imdecode(image_raw_buffer_nd, cv2.IMREAD_COLOR) test_case.assertTrue(np.array_equal(img, gt_np)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ddp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow # Test import from oneflow.nn.parallel.distributed from oneflow.nn.parallel.distributed import DistributedDataParallel from oneflow.nn.parallel import DistributedDataParallel as ddp from oneflow.test_utils.test_util import GenCartesianProduct import oneflow.unittest import numpy as np import os def np_allclose_with_shape(a, b, *args, **kwargs): if a.shape != b.shape: return False return np.allclose(a, b, *args, **kwargs) test_device = ["cpu"] if os.getenv("ONEFLOW_TEST_CPU_ONLY") else ["cpu", "cuda"] @flow.unittest.skip_unless_1n2d() class TestDDP(flow.unittest.TestCase): def _test_ddp_basic(test_case, dev_type): class Mul(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.Tensor([1, 1])) def forward(self, x): return x * self.w rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1, 1]) elif rank == 1: x = flow.Tensor([2, 2]) else: raise ValueError() x = x.to(dev_type) m = Mul().to(dev_type) m = ddp(m) y = m(x) y.sum().backward() test_case.assertTrue( np_allclose_with_shape(m.w.grad.numpy(), np.array([1.5, 1.5])) ) def test_ddp_basic(test_case): for dev_type in test_device: test_case._test_ddp_basic(dev_type) def _test_ddp_multiple_buckets(test_case, dev_type, use_bucket): class Mul(flow.nn.Module): def __init__(self): super().__init__() for i in range(10): self.register_parameter( f"w{i}", flow.nn.Parameter(flow.Tensor([i % 2 + 1, i % 2 + 1])) ) def forward(self, x): for i in range(10): x = x * getattr(self, f"w{i}") return x rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1, 1]) elif rank == 1: x = flow.Tensor([2, 2]) else: raise ValueError() x = x.to(dev_type) m = Mul().to(dev_type) m = ddp(m, bucket_size=3, use_bucket=use_bucket) y = m(x) y.sum().backward() for i in range(10): test_case.assertTrue( np_allclose_with_shape( getattr(m, f"w{i}").grad.numpy(), np.array([48, 48]) if i % 2 == 0 else np.array([24, 24]), ) ) def test_ddp_multiple_buckets(test_case): for dev_type, use_bucket in GenCartesianProduct((test_device, [True, False])): test_case._test_ddp_multiple_buckets(dev_type, use_bucket) def _test_ddp_with_unused_param(test_case, dev_type): class Model(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.Tensor([1])) self.used_only_in_rank0 = flow.nn.Parameter(flow.Tensor([2])) self.unused_in_all_ranks = flow.nn.Parameter(flow.Tensor([3])) def forward(self, x): x = x * self.w if flow.env.get_rank() == 0: x = x * self.used_only_in_rank0 return x rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1]) elif rank == 1: x = flow.Tensor([2]) else: raise ValueError() x = x.to(dev_type) m = Model().to(dev_type) m = ddp(m, bucket_size=2) y = m(x) y.backward() test_case.assertTrue(np_allclose_with_shape(m.w.grad.numpy(), np.array([2]))) test_case.assertTrue( np_allclose_with_shape(m.used_only_in_rank0.grad.numpy(), np.array([0.5])) ) test_case.assertTrue( np_allclose_with_shape(m.unused_in_all_ranks.grad.numpy(), np.array([0])) ) def test_ddp_with_unused_param(test_case): for dev_type in test_device: test_case._test_ddp_with_unused_param(dev_type) def _test_out_of_order_execution(test_case, dev_type): class Model(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([1])) self.w2 = flow.nn.Parameter(flow.Tensor([2])) self.w3 = flow.nn.Parameter(flow.Tensor([3])) def forward(self, x): if flow.env.get_rank() == 0: x *= self.w1 x *= self.w2 x *= self.w3 else: x *= self.w3 x *= self.w2 x *= self.w1 return x rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1]) elif rank == 1: x = flow.Tensor([2]) else: raise ValueError() x = x.to(dev_type) m = Model().to(dev_type) m = ddp(m, bucket_size=1) y = m(x) y.backward() test_case.assertTrue(np_allclose_with_shape(m.w1.grad.numpy(), np.array([9]))) test_case.assertTrue(np_allclose_with_shape(m.w2.grad.numpy(), np.array([4.5]))) test_case.assertTrue(np_allclose_with_shape(m.w3.grad.numpy(), np.array([3]))) def test_out_of_order_execution(test_case): for dev_type in test_device: test_case._test_out_of_order_execution(dev_type) def _test_ddp_with_partial_requires_grad_parameter(test_case, dev_type): class Model(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([1]), requires_grad=False) self.w2 = flow.nn.Parameter(flow.Tensor([2])) self.w3 = flow.nn.Parameter(flow.Tensor([3])) def forward(self, x): if flow.env.get_rank() == 0: x *= self.w1 x *= self.w2 x *= self.w3 else: x *= self.w3 x *= self.w2 x *= self.w1 return x rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1]) elif rank == 1: x = flow.Tensor([2]) else: raise ValueError() x = x.to(dev_type) m = Model().to(dev_type) m = ddp(m, bucket_size=1) y = m(x) y.backward() test_case.assertTrue(np_allclose_with_shape(m.w2.grad.numpy(), np.array([4.5]))) test_case.assertTrue(np_allclose_with_shape(m.w3.grad.numpy(), np.array([3]))) def test_ddp_with_partial_requires_grad_parameter(test_case): for dev_type in test_device: test_case._test_ddp_with_partial_requires_grad_parameter(dev_type) def _test_ddp_two_iters(test_case, dev_type): class Mul(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.Tensor([1, 1])) def forward(self, x): return x * self.w rank = flow.env.get_rank() if rank == 0: x = flow.Tensor([1, 1]) elif rank == 1: x = flow.Tensor([2, 2]) else: raise ValueError() x = x.to(dev_type) m = Mul().to(dev_type) m = ddp(m) for _ in range(2): y = m(x) y.sum().backward() test_case.assertTrue(np_allclose_with_shape(m.w.grad.numpy(), np.array([3, 3]))) def test_ddp_two_iters(test_case): for dev_type in test_device: test_case._test_ddp_two_iters(dev_type) def _test_broadcast_buffer(test_case, dev_type): rank = flow.env.get_rank() class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.register_buffer("buf", flow.tensor([1, 2]) * (rank + 1)) def forward(self, x): res = self.buf + x self.buf.copy_(x) return res x = flow.tensor([2, 3]) * (rank + 1) x = x.to(dev_type) m = CustomModule() m = m.to(dev_type) m = ddp(m) y1 = m(x) y2 = m(x) m = CustomModule() m = m.to(dev_type) m = ddp(m, broadcast_buffers=False) y3 = m(x) y4 = m(x) if rank == 0: test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([3, 5]))) test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([4, 6]))) test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([3, 5]))) test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([4, 6]))) elif rank == 1: test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([5, 8]))) test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([6, 9]))) test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([6, 10]))) test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([8, 12]))) else: raise ValueError() def test_broadcast_buffer(test_case): for dev_type in test_device: test_case._test_broadcast_buffer(dev_type) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ddp_multi_outputs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import oneflow as flow from oneflow.nn.parallel import DistributedDataParallel as ddp import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict train_x = [ flow.tensor([[1, 2], [2, 3]], dtype=flow.float32), flow.tensor([[4, 6], [3, 1]], dtype=flow.float32), ] train_float32 = [ flow.tensor([[1, 2], [2, 3]], dtype=flow.float32), flow.tensor([[4, 6], [3, 1]], dtype=flow.float32), ] train_int32 = [ flow.tensor([[8], [13]], dtype=flow.int32), flow.tensor([[26], [9]], dtype=flow.int32), ] class Model(flow.nn.Module): def __init__(self): super().__init__() self.lr = 0.01 self.iter_count = 10 self.w1 = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32)) self.w2 = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32)) def forward(self, x, label): if flow.env.get_rank() == 0: x1 = flow.matmul(x, self.w1) else: x1 = flow.matmul(x, self.w2) return ([x1, label + 1], label + 2) def train(test_case, train_x, device, output, requires_grad): m = Model().to(device) m = ddp(m) loss = flow.nn.MSELoss(reduction="sum") optimizer = flow.optim.SGD(m.parameters(), m.lr) for i in range(0, m.iter_count): rank = flow.env.get_rank() x = train_x[rank].clone().to(device) y = output[rank].clone().to(device) y.requires_grad = requires_grad (y_pred, y_add_1), y_add_2 = m(x, y) test_case.assertEqual(y_add_1.requires_grad, y.requires_grad) test_case.assertEqual(y_add_2.requires_grad, y.requires_grad) l = loss(y_pred, y) l.backward() optimizer.step() optimizer.zero_grad() test_device = ["cpu"] if os.getenv("ONEFLOW_TEST_CPU_ONLY") else ["cpu", "cuda"] @flow.unittest.skip_unless_1n2d() class TestDdpMultmpleOutputs(flow.unittest.TestCase): def test_outputs_float32(test_case): arg_dict = OrderedDict() arg_dict["device"] = test_device arg_dict["output"] = [train_float32] arg_dict["requires_grad"] = [True, False] for arg in GenArgDict(arg_dict): train(test_case, train_x, **arg) def test_outputs_int32(test_case): arg_dict = OrderedDict() arg_dict["device"] = test_device arg_dict["output"] = [train_int32] arg_dict["requires_grad"] = [False] for arg in GenArgDict(arg_dict): train(test_case, train_x, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_deconv2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.nn as nn import oneflow.unittest import torch as torch_original from packaging import version def _test_deconv_bias_false(test_case, device): np_arr = np.array( [ [ [ [0.2735021114349365, -1.3842310905456543], [1.058540940284729, -0.03388553857803345], ] ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [ [0.06456436216831207, -0.10852358490228653, -0.21638715267181396], [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096], [0.05026858672499657, 0.10818571597337723, 0.02056501805782318], ], [ [0.205095112323761, 0.1488947868347168, -0.2344113141298294], [0.1684819906949997, -0.21986986696720123, 0.1082606166601181], [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222], ], ] ] ) m = nn.ConvTranspose2d(1, 2, 3, stride=1, bias=False) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [ [ 0.01765848882496357, -0.1190534234046936, 0.09103937447071075, 0.2995298206806183, ], [ 0.006009865552186966, 0.2388070970773697, -0.37657976150512695, -0.26200416684150696, ], [ -0.22750461101531982, 0.12405071407556534, 0.056831881403923035, -0.035060010850429535, ], [ 0.053211357444524765, 0.11281562596559525, 0.0181029811501503, -0.0006968567031435668, ], ], [ [ 0.05609394609928131, -0.24317599833011627, -0.27021679282188416, 0.32447943091392517, ], [ 0.26318174600601196, -0.14269141852855682, 0.08078087121248245, -0.14191456139087677, ], [ 0.13652732968330383, 0.020019691437482834, -0.10959184169769287, -0.03072327747941017, ], [ -0.16184815764427185, 0.1864076405763626, 0.014887845143675804, -0.0006622931105084717, ], ], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.24731683731079102, 0.24731683731079102], [0.24731683731079102, 0.24731683731079102], ] ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_deconv_bias_true(test_case, device): np_arr = np.array( [ [ [ [0.2735021114349365, -1.3842310905456543], [1.058540940284729, -0.03388553857803345], ] ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) weight = np.array( [ [ [ [0.06456436216831207, -0.10852358490228653, -0.21638715267181396], [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096], [0.05026858672499657, 0.10818571597337723, 0.02056501805782318], ], [ [0.205095112323761, 0.1488947868347168, -0.2344113141298294], [0.1684819906949997, -0.21986986696720123, 0.1082606166601181], [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222], ], ] ] ) bias = np.array([0.06456436216831207, -0.10852358490228653]) m = nn.ConvTranspose2d(1, 2, 3, stride=1) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) output = m(input) np_out = [ [ [ [ 0.0822228491306305, -0.05448906123638153, 0.15560373663902283, 0.36409419775009155, ], [ 0.07057422399520874, 0.30337145924568176, -0.3120154142379761, -0.19743980467319489, ], [ -0.16294024884700775, 0.188615083694458, 0.12139624357223511, 0.029504351317882538, ], [ 0.11777572333812714, 0.17737999558448792, 0.08266734331846237, 0.06386750191450119, ], ], [ [ -0.05242963880300522, -0.3516995906829834, -0.3787403702735901, 0.21595585346221924, ], [ 0.15465816855430603, -0.25121501088142395, -0.027742713689804077, -0.2504381537437439, ], [ 0.028003744781017303, -0.088503897190094, -0.2181154191493988, -0.139246866106987, ], [ -0.2703717350959778, 0.07788405567407608, -0.09363573789596558, -0.10918587446212769, ], ], ] ] test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.24731683731079102, 0.24731683731079102], [0.24731683731079102, 0.24731683731079102], ] ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_deconv_group_bias_false(test_case, device): np_arr = np.array( [ [ [ [-2.0125174206754517, 1.9917882689443576], [0.13146748727936577, -0.5356457374181375], ], [ [1.020683505853394, 1.2900643048299678], [-0.549010560600543, 0.8088391626901512], ], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = nn.ConvTranspose2d(2, 2, 3, stride=1, groups=2, bias=False) weight = np.array( [ [ [ [0.06456436216831207, -0.10852358490228653, -0.21638715267181396], [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096], [0.05026858672499657, 0.10818571597337723, 0.02056501805782318], ] ], [ [ [0.205095112323761, 0.1488947868347168, -0.2344113141298294], [0.1684819906949997, -0.21986986696720123, 0.1082606166601181], [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222], ] ], ] ) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [ [ -0.12993690371513367, 0.34700414538383484, 0.219326913356781, -0.43099740147590637, ], [ 0.4671630859375, -0.8000040054321289, -0.06776165962219238, 0.5034587383270264, ], [ -0.13112929463386536, 0.02389305830001831, 0.12057329714298248, -0.06326202303171158, ], [ 0.00660868501290679, -0.012703249230980873, -0.05524558573961258, -0.011015564203262329, ], ], [ [ 0.20933720469474792, 0.4165603518486023, -0.04717591404914856, -0.3024056851863861, ], [ 0.059367403388023376, 0.07707919180393219, 0.07597976922988892, -0.049937888979911804, ], [ -0.24855825304985046, 0.2344835251569748, 0.003538096323609352, 0.11277973651885986, ], [ 0.08394229412078857, -0.21766230463981628, 0.12774622440338135, 0.015808766707777977, ], ], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.03301373869180679, 0.03301373869180679], [0.03301373869180679, 0.03301373869180679], ], [ [0.21430310606956482, 0.21430310606956482], [0.21430310606956482, 0.21430310606956482], ], ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_deconv_group_bias_true(test_case, device): np_arr = np.array( [ [ [ [-2.0125174206754517, 1.9917882689443576], [0.13146748727936577, -0.5356457374181375], ], [ [1.020683505853394, 1.2900643048299678], [-0.549010560600543, 0.8088391626901512], ], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = nn.ConvTranspose2d(2, 2, 3, stride=1, groups=2) weight = np.array( [ [ [ [0.06456436216831207, -0.10852358490228653, -0.21638715267181396], [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096], [0.05026858672499657, 0.10818571597337723, 0.02056501805782318], ] ], [ [ [0.205095112323761, 0.1488947868347168, -0.2344113141298294], [0.1684819906949997, -0.21986986696720123, 0.1082606166601181], [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222], ] ], ] ) m.weight = flow.nn.Parameter(flow.Tensor(weight)) bias = np.array([0.06456436216831207, -0.10852358490228653]) m.bias = flow.nn.Parameter(flow.Tensor(bias)) m = m.to(device) output = m(input) np_out = [ [ [ [ -0.0653725415468216, 0.4115685224533081, 0.2838912606239319, -0.3664330244064331, ], [ 0.5317274332046509, -0.735439658164978, -0.00319729745388031, 0.5680230855941772, ], [ -0.06656493246555328, 0.08845742046833038, 0.18513765931129456, 0.0013023391366004944, ], [ 0.0711730495095253, 0.05186111479997635, 0.009318776428699493, 0.053548797965049744, ], ], [ [ 0.1008136197924614, 0.30803677439689636, -0.1556994915008545, -0.41092926263809204, ], [ -0.04915618151426315, -0.03144439309835434, -0.032543815672397614, -0.15846148133277893, ], [ -0.3570818305015564, 0.12595993280410767, -0.10498549044132233, 0.004256151616573334, ], [ -0.024581290781497955, -0.3261858820915222, 0.019222639501094818, -0.0927148163318634, ], ], ] ] test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.03301373869180679, 0.03301373869180679], [0.03301373869180679, 0.03301373869180679], ], [ [0.21430310606956482, 0.21430310606956482], [0.21430310606956482, 0.21430310606956482], ], ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_deconv_group_large_out_channel(test_case, device): np_arr = np.array( [ [ [ [-2.0125174206754517, 1.9917882689443576], [0.13146748727936577, -0.5356457374181375], ], [ [1.020683505853394, 1.2900643048299678], [-0.549010560600543, 0.8088391626901512], ], ] ] ) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = nn.ConvTranspose2d(2, 6, 3, stride=1, groups=2, bias=False) weight = np.array( [ [ [ [0.05271657928824425, -0.08860913664102554, -0.17667937278747559], [-0.18608860671520233, 0.12057777494192123, 0.1588696986436844], [0.04104413092136383, 0.08833327144384384, 0.016791267320513725], ], [ [0.16745945811271667, 0.1215720921754837, -0.19139604270458221], [0.13756497204303741, -0.17952299118041992, 0.08839442580938339], [-0.12484020739793777, 0.13978762924671173, 0.015958432108163834], ], [ [-0.07709092646837234, -0.029757702723145485, -0.18154984712600708], [-0.14461342990398407, 0.06567336618900299, 0.05665326863527298], [0.04441174864768982, -0.04477253183722496, 0.191376194357872], ], ], [ [ [0.1850736141204834, 0.07141514122486115, 0.05791180208325386], [0.07253318279981613, -0.042754165828228, -0.14045141637325287], [0.08525089919567108, 0.009758883155882359, -0.07303793728351593], ], [ [-0.005451973062008619, 0.1499139368534088, 0.16706342995166779], [-0.05473465472459793, 0.02753184549510479, -0.06856250017881393], [0.03629609942436218, -0.06238799914717674, -0.041715867817401886], ], [ [0.15021666884422302, -0.10501708835363388, 0.04741475358605385], [-0.16011257469654083, 0.1280348002910614, 0.11050418764352798], [-0.10031674802303314, 0.1449088454246521, -0.16990724205970764], ], ], ] ) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) output = m(input) np_out = np.array( [ [ [ [ -0.10609303414821625, 0.28332769870758057, 0.17907968163490295, -0.3519079089164734, ], [ 0.3814370930194855, -0.653200626373291, -0.055327147245407104, 0.41107234358787537, ], [ -0.10706663131713867, 0.019508585333824158, 0.09844768047332764, -0.05165322124958038, ], [ 0.005395968910306692, -0.010372160002589226, -0.04510783404111862, -0.00899417046457529, ], ], [ [ -0.3370150923728943, 0.08887782692909241, 0.6273337602615356, -0.38122040033340454, ], [ -0.25483641028404236, 0.561577320098877, -0.6257490515708923, 0.27858346700668335, ], [ 0.26932841539382935, -0.6272678375244141, 0.35409244894981384, -0.015562277287244797, ], [ -0.01641242951154709, 0.08524765074253082, -0.0727786272764206, -0.008548066020011902, ], ], [ [ 0.15514683723449707, -0.09366090595722198, 0.3061012029647827, -0.3616088628768921, ], [ 0.28090208768844604, -0.38282686471939087, 0.008863434195518494, 0.21008771657943726, ], [ -0.10839138925075531, 0.2646597623825073, -0.5020549297332764, 0.35083478689193726, ], [ 0.005838701035827398, -0.029675094410777092, 0.04914196580648422, -0.10250984132289886, ], ], [ [ 0.18890158832073212, 0.3116491138935089, 0.15123975276947021, 0.074709951877594, ], [ -0.027573950588703156, 0.16042113304138184, -0.17254289984703064, -0.1343500316143036, ], [ 0.047192707657814026, 0.20208004117012024, -0.01943095773458481, -0.20782624185085297, ], [ -0.04680364578962326, 0.06359653919935226, 0.04799196869134903, -0.05907594412565231, ], ], [ [ -0.005564738996326923, 0.1459812968969345, 0.3639175295829773, 0.21552257239818573, ], [ -0.05287356674671173, -0.12922403216362, -0.0049260929226875305, 0.04667740315198898, ], [ 0.06709674000740051, -0.0762409120798111, -0.06315286457538605, -0.10927218943834305, ], [ -0.019926942884922028, 0.06360937654972076, -0.027559401467442513, -0.03374142572283745, ], ], [ [ 0.1533236801624298, 0.08659995347261429, -0.08708333969116211, 0.06116808205842972, ], [ -0.24589480459690094, 0.10328409075737, 0.16698980331420898, 0.1809084266424179, ], [ -0.014488153159618378, -0.18130677938461304, 0.056411802768707275, -0.1298111528158188, ], [ 0.05507495626807213, -0.1606965959072113, 0.21048882603645325, -0.13742762804031372, ], ], ] ] ) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.0822635293006897, 0.0822635293006897], [0.0822635293006897, 0.0822635293006897], ], [ [0.4193778932094574, 0.4193778932094574], [0.4193778932094574, 0.4193778932094574], ], ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_deconv_group_large_in_channel(test_case, device): np_arr = [ [ [ [0.6393764315295867, 0.3890587560476374], [0.8467359871201484, 0.24046160407703143], ], [ [0.23352071016856402, 0.6760713653927521], [0.061939453383917376, 0.13541973098624682], ], [ [0.7524804920779914, 0.34366296030931365], [0.4961502482687954, 0.38175448164636205], ], [ [0.01867975512238773, 0.12599156959160163], [0.2658608593205851, 0.6184459583178925], ], ] ] input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = nn.ConvTranspose2d(4, 2, 3, stride=1, groups=2, bias=False) weight = np.array( [ [ [ [0.09130779653787613, -0.15347552299499512, -0.30601766705513], [-0.32231491804122925, 0.2088468372821808, 0.27517038583755493], [0.07109051942825317, 0.1529977172613144, 0.02908332832157612], ] ], [ [ [0.2900483012199402, 0.21056903898715973, -0.33150768280029297], [0.23826952278614044, -0.31094294786453247, 0.15310363471508026], [-0.21622958779335022, 0.24211928248405457, 0.0276408139616251], ] ], [ [ [-0.13352541625499725, -0.051541853696107864, -0.3144535720348358], [-0.2504778206348419, 0.11374961584806442, 0.09812634438276291], [0.07692340761423111, -0.0775483027100563, 0.33147329092025757], ] ], [ [ [0.3205569088459015, 0.12369465827941895, 0.1003061905503273], [0.1256311535835266, -0.07405238598585129, -0.24326899647712708], [0.14765889942646027, 0.016902882605791092, -0.12650541961193085], ] ], ] ) m.weight = flow.nn.Parameter(flow.Tensor(weight)) m = m.to(device) np_out = np.array( [ [ [ [ 0.12611234188079834, 0.1826610565185547, -0.19042569398880005, -0.34318169951438904, ], [ -0.05516064167022705, 0.04093143343925476, -0.2053149938583374, 0.0920882523059845, ], [ -0.2631978690624237, 0.14817529916763306, 0.4988565742969513, 0.11690345406532288, ], [ 0.04680176079273224, 0.13235820829868317, 0.09591575711965561, 0.010736535303294659, ], ], [ [ -0.09448734670877457, -0.04197392612695694, -0.2368750274181366, -0.09542831033468246, ], [ -0.1671580672264099, 0.16854587197303772, 0.02652890235185623, -0.05493755638599396, ], [ -0.030232630670070648, 0.0058259665966033936, 0.20417997241020203, -0.015012085437774658, ], [ 0.07742229104042053, 0.0867031067609787, 0.11167682707309723, 0.048304662108421326, ], ], ] ] ) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = [ [ [ [0.046688467264175415, 0.046688467264175415], [0.046688467264175415, 0.046688467264175415], ], [ [0.30307042598724365, 0.30307042598724365], [0.30307042598724365, 0.30307042598724365], ], [ [-0.20727425813674927, -0.20727425813674927], [-0.20727425813674927, -0.20727425813674927], ], [ [0.3909238576889038, 0.3909238576889038], [0.3909238576889038, 0.3909238576889038], ], ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06)) @flow.unittest.skip_unless_1n1d() class TestDeconv2d(flow.unittest.TestCase): def test_deconv2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_deconv_bias_false, _test_deconv_bias_true, _test_deconv_group_bias_false, _test_deconv_group_bias_true, _test_deconv_group_large_out_channel, _test_deconv_group_large_in_channel, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, rtol=1e-2, atol=1e-3) def test_deconv2d_with_random_data(test_case): channels = random(1, 6) m = torch.nn.ConvTranspose2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), bias=random_bool(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=channels).to(device) y = m(x) return y @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.13.0"), "deconv module don't support unbatched input in PyTorch before '1.13.0'", ) @autotest(n=5) def test_deconv2d_auto_squeeze_with_random_data(test_case): channels = random(1, 6) m = torch.nn.ConvTranspose2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), bias=random_bool(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim0=channels).to(device) y = m(x) return y @autotest(check_graph=False) def test_deconv2d_0size_with_random_data(test_case): channels = random(1, 6) m = torch.nn.ConvTranspose2d( in_channels=channels, out_channels=random(1, 20), kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 5) | nothing(), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim0=0, dim1=channels).to(device) y = m(x) return y @unittest.skip( "Likely to fail the test. This case should run on cpu when the problem is solved." ) @autotest(n=30, check_graph=False, rtol=1e-2, atol=1e-4) def test_deconv2d_group_with_random_data(test_case): channels = 720 # lcm(1, 2, 3, 4, 5, 6) m = torch.nn.ConvTranspose2d( in_channels=channels, out_channels=channels, kernel_size=random(1, 4), stride=random() | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(1, 5) | nothing(), groups=random(1, 7), padding_mode=constant("zeros") | nothing(), ) m.train(random()) device = random_device() m.to(device) m.pytorch.to("cuda") x = random_tensor(ndim=4, dim1=channels).to(device) x.pytorch = x.pytorch.to("cuda") y = m(x) return y @profile(torch.nn.functional.conv_transpose2d) def profile_conv_transpose2d(test_case): inputs = torch.ones(16, 128, 128, 128) weights_4x4_64c = torch.ones(128, 64, 4, 4) weights_6x6_64c = torch.ones(128, 64, 6, 6) weights_8x8_64c = torch.ones(128, 64, 8, 8) torch.nn.functional.conv_transpose2d( inputs, weights_4x4_64c, stride=2, padding=1 ) torch.nn.functional.conv_transpose2d( inputs, weights_4x4_64c, stride=2, padding=1, bias=torch.ones(64) ) torch.nn.functional.conv_transpose2d( inputs, weights_6x6_64c, stride=3, padding=2, output_padding=1 ) torch.nn.functional.conv_transpose2d( inputs, weights_6x6_64c, stride=3, padding=2, bias=torch.ones(64), output_padding=1, ) torch.nn.functional.conv_transpose2d( inputs, weights_8x8_64c, stride=4, padding=2 ) torch.nn.functional.conv_transpose2d( inputs, weights_8x8_64c, stride=4, padding=2, bias=torch.ones(64) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_default_dtype.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest _source_op_list = [ flow.ones, flow.zeros, flow.rand, flow.randn, flow.empty, flow.Tensor, ] class TestDefaultDTypeInferface(oneflow.unittest.TestCase): def test_set_default_dtype(test_case): flow.set_default_dtype(flow.float32) test_case.assertEqual(flow.get_default_dtype(), flow.float32) flow.set_default_dtype(flow.float64) test_case.assertEqual(flow.get_default_dtype(), flow.float64) for op in _source_op_list: x = op((2, 3)) test_case.assertEqual(x.dtype, flow.float64) x = op(2, 3) test_case.assertEqual(x.dtype, flow.float64) with test_case.assertRaises(Exception) as ctx: flow.set_default_dtype(flow.int32) test_case.assertTrue( "only floating-point types are supported as the default type" in str(ctx.exception) ) def test_set_default_tensor_type(test_case): flow.set_default_dtype(flow.float32) test_case.assertEqual(flow.get_default_dtype(), flow.float32) # set default tensor type by TensorType flow.set_default_tensor_type(flow.DoubleTensor) test_case.assertEqual(flow.get_default_dtype(), flow.float64) for op in _source_op_list: x = op((2, 3)) test_case.assertEqual(x.dtype, flow.float64) x = op(2, 3) test_case.assertEqual(x.dtype, flow.float64) # set default tensor type by TensorType string flow.set_default_tensor_type("oneflow.FloatTensor") test_case.assertEqual(flow.get_default_dtype(), flow.float32) for op in _source_op_list: x = op((2, 3)) test_case.assertEqual(x.dtype, flow.float32) def test_behavior_for_oneflow_tensor(test_case): # float32 scope flow.set_default_dtype(flow.float32) test_case.assertEqual(flow.get_default_dtype(), flow.float32) x = flow.tensor([1.0, 2]) test_case.assertEqual(x.dtype, flow.float32) # float64 scope flow.set_default_dtype(flow.float64) test_case.assertEqual(flow.get_default_dtype(), flow.float64) x = flow.tensor([1.0, 2]) test_case.assertEqual(x.dtype, flow.float64) # no affect for int type x = flow.tensor((2, 3)) test_case.assertEqual(x.dtype, flow.int64) # no affect for numpy array input nd_arr = np.array([1, 2, 3]).astype(np.float32) x = flow.tensor(nd_arr) test_case.assertEqual(x.dtype, flow.float32) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_deform_conv2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torchvision.ops import torch import oneflow as flow from oneflow.test_utils.automated_test_util.torch_flow_dual_object import random_tensor from oneflow.test_utils.test_util import GenArgList import oneflow.unittest def GetRandomData(max_batch_sz): batch_sz = max_batch_sz n_weight_grps = np.random.randint(1, 2) n_offset_grps = np.random.randint(1, 2) n_out_channels = n_offset_grps * np.random.randint(1, 15) n_in_channels = n_offset_grps * np.random.randint(1, 15) random_stride_h = np.random.randint(1, 5) random_stride_w = np.random.randint(1, 5) random_pad_h = np.random.randint(0, 3) random_pad_w = np.random.randint(0, 3) random_dilation_h = np.random.randint(1, 3) random_dilation_w = np.random.randint(1, 3) random_in_h = np.random.randint(5, 30) random_in_w = np.random.randint(5, 30) # BUG(yzm): Now use the rectangular convolution kernel is not aligned with PyTorch # NOTE: Modify the following program after alignment using a rectangular convolution kernel random_kernel_h = np.random.randint(1, 11) random_kernel_w = random_kernel_h # random_kernel_w=np.random.randint(1, 11) stride = (random_stride_h, random_stride_w) pad = (random_pad_h, random_pad_w) dilation = (random_dilation_h, random_dilation_w) return ( batch_sz, n_out_channels, n_in_channels, n_weight_grps, n_offset_grps, stride, pad, dilation, random_kernel_h, random_kernel_w, random_in_h, random_in_w, ) def GetFunArgs(device, max_batch_size): out_w = 0 out_h = 0 while out_w <= 0 or out_h <= 0: ( batch_sz, n_out_channels, n_in_channels, n_weight_grps, n_offset_grps, stride, pad, dilation, random_kernel_h, random_kernel_w, random_in_h, random_in_w, ) = GetRandomData(max_batch_size) stride_h, stride_w = stride pad_h, pad_w = pad dil_h, dil_w = dilation weight_h, weight_w = (random_kernel_h, random_kernel_w) in_h, in_w = (random_in_h, random_in_w) out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1 out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1 input_dims = [batch_sz, n_in_channels, in_h, in_w] offset_dims = [batch_sz, 2 * n_offset_grps * weight_h * weight_w, out_h, out_w] mask_dims = [batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w] weight_dims = [n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w] input = random_tensor(4, *input_dims).to(device) offset = random_tensor(4, *offset_dims).to(device) mask = random_tensor(4, *mask_dims).to(device) weight = random_tensor(4, *weight_dims).to(device) bias_dims = [n_out_channels] bias = random_tensor(1, *bias_dims).to(device) return input, weight, offset, mask, bias, stride, pad, dilation def _test_deform_conv2d_forward( test_case, input, weight, offset, mask, bias, stride, padding, dilation, ): torch_input = input.pytorch torch_weight = weight.pytorch torch_offset = offset.pytorch torch_mask = mask.pytorch torch_bias = bias.pytorch torch_out = torchvision.ops.deform_conv2d( torch_input, torch_offset, torch_weight, stride=stride, padding=padding, dilation=dilation, mask=torch_mask, bias=torch_bias, ) flow_input = input.oneflow flow_weight = weight.oneflow flow_offset = offset.oneflow flow_mask = mask.oneflow flow_bias = bias.oneflow flow_out = oneflow.nn.functional.deform_conv2d( flow_input, flow_offset, flow_weight, stride=stride, padding=padding, dilation=dilation, mask=flow_mask, bias=flow_bias, ) test_case.assertTrue( np.allclose( flow_out.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-2, atol=1e-2 ) ) def _test_deform_conv2d_backward( test_case, input, weight, offset, mask, bias, stride, padding, dilation ): torch_input = input.pytorch.detach().requires_grad_() torch_weight = weight.pytorch.detach().requires_grad_() torch_offset = offset.pytorch.detach().requires_grad_() torch_mask = mask.pytorch.detach().requires_grad_() torch_bias = bias.pytorch.detach().requires_grad_() torch_out = torchvision.ops.deform_conv2d( torch_input, torch_offset, torch_weight, stride=stride, padding=padding, dilation=dilation, mask=torch_mask, bias=torch_bias, ) torch_out.sum().backward() flow_input = input.oneflow.detach().requires_grad_() flow_weight = weight.oneflow.detach().requires_grad_() flow_offset = offset.oneflow.detach().requires_grad_() flow_mask = mask.oneflow.detach().requires_grad_() flow_bias = bias.oneflow.detach().requires_grad_() flow_out = oneflow.nn.functional.deform_conv2d( flow_input, flow_offset, flow_weight, stride=stride, padding=padding, dilation=dilation, mask=flow_mask, bias=flow_bias, ) flow_out.sum().backward() test_case.assertTrue( np.allclose( flow_input.grad.numpy(), torch_input.grad.cpu().numpy(), rtol=1e-2, atol=1e-2, ) ) test_case.assertTrue( np.allclose( flow_weight.grad.numpy(), torch_weight.grad.cpu().numpy(), rtol=1e-2, atol=1e-2, ) ) test_case.assertTrue( np.allclose( flow_offset.grad.numpy(), torch_offset.grad.cpu().numpy(), rtol=1e-2, atol=1e-2, ) ) test_case.assertTrue( np.allclose( flow_mask.grad.numpy(), torch_mask.grad.cpu().numpy(), rtol=1e-2, atol=1e-2 ) ) test_case.assertTrue( np.allclose( flow_bias.grad.numpy(), torch_bias.grad.cpu().numpy(), rtol=1e-5, atol=1e-5 ) ) def _test_forward_and_backward(test_case, device): max_batch_size = 40 for batch_size in range(1, max_batch_size): input, weight, offset, mask, bias, stride, padding, dilation = GetFunArgs( device, batch_size ) _test_deform_conv2d_forward( test_case, input, weight, offset, mask, bias, stride, padding, dilation ) _test_deform_conv2d_backward( test_case, input, weight, offset, mask, bias, stride, padding, dilation ) @flow.unittest.skip_unless_1n1d() class TestDeformConv2d(flow.unittest.TestCase): def test_deform_conv2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_forward_and_backward] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_det.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import re import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def det_random_device(): cuda_version = flow._oneflow_internal.flags.cuda_version() if cuda_version < 11000: # cuSOLVER is only supported in CUDA 11.0 and above return cpu_device() else: return random_device() @flow.unittest.skip_unless_1n1d() class TestLinalgDet(flow.unittest.TestCase): @autotest(n=5, rtol=1e-2, auto_backward=False) def test_det_3by3_with_random_data(test_case): device = det_random_device() x = random_tensor(ndim=2, dim0=3, dim1=3, low=-1).to(device) return torch.linalg.det(x) @autotest(n=5, rtol=1e-2, auto_backward=False) def test_det_batch_3by3_with_random_data(test_case): device = det_random_device() x = random_tensor(ndim=3, dim0=random(), dim1=3, dim2=3, low=-1).to(device) return torch.linalg.det(x) @autotest(n=5, rtol=1e-2, auto_backward=False) def test_det_random_square_with_random_data(test_case): device = det_random_device() square_dim = random() x = random_tensor(ndim=4, dim2=square_dim, dim3=square_dim, low=-1).to(device) return torch.linalg.det(x) @profile(torch.linalg.det) def profile_linalg_det(test_case): torch.linalg.det(torch.randn(1, 32, 4, 4)) torch.linalg.det(torch.randn(16, 32, 4, 4)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_diag.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as ori_torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class Test_Diag_module(flow.unittest.TestCase): @autotest(n=5) def test_diag_one_dim(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random()).to(device) return torch.diag(x) @autotest(n=5) def test_diag_other_dim(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device) return torch.diag(x) @autotest(auto_backward=False) def test_diag_one_dim(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random()).to(device, torch.bool) return torch.diag(x) def test_diag_0size_tensor(test_case): torch_tensor = ori_torch.empty(0).diag() flow_tensor = flow.empty(0).diag() test_case.assertTrue( np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape)) ) torch_tensor = ori_torch.empty(0, 0).diag() flow_tensor = flow.empty(0, 0).diag() test_case.assertTrue( np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape)) ) torch_tensor = ori_torch.empty(0, 3).diag() flow_tensor = flow.empty(0, 3).diag() test_case.assertTrue( np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape)) ) @profile(torch.diag) def profile_diag(test_case): torch.diag(torch.ones(1000)) torch.diag(torch.ones(128, 128)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_diagonal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestDiagonal(flow.unittest.TestCase): @autotest(n=10, check_graph=True) def test_flow_diagonal_with_random_data(test_case): device = random_device() offset = random(-5, 5).to(int) dim1 = random(-4, 4).to(int) dim2 = random(-4, 4).to(int) x = random_tensor( ndim=4, dim1=random(4, 6), dim2=random(4, 6), dim3=random(4, 6), dim4=random(4, 6), ).to(device) z = torch.diagonal(x, offset, dim1, dim2) return z @autotest(auto_backward=False, n=10, check_graph=True) def test_flow_diagonal_with_random_data(test_case): device = random_device() offset = random(-5, 5).to(int) dim1 = random(-4, 4).to(int) dim2 = random(-4, 4).to(int) x = random_tensor( ndim=4, dim1=random(4, 6), dim2=random(4, 6), dim3=random(4, 6), dim4=random(4, 6), ).to(device, torch.bool) z = torch.diagonal(x, offset, dim1, dim2) return z @profile(torch.diagonal) def profile_diagonal(test_case): input1 = torch.ones(128, 128) input2 = torch.ones(16, 10, 128, 128) torch.diagonal(input1, 0) torch.diagonal(input1, 1) torch.diagonal(input2, offset=-1, dim1=1, dim2=2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_div.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_div_impl(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.div(x, y) np_out = np.divide(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = 5 y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.div(x, y) np_out = np.divide(x, y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = 5 of_out = flow.div(x, y) np_out = np.divide(x.numpy(), y) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = flow.tensor( np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device) ) of_out = flow.div(x, y) np_out = np.divide(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor(np.array([5.0]), dtype=flow.float32, device=flow.device(device)) y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.div(x, y) np_out = np.divide(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.array([5.0]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.div(x, y) np_out = np.divide(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad_x = np.full(shape, 0.2) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestDiv(flow.unittest.TestCase): def test_div(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_div_impl(test_case, *arg) @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True) def test_random_dim_div(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) z = x / y return z @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True) def test_random_dim_scalar_div(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=0).to(device) z = x / y return z @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True) def test_0_size_div(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = random_tensor(4, 2, 1, 0, 3).to(device) z = x / y return z @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True) def test_0dim_div(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) z = x / y return z @autotest(n=10, include_complex=True) def test_non_contiguous_inplace_div(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) y = x + 1 y = y[:, 1:3] y /= random_tensor(2, 2, 2).to(device) return y @autotest(n=3, check_graph=False) def test_int_dtype_inplace_div(test_case): num_elems = 20 flow_out = flow.arange(num_elems) / num_elems torch_out = torch.arange(num_elems) / num_elems test_case.assertTrue(np.allclose(flow_out.numpy(), torch_out.numpy())) @autotest(n=5, include_complex=True) def test_scalar_div_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() x1 = random_tensor(2, 2, 3).to(x1_device).mean() x2 = random_tensor(2, 2, 3).to(x2_device) y = x1 / x2 return y @profile(torch.div) def profile_div(test_case): input1 = torch.ones(16, 10, 128, 128) input2 = torch.ones(16, 10, 128, 128) torch.div(input1, input2) @flow.unittest.skip_unless_1n1d() class TestDivRoundmode(flow.unittest.TestCase): @autotest(n=3) def test_random_dim_div_floor(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) z = torch.div(x, y, rounding_mode="floor") return z @autotest(n=3) def test_random_dim_div_trunc(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) z = torch.div(x, y, rounding_mode="trunc") return z @autotest(n=3) def test_scalar_div_mode_floor(test_case): device = random_device() x1 = random(low=1, high=5).to(float) x2 = random_tensor(2, 2, 3).to(device) y = torch.div(x1, x2, rounding_mode="floor") return y @autotest(n=3) def test_scalar_div_mode_trunc(test_case): device = random_device() x1 = random(low=1, high=5).to(float) x2 = random_tensor(2, 2, 3).to(device) y = torch.div(x1, x2, rounding_mode="trunc") return y @autotest(n=3) def test_scalar_div_mode_floor2(test_case): device = random_device() x1 = random(low=1, high=5).to(float) x2 = random_tensor(2, 2, 3).to(device) y = torch.div(x2, x1, rounding_mode="floor") return y @autotest(n=3) def test_scalar_div_mode_trunc2(test_case): device = random_device() x1 = random(low=1, high=5).to(float) x2 = random_tensor(2, 2, 3).to(device) y = torch.div(x2, x1, rounding_mode="trunc") return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_dlpack.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random import unittest import os import torch import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenCartesianProduct test_device_args = ( [("cpu",)] if os.getenv("ONEFLOW_TEST_CPU_ONLY") else [("cpu",), ("cuda", 0), ("cuda", 1)] ) test_args = list( GenCartesianProduct((test_device_args, [(torch, flow), (flow, torch)])) ) def are_tensors_equal(a, b): def are_devices_equal(a, b): if a.type == "cuda" and b.type == "cuda": return a.index == b.index else: return a.type == b.type return ( np.array_equal(a.cpu().numpy(), b.cpu().numpy()) and are_devices_equal(a.device, b.device) and a.shape == b.shape and a.stride() == b.stride() and a.cpu().numpy().dtype == b.cpu().numpy().dtype ) @flow.unittest.skip_unless_1n2d() class TestPack(flow.unittest.TestCase): def test_same_data(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) test_case.assertEqual(tensor2.storage_offset(), 0) tensor2[1:2, 2:3, 3:4] = random.random() # NOTE: OneFlow operations are asynchoronously executed, # so we need to synchronize explicitly here. flow._oneflow_internal.eager.Sync() test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) def test_use_ops(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) res1 = tensor1 ** 2 res2 = tensor2 ** 2 test_case.assertTrue(np.allclose(res1.cpu().numpy(), res2.cpu().numpy())) def test_more_dtype(test_case): # PyTorch bfloat16 tensor doesn't support .numpy() method # so we can't test it # torch.bfloat16, flow.bfloat16 dtypes = ["float64", "float32", "float16", "int64", "int32", "int8", "uint8"] for device_args, (m1, m2) in test_args: for dtype in dtypes: tensor1 = m1.ones( (2, 3), dtype=getattr(m1, dtype), device=m1.device(*device_args) ) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) def test_non_contiguous_input(test_case): for device_args, (m1, m2) in test_args: tensor1 = ( m1.randn(2, 3, 4, 5).permute(2, 0, 3, 1).to(m1.device(*device_args)) ) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) def test_scalar_tensor(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.tensor(5).to(m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) def test_0_size_tensor(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.tensor([]).to(m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) def test_lifecycle(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.randn(2, 3, 4, 5).to(m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) value = tensor1.cpu().numpy() del tensor2 if device_args[0] == "cuda": m2.cuda.synchronize() # actually release the cuda memory m2.cuda.empty_cache() test_case.assertTrue(np.array_equal(tensor1.cpu().numpy(), value)) tensor1 = m1.randn(2, 3, 4, 5).to(m1.device(*device_args)) tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) value = tensor2.cpu().numpy() del tensor1 if device_args[0] == "cuda": m1.cuda.synchronize() m1.cuda.empty_cache() test_case.assertTrue(np.array_equal(tensor2.cpu().numpy(), value)) def test_subview(test_case): for device_args, (m1, m2) in test_args: tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args)) tensor1 = tensor1[1:, :, ::2] tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1)) test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) test_case.assertEqual(tensor2.storage_offset(), 0) tensor2[1:2, ::2, 3:4] = random.random() test_case.assertTrue(are_tensors_equal(tensor1, tensor2)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_dot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestDot(flow.unittest.TestCase): @autotest(n=5) def test_dot(test_case): device = random_device() k = random(10, 100) x = random_tensor(ndim=1, dim0=k).to(device) y = random_tensor(ndim=1, dim0=k).to(device) z = torch.dot(x, y) return z @profile(torch.dot) def profile_dot(test_case): input1 = torch.ones(10000) input2 = torch.ones(10000) torch.dot(input1, input2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def do_test_dropout_numpy_p0(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) np_one_mask = np.ones_like(np_x) x_tensor = flow.tensor(np_x, requires_grad=True, device=device) out = flow._C.dropout(x_tensor, p=0.0) test_case.assertTrue(np.allclose(out.numpy(), np_x, atol=1e-5, rtol=1e-5)) out_sum = out.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_numpy_p1(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) np_zero_mask = np.zeros_like(np_x) x_tensor = flow.tensor(np_x, requires_grad=True, device=device) out = flow._C.dropout(x_tensor, p=1.0) test_case.assertTrue(np.allclose(out.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)) out_sum = out.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_numpy_fp16_p0(test_case, shape): np_x = np.random.randn(*shape).astype(np.float32) np_x_fp16 = np_x.astype(np.float16) x_tensor = flow.tensor(np_x, requires_grad=True, device="cuda") x_tensor_fp16 = flow.cast(x_tensor, flow.float16) np_one_mask = np.ones_like(np_x) out = flow._C.dropout(x_tensor_fp16, p=0.0) out_fp32 = flow.cast(out, flow.float32) test_case.assertTrue(np.allclose(out_fp32.numpy(), np_x_fp16, atol=1e-5, rtol=1e-5)) out_sum = out_fp32.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_numpy_fp16_p1(test_case, shape): np_x = np.random.randn(*shape).astype(np.float32) x_tensor = flow.tensor(np_x, requires_grad=True, device="cuda") x_tensor_fp16 = flow.cast(x_tensor, flow.float16) np_zero_mask = np.zeros_like(np_x) out = flow._C.dropout(x_tensor_fp16, p=1.0) out_fp32 = flow.cast(out, flow.float32) test_case.assertTrue( np.allclose(out_fp32.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5) ) out_sum = out_fp32.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_addend_numpy_p0(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) np_addend = np.random.randn(*shape).astype(dtype) np_one_mask = np.ones_like(np_x) x_tensor = flow.tensor(np_x, requires_grad=True, device=device) addend_tensor = flow.tensor(np_addend, requires_grad=True, device=device) DropoutModule = flow.nn.Dropout(p=0.0) out = DropoutModule(x_tensor, addend_tensor) test_case.assertTrue( np.allclose(out.numpy(), np_x + np_addend, atol=1e-5, rtol=1e-5) ) out_sum = out.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) test_case.assertTrue( np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_addend_numpy_p1(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) np_addend = np.random.randn(*shape).astype(dtype) np_one_mask = np.ones_like(np_x) np_zero_mask = np.zeros_like(np_x) x_tensor = flow.tensor(np_x, requires_grad=True, device=device) addend_tensor = flow.tensor(np_addend, requires_grad=True, device=device) DropoutModule = flow.nn.Dropout(p=1.0) out = DropoutModule(x_tensor, addend_tensor) test_case.assertTrue(np.allclose(out.numpy(), np_addend, atol=1e-5, rtol=1e-5)) out_sum = out.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5) ) test_case.assertTrue( np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_addend_numpy_fp16_p0(test_case, shape): np_x = np.random.randn(*shape).astype(np.float32) np_x_fp16 = np_x.astype(np.float16) np_addend = np.random.randn(*shape).astype(np.float32) np_addend_fp16 = np_addend.astype(np.float16) x_tensor = flow.tensor(np_x, requires_grad=True, device="cuda") x_tensor_fp16 = flow.cast(x_tensor, flow.float16) addend_tensor = flow.tensor(np_addend, requires_grad=True, device="cuda") addend_tensor_fp16 = flow.cast(addend_tensor, flow.float16) np_one_mask = np.ones_like(np_x) DropoutModule = flow.nn.Dropout(p=0.0) out = DropoutModule(x_tensor_fp16, addend_tensor_fp16) out_fp32 = flow.cast(out, flow.float32) test_case.assertTrue( np.allclose(out_fp32.numpy(), np_x_fp16 + np_addend_fp16, atol=1e-5, rtol=1e-5) ) out_sum = out_fp32.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) test_case.assertTrue( np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def do_test_dropout_addend_numpy_fp16_p1(test_case, shape): np_x = np.random.randn(*shape).astype(np.float32) np_addend = np.random.randn(*shape).astype(np.float32) np_addend_fp16 = np_addend.astype(np.float16) x_tensor = flow.tensor(np_x, requires_grad=True, device="cuda") x_tensor_fp16 = flow.cast(x_tensor, flow.float16) addend_tensor = flow.tensor(np_addend, requires_grad=True, device="cuda") addend_tensor_fp16 = flow.cast(addend_tensor, flow.float16) np_zero_mask = np.zeros_like(np_x) np_one_mask = np.ones_like(np_x) DropoutModule = flow.nn.Dropout(p=1.0) out = DropoutModule(x_tensor_fp16, addend_tensor_fp16) out_fp32 = flow.cast(out, flow.float32) test_case.assertTrue( np.allclose(out_fp32.numpy(), np_addend_fp16, atol=1e-5, rtol=1e-5) ) out_sum = out_fp32.sum() out_sum.backward() test_case.assertTrue( np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5) ) test_case.assertTrue( np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5) ) def fixed_cpu_seed_dropout_test(test_case): gen1 = flow.Generator() gen1.manual_seed(5) dropped_array1 = np.array( [ [0.000000, 0.000000, 1.333333], [1.333333, 0.000000, 1.333333], [1.333333, 1.333333, 1.333333], ] ).astype(np.float32) dropout1 = flow.nn.Dropout(p=0.25, generator=gen1) x = flow.ones((3, 3), dtype=flow.float32) out1 = dropout1(x) test_case.assertTrue( np.allclose(out1.numpy(), dropped_array1, atol=1e-4, rtol=1e-4) ) gen2 = flow.Generator() gen2.manual_seed(7) dropout2 = flow.nn.Dropout(p=0.5, generator=gen2) dropped_array2 = np.array( [[0.0, 0.0, 2.0], [0.0, 0.0, 2.0], [2.0, 0.0, 2.0]] ).astype(np.float32) out2 = dropout2(x) test_case.assertTrue( np.allclose(out2.numpy(), dropped_array2, atol=1e-4, rtol=1e-4) ) def fixed_gpu_seed_dropout_test(test_case): gen1 = flow.Generator() gen1.manual_seed(5) dropped_array1 = np.array( [[1.2500, 0.0000, 1.2500], [1.2500, 1.2500, 1.2500], [1.2500, 1.2500, 1.2500]] ).astype(np.float32) dropout1 = flow.nn.Dropout(p=0.2, generator=gen1).to("cuda") x = flow.ones((3, 3), dtype=flow.float32).to("cuda") out1 = dropout1(x) test_case.assertTrue( np.allclose(out1.numpy(), dropped_array1, atol=1e-4, rtol=1e-4) ) gen2 = flow.Generator() gen2.manual_seed(7) dropout2 = flow.nn.Dropout(p=0.7, generator=gen2).to("cuda") dropped_array2 = np.array( [ [3.333333, 3.333333, 0.000000], [0.000000, 0.000000, 0.000000], [0.000000, 0.000000, 0.000000], ] ).astype(np.float32) out2 = dropout2(x) test_case.assertTrue( np.allclose(out2.numpy(), dropped_array2, atol=1e-4, rtol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_dropout_numpy_case(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [do_test_dropout_numpy_p0, do_test_dropout_numpy_p1] arg_dict["shape"] = [[4], [4, 3], [4, 127, 256], [2, 1024, 1024]] arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["dtype"] = [np.float32, np.float64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_dropout_fp16_numpy_case(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ do_test_dropout_numpy_fp16_p0, do_test_dropout_numpy_fp16_p1, ] arg_dict["shape"] = [[4, 127, 256], [5, 63, 49], [7, 32, 64], [16, 512, 512]] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_dropout_addend_numpy_case(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ do_test_dropout_addend_numpy_p0, do_test_dropout_addend_numpy_p1, ] arg_dict["shape"] = [[4, 47, 156], [5, 33, 65], [3, 132, 94], [9, 256, 63]] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [np.float32, np.float64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_dropout_addend_fp16_numpy_case(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ do_test_dropout_addend_numpy_fp16_p0, do_test_dropout_addend_numpy_fp16_p1, ] arg_dict["shape"] = [[2, 44, 66], [1, 2, 7], [5, 32, 74], [8, 125, 63]] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_cpu_fixed_dropout(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ fixed_cpu_seed_dropout_test, ] for arg in GenArgList(arg_dict): arg[0](test_case) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_gpu_fixed_dropout(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ fixed_gpu_seed_dropout_test, ] for arg in GenArgList(arg_dict): arg[0](test_case) @autotest(n=5) def test_dropout_p0(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout(p=0, inplace=False) return m(x) @unittest.skipIf(True, "Pytorch 1.10.0 do not have Dropout1d module") @autotest(n=5) def test_dropout1d_p0(test_case): device = random_device() x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device) m = torch.nn.Dropout1d(p=0, inplace=False) return m(x) @autotest(n=5) def test_dropout2d_p0(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout2d(p=0, inplace=False) return m(x) @unittest.skipIf( True, "this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad", ) @autotest(n=5) def test_dropout3d_p0(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout3d(p=0, inplace=False) return m(x) @autotest(n=5) def test_dropout_p1(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout(p=1.0, inplace=False) return m(x) @unittest.skipIf(True, "Pytorch 1.10.0 do not have Dropout1d module") @autotest(n=5) def test_dropout1d_p1(test_case): device = random_device() x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device) m = torch.nn.Dropout1d(p=1.0, inplace=False) return m(x) @unittest.skip("skip for now, becase it failed 8 times in past week") @autotest(n=5) def test_dropout2d_p1(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout2d(p=1.0, inplace=False) return m(x) @unittest.skipIf( True, "this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad", ) @autotest(n=5) def test_dropout3d_p1(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout3d(p=1.0, inplace=False) return m(x) @unittest.skipIf(True, "Pytorch 1.10.0 do not have Dropout1d module") @autotest(n=5) def test_functional_dropout1d_p1(test_case): device = random_device() x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device) return torch.nn.functional.dropout1d(x, p=1.0) @autotest(n=5) def test_functional_dropout2d_p1(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) return torch.nn.functional.dropout2d(x, p=1.0) @unittest.skipIf( True, "this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad", ) @autotest(n=5) def test_functional_dropout3d_p1(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) return torch.nn.functional.dropout3d(x, p=1.0) @autotest(n=5, check_graph=False) def test_dropout_eval(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout(p=1.0, inplace=False) m.eval() return m(x) @unittest.skipIf(True, "Pytorch 1.10.0 do not have Dropout1d module") @autotest(n=5, check_graph=False) def test_dropout1d_eval(test_case): device = random_device() x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device) m = torch.nn.Dropout1d(p=1.0, inplace=False) m.eval() return m(x) @autotest(n=5, check_graph=False) def test_dropout2d_eval(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout2d(p=1.0, inplace=False) m.eval() return m(x) @unittest.skipIf( True, "this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad", ) @autotest(n=5, check_graph=False) def test_dropout3d_eval(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) m = torch.nn.Dropout3d(p=1.0, inplace=False) m.eval() return m(x) @autotest(n=5, check_graph=False) def test_0dim_dropout_eval(test_case): device = random_device() x = random_tensor(ndim=0).to(device) m = torch.nn.Dropout(p=1.0, inplace=False) m.eval() return m(x) @profile(torch.nn.functional.dropout) def profile_dropout(test_case): input = torch.ones(100, 128) torch.nn.functional.dropout(input, p=0.3) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestDropoutOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x = flow.tensor([2, 3], dtype=flow.float, device="cuda:1") y = flow._C.dropout(x) test_case.assertEqual(y.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_dynamic_allocation_gradient_shuffle_shuffle_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "1" import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow parallel_num = 2 max_id = 1000 def get_tensors(batch_size, num_tables): placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) ids = np.random.randint(0, max_id, (batch_size, num_tables), dtype=np.int64) ids_tensor = flow.tensor(ids, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.split(0) ) table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids table_ids_tensor = flow.tensor( table_ids.astype(np.int32), requires_grad=False ).to_global(placement=placement, sbp=flow.sbp.split(0)) return ids_tensor, table_ids_tensor def round_half_away_from_zero(x): sign = np.sign(x) abs_val = np.abs(x) abs_val += 0.5 floor_val = np.floor(abs_val) out = floor_val * sign return out def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): np_tolerance = 0 batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" else: if fp16: np_tolerance = 1e-2 else: np_tolerance = 1e-4 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype( np.float32 ) embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.split(0) ) class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, ids, table_ids, embedding_grad): ( num_unique_matrix, inverse_unique_partition_indices, cur_rank_num_unique, cur_rank_unique_ids, _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, "test") if fp16: embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, "test", ) if fp16: cur_rank_unique_embedding_grad = flow.cast( cur_rank_unique_embedding_grad, flow.float32 ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_num_unique, flow.int32), cur_rank_unique_ids, ) graph = TestGraph() for i in range(10): ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables) ( cur_rank_unique_embedding_grad, local_cur_rank_num_unique, cur_rank_unique_ids, ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor) cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global( placement=placement, sbp=flow.sbp.split(0) ) global_ids = ids_tensor.numpy() global_embedding_grad = embedding_grad_tensor.numpy() np_unique_ids = np.unique(global_ids) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size)) if fp16: global_embedding_grad = global_embedding_grad.astype(np.float16) for k in range(np_num_unique): unique_id = np_unique_ids[k] np_data = sum( global_embedding_grad.reshape(-1, embedding_size)[ np.where(global_ids.flatten() == unique_id)[0] ] ) # Quantize Embedding Gradient. if enable_quantized_comm: abs_max_factor = np.max(np.abs(np_data)) int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) quantize_factor = int8_factor / abs_max_factor np_data = np_data * quantize_factor np_data = round_half_away_from_zero(np_data) np_data = np_data.astype(np.int8) np_data = np_data.astype(np.float32) dequantize_factor = abs_max_factor / int8_factor np_data = np_data * dequantize_factor np_cur_rank_unique_embedding_grad[unique_id, :] = np_data if fp16: np_cur_rank_unique_embedding_grad = np_cur_rank_unique_embedding_grad.astype( np.float32 ) cur_rank_num_ids = batch_size * num_tables * parallel_num of_unique_embedding_grad = np.zeros((max_id, embedding_size)) for i in range(parallel_num): num_unique_i = cur_rank_num_unique.numpy()[i] unique_ids_i = cur_rank_unique_ids.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] unique_embedding_grad_i = cur_rank_unique_embedding_grad.numpy()[ cur_rank_num_ids * i : cur_rank_num_ids * (i + 1) ] for j in range(num_unique_i): unique_id = unique_ids_i[j] of_unique_embedding_grad[unique_id, :] = unique_embedding_grad_i[j, :] test_case.assertTrue( np.allclose( of_unique_embedding_grad, np_cur_rank_unique_embedding_grad, atol=np_tolerance, rtol=np_tolerance, ), ) # FIXME: restore this test after upgrading CUDA driver @unittest.skip("CUDA driver version of CI machine is insufficient for this test") # @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class DataShuffleTestCase(flow.unittest.TestCase): def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] arg_dict["fp16"] = [True, False] arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_eager_boxing.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import os import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_eager_boxing_with_non_overlapping_placement_p_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[6, 16], [9, 17], [7, 13], [12, 16],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[15, 27], [19, 5], [11, 9], [15, 4],], dtype=np.float32,), ) ) def _test_eager_boxing_with_non_overlapping_placement_b_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_non_overlapping_placement_s0_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_non_overlapping_placement_s1_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_non_overlapping_placement_s1_to_s0( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.split(0)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_non_overlapping_placement_s1_to_b( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0], [2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0], [2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_non_overlapping_placement_s1_to_p( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.partial_sum) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 0, 0], [6, 8, 0, 0], [3, 7, 0, 0], [6, 8, 0, 0], [2, 10, 0, 0], [3, 9, 0, 0], [4, 6, 0, 0], [6, 8, 0, 0], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 5, 20], [0, 0, 9, 0], [0, 0, 5, 0], [0, 0, 9, 0], [0, 0, 10, 7], [0, 0, 10, 5], [0, 0, 6, 9], [0, 0, 6, 4], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_overlapping_placement_p_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[15, 20], [16, 19], [13, 16], [15, 23],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[20, 35], [28, 10], [20, 11], [20, 12],], dtype=np.float32,), ) ) def _test_eager_boxing_with_overlapping_placement_b_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_overlapping_placement_s0_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement(out_device, ranks=[2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8], [9, 4], [7, 2], [6, 3], [3, 7], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], [5, 8], [9, 5], [9, 2], [5, 8], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_overlapping_placement_s1_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5], [6, 8, 9], [3, 7, 5], [6, 8, 9], [2, 10, 10], [3, 9, 10], [4, 6, 6], [6, 8, 6], [9, 4, 5], [7, 2, 9], [6, 3, 9], [3, 7, 5], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [20, 8, 9], [0, 4, 6], [0, 3, 5], [0, 8, 7], [7, 10, 3], [5, 5, 6], [9, 8, 6], [4, 5, 3], [8, 9, 6], [5, 4, 1], [2, 5, 2], [8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_overlapping_placement_s1_to_s0( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.split(0)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_overlapping_placement_s1_to_b( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_overlapping_placement_s1_to_p( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[2, 3]) z = y.to_global(new_placement, flow.sbp.partial_sum) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [2, 10, 0, 0, 0, 0], [3, 9, 0, 0, 0, 0], [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [9, 4, 0, 0, 0, 0], [7, 2, 0, 0, 0, 0], [6, 3, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 5, 20, 8, 9], [0, 0, 9, 0, 4, 6], [0, 0, 5, 0, 3, 5], [0, 0, 9, 0, 8, 7], [0, 0, 10, 7, 10, 3], [0, 0, 10, 5, 5, 6], [0, 0, 6, 9, 8, 6], [0, 0, 6, 4, 5, 3], [0, 0, 5, 8, 9, 6], [0, 0, 9, 5, 4, 1], [0, 0, 9, 2, 5, 2], [0, 0, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(out_device, ranks=[1, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[15, 20], [16, 19], [13, 16], [15, 23],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[20, 35], [28, 10], [20, 11], [20, 12],], dtype=np.float32,), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement(out_device, ranks=[1, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement(out_device, ranks=[1, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8], [9, 4], [7, 2], [6, 3], [3, 7], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], [5, 8], [9, 5], [9, 2], [5, 8], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 2, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[1, 3]) z = y.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 2, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[1, 3]) z = y.to_global(new_placement, flow.sbp.split(0)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array([[4, 6, 5, 20], [6, 8, 9, 0],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array([[3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 2, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[1, 3]) z = y.to_global(new_placement, flow.sbp.partial_sum) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 5, 0], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[0, 0, 0, 20], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 2, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[1, 3]) z = y.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[15], [16], [13], [15],], dtype=np.float32,), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[20], [19], [16], [23],], dtype=np.float32,), ) ) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[20], [28], [20], [20],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[35], [10], [11], [12],], dtype=np.float32,), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[4], [6], [3], [6],], dtype=np.float32,), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[6], [8], [7], [8],], dtype=np.float32,), ) ) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[5], [9], [5], [9],], dtype=np.float32,), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array([[20], [0], [0], [0],], dtype=np.float32,), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) y = x.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[4], [6], [3], [6], [2], [3], [4], [6], [9], [7], [6], [3],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[6], [8], [7], [8], [10], [9], [6], [8], [4], [2], [3], [7],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[5], [9], [5], [9], [10], [10], [6], [6], [5], [9], [9], [5],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[20], [0], [0], [0], [7], [5], [9], [4], [8], [5], [2], [8],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) z = y.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(z.placement, new_placement) test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) z = y.to_global(new_placement, flow.sbp.partial_sum) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [2, 10, 0, 0, 0, 0], [3, 9, 0, 0, 0, 0], [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [9, 4, 0, 0, 0, 0], [7, 2, 0, 0, 0, 0], [6, 3, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], ], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 5, 20, 0, 0], [0, 0, 9, 0, 0, 0], [0, 0, 5, 0, 0, 0], [0, 0, 9, 0, 0, 0], [0, 0, 10, 7, 0, 0], [0, 0, 10, 5, 0, 0], [0, 0, 6, 9, 0, 0], [0, 0, 6, 4, 0, 0], [0, 0, 5, 8, 0, 0], [0, 0, 9, 5, 0, 0], [0, 0, 9, 2, 0, 0], [0, 0, 5, 8, 0, 0], ], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ], dtype=np.float32, ), ) ) else: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 0, 0, 8, 9], [0, 0, 0, 0, 4, 6], [0, 0, 0, 0, 3, 5], [0, 0, 0, 0, 8, 7], [0, 0, 0, 0, 10, 3], [0, 0, 0, 0, 5, 6], [0, 0, 0, 0, 8, 6], [0, 0, 0, 0, 5, 3], [0, 0, 0, 0, 9, 6], [0, 0, 0, 0, 4, 1], [0, 0, 0, 0, 5, 2], [0, 0, 0, 0, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) z = y.to_global(new_placement, flow.sbp.split(0)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5],], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6],], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6],], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1( test_case, in_device, out_device ): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9, 5, 20], [6, 8, 9, 0, 4, 6, 9, 0], [3, 7, 5, 0, 3, 5, 0, 3], [6, 8, 9, 0, 8, 7, 8, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3, 10, 7], [3, 9, 10, 5, 5, 6, 9, 10], [4, 6, 6, 9, 8, 6, 6, 9], [6, 8, 6, 4, 5, 3, 8, 6], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6, 8, 3], [4, 9, 7, 0, 2, 1, 9, 7], [2, 5, 7, 9, 4, 8, 5, 7], [6, 8, 10, 0, 4, 9, 8, 10], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6, 5, 8], [7, 2, 9, 5, 4, 1, 7, 2], [6, 3, 9, 2, 5, 2, 9, 2], [3, 7, 5, 8, 9, 3, 7, 5], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3]) z = y.to_global(new_placement, flow.sbp.split(1)) test_case.assertEqual(z.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], ], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [[8, 9], [4, 6], [3, 5], [8, 7], [10, 3], [5, 6], [8, 6], [5, 3],], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [5, 20], [9, 0], [0, 3], [8, 9], [10, 7], [9, 10], [6, 9], [8, 6], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_p_to_s1(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [6, 8, 9, 0, 4, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [4, 9, 7, 0, 2, 1], [6, 3, 9, 2, 5, 2], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], [6, 3, 9, 2, 5, 2], [2, 5, 7, 9, 4, 8], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], [7, 2, 9, 5, 4, 1], [4, 9, 7, 0, 2, 1], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.partial_sum) y = x.to_global(placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[15, 20], [16, 19], [13, 16], [15, 23], [17, 19], [16, 20],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[20, 35], [28, 10], [20, 11], [20, 12], [25, 5], [22, 6],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[27, 18], [13, 13], [16, 13], [22, 13], [10, 8], [12, 6],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_b_to_s1(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [6, 8, 9, 0, 4, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [4, 9, 7, 0, 2, 1], [6, 3, 9, 2, 5, 2], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], [6, 3, 9, 2, 5, 2], [2, 5, 7, 9, 4, 8], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], [7, 2, 9, 5, 4, 1], [4, 9, 7, 0, 2, 1], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.broadcast) y = x.to_global(placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[4, 6], [6, 8], [3, 7], [6, 8], [6, 8], [6, 8],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[5, 20], [9, 0], [5, 0], [9, 0], [9, 0], [6, 4],], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[8, 9], [4, 6], [3, 5], [8, 7], [4, 6], [5, 3],], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_s0_to_s1(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) test_case.assertEqual(y.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8], [9, 4], [7, 2], [6, 3], [3, 7], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], [5, 8], [9, 5], [9, 2], [5, 8], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [ [8, 9], [4, 6], [3, 5], [8, 7], [10, 3], [5, 6], [8, 6], [5, 3], [9, 6], [4, 1], [5, 2], [9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_s1_to_s1(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) z = y.to_global(placement, flow.sbp.split(1)) test_case.assertEqual(z.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8], [9, 4], [7, 2], [6, 3], [3, 7], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [5, 20], [9, 0], [5, 0], [9, 0], [10, 7], [10, 5], [6, 9], [6, 4], [5, 8], [9, 5], [9, 2], [5, 8], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [8, 9], [4, 6], [3, 5], [8, 7], [10, 3], [5, 6], [8, 6], [5, 3], [9, 6], [4, 1], [5, 2], [9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_s1_to_s0(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) z = y.to_global(placement, flow.sbp.split(0)) test_case.assertEqual(z.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_s1_to_p(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) z = y.to_global(placement, flow.sbp.partial_sum) test_case.assertEqual(z.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [2, 10, 0, 0, 0, 0], [3, 9, 0, 0, 0, 0], [4, 6, 0, 0, 0, 0], [6, 8, 0, 0, 0, 0], [9, 4, 0, 0, 0, 0], [7, 2, 0, 0, 0, 0], [6, 3, 0, 0, 0, 0], [3, 7, 0, 0, 0, 0], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 5, 20, 0, 0], [0, 0, 9, 0, 0, 0], [0, 0, 5, 0, 0, 0], [0, 0, 9, 0, 0, 0], [0, 0, 10, 7, 0, 0], [0, 0, 10, 5, 0, 0], [0, 0, 6, 9, 0, 0], [0, 0, 6, 4, 0, 0], [0, 0, 5, 8, 0, 0], [0, 0, 9, 5, 0, 0], [0, 0, 9, 2, 0, 0], [0, 0, 5, 8, 0, 0], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [0, 0, 0, 0, 8, 9], [0, 0, 0, 0, 4, 6], [0, 0, 0, 0, 3, 5], [0, 0, 0, 0, 8, 7], [0, 0, 0, 0, 10, 3], [0, 0, 0, 0, 5, 6], [0, 0, 0, 0, 8, 6], [0, 0, 0, 0, 5, 3], [0, 0, 0, 0, 9, 6], [0, 0, 0, 0, 4, 1], [0, 0, 0, 0, 5, 2], [0, 0, 0, 0, 9, 3], ], dtype=np.float32, ), ) ) def _test_eager_boxing_with_same_placement_s1_to_b(test_case, in_device, out_device): if flow.env.get_rank() == 0: np_arr = np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], ], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [ [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], ], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [ [9, 6, 5, 8, 3, 6], [4, 9, 7, 0, 2, 1], [2, 5, 7, 9, 4, 8], [6, 8, 10, 0, 4, 9], ], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [ [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ) device = flow.device(in_device) tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement(in_device, ranks=[0, 1, 3]) x = tensor.to_global(placement, flow.sbp.split(0)) y = x.to_global(placement, flow.sbp.split(1)) z = y.to_global(placement, flow.sbp.broadcast) test_case.assertEqual(z.placement, placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( z.to_local().numpy(), np.array( [ [4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5], [6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6], [4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6], [7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3], ], dtype=np.float32, ), ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_eager_boxing_b_to_s( test_case, shape, device_type, in_device_list, out_device_list, out_split_axis ): np_arr = np.random.uniform(-1e-05, 1e-05, shape) # use cuda to avoid slice boxing here placement_with_all_cuda_device = flow.placement.all("cuda") x = flow.tensor(np_arr, device="cuda", dtype=flow.float32) x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast) placement = flow.placement(device_type, in_device_list) y = x.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement(device_type, out_device_list) z = y.to_global(new_placement, flow.sbp.split(out_split_axis)) if flow.env.get_rank() in out_device_list: idx = out_device_list.index(flow.env.get_rank()) step = int(shape[out_split_axis] / len(out_device_list)) if out_split_axis == 0: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) elif out_split_axis == 1: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[..., idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) else: raise "only test case with out_split_axis == 0 or out_split_axis == 1" @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_eager_boxing_s_to_b( test_case, shape, device_type, in_device_list, out_device_list, in_split_axis ): np_arr = np.random.uniform(-1e-05, 1e-05, shape) # use cuda to avoid slice boxing here placement_with_all_cuda_device = flow.placement.all("cuda") x = flow.tensor(np_arr, device="cuda", dtype=flow.float32) x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast) placement = flow.placement(device_type, in_device_list) y = x.to_global(placement, flow.sbp.broadcast) y = y.to_global(placement, flow.sbp.split(in_split_axis)) new_placement = flow.placement(device_type, out_device_list) z = y.to_global(new_placement, flow.sbp.broadcast) if flow.env.get_rank() in out_device_list: test_case.assertTrue( np.allclose(z.to_local().numpy(), x.to_local().numpy(), 1e-5, 1e-5,) ) test_case.assertEqual(z.placement, new_placement) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_eager_boxing_p_to_s( test_case, shape, device_type, in_device_list, out_device_list, out_split_axis ): np_arr = np.random.uniform(-1e-05, 1e-05, shape) # use cuda to avoid slice boxing here placement_with_all_cuda_device = flow.placement.all("cuda") x = flow.tensor(np_arr, device="cuda", dtype=flow.float32) x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast) placement = flow.placement(device_type, in_device_list) y = x.to_global(placement, flow.sbp.broadcast) y = y.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(device_type, out_device_list) z = y.to_global(new_placement, flow.sbp.split(out_split_axis)) if flow.env.get_rank() in out_device_list: idx = out_device_list.index(flow.env.get_rank()) step = int(shape[out_split_axis] / len(out_device_list)) if out_split_axis == 0: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) elif out_split_axis == 1: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[..., idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) else: raise "only test case with out_split_axis == 0 or out_split_axis == 1" @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_eager_boxing_p_to_b( test_case, shape, device_type, in_device_list, out_device_list ): np_arr = np.random.uniform(-1e-05, 1e-05, shape) # use cuda to avoid slice boxing here placement_with_all_cuda_device = flow.placement.all("cuda") x = flow.tensor(np_arr, device="cuda", dtype=flow.float32) x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast) placement = flow.placement(device_type, in_device_list) y = x.to_global(placement, flow.sbp.broadcast) y = y.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement(device_type, out_device_list) z = y.to_global(new_placement, flow.sbp.broadcast) if flow.env.get_rank() in out_device_list: test_case.assertTrue( np.allclose(z.to_local().numpy(), x.to_local().numpy(), 1e-5, 1e-5,) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_eager_naive_boxing_s_to_s( test_case, device_type, shape, in_device_list, out_device_list, in_split_axis, out_split_axis, ): np_arr = np.random.uniform(-1e-05, 1e-05, shape) placement_with_all_cuda_device = flow.placement.all(device_type) x = flow.tensor(np_arr, device=device_type, dtype=flow.float32) x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast) placement = flow.placement(device_type, in_device_list) y = x.to_global(placement, flow.sbp.broadcast) y = y.to_global(placement, flow.sbp.split(in_split_axis)) new_placement = flow.placement(device_type, out_device_list) z = y.to_global(new_placement, flow.sbp.split(out_split_axis)) if flow.env.get_rank() in out_device_list: idx = out_device_list.index(flow.env.get_rank()) step = int(shape[out_split_axis] / len(out_device_list)) if out_split_axis == 0: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) elif out_split_axis == 1: test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[..., idx * step : (idx + 1) * step], 1e-5, 1e-5, ) ) else: raise "only test case with out_split_axis == 0 or out_split_axis == 1" test_case.assertEqual(z.placement, new_placement) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingWithNonOverlappingPlacement(flow.unittest.TestCase): def test_eager_boxing_with_non_overlapping_placement_p_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_p_to_s1(test_case, *arg) def test_eager_boxing_with_non_overlapping_placement_b_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_b_to_s1(test_case, *arg) def test_eager_boxing_with_non_overlapping_placement_s0_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_s0_to_s1(test_case, *arg) def test_eager_boxing_with_non_overlapping_placement_s1_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_s1_to_s1(test_case, *arg) def test_eager_boxing_with_non_overlapping_placement_s1_to_s0(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_s1_to_s0(test_case, *arg) def test_eager_boxing_with_non_overlapping_placement_s1_to_b(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_s1_to_b(test_case, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_non_overlapping_placement_s1_to_p(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_non_overlapping_placement_s1_to_p(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingWithOverlappingPlacement(flow.unittest.TestCase): def test_eager_boxing_with_overlapping_placement_p_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_p_to_s1(test_case, *arg) def test_eager_boxing_with_overlapping_placement_b_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_b_to_s1(test_case, *arg) def test_eager_boxing_with_overlapping_placement_s0_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_s0_to_s1(test_case, *arg) def test_eager_boxing_with_overlapping_placement_s1_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_s1_to_s1(test_case, *arg) def test_eager_boxing_with_overlapping_placement_s1_to_s0(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_s1_to_s0(test_case, *arg) def test_eager_boxing_with_overlapping_placement_s1_to_b(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_s1_to_b(test_case, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_overlapping_placement_s1_to_p(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_overlapping_placement_s1_to_p(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingWithInPlacementContainOutPlacement(flow.unittest.TestCase): def test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1( test_case, *arg ) def test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1( test_case, *arg ) def test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1( test_case, *arg ) def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1( test_case, *arg ) def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0( test_case, *arg ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p( test_case, *arg ) def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b( test_case, *arg ) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingWithOutPlacementContainInPlacement(flow.unittest.TestCase): def test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1( test_case, *arg ) def test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1( test_case, *arg ) def test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1( test_case, *arg ) def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b( test_case, *arg ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p( test_case, *arg ) def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0( test_case, *arg ) def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1( test_case, *arg ) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingWithSameInOutPlacement(flow.unittest.TestCase): def test_eager_boxing_with_same_placement_s0_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_s0_to_s1(test_case, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_same_placement_p_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_p_to_s1(test_case, *arg) def test_eager_boxing_with_same_placement_b_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_b_to_s1(test_case, *arg) def test_eager_boxing_with_same_placement_s1_to_s1(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_s1_to_s1(test_case, *arg) def test_eager_boxing_with_same_placement_s1_to_s0(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_s1_to_s0(test_case, *arg) def test_eager_boxing_with_same_placement_s1_to_p(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_s1_to_p(test_case, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_boxing_with_same_placement_s1_to_b(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_with_same_placement_s1_to_b(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingBToS(flow.unittest.TestCase): def test_eager_boxing_b_to_s(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(12, 12), (18, 24)] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3]] arg_dict["out_device_list"] = [[2, 3], [0, 1, 3]] arg_dict["out_split_axis"] = [0, 1] for arg in GenArgList(arg_dict): _test_eager_boxing_b_to_s(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingPToS(flow.unittest.TestCase): def test_eager_boxing_p_to_s(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(12, 12), (18, 24)] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3]] arg_dict["out_device_list"] = [[2, 3], [0, 1, 3]] arg_dict["out_split_axis"] = [0, 1] for arg in GenArgList(arg_dict): _test_eager_boxing_p_to_s(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingSToB(flow.unittest.TestCase): def test_eager_boxing_s_to_b(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(12, 12), (12, 18, 24)] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3]] arg_dict["out_device_list"] = [[2, 3], [0, 1, 3]] arg_dict["in_split_axis"] = [0, 1] for arg in GenArgList(arg_dict): _test_eager_boxing_s_to_b(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingPToB(flow.unittest.TestCase): def test_eager_boxing_p_to_b(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(12, 12), (12, 18, 24)] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3]] arg_dict["out_device_list"] = [[2, 3], [0, 1, 3]] for arg in GenArgList(arg_dict): _test_eager_boxing_p_to_b(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerNaiveBoxingSToS(flow.unittest.TestCase): def test_eager_naive_boxing_s_to_s(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["shape"] = [(12, 12), (18, 24)] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3]] arg_dict["out_device_list"] = [[1], [3], [2, 3], [0, 1, 3]] arg_dict["in_split_axis"] = [0, 1] arg_dict["out_split_axis"] = [0, 1] for arg in GenArgList(arg_dict): _test_eager_naive_boxing_s_to_s(test_case, *arg) @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerGlobalCastWithSamePlacementAndSBP(flow.unittest.TestCase): def test_eager_global_cast_with_same_placement_and_sbp(test_case): x = np.ones((4, 8), dtype=np.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = flow.tensor( x, dtype=flow.float32, placement=placement, sbp=[flow.sbp.split(0)], requires_grad=False, ) z = y.to_global(placement=placement, sbp=[flow.sbp.split(0)]) test_case.assertEqual(y.global_id(), z.global_id()) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerGlobalCast1DTo2DSBP(flow.unittest.TestCase): def test_eager_global_cast_1d_to_2d_sbp(test_case): x = np.ones((4, 8), dtype=np.int32) placement1 = flow.placement("cuda", ranks=[0, 1, 2, 3]) placement2 = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) y = flow.tensor( x, dtype=flow.float32, placement=placement1, sbp=[flow.sbp.split(0)], requires_grad=False, ) z = y.to_global( placement=placement2, sbp=[flow.sbp.broadcast, flow.sbp.split(0)] ) test_case.assertEqual(z.placement, placement2) test_case.assertTrue( np.array_equal(z.to_local().numpy(), np.ones((2, 8), dtype=np.int32),) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerGlobalCast2DTo1DSBP(flow.unittest.TestCase): def test_eager_global_cast_2d_to_1d_sbp(test_case): x = np.ones((4, 8), dtype=np.int32) placement1 = flow.placement("cuda", ranks=[0, 1, 2, 3]) placement2 = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) y = flow.tensor( x, dtype=flow.float32, placement=placement2, sbp=[flow.sbp.broadcast, flow.sbp.split(0)], requires_grad=False, ) z = y.to_global(placement=placement1, sbp=[flow.sbp.split(0)]) test_case.assertEqual(z.placement, placement1) test_case.assertTrue( np.array_equal(z.to_local().numpy(), np.ones((1, 8), dtype=np.int32),) ) def _test_eager_global_cast_1d_uneven_split(test_case, device_type, shape): np_arr = np.random.uniform(-1e-05, 1e-05, shape) placement = flow.placement(device_type, range(flow.env.get_world_size())) x = flow.tensor( np_arr, dtype=flow.float32, device=device_type, requires_grad=False, ) x = x.to_global(placement=placement, sbp=[flow.sbp.broadcast]) # B To S(0) y = x.to_global(placement=placement, sbp=[flow.sbp.split(0)]) from oneflow.framework import balanced_splitter as balanced_splitter s0_balanced_ranges = balanced_splitter.BalancedRanges( shape[0], flow.env.get_world_size() ) s0_range_of_this_rank = s0_balanced_ranges[flow.env.get_rank()] test_case.assertEqual(y.placement, placement) test_case.assertTrue( np.array_equal( y.to_local().numpy(), x.to_local().numpy()[s0_range_of_this_rank[0] : s0_range_of_this_rank[1]], ) ) # S(0) To S(1) z = y.to_global(placement=placement, sbp=[flow.sbp.split(1)]) s1_balanced_ranges = flow.framework.balanced_splitter.BalancedRanges( shape[1], flow.env.get_world_size() ) s1_range_of_this_rank = s1_balanced_ranges[flow.env.get_rank()] test_case.assertEqual(z.placement, placement) test_case.assertTrue( np.allclose( z.to_local().numpy(), x.to_local().numpy()[ ..., s1_range_of_this_rank[0] : s1_range_of_this_rank[1] ], ) ) # S(1) To B w = z.to_global(placement=placement, sbp=[flow.sbp.broadcast]) test_case.assertEqual(w.placement, placement) test_case.assertTrue(np.allclose(w.to_local().numpy(), x.to_local().numpy())) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerGlobalCastOneDUnevenSplit(flow.unittest.TestCase): def test_eager_global_cast_1d_uneven_split(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["shape"] = [(25, 33), (13, 17)] for arg in GenArgList(arg_dict): _test_eager_global_cast_1d_uneven_split(test_case, *arg) def _test_eager_global_n_dim_reduce(test_case, device_type, src_sbp, dst_sbp): np.random.seed(10) np_arr = np.random.uniform(-1e-05, 1e-05, (16, 32)) placement0 = flow.placement(device_type, ranks=[[0]]) placement1 = flow.placement(device_type, ranks=[[0, 1], [2, 3]]) # oneflow.placement(type="cuda", ranks=[[0]]) # (src_sbp, src_sbp) x = flow.tensor( np_arr, placement=placement0, sbp=[src_sbp, src_sbp], requires_grad=False, ) # oneflow.placement(type="cuda", ranks=[[0,1],[2,3]]) # (dst_sbp, dst_sbp) y = x.to_global(placement=placement1, sbp=[dst_sbp, dst_sbp]) z = y.to_global(placement=placement1, sbp=[flow.sbp.broadcast, flow.sbp.broadcast]) test_case.assertEqual(z.placement, placement1) test_case.assertTrue(np.allclose(z.to_local().numpy(), np_arr)) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerGlobalCastNDimReduceBoxing(flow.unittest.TestCase): def test_eager_global_n_dim_reduce(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["src_sbp"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)] arg_dict["dst_sbp"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)] for arg in GenArgList(arg_dict): _test_eager_global_n_dim_reduce(test_case, *arg) def _test_eager_global_with_0_size_data( test_case, shape, in_device_type, out_device_type, in_device_list, out_device_list, in_sbp, out_sbp, ): in_placement = flow.placement(in_device_type, in_device_list) out_placement = flow.placement(out_device_type, out_device_list) x = flow.Tensor(*shape, placement=in_placement, sbp=in_sbp) y = x.to_global(out_placement, out_sbp) test_case.assertEqual(y.placement, out_placement) test_case.assertEqual(y.sbp, out_sbp) test_case.assertEqual(y.size(), shape) @flow.unittest.skip_unless_1n4d() class TestEagerNaiveBoxingSToS(flow.unittest.TestCase): def test_eager_global_with_0_size_data(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8, 0, 4), (5, 0, 7)] arg_dict["in_device_type"] = ["cpu", "cuda"] arg_dict["out_device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] arg_dict["out_device_list"] = [[1], [3], [2, 3], [0, 1, 3], [0, 1, 2, 3]] arg_dict["in_sbp"] = [ (flow.sbp.split(0),), (flow.sbp.split(2),), (flow.sbp.broadcast,), (flow.sbp.partial_sum,), ] arg_dict["out_sbp"] = [ (flow.sbp.split(0),), (flow.sbp.split(2),), (flow.sbp.broadcast,), (flow.sbp.partial_sum,), ] for arg in GenArgList(arg_dict): _test_eager_global_with_0_size_data(test_case, *arg) def _test_eager_boxing_one_to_n_with_diff_dim( test_case, in_device_type, out_device_type ): x = flow.tensor( [1, 2, 3, 4], sbp=flow.sbp.broadcast, placement=flow.placement(in_device_type, ranks=[0]), ) y = x.to_global( sbp=[flow.sbp.broadcast, flow.sbp.split(0)], placement=flow.placement(out_device_type, ranks=[[0, 1], [2, 3]]), ) rank = flow.env.get_rank() if rank == 0 or rank == 2: test_case.assertTrue(np.array_equal(y.to_local().numpy(), np.array([1, 2]),)) elif rank == 1 or rank == 3: test_case.assertTrue(np.array_equal(y.to_local().numpy(), np.array([3, 4]),)) def _test_eager_boxing_n_to_one_with_diff_dim( test_case, in_device_type, out_device_type ): x = flow.tensor( [1, 2, 3, 4], sbp=[flow.sbp.broadcast, flow.sbp.split(0)], placement=flow.placement(in_device_type, ranks=[[0, 1], [2, 3]]), ) y = x.to_global( sbp=flow.sbp.broadcast, placement=flow.placement(out_device_type, ranks=[0]) ) rank = flow.env.get_rank() if rank == 0: test_case.assertTrue( np.array_equal(y.to_local().numpy(), np.array([1, 2, 3, 4]),) ) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingOneToNWithDiffDim(flow.unittest.TestCase): def test_eager_boxing_one_to_n_with_diff_dim(test_case): arg_dict = OrderedDict() arg_dict["in_device_type"] = ["cpu", "cuda"] arg_dict["out_device_type"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_one_to_n_with_diff_dim(test_case, *arg) @flow.unittest.skip_unless_1n4d() class TestEagerBoxingNToOneWithDiffDim(flow.unittest.TestCase): def test_eager_boxing_n_to_one_with_diff_dim(test_case): arg_dict = OrderedDict() arg_dict["in_device_type"] = ["cpu", "cuda"] arg_dict["out_device_type"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_one_to_n_with_diff_dim(test_case, *arg) def _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement( test_case, in_sbp, out_sbp, shape, in_device_type, out_device_type, in_device_list, out_device_list, ): if not isinstance(in_sbp, tuple): in_sbp = (in_sbp,) if not isinstance(out_sbp, tuple): out_sbp = (out_sbp,) in_placement = flow.placement(type=in_device_type, ranks=in_device_list) out_placement = flow.placement(type=out_device_type, ranks=out_device_list) np_arr = np.random.uniform(-1e-05, 1e-05, shape) x = flow.tensor( np_arr, dtype=flow.float32, device=in_device_type, requires_grad=False, ) x = x.to_global(in_placement, in_sbp) y = x.to_global(out_placement, out_sbp) test_case.assertTrue(y.sbp == out_sbp) test_case.assertTrue(y.placement == out_placement) test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerBoxingAsymmetricMix1d2dWithRandomPlacement(flow.unittest.TestCase): def test_eager_boxing_asymmetric_mix_1d_2d_with_random_placement(test_case): arg_dict = OrderedDict() sbp_dict = OrderedDict() arg_dict["shape"] = [(12, 24), (17, 13, 19)] arg_dict["in_device_type"] = ["cpu", "cuda"] arg_dict["out_device_type"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [ [2], [0, 1], [1, 2, 3], [0, 1, 2, 3], [[0, 1, 2, 3]], [[0, 1], [2, 3]], ] arg_dict["out_device_list"] = [ [1], [3], [2, 3], [0, 1, 3], [0, 1, 2, 3], [[2], [3]], [[0, 1], [2, 3]], ] sbp_1d = [ flow.sbp.split(0), flow.sbp.split(1), flow.sbp.broadcast, flow.sbp.partial_sum, ] sbp_dict["in_sbp_1d"] = sbp_1d sbp_dict["out_sbp_1d"] = sbp_1d import itertools sbp_2d = list(itertools.product(sbp_1d, sbp_1d)) sbp_dict["in_sbp_2d"] = sbp_2d sbp_dict["out_sbp_2d"] = sbp_2d is_2d_device_list = lambda x: isinstance(x[0], list) for arg in GenArgList(arg_dict): in_device_list = arg[-2] out_device_list = arg[-1] is_in_2d_n_device_list = is_2d_device_list(in_device_list) is_out_2d_n_device_list = is_2d_device_list(out_device_list) if is_in_2d_n_device_list and is_out_2d_n_device_list: for in_sbp in sbp_dict["in_sbp_2d"]: for out_sbp in sbp_dict["out_sbp_2d"]: _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement( test_case, in_sbp, out_sbp, *arg ) elif is_in_2d_n_device_list and not is_out_2d_n_device_list: for in_sbp in sbp_dict["in_sbp_2d"]: for out_sbp in sbp_dict["out_sbp_1d"]: _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement( test_case, in_sbp, out_sbp, *arg ) elif not is_in_2d_n_device_list and is_out_2d_n_device_list: for in_sbp in sbp_dict["in_sbp_1d"]: for out_sbp in sbp_dict["out_sbp_2d"]: _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement( test_case, in_sbp, out_sbp, *arg ) elif not is_in_2d_n_device_list and not is_out_2d_n_device_list: for in_sbp in sbp_dict["in_sbp_1d"]: for out_sbp in sbp_dict["out_sbp_1d"]: _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement( test_case, in_sbp, out_sbp, *arg ) else: raise NotImplementedError @flow.unittest.skip_unless_1n4d() class TestEagerBoxing2DLocalToGlobalWithBalancedSplitSize(flow.unittest.TestCase): def test_eager_boxing_2d_local_to_globa_with_balanced_size(test_case): placement = flow.placement(type="cpu", ranks=np.arange(4).reshape((2, 2))) sbp = (flow.sbp.split(0), flow.sbp.split(1)) x = flow.tensor(np.arange(25).reshape((5, 5)), placement=placement, sbp=sbp) y = x.to_local() z = y.to_global(placement=placement, sbp=sbp) test_case.assertEqual(z.placement, placement) test_case.assertEqual(z.sbp, sbp) test_case.assertEqual(z.size(), (5, 5)) test_case.assertTrue( np.allclose(z.numpy(), np.arange(25).reshape((5, 5)), 1e-5, 1e-5) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_eager_boxing_exhaustive.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import itertools import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _test_eager_boxing_normal_1d_exhaustive_testing( test_case, shape, in_device, out_device, in_device_list, out_device_list ): sbps = [ flow.sbp.split(0), flow.sbp.split(1), flow.sbp.broadcast, flow.sbp.partial_sum, ] in_placement = flow.placement(type=in_device, ranks=in_device_list) out_placement = flow.placement(type=out_device, ranks=out_device_list) rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow for elem in itertools.product(sbps, sbps): x = rand_tensor.to_global(placement=in_placement, sbp=elem[0]) y = x.to_global(placement=out_placement, sbp=elem[1]) test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3)) def _test_eager_boxing_symmetric_2d_exhaustive_testing( test_case, in_device, out_device ): sbps = [ flow.sbp.split(0), flow.sbp.split(1), flow.sbp.broadcast, flow.sbp.partial_sum, ] nd_sbps = itertools.product( itertools.product(sbps, sbps), itertools.product(sbps, sbps) ) shape = (8, 8, 16) in_placement = flow.placement(type=in_device, ranks=[[0, 1], [2, 3]]) out_placement = flow.placement(type=out_device, ranks=[[0, 1], [2, 3]]) rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow for elem in nd_sbps: x = rand_tensor.to_global(placement=in_placement, sbp=elem[0]) y = x.to_global(placement=out_placement, sbp=elem[1]) test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3)) def _test_eager_boxing_1d_special_split_axis( test_case, in_device, out_device, in_device_list, out_device_list ): sbps = [ flow.sbp.split(2), flow.sbp.split(3), flow.sbp.broadcast, flow.sbp.partial_sum, ] shape = (4, 4, 5, 7) in_placement = flow.placement(type=in_device, ranks=in_device_list) out_placement = flow.placement(type=out_device, ranks=out_device_list) rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow for elem in itertools.product(sbps, sbps): x = rand_tensor.to_global(placement=in_placement, sbp=elem[0]) y = x.to_global(placement=out_placement, sbp=elem[1]) test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3)) def _test_eager_boxing_2d_special_split_axis(test_case, in_device, out_device): sbps = [ flow.sbp.split(2), flow.sbp.split(4), flow.sbp.broadcast, flow.sbp.partial_sum, ] nd_sbps = itertools.product( itertools.product(sbps, sbps), itertools.product(sbps, sbps) ) shape = (4, 8, 4, 8, 4) in_placement = flow.placement(type=in_device, ranks=[[0, 1], [2, 3]]) out_placement = flow.placement(type=out_device, ranks=[[0, 1], [2, 3]]) rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow for elem in nd_sbps: x = rand_tensor.to_global(placement=in_placement, sbp=elem[0]) y = x.to_global(placement=out_placement, sbp=elem[1]) test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3)) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerBoxingSymmetricExhaustiveTesting(flow.unittest.TestCase): @globaltest def test_eager_boxing_normal_1d_exhaustive_testing(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(4, 4), (6, 8), (5, 7)] arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] arg_dict["out_device_list"] = [[0, 1, 3], [0, 1, 2, 3]] for arg in GenArgList(arg_dict): _test_eager_boxing_normal_1d_exhaustive_testing(test_case, *arg) @globaltest def test_eager_boxing_symmetric_2d_exhaustive_testing(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_symmetric_2d_exhaustive_testing(test_case, *arg) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestEagerBoxingSpecialSplitAxisExhaustiveTesting(flow.unittest.TestCase): @globaltest def test_eager_boxing_1d_special_split_axis(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] arg_dict["in_device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] arg_dict["out_device_list"] = [[0, 1, 3], [0, 1, 2, 3]] for arg in GenArgList(arg_dict): _test_eager_boxing_1d_special_split_axis(test_case, *arg) @globaltest def test_eager_boxing_2d_special_split_axis(test_case): arg_dict = OrderedDict() arg_dict["in_device"] = ["cpu", "cuda"] arg_dict["out_device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_eager_boxing_2d_special_split_axis(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_empty.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow from oneflow.test_utils.test_util import GenArgDict def _test_local_empty(test_case, shape, dtype, device, requires_grad): x = flow.empty( shape, dtype=dtype, device=flow.device(device), requires_grad=requires_grad if dtype == flow.float32 else False, ) test_case.assertFalse(x.is_global) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.device, flow.device(device)) if dtype == flow.float32: test_case.assertEqual(x.requires_grad, requires_grad) empty_like_x = flow.empty_like( x, dtype=dtype, device=flow.device(device), requires_grad=requires_grad if dtype == flow.float32 else False, ) test_case.assertFalse(empty_like_x.is_global) test_case.assertEqual(empty_like_x.shape, flow.Size(shape)) test_case.assertEqual(empty_like_x.dtype, dtype) test_case.assertEqual(empty_like_x.device, flow.device(device)) if dtype == flow.float32: test_case.assertEqual(empty_like_x.requires_grad, requires_grad) def _test_new_empty(test_case, shape, dtype, device, requires_grad): x = flow.empty(shape, dtype=dtype, device=flow.device(device)) y = x.new_empty( shape, dtype=dtype, device=flow.device(device), requires_grad=requires_grad if dtype == flow.float32 else False, ) test_case.assertFalse(y.is_global) test_case.assertEqual(y.shape, flow.Size(shape)) test_case.assertEqual(y.dtype, dtype) test_case.assertEqual(y.device, flow.device(device)) if dtype == flow.float32: test_case.assertEqual(y.requires_grad, requires_grad) y = x.new_empty(*shape) test_case.assertFalse(y.is_global) test_case.assertEqual(y.shape, flow.Size(shape)) test_case.assertEqual(y.dtype, x.dtype) test_case.assertEqual(y.device, x.device) test_case.assertFalse(y.requires_grad) def _test_local_empty_strided(test_case, shape, stride, dtype, device, requires_grad): x = flow.empty_strided( shape, stride, dtype=dtype, device=flow.device(device), requires_grad=requires_grad, ) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.stride(), stride) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.device, flow.device(device)) if dtype == flow.float32: test_case.assertEqual(x.requires_grad, requires_grad) @flow.unittest.skip_unless_1n1d() class TestEmptyOp(flow.unittest.TestCase): def test_local_empty(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["dtype"] = [flow.float32, flow.float16, flow.int32] arg_dict["device"] = ["cpu", "cuda"] arg_dict["requires_grad"] = [True, False] for arg in GenArgDict(arg_dict): _test_local_empty(test_case, **arg) _test_new_empty(test_case, **arg) def test_local_empty_strided(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 6), (2, 3, 12, 4)] arg_dict["stride"] = [(1, 2), (1, 2, 3), (2, 4, 5, 1)] arg_dict["dtype"] = [flow.float32, flow.float16, flow.int32] arg_dict["device"] = ["cpu", "cuda"] arg_dict["requires_grad"] = [True, False] for arg in GenArgDict(arg_dict): _test_local_empty_strided(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_eq.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestEq(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, check_graph=True) def test_eq_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 0, 3).to(device) y = random_tensor(3, 2, 0, 3).to(device) z = torch.eq(x, y) return z @autotest(n=5, auto_backward=False, check_graph=True) def test_eq_with_0shape_0d_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) z = torch.eq(x, y) return z @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_eq_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.eq(x, oneof(y, random().to(int), random().to(float))) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_tensor_eq_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = random_tensor(len(shape), *shape, requires_grad=False).to(device) return x.eq(oneof(y, random().to(int), random().to(float))) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_eq_with_random_0d_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(ndim=0, requires_grad=False).to(device) y = random_tensor(ndim=0, requires_grad=False).to(device) return torch.eq(x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_eq_with_same_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.eq(x, x) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_eq_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) return torch.eq(x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_eq_with_same_random_0d_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(ndim=0, requires_grad=False).to(device) return torch.eq(x, x) @profile(torch.eq) def profile_eq(test_case): input1 = torch.ones(1000, 1280) input2 = torch.ones(1000, 1280) torch.eq(input1, input2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_equal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as torch_original from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestEqual(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_eq_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 0, 3).to(device) y = random_tensor(3, 2, 0, 3).to(device) z = torch.equal(x, y) return z @autotest(n=5, auto_backward=False, check_graph=False) def test_equal_with_0shape_0d_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) z = torch.equal(x, y) return z @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.equal(x, y) @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_tensor_equal_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = random_tensor(len(shape), *shape, requires_grad=False).to(device) return x.equal(y) @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_random_0d_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(ndim=0, requires_grad=False).to(device) y = random_tensor(ndim=0, requires_grad=False).to(device) return torch.equal(x, y) @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_same_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.equal(x, x) @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_flow_equal_complex_with_same_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( device ) return torch.equal(x, x) @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) return torch.equal(x, y) @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_flow_equal_complex_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( device=device ) y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( device=device ) return torch.equal(x, y) @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_flow_not_equal_complex_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( device=device ) y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( device=device ) return torch.not_equal(x, y) @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_same_random_0d_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(ndim=0, requires_grad=False).to(device) return torch.equal(x, x) @profile(torch.equal) def profile_equal(test_case): input1 = torch.ones(1000, 1280) input2 = torch.ones(1000, 1280) torch.equal(input1, input2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_erf.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from scipy import special from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestErfModule(flow.unittest.TestCase): @autotest(n=5) def test_flow_erf_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.erf(x) return y @autotest(n=5) def test_flow_erf_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.erf(x) return y @profile(torch.erf) def profile_erf(test_case): torch.erf(torch.ones(100000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_erfc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from scipy import special from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestErfcModule(flow.unittest.TestCase): @autotest(n=5) def test_flow_erfc_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.erfc(x) return y @autotest(n=5) def test_flow_erfc_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.erfc(x) return y @profile(torch.erfc) def profile_erfc(test_case): torch.erfc(torch.ones(100000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_erfinv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from scipy import special from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_flow_erfinv_with_inf_data(test_case, device): x = flow.tensor(np.ones((5, 5)), dtype=flow.float32, device=flow.device(device)) of_out = flow.erfinv(x) np_out = np.full((5, 5), np.inf) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_flow_erfinv_with_nan_data(test_case, device): x = flow.tensor( np.arange(2, 22).reshape(4, 5), dtype=flow.float32, device=flow.device(device) ) of_out = flow.erfinv(x) np_out = np.full((4, 5), np.nan) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out, equal_nan=True)) @flow.unittest.skip_unless_1n1d() class TestErfinvModule(flow.unittest.TestCase): def test_flow_erfinv(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_flow_erfinv_with_inf_data, _test_flow_erfinv_with_nan_data, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True, auto_backward=False) def test_flow_erfinv_with_random_data(test_case): device = random_device() x = random_tensor(requires_grad=False).to(device) y = torch.erfinv(x) return y @profile(torch.erfinv) def profile_erfinv(test_case): torch.erfinv(torch.ones(100000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_expand.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _np_get_expand(input_shape, expand_size): input = np.random.random(size=input_shape).astype(np.float32) input_stride = [1] for i in range(len(input_shape) - 2, -1, -1): input_stride.insert(0, input_stride[0] * input_shape[i + 1]) # calculate the output shape and stride new_size = [] new_stride = [] diff = len(expand_size) - len(input_shape) for i in range(len(expand_size) - 1, -1, -1): if i >= diff: if expand_size[i] == -1 or expand_size[i] == input_shape[i - diff]: new_size.insert(0, input_shape[i - diff]) new_stride.insert(0, input_stride[i - diff]) else: assert expand_size[i] >= 1 and input_shape[i - diff] == 1 new_size.insert(0, expand_size[i]) new_stride.insert(0, 0) else: assert expand_size[i] >= 1 new_size.insert(0, expand_size[i]) if expand_size[i] == 1: new_stride.insert(0, new_stride[0]) else: new_stride.insert(0, 0) gout = np.random.random(size=tuple(new_size)).astype(np.float32) out_stride = [1] for i in range(len(new_size) - 2, -1, -1): out_stride.insert(0, out_stride[0] * new_size[i + 1]) gin = np.zeros(input_shape).flatten() out = np.zeros(np.product(new_size)) def getOffset(i_offset, stride, expand_stride, n): remain = i_offset o_offset = 0 for i in range(n): idx = int(remain / stride[i]) o_offset += idx * expand_stride[i] remain = remain - idx * stride[i] return o_offset in_flatten = input.flatten() gout_flatten = gout.flatten() num_elem = np.product(new_size) dims = len(new_size) for i in range(num_elem): offset = getOffset(i, out_stride, new_stride, dims) gin[offset] += gout_flatten[i] out[i] = in_flatten[offset] return input, gout, out.reshape(tuple(new_size)), gin.reshape(input_shape) def _test_expand_new_dims(test_case, device): input_shape = (1, 4, 1, 32) expand_dim = [2, 1, 2, 4, 2, 32] input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim) of_input = flow.tensor( input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input.expand(2, 1, 2, 4, 2, 32) test_case.assertTrue(np.array_equal(of_out.numpy(), out_np)) def _test_expand_same_dim(test_case, device): input_shape = (2, 4, 1, 32) expand_dim = [2, 4, 2, 32] input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim) of_input = flow.tensor(input, dtype=flow.float32, device=flow.device(device)) of_out = of_input.expand(2, 4, 2, 32) test_case.assertTrue(np.array_equal(of_out.numpy(), out_np)) def _test_expand_same_dim_negative(test_case, device): input_shape = (1, 6, 5, 3) expand_dim = [4, -1, 5, 3] input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim) of_input = flow.tensor(input, dtype=flow.float32, device=flow.device(device)) of_out = of_input.expand(4, -1, 5, 3) test_case.assertTrue(np.array_equal(of_out.numpy(), out_np)) def _test_expand_same_int(test_case, device): input_shape = (2, 4, 1, 32) expand_dim = [2, 4, 2, 32] input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim) of_input = flow.tensor(input, dtype=flow.int, device=flow.device(device)) of_out = of_input.expand(2, 4, 2, 32) test_case.assertTrue(np.array_equal(of_out.numpy(), out_np.astype(np.int32))) def _test_expand_flow_size(test_case, device): input_shape = (2, 4, 1, 32) expand_dim = flow.Size([2, 4, 2, 32]) input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim) of_input = flow.tensor(input, dtype=flow.int, device=flow.device(device)) of_out = of_input.expand(expand_dim) test_case.assertTrue(np.array_equal(of_out.numpy(), out_np.astype(np.int32))) def _test_expand_backward_same_dim(test_case, device): input = np.array( [ [ [[0.9876952171325684]], [[0.8772538304328918]], [[0.9200366735458374]], [[0.2810221314430237]], ], [ [[0.3037724494934082]], [[0.7783719897270203]], [[0.08884672075510025]], [[0.17156553268432617]], ], ] ) of_input = flow.tensor( input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input.expand(2, 4, 2, 1) of_out.sum().backward() np_grad = [ [[[2.0]], [[2.0]], [[2.0]], [[2.0]]], [[[2.0]], [[2.0]], [[2.0]], [[2.0]]], ] test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad)) def _test_expand_backward(test_case, device): input = np.array( [ [ [[0.8981702327728271, 0.5372866988182068]], [[0.45116370916366577, 0.8656941056251526]], [[0.8811476230621338, 0.5552017688751221]], [[0.6291894316673279, 0.5786571502685547]], ] ] ) of_input = flow.tensor( input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = of_input.expand(2, 1, 2, 4, 2, 2) of_out.sum().backward() np_grad = [[[[8.0, 8.0]], [[8.0, 8.0]], [[8.0, 8.0]], [[8.0, 8.0]]]] test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad)) def random_expand(x, ndim, expand_size): dim_size = [1,] * ndim random_index = random(0, ndim).to(int).value() dim_size[random_index] = expand_size return x.expand(*dim_size) @flow.unittest.skip_unless_1n1d() class TestExpand(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_tensor_expand_with_random_data(test_case): random_expand_size = random(1, 6).to(int).value() x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1) return random_expand(x, ndim=5, expand_size=random_expand_size) @autotest(auto_backward=False, check_graph=True) def test_flow_tensor_expand_bool_with_random_data(test_case): random_expand_size = random(1, 6).to(int).value() x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1).to(torch.bool) return random_expand(x, ndim=5, expand_size=random_expand_size) def test_expand_compare_with_numpy(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_expand_new_dims, _test_expand_same_dim, _test_expand_same_dim_negative, _test_expand_same_int, _test_expand_flow_size, _test_expand_backward, _test_expand_backward_same_dim, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False) def test_flow_expand_with_0_size(test_case): device = random_device() x = random_tensor(ndim=2, dim1=1).to(device) return x.expand([0, 3]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_expand_stride.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import torch from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict def _cmp_expand_stride( test_case, input_shape, expand_shape, device="cuda", verbose=False, ): input = np.random.randn(*input_shape) torch_x = torch.tensor(input, dtype=torch.float32, device=device) torch_y = torch_x.expand(*expand_shape) x = flow.tensor(input, dtype=flow.float32, device=device) y = x.expand(*expand_shape) if verbose: print("") print(f" eager (view::Expand) (device={device}) ".center(50, "=")) print(f" {input_shape} -> {expand_shape} ".center(50, "*")) print(f"x: shape={x.shape}, stride={x.stride()}") print(f"y: shape={y.shape}, stride={y.stride()}") print(f"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}") print(" input ".center(50, "-")) print(input) print(" y ".center(50, "-")) print(y) print(" torch_y ".center(50, "-")) print(torch_y) test_case.assertTrue(np.array_equal(y.stride(), torch_y.stride())) test_case.assertTrue(np.array_equal(y.numpy(), torch_y.detach().cpu().numpy())) def _cmp_expand_non_contiguous_stride( test_case, input_shape, perm, expand_shape, device="cuda", verbose=False, ): input = np.random.randn(*input_shape).astype(np.float32) x = flow.tensor(input, device=device) y = x.permute(*perm) z = y.expand(*expand_shape) torch_x = torch.tensor(input, device=device) torch_y = torch_x.permute(*perm) torch_z = torch_y.expand(*expand_shape) if verbose: print("") print(f" non_contiguous (device={device}) ".center(50, "-")) print(f" {input_shape}, {perm} -> {expand_shape} ".center(50, "-")) print(f"x: shape={x.shape}, stride={x.stride()}") print(f"y: shape={y.shape}, stride={y.stride()}") print(f"z: shape={z.shape}, stride={z.stride()}") print(f"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}") print(f"torch_z: shape={torch_z.shape}, stride={torch_z.stride()}") print(" input ".center(50, "-")) print(input) print(" z ".center(50, "-")) print(z) print(" torch_z ".center(50, "-")) print(torch_z) test_case.assertTrue(np.array_equal(z.stride(), torch_z.stride())) test_case.assertTrue(np.array_equal(z.numpy(), torch_z.detach().cpu().numpy())) def _cmp_lazy_expand_stride( test_case, input_shape, expand_shape, device="cuda", verbose=False, ): input = np.random.randn(*input_shape) torch_x = torch.tensor(input, dtype=torch.float32, device=device) torch_y = torch_x.expand(*expand_shape).contiguous() # oneflow lazy must do this contiguous class MyGraph(flow.nn.Graph): def __init__(self, expand_shape): super().__init__() self.expand_shape = expand_shape def build(self, x): return x.expand(*self.expand_shape) expand_graph = MyGraph(expand_shape) x = flow.tensor(input, dtype=flow.float32, device=device) y = expand_graph(x) squeeze_y_stride = [] for d, s in zip(y.shape, y.stride()): if d != 1: squeeze_y_stride.append(s) squeeze_torch_y_stride = [] for d, s in zip(torch_y.shape, torch_y.stride()): if d != 1: squeeze_torch_y_stride.append(s) if verbose: print("") print(f" lazy (expand op/kernel) (device={device}) ".center(50, "=")) print(f" {input_shape} -> {expand_shape} ".center(50, "*")) print(f"x: shape={x.shape}, stride={x.stride()}") print(f"y: shape={y.shape}, stride={y.stride()}") print(f"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}") print(f"squeeze_y_stride={squeeze_y_stride}") print(f"squeeze_torch_y_stride={squeeze_torch_y_stride}") print(" input ".center(50, "-")) print(input) print(" y ".center(50, "-")) print(y) print(" torch_y ".center(50, "-")) print(torch_y) test_case.assertTrue(np.array_equal(squeeze_y_stride, squeeze_torch_y_stride)) test_case.assertTrue(np.array_equal(y.numpy(), torch_y.detach().cpu().numpy())) @flow.unittest.skip_unless_1n1d() class ExpandStrideTestCase(flow.unittest.TestCase): test_shape_tuple_list = [ ((1, 2), (2, 2)), ((1, 2), (1, 1, 2)), ((1, 2), (1, 2, 2)), ((1, 2), (2, 1, 2)), ((1, 2), (2, 2, 2)), ((1, 2), (1, 1, 1, 2)), ((1, 2), (1, 2, 1, 2)), ((1, 2), (2, 1, 1, 2)), ((1, 2), (2, 2, 1, 2)), ((1, 2), (2, 2, 2, 2)), ((2, 1), (2, 2)), ((2, 1), (1, 2, 1)), ((2, 1), (1, 2, 2)), ((2, 1), (2, 2, 1)), ((2, 1), (2, 2, 2)), ((2, 1), (1, 1, 2, 1)), ((2, 1), (1, 2, 2, 1)), ((2, 1), (2, 2, 2, 1)), ((2, 1), (2, 2, 2, 2)), ((2, 2), (1, 2, 2)), ((2, 2), (2, 2, 2)), ((2, 2), (1, 1, 2, 2)), ((2, 2), (1, 2, 2, 2)), ((2, 2), (2, 1, 2, 2)), ((2, 2), (2, 2, 2, 2)), ((2, 1, 4), (2, 2, 2, 4)), ((2, 1, 3), (2, 1, -1, -1, -1)), ((2, 1, 3), (1, 2, -1, -1, -1)), ((2, 1, 3), (2, 2, -1, -1, -1)), ((2, 1, 3), (2, 1, -1, 2, 3)), ((2, 1, 3), (1, 2, 2, 2, -1)), ((2, 1, 3), (2, 2, 2, 2, 3)), ((2, 3, 4), (1, 2, -1, -1, -1)), ((2, 3, 4), (2, 1, -1, -1, -1)), ((2, 3, 4), (2, 2, -1, -1, -1)), ((), (1,)), ((), (2,)), ((), (1, 2)), ((), (2, 1)), ((), (2, 2)), ] @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_on_cpu(self): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cpu"] arg_dict["shapes"] = self.test_shape_tuple_list for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _cmp_expand_stride(self, input_shape, expand_shape, **kwargs) def test_stride(self): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cuda"] arg_dict["shapes"] = self.test_shape_tuple_list for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _cmp_expand_stride(self, input_shape, expand_shape, **kwargs) def test_non_contiguous_stride(self): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shapes"] = [ ((2, 1, 3), (0, 2, 1), (1, 2, -1, -1, -1)), ((2, 1, 3), (0, 2, 1), (2, 1, -1, -1, -1)), ((2, 1, 3), (0, 2, 1), (2, 3, -1, -1, -1)), ((2, 1, 3), (0, 2, 1), (1, 2, -1, -1, 2)), ((2, 1, 3), (0, 2, 1), (2, 1, -1, -1, 2)), ((2, 1, 3), (0, 2, 1), (2, 3, -1, -1, 2)), ((2, 3, 4), (0, 2, 1), (1, 2, -1, -1, -1)), ((2, 3, 4), (0, 2, 1), (2, 1, -1, -1, -1)), ((2, 3, 4), (0, 2, 1), (2, 2, -1, -1, -1)), ] for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, perm, expand_shape = kwargs.pop("shapes") _cmp_expand_non_contiguous_stride( self, input_shape, perm, expand_shape, **kwargs ) def test_lazy(self): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cuda"] arg_dict["shapes"] = self.test_shape_tuple_list for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _cmp_lazy_expand_stride(self, input_shape, expand_shape, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_expm1.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_expm1_impl(test_case, device, shape): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.expm1(x) np_out = np.expm1(x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.exp(x.numpy()), 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestExpm1Module(flow.unittest.TestCase): def test_expm1(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_expm1_impl] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(1,), (2, 3), (2, 3, 4), (2, 3, 4, 5)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, check_graph=True) def test_expm1_flow_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = torch.expm1(input) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_expm1_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.expm1(x) return y @autotest(n=5, check_graph=True) def test_expm1_flow_with_0dim_data(test_case): device = random_device() input = random_tensor(ndim=0).to(device) y = torch.expm1(input) return y @profile(torch.expm1) def profile_expm1(test_case): torch.expm1(torch.ones(100000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_eye.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _test_eye_forward(test_case, device, n, m): output = flow.eye(n, m, device=device) np_out = np.eye(n, m) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_eye_backward(test_case, device, n, m): x = flow.eye(n, m, device=device) x.requires_grad = True y = x.sum() y.backward() test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones([n, m]))) def _test_eye_with_1n2d(test_case, n, m, device): placement = flow.placement(device, range(2)) x = flow.eye(n, m, placement=placement, sbp=flow.sbp.broadcast) test_case.assertTrue(x.placement, placement) test_case.assertTrue(x.sbp, flow.sbp.broadcast) @flow.unittest.skip_unless_1n1d() class TestEye(flow.unittest.TestCase): def test_eye(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_eye_forward, _test_eye_backward, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["n"] = [4, 3, 2] arg_dict["m"] = [4, 3, 2] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True) def test_eye_with_random_data(test_case): n = random(low=1, high=5).to(int) m = random(low=1, high=5).to(int) x = torch.eye(n=n, m=m, device=random_device()) x.oneflow.requires_grad = True x.pytorch.requires_grad = True return x @autotest(check_graph=True, auto_backward=False) def test_eye_with_random_data(test_case): n = random(low=0, high=1).to(int) m = random(low=0, high=2).to(int) x = torch.eye(n=n, m=m, device=random_device()) return x @autotest(check_graph=True) def test_eye_bool_with_random_data(test_case): n = random().to(int) m = random().to(int) x = torch.eye(n=n, m=m) device = random_device() x.to(device=device, dtype=torch.bool) x = random_tensor().to(device) return x @autotest(check_graph=True, auto_backward=False) def test_eye_with_0dim_data(test_case): n = random().to(int) m = random().to(int) x = torch.eye(n=n, m=m) device = random_device() x.to(device) x = random_tensor(ndim=0).to(device) return x @profile(torch.eye) def profile_eye(test_case): torch.eye(1000) torch.eye(100, 1280) @flow.unittest.skip_unless_1n2d() class TestGlobalEye(flow.unittest.TestCase): def test_eye_with_1n2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_eye_with_1n2d] arg_dict["n"] = [4, 3, 2] arg_dict["m"] = [4, 3, 2] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fake_quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) import oneflow as flow import oneflow.unittest def gen_quant_scale_for_min_max_symmetric(weight, quantization_bit): weight_max = np.max(np.abs(weight)) denominator = 2.0 ** (quantization_bit - 1) - 1 return (weight_max / denominator, 0) def gen_quant_scale_for_min_max_affine(weight, quantization_bit): weight_max = np.max(weight) weight_min = np.min(weight) denominator = 2.0 ** quantization_bit - 1 scale = (weight_max - weight_min) / denominator zero_point = -np.round(weight_min / scale) return (scale, zero_point) def gen_quant_scale_for_min_max_cambricon(weight, quantization_bit): weight_max = np.max(np.abs(weight)) scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2) return (scale, 0) def product(tu): return np.prod(tu).astype(np.int32).item() def fake_quant_per_layer_symmetric(input, quantization_bit, scale): upper_bound = 2.0 ** (quantization_bit - 1) - 1 lower_bound = -upper_bound return np.clip(np.rint(input / scale), lower_bound, upper_bound) * scale def fake_quant_per_layer_affine(input, quantization_bit, scale, zero_point): upper_bound = 2.0 ** quantization_bit - 1 lower_bound = 0 return ( np.clip(np.rint(input / scale + zero_point), lower_bound, upper_bound) - zero_point ) * scale def fake_quant_per_layer_cambricon(input, quantization_bit, shift): upper_bound = 2.0 ** (quantization_bit - 1) - 1 lower_bound = -upper_bound scale = 2 ** shift return np.clip(np.rint(input / scale), lower_bound, upper_bound) * scale def _check_fake_quantize( test_case, input, input_diff_of, out_of, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): if per_layer_quantization or quantization_formula == "cambricon": outer_num = 1 inner_num = product(input.shape[0:]) else: outer_num = input.shape[0] inner_num = product(input.shape[1:]) scale_np = np.zeros((outer_num,)) zero_point_np = np.zeros((outer_num,)) out_np = np.zeros((inner_num * outer_num,)) input_flatten = input.flatten() input_diff_np = np.full((inner_num * outer_num,), 1.0 / (inner_num * outer_num)) if quantization_formula == "google": if quantization_scheme == "symmetric": for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit ) out = fake_quant_per_layer_symmetric( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, scale_np[c], ) out_np[c * inner_num : (c + 1) * inner_num] = out else: for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit ) out = fake_quant_per_layer_affine( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, scale_np[c], zero_point_np[c], ) out_np[c * inner_num : (c + 1) * inner_num] = out else: (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon( input_flatten, quantization_bit ) out_np = fake_quant_per_layer_cambricon( input_flatten, quantization_bit, scale_np[0] ) rmse = np.sqrt(np.mean((out_of - out_np) ** 2)) assert rmse <= 1.0, "fake_quantization op has bug!" test_case.assertTrue(np.allclose(input_diff_of, input_diff_np, rtol=0.001)) def _run_test_fake_quantize( test_case, device_type, dtype, in_shape, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): input = (np.random.random(in_shape) - 0.5).astype(type_name_to_np_type[dtype]) input_tensor = flow.tensor( input, dtype=flow.float32, requires_grad=True, device=flow.device(device_type) ) min_max_observer = flow.nn.MinMaxObserver( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization, ) (scale, zero_point) = min_max_observer(input_tensor) fake_quantization = flow.nn.FakeQuantization( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, ) output_tensor = fake_quantization(input_tensor, scale, zero_point) y = output_tensor.mean() y = y.backward() out = output_tensor.numpy() input_diff = input_tensor.grad.numpy() _check_fake_quantize( test_case, input, input_diff.flatten(), out.flatten(), quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ) class TestFakeQuantize(flow.unittest.TestCase): def test_fake_quantize(test_case): arg_dict = OrderedDict() arg_dict["test_case"] = [test_case] arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dtype"] = ["float32", "double"] arg_dict["in_shape"] = [(9, 40, 20, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["per_layer_quantization"] = [True, False] for arg in GenArgList(arg_dict): if arg[-2] == "cambricon" and arg[-1] == False: continue _run_test_fake_quantize(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fft.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as torch_original from packaging import version import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import os def is_cufft_available(): if flow.cuda.is_available(): (major, _minor) = flow.cuda.get_device_capability() return major >= 7 else: return False def is_complex_dtype(dtype): if hasattr(dtype, "pytorch") and hasattr(dtype, "oneflow"): # is DualObject return dtype.pytorch.is_complex else: return dtype in [ flow.complex64, flow.complex128, torch_original.complex64, torch_original.complex128, torch.pytorch.complex64, torch.pytorch.complex128, ] def gen_params_1d_fft(lower_n_dims=1, upper_n_dims=5): num_dims = np.random.randint(lower_n_dims, upper_n_dims) shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] if np.random.randint(2) == 1: dim = np.random.randint(low=-num_dims, high=num_dims - 1) else: dim = -1 norm = np.random.choice(["backward", "forward", "ortho", None]) if np.random.randint(2) == 1: n = None else: n = np.random.randint(low=1, high=shape[dim] * 2) params = { "num_dims": num_dims, "shape": shape, "n": n, "dim": dim, "norm": norm, } return params def gen_params_2d_fft(lower_n_dims=2, upper_n_dims=5): num_dims = np.random.randint(lower_n_dims, upper_n_dims) shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] len_fft_dim = np.random.randint(low=1, high=3) total_dims_range = np.arange(num_dims) if np.random.randint(2) == 1: dims = np.random.choice( total_dims_range, size=len_fft_dim, replace=False ).tolist() else: dims = (-2, -1) norm = np.random.choice(["backward", "forward", "ortho", None]) len_fft_dim = len(dims) if np.random.randint(2) == 1 and dims is not None: n = [] for i in range(len_fft_dim): n_ = ( np.random.randint(low=1, high=2 * shape[i]) if np.random.randint(2) == 1 else -1 ) n.append(n_) else: n = None params = { "num_dims": num_dims, "shape": shape, "n": n, "dim": dims, "norm": norm, } return params def gen_params_nd_fft(lower_n_dims=2, upper_n_dims=5): num_dims = np.random.randint(lower_n_dims, upper_n_dims) shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] len_fft_dim = np.random.randint(low=1, high=num_dims + 1) total_dims_range = np.arange(num_dims) if np.random.randint(2) == 1: dims = np.random.choice( total_dims_range, size=len_fft_dim, replace=False ).tolist() else: dims = None norm = np.random.choice(["backward", "forward", "ortho", None]) if np.random.randint(2) == 1: n = None else: n = [] len_fft_dim = ( len(dims) if dims is not None else np.random.randint(low=1, high=num_dims + 1) ) for i in range(len_fft_dim): n_ = ( np.random.randint(low=1, high=2 * shape[i]) if np.random.randint(2) == 1 else -1 ) n.append(n_) params = { "num_dims": num_dims, "shape": shape, "n": n, "dim": dims, "norm": norm, } return params def _test_fft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.fft(x, n, dim, norm) return y def _test_ifft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.ifft(x, n, dim, norm) return y def _test_rfft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.rfft(x, n, dim, norm) return y def _test_irfft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.irfft(x, n, dim, norm) return y def _test_hfft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.hfft(x, n, dim, norm) return y def _test_ihfft(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] params = gen_params_1d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.ihfft(x, n, dim, norm) return y def _test_fft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.fft2(x, n, dim, norm) return y def _test_ifft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.ifft2(x, n, dim, norm) return y def _test_rfft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.rfft2(x, n, dim, norm) return y def _test_irfft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.irfft2(x, n, dim, norm) return y def _test_hfft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.hfft2(x, n, dim, norm) return y def _test_ihfft2(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] params = gen_params_2d_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.ihfft2(x, n, dim, norm) return y def _test_fftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.fftn(x, n, dim, norm) return y def _test_ifftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] x = random_tensor(num_dims, dtype=float, *shape) if is_complex_dtype(x.dtype): # test fft_c2c dtype = test_case.dtype_dict["complex"] x = x.to(device=device, dtype=dtype) else: # test fft_r2c dtype = test_case.dtype_dict["real"] x = x.to(device=device, dtype=dtype) y = torch.fft.ifftn(x, n, dim, norm) return y def _test_rfftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.rfftn(x, n, dim, norm) return y def _test_irfftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.irfftn(x, n, dim, norm) return y def _test_hfftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["complex"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.hfftn(x, n, dim, norm) return y def _test_ihfftn(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] params = gen_params_nd_fft(lower_n_dims, upper_n_dims) num_dims = params["num_dims"] shape = params["shape"] n = params["n"] dim = params["dim"] norm = params["norm"] dtype = test_case.dtype_dict["real"] x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) y = torch.fft.ihfftn(x, n, dim, norm) return y # NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly @flow.unittest.skip_unless_1n1d() class TestComplex64Fft(flow.unittest.TestCase): def setUp(test_case): # should override by other data type of complex test_case.ndims_dict = { "1d": {"lower_n_dims": 1, "upper_n_dims": 5}, "2d": {"lower_n_dims": 2, "upper_n_dims": 5}, "nd": {"lower_n_dims": 1, "upper_n_dims": 5}, } test_case.dtype_dict = {"real": torch.float32, "complex": torch.complex64} test_case.rtol = 1e-5 test_case.atol = 1e-5 if os.environ["ONEFLOW_CI"] == "1": test_case.rtol = 1e-2 test_case.atol = 1e-2 test_case.initTestFft() def initTestFft(test_case): test_case.test_fft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_fft) test_case.test_ifft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_ifft) test_case.test_rfft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_rfft) test_case.test_irfft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_irfft) test_case.test_hfft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_hfft) test_case.test_ihfft = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_ihfft) test_case.test_fft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_fft2) test_case.test_ifft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_ifft2) test_case.test_rfft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_rfft2) test_case.test_irfft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 100, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_irfft2) test_case.test_hfft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 100, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_hfft2) test_case.test_ihfft2 = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_ihfft2) test_case.test_fftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, # NOTE: check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_fftn) test_case.test_ifftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_ifftn) test_case.test_rfftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_rfftn) test_case.test_irfftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_irfftn) test_case.test_hfftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error check_graph=False, check_grad_use_random_data=True, include_complex=True, )(_test_hfftn) test_case.test_ihfftn = autotest( n=5, auto_backward=True, rtol=test_case.rtol, atol=test_case.atol * 1e2, check_graph=False, check_grad_use_random_data=True, include_complex=False, )(_test_ihfftn) def test_1d_fft(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ test_case.test_fft, test_case.test_ifft, test_case.test_rfft, test_case.test_irfft, test_case.test_hfft, test_case.test_ihfft, ] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_2d_fft_except_hfft2(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ test_case.test_fft2, test_case.test_ifft2, test_case.test_rfft2, test_case.test_irfft2, ] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf( version.parse(torch_original.__version__) < version.parse("1.11.0"), "module 'torch.fft' has no attribute 'hfft2' or 'ihfft2' before '1.11.0'", ) def test_2d_fft_hfft2(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [test_case.test_hfft2, test_case.test_ihfft2] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_nd_fft_except_hfftn(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ test_case.test_fftn, test_case.test_ifftn, test_case.test_rfftn, test_case.test_irfftn, ] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf( version.parse(torch_original.__version__) < version.parse("1.11.0"), "module 'torch.fft' has no attribute 'hfftn' or 'ihfftn' before '1.11.0'", ) def test_nd_fft_hfftn(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [test_case.test_hfftn, test_case.test_ihfftn] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) # NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly @flow.unittest.skip_unless_1n1d() class TestComplex128Fft(TestComplex64Fft): def setUp(test_case): # should override by other data type of complex test_case.ndims_dict = { "1d": {"lower_n_dims": 1, "upper_n_dims": 5}, "2d": {"lower_n_dims": 2, "upper_n_dims": 5}, "nd": {"lower_n_dims": 1, "upper_n_dims": 5}, } test_case.dtype_dict = {"real": torch.float64, "complex": torch.complex128} test_case.rtol = 1e-7 test_case.atol = 1e-7 test_case.initTestFft() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_flatten.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_flatten(test_case, device): m = flow.nn.Flatten() x = flow.Tensor(32, 2, 5, 5, device=flow.device(device)) flow.nn.init.uniform_(x) y = m(x) test_case.assertTrue(y.shape == flow.Size((32, 50))) test_case.assertTrue(np.array_equal(y.numpy().flatten(), x.numpy().flatten())) y2 = flow.flatten(x, start_dim=2) test_case.assertTrue(y2.shape == flow.Size((32, 2, 25))) test_case.assertTrue(np.array_equal(y2.numpy().flatten(), x.numpy().flatten())) y3 = x.flatten(start_dim=1) test_case.assertTrue(y3.shape == flow.Size((32, 50))) test_case.assertTrue(np.array_equal(y3.numpy().flatten(), x.numpy().flatten())) y4 = x.flatten(start_dim=1, end_dim=2) test_case.assertTrue(y4.shape == flow.Size((32, 10, 5))) test_case.assertTrue(np.array_equal(y4.numpy().flatten(), x.numpy().flatten())) y5 = flow.flatten(x) test_case.assertTrue(y5.shape == flow.Size((1600,))) test_case.assertTrue(np.array_equal(y5.numpy().flatten(), x.numpy().flatten())) def _test_flatten_backward(test_case, device): m = flow.nn.Flatten().to(flow.device(device)) x = flow.Tensor(2, 3, 4, 5, device=flow.device(device)) x.requires_grad = True flow.nn.init.uniform_(x) y = m(x) z = y.sum() z.backward() test_case.assertTrue(np.array_equal(np.ones(shape=(2, 3, 4, 5)), x.grad.numpy())) @flow.unittest.skip_unless_1n1d() class TestFlattenModule(flow.unittest.TestCase): def test_cast(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_flatten, _test_flatten_backward] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_flatten_module_with_random_data(test_case): m = torch.nn.Flatten( start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y @autotest(n=5) def test_flatten_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.flatten( x, start_dim=random(1, 6).to(int) | nothing(), end_dim=random(1, 6).to(int) | nothing(), ) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_flatten_bool_with_random_data(test_case): device = random_device() x = random_tensor().to(device=device, dtype=torch.bool) y = torch.flatten( x, start_dim=random(1, 6).to(int) | nothing(), end_dim=random(1, 6).to(int) | nothing(), ) return y @autotest(n=5) def test_flatten_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.flatten( x, start_dim=random(1, 6).to(int) | nothing(), end_dim=random(1, 6).to(int) | nothing(), ) return y @profile(torch.flatten) def profile_flatten(test_case): torch.flatten(torch.ones(1000, 1000)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestFlip(flow.unittest.TestCase): @autotest(check_graph=True, check_allclose=False) def test_flow_flip_list_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = torch.flip(x, constant([0, 1, 2])) return y @autotest(n=5) def test_flow_flip_tuple_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = torch.flip(x, constant((0, 1, 2))) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_flip_bool_tuple_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device=device, dtype=torch.bool) y = torch.flip(x, constant((0, 1, 2))) return y def test_flow_flip_list_lastdim_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = torch.flip(x, [-1,]) return y @profile(torch.flip) def profile_flip(test_case): torch.flip(torch.ones(100, 100, 100), [0, 1]) torch.flip(torch.ones(1, 100000), [-1,]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_floor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_floor(test_case, shape, device): np_input = np.random.randn(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = flow.floor(of_input) np_out = np.floor(np_input) test_case.assertTrue( np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) ) of_out = of_out.sum() of_out.backward() np_out_grad = np.zeros_like(of_out, dtype=np.float32) test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True) ) @flow.unittest.skip_unless_1n1d() class TestFloor(flow.unittest.TestCase): def test_floor(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2,), (2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_floor(test_case, *arg) @autotest(check_graph=True) def test_flow_floor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.floor(x) return y @autotest(check_graph=True) def test_flow_floor_inplace_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x + 1 y.floor_() return y @autotest(check_graph=True) def test_flow_floor_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.floor(x) return y @profile(torch.floor) def profile_floor(test_case): torch.floor(torch.ones(100, 100, 100)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fmod.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random as rd import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch as torch_original from packaging import version @flow.unittest.skip_unless_1n1d() class TestFmodModule(flow.unittest.TestCase): # other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0 grad_implemented = version.parse(torch_original.__version__) >= version.parse( "1.11.0" ) @autotest(n=1, auto_backward=grad_implemented) def test_flow_fmod_element_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) other = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) return torch.fmod(input, other) @autotest(n=1, auto_backward=grad_implemented) def test_flow_fmod_element_with_0dim_data(test_case): device = random_device() input = random_tensor(ndim=0).to(device) other = random_tensor(ndim=0).to(device) return torch.fmod(input, other) @autotest(n=1, auto_backward=grad_implemented) def test_flow_fmod_broadcast_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=constant(1), dim2=dim2).to(device) other = random_tensor(ndim=3, dim1=dim1, dim2=constant(1)).to(device) return torch.fmod(input, other) @autotest(n=1, auto_backward=True) def test_flow_fmod_scalar_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) other = 3 return torch.fmod(input, other) @autotest(n=1, auto_backward=True) def test_fmod_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.fmod(x, 2) return y @profile(torch.fmod) def profile_fmod(test_case): torch.fmod(torch.ones(100, 100, 100), 1) torch.fmod(torch.ones(100, 100, 100), -0.5) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fold.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_2_t @flow.unittest.skip_unless_1n1d() class TestFold(flow.unittest.TestCase): @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4) def test_fold_with_random_data_1(test_case): m = torch.nn.Fold( output_size=constant((4, 4)), kernel_size=constant(3), dilation=constant(1), padding=constant(1), stride=constant(1), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=3, dim0=constant(2), dim1=constant(36), dim2=constant(16) ).to(device) y = m(x) func_y = torch.nn.functional.fold( x, output_size=constant((4, 4)), kernel_size=constant(3), dilation=constant(1), padding=constant(1), stride=constant(1), ) return y, func_y @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4) def test_fold_with_random_data_2(test_case): m = torch.nn.Fold( output_size=constant((4, 4)), kernel_size=constant(3), dilation=constant(1), padding=constant(0), stride=constant(1), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=3, dim0=constant(2), dim1=constant(36), dim2=constant(4) ).to(device) y = m(x) func_y = torch.nn.functional.fold( x, output_size=constant((4, 4)), kernel_size=constant(3), dilation=constant(1), padding=constant(0), stride=constant(1), ) return y, func_y @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4) def test_fold_with_random_data_3(test_case): m = torch.nn.Fold( output_size=constant((8, 8)), kernel_size=constant(3), dilation=constant(1), padding=constant(1), stride=constant(2), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=3, dim0=constant(2), dim1=constant(72), dim2=constant(16) ).to(device) y = m(x) func_y = torch.nn.functional.fold( x, output_size=constant((8, 8)), kernel_size=constant(3), dilation=constant(1), padding=constant(1), stride=constant(2), ) return y, func_y @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4) def test_fold_with_random_data_4(test_case): m = torch.nn.Fold( output_size=constant((8, 8)), kernel_size=constant(3), dilation=constant(2), padding=constant(1), stride=constant(2), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=3, dim0=constant(2), dim1=constant(9), dim2=constant(9) ).to(device) y = m(x) func_y = torch.nn.functional.fold( x, output_size=constant((8, 8)), kernel_size=constant(3), dilation=constant(2), padding=constant(1), stride=constant(2), ) return y, func_y @profile(torch.nn.functional.fold) def profile_fold(test_case): x = torch.ones(128, 128, 4) torch.nn.functional.fold(x, output_size=(4, 4), kernel_size=(2, 2), stride=2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fork_sub_process.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from multiprocessing.pool import Pool import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _test_fork_sub_process(id): print("\nchild process:%s start! process id: %d" % (id, os.getpid())) import oneflow as flow x = flow.tensor(np.ones((4, 16)), device="cpu") y = flow.tensor(np.ones((16)), device="cpu") z = x + y assert np.array_equal(z.numpy(), np.ones((4, 16)) * 2) print("%s child process done! process id: %d." % (id, os.getpid())) @flow.unittest.skip_unless_1n1d() class TestForkSubProcess(flow.unittest.TestCase): def test_fork_sub_process(test_case): flow._oneflow_internal.eager.Sync() print("=============main process start=============") # process pool num_process = 4 p = Pool(num_process) async_res = [] for i in range(num_process): # create n child processes # put it to pool async_res.append(p.apply_async(_test_fork_sub_process, args=(i,))) p.close() p.join() for i in range(num_process): test_case.assertTrue(async_res[i].successful()) print("=============main process done!=============") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_frac.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestFrac(flow.unittest.TestCase): @autotest(n=5) def test_frac(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) output = torch.frac(input) return output @autotest(n=5) def test_tensor_frac(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) output = input.frac() return output @autotest(n=5) def test_tensor_frac_inplace(test_case): device = random_device() ndim = random(2, 4).to(int).value() shape = [random(2, 4) for i in range(ndim)] input = random_tensor(ndim, *shape).to(device) input = input + 1.0 input.frac_() return input if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_from_numpy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random import unittest import torch import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestFromNumpy(flow.unittest.TestCase): def test_same_data(test_case): np_arr = np.random.randn(3, 4, 5) tensor = flow.from_numpy(np_arr) test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) test_case.assertEqual(tensor.size(), (3, 4, 5)) test_case.assertEqual(tensor.stride(), (20, 5, 1)) test_case.assertEqual(tensor.storage_offset(), 0) np_arr[1:2, 2:3, 3:4] = random.random() test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) def test_use_ops(test_case): np_arr = np.random.randn(3, 4, 5) tensor = flow.from_numpy(np_arr) res = tensor ** 2 test_case.assertTrue(np.allclose(np_arr ** 2, res.numpy())) def test_more_dtype(test_case): for dtype in [ np.float64, np.float32, np.float16, np.int64, np.int32, np.int8, np.uint8, ]: np_arr = np.ones((2, 3), dtype=dtype) tensor = flow.from_numpy(np_arr) # TODO(wyg): oneflow.float16 do not support to copy from tensor to numpy if tensor.dtype not in [flow.float16]: test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) def test_non_contiguous_input(test_case): np_arr = np.random.randn(2, 3, 4, 5).transpose(2, 0, 3, 1) flow_tensor = flow.from_numpy(np_arr) torch_tensor = torch.from_numpy(np_arr) test_case.assertTrue(flow_tensor.shape == torch_tensor.shape) test_case.assertTrue(flow_tensor.stride() == torch_tensor.stride()) test_case.assertTrue( flow_tensor.is_contiguous() == torch_tensor.is_contiguous() ) test_case.assertTrue(np.array_equal(flow_tensor.numpy(), torch_tensor.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_from_torch.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import torch def torch_device_to_flow(device): if device.type == "cpu": return flow.device("cpu") elif device.type == "cuda": return flow.device("cuda", device.index) else: raise NotImplementedError("Unsupported device type: {}".format(device.type)) class TestFromTroch(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_from_torch_cpu(test_case): torch_t = torch.rand(5, 3, 3) numpy_from_torch = torch_t.numpy() # NOTE: torch and numpy shared the same memory. test_case.assertEqual( torch_t.data_ptr(), numpy_from_torch.__array_interface__["data"][0] ) numpy_from_torch[0][0] = [1, 2, 3] test_case.assertTrue( np.allclose(torch_t.numpy(), numpy_from_torch, rtol=0.001, atol=0.001) ) # NOTE: oneflow and numpy shared the same memory, # so oneflow and torch cpu tensor shared the same memory, # which means oneflow can use torch's cpu tensor without cost. flow_t = flow.utils.tensor.from_torch(torch_t) test_case.assertTrue( np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype) # NOTE: For the case of 0 size tensor, no memory addresses are compared. # Because the address of 0 size tensor is random at this time. @flow.unittest.skip_unless_1n1d() def test_from_torch_cpu_with_0_size_data(test_case): torch_t = torch.rand(5, 0, 3) flow_t = flow.utils.tensor.from_torch(torch_t) test_case.assertTrue( np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype) @flow.unittest.skip_unless_1n1d() def test_from_torch_cpu_with_0dim_data(test_case): torch_t = torch.tensor(5) numpy_from_torch = torch_t.numpy() test_case.assertEqual( torch_t.data_ptr(), numpy_from_torch.__array_interface__["data"][0] ) flow_t = flow.utils.tensor.from_torch(torch_t) test_case.assertTrue( np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype) @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_from_torch_gpu(test_case): for device in [torch.device("cuda", 0), torch.device("cuda", 1)]: torch_t = torch.tensor([1, 2]).to(device) flow_t = flow.utils.tensor.from_torch(torch_t) test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy())) test_case.assertEqual(torch_t.cpu().numpy().dtype, flow_t.numpy().dtype) test_case.assertEqual(torch_device_to_flow(torch_t.device), flow_t.device) # Test oneflow tensor and pytorch tensor share the data torch_t[0] = 5 test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_functional_docstr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect import os import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _run_functional_doctest( test_case, globs=None, verbose=None, optionflags=0, raise_on_error=True, module=flow, ): import doctest parser = doctest.DocTestParser() if raise_on_error: runner = doctest.DebugRunner(verbose=verbose, optionflags=optionflags) else: runner = doctest.DocTestRunner(verbose=verbose, optionflags=optionflags) r = inspect.getmembers(module) for (name, fun) in r: if fun.__doc__ is not None: test = parser.get_doctest(fun.__doc__, {}, __name__, __file__, 0) try: runner.run(test) except doctest.DocTestFailure as e: print(f"\nGot error result in the docstring of {name}") print(f"got output: {e.got}") raise e except doctest.UnexpectedException as e: print(f"\nGot UnexpectedException in the docstring of {name}") raise e.exc_info[1] if not raise_on_error: test_case.assertEqual( runner.failures, 0, f"{runner.summarize()}, please turn on raise_on_error to see more details", ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestFunctionalDocstrModule(flow.unittest.TestCase): def test_functional_docstr(test_case): arg_dict = OrderedDict() arg_dict["module"] = [flow, flow.Tensor, flow.sbp, flow.env, flow.nn.functional] for arg in GenArgList(arg_dict): _run_functional_doctest( test_case, raise_on_error=True, verbose=True, module=arg[0] ) if __name__ == "__main__": flow.set_printoptions(linewidth=80) unittest.main() ================================================ FILE: python/oneflow/test/modules/test_functional_scalar_tensor_param.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestFunctionalWithScalarTensorParam(flow.unittest.TestCase): # NOTE: graph mode not support dynamic scalar tensor parameter @autotest(n=2, auto_backward=False, check_graph=False) def test_scalar_tensor_transfer_to_scalar(test_case): device = random_device() min = torch.tensor(0.0) max = torch.tensor(0.5) x = random_tensor(ndim=2, dim0=2, dim1=3).to(device) return x.clamp(min=min, max=max) @autotest(n=2, auto_backward=False, check_graph=False) def test_scalar_tensor_transfer_to_double(test_case): device = random_device() threshold = torch.tensor(0.5).to(device) x = random_tensor(ndim=2, dim0=2, dim1=3).to(device) return torch.nn.functional.threshold(x, threshold=threshold, value=0.5) @autotest(n=2, auto_backward=False, check_graph=False) def test_scalar_tensor_transfer_to_int(test_case): device = random_device() start_dim = torch.tensor(1).to(device) end_dim = torch.tensor(3).to(device) x = random_tensor(4, *(2, 3, 4, 5)).to(device) return x.flatten(start_dim=start_dim, end_dim=end_dim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_attention_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import math import itertools import os import oneflow as flow def _ref( query, key, value, num_heads, attn_mask_type="none", attn_bias=None, causal_diagonal_offset=0, query_seq_len=None, key_seq_len=None, ): query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 3, 1) value = value.permute(0, 2, 1, 3) scores = flow.matmul(query, key) / math.sqrt(query.shape[-1]) if attn_mask_type == "causal_from_bottom_right": causal_diagonal_offset += key.shape[-1] - query.shape[-2] if ( attn_mask_type == "causal_from_top_left" or attn_mask_type == "causal_from_bottom_right" ): causal_mask = flow.triu( flow.ones( scores.shape[-2], scores.shape[-1], dtype=flow.bool, device="cuda" ), causal_diagonal_offset + 1, ) scores = flow.masked_fill(scores, causal_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias if query_seq_len is not None: scores = flow.masked_fill( scores, flow.arange(scores.shape[-2], device=query_seq_len.device).view( 1, 1, scores.shape[-2], 1 ) >= query_seq_len.view(scores.shape[0], 1, 1, 1), float("-inf"), ) if key_seq_len is not None: scores = flow.masked_fill( scores, flow.arange(scores.shape[-1], device=key_seq_len.device).view( 1, 1, 1, scores.shape[-1] ) >= key_seq_len.view(scores.shape[0], 1, 1, 1), float("-inf"), ) attn = flow.softmax(scores, dim=-1) out = flow.matmul(attn, value) out = out.permute(0, 2, 1, 3) out = out.reshape(out.shape[0], out.shape[1], -1) return out def _to_layout(ts, layout, tensor_index, seq_len=None): if layout == "BMHK": return ts[tensor_index] elif layout == "BM(HK)": return ts[tensor_index].view( ts[tensor_index].shape[0], ts[tensor_index].shape[1], -1 ) elif layout == "MB(HK)": return ( ts[tensor_index] .view(ts[tensor_index].shape[0], ts[tensor_index].shape[1], -1) .transpose(0, 1) ) elif layout == "BHMK": return ts[tensor_index].transpose(1, 2) elif layout == "MBHK": return ts[tensor_index].transpose(0, 1) elif layout == "BM(H3K)": return flow.stack(ts, -2).view(ts[0].shape[0], ts[0].shape[1], -1) elif layout == "MB(H3K)": return ( flow.stack(ts, -2).view(ts[0].shape[0], ts[0].shape[1], -1).transpose(0, 1) ) elif layout == "BM(H2K)": return flow.stack(ts[1:], -2).view(ts[1].shape[0], ts[1].shape[1], -1) elif layout == "MB(H2K)": return ( flow.stack(ts[1:], -2) .view(ts[1].shape[0], ts[1].shape[1], -1) .transpose(0, 1) ) elif layout == "(BM)HK": t = ts[tensor_index] if seq_len is None: return t.view(-1, t.shape[-2], t.shape[-1]) mask = flow.arange(t.shape[1], device=t.device).view( 1, t.shape[1] ) < seq_len.view(t.shape[0], 1) return flow.masked_select( t, mask.view(mask.shape[0], mask.shape[1], 1, 1) ).view(-1, t.shape[-2], t.shape[-1]) elif layout == "(BM)(HK)": t = ts[tensor_index] if seq_len is None: return t.view(-1, t.shape[-2] * t.shape[-1]) mask = flow.arange(t.shape[1], device=t.device).view( 1, t.shape[1] ) < seq_len.view(t.shape[0], 1) return flow.masked_select( t, mask.view(mask.shape[0], mask.shape[1], 1, 1) ).view(-1, t.shape[-2] * t.shape[-1]) elif layout == "(BM)(H2K)": t = flow.stack(ts[1:], -2) if seq_len is None: return t.view(t.shape[0] * t.shape[1], -1) mask = flow.arange(t.shape[1], device=t.device).view( 1, t.shape[1] ) < seq_len.view(t.shape[0], 1) return flow.masked_select( t, mask.view(mask.shape[0], mask.shape[1], 1, 1, 1) ).view(-1, t.shape[-3] * t.shape[-2] * t.shape[-1]) elif layout == "(BM)(H3K)": t = flow.stack(ts, -2) if seq_len is None: return t.view(t.shape[0] * t.shape[1], -1) mask = flow.arange(t.shape[1], device=t.device).view( 1, t.shape[1] ) < seq_len.view(t.shape[0], 1) return flow.masked_select( t, mask.view(mask.shape[0], mask.shape[1], 1, 1, 1) ).view(-1, t.shape[-3] * t.shape[-2] * t.shape[-1]) else: raise NotImplementedError def _fused_mha( query, key, value, num_heads, attn_mask_type="none", attn_bias=None, causal_diagonal_offset=0, query_layout="BM(HK)", key_layout="BM(HK)", value_layout="BM(HK)", output_layout="MB(HK)", query_seq_len=None, key_seq_len=None, use_kv_seq_len=False, ): batch_size = query.shape[0] query_max_seq_len = query.shape[1] query_head_size = query.shape[-1] key_max_seq_len = key.shape[1] ts = [query, key, value] query = _to_layout(ts, query_layout, 0, query_seq_len) if use_kv_seq_len: key = _to_layout(ts, key_layout, 1) value = _to_layout(ts, value_layout, 2) else: key = _to_layout(ts, key_layout, 1, key_seq_len) value = _to_layout(ts, value_layout, 2, key_seq_len) if query_seq_len is not None: query_seq_start = ( flow.cumsum(flow.pad(query_seq_len, (1, 0)), dim=-1) .to(flow.int32) .to(query.device) ) else: query_seq_start = None query_max_seq_len = None if key_seq_len is not None: if use_kv_seq_len: key_seq_start = flow.arange( 0, key_max_seq_len * (batch_size + 1), key_max_seq_len, dtype=flow.int32, device=key_seq_len.device, ) else: key_seq_start = ( flow.cumsum(flow.pad(key_seq_len, (1, 0)), dim=-1) .to(flow.int32) .to(query.device) ) else: key_seq_start = None key_max_seq_len = None if attn_bias is not None and attn_bias.shape[-1] % 8 != 0: pad = 8 - attn_bias.shape[-1] % 8 attn_bias = flow.pad(attn_bias, (0, pad), "constant", 0) output = flow._C.fused_multi_head_attention_inference_v2( query=query, key=key, value=value, query_head_size=query_head_size, attn_mask_type=attn_mask_type, attn_bias=attn_bias, causal_diagonal_offset=causal_diagonal_offset, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, output_layout=output_layout, query_seq_start=query_seq_start, key_seq_start=key_seq_start, key_seq_len=key_seq_len.to(flow.int32).to("cuda") if use_kv_seq_len else None, query_max_seq_len=query_max_seq_len, key_max_seq_len=key_max_seq_len, ) if output_layout == "BM(HK)" or output_layout == "(BM)(HK)": return output elif output_layout == "MB(HK)": return output.transpose(0, 1) else: raise NotImplementedError def _test_fused_attention_concat_past_key_value( test_case, dtype, b, past_m, m, h, k, past_key_layout, past_value_layout, key_layout, value_layout, ): if past_m > 0: past_key = flow.randn((b, past_m, h, k), device="cuda", dtype=flow.float,).to( dtype ) past_value = flow.randn((b, past_m, h, k), device="cuda", dtype=flow.float,).to( dtype ) else: past_key = None past_value = None key = flow.randn((b, m, h, k), device="cuda", dtype=flow.float,).to(dtype) value = flow.randn((b, m, h, k), device="cuda", dtype=flow.float,).to(dtype) ( fused_concated_key, fused_concated_value, ) = flow._C.fused_attention_concat_past_key_value( past_key=_to_layout([past_key, past_key, past_value], past_key_layout, 1), past_key_layout=past_key_layout, past_value=_to_layout([past_key, past_key, past_value], past_value_layout, 2), past_value_layout=past_value_layout, key=_to_layout([key, key, value], key_layout, 1), key_layout=key_layout, value=_to_layout([key, key, value], value_layout, 2), value_layout=value_layout, key_head_size=k, ) if past_m > 0: concated_key = flow.cat([past_key, key], dim=1) concated_value = flow.cat([past_value, value], dim=1) else: concated_key = key concated_value = value ref_concated_key = _to_layout( [concated_key, concated_key, concated_value], past_key_layout, 1 ) ref_concated_value = _to_layout( [concated_key, concated_key, concated_value], past_value_layout, 2 ) test_case.assertTrue( np.array_equal(fused_concated_key.numpy(), ref_concated_key.numpy()) ) test_case.assertTrue( np.array_equal(fused_concated_value.numpy(), ref_concated_value.numpy()) ) def _test_fused_multi_head_attention_inference( test_case, batch_size, num_heads, query_seq_len, kv_seq_len, query_head_size, value_head_size, dtype, attn_mask_type="none", causal_diagonal_offset=0, query_layout="BM(HK)", key_layout="BM(HK)", value_layout="BM(HK)", output_layout="BM(HK)", ): query = flow.randn( (batch_size, query_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) key = flow.randn( (batch_size, kv_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) value = flow.randn( (batch_size, kv_seq_len, num_heads, value_head_size), device="cuda", dtype=flow.float, ).to(dtype) fused_out = _fused_mha( query, key, value, num_heads, attn_mask_type=attn_mask_type, causal_diagonal_offset=causal_diagonal_offset, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, output_layout=output_layout, ).numpy() ref_out = _ref( query, key, value, num_heads, attn_mask_type=attn_mask_type, causal_diagonal_offset=causal_diagonal_offset, ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) def _test_fused_multi_head_attention_inference_with_attn_bias( test_case, batch_size, num_heads, query_seq_len, kv_seq_len, query_head_size, value_head_size, dtype, attn_mask_type="none", ): query = flow.randn( (batch_size, query_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) key = flow.randn( (batch_size, kv_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) value = flow.randn( (batch_size, kv_seq_len, num_heads, value_head_size), device="cuda", dtype=flow.float, ).to(dtype) attn_bias = flow.randn((kv_seq_len,), device="cuda", dtype=flow.float).to(dtype) ref_out = _ref( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() fused_out = _fused_mha( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) attn_bias = flow.randn( (query_seq_len, kv_seq_len), device="cuda", dtype=flow.float ).to(dtype) ref_out = _ref( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() fused_out = _fused_mha( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) attn_bias = flow.randn( (num_heads, query_seq_len, kv_seq_len), device="cuda", dtype=flow.float ).to(dtype) ref_out = _ref( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() fused_out = _fused_mha( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) attn_bias = flow.randn( (batch_size, num_heads, query_seq_len, kv_seq_len), device="cuda", dtype=flow.float, ).to(dtype) ref_out = _ref( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() fused_out = _fused_mha( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) attn_bias = flow.randn( (num_heads, 1, kv_seq_len), device="cuda", dtype=flow.float ).to(dtype) ref_out = _ref( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() fused_out = _fused_mha( query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type ).numpy() test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) def _test_fused_multi_head_attention_inference_variable_length( test_case, batch_size, num_heads, query_seq_len, kv_seq_len, query_head_size, value_head_size, dtype, query_layout, key_layout, value_layout, use_kv_seq_len, attn_mask_type="none", causal_diagonal_offset=0, ): query = flow.randn( (batch_size, query_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) key = flow.randn( (batch_size, kv_seq_len, num_heads, query_head_size), device="cuda", dtype=flow.float, ).to(dtype) value = flow.randn( (batch_size, kv_seq_len, num_heads, value_head_size), device="cuda", dtype=flow.float, ).to(dtype) query_seq_len_t = flow.randint( low=1, high=query.shape[1], size=(query.shape[0],), device="cuda", dtype=flow.int32, ) key_seq_len_t = flow.randint( low=1, high=key.shape[1], size=(key.shape[0],), device="cuda", dtype=flow.int32 ) fused_out = _fused_mha( query, key, value, num_heads, attn_mask_type=attn_mask_type, causal_diagonal_offset=causal_diagonal_offset, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, output_layout="(BM)(HK)", query_seq_len=query_seq_len_t, key_seq_len=key_seq_len_t, use_kv_seq_len=use_kv_seq_len, ) ref_out = _ref( query, key, value, num_heads, attn_mask_type=attn_mask_type, causal_diagonal_offset=causal_diagonal_offset, query_seq_len=query_seq_len_t, key_seq_len=key_seq_len_t, ) ref_out = ref_out.view(batch_size, query_seq_len, num_heads, value_head_size) ref_out = _to_layout([ref_out], "(BM)HK", 0, seq_len=query_seq_len_t) ref_out = ref_out.view(ref_out.shape[0], -1) test_case.assertTrue( np.allclose(ref_out.numpy(), fused_out.numpy(), atol=1e-2, rtol=1e-2) ) @unittest.skipIf(True, "skip test") @flow.unittest.skip_unless_1n1d() class TestFusedMultiHeadAttentionInference(flow.unittest.TestCase): def test_multi_head_attention_inference(test_case): # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 4096, 40, 40, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 77, 40, 40, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 1024, 80, 80, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 77, 80, 80, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 256, 160, 160, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 77, 160, 160, flow.float16 ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 4096, 40, 40, flow.float ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 4096, 77, 40, 40, flow.float ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 1024, 80, 80, flow.float ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 1024, 77, 80, 80, flow.float ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 256, 160, 160, flow.float ) _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 77, 160, 160, flow.float ) _test_fused_multi_head_attention_inference( test_case, 1, 8, 4, 8, 16, 16, flow.float, attn_mask_type="causal_from_top_left", causal_diagonal_offset=4, ) def test_multi_head_attention_inference_with_attn_bias(test_case): # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 4096, 40, 40, flow.float16 ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 4096, 40, 40, flow.float ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 4096, 40, 40, flow.float16, "causal_from_top_left" ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 4096, 40, 40, flow.float, "causal_from_bottom_right" ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 80, 40, 40, flow.float16 ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 80, 40, 40, flow.float ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 80, 40, 40, flow.float16, "causal_from_top_left" ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 80, 4096, 40, 40, flow.float16, "causal_from_bottom_right" ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 80, 40, 40, flow.float, "causal_from_top_left" ) _test_fused_multi_head_attention_inference_with_attn_bias( test_case, 2, 8, 4096, 77, 40, 40, flow.float, "causal_from_top_left" ) def test_multi_head_attention_inference_with_layout(test_case): layouts = [ "BM(HK)", "BMHK", "MBHK", "BHMK", "MB(HK)", "BM(H3K)", "BM(H2K)", "MB(H3K)", "MB(H2K)", ] for query_layout, key_layout, value_layout in itertools.product( layouts, layouts, layouts ): if query_layout == "BM(H2K)" or query_layout == "MB(H2K)": continue _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 256, 160, 160, flow.float16, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, ) def test_multi_head_attention_inference_with_output_layout(test_case): layouts = [ "BM(HK)", "MB(HK)", ] for output_layout in layouts: _test_fused_multi_head_attention_inference( test_case, 2, 8, 256, 256, 160, 160, flow.float16, output_layout=output_layout, ) _test_fused_multi_head_attention_inference( test_case, 1, 8, 256, 256, 160, 160, flow.float16, output_layout=output_layout, ) def test_multi_head_attention_inference_variable_length(test_case): # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype layouts = ["(BM)HK", "(BM)(HK)", "(BM)(H2K)", "(BM)(H3K)"] for ( query_layout, key_layout, value_layout, use_kv_seq_len, ) in itertools.product(layouts, layouts, layouts, (False, True)): if query_layout == "(BM)(H2K)": continue _test_fused_multi_head_attention_inference_variable_length( test_case, 2, 8, 16, 16, 40, 40, flow.float16, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, use_kv_seq_len=use_kv_seq_len, ) if ( query_layout == "(BM)(H3K)" or key_layout == "(BM)(H3K)" or value_layout == "(BM)(H3K)" ): continue _test_fused_multi_head_attention_inference_variable_length( test_case, 2, 8, 16, 32, 40, 40, flow.float16, query_layout=query_layout, key_layout=key_layout, value_layout=value_layout, use_kv_seq_len=use_kv_seq_len, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestFusedAttentionConcatPastKeyValue(flow.unittest.TestCase): def test_fused_attention_concat_past_key_value(test_case): kv_layouts = [ "BM(HK)", "BMHK", "MBHK", "BHMK", "MB(HK)", "BM(H3K)", # "BM(H2K)", # "MB(H3K)", "MB(H2K)", ] past_layouts = [ "BM(HK)", "BMHK", # "MBHK", # "BHMK", "MB(HK)", ] types = [flow.float16] for ( past_key_layout, past_value_layout, key_layout, value_layout, dtype, ) in itertools.product( past_layouts, past_layouts, kv_layouts, kv_layouts, types ): _test_fused_attention_concat_past_key_value( test_case, dtype, 1, 127, 1, 40, 128, past_key_layout=past_key_layout, past_value_layout=past_value_layout, key_layout=key_layout, value_layout=value_layout, ) _test_fused_attention_concat_past_key_value( test_case, flow.float, 1, 0, 1, 40, 128, past_key_layout="BMHK", past_value_layout="BMHK", key_layout="BMHK", value_layout="BMHK", ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_bias_add_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_bias_add_dropout(test_case, shape, axis, drop_prob): x = np.random.randn(*shape) bias = np.random.randn(shape[axis]) # fused version only support in GPU fused_x_tensor = flow.Tensor(x).to("cuda") fused_x_tensor.requires_grad = True fused_bias_tensor = flow.Tensor(bias).to("cuda") fused_bias_tensor.requires_grad = True fused_out = flow._C.fused_bias_add_dropout( fused_x_tensor, fused_bias_tensor, p=drop_prob, axis=axis ) origin_x_tensor = flow.Tensor(x).to("cuda") origin_x_tensor.requires_grad = True origin_bias_tensor = flow.Tensor(bias).to("cuda") origin_bias_tensor.requires_grad = True origin_dropout = flow.nn.Dropout(p=drop_prob) origin_out = origin_dropout( flow._C.bias_add(origin_x_tensor, origin_bias_tensor, axis=axis) ) total_out = fused_out.sum() + origin_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedBiasAddDropout(flow.unittest.TestCase): def test_fuse_bias_add_dropout(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_bias_add_dropout] arg_dict["shape"] = [(16, 64, 72), (32, 16, 48)] arg_dict["axis"] = [0, 1, 2, -1, -2, -3] arg_dict["drop_prob"] = [0.0, 1.0] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_bias_add_gelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_bias_add_gelu(test_case, channel, axis): x = np.random.randn(4, channel, 8, 10) bias = np.random.randn(channel) # fused version only support in GPU fused_x_tensor = flow.Tensor(x).to("cuda") fused_x_tensor.requires_grad = True fused_bias_tensor = flow.Tensor(bias).to("cuda") fused_bias_tensor.requires_grad = True fused_out = flow._C.fused_bias_add_gelu( fused_x_tensor, fused_bias_tensor, axis=axis ) origin_x_tensor = flow.Tensor(x).to("cuda") origin_x_tensor.requires_grad = True origin_bias_tensor = flow.Tensor(bias).to("cuda") origin_bias_tensor.requires_grad = True origin_out = flow.gelu( flow._C.bias_add(origin_x_tensor, origin_bias_tensor, axis=axis) ) total_out = fused_out.sum() + origin_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_tensor.grad.numpy(), origin_bias_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedBiasAddGelu(flow.unittest.TestCase): def test_gather(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_bias_add_gelu] arg_dict["channel"] = [2, 4, 6, 8] arg_dict["axis"] = [1] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_bias_add_scale_mask_softmax_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np from collections import OrderedDict import torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgDict def _torch_bias_add_scale_mask_softmax_dropout(x, bias, mask, fill, scale, p): masked = (x + bias) * mask * scale unmask = (1 - mask.int()).bool() masked.masked_fill_(unmask, fill) softmax_y = torch.nn.functional.softmax(masked, dim=-1) y = torch.nn.functional.dropout(softmax_y, p) return y, softmax_y def _test_bias_add_fused_scale_mask_softmax_dropout( test_case, input_shape, bias_shape, mask_shape, input_dtype=flow.float32, mask_dtype=flow.bool, fill=-10000, scale=1.0, p=0.0, device="cuda", ): # print(f"{'=' * 40} test case {'=' * 40}") # print(f"input_shap={input_shape}") # print(f"bias_shape={bias_shape}") # print(f"mask_shape={mask_shape}") # print(f"input_dtype={input_dtype}") # print(f"mask_dtype={mask_dtype}") # print(f"fill={fill}") # print(f"scale={scale}") # print(f"p={p}") np_input = np.random.randn(*input_shape).astype(np.float32) np_bias = np.random.randn(*bias_shape).astype(np.float32) np_mask = np.random.randint(0, 2, size=mask_shape).astype(np.int32) np_rand_init_grad = np.random.randn(*input_shape).astype(np.float32) torch_input = torch.tensor(np_input).to(device=device) torch_bias = torch.tensor(np_bias).to(device=device) torch_mask = torch.tensor(np_mask).to(device=device).bool() torch_rand_init_grad = torch.tensor(np_rand_init_grad).to(device=device) torch_input.requires_grad_(True) torch_bias.requires_grad_(True) torch_output, torch_softmax_output = _torch_bias_add_scale_mask_softmax_dropout( torch_input, torch_bias, torch_mask, fill, scale, p ) (torch_output * torch_rand_init_grad).sum().backward() torch_input_grad = torch_input.grad.detach().cpu() torch_bias_grad = torch_bias.grad.detach().cpu() torch_output = torch_output.detach().cpu() torch_softmax_output = torch_softmax_output.detach().cpu() input = flow.tensor(np_input, dtype=input_dtype, device=device) bias = flow.tensor(np_bias, dtype=input_dtype, device=device) mask = flow.tensor(np_mask, dtype=mask_dtype, device=device) rand_init_grad = flow.tensor(np_rand_init_grad, dtype=input_dtype, device=device) input.requires_grad_(True) bias.requires_grad_(True) output, softmax_output = flow._C.fused_bias_add_scale_mask_softmax_dropout( input, bias, mask, fill_value=fill, scale=scale, p=p, ) (output * rand_init_grad).sum().backward() input_grad = input.grad.detach().cpu() bias_grad = bias.grad.detach().cpu() output = output.to(dtype=flow.float32, device="cpu") softmax_output = softmax_output.to(dtype=flow.float32, device="cpu") def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8): test_case.assertTrue( np.allclose(a.numpy(), b.numpy(), atol=atol, rtol=rtol), f"\n{a_name}:\n{a.numpy()}\n{'-' * 80}\n{b_name}:\n{b.numpy()}\n{'*' * 80}\ndiff:\n{a.numpy() - b.numpy()}\n{a_name} vs. {b_name} max_diff:\n{np.max(np.abs(a.numpy() - b.numpy()))}", ) if input_dtype == flow.float16: compare(output, torch_output, "output", "torch_output", atol=1e-3, rtol=1e-2) compare( softmax_output, torch_softmax_output, "softmax_output", "torch_softmax_output", atol=1e-3, rtol=1e-2, ) compare( input_grad, torch_input_grad, "input_grad", "torch_input_grad", atol=1e-2, rtol=1e-2, ) compare( bias_grad, torch_bias_grad, "bias_grad", "torch_bias_grad", atol=1e-2, rtol=1e-2, ) else: compare(output, torch_output, "output", "torch_output") compare( softmax_output, torch_softmax_output, "softmax_output", "torch_softmax_output", ) compare(input_grad, torch_input_grad, "input_grad", "torch_input_grad") compare(bias_grad, torch_bias_grad, "bias_grad", "torch_bias_grad") @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedBiasAddScaleMaskSoftmaxDropout(flow.unittest.TestCase): def test_real_case(test_case): args_dict = OrderedDict() args_dict["input_shape"] = [[4, 12, 8, 8]] args_dict["bias_shape"] = [[1, 12, 8, 8]] args_dict["mask_shape"] = [[4, 1, 1, 8]] args_dict["input_dtype"] = [flow.float16, flow.float32] args_dict["mask_dtype"] = [flow.bool] args_dict["fill"] = [-10000.0] args_dict["scale"] = [1.0, 2.0, 4.0] args_dict["p"] = [0.0, 1.0] for kwarg in GenArgDict(args_dict): _test_bias_add_fused_scale_mask_softmax_dropout(test_case, **kwarg) def test_different_broadcast_dim(test_case): _test_bias_add_fused_scale_mask_softmax_dropout( test_case, [4, 2, 3], [1, 2, 3], [4, 1, 3] ) def test_same_broadcast_dim(test_case): _test_bias_add_fused_scale_mask_softmax_dropout( test_case, [4, 2, 3], [1, 2, 3], [1, 2, 3] ) def test_broadcast_bias(test_case): _test_bias_add_fused_scale_mask_softmax_dropout( test_case, [4, 2, 3], [1, 1, 3], [4, 2, 3] ) def test_broadcast_mask(test_case): _test_bias_add_fused_scale_mask_softmax_dropout( test_case, [4, 2, 3], [4, 2, 3], [4, 1, 3] ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_center.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def torch_center(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2): return ( (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 ) / 4 def _test_fused_get_center_dist_impl(test_case, device, shape): def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) x = [] torch_x = [] for _ in range(8): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = ( x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], ) ( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, ) = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], torch_x[4], torch_x[5], torch_x[6], torch_x[7], ) rho2 = flow._C.fused_get_center_dist( b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 ) torch_rho2 = torch_center( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, ) compare(rho2, torch_rho2) rho2.sum().backward() torch_rho2.sum().backward() compare(b1_x1.grad, torch_b1_x1.grad) compare(b1_x2.grad, torch_b1_x2.grad) compare(b2_x1.grad, torch_b2_x1.grad) compare(b2_x2.grad, torch_b2_x2.grad) compare(b1_y1.grad, torch_b1_y1.grad) compare(b1_y2.grad, torch_b1_y2.grad) compare(b2_y1.grad, torch_b2_y1.grad) compare(b2_y2.grad, torch_b2_y2.grad) @flow.unittest.skip_unless_1n1d() class TestGetCenterDistModule(flow.unittest.TestCase): def test_fused_get_center_dist(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_get_center_dist_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(583, 1), (759, 1), (1234, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_codegeex_qkv_reshape.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_codegeex_qkv_reshape_impl(test_case, device, shape, num_attention_heads): query = flow.randn(shape).to("cuda") key = flow.randn(shape).to("cuda") value = flow.randn(shape).to("cuda") new_shape = ( shape[0], shape[1], num_attention_heads, shape[2] / num_attention_heads, ) new_query = query.view(new_shape) new_query = new_query.contiguous() new_key = key.view(new_shape) new_key = new_key.contiguous() new_value = value.view(new_shape) new_value = new_value.contiguous() ( fused_new_query, fused_new_key, fused_new_value, ) = flow._C.fused_codegeex_qkv_reshape(query, key, value, num_attention_heads) def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) compare(new_query, fused_new_query) compare(new_key, fused_new_key) compare(new_value, fused_new_value) @flow.unittest.skip_unless_1n1d() class TestFusedCodegeexQkvReshapeModule(flow.unittest.TestCase): def test_codegeex_qkv_reshape(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_codegeex_qkv_reshape_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(32, 8, 16), (32, 8, 32)] arg_dict["num_attention_heads"] = [(4), (8)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_cross_interaction.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _test_fused_cross_feature_interaction_v1( test_case, batchsize, in_feature, dtype, device, ): x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) weight = np.random.uniform(low=-1, high=1, size=(1, in_feature)) bias = np.random.uniform(low=-1, high=1, size=(in_feature)) x0 = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True) naive_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True) fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True) naive_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True) fused_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True) naive_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True) fused_out = flow._C.fused_cross_feature_interaction( fused_x, fused_weight, fused_x0, fused_bias, "vector" ) naive_out = ( flow._C.matmul(naive_x, naive_weight, transpose_b=True) * naive_x0 + naive_bias ) + naive_x total_out = fused_out.sum() + naive_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4,) ) test_case.assertTrue( np.allclose( fused_weight.grad.numpy(), naive_weight.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose(fused_x0.grad.numpy(), naive_x0.grad.numpy(), atol=1e-4, rtol=1e-4,) ) test_case.assertTrue( np.allclose( fused_bias.grad.numpy(), naive_bias.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) def _test_fused_cross_feature_interaction_v2( test_case, batchsize, in_feature, dtype, device, ): x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) weight = np.random.uniform(low=-1, high=1, size=(in_feature, in_feature)) bias = np.random.uniform(low=-1, high=1, size=(in_feature)) x0 = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True) naive_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True) fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True) naive_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True) fused_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True) naive_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True) fused_out = flow._C.fused_cross_feature_interaction( fused_x, fused_weight, fused_x0, fused_bias, "matrix" ) naive_out = ( flow._C.bias_add( flow._C.matmul(naive_x, naive_weight, transpose_b=True), naive_bias, axis=1 ) * naive_x0 + naive_x ) total_out = fused_out.sum() + naive_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4,) ) test_case.assertTrue( np.allclose( fused_weight.grad.numpy(), naive_weight.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose(fused_x0.grad.numpy(), naive_x0.grad.numpy(), atol=1e-4, rtol=1e-4,) ) test_case.assertTrue( np.allclose( fused_bias.grad.numpy(), naive_bias.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestFusedCrossFeatureInteraction(flow.unittest.TestCase): def test_fused_cross_feature_interaction_v1(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_cross_feature_interaction_v1] args_dict["batchsize"] = [1, 2, 4] args_dict["in_feature"] = [32, 64, 96, 128] args_dict["dtype"] = [flow.float32] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) def test_fused_cross_feature_interaction_v2(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_cross_feature_interaction_v2] args_dict["batchsize"] = [1, 2, 4] args_dict["in_feature"] = [32, 64, 96, 128] args_dict["dtype"] = [flow.float32] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_dot_feature_interaction.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow import oneflow.unittest import os def _test_fused_dot_feature_interaction( test_case, embedding_size, self_interaction=False, output_concat=True, output_padding=0, dtype=flow.float32, device_type="cuda", ): batch_size = 100 dims = 26 if dtype == flow.float16: np_dtype = np.float16 else: np_dtype = np.float32 feature_0_np = np.random.rand(batch_size, embedding_size).astype(np_dtype) feature_1_np = np.random.rand(batch_size, 26, embedding_size).astype(np_dtype) feature_0_tensor = flow.tensor(feature_0_np, device="cuda", requires_grad=True) feature_1_tensor = flow.tensor(feature_1_np, device="cuda", requires_grad=True) if self_interaction: offset = 1 else: offset = 0 li = flow.tensor([i for i in range(27) for j in range(i + offset)]) lj = flow.tensor([j for i in range(27) for j in range(i + offset)]) T = flow.cat( [ flow.reshape(feature_0_tensor, (batch_size, 1, embedding_size)), feature_1_tensor, ], dim=1, ) Z = flow.matmul(T, T, transpose_b=True) # gather_nd not support half, so cast to float32 Z = flow.cast(Z, flow.float32) Zflat = Z[:, li, lj] Zflat = flow.cast(Zflat, dtype) if output_concat: R = flow.cat([feature_0_tensor, Zflat], dim=1) else: R = Zflat if output_padding != 0: padding_tensor = flow.tensor( np.zeros((batch_size, output_padding)).astype(np_dtype), device="cuda", requires_grad=False, ) R = flow.cat([R, padding_tensor], dim=1) loss = R.sum() loss.backward() fused_feature_0_tensor = flow.tensor( feature_0_np, device="cuda", requires_grad=True ) fused_feature_1_tensor = flow.tensor( feature_1_np, device="cuda", requires_grad=True ) if output_concat: output_concat_tensor = fused_feature_0_tensor else: output_concat_tensor = None fused_R = flow._C.fused_dot_feature_interaction( [ fused_feature_0_tensor.reshape(batch_size, 1, embedding_size), fused_feature_1_tensor, ], output_concat=output_concat_tensor, self_interaction=self_interaction, output_padding=output_padding, pooling="none", ) fused_loss = fused_R.sum() fused_loss.backward() test_case.assertTrue( np.allclose( feature_0_tensor.grad.numpy(), fused_feature_0_tensor.grad.numpy(), rtol=1e-3, atol=1e-4, ) ) test_case.assertTrue( np.allclose( feature_1_tensor.grad.numpy(), fused_feature_1_tensor.grad.numpy(), rtol=1e-3, atol=1e-4, ) ) test_case.assertTrue(np.allclose(fused_R.numpy(), R.numpy(), rtol=1e-3, atol=1e-3)) def _test_fused_dot_feature_interaction_pooling_sum( test_case, dtype, feature_dims, embedding_size, device_type="cuda", ): batch_size = 100 if dtype == flow.float16: np_dtype = np.float16 else: np_dtype = np.float32 feature_tensor_list = [] fused_feature_tensor_list = [] for dim in feature_dims: feature_np = np.random.uniform(-1, 1, (batch_size, dim, embedding_size)).astype( np_dtype ) feature_tensor = flow.tensor(feature_np, device="cuda", requires_grad=True) feature_tensor_list.append(feature_tensor) fused_feature_tensor = flow.tensor( feature_np, device="cuda", requires_grad=True ) fused_feature_tensor_list.append(fused_feature_tensor) concat = flow.cat(feature_tensor_list, dim=1,) if dtype == flow.float16: concat = flow.cast(concat, flow.float) sum_then_square = flow.sum(concat, dim=1) ** 2 square_then_sum = flow.sum(concat ** 2, dim=1) bi_interaction = (sum_then_square - square_then_sum) * 0.5 if dtype == flow.float16: bi_interaction = flow.cast(bi_interaction, flow.float16) R = flow.sum(bi_interaction, dim=-1, keepdim=True) loss = R.sum() loss.backward() fused_R = flow._C.fused_dot_feature_interaction( fused_feature_tensor_list, pooling="sum", ) fused_loss = fused_R.sum() fused_loss.backward() if dtype == flow.float16: tol = 1e-2 else: tol = 1e-3 for i in range(len(feature_dims)): test_case.assertTrue( np.allclose( feature_tensor_list[i].grad.numpy(), fused_feature_tensor_list[i].grad.numpy(), rtol=1e-3, atol=1e-3, ) ) test_case.assertTrue(np.allclose(fused_R.numpy(), R.numpy(), rtol=tol, atol=tol)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class FusedDotFeatureInteractionTestCase(flow.unittest.TestCase): def test_fused_dot_feature_interaction(test_case): arg_dict = OrderedDict() arg_dict["embedding_size"] = [128, 127, 16, 15] arg_dict["self_interaction"] = [False, True] arg_dict["output_concat"] = [True, False] arg_dict["output_padding"] = [1, 0] arg_dict["dtype"] = [flow.float16, flow.float32] for kwargs in GenArgDict(arg_dict): _test_fused_dot_feature_interaction(test_case, **kwargs) def test_fused_dot_feature_interaction_pooling_sum(test_case): arg_dict = OrderedDict() arg_dict["dtype"] = [flow.float16, flow.float32] arg_dict["feature_dims"] = [[39], [13, 26], [1, 10, 3]] arg_dict["embedding_size"] = [16, 11, 12] for kwargs in GenArgDict(arg_dict): _test_fused_dot_feature_interaction_pooling_sum(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_gelu_mul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgDict def _test_fused_fast_gelu_mul(test_case, shape, dtype=flow.float32): x = flow.randn(*shape).to(dtype=dtype, device="cuda").requires_grad_(True) multiplier = flow.randn(*shape).to(dtype=dtype, device="cuda").requires_grad_(True) y = flow.nn.functional.gelu(x, approximate="tanh") * multiplier y.mean().backward() x_grad = x.grad.detach().cpu() m_grad = multiplier.grad.detach().cpu() y = y.detach().cpu() fused_x = x.detach().clone().requires_grad_(True) fused_multiplier = multiplier.detach().clone().requires_grad_(True) fused_y = flow._C.fused_fast_gelu_mul(fused_x, fused_multiplier) fused_y.mean().backward() fused_x_grad = fused_x.grad.detach().cpu() fused_m_grad = fused_multiplier.grad.detach().cpu() fused_y = fused_y.detach().cpu() def compare(a, b, rtol=1e-5, atol=1e-8): test_case.assertTrue( np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol), f"\na\n{a.numpy()}\n{'-' * 80}\nb:\n{b.numpy()}\n{'*' * 80}\ndiff:\n{a.numpy() - b.numpy()}", ) # print(f"\n{'=' * 20} shape={shape} dtype={dtype} {'=' * 20}") if dtype == flow.float16: compare(fused_y, y, 1e-2, 1e-3) compare(fused_x_grad, x_grad, 1e-4, 1e-3) compare(fused_m_grad, m_grad, 1e-4, 1e-3) else: compare(fused_y, y) compare(fused_x_grad, x_grad) compare(fused_m_grad, m_grad) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedFastGeluMul(flow.unittest.TestCase): def test_fused_fast_gelu_mul(test_case): args_dict = OrderedDict() args_dict["shape"] = [[5], [7, 10], [4, 2, 3], [8, 3, 16, 16]] args_dict["dtype"] = [flow.float16, flow.float32] for kwarg in GenArgDict(args_dict): _test_fused_fast_gelu_mul(test_case, **kwarg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_boundding_boxes_coord.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_get_boundding_boxes_coord_impl(test_case, device, shape): x = [] torch_x = [] for _ in range(8): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) x1, y1, w1, h1, x2, y2, w2, h2 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7] ( b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, ) = flow._C.fused_get_boundding_boxes_coord(x1, y1, w1, h1, x2, y2, w2, h2) torch_x1, torch_y1, torch_w1, torch_h1, torch_x2, torch_y2, torch_w2, torch_h2 = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], torch_x[4], torch_x[5], torch_x[6], torch_x[7], ) torch_w1_, torch_h1_, torch_w2_, torch_h2_ = ( torch_w1 / 2, torch_h1 / 2, torch_w2 / 2, torch_h2 / 2, ) torch_b1_x1, torch_b1_x2, torch_b1_y1, torch_b1_y2 = ( torch_x1 - torch_w1_, torch_x1 + torch_w1_, torch_y1 - torch_h1_, torch_y1 + torch_h1_, ) torch_b2_x1, torch_b2_x2, torch_b2_y1, torch_b2_y2 = ( torch_x2 - torch_w2_, torch_x2 + torch_w2_, torch_y2 - torch_h2_, torch_y2 + torch_h2_, ) def compare(a, b, rtol=1e-5, atol=1e-8): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) compare(b1_x1, torch_b1_x1) compare(b1_x2, torch_b1_x2) compare(b1_y1, torch_b1_y1) compare(b1_y2, torch_b1_y2) compare(b2_x1, torch_b2_x1) compare(b2_x2, torch_b2_x2) compare(b2_y1, torch_b2_y1) compare(b2_y2, torch_b2_y2) res = ( (b1_x1 + 2 * b1_x2 + b1_y1 + b1_y2 + b2_x1 + b2_x2 + b2_y1 + b2_y2) * 2 ).sum() torch_res = ( ( torch_b1_x1 + 2 * torch_b1_x2 + torch_b1_y1 + torch_b1_y2 + torch_b2_x1 + torch_b2_x2 + torch_b2_y1 + torch_b2_y2 ) * 2 ).sum() res.sum().backward() torch_res.sum().backward() compare(x1.grad, torch_x1.grad) compare(y1.grad, torch_y1.grad) compare(w1.grad, torch_w1.grad) compare(h1.grad, torch_h1.grad) compare(x2.grad, torch_x2.grad) compare(y2.grad, torch_y2.grad) compare(w2.grad, torch_w2.grad) compare(h2.grad, torch_h2.grad) @flow.unittest.skip_unless_1n1d() class TestGetBounddingBoxesCoordModule(flow.unittest.TestCase): def test_get_boundding_boxes_coord(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_get_boundding_boxes_coord_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(583, 1), (759, 1), (1234, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_ciou_diagonal_angle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def torch_get_ciou_diagonal_angle(w1, h1, w2, h2, eps=1e-8): return (4 / math.pi ** 2) * torch.pow( torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2 ) def _test_fused_get_ciou_diagonal_angle_impl(test_case, device, shape): def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) x = [] torch_x = [] for _ in range(4): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) w1, h1, w2, h2 = ( x[0], x[1], x[2], x[3], ) (torch_w1, torch_h1, torch_w2, torch_h2,) = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], ) v = flow._C.fused_get_ciou_diagonal_angle(w1, h1, w2, h2, eps=1e-8) torch_v = torch_get_ciou_diagonal_angle(torch_w1, torch_h1, torch_w2, torch_h2,) compare(v, torch_v) v.sum().backward() torch_v.sum().backward() compare(w1.grad, torch_w1.grad) compare(h1.grad, torch_h1.grad) compare(w2.grad, torch_w2.grad) compare(h2.grad, torch_h2.grad) @flow.unittest.skip_unless_1n1d() class TestGetCiouDiagonalAngle(flow.unittest.TestCase): def test_fused_get_ciou_diagonal_angle(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_get_ciou_diagonal_angle_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(583, 1), (759, 1), (1234, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_ciou_result.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_get_ciou_result_impl(test_case, device, shape): eps = 1e-7 x = [] torch_x = [] for _ in range(4): tmp = flow.tensor( np.random.uniform(0, 1, shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) v, iou, rho2, c2 = x[0], x[1], x[2], x[3] y = flow._C.fused_get_ciou_result(v, iou, rho2, c2, eps)[0] torch_v, torch_iou, torch_rho2, torch_c2 = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], ) with torch.no_grad(): torch_alpha = torch_v / (torch_v - torch_iou + (1.0 + eps)) torch_y = torch_iou - (torch_rho2 / torch_c2 + torch_v * torch_alpha) def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) compare(y, torch_y) res = y.sum() torch_res = torch_y.sum() res.backward() torch_res.backward() compare(v.grad, torch_v.grad) compare(iou.grad, torch_iou.grad) compare(rho2.grad, torch_rho2.grad) compare(c2.grad, torch_c2.grad) @flow.unittest.skip_unless_1n1d() class TestGetCiouResultModule(flow.unittest.TestCase): def test_get_ciou_result(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_get_ciou_result_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(492), (691, 1), (1162, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_convex_diagonal_squared.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def torch_fused_get_convex_diagonal_squared( b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, eps ): cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) c2 = cw ** 2 + ch ** 2 + eps return c2 def _test_fused_get_convex_diagonal_squared_impl(test_case, device, shape): def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) eps = 1e-8 x = [] torch_x = [] for _ in range(8): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = ( x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], ) ( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, ) = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], torch_x[4], torch_x[5], torch_x[6], torch_x[7], ) c2 = flow._C.fused_get_convex_diagonal_squared( b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, eps ) torch_c2 = torch_fused_get_convex_diagonal_squared( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, eps, ) compare(c2, torch_c2) c2.sum().backward() torch_c2.sum().backward() compare(b1_x1.grad, torch_b1_x1.grad) compare(b1_x2.grad, torch_b1_x2.grad) compare(b2_x1.grad, torch_b2_x1.grad) compare(b2_x2.grad, torch_b2_x2.grad) compare(b1_y1.grad, torch_b1_y1.grad) compare(b1_y2.grad, torch_b1_y2.grad) compare(b2_y1.grad, torch_b2_y1.grad) compare(b2_y2.grad, torch_b2_y2.grad) @flow.unittest.skip_unless_1n1d() class TestGetCenterDistModule(flow.unittest.TestCase): def test_fused_get_convex_diagonal_squared(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_get_convex_diagonal_squared_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(583, 1), (759, 1), (1234, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_intersection_area.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def torch_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2): return (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * ( torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1) ).clamp(0) def _test_fused_get_intersection_area_impl(test_case, device, shape): def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\n", ) x = [] torch_x = [] for _ in range(8): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=True, ) ) b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = ( x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], ) ( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, ) = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], torch_x[4], torch_x[5], torch_x[6], torch_x[7], ) inter = flow._C.fused_get_intersection_area( b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 ) torch_inter = torch_get_intersection_area( torch_b1_x1, torch_b1_x2, torch_b2_x1, torch_b2_x2, torch_b1_y1, torch_b1_y2, torch_b2_y1, torch_b2_y2, ) compare(inter, torch_inter) inter.sum().backward() torch_inter.sum().backward() compare(b1_x1.grad, torch_b1_x1.grad) compare(b1_x2.grad, torch_b1_x2.grad) compare(b2_x1.grad, torch_b2_x1.grad) compare(b2_x2.grad, torch_b2_x2.grad) compare(b1_y1.grad, torch_b1_y1.grad) compare(b1_y2.grad, torch_b1_y2.grad) compare(b2_y1.grad, torch_b2_y1.grad) compare(b2_y2.grad, torch_b2_y2.grad) @flow.unittest.skip_unless_1n1d() class TestGetIntersectionAreaModule(flow.unittest.TestCase): def test_fused_get_inter_intersection_area(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_get_intersection_area_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(583, 1), (759, 1), (1234, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_get_iou.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_get_iou_impl(test_case, device, shape): eps = 1e-7 x = [] torch_x = [] for _ in range(5): tmp = flow.tensor( np.random.uniform(0, 1, shape), dtype=flow.float64, device=flow.device(device), requires_grad=True if (_ < 2 or _ > 3) else False, ) x.append(tmp) torch_x.append( torch.tensor( tmp.numpy(), dtype=torch.float64, device=torch.device(device), requires_grad=True if (_ < 2 or _ > 3) else False, ) ) w1, h1, w2, h2, inter = x[0], x[1], x[2], x[3], x[4] iou = flow._C.fused_get_iou(w1, h1, w2, h2, inter, eps) torch_w1, torch_h1, torch_w2, torch_h2, torch_inter = ( torch_x[0], torch_x[1], torch_x[2], torch_x[3], torch_x[4], ) torch_iou = torch_inter / ( torch_w1 * torch_h1 + torch_w2 * torch_h2 - torch_inter + eps ) def compare(a, b, rtol=1e-5, atol=1e-5, w1=w1, h1=h1, w2=w2, h2=h2, inter=inter): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) compare(iou, torch_iou) res = iou.sum() torch_res = torch_iou.sum() res.backward() torch_res.backward() compare(w1.grad, torch_w1.grad) compare(h1.grad, torch_h1.grad) compare(inter.grad, torch_inter.grad) @flow.unittest.skip_unless_1n1d() class TestGetIouModule(flow.unittest.TestCase): def test_get_iou(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_get_iou_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(492), (691, 1), (1162, 1)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_glu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import time import datetime import numpy as np from collections import OrderedDict import oneflow as flow import oneflow.nn as nn import oneflow.unittest from oneflow.test_utils.test_util import GenArgList test_dualgemm_impt = False class Glu(nn.Module): def __init__(self): super().__init__() def forward( self, x: flow.Tensor, w: flow.Tensor, b: flow.Tensor = None, v: flow.Tensor = None, c: flow.Tensor = None, split_mode: bool = False, activation: str = "none", ) -> flow.Tensor: # matmul matmul_wx = flow._C.matmul( input=x, other=w, transpose_a=False, transpose_b=True ) if split_mode: matmul_vx = flow._C.matmul( input=x, other=v, transpose_a=False, transpose_b=True ) # add bias if b != None: matmul_wx_b = flow._C.add(input=matmul_wx, other=b) if split_mode: matmul_vx_c = flow._C.add(input=matmul_vx, other=c) else: matmul_wx_b = matmul_wx if split_mode: matmul_vx_c = matmul_vx # chunk if split_mode: hidden_state = matmul_wx_b gate = matmul_vx_c else: hidden_state, gate = matmul_wx_b.chunk(2, dim=-1) # activation and element-wise product if activation == "none": return hidden_state * gate elif activation == "sigmoid": return hidden_state * flow.sigmoid(gate) elif activation == "relu": return hidden_state * flow.relu(gate) elif activation == "gelu": return hidden_state * flow.gelu(gate) elif activation == "fast_gelu": return hidden_state * flow._C.fast_gelu(gate) elif activation == "silu": return hidden_state * flow.silu(gate) def tensor_builder(params: dict, dtype=flow.float32, is_split_mode=True): # config test data m = params["m"] n = params["n"] k = params["k"] # generate random input x = np.random.randn(2, m, k) / 100 y_nor = np.random.randn(2, m, n) if is_split_mode: w = np.random.randn(n, k) / 100 # transpose b = np.random.randn(n) / 100 v = np.random.randn(n, k) / 100 # transpose c = np.random.randn(n) / 100 else: w = np.random.randn(n * 2, k) / 100 # transpose b = np.random.randn(n * 2) / 100 # transfer to gpu memory tensor_x = flow.FloatTensor(x).to(dtype=dtype, device="cuda") tensor_y_nor = flow.FloatTensor(y_nor).to(dtype=dtype, device="cuda") tensor_w = flow.FloatTensor(w).to(dtype=dtype, device="cuda").requires_grad_(True) tensor_b = flow.FloatTensor(b).to(dtype=dtype, device="cuda").requires_grad_(True) if is_split_mode: tensor_v = ( flow.FloatTensor(v).to(dtype=dtype, device="cuda").requires_grad_(True) ) tensor_c = ( flow.FloatTensor(c).to(dtype=dtype, device="cuda").requires_grad_(True) ) if is_split_mode: return tensor_x, tensor_w, tensor_b, tensor_v, tensor_c, tensor_y_nor else: return tensor_x, tensor_w, tensor_b, tensor_y_nor def compare_result(test_case, a, b, rtol=1e-5, atol=1e-8): test_case.assertTrue( np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol), f"\na\n{a.numpy()}\n{'-' * 80}\nb:\n{b.numpy()}\n{'*' * 80}\ndiff:\n{a.numpy() - b.numpy()}", ) def _test_fused_glu(test_case, params: dict, dtype=flow.float32): print(f"========== Start Testing ==========") print(f"weight tensor: merged") print(f'tensor shape: m={params["m"]}, n={params["n"]}, k={params["k"]}') print(f'activation: {params["act"]}') print(f"dtype: {dtype}") flow_module = Glu() x, w, b, y_nor = tensor_builder(params=params, dtype=dtype, is_split_mode=False) # forward y = flow_module.forward(x=x, w=w, b=b, split_mode=False, activation=params["act"]) # backward y.sum().backward() # copy back to cpu memory w_grad = w.grad.detach().cpu() b_grad = b.grad.detach().cpu() y = y.detach().cpu() fused_x = x.detach().clone() fused_w = w.detach().clone().requires_grad_(True) fused_b = b.detach().clone().requires_grad_(True) # forward fused_y = flow._C.fused_glu( x=fused_x, w=fused_w, b=fused_b, v=None, c=None, activation=params["act"] ) # backward fused_y.sum().backward() # copy back to cpu memory fused_w_grad = fused_w.grad.detach().cpu() fused_b_grad = fused_b.grad.detach().cpu() fused_y = fused_y.detach().cpu() if dtype == flow.float16: compare_result(test_case, fused_y, y, 1e-2, 1e-3) compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1) compare_result(test_case, fused_b_grad, b_grad, 1e-2, 1e-1) else: compare_result(test_case, fused_y, y) compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2) compare_result(test_case, fused_b_grad, b_grad, 1e-5, 1e-2) print(f"============== PASSED =============") print("\n") def _test_fused_glu_without_bias(test_case, params: dict, dtype=flow.float32): print(f"========== Start Testing ==========") print(f"weight tensor: merged") print(f"no bias") print(f'tensor shape: m={params["m"]}, n={params["n"]}, k={params["k"]}') print(f'activation: {params["act"]}') print(f"dtype: {dtype}") flow_module = Glu() x, w, b, y_nor = tensor_builder(params=params, dtype=dtype, is_split_mode=False) # forward y = flow_module.forward(x=x, w=w, split_mode=False, activation=params["act"]) # backward y.sum().backward() # copy back to cpu memory w_grad = w.grad.detach().cpu() y = y.detach().cpu() fused_x = x.detach().clone() fused_w = w.detach().clone().requires_grad_(True) # forward fused_y = flow._C.fused_glu( x=fused_x, w=fused_w, b=None, v=None, c=None, activation=params["act"] ) # backward fused_y.sum().backward() # copy back to cpu memory fused_w_grad = fused_w.grad.detach().cpu() fused_y = fused_y.detach().cpu() if dtype == flow.float16: compare_result(test_case, fused_y, y, 1e-2, 1e-3) compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1) else: compare_result(test_case, fused_y, y) compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2) print(f"============== PASSED =============") print("\n") def _test_fused_glu_split(test_case, params: dict, dtype=flow.float32): print(f"========== Start Testing ==========") print(f"weight tensor: splited") print(f'tensor shape: m={params["m"]}, n={params["n"]}, k={params["k"]}') print(f'activation: {params["act"]}') print(f"dtype: {dtype}") flow_module = Glu() x, w, b, v, c, y_nor = tensor_builder( params=params, dtype=dtype, is_split_mode=True ) # forward y = flow_module.forward( x=x, w=w, b=b, v=v, c=c, split_mode=True, activation=params["act"] ) # backward y.sum().backward() # copy back to cpu memory w_grad = w.grad.detach().cpu() b_grad = b.grad.detach().cpu() v_grad = v.grad.detach().cpu() c_grad = c.grad.detach().cpu() y = y.detach().cpu() fused_x = x.detach().clone() fused_w = w.detach().clone().requires_grad_(True) fused_b = b.detach().clone().requires_grad_(True) fused_v = v.detach().clone().requires_grad_(True) fused_c = c.detach().clone().requires_grad_(True) # forward fused_y = flow._C.fused_glu( x=fused_x, w=fused_w, b=fused_b, v=fused_v, c=fused_c, activation=params["act"] ) # backward fused_y.sum().backward() fused_w_grad = fused_w.grad.detach().cpu() fused_b_grad = fused_b.grad.detach().cpu() fused_v_grad = fused_v.grad.detach().cpu() fused_c_grad = fused_c.grad.detach().cpu() fused_y = fused_y.detach().cpu() if dtype == flow.float16: compare_result(test_case, fused_y, y, 1e-2, 1e-3) compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1) compare_result(test_case, fused_b_grad, b_grad, 1e-2, 1e-1) compare_result(test_case, fused_v_grad, v_grad, 1e-2, 1e-1) compare_result(test_case, fused_c_grad, c_grad, 1e-2, 1e-1) else: compare_result(test_case, fused_y, y) compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2) compare_result(test_case, fused_b_grad, b_grad, 1e-5, 1e-2) compare_result(test_case, fused_v_grad, v_grad, 1e-5, 1e-2) compare_result(test_case, fused_c_grad, c_grad, 1e-5, 1e-2) print(f"============== PASSED =============") print("\n") def _test_fused_glu_split_without_bias(test_case, params: dict, dtype=flow.float32): print(f"========== Start Testing ==========") print(f"weight tensor: splited") print(f"no bias") print(f'tensor shape: m={params["m"]}, n={params["n"]}, k={params["k"]}') print(f'activation: {params["act"]}') print(f"dtype: {dtype}") flow_module = Glu() x, w, b, v, c, y_nor = tensor_builder( params=params, dtype=dtype, is_split_mode=True ) # forward y = flow_module.forward(x=x, w=w, v=v, split_mode=True, activation=params["act"]) # backward y.sum().backward() # copy back to cpu memory w_grad = w.grad.detach().cpu() v_grad = v.grad.detach().cpu() y = y.detach().cpu() fused_x = x.detach().clone() fused_w = w.detach().clone().requires_grad_(True) fused_v = v.detach().clone().requires_grad_(True) # forward fused_y = flow._C.fused_glu( x=fused_x, w=fused_w, b=None, v=fused_v, c=None, activation=params["act"] ) # backward fused_y.sum().backward() fused_w_grad = fused_w.grad.detach().cpu() fused_v_grad = fused_v.grad.detach().cpu() fused_y = fused_y.detach().cpu() if dtype == flow.float16: compare_result(test_case, fused_y, y, 1e-2, 1e-3) compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1) compare_result(test_case, fused_v_grad, v_grad, 1e-2, 1e-1) else: compare_result(test_case, fused_y, y) compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2) compare_result(test_case, fused_v_grad, v_grad, 1e-5, 1e-2) print(f"============== PASSED =============") print("\n") # @flow.unittest.skip_unless_1n1d() # @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @unittest.skipIf(True, "CI test taking too long.") class TestFusedGlu(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") def test_gather(test_case): arg_dict = OrderedDict() # set up test functions arg_dict["test_fun"] = [ _test_fused_glu, _test_fused_glu_split, _test_fused_glu_without_bias, _test_fused_glu_split_without_bias, ] # set up env valuable if necessary if not test_dualgemm_impt: os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "false" else: os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "true" # set up profiling functions if not test_dualgemm_impt: arg_dict["params"] = [ # m=256, k=1280, n=5120 {"m": 256, "k": 1280, "n": 5120, "act": "none"}, {"m": 256, "k": 1280, "n": 5120, "act": "sigmoid"}, {"m": 256, "k": 1280, "n": 5120, "act": "relu"}, {"m": 256, "k": 1280, "n": 5120, "act": "gelu"}, {"m": 256, "k": 1280, "n": 5120, "act": "fast_gelu"}, {"m": 256, "k": 1280, "n": 5120, "act": "silu"}, # m=1024, k=640, n=2560 {"m": 1024, "k": 640, "n": 2560, "act": "none"}, {"m": 1024, "k": 640, "n": 2560, "act": "sigmoid"}, {"m": 1024, "k": 640, "n": 2560, "act": "relu"}, {"m": 1024, "k": 640, "n": 2560, "act": "gelu"}, {"m": 1024, "k": 640, "n": 2560, "act": "fast_gelu"}, {"m": 1024, "k": 640, "n": 2560, "act": "silu"}, # m=4096, k=320, n=1280 # {"m": 4096, "k": 320, "n": 1280, "act": "none"}, # {"m": 4096, "k": 320, "n": 1280, "act": "sigmoid"}, # {"m": 4096, "k": 320, "n": 1280, "act": "relu"}, # {"m": 4096, "k": 320, "n": 1280, "act": "gelu"}, # {"m": 4096, "k": 320, "n": 1280, "act": "fast_gelu"}, # {"m": 4096, "k": 320, "n": 1280, "act": "silu"}, # m=2560, k=12800, n=51200 # {"m": 2560, "k": 1280, "n": 5120, "act": "none"}, # {"m": 2560, "k": 1280, "n": 5120, "act": "sigmoid"}, # {"m": 2560, "k": 1280, "n": 5120, "act": "relu"}, # {"m": 2560, "k": 1280, "n": 5120, "act": "gelu"}, # {"m": 2560, "k": 1280, "n": 5120, "act": "fast_gelu"}, # {"m": 2560, "k": 1280, "n": 5120, "act": "silu"}, ] else: arg_dict["params"] = [ # m=256, k=1280, n=5120 {"m": 256, "k": 1280, "n": 5120, "act": "fast_gelu"}, # m=1024, k=640, n=2560 {"m": 1024, "k": 640, "n": 2560, "act": "fast_gelu"}, # m=4096, k=320, n=1280 {"m": 4096, "k": 320, "n": 1280, "act": "fast_gelu"}, # m=2560, k=12800, n=51200 {"m": 2560, "k": 1280, "n": 5120, "act": "fast_gelu"}, ] if not test_dualgemm_impt: arg_dict["dtype"] = [flow.float16, flow.float32] else: arg_dict["dtype"] = [flow.float16] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_matmul_bias.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList import oneflow as flow import numpy as np def _matmul_bias(x, weight, bias, add_to_output): return flow._C.add( flow._C.bias_add( flow._C.matmul(x, weight, transpose_b=True), bias, axis=len(x.shape) - 1 ), add_to_output, ) def _test_fused_matmul_add_bias( test_case, batchsize, in_feature, out_feature, _add_to_output, dtype, device, ): add_to_output = np.zeros((*batchsize, out_feature)) if _add_to_output: add_to_output = np.random.uniform( low=-1, high=1, size=(*batchsize, out_feature) ) x = np.random.uniform(low=-1, high=1, size=(*batchsize, in_feature)) weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature)) bias = np.random.uniform(low=-1, high=1, size=(out_feature)) naive_x = flow.tensor(x, dtype=dtype, requires_grad=True) naive_weight = flow.tensor(weight, dtype=dtype, requires_grad=True) naive_bias = flow.tensor(bias, dtype=dtype, requires_grad=True) naive_add_to_output = flow.tensor(add_to_output, dtype=dtype, requires_grad=True) fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True) fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True) fused_add_to_output = None if _add_to_output: fused_add_to_output = flow.tensor( add_to_output, dtype=dtype, device=device, requires_grad=False ) navie_y = _matmul_bias(naive_x, naive_weight, naive_bias, naive_add_to_output) fused_y = flow._C.fused_matmul_bias( fused_x, fused_weight, fused_bias, fused_add_to_output ) y = navie_y.sum() + fused_y.sum() y.backward() # TODO: relative error might be too high... # Test output equality if _add_to_output: test_case.assertTrue( np.allclose(navie_y.numpy(), fused_y.numpy(), atol=5e-2, rtol=1e-4) ) else: test_case.assertTrue( np.allclose(navie_y.numpy(), fused_y.numpy(), atol=5e-2, rtol=1e-4) ) # Test grad equality test_case.assertTrue( np.allclose(naive_x.grad.numpy(), fused_x.grad.numpy(), atol=5e-2, rtol=1e-4) ) test_case.assertTrue( np.allclose( naive_weight.grad.numpy(), fused_weight.grad.numpy(), atol=5e-2, rtol=1e-4 ) ) test_case.assertTrue( np.allclose( naive_bias.grad.numpy(), fused_bias.grad.numpy(), atol=1e-4, rtol=1e-4 ) ) @flow.unittest.skip_unless_1n1d() class TestFusedMatmulBiasAddRelu(flow.unittest.TestCase): def test_fused_matmul_op(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_matmul_add_bias] args_dict["batchsize"] = [ (1,), (4,), (8,), (2, 4), (2, 4, 8), (2, 4, 4, 4, 8), ] args_dict["in_feature"] = [96, 128] args_dict["out_feature"] = [512, 1024, 288, 1] args_dict["_add_to_output"] = [True] args_dict["dtype"] = [flow.float32, flow.float64] args_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_matmul_bias_add_relu_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _matmul_bias_relu(x, weight, bias, skip_activate): # We do not add dropout in unittest, cause its result is random. out = flow._C.bias_add(flow._C.matmul(x, weight, transpose_b=True), bias, axis=1) if not skip_activate: out = flow._C.relu(out) return out def _test_fused_matmul_bias_add_relu_dropout( test_case, batchsize, in_feature, hidden_size_list, out_feature, skip_final_activation, dtype, device, ): x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature)) fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_weight_list = [] naive_weight_list = [] fused_bias_list = [] naive_bias_list = [] hidden_num = len(hidden_size_list) if hidden_num != 0: np_first_weight = np.random.uniform( low=-1, high=1, size=(hidden_size_list[0], in_feature) ) np_first_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[0]) fused_weight_list.append( flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True) ) for idx in range(1, hidden_num): np_weight = np.random.uniform( low=-1, high=1, size=(hidden_size_list[idx], hidden_size_list[idx - 1]) ) np_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[idx]) fused_weight_list.append( flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True) ) np_final_weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature)) if hidden_num != 0: np_final_weight = np.random.uniform( low=-1, high=1, size=(out_feature, hidden_size_list[-1]) ) np_final_bias = np.random.uniform(low=-1, high=1, size=(out_feature)) fused_weight_list.append( flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True) ) fused_bias_list.append( flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True) ) naive_weight_list.append( flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True) ) naive_bias_list.append( flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True) ) fused_out = flow._C.fused_matmul_bias_add_relu_dropout( fused_x, fused_weight_list, fused_bias_list, # We do not add dropout in unittest, cause its result is random. dropout_rate_list=[0.0] * len(fused_weight_list), skip_final_activation=skip_final_activation, ) naive_out = _matmul_bias_relu( naive_x, naive_weight_list[0], naive_bias_list[0], False if hidden_num != 0 else skip_final_activation, ) for idx in range(1, hidden_num + 1): if idx == hidden_num: naive_out = _matmul_bias_relu( naive_out, naive_weight_list[idx], naive_bias_list[idx], skip_final_activation, ) else: naive_out = _matmul_bias_relu( naive_out, naive_weight_list[idx], naive_bias_list[idx], False ) total_out = fused_out.sum() + naive_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4) ) # Test weight grad equality for idx in range(hidden_num + 1): test_case.assertTrue( np.allclose( fused_weight_list[idx].grad.numpy(), naive_weight_list[idx].grad.numpy(), atol=1e-4, rtol=1e-4, ) ) test_case.assertTrue( np.allclose( fused_bias_list[idx].grad.numpy(), naive_bias_list[idx].grad.numpy(), atol=1e-4, rtol=1e-4, ) ) # Test dx equality test_case.assertTrue( np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestFusedMatmulBiasAddReluDropout(flow.unittest.TestCase): def test_fused_matmul_bias_add_relu_dropout(test_case): args_dict = OrderedDict() args_dict["test_func"] = [_test_fused_matmul_bias_add_relu_dropout] args_dict["batchsize"] = [1, 2, 4] args_dict["in_feature"] = [96, 128, 64] args_dict["hidden_size_list"] = [[256, 512], [400, 400, 400, 400], [17, 33, 79]] args_dict["out_feature"] = [512, 400, 1024, 1] args_dict["skip_final_activation"] = [False] args_dict["dtype"] = [flow.float32] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_rotary_embedding.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList import oneflow as flow import numpy as np import math def plane_shuffle(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return np.concatenate((-x2, x1), axis=-1) def shuffle_adjacent_two_elem(x): y = x.copy() for i in range(x.shape[-1] // 2): y[..., 2 * i] = -x[..., 2 * i + 1] y[..., 2 * i + 1] = x[..., 2 * i] return y def parseDims(dims, x_layout): B = 1 M = 1 H = 1 K = 1 merged_dims = dims if x_layout == "BHMK": B = dims[0] H = dims[1] M = dims[2] K = dims[3] merged_dims = dims # no merge elif x_layout == "BMHK": B = dims[0] M = dims[1] H = dims[2] K = dims[3] merged_dims = dims elif x_layout == "MBHK": B = dims[1] M = dims[0] H = dims[2] K = dims[3] merged_dims = dims elif x_layout == "BM(HK)": B = dims[0] M = dims[1] H = dims[2] K = dims[3] merged_dims = [dims[0], dims[1], dims[2] * dims[3]] elif x_layout == "MB(HK)": B = dims[1] M = dims[0] H = dims[2] K = dims[3] merged_dims = [dims[0], dims[1], dims[2] * dims[3]] elif x_layout == "BM(H3K)": B = dims[0] M = dims[1] H = dims[2] K = dims[3] merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] elif x_layout == "MB(H3K)": B = dims[1] M = dims[0] H = dims[2] K = dims[3] merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] return B, M, H, K, merged_dims # all cos&sin are by default in x_layout (B, H, M, K), in which H is 1 def naive_embedding( x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ): naive_out = None if mode == "plane": if rotary_ndims == 2: y1 = plane_shuffle(x[..., : rotary_size // 2]) y2 = plane_shuffle(x[..., rotary_size // 2 : rotary_size]) y3 = x[..., rotary_size:] y = np.concatenate((y1, y2, y3), axis=-1) else: y1 = plane_shuffle(x[..., :rotary_size]) y2 = x[..., rotary_size:] y = np.concatenate((y1, y2), axis=-1) else: y = shuffle_adjacent_two_elem(x) if x_layout == "BHMK": naive_out = x * cos + y * sin elif x_layout == "BMHK": naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( dims ) * sin.reshape( [B, M, 1, K] ) # un-merge elif x_layout == "MBHK" or x_layout == "MB(HK)": naive_out = x.reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) + y.reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) # un-merge elif x_layout == "BM(HK)": naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( dims ) * sin.reshape( [B, M, 1, K] ) # un-merge elif x_layout == "BM(H3K)": out0 = x[..., 0, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ ..., 0, : ].reshape(dims) * sin.reshape([B, M, 1, K]) out1 = x[..., 1, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ ..., 1, : ].reshape(dims) * sin.reshape([B, M, 1, K]) out2 = x[..., 2, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ ..., 2, : ].reshape(dims) * sin.reshape([B, M, 1, K]) naive_out = np.concatenate((out0, out1, out2), axis=-1) elif x_layout == "MB(H3K)": out0 = x[..., 0, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) + y[..., 0, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) out1 = x[..., 1, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) + y[..., 1, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) out2 = x[..., 2, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) + y[..., 2, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) naive_out = np.concatenate((out0, out1, out2), axis=-1) return naive_out # this assume that rotary_ndims is by default 1 def _test_without_position( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): B, M, H, K, merged_dims = parseDims(dims, x_layout) np.random.seed(3124) x = np.random.uniform(low=-1, high=1, size=(*merged_dims,)) naive_cos = np.array( [ [ [ math.cos( m * ( (1 / base) ** ( 2 * ((i % (rotary_size / rotary_ndims)) // 2) / (rotary_size / rotary_ndims) ) ) ) for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_sin = np.array( [ [ [ math.sin( m * ( (1 / base) ** ( 2 * ((i % (rotary_size / rotary_ndims)) // 2) / (rotary_size / rotary_ndims) ) ) ) for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_cos[..., rotary_size:] = 1 naive_sin[..., rotary_size:] = 0 naive_x = x if x_layout == "BM(HK)" or x_layout == "BM(H2K)" or x_layout == "BM(H3K)": naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) naive_out = naive_embedding( naive_x, naive_cos, naive_sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ) fused_cos = np.array( [ [ math.cos( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size / rotary_ndims) ) ) ) for i in range(rotary_size // rotary_ndims) ] for m in range(M) ] ).reshape(M, rotary_size // rotary_ndims) fused_sin = np.array( [ [ math.sin( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) for i in range(rotary_size // rotary_ndims) ] for m in range(M) ] ).reshape(M, rotary_size // rotary_ndims) fused_x = flow.tensor(x, dtype=dtype, device=device) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=0, ) out1 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=1, ) out2 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=2, ) fused_out = np.concatenate((out0, out1, out2), axis=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=None, x_layout=x_layout, k_size=K, base=base, rotary_size=rotary_size, mode=mode, ).numpy() test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), fused_out.reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) # this assume that rotary_ndims is by default 1 def _test_without_position_sinuous( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): B, M, H, K, merged_dims = parseDims(dims, x_layout) x = np.random.uniform(low=-1, high=1, size=(*merged_dims,)) naive_cos = np.array( [ [ [ math.cos( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_sin = np.array( [ [ [ math.sin( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_cos[..., rotary_size:] = 1 naive_sin[..., rotary_size:] = 0 naive_x = x if x_layout == "BM(HK)" or x_layout == "BM(H2K)" or x_layout == "BM(H3K)": naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) naive_out = naive_embedding( naive_x, naive_cos, naive_sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ) fused_x = flow.tensor(x, dtype=dtype, device=device) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=0, ) out1 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=1, ) out2 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=None, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=2, ) fused_out = np.concatenate((out0, out1, out2), axis=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=None, x_layout=x_layout, k_size=K, base=base, rotary_size=rotary_size, mode=mode, ).numpy() test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), fused_out.reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) def _test_with_position_sinuous( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): B, M, H, K, merged_dims = parseDims(dims, x_layout) np.random.seed(3124) x = np.random.uniform(low=-1, high=1, size=(*merged_dims,)) position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=np.int64) naive_cos = np.array( [ [ [ math.cos( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) if i < rotary_size else 1 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_sin = np.array( [ [ [ math.sin( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) if i < rotary_size else 0 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_cos[..., rotary_size:] = 1 naive_sin[..., rotary_size:] = 0 naive_x = x if x_layout == "BM(HK)" or x_layout == "BM(H2K)" or x_layout == "BM(H3K)": naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) naive_out = naive_embedding( naive_x, naive_cos, naive_sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ) fused_cos = np.array( [ [ math.cos( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) for i in range(rotary_size // rotary_ndims) ] for m in range(2 * M) ] ) fused_sin = np.array( [ [ math.sin( m * ( (1 / base) ** ( 2 * ((i % (rotary_size // rotary_ndims)) // 2) / (rotary_size // rotary_ndims) ) ) ) for i in range(rotary_size // rotary_ndims) ] for m in range(2 * M) ] ) fused_x = flow.tensor(x, dtype=dtype, device=device) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=0, ) out1 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=1, ) out2 = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=2, ) fused_out = np.concatenate((out0, out1, out2), axis=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, sin=fused_sin, position_ids=fused_position_ids, x_layout=x_layout, k_size=K, base=base, rotary_size=rotary_size, mode=mode, ).numpy() test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), fused_out.reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) def _test_with_position( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): B, M, H, K, merged_dims = parseDims(dims, x_layout) x = np.random.uniform(low=-1, high=1, size=(*merged_dims,)) position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=int) naive_cos = np.array( [ [ [ math.cos( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( (1 / base) ** ( 2 * ((i % (rotary_size / rotary_ndims)) // 2) / (rotary_size / rotary_ndims) ) ) ) if i < rotary_size else 1 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_sin = np.array( [ [ [ math.sin( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( (1 / base) ** ( 2 * ((i % (rotary_size / rotary_ndims)) // 2) / (rotary_size / rotary_ndims) ) ) ) if i < rotary_size else 0 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_x = x if x_layout == "BM(HK)" or x_layout == "BM(H2K)" or x_layout == "BM(H3K)": naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) naive_out = naive_embedding( naive_x, naive_cos, naive_sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ) fused_x = flow.tensor(x, dtype=dtype, device=device) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=0, ) out1 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=1, ) out2 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="BMHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=2, ) fused_out = np.concatenate((out0, out1, out2), axis=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, k_size=K, base=base, rotary_size=rotary_size, mode=mode, ).numpy() test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), fused_out.reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) def _test_plane( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): B, M, H, K, merged_dims = parseDims(dims, x_layout) np.random.seed(3124) x = np.random.uniform(low=-1, high=1, size=(*merged_dims,)) position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=int) naive_cos = np.array( [ [ [ math.cos( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( 1 / ( base ** ( 2 * (i % (rotary_size // (2 * rotary_ndims))) / (rotary_size / rotary_ndims) ) ) ) ) if i < rotary_size else 1 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_sin = np.array( [ [ [ math.sin( position_ids[b, i // ((rotary_size) // rotary_ndims), m] * ( 1 / ( base ** ( 2 * (i % (rotary_size // (2 * rotary_ndims))) / (rotary_size / rotary_ndims) ) ) ) ) if i < rotary_size else 0 for i in range(K) ] for m in range(M) ] for b in range(B) ] ).reshape(B, 1, M, K) naive_x = x if x_layout == "BM(HK)" or x_layout == "BM(H2K)" or x_layout == "BM(H3K)": naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) naive_out = naive_embedding( naive_x, naive_cos, naive_sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode, ) fused_x = flow.tensor(x, dtype=dtype, device=device) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "MB(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="MBHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=0, ) out1 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="MBHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=1, ) out2 = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, output_layout="MBHK", k_size=K, base=base, rotary_size=rotary_size, mode=mode, tensor_index=2, ) fused_out = np.concatenate((out0, out1, out2), axis=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=None, sin=None, position_ids=fused_position_ids, x_layout=x_layout, k_size=K, base=base, rotary_size=rotary_size, mode=mode, ).numpy() test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), fused_out.reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) """ 1. if cos&sin is given, then base will not be used 2. if cos&sin is not given, then any form of x_layout which cannot infer the dimension of k is not allowed, e.g. BM(HK) 3. if position_ids is given, then M of cos&sin could be different from M of x 4. if position_ids is not given, the dimension of rotary positional embedding is by default 1 """ @flow.unittest.skip_unless_1n1d() class TestFusedRotaryEmbedding(flow.unittest.TestCase): # because rule no.2, kernels without cos&sin cannot work under specific x_layout def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] args_dict["dims"] = [(3, 2, 5, 8)] args_dict["rotary_ndims"] = [2, 1] # args_dict["rotary_size"] = [48] # args_dict["dims"] = [(32, 2048, 32, 64)] args_dict["dtype"] = [flow.float16] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, _test_with_position_sinuous] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] args_dict["dims"] = [(3, 2, 5, 8)] args_dict["rotary_ndims"] = [2] # args_dict["rotary_size"] = [48] # args_dict["dims"] = [(32, 2048, 32, 64)] args_dict["dtype"] = [flow.float16] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ _test_without_position_sinuous, _test_without_position, _test_with_position, _test_with_position_sinuous, ] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] args_dict["dims"] = [(3, 2, 5, 8)] args_dict["rotary_ndims"] = [1] # args_dict["rotary_size"] = [48] # args_dict["dims"] = [(32, 2048, 32, 64)] args_dict["dtype"] = [flow.float16] args_dict["device"] = ["cuda"] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_scale_mask_bias_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import os from typing import List import time import oneflow as flow def timing(fn): def wrapper(*args, **kwargs): if args[-1] or kwargs.get("inplace"): return fn(*args, **kwargs) for _ in range(10): fn(*args, **kwargs) flow.cuda.synchronize() start = time.perf_counter() for _ in range(10): fn(*args, **kwargs) flow.cuda.synchronize() print(f"{fn.__name__}:{time.perf_counter() - start}") return fn(*args, **kwargs) return wrapper def permute_final_dims(tensor: flow.Tensor, inds: List[int]): zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) @timing def _fused_op(x, v, scale, mask, bias, inplace=False): out = flow._C.fused_scale_mask_bias_softmax(x, mask, bias, scale, inplace=inplace) out = flow.matmul(out, v) return out @timing def _ref_op(x, v, scale, mask, bias=None, inplace=False): x = x * scale + mask + bias if bias is not None else x * scale + mask out = flow.softmax(x, dim=-1) out = flow.matmul(out, v) return out def _test_fused_scale_mask_bias_softmax( test_case, N=512, S=128, D=128, h=8, d=32, mode="row", ensemble_batch=8, inplace=False, ): x = flow.randn(N, S, D, requires_grad=True).cuda() # N, S, D w3 = [flow.randn(D, h * d, requires_grad=True).cuda() for _ in range(3)] # D, h*d*3 mask = flow.randn(N, S, requires_grad=False).cuda() # N, S bias = None scale = 1 / (d ** 0.5) if mode in ["row", "triangular_start", "triangular_end"]: bias = flow.randn(1, h, S, S, requires_grad=True).cuda() # 1, h, S, S bias.retain_grad() mask = mask[:, None, None, :] if mode == "ensemble": x = flow.randn(ensemble_batch, N, S, D, requires_grad=True).cuda() # N, S, D bias = flow.randn( ensemble_batch, 1, h, S, S, requires_grad=True ).cuda() # E, 1, h, S, S bias.retain_grad() mask = flow.randn(ensemble_batch, N, 1, 1, S, requires_grad=False).cuda() if mode == "col" or mode == "global_col": N, S = S, N x = x.transpose(-2, -3) # S, N, D mask = mask.transpose(-1, -2) if mode == "col": mask = mask[..., None, None, :] # S, 1, 1, N q, k, v = [flow.matmul(x, w) for w in w3] # N, S, h * d if mode == "template": n_templ = 4 x = flow.randn(S, S, 1, D, requires_grad=True).cuda() k = v = flow.randn(S, S, n_templ, D, requires_grad=True).cuda() # N, S, D mask = flow.randn(1, 1, 1, 1, n_templ).cuda() q, k, v = [flow.matmul(x_, w) for x_, w in zip([x, k, v], w3)] q, k, v = [ permute_final_dims(a.view(*a.shape[:-1], h, d), (0, 2, 1, 3)) for a in [q, k, v] ] # N, h, S, d if mode == "global_col": w_q = flow.randn(D, h * d, requires_grad=True).cuda() # D, h*d w_kv = flow.randn(D, d * 2, requires_grad=True).cuda() # D, h*d*2 q = flow.sum(x * mask.unsqueeze(-1), dim=-2) / ( flow.sum(mask, dim=-1)[..., None] + 1e-9 ) # [N, D] mask = mask[..., :, None, :] # N,1,S q = flow.matmul(q, w_q).view(*q.shape[:-1], h, d) # N, h, d k, v = flow.matmul(x, w_kv).chunk(2, dim=-1) # N, S, d qk = flow.matmul(q, k.transpose(-1, -2)) # general op x.retain_grad() out1 = _ref_op(qk, v, scale, mask, bias, inplace) out1.sum().backward(retain_graph=True) grad_x1 = x.grad grad_bias1 = bias.grad if bias is not None else None # fused op out2 = _fused_op(qk, v, scale, mask, bias, inplace) out2.sum().backward() grad_x2 = x.grad grad_bias2 = bias.grad if bias is not None else None test_case.assertTrue(np.allclose(out1, out2, atol=2e-3, rtol=1e-5)) test_case.assertTrue(np.allclose(grad_x1, grad_x2, atol=5e-3, rtol=1e-5)) if bias is not None: test_case.assertTrue(np.allclose(grad_bias1, grad_bias2, atol=5e-4, rtol=1e-5)) @unittest.skipIf(True, "skip test for fused_scale_mask_bias_softmax.") @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedMsaSoftmax(flow.unittest.TestCase): def test_fused_msa_softmax(test_case): # different mask shape for each mode _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, "row") _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, "col") _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "triangular_start" ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "triangular_end" ) _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, "template") _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, "global_col") _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, "ensemble") _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "row", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "col", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 128, 128, 64, 8, 32, "triangular_start", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "triangular_end", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "template", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "global_col", inplace=True ) _test_fused_scale_mask_bias_softmax( test_case, 16, 128, 64, 8, 32, "ensemble", inplace=True ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_scale_mask_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_scale_mask_softmax( test_case, batch_size, num_heads, seq_length, fill_value, scale_value, broadcast_dim ): x = np.random.randn(batch_size, num_heads, seq_length, seq_length).astype( np.float32 ) mask_size = [batch_size, num_heads, seq_length, seq_length] if broadcast_dim: mask_size[broadcast_dim] = 1 mask = np.random.randint(0, 2, size=mask_size, dtype=bool) fused_x_tensor = flow.tensor(x, dtype=flow.float32).to("cuda") fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to("cuda") fused_x_tensor.requires_grad = True fused_out = flow._C.fused_scale_mask_softmax( fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value, ) origin_x_tensor = flow.tensor(x).to("cuda") origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to("cuda") origin_x_tensor.requires_grad = True origin_out = flow.mul( origin_x_tensor, origin_mask_tensor ) * scale_value + fill_value * (1.0 - origin_mask_tensor) origin_out = flow.softmax(origin_out, dim=-1) total_out = fused_out.sum() + origin_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedScaleMaskSoftmax(flow.unittest.TestCase): def test_fused_op(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_scale_mask_softmax] args_dict["batch_size"] = [4, 8, 16] args_dict["num_heads"] = [1, 4, 8] args_dict["seq_length"] = [16, 32, 64] args_dict["fill_value"] = [-10000.0] args_dict["scale_value"] = [1.0, 2.0, 4.0] args_dict["broadcast_dim"] = [None, 0, 1, 2] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_scale_mask_softmax_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_scale_mask_softmax_dropout( test_case, batch_size, num_heads, seq_length, fill_value, scale_value, broadcast_dim, p, ): x = np.random.randn(batch_size, num_heads, seq_length, seq_length) mask_size = [batch_size, num_heads, seq_length, seq_length] if broadcast_dim: mask_size[broadcast_dim] = 1 mask = np.random.randint(0, 2, size=mask_size, dtype=bool) fused_x_tensor = flow.tensor(x, dtype=flow.float32).to("cuda") fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to("cuda") fused_x_tensor.requires_grad = True # if mask is zero, fill it fused_out = flow._C.fused_scale_mask_softmax_dropout( fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value, p=p, )[0] origin_x_tensor = flow.tensor(x, dtype=flow.float32).to("cuda") origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to("cuda") origin_x_tensor.requires_grad = True origin_out = flow.mul( origin_x_tensor, origin_mask_tensor ) * scale_value + fill_value * (1.0 - origin_mask_tensor) origin_out = flow.softmax(origin_out, dim=-1) origin_out = flow._C.dropout(origin_out, p=p) total_out = fused_out.sum() + origin_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedScaleMaskSoftmaxDropout(flow.unittest.TestCase): def test_fused_op(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_fused_scale_mask_softmax_dropout] args_dict["batch_size"] = [4, 8, 16] args_dict["num_heads"] = [1, 4, 8] args_dict["seq_length"] = [8, 16, 32, 64] args_dict["fill_value"] = [-10000.0] args_dict["scale_value"] = [1.0, 2.0, 4.0] args_dict["broadcast_dim"] = [None, 0, 1, 2] args_dict["p"] = [0.0, 1.0] for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_scale_tril.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict import oneflow as flow def _np_tril(x, diagonal, fill_value, scale): if int(fill_value) == 0: return np.tril(x, diagonal) * scale upper = np.empty(x.shape) upper.fill(fill_value) upper = np.triu(upper, diagonal + 1) return np.tril(x, diagonal) * scale + upper def _test_fused_scale_tril( test_case, shape, diagonal=0, fill_value=0, scale=1, dtype=flow.float32, device_type="cuda", ): if dtype is flow.int32 and not isinstance(scale, int): return if dtype is flow.int32: x = np.random.randint(0, 10, shape) y_grad = np.random.randint(0, 10, shape) else: x = np.random.rand(*shape) y_grad = np.random.rand(*shape) y = _np_tril(x, diagonal, fill_value, scale) x_grad = _np_tril(y_grad, diagonal, 0, scale) flow_x = flow.tensor( x, device=flow.device(device_type), dtype=dtype, requires_grad=True ) flow_y = flow._C.fused_scale_tril(flow_x, diagonal, fill_value, scale) flow_y_grad = flow.tensor(y_grad, device=flow.device(device_type), dtype=dtype) flow_y.backward(flow_y_grad) flow_y_np = flow_y.numpy() test_case.assertTrue(np.allclose(flow_y_np, y.astype(flow_y_np.dtype))) flow_x_grad_np = flow_x.grad.numpy() test_case.assertTrue( np.allclose(flow_x_grad_np, x_grad.astype(flow_x_grad_np.dtype)) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class FusedScaleTrilTestCase(flow.unittest.TestCase): def test_fused_scale_tril(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(5, 5), (4, 6)] arg_dict["diagonal"] = [-1, 0, 1] arg_dict["fill_value"] = [-1, 0, 1] arg_dict["scale"] = [-2.3, 0.7, 2] arg_dict["dtype"] = [flow.float32] for kwargs in GenArgDict(arg_dict): _test_fused_scale_tril(test_case, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_self_attention.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_self_attention(test_case, batch_size, seq_len, num_heads, head_size): hidden_size = num_heads * 3 * head_size x = np.random.randn(seq_len, batch_size, hidden_size) fused_input = flow.Tensor(x).to("cuda") fused_input.requires_grad = True (fused_qmk, fused_v) = flow._C.fused_self_attention( fused_input, head_size=head_size, alpha=1.0, ) fused_atten = flow.matmul(fused_qmk, fused_v) fused_atten_sum = fused_atten.sum() origin_input = flow.Tensor(x).to("cuda") origin_input.requires_grad = True reshape_input = flow.reshape(origin_input, (seq_len, batch_size, -1, 3 * head_size)) origin_q = flow.slice( reshape_input, slice_tup_list=[ [None, None, None], [None, None, None], [None, None, None], [0, head_size, 1], ], ).permute(1, 2, 0, 3) origin_k = flow.slice( reshape_input, slice_tup_list=[ [None, None, None], [None, None, None], [None, None, None], [head_size, 2 * head_size, 1], ], ).permute(1, 2, 0, 3) origin_v = flow.slice( reshape_input, slice_tup_list=[ [None, None, None], [None, None, None], [None, None, None], [2 * head_size, 3 * head_size, 1], ], ).permute(1, 2, 0, 3) origin_k = origin_k.transpose(2, 3) origin_qmk = flow.matmul(origin_q, origin_k) origin_atten = flow.matmul(origin_qmk, origin_v) origin_atten_sum = origin_atten.sum() total_sum = fused_atten_sum + origin_atten_sum total_sum.backward() test_case.assertTrue( np.allclose(fused_atten.numpy(), origin_atten.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_input.grad.numpy(), origin_input.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestFusedSelfAttention(flow.unittest.TestCase): def _test_fused_self_attention(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_self_attention] arg_dict["batch_size"] = [1, 4, 6, 8] arg_dict["seq_len"] = [5, 10, 12] arg_dict["num_heads"] = [4, 8, 16] arg_dict["head_size"] = [16, 32, 64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_tril_softmax_mask_scale.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_fused_tril_softmax_mask_scale( test_case, seq_length, channel, p, diagonal, tril_scale_value ): x = np.random.randn(4, seq_length, channel) # fused version only support in GPU fused_x_tensor = flow.Tensor(x).to("cuda") fused_x_tensor.requires_grad = True fused_out = flow._C.fused_scale_tril_softmax_mask_scale( fused_x_tensor, p=p, diagonal=diagonal, tril_scale_value=tril_scale_value )[ 0 ] # The second output is softmax_y origin_x_tensor = flow.Tensor(x).to("cuda") origin_x_tensor.requires_grad = True origin_out = flow.tril(origin_x_tensor, diagonal) origin_out = origin_out * tril_scale_value origin_out = flow.softmax(origin_out, dim=-1) origin_out = flow._C.dropout(origin_out, p=p) total_out = fused_out.sum() + origin_out.sum() total_out.backward() test_case.assertTrue( np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( fused_x_tensor.grad.numpy(), origin_x_tensor.grad.numpy(), atol=1e-4, rtol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test gpu cases") class TestFusedTrilSoftmaxMaskScale(flow.unittest.TestCase): def test_fused_tril_softmax_dropout(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_tril_softmax_mask_scale] arg_dict["seq_length"] = [10, 20] arg_dict["channel"] = [20, 30] arg_dict["p"] = [0.0, 1.0] arg_dict["diagonal"] = [0, 1, 2] arg_dict["tril_scale_value"] = [2, 4, 10] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_fused_weighted_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import math import random import os import oneflow as flow def _ref(inputs, weights, alpha, init_grad, device, dtype): inputs = [flow.tensor(t).to(device).to(dtype) for t in inputs] for t in inputs: t.requires_grad = True init_grad = flow.tensor(init_grad).to(device).to(dtype) out = inputs[0] * weights[0] for i, w in zip(inputs[1:], weights[1:]): out += i * w out = out * alpha out.backward(init_grad) return out, [t.grad for t in inputs] def _fused_weighted_sum(inputs, weights, alpha, init_grad, device, dtype): inputs = [flow.tensor(t).to(device).to(dtype) for t in inputs] for t in inputs: t.requires_grad = True init_grad = flow.tensor(init_grad).to(device).to(dtype) out = flow._C.fused_weighted_sum(inputs, weights, alpha) out.backward(init_grad) return out, [t.grad for t in inputs] def _test_fused_weighted_sum(test_case, shape, n, device, dtype): inputs = [np.random.randn(*shape) for _ in range(n)] init_grad = np.random.randn(*shape) weights = [random.random() for _ in range(n)] alpha = random.random() out, grads = _fused_weighted_sum(inputs, weights, alpha, init_grad, device, dtype) ref, ref_grads = _ref(inputs, weights, alpha, init_grad, device, dtype) test_case.assertTrue(np.allclose(ref, out, atol=1e-5, rtol=1e-5)) for (grad, ref_grad) in zip(grads, ref_grads): test_case.assertTrue(np.allclose(ref_grad, grad, atol=1e-5, rtol=1e-5)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestFusedWeightedSum(flow.unittest.TestCase): def test_fused_weighted_sum(test_case): _test_fused_weighted_sum(test_case, (1024, 1024), 1, "cuda", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 3, "cuda", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 8, "cuda", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 11, "cuda", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 21, "cuda", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 1, "cpu", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 3, "cpu", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 8, "cpu", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 11, "cpu", flow.float32) _test_fused_weighted_sum(test_case, (1024, 1024), 21, "cpu", flow.float32) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_gather.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _scatter_add_numpy(src, dim, index, outshape): output = np.zeros(outshape) for srcidx in range(0, src.size): outcoord = np.unravel_index(srcidx, src.shape) outcoord = [*outcoord] outcoord[dim] = index[np.unravel_index(srcidx, index.shape)] output_offset = np.ravel_multi_index(outcoord, outshape) output[np.unravel_index(output_offset, outshape)] += src[ np.unravel_index(srcidx, src.shape) ] return output def _test_gather(test_case, device): input = np.array([[1, 2], [3, 4]]) index = np.array([[0, 0], [1, 0]]) np_out = np.take_along_axis(input, index, 0) output = flow.gather( flow.tensor(input, dtype=flow.float32, device=flow.device(device)), 0, flow.tensor(index, dtype=flow.int64, device=flow.device(device)), ) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_gather_tensor_function(test_case, device): input = np.array([[1, 2], [3, 4]]) index = np.array([[0, 0], [1, 0]]) np_out = np.take_along_axis(input, index, 1) input = flow.tensor(input, dtype=flow.float32, device=flow.device(device)) index = flow.tensor(index, dtype=flow.int64, device=flow.device(device)) output = input.gather(1, index) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_gather_random_array(test_case, device): input = np.random.randn(3, 4, 3, 5) index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5)) np_out = np.take_along_axis(input, index, 1) output = flow.gather( flow.tensor(input, dtype=flow.float32, device=flow.device(device)), 1, flow.tensor(index, dtype=flow.int64, device=flow.device(device)), ) test_case.assertTrue(np.allclose(output.numpy(), np_out)) np_out2 = np.take_along_axis(input, index, 2) output2 = flow.gather( flow.tensor(input, dtype=flow.float32, device=flow.device(device)), 2, flow.tensor(index, dtype=flow.int64, device=flow.device(device)), ) test_case.assertTrue(np.allclose(output2.numpy(), np_out2)) np_out3 = np.take_along_axis(input, index, 3) output3 = flow.gather( flow.tensor(input, dtype=flow.float32, device=flow.device(device)), 3, flow.tensor(index, dtype=flow.int64, device=flow.device(device)), ) test_case.assertTrue(np.allclose(output3.numpy(), np_out3)) def _test_gather_backward(test_case, device): input = np.array([[1, 2], [3, 4]]) index = np.array([[0, 0], [1, 0]]) np_out = np.take_along_axis(input, index, 0) np_grad = _scatter_add_numpy(np.ones_like(np_out), 0, index, input.shape) of_input = flow.tensor( input, dtype=flow.float32, requires_grad=True, device=flow.device(device) ) output = flow.gather( of_input, 0, flow.tensor(index, dtype=flow.int64, device=flow.device(device)), ) out_sum = output.sum() out_sum.backward() test_case.assertTrue(np.array_equal(output.numpy(), np_out)) test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad)) def _test_gather_index_0dim_tensor(test_case, device): input = flow.ones(1).to(device) input.requires_grad = True index = flow.tensor(0).to(device) output = flow.gather(input, 0, index) test_case.assertTrue(np.array_equal(output.numpy(), 1.0)) output.sum().backward() test_case.assertTrue(np.array_equal(input.grad.numpy(), [1.0])) def _test_gather_input_index_0dim_tensor(test_case, device): input = flow.tensor(1.0).to(device) input.requires_grad = True index = flow.tensor(0).to(device) output = flow.gather(input, 0, index) test_case.assertTrue(np.array_equal(output.numpy(), 1.0)) output.sum().backward() test_case.assertTrue(np.array_equal(input.grad.numpy(), 1.0)) def _test_gather_input_0dim_tensor(test_case, device): input = flow.tensor(1.0).to(device) input.requires_grad = True index = flow.tensor([0]).to(device) output = flow.gather(input, 0, index) test_case.assertTrue(np.array_equal(output.numpy(), [1.0])) output.sum().backward() test_case.assertTrue(np.array_equal(input.grad.numpy(), 1.0)) @flow.unittest.skip_unless_1n1d() class TestGather(flow.unittest.TestCase): def test_gather(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_gather, _test_gather_tensor_function, _test_gather_random_array, _test_gather_backward, _test_gather_index_0dim_tensor, _test_gather_input_index_0dim_tensor, _test_gather_input_0dim_tensor, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_flow_gather_with_random_data(test_case): device = random_device() input = random_tensor(ndim=4, dim0=3, dim1=3, dim2=4, dim3=5).to(device) dim = random(-4, 4).to(int) index = random_tensor( ndim=4, dim1=random(1, 3).to(int), dim2=random(1, 4).to(int), dim3=random(1, 5).to(int), low=0, high=3, dtype=int, ).to(device) return torch.gather(input, dim, index) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_gather_bool_with_random_data(test_case): device = random_device() input = random_tensor(ndim=4, dim0=3, dim1=3, dim2=4, dim3=5).to( device=device, dtype=torch.bool ) dim = random(0, 4).to(int) index = random_tensor( ndim=4, dim1=random(1, 3).to(int), dim2=random(1, 4).to(int), dim3=random(1, 5).to(int), low=0, high=3, dtype=int, ).to(device) return torch.gather(input, dim, index) @profile(torch.gather) def profile_gather(test_case): t = torch.ones(1000, 1000) torch.gather(t, 1, torch.ones(1000, 1000, dtype=torch.int64)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_gather_nd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_gather_nd(test_case, device): input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) indices = np.array([[0], [2]]) np_out = np.array([[1, 2, 3], [7, 8, 9]]) output = flow.gather_nd( flow.tensor(input, dtype=flow.float, device=flow.device(device)), flow.tensor(indices, dtype=flow.int, device=flow.device(device)), ) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_gather_nd_t(test_case, device): input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) indices = np.array([[0, 2], [2, 1]]) np_out = np.array([3, 8]) output = flow.gather_nd( flow.tensor(input, dtype=flow.float, device=flow.device(device)), flow.tensor(indices, dtype=flow.int, device=flow.device(device)), ) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) def _test_gather_nd_backward(test_case, device): input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) indices = np.array([[0], [2]]) np_out = np.array([[1, 2, 3], [7, 8, 9]]) np_grad = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]]) of_input = flow.tensor( input, requires_grad=True, dtype=flow.float, device=flow.device(device) ) output = flow.gather_nd( of_input, flow.tensor(indices, dtype=flow.int, device=flow.device(device)) ) out_sum = output.sum() out_sum.backward() test_case.assertTrue(np.array_equal(output.numpy(), np_out)) test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad)) def _test_gather_nd_backward_t(test_case, device): input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) indices = np.array([[0, 2], [2, 1]]) np_out = np.array([3, 8]) np_grad = np.array([[0, 0, 1], [0, 0, 0], [0, 1, 0]]) of_input = flow.tensor( input, requires_grad=True, dtype=flow.float, device=flow.device(device) ) output = flow.gather_nd( of_input, flow.tensor(indices, dtype=flow.int, device=flow.device(device)) ) out_sum = output.sum() out_sum.backward() test_case.assertTrue(np.array_equal(output.numpy(), np_out)) test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad)) @flow.unittest.skip_unless_1n1d() class TestGather_nd(flow.unittest.TestCase): def test_gather_nd(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_gather_nd, _test_gather_nd_t, _test_gather_nd_backward, _test_gather_nd_backward_t, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_gelu_approximate.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch class NewGELUActivation(torch.nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ def forward(self, input: torch.Tensor) -> torch.Tensor: return ( 0.5 * input * ( 1.0 + torch.tanh( math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)) ) ) ) def _test_gelu_approximate(test_case, device): torch_gelu = NewGELUActivation() x = np.random.randn(2, 4, 3) torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device)) oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device)) torch_y = torch_gelu(torch_x) oneflow_y = flow._C.gelu_with_approximate(oneflow_x, "tanh") test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy())) torch_y_sum = torch_y.sum() torch_y_sum.backward() oneflow_y_sum = oneflow_y.sum() oneflow_y_sum.backward() test_case.assertTrue( np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy()) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_gelu_approximate(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_gelu_approximate] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_generator.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest class TestGenerator(flow.unittest.TestCase): def test_different_devices(test_case): auto_gen = flow.Generator(device="auto") cpu_gen = flow.Generator(device="cpu") test_case.assertTrue(auto_gen.initial_seed() == cpu_gen.initial_seed()) with test_case.assertRaises(RuntimeError) as context: flow.Generator(device="invalid") if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_gen = flow.Generator(device="cuda") test_case.assertTrue(auto_gen.initial_seed() == cuda_gen.initial_seed()) def test_generator_manual_seed(test_case): generator = flow.Generator() generator.manual_seed(1) test_case.assertTrue(generator.initial_seed() == 1) generator.manual_seed(2) test_case.assertTrue(generator.initial_seed() == 2) def test_generator_in_dropout(test_case): tgt = flow.ones(2000000) output = flow._C.dropout( tgt, p=0.1, training=True, generator=flow.Generator(), addend=None ) output.numpy() if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): output = flow._C.dropout( tgt.cuda(), 0.1, training=True, generator=flow.Generator(), addend=None ) output.numpy() class TestDefaultGenerator(flow.unittest.TestCase): def test_different_devices(test_case): auto_gen = flow.Generator(device="auto") cpu_gen = flow.default_generator with test_case.assertRaises(RuntimeError) as context: flow.Generator(device="invalid") flow.Generator(device="cpu:1000") if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): with test_case.assertRaises( oneflow._oneflow_internal.exception.Exception ) as context: flow.Generator(device="cuda:1000") cuda_gen = flow.Generator(device="cuda") cuda0_gen = flow.Generator(device="cuda:0") def test_generator_manual_seed(test_case): cpu_gen = flow.default_generator auto_gen = flow.Generator(device="auto") test_gens = [cpu_gen, auto_gen] if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_gen = flow.Generator(device="cuda") cuda0_gen = flow.Generator(device="cuda:0") test_gens += [cuda_gen, cuda0_gen] for seed in [1, 2]: for gen in test_gens: gen.manual_seed(seed) test_case.assertTrue(gen.initial_seed() == seed) def test_generator_seed(test_case): cpu_gen = flow.default_generator auto_gen = flow.Generator(device="auto") test_gens = [auto_gen, cpu_gen] if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_gen = flow.Generator(device="cuda") cuda0_gen = flow.Generator(device="cuda:0") test_gens += [cuda_gen, cuda0_gen] for gen in test_gens: seed = gen.seed() test_case.assertTrue(seed == gen.initial_seed()) def test_generator_getstate(test_case): auto_gen = flow.Generator(device="auto") state = auto_gen.get_state() cpu_gen = flow.Generator(device="cpu") state = cpu_gen.get_state() if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_gen = flow.Generator(device="cuda") state = cuda_gen.get_state() @unittest.skip("the curandstate is no longer used by normal kernel") def test_generator_setstate(test_case): cpu_gen = flow.default_generator flow.randn(100, 100, dtype=flow.float32, device="cpu", generator=cpu_gen) if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_gen = flow.Generator("cuda") flow.randn(100, 100, dtype=flow.float32, device="cuda", generator=cuda_gen) state = cpu_gen.get_state() flow.randn(100, 100, dtype=flow.float32, device="cpu", generator=cpu_gen) if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): cuda_state = cuda_gen.get_state() flow.randn(100, 100, dtype=flow.float32, device="cuda", generator=cuda_gen) new_state = cpu_gen.get_state() test_case.assertTrue(not np.allclose(new_state.numpy(), state.numpy())) cpu_gen.set_state(state) new_state = cpu_gen.get_state() test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy())) if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): new_cuda_state = cuda_gen.get_state() test_case.assertTrue( not np.allclose(new_cuda_state.numpy(), cuda_state.numpy()) ) cuda_gen.set_state(cuda_state) new_cuda_state = cuda_gen.get_state() test_case.assertTrue( np.allclose(new_cuda_state.numpy(), cuda_state.numpy()) ) def test_get_rng_state(test_case): cpu_gen = flow.default_generator state = cpu_gen.get_state() rng_state = flow.get_rng_state() test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy())) flow.randn(100, 100, dtype=flow.float32, device="cpu", generator=cpu_gen) state = cpu_gen.get_state() rng_state = flow.get_rng_state() test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy())) def test_set_rng_state(test_case): flow.randn(100, 100) state = flow.get_rng_state() flow.randn(100, 100) new_state = flow.get_rng_state() test_case.assertTrue(not np.allclose(new_state.numpy(), state.numpy())) flow.set_rng_state(state) new_state = flow.get_rng_state() test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy())) if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): flow.randn(100, 100).to("cuda") state = flow.cuda.get_rng_state() flow.randn(100, 100).to("cuda") new_state = flow.cuda.get_rng_state() test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy())) states = flow.cuda.get_rng_state_all() before0 = flow.cuda.FloatTensor(100, device=0).normal_() before1 = flow.cuda.FloatTensor(100, device=1).normal_() flow.cuda.set_rng_state_all(states) after0 = flow.cuda.FloatTensor(100, device=0).normal_() after1 = flow.cuda.FloatTensor(100, device=1).normal_() test_case.assertTrue(np.allclose(before0.numpy(), after0.numpy())) test_case.assertTrue(np.allclose(before1.numpy(), after1.numpy())) # NOTE: according to https://github.com/Oneflow-Inc/oneflow/pull/9102#discussion_r973811389 # tensor init function fallback to `flow.default_generator.seed()`, and this test will be normal while tensor init functions reconstructed.(using op/kernel) # @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @unittest.skipIf(True, "tensor init functions need to be reconstructed!") def test_tensor_init(test_case): flow.manual_seed(0) x = flow.ones(2) x.uniform_() flow.manual_seed(0) y = flow.ones(2).to("cuda") y.uniform_() test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_0_dim_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_0_dim_tensor(test_case, placement, sbp): x1 = random_tensor(0).to_global(placement=placement, sbp=sbp) x2 = random_tensor(0).to_global(placement=placement, sbp=sbp) y1 = x1 * x2 y2 = x1 + x2 return y1 + y2 @autotest(n=1, check_graph=True) def _test_1dim_slice(test_case, placement, sbp): x = random_tensor(1, random(1, 4) * 8).to_global(placement=placement, sbp=sbp) return x[5] class TestZeroDimensionTensor(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_0_dim_tensor(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=0): _test_0_dim_tensor(test_case, placement, sbp) for sbp in all_sbp(placement, max_dim=1): _test_1dim_slice(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_TripletMarginLoss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_global_triplet_marginloss_with_random_data(test_case, placement, sbp): margin = random().to(float) p = random().to(float) swap = random_bool() reduction = oneof("none", "sum", "mean", nothing()) m = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction) m.train(random()) anchor = random_tensor(2, 8, 16).to_global(placement, sbp) pos = random_tensor(2, 8, 16).to_global(placement, sbp) neg = random_tensor(2, 8, 16).to_global(placement, sbp) y = m(anchor, pos, neg) return y class TestGlobalTripletMarginLoss(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @globaltest def test_global_triplet_marginloss_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_triplet_marginloss_with_random_data( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_abs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest @autotest(n=1, check_graph=True) def _test_abs_with_ndim_data(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.abs(x) return y class TestAbsModule(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_abs_with_ndim_data(test_case): for placement in all_placement(): ndim = random(0, 4).to(int).value() for sbp in all_sbp(placement, max_dim=ndim): _test_abs_with_ndim_data(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def build_module(act_type): if act_type == "relu": return torch.nn.ReLU() elif act_type == "relu6": return torch.nn.ReLU6() elif act_type == "tanh": return torch.nn.Tanh() elif act_type == "elu": return torch.nn.ELU(alpha=random()) elif act_type == "celu": return torch.nn.CELU(alpha=random()) elif act_type == "gelu": return torch.nn.GELU() elif act_type == "sigmoid": return torch.nn.Sigmoid() elif act_type == "hardsigmoid": return torch.nn.Hardsigmoid() elif act_type == "hardshrink": return torch.nn.Hardshrink(lambd=random()) elif act_type == "logsigmoid": return torch.nn.LogSigmoid() elif act_type == "hardswish": return torch.nn.Hardswish() elif act_type == "hardtanh": return torch.nn.Hardtanh( min_val=random().to(float), max_val=random().to(float), ) elif act_type == "leakyrelu": return torch.nn.LeakyReLU(negative_slope=random()) elif act_type == "mish": return torch.nn.Mish() elif act_type == "silu": return torch.nn.SiLU() elif act_type == "selu": return torch.nn.SELU() elif act_type == "threshold": return torch.nn.Threshold(threshold=random(), value=random()) elif act_type == "softplus": return torch.nn.Softplus() elif act_type == "softshrink": return torch.nn.Softshrink() else: raise ValueError("activation type %s is not support" % act_type) @autotest(n=1, check_graph=False) def _test_activation_module_with_random_data(test_case, act_type, ndim, placement, sbp): m = build_module(act_type) m.train(random()) dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) return y @autotest(n=1, check_graph=False) def _test_activation_module_with_0dim_data(test_case, act_type, placement, sbp): m = build_module(act_type) m.train(random()) x = random_tensor(ndim=0).to_global(placement=placement, sbp=sbp) y = m(x) return y @autotest(n=1, check_graph=False) def _test_activation_module_with_0_size_data( test_case, act_type, ndim, zerodim, placement, sbp ): m = build_module(act_type) m.train(random()) dims = [random(1, 3) * 8 for i in range(ndim)] dims[zerodim] = 0 x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) return y @globaltest def _test_activation_module(test_case, act_type): for placement in all_placement(): ndim = random(1, 4).to(int).value() for sbp in all_sbp(placement, max_dim=ndim): _test_activation_module_with_random_data( test_case, act_type, ndim, placement, sbp ) # Skip gelu 0 size test since "Floating point exception" maybe encountered in PyTorch. if act_type != "gelu": zerodim = random(0, ndim).to(int).value() valid_split_axis = [i for i in range(ndim) if i != zerodim] for sbp in all_sbp( placement, max_dim=ndim, valid_split_axis=valid_split_axis ): _test_activation_module_with_0_size_data( test_case, act_type, ndim, zerodim, placement, sbp ) for sbp in all_sbp(placement, max_dim=0): _test_activation_module_with_0dim_data(test_case, act_type, placement, sbp) class TestReLUModule(flow.unittest.TestCase): def test_relu_module(test_case): _test_activation_module(test_case, "relu") class TestReLU6Module(flow.unittest.TestCase): def test_relu6_module(test_case): _test_activation_module(test_case, "relu6") class TestTanh(flow.unittest.TestCase): def test_tanh_module(test_case): _test_activation_module(test_case, "tanh") class TestELUModule(flow.unittest.TestCase): def test_elu_module(test_case): _test_activation_module(test_case, "elu") class TestCELUModule(flow.unittest.TestCase): def test_celu_module(test_case): _test_activation_module(test_case, "celu") class TestGelu(flow.unittest.TestCase): def test_gelu_module(test_case): _test_activation_module(test_case, "gelu") class TestSigmoidModule(flow.unittest.TestCase): def test_sigmoid_module(test_case): _test_activation_module(test_case, "sigmoid") class TestHardsigmoidModule(flow.unittest.TestCase): def test_hardsigmoid_module(test_case): _test_activation_module(test_case, "hardsigmoid") class TestHardshrinkModule(flow.unittest.TestCase): def test_hardshrink_module(test_case): _test_activation_module(test_case, "hardshrink") class TestLogSigmoidModule(flow.unittest.TestCase): def test_logsigmoid_module(test_case): _test_activation_module(test_case, "logsigmoid") class TestHardswishModule(flow.unittest.TestCase): def test_hardswish_module(test_case): _test_activation_module(test_case, "hardswish") class TestHardtanhModule(flow.unittest.TestCase): def test_hardtanh_module(test_case): _test_activation_module(test_case, "hardtanh") class TestLeakyReLUModule(flow.unittest.TestCase): def test_leakyrelu_module(test_case): _test_activation_module(test_case, "leakyrelu") class TestMishModule(flow.unittest.TestCase): def test_mish_module(test_case): _test_activation_module(test_case, "mish") class TestSiluModule(flow.unittest.TestCase): def test_silu_module(test_case): _test_activation_module(test_case, "silu") class TestSeluModule(flow.unittest.TestCase): def test_selu_module(test_case): _test_activation_module(test_case, "selu") class TestThresholdModule(flow.unittest.TestCase): def test_threshold_module(test_case): _test_activation_module(test_case, "threshold") class TestSoftplusModule(flow.unittest.TestCase): def test_softplus_module(test_case): _test_activation_module(test_case, "softplus") class TestSoftshrinkModule(flow.unittest.TestCase): def test_softshrink_module(test_case): _test_activation_module(test_case, "softshrink") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_adaptive_pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from packaging import version import unittest from typing import Union, Tuple import torch as torch_original import oneflow as flow import oneflow.unittest from oneflow.nn.common_types import _size_1_t from oneflow.test_utils.automated_test_util import * NoneType = type(None) # Not the same as those in PyTorch because 'output_size' cannot be NoneType (even in 'torch.nn.AdaptiveAvgPoolXd') _size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]] _size_3_opt_t_not_none = Union[ int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]] ] @autotest(n=1, check_graph=True) def _test_adaptive_avgpoolnd(test_case, ndim, pool_size, placement, sbp): dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) if pool_size == 1: m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t)) elif pool_size == 2: m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none)) elif pool_size == 3: m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none)) else: raise ValueError("pool size should be 1, 2 or 3, but got %d" % pool_size) m.train(random()) y = m(x) return y @autotest(n=1, check_graph=True) def _test_adaptive_avgpoolnd_functional(test_case, ndim, pool_size, placement, sbp): dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) if pool_size == 1: return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int)) elif pool_size == 2: return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int)) elif pool_size == 3: return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int)) class TestAdaptiveAvgPool(flow.unittest.TestCase): @globaltest def test_adaptive_avgpool(test_case): for placement in all_placement(): ndim = 3 for sbp in all_sbp(placement, max_dim=2): _test_adaptive_avgpoolnd(test_case, ndim, 1, placement, sbp) _test_adaptive_avgpoolnd_functional(test_case, ndim, 1, placement, sbp) ndim = 4 for sbp in all_sbp(placement, max_dim=2): _test_adaptive_avgpoolnd(test_case, ndim, 2, placement, sbp) _test_adaptive_avgpoolnd_functional(test_case, ndim, 2, placement, sbp) # GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0' if ( version.parse(torch_original.__version__) < version.parse("1.10.0") and placement.type == "cuda" ): continue ndim = 5 for sbp in all_sbp(placement, max_dim=2): _test_adaptive_avgpoolnd(test_case, ndim, 3, placement, sbp) _test_adaptive_avgpoolnd_functional(test_case, ndim, 3, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_add_with_alpha(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x1 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean() x2 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean() x3 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean() y = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) s = random().to(float) alpha = random().to(float) z1 = torch.add(x1, y, alpha=alpha) z2 = torch.add(x2, s, alpha=alpha) z3 = torch.add(s, x3, alpha=alpha) return z1, z2, z3 @autotest(n=1, check_graph=True) def _test_add_with_0size(test_case, ndim, zerodim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] dims[zerodim] = 1 x1 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) dims[zerodim] = 0 x2 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) return torch.add(x1, x2) class TestAddModule(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_add_with_alpha(test_case): ndim = random(1, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_add_with_alpha(test_case, ndim, placement, sbp) zerodim = random(0, ndim).to(int).value() valid_split_axis = [i for i in range(ndim) if i != zerodim] for sbp in all_sbp( placement, max_dim=ndim, valid_split_axis=valid_split_axis ): _test_add_with_0size(test_case, ndim, zerodim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_addcdiv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_addcdiv(test_case, ndim, placement, sbp): shape = [random(2, 4) * 8 for i in range(ndim)] input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp) tensor1 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp) tensor2 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp) value = random(2, 4).to(int) output = torch.addcdiv(input, tensor1, tensor2, value=value) return output class TestModule(flow.unittest.TestCase): @globaltest def test_addcdiv(test_case): ndim = random(2, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_addcdiv(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_addcmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_addcmul(test_case, ndim, placement, sbp): shape = [random(low=2, high=3) * 8 for i in range(ndim)] input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp) tensor1 = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp) tensor2 = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp) value = random(3, 6).to(int) output = torch.addcmul(input, tensor1, tensor2, value=value) return output class TestModule(flow.unittest.TestCase): @globaltest def test_addcmul(test_case): ndim = random(low=2, high=5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_addcmul(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_addmm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_addmm_with_random_data(test_case, placement, sbp): m = random(1, 3) * 8 n = random(1, 3) * 8 k = random(1, 3) * 8 input = random_tensor(ndim=2, dim0=m, dim1=n).to_global( placement=placement, sbp=sbp ) mat1 = random_tensor(ndim=2, dim0=m, dim1=k).to_global(placement=placement, sbp=sbp) mat2 = random_tensor(ndim=2, dim0=k, dim1=n).to_global(placement=placement, sbp=sbp) y = torch.addmm( input, mat1, mat2, beta=random().to(float), alpha=random().to(float), ) return y @autotest(n=1, check_graph=True) def _test_addmm_broadcast_with_random_data(test_case, placement, sbp): m = random(1, 3) * 8 n = random(1, 3) * 8 k = random(1, 3) * 8 input = random_tensor(ndim=2, dim0=1, dim1=1).to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(sbp))] ) mat1 = random_tensor(ndim=2, dim0=m, dim1=k).to_global(placement=placement, sbp=sbp) mat2 = random_tensor(ndim=2, dim0=k, dim1=n).to_global(placement=placement, sbp=sbp) y = torch.addmm( input, mat1, mat2, beta=random().to(float), alpha=random().to(float), ) return y class TestAddmm(flow.unittest.TestCase): @globaltest def test_addmm(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_addmm_with_random_data(test_case, placement, sbp) _test_addmm_broadcast_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_affine_grid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True) def _test_affine_grid_2d_with_random_data(test_case, placement, sbp): N = random(1, 3).to(int).value() * 8 C = random(1, 8).to(int).value() H = random(1, 8).to(int).value() W = random(1, 8).to(int).value() align_corners = oneof(True, False).value() dims = [N, 2, 3] theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp) output = torch.nn.functional.affine_grid( theta, (N, C, H, W), align_corners=align_corners ) return output @autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True) def _test_affine_grid_3d_with_random_data(test_case, placement, sbp): N = random(1, 3).to(int) * 8 C = random(1, 8).to(int) D = random(1, 8).to(int) H = random(1, 8).to(int) W = random(1, 8).to(int) align_corners = oneof(True, False) dims = [N, 3, 4] theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp) output = torch.nn.functional.affine_grid( theta, (N, C, D, H, W), align_corners=align_corners ) return output class TestAffineGrid(flow.unittest.TestCase): @globaltest def test_affine_grid(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_affine_grid_2d_with_random_data(test_case, placement, sbp) _test_affine_grid_3d_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_argmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_argmax_with_random_data(test_case, ndim, placement, sbp): dims = [8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.argmax(x, dim=random(0, ndim).to(int), keepdim=random().to(bool)) return y @unittest.skip("TODO: sometimes global TestArgmax fails on 2-GPU runs") class TestArgmax(flow.unittest.TestCase): @globaltest def test_argmax(test_case): ndim = random(1, 3).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_argmax_with_random_data(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_argmin.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_argmin_with_random_data(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.argmin(x, dim=random(0, ndim).to(int), keepdim=random().to(bool)) return y @unittest.skip("TODO: sometimes global TestArgmin fails on 2-GPU runs") class TestArgmin(flow.unittest.TestCase): @globaltest def test_argmin(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_argmin_with_random_data(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_argsort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_argsort_with_random_data(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.argsort( x, dim=random(low=-ndim, high=ndim).to(int), descending=random_bool() ) return y @unittest.skip("argsort has bug not found at now.") class TestArgsort(flow.unittest.TestCase): @globaltest def test_argsort(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_argsort_with_random_data(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_argwhere.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import torch as torch_ori import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_argwhere_with_random_data(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) # PyTorch has no argwhere before v1.11, so we use nonzero instead of argwhere for PyTorch # y = torch.argwhere(x) y = x.clone() y.oneflow = flow.argwhere(x.oneflow) y.pytorch = torch_ori.nonzero(x.pytorch) return y class TestArgwhere(flow.unittest.TestCase): @globaltest def test_argwhere(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_argwhere_with_random_data(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_atleast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_atleast1d_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp) y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp) out = torch.atleast_1d([x, y]) return out @autotest(n=1, check_graph=True) def _test_atleast2d_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp) y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp) z = random_tensor(ndim=3, dim0=8).to_global(placement, sbp) out = torch.atleast_2d([x, y, z]) return out @autotest(n=1, check_graph=True) def _test_atleast3d_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp) y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp) z = random_tensor(ndim=3, dim0=8).to_global(placement, sbp) p = random_tensor(ndim=4, dim0=8).to_global(placement, sbp) out = torch.atleast_3d([x, y, z, p]) return out class TestAtLeastModule(flow.unittest.TestCase): @globaltest def test_atleast1d_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_atleast1d_with_random_data(test_case, placement, sbp) @globaltest def test_atleast2d_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_atleast2d_with_random_data(test_case, placement, sbp) @globaltest def test_atleast3d_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_atleast3d_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_avgpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_avgpool1d_with_random_data(test_case, placement, sbp): m = torch.nn.AvgPool1d( kernel_size=random(4, 6), stride=random(1, 3), padding=random(1, 3), ceil_mode=random(), count_include_pad=random(), ) m.train(random()) m.to_global(placement=placement, sbp=sbp) ndim = 3 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) return y @autotest(n=1, check_graph=True) def _test_avgpool2d_with_random_data(test_case, placement, sbp): m = torch.nn.AvgPool2d( kernel_size=random(4, 6), stride=random(1, 3), padding=random(1, 3), ceil_mode=random(), count_include_pad=random(), divisor_override=random().to(int), ) m.train(random()) m.to_global(placement=placement, sbp=sbp) ndim = 4 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) return y @autotest(n=1, check_graph=True) def _test_avgpool3d_with_random_data(test_case, placement, sbp): m = torch.nn.AvgPool3d( kernel_size=random(4, 6), stride=random(1, 3), padding=random(1, 3), ceil_mode=random(), count_include_pad=random(), divisor_override=random().to(int), ) m.train(random()) m.to_global(placement=placement, sbp=sbp) ndim = 5 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) return y @autotest(n=1, check_graph=True) def _test_functional_avgpool1d_with_random_data(test_case, placement, sbp): ndim = 3 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.nn.functional.avg_pool1d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y @autotest(n=1, check_graph=True) def _test_functional_avgpool2d_with_random_data(test_case, placement, sbp): ndim = 4 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.nn.functional.avg_pool2d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y @autotest(n=1, check_graph=True) def _test_functional_avgpool3d_with_random_data(test_case, placement, sbp): ndim = 5 dims = [random(1, 3) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.nn.functional.avg_pool3d( x, kernel_size=random(1, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), ceil_mode=random_bool(), count_include_pad=random_bool(), ) return y class TestAvgPoolingModule(flow.unittest.TestCase): @globaltest def test_avg_pooling(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_avgpool1d_with_random_data(test_case, placement, sbp) _test_functional_avgpool1d_with_random_data(test_case, placement, sbp) for sbp in all_sbp(placement, max_dim=2): _test_avgpool2d_with_random_data(test_case, placement, sbp) _test_functional_avgpool2d_with_random_data(test_case, placement, sbp) for sbp in all_sbp(placement, max_dim=2): _test_avgpool3d_with_random_data(test_case, placement, sbp) _test_functional_avgpool3d_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_batch_gather.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.automated_test_util.util import broadcast def _test_batch_gather(test_case, ndim, placement, sbp): dims = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims, requires_grad=True) local_x = flow.tensor(x.pytorch.detach().cpu().numpy(), requires_grad=True) global_x = x.oneflow.to_global(placement=placement, sbp=sbp) global_x.retain_grad() indices_ndim = random(1, ndim + 1).to(int).value() indices_dims = [dims[i] for i in range(indices_ndim)] indices_dims[-1] = random(1, dims[indices_ndim - 1]).to(int).value() indices = np.random.choice(dims[indices_ndim - 1], indices_dims) indices = broadcast(indices) local_indices = flow.tensor(indices) global_indices = local_indices.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(sbp))] ) global_out = flow.batch_gather(global_x, global_indices) global_out.sum().backward() local_out = flow.batch_gather(local_x, local_indices) local_out.sum().backward() test_case.assertTrue( np.allclose( global_x.grad.detach().cpu().numpy(), local_x.grad.detach().cpu().numpy(), atol=1e-5, rtol=1e-5, ) ) class TestBatchGather(flow.unittest.TestCase): @globaltest def test_batch_gather(test_case): ndim = 2 for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_batch_gather(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_bincount.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False, auto_backward=False) def _test_bincount(test_case, placement, sbp): x = random_tensor(1, 64, low=0, dtype=int).to_global(placement=placement, sbp=sbp) weight = random_tensor(1, 64).to_global(placement=placement, sbp=sbp) minlength = random(1, 100).to(int) return torch.bincount(x, weight, minlength) class TestBinCountModule(flow.unittest.TestCase): @globaltest def test_bincount(test_case): for placement in all_placement(): for sbp in all_sbp(placement, valid_split_axis=0): _test_bincount(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_bitwise.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False) def _test_bitwise_ops_with_random_data(test_case, op, placement, sbp): x = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp) y = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp) out = op(x, y) return out @autotest(n=1, auto_backward=False) def _test_bitwise_not_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp) return torch.bitwise_not(x) class TestBitwiseModule(flow.unittest.TestCase): @globaltest def test_bitwise_and_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_bitwise_ops_with_random_data( test_case, torch.bitwise_and, placement, sbp ) _test_bitwise_ops_with_random_data( test_case, torch.bitwise_or, placement, sbp ) _test_bitwise_ops_with_random_data( test_case, torch.bitwise_xor, placement, sbp ) _test_bitwise_not_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_broadcase_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_broadcast_like(test_case, placement, sbp): like_shape = [8] * 4 like = random_tensor(4, *like_shape).to_global( placement, random_sbp(placement, max_dim=4) ) x = random_tensor(2, *(8, 8)).to_global(placement, sbp) # oneflow of_y = flow.broadcast_like(x.oneflow, like.oneflow) # pytorch torch_y = x.pytorch.broadcast_to(like_shape) test_case.assertTrue(np.allclose(of_y.numpy(), torch_y.detach().cpu().numpy())) def _test_broadcast_like_expand_dims(test_case, placement, sbp): like_shape = [8] * 4 like = random_tensor(4, *like_shape).to_global( placement, random_sbp(placement, max_dim=4) ) x = random_tensor(2, *(8, 8)).to_global(placement, sbp) # oneflow of_y = flow.broadcast_like(x.oneflow, like.oneflow, [1, 3]) # pytorch torch_y = x.pytorch.view(8, 1, 8, 1).broadcast_to(like_shape) test_case.assertTrue(np.allclose(of_y.numpy(), torch_y.detach().cpu().numpy())) class TestGlobalBroadcaseLike(flow.unittest.TestCase): @globaltest def test_broadcase_like(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_broadcast_like(test_case, placement, sbp) _test_broadcast_like_expand_dims(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_broadcast_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_flow_tensor_global_broadcast_matmul_with_random_data( test_case, placement, x_sbp, y_sbp ): batch_dim = random(1, 6) * 8 k = random(1, 6) * 4 x = random_tensor(ndim=3, dim0=batch_dim, dim2=k).to_global( placement=placement, sbp=x_sbp ) y = random_tensor(ndim=2, dim0=k).to_global(placement=placement, sbp=y_sbp) return x.matmul(y) @autotest(n=1, check_graph=True) def _test_flow_tensor_global_x_broadcast_y_matmul(test_case, placement, x_sbp, y_sbp): batch_dim = random(1, 6) * 8 k = random(1, 6) * 4 x = random_tensor(ndim=2, dim1=k).to_global(placement=placement, sbp=x_sbp) y = random_tensor(ndim=3, dim0=batch_dim, dim1=k).to_global( placement=placement, sbp=y_sbp ) return x.matmul(y) @autotest(n=1, check_graph=True, rtol=1e-3, atol=1e-4) def _test_flow_tensor_global_broadcast_matmul_with_same_dims( test_case, placement, x_sbp, y_sbp ): k = random(1, 6) * 8 batch_dim = random(1, 6) * 8 x = random_tensor(ndim=3, dim0=batch_dim, dim1=4, dim2=k).to_global( placement=placement, sbp=x_sbp ) y = random_tensor(ndim=3, dim0=batch_dim, dim1=k, dim2=4).to_global( placement=placement, sbp=y_sbp ) return x.matmul(y) class TestGlobalBroadcastMatmulModule(flow.unittest.TestCase): @globaltest def test_global_broadcast_matmul_with_random_data(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]): for y_sbp in all_sbp(placement, max_dim=2, except_split=True): _test_flow_tensor_global_broadcast_matmul_with_random_data( test_case, placement, x_sbp, y_sbp ) @globaltest def test_global_x_broadcast_y_matmul(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, max_dim=2, except_split=True): for y_sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]): _test_flow_tensor_global_x_broadcast_y_matmul( test_case, placement, x_sbp, y_sbp ) @globaltest def test_global_broadcast_matmul_with_same_dims(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, max_dim=2): for y_sbp in all_sbp(placement, max_dim=2): _test_flow_tensor_global_broadcast_matmul_with_same_dims( test_case, placement, x_sbp, y_sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_broadcast_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1) def _test_global_broadcast_tensors( test_case, input_shape, other_shape, placement, x_sbp, y_sbp ): x = random_tensor(len(input_shape), *input_shape).to_global( placement=placement, sbp=x_sbp ) y = random_tensor(len(other_shape), *other_shape).to_global( placement=placement, sbp=y_sbp ) return torch.broadcast_tensors(x, y) class TestGlobalBroadcastOps(flow.unittest.TestCase): # flow.broadcast_shapes's input are shapes, so it can't be tested in global mode # flow.broadcast_to is an alias of flow.expand, so its global tests are same as flow.expand's @globaltest def test_global_tensors(test_case): shapes = [((2, 2), (2, 2, 2)), ((1, 2), (3, 1))] for input_shape, other_shape in shapes: for placement in all_placement(): for x_sbp in all_sbp( placement, max_dim=2, valid_split_axis=[x for x in input_shape if x != 1], ): for y_sbp in all_sbp( placement, max_dim=2, valid_split_axis=[y for y in other_shape if y != 1], ): _test_global_broadcast_tensors( test_case, input_shape, other_shape, placement, x_sbp, y_sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_cast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import os import numpy as np import oneflow as flow from oneflow import nn import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow import Tensor from oneflow.framework.args_tree import ArgsTree @flow.unittest.skip_unless_1n4d() class TestGlobalCastModule_1n4d(flow.unittest.TestCase): def test_to_global_flatten_hierarchy(test_case): x = flow.ones((4, 4), dtype=flow.int32) sbp = (flow.sbp.partial_sum,) y = x.to_global( placement=flow.placement("cpu", ranks=[[0, 1], [2, 3]]), sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum), ) placement = flow.placement("cpu", ranks=[0, 1, 2, 3]) y = y.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_to_global_flatten_hierarchy_cpu_to_gpu(test_case): x = flow.ones((4, 4), dtype=flow.int32) sbp = (flow.sbp.partial_sum,) y = x.to_global( placement=flow.placement("cpu", ranks=[[0, 1], [2, 3]]), sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum), ) placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) y = y.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_to_global_flatten_hierarchy_gpu_to_cpu(test_case): x = flow.ones((4, 4), dtype=flow.int32) sbp = (flow.sbp.partial_sum,) y = x.to_global( placement=flow.placement("cuda", ranks=[[0, 1], [2, 3]]), sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum), ) placement = flow.placement("cpu", ranks=[0, 1, 2, 3]) y = y.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) def test_to_global_broadcast_shape_dtype(test_case): if int(os.getenv("RANK")) < 2: x = flow.ones((4, 4), dtype=flow.int32) else: x = flow.zeros((1,), dtype=flow.float) placement = flow.placement("cpu", ranks=[0, 1]) sbp = (flow.sbp.split(0),) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_local_to_global_2d_sbp(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) sbp = (flow.sbp.split(0), flow.sbp.partial_sum) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_local_to_global_sp_2_bb(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) sbp = (flow.sbp.split(0), flow.sbp.partial_sum) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) y = y.to_global(sbp=(flow.sbp.broadcast, flow.sbp.broadcast)) test_case.assertEqual(y.sbp, (flow.sbp.broadcast, flow.sbp.broadcast)) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue( np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32) * 2) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_local_to_global_ps0_2_s0s0(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) x = x * int(os.getenv("RANK")) placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) sbp = (flow.sbp.partial_sum, flow.sbp.split(0)) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) sbp = (flow.sbp.split(0), flow.sbp.split(0)) y = y.to_global(sbp=sbp) z = y.to_local() if int(os.getenv("RANK")) < 2: scale = 2 else: scale = 4 test_case.assertTrue( np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * scale) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_local_to_global_s0p_2_s0s0(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) x = x * int(os.getenv("RANK")) placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) sbp = (flow.sbp.split(0), flow.sbp.partial_sum) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) sbp = (flow.sbp.split(0), flow.sbp.split(0)) y = y.to_global(sbp=sbp) z = y.to_local() if int(os.getenv("RANK")) < 2: scale = 1 else: scale = 5 test_case.assertTrue( np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * scale) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_to_global_loop_broadcast_shape_dtype(test_case): if int(os.getenv("RANK")) < 2: x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) a = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) else: x = flow.zeros((1,), dtype=flow.float) a = flow.zeros((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) sbp = (flow.sbp.split(0),) for i in range(1000): if i % 100 == 0: print(i) y = x.to_global(placement=placement, sbp=sbp) b = a.to_global(placement=placement, sbp=flow.sbp.broadcast) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) @flow.unittest.skip_unless_1n2d() class TestGlobalCastModule_1n2d(flow.unittest.TestCase): def test_to_global_broadcast_shape_dtype(test_case): if os.getenv("RANK") == "0": x = flow.ones((4, 4), dtype=flow.int32) else: x = flow.zeros((1,), dtype=flow.float) placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) def test_local_to_global_broadcast_data(test_case): if int(os.getenv("RANK")) == 0: x = flow.ones((4, 4), dtype=flow.int32) else: x = flow.zeros((4, 4), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) sbp = (flow.sbp.broadcast,) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue(np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32))) def test_cuda_global_to_global_cpu_s2b(test_case): x = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.split(0)) sbp = (flow.sbp.broadcast,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue(np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32))) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_s2b(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.split(0)) sbp = (flow.sbp.broadcast,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue(np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32))) def test_cuda_global_to_global_cpu_s2p(test_case): x = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.split(0)) sbp = (flow.sbp.partial_sum,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() if int(os.getenv("RANK")) == 0: test_case.assertTrue( np.array_equal( z.numpy(), np.concatenate( ( np.ones((4, 4), dtype=np.int32), np.zeros((4, 4), dtype=np.int32), ), axis=0, ), ) ) else: test_case.assertTrue( np.array_equal( z.numpy(), np.concatenate( ( np.zeros((4, 4), dtype=np.int32), np.ones((4, 4), dtype=np.int32), ), axis=0, ), ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_s2p(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.split(0)) sbp = (flow.sbp.partial_sum,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (8, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() if int(os.getenv("RANK")) == 0: test_case.assertTrue( np.array_equal( z.numpy(), np.concatenate( ( np.ones((4, 4), dtype=np.int32), np.zeros((4, 4), dtype=np.int32), ), axis=0, ), ) ) else: test_case.assertTrue( np.array_equal( z.numpy(), np.concatenate( ( np.zeros((4, 4), dtype=np.int32), np.ones((4, 4), dtype=np.int32), ), axis=0, ), ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_b2p(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.broadcast) sbp = (flow.sbp.partial_sum,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() if int(os.getenv("RANK")) == 0: test_case.assertTrue( np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32)) ) else: test_case.assertTrue( np.array_equal(z.numpy(), np.zeros((4, 4), dtype=np.int32)) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_b2s(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.broadcast) sbp = (flow.sbp.split(0),) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue(np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32))) def test_cuda_global_to_global_cpu_p2s(test_case): x = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum) sbp = (flow.sbp.split(0),) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue( np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * 2) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_p2s(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum) sbp = (flow.sbp.split(0),) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue( np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * 2) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_cuda_h2d(test_case): x = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) cuda_placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum) y = y.to_global(placement=cuda_placement, sbp=flow.sbp.partial_sum) test_case.assertEqual(y.sbp, (flow.sbp.partial_sum,)) test_case.assertEqual(y.placement, cuda_placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue(np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32))) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_cpu_p2b(test_case): x = flow.ones((4, 4), device=flow.device("cpu"), dtype=flow.int32) placement = flow.placement("cpu", ranks=[0, 1]) cuda_placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum) import time y = y.to_global(placement=cuda_placement, sbp=flow.sbp.partial_sum) sbp = (flow.sbp.broadcast,) y = y.to_global(placement=cuda_placement, sbp=sbp) y = y.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue( np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32) * 2) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cuda_global_to_global_p2b(test_case): x = flow.ones((4, 4), device=flow.device("cuda"), dtype=flow.int32) placement = flow.placement("cuda", ranks=[0, 1]) y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum) sbp = (flow.sbp.broadcast,) y = y.to_global(sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) test_case.assertEqual(y.dtype, flow.int32) z = y.to_local() test_case.assertTrue( np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32) * 2) ) @flow.unittest.skip_unless_1n1d() class TestGlobalCastModule_1n1d(flow.unittest.TestCase): def test_to_global(test_case): x = flow.ones((4, 4)) placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) y = x.to_global(placement=placement, sbp=sbp) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) test_case.assertEqual(tuple(y.shape), (4, 4)) def _test_cpu_p2b_with_random_parameter(test_case, device_list): gen_float = np.random.random gen_int = np.random.randint dtype_list = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, flow.double, ] def choose_shape_and_dtype(seed): rng = np.random.default_rng(seed) kdtype = rng.integers(low=1, high=len(dtype_list), size=1) ndim = rng.integers(low=1, high=4, size=1) shape = rng.integers(low=1, high=10, size=ndim) return kdtype, shape for _ in range(10): seed = flow.tensor(gen_int(1, 1000, 1)) seed = seed.to_global( placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast, ) seed = int(seed.to_local().numpy()) kdtype, shape = choose_shape_and_dtype(seed) if kdtype <= 3: np_arr = gen_int(1, 10, shape) else: np_arr = gen_float(shape) tensor = flow.tensor(np_arr, device="cpu", dtype=dtype_list[int(kdtype)]) cpu_tensor = tensor.to_global( placement=flow.placement("cpu", device_list), sbp=flow.sbp.partial_sum ) cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.broadcast) tensor = tensor.to("cuda") cuda_tensor = tensor.to_global( placement=flow.placement("cuda", device_list), sbp=flow.sbp.partial_sum ) cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.broadcast) test_case.assertTrue( np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy()) ) def _test_cpu_s2b_with_random_parameter(test_case, device_list): gen_float = np.random.random gen_int = np.random.randint dtype_list = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, flow.double, ] def choose_shape_and_dtype(seed): rng = np.random.default_rng(seed) kdtype = rng.integers(low=1, high=len(dtype_list), size=1) ndim = rng.integers(low=1, high=4, size=1) shape = rng.integers(low=1, high=10, size=ndim) return kdtype, shape for _ in range(10): seed = flow.tensor(gen_int(1, 1000, 1)) seed = seed.to_global( placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast, ) seed = int(seed.to_local().numpy()) kdtype, shape = choose_shape_and_dtype(seed) if kdtype <= 3: np_arr = gen_int(1, 10, shape) else: np_arr = gen_float(shape) tensor = flow.tensor(np_arr, device="cpu", dtype=dtype_list[int(kdtype)]) cpu_tensor = tensor.to_global( placement=flow.placement("cpu", device_list), sbp=flow.sbp.split(0) ) cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.broadcast) tensor = tensor.to("cuda") cuda_tensor = tensor.to_global( placement=flow.placement("cuda", device_list), sbp=flow.sbp.split(0) ) cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.broadcast) test_case.assertTrue( np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy()) ) def _test_cpu_p2s_with_random_parameter(test_case, device_list): gen_float = np.random.random gen_int = np.random.randint dtype_list = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, flow.double, ] def choose_shape_and_dtype(seed): rng = np.random.default_rng(seed) kdtype = rng.integers(low=1, high=len(dtype_list), size=1) ndim = rng.integers(low=1, high=4, size=1) shape = list(rng.integers(low=1, high=5, size=1) * 12) + list( rng.integers(low=1, high=10, size=ndim - 1) ) return kdtype, shape for _ in range(10): seed = flow.tensor(gen_int(1, 1000, 1)) seed = seed.to_global( placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast, ) seed = int(seed.to_local().numpy()) kdtype, shape = choose_shape_and_dtype(seed) if kdtype <= 3: np_arr = gen_int(1, 10, shape) else: np_arr = gen_float(shape) tensor = flow.tensor(np_arr, device="cpu", dtype=dtype_list[int(kdtype)]) cpu_tensor = tensor.to_global( placement=flow.placement("cpu", device_list), sbp=flow.sbp.partial_sum ) cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.split(0)) tensor = tensor.to("cuda") cuda_tensor = tensor.to_global( placement=flow.placement("cuda", device_list), sbp=flow.sbp.partial_sum ) cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.split(0)) test_case.assertTrue( np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy()) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast(flow.unittest.TestCase): def test_cpu_local_tensor_to_gpu_placement(test_case): if flow.env.get_rank() == 0: np_arr = np.array([4, 6, 7, 8], dtype=np.float32) else: np_arr = np.array([0, 0, 0, 0], dtype=np.float32) tensor = flow.tensor(np_arr, dtype=flow.float32) placement = flow.placement("cuda", [0, 1, 2, 3]) device = flow.device("cuda") global_tensor = tensor.to_global(placement, flow.sbp.broadcast) test_case.assertEqual(global_tensor.to_local().device, device) test_case.assertEqual(global_tensor.placement, placement) test_case.assertTrue( np.array_equal( global_tensor.to_local().numpy(), np.array([4, 6, 7, 8], dtype=np.float32), ) ) def test_cpu_p2b_with_random_parameter(test_case): arg_dict = OrderedDict() arg_dict["device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] for arg in GenArgList(arg_dict): _test_cpu_p2b_with_random_parameter(test_case, *arg) def test_cpu_s2b_with_random_parameter(test_case): arg_dict = OrderedDict() arg_dict["device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] for arg in GenArgList(arg_dict): _test_cpu_s2b_with_random_parameter(test_case, *arg) def test_cpu_p2s_with_random_parameter(test_case): arg_dict = OrderedDict() arg_dict["device_list"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]] for arg in GenArgList(arg_dict): _test_cpu_p2s_with_random_parameter(test_case, *arg) def test_local_to_global_with_wrong_device(test_case): np_arr = np.array([4, 6], dtype=np.float32) tensor = flow.tensor( np_arr, device=flow.device("cuda:%d" % ((flow.env.get_rank() + 1) % 4)), dtype=flow.float32, ) placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) device = flow.device("cuda") global_tensor = tensor.to_global(placement, flow.sbp.broadcast) local_tensor = global_tensor.to_local() test_case.assertEqual(local_tensor.device, device) test_case.assertEqual(global_tensor.placement, placement) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast_S2S(flow.unittest.TestCase): def test_global_to_global_s0_to_s1(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) split0_tensor = tensor.to_global(placement, flow.sbp.split(0)) split1_tensor = split0_tensor.to_global(placement, flow.sbp.split(1)) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( split1_tensor.to_local().numpy(), np.array( [ [4.0, 6.0], [6.0, 2.0], [3.0, 7.0], [6.0, 8.0], [2.0, 10.0], [3.0, 9.0], [4.0, 6.0], [6.0, 8.0], ], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( split1_tensor.to_local().numpy(), np.array( [ [5.0, 20.0], [5.0, 7.0], [5.0, 4.0], [9.0, 4.0], [10.0, 7.0], [10.0, 5.0], [6.0, 9.0], [6.0, 4.0], ], dtype=np.float32, ), ) ) def test_global_to_global_s1_to_s0(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) split_tensor = tensor.to_global(placement, flow.sbp.split(0)) split1_tensor = split_tensor.to_global(placement, flow.sbp.split(1)) split0_tensor = split1_tensor.to_global(placement, flow.sbp.split(0)) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( split0_tensor.to_local().numpy(), np.array( [ [4.0, 6.0, 5.0, 20.0], [6.0, 2.0, 5.0, 7.0], [3.0, 7.0, 5.0, 4.0], [6.0, 8.0, 9.0, 4.0], ], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( split0_tensor.to_local().numpy(), np.array( [ [2.0, 10.0, 10.0, 7.0], [3.0, 9.0, 10.0, 5.0], [4.0, 6.0, 6.0, 9.0], [6.0, 8.0, 6.0, 4.0], ], dtype=np.float32, ), ) ) def test_global_to_global_s0_to_s1_cpu(test_case): np_arr = np.random.randn(4, 12) cuda_device = flow.device("cuda") cuda_tensor = flow.tensor(np_arr, device=cuda_device, dtype=flow.float32) cuda_placement = flow.placement("cuda", ranks=[1, 3]) cuda_split0_tensor = cuda_tensor.to_global(cuda_placement, flow.sbp.split(0)) cuda_split1_tensor = cuda_split0_tensor.to_global( cuda_placement, flow.sbp.split(1) ) cpu_device = flow.device("cpu") cpu_tensor = flow.tensor(np_arr, device=cpu_device, dtype=flow.float32) cpu_placement = flow.placement("cpu", ranks=[1, 3]) cpu_split0_tensor = cpu_tensor.to_global(cpu_placement, flow.sbp.split(0)) cpu_split1_tensor = cpu_split0_tensor.to_global( cpu_placement, flow.sbp.split(1) ) if flow.env.get_rank() == 0 or flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( cuda_split1_tensor.to_local().numpy(), cpu_split1_tensor.to_local().numpy(), ) ) def test_global_to_global_s1_to_s0_cpu(test_case): np_arr = np.random.randn(4, 12) cuda_device = flow.device("cuda") cuda_tensor = flow.tensor(np_arr, device=cuda_device, dtype=flow.float32) cuda_placement = flow.placement("cuda", ranks=[0, 1]) cuda_split_tensor = cuda_tensor.to_global(cuda_placement, flow.sbp.split(0)) cuda_split1_tensor = cuda_split_tensor.to_global( cuda_placement, flow.sbp.split(1) ) cuda_split0_tensor = cuda_split1_tensor.to_global( cuda_placement, flow.sbp.split(0) ) cpu_device = flow.device("cpu") cpu_tensor = flow.tensor(np_arr, device=cpu_device, dtype=flow.float32) cpu_placement = flow.placement("cpu", ranks=[0, 1]) cpu_split_tensor = cpu_tensor.to_global(cpu_placement, flow.sbp.split(0)) cpu_split1_tensor = cpu_split_tensor.to_global(cpu_placement, flow.sbp.split(1)) cpu_split0_tensor = cpu_split1_tensor.to_global( cpu_placement, flow.sbp.split(0) ) if flow.env.get_rank() == 0 or flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( cuda_split0_tensor.to_local().numpy(), cpu_split0_tensor.to_local().numpy(), ) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast_XToB(flow.unittest.TestCase): def test_global_to_global_btb_gpu_to_gpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) global_tensor = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement("cuda", ranks=[0, 1, 2]) broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(broadcast_tensor.placement, new_placement) if flow.env.get_rank() != 3: test_case.assertTrue( np.array_equal( broadcast_tensor.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ), ) ) def test_global_to_global_stb_gpu_to_gpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1, 2]) global_tensor = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(broadcast_tensor.placement, new_placement) test_case.assertTrue( np.array_equal( broadcast_tensor.to_local().numpy(), np.array( [ [4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0], [2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4], [9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0], ], dtype=np.float32, ), ) ) def test_global_to_global_ptb_gpu_to_gpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1, 2]) global_tensor = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(broadcast_tensor.placement, new_placement) test_case.assertTrue( np.array_equal( broadcast_tensor.to_local().numpy(), np.array( [ [15, 22, 20, 35], [13, 26, 26, 5], [9, 18, 18, 18], [18, 24, 25, 4], ], dtype=np.float32, ), ) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast_1ToN(flow.unittest.TestCase): def test_global_to_global_1tob(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0]) global_tensor = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[0, 1]) broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(broadcast_tensor.placement, new_placement) if flow.env.get_rank() < 2: test_case.assertTrue( np.array_equal( broadcast_tensor.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ), ) ) def test_global_to_global_1top(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", [0]) global_tensor = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[0, 1]) partial_sum_tensor = global_tensor.to_global( new_placement, flow.sbp.partial_sum ) test_case.assertEqual(partial_sum_tensor.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( partial_sum_tensor.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( partial_sum_tensor.to_local().numpy(), np.array( [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float32, ), ) ) def test_global_to_global_1tos(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0]) global_tensor = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[0, 1, 2, 3]) split_tensor = global_tensor.to_global(new_placement, flow.sbp.split(0)) test_case.assertEqual(split_tensor.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( split_tensor.to_local().numpy(), np.array([[4, 6, 5, 20]], dtype=np.float32,), ) ) elif flow.env.get_rank() == 1: test_case.assertTrue( np.array_equal( split_tensor.to_local().numpy(), np.array([[6, 2, 5, 7]], dtype=np.float32,), ) ) elif flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( split_tensor.to_local().numpy(), np.array([[3, 7, 5, 4]], dtype=np.float32,), ) ) elif flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( split_tensor.to_local().numpy(), np.array([[6, 8, 9, 4]], dtype=np.float32,), ) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast_NTo1(flow.unittest.TestCase): def test_global_to_global_bt1(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) global_tensor = tensor.to_global(placement, flow.sbp.broadcast) new_placement = flow.placement("cuda", ranks=[0]) broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(broadcast_tensor.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( broadcast_tensor.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ), ) ) def test_global_to_global_st1(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) global_tensor = tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[0]) partial_sum_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(partial_sum_tensor.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( partial_sum_tensor.to_local().numpy(), np.array( [ [4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4], [2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4], ], dtype=np.float32, ), ) ) def test_global_to_global_pt1(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0, 1]) global_tensor = tensor.to_global(placement, flow.sbp.partial_sum) new_placement = flow.placement("cuda", ranks=[0]) partial_sum_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(partial_sum_tensor.placement, new_placement) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( partial_sum_tensor.to_local().numpy(), np.array( [ [6, 16, 15, 27], [9, 11, 15, 12], [7, 13, 11, 13], [12, 16, 15, 8], ], dtype=np.float32, ), ) ) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGlobalCast_1To1(flow.unittest.TestCase): def test_global_to_global_1to1_gpu_to_gpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[3]) x = local_tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[2]) y = x.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ), ) ) def test_global_to_global_1to1_cpu_to_cpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cpu") local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cpu", ranks=[0]) x = local_tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cpu", ranks=[2]) y = x.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 2: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ), ) ) def test_global_to_global_1to1_gpu_to_cpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cuda") local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cuda", ranks=[0]) x = local_tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cpu", ranks=[3]) y = x.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ), ) ) def test_global_to_global_1to1_cpu_to_gpu(test_case): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]], dtype=np.float32, ) else: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) device = flow.device("cpu") local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32) placement = flow.placement("cpu", ranks=[1]) x = local_tensor.to_global(placement, flow.sbp.split(0)) new_placement = flow.placement("cuda", ranks=[3]) y = x.to_global(new_placement, flow.sbp.broadcast) test_case.assertEqual(y.placement, new_placement) if flow.env.get_rank() == 3: test_case.assertTrue( np.array_equal( y.to_local().numpy(), np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ), ) ) class GraphTestModel(nn.Graph): def __init__(self, model): super().__init__() self.model = model def build(self, x): return self.model(x) @flow.unittest.skip_unless_1n2d() class TestToGlobalAndLocal(flow.unittest.TestCase): placement = flow.placement("cpu", ranks=[0, 1]) sbp = None model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) local_graph_model = GraphTestModel(model) global_graph_model = None def __all_global(test_case, input, placement, sbp): if type(input) == Tensor: test_case.assertTrue(input.is_global) # check placement test_case.assertEqual(placement.type, input.placement.type) test_case.assertListEqual( list(placement.ranks), list(input.placement.ranks) ) # check sbp test_case.assertTupleEqual(sbp, input.sbp) elif isinstance(input, (dict, tuple, list)): node_tree = ArgsTree(input) for node in node_tree.iter_nodes(): if isinstance(node, Tensor): test_case.assertTrue(node.is_global) # check placement test_case.assertEqual(placement.type, node.placement.type) test_case.assertListEqual( list(placement.ranks), list(node.placement.ranks) ) # check sbp test_case.assertTupleEqual(sbp, node.sbp) def __all_local(test_case, input): if type(input) == Tensor: test_case.assertFalse(input.is_global) elif isinstance(input, (dict, tuple, list)): node_tree = ArgsTree(input) for node in node_tree.iter_nodes(): if isinstance(node, Tensor): test_case.assertFalse(node.is_global) def _test_any_input(test_case): tensor = flow.zeros((3, 4)) tensor_list = [flow.tensor([1, 2, 3]), flow.randn((2, 3, 4))] tensor_tuple = (flow.zeros((2, 2)), flow.ones((2, 3)), flow.randn((3, 5))) tensor_dict = {"tensor": tensor, "tensor_lt": tensor_list} random_combination = [ None, 1, "test_str", tensor, tensor_list, tensor_tuple, tensor_dict, ] inputs = [ None, 100, "test_str", tensor, tensor_list, tensor_tuple, tensor_dict, random_combination, ] global_inputs = [] for i in inputs: ret = flow.utils.global_view.to_global( i, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) test_case.__all_global( ret, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) global_inputs.append(ret) for i in global_inputs: ret = flow.utils.global_view.to_local(i) test_case.__all_local(ret) def _test_any_input_get_sbp_func(test_case): def __get_sbp(input, tensor): return TestToGlobalAndLocal.sbp tensor = flow.zeros((3, 4)) tensor_list = [flow.tensor([1, 2, 3]), flow.randn((2, 3, 4))] tensor_tuple = (flow.zeros((2, 2)), flow.ones((2, 3)), flow.randn((3, 5))) tensor_dict = {"tensor": tensor, "tensor_lt": tensor_list} random_combination = [ None, 1, "test_str", tensor, tensor_list, tensor_tuple, tensor_dict, ] inputs = [ None, 100, "test_str", tensor, tensor_list, tensor_tuple, tensor_dict, random_combination, ] global_inputs = [] for i in inputs: ret = flow.utils.global_view.to_global( i, placement=TestToGlobalAndLocal.placement, sbp=__get_sbp, ) test_case.__all_global( ret, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) global_inputs.append(ret) for i in global_inputs: ret = flow.utils.global_view.to_local(i) test_case.__all_local(ret) def _test_tensor_to_global(test_case): local_tensor = flow.ones((3, 4)) # local tensor -> global tensor global_tensor = flow.utils.global_view.to_global( local_tensor, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) test_case.assertTrue(global_tensor.is_global) # global tensor -> global tensor global_tensor = flow.utils.global_view.to_global( global_tensor, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) test_case.assertTrue(global_tensor.is_global) # passing no placement and sbp with test_case.assertRaises(ValueError): global_tensor = flow.utils.global_view.to_global( local_tensor, placement=None, sbp=None ) # wrong sbp type with test_case.assertRaises(TypeError): global_tensor = flow.utils.global_view.to_global( local_tensor, placement=TestToGlobalAndLocal.placement, sbp=(TestToGlobalAndLocal.sbp, 0), ) def _test_tensor_to_local(test_case): # global tensor -> local tensor global_tensor = flow.ones( (3, 4), placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) local_tensor = flow.utils.global_view.to_local(global_tensor) test_case.assertFalse(local_tensor.is_global) def __test_state_dict_to_global(test_case, local_state_dict): # local state dict -> global state dict global_state_dict = flow.utils.global_view.to_global( local_state_dict, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) test_case.__all_global( global_state_dict, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) # global state dict -> global state dict global_state_dict = flow.utils.global_view.to_global( global_state_dict, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) test_case.__all_global( global_state_dict, placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp, ) def __test_state_dict_to_local(test_case, global_state_dict): # global state dict -> local state dict local_state_dict = flow.utils.global_view.to_local(global_state_dict) test_case.__all_local(local_state_dict) # local input, display warning local_state_dict = flow.utils.global_view.to_local(local_state_dict) def _test_eagar_state_dict(test_case): test_case.__test_state_dict_to_global(TestToGlobalAndLocal.model.state_dict()) global_model = TestToGlobalAndLocal.model.to_global( placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp ) test_case.__test_state_dict_to_local(global_model.state_dict()) def _test_graph_state_dict(test_case): test_case.__test_state_dict_to_global( TestToGlobalAndLocal.local_graph_model.state_dict() ) test_case.__test_state_dict_to_local( TestToGlobalAndLocal.global_graph_model.state_dict() ) def test_to_global_local(test_case): sbp_types = [ (flow.sbp.broadcast,), (flow.sbp.split(0),), (flow.sbp.partial_sum,), ] for sbp in sbp_types: TestToGlobalAndLocal.sbp = sbp TestToGlobalAndLocal.global_graph_model = GraphTestModel( TestToGlobalAndLocal.model.to_global( placement=TestToGlobalAndLocal.placement, sbp=sbp ) ) test_case._test_any_input() test_case._test_any_input_get_sbp_func() test_case._test_tensor_to_global() test_case._test_tensor_to_local() test_case._test_eagar_state_dict() test_case._test_graph_state_dict() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_chunk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_chunk(test_case, ndim, placement, sbp): dims = [random(1, 3).to(int) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) dim = random(-ndim, ndim).to(int) chunks = random(low=1, high=4).to(int) y = torch.chunk(x, chunks=chunks, dim=dim) z = torch.cat(y, dim=dim) return z class TestModule(flow.unittest.TestCase): @globaltest def test_chunk(test_case): for placement in all_placement(): ndim = random(1, 4).to(int).value() for sbp in all_sbp(placement, max_dim=min(ndim, 2)): _test_chunk(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_clone.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_clone_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = y.clone() return z class TestCloneConsistent(flow.unittest.TestCase): @globaltest def test_clone(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_clone_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_coin_flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_coin_flip( test_case, batch_size, random_seed, probability, placement, sbp ): m = flow.nn.CoinFlip( batch_size, random_seed, probability, placement=placement, sbp=sbp ) x = m() test_case.assertEqual(x.shape[0], batch_size) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_coin_flip( test_case, batch_size, random_seed, probability, placement, sbp ): class GlobalCoinFlipGraph(flow.nn.Graph): def __init__(self,): super().__init__() self.m = flow.nn.CoinFlip( batch_size, random_seed, probability, placement=placement, sbp=sbp ) def build(self): return self.m() model = GlobalCoinFlipGraph() x = model() test_case.assertEqual(x.shape[0], batch_size) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestCoinFlipGlobal(flow.unittest.TestCase): @globaltest def test_coin_flip_global(test_case): arg_dict = OrderedDict() arg_dict["batch_size"] = [8, 64] arg_dict["random_seed"] = [None, 1, -1] arg_dict["probability"] = [0.0, 0.5, 1.0] for args in GenArgDict(arg_dict): for placement in all_placement(): # TODO: CoinFlip support cuda kernel if placement.type == "cuda": continue for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_global_coin_flip( test_case, **args, placement=placement, sbp=sbp ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_coin_flip_graph(test_case): arg_dict = OrderedDict() arg_dict["batch_size"] = [8] arg_dict["random_seed"] = [None, 1, -1] arg_dict["probability"] = [0.0, 0.5, 1.0] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), # TODO: CoinFlip support cuda kernel # flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), # TODO: CoinFlip support cuda kernel # flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): placement = args["placement"] for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_graph_coin_flip(test_case, **args, sbp=sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_concat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_cat_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement=placement, sbp=sbp) return torch.cat((x, x), random(0, 2).to(int)) @autotest(n=1, auto_backward=False, check_graph=True) def _test_concat_with_input_0_size_data(test_case, placement, sbp): x = random_tensor(4, 8, 8, 2, 4).to_global(placement=placement, sbp=sbp) y = random_tensor(4, 8, 8, random(0, 3) * 8, 4).to_global( placement=placement, sbp=sbp ) z = torch.cat((x, y), dim=2) return z @autotest(n=1, auto_backward=False, check_graph=True) def _test_concat_with_output_0_size_data(test_case, placement, sbp): x = random_tensor(4, 8, 8, 0, 4).to_global(placement=placement, sbp=sbp) y = random_tensor(4, 8, 8, 0, 4).to_global(placement=placement, sbp=sbp) z = torch.cat((x, y), dim=2) return z @autotest(n=1, check_graph=True) def _test_cat_only_one_tensor(test_case, placement, sbp): x = random_tensor(4, 8, 8, random(1, 3) * 8, 8).to_global( placement=placement, sbp=sbp ) return torch.cat((x,), 0) class TestModule(flow.unittest.TestCase): @globaltest def test_cat_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_cat_with_random_data(test_case, placement, sbp) @globaltest def test_cat_only_one_tensor(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_cat_only_one_tensor(test_case, placement, sbp) @globaltest def test_concat_with_input_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_concat_with_input_0_size_data(test_case, placement, sbp) @globaltest def test_concat_with_output_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_concat_with_output_0_size_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_constant.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_new_full(test_case, shape, full_value, placement, sbp): np_res = np.full(shape, full_value) x = flow.ones(shape) y = x.new_full(shape, full_value, placement=placement, sbp=sbp) test_case.assertEqual(y.shape, flow.Size(shape)) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) y = y.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() test_case.assertTrue(np.array_equal(y.numpy(), np_res)) def _test_global_graph_new_full(test_case, shape, full_value, placement, sbp): np_res = np.full(shape, full_value) class GlobalNewFullGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self,): x = flow.ones(shape) y = x.new_full(shape, full_value, placement=placement, sbp=sbp) return y model = GlobalNewFullGraph() y = model() test_case.assertEqual(y.shape, flow.Size(shape)) test_case.assertEqual(y.sbp, sbp) test_case.assertEqual(y.placement, placement) y = y.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() test_case.assertTrue(np.array_equal(y.numpy(), np_res)) def _test_global_constant(test_case, func, shape, placement, sbp): func2 = None if func == "ones": func = flow.ones np_res = np.ones(shape) elif func == "zeros": func = flow.zeros np_res = np.zeros(shape) elif func == "new_zeros": func = flow.zeros np_res = np.zeros(shape) func2 = flow.new_zeros else: raise NotImplementedError x = func(*shape, placement=placement, sbp=sbp) if func2: x = func2(x) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) x = x.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() test_case.assertTrue(np.array_equal(x.numpy(), np_res)) def _test_graph_constant(test_case, func, shape, placement, sbp): func2 = None if func == "ones": func = flow.ones np_res = np.ones(shape) elif func == "zeros": func = flow.zeros np_res = np.zeros(shape) elif func == "new_zeros": func = flow.zeros np_res = np.zeros(shape) func2 = flow.new_zeros else: raise NotImplementedError class GlobalConstantGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = func(*shape, placement=placement, sbp=sbp) if func2: x = func2(x) return x model = GlobalConstantGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) x = x.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() test_case.assertTrue(np.array_equal(x.numpy(), np_res)) class TestConstantGlobal(flow.unittest.TestCase): @globaltest def test_constant_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] functions = [ "ones", "zeros", "new_zeros", ] for func in functions: for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_constant(test_case, func, shape, placement, sbp) full_values = [2, 3, 4] for full_value in full_values: for shape in shapes: for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(shape),): _test_global_new_full( test_case, shape, full_value, placement, sbp ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_constant_graph(test_case): arg_dict = OrderedDict() arg_dict["func"] = ["ones", "zeros", "new_zeros"] arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): func = args["func"] shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_constant(test_case, func, shape, placement, sbp) full_values = [2, 3, 4] shapes = [(8,), (8, 8,), (8, 8, 8)] for full_value in full_values: for shape in shapes: for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(shape)): _test_global_graph_new_full( test_case, shape, full_value, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_ctc_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest import torch from oneflow.test_utils.automated_test_util.generators import * from oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest from oneflow.test_utils.test_util import GenArgDict def log_softmax(logits, axis=0): max_value = np.max(logits, axis, keepdims=True) exp = np.exp(logits - max_value) exp_sum = np.sum(exp, axis, keepdims=True) dist = exp / exp_sum return np.log(dist) def _compare_torch_and_oneflow( test_case, torch_ctc_loss, flow_ctc_loss, placement, module_sbp, in_sbp, max_input_length, batch_size, num_classes, max_target_length, ): log_probs = np.random.random( size=(max_input_length, batch_size, num_classes) ).astype(np.float32) log_probs = log_softmax(log_probs, axis=2) targets = np.random.randint( 1, high=num_classes, size=(batch_size, max_target_length), dtype=np.int32 ) input_lengths = np.random.randint( max_input_length / 2, high=max_input_length, size=(batch_size,), dtype=np.int32 ) target_lengths = np.random.randint( max_target_length / 2, high=max_target_length, size=(batch_size,), dtype=np.int32, ) log_probs_torch = torch.tensor(log_probs, dtype=torch.float32, requires_grad=True) targets_torch = torch.tensor(targets, dtype=torch.int32) input_lengths_torch = torch.tensor(input_lengths, dtype=torch.int32) target_lengths_torch = torch.tensor(target_lengths, dtype=torch.int32) log_probs_flow = ( flow.tensor(log_probs, dtype=flow.float32, requires_grad=True) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) .to_global(placement=placement, sbp=in_sbp) ) targets_flow = ( flow.tensor(targets, dtype=flow.int32) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) .to_global(placement=placement, sbp=in_sbp) ) input_lengths_flow = ( flow.tensor(input_lengths, dtype=flow.int32) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) .to_global(placement=placement, sbp=in_sbp) ) target_lengths_flow = ( flow.tensor(target_lengths, dtype=flow.int32) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) .to_global(placement=placement, sbp=in_sbp) ) out_torch = torch_ctc_loss( log_probs_torch, targets_torch, input_lengths_torch, target_lengths_torch ) out_flow = flow_ctc_loss( log_probs_flow, targets_flow, input_lengths_flow, target_lengths_flow ) # check forward local_output = out_flow.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose( out_torch.cpu().detach().numpy(), local_output.numpy(), rtol=1e-05, atol=1e-05, ) ) # check backward out_torch.sum().backward() out_flow.sum().backward() local_x_grad = log_probs_flow.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose( log_probs_torch.cpu().detach().numpy(), local_x_grad.numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_ctc_loss_impl( test_case, placement, module_sbp, in_sbp, max_input_length, batch_size, num_classes, max_target_length, blank, reduction, zero_infinity, ): torch_ctc_loss = torch.nn.CTCLoss( blank=blank, reduction=reduction, zero_infinity=zero_infinity ) flow_ctc_loss = flow.nn.CTCLoss( blank=blank, reduction=reduction, zero_infinity=zero_infinity ) _compare_torch_and_oneflow( test_case, torch_ctc_loss, flow_ctc_loss, placement, module_sbp, in_sbp, max_input_length, batch_size, num_classes, max_target_length, ) @flow.unittest.skip_unless_1n2d() @unittest.skip("skip for now, becase it segfaults several times in CI") class TestCTCLossGlobal(oneflow.unittest.TestCase): @globaltest def test_ctc_loss_global(test_case): arg_dict = OrderedDict() arg_dict["max_input_length"] = [20] arg_dict["batch_size"] = [4] arg_dict["num_classes"] = [5] arg_dict["max_target_length"] = [10] arg_dict["blank"] = [0, 4] arg_dict["reduction"] = ["mean", "none"] arg_dict["zero_infinity"] = [False, True] module_sbp = flow.sbp.broadcast for args in GenArgDict(arg_dict): for placement in all_placement(): for in_sbp in all_sbp(placement): _test_ctc_loss_impl( test_case, placement, module_sbp, in_sbp, **args ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_cumprod.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, auto_backward=True, check_graph=True) def _test_cumprod_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) dim = random(0, ndim).to(int).value() z = torch.cumprod(y, dim) return z class TestCumprodGlobal(flow.unittest.TestCase): @globaltest def test_cumprod(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=min(2, ndim)): _test_cumprod_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_cumsum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=True, check_graph=True) def _test_cumsum_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) dim = random(0, ndim).to(int).value() z = torch.cumsum(x, dim) return z @unittest.skip("This fails in multi-gpu") class TestCumsumGlobal(flow.unittest.TestCase): @globaltest def test_cumsum(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=min(2, ndim)): _test_cumsum_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_deconv2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True, rtol=1e-2, atol=1e-3) def _test_deconv2d_impl(test_case, placement, input_sbp): ndim = 4 in_channels = random(1, 5).to(int).value() * 8 groups = random(1, 4).to(int).value() out_channels = groups * 8 kernel_size = random(1, 4).to(int).value() stride = random(1, 5).to(int).value() padding = random(1, 3).to(int).value() dilation = random(1, 5).to(int).value() padding_mode = constant("zeros") m = torch.nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, padding_mode=padding_mode, bias=False, ) m.train(random()) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight = torch.nn.Parameter( m.weight.to_global(placement=placement, sbp=weight_sbp) ) if m.bias is not None: bias_sbp = random_sbp(placement, max_dim=1) m.bias = torch.nn.Parameter(m.bias.to_global(placement=placement, sbp=bias_sbp)) batch = random(1, 3).to(int).value() * 8 height = random(1, 5).to(int).value() * 8 width = random(1, 5).to(int).value() * 8 nchw = [batch, in_channels, height, width] x = random_tensor(ndim, *nchw).to_global(placement=placement, sbp=input_sbp) y = m(x) return y class TestDeconv2dGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_deconv2d(test_case): for placement in all_placement(): for input_sbp in all_sbp(placement, max_dim=2): _test_deconv2d_impl(test_case, placement, input_sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_deform_conv2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest import torch as pytorch import torchvision from oneflow.test_utils.automated_test_util import * def _test_deform_conv2d(test_case, placement): input_sbp = random_sbp(placement, max_dim=4) input_dims = [8, 8, 8, 8] input = random_tensor(4, *input_dims).to_global(placement=placement, sbp=input_sbp) offset_sbp = random_sbp(placement, max_dim=2) offset_dims = [8, 32, 5, 5] offset = random_tensor(4, *offset_dims).to_global( placement=placement, sbp=offset_sbp ) mask_sbp = random_sbp(placement, max_dim=2) mask_dims = [8, 4 * 4, 5, 5] mask = random_tensor(4, *mask_dims).to_global(placement=placement, sbp=mask_sbp) weight_sbp = random_sbp(placement, max_dim=2) weight_dims = [8, 8, 4, 4] weight = random_tensor(4, *weight_dims).to_global( placement=placement, sbp=weight_sbp ) bias_sbp = random_sbp(placement, max_dim=1) bias_dims = [8] bias = random_tensor(1, *bias_dims).to_global(placement=placement, sbp=bias_sbp) flow_input = input.oneflow.detach().requires_grad_() torch_input = input.pytorch.detach().requires_grad_() flow_offset = offset.oneflow.detach().requires_grad_() torch_offset = offset.pytorch.detach().requires_grad_() flow_weight = weight.oneflow.detach().requires_grad_() torch_weight = weight.pytorch.detach().requires_grad_() flow_mask = mask.oneflow.detach().requires_grad_() torch_mask = mask.pytorch.detach().requires_grad_() flow_bias = bias.oneflow.detach().requires_grad_() torch_bias = bias.pytorch.detach().requires_grad_() torch_out = torchvision.ops.deform_conv2d( torch_input, torch_offset, torch_weight, mask=torch_mask, bias=torch_bias ) flow_out = oneflow.nn.functional.deform_conv2d( flow_input, flow_offset, flow_weight, mask=flow_mask, bias=flow_bias ) # compare forward test_case.assertTrue( np.allclose( flow_out.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-04, atol=1e-4 ) ) # compare backward flow_out.sum().backward() torch_out.sum().backward() flow_input_grad = flow_input.grad torch_input_grad = torch_input.grad.detach().cpu() flow_weight_grad = flow_weight.grad torch_weight_grad = torch_weight.grad.detach().cpu() flow_offset_grad = flow_offset.grad torch_offset_grad = torch_offset.grad.detach().cpu() flow_mask_grad = flow_mask.grad torch_mask_grad = torch_mask.grad.detach().cpu() flow_bias_grad = flow_bias.grad torch_bias_grad = torch_bias.grad.detach().cpu() test_case.assertTrue( np.allclose( flow_input_grad.numpy(), torch_input_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) test_case.assertTrue( np.allclose( flow_weight_grad.numpy(), torch_weight_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) test_case.assertTrue( np.allclose( flow_offset_grad.numpy(), torch_offset_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) test_case.assertTrue( np.allclose( flow_mask_grad.numpy(), torch_mask_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) test_case.assertTrue( np.allclose( flow_bias_grad.numpy(), torch_bias_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) class TestGlobalDeformConv2d(flow.unittest.TestCase): @globaltest def test_deform_conv2d(test_case): for placement in all_placement(): for count in range(5): _test_deform_conv2d(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_det.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import re import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def det_all_placement(): cuda_version = flow._oneflow_internal.flags.cuda_version() if cuda_version < 11000: # cuSOLVER is only supported in CUDA 11.0 and above return all_cpu_placement() else: # FIXME: remove this after fixing the bug of cuda global det return all_cpu_placement() # return all_placement() @autotest(n=1, check_graph=False, auto_backward="auto") def _test_det(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)] square_dim = 8 dim_list.extend([square_dim] * 2) x = ( random_tensor(ndim, *dim_list, low=-1) .to(torch.double) .to_global(placement, sbp) ) return torch.linalg.det(x) class TestDet(flow.unittest.TestCase): @globaltest def test_det(test_case): ndim = random(2, 5).to(int).value() for placement in det_all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_det(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_diag.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_diag_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) return torch.diag(y) class TestDiagGlobal(flow.unittest.TestCase): @globaltest def test_diag(test_case): # random ndim in range [1,2] ndim = random(1, 3).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_diag_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_diagonal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=True, check_graph=True) def _test_diagonal_impl(test_case, placement, sbp): offset = random(-5, 5).to(int).value() dim1 = random(-4, 4).to(int).value() dim2 = random(-4, 4).to(int).value() x = random_tensor( ndim=4, dim0=random(1, 4) * 8, dim1=random(1, 4) * 8, dim2=random(1, 4) * 8, dim3=random(1, 4) * 8, ) y = x.to_global(placement=placement, sbp=sbp) z = torch.diagonal(y, offset, dim1, dim2) return z @unittest.skip("TODO: fix this test") class TestDiagonalGlobal(flow.unittest.TestCase): @globaltest def test_diagonal(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_diagonal_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_div.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_div_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) x = x.to_global(placement=placement, sbp=sbp) y = random_tensor(ndim, *dims) y = y.to_global(placement=placement, sbp=sbp) z = torch.div(x, y) return z class TestDivGlobal(flow.unittest.TestCase): @globaltest def test_div(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_div_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_dot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_dot_impl(test_case, placement, sbp): k = random(100, 1000) * 8 x = random_tensor(ndim=1, dim0=k).to_global(placement=placement, sbp=sbp) y = random_tensor(ndim=1, dim0=k).to_global(placement=placement, sbp=sbp) z = torch.dot(x, y) return z class TestDotGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @globaltest def test_dot(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): do_test_dot_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_dropout.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=True, check_graph=True, atol=1e-5, rtol=1e-5) def _test_dropout_p01(test_case, placement, sbp, ndim, p): dims = [random(1, 5) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) m = torch.nn.Dropout(p=p, inplace=False) return m(x) @autotest(n=1, auto_backward=True, check_graph=True, atol=1e-5, rtol=1e-5) def _test_dropout_eval_p01(test_case, placement, sbp, ndim, p): dims = [random(1, 5) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) m = torch.nn.Dropout(p=p, inplace=False) m.eval() return m(x) class TestDropoutGlobal(flow.unittest.TestCase): @globaltest def test_dropout_p01(test_case): # random ndim in range [1,3] ndim = random(1, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=min(2, ndim)): _test_dropout_p01(test_case, placement, sbp, ndim, p=0.0) _test_dropout_p01(test_case, placement, sbp, ndim, p=1.0) @globaltest def test_dropout_eval(test_case): # random ndim in range [1,3] ndim = random(1, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=min(2, ndim)): _test_dropout_eval_p01(test_case, placement, sbp, ndim, 0.0) _test_dropout_eval_p01(test_case, placement, sbp, ndim, 1.0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase1.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_alphaflod_usecase1(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("hij, ijc->ihc", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase1(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase1(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase10.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_alphaflod_usecase10(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) y = random_tensor(ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 3) * 8) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("bhqk,bkhc->bqhc", g_x, g_y) return z @unittest.skipIf(True, "skip this test temporarily") class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase10(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_einsum_alphaflod_usecase10(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase11.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4) def _test_einsum_alphaflod_usecase11(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("bqa,ahc->bqhc", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase11(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase11(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_alphaflod_usecase2(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("rac,rab->rbc", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase2(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase2(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase3.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_alphaflod_usecase3(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("ra,rab->rb", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @globaltest def test_einsum_alphaflod_usecase3(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_alphaflod_usecase3(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase4.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_alphaflod_usecase4(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("qhc,khc->qkh", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase4(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase4(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase5.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_alphaflod_usecase5(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("nm, mrc->nrc", g_x, g_y) return z @unittest.skip("this case fails in multi gpu. TODO: depeng, shenghang") class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase5(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_alphaflod_usecase5(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase6.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_alphaflod_usecase6(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("abc,adc->bdc", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase6(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase6(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase7.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4) def _test_einsum_alphaflod_usecase7(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor( ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8, ) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("dceb,cef->dbf", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase7(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase7(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase8.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4) def _test_einsum_alphaflod_usecase8(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("acb,ade->dceb", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase8(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_alphaflod_usecase8(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_alphaflod_usecase9.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4) def _test_einsum_alphaflod_usecase9(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,) y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("qkc,ch->hqk", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_alphaflod_usecase9(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_alphaflod_usecase9(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_attention.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_attention(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b h i d, b h j d -> b h i j", g_x, g_y) return z @unittest.skipIf(True, "skip this test temporarily") class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_attention(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_einsum_attention(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_batch_matmul(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("ijk,ikl->ijl", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_batch_matmul(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_batch_matmul(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_matmul2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_batch_matmul2(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2) y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 3) * 8) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b h i j, b h j d -> b h i d", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_batch_matmul2(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_einsum_batch_matmul2(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_matmul3.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True, atol=1e-2) def _test_einsum_batch_matmul3(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor( ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1, ) y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b x i d, b j d -> b x i j", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_batch_matmul3(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_batch_matmul3(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_matmul4.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, atol=1e-2) def _test_einsum_batch_matmul4(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor( ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1, ) y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b x i j, b j d -> b x i d", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_batch_matmul4(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_batch_matmul4(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_matrix_vector_multiply.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,) y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b i d, b i j d -> b i j", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 28 times in past week") @globaltest def test_einsum_batch_matrix_vector_multiply(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_batch_permute.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_batch_permute(test_case, placement, sbp): x = random_tensor( ndim=5, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=random(1, 3) * 8, dim4=random(1, 3) * 8, ) g_x = x.to_global(placement=placement, sbp=sbp) z = torch.einsum("...ij->...ji", g_x) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_batch_permute(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=5): _test_einsum_batch_permute(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_bilinear_transformation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_bilinear_transformation(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim1, dim2=dim2,) w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) g_w = w.to_global(placement=placement, sbp=sbp) z = torch.einsum("ik,jkl,il->ij", g_x, g_y, g_w) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_bilinear_transformation(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_bilinear_transformation(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_eltwise_mul_sum_row.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("n d, n d -> n", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_eltwise_mul_sum_row(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) # NOTE(Liang Depeng): the same as 'ij,ij->' z = torch.einsum("ij,ij", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_eltwise_mul_then_reduce_sum(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_eltwise_multiply.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_eltwise_multiply(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("ij,ij->ij", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_eltwise_multiply(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_eltwise_multiply(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_get_diagonal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_get_diagonal(test_case, placement, sbp): dim = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim, dim1=dim,) g_x = x.to_global(placement=placement, sbp=sbp) z = torch.einsum("ii->i", g_x) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_get_diagonal(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_get_diagonal(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3) def _test_einsum_matmul(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 dim2 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) # NOTE(Liang Depeng): the same as 'ik,kj->ij' z = torch.einsum("ik,kj", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest @unittest.skip("skip for now, becase it fails several times in CI") def test_einsum_matmul(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_matmul(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_matmul2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_matmul2(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) y = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("i d, j d -> i j", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @globaltest def test_einsum_matmul2(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_matmul2(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_matrix_column_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3) def _test_einsum_matrix_column_sum(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) z = torch.einsum("ij->j", g_x) return z class TestEinsumGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 8 times in past week") @globaltest def test_einsum_matrix_column_sum(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_matrix_column_sum(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_matrix_transpose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_matrix_transpose(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8) g_x = x.to_global(placement=placement, sbp=sbp) z = torch.einsum("ij->ji", g_x) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_matrix_transpose(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_matrix_transpose(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_matrix_vector_multiply.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_matrix_vector_multiply(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) y = random_tensor(ndim=1, dim0=dim1,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) # NOTE(Liang Depeng): the same as 'ik,k->i' z = torch.einsum("ik,k", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_matrix_vector_multiply(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_einsum_matrix_vector_multiply(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_reduce_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_reduce_sum(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) z = torch.einsum("ij->", g_x) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_reduce_sum(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_reduce_sum(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_tensor_contraction.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # The rtol is too large caused by the expansion of random tensor range # of #9534. It should be checked again in the future. @autotest(n=1, check_graph=True, rtol=5e-1, atol=1e-3) def _test_einsum_tensor_contraction(test_case, placement, sbp): dim0 = random(1, 3) * 8 dim1 = random(1, 3) * 8 x = random_tensor( ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8, ) y = random_tensor( ndim=5, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0, dim3=random(1, 3) * 8, dim4=dim1, ) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("pqrs,tuqvr->pstuv", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_tensor_contraction(test_case): for placement in all_placement(): if len(np.array(placement.ranks).shape) > 1 and all( dim != 1 for dim in np.array(placement.ranks).shape ): print( f"[{flow.env.get_rank()}] skip TestEinsumConsistent.test_einsum_tensor_contraction with {placement}" ) continue for sbp in all_sbp(placement, max_dim=4): _test_einsum_tensor_contraction(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_tensor_contraction2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4) def _test_einsum_tensor_contraction2(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor( ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=random(1, 3) * 8, dim3=random(1, 3) * 8, ) y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) z = torch.einsum("b n h w, n d -> b d h w", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 10 times in past week") @globaltest def test_einsum_tensor_contraction2(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_einsum_tensor_contraction2(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_vector_inner_product.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_vector_inner_product(test_case, placement, sbp): dim0 = random(1, 3) * 8 x = random_tensor(ndim=1, dim0=dim0,) y = random_tensor(ndim=1, dim0=dim0,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) # NOTE(Liang Depeng): the same as 'i,i->' z = torch.einsum("i,i", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_vector_inner_product(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_einsum_vector_inner_product(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_einsum_vector_outer_product.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_einsum_vector_outer_product(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=random(1, 3) * 8,) y = random_tensor(ndim=1, dim0=random(1, 3) * 8,) g_x = x.to_global(placement=placement, sbp=sbp) g_y = y.to_global(placement=placement, sbp=sbp) # NOTE(Liang Depeng): the same as 'i,j->ij' z = torch.einsum("i,j", g_x, g_y) return z class TestEinsumGlobal(flow.unittest.TestCase): @globaltest def test_einsum_vector_outer_product(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_einsum_vector_outer_product(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_empty.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_empty(test_case, func, shape, placement, sbp): func2 = None if func == "empty": func = flow.empty elif func == "new_empty": func = flow.empty func2 = flow.new_empty elif func == "empty_like": func = flow.empty func2 = flow.empty_like else: raise NotImplementedError x = func(*shape, placement=placement, sbp=sbp) if func2: if func2.__name__ == "new_empty_op": x = func2(x, size=shape) elif func2.__name__ == "empty_like_op": x = func2(x) else: raise NotImplementedError test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_empty(test_case, func, shape, placement, sbp): func2 = None if func == "empty": func = flow.empty elif func == "new_empty": func = flow.empty func2 = flow.new_empty elif func == "empty_like": func = flow.empty func2 = flow.empty_like else: raise NotImplementedError class GlobalEmptyGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = func(*shape, placement=placement, sbp=sbp) if func2: if func2.__name__ == "new_empty_op": x = func2(x, size=shape) elif func2.__name__ == "empty_like_op": x = func2(x) else: raise NotImplementedError return x model = GlobalEmptyGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestEmptyGlobal(flow.unittest.TestCase): @globaltest def test_empty_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] functions = [ "empty", "new_empty", "empty_like", ] for func in functions: for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_empty(test_case, func, shape, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_empty_graph(test_case): arg_dict = OrderedDict() arg_dict["func"] = ["empty", "new_empty", "empty_like"] arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): func = args["func"] shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_empty(test_case, func, shape, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_eq.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def do_test_eq_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) x = x.to_global(placement=placement, sbp=sbp) y = random_tensor(ndim, *dims) y = y.to_global(placement=placement, sbp=sbp) z = torch.eq(x, y) return z class TestEqGlobal(flow.unittest.TestCase): @globaltest def test_eq(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_eq_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_erf.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_erf_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.erf(y) return z class TestErfGlobal(flow.unittest.TestCase): @globaltest def test_erf(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_erf_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_erfc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def do_test_erfc_impl(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.erfc(y) return z class TestErfcGlobal(flow.unittest.TestCase): @globaltest def test_erfc(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_erfc_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_expand_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest import torch from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict def _test_global_expand( test_case, input_shape, expand_shape, device="cuda", sbp=flow.sbp.broadcast, verbose=False, ): # random input input = np.random.randn(*input_shape) if isinstance(input, np.ndarray): input = input.astype(np.float32) # torch computation torch_x = torch.tensor(input, requires_grad=True) torch_y = torch_x.expand(*expand_shape) torch_y.sum().backward() # oneflow computation placement = flow.placement(device, np.array(range(flow.env.get_world_size()))) x = flow.tensor(input, requires_grad=True) global_x = x.to_global(placement=placement, sbp=flow.sbp.broadcast) if global_x.sbp != sbp: global_x = global_x.to_global(sbp=sbp, grad_sbp=flow.sbp.broadcast) y = global_x.expand(*expand_shape) y.sum().backward() y_b = y.to_global(sbp=flow.sbp.broadcast) if flow.env.get_rank() == 0: out_a = y_b.to_local().numpy() out_b = torch_y.detach().cpu().numpy() grad_a = x.grad.numpy() grad_b = torch_x.grad.cpu().numpy() if verbose: print("") print(f"{'=' * 10} {input_shape} -> {expand_shape} {'=' * 10}") print(f"{'=' * 10} {device}, {sbp} {'=' * 10}") print(f"{'-' * 20} compare out {'-' * 20}") print(out_a) print("*" * 20) print(out_b) print("") print(f"{'-' * 20} compare grad {'-' * 20}") print(grad_a) print("*" * 20) print(grad_b) test_case.assertTrue(np.array_equal(out_a, out_b)) test_case.assertTrue(np.array_equal(grad_a, grad_b)) @flow.unittest.skip_unless_1n2d() class ExpandGlobalTestCase(oneflow.unittest.TestCase): def test_global_expand(test_case): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cpu", "cuda"] arg_dict["sbp"] = [flow.sbp.split(0), flow.sbp.broadcast()] arg_dict["shapes"] = [ ((2, 2), (2, 2, 2)), ((2, 1, 3), (2, 1, -1, -1, -1)), ((2, 1, 3), (1, 2, -1, -1, -1)), ((2, 1, 3), (2, 1, -1, 2, 3)), ((2, 1, 3), (1, 2, 2, 2, -1)), ] for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _test_global_expand(test_case, input_shape, expand_shape, **kwargs) def test_split_expand(test_case): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cuda"] arg_dict["sbp"] = [flow.sbp.split(0)] arg_dict["shapes"] = [ ((2,), (1, 2)), ((2,), (2, 2)), ] for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _test_global_expand(test_case, input_shape, expand_shape, **kwargs) def test_broadcast_scalar_expand(test_case): arg_dict = OrderedDict() arg_dict["verbose"] = [False] arg_dict["device"] = ["cpu", "cuda"] arg_dict["sbp"] = [flow.sbp.broadcast()] arg_dict["shapes"] = [ ((), (1,)), ((), (2,)), ((), (1, 1)), ((), (1, 2)), ((), (2, 1)), ((), (2, 2)), ((), (2, 1, 2)), ] for kwargs in GenArgDict(arg_dict): assert "shapes" in kwargs input_shape, expand_shape = kwargs.pop("shapes") _test_global_expand(test_case, input_shape, expand_shape, **kwargs) if __name__ == "__main__": unittest.main() # ONEFLOW_TEST_DEVICE_NUM=2 python3 -m oneflow.distributed.launch --nproc_per_node 2 test_global_expand_op.py ================================================ FILE: python/oneflow/test/modules/test_global_expm1.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_expm1_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.expm1(y) return z class TestExpm1Global(flow.unittest.TestCase): @globaltest def test_expm1(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_expm1_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_eye.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def do_test_eye_impl(test_case, placement, sbp): n = random(1, 5).to(int).value() * 8 m = random(1, 5).to(int).value() * 8 x = torch.eye(n, m) x.oneflow = flow.tensor( x.pytorch.cpu().detach().numpy(), requires_grad=x.pytorch.requires_grad, placement=placement, sbp=sbp, ) return x class TestEyeGlobal(flow.unittest.TestCase): @globaltest def test_eye(test_case): shape = random_tensor().shape for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): do_test_eye_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_fill.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_fill_(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) value = random().to(float) y = x + 1 y.fill_(value) return y @autotest(n=1, check_graph=True) def _test_fill_tensor_(test_case, ndim, placement, sbp): dims = [random(2, 4) * 8 for i in range(ndim)] x = ( random_tensor(ndim, *dims) .to_global(placement=placement, sbp=sbp) .requires_grad_() ) value = ( torch.tensor(1.0) .to_global(placement=placement, sbp=[flow.sbp.broadcast for _ in sbp]) .requires_grad_() ) y = x + 1 y.oneflow = y.oneflow.to_global(placement, sbp) y.fill_(value) return y class TestFillModule(flow.unittest.TestCase): @globaltest def test_fill_(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_fill_(test_case, ndim, placement, sbp) _test_fill_tensor_(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_flatten.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_flatten_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) start_dim = random(0, ndim).to(int).value() end_dim = random(start_dim, ndim).to(int).value() z = torch.flatten(x, start_dim, end_dim) return z class TestFlattenGlobal(flow.unittest.TestCase): @globaltest def test_flatten(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_flatten_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_flip_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) new_dim = random(0, ndim).to(int).value() z = torch.flip(y, constant([i for i in range(new_dim)])) return z class TestFlipGlobal(flow.unittest.TestCase): @globaltest def test_flip(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_flip_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_floor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def do_test_floor_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.floor(y) return z class TestFloorGlobal(flow.unittest.TestCase): @globaltest def test_floor(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_floor_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_fmod.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as torch_original from packaging import version # other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0 grad_implemented = version.parse(torch_original.__version__) >= version.parse("1.11.0") @autotest(n=1, auto_backward=grad_implemented, check_graph=True) def do_test_fmod_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims) x = x.to_global(placement=placement, sbp=sbp) y = random_tensor(ndim, *dims) y = y.to_global(placement=placement, sbp=sbp) z = torch.fmod(x, y) return z class TestFmodGlobal(flow.unittest.TestCase): @globaltest def test_fmod(test_case): # random ndim in range [1,5] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_fmod_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_fold.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_fold_impl(test_case, placement, sbp): ndim = 3 dims = [random(1, 4).to(int).value() * 8 for i in range(ndim)] m = torch.nn.Fold( output_size=constant(((dims[2] // 4) * 2, 4 * 2)), kernel_size=constant(2), dilation=constant(1), padding=constant(0), stride=constant(2), ) m.train(random()) x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = m(x) func_y = torch.nn.functional.fold( x, output_size=constant(((dims[2] // 4) * 2, 4 * 2)), kernel_size=constant(2), dilation=constant(1), padding=constant(0), stride=constant(2), ) return y, func_y class TestFold(flow.unittest.TestCase): @globaltest def test_fold(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_fold_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_frac.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_frac(test_case, ndim, placement, sbp): shape = [random(2, 4) * 8 for i in range(ndim)] input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp) output = torch.frac(input) return output class TestModule(flow.unittest.TestCase): @globaltest def test_frac(test_case): ndim = random(2, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_frac(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_full.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_full(test_case, shape, placement, sbp): x = flow.full(shape, 1.0, placement=placement, sbp=sbp) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_global_full_tensor_scalar(test_case, shape, placement, sbp): scalar_sbp = [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] x1 = flow.tensor(1.0, placement=placement, sbp=scalar_sbp) x2 = flow.full(shape, x1, placement=placement, sbp=sbp) test_case.assertEqual(x2.shape, flow.Size(shape)) test_case.assertEqual(x2.sbp, sbp) test_case.assertEqual(x2.placement, placement) def _test_graph_full(test_case, shape, placement, sbp): class GlobalFullGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.full(shape, 1.0, placement=placement, sbp=sbp) return x model = GlobalFullGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_full_tensor_scalar(test_case, shape, placement, sbp): class GlobalFullGraph2(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.full( shape, flow.tensor( 1.0, placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ), placement=placement, sbp=sbp, ) return x model = GlobalFullGraph2() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestFullGlobal(flow.unittest.TestCase): @globaltest def test_full_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_full(test_case, shape, placement, sbp) _test_global_full_tensor_scalar(test_case, shape, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_full_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [[8], [8, 8], [8, 8, 8]] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_full(test_case, shape, placement, sbp) _test_graph_full_tensor_scalar(test_case, shape, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_full_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_full_like(test_case, shape, placement, sbp): x_ = flow.randn(shape) x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_full_like(test_case, shape, placement, sbp): class GlobalFullLikeGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x_ = flow.randn(shape) x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp) return x model = GlobalFullLikeGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestFillLikeGlobal(flow.unittest.TestCase): @globaltest def test_full_like_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_full_like(test_case, shape, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_full_like_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [[8], [8, 8], [8, 8, 8]] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_full_like(test_case, shape, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_greater.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=10, auto_backward=False, check_graph=True) def _test_greater_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x1 = random_tensor(ndim, *dims) x2 = x1.to_global(placement=placement, sbp=sbp) y1 = random_tensor(ndim, *dims) y2 = y1.to_global(placement=placement, sbp=sbp) z = torch.gt(x2, y2) return z @unittest.skip("TODO: houjiang, yushun. this test might fail") class TestGreaterGlobal(flow.unittest.TestCase): @globaltest def test_greater(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_greater_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_greater_equal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def do_test_greater_equal_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x1 = random_tensor(ndim, *dims) x1 = x1.to_global(placement=placement, sbp=sbp) x2 = random_tensor(ndim, *dims) x2 = x2.to_global(placement=placement, sbp=sbp) z = torch.ge(x1, x2) return z class TestGreaterEqualGlobal(flow.unittest.TestCase): @globaltest def test_greater_equal(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_greater_equal_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_grid_sample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True) def _test_flow_grid_sample_cudnn(test_case, placement, sbp): # cudnn only support 4D input, with mode = 'bilinear' && padding_mode = 'zeros' && align_corners N = random(1, 3).to(int) * 8 C = random(1, 3).to(int) * 8 in_H = random(1, 8).to(int) in_W = random(1, 8).to(int) out_H = random(1, 8).to(int) out_W = random(1, 8).to(int) mode = "bilinear" padding_mode = "zeros" align_corners = True theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to_global( placement=placement, sbp=random_sbp(placement, max_dim=1) ) grid = torch.nn.functional.affine_grid( theta, (N, C, out_H, out_W), align_corners=align_corners ) input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to_global( placement=placement, sbp=sbp ) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output # This test may fail due to using ::floor in backward # floor(1.99999988) = 1 and floor(2.000000) = 2, then select differente images pixel @autotest( n=1, auto_backward=False, rtol=1e-03, atol=1e-04, check_graph=True, check_allclose=False, ) def _test_flow_grid_sample_4d(test_case, placement, sbp): N = random(1, 3).to(int) * 8 C = random(1, 3).to(int) * 8 in_H = random(1, 8).to(int) in_W = random(1, 8).to(int) out_H = random(1, 8).to(int) out_W = random(1, 8).to(int) mode = oneof("bilinear", "nearest", "bicubic") padding_mode = oneof("zeros", "border", "reflection") align_corners = oneof(True, False) theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to_global( placement=placement, sbp=random_sbp(placement, max_dim=1) ) grid = torch.nn.functional.affine_grid( theta, (N, C, out_H, out_W), align_corners=align_corners ) input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to_global( placement=placement, sbp=sbp ) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output @autotest(n=1, auto_backward=False, rtol=1e-03, atol=1e-03, check_graph=True) def _test_flow_grid_sample_5d(test_case, placement, sbp): N = random(1, 3).to(int) * 8 C = random(1, 3).to(int) * 8 in_D = random(1, 8).to(int) in_H = random(1, 8).to(int) in_W = random(1, 8).to(int) out_D = random(1, 8).to(int) out_H = random(1, 8).to(int) out_W = random(1, 8).to(int) mode = oneof("bilinear", "nearest") padding_mode = oneof("zeros", "border", "reflection") align_corners = oneof(True, False) theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to_global( placement=placement, sbp=random_sbp(placement, max_dim=1) ) grid = torch.nn.functional.affine_grid( theta, (N, C, out_D, out_H, out_W), align_corners=align_corners ) input = random_tensor( ndim=5, dim0=N, dim1=C, dim2=in_D, dim3=in_H, dim4=in_W ).to_global(placement=placement, sbp=sbp) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output class TestGridSample(flow.unittest.TestCase): @unittest.skip("skip for now, becase it may fail in CI") @globaltest def test_grid_sample(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): if placement.type == "cuda": _test_flow_grid_sample_cudnn(test_case, placement, sbp) _test_flow_grid_sample_4d(test_case, placement, sbp) _test_flow_grid_sample_5d(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_groupnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False, atol=1e-3, rtol=1e-3) def _test_global_group_norm(test_case, placement, input_sbp, affine): if placement.type == "cpu": return batch_size = 4 channel_size = 8 num_groups = 2 m = torch.nn.GroupNorm( num_groups=num_groups, num_channels=channel_size, affine=affine ) m.train(random()) m.to_global( placement=placement, sbp=[flow.sbp.broadcast] * len(placement.ranks.shape) ) x = random_tensor( ndim=4, dim0=batch_size, dim1=channel_size, dim2=random(4, 16), dim3=random(4, 16), ).to_global(placement=placement, sbp=input_sbp) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGroupNormModule(flow.unittest.TestCase): @globaltest def test_global_group_norm_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_group_norm(test_case, placement, sbp, True) _test_global_group_norm(test_case, placement, sbp, False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_gru_cell.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # NOTE(lixiang): Do not check the graph for the time being, because ci will report "The action has timed out". @autotest(n=1, check_graph="ValidatedFalse") def _test_gru_cell(test_case, placement, sbp): batch_size = random(2, 3) * 8 time_steps = random(2, 3) * 8 input_size = random(2, 3) * 8 hidden_size = random(2, 3) * 8 has_bias = random().to(bool) m = torch.nn.GRUCell(input_size=input_size, hidden_size=hidden_size, bias=has_bias,) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight_ih = torch.nn.Parameter( m.weight_ih.to_global(placement=placement, sbp=weight_sbp) ) m.weight_hh = torch.nn.Parameter( m.weight_hh.to_global(placement=placement, sbp=weight_sbp) ) if m.bias_ih is not None: # bias is 1-d tensor bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True) m.bias_ih = torch.nn.Parameter( m.bias_ih.to_global(placement=placement, sbp=bias_sbp) ) m.bias_hh = torch.nn.Parameter( m.bias_hh.to_global(placement=placement, sbp=bias_sbp) ) input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to_global(placement=placement, sbp=input_sbp) hx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False ).to_global(placement=placement, sbp=sbp) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx class TestRNNCellGlobal(flow.unittest.TestCase): @globaltest def test_gru_cell(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_gru_cell(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_hann_window.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_hann_window(test_case, placement, sbp): x = flow.hann_window(8, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_hann_window(test_case, placement, sbp): class GlobalHannWindowGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.hann_window(8, placement=placement, sbp=sbp) return x model = GlobalHannWindowGraph() x = model() test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestHannWindowGlobal(flow.unittest.TestCase): # TODO(wyg): It will be infer all broadcast sbp when 1n1d, # slice_update will get error when doing inplace operator. # Remove this judgement after refactor sbp infer method in Operator class. @globaltest def test_hann_window_global(test_case): for placement in all_placement(): if placement.ranks.size == 1: continue for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_global_hann_window(test_case, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_hann_window_graph(test_case): arg_dict = OrderedDict() arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): placement = args["placement"] for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_graph_hann_window(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin from collections import defaultdict def _assert_true(test_case, value1, value2): test_case.assertTrue( np.allclose( value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_activation_grad_grad_impl(test_case, op_name, placement, *args, **kwargs): x = random_tensor(ndim=2, low=-5, dim0=8, dim1=8).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) y = eval(f"torch.nn.functional.{op_name}")(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) init_grad_y = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _assert_true(test_case, dx.pytorch, dx.oneflow) ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) def _test_prelu_activation_grad_grad_impl( test_case, op_name, placement, *args, **kwargs ): x = random_tensor(ndim=2, low=-5, dim0=8, dim1=8).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) a = random_tensor(ndim=1, dim0=x.oneflow.shape[1]).to_global( placement=placement, sbp=random_sbp(placement, max_dim=1) ) y = torch.nn.functional.prelu(x, a) x_shape = x.oneflow.shape a_shape = a.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) init_grad_y = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) init_grad_a = random_tensor(len(a_shape), *a_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=1) ) dx_and_da = torch.autograd.grad(y, [x, a], init_grad_y, True, True) dx, da = dx_and_da[0], dx_and_da[1] _assert_true(test_case, dx.pytorch, dx.oneflow) _assert_true(test_case, da.pytorch, da.oneflow) ddx_dda_ddy = torch.autograd.grad( dx_and_da, [dx, da, init_grad_y], [init_grad_x, init_grad_a], True, True ) ddx, dda, ddy = ddx_dda_ddy[0], ddx_dda_ddy[1], ddx_dda_ddy[2] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, dda.pytorch, dda.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) def _test_hardswish_activation_grad_grad_impl( test_case, op_name, placement, *args, **kwargs ): x = random_tensor(ndim=2, low=-1, dim0=8, dim1=8).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) y = torch.nn.functional.hardswish(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) init_grad_y = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) dx_pytorch = pytorch_origin.autograd.grad( y.pytorch, x.pytorch, init_grad_y.pytorch )[0] dx_oneflow = oneflow_origin.autograd.grad( y.oneflow, x.oneflow, init_grad_y.oneflow, True, True )[0] _assert_true(test_case, dx_pytorch, dx_oneflow) ddx, ddy = flow.autograd.grad( dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow ) x, dx, init_grad_x, init_grad_y = ( x.oneflow, dx_oneflow, init_grad_x.oneflow, init_grad_y.oneflow, ) zeros_grad = flow.zeros_like(x).to_global(placement=placement, sbp=x.sbp) manual_ddx = flow.where( ((x > -3.0) < 3.0), 1.0 / 3.0 * init_grad_x * init_grad_y, zeros_grad ) manual_ddy = dx / init_grad_y * init_grad_x _assert_true(test_case, manual_ddx, ddx) _assert_true(test_case, manual_ddy, ddy) def _test_hardsigmoid_activation_grad_grad_impl( test_case, op_name, placement, *args, **kwargs ): x = random_tensor(ndim=2, low=-1, dim0=8, dim1=8).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) y = torch.nn.functional.hardsigmoid(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) init_grad_y = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement=placement, max_dim=2) ) dx_pytorch = pytorch_origin.autograd.grad( y.pytorch, x.pytorch, init_grad_y.pytorch )[0] dx_oneflow = oneflow_origin.autograd.grad( y.oneflow, x.oneflow, init_grad_y.oneflow, True, True )[0] _assert_true(test_case, dx_pytorch, dx_oneflow) ddx, ddy = flow.autograd.grad( dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow ) x, dx, init_grad_x, init_grad_y = ( x.oneflow, dx_oneflow, init_grad_x.oneflow, init_grad_y.oneflow, ) manual_ddx = flow.zeros_like(x) manual_ddy = dx / init_grad_y * init_grad_x _assert_true(test_case, manual_ddx, ddx) _assert_true(test_case, manual_ddy, ddy) class TestActivationHigherDerivative(flow.unittest.TestCase): @globaltest def test_activation_grad_grad(test_case): op_args = defaultdict(list) op_kwargs = defaultdict(dict) # parameter name not same in pytorch and oneflow op_args["leaky_relu"] = [random(-1, 1).to(float)] # some op only support kwargs, like celu in oneflow op_kwargs["hardtanh"] = { "min_val": random(-5, -1).to(float), "max_val": random(1, 5).to(float), } op_kwargs["elu"] = {"alpha": random(0, 10).to(float)} op_kwargs["celu"] = {"alpha": random(0, 10).to(float)} op_kwargs["threshold"] = { "threshold": random().to(float), "value": random().to(float), } op_kwargs["softplus"] = { "beta": random().to(float), "threshold": random().to(float), } op_names = [ "mish", "gelu", "silu", "selu", "softsign", "hardsigmoid", "hardswish", "relu", "elu", "celu", "prelu", "hardshrink", "softshrink", "leaky_relu", "hardtanh", "softplus", "threshold", ] for op_name in op_names: try: functor = eval(f"_test_{op_name}_activation_grad_grad_impl") except: functor = _test_activation_grad_grad_impl print(f"| {op_name:-^60} |") for placement in all_placement(): for i in range(1): functor( test_case, op_name, placement, *op_args[op_name], **op_kwargs[op_name], ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin def _test_convnd_grad_grad_impl(test_case, ndim, placement): x_shape = [8, 8] + [5 for _ in range(ndim)] w_shape = [8, 8] + [3 for _ in range(ndim)] y_shape = [8, 8] + [3 for _ in range(ndim)] x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) w = random_tensor(len(w_shape), *w_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_w = random_tensor(len(w_shape), *w_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_y = random_tensor(len(y_shape), *y_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = eval(f"torch.nn.functional.conv{ndim}d")( x, w, stride=1, padding=0, groups=1, dilation=1 ) dx = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dx.pytorch.detach().cpu().numpy(), dx.oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) dw = torch.autograd.grad( outputs=y, inputs=w, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dw.pytorch.detach().cpu().numpy(), dw.oneflow.detach().numpy(), rtol=1e-5, atol=1e-5, ) ) # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list # so use the original pytorch/oneflow module ddx_pytorch, ddw_pytorch = pytorch_origin.autograd.grad( outputs=[dx.pytorch, dw.pytorch], inputs=[x.pytorch, w.pytorch], grad_outputs=[init_grad_x.pytorch, init_grad_w.pytorch], create_graph=True, retain_graph=True, ) ddx_oneflow, ddw_oneflow = oneflow_origin.autograd.grad( outputs=[dx.oneflow, dw.oneflow], inputs=[x.oneflow, w.oneflow], grad_outputs=[init_grad_x.oneflow, init_grad_w.oneflow], create_graph=True, retain_graph=True, ) test_case.assertTrue( np.allclose( ddw_pytorch.detach().cpu().numpy(), ddw_oneflow.detach().numpy(), rtol=1e-5, atol=1e-5, ) ) test_case.assertTrue( np.allclose( ddx_pytorch.detach().cpu().numpy(), ddx_oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) dgrad_dx = torch.autograd.grad( outputs=dx, inputs=init_grad_y, grad_outputs=init_grad_x, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_dx.pytorch.detach().cpu().numpy(), dgrad_dx.oneflow.detach().numpy(), rtol=1e-4, atol=1e-2, ) ) dgrad_dw = torch.autograd.grad( outputs=dw, inputs=init_grad_y, grad_outputs=init_grad_w, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_dw.pytorch.detach().cpu().numpy(), dgrad_dw.oneflow.detach().numpy(), rtol=1e-4, atol=1e-2, ) ) class TestGlobalConvHigherDerivative(flow.unittest.TestCase): @globaltest def test_conv1d_grad_grad(test_case): for placement in all_placement(): for i in range(5): _test_convnd_grad_grad_impl(test_case, ndim=1, placement=placement) @globaltest def test_conv2d_grad_grad(test_case): for placement in all_placement(): for i in range(5): _test_convnd_grad_grad_impl(test_case, ndim=2, placement=placement) @globaltest def test_conv3d_grad_grad(test_case): for placement in all_placement(): for i in range(5): _test_convnd_grad_grad_impl(test_case, ndim=3, placement=placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_div.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_global_div_grad_grad_impl(test_case, placement): x_shape = [8, 8, 8, 8] y_shape = [8, 8] if random_bool().value(): x_shape, y_shape = y_shape, x_shape x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = random_tensor(len(y_shape), *y_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) z = torch.div(x, y) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True) test_case.assertTrue( np.allclose( dx_and_dy.pytorch[0].detach().cpu().numpy(), dx_and_dy.oneflow[0].detach().numpy(), rtol=1e-4, atol=1e-4, ) ) test_case.assertTrue( np.allclose( dx_and_dy.pytorch[1].detach().cpu().numpy(), dx_and_dy.oneflow[1].detach().numpy(), rtol=1e-3, atol=1e-4, ) ) ddx_and_ddy_and_ddz = torch.autograd.grad( dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y], True, True ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[0].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[0].detach().numpy(), rtol=1e-3, atol=1e-3, ) ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[1].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[1].detach().numpy(), rtol=1e-2, atol=1e-3, ) ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[2].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[2].detach().numpy(), rtol=1e-3, atol=1e-3, ) ) class TestGlobalDivHigherDerivative(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 22 times in past week") @globaltest def test_global_div_grad_grad(test_case): for placement in all_placement(): for i in range(1): _test_global_div_grad_grad_impl(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _assert_true(test_case, value1, value2, name=""): is_equal = np.allclose( value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-03, atol=1e-03, ) test_case.assertTrue(is_equal, f"{name} is not equal." if name else "") def generate_grads_for_variables(variables): if isinstance(variables, list): shape_and_sbp = [(i.oneflow.shape, i.oneflow.sbp) for i in variables] placement = variables[0].oneflow.placement elif hasattr(variables, "pytorch"): shape_and_sbp = [(i.shape, i.sbp) for i in variables.oneflow] placement = variables.oneflow[0].placement else: assert False grads = [ random_tensor( len(shape), *shape, requires_grad=random_bool().value() ).to_global(placement=placement, sbp=sbp) for shape, sbp in shape_and_sbp ] return grads def calculate_and_compare_loss(test_case, input, target, model, order=2): output = model(input, target) _assert_true(test_case, output.pytorch, output.oneflow, "output") init_inputs = [input, target] grad_inputs = [output] grad_outputs = [] for i in range(order): inputs = [ var for var in [*init_inputs, *grad_outputs] if var.pytorch.requires_grad ] outputs = grad_inputs grad_outputs = generate_grads_for_variables(outputs) if i == order - 1: grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs) else: grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs, True, True) for j in range(len(inputs)): _assert_true( test_case, grad_inputs[j].pytorch, grad_inputs[j].oneflow, f"{i}-grad_inputs[{j}]", ) def generate_necessity_for_default_loss(placement): shape = [8, 8] ndim = len(shape) input_requires_grad = True target_requires_grad = random_bool().value() return ( random_tensor(ndim, *shape, low=0, requires_grad=input_requires_grad).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ), random_tensor( ndim, *shape, low=0, requires_grad=target_requires_grad ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=2)), ) def generate_necessity_for_nll_loss(placement): ndim = 2 num_classes = 8 batch_size = 8 ignore_index = oneof(random(0, num_classes).to(int).value(), -100).value() extra_dim = [random().to(int) for _ in range(ndim - 2)] return ( random_tensor(ndim, batch_size, num_classes).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ), random_tensor( ndim - 1, batch_size, low=0, high=num_classes, dtype=int, requires_grad=False, ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)), random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to_global( placement=placement, sbp=random_sbp(placement, except_split=True) ), ignore_index, ) def generate_necessity_for_bce_loss(placement): ndim = 3 num_classes = 2 batch_size = 8 extra_dim = [random().to(int) for _ in range(ndim - 2)] input_requires_grad = True target_requires_grad = False return ( random_tensor( ndim, batch_size, num_classes, low=0, high=1, *extra_dim, requires_grad=input_requires_grad, ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)), random_tensor( ndim, batch_size, num_classes, *extra_dim, low=0, high=num_classes, requires_grad=target_requires_grad, ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)), random_tensor( ndim, batch_size, num_classes, *extra_dim, low=0, high=3, requires_grad=False, ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)), random_tensor(1, 1, low=1, high=3, requires_grad=False,).to_global( placement=placement, sbp=random_sbp(placement, except_split=True) ), ) def _test_smooth_l1_loss_grad_grad_impl(test_case, placement): x, y = generate_necessity_for_default_loss(placement) m = torch.nn.SmoothL1Loss( reduction=oneof("none", "sum", "mean", nothing()), beta=oneof(0.0, 0.5, 1) ) calculate_and_compare_loss(test_case, x, y, m) def _test_kl_div_loss_grad_grad_impl(test_case, placement): x, y = generate_necessity_for_default_loss(placement) m = torch.nn.KLDivLoss( reduction=oneof("none", "sum", "mean", nothing()), log_target=oneof(True, False), ) calculate_and_compare_loss(test_case, x, y, m) def _test_bce_loss_grad_grad_impl(test_case, placement, with_logits=False): x, y, weight, pos_weight = generate_necessity_for_bce_loss(placement) if with_logits: weight = weight if random_bool().value() else None has_pos_weight = random_bool().value() pos_weight = pos_weight if has_pos_weight else nothing() m = torch.nn.BCEWithLogitsLoss( weight=weight, pos_weight=pos_weight, reduction=oneof("none", "sum", "mean"), ) if has_pos_weight: y = y.detach().clone().requires_grad_(False) else: m = torch.nn.BCELoss( weight=(weight if random_bool().value() else None), reduction=oneof("none", "sum", "mean"), ) calculate_and_compare_loss(test_case, x, y, m) def _test_nll_loss_grad_grad_impl(test_case, placement): (x, y, weight, ignore_index) = generate_necessity_for_nll_loss(placement) m = torch.nn.NLLLoss( weight=(weight if random_bool().value() else None), reduction=oneof("none", "sum", "mean", nothing()), ignore_index=ignore_index, ) calculate_and_compare_loss(test_case, x, y, m) class TestGlobalLossHigherDerivative(flow.unittest.TestCase): @globaltest def test_smooth_l1_loss_grad_grad(test_case): for placement in all_placement(): _test_smooth_l1_loss_grad_grad_impl(test_case, placement) @globaltest def test_kl_div_loss_grad_grad(test_case): for placement in all_placement(): _test_kl_div_loss_grad_grad_impl(test_case, placement) @globaltest def test_nll_loss_grad_grad(test_case): for placement in all_placement(): _test_nll_loss_grad_grad_impl(test_case, placement) @globaltest def test_bce_loss_grad_grad(test_case): for placement in all_placement(): _test_bce_loss_grad_grad_impl(test_case, placement) @globaltest def test_bce_with_logits_loss_grad_grad(test_case): for placement in all_placement(): _test_bce_loss_grad_grad_impl(test_case, placement, with_logits=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin def _test_broadcast_matmul_grad_b_grad_impl(test_case, placement): broadcast_dims = [np.random.randint(1, 5) * 8 for _ in range(2)] m = np.random.randint(1, 5) * 8 n = np.random.randint(1, 5) * 8 k = np.random.randint(1, 5) * 8 a_shape = broadcast_dims + [m, k] b_shape = [k, n] y_shape = broadcast_dims + [m, n] a = random_tensor(len(a_shape), *a_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) b = random_tensor(len(b_shape), *b_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_a = random_tensor(len(a_shape), *a_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_b = random_tensor(len(b_shape), *b_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_y = random_tensor(len(y_shape), *y_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = torch.matmul(a, b) da = torch.autograd.grad( outputs=y, inputs=a, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( da.pytorch.detach().cpu().numpy(), da.oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) db = torch.autograd.grad( outputs=y, inputs=b, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( db.pytorch.detach().cpu().numpy(), db.oneflow.detach().numpy(), rtol=1e-3, atol=1e-4, ) ) # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list # so use the original pytorch/oneflow module dda_pytorch, ddb_pytorch = pytorch_origin.autograd.grad( outputs=[da.pytorch, db.pytorch], inputs=[a.pytorch, b.pytorch], grad_outputs=[init_grad_a.pytorch, init_grad_b.pytorch], create_graph=True, retain_graph=True, ) dda_oneflow, ddb_oneflow = oneflow_origin.autograd.grad( outputs=[da.oneflow, db.oneflow], inputs=[a.oneflow, b.oneflow], grad_outputs=[init_grad_a.oneflow, init_grad_b.oneflow], create_graph=True, retain_graph=True, ) test_case.assertTrue( np.allclose( ddb_pytorch.detach().cpu().numpy(), ddb_oneflow.detach().numpy(), rtol=1e-3, atol=1e-4, ) ) test_case.assertTrue( np.allclose( dda_pytorch.detach().cpu().numpy(), dda_oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) dgrad_da = torch.autograd.grad( outputs=da, inputs=init_grad_y, grad_outputs=init_grad_a, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_da.pytorch.detach().cpu().numpy(), dgrad_da.oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) dgrad_db = torch.autograd.grad( outputs=db, inputs=init_grad_y, grad_outputs=init_grad_b, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_db.pytorch.detach().cpu().numpy(), dgrad_db.oneflow.detach().numpy(), rtol=1e-5, atol=1e-2, ) ) class TestGlobalMatmulHigherDerivative(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 32 times in past week") @globaltest def test_broadcast_matmul_grad_b_grad(test_case): for placement in all_placement(): for i in range(5): _test_broadcast_matmul_grad_b_grad_impl(test_case, placement=placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_neg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _global_neg_grad_grad_impl(test_case, placement, sbp): x = flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True) init_grad = ( flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True) ) init_grad_grad = ( flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True) ) y = x.neg() x_grad = flow.autograd.grad(y, x, init_grad, create_graph=True)[0] test_case.assertTrue(np.allclose(-init_grad, x_grad.detach().numpy())) dgrad = flow.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=True)[0] test_case.assertTrue(np.allclose(-init_grad_grad, dgrad.detach().numpy(),)) class TestGlobalNegHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_neg_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_neg_grad_grad_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, name="", rtol=1e-5, atol=1e-5): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal, f"{name} is not equal" if name else "") def _test_avg_pool_grad_grad_impl(test_case, placement, ndim): x_shape = [8, 8] + [5] * ndim m = eval(f"torch.nn.AvgPool{ndim}d")(kernel_size=random(2, 5).to(int)) x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") def _test_max_pool_grad_grad_impl(test_case, placement, ndim): x_shape = [8, 8] + [5] * ndim m = eval(f"torch.nn.MaxPool{ndim}d")(kernel_size=random(2, 5).to(int)) x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") def _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim, mode): x_shape = [8, 8] + [5] * ndim m = eval(f"torch.nn.Adaptive{mode.title()}Pool{ndim}d")( output_size=random(2, 5).to(int) ) x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") @flow.unittest.skip_unless_1n1d() class TestGlobalPoolHigherDerivative(flow.unittest.TestCase): @globaltest def test_max_pool_1d_grad_grad(test_case): for placement in all_placement(): _test_max_pool_grad_grad_impl(test_case, placement, 1) @globaltest def test_max_pool_2d_grad_grad(test_case): for placement in all_placement(): _test_max_pool_grad_grad_impl(test_case, placement, 2) @globaltest def test_max_pool_3d_grad_grad(test_case): for placement in all_placement(): _test_max_pool_grad_grad_impl(test_case, placement, 3) @globaltest def test_avg_pool_1d_grad_grad(test_case): for placement in all_placement(): _test_avg_pool_grad_grad_impl(test_case, placement, ndim=1) @globaltest def test_avg_pool_2d_grad_grad(test_case): for placement in all_placement(): _test_avg_pool_grad_grad_impl(test_case, placement, ndim=2) @globaltest def test_avg_pool_3d_grad_grad(test_case): for placement in all_placement(): _test_avg_pool_grad_grad_impl(test_case, placement, ndim=3) @globaltest def test_adaptive_avg_pool_1d_grad_grad(test_case): for placement in all_placement(): _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=1, mode="avg") @globaltest def test_adaptive_avg_pool_2d_grad_grad(test_case): for placement in all_placement(): _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=2, mode="avg") @globaltest def test_adaptive_avg_pool_3d_grad_grad(test_case): for placement in all_placement(): _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=3, mode="avg") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_pow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, rtol=1e-3, atol=1e-3): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal) def _test_global_pow_grad_grad_impl(test_case, placement): x_shape, y_shape = [([8, 8], [8, 8]), ([8, 8, 8], [8, 8]), ([8, 8], [8, 8, 8]),][ random(1, 3).to(int).value() ] x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) y = random_tensor(len(y_shape), *y_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) z = torch.pow(x, y) _check_equal(test_case, z.pytorch, z.oneflow) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True) _check_equal(test_case, dx_and_dy.pytorch[0], dx_and_dy.oneflow[0]) _check_equal(test_case, dx_and_dy.pytorch[1], dx_and_dy.oneflow[1]) ddx_ddy_ddz = torch.autograd.grad( dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y] ) _check_equal(test_case, ddx_ddy_ddz.pytorch[0], ddx_ddy_ddz.oneflow[0]) _check_equal(test_case, ddx_ddy_ddz.pytorch[1], ddx_ddy_ddz.oneflow[1]) _check_equal(test_case, ddx_ddy_ddz.pytorch[2], ddx_ddy_ddz.oneflow[2]) class TestGlobalPowHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_pow_grad_grad(test_case): for placement in all_placement(): for i in range(5): _test_global_pow_grad_grad_impl(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_scalar_pow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, rtol=1e-4, atol=1e-4, name=""): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal, f"{name} is not equal") def _test_scalar_pow_grad_grad_impl(test_case, placement, reverse=False): x_shape = [8, 8] y = random().to(float if random_bool().value() else int).value() x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) z = torch.pow(x, y) if not reverse else torch.pow(y, x) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) dx = torch.autograd.grad(z, x, init_grad_z, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, name="dx") ddx_and_ddz = torch.autograd.grad(dx, [x, init_grad_z], init_grad_x, True, True) _check_equal(test_case, ddx_and_ddz.pytorch[0], ddx_and_ddz.oneflow[0], name="ddx") _check_equal(test_case, ddx_and_ddz.pytorch[1], ddx_and_ddz.oneflow[1], name="ddz") class TestGlobalScalarPowHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_scalar_pow_grad_grad(test_case): for placement in all_placement(): for i in range(10): _test_scalar_pow_grad_grad_impl(test_case, placement) @globaltest def test_global_scalar_reverse_pow_grad_grad(test_case): for placement in all_placement(): for i in range(10): _test_scalar_pow_grad_grad_impl(test_case, placement, reverse=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_slice.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _global_slice_grad_grad_impl(test_case, placement, sbp): x = ( random_tensor(ndim=3, dim0=8, dim1=8, dim2=8) .to_global(placement=placement, sbp=sbp) .requires_grad_(True) ) init_grad = ( random_tensor(ndim=3, dim0=8, dim1=8, dim2=4) .to_global(placement=placement, sbp=sbp) .requires_grad_(True) ) init_grad_grad = ( random_tensor(ndim=3, dim0=8, dim1=8, dim2=8) .to_global(placement=placement, sbp=sbp) .requires_grad_(True) ) y = x[:, :, 2:6] x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0] test_case.assertTrue( np.allclose( x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy() ) ) dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=False)[ 0 ] test_case.assertTrue( np.allclose( dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(), ) ) class TestGlobalSliceHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_slice_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_slice_grad_grad_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_higher_derivative_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _assert_true(test_case, value1, value2): test_case.assertTrue( np.allclose( value1.detach().cpu().numpy(), value2.detach().cpu().numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_global_softmax_grad_grad_impl(test_case, op_name, placement, sbp): ndim = 2 data = random_tensor(ndim=ndim, dim0=8, dim1=8) for dim in range(ndim): x = ( data.detach() .clone() .requires_grad_() .to_global(placement=placement, sbp=sbp) ) m = eval(f"torch.nn.{op_name}")(dim) y = m(x) _assert_true(test_case, y.pytorch, y.oneflow) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=sbp ) init_grad_y = random_tensor(len(x_shape), *x_shape).to_global( placement=placement, sbp=sbp ) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _assert_true(test_case, dx.pytorch, dx.oneflow) ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) class TestGlobalSoftmaxHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_softmax_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_softmax_grad_grad_impl( test_case, op_name="Softmax", placement=placement, sbp=sbp ) @globaltest def test_global_logsoftmax_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_softmax_grad_grad_impl( test_case, op_name="LogSoftmax", placement=placement, sbp=sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_inv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_inv(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)] square_dim = 8 dim_list.extend([square_dim] * 2) x = ( random_tensor(ndim, *dim_list, low=-1) .to(torch.double) .to_global(placement, sbp) ) return torch.linalg.inv(x) class TestInv(flow.unittest.TestCase): @globaltest def test_inv(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_inv(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_lerp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_lerp(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)] square_dim = 8 dim_list.extend([square_dim] * 2) start = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp) end = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp) weight = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp) return torch.lerp(start, end, weight) class TestLerp(flow.unittest.TestCase): @globaltest def test_lerp(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_lerp(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_linalg_cross.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1) def _test_linalg_cross(test_case, index_size_equal_3, ndim, placement, sbp): shape = [random(1, 4).to(int) * 8 for i in range(ndim)] shape[index_size_equal_3] = 3 x = random_tensor(ndim, *shape) x = x.to_global(placement=placement, sbp=sbp) y = random_tensor(ndim, *shape) y = y.to_global(placement=placement, sbp=sbp) return torch.cross( x, y, dim=index_size_equal_3 ) # TODO(peihong): will convert to torch.linalg.cross when PyTorch in ci is upgraded to 1.11 class TestLinalgCrossGlobal(flow.unittest.TestCase): @globaltest def test_linalg_cross(test_case): ndim = random(2, 5).to(int).value() index_size_equal_3 = random(0, ndim).to(int).value() for placement in all_placement(): for sbp in all_sbp( placement, max_dim=ndim, valid_split_axis=[i for i in range(ndim) if i != index_size_equal_3], ): _test_linalg_cross(test_case, index_size_equal_3, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_linear.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_linear_with_random_data(test_case, placement, input_sbp): input_size = 8 m = torch.nn.Linear(in_features=input_size, out_features=8, bias=random()) m.train(random()) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight = torch.nn.Parameter( m.weight.to_global(placement=placement, sbp=weight_sbp) ) if m.bias is not None: # bias is 1-d tensor bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True) m.bias = torch.nn.Parameter(m.bias.to_global(placement=placement, sbp=bias_sbp)) x = random_tensor(ndim=2, dim0=input_size, dim1=8).to_global( placement=placement, sbp=input_sbp ) y = m(x) return y class TestLinearModule(flow.unittest.TestCase): @globaltest def test_linear_with_random_data(test_case): for placement in all_placement(): for input_sbp in all_sbp(placement, max_dim=2): _test_linear_with_random_data(test_case, placement, input_sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_linspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_linspace(test_case, placement, sbp): x = flow.linspace(start=-10, end=10, steps=8, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_linspace(test_case, start, end, steps, placement, sbp): class GlobalLinspaceGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.linspace(start, end, steps, placement=placement, sbp=sbp) return x model = GlobalLinspaceGraph() x = model() test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestLinspaceGlobal(flow.unittest.TestCase): @globaltest def test_linspace_global(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_global_linspace(test_case, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_linspace_graph(test_case): arg_dict = OrderedDict() arg_dict["start"] = [-2, 0, 2] arg_dict["end"] = [4, 8, 16] arg_dict["steps"] = [8, 16, 24] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): start = args["start"] end = args["end"] steps = args["steps"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_graph_linspace(test_case, start, end, steps, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_logspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_logspace(test_case, placement, sbp): x = flow.logspace(start=-10, end=10, steps=8, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_logspace(test_case, start, end, steps, placement, sbp): class GlobalLogspaceGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.logspace(start, end, steps, placement=placement, sbp=sbp) return x model = GlobalLogspaceGraph() x = model() test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestLogspaceGlobal(flow.unittest.TestCase): # TODO(wyg): It will be infer all broadcast sbp when 1n1d, # slice_update will get error when doing inplace operator. # Remove this judgement after refactor sbp infer method in Operator class. @globaltest def test_logspace_global(test_case): for placement in all_placement(): if placement.ranks.size == 1: continue for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_global_logspace(test_case, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_logspace_graph(test_case): arg_dict = OrderedDict() arg_dict["start"] = [-2, 0, 2] arg_dict["end"] = [4, 8, 16] arg_dict["steps"] = [8, 16, 24] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): start = args["start"] end = args["end"] steps = args["steps"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_graph_logspace(test_case, start, end, steps, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_lstm_cell.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # NOTE(lixiang): Do not check the graph for the time being, because ci will report "The action has timed out". @autotest(n=1, check_graph="ValidatedFalse") def _test_lstm_cell(test_case, placement, sbp): batch_size = random(2, 3) * 8 time_steps = random(2, 3) * 8 input_size = random(2, 3) * 8 hidden_size = random(2, 3) * 8 has_bias = random().to(bool) cx_requires_grad = random().to(bool) m = torch.nn.LSTMCell( input_size=input_size, hidden_size=hidden_size, bias=has_bias, ) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight_ih = torch.nn.Parameter( m.weight_ih.to_global(placement=placement, sbp=weight_sbp) ) m.weight_hh = torch.nn.Parameter( m.weight_hh.to_global(placement=placement, sbp=weight_sbp) ) if m.bias_ih is not None: # bias is 1-d tensor bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True) m.bias_ih = torch.nn.Parameter( m.bias_ih.to_global(placement=placement, sbp=bias_sbp) ) m.bias_hh = torch.nn.Parameter( m.bias_hh.to_global(placement=placement, sbp=bias_sbp) ) input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to_global(placement=placement, sbp=input_sbp) hx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False ).to_global(placement=placement, sbp=sbp) cx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=cx_requires_grad ).to_global(placement=placement, sbp=sbp) for i in range(time_steps.to(int).value()): res = m(input[i], (hx, cx)) hx = res[0] cx = res[1] return res[0] class TestRNNCellGlobal(flow.unittest.TestCase): @globaltest def test_lstm_cell(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_lstm_cell(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_masked_fill.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_masked_fill(test_case, placement, sbp): k1 = random().to(int).value() * 8 k2 = random().to(int).value() * 8 input = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp) mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp) value = random().to(float) return input.masked_fill(mask > 0.5, value) @autotest(n=1, check_graph=True) def _test_masked_fill_with_0dim_data(test_case, placement, sbp): input = random_tensor(ndim=0).to_global(placement, sbp) mask = random_tensor(ndim=0).to_global(placement, sbp) value = random().to(float) return input.masked_fill(mask > 0.5, value) @autotest(n=1, check_graph=True) def _test_masked_fill_with_broadcast_way(test_case, placement, sbp): k1 = random().to(int).value() * 8 k2 = random().to(int).value() * 8 input = random_tensor(ndim=2, dim0=k1, dim1=k2, dim2=1, dim3=k2).to_global( placement, sbp ) mask = random_tensor(ndim=2, dim0=k1, dim1=k2, dim2=k1, dim3=1).to_global( placement, sbp ) value = random().to(float) return input.masked_fill(mask > 0.5, value) class TestMaskedFill(flow.unittest.TestCase): @globaltest def test_masked_fill(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_masked_fill(test_case, placement, sbp) # TODO() : fail at tensor slice # _test_masked_fill_with_0dim_data(test_case, placement, sbp) _test_masked_fill_with_broadcast_way(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_masked_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # Not check graph because of one reason: # Reason 1, The implementation of the masked_select op calls argwhere with the lazy tensor as an argument, but lazy tensor can not be applied to argwhere. # Please refer to File "python/oneflow/nn/modules/masked_select.py", line 54, in masked_select_op. @autotest(n=1, check_graph="ValidatedFalse") def _test_masked_select(test_case, placement, sbp): k1 = random(1, 2).to(int).value() * 8 k2 = random(1, 2).to(int).value() * 8 input = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp) mask = input.ge(0.5) return torch.masked_select(input, mask) # Not check graph because of one reason: # Reason 1, The implementation of the masked_select op calls argwhere with the lazy tensor as an argument, but lazy tensor can not be applied to argwhere. # Please refer to File "python/oneflow/nn/modules/masked_select.py", line 54, in masked_select_op. @autotest(n=1, check_graph="ValidatedFalse") def _test_masked_select_broadcast(test_case, placement, input_sbp, mask_sbp): k1 = random(1, 2).to(int).value() * 8 k2 = random(1, 2).to(int).value() * 8 input = random_tensor(ndim=4, dim0=k1, dim1=k2, dim2=1, dim3=k2).to_global( placement, input_sbp ) mask = random_tensor(ndim=4, dim0=k1, dim1=k2, dim2=k1, dim3=1).to_global( placement, mask_sbp ) return torch.masked_select(input, mask > 0.5) class TestMaskedSelect(flow.unittest.TestCase): @globaltest def test_masked_select(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_masked_select(test_case, placement, sbp) @globaltest def test_masked_select_broadcast(test_case): for placement in all_placement(): for input_sbp in all_sbp(placement, valid_split_axis=[0, 1, 3]): for mask_sbp in all_sbp(placement, max_dim=3): _test_masked_select_broadcast( test_case, placement, input_sbp, mask_sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_math_op_higher_derivative.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _global_math_op_grad_grad_impl(test_case, op_name, placement, sbp): x = ( random_tensor(2, dim0=8, dim1=8, low=-2, high=2) .to_global(placement=placement, sbp=sbp) .requires_grad_(True) ) y = eval(f"torch.{op_name}")(x) init_grad = random_tensor(2, 8, 8).to_global(placement, sbp).requires_grad_() x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0] test_case.assertTrue( np.allclose( x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) x_grad_grad = torch.autograd.grad(x_grad, x, init_grad, retain_graph=True)[0] test_case.assertTrue( np.allclose( x_grad_grad.pytorch.detach().cpu().numpy(), x_grad_grad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) init_grad_grad = random_tensor(2, 8, 8).to_global(placement, sbp).requires_grad_() dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, retain_graph=True)[0] test_case.assertTrue( np.allclose( dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) class TestGlobalMathOpHigherDerivative(flow.unittest.TestCase): @globaltest def test_global_sin_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "sin", placement, sbp) @globaltest def test_global_cos_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "cos", placement, sbp) @globaltest def test_global_tan_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "tan", placement, sbp) @globaltest def test_global_sinh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "sinh", placement, sbp) @globaltest def test_global_cosh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "cosh", placement, sbp) @globaltest def test_global_tanh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "tanh", placement, sbp) @globaltest def test_global_asin_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "asin", placement, sbp) @globaltest def test_global_acos_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "acos", placement, sbp) @globaltest def test_global_atan_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "atan", placement, sbp) @globaltest def test_global_asinh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "asinh", placement, sbp) @globaltest def test_global_acosh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "acosh", placement, sbp) @globaltest def test_global_atanh_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "atanh", placement, sbp) @globaltest def test_global_erf_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "erf", placement, sbp) @globaltest def test_global_erfc_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "erfc", placement, sbp) @globaltest def test_global_exp_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "exp", placement, sbp) @globaltest def test_global_exp2_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "exp2", placement, sbp) @globaltest def test_global_expm1_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "expm1", placement, sbp) @globaltest def test_global_log_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "log", placement, sbp) @globaltest def test_global_logsigmoid_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl( test_case, "nn.functional.logsigmoid", placement, sbp ) @globaltest def test_global_log2_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "log2", placement, sbp) @globaltest def test_global_log1p_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "log1p", placement, sbp) @globaltest def test_global_reciprocal_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "reciprocal", placement, sbp) @globaltest def test_global_rsqrt_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "rsqrt", placement, sbp) @globaltest def test_global_sqrt_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "sqrt", placement, sbp) @globaltest def test_global_square_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "square", placement, sbp) @globaltest def test_global_sigmoid_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "sigmoid", placement, sbp) @globaltest def test_global_abs_grad_grad(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _global_math_op_grad_grad_impl(test_case, "abs", placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_math_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1) def _test_sinh(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.sinh(x) return y @autotest(n=1) def _test_sin(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.sin(x) return y @autotest(n=1) def _test_inplace_sin(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = x + 1 y.sin_() return y @autotest(n=1) def _test_cos(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.cos(x) return y @autotest(n=1) def _test_log(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.log(x) return y @autotest(n=1) def _test_sqrt(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.sqrt(x) return y @autotest(n=1) def _test_exp(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.exp(x) return y @autotest(n=1) def _test_exp2(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.exp2(x) return y @autotest(n=1) def _test_rsqrt(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.rsqrt(x) return y @autotest(n=1) def _test_square(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = torch.square(x) return y @autotest(n=1) def _test_pow_with_scalar(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = random().to(float) z = torch.pow(x, y) return z @autotest(n=1, auto_backward=False) def _test_floordiv_with_scalar(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp) y = random().to(float) z = torch.floor_divide(x, y) return z @autotest(n=1) def _test_arccos(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, low=-1, high=1).to_global(placement, sbp) y = torch.arccos(x) return y @autotest(n=1) def _test_acos(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, low=-1, high=1).to_global(placement, sbp) y = torch.acos(x) return y @autotest(n=1) def _test_arccosh(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, low=2, high=3).to_global(placement, sbp) y = torch.arccosh(x) return y @autotest(n=1) def _test_acosh(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, low=2, high=3).to_global(placement, sbp) y = torch.acosh(x) return y @autotest(n=1, auto_backward=False) def _test_floordiv(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp) y = random_tensor(ndim, *dim_list, low=1, high=10).to_global(placement, sbp) z = torch.floor_divide(x, y) return z @autotest(n=1) def _test_atan2(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = random_tensor(ndim, *dim_list).to_global(placement, sbp) z = torch.atan2(x, y) return z @autotest(n=1) def _test_digamma(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp) y = torch.digamma(x) return y class TestMathOps(flow.unittest.TestCase): @globaltest def test_math_ops(test_case): ndim = random(1, 3).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_sinh(test_case, placement, sbp, ndim) _test_sin(test_case, placement, sbp, ndim) _test_inplace_sin(test_case, placement, sbp, ndim) _test_cos(test_case, placement, sbp, ndim) _test_log(test_case, placement, sbp, ndim) _test_sqrt(test_case, placement, sbp, ndim) _test_exp(test_case, placement, sbp, ndim) _test_exp2(test_case, placement, sbp, ndim) _test_rsqrt(test_case, placement, sbp, ndim) _test_square(test_case, placement, sbp, ndim) _test_pow_with_scalar(test_case, placement, sbp, ndim) _test_floordiv_with_scalar(test_case, placement, sbp, ndim) _test_arccos(test_case, placement, sbp, ndim) _test_acos(test_case, placement, sbp, ndim) _test_arccosh(test_case, placement, sbp, ndim) _test_acosh(test_case, placement, sbp, ndim) _test_digamma(test_case, placement, sbp, ndim) _test_floordiv(test_case, placement, sbp, ndim) _test_atan2(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_matmul(test_case, placement, x_sbp, y_sbp): x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement=placement, sbp=x_sbp) y = random_tensor(ndim=2, dim0=16, dim1=8).to_global(placement=placement, sbp=y_sbp) return torch.matmul(x, y) @autotest(n=1, check_graph=True) def _test_tensor_broadcast_matmul(test_case, placement, x_sbp, y_sbp): x = random_tensor(ndim=3, dim0=8, dim1=8, dim2=16).to_global( placement=placement, sbp=x_sbp ) y = random_tensor(ndim=2, dim0=16, dim1=8).to_global(placement=placement, sbp=y_sbp) return x.matmul(y) class TestMatMulModule(flow.unittest.TestCase): @globaltest def test_matmul(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, max_dim=2): for y_sbp in all_sbp(placement, max_dim=2): _test_matmul(test_case, placement, x_sbp, y_sbp) @globaltest def test_broadcast_matmul(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, valid_split_axis=[0, 1, 2, 3]): for y_sbp in all_sbp(placement, valid_split_axis=[0, 1]): _test_tensor_broadcast_matmul(test_case, placement, x_sbp, y_sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_max.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList def _np_max(shape, dim, keepdims): # np array result input_arr = np.random.randn(*shape) np_out = np.amax(input_arr, axis=dim, keepdims=keepdims) np_out_grad = np.zeros_like(input_arr) if dim == None: arg_max = np.argmax(input_arr) np.put(np_out_grad, arg_max, 1) else: arg_max = np.expand_dims(np.argmax(input_arr, axis=dim), axis=dim) np.put_along_axis(np_out_grad, arg_max, 1, axis=dim) return np_out, np_out_grad, input_arr def _test_max( test_case, placement, sbp, np_out, np_out_grad, input_arr, shape, dim, keepdims ): # of result global_x = flow.tensor( input_arr, dtype=flow.float32, requires_grad=True, placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ) if dim is None: of_out = flow.max(global_x) else: of_out = flow.max(global_x, dim, keepdims)[0] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(global_x.grad.numpy(), np_out_grad, 0.0001, 0.0001) ) class TestMaxModule(flow.unittest.TestCase): # backward formula is different from one of torch. @globaltest def test_eager_global_max(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_max] arg_dict["shape"] = [(8,), (8, 8), (8, 8, 8, 8)] arg_dict["dim"] = [None, 0, -1] arg_dict["keepdims"] = [False, True] for arg in GenArgList(arg_dict): np_out, np_out_grad, input_arr = _np_max(*arg[1:]) np_out = ( flow.tensor(np_out) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) np_out_grad = ( flow.tensor(np_out_grad) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) input_arr = ( flow.tensor(input_arr) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(*arg[1:2])): arg[0]( test_case, placement, sbp, np_out, np_out_grad, input_arr, *arg[1:] ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_maximum_minimum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import torch as torch_original from packaging import version from oneflow.test_utils.automated_test_util import * @autotest( n=5, auto_backward=( version.parse(torch_original.__version__) >= version.parse("1.10.2") ), check_graph=True, ) def _test_broadcast_maximum(test_case, placement, x_sbp, y_sbp): x = random_tensor(ndim=5, dim0=8, dim1=8, dim2=8, dim3=1, dim4=8).to_global( placement, x_sbp ) y = random_tensor(ndim=5, dim0=8, dim1=8, dim2=1, dim3=8, dim4=1).to_global( placement, y_sbp ) z = torch.maximum(x, y) return z @autotest( n=5, auto_backward=( version.parse(torch_original.__version__) >= version.parse("1.10.2") ), check_graph=True, ) def _test_broadcast_minimum(test_case, placement, x_sbp, y_sbp): x = random_tensor(ndim=5, dim0=8, dim1=8, dim2=8, dim3=1, dim4=8).to_global( placement, x_sbp ) y = random_tensor(ndim=5, dim0=8, dim1=8, dim2=1, dim3=8, dim4=1).to_global( placement, y_sbp ) z = torch.minimum(x, y) return z @autotest( n=5, auto_backward=( version.parse(torch_original.__version__) >= version.parse("1.10.2") ), check_graph=True, ) def _test_maximum_with_same_input(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=8, dim2=8, dim3=8).to_global(placement, sbp) y = x.detach().clone() y.requires_grad = True z = torch.maximum(x, y) return z @autotest( n=5, auto_backward=( version.parse(torch_original.__version__) >= version.parse("1.10.2") ), check_graph=True, ) def _test_minimum_with_same_input(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=8, dim2=8, dim3=8).to_global(placement, sbp) y = x.detach().clone() y.requires_grad = True z = torch.minimum(x, y) return z class TestMaximumMinimumOps(flow.unittest.TestCase): @globaltest def test_maximum_minimum_with_same_input(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_maximum_with_same_input(test_case, placement, sbp) _test_minimum_with_same_input(test_case, placement, sbp) @globaltest def test_broadcast_maximum_minimum(test_case): for placement in all_placement(): for x_sbp in all_sbp(placement, valid_split_axis=[0, 1, 2, 4]): for y_sbp in all_sbp(placement, valid_split_axis=[0, 1, 3]): _test_broadcast_maximum(test_case, placement, x_sbp, y_sbp) _test_broadcast_minimum(test_case, placement, x_sbp, y_sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_maxpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from pkg_resources import packaging import oneflow as flow import torch as ori_torch import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t @autotest(n=1, check_graph=True) def _test_maxpool1d_functional(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(1, 4).to(int).value() * 8 dim1 = random(1, 4).to(int).value() * 8 x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 22)).to_global( placement, sbp ) y = torch.nn.functional.max_pool1d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), dilation=random(2, 4).to(int), ceil_mode=random().to(bool), return_indices=return_indices, ) if return_indices: return y[0] else: return y @autotest(n=1, check_graph=True) def _test_maxpool2d_functional(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(1, 4).to(int).value() * 8 dim1 = random(1, 4).to(int).value() * 8 x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(20, 22), dim3=random(20, 22) ).to_global(placement, sbp) y = torch.nn.functional.max_pool2d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), dilation=random(2, 4).to(int), ceil_mode=random().to(bool), return_indices=return_indices, ) if return_indices: return y[0] else: return y @autotest(n=1, check_graph=True) def _test_maxpool3d_functional(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(high=4).to(int).value() * 8 dim1 = random(high=4).to(int).value() * 8 x = random_tensor( ndim=5, dim0=dim0, dim1=dim1, dim2=random(10, 12), dim3=random(10, 12), dim4=random(10, 12), ).to_global(placement, sbp) y = torch.nn.functional.max_pool3d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int), padding=random(1, 3).to(int), dilation=random(2, 4).to(int), ceil_mode=random().to(bool), return_indices=return_indices, ) if return_indices: return y[0] else: return y @autotest(n=1, check_graph=True) def _test_maxpool1d(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(1, 4).to(int).value() * 8 dim1 = random(1, 4).to(int).value() * 8 m = torch.nn.MaxPool1d( kernel_size=random(4, 6).to(_size_1_t), stride=random(1, 3).to(_size_1_t), padding=random(1, 3).to(_size_1_t), dilation=random(2, 4).to(_size_1_t), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 22)).to_global( placement, sbp ) y = m(x) if return_indices: return y[0] else: return y @autotest(n=1, check_graph=True) def _test_maxpool2d(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(1, 3).to(int).value() * 8 dim1 = random(1, 3).to(int).value() * 8 m = torch.nn.MaxPool2d( kernel_size=random(4, 6).to(_size_2_t), stride=random(1, 3).to(_size_2_t), padding=random(1, 3).to(_size_2_t), dilation=random(2, 4).to(_size_2_t), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(20, 22), dim3=random(20, 22) ).to_global(placement, sbp) y = m(x) if return_indices: return y[0] else: return y @autotest(n=1, check_graph=True) def _test_maxpool3d(test_case, placement, sbp): return_indices = random().to(bool).value() dim0 = random(high=4).to(int).value() * 8 dim1 = random(high=4).to(int).value() * 8 m = torch.nn.MaxPool3d( kernel_size=random(4, 6).to(_size_3_t), stride=random(1, 3).to(_size_3_t), padding=random(1, 3).to(_size_3_t), dilation=random(2, 4).to(_size_3_t), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) x = random_tensor( ndim=5, dim0=dim0, dim1=dim1, dim2=random(10, 12), dim3=random(10, 12), dim4=random(10, 12), ).to_global(placement, sbp) y = m(x) if return_indices: return y[0] else: return y def _test_maxpool2d_channel_last( test_case, placement, sbp, shape, kernel_size, stride, padding, dilation, ceil_mode ): os.environ["ONEFLOW_ENABLE_NHWC"] = "1" tensor = random_tensor(len(shape), *shape, requires_grad=False).to_global( placement, sbp ) # oneflow result x1 = tensor.oneflow m1 = flow.nn.MaxPool2d( kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, ) y1 = m1(x1) # pytorch result x2 = tensor.pytorch.permute(0, 3, 1, 2).to(placement.type) m2 = ori_torch.nn.MaxPool2d( kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, ) y2 = m2(x2).permute(0, 2, 3, 1) os.environ["ONEFLOW_ENABLE_NHWC"] = "1" # It should be added after updating to torch1.13 # test_case.assertTrue( # np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), 1e-4, 1e-4) # ) class TestMaxPool(flow.unittest.TestCase): @globaltest def test_maxpool(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_maxpool1d_functional(test_case, placement, sbp) _test_maxpool2d_functional(test_case, placement, sbp) _test_maxpool3d_functional(test_case, placement, sbp) _test_maxpool1d(test_case, placement, sbp) _test_maxpool2d(test_case, placement, sbp) _test_maxpool3d(test_case, placement, sbp) @globaltest @unittest.skipIf( packaging.version.parse(ori_torch.__version__) == packaging.version.parse("1.10.0"), "skip when pytorch version == 1.10.0", ) # NOTE:pytorch maxpool2d nhwc has bug in version of 1.10.0, so skip it in CI. # detail:https://github.com/pytorch/pytorch/pull/76597 def test_maxpool2d_channel_last(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_maxpool2d_channel_last] arg_dict["shape"] = [(1, 16, 16, 3), (2, 224, 224, 3)] arg_dict["kernel_size"] = [3, (2, 3)] arg_dict["stride"] = [1, (1, 2)] arg_dict["padding"] = [0, (0, 1)] arg_dict["dilation"] = [1, 2] arg_dict["ceil_mode"] = [True, False] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, valid_split_axis=[1, 2]): arg[0](test_case, placement, sbp, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_maxunpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t # y = pool(x), z = unpool(y, indices), pool_input_shape is x.shape, pool_output_shape is y.shape. # When `output_size` in unpool() is empty, the op will calculate the output size according to # kernel_size, stride and padding. But when index in indices is outside the range required # by output_size calculated by unpool op, the value of result and related grad will be unknown. # To avoid the problem, this function calculate the output_size which will not cause unknown problems. def _get_valid_output_size( pool_input_shape, pool_output_shape, kernel_size, stride, padding ): def convert_data(data, i, dst_data=None): if not isinstance(data, (list, int)): return dst_data if isinstance(data, list): return data[i] return data _, _, *pool_input_hwd_shape = pool_input_shape.pytorch batch_size, num_channels, *pool_out_hwd_shape = pool_output_shape.pytorch unpool_output_shape = [batch_size, num_channels] for i, (pool_input_size, pool_output_size) in enumerate( zip(pool_input_hwd_shape, pool_out_hwd_shape) ): kernel_size_value = convert_data(kernel_size.value(), i) stride_value = convert_data(stride.value(), i, kernel_size_value) padding_value = convert_data(padding.value(), i, 0) unpool_output_size = max( pool_input_size, (pool_output_size - 1) * stride_value - 2 * padding_value + kernel_size_value, ) unpool_output_shape.append(unpool_output_size) return torch.Size(unpool_output_shape) def _test_module_unpoolnd(test_case, placement, sbp, n): device = random_device() dim0 = random(high=4).to(int).value() * 8 dim1 = random(high=4).to(int).value() * 8 if n == 1: _size_n_t = _size_1_t MaxPoolNd = torch.nn.MaxPool1d MaxUnpoolNd = torch.nn.MaxUnpool1d x = random_tensor( ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 31), requires_grad=False ).to_global(placement=placement, sbp=sbp) elif n == 2: _size_n_t = _size_2_t MaxPoolNd = torch.nn.MaxPool2d MaxUnpoolNd = torch.nn.MaxUnpool2d x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(10, 21), dim3=random(10, 21), requires_grad=False, ).to_global(placement=placement, sbp=sbp) elif n == 3: _size_n_t = _size_3_t MaxPoolNd = torch.nn.MaxPool3d MaxUnpoolNd = torch.nn.MaxUnpool3d x = random_tensor( ndim=5, dim0=dim0, dim1=dim1, dim2=random(10, 14), dim3=random(10, 14), dim4=random(10, 14), requires_grad=False, ).to_global(placement=placement, sbp=sbp) kernel_size = random(4, 6).to(_size_n_t) stride = random(1, 3).to(_size_n_t) padding = random(1, 3).to(_size_n_t) m = MaxPoolNd( kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True, ) m.train(random()) m.to(device) y = m(x) pooling_results = y[0] indices = y[1] pooling_results.requires_grad_() output_size = _get_valid_output_size( x.shape, pooling_results.shape, kernel_size, stride, padding ) unpool_module = MaxUnpoolNd( kernel_size=kernel_size, stride=stride, padding=padding, ) result = unpool_module(pooling_results, indices, output_size=output_size) return result def _test_functional_unpoolnd(test_case, placement, sbp, n): device = random_device() dim0 = random(high=4).to(int).value() * 8 dim1 = random(high=4).to(int).value() * 8 if n == 1: _size_n_t = _size_1_t MaxPoolNd = torch.nn.MaxPool1d max_unpool_nd = torch.nn.functional.max_unpool1d x = random_tensor( ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 31), requires_grad=False ).to_global(placement=placement, sbp=sbp) elif n == 2: _size_n_t = _size_2_t MaxPoolNd = torch.nn.MaxPool2d max_unpool_nd = torch.nn.functional.max_unpool2d x = random_tensor( ndim=4, dim0=dim0, dim1=dim1, dim2=random(10, 21), dim3=random(10, 21), requires_grad=False, ).to_global(placement=placement, sbp=sbp) elif n == 3: _size_n_t = _size_3_t MaxPoolNd = torch.nn.MaxPool3d max_unpool_nd = torch.nn.functional.max_unpool3d x = random_tensor( ndim=5, dim0=dim0, dim1=dim1, dim2=random(10, 14), dim3=random(10, 14), dim4=random(10, 14), requires_grad=False, ).to_global(placement=placement, sbp=sbp) kernel_size = random(4, 6).to(_size_n_t) stride = random(1, 3).to(_size_n_t) padding = random(1, 3).to(_size_n_t) m = MaxPoolNd( kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True, ) m.train(random()) m.to(device) y = m(x) pooling_results = y[0] indices = y[1] pooling_results.requires_grad_() output_size = _get_valid_output_size( x.shape, pooling_results.shape, kernel_size, stride, padding ) return max_unpool_nd( pooling_results, indices, kernel_size=kernel_size, stride=stride, padding=padding, output_size=output_size, ) class TestMaxPool(flow.unittest.TestCase): @globaltest def test_maxpool(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_functional_unpoolnd(test_case, placement, sbp, 1) _test_functional_unpoolnd(test_case, placement, sbp, 2) _test_functional_unpoolnd(test_case, placement, sbp, 3) _test_module_unpoolnd(test_case, placement, sbp, 1) _test_module_unpoolnd(test_case, placement, sbp, 2) _test_module_unpoolnd(test_case, placement, sbp, 3) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_mean.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_mean(test_case, placement, sbp, ndim): dim = random(1, ndim).to(int).value() dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list, dtype=float).to_global(placement, sbp) return torch.mean(x, dim) class TestMean(flow.unittest.TestCase): @globaltest def test_mean(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_mean(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_median.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import torch from functools import reduce import operator import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_median(test_case, placement, sbp, ndim): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) return torch.median(x) @autotest(n=1, check_graph=True) def _test_median_with_indices(test_case, placement, sbp, ndim): dim = random(1, ndim).to(int).value() dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = choice_tensor( reduce(operator.mul, dim_list, 1), dim_list, replace=False, dtype=float, requires_grad=True, ).to_global(placement, sbp) return torch.median(x, dim) class TestMedian(flow.unittest.TestCase): @globaltest def test_median(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_median(test_case, placement, sbp, ndim) _test_median_with_indices(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_meshgrid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=5, auto_backward=False, check_graph=True) def _test_meshgrid(test_case, placement): x_sbp = random_sbp(placement, max_dim=1) x = random_tensor(ndim=1, dim0=8, requires_grad=False).to_global(placement, x_sbp) y_sbp = random_sbp(placement, max_dim=1) y = random_tensor(ndim=1, dim0=8, requires_grad=False).to_global(placement, y_sbp) res = torch.meshgrid(x, y) return res[0], res[1] class TestMeshGrid(flow.unittest.TestCase): @globaltest def test_meshgrid(test_case): for placement in all_placement(): _test_meshgrid(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_min.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList def _np_min(shape, dim, keepdims): # np array result input_arr = np.random.randn(*shape) np_out = np.amin(input_arr, axis=dim, keepdims=keepdims) np_out_grad = np.zeros_like(input_arr) if dim == None: arg_min = np.argmin(input_arr) np.put(np_out_grad, arg_min, 1) else: arg_min = np.expand_dims(np.argmin(input_arr, axis=dim), axis=dim) np.put_along_axis(np_out_grad, arg_min, 1, axis=dim) return np_out, np_out_grad, input_arr def _test_min( test_case, placement, sbp, np_out, np_out_grad, input_arr, shape, dim, keepdims ): # of result global_x = flow.tensor( input_arr, dtype=flow.float32, requires_grad=True, placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ) if dim is None: of_out = flow.min(global_x) else: of_out = flow.min(global_x, dim, keepdims)[0] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(global_x.grad.numpy(), np_out_grad, 0.0001, 0.0001) ) class TestMinModule(flow.unittest.TestCase): # backward formula is different from one of torch. @unittest.skip("skip for now, becase it failed 8 times in past week") @globaltest def test_eager_global_min(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_min] arg_dict["shape"] = [(8,), (8, 8), (8, 8, 8, 8)] arg_dict["dim"] = [None, 0, -1] arg_dict["keepdims"] = [False, True] for arg in GenArgList(arg_dict): np_out, np_out_grad, input_arr = _np_min(*arg[1:]) np_out = ( flow.tensor(np_out) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) np_out_grad = ( flow.tensor(np_out_grad) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) input_arr = ( flow.tensor(input_arr) .to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast,) .numpy() ) for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(*arg[1:2])): arg[0]( test_case, placement, sbp, np_out, np_out_grad, input_arr, *arg[1:] ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_min_max_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.nn.modules import min_max_observer from oneflow.test_utils.test_util import GenArgList from test_min_max_observer import _check_min_max_observer def _run_test_min_max_observer( test_case, placement, sbp, weight_shape, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): weight = random_tensor( len(weight_shape), *weight_shape, low=-0.5, high=0.5 ).to_global(placement, sbp) of_weight = weight.oneflow np_weight = of_weight.numpy() min_max_observer = flow.nn.MinMaxObserver( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization, ) scale, zero_point = min_max_observer(of_weight) _check_min_max_observer( test_case, np_weight, scale.numpy(), zero_point.numpy(), quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ) class TestMinMaxObserver(flow.unittest.TestCase): @globaltest def test_min_max_observer(test_case): arg_dict = OrderedDict() arg_dict["weight_shape"] = [(9, 48, 24, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["per_layer_quantization"] = [True, False] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, valid_split_axis=[1, 2]): _run_test_min_max_observer(test_case, placement, sbp, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_movedim.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_movedim(test_case, placement, sbp): x = random_tensor( ndim=4, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=random(1, 3) * 8, dim4=random(1, 3) * 8, ).to_global(placement, sbp) z = torch.movedim(x, (0, 1), (2, 3)) return z class TestMovedim(flow.unittest.TestCase): @globaltest def test_movedim(test_case): for placement in all_placement(): for sbp in all_sbp(placement): _test_movedim(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_moving_average_max_min_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from test_moving_average_min_max_observer import _check_moving_average_min_max_observer from oneflow.test_utils.automated_test_util import * def _run_test_moving_average_min_max_observer( test_case, placement, sbp, device_type, dtype, activation_shape, quantization_bit, quantization_scheme, quantization_formula, momentum, ): moving_max_np = np.zeros((1,)) moving_min_np = np.zeros((1,)) current_train_step_tensor = flow.tensor( np.zeros((1,)).astype(np.float32), dtype=flow.int64, placement=placement, sbp=sbp, ) for i in range(10): of_activation = ( random_tensor(len(activation_shape), *activation_shape, low=-0.5, high=0.5) .to_global(placement, sbp) .oneflow ) np_activation = of_activation.numpy() moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver( quantization_formula=quantization_formula, stop_update_after_iters=1, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, momentum=momentum, ) moving_average_min_max_observer = moving_average_min_max_observer.to_global( placement, sbp ) (scale, zero_point) = moving_average_min_max_observer( of_activation, current_train_step_tensor ) _check_moving_average_min_max_observer( test_case, np_activation, scale.numpy(), zero_point.numpy(), moving_max_np, moving_min_np, quantization_bit, quantization_scheme, quantization_formula, momentum, ) class TestMovingAverageMinMaxObserver(flow.unittest.TestCase): @globaltest def test_moving_average_min_max_observer(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["dtype"] = ["float32", "double"] arg_dict["activation_shape"] = [(9, 48, 24, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["momentum"] = [0.95] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, valid_split_axis=[1, 2]): _run_test_moving_average_min_max_observer( test_case, placement, sbp, *arg ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_mul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_broadcast_mul(test_case, placement, sbp): x = random_tensor(ndim=3, dim0=16, dim1=8, dim2=24).to_global(placement, sbp) y_sbp = random_sbp(placement, max_dim=2) y = random_tensor(ndim=2, dim0=8, dim1=24).to_global(placement, y_sbp) z = torch.mul(x, y) return z @autotest(n=1, check_graph=True) def _test_mul_with_scalar(test_case, ndim, placement, sbp): dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dim_list).to_global(placement, sbp) y = 2 return torch.mul(x, y) class TestMulModule(flow.unittest.TestCase): @globaltest def test_broadcast_mul(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_broadcast_mul(test_case, placement, sbp) @globaltest def test_mul_with_scalar(test_case): ndim = random(1, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_mul_with_scalar(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_mv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_mv(test_case, placement, sbp): dim = random(1, 6) mat = random_tensor(2, dim1=dim).to_global(placement=placement, sbp=sbp) vec = random_tensor(1, dim0=dim).to_global(placement=placement, sbp=sbp) return torch.mv(mat, vec) class TestMvModule(flow.unittest.TestCase): @globaltest def test_mv(test_case): for placement in all_placement(): for sbp in all_sbp(placement): _test_mv(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_nansum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=False) def _test_global_nansum_against_pytorch(test_case, placement, sbp): x = random_tensor(4, 8, 16, 8, 24).to_global(placement, sbp) mask = x < 0 x = x.masked_fill(mask, float("nan")) y = torch.nansum(x) return y @autotest(n=1, check_graph=False) def _test_global_nansum_with_0_size_tensor(test_case, placement, sbp): x = random_tensor(4, 8, 16, 0, 24).to_global(placement, sbp) mask = torch.ones_like(x).bool() x = x.masked_fill(mask, float("nan")) y = torch.nansum(x, dim=random(0, 3).to(int)) return y class TestGlobalNanSumModule(flow.unittest.TestCase): @globaltest def test_global_nansum_against_pytorch(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_nansum_against_pytorch(test_case, placement, sbp) @globaltest def test_global_nansum_with_0_size_tensor(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]): _test_global_nansum_with_0_size_tensor(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_narrow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_narrow(test_case, ndim, placement, sbp): dims = [random(1, 3).to(int).value() * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) dim = random(-ndim, ndim).to(int).value() start = random(0, dims[dim]).to(int).value() length = random(1, dims[dim] - start + 1).to(int).value() return torch.narrow(x, dim=dim, start=start, length=length) class TestNarrow(flow.unittest.TestCase): @globaltest def test_narrow(test_case): for placement in all_placement(): ndim = random(1, 4).to(int).value() for sbp in all_sbp(placement, max_dim=min(ndim, 2)): _test_narrow(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_ne.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_ne(test_case, placement, sbp): x1 = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement, sbp) x2 = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement, sbp) return torch.ne(x1, x2) class TestNe(flow.unittest.TestCase): @globaltest def test_ne(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_ne(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_negative.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_negative(test_case, placement, sbp, ndim): shape = [8 for _ in range(ndim)] x = random_tensor(ndim, *shape).to_global(placement, sbp) return torch.negative(x) class TestNegative(flow.unittest.TestCase): @globaltest def test_negative(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_negative(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_nms.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from test_nms import create_tensors_with_iou from test_nms import nms_np def _test_nms(test_case, placement, sbp): iou = 0.5 boxes, scores = create_tensors_with_iou(800, iou) global_boxes = flow.tensor(boxes, dtype=flow.float32).to_global( placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast ) np_boxes = global_boxes.numpy() global_boxes = global_boxes.to_global(placement=placement, sbp=sbp) global_scores = flow.tensor(scores, dtype=flow.float32).to_global( placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast ) np_scores = global_scores.numpy() global_scores = global_scores.to_global(placement=placement, sbp=sbp) keep_np = nms_np(np_boxes, np_scores, iou) keep = flow.nms(global_boxes, global_scores, iou) test_case.assertTrue(np.allclose(keep.numpy(), keep_np)) class TestNMS(flow.unittest.TestCase): @globaltest def test_nms(test_case): for placement in all_placement(): # TODO: nms only has cuda kernel at now. if placement.type == "cpu": continue for sbp in all_sbp(placement, max_dim=1): _test_nms(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_normal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type from oneflow.test_utils.automated_test_util import * import oneflow as flow def _test_global_normal( test_case, placement, sbp, mean, std, shape, dtype, requires_grad ): dtype = type_name_to_flow_type[dtype] x = flow.normal( mean, std, shape, placement=placement, sbp=sbp, dtype=dtype, requires_grad=requires_grad, ) test_case.assertEqual(x.shape, shape) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) test_case.assertEqual(x.requires_grad, requires_grad) class TestNormalGlobal(flow.unittest.TestCase): @globaltest def test_normal_global(test_case): arg_dict = OrderedDict() arg_dict["mean"] = [-1, 0, 1] arg_dict["std"] = [1, 2, 8] arg_dict["shape"] = [(8, 8), (8, 8, 8), (8, 8, 8, 8)] arg_dict["dtype"] = ["float32", "double"] arg_dict["requires_grad"] = [True, False] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(arg[2]), except_partial_sum=True ): _test_global_normal(test_case, placement, sbp, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_normalize.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_functional_normalize(test_case, placement, sbp): ndim = random(low=2, high=5).to(int).value() shape = [random(low=2, high=3) * 8 for i in range(ndim)] x = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp) dim = random(low=0, high=ndim).to(int).value() y = torch.nn.functional.normalize(x, oneof(2, 3, 4), dim, 1e-12) return y class TestModule(flow.unittest.TestCase): @globaltest def test_normalize_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_functional_normalize(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_nozero.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # Not check graph because of one reason: # Reason 1, lazy tensor cannot call numpy(), tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(n=1, auto_backward=False, check_graph="ValidatedFalse") def _test_nonzero(test_case, placement, sbp, ndim): shape = [8 for _ in range(ndim)] x = random_tensor(ndim, *shape).to_global(placement, sbp) return torch.nonzero(x) class TestNonZero(flow.unittest.TestCase): @globaltest def test_nonzero(test_case): ndim = random(2, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_nonzero(test_case, placement, sbp, ndim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_ones_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_ones_like_float(test_case, placement, sbp, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) x = x.to_global(placement=placement, sbp=sbp) y = flow.ones_like(x, placement=placement, sbp=sbp) test_case.assertTrue(y.dtype is flow.float32) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.placement == placement) y_numpy = np.ones(x.numpy().shape) print("y_numpy: ", y_numpy) print("y.numpy()", y.numpy()) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) def _test_ones_like_int(test_case, placement, sbp, shape, device): x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device)) x = x.to_global(placement=placement, sbp=sbp) y = flow.ones_like(x, dtype=flow.int, placement=placement, sbp=sbp) test_case.assertTrue(y.dtype is flow.int) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.placement == placement) y_numpy = np.ones(x.numpy().shape) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) class TestModule(flow.unittest.TestCase): @unittest.skip("TODO: global ones_like test will fail!") @globaltest def test_ones_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_ones_like_float, _test_ones_like_int] arg_dict["shape"] = [(8, 8), (8, 8, 4), (8, 8, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): arg[0](test_case, placement, sbp, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest @autotest(n=1, check_graph=True) def _test_pad_1d_impl(test_case, placement, sbp): pad = [random(0, 5).to(int) for i in range(2)] x = random_tensor( ndim=3, dim0=8, dim1=random(2, 8).to(int) * 8, dim2=random(2, 8).to(int) * 8 ).to_global(placement=placement, sbp=sbp) y = torch.nn.functional.pad(x, pad, mode=oneof("constant", "reflect", "replicate")) return y @autotest(n=1, check_graph=True) def _test_pad_2d_impl(test_case, placement, sbp): pad = [random(0, 5).to(int) for i in range(4)] x = random_tensor( ndim=4, dim0=8, dim1=8, dim2=random(2, 8).to(int) * 8, dim3=random(2, 8).to(int) * 8, ).to_global(placement=placement, sbp=sbp) y = torch.nn.functional.pad(x, pad, mode=oneof("constant", "reflect", "replicate")) return y class TestPad(flow.unittest.TestCase): @globaltest def test_pad_1d(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_pad_1d_impl(test_case, placement, sbp) _test_pad_2d_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_partical_fc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestParitalFC(flow.unittest.TestCase): @globaltest def test_parital_fc(test_case): placement = flow.placement.all("cuda") w = flow.randn(5000, 128, placement=placement, sbp=flow.sbp.split(0)) label = flow.randint( 0, 5000, (512,), placement=placement, sbp=flow.sbp.split(0) ) num_sample = 500 out = flow.distributed_partial_fc_sample(w, label, num_sample) test_case.assertTrue(out[0].shape == flow.Size([512])) test_case.assertTrue(out[1].shape == flow.Size([500])) test_case.assertTrue(out[2].shape == flow.Size([500, 128])) w = flow.randn(5000, 128, placement=placement, sbp=flow.sbp.broadcast) label = flow.randint( 0, 5000, (512,), placement=placement, sbp=flow.sbp.split(0) ) num_sample = 500 out = flow.distributed_partial_fc_sample(w, label, num_sample) test_case.assertTrue(out[0].shape == flow.Size([512])) test_case.assertTrue(out[1].shape == flow.Size([500])) test_case.assertTrue(out[2].shape == flow.Size([500, 128])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_permute.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_permute4d_tensor_with_random_data(test_case, placement, sbp): ndim = 4 permute_list = [1, 2, 3, 0] x = random_tensor( ndim=ndim, dim0=8, dim1=8, dim2=random(2, 8).to(int), dim3=random(2, 8).to(int), ).to_global(placement=placement, sbp=sbp) y = x.permute(permute_list) return y class TestModule(flow.unittest.TestCase): @globaltest def test_permute4d_tensor_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_permute4d_tensor_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_rand.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_rand(test_case, shape, placement, sbp): x = flow.rand(*shape, placement=placement, sbp=sbp) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_rand(test_case, shape, placement, sbp): class GlobalRandGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.rand(*shape, placement=placement, sbp=sbp) return x model = GlobalRandGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestRandGlobal(flow.unittest.TestCase): @globaltest def test_rand_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_rand(test_case, shape, placement, sbp) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_rand_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_rand(test_case, shape, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_randint.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_randint(test_case, shape, placement, sbp, dtype): x = flow.randint(1, 10, shape, placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) test_case.assertEqual(x.dtype, dtype) def _test_graph_randint(test_case, shape, placement, sbp, dtype): class GlobalRandintGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.randint(1, 10, shape, placement=placement, sbp=sbp, dtype=dtype) return x model = GlobalRandintGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) test_case.assertEqual(x.dtype, dtype) class TestRandintGlobal(flow.unittest.TestCase): @globaltest def test_randint_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] dtypes = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): for dtype in dtypes: _test_global_randint(test_case, shape, placement, sbp, dtype) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_randint_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["dtype"] = [ flow.uint8, flow.int32, flow.float32, ] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] dtype = args["dtype"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_randint(test_case, shape, placement, sbp, dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_randint_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_consistent_randint_like(test_case, shape, placement, sbp, dtype): x_ = flow.randint(1, 10, shape) x = flow.randint_like(x_, 1, 10, placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) test_case.assertEqual(x.dtype, dtype) def _test_graph_randint_like(test_case, shape, placement, sbp, dtype): class ConsistentRandIntLikeGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x_ = flow.randint(1, 10, shape) x = flow.randint_like(x_, 1, 10, placement=placement, sbp=sbp, dtype=dtype) return x model = ConsistentRandIntLikeGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) test_case.assertEqual(x.dtype, dtype) class TestRandIntLikeConsistent(flow.unittest.TestCase): @globaltest def test_randint_like_consistent(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] dtypes = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): for dtype in dtypes: _test_consistent_randint_like( test_case, shape, placement, sbp, dtype ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_randint_like_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["dtype"] = [ flow.uint8, flow.int32, flow.float32, ] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] dtype = args["dtype"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_randint_like(test_case, shape, placement, sbp, dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_randn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import numpy as np import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_randn(test_case, shape, placement, sbp): x1 = flow.randn(*shape, placement=placement, sbp=sbp) x2 = flow.randn(*shape, placement=placement, sbp=sbp) test_case.assertTrue(not np.allclose(x1.numpy(), x2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertEqual(x1.shape, flow.Size(shape)) test_case.assertEqual(x1.sbp, sbp) test_case.assertEqual(x1.placement, placement) def _test_different_dtype(test_case, shape, placement, sbp): x1 = flow.randn(*shape, dtype=flow.float32, placement=placement, sbp=sbp) x2 = flow.randn(*shape, dtype=flow.float64, placement=placement, sbp=sbp) test_case.assertTrue(not np.allclose(x1.numpy(), x2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertEqual(x1.shape, flow.Size(shape)) def _test_backward(test_case, shape, placement, sbp): x = flow.randn(*shape, placement=placement, sbp=sbp, requires_grad=True) y = x.sum() y.backward() test_case.assertTrue( np.allclose(np.ones(shape), x.grad.numpy(), atol=1e-4, rtol=1e-4) ) def _test_with_generator(test_case, shape, placement, sbp): gen = flow.Generator() gen.manual_seed(0) y1 = flow.randn(*shape, placement=placement, sbp=sbp, generator=gen) gen.manual_seed(0) y2 = flow.randn(*shape, placement=placement, sbp=sbp, generator=gen) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_randn_tuple_shape(test_case, shape, placement, sbp): y1 = flow.randn(*shape, placement=placement, sbp=sbp) y2 = flow.randn(*shape, placement=placement, sbp=sbp) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) def _test_graph_randn(test_case, shape, placement, sbp): class GlobalRandnGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.randn(*shape, placement=placement, sbp=sbp) return x model = GlobalRandnGraph() x = model() test_case.assertEqual(x.shape, flow.Size(shape)) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) class TestRandnGlobal(flow.unittest.TestCase): @globaltest def test_randn_global(test_case): shapes = [(8,), (8, 8,), (8, 8, 8)] for shape in shapes: for placement in all_placement(): for sbp in all_sbp( placement, max_dim=len(shape), except_partial_sum=True ): _test_global_randn(test_case, shape, placement, sbp) _test_different_dtype(test_case, shape, placement, sbp) _test_backward(test_case, shape, placement, sbp) _test_with_generator(test_case, shape, placement, sbp) _test_randn_tuple_shape(test_case, shape, placement, sbp) @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @globaltest def test_randn_graph(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8,), (8, 8,), (8, 8, 8)] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] for args in GenArgDict(arg_dict): shape = args["shape"] placement = args["placement"] for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True): _test_graph_randn(test_case, shape, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_random_op_data.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import numpy as np import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict _fn_param = { "normal": lambda shape, placement, sbp: flow.normal( size=shape, mean=0.0, std=1.0, placement=placement, sbp=sbp ), "rand": lambda shape, placement, sbp: flow.rand( size=shape, placement=placement, sbp=sbp ), "randint": lambda shape, placement, sbp: flow.randint( low=0, high=2, size=shape, placement=placement, sbp=sbp ), "randn": lambda shape, placement, sbp: flow.randn( size=shape, placement=placement, sbp=sbp ), } def _test_data_consistent(test_case, shape, placement, sbp, fn): # lazy result class GlobalRandnGraph(flow.nn.Graph): def __init__(self): super().__init__() def build(self): flow.manual_seed(233) x = fn(shape, placement, sbp) return x model = GlobalRandnGraph() lazy_x = model() # eager result flow.manual_seed(233) eager_x = fn(shape, placement, sbp) test_case.assertTrue( np.array_equal(lazy_x.to_local().numpy(), eager_x.to_local().numpy()) ) # different data eager_x2 = fn(shape, placement, sbp) test_case.assertFalse( np.array_equal(eager_x.to_local().numpy(), eager_x2.to_local().numpy()) ) class TestGlobalRandomOpData(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @globaltest def test_random_op_data_consistent_with_eager_and_lazy(test_case): shape = (8, 8) for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): for _, fn in _fn_param.items(): _test_data_consistent(test_case, shape, placement, sbp, fn=fn) @globaltest @oneflow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_op_data_correctness(test_case): shape = (8, 8) sbp = [flow.sbp.split(0), flow.sbp.broadcast] for device in ["cpu", "cuda"]: placement = flow.placement(device, [[0, 1], [2, 3]]) for _, fn in _fn_param.items(): flow.manual_seed(233) local_tensor = fn(shape, placement, sbp).to_local().cpu() # broadcast local data for each rank rank_to_tensor = [ local_tensor if rank_id == flow.env.get_rank() else flow.empty(local_tensor.shape, dtype=local_tensor.dtype) for rank_id in range(4) ] for rank_id in range(4): flow.comm.broadcast(rank_to_tensor[rank_id], rank_id) np_local = [x.numpy() for x in rank_to_tensor] # rank0 == rank1 test_case.assertTrue(np.array_equal(np_local[0], np_local[1])) # rank2 == rank3 test_case.assertTrue(np.array_equal(np_local[2], np_local[3])) # rank0 != rank2 test_case.assertFalse(np.array_equal(np_local[0], np_local[2])) # rank1 != rank3 test_case.assertFalse(np.array_equal(np_local[1], np_local[3])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_randperm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgDict def _test_global_randperm(test_case, N, placement, sbp, dtype): x = flow.randperm(N, placement=placement, sbp=sbp, dtype=dtype) # TODO:Synchronously get a global random seed, and then each rank sets its own seed in manual_seeds test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def _test_graph_randperm(test_case, N, placement, sbp, dtype): class GlobalRandpermGraph(flow.nn.Graph): def __init__(self,): super().__init__() def build(self): x = flow.randperm(N, placement=placement, sbp=sbp, dtype=dtype) return x model = GlobalRandpermGraph() x = model() y1 = x.to_global(placement=placement, sbp=sbp) y1_np_sort = np.sort(y1.numpy()) y2 = np.arange(N) test_case.assertTrue(np.allclose(y1_np_sort, y2, atol=1e-4, rtol=1e-4)) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) @unittest.skip("This fails in multi-gpu") class TestRandpermGlobal(flow.unittest.TestCase): @globaltest def test_randperm_global(test_case): RandNs = [i for i in range(10, 50, 10)] # TODO support uint8,int8,int64,float32,float64,data type test Dtypes = [ flow.int32, ] for N in RandNs: for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): for dtype in Dtypes: _test_global_randperm(test_case, N, placement, sbp, dtype) @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @globaltest def test_randperm_graph(test_case): arg_dict = OrderedDict() arg_dict["N"] = [i for i in range(10, 50, 10)] arg_dict["placement"] = [ # 1d flow.placement("cpu", ranks=[0, 1]), flow.placement("cuda", ranks=[0, 1]), # 2d flow.placement("cpu", ranks=[[0, 1],]), flow.placement("cuda", ranks=[[0, 1],]), ] arg_dict["dtype"] = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ] for args in GenArgDict(arg_dict): N = args["N"] placement = args["placement"] dtype = args["dtype"] for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True): _test_graph_randperm(test_case, N, placement, sbp, dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_reciprocal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_reciprocal_impl(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for _ in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.reciprocal(y) return z class TestReciprocalGlobal(flow.unittest.TestCase): @globaltest def test_reciprocal(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_reciprocal_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_reflection_pad2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_reflection_pad2d_impl(test_case, padding, placement, sbp): m = torch.nn.ReflectionPad2d(padding=padding) dims = [random(2, 4) * 8 for _ in range(4)] x = random_tensor(4, *dims) y = x.to_global(placement=placement, sbp=sbp) z = m(y) return z class TestReflectionPad2dGlobal(flow.unittest.TestCase): @globaltest def test_reflection_pad2d(test_case): padding = [ (2, 2, 1, 1), 1, (1, 0, 1, 0), (0, 1, 0, 1), ] for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): for pad in padding: _test_reflection_pad2d_impl(test_case, pad, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_repeat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_repeat_impl(test_case, ndim, placement, sbp): dims = [random(1, 4).to(int).value() * 8 for _ in range(ndim)] repeat_size = [random(1, 3).to(int).value() for _ in range(ndim)] x = random_tensor(ndim, *dims) y = x.to_global(placement=placement, sbp=sbp) z = y.repeat(repeat_size) return z class TestRepeatGlobal(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_repeat(test_case): # random ndim in range [1,3] ndim = random(1, 4).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_repeat_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_replication_pad2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_replication_pad2d_impl(test_case, padding, placement, sbp): m = torch.nn.ReplicationPad2d(padding=padding) dims = [random(2, 4) * 8 for _ in range(4)] x = random_tensor(4, *dims) y = x.to_global(placement=placement, sbp=sbp) z = m(y) return z class TestReplicationPad2dGlobal(flow.unittest.TestCase): @globaltest def test_replication_pad2d(test_case): padding = [ (2, 2, 1, 1), 1, (1, 0, 1, 0), (0, 1, 0, 1), ] for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): for pad in padding: _test_replication_pad2d_impl(test_case, pad, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_reshape.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_reshape_impl(test_case, pair, placement, sbp): shape, to_shape = pair x = random_tensor(len(shape), *shape) y = x.to_global(placement=placement, sbp=sbp) z = y.reshape(to_shape) return z def _test_reshape_like_impl(test_case, pair, placement, in_sbp, like_sbp): shape, to_shape = pair nd_arr = np.random.rand(*shape) np_out = nd_arr.reshape(to_shape) x = flow.tensor(nd_arr) like = flow.empty(to_shape) y = x.to_global(flow.placement.all("cpu"), flow.sbp.broadcast).to_global( placement=placement, sbp=in_sbp ) like = like.to_global(flow.placement.all("cpu"), flow.sbp.broadcast).to_global( placement=placement, sbp=like_sbp ) z = flow._C.reshape_like(y, like) local_z = z.to_global( placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ).to_local() if flow.env.get_rank() == 0: test_case.assertTrue(np.array_equal(np_out, local_z.numpy())) class TestReshapeGlobal(flow.unittest.TestCase): @globaltest def test_reshape(test_case): shape_pairs = [ ((8, 16), (8 * 16,)), ((8, 16), (8 * 4, 4)), ((8, 16, 24), (64, 6, 8)), ((8, 16), (64, 1, -1)), ((8, 16), (-1,)), ] for pair in shape_pairs: for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(pair[0])): _test_reshape_impl(test_case, pair, placement, sbp) @globaltest def test_reshape_like(test_case): shape_pairs = [ ((8, 16), (8 * 16,)), ((8, 16), (8 * 2, 8)), ((8, 16, 24), (64, 48)), ] for pair in shape_pairs: for placement in all_placement(): for in_sbp in all_sbp(placement, max_dim=len(pair[0])): for like_sbp in all_sbp(placement, max_dim=len(pair[1])): _test_reshape_like_impl( test_case, pair, placement, in_sbp, like_sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_rnn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest import torch from oneflow.test_utils.automated_test_util.generators import * from oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest from oneflow.test_utils.test_util import GenArgDict def _compare_torch_and_oneflow( test_case, m_torch, m_flow, placement, module_sbp, in_sbp, input_size ): torch_state_dict = m_torch.state_dict() new_dict = {} for k, v in torch_state_dict.items(): new_dict[k] = v.detach().numpy() m_flow.load_state_dict(new_dict) m_flow = m_flow.to_global(flow.placement.all("cpu"), flow.sbp.broadcast).to_global( placement=placement, sbp=[module_sbp for _ in range(len(placement.ranks.shape))] ) x = np.random.rand(32, 16, input_size).astype(np.float32) x_torch = torch.tensor(x, dtype=torch.float32, requires_grad=True) x_flow = ( flow.tensor(x, dtype=flow.float32, requires_grad=True) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) .to_global(placement=placement, sbp=in_sbp) ) out_torch, hid_torch = m_torch(x_torch) out_flow, hid_flow = m_flow(x_flow) # check forward local_output = out_flow.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose( out_torch.cpu().detach().numpy(), local_output.numpy(), rtol=1e-05, atol=1e-05, ) ) # check backward out_torch.sum().backward() out_flow.sum().backward() local_x_grad = x_flow.to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ).to_local() if flow.env.get_rank() == 0: test_case.assertTrue( np.allclose( x_torch.cpu().detach().numpy(), local_x_grad.numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_rnn_impl( test_case, placement, module_sbp, in_sbp, input_size, hidden_size, num_layers, nonlinearity, bias, batch_first, dropout, bidirectional, ): rnn_torch = torch.nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, nonlinearity=nonlinearity, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) rnn_flow = flow.nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, nonlinearity=nonlinearity, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) _compare_torch_and_oneflow( test_case, rnn_torch, rnn_flow, placement, module_sbp, in_sbp, input_size ) def _test_lstm_impl( test_case, placement, module_sbp, in_sbp, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, proj_size, ): lstm_torch = torch.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) lstm_flow = flow.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) _compare_torch_and_oneflow( test_case, lstm_torch, lstm_flow, placement, module_sbp, in_sbp, input_size ) def _test_gru_impl( test_case, placement, module_sbp, in_sbp, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, ): gru_torch = torch.nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) gru_flow = flow.nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, ) _compare_torch_and_oneflow( test_case, gru_torch, gru_flow, placement, module_sbp, in_sbp, input_size ) class TestRNNGlobal(oneflow.unittest.TestCase): @globaltest def test_rnn(test_case): arg_dict = OrderedDict() arg_dict["input_size"] = [ 1, ] arg_dict["hidden_size"] = [ 1, ] arg_dict["num_layers"] = [ 1, ] arg_dict["nonlinearity"] = ["tanh", "relu"] arg_dict["bias"] = [True, False] arg_dict["batch_first"] = [True, False] arg_dict["dropout"] = [ 0, ] arg_dict["bidirectional"] = [True, False] module_sbp = flow.sbp.broadcast for args in GenArgDict(arg_dict): for placement in all_placement(): for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1): _test_rnn_impl(test_case, placement, module_sbp, in_sbp, **args) @globaltest def test_lstm(test_case): arg_dict = OrderedDict() arg_dict["input_size"] = [ 1, ] arg_dict["hidden_size"] = [ 2, ] arg_dict["num_layers"] = [ 1, ] arg_dict["bias"] = [True, False] arg_dict["batch_first"] = [True, False] arg_dict["dropout"] = [ 0, ] arg_dict["bidirectional"] = [True, False] arg_dict["proj_size"] = [0, 1] module_sbp = flow.sbp.broadcast for args in GenArgDict(arg_dict): for placement in all_placement(): for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1): _test_lstm_impl(test_case, placement, module_sbp, in_sbp, **args) @globaltest def test_gru(test_case): arg_dict = OrderedDict() arg_dict["input_size"] = [ 1, ] arg_dict["hidden_size"] = [ 1, ] arg_dict["num_layers"] = [ 1, ] arg_dict["bias"] = [True, False] arg_dict["batch_first"] = [True, False] arg_dict["dropout"] = [ 0, ] arg_dict["bidirectional"] = [True, False] module_sbp = flow.sbp.broadcast for args in GenArgDict(arg_dict): for placement in all_placement(): for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1): _test_gru_impl(test_case, placement, module_sbp, in_sbp, **args) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_rnn_cell.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_rnn_relu_cell(test_case, placement, sbp): batch_size = random(2, 3) * 8 time_steps = random(2, 3) * 8 input_size = random(2, 3) * 8 hidden_size = random(2, 3) * 8 has_bias = random().to(bool) m = torch.nn.RNNCell( input_size=input_size, hidden_size=hidden_size, bias=has_bias, nonlinearity="relu", ) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight_ih = torch.nn.Parameter( m.weight_ih.to_global(placement=placement, sbp=weight_sbp) ) m.weight_hh = torch.nn.Parameter( m.weight_hh.to_global(placement=placement, sbp=weight_sbp) ) if m.bias_ih is not None: # bias is 1-d tensor bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True) m.bias_ih = torch.nn.Parameter( m.bias_ih.to_global(placement=placement, sbp=bias_sbp) ) m.bias_hh = torch.nn.Parameter( m.bias_hh.to_global(placement=placement, sbp=bias_sbp) ) input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to_global(placement=placement, sbp=input_sbp) hx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False ).to_global(placement=placement, sbp=sbp) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx @autotest(n=1, check_graph=False) def _test_rnn_tanh_cell(test_case, placement, sbp): batch_size = random(2, 3) * 8 time_steps = random(2, 3) * 8 input_size = random(2, 3) * 8 hidden_size = random(2, 3) * 8 has_bias = random().to(bool) m = torch.nn.RNNCell( input_size=input_size, hidden_size=hidden_size, bias=has_bias, nonlinearity="tanh", ) weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True) m.weight_ih = torch.nn.Parameter( m.weight_ih.to_global(placement=placement, sbp=weight_sbp) ) m.weight_hh = torch.nn.Parameter( m.weight_hh.to_global(placement=placement, sbp=weight_sbp) ) if m.bias_ih is not None: # bias is 1-d tensor bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True) m.bias_ih = torch.nn.Parameter( m.bias_ih.to_global(placement=placement, sbp=bias_sbp) ) m.bias_hh = torch.nn.Parameter( m.bias_hh.to_global(placement=placement, sbp=bias_sbp) ) input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1) input = random_tensor( ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size ).to_global(placement=placement, sbp=input_sbp) hx = random_tensor( ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False ).to_global(placement=placement, sbp=sbp) for i in range(time_steps.to(int).value()): hx = m(input[i], hx) return hx @unittest.skip("TODO(depeng): fails often on 4 GPUs") class TestRNNCellGlobal(flow.unittest.TestCase): @globaltest def test_rnn_relu_cell(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_rnn_relu_cell(test_case, placement, sbp) @globaltest def test_rnn_tanh_cell(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_rnn_tanh_cell(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_roi_align.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest import torch as pytorch import torchvision from oneflow.test_utils.automated_test_util import * def _get_np_rois(): random_img_idx = np.asarray( [random(0, 2).to(int).value() for _ in range(200)] ).reshape((200, 1)) random_box_idx = np.asarray( [random(0, 64 * 64).to(float).value() for _ in range(400)] ).reshape((200, 2)) def get_h_w(idx1, idx2): if idx1 > idx2: idx1, idx2 = idx2, idx1 h1 = idx1 // 64 w1 = idx1 % 64 h2 = idx2 // 64 w2 = idx2 % 64 return [x / 2 for x in [h1, w1, h2, w2]] zipped = zip(random_box_idx[:, 0], random_box_idx[:, 1]) concated = [get_h_w(idx1, idx2) for (idx1, idx2) in zipped] concated = np.array(concated) rois = np.hstack((random_img_idx, concated)).astype(np.float32) return rois def _test_roi_align(test_case, placement, rois_sbp): dims = [8, 8, 64, 64] x = random_tensor(4, *dims).to_global( placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))], ) x.oneflow = x.oneflow.detach().requires_grad_() x.pytorch = x.pytorch.detach().requires_grad_() def get_h_w(idx1, idx2): if idx1 > idx2: idx1, idx2 = idx2, idx1 h1 = idx1 // 64 w1 = idx1 % 64 h2 = idx2 // 64 w2 = idx2 % 64 return [x / 2 for x in [h1, w1, h2, w2]] np_rois = _get_np_rois() of_rois = ( flow.tensor(np_rois, dtype=flow.float) .to_global(placement=flow.placement.all("cpu"), sbp=[flow.sbp.broadcast,]) .to_global(placement, rois_sbp) ) torch_rois = pytorch.tensor(np_rois) of_out = flow.roi_align(x.oneflow, of_rois, 2.0, 14, 14, 2, True) torch_out = torchvision.ops.roi_align( x.pytorch, torch_rois, spatial_scale=2.0, output_size=[14, 14], sampling_ratio=2, aligned=True, ) # compare output of_local = of_out.to_global( placement=flow.placement.all("cpu"), sbp=[flow.sbp.broadcast,] ).to_local() test_case.assertTrue( np.allclose( of_local.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-04, atol=1e-4 ) ) # compare backward of_out.sum().backward() torch_out.sum().backward() of_input_grad = x.oneflow.grad.to_global( placement=flow.placement.all("cpu"), sbp=[flow.sbp.broadcast,] ).to_local() torch_input_grad = x.pytorch.grad.detach().cpu() test_case.assertTrue( np.allclose( of_input_grad.numpy(), torch_input_grad.numpy(), rtol=1e-04, atol=1e-4 ) ) def _test_roi_align_in_fixed_data_impl(test_case, placement, sbp): from test_roi_align import input_np, rois_np, input_grad_np input = ( flow.tensor(input_np, dtype=flow.float32) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp) .requires_grad_() ) rois = ( flow.tensor(rois_np, dtype=flow.float32) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) ) of_out = flow.roi_align(input, rois, 2.0, 5, 5, 2, True) of_out.sum().backward() test_case.assertTrue( np.allclose(input.grad.numpy(), input_grad_np, rtol=1e-04, atol=1e-4) ) class TestGlobalRoiAlign(flow.unittest.TestCase): # TODO(wyg): It is a bug in pytorch-1.9.0, torchvision-0.10.0 and python3.7.10. # Open this test after updating the versions of pytorch in CI. # @globaltest # def test_global_roi_align(test_case): # for placement in all_placement(): # # TODO: roi_align only support gpu # if placement.type == "cpu": # continue # for rois_sbp in all_sbp(placement, max_dim=0, except_partial_sum=True): # _test_roi_align(test_case, placement, rois_sbp) def test_global_roi_align_in_fixed_data(test_case): for placement in all_placement(): # TODO: roi_align only support gpu if placement.type == "cpu": continue for rois_sbp in all_sbp(placement, max_dim=0, except_partial_sum=True): _test_roi_align_in_fixed_data_impl(test_case, placement, rois_sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_roll.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_roll_impl(test_case, placement, sbp): shifts = ( random(-100, 100).to(int).value(), random(-100, 100).to(int).value(), random(-100, 100).to(int).value(), random(-100, 100).to(int).value(), ) dims = (0, 1, 2, 3) x_dims = [random(2, 4) * 8 for _ in range(4)] x = random_tensor(4, *x_dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.roll(y, shifts, dims) return z class TestRollGlobal(flow.unittest.TestCase): @globaltest def test_roll(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_roll_impl(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_round.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False) def _test_round_impl(test_case, ndim, placement, sbp): x_dims = [random(2, 4) * 8 for _ in range(ndim)] x = random_tensor(ndim, *x_dims) y = x.to_global(placement=placement, sbp=sbp) z = torch.round(y) return z class TestRoundGlobal(flow.unittest.TestCase): @globaltest def test_round(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_round_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_scatter_nd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_scatter_nd(test_case, placement, sbp): indices = ( flow.tensor(np.array([[1], [6], [4]]), dtype=flow.int) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp) ) update = ( flow.tensor(np.array([10.2, 5.1, 12.7]), dtype=flow.float) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp) .requires_grad_() ) output = flow.scatter_nd(indices, update, [8]) # forward of_local = output.to_global( flow.placement.all("cpu"), [flow.sbp.broadcast,] ).to_local() np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0]) test_case.assertTrue(np.allclose(of_local.numpy(), np_out, 1e-4, 1e-4)) # backward output.sum().backward() of_grad_local = update.grad.to_global( flow.placement.all("cpu"), [flow.sbp.broadcast,] ).to_local() test_case.assertTrue(np.allclose(of_grad_local.numpy(), np.ones((3)), 1e-4, 1e-4)) class TestScatterNd(flow.unittest.TestCase): @globaltest def test_scatter_nd(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_partial_sum=True, except_split=True): _test_scatter_nd(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_scatter_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=10, auto_backward=True, check_graph=True) def _test_scatter_random_data(test_case, placement): input = random_tensor(ndim=2, dim0=2, dim1=2).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) src = random_tensor(ndim=2, dim0=2, dim1=2).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) index = ( torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp=random_sbp(placement, max_dim=2),) ) dim = random(0, 2).to(int).value() return torch.scatter(input, dim, index, src) @autotest(n=10, auto_backward=True, check_graph=True) def _test_scatter_scalar_random_data(test_case, placement): input = random_tensor(ndim=2, dim0=2, dim1=2).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) index = ( torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp=random_sbp(placement, max_dim=2),) ) dim = random(0, 2).to(int).value() return torch.scatter(input, dim, index, 3.14) @autotest(n=10, auto_backward=True, check_graph=True) def _test_scatter_add_random_data(test_case, placement): input = random_tensor(ndim=2, dim0=2, dim1=2).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) src = random_tensor(ndim=2, dim0=2, dim1=2).to_global( placement=placement, sbp=random_sbp(placement, max_dim=2) ) index = ( torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64) .to_global(flow.placement.all("cpu"), [flow.sbp.broadcast,]) .to_global(placement, sbp=random_sbp(placement, max_dim=2),) ) dim = random(0, 2).to(int).value() return torch.scatter_add(input, dim, index, src) @flow.unittest.skip_unless_1n2d() class TestScatterOps(flow.unittest.TestCase): @globaltest def test_scatter_ops(test_case): for placement in all_placement(): _test_scatter_random_data(test_case, placement) _test_scatter_scalar_random_data(test_case, placement) _test_scatter_add_random_data(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_searchsorted.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=False) def _test_search_sorted(test_case, placement, sbp, ndim): dims = [random(1, 3) * 8 for _ in range(ndim)] sorted_sequence = random_tensor(ndim, *dims).to_global(placement, sbp) values = random_tensor(ndim, *dims).to_global(placement, sbp) y = torch.searchsorted( sorted_sequence, values, out_int32=oneof(True, False), right=oneof(True, False), ) return y class TestSearchSorted_Global(flow.unittest.TestCase): @globaltest def test_search_sorted(test_case): ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_search_sorted(test_case, placement, sbp, ndim) @autotest(n=1, auto_backward=False, check_graph=False) def _test_search_sorted_scalar(test_case, placement, sbp): dim0 = [random(1, 3) * 8] sorted_sequence = random_tensor(1, *dim0).to_global(placement, sbp) values = 5 y = torch.searchsorted( sorted_sequence, values, out_int32=oneof(True, False), right=oneof(True, False), ) return y class TestSearchSortedScalar_Global(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 8 times in past week") @globaltest def test_search_sorted_scalar(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_search_sorted_scalar(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sign.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest @autotest(n=1, check_graph=True) def _test_sign_impl(test_case, ndim, placement, sbp): dims = [random(1, 3) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp) y = torch.sign(x) return y class TestSign(flow.unittest.TestCase): @globaltest def test_sign(test_case): for placement in all_placement(): ndim = random(1, 4).to(int).value() for sbp in all_sbp(placement, max_dim=ndim): _test_sign_impl(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_slice.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_forward_and_backward(test_case, input, of_out, torch_out): # compare forward test_case.assertTrue( np.array_equal(of_out.numpy(), torch_out.cpu().detach().numpy()) ) # compare backward of_out.sum().backward() torch_out.sum().backward() torch_grad_local = input.pytorch.grad.cpu().detach() test_case.assertTrue( np.array_equal(input.oneflow.grad.numpy(), torch_grad_local.numpy()) ) def _test_slice_random_data(test_case, placement, sbp): dims = [random(1, 2) * 8 for _ in range(2)] input = random_tensor(2, *dims) x = input.to_global(placement=placement, sbp=sbp) slice_tup_list = [[None, None, None], [0, 5, 2]] of_out = flow.slice(x.oneflow, slice_tup_list=slice_tup_list) torch_out = x.pytorch[:, 0:5:2] _check_forward_and_backward(test_case, input, of_out, torch_out) def _test_slice_empty(test_case, placement, sbp): dims = [random(1, 2) * 8 for _ in range(2)] input = random_tensor(2, *dims) x = input.to_global(placement=placement, sbp=sbp) slice_tup_list = [[3, 3, 1], [None, None, None]] of_out = flow.slice(x.oneflow, slice_tup_list=slice_tup_list) torch_out = x.pytorch[3:3:1, :] _check_forward_and_backward(test_case, input, of_out, torch_out) def _test_slice_1dim(test_case, placement, sbp): dims = [random(1, 2) * 8 for _ in range(2)] input = random_tensor(2, *dims) x = input.to_global(placement=placement, sbp=sbp) of_out = x.oneflow[2] torch_out = x.pytorch[2] _check_forward_and_backward(test_case, input, of_out, torch_out) def _test_negative_index(test_case, placement, sbp): dims = [random(1, 2) * 8 for _ in range(2)] input = random_tensor(2, *dims) x = input.to_global(placement=placement, sbp=sbp) of_out = x.oneflow[-1:-6:1, :] torch_out = x.pytorch[-1:-6:1, :] _check_forward_and_backward(test_case, input, of_out, torch_out) def _test_slice_ellipsis_type(test_case, placement, sbp): dims = [random(1, 2) * 8 for _ in range(2)] input = random_tensor(2, *dims) x = input.to_global(placement=placement, sbp=sbp) of_out = x.oneflow[..., :] torch_out = x.pytorch[..., :] _check_forward_and_backward(test_case, input, of_out, torch_out) def _test_slice_with_bool(test_case, placement, sbp): x = random_tensor(2, 8, 8).oneflow > 0.5 x_numpy = x.detach().cpu().numpy() x = x.to_global(placement=placement, sbp=sbp) y = flow.slice(x, slice_tup_list=[[0, 1, 1]]) test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1])) @autotest( n=2, auto_backward=False, check_graph=True, ) def _test_slice_with_grad(test_case, placement): sbp = random_sbp(placement, max_dim=2).value() # out_sbp sbp_map = { flow.sbp.broadcast: flow.sbp.broadcast, flow.sbp.split(0): flow.sbp.split(0), flow.sbp.split(1): flow.sbp.partial_sum(), flow.sbp.partial_sum: flow.sbp.partial_sum(), } assert sbp is not None out_sbp = tuple([sbp_map[in_sbp] for in_sbp in sbp]) x = random_tensor(2, 8, 16, requires_grad=True).oneflow x_numpy = x.detach().cpu().numpy() class SliceWithGrad(flow.nn.Module): def __init__(self): super().__init__() self.input_grad = flow.nn.Parameter(flow.zeros(8, 16)) def forward(self, input): x = input + self.input_grad x = x.to_global(placement, sbp) return x[:, :8] slice_with_grad_m = SliceWithGrad().to_global( placement, [flow.sbp.broadcast,] * len(sbp) ) of_sgd = flow.optim.SGD(slice_with_grad_m.parameters(), lr=1.0, momentum=0.0) class SliceTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.module = slice_with_grad_m self.add_optimizer(of_sgd) def build(self, x): out = self.module(x) test_case.assertEqual( out.sbp, out_sbp, f"input sbp is {sbp}, but output sbp is {out.sbp} with placement: {placement}", ) z = out.sum() z.backward() return out graph = SliceTrainGraph() input = x.to_global(placement=placement, sbp=sbp) y = graph(input) # output test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :8])) # input_grad x_grad_np = np.zeros((8, 16)) x_grad_np[:, :8] = 1 test_case.assertTrue( np.array_equal(-graph.module.input_grad.to(flow.Tensor).numpy(), x_grad_np) ) class TestSlice(flow.unittest.TestCase): @globaltest def test_slice(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_slice_random_data(test_case, placement, sbp) _test_slice_empty(test_case, placement, sbp) _test_slice_1dim(test_case, placement, sbp) _test_negative_index(test_case, placement, sbp) _test_slice_ellipsis_type(test_case, placement, sbp) _test_slice_with_bool(test_case, placement, sbp) @unittest.skip("skip for now, becase it failed 12 times in past week") @globaltest def test_graph_slice(test_case): for placement in all_placement(): _test_slice_with_grad(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_slice_update.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_slice_update(test_case, placement, sbp): input = random_tensor(2, 8, 16, requires_grad=True).oneflow value = random_tensor(2, 8, 8, requires_grad=True).oneflow x = (input + 0).to_global( placement=placement, sbp=sbp ) # add 0 to change to non-leaf tensor y = value.to_global(placement, sbp=sbp) x[:, :8] = y ref_np = input.detach().cpu().numpy() value_np = value.detach().cpu().numpy() # forward ref_np[:, :8] = value_np test_case.assertTrue(x.sbp == sbp) test_case.assertTrue(np.array_equal(x.numpy(), ref_np)) # backward x.sum().backward() # ref grad ref_grad_np = np.ones((8, 16)) ref_grad_np[:, :8] = 0 test_case.assertTrue(np.array_equal(input.grad.numpy(), ref_grad_np)) # value grad value_grad_np = np.ones((8, 8)) test_case.assertTrue(np.array_equal(value.grad.numpy(), value_grad_np)) def _test_graph_slice_update(test_case, placement, sbp): ref = random_tensor(2, 8, 16, requires_grad=True).oneflow value = random_tensor(2, 8, 8, requires_grad=True).oneflow class SliceUpdateWithGrad(flow.nn.Module): def __init__(self): super().__init__() self.ref_grad = flow.nn.Parameter(flow.zeros(8, 16)) self.value_grad = flow.nn.Parameter(flow.zeros(8, 8)) def forward(self, ref, value): x = ref + self.ref_grad y = value + self.value_grad x = x.to_global(placement, sbp) y = y.to_global(placement, sbp) x[:, :8] = y return x slice_update_with_grad_m = SliceUpdateWithGrad().to_global( placement, [flow.sbp.broadcast,] * len(sbp) ) of_sgd = flow.optim.SGD(slice_update_with_grad_m.parameters(), lr=1.0, momentum=0.0) class SliceUpdateTrainGraph(flow.nn.Graph): def __init__(self): super().__init__() self.module = slice_update_with_grad_m self.add_optimizer(of_sgd) def build(self, x, y): out = self.module(x, y) z = out.sum() z.backward() return out graph = SliceUpdateTrainGraph() x = ref.to_global(placement=placement, sbp=sbp) y = value.to_global(placement=placement, sbp=sbp) z = graph(x, y) test_case.assertTrue(z.sbp == sbp) ref_np = ref.detach().cpu().numpy() value_np = value.detach().cpu().numpy() # forward ref_np[:, :8] = value_np test_case.assertTrue(np.array_equal(z.numpy(), ref_np)) # backward # ref grad ref_grad = np.ones((8, 16)) ref_grad[:, :8] = 0 test_case.assertTrue( np.array_equal(-graph.module.ref_grad.to(flow.Tensor).numpy(), ref_grad) ) # value grad value_grad = np.ones((8, 8)) test_case.assertTrue( np.array_equal(-graph.module.value_grad.to(flow.Tensor).numpy(), value_grad) ) class TestGlobalSliceUpdate(flow.unittest.TestCase): @globaltest def test_slice_update(test_case): for placement in all_placement(): for _ in range(2): sbp = random_sbp(placement, max_dim=2).value() _test_slice_update(test_case, placement, sbp) _test_graph_slice_update(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_sort_impl(test_case, placement): sbp = random_sbp(placement, max_dim=4) x_dims = [random(2, 4) * 8 for _ in range(4)] x = random_tensor(4, *x_dims) dim = random(0, 4).to(int).value() descending = random().to(bool).value() y = x.to_global(placement=placement, sbp=sbp) sort_result = torch.sort(y, dim=dim, descending=descending) value = sort_result[0] return value class TestSortGlobal(flow.unittest.TestCase): @globaltest def test_sort(test_case): for placement in all_placement(): _test_sort_impl(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sparse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=False) def _test_embedding(test_case, ndim, placement, sbp): emb_size = random() * 8 emb_dim = random() * 8 emb_shape = [emb_size, emb_dim] idx_shape = [random(high=4) * 8 for i in range(ndim)] weight = random_tensor(2, *emb_shape) indices = random_tensor( len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int ).to_global(placement=placement, sbp=sbp) embedding = torch.nn.Embedding(emb_size, emb_dim, _weight=weight).to_global( placement=placement, sbp=sbp ) output = embedding(indices) return output class TestEmbedding(flow.unittest.TestCase): @globaltest def test_embedding(test_case): ndim = 2 for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): _test_embedding(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os from collections import OrderedDict import numpy as np import torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type from oneflow.test_utils.automated_test_util.generators import * from oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest def _compare_eager_global_with_torch( placement, logits_sbp, labels_sbp, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) if flow.env.get_rank() == 0: torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) torch_output.sum().backward() of_logits = flow.tensor(np_logits, dtype=data_type, requires_grad=True).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) of_logits = of_logits.to_global(placement, logits_sbp) of_logits.retain_grad() of_labels = flow.tensor(np_labels, dtype=label_type).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) of_labels = of_labels.to_global(placement, labels_sbp) of_output = flow.nn.functional.sparse_softmax_cross_entropy( labels=of_labels, logits=of_logits ) of_output.sum().backward() of_logits_grad = of_logits.grad.to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) of_logits_grad = of_logits_grad.to_local() of_output = of_output.to_global(flow.placement.all("cpu"), flow.sbp.broadcast) of_output = of_output.to_local() if flow.env.get_rank() == 0: assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) assert np.allclose( of_logits_grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04 ) def _compare_lazy_global_with_torch( placement, logits_sbp, labels_sbp, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) if flow.env.get_rank() == 0: torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) class MyModule(flow.nn.Graph): def __init__(self): super(MyModule, self).__init__() # nn.graph no support get input.grad def build(self, logits, labels): output = flow.nn.functional.sparse_softmax_cross_entropy( labels=labels, logits=logits ) return output of_logits = flow.tensor(np_logits, dtype=data_type, requires_grad=True).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) of_logits = of_logits.to_global(placement, logits_sbp) of_labels = flow.tensor(np_labels, dtype=label_type).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) of_labels = of_labels.to_global(placement, labels_sbp) graph = MyModule() of_output = graph(of_logits, of_labels) of_output = of_output.to_global( placement=flow.placement.all("cpu"), sbp=[flow.sbp.broadcast] ) of_output = of_output.to_local() flow._oneflow_internal.eager.multi_client.Sync() if flow.env.get_rank() == 0: assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) class TestGlobalSparseSoftmaxCrossEntropyWithLogits(flow.unittest.TestCase): @globaltest def test_eager_global_sparse_softmax_cross_entropy(test_case): arg_dict = OrderedDict() arg_dict["data_type"] = ["float32", "double"] arg_dict["label_type"] = ["int32", "int64"] arg_dict["batch_size"] = [64] arg_dict["num_classes"] = [1024] for arg in GenArgList(arg_dict): for placement in all_placement(): for logits_sbp in all_sbp(placement, max_dim=2): for labels_sbp in all_sbp(placement, max_dim=1): _compare_eager_global_with_torch( placement, logits_sbp, labels_sbp, *arg ) # TODO: Too many streams will cause bugs, open the graph mode after solving # @globaltest # def test_lazy_global_sparse_softmax_cross_entropy(test_case): # arg_dict = OrderedDict() # arg_dict["data_type"] = ["float32", "double"] # arg_dict["label_type"] = ["int32", "int64"] # arg_dict["batch_size"] = [64] # arg_dict["num_classes"] = [1024] # for arg in GenArgList(arg_dict): # for placement in all_placement(): # for logits_sbp in all_sbp(placement, max_dim=2): # for labels_sbp in all_sbp(placement, max_dim=1): # _compare_lazy_global_with_torch( # placement, logits_sbp, labels_sbp, *arg # ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_split.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_flow_split_with_random_data(test_case, placement, sbp): k0 = random(2, 6) * 8 k1 = random(2, 6) * 8 k2 = random(2, 6) * 8 rand_dim = random(0, 3).to(int) x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global( placement=placement, sbp=sbp ) res = torch.split(x, 2, dim=rand_dim) return torch.cat(res, rand_dim) @autotest(n=2, check_graph=True) def _test_flow_split_sizes_with_random_data(test_case, placement, sbp): k0 = random(2, 6) * 8 k1 = 16 k2 = random(2, 6) * 8 x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global( placement=placement, sbp=sbp ) res = torch.split(x, [6, 3, 4, 3], dim=1) return torch.cat(res, dim=1) @autotest(n=2, check_graph=True) def _test_flow_split_sizes_neg_dim_with_random_data(test_case, placement, sbp): k0 = random(2, 6) * 8 k1 = 16 k2 = random(2, 6) * 8 x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global( placement=placement, sbp=sbp ) res = torch.split(x, [6, 3, 4, 3], dim=-2) return torch.cat(res, dim=1) class TestGlobalSplitModule(flow.unittest.TestCase): @globaltest def test_flow_split_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_split_with_random_data(test_case, placement, sbp) @globaltest def test_flow_split_sizes_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_split_sizes_with_random_data(test_case, placement, sbp) @globaltest def test_flow_split_sizes_neg_dim_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_split_sizes_neg_dim_with_random_data( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sqrt_square_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True, rtol=0.5, atol=0.5) def _test_sqrt_sum_with_cpu_random_data(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=32, dim2=40, dim3=64).to_global( placement=placement, sbp=sbp ) y = torch.linalg.norm(x) return y @autotest(n=1, check_graph=True, rtol=0.5, atol=0.5) def _test_scalar_random_data(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=24, dim2=16, dim3=40).to_global( placement=placement, sbp=sbp ) y = torch.linalg.norm(x) return y class TestGlobalLinalgVectorNorm2D(flow.unittest.TestCase): @globaltest def test_sqrt_sum_with_cpu_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_sqrt_sum_with_cpu_random_data(test_case, placement, sbp) @globaltest def test_scalar_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_scalar_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_squeeze.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_squeeze_1d_input(test_case, placement, sbp): x = random_tensor(1, 16, dtype=float).to_global(placement, sbp) y = torch.squeeze(x) return y @autotest(n=1, check_graph=True) def _test_flow_squeeze_with_random_data(test_case, placement, sbp): x = random_tensor(2, 8, 16).to_global(placement, sbp) y = torch.squeeze(x, random(0, 2).to(int)) return y @autotest(n=1, check_graph=True) def _test_squeeze_with_0_size_data(test_case, placement, sbp): x = random_tensor(3, 8, 16, 0).to_global(placement, sbp) y = torch.squeeze(x) return y class TestGlobalSqueeze(flow.unittest.TestCase): @globaltest def test_squeeze_1d_input(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_squeeze_1d_input(test_case, placement, sbp) @globaltest def test_flow_squeeze_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_squeeze_with_random_data(test_case, placement, sbp) @globaltest def test_squeeze_with_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_squeeze_with_0_size_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_stack.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=2, check_graph=True) def _test_stack_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=8).to_global( placement, sbp ) y = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=8).to_global( placement, sbp ) out = torch.stack((x, y), dim=random(low=-5, high=5).to(int)) return out @unittest.skip("backward of stack with random diff has bug.") class TestStackModule(flow.unittest.TestCase): @globaltest def test_stack_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_stack_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_stateful_kernel_with_cache.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_global_stateful_kernel_with_inpersistent_state(test_case, placement, sbp): x = ( flow.arange(64) .reshape(8, 8) .to_global(flow.placement.all("cpu"), flow.sbp.broadcast) ) x = x.to_global(placement, sbp) y = x[0:3, 0:1] y_np = np.array([[0], [8], [16]]) test_case.assertTrue(np.array_equal(y.numpy(), y_np,)) x = x.to_global(flow.placement.all("cpu"), sbp=flow.sbp.split(1)) y = x[0:3, 0:1] test_case.assertTrue(np.array_equal(y.numpy(), y_np,)) class TestStatefulKernelWithInpersistentState(flow.unittest.TestCase): @globaltest def test_global_stateful_kernel_with_inpersistent_state(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_stateful_kernel_with_inpersistent_state( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_std.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_std_flow_with_random_data(test_case, placement, sbp): dim = random(low=0, high=4).to(int) x = random_tensor( ndim=4, dim0=random(1, 4) * 8, dim1=random(1, 4) * 8, dim2=random(1, 4) * 8, dim3=random(1, 4) * 8, ).to_global(placement, sbp) z = torch.std(x, dim=dim, unbiased=random().to(bool), keepdim=random().to(bool),) return z @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_std_tensor_with_random_data(test_case, placement, sbp): dim = random(low=0, high=4).to(int) x = random_tensor( ndim=4, dim0=random(1, 4) * 8, dim1=random(1, 4) * 8, dim2=random(1, 4) * 8, dim3=random(1, 4) * 8, ).to_global(placement, sbp) z = x.std(dim=dim, keepdim=random().to(bool),) return z class TestGlobalStd(flow.unittest.TestCase): @globaltest def test_global_std_flow_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_std_flow_with_random_data(test_case, placement, sbp) @globaltest def test_global_std_tensor_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_std_tensor_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sub.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_sub(test_case, placement, sbp): x = random_tensor(2, 8, 8).to_global(placement=placement, sbp=sbp) y = random_tensor(2, 8, 8).to_global(placement=placement, sbp=sbp) out1 = x - y out2 = x - 2 out3 = 2 - x out4 = torch.sub(x, y) return out1, out2, out3, out4 @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_sub_with_0_size_data(test_case, placement, sbp): device = random_device() x = random_tensor(2, 0, 8).to_global(placement=placement, sbp=sbp) out1 = x - 2 out2 = 2 - x return out1, out2 class TestGlobalSubModule(flow.unittest.TestCase): @globaltest def test_global_sub(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_sub(test_case, placement, sbp) @globaltest def test_global_sub_with_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2, valid_split_axis=1): _test_global_sub_with_0_size_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True, rtol=1e-3) def _test_global_sum_against_pytorch(test_case, placement, sbp): x = random_tensor(4, 8, 16, 8, 24).to_global(placement, sbp) y = torch.sum(x) return y @autotest(n=1, check_graph=True) def _test_global_sum_with_0_size_tensor(test_case, placement, sbp): x = random_tensor(4, 8, 16, 0, 24).to_global(placement, sbp) y = torch.sum(x, dim=random(0, 3).to(int)) return y class TestGlobalSumModule(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_global_sum_against_pytorch(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_sum_against_pytorch(test_case, placement, sbp) @globaltest def test_global_sum_with_0_size_tensor(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]): _test_global_sum_with_0_size_tensor(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tensor_new.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=False, auto_backward=False) def _test_tensor_new(test_case, placement, sbp): x = random_tensor(1, 64).to_global(placement=placement, sbp=sbp).oneflow y = x.new() test_case.assertTrue(x.dtype == y.dtype) for x_sbp, y_sbp in zip(x.sbp, y.sbp): test_case.assertTrue(x_sbp == y_sbp) test_case.assertTrue(x.placement == y.placement) y = x.new(1, 2, 3) test_case.assertTrue(list(y.shape) == [1, 2, 3]) test_case.assertTrue(x.dtype == y.dtype) for x_sbp, y_sbp in zip(x.sbp, y.sbp): test_case.assertTrue(x_sbp == y_sbp) test_case.assertTrue(x.placement == y.placement) y = x.new([1, 2, 3]) test_case.assertTrue(list(y.shape) == [3]) test_case.assertTrue(x.dtype == y.dtype) for x_sbp, y_sbp in zip(x.sbp, y.sbp): test_case.assertTrue(x_sbp == y_sbp) test_case.assertTrue(x.placement == y.placement) class TestTensorNew(flow.unittest.TestCase): @globaltest def test_tensor_new(test_case): for placement in all_placement(): for sbp in all_sbp(placement, valid_split_axis=0): _test_tensor_new(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tensor_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_type_as(test_case, shape, src_dtype, tgt_dtype, placement, sbp): np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=src_dtype).to_global(placement, sbp) target = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp) input = input.type_as(target) test_case.assertEqual(input.dtype, target.dtype) def _test_local_to_global_type_as( test_case, shape, src_dtype, tgt_dtype, placement, sbp ): np_input = np.random.rand(*shape) input = random_tensor(ndim=len(shape)).oneflow.to_local() target = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp) input = input.type_as(target) test_case.assertEqual(input.dtype, target.dtype) test_case.assertEqual(input.placement, target.placement) test_case.assertEqual(input.sbp, target.sbp) def _test_global_to_local_type_as( test_case, shape, src_dtype, tgt_dtype, placement, sbp ): np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp) target = random_tensor(ndim=len(shape)).to(random_device()).oneflow.to_local() input = input.type_as(target) test_case.assertEqual(input.dtype, target.dtype) test_case.assertEqual(input.device, target.device) def _test_is_floating_point(test_case, shape, dtype, placement, sbp): np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=dtype).to_global(placement, sbp) output = input.is_floating_point() if input.dtype in (flow.float, flow.float16, flow.float32, flow.double): test_case.assertEqual(output, True) else: test_case.assertEqual(output, False) @autotest(n=1, check_graph=True) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def _test_global_cuda(test_case, placement, sbp): x = random_tensor(2, 8, 16).to_global(placement, sbp) x = x.cuda() y = x.sum() return y class TestGlobalCuda(flow.unittest.TestCase): @globaltest def test_global_cuda(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_cuda(test_case, placement, sbp) @autotest(n=1, check_graph=True) def _test_global_cpu(test_case, placement, sbp): x = random_tensor(2, 8, 16).to_global(placement, sbp) x = x.cpu() y = x.sum() return y # PyTorch error if open auto_backward: # element 0 of tensors does not require grad and does not have a grad_fn @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_long(test_case, placement, sbp): x = random_tensor(2, 8, 16, requires_grad=True).to_global(placement, sbp) y = x.long() test_case.assertFalse(y.oneflow.requires_grad) return y @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_int(test_case, placement, sbp): x = random_tensor(2, 8, 16, requires_grad=True).to_global(placement, sbp) y = x.int() test_case.assertFalse(y.oneflow.requires_grad) return y @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_float(test_case, placement, sbp): x = random_tensor(2, 8, 16, dtype=int).to_global(placement, sbp) y = x.float() return y @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_double(test_case, placement, sbp): x = random_tensor(2, 8, 16, dtype=int).to_global(placement, sbp) y = x.double() return y @autotest(n=1, auto_backward=False, check_graph=True) def _test_global_item(test_case, placement, sbp): x = random_tensor(ndim=1, dim0=1, dtype=int).to_global(placement, sbp) y = torch.tensor(x.item()) return y @autotest(n=1, auto_backward=False, check_graph=False) def _test_global_tolist(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=32, dtype=int).to_global( placement, sbp ) y = torch.tensor(x.tolist()) return y class TestGlobalTensorOps(flow.unittest.TestCase): @globaltest def test_global_cpu(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_cpu(test_case, placement, sbp) @globaltest def test_global_long(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_long(test_case, placement, sbp) @globaltest def test_global_int(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_int(test_case, placement, sbp) @globaltest def test_global_float(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_float(test_case, placement, sbp) @globaltest def test_global_double(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_double(test_case, placement, sbp) @unittest.skip("TODO: sometimes global item will result to segment fault!") @globaltest def test_global_item(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1, except_split=True): _test_global_item(test_case, placement, sbp) @globaltest def test_global_tolist(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_tolist(test_case, placement, sbp) @globaltest def test_type_as(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8, 16), (8, 16, 24), (8, 16, 24, 32)] arg_dict["src_dtype"] = [flow.int64, flow.int32, flow.float32, flow.float64] arg_dict["tgt_dtype"] = [flow.int64, flow.int32, flow.float32, flow.float64] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(arg[0])): _test_type_as(test_case, *arg, placement, sbp) _test_local_to_global_type_as(test_case, *arg, placement, sbp) _test_global_to_local_type_as(test_case, *arg, placement, sbp) @globaltest def test_is_floating_point(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(8, 16), (8, 16, 24), (8, 16, 24, 32)] arg_dict["dtype"] = [ # flow.uint8, nccl don't support uint8 flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, flow.double, flow.float, flow.int, ] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=len(arg[0])): _test_is_floating_point(test_case, *arg, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * class TensorScatterNdUpdate(flow.nn.Graph): def __init__(self): super(TensorScatterNdUpdate, self).__init__() def build(self, origin, indices, update): return flow.tensor_scatter_nd_update(origin, indices, update) def _test_global_tensor_scatter_nd_update(test_case, placement, sbp, check_graph=True): origin = random_tensor(1, 16, requires_grad=False).to_global(placement, sbp) indices = choice_tensor(16, (8, 1), replace=False).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) update = random_tensor(1, 8, requires_grad=False).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) np_origin = origin.oneflow.numpy() np_indices = indices.oneflow.numpy().reshape(8) np_update = update.oneflow.numpy() if check_graph: tensor_scatter_nd_update = TensorScatterNdUpdate() output = tensor_scatter_nd_update( origin.oneflow, indices.oneflow, update.oneflow ) else: output = flow.tensor_scatter_nd_update( origin.oneflow, indices.oneflow, update.oneflow ) np_origin[np_indices] = np_update test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001)) def _test_global_tensor_scatter_nd_update_t( test_case, placement, sbp, check_graph=True ): origin = random_tensor(2, 16, 4, requires_grad=False).to_global(placement, sbp) indices = choice_tensor(16, (8, 1), replace=False).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) update = random_tensor(2, 8, 4, requires_grad=False).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) np_origin = origin.oneflow.numpy() np_indices = indices.oneflow.numpy().reshape(8) np_update = update.oneflow.numpy() if check_graph: tensor_scatter_nd_update = TensorScatterNdUpdate() output = tensor_scatter_nd_update( origin.oneflow, indices.oneflow, update.oneflow ) else: output = flow.tensor_scatter_nd_update( origin.oneflow, indices.oneflow, update.oneflow ) np_origin[np_indices] = np_update test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001)) def _test_eager_global_tensor_scatter_nd_update_backward(test_case, placement, sbp): origin = random_tensor(1, 16,).to_global(placement, sbp) origin.retain_grad() indices = choice_tensor(16, (8, 1), replace=False).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) update = random_tensor(1, 8).to_global( placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))] ) update.retain_grad() np_origin = origin.oneflow.numpy() np_indices = indices.oneflow.numpy().reshape(8) np_update = update.oneflow.numpy() np_update_grad = np.ones(8) np_origin_grad = np.ones(16) np_origin_grad[np_indices] = np.zeros(8) output = flow.tensor_scatter_nd_update( origin.oneflow, indices.oneflow, update.oneflow ) out_sum = output.sum() out_sum.backward() np_origin[np_indices] = np_update test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001)) test_case.assertTrue(np.allclose(update.oneflow.grad.numpy(), np_update_grad)) test_case.assertTrue(np.allclose(origin.oneflow.grad.numpy(), np_origin_grad)) class TestTensorScatterNdUpdate(flow.unittest.TestCase): @globaltest def test_global_tensor_scatter_nd_update(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_tensor_scatter_nd_update( test_case, placement, sbp, False ) # eager global # skip lazy test # _test_global_tensor_scatter_nd_update( # test_case, placement, sbp, True # ) # nn graph @globaltest def test_global_tensor_scatter_nd_update_t(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_tensor_scatter_nd_update_t( test_case, placement, sbp, False ) # eager global # skip lazy test # _test_global_tensor_scatter_nd_update_t( # test_case, placement, sbp, True # ) # nn graph @globaltest def test_global_tensor_scatter_nd_update_backward(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_eager_global_tensor_scatter_nd_update_backward( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tensordot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True, atol=1e-3) def _test_global_tensordot_against_pytorch(test_case, ndim, placement, sbp): k = random(1, 2) * 8 tensordot_dim = random(0, ndim + 1).to(int) x = random_tensor(ndim=ndim, dim0=k, dim1=k, dim2=k, dim3=k).to_global( placement=placement, sbp=sbp ) y = random_tensor(ndim=ndim, dim0=k, dim1=k, dim2=k, dim3=k).to_global( placement=placement, sbp=sbp ) z = torch.tensordot(x, y, dims=tensordot_dim) return z class TestTensorDotGlobal(flow.unittest.TestCase): @globaltest def test_tensordot(test_case): for placement in all_placement(): for ndim in range(1, 4): for sbp in all_sbp(placement, max_dim=ndim): _test_global_tensordot_against_pytorch( test_case, ndim, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_global_flow_tile_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) reps = ( random(1, 5).to(int) * 8, random(1, 5).to(int) * 8, random(1, 5).to(int) * 8, ) z = torch.tile(x, reps) return z @autotest(n=1, check_graph=True) def _test_global_flow_tensor_tile_with_random_data(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) reps = ( random(1, 5).to(int) * 8, random(1, 5).to(int) * 8, random(1, 5).to(int) * 8, ) y = x.tile(reps) return y class TestGlobalTile(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed in 10 retry") @globaltest def test_global_flow_tile_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_flow_tile_with_random_data(test_case, placement, sbp) @unittest.skip("skip for now, becase it failed in 10 retry") @globaltest def test_global_flow_tensor_tile_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_flow_tensor_tile_with_random_data( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_transpose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_global_transpose(test_case, placement, sbp): input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) input = input.to_global(placement, sbp) of_out = flow.transpose(input, 0, 1) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_global_tensor_transpose(test_case, placement, sbp): input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) input = input.to_global(placement, sbp) of_out = input.transpose(0, 1) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_global_tranpose_negative_dim(test_case, placement, sbp): input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global( flow.placement.all("cpu"), flow.sbp.broadcast ) input = input.to_global(placement, sbp) of_out = flow.transpose(input, -4, -3) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_global_transpose_backward(test_case, placement, sbp): x = flow.tensor( np.random.randn(8, 16, 8, 16), dtype=flow.float32, requires_grad=True, ).to_global(flow.placement.all("cpu"), flow.sbp.broadcast) x = x.to_global(placement, sbp) x.retain_grad() y = flow.transpose(x, 0, 1).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((8, 16, 8, 16)), 1e-05, 1e-05) ) def _test_global_transpose_backward_v2(test_case, placement, sbp): x = flow.tensor( np.random.randn(8, 16, 8, 16), dtype=flow.float32, requires_grad=True, ).to_global(flow.placement.all("cpu"), flow.sbp.broadcast) x = x.to_global(placement, sbp) x.retain_grad() y = flow.transpose(x, 3, 1).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((8, 16, 8, 16)), 1e-05, 1e-05) ) @autotest(n=1, check_graph=True) def _test_global_transpose_flow_with_random_data(test_case, placement, sbp): x = random_tensor(4, 8, 16, 24, 8).to_global(placement, sbp) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y @autotest(n=1, check_graph=True) def _test_global_transpose_with_0_size_data(test_case, placement, sbp): device = random_device() x = random_tensor(4, 8, 16, 0, 8).to_global(placement, sbp) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y class TestGlobalTranspose(flow.unittest.TestCase): @globaltest def test_global_transpose(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_global_transpose, _test_global_tensor_transpose, _test_global_tranpose_negative_dim, _test_global_transpose_backward, _test_global_transpose_backward_v2, ] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): arg[0](test_case, placement, sbp) @globaltest def test_global_transpose_flow_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_transpose_flow_with_random_data(test_case, placement, sbp) @globaltest def test_global_transpose_with_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]): _test_global_transpose_with_0_size_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_tril.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=2, check_graph=True) def _test_global_tril_without_diag(test_case, placement, sbp): x = random_tensor( ndim=4, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, dim2=random(1, 3).to(int) * 8, dim3=random(1, 3).to(int) * 8, ).to_global(placement, sbp) y = torch.tril(x) y = torch.exp(y) return y @autotest(n=2, check_graph=True) def _test_global_tril_with_diag(test_case, placement, sbp): diagonal = random(-3, 3).to(int) x = random_tensor( ndim=4, dim0=random(1, 4).to(int) * 8, dim1=random(1, 4).to(int) * 8, dim2=random(1, 4).to(int) * 8, dim3=random(1, 4).to(int) * 8, ).to_global(placement, sbp) y = torch.tril(x, diagonal) y = torch.exp(y) return y class TestGlobalTril(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") @globaltest def test_global_tril_without_diag(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_tril_without_diag(test_case, placement, sbp) @globaltest def test_global_tril_with_diag(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_tril_with_diag(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_triu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=2, check_graph=True) def _test_global_triu_without_diag(test_case, placement, sbp): x = random_tensor( ndim=4, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, dim2=2, dim3=4, ).to_global(placement, sbp) y = torch.triu(x) y = torch.exp(y) return y @autotest(n=2, check_graph=True) def _test_global_triu_with_diag(test_case, placement, sbp): diagonal = random(-3, 3).to(int) x = random_tensor( ndim=4, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, dim2=2, dim3=4, ).to_global(placement, sbp) y = torch.triu(x, diagonal) y = torch.exp(y) return y class TestGlobalTriu(flow.unittest.TestCase): @globaltest def test_global_triu_without_diag(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_triu_without_diag(test_case, placement, sbp) @globaltest def test_global_triu_with_diag(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_triu_with_diag(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_unbind.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * # TODO: the test is dependent on global select op(global tensor->stride()) @unittest.skip("global select op is not currently supported") @autotest(n=1, check_graph=True) def _test_unbind(test_case, placement, sbp): dim_size = random(1, 3).to(int).value() * 8 rand_dim = random(0, 3).to(int).value() x = random_tensor(ndim=3, dim0=dim_size, dim1=dim_size, dim2=dim_size).to_global( placement, sbp ) return torch.unbind(x, dim=rand_dim) class TestUnbind(flow.unittest.TestCase): @globaltest def test_unbind(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=3): _test_unbind(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_unfold.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_2_t @autotest(n=1, check_graph=True) def _test_unfold_with_random_data(test_case, placement, sbp): ndim = 4 dims = [random(1, 3).to(int).value() * 8 for i in range(ndim)] m = torch.nn.Unfold( kernel_size=random(1, 3).to(_size_2_t), dilation=random(1, 2).to(_size_2_t), padding=random(0, 1).to(_size_2_t), stride=random(1, 2).to(_size_2_t), ) m.train(random()) x = random_tensor(ndim, *dims).to_global(placement, sbp) y = m(x) func_y = torch.nn.functional.unfold( x, kernel_size=random(1, 3).to(_size_2_t), dilation=random(1, 2).to(_size_2_t), padding=random(0, 1).to(_size_2_t), stride=random(1, 2).to(_size_2_t), ) return y, func_y class TestUnfold(flow.unittest.TestCase): @globaltest def test_unfold_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_unfold_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_unfold_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import numpy as np @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_unfold_tensor_with_random_data(test_case, placement, sbp): ndim = 4 dim = random(0, ndim).to(int).value() x = random_tensor( ndim=ndim, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, dim2=4, dim3=4, ).to_global(placement, sbp) high = x.oneflow.size()[dim] size = random(1, high).to(int).value() step = random(1, high).to(int).value() y = x.unfold(dim, size, step) return y class TestGlobalUnfoldTensor(flow.unittest.TestCase): @globaltest def test_global_unfold_tensor_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_unfold_tensor_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_unique.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import torch as torch_ori import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_unique_unsorted(test_case, placement, sbp): input = random_tensor(ndim=1, dim0=64, high=20).to_global( placement=placement, sbp=sbp ) oneflow_output = flow.unique( input.oneflow, sorted=False, return_inverse=True, return_counts=True ) torch_output = torch_ori.unique( input.pytorch, sorted=False, return_inverse=True, return_counts=True ) oneflow_result, oneflow_indices, oneflow_counts = oneflow_output torch_result, torch_indices, torch_counts = torch_output test_case.assertTrue( np.allclose( np.sort(oneflow_result.to_local().numpy()), np.sort(torch_result.detach().cpu().numpy()), ) ) test_case.assertTrue( np.allclose( oneflow_result[oneflow_indices].numpy(), torch_result[torch_indices].detach().cpu().numpy(), ) ) test_case.assertTrue( np.allclose( oneflow_counts.numpy()[np.argsort(oneflow_result.numpy())], torch_counts.detach() .cpu() .numpy()[np.argsort(torch_result.detach().cpu().numpy())], ) ) def _test_unique_sorted(test_case, placement, sbp): input = random_tensor(ndim=1, dim0=64, high=20).to_global( placement=placement, sbp=sbp ) oneflow_output = flow.unique( input.oneflow, sorted=True, return_inverse=True, return_counts=True ) torch_output = torch_ori.unique( input.pytorch, sorted=True, return_inverse=True, return_counts=True ) oneflow_result, oneflow_indices, oneflow_counts = oneflow_output torch_result, torch_indices, torch_counts = torch_output test_case.assertTrue( np.allclose( oneflow_result.to_local().numpy(), torch_result.detach().cpu().numpy(), ) ) test_case.assertTrue( np.allclose(oneflow_indices.numpy(), torch_indices.detach().cpu().numpy(),) ) test_case.assertTrue( np.allclose(oneflow_counts.numpy(), torch_counts.detach().cpu().numpy(),) ) class TestUniqueModule(flow.unittest.TestCase): @globaltest def test_unique(test_case): for placement in all_placement(): for sbp in all_sbp(placement): _test_unique_unsorted(test_case, placement, sbp) _test_unique_sorted(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_unsqueeze.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_flow_unsqueeze_with_random_data(test_case, placement, sbp): x = random_tensor(2, 8, 16).to_global(placement, sbp) y = torch.unsqueeze(x, random(0, 3).to(int)) return y @autotest(n=1, check_graph=True) def _test_tensor_unsqueeze_with_random_data(test_case, placement, sbp): x = random_tensor(2, 8, 16).to_global(placement, sbp) y = x.unsqueeze(random(0, 3).to(int)) return y @autotest(n=1, check_graph=True) def _test_unsqueeze_with_0_size_data(test_case, placement, sbp): x = random_tensor(3, 8, 16, 0).to_global(placement, sbp) y = torch.unsqueeze(x, random(0, 4).to(int)) return y class TestGlobalUnsqueeze(flow.unittest.TestCase): @globaltest def test_flow_unsqueeze_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_unsqueeze_with_random_data(test_case, placement, sbp) @globaltest def test_tensor_unsqueeze_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_tensor_unsqueeze_with_random_data(test_case, placement, sbp) @globaltest def test_unsqueeze_with_0_size_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_unsqueeze_with_0_size_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_upsample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_upsample2d_nearest(test_case, placement, sbp): x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp) print(x) m = torch.nn.Upsample(scale_factor=random().to(int), mode="nearest",) y = m(x) return y @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_upsample2d_linear(test_case, placement, sbp): x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp) m = torch.nn.Upsample( scale_factor=random().to(int), mode="linear", align_corners=random_bool(), ) y = m(x) return y @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_upsample2d_bilinear(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp) m = torch.nn.Upsample( scale_factor=random().to(int), mode="bilinear", align_corners=random_bool(), ) y = m(x) return y @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_upsample2d_bicubic(test_case, placement, sbp): x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp) m = torch.nn.Upsample( scale_factor=random().to(int), mode="bicubic", align_corners=random_bool(), ) y = m(x) return y @autotest(n=1, auto_backward=True, check_graph=True) def _test_global_upsample2d_trilinear(test_case, placement, sbp): x = random_tensor(ndim=5, dim0=8, dim1=16).to_global(placement, sbp) m = torch.nn.Upsample( scale_factor=random().to(int), mode="trilinear", align_corners=random_bool(), ) y = m(x) return y class TestGlobalUpsample2d(flow.unittest.TestCase): @unittest.skip( "The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200" ) @globaltest def test_global_upsample2d_nearest(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_upsample2d_nearest(test_case, placement, sbp) @globaltest def test_global_upsample2d_linear(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_upsample2d_linear(test_case, placement, sbp) @globaltest def test_global_upsample2d_bilinear(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_upsample2d_bilinear(test_case, placement, sbp) @globaltest def test_global_upsample2d_bicubic(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_upsample2d_bicubic(test_case, placement, sbp) @globaltest def test_global_upsample2d_trilinear(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_upsample2d_trilinear(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_var.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util.generators import random import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_flow_global_var_all_dim_with_random_data(test_case, placement, sbp): x = random_tensor( ndim=2, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, ).to_global(placement, sbp) y = torch.var(x) return y @autotest(n=1, check_graph=True) def _test_flow_global_var_one_dim_with_random_data(test_case, placement, sbp): x = random_tensor( ndim=2, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8, ).to_global(placement, sbp) y = torch.var( x, dim=random(low=0, high=2).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y @autotest(n=1, auto_backward=True, check_graph=True) def _test_flow_var_0_size_data_with_random_data(test_case, placement, sbp): x = random_tensor(3, 8, 0, 8).to_global(placement, sbp) y = torch.var( x, dim=random(low=0, high=3).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y class TestVar(flow.unittest.TestCase): @globaltest def test_flow_global_var_all_dim_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_global_var_all_dim_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_global_var_one_dim_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_global_var_one_dim_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_var_0_size_data_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]): _test_flow_var_0_size_data_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_vector_matrix_product.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=1, check_graph=True) def _test_vector_matrix_product(test_case, placement, sbp): dim = random(1, 6) vec = random_tensor(1, dim0=dim).to_global(placement=placement, sbp=sbp) mat = random_tensor(2, dim0=dim, dim1=constant(4)).to_global( placement=placement, sbp=sbp ) return torch.matmul(vec, mat) class TestGlobalVectorMatrixProduct(flow.unittest.TestCase): @globaltest def test_vector_matrix_product(test_case): for placement in all_placement(): for sbp in all_sbp(placement): _test_vector_matrix_product(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_view.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_global_view(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=32).to_global(placement, sbp) y = x.view(8, 8, 2, -1) return y @autotest(n=1, check_graph=True) def _test_global_view_size(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=32).to_global(placement, sbp) shape = torch.Size([8, 8, 2, -1]) y = x.view(shape) return y class TestGlobalView(flow.unittest.TestCase): @globaltest def test_global_view(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_view(test_case, placement, sbp) @globaltest def test_global_view_size(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_view_size(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_weight_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=False) def _test_global_weight_norm_with_random_data(test_case, placement, sbp): dim = random(-2, 2).to(int).value() liner_model_torch = torch.nn.Linear(8, 16).to_global(placement, sbp) m = torch.nn.utils.weight_norm(liner_model_torch, name="weight", dim=dim) return m.weight_g, m.weight_v class TestGlobalWeightNorm(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 6 times in past week") @globaltest def test_global_weight_norm_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_global_weight_norm_with_random_data(test_case, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_where.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_global_where(test_case, placement, sbp): x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) y = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) condition = random_tensor(ndim=2, dim0=8, dim1=16, high=2, dtype=int).to_global( placement, sbp ) condition = condition.to(torch.bool) z = torch.where(condition, x, y) return z @autotest(n=1, check_graph=True) def _test_global_where_broadcast(test_case, placement, sbp): x = random_tensor(ndim=3, dim0=8, dim1=16, dim2=1).to_global(placement, sbp) y = random_tensor(ndim=3, dim0=8, dim1=16, dim2=8).to_global(placement, sbp) condition = random_tensor( ndim=3, dim0=8, dim1=16, dim2=1, high=2, dtype=int ).to_global(placement, sbp) condition = condition.to(torch.bool) z = torch.where(condition, x, y) return z @autotest(n=1, check_graph=True) def _test_global_where_scalar(test_case, placement, sbp): x = random_tensor(ndim=0).to_global(placement, sbp) y = random_tensor(ndim=0).to_global(placement, sbp) condition = random_tensor(ndim=0, high=2, dtype=int).to_global(placement, sbp) condition = condition.to(torch.bool) z = torch.where(condition, x, y) return z # Close auto_backward because pytorch raise error: # PyTorch error: element 0 of tensors does not require grad and does not have a grad_fn # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(), tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. Because nonzero_op is called by where. @autotest(n=1, auto_backward=False, check_graph="ValidatedFalse") def _test_where_x_y_none(test_case, placement, sbp): condition = random_tensor(ndim=2, dim0=8, dim1=8, low=-1, high=1).to_global( placement, sbp ) y = torch.where(condition) return y[0], y[1] @autotest(n=1, check_graph=True) def _test_global_where_tensor_with_0dim_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random_tensor(ndim=0).to_global(placement, sbp) y = random_tensor(ndim=0).to_global(placement, sbp) return torch.where(cond > 0, x, y) @autotest(n=1, check_graph=True) def _test_flow_where_tensor_broadcast_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=3, dim0=8, dim1=16, dim2=8).to_global(placement, sbp) x = random_tensor(ndim=3, dim0=8, dim1=1, dim2=8).to_global(placement, sbp) y = random_tensor(ndim=3, dim0=8, dim1=16, dim2=1).to_global(placement, sbp) return torch.where(cond > 0, x, y) @autotest(n=1, check_graph=True) def _test_flow_where_scalar_x_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random().to(float) y = ( random_tensor(ndim=2, dim0=8, dim1=16, dtype=float) .to_global(placement, sbp) .to(torch.float64) ) return torch.where(cond > 0, x, y) @autotest(n=1, check_graph=True) def _test_flow_where_scalar_x_broadcast_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp) x = random().to(float) y = ( random_tensor(ndim=2, dim0=8, dim1=1, dtype=float) .to_global(placement, sbp) .to(torch.float64) ) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_x_int_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random().to(int) y = random_tensor(ndim=2, dim0=8, dim1=16, dtype=int).to_global(placement, sbp) return torch.where(cond > 0, x, y) @autotest(n=1, check_graph=True) def _test_flow_where_scalar_y_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = ( random_tensor(ndim=2, dim0=8, dim1=16, dtype=float) .to_global(placement, sbp) .to(torch.float64) ) y = random().to(float) return torch.where(cond > 0, x, y) @autotest(n=1, check_graph=True) def _test_flow_where_scalar_y_broadcast_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp) x = ( random_tensor(ndim=2, dim0=8, dim1=1, dtype=float) .to_global(placement, sbp) .to(torch.float64) ) y = random().to(float) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_y_int_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random_tensor(ndim=2, dim0=8, dim1=16, dtype=int).to_global(placement, sbp) y = random().to(int) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_tensor_bool_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp).to(torch.bool) y = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp).to(torch.bool) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_tensor_broadcast_bool_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp).to(torch.bool) y = random_tensor(ndim=2, dim0=8, dim1=1).to_global(placement, sbp).to(torch.bool) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_x_bool_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random().to(bool) y = ( random_tensor(ndim=2, dim0=8, dim1=16, dtype=float) .to_global(placement, sbp) .to(torch.bool) ) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_x_broadcast_bool_with_random_data( test_case, placement, sbp ): cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp) x = random().to(bool) y = ( random_tensor(ndim=2, dim0=8, dim1=1, dtype=float) .to_global(placement, sbp) .to(torch.bool) ) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_y_bool_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = ( random_tensor(ndim=2, dim0=8, dim1=16, dtype=float) .to_global(placement, sbp) .to(torch.bool) ) y = random().to(bool) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_y_broadcast_bool_with_random_data( test_case, placement, sbp ): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = ( random_tensor(ndim=2, dim0=8, dim1=1, dtype=float) .to_global(placement, sbp) .to(torch.bool) ) y = random().to(bool) return torch.where(cond > 0, x, y) @autotest(n=1, auto_backward=False, check_graph=True) def _test_flow_where_scalar_xy_bool_with_random_data(test_case, placement, sbp): cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) x = random().to(bool) y = random().to(bool) return torch.where(cond > 0, x, y) class TestGlobalWhere(flow.unittest.TestCase): @globaltest def test_global_where(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_where(test_case, placement, sbp) @globaltest def test_global_where_broadcast(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_global_where_broadcast(test_case, placement, sbp) @globaltest def test_global_where_scalar(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_global_where_scalar(test_case, placement, sbp) @globaltest def test_where_x_y_none(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_where_x_y_none(test_case, placement, sbp) @globaltest def test_global_where_tensor_with_0dim_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_global_where_tensor_with_0dim_data(test_case, placement, sbp) @globaltest def test_flow_where_tensor_broadcast_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=1): _test_flow_where_tensor_broadcast_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_x_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_x_with_random_data(test_case, placement, sbp) @globaltest def test_flow_where_scalar_x_broadcast_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_x_broadcast_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_x_int_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_x_int_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_y_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_y_with_random_data(test_case, placement, sbp) @globaltest def test_flow_where_scalar_y_broadcast_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_y_broadcast_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_y_int_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_y_int_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_tensor_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): _test_flow_where_tensor_bool_with_random_data(test_case, placement, sbp) @globaltest def test_flow_where_tensor_broadcast_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_tensor_broadcast_bool_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_x_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_x_bool_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_x_broadcast_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_x_broadcast_bool_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_y_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_y_bool_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_y_broadcast_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_y_broadcast_bool_with_random_data( test_case, placement, sbp ) @globaltest def test_flow_where_scalar_xy_bool_with_random_data(test_case): for placement in all_placement(): for sbp in all_sbp(placement, except_split=True): _test_flow_where_scalar_xy_bool_with_random_data( test_case, placement, sbp ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_zeropad2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @autotest(n=1, check_graph=True) def _test_global_ZeroPad2d(test_case, placement, sbp, padding): x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=8, dim3=8,).to_global( placement, sbp ) m = torch.nn.ZeroPad2d(padding) y = m(x) return y class TestGlobalZeroPad2dModule(flow.unittest.TestCase): @globaltest def test_global_ZeroPad2d(test_case): arg_dict = OrderedDict() arg_dict["padding"] = [2, (1, 1, 2, 2)] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=4): _test_global_ZeroPad2d(test_case, placement, sbp, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_global_zeros_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_zeros_like_float(test_case, placement, sbp, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) x = x.to_global(placement=placement, sbp=sbp) y = flow.zeros_like(x, placement=placement, sbp=sbp) test_case.assertTrue(y.dtype is flow.float32) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.placement == placement) y_numpy = np.zeros(x.numpy().shape) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) def _test_zeros_like_int(test_case, placement, sbp, shape, device): x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device)) x = x.to_global(placement=placement, sbp=sbp) y = flow.zeros_like(x, dtype=flow.int, placement=placement, sbp=sbp) test_case.assertTrue(y.dtype is flow.int) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.placement == placement) y_numpy = np.zeros(x.numpy().shape) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) class TestModule(flow.unittest.TestCase): @globaltest def test_zeros_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_zeros_like_float, _test_zeros_like_int] arg_dict["shape"] = [(8, 8), (8, 8, 4), (8, 8, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): for placement in all_placement(): for sbp in all_sbp(placement, max_dim=2): arg[0](test_case, placement, sbp, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_glu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestGluModule(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_glu_module_with_random_data(test_case): device = random_device() dim = random(-3, 3).to(int) m = torch.nn.functional.glu x = random_tensor(ndim=3, dim0=2, dim1=4, dim2=6).to(device) y = m(x, dim) return y @autotest(n=5, check_graph=True) def test_glu_module_with_random_data(test_case): device = random_device() m = torch.nn.GLU() m.train(random()) m.to(device) x = random_tensor(ndim=3, dim0=2, dim1=4, dim2=6).to(device) y = m(x) return y @profile(torch.nn.functional.glu) def profile_glu(test_case): input = torch.ones(1000, 1000) torch.nn.functional.glu(input) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_gpt_data_loader.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest class GPTDataLoader(flow.nn.Module): def __init__( self, data_file_prefix=flow.unittest.dataset_dir( "Megatron-LM/dummy/gpt_sample_dataset_text_document" ), seq_length=1024, num_samples=648, batch_size=8, shuffle=True, random_seed=12345, device=None, placement=None, sbp=None, ): super().__init__() self.loader_ = flow.nn.GPTIndexedBinDataReader( data_file_prefix=data_file_prefix, seq_length=seq_length, num_samples=num_samples, batch_size=batch_size, shuffle=shuffle, random_seed=random_seed, device=device, placement=placement, sbp=sbp, ) def forward(self): return self.loader_() class DataLoaderGraph(flow.nn.Graph): def __init__(self, loader): super().__init__() self.loader_ = loader def build(self): return self.loader_() @unittest.skipIf( os.getenv("ONEFLOW_TEST_GITHUB_HOSTED"), "/dataset not available on GitHub hosted servers", ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class GPTDataLoaderDistributedTestCase(oneflow.unittest.TestCase): def test_case1(test_case): rank = flow.env.get_rank() # print( # f"GPTDataLoaderDistributedTestCase.test_case1 on rank {rank} {os.getpid()}" # ) eager_gpt_loader = GPTDataLoader(batch_size=4, device=flow.device("cpu", rank)) global_gpt_loader = GPTDataLoader( batch_size=8, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)], ) gpt_loader_graph = DataLoaderGraph(global_gpt_loader) iteration = 2 for i in range(iteration): tokens = eager_gpt_loader() # print( # f"rank {rank} tokens: {tokens.shape}, {tokens.dtype}, device: {tokens.device}" # f"\n{tokens.numpy()}" # ) g_tokens = gpt_loader_graph() # print( # f"rank {rank} graph output tokens: {g_tokens.shape}, {g_tokens.dtype}" # f", placement: {g_tokens.placement}" # f"\n{g_tokens.to_local().numpy()}" # ) # print(f"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}") test_case.assertTrue( np.allclose(tokens.numpy(), g_tokens.to_local().numpy()) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_greater.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_greater_normal(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.gt(input1, input2) np_out = np.greater(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_symbol(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = input1 > input2 np_out = np.greater(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_int_scalar(test_case, device): np_arr = np.random.randn(2, 3, 4, 5) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 1 of_out = input1 > input2 np_out = np.greater(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_int_tensor_int_scalar(test_case, device): np_arr = np.random.randint(2, size=(2, 3, 4, 5)) input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device)) input2 = 1 of_out = input1 > input2 np_out = np.greater(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_float_scalar(test_case, device): np_arr = np.random.randn(3, 2, 5, 7) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 2.3 of_out = input1 > input2 np_out = np.greater(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestGreater(flow.unittest.TestCase): def test_greater(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_greater_normal, _test_greater_symbol, _test_greater_int_scalar, _test_greater_int_tensor_int_scalar, _test_greater_float_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False, check_graph=True) def test_greater_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.gt(x1, oneof(x2, random().to(int), random().to(float))) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_tensor_inplace_greater_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x1.gt_(oneof(x2, random().to(int), random().to(float))) return x1 @autotest(n=5, auto_backward=False, check_graph=True) def test_tensor_greater_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y1 = x1.gt(oneof(x2, random().to(int), random().to(float))) y2 = x1 > x2 return (y1, y2) @autotest(n=5, auto_backward=False, check_graph=True) def test_greater_with_0_size_data(test_case): device = random_device() x1 = random_tensor(4, 2, 3, 0, 5).to(device) x2 = random_tensor(4, 2, 3, 0, 5).to(device) y1 = torch.gt(x1, x2) y2 = x1 > x2 return (y1, y2) @autotest(n=5, auto_backward=False, check_graph=True) def test_greater_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) x2 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.gt(x1, oneof(x2, random().to(int), random().to(float))) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_greater_with_0dim_data(test_case): device = random_device() x1 = random_tensor(ndim=0).to(device) x2 = random_tensor(ndim=0).to(device) y1 = torch.gt(x1, x2) y2 = x1 > x2 return (y1, y2) @profile(torch.gt) def profile_gt(test_case): input = torch.ones(1000, 1000) other = torch.ones(1000, 1000) torch.gt(input, other) torch.gt(input, 0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_greater_equal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_greater_equal_normal(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow.ge(input1, input2) np_out = np.greater_equal(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_equal_symbol(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = input1 >= input2 np_out = np.greater_equal(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_equal_int_scalar(test_case, device): np_arr = np.random.randn(2, 3, 4, 5) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 1 of_out = input1 >= input2 np_out = np.greater_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_equal_int_tensor_int_scalr(test_case, device): np_arr = np.random.randint(2, size=(2, 3, 4, 5)) input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device)) input2 = 1 of_out = input1 >= input2 np_out = np.greater_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_greater_equal_float_scalar(test_case, device): np_arr = np.random.randn(3, 2, 5, 7) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 2.3 of_out = input1 >= input2 np_out = np.greater_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestGreaterEqual(flow.unittest.TestCase): def test_greter_equal(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_greater_equal_normal, _test_greater_equal_symbol, _test_greater_equal_int_scalar, _test_greater_equal_int_tensor_int_scalr, _test_greater_equal_float_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_grid_sample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import randint from random import choice import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestGridSample(flow.unittest.TestCase): def test_grid_sample_4d(test_case): input = flow.tensor( np.arange(1.0, 11).reshape((1, 1, 2, 5)), dtype=flow.float32 ) np_grid = np.array( [ [[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]], ] ).reshape(1, 2, 5, 2) grid = flow.tensor(np_grid, dtype=flow.float32) groundtruth = np.reshape( np.array([[0.0, 8.0, 5.0, 7.0, 9.0], [1.0, 8.0, 5.0, 8.0, 0.0]]), (1, 1, 2, 5), ) output = flow.nn.functional.grid_sample( input, grid, mode="nearest", padding_mode="zeros", align_corners=True ) test_case.assertTrue( np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(rtol=1e-03, atol=1e-04, check_graph=True) def test_flow_grid_sample_cudnn_with_random_data(test_case): # cudnn only support 4D input, with mode = 'bilinear' && padding_mode = 'zeros' && align_corners N = randint(1, 8) C = randint(1, 8) in_H = randint(1, 8) in_W = randint(1, 8) out_H = randint(1, 8) out_W = randint(1, 8) device = "cuda" mode = "bilinear" padding_mode = "zeros" align_corners = True theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) grid = torch.nn.functional.affine_grid( theta, (N, C, out_H, out_W), align_corners=align_corners ).to(device) input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to(device) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output # This test may fail due to using ::floor in backward # floor(1.99999988) = 1 and floor(2.000000) = 2, then select differente images pixel @autotest( auto_backward=False, rtol=1e-03, atol=1e-04, check_graph=True, check_allclose=False, ) def test_flow_grid_sample_4d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) in_H = randint(1, 8) in_W = randint(1, 8) out_H = randint(1, 8) out_W = randint(1, 8) device = random_device() mode = choice(["bilinear", "nearest", "bicubic"]) padding_mode = choice(["zeros", "border", "reflection"]) align_corners = choice([True, False]) theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) grid = torch.nn.functional.affine_grid( theta, (N, C, out_H, out_W), align_corners=align_corners ).to(device) input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to(device) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output @autotest(auto_backward=False, rtol=1e-03, atol=1e-03, check_graph=True) def test_flow_grid_sample_5d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) in_D = randint(1, 8) in_H = randint(1, 8) in_W = randint(1, 8) out_D = randint(1, 8) out_H = randint(1, 8) out_W = randint(1, 8) device = random_device() mode = choice(["bilinear", "nearest"]) padding_mode = choice(["zeros", "border", "reflection"]) align_corners = choice([True, False]) theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device) grid = torch.nn.functional.affine_grid( theta, (N, C, out_D, out_H, out_W), align_corners=align_corners ).to(device) input = random_tensor( ndim=5, dim0=N, dim1=C, dim2=in_D, dim3=in_H, dim4=in_W ).to(device) output = torch.nn.functional.grid_sample( input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) return output @profile(torch.nn.functional.grid_sample) def profile_grid_sample(test_case): input = torch.ones(32, 3, 128, 128) grid = torch.ones(32, 64, 64, 2) torch.nn.functional.grid_sample(input, grid) torch.nn.functional.grid_sample(input, grid, align_corners=True) torch.nn.functional.grid_sample(input, grid, mode="nearest", align_corners=True) torch.nn.functional.grid_sample(input, grid, mode="bicubic", align_corners=True) torch.nn.functional.grid_sample(input, grid, padding_mode="border") torch.nn.functional.grid_sample(input, grid, padding_mode="reflection") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_grouped_matmul_bias.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import math import os import oneflow as flow def _ref(xs, weights, biases): if biases is None: return [ flow._C.matmul(x, w, transpose_a=False, transpose_b=True) for x, w in zip(xs, weights) ] else: return [ flow._C.matmul(x, w, transpose_a=False, transpose_b=True) + b for x, w, b in zip(xs, weights, biases) ] def _grouped(xs, weights, biases): if biases is None: return flow._C.grouped_matmul(xs, weights) else: return flow._C.grouped_matmul_bias(xs, weights, biases) def _test_grouped_matmul_bias(test_case, dtype, problems, bias): xs = [ flow.randn((m, k), device="cuda", dtype=dtype) / 10.0 for (m, n, k) in problems ] ws = [ flow.randn((n, k), device="cuda", dtype=dtype) / 10.0 for (m, n, k) in problems ] bs = [flow.randn((n), device="cuda", dtype=dtype) / 10.0 for (m, n, k) in problems] ref_out = _ref(xs, ws, bs if bias else None) grouped_out = _grouped(xs, ws, bs if bias else None) for (ref_y, grouped_y) in zip(ref_out, grouped_out): test_case.assertTrue( np.allclose(ref_y.numpy(), grouped_y.numpy(), atol=1e-2, rtol=1e-2) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGroupedMatmulBias(flow.unittest.TestCase): def test_grouped_matmul_bias(test_case): problems = [(2, 1280, 1280)] * 12 + [(2, 1280, 640)] * 4 + [(2, 1280, 320)] * 5 _test_grouped_matmul_bias(test_case, flow.float16, problems, True) _test_grouped_matmul_bias(test_case, flow.float16, problems, False) problems = ( [(2 * 77, 768, 1280)] * 6 + [(2 * 77, 768, 640)] * 5 + [(2 * 77, 768, 320)] * 5 ) _test_grouped_matmul_bias(test_case, flow.float16, problems, True) _test_grouped_matmul_bias(test_case, flow.float16, problems, False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_groupnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _test_groupnorm(test_case, device): input_arr = np.array( [ [ [ [-0.8791, 0.2553, 0.7403, -0.2859], [0.8006, -1.7701, -0.9617, 0.1705], [0.2842, 1.7825, 0.3365, -0.8525], ], [ [0.7332, -0.0737, 0.7245, -0.6551], [1.4461, -0.1827, 0.9737, -2.1571], [0.4657, 0.7244, 0.3378, 0.1775], ], ], [ [ [1.8896, 1.8686, 0.1896, 0.9817], [-0.0671, 1.5569, 1.1449, 0.0086], [-0.9468, -0.0124, 1.3227, -0.6567], ], [ [-0.8472, 1.3012, -1.1065, 0.9348], [1.0346, 1.5703, 0.2419, -0.7048], [0.6957, -0.4523, -0.8819, 1.0164], ], ], ], dtype=np.float32, ) output = np.array( [ [ [ [-1.0548115, 0.18125379, 0.7097197, -0.4084487], [0.77542377, -2.0256634, -1.1448141, 0.08885399], [0.21274385, 1.845322, 0.26973096, -1.0258276], ], [ [0.7019834, -0.17723128, 0.6925037, -0.81073654], [1.4787737, -0.2959999, 0.96403706, -2.4473464], [0.4105099, 0.69239473, 0.2711475, 0.09648134], ], ], [ [ [1.5438884, 1.5218256, -0.24213786, 0.5900453], [-0.5118278, 1.1943525, 0.76150376, -0.43229714], [-1.4360437, -0.4543598, 0.94830114, -1.1312639], ], [ [-1.3314037, 0.9257132, -1.6038253, 0.54077196], [0.6456222, 1.2084305, -0.18719131, -1.1817979], [0.28957263, -0.91652036, -1.3678597, 0.6265012], ], ], ], dtype=np.float32, ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) m = flow.nn.GroupNorm(num_groups=1, num_channels=2).to(device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-03, 1e-03)) def _test_groupnorm_3d(test_case, device): input_arr = np.array( [ [ [ [ [1.04569761, 0.22863248, 1.42439335, 1.62249689], [-0.80578825, -0.27276461, 1.04556507, 0.56864134], [-1.24085419, -1.23960097, 0.33451416, -1.84820402], ], [ [-1.511261, 1.06157517, -0.26715858, -1.32888141], [1.17976881, -0.07931171, 0.33910684, -1.93458573], [-1.72659647, 0.79049652, 0.39102785, -1.16264882], ], ], [ [ [0.30067973, -1.2912226, -0.61508225, 0.56454001], [0.87074187, -1.69257376, 0.36119148, -0.31014289], [0.20776964, 1.26195488, -1.37122193, -0.17945234], ], [ [-0.31112407, -0.80682631, 0.8233194, 0.6384975], [0.57617527, 0.45505028, 1.68286151, -1.09590744], [-1.18127546, -1.07529277, 0.52779943, 1.21755926], ], ], ], [ [ [ [-0.12832351, 1.05625455, -0.23253249, -0.64747611], [-0.00738123, -1.41390089, -1.92664144, -0.21427625], [-0.94631219, -0.86493989, 0.21026905, 0.24989732], ], [ [1.3859182, 1.72002107, 0.50091892, 1.04198896], [0.71694594, 1.66417023, -1.63030052, 0.77182641], [0.71545083, 1.96458366, -1.99031931, 1.3196714], ], ], [ [ [1.80091702, 0.02834973, 0.82259214, -1.05597501], [-0.58212207, 0.44205949, -0.14740003, -0.994508], [1.14678114, -0.39196097, 1.2554798, -0.41829324], ], [ [-1.0153903, -0.25755713, -1.81756333, -1.06781159], [1.79680841, -1.9107133, -0.64325796, -1.94640775], [1.30671156, 1.20445339, -1.26262901, -0.79494188], ], ], ], ], dtype=np.float32, ) output = np.array( [ [ [ [ [1.0670303, 0.3324034, 1.4075173, 1.5856332], [-0.5976489, -0.11840499, 1.0669112, 0.6381069], [-0.9888186, -0.9876919, 0.42760208, -1.5348896], ], [ [-1.2319425, 1.0813059, -0.11336456, -1.0679643], [1.1875744, 0.05552938, 0.43173137, -1.6125557], [-1.4255517, 0.8375778, 0.4784138, -0.9185038], ], ], [ [ [0.3447361, -1.3750811, -0.6446106, 0.62979853], [0.9606047, -1.8086823, 0.41011015, -0.3151683], [0.24436034, 1.3832531, -1.4615086, -0.17397629], ], [ [-0.31622827, -0.8517619, 0.9093717, 0.7096987], [0.6423687, 0.51151085, 1.8379811, -1.1640717], [-1.2562994, -1.1418006, 0.59010565, 1.3352901], ], ], ], [ [ [ [-0.23265934, 0.8016156, -0.32364592, -0.6859402], [-0.12706259, -1.3551185, -1.802801, -0.30770612], [-0.946859, -0.8758114, 0.06297152, 0.09757163], ], [ [1.0894505, 1.3811613, 0.3167428, 0.78916013], [0.50535965, 1.3323971, -1.5440607, 0.55327666], [0.50405425, 1.5946931, -1.8583992, 1.0316093], ], ], [ [ [1.7506906, 0.19012147, 0.8893728, -0.7645185], [-0.3473382, 0.5543517, 0.03539129, -0.71040297], [1.174789, -0.17992027, 1.2704874, -0.20310321], ], [ [-0.7287877, -0.06159106, -1.4350212, -0.7749395], [1.7470733, -1.5170306, -0.40116227, -1.548456], [1.3155918, 1.2255636, -0.9464568, -0.53470486], ], ], ], ], dtype=np.float32, ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) m = flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False).to( device=flow.device(device) ) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-03, 1e-03)) def _test_groupnorm_backward(test_case, device): input_arr = np.array( [ [ [ [-0.8791, 0.2553, 0.7403, -0.2859], [0.8006, -1.7701, -0.9617, 0.1705], [0.2842, 1.7825, 0.3365, -0.8525], ], [ [0.7332, -0.0737, 0.7245, -0.6551], [1.4461, -0.1827, 0.9737, -2.1571], [0.4657, 0.7244, 0.3378, 0.1775], ], ], [ [ [1.8896, 1.8686, 0.1896, 0.9817], [-0.0671, 1.5569, 1.1449, 0.0086], [-0.9468, -0.0124, 1.3227, -0.6567], ], [ [-0.8472, 1.3012, -1.1065, 0.9348], [1.0346, 1.5703, 0.2419, -0.7048], [0.6957, -0.4523, -0.8819, 1.0164], ], ], ], dtype=np.float32, ) x = flow.tensor( input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = flow.nn.GroupNorm(num_groups=1, num_channels=2).to(device=flow.device(device)) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03) ) def _test_groupnorm_backward_fp16(test_case, device): input_arr = np.array( [ [ [ [-0.8791, 0.2553, 0.7403, -0.2859], [0.8006, -1.7701, -0.9617, 0.1705], [0.2842, 1.7825, 0.3365, -0.8525], ], [ [0.7332, -0.0737, 0.7245, -0.6551], [1.4461, -0.1827, 0.9737, -2.1571], [0.4657, 0.7244, 0.3378, 0.1775], ], ], [ [ [1.8896, 1.8686, 0.1896, 0.9817], [-0.0671, 1.5569, 1.1449, 0.0086], [-0.9468, -0.0124, 1.3227, -0.6567], ], [ [-0.8472, 1.3012, -1.1065, 0.9348], [1.0346, 1.5703, 0.2419, -0.7048], [0.6957, -0.4523, -0.8819, 1.0164], ], ], ], dtype=np.float16, ) x = flow.tensor( input_arr, dtype=flow.float16, device=flow.device(device), requires_grad=True ) m = ( flow.nn.GroupNorm(num_groups=1, num_channels=2) .to(device=flow.device(device)) .to(flow.float16) ) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03) ) def _test_groupnorm_backward_3d(test_case, device): input_arr = np.array( [ [ [ [ [1.04569761, 0.22863248, 1.42439335, 1.62249689], [-0.80578825, -0.27276461, 1.04556507, 0.56864134], [-1.24085419, -1.23960097, 0.33451416, -1.84820402], ], [ [-1.511261, 1.06157517, -0.26715858, -1.32888141], [1.17976881, -0.07931171, 0.33910684, -1.93458573], [-1.72659647, 0.79049652, 0.39102785, -1.16264882], ], ], [ [ [0.30067973, -1.2912226, -0.61508225, 0.56454001], [0.87074187, -1.69257376, 0.36119148, -0.31014289], [0.20776964, 1.26195488, -1.37122193, -0.17945234], ], [ [-0.31112407, -0.80682631, 0.8233194, 0.6384975], [0.57617527, 0.45505028, 1.68286151, -1.09590744], [-1.18127546, -1.07529277, 0.52779943, 1.21755926], ], ], ], [ [ [ [-0.12832351, 1.05625455, -0.23253249, -0.64747611], [-0.00738123, -1.41390089, -1.92664144, -0.21427625], [-0.94631219, -0.86493989, 0.21026905, 0.24989732], ], [ [1.3859182, 1.72002107, 0.50091892, 1.04198896], [0.71694594, 1.66417023, -1.63030052, 0.77182641], [0.71545083, 1.96458366, -1.99031931, 1.3196714], ], ], [ [ [1.80091702, 0.02834973, 0.82259214, -1.05597501], [-0.58212207, 0.44205949, -0.14740003, -0.994508], [1.14678114, -0.39196097, 1.2554798, -0.41829324], ], [ [-1.0153903, -0.25755713, -1.81756333, -1.06781159], [1.79680841, -1.9107133, -0.64325796, -1.94640775], [1.30671156, 1.20445339, -1.26262901, -0.79494188], ], ], ], ], dtype=np.float32, ) x = flow.tensor( input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) m = flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False).to( device=flow.device(device) ) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03) ) def _test_groupnorm_backward_3d_fp16(test_case, device): input_arr = np.array( [ [ [ [ [1.04569761, 0.22863248, 1.42439335, 1.62249689], [-0.80578825, -0.27276461, 1.04556507, 0.56864134], [-1.24085419, -1.23960097, 0.33451416, -1.84820402], ], [ [-1.511261, 1.06157517, -0.26715858, -1.32888141], [1.17976881, -0.07931171, 0.33910684, -1.93458573], [-1.72659647, 0.79049652, 0.39102785, -1.16264882], ], ], [ [ [0.30067973, -1.2912226, -0.61508225, 0.56454001], [0.87074187, -1.69257376, 0.36119148, -0.31014289], [0.20776964, 1.26195488, -1.37122193, -0.17945234], ], [ [-0.31112407, -0.80682631, 0.8233194, 0.6384975], [0.57617527, 0.45505028, 1.68286151, -1.09590744], [-1.18127546, -1.07529277, 0.52779943, 1.21755926], ], ], ], [ [ [ [-0.12832351, 1.05625455, -0.23253249, -0.64747611], [-0.00738123, -1.41390089, -1.92664144, -0.21427625], [-0.94631219, -0.86493989, 0.21026905, 0.24989732], ], [ [1.3859182, 1.72002107, 0.50091892, 1.04198896], [0.71694594, 1.66417023, -1.63030052, 0.77182641], [0.71545083, 1.96458366, -1.99031931, 1.3196714], ], ], [ [ [1.80091702, 0.02834973, 0.82259214, -1.05597501], [-0.58212207, 0.44205949, -0.14740003, -0.994508], [1.14678114, -0.39196097, 1.2554798, -0.41829324], ], [ [-1.0153903, -0.25755713, -1.81756333, -1.06781159], [1.79680841, -1.9107133, -0.64325796, -1.94640775], [1.30671156, 1.20445339, -1.26262901, -0.79494188], ], ], ], ], dtype=np.float16, ) x = flow.tensor( input_arr, dtype=flow.float16, device=flow.device(device), requires_grad=True ) m = ( flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False) .to(device=flow.device(device)) .to(flow.float16) ) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03) ) def _test_groupnorm_nhwc(test_case, shape, num_groups): (n, c, h, w) = shape x = flow.tensor( np.random.uniform(low=0.0, high=1.0, size=shape).astype(np.float32) ).to("cuda") gamma = flow.tensor( np.random.uniform(low=0.0, high=1.0, size=(c)).astype(np.float32) ).to("cuda") beta = flow.tensor( np.random.uniform(low=0.0, high=1.0, size=(c)).astype(np.float32) ).to("cuda") y = flow._C.group_norm(x, gamma, beta, True, num_groups, 1e-5) x_nhwc = x.permute(0, 2, 3, 1).contiguous() y_nhwc = flow._C.group_norm( x_nhwc, gamma, beta, True, num_groups, 1e-5, "channels_last" ) test_case.assertTrue( np.allclose(y_nhwc.permute(0, 3, 1, 2).numpy(), y, 1e-03, 1e-03) ) @flow.unittest.skip_unless_1n1d() class TestGroupNorm(flow.unittest.TestCase): def test_groupnorm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_groupnorm, _test_groupnorm_3d, _test_groupnorm_backward, _test_groupnorm_backward_3d, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_groupnorm_grad_fp16(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_groupnorm_backward_fp16, _test_groupnorm_backward_3d_fp16, ] # cpu test will raise error: var only support floating point dtypes # https://github.com/Oneflow-Inc/oneflow/issues/9559 # arg_dict["device"] = ["cpu", "cuda"] arg_dict["device"] = ["cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(rtol=1e-03, atol=1e-03, check_graph=True) def test_group_norm_with_random_data(test_case): channels = random(5, 20) m = torch.nn.GroupNorm( num_groups=random(1, 5), num_channels=channels, eps=random(0, 1) | nothing(), affine=random(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=channels).to(device) y = m(x) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_groupnorm_nhwc(test_case): _test_groupnorm_nhwc(test_case, (16, 64, 128, 128), 32) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_groupwise_quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import math import os import oneflow as flow def _pack_int8_to_int4(x): np_x = x.numpy() l = np_x[..., 0::2] r = np_x[..., 1::2] l = np.left_shift(l, 4) if x.dtype is flow.int8: r = np.bitwise_and(r, np.int8(0xF)) packed = flow.tensor(np.bitwise_or(l, r), device=x.device) return packed def _unpack_int4_to_int8(x): np_x = x.numpy() l = np.right_shift(np_x, 4).reshape(x.shape + (1,)) r = np.right_shift(np.left_shift(np_x, 4), 4).reshape(x.shape + (1,)) unpacked = np.concatenate((l, r), -1).reshape(x.shape[0:-1] + (x.shape[-1] * 2,)) unpacked = flow.tensor(unpacked, device=x.device) return unpacked def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type): x_float = x.float() x_reshaped = x_float.reshape( x.shape[:group_dim] + (x.shape[group_dim] // group_size, group_size) + x.shape[group_dim + 1 :] ) if symmetric: signed_max = float(2 ** (num_bits - 1)) - 1 offset = signed_max if quant_type is flow.uint8 else 0.0 scale_float = ( x_reshaped.abs().max(dim=group_dim + 1, keepdim=True).values / signed_max ) quantized = ( flow.round(x_reshaped / scale_float + offset) .reshape(x.shape) .to(quant_type) ) if num_bits == 4: quantized = _pack_int8_to_int4(quantized) return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None) else: unsigned_max = float(2 ** num_bits) - 1 mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values scale_float = (mx - mn) / unsigned_max quantized = ( flow.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(flow.uint8) ) if num_bits == 4: quantized = _pack_int8_to_int4(quantized) return ( quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), mn.squeeze(group_dim + 1).to(x.dtype), ) def _dequantize_ref(num_bits, symmetric, quantized, scale, zero, group_dim, group_size): if num_bits == 4: quantized = _unpack_int4_to_int8(quantized) scale_reshaped = scale.unsqueeze(group_dim + 1) quantized_reshaped = quantized.reshape( quantized.shape[:group_dim] + (quantized.shape[group_dim] // group_size, group_size) + quantized.shape[group_dim + 1 :] ) if symmetric: offset = ( float(2 ** (num_bits - 1)) - 1 if quantized.dtype is flow.uint8 else 0.0 ) dequantized = (quantized_reshaped.to(scale.dtype) - offset) * scale_reshaped else: zero_reshaped = zero.unsqueeze(group_dim + 1) dequantized = ( zero_reshaped + quantized_reshaped.to(scale.dtype) * scale_reshaped ) return dequantized.reshape(quantized.shape) def _dequantize(num_bits, symmetric, x, scale, zero, group_dim, group_size): return flow._C.groupwise_dequantize( x, scale=scale, zero=zero, group_dim=group_dim, group_size=group_size, num_bits=num_bits, symmetric=symmetric, ) def _test_dequantize(test_case, num_bits, shape, group_dim, group_size): for dtype in [flow.float, flow.float16]: x = flow.randn(shape, device="cuda", dtype=flow.float,).to(dtype) for symmetric in [True, False]: for quant_type in [flow.int8, flow.uint8] if symmetric else [flow.uint8]: quantized, scale, zero = _quantize( num_bits, symmetric, x, group_dim, group_size, quant_type ) dequantized = _dequantize( num_bits, symmetric, quantized, scale, zero, group_dim, group_size ) dequantized_ref = _dequantize_ref( num_bits, symmetric, quantized, scale, zero, group_dim, group_size, ) test_case.assertTrue( np.allclose(dequantized_ref, dequantized, atol=1e-2, rtol=1e-2) ) def _test_fused_linear(test_case, num_bits, m, k, n, group_dim, group_size): for dtype in [flow.float16, flow.float]: x = flow.randn((m, k), device="cuda", dtype=flow.float,).to(dtype) / 10 w = flow.randn((n, k), device="cuda", dtype=flow.float,).to(dtype) / 10 b = flow.randn((n), device="cuda", dtype=flow.float,).to(dtype) / 10 for symmetric in [True, False]: for quant_type in [flow.int8, flow.uint8] if symmetric else [flow.uint8]: w_quantized, w_scale, w_zero = _quantize( num_bits, symmetric, w, group_dim, group_size, quant_type ) fused_out = flow._C.fused_linear_with_groupwise_quantized_weight( x=x, w=w_quantized, w_scale=w_scale, w_zero=w_zero, b=b, num_bits=num_bits, symmetric=symmetric, group_dim=group_dim, group_size=group_size, ) ref = ( flow.matmul( x, _dequantize( num_bits, symmetric, w_quantized, w_scale, w_zero, group_dim, group_size, ).t(), ) + b ) test_case.assertTrue(np.allclose(ref, fused_out, atol=1e-2, rtol=1e-2)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGroupWiseQuantization(flow.unittest.TestCase): def test_dequantize(test_case): _test_dequantize(test_case, 8, (128, 256), 0, 128) _test_dequantize(test_case, 8, (64, 128, 256), 0, 64) _test_dequantize(test_case, 8, (64, 128, 256), 1, 128) _test_dequantize(test_case, 8, (64, 128, 256), 2, 256) _test_dequantize(test_case, 8, (63, 127, 255), 0, 63) _test_dequantize(test_case, 8, (63, 127, 255), 1, 127) _test_dequantize(test_case, 8, (63, 127, 255), 2, 255) _test_dequantize(test_case, 8, (128, 256), 1, 256 // 4) _test_dequantize(test_case, 8, (128, 256), 0, 128 // 4) _test_dequantize(test_case, 8, (64, 128, 256), 0, 64 // 4) _test_dequantize(test_case, 8, (64, 128, 256), 1, 128 // 4) _test_dequantize(test_case, 8, (64, 128, 256), 2, 256 // 4) _test_dequantize(test_case, 4, (128, 256), 1, 256) _test_dequantize(test_case, 4, (128, 256), 0, 128) _test_dequantize(test_case, 4, (64, 128, 256), 0, 64) _test_dequantize(test_case, 4, (64, 128, 256), 1, 128) _test_dequantize(test_case, 4, (64, 128, 256), 2, 256) _test_dequantize(test_case, 4, (128, 256), 1, 256 // 4) _test_dequantize(test_case, 4, (128, 256), 0, 128 // 4) _test_dequantize(test_case, 4, (64, 128, 256), 0, 64 // 4) _test_dequantize(test_case, 4, (64, 128, 256), 1, 128 // 4) _test_dequantize(test_case, 4, (64, 128, 256), 2, 256 // 4) def test_fused_linear(test_case): _test_fused_linear(test_case, 8, 1, 64, 128, 0, 128) _test_fused_linear(test_case, 8, 1, 64, 128, 1, 64) _test_fused_linear(test_case, 8, 16, 64, 128, 0, 128) _test_fused_linear(test_case, 8, 16, 64, 128, 1, 64) _test_fused_linear(test_case, 8, 1, 63, 127, 0, 127) _test_fused_linear(test_case, 8, 1, 63, 127, 1, 63) _test_fused_linear(test_case, 8, 1, 256, 512, 0, 64) _test_fused_linear(test_case, 8, 1, 256, 512, 1, 64) _test_fused_linear(test_case, 4, 1, 256, 512, 0, 512) _test_fused_linear(test_case, 4, 1, 256, 512, 1, 256) _test_fused_linear(test_case, 4, 1, 256, 512, 0, 64) _test_fused_linear(test_case, 4, 1, 256, 512, 1, 64) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_gumbel_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.nn.functional as F import oneflow.unittest def _test_gumbel_softmax(test_case, tau, dim, device, dtype): dtype = type_name_to_flow_type[dtype] x = flow.tensor(np.random.randn(20, 32), dtype=dtype, device=flow.device(device),) y_soft = F.gumbel_softmax(x, tau=tau, dim=dim) y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True) test_case.assertEqual(x.shape, y_soft.shape) test_case.assertEqual(x.shape, y_hard.shape) test_case.assertEqual(x.dtype, y_soft.dtype) test_case.assertEqual(x.dtype, y_hard.dtype) def _test_gumbel_softmax_hard(test_case, tau, dim, device, dtype): dtype = type_name_to_flow_type[dtype] x = flow.tensor(np.random.randn(45, 23), dtype=dtype, device=flow.device(device),) y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True) test_case.assertEqual(y_hard.min(), 0) if dim == -1: test_case.assertEqual(y_hard.sum().item(), 45) elif dim == 0: test_case.assertEqual(y_hard.sum().item(), 23) def _test_gumbel_softmax_backward(test_case, tau, dim, device, dtype): dtype = type_name_to_flow_type[dtype] x_np = np.random.rand(10, 10) x_soft = flow.tensor( x_np, dtype=dtype, device=flow.device(device), requires_grad=True, ) x_hard = flow.tensor( x_np, dtype=dtype, device=flow.device(device), requires_grad=True, ) y_soft = F.gumbel_softmax(x_soft, tau, dim=dim) y_hard = F.gumbel_softmax(x_hard, tau, dim=dim, hard=False) y_soft.mean().backward() y_hard.mean().backward() np.testing.assert_allclose( x_hard.grad.numpy(), x_soft.grad.numpy(), rtol=1e-5, atol=1e-5, verbose=True ) def _test_gumbel_softmax_half(test_case, tau, dim, device): x = flow.tensor(np.random.randn(20, 32), device=flow.device(device),).to( flow.float16 ) y_soft = F.gumbel_softmax(x, tau=tau, dim=dim) y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True) test_case.assertEqual(x.shape, y_soft.shape) test_case.assertEqual(x.shape, y_hard.shape) test_case.assertEqual(x.dtype, y_soft.dtype) test_case.assertEqual(x.dtype, y_hard.dtype) @flow.unittest.skip_unless_1n1d() class TestGumbelSoftmaxModule(flow.unittest.TestCase): @autotest() def test_gumbel_softmax(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_gumbel_softmax, _test_gumbel_softmax_hard, _test_gumbel_softmax_backward, ] arg_dict["tau"] = [1, 2, 0.5] arg_dict["dim"] = [0, -1] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = ["float32", "double"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest() def test_leakyrelu_module_with_half_random_data(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_gumbel_softmax_half, ] arg_dict["tau"] = [1, 2, 0.5] arg_dict["dim"] = [0, -1] arg_dict["device"] = ["cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_hann_window.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestHannWindow(flow.unittest.TestCase): @autotest(n=1, auto_backward=False, check_graph=True) def test_hann_window(test_case): device = random_device() window_length = random(1, 8).to(int).value() periodic = random_bool().value() output = torch.hann_window(window_length, periodic, device=device) return output def test_hann_window_global(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) window_length = random(1, 8).to(int).value() periodic = random_bool().value() output = flow.hann_window(window_length, periodic, placement=placement, sbp=sbp) test_case.assertEqual(output.sbp, sbp) test_case.assertEqual(output.placement, placement) def test_hann_window_dtype(test_case): device = random_device().value() window_length = random(1, 8).to(int).value() periodic = random_bool().value() dtype = flow.float64 output = flow.hann_window(window_length, periodic, device=device, dtype=dtype) test_case.assertEqual(output.dtype, dtype) @profile(torch.hann_window) def profile_hann_window(test_case): torch.hann_window(128000, periodic=True) torch.hann_window(128001, periodic=False) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin from collections import defaultdict def _assert_true(test_case, value1, value2): test_case.assertTrue( np.allclose( value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_activation_grad_grad_impl(test_case, op_name, *args, **kwargs): x = random_tensor(ndim=2, low=-5) y = eval(f"torch.nn.functional.{op_name}")(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_y = random_tensor(len(x_shape), *x_shape) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _assert_true(test_case, dx.pytorch, dx.oneflow) ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) def _test_prelu_activation_grad_grad_impl(test_case, op_name, *args, **kwargs): x = random_tensor(ndim=2, low=-5) a = random_tensor(ndim=1, dim0=x.oneflow.shape[1]) y = torch.nn.functional.prelu(x, a) x_shape = x.oneflow.shape a_shape = a.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_y = random_tensor(len(x_shape), *x_shape) init_grad_a = random_tensor(len(a_shape), *a_shape) dx_and_da = torch.autograd.grad(y, [x, a], init_grad_y, True, True) dx, da = dx_and_da[0], dx_and_da[1] _assert_true(test_case, dx.pytorch, dx.oneflow) _assert_true(test_case, da.pytorch, da.oneflow) ddx_dda_ddy = torch.autograd.grad( dx_and_da, [dx, da, init_grad_y], [init_grad_x, init_grad_a] ) ddx, dda, ddy = ddx_dda_ddy[0], ddx_dda_ddy[1], ddx_dda_ddy[2] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, dda.pytorch, dda.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) def _test_hardswish_activation_grad_grad_impl(test_case, op_name, *args, **kwargs): x = random_tensor(ndim=2, low=-1, dim1=4) y = torch.nn.functional.hardswish(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_y = random_tensor(len(x_shape), *x_shape) dx_pytorch = pytorch_origin.autograd.grad( y.pytorch, x.pytorch, init_grad_y.pytorch )[0] dx_oneflow = oneflow_origin.autograd.grad( y.oneflow, x.oneflow, init_grad_y.oneflow, True, True )[0] _assert_true(test_case, dx_pytorch, dx_oneflow) ddx, ddy = flow.autograd.grad( dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow ) x, dx, init_grad_x, init_grad_y = ( x.oneflow, dx_oneflow, init_grad_x.oneflow, init_grad_y.oneflow, ) manual_ddx = flow.where( ((x > -3.0) < 3.0), 1.0 / 3.0 * init_grad_x * init_grad_y, flow.tensor(0.0) ) manual_ddy = dx / init_grad_y * init_grad_x _assert_true(test_case, manual_ddx, ddx) _assert_true(test_case, manual_ddy, ddy) def _test_hardsigmoid_activation_grad_grad_impl(test_case, op_name, *args, **kwargs): x = random_tensor(ndim=2, low=-1, dim1=4) y = torch.nn.functional.hardsigmoid(x, *args, **kwargs) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_y = random_tensor(len(x_shape), *x_shape) dx_pytorch = pytorch_origin.autograd.grad( y.pytorch, x.pytorch, init_grad_y.pytorch )[0] dx_oneflow = oneflow_origin.autograd.grad( y.oneflow, x.oneflow, init_grad_y.oneflow, True, True )[0] _assert_true(test_case, dx_pytorch, dx_oneflow) ddx, ddy = flow.autograd.grad( dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow ) x, dx, init_grad_x, init_grad_y = ( x.oneflow, dx_oneflow, init_grad_x.oneflow, init_grad_y.oneflow, ) manual_ddx = flow.zeros_like(x) manual_ddy = dx / init_grad_y * init_grad_x _assert_true(test_case, manual_ddx, ddx) _assert_true(test_case, manual_ddy, ddy) class TestActivationHigherDerivative(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 8 times in past week") def test_activation_grad_grad(test_case): op_args = defaultdict(list) op_kwargs = defaultdict(dict) # parameter name not same in pytorch and oneflow op_args["leaky_relu"] = [random(-1, 1).to(float)] # some op only support kwargs, like celu in oneflow op_kwargs["hardtanh"] = { "min_val": random(-5, -1).to(float), "max_val": random(1, 5).to(float), } op_kwargs["elu"] = {"alpha": random(0, 1).to(float)} op_kwargs["celu"] = {"alpha": random(0, 1).to(float)} op_kwargs["threshold"] = { "threshold": random().to(float), "value": random().to(float), } op_kwargs["softplus"] = { "beta": random().to(float), "threshold": random().to(float), } op_names = [ "gelu", "mish", "silu", "selu", "softsign", "hardsigmoid", "hardswish", "relu", "elu", "celu", "prelu", "hardshrink", "softshrink", "leaky_relu", "hardtanh", "softplus", "threshold", ] for op_name in op_names: try: functor = eval(f"_test_{op_name}_activation_grad_grad_impl") except: functor = _test_activation_grad_grad_impl print(f"| {op_name:-^60} |") for i in range(10): functor(test_case, op_name, *op_args[op_name], **op_kwargs[op_name]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_conv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin def _test_convnd_grad_grad_impl(test_case, ndim, rtol=1e-4, atol=1e-5): minibatch = np.random.randint(1, 5) groups = np.random.randint(1, 5) in_channels = np.random.randint(1, 5) * groups out_channels = in_channels * np.random.randint(1, 5) padding = np.random.randint(1, 3) stride = np.random.randint(1, 3) dilation = np.random.randint(1, 3) x_shape = [minibatch, in_channels] + [np.random.randint(8, 12) for i in range(ndim)] w_shape = [out_channels, in_channels // groups] + [ np.random.randint(2, 5) for i in range(ndim) ] x = random_tensor(len(x_shape), *x_shape) w = random_tensor(len(w_shape), *w_shape) init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_w = random_tensor(len(w_shape), *w_shape) y = eval(f"torch.nn.functional.conv{ndim}d")( x, w, stride=stride, padding=padding, groups=groups, dilation=dilation ) init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape) dx = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dx.pytorch.detach().cpu().numpy(), dx.oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) dw = torch.autograd.grad( outputs=y, inputs=w, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dw.pytorch.detach().cpu().numpy(), dw.oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list # so use the original pytorch/oneflow module ddx_pytorch, ddw_pytorch = pytorch_origin.autograd.grad( outputs=[dx.pytorch, dw.pytorch], inputs=[x.pytorch, w.pytorch], grad_outputs=[init_grad_x.pytorch, init_grad_w.pytorch], create_graph=True, retain_graph=True, ) ddx_oneflow, ddw_oneflow = oneflow_origin.autograd.grad( outputs=[dx.oneflow, dw.oneflow], inputs=[x.oneflow, w.oneflow], grad_outputs=[init_grad_x.oneflow, init_grad_w.oneflow], create_graph=True, retain_graph=True, ) test_case.assertTrue( np.allclose( ddw_pytorch.detach().cpu().numpy(), ddw_oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) test_case.assertTrue( np.allclose( ddx_pytorch.detach().cpu().numpy(), ddx_oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) dgrad_dx = torch.autograd.grad( outputs=dx, inputs=init_grad_y, grad_outputs=init_grad_x, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_dx.pytorch.detach().cpu().numpy(), dgrad_dx.oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) dgrad_dw = torch.autograd.grad( outputs=dw, inputs=init_grad_y, grad_outputs=init_grad_w, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_dw.pytorch.detach().cpu().numpy(), dgrad_dw.oneflow.detach().numpy(), rtol=rtol, atol=atol, ) ) class TestConvHigherDerivative(flow.unittest.TestCase): def test_conv1d_grad_grad(test_case): _test_convnd_grad_grad_impl(test_case, 1) def test_conv2d_grad_grad(test_case): _test_convnd_grad_grad_impl(test_case, 2) def test_conv3d_grad_grad(test_case): _test_convnd_grad_grad_impl(test_case, 3, atol=1e-3) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_div.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from numpy.random import randint def _test_div_grad_grad_impl(test_case): y_shape = [randint(2, 5) for _ in range(randint(0, 6))] x_shape = [randint(2, 5) for _ in range(randint(0, 6 - len(y_shape)))] + y_shape if random_bool().value(): x_shape, y_shape = y_shape, x_shape x = random_tensor(len(x_shape), *x_shape).requires_grad_(True) y = random_tensor(len(y_shape), *y_shape).requires_grad_(True) z = torch.div(x, y) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape) init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape) dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True) test_case.assertTrue( np.allclose( dx_and_dy.pytorch[0].detach().cpu().numpy(), dx_and_dy.oneflow[0].detach().numpy(), rtol=1e-4, atol=1e-4, ) ) test_case.assertTrue( np.allclose( dx_and_dy.pytorch[1].detach().cpu().numpy(), dx_and_dy.oneflow[1].detach().numpy(), rtol=1e-4, atol=1e-4, ) ) ddx_and_ddy_and_ddz = torch.autograd.grad( dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y], True, True ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[0].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[0].detach().numpy(), rtol=1e-3, atol=1e-3, ) ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[1].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[1].detach().numpy(), rtol=1e-3, atol=1e-3, ) ) test_case.assertTrue( np.allclose( ddx_and_ddy_and_ddz.pytorch[2].detach().cpu().numpy(), ddx_and_ddy_and_ddz.oneflow[2].detach().numpy(), rtol=1e-3, atol=1e-3, ) ) class TestDivHigherDerivative(flow.unittest.TestCase): def test_div_grad_grad(test_case): for i in range(10): _test_div_grad_grad_impl(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _assert_true(test_case, value1, value2, name=""): is_equal = np.allclose( value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-04, atol=1e-04, ) test_case.assertTrue(is_equal, f"{name} is not equal." if name else "") def generate_grads_for_variables(variables): if isinstance(variables, list): variables_shape = [i.pytorch.shape for i in variables] device = torch.device(str(variables[0].pytorch.device)) elif hasattr(variables, "pytorch"): variables_shape = [i.shape for i in variables.pytorch] device = torch.device(str(variables.pytorch[0].device)) else: assert False grads = [ random_tensor(len(shape), *shape, requires_grad=True).to(device) for shape in variables_shape ] return grads def calculate_and_compare_loss(test_case, input, target, model, order=2): output = model(input, target) _assert_true(test_case, output.pytorch, output.oneflow) init_inputs = [input, target] grad_inputs = [output] grad_outputs = [] for i in range(order): inputs = [ var for var in [*init_inputs, *grad_outputs] if var.pytorch.requires_grad ] outputs = grad_inputs grad_outputs = generate_grads_for_variables(outputs) if i == order - 1: grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs) else: grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs, True, True) for j in range(len(inputs)): _assert_true( test_case, grad_inputs[j].pytorch, grad_inputs[j].oneflow, f"{i}-grad_inputs[{j}]", ) def generate_necessity_for_default_loss(): ndim = random(2, 6).to(int).value() device = random_device() shape = [random().to(int) for _ in range(ndim)] input_requires_grad = True target_requires_grad = random_bool().value() return ( random_tensor(ndim, *shape, requires_grad=input_requires_grad, low=0).to( device ), random_tensor(ndim, *shape, requires_grad=target_requires_grad, low=0).to( device ), ) def generate_necessity_for_nll_loss(): ndim = random(2, 6).to(int).value() device = random_device() num_classes = random(low=2).to(int) batch_size = random(low=2, high=5).to(int) ignore_index = ( random(0, num_classes).to(int) | nothing() if num_classes.value() > 2 else nothing() ) extra_dim = [random().to(int) for _ in range(ndim - 2)] return ( random_tensor(ndim, batch_size, num_classes, *extra_dim).to(device), random_tensor( ndim - 1, batch_size, *extra_dim, low=0, high=num_classes, dtype=int, requires_grad=False, ).to(device), random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to(device), ignore_index, ) def generate_necessity_for_bce_loss(): ndim = random(2, 6).to(int).value() device = random_device() num_classes = 2 batch_size = random(low=2, high=5).to(int) extra_dim = [random().to(int) for _ in range(ndim - 2)] input_requires_grad = True target_requires_grad = False return ( random_tensor( ndim, batch_size, num_classes, *extra_dim, requires_grad=input_requires_grad, low=0, high=1, ).to(device), random_tensor( ndim, batch_size, num_classes, *extra_dim, low=0, high=num_classes, requires_grad=target_requires_grad, ).to(device), random_tensor( ndim, batch_size, num_classes, *extra_dim, low=0, high=3, requires_grad=False, ).to(device), random_tensor( 1, oneof(extra_dim[-1] if ndim > 2 else num_classes, 1).value(), low=1, high=3, requires_grad=False, ).to(device), ) def _test_smooth_l1_loss_grad_grad_impl(test_case): x, y = generate_necessity_for_default_loss() m = torch.nn.SmoothL1Loss( reduction=oneof("none", "sum", "mean", nothing()), beta=oneof(0.0, 0.5, 1) ) m.to(x.device) calculate_and_compare_loss(test_case, x, y, m) def _test_kl_div_loss_grad_grad_impl(test_case): x, y = generate_necessity_for_default_loss() m = torch.nn.KLDivLoss( reduction=oneof("none", "sum", "mean", nothing()), log_target=oneof(True, False), ) m.to(x.device) calculate_and_compare_loss(test_case, x, y, m) def _test_bce_loss_grad_grad_impl(test_case, with_logits=False): x, y, weight, pos_weight = generate_necessity_for_bce_loss() if with_logits: weight = oneof(weight, nothing()) has_pos_weight = random_bool().value() pos_weight = pos_weight if has_pos_weight else nothing() m = torch.nn.BCEWithLogitsLoss( weight=weight, pos_weight=pos_weight, reduction=oneof("none", "sum", "mean"), ) if has_pos_weight: y = y.detach().clone().requires_grad_(False) else: m = torch.nn.BCELoss( weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean"), ) m.to(x.device) calculate_and_compare_loss(test_case, x, y, m) def _test_nll_loss_grad_grad_impl(test_case): (x, y, weight, ignore_index) = generate_necessity_for_nll_loss() m = torch.nn.NLLLoss( weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean"), ignore_index=ignore_index, ) m.to(x.device) calculate_and_compare_loss(test_case, x, y, m) @flow.unittest.skip_unless_1n1d() class TestLossHigherDerivative(flow.unittest.TestCase): def test_smooth_l1_loss_grad_grad(test_case): for i in range(5): _test_smooth_l1_loss_grad_grad_impl(test_case) def test_kl_div_loss_grad_grad(test_case): for i in range(5): _test_kl_div_loss_grad_grad_impl(test_case) @unittest.skip("skip for now, becase it failed 8 times in past week") def test_nll_loss_grad_grad(test_case): for i in range(5): _test_nll_loss_grad_grad_impl(test_case) def test_bce_loss_grad_grad(test_case): for i in range(5): _test_bce_loss_grad_grad_impl(test_case) def test_bce_with_logits_loss_grad_grad(test_case): for i in range(5): _test_bce_loss_grad_grad_impl(test_case, with_logits=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import torch as pytorch_origin import oneflow as oneflow_origin class TestMatmulHigherDerivative(flow.unittest.TestCase): def test_broadcast_matmul_grad_b_grad(test_case): broadcast_dims = [ np.random.randint(2, 10) for _ in range(np.random.randint(1, 3)) ] m = np.random.randint(2, 10) n = np.random.randint(2, 10) k = np.random.randint(2, 10) shape_a = broadcast_dims + [m, k] shape_b = [k, n] shape_y = broadcast_dims + [m, n] a = random_tensor(len(shape_a), *shape_a).requires_grad_(True) b = random_tensor(len(shape_b), *shape_b).requires_grad_(True) y = torch.matmul(a, b) init_grad_a = random_tensor(len(shape_a), *shape_a).requires_grad_(True) init_grad_b = random_tensor(len(shape_b), *shape_b).requires_grad_(True) init_grad_y = random_tensor(len(shape_y), *shape_y).requires_grad_(True) da = torch.autograd.grad( outputs=y, inputs=a, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( da.pytorch.detach().cpu().numpy(), da.oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) db = torch.autograd.grad( outputs=y, inputs=b, grad_outputs=init_grad_y, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( db.pytorch.detach().cpu().numpy(), db.oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list # so use the original pytorch/oneflow module dda_pytorch, ddb_pytorch = pytorch_origin.autograd.grad( outputs=[da.pytorch, db.pytorch], inputs=[a.pytorch, b.pytorch], grad_outputs=[init_grad_a.pytorch, init_grad_b.pytorch], create_graph=True, retain_graph=True, ) dda_oneflow, ddb_oneflow = oneflow_origin.autograd.grad( outputs=[da.oneflow, db.oneflow], inputs=[a.oneflow, b.oneflow], grad_outputs=[init_grad_a.oneflow, init_grad_b.oneflow], create_graph=True, retain_graph=True, ) test_case.assertTrue( np.allclose( ddb_pytorch.detach().cpu().numpy(), ddb_oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) test_case.assertTrue( np.allclose( dda_pytorch.detach().cpu().numpy(), dda_oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) dgrad_da = torch.autograd.grad( outputs=da, inputs=init_grad_y, grad_outputs=init_grad_a, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_da.pytorch.detach().cpu().numpy(), dgrad_da.oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) dgrad_db = torch.autograd.grad( outputs=db, inputs=init_grad_y, grad_outputs=init_grad_b, create_graph=True, retain_graph=True, )[0] test_case.assertTrue( np.allclose( dgrad_db.pytorch.detach().cpu().numpy(), dgrad_db.oneflow.detach().numpy(), rtol=1e-4, atol=1e-5, ) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_neg.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * class TestNegHigherDerivative(flow.unittest.TestCase): def test_neg_grad_grad(test_case): x = random_tensor(ndim=2).requires_grad_(True) y = torch.neg(x) np_arr = np.random.rand(*x.oneflow.shape) init_grad = torch.tensor(np_arr).requires_grad_() x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0] test_case.assertTrue( np.allclose( x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy() ) ) init_grad_grad = torch.tensor(np_arr).requires_grad_() dgrad = torch.autograd.grad( x_grad, init_grad, init_grad_grad, create_graph=False )[0] test_case.assertTrue( np.allclose( dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(), ) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_pool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, name="", rtol=1e-5, atol=1e-5): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal, f"{name} is not equal" if name else "") def _test_avg_pool_grad_grad_impl(test_case, ndim): device = random_device() minibatch = random(1, 5).to(int).value() channels = random(1, 5).to(int).value() padding = random(0, 3).to(int).value() ceil_mode = random_bool().value() count_include_pad = random_bool().value() divisor_override = random().to(int).value() kernel_size = random(4, 6).to(int).value() stride = random(1, 3).to(int).value() x_shape = [minibatch, channels] + [ random(8, 12).to(int).value() for i in range(ndim) ] kwargs = { "kernel_size": kernel_size, "stride": oneof(stride, nothing()), "padding": oneof(padding, nothing()), "ceil_mode": ceil_mode, "count_include_pad": count_include_pad, } if ndim != 1: kwargs["divisor_override"] = divisor_override m = eval(f"torch.nn.AvgPool{ndim}d")(**kwargs) m.to(device) x = random_tensor(len(x_shape), *x_shape).to(device) y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") def _test_max_pool_grad_grad_impl(test_case, ndim): device = random_device() minibatch = random(1, 5).to(int).value() channels = random(1, 5).to(int).value() padding = random(0, 3).to(int).value() dilation = random(1, 3).to(int).value() ceil_mode = random_bool().value() return_indices = random_bool().value() kernel_size = random(4, 6).to(int).value() stride = random(1, 3).to(int).value() x_shape = [minibatch, channels] + [ random(10, 12).to(int).value() for i in range(ndim) ] m = eval(f"torch.nn.MaxPool{ndim}d")( kernel_size=kernel_size, stride=oneof(stride, nothing()), padding=oneof(padding, nothing()), dilation=oneof(dilation, nothing()), ceil_mode=ceil_mode, return_indices=return_indices, ) m.to(device) x = random_tensor(len(x_shape), *x_shape).to(device) if return_indices: y_and_indices = m(x) y, indices = y_and_indices[0], y_and_indices[1] else: y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") def _test_adaptive_pool_grad_grad_impl(test_case, ndim, mode): device = random_device() x_shape = [random(5, 10).to(int).value() for i in range(2 + ndim)] output_size = [random(2, 1 + x_shape[2 + i]).to(int).value() for i in range(ndim)] m = eval(f"torch.nn.Adaptive{mode.title()}Pool{ndim}d")(output_size) m.to(device) x = random_tensor(len(x_shape), *x_shape).to(device) y = m(x) _check_equal(test_case, y.pytorch, y.oneflow, "y") init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, "dx") ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _check_equal(test_case, ddx.pytorch, ddx.oneflow, "ddx") _check_equal(test_case, ddy.pytorch, ddy.oneflow, "ddy") @flow.unittest.skip_unless_1n1d() class TestPoolHigherDerivative(flow.unittest.TestCase): def test_max_pool_1d_grad_grad(test_case): _test_max_pool_grad_grad_impl(test_case, 1) def test_max_pool_2d_grad_grad(test_case): _test_max_pool_grad_grad_impl(test_case, 2) def test_max_pool_3d_grad_grad(test_case): _test_max_pool_grad_grad_impl(test_case, 3) def test_avg_pool_1d_grad_grad(test_case): _test_avg_pool_grad_grad_impl(test_case, ndim=1) def test_avg_pool_2d_grad_grad(test_case): _test_avg_pool_grad_grad_impl(test_case, ndim=2) def test_avg_pool_3d_grad_grad(test_case): _test_avg_pool_grad_grad_impl(test_case, ndim=3) def test_adaptive_avg_pool_1d_grad_grad(test_case): _test_adaptive_pool_grad_grad_impl(test_case, ndim=1, mode="avg") def test_adaptive_avg_pool_2d_grad_grad(test_case): _test_adaptive_pool_grad_grad_impl(test_case, ndim=2, mode="avg") def test_adaptive_avg_pool_3d_grad_grad(test_case): _test_adaptive_pool_grad_grad_impl(test_case, ndim=3, mode="avg") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_pow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, rtol=1e-3, atol=1e-3): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal) def _test_pow_grad_grad_impl(test_case): y_shape = [random().to(int).value() for _ in range(random().to(int).value())] x_shape = y_shape[random(0, 5).to(int).value() :] if random_bool().value(): x_shape, y_shape = y_shape, x_shape # The range limit should be removed after solving issue #9908 x = random_tensor(len(x_shape), *x_shape, low=0, high=1) y = random_tensor(len(y_shape), *y_shape, low=0, high=1) z = torch.pow(x, y) _check_equal(test_case, z.pytorch, z.oneflow) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape) init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape) dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True) _check_equal(test_case, dx_and_dy.pytorch[0], dx_and_dy.oneflow[0]) _check_equal(test_case, dx_and_dy.pytorch[1], dx_and_dy.oneflow[1]) ddx_ddy_ddz = torch.autograd.grad( dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y] ) _check_equal(test_case, ddx_ddy_ddz.pytorch[0], ddx_ddy_ddz.oneflow[0]) _check_equal(test_case, ddx_ddy_ddz.pytorch[1], ddx_ddy_ddz.oneflow[1]) _check_equal(test_case, ddx_ddy_ddz.pytorch[2], ddx_ddy_ddz.oneflow[2]) @flow.unittest.skip_unless_1n1d() class TestPowHigherDerivative(flow.unittest.TestCase): def test_pow_grad_grad(test_case): for i in range(10): _test_pow_grad_grad_impl(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_scalar_pow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _check_equal(test_case, lhs, rhs, rtol=1e-4, atol=1e-4, name=""): is_equal = np.allclose( lhs.detach().cpu().numpy(), rhs.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, ) test_case.assertTrue(is_equal, f"{name} is not equal") def _test_scalar_pow_grad_grad_impl(test_case, reverse=False): x_shape = [random().to(int).value() for _ in range(random().to(int).value())] y = random().to(float if random_bool().value() else int).value() x = random_tensor(len(x_shape), *x_shape) z = torch.pow(x, y) if not reverse else torch.pow(y, x) init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape) init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape) dx = torch.autograd.grad(z, x, init_grad_z, True, True)[0] _check_equal(test_case, dx.pytorch, dx.oneflow, name="dx") ddx_and_ddz = torch.autograd.grad(dx, [x, init_grad_z], init_grad_x, True, True) _check_equal(test_case, ddx_and_ddz.pytorch[0], ddx_and_ddz.oneflow[0], name="ddx") _check_equal(test_case, ddx_and_ddz.pytorch[1], ddx_and_ddz.oneflow[1], name="ddz") class TestScalarPowHigherDerivative(flow.unittest.TestCase): def test_scalar_pow_grad_grad(test_case): for i in range(10): _test_scalar_pow_grad_grad_impl(test_case) def test_scalar_reverse_pow_grad_grad(test_case): for i in range(10): _test_scalar_pow_grad_grad_impl(test_case, reverse=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_slice.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def random_index(dim): start = np.random.choice(list(range(dim))) stop = np.random.choice(list(range(1, dim + 1))) if start >= stop: start, stop = stop - 1, start + 1 step = np.random.randint(1, dim) return f"{start}:{stop}:{step}" def random_slice(dim_vec): slice_index = ", ".join(random_index(dim) for dim in dim_vec) return slice_index def _test_slice_grad_grad_impl(test_case): ndim = np.random.randint(2, 5) x_shape = [np.random.randint(3, 8) for _ in range(ndim)] x = random_tensor(len(x_shape), *x_shape).requires_grad_(True) slice_index = random_slice(x_shape) y = eval(f"x[{slice_index}]") init_grad = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).requires_grad_() x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0] test_case.assertTrue( np.allclose( x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy() ) ) init_grad_grad = random_tensor( len(x_grad.oneflow.shape), *x_grad.oneflow.shape ).requires_grad_() dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=False)[ 0 ] test_case.assertTrue( np.allclose( dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(), ) ) class TestSliceHigherDerivative(flow.unittest.TestCase): def test_slice_grad_grad(test_case): for i in range(10): _test_slice_grad_grad_impl(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_higher_derivative_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _assert_true(test_case, value1, value2): test_case.assertTrue( np.allclose( value1.detach().cpu().numpy(), value2.detach().cpu().numpy(), rtol=1e-05, atol=1e-05, ) ) def _test_softmax_grad_grad_impl(test_case, op_name): ndim = random(low=2).to(int).value() data = random_tensor(ndim=ndim) for dim in range(ndim): x = data.detach().clone().requires_grad_() m = eval(f"torch.nn.{op_name}")(dim) y = m(x) _assert_true(test_case, y.pytorch, y.oneflow) x_shape = x.oneflow.shape init_grad_x = random_tensor(len(x_shape), *x_shape) init_grad_y = random_tensor(len(x_shape), *x_shape) dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0] _assert_true(test_case, dx.pytorch, dx.oneflow) ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x) ddx, ddy = ddx_ddy[0], ddx_ddy[1] _assert_true(test_case, ddx.pytorch, ddx.oneflow) _assert_true(test_case, ddy.pytorch, ddy.oneflow) @flow.unittest.skip_unless_1n1d() class TestSoftmaxHigherDerivative(flow.unittest.TestCase): def test_softmax_grad_grad(test_case): _test_softmax_grad_grad_impl(test_case, op_name="Softmax") def test_logsoftmax_grad_grad(test_case): _test_softmax_grad_grad_impl(test_case, op_name="LogSoftmax") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_host_memory_input.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow from oneflow import nn import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestHostMemory(oneflow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_host_memory(test_case): x = flow.ones(2, 3, device="cuda") scalar = flow.Tensor([3.0], device="cuda") y = x + scalar out = y + scalar + y class HostMemoryInputGraph(nn.Graph): def __init__(self): super(HostMemoryInputGraph, self).__init__() def build(self, x, scalar): a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu()) b = flow._C.host_scalar_add_by_tensor(a, scalar) return a + b graph = HostMemoryInputGraph() lazy_out = graph(x, scalar) test_case.assertTrue(np.array_equal(out.numpy(), lazy_out.numpy())) a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu()) b = flow._C.host_scalar_add_by_tensor(a, scalar) eager_out = a + b test_case.assertTrue(np.array_equal(out.numpy(), eager_out.numpy())) @flow.unittest.skip_unless_1n2d() def test_host_memory_1n2d(test_case): x = flow.ones( 2, 3, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) scalar = flow.Tensor( [3.0], placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) y = x + scalar out = y + scalar + y class HostMemoryInputGraph(nn.Graph): def __init__(self): super(HostMemoryInputGraph, self).__init__() def build(self, x, scalar): a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu()) b = flow._C.host_scalar_add_by_tensor(a, scalar) return a + b graph = HostMemoryInputGraph() lazy_out = graph(x, scalar) test_case.assertTrue(np.array_equal(out.numpy(), lazy_out.numpy())) a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu()) b = flow._C.host_scalar_add_by_tensor(a, scalar) eager_out = a + b test_case.assertTrue(np.array_equal(out.numpy(), eager_out.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_hsplit.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestHsplitVec(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_flow_hsplit_vec(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) z = torch.hsplit(x, (1, 2)) return z @autotest(n=5) def test_flow_hsplit_vec_with_stride(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.hsplit(y, (1, 2)) return z @flow.unittest.skip_unless_1n1d() class TestHsplitInt(flow.unittest.TestCase): @autotest(n=10, check_graph=True) def test_flow_hsplit_int(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=12, dim2=random(3, 6), dim3=random(3, 6), ).to(device) split = oneof(2, 4, 6) z = torch.hsplit(x, split) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_hub.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skip(reason="network fluctuations can cause downloads to fail!") class TestHub(flow.unittest.TestCase): def test_hub_list_api(test_case): entrypoints = flow.hub.list("OneFlow-Inc/vision", force_reload=False) test_case.assertEqual("alexnet" in entrypoints, True) test_case.assertEqual("densenet121" in entrypoints, True) def test_hub_help_api(test_case): help_info = flow.hub.help("Oneflow-Inc/vision", "resnet18", force_reload=False) print(help_info) def test_hub_load_api(test_case): repo = "Oneflow-Inc/vision" model = flow.hub.load(repo, "resnet18", pretrained=True) x = flow.randn(1, 3, 224, 224) y = model(x) test_case.assertTrue(np.array_equal(y.size(), (1, 1000))) def test_hub_download_url_to_file__api(test_case): flow.hub.download_url_to_file( "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip", "/tmp/temporary_file", ) def test_hub_load_state_dict_from_url_api(test_case): state_dict = flow.hub.load_state_dict_from_url( "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip" ) test_case.assertEqual("layer3.1.bn2.bias" in state_dict.keys(), True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_image_batch_align.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import operator import unittest from functools import reduce import cv2 import numpy as np import oneflow as flow import oneflow.unittest def _read_images_by_cv(image_files): images = [cv2.imread(image_file).astype(np.single) for image_file in image_files] return images def _get_images_static_shape(images): image_shapes = [image.shape for image in images] image_static_shape = np.amax(image_shapes, axis=0) assert isinstance( image_static_shape, np.ndarray ), "image_shapes: {}, image_static_shape: {}".format( str(image_shapes), str(image_static_shape) ) image_static_shape = image_static_shape.tolist() image_static_shape.insert(0, len(image_shapes)) return image_static_shape def _roundup(x, n): return int((x + n - 1) / n) * n @flow.unittest.skip_unless_1n1d() class TestImageBatchAlign(flow.unittest.TestCase): def test_image_batch_align(test_case): image_files = [ flow.unittest.dataset_dir("mscoco_2017/val2017/000000000139.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000000632.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000000785.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000001000.jpg"), ] alignment = 16 images = _read_images_by_cv(image_files) image_shape = _get_images_static_shape(images) assert len(image_shape) == 4 aligned_image_shape = [ image_shape[0], _roundup(image_shape[1], alignment), _roundup(image_shape[2], alignment), image_shape[3], ] image_batch_aligner = flow.nn.image.batch_align( shape=aligned_image_shape[1:], dtype=flow.float, alignment=alignment ) images_np_arr_static = np.zeros(image_shape, dtype=np.float32) for (idx, np_arr) in enumerate(images): images_np_arr_static[idx, : np_arr.shape[0], : np_arr.shape[1], :] = np_arr input = flow.tensor( images_np_arr_static, dtype=flow.float, device=flow.device("cpu") ) images_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3) of_aligned_image = image_batch_aligner(images_buffer).numpy() test_case.assertTrue( np.array_equal(aligned_image_shape, of_aligned_image.shape) ) empty_image_array = np.zeros(aligned_image_shape, np.float32) for (empty_image, image) in zip(empty_image_array, images): empty_image[0 : image.shape[0], 0 : image.shape[1], :] = image test_case.assertTrue(np.array_equal(of_aligned_image, empty_image_array)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_image_decode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import cv2 import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestImageDecode(flow.unittest.TestCase): def test_image_decode(test_case): images = [ flow.unittest.dataset_dir("mscoco_2017/val2017/000000000139.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000000632.jpg"), ] image_files = [open(im, "rb") for im in images] images_bytes = [imf.read() for imf in image_files] static_shape = (len(images_bytes), max([len(bys) for bys in images_bytes])) for imf in image_files: imf.close() image_decoder = flow.nn.image.decode(color_space="BGR") images_np_arr = [ np.frombuffer(bys, dtype=np.byte).reshape(1, -1) for bys in images_bytes ] images_np_arr_static = np.zeros(static_shape, dtype=np.int8) for (idx, np_arr) in enumerate(images_np_arr): images_np_arr_static[idx, : np_arr.shape[1]] = np_arr input = flow.tensor( images_np_arr_static, dtype=flow.int8, device=flow.device("cpu") ) images_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=1) decoded_images_buffer = image_decoder(images_buffer) of_decoded_images = decoded_images_buffer.numpy() cv2_images = [cv2.imread(image) for image in images] cv2_decoded_images = [np.array(image) for image in cv2_images] for (of_decoded_image, cv2_decoded_image) in zip( of_decoded_images, cv2_decoded_images ): test_case.assertTrue(len(of_decoded_image.shape) == 3) test_case.assertTrue(len(cv2_decoded_image.shape) == 3) test_case.assertTrue(np.allclose(of_decoded_image, cv2_decoded_image)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_image_flip.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import cv2 import numpy as np import oneflow as flow import oneflow.unittest def _of_image_flip(images, image_static_shape, flip_code): image_tensors = flow.tensor(images, dtype=flow.float, device=flow.device("cpu")) image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3) flip_images = flow.nn.image.flip()(image_tensor_buffer, flip_code) return flip_images.numpy() def _read_images_by_cv(image_files): images = [cv2.imread(image_file).astype(np.single) for image_file in image_files] return [np.expand_dims(image, axis=0) for image in images] def _get_images_static_shape(images): image_shapes = [image.shape for image in images] image_static_shape = np.amax(image_shapes, axis=0) assert isinstance( image_static_shape, np.ndarray ), "image_shapes: {}, image_static_shape: {}".format( str(image_shapes), str(image_static_shape) ) image_static_shape = image_static_shape.tolist() assert image_static_shape[0] == 1, str(image_static_shape) image_static_shape[0] = len(image_shapes) return image_static_shape def _compare_image_flip_with_cv(test_case, image_files): images = _read_images_by_cv(image_files) assert all([len(image.shape) == 4 for image in images]) image_static_shape = _get_images_static_shape(images) image_paddings = np.zeros(tuple(image_static_shape)) for (idx, image) in enumerate(images): image_paddings[ idx, : image.shape[1], : image.shape[2], : image.shape[3] ] = image flip_code = flow.ones(image_static_shape[0], dtype=flow.int8) flip_images = _of_image_flip(image_paddings, image_static_shape, flip_code) for (image, flip_image) in zip(image_paddings, flip_images): exp_flip_image = cv2.flip(image.squeeze(), 1) test_case.assertTrue(np.allclose(exp_flip_image, flip_image)) @flow.unittest.skip_unless_1n1d() class TestImageFlip(flow.unittest.TestCase): def test_image_flip(test_case): _compare_image_flip_with_cv( test_case, [ flow.unittest.dataset_dir("mscoco_2017/val2017/000000000139.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000000632.jpg"), ], ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_image_normalize.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import cv2 import numpy as np import oneflow as flow import oneflow.unittest def _of_image_normalize(images, image_static_shape, std, mean): image_zeros = np.zeros(tuple(image_static_shape)) for (idx, image) in enumerate(images): image_zeros[idx, : image.shape[1], : image.shape[2], : image.shape[3]] = image image_tensors = flow.tensor( image_zeros, dtype=flow.float, device=flow.device("cpu") ) image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3) image_normalizer = flow.nn.image.normalize(std, mean) norm_images = image_normalizer(image_tensor_buffer) return norm_images.numpy() def _read_images_by_cv(image_files): images = [cv2.imread(image_file).astype(np.single) for image_file in image_files] return [np.expand_dims(image, axis=0) for image in images] def _get_images_static_shape(images): image_shapes = [image.shape for image in images] image_static_shape = np.amax(image_shapes, axis=0) assert isinstance( image_static_shape, np.ndarray ), "image_shapes: {}, image_static_shape: {}".format( str(image_shapes), str(image_static_shape) ) image_static_shape = image_static_shape.tolist() assert image_static_shape[0] == 1, str(image_static_shape) image_static_shape[0] = len(image_shapes) return image_static_shape def _compare_image_normalize(test_case, image_files, std, mean): images = _read_images_by_cv(image_files) assert all([len(image.shape) == 4 for image in images]) image_static_shape = _get_images_static_shape(images) norm_images = _of_image_normalize(images, image_static_shape, std, mean) std_array = np.array(std).reshape(1, 1, 1, -1) mean_array = np.array(mean).reshape(1, 1, 1, -1) for (image, norm_image) in zip(images, norm_images): np_norm_image = np.squeeze((image - mean_array) / std_array, axis=0) norm_image = norm_image[ : np_norm_image.shape[0], : np_norm_image.shape[1], : np_norm_image.shape[2] ] test_case.assertTrue(np.allclose(np_norm_image, norm_image)) @flow.unittest.skip_unless_1n1d() class TestImageNormalize(flow.unittest.TestCase): def test_image_normalize(test_case): _compare_image_normalize( test_case, [ flow.unittest.dataset_dir("mscoco_2017/val2017/000000000139.jpg"), flow.unittest.dataset_dir("mscoco_2017/val2017/000000000632.jpg"), ], (102.9801, 115.9465, 122.7717), (1.0, 1.0, 1.0), ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_image_resize.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import cv2 import image_test_util import numpy as np import oneflow as flow import oneflow.nn as nn import oneflow.unittest def _of_image_resize( image_list, dtype=flow.float32, origin_dtype=flow.float32, channels=3, keep_aspect_ratio=False, target_size=None, min_size=None, max_size=None, resize_side="shorter", interpolation_type="bilinear", ): assert isinstance(image_list, (list, tuple)) assert all((isinstance(image, np.ndarray) for image in image_list)) assert all((image.ndim == 3 for image in image_list)) assert all((image.shape[2] == channels for image in image_list)) res_image_list = [] res_size_list = [] res_scale_list = [] image_resize_module = nn.image.Resize( target_size=target_size, min_size=min_size, max_size=max_size, keep_aspect_ratio=keep_aspect_ratio, resize_side=resize_side, dtype=dtype, interpolation_type=interpolation_type, channels=channels, ) for image in image_list: tensor_dtype = dtype if keep_aspect_ratio else origin_dtype input = flow.tensor( np.expand_dims(image, axis=0), dtype=tensor_dtype, device=flow.device("cpu") ) image_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3) (res_image, scale, new_size) = image_resize_module(image_buffer) res_image = res_image.numpy() scale = scale.numpy() if not keep_aspect_ratio: new_size = np.asarray([(target_size, target_size)]) else: new_size = new_size.numpy() res_image_list.append(res_image[0]) res_size_list.append(new_size[0]) res_scale_list.append(scale[0]) return (res_image_list, res_scale_list, res_size_list) def _get_resize_size_and_scale( w, h, target_size, min_size=None, max_size=None, keep_aspect_ratio=True, resize_side="shorter", ): if keep_aspect_ratio: assert isinstance(target_size, int) aspect_ratio = float(min((w, h))) / float(max((w, h))) ( min_res_size, max_res_size, ) = image_test_util.compute_keep_aspect_ratio_resized_size( target_size, min_size, max_size, aspect_ratio, resize_side ) if w < h: res_w = min_res_size res_h = max_res_size else: res_w = max_res_size res_h = min_res_size else: assert isinstance(target_size, (list, tuple)) assert len(target_size) == 2 assert all((isinstance(size, int) for size in target_size)) (res_w, res_h) = target_size scale_w = res_w / w scale_h = res_h / h return ((res_w, res_h), (scale_w, scale_h)) def _cv_image_resize( image_list, target_size, keep_aspect_ratio=True, min_size=None, max_size=None, resize_side="shorter", interpolation=cv2.INTER_LINEAR, dtype=np.float32, ): res_image_list = [] res_size_list = [] res_scale_list = [] for image in image_list: (h, w) = image.shape[:2] (new_size, scale) = _get_resize_size_and_scale( w, h, target_size, min_size, max_size, keep_aspect_ratio, resize_side ) res_image_list.append( cv2.resize(image.squeeze(), new_size, interpolation=interpolation).astype( dtype ) ) res_size_list.append(new_size) res_scale_list.append(scale) return (res_image_list, res_scale_list, res_size_list) def _test_image_resize_with_cv( test_case, image_files, target_size, min_size=None, max_size=None, keep_aspect_ratio=True, resize_side="shorter", dtype=flow.float32, origin_dtype=None, ): if origin_dtype is None: origin_dtype = dtype image_list = image_test_util.read_images_by_cv(image_files, origin_dtype) (of_res_images, of_scales, of_new_sizes) = _of_image_resize( image_list=image_list, dtype=dtype, origin_dtype=origin_dtype, keep_aspect_ratio=keep_aspect_ratio, target_size=target_size, min_size=min_size, max_size=max_size, resize_side=resize_side, ) (cv_res_images, cv_scales, cv_new_sizes) = _cv_image_resize( image_list=image_list, target_size=target_size, keep_aspect_ratio=keep_aspect_ratio, min_size=min_size, max_size=max_size, resize_side=resize_side, dtype=flow.convert_oneflow_dtype_to_numpy_dtype(dtype), ) for ( of_res_image, cv_res_image, of_scale, cv_scale, of_new_size, cv_new_size, ) in zip( of_res_images, cv_res_images, of_scales, cv_scales, of_new_sizes, cv_new_sizes ): test_case.assertTrue(np.allclose(of_res_image, cv_res_image)) test_case.assertTrue(np.allclose(of_scale, cv_scale)) test_case.assertTrue(np.allclose(of_new_size, cv_new_size)) @flow.unittest.skip_unless_1n1d() @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", ) class TestImageResize(flow.unittest.TestCase): def test_image_resize_to_fixed_size(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=(224, 224), keep_aspect_ratio=False ) def test_image_resize_shorter_to_target_size(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=800, keep_aspect_ratio=True, resize_side="shorter", ) def test_image_resize_longer_to_target_size(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=1000, keep_aspect_ratio=True, resize_side="longer", ) def test_image_resize_shorter_to_target_size_with_max_size(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=800, max_size=1333, keep_aspect_ratio=True, resize_side="shorter", ) def test_image_resize_longer_to_target_size_with_min_size(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=1000, min_size=600, keep_aspect_ratio=True, resize_side="longer", ) def test_image_resize_to_fixed_size_with_dtype_uint8(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=(1000, 1000), keep_aspect_ratio=False, dtype=flow.uint8, ) def test_image_reisze_shorter_to_target_size_with_max_size_with_dtype_uint8( test_case, ): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=1000, max_size=1600, keep_aspect_ratio=True, resize_side="shorter", dtype=flow.uint8, ) def test_image_resize_uint8_to_float(test_case): (image_files, _) = image_test_util.random_sample_images_from_coco() _test_image_resize_with_cv( test_case, image_files, target_size=(1000, 1000), keep_aspect_ratio=False, dtype=flow.float32, origin_dtype=flow.uint8, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_in_top_k.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _topk_np(input, k, dim: int = -1, largest: bool = True, _sorted: bool = True): in_dims = input.shape out_dims = list(in_dims) num_axes = len(input.shape) if dim < 0: dim = dim + num_axes n = in_dims[dim] if k > n: k = n out_dims[dim] = k out_dims = tuple(out_dims) prev_dims = 1 next_dims = 1 for i in range(dim): prev_dims *= in_dims[i] for i in range(dim + 1, len(in_dims)): next_dims *= in_dims[i] input_flat = input.reshape((prev_dims, n, next_dims)) values_ref = np.ndarray(shape=(prev_dims, k, next_dims), dtype=input.dtype) values_ref.fill(0) indices_ref = np.ndarray(shape=(prev_dims, k, next_dims), dtype=np.int64) indices_ref.fill(-1) for i in range(prev_dims): for j in range(next_dims): kv = [] for x in range(n): val = input_flat[i, x, j] y = x * next_dims + i * in_dims[dim] * next_dims + j kv.append((val, x, y)) cnt = 0 for (val, x, y) in sorted(kv, key=lambda x: (x[0], -x[1]), reverse=largest): values_ref[i, cnt, j] = val indices_ref[i, cnt, j] = x cnt += 1 if cnt >= k or cnt >= n: break values_ref = values_ref.reshape(out_dims) indices_ref = indices_ref.reshape(out_dims) return (values_ref, indices_ref) def _in_top_k_np(targets, predictions, k): assert ( targets.shape[0] == predictions.shape[0] ), "The num of targets must equal the num of predictions" assert len(targets.shape) == 1, "The dimension of targets must be 1" assert len(predictions.shape) == 2, "The dimension of predictions must be 2" results = np.zeros_like(targets, dtype=np.int8) for i in range(len(results)): (_, indices_topk) = _topk_np(predictions[i], k) if targets[i] in indices_topk: results[i] = 1 return results def _test_in_top_k_impl(test_case, shape, k, device): np_targets = np.random.randint(0, shape[1], size=shape[0]) np_predictions = np.random.rand(*shape) of_targets = flow.tensor( np_targets, dtype=flow.int32, device=flow.device(device), requires_grad=False ) of_predictions = flow.tensor( np_predictions, dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.in_top_k(of_targets, of_predictions, k) np_out = _in_top_k_np(np_targets, np_predictions, k) test_case.assertTrue( np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True) ) @flow.unittest.skip_unless_1n1d() class TestInTopK(flow.unittest.TestCase): def test_in_top_k(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (3, 4), (5, 6)] arg_dict["k"] = [1, 2, 5] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_in_top_k_impl(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_index_add.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import torch as torch_origin from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import unittest from oneflow.test_utils.automated_test_util import * def _test_index_add(test_case, device): torch_origin_x = torch_origin.ones(5, 3).to(device) torch_origin_t = torch_origin.tensor( [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch_origin.float ).to(device) torch_origin_index = torch_origin.tensor([0, 4, 2]).to(device) torch_origin_y = torch_origin.index_add( torch_origin_x, 0, torch_origin_index, torch_origin_t ) torch_origin_y_alpha = torch_origin.index_add( torch_origin_x, 0, torch_origin_index, torch_origin_t, alpha=-1 ) flow_x = flow.ones(5, 3).to(device) flow_t = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float).to(device) flow_index = flow.tensor([0, 4, 2]).to(device) flow_y = flow.index_add(flow_x, 0, flow_index, flow_t) flow_y_alpha = flow.index_add(flow_x, 0, flow_index, flow_t, alpha=-1) test_case.assertTrue( np.allclose(torch_origin_y.cpu().numpy(), flow_y.cpu().numpy(), 1e-05, 1e-05) ) test_case.assertTrue( np.allclose( torch_origin_y_alpha.cpu().numpy(), flow_y_alpha.cpu().numpy(), 1e-05, 1e-05 ) ) # check inplace torch_origin_x.index_add_(0, torch_origin_index, torch_origin_t) flow_x.index_add_(0, flow_index, flow_t) test_case.assertTrue( np.allclose(torch_origin_y.cpu().numpy(), flow_y.cpu().numpy(), 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestIndexAdd(flow.unittest.TestCase): def test_index_add(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_index_add] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @profile(torch.index_add) def profile_index_add(test_case): torch.index_add( torch.ones(50, 30), 0, torch.arange(30), torch.arange(1, 901, dtype=torch.float32).reshape(30, 30), ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_index_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow import oneflow.unittest import unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestIndexSelect(flow.unittest.TestCase): @autotest() def test_index_select_by_random(test_case): device = random_device() # test 4 dimensions tensor dim = random(0, 4).to(int) tensor_dim = [] for i in range(0, 4): tensor_dim.append(random(2, 6).to(int).value()) index = random_tensor( ndim=1, dim0=random(1, 10).to(int), low=0, high=tensor_dim[dim.value()], dtype=int, ).to(device) x = random_tensor( ndim=4, dim0=tensor_dim[0], dim1=tensor_dim[1], dim2=tensor_dim[2], dim3=tensor_dim[3], ).to(device) y = torch.index_select(x, dim, index) return y @autotest(auto_backward=False) def test_index_select_bool_by_random(test_case): device = random_device() # test 4 dimensions tensor dim = random(0, 4).to(int) tensor_dim = [] for i in range(0, 4): tensor_dim.append(random(2, 6).to(int).value()) index = random_tensor( ndim=1, dim0=random(1, 10).to(int), low=0, high=tensor_dim[dim.value()], dtype=int, ).to(device) x = random_tensor( ndim=4, dim0=tensor_dim[0], dim1=tensor_dim[1], dim2=tensor_dim[2], dim3=tensor_dim[3], ).to(device=device, dtype=torch.bool) y = torch.index_select(x, dim, index) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_info.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest def _test_finfo(test_case, dtype): # test finfo without input params if dtype is None: finfo = torch.finfo() else: finfo = torch.finfo(dtype) torch_finfo = finfo.pytorch flow_finfo = finfo.oneflow test_case.assertEqual(torch_finfo.max, flow_finfo.max) test_case.assertEqual(torch_finfo.min, flow_finfo.min) test_case.assertEqual(torch_finfo.bits, flow_finfo.bits) test_case.assertEqual(torch_finfo.eps, flow_finfo.eps) test_case.assertEqual(torch_finfo.tiny, flow_finfo.tiny) test_case.assertEqual(torch_finfo.resolution, flow_finfo.resolution) @flow.unittest.skip_unless_1n1d() class TestTypeInfo(flow.unittest.TestCase): def test_iinfo(test_case): for dtype in [torch.uint8, torch.int8, torch.int32, torch.int64]: iinfo = torch.iinfo(dtype) # checker not implemented for type and # so return all fields as a tuple return iinfo.max, iinfo.min, iinfo.bits def test_finfo(test_case): for dtype in [None, torch.half, torch.bfloat16, torch.float, torch.double]: _test_finfo(test_case, dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_initializer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest class DataChecker: check_list = [ "mean", "std", "min", "max", "value", "lambda_func", ] def __init__(self, **kwargs): self.checkers = {} for key in self.check_list: if key in kwargs: self.checkers[key] = kwargs[key] def __call__(self, test_case, tensor): for func in ["mean", "std"]: if func in self.checkers: of_res = eval(f"tensor.{func}")().numpy() checker_res = self.checkers[func] test_case.assertTrue( np.allclose(of_res, checker_res, rtol=1e-1, atol=1e-1), f"{func} not equal, {of_res} vs {checker_res}", ) if "min" in self.checkers: test_case.assertTrue(np.all(tensor.numpy() >= self.checkers["min"])) if "max" in self.checkers: test_case.assertTrue(np.all(tensor.numpy() <= self.checkers["max"])) if "value" in self.checkers: test_case.assertTrue(np.all(tensor.numpy() == self.checkers["value"])) if "lambda_func" in self.checkers: test_case.assertTrue( np.allclose( tensor.numpy(), self.checkers["lambda_func"](tensor.shape), rtol=1e-4, atol=1e-4, ) ) # NOTE(wyg): register initializers to this list check_func_list = [ # oneflow.nn.init.normal_ { "func": flow.nn.init.normal_, "params": {"mean": 0.0, "std": 1.0}, "checker": DataChecker(mean=0.0, std=1.0), }, # oneflow.nn.init.xavier_normal_ { "func": flow.nn.init.xavier_normal_, "params": {"gain": 1.0}, "checker": DataChecker(mean=0.0, std=0.0625), }, # oneflow.nn.init.kaiming_normal_ { "func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in"}, "checker": DataChecker(mean=0.0, std=0.0883883476), }, { "func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_out"}, "checker": DataChecker(mean=0.0, std=0.0883883476), }, { "func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "leaky_relu"}, "checker": DataChecker(mean=0.0, std=0.0395284708), }, { "func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "linear"}, "checker": DataChecker(mean=0.0, std=0.0625), }, # oneflow.nn.init.trunc_normal_ { "func": flow.nn.init.trunc_normal_, "params": {"mean": 0.0, "std": 1.0, "a": -5.0, "b": 5.0}, "checker": DataChecker(min=-5.0, max=5.0), }, # oneflow.nn.init.uniform_ { "func": flow.nn.init.uniform_, "params": {"a": 0.0, "b": 1.0}, "checker": DataChecker(min=0.0, max=1.0, mean=0.5, std=0.28849875926971436), }, # oneflow.nn.init.xavier_uniform_ { "func": flow.nn.init.xavier_uniform_, "params": {"gain": 1.0}, "checker": DataChecker( min=-0.10825317547305482, max=0.10825317547305482, mean=0.0, std=0.0625 ), }, # oneflow.nn.init.kaiming_uniform_ { "func": flow.nn.init.kaiming_uniform_, "params": {"mode": "fan_in"}, "checker": DataChecker( min=-0.15309310892394865, max=15309310892394865, mean=0.0, std=0.0883883476 ), }, { "func": flow.nn.init.kaiming_uniform_, "params": {"mode": "fan_out"}, "checker": DataChecker( min=-0.15309310892394865, max=15309310892394865, mean=0.0, std=0.0883883476 ), }, { "func": flow.nn.init.kaiming_uniform_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "leaky_relu"}, "checker": DataChecker( min=-0.06846531968814576, max=0.06846531968814576, mean=0.0, std=0.0395284708, ), }, { "func": flow.nn.init.kaiming_uniform_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "linear"}, "checker": DataChecker( min=-0.10825317547305482, max=0.10825317547305482, mean=0.0, std=0.0625 ), }, # oneflow.nn.init.eye_ { "func": flow.nn.init.eye_, "params": {}, "checker": DataChecker(lambda_func=lambda size: np.eye(*size)), }, ] @oneflow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestInitializer(flow.unittest.TestCase): def test_initializer(test_case): default_shape = (256, 256) for device in ["cpu", "cuda"]: for check_func in check_func_list: tensor = flow.empty(*default_shape, device=flow.device(device)) check_func["func"](tensor, **check_func["params"]) try: check_func["checker"](test_case, tensor) except AssertionError as e: print( f"Failed: {check_func['func'].__name__} {check_func['params']}" ) raise e if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_instancenorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_instancenorm1d(test_case, device): input_arr = np.array( [ [ [-0.1091, 2.0041, 0.885, -0.0412], [-1.2055, 0.7442, 2.33, 1.2411], [-1.2466, 0.3667, 1.2267, 0.3043], ], [ [-0.2484, -1.1407, 0.3352, 0.6687], [-0.2975, -0.0227, -0.2302, -0.3762], [-0.7759, -0.6789, 1.1444, 1.8077], ], ], dtype=np.float32, ) output_arr = np.array( [ [ [-0.9262, 1.5395, 0.2337, -0.847], [-1.5486, -0.026, 1.2125, 0.3621], [-1.5807, 0.2287, 1.1933, 0.1587], ], [ [-0.2215, -1.5212, 0.6285, 1.1143], [-0.5016, 1.5917, 0.011, -1.1011], [-1.0207, -0.9346, 0.6833, 1.2719], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm1d(num_features=3, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=1e-3, atol=1e-3)) m.eval() y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=1e-3, atol=1e-3)) def _test_instancenorm2d(test_case, device): input_arr = np.array( [ [ [ [-0.8791, 0.2553, 0.7403, -0.2859], [0.8006, -1.7701, -0.9617, 0.1705], [0.2842, 1.7825, 0.3365, -0.8525], ], [ [0.7332, -0.0737, 0.7245, -0.6551], [1.4461, -0.1827, 0.9737, -2.1571], [0.4657, 0.7244, 0.3378, 0.1775], ], ], [ [ [1.8896, 1.8686, 0.1896, 0.9817], [-0.0671, 1.5569, 1.1449, 0.0086], [-0.9468, -0.0124, 1.3227, -0.6567], ], [ [-0.8472, 1.3012, -1.1065, 0.9348], [1.0346, 1.5703, 0.2419, -0.7048], [0.6957, -0.4523, -0.8819, 1.0164], ], ], ], dtype=np.float32, ) output = np.array( [ [ [ [-0.9155, 0.31, 0.8339, -0.2747], [0.8991, -1.8781, -1.0048, 0.2183], [0.3412, 1.9598, 0.3977, -0.8868], ], [ [0.586, -0.3169, 0.5763, -0.9675], [1.3837, -0.4389, 0.8551, -2.6483], [0.2867, 0.5761, 0.1435, -0.0358], ], ], [ [ [1.374, 1.3515, -0.4466, 0.4017], [-0.7215, 1.0177, 0.5765, -0.6405], [-1.6636, -0.663, 0.7669, -1.353], ], [ [-1.1583, 1.1444, -1.4363, 0.7516], [0.8586, 1.4328, 0.009, -1.0057], [0.4954, -0.7351, -1.1955, 0.8391], ], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm2d(num_features=2, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001)) m.eval() y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001)) def _test_instancenorm3d(test_case, device): input_arr = np.array( [ [ [ [ [1.04569761, 0.22863248, 1.42439335, 1.62249689], [-0.80578825, -0.27276461, 1.04556507, 0.56864134], [-1.24085419, -1.23960097, 0.33451416, -1.84820402], ], [ [-1.511261, 1.06157517, -0.26715858, -1.32888141], [1.17976881, -0.07931171, 0.33910684, -1.93458573], [-1.72659647, 0.79049652, 0.39102785, -1.16264882], ], ], [ [ [0.30067973, -1.2912226, -0.61508225, 0.56454001], [0.87074187, -1.69257376, 0.36119148, -0.31014289], [0.20776964, 1.26195488, -1.37122193, -0.17945234], ], [ [-0.31112407, -0.80682631, 0.8233194, 0.6384975], [0.57617527, 0.45505028, 1.68286151, -1.09590744], [-1.18127546, -1.07529277, 0.52779943, 1.21755926], ], ], ], [ [ [ [-0.12832351, 1.05625455, -0.23253249, -0.64747611], [-0.00738123, -1.41390089, -1.92664144, -0.21427625], [-0.94631219, -0.86493989, 0.21026905, 0.24989732], ], [ [1.3859182, 1.72002107, 0.50091892, 1.04198896], [0.71694594, 1.66417023, -1.63030052, 0.77182641], [0.71545083, 1.96458366, -1.99031931, 1.3196714], ], ], [ [ [1.80091702, 0.02834973, 0.82259214, -1.05597501], [-0.58212207, 0.44205949, -0.14740003, -0.994508], [1.14678114, -0.39196097, 1.2554798, -0.41829324], ], [ [-1.0153903, -0.25755713, -1.81756333, -1.06781159], [1.79680841, -1.9107133, -0.64325796, -1.94640775], [1.30671156, 1.20445339, -1.26262901, -0.79494188], ], ], ], ], dtype=np.float32, ) output_arr = np.array( [ [ [ [ [1.067, 0.3324, 1.4075, 1.5856], [-0.5976, -0.1184, 1.0669, 0.6381], [-0.9888, -0.9877, 0.4276, -1.5349], ], [ [-1.2319, 1.0813, -0.1134, -1.068], [1.1876, 0.0555, 0.4317, -1.6126], [-1.4256, 0.8376, 0.4784, -0.9185], ], ], [ [ [0.3447, -1.3751, -0.6446, 0.6298], [0.9606, -1.8087, 0.4101, -0.3152], [0.2444, 1.3833, -1.4615, -0.174], ], [ [-0.3162, -0.8518, 0.9094, 0.7097], [0.6424, 0.5115, 1.838, -1.1641], [-1.2563, -1.1418, 0.5901, 1.3353], ], ], ], [ [ [ [-0.2327, 0.8016, -0.3236, -0.6859], [-0.1271, -1.3551, -1.8028, -0.3077], [-0.9469, -0.8758, 0.063, 0.0976], ], [ [1.0895, 1.3812, 0.3167, 0.7892], [0.5054, 1.3324, -1.5441, 0.5533], [0.5041, 1.5947, -1.8584, 1.0316], ], ], [ [ [1.7507, 0.1901, 0.8894, -0.7645], [-0.3473, 0.5544, 0.0354, -0.7104], [1.1748, -0.1799, 1.2705, -0.2031], ], [ [-0.7288, -0.0616, -1.435, -0.7749], [1.7471, -1.517, -0.4012, -1.5485], [1.3156, 1.2256, -0.9465, -0.5347], ], ], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm3d(num_features=2, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001)) m.eval() y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001)) def _test_instancenorm1d_backward(test_case, device): input_arr = np.array( [ [ [-0.1091, 2.0041, 0.885, -0.0412], [-1.2055, 0.7442, 2.33, 1.2411], [-1.2466, 0.3667, 1.2267, 0.3043], ], [ [-0.2484, -1.1407, 0.3352, 0.6687], [-0.2975, -0.0227, -0.2302, -0.3762], [-0.7759, -0.6789, 1.1444, 1.8077], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm1d(num_features=2, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05) ) def _test_instancenorm2d_backward(test_case, device): input_arr = np.array( [ [ [ [-0.8791, 0.2553, 0.7403, -0.2859], [0.8006, -1.7701, -0.9617, 0.1705], [0.2842, 1.7825, 0.3365, -0.8525], ], [ [0.7332, -0.0737, 0.7245, -0.6551], [1.4461, -0.1827, 0.9737, -2.1571], [0.4657, 0.7244, 0.3378, 0.1775], ], ], [ [ [1.8896, 1.8686, 0.1896, 0.9817], [-0.0671, 1.5569, 1.1449, 0.0086], [-0.9468, -0.0124, 1.3227, -0.6567], ], [ [-0.8472, 1.3012, -1.1065, 0.9348], [1.0346, 1.5703, 0.2419, -0.7048], [0.6957, -0.4523, -0.8819, 1.0164], ], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm2d(num_features=2, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05) ) def _test_instancenorm3d_backward(test_case, device): input_arr = np.array( [ [ [ [ [1.04569761, 0.22863248, 1.42439335, 1.62249689], [-0.80578825, -0.27276461, 1.04556507, 0.56864134], [-1.24085419, -1.23960097, 0.33451416, -1.84820402], ], [ [-1.511261, 1.06157517, -0.26715858, -1.32888141], [1.17976881, -0.07931171, 0.33910684, -1.93458573], [-1.72659647, 0.79049652, 0.39102785, -1.16264882], ], ], [ [ [0.30067973, -1.2912226, -0.61508225, 0.56454001], [0.87074187, -1.69257376, 0.36119148, -0.31014289], [0.20776964, 1.26195488, -1.37122193, -0.17945234], ], [ [-0.31112407, -0.80682631, 0.8233194, 0.6384975], [0.57617527, 0.45505028, 1.68286151, -1.09590744], [-1.18127546, -1.07529277, 0.52779943, 1.21755926], ], ], ], [ [ [ [-0.12832351, 1.05625455, -0.23253249, -0.64747611], [-0.00738123, -1.41390089, -1.92664144, -0.21427625], [-0.94631219, -0.86493989, 0.21026905, 0.24989732], ], [ [1.3859182, 1.72002107, 0.50091892, 1.04198896], [0.71694594, 1.66417023, -1.63030052, 0.77182641], [0.71545083, 1.96458366, -1.99031931, 1.3196714], ], ], [ [ [1.80091702, 0.02834973, 0.82259214, -1.05597501], [-0.58212207, 0.44205949, -0.14740003, -0.994508], [1.14678114, -0.39196097, 1.2554798, -0.41829324], ], [ [-1.0153903, -0.25755713, -1.81756333, -1.06781159], [1.79680841, -1.9107133, -0.64325796, -1.94640775], [1.30671156, 1.20445339, -1.26262901, -0.79494188], ], ], ], ], dtype=np.float32, ) m = flow.nn.InstanceNorm3d(num_features=2, eps=1e-05, momentum=0.1).to( device=flow.device(device) ) x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True) y = m(x) z = y.sum() z.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestInstanceNorm(flow.unittest.TestCase): def test_instancenorm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_instancenorm1d, _test_instancenorm2d, _test_instancenorm3d, _test_instancenorm1d_backward, _test_instancenorm2d_backward, _test_instancenorm3d_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) # NOTE: in the following tese cases, if set track_running_stats=True, will fail! # it could be some bud to be fixed in nn.InstanceNorm @autotest(n=5, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def test_instancenorm_with_random_data(test_case): height = random(1, 6).to(int) width = random(1, 6).to(int) m = torch.nn.InstanceNorm1d( num_features=height, eps=random().to(float) | nothing(), momentum=random().to(float) | nothing(), affine=random().to(bool), track_running_stats=False, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=height, dim2=width).to(device) y = m(x) return y @autotest(n=5, rtol=1e-3, atol=1e-3) def test_instancenorm_with_random_data2(test_case): channel = random(1, 6).to(int) height = random(1, 6).to(int) width = random(1, 6).to(int) m = torch.nn.InstanceNorm2d( num_features=channel, eps=random().to(float) | nothing(), momentum=random().to(float) | nothing(), affine=random().to(bool), track_running_stats=False, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device) y = m(x) return y @autotest(n=5, rtol=1e-3, atol=1e-3) def test_instancenorm_with_random_data3(test_case): channel = random(1, 6).to(int) depth = random(1, 6).to(int) height = random(1, 6).to(int) width = random(1, 6).to(int) m = torch.nn.InstanceNorm3d( num_features=channel, eps=random().to(float) | nothing(), momentum=random().to(float) | nothing(), affine=random().to(bool), track_running_stats=False, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=5, dim1=channel, dim2=depth, dim3=height, dim4=width).to( device ) y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_interpolate.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_interpolate_linear_1d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 4)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="linear") np_out = [[[1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0]]] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = [[[2.0, 2.0, 2.0, 2.0]]] test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) input.grad = None of_out = flow.nn.functional.interpolate( input, scale_factor=2.0, mode="linear", align_corners=True ) np_out = [ [ [ 1.0, 1.4285714626312256, 1.8571429252624512, 2.2857141494750977, 2.7142856121063232, 3.142857074737549, 3.5714285373687744, 4.0, ] ] ] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = [ [ [ 1.7142856121063232, 2.2857141494750977, 2.2857143878936768, 1.7142856121063232, ] ] ] test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) def _test_interpolate_nearest_1d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 4)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") np_out = [[[1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]]] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = [[[2.0, 2.0, 2.0, 2.0]]] test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) def _test_interpolate_nearest_2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") np_out = np.array( [ [ [ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_nearest_3d(test_case, device): input = flow.tensor( np.arange(1, 9).reshape((1, 1, 2, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") np_out = np.array( [ [ [ [ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ], [ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ], [ [5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0], ], [ [5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0], ], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_bilinear_2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="bilinear") np_out = np.array( [ [ [ [1.0, 1.25, 1.75, 2.0], [1.5, 1.75, 2.25, 2.5], [2.5, 2.75, 3.25, 3.5], [3.0, 3.25, 3.75, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_bicubic_2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="bicubic") np_out = np.array( [ [ [ [0.68359375, 1.015625, 1.5625, 1.89453125], [1.34765625, 1.6796875, 2.2265625, 2.55859375], [2.44140625, 2.7734375, 3.3203125, 3.65234375], [3.10546875, 3.4375, 3.984375, 4.31640625], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_bicubic_same_dim_2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=1.0, mode="bicubic") np_out = [[[[1.0, 2.0], [3.0, 4.0]]]] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[1.0, 1.0], [1.0, 1.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_trilinear_3d(test_case, device): input = flow.tensor( np.arange(1, 9).reshape((1, 1, 2, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="trilinear") np_out = np.array( [ [ [ [ [1.0, 1.25, 1.75, 2.0], [1.5, 1.75, 2.25, 2.5], [2.5, 2.75, 3.25, 3.5], [3.0, 3.25, 3.75, 4.0], ], [ [2.0, 2.25, 2.75, 3.0], [2.5, 2.75, 3.25, 3.5], [3.5, 3.75, 4.25, 4.5], [4.0, 4.25, 4.75, 5.0], ], [ [4.0, 4.25, 4.75, 5.0], [4.5, 4.75, 5.25, 5.5], [5.5, 5.75, 6.25, 6.5], [6.0, 6.25, 6.75, 7.0], ], [ [5.0, 5.25, 5.75, 6.0], [5.5, 5.75, 6.25, 6.5], [6.5, 6.75, 7.25, 7.5], [7.0, 7.25, 7.75, 8.0], ], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_trilinear_3d_align_corners(test_case, device): input = flow.tensor( np.arange(1, 9).reshape((1, 1, 2, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) of_out = flow.nn.functional.interpolate( input, scale_factor=2.0, mode="trilinear", align_corners=True ) np_out = np.array( [ [ [ [ [1.0, 1.3333332538604736, 1.6666667461395264, 2.0], [ 1.6666666269302368, 2.0, 2.3333334922790527, 2.6666665077209473, ], [ 2.3333332538604736, 2.6666665077209473, 3.0, 3.3333334922790527, ], [3.0, 3.3333332538604736, 3.6666667461395264, 4.0], ], [ [ 2.3333334922790527, 2.6666665077209473, 3.0, 3.3333332538604736, ], [3.0, 3.3333330154418945, 3.6666665077209473, 4.0], [ 3.6666665077209473, 4.0, 4.333333492279053, 4.6666669845581055, ], [4.333333492279053, 4.666666030883789, 5.0, 5.3333330154418945], ], [ [3.6666667461395264, 4.0, 4.333333492279053, 4.666666507720947], [4.333333492279053, 4.666666507720947, 5.0, 5.3333330154418945], [5.0, 5.333333492279053, 5.6666669845581055, 6.0], [ 5.6666669845581055, 6.0, 6.333333492279053, 6.6666669845581055, ], ], [ [5.0, 5.3333330154418945, 5.666666507720947, 6.0], [ 5.666666507720947, 5.999999523162842, 6.3333330154418945, 6.666666507720947, ], [6.333333492279053, 6.666666030883789, 7.0, 7.333333492279053], [7.0, 7.3333330154418945, 7.6666669845581055, 8.0], ], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = [ [ [ [[7.999999523162842, 8.0], [7.999999523162842, 8.0]], [[8.0, 8.0], [8.0, 8.0]], ] ] ] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_area_1d(test_case, device): input = flow.tensor( np.array( [ [ [ 0.05580734834074974, -0.6875145435333252, -1.654430866241455, -0.6225992441177368, 0.10183599591255188, 0.05019790679216385, -1.2537643909454346, 0.14907236397266388, ] ] ] ), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out_1 = flow.nn.functional.interpolate(input, size=4, mode="area") of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode="area") np_out = np.array( [ [ [ -0.3158535957336426, -1.1385149955749512, 0.07601694762706757, -0.5523459911346436, ] ] ] ) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05)) of_out_1 = of_out_1.sum() of_out_1.backward() np_grad = np.array([[[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]]]) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_area_2d(test_case, device): input = flow.tensor( np.array( [ [ [ [ 0.10039155930280685, 0.04879157617688179, -1.0515470504760742, 0.9466001987457275, ], [ 0.45375481247901917, 0.23611211776733398, 1.343685269355774, 0.3979687988758087, ], [ 0.05580734834074974, -0.6875145435333252, -1.654430866241455, -0.6225992441177368, ], [ 0.10183599591255188, 0.05019790679216385, -1.2537643909454346, 0.14907236397266388, ], ] ] ] ), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out_1 = flow.nn.functional.interpolate(input, size=(2, 2), mode="area") of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode="area") np_out = np.array( [ [ [ [0.20976251363754272, 0.4091767966747284], [-0.1199183315038681, -0.8454304933547974], ] ] ] ) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05)) of_out_1 = of_out_1.sum() of_out_1.backward() np_grad = np.array( [ [ [ [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], ] ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_area_3d(test_case, device): input = flow.tensor( np.array( [ [ [ [ [ -1.077571799600885, -0.7804538890365837, -1.2627538752119443, 0.9993507145120477, ], [ 2.0222532489157516, 1.103451377699465, -0.4377324754879578, 1.890491810587517, ], [ -0.5593861899064654, -0.4949520241526519, -0.18536721363519787, -0.6098969866775772, ], [ -1.6536215260171816, -1.0392583540436786, 0.3686776597613967, -0.5356882834951805, ], ], [ [ -1.2617900664449953, -1.4390921091631532, 0.20654399652431357, 0.8186472101906713, ], [ -0.3033378863400014, -0.8173269764076293, -0.3767515097625614, -0.11021655039337777, ], [ -0.22977043608192885, 1.2717196366649905, -0.4790851297878291, -1.4495369404727856, ], [ -1.2802093286977783, -0.11184514806663474, 1.7022167087210984, -1.7354837287725355, ], ], [ [ 2.4706497991773606, -0.6549702631973298, -0.9318107079571676, 1.4652904271682428, ], [ 1.1419864234341397, 1.389909081086008, 0.9657841900525568, -0.8563114264976619, ], [ 0.19515087084250754, -0.37808457398571094, 0.2938625398496183, 0.9279930510353327, ], [ -0.9374118277994007, 0.3341831730452431, -0.2792542765303833, 0.38029090707066726, ], ], [ [ 0.5918686659736041, -0.7870631089938902, -0.9534344874245392, 0.31341612954718795, ], [ 0.7509029444145228, -0.9299288398562323, -0.7343054052782476, -0.8806481590696694, ], [ -0.4707853016353985, 0.12253641652645629, 0.5088022039832846, 0.520391789327562, ], [ -0.0861300651163632, 0.30291348404866386, -0.6268565873680123, -0.27469204305759976, ], ], ] ] ] ), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out_1 = flow.nn.functional.interpolate(input, size=(2, 2, 2), mode="area") of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode="area") np_out = np.array( [ [ [ [ [-0.3192335125472539, 0.2159474151198386], [-0.5121654212876662, -0.3655204892948264], ], [ [0.4966693377547728, -0.2015024299324123], [-0.11470347800925032, 0.18131719803880864], ], ] ] ] ) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05)) of_out_1 = of_out_1.sum() of_out_1.backward() np_grad = np.array( [ [ [ [ [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], ], [ [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], ], [ [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], ], [ [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], [0.125, 0.125, 0.125, 0.125], ], ] ] ] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_output_size_arg_with_scalar(test_case, device): mode = "bicubic" x = flow.Tensor(8, 32, 64).to(device) window = 16 t = x.shape[2] x = x[:, None] np_center = np.random.randint(window, t - window, (1,))[0] np_warped = np.random.randint(np_center - window, np_center + window, (1,))[0] + 1 center = flow.tensor(np_center) warped = flow.tensor(np_warped) res = flow.nn.functional.interpolate( x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False ) test_case.assertTrue(np.array_equal(res.size()[0], 8)) test_case.assertTrue(np.array_equal(res.size()[1], 1)) @flow.unittest.skip_unless_1n1d() class TestInterpolate(flow.unittest.TestCase): def test_interpolate(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_interpolate_linear_1d, _test_interpolate_nearest_1d, _test_interpolate_nearest_2d, _test_interpolate_nearest_3d, _test_interpolate_bilinear_2d, _test_interpolate_bicubic_2d, _test_interpolate_bicubic_same_dim_2d, _test_interpolate_trilinear_3d, _test_interpolate_trilinear_3d_align_corners, _test_interpolate_area_1d, _test_interpolate_area_2d, _test_interpolate_area_3d, _test_interpolate_output_size_arg_with_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): for i in range(100): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_inv.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import time import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLinalgInv(flow.unittest.TestCase): @unittest.skip("TODO: peihong, fix this test") @autotest(n=5, rtol=1e-2) def test_inv_3by3_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim0=3, dim1=3, low=-1).to(device) return torch.linalg.inv(x) @autotest(n=5, rtol=1e-2) def test_inv_batch_3by3_with_random_data(test_case): device = random_device() x = random_tensor(ndim=3, dim0=random(), dim1=3, dim2=3, low=-1).to(device) return torch.linalg.inv(x) @autotest(n=5, rtol=1e-2) def test_inv_random_square_with_random_data(test_case): device = random_device() square_dim = random() x = random_tensor(ndim=4, dim2=square_dim, dim3=square_dim, low=-1).to(device) return torch.linalg.inv(x) @profile(torch.linalg.inv) def profile_linalg_inv(test_case): torch.linalg.inv(torch.randn(1, 32, 4, 4)) torch.linalg.inv(torch.randn(16, 32, 4, 4)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_isclose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest rtol = 1e-3 def _perturbate(x): shape = x.oneflow.shape device = x.device diff = ( random_tensor(len(shape), *shape, low=-1, high=1, requires_grad=False).to( device ) * rtol * 2 ) return x + diff @flow.unittest.skip_unless_1n1d() class TestIsClose(flow.unittest.TestCase): @autotest(n=10, auto_backward=False, check_graph=False) def test_isclose_with_random_data(test_case): device = random_device() x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = torch.isclose(x1, x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_isclose_with_0dim_data(test_case): device = random_device() x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = torch.isclose(x1, x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_tensor_isclose_with_random_data(test_case): device = random_device() x1 = random_tensor(requires_grad=False).to(device) x2 = _perturbate(x1) y = x1.isclose(x2, rtol=rtol) return y @autotest(n=10, auto_backward=False, check_graph=False) def test_isclose_broadcast(test_case): device = random_device() shape = random_tensor(2, 2, 4).oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = _perturbate(x1[:, :1]) y = torch.isclose(x1, x2, rtol=rtol) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_jit_script_api.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList def _test_jit_script_api(test_case): @flow.jit.script def add2(x): return x + x x = flow.randn(2, 3) y = add2(x) test_case.assertTrue(x.size(), y.size()) def _test_jit_ignore_api(test_case): @flow.jit.ignore def add2(x): return x + x x = flow.randn(2, 3) y = add2(x) test_case.assertTrue(x.size(), y.size()) @flow.unittest.skip_unless_1n1d() class TestJitScriptApi(flow.unittest.TestCase): def test_jit_script(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_jit_script_api, _test_jit_ignore_api] for arg in GenArgList(arg_dict): arg[0](test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_layer_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest import oneflow as flow import oneflow.unittest import torch def _layer_norm(x, normalized_shape, weight=None, bias=None, eps=1e-6): begin_norm_axis = len(x.shape) - len(normalized_shape) begin_params_axis = len(x.shape) - len(normalized_shape) if weight is not None and bias is not None: return flow._C.layer_norm_affine( x, weight, bias, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, epsilon=eps, ) else: return flow._C.layer_norm( x, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, epsilon=eps, ) def _test_layer_norm( test_case, shape, normalized_shape, affine=True, eps=1e-6, dtype=flow.float32, device="cuda", backward=True, ): np_x = np.random.randn(*shape).astype(np.float32) if affine: np_weight = np.random.randn(*normalized_shape).astype(np.float32) np_bias = np.random.randn(*normalized_shape).astype(np.float32) # torch process torch_dtype = torch.float16 if dtype is flow.float16 else torch.float32 torch_x = torch.tensor(np_x).to(device=device, dtype=torch_dtype) if backward: torch_x.requires_grad_(True) torch_weight = None torch_bias = None if affine: torch_weight = torch.tensor(np_weight).to(device=device, dtype=torch_dtype) torch_bias = torch.tensor(np_bias).to(device=device, dtype=torch_dtype) if backward: torch_weight.requires_grad_(True) torch_bias.requires_grad_(True) torch_y = torch.nn.functional.layer_norm( torch_x, normalized_shape, torch_weight, torch_bias, eps ) if backward: np_rand_init_grad = np.random.randn(*tuple(torch_y.shape)).astype(np.float32) torch_rand_init_grad = torch.tensor(np_rand_init_grad).to( device=device, dtype=torch_dtype ) (torch_y * torch_rand_init_grad).sum().backward() torch_x_grad = torch_x.grad.detach().cpu().numpy() if affine: torch_weight_grad = torch_weight.grad.detach().cpu().numpy() torch_bias_grad = torch_bias.grad.detach().cpu().numpy() torch_y = torch_y.detach().cpu().numpy() # oneflow process x = flow.tensor(np_x).to(device=device, dtype=dtype) if backward: x.requires_grad_(True) weight = None bias = None if affine: weight = flow.tensor(np_weight).to(device=device, dtype=dtype) bias = flow.tensor(np_bias).to(device=device, dtype=dtype) if backward: weight.requires_grad_(True) bias.requires_grad_(True) y = _layer_norm(x, normalized_shape, weight, bias, eps) if backward: # np_rand_init_grad = np.random.randn(*tuple(y.shape)).astype(np.float32) rand_init_grad = flow.tensor(np_rand_init_grad).to(device=device, dtype=dtype) (y * rand_init_grad).sum().backward() x_grad = x.grad.detach().cpu().numpy() if affine: weight_grad = weight.grad.detach().cpu().numpy() bias_grad = bias.grad.detach().cpu().numpy() y = y.detach().cpu().numpy() def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8): test_case.assertTrue( np.allclose(a, b, atol=atol, rtol=rtol), f"\n{'=' * 80}" f"\n{a_name}:" f"\n{a}" f"\n{'-' * 80}" f"\n{b_name}:" f"\n{b}" f"\n{'-' * 80}" f"\ndiff:" f"\n{a - b}" f"\n{'*' * 80}" f"\nshape={shape}" f"\normalized_shape={normalized_shape}" f"\naffine={affine}" f"\ndtype={dtype}" f"\ndevice={device}" f"\n{a_name} vs. {b_name} max abs diff: {np.max(np.abs(a - b))}", ) if dtype is flow.float16: compare(y, torch_y, "y", "torch_y", 1e-2, 1e-2) if backward: compare(x_grad, torch_x_grad, "x_grad", "torch_x_grad", 1e-2, 1e-2) if affine: compare( weight_grad, torch_weight_grad, "weight_grad", "torch_weight_grad", 1e-2, 1e-2, ) compare( bias_grad, torch_bias_grad, "bias_grad", "torch_bias_grad", 1e-2, 1e-2, ) else: compare(y, torch_y, "y", "torch_y") if backward: compare(x_grad, torch_x_grad, "x_grad", "torch_x_grad") if affine: compare( weight_grad, torch_weight_grad, "weight_grad", "torch_weight_grad", ) compare( bias_grad, torch_bias_grad, "bias_grad", "torch_bias_grad", ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestLayerNorm(flow.unittest.TestCase): def test_no_affine(test_case): _test_layer_norm( test_case, shape=[4, 16], normalized_shape=[16], affine=False, ) def test_warp_impl(test_case): _test_layer_norm( test_case, shape=[32, 1024], normalized_shape=[1024], dtype=flow.float16, ) _test_layer_norm(test_case, shape=[16, 512], normalized_shape=[512]) _test_layer_norm(test_case, shape=[15, 512], normalized_shape=[512]) _test_layer_norm(test_case, shape=[16, 511], normalized_shape=[511]) _test_layer_norm(test_case, shape=[13, 499], normalized_shape=[499]) def test_block_smem_impl(test_case): _test_layer_norm( test_case, shape=[16, 2048], normalized_shape=[2048], dtype=flow.float16, ) _test_layer_norm(test_case, shape=[8, 1536], normalized_shape=[1536]) _test_layer_norm(test_case, shape=[8, 2048], normalized_shape=[2048]) _test_layer_norm(test_case, shape=[7, 1536], normalized_shape=[1536]) _test_layer_norm(test_case, shape=[8, 1533], normalized_shape=[1533]) _test_layer_norm(test_case, shape=[7, 1533], normalized_shape=[1533]) def test_block_uncached_impl(test_case): _test_layer_norm( test_case, shape=[16, 1024 * 1024], normalized_shape=[1024 * 1024], dtype=flow.float16, ) _test_layer_norm( test_case, shape=[8, 1024], normalized_shape=[1024], dtype=flow.double ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_lerp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLerp(flow.unittest.TestCase): @autotest(check_graph=False) def test_lerp_with_broadcast_data(test_case): device = random_device() start = random_tensor(ndim=2, dim0=3, dim1=1).to(device) end = random_tensor(ndim=2, dim0=1, dim1=3).to(device) weight = random_tensor(ndim=1, dim0=1).to(device) return torch.lerp(start, end, weight) @autotest() def test_lerp_with_random_data(test_case): device = random_device() start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) return torch.lerp( start, end, oneof(weight, random().to(int), random().to(float)) ) @autotest() def test_tesnor_lerp_with_random_data(test_case): device = random_device() start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) return start.lerp(end, oneof(weight, random().to(int), random().to(float))) @autotest() def test_tesnor_inplace_lerp_with_random_data(test_case): device = random_device() start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01 end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01 weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01 return start.lerp_(end, oneof(weight, random().to(int), random().to(float))) @profile(torch.lerp) def profile_lerp(test_case): torch.lerp( torch.randn(1, 32, 4, 4), torch.randn(1, 32, 4, 4), torch.randn(1, 32, 4, 4) ) torch.lerp( torch.randn(8, 32, 4, 4), torch.randn(8, 32, 4, 4), torch.randn(8, 32, 4, 4) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_less.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_less_normal(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.lt(input1, input2) np_out = np.less(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_symbol(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = input1 < input2 np_out = np.less(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_int_scalar(test_case, device): np_arr = np.random.randn(2, 3, 4, 5) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 1 of_out = input1 < input2 np_out = np.less(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_int_tensor_int_scalr(test_case, device): np_arr = np.random.randint(2, size=(2, 3, 4, 5)) input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device)) input2 = 1 of_out = input1 < input2 np_out = np.less(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_float_scalar(test_case, device): np_arr = np.random.randn(3, 2, 5, 7) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 2.3 of_out = input1 < input2 np_out = np.less(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLess(flow.unittest.TestCase): def test_less(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_less_normal, _test_less_symbol, _test_less_int_scalar, _test_less_int_tensor_int_scalr, _test_less_float_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, check_graph=True) def test_less_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.lt(x1, oneof(x2, random().to(int).to(float))) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_less_with_0dim_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(ndim=0).to(device) x2 = random_tensor(ndim=0).to(device) y = torch.lt(x1, oneof(x2, random().to(int).to(float))) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_tensor_less_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y1 = x1.lt(oneof(x2, random().to(int), random().to(float))) y2 = x1 < x2 return (y1, y2) @autotest(n=10, auto_backward=False, check_graph=True) def test_less_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) x2 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.lt(x1, oneof(x2, random().to(int).to(float))) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_tensor_less_with_0dim_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(ndim=0).to(device) x2 = random_tensor(ndim=0).to(device) y1 = x1.lt(oneof(x2, random().to(int), random().to(float))) y2 = x1 < x2 return (y1, y2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_less_equal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_less_equal_normal(test_case, device): input1 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.le(input1, input2) np_out = np.less_equal(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_equal_symbol(test_case, device): input1 = flow.tensor( np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) input2 = flow.tensor( np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = input1 <= input2 np_out = np.less_equal(input1.numpy(), input2.numpy()) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_equal_int_scalar(test_case, device): np_arr = np.random.randn(2, 3, 4, 5) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 1 of_out = input1 <= input2 np_out = np.less_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_equal_int_tensor_int_scalr(test_case, device): np_arr = np.random.randint(2, size=(2, 3, 4, 5)) input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device)) input2 = 1 of_out = input1 <= input2 np_out = np.less_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_less_equal_float_scalar(test_case, device): np_arr = np.random.randn(3, 2, 5, 7) input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) input2 = 2.3 of_out = input1 <= input2 np_out = np.less_equal(np_arr, input2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLessEqual(flow.unittest.TestCase): def test_less_equal(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_less_equal_normal, _test_less_equal_symbol, _test_less_equal_int_scalar, _test_less_equal_int_tensor_int_scalr, _test_less_equal_float_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_linalg_cross.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLinalgCross(flow.unittest.TestCase): # TODO(peihong): PyTorch 1.10 has no torch.linalg.cross, so uncomment the below code when PyTorch in ci is upgraded to 1.11. # @autotest(n=5) # def test_linalg_cross_with_random_data(test_case): # device = random_device() # ndim = np.random.randint(2, 6) # shape = list(np.random.randint(16, size=ndim)) # index = np.random.randint(ndim) # shape[index] = 3 # x = random_tensor(ndim, *shape).to(device) # y = random_tensor(ndim, *shape).to(device) # return torch.linalg.cross(x, y, dim=index) # @autotest(n=10) # def test_linalg_cross_with_random_data_broadcast(test_case): # device = random_device() # ndim = np.random.randint(3, 6) # shape = list(np.random.randint(16, size=ndim)) # indexes = list(np.random.choice(ndim, 3)) # shape[indexes[0]] = 3 # x_shape = shape # y_shape = shape[:] # x_shape[indexes[1]] = 1 # y_shape[indexes[2]] = 1 # x = random_tensor(ndim, *x_shape).to(device) # y = random_tensor(ndim, *y_shape).to(device) # return torch.linalg.cross(x, y, dim=indexes[0]) # @autotest(n=1) # def test_linalg_cross_with_random_data_broadcast_different_num_axes(test_case): # device = random_device() # x = random_tensor(4, 4, 5, 3, 5).to(device) # y = random_tensor(3, 1, 3, 5).to(device) # return torch.linalg.cross(x, y, dim=2) # @autotest(n=5) # def test_linalg_cross_with_random_data_default_dim(test_case): # device = random_device() # ndim = np.random.randint(2, 6) # shape = list(np.random.randint(16, size=ndim)) # index = np.random.randint(ndim) # shape[index] = 3 # x = random_tensor(ndim, *shape).to(device) # y = random_tensor(ndim, *shape).to(device) # return torch.linalg.cross(x, y) @autotest(n=5) def test_cross_with_random_data_default_dim(test_case): device = random_device() ndim = np.random.randint(2, 6) shape = list(np.random.randint(16, size=ndim)) index = np.random.randint(ndim) shape[index] = 3 x = random_tensor(ndim, *shape).to(device) y = random_tensor(ndim, *shape).to(device) return torch.cross(x, y) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_linear.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_linear_no_bias(test_case, device): linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.3) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) flow.nn.init.constant_(linear.weight, 2.3) of_out = linear(x) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_linear_with_bias(test_case, device): linear = flow.nn.Linear(3, 8) linear = linear.to(device) input_arr = np.array( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=np.float32, ) np_weight = np.ones((3, 8)).astype(np.float32) np_weight.fill(2.068758) np_bias = np.ones(8) np_bias.fill(0.23) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_out = linear(x) np_out = np.matmul(input_arr, np_weight) np_out += np_bias test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_linear_3_dimension_input(test_case, device): input_arr = np.random.randn(2, 3, 4) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) linear = flow.nn.Linear(4, 5, True) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 5.6) flow.nn.init.constant_(linear.bias, 0.78) of_out = linear(x) np_weight = np.ones((4, 5)).astype(np.float32) np_weight.fill(5.6) np_bias = np.ones(5) np_bias.fill(0.78) np_out = np.matmul(input_arr, np_weight) np_out += np_bias test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_linear_4_dimension_input(test_case, device): input_arr = np.random.randn(4, 5, 6, 7) x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device)) linear = flow.nn.Linear(7, 3, False) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 11.3) of_out = linear(x) np_weight = np.ones((7, 3)).astype(np.float32) np_weight.fill(11.3) np_out = np.matmul(input_arr, np_weight) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_identity(test_case, device): linear = flow.nn.Identity(54, unused_argument1=0.1, unused_argument2=False) linear = linear.to(device) x = flow.tensor( np.random.rand(2, 3, 4, 5), dtype=flow.float32, device=flow.device(device) ) y = linear(x) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) def _test_linear_backward_with_bias(test_case, device): linear = flow.nn.Linear(3, 8) linear = linear.to(device) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, device=flow.device(device), requires_grad=True, ) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_out = linear(x) of_out = of_out.sum() of_out.backward() np_grad = np.array( [ [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], ] ) test_case.assertTrue(np.allclose(np_grad, x.grad.numpy(), 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestLinear(flow.unittest.TestCase): def test_linear_forward(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_linear_no_bias, _test_linear_with_bias, _test_linear_3_dimension_input, _test_linear_4_dimension_input, _test_identity, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_linear_backward(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_linear_backward_with_bias] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, rtol=1e-2) def test_linear_with_random_data(test_case): input_size = random() m = torch.nn.Linear( in_features=input_size, out_features=random(), bias=random() | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=2, dim1=input_size).to(device) y = m(x) return y @autotest(n=5, rtol=1e-2, atol=1e-4) def test_linear_with_device_and_dtype(test_case): input_size = random() device = random_device() m = torch.nn.Linear( in_features=input_size, out_features=random(), bias=random() | nothing(), device=device, dtype=torch.float, ) m.train(random()) m.to(device) x = random_tensor(ndim=2, dim1=input_size).to(device) y = m(x) return y @autotest(n=5, rtol=1e-3) def test_nn_functional_linear_with_random_data(test_case): input_size = random() device = random_device() x = random_tensor(ndim=2, dim1=input_size).to(device) weight = random_tensor(ndim=2, dim1=input_size).to(device) y = torch.nn.functional.linear(x, weight) return y @autotest(n=5, rtol=1e-2) def test_nn_functional_bias_linear_with_random_data(test_case): input_size = random() bias_size = random() device = random_device() x = random_tensor(ndim=2, dim1=input_size).to(device) weight = random_tensor(ndim=2, dim0=bias_size, dim1=input_size).to(device) bias = random_tensor(ndim=1, dim0=bias_size).to(device) y = torch.nn.functional.linear(x, weight, bias) return y @autotest(n=5) def test_identity_with_random_data(test_case): m = torch.nn.Identity( x=random().to(int), unused_argument1=random().to(float), unused_argument2=random().to(float), ) m.train(random()) device = random_device() m.to(device) x = random_tensor().to(device) y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_linspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLinspace(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_linspace_int_with_random_data(test_case): start = random().to(int) end = start + random().to(int) steps = random(0, end - start).to(int) x = torch.linspace(start=start, end=end, steps=steps) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_linspace_float_with_random_data(test_case): start = random() end = start + random() steps = random(0, end - start).to(int) x = torch.linspace(start=start, end=end, steps=steps) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False) def test_linspace_with_scalar_tensor_as_params(test_case): start = random_tensor(2, 3, 4, requires_grad=False).mean() end = start + random_tensor(2, 3, 4, requires_grad=False).mean() steps = random(0, 10).to(int) y = torch.linspace(start=start, end=end, steps=steps) return y def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.linspace(start=0, end=10, steps=2, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_linspace_in_transformer_bug(test_case): drop_path_rate = 0.1 depths = [2, 2, 6, 2] flow_res = flow.linspace(0, drop_path_rate, sum(depths)) torch_res = np.array( [ 0.0000, 0.0091, 0.0182, 0.0273, 0.0364, 0.0455, 0.0545, 0.0636, 0.0727, 0.0818, 0.0909, 0.1000, ] ) test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4)) drop_path_rate = 0.2 depths = [2, 2, 6, 2] flow_res = flow.linspace(0, drop_path_rate, sum(depths)) torch_res = np.array( [ 0.0000, 0.0182, 0.0364, 0.0545, 0.0727, 0.0909, 0.1091, 0.1273, 0.1455, 0.1636, 0.1818, 0.2000, ] ) test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4)) drop_path_rate = 0.3 depths = [2, 2, 18, 2] flow_res = flow.linspace(0, drop_path_rate, sum(depths)) torch_res = np.array( [ 0.0000, 0.0130, 0.0261, 0.0391, 0.0522, 0.0652, 0.0783, 0.0913, 0.1043, 0.1174, 0.1304, 0.1435, 0.1565, 0.1696, 0.1826, 0.1957, 0.2087, 0.2217, 0.2348, 0.2478, 0.2609, 0.2739, 0.2870, 0.3000, ] ) test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4)) drop_path_rate = 0.1 depths = [2, 2, 18, 2] flow_res = flow.linspace(0, drop_path_rate, sum(depths)) torch_res = np.array( [ 0.0000, 0.0043, 0.0087, 0.0130, 0.0174, 0.0217, 0.0261, 0.0304, 0.0348, 0.0391, 0.0435, 0.0478, 0.0522, 0.0565, 0.0609, 0.0652, 0.0696, 0.0739, 0.0783, 0.0826, 0.0870, 0.0913, 0.0957, 0.1000, ] ) test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4)) drop_path_rate = 0.5 depths = [2, 2, 18, 2] flow_res = flow.linspace(0, drop_path_rate, sum(depths)) torch_res = np.array( [ 0.0000, 0.0217, 0.0435, 0.0652, 0.0870, 0.1087, 0.1304, 0.1522, 0.1739, 0.1957, 0.2174, 0.2391, 0.2609, 0.2826, 0.3043, 0.3261, 0.3478, 0.3696, 0.3913, 0.4130, 0.4348, 0.4565, 0.4783, 0.5000, ] ) test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4)) def test_linspace_start_equal_end_bug(test_case): flow_res = flow.linspace(0, 0.0, 12).numpy() torch_res = np.array( [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ) test_case.assertTrue(np.allclose(flow_res, torch_res, atol=1e-4)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_log1p.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLog1pModule(flow.unittest.TestCase): @autotest(check_graph=True) def test_log1p_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.log1p(x) @autotest(check_graph=True) def test_log1p_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) return torch.log1p(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logaddexp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestLogAddExpModule(flow.unittest.TestCase): @autotest(n=3, check_graph=True) def test_log_add_exp_against_pytorch(test_case): device = random_device() dim1 = random(1, 5) dim2 = random(1, 5) x = random_tensor(2, dim1, dim2).to(device) y = random_tensor(2, dim1, dim2).to(device) z = torch.logaddexp(x, y) return z @autotest(n=3, check_graph=True) def test_log_add_exp_with_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(2, random(1, 5), random(1, 5)).to(device) z = torch.logaddexp(x, y) return z @autotest(n=3, check_graph=True) def test_tensor_log_add_exp_against_pytorch(test_case): device = random_device() dim1 = random(1, 5) dim2 = random(1, 5) x = random_tensor(2, dim1, dim2).to(device) y = random_tensor(2, dim1, dim2).to(device) z = x.logaddexp(y) return z @autotest(n=3, check_graph=True) def test_tensor_log_add_exp_with_0dim_tensor(test_case): device = random_device() y = random_tensor(ndim=0).to(device) x = random_tensor(2, random(1, 5), random(1, 5)).to(device) z = x.logaddexp(y) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logical_and.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow from oneflow.test_utils.automated_test_util import * def _test_logical_and(test_case, shape, dtype, device): np_input = np.random.randint(3, size=shape) np_other = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=dtype, device=flow.device(device)) other = flow.tensor(np_other, dtype=dtype, device=flow.device(device)) of_out = flow.logical_and(input, other) np_out = np.logical_and(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) x = torch.ones(3).byte() y = torch.ones(3).byte() z = (x & ~y).bool() test_case.assertTrue(np.array_equal(z.numpy(), [False, False, False])) def _test_tensor_logical_and(test_case, shape, dtype, device): np_input = np.random.randint(3, size=shape) np_other = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=dtype, device=flow.device(device)) other = flow.tensor(np_other, dtype=dtype, device=flow.device(device)) of_out = input.logical_and(other) np_out = np.logical_and(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_scalar_logical_and(test_case, shape, scalar, dtype, device): np_input = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=dtype, device=flow.device(device)) of_out = input.logical_and(scalar) np_out = np.logical_and(np_input, scalar) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLogicalAndModule(flow.unittest.TestCase): def test_logical_and(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_logical_and, _test_tensor_logical_and, ] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["dtype"] = [flow.float32, flow.int32] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_scalar_logical_and(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_tensor_scalar_logical_and] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["scalar"] = [1, 0] arg_dict["dtype"] = [flow.float32, flow.int32] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_and_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.logical_and(x1, x2) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_and_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) x2 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.logical_and(x1, x2) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logical_not.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow from oneflow.test_utils.automated_test_util import * def _test_logical_not(test_case, shape, device): np_input = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) of_out = flow.logical_not(input) np_out = np.logical_not(np_input) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_logical_not(test_case, shape, device): np_input = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) of_out = input.logical_not() np_out = np.logical_not(np_input) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLogicalNotModule(flow.unittest.TestCase): def test_logical_not(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_logical_not, _test_tensor_logical_not, ] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_not_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.logical_not(x1) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_not_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.logical_not(x1) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logical_or.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow from oneflow.test_utils.automated_test_util import * def _test_logical_or(test_case, shape, device): np_input = np.random.randint(3, size=shape) np_other = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = flow.logical_or(input, other) np_out = np.logical_or(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_logical_or(test_case, shape, device): np_input = np.random.randint(3, size=shape) np_other = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = input.logical_or(other) np_out = np.logical_or(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_scalar_logical_or(test_case, shape, scalar, dtype, device): np_input = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=dtype, device=flow.device(device)) of_out = input.logical_or(scalar) np_out = np.logical_or(np_input, scalar) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLogicalOrModule(flow.unittest.TestCase): def test_logical_or(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_logical_or, _test_tensor_logical_or, ] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_scalar_logical_or(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_tensor_scalar_logical_or] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["scalar"] = [1, 0] arg_dict["dtype"] = [flow.float32, flow.int32] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_or_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.logical_or(x1, x2) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_or_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) x2 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.logical_or(x1, x2) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logical_reduce.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestLogicalReduce(flow.unittest.TestCase): @autotest(n=5, auto_backward=False) def test_sum_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.sum(x, dim) @autotest(n=5, auto_backward=False) def test_mean_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.mean(x, dim) @autotest(n=5, auto_backward=False) def test_all_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.all(x, dim) @autotest(n=5, auto_backward=False) def test_any_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.any(x, dim) @autotest(n=5, auto_backward=False) def test_prod_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.prod(x, dim) @autotest(n=5, auto_backward=False) def test_sum_keepdim_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.sum(x, dim, keepdim=True) @autotest(n=5, auto_backward=False) def test_mean_keepdim_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.mean(x, dim, keepdim=True) @autotest(n=5, auto_backward=False) def test_all_keepdim_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.all(x, dim, keepdim=True) @autotest(n=5, auto_backward=False) def test_any_keepdim_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.any(x, dim, keepdim=True) @autotest(n=5, auto_backward=False) def test_prod_keepdim_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.prod(x, dim, keepdim=True) @autotest(n=5, auto_backward=False) def test_scalar_reduce_sum_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.sum(x) @autotest(n=5, auto_backward=False) def test_scalar_reduce_mean_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.mean(x) @autotest(n=5, auto_backward=False) def test_scalar_reduce_all_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.all(x) @autotest(n=5, auto_backward=False) def test_scalar_reduce_any_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.any(x) @autotest(n=5, auto_backward=False) def test_scalar_reduce_prod_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device) return torch.prod(x) @autotest(n=5, auto_backward=False) def test_all_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return torch.all(x, dim) @autotest(auto_backward=False, check_graph=True) def test_max_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return torch.max(x, dim) @autotest(auto_backward=False, check_graph=True) def test_min_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return torch.min(x, dim) @autotest(n=5, auto_backward=False) def test_any_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return torch.any(x, dim) @autotest(n=5, auto_backward=False) def test_reduce_all_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0, requires_grad=False).to(device) return torch.all(x) @autotest(n=5, auto_backward=False) def test_reduce_all_0size_tensor(test_case): device = random_device() x = torch.empty(0, 2).to(device) return torch.all(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logical_xor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow from oneflow.test_utils.automated_test_util import * def _test_logical_xor_int(test_case, shape, device): np_input = np.random.randint(-2, 4, size=shape) np_other = np.random.randint(-2, 4, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = flow.logical_xor(input, other) np_out = np.logical_xor(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_logical_xor_float(test_case, shape, device): np_input = np.random.uniform(low=-5, high=5, size=shape) np_other = np.random.uniform(low=-5, high=5, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = flow.logical_xor(input, other) np_out = np.logical_xor(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_logical_xor_int(test_case, shape, device): np_input = np.random.randint(-2, 4, size=shape) np_other = np.random.randint(-2, 4, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = input.logical_xor(other) np_out = np.logical_xor(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_logical_xor_float(test_case, shape, device): np_input = np.random.uniform(low=-5, high=5, size=shape) np_other = np.random.uniform(low=-5, high=5, size=shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device)) of_out = input.logical_xor(other) np_out = np.logical_xor(np_input, np_other) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_scalar_logical_xor(test_case, shape, scalar, dtype, device): np_input = np.random.randint(3, size=shape) input = flow.tensor(np_input, dtype=dtype, device=flow.device(device)) of_out = input.logical_xor(scalar) np_out = np.logical_xor(np_input, scalar) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestLogicalXorModule(flow.unittest.TestCase): def test_logical_xor(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_logical_xor_int, _test_tensor_logical_xor_int, _test_logical_xor_float, _test_tensor_logical_xor_float, ] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_scalar_logical_xor(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_tensor_scalar_logical_xor] arg_dict["shape"] = [(2, 3), (2, 4, 5)] arg_dict["scalar"] = [1, 0] arg_dict["dtype"] = [flow.float32, flow.int32] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_xor_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device) x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.logical_xor(x1, x2) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_xor_bool_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) x2 = random_tensor(len(shape), *shape, requires_grad=False).to( device=device, dtype=torch.bool ) y = torch.logical_xor(x1, x2) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logspace.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLogspace(flow.unittest.TestCase): @autotest(n=5, auto_backward=False) def test_logspace_int_with_random_data(test_case): start = random().to(int) end = start + random().to(int) steps = random(0, end - start).to(int) x = torch.logspace(start=start, end=end, steps=steps) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False) def test_logspace_float_with_random_data(test_case): start = random() end = start + random() steps = random(0, end - start).to(int) x = torch.logspace(start=start, end=end, steps=steps) device = random_device() x.to(device) return x @autotest(n=5, auto_backward=False) def test_logspace_with_random_base(test_case): start = random() end = start + random() steps = random(0, end - start).to(int) base = random(1, 4).to(float) x = torch.logspace(start=start, end=end, steps=steps, base=base) device = random_device() x.to(device) return x def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.logspace(start=0, end=10, steps=2, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_logsumexp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestLogSumExpModule(flow.unittest.TestCase): @autotest(n=3, check_graph=True) def test_log_sum_exp_against_pytorch(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) y = torch.logsumexp(x, dim=np.random.randint(0, 3)) return y @unittest.skipIf(True, "pytorch-1.10.0 dose not support big_value of logsumexp") @autotest(n=3, auto_backward=False, check_graph=True) def test_log_sum_exp_with_big_value(test_case): device = random_device() x = torch.tensor([100, 200]).to(device) y = torch.logsumexp(x, dim=0) return y @autotest(n=3, auto_backward=False, check_graph=True) def test_log_sum_exp_with_0_size_tensor(test_case): device = random_device() x = random_tensor(4, 4, 3, 0, 2).to(device) y = torch.logsumexp(x, dim=np.random.randint(0, 3)) return y @autotest(n=3, auto_backward=False, check_graph=True) def test_log_sum_exp_with_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.logsumexp(x, dim=0) return y @autotest(n=3, check_graph=True) def test_tensor_log_sum_exp_against_pytorch(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) y = x.logsumexp(dim=np.random.randint(0, 3)) return y @unittest.skipIf(True, "pytorch-1.10.0 dose not support big_value of logsumexp") @autotest(n=3, auto_backward=False, check_graph=True) def test_tensor_log_sum_exp_with_big_value(test_case): device = random_device() x = torch.tensor([100, 200]).to(device) y = x.logsumexp(dim=0) return y @autotest(n=3, auto_backward=False, check_graph=True) def test_tensor_log_sum_exp_with_0_size_tensor(test_case): device = random_device() x = random_tensor(4, 4, 3, 0, 2).to(device) y = x.logsumexp(dim=np.random.randint(0, 3)) return y @autotest(n=3, auto_backward=False, check_graph=True) def test_tensor_log_sum_exp_with_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.logsumexp(dim=0) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import numpy as np import oneflow as flow import oneflow.unittest import torch as torch_original from packaging import version def generate_necessity_for_cross_entropy_or_nll_loss(dim: int, prob: bool = False): if dim > 5 or dim < 2: raise ValueError("dim should be less than 5 or greater than 1. ") device = random_device() num_classes = random(low=2).to(int) batch_size = random(low=10, high=100).to(int) ignore_index = ( random(0, num_classes).to(int) | nothing() if num_classes.value() > 2 and not prob else nothing() ) extra_dim = [random().to(int) for _ in range(dim - 2)] if prob: target_tensor = random_tensor( dim, batch_size, num_classes, *extra_dim, requires_grad=False, ).to(device) else: target_tensor = random_tensor( dim - 1, batch_size, *extra_dim, low=0, high=num_classes, dtype=int, requires_grad=False, ).to(device) return ( random_tensor(dim, batch_size, num_classes, *extra_dim).to(device), target_tensor, random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to(device), ignore_index, device, ) def generate_necessity_for_bce_loss(dim: int): if dim > 5 or dim < 2: raise ValueError("dim should be less than 6 or greater than 1. ") device = random_device() num_classes = random(low=3).to(int) batch_size = random(low=10, high=100).to(int) extra_dim = [random().to(int) for _ in range(dim - 2)] return ( random_tensor(dim, batch_size, num_classes, low=0, high=1, *extra_dim).to( device ), random_tensor( dim, batch_size, num_classes, *extra_dim, low=0, high=num_classes, requires_grad=False, ).to(device), random_tensor( dim, batch_size, num_classes, *extra_dim, low=0, high=3, requires_grad=False ).to(device), random_tensor( 1, extra_dim[-1] if dim > 2 else num_classes, low=1, high=3, requires_grad=False, ).to(device), device, ) def _test_cross_entropy_loss(dim: int, prob: bool = False): ( x, target, weight, ignore_index, device, ) = generate_necessity_for_cross_entropy_or_nll_loss(dim, prob) m = torch.nn.CrossEntropyLoss( reduction=oneof("none", "sum", "mean", nothing()), ignore_index=ignore_index, weight=oneof(weight, nothing()), # TODO(wangyi): PyTorch under 1.12 has bug here, which returns wrong result when ignore_index >= 0 and label_smoothing > 0 label_smoothing=random(low=0, high=1) if version.parse(torch_original.__version__) >= version.parse("1.12.0") else 0, ) m.train(random()) m.to(device) y = m(x, target) return y def _test_nn_functional_cross_entropy_loss(dim: int, prob: bool): ( x, target, weight, ignore_index, device, ) = generate_necessity_for_cross_entropy_or_nll_loss(dim, prob) y1 = torch.nn.functional.cross_entropy(x, target) y2 = torch.nn.functional.cross_entropy(x, target, weight) return y1 + y2 @flow.unittest.skip_unless_1n1d() class TestCrossEntropyLossModule(flow.unittest.TestCase): @autotest(n=5) def test_cross_entropy_loss_with_random_data_dim_2(test_case): return _test_cross_entropy_loss(2, prob=False) @autotest(n=5) def test_cross_entropy_loss_with_random_data_dim_3(test_case): return _test_cross_entropy_loss(3, prob=False) @autotest(n=5) def test_cross_entropy_loss_with_random_data_dim_4(test_case): return _test_cross_entropy_loss(4, prob=False) @autotest(n=5) def test_cross_entropy_loss_with_random_data_dim_5(test_case): return _test_cross_entropy_loss(5, prob=False) @autotest(n=5) def test_nn_functional_cross_entropy_with_random_data_dim(test_case): dim = random(2, 6).to(int).value() return _test_nn_functional_cross_entropy_loss(dim, prob=False) @unittest.skip("skip for now, becase it failed 3 times in past week") @autotest(n=5) def test_cross_entropy_prob_loss_with_random_data_dim_2(test_case): return _test_cross_entropy_loss(2, prob=True) @autotest(n=5, rtol=1e-3) def test_cross_entropy_prob_loss_with_random_data_dim_3(test_case): return _test_cross_entropy_loss(3, prob=True) @unittest.skip("skip for now, becase it failed 4 times in past week") @autotest(n=5) def test_cross_entropy_prob_loss_with_random_data_dim_4(test_case): return _test_cross_entropy_loss(4, prob=True) @unittest.skip("skip for now, becase it failed 6 times in past week") @autotest(n=5) def test_cross_entropy_prob_loss_with_random_data_dim_5(test_case): return _test_cross_entropy_loss(5, prob=True) @autotest(n=5) def test_nn_functional_prob_cross_entropy_with_random_data_dim(test_case): dim = random(2, 6).to(int).value() return _test_nn_functional_cross_entropy_loss(dim, prob=True) def _test_nll_loss(dim=int): ( x, target, weight, ignore_index, device, ) = generate_necessity_for_cross_entropy_or_nll_loss(dim) m = torch.nn.NLLLoss( weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean", nothing()), ignore_index=ignore_index, ) m.train(random()) m.to(device) y = m(x, target) return y @flow.unittest.skip_unless_1n1d() class TestNLLLossModule(flow.unittest.TestCase): @autotest(n=5) def test_nll_loss_with_random_data_dim_2(test_case): return _test_nll_loss(2) @autotest(n=5) def test_nll_loss_with_random_data_dim_3(test_case): return _test_nll_loss(3) @autotest(n=5) def test_nll_loss_with_random_data_dim_4(test_case): return _test_nll_loss(4) @autotest(n=5) def test_nll_loss_with_random_data_dim_5(test_case): return _test_nll_loss(5) def _test_bce_loss(dim=int, with_logits: bool = False): x, target, weight, pos_weight, device = generate_necessity_for_bce_loss(dim) m = torch.nn.BCELoss( weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean", nothing()), ) pos_weight_for_testing_broadcast = random_tensor( 1, 1, low=1, high=3, requires_grad=False, ).to(device) if with_logits: m = torch.nn.BCEWithLogitsLoss( weight=oneof(weight, nothing()), pos_weight=oneof(pos_weight, pos_weight_for_testing_broadcast, nothing()), reduction=oneof("none", "sum", "mean", nothing()), ) m.train(random()) m.to(device) y = m(x, target) return y def _test_nn_functional_binary_cross_entropy(dim=int): (x, target, weight, pos_weight, device) = generate_necessity_for_bce_loss(dim) y = torch.nn.functional.binary_cross_entropy( x, target, weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean", nothing()), pos_weight=oneof(pos_weight, nothing()), ) return y def _test_nn_functional_binary_cross_entropy_with_logits(dim=int): (x, target, weight, pos_weight, device) = generate_necessity_for_bce_loss(dim) y = torch.nn.functional.binary_cross_entropy_with_logits( x, target, weight=oneof(weight, nothing()), reduction=oneof("none", "sum", "mean", nothing()), ) return y def _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_float_first( test_case, shape, reduction, device ): def compare(a, b): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, ) ) arr = np.random.randn(*shape) flow_pred_mask = flow.Tensor(arr).float().to(device) flow_pred_mask.requires_grad = True flow_gt_mask = flow.Tensor(arr).double().to(device) flow_loss = flow.nn.functional.binary_cross_entropy_with_logits( flow_pred_mask, flow_gt_mask, reduction=reduction ) flow_loss.sum().backward() torch_pred_mask = torch_original.Tensor(arr).float().to(device) torch_pred_mask.requires_grad = True torch_gt_mask = torch_original.Tensor(arr).double().to(device) torch_loss = torch_original.nn.functional.binary_cross_entropy_with_logits( torch_pred_mask, torch_gt_mask, reduction=reduction ) torch_loss.sum().backward() compare(flow_loss, torch_loss) compare(flow_pred_mask.grad.data, torch_pred_mask.grad.data) def _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_double_first( test_case, shape, reduction, device ): def compare(a, b): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, ) ) arr = np.random.randn(*shape) flow_pred_mask = flow.Tensor(arr).double().to(device) flow_pred_mask.requires_grad = True flow_gt_mask = flow.Tensor(arr).float().to(device) flow_loss = flow.nn.functional.binary_cross_entropy_with_logits( flow_pred_mask, flow_gt_mask, reduction=reduction ) flow_loss.sum().backward() torch_pred_mask = torch_original.Tensor(arr).double().to(device) torch_pred_mask.requires_grad = True torch_gt_mask = torch_original.Tensor(arr).float().to(device) torch_loss = torch_original.nn.functional.binary_cross_entropy_with_logits( torch_pred_mask, torch_gt_mask, reduction=reduction ) torch_loss.sum().backward() compare(flow_loss, torch_loss) compare(flow_pred_mask.grad.data, torch_pred_mask.grad.data) @flow.unittest.skip_unless_1n1d() class TestBCELossModule(flow.unittest.TestCase): @autotest(n=5) def test_bce_loss_with_random_data_dim_2(test_case): return _test_bce_loss(2) @autotest(n=5) def test_bce_loss_with_random_data_dim_3(test_case): return _test_bce_loss(3) @autotest(n=5) def test_bce_loss_with_random_data_dim_4(test_case): return _test_bce_loss(4) @autotest(n=5) def test_bce_loss_with_random_data_dim_5(test_case): return _test_bce_loss(5) @autotest(n=5) def test_nn_functional_binary_cross_entropy(test_case): dim = random(2, 6).to(int).value() return _test_nn_functional_binary_cross_entropy(dim) @flow.unittest.skip_unless_1n1d() class TestBCEWithLogitsLossModule(flow.unittest.TestCase): @autotest(n=5) def test_bce_with_logits_loss_with_random_data_dim_2(test_case): return _test_bce_loss(2, True) @autotest(n=5) def test_bce_with_logits_loss_with_random_data_dim_3(test_case): return _test_bce_loss(3, True) @autotest(n=5) def test_bce_with_logits_loss_with_random_data_dim_4(test_case): return _test_bce_loss(4, True) @autotest(n=5) def test_bce_with_logits_loss_with_random_data_dim_5(test_case): return _test_bce_loss(5, True) @autotest(n=5) def test_nn_functional_binary_cross_entropy_with_logits(test_case): dim = random(2, 6).to(int).value() return _test_nn_functional_binary_cross_entropy_with_logits(dim) @autotest(n=5) def test_nn_functional_binary_cross_entropy_with_logits_different_dtype(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_float_first, _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_double_first, ] arg_dict["shape"] = [(24, 16, 80), (42, 160), (4, 54, 32, 56)] arg_dict["reduction"] = ["sum", "mean", "none"] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @flow.unittest.skip_unless_1n1d() class TestL1LossModule(flow.unittest.TestCase): @autotest(n=5) def test_l1_loss_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) m = torch.nn.L1Loss(reduction=oneof("none", "sum", "mean", nothing())) m.train(random()) m.to(device) y = m(x, target) return y @autotest(n=5) def _test_nn_functional_l1_loss(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.nn.functional.l1_loss( x, target, reduction=oneof("none", "sum", "mean", nothing()) ) return y @flow.unittest.skip_unless_1n1d() class TestSmoothL1LossModule(flow.unittest.TestCase): @autotest(n=5) def test_smooth_l1_loss_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) m = torch.nn.SmoothL1Loss( reduction=oneof("none", "sum", "mean", nothing()), beta=oneof(0, 0.5, 1) ) m.train(random()) m.to(device) y = m(x, target) return y @flow.unittest.skip_unless_1n1d() class TestMSELossModule(flow.unittest.TestCase): @autotest(n=5) def test_mse_loss_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) m = torch.nn.MSELoss(reduction=oneof("none", "sum", "mean", nothing())) m.train(random()) m.to(device) y = m(x, target) return y @autotest(n=5) def _test_nn_functional_mse_loss(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = torch.nn.functional.mse_loss( x, target, reduction=oneof("none", "sum", "mean", nothing()) ) return y @flow.unittest.skip_unless_1n1d() class TestKLDivLossModule(flow.unittest.TestCase): @autotest(n=5) def test_kldiv_loss_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), low=0, *shape).to(device) target = random_tensor(len(shape), low=0, *shape, requires_grad=False).to( device ) m = torch.nn.KLDivLoss( reduction=oneof("none", "sum", "mean", "batchmean", nothing()), log_target=oneof(True, False, nothing()), ) m.train(random()) m.to(device) y = m(x, target) return y @autotest(n=5) def test_nn_functional_kl_div(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), low=0, *shape).to(device) target = random_tensor(len(shape), low=0, *shape, requires_grad=False).to( device ) y = torch.nn.functional.kl_div( x, target, reduction=oneof("none", "sum", "mean", "batchmean", nothing()), log_target=oneof(True, False, nothing()), ) return y @flow.unittest.skip_unless_1n1d() class TestMarginRankingLossModule(flow.unittest.TestCase): @autotest(n=5) def test_margin_ranking_loss_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x1 = random_tensor(len(shape), *shape).to(device) x2 = random_tensor(len(shape), *shape).to(device) target = random_tensor(len(shape), *shape, requires_grad=False).to(device) m = torch.nn.MarginRankingLoss( margin=oneof(0.0, 0.3, 10), reduction=oneof("none", "sum", "mean", nothing()), ) m.train(random()) m.to(device) y = m(x1, x2, target) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_loss_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList def get_sbp(device: str): return flow.placement.all(device), flow.sbp.split(0) shapes = {2: (128, 8), 3: (16, 8, 64), 4: (16, 8, 32, 32), 5: (16, 8, 16, 16, 16)} def compare_loss(device_type, dim, reduction, cls, data_generator): x, y, x1, y1 = data_generator(dim, device_type, *get_sbp(device_type)) reduce_loss_func = cls(reduction=reduction).to(device_type) none_loss_func = cls(reduction="none").to(device_type) loss_mean = reduce_loss_func(x, y) loss_none = ( flow.mean(none_loss_func(x1, y1)) if reduction == "mean" else flow.sum(none_loss_func(x1, y1)) ) loss_mean.backward() loss_none.backward() assert np.allclose( loss_none.to_local().numpy(), loss_mean.to_local().numpy(), rtol=1e-05, atol=1e-05, ) assert np.allclose(loss_none.numpy(), loss_mean.numpy(), rtol=1e-05, atol=1e-05,) assert np.allclose( x.grad.to_local().numpy(), x1.grad.to_local().numpy(), rtol=1e-05, atol=1e-05, ) def generate_necessity_default(dim: int, device: str, placement, sbp): shape = shapes[dim] x_np = np.random.uniform(0, 1, shape) y_np = np.random.uniform(0, 1, shape) def f(x, requires_grad): t = flow.tensor(x, device=device, requires_grad=requires_grad).to_global( placement=placement, sbp=[sbp] ) if requires_grad: t.retain_grad() return t return f(x_np, True), f(y_np, False), f(x_np, True), f(y_np, False) def generate_necessity_for_cross_entropy_or_nll_loss( dim: int, device: str, placement, sbp ): shape = shapes[dim] y_shape = (shape[0],) if dim == 2 else (shape[0], *shape[2:]) x_np = np.random.uniform(0, 1, shape) y_np = np.random.randint(0, shape[1], y_shape) def f(x, requires_grad): t = flow.tensor(x, device=device, requires_grad=requires_grad).to_global( placement=placement, sbp=[sbp] ) if requires_grad: t.retain_grad() return t return f(x_np, True), f(y_np, False), f(x_np, True), f(y_np, False) class TestBCELossOrWithLogitsConsistent(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_bce_loss(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.BCELoss, flow.nn.BCEWithLogitsLoss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) class TestCrossEntropyOrNllLossConsistent(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_cross_entropy_loss_or_nll_loss(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.CrossEntropyLoss, flow.nn.NLLLoss] arg_dict["data_generator"] = [generate_necessity_for_cross_entropy_or_nll_loss] for arg in GenArgList(arg_dict): compare_loss(*arg) class TestKLDivLossConsistent(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_kl_div_loss(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.KLDivLoss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) class TestSmoothL1LossConsistent(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_smooth_l1_loss(testcase): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dim"] = [2, 3, 4, 5] arg_dict["reduction"] = ["sum", "mean"] arg_dict["cls"] = [flow.nn.SmoothL1Loss] arg_dict["data_generator"] = [generate_necessity_default] for arg in GenArgList(arg_dict): compare_loss(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_lr_scheduler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import random import tempfile import unittest import numpy as np from collections import OrderedDict import oneflow as flow import oneflow.unittest import torch from oneflow.nn.parameter import Parameter from oneflow.test_utils.test_util import GenArgDict def compare_with_torch_reduce_lr( test_case, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, ): optimizer_flow = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]},], lr=TestLrScheduler.base_lr, momentum=0.9, ) optimizer_torch = torch.optim.SGD( [{"params": [torch.nn.Parameter(torch.Tensor([1.0]))]},], lr=TestLrScheduler.base_lr, momentum=0.9, ) scheduler_flow = flow.optim.lr_scheduler.ReduceLROnPlateau( optimizer_flow, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, ) scheduler_troch = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_torch, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, ) val_loss = 0.1 for epoch in range(15): val_loss += (random.random() - 0.5) / 10 scheduler_flow.step(val_loss) scheduler_troch.step(val_loss) for (lr1, lr2) in zip(scheduler_flow._last_lr, scheduler_troch._last_lr): test_case.assertAlmostEqual(lr1, lr2, places=5) @flow.unittest.skip_unless_1n1d() class TestLrScheduler(flow.unittest.TestCase): base_lr = 1.0 def test_cosine_decay_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def cosine_decay_lr_step(base_lr, current_step, decay_steps, alpha): if current_step < decay_steps: cos_decay = 0.5 * (1 + math.cos(math.pi * current_step / decay_steps)) decay_factor = (1 - alpha) * cos_decay + alpha return base_lr * decay_factor else: return base_lr * alpha alpha = 0.5 decay_steps = 10 cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR( optimizer, decay_steps=decay_steps, alpha=alpha ) for i in range(1, 21): cosine_decay_lr.step() new_lr = cosine_decay_lr_step( TestLrScheduler.base_lr, i, decay_steps, alpha ) test_case.assertAlmostEqual( cosine_decay_lr.get_last_lr()[0], new_lr, places=4 ) def test_cosine_annealing_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def cosine_annealing_lr_step(base_lr, current_step, last_lr, T_max, eta_min): if (current_step - 1 - T_max) % (2 * T_max) == 0: return ( last_lr + (TestLrScheduler.base_lr - eta_min) * (1 - math.cos(math.pi / T_max)) / 2 ) else: return (1 + math.cos(math.pi * current_step / T_max)) / ( 1 + math.cos(math.pi * (current_step - 1) / T_max) ) * (last_lr - eta_min) + eta_min T_max = 20 eta_min = 0.5 cosine_annealing_lr = flow.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=T_max, eta_min=eta_min ) numpy_last_lr = TestLrScheduler.base_lr for i in range(1, 101): cosine_annealing_lr.step() numpy_last_lr = cosine_annealing_lr_step( TestLrScheduler.base_lr, i, numpy_last_lr, T_max, eta_min ) test_case.assertAlmostEqual( cosine_annealing_lr.get_last_lr()[0], numpy_last_lr, places=4 ) def test_step_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def step_lr_step(base_lr, current_step, step_size, gamma): return base_lr * gamma ** (current_step // step_size) gamma = 0.1 step_size = 5 step_lr = flow.optim.lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=gamma ) for i in range(1, 21): step_lr.step() new_lr = step_lr_step(TestLrScheduler.base_lr, i, step_size, gamma) test_case.assertAlmostEqual(step_lr.get_last_lr()[0], new_lr, places=5) def test_multistep_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def multistep_lr_step(base_lr, current_step, milestones, gamma): count = 0 for step in milestones: if current_step >= step: count += 1 return base_lr * gamma ** count gamma = 0.1 milestones = [5, 11, 15] multistep_lr = flow.optim.lr_scheduler.MultiStepLR( optimizer, milestones=milestones, gamma=gamma ) for i in range(1, 18): multistep_lr.step() new_lr = multistep_lr_step(TestLrScheduler.base_lr, i, milestones, gamma) test_case.assertAlmostEqual(multistep_lr.get_last_lr()[0], new_lr, places=5) def test_exponential_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def exponential_lr_step(base_lr, current_step, gamma): return base_lr * gamma ** current_step gamma = 0.1 exponential_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) for i in range(1, 21): exponential_lr.step() new_lr = exponential_lr_step(TestLrScheduler.base_lr, i, gamma) test_case.assertAlmostEqual( exponential_lr.get_last_lr()[0], new_lr, places=5 ) def test_lambda_lr(test_case): optimizer = flow.optim.SGD( [ {"params": [Parameter(flow.Tensor([1.0]))]}, {"params": [Parameter(flow.Tensor([1.0]))]}, ], lr=TestLrScheduler.base_lr, ) lambdas = [lambda step: step // 30, lambda step: 0.95 * step] def lambda_lr_step(base_lrs, current_step): return [ base_lr * lmbda(current_step) for (base_lr, lmbda) in zip(base_lrs, lambdas) ] lambda_lr = flow.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas) for i in range(1, 21): lambda_lr.step() new_lrs = lambda_lr_step(lambda_lr.base_lrs, i) for (lr1, lr2) in zip(lambda_lr.get_last_lr(), new_lrs): test_case.assertAlmostEqual(lr1, lr2, places=5) def test_polynomial_lr(test_case): optimizer = flow.optim.SGD( [{"params": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr ) def polynomial_lr_step(base_lr, end_lr, step, decay_steps, power, cycle): if cycle: if step == 0: step = 1 decay_steps = decay_steps * math.ceil(step / decay_steps) step = min(step, decay_steps) return (base_lr - end_lr) * (1 - step / decay_steps) ** power + end_lr decay_steps = 100 end_learning_rate = 1e-5 power = 2 cycle = True poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR( optimizer, decay_steps, end_learning_rate, power, cycle ) # step(0) will be invoked in LRScheduler.__init__ new_lr = polynomial_lr_step( TestLrScheduler.base_lr, end_learning_rate, 0, decay_steps, power, cycle ) test_case.assertAlmostEqual(poly_decay_lr.get_last_lr()[0], new_lr, places=4) for i in range(1, 21): poly_decay_lr.step() new_lr = polynomial_lr_step( TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle ) test_case.assertAlmostEqual( poly_decay_lr.get_last_lr()[0], new_lr, places=4 ) cycle = True poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR( optimizer, decay_steps, end_learning_rate, power, cycle ) for i in range(1, 21): poly_decay_lr.step() new_lr = polynomial_lr_step( TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle ) test_case.assertAlmostEqual( poly_decay_lr.get_last_lr()[0], new_lr, places=4 ) def test_reduce_lr_on_plateau(test_case): arg_dict = OrderedDict() arg_dict["mode"] = ["min", "max"] arg_dict["factor"] = [0.1, 0.3] arg_dict["patience"] = [2, 5] arg_dict["threshold"] = [1e-3, 1e-5] arg_dict["threshold_mode"] = ["rel", "abs"] arg_dict["cooldown"] = [0, 1] arg_dict["min_lr"] = [0, 1e-3] arg_dict["eps"] = [1e-5, 1e-8] for arg in GenArgDict(arg_dict): compare_with_torch_reduce_lr(test_case, **arg) def test_warmup_scheduler_save_and_load(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param]) cosine_scheduler = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( cosine_scheduler, warmup_factor=0.1, warmup_iters=5, warmup_method="linear", ) for _ in range(random.randint(1, 10)): lr_scheduler.step() # save with tempfile.NamedTemporaryFile() as f: flow.save(lr_scheduler.state_dict(), f.name) state_dict = flow.load(f.name) # load param2 = flow.nn.Parameter(flow.ones(3, 4)) optimizer2 = flow.optim.SGD([param]) cosine_scheduler2 = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50) lr_scheduler2 = flow.optim.lr_scheduler.WarmUpLR( cosine_scheduler2, warmup_factor=0.5, warmup_iters=10, warmup_method="linear", ) lr_scheduler2.load_state_dict(state_dict) # compare warm up scheduler for attr in ["warmup_iters", "warmup_factor", "warmup_method", "last_step"]: test_case.assertEqual( getattr(lr_scheduler, attr), getattr(lr_scheduler2, attr) ) # compare cosine_annealing_lr for attr in ["T_max", "eta_min", "last_step"]: test_case.assertEqual( getattr(cosine_scheduler, attr), getattr(cosine_scheduler2, attr) ) @flow.unittest.skip_unless_1n1d() class WarmupLRTestCase(flow.unittest.TestCase): def test_only_warmup(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.001) warmup_lr = flow.optim.lr_scheduler.WarmupLR( optimizer, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) expected_lrs = [ 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001, 0.001, 0.001, 0.001, 0.001, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): optimizer.step() warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_warmup_iters_0_exp_lr(test_case): lr = 0.1 gamma = 0.9 param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr) exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma) warmup_lr = flow.optim.lr_scheduler.WarmupLR( exp_lr, warmup_factor=0.5, warmup_iters=0, warmup_method="linear" ) iters = 10 lrs = [warmup_lr.get_last_lr()[0]] for _ in range(iters): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] expected_lrs = [lr * pow(gamma, i) for i in range(iters)] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_exp_lr(test_case): lr = 0.1 gamma = 0.9 param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr) exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma) warmup_lr = flow.optim.lr_scheduler.WarmupLR( exp_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="linear" ) expected_lrs = [ 0.05, 0.0518098, 0.0536196, 0.0554294, 0.0572392, 0.059049, 0.0531441, 0.04782969, 0.043046721, 0.0387420489, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_prefix_exp_lr(test_case): lr = 0.1 gamma = 0.9 param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr) exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma) warmup_lr = flow.optim.lr_scheduler.WarmupLR( exp_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="linear", warmup_prefix=True, ) expected_lrs = [ 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.09, 0.081, 0.0729, 0.06561, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_constant_warmup_cosine_annealing(test_case): lr = 0.1 param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr) cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) warmup_lr = flow.optim.lr_scheduler.WarmupLR( cos_annl_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="constant", ) expected_lrs = [ 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.03454915028125264, 0.020610737385376353, 0.009549150281252635, 0.002447174185242324, 0.0, 0.0024471741852423235, 0.009549150281252666, 0.020610737385376433, 0.034549150281252786, 0.050000000000000225, 0.06545084971874766, 0.079389262614624, 0.09045084971874778, 0.09755282581475812, 0.1, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_cosine_annealing(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) warmup_lr = flow.optim.lr_scheduler.WarmupLR( cos_annl_lr, warmup_factor=0.1, warmup_iters=5, warmup_method="linear", ) expected_lrs = [ 0.01, 0.025071068, 0.040142136, 0.055213203, 0.070284271, 0.085355339, 0.079389263, 0.072699525, 0.06545085, 0.057821723, 0.05, 0.042178277, 0.03454915, 0.027300475, 0.020610737, 0.014644661, 0.00954915, 0.005449674, 0.002447174, 0.000615583, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_prefix_cosine_annealing(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) warmup_lr = flow.optim.lr_scheduler.WarmupLR( cos_annl_lr, warmup_factor=0.1, warmup_iters=5, warmup_method="linear", warmup_prefix=True, ) expected_lrs = [ 0.01, 0.028, 0.046, 0.064, 0.082, 0.1, 0.099384417, 0.097552826, 0.094550326, 0.09045085, 0.085355339, 0.079389263, 0.072699525, 0.06545085, 0.057821723, 0.05, 0.042178277, 0.03454915, 0.027300475, 0.020610737, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_multistep_lr(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.001) multistep_lr = flow.optim.lr_scheduler.MultiStepLR(optimizer, [10]) warmup_lr = flow.optim.lr_scheduler.WarmupLR( multistep_lr, warmup_factor=0.5, warmup_iters=5, warmup_method="linear", ) expected_lrs = [ 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001, 0.001, 0.001, 0.001, 0.001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): optimizer.step() warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_linear_warmup_prefix_multistep_lr(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) multistep_lr = flow.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[5, 10] ) warmup_lr = flow.optim.lr_scheduler.WarmupLR( multistep_lr, warmup_factor=0.1, warmup_iters=5, warmup_method="linear", warmup_prefix=True, ) expected_lrs = [ 0.01, 0.028, 0.046, 0.064, 0.082, 0.1, 0.1, 0.1, 0.1, 0.1, 0.01, 0.01, 0.01, 0.01, 0.01, 0.001, 0.001, 0.001, 0.001, 0.001, ] lrs = [warmup_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): warmup_lr.step() lrs.append(warmup_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) @flow.unittest.skip_unless_1n1d() class ConstantLRTestCase(flow.unittest.TestCase): def test(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.01) constant_lr = flow.optim.lr_scheduler.ConstantLR(optimizer, 0.1, 10) expected_lrs = [ 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, ] lrs = [constant_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): constant_lr.step() lrs.append(constant_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) @flow.unittest.skip_unless_1n1d() class LinearLRTestCase(flow.unittest.TestCase): def test(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) linear_lr = flow.optim.lr_scheduler.LinearLR(optimizer, 0.1, 1, 10) expected_lrs = [ 0.01, 0.019, 0.028, 0.037, 0.046, 0.055, 0.064, 0.073, 0.082, 0.091, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, ] lrs = [linear_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): linear_lr.step() lrs.append(linear_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_end_factor(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) linear_lr = flow.optim.lr_scheduler.LinearLR(optimizer, 0.1, 0.9, 10) expected_lrs = [ 0.01, 0.018, 0.026, 0.034, 0.042, 0.05, 0.058, 0.066, 0.074, 0.082, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, ] lrs = [linear_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): linear_lr.step() lrs.append(linear_lr.get_last_lr()[0]) lrs = lrs[:-1] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) @flow.unittest.skip_unless_1n1d() class ChainedSchedulerTestCase(flow.unittest.TestCase): def test(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) opt = flow.optim.SGD([param], lr=1) s1 = flow.optim.lr_scheduler.ConstantLR(opt, factor=0.1, total_iters=3) s2 = flow.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9) scheduler = flow.optim.lr_scheduler.ChainedScheduler([s1, s2]) expected_lrs = [0.1, 0.09, 0.081, 0.729, 0.6561, 0.59049] lrs = [scheduler.get_last_lr()[0]] for _ in range(len(expected_lrs)): scheduler.step() lrs.append(scheduler.get_last_lr()[0]) lrs = lrs[: len(expected_lrs)] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) @flow.unittest.skip_unless_1n1d() class CosineAnnealingWarmRestartsTestCase(flow.unittest.TestCase): def test_mult_1(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=10, eta_min=0.01, ) # fmt: off expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092] # fmt: on lrs = [cosa_r_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): cosa_r_lr.step() lrs.append(cosa_r_lr.get_last_lr()[0]) lrs = lrs[: len(expected_lrs)] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_mult_2(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=10, T_mult=2, eta_min=0.01, ) # fmt: off expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.0994459753267812, 0.09779754323328192, 0.09509529358847656, 0.09140576474687263, 0.08681980515339464, 0.08145033635316129, 0.07542957248827961, 0.06890576474687264, 0.0620395509268104, 0.05500000000000001, 0.04796044907318963, 0.04109423525312737, 0.034570427511720396, 0.028549663646838717, 0.023180194846605363, 0.01859423525312737, 0.014904706411523451, 0.012202456766718092, 0.010554024673218806, 0.1, 0.09986128001799077, 0.0994459753267812, 0.09875664641789544, 0.09779754323328192, 0.0965745789630079, 0.09509529358847656, 0.09336880739593416, 0.09140576474687263, 0.0892182684520014, 0.08681980515339464, 0.08422516217485827, 0.08145033635316129, 0.0785124354122177, 0.07542957248827961, 0.07222075445642905, 0.06890576474687264, 0.06550504137351576, 0.0620395509268104, 0.05853065930775304, 0.05500000000000001, 0.05146934069224699, 0.04796044907318963, 0.04449495862648427, 0.04109423525312737, 0.03777924554357097, 0.034570427511720396, 0.031487564587782305, 0.028549663646838717, 0.02577483782514174, 0.023180194846605363, 0.02078173154799861, 0.01859423525312737, 0.016631192604065852, 0.014904706411523451, 0.013425421036992097, 0.012202456766718092, 0.011243353582104555, 0.010554024673218806, 0.010138719982009242] # fmt: on lrs = [cosa_r_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): cosa_r_lr.step() lrs.append(cosa_r_lr.get_last_lr()[0]) lrs = lrs[: len(expected_lrs)] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) def test_mult_2_decay_half_limit_2(test_case): param = flow.nn.Parameter(flow.ones(3, 4)) optimizer = flow.optim.SGD([param], lr=0.1) cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=10, T_mult=2, decay_rate=0.5, restart_limit=2, eta_min=0.01, ) # fmt: off expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.05, 0.04975376681190276, 0.04902113032590308, 0.04782013048376736, 0.04618033988749895, 0.044142135623730955, 0.04175570504584947, 0.03907980999479094, 0.03618033988749895, 0.03312868930080462, 0.03, 0.02687131069919539, 0.023819660112501053, 0.020920190005209068, 0.018244294954150538, 0.01585786437626905, 0.013819660112501053, 0.012179869516232645, 0.01097886967409693, 0.010246233188097247, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] # fmt: on lrs = [cosa_r_lr.get_last_lr()[0]] for _ in range(len(expected_lrs)): cosa_r_lr.step() lrs.append(cosa_r_lr.get_last_lr()[0]) lrs = lrs[: len(expected_lrs)] test_case.assertTrue( np.allclose(lrs, expected_lrs), f"\nexpected_lrs: {expected_lrs}\nvs.\ncalculated lrs: {lrs}", ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_masked_fill.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMaskedFill(flow.unittest.TestCase): @autotest(n=3) def test_flow_masked_fill_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() input = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) value = random().to(float) return input.masked_fill(mask > 0.5, value) @autotest(n=3) def test_flow_masked_fill_with_0dim_data(test_case): device = random_device() input = random_tensor(ndim=0).to(device) mask = random_tensor(ndim=0).to(device) value = random().to(float) return input.masked_fill(mask > 0, value) @autotest(n=3) def test_flow_masked_fill_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() input = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) mask = random_tensor(ndim=2, dim0=k1, dim1=1).to(device) value = random().to(float) return input.masked_fill(mask > 0.5, value) @autotest(n=3) def test_flow_masked_fill_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() input = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) value = random().to(int) return input.masked_fill(mask > 0.5, value) @autotest(auto_backward=False, n=3) def test_flow_masked_fill_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() input = random_tensor(ndim=2, dim0=k1, dim1=k2).to( device=device, dtype=torch.bool ) mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) value = random().to(bool) return input.masked_fill(mask > 0.5, value) @autotest(auto_backward=False, n=3) def test_flow_masked_fill_inplace_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=10, dim1=20).to(device).clone() mask = random_tensor(ndim=2, dim0=10, dim1=20).to(device) value = random().to(float) input.masked_fill_(mask > 0.5, value) return input if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_masked_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_masked_select(test_case, device): x = flow.tensor( np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) mask = x.gt(0.05) of_out = flow.masked_select(x, mask) np_out = np.array([0.3139, 0.3898]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array([[0, 1], [1, 0], [0, 0]]) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_masked_select_broadcast(test_case, device): x = flow.tensor( np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) mask = flow.tensor( np.array( [ [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], [[1.0, 0], [1.0, 1.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 1.0], [1.0, 1.0]], ] ), dtype=flow.int8, device=flow.device(device), ) of_out = flow.masked_select(x, mask) np_out = [ -0.462, 0.3898, -0.7197, -0.1657, -0.462, 0.3898, -0.7197, -0.1657, -0.462, 0.3139, -0.7197, 0.0478, -0.1657, ] test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = [[[3.0, 1.0], [2.0, 3.0], [1.0, 3.0]]] test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_masked_select_input_zero(test_case, device): x = flow.tensor( [[26, 14, 18, 14, 5, 18, 5, 18, 4, 18, 15, 18, 22, 18, 0]], device=flow.device(device), dtype=flow.int64, ) f_mask = flow.tensor( [ [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, ] ], device=flow.device(device), dtype=flow.bool, ) y = x.masked_select(f_mask) test_case.assertTrue( np.allclose( y.numpy(), [26, 14, 18, 14, 5, 18, 5, 18, 4, 18, 15, 18, 22, 18, 0], 1e-05, 1e-05, ) ) @flow.unittest.skip_unless_1n1d() class TestMaskedSelect(flow.unittest.TestCase): def test_masked_select(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_masked_select, _test_masked_select_broadcast, _test_masked_select_input_zero, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_masked_select_broadcast(test_case): x = flow.ones(2, 3, 3) mask = flow.triu(flow.ones(3, 3), 1) flow_res = flow.masked_select(x, mask) np_res = [1, 1, 1, 1, 1, 1] test_case.assertTrue(np.allclose(flow_res.numpy(), np_res, 1e-05, 1e-05)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_math_op_higher_derivative.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_math_op_grad_grad_impl(test_case, op_name): x = random_tensor(ndim=2, low=-2, high=2).requires_grad_(True) y = eval(f"torch.{op_name}")(x) np_arr = np.random.rand(*x.oneflow.shape) init_grad = torch.tensor(np_arr).requires_grad_() x_grad = torch.autograd.grad(y, x, init_grad, retain_graph=True, create_graph=True)[ 0 ] test_case.assertTrue( np.allclose( x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) x_grad_grad = torch.autograd.grad(x_grad, x, init_grad, retain_graph=True)[0] test_case.assertTrue( np.allclose( x_grad_grad.pytorch.detach().cpu().numpy(), x_grad_grad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) init_grad_grad = torch.tensor(np_arr).requires_grad_() dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, retain_graph=True)[0] test_case.assertTrue( np.allclose( dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(), atol=1e-4, rtol=1e-4, equal_nan=True, ) ) class TestMathOpHigherDerivative(flow.unittest.TestCase): def test_sin_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "sin") def test_cos_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "cos") def test_tan_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "tan") def test_sinh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "sinh") def test_cosh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "cosh") def test_tanh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "tanh") def test_asin_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "asin") def test_acos_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "acos") def test_atan_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "atan") def test_asinh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "asinh") def test_acosh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "acosh") def test_atanh_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "atanh") def test_erf_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "erf") def test_erfc_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "erfc") def test_exp_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "exp") def test_exp2_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "exp2") def test_expm1_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "expm1") def test_log_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "log") def test_logsigmoid_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "nn.functional.logsigmoid") def test_log2_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "log2") def test_log1p_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "log1p") def test_reciprocal_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "reciprocal") def test_rsqrt_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "rsqrt") def test_sqrt_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "sqrt") def test_square_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "square") def test_sigmoid_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "sigmoid") def test_abs_grad_grad(test_case): _test_math_op_grad_grad_impl(test_case, "abs") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_math_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) import torch as torch_original from packaging import version @flow.unittest.skip_unless_1n1d() class TestSinh(flow.unittest.TestCase): @autotest(n=5) def test_flow_sinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.sinh(x) return y @flow.unittest.skip_unless_1n1d() class TestSin(flow.unittest.TestCase): @autotest(n=5) def test_flow_sin_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.sin() return y @flow.unittest.skip_unless_1n1d() class TestInplaceSin(flow.unittest.TestCase): @autotest(n=5) def test_flow_inplace_sin_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x + 1 # transform to non-leaf tensor y.sin_() return y def _test_cos(test_case, shape, device): input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.cos(input) np_out = np.cos(input.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_cos_backward(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.cos(x) z = y.sum() z.backward() np_grad = -np.sin(x.numpy()) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestCos(flow.unittest.TestCase): def test_cos(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_cos, _test_cos_backward] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @flow.unittest.skip_unless_1n1d() class TestLogModule(flow.unittest.TestCase): @autotest(n=5) def test_log_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.log(x) @flow.unittest.skip_unless_1n1d() class TestSqrt(flow.unittest.TestCase): @autotest(n=10, include_complex=True) def test_sqrt_flow_with_random_data(test_case): device = random_device() x = random_tensor().to(device) z = torch.sqrt(x) return z @autotest(n=10, include_complex=True) def test_sqrt_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) z = x.sqrt() return z @flow.unittest.skip_unless_1n1d() class TestExp(flow.unittest.TestCase): @autotest(n=5) def test_flow_exp_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.exp(x) return y @flow.unittest.skip_unless_1n1d() class TestExp2(flow.unittest.TestCase): @autotest(n=5, auto_backward="auto") def test_flow_exp2_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.exp2(x) return y @flow.unittest.skip_unless_1n1d() class TestRsqrt(flow.unittest.TestCase): @autotest(n=5) def test_rsqrt_flow_with_random_data(test_case): device = random_device() x = random_tensor().to(device) z = torch.rsqrt(x) return z @flow.unittest.skip_unless_1n1d() class TestSquare(flow.unittest.TestCase): @autotest(n=5) def test_square_flow_with_random_data(test_case): device = random_device() x = random_tensor().to(device) z = torch.square(x) return z @autotest(n=5) def test_square_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) z = x.square() return z @flow.unittest.skip_unless_1n1d() class TestPow(flow.unittest.TestCase): @autotest(n=5) def test_pow_float_scalar_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(float) return torch.pow(x, y) def test_pow_int_scalar_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(int) return torch.pow(x, y) @autotest(n=10) def test_reverse_pow_int_scalar_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(int) return torch.pow(y, x) @autotest(n=10) def test_symbolic_reverse_pow_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(int) return y ** x @autotest(n=5) def test_pow_elementwise_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=2).to(device) y = random_tensor(ndim=2, dim1=2).to(device) return torch.pow(x, y) @autotest(n=5) def test_pow_broadcast_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=2).to(device) y = random_tensor(ndim=2, dim1=1).to(device) return torch.pow(x, y) @autotest(n=5) def test_pow_broadcast_with_random_data_reverse(test_case): device = random_device() x = random_tensor(ndim=2, dim1=1).to(device) y = random_tensor(ndim=2, dim1=2).to(device) return torch.pow(x, y) @autotest(n=5) def test_scalar_pow_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() x1 = random_tensor(2, 2, 3).to(x1_device).mean() x2 = random_tensor(2, 2, 3).to(x2_device) y = torch.pow(x1, x2) return y @flow.unittest.skip_unless_1n1d() class TestAsin(flow.unittest.TestCase): @autotest(n=5) def test_flow_asin_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = torch.asin(x) return y @autotest(n=5) def test_flow_arcsin_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = torch.arcsin(x) return y @flow.unittest.skip_unless_1n1d() class TestAsinh(flow.unittest.TestCase): @autotest(n=5) def test_flow_asinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.asinh(x) return y @autotest(n=5) def test_flow_arcsinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.arcsinh(x) return y @flow.unittest.skip_unless_1n1d() class TestTan(flow.unittest.TestCase): @autotest(n=5) def test_flow_tan_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.tan(x) return y @flow.unittest.skip_unless_1n1d() class TestAtan(flow.unittest.TestCase): @autotest(n=5) def test_flow_atan_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.atan(x) return y @autotest(n=5) def test_flow_arctan_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.arctan(x) return y @autotest(n=5) def test_flow_atan2_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) y = random_tensor(ndim=2, dim1=3).to(device) z = torch.atan2(x, y) return z @autotest(n=5) def test_flow_atan2_with_1elem_data(test_case): device = random_device() x = random_tensor(ndim=1, dim1=1).to(device) y = random_tensor(ndim=3, dim1=random(1, 6).to(int)).to(device) z = torch.atan2(x, y) return z @autotest(n=5) def test_flow_atanh_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = torch.atanh(x) return y @autotest(n=5) def test_flow_arctanh_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = torch.arctanh(x) return y @flow.unittest.skip_unless_1n1d() class TestTopk(flow.unittest.TestCase): @autotest(auto_backward=False) def test_flow_topk_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device) y = torch.topk( x, random(low=1, high=8).to(int), dim=random(low=1, high=4).to(int), largest=random_bool(), sorted=constant(True), ) return y[0], y[1] @flow.unittest.skip_unless_1n1d() class TestTopkReturnValues(flow.unittest.TestCase): @autotest(auto_backward=False) def test_flow_topk_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device) result = torch.topk( x, random(low=1, high=8).to(int), dim=random(low=1, high=4).to(int), largest=random_bool(), sorted=constant(True), ) return result.values, result.indices @flow.unittest.skip_unless_1n1d() class TestPow(flow.unittest.TestCase): @autotest(n=5) def test_pow_scalar_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(float) return torch.pow(x, y) @autotest(n=5) def test_pow_elementwise_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=2).to(device) y = random_tensor(ndim=2, dim1=2).to(device) return torch.pow(x, y) @unittest.skip("not support for broadcast currently") @autotest(n=5) def test_pow_broadcast_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=2).to(device) y = random_tensor(ndim=2, dim1=1).to(device) return torch.pow(x, y) @flow.unittest.skip_unless_1n1d() class TestArccos(flow.unittest.TestCase): @autotest(n=5) def test_arccos_flow_with_random_data(test_case): device = random_device() x = random_tensor(low=-1, high=1).to(device) y = torch.arccos(x) return y @flow.unittest.skip_unless_1n1d() class TestAcos(flow.unittest.TestCase): @autotest(n=5) def test_acos_flow_with_random_data(test_case): device = random_device() x = random_tensor(low=-1, high=1).to(device) y = torch.acos(x) return y @flow.unittest.skip_unless_1n1d() class TestArccosh(flow.unittest.TestCase): @autotest(n=5) def test_arccosh_flow_with_random_data(test_case): device = random_device() x = random_tensor(low=2, high=3).to(device) y = torch.arccosh(x) return y @flow.unittest.skip_unless_1n1d() class TestAcosh(flow.unittest.TestCase): @autotest(n=5) def test_acosh_flow_with_random_data(test_case): device = random_device() x = random_tensor(low=2, high=3).to(device) y = torch.acosh(x) return y @flow.unittest.skip_unless_1n1d() class TestAtan2(flow.unittest.TestCase): @autotest(n=5) def test_flow_atan2_with_random_data(test_case): device = random_device() x1 = random_tensor(ndim=1, dim0=1).to(device) x2 = random_tensor(ndim=1, dim0=1).to(device) y = torch.atan2(x1, x2) return y @flow.unittest.skip_unless_1n1d() class TestMinimum(flow.unittest.TestCase): @autotest(n=5) def test_flow_elementwise_minimum_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) return torch.minimum(x, y) @autotest(n=5) def test_flow_broadcast_minimum_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) k3 = random(2, 6) x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device) y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device) return torch.minimum(x, y) @flow.unittest.skip_unless_1n1d() class TestMaximum(flow.unittest.TestCase): @autotest(n=5) def test_flow_elementwise_mximum_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) return torch.maximum(x, y) @autotest(n=5) def test_flow_broadcast_maximum_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) k3 = random(2, 6) x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device) y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device) return torch.maximum(x, y) @flow.unittest.skip_unless_1n1d() class TestFloorDiv(flow.unittest.TestCase): @autotest(auto_backward=False) def test_elementwise_floordiv_random_data(test_case): device = random_device() # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=0, high=10).to( device ) y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=1, high=10).to( device ) return torch.floor_divide(x, y) @autotest(auto_backward=False) def test_tensor_floordiv_scalar_random_data(test_case): device = random_device() # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=0, high=10).to( device ) y = random().to(int) return torch.floor_divide(x, y) @flow.unittest.skip_unless_1n1d() class TestFmod(flow.unittest.TestCase): # other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0 grad_implemented = version.parse(torch_original.__version__) >= version.parse( "1.11.0" ) @autotest(auto_backward=grad_implemented) def test_elementwise_fmod_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) return torch.fmod(x, y) @autotest(n=5, auto_backward=grad_implemented) def test_flow_broadcast_fmod_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) k3 = random(2, 6) x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device) y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device) return torch.fmod(x, y) @autotest(auto_backward=grad_implemented) def test_tensor_fmod_scalar_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) y = random().to(int) return torch.fmod(x, y) @flow.unittest.skip_unless_1n1d() class TestPow(flow.unittest.TestCase): @autotest(auto_backward=False) def test_elementwise_pow_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) return torch.pow(x, y) @autotest(n=5) def test_flow_broadcast_pow_with_random_data(test_case): device = random_device() k1 = random(2, 6) k2 = random(2, 6) k3 = random(2, 6) x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device) y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device) return torch.pow(x, y) @autotest(auto_backward=False) def test_tensor_pow_scalar_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device) y = random().to(int) return torch.pow(x, y) @flow.unittest.skip_unless_1n1d() class TestAbsModule(flow.unittest.TestCase): @autotest(n=5) def test_abs_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.abs(x) @flow.unittest.skip_unless_1n1d() class TestCoshModule(flow.unittest.TestCase): @autotest(n=5) def test_cosh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.cosh(x) @flow.unittest.skip_unless_1n1d() class TestLgammaModule(flow.unittest.TestCase): @autotest(n=5) def test_lgamma_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.lgamma(x) @flow.unittest.skip_unless_1n1d() class TestLog2Module(flow.unittest.TestCase): @autotest(n=5) def test_log2_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.log2(x) @flow.unittest.skip_unless_1n1d() class TestLog10Module(flow.unittest.TestCase): @autotest(n=5) def test_log10_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.log10(x) @flow.unittest.skip_unless_1n1d() class TestDigammaModule(flow.unittest.TestCase): @autotest(n=5) def test_digamma_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return torch.digamma(x) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_matmul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import torch as torch_original import oneflow as flow import oneflow.unittest import torch as torch_original from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): @autotest(check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True) def test_flow_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) z = torch.matmul(x, y) return z @autotest(check_graph=True, rtol=1e-2, atol=1e-4) def test_flow_tensor_matmul_with_random_data_allow_tf32(test_case): flow.backends.cuda.matmul.allow_tf32 = True torch_original.backends.cuda.matmul.allow_tf32 = True device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) ret = x.matmul(y) flow.backends.cuda.matmul.allow_tf32 = False torch_original.backends.cuda.matmul.allow_tf32 = False return ret @autotest(check_graph=True, rtol=1e-2, atol=1e-4) def test_flow_tensor_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) return x.matmul(y) @autotest(n=5, check_graph=False) def test_flow_tensor_matmul_with_random_int_data(test_case): x = np.random.randint(10, 21, size=5) y = np.random.randint(1, 14, size=(5, 4)) torch_x = torch.from_numpy(x).to(torch.int) torch_y = torch.from_numpy(y).to(torch.int) torch_output_numpy = torch_x.matmul(torch_y).numpy() flow_x = flow.tensor(x).to(flow.int) flow_y = flow.tensor(y).to(flow.int) flow_output_numpy = flow_x.matmul(flow_y).numpy() test_case.assertTrue( np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, check_graph=False) def test_flow_tensor_matmul_with_random_fp16_data(test_case): x = np.random.rand(3, 5) y = np.random.rand(5, 4) torch_x = torch.from_numpy(x).to(device=gpu_device(), dtype=torch.float16) torch_y = torch.from_numpy(y).to(device=gpu_device(), dtype=torch.float16) torch_output_numpy = torch_x.matmul(torch_y).cpu().numpy() flow_x = flow.tensor(x).to(device="cuda", dtype=flow.float16) flow_y = flow.tensor(y).to(device="cuda", dtype=flow.float16) flow_output_numpy = flow_x.matmul(flow_y).cpu().numpy() test_case.assertTrue( np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05) ) @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3) def test_flow_tensor_broadcast_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=4, dim3=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) return x.matmul(y) @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True) def test_flow_tensor_x_broadcast_y_matmul(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=4, dim2=k).to(device) return x.matmul(y) @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-4, include_complex=True) def test_flow_tensor_broadcast_matmul_with_same_dims(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=4, dim1=1, dim3=k).to(device) y = random_tensor(ndim=4, dim0=1, dim2=k).to(device) return x.matmul(y) @autotest(check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True) def test_flow_mm_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) z = torch.mm(x, y) return z @autotest(n=10, check_graph=True, include_complex=True) def test_flow_mv_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=2, dim1=k).to(device) y = random_tensor(ndim=1, dim0=k).to(device) z = torch.mv(x, y) return z @profile(torch.mv) def profile_mv(test_case): torch.mv(torch.ones(32, 64), torch.ones(64)) @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-4, include_complex=True) def test_flow_vector_matrix_product_with_random_data(test_case): device = random_device() k = random(1, 6) x = random_tensor(ndim=1, dim0=k).to(device) y = random_tensor(ndim=2, dim0=k).to(device) z = torch.matmul(x, y) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_max.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _test_scalar_max(test_case, device): y = flow.max(flow.tensor(1.0, device=device), flow.tensor([2], device=device)) test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05)) y = flow.max(flow.tensor(1.0, device=device), flow.tensor(2, device=device)) test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05)) y = flow.max(flow.tensor([1.0], device=device), flow.tensor(2, device=device)) test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestMaxModule(flow.unittest.TestCase): def test_scalar_max(test_case): _test_scalar_max(test_case, "cpu") @autotest(n=5, check_allclose=False, check_graph=True) def test_max_reduce_random_dim(test_case): device = random_device() ndim = random().to(int).value() x = random_tensor(ndim=ndim, dim0=random(1, 8)) y = x.to(device) dim = random(-ndim, ndim).to(int).value() keep_dims = random_bool().value() y = torch.max(x, dim=dim, keepdim=keep_dims) # pytorch result is an instance of class 'torch.return_types.max', but oneflow is tuple test_case.assertTrue( np.allclose( y.oneflow[0].detach().cpu().numpy(), y.pytorch.values.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) test_case.assertTrue( np.allclose( y.oneflow[1].detach().cpu().numpy(), y.pytorch.indices.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) y.oneflow[0].sum().backward() y.pytorch.values.sum().backward() test_case.assertTrue( np.allclose( x.oneflow.grad.detach().cpu().numpy(), x.pytorch.grad.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) @autotest(n=5, check_graph=True) def test_max_reduce_all_dim(test_case): device = random_device() ndim = random().to(int).value() x = random_tensor(ndim=ndim, dim0=random(1, 8)).to(device) return torch.max(x) @autotest(n=5, check_graph=True) def test_max_elementwise(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] x = random_tensor(ndim, *dims).to(device) y = random_tensor(ndim, *dims).to(device) return torch.max(x, y) @autotest(n=5, check_graph=True, check_dtype=True) def test_max_elementwise_dtype_promotion(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] x = random_tensor(ndim, *dims, dtype=float).to(device) y = random_tensor(ndim, *dims, dtype=int).to(device) return torch.max(x, y) @autotest(n=5, check_graph=True, check_dtype=True) def test_max_broadcast_dtype_promotion(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] b_dims = [1 for _ in range(ndim)] x = random_tensor(ndim, *dims, dtype=float).to(device) y = random_tensor(ndim, *b_dims, dtype=int).to(device) return torch.max(x, y) @autotest(n=3, auto_backward=True, check_graph=True) def test_max_with_diff_size(test_case): x = flow.rand(1, 1, 4, requires_grad=True) y = flow.rand(1, 4, requires_grad=True) x = random_tensor(3, 1, 1, 4) y = random_tensor(2, 1, 4) return torch.max(x, y) @autotest(n=3, auto_backward=False) def test_max_return_type(test_case): x = random_tensor(3, 4) result = x.max(1) return result.values, result.indices if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_maxpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from pkg_resources import packaging import numpy as np import torch as pytorch import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t def _test_maxpool2d_channel_last( test_case, device, shape, kernel_size, stride, padding, dilation, ceil_mode ): os.environ["ONEFLOW_ENABLE_NHWC"] = "1" arr = np.random.randn(*shape) x1 = flow.tensor(arr, dtype=flow.float64, device=device) m1 = flow.nn.MaxPool2d( kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, ) y1 = m1(x1) x2 = pytorch.tensor(arr.transpose(0, 3, 1, 2), dtype=pytorch.float64, device=device) m2 = pytorch.nn.MaxPool2d( kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, ) y2 = m2(x2).permute(0, 2, 3, 1) os.environ["ONEFLOW_ENABLE_NHWC"] = "0" # The test fails with pytorch 1.10 but success with pytorch1.13. It should be took back after updating to pytorch1.13. # test_case.assertTrue( # np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), 1e-4, 1e-4) # ) @flow.unittest.skip_unless_1n1d() class TestMaxPooling(flow.unittest.TestCase): @autotest(n=5, auto_backward=True, check_graph=True) def test_maxpool1d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool1d( kernel_size=random(4, 6).to(_size_1_t), stride=random(1, 3).to(_size_1_t) | nothing(), padding=random(1, 3).to(_size_1_t) | nothing(), dilation=random(2, 4).to(_size_1_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim2=random(20, 22)).to(device) y = m(x) # NOTE(lixiang): When return_indices=False, maxpool1d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @autotest(n=5) def test_maxpool1d_with_2d_input_tensor(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool1d( kernel_size=random(4, 6).to(_size_1_t), stride=random(1, 3).to(_size_1_t) | nothing(), padding=random(1, 3).to(_size_1_t) | nothing(), dilation=random(2, 4).to(_size_1_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=2, dim1=random(20, 22)).to(device) y = m(x) if return_indices: return y[0] else: return y @autotest(n=10, auto_backward=True, check_graph=True) def test_maxpool2d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool2d( kernel_size=random(4, 6).to(_size_2_t), stride=random(1, 3).to(_size_2_t) | nothing(), padding=random(1, 3).to(_size_2_t) | nothing(), dilation=random(2, 4).to(_size_2_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device) y = m(x) # NOTE(lixiang): When return_indices=False, maxpool2d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @autotest(n=5) def test_maxpool2d_with_3d_input_tensor(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool2d( kernel_size=random(4, 6).to(_size_2_t), stride=random(1, 3).to(_size_2_t) | nothing(), padding=random(1, 3).to(_size_2_t) | nothing(), dilation=random(2, 4).to(_size_2_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=random(20, 22), dim2=random(20, 22)).to(device) y = m(x) if return_indices: return y[0] else: return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, auto_backward=False) def test_maxpool2d_with_half_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool2d( kernel_size=random(4, 6).to(_size_2_t), stride=random(1, 3).to(_size_2_t) | nothing(), padding=random(1, 3).to(_size_2_t) | nothing(), dilation=random(2, 4).to(_size_2_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = gpu_device() m.to(device) x = ( random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)) .to(device) .to(torch.float16) ) y = m(x) if return_indices: return y[0] else: return y @autotest(n=5, auto_backward=True, check_graph=True) def test_maxpool3d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool3d( kernel_size=random(4, 6).to(_size_3_t), stride=random(1, 3).to(_size_3_t) | nothing(), padding=random(1, 3).to(_size_3_t) | nothing(), dilation=random(2, 4).to(_size_3_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22) ).to(device) y = m(x) # NOTE(lixiang): When return_indices=False, maxpool3d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @autotest(n=5) def test_maxpool3d_with_4d_input_tensor(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool3d( kernel_size=random(4, 6).to(_size_3_t), stride=random(1, 3).to(_size_3_t) | nothing(), padding=random(1, 3).to(_size_3_t) | nothing(), dilation=random(2, 4).to(_size_3_t) | nothing(), ceil_mode=random(), return_indices=return_indices, ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=4, dim1=random(20, 22), dim2=random(20, 22), dim3=random(20, 22) ).to(device) y = m(x) if return_indices: return y[0] else: return y @unittest.skipIf( packaging.version.parse(pytorch.__version__) == packaging.version.parse("1.10.0"), "skip when pytorch version == 1.10.0", ) # NOTE:pytorch maxpool2d nhwc has bug in version of 1.10.0, so skip it in CI. # detail:https://github.com/pytorch/pytorch/pull/76597 def test_maxpool2d_channel_last(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_maxpool2d_channel_last] arg_dict["device"] = ["cuda"] # CPU pool is very slow, so don't run it with CUDA if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["shape"] = [(3, 14, 27, 3), (5, 9, 14, 10), (2, 224, 224, 3)] arg_dict["kernel_size"] = [3, (2, 3), (3, 4)] arg_dict["stride"] = [1, (1, 2), 2] arg_dict["padding"] = [0, (0, 1)] arg_dict["dilation"] = [1, (1, 2), 2] arg_dict["ceil_mode"] = [True, False] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @flow.unittest.skip_unless_1n1d() class TestMaxPoolingFunctional(flow.unittest.TestCase): @autotest(n=5, auto_backward=True, check_graph=True) def test_maxpool1d_with_random_data(test_case): return_indices = random().to(bool).value() device = random_device() x = random_tensor(ndim=3, dim2=random(20, 22)).to(device) y = torch.nn.functional.max_pool1d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(2, 4).to(int) | nothing(), ceil_mode=random().to(bool), return_indices=return_indices, ) # NOTE(lixiang): When return_indices=False, maxpool1d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @autotest(n=5, auto_backward=True, check_graph=True) def test_maxpool2d_with_random_data(test_case): return_indices = random().to(bool).value() device = random_device() x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device) y = torch.nn.functional.max_pool2d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(2, 4).to(int) | nothing(), ceil_mode=random().to(bool), return_indices=return_indices, ) # NOTE(lixiang): When return_indices=False, maxpool2d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @autotest(auto_backward=True, check_graph=True) def test_maxpool3d_with_random_data(test_case): return_indices = random().to(bool).value() device = random_device() x = random_tensor( ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22) ).to(device) y = torch.nn.functional.max_pool3d( x, kernel_size=random(4, 6).to(int), stride=random(1, 3).to(int) | nothing(), padding=random(1, 3).to(int) | nothing(), dilation=random(2, 4).to(int) | nothing(), ceil_mode=random().to(bool), return_indices=return_indices, ) # NOTE(lixiang): When return_indices=False, maxpool3d will return the max indices along with the outputs, # y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here. if return_indices: return y[0] else: return y @profile(torch.nn.functional.max_pool2d) def profile_maxpool2d(test_case): torch.nn.functional.max_pool2d( torch.ones(1, 128, 28, 28), kernel_size=3, padding=1 ) torch.nn.functional.max_pool2d( torch.ones(1, 128, 28, 28), kernel_size=3, stride=2, padding=1 ) torch.nn.functional.max_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, padding=1 ) torch.nn.functional.max_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1 ) torch.nn.functional.max_pool2d( torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1, ceil_mode=True, ) # torch.nn.functional.max_pool2d(torch.ones(16, 128, 28, 28), kernel_size=3, dilation=2, padding=2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_maxunpool.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random as random_util import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t # y = pool(x), z = unpool(y, indices), pool_input_shape is x.shape, pool_output_shape is y.shape. # When `output_size` in unpool() is empty, the op will calculate the output size according to # kernel_size, stride and padding. But when index in indices is outside the range required # by output_size calculated by unpool op, the value of result and related grad will be unknown. # To avoid the problem, this function calculate the output_size which will not cause unknown problems. def _get_valid_output_size( pool_input_shape, pool_output_shape, kernel_size, stride, padding ): def convert_data(data, i, dst_data=None): if not isinstance(data, (list, int)): return dst_data if isinstance(data, list): return data[i] return data _, _, *pool_input_hwd_shape = pool_input_shape.pytorch batch_size, num_channels, *pool_out_hwd_shape = pool_output_shape.pytorch unpool_output_shape = [batch_size, num_channels] for i, (pool_input_size, pool_output_size) in enumerate( zip(pool_input_hwd_shape, pool_out_hwd_shape) ): kernel_size_value = convert_data(kernel_size.value(), i) stride_value = convert_data(stride.value(), i, kernel_size_value) padding_value = convert_data(padding.value(), i, 0) unpool_output_size = max( pool_input_size, (pool_output_size - 1) * stride_value - 2 * padding_value + kernel_size_value, ) unpool_output_shape.append(unpool_output_size) return torch.Size(unpool_output_shape) def _test_module_unpoolnd(test_case, n): device = random_device() if n == 1: _size_n_t = _size_1_t MaxPoolNd = torch.nn.MaxPool1d MaxUnpoolNd = torch.nn.MaxUnpool1d x = random_tensor(ndim=3, dim2=random(20, 31), requires_grad=False).to(device) elif n == 2: _size_n_t = _size_2_t MaxPoolNd = torch.nn.MaxPool2d MaxUnpoolNd = torch.nn.MaxUnpool2d x = random_tensor( ndim=4, dim2=random(20, 31), dim3=random(20, 31), requires_grad=False ).to(device) elif n == 3: _size_n_t = _size_3_t MaxPoolNd = torch.nn.MaxPool3d MaxUnpoolNd = torch.nn.MaxUnpool3d x = random_tensor( ndim=5, dim2=random(20, 31), dim3=random(20, 31), dim4=random(20, 31), requires_grad=False, ).to(device) kernel_size = random(4, 6).to(_size_n_t) stride = random(1, 3).to(_size_n_t) | nothing() padding = random(1, 3).to(_size_n_t) | nothing() m = MaxPoolNd( kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True, ) m.train(random()) m.to(device) y = m(x) pooling_results_dtype = random_util.choice( [torch.int, torch.long, torch.float, torch.double] ) indices_dtype = random_util.choice([torch.int, torch.long]) pooling_results = y[0].to(pooling_results_dtype) indices = y[1].to(indices_dtype) pooling_results.requires_grad_() output_size = _get_valid_output_size( x.shape, pooling_results.shape, kernel_size, stride, padding ) unpool_module = MaxUnpoolNd( kernel_size=kernel_size, stride=stride, padding=padding, ) result = unpool_module(pooling_results, indices, output_size=output_size) return result def _test_functional_unpoolnd(test_case, n): device = random_device() if n == 1: _size_n_t = _size_1_t MaxPoolNd = torch.nn.MaxPool1d max_unpool_nd = torch.nn.functional.max_unpool1d x = random_tensor(ndim=3, dim2=random(20, 31), requires_grad=False).to(device) elif n == 2: _size_n_t = _size_2_t MaxPoolNd = torch.nn.MaxPool2d max_unpool_nd = torch.nn.functional.max_unpool2d x = random_tensor( ndim=4, dim2=random(20, 31), dim3=random(20, 31), requires_grad=False ).to(device) elif n == 3: _size_n_t = _size_3_t MaxPoolNd = torch.nn.MaxPool3d max_unpool_nd = torch.nn.functional.max_unpool3d x = random_tensor( ndim=5, dim2=random(20, 31), dim3=random(20, 31), dim4=random(20, 31), requires_grad=False, ).to(device) kernel_size = random(4, 6).to(_size_n_t) stride = random(1, 3).to(_size_n_t) | nothing() padding = random(1, 3).to(_size_n_t) | nothing() m = MaxPoolNd( kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True, ) m.train(random()) m.to(device) y = m(x) pooling_results_dtype = random_util.choice( [torch.int, torch.long, torch.float, torch.double] ) indices_dtype = random_util.choice([torch.int, torch.long]) pooling_results = y[0].to(pooling_results_dtype) indices = y[1].to(indices_dtype) pooling_results.requires_grad_() output_size = _get_valid_output_size( x.shape, pooling_results.shape, kernel_size, stride, padding ) return max_unpool_nd( pooling_results, indices, kernel_size=kernel_size, stride=stride, padding=padding, output_size=output_size, ) @flow.unittest.skip_unless_1n1d() class TestMaxUnpooling(flow.unittest.TestCase): @autotest(n=3, check_graph=False) def test_max_unpool1d_with_random_data(test_case): return _test_module_unpoolnd(test_case, 1) @autotest(n=3, check_graph=False) def test_functional_max_unpool1d_with_random_data(test_case): return _test_functional_unpoolnd(test_case, 1) @autotest(n=3, check_graph=False) def test_max_unpool2d_with_random_data(test_case): return _test_module_unpoolnd(test_case, 2) @autotest(n=3, check_graph=False) def test_functional_max_unpool2d_with_random_data(test_case): return _test_functional_unpoolnd(test_case, 2) @autotest(n=3, check_graph=False) def test_max_unpool3d_with_random_data(test_case): return _test_module_unpoolnd(test_case, 3) @autotest(n=3, check_graph=False) def test_functional_max_unpool3d_with_random_data(test_case): return _test_functional_unpoolnd(test_case, 3) @profile(torch.nn.functional.max_unpool1d) def profile_max_unpool1d(test_case): max_pool_results = torch.randn(1, 32, 64) max_pool_indices = torch.arange(64).expand(1, 32, 64) torch.nn.functional.max_unpool1d(max_pool_results, max_pool_indices, 2) max_pool_results = torch.randn(32, 32, 64) max_pool_indices = torch.arange(64).expand(32, 32, 64) torch.nn.functional.max_unpool1d(max_pool_results, max_pool_indices, 2) @profile(torch.nn.functional.max_unpool2d) def profile_max_unpool2d(test_case): max_pool_results = torch.randn(1, 16, 32, 32) max_pool_indices = torch.arange(32).expand(1, 16, 32, 32) torch.nn.functional.max_unpool2d(max_pool_results, max_pool_indices, 2) max_pool_results = torch.randn(32, 16, 32, 32) max_pool_indices = torch.arange(32).expand(32, 16, 32, 32) torch.nn.functional.max_unpool2d(max_pool_results, max_pool_indices, 2) @profile(torch.nn.functional.max_unpool3d) def profile_max_unpool3d(test_case): max_pool_results = torch.randn(1, 4, 32, 32, 32) max_pool_indices = torch.arange(32).expand(1, 4, 32, 32, 32) torch.nn.functional.max_unpool3d(max_pool_results, max_pool_indices, 2) max_pool_results = torch.randn(16, 4, 32, 32, 32) max_pool_indices = torch.arange(32).expand(16, 4, 32, 32, 32) torch.nn.functional.max_unpool3d(max_pool_results, max_pool_indices, 2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_mean.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_mean(test_case, shape, device): input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.mean(input, dim=1) np_out = np.mean(input.numpy(), axis=1) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.mean(input, dim=0) np_out = np.mean(input.numpy(), axis=0) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) def _test_mean_negative_dim(test_case, shape, device): if len(shape) < 4: shape = (2, 3, 4, 5) input = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.mean(input, dim=(-2, -1, -3)) np_out = np.mean(input.numpy(), axis=(-2, -1, -3)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) def _test_mean_backward(test_case, shape, device): np_arr = np.random.randn(*shape) x = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) y = flow.mean(x, dim=1) z = y.sum() z.backward() np_grad = np.zeros(shape=np_arr.shape) np_grad[:] = 1 / x.size(1) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestMean(flow.unittest.TestCase): def test_mean(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_mean, _test_mean_negative_dim, _test_mean_backward, ] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True) def test_mean_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float).to(device) return torch.mean(x, dim) @autotest(n=5) def test_mean_with_scalar_data(test_case): device = random_device() x = random_tensor(ndim=4, dtype=float).to(device).mean() y = x.mean(-1) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=5, atol=1e-3) def test_mean_with_float16_data(test_case): device = gpu_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float).to(device=device, dtype=torch.float16) return torch.mean(x, dim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_median.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMedianModule(flow.unittest.TestCase): @autotest(n=5) def test_median_reduce_all_dim(test_case): device = random_device() ndim = random(1, 4).to(int).value() x = random_tensor(ndim=ndim, dim0=random(1, 4)).to(device) return torch.median(x) @autotest(n=5) def test_median_reduce_one_dim(test_case): device = random_device() ndim = random(low=2).to(int).value() reduce_dim = random(high=ndim).to(int).value() x = random_tensor(ndim).to(device) return torch.median(x, reduce_dim) @autotest(n=5) def test_median_reduce_one_dim_keepdim(test_case): device = random_device() ndim = random(low=2).to(int).value() reduce_dim = random(high=ndim).to(int).value() x = random_tensor(ndim).to(device) return torch.median(x, reduce_dim, True) @autotest(n=5, auto_backward=False, check_graph=True) def test_median_0size(test_case): device = random_device() x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device) return torch.median(x) @autotest(n=5, auto_backward=False, check_graph=True) def test_median_reduce_one_dim_0size(test_case): device = random_device() x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device) return torch.median(x, 0) @autotest(n=5, auto_backward=False) def test_median_return_type(test_case): x = random_tensor(3, 4) result = x.median(1) return result.values, result.indices if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_meshgrid.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_meshgrid_forawd(test_case, device, indexing): input1 = flow.tensor( np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device) ) (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing) (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) def _test_meshgrid_forawd_scalar(test_case, device, indexing): input1 = flow.tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device)) input2 = flow.tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device)) (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing) (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) def _test_meshgrid_forawd_3tensor(test_case, device, indexing): input1 = flow.tensor( np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device) ) input3 = flow.tensor( np.array([7, 8, 9]), dtype=flow.float32, device=flow.device(device) ) (np_x, np_y, np_z) = np.meshgrid( input1.numpy(), input2.numpy(), input3.numpy(), indexing=indexing ) (of_x, of_y, of_z) = flow.meshgrid(input1, input2, input3, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestMeshGridModule(flow.unittest.TestCase): def test_meshgrid(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_meshgrid_forawd, _test_meshgrid_forawd_scalar, _test_meshgrid_forawd_3tensor, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["indexing"] = ["ij", "xy"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(auto_backward=False, check_graph=True) @unittest.skip("pytorch 1.9.0 exist not indexing") def test_meshgrid_with_random_data(test_case): device = random_device() x = random_tensor(ndim=1, dim0=3, requires_grad=False).to(device) y = random_tensor(ndim=1, dim0=3, requires_grad=False).to(device) res = torch.meshgrid(x, y) return res[0], res[1] @autotest(auto_backward=False) def test_meshgrid_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) res = torch.meshgrid(x, y) @autotest(auto_backward=True) @unittest.skip("pytorch 1.9.0 exist not indexing") def test_meshgrid_with_random_data_xy(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random(1, 6)).to(device) y = random_tensor(ndim=1, dim0=random(1, 6)).to(device) res = torch.meshgrid(x, y, indexing="xy") return torch.cat((res[0], res[1]), 0) @autotest(auto_backward=True) @unittest.skip("pytorch 1.9.0 exist not indexing") def test_meshgrid_with_random_data_size(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random(1, 6)).to(device) res = torch.meshgrid(x, indexing="xy") return res[0] @autotest(n=3) def test_meshgrid_tuple_list_with_random_data(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random(1, 6)).to(device) y = random_tensor(ndim=1, dim0=random(1, 6)).to(device) res1 = torch.meshgrid((x, y)) res2 = torch.meshgrid([x, y]) return torch.cat((res1[0], res1[1], res2[0], res2[1]), 0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_min.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMinModule(flow.unittest.TestCase): @autotest(n=5, check_allclose=False, check_graph=True) def test_min_reduce_random_dim(test_case): device = random_device() ndim = random().to(int).value() x = random_tensor(ndim=ndim, dim0=random(1, 8)) y = x.to(device) dim = random(-ndim, ndim).to(int).value() keep_dims = random_bool().value() y = torch.min(x, dim=dim, keepdim=keep_dims) # pytorch result is an instance of class 'torch.return_types.min', but oneflow is tuple test_case.assertTrue( np.allclose( y.oneflow[0].detach().cpu().numpy(), y.pytorch.values.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) test_case.assertTrue( np.allclose( y.oneflow[1].detach().cpu().numpy(), y.pytorch.indices.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) y.oneflow[0].sum().backward() y.pytorch.values.sum().backward() test_case.assertTrue( np.allclose( x.oneflow.grad.detach().cpu().numpy(), x.pytorch.grad.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) @autotest(n=5, check_graph=True) def test_min_reduce_all_dim(test_case): device = random_device() ndim = random().to(int).value() x = random_tensor(ndim=ndim, dim0=random(1, 8)).to(device) return torch.min(x) @autotest(n=5, check_graph=True) def test_min_elementwise(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] x = random_tensor(ndim, *dims).to(device) y = random_tensor(ndim, *dims).to(device) return torch.min(x, y) @autotest(n=5, check_graph=True, check_dtype=True) def test_min_elementwise_dtype_promotion(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] x = random_tensor(ndim, *dims, dtype=float).to(device) y = random_tensor(ndim, *dims, dtype=int).to(device) return torch.min(x, y) @autotest(n=5, check_graph=True, check_dtype=True) def test_min_broadcast_dtype_promotion(test_case): device = random_device() ndim = random().to(int).value() dims = [random(1, 8) for _ in range(ndim)] b_dims = [1 for _ in range(ndim)] x = random_tensor(ndim, *dims, dtype=float).to(device) y = random_tensor(ndim, *b_dims, dtype=int).to(device) return torch.min(x, y) @autotest(n=3, auto_backward=False) def test_min_return_type(test_case): x = random_tensor(3, 4) result = x.min(1) return result.values, result.indices if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_min_max_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.nn.modules import min_max_observer from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) import oneflow as flow import oneflow.unittest def gen_quant_scale_for_min_max_symmetric(weight, quantization_bit): weight_max = np.max(np.abs(weight)) denominator = 2.0 ** (quantization_bit - 1) - 1 return (weight_max / denominator, 0) def gen_quant_scale_for_min_max_affine(weight, quantization_bit): weight_max = np.max(weight) weight_min = np.min(weight) denominator = 2.0 ** quantization_bit - 1 scale = (weight_max - weight_min) / denominator zero_point = -np.round(weight_min / scale) return (scale, zero_point) def gen_quant_scale_for_min_max_cambricon(weight, quantization_bit): weight_max = np.max(np.abs(weight)) scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2) return (scale, 0) def product(tu): return np.prod(tu).astype(np.int32).item() def _check_min_max_observer( test_case, weight, scale_of, zero_point_of, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): if per_layer_quantization or quantization_formula == "cambricon": outer_num = 1 inner_num = product(weight.shape[0:]) else: outer_num = weight.shape[0] inner_num = product(weight.shape[1:]) scale_np = np.zeros((outer_num,)) zero_point_np = np.zeros((outer_num,)) weight_flatten = weight.flatten() if quantization_formula == "google": if quantization_scheme == "symmetric": for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric( weight_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, ) else: for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine( weight_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, ) else: (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon( weight_flatten, quantization_bit ) test_case.assertTrue(np.allclose(scale_of, scale_np, rtol=0.001)) rmse = np.sqrt(np.mean((zero_point_of - zero_point_np) ** 2)) assert rmse <= 1.0, "min_max_observer op zero_point calculate has bug!" def _run_test_min_max_observer( test_case, device_type, weight_shape, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): weight = (np.random.random(weight_shape) - 0.5).astype(np.float32) tensor_weight = flow.tensor( weight, device=flow.device(device_type), dtype=flow.float32 ) min_max_observer = flow.nn.MinMaxObserver( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization, ) scale, zero_point = min_max_observer(tensor_weight) _check_min_max_observer( test_case, weight, scale.numpy(), zero_point.numpy(), quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ) class TestMinMaxObserver(flow.unittest.TestCase): def test_min_max_observer(test_case): arg_dict = OrderedDict() arg_dict["test_case"] = [test_case] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["weight_shape"] = [(9, 40, 20, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["per_layer_quantization"] = [True, False] for arg in GenArgList(arg_dict): if arg[-2] == "cambricon" and arg[-1] == False: continue _run_test_min_max_observer(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_mock.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMockModule(flow.unittest.TestCase): def test_mock_device(test_case): device = flow.device("mock") test_case.assertEqual(device.type, "mock") def test_mock_placement(test_case): placement = flow.placement("mock", [0]) test_case.assertEqual(placement.type, "mock") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestModeModule(flow.unittest.TestCase): @autotest(n=5) def test_mode_reduce_one_dim(test_case): device = cpu_device() ndim = random(low=2).to(int).value() reduce_dim = random(high=ndim).to(int).value() x = random_tensor(ndim).to(device) return torch.mode(x, reduce_dim) @autotest(n=5) def test_mode_reduce_one_dim_keepdim(test_case): device = cpu_device() ndim = random(low=2).to(int).value() reduce_dim = random(high=ndim).to(int).value() x = random_tensor(ndim).to(device) return torch.mode(x, reduce_dim, True) @autotest(n=5, auto_backward=False, check_graph=False) def test_mode_0size(test_case): device = cpu_device() x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device) return torch.mode(x) @autotest(n=5, auto_backward=False, check_graph=False) def test_mode_reduce_one_dim_0size(test_case): device = cpu_device() x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device) return torch.mode(x, 0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_module.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import math import warnings import tempfile import unittest from itertools import repeat from typing import Tuple, Union, List from collections import OrderedDict import numpy as np import torch import oneflow as flow import oneflow.nn as nn import oneflow.unittest from oneflow._oneflow_internal import TensorTuple from oneflow.test_utils.test_util import GenArgList def np_relu(np_arr): return np.where(np_arr > 0, np_arr, 0) def _test_hooks(test_case, backward_register_fn): module = nn.Sigmoid() input = flow.ones(5, 5, requires_grad=True) counter = {"forwards": 0, "backwards": 0} def fw_hook(inc, h_module, input, output): test_case.assertTrue(isinstance(input, tuple)) test_case.assertTrue(isinstance(output, flow.Tensor)) test_case.assertTrue(h_module is module) test_case.assertTrue(flow.equal(input[0], flow.ones(5, 5))) test_case.assertTrue( flow.equal(output, flow.empty(5, 5).fill_(1 / (1 + 1 / math.e))) ) counter["forwards"] += inc def bw_hook(inc, h_module, grad_input, grad_output): test_case.assertTrue(isinstance(grad_input, TensorTuple)) test_case.assertTrue(isinstance(grad_output, TensorTuple)) test_case.assertTrue(h_module is module) test_case.assertTrue(flow.equal(grad_output[0], flow.ones(5, 5) * 2)) counter["backwards"] += inc test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args)) module(input) module(input) test_case.assertEqual(counter["forwards"], 2) test_case.assertEqual(counter["backwards"], 0) test_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook(1, *args)) output = module(input) test_case.assertEqual(counter["forwards"], 3) test_case.assertEqual(counter["backwards"], 0) output.backward(flow.ones(5, 5) * 2, retain_graph=True) test_case.assertEqual(counter["forwards"], 3) test_case.assertEqual(counter["backwards"], 1) output.backward(flow.ones(5, 5) * 2, retain_graph=True) test_case.assertEqual(counter["forwards"], 3) test_case.assertEqual(counter["backwards"], 2) test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args)) output = module(input) test_case.assertEqual(counter["forwards"], 6) test_case.assertEqual(counter["backwards"], 2) test2_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook(2, *args)) module(input).backward(flow.ones(5, 5) * 2) test_case.assertEqual(counter["forwards"], 9) test_case.assertEqual(counter["backwards"], 5) test2_bwd.remove() module(input).backward(flow.ones(5, 5) * 2) test_case.assertEqual(counter["forwards"], 12) test_case.assertEqual(counter["backwards"], 6) test2_fwd.remove() module(input).backward(flow.ones(5, 5) * 2) test_case.assertEqual(counter["forwards"], 13) test_case.assertEqual(counter["backwards"], 7) test_fwd.remove() test_bwd.remove() def _test_module_forward_preforward_hook_removable(test_case): module = nn.Sigmoid() def removable_hook(m, input): nonlocal handle handle.remove() return input def removable_hook_2(m, input): nonlocal handle_2 handle_2.remove() return input handle = module.register_forward_pre_hook(removable_hook) handle_2 = module.register_forward_pre_hook(removable_hook_2) # make sure hook register is successful test_case.assertEqual(len(handle.hooks_dict_ref()), 2) test_case.assertEqual(len(handle_2.hooks_dict_ref()), 2) input = flow.randn(2, 2) output = module(input) test_case.assertTrue(flow.equal(flow.sigmoid(input), output)) # make sure hook removal is successful test_case.assertFalse(handle.id in handle.hooks_dict_ref()) test_case.assertFalse(handle_2.id in handle.hooks_dict_ref()) test_case.assertEqual(len(handle.hooks_dict_ref()), 0) test_case.assertEqual(len(handle_2.hooks_dict_ref()), 0) def _test_module_forward_forward_hook_removable(test_case): module = nn.Sigmoid() def removable_hook(m, input, output): nonlocal handle handle.remove() return output def removable_hook_2(m, input, output): nonlocal handle_2 handle_2.remove() return output handle = module.register_forward_hook(removable_hook) handle_2 = module.register_forward_hook(removable_hook_2) # make sure hook register is successful test_case.assertEqual(len(handle.hooks_dict_ref()), 2) test_case.assertEqual(len(handle_2.hooks_dict_ref()), 2) input = flow.randn(2, 2) output = module(input) test_case.assertTrue(flow.equal(flow.sigmoid(input), output)) # make sure hook removal is successful test_case.assertFalse(handle.id in handle.hooks_dict_ref()) test_case.assertFalse(handle_2.id in handle.hooks_dict_ref()) test_case.assertEqual(len(handle.hooks_dict_ref()), 0) test_case.assertEqual(len(handle_2.hooks_dict_ref()), 0) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestModule(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_nested_module(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.relu = flow.nn.ReLU() def forward(self, x): return self.relu(x) m = CustomModule() x = flow.Tensor(2, 3) flow.nn.init.uniform_(x, a=-1.0, b=1.0) y = m(x) test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy())) @flow.unittest.skip_unless_1n1d() def test_relu(test_case): relu = flow.nn.ReLU() x = flow.Tensor(2, 3) flow.nn.init.uniform_(x, a=-1.0, b=1.0) y = relu(x) test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy())) @flow.unittest.skip_unless_1n1d() def test_load_state_dict(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.w = flow.nn.Parameter(flow.Tensor(2, 3)) def forward(self, x): return self.w m = CustomModule() ones = np.ones((2, 3), dtype=np.float32) m.load_state_dict({"w": ones}) x = flow.Tensor(2, 3) y = m(x).numpy() test_case.assertTrue(np.array_equal(y, ones)) @flow.unittest.skip_unless_1n1d() def test_state_dict(test_case): class CustomModule(flow.nn.Module): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 tensor0 = flow.nn.Parameter(flow.Tensor(2, 3)) tensor1 = flow.nn.Parameter(flow.Tensor(2, 3)) sub_module = CustomModule(tensor0, tensor1) m = CustomModule(tensor1, sub_module) state_dict = m.state_dict() test_case.assertEqual( state_dict, {"param2.param1": tensor0, "param2.param2": tensor1, "param1": tensor1}, ) @flow.unittest.skip_unless_1n1d() def test_parameter(test_case): shape = (3, 4) t = flow.Tensor(*shape) p = flow.nn.Parameter(t) test_case.assertEqual(type(p), flow.nn.Parameter) test_case.assertEqual(p.shape, shape) @flow.unittest.skip_unless_1n1d() def test_module_forward(test_case): class CustomModule(flow.nn.Module): def __init__(self, w): super().__init__() self.w = w def forward(self, x): return x + self.w m = CustomModule(5) test_case.assertEqual(m(1), 6) m = CustomModule(4) test_case.assertEqual(m(3), 7) @flow.unittest.skip_unless_1n1d() def test_train_eval(test_case): m = flow.nn.Module() test_case.assertEqual(m.training, True) m.train() test_case.assertEqual(m.training, True) m.eval() test_case.assertEqual(m.training, False) @flow.unittest.skip_unless_1n1d() def test_module_setattr(test_case): class CustomModule(flow.nn.Module): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 param0 = flow.nn.Parameter(flow.Tensor(2, 3)) param1 = flow.nn.Parameter(flow.Tensor(2, 3)) param2 = CustomModule(param0, param1) m = CustomModule(param1, param2) params = list(m.parameters()) test_case.assertEqual(len(params), 2) test_case.assertTrue( np.allclose(params[0].numpy(), param1.numpy(), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(params[1].numpy(), param0.numpy(), atol=1e-4, rtol=1e-4) ) children = list(m.children()) test_case.assertEqual(len(children), 1) child = children[0] test_case.assertEqual(child, param2) child_params = list(child.parameters()) test_case.assertEqual(len(child_params), 2) test_case.assertTrue(np.allclose(child_params[0].numpy(), param0.numpy())) test_case.assertTrue(np.allclose(child_params[1].numpy(), param1.numpy())) @flow.unittest.skip_unless_1n1d() def test_module_apply(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.modules = flow.nn.Module() global module_num module_num = 0 def get_module_num(m): global module_num module_num += 1 net = CustomModule() net.apply(get_module_num) test_case.assertEqual(module_num, 2) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_module_cpu_cuda(test_case): class CustomModule(flow.nn.Module): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 tensor0 = flow.nn.Parameter(flow.Tensor(2, 3, device=flow.device("cpu"))) tensor1 = flow.nn.Parameter(flow.Tensor(2, 3, device=flow.device("cpu"))) sub_module = CustomModule(tensor0, tensor1) m = CustomModule(tensor1, sub_module) m.cuda() state_dict = m.state_dict() test_case.assertEqual(state_dict["param2.param1"].device, flow.device("cuda:0")) test_case.assertEqual(state_dict["param2.param2"].device, flow.device("cuda:0")) m.cpu() state_dict = m.state_dict() test_case.assertEqual(state_dict["param2.param1"].device, flow.device("cpu")) test_case.assertEqual(state_dict["param2.param2"].device, flow.device("cpu")) @flow.unittest.skip_unless_1n1d() def test_module_float_double(test_case): class CustomModule(flow.nn.Module): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 tensor0 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float64)) tensor1 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float64)) m = CustomModule(tensor0, tensor1) m = m.float() state_dict = m.state_dict() test_case.assertEqual(state_dict["param1"].dtype, flow.float32) test_case.assertEqual(state_dict["param2"].dtype, flow.float32) m = m.double() state_dict = m.state_dict() test_case.assertEqual(state_dict["param1"].dtype, flow.float64) test_case.assertEqual(state_dict["param2"].dtype, flow.float64) @flow.unittest.skip_unless_1n1d() def test_moduledict(test_case): class ModuleDict(nn.Module): def __init__(self): super(ModuleDict, self).__init__() self.choices = nn.ModuleDict( {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} ) self.activations = nn.ModuleDict( {"relu": nn.ReLU(), "prelu": nn.PReLU()} ) def forward(self, x, choice, act): x = self.choices[choice](x) x = self.activations[act](x) return x model = ModuleDict() input = flow.tensor(np.random.randn(4, 10, 32, 32), dtype=flow.float32) output = model(input, "conv", "relu") test_case.assertEqual(output.shape, flow.Size([4, 10, 30, 30])) @flow.unittest.skip_unless_1n1d() def test_module_submodule(test_case): class CustomSubModule(flow.nn.Module): def __init__(self): super().__init__() self.param = flow.nn.Linear(2, 3) class CustomModule(flow.nn.Module): def __init__(self) -> None: super().__init__() self.linear = CustomSubModule() m = CustomModule() test_case.assertTrue( isinstance(m.get_submodule("linear.param"), flow.nn.Linear) ) @flow.unittest.skip_unless_1n1d() def test_module_get_parameter(test_case): class CustomModule(flow.nn.Module): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 tensor0 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float32)) tensor1 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float32)) m = CustomModule(tensor0, tensor1) test_case.assertTrue(m.get_parameter("param1") is tensor0) test_case.assertTrue(m.get_parameter("param2") is tensor1) def test_module_delattr(test_case): class ConvBNModule(nn.Module): def __init__(self): super(ConvBNModule, self).__init__() self.conv = nn.Conv2d(1, 2, 1, 1) self.bn = nn.BatchNorm2d(2) def forward(self, x): return self.bn(self.conv(x)) m = ConvBNModule() delattr(m, "bn") @flow.unittest.skip_unless_1n1d() def test_hooks_register(test_case): for hook in ["register_backward_hook", "register_full_backward_hook"]: _test_hooks(test_case, hook) _test_module_forward_preforward_hook_removable(test_case) _test_module_forward_forward_hook_removable(test_case) @flow.unittest.skip_unless_1n1d() def test_register_state_dict_hook_hook(test_case): destination_check = None def state_dict_hook(module, destination, prefix, local_metadata): for submodule_name, submodule in module.named_modules(): for attr_name, attr in submodule.__dict__.items(): if isinstance(attr, torch.Tensor): mod_prefix = prefix + submodule_name key = mod_prefix + ("." if mod_prefix else "") + attr_name destination[key] = attr nonlocal destination_check destination_check = destination class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 5) self._register_state_dict_hook(state_dict_hook) def forward(self, x): x = self.linear(x) return x m = CustomModule() test_case.assertEqual(destination_check, None) state_dict = m.state_dict() test_case.assertEqual(destination_check, state_dict) @flow.unittest.skip_unless_1n1d() def test_full_backward_hook(test_case): hook_triggered = False def hook(_, grad_input, grad_output): nonlocal hook_triggered hook_triggered = True test_case.assertEqual(len(grad_input), 1) test_case.assertEqual(len(grad_output), 1) test_case.assertTrue(np.array_equal(grad_input[0].numpy(), [1, 0])) test_case.assertTrue(np.array_equal(grad_output[0].numpy(), [1, 1])) m = flow.nn.ReLU() m.register_full_backward_hook(hook) x0 = flow.tensor([1.0, -1], requires_grad=True) x = x0 + 1 y = m(x) y.sum().backward() test_case.assertTrue(hook_triggered) test_case.assertTrue(np.array_equal(x0.grad, [1, 0])) @flow.unittest.skip_unless_1n1d() def test_full_backward_hook_with_return_value(test_case): hook_triggered = False def hook(_, grad_input, grad_output): nonlocal hook_triggered hook_triggered = True test_case.assertEqual(len(grad_input), 1) test_case.assertEqual(len(grad_output), 1) test_case.assertTrue(np.array_equal(grad_input[0].numpy(), [1, 0])) test_case.assertTrue(np.array_equal(grad_output[0].numpy(), [1, 1])) return (flow.tensor([1, 1]),) m = flow.nn.ReLU() m.register_full_backward_hook(hook) x0 = flow.tensor([1.0, -1], requires_grad=True) x = x0 + 1 y = m(x) y.sum().backward() test_case.assertTrue(hook_triggered) test_case.assertTrue(np.array_equal(x0.grad, [1, 1])) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_module_to.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest dummy_val = np.random.randn(2, 3) in_val = np.full((2, 3), -2) cpu0_device = flow.device("cpu") if os.getenv("ONEFLOW_TEST_CPU_ONLY"): gpu0_device = cpu0_device else: gpu0_device = flow.device("cuda") class DummyModule(flow.nn.Module): def __init__(self): super().__init__() self.register_buffer("dummy_buf", flow.Tensor(dummy_val)) self.dummy_para = flow.nn.Parameter(flow.Tensor(dummy_val)) self.register_buffer("dummy_buf_int", flow.Tensor(dummy_val).to(flow.int32)) def forward(self, x): return self.dummy_para * x + self.dummy_buf def _test_dummy_module(test_case): m = DummyModule() test_case.assertEqual(m.dummy_buf.device, cpu0_device) test_case.assertEqual(m.dummy_para.device, cpu0_device) input = flow.Tensor(in_val) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), -dummy_val, 0.0001, 0.0001)) test_case.assertEqual(m.dummy_buf.grad, None) test_case.assertEqual(m.dummy_para.grad, None) test_case.assertEqual(input.device, cpu0_device) test_case.assertEqual(output.device, cpu0_device) def _test_dummy_module_to(test_case): m = DummyModule() test_case.assertEqual(m.dummy_buf.device, cpu0_device) test_case.assertEqual(m.dummy_para.device, cpu0_device) m.to(gpu0_device) test_case.assertEqual(m.dummy_buf.device, gpu0_device) test_case.assertTrue(m.dummy_buf.is_leaf) test_case.assertTrue(not m.dummy_buf.requires_grad) test_case.assertEqual(m.dummy_para.device, gpu0_device) test_case.assertTrue(m.dummy_para.is_leaf) test_case.assertTrue(m.dummy_para.requires_grad) input = flow.Tensor(in_val).to(gpu0_device) output = m(input) test_case.assertTrue(np.allclose(output.numpy(), -dummy_val, 0.0001, 0.0001)) test_case.assertEqual(m.dummy_buf.grad, None) test_case.assertEqual(m.dummy_para.grad, None) test_case.assertEqual(input.device, gpu0_device) test_case.assertEqual(output.device, gpu0_device) output_grad = flow.ones((2, 3)).to(gpu0_device) output.backward(output_grad) test_case.assertEqual(output_grad.device, gpu0_device) test_case.assertEqual(m.dummy_buf.grad, None) test_case.assertTrue(np.allclose(m.dummy_para.grad.numpy(), in_val, 0.0001, 0.0001)) test_case.assertEqual(m.dummy_para.grad.device, gpu0_device) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestModuleTo(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") def test_module_to_device(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_dummy_module, _test_dummy_module_to] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_module_to_dtype(test_case): m = DummyModule() m.to(flow.float64) test_case.assertEqual(m.dummy_buf.dtype, flow.float64) test_case.assertEqual(m.dummy_para.dtype, flow.float64) test_case.assertEqual(m.dummy_buf_int.dtype, flow.int32) def test_module_to_tensor(test_case): m = DummyModule() m.to(flow.zeros(1, dtype=flow.float16, device="cuda")) test_case.assertEqual(m.dummy_buf.dtype, flow.float16) test_case.assertEqual(m.dummy_para.dtype, flow.float16) test_case.assertEqual(m.dummy_buf_int.dtype, flow.int32) test_case.assertEqual(m.dummy_buf.device.type, "cuda") test_case.assertEqual(m.dummy_para.device.type, "cuda") test_case.assertEqual(m.dummy_buf_int.device.type, "cuda") def test_module_to_with_var_reuse(test_case): class ReuseVarModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 4) self.linear2 = flow.nn.Linear(3, 4) self.linear2.weight = self.linear1.weight reuse_var_m = ReuseVarModule() test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) test_case.assertEqual(reuse_var_m.linear1.weight.device, cpu0_device) test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) test_case.assertEqual(reuse_var_m.linear1.bias.device, cpu0_device) reuse_var_m.to(gpu0_device) test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) test_case.assertEqual(reuse_var_m.linear1.weight.device, gpu0_device) test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) test_case.assertEqual(reuse_var_m.linear1.bias.device, gpu0_device) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_module_to_global_or_local.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestModuleToGlobalOrLocal(flow.unittest.TestCase): def test_module_to_global(test_case): rank = flow.env.get_rank() P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast class ReuseVarModule(flow.nn.Module): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 4) self.linear2 = flow.nn.Linear(3, 4) self.linear2.weight = self.linear1.weight reuse_var_m = ReuseVarModule() test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) test_case.assertEqual( reuse_var_m.linear1.weight.device, flow.device("cpu", rank) ) test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) test_case.assertEqual(reuse_var_m.linear1.bias.device, flow.device("cpu", rank)) reuse_var_m.to_global(placement=P, sbp=B) test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) test_case.assertEqual(reuse_var_m.linear1.weight.placement, P) test_case.assertEqual(reuse_var_m.linear1.weight.sbp[0], B) test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) test_case.assertEqual(reuse_var_m.linear1.bias.placement, P) test_case.assertEqual(reuse_var_m.linear1.bias.sbp[0], B) def test_module_to_local(test_case): rank = flow.env.get_rank() device = "cuda" P = flow.placement(device, ranks=[0, 1]) B = flow.sbp.broadcast S = flow.sbp.split(0) class ToLocalModule(flow.nn.Module): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 4, False) to_local_m = ToLocalModule() flow.nn.init.uniform_(to_local_m.linear.weight) to_local_m.to_global(placement=P, sbp=B) origin_w_np = to_local_m.linear.weight.numpy() to_local_m.to_global(placement=P, sbp=S) test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np) ) # When wight SBP is split(0) to_local_m.to_local() test_case.assertTrue(to_local_m.linear.weight.is_local) if rank == 0: test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np[:2]) ) elif rank == 1: test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np[2:]) ) # local to global from split(0) to_local_m.to_global(placement=P, sbp=S) test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np) ) # When wight SBP is broadcast to_local_m.to_global(placement=P, sbp=B) test_case.assertTrue(not to_local_m.linear.weight.is_local) test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np) ) # When wight SBP is broadcast to_local_m.to_local() test_case.assertTrue(to_local_m.linear.weight.is_local) test_case.assertTrue( np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_module_to_half.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestModuleToHalf(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_module_to_half(test_case): input = flow.randn(10, 10).to(flow.float16).cuda() model = flow.nn.Linear(10, 20).half().cuda() output = model(input) test_case.assertEqual(output.dtype, flow.float16) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_movedim.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestMovedim(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_movedim_with_vector(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), dim4=random(3, 6), ).to(device) z = torch.movedim(x, (0, 1), (2, 3)) return z @autotest(n=10) def test_flow_movedim_with_stride(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), dim4=random(3, 6), ).to(device) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.movedim(y, (0, 1), (2, 3)) return z @autotest(check_graph=True) def test_flow_movedim_with_int(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), dim4=random(3, 6), ).to(device) z = torch.movedim(x, 0, 3) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_moving_average_min_max_observer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) import oneflow as flow import oneflow.unittest def gen_quant_scale_for_moving_average_min_max_symmetric( activation, quantization_bit, momentum, moving_max, moving_min ): activation_max = np.max(np.abs(activation)) denominator = 2.0 ** (quantization_bit - 1) - 1 if moving_max[0] == 0: moving_max[0] = activation_max else: moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum) moving_min[0] = moving_max[0] return (moving_max[0] / denominator, 0) def gen_quant_scale_for_moving_average_min_max_affine( activation, quantization_bit, momentum, moving_max, moving_min ): activation_max = np.max(activation) activation_min = np.min(activation) denominator = 2.0 ** quantization_bit - 1 if moving_max[0] == 0: moving_max[0] = activation_max else: moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum) if moving_min[0] == 0: moving_min[0] = activation_min else: moving_min[0] = moving_min[0] * momentum + activation_min * (1 - momentum) scale = (moving_max[0] - moving_min[0]) / denominator zero_point = -np.round(moving_min[0] / scale) return (scale, zero_point) def gen_quant_scale_for_moving_average_min_max_cambricon( activation, quantization_bit, momentum, moving_max, moving_min ): activation_max = np.max(np.abs(activation)) if moving_max[0] == 0: moving_max[0] = activation_max else: moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum) moving_min[0] = moving_max[0] return (math.floor(math.log2(moving_max[0])) - (quantization_bit - 2), 0) def _check_moving_average_min_max_observer( test_case, activation, scale_of, zero_point_of, moving_max_np, moving_min_np, quantization_bit, quantization_scheme, quantization_formula, momentum, ): if quantization_formula == "google": if quantization_scheme == "symmetric": ( scale_np, zero_point_np, ) = gen_quant_scale_for_moving_average_min_max_symmetric( activation.flatten(), quantization_bit, momentum, moving_max_np, moving_min_np, ) else: ( scale_np, zero_point_np, ) = gen_quant_scale_for_moving_average_min_max_affine( activation.flatten(), quantization_bit, momentum, moving_max_np, moving_min_np, ) else: ( scale_np, zero_point_np, ) = gen_quant_scale_for_moving_average_min_max_cambricon( activation.flatten(), quantization_bit, momentum, moving_max_np, moving_min_np, ) test_case.assertTrue(np.allclose(scale_of[0], scale_np, rtol=0.001)) rmse = np.sqrt(np.mean((zero_point_of[0] - zero_point_np) ** 2)) assert ( rmse <= 1.0 ), "moving_average_min_max_observer op zero_point calculate has bug!" def _run_test_moving_average_min_max_observer( test_case, device_type, dtype, activation_shape, quantization_bit, quantization_scheme, quantization_formula, momentum, ): moving_max_np = np.zeros((1,)) moving_min_np = np.zeros((1,)) current_train_step_tensor = flow.tensor( np.zeros((1,)).astype(np.float32), dtype=flow.int64, device=flow.device(device_type), ) for i in range(10): activation = (np.random.random(activation_shape) - 0.5).astype( type_name_to_np_type[dtype] ) activation_tensor = flow.tensor( activation, dtype=flow.float32, device=flow.device(device_type) ) moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver( stop_update_after_iters=1, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, momentum=momentum, ) moving_average_min_max_observer = moving_average_min_max_observer.to( device_type ) (scale, zero_point) = moving_average_min_max_observer( activation_tensor, current_train_step_tensor ) _check_moving_average_min_max_observer( test_case, activation, scale.numpy(), zero_point.numpy(), moving_max_np, moving_min_np, quantization_bit, quantization_scheme, quantization_formula, momentum, ) class TestMovingAverageMinMaxObserver(flow.unittest.TestCase): def test_moving_average_min_max_observer(test_case): arg_dict = OrderedDict() arg_dict["test_case"] = [test_case] arg_dict["device_type"] = ["cpu", "cuda"] arg_dict["dtype"] = ["float32", "double"] arg_dict["activation_shape"] = [(9, 40, 20, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["momentum"] = [0.95] for arg in GenArgList(arg_dict): _run_test_moving_average_min_max_observer(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_mul.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_mul_impl(test_case, device): x = flow.tensor( np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.mul(x, y) np_out = np.multiply(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad_x = y.numpy() np_grad_y = x.numpy() test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05)) x = 5 y = flow.tensor( np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.mul(x, y) np_out = np.multiply(x, y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor( np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device) ) y = 5 of_out = flow.mul(x, y) np_out = np.multiply(x.numpy(), y) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor( np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.mul(x, y) np_out = np.multiply(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05)) x = flow.tensor( np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.random.randn(2, 3, 4), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.mul(x, y) np_out = np.multiply(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05)) x = flow.tensor( np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.mul(x, y) np_out = np.multiply(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05)) def inplace_mul_tensors_helper(test_case, device, arr_0, arr_y): of_x = flow.tensor( arr_0, dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_inplace_x = of_x + 1 of_y = flow.tensor( arr_y, dtype=flow.float32, device=flow.device(device), requires_grad=True, ) id_inpalce_x = id(of_inplace_x) of_inplace_x.mul_(of_y) test_case.assertTrue( np.allclose(of_inplace_x.numpy(), np.multiply(arr_0 + 1, arr_y), 1e-05, 1e-05) ) test_case.assertTrue(id_inpalce_x == id(of_inplace_x)) of_inplace_x = of_inplace_x.sum() of_inplace_x.backward() test_case.assertTrue(np.allclose(arr_y, of_x.grad.numpy(), 1e-05, 1e-05)) test_case.assertTrue(np.allclose(arr_0 + 1, of_y.grad.numpy(), 1e-05, 1e-05)) def _test_inplace_mul_tensors(test_case, device): arr_0 = np.random.rand(3, 5) arr_y = np.random.rand(3, 5) inplace_mul_tensors_helper(test_case, device, arr_0, arr_y) def _test_inplace_mul_scalar(test_case, device): arr = np.random.rand(2, 3, 4) of_x = flow.tensor( arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) y = 3.25 of_inplace_x = of_x + 1 id_x_before = id(of_inplace_x) of_inplace_x.mul_(y) test_case.assertTrue(id_x_before == id(of_inplace_x)) test_case.assertTrue(np.allclose(of_inplace_x.numpy(), np.multiply(arr + 1, y))) of_x = flow.tensor( arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_inplace_x = of_x + 1 of_inplace_x_id_before = id(of_inplace_x) of_inplace_x.mul_(y) test_case.assertTrue(of_inplace_x_id_before == id(of_inplace_x)) test_case.assertTrue( np.allclose(of_inplace_x.numpy(), np.multiply(arr + 1, y), 1e-05, 1e-05) ) of_inplace_x = of_inplace_x.sum() of_inplace_x.backward() test_case.assertTrue( np.allclose(np.full(arr.shape, y), of_x.grad.numpy(), 1e-05, 1e-05) ) def _test_mul_inplace_0size_tensor(test_case, device): targets = flow.randn((0, 6), device=flow.device(device)) height, width = 640, 640 targets[:, 2:] *= flow.tensor( (width, height, width, height), device=flow.device(device) ) test_case.assertTrue(np.array_equal(targets.size(), (0, 6))) @flow.unittest.skip_unless_1n1d() class TestMulModule(flow.unittest.TestCase): def test_mul(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_mul_impl, _test_inplace_mul_tensors, _test_inplace_mul_scalar, _test_mul_inplace_0size_tensor, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True, include_complex=True) def test_broadcast_mul(test_case): device = random_device() x_0 = random_tensor(ndim=3, dim0=4, dim1=2, dim2=3).to(device) y = random_tensor(ndim=2, dim0=2, dim1=3).to(device) x = x_0 + 1 x.mul_(y) return x @autotest(n=6, include_complex=True) def test_non_contiguous_inplace_mul(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) y = x + 1 y = y[:, 1:3] y *= random_tensor(2, 2, 2).to(device) return y @autotest(n=10, include_complex=True) def test_scalar_mul_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() x1 = random_tensor(2, 2, 3).to(x1_device).mean() x2 = random_tensor(2, 2, 3).to(x2_device) y = x1 * x2 return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_multi_tensor_yolov5_weight_update.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_multi_tensor_weight_update_impl(test_case, device, shape, n, d): def compare(a, b, rtol=1e-5, atol=1e-5): test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol ), f"\na\n{a.detach().cpu().numpy()}\n{'-' * 80}\nb:\n{b.detach().cpu().numpy()}\n{'*' * 80}\ndiff:\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}", ) weight = [] torch_weight = [] weight_update = [] torch_weight_update = [] for _ in range(n): tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) weight.append(tmp) torch_weight.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=False, ) ) tmp = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=False, ) weight_update.append(tmp) torch_weight_update.append( torch.tensor( tmp.numpy(), dtype=torch.float32, device=torch.device(device), requires_grad=False, ) ) for i, v in enumerate(torch_weight): v = v * d v = v + (1 - d) * torch_weight_update[i] torch_weight[i] = v flow._C.multi_tensor_yolov5_weight_update(weight, weight_update, d) for i in range(n): compare(weight[i], torch_weight[i]) @flow.unittest.skip_unless_1n1d() class TestMultiTensorWeightUpdateModule(flow.unittest.TestCase): def test_multi_tensor_weight_update(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_multi_tensor_weight_update_impl] arg_dict["device"] = ["cuda"] arg_dict["shape"] = [(20, 1), (30, 1), (55, 1)] arg_dict["n"] = [5, 10, 292] arg_dict["d"] = [0.22, 0.5] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_multinomial.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random import numpy as np from collections import OrderedDict import torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_multinomial(test_case, device, seed, replacement, dtype): n_dists = random.randint(8, 64) n_categories = random.randint(8, 64) num_samples = random.randint(4, n_categories) weights_torch = torch.rand( n_dists, n_categories, device=device, dtype=torch.float32 if dtype == "float" else torch.float64, ) weights_oneflow = flow.tensor( weights_torch.cpu().numpy(), device=device, dtype=flow.float32 if dtype == "float" else flow.float64, ) torch.manual_seed(seed) flow.manual_seed(seed) torch_res = torch.multinomial( weights_torch, num_samples, replacement=replacement, generator=None ) flow_res = flow.multinomial( weights_oneflow, num_samples, replacement=replacement, generator=None ) test_case.assertTrue( np.allclose(torch_res.cpu().numpy(), flow_res.cpu().numpy(), atol=1e-8,) ) torch_gen = torch.Generator(device=device) torch_gen.manual_seed(seed) oneflow_gen = flow.Generator(device=device) oneflow_gen.manual_seed(seed) torch_res = torch.multinomial( weights_torch, num_samples, replacement=replacement, generator=torch_gen ) flow_res = flow.multinomial( weights_oneflow, num_samples, replacement=replacement, generator=oneflow_gen ) test_case.assertTrue( np.allclose(torch_res.cpu().numpy(), flow_res.cpu().numpy(), atol=1e-8,) ) @flow.unittest.skip_unless_1n1d() class TestMultinomial(flow.unittest.TestCase): def test_multinomial(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["seed"] = [0, 2, 4] arg_dict["replacement"] = [True, False] arg_dict["dtype"] = ["double", "float"] for arg in GenArgList(arg_dict): _test_multinomial(test_case, *arg[0:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_nansum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestNanSumModule(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_nansum_without_nan(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) y = torch.nansum(x) return y @autotest(n=5, check_graph=True) def test_nansum_with_partial_nan(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) mask = x < 0 x = x.masked_fill(mask, float("nan")) y = torch.nansum(x) return y @autotest(n=5, check_graph=True) def test_nansum_with_total_nan(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) mask = torch.ones_like(x).bool() x = x.masked_fill(mask, float("nan")) y = torch.nansum(x) return y @autotest(n=5, check_graph=True) def test_nansum_with_partial_nan_dims(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) mask = x < 0 x = x.masked_fill(mask, float("nan")) y = torch.nansum(x, dim=random(0, 4).to(int)) return y @autotest(n=5, check_graph=True) def test_nansum_with_total_nan_dims(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) mask = torch.ones_like(x).bool() x = x.masked_fill(mask, float("nan")) y = torch.nansum(x, dim=random(0, 4).to(int)) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_sum_with_0_size_tensor(test_case): device = random_device() x = random_tensor(4, 4, 3, 0, 2).to(device) y = torch.nansum(x, dim=np.random.randint(0, 3)) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_sum_with_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.nansum(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_narrow.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from random import shuffle from scipy.fftpack import ss_diff from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestNarrow(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_narrow_start_with_random_data(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) return torch.narrow(x, dim=rand_dim, start=2, length=1) @autotest(check_graph=True) def test_flow_narrow_length_with_random_data(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) return torch.narrow(x, dim=rand_dim, start=0, length=2) @autotest(n=10, check_graph=True) def test_flow_narrow_with_stride(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) perm = [0, 1, 2] shuffle(perm) x = x.permute(perm) y = torch.narrow(x, dim=rand_dim, start=0, length=2) return y @autotest(auto_backward=False, check_graph=True) def test_flow_narrow_start_bool_with_random_data(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to( device=device, dtype=torch.bool ) return torch.narrow(x, dim=rand_dim, start=2, length=1) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ne.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_ne(test_case, shape, device): arr1 = np.random.randn(*shape) arr2 = np.random.randn(*shape) input = flow.tensor(arr1, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(arr2, dtype=flow.float32, device=flow.device(device)) of_out = flow.ne(input, other) of_out2 = flow.not_equal(input, other) np_out = np.not_equal(arr1, arr2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) test_case.assertTrue(np.array_equal(of_out2.numpy(), np_out)) test_case.assertTrue(input != None) test_case.assertTrue(None != input) def _test_tensor_ne_operator(test_case, shape, device): arr1 = np.random.randn(*shape) arr2 = np.random.randn(*shape) input = flow.tensor(arr1, dtype=flow.float32, device=flow.device(device)) other = flow.tensor(arr2, dtype=flow.float32, device=flow.device(device)) of_out = input.ne(other) np_out = np.not_equal(arr1, arr2) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_ne_int(test_case, shape, device): arr = np.random.randn(*shape) input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device)) num = 1 of_out = flow.ne(input, num) np_out = np.not_equal(arr, num) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_ne_operator_int(test_case, shape, device): arr = np.random.randn(*shape) input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device)) num = 1 of_out = input.ne(num) np_out = np.not_equal(arr, num) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_ne_float(test_case, shape, device): arr = np.random.randn(*shape) input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device)) num = 1.0 of_out = flow.ne(input, num) np_out = np.not_equal(arr, num) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_tensor_ne_operator_float(test_case, shape, device): arr = np.random.randn(*shape) input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device)) num = 1.0 of_out = input.ne(num) np_out = np.not_equal(arr, num) test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestNe(flow.unittest.TestCase): def test_ne(test_case): arg_dict = OrderedDict() arg_dict["test_func"] = [ _test_ne, _test_tensor_ne_operator, _test_ne_int, _test_tensor_ne_operator_int, _test_ne_float, _test_tensor_ne_operator_float, ] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False, check_graph=True) def test_ne_with_0_size_data(test_case): device = random_device() x1 = random_tensor(4, 2, 3, 0, 5).to(device) x2 = random_tensor(4, 2, 3, 0, 5).to(device) y1 = torch.ne(x1, x2) y2 = torch.ne(x1, 2) y3 = torch.ne(x1, 2.0) return (y1, y2, y3) @autotest(n=5, auto_backward=False) def test_ne_with_0dim_data(test_case): device = random_device() x1 = random_tensor(ndim=0).to(device) x2 = random_tensor(ndim=0).to(device) y1 = torch.ne(x1, x2) y2 = torch.ne(x1, 2) y3 = torch.ne(x1, 2.0) return (y1, y2, y3) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_negative.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestNegativeModule(flow.unittest.TestCase): @autotest(n=5, auto_backward=False, check_graph=True) def test_ne_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 5).to(device) y1 = torch.negative(x) y2 = torch.neg(x) y3 = -x return (y1, y2, y3) @autotest() def test_tensor_negative_with_random_data(test_case): x = random_tensor().to(random_device()) return x.negative() @autotest() def test_negative_with_random_data(test_case): x = random_tensor().to(random_device()) z = torch.negative(x) return z @autotest() def test_neg_with_random_data(test_case): x = random_tensor().to(random_device()) z = torch.neg(x) return z @autotest() def test_tensor_negative_with_0dim_data(test_case): x = random_tensor(ndim=0).to(random_device()) return x.negative() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_nll_loss.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @autotest(n=5) def _test_nll_loss( test_case, has_weight=False, split_batch_dim=False, split_class_dim=False ): N = random(1, 4) * 2 C = random(1, 10) * 2 ndim = random(2, 5).to(int).value() dims = [random(2, 10) for i in range(ndim - 2)] input_dims = [N, C] + dims target_dims = [N] + dims input = random_tensor(ndim, *input_dims) target = random_tensor( ndim - 1, *target_dims, low=0, high=C, dtype=int, requires_grad=False ) weight = None if has_weight: weight = random_tensor(1, C, requires_grad=False) device = random_device().value() if not split_class_dim and not split_batch_dim: input = input.to(device) target = target.to(device) if has_weight: weight = weight.to(device) else: rank = flow.env.get_rank() world_size = flow.env.get_world_size() assert world_size % 2 == 0 ranks = np.array(range(world_size)) if split_batch_dim and split_class_dim: placement = flow.placement(device, ranks.reshape((ranks.size // 2, 2))) input_sbp = [flow.sbp.split(0), flow.sbp.split(1)] target_sbp = [flow.sbp.split(0), flow.sbp.broadcast()] weight_sbp = [flow.sbp.broadcast(), flow.sbp.split(0)] elif split_batch_dim: placement = flow.placement(device, ranks) input_sbp = flow.sbp.split(0) target_sbp = flow.sbp.split(0) weight_sbp = flow.sbp.broadcast() else: placement = flow.placement(device, ranks) input_sbp = flow.sbp.split(1) target_sbp = flow.sbp.broadcast() weight_sbp = flow.sbp.split(0) input = input.to_global(placement=placement, sbp=input_sbp) target = target.to_global(placement=placement, sbp=target_sbp) # print( # f"**[{rank}] input: {input.oneflow.shape} {input.oneflow.placement} {input.oneflow.sbp}" # ) # print( # f"**[{rank}] target: {target.oneflow.shape} {target.oneflow.placement} {target.oneflow.sbp}" # ) if has_weight: # print(f"**[{rank}] weight: {weight.oneflow.numpy()}") weight = weight.to_global(placement=placement, sbp=weight_sbp) # reduction = oneof("none", "sum", "mean") reduction = ( "none" # Temporarily skip the test of "sum" and "mean" because of unknown error ) if has_weight: nll = torch.nn.NLLLoss(weight=weight, reduction=reduction) else: nll = torch.nn.NLLLoss(reduction=reduction) return nll(input, target) @flow.unittest.skip_unless_1n1d() class NLLLossTestCase(flow.unittest.TestCase): def test_local(test_case): _test_nll_loss(test_case) def test_weighted(test_case): _test_nll_loss(test_case, has_weight=True) @flow.unittest.skip_unless_1n2d() class ParallelNLLLossTestCase(flow.unittest.TestCase): @globaltest def test_data_parallel(test_case): _test_nll_loss(test_case, split_batch_dim=True) @globaltest def test_data_parallel_weighted(test_case): _test_nll_loss(test_case, has_weight=True, split_batch_dim=True) @globaltest def test_model_parallel(test_case): _test_nll_loss(test_case, split_class_dim=True) @globaltest def test_model_parallel_weighted(test_case): _test_nll_loss(test_case, has_weight=True, split_class_dim=True) @flow.unittest.skip_unless_1n4d() class TowDParallelNLLLossTestCase(flow.unittest.TestCase): @globaltest def test_2d_parallel(test_case): _test_nll_loss(test_case, split_batch_dim=True, split_class_dim=True) @globaltest def test_2d_parallel_weighted(test_case): _test_nll_loss( test_case, has_weight=True, split_batch_dim=True, split_class_dim=True ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_nms.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList def box_area(boxes): return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) def _box_inter_union_np(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = np.maximum(boxes1[:, np.newaxis, :2], boxes2[:, :2]) rb = np.minimum(boxes1[:, np.newaxis, 2:], boxes2[:, 2:]) wh = np.clip(rb - lt, a_min=0, a_max=np.inf) inter = wh[:, :, 0] * wh[:, :, 1] union = area1[:, np.newaxis] + area2 - inter return inter, union def box_iou_np(boxes1, boxes2): inter, union = _box_inter_union_np(boxes1, boxes2) iou = inter / union return iou def nms_np(boxes, scores, iou_threshold): picked = [] indexes = np.argsort(-scores) while len(indexes) > 0: current = indexes[0] picked.append(current.item()) if len(indexes) == 1: break current_box = boxes[current, :] indexes = indexes[1:] rest_boxes = boxes[indexes, :] iou = np.squeeze(box_iou_np(rest_boxes, current_box[np.newaxis]), axis=1) indexes = indexes[iou <= iou_threshold] return np.asarray(picked) def create_tensors_with_iou(N, iou_thresh): boxes = np.random.rand(N, 4) * 100 boxes[:, 2:] += boxes[:, :2] boxes[-1, :] = boxes[0, :] x0, y0, x1, y1 = boxes[-1].tolist() iou_thresh += 1e-5 boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh # Avoid score lists have the same score which will # result in an unstable sort. scores = np.random.choice(N, N, replace=False) return boxes, scores def _test_nms(test_case, device): iou = 0.5 boxes, scores = create_tensors_with_iou(1000, iou) boxes = flow.tensor(boxes, dtype=flow.float32, device=flow.device(device)) scores = flow.tensor(scores, dtype=flow.float32, device=flow.device(device)) keep_np = nms_np(boxes.numpy(), scores.numpy(), iou) keep = flow.nms(boxes, scores, iou) test_case.assertTrue(np.allclose(keep.numpy(), keep_np)) @flow.unittest.skip_unless_1n1d() class TestNMS(flow.unittest.TestCase): def test_nms(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_nms] arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_noncontiguous_binary_op.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList def _test_op(test_case, x, y, inplace): ref1 = x + y out1 = flow._C.noncontiguous_binary_op(x, y, op="add", inplace=inplace) test_case.assertTrue(np.allclose(ref1.numpy(), out1.numpy(), rtol=1e-5, atol=1e-5)) ref2 = x - y out2 = flow._C.noncontiguous_binary_op(x, y, op="sub", inplace=inplace) test_case.assertTrue(np.allclose(ref2.numpy(), out2.numpy(), rtol=1e-5, atol=1e-5)) ref3 = x * y out3 = flow._C.noncontiguous_binary_op(x, y, op="mul", inplace=inplace) test_case.assertTrue(np.allclose(ref3.numpy(), out3.numpy(), rtol=1e-5, atol=1e-5)) y = y.abs() + 1e-3 # incase zero ref4 = x / y out4 = flow._C.noncontiguous_binary_op(x, y, op="div", inplace=inplace) print(np.abs(ref4 - out4).max()) test_case.assertTrue(np.allclose(ref4.numpy(), out4.numpy(), rtol=1e-3, atol=1e-3)) def _test_noncontiguous_binary_op(test_case, dtype, pack_size, ndims, inplace): shape = [] for _ in range(ndims - 1): if np.random.uniform(-1, 1) > 0: shape.append(1 << np.random.randint(4, 7)) else: shape.append(np.random.randint(20, 100)) shape.append(1 << np.random.randint(3, 7) + pack_size) # case 1 x = flow.randn(*shape, requires_grad=True).cuda().to(dtype) y = flow.randn(*shape, requires_grad=True).cuda().to(dtype) d1, d2 = np.random.choice(ndims, 2, replace=False) x1 = x.transpose(d1, d2) y1 = y.transpose(d1, d2) _test_op(test_case, x1, y1, inplace) # case 2 y2 = flow.randn(*shape, requires_grad=True).cuda().to(dtype) shape[d1], shape[d2] = shape[d2], shape[d1] x = flow.randn(*shape, requires_grad=True).cuda().to(dtype) x2 = x.transpose(d1, d2) _test_op(test_case, x2, y2, inplace) @unittest.skipIf(True, "skip test for noncontiguous_binary_op.") @flow.unittest.skip_unless_1n1d() class TestNonContiguousBinaryOp(flow.unittest.TestCase): def test_noncontiguous_binary_op(test_case): arg_dict = OrderedDict() arg_dict["test_fn"] = [_test_noncontiguous_binary_op] arg_dict["dtype"] = [flow.float16, flow.float32] arg_dict["pack_size"] = [1, 2, 4] arg_dict["ndims"] = [2, 3, 4] arg_dict["inplace"] = [True, False] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_nonzero.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def np_nonzero(input, as_tuple): if as_tuple: return np.nonzero(input) else: return np.transpose(np.nonzero(input)) def _test_nonzero(test_case, shape, as_tuple, device): np_input = np.random.randn(*shape) input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device)) of_out = flow.nonzero(input, as_tuple) np_out = np_nonzero(np_input, as_tuple) if as_tuple: test_case.assertTrue( np.allclose(tuple(x.numpy() for x in of_out), np_out, 0.0001, 0.0001) ) else: test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestNonzero(flow.unittest.TestCase): def test_nonzero(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_nonzero] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)] arg_dict["as_tuple"] = [True, False] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(auto_backward=False, check_graph="ValidatedFalse") def test_nonzero_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(2, 5).to(int)).to(device) y = torch.nonzero(x) return y # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(auto_backward=False, check_graph="ValidatedFalse") def test_nonzero_bool_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(2, 5).to(int)).to(device=device, dtype=torch.bool) y = torch.nonzero(x) return y # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. @autotest(auto_backward=False, check_graph="ValidatedFalse") def test_half_nonzero_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(2, 5).to(int)).to( device=device, dtype=torch.float16 ) y = torch.nonzero(x) return y # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(auto_backward=False, check_graph="ValidatedFalse") def test_nonzero_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.nonzero(x) return y # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(auto_backward=False, check_graph="ValidatedFalse") def test_nonzero_tuple_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(2, 5).to(int)).to(device) y = torch.nonzero(x, as_tuple=True) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _np_vector_norm_backward(x, ord=2, dim=None): re = np.zeros_like(x) if isinstance(ord, int) and isinstance(dim, int): if ord == 0: return re else: temp = np.sum(np.abs(x ** ord), dim) ** (1.0 / ord - 1) re = np.where(x ** ord < 0, -temp, temp) * x ** (ord - 1) elif dim == None and x.ndim == 1: if ord == 0: return re elif ord == float("inf"): max_ind = np.argmax(np.abs(x)) re[max_ind] += 1 if x[max_ind] != 0 else 0 re = np.where(x < 0, -re, re) elif ord == float("-inf"): min_ind = np.argmin(np.abs(x)) re[min_ind] += 1 if x[min_ind] != 0 else 0 re = np.where(x < 0, -re, re) else: temp = np.sum(np.abs(x ** ord)) ** (1.0 / ord - 1) re = np.where(x ** ord < 0, -temp, temp) * x ** (ord - 1) elif ( isinstance(ord, float) and isinstance(dim, int) and (ord in [float("inf"), float("-inf")]) ): if ord == float("inf"): max_ind = np.argmax(np.abs(x), dim) index = ( [(i, max_ind[i]) for i in range(len(max_ind))] if dim == 1 else [(max_ind[i], i) for i in range(len(max_ind))] ) print(index) for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) else: min_ind = np.argmin(np.abs(x), dim) index = ( [(i, min_ind[i]) for i in range(len(min_ind))] if dim == 1 else [(min_ind[i], i) for i in range(len(min_ind))] ) for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) return re def _np_matrix_norm_backward(x, ord="fro"): re = np.zeros_like(x) if isinstance(ord, int): if ord == 1: max_ind = np.argmax(np.sum(np.abs(x), 0)) index = [(i, max_ind) for i in range(x.shape[0])] for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) elif ord == -1: min_ind = np.argmin(np.sum(np.abs(x), 0)) index = [(i, min_ind) for i in range(x.shape[0])] for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) elif ord == "fro": re = np.sum(x ** 2) ** (-0.5) * x elif isinstance(ord, float) and ord in [float("inf"), float("-inf")]: if ord == float("inf"): max_ind = np.argmax(np.sum(np.abs(x), 1)) index = [(max_ind, i) for i in range(x.shape[1])] for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) else: min_ind = np.argmin(np.sum(np.abs(x), 1)) index = [(min_ind, i) for i in range(x.shape[1])] for j in index: re[j] += 1 if x[j] != 0 else 0 re = np.where(x < 0, -re, re) return re def _test_norm_1d(test_case, device): input = flow.tensor( np.random.randn(10), dtype=flow.float32, device=flow.device(device) ) of_out_1 = flow.linalg.norm(input) of_out_2 = flow.linalg.norm(input, ord=0) of_out_3 = flow.linalg.norm(input, ord=3) of_out_4 = flow.linalg.norm(input, ord=float("inf")) of_out_5 = flow.linalg.norm(input, ord=-float("inf")) np_out_1 = np.linalg.norm(input.numpy()) np_out_2 = np.linalg.norm(input.numpy(), ord=0) np_out_3 = np.linalg.norm(input.numpy(), ord=3) np_out_4 = np.linalg.norm(input.numpy(), ord=float("inf")) np_out_5 = np.linalg.norm(input.numpy(), ord=-float("inf")) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_4.numpy(), np_out_4, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_5.numpy(), np_out_5, 1e-05, 1e-05)) def _test_norm_2d(test_case, device): input = flow.tensor( np.random.randn(5, 4), dtype=flow.float32, device=flow.device(device) ) of_out_1 = flow.linalg.norm(input) of_out_2 = flow.linalg.norm(input, dim=0) of_out_3 = flow.linalg.norm(input, dim=1, keepdim=True) of_out_4 = flow.linalg.norm(input, ord=1, dim=0) of_out_5 = flow.linalg.norm(input, ord=-1, dim=1, keepdim=True) np_out_1 = np.linalg.norm(input.numpy()) np_out_2 = np.linalg.norm(input.numpy(), axis=0) np_out_3 = np.linalg.norm(input.numpy(), axis=1, keepdims=True) np_out_4 = np.linalg.norm(input.numpy(), ord=1, axis=0) np_out_5 = np.linalg.norm(input.numpy(), ord=-1, axis=1, keepdims=True) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_4.numpy(), np_out_4, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_5.numpy(), np_out_5, 1e-05, 1e-05)) def _test_norm_Nd(test_case, device): input1 = flow.tensor( np.random.randn(3, 4, 3), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.random.randn(3, 4, 3, 5), dtype=flow.float32, device=flow.device(device) ) of_out_1 = flow.linalg.norm(input1) of_out_2 = flow.linalg.norm(input1, dim=(0, 1)) of_out_3 = flow.linalg.norm(input2, dim=(0, 2)) np_out_1 = np.linalg.norm(input1.numpy()) np_out_2 = np.linalg.norm(input1.numpy(), axis=(0, 1)) np_out_3 = np.linalg.norm(input2.numpy(), axis=(0, 2)) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05)) def _test_fro_order_norm_backward(test_case, device): input = flow.tensor( np.random.randn(5, 4), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.linalg.norm(input) of_out.backward() np_out_grad = _np_matrix_norm_backward(input.numpy()) test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) def _test_1d_inf_order_norm_backward(test_case, device): for ord in [float("inf"), -float("inf")]: input = flow.tensor( np.random.randn(5), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.linalg.norm(input, ord=ord) of_out.backward() np_out_grad = _np_vector_norm_backward(input.numpy(), ord=ord) test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) def _test_2d_inf_order_norm_backward(test_case, device): for ord in [float("inf"), -float("inf")]: input = flow.tensor( np.random.randn(5, 4), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.linalg.norm(input, ord=ord) of_out.backward() np_out_grad = _np_matrix_norm_backward(input.numpy(), ord=ord) test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) def _test_1d_digits_order_norm_backward(test_case, device): for ord in [1, -1, 2, -2, 5]: input = flow.tensor( np.random.randn(5), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.linalg.norm(input, ord=ord) of_out.backward() np_out_grad = _np_vector_norm_backward(input.numpy(), ord=ord) test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) def _test_2d_digits_order_norm_backward(test_case, device): for ord in [1, -1]: input = flow.tensor( np.random.randn(4, 5), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.linalg.norm(input, ord=ord) of_out.backward() np_out_grad = _np_matrix_norm_backward(input.numpy(), ord=ord) test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) def _test_linalg_norm_shape_not_match(test_case, device): x = flow.randn(1, 3, 1, 5, 2) x = x.to(device) y = flow.linalg.norm(x, keepdim=True) test_case.assertEqual(y.size(), (1, 1, 1, 1, 1)) @flow.unittest.skip_unless_1n1d() class TestNormModule(flow.unittest.TestCase): def test_norm(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_norm_1d, _test_norm_2d, _test_norm_Nd, _test_fro_order_norm_backward, _test_1d_inf_order_norm_backward, _test_2d_inf_order_norm_backward, _test_1d_digits_order_norm_backward, _test_2d_digits_order_norm_backward, _test_linalg_norm_shape_not_match, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_no_dim_no_ord_norm_with_random_data(test_case): device = random_device() input = random_tensor().to(device) keepdim = random_bool() m = torch.linalg.norm(input, keepdim=keepdim) n = torch.norm(input, keepdim=keepdim) return m, n @autotest(n=5) def test_one_dim_norm_with_random_data(test_case): device = random_device() input = random_tensor(ndim=4).to(device) dim = random(low=0, high=4).to(int) k = random().to(float) ord = oneof(float("inf"), float("-inf"), k, None) keepdim = random_bool() m = torch.linalg.norm(input, ord, dim, keepdim) n = torch.norm(input, ord, dim, keepdim) return m, n @autotest(n=5) def test_no_dim_one_shape_norm_with_random_data(test_case): device = random_device() input = random_tensor(ndim=1).to(device) k = random().to(float) ord = oneof(float("inf"), float("-inf"), k) keepdim = random_bool() m = torch.linalg.norm(input, ord=ord, keepdim=keepdim) n = torch.norm(input, p=ord, keepdim=keepdim) return m, n @autotest(n=5) def test_no_dim_two_shape_norm_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2).to(device) ord = oneof(float("inf"), float("-inf"), "fro", 1, -1) keepdim = random().to(bool) m = torch.linalg.norm(input, ord=ord, keepdim=keepdim) return m @autotest(n=5) def test_tuple_dim_norm_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2).to(device) dim = oneof((-2, -1), (0, 1), (-1, 0)) ord = oneof(float("inf"), float("-inf"), "fro", 1, -1, None) keepdim = random().to(bool) m = torch.linalg.norm(input, ord=ord, dim=dim, keepdim=keepdim) return m @autotest(n=5) def test_vector_norm_only_zero_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2).to(device) dim = oneof((-2, -1), (0, 1), (-1, 0)) keepdim = random().to(bool) m = torch.linalg.vector_norm(input, ord=0, dim=dim, keepdim=keepdim) return m @autotest(n=5) def test_ord_random_data(test_case): device = random_device() ndim = random(1, 3).to(int) input = random_tensor(ndim).to(device) p1 = random(-5, -1).to(int).value() p2 = random(2, 6).to(int).value() m = input.norm(p1) n = input.norm(p2) return m, n if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_normalize.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type from oneflow.test_utils.automated_test_util import * def _test_functional_normalize_double_dtype(test_case, device, dtype): dtype = type_name_to_flow_type[dtype] x = flow.ones(2, 2, dtype=dtype).to(device) y = flow.nn.functional.normalize(x, p=2, dim=0) test_case.assertEqual((2, 2), y.shape) out = np.array( [ [0.7071067690849304, 0.7071067690849304], [0.7071067690849304, 0.7071067690849304], ] ) test_case.assertTrue(np.allclose(y.numpy().tolist(), out, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestFunctionalNormalize(flow.unittest.TestCase): def test_functional_normalize_naive(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [_test_functional_normalize_double_dtype] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = ["float32", "double"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_functional_normalize(test_case): device = random_device() ndim = random(low=2) shape = list(random_tensor(ndim=ndim).oneflow.shape) dim = random(low=0, high=ndim).to(int).value() shape[dim] = random(low=2, high=8).to(int).value() shape = tuple(shape) x = random_tensor(len(shape), *shape).to(device) y = torch.nn.functional.normalize(x, oneof(2, 3, 4), dim, 1e-12) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ofrecord_reader.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest class OFRecordDataLoader(flow.nn.Module): def __init__(self, batch_size, device=None, placement=None, sbp=None): super().__init__() # don't shuffle, for comparing shuffle = False self.ofrecord_reader = flow.nn.OFRecordReader( flow.unittest.dataset_dir("imagenet_227/train/32"), batch_size=batch_size, data_part_num=2, random_shuffle=shuffle, shuffle_after_epoch=shuffle, device=device, placement=placement, sbp=sbp, ) self.record_label_decoder = flow.nn.OFRecordRawDecoder( "class/label", shape=(), dtype=flow.int32 ) self.record_image_decoder = flow.nn.OFRecordImageDecoder( "encoded", color_space="RGB" ) self.resize = flow.nn.image.Resize(target_size=[227, 227], dtype=flow.float32) def forward(self): record = self.ofrecord_reader() label = self.record_label_decoder(record) image_raw_buffer = self.record_image_decoder(record) image = self.resize(image_raw_buffer)[0] return image, label class DataLoaderGraph(flow.nn.Graph): def __init__(self, loader): super().__init__() self.loader_ = loader def build(self): return self.loader_() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @unittest.skipUnless(os.path.exists(flow.unittest.dataset_dir("imagenet_227")), "") @flow.unittest.skip_unless_1n2d() class DistributedOFRecordReaderTestCase(oneflow.unittest.TestCase): def test(test_case): rank = flow.env.get_rank() # print(f"DistributedOFRecordReaderTestCase.test on rank {rank} {os.getpid()}") eager_ofrecord_loader = OFRecordDataLoader( batch_size=2, device=flow.device("cpu", rank) ) lazy_global_loader = OFRecordDataLoader( batch_size=4, placement=flow.placement("cpu", ranks=[0, 1]), sbp=[flow.sbp.split(0)], ) loader_graph = DataLoaderGraph(lazy_global_loader) iteration = 2 for i in range(iteration): image, label = eager_ofrecord_loader() # print( # f"rank {rank} image: {image.shape}, {image.dtype}, device: {image.device}" # f"\n{image.numpy().mean()}" # ) # print( # f"rank {rank} label: {label.shape}, {label.dtype}, device: {label.device}" # f"\n{label.numpy()}" # ) g_image, g_label = loader_graph() # print( # f"rank {rank} graph output image: {g_image.shape}, {g_image.dtype}, placement: {g_image.placement}" # f"\n{g_image.to_local().numpy().mean()}" # ) # print( # f"rank {rank} graph output label: {g_label.shape}, {g_label.dtype}, placement: {g_image.placement}" # f"\n{g_label.to_local().numpy()}" # ) # print(f"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}") test_case.assertTrue(np.allclose(image.numpy(), g_image.to_local().numpy())) test_case.assertTrue(np.allclose(label.numpy(), g_label.to_local().numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_one_embedding_adagrad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import tempfile import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import numpy as np from oneflow.test_utils.test_util import GenArgDict from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adagrad( test_case, weight_decay, lr_decay, scale, learning_rate, train_iters, ): num_rows = 500 embedding_size = 128 model_shape = (num_rows, embedding_size) line_size = embedding_size * 2 num_valid_seq = np.random.randint(1, num_rows, (train_iters)) skip_if_seq = [np.random.randint(2) for i in range(train_iters)] random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) down_scale_by = 10 epsilon = 1e-5 class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build( self, ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, train_step, ): # add id shuffle to set num_unique in op, and use it in update (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle( ids, table_ids=None, num_tables=1, embedding_name="" ) return flow._C.one_embedding_adagrad_update( num_valid, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, train_step, 0, 0.0, scale, weight_decay, lr_decay, epsilon, line_size, embedding_size, "", ) graph = TestGraph() def adagrad_by_oneflow(): unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( "cuda" ) lr_tensor = flow.tensor( np.array(learning_rate).reshape(1,).astype(np.float32) ).to("cuda") down_scale_by_tensor = flow.tensor( np.array(down_scale_by).reshape(1,).astype(np.float32) ).to("cuda") def train_one_iter(ids, unique_embeddings, embedding_grad, skip_if, train_step): return graph( ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, train_step, ) for i in range(1, train_iters): np_ids = np.zeros(num_rows) np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i]) # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input ids = flow.tensor(np_ids.astype(np.int32)).to("cuda") grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") skip_if_tensor = flow.tensor( np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) ).to("cuda") step_tensor = flow.tensor(np.array(i).reshape(1,).astype(np.int64)).to( "cuda" ) updated_tensor = train_one_iter( ids, unique_embeddings_tensor, grad_tensor, skip_if_tensor, step_tensor, ) unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ 0 : num_valid_seq[i] ] return unique_embeddings_tensor def adagrad_by_numpy(): x = init_value[:, 0:embedding_size] st = init_value[:, embedding_size:] def train_one_iter(iter, num_valid, grad, model, state): grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) lr = learning_rate / (1 + iter * lr_decay) state[0:num_valid] = ( state[0:num_valid] + grad[0:num_valid] * grad[0:num_valid] ) model[0:num_valid] = ( model[0:num_valid] - lr / (np.sqrt(state[0:num_valid]) + epsilon) * grad[0:num_valid] - lr * weight_decay * model[0:num_valid] ) return (model, state) for i in range(1, train_iters): if skip_if_seq[i] > 0: pass else: (x, st) = train_one_iter( i, int(num_valid_seq[i]), random_grad_seq[i], x, st ) return x, st oneflow_res = adagrad_by_oneflow().numpy() of_model = oneflow_res[:, 0:embedding_size] of_sum = oneflow_res[:, embedding_size:] np_model, np_sum = adagrad_by_numpy() test_case.assertTrue( np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) ) test_case.assertTrue( np.allclose(of_sum.flatten(), np_sum.flatten(), rtol=0.001, atol=0.001) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestOptimizers(flow.unittest.TestCase): def test_one_embedding_adagrad(test_case): arg_dict = OrderedDict() arg_dict["weight_decay"] = [0, 0.1] arg_dict["lr_decay"] = [0, 0.1] arg_dict["scale"] = [1, 0.1] arg_dict["learning_rate"] = [0.3, 1.5] arg_dict["train_iters"] = [10] for arg in GenArgDict(arg_dict): compare_with_numpy_adagrad(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_one_embedding_adam.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import tempfile import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import numpy as np from oneflow.test_utils.test_util import GenArgDict from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adam( test_case, weight_decay, scale, learning_rate, train_iters, do_bias_correction, beta1, beta2, use_optional_tensor, ): num_rows = 500 embedding_size = 128 model_shape = (num_rows, embedding_size) line_size = embedding_size * 3 num_valid_seq = np.random.randint(1, num_rows, (train_iters)) skip_if_seq = [np.random.randint(2) for i in range(train_iters)] random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) down_scale_by = 10 """ In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode. in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest. if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test. """ bias_correction1_val = 1.0 bias_correction2_val = 1.0 if use_optional_tensor: scale_val = scale else: # if pass as attr instead of tensor, mul down_scale_by to scale_value scale_val = scale / down_scale_by epsilon = 1e-5 class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build( self, ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, bias_correction1, bias_correction2, ): # add id shuffle to set num_unique in op, and use it in update (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle( ids, table_ids=None, num_tables=1, embedding_name="" ) return flow._C.one_embedding_adam_update( num_valid, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, bias_correction1, bias_correction2, learning_rate, scale_val, weight_decay, beta1, beta2, bias_correction1_val, bias_correction2_val, epsilon, do_bias_correction, line_size, embedding_size, embedding_name="", ) graph = TestGraph() def adam_by_oneflow(): unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( "cuda" ) if use_optional_tensor: lr_tensor = flow.tensor( np.array(learning_rate).reshape(1,).astype(np.float32) ).to("cuda") down_scale_by_tensor = flow.tensor( np.array(down_scale_by).reshape(1,).astype(np.float32) ).to("cuda") else: lr_tensor = None down_scale_by_tensor = None def train_one_iter( ids, unique_embeddings, embedding_grad, skip_if, bias_correction1, bias_correction2, ): return graph( ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, bias_correction1, bias_correction2, ) for i in range(1, train_iters): np_ids = np.zeros(num_rows) np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i]) # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input ids = flow.tensor(np_ids.astype(np.int32)).to("cuda") grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") if use_optional_tensor: skip_if_tensor = flow.tensor( np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) ).to("cuda") else: skip_if_tensor = None if do_bias_correction and use_optional_tensor: bias_correction1 = 1.0 - np.power(beta1, i) bias_correction2 = 1.0 - np.power(beta2, i) bias_correction1_tensor = flow.tensor( np.array(bias_correction1).reshape(1,).astype(np.float32) ).to("cuda") bias_correction2_tensor = flow.tensor( np.array(bias_correction2).reshape(1,).astype(np.float32) ).to("cuda") else: bias_correction1_tensor = None bias_correction2_tensor = None updated_tensor = train_one_iter( ids, unique_embeddings_tensor, grad_tensor, skip_if_tensor, bias_correction1_tensor, bias_correction2_tensor, ) unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ 0 : num_valid_seq[i] ] return unique_embeddings_tensor def adam_by_numpy(): x = init_value[:, 0:embedding_size] m = init_value[:, embedding_size : 2 * embedding_size] v = init_value[:, 2 * embedding_size : 3 * embedding_size] def np_train_one_iter(step, num_valid, grad, model, state_m, state_v): grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction and use_optional_tensor: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) state_m[0:num_valid] = ( beta1 * state_m[0:num_valid] + (1 - beta1) * grad[0:num_valid] ) state_v[0:num_valid] = ( beta2 * state_v[0:num_valid] + (1 - beta2) * grad[0:num_valid] * grad[0:num_valid] ) denom = np.sqrt(state_v[0:num_valid]) / np.sqrt(bias_correction2) + epsilon model[0:num_valid] = ( model[0:num_valid] - ((learning_rate / bias_correction1) * state_m[0:num_valid] / denom) - learning_rate * weight_decay * model[0:num_valid] ) return (model, state_m, state_v) for i in range(1, train_iters): # if step = 0, bias_correction2 is 0 if skip_if_seq[i] > 0 and use_optional_tensor: pass else: (x, m, v) = np_train_one_iter( i, int(num_valid_seq[i]), random_grad_seq[i], x, m, v ) return x, m, v oneflow_res = adam_by_oneflow().numpy() of_model = oneflow_res[:, 0:embedding_size] of_m = oneflow_res[:, embedding_size : 2 * embedding_size] of_v = oneflow_res[:, 2 * embedding_size : 3 * embedding_size] np_model, np_m, np_v = adam_by_numpy() test_case.assertTrue( np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) ) test_case.assertTrue( np.allclose(of_m.flatten(), np_m.flatten(), rtol=0.001, atol=0.001) ) test_case.assertTrue( np.allclose(of_v.flatten(), np_v.flatten(), rtol=0.001, atol=0.001) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestOptimizers(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 16 times in past week") def test_one_embedding_adam(test_case): arg_dict = OrderedDict() arg_dict["weight_decay"] = [0, 0.1] arg_dict["scale"] = [1, 0.1] arg_dict["learning_rate"] = [1, 1.5] arg_dict["train_iters"] = [10] arg_dict["do_bias_correction"] = [True, False] arg_dict["beta1"] = [0.9, 0.8] arg_dict["beta2"] = [0.9, 0.8] arg_dict["use_optional_tensor"] = [True, False] for arg in GenArgDict(arg_dict): compare_with_numpy_adam(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_one_embedding_ftrl.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import tempfile import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import numpy as np from oneflow.test_utils.test_util import GenArgDict from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_ftrl( test_case, weight_decay, lr_power, lambda1, lambda2, beta, scale, learning_rate, train_iters, use_optional_tensor, ): num_rows = 500 embedding_size = 128 model_shape = (num_rows, embedding_size) line_size = embedding_size * 3 num_valid_seq = np.random.randint(1, num_rows, (train_iters)) skip_if_seq = [np.random.randint(2) for i in range(train_iters)] random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) down_scale_by = 10 """ In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode. in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest. if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test. """ if use_optional_tensor: scale_val = scale else: # if pass as attr instead of tensor, mul down_scale_by to scale_value scale_val = scale / down_scale_by class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build( self, ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, ): # add id shuffle to set num_unique in op, and use it in update (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle( ids, table_ids=None, num_tables=1, embedding_name="" ) return flow._C.one_embedding_ftrl_update( num_valid, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, learning_rate, scale_val, weight_decay, lr_power, lambda1, lambda2, beta, line_size, embedding_size, embedding_name="", ) graph = TestGraph() def ftrl_by_oneflow(): unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( "cuda" ) if use_optional_tensor: lr_tensor = flow.tensor( np.array(learning_rate).reshape(1,).astype(np.float32) ).to("cuda") down_scale_by_tensor = flow.tensor( np.array(down_scale_by).reshape(1,).astype(np.float32) ).to("cuda") else: lr_tensor = None down_scale_by_tensor = None def train_one_iter(ids, unique_embeddings, embedding_grad, skip_if): return graph( ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, ) for i in range(1, train_iters): np_ids = np.zeros(num_rows) np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i]) # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input ids = flow.tensor(np_ids.astype(np.int32)).to("cuda") grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") if use_optional_tensor: skip_if_tensor = flow.tensor( np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) ).to("cuda") else: skip_if_tensor = None updated_tensor = train_one_iter( ids, unique_embeddings_tensor, grad_tensor, skip_if_tensor, ) unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ 0 : num_valid_seq[i] ] return unique_embeddings_tensor def ftrl_by_numpy(): x = init_value[:, 0:embedding_size] accumulate = init_value[:, embedding_size : 2 * embedding_size] z = init_value[:, 2 * embedding_size :] def train_one_iter(iter, num_valid, grad, model, accum, z): grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) new_accum = accumulate[0:num_valid] + grad[0:num_valid] * grad[0:num_valid] sigma = ( np.power(new_accum, lr_power) - np.power(accumulate[0:num_valid], lr_power) ) / learning_rate new_z_val = z[0:num_valid] + grad[0:num_valid] - sigma * model[0:num_valid] # Here weight_decay equals to AdamW's, not equal to l2. update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / ( (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2 ) - learning_rate * weight_decay * model[0:num_valid] model[0:num_valid] = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val) accumulate[0:num_valid] = new_accum z[0:num_valid] = new_z_val return (model, accumulate, z) for i in range(1, train_iters): # when use_optional_tensor is False, not pass skip_if to op if skip_if_seq[i] > 0 and use_optional_tensor: pass else: (x, accumulate, z) = train_one_iter( i, int(num_valid_seq[i]), random_grad_seq[i], x, accumulate, z ) return x, accumulate, z oneflow_res = ftrl_by_oneflow().numpy() of_model = oneflow_res[:, 0:embedding_size] of_accum = oneflow_res[:, embedding_size : 2 * embedding_size] of_z = oneflow_res[:, 2 * embedding_size :] np_model, np_accum, np_z = ftrl_by_numpy() test_case.assertTrue( np.allclose(of_model.flatten(), np_model.flatten(), rtol=1e-4, atol=1e-4) ) test_case.assertTrue( np.allclose(of_accum.flatten(), np_accum.flatten(), rtol=1e-4, atol=1e-4) ) test_case.assertTrue( np.allclose(of_z.flatten(), np_z.flatten(), rtol=1e-4, atol=1e-4) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestOptimizers(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_ftrl(test_case): arg_dict = OrderedDict() arg_dict["weight_decay"] = [ 0.0 ] # TODO(zzk): Currently Only support weight_decay = 0.0. arg_dict["lr_power"] = [-0.2, -0.05] arg_dict["lambda1"] = [0.1] arg_dict["lambda2"] = [0.00] arg_dict["beta"] = [1.0] arg_dict["scale"] = [1, 0.1] arg_dict["learning_rate"] = [0.3, 1.5] arg_dict["train_iters"] = [10] arg_dict["use_optional_tensor"] = [True, False] for arg in GenArgDict(arg_dict): compare_with_numpy_ftrl(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_one_embedding_sgd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import tempfile import os # dynamic memory allocation can't be tested in unittest os.environ["ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"] = "0" import numpy as np from oneflow.test_utils.test_util import GenArgDict from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_sgd( test_case, momentum, weight_decay, scale, learning_rate, train_iters, use_optional_tensor, ): # if use_optional_tensor, pass lr as tensor to sgd_update, else pass as attr. num_rows = 500 embedding_size = 128 model_shape = (num_rows, embedding_size) line_size = embedding_size * 2 if momentum > 0 else embedding_size num_valid_seq = np.random.randint(1, num_rows, (train_iters)) skip_if_seq = [np.random.randint(2) for i in range(train_iters)] random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) """ In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode. in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest. if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test. """ down_scale_by = 10 if use_optional_tensor: scale_val = scale else: # if pass as attr instead of tensor, mul down_scale_by to scale_value scale_val = scale / down_scale_by class TestGraph(flow.nn.Graph): def __init__(self): super().__init__() def build( self, ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, ): # add id shuffle to set num_unique in op, and use it in update (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle( ids, table_ids=None, num_tables=1, embedding_name="" ) return flow._C.one_embedding_sgd_update( num_valid, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, learning_rate, scale_val, weight_decay, momentum, line_size, embedding_size, embedding_name="", ) graph = TestGraph() def sgd_by_oneflow(): unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( "cuda" ) if use_optional_tensor: lr_tensor = flow.tensor( np.array(learning_rate).reshape(1,).astype(np.float32) ).to("cuda") down_scale_by_tensor = flow.tensor( np.array((down_scale_by,)).astype(np.float32) ).to("cuda") else: # pass by attr lr_tensor = None down_scale_by_tensor = None def train_one_iter( ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, ): return graph( ids, unique_embeddings, embedding_grad, lr_tensor, down_scale_by_tensor, skip_if, ) for i in range(train_iters): np_ids = np.zeros(num_rows) np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i]) # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input ids = flow.tensor(np_ids.astype(np.int32)).to("cuda") grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") if use_optional_tensor: skip_if_tensor = flow.tensor( np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) ).to("cuda") else: skip_if_tensor = None updated_tensor = train_one_iter( ids, unique_embeddings_tensor, grad_tensor, lr_tensor, down_scale_by_tensor, skip_if_tensor, ) unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ 0 : num_valid_seq[i] ] return unique_embeddings_tensor def sgd_by_numpy(): x = init_value[:, 0:embedding_size] vt = init_value[:, embedding_size:] def train_one_iter(num_valid, grad, model, state): grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) next_state = ( (momentum * state[0:num_valid] + grad[0:num_valid]) if momentum > 0 else 0 ) if momentum > 0: state[0:num_valid] = next_state model[0:num_valid] = ( model[0:num_valid] - learning_rate * next_state - learning_rate * weight_decay * model[0:num_valid] ) else: state[0:num_valid] = 0 model[0:num_valid] = ( model[0:num_valid] - learning_rate * grad[0:num_valid] - learning_rate * weight_decay * model[0:num_valid] ) return (model, state) for i in range(train_iters): if skip_if_seq[i] > 0 and use_optional_tensor: pass else: (x, vt) = train_one_iter( int(num_valid_seq[i]), random_grad_seq[i], x, vt ) return x, vt oneflow_res = sgd_by_oneflow().numpy() of_model = oneflow_res[:, 0:embedding_size] of_momentum = oneflow_res[:, embedding_size:] np_model, np_momentum = sgd_by_numpy() test_case.assertTrue( np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) ) if momentum > 0: test_case.assertTrue( np.allclose( of_momentum.flatten(), np_momentum.flatten(), rtol=0.001, atol=0.001 ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestOptimizers(flow.unittest.TestCase): def test_one_embedding_sgd(test_case): arg_dict = OrderedDict() arg_dict["momentum"] = [0, 0.9] arg_dict["weight_decay"] = [0, 0.1] arg_dict["scale"] = [1, 0.1] arg_dict["learning_rate"] = [1, 0.9] arg_dict["train_iters"] = [10] arg_dict["use_optional_tensor"] = [True, False] for arg in GenArgDict(arg_dict): compare_with_numpy_sgd(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_one_hot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow def _test_one_hot(test_case, device, num_classes, size, on_value, off_value): x = np.random.randint(9, size=size) input = flow.tensor(x, device=flow.device(device), dtype=flow.int64) output = flow.nn.functional.one_hot(input, num_classes, on_value, off_value) if num_classes == -1: np_outtmp = np.eye(np.max(x) + 1)[x] else: np_outtmp = np.eye(num_classes)[x] np_out = np.where(np_outtmp == 1, on_value, off_value) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) @flow.unittest.skip_unless_1n1d() class TestOnehot(flow.unittest.TestCase): def test_onehot(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_one_hot, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["num_classes"] = [-1, 10, 11] arg_dict["size"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["on_value"] = [-1, -0.9, 0, 0.9, 1] arg_dict["off_value"] = [-2, -0.5, 0, 0.5, 2] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(auto_backward=False) def test_one_hot_scalar(test_case): x = torch.tensor(2) y = torch.nn.functional.one_hot(x, num_classes=5) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_ones_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_ones_like_float(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = flow.ones_like(x) test_case.assertTrue(y.dtype is flow.float32) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.device == x.device) y_numpy = np.ones_like(x.numpy()) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) def _test_ones_like_int(test_case, shape, device): x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device)) y = flow.ones_like(x) test_case.assertTrue(y.dtype is flow.int) test_case.assertTrue(y.shape == x.shape) test_case.assertTrue(y.device == x.device) y_numpy = np.ones_like(x.numpy()) test_case.assertTrue(np.array_equal(y.numpy(), y_numpy)) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_ones_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_ones_like_float, _test_ones_like_int] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_adadelta.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import tempfile import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adadelta( test_case, device, x_shape, learning_rate, train_iters, rho, eps, maximize, weight_decay, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) adadelta = flow.optim.Adadelta( [{"params": [x], "lr": learning_rate, "weight_decay": weight_decay,}], rho=rho, eps=eps, maximize=maximize, contiguous_params=contiguous_params, ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, requires_grad=False, device=flow.device(device) ) loss = flow.sum(x * grad_tensor) loss.backward() adadelta.step() adadelta.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adadelta.state_dict() adadelta = flow.optim.Adadelta([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adadelta.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value square_avgs = np.zeros_like(x) acc_deltas = np.zeros_like(x) def train_one_iter(grad): grad = grad if not maximize else -grad grad = grad + weight_decay * x new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad std = np.sqrt(new_square_avgs + eps) delta = np.sqrt(acc_deltas + eps) / std * grad new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho) param = x - learning_rate * delta return (param, new_square_avgs, new_acc_deltas) for i in range(1, train_iters + 1): (x, square_avgs, acc_deltas) = train_one_iter(random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4) ) def compare_with_numpy_adadelta_clip_grad( test_case, device, x_shape, learning_rate, train_iters, rho, eps, maximize, weight_decay, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) adadelta = flow.optim.Adadelta( [ { "params": [x], "lr": learning_rate, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], rho=rho, eps=eps, maximize=maximize, contiguous_params=contiguous_params, ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, requires_grad=False, device=flow.device(device) ) loss = flow.sum(x * grad_tensor) loss.backward() adadelta.clip_grad() adadelta.step() adadelta.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adadelta.state_dict() adadelta = flow.optim.Adadelta([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adadelta.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value square_avgs = np.zeros_like(x) acc_deltas = np.zeros_like(x) def train_one_iter(grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad if not maximize else -grad grad = grad + weight_decay * x new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad std = np.sqrt(new_square_avgs + eps) delta = np.sqrt(acc_deltas + eps) / std * grad new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho) param = x - learning_rate * delta return (param, new_square_avgs, new_acc_deltas) for i in range(1, train_iters + 1): (x, square_avgs, acc_deltas) = train_one_iter(random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4) ) @flow.unittest.skip_unless_1n1d() class TestAdadelta(flow.unittest.TestCase): def test_adadelta(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["rho"] = [0.9, 0.6] arg_dict["eps"] = [1e-6, 1e-4] arg_dict["maximize"] = [False] arg_dict["weight_decay"] = [0.0, 0.1] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_adadelta(test_case, *arg) def test_adadelta_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1e-3] arg_dict["train_iters"] = [10] arg_dict["rho"] = [0.9, 0.6] arg_dict["eps"] = [1e-6, 1e-4] arg_dict["maximize"] = [False] arg_dict["weight_decay"] = [0.0, 0.1] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_adadelta_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_adagrad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import tempfile import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adagrad( test_case, device, x_shape, learning_rate, train_iters, lr_decay, weight_decay, initial_accumulator_value, eps, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) adagrad = flow.optim.Adagrad( [ { "params": [x], "lr": learning_rate, "eps": eps, "weight_decay": weight_decay, } ], lr_decay=lr_decay, initial_accumulator_value=initial_accumulator_value, contiguous_params=contiguous_params, ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, requires_grad=False, device=flow.device(device) ) loss = flow.sum(x * grad_tensor) loss.backward() adagrad.step() adagrad.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adagrad.state_dict() adagrad = flow.optim.Adagrad([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adagrad.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value st = np.ones_like(x) * initial_accumulator_value def train_one_iter(iter, grad): grad = grad + weight_decay * x lr = learning_rate / (1 + (iter - 1) * lr_decay) s = st + grad * grad param = x - lr / (np.sqrt(s) + eps) * grad return (param, s) for i in range(1, train_iters + 1): (x, st) = train_one_iter(i, random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-3, atol=1e-3) ) def compare_with_numpy_adagrad_clip_grad( test_case, device, x_shape, learning_rate, train_iters, lr_decay, weight_decay, initial_accumulator_value, eps, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) adagrad = flow.optim.Adagrad( [ { "params": [x], "lr": learning_rate, "eps": eps, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], lr_decay=lr_decay, initial_accumulator_value=initial_accumulator_value, contiguous_params=contiguous_params, ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, requires_grad=False, device=flow.device(device) ) loss = flow.sum(x * grad_tensor) loss.backward() adagrad.clip_grad() adagrad.step() adagrad.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adagrad.state_dict() adagrad = flow.optim.Adagrad([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adagrad.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value st = np.ones_like(x) * initial_accumulator_value def train_one_iter(iter, grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x lr = learning_rate / (1 + (iter - 1) * lr_decay) s = st + grad * grad param = x - lr / (np.sqrt(s) + eps) * grad return (param, s) for i in range(1, train_iters + 1): (x, st) = train_one_iter(i, random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-3, atol=1e-3) ) @flow.unittest.skip_unless_1n1d() class TestAdagrad(flow.unittest.TestCase): def test_adagrad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["lr_decay"] = [0.9, 0.75] arg_dict["weight_decay"] = [0.0, 0.1] arg_dict["initial_accumulator_value"] = [1.0, 2.1] arg_dict["eps"] = [1e-08, 1e-07] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_adagrad(test_case, *arg) def test_adagrad_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["lr_decay"] = [0.9, 0.75] arg_dict["weight_decay"] = [0.0, 0.1] arg_dict["initial_accumulator_value"] = [2.1] arg_dict["eps"] = [1e-07] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_adagrad_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_adam.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import tempfile import unittest from collections import OrderedDict import random as random_util import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import random_device, random_bool from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adam( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) adam = flow.optim.Adam( [ { "params": x, "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() adam.step() adam.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adam.state_dict() adam = flow.optim.Adam( [{"params": x,}], contiguous_params=contiguous_params ) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adam.load_state_dict(state_dict) return x def train_by_numpy(tensor_idx): x = init_value_seq[tensor_idx] vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def np_train_one_iter(step, grad): grad = grad + weight_decay * x bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps param = x - ((learning_rate / bias_correction1) * v / denom) return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = np_train_one_iter( i, random_grad_seq[i - 1][tensor_idx] ) return x oneflow_res = train_by_oneflow() numpy_res = [] for i in range(tensor_num): numpy_res.append(train_by_numpy(i)) for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.001, atol=0.0001, ) ) def compare_with_numpy_adam_clip_grad( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) adam = flow.optim.Adam( [ { "params": x, "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() adam.clip_grad() adam.step() adam.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adam.state_dict() adam = flow.optim.Adam( [{"params": x,}], contiguous_params=contiguous_params ) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adam.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value_seq vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def train_one_iter(step, grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) for i in range(tensor_num): grad[i] = grad[i] + weight_decay * x[i] bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) vt[i] = beta1 * vt[i] + (1 - beta1) * grad[i] st[i] = beta2 * st[i] + (1 - beta2) * grad[i] * grad[i] if amsgrad: max_st[i] = np.maximum(st[i], max_st[i]) denom = np.sqrt(max_st[i]) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(st[i]) / np.sqrt(bias_correction2) + eps x[i] = x[i] - ((learning_rate / bias_correction1) * vt[i] / denom) for i in range(1, train_iters + 1): train_one_iter(i, random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow() numpy_res = train_by_numpy() for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.0001, atol=0.0001, ) ) @flow.unittest.skip_unless_1n1d() class TestAdam(flow.unittest.TestCase): def test_adam(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.9, 0.000] arg_dict["eps"] = [1e-08] arg_dict["do_bias_correction"] = [random_bool().value()] arg_dict["amsgrad"] = [random_bool().value()] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgList(arg_dict): compare_with_numpy_adam(test_case, *arg) def test_adam_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.1, 0.000] arg_dict["eps"] = [1e-08] arg_dict["do_bias_correction"] = [random_bool().value()] arg_dict["amsgrad"] = [random_bool().value()] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = random_util.sample( ["inf", "-inf", 0.0, 1.0, 2.0, 3.5], k=3 ) arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgList(arg_dict): compare_with_numpy_adam_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_adamw.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import tempfile import unittest from collections import OrderedDict import random as random_util import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import random_bool, random_device from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_adamw( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) adam = flow.optim.AdamW( [ { "params": x, "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() adam.step() adam.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adam.state_dict() adam = flow.optim.AdamW(x, contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adam.load_state_dict(state_dict) return x def train_by_numpy(tensor_idx): x = init_value_seq[tensor_idx] vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def train_one_iter(step, grad): v = beta1 * vt + (1 - beta1) * grad s = beta2 * st + (1 - beta2) * grad * grad bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) max_s = np.zeros_like(x) if amsgrad: max_s = np.maximum(s, max_st) denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps lr = learning_rate / bias_correction1 / denom g = lr * v + learning_rate * weight_decay * x param = x - g return (param, v, s, max_s) for i in range(1, train_iters + 1): (x, vt, st, max_st) = train_one_iter(i, random_grad_seq[i - 1][tensor_idx]) return x oneflow_res = train_by_oneflow() numpy_res = [] for i in range(tensor_num): numpy_res.append(train_by_numpy(i)) for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.0001, atol=0.0001, ) ) def compare_with_numpy_adamw_clip_grad( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, amsgrad, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) adam = flow.optim.AdamW( [ { "params": x, "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], do_bias_correction=do_bias_correction, amsgrad=amsgrad, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() adam.clip_grad() adam.step() adam.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = adam.state_dict() adam = flow.optim.AdamW(x, contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) adam.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value_seq vt = np.zeros_like(x) st = np.zeros_like(x) max_st = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] def train_one_iter(step, grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) for i in range(tensor_num): vt[i] = beta1 * vt[i] + (1 - beta1) * grad[i] st[i] = beta2 * st[i] + (1 - beta2) * grad[i] * grad[i] bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step) bias_correction2 = 1.0 - np.power(beta2, step) if amsgrad: max_st[i] = np.maximum(st[i], max_st[i]) denom = np.sqrt(max_st[i]) / np.sqrt(bias_correction2) + eps else: denom = np.sqrt(st[i]) / np.sqrt(bias_correction2) + eps lr = learning_rate / bias_correction1 / denom g = lr * vt[i] + learning_rate * weight_decay * x[i] x[i] = x[i] - g for i in range(1, train_iters + 1): train_one_iter(i, random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow() numpy_res = train_by_numpy() for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.0001, atol=0.0001, ) ) @flow.unittest.skip_unless_1n1d() class TestAdamW(flow.unittest.TestCase): def test_adamw(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.9, 0.999)] arg_dict["weight_decay"] = [0.01, 0.00] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [random_bool().value()] arg_dict["amsgrad"] = [random_bool().value()] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgList(arg_dict): compare_with_numpy_adamw(test_case, *arg) def test_adamw_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.9, 0.999)] arg_dict["weight_decay"] = [0.001, 0.0] arg_dict["eps"] = [1e-8] arg_dict["do_bias_correction"] = [random_bool().value()] arg_dict["amsgrad"] = [random_bool().value()] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = random_util.sample( ["inf", "-inf", 0.0, 1.0, 2.0, 3.5], k=3 ) arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgList(arg_dict): compare_with_numpy_adamw_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_add_param_group.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.test_util import GenArgList import oneflow as flow def _test_sgd_add_param_group(test_case): w1 = flow.ones(3, 3) w1.requires_grad = True w2 = flow.ones(3, 3) w2.requires_grad = True o = flow.optim.SGD([w1]) test_case.assertTrue(o.param_groups[0]["lr"] == 0.001) test_case.assertTrue(o.param_groups[0]["momentum"] == 0.0) test_case.assertTrue(o.param_groups[0]["weight_decay"] == 0.0) test_case.assertTrue(o.param_groups[0]["nesterov"] == False) test_case.assertTrue(o.param_groups[0]["maximize"] == False) o.step() o.add_param_group({"params": w2}) test_case.assertTrue(o.param_groups[1]["lr"] == 0.001) test_case.assertTrue(o.param_groups[1]["momentum"] == 0.0) test_case.assertTrue(o.param_groups[1]["weight_decay"] == 0.0) test_case.assertTrue(o.param_groups[1]["nesterov"] == False) test_case.assertTrue(o.param_groups[1]["maximize"] == False) o.step() class TestAddParamGroup(flow.unittest.TestCase): def test_sgd_add_param_group(test_case): _test_sgd_add_param_group(test_case) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_ftrl.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import tempfile import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from optimizer_test_util import clip_grad_norm_np from oneflow.one_embedding import Ftrl import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_ftrl( test_case, device, x_shape, learning_rate, train_iters, weight_decay, lr_power, initial_accumulator_value, lambda1, lambda2, beta, reload_state_step, save_load_by_pickle, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) ftrl = Ftrl( [ { "params": [x], "lr": learning_rate, "weight_decay": weight_decay, "lr_power": lr_power, "initial_accumulator_value": initial_accumulator_value, "lambda1": lambda1, "lambda2": lambda2, "beta": beta, } ] ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss = flow.sum(x * grad_tensor) loss.backward() ftrl.step() ftrl.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = ftrl.state_dict() ftrl = Ftrl([{"params": [x],}],) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) ftrl.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value accum = np.zeros_like(x) accum.fill(initial_accumulator_value) z_arr = np.zeros_like(x) def np_train_one_iter(grad): grad = grad + weight_decay * x new_accum = accum + grad * grad sigma = ( np.power(new_accum, lr_power) - np.power(accum, lr_power) ) / learning_rate new_z_val = z_arr + grad - sigma * x update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / ( (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2 ) param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val) return (param, new_accum, new_z_val) for i in range(1, train_iters + 1): (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4) ) def compare_with_numpy_ftrl_clip_grad( test_case, device, x_shape, learning_rate, train_iters, weight_decay, lr_power, initial_accumulator_value, lambda1, lambda2, beta, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) ftrl = Ftrl( [ { "params": [x], "lr": learning_rate, "weight_decay": weight_decay, "lr_power": lr_power, "initial_accumulator_value": initial_accumulator_value, "lambda1": lambda1, "lambda2": lambda2, "beta": beta, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ] ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss = flow.sum(x * grad_tensor) loss.backward() ftrl.clip_grad() ftrl.step() ftrl.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = ftrl.state_dict() ftrl = Ftrl([{"params": [x],}]) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) ftrl.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value accum = np.zeros_like(x) accum.fill(initial_accumulator_value) z_arr = np.zeros_like(x) def np_train_one_iter(grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x new_accum = accum + grad * grad sigma = ( np.power(new_accum, lr_power) - np.power(accum, lr_power) ) / learning_rate new_z_val = z_arr + grad - sigma * x update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / ( (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2 ) param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val) return (param, new_accum, new_z_val) for i in range(1, train_iters + 1): (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4) ) @flow.unittest.skip_unless_1n1d() class Testftrl(flow.unittest.TestCase): def test_ftrl(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["weight_decay"] = [0.9, 0.000] arg_dict["lr_power"] = [-0.5, 0.5] arg_dict["initial_accumulator_value"] = [0.1, 0.05] arg_dict["lambda1"] = [0.01] arg_dict["lambda2"] = [0.0, 0.01] arg_dict["beta"] = [1.0] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_ftrl(test_case, *arg) def test_ftrl_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["weight_decay"] = [0.9, 0.000] arg_dict["lr_power"] = [-0.5] arg_dict["initial_accumulator_value"] = [0.1, 0.05] arg_dict["lambda1"] = [0.01] arg_dict["lambda2"] = [0.0] arg_dict["beta"] = [1.0] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_ftrl_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_lamb.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import tempfile import unittest from collections import OrderedDict import numpy as np from optimizer_test_util import clip_grad_norm_np from oneflow.test_utils.test_util import GenArgList import oneflow as flow def compare_with_numpy_lamb( test_case, device, x_shape, learning_rate, train_iters, betas, weight_decay, eps, do_bias_correction, adam_w_mode, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, ): np.random.seed(1000) random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = flow.nn.Parameter(flow.Tensor(init_value, device=flow.device(device))) optim_kwargs = { "params": [x], "lr": learning_rate, "betas": betas, "eps": eps, "weight_decay": weight_decay, "adam_w_mode": adam_w_mode, "do_bias_correction": do_bias_correction, "contiguous_params": contiguous_params, } if clip_grad_max_norm != -1: optim_kwargs["clip_grad_max_norm"] = clip_grad_max_norm optim_kwargs["clip_grad_norm_type"] = clip_grad_norm_type lamb = flow.optim.LAMB([optim_kwargs]) def train_one_iter(grad): grad_tensor = flow.tensor( grad, dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss = flow.sum(x * grad_tensor) loss.backward() if clip_grad_max_norm != -1: lamb.clip_grad() lamb.step() lamb.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = lamb.state_dict() lamb = flow.optim.LAMB([optim_kwargs]) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) lamb.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value mt = np.zeros_like(x) vt = np.zeros_like(x) beta1 = betas[0] beta2 = betas[1] if adam_w_mode: l2 = 0 wd = weight_decay else: l2 = weight_decay wd = 0 def np_train_one_iter(step, grad): if clip_grad_max_norm != -1: _, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + l2 * x bias_correction1 = 1.0 bias_correction2 = 1.0 if do_bias_correction: bias_correction1 = 1.0 - np.power(beta1, step + 1) bias_correction2 = 1.0 - np.power(beta2, step + 1) m = beta1 * mt + (1 - beta1) * grad v = beta2 * vt + (1 - beta2) * grad * grad denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps adam_diff = m / bias_correction1 / denom w_norm = np.linalg.norm(x, ord=2) g_norm = np.linalg.norm(adam_diff, ord=2) if w_norm > 0 and g_norm > 0: trust_ratio = w_norm / g_norm else: trust_ratio = 1.0 param = x - learning_rate * trust_ratio * (adam_diff + wd * x) return (param, m, v) for i in range(train_iters): (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i]) return x of_res = train_by_oneflow().numpy() np_res = train_by_numpy() test_case.assertTrue( np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3) ) @flow.unittest.skip_unless_1n1d() class TestLamb(flow.unittest.TestCase): def test_lamb(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY"): arg_dict["device"] = ["cpu"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [0.1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["betas"] = [(0.99, 0.9)] arg_dict["weight_decay"] = [0.001, 0.1] arg_dict["eps"] = [1e-6] arg_dict["do_bias_correction"] = [True, False] arg_dict["adam_w_mode"] = [True, False] # NOTE(l1aoxingyu): max_norm = -1 means no clip grad arg_dict["clip_grad_max_norm"] = [-1, 0.0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["reload_state_step"] = [5] arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_lamb(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_lbfgs.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import tempfile import unittest from collections import OrderedDict import random as random_util import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import random_device, random_bool import oneflow as flow from oneflow.nn.parameter import Parameter from collections import defaultdict def _quadratic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): if bounds is not None: xmin_bound, xmax_bound = bounds else: xmin_bound, xmax_bound = (x1, x2) if x1 < x2 else (x2, x1) if x1 == 0: t_new = -(g1 * (x2 ** 2)) / (2 * (f2 - f1 - g1 * x2)) else: a = -(f1 - f2 - g1 * (x1 - x2)) / ((x1 - x2) ** 2) t_new = x1 - g1 / (2 * a) return min(xmax_bound, max(xmin_bound, t_new)) def _strong_wolfe( eval_closure, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 ): d_norm = max(map(abs, d)) g = np.copy(g) f_new, g_new = eval_closure(x, t, d) ls_func_evals = 1 gtd_new = g_new.dot(d) t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd done = False ls_iter = 0 while ls_iter < max_ls: if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new > f_prev): search_area = [t_prev, t] search_area_f = [f_prev, f_new] search_area_g = [g_prev, np.copy(g_new)] search_area_gtd = [gtd_prev, gtd_new] break if abs(gtd_new) <= -c2 * gtd: search_area = [t] search_area_f = [f_new] search_area_g = [g_new] done = True break if gtd_new >= 0: search_area = [t_prev, t] search_area_f = [f_prev, f_new] search_area_g = [g_prev, np.copy(g_new)] search_area_gtd = [gtd_prev, gtd_new] min_step = t + 0.01 * (t - t_prev) max_step = t * 10 tmp = t t = _quadratic_interpolate( t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) ) t_prev = tmp f_prev = f_new g_prev = np.copy(g_new) gtd_prev = gtd_new f_new, g_new = eval_closure(x, t, d) ls_func_evals += 1 gtd_new = g_new.dot(d) ls_iter += 1 if ls_iter == max_ls: search_area = [0, t] search_area_f = [f, f_new] search_area_g = [g, g_new] # zoom low_pos, high_pos = (0, 1) if search_area_f[0] <= search_area_f[-1] else (1, 0) while not done and ls_iter < max_ls: if abs(search_area[1] - search_area[0]) * d_norm < tolerance_change: break t = _quadratic_interpolate( search_area[0], search_area_f[0], search_area_gtd[0], search_area[1], search_area_f[1], search_area_gtd[1], ) f_new, g_new = eval_closure(x, t, d) ls_func_evals += 1 gtd_new = g_new.dot(d) ls_iter += 1 if f_new > (f + c1 * t * gtd) or f_new >= search_area_f[low_pos]: search_area[high_pos] = t search_area_f[high_pos] = f_new search_area_g[high_pos] = np.copy(g_new) search_area_gtd[high_pos] = gtd_new low_pos, high_pos = ( (0, 1) if search_area_f[0] <= search_area_f[1] else (1, 0) ) if abs(gtd_new) <= -c2 * gtd: done = True elif gtd_new * (search_area[high_pos] - search_area[low_pos]) >= 0: search_area[high_pos] = search_area[low_pos] search_area_f[high_pos] = search_area_f[low_pos] search_area_g[high_pos] = search_area_g[low_pos] search_area_gtd[high_pos] = search_area_gtd[low_pos] search_area[low_pos] = t search_area_f[low_pos] = f_new search_area_g[low_pos] = np.copy(g_new) search_area_gtd[low_pos] = gtd_new t = search_area[low_pos] f_new = search_area_f[low_pos] g_new = search_area_g[low_pos] return f_new, g_new, t, ls_func_evals def compare_with_numpy_lbfgs( test_case, device, x_shape, learning_rate, train_iters, max_iter, max_eval, tolerance_grad, tolerance_change, history_size, line_search_fn, reload_state_step, save_load_by_pickle, contiguous_params, tensor_num, use_float64, ): random_grad_seq = [] init_value_seq = [] if use_float64: npType = np.float64 flowType = flow.float64 flow.set_default_tensor_type(flow.DoubleTensor) else: npType = np.float32 flowType = flow.float32 flow.set_default_tensor_type(flow.FloatTensor) for _ in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(npType)) for _ in range(tensor_num): random_grad_seq.append(np.random.uniform(size=x_shape).astype(npType)) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter( flow.tensor( init_value_seq[i], device=flow.device(device), dtype=flowType ) ) ) lbfgs = flow.optim.LBFGS( [{"params": x}], lr=learning_rate, max_iter=max_iter, max_eval=max_eval, tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, history_size=history_size, line_search_fn=line_search_fn, contiguous_params=contiguous_params, ) def compute_loss(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flowType, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * x[i] * grad_tensor) loss.backward() return loss def train_one_iter(grad): def closure(): lbfgs.zero_grad() loss = compute_loss(grad) return loss return lbfgs.step(closure) for i in range(train_iters): train_one_iter(random_grad_seq) if i == reload_state_step: state_dict = lbfgs.state_dict() lbfgs = flow.optim.LBFGS( [{"params": x,}], contiguous_params=contiguous_params ) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) lbfgs.load_state_dict(state_dict) return x def train_by_numpy(): def compute_loss(param, grad): loss = 0.0 loss += np.sum(param * param * grad) return loss x = np.concatenate(init_value_seq) def np_train_one_iter(x, state, init_grad): flat_grad = 2 * x * init_grad if max(map(abs, flat_grad)) <= tolerance_grad: return x loss = compute_loss(x, init_grad) current_evals = 1 state["func_evals"] += 1 d = state.get("d") t = state.get("t") old_diffs = state.get("old_diffs") old_step_size = state.get("old_step_size") ro = state.get("ro") H_diag = state.get("H_diag") prev_flat_grad = state.get("prev_flat_grad") prev_loss = state.get("prev_loss") n_iter = 0 while n_iter < max_iter: n_iter += 1 state["n_iter"] += 1 if state["n_iter"] == 1: d = -flat_grad old_diffs = [] old_step_size = [] ro = [] H_diag = 1 else: y = flat_grad - prev_flat_grad s = d * t ys = y.dot(s) if ys > 1e-10: if len(old_diffs) == history_size: old_diffs.pop(0) old_step_size.pop(0) ro.pop(0) old_diffs.append(y) old_step_size.append(s) ro.append(1.0 / ys) H_diag = ys / y.dot(y) num_old = len(old_diffs) if "alpha" not in state: state["alpha"] = [None] * history_size alpha = state["alpha"] q = -flat_grad for i in range(num_old - 1, -1, -1): alpha[i] = old_step_size[i].dot(q) * ro[i] q += old_diffs[i] * -alpha[i] d = q * H_diag for i in range(num_old): beta_i = old_diffs[i].dot(d) * ro[i] d += old_step_size[i] * (alpha[i] - beta_i) prev_flat_grad = np.copy(flat_grad) prev_loss = loss if state["n_iter"] == 1: t = min(1.0, 1.0 / np.sum(np.abs(flat_grad))) * learning_rate else: t = learning_rate gtd = flat_grad.dot(d) if gtd > -tolerance_change: break ls_func_evals = 0 if line_search_fn is None: x += t * d if n_iter != max_iter: loss = float(compute_loss(x, init_grad)) ls_func_evals = 1 flat_grad = 2 * x * init_grad else: assert ( line_search_fn == "strong_wolfe" ), "only strong_wolfe is expected" init_param = np.copy(x) def eval_func(x, t, d): return ( compute_loss(x + t * d, init_grad), 2 * (x + t * d) * init_grad, ) loss, flat_grad, t, ls_func_evals = _strong_wolfe( eval_func, init_param, t, d, loss, flat_grad, gtd ) x += t * d current_evals += ls_func_evals state["func_evals"] += ls_func_evals if n_iter == max_iter: break if current_evals >= max_eval: break if np.max(np.abs(flat_grad)) <= tolerance_grad: break if np.max(np.abs(d * t)) <= tolerance_change: break if abs(loss - prev_loss) < tolerance_change: break state["d"] = d state["t"] = t state["old_diffs"] = old_diffs state["old_step_size"] = old_step_size state["ro"] = ro state["prev_flat_grad"] = prev_flat_grad state["prev_loss"] = prev_loss state["H_diag"] = H_diag return x state = defaultdict(dict) state.setdefault("func_evals", 0) state.setdefault("n_iter", 0) for _ in range(0, train_iters): x = np_train_one_iter(x, state, np.concatenate(random_grad_seq)) return x oneflow_res = flow.cat(train_by_oneflow(), 0) numpy_res = train_by_numpy() test_case.assertTrue( np.allclose( oneflow_res.numpy().flatten(), numpy_res.flatten(), rtol=0.01, atol=0.01, ) ) @flow.unittest.skip_unless_1n1d() class TestLBFGS(flow.unittest.TestCase): def test_lbfgs_numpy(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [10, 20] arg_dict["learning_rate"] = [0.01] arg_dict["train_iters"] = [20] arg_dict["max_iter"] = [20] arg_dict["max_eval"] = [25] arg_dict["tolerance_grad"] = [1e-7] arg_dict["tolerance_change"] = [1e-9] arg_dict["history_size"] = [100] arg_dict["line_search_fn"] = [None, "strong_wolfe"] arg_dict["reload_state_step"] = [5] arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["tensor_num"] = [3, 4, 7] arg_dict["use_float64"] = [True, False] for arg in GenArgList(arg_dict): compare_with_numpy_lbfgs(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_rmsprop.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import tempfile import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_rmsprop( test_case, device, x_shape, learning_rate, momentum, train_iters, alpha, eps, weight_decay, centered, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) param_list = list() param_list.append(x) rmsprop = flow.optim.RMSprop( [ { "params": param_list, "lr": learning_rate, "alpha": alpha, "eps": eps, "weight_decay": weight_decay, "momentum": momentum, "centered": centered, "contiguous_params": contiguous_params, } ] ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss = flow.sum(x * grad_tensor) loss.backward() rmsprop.step() rmsprop.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = rmsprop.state_dict() rmsprop = flow.optim.RMSprop([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) rmsprop.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value r = np.zeros_like(x) v = np.zeros_like(x) g = np.zeros_like(x) def train_one_iter(grad): grad = grad + weight_decay * x r_ = alpha * r + (1 - alpha) * grad * grad if centered: g_ = alpha * g + (1 - alpha) * grad v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad else: g_ = g v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad param = x - v_ return (param, r_, g_, v_) for i in range(train_iters): (x, r, g, v) = train_one_iter(random_grad_seq[i]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=2e-3, atol=2e-3) ) def compare_with_numpy_rmsprop_clip_grad( test_case, device, x_shape, learning_rate, momentum, train_iters, alpha, eps, weight_decay, centered, clip_grad_max_norm, clip_grad_norm_type, reload_state_step, save_load_by_pickle, contiguous_params, ): random_grad_seq = [] for _ in range(train_iters): random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) init_value = np.random.uniform(size=x_shape).astype(np.float32) def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) param_list = list() param_list.append(x) rmsprop = flow.optim.RMSprop( [ { "params": param_list, "lr": learning_rate, "alpha": alpha, "eps": eps, "weight_decay": weight_decay, "momentum": momentum, "centered": centered, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, "contiguous_params": contiguous_params, } ] ) def train_one_iter(grad): grad_tensor = flow.tensor( grad, dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss = flow.sum(x * grad_tensor) loss.backward() rmsprop.clip_grad() rmsprop.step() rmsprop.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) if i == reload_state_step: state_dict = rmsprop.state_dict() rmsprop = flow.optim.RMSprop([x], contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) rmsprop.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value r = np.zeros_like(x) v = np.zeros_like(x) g = np.zeros_like(x) def train_one_iter(grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x r_ = alpha * r + (1 - alpha) * grad * grad if centered: g_ = alpha * g + (1 - alpha) * grad v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad else: g_ = g v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad param = x - v_ return (param, r_, g_, v_) for i in range(train_iters): (x, r, g, v) = train_one_iter(random_grad_seq[i]) return x oneflow_res = train_by_oneflow().numpy() numpy_res = train_by_numpy() test_case.assertTrue( np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=2e-3, atol=2e-3) ) @flow.unittest.skip_unless_1n1d() class TestRMSProp(flow.unittest.TestCase): def test_rmsprop(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1] arg_dict["momentum"] = [0.0] arg_dict["train_iters"] = [2] arg_dict["alpha"] = [0.9, 0.99] arg_dict["eps"] = [1e-08, 1e-05] arg_dict["weight_decay"] = [0.1, 0.99] arg_dict["centered"] = [False, True] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [True, False] for arg in GenArgList(arg_dict): compare_with_numpy_rmsprop(test_case, *arg) def test_rmsprop_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1] arg_dict["momentum"] = [0.0] arg_dict["train_iters"] = [2] arg_dict["alpha"] = [0.9, 0.99] arg_dict["eps"] = [1e-08, 1e-05] arg_dict["weight_decay"] = [0.1, 0.99] arg_dict["centered"] = [False, True] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [False, True] arg_dict["contiguous_params"] = [False, True] for arg in GenArgList(arg_dict): compare_with_numpy_rmsprop_clip_grad(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_optim_sgd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import tempfile import os import random as random_util import numpy as np from oneflow.test_utils.test_util import GenArgDict from oneflow.test_utils.automated_test_util import random_bool, random_device from optimizer_test_util import clip_grad_norm_np import oneflow as flow from oneflow.nn.parameter import Parameter def compare_with_numpy_sgd( test_case, device, x_shape, momentum, dampening, nesterov, maximize, weight_decay, learning_rate, train_iters, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) sgd = flow.optim.SGD( [{"params": x, "lr": learning_rate, "weight_decay": weight_decay,}], momentum=momentum, dampening=dampening, nesterov=nesterov, maximize=maximize, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() sgd.step() sgd.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) # test state_dict/load_state_dict if i == reload_state_step: state_dict = sgd.state_dict() sgd = flow.optim.SGD(x, contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) sgd.load_state_dict(state_dict) return x def train_by_numpy(tensor_idx): x = init_value_seq[tensor_idx] vt = np.zeros_like(x) def train_one_iter(grad): grad = grad + weight_decay * x if momentum > 0.0: next_momentum = momentum * vt + (1 - dampening) * grad v = next_momentum if nesterov: grad += momentum * next_momentum else: grad = next_momentum alpha = -learning_rate if maximize: alpha = learning_rate next_model = x + alpha * grad param = next_model else: v = learning_rate * grad param = x - v return (param, v) for i in range(train_iters): (x, vt) = train_one_iter(random_grad_seq[i][tensor_idx]) return x oneflow_res = train_by_oneflow() numpy_res = [] for i in range(tensor_num): numpy_res.append(train_by_numpy(i)) for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.0001, atol=0.0001, ) ) def compare_with_numpy_sgd_clip_grad( test_case, device, x_shape, momentum, dampening, nesterov, maximize, weight_decay, learning_rate, clip_grad_max_norm, clip_grad_norm_type, train_iters, reload_state_step, save_load_by_pickle, contiguous_params, fused, tensor_num, ): random_grad_seq = [] init_value_seq = [] for i in range(tensor_num): init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) for _ in range(train_iters): random_grad_seq_per_iter = [] for i in range(tensor_num): random_grad_seq_per_iter.append( np.random.uniform(size=x_shape).astype(np.float32) ) random_grad_seq.append(random_grad_seq_per_iter) def train_by_oneflow(): x = [] for i in range(tensor_num): x.append( Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device))) ) sgd = flow.optim.SGD( [ { "params": x, "lr": learning_rate, "dampening": dampening, "nesterov": nesterov, "maximize": maximize, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } ], momentum=momentum, dampening=dampening, nesterov=nesterov, maximize=maximize, contiguous_params=contiguous_params, fused=fused, ) def train_one_iter(grad): loss = 0.0 for i in range(tensor_num): grad_tensor = flow.tensor( grad[i], dtype=flow.float32, requires_grad=False, device=flow.device(device), ) loss += flow.sum(x[i] * grad_tensor) loss.backward() sgd.clip_grad() sgd.step() sgd.zero_grad() for i in range(train_iters): train_one_iter(random_grad_seq[i]) # test state_dict/load_state_dict if i == reload_state_step: state_dict = sgd.state_dict() sgd = flow.optim.SGD(x, contiguous_params=contiguous_params) if save_load_by_pickle: with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) state_dict = flow.load(f.name) sgd.load_state_dict(state_dict) return x def train_by_numpy(): x = init_value_seq vt = np.zeros_like(x) def train_one_iter(grad): total_norm, grad = clip_grad_norm_np( grad, clip_grad_max_norm, clip_grad_norm_type ) for i in range(tensor_num): grad[i] = grad[i] + weight_decay * x[i] if momentum > 0.0: next_momentum = momentum * vt[i] + (1 - dampening) * grad[i] vt[i] = next_momentum if nesterov: grad[i] += momentum * next_momentum else: grad[i] = next_momentum alpha = -learning_rate if maximize: alpha = learning_rate x[i] = x[i] + alpha * grad[i] else: vt[i] = learning_rate * grad[i] x[i] = x[i] - vt[i] for i in range(train_iters): train_one_iter(random_grad_seq[i]) return x oneflow_res = train_by_oneflow() numpy_res = train_by_numpy() for i in range(tensor_num): test_case.assertTrue( np.allclose( oneflow_res[i].numpy().flatten(), numpy_res[i].flatten(), rtol=0.0001, atol=0.0001, ) ) @flow.unittest.skip_unless_1n1d() class TestOptimizers(flow.unittest.TestCase): def test_sgd(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["momentum"] = [0.0, 0.9] arg_dict["dampening"] = [0.0, 0.9] arg_dict["nesterov"] = [random_bool().value()] arg_dict["maximize"] = [random_bool().value()] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["learning_rate"] = [1, 0.1] arg_dict["train_iters"] = [10] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgDict(arg_dict): compare_with_numpy_sgd(test_case, **arg) def test_sgd_clip_grad(test_case): arg_dict = OrderedDict() arg_dict["device"] = [random_device().value()] arg_dict["x_shape"] = [(10,)] arg_dict["momentum"] = [0.0, 0.9] arg_dict["dampening"] = [0.0, 0.9] arg_dict["nesterov"] = [random_bool().value()] arg_dict["maximize"] = [random_bool().value()] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["learning_rate"] = [1, 0.1] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0] arg_dict["clip_grad_norm_type"] = random_util.sample( ["inf", "-inf", 0.0, 1.0, 2.0, 3.5], k=3 ) arg_dict["train_iters"] = [10] arg_dict["reload_state_step"] = [5] # save and load optim state arg_dict["save_load_by_pickle"] = [random_bool().value()] arg_dict["contiguous_params"] = [random_bool().value()] arg_dict["fused"] = [random_bool().value()] arg_dict["tensor_num"] = [1, 4] for arg in GenArgDict(arg_dict): compare_with_numpy_sgd_clip_grad(test_case, **arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_eager_global_zero_grad_sbp(test_case): x = flow.nn.Parameter( flow.zeros((10,)).to_global( sbp=flow.sbp.broadcast, placement=flow.placement("cuda", [0]) ) ) x.grad = flow.ones_like(x) t = x.grad test_case.assertEqual(len(t.sbp), 1) test_case.assertEqual(t.sbp[0], flow.sbp.broadcast) optimizer = flow.optim.SGD([x]) optimizer.zero_grad() test_case.assertTrue(np.allclose(t.numpy(), 0.0)) test_case.assertEqual(len(t.sbp), 1) test_case.assertEqual(t.sbp[0], flow.sbp.partial_sum) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_pairwise_distance.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestPairwiseDistance(flow.unittest.TestCase): @autotest(n=3) def test_pairwise_distance_module_with_random_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) cos = torch.nn.PairwiseDistance(p=2, eps=1e-6).to(device) cos.train(random()) output = cos(a, b) return output @autotest(n=3) def test_pairwise_distance_module_with_nonequal_dim_random_data(test_case): device = random_device() a = random_tensor(ndim=1, dim0=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) cos = torch.nn.PairwiseDistance(p=2, eps=1e-6).to(device) cos.train(random()) output = cos(a, b) return output @autotest(n=3) def test_pairwise_distance_functional_with_random_data(test_case): device = random_device() a = random_tensor(ndim=2, dim0=10, dim1=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) output = torch.nn.functional.pairwise_distance(a, b, p=2, eps=1e-6) return output @autotest(n=3) def test_pairwise_distance_functional_with_nonequal_dim_random_data(test_case): device = random_device() a = random_tensor(ndim=1, dim0=128).to(device) b = random_tensor(ndim=2, dim0=10, dim1=128).to(device) output = torch.nn.functional.pairwise_distance(a, b, p=2, eps=1e-6) return output if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_param_group.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestParamGroup(flow.unittest.TestCase): def test_ParamGroup(test_case): parameters = { "params": [flow.ones(10), flow.ones(5)], "lr": 0.01, } default_options = { "test_float": 1e-3, "test_int": 6, "test_list": [1, 2, 3], "test_tensor": flow.ones(10), "test_str": "test", } pg = flow.optim.optimizer.ParamGroup(parameters, default_options) test_case.assertEqual(pg["test_float"], 1e-3) test_case.assertEqual(pg["test_int"], 6) test_case.assertTrue(np.array_equal(pg.get("test_list"), [1, 2, 3])) test_case.assertTrue( np.array_equal(pg.get("test_tensor").numpy(), flow.ones(10).numpy()) ) test_case.assertEqual(pg["test_str"], "test") test_case.assertTrue("params" in pg.keys()) test_case.assertTrue( np.array_equal(pg["params"][0].numpy(), flow.ones(10).numpy()) ) test_case.assertTrue( np.array_equal(pg["params"][1].numpy(), flow.ones(5).numpy()) ) test_case.assertEqual(pg["lr"], 0.01) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_parameters_grouping.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from collections import OrderedDict import oneflow as flow from oneflow.test_utils.test_util import GenArgDict from oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup as CPG from oneflow.nn.parameter import Parameter import oneflow.unittest def np_allclose_with_shape(a, b, *args, **kwargs): return a.shape == b.shape and np.allclose(a, b, *args, **kwargs) def module_grouping(test_case, device): class Model(flow.nn.Module): def __init__(self): super(Model, self).__init__() dtypes = [flow.float32, flow.float64] for i in range(10): self.register_parameter( f"w{i}", flow.nn.Parameter( flow.tensor([i % 2 + 1, i % 2 + 1], dtype=dtypes[i % 2]) ), ) m = Model().to(device) m.make_contiguous_params_group() cpg = CPG( list(m.parameters()) + [flow.tensor([3, 3], dtype=flow.float32, requires_grad=True)] ) test_case.assertTrue(len(m.cpg.grouped_parameters) == 2) test_case.assertTrue(len(m.cpg.grouped_grads) == 2) test_case.assertTrue(flow.max(m.cpg.grouped_parameters[0]) == 1) test_case.assertTrue(flow.max(m.cpg.grouped_parameters[1]) == 2) test_case.assertTrue(len(cpg.grouped_parameters) == 3) test_case.assertTrue(len(cpg.grouped_grads) == 3) test_case.assertTrue(flow.max(cpg.grouped_parameters[0]) == 1) test_case.assertTrue(flow.max(cpg.grouped_parameters[1]) == 2) test_case.assertTrue(flow.max(cpg.grouped_parameters[2]) == 3) def direct_grouping(test_case, device): x = [ Parameter( flow.tensor( [1, 2], device=flow.device(device), dtype=flow.float32, requires_grad=True, ) ), Parameter( flow.tensor( [3, 4], device=flow.device(device), dtype=flow.float32, requires_grad=True, ) ), ] cpg = CPG([[x[0]], [x[1]]]) test_case.assertTrue(len(cpg.grouped_parameters) == 2) test_case.assertTrue(len(cpg.grouped_grads) == 2) def global_grouping(test_case, device): x = flow.nn.Parameter( flow.zeros((10,), dtype=flow.float32, requires_grad=True).to_global( sbp=flow.sbp.broadcast, placement=flow.placement(device, [0]) ) ) y = flow.nn.Parameter( flow.zeros((10,), dtype=flow.float32, requires_grad=True).to_global( sbp=flow.sbp.split(0), placement=flow.placement(device, [0]) ) ) cpg = CPG([x, y], group_on_current_buffer=False) test_case.assertTrue(len(cpg.grouped_parameters) == 2) test_case.assertTrue(len(cpg.grouped_grads) == 2) def multi_module_grad(test_case, device): class Module1(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([1, 1])) self.w2 = flow.nn.Parameter(flow.Tensor([1, 1])) def forward(self, x): return x * self.w1 * self.w2 class Module2(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([2, 2])) self.w2 = flow.nn.Parameter(flow.Tensor([2, 2])) def forward(self, x): return x * self.w1 * self.w2 m1 = Module1().to(device) m1.make_contiguous_params_group() m2 = Module2().to(device) m2.make_contiguous_params_group() optim1 = flow.optim.SGD(m1.parameters(), lr=1e-2, contiguous_params=True) optim2 = flow.optim.SGD(m2.parameters(), lr=1e-2, contiguous_params=True) x1 = flow.ones([1, 1]).to(device) x2 = flow.ones([2, 2]).to(device) flow.sum(m1(x1)).backward() flow.sum(m2(x2)).backward() for p in m1.parameters(): test_case.assertTrue( np_allclose_with_shape(p.grad.numpy(), np.array([1.0, 1.0])) ) for p in m2.parameters(): test_case.assertTrue( np_allclose_with_shape(p.grad.numpy(), np.array([4.0, 4.0])) ) def multi_module_lifecycle(test_case, device): class Module1(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([1, 1])) self.w2 = flow.nn.Parameter(flow.Tensor([1, 1])) def forward(self, x): return x * self.w1 * self.w2 class Module2(flow.nn.Module): def __init__(self): super().__init__() self.w1 = flow.nn.Parameter(flow.Tensor([2, 2])) self.w2 = flow.nn.Parameter(flow.Tensor([2, 2])) def forward(self, x): return x * self.w1 * self.w2 m1 = Module1().to(device) m1.make_contiguous_params_group() m2 = Module2().to(device) m2.make_contiguous_params_group() del m1 cpg = CPG(list(m2.parameters())) test_case.assertTrue(len(cpg.grouped_parameters) == 1) @flow.unittest.skip_unless_1n1d() class TestCPG(flow.unittest.TestCase): def test_cpg(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgDict(arg_dict): device = arg["device"] module_grouping(test_case, device) direct_grouping(test_case, device) global_grouping(test_case, device) multi_module_lifecycle(test_case, device) multi_module_grad(test_case, device) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_parital_fc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest # TODO: guoran, fix this on multi gpu @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestParitalFC(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_parital_fc(test_case): p = flow.placement.all("cuda") w = flow.randn( 50000, 128, placement=p, sbp=flow.sbp.broadcast, requires_grad=True ) label = flow.randint(0, 50000, (512,), placement=p, sbp=flow.sbp.broadcast) num_sample = 5000 out = flow.distributed_partial_fc_sample(w, label, num_sample) test_case.assertTrue(out[0].shape == flow.Size([512])) test_case.assertTrue(out[1].shape == flow.Size([5000])) test_case.assertTrue(out[2].shape == flow.Size([5000, 128])) # test gradient function sample_weight = out[2] sample_weight.sum().backward() if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_pixel_shuffle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _np_pixel_shuffle(input, h_factor, w_factor): (_batch, _channel, _height, _width) = input.shape assert ( _channel % (h_factor * w_factor) == 0 ), "The channels of input tensor must be divisible by (h_upscale_factor * w_upscale_factor)" _new_c = int(_channel / (h_factor * w_factor)) out = np.reshape(input, [_batch, _new_c, h_factor * w_factor, _height, _width]) out = np.reshape(out, [_batch, _new_c, h_factor, w_factor, _height, _width]) out = np.transpose(out, [0, 1, 4, 2, 5, 3]) out = np.reshape(out, [_batch, _new_c, _height * h_factor, _width * w_factor]) return out def _np_pixel_shuffle_grad(input, h_factor, w_factor): (_batch, _new_channel, _height_mul_factor, _width_mul_factor) = input.shape _channel = _new_channel * (h_factor * w_factor) _height = _height_mul_factor // h_factor _width = _width_mul_factor // w_factor out = np.ones(shape=(_batch, _channel, _height, _width)) return out def _test_pixel_shuffle_impl( test_case, device, shape, h_upscale_factor, w_upscale_factor ): x = np.random.randn(*shape) input = flow.tensor( x, dtype=flow.float32, requires_grad=True, device=flow.device(device) ) m = flow.nn.PixelShuffle( h_upscale_factor=h_upscale_factor, w_upscale_factor=w_upscale_factor ) m = m.to(device) of_out = m(input) np_out = _np_pixel_shuffle(x, h_upscale_factor, w_upscale_factor) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = _np_pixel_shuffle_grad(np_out, h_upscale_factor, w_upscale_factor) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestPixelShuffleModule(flow.unittest.TestCase): def test_pixel_shuffle(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_pixel_shuffle_impl] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)] arg_dict["h_upscale_factor"] = [2, 3, 4] arg_dict["w_upscale_factor"] = [2, 3, 4] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)] arg_dict["h_upscale_factor"] = [5] arg_dict["w_upscale_factor"] = [5] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest() def test_pixel_shuffle_with_random_data(test_case): upscale_factor = random().to(int) num_channels = upscale_factor * upscale_factor * random().to(int) m = torch.nn.PixelShuffle(upscale_factor=upscale_factor) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=num_channels).to(device) y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_prelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestPReLU(flow.unittest.TestCase): @autotest(n=5) def test_prelu_4dim_module_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=3).to(device) m = torch.nn.PReLU( num_parameters=3 | nothing(), init=random().to(float) | nothing(), ) m.to(device) m.train(random()) y = m(x) return y @autotest(n=5) def test_prelu_4dim_default_alpha_module_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=3).to(device) m = torch.nn.PReLU(init=random().to(float) | nothing(),) m.to(device) m.train(random()) y = m(x) return y @autotest(n=5) def test_prelu_2dim_module_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) m = torch.nn.PReLU( num_parameters=3 | nothing(), init=random().to(float) | nothing(), ) m.to(device) m.train(random()) y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_prod.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestReduceProd(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_reduce_prod_without_dim(test_case): device = random_device() ndim = random(1, 5).to(int) x = random_tensor(ndim=ndim).to(device) y = torch.prod(x) return y @autotest(n=5, check_graph=True) def test_reduce_prod_with_dim(test_case): device = random_device() ndim = random(1, 5).to(int) x = random_tensor(ndim=ndim).to(device) dim = random(0, ndim).to(int) y = torch.prod(x, dim) y = torch.exp(y) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_reduce_prod_bool_without_dim(test_case): device = random_device() ndim = random(1, 5).to(int) x = random_tensor(ndim=ndim).to(device=device, dtype=torch.bool) y = torch.prod(x) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_reduce_prod_with_dtype(test_case): device = random_device() ndim = random(1, 5).to(int) x = random_tensor(ndim=ndim, low=1.0, high=4.0, requires_grad=False).to(device) dim = random(0, ndim).to(int) y = torch.prod(x, dim, dtype=torch.int32) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_pruning.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import sys from collections import OrderedDict import numpy as np import tempfile import pickle from oneflow.test_utils.test_util import GenArgList import oneflow.nn.utils.prune as prune import oneflow as flow import oneflow.unittest import oneflow.nn as nn import unittest.mock as mock from contextlib import contextmanager from oneflow.test_utils.automated_test_util import * class TestPrune(flow.unittest.TestCase): def test_validate_pruning_amount_init(self): r"""Test the first util function that validates the pruning amount requested by the user the moment the pruning method is initialized. This test checks that the expected errors are raised whenever the amount is invalid. The original function runs basic type checking + value range checks. It doesn't check the validity of the pruning amount with respect to the size of the tensor to prune. That's left to `_validate_pruning_amount`, tested below. """ # neither float not int should raise TypeError with self.assertRaises(TypeError): prune._validate_pruning_amount_init(amount="I'm a string") # float not in [0, 1] should raise ValueError with self.assertRaises(ValueError): prune._validate_pruning_amount_init(amount=1.1) with self.assertRaises(ValueError): prune._validate_pruning_amount_init(amount=20.0) # negative int should raise ValueError with self.assertRaises(ValueError): prune._validate_pruning_amount_init(amount=-10) # all these should pass without errors because they're valid amounts prune._validate_pruning_amount_init(amount=0.34) prune._validate_pruning_amount_init(amount=1500) prune._validate_pruning_amount_init(amount=0) prune._validate_pruning_amount_init(amount=0.0) prune._validate_pruning_amount_init(amount=1) prune._validate_pruning_amount_init(amount=1.0) self.assertTrue(True) def test_validate_pruning_amount(self): r"""Tests the second util function that validates the pruning amount requested by the user, this time with respect to the size of the tensor to prune. The rationale is that if the pruning amount, converted to absolute value of units to prune, is larger than the number of units in the tensor, then we expect the util function to raise a value error. """ # if amount is int and amount > tensor_size, raise ValueError with self.assertRaises(ValueError): prune._validate_pruning_amount(amount=20, tensor_size=19) # amount is a float so this should not raise an error prune._validate_pruning_amount(amount=0.3, tensor_size=0) # this is okay prune._validate_pruning_amount(amount=19, tensor_size=20) prune._validate_pruning_amount(amount=0, tensor_size=0) prune._validate_pruning_amount(amount=1, tensor_size=1) self.assertTrue(True) def test_compute_nparams_to_prune(self): r"""Test that requested pruning `amount` gets translated into the correct absolute number of units to prune. """ self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0) self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10) # if 1 is int, means 1 unit self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1) # if 1. is float, means 100% of units self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15) self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7) def test_random_pruning_sizes(self): r"""Test that the new parameters and buffers created by the pruning method have the same size as the input tensor to prune. These, in fact, correspond to the pruned version of the tensor itself, its mask, and its original copy, so the size must match. """ # fixturize test # TODO: add other modules modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): original_tensor = getattr(m, name) prune.random_unstructured(m, name=name, amount=0.1) # mask has the same size as tensor being pruned self.assertEqual( original_tensor.size(), getattr(m, name + "_mask").size() ) # 'orig' tensor has the same size as the original tensor self.assertEqual( original_tensor.size(), getattr(m, name + "_orig").size() ) # new tensor has the same size as the original tensor self.assertEqual(original_tensor.size(), getattr(m, name).size()) def test_random_pruning_orig(self): r"""Test that original tensor is correctly stored in 'orig' after pruning is applied. Important to make sure we don't lose info about the original unpruned parameter. """ # fixturize test # TODO: add other modules modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): # tensor prior to pruning original_tensor = getattr(m, name) prune.random_unstructured(m, name=name, amount=0.1) result = flow.sum( original_tensor - getattr(m, name + "_orig") ).item() self.assertEqual(result, 0) def test_random_pruning_new_weight(self): r"""Test that module.name now contains a pruned version of the original tensor obtained from multiplying it by the mask. """ # fixturize test # TODO: add other modules modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): # tensor prior to pruning original_tensor = getattr(m, name) prune.random_unstructured(m, name=name, amount=0.1) # weight = weight_orig * weight_mask weight = getattr(m, name) weight_orig_mask = getattr(m, name + "_orig") * getattr( m, name + "_mask" ).to(dtype=original_tensor.dtype) result = flow.sum(weight - weight_orig_mask).item() self.assertEqual(result, 0) def test_identity_pruning(self): r"""Test that a mask of 1s does not change forward or backward. """ input_ = flow.ones(1, 5) m = nn.Linear(5, 2) y_prepruning = m(input_) # output prior to pruning # compute grad pre-pruning and check it's equal to all ones y_prepruning.sum().backward() old_grad_weight = m.weight.grad.clone() # don't grab pointer! self.assertEqual(flow.sum(old_grad_weight - flow.ones_like(m.weight)).item(), 0) old_grad_bias = m.bias.grad.clone() self.assertEqual(flow.sum(old_grad_bias - flow.ones_like(m.bias)).item(), 0) # remove grads m.zero_grad() # force the mask to be made of all 1s prune.identity(m, name="weight") # with mask of 1s, output should be identical to no mask y_postpruning = m(input_) self.assertEqual(flow.sum(y_prepruning - y_postpruning).item(), 0) # with mask of 1s, grad should be identical to no mask y_postpruning.sum().backward() self.assertEqual(flow.sum(old_grad_weight - m.weight_orig.grad).item(), 0) self.assertEqual(flow.sum(old_grad_bias - m.bias.grad).item(), 0) # calling forward twice in a row shouldn't change output y1 = m(input_) y2 = m(input_) self.assertEqual(flow.sum(y1 - y2).item(), 0) def test_random_pruning_0perc(self): r"""Test that a mask of 1s does not change forward or backward. """ input_ = flow.ones(1, 5) m = nn.Linear(5, 2) y_prepruning = m(input_) # output prior to pruning # compute grad pre-pruning and check it's equal to all ones y_prepruning.sum().backward() old_grad_weight = m.weight.grad.clone() # don't grab pointer! self.assertEqual(flow.sum(old_grad_weight - flow.ones_like(m.weight)).item(), 0) old_grad_bias = m.bias.grad.clone() self.assertEqual(flow.sum(old_grad_bias - flow.ones_like(m.bias)).item(), 0) # remove grads m.zero_grad() # force the mask to be made of all 1s with mock.patch( "oneflow.nn.utils.prune.RandomUnstructured.compute_mask" ) as compute_mask: compute_mask.return_value = flow.ones_like(m.weight) prune.random_unstructured( m, name="weight", amount=0.9 ) # amount won't count # with mask of 1s, output should be identical to no mask y_postpruning = m(input_) self.assertEqual(flow.sum(y_prepruning - y_postpruning).item(), 0) # with mask of 1s, grad should be identical to no mask y_postpruning.sum().backward() self.assertEqual(flow.sum(old_grad_weight - m.weight_orig.grad).item(), 0) self.assertEqual(flow.sum(old_grad_bias - m.bias.grad).item(), 0) # calling forward twice in a row shouldn't change output y1 = m(input_) y2 = m(input_) self.assertEqual(flow.sum(y1 - y2).item(), 0) def test_random_pruning(self): input_ = flow.ones(1, 5) m = nn.Linear(5, 2) # define custom mask to assign with mock mask = flow.ones_like(m.weight) mask[1, 0] = 0 mask[0, 3] = 0 # check grad is zero for masked weights with mock.patch( "oneflow.nn.utils.prune.RandomUnstructured.compute_mask" ) as compute_mask: compute_mask.return_value = mask prune.random_unstructured(m, name="weight", amount=0.9) y_postpruning = m(input_) y_postpruning.sum().backward() # weight_orig is the parameter, so it's the tensor that will accumulate the grad self.assertEqual( flow.sum(m.weight_orig.grad - mask).item(), 0 ) # all 1s, except for masked units self.assertEqual(flow.sum(m.bias.grad - flow.ones_like(m.bias)).item(), 0) # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] old_weight_orig = m.weight_orig.clone() # update weights learning_rate = 1.0 for p in m.parameters(): p.data.sub_(p.grad.data * learning_rate) # since these are pruned, they should not be updated self.assertEqual( flow.sum(old_weight_orig[1, 0] - m.weight_orig[1, 0]).item(), 0 ) self.assertEqual( flow.sum(old_weight_orig[0, 3] - m.weight_orig[0, 3]).item(), 0 ) def test_random_pruning_forward(self): r"""check forward with mask (by hand). """ input_ = flow.ones(1, 5) m = nn.Linear(5, 2) # define custom mask to assign with mock mask = flow.zeros_like(m.weight) mask[1, 0] = 1 mask[0, 3] = 1 with mock.patch( "oneflow.nn.utils.prune.RandomUnstructured.compute_mask" ) as compute_mask: compute_mask.return_value = mask prune.random_unstructured(m, name="weight", amount=0.9) yhat = m(input_) self.assertTrue( flow.sum(yhat[0, 0] - m.weight_orig[0, 3] - m.bias[0]).item() - 0 < 1e-5 ) self.assertTrue( flow.sum(yhat[0, 1] - m.weight_orig[1, 0] - m.bias[1]).item() - 0 < 1e-5 ) def test_remove_pruning_forward(self): r"""Remove pruning and check forward is unchanged from previous pruned state. """ input_ = flow.ones(1, 5) m = nn.Linear(5, 2) # define custom mask to assign with mock mask = flow.ones_like(m.weight) mask[1, 0] = 0 mask[0, 3] = 0 # check grad is zero for masked weights with mock.patch( "oneflow.nn.utils.prune.RandomUnstructured.compute_mask" ) as compute_mask: compute_mask.return_value = mask prune.random_unstructured(m, name="weight", amount=0.9) y_postpruning = m(input_) prune.remove(m, "weight") y_postremoval = m(input_) self.assertEqual(flow.sum(y_postpruning - y_postremoval).item(), 0) def test_pruning_id_consistency(self): r"""Test that pruning doesn't change the id of the parameters, which would otherwise introduce issues with pre-existing optimizers that point to old parameters. """ m = nn.Linear(5, 2, bias=False) tensor_id = id(list(m.parameters())[0]) prune.random_unstructured(m, name="weight", amount=0.9) self.assertEqual(tensor_id, id(list(m.parameters())[0])) prune.remove(m, "weight") self.assertEqual(tensor_id, id(list(m.parameters())[0])) def test_random_pruning_pickle(self): modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): prune.random_unstructured(m, name=name, amount=0.1) m_new = pickle.loads(pickle.dumps(m)) self.assertIsInstance(m_new, type(m)) def test_multiple_pruning_calls(self): # if you call pruning twice, the hook becomes a PruningContainer m = nn.Conv3d(2, 2, 2) prune.l1_unstructured(m, name="weight", amount=0.1) weight_mask0 = m.weight_mask # save it for later sanity check # prune again prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0) hook = next(iter(m._forward_pre_hooks.values())) self.assertIsInstance(hook, oneflow.nn.utils.prune.PruningContainer) # check that container._tensor_name is correctly set no matter how # many pruning methods are in the container self.assertEqual(hook._tensor_name, "weight") # check that the pruning container has the right length # equal to the number of pruning iters self.assertEqual(len(hook), 2) # m.weight has been pruned twice # check that the entries of the pruning container are of the expected # type and in the expected order self.assertIsInstance(hook[0], oneflow.nn.utils.prune.L1Unstructured) self.assertIsInstance(hook[1], oneflow.nn.utils.prune.LnStructured) # check that all entries that are 0 in the 1st mask are 0 in the # 2nd mask too self.assertTrue(flow.all(m.weight_mask[weight_mask0 == 0] == 0)) # prune again prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1) # check that container._tensor_name is correctly set no matter how # many pruning methods are in the container hook = next(iter(m._forward_pre_hooks.values())) self.assertEqual(hook._tensor_name, "weight") def test_pruning_container(self): # create an empty container container = prune.PruningContainer() container._tensor_name = "test" self.assertEqual(len(container), 0) p = prune.L1Unstructured(amount=2) p._tensor_name = "test" # test adding a pruning method to a container container.add_pruning_method(p) # test error raised if tensor name is different q = prune.L1Unstructured(amount=2) q._tensor_name = "another_test" with self.assertRaises(ValueError): container.add_pruning_method(q) # test that adding a non-pruning method object to a pruning container # raises a TypeError with self.assertRaises(TypeError): container.add_pruning_method(10) with self.assertRaises(TypeError): container.add_pruning_method("ugh") def test_pruning_container_compute_mask(self): r"""Test `compute_mask` of pruning container with a known `t` and `default_mask`. Indirectly checks that Ln structured pruning is acting on the right axis. """ # create an empty container container = prune.PruningContainer() container._tensor_name = "test" # 1) test unstructured pruning # create a new pruning method p = prune.L1Unstructured(amount=2) p._tensor_name = "test" # add the pruning method to the container container.add_pruning_method(p) # create tensor to be pruned t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32) # create prior mask by hand default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) # since we are pruning the two lowest magnitude units, the outcome of # the calculation should be this: expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=flow.float32) computed_mask = container.compute_mask(t, default_mask) self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0) # 2) test structured pruning q = prune.LnStructured(amount=1, n=2, dim=0) q._tensor_name = "test" container.add_pruning_method(q) # since we are pruning the lowest magnitude one of the two rows, the # outcome of the calculation should be this: expected_mask = flow.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=flow.float32) computed_mask = container.compute_mask(t, default_mask) self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0) # 2) test structured pruning, along another axis r = prune.LnStructured(amount=1, n=2, dim=1) r._tensor_name = "test" container.add_pruning_method(r) # since we are pruning the lowest magnitude of the four columns, the # outcome of the calculation should be this: expected_mask = flow.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=flow.float32) computed_mask = container.compute_mask(t, default_mask) self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0) def test_l1_unstructured_pruning(self): r"""Test that l1 unstructured pruning actually removes the lowest entries by l1 norm (by hand). It also checks that applying l1 unstructured pruning more than once respects the previous mask. """ m = nn.Linear(4, 2) # modify its weight matrix by hand m.weight = flow.nn.Parameter( flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=flow.float32) ) prune.l1_unstructured(m, "weight", amount=2) expected_weight = flow.tensor( [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0) # check that pruning again removes the next two smallest entries prune.l1_unstructured(m, "weight", amount=2) expected_weight = flow.tensor( [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0) def test_l1_unstructured_pruning_with_importance_scores(self): r"""Test that l1 unstructured pruning actually removes the lowest entries of importance scores and not the parameter by l1 norm (by hand). It also checks that applying l1 unstructured pruning more than once respects the previous mask. """ m = nn.Linear(4, 2) # modify its weight matrix by hand m.weight = flow.nn.Parameter( flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=flow.float32) ) importance_scores = flow.tensor( [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=flow.float32 ) prune.l1_unstructured( m, "weight", amount=2, importance_scores=importance_scores ) expected_weight = flow.tensor( [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0) # check that pruning again removes two entries of m.weight that are colocated with # the next two smallest absolute values of importance scores. prune.l1_unstructured( m, "weight", amount=2, importance_scores=importance_scores ) expected_weight = flow.tensor( [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0) def test_unstructured_pruning_same_magnitude(self): r"""Since it may happen that the tensor to prune has entries with the same exact magnitude, it is important to check that pruning happens consistenly based on the bottom % of weights, and not by threshold, which would instead kill off *all* units with magnitude = threshold. """ AMOUNT = 0.2 p = prune.L1Unstructured(amount=AMOUNT) # create a random tensors with entries in {-2, 0, 2} t = 2 * flow.randint(low=-1, high=2, size=(10, 7)) nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) computed_mask = p.compute_mask(t, default_mask=flow.ones_like(t)) nparams_pruned = flow.sum(computed_mask == 0) self.assertEqual(nparams_toprune, nparams_pruned) def test_random_structured_pruning_amount(self): AMOUNT = 0.6 AXIS = 2 p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) t = 2 * flow.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=flow.float32) nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) computed_mask = p.compute_mask(t, default_mask=flow.ones_like(t)) # check that 1 column is fully prune, the others are left untouched remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] per_column_sums = sorted(flow.sum(computed_mask == 0, dim=remaining_axes)) assert per_column_sums == [0, 20] def test_ln_structured_pruning(self): r"""Check Ln structured pruning by hand. """ m = nn.Conv2d(3, 1, 2) m.weight.data = flow.tensor( [ [ [[1.0, 2.0], [1.0, 2.5]], [[0.5, 1.0], [0.1, 0.1]], [[-3.0, -5.0], [0.1, -1.0]], ] ] ) # expected effect of pruning 1 of the 3 channels by L2-norm expected_mask_axis1 = flow.ones_like(m.weight) expected_mask_axis1[:, 1] = 0.0 prune.ln_structured(m, "weight", amount=1, n=2, dim=1) self.assertEqual(flow.sum(expected_mask_axis1 - m.weight_mask).item(), 0) # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm expected_mask_axis3 = expected_mask_axis1 expected_mask_axis3[:, :, :, 0] = 0.0 prune.ln_structured(m, "weight", amount=1, n=1, dim=-1) self.assertEqual(flow.sum(expected_mask_axis3 - m.weight_mask).item(), 0) def test_ln_structured_pruning_importance_scores(self): r"""Check Ln structured pruning by hand. """ m = nn.Conv2d(3, 1, 2) m.weight.data = flow.tensor( [ [ [[1.0, 2.0], [1.0, 2.5]], [[0.5, 1.0], [0.1, 0.1]], [[-3.0, -5.0], [0.1, -1.0]], ] ] ) importance_scores = flow.tensor( [ [ [[10.0, 1.0], [10.0, 1.0]], [[30.0, 3.0], [30.0, 3.0]], [[-20.0, -2.0], [-20.0, -2.0]], ] ] ) # expected effect of pruning 1 of the 3 channels by L2-norm expected_mask_axis1 = flow.ones_like(m.weight) expected_mask_axis1[:, 0] = 0.0 prune.ln_structured( m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores ) self.assertEqual(flow.sum(expected_mask_axis1 - m.weight_mask).item(), 0) # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm expected_mask_axis3 = expected_mask_axis1 expected_mask_axis3[:, :, :, 1] = 0.0 prune.ln_structured( m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores ) self.assertEqual(flow.sum(expected_mask_axis3 - m.weight_mask).item(), 0) def test_remove_pruning(self): r"""`prune.remove` removes the hook and the reparametrization and makes the pruning final in the original parameter. """ modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): # first prune prune.random_unstructured(m, name, amount=0.5) self.assertIn(name + "_orig", dict(m.named_parameters())) self.assertIn(name + "_mask", dict(m.named_buffers())) self.assertNotIn(name, dict(m.named_parameters())) self.assertTrue(hasattr(m, name)) pruned_t = getattr(m, name) # then remove pruning prune.remove(m, name) self.assertIn(name, dict(m.named_parameters())) self.assertNotIn(name + "_orig", dict(m.named_parameters())) self.assertNotIn(name + "_mask", dict(m.named_buffers())) final_t = getattr(m, name) self.assertEqual(flow.sum(pruned_t - final_t).item(), 0) def test_remove_pruning_exception(self): r"""Removing from an unpruned tensor throws an assertion error """ modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): # check that the module isn't pruned self.assertFalse(prune.is_pruned(m)) # since it isn't pruned, pruning can't be removed from it with self.assertRaises(ValueError): prune.remove(m, name) def test_global_pruning(self): r"""Test that global l1 unstructured pruning over 2 parameters removes the `amount=4` smallest global weights across the 2 parameters. """ m = nn.Linear(4, 2) n = nn.Linear(3, 1) # modify the weight matrices by hand m.weight = flow.nn.Parameter( flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=flow.float32) ) n.weight = flow.nn.Parameter(flow.tensor([[0, 0.1, -2]]).to(dtype=flow.float32)) params_to_prune = ( (m, "weight"), (n, "weight"), ) # prune the 4 smallest weights globally by L1 magnitude prune.global_unstructured( params_to_prune, pruning_method=prune.L1Unstructured, amount=4 ) expected_mweight = flow.tensor( [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_mweight - m.weight).item(), 0) expected_nweight = flow.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) self.assertEqual(flow.sum(expected_nweight - n.weight).item(), 0) def test_global_pruning_importance_scores(self): r"""Test that global l1 unstructured pruning over 2 parameters removes the `amount=4` smallest global weights across the 2 parameters. """ m = nn.Linear(4, 2) n = nn.Linear(3, 1) # modify the weight matrices by hand m.weight = flow.nn.Parameter( flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=flow.float32) ) m_importance_scores = flow.tensor( [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=flow.float32 ) n.weight = flow.nn.Parameter(flow.tensor([[0, 0.1, -2]]).to(dtype=flow.float32)) n_importance_scores = flow.tensor([[0, 10.0, -0.2]]).to(dtype=flow.float32) params_to_prune = ( (m, "weight"), (n, "weight"), ) importance_scores = { (m, "weight"): m_importance_scores, (n, "weight"): n_importance_scores, } # prune the 4 smallest weights globally by L1 magnitude prune.global_unstructured( params_to_prune, pruning_method=prune.L1Unstructured, amount=4, importance_scores=importance_scores, ) expected_m_weight = flow.tensor( [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype ) self.assertEqual(flow.sum(expected_m_weight - m.weight).item(), 0) expected_n_weight = flow.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) self.assertEqual(flow.sum(expected_n_weight - n.weight).item(), 0) def test_custom_from_mask_pruning(self): r"""Test that the CustomFromMask is capable of receiving as input at instantiation time a custom mask, and combining it with the previous default mask to generate the correct final mask. """ # new mask mask = flow.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) # old mask default_mask = flow.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) # some tensor (not actually used) t = flow.rand(mask.shape, dtype=flow.float32, device=mask.device) # t = flow.rand_like(mask.to(dtype=flow.float32)) p = prune.CustomFromMask(mask=mask) computed_mask = p.compute_mask(t, default_mask) expected_mask = flow.tensor( [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype ) self.assertEqual(flow.sum(computed_mask - expected_mask).item(), 0) def test_pruning_rollback(self): r"""Test that if something fails when the we try to compute the mask, then the model isn't left in some intermediate half-pruned state. The try/except statement in `apply` should handle rolling back to the previous state before pruning began. """ modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] names = ["weight", "bias"] for m in modules: for name in names: with self.subTest(m=m, name=name): with mock.patch( "oneflow.nn.utils.prune.L1Unstructured.compute_mask" ) as compute_mask: compute_mask.side_effect = Exception("HA!") with self.assertRaises(Exception): prune.l1_unstructured(m, name=name, amount=0.9) self.assertTrue(name in dict(m.named_parameters())) self.assertFalse(name + "_mask" in dict(m.named_buffers())) self.assertFalse(name + "_orig" in dict(m.named_parameters())) def test_pruning_serialization_model(self): # create a model model = flow.nn.Sequential( flow.nn.Linear(10, 10), flow.nn.ReLU(), flow.nn.Linear(10, 1), ) # check that everything looks normal before pruning self.assertNotIn("0.weight_orig", model.state_dict()) self.assertNotIn("0.weight_mask", model.state_dict()) self.assertIn("0.weight", model.state_dict()) # prune one of its parameters prune.l1_unstructured(module=model[0], name="weight", amount=0.9) # check that the original weight and the new mask are present self.assertIn("0.weight_orig", model.state_dict()) self.assertIn("0.weight_mask", model.state_dict()) self.assertNotIn("0.weight", model.state_dict()) self.assertTrue(hasattr(model[0], "weight")) pruned_weight = model[0].weight with tempfile.NamedTemporaryFile() as f: flow.save(model, f.name) new_model = flow.load(f.name) # check that the original weight and the new mask are present self.assertIn("0.weight_orig", new_model.state_dict()) self.assertIn("0.weight_mask", new_model.state_dict()) self.assertNotIn("0.weight", new_model.state_dict()) self.assertTrue(hasattr(new_model[0], "weight")) self.assertEqual(flow.sum(pruned_weight - new_model[0].weight).item(), 0) def test_prune(self): # create a new pruning method p = prune.L1Unstructured(amount=2) # create tensor to be pruned t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32) # create prior mask by hand default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) # since we are pruning the two lowest magnitude units, the outcome of # the calculation should be this: expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) pruned_tensor = p.prune(t, default_mask) self.assertEqual(flow.sum(t * expected_mask - pruned_tensor).item(), 0) def test_prune_importance_scores(self): # create a new pruning method p = prune.L1Unstructured(amount=2) # create tensor to be pruned t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32) importance_scores = flow.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to( dtype=flow.float32 ) # create prior mask by hand default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) # since we are pruning the two lowest magnitude units, the outcome of # the calculation should be this: expected_mask = flow.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) self.assertEqual(flow.sum(t * expected_mask - pruned_tensor).item(), 0) def test_prune_importance_scores_mimic_default(self): # create a new pruning method p = prune.L1Unstructured(amount=2) # create tensor to be pruned t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32) # create prior mask by hand default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) # since we are pruning the two lowest magnitude units, the outcome of # the calculation should be this: expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) pruned_tensor_without_importance_scores = p.prune(t, default_mask) pruned_tensor_with_importance_scores = p.prune( t, default_mask, importance_scores=t ) self.assertEqual( flow.sum( pruned_tensor_without_importance_scores - pruned_tensor_with_importance_scores ).item(), 0, ) self.assertEqual( flow.sum( t * expected_mask - pruned_tensor_without_importance_scores ).item(), 0, ) def test_rnn_pruning(self): l = flow.nn.LSTM(32, 32) # This Module has 4 parameters called: # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' # Pruning one of them causes one of the weights to become a tensor prune.l1_unstructured(l, "weight_ih_l0", 0.5) assert sum([isinstance(p, flow.nn.Parameter) for p in l._flat_weights]) == 3 # Removing the pruning reparametrization restores the Parameter prune.remove(l, "weight_ih_l0") assert sum([isinstance(p, flow.nn.Parameter) for p in l._flat_weights]) == 4 # Make sure that, upon removal of the reparametrization, the # `._parameters` and `.named_parameters` contain the right params. # Specifically, the original weight ('weight_ih_l0') should be placed # back in the parameters, while the reparametrization component # ('weight_ih_l0_orig') should be removed. assert "weight_ih_l0" in l._parameters assert l._parameters["weight_ih_l0"] is not None assert "weight_ih_l0_orig" not in l._parameters assert "weight_ih_l0" in dict(l.named_parameters()) assert dict(l.named_parameters())["weight_ih_l0"] is not None assert "weight_ih_l0_orig" not in dict(l.named_parameters()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_qat_conv_modules.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import random import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_qat_conv1d( test_case, device, quantization_formula, quantization_bit, quantization_scheme, weight_quant_per_layer, input_quant_momentum, ): batch_size = random.randint(1, 5) input_channels = random.randint(1, 3) output_channels = random.randint(1, 3) spatial_size = random.randint(8, 16) kernel_size = random.randint(1, 3) stride = random.randint(1, 2) padding = random.randint(0, 2) qat_conv1d = flow.nn.QatConv1d( in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, stride=stride, padding=padding, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ).to(device) qat_input = flow.rand( batch_size, input_channels, spatial_size, dtype=flow.float32, requires_grad=True, device=device, ) qat_out = qat_conv1d(qat_input) qat_out.sum().backward() qat_out.numpy() qat_input.grad.numpy() def _test_qat_conv2d( test_case, device, quantization_formula, quantization_bit, quantization_scheme, weight_quant_per_layer, input_quant_momentum, ): batch_size = random.randint(1, 5) input_channels = random.randint(1, 3) output_channels = random.randint(1, 3) spatial_size = random.randint(8, 16) kernel_size = random.randint(1, 3) stride = random.randint(1, 2) padding = random.randint(0, 2) qat_conv2d = flow.nn.QatConv2d( in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, stride=stride, padding=padding, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ).to(device) qat_input = flow.rand( batch_size, input_channels, spatial_size, spatial_size, dtype=flow.float32, requires_grad=True, device=device, ) qat_out = qat_conv2d(qat_input) qat_out.sum().backward() qat_out.numpy() qat_input.grad.numpy() def _test_qat_conv3d( test_case, device, quantization_formula, quantization_bit, quantization_scheme, weight_quant_per_layer, input_quant_momentum, ): batch_size = random.randint(1, 5) input_channels = random.randint(1, 3) output_channels = random.randint(1, 3) spatial_size = random.randint(8, 16) kernel_size = random.randint(1, 3) stride = random.randint(1, 2) padding = random.randint(0, 2) qat_conv3d = flow.nn.QatConv3d( in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, stride=stride, padding=padding, quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, weight_quant_per_layer=weight_quant_per_layer, input_quant_momentum=input_quant_momentum, ).to(device) qat_input = flow.rand( batch_size, input_channels, spatial_size, spatial_size, spatial_size, dtype=flow.float32, requires_grad=True, device=device, ) qat_out = qat_conv3d(qat_input) qat_out.sum().backward() qat_out.numpy() qat_input.grad.numpy() @flow.unittest.skip_unless_1n1d() class TestQatModules(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 2 times in past week") def test_qat_conv1d(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["quantization_formula"] = ["google"] arg_dict["quantization_bit"] = [4, 8] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["weight_quant_per_layer"] = [True, False] arg_dict["input_quant_momentum"] = [0.95] for i in range(5): for arg in GenArgList(arg_dict): _test_qat_conv1d(test_case, *arg) def test_qat_conv2d(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["quantization_formula"] = ["google"] arg_dict["quantization_bit"] = [4, 8] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["weight_quant_per_layer"] = [True, False] arg_dict["input_quant_momentum"] = [0.95] for i in range(5): for arg in GenArgList(arg_dict): _test_qat_conv2d(test_case, *arg) def test_qat_conv3d(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["quantization_formula"] = ["google"] arg_dict["quantization_bit"] = [4, 8] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["weight_quant_per_layer"] = [True, False] arg_dict["input_quant_momentum"] = [0.95] for i in range(5): for arg in GenArgList(arg_dict): _test_qat_conv3d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_quantile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @autotest(n=3, check_graph=True) def _test_quantile(test_cast, q): device = random_device() a = random_tensor(2, random(2, 5), random(2, 5)).to(device) out = torch.quantile(a, q, dim=1, interpolation="linear") return out @unittest.skipIf(True, "pytorch-1.10.0 will cause oneflow cudnn or cublas error") @flow.unittest.skip_unless_1n1d() class TestQuantile(flow.unittest.TestCase): def test_quantile(test_case): arg_dict = OrderedDict() arg_dict["q"] = [0.2, 0.6, 0.8] for arg in GenArgList(arg_dict): _test_quantile(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_quantization.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) import oneflow as flow import oneflow.unittest def gen_quant_scale_for_min_max_symmetric(weight, quantization_bit): weight_max = np.max(np.abs(weight)) denominator = 2.0 ** (quantization_bit - 1) - 1 return (weight_max / denominator, 0) def gen_quant_scale_for_min_max_affine(weight, quantization_bit): weight_max = np.max(weight) weight_min = np.min(weight) denominator = 2.0 ** quantization_bit - 1 scale = (weight_max - weight_min) / denominator zero_point = -np.round(weight_min / scale) return (scale, zero_point) def gen_quant_scale_for_min_max_cambricon(weight, quantization_bit): weight_max = np.max(np.abs(weight)) scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2) return (scale, 0) def product(tu): return np.prod(tu).astype(np.int32).item() def quant_per_layer_symmetric(input, quantization_bit, scale): upper_bound = 2.0 ** (quantization_bit - 1) - 1 lower_bound = -upper_bound return np.clip(np.rint(input / scale), lower_bound, upper_bound) def quant_per_layer_affine(input, quantization_bit, scale, zero_point): upper_bound = 2.0 ** quantization_bit - 1 lower_bound = 0 return np.clip(np.rint(input / scale + zero_point), lower_bound, upper_bound) def quant_per_layer_cambricon(input, quantization_bit, shift): upper_bound = 2.0 ** (quantization_bit - 1) - 1 lower_bound = -upper_bound scale = 2 ** shift return np.clip(np.rint(input / scale), lower_bound, upper_bound) def _check_quantize( test_case, input, out_of, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): if per_layer_quantization or quantization_formula == "cambricon": outer_num = 1 inner_num = product(input.shape[0:]) else: outer_num = input.shape[0] inner_num = product(input.shape[1:]) scale_np = np.zeros((outer_num,)) zero_point_np = np.zeros((outer_num,)) out_np = np.zeros((inner_num * outer_num,)) input_flatten = input.flatten() input_diff_np = np.full((inner_num * outer_num,), 1.0 / (inner_num * outer_num)) if quantization_formula == "google": if quantization_scheme == "symmetric": for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit ) out = quant_per_layer_symmetric( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, scale_np[c], ) out_np[c * inner_num : (c + 1) * inner_num] = out else: for c in range(outer_num): (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit ) out = quant_per_layer_affine( input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit, scale_np[c], zero_point_np[c], ) out_np[c * inner_num : (c + 1) * inner_num] = out else: (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon( input_flatten, quantization_bit ) out_np = quant_per_layer_cambricon(input_flatten, quantization_bit, scale_np[0]) rmse = np.sqrt(np.mean((out_of - out_np) ** 2)) assert rmse <= 2.0, "quantization op has bug!" def _run_test_quantize( test_case, device_type, dtype, in_shape, quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ): input = (np.random.random(in_shape) - 0.5).astype(type_name_to_np_type[dtype]) input_tensor = flow.tensor( input, dtype=flow.float32, device=flow.device(device_type) ) min_max_observer = flow.nn.MinMaxObserver( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization, ) (scale, zero_point) = min_max_observer(input_tensor) quantization = flow.nn.Quantization( quantization_formula=quantization_formula, quantization_bit=quantization_bit, quantization_scheme=quantization_scheme, ) output_tensor = quantization(input_tensor, scale, zero_point) out = output_tensor.numpy() _check_quantize( test_case, input, out.flatten(), quantization_bit, quantization_scheme, quantization_formula, per_layer_quantization, ) class TestQuantize(flow.unittest.TestCase): def test_quantize(test_case): arg_dict = OrderedDict() arg_dict["test_case"] = [test_case] arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["dtype"] = ["float32", "double"] arg_dict["in_shape"] = [(9, 40, 20, 10)] arg_dict["quantization_bit"] = [8, 2] arg_dict["quantization_scheme"] = ["symmetric", "affine"] arg_dict["quantization_formula"] = ["google"] arg_dict["per_layer_quantization"] = [True, False] for arg in GenArgList(arg_dict): if arg[-2] == "cambricon" and arg[-1] == False: continue _run_test_quantize(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_quick_gelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch class QuickGELUActivation(torch.nn.Module): """ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs """ def forward(self, input: torch.Tensor) -> torch.Tensor: return input * torch.sigmoid(1.702 * input) def _test_quick_gelu(test_case, device): torch_quick_gelu = QuickGELUActivation() x = np.random.randn(2, 4, 3) torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device)) oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device)) torch_y = torch_quick_gelu(torch_x) oneflow_y = flow._C.quick_gelu(oneflow_x) test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy())) torch_y_sum = torch_y.sum() torch_y_sum.backward() oneflow_y_sum = oneflow_y.sum() oneflow_y_sum.backward() test_case.assertTrue( np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy()) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_quick_gelu(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_quick_gelu] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_rand.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList def _test_rand(test_case, device, shape): y1 = flow.rand(*shape, device=flow.device(device)) y2 = flow.rand(size=shape, device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) test_case.assertTrue(shape == y2.shape) def _test_rand_tuple_shape(test_case, device, shape): y1 = flow.rand(shape, device=flow.device(device)) y2 = flow.rand(shape, device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) def _test_0d_rand(test_case, device, shape): y1 = flow.rand(*shape, device=flow.device(device)) y2 = flow.rand(*shape, device=flow.device(device)) test_case.assertTrue( np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4) ) # 0d is [] and [] test_case.assertTrue(shape == y1.shape) def _test_different_dtype(test_case, device, shape): y1 = flow.rand(*shape, dtype=flow.float32, device=flow.device(device)) y2 = flow.rand(*shape, dtype=flow.float64, device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) with test_case.assertRaises(NotImplementedError): flow.rand(*shape, dtype=flow.int32, device=flow.device(device)) def _test_backward(test_case, device, shape): x = flow.rand(*shape, device=flow.device(device), requires_grad=True) y = x.sum() y.backward() test_case.assertTrue(np.array_equal(np.ones(shape), x.grad.numpy())) def _test_with_generator(test_case, device, shape): gen = flow.Generator() gen.manual_seed(0) y1 = flow.rand( *shape, dtype=flow.float32, device=flow.device(device), generator=gen ) gen.manual_seed(0) y2 = flow.rand( *shape, dtype=flow.float32, device=flow.device(device), generator=gen ) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_rand_with_flow_size(test_case, device, shape): y1 = flow.rand(flow.Size(shape), device=flow.device(device)) y2 = flow.rand(flow.Size(shape), device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) @flow.unittest.skip_unless_1n1d() class TestRandModule(flow.unittest.TestCase): def test_0d_randint(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0d_rand] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_cases(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_rand, _test_rand_tuple_shape, _test_different_dtype, _test_backward, _test_with_generator, _test_rand_with_flow_size, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 4)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_half_rand(test_case): for device in ["cuda", "cpu"]: x = flow.rand(2, 3, dtype=flow.float16, device=flow.device(device)) test_case.assertTrue(x.dtype == flow.float16) test_case.assertTrue(x.shape == flow.Size((2, 3))) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x = flow.rand(2, 3, device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_randint.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_randint(test_case, device, shape, low, high): y1 = flow.randint(low, high, shape, device=flow.device(device)) y2 = flow.randint(low, high, shape, device=flow.device(device)) test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) def _test_0d_randint(test_case, device, shape, low, high): y1 = flow.randint(low, high, shape, device=flow.device(device)) y2 = flow.randint(low, high, shape, device=flow.device(device)) test_case.assertTrue( np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4) ) # 0d is [] and [] test_case.assertTrue(shape == y1.shape) def _test_different_dtype(test_case, device, shape, low, high): for dtype in [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ]: y = flow.randint(low, high, shape, dtype=dtype, device=flow.device(device)) test_case.assertTrue(y.dtype == dtype) test_case.assertTrue(y.shape == shape) def _test_with_generator(test_case, device, shape, low, high): gen = flow.Generator() gen.manual_seed(0) y1 = flow.randint( low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen ) gen.manual_seed(0) y2 = flow.randint( low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen ) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_high(test_case, device, shape, low, high): y1 = flow._C.randint(high, shape, device=flow.device(device)) y2 = flow._C.randint(high, shape, device=flow.device(device)) test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) def _test_0rank(test_case, device, shape, low, high): y1 = flow.randint(low, high, shape, device=flow.device(device)) test_case.assertTrue(y1.shape == shape) @flow.unittest.skip_unless_1n1d() class TestRandint(flow.unittest.TestCase): def test_global_different_types(test_case): for dtype in [ flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ]: placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.randint(0, 16, (10, 1), placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_randint(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_randint, _test_different_dtype, _test_with_generator, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0d_randint(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0d_randint] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(1, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_high_randint(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_high] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3, 4), (2, 5, 2)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0rank_randint(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0rank] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [()] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [1000 + np.random.randint(1, 10) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandintOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x = flow.randint(low=1, high=2, size=flow.Size((2, 3)), device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_randint_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_randint_like(test_case, device, shape, low, high): x = flow.randn(shape) y1 = flow.randint_like(x, low, high, device=flow.device(device)) y2 = flow.randint_like(x, low, high, device=flow.device(device)) test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) def _test_0d_randint_like(test_case, device, shape, low, high): x = flow.randn(shape) y1 = flow.randint_like(x, low, high, device=flow.device(device)) y2 = flow.randint_like(x, low, high, device=flow.device(device)) test_case.assertTrue( np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4) ) # 0d is [] and [] test_case.assertTrue(shape == y1.shape) def _test_different_dtype(test_case, device, shape, low, high): for dtype in [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ]: x = flow.randint(low, high, shape, dtype=dtype) y = flow.randint_like(x, low, high, dtype=dtype, device=flow.device(device)) test_case.assertTrue(y.dtype == dtype) test_case.assertTrue(y.shape == shape) def _test_with_generator(test_case, device, shape, low, high): gen = flow.Generator() gen.manual_seed(0) x = flow.randn(shape) y1 = flow.randint_like( x, low, high, dtype=flow.float32, device=flow.device(device), generator=gen ) gen.manual_seed(0) x = flow.randn(shape) y2 = flow.randint_like( x, low, high, dtype=flow.float32, device=flow.device(device), generator=gen ) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_high(test_case, device, shape, low, high): x = flow.randn(shape) y1 = flow._C.randint_like(x, high, device=flow.device(device)) y2 = flow._C.randint_like(x, high, device=flow.device(device)) test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) def _test_0rank(test_case, device, shape, low, high): x = flow.randn(shape) y1 = flow.randint_like(x, low, high, device=flow.device(device)) test_case.assertTrue(y1.shape == shape) @flow.unittest.skip_unless_1n1d() class TestRandIntLike(flow.unittest.TestCase): def test_global_different_types(test_case): for dtype in [ flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ]: placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x_ = flow.randn((10, 1)) x = flow.randint_like(x_, 0, 16, placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_randint_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_randint_like, _test_different_dtype, _test_with_generator, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0d_randint_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0d_randint_like] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(1, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_high_randint_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_high] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3, 4), (2, 5, 2)] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0rank_randint_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0rank] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [()] arg_dict["low"] = [i for i in range(10)] arg_dict["high"] = [1000 + np.random.randint(1, 10) for i in range(10)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandIntLikeOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x_ = flow.randn((2, 3)) x = flow.randint_like(x_, low=1, high=2, device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_randn.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def _test_randn(test_case, device, shape): y1 = flow.randn(*shape, device=flow.device(device)) y2 = flow.randn(size=shape, device=flow.device(device)) test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) test_case.assertTrue(shape == y2.shape) def _test_0d_rand(test_case, device, shape): y1 = flow.randn(*shape, device=flow.device(device)) y2 = flow.randn(*shape, device=flow.device(device)) test_case.assertTrue( np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4) ) # 0d is [] and [] test_case.assertTrue(shape == y1.shape) def _test_different_dtype(test_case, device, shape): y1 = flow.randn(*shape, dtype=flow.float32, device=flow.device(device)) y2 = flow.randn(*shape, dtype=flow.float64, device=flow.device(device)) test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) test_case.assertTrue(shape == y1.shape) with test_case.assertRaises(NotImplementedError): flow.randn(*shape, dtype=flow.int32, device=flow.device(device)) def _test_backward(test_case, device, shape): x = flow.randn(*shape, device=flow.device(device), requires_grad=True) y = x.sum() y.backward() test_case.assertTrue( np.allclose(np.ones(shape), x.grad.numpy(), atol=1e-4, rtol=1e-4) ) def _test_with_generator(test_case, device, shape): gen = flow.Generator() gen.manual_seed(0) y1 = flow.randn( *shape, dtype=flow.float32, device=flow.device(device), generator=gen ) gen.manual_seed(0) y2 = flow.randn( *shape, dtype=flow.float32, device=flow.device(device), generator=gen ) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_randn_tuple_shape(test_case, device, shape): y1 = flow.randn(shape, device=flow.device(device)) y2 = flow.randn(shape, device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) def _test_randn_with_flow_size(test_case, device, shape): y1 = flow.randn(flow.Size(shape), device=flow.device(device)) y2 = flow.randn(flow.Size(shape), device=flow.device(device)) test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy())) test_case.assertTrue(shape == y1.shape) @flow.unittest.skip_unless_1n1d() class TestRandnModule(flow.unittest.TestCase): def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.randn(16, 16, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_randn(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_randn, _test_different_dtype, _test_backward, _test_with_generator, _test_randn_tuple_shape, _test_randn_with_flow_size, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0d_randn(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0d_rand] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_half_randn(test_case): for device in ["cuda", "cpu"]: x = flow.randn(2, 3, dtype=flow.float16, device=flow.device(device)) test_case.assertTrue(x.dtype == flow.float16) test_case.assertTrue(x.shape == flow.Size((2, 3))) # Just check if `layout` param in api is available, there's no related implementation about it # TODO(WangYi): remove this test when randn **really** supports `layout` def test_randn_layout_param(test_case): x = flow.randn(2, 3, layout=flow.strided) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandnOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x = flow.randn(2, 3, device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) def test_with_generator(test_case): gen = flow.Generator("cuda") x = flow.randn(2, 3, device="cuda", generator=gen) test_case.assertEqual(x.device, flow.device(f"cuda:{flow.env.get_rank()}")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_randn_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_randn_like(test_case, device, shape): x = flow.randn(shape) y = flow.randn_like(x, device=flow.device(device)) test_case.assertTrue(x.shape == y.shape) def _test_0d_randn_like(test_case, device, shape): x = flow.randn(shape) y = flow.randn_like(x, device=flow.device(device)) test_case.assertTrue(x.shape == y.shape) def _test_different_dtype(test_case, device, shape): for dtype in [ flow.float16, flow.float32, flow.float64, flow.double, ]: x = flow.randn(shape, dtype=dtype) y = flow.randn_like(x, dtype=dtype, device=flow.device(device)) test_case.assertTrue(x.shape == y.shape) def _test_with_generator(test_case, device, shape): gen = flow.Generator() gen.manual_seed(0) x = flow.randn(shape) y1 = flow.randn_like( x, dtype=flow.float32, device=flow.device(device), generator=gen ) gen.manual_seed(0) x = flow.randn(shape) y2 = flow.randn_like( x, dtype=flow.float32, device=flow.device(device), generator=gen ) test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)) def _test_0rank(test_case, device, shape): x = flow.randn(shape) y = flow.randn_like(x, device=flow.device(device)) test_case.assertTrue(x.shape == y.shape) @flow.unittest.skip_unless_1n1d() class TestRandIntLike(flow.unittest.TestCase): def test_global_different_types(test_case): for dtype in [ flow.float16, flow.float32, flow.float64, flow.double, ]: placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x_ = flow.randn((10, 1), dtype=dtype) x = flow.randn_like(x_, placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_randn_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_randn_like, _test_different_dtype, _test_with_generator, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0d_randn_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0d_randn_like] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_0rank_randn_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_0rank] arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [()] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandIntLikeOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x_ = flow.randn((2, 3)) x = flow.randn_like(x_, device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_random_generator_and_seed.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import inspect import types import unittest import oneflow as flow import oneflow.nn as nn import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgDict # y1 = rand_op1(x) # y2 = rand_op2(x) # rand_op1 and rand_op2 should have different seed in graph, then lead to different result def _inspect_rand_op_and_args(rand_op, **kwargs): if inspect.isclass(rand_op) and issubclass(rand_op, nn.Module): init_method_signature = inspect.signature(rand_op.__init__) module_init_args = dict() for arg_name in list(init_method_signature.parameters.keys())[1:]: if arg_name in kwargs: module_init_args[arg_name] = kwargs.pop(arg_name) module_instance = rand_op(**module_init_args) return module_instance, kwargs if isinstance(rand_op, types.BuiltinFunctionType): return rand_op, kwargs if inspect.isfunction(rand_op): return rand_op, kwargs raise ValueError(f"invalid rand_op {rand_op}, type: {type(rand_op)}") def _test_rand_op_unidentical(test_case, rand_op, input=None, **kwargs): rand_op1, kwargs1 = _inspect_rand_op_and_args(rand_op, **kwargs) rand_op2, kwargs2 = _inspect_rand_op_and_args(rand_op, **kwargs) if input is None: result1 = rand_op1(**kwargs1) result2 = rand_op2(**kwargs2) else: x1 = input x2 = input.clone() result1 = rand_op1(x1, **kwargs1) result2 = rand_op2(x2, **kwargs2) if isinstance(result1, (list, tuple)): result1 = result1[0] if isinstance(result2, (list, tuple)): result2 = result2[0] test_case.assertFalse( np.allclose(result1.numpy(), result2.numpy()), f"\ninput:\n{input}\result1:\n{result1}\result2:\n{result2}", ) def _test_global_rand_op_with_split(test_case, rand_op, input=None, **kwargs): rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs) ranks = np.array(range(flow.env.get_world_size())) if input is None: device = kwargs.pop("device", None) placement = flow.placement(device, ranks) y = rand_op(placement=placement, sbp=flow.sbp.split(0), **kwargs) else: x = flow.concat([input, input], dim=0) placement = flow.placement(input.device.type, ranks) # local to broadcast global x_broadcast = x.to_global( placement=placement, sbp=flow.sbp.broadcast(), copy=True ) x_split = x_broadcast.to_global(sbp=flow.sbp.split(0)) y = rand_op(x_split, **kwargs) if isinstance(y, (list, tuple)): y = y[0] y_broadcast = y.to_global(placement=placement, sbp=flow.sbp.broadcast()) half = y_broadcast.shape[0] // 2 first_half = y_broadcast[0:half] second_half = y_broadcast[half:] test_case.assertFalse(np.allclose(first_half.numpy(), second_half.numpy())) def _test_global_rand_op_with_broadcast(test_case, rand_op, input=None, **kwargs): rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs) ranks = np.array(range(flow.env.get_world_size())) if input is None: device = kwargs.pop("device", "cpu") placement = flow.placement(device, ranks) y = rand_op(placement=placement, sbp=flow.sbp.broadcast(), **kwargs) else: placement = flow.placement(input.device.type, ranks) # local to broadcast global x = input.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True) y = rand_op(x, **kwargs) if isinstance(y, (list, tuple)): y_local = y[0].to_local() else: y_local = y.to_local() y_all_ranks = y_local.to_global(placement=placement, sbp=flow.sbp.split(0)) y_allgather = y_all_ranks.to_global(sbp=flow.sbp.broadcast()) half = y_allgather.shape[0] // 2 first_half = y_allgather[0:half] second_half = y_allgather[half:] test_case.assertTrue(np.allclose(first_half.numpy(), second_half.numpy())) @flow.unittest.skip_unless_1n1d() class TestRandOpUnidentical(oneflow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op(self): for device in ("cpu", "cuda"): x = flow.randn(4, 16, device=device) _test_rand_op_unidentical(self, nn.Dropout, x, p=0.5) _test_rand_op_unidentical(self, flow._C.rrelu, x, training=True) _test_rand_op_unidentical(self, nn.init.uniform_, x) _test_rand_op_unidentical(self, flow._C.exponential_, x) x1 = flow.rand(4, 16, device=device) _test_rand_op_unidentical( self, flow.multinomial, x1, num_samples=16, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op(self): shape = (4, 16) for device in ("cpu", "cuda"): _test_rand_op_unidentical(self, flow.rand, size=shape, device=device) _test_rand_op_unidentical( self, flow.normal, mean=0.0, std=1.0, size=shape, device=device ) _test_rand_op_unidentical( self, flow.randint, low=0, high=10, size=shape, device=device ) _test_rand_op_unidentical(self, flow.randperm, n=32, device=device) def test_bernoulli(self): x1 = flow.randn(4, 16) _test_rand_op_unidentical(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(4, 16) _test_rand_op_unidentical(self, flow.bernoulli, x2) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like(self): x = flow.randn(4, 16, 64).to("cuda") _test_rand_op_unidentical( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.1, diagonal=2, tril_scale_value=-1000, ) @flow.unittest.skip_unless_1n2d() class TestGlobalRandOp(oneflow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op_with_split(self): for device in ("cpu", "cuda"): x = flow.randn(2, 4, device=device) _test_global_rand_op_with_split(self, nn.Dropout, x, p=0.5) _test_global_rand_op_with_split(self, flow._C.rrelu, x, training=True) _test_global_rand_op_with_split(self, nn.init.uniform_, x) _test_global_rand_op_with_split(self, flow._C.exponential_, x) x1 = flow.rand(2, 8, device=device) _test_global_rand_op_with_split( self, flow.multinomial, x1, num_samples=8, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_usual_rand_op_with_broadcast(self): for device in ("cpu", "cuda"): x = flow.randn(2, 4, device=device) _test_global_rand_op_with_broadcast(self, nn.Dropout, x, p=0.5) _test_global_rand_op_with_broadcast(self, flow._C.rrelu, x, training=True) _test_global_rand_op_with_broadcast(self, nn.init.uniform_, x) _test_global_rand_op_with_broadcast(self, flow._C.exponential_, x) x1 = flow.rand(2, 8, device=device) _test_global_rand_op_with_broadcast( self, flow.multinomial, x1, num_samples=8, replacement=True ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op_with_split(self): shape = (4, 4) for device in ("cpu", "cuda"): _test_global_rand_op_with_split(self, flow.rand, size=shape, device=device) _test_global_rand_op_with_split( self, flow.normal, mean=0.0, std=1.0, size=shape, device=device ) _test_global_rand_op_with_split( self, flow.randint, low=0, high=10, size=shape, device=device ) _test_global_rand_op_with_split(self, flow.randperm, n=32, device=device) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_source_rand_op_with_broadcast(self): shape = (4, 4) for device in ("cpu", "cuda"): _test_global_rand_op_with_broadcast( self, flow.rand, size=shape, device=device ) _test_global_rand_op_with_broadcast( self, flow.normal, mean=0.0, std=1.0, size=shape, device=device ) _test_global_rand_op_with_broadcast( self, flow.randint, low=0, high=10, size=shape, device=device ) _test_global_rand_op_with_broadcast( self, flow.randperm, n=32, device=device ) @unittest.skip("skip for now, becase it failed 4 times in past week") def test_bernoulli_with_split(self): x1 = flow.randn(2, 8) _test_global_rand_op_with_split(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(2, 8) _test_global_rand_op_with_split(self, flow.bernoulli, x2) def test_bernoulli_with_broadcast(self): x1 = flow.randn(2, 8) _test_global_rand_op_with_broadcast(self, flow.bernoulli, x1, p=0.5) x2 = flow.rand(2, 8) _test_global_rand_op_with_broadcast(self, flow.bernoulli, x2) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like_with_split(self): x = flow.randn(2, 16, 64).to("cuda") _test_global_rand_op_with_split( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.1, diagonal=0, tril_scale_value=-1000, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_random_mask_like_with_broadcast(self): x = flow.randn(2, 16, 64).to("cuda") _test_global_rand_op_with_broadcast( self, flow._C.fused_scale_tril_softmax_mask_scale, x, p=0.2, diagonal=1, tril_scale_value=-100, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_randperm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow as flow from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import numpy as np import unittest def _test_randperm_with_generator(test_case, N, device, dtype): generator = flow.Generator() generator.manual_seed(0) y_1 = flow.randperm(N, device=device, dtype=dtype, generator=generator) generator.manual_seed(0) y_2 = flow.randperm(N, device=device, dtype=dtype, generator=generator) test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy())) test_case.assertTrue( y_1.device == flow.device(device) and y_2.device == flow.device(device) ) test_case.assertTrue(y_1.dtype == dtype and y_2.dtype == dtype) def _test_randperm_backward(test_case, N, device, dtype): dtype = flow.float32 # fix dtype here as reduce_sum doesn't support all dtypes yet x = flow.randperm(N, device=device, dtype=dtype) x.requires_grad = True y = x.sum() y.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(N), 1e-05, 1e-05)) def _test_randperm_randomness(test_case, N, device, dtype): n = np.random.randint(100, 1000) x1 = flow.randperm(n, device=device) x2 = flow.randperm(n, device=device) test_case.assertFalse(np.all(x1.numpy() == x2.numpy())) def _test_randperm_large_seq_randomness(test_case, N, device, dtype): n = 65536 x1 = flow.randperm(n, device=device) x2 = flow.randperm(n, device=device) test_case.assertFalse(np.all(x1.numpy() == x2.numpy())) @flow.unittest.skip_unless_1n1d() class Testrandperm(flow.unittest.TestCase): def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.randperm(10, placement=placement, sbp=sbp) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_global_different_types(test_case): for dtype in [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ]: placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,) x = flow.randperm(10, placement=placement, sbp=sbp, dtype=dtype) test_case.assertEqual(x.dtype, dtype) test_case.assertEqual(x.sbp, sbp) test_case.assertEqual(x.placement, placement) def test_randperm(test_case): arg_dict = OrderedDict() arg_dict["test_functions"] = [ _test_randperm_with_generator, _test_randperm_randomness, _test_randperm_large_seq_randomness, ] arg_dict["N"] = [i for i in range(10, 100, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, ] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) def test_randperm_backward(test_case): arg_dict = OrderedDict() arg_dict["test_functions"] = [ _test_randperm_backward, ] arg_dict["N"] = [i for i in range(10, 100, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [flow.float32, flow.float64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(auto_backward=False, check_graph=True) def test_auto_1(test_case): device = random_device() y = torch.randperm(1, device=device) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_auto_0(test_case): device = random_device() y = torch.randperm(0, device=device) return y @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestRandpermOnNonDefaultDevice(flow.unittest.TestCase): def test_non_default_device(test_case): x = flow.randperm(3, device="cuda:1") test_case.assertEqual(x.device, flow.device("cuda:1")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reciprocal.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestReciprocalModule(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_reciprocal_list_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = torch.reciprocal(x) return y @autotest(check_graph=True) def test_flow_reciprocal_list_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.reciprocal(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reduce.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_reduce(test_case, dst, device): if flow.env.get_rank() == 0: np_arr = np.array( [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 1: np_arr = np.array( [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]], dtype=np.float32, ) elif flow.env.get_rank() == 2: np_arr = np.array( [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32, ) elif flow.env.get_rank() == 3: np_arr = np.array( [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32, ) x = flow.tensor(np_arr, device=device, dtype=flow.float32) flow._C.local_reduce(x, dst=dst) if flow.env.get_rank() == dst: test_case.assertTrue( np.allclose( x.numpy(), np.array( [ [24, 26, 25, 43], [20, 28, 35, 10], [15, 21, 27, 20], [21, 31, 30, 12], ], dtype=np.float32, ), ) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestReduce(flow.unittest.TestCase): def test_reduce(test_case): arg_dict = OrderedDict() arg_dict["dst"] = [0, 1, 2, 3] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_reduce(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reduce_sum_like.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_reduce_sum_like(test_case, device): input = flow.tensor( np.ones(shape=(3, 3, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(3, 1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2)) np_out = np.full(shape=like_tensor.shape, fill_value=9) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_one(test_case, device): input = flow.tensor( np.ones(shape=(1, 2, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2)) np_out = np.full(like_tensor.shape, 6) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_different_dim(test_case, device): input = flow.tensor( np.ones(shape=(2, 3, 4), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(3, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0, 2)) np_out = np.full(like_tensor.shape, 8) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_different_dim_with_input_axisvec(test_case, device): input = flow.tensor( np.ones(shape=(1, 5, 6, 1, 6), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 5, 6), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(3, 4)) np_out = np.full(like_tensor.shape, 6) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_3dim(test_case, device): input = flow.tensor( np.ones(shape=(3, 3, 2), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 3, 2), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0,)) np_out = np.full(like_tensor.shape, 3) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_4dim(test_case, device): input = flow.tensor( np.ones(shape=(3, 3, 2, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) like_tensor = flow.tensor( np.ones(shape=(1, 3, 2, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0, 3)) np_out = np.full(like_tensor.shape, 9) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_reduce_sum_like_backward(test_case, device): input = flow.tensor( np.ones(shape=(3, 3, 3), dtype=np.float32), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) like_tensor = flow.tensor( np.ones(shape=(3, 1, 1), dtype=np.float32), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2)) of_out = of_out.sum() of_out.backward() np_grad = np.full(input.shape, 1.0) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestReduceSumLike(flow.unittest.TestCase): def test_reduce_sum_like(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_reduce_sum_like, _test_reduce_sum_like_one, _test_reduce_sum_like_different_dim, _test_reduce_sum_like_different_dim_with_input_axisvec, _test_reduce_sum_like_3dim, _test_reduce_sum_like_4dim, _test_reduce_sum_like_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reflection_pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import ( Array2Numpy, FlattenArray, GenArgList, Index2Coordinate, ) import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def gen_numpy_test_sample(input, padding): (c_idx, h_idx, w_idx) = (1, 2, 3) pad_left = padding[0] pad_right = padding[1] pad_top = padding[2] pad_bottom = padding[3] pad_shape = ((0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)) def _np_reflection_pad2d(input, pad_shape): numpy_reflect = np.pad(input, pad_shape, "reflect") return numpy_reflect def _np_reflection_pad2d_grad(src, dest): (dx_height, dx_width) = (input.shape[h_idx], input.shape[w_idx]) (dy_height, dy_width) = (output.shape[h_idx], output.shape[w_idx]) numpy_src = np.ones(src.shape, np.int32) numpy_dest = np.zeros(dest.shape, np.int32) array_src = FlattenArray(numpy_src) array_dest = FlattenArray(numpy_dest) src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx] dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx] elements_num = src.shape[0] * src_num for iter_n in range(elements_num): coords = Index2Coordinate(iter_n, src.shape) (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx]) ip_x = ip_y = 0 if j < pad_left: ip_x = pad_left * 2 - j elif j >= pad_left and j < dx_width + pad_left: ip_x = j else: ip_x = (dx_width + pad_left - 1) * 2 - j if i < pad_top: ip_y = pad_top * 2 - i elif i >= pad_top and i < dx_height + pad_top: ip_y = i else: ip_y = (dx_height + pad_top - 1) * 2 - i ip_x = ip_x - pad_left ip_y = ip_y - pad_top src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j dest_index = ( n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x ) array_dest[dest_index] += array_src[src_index] numpy_dest = Array2Numpy(array_dest, dest.shape) return numpy_dest output = _np_reflection_pad2d(input, pad_shape) grad = _np_reflection_pad2d_grad(output, input) return (output, grad) def _test_reflection_pad2d(test_case, shape, padding, device): np_input = np.random.randn(*shape).astype(np.float32) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) if isinstance(padding, int): boundary = [padding, padding, padding, padding] elif isinstance(padding, tuple) and len(padding) == 4: boundary = [padding[0], padding[1], padding[2], padding[3]] else: raise ValueError("padding must be in or list or tuple!") (np_out, np_grad) = gen_numpy_test_sample(np_input, boundary) layer = flow.nn.ReflectionPad2d(padding=padding) of_out = layer(of_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestReflectionPadModule(flow.unittest.TestCase): def test_reflection_pad2d(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)] arg_dict["padding"] = [2, (1, 1, 2, 2)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_reflection_pad2d(test_case, *arg) @autotest(n=5) def test_reflection_pad_1d_with_3d_input(test_case): c = random(1, 6).to(int) w = random(1, 6).to(int) m = torch.nn.ReflectionPad1d(padding=random(low=0, high=5).to(int)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=c, dim2=w).to(device) y = m(x) return y @autotest(n=5) def test_reflection_pad_1d_with_2d_input(test_case): w = random(1, 6).to(int) m = torch.nn.ReflectionPad1d(padding=random(low=0, high=5).to(int)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=2, dim1=w).to(device) y = m(x) return y @autotest(n=5) def test_reflection_pad_2d_with_random_data(test_case): c = random(1, 6).to(int) h = random(1, 6).to(int) w = random(1, 6).to(int) m = torch.nn.ReflectionPad2d(padding=random(low=0, high=5).to(int)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device) y = m(x) return y @autotest(n=5) def test_functional_reflection_pad_1d_with_random_data(test_case): c = random(1, 6).to(int) w = random(1, 6).to(int) pad = [1, 2] device = random_device() x = random_tensor(ndim=3, dim1=c, dim2=w).to(device) y = torch.nn.functional.pad(input=x, pad=pad, mode="reflect") return y @autotest(n=5) def test_functional_reflection_pad_2d_with_random_data(test_case): c = random(1, 6).to(int) h = random(1, 6).to(int) w = random(1, 6).to(int) pad = [0, 1, 2, 3] device = random_device() x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device) y = torch.nn.functional.pad(input=x, pad=pad, mode="reflect") return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_repeat.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestRepeat(flow.unittest.TestCase): @autotest(n=10) def test_flow_tensor_repeat_with_random_data(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2) sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) y = x.repeat(sizes) return y @autotest(n=10, auto_backward=False) def test_flow_tensor_repeat_bool_with_random_data(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2).to(torch.bool) sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) y = x.repeat(sizes) return y @autotest(n=10) def test_flow_tensor_repeat_with_0dim_data(test_case): x = random_tensor(ndim=0) sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) y = x.repeat(sizes) return y @autotest(n=5, auto_backward=False) def test_complicated_repeat_case(test_case): x = torch.ones(224, 224) y = torch.triu(x, diagonal=1).repeat(32, 1, 1) z = y.byte() return z @autotest(n=5) def test_flow_tensor_0size_with_random_data(test_case): x = random_tensor(ndim=2, dim0=3, dim1=1) sizes = (1, 0) y = x.repeat(sizes) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_repeat_interleave.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import oneflow.unittest import torch as torch_original from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestRepeatInterLeave(flow.unittest.TestCase): @autotest(n=5) def test_flow_int_repeat_interleave_dim_none(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2) y = torch.repeat_interleave(x, 2) return y @autotest(n=5) def test_flow_int_repeat_interleave_with_dim(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) dim = random(low=0, high=2).to(int) y = torch.repeat_interleave(x, 2, dim) return y @autotest(n=5) def test_flow_tensor_repeat_interleave_dim(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4) z = torch.repeat_interleave(x, y, 1) return z @autotest(n=5) def test_flow_tensor_repeat_interleave_dim_with_output_size(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4) z = torch.repeat_interleave(x, y, 1, output_size=2) return z def test_flow_tensor_repeat_interleave_0size_tensor(test_case): np_arr = np.array( [ [[0.8548, 0.0436, 0.7977], [0.1919, 0.4191, 0.2186]], [[0.4741, 0.8896, 0.6859], [0.5223, 0.7803, 0.1134]], ] ) x_torch = torch_original.tensor(np_arr) x_torch.requires_grad = True y_torch = torch_original.tensor([0, 0]) z_torch = torch_original.repeat_interleave(x_torch, y_torch, 1) z_torch.sum().backward() x_flow = flow.tensor(np_arr) x_flow.requires_grad = True y_flow = flow.tensor([0, 0]) z_flow = flow.repeat_interleave(x_flow, y_flow, 1) z_flow.sum().backward() test_case.assertTrue(np.array_equal(x_torch.grad.numpy(), x_flow.grad.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_replication_pad.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import ( Array2Numpy, FlattenArray, GenArgList, Index2Coordinate, ) import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _np_replication_pad2d_grad(src, dest, padding): (c_idx, h_idx, w_idx) = (1, 2, 3) pad_left = padding[0] pad_right = padding[1] pad_top = padding[2] pad_bottom = padding[3] (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx]) (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx]) numpy_src = np.ones(src.shape, np.int32) numpy_dest = np.zeros(dest.shape, np.int32) array_src = FlattenArray(numpy_src) array_dest = FlattenArray(numpy_dest) src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx] dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx] elements_num = src.shape[0] * src_num for iter_n in range(elements_num): coords = Index2Coordinate(iter_n, src.shape) (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx]) ip_x = ip_y = 0 if j < pad_left: ip_x = pad_left elif j >= pad_left and j < dx_width + pad_left: ip_x = j else: ip_x = dx_width + pad_left - 1 if i < pad_top: ip_y = pad_top elif i >= pad_top and i < dx_height + pad_top: ip_y = i else: ip_y = dx_height + pad_top - 1 ip_x = ip_x - pad_left ip_y = ip_y - pad_top src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x array_dest[dest_index] += array_src[src_index] numpy_dest = Array2Numpy(array_dest, dest.shape) return numpy_dest def _test_ReplicationPad2d(test_case, shape, padding, device): np_input = np.random.random(shape).astype(np.float32) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) if isinstance(padding, int): np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding)) boundry = [padding, padding, padding, padding] elif isinstance(padding, (tuple, int)) and len(padding) == 4: np_boundary = ( (0, 0), (0, 0), (padding[2], padding[3]), (padding[0], padding[1]), ) boundry = [padding[0], padding[1], padding[2], padding[3]] else: raise ValueError("padding must be in or list or tuple!") layer = flow.nn.ReplicationPad2d(padding=padding) of_out = layer(of_input) np_out = np.pad(np_input, np_boundary, mode="edge") test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_out_grad = _np_replication_pad2d_grad(np_out, np_input, boundry) test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out_grad, 0.001, 0.001)) @flow.unittest.skip_unless_1n1d() class TestReplicationPadModule(flow.unittest.TestCase): def test_ReplicationPad2d(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)] arg_dict["padding"] = [2, (1, 1, 2, 2)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_ReplicationPad2d(test_case, *arg) @autotest(n=5) def test_replication_pad1d_with_3d_input(test_case): c = random(1, 6).to(int) w = random(1, 6).to(int) pad = random(low=0, high=5).to(int) m = torch.nn.ReplicationPad1d(padding=pad) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=3, dim1=c, dim2=w).to(device) y = m(x) return y @autotest(n=5) def test_replication_pad1d_with_2d_input(test_case): w = random(1, 6).to(int) m = torch.nn.ReplicationPad1d(padding=random(low=0, high=5).to(int)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=2, dim1=w).to(device) y = m(x) return y @autotest(n=5) def test_replication_pad2d_with_random_data(test_case): c = random(1, 6).to(int) h = random(1, 6).to(int) w = random(1, 6).to(int) m = torch.nn.ReplicationPad2d(padding=random(low=0, high=5)) m.train(random()) device = random_device() m.to(device) x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device) y = m(x) return y @autotest(n=5) def test_functional_replication_pad_1d_with_random_data(test_case): c = random(1, 6).to(int) w = random(1, 6).to(int) pad = [0, 1] device = random_device() x = random_tensor(ndim=3, dim1=c, dim2=w).to(device) y = torch.nn.functional.pad(input=x, pad=pad, mode="replicate") return y @autotest(n=5) def test_functional_replication_pad_2d_with_random_data(test_case): c = random(1, 6).to(int) h = random(1, 6).to(int) w = random(1, 6).to(int) pad = [0, 1, 2, 3] device = random_device() x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device) y = torch.nn.functional.pad(input=x, pad=pad, mode="replicate") return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reshape.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_reshape(test_case, device): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor(x, dtype=flow.float32, device=flow.device(device)) of_shape = flow.reshape(input, shape=[2, 2, 2, -1]).numpy().shape np_shape = (2, 2, 2, 2) test_case.assertTrue(np.array_equal(of_shape, np_shape)) def _test_reshape_tuple(test_case, device): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor(x, dtype=flow.float32, device=flow.device(device)) of_shape = flow.reshape(input, shape=(2, 2, 2, -1)).numpy().shape np_shape = (2, 2, 2, 2) test_case.assertTrue(np.array_equal(of_shape, np_shape)) def _test_reshape_backward(test_case, device): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor( x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = flow.reshape(input, shape=[2, 2, 2, -1]).sum() of_out.backward() np_grad = np.array( [ [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], ] ) test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) def _test_reshape_scalar(test_case, device): x = flow.tensor(2.0, device=flow.device(device)) test_case.assertTrue(np.array_equal(x.shape, ())) a = flow.reshape(x, (1,)) test_case.assertTrue(np.array_equal(a.shape, (1,))) b = flow.reshape(x, (1, 1, 1, 1,)) test_case.assertTrue(np.array_equal(b.shape, (1, 1, 1, 1))) c = flow.reshape(b, ()) test_case.assertTrue(np.array_equal(c.shape, ())) d = flow.reshape(x, ()) test_case.assertTrue(np.array_equal(d.shape, ())) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_reshape(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_reshape, _test_reshape_tuple, _test_reshape_backward, _test_reshape_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_reshape_flow_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.reshape(x, shape=(-1,)) return y @autotest(n=5) def test_reshape_flow_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.reshape(x, shape=(-1,)) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_reshape_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 0, 3).to(device) y = torch.reshape( x, shape=(random(0, 5).to(int).value(), 0, random(0, 5).to(int).value()) ) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_reshape_flow_bool_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device=device, dtype=torch.bool) y = torch.reshape(x, shape=(-1,)) return y @autotest(n=2, auto_backward=False, check_graph=True) def test_reshape_like(test_case): device = random_device() shape = [random(1, 5).to(int).value() for _ in range(4)] like_shape = np.random.choice( np.array(shape), len(shape), replace=False ).tolist() x = ( random_tensor(4, *shape, requires_grad=False) .to(device=device) .requires_grad_() ) y = ( random_tensor(4, *like_shape) .to(device=device) .requires_grad_(random_bool()) ) # forward of_z = flow._C.reshape_like(x.oneflow, y.oneflow) torch_z = torch.pytorch.reshape(x.pytorch, like_shape) test_case.assertTrue( np.array_equal(of_z.numpy(), torch_z.detach().cpu().numpy()) ) # backward of_z.sum().backward() torch_z.sum().backward() test_case.assertTrue( np.array_equal( x.grad.oneflow.numpy(), x.grad.pytorch.detach().cpu().numpy() ) ) @profile(torch.reshape) def profile_reshape(test_case): torch.reshape(torch.ones(50, 20), (20, 50)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_reshape_sbp.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import oneflow.unittest import oneflow as flow @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestReshapeSbp(flow.unittest.TestCase): def test_reshape_sbp(test_case): input = flow.rand( 9, 9, 8, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.split(0) ) output = input.view(81, 8) test_case.assertTrue(output.sbp[0] != flow.sbp.split(0)) @flow.unittest.skip_unless_1n4d() class TestReshapeNdSbp(flow.unittest.TestCase): def test_reshape_nd_sbp(test_case): in_shape = (8, 4) out_shape = (2, 4, 4) P = flow.placement("cpu", [[0, 1], [2, 3]]) in_sbp = [flow.sbp.split(0), flow.sbp.split(0)] input = flow.rand(*in_shape, placement=P, sbp=in_sbp) output = input.view(*out_shape) out_sbp = output.sbp test_case.assertTrue(len(in_sbp) == len(out_sbp)) test_case.assertTrue(out_sbp[0] == flow.sbp.split(0)) test_case.assertTrue(out_sbp[1] == flow.sbp.split(1)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_resnet_load_torch_weight_compatibile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import torch import torchvision.models as models_torch import flowvision.models as models_flow import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestResNet18LoadWeightCompatibile(flow.unittest.TestCase): def test_resnet18_load_weight_compatibile(test_case): resnet18_torch = models_torch.resnet18(pretrained=True) resnet18_flow = models_flow.resnet18() parameters = resnet18_torch.state_dict() for key, value in parameters.items(): val = value.detach().cpu().numpy() parameters[key] = val resnet18_flow.load_state_dict(parameters) torch_input = torch.randn(1, 3, 224, 224) flow_input = flow.tensor(torch_input.cpu().numpy()) torch_output = resnet18_torch(torch_input) flow_output = resnet18_flow(flow_input) test_case.assertTrue( np.allclose( torch_output.detach().numpy(), flow_output.numpy(), atol=1e-4, rtol=1e-3 ) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_rmsnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest from collections import OrderedDict import oneflow as flow import oneflow.unittest import torch def _get_norm_dims(shape, normalized_shape): lpad = len(shape) - len(normalized_shape) assert lpad >= 0 return tuple(range(lpad, len(shape))) def _torch_rmsnorm(x, weight, normalized_shape=None, eps=1e-6): if weight is not None: normalized_shape = weight.shape else: assert normalized_shape is not None norm_dims = _get_norm_dims(x.shape, normalized_shape) root_mean = torch.mean(x * x, dim=norm_dims, keepdim=True) rms = torch.rsqrt(root_mean + eps) normed = x * rms return normed * weight if weight is not None else normed def _test_rmsnorm( test_case, shape, normalized_shape, affine=True, eps=1e-6, dtype=flow.float32, device="cuda", ): np_x = np.random.randn(*shape).astype(np.float32) np_weight = ( np.random.randn(*normalized_shape).astype(np.float32) if affine else None ) torch_dtype = torch.float16 if dtype is flow.float16 else torch.float32 torch_x = torch.tensor(np_x).to(device=device, dtype=torch_dtype) torch_weight = ( torch.tensor(np_weight).to(device=device, dtype=torch_dtype) if affine else None ) torch_x.requires_grad_(True) if affine: torch_weight.requires_grad_(True) torch_y = _torch_rmsnorm(torch_x, torch_weight, normalized_shape, eps) np_rand_init_grad = np.random.randn(*tuple(torch_y.shape)).astype(np.float32) torch_rand_init_grad = torch.tensor(np_rand_init_grad).to( device=device, dtype=torch_dtype ) (torch_y * torch_rand_init_grad).sum().backward() torch_y = torch_y.detach().cpu().numpy() torch_x_grad = torch_x.grad.detach().cpu().numpy() if affine: torch_weight_grad = torch_weight.grad.detach().cpu().numpy() x = flow.tensor(np_x).to(device=device, dtype=dtype) weight = flow.tensor(np_weight).to(device=device, dtype=dtype) if affine else None x.requires_grad_(True) if affine: weight.requires_grad_(True) y = flow._C.rms_norm(x, weight, normalized_shape, eps) # np_rand_init_grad = np.random.randn(*tuple(y.shape)).astype(np.float32) rand_init_grad = flow.tensor(np_rand_init_grad).to(device=device, dtype=dtype) (y * rand_init_grad).sum().backward() y = y.detach().cpu().numpy() x_grad = x.grad.detach().cpu().numpy() if affine: weight_grad = weight.grad.detach().cpu().numpy() def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8): test_case.assertTrue( np.allclose(a, b, atol=atol, rtol=rtol), f"\n{'=' * 80}" f"\n{a_name}:" f"\n{a}" f"\n{'-' * 80}" f"\n{b_name}:" f"\n{b}" f"\n{'-' * 80}" f"\ndiff:" f"\n{a - b}" f"\n{'*' * 80}" f"\nshape={shape}" f"\normalized_shape={normalized_shape}" f"\naffine={affine}" f"\ndtype={dtype}" f"\ndevice={device}" f"\n{a_name} vs. {b_name} max abs diff: {np.max(np.abs(a - b))}", ) if dtype is flow.float16: compare(y, torch_y, "y", "torch_y", 1e-3, 1e-2) compare(x_grad, torch_x_grad, "x_grad", "torch_x_grad", 1e-2, 1e-2) if affine: compare( weight_grad, torch_weight_grad, "weight_grad", "torch_weight_grad", 0.1, 0.1, ) else: compare(y, torch_y, "y", "torch_y") compare(x_grad, torch_x_grad, "x_grad", "torch_x_grad") if affine: compare( weight_grad, torch_weight_grad, "weight_grad", "torch_weight_grad", 1e-5, 1e-4, ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestRMSNorm(flow.unittest.TestCase): def test_real_example(test_case): _test_rmsnorm( test_case, shape=[512, 4, 768], normalized_shape=[768], affine=True, dtype=flow.float16, device="cuda", ) def test_no_affine(test_case): _test_rmsnorm( test_case, shape=[4, 16], normalized_shape=[16], affine=False, ) def test_warp_impl(test_case): _test_rmsnorm( test_case, shape=[32, 1024], normalized_shape=[1024], dtype=flow.float16, ) _test_rmsnorm(test_case, shape=[16, 512], normalized_shape=[512]) _test_rmsnorm(test_case, shape=[15, 512], normalized_shape=[512]) _test_rmsnorm(test_case, shape=[16, 511], normalized_shape=[511]) _test_rmsnorm(test_case, shape=[13, 499], normalized_shape=[499]) def test_block_smem_impl(test_case): _test_rmsnorm( test_case, shape=[16, 2048], normalized_shape=[2048], dtype=flow.float16, ) _test_rmsnorm(test_case, shape=[8, 1536], normalized_shape=[1536]) _test_rmsnorm(test_case, shape=[8, 2048], normalized_shape=[2048]) _test_rmsnorm(test_case, shape=[7, 1536], normalized_shape=[1536]) _test_rmsnorm(test_case, shape=[8, 1533], normalized_shape=[1533]) _test_rmsnorm(test_case, shape=[7, 1533], normalized_shape=[1533]) @unittest.skip("skip for now, becase it failed 4 times in past week") def test_block_uncached_impl(test_case): _test_rmsnorm( test_case, shape=[16, 1024 * 1024], normalized_shape=[1024 * 1024], dtype=flow.float16, ) _test_rmsnorm( test_case, shape=[8, 1024], normalized_shape=[1024], dtype=flow.double ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_roc_auc_score.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList from sklearn.metrics import roc_auc_score def _test_roc_auc_score(test_case, label_dtype, pred_dtype): inputs = [ {"label": [0, 0, 1, 1], "pred": [0.1, 0.4, 0.35, 0.8], "score": 0.75}, {"label": [0, 1, 0, 1], "pred": [0.5, 0.5, 0.5, 0.5], "score": 0.5}, ] for data in inputs: label = flow.tensor(data["label"], dtype=label_dtype) pred = flow.tensor(data["pred"], dtype=pred_dtype) of_score = flow.roc_auc_score(label, pred) test_case.assertTrue(np.allclose(of_score.numpy()[0], data["score"])) def _compare_roc_auc_score(test_case, label_dtype, pred_dtype): n_examples = 16384 label = np.random.randint(0, 2, n_examples) pred = np.random.random(n_examples) score = roc_auc_score(label, pred) label = flow.tensor(label, dtype=label_dtype) pred = flow.tensor(pred, dtype=pred_dtype) of_score = flow.roc_auc_score(label, pred) test_case.assertTrue(np.allclose(of_score.numpy()[0], score)) @flow.unittest.skip_unless_1n1d() class TestNMS(flow.unittest.TestCase): def test_roc_auc_score(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_roc_auc_score, _compare_roc_auc_score] arg_dict["label_dtype"] = [ flow.double, flow.int32, flow.float, flow.int64, flow.int8, flow.uint8, ] arg_dict["pred_dtype"] = [flow.float] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_roi_align.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import math import oneflow as flow from oneflow.test_utils.test_util import GenArgList input_np = np.array( [ [ [ [ 0.33840093, 1.1469249, 1.0410756, -0.8350606, -1.782742, -0.00350855, -0.45829752, -1.0764053, ], [ -0.4169678, -0.07322863, 1.5186151, 1.3238515, -0.3002863, 0.90660757, -0.2955834, 1.5069526, ], [ 0.3829125, 1.0149552, -0.5808607, -0.4644214, 1.2142111, 0.668561, 1.0866925, 0.16446872, ], [ 0.14043295, -0.55108964, -0.8154048, 1.1554539, 2.421505, -0.54017824, 0.32610297, -1.0632077, ], [ -0.6218423, 0.6000421, 0.3742695, 0.11130165, 0.9991065, -0.28596586, -0.05164787, 0.07725058, ], [ 0.6141537, 0.2919493, 0.2101646, -0.16639, 1.145933, 0.08825321, 0.9865119, 0.47285828, ], [ -1.5073836, -0.8056736, -0.7402776, -0.9932287, 0.74761075, -0.46474454, -0.22881153, 0.6082243, ], [ 0.8328902, 0.17223845, 0.48917648, -1.6264182, 0.248678, -1.2603166, 1.2644174, 0.06434552, ], ] ], [ [ [ 0.6627289, 0.68173873, 0.17659399, 0.17474514, 0.72995424, -0.47240442, 0.27204773, -0.5277862, ], [ 0.23609516, 0.9604236, 0.78075147, 0.26125216, 0.72746485, 0.04412199, 0.04948105, -0.08477508, ], [ 0.8646437, -0.20755729, 1.0184883, 0.06346282, -0.18039183, 0.56243396, -0.07350786, -1.8523406, ], [ -0.2267861, -1.6466936, 2.1746075, -1.2284307, 0.74488103, -0.13243976, -0.9046582, -2.2992454, ], [ -0.56131303, -0.17723852, -0.6063047, 2.4105318, 0.96672636, -1.8386889, 1.1021106, -0.65429336, ], [ 2.0618255, -0.86972237, -0.59159493, 0.9894253, -0.26607743, -0.395585, -0.44035113, -0.663197, ], [ -0.02398485, -0.04574186, -0.43163615, -0.42599657, -2.751177, -0.35520887, -0.413676, 2.0098279, ], [ 1.5619192, -2.4961088, 0.08771367, -2.289146, 1.0729461, 0.7120767, -0.09780294, -1.6628668, ], ] ], ] ) rois_np = np.array( [ [1.0, 2.0, 1.0324688, 2.5, 3.90168], [1.0, 2.5, 2.8329468, 3.5, 3.2008305], [0.0, 1.0, 1.6188955, 2.0, 0.99051666], [1.0, 1.0, 1.843338, 1.0, 3.9240131], [1.0, 2.0, 2.798994, 3.5, 1.2012959], [0.0, 0.5, 2.7753997, 3.0, 0.8280029], [1.0, 0.5, 2.167975, 2.0, 2.067833], [0.0, 0.5, 2.6843219, 2.0, 3.9924717], [0.0, 2.0, 2.8996983, 3.5, 2.356554], [0.0, 1.5, 0.34730053, 3.0, 2.8540745], [0.0, 0.0, 2.096885, 0.5, 3.357812], [0.0, 1.5, 0.10133362, 3.0, 0.18236923], [1.0, 1.0, 1.609498, 1.5, 3.8893862], [0.0, 1.5, 0.03415012, 1.5, 1.2880297], [0.0, 0.5, 3.9403543, 2.0, 3.8870106], [0.0, 0.0, 3.7515945, 3.5, 0.5866394], [1.0, 1.5, 1.7729645, 2.0, 1.2372265], [1.0, 0.0, 1.5092888, 2.0, 3.1585617], [1.0, 0.0, 2.9033833, 1.5, 1.659832], [1.0, 0.5, 1.9115062, 3.0, 1.066021], [0.0, 1.5, 3.185645, 2.0, 0.20558739], [1.0, 2.0, 0.3081894, 2.5, 2.4888725], [0.0, 0.5, 3.5662794, 3.5, 2.8792458], [1.0, 0.5, 2.556768, 2.5, 2.1553097], [0.0, 1.0, 1.397994, 3.5, 0.77407074], [0.0, 0.5, 3.1722808, 3.5, 2.5378036], [0.0, 0.5, 0.11013985, 3.5, 0.8963146], [0.0, 2.0, 1.1824799, 2.0, 3.2211132], [1.0, 0.0, 3.9227288, 2.0, 2.0894089], [0.0, 1.0, 0.79490566, 1.5, 3.4291687], ] ) input_grad_np = np.array( [ [ [ [ 0.2517704, 1.7398968, 8.248332, 16.302334, 11.048147, 10.059495, 2.800579, 0.24844748, ], [ 0.790752, 3.154358, 13.0182705, 15.519342, 7.0133696, 6.28652, 3.9538488, 0.51601994, ], [ 0.7077478, 3.6854784, 19.228241, 22.597464, 10.153106, 6.2180595, 3.5736852, 0.44621366, ], [ 1.1430397, 2.6666558, 8.699481, 12.510508, 7.6093874, 3.3150473, 1.0373969, 0.08225401, ], [ 7.372374, 3.458156, 6.5517087, 10.535179, 9.493686, 5.800008, 3.2196481, 0.3790145, ], [ 9.979998, 7.723156, 11.384828, 15.13672, 14.71994, 11.550301, 8.666647, 1.1556869, ], [ 7.4674473, 7.990606, 11.032139, 10.031732, 6.5969977, 5.1203485, 4.1267443, 0.57233953, ], [ 1.9118737, 10.9567, 12.461995, 10.991727, 2.2403586, 0.9002282, 0.74645257, 0.1000254, ], ] ], [ [ [0.0, 0.0, 0.0, 0.2796778, 1.6780672, 0.2796781, 0.0, 0.0], [ 0.02485762, 0.17400333, 0.19886094, 0.94998413, 4.7056007, 0.9251272, 0.02485762, 0.0, ], [ 0.54076296, 2.3330488, 4.2377095, 13.100019, 12.285746, 4.7681584, 1.6636131, 0.18542966, ], [ 4.555413, 9.538326, 14.063398, 17.882318, 14.635002, 5.9126663, 2.6039343, 0.3144545, ], [ 7.877132, 19.767809, 24.037426, 15.584505, 14.542083, 4.4302306, 2.3387682, 0.3145125, ], [ 6.3498077, 11.157468, 14.465272, 6.4254785, 7.471047, 7.448948, 6.4972777, 0.88493776, ], [ 2.473032, 6.144208, 9.52839, 3.2779343, 4.3061023, 6.409383, 5.87155, 0.80066556, ], [ 1.4289956, 4.5101476, 7.2189507, 2.2500885, 2.8763475, 0.45081174, 0.0, 0.0, ], ] ], ] ) def bilinear_interpolate(data, y, x, snap_border=False): height, width = data.shape if snap_border: if -1 < y <= 0: y = 0 elif height - 1 <= y < height: y = height - 1 if -1 < x <= 0: x = 0 elif width - 1 <= x < width: x = width - 1 y_low = int(math.floor(y)) x_low = int(math.floor(x)) y_high = y_low + 1 x_high = x_low + 1 wy_h = y - y_low wx_h = x - x_low wy_l = 1 - wy_h wx_l = 1 - wx_h val = 0 for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): if 0 <= yp < height and 0 <= xp < width: val += wx * wy * data[yp, xp] return val def roi_align_np( in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, dtype=np.float32, ): n_channels = in_data.shape[1] out_data = np.zeros((rois.shape[0], n_channels, pool_h, pool_w), dtype=dtype) offset = 0.5 if aligned else 0.0 for r, roi in enumerate(rois): batch_idx = int(roi[0]) j_begin, i_begin, j_end, i_end = ( x.item() * spatial_scale - offset for x in roi[1:] ) roi_h = i_end - i_begin roi_w = j_end - j_begin bin_h = roi_h / pool_h bin_w = roi_w / pool_w for i in range(0, pool_h): start_h = i_begin + i * bin_h grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h)) for j in range(0, pool_w): start_w = j_begin + j * bin_w grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) for channel in range(0, n_channels): val = 0 for iy in range(0, grid_h): y = start_h + (iy + 0.5) * bin_h / grid_h for ix in range(0, grid_w): x = start_w + (ix + 0.5) * bin_w / grid_w val += bilinear_interpolate( in_data[batch_idx, channel, :, :], y, x, snap_border=True, ) val /= grid_h * grid_w out_data[r, channel, i, j] = val return out_data def _test_roi_align(test_case, device): input = flow.tensor( np.random.randn(2, 3, 64, 64), dtype=flow.float32, device=flow.device(device) ) random_img_idx = np.random.randint(low=0, high=2, size=(200, 1)) random_box_idx = np.random.uniform(low=0, high=64 * 64, size=(200, 2)).astype( np.float32 ) def get_h_w(idx1, idx2): if idx1 > idx2: idx1, idx2 = idx2, idx1 h1 = idx1 // 64 w1 = idx1 % 64 h2 = idx2 // 64 w2 = idx2 % 64 return [x / 2 for x in [h1, w1, h2, w2]] zipped = zip(random_box_idx[:, 0], random_box_idx[:, 1]) concated = [get_h_w(idx1, idx2) for (idx1, idx2) in zipped] concated = np.array(concated) rois = flow.tensor( np.hstack((random_img_idx, concated)), dtype=flow.float32, device=flow.device(device), ) of_out = flow.roi_align(input, rois, 2.0, 14, 14, 2, True) np_out = roi_align_np(input.numpy(), rois.numpy(), 14, 14, 2.0, 2, True) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=1e-4, atol=1e-4)) def _test_roi_align_backward(test_case, device): input = flow.tensor( input_np, dtype=flow.float32, device=flow.device(device), requires_grad=True ) rois = flow.tensor(rois_np, dtype=flow.float32, device=flow.device(device)) of_out = flow.roi_align(input, rois, 2.0, 5, 5, 2, True) of_out.sum().backward() test_case.assertTrue( np.allclose(input.grad.numpy(), input_grad_np, rtol=1e-5, atol=1e-5) ) @flow.unittest.skip_unless_1n1d() class TestRoIAlign(flow.unittest.TestCase): def test_roi_align(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_roi_align, _test_roi_align_backward] arg_dict["device"] = ["cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_roll.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList import torch def _test_roll(test_case, device): torch_x = torch.rand( (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True ) torch_grad = torch.rand_like(torch_x, device=device) shifts = ( np.random.randint(-100, 100), np.random.randint(-100, 100), np.random.randint(-100, 100), np.random.randint(-100, 100), ) dims = (0, 2, 3, 4) torch_y = torch.roll(torch_x, shifts, dims) torch_y.backward(torch_grad) of_x = flow.tensor( torch_x.detach().cpu().numpy(), device=device, dtype=flow.float32, requires_grad=True, ) of_y = flow.roll(of_x, shifts, dims) of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32) of_y.backward(of_grad) test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy())) test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy())) def _test_roll_single_dims(test_case, device): torch_x = torch.rand( (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True ) torch_grad = torch.rand_like(torch_x, device=device) shifts = np.random.randint(-100, 100) dims = np.random.randint(0, 4) torch_y = torch.roll(torch_x, shifts, dims) torch_y.backward(torch_grad) of_x = flow.tensor( torch_x.detach().cpu().numpy(), device=device, dtype=flow.float32, requires_grad=True, ) of_y = flow.roll(of_x, shifts, dims) of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32) of_y.backward(of_grad) test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy())) test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy())) def _test_roll_none_dims(test_case, device): torch_x = torch.rand( (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True ) torch_grad = torch.rand_like(torch_x, device=device) shifts = np.random.randint(-100, 100) dims = None torch_y = torch.roll(torch_x, shifts, dims) torch_y.backward(torch_grad) of_x = flow.tensor( torch_x.detach().cpu().numpy(), device=device, dtype=flow.float32, requires_grad=True, ) of_y = flow.roll(of_x, shifts, dims) of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32) of_y.backward(of_grad) test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy())) test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy())) @flow.unittest.skip_unless_1n1d() class TestRoll(flow.unittest.TestCase): def test_expand_compare_with_torch(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_roll, _test_roll_single_dims, _test_roll_none_dims, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_round.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestRound(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_round_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.round(x) return y @autotest(check_graph=True) def test_flow_round_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.round(x) return y @autotest(check_graph=True) def test_flow_round_half_to_even(test_case): device = random_device() random_shape = [random(1, 10).to(int).value() for _ in range(4)] random_tenosr = np.random.randint(-99999, 99999, size=random_shape) x = torch.tensor(random_tenosr).to(device) y = torch.full(x.shape, 0.5).to(device) y += x y = y.requires_grad_() z = torch.round(y) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_rrelu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from collections import OrderedDict import numpy as np import torch as torch_original from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def do_test_rrelu_same_bound(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) flow.manual_seed(233) torch_original.manual_seed(233) flow_tensor = flow.tensor(np_x, requires_grad=True, device=device) torch_tensor = torch_original.tensor(np_x, requires_grad=True, device=device) rate = np.random.randn() flow_rrelu = flow.nn.RReLU(lower=rate, upper=rate) torch_rrelu = torch_original.nn.RReLU(lower=rate, upper=rate) flow_out = flow_rrelu(flow_tensor) torch_out = torch_rrelu(torch_tensor) test_case.assertTrue( np.allclose( flow_out.cpu().detach().numpy(), torch_out.cpu().detach().numpy(), atol=1e-5, rtol=1e-5, ) ) flow_out.sum().backward() torch_out.sum().backward() test_case.assertTrue( np.allclose( flow_tensor.grad.cpu().detach().numpy(), torch_tensor.grad.cpu().detach().numpy(), atol=1e-5, rtol=1e-5, ) ) def do_test_rrelu_different_bound(test_case, shape, device, dtype): np_x = np.random.randn(*shape).astype(dtype) flow_tensor = flow.tensor(np_x, requires_grad=True, device=device) rate = np.random.randn() flow_rrelu = flow.nn.RReLU(lower=rate, upper=rate + 0.5) flow_out = flow_rrelu(flow_tensor) flow_out.sum().backward() flow_grad = flow_tensor.grad flow_div = flow_out / flow_tensor test_case.assertTrue( np.allclose( (flow.where(flow_tensor >= 0, 1, 0)).cpu().detach().numpy(), (flow.where(flow_div == 1.0, 1, 0)).cpu().detach().numpy(), rtol=1e-4, ) ) test_case.assertTrue( np.allclose( (flow.where(flow_tensor < 0, 1, 0)).cpu().detach().numpy(), ( flow.where( flow.logical_and( flow.logical_and(flow_div >= rate, flow_div <= (rate + 0.5)), flow_tensor < 0, ), 1, 0, ) ) .cpu() .detach() .numpy(), ) ) test_case.assertTrue( np.allclose( flow_grad.cpu().detach().numpy(), flow_div.cpu().detach().numpy(), rtol=1e-1, atol=1e-4, ) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): @unittest.skip("skip for now, becase it failed 4 times in past week") def test_numpy_case(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ do_test_rrelu_same_bound, do_test_rrelu_different_bound, ] arg_dict["shape"] = [ [20], [12, 32], [4, 47, 156], [5, 33, 65], [3, 132, 94], [9, 256, 63], ] # NOTE(hujiakui): in PyTorch <= 1.13, the CUDA RReLU Backward Function of PyTorch is wrong. if float(torch_original.__version__[:4]) < 1.13: arg_dict["device"] = ["cpu"] else: arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [np.float32, np.float64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_functional_rrelu(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) lower = np.abs( np.random.randn() ) # In-place leakyReLu backward calculation is triggered with a negative slope which is not supported return torch.nn.functional.rrelu( x, lower=lower, upper=lower + 0.5, inplace=random_bool(), training=False, ) @autotest(n=5) @unittest.skipIf( float(torch_original.__version__[:4]) < 1.13 and not os.getenv("ONEFLOW_TEST_CPU_ONLY"), f"RReLU CUDA test need pytorch version >= 1.13, got {torch_original.__version__}", ) def test_rrelu_train(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) lower = np.abs(np.random.randn()) m = torch.nn.RReLU(lower=lower, upper=lower, inplace=random_bool()) return m(x) @autotest(n=5, check_graph=False) def test_rrelu_eval(test_case): device = random_device() x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) lower = np.abs(np.random.randn()) m = torch.nn.RReLU(lower=lower, upper=lower, inplace=random_bool()).eval() return m(x) @profile(torch.nn.functional.rrelu) def profile_rrelu(test_case): lower = np.random.randn() torch.nn.functional.rrelu( torch.ones(1, 128, 28, 28), lower=lower, upper=lower + 0.5, inplace=False, training=True, ) torch.nn.functional.rrelu( torch.ones(1, 128, 28, 28), lower=lower, upper=lower + 0.5, inplace=True, training=True, ) torch.nn.functional.rrelu( torch.ones(16, 128, 28, 28), lower=lower, upper=lower + 0.5, inplace=False, training=True, ) torch.nn.functional.rrelu( torch.ones(16, 128, 28, 28), lower=lower, upper=lower + 0.5, inplace=True, training=True, ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_save_load.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import warnings import tempfile import unittest from pathlib import Path import io import numpy as np import torch import oneflow as flow import oneflow.nn as nn import oneflow.unittest class CustomModuleForSaveLoad(flow.nn.Module): def __init__(self): super().__init__() self.param = flow.nn.Parameter(flow.randn(1, 3, 3, 3)) def forward(self, x): return self.param + x @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestSaveLoad(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_load_map_location(test_case): x = flow.ones(1, 2, 3) y = flow.ones(2, 3, 4) with tempfile.NamedTemporaryFile() as f: flow.save({"x": x, "y": y}, f.name) loaded = flow.load(f.name, map_location="cuda") assert np.array_equal(loaded["x"].numpy(), x.numpy()) assert loaded["x"].device == flow.device("cuda") assert np.array_equal(loaded["y"].numpy(), y.numpy()) assert loaded["y"].device == flow.device("cuda") with tempfile.NamedTemporaryFile() as f: flow.save({"x": x, "y": y}, f.name) loaded = flow.load(f.name, map_location="cpu") assert np.array_equal(loaded["x"].numpy(), x.numpy()) assert loaded["x"].device == flow.device("cpu") assert np.array_equal(loaded["y"].numpy(), y.numpy()) assert loaded["y"].device == flow.device("cpu") x = x.to_global(sbp=flow.sbp.broadcast, placement=flow.placement("cuda", [0])) y = y.to_global(sbp=flow.sbp.broadcast, placement=flow.placement("cuda", [0])) with tempfile.NamedTemporaryFile() as f: flow.save({"x": x, "y": y}, f.name, global_dst_rank=0) loaded = flow.load( f.name, global_src_rank=0, map_location=flow.placement("cuda", [0]) ) assert np.array_equal(loaded["x"].numpy(), x.numpy()) assert loaded["x"].placement == flow.placement("cuda", [0]) assert np.array_equal(loaded["y"].numpy(), y.numpy()) assert loaded["y"].placement == flow.placement("cuda", [0]) with tempfile.NamedTemporaryFile() as f: flow.save({"x": x, "y": y}, f.name, global_dst_rank=0) loaded = flow.load( f.name, global_src_rank=0, map_location=flow.placement("cpu", [0]) ) assert np.array_equal(loaded["x"].numpy(), x.numpy()) assert loaded["y"].placement == flow.placement("cpu", [0]) assert np.array_equal(loaded["y"].numpy(), y.numpy()) assert loaded["y"].placement == flow.placement("cpu", [0]) @flow.unittest.skip_unless_1n1d() def test_save_dir(test_case): m1 = CustomModuleForSaveLoad() with tempfile.TemporaryDirectory() as save_dir: flow.save(m1.state_dict(), save_dir, save_as_external_data=True) loaded_state_dict = flow.load(save_dir) m2 = CustomModuleForSaveLoad() m2.load_state_dict(loaded_state_dict) test_case.assertTrue(np.array_equal(m1.param.numpy(), m2.param.numpy())) @flow.unittest.skip_unless_1n1d() def test_save_dir_fault_tolerance(test_case): m1 = CustomModuleForSaveLoad() with tempfile.TemporaryDirectory() as save_dir: flow.save(m1.state_dict(), save_dir, save_as_external_data=True) with open(os.path.join(save_dir, "random_file"), "w") as fp: fp.write("nothing") with warnings.catch_warnings(): warnings.simplefilter("ignore") loaded_state_dict = flow.load(save_dir) m2 = CustomModuleForSaveLoad() m2.load_state_dict(loaded_state_dict) test_case.assertTrue(np.array_equal(m1.param.numpy(), m2.param.numpy())) @flow.unittest.skip_unless_1n1d() def test_save_state_dict(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) def forward(self): return self.param1 + self.param2 m = CustomModule() res1 = m() state_dict = m.state_dict() with tempfile.NamedTemporaryFile() as f: flow.save(state_dict, f.name) test_case.assertTrue(os.path.exists(f.name)) loaded_state_dict = flow.load(f.name) m.load_state_dict(loaded_state_dict) res2 = m() test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) @flow.unittest.skip_unless_1n1d() def test_save_state_dict_bytes(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) def forward(self): return self.param1 + self.param2 m = CustomModule() res1 = m() state_dict = m.state_dict() with tempfile.NamedTemporaryFile() as path: buffer = io.BytesIO() flow.save(state_dict, buffer) with open(path.name, "wb") as f: f.write(buffer.getvalue()) test_case.assertTrue(os.path.exists(path.name)) loaded_state_dict = flow.load(path.name) m.load_state_dict(loaded_state_dict) res2 = m() test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) def _test_save_and_load_global_from_nested_dict(test_case): class CustomModule(flow.nn.Module): def __init__(self): super().__init__() self.param = flow.nn.Parameter(flow.randn(3, 32, 3, 3)) def forward(self): return self.param m1 = CustomModule() m1 = m1.to_global( flow.placement("cuda", range(1, 3)), flow.sbp.broadcast ).to_global(sbp=flow.sbp.split(1)) m2 = CustomModule() m2 = m2.to_global(flow.placement("cuda", range(1, 3)), flow.sbp.broadcast) res1 = m1() + m2() state_dict1 = m1.state_dict() state_dict2 = m2.state_dict() state_dict = {"m1": state_dict1, "m2": state_dict2} with tempfile.TemporaryDirectory() as dir: filename = os.path.join(dir, "tmp") with test_case.assertRaises(Exception): flow.save(state_dict, filename) global_src_dst_rank = 0 flow.save(state_dict, filename, global_dst_rank=global_src_dst_rank) rank = flow.env.get_rank() if rank != global_src_dst_rank: test_case.assertFalse(os.path.exists(filename)) m1 = CustomModule() m1 = m1.to_global( flow.placement("cuda", [[0, 1], [2, 3]]), [flow.sbp.broadcast, flow.sbp.broadcast], ).to_global(sbp=[flow.sbp.split(1), flow.sbp.broadcast]) m2 = CustomModule() m2 = m2.to_global( flow.placement("cuda", [[0, 1], [2, 3]]), [flow.sbp.broadcast, flow.sbp.broadcast], ).to_global(sbp=[flow.sbp.broadcast, flow.sbp.split(1)]) with test_case.assertRaises(Exception): loaded_state_dict = flow.load(filename) m1.load_state_dict(loaded_state_dict["m1"]) loaded_state_dict = flow.load(filename, global_src_rank=global_src_dst_rank) test_case.assertEqual(len(loaded_state_dict), 2) m1.load_state_dict(loaded_state_dict["m1"]) m2.load_state_dict(loaded_state_dict["m2"]) res2 = m1() + m2() test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) @flow.unittest.skip_unless_1n4d() def test_save_and_load_global_from_nested_dict_1n4d(test_case): test_case._test_save_and_load_global_from_nested_dict() @flow.unittest.skip_unless_2n2d() def test_save_and_load_global_from_nested_dict_2n2d(test_case): test_case._test_save_and_load_global_from_nested_dict() @flow.unittest.skip_unless_1n1d() def test_load_pytorch_weights(test_case): for device in ["cpu", "cuda"]: for map_location in [None, flow.device("cuda:0")]: conv_torch = torch.nn.Conv2d(3, 3, 3).to(device) conv_flow1 = flow.nn.Conv2d(3, 3, 3).to(device) with tempfile.NamedTemporaryFile() as f: torch.save(conv_torch.state_dict(), f.name) conv_flow1.load_state_dict( flow.load(f.name, map_location=map_location) ) test_case.assertTrue( np.array_equal( conv_torch.weight.detach().cpu().numpy(), conv_flow1.weight.numpy(), ) ) conv_flow2 = flow.nn.Conv2d(3, 3, 3).to(device) with tempfile.NamedTemporaryFile() as f: torch.save({"weights": conv_torch.state_dict()}, f.name) conv_flow2.load_state_dict( flow.load(f.name, map_location=map_location)["weights"] ) test_case.assertTrue( np.array_equal( conv_torch.weight.detach().cpu().numpy(), conv_flow2.weight.numpy(), ) ) @flow.unittest.skip_unless_1n2d() def test_load_pytorch_weights_global(test_case): for device in ["cpu", "cuda"]: for map_location in [None, flow.placement.all("cuda")]: conv_torch = torch.nn.Conv2d(3, 3, 3).to(device) all_placement = flow.placement.all(device) conv_flow1 = flow.nn.Conv2d(3, 3, 3).to_global( all_placement, flow.sbp.broadcast ) with tempfile.NamedTemporaryFile() as f: if flow.env.get_rank() == 0: torch.save(conv_torch.state_dict(), f.name) conv_flow1.load_state_dict( flow.load(f.name, map_location=map_location, global_src_rank=0) ) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( conv_torch.weight.detach().cpu().numpy(), conv_flow1.weight.numpy(), ) ) conv_flow2 = flow.nn.Conv2d(3, 3, 3).to_global( all_placement, flow.sbp.broadcast ) with tempfile.NamedTemporaryFile() as f: if flow.env.get_rank() == 0: torch.save({"weights": conv_torch.state_dict()}, f.name) conv_flow2.load_state_dict( flow.load(f.name, map_location=map_location, global_src_rank=0)[ "weights" ] ) if flow.env.get_rank() == 0: test_case.assertTrue( np.array_equal( conv_torch.weight.detach().cpu().numpy(), conv_flow2.weight.numpy(), ) ) @flow.unittest.skip_unless_1n1d() def test_save_load_module_directly(test_case): x = flow.randn(1, 3, 3, 3) m = CustomModuleForSaveLoad() with tempfile.NamedTemporaryFile() as f: flow.save(m, f.name) new_m = flow.load(f.name) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) m = flow.nn.parallel.DistributedDataParallel(m) test_case.assertTrue(m._is_ddp_module) with tempfile.NamedTemporaryFile() as f: flow.save(m, f.name) new_m = flow.load(f.name) test_case.assertTrue(new_m._is_ddp_module) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) @flow.unittest.skip_unless_1n1d() def test_save_load_module_directly_save_bytes(test_case): x = flow.randn(1, 3, 3, 3) m = CustomModuleForSaveLoad() with tempfile.NamedTemporaryFile() as path: buffer = io.BytesIO() flow.save(m, buffer) with open(path.name, "wb") as f: f.write(buffer.getvalue()) new_m = flow.load(path.name) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) m = flow.nn.parallel.DistributedDataParallel(m) test_case.assertTrue(m._is_ddp_module) with tempfile.NamedTemporaryFile() as path: buffer = io.BytesIO() flow.save(m, buffer) with open(path.name, "wb") as f: f.write(buffer.getvalue()) new_m = flow.load(path.name) test_case.assertTrue(new_m._is_ddp_module) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) @flow.unittest.skip_unless_1n1d() def test_save_load_module_directly_load_filestream(test_case): x = flow.randn(1, 3, 3, 3) m = CustomModuleForSaveLoad() with tempfile.NamedTemporaryFile() as f: flow.save(m, f.name) with open(f.name, "rb") as r: new_m = flow.load(r) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) m = flow.nn.parallel.DistributedDataParallel(m) test_case.assertTrue(m._is_ddp_module) with tempfile.NamedTemporaryFile() as f: flow.save(m, f.name) with open(f.name, "rb") as r: new_m = flow.load(r) test_case.assertTrue(new_m._is_ddp_module) res = m(x) new_res = new_m(x) test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy())) def test_load_old_dir_data(test_case): test_data_dir = Path(__file__).parent / "save_load_test_data" m1 = nn.Conv2d(3, 3, 3) params = flow.load(test_data_dir / "3x3_i3o3_conv2d_params") m1.load_state_dict(params) m2 = flow.load(test_data_dir / "3x3_i3o3_conv2d") x = flow.randn(1, 3, 3, 3) y1 = m1(x) y2 = m2(x) test_case.assertTrue(np.array_equal(y1.numpy(), y2.numpy())) def test_pytorch_non_tensor(test_case): with tempfile.NamedTemporaryFile() as f: torch.save({"a": 2}, f.name) res = flow.load(f.name, map_location="cpu") test_case.assertTrue(isinstance(res, dict)) test_case.assertEqual(len(res), 1) test_case.assertEqual(res["a"], 2) def test_pytorch_non_tensor_load_filestream(test_case): with tempfile.NamedTemporaryFile() as f: torch.save({"a": 2}, f.name) with open(f.name, "rb") as r: res = flow.load(r, map_location="cpu") test_case.assertTrue(isinstance(res, dict)) test_case.assertEqual(len(res), 1) test_case.assertEqual(res["a"], 2) def test_pytorch_non_tensor_save_bytes(test_case): with tempfile.NamedTemporaryFile() as path: buffer = io.BytesIO() torch.save({"a": 2}, buffer) with open(path.name, "wb") as f: f.write(buffer.getvalue()) res = flow.load(path.name, map_location="cpu") test_case.assertTrue(isinstance(res, dict)) test_case.assertEqual(len(res), 1) test_case.assertEqual(res["a"], 2) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_saved_tensor_hooks.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestSavedTensorHooks(flow.unittest.TestCase): def test_normal_saved_tensor_hooks(test_case): x = flow.ones(1, 2, 3).to("cuda").requires_grad_() y = flow.zeros(1, 2, 3).to("cuda").requires_grad_() tensor_list = [] def pack(x): tensor_list.append(x) return len(tensor_list) - 1 def unpack(x): return tensor_list[x] with flow.autograd.graph.saved_tensors_hooks(pack, unpack): z = x * y z.sum().backward() test_case.assertEqual(len(tensor_list), 2) test_case.assertTrue(np.array_equal(tensor_list[0], y)) test_case.assertTrue(np.array_equal(tensor_list[1], x)) test_case.assertTrue(np.allclose(x.grad, y)) test_case.assertTrue(np.allclose(y.grad, x)) def test_saved_tensor_hooks_in_autograd_function(test_case): x = flow.ones(1, 2, 3).to("cuda").requires_grad_() y = flow.zeros(1, 2, 3).to("cuda").requires_grad_() tensor_list = [] def pack(x): tensor_list.append(x) return len(tensor_list) - 1 def unpack(x): return tensor_list[x] class MulFunction(flow.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) return x * y @staticmethod def backward(ctx, dz): x, y = ctx.saved_tensors dx = dz * y dy = dz * x return dx, dy with flow.autograd.graph.saved_tensors_hooks(pack, unpack): z = MulFunction.apply(x, y) z.sum().backward() test_case.assertEqual(len(tensor_list), 2) test_case.assertTrue(np.array_equal(tensor_list[0], x)) test_case.assertTrue(np.array_equal(tensor_list[1], y)) test_case.assertTrue(np.allclose(x.grad, y)) test_case.assertTrue(np.allclose(y.grad, x)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sbp_symbol.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestSBPSymbol(flow.unittest.TestCase): def test_sbp_symbol(test_case): test_case.assertTrue(flow.sbp.split(0) == flow.sbp.split(0)()) test_case.assertTrue(flow.sbp.split(1) == flow.sbp.split(1)()) test_case.assertTrue(flow.sbp.split(0) != flow.sbp.split(1)) test_case.assertTrue(flow.sbp.broadcast == flow.sbp.broadcast()) test_case.assertTrue(flow.sbp.partial_sum == flow.sbp.partial_sum()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_scatter_nd.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_scatter_nd(test_case, device): indices = flow.tensor( np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device) ) update = flow.tensor( np.array([10.2, 5.1, 12.7]), dtype=flow.float, device=flow.device(device) ) np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0]) output = flow.scatter_nd(indices, update, [8]) test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) def _test_scatter_nd_t(test_case, device): indices = flow.tensor( np.array([[0], [4], [2]]), dtype=flow.int, device=flow.device(device) ) update = flow.tensor( np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]), dtype=flow.float, device=flow.device(device), ) np_out = np.array( [ [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [3.0, 3.0, 3.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0], ] ) output = flow.scatter_nd(indices, update, [5, 3]) test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) def _test_scatter_nd_backward(test_case, device): indices = flow.tensor( np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device) ) of_update = flow.tensor( np.array([10.2, 5.1, 12.7]), requires_grad=True, dtype=flow.float, device=flow.device(device), ) np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0]) np_grad = np.array([1.0, 1.0, 1.0]) output = flow.scatter_nd(indices, of_update, [8]) out_sum = output.sum() out_sum.backward() test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) test_case.assertTrue(np.array_equal(of_update.grad.numpy(), np_grad)) @flow.unittest.skip_unless_1n1d() class TestScatter_nd(flow.unittest.TestCase): def test_scatter_nd(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_scatter_nd, _test_scatter_nd_t, _test_scatter_nd_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_scatter_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import numpy as np from oneflow.test_utils.automated_test_util import * def _get_indexes(device): return ( constant( torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[1, 0], [1, 0]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[0, 1], [0, 1]]), dtype=torch.int64, device=device) ), ) def _test_scatter(test_case, test_scalar: bool, dim: int): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=2).to(device) src = 3.14 if test_scalar else random_tensor(ndim=2, dim0=2, dim1=2).to(device) y = torch.scatter(input, dim, oneof(*_get_indexes(device)), src) return y def _test_scatter_add(test_case, dim: int): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=2).to(device) src = random_tensor(ndim=2, dim0=2, dim1=2).to(device) y = torch.scatter_add(input, dim, oneof(*_get_indexes(device)), src) return y def _test_scatter_reduce(test_case, dim: int): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=2).to(device) src = random_tensor(ndim=2, dim0=2, dim1=2).to(device) y = torch.scatter( input, dim, oneof(*_get_indexes(device)), src, reduce=oneof("add", "multiply", nothing()), ) return y @flow.unittest.skip_unless_1n1d() class TestScatterOpsModule(flow.unittest.TestCase): @autotest(n=10) def test_scatter_with_random_data(test_case): return _test_scatter(test_case, oneof(True, False), oneof(0, 1, -1)) @autotest(n=5) def test_scatter_add_with_random_data(test_case): return _test_scatter_add(test_case, oneof(0, 1)) @autotest( n=5, auto_backward=False ) # peihong: pytorch dose not support backward when reduce is add or multiply def test_scatter_reduce_with_random_data(test_case): return _test_scatter_reduce(test_case, oneof(0, 1)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_searchsorted.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util.torch_flow_dual_object import autotest def _test_search_sorted(test_case, input_dtype, device): sorted_sequence = flow.tensor( np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), dtype=input_dtype, device=flow.device(device), ) values = flow.tensor( np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device) ) gt = np.array([[1, 3, 4], [1, 2, 4]]) output = flow.searchsorted(sorted_sequence, values) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_1(test_case, input_dtype, device): sorted_sequence = flow.tensor( np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), dtype=input_dtype, device=flow.device(device), ) values = flow.tensor( np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device) ) gt = np.array([[2, 3, 5], [1, 3, 4]]) output = flow.searchsorted(sorted_sequence, values, right=True, side="right") test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_2(test_case, input_dtype, device): sorted_sequence_1d = flow.tensor( np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device) ) values = flow.tensor( np.array([3, 6, 9]), dtype=input_dtype, device=flow.device(device) ) gt = np.array([1, 3, 4]) output = flow.searchsorted(sorted_sequence_1d, values) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_3(test_case, input_dtype, device): sorted_sequence = flow.tensor( np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), dtype=input_dtype, device=flow.device(device), ) values = flow.tensor( np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device) ) gt = np.array([[1, 3, 4], [1, 2, 4]]) output = flow.searchsorted(sorted_sequence, values, out_int32=True) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int32) def _test_search_sorted_4(test_case, input_dtype, device): sorted_sequence = flow.tensor( np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), dtype=input_dtype, device=flow.device(device), ) values = flow.tensor( np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device) ) sorter = flow.tensor( np.array([[4, 3, 2, 1, 0], [3, 2, 4, 0, 1]]), dtype=flow.int64, device=flow.device(device), ) gt = np.array([[0, 5, 5], [0, 0, 2]]) output = flow.searchsorted(sorted_sequence, values, sorter=sorter) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_5(test_case, input_dtype, device): sorted_sequence_1d = flow.tensor( np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device) ) gt = np.array(2) output = flow.searchsorted(sorted_sequence_1d, 5) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_6(test_case, input_dtype, device): sorted_sequence_1d = flow.tensor( np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device) ) gt = np.array(3) output = flow.searchsorted(sorted_sequence_1d, 5, right=True, side="right") test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int64) def _test_search_sorted_7(test_case, input_dtype, device): sorted_sequence_1d = flow.tensor( np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device) ) gt = np.array(2) output = flow.searchsorted(sorted_sequence_1d, 5, out_int32=True) test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001)) test_case.assertTrue(output.dtype == flow.int32) @flow.unittest.skip_unless_1n1d() class TestSearchSorted(flow.unittest.TestCase): def test_search_sorted(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_search_sorted, _test_search_sorted_1, _test_search_sorted_2, _test_search_sorted_3, _test_search_sorted_4, _test_search_sorted_5, _test_search_sorted_6, _test_search_sorted_7, ] arg_dict["input_dtype"] = [ flow.int8, flow.int32, flow.int64, flow.float, flow.double, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=20, auto_backward=False, check_dtype=True) def test_search_sorted(test_case): device = random_device() sorted_sequence = random_tensor(ndim=2, dim0=2, dim1=3).to(device) values = random_tensor(ndim=2, dim0=2).to(device) right = oneof(True, False) y = torch.searchsorted( sorted_sequence, values, out_int32=oneof(True, False), right=right, ) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_select.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import shuffle from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.automated_test_util import util import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestSelect(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_select(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) dim = random(-4, 3).to(int) index = random(0, 2).to(int) z = torch.select(x, dim, index) return z # TODO:(zhaoluyang) some bug in as_strided backward to be fixed @autotest(n=10, auto_backward=False, check_graph=True) def test_flow_select_with_stride(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) dim = random(-4, 3).to(int) index = random(0, 2).to(int) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.select(y, dim, index) return z @autotest(check_graph=True) def test_flow_select_1dim(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random(3, 6),).to(device) index = random(0, 2).to(int) z = torch.select(x, 0, index) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_shutting_down.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow import os world_size = os.getenv("WORLD_SIZE") class _TestCallWhenShuttingDown: def __init__(self): self.oneflow = oneflow tensor = oneflow.ones((2, 2)) print(tensor) def __del__(self, of=oneflow): try: if world_size == 1: tensor = of.ones((2, 2)) except: # Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402 print("__del__ at shutting down phase in Python is not stable.") test_call_when_shutting_down = _TestCallWhenShuttingDown() class _TestSyncWhenShuttingDown: def __init__(self): self.eager = oneflow._oneflow_internal.eager def __del__(self): try: self.eager.Sync() except: # Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402 print("__del__ at shutting down phase in Python is not stable.") test_sync_when_shutting_down = _TestSyncWhenShuttingDown() ================================================ FILE: python/oneflow/test/modules/test_sign.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_sign_impl(test_case, shape, device): np_input = np.random.randn(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = flow.sign(of_input) np_out = np.sign(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() np_grad = np.zeros_like(np_input) test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestSign(flow.unittest.TestCase): def test_sign(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_sign_impl(test_case, *arg) @autotest(n=5) def test_sign_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.sign(x) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_sign_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 4).to(device) y = torch.sign(x) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_sign_with_random_data(test_case): device = random_device() x = random_tensor().to(device=device, dtype=torch.bool) y = torch.sign(x) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_sign_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.sign(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_single_threaded_vm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import subprocess import sys import os import unittest import oneflow as flow import oneflow.unittest class TestSingleThreadedVM(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_ddp_in_single_threaded_vm(test_case): # Environment variables of current process like ONEFLOW_TEST_DEVICE_NUM # and environment variables about distributed training (i.e. MASTER_ADDR, # MASTER_PORT, WORLD_SIZE, RANK) are all in `env`. env = os.environ.copy() env["ONEFLOW_VM_MULTI_THREAD"] = "0" p = subprocess.run( [sys.executable, "test_ddp.py"], cwd=os.path.dirname(os.path.realpath(__file__)), env=env, ) test_case.assertEqual(p.returncode, 0) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_skip_layer_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest import oneflow as flow import oneflow.nn as nn import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList is_profiling = False def compare_result(test_case, a, b, rtol=1e-5, atol=1e-8): test_case.assertTrue( np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol), f"\na\n{a.numpy()}\n{'-' * 80}\nb:\n{b.numpy()}\n{'*' * 80}\ndiff:\n{a.numpy() - b.numpy()}", ) class NaiveSkipLayerNorm(nn.Module): def __init__(self): super().__init__() def forward( self, x: flow.Tensor, gamma: flow.Tensor, beta: flow.Tensor, bias: flow.Tensor = None, skip: flow.Tensor = None, alpha: float = 1e-5, eps: float = 1e-6, ) -> flow.Tensor: begin_norm_axis = len(x.shape) - 1 begin_params_axis = len(x.shape) - 1 if bias is not None: x = flow._C.add(input=x, other=bias) if skip is not None: skip = skip * alpha x = flow._C.add(input=x, other=skip) return flow._C.layer_norm_affine( x, gamma, beta, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, epsilon=eps, ) class FusedSkipLayerNorm(nn.Module): def __init__(self): super().__init__() def forward( self, x: flow.Tensor, gamma: flow.Tensor, beta: flow.Tensor, bias: flow.Tensor = None, skip: flow.Tensor = None, alpha: float = 1e-5, eps: float = 1e-6, ) -> flow.Tensor: return flow._C.skip_layer_norm( x=x, gamma=gamma, beta=beta, bias=bias, skip=skip, alpha=alpha, epsilon=eps ) def _test_skip_layer_norm( test_case, x_shape, has_gamma, has_beta, has_bias, has_skip, eps=1e-6, alpha=1e-5, dtype=flow.float32, ): print( f"x_shape: {x_shape}\nhas_gamma: {has_gamma}\nhas_beta: {has_beta}\nhas_bias: {has_bias}\nhas_skip: {has_skip}\ndtype: {dtype}\n" ) normalize_shape = list() normalize_shape.append(x_shape[-1]) np_dtype = np.float16 if dtype is flow.float16 else np.float32 # generate np array np_x = np.random.randn(*x_shape).astype(np_dtype) naive_flow_gamma = None fused_flow_gamma = None if has_gamma: np_gamma = np.random.randn(*normalize_shape).astype(np_dtype) naive_flow_gamma = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) fused_flow_gamma = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) else: np_gamma = np.ones(*normalize_shape).astype(np_dtype) naive_flow_gamma = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) naive_flow_beta = None fused_flow_beta = None if has_beta: np_beta = np.random.randn(*normalize_shape).astype(np_dtype) naive_flow_beta = flow.tensor(np_beta).to(device="cuda", dtype=dtype) fused_flow_beta = flow.tensor(np_beta).to(device="cuda", dtype=dtype) else: np_beta = np.zeros(*normalize_shape).astype(np_dtype) naive_flow_beta = flow.tensor(np_beta).to(device="cuda", dtype=dtype) flow_bias = None if has_bias: np_bias = np.random.randn(*normalize_shape).astype(np_dtype) flow_bias = flow.tensor(np_bias).to(device="cuda", dtype=dtype) flow_skip_naive = None flow_skip_fused = None np_skip = None if has_skip: np_skip = np.random.randn(*x_shape).astype(np_dtype) flow_skip_naive = flow.tensor(np_skip).to(device="cuda", dtype=dtype) flow_skip_fused = flow.tensor(np_skip).to(device="cuda", dtype=dtype) # naive process flow_naive_module = NaiveSkipLayerNorm() flow_x_naive = flow.tensor(np_x).to(device="cuda", dtype=dtype) flow_y_naive = flow_naive_module.forward( x=flow_x_naive, gamma=naive_flow_gamma, beta=naive_flow_beta, bias=flow_bias, skip=flow_skip_naive, alpha=alpha, eps=eps, ) # fused process flow_fused_module = FusedSkipLayerNorm() flow_x_fused = flow.tensor(np_x).to(device="cuda", dtype=dtype) flow_y_fused = flow_fused_module.forward( x=flow_x_fused, gamma=fused_flow_gamma, beta=fused_flow_beta, bias=flow_bias, skip=flow_skip_fused, alpha=alpha, eps=eps, ) if dtype is flow.float16: compare_result(test_case, flow_y_naive, flow_y_fused, 1e-2, 1e-2) else: compare_result(test_case, flow_y_naive, flow_y_fused, 1e-4, 1e-4) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestSkipLayerNorm(flow.unittest.TestCase): def test_gather(test_case): arg_dict = OrderedDict() # set up test functions arg_dict["test_fun"] = [ _test_skip_layer_norm, ] # set up test parameters if is_profiling: arg_dict["x_shape"] = [[1, 5120]] arg_dict["has_gamma"] = [True] arg_dict["has_beta"] = [True] arg_dict["has_bias"] = [True] arg_dict["has_skip"] = [True] arg_dict["eps"] = [1e-6] arg_dict["alpha"] = [1e-5] arg_dict["dtype"] = [flow.float32] else: arg_dict["x_shape"] = [[1, 5120]] arg_dict["has_gamma"] = [True, False] arg_dict["has_beta"] = [True, False] arg_dict["has_bias"] = [True, False] arg_dict["has_skip"] = [True, False] arg_dict["eps"] = [1e-6] arg_dict["alpha"] = [1e-5] arg_dict["dtype"] = [flow.float32, flow.float16] # run test functions for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_skip_rms_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import numpy as np import unittest import oneflow as flow import oneflow.nn as nn import oneflow.unittest from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList is_profiling = False def compare_result(test_case, a, b, rtol=1e-5, atol=1e-8): test_case.assertTrue( np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol), f"\na\n{a.numpy()}\n{'-' * 80}\nb:\n{b.numpy()}\n{'*' * 80}\ndiff:\n{a.numpy() - b.numpy()}", ) class NaiveSkipRMSNorm(nn.Module): def __init__(self): super().__init__() def forward( self, x: flow.Tensor, weight: flow.Tensor, bias: flow.Tensor = None, skip: flow.Tensor = None, alpha: float = 1e-5, eps: float = 1e-6, ) -> flow.Tensor: if bias is not None: x = flow._C.add(input=x, other=bias) if skip is not None: skip = skip * alpha x = flow._C.add(input=x, other=skip) return flow._C.rms_norm(x, weight, [x.shape[-1]], eps) class FusedSkipRMSNorm(nn.Module): def __init__(self): super().__init__() def forward( self, x: flow.Tensor, weight: flow.Tensor, bias: flow.Tensor = None, skip: flow.Tensor = None, alpha: float = 1e-5, eps: float = 1e-6, ) -> flow.Tensor: return flow._C.skip_rms_norm( x=x, weight=weight, bias=bias, skip=skip, epsilon=eps, alpha=alpha ) def _test_skip_rms_norm( test_case, x_shape, has_weight, has_bias, has_skip, eps=1e-6, alpha=1e-5, dtype=flow.float32, ): print( f"x_shape: {x_shape}\nhas_weight: {has_weight}\nhas_bias: {has_bias}\nhas_skip: {has_skip}\ndtype: {dtype}\n" ) normalize_shape = list() normalize_shape.append(x_shape[-1]) np_dtype = np.float16 if dtype is flow.float16 else np.float32 # generate np array np_x = np.random.randn(*x_shape).astype(np_dtype) naive_flow_weight = None fused_flow_weight = None if has_weight: np_gamma = np.random.randn(*normalize_shape).astype(np_dtype) naive_flow_weight = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) fused_flow_weight = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) else: np_gamma = np.ones(*normalize_shape).astype(np_dtype) naive_flow_gamma = flow.tensor(np_gamma).to(device="cuda", dtype=dtype) flow_bias = None if has_bias: np_bias = np.random.randn(*normalize_shape).astype(np_dtype) flow_bias = flow.tensor(np_bias).to(device="cuda", dtype=dtype) flow_skip_naive = None flow_skip_fused = None np_skip = None if has_skip: np_skip = np.random.randn(*x_shape).astype(np_dtype) flow_skip_naive = flow.tensor(np_skip).to(device="cuda", dtype=dtype) flow_skip_fused = flow.tensor(np_skip).to(device="cuda", dtype=dtype) # naive process flow_naive_module = NaiveSkipRMSNorm() flow_x_naive = flow.tensor(np_x).to(device="cuda", dtype=dtype) flow_y_naive = flow_naive_module.forward( x=flow_x_naive, weight=naive_flow_weight, bias=flow_bias, skip=flow_skip_naive, alpha=alpha, eps=eps, ) # fused process flow_fused_module = FusedSkipRMSNorm() flow_x_fused = flow.tensor(np_x).to(device="cuda", dtype=dtype) flow_y_fused = flow_fused_module.forward( x=flow_x_fused, weight=fused_flow_weight, bias=flow_bias, skip=flow_skip_fused, alpha=alpha, eps=eps, ) if dtype is flow.float16: compare_result(test_case, flow_y_naive, flow_y_fused, 1e-2, 1e-2) else: compare_result(test_case, flow_y_naive, flow_y_fused, 1e-4, 1e-4) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestSkipRMSNorm(flow.unittest.TestCase): def test_gather(test_case): arg_dict = OrderedDict() # set up test functions arg_dict["test_fun"] = [ _test_skip_rms_norm, ] # set up test parameters if is_profiling: arg_dict["x_shape"] = [[1, 5120]] arg_dict["has_weight"] = [True] arg_dict["has_bias"] = [True] arg_dict["has_skip"] = [True] arg_dict["eps"] = [1e-6] arg_dict["alpha"] = [1e-5] arg_dict["dtype"] = [flow.float32] else: arg_dict["x_shape"] = [[1, 5120]] arg_dict["has_weight"] = [True, False] arg_dict["has_bias"] = [True, False] arg_dict["has_skip"] = [True, False] arg_dict["eps"] = [1e-6] arg_dict["alpha"] = [1e-5] arg_dict["dtype"] = [flow.float32, flow.float16] # run test functions for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_slice.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict from random import randint import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_slice(test_case, device): np_arr = np.random.randn(3, 6, 9).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] y = flow.slice(x, slice_tup_list=tup_list) flow_tmp = x[0:3, 0:5, 0:6] y = flow_tmp[::1, ::2, ::3] tmp = np_arr[0:3, 0:5, 0:6] np_out = tmp[::1, ::2, ::3] test_case.assertTrue(np.array_equal(y.numpy(), np_out)) def _test_slice_empty(test_case, device): np_arr = np.random.randn(10).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) y = x[3:3] test_case.assertTrue(y.shape, flow.Size((0,))) np_out = np_arr[3:3] test_case.assertTrue(np.array_equal(y.numpy(), np_out)) def _test_slice_1_dim(test_case, device): np_arr = np.random.randn(100).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) test_case.assertTrue(np.allclose(x[1].numpy(), np_arr[1], 1e-05, 1e-05)) test_case.assertTrue(np.allclose(x[99].numpy(), np_arr[99], 1e-05, 1e-05)) test_case.assertTrue(np.allclose(x[0:2].numpy(), np_arr[0:2], 1e-05, 1e-05)) def _test_slice_3_dim(test_case, device): np_arr = np.random.randn(2, 3, 4).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) test_case.assertTrue(np.allclose(x[:, 0].numpy(), np_arr[:, 0], 1e-05, 1e-05)) def _test_slice_4_dim(test_case, device): np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) tup_list = [[0, 5, 2], [None, None, None], [0, 5, 2], [0, 6, 3]] y = flow.slice(x, slice_tup_list=tup_list) tmp = np_arr[0:5, 0:3, 0:5, 0:6] np_out = tmp[::2, ::1, ::2, ::3] test_case.assertTrue(np.array_equal(y.numpy(), np_out)) def _test_slice_with_int_index(test_case, device): np_arr = np.random.randn(2, 3, 4).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) of_out = x[0, 1:2] np_out = np_arr[0, 1:2] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) np_arr = np.random.randn(2, 3, 4).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) of_out = x[0, :] np_out = np_arr[0, :] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) of_out = x[0, :, :] np_out = np_arr[0, :, :] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) np_arr = np.random.randn(2, 3, 4, 5).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) of_out = x[0, :, :, :] np_out = np_arr[0, :, :, :] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_slice_negative_index(test_case, device): np_arr = np.random.randn(4, 5, 6) x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) test_case.assertTrue(np.allclose(x[-1].numpy(), np_arr[-1], 0.0001, 0.0001)) test_case.assertTrue(np.allclose(x[-2].numpy(), np_arr[-2], 0.0001, 0.0001)) test_case.assertTrue(np.allclose(x[-3].numpy(), np_arr[-3], 0.0001, 0.0001)) test_case.assertTrue(np.allclose(x[-4].numpy(), np_arr[-4], 0.0001, 0.0001)) def _test_slice_ellipsis_type(test_case, device): np_arr = np.random.randn(2, 3, 4, 5, 6, 7).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device)) of_out = x[..., ::2, ::2, 3:4] np_out = np_arr[..., ::2, ::2, 3:4] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) of_out = x[..., 1:2, ::2, 1, ::3] np_out = np_arr[..., 1:2, ::2, 1, ::3] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) of_out = x[0, 2, ..., 1, 1:2] np_out = np_arr[0, 2, ..., 1, 1:2] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) of_out = x[::2, ..., 1:2] np_out = np_arr[::2, ..., 1:2] test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) def _test_slice_backward(test_case, device): np_arr = np.random.randn(3, 6, 9).astype(np.float32) x = flow.tensor(np_arr, device=flow.device(device), requires_grad=True) tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] y = flow.slice(x, slice_tup_list=tup_list) z = y.sum() z.backward() np_grad = np.zeros((3, 6, 9)) np_grad[0:3, 0:5, 0:6][::1, ::2, ::3] = 1 test_case.assertTrue(np.array_equal(x.grad.numpy(), np_grad)) def _test_slice_scalar(test_case, device): dtype = [flow.int8, flow.int16, flow.int32, flow.int64] x = flow.randn(50, 534, 800, device=device) for d in dtype: scalar = flow.tensor(3, dtype=d, device=device) y = x[scalar] test_case.assertTrue(y.shape, (534, 800)) @flow.unittest.skip_unless_1n1d() class TestSlice(flow.unittest.TestCase): def test_slice(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_slice, _test_slice_empty, _test_slice_1_dim, _test_slice_3_dim, _test_slice_4_dim, _test_slice_with_int_index, _test_slice_negative_index, _test_slice_ellipsis_type, _test_slice_backward, _test_slice_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @flow.unittest.skip_unless_1n1d() class TestSliceUpdate(flow.unittest.TestCase): def test_slice_update(test_case): x = np.array([1, 1, 1, 1, 1]).astype(np.float32) input = flow.tensor(x) update = flow.tensor(np.array([2, 3, 4]).astype(np.float32)) output = np.array([1.0, 2.0, 3.0, 4.0, 1.0]) flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]]) test_case.assertTrue(np.array_equal(input.numpy(), output)) def test_slice_update_negative_index(test_case): np_arr = np.zeros(shape=(2, 3, 4)) input = flow.tensor(np_arr, dtype=flow.float32) np_arr[-1] = 1 input[-1] = 1 test_case.assertTrue(np.array_equal(input.numpy(), np_arr)) def test_slice_update_scalar_integer_tensor_index(test_case): np_arr_a = np.random.rand(133, 1, 15) np_arr_b = np.random.rand(133, 2, 1) a_torch = torch.tensor(np_arr_a) b_torch = torch.tensor(np_arr_b) pos_torch = torch.tensor(0) a_torch[:, 0, pos_torch] = b_torch[:, 1, 0] a_flow = flow.tensor(np_arr_a) b_flow = flow.tensor(np_arr_b) pos_flow = flow.tensor(0) a_flow[:, 0, pos_flow] = b_flow[:, 1, 0] test_case.assertTrue( np.allclose(a_flow.numpy(), a_torch.cpu().numpy(), rtol=1e-5, atol=1e-5,) ) def test_slice_update_scalar_boolean_tensor_index(test_case): np_arr_a = np.random.rand(2, 1, 2) np_arr_b = np.random.rand(2, 2, 1) a_torch = torch.tensor(np_arr_a) b_torch = torch.tensor(np_arr_b) pos_torch = torch.tensor(True) a_torch[:, 0, pos_torch] = b_torch[:, 1, 0] a_flow = flow.tensor(np_arr_a) b_flow = flow.tensor(np_arr_b) pos_flow = flow.tensor(True) a_flow[:, 0, pos_flow] = b_flow[:, 1, 0] test_case.assertTrue( np.allclose(a_flow.numpy(), a_torch.cpu().numpy(), rtol=1e-5, atol=1e-5,) ) def test_slice_update_negative_index_graph(test_case): np_arr = np.zeros(shape=(2, 3, 4)) input = flow.tensor(np_arr, dtype=flow.float32) np_arr[-1] = 1 @flow.nn.Graph.trace def test_func(): input[-1] = 1 return input out = test_func() test_case.assertTrue(np.array_equal(out.numpy(), np_arr)) def test_slice_update_different_dtype(test_case): x = np.array([1, 1, 1, 1, 1]).astype(np.float32) for value_type in [np.int32, np.float64]: input = flow.tensor(x) update = flow.tensor(np.array([2, 3, 4]).astype(value_type)) output = np.array([1.0, 2.0, 3.0, 4.0, 1.0]) flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]]) test_case.assertTrue(np.array_equal(input.numpy(), output)) def test_slice_update_ellipsis_type(test_case): np_arr = np.zeros(shape=(2, 3, 4, 5, 6)) input = flow.tensor(np_arr, dtype=flow.float32) np_arr[0, ::1, ..., 2:3] = 1 input[0, ::1, ..., 2:3] = 1 test_case.assertTrue(np.array_equal(input.numpy(), np_arr)) def test_slice_update_ellipsis_type_graph(test_case): np_arr = np.zeros(shape=(2, 3, 4, 5, 6)) input = flow.tensor(np_arr, dtype=flow.float32) np_arr[0, ::1, ..., 2:3] = 1 @flow.nn.Graph.trace def test_func(): input[0, ::1, ..., 2:3] = 1 return input out = test_func() test_case.assertTrue(np.array_equal(out.numpy(), np_arr)) def test_slice_update_grad_graph(test_case): x = np.array([1, 1, 1, 1, 1]).astype(np.float32) input = flow.tensor(x, requires_grad=True) update = flow.tensor(np.array([2, 3, 4]).astype(np.float32), requires_grad=True) output = np.array([1.0, 2.0, 3.0, 4.0, 1.0]) class TestModule(flow.nn.Module): def __init__(self): super().__init__() self.ref_grad = flow.nn.Parameter(flow.zeros(5)) self.value_grad = flow.nn.Parameter(flow.zeros(3)) def forward(self, ref, value): x = ref + self.ref_grad y = value + self.value_grad return flow._C.slice_update(x, y, [1,], [4,], [1,]) test_m = TestModule() of_sgd = flow.optim.SGD(test_m.parameters(), lr=1.0, momentum=0.0) class TestSliceUpdateGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = test_m self.add_optimizer(of_sgd) def build(self, ref, update): x = self.m(ref, update) x.sum().backward() return x slice_update_g = TestSliceUpdateGraph() y = slice_update_g(input, update) # forward test_case.assertTrue(np.array_equal(y.numpy(), output)) # ref grad ref_grad = np.array([1.0, 0.0, 0.0, 0.0, 1.0]).astype(np.float32) test_case.assertTrue(np.array_equal(-test_m.ref_grad, ref_grad)) # value grad value_grad = np.array([1.0, 1.0, 1.0]).astype(np.float32) test_case.assertTrue(np.array_equal(-test_m.value_grad, value_grad)) def test_random_nd_slice_update_in_non_contiguous_tensor(test_case): def get_random_slice_tuple(shape): slice_tup = [] slice_size = [] for i in range(len(shape)): start = randint(0, shape[i] - 1) end = randint(start + 1, shape[i]) step = randint(1, end - start + 1) slice_tup.append(slice(start, end, step)) slice_size.append((end - start + step - 1) // step) return tuple(slice_tup), tuple(slice_size) def get_random_update_shape_and_perm(shape): perm = flow.randperm(len(shape)).tolist() no_perm_shape = [shape[i] for i in perm] inv_perm = [0] * len(shape) for i in range(len(shape)): inv_perm[perm[i]] = i return no_perm_shape, inv_perm def compare_result_between_oneflow_and_numpy(test_case, shape): device = random_device().value() # non-contiguous ref ref = ( flow.rand(shape, dtype=flow.float32) .to(device) .permute(flow.randperm(len(shape)).tolist()) ) ref_np = ref.detach().clone().numpy() shape = ref.shape # slice param slice_tup, slice_size = get_random_slice_tuple(shape) # non-contiguous update no_perm_shape, perm = get_random_update_shape_and_perm(slice_size) update = ( flow.rand(no_perm_shape, dtype=flow.float32).to(device).permute(perm) ) update_np = update.detach().clone().numpy() ref_np[slice_tup] = update_np # non-inplace update # NOTE: should test non-inplace first def slice_tuple_to_slice_list(slice_tup): # NOTE: oneflow.slice_update don't support passing slice parameters. slice_list = [] for i in range(len(slice_tup)): slice_list.append( (slice_tup[i].start, slice_tup[i].stop, slice_tup[i].step) ) return slice_list of_res = flow.slice_update( ref, update, slice_tuple_to_slice_list(slice_tup) ) test_case.assertTrue(np.array_equal(of_res.numpy(), ref_np)) # inplace update ref[slice_tup] = update test_case.assertTrue(np.array_equal(ref.numpy(), ref_np)) for dims in (2, 3, 4): for _ in range(10): shape = [randint(1, 21) for _ in range(dims)] compare_result_between_oneflow_and_numpy(test_case, shape) def test_slice_update_expand_value(test_case): ref_np = np.random.rand(2, 3, 4) ref_of = flow.tensor(ref_np) update_np = np.random.rand(3,) update_ref = flow.tensor(update_np) ref_of[:, :, 1] = update_ref ref_np[:, :, 1] = update_np test_case.assertTrue(np.array_equal(ref_of.numpy(), ref_np)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_softmax.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import os from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _dtype_flow_to_np(dtype): return {flow.float32: np.float32, flow.float16: np.float16}[dtype] def _np_softmax(x, dtype=None): if dtype is not None: x = x.astype(dtype) x -= np.max(x, axis=-1, keepdims=True) x = np.exp(x) return x / np.sum(x, axis=-1, keepdims=True) def _test_softmax_impl(test_case, shape, input_dtype, output_dtype): np_input = np.random.randn(*shape).astype(_dtype_flow_to_np(input_dtype)) of_input = flow.tensor(np_input, dtype=input_dtype, device=flow.device("cuda")) of_out = flow.nn.functional.softmax(of_input, dtype=output_dtype) if output_dtype is not None: np_out = _np_softmax(np_input, dtype=_dtype_flow_to_np(output_dtype)) else: np_out = _np_softmax(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.001, 0.001)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class Testsoftmax(flow.unittest.TestCase): def test_softmax(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(20, 30), (32, 128)] arg_dict["input_dtype"] = [flow.float16, flow.float32] arg_dict["output_dtype"] = [None, flow.float32] for arg in GenArgList(arg_dict): _test_softmax_impl(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_softplus.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_softplus_impl(test_case, shape, device): np_input = np.random.randn(*shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) np_x_grad = np.exp(np_input) / (1 + np.exp(np_input)) of_out = flow.softplus(of_input) np_out = np.log(1 + np.exp(np_input)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_x_grad, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class Testsoftplus(flow.unittest.TestCase): def test_softplus(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_softplus_impl(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sort.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_sort(test_case, data_shape, axis, descending, data_type, device): input = flow.tensor( np.random.randn(*data_shape), dtype=type_name_to_flow_type[data_type], device=flow.device(device), ) (of_values, of_indices) = flow.sort(input, dim=axis, descending=descending) np_input = -input.numpy() if descending else input.numpy() np_indices = np.argsort(np_input, axis=axis) np_out = np.sort(np_input, axis=axis) np_values = -np_out if descending else np_out test_case.assertTrue( np.array_equal(of_values.numpy().flatten(), np_values.flatten()) ) test_case.assertTrue( np.array_equal(of_indices.numpy().flatten(), np_indices.flatten()) ) def _test_tensor_sort(test_case, data_shape, axis, descending, data_type, device): input = flow.tensor( np.random.randn(*data_shape), dtype=type_name_to_flow_type[data_type], device=flow.device(device), ) (of_values, of_indices) = input.sort(dim=axis, descending=descending) np_input = -input.numpy() if descending else input.numpy() np_indices = np.argsort(np_input, axis=axis) np_out = np.sort(np_input, axis=axis) np_values = -np_out if descending else np_out test_case.assertTrue( np.array_equal(of_values.numpy().flatten(), np_values.flatten()) ) test_case.assertTrue( np.array_equal(of_indices.numpy().flatten(), np_indices.flatten()) ) @flow.unittest.skip_unless_1n1d() class TestSort(flow.unittest.TestCase): def test_sort(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_sort, _test_tensor_sort] arg_dict["data_shape"] = [(2, 6, 5, 4), (3, 4, 8)] arg_dict["axis"] = [-1, 0, 2] arg_dict["descending"] = [True, False] arg_dict["data_type"] = ["double", "float32", "int32"] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, auto_backward=False, check_graph=True) def test_sort_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.sort(x, dim=random(low=-4, high=4).to(int), descending=random_bool()) return y[0], y[1] @autotest(n=5, auto_backward=False, check_graph=True) def test_sort_return_type_with_random_data_(test_case): device = random_device() x = random_tensor(ndim=4).to(device) result = torch.sort( x, dim=random(low=-4, high=4).to(int), descending=random_bool() ) return result.values, result.indices @autotest(n=10, auto_backward=False, check_graph=True) def test_sort_bool_with_random_data(test_case): x = random_tensor(ndim=4).to(device="cpu", dtype=torch.bool) y = torch.sort(x, dim=random(low=-4, high=4).to(int), descending=random_bool()) return y[0], y[1] @autotest(n=10, auto_backward=False, check_graph=True) def test_sort_return_type_bool_with_random_data(test_case): x = random_tensor(ndim=4).to(device="cpu", dtype=torch.bool) result = torch.sort( x, dim=random(low=-4, high=4).to(int), descending=random_bool() ) return result.values, result.indices if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sparse.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import numpy as np def _test_embedding_padding_idx(test_case, device): indices = flow.tensor( [[1, 0, 4, 8], [8, 3, 0, 9]], dtype=flow.int, device=flow.device(device), requires_grad=False, ) embedding = flow.nn.Embedding(10, 3, padding_idx=0).to(device) output = embedding(indices) test_case.assertEqual(output[0][1].sum(), 0) test_case.assertEqual(output[1][2].sum(), 0) # negative indexing check for padding_idx # padding_idx=-2, num_embeddings=10 ==> index 8 padded embedding = flow.nn.Embedding(10, 3, padding_idx=-2).to(device) output = embedding(indices) test_case.assertEqual(output[0][3].sum(), 0) test_case.assertEqual(output[1][0].sum(), 0) # out of bounds check for padding_idx test_case.assertRaises( AssertionError, flow.nn.Embedding, num_embeddings=10, embedding_dim=3, padding_idx=25, ) test_case.assertRaises( AssertionError, flow.nn.Embedding, num_embeddings=10, embedding_dim=3, padding_idx=-25, ) padding_idx = 0 embedding = flow.nn.Embedding(10, 3, padding_idx=padding_idx).to(device) indices = flow.tensor( [[1, 0, 4, 8], [8, 3, 0, 9]], dtype=flow.int, device=flow.device(device), requires_grad=False, ) pre = embedding.weight[padding_idx].clone() embedding(indices).sum().backward() after = (embedding.weight + embedding.weight.grad)[padding_idx] embedding.zero_grad() test_case.assertTrue(flow.equal(after, pre)) def _test_embedding_scale_by_freq(test_case, device): weight = np.array( [ [0.68258786, 0.6957856, 1.1829041], [1.0154, -1.0616943, 0.50303376], [0.29679507, 0.65562993, 1.0424724], [-0.42980736, -0.35347632, -0.15600166], [0.6763601, -0.24286619, -2.0873115], [-0.13371214, -0.5589277, 1.9173933], [0.08762296, 1.0264007, -0.67938024], [0.32019204, -0.26137325, -1.3534237], [-1.1555519, -0.67776406, 0.27372134], [1.0615997, -0.59715784, 1.9855849], ], dtype=np.float32, ) output = np.array( [ [ [1.0154, -1.0616943, 0.50303376], [0.29679507, 0.65562993, 1.0424724], [0.6763601, -0.24286619, -2.0873115], [-0.13371214, -0.5589277, 1.9173933], ], [ [0.6763601, -0.24286619, -2.0873115], [-0.42980736, -0.35347632, -0.15600166], [0.29679507, 0.65562993, 1.0424724], [1.0615997, -0.59715784, 1.9855849], ], ], dtype=np.float32, ) indices = flow.tensor( [[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int, device=flow.device(device), requires_grad=False, ) m = flow.nn.Embedding(10, 3, scale_grad_by_freq=True, _weight=flow.Tensor(weight)) m = m.to(device) y = m(indices) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05)) y = y.sum() y.backward() weight_grad_np = [ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], ] test_case.assertTrue( np.allclose(m.weight.grad.numpy(), weight_grad_np, 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestEmbedding(flow.unittest.TestCase): def test_padding_idx(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_embedding_padding_idx(test_case, *arg) _test_embedding_scale_by_freq(test_case, *arg) @unittest.skip("skip for now, becase it failed 2 times in past week") @autotest(n=5, check_graph=True) def test_embedding_impl(test_case): device = random_device() emb_size = random(low=2) * 16 emb_dim = random(low=2) * 16 emb_shape = [emb_size, emb_dim] idx_ndim = random(high=4).to(int).value() idx_shape = [random(high=4) for i in range(idx_ndim)] weight = random_tensor(len(emb_shape), *emb_shape).to(device) indices = random_tensor( len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int ).to(device) embedding = torch.nn.Embedding(emb_size, emb_dim, _weight=weight).to(device) y = embedding(indices) return y @autotest(n=5, check_graph=True) def test_embedding_functional(test_case): device = random_device() emb_size = random(low=2) * 16 emb_dim = random(low=2) * 16 emb_shape = [emb_size, emb_dim] idx_ndim = random(high=4).to(int).value() idx_shape = [random(high=4) for i in range(idx_ndim)] weight = random_tensor(len(emb_shape), *emb_shape).to(device) indices = random_tensor( len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int ).to(device) y = torch.nn.functional.embedding(indices, weight) return y # NOTE(Yao Zihang): Set check_graph=False temporarily # Graph mode do not support inplace op with flow.no_grad() # See this issue: https://github.com/Oneflow-Inc/OneTeam/issues/1382 @unittest.skip("still have error in ci test. TODO(Yao Zihang)") @autotest(n=5, rtol=1e-03, atol=1e-03, check_graph="ValidatedFalse") def test_embedding_renorm(test_case): device = random_device() emb_size = random(low=2) * 16 emb_dim = random(low=2) * 16 emb_shape = [emb_size, emb_dim] idx_ndim = 2 idx_shape = [random(high=4) for i in range(idx_ndim)] weight = random_tensor(len(emb_shape), *emb_shape).to(device) indices = random_tensor( len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int ).to(device) embedding = torch.nn.Embedding( emb_size, emb_dim, max_norm=1.0, _weight=weight ).to(device) y = embedding(indices) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sparse_softmax_cross_entropy.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os from collections import OrderedDict import numpy as np import torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import ( GenArgList, type_name_to_flow_type, type_name_to_np_type, ) def compare_with_torch( device_type, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) torch_output.sum().backward() of_logits = flow.tensor( np_logits, device=device_type, dtype=data_type, requires_grad=True ) of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type) of_output = flow.nn.functional.sparse_softmax_cross_entropy( labels=of_labels, logits=of_logits ).to(device_type) of_output.sum().backward() assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) assert np.allclose( of_logits.grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04 ) def compare_eager_global_with_torch( device_type, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) placement = flow.placement(device_type, range(4)) rank = flow.env.get_rank() if rank == 0: torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) torch_output.sum().backward() # 1D sbp of_logits = flow.tensor( np_logits, device=device_type, dtype=data_type, requires_grad=True ) flow.comm.broadcast(of_logits, 0) of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.broadcast]) of_logits.retain_grad() global_of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.split(1)]) of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type) flow.comm.broadcast(of_labels, 0) of_labels = of_labels.to_global(placement=placement, sbp=[flow.sbp.broadcast]) of_output = flow.nn.functional.sparse_softmax_cross_entropy( labels=of_labels, logits=global_of_logits ).to(device_type) of_output.sum().backward() of_logits_grad = of_logits.grad.to_global( placement=placement, sbp=[flow.sbp.broadcast] ) of_logits_grad = of_logits_grad.to_local() of_output = of_output.to_global(placement=placement, sbp=[flow.sbp.broadcast]) of_output = of_output.to_local() if rank == 0: assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) assert np.allclose( of_logits_grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04 ) def compare_eager_2d_global_with_torch( device_type, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) rank = flow.env.get_rank() if rank == 0: torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) torch_output.sum().backward() # 2D sbp placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) of_logits = flow.tensor( np_logits, device=device_type, dtype=data_type, requires_grad=True ) flow.comm.broadcast(of_logits, 0) of_logits = of_logits.to_global( placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast] ) of_logits.retain_grad() global_of_logits = of_logits.to_global( placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)] ) of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type) flow.comm.broadcast(of_labels, 0) of_labels = of_labels.to_global( placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast] ) of_labels = of_labels.to_global( placement=placement, sbp=[flow.sbp.split(0), flow.sbp.broadcast] ) of_output = flow.nn.functional.sparse_softmax_cross_entropy( labels=of_labels, logits=global_of_logits ).to(device_type) of_output.sum().backward() of_logits_grad = of_logits.grad.to_global( placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast] ) of_logits_grad = of_logits_grad.to_local() of_output = of_output.to_global( placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast] ) of_output = of_output.to_local() if rank == 0: assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) assert np.allclose( of_logits_grad.numpy(), torch_logits.grad.detach().numpy(), rtol=1e-03, atol=1e-04, ) def compare_lazy_global_with_torch( device_type, data_type, label_type, batch_size, num_classes, ): data_type = type_name_to_flow_type[data_type] label_type = type_name_to_flow_type[label_type] np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32) np_logits = np.random.random((batch_size, num_classes)).astype(np.float32) placement = flow.placement(device_type, range(4)) rank = flow.env.get_rank() if rank == 0: torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True) torch_labels = torch.tensor(np_labels, dtype=torch.int64) torch_output = torch.nn.functional.cross_entropy( torch_logits, torch_labels, reduction="none" ) torch_output.sum().backward() class MyModule(flow.nn.Graph): def __init__(self): super(MyModule, self).__init__() def build(self, logits, labels): output = flow.nn.functional.sparse_softmax_cross_entropy( labels=labels, logits=logits ) # nn.graph no support get input.grad # output.sum().backward() return output of_logits = flow.tensor( np_logits, device=device_type, dtype=data_type, requires_grad=True ) flow.comm.broadcast(of_logits, 0) of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.broadcast]) of_logits.retain_grad() global_of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.split(1)]) of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type) flow.comm.broadcast(of_labels, 0) of_labels = of_labels.to_global(placement=placement, sbp=[flow.sbp.broadcast]) graph = MyModule() of_output = graph(global_of_logits, of_labels) of_output = of_output.to_global(placement=placement, sbp=[flow.sbp.broadcast]) of_output = of_output.to_local() flow._oneflow_internal.eager.Sync() if rank == 0: assert np.allclose( of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04 ) class TestSparseSoftmaxCrossEntropyWithLogits(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_sparse_softmax_cross_entropy(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda", "cpu"] arg_dict["data_type"] = ["float32", "double"] arg_dict["label_type"] = ["int32", "int64"] arg_dict["batch_size"] = [64, 16] arg_dict["num_classes"] = [100, 1000] for arg in GenArgList(arg_dict): compare_with_torch(*arg) class TestSparseSoftmaxCrossEntropyMsWithLogits(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() def test_distributed_sparse_softmax_cross_entropy(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["cuda"] arg_dict["data_type"] = ["float32", "double"] arg_dict["label_type"] = ["int32", "int64"] arg_dict["batch_size"] = [64] arg_dict["num_classes"] = [1000] for arg in GenArgList(arg_dict): # compare_eager_global_with_torch(*arg) compare_eager_2d_global_with_torch(*arg) compare_lazy_global_with_torch(*arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_special_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch as torch_original from packaging import version from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestSpecialOps(flow.unittest.TestCase): @autotest(n=5, auto_backward="auto") def test_flow_erf_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.erf(x) return y @autotest(n=5, auto_backward="auto") def test_flow_erfc_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.erfc(x) return y @autotest(n=5, auto_backward="auto") def test_flow_erfinv_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["float"]) x = random_tensor(requires_grad=False).to(device).to(x_dtype) y = torch.special.erfinv(x) return y @autotest(n=5, auto_backward="auto") def test_flow_exp2_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.exp2(x) return y @autotest(n=5, auto_backward="auto") def test_flow_expm1_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.expm1(x) return y @autotest(n=5, auto_backward="auto") def test_flow_round_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.round(x) @autotest(n=5, auto_backward="auto") def test_flow_log1p_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.log1p(x) return y @autotest(n=5, auto_backward="auto") def test_flow_log_softmax_with_random_data(test_case): num_dims = random(low=1, high=5).to(int) device = random_device() x = random_tensor(ndim=num_dims).to(device) y = torch.special.log_softmax(x, dim=random(low=0, high=num_dims).to(int)) return y @unittest.skipIf( version.parse(torch_original.__version__) <= version.parse("1.13.0"), "module 'torch.special' has no attribute 'softmax' before '1.13.0'", ) @autotest(n=5, auto_backward="auto") def test_flow_softmax_with_random_data(test_case): num_dims = random(low=1, high=5).to(int) device = random_device() x = random_tensor(ndim=num_dims).to(device) y = torch.special.softmax(x, dim=random(low=0, high=num_dims).to(int)) return y @autotest(n=5, auto_backward="auto") def test_flow_logsumexp_with_random_data(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) y = torch.special.logsumexp(x, dim=np.random.randint(0, 3)) return y @autotest(n=5, auto_backward="auto") def test_flow_digamma_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic", "half"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.digamma(x) return y @autotest(n=5, auto_backward="auto") def test_flow_psi_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic", "half"]) x = random_tensor().to(device).to(x_dtype) y = torch.special.psi(x) return y @flow.unittest.skip_unless_1n1d() class TestZeta(flow.unittest.TestCase): # the grad func of zeta is not supported @autotest(n=5, auto_backward=False) def test_flow_zeta_with_random_data(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) input = ( random_tensor(ndim=2, dim0=20, dim1=20, low=1, high=10) .to(device) .to(x_dtype) ) other = ( random_tensor(ndim=2, dim0=20, dim1=20, low=1, high=10) .to(device) .to(x_dtype) ) out = torch.special.zeta(input, other) return out @autotest(n=5, auto_backward=False) def test_flow_zeta_broadcast_input(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) input = random_tensor(ndim=2, dim0=1, dim1=20).to(device).to(x_dtype) other = random_tensor(ndim=2, dim0=20, dim1=20).to(device).to(x_dtype) out = torch.special.zeta(input, other) return out @autotest(n=5, auto_backward=False) def test_flow_zeta_broadcast_other(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) input = random_tensor(ndim=2, dim0=20, dim1=20).to(device).to(x_dtype) other = random_tensor(ndim=2, dim0=1, dim1=20).to(device).to(x_dtype) out = torch.special.zeta(input, other) return out @autotest(n=5, auto_backward=False) def test_flow_zeta_scalar_other(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) input = random_tensor(ndim=2, dim0=2, dim1=20).to(device).to(x_dtype) out = torch.special.zeta(0.5, input) return out @autotest(n=5, auto_backward=False) def test_flow_zeta_scalar_other(test_case): device = random_device() x_dtype = random_dtype(["arithmetic"]) input = random_tensor(ndim=2, dim0=2, dim1=20).to(device).to(x_dtype) out = torch.special.zeta(input, 0.5) return out if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_split.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestSplit(flow.unittest.TestCase): @autotest(n=5) def test_flow_split_with_random_data(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) res = torch.split(x, 2, dim=rand_dim) return torch.cat(res, rand_dim) @autotest(n=5, check_graph=True) def test_flow_split_with_stride(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) perm = [0, 1, 2] shuffle(perm) y = x.permute(perm) z = torch.split(y, 2, dim=rand_dim) return torch.cat(z, rand_dim) @autotest(n=5) def test_flow_split_sizes_with_random_data(test_case): k0 = random(2, 6) k1 = 7 k2 = random(2, 6) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) res = torch.split(x, [1, 2, 3, 1], dim=1) return torch.cat(res, dim=1) @autotest(n=5) def test_flow_split_sizes_neg_dim_with_random_data(test_case): k0 = random(2, 6) k1 = 7 k2 = random(2, 6) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) res = torch.split(x, [1, 2, 3, 1], dim=-2) return torch.cat(res, dim=1) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_split_bool_with_random_data(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to( device=device, dtype=torch.bool ) res = torch.split(x, split_size_or_sections=2, dim=rand_dim) return torch.cat(res, rand_dim) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_square_relu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch class SquareReLUActivation(torch.nn.Module): """ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 """ def forward(self, input): relu_applied = torch.nn.functional.relu(input) squared = torch.square(relu_applied) return squared def _test_square_relu(test_case, device): torch_square_relu = SquareReLUActivation() x = np.random.randn(2, 4, 3) torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device)) oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device)) torch_y = torch_square_relu(torch_x) oneflow_y = flow._C.square_relu(oneflow_x) test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy())) torch_y_sum = torch_y.sum() torch_y_sum.backward() oneflow_y_sum = oneflow_y.sum() oneflow_y_sum.backward() test_case.assertTrue( np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy()) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_square_relu(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_square_relu] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_squeeze.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_squeeze(test_case, device): np_arr = np.random.rand(1, 1, 1, 3) input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) of_shape = flow.squeeze(input, dim=[1, 2]).numpy().shape np_shape = (1, 3) test_case.assertTrue(np.array_equal(of_shape, np_shape)) test_case.assertTrue( np.allclose( flow.squeeze(input, dim=[1, 2]).numpy(), np.squeeze(input.numpy(), axis=(1, 2)), 0.0001, 0.0001, ) ) def _test_squeeze_1d_input(test_case, device): np_arr = np.random.rand(10) input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) output = flow.squeeze(input) test_case.assertTrue(np.allclose(output.numpy(), np_arr, 1e-05, 1e-05)) def _test_tensor_squeeze(test_case, device): np_arr = np.random.rand(1, 1, 1, 3) input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) of_shape = input.squeeze(dim=[1, 2]).numpy().shape np_shape = (1, 3) test_case.assertTrue(np.array_equal(of_shape, np_shape)) test_case.assertTrue( np.allclose( input.squeeze(dim=[1, 2]).numpy(), np.squeeze(input.numpy(), axis=(1, 2)), 0.0001, 0.0001, ) ) def _test_squeeze_int(test_case, device): np_arr = np.random.rand(1, 1, 1, 3) input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) of_shape = flow.squeeze(input, 1).numpy().shape np_shape = (1, 1, 3) test_case.assertTrue(np.array_equal(of_shape, np_shape)) test_case.assertTrue( np.allclose( input.squeeze(1).numpy(), np.squeeze(input.numpy(), axis=1), 0.0001, 0.0001 ) ) def _test_squeeze_backward(test_case, device): np_arr = np.random.rand(1, 1, 1, 3) input = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) y = flow.squeeze(input, dim=1).sum() y.backward() np_grad = np.ones((1, 1, 1, 3)) test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad)) @flow.unittest.skip_unless_1n1d() class TestSqueeze(flow.unittest.TestCase): def test_squeeze(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_squeeze, _test_squeeze_1d_input, _test_squeeze_int, _test_tensor_squeeze, _test_squeeze_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True) def test_flow_squeeze_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.squeeze(x, random(1, 3).to(int)) return y @autotest(n=10, check_graph=False, auto_backward=False) def test_inplace_squeeze_with_random_data(test_case): device = random_device() x = random_tensor(requires_grad=False).to(device) y = x.squeeze_(random(1, 3).to(int)) return y @autotest(auto_backward=False, check_graph=True) def test_squeeze_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 1, 0).to(device) y = torch.squeeze(x) return y @autotest(auto_backward=False, check_graph=True) def test_flow_squeeze_bool_with_random_data(test_case): device = random_device() x = random_tensor().to(device=device, dtype=torch.bool) y = torch.squeeze(x, random(1, 3).to(int)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_stack.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestStackModule(flow.unittest.TestCase): @autotest(check_graph=True) def test_stack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device) y = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device) out = torch.stack((x, y), dim=random(low=-5, high=5).to(int)) return out @autotest(auto_backward=False, check_graph=True) def test_stack_bool_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to( device=device, dtype=torch.bool ) y = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to( device=device, dtype=torch.bool ) out = torch.stack((x, y), dim=random(low=1, high=4).to(int)) return out @autotest(check_graph=True) def test_column_stack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=1, dim0=10).to(device) y = random_tensor(ndim=2, dim0=10, dim1=5).to(device) z = random_tensor(ndim=2, dim0=10, dim1=5).to(device) out = torch.column_stack((x, y, z)) return out def test_column_stack_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=1, dim0=1).to(device) out = torch.column_stack((x, y)) return out @autotest(check_graph=True) def test_row_stack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=1, dim0=10).to(device) y = random_tensor(ndim=2, dim0=5, dim1=10).to(device) z = random_tensor(ndim=2, dim0=5, dim1=10).to(device) out = torch.row_stack((x, y, z)) return out def test_row_stack_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=1, dim0=1).to(device) out = torch.row_stack((x, y)) return out @autotest(check_graph=True) def test_hstack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=1, dim0=5).to(device) y = random_tensor(ndim=1, dim0=5).to(device) out = torch.hstack((x, y)) return out @autotest(check_graph=True) def test_hstack_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) # test 1-dim simultaneouslsimultaneouslyy z = random_tensor(ndim=1, dim0=1).to(device) out = torch.hstack((x, y, z)) return out @autotest(check_graph=True) def test_vstack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim0=3, dim1=4).to(device) y = random_tensor(ndim=1, dim0=4).to(device) z = random_tensor(ndim=2, dim0=3, dim1=4).to(device) out = torch.vstack((x, y, z)) return out @autotest(check_graph=True) def test_vstack_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) out = torch.vstack((x, y)) return out @autotest(check_graph=True) def test_dstack_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim0=1, dim1=4).to(device) y = random_tensor(ndim=3, dim0=1, dim1=4, dim2=1).to(device) z = random_tensor(ndim=1, dim0=4).to(device) out = torch.dstack((x, y, z)) return out @autotest(check_graph=True) def test_dstack_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) z = random_tensor(ndim=0).to(device) out = torch.dstack((x, y, z)) @autotest(auto_backward=True, check_graph=True) def test_stack_kMaxInputCount_inputs(test_case): kMaxInputCount = 128 + 1 stack_list = [ random_tensor(ndim=2, dim0=3, dim1=4) for _ in range(kMaxInputCount) ] out = torch.stack(stack_list, 0) return out if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_stateful_kernel_with_cache.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import os import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestStatefulKernelWithInpersistentState(flow.unittest.TestCase): def test_stateful_kernel_with_inpersistent_state(test_case): x = flow.arange(4).reshape(2, 2) x = x.to_global(flow.placement.all("cuda"), flow.sbp.split(0)) y = x[0:3, 0:1] y_np = np.array([[0], [2], [0]]) test_case.assertTrue( np.array_equal(y.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), y_np) ) x = x.to_global(sbp=flow.sbp.split(1)) y = x[0:3, 0:1] test_case.assertTrue( np.array_equal(y.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), y_np) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_stateful_local_opkernel.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest import numpy as np @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestStatefulLocalKernel(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_dynamic_attrs(test_case): x = flow.full((2, 3), 3.0) y = flow.unsqueeze(x, dim=1) test_case.assertEqual(y.shape, flow.Size((2, 1, 3))) y = flow.unsqueeze(x, dim=2) test_case.assertEqual(y.shape, flow.Size((2, 3, 1))) @flow.unittest.skip_unless_1n2d() def test_stateful_local_kernel_in_global_mode(test_case): rank = int(os.getenv("RANK")) x = flow.tensor(np.array([1, 2]) * (rank + 1)).to("cuda") x = x.to_global(flow.placement("cuda", range(2)), flow.sbp.split(0)) y = flow.tensor([3, 4, 5]).to("cuda") y = y.to_global(flow.placement("cuda", range(2)), flow.sbp.broadcast) # logical slice assign op needs sbp and logical shape from stateful local opkernel x[:3] = y x = x.to_global(sbp=flow.sbp.broadcast) test_case.assertTrue( np.array_equal(x.to_local().numpy(), np.array([3, 4, 5, 4])) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_std.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestStd(flow.unittest.TestCase): @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True) def test_std_flow_with_random_data(test_case): device = random_device() all_dim = random().to(int) dim = random(low=0, high=6).to(int) x = random_tensor(ndim=all_dim, low=2, high=6).to(device) z = torch.std( x, dim=dim, unbiased=random().to(bool), keepdim=random().to(bool), ) return z @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True) def test_std_tensor_with_random_data(test_case): device = random_device() dim = random(low=0, high=4).to(int) x = random_tensor( ndim=4, dim0=random(2, 4), dim1=random(2, 4), dim2=random(2, 4), dim3=random(2, 4), ).to(device) z = x.std(dim=dim, keepdim=random().to(bool),) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_stft.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from numpy import random import unittest from collections import OrderedDict import numpy as np import re import oneflow as flow from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * def getRandBoolvalue(): value = np.random.randint(0, 2) if value == 1: return True else: return False def getRandFFtvalue(): pow = np.random.randint(2, 5) result = 1 for i in range(pow): result = result * 2 return result def is_cufft_available(): if flow.cuda.is_available(): (major, _minor) = flow.cuda.get_device_capability() return major >= 7 else: return False class TestStft(flow.unittest.TestCase): @autotest( n=20, check_graph=False, check_grad_use_random_data=False, auto_backward=False, ) def test_stft_with_1D_random_data(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() rand_fft = getRandFFtvalue() rand_size = np.random.randint(rand_fft, 300) input_dims = [rand_size] win_dims = [rand_fft] x = random_tensor(1, *input_dims).to(device) win = random_tensor(1, *win_dims).to(device) onesided_value = getRandBoolvalue() center_value = getRandBoolvalue() normalized_value = getRandBoolvalue() y = torch.stft( x, n_fft=rand_fft, window=win, return_complex=False, onesided=onesided_value, center=center_value, normalized=normalized_value, ) return y def test_stft_with_2D_random_data(test_case): if is_cufft_available(): device = random_device() else: device = cpu_device() row_rand_size = np.random.randint(1, 50) rand_fft = getRandFFtvalue() col_rand_size = np.random.randint(rand_fft, 300) input_dims = [row_rand_size, col_rand_size] win_dims = [rand_fft] x = random_tensor(2, *input_dims).to(device) win = random_tensor(1, *win_dims).to(device) onesided_value = getRandBoolvalue() center_value = getRandBoolvalue() normalized_value = getRandBoolvalue() y = torch.stft( x, n_fft=rand_fft, window=win, return_complex=False, onesided=onesided_value, center=center_value, normalized=normalized_value, ) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sub.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_sub_impl(test_case, shape, device): x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.sub(x, y) np_out = np.subtract(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad_x = np.ones(shape) np_grad_y = -np.ones(shape) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05)) x = 5 y = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) of_out = flow.sub(x, y) np_out = np.subtract(x, y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = 5 of_out = flow.sub(x, y) np_out = np.subtract(x.numpy(), y) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor( np.random.randn(*shape), dtype=flow.float32, device=flow.device(device) ) y = flow.tensor( np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device) ) of_out = flow.sub(x, y) np_out = np.subtract(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor(np.array([5.0]), dtype=flow.float32) y = flow.tensor(np.random.randn(1, 1), dtype=flow.float32) of_out = flow.sub(x, y) np_out = np.subtract(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) x = flow.tensor(np.random.randn(1, 1), dtype=flow.float32, requires_grad=True) y = flow.tensor(np.array([5.0]), dtype=flow.float32, requires_grad=True) of_out = flow.sub(x, y) np_out = np.subtract(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad_x = np.ones((1, 1)) np_grad_y = -np.ones(1) test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestSubModule(flow.unittest.TestCase): def test_sub(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_sub_impl(test_case, *arg) @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True) def test_random_dim_sub(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) z = x - y return z @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True) def test_random_dim_scalar_sub(test_case): device = random_device() dim0 = random(low=1, high=4).to(int) dim1 = random(low=1, high=4).to(int) x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) y = random_tensor(ndim=0).to(device) z = x - y return z @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True) def test_sub_with_0_size_data(test_case): device = random_device() x = random_tensor(2, 0, 3).to(device) y = random_tensor(2, 1, 3).to(device) out1 = x - y out2 = x - 2 out3 = 2 - x out4 = torch.sub(x, y) return out1, out2, out3, out4 @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True) def test_sub_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) out1 = x - y out2 = x - 2 out3 = 2 - x out4 = torch.sub(x, y) return out1, out2, out3, out4 @autotest(n=5, include_complex=True) def test_sub_with_alpha(test_case): device = random_device() x1 = random_tensor(2, 2, 3).to(device) x2 = random_tensor(2, 2, 3).to(device) x3 = random_tensor(2, 2, 3).to(device) y = random_tensor(2, 2, 3).to(device) s = random().to(float) alpha = random().to(float) z1 = torch.sub(x1, y, alpha=alpha) z2 = torch.sub(x2, s, alpha=alpha) z3 = torch.sub(s, x3, alpha=alpha) return z1, z2, z3 @autotest(n=5, include_complex=True) def test_non_contiguous_inplace_sub(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) y = x + 1 y = y[:, 1:3] y -= random_tensor(2, 2, 2).to(device) return y @unittest.skip("skip for now, becase it failed 2 times in past week") @autotest(n=5, include_complex=True) def test_scalar_sub_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() x1 = random_tensor(2, 2, 3).to(x1_device).mean() x2 = random_tensor(2, 2, 3).to(x2_device) y = x1 - x2 return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sum.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_sum_impl(test_case, device, data_type): if device == "cpu" and data_type == flow.float16: return input = flow.tensor( np.random.randn(2, 3) - 0.5, dtype=data_type, device=flow.device(device) ) of_out = flow.sum(input, dim=0) np_out = np.sum(input.numpy(), axis=0) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) input = flow.tensor( np.random.randn(2, 3), dtype=data_type, device=flow.device(device) ) of_out = flow.sum(input, dim=0) np_out = np.sum(input.numpy(), axis=0) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) input = flow.tensor( np.random.randn(2, 3), dtype=data_type, device=flow.device(device) ) of_out = flow.sum(input, dim=1) of_out2 = input.sum(dim=1) np_out = np.sum(input.numpy(), axis=1) test_case.assertTrue(np.allclose(of_out2.numpy(), of_out.numpy(), 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) input = flow.tensor( np.random.randn(4, 5, 6) - 0.5, dtype=data_type, device=flow.device(device), requires_grad=True, ) of_out = flow.sum(input, dim=(2, 1)) np_out = np.sum(input.numpy(), axis=(2, 1)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.ones((4, 5, 6)) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) # For 0-dim tensor test input = flow.tensor(1.0) of_out = input.sum() test_case.assertTrue(np.allclose(input.numpy(), of_out.numpy(), 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestSumModule(flow.unittest.TestCase): def test_sum(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["data_type"] = [flow.float16, flow.float32] for arg in GenArgList(arg_dict): _test_sum_impl(test_case, *arg) @autotest(check_graph=True, include_complex=True) def test_sum_against_pytorch(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) y = torch.sum(x) return y @autotest(check_graph=True, auto_backward=False) def test_sum_dtype(test_case): device = random_device() x = random_tensor(4, requires_grad=False).to(device) y = torch.sum( x, dim=np.random.randint(0, 3), keepdim=random_bool(), dtype=random_dtype(["arithmetic"]), ) return y @autotest( n=10, check_graph=False, auto_backward=True, include_complex=True, atol=1e-2, rtol=1e-5, ) def test_sum_complex_dtype(test_case): device = random_device() x = random_tensor(4, dtype=complex, requires_grad=True).to( device=device, dtype=random_dtype(["complex"]) ) y = torch.sum( x, dim=np.random.randint(0, 3), keepdim=random_bool(), dtype=random_dtype(["complex"]), ) return y @autotest( n=10, check_graph=False, auto_backward=True, include_complex=True, atol=1e-2, rtol=1e-5, ) def test_sum_complex_dtype(test_case): device = random_device() x = random_tensor(4, dtype=complex, requires_grad=True).to( device=device, dtype=random_dtype(["complex"]) ) y = torch.sum( x, dim=np.random.randint(0, 3), keepdim=random_bool(), dtype=random_dtype(["complex"]), ) return y @autotest(check_graph=True, auto_backward=False) def test_sum_arithmetic_dtype(test_case): device = random_device() x = random_tensor(4, requires_grad=False).to(device) y = torch.sum(x, dtype=random_dtype(["arithmetic"])) return y @autotest(auto_backward=False, check_graph=True, include_complex=True) def test_sum_with_0_size_tensor(test_case): device = random_device() x = random_tensor(4, 4, 3, 0, 2).to(device) y = torch.sum(x, dim=np.random.randint(0, 3)) return y @autotest(auto_backward=False, check_graph=True, include_complex=True) def test_sum_with_0dim_tensor(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.sum(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_swapaxes.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestSwapaxes(flow.unittest.TestCase): @autotest(check_graph=True) def test_swapaxes_flow_with_random_data(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = torch.swapaxes(x, random(0, 2).to(int), random(0, 2).to(int)) return y @autotest(n=10) def test_swapaxes_flow_with_stride(test_case): device = random_device() x = random_tensor(ndim=3).to(device) perm = [0, 1, 2] shuffle(perm) y = x.permute(perm) z = torch.swapaxes(y, random(0, 2).to(int), random(0, 2).to(int)) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_swapdims.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class Testswapdims(flow.unittest.TestCase): @autotest(check_graph=True) def test_swapdims_flow_with_random_data(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = torch.swapdims(x, np.random.randint(0, 3), np.random.randint(0, 3)) return y @autotest(check_graph=True) def test_swapdims_flow_with_random_data2(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.swapdims(x, np.random.randint(0, 4), np.random.randint(0, 4)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_swautils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""" This test module references to pytorch. https://github.com/pytorch/pytorch/blob/master/test/test_optim.py. """ import math import unittest import itertools import numpy as np import oneflow as flow import oneflow.optim as optim import oneflow.nn.functional as F from oneflow.nn import Parameter from oneflow.optim import SGD, Optimizer from oneflow.nn.optimizer.lr_scheduler import LRScheduler from oneflow.nn.optimizer.multiplicative_lr import MultiplicativeLR from oneflow.nn.optimizer.swa_utils import AveragedModel, SWALR, update_bn import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestLRScheduler(flow.unittest.TestCase): # This class mainly used to test MultiplicativeLR and SWALR def setUp(self): super(TestLRScheduler, self).setUp() self.net = SchedulerTestNet() self.opt = SGD( [ {"params": self.net.conv1.parameters()}, {"params": self.net.conv2.parameters(), "lr": 0.5}, ], lr=0.05, ) def test_multiplicative_lr(self): # test Multiplicative lr epochs = 10 self.opt.param_groups[0]["lr"] = 0.05 self.opt.param_groups[1]["lr"] = 0.4 targets = [ [0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)], ] scheduler = MultiplicativeLR( self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8] ) self._test(scheduler, targets, epochs) def _test(self, schedulers, targets, epochs=10): if isinstance(schedulers, LRScheduler): schedulers = [schedulers] for epoch in range(epochs): for param_group, target in zip(self.opt.param_groups, targets): self.assertTrue( np.allclose( target[epoch], param_group["lr"], atol=1e-6, rtol=1e-5, ), msg="LR is wrong in epoch {}: expected {}, got {}".format( epoch, target[epoch], param_group["lr"] ), ) [scheduler.step() for scheduler in schedulers] def test_swa_lr_state_dict(self): self._check_scheduler_state_dict( lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), lambda: SWALR( self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0 ), ) def _check_scheduler_state_dict(self, constr, constr2, epochs=10): scheduler = constr() for _ in range(epochs): scheduler.optimizer.step() scheduler.step() scheduler_copy = constr2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): if key != "optimizer": self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) def test_swalr_no_anneal(self): epochs, swa_start, swa_lr = 10, 5, 0.01 initial_lrs = [group["lr"] for group in self.opt.param_groups] targets = [ [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) for lr in initial_lrs ] swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) def test_swalr_cosine_anneal_after_multiplicative(self): # same swa_lr for different param_groups epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5 mult_factor = 0.9 scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr) def anneal_coef(t): if t + 1 >= anneal_epochs: return 0.0 return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 initial_lrs = [group["lr"] for group in self.opt.param_groups] targets_before_swa = [ [lr * mult_factor ** i for i in range(swa_start + 1)] for lr in initial_lrs ] swa_epochs = epochs - swa_start - 1 targets = [ lrs + [ lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs) ] for lrs in targets_before_swa ] self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): for epoch in range(epochs): for param_group, target in zip(self.opt.param_groups, targets): self.assertTrue( np.allclose( target[epoch], param_group["lr"], atol=1e-6, rtol=1e-5, ), msg="LR is wrong in epoch {}: expected {}, got {}".format( epoch, target[epoch], param_group["lr"] ), ) if epoch >= swa_start: self.opt.step() swa_scheduler.step() elif scheduler is not None: self.opt.step() scheduler.step() def test_swalr_hypers(self): # Test that SWALR raises errors for incorrect hyper-parameters with self.assertRaisesRegex(ValueError, "anneal_strategy must"): swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0) with self.assertRaisesRegex(ValueError, "anneal_epochs must"): swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0) with self.assertRaisesRegex(ValueError, "anneal_epochs must"): swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0) with self.assertRaisesRegex(ValueError, "swa_lr must"): swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01]) @flow.unittest.skip_unless_1n1d() class TestSWAUtils(flow.unittest.TestCase): # This class mainly used to test AveragedModel and update_bn def _test_averaged_model(self, net_device, swa_device): # test the average of AveragedModel dnn = flow.nn.Sequential( flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.ReLU(), flow.nn.MaxPool2d(kernel_size=2), flow.nn.BatchNorm2d(5, momentum=0.3), flow.nn.Conv2d(5, 2, kernel_size=3), flow.nn.ReLU(), flow.nn.Linear(5, 5), flow.nn.ReLU(), flow.nn.Linear(5, 10), ).to(net_device) averaged_dnn = AveragedModel(dnn, device=swa_device) averaged_params = [flow.zeros_like(param) for param in dnn.parameters()] n_updates = 10 for i in range(n_updates): for p, p_avg in zip(dnn.parameters(), averaged_params): p.detach().add_(flow.randn_like(p)) p_avg += p.detach() / n_updates if i == 0: averaged_dnn.update_parameters(dnn) else: averaged_dnn.update_parameters(dnn) for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertTrue( flow.allclose(p_avg.cpu(), p_swa.cpu(), atol=1e-5, rtol=1e-4) ) # Check that AveragedModel is on the correct device self.assertTrue(p_swa.device == swa_device) self.assertTrue(p.device == net_device) self.assertTrue(averaged_dnn.n_averaged.device == swa_device) def test_averaged_model_all_devices(self): cpu = flow.device("cpu") self._test_averaged_model(cpu, cpu) if flow.cuda.is_available(): cuda = flow.device("cuda:0") self._test_averaged_model(cuda, cpu) self._test_averaged_model(cpu, cuda) self._test_averaged_model(cuda, cuda) def test_averaged_model_mixed_device(self): if not flow.cuda.is_available(): return dnn = flow.nn.Sequential( flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.Linear(5, 10) ) dnn[0].cuda() dnn[1].cpu() averaged_dnn = AveragedModel(dnn) averaged_params = [flow.zeros_like(param) for param in dnn.parameters()] n_updates = 10 for i in range(n_updates): for p, p_avg in zip(dnn.parameters(), averaged_params): p.detach().add_(flow.randn_like(p)) p_avg += p.detach() / n_updates averaged_dnn.update_parameters(dnn) for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4)) # Check that AveragedModel is on the correct device self.assertTrue(p_avg.device == p_swa.device) def test_averaged_model_state_dict(self): dnn = flow.nn.Sequential( flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.Linear(5, 10) ) averaged_dnn = AveragedModel(dnn) averaged_dnn2 = AveragedModel(dnn) n_updates = 10 for i in range(n_updates): for p in dnn.parameters(): p.detach().add_(flow.randn_like(p)) averaged_dnn.update_parameters(dnn) averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): self.assertTrue(flow.allclose(p_swa, p_swa2, atol=1e-5, rtol=1e-4)) self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) def test_averaged_model_exponential(self): # Test AveragedModel with EMA as avg_fn dnn = flow.nn.Sequential( flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.BatchNorm2d(5, momentum=0.3), flow.nn.Linear(5, 10), ) alpha = 0.9 def avg_fn(p_avg, p, n_avg): return alpha * p_avg + (1 - alpha) * p averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn) averaged_params = [flow.zeros_like(param) for param in dnn.parameters()] n_updates = 10 for i in range(n_updates): updated_averaged_params = [] for p, p_avg in zip(dnn.parameters(), averaged_params): p.detach().add_(flow.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: updated_averaged_params.append( (p_avg * alpha + p * (1 - alpha)).clone() ) for b in dnn.buffers(): if b.size() != flow.Size([]): # oneflow don't support detach_ # b.detach_().add_(flow.randn_like(b)) b.detach().add_(flow.randn_like(b)) averaged_dnn.update_parameters(dnn) averaged_params = updated_averaged_params for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4)) for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): self.assertTrue(flow.allclose(b_avg, b_swa, atol=1e-5, rtol=1e-4)) def test_averaged_model_exponential_buffers(self): # Test AveragedModel with EMA as avg_fn and use_buffers as True. dnn = flow.nn.Sequential( flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.BatchNorm2d(5, momentum=0.3), flow.nn.Linear(5, 10), ) alpha = 0.9 def avg_fn(p_avg, p, n_avg): return alpha * p_avg + (1 - alpha) * p averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True) dnn_params = itertools.chain(dnn.parameters(), dnn.buffers()) averaged_params = [ flow.zeros_like(param) for param in dnn_params if param.size() != flow.Size([]) ] n_updates = 10 for i in range(n_updates): updated_averaged_params = [] for p, p_avg in zip(dnn_params, averaged_params): if p.size() == flow.Size.Size([]): continue p.detach().add_(flow.Size.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: updated_averaged_params.append( (p_avg * alpha + p * (1 - alpha)).clone() ) averaged_dnn.update_parameters(dnn) averaged_params = updated_averaged_params for p_avg, p_swa in zip( averaged_params, itertools.chain( averaged_dnn.module.parameters(), averaged_dnn.module.buffers() ), ): self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4)) def _test_update_bn(self, dnn, dl_x, dl_xy, momentum, cuda): preactivation_sum = flow.zeros(dnn.n_features) preactivation_squared_sum = flow.zeros(dnn.n_features) if cuda: preactivation_sum = preactivation_sum.cuda() preactivation_squared_sum = preactivation_squared_sum.cuda() total_num = 0 for x in dl_x: x = x[0] if cuda: x = x.cuda() dnn.forward(x) preactivations = dnn.compute_preactivation(x) if len(preactivations.shape) == 4: preactivations = preactivations.transpose(1, 3) preactivations = preactivations.contiguous().view(-1, dnn.n_features) total_num += preactivations.shape[0] preactivation_sum += flow.sum(preactivations, dim=0) preactivation_squared_sum += flow.sum(preactivations ** 2, dim=0) preactivation_mean = preactivation_sum / total_num preactivation_var = preactivation_squared_sum / total_num preactivation_var = preactivation_var - preactivation_mean ** 2 update_bn(dl_xy, dnn, device=x.device) self.assertTrue( flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3) ) self.assertTrue( flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1) ) def _reset_bn(module): if issubclass(module.__class__, flow.nn.modules.batchnorm._BatchNorm): module.running_mean = flow.zeros_like(module.running_mean) module.running_var = flow.ones_like(module.running_var) # reset batch norm and run update_bn again dnn.apply(_reset_bn) update_bn(dl_xy, dnn, device=x.device) self.assertTrue( flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3) ) self.assertTrue( flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1) ) # using the dl_x loader instead of dl_xy dnn.apply(_reset_bn) update_bn(dl_x, dnn, device=x.device) self.assertTrue( flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3) ) self.assertTrue( flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1) ) def test_update_bn_dnn(self): # Test update_bn for a fully-connected network with BatchNorm1d objects, input_features = 100, 5 x = flow.rand(objects, input_features) y = flow.rand(objects) ds_x = flow.utils.data.TensorDataset(x) ds_xy = flow.utils.data.TensorDataset(x, y) dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) dl_xy = flow.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) dnn = SWATestDNN(input_features=input_features) dnn.train() self._test_update_bn(dnn, dl_x, dl_xy, 0.1, False) if flow.cuda.is_available(): dnn = SWATestDNN(input_features=input_features) dnn.train() self._test_update_bn(dnn.cuda(), dl_x, dl_xy, 0.1, True) self.assertTrue(dnn.training) def test_update_bn_cnn(self): # Test update_bn for convolutional network and BatchNorm2d objects = 100 input_channels = 3 height, width = 5, 5 x = flow.rand(objects, input_channels, height, width) y = flow.rand(objects) ds_x = flow.utils.data.TensorDataset(x) ds_xy = flow.utils.data.TensorDataset(x, y) dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) dl_xy = flow.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) dnn = SWATestCNN(input_channels=input_channels) dnn.train() self._test_update_bn(dnn, dl_x, dl_xy, 0.3, False) if flow.cuda.is_available(): dnn = SWATestCNN(input_channels=input_channels) dnn.train() self._test_update_bn(dnn.cuda(), dl_x, dl_xy, 0.3, True) self.assertTrue(dnn.training) def test_bn_update_eval_momentum(self): # check that update_bn preserves eval mode objects = 100 input_channels = 3 height, width = 5, 5 x = flow.rand(objects, input_channels, height, width) ds_x = flow.utils.data.TensorDataset(x) dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) dnn = SWATestCNN(input_channels=input_channels) dnn.eval() update_bn(dl_x, dnn) self.assertFalse(dnn.training) # check that momentum is preserved self.assertEqual(dnn.bn.momentum, 0.3) class SWATestDNN(flow.nn.Module): def __init__(self, input_features): super(SWATestDNN, self).__init__() self.n_features = 100 self.fc1 = flow.nn.Linear(input_features, self.n_features) self.bn = flow.nn.BatchNorm1d(self.n_features) def compute_preactivation(self, x): return self.fc1(x) def forward(self, x): x = self.fc1(x) x = self.bn(x) return x class SWATestCNN(flow.nn.Module): def __init__(self, input_channels): super(SWATestCNN, self).__init__() self.n_features = 10 self.conv1 = flow.nn.Conv2d( input_channels, self.n_features, kernel_size=3, padding=1 ) self.bn = flow.nn.BatchNorm2d(self.n_features, momentum=0.3) def compute_preactivation(self, x): return self.conv1(x) def forward(self, x): x = self.conv1(x) x = self.bn(x) return x class SchedulerTestNet(flow.nn.Module): def __init__(self): super(SchedulerTestNet, self).__init__() self.conv1 = flow.nn.Conv2d(1, 1, 1) self.conv2 = flow.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv2(F.relu(self.conv1(x))) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sync_and_async_allreduce.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import os import oneflow as flow import oneflow.unittest def sync_allreduce(x): return x.to_global(sbp=flow.sbp.broadcast) def async_allreduce(x): return flow._C.local_all_reduce(x) @flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestP2bOnGPU(flow.unittest.TestCase): def test_p2b(test_case): placement = flow.placement("cuda", range(4)) sync_x = flow.ones( (128, 1024), placement=placement, dtype=flow.int32, sbp=flow.sbp.partial_sum, ) async_x = flow.ones((128 * 2, 1024), device="cuda", dtype=flow.int32) i = 0 for i in range(500): synced_y = sync_allreduce(sync_x) asynced_y = async_allreduce(async_x) if i % 20 == 0: print(i) print(synced_y.to_local().numpy()) print(asynced_y.numpy()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_sync_batchnorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest from sync_batchnorm_test_util import ensure_datas @flow.unittest.skip_unless_1n2d() @unittest.skip("TODO(depeng): data too larger") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestSyncBatchNorm(flow.unittest.TestCase): def test_sync_batchnorm3d(test_case): data_path = ensure_datas() os.environ["ONEFLOW_ENABLE_NHWC"] = "0" channel = 8 input_np = np.load( f"{data_path}/sync_bn3d_nchw_input_rank{flow.env.get_rank()}.npy" ) torch_out = np.load( f"{data_path}/sync_bn3d_nchw_torch_output_rank{flow.env.get_rank()}.npy" ) torch_grad = np.load( f"{data_path}/sync_bn3d_nchw_torch_grad_rank{flow.env.get_rank()}.npy" ) of_input = flow.tensor(input_np, requires_grad=True, device="cuda") of_bn = flow.nn.BatchNorm3d(channel) of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda() of_res = of_bn(of_input) of_res.sum().backward() test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8)) test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,)) def test_sync_batchnorm2d(test_case): data_path = ensure_datas() os.environ["ONEFLOW_ENABLE_NHWC"] = "0" channel = 8 input_np = np.load( f"{data_path}/sync_bn2d_nchw_input_rank{flow.env.get_rank()}.npy" ) torch_out = np.load( f"{data_path}/sync_bn2d_nchw_torch_output_rank{flow.env.get_rank()}.npy" ) torch_grad = np.load( f"{data_path}/sync_bn2d_nchw_torch_grad_rank{flow.env.get_rank()}.npy" ) of_input = flow.tensor(input_np, requires_grad=True, device="cuda") of_bn = flow.nn.BatchNorm2d(channel) of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda() of_res = of_bn(of_input) of_res.sum().backward() test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8)) test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,)) def test_sync_batchnorm1d(test_case): data_path = ensure_datas() os.environ["ONEFLOW_ENABLE_NHWC"] = "0" channel = 8 input_np = np.load( f"{data_path}/sync_bn2d_nchw_input_rank{flow.env.get_rank()}.npy" ) torch_out = np.load( f"{data_path}/sync_bn2d_nchw_torch_output_rank{flow.env.get_rank()}.npy" ) torch_grad = np.load( f"{data_path}/sync_bn2d_nchw_torch_grad_rank{flow.env.get_rank()}.npy" ) of_input = flow.tensor(input_np, requires_grad=True, device="cuda") of_bn = flow.nn.BatchNorm1d(channel) of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda() of_res = of_bn(of_input) of_res.sum().backward() test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8)) test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_t.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestTransposeAllDimFunction(flow.unittest.TestCase): @autotest(check_graph=True) def test_t_flow_with_random_data(test_case): device = random_device() x = random_tensor( ndim=constant(2).to(int), dim0=random(0, 64), dim1=random(0, 64) ).to(device) y = torch.t(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_t5_layernorm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import math import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest import torch class TorchT5LayerNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) res = self.weight * hidden_states return res def _test_t5_layer_norm(test_case, device): torch_t5_layernrom = TorchT5LayerNorm(3) oneflow_t5_layernorm = flow.nn.RMSLayerNorm(3) torch_t5_layernrom.to(device) oneflow_t5_layernorm.to(device) x = np.random.randn(2, 4, 3) torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device)) oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device)) torch_y = torch_t5_layernrom(torch_x) oneflow_y = oneflow_t5_layernorm(oneflow_x) test_case.assertTrue( np.allclose( torch_y.detach().cpu().numpy(), oneflow_y.numpy(), rtol=1e-4, atol=1e-4 ) ) torch_y_sum = torch_y.sum() torch_y_sum.backward() oneflow_y_sum = oneflow_y.sum() oneflow_y_sum.backward() test_case.assertTrue( np.allclose( torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy(), rtol=1e-5, atol=1e-5 ) ) @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): def test_t5_layernorm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_t5_layer_norm] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensor_buffer.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type import oneflow as flow import oneflow.unittest def _test_tensor_buffer_convert(test_case, device): input = flow.tensor( np.random.rand(16, 24, 32, 36), dtype=flow.float32, device=flow.device(device) ) tensor_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=2) orig_tensor = flow.tensor_buffer_to_tensor( tensor_buffer, dtype=flow.float32, instance_shape=[32, 36] ) test_case.assertTrue(np.array_equal(input.numpy(), orig_tensor.numpy())) @flow.unittest.skip_unless_1n1d() class TestTensorBufferOps(flow.unittest.TestCase): def test_tensor_buffer_convert(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_tensor_buffer_convert] arg_dict["device"] = ["cpu"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensor_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from random import shuffle from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_is_floating_point(test_case, shape, device, dtype): np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=dtype, device=device) output = input.is_floating_point() if input.dtype in (flow.float, flow.float16, flow.float32, flow.double): test_case.assertEqual(output, True) else: test_case.assertEqual(output, False) def _test_type_dtype(test_case, shape, device, src_dtype, tgt_dtype): # test tensor.type(x: dtype) rather than tensor.type_dtype np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=src_dtype, device=device) input = input.type(tgt_dtype) test_case.assertEqual(input.dtype, tgt_dtype) test_case.assertEqual(input.device, flow.device(device)) def _test_type_str( test_case, tensortype_dict, shape, device, dtype, tgt_tensortype_str ): # test tensor.type(x: str) rather than tensor.type_tensortype np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=dtype, device=device) input = input.type(tgt_tensortype_str) tgt_dtype, tgt_device = tensortype_dict[tgt_tensortype_str] test_case.assertEqual(input.dtype, tgt_dtype) test_case.assertEqual(input.device, tgt_device) def _test_type_tensortype( test_case, tensortype_dict, shape, device, dtype, tgt_tensortype ): # test tensor.type(x: tensortype) rather than tensor.type_tensortype np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=dtype, device=device) input = input.type(tgt_tensortype) tgt_dtype, tgt_device = tensortype_dict[tgt_tensortype] test_case.assertEqual(input.dtype, tgt_dtype) test_case.assertEqual(input.device, tgt_device) def _test_type_noargs(test_case, shape, device, dtype): # test tensor.type() rather than tensor.type_noargs def generate_tensortype_string(device, dtype): dtype_to_str_dict = { flow.uint8: "ByteTensor", flow.int8: "CharTensor", flow.int32: "IntTensor", flow.int64: "LongTensor", flow.float16: "HalfTensor", flow.bfloat16: "BFloat16Tensor", # Currently unsupport flow.float32: "FloatTensor", flow.float64: "DoubleTensor", } dtype = dtype_to_str_dict[dtype] if device == "cpu": return dtype return ".".join([device, dtype]) np_input = np.random.rand(*shape) input = flow.tensor(np_input, dtype=dtype, device=device) test_case.assertEqual( input.type(), "oneflow." + generate_tensortype_string(device, dtype) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestCuda(flow.unittest.TestCase): @autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-4, check_graph=True) def test_cuda(test_case): device = random_device() x = random_tensor().to(device) x = x.cuda() y = x.sum() return y @autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-4, check_graph=True) def test_cuda_0dim(test_case): device = random_device() x = random_tensor(ndim=0).to(device) x = x.cuda() y = x.sum() return y @autotest(n=5) def test_cuda_int_device(test_case): device = random_device() x = random_tensor().to(device) x = x.cuda(0) y = x.sum() return y @flow.unittest.skip_unless_1n1d() class TestTensorOps(flow.unittest.TestCase): @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_cpu(test_case): device = random_device() x = random_tensor().to(device) x = x.cpu() y = x.sum() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_long(test_case): device = random_device() x = random_tensor().to(device) y = x.long() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_long_0dim(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.long() return y @autotest(n=5, auto_backward=False) def test_long_with_non_contiguous_input(test_case): device = random_device() permute_list = list(range(4)) shuffle(permute_list) input = random_tensor(ndim=4).to(device) x = input.permute(permute_list) y = x.long() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_int(test_case): device = random_device() x = random_tensor().to(device) y = x.int() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_int_0dim(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.int() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_half(test_case): device = random_device() x = random_tensor(dtype=int).to(device) y = x.half() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_half_0dim(test_case): device = random_device() x = random_tensor(ndim=0, dtype=int).to(device) y = x.half() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_float(test_case): device = random_device() x = random_tensor(dtype=int).to(device) y = x.float() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_float_0dim(test_case): device = random_device() x = random_tensor(ndim=0, dtype=int).to(device) y = x.float() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_double(test_case): device = random_device() x = random_tensor(dtype=int).to(device) y = x.double() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_double_0dim(test_case): device = random_device() x = random_tensor(ndim=0, dtype=int).to(device) y = x.double() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_bool(test_case): device = random_device() x = random_tensor().to(device) y = x.bool() return y @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_bool_0dim(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = x.bool() return y @autotest(n=5, auto_backward=False) def test_bool_with_non_contiguous_input(test_case): device = random_device() permute_list = list(range(4)) shuffle(permute_list) input = random_tensor(ndim=4).to(device) x = input.permute(permute_list) y = x.bool() return y # Not check graph because of 2 reason. # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None. # Reason 2, This op needs to convert the EagerTensor to a numpy array,so this op only supports eager mode. # Please refer to File "oneflow/api/python/utils/tensor_utils.h", line 49, in EagerTensorToNumpy. @autotest( n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph="ValidatedFalse" ) def test_item(test_case): device = random_device() x = random_tensor(ndim=1, dim0=1, dtype=int).to(device) y = torch.tensor(x.item()) return y # Not check graph because of 2 reason. # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None. # Reason 2, This op needs to convert the EagerTensor to a numpy array,so this op only supports eager mode. # Please refer to File "oneflow/api/python/utils/tensor_utils.h", line 49, in EagerTensorToNumpy. @autotest( n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph="ValidatedFalse" ) def test_item_0dim(test_case): device = random_device() x = random_tensor(ndim=0, dtype=int).to(device) y = torch.tensor(x.item()) return y # Not check graph because of 2 reasons # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None. # Reason 2, This op needs to convert the EagerTensor to a numpy array,so this op only supports eager mode. # Please refer to File "oneflow/api/python/utils/tensor_utils.h", line 49, in EagerTensorToNumpy. @autotest( n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph="ValidatedFalse" ) def test_tolist(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.tensor(x.tolist()) return y # Not check graph because of 2 reasons # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None. # Reason 2, This op needs to convert the EagerTensor to a numpy array,so this op only supports eager mode. # Please refer to File "oneflow/api/python/utils/tensor_utils.h", line 49, in EagerTensorToNumpy. @autotest( n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph="ValidatedFalse" ) def test_tolist_0dim(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y = torch.tensor(x.tolist()) return y @autotest() def test_type_as(test_case): input = random_tensor().to(random_device()) target = random_tensor().to(random_device()) input = input.type_as(target) return input def test_is_floating_point(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [ flow.uint8, flow.int8, flow.int32, flow.int64, flow.float32, flow.float64, flow.double, flow.float, flow.int, ] for arg in GenArgList(arg_dict): _test_is_floating_point(test_case, *arg) def test_type_dtype(test_case): # test tensor.type(x.dtype) rather than tensor.type_dtype arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["src_dtype"] = [ flow.uint8, flow.int8, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] arg_dict["tgt_dtype"] = arg_dict["src_dtype"] for arg in GenArgList(arg_dict): _test_type_dtype(test_case, *arg) def test_type_tensortype_str_cpu(test_case): # test tensor.type(x: str) rather than tensor.type_tensortype arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["src_dtype"] = [ flow.uint8, flow.int8, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] tensortype_dict = { "oneflow.CharTensor": [flow.char, flow.device("cpu")], "oneflow.ByteTensor": [flow.uint8, flow.device("cpu")], "oneflow.IntTensor": [flow.int32, flow.device("cpu")], "oneflow.LongTensor": [flow.int64, flow.device("cpu")], "oneflow.HalfTensor": [flow.float16, flow.device("cpu")], "oneflow.FloatTensor": [flow.float32, flow.device("cpu")], "oneflow.DoubleTensor": [flow.float64, flow.device("cpu")], } arg_dict["tgt_tensortype_str"] = list(tensortype_dict.keys()) for arg in GenArgList(arg_dict): _test_type_str(test_case, tensortype_dict, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_type_tensortype_str(test_case): # test tensor.type(x: str) rather than tensor.type_tensortype arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["src_dtype"] = [ flow.uint8, flow.char, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] tensortype_dict = { "oneflow.CharTensor": [flow.char, flow.device("cpu")], "oneflow.ByteTensor": [flow.uint8, flow.device("cpu")], "oneflow.IntTensor": [flow.int32, flow.device("cpu")], "oneflow.LongTensor": [flow.int64, flow.device("cpu")], "oneflow.HalfTensor": [flow.float16, flow.device("cpu")], "oneflow.FloatTensor": [flow.float32, flow.device("cpu")], "oneflow.DoubleTensor": [flow.float64, flow.device("cpu")], "oneflow.cuda.CharTensor": [flow.char, flow.device("cuda")], "oneflow.cuda.ByteTensor": [flow.uint8, flow.device("cuda")], "oneflow.cuda.IntTensor": [flow.int32, flow.device("cuda")], "oneflow.cuda.LongTensor": [flow.int64, flow.device("cuda")], "oneflow.cuda.HalfTensor": [flow.float16, flow.device("cuda")], "oneflow.cuda.FloatTensor": [flow.float32, flow.device("cuda")], "oneflow.cuda.DoubleTensor": [flow.float64, flow.device("cuda")], } arg_dict["tgt_tensortype_str"] = list(tensortype_dict.keys()) for arg in GenArgList(arg_dict): _test_type_str(test_case, tensortype_dict, *arg) def test_type_tensortype_cpu(test_case): # test tensor.type(x: tensortype) rather than tensor.type_tensortype arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["src_dtype"] = [ flow.uint8, flow.int8, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] tensortype_dict = { flow.CharTensor: [flow.int8, flow.device("cpu")], flow.ByteTensor: [flow.uint8, flow.device("cpu")], flow.IntTensor: [flow.int32, flow.device("cpu")], flow.LongTensor: [flow.int64, flow.device("cpu")], flow.HalfTensor: [flow.float16, flow.device("cpu")], flow.FloatTensor: [flow.float32, flow.device("cpu")], flow.DoubleTensor: [flow.float64, flow.device("cpu")], } arg_dict["tgt_tensortype"] = list(tensortype_dict.keys()) for arg in GenArgList(arg_dict): _test_type_tensortype(test_case, tensortype_dict, *arg) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_type_tensortype(test_case): # test tensor.type(x: tensortype) rather than tensor.type_tensortype arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["src_dtype"] = [ flow.uint8, flow.int8, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] tensortype_dict = { flow.CharTensor: [flow.int8, flow.device("cpu")], flow.ByteTensor: [flow.uint8, flow.device("cpu")], flow.IntTensor: [flow.int32, flow.device("cpu")], flow.LongTensor: [flow.int64, flow.device("cpu")], flow.HalfTensor: [flow.float16, flow.device("cpu")], flow.Tensor: [flow.float32, flow.device("cpu")], flow.FloatTensor: [flow.float32, flow.device("cpu")], flow.DoubleTensor: [flow.float64, flow.device("cpu")], flow.cuda.CharTensor: [flow.int8, flow.device("cuda")], flow.cuda.ByteTensor: [flow.uint8, flow.device("cuda")], flow.cuda.IntTensor: [flow.int32, flow.device("cuda")], flow.cuda.LongTensor: [flow.int64, flow.device("cuda")], flow.cuda.HalfTensor: [flow.float16, flow.device("cuda")], flow.cuda.FloatTensor: [flow.float32, flow.device("cuda"),], flow.cuda.DoubleTensor: [flow.float64, flow.device("cuda"),], } arg_dict["tgt_tensortype"] = list(tensortype_dict.keys()) for arg in GenArgList(arg_dict): _test_type_tensortype(test_case, tensortype_dict, *arg) def test_type_noargs(test_case): # test tensor.type() rather than tensor.type_noargs arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dtype"] = [ flow.uint8, flow.int8, flow.int64, flow.int32, flow.float16, flow.float32, flow.float64, ] for arg in GenArgList(arg_dict): _test_type_noargs(test_case, *arg) @autotest(n=3, auto_backward=False) def test_bincount(test_case): device = random_device() len = random(1, 100) input = random_tensor(1, len, dtype=int, low=0).to(device) weight = random_tensor(1, len, dtype=float).to(device) min_length = random(1, 100) | nothing() return ( input.bincount(minlength=min_length), input.bincount(weight, minlength=min_length), ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensor_scatter_nd_update.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_tensor_scatter_nd_update(test_case, device): origin = flow.tensor(np.arange(8), dtype=flow.float, device=flow.device(device)) indices = flow.tensor( np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device) ) update = flow.tensor( np.array([10.2, 5.1, 12.7]), dtype=flow.float, device=flow.device(device) ) np_out = np.array([0.0, 10.2, 2.0, 3.0, 12.7, 5.0, 5.1, 7.0]) output = flow.tensor_scatter_nd_update(origin, indices, update) test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) def _test_tensor_scatter_nd_update_with_non_contiguous_input(test_case, device): # non-contiguous tensor with shape (2, 3, 4) origin = flow.tensor( np.ones((4, 3, 2)), dtype=flow.float, device=flow.device(device) ).permute(2, 1, 0) # indices with shape (3, 2) indices = flow.tensor( np.array([[0, 0], [1, 0], [1, 1]]), dtype=flow.int, device=flow.device(device) ) # non-contiguous update with shape (3, 4) update = flow.tensor( np.zeros((4, 3)), dtype=flow.float, device=flow.device(device) ).T output = flow.tensor_scatter_nd_update(origin, indices, update) np_res = np.ones((2, 3, 4)) np_res[0, 0] = 0 np_res[1, 0] = 0 np_res[1, 1] = 0 test_case.assertTrue(np.array_equal(output.numpy(), np_res)) def _test_tensor_scatter_nd_update_t(test_case, device): origin = flow.tensor( np.arange(15).reshape(5, 3), dtype=flow.float, device=flow.device(device) ) indices = flow.tensor( np.array([[0], [4], [2]]), dtype=flow.int, device=flow.device(device) ) update = flow.tensor( np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]), dtype=flow.float, device=flow.device(device), ) np_out = np.array( [ [1.0, 1.0, 1.0], [3.0, 4.0, 5.0], [3.0, 3.0, 3.0], [9.0, 10.0, 11.0], [2.0, 2.0, 2.0], ] ) output = flow.tensor_scatter_nd_update(origin, indices, update) test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) def _test_tensor_scatter_nd_update_backward(test_case, device): origin = flow.tensor( np.arange(8), dtype=flow.float, device=flow.device(device), requires_grad=True, ) indices = flow.tensor( np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device) ) of_update = flow.tensor( np.array([10.2, 5.1, 12.7]), requires_grad=True, dtype=flow.float, device=flow.device(device), ) np_out = np.array([0.0, 10.2, 2.0, 3.0, 12.7, 5.0, 5.1, 7.0]) np_update_grad = np.array([1.0, 1.0, 1.0]) np_origin_grad = np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0]) output = flow.tensor_scatter_nd_update(origin, indices, of_update) out_sum = output.sum() out_sum.backward() test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001)) test_case.assertTrue(np.allclose(of_update.grad.numpy(), np_update_grad)) test_case.assertTrue(np.allclose(origin.grad.numpy(), np_origin_grad)) @flow.unittest.skip_unless_1n1d() class TestTensorScatterNdUpdate(flow.unittest.TestCase): def test_tensor_scatter_nd_update(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_tensor_scatter_nd_update, _test_tensor_scatter_nd_update_with_non_contiguous_input, _test_tensor_scatter_nd_update_t, _test_tensor_scatter_nd_update_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensor_split.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestTensorSplitVec(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_tensor_split_vec(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) dim = random(-3, 3).to(int) z = torch.tensor_split(x, (1, 2), dim) return z[0] @autotest(n=5) def test_flow_tensor_split_vec_with_stride(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) dim = random(-3, 3).to(int) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.tensor_split(y, (1, 2), dim) return z[0] @flow.unittest.skip_unless_1n1d() class TestTensorSplitInt(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_tensor_split_int(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) split = random(1, 3).to(int) dim = random(-3, 3).to(int) z = torch.tensor_split(x, split, dim) return z[0] if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensor_to.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n2d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class Test2DeviceGlobalTensorTo(flow.unittest.TestCase): def test_asymmetric_global_tensor_clone(test_case): placement = flow.placement("cuda", range(1)) x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast) cloned = x.detach().clone() test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) if flow.env.get_rank() == 0: cloned_local = cloned.to_local() cloned_local[0] = 0 test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) def test_global_tensor_clone(test_case): placement = flow.placement("cuda", range(2)) x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast) cloned = x.detach().clone() test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) cloned_local = cloned.to_local() cloned_local[0] = 0 test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) def test_global_tensor_to(test_case): placement = flow.placement("cuda", range(2)) x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast) cloned = x.to(copy=True) test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) cloned_local = cloned.to_local() cloned_local[0] = 0 test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) def test_tensor_to_h2d1(test_case): input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.int64) output = input.to(device=flow.device("cuda:1"), dtype=flow.int32) test_case.assertEqual(output.device, flow.device("cuda:1")) test_case.assertEqual(output.dtype, flow.int32) test_case.assertTrue( np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) ) @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTo(flow.unittest.TestCase): def test_global_tensor_clone(test_case): x = flow.ones( (4,), placement=flow.placement("cuda", ranks=[0]), sbp=flow.sbp.broadcast ) cloned = x.detach().clone() test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) cloned_local = cloned.to_local() cloned_local[0] = 0 test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) def test_global_tensor_to(test_case): x = flow.ones( (4,), placement=flow.placement("cuda", ranks=[0]), sbp=flow.sbp.broadcast ) cloned = x.to(copy=True) test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) cloned_local = cloned.to_local() cloned_local[0] = 0 test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) def test_empty_global_tensor_to(test_case): x = flow.ones( (0,), placement=flow.placement("cuda", ranks=[0]), sbp=flow.sbp.broadcast ) cloned = x.to(copy=True) test_case.assertEqual(x.placement, cloned.placement) test_case.assertEqual(x.sbp, cloned.sbp) cloned_local = cloned.to_local() test_case.assertEqual(tuple(cloned.shape), (0,)) test_case.assertEqual(tuple(cloned_local.shape), (0,)) def test_tensor_to_h2d(test_case): input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32) output = input.to(device=flow.device("cuda")) test_case.assertEqual(output.device, flow.device("cuda")) test_case.assertTrue( np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) ) gpu_output = output.to(device=flow.device("cuda")) test_case.assertEqual(gpu_output.device, flow.device("cuda")) test_case.assertTrue( np.allclose(input.numpy(), gpu_output.numpy(), rtol=0.0001, atol=0.0001) ) def test_tensor_to_d2h(test_case): input = flow.tensor( np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device("cuda") ) output = input.to(device=flow.device("cpu")) test_case.assertEqual(output.device, flow.device("cpu")) test_case.assertTrue( np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) ) def test_tensor_to_d2d(test_case): input = flow.tensor( np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device("cuda") ) output = input.to(device=flow.device("cuda:0")) test_case.assertEqual(output.device, flow.device("cuda:0")) test_case.assertTrue( np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) ) def test_tensor_to_h2h(test_case): input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32) output = input.to(device=flow.device("cpu")) test_case.assertEqual(output.device, flow.device("cpu")) test_case.assertTrue( np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) ) def test_tensor_to_cast(test_case): input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32) output = input.to(dtype=flow.int) test_case.assertEqual(output.dtype, flow.int) def test_tensor_to_cast_h2d(test_case): input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32) output = input.to(device=flow.device("cuda"), dtype=flow.int) test_case.assertEqual(output.dtype, flow.int) test_case.assertEqual(output.device, flow.device("cuda")) def test_tensor_using_tensor(test_case): tensor = flow.tensor(np.random.randn(2, 3, 4, 5), device="cuda", dtype=flow.int) input = flow.tensor(np.random.randn(2, 3)) output = input.to(tensor) test_case.assertEqual(output.dtype, flow.int) test_case.assertEqual(output.device, flow.device("cuda")) @autotest(n=5, check_graph=True) def test_int_to_args(test_case): device_num = random(0, 2).to(int).value() x = random_tensor(ndim=4).to(device_num) return x @autotest(n=5, check_graph=True) def test_int_to_kwargs(test_case): device_num = random(0, 2).to(int).value() x = random_tensor(ndim=4).to(device=device_num) return x if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tensordot.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict import unittest import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestTensordot(flow.unittest.TestCase): @autotest(n=5, rtol=1e-2, atol=1e-3) def test_tensordot_intdim(test_case): device = random_device() dims = random() dims_list = [random().to(int).value() for i in range(dims.to(int).value() + 3)] x = random_tensor( ndim=3, dim0=dims_list[0], dim1=dims_list[1], dim2=dims_list[2], ).to(device) y = random_tensor( ndim=3, dim0=dims_list[0 + dims.to(int).value()], dim1=dims_list[1 + dims.to(int).value()], dim2=dims_list[2 + dims.to(int).value()], ).to(device) z = torch.tensordot(x, y, dims=3 - dims.to(int).value()) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_tensordot_list_dim(test_case): device = random_device() x = random_tensor(4, 1, 3, 2, 5).to(device) y = random_tensor(4, 4, 2, 3, 5).to(device) z = torch.tensordot(x, y, dims=[[1, 2, 0], [2, 1, 0]]) return z @autotest(n=5, rtol=1e-2, atol=1e-2) def test_tensordot_tuple_dim(test_case): device = random_device() x = random_tensor(4, 1, 3, 2, 5).to(device) y = random_tensor(4, 4, 2, 3, 5).to(device) z = torch.tensordot(x, y, dims=([1, 2, 0], [2, 1, 0])) return z @autotest(n=5, rtol=1e-2, atol=1e-3) def test_tensordot_list_neg_dim(test_case): device = random_device() x = random_tensor(4, 1, 3, 2, 5).to(device) y = random_tensor(4, 4, 2, 3, 5).to(device) z = torch.tensordot(x, y, dims=[[-3, -2, -4], [-2, -3, -4]]) return z @autotest(check_graph=False, rtol=1e-2, atol=1e-3) def test_tensordot_backward(test_case): device = random_device() x = random_tensor(3, 3, 4, 5).to(device) y = random_tensor(2, 4, 5).to(device) z = torch.tensordot(x, y, dims=[[1, 2], [0, 1]]) z.sum().backward() @autotest(check_graph=False) def test_tensordot_tensor_dim(test_case): def _test_tensor_dim(test_case, device): np_dim = np.array([[1, 2, 3], [1, 2, 3]], dtype=int) flow_dim = flow.tensor(np_dim).to(device) torch_dim = torch.tensor(np_dim).to(device) np_random_array = np.random.randn(2, 3, 4, 5) flow_tensor = flow.tensor(np_random_array).to(device) torch_tensor = torch.tensor(np_random_array).to(device) flow_result = flow.tensordot(flow_tensor, flow_tensor, dims=flow_dim) torch_result = torch.tensordot(torch_tensor, torch_tensor, dims=torch_dim) test_case.assertTrue( np.allclose( flow_result.numpy(), torch_result.cpu().numpy(), rtol=0.0001, atol=0.0001, ) ) arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_tensor_dim(test_case, arg[0]) @autotest(n=5, check_graph=False, rtol=1e-2, atol=1e-2) def test_tensordot_single_item_tensor_dim(test_case): device = random_device() dims = random_tensor(1, dim0=1, low=0, high=4, dtype=int).to(device) x = random_tensor(3, dim0=4, dim1=4, dim2=4).to(device) y = random_tensor(3, dim0=4, dim1=4, dim2=4).to(device) z = torch.tensordot(x, y, dims=dims) return z @autotest(n=5, rtol=1e-3, atol=1e-4) def test_tensordot_broadcast(test_case): device = random_device() x = random_tensor(4, 1, 1, 1, 1).to(device) y = random_tensor(4, 2, 3, 4, 5).to(device) z = torch.tensordot(x, y, dims=random(high=5).to(int).value()) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tile.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestTile(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_tile_with_random_data(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) z = torch.tile(x, reps) return z @autotest(check_graph=True) def test_flow_tensor_tile_with_random_data(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) y = x.tile(reps) return y @autotest(auto_backward=False, check_graph=True) def test_flow_tile_bool_with_random_data(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2).to(torch.bool) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) z = torch.tile(x, reps) return z @autotest(check_graph=True) def test_flow_tile_with_0dim_data(test_case): x = random_tensor(ndim=0) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) z = torch.tile(x, reps) return z if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_to_torch.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import os import oneflow as flow import oneflow.unittest import torch @flow.unittest.skip_unless_1n1d() class TestToTroch(flow.unittest.TestCase): # NOTE: oneflow and torch cpu tensor shared the same memory, refer to File "python/oneflow/test/modules/test_from_torch.py", line 49, in test_from_torch_cpu. def test_to_torch_cpu(test_case): flow_t = flow.rand(5, 3, 3) numpy_from_flow = flow_t.numpy() torch_t = flow.utils.tensor.to_torch(flow_t) test_case.assertEqual( torch_t.data_ptr(), numpy_from_flow.__array_interface__["data"][0] ) numpy_from_flow[0][0] = [1, 2, 3] test_case.assertTrue( np.allclose(torch_t.numpy(), numpy_from_flow, rtol=0.001, atol=0.001) ) test_case.assertTrue( np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype) # NOTE: For the case of 0 size tensor, no memory addresses are compared. # Because the address of 0 size tensor is random at this time. def test_to_torch_cpu_with_0_size_data(test_case): flow_t = flow.rand(5, 3, 0) torch_t = flow.utils.tensor.to_torch(flow_t) test_case.assertTrue( np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype) def test_to_torch_cpu_with_0dim_data(test_case): flow_t = flow.tensor(5) numpy_from_flow = flow_t.numpy() torch_t = flow.utils.tensor.to_torch(flow_t) test_case.assertEqual( torch_t.data_ptr(), numpy_from_flow.__array_interface__["data"][0] ) test_case.assertTrue( np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001) ) test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_to_torch_gpu(test_case): flow_t = flow.rand(5, 3, 3).to("cuda") torch_t = flow.utils.tensor.to_torch(flow_t) flow_t[0][0] = flow.tensor([1, 2, 3]).to(flow.float32) # NOTE: OneFlow operations are asynchoronously executed, # so we need to synchronize explicitly here. flow._oneflow_internal.eager.Sync() test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy())) test_case.assertEqual(flow_t.numpy().dtype, torch_t.cpu().numpy().dtype) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_to_torch_global(test_case): flow_t = flow.rand(5, 3, 3).to_global( placement=flow.placement.all("cuda"), sbp=flow.sbp.broadcast ) torch_t = flow.utils.tensor.to_torch(flow_t) test_case.assertEqual(flow_t.numpy().dtype, torch_t.cpu().numpy().dtype) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_topk.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import torch import oneflow.unittest def _test_top_k(test_case, shape, k, dim, device): if k >= shape[dim]: return x_np = np.random.randn(*shape) x_of = flow.tensor(x_np, device=device) of_out = flow.topk(x_of, k=k, dim=dim) x_pt = torch.tensor(x_np, device=device) pt_out = torch.topk(x_pt, k=k, dim=dim) test_case.assertTrue( np.array_equal(of_out.values.cpu().numpy(), pt_out.values.cpu().numpy()) ) test_case.assertTrue( np.array_equal(of_out.indices.cpu().numpy(), pt_out.indices.cpu().numpy()) ) @flow.unittest.skip_unless_1n1d() class TestTopK(flow.unittest.TestCase): def test_in_top_k(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 16), (1, 1024), (8, 8), (8, 256)] arg_dict["k"] = [1, 4, 64] arg_dict["dim"] = [0, 1] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_top_k(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_transpose.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from cgi import test import unittest from collections import OrderedDict import numpy as np from random import shuffle from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_transpose(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.transpose(input, 0, 1) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_tensor_transpose(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = input.transpose(0, 1) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_tranpose_negative_dim(test_case, device): input = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device) ) of_out = flow.transpose(input, -4, -3) np_out = input.numpy().transpose((1, 0, 2, 3)) test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten())) def _test_transpose_backward(test_case, device): x = flow.tensor( np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.transpose(x, 0, 1).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((2, 6, 5, 3)), 1e-05, 1e-05) ) def _test_transpose_backward_v2(test_case, device): x = flow.tensor( np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.transpose(x, 3, 1).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((2, 3, 4, 5)), 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestTranspose(flow.unittest.TestCase): def test_transpose(test_case): arg_dict = OrderedDict() arg_dict["fun"] = [ _test_transpose, _test_tensor_transpose, _test_tranpose_negative_dim, _test_transpose_backward, _test_transpose_backward_v2, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=10, check_graph=True) def test_transpose_flow_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y @autotest(n=10, check_graph=True) def test_transpose_with_stride(test_case): device = random_device() x = random_tensor(ndim=4).to(device) permute_list = [0, 1, 2, 3] shuffle(permute_list) x = x.permute(permute_list) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_transpose_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 4).to(device) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y @autotest(n=10, auto_backward=False, check_graph=True) def test_transpose_flow_bool_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device=device, dtype=torch.bool) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_tril.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestTril(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_tril_without_diag(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(1, 5).to(int), dim1=random(1, 5).to(int), dim2=random(1, 5).to(int), dim3=random(1, 5).to(int), ).to(device) y = torch.tril(x) y = torch.exp(y) return y @autotest(n=5, check_graph=True) def test_tril_with_diag(test_case): device = random_device() diagonal = random(-3, 3).to(int) x = random_tensor( ndim=4, dim0=random(1, 5).to(int), dim1=random(1, 5).to(int), dim2=random(1, 5).to(int), dim3=random(1, 5).to(int), ).to(device) y = torch.tril(x, diagonal) y = torch.exp(y) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_triu.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.nn as nn import oneflow.unittest from oneflow.test_utils.automated_test_util import * def _test_triu(test_case, diagonal, device, dtype): arr_shape = (4, 4, 8) flow_dtype, np_dtype = dtype np_arr = np.random.randn(*arr_shape).astype(np_dtype) input_tensor = flow.tensor( np_arr, dtype=flow_dtype, device=flow.device(device), requires_grad=True ) output = flow.triu(input_tensor, diagonal=diagonal) np_out = np.triu(np_arr, diagonal) test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() np_grad = np.triu(np.ones(shape=arr_shape, dtype=np_dtype), diagonal) test_case.assertTrue(np.allclose(input_tensor.grad.numpy(), np_grad, 1e-06, 1e-06)) def _test_triu_(test_case, diagonal, device, dtype): arr_shape = (4, 4, 8) flow_dtype, np_dtype = dtype np_arr = np.random.randn(*arr_shape).astype(np_dtype) input = flow.tensor(np_arr, dtype=flow_dtype, device=flow.device(device)) np_out = np.triu(np_arr, diagonal) test_case.assertFalse(np.allclose(input.numpy(), np_out)) input.triu_(diagonal=diagonal) test_case.assertTrue(np.allclose(input.numpy(), np_out)) @flow.unittest.skip_unless_1n1d() class TestTriu(flow.unittest.TestCase): def test_triu(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_triu, _test_triu_] arg_dict["diagonal"] = [2, -1] arg_dict["device"] = ["cuda", "cpu"] arg_dict["dtype"] = [(flow.float32, np.float32), (flow.float16, np.float16)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest() def test_triu_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = torch.triu(x) return y @autotest() def test_triu_with_0_size_data_fp16(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device, torch.float16) y = torch.triu(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_trunc.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestTrunc(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_trunc(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(1, 5).to(int), dim1=random(1, 5).to(int), dim2=random(1, 5).to(int), dim3=random(1, 5).to(int), ).to(device) y = torch.trunc(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_trunc_divide.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from oneflow.test_utils.automated_test_util import * import oneflow as flow import torch as torch_original import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestTruncDivide(flow.unittest.TestCase): @autotest(n=5, check_allclose=False, check_graph=True) def test_elementwise_trunc_divide_random_data(test_case): device = random_device() dim0 = random(1, 8) dim1 = random(1, 8) dim2 = random(1, 8) dim3 = random(1, 8) x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3).to(device) y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3).to(device) x.oneflow = x.oneflow.detach().requires_grad_() x.pytorch = x.pytorch.detach().requires_grad_() y.oneflow = y.oneflow.detach().requires_grad_() y.pytorch = y.pytorch.detach().requires_grad_() oneflow_out = flow._C.trunc_divide(x.oneflow, y.oneflow) torch_out = torch_original.div(x.pytorch, y.pytorch, rounding_mode="trunc") test_case.assertTrue( np.allclose( oneflow_out.detach().cpu().numpy(), torch_out.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) oneflow_out.sum().backward() torch_out.sum().backward() test_case.assertTrue( np.allclose( x.oneflow.grad.detach().cpu().numpy(), x.pytorch.grad.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) test_case.assertTrue( np.allclose( y.oneflow.grad.detach().cpu().numpy(), y.pytorch.grad.detach().cpu().numpy(), rtol=0.0001, atol=1e-05, ) ) @autotest(n=5, check_allclose=False, check_graph=True) def test_tensor_truncdiv_scalar_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(1, 8), dim1=random(1, 8), dim2=random(1, 8), dim3=random(1, 8), ).to(device) x.oneflow = x.oneflow.detach().requires_grad_() x.pytorch = x.pytorch.detach().requires_grad_() scalar = random().to(float).value() oneflow_out = oneflow._C.trunc_divide(x.oneflow, scalar) torch_out = torch_original.div(x.pytorch, scalar, rounding_mode="trunc") test_case.assertTrue( np.allclose( oneflow_out.detach().cpu().numpy(), torch_out.detach().cpu().numpy(), rtol=0.0001, atol=1e-5, ) ) oneflow_out.sum().backward() torch_out.sum().backward() test_case.assertTrue( np.allclose( x.oneflow.grad.detach().cpu().numpy(), x.pytorch.grad.detach().cpu().numpy(), rtol=0.0001, atol=1e-5, ) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_type_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import random import unittest import os import numpy as np import oneflow as flow import oneflow.unittest type_tensor_all = [ { "cpu_interface": flow.HalfTensor, "cuda_interface": flow.cuda.HalfTensor, "dtype": flow.float16, }, { "cpu_interface": flow.FloatTensor, "cuda_interface": flow.cuda.FloatTensor, "dtype": flow.float32, }, { "cpu_interface": flow.DoubleTensor, "cuda_interface": flow.cuda.DoubleTensor, "dtype": flow.float64, }, { "cpu_interface": flow.BoolTensor, "cuda_interface": flow.cuda.BoolTensor, "dtype": flow.bool, }, { "cpu_interface": flow.ByteTensor, "cuda_interface": flow.cuda.ByteTensor, "dtype": flow.uint8, }, { "cpu_interface": flow.CharTensor, "cuda_interface": flow.cuda.CharTensor, "dtype": flow.int8, }, { "cpu_interface": flow.IntTensor, "cuda_interface": flow.cuda.IntTensor, "dtype": flow.int32, }, { "cpu_interface": flow.LongTensor, "cuda_interface": flow.cuda.LongTensor, "dtype": flow.int64, }, # TODO: flow.BFloat16Tensor fails to creat Tensor. # {"cpu_interface": flow.BFloat16Tensor, "cuda_interface": flow.cuda.BFloat16Tensor, "dtype": flow.bfloat16}, ] @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTypeTensor(flow.unittest.TestCase): def test_type_tensor(test_case): for type_tensor_case in type_tensor_all: x = type_tensor_case["cpu_interface"](np.random.randn(2, 3, 4, 5)) test_case.assertEqual(x.device, flow.device("cpu")) test_case.assertEqual(x.dtype, type_tensor_case["dtype"]) test_case.assertEqual(x.shape, (2, 3, 4, 5)) test_case.assertFalse(x.requires_grad) test_case.assertTrue(x.is_leaf) y = type_tensor_case["cuda_interface"](np.random.randn(2, 3, 4, 5)) test_case.assertEqual(y.device, flow.device("cuda")) test_case.assertEqual(y.dtype, type_tensor_case["dtype"]) test_case.assertEqual(y.shape, (2, 3, 4, 5)) test_case.assertFalse(y.requires_grad) test_case.assertTrue(y.is_leaf) def test_doubletensor_corner_cases(test_case): corner_cases = [random.randint(1 << 24, 1 << 25) for _ in range(20)] test_case.assertTrue( np.allclose( flow.DoubleTensor(corner_cases).numpy(), np.array(corner_cases, dtype=np.float64), 1e-6, 1e-6, ) ) def test_type_tensor_ctor(test_case): for tensor_type in type_tensor_all: cpu_type = tensor_type["cpu_interface"] cuda_type = tensor_type["cuda_interface"] # empty ctor cpu_type_tensor = cpu_type() cuda_type_tensor = cuda_type() test_case.assertEqual(cpu_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cpu_type_tensor.device, flow.device("cpu")) test_case.assertEqual(cuda_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cuda_type_tensor.device, flow.device("cuda")) # other ctor other_tensor = flow.Tensor(flow.Size([2, 3, 4, 5])) cpu_type_tensor = cpu_type(other_tensor) cuda_type_tensor = cuda_type(other_tensor) test_case.assertEqual(cpu_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cpu_type_tensor.device, flow.device("cpu")) test_case.assertEqual(cuda_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cuda_type_tensor.device, flow.device("cuda")) # data ctor # numpy inputs have been tested above in test_type_tensor data = [random.random() for i in range(20)] cpu_type_tensor = cpu_type(data) cuda_type_tensor = cuda_type(data) test_case.assertEqual(cpu_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cpu_type_tensor.device, flow.device("cpu")) test_case.assertEqual(cuda_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cuda_type_tensor.device, flow.device("cuda")) # shape ctor shape = flow.Size([2, 3, 4, 5]) cpu_type_tensor = cpu_type(shape) cuda_type_tensor = cuda_type(shape) test_case.assertEqual(cpu_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cpu_type_tensor.device, flow.device("cpu")) test_case.assertEqual(cuda_type_tensor.dtype, tensor_type["dtype"]) test_case.assertEqual(cuda_type_tensor.device, flow.device("cuda")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_unbind.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestUnbind(flow.unittest.TestCase): @autotest(n=5, check_graph=True) def test_unbind_flow_with_random_data1(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.unbind(x, random(0, 4).to(int)) return y @autotest(n=5, check_graph=True) def test_unbind_flow_with_random_data2(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.unbind(x, random(0, 4).to(int)) return y @autotest(n=5, check_graph=True) def test_unbind_flow_with_random_data3(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = torch.unbind(x, random(0, 3).to(int)) return y @autotest(n=5, check_graph=True) def test_unbind_flow_with_random_data4(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = torch.unbind(x, random(0, 3).to(int)) return y @autotest(n=5, check_graph=True) def test_unbind_flow_with_random_data5(test_case): device = random_device() x = random_tensor(ndim=2).to(device) y = torch.unbind(x, random(0, 2).to(int)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_unfold.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.nn.common_types import _size_2_t @flow.unittest.skip_unless_1n1d() class TestUnfold(flow.unittest.TestCase): @autotest(n=50, auto_backward=True, rtol=1e-4, atol=1e-4) def test_unfold_with_random_data(test_case): m = torch.nn.Unfold( kernel_size=random(1, 3).to(_size_2_t), dilation=random(1, 2).to(_size_2_t) | nothing(), padding=random(0, 1).to(_size_2_t) | nothing(), stride=random(1, 2).to(_size_2_t) | nothing(), ) m.train(random()) device = random_device() m.to(device) x = random_tensor( ndim=4, dim0=random(1, 5), dim1=random(1, 5), dim2=random(10, 20), dim3=random(10, 20), ).to(device) y = m(x) func_y = torch.nn.functional.unfold( x, kernel_size=random(1, 3).to(_size_2_t), dilation=random(1, 2).to(_size_2_t) | nothing(), padding=random(0, 1).to(_size_2_t) | nothing(), stride=random(1, 2).to(_size_2_t) | nothing(), ) return y, func_y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_unfold_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from random import shuffle import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestUnfoldTensor(flow.unittest.TestCase): @autotest(n=10, auto_backward=True, check_graph=True) def test_unfold_tensor_with_random_data(test_case): device = random_device() x = random_tensor(3, 3, 4, 5).to(device) dimension = random(0, 2).to(int).value() size = random(1, 3).to(int).value() step = random(1, 3).to(int).value() y = x.unfold(dimension, size, step) return y @autotest(n=5) def test_unfold_tensor_with_stride(test_case): device = random_device() x = random_tensor(3, 3, 4, 5).to(device) perm = [0, 1, 2] shuffle(perm) y = x.permute(perm) dimension = random(0, 2).to(int).value() size = random(1, 3).to(int).value() step = random(1, 3).to(int).value() z = y.unfold(dimension, size, step) return z @autotest(n=10, auto_backward=True, check_graph=True) def test_unfold_tensor_with_0dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) dimension = random(0, 2).to(int).value() size = random(1, 3).to(int).value() step = random(1, 3).to(int).value() y = x.unfold(dimension, size, step) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_unique.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random as random_util import torch as torch_ori from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import numpy as np import oneflow as flow import oneflow.unittest def _test_unique_unsorted(test_case, device, return_inverse, return_counts): dtype = random_util.choice([torch.int8, torch.int, torch.float, torch.double]) input = random_tensor(ndim=3, dim0=random(), dim1=random(), dim2=random(), high=20) input = input.to(device).to(dtype) oneflow_output = flow.unique( input.oneflow, sorted=False, return_inverse=return_inverse, return_counts=return_counts, ) torch_output = torch_ori.unique( input.pytorch, sorted=False, return_inverse=return_inverse, return_counts=return_counts, ) if not return_inverse and not return_counts: oneflow_result = oneflow_output torch_result = torch_output else: oneflow_result = oneflow_output[0] torch_result = torch_output[0] test_case.assertTrue( np.allclose( np.sort(oneflow_result.numpy()), np.sort(torch_result.detach().cpu().numpy()), ) ) test_case.assertEqual(list(oneflow_result.shape), list(torch_result.shape)) if return_inverse: oneflow_indices = oneflow_output[1] torch_indices = torch_output[1] test_case.assertTrue( np.allclose( oneflow_result[oneflow_indices].numpy(), torch_result[torch_indices].detach().cpu().numpy(), ) ) test_case.assertEqual(list(oneflow_indices.shape), list(torch_indices.shape)) if return_counts: oneflow_counts = oneflow_output[-1] torch_counts = torch_output[-1] test_case.assertTrue( np.allclose( oneflow_counts.numpy()[np.argsort(oneflow_result.numpy())], torch_counts.detach() .cpu() .numpy()[np.argsort(torch_result.detach().cpu().numpy())], ) ) test_case.assertEqual(list(oneflow_counts.shape), list(torch_counts.shape)) def _test_unique_sorted(test_case, device, return_inverse, return_counts): dtype = random_util.choice([torch.int8, torch.int, torch.float, torch.double]) input = random_tensor(ndim=3, dim0=random(), dim1=random(), dim2=random(), high=20) input = input.to(device).to(dtype) oneflow_output = flow.unique( input.oneflow, sorted=True, return_inverse=return_inverse, return_counts=return_counts, ) torch_output = torch_ori.unique( input.pytorch, sorted=True, return_inverse=return_inverse, return_counts=return_counts, ) if not return_inverse and not return_counts: oneflow_result = oneflow_output torch_result = torch_output else: oneflow_result = oneflow_output[0] torch_result = torch_output[0] test_case.assertTrue( np.allclose(oneflow_result.numpy(), torch_result.detach().cpu().numpy(),) ) test_case.assertEqual(list(oneflow_result.shape), list(torch_result.shape)) if return_inverse: oneflow_indices = oneflow_output[1] torch_indices = torch_output[1] test_case.assertTrue( np.allclose(oneflow_indices.numpy(), torch_indices.detach().cpu().numpy(),) ) test_case.assertEqual(list(oneflow_indices.shape), list(torch_indices.shape)) if return_counts: oneflow_counts = oneflow_output[-1] torch_counts = torch_output[-1] test_case.assertTrue( np.allclose(oneflow_counts.numpy(), torch_counts.detach().cpu().numpy(),) ) test_case.assertEqual(list(oneflow_counts.shape), list(torch_counts.shape)) @flow.unittest.skip_unless_1n1d() class TestUnique(flow.unittest.TestCase): @autotest(n=5) def test_unique(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] arg_dict["return_inverse"] = [False, True] arg_dict["return_counts"] = [False, True] for arg in GenArgList(arg_dict): _test_unique_unsorted(test_case, *arg) _test_unique_sorted(test_case, *arg) @profile(torch.unique) def profile_unique(test_case): input = torch.randint(0, 1000, (1000,)) torch.unique(input) torch.unique(input, return_inverse=True, return_counts=True) input = torch.randn(1000,) torch.unique(input) torch.unique(input, return_inverse=True, return_counts=True) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_unsqueeze.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_unsqueeze(test_case, device): np_arr = np.random.rand(2, 6, 9, 3) x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) y = flow.unsqueeze(x, dim=1) output = np.expand_dims(np_arr, axis=1) test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05)) x_flow = flow.randn(5) x_flow = flow.unsqueeze(x_flow, 0) test_case.assertTrue(np.array_equal(x_flow.stride(), (5, 1))) x_flow = flow.randn(5, 2) x_flow = flow.unsqueeze(x_flow, 0) test_case.assertTrue(np.array_equal(x_flow.stride(), (10, 2, 1))) def _test_unsqueeze_tensor_function(test_case, device): np_arr = np.random.rand(2, 3, 4) x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) y = x.unsqueeze(dim=2) output = np.expand_dims(np_arr, axis=2) test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05)) def _test_unsqueeze_different_dim(test_case, device): np_arr = np.random.rand(4, 5, 6, 7) x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device)) for axis in range(-5, 5): y = flow.unsqueeze(x, dim=axis) output = np.expand_dims(np_arr, axis=axis) test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05)) def _test_unsqueeze_backward(test_case, device): np_arr = np.random.rand(2, 3, 4, 5) x = flow.tensor( np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True ) y = flow.unsqueeze(x, dim=1).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((2, 3, 4, 5)), 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() class TestUnsqueeze(flow.unittest.TestCase): def test_unsqueeze(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_unsqueeze, _test_unsqueeze_tensor_function, _test_unsqueeze_different_dim, _test_unsqueeze_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(check_graph=True) def test_flow_unsqueeze_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.unsqueeze(x, random(1, 3).to(int)) return y @autotest(n=10, check_graph=False, auto_backward=False) def test_inplace_unsqueeze_with_random_data(test_case): device = random_device() x = random_tensor(requires_grad=False).to(device) y = x.unsqueeze_(random(1, 3).to(int)) return y @autotest(auto_backward=False, check_graph=True) def test_unsqueeze_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 1, 0).to(device) y = torch.unsqueeze(x, random(0, 2).to(int)) return y @autotest(auto_backward=False, check_graph=True) def test_flow_unsqueeze_bool_with_random_data(test_case): device = random_device() x = random_tensor().to(device=device, dtype=torch.bool) y = torch.unsqueeze(x, random(1, 3).to(int)) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_upsample.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _test_upsample2d_bilinear(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.25, 1.75, 2.0], [1.5, 1.75, 2.25, 2.5], [2.5, 2.75, 3.25, 3.5], [3.0, 3.25, 3.75, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_upsample2d_bilinear_aligncorner(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.3333, 1.6667, 2.0], [1.6667, 2.0, 2.3333, 2.6667], [2.3333, 2.6667, 3.0, 3.3333], [3.0, 3.3333, 3.6667, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) def _test_UpsamplingNearest2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.UpsamplingNearest2d(scale_factor=2.0) of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_UpsamplingBilinear2d(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.3333, 1.6667, 2.0], [1.6667, 2.0, 2.3333, 2.6667], [2.3333, 2.6667, 3.0, 3.3333], [3.0, 3.3333, 3.6667, 4.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) def _test_upsample2d_4dim(test_case, device): input = flow.tensor( np.arange(1, 37).reshape((2, 2, 3, 3)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.0, 2.0, 2.0, 3.0, 3.0], [1.0, 1.0, 2.0, 2.0, 3.0, 3.0], [4.0, 4.0, 5.0, 5.0, 6.0, 6.0], [4.0, 4.0, 5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0, 9.0, 9.0], [7.0, 7.0, 8.0, 8.0, 9.0, 9.0], ], [ [10.0, 10.0, 11.0, 11.0, 12.0, 12.0], [10.0, 10.0, 11.0, 11.0, 12.0, 12.0], [13.0, 13.0, 14.0, 14.0, 15.0, 15.0], [13.0, 13.0, 14.0, 14.0, 15.0, 15.0], [16.0, 16.0, 17.0, 17.0, 18.0, 18.0], [16.0, 16.0, 17.0, 17.0, 18.0, 18.0], ], ], [ [ [19.0, 19.0, 20.0, 20.0, 21.0, 21.0], [19.0, 19.0, 20.0, 20.0, 21.0, 21.0], [22.0, 22.0, 23.0, 23.0, 24.0, 24.0], [22.0, 22.0, 23.0, 23.0, 24.0, 24.0], [25.0, 25.0, 26.0, 26.0, 27.0, 27.0], [25.0, 25.0, 26.0, 26.0, 27.0, 27.0], ], [ [28.0, 28.0, 29.0, 29.0, 30.0, 30.0], [28.0, 28.0, 29.0, 29.0, 30.0, 30.0], [31.0, 31.0, 32.0, 32.0, 33.0, 33.0], [31.0, 31.0, 32.0, 32.0, 33.0, 33.0], [34.0, 34.0, 35.0, 35.0, 36.0, 36.0], [34.0, 34.0, 35.0, 35.0, 36.0, 36.0], ], ], ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_upsample2d_bilinear_4dim(test_case, device): input = flow.tensor( np.arange(1, 37).reshape((2, 2, 3, 3)), device=flow.device(device), dtype=flow.float32, ) m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.25, 1.75, 2.25, 2.75, 3.0], [1.75, 2.0, 2.5, 3.0, 3.5, 3.75], [3.25, 3.5, 4.0, 4.5, 5.0, 5.25], [4.75, 5.0, 5.5, 6.0, 6.5, 6.75], [6.25, 6.5, 7.0, 7.5, 8.0, 8.25], [7.0, 7.25, 7.75, 8.25, 8.75, 9.0], ], [ [10.0, 10.25, 10.75, 11.25, 11.75, 12.0], [10.75, 11.0, 11.5, 12.0, 12.5, 12.75], [12.25, 12.5, 13.0, 13.5, 14.0, 14.25], [13.75, 14.0, 14.5, 15.0, 15.5, 15.75], [15.25, 15.5, 16.0, 16.5, 17.0, 17.25], [16.0, 16.25, 16.75, 17.25, 17.75, 18.0], ], ], [ [ [19.0, 19.25, 19.75, 20.25, 20.75, 21.0], [19.75, 20.0, 20.5, 21.0, 21.5, 21.75], [21.25, 21.5, 22.0, 22.5, 23.0, 23.25], [22.75, 23.0, 23.5, 24.0, 24.5, 24.75], [24.25, 24.5, 25.0, 25.5, 26.0, 26.25], [25.0, 25.25, 25.75, 26.25, 26.75, 27.0], ], [ [28.0, 28.25, 28.75, 29.25, 29.75, 30.0], [28.75, 29.0, 29.5, 30.0, 30.5, 30.75], [30.25, 30.5, 31.0, 31.5, 32.0, 32.25], [31.75, 32.0, 32.5, 33.0, 33.5, 33.75], [33.25, 33.5, 34.0, 34.5, 35.0, 35.25], [34.0, 34.25, 34.75, 35.25, 35.75, 36.0], ], ], ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_upsample2d_backward(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") of_out = m(input) of_out = of_out.sum() of_out.backward() np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_upsample2d_bilinear_aligncorner_backward(test_case, device): input = flow.tensor( np.arange(1, 5).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) of_out = m(input) of_out = of_out.sum() of_out.backward() np_grad = [[[[3.999999523162842, 4.000000476837158], [3.999999761581421, 4.0]]]] test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_nearest_float_scale(test_case, device): input = flow.tensor( np.arange(1, 10).reshape((1, 1, 3, 3)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(scale_factor=1.5) of_out = m(input) np_out = np.array( [ [ [ [1.0, 1.0, 2.0, 3.0], [1.0, 1.0, 2.0, 3.0], [4.0, 4.0, 5.0, 6.0], [7.0, 7.0, 8.0, 9.0], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array([[[[4.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 1.0]]]]) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_interpolate_bilinear_float_scale(test_case, device): input = flow.tensor( np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear") of_out = m(input) np_out = np.array([[[[2.5]]]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array([[[[0.25, 0.25], [0.25, 0.25]]]]) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) input = flow.tensor( np.arange(1, 10, dtype=np.int32).reshape((1, 1, 3, 3)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear") of_out = m(input) np_out = np.array([[[[3.0]]]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array([[[[0.25, 0.25, 0.0], [0.25, 0.25, 0.0], [0.0, 0.0, 0.0]]]]) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) input = flow.tensor( np.arange(1, 11, dtype=np.int32).reshape((1, 1, 5, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(size=(4, 4), mode="bilinear") of_out = m(input) np_out = np.array( [ [ [ [1.25, 1.5, 2.0, 2.25], [3.75, 4.0, 4.5, 4.75], [6.25, 6.5, 7.0, 7.25], [8.75, 9.0, 9.5, 9.75], ] ] ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array( [[[[1.75, 1.75], [1.5, 1.5], [1.5, 1.5], [1.5, 1.5], [1.75, 1.75]]]] ) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) def _test_upsample_bilinear_align_corners(test_case, device): input = flow.tensor( np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)), device=flow.device(device), dtype=flow.float32, requires_grad=True, ) m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True) of_out = m(input) np_out = np.array([[[[1.0]]]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_grad = np.array([[[[1.0, 0.0], [0.0, 0.0]]]]) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestUpsample2d(flow.unittest.TestCase): def test_upsample2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_upsample2d_bilinear, _test_upsample2d_bilinear_aligncorner, _test_UpsamplingNearest2d, _test_UpsamplingBilinear2d, _test_upsample2d_4dim, _test_upsample2d_bilinear_4dim, _test_upsample2d_backward, _test_upsample2d_bilinear_aligncorner_backward, _test_interpolate_nearest_float_scale, _test_interpolate_bilinear_float_scale, _test_upsample_bilinear_align_corners, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @unittest.skip( "The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200" ) @autotest() def test_upsample2d_nearest(test_case): device = random_device() x = random_tensor().to(device) m = torch.nn.Upsample(scale_factor=random().to(float), mode="nearest") y = m(x) return y @unittest.skip( "The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200" ) @autotest() def test_upsample2d_nearest_half(test_case): device = random_device() x = random_tensor().to(device=device, dtype=torch.float16) m = torch.nn.Upsample(scale_factor=random().to(float), mode="nearest") y = m(x) return y # The forward and backward result in cpu and cuda of bilinear interpolate operation in PyTorch is different # in some corner cases. OneFlow has the same cpu and cuda results with PyTorch's cuda result. # So here we only test cuda device forward result. @autotest(n=10, auto_backward=False, atol=1e-8) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bilinear(test_case): x = random_tensor(ndim=4).to("cuda") x = x.permute(1, 3, 0, 2) m = torch.nn.Upsample( scale_factor=random().to(float), mode="bilinear", align_corners=random_bool(), ) y = m(x) return y @autotest(atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bicubic(test_case): x = random_tensor(ndim=4, dim0=16, dim1=8).to("cuda") m = torch.nn.Upsample( scale_factor=random().to(float), mode="bicubic", align_corners=random_bool(), ) y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample1d_nearest_output_size(test_case): x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to("cuda") m = torch.nn.Upsample(size=(13), mode="nearest") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_nearest_output_size(test_case): x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=1, dim3=937).to("cuda") m = torch.nn.Upsample(size=(1, 30), mode="nearest") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample3d_nearest_output_size(test_case): x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=6, dim3=12, dim4=6).to("cuda") m = torch.nn.Upsample(size=(8, 10, 7), mode="nearest") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample1d_linear_output_size(test_case): device = random_device() x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to(device) m = torch.nn.Upsample(size=(13), mode="linear") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bilinear_output_size(test_case): x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=12, dim3=21).to("cuda") m = torch.nn.Upsample(size=(14, 19), mode="bilinear") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bicubic_output_size(test_case): x = random_tensor(ndim=4, dim0=1, dim1=2, dim2=12, dim3=21).to("cuda") m = torch.nn.Upsample(size=(14, 19), mode="bicubic") y = m(x) return y @autotest(n=5, atol=1e-5) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample3d_trilinear_output_size(test_case): x = random_tensor(ndim=5, dim0=1, dim1=2, dim2=1, dim3=12, dim4=17).to("cuda") m = torch.nn.Upsample(size=(1, 14, 23), mode="trilinear") y = m(x) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_util_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow from collections import OrderedDict from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList # TODO(): random_tensor can't generate a tensor with nan or inf element. def _test_isnan(test_case, shape, dtype, device): np_array = np.random.randn(*shape) mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool) np_array[mask] = np.nan of_tensor = flow.tensor(np_array, dtype=dtype, device=device) res = flow.isnan(of_tensor) test_case.assertTrue(np.allclose(res.numpy(), np.isnan(of_tensor.numpy()))) def _test_isinf(test_case, shape, dtype, device): np_array = np.random.randn(*shape) mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool) np_array[mask] = np.inf of_tensor = flow.tensor(np_array, dtype=dtype, device=device) res = flow.isinf(of_tensor) test_case.assertTrue(np.allclose(res.numpy(), np.isinf(of_tensor.numpy()))) def _test_isfinite(test_case, shape, dtype, device): np_array = np.random.randn(*shape) inf_mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool) nan_mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool) np_array[inf_mask] = np.inf np_array[nan_mask] = np.nan of_tensor = flow.tensor(np_array, dtype=dtype, device=device) res = flow.isfinite(of_tensor) test_case.assertTrue(np.allclose(res.numpy(), np.isfinite(of_tensor.numpy()))) @flow.unittest.skip_unless_1n1d() class TestUtilOps(flow.unittest.TestCase): def test_util_ops(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_isnan, _test_isinf, _test_isfinite] arg_dict["shape"] = [(2, 3, 4), (1, 2, 3)] arg_dict["dtype"] = [flow.float, flow.int] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow import torch from torch._utils import _flatten_dense_tensors as torch_flatten_dense_tensors from torch._utils import _unflatten_dense_tensors as torch_unflatten_dense_tensors from oneflow._utils import _flatten_dense_tensors, _unflatten_dense_tensors from collections import OrderedDict from oneflow.test_utils.test_util import GenArgList def _test_flatten_dense_tensors(test_case, device): torch_x = torch.randn(6, 6, device=device) x = flow.utils.tensor.from_torch(torch_x) torch_x_flatten = torch_flatten_dense_tensors([torch_x]) x_flatten = _flatten_dense_tensors([x]) test_case.assertTrue(np.array_equal(torch_x_flatten.size(), x_flatten.size())) torch_x_flatten = torch_flatten_dense_tensors([torch_x, torch_x, torch_x]) x_flatten = _flatten_dense_tensors([x, x, x]) test_case.assertTrue(np.array_equal(torch_x_flatten.size(), x_flatten.size())) test_case.assertTrue( np.allclose( torch_x_flatten.cpu().numpy(), x_flatten.cpu().numpy(), 1e-05, 1e-05 ) ) def _test_unflatten_dense_tensors(test_case, device): torch_flat = torch.randn(6, 1, device=device) torch_x1 = torch.randn(2, 1, device=device) torch_x2 = torch.randn(2, 1, device=device) torch_x3 = torch.randn(2, 1, device=device) torch_tensors = [ torch_x1, torch_x2, torch_x3, ] tensors = [ flow.utils.tensor.from_torch(torch_x1), flow.utils.tensor.from_torch(torch_x2), flow.utils.tensor.from_torch(torch_x3), ] torch_outputs = torch_unflatten_dense_tensors(torch_flat, torch_tensors) outputs = _unflatten_dense_tensors( flow.utils.tensor.from_torch(torch_flat), tensors ) for i in range(len(outputs)): test_case.assertTrue(np.array_equal(torch_outputs[i].size(), outputs[i].size())) test_case.assertTrue( np.allclose( torch_outputs[i].cpu().numpy(), outputs[i].cpu().numpy(), 1e-05, 1e-05 ) ) @flow.unittest.skip_unless_1n1d() class TestUtilsFunction(flow.unittest.TestCase): def test_utils_function(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_flatten_dense_tensors, _test_unflatten_dense_tensors, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_var.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow from oneflow.test_utils.automated_test_util.generators import random import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestVar(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_var_all_dim_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = torch.var(x) return y @autotest(check_graph=True) def test_flow_var_one_dim_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = torch.var( x, dim=random(low=-4, high=4).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y # In fp16 mode, variance op backward has a gap of 1e-3 between the gradient of PyTorch # and OneFlow for some unknown reason. However, it is not important now, because both in # PyTorch and OneFlow variance op don't need support fp16 backward in amp train. @autotest(n=5, auto_backward=True, check_graph=True, rtol=1e-3, atol=1e-3) def test_flow_var_one_dim_with_random_half_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device).to(torch.float16) y = torch.var( x, dim=random(low=-4, high=4).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y @autotest(auto_backward=False, check_graph=True) def test_flow_var_0_size_data_with_random_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 4).to(device) y = torch.var( x, dim=random(low=-4, high=4).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_var_0_size_data_with_random_half_data(test_case): device = random_device() x = random_tensor(4, 2, 3, 0, 4).to(device).to(torch.float16) y = torch.var( x, dim=random(low=-4, high=4).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y @autotest(n=5) def test_flow_var_all_dim_with_random_data_n5(test_case): device = random_device() x = random_tensor(ndim=4, dim0=5, dim1=1, dim2=16, dim3=16).to(device) y = torch.var(x, dim=[0, 2, 3]) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_view.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import GenArgList from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest def _test_view(test_case, device): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor( x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) of_out = input.view(2, 2, 2, -1) of_shape = of_out.numpy().shape np_shape = (2, 2, 2, 2) test_case.assertTrue(np.array_equal(of_shape, np_shape)) of_out = of_out.sum() of_out.backward() np_grad = np.array( [ [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], ] ) test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) def _test_view_flow_size(test_case, device): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor( x, dtype=flow.float32, device=flow.device(device), requires_grad=True ) shape = flow.Size([2, 2, 2, -1]) of_out = input.view(shape) np_shape = (2, 2, 2, 2) test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_shape)) of_out = of_out.sum() of_out.backward() np_grad = np.array( [ [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], ] ) test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestView(flow.unittest.TestCase): # TODO:(zhaoluyang) add test case that trigger tensor.view's check def test_view(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_view, _test_view_flow_size, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5, check_graph=True) def test_view_with_0_dim_data(test_case): device = random_device() x = random_tensor(ndim=0).to(device) y1 = torch.reshape(x, shape=(-1,)) y2 = x.view((1, 1, 1)) test_case.assertTrue(x.oneflow.stride() == x.pytorch.stride()) test_case.assertTrue(y1.oneflow.stride() == y1.pytorch.stride()) test_case.assertTrue(y2.oneflow.stride() == y2.pytorch.stride()) return y2 if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_vsplit.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from random import shuffle from oneflow.test_utils.automated_test_util import * import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestVsplitVec(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_vsplit_vec(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) z = torch.vsplit(x, (1, 2)) return z[0] @autotest(n=10) def test_flow_vsplit_vec_with_stride(test_case): device = random_device() x = random_tensor( ndim=4, dim0=random(3, 6), dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) perm = [0, 1, 2, 3] shuffle(perm) y = x.permute(perm) z = torch.vsplit(y, (1, 2)) return z[0] @flow.unittest.skip_unless_1n1d() class TestVsplitInt(flow.unittest.TestCase): @autotest(check_graph=True) def test_flow_vsplit_int(test_case): device = random_device() x = random_tensor( ndim=4, dim0=12, dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6), ).to(device) split = oneof(2, 4, 6) z = torch.vsplit(x, split) return z[0] if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_weight_norm.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList import torch as torch_original from oneflow.test_utils.automated_test_util import * input_arr = np.array( [ [-0.16046895, -1.03667831], [-0.34974465, 0.26505867], [-1.24111986, -0.53806001], [1.72426331, 0.43572459], ], dtype=np.float64, ) def _test_weightnorm(test_case, device, dim): model_flow = flow.nn.Linear(2, 4) model_flow = model_flow.to(device) with flow.no_grad(): for i in range(input_arr.shape[0]): for j in range(input_arr.shape[1]): model_flow.weight[i, j] = input_arr[i][j] m_flow = flow.nn.utils.weight_norm(model_flow, name="weight", dim=dim) model_torch = torch_original.nn.Linear(2, 4) model_torch = model_torch.to(device) with torch_original.no_grad(): for i in range(input_arr.shape[0]): for j in range(input_arr.shape[1]): model_torch.weight[i, j] = input_arr[i][j] m_torch = torch_original.nn.utils.weight_norm(model_torch, name="weight", dim=dim) if device == "cpu": test_case.assertTrue( np.allclose( m_flow.weight_g.detach().numpy(), m_torch.weight_g.detach().numpy(), 1e-05, 1e-05, ) ) test_case.assertTrue( np.allclose( m_flow.weight_v.detach().numpy(), m_torch.weight_v.detach().numpy(), 1e-05, 1e-05, ) ) elif device == "cuda": test_case.assertTrue( np.allclose( m_flow.weight_g.detach().cpu().numpy(), m_torch.weight_g.detach().cpu().numpy(), 1e-05, 1e-05, ) ) test_case.assertTrue( np.allclose( m_flow.weight_v.detach().numpy(), m_torch.weight_v.detach().cpu().numpy(), 1e-05, 1e-05, ) ) def _test_weightnorm_backward(test_case, device, dim): linear = flow.nn.Linear(3, 8) x = flow.tensor( [ [-0.94630778, -0.83378579, -0.87060891], [2.0289922, -0.28708987, -2.18369248], [0.35217619, -0.67095644, -1.58943879], [0.08086036, -1.81075924, 1.20752494], [0.8901075, -0.49976737, -1.07153746], [-0.44872912, -1.07275683, 0.06256855], [-0.22556897, 0.74798368, 0.90416439], [0.48339456, -2.32742195, -0.59321527], ], dtype=flow.float32, requires_grad=True, ) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) linear_wn = flow.nn.utils.weight_norm(linear, name="weight", dim=dim) of_out = linear_wn(x) of_out = of_out.sum() of_out.backward() np_grad = np.array( [ [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], [16.5501, 16.5501, 16.5501], ] ) test_case.assertTrue(np.allclose(np_grad, x.grad.numpy(), 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() class TestWeightNorm(flow.unittest.TestCase): def test_weightnorm(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_weightnorm, _test_weightnorm_backward, ] arg_dict["device"] = ["cpu", "cuda"] arg_dict["dim"] = [None, -2, -1, 0, 1] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) # Not check graph because of one reason: # Reason 1, Graph's build input nn.modules.linear.Linear type is not supported. # Please refer to issue: https://github.com/Oneflow-Inc/oneflow/issues/7466 @autotest(n=10, auto_backward=True, check_graph="ValidatedFalse") def test_weight_norm_with_random_data(test_case): device = random_device() dim = random(-2, 2).to(int).value() output = random(2, 6).to(int) input = random(2, 6).to(int) model_torch = torch.nn.Linear(output, input) model_torch = model_torch.to(device) m = torch.nn.utils.weight_norm(model_torch, name="weight", dim=dim) return m.weight_g, m.weight_v if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_where.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList import oneflow as flow import oneflow.unittest def _test_where(test_case, device): x = flow.tensor( np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32, device=flow.device(device), ) y = flow.tensor( np.ones(shape=(3, 2)), dtype=flow.float32, device=flow.device(device) ) condition = flow.tensor( np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device) ) of_out = flow.where(condition, x, y) np_out = np.array([[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_where_broadcast(test_case, device): x = flow.tensor( np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), dtype=flow.float32, device=flow.device(device), ) y = flow.tensor( np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device) ) condition = flow.tensor( np.array([[[0, 1], [1, 0], [1, 0]]]), dtype=flow.int32, device=flow.device(device), ) of_out = flow.where(condition, x, y) np_out = np.array( [ [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]], [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]], [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]], ] ) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_where_scalar(test_case, device): x = 0.5 y = 2.0 condition = flow.tensor(np.array([1]), dtype=flow.int32) of_out = flow.where(condition, x, y) test_case.assertTrue(of_out.dtype == flow.float32) np_out = np.array([0.5]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) flow.set_default_dtype(flow.double) of_out = flow.where(condition, x, y) test_case.assertTrue(of_out.dtype == flow.double) flow.set_default_dtype(flow.float16) of_out = flow.where(condition, x, y) test_case.assertTrue(of_out.dtype == flow.float16) flow.set_default_dtype(flow.bfloat16) of_out = flow.where(condition, x, y) test_case.assertTrue(of_out.dtype == flow.bfloat16) def _test_where_dim4(test_case, device): x = flow.tensor( np.array([[[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]]), dtype=flow.float32, device=flow.device(device), ) y = flow.tensor( np.ones(shape=(1, 1, 3, 2)), dtype=flow.float32, device=flow.device(device) ) condition = flow.tensor( np.array([[[[0, 1], [1, 0], [1, 0]]]]), dtype=flow.int32, device=flow.device(device), ) of_out = flow.where(condition, x, y) np_out = np.array([[[[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]]]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def _test_where_backward(test_case, device): x = flow.tensor( np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.ones(shape=(3, 2)), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) condition = flow.tensor( np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device) ) of_out = flow.where(condition, x, y) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), condition.numpy() == 1, 1e-05, 1e-05) ) test_case.assertTrue( np.allclose(y.grad.numpy(), condition.numpy() == 0, 1e-05, 1e-05) ) def _test_where_broadcast_backward(test_case, device): x = flow.tensor( np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) condition = flow.tensor( np.array([[[0, 1], [1, 0], [1, 0]]]), dtype=flow.int32, device=flow.device(device), ) of_out = flow.where(condition, x, y) of_out = of_out.sum() of_out.backward() x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]] test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05)) y_grad = [ [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], ] test_case.assertTrue(np.allclose(y.grad.numpy(), y_grad, 1e-05, 1e-05)) def _test_where_broadcast_x_backward(test_case, device): x = flow.tensor( np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) y = flow.tensor( np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device) ) condition = flow.tensor( np.array([[[0, 1], [1, 0], [1, 0]]]), dtype=flow.int32, device=flow.device(device), ) of_out = flow.where(condition, x, y) of_out = of_out.sum() of_out.backward() x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]] test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05)) def _test_where_x_y_none(test_case, device): condition = flow.tensor( np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), dtype=flow.float32, device=flow.device(device), requires_grad=True, ) of_out = flow.where(condition) of_nonzero = flow.nonzero(condition, as_tuple=True) for i in range(len(of_out)): test_case.assertTrue( np.allclose(of_out[i].numpy(), of_nonzero[i].numpy(), 1e-05, 1e-05) ) def _test_where_scalar(test_case, device): x = flow.randn(5, 5) y = flow.where(x > 0, x, 0.0) test_case.assertTrue(np.array_equal(y.size(), (5, 5))) y = flow.where(x > 0, 0.0, x) test_case.assertTrue(np.array_equal(y.size(), (5, 5))) @flow.unittest.skip_unless_1n1d() class TestWhere(flow.unittest.TestCase): def test_where(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_where, _test_where_broadcast, _test_where_scalar, _test_where_dim4, _test_where_backward, _test_where_broadcast_backward, _test_where_broadcast_x_backward, _test_where_x_y_none, _test_where_scalar, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=5) def test_flow_where_tensor_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_tensor_with_0dim_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=0).to(device) y = random_tensor(ndim=0).to(device) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_tensor_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) y = random_tensor(ndim=2, dim0=k1, dim1=1).to(device) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_scalar_x_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(float) y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to( device=device, dtype=torch.float64 ) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_scalar_x_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) x = random().to(float) y = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to( device=device, dtype=torch.float64 ) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_x_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(int) y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=int).to(device) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_scalar_y_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to( device=device, dtype=torch.float64 ) y = random().to(float) return torch.where(cond > 0, x, y) @autotest(n=5) def test_flow_where_scalar_y_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to( device=device, dtype=torch.float64 ) y = random().to(float) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_y_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=int).to(device) y = random().to(int) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_xy_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(float) y = random().to(float) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_xy_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(int) y = random().to(int) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_tensor_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device=device, dtype=torch.bool) y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device=device, dtype=torch.bool) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_tensor_broadcast_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=1, dim1=k2).to(device=device, dtype=torch.bool) y = random_tensor(ndim=2, dim0=k1, dim1=1).to(device=device, dtype=torch.bool) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_x_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(bool) y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to( device=device, dtype=torch.bool ) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_x_broadcast_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) x = random().to(bool) y = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to( device=device, dtype=torch.bool ) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_y_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to( device=device, dtype=torch.bool ) y = random().to(bool) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_y_broadcast_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device) x = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to( device=device, dtype=torch.bool ) y = random().to(bool) return torch.where(cond > 0, x, y) @autotest(n=5, auto_backward=False, check_graph=True) def test_flow_where_scalar_xy_bool_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) device = random_device() cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device) x = random().to(bool) y = random().to(bool) return torch.where(cond > 0, x, y) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/modules/test_zeropad2d.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np from oneflow.test_utils.test_util import ( Array2Numpy, FlattenArray, GenArgList, Index2Coordinate, ) import oneflow as flow import oneflow.unittest def _np_zero_pad2d_grad(src, dest, padding): (c_idx, h_idx, w_idx) = (1, 2, 3) pad_left = padding[0] pad_right = padding[1] pad_top = padding[2] pad_bottom = padding[3] (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx]) (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx]) numpy_src = np.ones(src.shape, np.int32) numpy_dest = np.zeros(dest.shape, np.int32) array_src = FlattenArray(numpy_src) array_dest = FlattenArray(numpy_dest) src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx] dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx] elements_num = src.shape[0] * src_num for iter_n in range(elements_num): coords = Index2Coordinate(iter_n, src.shape) (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx]) ip_x = ip_y = 0 if ( j >= pad_left and j < dx_width + pad_left and (i >= pad_top) and (i < dx_height + pad_top) ): ip_x = j - pad_left ip_y = i - pad_top src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j dest_index = ( n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x ) array_dest[dest_index] += array_src[src_index] numpy_dest = Array2Numpy(array_dest, dest.shape) return numpy_dest def _test_ZeroPad2d(test_case, shape, padding, value, device): np_input = np.random.random(shape) of_input = flow.tensor( np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True ) if isinstance(padding, int): np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding)) elif isinstance(padding, (tuple, int)) and len(padding) == 4: np_boundary = ( (0, 0), (0, 0), (padding[2], padding[3]), (padding[0], padding[1]), ) else: raise ValueError("padding must be in or tuple!") layer = flow.nn.ZeroPad2d(padding=padding) of_out = layer(of_input) np_out = np.pad(np_input, np_boundary, mode="constant", constant_values=value) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() np_out_grad = _np_zero_pad2d_grad(np_out, np_input, layer.padding) test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05)) @flow.unittest.skip_unless_1n1d() class TestZeroPad2dModule(flow.unittest.TestCase): def test_ConstantPad2d(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)] arg_dict["padding"] = [2, (1, 1, 2, 2)] arg_dict["value"] = [0.0] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): _test_ZeroPad2d(test_case, *arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/profiler/test_events.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import json import unittest import oneflow.unittest import oneflow as flow from oneflow.profiler.events import * class TestEventAndEvents(flow.unittest.TestCase): def test_event(test_case): classes = [CustomEvent, KernelEvent] custom_event = CustomEvent("custom", 1234, CustomEventType.Default) custom_event_json = { "name": "custom", "time": 1234, "custom_type": 0, "type": 0, } test_case.assertEqual( custom_event, classes[custom_event_json.get("type")].from_dict(custom_event_json), ) kernel_event = KernelEvent("kernel", 1234, 1024, "-") kernel_event_json = { "name": "kernel", "time": 1234, "memory_size": 1024, "type": 1, "input_shapes": "-", } test_case.assertEqual( kernel_event, classes[kernel_event_json.get("type")].from_dict(kernel_event_json), ) def test_event_update(test_case): event = CustomEvent("custom", 1234, CustomEventType.Default) event1 = CustomEvent("custom", 3346, CustomEventType.Default) event.update(event1) test_case.assertEqual(event.count, 2) test_case.assertEqual(event.cpu_time, 2290) test_case.assertEqual(event.cpu_time_total, 4580) def test_events(test_case): events_json = json.dumps( [ {"name": "custom", "time": 1234, "custom_type": 0, "type": 0}, {"name": "custom", "time": 3346, "custom_type": 0, "type": 0}, ] ) events = [ CustomEvent("custom", 1234, CustomEventType.Default), CustomEvent("custom", 3346, CustomEventType.Default), ] events_avg = [CustomEvent("custom", 4580, CustomEventType.Default)] events_avg[0].count = 2 test_case.assertEqual(Events(events_json), events) test_case.assertEqual(Events(events_json).key_averages(), events_avg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/profiler/test_profile_lenet.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow.unittest import oneflow as flow import oneflow.nn as nn import oneflow.nn.functional as F import oneflow.profiler from collections import OrderedDict from oneflow.profiler.events import CustomEvent, KernelEvent from oneflow.test_utils.test_util import GenArgDict class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): out = F.relu(self.conv1(x)) out = F.max_pool2d(out, 2) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out def get_event(events, name: str, input_shapes: str = "", attributes: str = ""): for item in events: if isinstance(item, CustomEvent): if item.name == name: return item if isinstance(item, KernelEvent): if ( item.name == name and item.input_shapes == input_shapes and item.attributes == attributes ): return item return None def _test_lenet( test_case, on_cuda: bool, record_shapes: bool, record_attrs: bool, record_bandwidth_for_cuda: bool = False, ): x = flow.randn(2, 3, 32, 32) lenet = LeNet() if on_cuda: x = x.to("cuda") lenet.to("cuda") activities = [oneflow.profiler.ProfilerActivity.CPU] if on_cuda: activities.append(oneflow.profiler.ProfilerActivity.CUDA) with oneflow.profiler.profile( activities=activities, record_shapes=record_shapes, record_attrs=record_attrs, record_bandwidth_for_cuda=record_bandwidth_for_cuda, ) as prof: with oneflow.profiler.record_function("lenet_forward_total_time") as f: for _ in range(2): eager_res = lenet(x) with oneflow.profiler.record_function("lenet_backward_total_time") as f: eager_res.sum().backward() events = prof.key_averages(group_by_input_shape=True, group_by_attributes=True) conv_event_input_shapes = "(2,3,32,32), (6,3,5,5)" if record_shapes else "" conv_event_attributes = ( "data_format=channels_first, dilation_rate=[1, 1], filters=6, groups=1, kernel_size=[5, 5], padding_before=[0, 0], strides=[1, 1]" if record_attrs else "" ) conv_event = get_event( events, "conv2d", conv_event_input_shapes, conv_event_attributes ) test_case.assertIsNotNone(conv_event) if on_cuda: test_case.assertGreater(conv_event.cpu_time, 0.0) test_case.assertGreater(conv_event.cpu_time_total, 0.0) test_case.assertGreater(conv_event.cuda_time, 0.0) test_case.assertGreater(conv_event.cuda_time_total, 0.0) else: test_case.assertGreater(conv_event.cpu_time, 0.0) test_case.assertGreater(conv_event.cpu_time_total, 0.0) test_case.assertEqual(conv_event.count, 2 if record_shapes or record_attrs else 4) if record_bandwidth_for_cuda and on_cuda: test_case.assertNotEqual(conv_event.bandwidth, -1) relu_grad_event_input_shapes = "(2,6,28,28), (2,6,28,28)" if record_shapes else "" relu_grad_event = get_event(events, "relu_grad", relu_grad_event_input_shapes, "") test_case.assertIsNotNone(relu_grad_event) if on_cuda: test_case.assertGreater(relu_grad_event.cpu_time, 0.0) test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0) test_case.assertGreater(relu_grad_event.cuda_time, 0.0) test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0) else: test_case.assertGreater(relu_grad_event.cpu_time, 0.0) test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0) test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4) if record_bandwidth_for_cuda and on_cuda: test_case.assertNotEqual(relu_grad_event.bandwidth, -1) test_case.assertIsNotNone(get_event(events, "lenet_forward_total_time")) test_case.assertIsNotNone(get_event(events, "lenet_backward_total_time")) class TestProfileLenet(flow.unittest.TestCase): def test_lenet_cpu(test_case): arg_dict = OrderedDict() arg_dict["record_shapes"] = [True, False] arg_dict["record_attrs"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_lenet(test_case, False, **kwargs) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_lenet_cuda(test_case): arg_dict = OrderedDict() arg_dict["record_shapes"] = [True, False] arg_dict["record_attrs"] = [True, False] arg_dict["record_bandwidth_for_cuda"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_lenet(test_case, True, **kwargs) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_autocast.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "skip test cpu cases") @flow.unittest.skip_unless_1n1d() class TestAutoCast(flow.unittest.TestCase): @autotest(n=1, auto_backward=True, check_graph=False) def test_autocast_half_mm(test_case): a = random_tensor(2, 2, 3).to("cuda") b = random_tensor(2, 3, 4).to("cuda") with torch.autocast("cuda"): x = torch.mm(a, b) return x @autotest(n=1, auto_backward=True, check_graph=False) def test_autocast_half_mm_add(test_case): a = random_tensor(2, 2, 3).to("cuda") b = random_tensor(2, 3, 4).to("cuda") c = random_tensor(2, 2, 4).to("cuda") with torch.autocast("cuda"): x = torch.mm(a, b) y = x + c return x.float() + y.float() def test_autocast_graph(test_case): class LinearGraph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 4, bias=False).cuda().half() def build(self, x): return self.linear(x) x = flow.Tensor(3, 3).cuda() with flow.autocast(device_type="cuda"): linear = LinearGraph() y = linear(x) test_case.assertTrue(y.dtype == flow.float16) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_bfloat16_activation.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest @flow.unittest.skip_unless_1n1d() class TestBfloat16Activatian(flow.unittest.TestCase): def test_tan_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.tan(x) fp32_y = flow.tan(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_tanh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.tanh(x) fp32_y = flow.tanh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_sin_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.sin(x) fp32_y = flow.sin(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_sinh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.sinh(x) fp32_y = flow.sinh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_cos_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.cos(x) fp32_y = flow.cos(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_cosh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.cosh(x) fp32_y = flow.cosh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_atan_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.atan(x) fp32_y = flow.atan(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_atanh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.atanh(x) fp32_y = flow.atanh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_asin_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.asin(x) fp32_y = flow.asin(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_asinh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.asinh(x) fp32_y = flow.asinh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_acos_with_random_data(test_case): np_array = np.random.uniform(-1, 1, (4, 4)) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.acos(x) fp32_y = flow.acos(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_acosh_with_random_data(test_case): np_array = np.random.uniform(1, 5, (4, 4)) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.acosh(x) fp32_y = flow.acosh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_sqrt_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.sqrt(x) fp32_y = flow.sqrt(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_square_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.square(x) fp32_y = flow.square(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_exp_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.exp(x) fp32_y = flow.exp(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_exp2_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.exp2(x) fp32_y = flow.exp2(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_ceil_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.ceil(x) fp32_y = flow.ceil(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_erf_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.erf(x) fp32_y = flow.erf(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_erfc_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.erfc(x) fp32_y = flow.erfc(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_floor_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.floor(x) fp32_y = flow.floor(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_expm1_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.expm1(x) fp32_y = flow.expm1(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_lgamma_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.lgamma(x) fp32_y = flow.lgamma(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_log_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.log(x) fp32_y = flow.log(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_log2_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.log2(x) fp32_y = flow.log2(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_log1p_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.log1p(x) fp32_y = flow.log1p(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_sigmoid_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.sigmoid(x) fp32_y = flow.sigmoid(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_round_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.round(x) fp32_y = flow.round(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_rsqrt_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.rsqrt(x) fp32_y = flow.rsqrt(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_softplus_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.softplus(x) fp32_y = flow.softplus(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_softsign_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.softsign(x) fp32_y = flow.softsign(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_softshrink_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.softshrink(x) fp32_y = flow.softshrink(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_silu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.silu(x) fp32_y = flow.silu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_selu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.selu(x) fp32_y = flow.selu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_mish_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.mish(x) fp32_y = flow.mish(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_gelu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.gelu(x) fp32_y = flow.gelu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_elu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() elu = flow.nn.ELU() y = elu(x) fp32_y = elu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_celu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() celu = flow.nn.CELU() y = celu(x) fp32_y = celu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_hardswish_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() hardswish = flow.nn.Hardswish() y = hardswish(x) fp32_y = hardswish(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_hardswish_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() hardsigmoid = flow.nn.Hardsigmoid() y = hardsigmoid(x) fp32_y = hardsigmoid(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_hardshrink_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() hardshrink = flow.nn.Hardshrink() y = hardshrink(x) fp32_y = hardshrink(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_hardtanh_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() hardtanh = flow.nn.Hardtanh() y = hardtanh(x) fp32_y = hardtanh(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_leakyrelu_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() leakyrelu = flow.nn.LeakyReLU(0.1) y = leakyrelu(x) fp32_y = leakyrelu(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_threshold_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() th = flow.nn.Threshold(threshold=0.5, value=0.2) y = th(x) fp32_y = th(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_logsinmoid_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() logsigmoid = flow.nn.LogSigmoid() y = logsigmoid(x) fp32_y = logsigmoid(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) def test_digamma_with_random_data(test_case): np_array = np.random.rand(4, 4) x = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") fp32_x = x.float() y = flow.digamma(x) fp32_y = flow.digamma(fp32_x) test_case.assertTrue( np.allclose( y.float().numpy(), fp32_y.bfloat16().float().numpy(), atol=1e-4, rtol=1e-4, ) ) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_complex.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np import torch as torch_original import os import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import ( Array2Numpy, FlattenArray, GenArgList, Index2Coordinate, ) from collections import OrderedDict """ TODO(lml): Support and test more apis. Finished: flow.from_numpy() flow.tensor() flow.ones() flow.zeros() flow.full() flow.add() flow.sub() flow.mul flow.sum() flow.equal() flow.not_equal() flow.cast() Tensor.new_ones() Tensor.new_zeros() Tensor.new_full() Tensor.real() Tensor.imag() Tensor.conj() Tensor.conj_physical() To complete: flow.randn() flow.div() flow.pow() Tensor.adjoint() Tensor.conj_physical_() Tensor.resolve_conj() Tensor.chalf() Tensor.cfloat(), Tensor.cdouble() More apis.. """ def compare_result(a, b, rtol=1e-5, atol=1e-8): assert np.allclose( a, b, rtol=rtol, atol=atol ), f"\na\n{a}\n{'-' * 80}\nb:\n{b}\n{'*' * 80}\ndiff:\n{a - b}" def _np_zero_pad2d_grad(src, dest, padding): (c_idx, h_idx, w_idx) = (1, 2, 3) pad_left = padding[0] pad_right = padding[1] pad_top = padding[2] pad_bottom = padding[3] (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx]) (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx]) numpy_src = np.ones(src.shape, np.int32) numpy_dest = np.zeros(dest.shape, np.int32) array_src = FlattenArray(numpy_src) array_dest = FlattenArray(numpy_dest) src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx] dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx] elements_num = src.shape[0] * src_num for iter_n in range(elements_num): coords = Index2Coordinate(iter_n, src.shape) (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx]) ip_x = ip_y = 0 if ( j >= pad_left and j < dx_width + pad_left and (i >= pad_top) and (i < dx_height + pad_top) ): ip_x = j - pad_left ip_y = i - pad_top src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j dest_index = ( n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x ) array_dest[dest_index] += array_src[src_index] numpy_dest = Array2Numpy(array_dest, dest.shape) return numpy_dest def _test_ZeroPad2d(test_case, shape, padding, value, device, rtol, atol): np_input = np.random.random(shape) of_input = flow.tensor( np_input, dtype=test_case.dtype, device=flow.device(device), requires_grad=True ) if isinstance(padding, int): np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding)) elif isinstance(padding, (tuple, int)) and len(padding) == 4: np_boundary = ( (0, 0), (0, 0), (padding[2], padding[3]), (padding[0], padding[1]), ) else: raise ValueError("padding must be in or tuple!") layer = flow.nn.ZeroPad2d(padding=padding) of_out = layer(of_input) np_out = np.pad(np_input, np_boundary, mode="constant", constant_values=value) test_case.assertTrue(np.allclose(of_out.cpu().detach().numpy(), np_out, rtol, atol)) of_out = of_out.sum() of_out.backward() np_out_grad = _np_zero_pad2d_grad(np_out, np_input, layer.padding) test_case.assertTrue( np.allclose(of_input.grad.cpu().detach().numpy(), np_out_grad, rtol, atol) ) class TestTensorComplex64(unittest.TestCase): def setUp(self): self.dtype = flow.cfloat self.complex_dtype = flow.complex64 self.np_dtype = np.complex64 self.type_str = "ComplexFloatTensor" self.real_dtype = flow.float self.np_real_dtype = np.float32 self.rtol = 1e-5 self.atol = 1e-5 self.a = [1.0 + 1j, 2.0] self.np_a = np.array(self.a, dtype=self.np_dtype) self.b = [[1.0 + 1j, 2.0], [1.0, 2.0 - 1j], [-1.0, 1j]] self.np_b = np.array(self.b, dtype=self.np_dtype) self.lower_n_dims = 2 self.upper_n_dims = 5 self.shape = [] for _ in range(10): num_dims = np.random.randint(self.lower_n_dims, self.upper_n_dims) shape_ = [np.random.randint(1, 11) * 4 for _ in range(num_dims)] self.shape.append(shape_) def test_from_numpy(self): a = flow.from_numpy(self.np_a) self.assertEqual(a.dtype, self.dtype) self.assertEqual(a.type(), "oneflow." + self.type_str) np_a = a.numpy() self.assertEqual(np_a.dtype, self.np_dtype) assert np.allclose(np_a, self.np_a) b = flow.from_numpy(self.np_b) self.assertEqual(b.dtype, self.dtype) self.assertEqual(b.type(), "oneflow." + self.type_str) np_b = b.numpy() self.assertEqual(np_b.dtype, self.np_dtype) assert np.allclose(np_b, self.np_b) def test_tensor(self): a = flow.tensor(self.a, dtype=self.dtype) self.assertEqual(a.dtype, self.dtype) self.assertEqual(a.type(), "oneflow." + self.type_str) np_a = a.numpy() self.assertEqual(np_a.dtype, self.np_dtype) assert np.allclose(np_a, self.np_a) a = flow.tensor(self.np_a, dtype=self.dtype) self.assertEqual(a.dtype, self.dtype) self.assertEqual(a.type(), "oneflow." + self.type_str) np_a = a.numpy() self.assertEqual(np_a.dtype, self.np_dtype) assert np.allclose(np_a, self.np_a) @unittest.skip("skip for now, becase it failed 6 times in past week") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_tensor_cuda(self): a = flow.tensor(self.a, dtype=self.dtype, device="cuda") self.assertEqual(a.dtype, self.dtype) self.assertEqual(a.type(), "oneflow.cuda." + self.type_str) np_a = a.numpy() self.assertEqual(np_a.dtype, self.np_dtype) assert np.allclose(np_a, self.np_a) a = flow.tensor(self.np_a, dtype=self.dtype, device="cuda") self.assertEqual(a.dtype, self.dtype) self.assertEqual(a.type(), "oneflow.cuda." + self.type_str) np_a = a.numpy() self.assertEqual(np_a.dtype, self.np_dtype) assert np.allclose(np_a, self.np_a) @unittest.skip("skip for now, becase it failed 2 times in past week") def test_slice(self): a = flow.from_numpy(self.np_a) np_slice_a = a[1].numpy() self.assertEqual(np_slice_a.dtype, self.np_dtype) assert np.allclose(np_slice_a, self.np_a[1]) b = flow.from_numpy(self.np_b) np_slice_b = b[1].numpy() self.assertEqual(np_slice_b.dtype, self.np_dtype) assert np.allclose(np_slice_b, self.np_b[1]) c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype) np_slice_c = c[0:2, :].numpy() self.assertEqual(np_slice_c.dtype, self.np_dtype) assert np.allclose( np_slice_c, np.ones((2, 2), dtype=self.np_dtype) * (3.14 + 2j) ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_slice_cuda(self): a = flow.from_numpy(self.np_a).cuda() np_slice_a = a[1].cpu().numpy() self.assertEqual(np_slice_a.dtype, self.np_dtype) assert np.allclose(np_slice_a, self.np_a[1]) b = flow.from_numpy(self.np_b).cuda() np_slice_b = b[1].cpu().numpy() self.assertEqual(np_slice_b.dtype, self.np_dtype) assert np.allclose(np_slice_b, self.np_b[1]) c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).cuda() np_slice_c = c[0:2, :].cpu().numpy() self.assertEqual(np_slice_c.dtype, self.np_dtype) assert np.allclose( np_slice_c, np.ones((2, 2), dtype=self.np_dtype) * (3.14 + 2j) ) def test_new_tensor(self): a = flow.tensor(self.a, dtype=self.dtype) b = a.new_tensor(self.b) self.assertEqual(b.dtype, self.dtype) self.assertEqual(b.type(), "oneflow." + self.type_str) np_b = b.numpy() self.assertEqual(np_b.dtype, self.np_dtype) assert np.allclose(np_b, self.np_b) def test_new_empty(self): a = flow.tensor(self.a, dtype=self.dtype) c = a.new_empty((3, 2)) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) def test_ones(self): c = flow.ones((3, 2), dtype=self.dtype) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype)) def test_new_ones(self): b = flow.tensor(self.b, dtype=self.dtype) c = b.new_ones((3, 2)) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype)) def test_zeros(self): c = flow.zeros((3, 2), dtype=self.dtype) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.zeros((3, 2), dtype=self.np_dtype)) def test_new_zeros(self): b = flow.tensor(self.b, dtype=self.dtype) c = b.new_zeros((3, 2)) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.zeros((3, 2), dtype=self.np_dtype)) def test_full(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 + 2j)) def test_new_full(self): a = flow.tensor(self.a, dtype=self.dtype) c = a.new_full((3, 2), 3.14 + 2j) self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 + 2j)) def test_real(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).real() self.assertEqual(c.dtype, self.real_dtype) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_real_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 3.14) def test_imag(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).imag() self.assertEqual(c.dtype, self.real_dtype) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_real_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 2) def test_conj(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).conj() self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j)) def test_conj_physical(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).conj_physical() self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_real_cuda(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device="cuda").real() self.assertEqual(c.dtype, self.real_dtype) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_real_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 3.14) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_imag_cuda(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device="cuda").imag() self.assertEqual(c.dtype, self.real_dtype) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_real_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 2) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_conj_cuda(self): c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device="cuda").conj() self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow.cuda." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_conj_physical_cuda(self): c = flow.full( (3, 2), 3.14 + 2j, dtype=self.dtype, device="cuda" ).conj_physical() self.assertEqual(c.dtype, self.dtype) self.assertEqual(c.type(), "oneflow.cuda." + self.type_str) np_c = c.numpy() self.assertEqual(np_c.dtype, self.np_dtype) assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j)) def test_add_cpu(self): device = "cpu" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.add(flow_x, flow_y) np_ret = np_x + np_y compare_result(flow_ret, np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol ) compare_result( flow_y.grad.numpy(), np.ones(input_shape), self.rtol, self.atol ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_add_cuda(self): device = "cuda" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.add(flow_x, flow_y) np_ret = np_x + np_y compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.cpu().detach().numpy(), np.ones(input_shape), self.rtol, self.atol, ) compare_result( flow_y.grad.cpu().detach().numpy(), np.ones(input_shape), self.rtol, self.atol, ) def test_sub_cpu(self): device = "cpu" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.sub(flow_x, flow_y) np_ret = np_x - np_y compare_result(flow_ret, np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol ) compare_result( flow_y.grad.numpy(), -np.ones(input_shape), self.rtol, self.atol ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_sub_cuda(self): device = "cuda" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.sub(flow_x, flow_y) np_ret = np_x - np_y compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.cpu().detach().numpy(), np.ones(input_shape), self.rtol, self.atol, ) compare_result( flow_y.grad.cpu().detach().numpy(), -np.ones(input_shape), self.rtol, self.atol, ) def test_mul_cpu(self): device = "cpu" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.mul(flow_x, flow_y) np_ret = np_x * np_y compare_result(flow_ret, np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.numpy(), flow_y.numpy().conjugate(), self.rtol, self.atol ) compare_result( flow_y.grad.numpy(), flow_x.numpy().conjugate(), self.rtol, self.atol ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_mul_cuda(self): device = "cuda" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) # forward flow_ret = flow.mul(flow_x, flow_y) np_ret = np_x * np_y compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol) # backward flow_ret.sum().backward() compare_result( flow_x.grad.cpu().detach().numpy(), flow_y.numpy().conjugate(), self.rtol, self.atol, ) compare_result( flow_y.grad.cpu().detach().numpy(), flow_x.numpy().conjugate(), self.rtol, self.atol, ) def test_sum_cpu(self): device = "cpu" for i, input_shape in enumerate(self.shape): n_dims = np.random.randint(1, len(input_shape)) dims = np.random.choice( len(input_shape) - 1, n_dims, replace=False ).tolist() keepdim = True if np.random.randint(2) == 1 else False np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) # forward flow_ret = flow.sum(flow_x, dim=dims, keepdim=keepdim) np_ret = np.sum(np_x, axis=tuple(dims), keepdims=keepdim) compare_result(flow_ret, np_ret, self.rtol, self.atol * 1000) # backward flow_ret.sum().backward() compare_result( flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_sum_cuda(self): device = "cuda" for i, input_shape in enumerate(self.shape): n_dims = np.random.randint(1, len(input_shape)) dims = np.random.choice( len(input_shape) - 1, n_dims, replace=False ).tolist() keepdim = True if np.random.randint(2) == 1 else False np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True) self.assertEqual(flow_x.dtype, self.dtype) # forward flow_ret = flow.sum(flow_x, dim=dims, keepdim=keepdim) np_ret = np.sum(np_x, axis=tuple(dims), keepdims=keepdim) compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol * 1000) # backward flow_ret.sum().backward() compare_result( flow_x.grad.cpu().detach().numpy(), np.ones(input_shape), self.rtol, self.atol, ) def test_equal_cpu(self): device = "cpu" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) np_z = np.copy(np_x) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(False) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(False) flow_z = flow.from_numpy(np_z).to(device).requires_grad_(False) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) self.assertEqual(flow_z.dtype, self.dtype) # forward flow_ret = flow.equal(flow_x, flow_y) np_ret = np.equal(np_x, np_y) compare_result(flow_ret, np_ret, self.rtol, self.atol) flow_ret = flow.equal(flow_x, flow_z) compare_result( flow_ret, np.ones(flow_x.shape).astype(bool), self.rtol, self.atol ) flow_ret = flow.not_equal(flow_x, flow_z) compare_result( flow_ret, np.zeros(flow_x.shape).astype(bool), self.rtol, self.atol ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_equal_cuda(self): device = "cuda" for i, input_shape in enumerate(self.shape): np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_x = np_x.astype(self.np_dtype) np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape) np_y = np_y.astype(self.np_dtype) np_z = np.copy(np_x) flow_x = flow.from_numpy(np_x).to(device).requires_grad_(False) flow_y = flow.from_numpy(np_y).to(device).requires_grad_(False) flow_z = flow.from_numpy(np_z).to(device).requires_grad_(False) self.assertEqual(flow_x.dtype, self.dtype) self.assertEqual(flow_y.dtype, self.dtype) self.assertEqual(flow_z.dtype, self.dtype) # forward flow_ret = flow.equal(flow_x, flow_y) np_ret = np.equal(np_x, np_y) compare_result(flow_ret, np_ret, self.rtol, self.atol) flow_ret = flow.equal(flow_x, flow_z) compare_result( flow_ret, np.ones(flow_x.shape).astype(bool), self.rtol, self.atol ) flow_ret = flow.not_equal(flow_x, flow_z) compare_result( flow_ret.cpu().detach(), np.zeros(flow_x.shape).astype(bool), self.rtol, self.atol, ) def test_constant_pad(self): arg_dict = OrderedDict() arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)] arg_dict["padding"] = [2, (1, 1, 2, 2)] arg_dict["value"] = [0.0] arg_dict["device"] = ( ["cpu", "cuda"] if os.getenv("ONEFLOW_TEST_CPU_ONLY") is None else ["cpu"] ) arg_dict["rtol"] = [self.rtol] arg_dict["atol"] = [self.atol] for arg in GenArgList(arg_dict): _test_ZeroPad2d(self, *arg) def test_cast(self): dtype_pairs = [ (np.uint8, "ByteTensor"), (np.int8, "CharTensor"), (np.int32, "IntTensor"), (np.int64, "LongTensor"), (np.float32, "FloatTensor"), (np.float64, "DoubleTensor"), ] shape = (3, 5, 2) for np_dtype, type_str in dtype_pairs: np_arr = np.random.randn(*shape).astype(np_dtype) flow_tensor = flow.from_numpy(np_arr) self.assertEqual(flow_tensor.type(), "oneflow." + type_str) np_out = np_arr.astype(self.np_dtype) flow_out = flow.cast(flow_tensor, dtype=self.complex_dtype) self.assertTrue(np.array_equal(flow_out.numpy(), np_out)) # cp64 -> cp128 np_arr = np.random.randn(*shape) + 1.0j * np.random.randn(*shape) np_arr = np_arr.astype(np.complex64) flow_tensor = flow.from_numpy(np_arr) self.assertEqual(flow_tensor.dtype, flow.complex64) np_out = np_arr.astype(np.complex128) flow_out = flow.cast(flow_tensor, dtype=flow.complex128) self.assertTrue(np.array_equal(flow_out.numpy(), np_out)) # cp128 -> cp64 np_out = np_out.astype(np.complex64) flow_out = flow.cast(flow_out, dtype=flow.complex64) self.assertTrue(np.array_equal(flow_out.numpy(), np_out)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_cast_cuda(self): dtype_pairs = [ (np.uint8, "ByteTensor"), (np.int8, "CharTensor"), (np.int32, "IntTensor"), (np.int64, "LongTensor"), (np.float32, "FloatTensor"), (np.float64, "DoubleTensor"), ] shape = (7, 4, 11) for np_dtype, type_str in dtype_pairs: np_arr = np.random.randn(*shape).astype(np_dtype) flow_tensor = flow.from_numpy(np_arr).cuda() self.assertEqual(flow_tensor.type(), "oneflow.cuda." + type_str) np_out = np_arr.astype(self.np_dtype) flow_out = flow.cast(flow_tensor, dtype=self.complex_dtype) self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out)) # cp64 -> cp128 np_arr = np.random.randn(*shape) + 1.0j * np.random.randn(*shape) np_arr = np_arr.astype(np.complex64) flow_tensor = flow.from_numpy(np_arr).cuda() self.assertEqual(flow_tensor.dtype, flow.complex64) np_out = np_arr.astype(np.complex128) flow_out = flow.cast(flow_tensor, dtype=flow.complex128) self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out)) # cp128 -> cp64 np_out = np_out.astype(np.complex64) flow_out = flow.cast(flow_out, dtype=flow.complex64) self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out)) class TestTensorComplex128(TestTensorComplex64): def setUp(self): self.dtype = flow.cdouble self.complex_dtype = flow.complex128 self.np_dtype = np.complex128 self.type_str = "ComplexDoubleTensor" self.real_dtype = flow.double self.np_real_dtype = np.float64 self.rtol = 1e-7 self.atol = 1e-7 self.a = [1.0 + 1j, 2.0] self.np_a = np.array(self.a, dtype=self.np_dtype) self.b = [[1.0 + 1j, 2.0], [1.0, 2.0 - 1j], [-1.0, 1j]] self.np_b = np.array(self.b, dtype=self.np_dtype) self.lower_n_dims = 2 self.upper_n_dims = 5 self.shape = [] for _ in range(10): num_dims = np.random.randint(self.lower_n_dims, self.upper_n_dims) shape_ = [np.random.randint(1, 11) * 4 for _ in range(num_dims)] self.shape.append(shape_) class TestAutograd(unittest.TestCase): def test_backward(self): a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat) a.requires_grad = True b = flow.conj(a) loss = flow.sum(a.real() + b.imag()) loss.backward() assert np.allclose(a.grad.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_backward_cuda(self): a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat, device="cuda") a.requires_grad = True b = flow.conj(a) loss = flow.sum(a.real() + b.imag()) loss.backward() assert np.allclose(a.grad.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j)) def test_grad(self): a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat) a.requires_grad = True b = flow.conj(a) c = a.real() + b.imag() np_dc = np.ones((3,), dtype=np.float32) dc = flow.tensor(np_dc) (da,) = flow.autograd.grad(c, a, dc) assert np.allclose(da.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_grad_cuda(self): a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat, device="cuda") a.requires_grad = True b = flow.conj(a) c = a.real() + b.imag() np_dc = np.ones((3,), dtype=np.float32) dc = flow.tensor(np_dc) (da,) = flow.autograd.grad(c, a, dc) assert np.allclose(da.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_data_ptr.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest class TestDataPtr(unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_equality(test_case): x = flow.ones(2, 3) y = flow.ones(2, 3) test_case.assertNotEqual(x.data_ptr(), y.data_ptr()) test_case.assertEqual(x.data_ptr(), x.data.data_ptr()) x_ptr = x.data_ptr() x[:] = 2 test_case.assertEqual(x_ptr, x.data_ptr()) @flow.unittest.skip_unless_1n2d() def test_global_tensor(test_case): x = flow.randn( 2, 3, placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast ) test_case.assertEqual(x.data_ptr(), x.to_local().data_ptr()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_global_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTensor(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_creating_global_tensor(test_case): placement = flow.placement("cuda", [0]) sbp = flow.sbp.broadcast # Shape -> GlobalTensor shape = (2, 3) x = flow.Tensor(*shape, placement=placement, sbp=sbp) test_case.assertTrue(x.is_global) test_case.assertTrue(x.size() == shape) shape = flow.Size((2, 3)) x = flow.Tensor(shape, placement=placement, sbp=sbp) test_case.assertTrue(x.is_global) test_case.assertTrue(x.size() == shape) # LocalTensor -> GlobalTensor x = flow.Tensor(*shape, device="cpu") test_case.assertTrue(x.is_local) y = flow.Tensor(x, placement=placement, sbp=sbp) test_case.assertTrue(y.is_global) # GlobalTensor -> GlobalTensor z = flow.Tensor(y, placement=placement, sbp=sbp) test_case.assertTrue(z.is_global) # TODO: ndarray -> GlobalTensor @flow.unittest.skip_unless_1n1d() def test_construct_local_from_global_tensor(test_case): placement = flow.placement("cuda", [0]) sbp = flow.sbp.broadcast shape = (2, 3) x = flow.Tensor(*shape, placement=placement, sbp=sbp) test_case.assertTrue(x.is_global) # GlobalTensor -> LocalTensor y = flow.Tensor(x, device="cpu") test_case.assertTrue(y.is_local) y = flow.Tensor(x, device="cuda") test_case.assertTrue(y.is_local) @flow.unittest.skip_unless_1n1d() def test_global_set_data(test_case): x_placement = flow.placement("cpu", [0]) x_sbp = flow.sbp.broadcast x = flow.ones(2, 3, placement=x_placement, sbp=x_sbp) y_placement = flow.placement("cuda", [0]) y_sbp = flow.sbp.split(0) y = flow.ones(4, 5, placement=y_placement, sbp=y_sbp) old_id = id(x) x.data = y test_case.assertEqual(old_id, id(x)) test_case.assertTrue(x.shape == (4, 5)) test_case.assertTrue(x.placement == y_placement) test_case.assertTrue(x.sbp[0] == y_sbp) @flow.unittest.skip_unless_1n1d() def test_global_tensor_autograd_related_methods(test_case): placement = flow.placement("cuda", [0]) sbp = flow.sbp.split(0) shape = (2, 3, 4, 5) l_x = flow.Tensor(*shape) test_case.assertFalse(l_x.requires_grad) test_case.assertTrue(l_x.is_leaf) l_y = flow.Tensor(*shape) l_y.requires_grad = True test_case.assertTrue(l_y.requires_grad) test_case.assertTrue(l_y.is_leaf) x = l_x.to_global(placement=placement, sbp=sbp) test_case.assertTrue(x.is_leaf) y = l_y.to_global(placement=placement, sbp=sbp) test_case.assertFalse(y.is_leaf) z = x + y test_case.assertTrue(z.requires_grad) test_case.assertFalse(z.is_leaf) with flow.no_grad(): m = x + y test_case.assertTrue(m.is_leaf) test_case.assertFalse(m.requires_grad) l_v = flow.Tensor(*shape) l_v.requires_grad = True v = l_v.to_global(placement=placement, sbp=sbp) z.retain_grad() w = v + z l_grad = flow.ones(*shape) grad = l_grad.to_global(placement=placement, sbp=sbp) w.backward(gradient=grad) test_case.assertTrue( np.allclose(l_v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(l_y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose( z.grad.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), np.ones(shape), atol=1e-4, rtol=1e-4, ) ) test_case.assertIsNone(l_x.grad) @flow.unittest.skip_unless_1n1d() def test_global_tensor_unsupported_property(test_case): shape = (2, 3) placement = flow.placement("cuda", [0]) sbp = flow.sbp.split(0) a = flow.Tensor(*shape) b = a.to_global(placement=placement, sbp=sbp) test_case.assertTrue(b.is_global) with test_case.assertRaises(RuntimeError): b.device() with test_case.assertRaises(RuntimeError): b._tensor_buffer_shapes_and_dtypes @flow.unittest.skip_unless_1n4d() def test_global_tensor_2d_sbp_init(test_case): V = 10 H = 4 S = 6 P = flow.placement("cuda", [[0, 1], [2, 3]]) wte = flow.nn.Parameter( flow.empty( (V, H), dtype=flow.float32, placement=P, sbp=[flow.sbp.broadcast, flow.sbp.split(0)], ) ) wpe = flow.nn.Parameter( flow.empty( (S, H), dtype=flow.float32, placement=P, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], ) ) flow.nn.init.normal_(wte, std=0.02) flow.nn.init.normal_(wpe, std=0.02) @flow.unittest.skip_unless_1n2d() def test_copy(test_case): x = flow.zeros(2, 3) y = flow.ones(2, 3) x.copy_(y) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) x = flow.zeros( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) y = flow.ones( 4, 6, placement=flow.placement("cpu", [0]), sbp=flow.sbp.broadcast ) x.copy_(y) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) x = flow.zeros( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) y = flow.ones( 4, 6, placement=flow.placement("cuda", [0]), sbp=flow.sbp.broadcast ) x.copy_(y) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) x = flow.zeros( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.split(0) ) y = flow.ones( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) x.copy_(y) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) x = flow.zeros( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) y = flow.ones( 4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast ) x.copy_(y) test_case.assertTrue(np.array_equal(x.numpy(), y.numpy())) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_global_tensor_and_ndarray_compatibility.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import torch import oneflow as flow import unittest import oneflow.unittest from oneflow.test_utils.automated_test_util import * import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * import numpy as np np.random.seed(233) test_compute_op_list = [ "+", "-", "*", "/", "**", "//", "%", ] def do_test_compute_op(test_case, ndim, placement, sbp): dims = [random(1, 4) * 8 for i in range(ndim)] x = random_tensor(ndim, *dims, dtype=int, low=0, high=5) x = x.to_global(placement=placement, sbp=sbp) x = x.to("cpu") flow_input = x.oneflow.detach() torch_input = x.pytorch.detach() for op in test_compute_op_list: if op not in ["**"]: random_numpy = np.random.randint(1, 30000, size=list(flow_input.shape)) else: random_numpy = np.random.randint(1, 5, size=list(flow_input.shape)) z_flow = eval(f"flow_input {op} random_numpy") z_torch = eval(f"torch_input {op} random_numpy") test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy())) class TestGlobalTensorAndNdarrayCompatibility(flow.unittest.TestCase): @globaltest def test_tensor_and_ndarray_compatibility(test_case): # random ndim in range [1,4] ndim = random(1, 5).to(int).value() for placement in all_placement(): for sbp in all_sbp(placement, max_dim=ndim): do_test_compute_op(test_case, ndim, placement, sbp) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_global_tensor_indexing.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This test code is referenced from: https://github.com/pytorch/pytorch/blob/cd41c8f032dd06c445bf97fc76fb82008b19afcb/test/test_indexing.py import unittest import numpy as np import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest def _randint(low, high): """ Get a random integer in the range [low, high). """ return random(low, high).to(int).value() def _cpu_global_tensor(tensor): return tensor.to_global(flow.placement.all("cpu"), flow.sbp.broadcast) def _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0): test_case.assertTrue( np.allclose(tensor1.numpy(), tensor2.numpy(), atol, rtol), f"{tensor1.numpy()} vs {tensor2.numpy()}", ) def global_broadcast_consec(size, start=1): """ Generate a arithmetic progression with given size and start value. """ sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0) sequence.add_(start - 1) return _cpu_global_tensor(sequence.view(*size)) def _test_basic_slice(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) ref_sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, ref_sbp) # empty tensor indexing _assert_tensor_equal( test_case, reference[ _cpu_global_tensor(flow.LongTensor()).to_global( placement, broadcast_for_placement ) ], flow.empty(0, 8, 8), atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference[0], global_broadcast_consec((8, 8)), atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[1], global_broadcast_consec((8, 8), 65), atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[2], global_broadcast_consec((8, 8), 129), atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0, 1], global_broadcast_consec((8,), 9), atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0:2], global_broadcast_consec((2, 8, 8)), atol=0, rtol=0 ) test_case.assertEqual(reference[2, 2, 2].item(), 147) _assert_tensor_equal( test_case, reference[:], global_broadcast_consec((8, 8, 8)), atol=0, rtol=0 ) # indexing with Ellipsis _assert_tensor_equal( test_case, reference[..., 2, 2], flow.tensor([19, 83, 147, 211, 275, 339, 403, 467]), atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference[0, ..., 2], flow.tensor([3, 11, 19, 27, 35, 43, 51, 59]), atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference[..., 2], reference[:, :, 2], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0, 2, ...], reference[0, 2], atol=0, rtol=0 ) test_case.assertEqual(reference[..., 2, 2, 2].item(), 147) test_case.assertEqual(reference[2, ..., 2, 2].item(), 147) test_case.assertEqual(reference[2, 2, ..., 2].item(), 147) test_case.assertEqual(reference[2, 2, 2, ...].item(), 147) _assert_tensor_equal(test_case, reference[...], reference, atol=0, rtol=0) reference_5d = global_broadcast_consec((8, 8, 8, 8, 8)).to_global( placement, sbp=random_sbp(placement, max_dim=5).value() ) _assert_tensor_equal( test_case, reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0, ) _assert_tensor_equal(test_case, reference_5d[...], reference_5d, atol=0, rtol=0) # LongTensor indexing sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.LongTensor([2, 4])).to_global( placement, broadcast_for_placement ) _assert_tensor_equal( test_case, reference[idx], flow.stack([reference[2], reference[4]]) ) # None indexing _assert_tensor_equal(test_case, reference[2, None], reference[2].unsqueeze(0)) _assert_tensor_equal( test_case, reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0) ) _assert_tensor_equal(test_case, reference[2:4, None], reference[2:4].unsqueeze(1)) _assert_tensor_equal( test_case, reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), ) _assert_tensor_equal( test_case, reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), ) # indexing 0-length slice _assert_tensor_equal(test_case, flow.empty(0, 8, 8), reference[slice(0)]) _assert_tensor_equal(test_case, flow.empty(0, 8), reference[slice(0), 2]) _assert_tensor_equal(test_case, flow.empty(0, 8), reference[2, slice(0)]) _assert_tensor_equal(test_case, flow.tensor([]), reference[2, 1:1, 2]) # indexing with step sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, reference[1:5:2], flow.stack([reference[1], reference[3]], 0) ) _assert_tensor_equal( test_case, reference[1:6:2], flow.stack([reference[1], reference[3], reference[5]], 0), ) _assert_tensor_equal( test_case, reference[1:9:4], flow.stack([reference[1], reference[5]], 0) ) _assert_tensor_equal( test_case, reference[2:4, 1:5:2], flow.stack([reference[2:4, 1], reference[2:4, 3]], 1), ) _assert_tensor_equal( test_case, reference[3, 1:6:2], flow.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0), ) _assert_tensor_equal( test_case, reference[None, 2, 1:9:4], flow.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0), ) _assert_tensor_equal( test_case, reference[:, 2, 1:6:2], flow.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1), ) # random check lst = [ list(range(i, i + 16)) for i in range(0, 256, 16) ] # arange(64).reshape(8, 8) tensor = _cpu_global_tensor(flow.DoubleTensor(lst)) for _ in range(5): sbp = random_sbp(placement, max_dim=2).value() cur_tensor = tensor.to_global(placement, sbp) idx1_start = _randint(0, 16) idx1_end = idx1_start + _randint(1, 16 - idx1_start + 1) idx1_step = _randint(1, 14) idx1 = slice(idx1_start, idx1_end, idx1_step) if _randint(0, 2) == 0: idx2_start = _randint(0, 16) idx2_end = idx2_start + _randint(1, 16 - idx2_start + 1) idx2_step = _randint(1, 14) idx2 = slice(idx2_start, idx2_end, idx2_step) lst_indexed = [l[idx2] for l in lst[idx1]] tensor_indexed = cur_tensor[idx1, idx2] else: lst_indexed = lst[idx1] tensor_indexed = cur_tensor[idx1] _assert_tensor_equal(test_case, flow.DoubleTensor(lst_indexed), tensor_indexed) # error check sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) test_case.assertRaises(RuntimeError, lambda: reference[1:9:0]) test_case.assertRaises(RuntimeError, lambda: reference[1:9:-1]) test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) test_case.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) test_case.assertRaises(IndexError, lambda: reference[0.0]) test_case.assertRaises(RuntimeError, lambda: reference[0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) def _test_advanced_indexing(test_case, placement, dtype): broadcast_for_placement = [flow.sbp.broadcast] * len(placement.ranks.shape) # pick a random valid indexer type def ri(indices): choice = _randint(0, 3) if choice == 0: return _cpu_global_tensor(flow.LongTensor(indices)).to_global( placement, broadcast_for_placement ) elif choice == 1: return list(indices) else: return tuple(indices) def validate_indexing(x): _assert_tensor_equal(test_case, x[[0]], global_broadcast_consec((1,))) _assert_tensor_equal(test_case, x[ri([0]),], global_broadcast_consec((1,))) _assert_tensor_equal(test_case, x[ri([3]),], global_broadcast_consec((1,), 4)) _assert_tensor_equal(test_case, x[[2, 3, 4]], global_broadcast_consec((3,), 3)) _assert_tensor_equal( test_case, x[ri([2, 3, 4]),], global_broadcast_consec((3,), 3) ) _assert_tensor_equal( test_case, x[ri([0, 2, 4]),], flow.tensor([1, 3, 5], dtype=dtype), ) def validate_setting(x): x[[0]] = -2 _assert_tensor_equal(test_case, x[[0]], flow.tensor([-2], dtype=dtype)) x[[0]] = -1 _assert_tensor_equal(test_case, x[ri([0]),], flow.tensor([-1], dtype=dtype)) x[[2, 3, 4]] = 4 _assert_tensor_equal( test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype) ) x[ri([2, 3, 4]),] = 3 _assert_tensor_equal( test_case, x[ri([2, 3, 4]),], flow.tensor([3, 3, 3], dtype=dtype), ) x[ri([0, 2, 4]),] = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype)) _assert_tensor_equal( test_case, x[ri([0, 2, 4]),], flow.tensor([5, 4, 3], dtype=dtype), ) # 1d tensor and integer index setitem and getitem sbp = random_sbp(placement, max_dim=1).value() reference = global_broadcast_consec((8,)).to_global(placement, sbp) validate_indexing(reference) validate_setting(reference) # reference is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([1, 9, 17], dtype=dtype), ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([1])], flow.tensor([2, 10, 18], dtype=dtype), ) _assert_tensor_equal( test_case, reference[ri([0]), ri([0])], global_broadcast_consec((1,)) ) _assert_tensor_equal( test_case, reference[ri([2]), ri([1])], global_broadcast_consec((1,), 18) ) _assert_tensor_equal( test_case, reference[[ri([0, 0]), ri([0, 1])]], flow.tensor([1, 2], dtype=dtype), ) _assert_tensor_equal( test_case, reference[[ri([0, 1, 1, 0, 2, 7]), ri([1])]], flow.tensor([2, 10, 10, 2, 18, 58], dtype=dtype), ) _assert_tensor_equal( test_case, reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], flow.tensor([1, 2, 9, 9], dtype=dtype), ) rows = ri([[0, 0], [1, 6]]) columns = ([0],) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[1, 1], [9, 49]], dtype=dtype), ) rows = ri([[0, 0], [1, 6]]) columns = ri([6, 0]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[7, 1], [15, 49]], dtype=dtype), ) rows = ri([[0, 0], [1, 2]]) columns = ri([[0, 1], [3, 7]]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[1, 2], [12, 24]], dtype=dtype), ) # setting values reference[ri([0]), ri([1])] = -1 _assert_tensor_equal( test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype), ) reference[ri([0, 1, 2]), ri([0])] = _cpu_global_tensor( flow.tensor([-1, 2, -4], dtype=dtype) ).to_global(placement, broadcast_for_placement) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([-1, 2, -4], dtype=dtype), ) reference[rows, columns] = _cpu_global_tensor( flow.tensor([[4, 6], [2, 3]], dtype=dtype) ).to_global(placement, broadcast_for_placement) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype), ) # Tests using less than the number of dims, and ellipsis # reference is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, reference[ri([0, 2]),], flow.tensor( [[1, 2, 3, 4, 5, 6, 7, 8], [17, 18, 19, 20, 21, 22, 23, 24]], dtype=dtype ), ) _assert_tensor_equal( test_case, reference[ri([1]), ...], flow.tensor([[9, 10, 11, 12, 13, 14, 15, 16]], dtype=dtype), ) _assert_tensor_equal( test_case, reference[..., ri([1])], flow.tensor([[2], [10], [18], [26], [34], [42], [50], [58]], dtype=dtype), ) # verify too many indices fails with test_case.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])] # test invalid index fails sbp = random_sbp(placement, max_dim=1).value() reference = _cpu_global_tensor(flow.empty(8, dtype=dtype)).to_global(placement, sbp) for err_idx in (10, -11): with test_case.assertRaisesRegex(IndexError, r"out of bounds"): reference[err_idx] def _test_combined_indexing(test_case, placement, dtype): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) def tensor_indices_to_np(tensor, indices): # convert the flow Tensor to a numpy array npt = tensor.numpy() # convert indices idxs = tuple( i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices ) return npt, idxs def get_numpy(tensor, indices): npt, idxs = tensor_indices_to_np(tensor, indices) # index and return as a oneflow local Tensor return flow.tensor(npt[idxs], dtype=dtype) def set_numpy(tensor, indices, value): if not isinstance(value, int): value = value.numpy() npt, idxs = tensor_indices_to_np(tensor, indices) npt[idxs] = value return npt def assert_get_eq(tensor, indexer): _assert_tensor_equal(test_case, tensor[indexer], get_numpy(tensor, indexer)) def assert_set_eq(tensor, indexer, val): pyt = tensor.clone() np_ref = tensor.clone() pyt[indexer] = val np_ref = flow.tensor(set_numpy(np_ref, indexer, val), dtype=dtype) _assert_tensor_equal(test_case, pyt, np_ref) def assert_backward_eq(tensor, indexer): # compare gradient between cpu and cuda cpu = ( tensor.float() .clone() .detach() .to_global(placement, broadcast_for_placement) .requires_grad_() ) outcpu = cpu.clone()[indexer] outcpu.sum().backward() dev = ( cpu.detach() .to_global( placement, random_sbp(placement, max_dim=len(tensor.shape)).value() ) .requires_grad_(True) ) outdev = dev[indexer] outdev.sum().backward() _assert_tensor_equal(test_case, cpu.grad, dev.grad) def get_set_tensor(indexed, indexer): set_size = indexed[indexer].size() set_count = indexed[indexer].numel() set_tensor = _cpu_global_tensor( flow.arange(set_count, 0, -1).view(set_size).to(dtype) ).to_global(placement, broadcast_for_placement) return set_tensor # Tensor is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) indices_to_test = [ # grab the second, fourth columns [slice(None), [4, 6]], # first, third rows, [[0, 6], slice(None)], # TODO(wyg): only support getitem but not setitem # # weird shape # [slice(None), [[0, 1], # [2, 3]]], # negatives [[-1], [0]], [[0, 7], [-1]], [slice(None), [-1]], ] # test getitem get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] get_indices_to_test = indices_to_test + [ [slice(None), [[0, 1], [2, 3]]] ] # TODO: test setitem for indexer in get_indices_to_test: assert_get_eq(reference, indexer) if placement.type != "cpu": assert_backward_eq(reference, indexer) # test setitem for indexer in indices_to_test: assert_set_eq(reference, indexer, 44) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) ######################### # test more dims tensor # ######################### sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8), 0).float().to_global(placement, sbp) indices_to_test = [ [slice(None), slice(None), [0, 3, 4]], [slice(None), [2, 4, 5, 7], slice(None)], [[2, 3], slice(None), slice(None)], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0], [1, 2, 4]], [slice(None), [0, 1, 3], [4]], [slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), [[5, 6]], [[0, 3], [4, 4]]], [[0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4], slice(None)], [[0, 1, 3], [4], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 3]], slice(None)], [[[0, 1], [2, 3]], [[0]], slice(None)], [[[2, 1]], [[0, 3], [4, 4]], slice(None)], [[[2]], [[0, 3], [4, 1]], slice(None)], # non-contiguous indexing subspace [[0, 2, 3], slice(None), [1, 3, 4]], # less dim, ellipsis [[0, 2],], [[0, 2], slice(None)], [[0, 2], Ellipsis], [[0, 2], slice(None), Ellipsis], [[0, 2], Ellipsis, slice(None)], [[0, 2], [1, 3]], [[0, 2], [1, 3], Ellipsis], [Ellipsis, [1, 3], [2, 3]], [Ellipsis, [2, 3, 4]], [Ellipsis, slice(None), [2, 3, 4]], [slice(None), Ellipsis, [2, 3, 4]], # ellipsis counts for nothing [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, slice(None), [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), [0, 3, 4], Ellipsis], [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 212) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) if placement.type != "cpu": assert_backward_eq(reference, indexer) sbp = random_sbp(placement, max_dim=4).value() reference = ( global_broadcast_consec((8, 8, 8, 8), 0).float().to_global(placement, sbp) ) indices_to_test = [ [slice(None), slice(None), slice(None), [0, 3, 4]], [slice(None), slice(None), [2, 4, 5, 7], slice(None)], [slice(None), [2, 3], slice(None), slice(None)], [[1, 2], slice(None), slice(None), slice(None)], [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), slice(None), [0], [1, 2, 4]], [slice(None), slice(None), [0, 1, 3], [4]], [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], [slice(None), [0], [1, 2, 4], slice(None)], [slice(None), [0, 1, 3], [4], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], [[0], [1, 2, 4], slice(None), slice(None)], [[0, 1, 2], [4], slice(None), slice(None)], [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 4], [1, 3, 4], [4]], [slice(None), [0, 1, 3], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 5], [3], [4]], [slice(None), [0], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1]], [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], [[2, 0, 1], [1, 2, 3], [4], slice(None)], [[0, 1, 2], [4], [1, 3, 4], slice(None)], [[0], [0, 2, 3], [1, 3, 4], slice(None)], [[0, 2, 1], [3], [4], slice(None)], [[0], [4], [1, 3, 4], slice(None)], [[1], [0, 2, 3], [1], slice(None)], [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], # less dim, ellipsis [Ellipsis, [0, 3, 4]], [Ellipsis, slice(None), [0, 3, 4]], [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4]], [[0], [1, 2, 4], slice(None)], [[0], [1, 2, 4], Ellipsis], [[0], [1, 2, 4], Ellipsis, slice(None)], [[1],], [[0, 2, 1], [3], [4]], [[0, 2, 1], [3], [4], slice(None)], [[0, 2, 1], [3], [4], Ellipsis], [Ellipsis, [0, 2, 1], [3], [4]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) indices_to_test += [ [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) if placement.type != "cpu": assert_backward_eq(reference, indexer) def _test_single_int(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() v = _cpu_global_tensor(flow.zeros(8, 7, 3)).to_global(placement, sbp) test_case.assertEqual(v[2].shape, (7, 3)) test_case.assertEqual(v[6].shape, (7, 3)) def _test_multiple_int(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp) test_case.assertEqual(v[4, :, 1].shape, (8,)) def _test_none(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp) test_case.assertEqual(v[None].shape, (1, 8, 8, 8)) test_case.assertEqual(v[:, None].shape, (8, 1, 8, 8)) test_case.assertEqual(v[:, None, None].shape, (8, 1, 1, 8, 8)) test_case.assertEqual(v[..., None].shape, (8, 8, 8, 1)) def _test_step(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() v = _cpu_global_tensor(flow.arange(8)).to_global(placement, sbp) _assert_tensor_equal(test_case, v[::1], v) test_case.assertEqual(v[::2].tolist(), [0, 2, 4, 6]) test_case.assertEqual(v[::3].tolist(), [0, 3, 6]) test_case.assertEqual(v[::11].tolist(), [0]) test_case.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) def _test_step_assignment(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() v = _cpu_global_tensor(flow.zeros(8, 8)).to_global(placement, sbp) v[0, 1::2] = _cpu_global_tensor(flow.tensor([3.0, 4.0, 5.0, 6.0])).to_global( placement, broadcast_for_placement ) test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0]) test_case.assertEqual(v[1:].sum(), 0) def _test_bool_indices(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=3).value() v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) boolIndices = _cpu_global_tensor( flow.tensor( [True, False, True, True, False, False, False, True], dtype=flow.bool ) ).to_global(placement, broadcast_for_placement) test_case.assertEqual(v[boolIndices].shape, (4, 8, 8)) _assert_tensor_equal( test_case, v[boolIndices], flow.stack([v[0], v[2], v[3], v[7]]) ) def _test_multiple_bool_indices(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() v = global_broadcast_consec((8, 8, 4)).to_global(placement, sbp) # NOTE: these broadcast together and are transposed to the first dim mask1 = _cpu_global_tensor( flow.tensor([1, 0, 1, 0, 0, 1, 0, 0], dtype=flow.bool) ).to_global(placement, broadcast_for_placement) mask2 = _cpu_global_tensor(flow.tensor([1, 1, 1, 0], dtype=flow.bool)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(v[mask1, :, mask2].shape, (3, 8)) def _test_int_indices(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) test_case.assertEqual(v[[0, 4, 2]].shape, (3, 8, 8)) test_case.assertEqual(v[:, [0, 4, 2]].shape, (8, 3, 8)) test_case.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (8, 2, 2, 8)) def _test_int_indices2d(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) rows = _cpu_global_tensor(flow.tensor([[0, 0], [6, 3]])).to_global( placement, broadcast_for_placement ) columns = _cpu_global_tensor(flow.tensor([[0, 2], [0, 7]])).to_global( placement, broadcast_for_placement ) test_case.assertEqual(x[rows, columns].tolist(), [[1, 3], [49, 32]]) def _test_int_indices_broadcast(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) rows = _cpu_global_tensor(flow.tensor([0, 7])).to_global( placement, broadcast_for_placement ) columns = _cpu_global_tensor(flow.tensor([7, 2])).to_global( placement, broadcast_for_placement ) result = x[rows[:, None], columns] test_case.assertEqual(result.tolist(), [[8, 3], [64, 59]]) def _test_empty_index(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) # TODO:(wangyinggang): masked_fill support sbp:partial_sum sbp = random_sbp(placement, max_dim=2, except_partial_sum=True).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.tensor([], dtype=flow.long)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(x[idx].numel(), 0) # empty assignment should have no effect but not throw an exception y = x.clone() y[idx] = -1 _assert_tensor_equal(test_case, x, y) mask = _cpu_global_tensor(flow.zeros(8, 8).to(flow.bool)).to_global( placement, broadcast_for_placement ) y[mask] = -1 _assert_tensor_equal(test_case, x, y) def _test_empty_ndim_index(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=1).value() x = global_broadcast_consec((8,)).to_global(placement, sbp) _assert_tensor_equal( test_case, x[ _cpu_global_tensor(flow.empty(0, 2, dtype=flow.int64)).to_global( placement, broadcast_for_placement ) ], flow.empty(0, 2), ) sbp = random_sbp(placement, max_dim=1).value() x = _cpu_global_tensor(flow.empty(8, 0)).to_global(placement, sbp) test_case.assertEqual(x[[1, 2]].shape, (2, 0)) test_case.assertEqual(x[[], []].shape, (0,)) test_case.assertEqual(x[[[]]].shape, (0, 0)) test_case.assertEqual(x[[[[]]]].shape, (1, 0, 0)) test_case.assertEqual(x[[1], []].shape, (0,)) test_case.assertEqual(x[[], [2]].shape, (0,)) with test_case.assertRaisesRegex(IndexError, "for dimension with size 0"): x[:, [0, 1]] def _test_empty_ndim_index_bool(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=1).value() x = global_broadcast_consec((8,)).to_global(placement, sbp) test_case.assertRaises( IndexError, lambda: x[ _cpu_global_tensor(flow.empty(0, 2, dtype=flow.uint8)).to_global( placement, broadcast_for_placement ) ], ) def _test_empty_slice(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() x = global_broadcast_consec((8, 8, 8, 8)).to_global(placement, sbp) y = x[:, :, :, 1] z = y[:, 1:1, :] test_case.assertEqual((8, 0, 8), z.shape) def _test_index_getitem_copy_bools_slices(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) false = _cpu_global_tensor(flow.tensor(0, dtype=flow.uint8)).to_global( placement, broadcast_for_placement ) sbp = random_sbp(placement, max_dim=1).value() tensor = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal(test_case, flow.empty(0, *tensor.shape), tensor[False]) _assert_tensor_equal(test_case, flow.empty(0, *tensor.shape), tensor[false]) def _test_setitem_scalars(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) zero = _cpu_global_tensor(flow.tensor(0, dtype=flow.int64)).to_global( placement, broadcast_for_placement ) # non-scalar indexed with scalars a = global_broadcast_consec((8, 8)).to_global( placement, random_sbp(placement, max_dim=2).value() ) a_set_with_number = a.clone() a_set_with_scalar = a.clone() b = global_broadcast_consec((8,), 233).to_global( placement, random_sbp(placement, max_dim=1).value() ) a_set_with_number[0] = b a_set_with_scalar[zero] = b _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar) a[1, zero] = 7.7 value = a[1, 0].numpy() test_case.assertEqual(np.array(7.7, dtype=value.dtype), value) np_x = np.zeros((8, 8)) np_x[0, 6] = 1.0 x = _cpu_global_tensor(flow.tensor(np_x)).to_global( placement, random_sbp(placement, max_dim=2).value() ) x[0, 6] = 1.0 test_case.assertEqual(x.numpy().all(), np_x.all()) # scalar indexed with scalars r = _cpu_global_tensor(flow.tensor(1.0)).to_global( placement, random_sbp(placement, max_dim=0).value() ) with test_case.assertRaises(IndexError): r[:] = 8.8 with test_case.assertRaises(IndexError): r[zero] = 8.8 r[...] = 9.9 test_case.assertEqual(r, 9.9) # scalar indexed with oneflow.Size([1]) np_x = np.zeros((8, 8)) np_x[0, 6] = np.ones(1) x = _cpu_global_tensor(flow.tensor(np_x)).to_global( placement, random_sbp(placement, max_dim=2).value() ) x[0, 0] = _cpu_global_tensor(flow.ones(1).to(flow.float64)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(x.numpy().all(), np_x.all()) def _test_basic_advanced_combined(test_case, placement): sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal(test_case, x[1:2, 3:5], x[1:2, [3, 4]]) test_case.assertEqual(x[1:2, 1:3].tolist(), [[10, 11]]) # Check that it is a copy unmodified = x.clone() x[1:2, [1, 2]].zero_() _assert_tensor_equal(test_case, x, unmodified) # But assignment should modify the original unmodified = x.clone() x[1:2, [1, 2]] = 0 test_case.assertFalse(np.array_equal(x.numpy(), unmodified.numpy())) def _test_ellipsis_tensor(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.tensor([0, 7])).to_global( placement, broadcast_for_placement ) test_case.assertEqual( x[..., idx].tolist(), [[1, 8], [9, 16], [17, 24], [25, 32], [33, 40], [41, 48], [49, 56], [57, 64]], ) test_case.assertEqual( x[idx, ...].tolist(), [[1, 2, 3, 4, 5, 6, 7, 8], [57, 58, 59, 60, 61, 62, 63, 64]], ) # Test scalar ellipsis getitem x_scalar = _cpu_global_tensor(flow.tensor(9.9)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(x_scalar[...], 9.9) class TestGlobalIndexing(flow.unittest.TestCase): @globaltest def test_global_slice(test_case): for placement in all_placement(): for _ in range(5): _test_basic_slice(test_case, placement) _test_advanced_indexing(test_case, placement, dtype=flow.float32) _test_combined_indexing(test_case, placement, dtype=flow.float32) _test_single_int(test_case, placement) _test_multiple_int(test_case, placement) _test_none(test_case, placement) _test_step(test_case, placement) _test_step_assignment(test_case, placement) _test_bool_indices(test_case, placement) _test_multiple_bool_indices(test_case, placement) _test_int_indices(test_case, placement) _test_int_indices2d(test_case, placement) _test_int_indices_broadcast(test_case, placement) _test_empty_index(test_case, placement) _test_empty_ndim_index(test_case, placement) _test_empty_ndim_index_bool(test_case, placement) _test_empty_slice(test_case, placement) _test_index_getitem_copy_bools_slices(test_case, placement) _test_setitem_scalars(test_case, placement) _test_basic_advanced_combined(test_case, placement) _test_ellipsis_tensor(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_lazy_tensor_indexing.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np import oneflow as flow from oneflow.test_utils.automated_test_util import * import oneflow.unittest import oneflow.framework.session_context as session_ctx def get_graph_output(*args, func): def generate_graph(func): class Graph(flow.nn.Graph): def __init__(self): super().__init__() def build(self, *args): return func(*args) return Graph() graph = generate_graph(func) return graph(*args) def setitem_and_return(ref, idx, value): ref[idx] = value return ref def _randint(low, high): """ Get a random integer in the range [low, high). """ return random(low, high).to(int).value() def _cpu_global_tensor(tensor): return tensor.to_global(flow.placement.all("cpu"), flow.sbp.broadcast) def _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0): test_case.assertTrue( np.allclose(tensor1.numpy(), tensor2.numpy(), atol, rtol), f"{tensor1.numpy()} vs {tensor2.numpy()}", ) def global_broadcast_consec(size, start=1): """ Generate a arithmetic progression with given size and start value. """ sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0) sequence.add_(start - 1) return _cpu_global_tensor(sequence.view(*size)) def _test_basic_slice(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) ref_sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, ref_sbp) # empty tensor indexing empty_index = _cpu_global_tensor(flow.LongTensor()).to_global( placement, broadcast_for_placement ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[empty_index]), flow.empty(0, 8, 8), atol=0, rtol=0, ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[1]), global_broadcast_consec((8, 8), 65), atol=0, rtol=0, ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[0, 1]), global_broadcast_consec((8,), 9), atol=0, rtol=0, ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[0:2]), global_broadcast_consec((2, 8, 8)), atol=0, rtol=0, ) test_case.assertEqual( get_graph_output(reference, func=lambda x: x[2, 2, 2]).item(), 147 ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[:]), global_broadcast_consec((8, 8, 8)), atol=0, rtol=0, ) # indexing with Ellipsis _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[..., 2, 2]), flow.tensor([19, 83, 147, 211, 275, 339, 403, 467]), atol=0, rtol=0, ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[0, ..., 2]), flow.tensor([3, 11, 19, 27, 35, 43, 51, 59]), atol=0, rtol=0, ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[0, 2, ...]), reference[0, 2], atol=0, rtol=0, ) reference_5d = global_broadcast_consec((8, 8, 8, 8, 8)).to_global( placement, sbp=random_sbp(placement, max_dim=5).value() ) _assert_tensor_equal( test_case, get_graph_output(reference_5d, func=lambda x: x[2, ..., 1, 0]), get_graph_output(reference_5d, func=lambda x: x[2, :, :, 1, 0]), atol=0, rtol=0, ) # LongTensor indexing sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.LongTensor([2, 4])).to_global( placement, broadcast_for_placement ) _assert_tensor_equal( test_case, get_graph_output(reference, idx, func=lambda x, y: x[y]), get_graph_output(reference, func=lambda x: flow.stack([x[2], x[4]])), ) # None indexing _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[None, 2, None, None]), reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[None, 2:5, None, None]), reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), ) # indexing 0-length slice _assert_tensor_equal( test_case, flow.empty(0, 8, 8), get_graph_output(reference, func=lambda x: x[slice(0)]), ) _assert_tensor_equal( test_case, flow.empty(0, 8), get_graph_output(reference, func=lambda x: x[2, slice(0)]), ) _assert_tensor_equal( test_case, flow.tensor([]), get_graph_output(reference, func=lambda x: x[2, 1:1, 2]), ) # indexing with step sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[2:4, 1:5:2]), get_graph_output( reference, func=lambda x: flow.stack([x[2:4, 1], x[2:4, 3]], 1) ), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[3, 1:6:2]), get_graph_output( reference, func=lambda x: flow.stack([x[3, 1], x[3, 3], x[3, 5]], 0) ), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[None, 2, 1:9:4]), get_graph_output( reference, func=lambda x: flow.stack([x[2, 1], x[2, 5]], 0).unsqueeze(0) ), ) def _test_advanced_indexing(test_case, placement, dtype): broadcast_for_placement = [flow.sbp.broadcast] * len(placement.ranks.shape) # pick a random valid indexer type def ri(indices): choice = _randint(0, 3) if choice == 0: return flow.LongTensor( indices, placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ).to_global(placement, broadcast_for_placement) elif choice == 1: return list(indices) else: return tuple(indices) def validate_indexing(x): _assert_tensor_equal( test_case, get_graph_output(x, func=lambda x: x[ri([3]),]), global_broadcast_consec((1,), 4), ) _assert_tensor_equal( test_case, get_graph_output(x, func=lambda x: x[ri([2, 3, 4]),]), global_broadcast_consec((3,), 3), ) def validate_setting(x): # x[[0]] = -2 x = get_graph_output(x, func=lambda x: setitem_and_return(x, [0], -2)) _assert_tensor_equal(test_case, x[0], flow.tensor([-2], dtype=dtype)) # x[[0]] = -1 x = get_graph_output(x, func=lambda x: setitem_and_return(x, [0], -1)) _assert_tensor_equal(test_case, x[0], flow.tensor([-1], dtype=dtype)) # x[[2, 3, 4]] = 4 x = get_graph_output(x, func=lambda x: setitem_and_return(x, [2, 3, 4], 4)) _assert_tensor_equal( test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype) ) # x[ri([2, 3, 4]),] = 3 x = get_graph_output( x, func=lambda x: setitem_and_return(x, [ri([2, 3, 4]),], 3) ) _assert_tensor_equal( test_case, x[[2, 3, 4]], flow.tensor([3, 3, 3], dtype=dtype), ) # x[ri([0, 2, 4]),] = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype)) value_tensor = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype)) x = get_graph_output( x, func=lambda x: setitem_and_return(x, [ri([0, 2, 4]),], value_tensor) ) _assert_tensor_equal( test_case, x[[0, 2, 4]], flow.tensor([5, 4, 3], dtype=dtype), ) # 1d tensor and integer index setitem and getitem sbp = random_sbp(placement, max_dim=1).value() reference = global_broadcast_consec((8,)).to_global(placement, sbp) validate_indexing(reference) validate_setting(reference) # reference is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0, 1, 2]), ri([0])]), flow.tensor([1, 9, 17], dtype=dtype), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0, 1, 2]), ri([1])]), flow.tensor([2, 10, 18], dtype=dtype), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0]), ri([0])]), global_broadcast_consec((1,)), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([2]), ri([1])]), global_broadcast_consec((1,), 18), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0, 0]), ri([0, 1])]), flow.tensor([1, 2], dtype=dtype), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0, 1, 1, 0, 2, 7]), ri([1])]), flow.tensor([2, 10, 10, 2, 18, 58], dtype=dtype), ) _assert_tensor_equal( test_case, get_graph_output( reference, func=lambda x: x[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])] ), flow.tensor([1, 2, 9, 9], dtype=dtype), ) rows = ri([[0, 0], [1, 6]]) columns = ([0],) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[rows, columns]), flow.tensor([[1, 1], [9, 49]], dtype=dtype), ) rows = ri([[0, 0], [1, 6]]) columns = ri([6, 0]) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[rows, columns]), flow.tensor([[7, 1], [15, 49]], dtype=dtype), ) rows = ri([[0, 0], [1, 2]]) columns = ri([[0, 1], [3, 7]]) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[rows, columns]), flow.tensor([[1, 2], [12, 24]], dtype=dtype), ) # setting values # reference[ri([0]), ri([1])] = -1 reference = get_graph_output( reference, func=lambda x: setitem_and_return(x, [ri([0]), ri([1])], -1) ) _assert_tensor_equal( test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype), ) value_tensor = _cpu_global_tensor(flow.tensor([-1, 2, -4], dtype=dtype)).to_global( placement, broadcast_for_placement ) reference = get_graph_output( reference, func=lambda x: setitem_and_return(x, [ri([0, 1, 2]), ri([0])], value_tensor), ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([-1, 2, -4], dtype=dtype), ) value_tensor = _cpu_global_tensor( flow.tensor([[4, 6], [2, 3]], dtype=dtype) ).to_global(placement, broadcast_for_placement) reference = get_graph_output( reference, func=lambda x: setitem_and_return(x, [rows, columns], value_tensor) ) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype), ) # Tests using less than the number of dims, and ellipsis # reference is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([0, 2]),]), flow.tensor( [[1, 2, 3, 4, 5, 6, 7, 8], [17, 18, 19, 20, 21, 22, 23, 24]], dtype=dtype ), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[ri([1]), ...]), flow.tensor([[9, 10, 11, 12, 13, 14, 15, 16]], dtype=dtype), ) _assert_tensor_equal( test_case, get_graph_output(reference, func=lambda x: x[..., ri([1])]), flow.tensor([[2], [10], [18], [26], [34], [42], [50], [58]], dtype=dtype), ) def _test_combined_indexing(test_case, placement, dtype): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) def tensor_indices_to_np(tensor, indices): # convert the flow Tensor to a numpy array npt = tensor.numpy() # convert indices idxs = tuple( i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices ) return npt, idxs def get_numpy(tensor, indices): npt, idxs = tensor_indices_to_np(tensor, indices) # index and return as a oneflow local Tensor return flow.tensor(npt[idxs], dtype=dtype) def set_numpy(tensor, indices, value): if not isinstance(value, int): value = value.numpy() npt, idxs = tensor_indices_to_np(tensor, indices) npt[idxs] = value return npt def assert_get_eq(tensor, indexer): _assert_tensor_equal( test_case, get_graph_output(tensor, func=lambda x: x[indexer]), get_numpy(tensor, indexer), ) def assert_set_eq(tensor, indexer, val): pyt = tensor.clone() np_ref = tensor.clone() pyt = get_graph_output(pyt, func=lambda x: setitem_and_return(x, indexer, val)) np_ref = flow.tensor(set_numpy(np_ref, indexer, val), dtype=dtype) _assert_tensor_equal(test_case, pyt, np_ref) def get_set_tensor(indexed, indexer): set_size = indexed[indexer].size() set_count = indexed[indexer].numel() set_tensor = _cpu_global_tensor( flow.arange(set_count, 0, -1).view(set_size).to(dtype) ).to_global(placement, broadcast_for_placement) return set_tensor # Tensor is 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 # 17 18 19 20 21 22 23 24 # 25 26 27 28 29 30 31 32 # 33 34 35 36 37 38 39 40 # 41 42 43 44 45 46 47 48 # 49 50 51 52 53 54 55 56 # 57 58 59 60 61 62 63 64 sbp = random_sbp(placement, max_dim=2).value() reference = global_broadcast_consec((8, 8)).to_global(placement, sbp) indices_to_test = [ # grab the second, fourth columns [slice(None), [4, 6]], # first, third rows, [[0, 6], slice(None)], # TODO(wyg): only support getitem but not setitem # # weird shape # [slice(None), [[0, 1], # [2, 3]]], # negatives [[-1], [0]], [[0, 7], [-1]], [slice(None), [-1]], ] # test getitem get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] get_indices_to_test = indices_to_test + [ [slice(None), [[0, 1], [2, 3]]] ] # TODO: test setitem for indexer in get_indices_to_test: assert_get_eq(reference, indexer) # test setitem for indexer in indices_to_test: assert_set_eq(reference, indexer, 44) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) ######################### # test more dims tensor # ######################### sbp = random_sbp(placement, max_dim=3).value() reference = global_broadcast_consec((8, 8, 8), 0).float().to_global(placement, sbp) indices_to_test = [ [slice(None), slice(None), [0, 3, 4]], [slice(None), [2, 4, 5, 7], slice(None)], [[2, 3], slice(None), slice(None)], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0], [1, 2, 4]], [slice(None), [0, 1, 3], [4]], [slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), [[5, 6]], [[0, 3], [4, 4]]], [[0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4], slice(None)], [[0, 1, 3], [4], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 3]], slice(None)], [[[0, 1], [2, 3]], [[0]], slice(None)], [[[2, 1]], [[0, 3], [4, 4]], slice(None)], [[[2]], [[0, 3], [4, 1]], slice(None)], # non-contiguous indexing subspace [[0, 2, 3], slice(None), [1, 3, 4]], # less dim, ellipsis [[0, 2],], [[0, 2], slice(None)], [[0, 2], Ellipsis], [[0, 2], slice(None), Ellipsis], [[0, 2], Ellipsis, slice(None)], [[0, 2], [1, 3]], [[0, 2], [1, 3], Ellipsis], [Ellipsis, [1, 3], [2, 3]], [Ellipsis, [2, 3, 4]], [Ellipsis, slice(None), [2, 3, 4]], [slice(None), Ellipsis, [2, 3, 4]], # ellipsis counts for nothing [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, slice(None), [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), [0, 3, 4], Ellipsis], [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 212) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) sbp = random_sbp(placement, max_dim=4).value() reference = ( global_broadcast_consec((8, 8, 8, 8), 0).float().to_global(placement, sbp) ) indices_to_test = [ [slice(None), slice(None), slice(None), [0, 3, 4]], [slice(None), slice(None), [2, 4, 5, 7], slice(None)], [slice(None), [2, 3], slice(None), slice(None)], [[1, 2], slice(None), slice(None), slice(None)], [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), slice(None), [0], [1, 2, 4]], [slice(None), slice(None), [0, 1, 3], [4]], [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], [slice(None), [0], [1, 2, 4], slice(None)], [slice(None), [0, 1, 3], [4], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], [[0], [1, 2, 4], slice(None), slice(None)], [[0, 1, 2], [4], slice(None), slice(None)], [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 4], [1, 3, 4], [4]], [slice(None), [0, 1, 3], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 5], [3], [4]], [slice(None), [0], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1]], [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], [[2, 0, 1], [1, 2, 3], [4], slice(None)], [[0, 1, 2], [4], [1, 3, 4], slice(None)], [[0], [0, 2, 3], [1, 3, 4], slice(None)], [[0, 2, 1], [3], [4], slice(None)], [[0], [4], [1, 3, 4], slice(None)], [[1], [0, 2, 3], [1], slice(None)], [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], # less dim, ellipsis [Ellipsis, [0, 3, 4]], [Ellipsis, slice(None), [0, 3, 4]], [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4]], [[0], [1, 2, 4], slice(None)], [[0], [1, 2, 4], Ellipsis], [[0], [1, 2, 4], Ellipsis, slice(None)], [[1],], [[0, 2, 1], [3], [4]], [[0, 2, 1], [3], [4], slice(None)], [[0, 2, 1], [3], [4], Ellipsis], [Ellipsis, [0, 2, 1], [3], [4]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) indices_to_test += [ [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) def _test_single_int(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() v = _cpu_global_tensor(flow.zeros(8, 7, 3)).to_global(placement, sbp) test_case.assertEqual(get_graph_output(v, func=lambda x: x[2]).shape, (7, 3)) def _test_multiple_int(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp) test_case.assertEqual(get_graph_output(v, func=lambda x: x[4, :, 1]).shape, (8,)) def _test_none(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp) test_case.assertEqual( get_graph_output(v, func=lambda x: x[None]).shape, (1, 8, 8, 8) ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[:, None]).shape, (8, 1, 8, 8) ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[:, None, None]).shape, (8, 1, 1, 8, 8) ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[..., None]).shape, (8, 8, 8, 1) ) def _test_step(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() v = _cpu_global_tensor(flow.arange(8)).to_global(placement, sbp) _assert_tensor_equal(test_case, v[::1], v) test_case.assertEqual( get_graph_output(v, func=lambda x: x[::2]).tolist(), [0, 2, 4, 6] ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[::3]).tolist(), [0, 3, 6] ) test_case.assertEqual(get_graph_output(v, func=lambda x: x[::11]).tolist(), [0]) test_case.assertEqual( get_graph_output(v, func=lambda x: x[1:6:2]).tolist(), [1, 3, 5] ) def _test_step_assignment(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() v = _cpu_global_tensor(flow.zeros(8, 8)).to_global(placement, sbp) value_tensor = _cpu_global_tensor(flow.tensor([3.0, 4.0, 5.0, 6.0])).to_global( placement, broadcast_for_placement ) v = get_graph_output( v, func=lambda x: setitem_and_return(x, [0, slice(1, None, 2)], value_tensor) ) test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0]) test_case.assertEqual(v[1:].sum(), 0) def _test_multiple_bool_indices(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() v = global_broadcast_consec((8, 8, 4)).to_global(placement, sbp) # NOTE: these broadcast together and are transposed to the first dim mask1 = _cpu_global_tensor( flow.tensor([1, 0, 1, 0, 0, 1, 0, 0], dtype=flow.bool) ).to_global(placement, broadcast_for_placement) mask2 = _cpu_global_tensor(flow.tensor([1, 1, 1, 0], dtype=flow.bool)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(v[mask1, :, mask2].shape, (3, 8)) def _test_int_indices(test_case, placement): sbp = random_sbp(placement, max_dim=3).value() v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp) test_case.assertEqual( get_graph_output(v, func=lambda x: x[[0, 4, 2]]).shape, (3, 8, 8) ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[:, [0, 4, 2]]).shape, (8, 3, 8) ) test_case.assertEqual( get_graph_output(v, func=lambda x: x[:, [[0, 1], [4, 3]]]).shape, (8, 2, 2, 8) ) def _test_int_indices2d(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) rows = _cpu_global_tensor(flow.tensor([[0, 0], [6, 3]])).to_global( placement, broadcast_for_placement ) columns = _cpu_global_tensor(flow.tensor([[0, 2], [0, 7]])).to_global( placement, broadcast_for_placement ) test_case.assertEqual( get_graph_output(x, func=lambda x: x[rows, columns]).tolist(), [[1, 3], [49, 32]], ) def _test_int_indices_broadcast(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) rows = _cpu_global_tensor(flow.tensor([0, 7])).to_global( placement, broadcast_for_placement ) columns = _cpu_global_tensor(flow.tensor([7, 2])).to_global( placement, broadcast_for_placement ) result = get_graph_output(x, func=lambda x: x[rows[:, None], columns]) test_case.assertEqual(result.tolist(), [[8, 3], [64, 59]]) def _test_empty_index(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.tensor([], dtype=flow.long)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(get_graph_output(x, func=lambda x: x[idx]).numel(), 0) # empty assignment should have no effect but not throw an exception y = x.clone() y = get_graph_output(y, func=lambda x: setitem_and_return(x, idx, -1)) _assert_tensor_equal(test_case, x, y) # TODO(wyg): support eager bool indices tensor in lazy mode # mask = _cpu_global_tensor(flow.zeros(8, 8).to(flow.bool)).to_global( # placement, broadcast_for_placement # ) # y = get_graph_output(y, func=lambda x: setitem_and_return(x, mask, -1)) # _assert_tensor_equal(test_case, x, y) def _test_empty_ndim_index(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=1).value() x = global_broadcast_consec((8,)).to_global(placement, sbp) index = _cpu_global_tensor(flow.empty(0, 2, dtype=flow.int64)).to_global( placement, broadcast_for_placement ) _assert_tensor_equal( test_case, get_graph_output(x, func=lambda x: x[index]), flow.empty(0, 2), ) sbp = random_sbp(placement, max_dim=1).value() x = _cpu_global_tensor(flow.empty(8, 0)).to_global(placement, sbp) test_case.assertEqual(get_graph_output(x, func=lambda x: x[[1, 2]]).shape, (2, 0)) test_case.assertEqual(get_graph_output(x, func=lambda x: x[[], []]).shape, (0,)) test_case.assertEqual(get_graph_output(x, func=lambda x: x[[[]]]).shape, (0, 0)) test_case.assertEqual( get_graph_output(x, func=lambda x: x[[[[]]]]).shape, (1, 0, 0) ) test_case.assertEqual(get_graph_output(x, func=lambda x: x[[1], []]).shape, (0,)) test_case.assertEqual(get_graph_output(x, func=lambda x: x[[], [2]]).shape, (0,)) def _test_empty_slice(test_case, placement): sbp = random_sbp(placement, max_dim=1).value() x = global_broadcast_consec((8, 8, 8, 8)).to_global(placement, sbp) y = get_graph_output(x, func=lambda x: x[:, 1:1, :, 1]) test_case.assertEqual((8, 0, 8), y.shape) def _test_setitem_scalars(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) zero = _cpu_global_tensor(flow.tensor(0, dtype=flow.int64)).to_global( placement, broadcast_for_placement ) # non-scalar indexed with scalars a = global_broadcast_consec((8, 8)).to_global( placement, random_sbp(placement, max_dim=2).value() ) a_set_with_number = a.clone() a_set_with_scalar = a.clone() b = global_broadcast_consec((8,), 233).to_global( placement, random_sbp(placement, max_dim=1).value() ) a_set_with_number = get_graph_output( a_set_with_number, func=lambda x: setitem_and_return(x, 0, b) ) a_set_with_scalar = get_graph_output( a_set_with_scalar, func=lambda x: setitem_and_return(x, zero, b) ) _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar) # a[1, zero] = 7.7 value = get_graph_output( a, func=lambda x: setitem_and_return(x, [1, zero], 7.7) ).numpy() test_case.assertEqual(np.array(7.7, dtype=value.dtype), value[1, 0]) np_x = np.zeros((8, 8)) np_x[0, 6] = 1.0 x = _cpu_global_tensor(flow.tensor(np_x)).to_global( placement, random_sbp(placement, max_dim=2).value() ) # x[0, 6] = 1.0 res = get_graph_output(x, func=lambda x: setitem_and_return(x, [0, 6], 1.0)) test_case.assertEqual(res.numpy().all(), np_x.all()) # scalar indexed with scalars r = _cpu_global_tensor(flow.tensor(1.0)).to_global( placement, random_sbp(placement, max_dim=0).value() ) # r[...] = 9.9 res = get_graph_output(r, func=lambda x: setitem_and_return(x, [...], 9.9)) test_case.assertEqual(res, 9.9) # scalar indexed with oneflow.Size([1]) np_x = np.zeros((8, 8)) np_x[0, 6] = np.ones(1) x = _cpu_global_tensor(flow.tensor(np_x)).to_global( placement, random_sbp(placement, max_dim=2).value() ) value_tensor = _cpu_global_tensor(flow.ones(1).to(flow.float64)).to_global( placement, broadcast_for_placement ) # x[0, 0] = value res = get_graph_output( x, func=lambda x: setitem_and_return(x, [0, 0], value_tensor) ) test_case.assertEqual(res.numpy().all(), np_x.all()) def _test_ellipsis_tensor(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=2).value() x = global_broadcast_consec((8, 8)).to_global(placement, sbp) idx = _cpu_global_tensor(flow.tensor([0, 7])).to_global( placement, broadcast_for_placement ) test_case.assertEqual( get_graph_output(x, func=lambda x: x[..., idx]).tolist(), [[1, 8], [9, 16], [17, 24], [25, 32], [33, 40], [41, 48], [49, 56], [57, 64]], ) test_case.assertEqual( get_graph_output(x, func=lambda x: x[idx, ...]).tolist(), [[1, 2, 3, 4, 5, 6, 7, 8], [57, 58, 59, 60, 61, 62, 63, 64]], ) # Test scalar ellipsis getitem x_scalar = _cpu_global_tensor(flow.tensor(9.9)).to_global( placement, broadcast_for_placement ) test_case.assertEqual(get_graph_output(x_scalar, func=lambda x: x[...]), 9.9) def _test_bool_indices(test_case, placement): broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape) sbp = random_sbp(placement, max_dim=1, except_partial_sum=True).value() v = global_broadcast_consec((8,)).to_global(placement, sbp) boolIndices = _cpu_global_tensor( flow.tensor( [True, False, True, True, False, False, False, True], dtype=flow.bool ) ).to_global(placement, sbp) _assert_tensor_equal( test_case, get_graph_output(v, func=lambda x: setitem_and_return(x, boolIndices, 6.6)), flow.tensor([6.6, 2.0, 6.6, 6.6, 5.0, 6.0, 7.0, 6.6]), ) class TestGlobalIndexing(flow.unittest.TestCase): @globaltest @unittest.skip( "TODO(wyg, zwx): test these cases after supporting clear session interface to avoid" "geting 'stream_id.h:33 Check failed: stream_index <= kMaxStreamIndex (4096 vs. 4095)' error" ) def test_global_slice(test_case): for placement in all_placement(): for _ in range(5): _test_basic_slice(test_case, placement) _test_advanced_indexing(test_case, placement, dtype=flow.float32) _test_combined_indexing(test_case, placement, dtype=flow.float32) _test_single_int(test_case, placement) _test_multiple_int(test_case, placement) _test_none(test_case, placement) _test_step(test_case, placement) _test_step_assignment(test_case, placement) _test_int_indices(test_case, placement) _test_int_indices2d(test_case, placement) _test_int_indices_broadcast(test_case, placement) _test_empty_index(test_case, placement) _test_empty_ndim_index(test_case, placement) _test_empty_slice(test_case, placement) _test_ellipsis_tensor(test_case, placement) # TODO: cpu variable don't support common net if not placement.type == "cpu": _test_setitem_scalars(test_case, placement) @globaltest def test_bool_indices(test_case): for placement in all_placement(): for _ in range(2): _test_bool_indices(test_case, placement) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_meta_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import oneflow as flow import oneflow.unittest from oneflow import nn class CustomModule(nn.Module): def __init__(self, foo, bar, device=None): super().__init__() # ==== Case 1: Module creates parameters directly. ==== self.param1 = nn.Parameter(flow.empty((foo, bar), device=device)) self.register_parameter("param2", nn.Parameter(flow.empty(bar, device=device))) with flow.no_grad(): nn.init.kaiming_uniform_(self.param1) nn.init.uniform_(self.param2) # ==== Case 2: Module creates submodules. ==== self.fc = nn.Linear(bar, 5, device=device) self.linears = nn.Sequential( nn.Linear(5, 5, device=device), nn.Linear(5, 1, device=device) ) # ==== Case 3: Module creates buffers. ==== self.register_buffer("some_buffer", flow.ones(7, device=device)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestMetaTensor(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_meta_tensor_local_mode_without_data(test_case): x = flow.Tensor(3, 2, device="meta") y = flow.Tensor(3, 2, device="cpu") test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_local_mode_with_data(test_case): x = flow.Tensor([3, 2], device="meta") y = flow.Tensor([3, 2], device="cpu") test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_func_local_mode_without_data(test_case): x = flow.tensor([3, 2], device="meta") y = flow.tensor([3, 2], device="cpu") test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_func_local_mode_with_data(test_case): x = flow.tensor([3, 2], device="meta") y = flow.tensor([3, 2], device="cpu") test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_local_mode_ones(test_case): x = flow.ones(3, 2, device="meta") y = flow.ones([3, 2], device="cpu") test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_local_mode_linear(test_case): x = flow.nn.Linear(3, 2, device="meta") y = flow.nn.Linear(3, 2, device="cpu") test_case.assertEqual(x.weight.dtype, y.weight.dtype) test_case.assertEqual(x.weight.shape, y.weight.shape) test_case.assertEqual(x.weight.requires_grad, y.weight.requires_grad) test_case.assertEqual(x.weight.device, flow.device("meta")) @flow.unittest.skip_unless_1n1d() def test_skip_init_function(test_case): x = flow.nn.utils.skip_init(flow.nn.Linear, 4, 3) y = flow.nn.Linear(4, 3, device="cpu") test_case.assertEqual(x.weight.dtype, y.weight.dtype) test_case.assertEqual(x.weight.shape, y.weight.shape) test_case.assertEqual(x.weight.requires_grad, y.weight.requires_grad) test_case.assertEqual(x.weight.device, flow.device("cpu")) @flow.unittest.skip_unless_1n1d() def test_skip_init_function_custom_module(test_case): x = flow.nn.utils.skip_init(CustomModule, 4, 3) y = CustomModule(4, 3, device="cpu") test_case.assertEqual(x.param1.dtype, y.param1.dtype) test_case.assertEqual(x.param1.shape, y.param1.shape) test_case.assertEqual(x.param1.requires_grad, y.param1.requires_grad) test_case.assertEqual(x.param1.device, flow.device("cpu")) test_case.assertEqual(x.param2.dtype, y.param2.dtype) test_case.assertEqual(x.param2.shape, y.param2.shape) test_case.assertEqual(x.param2.requires_grad, y.param2.requires_grad) test_case.assertEqual(x.param2.device, flow.device("cpu")) test_case.assertEqual(x.fc.weight.dtype, y.fc.weight.dtype) test_case.assertEqual(x.fc.weight.shape, y.fc.weight.shape) test_case.assertEqual(x.fc.weight.requires_grad, y.fc.weight.requires_grad) test_case.assertEqual(x.fc.weight.device, flow.device("cpu")) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_local_mode_clone(test_case): x = flow.tensor([3, 2], device="meta") y = x.clone() test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.device, y.device) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_global_mode_without_data(test_case): P1 = flow.placement(type="meta", ranks=[0]) P2 = flow.placement(type="cpu", ranks=[0]) sbp = flow.sbp.broadcast x = flow.Tensor(3, 2, placement=P1, sbp=sbp) y = flow.Tensor(3, 2, placement=P2, sbp=sbp) test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.sbp, y.sbp) test_case.assertEqual(x.placement.type, "meta") test_case.assertEqual(x.to_local().dtype, y.to_local().dtype) test_case.assertEqual(x.to_local().shape, y.to_local().shape) test_case.assertEqual(x.to_local().device.type, "meta") @flow.unittest.skip_unless_1n1d() def test_meta_tensor_global_mode_with_data(test_case): P1 = flow.placement(type="meta", ranks=[0]) P2 = flow.placement(type="cpu", ranks=[0]) sbp = flow.sbp.broadcast x = flow.Tensor([3, 2], placement=P1, sbp=sbp) y = flow.Tensor([3, 2], placement=P2, sbp=sbp) test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.sbp, y.sbp) test_case.assertEqual(x.placement.type, "meta") test_case.assertEqual(x.to_local().dtype, y.to_local().dtype) test_case.assertEqual(x.to_local().shape, y.to_local().shape) test_case.assertEqual(x.to_local().device.type, "meta") @flow.unittest.skip_unless_1n1d() def test_meta_tensor_func_global_mode_without_data(test_case): P1 = flow.placement(type="meta", ranks=[0]) P2 = flow.placement(type="cpu", ranks=[0]) sbp = flow.sbp.broadcast x = flow.tensor([3, 2], placement=P1, sbp=sbp) y = flow.tensor([3, 2], placement=P2, sbp=sbp) test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.sbp, y.sbp) test_case.assertEqual(x.placement.type, "meta") test_case.assertEqual(x.to_local().dtype, y.to_local().dtype) test_case.assertEqual(x.to_local().shape, y.to_local().shape) test_case.assertEqual(x.to_local().device.type, "meta") @flow.unittest.skip_unless_1n1d() def test_meta_tensor_global_mode_clone(test_case): P = flow.placement(type="meta", ranks=[0]) sbp = flow.sbp.broadcast x = flow.tensor([3, 2], placement=P, sbp=sbp) y = x.clone() test_case.assertEqual(x.dtype, y.dtype) test_case.assertEqual(x.shape, y.shape) test_case.assertEqual(x.sbp, y.sbp) test_case.assertEqual(x.placement, y.placement) @flow.unittest.skip_unless_1n1d() def test_meta_tensor_calculate(test_case): x1 = flow.tensor([3, 2], device="meta") y1 = x1 + 1 P = flow.placement(type="meta", ranks=[0]) sbp = flow.sbp.broadcast x2 = flow.tensor([3, 2], placement=P, sbp=sbp) y2 = x2 + 1 test_case.assertEqual(y1.device.type, "meta") test_case.assertEqual(y2.placement.type, "meta") if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_new_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import numpy as np import oneflow as flow import oneflow.unittest class TestNewTensor(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_new_tensor_local_mode_with_default_args(test_case): tensor = flow.randn(5) data = [[1, 2], [3, 4]] new_tensor = tensor.new_tensor(data) test_case.assertEqual(new_tensor.dtype, tensor.dtype) test_case.assertEqual(new_tensor.device, tensor.device) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() def test_new_tensor_local_mode_with_spec_args(test_case): tensor = flow.randn(5) data = [[1, 2], [3, 4]] new_tensor = tensor.new_tensor(data, flow.int64, "cuda") test_case.assertEqual(new_tensor.dtype, flow.int64) test_case.assertEqual(new_tensor.device, flow.device("cuda")) @flow.unittest.skip_unless_1n2d() def test_new_tensor_global_mode_with_default_args(test_case): placement = flow.placement(type="cpu", ranks=[0, 1]) sbp = flow.sbp.split(0) tensor = flow.randn(4, 4, placement=placement, sbp=sbp) data = [[1, 2], [3, 4]] new_tensor = tensor.new_tensor(data) test_case.assertEqual(new_tensor.dtype, tensor.dtype) test_case.assertEqual(new_tensor.placement, placement) test_case.assertEqual(new_tensor.sbp, (sbp,)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() def test_new_tensor_global_mode_with_spec_args(test_case): placement = flow.placement(type="cuda", ranks=[0, 1]) sbp = flow.sbp.split(0) tensor = flow.randn(4, 4, placement=placement, sbp=sbp) data = [[1, 2], [3, 4]] new_tensor = tensor.new_tensor( data, placement=placement, sbp=flow.sbp.broadcast ) test_case.assertEqual(new_tensor.dtype, tensor.dtype) test_case.assertEqual(new_tensor.placement, placement) test_case.assertEqual(new_tensor.sbp, (flow.sbp.broadcast,)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() def test_new_cuda_bfloat16_local_tensor_with_numpy(test_case): from oneflow import sysconfig if sysconfig.get_cuda_version() < 11000: return np_array = np.random.rand(4, 4) tensor = flow.tensor(np_array, dtype=flow.bfloat16, device="cuda") test_case.assertEqual(tensor.dtype, flow.bfloat16) test_case.assertEqual(tensor.device, flow.device("cuda")) @flow.unittest.skip_unless_1n1d() def test_new_cpu_bfloat16_local_tensor_with_numpy(test_case): np_array = np.random.rand(4, 4) tensor = flow.tensor(np_array, dtype=flow.bfloat16, device="cpu") test_case.assertEqual(tensor.dtype, flow.bfloat16) test_case.assertEqual(tensor.device, flow.device("cpu")) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_parameter.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestParameter(flow.unittest.TestCase): @autotest(n=1, check_graph=True) def test_parameter_grad_fn_none(test_case): x = torch.ones(2, 3).requires_grad_(True) y = x + x z = torch.nn.Parameter(y) return z.grad_fn @autotest(n=1, check_graph=True) def test_parameter_set_data_autograd_meta(test_case): x = torch.ones(2, 3).requires_grad_(True) y = x + x z = torch.nn.Parameter(x) z.data = y return z.grad_fn, z.is_leaf # Not check graph because of 2 reason. # Reason 1, x.data return a new tensor but share storage with the origin tensor, this is not well dealed in nn.Graph. # Reason 2, inplace operation mul_ can works well inside nn.Graph but will not change the value in free eager tensor. # Please refer to test case: test_graph_return_inplace_free_eager_tensor @autotest(n=1, check_graph="ValidatedFalse") def test_parameter_inplace_modify_data(test_case): x = torch.nn.Parameter(torch.ones(2, 3)) x.data.mul_(2) return x def test_parameter_set_data(test_case): a = flow.nn.Parameter(flow.ones(2, 3), False) old_id = id(a) b = flow.nn.Parameter(flow.ones(4, 5), True) a.data = b test_case.assertEqual(old_id, id(a)) test_case.assertTrue(a.shape == (4, 5)) test_case.assertFalse(a.requires_grad) test_case.assertTrue(a.is_leaf) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_safetensors.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest import tempfile import oneflow as flow import oneflow.unittest import oneflow.mock_torch as mock tensors = { "weight1": flow.zeros((1024, 1024)), "weight2": flow.ones((1024, 1024)), "weight3": flow.rand((1024, 1024)), "weight4": flow.eye(1024), } def _test_save_safetensors(save_path): with mock.enable(): from safetensors.torch import save_file save_file(tensors, save_path) def _test_load_safetensors(load_path): with mock.enable(): from safetensors import safe_open tensors_load = {} with safe_open(load_path, framework="pt", device="cpu") as f: for key in f.keys(): tensors_load[key] = f.get_tensor(key) return tensors_load class TestSafetensors(flow.unittest.TestCase): def test_safetensors(test_case): with tempfile.TemporaryDirectory() as f0: _test_save_safetensors(os.path.join(f0, "model.safetensors")) tensors_load = _test_load_safetensors(os.path.join(f0, "model.safetensors")) for key in tensors.keys(): test_case.assertTrue((tensors[key] == tensors_load[key]).all()) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_and_ndarray_compatibility.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest from collections import OrderedDict import oneflow as flow from oneflow.test_utils.test_util import GenArgDict import numpy as np import torch test_compute_op_list = [ "+", "-", "*", "/", "**", "//", "%", ] test_login_op_list = [ "^", "&", "|", ] test_compare_op_list = [ "==", "!=", ] def _test_compute_operator(test_case, shape, dtype): random_tensor = np.random.randn(*shape).astype(dtype) x_flow = flow.tensor(random_tensor) x_torch = torch.tensor(random_tensor) random_numpy = np.random.randn(*shape) for op in test_compute_op_list: if op in ["**", "//", "%"]: random_tensor = np.random.randint(1, 100, size=shape) random_numpy = np.random.randint(1, 10, size=shape) else: random_tensor = np.random.randn(*shape) random_numpy = np.random.randn(*shape) x_flow = flow.tensor(random_tensor) x_torch = torch.tensor(random_tensor) z_flow = eval(f"x_flow {op} random_numpy") z_torch = eval(f"x_torch {op} random_numpy") test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy())) # TODO:support for "+=" compatibility if op not in ["**", "+"]: exec(f"x_flow {op}= random_numpy") exec(f"x_torch {op}= random_numpy") test_case.assertTrue( np.allclose(z_flow.numpy(), z_torch.numpy(), 1e-05, 1e-05) ) def _test_logic_operator(test_case, shape): random_tensor = np.random.randint(100, size=shape) x_flow = flow.tensor(random_tensor, dtype=flow.int64) x_torch = torch.tensor(random_tensor, dtype=torch.int64) random_numpy = np.random.randint(100, size=shape) for op in test_login_op_list: z_flow = eval(f"x_flow {op} random_numpy") z_torch = eval(f"x_torch {op} random_numpy") test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy(), 1e-05, 1e-05)) def _test_compare_operator(test_case, shape): random_tensor = np.random.randint(100, size=shape) x_flow = flow.tensor(random_tensor, dtype=flow.int64) x_torch = torch.tensor(random_tensor, dtype=torch.int64) random_numpy = np.random.randint(100, size=shape) for op in test_compare_op_list: flow_bool_value = eval(f"x_flow {op} random_numpy") torch_bool_value = eval(f"x_torch {op} random_numpy") print(flow_bool_value) print(torch_bool_value) test_case.assertTrue(flow_bool_value, torch_bool_value) @flow.unittest.skip_unless_1n1d() class TestTensorAndNdarrayCompatibility(flow.unittest.TestCase): def test_op_compatibility(test_case): arg_dict = OrderedDict() arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["dtype"] = [np.float32, np.float64] for arg in GenArgDict(arg_dict): _test_compute_operator(test_case, **arg) # TODO(yzm):support compare operator Compatibility # _test_compare_operator(test_case, **arg) # TODO(yzm):fix the logic op bug # _test_logic_operator(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_exponential.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random import numpy as np from collections import OrderedDict import torch import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_exponential(test_case, device, seed, lambd, dtype): torch.manual_seed(seed) flow.manual_seed(seed) dim1 = random.randint(8, 64) dim2 = random.randint(8, 64) torch_arr = torch.zeros( dim1, device=device, dtype=torch.float32 if dtype == "float" else torch.float64 ).exponential_(lambd=lambd, generator=None) oneflow_arr = flow.zeros( dim1, device=device, dtype=flow.float32 if dtype == "float" else flow.float64 ).exponential_(lambd=lambd, generator=None) test_case.assertTrue( np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,) ) torch_arr = torch.zeros( dim1, device=device, dtype=torch.float32 if dtype == "float" else torch.float64 ).exponential_(lambd=lambd, generator=None) oneflow_arr = flow.zeros( dim1, device=device, dtype=flow.float32 if dtype == "float" else flow.float64 ).exponential_(lambd=lambd, generator=None) test_case.assertTrue( np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,) ) torch_gen = torch.Generator(device=device) torch_gen.manual_seed(seed) oneflow_gen = flow.Generator(device=device) oneflow_gen.manual_seed(seed) torch_arr = torch.zeros( dim1, device=device, dtype=torch.float32 if dtype == "float" else torch.float64 ).exponential_(lambd=lambd, generator=torch_gen) oneflow_arr = flow.zeros( dim1, device=device, dtype=flow.float32 if dtype == "float" else flow.float64 ).exponential_(lambd=lambd, generator=oneflow_gen) test_case.assertTrue( np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,) ) torch_arr = torch.zeros( dim1, device=device, dtype=torch.float32 if dtype == "float" else torch.float64 ).exponential_(lambd=lambd, generator=torch_gen) oneflow_arr = flow.zeros( dim1, device=device, dtype=flow.float32 if dtype == "float" else flow.float64 ).exponential_(lambd=lambd, generator=oneflow_gen) test_case.assertTrue( np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,) ) @flow.unittest.skip_unless_1n1d() class TestExponential(flow.unittest.TestCase): def test_exponential(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] arg_dict["seed"] = [0, 2, 4] arg_dict["lambd"] = [1, 0.5, 0.1] arg_dict["dtype"] = ["double", "float"] for arg in GenArgList(arg_dict): _test_exponential(test_case, *arg[0:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_indexing.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest from oneflow.test_utils.test_util import GenArgList from collections import OrderedDict from oneflow.test_utils.automated_test_util import * import numpy as np import oneflow as flow import oneflow.unittest def _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar): x = flow.Tensor(numpy_x) # basic_slice test_case.assertTrue(np.allclose(numpy_x[np_scalar(1)], x[np_scalar(1)].numpy())) test_case.assertTrue(np.allclose(numpy_x[np_scalar(-2)], x[np_scalar(-2)].numpy())) test_case.assertTrue( np.allclose( numpy_x[np_scalar(0), np_scalar(1)], x[np_scalar(0), np_scalar(1)].numpy() ) ) test_case.assertTrue( np.allclose( numpy_x[(np_scalar(0), np_scalar(1))], x[(np_scalar(0), np_scalar(1))].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[((np_scalar(0), np_scalar(1)))], x[((np_scalar(0), np_scalar(1)))].numpy(), ) ) def _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar): x = flow.Tensor(numpy_x) # advance indexing test_case.assertTrue( np.allclose( numpy_x[[np_scalar(0), np_scalar(1)]], x[[np_scalar(0), np_scalar(1)]].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]], x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[ [np_scalar(0), np_scalar(1)], [np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)], ], x[ [np_scalar(0), np_scalar(1)], [np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)], ].numpy(), ) ) def _test_basic_slice(test_case, numpy_x): x = flow.tensor(numpy_x) test_case.assertTrue(np.allclose(numpy_x[1], x[1].numpy())) test_case.assertTrue(np.allclose(numpy_x[-2], x[-2].numpy())) test_case.assertTrue(np.allclose(numpy_x[0, 1], x[0, 1].numpy())) test_case.assertTrue(np.allclose(numpy_x[(0, 1)], x[(0, 1)].numpy())) test_case.assertTrue(np.allclose(numpy_x[((0, 1))], x[((0, 1))].numpy())) test_case.assertTrue(np.allclose(numpy_x[None], x[None].numpy())) test_case.assertTrue(np.allclose(numpy_x[True], x[True].numpy())) test_case.assertTrue(np.allclose(numpy_x[1, None], x[1, None].numpy())) test_case.assertTrue(np.allclose(numpy_x[1, None, 1], x[1, None, 1].numpy())) test_case.assertTrue( np.allclose(numpy_x[1, None, None, 1], x[1, None, None, 1].numpy()) ) test_case.assertTrue(np.allclose(numpy_x[:], x[:].numpy())) test_case.assertTrue(np.allclose(numpy_x[:1], x[:1].numpy())) test_case.assertTrue(np.allclose(numpy_x[0:1], x[0:1].numpy())) test_case.assertTrue(np.allclose(numpy_x[-2:-1], x[-2:-1].numpy())) test_case.assertTrue(np.allclose(numpy_x[2:100:200], x[2:100:200].numpy())) test_case.assertTrue(np.allclose(numpy_x[0:2, ...], x[0:2, ...].numpy())) test_case.assertTrue(np.allclose(numpy_x[0:2, ..., 1], x[0:2, ..., 1].numpy())) test_case.assertTrue( np.allclose(numpy_x[0:2, ..., 1, 1], x[0:2, ..., 1, 1].numpy()) ) test_case.assertTrue(np.allclose(numpy_x[0:4:2, ...], x[0:4:2, ...].numpy())) test_case.assertTrue( np.allclose(numpy_x[0:2, None, ..., True], x[0:2, None, ..., True].numpy()) ) test_case.assertTrue( np.allclose(numpy_x[None, ..., 0:4:2, True], x[None, ..., 0:4:2, True].numpy()) ) test_case.assertTrue(np.allclose(numpy_x[False, ...], x[False, ...].numpy())) test_case.assertTrue( np.allclose(numpy_x[False, True, ...], x[False, True, ...].numpy()) ) test_case.assertTrue( np.allclose(numpy_x[True, ..., False, True], x[True, ..., False, True].numpy()) ) test_case.assertTrue( np.allclose( numpy_x[True, None, ..., False, True], x[True, None, ..., False, True].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[True, 1, ..., False, True], x[True, 1, ..., False, True].numpy() ) ) # NOTE: When numpy>=1.23.0, the list of index will be seemed as basic indexing, # and tuple of index will be seemed as advanced indexing. def _test_advanced_indexing(test_case, numpy_x): x = flow.tensor(numpy_x) test_case.assertTrue(np.allclose(numpy_x[[0, 1]], x[[0, 1]].numpy())) test_case.assertTrue( np.allclose(numpy_x[[0, 1], [1, 0]], x[[0, 1], [1, 0]].numpy()) ) test_case.assertTrue( np.allclose( numpy_x[tuple([[0, 1], [0, 1], [1, 0]])], x[[[0, 1], [0, 1], [1, 0]]].numpy(), ) ) test_case.assertTrue(np.allclose(numpy_x[tuple([[0], [1]])], x[[[0], [1]]].numpy())) test_case.assertTrue( np.allclose( numpy_x[tuple([[[0], [1]], [[0], [1]], [0, 1]])], x[[[[0], [1]], [[0], [1]], [0, 1]]].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[tuple([[[0, 1], [1, 1]], [[0, 0], [1, 1]], [0, 1]])], x[[[[0, 1], [1, 1]], [[0, 0], [1, 1]], [0, 1]]].numpy(), ) ) # Tensor index test_case.assertTrue( np.allclose( numpy_x[np.array([0, 1]), np.array([1, 0])], x[flow.tensor([0, 1]), flow.tensor([1, 0])].numpy(), ) ) test_case.assertTrue( np.allclose( numpy_x[:, np.array([[0, 1], [1, 1]]), np.array([[1, 0], [1, 1]])], x[:, flow.tensor([[0, 1], [1, 1]]), flow.tensor([[1, 0], [1, 1]]),].numpy(), ) ) # mask tensor index mask = np.random.rand(numpy_x.shape[0], numpy_x.shape[1]).astype(np.float32) y = flow.tensor(mask) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5, 1], x[y > 0.5, 1].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 0], x[y > 0].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 0, 1], x[y > 0, 1].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1], x[y > 1].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1, 1], x[y > 1, 1].numpy())) mask = np.random.rand(*numpy_x.shape).astype(np.float32) y = flow.tensor(mask) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 0], x[y > 0].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1], x[y > 1].numpy())) def _test_advanced_indexing_array(test_case, numpy_x, dtype): x = flow.tensor(numpy_x) idx = np.array([0, 1], dtype=dtype) test_case.assertTrue(np.allclose(numpy_x[idx], x[idx].numpy())) idx1 = np.array([0, 1], dtype=dtype) idx2 = np.array([1, 0], dtype=dtype) test_case.assertTrue(np.allclose(numpy_x[idx1, idx2], x[idx1, idx2].numpy())) idx = np.array([[0, 1], [0, 1], [1, 0]], dtype=dtype) test_case.assertTrue(np.allclose(numpy_x[idx, :, :], x[idx, :, :].numpy())) test_case.assertTrue(np.allclose(numpy_x[idx, idx, :], x[idx, idx, :].numpy())) test_case.assertTrue(np.allclose(numpy_x[idx, idx, idx], x[idx, idx, idx].numpy())) idx1 = np.array([[1, 0, 1], [1, 1, 0]]) idx2 = np.array([[0], [1]]) test_case.assertTrue( np.allclose(numpy_x[:, idx1, :, idx2].shape, x[:, idx1, :, idx2].shape) ) test_case.assertTrue( np.allclose(numpy_x[:, idx1, 1, idx2].shape, x[:, idx1, 1, idx2].shape) ) test_case.assertTrue( np.allclose(numpy_x[idx1, :, idx2, :].shape, x[idx1, :, idx2, :].shape) ) test_case.assertTrue( np.allclose(numpy_x[:, idx1, idx2, :].shape, x[:, idx1, idx2, :].shape) ) def _test_combining_indexing(test_case, numpy_x): x = flow.tensor(numpy_x) test_case.assertTrue( np.allclose(numpy_x[[0, 1], 1:2, [1, 0]], x[[0, 1], 1:2, [1, 0]].numpy()) ) test_case.assertTrue( np.allclose(numpy_x[:, [0, 1], [1, 0]], x[:, [0, 1], [1, 0]].numpy()) ) test_case.assertTrue(np.allclose(numpy_x[:, [0, 1], 1], x[:, [0, 1], 1].numpy())) test_case.assertTrue( np.allclose(numpy_x[..., [0, 1], 1, [1, 0]], x[..., [0, 1], 1, [1, 0]].numpy()) ) def _test_mask_getitem(test_case, numpy_x): x = flow.tensor(numpy_x) mask = np.random.rand(*numpy_x.shape).astype(np.float32) y = flow.tensor(mask) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1.0], x[y > 1.0].numpy())) mask = np.random.rand(numpy_x.shape[0]).astype(np.float32) y = flow.tensor(mask) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1.0], x[y > 1.0].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 0.5, 1], x[y > 0.5, 1].numpy())) test_case.assertTrue(np.allclose(numpy_x[mask > 1.0, 1], x[y > 1.0, 1].numpy())) def _test_mask_setitem(test_case, numpy_x): x = flow.tensor(numpy_x) # mask tensor index mask = np.random.rand(*numpy_x.shape).astype(np.float32) y = flow.tensor(mask) # broadcast set x[y > 0.5] = 1.0 numpy_x[mask > 0.5] = 1.0 test_case.assertTrue(np.allclose(numpy_x, x.numpy())) # elementwise set update = np.random.randn((mask > 0.5).sum()).astype(np.float32) tensor_update = flow.tensor(update) x[y > 0.5] = tensor_update numpy_x[mask > 0.5] = update test_case.assertTrue(np.allclose(numpy_x, x.numpy())) # empty mask x[y > 1.0] = 1.0 numpy_x[mask > 1.0] = 1.0 test_case.assertTrue(np.allclose(numpy_x, x.numpy())) def _test_list_indexing_using_scalar_tensor(test_case, dtype): y = np.random.randint(0, 100, size=100) for i in range(len(y)): x = flow.tensor(i, dtype=dtype) test_case.assertEqual(y[i], y[x]) @flow.unittest.skip_unless_1n1d() class TestTensorIndexing(flow.unittest.TestCase): def test_basic_slice(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_basic_slice(test_case, numpy_x) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_basic_slice(test_case, numpy_x) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_basic_slice(test_case, numpy_x) def test_advanced_indexing(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_advanced_indexing(test_case, numpy_x) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_advanced_indexing(test_case, numpy_x) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_advanced_indexing(test_case, numpy_x) def test_advanced_indexing_array(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 2, 2, 5]).astype(np.float32) _test_advanced_indexing_array(test_case, numpy_x, np.int32) _test_advanced_indexing_array(test_case, numpy_x, np.int64) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_advanced_indexing_array(test_case, numpy_x, np.int32) _test_advanced_indexing_array(test_case, numpy_x, np.int64) numpy_x = np.arange(0, 720, 1).reshape([5, 8, 9, 2]).astype(np.float32) _test_advanced_indexing_array(test_case, numpy_x, np.int32) _test_advanced_indexing_array(test_case, numpy_x, np.int64) def test_combining_indexing(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_combining_indexing(test_case, numpy_x) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_combining_indexing(test_case, numpy_x) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_combining_indexing(test_case, numpy_x) def test_numpy_scalar_indexing(test_case): for np_scalar in [np.int8, np.int16, np.int32, np.int64]: numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) # TODO: add np.int16 when advance indexing supports np.int16 mapping for np_scalar in [np.int32, np.int64]: numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) def test_mask_getitem(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_mask_getitem(test_case, numpy_x) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_mask_getitem(test_case, numpy_x) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_mask_getitem(test_case, numpy_x) numpy_x = np.arange(0, 27, 1).reshape(3, 3, 3) x = flow.tensor(numpy_x) test_case.assertTrue( np.allclose( numpy_x[[False, True, False], 1], x[[False, True, False], 1].numpy() ) ) test_case.assertTrue( np.allclose( numpy_x[[False, True, False], [True, False, False]], x[[False, True, False], [True, False, False]].numpy(), ) ) def test_mask_setitem(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_mask_setitem(test_case, numpy_x) numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) _test_mask_setitem(test_case, numpy_x) numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_mask_setitem(test_case, numpy_x) def test_combined_mask_setitem(test_case): np_in = np.random.rand(5, 4, 3, 2) np_mask_dim1 = np.array([False, True, False, True]) np_mask_dim3 = np.array([True, False]) np_update = np.random.rand(2, 5, 3) np_in[:, np_mask_dim1, :, np_mask_dim3] = np_update flow_in = flow.tensor(np_in) flow_mask_dim1 = flow.tensor(np_mask_dim1) flow_mask_dim3 = flow.tensor(np_mask_dim3) flow_update = flow.tensor(np_update) flow_in[:, flow_mask_dim1, :, flow_mask_dim3] = flow_update test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in)) def test_non_contiguous_combined_mask_setitem(test_case): np_in = np.random.rand(5, 4, 3, 2) np_mask_dim1 = np.array([False, True, False]) np_mask_dim3 = np.array([True, False, False, True, True]) np_update = np.random.rand(4, 2, 3) flow_in = flow.tensor(np_in).permute(3, 2, 1, 0) # (2, 3, 4, 5) flow_mask_dim1 = flow.tensor(np_mask_dim1) flow_mask_dim3 = flow.tensor(np_mask_dim3) flow_update = flow.tensor(np_update).permute(2, 1, 0) # (3, 2, 4) flow_in[:, flow_mask_dim1, :, flow_mask_dim3] = flow_update np_in = np_in.transpose(3, 2, 1, 0) np_update = np_update.transpose(2, 1, 0) np_in[:, np_mask_dim1, :, np_mask_dim3] = np_update test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in)) def test_combined_indexing_setitem(test_case): np_in = np.random.rand(2, 3, 4) np_in[[0, 1], 1:2, [0, 1]] = 1.0 flow_in = flow.tensor(np_in) flow_in[[0, 1], 1:2, [0, 1]] = 1.0 test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in)) def test_expand_dim_setitem(test_case): a = flow.tensor(1.0) a[True, ...] = 0.0 test_case.assertTrue(np.array_equal(a.numpy(), 0.0)) a = flow.tensor(1.0) a[False, ...] = 1.0 test_case.assertTrue(np.array_equal(a.numpy(), 1.0)) def test_advanced_indexing_with_scalar_index(test_case): index = flow.tensor([0, 2]) x = flow.randn(5) x[index[0]] = 1 test_case.assertTrue(np.allclose(x[0].numpy(), 1)) def test_list_indexing_using_scalar_tensor(test_case): arg_dict = OrderedDict() arg_dict["function_test"] = [ _test_list_indexing_using_scalar_tensor, ] arg_dict["dtype"] = [flow.uint8, flow.int8, flow.int32, flow.int64] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(n=3, auto_backward=False) def test_advanced_indexing_with_0_size_tensor(test_case): device = random_device() data = torch.arange(8).reshape(2, 2, 2).to(device) ranges = [] ranges.append(torch.ones(0, 1).to(torch.int64)) ranges.append(torch.zeros(1, 3).to(torch.int64)) res = data[ranges] return res @autotest(n=1) def test_dataloader_indexing_with_1_dim_tensor(test_case): device = random_device() x = random_tensor(ndim=1, dim0=512).to(device) batch_data = list() for i in range(512): batch_data.append(x[i]) return torch.stack(batch_data) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_indecies_on_different_devices(test_case): x = flow.ones(3, 10) y = flow.ones(3, 10, device=flow.device("cuda:0")) x_idx = [flow.tensor([1, 2]), flow.tensor([2, 0], device=flow.device("cuda:0"))] y_idx = [flow.tensor([1, 2], device=flow.device("cuda:0")), flow.tensor([2, 0])] test_case.assertTrue(np.allclose(x[x_idx].numpy(), np.array([1, 1]))) test_case.assertTrue(np.allclose(y[y_idx].numpy(), np.array([1, 1]))) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTensorIndexingMultiGpu(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_indecies_on_different_devices(test_case): x = flow.ones(3, 10, device=flow.device("cuda:0")) idx = [flow.tensor([1, 2], device=flow.device("cuda:1")), flow.tensor([2, 0])] test_case.assertTrue(np.allclose(x[idx].numpy(), np.array([1, 1]))) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_indexing2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This test code is referenced from: https://github.com/pytorch/pytorch/blob/cd41c8f032dd06c445bf97fc76fb82008b19afcb/test/test_indexing.py from collections import OrderedDict import random from random import randrange import unittest import numpy as np import oneflow as flow from oneflow.test_utils.test_util import GenArgDict import oneflow.unittest def _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0): test_case.assertTrue(np.allclose(tensor1.numpy(), tensor2.numpy())) def consec(size, start=1): """ Generate a arithmetic progression with given size and start value. """ sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0) sequence.add_(start - 1) return sequence.view(*size) def _test_basic_slice(test_case, device, dtype): reference = consec((3, 3, 3)).to(device=device, dtype=dtype) # empty tensor indexing _assert_tensor_equal( test_case, reference[flow.LongTensor().to(device)], flow.empty(0, 3, 3), atol=0, rtol=0, ) _assert_tensor_equal(test_case, reference[0], consec((3, 3)), atol=0, rtol=0) _assert_tensor_equal(test_case, reference[1], consec((3, 3), 10), atol=0, rtol=0) _assert_tensor_equal(test_case, reference[2], consec((3, 3), 19), atol=0, rtol=0) _assert_tensor_equal(test_case, reference[0, 1], consec((3,), 4), atol=0, rtol=0) _assert_tensor_equal(test_case, reference[0:2], consec((2, 3, 3)), atol=0, rtol=0) test_case.assertEqual(reference[2, 2, 2].item(), 27) _assert_tensor_equal(test_case, reference[:], consec((3, 3, 3)), atol=0, rtol=0) # indexing with Ellipsis _assert_tensor_equal( test_case, reference[..., 2], flow.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]), atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference[0, ..., 2], flow.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[..., 2], reference[:, :, 2], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference[0, 2, ...], reference[0, 2], atol=0, rtol=0 ) test_case.assertEqual(reference[..., 2, 2, 2].item(), 27) test_case.assertEqual(reference[2, ..., 2, 2].item(), 27) test_case.assertEqual(reference[2, 2, ..., 2].item(), 27) test_case.assertEqual(reference[2, 2, 2, ...].item(), 27) _assert_tensor_equal(test_case, reference[...], reference, atol=0, rtol=0) reference_5d = consec((3, 3, 3, 3, 3)).to(device) _assert_tensor_equal( test_case, reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0 ) _assert_tensor_equal( test_case, reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0, ) _assert_tensor_equal( test_case, reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0, ) _assert_tensor_equal(test_case, reference_5d[...], reference_5d, atol=0, rtol=0) # LongTensor indexing reference = consec((5, 5, 5)).to(device=device, dtype=dtype) idx = flow.LongTensor([2, 4]).to(device) _assert_tensor_equal( test_case, reference[idx], flow.stack([reference[2], reference[4]]) ) # None indexing _assert_tensor_equal(test_case, reference[2, None], reference[2].unsqueeze(0)) _assert_tensor_equal( test_case, reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0) ) _assert_tensor_equal(test_case, reference[2:4, None], reference[2:4].unsqueeze(1)) _assert_tensor_equal( test_case, reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), ) _assert_tensor_equal( test_case, reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), ) # indexing 0-length slice _assert_tensor_equal(test_case, flow.empty(0, 5, 5), reference[slice(0)]) _assert_tensor_equal(test_case, flow.empty(0, 5), reference[slice(0), 2]) _assert_tensor_equal(test_case, flow.empty(0, 5), reference[2, slice(0)]) _assert_tensor_equal(test_case, flow.tensor([]), reference[2, 1:1, 2]) # indexing with step reference = consec((10, 10, 10)).to(device=device, dtype=dtype) _assert_tensor_equal( test_case, reference[1:5:2], flow.stack([reference[1], reference[3]], 0) ) _assert_tensor_equal( test_case, reference[1:6:2], flow.stack([reference[1], reference[3], reference[5]], 0), ) _assert_tensor_equal( test_case, reference[1:9:4], flow.stack([reference[1], reference[5]], 0) ) _assert_tensor_equal( test_case, reference[2:4, 1:5:2], flow.stack([reference[2:4, 1], reference[2:4, 3]], 1), ) _assert_tensor_equal( test_case, reference[3, 1:6:2], flow.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0), ) _assert_tensor_equal( test_case, reference[None, 2, 1:9:4], flow.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0), ) _assert_tensor_equal( test_case, reference[:, 2, 1:6:2], flow.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1), ) lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] tensor = flow.DoubleTensor(lst).to(device=device, dtype=dtype) for _ in range(10): idx1_start = randrange(10) idx1_end = idx1_start + randrange(1, 10 - idx1_start + 1) idx1_step = randrange(1, 8) idx1 = slice(idx1_start, idx1_end, idx1_step) if randrange(2) == 0: idx2_start = randrange(10) idx2_end = idx2_start + randrange(1, 10 - idx2_start + 1) idx2_step = randrange(1, 8) idx2 = slice(idx2_start, idx2_end, idx2_step) lst_indexed = [l[idx2] for l in lst[idx1]] tensor_indexed = tensor[idx1, idx2] else: lst_indexed = lst[idx1] tensor_indexed = tensor[idx1] _assert_tensor_equal( test_case, flow.DoubleTensor(lst_indexed).to(dtype), tensor_indexed ) test_case.assertRaises(RuntimeError, lambda: reference[1:9:0]) test_case.assertRaises(RuntimeError, lambda: reference[1:9:-1]) test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) test_case.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) test_case.assertRaises(IndexError, lambda: reference[0.0]) test_case.assertRaises(RuntimeError, lambda: reference[0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) def _test_advanced_indexing(test_case, device, dtype): # pick a random valid indexer type def ri(indices): choice = random.randint(0, 2) if choice == 0: return flow.LongTensor(indices).to(device) elif choice == 1: return list(indices) else: return tuple(indices) def validate_indexing(x): _assert_tensor_equal(test_case, x[[0]], consec((1,))) _assert_tensor_equal(test_case, x[ri([0]),], consec((1,))) _assert_tensor_equal(test_case, x[ri([3]),], consec((1,), 4)) _assert_tensor_equal(test_case, x[[2, 3, 4]], consec((3,), 3)) _assert_tensor_equal(test_case, x[ri([2, 3, 4]),], consec((3,), 3)) _assert_tensor_equal( test_case, x[ri([0, 2, 4]),], flow.tensor([1, 3, 5], dtype=dtype, device=device), ) def validate_setting(x): x[[0]] = -2 _assert_tensor_equal( test_case, x[[0]], flow.tensor([-2], dtype=dtype, device=device) ) x[[0]] = -1 _assert_tensor_equal( test_case, x[ri([0]),], flow.tensor([-1], dtype=dtype, device=device) ) x[[2, 3, 4]] = 4 _assert_tensor_equal( test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype, device=device) ) x[ri([2, 3, 4]),] = 3 _assert_tensor_equal( test_case, x[ri([2, 3, 4]),], flow.tensor([3, 3, 3], dtype=dtype, device=device), ) x[ri([0, 2, 4]),] = flow.tensor([5, 4, 3], dtype=dtype, device=device) _assert_tensor_equal( test_case, x[ri([0, 2, 4]),], flow.tensor([5, 4, 3], dtype=dtype, device=device), ) # 1d tensor and integer index setitem and getitem reference = consec((10,)).to(device=device, dtype=dtype) validate_indexing(reference) validate_setting(reference) # reference is 1 2 # 3 4 # 5 6 reference = consec((3, 2)).to(device=device, dtype=dtype) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([1, 3, 5], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([1])], flow.tensor([2, 4, 6], dtype=dtype, device=device), ) _assert_tensor_equal(test_case, reference[ri([0]), ri([0])], consec((1,))) _assert_tensor_equal(test_case, reference[ri([2]), ri([1])], consec((1,), 6)) _assert_tensor_equal( test_case, reference[[ri([0, 0]), ri([0, 1])]], flow.tensor([1, 2], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[[ri([0, 1, 1, 0, 2]), ri([1])]], flow.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], flow.tensor([1, 2, 3, 3], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 2]]) columns = ([0],) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[1, 1], [3, 5]], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 2]]) columns = ri([1, 0]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[2, 1], [4, 5]], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 2]]) columns = ri([[0, 1], [1, 0]]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[1, 2], [4, 5]], dtype=dtype, device=device), ) # setting values reference[ri([0]), ri([1])] = -1 _assert_tensor_equal( test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype, device=device), ) reference[ri([0, 1, 2]), ri([0])] = flow.tensor( [-1, 2, -4], dtype=dtype, device=device ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([-1, 2, -4], dtype=dtype, device=device), ) reference[rows, columns] = flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), ) # Test non-contiguous(by transpose) reference # Transposed: [[0, 4, 8], # [1, 5, 9], # [2, 6, 10], # [3, 7, 11]] reference = flow.tensor( [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device ).T _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([0, 1, 2], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([1])], flow.tensor([4, 5, 6], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[ri([0]), ri([0])], flow.tensor([0], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[ri([2]), ri([1])], flow.tensor([6], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[[ri([0, 0]), ri([0, 1])]], flow.tensor([0, 4], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[[ri([0, 1, 1, 0, 3]), ri([1])]], flow.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], flow.tensor([0, 4, 1, 1], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 2]]) columns = ([0],) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[0, 0], [1, 2]], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 2]]) columns = ri([1, 0]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[4, 0], [5, 2]], dtype=dtype, device=device), ) rows = ri([[0, 0], [1, 3]]) columns = ri([[0, 1], [1, 2]]) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[0, 4], [5, 11]], dtype=dtype, device=device), ) # setting values reference[ri([0]), ri([1])] = -1 _assert_tensor_equal( test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype, device=device), ) reference[ri([0, 1, 2]), ri([0])] = flow.tensor( [-1, 2, -4], dtype=dtype, device=device ) _assert_tensor_equal( test_case, reference[ri([0, 1, 2]), ri([0])], flow.tensor([-1, 2, -4], dtype=dtype, device=device), ) reference[rows, columns] = flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) _assert_tensor_equal( test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device), ) # Tests using less than the number of dims, and ellipsis # reference is 1 2 # 3 4 # 5 6 reference = consec((3, 2)).to(dtype=dtype, device=device) _assert_tensor_equal( test_case, reference[ri([0, 2]),], flow.tensor([[1, 2], [5, 6]], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[ri([1]), ...], flow.tensor([[3, 4]], dtype=dtype, device=device), ) _assert_tensor_equal( test_case, reference[..., ri([1])], flow.tensor([[2], [4], [6]], dtype=dtype, device=device), ) # verify too many indices fails with test_case.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])] # test invalid index fails reference = flow.empty(10, dtype=dtype, device=device) for err_idx in (10, -11): with test_case.assertRaisesRegex(IndexError, r"out of range"): reference[err_idx] def _test_combined_indexing(test_case, device, dtype): def tensor_indices_to_np(tensor, indices): # convert the flow Tensor to a numpy array tensor = tensor.to(device="cpu") npt = tensor.numpy() # convert indices idxs = tuple( i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices ) return npt, idxs def get_numpy(tensor, indices): npt, idxs = tensor_indices_to_np(tensor, indices) # index and return as a flow Tensor return flow.tensor(npt[idxs], dtype=dtype, device=device) def set_numpy(tensor, indices, value): if not isinstance(value, int): if device != "cpu": value = value.cpu() value = value.numpy() npt, idxs = tensor_indices_to_np(tensor, indices) npt[idxs] = value return npt def assert_get_eq(tensor, indexer): _assert_tensor_equal(test_case, tensor[indexer], get_numpy(tensor, indexer)) def assert_set_eq(tensor, indexer, val): pyt = tensor.clone() np_ref = tensor.clone() pyt[indexer] = val np_ref = flow.tensor( set_numpy(np_ref, indexer, val), dtype=dtype, device=device ) _assert_tensor_equal(test_case, pyt, np_ref) def assert_backward_eq(tensor, indexer): cpu = tensor.cpu().float().clone().detach().requires_grad_(True) outcpu = cpu[indexer] grad = flow.rand(outcpu.shape) outcpu.backward(grad) dev = cpu.to(device).detach().requires_grad_(True) outdev = dev[indexer] outdev.backward(grad.to(device)) _assert_tensor_equal(test_case, cpu.grad, dev.grad) def get_set_tensor(indexed, indexer): set_size = indexed[indexer].size() set_count = indexed[indexer].numel() set_tensor = flow.randperm(set_count).view(set_size).to(dtype).to(device) return set_tensor # Tensor is 0 1 2 3 4 # 5 6 7 8 9 # 10 11 12 13 14 # 15 16 17 18 19 reference = flow.arange(0.0, 20, device=device).to(dtype).view(4, 5) indices_to_test = [ # grab the second, fourth columns [slice(None), [1, 3]], # first, third rows, [[0, 2], slice(None)], # TODO(wyg): only support getitem but not setitem # # weird shape # [slice(None), [[0, 1], # [2, 3]]], # negatives [[-1], [0]], [[0, 2], [-1]], [slice(None), [-1]], ] # test getitem get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] get_indices_to_test = indices_to_test + [ [slice(None), [[0, 1], [2, 3]]] ] # TODO: test setitem for indexer in get_indices_to_test: assert_get_eq(reference, indexer) if device != "cpu": assert_backward_eq(reference, indexer) # test setitem for indexer in indices_to_test: assert_set_eq(reference, indexer, 44) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) ######################### # test more dims tensor # ######################### reference = flow.arange(0.0, 160, device=device).to(dtype).view(4, 8, 5) indices_to_test = [ [slice(None), slice(None), [0, 3, 4]], [slice(None), [2, 4, 5, 7], slice(None)], [[2, 3], slice(None), slice(None)], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0], [1, 2, 4]], [slice(None), [0, 1, 3], [4]], [slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), [[5, 6]], [[0, 3], [4, 4]]], [[0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4], slice(None)], [[0, 1, 3], [4], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 3]], slice(None)], [[[0, 1], [2, 3]], [[0]], slice(None)], [[[2, 1]], [[0, 3], [4, 4]], slice(None)], [[[2]], [[0, 3], [4, 1]], slice(None)], # non-contiguous indexing subspace [[0, 2, 3], slice(None), [1, 3, 4]], # less dim, ellipsis [[0, 2],], [[0, 2], slice(None)], [[0, 2], Ellipsis], [[0, 2], slice(None), Ellipsis], [[0, 2], Ellipsis, slice(None)], [[0, 2], [1, 3]], [[0, 2], [1, 3], Ellipsis], [Ellipsis, [1, 3], [2, 3]], [Ellipsis, [2, 3, 4]], [Ellipsis, slice(None), [2, 3, 4]], [slice(None), Ellipsis, [2, 3, 4]], # ellipsis counts for nothing [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, slice(None), [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), [0, 3, 4], Ellipsis], [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 212) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) if device != "cpu": assert_backward_eq(reference, indexer) reference = flow.arange(0.0, 1296, device=device).to(dtype).view(3, 9, 8, 6) indices_to_test = [ [slice(None), slice(None), slice(None), [0, 3, 4]], [slice(None), slice(None), [2, 4, 5, 7], slice(None)], [slice(None), [2, 3], slice(None), slice(None)], [[1, 2], slice(None), slice(None), slice(None)], [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), slice(None), [0], [1, 2, 4]], [slice(None), slice(None), [0, 1, 3], [4]], [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], [slice(None), [0], [1, 2, 4], slice(None)], [slice(None), [0, 1, 3], [4], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], [[0], [1, 2, 4], slice(None), slice(None)], [[0, 1, 2], [4], slice(None), slice(None)], [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 4], [1, 3, 4], [4]], [slice(None), [0, 1, 3], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1, 3, 4]], [slice(None), [2, 3, 5], [3], [4]], [slice(None), [0], [4], [1, 3, 4]], [slice(None), [6], [0, 2, 3], [1]], [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], [[2, 0, 1], [1, 2, 3], [4], slice(None)], [[0, 1, 2], [4], [1, 3, 4], slice(None)], [[0], [0, 2, 3], [1, 3, 4], slice(None)], [[0, 2, 1], [3], [4], slice(None)], [[0], [4], [1, 3, 4], slice(None)], [[1], [0, 2, 3], [1], slice(None)], [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], # less dim, ellipsis [Ellipsis, [0, 3, 4]], [Ellipsis, slice(None), [0, 3, 4]], [Ellipsis, slice(None), slice(None), [0, 3, 4]], [slice(None), Ellipsis, [0, 3, 4]], [slice(None), slice(None), Ellipsis, [0, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4]], [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], [[0], [1, 2, 4]], [[0], [1, 2, 4], slice(None)], [[0], [1, 2, 4], Ellipsis], [[0], [1, 2, 4], Ellipsis, slice(None)], [[1],], [[0, 2, 1], [3], [4]], [[0, 2, 1], [3], [4], slice(None)], [[0, 2, 1], [3], [4], Ellipsis], [Ellipsis, [0, 2, 1], [3], [4]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) indices_to_test += [ [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], ] for indexer in indices_to_test: assert_get_eq(reference, indexer) assert_set_eq(reference, indexer, 1333) if device != "cpu": assert_backward_eq(reference, indexer) def _test_single_int(test_case, device): v = flow.randn(5, 7, 3, device=device) test_case.assertEqual(v[4].shape, (7, 3)) def _test_multiple_int(test_case, device): v = flow.randn(5, 7, 3, device=device) test_case.assertEqual(v[4].shape, (7, 3)) test_case.assertEqual(v[4, :, 1].shape, (7,)) def _test_none(test_case, device): v = flow.randn(5, 7, 3, device=device) test_case.assertEqual(v[None].shape, (1, 5, 7, 3)) test_case.assertEqual(v[:, None].shape, (5, 1, 7, 3)) test_case.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) test_case.assertEqual(v[..., None].shape, (5, 7, 3, 1)) def _test_step(test_case, device): v = flow.arange(10, device=device) _assert_tensor_equal(test_case, v[::1], v) test_case.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8]) test_case.assertEqual(v[::3].tolist(), [0, 3, 6, 9]) test_case.assertEqual(v[::11].tolist(), [0]) test_case.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) def _test_step_assignment(test_case, device): v = flow.zeros(4, 4, device=device) v[0, 1::2] = flow.tensor([3.0, 4.0], device=device) test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0]) test_case.assertEqual(v[1:].sum(), 0) def _test_bool_indices(test_case, device): v = flow.randn(5, 7, 3, device=device) boolIndices = flow.tensor( [True, False, True, True, False], dtype=flow.bool, device=device ) test_case.assertEqual(v[boolIndices].shape, (3, 7, 3)) _assert_tensor_equal(test_case, v[boolIndices], flow.stack([v[0], v[2], v[3]])) v = flow.tensor([True, False, True], dtype=flow.bool, device=device) boolIndices = flow.tensor([True, False, False], dtype=flow.bool, device=device) uint8Indices = flow.tensor([1, 0, 0], dtype=flow.uint8, device=device) test_case.assertEqual(v[boolIndices].shape, v[uint8Indices].shape) test_case.assertEqual(v[boolIndices], v[uint8Indices]) test_case.assertEqual( v[boolIndices], flow.tensor([True], dtype=flow.bool, device=device) ) def _test_multiple_bool_indices(test_case, device): v = flow.randn(5, 7, 3, device=device) # NOTE: these broadcast together and are transposed to the first dim mask1 = flow.tensor([1, 0, 1, 1, 0], dtype=flow.bool, device=device) mask2 = flow.tensor([1, 1, 1], dtype=flow.bool, device=device) test_case.assertEqual(v[mask1, :, mask2].shape, (3, 7)) def _test_int_indices(test_case, device): v = flow.randn(5, 7, 3, device=device) test_case.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) test_case.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) test_case.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) def _test_int_indices2d(test_case, device): x = flow.arange(0, 12, device=device).view(4, 3) rows = flow.tensor([[0, 0], [3, 3]], device=device) columns = flow.tensor([[0, 2], [0, 2]], device=device) test_case.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]]) def _test_int_indices_broadcast(test_case, device): x = flow.arange(0, 12, device=device).view(4, 3) rows = flow.tensor([0, 3], device=device) columns = flow.tensor([0, 2], device=device) result = x[rows[:, None], columns] test_case.assertEqual(result.tolist(), [[0, 2], [9, 11]]) def _test_empty_index(test_case, device): x = flow.arange(0, 12, device=device).view(4, 3) idx = flow.tensor([], dtype=flow.long, device=device) test_case.assertEqual(x[idx].numel(), 0) # empty assignment should have no effect but not throw an exception y = x.clone() y[idx] = -1 _assert_tensor_equal(test_case, x, y) mask = flow.zeros(4, 3, device=device).to(flow.bool) y[mask] = -1 _assert_tensor_equal(test_case, x, y) def _test_empty_ndim_index(test_case, device): x = flow.randn(5, device=device) _assert_tensor_equal( test_case, flow.empty(0, 2, device=device), x[flow.empty(0, 2, dtype=flow.int64, device=device)], ) x = flow.randn(2, 3, 4, 5, device=device) _assert_tensor_equal( test_case, flow.empty(2, 0, 6, 4, 5, device=device), x[:, flow.empty(0, 6, dtype=flow.int64, device=device)], ) x = flow.empty(10, 0, device=device) test_case.assertEqual(x[[1, 2]].shape, (2, 0)) test_case.assertEqual(x[[], []].shape, (0,)) test_case.assertEqual(x[[[]]].shape, (0, 0)) test_case.assertEqual(x[[[[]]]].shape, (1, 0, 0)) test_case.assertEqual(x[[1], []].shape, (0,)) test_case.assertEqual(x[[], [2]].shape, (0,)) with test_case.assertRaisesRegex(IndexError, "for dimension with size 0"): x[:, [0, 1]] def _test_empty_ndim_index_bool(test_case, device): x = flow.randn(5, device=device) test_case.assertRaises( IndexError, lambda: x[flow.empty(0, 2, dtype=flow.uint8, device=device)] ) def _test_empty_slice(test_case, device): x = flow.randn(2, 3, 4, 5, device=device) y = x[:, :, :, 1] z = y[:, 1:1, :] test_case.assertEqual((2, 0, 4), z.shape) # this isn't technically necessary, but matches NumPy stride calculations. test_case.assertEqual((60, 20, 5), z.stride()) test_case.assertTrue(z.is_contiguous()) def _test_index_getitem_copy_bools_slices(test_case, device): true = flow.tensor(1, dtype=flow.uint8, device=device) false = flow.tensor(0, dtype=flow.uint8, device=device) tensors = [flow.randn(2, 3, device=device), flow.tensor([1.0], device=device)] # TODO: compare tensor_storage after exporting the inferface for a in tensors: # test_case.assertNotEqual(a.data_ptr(), a[True].data_ptr()) _assert_tensor_equal(test_case, flow.empty(0, *a.shape), a[False]) # test_case.assertNotEqual(a.data_ptr(), a[true].data_ptr()) _assert_tensor_equal(test_case, flow.empty(0, *a.shape), a[false]) # test_case.assertEqual(a.data_ptr(), a[None].data_ptr()) # test_case.assertEqual(a.data_ptr(), a[...].data_ptr()) def _test_setitem_scalars(test_case, device): zero = flow.tensor(0, dtype=flow.int64) # non-scalar indexed with scalars a = flow.randn(2, 3, device=device) a_set_with_number = a.clone() a_set_with_scalar = a.clone() b = flow.randn(3, device=device) a_set_with_number[0] = b a_set_with_scalar[zero] = b _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar) a[1, zero] = 7.7 value = a[1, 0].numpy() test_case.assertEqual(np.array(7.7, dtype=value.dtype), value) np_x = np.random.rand(2, 3) np_x[0, 0] = 1.0 x = flow.tensor(np_x) x[0, 0] = 1.0 test_case.assertEqual(x.numpy().all(), np_x.all()) # scalar indexed with scalars r = flow.tensor(1.0).to(device) with test_case.assertRaises(IndexError): r[:] = 8.8 with test_case.assertRaises(IndexError): r[zero] = 8.8 r[...] = 9.9 test_case.assertEqual(r, 9.9) # scalar indexed with oneflow.Size([1]) np_x = np.random.rand(2, 3) np_x[0, 0] = np.ones(1) x = flow.tensor(np_x) x[0, 0] = flow.ones(1).to(flow.float64) test_case.assertEqual(x.numpy().all(), np_x.all()) def _test_basic_advanced_combined(test_case, device): x = flow.arange(0, 12, device=device).view(4, 3) _assert_tensor_equal(test_case, x[1:2, 1:3], x[1:2, [1, 2]]) test_case.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]]) # Check that it is a copy unmodified = x.clone() x[1:2, [1, 2]].zero_() _assert_tensor_equal(test_case, x, unmodified) # But assignment should modify the original unmodified = x.clone() x[1:2, [1, 2]] = 0 test_case.assertFalse(np.array_equal(x.numpy(), unmodified.numpy())) def _test_ellipsis_tensor(test_case, device): x = flow.arange(0, 9, device=device).view(3, 3) idx = flow.tensor([0, 2], device=device) test_case.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]]) test_case.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]]) # Test scalar ellipsis getitem y = flow.tensor(1.0).to(device) x_scalar = flow.tensor(9.9) y = x_scalar[...] test_case.assertEqual(y, 9.9) @flow.unittest.skip_unless_1n1d() class TestIndexing(flow.unittest.TestCase): def test_slice(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgDict(arg_dict): dtype_list = [flow.float32, flow.float16] from oneflow import sysconfig if not sysconfig.get_cuda_version() < 11000: dtype_list.append(flow.bfloat16) for dtype in dtype_list: _test_basic_slice(test_case, **arg, dtype=dtype) _test_advanced_indexing(test_case, **arg, dtype=dtype) _test_combined_indexing(test_case, **arg, dtype=dtype) _test_single_int(test_case, **arg) _test_multiple_int(test_case, **arg) _test_none(test_case, **arg) _test_step(test_case, **arg) _test_step_assignment(test_case, **arg) _test_bool_indices(test_case, **arg) _test_multiple_bool_indices(test_case, **arg) _test_int_indices(test_case, **arg) _test_int_indices2d(test_case, **arg) _test_int_indices_broadcast(test_case, **arg) _test_empty_index(test_case, **arg) _test_empty_ndim_index(test_case, **arg) _test_empty_ndim_index_bool(test_case, **arg) _test_empty_slice(test_case, **arg) _test_index_getitem_copy_bools_slices(test_case, **arg) _test_setitem_scalars(test_case, **arg) _test_basic_advanced_combined(test_case, **arg) _test_ellipsis_tensor(test_case, **arg) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_is_view.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random import numpy as np from collections import OrderedDict import oneflow as flow import oneflow.unittest from oneflow.test_utils.test_util import GenArgList def _test_is_view(test_case, device): shape = (2, 3, 4, 5) xx = flow.randn(shape, device=device) yy = xx.reshape(4, 5, 6) test_case.assertEqual(xx.is_contiguous(), yy.is_contiguous()) test_case.assertEqual(yy.is_view(), True) test_case.assertEqual(xx.is_view(), False) @flow.unittest.skip_unless_1n1d() class TestTensorIsView(flow.unittest.TestCase): def test_is_view(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cuda", "cpu"] for arg in GenArgList(arg_dict): _test_is_view(test_case, *arg[0:]) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_part_1.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import copy import os import numpy as np import unittest import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensor(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_numpy_and_default_dtype(test_case): shape = (2, 3, 4, 5) tensor = flow.Tensor(*shape) flow.nn.init.ones_(tensor) test_case.assertTrue(tensor.dtype == flow.float32) test_case.assertTrue( np.allclose(tensor.numpy(), np.ones(shape, dtype=np.float32)) ) shape = flow.Size((2, 3, 4, 5)) tensor = flow.Tensor(shape) flow.nn.init.ones_(tensor) test_case.assertTrue(tensor.dtype == flow.float32) test_case.assertTrue( np.allclose(tensor.numpy(), np.ones(shape, dtype=np.float32)) ) shape = flow.Size((2, 3)) tensor = flow.Tensor(shape) flow.nn.init.eye_(tensor) test_case.assertTrue(tensor.dtype == flow.float32) test_case.assertTrue(np.allclose(tensor.numpy(), np.eye(2, 3))) @flow.unittest.skip_unless_1n1d() def test_tensor_deepcopy(test_case): shape = (2, 3) tensor1 = flow.ones(*shape).cuda() tensor2 = copy.deepcopy(tensor1) tensor1[0, 0] = 0 test_case.assertEqual(tensor1.device, tensor2.device) test_case.assertEqual(tensor1[0, 0], 0) test_case.assertEqual(tensor2[0, 0], 1) @flow.unittest.skip_unless_1n1d() def test_tensor_property(test_case): shape = (2, 3, 4, 5) tensor = flow.Tensor(*shape) test_case.assertEqual(tensor.storage_offset(), 0) test_case.assertEqual(tensor.stride(), (60, 20, 5, 1)) test_case.assertEqual(tensor.is_cuda, False) test_case.assertTrue(tensor.is_contiguous()) @flow.unittest.skip_unless_1n1d() def test_copy_to_and_from_numpy(test_case): np_arr = np.array([4, 6], dtype=np.float32) tensor = flow.tensor(np_arr, dtype=flow.float32) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) test_case.assertEqual(np.float32, tensor.numpy().dtype) np_arr = np.array([4, 6], dtype=np.int32) tensor = flow.tensor(np_arr, dtype=flow.int32) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) test_case.assertEqual(np.int32, tensor.numpy().dtype) np_arr = np.array([4, 6], dtype=np.float16) tensor = flow.tensor(np_arr, dtype=flow.float16) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) test_case.assertEqual(np.float16, tensor.numpy().dtype) @flow.unittest.skip_unless_1n1d() def test_inplace_copy_from_contiguous_numpy(test_case): np_arr = np.arange(6).reshape(3, 2) tensor = flow.zeros(3, 2).to(flow.int64) tensor.copy_(np_arr) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) @flow.unittest.skip_unless_1n1d() def test_inplace_copy_from_non_contiguous_numpy(test_case): np_arr = np.arange(6).reshape(2, 3).transpose(1, 0) tensor = flow.zeros(3, 2).to(flow.int64) tensor.copy_(np_arr) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) @flow.unittest.skip_unless_1n1d() def test_construct_from_numpy_or_list(test_case): shape = (2, 3, 4, 5) np_arr = np.random.rand(*shape).astype(np.float32) tensor = flow.tensor(np_arr) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) np_int_arr = np.random.randint(-100, high=100, size=shape, dtype=np.int32) tensor = flow.tensor(np_int_arr, dtype=flow.int32) test_case.assertEqual(tensor.dtype, flow.int32) test_case.assertTrue(np_arr.flags["C_CONTIGUOUS"]) test_case.assertTrue(np.allclose(tensor.numpy(), np_int_arr)) np_arr = np.random.random((1, 256, 256, 3)).astype(np.float32) np_arr = np_arr.transpose(0, 3, 1, 2) tensor = flow.tensor(np_arr) test_case.assertFalse(np_arr.flags["C_CONTIGUOUS"]) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) @flow.unittest.skip_unless_1n1d() def test_construct_from_another_tensor(test_case): shape = (2, 3, 4, 5) np_arr = np.random.rand(*shape).astype(np.float32) tensor = flow.tensor(np_arr) output = flow.tensor(tensor) test_case.assertEqual(output.dtype, flow.float32) test_case.assertTrue(np.allclose(output.numpy(), np_arr)) @flow.unittest.skip_unless_1n1d() def test_construct_np_array_from_tensor(test_case): tensor = flow.randn(5) np_arr = np.array(tensor) test_case.assertEqual(np_arr.shape, (5,)) test_case.assertEqual(np_arr.dtype, np.float32) test_case.assertTrue(np.allclose(np_arr, tensor.numpy())) test_case.assertEqual(str(np_arr), str(tensor.numpy())) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_sign_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.sign() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_gather_with_random_data(test_case): device = random_device() input = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device) dim = random(0, 4).to(int).value() index = random_tensor( ndim=4, dim1=random(1, 3).to(int), dim2=random(1, 4).to(int), dim3=random(1, 5).to(int), low=0, high=1 if dim == 0 else dim, dtype=int, ).to(device) return input.gather(dim, index) def _test_tensor_init_methods(test_case, tensor_creator, get_numpy): for dtype in [flow.float32, flow.float16]: shape = (2, 3, 4, 5) x = tensor_creator(*shape).to(dtype) np_ones = np.ones(x.shape) np_zeros = np.zeros(x.shape) random_fill_val = 2.0 x.fill_(random_fill_val) test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones)) flow.nn.init.ones_(x) test_case.assertTrue(np.allclose(get_numpy(x), np_ones)) flow.nn.init.zeros_(x) test_case.assertTrue(np.allclose(get_numpy(x), np_zeros)) flow.nn.init.constant_(x, random_fill_val) test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones)) z = tensor_creator(5, 4, 3, 2) flow.nn.init.kaiming_normal_(z, a=0.1, mode="fan_out", nonlinearity="relu") flow.nn.init.kaiming_uniform_(z) z.requires_grad_() flow.nn.init.xavier_normal_(z, flow.nn.init.calculate_gain("relu")) flow.nn.init.xavier_uniform_(z, flow.nn.init.calculate_gain("relu")) flow.nn.init.xavier_normal_( z, flow.nn.init.calculate_gain("leaky_relu", 0.2) ) flow.nn.init.xavier_uniform_( z, flow.nn.init.calculate_gain("leaky_relu", 0.2) ) flow.nn.init.trunc_normal_(z, mean=0.0, std=1.0, a=-2.0, b=2.0) flow.nn.init.normal_(z, mean=0.0, std=1.0) flow.nn.init.orthogonal_(z) x = tensor_creator(*shape).to(dtype=flow.int32) np_ones = np.ones(x.shape, dtype=np.int32) np_zeros = np.zeros(x.shape, dtype=np.int32) random_fill_val = -2 x.fill_(random_fill_val) test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones)) flow.nn.init.ones_(x) test_case.assertTrue(np.allclose(get_numpy(x), np_ones)) flow.nn.init.zeros_(x) test_case.assertTrue(np.allclose(get_numpy(x), np_zeros)) flow.nn.init.constant_(x, random_fill_val) test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones)) x.zero_() test_case.assertTrue(np.array_equal(get_numpy(x), np_zeros)) test_case.assertEqual(flow.nn.init.calculate_gain("conv2d"), 1) test_case.assertEqual(flow.nn.init.calculate_gain("tanh"), 5.0 / 3) def _test_non_contiguous_tensor_init_methods(test_case, tensor_creator, get_numpy): shape = (8, 8) x = flow.zeros(shape) sliced_x = x[::2, 1::2] not_sliced_x = x[1::2, ::2] random_fill_val = 923.53 np_zeros = np.zeros((4, 4)) # ones flow.nn.init.ones_(sliced_x) test_case.assertTrue(np.allclose(get_numpy(sliced_x), np.ones((4, 4)))) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # constant flow.nn.init.constant_(sliced_x, random_fill_val) test_case.assertTrue( np.allclose(get_numpy(sliced_x), np.ones((4, 4)) * random_fill_val) ) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # eye flow.nn.init.eye_(sliced_x) test_case.assertTrue(np.allclose(get_numpy(sliced_x), np.eye(4))) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # kaiming_normal_ flow.nn.init.kaiming_normal_( sliced_x, a=0.1, mode="fan_out", nonlinearity="relu" ) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # kaiming_uniform_ flow.nn.init.kaiming_uniform_(sliced_x) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # xavier_normal_ with relu gain flow.nn.init.xavier_normal_(sliced_x, flow.nn.init.calculate_gain("relu")) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # xavier_uniform_ with relu gain flow.nn.init.xavier_uniform_(sliced_x, flow.nn.init.calculate_gain("relu")) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # trunc_normal_ flow.nn.init.trunc_normal_(sliced_x, mean=0.0, std=1.0, a=-2.0, b=2.0) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # normal_ flow.nn.init.normal_(sliced_x, mean=0.0, std=1.0) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) # orthogonal_ flow.nn.init.orthogonal_(sliced_x) test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros)) @flow.unittest.skip_unless_1n1d() def test_local_tensor_init_methods(test_case): for device in ["cpu", "cuda"]: test_case._test_tensor_init_methods( lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy(), ) test_case._test_non_contiguous_tensor_init_methods( lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy(), ) @flow.unittest.skip_unless_1n2d() def test_global_tensor_init_methods(test_case): for device in ["cpu", "cuda"]: test_case._test_tensor_init_methods( lambda *args, **kwargs: flow.Tensor( *args, **kwargs, sbp=flow.sbp.broadcast, placement=flow.placement(device, range(2)) ), lambda x: x.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), ) @flow.unittest.skip_unless_1n1d() def test_tensor_with_single_int(test_case): x = flow.Tensor(5) test_case.assertEqual(x.shape, flow.Size([5])) x = flow.tensor(5) test_case.assertEqual(x.numpy().item(), 5) @flow.unittest.skip_unless_1n1d() def test_tensor_device(test_case): shape = (2, 3, 4, 5) x = flow.Tensor(*shape) test_case.assertTrue(not x.is_cuda) x = flow.Tensor(*shape, device=flow.device("cuda")) test_case.assertTrue(x.is_cuda) x = flow.Tensor(*shape, device=flow.device("cpu")) test_case.assertTrue(not x.is_cuda) @flow.unittest.skip_unless_1n1d() @autotest(n=1, check_graph=True) def test_tensor_set_data_autograd_meta(test_case): x = torch.ones(2, 3).requires_grad_(True) y = x + x z = torch.zeros(2, 3) z.data = y return z.grad_fn, z.is_leaf @flow.unittest.skip_unless_1n1d() def test_tensor_set_data(test_case): a = flow.ones(2, 3, requires_grad=False) b = flow.ones(4, 5, requires_grad=True).to("cuda") old_id = id(a) a.data = b test_case.assertEqual(old_id, id(a)) test_case.assertTrue(a.shape == (4, 5)) test_case.assertTrue(a.device == flow.device("cuda")) test_case.assertFalse(a.requires_grad) test_case.assertTrue(a.is_leaf) @flow.unittest.skip_unless_1n1d() def test_tensor_set_ref_tensor(test_case): a = flow.ones(2, 3, requires_grad=False) b = flow.ones(4, 5, requires_grad=True).to("cuda") test_case.assertEqual(a._ref_tensor, None) test_case.assertEqual(a._ref_index, 0) a._ref_tensor = b a._ref_index = 200 test_case.assertTrue(id(a._ref_tensor), id(b)) test_case.assertTrue(a._ref_tensor.shape == (4, 5)) test_case.assertTrue(a._ref_tensor.device == flow.device("cuda")) test_case.assertTrue(a._ref_tensor.requires_grad) test_case.assertTrue(a._ref_index, 200) @flow.unittest.skip_unless_1n1d() def test_tensor_unsupported_property(test_case): shape = (2, 3, 4, 5) x = flow.Tensor(*shape) test_case.assertTrue(x.is_local) with test_case.assertRaises(RuntimeError): x.global_id() with test_case.assertRaises(RuntimeError): x.sbp with test_case.assertRaises(RuntimeError): x.placement if x.dtype != flow.tensor_buffer: with test_case.assertRaises(RuntimeError): x._tensor_buffer_shapes_and_dtypes @flow.unittest.skip_unless_1n1d() def test_tensor_to_bool(test_case): x = flow.tensor([0.0]) test_case.assertFalse(bool(x)) x = flow.tensor([0.0]).to("cuda") test_case.assertFalse(bool(x)) x = flow.tensor([1.5]) test_case.assertTrue(bool(x)) x = flow.tensor([3]) test_case.assertTrue(bool(x)) with test_case.assertRaises(RuntimeError): bool(flow.tensor([1, 3, 5])) bool(flow.tensor([])) @flow.unittest.skip_unless_1n1d() def test_tensor_autograd_fill_cpu(test_case): shape = (2, 3, 4, 5) x = flow.Tensor(*shape) y = flow.Tensor(*shape) x.fill_(1.0) y.fill_(flow.tensor(1.0)) y.requires_grad = True z = x + y test_case.assertFalse(x.requires_grad) test_case.assertTrue(x.is_leaf) test_case.assertTrue(y.requires_grad) test_case.assertTrue(y.is_leaf) test_case.assertTrue(z.requires_grad) test_case.assertFalse(z.is_leaf) with flow.no_grad(): m = x + y test_case.assertTrue(m.is_leaf) test_case.assertFalse(m.requires_grad) m.requires_grad = True v = flow.Tensor(*shape) v.requires_grad = True z.retain_grad() w = v + z grad = flow.Tensor(*shape) grad.fill_(1.0) w.backward(gradient=grad, retain_graph=True) test_case.assertTrue( np.allclose(v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(z.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertIsNone(x.grad) test_case.assertIsNotNone(y.grad) w.backward(gradient=grad, retain_graph=True) # autocast test for fill_ x = flow.tensor([2.4, 3.5], device="cuda", dtype=flow.float16) with flow.amp.autocast("cuda", flow.float16): y = x.clone() y.fill_(2.36) test_case.assertTrue(y.dtype == flow.float16) @flow.unittest.skip_unless_1n1d() def test_tensor_autograd_fill_cuda(test_case): shape = (2, 3, 4, 5) x = flow.Tensor(*shape).to("cuda:0") y = flow.Tensor(*shape).to("cuda:0") x.fill_(1.0) y.fill_(flow.tensor(1.0).to("cuda:0")) y.requires_grad = True z = x + y test_case.assertFalse(x.requires_grad) test_case.assertTrue(x.is_leaf) test_case.assertTrue(y.requires_grad) test_case.assertTrue(y.is_leaf) test_case.assertTrue(z.requires_grad) test_case.assertFalse(z.is_leaf) with flow.no_grad(): m = x + y test_case.assertTrue(m.is_leaf) test_case.assertFalse(m.requires_grad) m.requires_grad = True v = flow.Tensor(*shape).to("cuda:0") v.requires_grad = True z.retain_grad() w = v + z grad = flow.Tensor(*shape) grad.fill_(1.0) w.backward(gradient=grad, retain_graph=True) test_case.assertTrue( np.allclose(v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertTrue( np.allclose(z.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4) ) test_case.assertIsNone(x.grad) test_case.assertIsNotNone(y.grad) w.backward(gradient=grad, retain_graph=True) @flow.unittest.skip_unless_1n1d() def test_tensor_register_post_grad_accumulation_hook(test_case): shape = (2, 3) x = flow.Tensor(*shape) x.requires_grad = True x._register_post_grad_accumulation_hook(lambda grad: grad * 2 + 1) y = x.sum() + (x * 2).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones(shape) * 7, atol=1e-4, rtol=1e-4) ) x = flow.Tensor(*shape) x.requires_grad = True def inplace_add_and_return_none(x): x.add_(1) return None x._register_post_grad_accumulation_hook(inplace_add_and_return_none) y = x.sum() + (x * 2).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones(shape) * 4, atol=1e-4, rtol=1e-4) ) @flow.unittest.skip_unless_1n1d() def test_tensor_register_hook(test_case): shape = (2, 3) x = flow.Tensor(*shape) x.requires_grad = True x.register_hook(lambda grad: grad * 2 + 1) y = x.sum() + (x * 2).sum() y.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones(shape) * 7, atol=1e-4, rtol=1e-4) ) x = flow.Tensor(*shape) x.requires_grad = True new_grad = flow.Tensor([[1, 2, 3], [4, 5, 6]]) x.register_hook(lambda _: new_grad) y = x.sum() + (x * 2).sum() y.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), new_grad.numpy())) grad_nonlocal = None def assign_nonlocal_variable_and_return_none(grad): nonlocal grad_nonlocal grad_nonlocal = grad x = flow.Tensor(*shape) x.requires_grad = True new_grad = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.float32) x.register_hook(assign_nonlocal_variable_and_return_none) y = x.sum() + (x * 2).sum() y.backward() test_case.assertTrue(np.allclose(grad_nonlocal.numpy(), np.ones(shape) * 3)) @flow.unittest.skip_unless_1n1d() def test_non_leaf_tensor_register_hook(test_case): shape = (2, 3) x = flow.Tensor(*shape).requires_grad_() y = x + 1 y.register_hook(lambda grad: grad * 2) z1 = y * 2 z2 = y * 3 loss = (z1 + z2).sum() loss.backward(retain_graph=True) loss.backward() test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(shape) * 20)) @flow.unittest.skip_unless_1n1d() def test_user_defined_data(test_case): list_data = [5, 5] tuple_data = (5, 5) numpy_data = np.array((5, 5)) x = flow.Tensor(list_data) y = flow.Tensor(tuple_data) z = flow.Tensor(numpy_data) test_case.assertTrue(np.allclose(x.numpy(), 5 * np.ones(x.shape))) test_case.assertTrue(np.allclose(y.numpy(), 5 * np.ones(y.shape))) test_case.assertTrue(np.allclose(z.numpy(), 5 * np.ones(z.shape))) @flow.unittest.skip_unless_1n1d() def test_local_tensor_and_op(test_case): x1 = flow.Tensor([[1.0, 2.0]]) test_case.assertEqual(x1.dtype, flow.float32) test_case.assertEqual(x1.shape, flow.Size((1, 2))) x2 = flow.Tensor([[1.0], [2.0]]) y = flow.matmul(x1, x2) test_case.assertTrue( np.allclose(y.numpy(), np.array([[5.0]], dtype=np.float32)) ) @flow.unittest.skip_unless_1n1d() @autotest(n=5, rtol=1e-2, atol=1e-3) def test_matmul_with_random_data(test_case): device = random_device() dim0 = random(low=2, high=10).to(int) dim1 = random(low=3, high=20).to(int) dim2 = random(low=2, high=11).to(int) a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) b = random_tensor(ndim=2, dim0=dim1, dim1=dim2).to(device) return a @ b @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_mv_with_random_data(test_case): device = random_device() dim0 = random(low=2, high=10).to(int) dim1 = random(low=3, high=20).to(int) a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) b = random_tensor(ndim=1, dim0=dim1).to(device) return a.mv(b) @flow.unittest.skip_unless_1n1d() @autotest(check_graph=True, rtol=1e-2, atol=1e-3) def test_mm_with_random_data(test_case): device = random_device() dim0 = random(low=2, high=10).to(int) dim1 = random(low=3, high=20).to(int) dim2 = random(low=2, high=11).to(int) a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device) b = random_tensor(ndim=2, dim0=dim1, dim1=dim2).to(device) return a.mm(b) @flow.unittest.skip_unless_1n1d() def test_tensor_to_list(test_case): list_data = [[1.0, 3.0], [5.0, 6.0]] input = flow.Tensor(list_data) test_case.assertEqual(list_data, input.tolist()) @flow.unittest.skip_unless_1n1d() def test_tensor_nelement(test_case): shape = (2, 3, 4) input = flow.Tensor(*shape) test_case.assertEqual(input.nelement(), 24) @flow.unittest.skip_unless_1n1d() def test_tensor_numel(test_case): shape = (2, 3, 4, 5) input = flow.Tensor(*shape) test_case.assertEqual(input.numel(), 120) @flow.unittest.skip_unless_1n1d() def test_tensor_print(test_case): shape = (2, 3, 4, 5) input = flow.Tensor(*shape) input_str = str(input) test_case.assertTrue(input_str.startswith("tensor(")) test_case.assertTrue("device=" not in input_str) gpu_input = flow.Tensor(*shape, device="cuda") gpu_input_str = str(gpu_input) test_case.assertTrue("device=" in gpu_input_str) test_case.assertTrue("cuda:0" in gpu_input_str) requires_grad_input = flow.Tensor(*shape) requires_grad_input.requires_grad = True requires_grad_input_str = str(requires_grad_input) test_case.assertTrue("requires_grad=" in requires_grad_input_str) @unittest.skip("skip for now, becase it failed 2 times in past week") @flow.unittest.skip_unless_1n1d() def test_indexing(test_case): class SliceExtracter: def __getitem__(self, key): return key se = SliceExtracter() def compare_getitem_with_numpy(tensor, slices): np_arr = tensor.numpy() test_case.assertTrue(np.allclose(np_arr[slices], tensor[slices].numpy())) def compare_setitem_with_numpy(tensor, slices, value): np_arr = tensor.numpy() if isinstance(value, flow.Tensor): np_value = value.numpy() else: np_value = value np_arr[slices] = np_value tensor[slices] = value test_case.assertTrue(np.allclose(np_arr, tensor.numpy(), rtol=1e-4)) x = flow.randn(5, 5) v = flow.Tensor([[0, 1, 2, 3, 4]]) compare_getitem_with_numpy(x, se[-4:-1:2]) compare_getitem_with_numpy(x, se[-1:]) compare_setitem_with_numpy(x, se[-1:], v) compare_setitem_with_numpy(x, se[2::2], 2) x = flow.Tensor(2, 3, 4) v = flow.Tensor(3) compare_setitem_with_numpy(x, se[:, :, 2], v) x = flow.Tensor(2, 3, 4) compare_setitem_with_numpy(x, se[1, :, 2], v) @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=False) def test_setitem_with_random_data(test_case): device = random_device() x = random_tensor(low=0, high=0, ndim=1, dim0=16, requires_grad=False).to( device ) y = random_tensor(low=-2, high=2, ndim=1, dim0=16).to(device) idx = random_tensor( low=0, high=15, ndim=1, dim0=20, dtype=int, requires_grad=False ).to(device) getitem_of = y.oneflow[idx.oneflow] getitem_torch = y.pytorch[idx.pytorch] test_case.assertTrue( np.allclose(getitem_of.numpy(), getitem_torch.detach().cpu().numpy()) ) x.oneflow[idx.oneflow] = getitem_of x.pytorch[idx.pytorch] = getitem_torch return x @flow.unittest.skip_unless_1n1d() def test_div(test_case): x = flow.Tensor(np.random.randn(1, 1)) y = flow.Tensor(np.random.randn(2, 3)) of_out = x / y np_out = np.divide(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = x / 3 np_out = np.divide(x.numpy(), 3) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = 3 / x np_out = np.divide(3, x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(1)) of_out = 3 / x np_out = np.divide(3, x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() def test_mul(test_case): x = flow.Tensor(np.random.randn(1, 1)) y = flow.Tensor(np.random.randn(2, 3)) of_out = x * y np_out = np.multiply(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = x * 3 np_out = np.multiply(x.numpy(), 3) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = 3 * x np_out = np.multiply(3, x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_mul_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor( low=-2, high=2, ndim=4, dim0=16, dim1=9, dim2=4, dim3=7 ).to(device) y = rand_tensor + 1 x = random_tensor(low=-2, high=2, ndim=4, dim0=16, dim1=9, dim2=4, dim3=7).to( device ) y.mul_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_broadcast_mul_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor(ndim=3, dim0=4, dim1=8, dim2=13).to(device) y = rand_tensor + 1 x = random_tensor(ndim=2, dim0=8, dim1=13).to(device) y.mul_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_div_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor( low=-2, high=2, ndim=4, dim0=26, dim1=7, dim2=4, dim3=17 ).to(device) y = rand_tensor + 1 x = random_tensor(low=-2, high=2, ndim=4, dim0=26, dim1=7, dim2=4, dim3=17).to( device ) y.div_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_broadcast_div_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor(ndim=3, dim0=4, dim1=8, dim2=13).to(device) y = rand_tensor + 1 x = random_tensor(ndim=2, dim0=8, dim1=13).to(device) y.div_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_add_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor( low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17 ).to(device) y = rand_tensor + 1 x = random_tensor(low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17).to( device ) y.add_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_broadcast_add_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor(ndim=3, dim0=5, dim1=9, dim2=23).to(device) y = rand_tensor + 1 x = random_tensor(ndim=2, dim0=9, dim1=23).to(device) y.add_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_sub_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor( low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17 ).to(device) y = rand_tensor + 1 x = random_tensor(low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17).to( device ) y.sub_(x) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_broadcast_sub_inplace_tensor(test_case): device = random_device() rand_tensor = random_tensor(ndim=3, dim0=5, dim1=9, dim2=23).to(device) y = rand_tensor + 1 x = random_tensor(ndim=2, dim0=9, dim1=23).to(device) y.sub_(x) return y @flow.unittest.skip_unless_1n1d() def test_add_tensor_method(test_case): x = flow.Tensor(np.random.randn(1, 1)) y = flow.Tensor(np.random.randn(2, 3)) of_out = x + y np_out = np.add(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = x + 3 np_out = np.add(x.numpy(), 3) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = 3 + x np_out = np.add(3, x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() def test_sub_tensor_method(test_case): x = flow.Tensor(np.random.randn(1, 1)) y = flow.Tensor(np.random.randn(2, 3)) of_out = x - y np_out = np.subtract(x.numpy(), y.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = x - 3 np_out = np.subtract(x.numpy(), 3) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) x = flow.Tensor(np.random.randn(2, 3)) of_out = 3 - x np_out = np.subtract(3, x.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() def test_sum(test_case): input = flow.tensor(np.random.randn(4, 5, 6), dtype=flow.float32) of_out = input.sum(dim=(2, 1)) np_out = np.sum(input.numpy(), axis=(2, 1)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() def test_argwhere(test_case): shape = (2, 3, 4, 5) precision = 1e-5 np_input = np.random.randn(*shape) input = flow.Tensor(np_input) of_out = input.argwhere() np_out = np.argwhere(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, precision, precision)) test_case.assertTrue(np.allclose(of_out.numpy().shape, np_out.shape)) @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=False, check_graph=True) def test_tensor_argmax_with_random_data(test_case): device = random_device() ndim = random(1, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = x.argmax(dim=random(0, ndim).to(int), keepdim=random().to(bool)) return y @autotest(auto_backward=False, check_graph=False) def test_max_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return x.max(dim) @autotest(auto_backward=False, check_graph=False) def test_min_bool_input_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float, requires_grad=False).to( device, dtype=torch.bool ) return x.min(dim) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_tanh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.tanh() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_asin_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = x.asin() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_arcsin_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = x.arcsin() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_asinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.asinh() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_arcsinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.arcsinh() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_sinh_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.sinh() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_flow_tensor_atan2_with_random_data(test_case): device = random_device() x1 = random_tensor(ndim=1, dim0=1).to(device) x2 = random_tensor(ndim=1, dim0=1).to(device) y = x1.atan2(x2) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_dot(test_case): device = random_device() k = random(10, 100) x = random_tensor(ndim=1, dim0=k).to(device) y = random_tensor(ndim=1, dim0=k).to(device) z = x.dot(y) return z @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_arccos_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=2, high=3).to(device) y = x.arccos() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_arccosh_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=2, high=3).to(device) y = x.arccosh() return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_acosh_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=2, high=3).to(device) y = x.acosh() return y @flow.unittest.skip_unless_1n1d() @autotest(auto_backward=False, check_graph=True) def test_sort_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.sort(dim=random(low=-4, high=4).to(int), descending=random_bool()) return y[0], y[1] @flow.unittest.skip_unless_1n1d() @autotest(auto_backward=False, check_graph=True) def test_sort_tensor_return_type(test_case): device = random_device() x = random_tensor(ndim=4).to(device) result = x.sort(dim=random(low=-4, high=4).to(int), descending=random_bool()) return result.values, result.indices @flow.unittest.skip_unless_1n1d() @autotest(auto_backward=False, check_graph=True) def test_argsort_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.argsort(dim=random(low=-4, high=4).to(int), descending=random_bool()) return y @autotest(n=5) def test_mean_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor(ndim=4, dtype=float).to(device) return x.mean(dim) @autotest(n=5) def test_log_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.log() @autotest(n=5) def test_log1p_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.log1p() @autotest(n=5) def test_log2_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.log2() @autotest(n=5) def test_log10_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.log10() @autotest(n=5) def test_neg_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return -x @autotest(n=5) def test_negative_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.negative() @autotest(n=5) def test_neg_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.neg() @autotest(auto_backward=False, check_graph=True) def test_greater_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=3, dim1=2, dim2=3).to(device) y = random_tensor(ndim=3, dim1=2, dim2=3).to(device) return x.gt(y) @autotest(auto_backward=False, check_graph=True) def test_less_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=3, dim1=2, dim2=3).to(device) y = random_tensor(ndim=3, dim1=2, dim2=3).to(device) return x.lt(y) @autotest(auto_backward=False, check_graph=True) def test_tensor_topk_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device) y = x.topk( random(low=1, high=8).to(int), dim=random(low=1, high=4).to(int) | nothing(), largest=random_bool() | nothing(), sorted=constant(True) | nothing(), ) return y[0], y[1] @autotest(auto_backward=False, check_graph=True) def test_tensor_topk_return_type(test_case): device = random_device() x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device) result = x.topk( random(low=1, high=8).to(int), dim=random(low=1, high=4).to(int), largest=random_bool(), sorted=constant(True), ) return result.values, result.indices @autotest(auto_backward=False, check_graph=True) def test_flow_fmod_element_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) other = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) return input.fmod(other) @autotest(auto_backward=False, check_graph=True) def test_flow_fmod_broadcast_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=constant(1), dim2=dim2).to(device) other = random_tensor(ndim=3, dim1=dim1, dim2=constant(1)).to(device) return input.fmod(other) @autotest(auto_backward=True, check_graph=True) def test_flow_fmod_scalar_with_random_data(test_case): device = random_device() dim1 = random().to(int) dim2 = random().to(int) input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device) other = 3 return input.fmod(other) @autotest(auto_backward=False, check_graph=True) def test_fmod_with_0_size_data(test_case): device = random_device() x = random_tensor(4, 2, 1, 0, 3).to(device) y = x.fmod(2) return y @autotest(n=5) def test_tensor_flip_list_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = x.flip(constant([0, 1, 2])) return y @autotest(n=5) def test_tensor_flip_tuple_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = x.flip(constant((0, 1, 2))) return y @autotest(n=5) def test_tensor_chunk_list_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) x = random_tensor( ndim=4, dim1=random(low=4, high=8).to(int), dim2=random(low=4, high=8).to(int), dim3=random(low=4, high=8).to(int), ).to(device) y = x.chunk(chunks=random(low=1, high=5).to(int), dim=dim) z = torch.cat(y, dim=dim) return z @autotest(n=5) def test_tensor_reciprocal_list_with_random_data(test_case): device = random_device() x = random_tensor( ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) ).to(device) y = x.reciprocal() return y @flow.unittest.skip_unless_1n1d() def test_tensor_slice(test_case): x = np.random.randn(2, 3, 4, 5).astype(np.float32) input = flow.tensor(x) test_case.assertTrue(np.allclose(input[0].numpy(), x[0], 1e-05, 1e-05)) test_case.assertTrue(np.allclose(input[1].numpy(), x[1], 1e-05, 1e-05)) test_case.assertTrue(np.allclose(input[0, :].numpy(), x[0, :], 1e-05, 1e-05)) test_case.assertTrue( np.allclose(input[0, :, 0:2].numpy(), x[0, :, 0:2], 1e-05, 1e-05) ) @flow.unittest.skip_unless_1n1d() def test_zeros_(test_case): shape = (2, 3) x = flow.tensor(np.random.randn(*shape), dtype=flow.float32) x.zero_() test_case.assertTrue(np.allclose(x.numpy(), np.zeros(shape))) @flow.unittest.skip_unless_1n1d() def test_construct_small_tensor(test_case): shape = (2, 3, 4, 5) np_arr = np.random.rand(*shape).astype(np.float32) tensor = flow.tensor(np_arr) test_case.assertTrue(np.allclose(tensor.numpy(), np_arr)) test_case.assertEqual(tensor.dtype, flow.float32) np_int_arr = np.random.randint(-100, high=100, size=shape, dtype=np.int32) tensor = flow.tensor(np_int_arr, dtype=flow.int32) test_case.assertEqual(tensor.dtype, flow.int32) list_data = [[1, 2.0], [5, 3]] tensor = flow.tensor(list_data) test_case.assertEqual(tensor.dtype, flow.float32) test_case.assertTrue( np.allclose(tensor.numpy(), np.array(list_data), 0.0001, 0.0001) ) tuple_data = ((1, 2, 5), (4, 3, 10)) tensor = flow.tensor(tuple_data) test_case.assertEqual(tensor.dtype, flow.int64) test_case.assertTrue(np.allclose(tensor.numpy(), np.array(tuple_data))) scalar = 5.5 tensor = flow.tensor(scalar) test_case.assertEqual(tensor.dtype, flow.float32) test_case.assertTrue( np.allclose(tensor.numpy(), np.array(scalar), 0.0001, 0.0001) ) @autotest(n=5) def test_tensor_floor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.floor() return y @autotest(n=5) def test_tensor_round_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.round() return y def _test_tensor_reshape(test_case): x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ).astype(np.float32) input = flow.tensor(x) of_shape = input.reshape(2, 2, 2, -1).numpy().shape np_shape = (2, 2, 2, 2) test_case.assertTrue(np.allclose(of_shape, np_shape)) @autotest(n=5) def test_flatten_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.flatten( start_dim=random(1, 6).to(int) | nothing(), end_dim=random(1, 6).to(int) | nothing(), ) return y @autotest(n=5) def test_reshape_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.reshape(-1) return y @autotest(n=1) def test_reshape_tensor_with_random_data_and_keyword(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.reshape(shape=[-1,]) return y @autotest(n=5) def test_reshape_as_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.reshape(-1) z = y.reshape_as(other=x) return z @autotest(n=5) def test_tensor_squeeze_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.squeeze(random().to(int)) return y @autotest(n=5) def test_flow_unsqueeze_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.unsqueeze(random(1, 3).to(int)) return y @autotest(n=3, auto_backward=False, check_graph=True) def test_flow_invert_with_random_data(test_case): device = random_device() x = random_tensor().to(device, dtype=torch.bool) y = ~x return y def test_tensor_float(test_case): x = flow.tensor(1) y = float(x) test_case.assertTrue(np.array_equal(y, 1.0)) def test_tensor_int(test_case): x = flow.tensor(2.3) y = int(x) test_case.assertTrue(np.array_equal(y, 2)) def test_none_equal(test_case): xt = flow.randn(10) yt = flow.randn(10) z = None in [xt, yt] test_case.assertTrue(np.array_equal(z, False)) zt = None z = None in [xt, yt, zt] test_case.assertTrue(np.array_equal(z, True)) def test_half(test_case): x = flow.tensor([1], dtype=flow.int64) test_case.assertTrue(x.dtype == flow.int64) y = x.half() test_case.assertTrue(y.dtype == flow.float16) def test_byte(test_case): x = flow.tensor([1.2], dtype=flow.float32) test_case.assertTrue(x.dtype == flow.float32) y = x.byte() test_case.assertTrue(y.dtype == flow.uint8) def test_tensor_constructor(test_case): x = flow.tensor([1, 2, 3]) test_case.assertTrue(np.array_equal(x.numpy(), [1, 2, 3])) test_case.assertEqual(x.dtype, flow.int64) x = flow.tensor([1.0, 2.0, 3.0]) test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0])) test_case.assertEqual(x.dtype, flow.float32) x = flow.tensor([1.0, 2.0, 3.0], dtype=flow.float64) test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0])) test_case.assertEqual(x.dtype, flow.float64) np_arr = np.array([1, 2, 3]) x = flow.tensor(np_arr) test_case.assertTrue(np.array_equal(x.numpy(), [1, 2, 3])) test_case.assertEqual(x.dtype, flow.int64) np_arr = np.array([1, 2, 3], dtype=np.float64) x = flow.tensor(np_arr) test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0])) test_case.assertEqual(x.dtype, flow.float64) x = flow.tensor(np_arr, dtype=flow.float32) test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0])) test_case.assertEqual(x.dtype, flow.float32) x = flow.tensor(np_arr, dtype=flow.int8) test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0])) test_case.assertEqual(x.dtype, flow.int8) x = flow.tensor([flow.tensor([1, 2])] * 3, dtype=flow.float32) test_case.assertTrue(np.array_equal(x.numpy(), [[1, 2], [1, 2], [1, 2]])) test_case.assertEqual(x.dtype, flow.float32) def test_tensor_contains_magic_method(test_case): x = flow.tensor([[1, 2, 3], [4, 5, 6]]) y = 1 in x test_case.assertEqual(y, True) @profile(torch.Tensor.fill_) def profile_fill_(test_case): torch.Tensor.fill_(torch.ones(1, 8, 16, 16), 2) torch.Tensor.fill_(torch.ones(1000, 1000), 2) torch.Tensor.fill_(torch.ones(1, 8, 16, 16), torch.tensor(2)) torch.Tensor.fill_(torch.ones(1000, 1000), torch.tensor(2)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_part_2.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import copy import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensor(flow.unittest.TestCase): @autotest(n=10) def test_permute_flow_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) permute_list = [0, 1, 2, 3] np.random.shuffle(permute_list) y = x.permute(permute_list) return y @autotest(n=1) def test_permute_flow_with_random_data_and_keyword(test_case): device = random_device() x = random_tensor(ndim=4).to(device) permute_list = [0, 1, 2, 3] np.random.shuffle(permute_list) y = x.permute(dims=permute_list) return y @autotest(n=5) def test_transpose_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) permute_list = np.random.permutation(4) y = x.transpose(permute_list[0], permute_list[1]) return y @autotest(n=5) def test_t_tensor_with_random_data(test_case): device = random_device() x = random_tensor( ndim=constant(2).to(int), dim0=random(0, 64), dim1=random(0, 64) ).to(device) y = x.t() return y @autotest(n=5) def test_T_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=random(1, 4)).to(device) y = x.T return y def test_tensor_where(test_case): x = flow.tensor( np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32, ) y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32) condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) of_out = condition.where(x, y) np_out = np.array([[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]]) test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) def test_tensor_equal(test_case): arr1 = np.random.randint(1, 10, size=(2, 3, 4, 5)) arr2 = np.random.randint(1, 10, size=(2, 3, 4, 5)) input = flow.tensor(arr1, dtype=flow.float32) other = flow.tensor(arr2, dtype=flow.float32) of_out = input.eq(other) np_out = np.equal(arr1, arr2) test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) def test_tensor_equal_bool_dtype(test_case): np_bool = np.random.randint(0, 2, size=()).astype(bool).item() input = flow.tensor(np_bool, dtype=flow.bool) input2 = flow.tensor([np_bool], dtype=flow.bool) test_case.assertTrue(input == np_bool) test_case.assertTrue(input2 == np_bool) def test_tensor_detach(test_case): shape = (2, 3, 4, 5) x = flow.tensor(np.random.randn(*shape), dtype=flow.float32, requires_grad=True) test_case.assertTrue(np.allclose(x.detach().numpy(), x.numpy(), 0.0001, 0.0001)) test_case.assertEqual(x.detach().requires_grad, False) y = x * 2 z = y.detach() test_case.assertEqual(z.is_leaf, True) test_case.assertEqual(z.grad_fn, None) def _test_cast_tensor_function(test_case): shape = (2, 3, 4, 5) np_arr = np.random.randn(*shape).astype(np.float32) input = flow.tensor(np_arr, dtype=flow.float32) output = input.cast(flow.int8) np_out = np_arr.astype(np.int8) test_case.assertTrue(np.allclose(output.numpy(), np_out)) def _test_sin_tensor_function(test_case, shape, device): input = flow.Tensor(np.random.randn(2, 3, 4, 5)) of_out = input.sin() np_out = np.sin(input.numpy()) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def test_cos_tensor_function(test_case): arr = np.random.randn(2, 3, 4, 5) input = flow.tensor(arr, dtype=flow.float32) np_out = np.cos(arr) of_out = input.cos() test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) def test_std_tensor_function(test_case): np_arr = np.random.randn(9, 8, 7, 6) input = flow.Tensor(np_arr) of_out = input.std(dim=1, unbiased=False, keepdim=False) np_out = np.std(np_arr, axis=1) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-04, 1e-04)) def test_sqrt_tensor_function(test_case): input_arr = np.random.rand(1, 6, 3, 8) np_out = np.sqrt(input_arr) x = flow.Tensor(input_arr) of_out = x.sqrt() test_case.assertTrue( np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) ) def test_rsqrt_tensor_function(test_case): np_arr = np.random.rand(3, 2, 5, 7) np_out = 1 / np.sqrt(np_arr) x = flow.Tensor(np_arr) of_out = flow.rsqrt(x) test_case.assertTrue( np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) ) def test_square_tensor_function(test_case): np_arr = np.random.randn(2, 7, 7, 3) np_out = np.square(np_arr) x = flow.Tensor(np_arr) of_out = x.square() test_case.assertTrue( np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) ) # This test will fail with the rtol and atol constraint under pytorch1.10, but success with pytorch 1.13. # The constraints should be removed in the future. @autotest(n=5, rtol=1e-3, atol=1e-3) def test_addmm_tensor_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=3).to(device) mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device) mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device) y = input.addmm( mat1, mat2, beta=random().to(float) | nothing(), alpha=random().to(float) | nothing(), ) return y # This test will fail with the rtol and atol constraint under pytorch1.10, but success with pytorch 1.13. # The constraints should be removed in the future. @autotest(n=5, rtol=1e-3, atol=1e-2) def test_addmm_broadcast_tensor_with_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=1, dim1=1).to(device) mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device) mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device) y = input.addmm( mat1, mat2, beta=random().to(float) | nothing(), alpha=random().to(float) | nothing(), ) return y @autotest(n=5) def test_clamp_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clamp_inplace_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float), ) return y @autotest(auto_backward=False) def test_clamp_inplace_tensor_no_grad_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clamp_minnone_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp( min=random(low=-1, high=-0.5).to(float) | nothing(), max=random(low=0.5, high=1).to(float), ) return y @flow.unittest.skip_unless_1n1d() @autotest(auto_backward=False) def test_clamp_minnone_tensor_no_grad_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp( min=random(low=-1, high=-0.5).to(float) | nothing(), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clamp_inplace_minnone_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float) | nothing(), max=random(low=0.5, high=1).to(float), ) return y @autotest(auto_backward=False) def test_clamp_inplace_minnone_tensor_no_grad_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float) | nothing(), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clamp_maxnone_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(auto_backward=False) def test_clamp_maxnone_tensor_no_grad_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(n=5) def test_clamp_inplace_maxnone_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(auto_backward=False) def test_clamp_inplace_maxnone_tensor_no_grad_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(n=5) def test_clamp_min_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp_min(random(low=-0.5, high=0.5).to(float)) return y @autotest(n=5) def test_clamp_min_inplace_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_min_(random(low=-0.5, high=0.5).to(float)) return y @autotest(auto_backward=False) def test_clamp_min_tensor_no_grad_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp_min(random(low=-0.5, high=0.5).to(float)) return y @autotest(auto_backward=False) def test_clamp_min_inplace_tensor_no_grad_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_min_(random(low=-0.5, high=0.5).to(float)) return y @autotest(n=5) def test_clamp_max_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp_max(random(low=-0.5, high=0.5).to(float)) return y @autotest(n=5) def test_clamp_max_inplace_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_max_(random(low=-0.5, high=0.5).to(float)) return y @autotest(auto_backward=False) def test_clamp_max_tensor_no_grad_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clamp_max(random(low=-0.5, high=0.5).to(float)) return y @autotest(auto_backward=False) def test_clamp_max_inplace_tensor_no_grad_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clamp_max_(random(low=-0.5, high=0.5).to(float)) return y @autotest(n=5) def test_clip_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clip( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clip_inplace_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clip_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clip_minnone_tensor_with_random_data(test_case): device = random_device() input = random_tensor(low=-2, high=2).to(device) y = input.clip( min=random(low=-1, high=-0.5).to(float) | nothing(), max=random(low=0.5, high=1).to(float), ) return y @autotest(n=5) def test_clip_inplace_maxnone_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clip_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(n=5) def test_clip_maxnone_tensor_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = input.clip( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(n=5) def test_clip_inplace_maxnone_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-2, high=2).to(device) y = x + 1 y.clip_( min=random(low=-1, high=-0.5).to(float), max=random(low=0.5, high=1).to(float) | nothing(), ) return y @autotest(n=5) def test_ceil_tensor_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = len(input) return y @autotest(n=5) def test_ceil_tensor_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = input.ceil() return y @autotest(n=5) def test_expm1_tensor_with_random_data(test_case): device = random_device() input = random_tensor().to(device) y = input.expm1() return y @autotest(n=5) def test_floor_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.floor() return y @autotest(n=5) def test_tensor_var_all_dim_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.var() return y # TODO(): 'var backward' is composed of several other ops, # reducemean doesn't support 0-shape for now @autotest(n=5, auto_backward=False) def test_tensor_var_one_dim_with_random_data(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.var( dim=random(low=0, high=4).to(int), unbiased=random().to(bool), keepdim=random().to(bool), ) return y def test_norm_tensor_function(test_case): input = flow.tensor( np.array([[-4.0, -3.0, -2.0], [-1.0, 0.0, 1.0], [2.0, 3.0, 4.0]]), dtype=flow.float32, ) of_out_1 = input.norm("fro") np_out_1 = np.linalg.norm(input.numpy(), "fro") of_out_2 = input.norm(2, dim=1) np_out_2 = np.linalg.norm(input.numpy(), ord=2, axis=1) of_out_3 = input.norm(float("inf"), dim=0, keepdim=True) np_out_3 = np.linalg.norm( input.numpy(), ord=float("inf"), axis=0, keepdims=True ) test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05)) test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05)) @autotest(n=5) def test_pow_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = random().to(float) z = x.pow(y) return z @autotest(n=5) def test_atanh_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.49).to(device) y = x.atanh() return y @autotest(n=5) def test_acos_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.49).to(device) y = x.acos() return y @autotest(n=5) def test_acosh_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=2.0, high=3.0).to(device) y = x.acosh() return y @autotest(n=5) def test_atan_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.atan() return y @autotest(n=5) def test_arctan_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.arctan() return y @autotest(n=5) def test_tan_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) y = x.tan() return y @autotest(n=5) def test_tan2_tensor_with_random_data(test_case): device = random_device() x = random_tensor(ndim=2, dim1=3).to(device) y = random_tensor(ndim=2, dim1=3).to(device) z = x.atan2(y) return z @autotest(n=5) def test_arctanh_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-0.5, high=0.5).to(device) y = x.arctanh() return y # Not check graph because of one reason: # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor. # Please refer to File "python/oneflow/nn/modules/nonzero.py", line 29, in nonzero_op. @autotest(n=5, auto_backward=False, check_graph="ValidatedFalse") def test_tensor_nonzero_with_random_data(test_case): device = random_device() ndim = random(2, 6).to(int) x = random_tensor(ndim=ndim).to(device) y = x.nonzero() return y @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), "numpy doesn't work in lazy mode", ) def test_tensor_fmod(test_case): x = flow.Tensor(np.random.uniform(-100, 100, (5, 5))) x.requires_grad = True y = np.random.uniform(-10, 10) of_out = x.fmod(y) np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001) ) @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), "numpy doesn't work in lazy mode", ) def test_magic_fmod(test_case): x = flow.Tensor(np.random.uniform(-100, 100, (5, 5))) x.requires_grad = True y = np.random.uniform(-10, 10) of_out = x % y np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001) ) def test_tensor_mish(test_case): def np_mish(x): f = 1 + np.exp(x) y = x * ((f * f - 1) / (f * f + 1)) y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / ( (f * f + 1) * (f * f + 1) ) return [y, y_grad] np_input = np.random.randn(2, 4, 5, 6) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.mish() (np_out, np_grad) = np_mish(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-05, 1e-05)) def test_tensor_triu(test_case): def np_triu(x, diagonal): y = np.triu(x, diagonal) y_grad = np.triu(np.ones_like(x), diagonal) return [y, y_grad] diagonal_list = [2, -1] for diagonal in diagonal_list: np_input = np.random.randn(2, 4, 6) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.triu(diagonal) (np_out, np_grad) = np_triu(np_input, diagonal) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) of_out = of_out.sum() of_out.backward() test_case.assertTrue( np.allclose(of_input.grad.numpy(), np_grad, 1e-05, 1e-05) ) def test_tensor_grad_assignment(test_case): np_input = np.random.randn(2, 4, 5, 6) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_output = 2 * of_input of_output = of_output.sum() of_output.backward() new_grad = flow.tensor( np.full(np_input.shape, np.random.randn(1)), dtype=flow.float32 ) of_input.grad = new_grad test_case.assertTrue( np.allclose(of_input.grad.detach().numpy(), new_grad.numpy(), 1e-05, 1e-05) ) of_input.grad = None test_case.assertTrue(of_input.grad is None) def test_tensor_grad_assignment_sum(test_case): np_input = np.random.randn(1, 5, 7, 3) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_output = of_input.sum() of_output.backward() rand_init = np.random.randn(1) rand_scale = np.random.randn(1) new_grad = flow.tensor(np.full(np_input.shape, rand_init), dtype=flow.float32) of_input.grad = new_grad of_output = flow.tensor(rand_scale, dtype=flow.float32) * of_input of_output = of_output.sum() of_output.backward() test_case.assertTrue( np.allclose( of_input.grad.detach().numpy(), np.full(np_input.shape, rand_init + rand_scale), 1e-05, 1e-05, ) ) of_input.grad = of_input.grad * 2 test_case.assertTrue( np.allclose( of_input.grad.detach().numpy(), 2 * np.full(np_input.shape, rand_init + rand_scale), 1e-05, 1e-05, ) ) def test_tensor_mish(test_case): def np_mish(x): f = 1 + np.exp(x) y = x * ((f * f - 1) / (f * f + 1)) y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / ( (f * f + 1) * (f * f + 1) ) return [y, y_grad] np_input = np.random.randn(2, 4, 5, 6,) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.mish() np_out, np_grad = np_mish(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) def test_tensor_silu(test_case): def np_silu(x): _sig = 1 / (1 + np.exp(-x)) y = x * _sig y_grad = _sig * (1 + x * (1 - _sig)) return [y, y_grad] np_input = np.random.randn(2, 4, 5, 6,) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.silu() np_out, np_grad = np_silu(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) def test_tensor_selu(test_case): _scale = 1.0507009873554804934193349852946 _alpha = 1.6732632423543772848170429916717 def np_selu(x): y = np.where(x < 0, _scale * _alpha * (np.exp(x) - 1), _scale * x) y_grad = np.where(x < 0, _scale * _alpha * np.exp(x), _scale) return [y, y_grad] np_input = np.random.randn(2, 4, 5, 6,) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.selu() np_out, np_grad = np_selu(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) @unittest.skip("still have error in ci") def test_tensor_softsign(test_case): def np_softsign(x): y = x / (1 + np.abs(x)) y_grad = 1 / np.square(1 + np.abs(x)) return [y, y_grad] np_input = np.random.randn(2, 4, 5, 6,) of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True) of_out = of_input.softsign() np_out, np_grad = np_softsign(np_input) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) of_out = of_out.sum() of_out.backward() test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) @autotest(auto_backward=False) def test_eq_tensor_with_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) y = random_tensor(len(shape), *shape, requires_grad=False).to(device) return x.eq(y) @autotest(auto_backward=False) def test_eq_tensor_with_same_random_data(test_case): device = random_device() shape = random_tensor().oneflow.shape x = random_tensor(len(shape), *shape, requires_grad=False).to(device) return x.eq(x) @autotest(n=5) def test_erf_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.erf() @autotest(n=5) def test_erfc_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.erfc() @autotest( auto_backward=False ) # Todo: After add gradient func, you should set `auto_backward` as True def test_erfinv_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-1, high=1).to(device).requires_grad_(False) return x.erfinv() @autotest( n=10, auto_backward=False ) # Todo: After add gradient func, you should set `auto_backward` as True def test_erfinv_inplace_tensor_with_random_data(test_case): device = random_device() x = random_tensor(low=-1, high=1).to(device).requires_grad_(False) y = x + 1 y.erfinv_() return y @autotest(n=5) def test_exp_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.exp() @autotest(n=5) def test_exp2_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.exp2() @autotest(n=5) def test_round_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.round() @autotest(n=5) def test_tensor_diag_one_dim(test_case): device = random_device() x = random_tensor(ndim=1, dim0=random()).to(device) return x.diag() @autotest(n=5) def test_flow_tensor_expand_with_random_data(test_case): random_expand_size = random(1, 6).to(int).value() x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1) ndim = 5 expand_size = random_expand_size dim_size = [1,] * ndim random_index = random(0, ndim).to(int).value() dim_size[random_index] = expand_size return x.expand(*dim_size) @autotest(n=5) def test_flow_tensor_expand_with_random_data(test_case): random_expand_size = random(1, 6).to(int).value() x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1) ndim = 5 expand_size = random_expand_size dim_size = [1,] * ndim random_index = random(0, ndim).to(int).value() dim_size[random_index] = expand_size y = torch.ones(dim_size) return x.expand_as(y) @autotest(n=5) def test_flow_tensor_view_with_random_data(test_case): dim0_ = random(2, 4).to(int) dim1_ = random(2, 4).to(int) dim2_ = random(2, 4).to(int) dim3_ = random(2, 4).to(int) dim4_ = random(2, 4).to(int) x = random_tensor( ndim=5, dim0=dim0_, dim1=dim1_, dim2=dim2_, dim3=dim3_, dim4=dim4_ ) shape = [x.value() for x in [dim4_, dim3_, dim2_, dim1_, dim0_]] return [x.view(shape), x.view(size=shape)] @autotest(n=5) def test_flow_tensor_view_as_with_random_data(test_case): dim0_ = random(2, 4).to(int) dim1_ = random(2, 4).to(int) dim2_ = random(2, 4).to(int) dim3_ = random(2, 4).to(int) dim4_ = random(2, 4).to(int) x = random_tensor( ndim=5, dim0=dim0_, dim1=dim1_, dim2=dim2_, dim3=dim3_, dim4=dim4_ ) other = random_tensor( ndim=5, dim0=dim4_, dim1=dim3_, dim2=dim2_, dim3=dim1_, dim4=dim0_ ) return x.view_as(other) @autotest(n=5) def test_tensor_diag_other_dim(test_case): device = random_device() x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device) return x.diag() @autotest(auto_backward=False) def test_floordiv_elementwise_tensor_with_random_data(test_case): device = random_device() # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. input = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device) other = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device) y = input.floor_divide(other) return y @autotest(auto_backward=False) def test_scalar_floordiv_tensor_with_random_data(test_case): device = random_device() # The random value is narrowed to positive number because of the error from pytorch 1.10.0 # Please remove the value range striction after updating the pytorch version of ci to 1.13. input = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device) other = random().to(int) y = input.floor_divide(other) return y @flow.unittest.skip_unless_1n4d() def test_construct_global_tensor_by_numpy(test_case): x = np.ones((4, 4), dtype=np.int32) placement = flow.placement("cuda", [0, 1, 2, 3]) y = flow.tensor( x, dtype=flow.float32, placement=placement, sbp=[flow.sbp.split(0)], requires_grad=False, ) test_case.assertTrue(y.dtype == flow.float32) test_case.assertTrue( np.allclose(y.to_local().numpy(), np.ones((1, 4), dtype=np.float32)) ) test_case.assertEqual(y.placement, placement) y_default_dtype = flow.tensor( x, placement=placement, sbp=[flow.sbp.split(0)], requires_grad=False, ) test_case.assertTrue(y_default_dtype.dtype == flow.int32) @autotest(n=5) def test_digamma_tensor_with_random_data(test_case): device = random_device() x = random_tensor().to(device) return x.digamma() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestTensorNumpy(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_1d_sbp_tensor_numpy_1n2d(test_case): ori_x = flow.tensor([1, 2, 3, 4]) + flow.env.get_rank() placement = flow.placement.all("cpu") x = ori_x.to_global(placement=placement, sbp=flow.sbp.split(0)) test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4, 2, 3, 4, 5])) x = ori_x.to_global(placement=placement, sbp=flow.sbp.broadcast, copy=True) test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4])) x = ori_x.to_global(placement=placement, sbp=flow.sbp.partial_sum) test_case.assertTrue(np.allclose(x.numpy(), [3, 5, 7, 9])) placement = flow.placement.all("cuda") x = ori_x.to_global(placement=placement, sbp=flow.sbp.split(0)) test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4, 2, 3, 4, 5])) x = ori_x.to_global(placement=placement, sbp=flow.sbp.broadcast, copy=True) test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4])) x = ori_x.to_global(placement=placement, sbp=flow.sbp.partial_sum) test_case.assertTrue(np.allclose(x.numpy(), [3, 5, 7, 9])) @flow.unittest.skip_unless_1n2d() def test_2d_sbp_tensor_numpy_1n2d(test_case): ori_x = flow.tensor(np.ones((2, 2))) + flow.env.get_rank() placement = flow.placement("cuda", [[0], [1]]) x = ori_x.to_global( placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)] ) test_case.assertTrue(np.allclose(x.numpy(), [[1, 1], [1, 1], [2, 2], [2, 2]])) x = ori_x.to_global( placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.split(0)] ) test_case.assertTrue(np.allclose(x.numpy(), [[1, 1], [1, 1]])) x = ori_x.to_global( placement=placement, sbp=[flow.sbp.partial_sum, flow.sbp.broadcast], copy=True, ) test_case.assertTrue(np.allclose(x.numpy(), [[3, 3], [3, 3]])) @flow.unittest.skip_unless_1n4d() def test_2d_sbp_tensor_numpy_1n4d(test_case): ori_x = flow.tensor(np.ones((2, 2))) + flow.env.get_rank() placement = flow.placement("cuda", [[0, 1], [2, 3]]) x = ori_x.to_global( placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)] ) test_case.assertTrue( np.allclose( x.numpy(), [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]] ) ) x = ori_x.to_global( placement=placement, sbp=[flow.sbp.split(0), flow.sbp.partial_sum] ) test_case.assertTrue(np.allclose(x.numpy(), [[3, 3], [3, 3], [7, 7], [7, 7]])) # TODO: (s0, b) has bug # x = ori_x.to_global(placement=placement, sbp=[flow.sbp.split(0), flow.sbp.broadcast]) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_bmm(test_case): t = random(1, 5) k = random(1, 5) input1 = random_tensor(ndim=3, dim0=t, dim1=3, dim2=k) input2 = random_tensor(ndim=3, dim0=t, dim1=k, dim2=5) of_out = input1.bmm(input2) return of_out @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_split(test_case): k0 = random(2, 6) k1 = random(2, 6) k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) res = x.split(2, dim=rand_dim) return torch.cat(res, rand_dim) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_split_sizes(test_case): k0 = random(2, 6) k1 = 7 k2 = random(2, 6) device = random_device() x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) res = x.split([1, 2, 3, 1], dim=-2) return torch.cat(res, dim=1) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_unbind(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.unbind(random(0, 4).to(int)) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_swapaxes(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = x.swapaxes(random(0, 2).to(int), random(0, 2).to(int)) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_swapdimst(test_case): device = random_device() x = random_tensor(ndim=3).to(device) y = x.swapdims(random(0, 3).to(int), random(0, 3).to(int)) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_int_repeat_interleave_dim_none(test_case): x = random_tensor(ndim=2, dim0=1, dim1=2) y = x.repeat_interleave(2) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_int_repeat_interleave_with_dim(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) dim = random(low=0, high=2).to(int) y = x.repeat_interleave(2, dim) return y @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_tensor_repeat_interleave_dim(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4) z = x.repeat_interleave(y, 1) return z @unittest.skip("skip for now, becase it failed 2 times in past week") @flow.unittest.skip_unless_1n1d() @autotest(n=5, rtol=1e-3) def test_tensor_tensor_repeat_interleave_dim_with_output_size(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4) z = x.repeat_interleave(y, 1, output_size=2) return z @flow.unittest.skip_unless_1n2d() @globaltest def test_global_tensor_detach(test_case): device = random_device().value() placement = flow.placement(device, [0, 1]) a = flow.ones(4, 8).to_global(placement, flow.sbp.broadcast) test_case.assertTrue(a.is_leaf) b = a.float().clone().detach() test_case.assertTrue(b.is_leaf) @flow.unittest.skip_unless_1n1d() @autotest(n=5) def test_tensor_nansum(test_case): device = random_device() x = random_tensor(4, random(0, 5), 2).to(device) mask = x < 0 x = x.masked_fill(mask, float("nan")) y = x.nansum() return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_part_3.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import oneflow as flow import oneflow.unittest import numpy as np from oneflow.test_utils.automated_test_util import * def _get_indexes(device): return ( constant( torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[1, 0], [1, 0]]), dtype=torch.int64, device=device) ), constant( torch.tensor(np.array([[0, 1], [0, 1]]), dtype=torch.int64, device=device) ), ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensor(flow.unittest.TestCase): @autotest(n=10) def test_scatter_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=2).to(device) src = oneof(3.14, random_tensor(ndim=2, dim0=2, dim1=2).to(device)) inplace = oneof(True, False) dim = oneof(0, 1, -1) if inplace: y = input + 1 y.scatter_(dim, oneof(*_get_indexes(device)), src) return y return input.scatter(dim, oneof(*_get_indexes(device)), src) @autotest( n=10, auto_backward=False ) # peihong: pytorch dose not support backward when reduce is add or multiply def test_scatter_add_or_multiply_random_data(test_case): device = random_device() input = random_tensor(ndim=2, dim0=2, dim1=2).to(device) src = random_tensor(ndim=2, dim0=2, dim1=2).to(device) inplace = oneof(True, False) reduce = oneof("add", "multiply") dim = oneof(0, 1) if inplace: y = input + 1 y.scatter_( dim, oneof(*_get_indexes(device)), src, reduce=reduce, ) return y return input.scatter(dim, oneof(*_get_indexes(device)), src, reduce=reduce) def test_tensor_element_size_api(test_case): x = flow.ones(2, 1, dtype=flow.float) test_case.assertEqual(x.element_size(), 4) def test_tensor_new(test_case): dtype = random_dtype(["pod"]) device = random_device() x = random_tensor(ndim=3).to(dtype).to(device) of_result = x.oneflow.new() th_result = x.pytorch.new() test_case.assertTrue(list(of_result.shape) == list(th_result.shape)) test_case.assertTrue( of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype ) test_case.assertTrue(of_result.device.type == th_result.device.type) y = random_tensor(ndim=3).to(dtype).to(device) of_result = x.oneflow.new(y.oneflow) th_result = x.pytorch.new(y.pytorch) test_case.assertTrue(list(of_result.shape) == list(th_result.shape)) test_case.assertTrue( of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype ) test_case.assertTrue(of_result.device.type == th_result.device.type) np_data = np.random.randn(3, 3) of_result = x.oneflow.new(np_data) th_result = x.pytorch.new(np_data) test_case.assertTrue(list(of_result.shape) == list(th_result.shape)) test_case.assertTrue( of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype ) test_case.assertTrue(of_result.device.type == th_result.device.type) of_result = x.oneflow.new([1, 2, 3]) th_result = x.pytorch.new([1, 2, 3]) test_case.assertTrue(list(of_result.shape) == list(th_result.shape)) test_case.assertTrue( of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype ) test_case.assertTrue(of_result.device.type == th_result.device.type) @autotest(n=3) def test_baddbmm(test_case): device = random_device() batch_dim = random().to(int) dim1 = random().to(int) dim2 = random().to(int) dim3 = random().to(int) x = random_tensor( ndim=3, dim0=oneof(batch_dim, 1).value(), dim1=dim1, dim2=dim3 ).to(device) batch1 = random_tensor(ndim=3, dim0=batch_dim, dim1=dim1, dim2=dim2).to(device) batch2 = random_tensor(ndim=3, dim0=batch_dim, dim1=dim2, dim2=dim3).to(device) alpha = random_or_nothing(-1, 1).to(float) beta = random_or_nothing(-1, 1).to(float) return x.baddbmm(batch1, batch2, alpha=alpha, beta=beta) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_pin_memory.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import copy import os import unittest from collections import OrderedDict import numpy as np import oneflow as flow import oneflow.unittest from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() class TestTensor(flow.unittest.TestCase): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=True, check_graph=False) def test_tensor_pin_memory(test_case): device = random_device() x = random_tensor(ndim=3).to(device) x2 = x.pin_memory() x3 = x2.pin_memory() test_case.assertTrue(id(x.pytorch) != id(x2.pytorch)) test_case.assertTrue(id(x3.pytorch) == id(x2.pytorch)) test_case.assertTrue(id(x.oneflow) != id(x2.oneflow)) test_case.assertTrue(id(x3.oneflow) == id(x2.oneflow)) return x3 @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=False, check_graph=False) def test_0_dim_tensor_pin_memory(test_case): device = random_device() x = random_tensor(ndim=1).to(device) x1 = x[0] x2 = x1.pin_memory() x3 = x2.pin_memory() test_case.assertTrue(id(x1.pytorch) != id(x2.pytorch)) test_case.assertTrue(id(x3.pytorch) == id(x2.pytorch)) test_case.assertTrue(id(x1.oneflow) != id(x2.oneflow)) test_case.assertTrue(id(x3.oneflow) == id(x2.oneflow)) return x3 @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=False, check_graph=False) def test_tensor_construct_with_pin_memory_param(test_case): device = random_device() n = random(1, 4).to(int) c = random(1, 4).to(int) h = random(1, 4).to(int) w = random(1, 4).to(int) x = random_tensor(ndim=4, dim0=n, dim1=c, dim2=h, dim3=w, pin_memory=True).to( device ) return x @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() @autotest(n=5, auto_backward=True, check_graph=False) def test_tensor_is_pinned(test_case): device = random_device() x = random_tensor(ndim=4).to(device) y = x.pin_memory() test_case.assertTrue(x.oneflow.is_pinned() == x.pytorch.is_pinned()) test_case.assertTrue(y.oneflow.is_pinned() == y.pytorch.is_pinned()) return y if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test/tensor/test_tensor_to_memory_format.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import random as random_util import oneflow as flow import oneflow.unittest import numpy as np from oneflow.test_utils.automated_test_util import * @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestTensor(flow.unittest.TestCase): @autotest(n=3) def test_to_memory_format(test_case): def check_equal(a, b): test_case.assertEqual(list(a.shape), list(b.shape)) test_case.assertEqual(list(a.stride()), list(b.stride())) test_case.assertEqual(a.is_contiguous(), b.is_contiguous()) test_case.assertTrue( np.allclose( a.detach().cpu().numpy(), b.detach().cpu().numpy(), 1e-06, 1e-06 ) ) device = random_device() x = random_tensor( ndim=4, dim0=random(1, 6).to(int), dim1=random(1, 6).to(int), dim2=random(1, 6).to(int), dim3=random(1, 6).to(int), ).to(device) oneflow_x = x.oneflow pytorch_x = x.pytorch # TODO(): implement backward with flow.no_grad(): oneflow_y = oneflow_x.to(memory_format=torch.contiguous_format.oneflow) pytorch_y = pytorch_x.to(memory_format=torch.contiguous_format.pytorch) check_equal(oneflow_y, pytorch_y) oneflow_y = oneflow_x.to(memory_format=torch.channels_last.oneflow) pytorch_y = pytorch_x.to(memory_format=torch.channels_last.pytorch) # Note: pytorch Tensor.to(channels_last) won't change tensor shape, so we should # permute it that only change the tensor shape and won't relayout its storage. # TODO(): align with pytorch check_equal(oneflow_y, pytorch_y.permute(0, 2, 3, 1)) if __name__ == "__main__": unittest.main() ================================================ FILE: python/oneflow/test_utils/__init__.py ================================================ ================================================ FILE: python/oneflow/test_utils/automated_test_util/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .generators import * from .torch_flow_dual_object import * from .torch_flow_dual_object import torch from .profiler import profile import os ================================================ FILE: python/oneflow/test_utils/automated_test_util/generators.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import inspect import os import random as random_util import typing from collections import namedtuple from typing import Any, Dict, Optional, Tuple, Sequence, Union from itertools import product import numpy as np import torch import oneflow as flow from .global_scope import * from .util import broadcast py_tuple = tuple NoneType = type(None) TEST_MODULE = 0 TEST_FLOW = 1 TEST_TENSOR = 2 rng = np.random.default_rng() annotation2default_generator = {} annotation2torch_to_flow_converter = {} NoneType = type(None) random_value_default_range = {int: (-10, 11), float: (-1, 1), complex: (-10, 10)} def data_generator(annotation): def register_data_generator(cls): annotation2default_generator[annotation] = lambda: cls() return cls return register_data_generator def torch_to_flow_converter(annotation): def register_flow_to_flow_converter(func): annotation2torch_to_flow_converter[annotation] = func return func return register_flow_to_flow_converter @torch_to_flow_converter(torch.Tensor) def tensor_converter(torch_tensor): return flow.tensor(torch_tensor.cpu().numpy()) def convert_torch_object_to_flow(x): for (annotation, converter) in annotation2torch_to_flow_converter.items(): if isinstance(x, annotation): return converter(x) return x def pack(x): if isinstance(x, generator): return x return constant(x) class Nothing: pass class generator: def __init__(self, children): self.children = children self._value = None self._has_value = False def _init(self): self._value = None self._has_value = False for x in self.children: x._init() def eval(self): self._init() return self.value() def _calc_value(self): raise NotImplementedError() def value(self): if not self._has_value: self._value = self._calc_value() if is_global(): self._value = broadcast(self._value) self._has_value = True return self._value def size(self): return 1 def __or__(self, other): other = pack(other) return oneof( self, other, possibility=self.size() / (self.size() + other.size()) ) def __ror__(self, other): return self | other def __add__(self, other): return add(self, other) def __radd__(self, other): return self + other def __sub__(self, other): return self + neg(other) def __rsub__(self, other): return neg(self - other) def __mul__(self, other): return mul(self, other) def __rmul__(self, other): return self * other def to(self, annotation): self._to(annotation) for x in self.children: x.to(annotation) return self def _to(self, annotation): pass class add(generator): def __init__(self, a, b): self.a = pack(a) self.b = pack(b) super().__init__([self.a, self.b]) def _calc_value(self): return self.a.value() + self.b.value() class mul(generator): def __init__(self, a, b): self.a = pack(a) self.b = pack(b) super(mul, self).__init__([self.a, self.b]) def _calc_value(self): return self.a.value() * self.b.value() class neg(generator): def __init__(self, a): self.a = pack(a) super().__init__([self.a]) def _calc_value(self): return -self.a.value() class oneof(generator): def __init__(self, *args, possibility=None): self.args = list(map(pack, args)) super().__init__(self.args) if isinstance(possibility, float): assert len(args) == 2 possibility = [possibility, 1 - possibility] if possibility is None: possibility = [1 / len(args)] * len(args) self.possibility = pack(possibility) def _calc_value(self): rand = rng.random() sum = 0 for (i, possibility) in enumerate(self.possibility.value()): sum += possibility if sum > rand: return self.args[i].value() raise RuntimeError() def __call__(self, *args: Any, **kwds: Any) -> Any: return self._calc_value()(*args, **kwds) def size(self): return sum([x.size() for x in self.args]) class tuple(generator): def __init__(self, *args): self.args = list(map(pack, args)) super().__init__(self.args) def _calc_value(self): return py_tuple([x.value() for x in self.args]) class constant(generator): def __init__(self, x): super().__init__([]) self.x = x def _calc_value(self): return self.x class nothing(generator): def __init__(self): super().__init__([]) def _calc_value(self): return Nothing() class random(generator): def __init__(self, low=1, high=6): self.low = pack(low) self.high = pack(high) super().__init__([self.low, self.high]) self.annotation = None def _to(self, annotation): if self.annotation is not None: return if hasattr(annotation, "__origin__"): annotation = eval(repr(annotation)) self.annotation = annotation def _generate(self, annotation): if hasattr(annotation, "__origin__"): if annotation.__origin__ is Union: x = random_util.choice(annotation.__args__) return self._generate(x) if annotation.__origin__ is Tuple or annotation.__origin__ is py_tuple: return [self._generate(x) for x in annotation.__args__] else: raise NotImplementedError( f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}" ) (low, high) = (self.low.value(), self.high.value()) if annotation == int: val = int(rng.integers(low, high)) elif annotation == float: val = float(rng.random() * (high - low) + low) elif annotation == bool: val = random_util.choice([True, False]) elif annotation == complex: val_real = float(rng.random() * (high - low) + low) val_imag = float(rng.random() * (high - low) + low) val = val_real + 1.0j * val_imag elif annotation is None: val = None elif annotation is NoneType: val = None else: raise NotImplementedError( f"Not implemented annotation {annotation} in random" ) return val def _calc_value(self): return self._generate(self.annotation) def random_or_nothing(low, high): return oneof(random(low, high), nothing(), possibility=2 / 3) @data_generator(torch.Tensor) class random_pytorch_tensor(generator): def __init__( self, ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None, low=None, high=None, dtype=float, pin_memory=False, ): if ndim is None: ndim = random(1, 6) if dim0 is None: dim0 = random(1, 8) if dim1 is None: dim1 = random(1, 8) if dim2 is None: dim2 = random(1, 8) if dim3 is None: dim3 = random(1, 8) if dim4 is None: dim4 = random(1, 8) self.ndim = pack(ndim).to(int) self.dim0 = pack(dim0).to(int) self.dim1 = pack(dim1).to(int) self.dim2 = pack(dim2).to(int) self.dim3 = pack(dim3).to(int) self.dim4 = pack(dim4).to(int) self.low = pack(low).to(float) self.high = pack(high).to(float) self.dtype = pack(dtype) self.pin_memory = pin_memory super().__init__( [ self.ndim, self.dim0, self.dim1, self.dim2, self.dim3, self.dim4, self.low, self.high, self.dtype, self.pin_memory, ] ) def _calc_value(self): ndim = self.ndim.value() dim0 = self.dim0.value() dim1 = self.dim1.value() dim2 = self.dim2.value() dim3 = self.dim3.value() dim4 = self.dim4.value() dtype = self.dtype.value() low = self.low.value() high = self.high.value() if low is None: low = random_value_default_range[dtype][0] if high is None: high = random_value_default_range[dtype][1] pin_memory = self.pin_memory shape = rng.integers(low=1, high=8, size=ndim) if ndim == 0: shape = [] if ndim >= 1 and dim0 is not None: shape[0] = dim0 if ndim >= 2: shape[1] = dim1 if ndim >= 3: shape[2] = dim2 if ndim >= 4: shape[3] = dim3 if ndim == 5: shape[4] = dim4 pytorch_tensor = None if dtype == float: np_arr = rng.uniform(low=low, high=high, size=shape) res = torch.Tensor(np_arr) if pin_memory: res = res.pin_memory() return res elif dtype == int: np_arr = rng.integers(low=low, high=high, size=shape) res = torch.tensor(np_arr, dtype=torch.int64) if pin_memory: res = res.pin_memory() return res elif dtype == complex: np_arr = rng.uniform(low=low, high=high, size=shape) + 1.0j * rng.uniform( low=low, high=high, size=shape ) res = torch.tensor(np_arr, dtype=torch.complex64) if pin_memory: res = res.pin_memory() return res else: raise NotImplementedError(f"Not implemented dtype {dtype} in random") @data_generator(bool) def random_bool(): return random().to(bool) class random_device(generator): def __init__(self): super().__init__([]) def _calc_value(self): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): return "cpu" else: return random_util.choice(["cuda", "cpu"]) class cpu_device(generator): def __init__(self): super().__init__([]) def _calc_value(self): return random_util.choice(["cpu"]) class gpu_device(generator): def __init__(self): super().__init__([]) def _calc_value(self): return random_util.choice(["cuda"]) @data_generator(torch.dtype) class random_pytorch_dtype(generator): none_dtype_seq = [None] bool_dtype_seq = [torch.bool] floating_dtype_seq = [torch.float, torch.double] half_dtype_seq = [torch.half] bfloat16_dtype_seq = [torch.bfloat16] complex_dtype_seq = [torch.complex64, torch.complex128] signed_int_dtype_seq = [torch.int8, torch.int32, torch.int64] unsigned_int_dtype_seq = [torch.uint8] int_dtype_seq = [torch.int8, torch.int32, torch.int64] image_dtype_seq = [torch.uint8, torch.float] index_dtype_seq = [torch.int32, torch.int64] arithmetic_dtype_seq = [*floating_dtype_seq, *int_dtype_seq] pod_dtype_seq = [*arithmetic_dtype_seq, *unsigned_int_dtype_seq, *bool_dtype_seq] all_dtype_seq = [*arithmetic_dtype_seq, torch.half, torch.bfloat16] seq_name_to_seq = { "None": none_dtype_seq, "bool": bool_dtype_seq, "float": floating_dtype_seq, "half": half_dtype_seq, "bfloat16": bfloat16_dtype_seq, "complex": complex_dtype_seq, "signed": signed_int_dtype_seq, "unsigned": unsigned_int_dtype_seq, "int": int_dtype_seq, "image": image_dtype_seq, "index": index_dtype_seq, "arithmetic": arithmetic_dtype_seq, "pod": pod_dtype_seq, "all": all_dtype_seq, } def __init__(self, seq_names): super().__init__([]) # concat related dtype_seq for name in seq_names self.data_type_seq = [ dtype for name in seq_names for dtype in self.seq_name_to_seq[name] ] def _calc_value(self): return random_util.choice(self.data_type_seq) class all_placement(generator): def __init__(self): super().__init__([]) self.node_size = flow.env.get_node_size() self.world_size = flow.env.get_world_size() self.num_rank_for_each_node = self.world_size // self.node_size def __len__(self): return len(self.value()) def __getitem__(self, key): return self.value()[key] def _calc_device(self): if os.getenv("ONEFLOW_TEST_CPU_ONLY"): return [ "cpu", ] else: return ["cuda", "cpu"] def _calc_all_placement(self): all_device = self._calc_device() all_hierarchy = [ (self.world_size,), (self.node_size, self.num_rank_for_each_node), ] return [ flow.placement(device, np.array(range(self.world_size)).reshape(hierarchy)) for device, hierarchy in list(product(all_device, all_hierarchy)) ] def _calc_value(self): return self._calc_all_placement() class all_cpu_placement(all_placement): def __init__(self): super().__init__() def _calc_device(self): return ["cpu"] class all_cuda_placement(all_placement): def __init__(self): super().__init__() def _calc_device(self): return ["cuda"] class random_placement(all_placement): def __init__(self): super().__init__() def _calc_value(self): return random_util.choice(self._calc_all_placement()) class random_cpu_placement(random_placement): def __init__(self): super().__init__() def _calc_device(self): return ["cpu"] class random_gpu_placement(random_placement): def __init__(self): super().__init__() def _calc_device(self): return ["cuda"] class all_sbp(generator): def __init__( self, placement=None, dim=1, max_dim=0, except_split=False, except_broadcast=False, except_partial_sum=False, valid_split_axis: Optional[Union[int, Sequence[int]]] = None, ): super().__init__([]) if placement is not None: if isinstance(placement, random_placement): self.dim = len(placement.value().ranks.shape) elif isinstance(placement, flow.placement): self.dim = len(placement.ranks.shape) else: raise RuntimeError( f"placement should be instance of random_placement or oneflow.placement" ) else: self.dim = dim self.max_dim = max_dim self.except_split = except_split self.except_broadcast = except_broadcast self.except_partial_sum = except_partial_sum if valid_split_axis is not None: if isinstance(valid_split_axis, int): self.valid_split_axis = [ valid_split_axis, ] else: self.valid_split_axis = list(valid_split_axis) else: self.valid_split_axis = [i for i in range(self.max_dim)] def __len__(self): return len(self.value()) def __getitem__(self, key): return self.value()[key] def _calc_all_sbp(self): # scalar only use broadcast sbp if self.max_dim == 0: return [ [flow.sbp.broadcast for i in range(self.dim)], ] all_sbps = [] if not self.except_split: for i in range(self.max_dim): if i in self.valid_split_axis: all_sbps.append(flow.sbp.split(i)) if not self.except_broadcast: all_sbps.append(flow.sbp.broadcast) if not self.except_partial_sum: all_sbps.append(flow.sbp.partial_sum) return list(product(all_sbps, repeat=self.dim)) def _calc_value(self): return self._calc_all_sbp() class random_sbp(all_sbp): def __init__( self, placement=None, dim=1, max_dim=0, except_split=False, except_broadcast=False, except_partial_sum=False, valid_split_axis: Optional[Union[int, Sequence[int]]] = None, ): super().__init__( placement, dim, max_dim, except_split, except_broadcast, except_partial_sum, valid_split_axis, ) def _calc_value(self): return random_util.choice(self._calc_all_sbp()) @data_generator(torch.Tensor) class choice_pytorch_tensor(generator): def __init__(self, a, size=None, replace=True, p=None, dtype=int): self.a = a self.size = size self.replace = replace self.p = p self.dtype = dtype super().__init__( [self.a, self.size, self.replace, self.p, self.dtype,] ) def _calc_value(self): pytorch_tensor = None np_arr = np.random.choice(self.a, self.size, self.replace, self.p) torch_dtype = None return torch.tensor(np_arr.astype(self.dtype)) __all__ = [ "random_pytorch_tensor", "random_bool", "random_device", "random_pytorch_dtype", "cpu_device", "gpu_device", "random_placement", "random_cpu_placement", "random_gpu_placement", "all_placement", "all_cpu_placement", "all_cuda_placement", "random_sbp", "all_sbp", "random", "random_or_nothing", "oneof", "constant", "nothing", "choice_pytorch_tensor", ] ================================================ FILE: python/oneflow/test_utils/automated_test_util/global_scope.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ _global_is_global = False class GlobalScope: def __init__(self): pass def __enter__(self, *argc, **kwarg): global _global_is_global self.last_is_global = _global_is_global _global_is_global = True def __exit__(self, *argc, **kwarg): global _global_is_global _global_is_global = self.last_is_global def is_global(): return _global_is_global ================================================ FILE: python/oneflow/test_utils/automated_test_util/profiler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools import os from typing import Any, Callable, Iterable, List, Optional, Tuple import torch import oneflow as flow import oneflow.support.env_var_util from oneflow.test_utils.automated_test_util import ( torch_flow_dual_object as dual_object_module, ) __all__ = ["profile", "set_profiler_hook", "profile_dual_object", "profiled_framework"] def compose(*fs): def compose2(f, g): return lambda *a, **kw: f(g(*a, **kw)) return functools.reduce(compose2, fs) class ProfResult: def __init__( self, prof, num, kind, device, thread_num, op_name, args_description, additional_description=None, ): self.prof = prof self.num = num self.kind = kind self.device = device self.thread_num = thread_num self.op_name = op_name self.args_description = args_description self.additional_description = additional_description def __getattr__(self, attr): return getattr(self.prof, attr) WARMUP_NUM = int(os.getenv("ONEFLOW_PROFILE_WARMUP_NUM", 10)) RUN_NUM = int(os.getenv("ONEFLOW_PROFILE_RUN_NUM", 1000)) PROF_VERBOSE = flow.support.env_var_util.parse_boolean_from_env( "ONEFLOW_PROFILE_VERBOSE", False ) END_TO_END = "end-to-end" def run_torch( op, args, kwargs, device, num_threads, op_name, args_description, additional_description=None, ): assert device in ["cpu", "cuda"] if device == "cpu": torch.set_num_threads(num_threads) assert torch.get_num_threads() == num_threads activities = [torch.profiler.ProfilerActivity.CPU] else: activities = [torch.profiler.ProfilerActivity.CUDA] def tensor_to_device(x): if isinstance(x, torch.Tensor): return x.to(device) return x args = [tensor_to_device(arg) for arg in args] kwargs = {k: tensor_to_device(v) for k, v in kwargs.items()} for _ in range(WARMUP_NUM): op(*args, **kwargs) if PROF_VERBOSE: print( f'PyTorch ({f"CPU, num_threads={num_threads}" if device == "cpu" else "GPU"}):' ) with torch.profiler.profile(activities=activities) as prof: with torch.profiler.record_function(END_TO_END): for _ in range(RUN_NUM): op(*args, **kwargs) if PROF_VERBOSE: print(prof.key_averages().table(row_limit=10)) return ProfResult( prof, RUN_NUM, "PyTorch", device, num_threads, op_name, args_description, additional_description, ) def run_flow( op, args, kwargs, device, num_threads, op_name, args_description, additional_description=None, ): assert device in ["cpu", "cuda"] if device == "cpu": # NOTE: there is no flow.get_num_threads() flow.set_num_threads(num_threads) activities = [flow.profiler.ProfilerActivity.CPU] else: activities = [flow.profiler.ProfilerActivity.CUDA] def tensor_to_device(x): if isinstance(x, flow.Tensor): return x.to(device) return x args = [tensor_to_device(arg) for arg in args] kwargs = {k: tensor_to_device(v) for k, v in kwargs.items()} for _ in range(WARMUP_NUM): op(*args, **kwargs) if PROF_VERBOSE: print( f'OneFlow ({f"CPU, num_threads={num_threads}" if device == "cpu" else "GPU"}):' ) with flow.profiler.profile( activities=activities, record_bandwidth_for_cuda=flow.profiler.ProfilerActivity.CUDA in activities, ) as prof: with flow.profiler.record_function(END_TO_END): for _ in range(RUN_NUM): op(*args, **kwargs) if PROF_VERBOSE: print(prof.key_averages()) return ProfResult( prof, RUN_NUM, "OneFlow", device, num_threads, op_name, args_description, additional_description, ) def profile_dual_object(op): assert isinstance(op, dual_object_module.DualObject) torch_op = op.pytorch flow_op = op.oneflow def profiled_op(*args, **kwargs): if "profile_description" in kwargs: additional_description = kwargs["profile_description"] del kwargs["profile_description"] else: additional_description = None ( torch_args, torch_kwargs, flow_args, flow_kwargs, ) = dual_object_module.get_args(torch_op, *args, **kwargs) op_name = dual_object_module.to_string(op) args_description = dual_object_module.to_string(*args, **kwargs) result = [] for hardware_info in _hardware_info_list: if "oneflow" in profiled_framework: result.append( run_flow( flow_op, flow_args, flow_kwargs, *hardware_info, op_name, args_description, additional_description, ) ) else: result.append(None) for hardware_info in _hardware_info_list: if "pytorch" in profiled_framework: result.append( run_torch( torch_op, torch_args, torch_kwargs, *hardware_info, op_name, args_description, additional_description, ) ) else: result.append(None) return _profiler_hook(result) return profiled_op HardwareInfo = Tuple[str, Optional[int]] # (device_type, num_threads) _hardware_info_list: List[HardwareInfo] = [("cpu", 1), ("cuda", None)] _profiler_hook: Callable[[List[ProfResult]], Any] = lambda x: x profiled_framework: List[str] = ["oneflow", "pytorch"] def set_hardware_info_list(hardware_info_list: List[HardwareInfo]) -> None: global _hardware_info_list _hardware_info_list = hardware_info_list def set_profiler_hook(hook: Callable[[List[ProfResult]], Any]) -> None: global _profiler_hook _profiler_hook = hook def profile(op): def deco(f): def new_f(*args, **kwargs): dual_object_module.profiled_method_name.append(op.name) res = f(*args, **kwargs) dual_object_module.profiled_method_name.pop() return res return new_f return deco ================================================ FILE: python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import collections.abc import functools import inspect import copy import os import warnings import gc from typing import Union import numpy as np import oneflow as flow from oneflow.test_utils.automated_test_util import profiler as auto_profiler from oneflow.test_utils.test_util import type_name_to_flow_type flow.backends.cudnn.deterministic = True try: import torch as torch_original torch_original.backends.cudnn.deterministic = True torch_original.set_printoptions(profile="full") except ImportError: print( "automated_test_util module uses PyTorch to verify OneFlow module's interface and result. Please install Pytorch according `https://pytorch.org/get-started/locally/`." ) from .util import broadcast from .global_scope import * from .generators import ( Nothing, generator, random_pytorch_tensor, random_pytorch_dtype, choice_pytorch_tensor, rng, ) postulate = [".rand", ".Tensor"] testing = False testing_graph = False testing_complex = False global_check_allclose = True global_atol = 1e-5 global_rtol = 1e-5 global_backward = True def torch_tensor_to_flow(x): return flow.tensor(x.cpu().numpy()) note_pytorch_method_names = [] note_pytorch_args = [] note_pytorch_kwargs = [] vis_tensor = [] vis_parameters = {} call_tensor_id = [] extra_input_tensor = [] class PyTorchDoesNotSupportError(Exception): def __init__(self, exc): self.exc = exc def __str__(self): return repr(self) def __repr__(self): return f"PyTorch error: {str(self.exc)}" class OneFlowGraphBuildOrRunError(Exception): def __init__(self, exc): self.exc = exc def __str__(self): return repr(self) def __repr__(self): return f"OneFlow nn.Graph Build Or Run Error: {str(self.exc)}" class BothDoNotSupportError(Exception): def __init__(self, th_exc, of_exc): self.th_exc = th_exc self.of_exc = of_exc def __str__(self): return repr(self) def __repr__(self): return f"PyTorch error: {str(self.th_exc)}\nOneFlow error: {str(self.of_exc)}" call_pytorch = None def get_tensor_shape(call_pytorch): shape_list = [] for i in range(len(call_pytorch.shape)): shape_list.append(call_pytorch.shape[i]) return shape_list def get_args(callable, *args, **kwargs): try: spec = inspect.getfullargspec(callable) spec_args = spec.args if spec_args[0] == "self": del spec_args[0] for (i, arg) in enumerate(args): arg_name = spec_args[i] annotation = spec.annotations[arg_name] if isinstance(arg, generator): arg.to(annotation) for (arg_name, arg) in kwargs.items(): annotation = spec.annotations[arg_name] if isinstance(arg, generator): arg.to(annotation) except: pass (pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs) = ([], {}, [], {}) def get_pytorch_value(x): if isinstance(x, DualObject): return x.pytorch return x def get_oneflow_value(x): if isinstance(x, DualObject): return x.oneflow return x def get_generator_value(x): if isinstance(x, generator): return x.value() return x for arg in args: # TODO: refine codes if isinstance(arg, (tuple, list)): pytorch_tuple_args = [] oneflow_tuple_args = [] for t in arg: t = get_generator_value(t) pytorch_tuple_args.append(get_pytorch_value(t)) oneflow_tuple_args.append(get_oneflow_value(t)) pytorch_args.append(tuple(pytorch_tuple_args)) oneflow_args.append(tuple(oneflow_tuple_args)) else: arg = get_generator_value(arg) pytorch_args.append(get_pytorch_value(arg)) oneflow_args.append(get_oneflow_value(arg)) for (key, value) in kwargs.items(): value = get_generator_value(value) if isinstance(value, Nothing): continue pytorch_kwargs[key] = get_pytorch_value(value) oneflow_kwargs[key] = get_oneflow_value(value) new_pytorch_args = [] new_pytorch_kwargs = {} for x in pytorch_args: if isinstance(x, (tuple, list)): new_x = f"(" len_x = len(x) for i in range(len_x): if type(x[i]) is torch_original.Tensor: if i < len_x - 1: new_x += f"Tensor({get_tensor_shape(x[i])}), " else: new_x += f"Tensor({get_tensor_shape(x[i])})" else: if i < len_x - 1: new_x += f"{x[i]}, " else: new_x += f"{x[i]}" new_x += f")" new_pytorch_args.append(new_x) continue if type(x) is torch_original.Tensor: new_pytorch_args.append(f"Tensor({get_tensor_shape(x)})") else: new_pytorch_args.append(x) for key, value in pytorch_kwargs.items(): if type(value) is torch_original.Tensor: new_pytorch_kwargs[key] = f"Tensor({get_tensor_shape(value)})" else: new_pytorch_kwargs[key] = value if not isinstance(callable, (torch_original.nn.Module)): if isinstance(call_pytorch, torch_original.Tensor): note_pytorch_method_names.append( f"Tensor({get_tensor_shape(call_pytorch)}).{callable.__name__}" ) elif isinstance(call_pytorch, torch_original.nn.Module): note_pytorch_method_names.append(f"Module.{callable.__name__}") else: note_pytorch_method_names.append(f"{callable.__name__}") else: note_pytorch_method_names.append(repr(callable)) note_pytorch_args.append(new_pytorch_args) note_pytorch_kwargs.append(new_pytorch_kwargs) return (pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs) def to_string(*args, **kwargs) -> str: def _to_string(x): if isinstance(x, DualObject): return x.name return str(x) strs = [] if len(args) > 0: strs.append(", ".join([_to_string(arg) for arg in args])) if len(kwargs) > 0: strs.append(", ".join([f"{k}={_to_string(v)}" for k, v in kwargs.items()])) return ", ".join(strs) counter = 0 align_exception = os.getenv("ONEFLOW_TEST_ALIGN_EXCEPTION") is not None def check_eager_graph_tensor(eager_res, graph_res): if ( global_check_allclose and isinstance(eager_res, flow.Tensor) and isinstance(graph_res, flow.Tensor) ): equality_res = np.allclose( eager_res.numpy(), graph_res.numpy(), rtol=global_rtol, atol=global_atol, equal_nan=True, ) return equality_res else: return True # NOTE(lixiang): Deepcopy the input parameters in order to correctly test the inplace version of the op. def get_args_copy(args, kwargs): copy_args = [] for arg in args: if flow.is_tensor(arg): copy_arg = arg.clone().detach() else: copy_arg = copy.deepcopy(arg) copy_args.append(copy_arg) copy_kwargs = {} for key, value in kwargs.items(): if flow.is_tensor(value): copy_kwargs[key] = value.clone().detach() else: copy_kwargs[key] = copy.deepcopy(value) return copy_args, copy_kwargs def get_fake_program_more_detail(oneflow, mode, func, args=None, kwargs=None): print(f"\033[1;33m============= {mode} ================\033[1;33m") print(f"\033[1;33mEnter {func} function\033[1;33m") try: if "__self__" in dir(oneflow) and flow.is_tensor(oneflow.__self__): print(f"\033[1;33m{oneflow.__self__}\033[1;33m") except: if flow.is_tensor(oneflow): print(f"\033[1;33m{oneflow}\033[1;33m") if args is not None: print(f"\033[1;33m{args}\033[1;33m") if kwargs is not None: print(f"\033[1;33m{kwargs}\033[1;33m") print_note_fake_program() print(f"\033[1;33mLeave {func} function\033[1;33m") print(f"\033[1;37m\033[1;37m") print("\n\n") # NOTE(lixiang): When the graph global test is executed, the func is used to get the device type. def get_global_test_device(oneflow_args, oneflow_kwargs=None): # The case when the parameter input of Op only has kwargs. if not oneflow_args: return oneflow_kwargs["placement"].type # The case when the parameter input of Op is tensors. elif isinstance(oneflow_args[0], flow.Tensor): return oneflow_args[0].placement.type # The case when the parameter input of Op is tensor. elif isinstance(oneflow_args[0], flow.placement): return oneflow_args[0].type # The case when the parameter input of Op is tuple. For example: test_0_dim_tensor. elif isinstance(oneflow_args[0], tuple): return oneflow_args[0][0].placement.type # When oneflow_args[0] is int or float, etc. else: return oneflow_args[1].placement.type # NOTE(lixiang): When oneflow is of type nn.Module, build the following Graph for testing. # graph_train_oneflow: is a deepcopy of oneflow. def get_module_graph_test(graph_train_oneflow, oneflow, verbose, oneflow_args, *args): of_sgd = flow.optim.SGD(graph_train_oneflow.parameters(), lr=0.001, momentum=0.9,) graph_train_parameters_len = 0 for param in oneflow._parameters.values(): if param is not None: graph_train_parameters_len += 1 if verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_module_graph_test", oneflow_args ) class TestGraphOfModule(flow.nn.Graph): def __init__(self): super().__init__() self.test_module = graph_train_oneflow if global_backward and graph_train_parameters_len: self.add_optimizer(of_sgd) def build(self, *args): res = self.test_module(*args) forward_res = res if global_backward and graph_train_parameters_len: if isinstance(self.test_module.to(flow.nn.Module), flow.nn.LSTMCell): res = res[0] + res[1] elif isinstance(self.test_module.to(flow.nn.Module), flow.nn.LSTM): res = res[0].sum() + res[1][0].sum() + res[1][1].sum() elif isinstance(res, (tuple, list)): res = res[0] res = res.sum() res.backward() return forward_res try: test_g_res = TestGraphOfModule() except Exception as e: if not verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_module_graph_test", oneflow_args ) raise OneFlowGraphBuildOrRunError(e) return test_g_res def check_oneflow_args_first_element_is_int(args): if isinstance(args, (tuple, list)) and len(args) > 0: if isinstance(args[0], (int, float)): return True elif isinstance(args[0], (tuple, list)): return check_oneflow_args_first_element_is_int(args[0]) return False # NOTE(lixiang): When oneflow is of functional type, build the following Graph for testing, and return the test results in Graph mode. # graph_functional_oneflow: is a deepcopy of oneflow. def get_functional_graph_res( graph_functional_oneflow, oneflow, oneflow_res, oneflow_args, oneflow_kwargs, verbose, *graph_args, **graph_kwargs, ): test_g_res = [] if verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_functional_graph_res", oneflow_args, oneflow_kwargs, ) class TestGraphOfFunctional(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return graph_functional_oneflow(*graph_args, **graph_kwargs) try: is_global_flag = is_global() # In graph mode, when the tensor on the cpu executes the to("cpu") method, a check error will be reported. if oneflow.__name__ == "to" or oneflow.__name__ == "_to": if isinstance(oneflow_res, flow.Tensor): # The global tensor needs to obtain the device type through placement.type. if is_global_flag: if ( oneflow_args and oneflow_res.placement.type == oneflow_args[0] ) or ( oneflow_kwargs and oneflow_res.placement.type == oneflow_kwargs["device"] ): test_g_res = oneflow_res # The tensor needs to obtain the device type through device.type. else: if ( oneflow_args and oneflow_res.device.type == oneflow_args[0] ) or ( oneflow_kwargs and oneflow_res.device.type == oneflow_kwargs["device"] ): test_g_res = oneflow_res else: pass # nn.Graph donot deal with Module type. EX: m.to_global(placement, sbp). elif oneflow.__name__ == "to_global": test_g_res = oneflow_res elif oneflow.__name__ == "Parameter": # nn.Graph donot deal with Parameter creation. test_g_res = oneflow_res # oneflow_args may be empty, such as dropout. elif is_global_flag and len(oneflow_args) == 0: test_g_res = oneflow_res # For some ops whose input parameters is int, 'int' object has no attribute 'placement'. elif ( is_global_flag and len(oneflow_args) != 0 and (check_oneflow_args_first_element_is_int(oneflow_args)) ): test_g_res = oneflow_res # When doing the global op test, get_global_test_device() will be executed, and temporarily skipping the graph autotest on cpu device. elif ( is_global_flag and oneflow.__name__ != "weight_norm" and (get_global_test_device(oneflow_args, oneflow_kwargs) == "cpu") ): test_g_res = oneflow_res else: test_g = TestGraphOfFunctional() test_g_res = test_g() except Exception as e: if not verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_functional_graph_res", oneflow_args, oneflow_kwargs, ) raise OneFlowGraphBuildOrRunError(e) return test_g_res # NOTE(lixiang): When oneflow is of tensor type, build the following Graph for testing, and return the test results in Graph mode. # graph_tensor_oneflow is a deepcopy of oneflow. def get_tensor_graph_res( graph_tensor_oneflow, oneflow, verbose, *tensor_graph_args, **tensor_graph_kwargs ): test_g_res = [] if verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_tensor_graph_res", tensor_graph_args, tensor_graph_kwargs, ) class TestGraphOfTensorMethod(flow.nn.Graph): def __init__(self): super().__init__() def build(self): return graph_tensor_oneflow(*tensor_graph_args, **tensor_graph_kwargs) try: # Set test_g_res = None, check_eager_graph_tensor will return True, the purpose is to temporarily skip the Graph global test on cpu. if is_global() and (get_global_test_device((oneflow,)) == "cpu"): test_g_res = None else: test_g = TestGraphOfTensorMethod() test_g_res = test_g() except Exception as e: if not verbose: get_fake_program_more_detail( oneflow, "nn.Graph", "get_tensor_graph_res", tensor_graph_args, tensor_graph_kwargs, ) raise OneFlowGraphBuildOrRunError(e) return test_g_res def get_oneflow_eager_res( oneflow, oneflow_args, oneflow_kwargs, verbose, is_tesnor_method=False ): if verbose: get_fake_program_more_detail( oneflow, "Eager", "get_oneflow_eager_res", oneflow_args, oneflow_kwargs ) if not is_tesnor_method: oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) else: oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) return oneflow_res # NOTE(lixiang): Check if the results of eager and graph are equal when oneflow is of type nn.Module or functional. def oneflow_eager_run_with_graph_check( oneflow, oneflow_args, oneflow_kwargs, testing_graph, verbose, *args ): if testing_graph: graph_args, graph_kwargs = get_args_copy(oneflow_args, oneflow_kwargs) if isinstance(oneflow, flow.nn.Module): graph_train_oneflow = copy.deepcopy(oneflow) if not is_global(): arg_device_type = "cpu" for arg in oneflow_args: if flow.is_tensor(arg): arg_device_type = arg.device.type graph_train_oneflow = graph_train_oneflow.to(arg_device_type) else: graph_functional_oneflow = copy.deepcopy(oneflow) oneflow_res = get_oneflow_eager_res(oneflow, oneflow_args, oneflow_kwargs, verbose) if testing_graph: find_check_module_func = True ignore_apis_list = ["tensor", "train"] test_g_res = [] if isinstance(oneflow, flow.nn.Module): test_g = get_module_graph_test( graph_train_oneflow, oneflow, verbose, oneflow_args, *args ) # When doing the global op test, get_global_test_device() will be executed, and temporarily skipping the graph autotest on cpu device. if is_global() and ( get_global_test_device(oneflow_args, oneflow_kwargs) == "cpu" ): test_g_res = oneflow_res else: # When testing module methods, kwargs are not considered. test_g_res = test_g(*graph_args) elif oneflow.__name__ in ignore_apis_list: find_check_module_func = False # 1. "oneflow.nn.modules" not in oneflow.__module__: For avoid run nn.Module branch graph test, like fold op call Fold Module actually. # 2. inspect.isfunction(oneflow): Compared with the ordinary flow.xxx, oneflow.nn.modules.math_ops series op exist an extra layer of python wrapper. # 3. inspect.ismethod(oneflow) and "oneflow.nn.modules" in oneflow.__module__: For op that only has Tensor.xxx method, and call oneflow.xxx actually, like masked_fill. elif ( ( oneflow.__module__ is not None and ("oneflow.nn.modules" not in oneflow.__module__) ) or inspect.isfunction(oneflow) or ( inspect.ismethod(oneflow) and "oneflow.nn.modules" in oneflow.__module__ ) ): test_g_res = get_functional_graph_res( graph_functional_oneflow, oneflow, oneflow_res, oneflow_args, oneflow_kwargs, verbose, *graph_args, **graph_kwargs, ) if find_check_module_func: if isinstance(test_g_res, tuple): for _, g_res in enumerate(test_g_res): if not check_eager_graph_tensor(oneflow_res, g_res): get_fake_program_more_detail( oneflow, "Eager + nn.Graph", "oneflow_eager_run_with_graph_check", oneflow_args, oneflow_kwargs, ) else: if not check_eager_graph_tensor(oneflow_res, test_g_res): get_fake_program_more_detail( oneflow, "Eager + nn.Graph", "oneflow_eager_run_with_graph_check", oneflow_args, oneflow_kwargs, ) return oneflow_res # NOTE(lixiang): Check if the results of eager and graph are equal when oneflow is of type tensor. def oneflow_tensor_eager_run_with_graph_check( oneflow, oneflow_method, oneflow_args, oneflow_kwargs, testing_graph, verbose ): if testing_graph: tensor_graph_args, tensor_graph_kwargs = get_args_copy( oneflow_args, oneflow_kwargs ) graph_tensor_oneflow = copy.deepcopy(oneflow_method) oneflow_res = get_oneflow_eager_res( oneflow_method, oneflow_args, oneflow_kwargs, verbose, is_tesnor_method=True ) if testing_graph: test_g_res = get_tensor_graph_res( graph_tensor_oneflow, oneflow, verbose, *tensor_graph_args, **tensor_graph_kwargs, ) if isinstance(test_g_res, tuple): for _, g_res in enumerate(test_g_res): if not check_eager_graph_tensor(oneflow_res, g_res): get_fake_program_more_detail( oneflow, "nn.Graph", "oneflow_tensor_eager_run_with_graph_check", oneflow_args, oneflow_kwargs, ) else: if not check_eager_graph_tensor(oneflow_res, test_g_res): get_fake_program_more_detail( oneflow, "nn.Graph", "oneflow_tensor_eager_run_with_graph_check", oneflow_args, oneflow_kwargs, ) return oneflow_res def get_pytorch_oneflow_res( pytorch, oneflow, pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, name, verbose, testing_graph, *args, ): try: pytorch_res = pytorch(*pytorch_args, **pytorch_kwargs) if isinstance(pytorch_res, torch_original.Tensor): call_flag = True source_flag = True for x in pytorch_args: if isinstance(x, (tuple, list)): for y in x: if torch_original.is_tensor(y): source_flag = False if ( id(pytorch_res) == id(y) and pytorch_res.device.type == y.device.type ): call_flag = False break elif torch_original.is_tensor(x): source_flag = False if ( id(pytorch_res) == id(x) and pytorch_res.device.type == x.device.type ): call_flag = False break for x in pytorch_kwargs.values(): if isinstance(x, (tuple, list)): for y in x: if torch_original.is_tensor(y): source_flag = False if ( id(pytorch_res) == id(y) and pytorch_res.device.type == y.device.type ): call_flag = False break elif torch_original.is_tensor(x): source_flag = False if ( id(pytorch_res) == id(x) and pytorch_res.device.type == x.device.type ): call_flag = False break if source_flag and pytorch.__name__ != "to": call_tensor_id.append(id(pytorch_res)) extra_input_tensor.append(pytorch_res) elif call_flag: call_tensor_id.append(id(pytorch_res)) except Exception as e: if align_exception: try: oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) except Exception as ee: raise BothDoNotSupportError(e, ee) from None print( "PyTorch has an error but OneFlow is ok, maybe you should check your implementation to align with PyTorch." ) get_fake_program_more_detail( oneflow, "Eager", "get_pytorch_oneflow_res", oneflow_args, oneflow_kwargs, ) raise PyTorchDoesNotSupportError(e) if name in postulate: oneflow_res = torch_tensor_to_flow(pytorch_res) else: oneflow_res = oneflow_eager_run_with_graph_check( oneflow, oneflow_args, oneflow_kwargs, testing_graph, verbose, *args, ) return pytorch_res, oneflow_res def get_pytorch_oneflow_tensor_res( pytorch_method, oneflow_method, oneflow, pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, testing_graph, verbose, ): try: pytorch_res = pytorch_method(*pytorch_args, **pytorch_kwargs) if isinstance(pytorch_res, torch_original.Tensor): if ( id(pytorch_res) != id(pytorch_method.__self__) or pytorch_res.device.type == pytorch_method.__self__.device.type ): call_tensor_id.append(id(pytorch_res)) except Exception as e: if align_exception: try: oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) except Exception as ee: raise BothDoNotSupportError(e, ee) from None print( "PyTorch has an error but OneFlow is ok, maybe you should check your implementation to align with PyTorch." ) raise PyTorchDoesNotSupportError(e) oneflow_res = oneflow_tensor_eager_run_with_graph_check( oneflow, oneflow_method, oneflow_args, oneflow_kwargs, testing_graph, verbose, ) return pytorch_res, oneflow_res profiled_method_name = [] def GetDualObject(name, pytorch, oneflow): global counter counter += 1 skipped_magic_methods = [ "__class__", "__mro__", "__new__", "__init__", "__getattr__", "__setattr__", "__getattribute__", "__dict__", "__weakref__", "__builtins__", "__qualname__", "__name__", "__str__", "__repr__", ] verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None pytorch_methods = dir(pytorch) if hasattr(pytorch, "__call__") and "__call__" not in pytorch_methods: pytorch_methods.append("__call__") magic_methods_for_new_cls = {} for method_name in pytorch_methods: if method_name.startswith("__") and method_name not in skipped_magic_methods: def get_dual_method(method_name): if method_name == "__call__": if name in profiled_method_name: def method(self, *args, **kwargs): return auto_profiler.profile_dual_object(self)( *args, **kwargs ) return method def dual_method(self, *args, **kwargs): param_str = to_string(*args, **kwargs) ( pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, ) = get_args(pytorch, *args, **kwargs) pytorch_res, oneflow_res = get_pytorch_oneflow_res( pytorch, oneflow, pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, name, verbose, testing_graph, *args, ) return GetDualObject( f"{name}({param_str})", pytorch_res, oneflow_res ) else: def dual_method(self, *args, **kwargs): pytorch_method = getattr(pytorch, method_name) oneflow_method = getattr(oneflow, method_name) ( pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, ) = get_args(pytorch_method, *args, **kwargs) pytorch_res, oneflow_res = get_pytorch_oneflow_tensor_res( pytorch_method, oneflow_method, oneflow, pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs, testing_graph, verbose, ) return GetDualObject("unused", pytorch_res, oneflow_res) return dual_method magic_methods_for_new_cls[method_name] = get_dual_method(method_name) Cls = type(f"{name}_{counter}", (DualObject,), magic_methods_for_new_cls) return Cls(name, pytorch, oneflow) def note_print_args(x, end=True): if end: if isinstance(x, str) and "Tensor" not in x: print(f"\033[32m{x}, \033[0m", end="") else: print(f"\033[32m{x}, \033[0m", end="") else: if isinstance(x, str) and "Tensor" not in x: print(f"\033[32m{x}\033[0m", end="") else: print(f"\033[32m{x}\033[0m", end="") def note_print_kwargs(x, y, end=True): if end: if isinstance(y, str) and "Tensor" not in y: print(f"\033[32m{x}={y}, \033[0m", end="") else: print(f"\033[32m{x}={y}, \033[0m", end="") else: if isinstance(y, str) and "Tensor" not in y: print(f"\033[32m{x}={y}\033[0m", end="") else: print(f"\033[32m{x}={y}\033[0m", end="") def print_note_fake_program(detail=False): code_len = len(note_pytorch_method_names) for i in range(code_len): note_pytorch_args_len = len(note_pytorch_args[i]) note_pytorch_kwargs_len = len(note_pytorch_kwargs[i]) print(f"\033[32m{note_pytorch_method_names[i]}\033[0m", end="") print(f"\033[32m(\033[0m", end="") if note_pytorch_args[i]: index = 0 for x in note_pytorch_args[i]: index += 1 note_print_args(x, index < note_pytorch_args_len) if note_pytorch_kwargs[i]: index = 0 if note_pytorch_args[i]: print(f"\033[32m, \033[0m", end="") for x in note_pytorch_kwargs[i].keys(): index += 1 note_print_kwargs( x, note_pytorch_kwargs[i][x], index < note_pytorch_kwargs_len ) print(f"\033[32m)\033[0m") if detail: print( f"\033[32m-----------------------------------------------------------\033[0m" ) unique_vis_tensor = [] flag_vis_input_tensor = [False for _ in range(len(vis_tensor))] for i in range(len(vis_tensor)): if flag_vis_input_tensor[i] == True: continue unique_vis_tensor.append(vis_tensor[i]) flag_vis_input_tensor[i] = True for j in range(i + 1, len(vis_tensor)): if ( id(vis_tensor[i]) == id(vis_tensor[j]) and flag_vis_input_tensor[j] == False ): flag_vis_input_tensor[j] = True unique_extra_tensor = [] flag_vis_extra_tensor = [False for _ in range(len(extra_input_tensor))] for i in range(len(extra_input_tensor)): if flag_vis_extra_tensor[i] == True: continue unique_extra_tensor.append(extra_input_tensor[i]) flag_vis_extra_tensor[i] = True for j in range(i + 1, len(extra_input_tensor)): if ( id(extra_input_tensor[i]) == id(extra_input_tensor[j]) and flag_vis_extra_tensor[j] == False ): flag_vis_extra_tensor[j] = True print( f"\033[32mThis program has {len(unique_extra_tensor) + len(unique_vis_tensor)} input tensor: \033[0m" ) for input_tensor in iter(unique_extra_tensor): print(f"\033[32mShape{get_tensor_shape(input_tensor)}\033[0m") print(f"\033[32m{input_tensor}\033[0m") print( f"\033[32m-----------------------------------------------------------\033[0m" ) for input_tensor in iter(unique_vis_tensor): print(f"\033[32mShape{get_tensor_shape(input_tensor)}\033[0m") print(f"\033[32m{input_tensor}\033[0m") print( f"\033[32m-----------------------------------------------------------\033[0m" ) if vis_parameters: print( f"\033[32m-------------------nn.Module Parameters---------------------\033[0m" ) for name, param in vis_parameters.items(): print(f"\033[32m{name}: {param}\033[0m") def clear_note_fake_program(): note_pytorch_method_names.clear() note_pytorch_args.clear() note_pytorch_kwargs.clear() call_tensor_id.clear() vis_tensor.clear() vis_parameters.clear() extra_input_tensor.clear() flow.set_printoptions(profile="full") tensor_size_limit_mb = int(os.getenv("ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB", 32)) class DualObject: def __init__(self, name, pytorch, oneflow): self.name = name if isinstance(pytorch, torch_original.nn.Module): if is_global(): pytorch.load_state_dict(broadcast(pytorch).state_dict()) state_dict = pytorch.state_dict() state_dict = {k: v.detach().cpu().numpy() for (k, v) in state_dict.items()} oneflow_state_dict = oneflow.state_dict() oneflow_state_dict = { k: v.detach() for (k, v) in oneflow_state_dict.items() } already_global = any([v.is_global for v in oneflow_state_dict.values()]) if is_global() and already_global: for k, v in state_dict.items(): if k not in oneflow_state_dict: continue of_state = oneflow_state_dict[k] if of_state.is_global: state_dict[k] = flow.tensor( v, sbp=of_state.sbp, placement=of_state.placement ) oneflow.load_state_dict(state_dict, strict=False) if is_global(): if already_global: for (k, v) in oneflow_state_dict.items(): if v.is_global: t = getattr(oneflow, k) new = t.to_global(placement=v.placement, sbp=v.sbp) if isinstance(t, flow.nn.Parameter): new = flow.nn.Parameter(new) setattr( oneflow, k, new, ) else: oneflow = oneflow.to_global( placement=flow.placement.all("cpu"), sbp=[flow.sbp.broadcast,], ) if testing: dual_modules_to_test.append(self) if isinstance(pytorch, torch_original.Tensor): tensor_size_mb = pytorch.nelement() * pytorch.element_size() / 1024 / 1024 assert ( tensor_size_mb < tensor_size_limit_mb ), f"Tensor memory in autotest cannot be larger than {tensor_size_limit_mb}MB, but got {tensor_size_mb}MB" if testing: dual_objects_to_test.append(self) self.pytorch = pytorch self.oneflow = oneflow def __repr__(self): return f"PyTorch object:\n{self.pytorch}\n\nOneFlow object:\n{self.oneflow}" def __getattr__(self, key): if key in ["to_global", "to_local"]: def identity(*args, **kwargs): if isinstance(self.pytorch, torch_original.Tensor): return self.pytorch.clone() return self.pytorch pytorch_attr = identity elif key in ["placement", "sbp"]: pytorch_attr = "unused" elif key in ["broadcast_like"]: def broadcast_like(x, y, *args, **kwargs): return self.pytorch.broadcast_to(x, y.size()) pytorch_attr = broadcast_like else: pytorch_attr = getattr(self.pytorch, key) oneflow_attr = getattr(self.oneflow, key) if pytorch_attr is None: assert ( oneflow_attr is None ), f"pytorch value is None for attr {key}, but oneflow is not." return None if self.name == "": new_name = key else: new_name = f"{self.name}.{key}" global call_pytorch call_pytorch = self.pytorch return GetDualObject(new_name, pytorch_attr, oneflow_attr) def __setattr__(self, key, value): if isinstance(value, DualObject): setattr(self.pytorch, key, value.pytorch) setattr(self.oneflow, key, value.oneflow) else: self.__dict__[key] = value def __eq__(self, other): if isinstance(other, DualObject): return self.pytorch == other.pytorch and self.oneflow == other.oneflow else: return self.pytorch == other dual_modules_to_test = [] dual_objects_to_test = [] torch_type2checker = {} def equality_checker(torch_type, flow_type): def deco(f): torch_type2checker[torch_type, flow_type] = f return f return deco def check_equality(dual_object: DualObject, rtol=0.0001, atol=1e-05, check_dtype=False): checker = torch_type2checker.get( (type(dual_object.pytorch), type(dual_object.oneflow)), None ) if checker is None: for (key, value) in torch_type2checker.items(): if isinstance(dual_object.pytorch, key[0]) and isinstance( dual_object.oneflow, key[1] ): checker = value break assert checker is not None, ( "checker not found for type " + str(type(dual_object.pytorch)) + " and " + str(type(dual_object.oneflow)) ) return checker(dual_object.pytorch, dual_object.oneflow, rtol, atol, check_dtype) @equality_checker(torch_original.Tensor, flow.Tensor) @equality_checker(torch_original.Tensor, flow._oneflow_internal.Tensor) def check_tensor_equality( torch_tensor, flow_tensor, rtol=0.0001, atol=1e-05, check_dtype=False ): if torch_tensor.grad is not None: if flow_tensor.grad is None: print_note_fake_program(detail=True) assert ( flow_tensor.grad is not None ), f"OneFlow tensor doesn't have grad while PyTorch tensor has one, PyTorch tensor is\n {torch_tensor}\n, OneFlow tensor is\n{flow_tensor} " torch_grad = ( torch_tensor.grad.detach().cpu().numpy() if not torch_original.is_conj(torch_tensor.grad) else torch_original.resolve_conj(torch_tensor.grad.detach()).cpu().numpy() ) flow_grad = flow_tensor.grad.numpy() if not np.allclose( torch_grad, flow_grad, rtol=rtol, atol=atol, equal_nan=True, ): print_note_fake_program(detail=True) print("---------Grad Shape--------") print(torch_grad.shape) print(flow_grad.shape) print( f"Grads are not equal. PyTorch grad: \n{torch_grad}\n, OneFlow grad: \n{flow_grad}" ) return False torch_numpy = ( torch_tensor.detach().cpu().numpy() if not torch_original.is_conj(torch_tensor) else torch_original.resolve_conj(torch_tensor.detach()).cpu().numpy() ) oneflow_numpy = flow_tensor.numpy() equality_res = np.allclose( torch_numpy, oneflow_numpy, rtol=rtol, atol=atol, equal_nan=True, ) # NOTE: if check_dtype=True, then check the equality of data type if check_dtype: equality_res = equality_res and (torch_numpy.dtype == oneflow_numpy.dtype) if equality_res == False: print_note_fake_program(detail=True) print("---------Tensor Shape--------") print(torch_tensor.shape) print(flow_tensor.shape) print("---------Tensor dtype--------") print(torch_tensor.dtype) print(flow_tensor.dtype) return equality_res @equality_checker(int, int) @equality_checker(bool, bool) def check_basetype_equality(a, b, ignored1, ignored2, check_dtype=False): if check_dtype: return (a == b) and (type(a) == type(b)) return a == b @equality_checker(tuple, tuple) @equality_checker(list, list) def check_basetype_equality(a, b, rtol=0.0001, atol=1e-05, check_dtype=False): if len(a) != len(b): equality_res = False else: for i in range(len(a)): torch_np = a[i].detach().cpu().numpy() flow_np = b[i].detach().cpu().numpy() equality_res = np.allclose( torch_np, flow_np, rtol=rtol, atol=atol, equal_nan=True, ) if check_dtype: equality_res = equality_res and (torch_np.dtype == flow_np.dtype) if equality_res == False: print_note_fake_program(detail=True) print("---------Tensor Shape--------") print(a[i].shape) print(b[i].shape) print("---------Tensor dtype--------") print(a[i].dtype) print(b[i].dtype) break return equality_res @equality_checker(type(None), type(None)) def check_nonetype_equality(a, b, ignored1, ignored2, check_dtype=False): return True def autotest( n=20, auto_backward: Union[bool, str] = True, rtol=0.0001, atol=1e-05, check_graph=True, check_allclose=True, check_dtype=False, check_grad_use_random_data=True, include_complex=False, ): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None if check_graph == "ValidatedFalse": # check graph is intentionally closed and there is a validated reason. check_graph = False def deco(f): @functools.wraps(f) def new_f(test_case, *args, **kwargs): successful_runs_needed = n loop_limit = successful_runs_needed * 20 current_run = 0 while successful_runs_needed > 0: clear_note_fake_program() if current_run > loop_limit: raise ValueError( "autotest stuck in an endless loop, usually it is caused by invalid code in the test case" ) dual_modules_to_test.clear() dual_objects_to_test.clear() global global_check_allclose, global_rtol, global_atol, global_backward global_check_allclose = check_allclose global_rtol = rtol global_atol = atol global_backward = auto_backward try: global testing_graph # for generate fake program input tensor global testing testing = True if check_graph: testing_graph = True global testing_complex if include_complex: testing_complex = True testing_graph = False res = f(test_case, *args, **kwargs) testing = False testing_graph = False testing_complex = False except (PyTorchDoesNotSupportError, BothDoNotSupportError) as e: if verbose: print(f"{f.__name__}") print(e) current_run += 1 continue if res is not None: if not isinstance(res, collections.abc.Sequence): res = [res] for x in res: if x is None: continue if auto_backward: if isinstance(x.pytorch, torch_original.Tensor): if auto_backward == "auto" and ( not x.pytorch.requires_grad or not x.oneflow.requires_grad ): continue call_tensor_id.append(id(x.pytorch)) if check_grad_use_random_data: np_arr = rng.uniform( low=0, high=1, size=list(x.oneflow.shape) ) if is_global(): np_arr = broadcast(np_arr) flow_tensor = flow.tensor( np_arr, dtype=x.oneflow.dtype, placement=x.oneflow.placement, sbp=len(x.oneflow.sbp) * [flow.sbp.broadcast], ) else: flow_tensor = flow.tensor( np_arr, dtype=x.oneflow.dtype, device=x.oneflow.device, ) # TODO(): Inferred shape of some op is different between oneflow and torch pytorch_tensor = torch_original.tensor( np_arr.reshape(list(x.pytorch.shape)), dtype=x.pytorch.dtype, device=x.pytorch.device, ) call_tensor_id.append(id(pytorch_tensor)) diff_output = GetDualObject( "unused", pytorch_tensor, flow_tensor ) x.backward(diff_output) else: x.sum().backward() dual_objects_to_test.append(x) for x in dual_modules_to_test: for key in x.pytorch.state_dict().keys(): if key not in x.oneflow.state_dict().keys(): warnings.warn(f"oneflow module don't have `{key}`") continue vis_parameters[key] = x.pytorch.state_dict()[key] dual_objects_to_test.append( GetDualObject( "unused", getattr(x.pytorch, key), getattr(x.oneflow, key), ) ) call_tensor_id.append(id(getattr(x.pytorch, key))) dual_objects_to_test.append( GetDualObject( "unused", getattr(x.pytorch, key).grad, getattr(x.oneflow, key).grad, ) ) call_tensor_id.append(id(getattr(x.pytorch, key).grad)) for x in dual_objects_to_test: if ( isinstance(x.pytorch, torch_original.Tensor) and id(x.pytorch) not in call_tensor_id ): vis_tensor.append(x.pytorch) # check eager for x in dual_objects_to_test: if check_allclose: test_case.assertTrue( check_equality( x, rtol=rtol, atol=atol, check_dtype=check_dtype, ), x, ) if verbose: print(f"{f.__name__} test eager passed.") if verbose and check_graph: print(f"{f.__name__} test graph passed.") successful_runs_needed -= 1 current_run += 1 return new_f return deco def globaltest(f): @functools.wraps(f) def new_f(*args, **kwargs): with GlobalScope() as scope: return f(*args, **kwargs) return new_f def random_tensor( ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None, low=None, high=None, dtype=float, requires_grad=True, pin_memory=False, ): if isinstance(requires_grad, generator): requires_grad = requires_grad.value() if dtype == float and testing_complex: # Generate complex with the probability of 0.5 dtype = complex if rng.integers(0, 2) == 1 else float pytorch_tensor = ( random_pytorch_tensor( ndim, dim0, dim1, dim2, dim3, dim4, low, high, dtype, pin_memory ) .value() .requires_grad_(requires_grad and dtype != int) ) extra_input_tensor.append(pytorch_tensor) if is_global(): flow_tensor = flow.tensor( pytorch_tensor.detach().cpu().numpy(), requires_grad=(requires_grad and dtype != int), placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ) else: flow_tensor = flow.tensor( pytorch_tensor.detach().cpu().numpy(), requires_grad=(requires_grad and dtype != int), pin_memory=pin_memory, ) return GetDualObject("unused", pytorch_tensor, flow_tensor) def random_dtype(seq_names): pytorch_dtype = random_pytorch_dtype(seq_names).value() if pytorch_dtype is None: flow_dtype = None else: flow_dtype = type_name_to_flow_type[pytorch_dtype.__str__().split(".")[-1]] return GetDualObject("DualDType", pytorch_dtype, flow_dtype) def choice_tensor( a, size=None, replace=True, p=None, dtype=int, requires_grad=False, ): """Generates a random sample from a given 1-D array, which aligns with numpy.random.choice see https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html for details """ if isinstance(requires_grad, generator): requires_grad = requires_grad.value() pytorch_tensor = ( choice_pytorch_tensor(a, size, replace, p, dtype) .value() .requires_grad_(requires_grad and dtype != int) ) if is_global(): flow_tensor = flow.tensor( pytorch_tensor.detach().cpu().numpy(), requires_grad=(requires_grad and dtype != int), placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast, ) else: flow_tensor = flow.tensor( pytorch_tensor.detach().cpu().numpy(), requires_grad=(requires_grad and dtype != int), ) return GetDualObject("unused", pytorch_tensor, flow_tensor) torch = GetDualObject("", torch_original, flow) __all__ = ["autotest", "globaltest", "random_tensor", "random_dtype", "choice_tensor"] ================================================ FILE: python/oneflow/test_utils/automated_test_util/util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import pickle import oneflow as flow def broadcast(obj, src: int = 0): rank = flow.env.get_rank() if src == rank: obj_bytes = pickle.dumps(obj) obj_bytes = flow._oneflow_internal.cpu_broadcast(obj_bytes, src) else: obj_bytes = flow._oneflow_internal.cpu_broadcast(None, src) return pickle.loads(obj_bytes) ================================================ FILE: python/oneflow/test_utils/oneflow_pytorch_compatibility/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from .oneflow_pytorch_compatiblity_test import * ================================================ FILE: python/oneflow/test_utils/oneflow_pytorch_compatibility/oneflow_pytorch_compatiblity_test.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import importlib.util import unittest import numpy as np import time import tempfile import argparse import oneflow as flow import torch import oneflow.unittest import shutil import matplotlib as mpl mpl.use("Agg") import matplotlib.pyplot as plt verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None def cos_sim(vector_a, vector_b): vector_a = np.mat(vector_a) vector_b = np.mat(vector_b) num = float(vector_a * vector_b.T) denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b) cos = num / denom sim = 0.5 + 0.5 * cos return sim def import_file(source): with tempfile.NamedTemporaryFile("w", suffix=".py") as f: f.write(source) f.flush() spec = importlib.util.spec_from_file_location("mod", f.name) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod def get_loss( image_nd, label_nd, model_path: str, module_name: str, test_pytorch: bool = True, device: str = "cuda", tmpfilename: str = "/tmp/oneflow_tmp_file", ): model_loss = [] learning_rate = 0.01 mom = 0.9 bp_iters = 100 for_time = 0.0 bp_time = 0.0 update_time = 0.0 if test_pytorch == True: image = flow.tensor(image_nd) label = flow.tensor(label_nd) corss_entropy = flow.nn.CrossEntropyLoss(reduction="mean") with open(model_path) as f: buf = f.read() lines = buf.split("\n") buf = "\n".join(lines) python_module = import_file(buf) Net = getattr(python_module, module_name) pytorch_model = Net() w = pytorch_model.state_dict() new_parameters = dict() for k, v in w.items(): if "num_batches_tracked" not in k: new_parameters[k] = flow.tensor(w[k].detach().numpy()) flow.save(new_parameters, tmpfilename) pytorch_model.to(device) torch_sgd = torch.optim.SGD( pytorch_model.parameters(), lr=learning_rate, momentum=mom ) image = torch.tensor(image_nd) image_gpu = image.to(device) corss_entropy = torch.nn.CrossEntropyLoss() corss_entropy.to(device) label = torch.tensor(label_nd, dtype=torch.long).to(device) print("start pytorch training loop....") start_t = time.time() for i in range(bp_iters): s_t = time.time() logits = pytorch_model(image_gpu) loss = corss_entropy(logits, label) for_time += time.time() - s_t s_t = time.time() loss.backward() bp_time += time.time() - s_t model_loss.append(loss.detach().cpu().numpy()) s_t = time.time() torch_sgd.step() torch_sgd.zero_grad() update_time += time.time() - s_t end_t = time.time() if verbose: print( "pytorch traning loop avg time : {}".format( (end_t - start_t) / bp_iters ) ) print("forward avg time : {}".format(for_time / bp_iters)) print("backward avg time : {}".format(bp_time / bp_iters)) print("update parameters avg time : {}".format(update_time / bp_iters)) else: with open(model_path) as f: buf = f.read() lines = buf.split("\n") for i, line in enumerate(lines): if ( i > 15 and "import" not in line and len(line.strip()) != 0 ): # 15 means license break lines = ( lines[:i] + [ "import oneflow as torch", "import oneflow.nn as nn", "import oneflow.nn.init as init", "import oneflow.nn.functional as F", "from oneflow import Tensor", "from oneflow.nn import Parameter", "import math", "from flowvision.layers import *", ] + lines[i:] ) buf = "\n".join(lines) python_module = import_file(buf) Net = getattr(python_module, module_name) oneflow_model = Net() image = flow.tensor(image_nd) label = flow.tensor(label_nd) corss_entropy = flow.nn.CrossEntropyLoss(reduction="mean") image_gpu = image.to(device) label = label.to(device) oneflow_model.to(device) corss_entropy.to(device) params = flow.load(tmpfilename) oneflow_model.load_state_dict(params) of_sgd = flow.optim.SGD( oneflow_model.parameters(), lr=learning_rate, momentum=mom ) print("start oneflow training loop....") start_t = time.time() for i in range(bp_iters): s_t = time.time() logits = oneflow_model(image_gpu) loss = corss_entropy(logits, label) for_time += time.time() - s_t s_t = time.time() loss.backward() bp_time += time.time() - s_t model_loss.append(loss.numpy()) s_t = time.time() of_sgd.step() of_sgd.zero_grad() update_time += time.time() - s_t end_t = time.time() if verbose: print( "oneflow traning loop avg time : {}".format( (end_t - start_t) / bp_iters ) ) print("forward avg time : {}".format(for_time / bp_iters)) print("backward avg time : {}".format(bp_time / bp_iters)) print("update parameters avg time : {}".format(update_time / bp_iters)) return model_loss def do_test_train_loss_oneflow_pytorch( test_case, model_path: str, module_name: str, device: str = "cuda", batch_size: int = 16, img_size: int = 224, ): image_nd = np.random.rand(batch_size, 3, img_size, img_size).astype(np.float32) label_nd = np.array([e for e in range(batch_size)], dtype=np.int32) oneflow_model_loss = [] pytorch_model_loss = [] with tempfile.NamedTemporaryFile() as f: pytorch_model_loss = get_loss( image_nd, label_nd, model_path, module_name, True, device, f.name ) oneflow_model_loss = get_loss( image_nd, label_nd, model_path, module_name, False, device, f.name ) if verbose: indes = [i for i in range(len(oneflow_model_loss))] plt.plot(indes, oneflow_model_loss, label="oneflow") plt.plot(indes, pytorch_model_loss, label="pytorch") plt.xlabel("iter - axis") # Set the y axis label of the current axis. plt.ylabel("loss - axis") # Set a title of the current axes. plt.title("compare ") # show a legend on the plot plt.legend() # Display a figure. plt.savefig("./loss_compare.png") plt.show() test_case.assertTrue( np.allclose(cos_sim(oneflow_model_loss, pytorch_model_loss), 1.0, 1e-1, 1e-1) ) ================================================ FILE: python/oneflow/test_utils/test_util.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import itertools import os from collections import OrderedDict from collections.abc import Iterable import numpy as np import oneflow as flow import oneflow.unittest def GenCartesianProduct(sets): assert isinstance(sets, Iterable) for set in sets: assert isinstance(set, Iterable) if os.getenv("ONEFLOW_TEST_CPU_ONLY"): if "cuda" in set: set.remove("cuda") return itertools.product(*sets) def GenArgList(arg_dict): assert isinstance(arg_dict, OrderedDict) assert all([isinstance(x, list) for x in arg_dict.values()]) sets = [arg_set for (_, arg_set) in arg_dict.items()] return GenCartesianProduct(sets) def GenArgDict(arg_dict): return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)] class Args: def __init__(self, flow_args, tf_args=None): super().__init__() if tf_args is None: tf_args = flow_args self.flow_args = flow_args self.tf_args = tf_args def __str__(self): return "flow_args={} tf_args={}".format(self.flow_args, self.tf_args) def __repr__(self): return self.__str__() type_name_to_flow_type = { "bool": flow.bool, "float16": flow.float16, "float32": flow.float32, "double": flow.double, "float64": flow.double, "int8": flow.int8, "int32": flow.int32, "int64": flow.int64, "uint8": flow.uint8, "half": flow.half, "bfloat16": flow.bfloat16, "complex64": flow.complex64, "complex128": flow.complex128, } type_name_to_np_type = { "float16": np.float16, "float32": np.float32, "double": np.float64, "int8": np.int8, "int32": np.int32, "int64": np.int64, "uint8": np.uint8, "complex64": np.complex64, "complex128": np.complex128, } def FlattenArray(input_array): output_array = list() for x in np.nditer(input_array): output_array.append(x.tolist()) return output_array def Array2Numpy(input_array, target_shape): return np.array(input_array).reshape(target_shape, order="C") def Index2Coordinate(idx, tensor_shape): coordinate = [] tmp = idx for i in range(len(tensor_shape) - 1, -1, -1): axis_size = tensor_shape[i] coor = tmp % axis_size coordinate.insert(0, int(coor)) tmp = (tmp - coor) / axis_size return coordinate def Coordinate2Index(coordinate, tensor_shape): if len(coordinate) != len(tensor_shape): raise "wrong coordinate or shape" idx = 0 for (i, coor) in enumerate(coordinate): size_at_axis = coor for j in range(i + 1, len(tensor_shape)): size_at_axis *= tensor_shape[j] idx += size_at_axis return idx ================================================ FILE: python/oneflow/test_utils/throttle.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import hashlib import subprocess import portalocker import os def parse_args(): parser = argparse.ArgumentParser( description="Control when the script runs through special variables." ) parser.add_argument( "--with-cuda", type=int, default=1, help="whether has cuda device." ) parser.add_argument("cmd", type=str, nargs="...", help="command to run") return parser.parse_args() def hash_cli2gpu(cmd: list): import pynvml pynvml.nvmlInit() slot = pynvml.nvmlDeviceGetCount() hash = hashlib.sha1(" ".join(cmd).encode("utf-8")).hexdigest() gpu_id = int(hash, 16) % slot return [gpu_id] def main(): args = parse_args() if args.with_cuda: cuda_visible_devices = [str(i) for i in hash_cli2gpu(args.cmd)] with portalocker.Lock( ".oneflow-throttle-gpu-" + "-".join(cuda_visible_devices) + ".lock", timeout=400, ): env = dict(os.environ, CUDA_VISIBLE_DEVICES=",".join(cuda_visible_devices)) return subprocess.call(args.cmd, env=env) else: return subprocess.call(args.cmd) if __name__ == "__main__": returncode = main() exit(returncode) ================================================ FILE: python/oneflow/unittest/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.unittest import ( TestCase, num_nodes_required, register_test_cases, skip_unless_1n1d, skip_unless_1n2d, skip_unless_1n4d, skip_unless_2n1d, skip_unless_2n2d, skip_unless_2n4d, ) from . import env from .mlir import MLIRTestCase from .dataset import dataset_dir ================================================ FILE: python/oneflow/unittest/dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os def dataset_dir(sub_dir=None): base_dir = os.getenv("ONEFLOW_TEST_DATASET_DIR") if base_dir == None: base_dir = "/dataset" if sub_dir == None: return base_dir else: return os.path.join(base_dir, sub_dir) ================================================ FILE: python/oneflow/unittest/env.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.unittest import ( device_num, eager_execution_enabled, has_node_list, has_world_size, node_list, node_size, typing_check_enabled, world_size, ) ================================================ FILE: python/oneflow/unittest/mlir.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import unittest class MLIRTestCase(unittest.TestCase): def tearDown(self): for key in os.environ.keys(): if key.startswith("ONEFLOW_MLIR"): os.environ.pop(key) ================================================ FILE: python/oneflow/utils/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.framework.config_util import api_load_library as load_library from oneflow.utils import tensor from oneflow.utils import global_view from oneflow.utils import model_zoo from . import checkpoint from . import hooks ================================================ FILE: python/oneflow/utils/checkpoint.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This file is mostly copied from PyTorch import oneflow as flow from typing import List, Union def _checkpoint_without_reentrant(function, *args): """Checkpointining without re-entrant autograd Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` *args: Arguments to pass in to the given ``function``. """ storage: List[Union[flow.Tensor, None]] = [] counter = 0 def pack(x): nonlocal counter counter += 1 return counter - 1 # TODO(jianhao): support restoring rng state once we have flow.random.fork_rng def unpack(x): if len(storage) == 0: def inner_pack(inner): storage.append(inner) return None def inner_unpack(packed): raise RuntimeError( "You are calling backwards on a tensor that is never exposed. Please open an issue." ) with flow.enable_grad(): with flow.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) return storage[x] with flow.autograd.graph.saved_tensors_hooks(pack, unpack): output = function(*args) return output def checkpoint(function, *args): r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does **not** save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, :attr:`function` will run in :func:`flow.no_grad` manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the :attr:`function` parameter. In the backwards pass, the saved inputs and :attr:`function` is retrieved, and the forward pass is computed on :attr:`function` again, now tracking the intermediate activations, and then the gradients are calculated using these activation values. The output of :attr:`function` can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. .. warning:: If :attr:`function` invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won't be equivalent, and unfortunately it can't be detected. .. warning:: Preserving rng states is not supported now, so that the behavior of checkpointing does not fully align with PyTorch. Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr:`function` on :attr:`*args` """ return _checkpoint_without_reentrant(function, *args) ================================================ FILE: python/oneflow/utils/data/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.utils.data.sampler import ( Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, BatchSampler, ) from oneflow.utils.data.dataset import ( Dataset, IterableDataset, TensorDataset, ConcatDataset, Subset, random_split, ) from oneflow.utils.data.dataset import IterableDataset as IterDataPipe from oneflow.utils.data.dataloader import ( DataLoader, _DatasetKind, get_worker_info, ) from oneflow.utils.data.decorator import ( functional_datapipe, guaranteed_datapipes_determinism, non_deterministic, ) from oneflow.utils.data.distributed import DistributedSampler __all__ = [ "Sampler", "SequentialSampler", "RandomSampler", "SubsetRandomSampler", "BatchSampler", "Dataset", "IterableDataset", "TensorDataset", "ConcatDataset", "Subset", "random_split", "DataLoader", "_DatasetKind", "IterDataPipe", "functional_datapipe", "guaranteed_datapipes_determinism", "non_deterministic", "DistributedSampler", ] ================================================ FILE: python/oneflow/utils/data/_utils/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. A lot of multiprocessing is used in data loading, which only supports running functions defined in global environment (py2 can't serialize static methods). Therefore, for code tidiness we put these functions into different files in this folder. """ import sys import atexit IS_WINDOWS = sys.platform == "win32" # pytorch's check interval is 5.0 seconds MP_STATUS_CHECK_INTERVAL = 10.0 r"""Interval (in seconds) to check status of processes to avoid hanging in multiprocessing data loading. This is mainly used in getting data from another process, in which case we need to periodically check whether the sender is alive to prevent hanging.""" python_exit_status = False r"""Whether Python is shutting down. This flag is guaranteed to be set before the Python core library resources are freed, but Python may already be exiting for some time when this is set. Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar hook in Python 3.7 multiprocessing library: https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 """ try: import numpy HAS_NUMPY = True except ModuleNotFoundError: HAS_NUMPY = False def _set_python_exit_flag(): global python_exit_status python_exit_status = True atexit.register(_set_python_exit_flag) from . import worker, signal_handling, collate, fetch, pin_memory ================================================ FILE: python/oneflow/utils/data/_utils/collate.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to collate samples fetched from dataset into Tensor(s). These **needs** to be in global scope since Py2 doesn't support serializing static methods. """ import re import collections import oneflow as flow string_classes = (str, bytes) np_str_obj_array_pattern = re.compile(r"[SaUO]") def default_convert(data): r"""Converts each NumPy array data field into a tensor""" elem_type = type(data) if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)): return data elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): # array of string classes and object if ( elem_type.__name__ == "ndarray" and np_str_obj_array_pattern.search(data.dtype.str) is not None ): return data return flow.tensor(data) elif isinstance(data, collections.abc.Mapping): return {key: default_convert(data[key]) for key in data} elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple return elem_type(*(default_convert(d) for d in data)) elif isinstance(data, collections.abc.Sequence) and not isinstance( data, string_classes ): return [default_convert(d) for d in data] else: # NOTE: pytorch just return data here, and not raise any exception! raise TypeError(default_convert_err_msg_format.format(elem_type)) default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}" ) default_convert_err_msg_format = ( "default_convert: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}" ) def default_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)): # TODO: tensor.storage()._new_shared(numel) return flow._C.stack(batch, dim=0) elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([flow.tensor(b) for b in batch]) elif elem.shape == (): # scalars return flow.tensor(batch) elif isinstance(elem, float): return flow.tensor(batch, dtype=flow.float64) elif isinstance(elem, int): return flow.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type(*(default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type)) ================================================ FILE: python/oneflow/utils/data/_utils/fetch.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """"Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. This logic is shared in both single- and multi-processing data loading. """ class _BaseDatasetFetcher(object): def __init__(self, dataset, auto_collation, collate_fn, drop_last): self.dataset = dataset self.auto_collation = auto_collation self.collate_fn = collate_fn self.drop_last = drop_last def fetch(self, possibly_batched_index): raise NotImplementedError() class _IterableDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_IterableDatasetFetcher, self).__init__( dataset, auto_collation, collate_fn, drop_last ) self.dataset_iter = iter(dataset) def fetch(self, possibly_batched_index): if self.auto_collation: data = [] for _ in possibly_batched_index: try: data.append(next(self.dataset_iter)) except StopIteration: break if len(data) == 0 or ( self.drop_last and len(data) < len(possibly_batched_index) ): raise StopIteration else: data = next(self.dataset_iter) return self.collate_fn(data) class _MapDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__( dataset, auto_collation, collate_fn, drop_last ) def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data) ================================================ FILE: python/oneflow/utils/data/_utils/pin_memory.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. These **needs** to be in global scope since Py2 doesn't support serializing static methods. """ import oneflow as flow import collections.abc import queue from . import MP_STATUS_CHECK_INTERVAL from oneflow._utils import ExceptionWrapper container_abcs = collections.abc string_classes = (str, bytes) def _pin_memory_loop(in_queue, out_queue, device_id, done_event): # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. flow.set_num_threads(1) # TODO: support flow.cuda.set_device # flow.cuda.set_device(device_id) while not done_event.is_set(): try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: data = pin_memory(data) except Exception: data = ExceptionWrapper( where="in pin memory thread for device {}".format(device_id) ) r = (idx, data) while not done_event.is_set(): try: out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue del r # save memory def pin_memory(data): if isinstance(data, flow.Tensor): return data.pin_memory() elif isinstance(data, string_classes): return data elif isinstance(data, container_abcs.Mapping): return {k: pin_memory(sample) for k, sample in data.items()} elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple return type(data)(*(pin_memory(sample) for sample in data)) elif isinstance(data, container_abcs.Sequence): return [pin_memory(sample) for sample in data] elif hasattr(data, "pin_memory"): return data.pin_memory() else: return data ================================================ FILE: python/oneflow/utils/data/_utils/signal_handling.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r"""Signal handling for multiprocessing data loading. NOTE [ Signal handling in multiprocessing data loading ] In cases like DataLoader, if a worker process dies due to bus error/segfault or just hang, the main process will hang waiting for data. This is difficult to avoid on OneFlow side as it can be caused by limited shm, or other libraries users call in the workers. In this file and `DataLoader.cpp`, we make our best effort to provide some error message to users when such unfortunate events happen. When a _BaseDataLoaderIter starts worker processes, their pids are registered in a defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ] via `_set_worker_pids`. When an error happens in a worker process, the main process received a SIGCHLD, and Python will eventually call the handler registered below (in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` call checks all registered worker pids and raise proper error message to prevent main process from hanging waiting for data from worker. Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, `_set_worker_signal_handlers` is called to register critical signal handlers (e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error message to stderr before triggering the default handler. So a message will also be printed from the worker process when it is killed by such signals. See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of this signal handling design and other mechanism we implement to make our multiprocessing data loading robust to errors. """ import signal import threading from . import IS_WINDOWS # Some of the following imported functions are not used in this file, but are to # be used `_utils.signal_handling.XXXXX`. from oneflow._oneflow_internal import ( _set_worker_pids, _remove_worker_pids, _error_if_any_worker_fails, _set_worker_signal_handlers, ) _SIGCHLD_handler_set = False r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one handler needs to be set for all DataLoaders in a process.""" def _set_SIGCHLD_handler(): # Windows doesn't support SIGCHLD handler if IS_WINDOWS: return # can't set signal in child threads if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined] return global _SIGCHLD_handler_set if _SIGCHLD_handler_set: return previous_handler = signal.getsignal(signal.SIGCHLD) if not callable(previous_handler): # This doesn't catch default handler, but SIGCHLD default handler is a # no-op. previous_handler = None def handler(signum, frame): # This following call uses `waitid` with WNOHANG from C side. Therefore, # Python can still get and update the process status successfully. _error_if_any_worker_fails() if previous_handler is not None: assert callable(previous_handler) previous_handler(signum, frame) signal.signal(signal.SIGCHLD, handler) _SIGCHLD_handler_set = True ================================================ FILE: python/oneflow/utils/data/_utils/worker.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing static methods. """ import random import os import sys import traceback import queue from dataclasses import dataclass from typing import Union from oneflow.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] from oneflow.multiprocessing import unlink_all_shared_memory import signal import oneflow as flow from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY from oneflow._utils import ExceptionWrapper if IS_WINDOWS: import ctypes from ctypes.wintypes import DWORD, BOOL, HANDLE # On Windows, the parent ID of the worker process remains unchanged when the manager process # is gone, and the only way to check it through OS is to let the worker have a process handle # of the manager and ask if the process status has changed. class ManagerWatchdog(object): def __init__(self): self.manager_pid = os.getppid() # mypy cannot detect this code is windows only self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) self.kernel32.OpenProcess.restype = HANDLE self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) self.kernel32.WaitForSingleObject.restype = DWORD # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx SYNCHRONIZE = 0x00100000 self.manager_handle = self.kernel32.OpenProcess( SYNCHRONIZE, 0, self.manager_pid ) if not self.manager_handle: raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] self.manager_dead = False def is_alive(self): if not self.manager_dead: # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx self.manager_dead = ( self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 ) return not self.manager_dead else: class ManagerWatchdog(object): # type: ignore[no-redef] def __init__(self): self.manager_pid = os.getppid() self.manager_dead = False def is_alive(self): if not self.manager_dead: self.manager_dead = os.getppid() != self.manager_pid return not self.manager_dead _worker_info = None class WorkerInfo(object): __initialized = False def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) self.__keys = tuple(kwargs.keys()) self.__initialized = True def __setattr__(self, key, val): if self.__initialized: raise RuntimeError( "Cannot assign attributes to {} objects".format(self.__class__.__name__) ) return super(WorkerInfo, self).__setattr__(key, val) def __repr__(self): items = [] for k in self.__keys: items.append("{}={}".format(k, getattr(self, k))) return "{}({})".format(self.__class__.__name__, ", ".join(items)) def get_worker_info(): r"""Returns the information about the current :class:`~flow.utils.data.DataLoader` iterator worker process. When called in a worker, this returns an object guaranteed to have the following attributes: * :attr:`id`: the current worker id. * :attr:`num_workers`: the total number of workers. * :attr:`seed`: the random seed set for the current worker. This value is determined by main process RNG and the worker id. See :class:`~flow.utils.data.DataLoader`'s documentation for more details. * :attr:`dataset`: the copy of the dataset object in **this** process. Note that this will be a different object in a different process than the one in the main process. When called in the main process, this returns ``None``. .. note:: When used in a :attr:`worker_init_fn` passed over to :class:`~flow.utils.data.DataLoader`, this method can be useful to set up each worker process differently, for instance, using ``worker_id`` to configure the ``dataset`` object to only read a specific fraction of a sharded dataset, or use ``seed`` to seed other libraries used in dataset code. """ return _worker_info r"""Dummy class used to signal the end of an IterableDataset""" @dataclass(frozen=True) class _IterableDatasetStopIteration(object): worker_id: int r"""Dummy class used to resume the fetching when worker reuse is enabled""" @dataclass(frozen=True) class _ResumeIteration(object): pass # The function `_generate_state` is adapted from `numpy.random.SeedSequence` # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx # It's MIT licensed, here is the copyright: # Copyright (c) 2015 Melissa E. O'Neill # Copyright (c) 2019 NumPy Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # This function generates an array of int32 as the seed for # `numpy.random`, in order to prevent state collision due to same # seed and algorithm for `numpy.random` and `random` modules. # TODO: Implement `SeedSequence` like object for `flow.random` def _generate_state(base_seed, worker_id): INIT_A = 0x43B0D7E5 MULT_A = 0x931E8875 INIT_B = 0x8B51F9DD MULT_B = 0x58F38DED MIX_MULT_L = 0xCA01F9DD MIX_MULT_R = 0x4973F715 XSHIFT = 4 * 8 // 2 MASK32 = 0xFFFFFFFF entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] pool = [0] * 4 hash_const_A = INIT_A def hash(value): nonlocal hash_const_A value = (value ^ hash_const_A) & MASK32 hash_const_A = (hash_const_A * MULT_A) & MASK32 value = (value * hash_const_A) & MASK32 value = (value ^ (value >> XSHIFT)) & MASK32 return value def mix(x, y): result_x = (MIX_MULT_L * x) & MASK32 result_y = (MIX_MULT_R * y) & MASK32 result = (result_x - result_y) & MASK32 result = (result ^ (result >> XSHIFT)) & MASK32 return result # Add in the entropy to the pool. for i in range(len(pool)): pool[i] = hash(entropy[i]) # Mix all bits together so late bits can affect earlier bits. for i_src in range(len(pool)): for i_dst in range(len(pool)): if i_src != i_dst: pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) hash_const_B = INIT_B state = [] for i_dst in range(4): data_val = pool[i_dst] data_val = (data_val ^ hash_const_B) & MASK32 hash_const_B = (hash_const_B * MULT_B) & MASK32 data_val = (data_val * hash_const_B) & MASK32 data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 state.append(data_val) return state def _worker_loop( dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id, num_workers, persistent_workers, ): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: def cleanup_shm_at_exit(num, frame): unlink_all_shared_memory() # Use os._exit() to handle the exit of the subprocess to avoid share memory leaks # caused by the subprocess continuing for a period of time after the parent process ends. os._exit(0) _prctl_pr_set_pdeathsig(signal.SIGINT) # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers signal_handling._set_worker_signal_handlers() signal.signal(signal.SIGTERM, cleanup_shm_at_exit) signal.signal(signal.SIGINT, cleanup_shm_at_exit) flow.set_num_threads(1) seed = base_seed + worker_id random.seed(seed) flow.manual_seed(seed) if HAS_NUMPY: np_seed = _generate_state(base_seed, worker_id) import numpy as np np.random.seed(np_seed) global _worker_info _worker_info = WorkerInfo( id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset ) from oneflow.utils.data import _DatasetKind init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last ) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id) ) # When using Iterable mode, some worker can exit earlier than others due # to the IterableDataset behaving differently for different workers. # When such things happen, an `_IterableDatasetStopIteration` object is # sent over to the main process with the ID of this worker, so that the # main process won't send more tasks to this worker, and will send # `None` to this worker to properly exit it. # # Note that we cannot set `done_event` from a worker as it is shared # among all processes. Instead, we set the `iteration_end` flag to # signify that the iterator is exhausted. When either `done_event` or # `iteration_end` is set, we skip all processing step and just wait for # `None`. iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if isinstance(r, _ResumeIteration): # Acknowledge the main process data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last ) continue elif r is None: # Received the final signal assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r data: Union[_IterableDatasetStopIteration, ExceptionWrapper] if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: if ( isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable ): data = _IterableDatasetStopIteration(worker_id) # Set `iteration_end` # (1) to save future `next(...)` calls, and # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. iteration_end = True else: # It is important that we don't store exc_info in a variable. # `ExceptionWrapper` does the correct thing. # See NOTE [ Python Traceback Reference Cycle Problem ] data = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id) ) data_queue.put((idx, data)) del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close() # Python subprocess will be exited by os._exit(), which skips destructors of # C++ objects, so we should explicitly call unlink_all_shared_memory() here unlink_all_shared_memory() ================================================ FILE: python/oneflow/utils/data/dataloader.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import os import threading import itertools import queue from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional import multiprocessing as python_multiprocessing import oneflow.multiprocessing as multiprocessing from oneflow._utils import ExceptionWrapper import oneflow as flow import numpy as np string_classes = (str, bytes) from . import ( IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler, Dataset, ) from . import _utils T_co = TypeVar("T_co", covariant=True) T = TypeVar("T") _worker_init_fn_t = Callable[[int], None] # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. # See https://github.com/python/mypy/issues/3737. _collate_fn_t = Callable[[List[T]], Any] # This function used to be defined in this file. However, it was moved to # _utils/collate.py. Although it is rather hard to access this from user land # (one has to explicitly directly `import flow.utils.data.dataloader`), there # probably is user code out there using it. This aliasing maintains BC in this # aspect. default_collate: _collate_fn_t = _utils.collate.default_collate get_worker_info = _utils.worker.get_worker_info class _DatasetKind(object): Map = 0 Iterable = 1 @staticmethod def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): if kind == _DatasetKind.Map: return _utils.fetch._MapDatasetFetcher( dataset, auto_collation, collate_fn, drop_last ) else: return _utils.fetch._IterableDatasetFetcher( dataset, auto_collation, collate_fn, drop_last ) class _InfiniteConstantSampler(Sampler): r"""Analogous to ``itertools.repeat(None, None)``. Used as sampler for :class:`~flow.utils.data.IterableDataset`. Args: data_source (Dataset): dataset to sample from """ def __init__(self): super(_InfiniteConstantSampler, self).__init__(None) def __iter__(self): while True: yield None class DataLoader(Generic[T_co]): r""" Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The :class:`~oneflow.utils.data.DataLoader` supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. See :py:mod:`oneflow.utils.data` documentation page for more details. In consideration of compatibility, the design of our dataloader is consistent with pytorch, ref: https://github.com/pytorch/pytorch/tree/v1.7.0 Args: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: ``1``). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). sampler (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__`` implemented. If specified, :attr:`shuffle` must not be specified. batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. num_workers (int, optional): how many subprocesses to use for data loading (default: ``0``). ``0`` means that the data will be loaded in the main process. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. (default: ``False``) drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) prefetch_factor (int, optional, keyword-only arg): Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. (default: ``2``) persistent_workers (bool, optional): If ``True``, the data loader will immediately initialize worker preocesses and not shutdown them after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. If you are using oneflow with RDMA support in distributed training, the ``persistent_workers`` must be ``True`` otherwise will encounter segmentation fault. (default: ``False``) .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. When :attr:`dataset` is an :class:`~flow.utils.data.IterableDataset`, it instead returns an estimate based on ``len(dataset) / batch_size``, with proper rounding depending on :attr:`drop_last`, regardless of multi-process loading configurations. This represents the best guess OneFlow can make because OneFlow trusts user :attr:`dataset` code in correctly handling multi-process loading to avoid duplicate data. However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when :attr:`drop_last` is set. Unfortunately, OneFlow can not detect such cases in general. """ dataset: Dataset[T_co] batch_size: Optional[int] num_workers: int pin_memory: bool drop_last: bool timeout: float sampler: Sampler prefetch_factor: int _iterator: Optional["_BaseDataLoaderIter"] __initialized = False def __init__( self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=flow.Generator("cpu"), *, prefetch_factor: int = 2, persistent_workers: bool = False ): if num_workers < 0: raise ValueError( "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." ) else: self.num_workers = num_workers if timeout < 0: raise ValueError("timeout option should be non-negative") if self.num_workers == 0 and prefetch_factor != 2: raise ValueError( "prefetch_factor option could only be specified in multiprocessing." "let num_workers > 0 to enable multiprocessing." ) assert prefetch_factor > 0 if persistent_workers and num_workers == 0: raise ValueError("persistent_workers option needs num_workers > 0") self.dataset = dataset self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context # Arg-check dataset related before checking samplers because we want to # tell users that iterable-style datasets are incompatible with custom # samplers first, so that they don't learn that this combo doesn't work # after spending time fixing the custom sampler errors. if isinstance(dataset, IterableDataset): self._dataset_kind = _DatasetKind.Iterable # NOTE [ Custom Samplers and IterableDataset ] # # `IterableDataset` does not support custom `batch_sampler` or # `sampler` since the key is irrelevant (unless we support # generator-style dataset one day...). # # For `sampler`, we always create a dummy sampler. This is an # infinite sampler even when the dataset may have an implemented # finite `__len__` because in multi-process data loading, naive # settings will return duplicated data (which may be desired), and # thus using a sampler with length matching that of dataset will # cause data lost (you may have duplicates of the first couple # batches, but never see anything afterwards). Therefore, # `Iterabledataset` always uses an infinite sampler, an instance of # `_InfiniteConstantSampler` defined above. # # A custom `batch_sampler` essentially only controls the batch size. # However, it is unclear how useful it would be since an iterable-style # dataset can handle that within itself. Moreover, it is pointless # in multi-process data loading as the assignment order of batches # to workers is an implementation detail so users can not control # how to batchify each worker's iterable. Thus, we disable this # option. If this turns out to be useful in future, we can re-enable # this, and support custom samplers that specify the assignments to # specific workers. if shuffle is not False: raise ValueError( "DataLoader with IterableDataset: expected unspecified " "shuffle option, but got shuffle={}".format(shuffle) ) elif sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "sampler option, but got sampler={}".format(sampler) ) elif batch_sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format( batch_sampler ) ) else: self._dataset_kind = _DatasetKind.Map if sampler is not None and shuffle: raise ValueError("sampler option is mutually exclusive with " "shuffle") if batch_sampler is not None: # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError( "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" ) batch_size = None drop_last = False elif batch_size is None: # no auto_collation if drop_last: raise ValueError( "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" ) if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: # Cannot statically verify that dataset is Sized # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] sampler = RandomSampler(dataset, generator=generator) # type: ignore else: sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler self.generator = generator if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn self.persistent_workers = persistent_workers self.__initialized = True self._IterableDataset_len_called = ( None # See NOTE [ IterableDataset and __len__ ] ) self._iterator = self._get_iterator() if self.persistent_workers else None def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self) def __setattr__(self, attr, val): if self.__initialized and attr in ( "batch_size", "batch_sampler", "sampler", "drop_last", "dataset", "persistent_workers", ): raise ValueError( "{} attribute should not be set after {} is " "initialized".format(attr, self.__class__.__name__) ) super(DataLoader, self).__setattr__(attr, val) # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up # since '_BaseDataLoaderIter' references 'DataLoader'. def __iter__(self) -> "_BaseDataLoaderIter": # When using a single worker the returned iterator should be # created everytime to avoid reseting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused if self.persistent_workers and self.num_workers > 0: if self._iterator is None: self._iterator = self._get_iterator() elif not self._iterator._status_reset: self._iterator._reset(self) return self._iterator else: return self._get_iterator() @property def _auto_collation(self): return self.batch_sampler is not None @property def _index_sampler(self): # The actual sampler used for generating indices for `_DatasetFetcher` # (see _utils/fetch.py) to read data at each time. This would be # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. # We can't change `.sampler` and `.batch_sampler` attributes for BC # reasons. if self._auto_collation: return self.batch_sampler else: return self.sampler def __len__(self) -> int: if self._dataset_kind == _DatasetKind.Iterable: # NOTE [ IterableDataset and __len__ ] # # For `IterableDataset`, `__len__` could be inaccurate when one naively # does multi-processing data loading, since the samples will be duplicated. # However, no real use case should be actually using that behavior, so # it should count as a user error. We should generally trust user # code to do the proper thing (e.g., configure each replica differently # in `__iter__`), and give us the correct `__len__` if they choose to # implement it (this will still throw if the dataset does not implement # a `__len__`). # # To provide a further warning, we track if `__len__` was called on the # `DataLoader`, save the returned value in `self._len_called`, and warn # if the iterator ends up yielding more than this number of samples. # Cannot statically verify that dataset is Sized length = self._IterableDataset_len_called = len(self.dataset) # type: ignore if ( self.batch_size is not None ): # IterableDataset doesn't allow custom sampler or batch_sampler from math import ceil if self.drop_last: length = length // self.batch_size else: length = ceil(length / self.batch_size) return length else: return len(self._index_sampler) def check_worker_number_rationality(self): def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): suggested_max_worker_msg = ( ( ( "Our suggested max number of worker in current system is {}{}, which is smaller " "than what this DataLoader is going to create." ).format( num_worker_suggest, ( "" if cpuset_checked else " (`cpuset` is not taken into account)" ), ) ) if num_worker_suggest is not None else ( "DataLoader is not able to compute a suggested max number of worker in current system." ) ) warn_msg = ( "This DataLoader will create {} worker processes in total. {} " "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " "lower the worker number to avoid potential slowness/freeze if necessary." ).format(num_worker_created, suggested_max_worker_msg) return warn_msg if not self.num_workers or self.num_workers == 0: return # try to compute a suggested max number of worker based on system's resource max_num_worker_suggest = None cpuset_checked = False if hasattr(os, "sched_getaffinity"): try: max_num_worker_suggest = len(os.sched_getaffinity(0)) cpuset_checked = True except Exception: pass if max_num_worker_suggest is None: # os.cpu_count() could return Optional[int] # get cpu count first and check None in order to satify mypy check cpu_count = os.cpu_count() if cpu_count is not None: max_num_worker_suggest = cpu_count if max_num_worker_suggest is None: warnings.warn( _create_warning_msg( max_num_worker_suggest, self.num_workers, cpuset_checked ) ) return if self.num_workers > max_num_worker_suggest: warnings.warn( _create_warning_msg( max_num_worker_suggest, self.num_workers, cpuset_checked ) ) class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._IterableDataset_len_called = loader._IterableDataset_len_called self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = loader.pin_memory and flow.cuda.is_available() self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item() self._base_seed = flow.randint( 0, np.iinfo(np.int64).max, (), generator=loader.generator ).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 self._profile_name = "enumerate(DataLoader)#{}.__next__".format( self.__class__.__name__ ) self._status_reset = True def __iter__(self) -> "_BaseDataLoaderIter": return self def _reset(self, loader, first_iter=False): self._status_reset = True self._sampler_iter = iter(self._index_sampler) self._num_yielded = 0 self._IterableDataset_len_called = loader._IterableDataset_len_called def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def _next_data(self): raise NotImplementedError def __next__(self) -> Any: self._status_reset = False if self._sampler_iter is None: self._reset() data = self._next_data() self._num_yielded += 1 if ( self._dataset_kind == _DatasetKind.Iterable and self._IterableDataset_len_called is not None and self._num_yielded > self._IterableDataset_len_called ): warn_msg = ( "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " "samples have been fetched. " ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) if self._num_workers > 1: warn_msg += "Multiprocessing dataloader is not support yet!" warnings.warn(warn_msg) return data next = __next__ def __len__(self) -> int: return len(self._index_sampler) def __getstate__(self): raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert 0 <= self._num_workers <= 1 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last, ) def _next_data(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" # NOTE [ Data Loader Multiprocessing Shutdown Logic ] # # Preliminary: # # Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/ # # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if # `pin_memory=False`. # # # Terminating multiprocessing logic requires very careful design. In # particular, we need to make sure that # # 1. The iterator gracefully exits the workers when its last reference is # gone or it is depleted. # # In this case, the workers should be gracefully exited because the # main process may still need to continue to run, and we want cleaning # up code in the workers to be executed (e.g., releasing GPU memory). # Naturally, we implement the shutdown logic in `__del__` of # DataLoaderIterator. # # We delay the discussion on the logic in this case until later. # # 2. The iterator exits the workers when the loader process and/or worker # processes exits normally or with error. # # We set all workers and `pin_memory_thread` to have `daemon=True`. # # You may ask, why can't we make the workers non-daemonic, and # gracefully exit using the same logic as we have in `__del__` when the # iterator gets deleted (see 1 above)? # # First of all, `__del__` is **not** guaranteed to be called when # interpreter exits. Even if it is called, by the time it executes, # many Python core library resources may alreay be freed, and even # simple things like acquiring an internal lock of a queue may hang. # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic # children. # # Thus, we register an `atexit` hook that sets a global flag # `_utils.python_exit_status`. Since `atexit` hooks are executed in the # reverse order of registration, we are guaranteed that this flag is # set before library resources we use are freed (which, at least in # CPython, is done via an `atexit` handler defined in # `multiprocessing/util.py` # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 # registered when an object requiring this mechanism is first # created, e.g., `mp.Queue` # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 # ) # # So in `__del__`, we check if `_utils.python_exit_status` is set or # `None` (freed), and perform no-op if so. # # However, simply letting library clean-up codes run can also be bad, # because such codes (i.e., `multiprocessing.util._exit_function()`) # include join putting threads for `mp.Queue`, which can be blocking. # Hence, the main process putting threads are called with # `cancel_join_thread` at creation. See later section # [ 3b. A process won't hang when putting into a queue; ] # for more details. # # Here are two example cases where library clean-up codes can run # before `__del__` is called: # # 1. If we hold onto a reference to the iterator, it more often # than not tries to do `multiprocessing` library cleaning before # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) # and thus prevents our cleaning-up code to run first. # # 2. A similar issue araises when a `DataLoader` is used in a subprocess. # When a process ends, it shuts the all its daemonic children # down with a SIGTERM (instead of joining them without a timeout). # Simiarly for threads, but by a different mechanism. This fact, # together with a few implementation details of multiprocessing, forces # us to make workers daemonic. All of our problems arise when a # DataLoader is used in a subprocess, and are caused by multiprocessing # code which looks more or less like this: # # try: # your_function_using_a_dataloader() # finally: # multiprocessing.util._exit_function() # # The joining/termination mentioned above happens inside # `_exit_function()`. Now, if `your_function_using_a_dataloader()` # throws, the stack trace stored in the exception will prevent the # frame which uses `DataLoaderIter` to be freed. If the frame has any # reference to the `DataLoaderIter` (e.g., in a method of the iter), # its `__del__`, which starts the shutdown procedure, will not be # called. That, in turn, means that workers aren't notified. Attempting # to join in `_exit_function` will then result in a hang. # # For context, `_exit_function` is also registered as an `atexit` call. # So it is unclear to me (@ssnl) why this is needed in a finally block. # The code dates back to 2008 and there is no comment on the original # PEP 371 or patch https://bugs.python.org/issue3050 (containing both # the finally block and the `atexit` registration) that explains this. # # # Finally, another choice is to just shutdown workers with logic in 1 # above whenever we see an error in `next`. This isn't ideal because # a. It prevents users from using try-catch to resume data loading. # b. It doesn't prevent hanging if users have references to the # iterator. # # 3. All processes exit if any of them die unexpectedly by fatal signals. # # As shown above, the workers are set as daemonic children of the main # process. However, automatic cleaning-up of such child processes only # happens if the parent process exits gracefully (e.g., not via fatal # signals like SIGKILL). So we must ensure that each process will exit # even the process that should send/receive data to/from it were # killed, i.e., # # a. A process won't hang when getting from a queue. # # Even with carefully designed data dependencies (i.e., a `put()` # always corresponding to a `get()`), hanging on `get()` can still # happen when data in queue is corrupted (e.g., due to # `cancel_join_thread` or unexpected exit). # # For child exit, we set a timeout whenever we try to get data # from `data_queue`, and check the workers' status on each timeout # and error. # See `_DataLoaderiter._get_batch()` and # `_DataLoaderiter._try_get_data()` for details. # # Additionally, for child exit on non-Windows platforms, we also # register a SIGCHLD handler (which is supported on Windows) on # the main process, which checks if any of the workers fail in the # (Python) handler. This is more efficient and faster in detecting # worker failures, compared to only using the above mechanism. # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. # # For `.get()` calls where the sender(s) is not the workers, we # guard them with timeouts, and check the status of the sender # when timeout happens: # + in the workers, the `_utils.worker.ManagerWatchdog` class # checks the status of the main process. # + if `pin_memory=True`, when getting from `pin_memory_thread`, # check `pin_memory_thread` status periodically until `.get()` # returns or see that `pin_memory_thread` died. # # b. A process won't hang when putting into a queue; # # We use `mp.Queue` which has a separate background thread to put # objects from an unbounded buffer array. The background thread is # daemonic and usually automatically joined when the process # *exits*. # # In case that the receiver has ended abruptly while # reading from the pipe, the join will hang forever. The usual # solution for this in Python is calling `q.cancel_join_thread`, # which prevents automatically joining it when finalizing # (exiting). # # Nonetheless, `cancel_join_thread` must only be called when the # queue is **not** going to be read from or write into by another # process, because it may hold onto a lock or leave corrupted data # in the queue, leading other readers/writers to hang. # # Hence, # + For worker processes, we only do so (for their output # queues, i.e., `worker_result_queue`) before exiting. # + For `pin_memory_thread`, its output queue `data_queue` is a # `queue.Queue` that does blocking `put` if the queue is full. # So there is no above problem, but as a result, in # `_pin_memory_loop`, we do need to wrap the `put` in a loop # that breaks not only upon success, but also when the main # process stops reading, i.e., is shutting down. # + For loader process, we `cancel_join_thread()` for all # `_index_queues` because the whole purpose of workers and # `pin_memory_thread` is to serve the loader process. If # loader process is already exiting, we don't really care if # the queues are corrupted. # # # Now let's get back to 1: # how we gracefully exit the workers when the last reference to the # iterator is gone. # # To achieve this, we implement the following logic along with the design # choices mentioned above: # # `workers_done_event`: # A `multiprocessing.Event` shared among the main process and all worker # processes. This is used to signal the workers that the iterator is # shutting down. After it is set, they will not send processed data to # queues anymore, and only wait for the final `None` before exiting. # `done_event` isn't strictly needed. I.e., we can just check for `None` # from the input queue, but it allows us to skip wasting resources # processing data if we are already shutting down. # # `pin_memory_thread_done_event`: # A `threading.Event` for a similar purpose to that of # `workers_done_event`, but is for the `pin_memory_thread`. The reason # that separate events are needed is that `pin_memory_thread` reads from # the output queue of the workers. But the workers, upon seeing that # `workers_done_event` is set, only wants to see the final `None`, and is # not required to flush all data in the output queue (e.g., it may call # `cancel_join_thread` on that queue if its `IterableDataset` iterator # happens to exhaust coincidentally, which is out of the control of the # main process). Thus, since we will exit `pin_memory_thread` before the # workers (see below), two separete events are used. # # NOTE: In short, the protocol is that the main process will set these # `done_event`s and then the corresponding processes/threads a `None`, # and that they may exit at any time after receiving the `None`. # # NOTE: Using `None` as the final signal is valid, since normal data will # always be a 2-tuple with the 1st element being the index of the data # transferred (different from dataset index/key), and the 2nd being # either the dataset key or the data sample (depending on which part # of the data model the queue is at). # # [ worker processes ] # While loader process is alive: # Get from `index_queue`. # If get anything else, # Check `workers_done_event`. # If set, continue to next iteration # i.e., keep getting until see the `None`, then exit. # Otherwise, process data: # If is fetching from an `IterableDataset` and the iterator # is exhausted, send an `_IterableDatasetStopIteration` # object to signal iteration end. The main process, upon # receiving such an object, will send `None` to this # worker and not use the corresponding `index_queue` # anymore. # If timed out, # No matter `workers_done_event` is set (still need to see `None`) # or not, must continue to next iteration. # (outside loop) # If `workers_done_event` is set, (this can be False with `IterableDataset`) # `data_queue.cancel_join_thread()`. (Everything is ending here: # main process won't read from it; # other workers will also call # `cancel_join_thread`.) # # [ pin_memory_thread ] # # No need to check main thread. If this thread is alive, the main loader # # thread must be alive, because this thread is set as daemonic. # While `pin_memory_thread_done_event` is not set: # Get from `index_queue`. # If timed out, continue to get in the next iteration. # Otherwise, process data. # While `pin_memory_thread_done_event` is not set: # Put processed data to `data_queue` (a `queue.Queue` with blocking put) # If timed out, continue to put in the next iteration. # Otherwise, break, i.e., continuing to the out loop. # # NOTE: we don't check the status of the main thread because # 1. if the process is killed by fatal signal, `pin_memory_thread` # ends. # 2. in other cases, either the cleaning-up in __del__ or the # automatic exit of daemonic thread will take care of it. # This won't busy-wait either because `.get(timeout)` does not # busy-wait. # # [ main process ] # In the DataLoader Iter's `__del__` # b. Exit `pin_memory_thread` # i. Set `pin_memory_thread_done_event`. # ii Put `None` in `worker_result_queue`. # iii. Join the `pin_memory_thread`. # iv. `worker_result_queue.cancel_join_thread()`. # # c. Exit the workers. # i. Set `workers_done_event`. # ii. Put `None` in each worker's `index_queue`. # iii. Join the workers. # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. # # NOTE: (c) is better placed after (b) because it may leave corrupted # data in `worker_result_queue`, which `pin_memory_thread` # reads from, in which case the `pin_memory_thread` can only # happen at timeing out, which is slow. Nonetheless, same thing # happens if a worker is killed by signal at unfortunate times, # but in other cases, we are better off having a non-corrupted # `worker_result_queue` for `pin_memory_thread`. # # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) # can be omitted # # NB: `done_event`s isn't strictly needed. E.g., we can just check for # `None` from `index_queue`, but it allows us to skip wasting resources # processing indices already in `index_queue` if we are already shutting # down. def __init__(self, loader): super(_MultiProcessingDataLoaderIter, self).__init__(loader) assert not flow.env.rdma_is_initialized(), ( "RDMA is initialized! Could not create _MultiProcessingDataLoaderIter any more. " "Please make sure Dataloader is created before invoking oneflow.env.init_rdma(). " "If this condition is met, you can pass the arg persistent_workers=True in " "Dataloader to avoid this error!" ) assert self._num_workers > 0 assert self._prefetch_factor > 0 if loader.multiprocessing_context is None: multiprocessing_context = multiprocessing else: multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) # No certainty which module multiprocessing_context is self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] self._worker_pids_set = False self._shutdown = False self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] self._workers = [] for i in range(self._num_workers): # No certainty which module multiprocessing_context is index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] # Need to `cancel_join_thread` here! # See sections (2) and (3b) above. index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_utils.worker._worker_loop, args=( self._dataset_kind, self._dataset, index_queue, self._worker_result_queue, self._workers_done_event, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed, self._worker_init_fn, i, self._num_workers, self._persistent_workers, ), ) w.daemon = True # NB: Process.start() actually take some time as it needs to # start a process and pass the arguments over via a pipe. # Therefore, we only add a worker to self._workers list after # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self._index_queues.append(index_queue) self._workers.append(w) if self._pin_memory: self._pin_memory_thread_done_event = threading.Event() # Queue is not type-annotated self._data_queue = queue.Queue() # type: ignore[var-annotated] pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self._worker_result_queue, self._data_queue, flow.cuda.current_device(), self._pin_memory_thread_done_event, ), ) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: self._data_queue = self._worker_result_queue # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) def _reset(self, loader, first_iter=False): super()._reset(loader, first_iter) self._send_idx = 0 # idx of the next task to be sent to workers self._rcvd_idx = 0 # idx of the next task to be returned in __next__ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} self._tasks_outstanding = ( 0 # always equal to count(v for v in task_info.values() if len(v) == 1) ) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled if not first_iter: for idx in range(self._num_workers): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, return_data = self._get_data() if isinstance(return_idx, _utils.worker._ResumeIteration): assert return_data is None resume_iteration_cnt -= 1 # prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers): self._try_put_index() def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # Tries to fetch data from `self._data_queue` once for a given timeout. # This can also be used as inner loop of fetching without timeout, with # the sender status as the loop condition. # # This raises a `RuntimeError` if any worker died expectedly. This error # can come from either the SIGCHLD handler in `_utils/signal_handling.py` # (only for non-Windows platforms), or the manual check below on errors # and timeouts. # # Returns a 2-tuple: # (bool: whether successfully get data, any: data if successful else None) try: data = self._data_queue.get(timeout=timeout) return (True, data) except Exception as e: # At timeout and error, we manually check whether any worker has # failed. Note that this is the only mechanism for Windows to detect # worker failures. failed_workers = [] for worker_id, w in enumerate(self._workers): if self._workers_status[worker_id] and not w.is_alive(): failed_workers.append(w) self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ", ".join(str(w.pid) for w in failed_workers) raise RuntimeError( "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str) ) from e if isinstance(e, queue.Empty): return (False, None) import tempfile import errno try: # Raise an exception if we are this close to the FDs limit. # Apparently, trying to open only one file is not a sufficient # test. # See NOTE [ DataLoader on Linux and open files limit ] fds_limit_margin = 10 fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] except OSError as e: if e.errno == errno.EMFILE: raise RuntimeError( "Too many open files. Communication with the" " workers is no longer possible. Please increase the" " limit using `ulimit -n` in the shell or change the" " sharing strategy by calling" " `flow.multiprocessing.set_sharing_strategy('file_system')`" " at the beginning of your code" ) from None raise def _get_data(self): # Fetches data from `self._data_queue`. # # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` # in a loop. This is the only mechanism to detect worker failures for # Windows. For other platforms, a SIGCHLD handler is also used for # worker failure detection. # # If `pin_memory=True`, we also need check if `pin_memory_thread` had # died at timeouts. if self._timeout > 0: success, data = self._try_get_data(self._timeout) if success: return data else: raise RuntimeError( "DataLoader timed out after {} seconds".format(self._timeout) ) elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() if success: return data else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError("Pin memory thread exited unexpectedly") # In this case, `self._data_queue` is a `queue.Queue`,. But we don't # need to call `.task_done()` because we don't use `.join()`. else: while True: success, data = self._try_get_data() if success: return data def _next_data(self): while True: # If the worker responsible for `self._rcvd_idx` has already ended # and was unable to fulfill this task (due to exhausting an `IterableDataset`), # we try to advance `self._rcvd_idx` to find the next valid index. # # This part needs to run in the loop because both the `self._get_data()` # call and `_IterableDatasetStopIteration` check below can mark # extra worker(s) as dead. while self._rcvd_idx < self._send_idx: info = self._task_info[self._rcvd_idx] worker_id = info[0] if ( len(info) == 2 or self._workers_status[worker_id] ): # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) if not self._persistent_workers: self._shutdown_workers() raise StopIteration # Now `self._rcvd_idx` is the batch index we want to fetch # Check if the next sample has already been generated if len(self._task_info[self._rcvd_idx]) == 2: data = self._task_info.pop(self._rcvd_idx)[1] return self._process_data(data) assert not self._shutdown and self._tasks_outstanding > 0 idx, data = self._get_data() self._tasks_outstanding -= 1 if self._dataset_kind == _DatasetKind.Iterable: # Check for _IterableDatasetStopIteration if isinstance(data, _utils.worker._IterableDatasetStopIteration): if self._persistent_workers: self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) self._try_put_index() continue if idx != self._rcvd_idx: # store out-of-order samples self._task_info[idx] += (data,) else: del self._task_info[idx] return self._process_data(data) def _try_put_index(self): assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try: index = self._next_index() except StopIteration: return for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: break else: # not found (i.e., didn't break) return self._index_queues[worker_queue_idx].put((self._send_idx, index)) self._task_info[self._send_idx] = (worker_queue_idx,) self._tasks_outstanding += 1 self._send_idx += 1 def _process_data(self, data): self._rcvd_idx += 1 self._try_put_index() if isinstance(data, ExceptionWrapper): data.reraise() return data def _mark_worker_as_unavailable(self, worker_id, shutdown=False): # Mark a worker as having finished its work e.g., due to # exhausting an `IterableDataset`. This should be used only when this # `_MultiProcessingDataLoaderIter` is going to continue running. assert self._workers_status[worker_id] or ( self._persistent_workers and shutdown ) # Signal termination to that specific worker. q = self._index_queues[worker_id] # Indicate that no more data will be put on this queue by the current # process. q.put(None) # Note that we don't actually join the worker here, nor do we remove the # worker's pid from C side struct because (1) joining may be slow, and # (2) since we don't join, the worker may still raise error, and we # prefer capturing those, rather than ignoring them, even though they # are raised after the worker has finished its job. # Joinning is deferred to `_shutdown_workers`, which it is called when # all workers finish their jobs (e.g., `IterableDataset` replicas) or # when this iterator is garbage collected. self._workers_status[worker_id] = False assert self._workers_done_event.is_set() == shutdown def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. # See (2) of the note. If Python is shutting down, do no-op. try: python_exit_status = _utils.python_exit_status except AttributeError: # Python is shutting down and `_utils` has been freed assert _utils is None return if python_exit_status is True or python_exit_status is None: return # Normal exit when last reference is gone / iterator is depleted. # See (1) and the second half of the note. if not self._shutdown: self._shutdown = True try: # Normal exit when last reference is gone / iterator is depleted. # See (1) and the second half of the note. # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. if hasattr(self, "_pin_memory_thread"): # Use hasattr in case error happens before we set the attribute. self._pin_memory_thread_done_event.set() # Send something to pin_memory_thread in case it is waiting # so that it can wake up and check `pin_memory_thread_done_event` self._worker_result_queue.put((None, None)) self._pin_memory_thread.join() self._worker_result_queue.cancel_join_thread() self._worker_result_queue.close() # Exit workers now. self._workers_done_event.set() for worker_id in range(len(self._workers)): # Get number of workers from `len(self._workers)` instead of # `self._num_workers` in case we error before starting all # workers. # If we are using workers_status with persistent_workers # we have to shut it down because the worker is paused if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable(worker_id, shutdown=True) for w in self._workers: # We should be able to join here, but in case anything went # wrong, we set a timeout and if the workers fail to join, # they are killed in the `finally` block. w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) for q in self._index_queues: q.cancel_join_thread() q.close() finally: # Even though all this function does is putting into queues that # we have called `cancel_join_thread` on, weird things can # happen when a worker is killed by a signal, e.g., hanging in # `Event.set()`. So we need to guard this with SIGCHLD handler, # and remove pids from the C side data structure only at the # end. # # FIXME: Unfortunately, for Windows, we are missing a worker # error detection mechanism here in this function, as it # doesn't provide a SIGCHLD handler. if self._worker_pids_set: _utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False for w in self._workers: if w.is_alive(): # Existing mechanisms try to make the workers exit # peacefully, but in case that we unfortunately reach # here, which we shouldn't, (e.g., pytorch/pytorch#39570), # we kill the worker. w.terminate() def __del__(self): self._shutdown_workers() ================================================ FILE: python/oneflow/utils/data/dataset.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import bisect import functools from typing import ( TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple, Dict, Callable, ) import oneflow as flow from oneflow.framework.tensor import Tensor default_generator = flow._oneflow_internal.default_generator # Taken from python 3.5 docs def _accumulate(iterable, fn=lambda x, y: x + y): "Return running totals" # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 it = iter(iterable) try: total = next(it) except StopIteration: return yield total for element in it: total = fn(total, element) yield total T_co = TypeVar("T_co", covariant=True) T = TypeVar("T") class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~flow.utils.data.Sampler` implementations and the default options of :class:`~flow.utils.data.DataLoader`. .. note:: :class:`~flow.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": return ConcatDataset([self, other]) class IterableDataset(Dataset[T_co]): r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~flow.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~flow.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> class MyIterableDataset(flow.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... iter_start = self.start ... iter_end = self.end ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> class MyIterableDataset(flow.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] """ functions: Dict[str, Callable] = {} reduce_ex_hook: Optional[Callable] = None def __iter__(self) -> Iterator[T_co]: raise NotImplementedError def __add__(self, other: Dataset[T_co]): return ChainDataset([self, other]) def __getattr__(self, attribute_name): if attribute_name in IterableDataset.functions: function = functools.partial( IterableDataset.functions[attribute_name], self ) return function else: raise AttributeError @classmethod def register_function(cls, function_name, function): IterableDataset.functions[function_name] = function @classmethod def register_datapipe_as_function(cls, function_name, cls_to_register): if function_name in IterableDataset.functions: raise Exception( "Unable to add DataPipe function name {} as it is already taken".format( function_name ) ) def class_function(cls, source_dp, *args, **kwargs): return cls(source_dp, *args, **kwargs) function = functools.partial(class_function, cls_to_register) IterableDataset.functions[function_name] = function def __reduce_ex__(self, *args, **kwargs): if IterableDataset.reduce_ex_hook is not None: try: return IterableDataset.reduce_ex_hook(self) except NotImplementedError: pass return super().__reduce_ex__(*args, **kwargs) @classmethod def set_reduce_ex_hook(cls, hook_fn): if IterableDataset.reduce_ex_hook is not None and hook_fn is not None: raise Exception("Attempt to override existing reduce_ex_hook") IterableDataset.reduce_ex_hook = hook_fn class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ def __init__(self, *tensors: Tensor) -> None: assert all( tensors[0].size(0) == tensor.size(0) for tensor in tensors ), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0) class ConcatDataset(Dataset[T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[T_co]] cumulative_sizes: List[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset]) -> None: super(ConcatDataset, self).__init__() # Cannot verify that datasets is Sized assert len(datasets) > 0, "datasets should not be an empty iterable" # type: ignore self.datasets = list(datasets) for d in self.datasets: assert not isinstance( d, IterableDataset ), "ConcatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError( "absolute value of index should not exceed dataset length" ) idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] class ChainDataset(IterableDataset): r"""Dataset for chainning multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chainning operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: Iterable[Dataset]) -> None: super(ChainDataset, self).__init__() self.datasets = datasets def __iter__(self): for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" for x in d: yield x def __len__(self): total = 0 for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" # Cannot verify that all self.datasets are Sized total += len(d) return total class Subset(Dataset[T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices) def random_split( dataset: Dataset[T], lengths: Sequence[int], generator: Optional[object] = default_generator, ) -> List[Subset[T]]: r""" Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.: >>> random_split(range(10), [3, 7], generator=flow.Generator().manual_seed(42)) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced generator (Generator): Generator used for the random permutation. """ # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore raise ValueError( "Sum of input lengths does not equal the length of the input dataset!" ) indices = flow._C.randperm(sum(lengths), generator=generator).tolist() return [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths) ] ================================================ FILE: python/oneflow/utils/data/decorator.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Any, Callable, Optional, Type, Union from oneflow.utils.data import IterDataPipe class functional_datapipe(object): name: str def __init__(self, name: str) -> None: self.name = name def __call__(self, cls): if isinstance(cls, Type): # type: ignore if not issubclass(cls, IterDataPipe): raise TypeError("`functional_datapipe` can only decorate IterDataPipe") # with non_deterministic decorator else: if not isinstance(cls, non_deterministic) and not ( hasattr(cls, "__self__") and isinstance(cls.__self__, non_deterministic) ): raise TypeError("`functional_datapipe` can only decorate IterDataPipe") IterDataPipe.register_datapipe_as_function(self.name, cls) return cls _determinism: bool = False class guaranteed_datapipes_determinism(object): prev: bool def __init__(self) -> None: global _determinism self.prev = _determinism _determinism = True def __enter__(self) -> None: pass def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: global _determinism _determinism = self.prev class non_deterministic(object): cls: Optional[Type[IterDataPipe]] = None # TODO: Lambda for picking deterministic_fn: Callable[[], bool] def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None: # 1. Decorator doesn't have any argument if isinstance(arg, Type): # type: ignore if not issubclass(arg, IterDataPipe): # type: ignore raise TypeError( "Only `IterDataPipe` can be decorated with `non_deterministic`" ", but {} is found".format(arg.__name__) ) self.cls = arg # type: ignore # 2. Decorator has an argument of a function # This class should behave differently given different inputs. Use this # function to verify the determinism for each instance. # When the function returns True, the instance is non-deterministic. Otherwise, # the instance is a deterministic DataPipe. elif isinstance(arg, Callable): # type:ignore self.deterministic_fn = arg # type: ignore else: raise TypeError("{} can not be decorated by non_deterministic".format(arg)) def __call__(self, *args, **kwargs): global _determinism # Decorate IterDataPipe if self.cls is not None: if _determinism: raise TypeError( "{} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " "You can turn off determinism for this DataPipe if that is acceptable " "for your application".format(self.cls.__name__) ) return self.cls(*args, **kwargs) # type: ignore # Decorate with a functional argument if not ( isinstance(args[0], Type) and issubclass( # type: ignore args[0], IterDataPipe ) ): raise TypeError( "Only `IterDataPipe` can be decorated, but {} is found".format( args[0].__name__ ) ) self.cls = args[0] return self.deterministic_wrapper_fn def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: res = self.deterministic_fn(*args, **kwargs) # type: ignore if not isinstance(res, bool): raise TypeError( "deterministic_fn of `non_deterministic` decorator is required " "to return a boolean value, but {} is found".format(type(res)) ) global _determinism if _determinism and res: raise TypeError( "{} is non-deterministic with the inputs, but you set " "'guaranteed_datapipes_determinism'. You can turn off determinism " "for this DataPipe if that is acceptable for your application".format( self.cls.__name__ ) ) # type: ignore return self.cls(*args, **kwargs) # type: ignore ================================================ FILE: python/oneflow/utils/data/distributed.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import numpy as np from typing import TypeVar, Optional, Iterator import oneflow as flow from oneflow.utils.data import Sampler, Dataset T_co = TypeVar("T_co", covariant=True) class DistributedSampler(Sampler[T_co]): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`flow.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class:`~flow.utils.data.DistributedSampler` instance as a :class:`~flow.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Args: dataset: Dataset used for sampling. num_replicas (int, optional): Number of processes participating in distributed training. By default, :attr:`world_size` is retrieved from the current distributed group. rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the current distributed group. shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. drop_last (bool, optional): if ``True``, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If ``False``, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: ``False``. .. warning:: In distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. For example: .. code-block:: python >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__( self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, ) -> None: if num_replicas is None: num_replicas = flow.env.get_world_size() if rank is None: rank = flow.env.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1) ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( # `type:ignore` is required because Dataset cannot provide a default __len__ (len(self.dataset) - self.num_replicas) / self.num_replicas ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self) -> Iterator[T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = flow.Generator("cpu") g.manual_seed(self.seed + self.epoch) indices = flow._C.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[ :padding_size ] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch ================================================ FILE: python/oneflow/utils/data/sampler.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized import numpy as np import oneflow as flow T_co = TypeVar("T_co", covariant=True) class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~flow.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~flow.utils.data.DataLoader`. """ def __init__(self, data_source: Optional[Sized]) -> None: pass def __iter__(self) -> Iterator[T_co]: raise NotImplementedError # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # # Many times we have an abstract class representing a collection/iterable of # data, e.g., `flow.utils.data.Sampler`, with its subclasses optionally # implementing a `__len__` method. In such cases, we must make sure to not # provide a default implementation, because both straightforward default # implementations have their issues: # # + `return NotImplemented`: # Calling `len(subclass_instance)` raises: # TypeError: 'NotImplementedType' object cannot be interpreted as an integer # # + `raise NotImplementedError()`: # This prevents triggering some fallback behavior. E.g., the built-in # `list(X)` tries to call `len(X)` first, and executes a different code # path if the method is not found or `NotImplemented` is returned, while # raising an `NotImplementedError` will propagate and and make the call # fail where it could have use `__iter__` to complete the call. # # Thus, the only two sensible things to do are # # + **not** provide a default `__len__`. # # + raise a `TypeError` instead, which is what Python uses when users call # a method that is not defined on an object. # (@ssnl verifies that this works on at least Python 3.7.) class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Args: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source) class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Args: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. generator (Generator): Generator used in sampling. """ data_source: Sized replacement: bool def __init__( self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None, ) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator if not isinstance(self.replacement, bool): raise TypeError( "replacement should be a boolean value, but got " "replacement={}".format(self.replacement) ) if self._num_samples is not None and not replacement: raise ValueError( "With replacement=False, num_samples should not be specified, " "since a random permute will be performed." ) if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError( "num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples) ) @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): n = len(self.data_source) if self.generator is None: generator = flow.Generator("cpu") generator.manual_seed(np.random.randint(0, np.iinfo(np.int64).max)) # TODO: use Tensor.random_ # generator.manual_seed( # int(flow.empty((), dtype=flow.int64).random_().item()) # ) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from flow._C.randint( high=n, size=(32,), dtype=flow.int64, generator=generator ).numpy().tolist() yield from flow._C.randint( high=n, size=(self.num_samples % 32,), dtype=flow.int64, generator=generator, ).numpy().tolist() else: yield from flow._C.randperm(n, generator=generator).numpy().tolist() def __len__(self): return self.num_samples class SubsetRandomSampler(Sampler[int]): r"""Samples elements randomly from a given list of indices, without replacement. Args: indices (sequence): a sequence of indices generator (Generator): Generator used in sampling. """ indices: Sequence[int] def __init__(self, indices: Sequence[int], generator=None) -> None: self.indices = indices self.generator = generator def __iter__(self): return ( self.indices[i] for i in flow._C.randperm(len(self.indices), generator=self.generator) ) def __len__(self): return len(self.indices) class BatchSampler(Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if ( not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0 ): raise ValueError( "batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size) ) if not isinstance(drop_last, bool): raise ValueError( "drop_last should be a boolean value, but got " "drop_last={}".format(drop_last) ) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore ================================================ FILE: python/oneflow/utils/global_view/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.utils.global_view.to_global import to_global from oneflow.utils.global_view.to_local import to_local from oneflow.utils.global_view.global_mode import global_mode, current_global_mode __all__ = [ "to_global", "to_local", "global_mode", "current_global_mode", ] ================================================ FILE: python/oneflow/utils/global_view/global_mode.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import oneflow._oneflow_internal.global_view as internal_global_view class global_mode(internal_global_view.global_mode): r"""Create a scope to provide global information for the computation process within it. It provides convinience for converting from local execution to global execution, especially for converting to ddp global execution. 1) Make the source op create the global tensor directly. 2) Make it legal for the "to(device)" API of the global tensor. 3) Make it legal to use ".device" to get the device type of the global tensor. Note: Both placement and sbp are required if the global mode is enabled. Args: enabled (bool): whether the global mode is enbaled. placement (oneflow.placement, optional): the desired placement of the input. Default: None sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp, optional): the desired sbp of the input or self-defined functions in order to specify SBP. Default: None For example: .. code-block:: python class LinearEvalGraphWithDDP(flow.nn.Graph): def __init__(self): super().__init__() self.linear_dp = linear_dp def build(self, x): with global_mode(True, placement=P, sbp=B): device = self.linear_dp.weight.device x = x.to(device) out = self.linear_dp(x) # The local tensor will be converted to global sample = flow.randn(out.shape, device="cpu").to(device) out = out + sample * 100 out = out - sample * 100 return out .. code-block:: python with global_mode(False): # The tensor will be keeped as local. sample = flow.randn(out.shape, device="cpu").to(device) out = out + sample * 100 out = out - sample * 100 """ def __init__(self, enabled, placement=None, sbp=None) -> None: if not enabled: super().__init__(enabled) else: super().__init__(enabled, placement, sbp) def __enter__(self): pass def __exit__(self, type, value, traceback): pass class current_global_mode(internal_global_view.current_global_mode): r"""Get the current global mode information. Use the current_global_mode to get the information of global mode, including enabled, placement and sbp. Note: The sbp property is supposed to return a list/tuple of `oneflow.sbp.sbp`. For example: .. code-block:: python with global_mode(True, placement=P, sbp=B): # Get the global mode info. cur_global_mode = global_view.current_global_mode() test_case.assertTrue(cur_global_mode.is_enabled) test_case.assertEqual(cur_global_mode.placement, P) test_case.assertEqual(cur_global_mode.sbp[0], B) """ def __init__(self) -> None: super().__init__() @property def is_enabled(self): return super().is_enabled @property def sbp(self): return super().sbp @property def placement(self): return super().placement ================================================ FILE: python/oneflow/utils/global_view/global_utils.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import pickle import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.framework.args_tree import ArgsTree def to_global_tensor(input_tensor, placement=None, sbp=None, **kwargs): # specific operation for None if input_tensor is None: return flow.local_to_global( input=input_tensor, placement=placement, sbp=sbp, **kwargs ) if input_tensor.is_global: return flow.global_to_global( input=input_tensor, placement=placement, sbp=sbp, **kwargs ) else: if "grad_sbp" in kwargs: del kwargs["grad_sbp"] return flow.local_to_global( input=input_tensor, placement=placement, sbp=sbp, **kwargs ) def to_local_tensor(input_tensor, copy): if not input_tensor.is_global: warnings.warn("The tensor should be global, local tensor will remain the same.") return input_tensor return flow._C.to_local(input_tensor, copy) def check_input_global(input): is_input_global = False if input is not None: if isinstance(input, Tensor): is_input_global = input.is_global elif isinstance(input, (dict, tuple, list)): is_first_tensor_in_input = True input_tree_for_is_global = ArgsTree(input) for arg in input_tree_for_is_global.iter_nodes(): if isinstance(arg, Tensor): if is_first_tensor_in_input: is_input_global = arg.is_global is_first_tensor_in_input = False else: assert ( arg.is_global == is_input_global ), "Tensor(s) in the input must be all local or all global." return is_input_global def check_placement_on_all_ranks(placement): # Determine whether the ranks of placement are same as all ranks is_placement_on_all_ranks = False all_ranks = flow.placement.all("cpu").ranks if ( all_ranks.shape == placement.ranks.shape and (all_ranks == placement.ranks).all() ): is_placement_on_all_ranks = True return is_placement_on_all_ranks def src_sbp_broadcast(obj, src: int = 0): rank = flow.env.get_rank() if src == rank: obj_bytes = pickle.dumps(obj) obj_bytes = flow._oneflow_internal.cpu_broadcast(obj_bytes, src) else: obj_bytes = flow._oneflow_internal.cpu_broadcast(None, src) return pickle.loads(obj_bytes) ================================================ FILE: python/oneflow/utils/global_view/to_global.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import pickle import types import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.framework.args_tree import ArgsTree from oneflow.utils.global_view.global_utils import ( to_global_tensor, check_input_global, check_placement_on_all_ranks, src_sbp_broadcast, ) def to_global(input, placement=None, sbp=None, warn_on_non_tensor_leaf=True, **kwargs): r"""Converts the input tensor or input tensor(s) in list/tuple/dict to global tensor(s). Note: Both placement and sbp are required if the input is local, otherwise at least one of placement and sbp is required. Args: input (oneflow.Tensor/None/list/tuple/dict): the input that needs to be converted. placement (oneflow.placement, optional): the desired placement of the input. Default: None sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp or Callable[[Tensor], oneflow.sbp.sbp], optional): the desired sbp of the input or self-defined functions in order to specify SBP. Default: None warn_on_non_tensor_leaf (bool, optional): whether to warn when the leaf is not a tensor. Default: True Returns: The converted input. For a tensor input: please refer to the examples in :func:`oneflow.Tensor.to_global`. For an input of other type (take a state dict as an example): .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> from oneflow import nn >>> placement = flow.placement("cpu", ranks=[0, 1]) # doctest: +SKIP >>> sbp = (flow.sbp.broadcast,) # doctest: +SKIP >>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) # doctest: +SKIP >>> global_state_dict = flow.utils.global_view.to_global(model.state_dict(), placement, sbp) # doctest: +SKIP >>> for val in state_dict.values(): # doctest: +SKIP >>> print(val.is_global) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 True True True True .. code-block:: python >>> # results on rank 1 True True True True Note: For the input of dict type, such as the state dict of the model, the unified sbp cannot be used when calling the to_global method, and the sbp needs to be specialized. Usually used for making graph models's state dict global. If you want to do the `split(0)` operation, but there are tensors that cannot be split by dim 0, then these tensors can specify sbp. It is worth noting that, for a tensor of shape `(1, n)`, you can specify SBP is `oneflow.sbp.split(1)`. For example: .. code-block:: python flow.utils.global_view.to_global(state_dict, placement=placement, sbp=get_sbp) # Defines a function to return the specified SBP. def get_sbp(state_dict, tensor): if tensor is state_dict["System-Train-TrainStep"]: return oneflow.sbp.broadcast if tensor is state_dict["module_pipeline"]["m_stage3.linear.weight"]: return oneflow.sbp.split(1) if tensor is state_dict["module_pipeline"]["m_stage3.linear.bias"]: return oneflow.sbp.broadcast return oneflow.sbp.split(0) """ is_input_not_tensor_or_none = False if (input is not None) and (not isinstance(input, (Tensor, dict, tuple, list))): is_input_not_tensor_or_none = True if ( (not is_input_not_tensor_or_none) and (placement is not None) and (not check_input_global(input)) and (not check_placement_on_all_ranks(placement)) ): src_rank = placement.ranks.flat[0] cur_rank = flow.env.get_rank() if cur_rank == src_rank: # Replace tensor(s) in the input with None, in order to reduce communication cost if isinstance(input, Tensor) or input is None: mapped_input_none = None else: input_tree_none = ArgsTree(input) def leaf_fn_to_none(node): if isinstance(node, Tensor): # Ensure that each rank has a tensor instance, which can avoid the situation of none is none in the user-defined get_sbp function. return flow.empty(0, 1) else: if warn_on_non_tensor_leaf: warnings.warn( "Non-Tensor type: {} encountered, it will remain the same.".format( type(node) ) ) return node mapped_input_none = input_tree_none.map_leaf(leaf_fn_to_none) obj_input = pickle.dumps(mapped_input_none) flow._oneflow_internal.cpu_broadcast(obj_input, src_rank) else: if cur_rank in placement.ranks: # Participating in the broadcast process but retaining original value flow._oneflow_internal.cpu_broadcast(None, src_rank) else: # The input of other ranks will be always overwritten no matter what is passed in input = pickle.loads( flow._oneflow_internal.cpu_broadcast(None, src_rank) ) if isinstance(input, (Tensor, dict, tuple, list)): input_tree = ArgsTree(input) def leaf_fn(node): if isinstance(node, Tensor) or node is None: if isinstance(sbp, types.FunctionType): return to_global_tensor(node, placement, sbp(input, node), **kwargs) else: return to_global_tensor(node, placement, sbp, **kwargs) else: if warn_on_non_tensor_leaf: warnings.warn( "Non-Tensor type: {} encountered, it will remain the same.".format( type(node) ) ) return node mapped_input = input_tree.map_leaf(leaf_fn) return mapped_input else: if warn_on_non_tensor_leaf: warnings.warn( "Non-Tensor type: {} encountered, it will remain the same.".format( type(input) ) ) return input ================================================ FILE: python/oneflow/utils/global_view/to_local.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import warnings import oneflow as flow from oneflow.framework.tensor import Tensor from oneflow.framework.args_tree import ArgsTree from oneflow.utils.global_view.global_utils import to_local_tensor def to_local(input, *, copy=False): r"""Returns the local part of the input. Returns: The converted input. For a tensor input: please refer to the examples in :func:`oneflow.Tensor.to_local`. For an input of other type (take a state dict as an example): .. code-block:: python >>> # Run on 2 ranks respectively >>> import oneflow as flow >>> from oneflow import nn >>> placement = flow.placement("cpu", ranks=[0, 1]) # doctest: +SKIP >>> sbp = (flow.sbp.broadcast,) # doctest: +SKIP >>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) # doctest: +SKIP >>> model = model.to_global(placement=placement, sbp=sbp) # doctest: +SKIP >>> local_state_dict = flow.utils.global_view.to_local(model.state_dict()) # doctest: +SKIP >>> for val in local_state_dict.values(): # doctest: +SKIP >>> print(val.is_global) # doctest: +SKIP .. code-block:: python >>> # results on rank 0 False False False False .. code-block:: python >>> # results on rank 1 False False False False """ if isinstance(input, Tensor): return to_local_tensor(input, copy) elif isinstance(input, (dict, tuple, list)): input_tree = ArgsTree(input) def leaf_fn(node): if isinstance(node, Tensor): return to_local_tensor(node, copy) else: warnings.warn( "Non-Tensor type: {} encountered, it will remain the same.".format( type(node) ) ) return node mapped_input = input_tree.map_leaf(leaf_fn) return mapped_input else: warnings.warn( "Non-Tensor type: {} encountered, it will remain the same.".format( type(input) ) ) return input ================================================ FILE: python/oneflow/utils/hooks.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # This file is mostly copied from PyTorch's torch/utils/hooks.py import oneflow as flow import oneflow.nn.modules._functions from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple from collections import OrderedDict import weakref import warnings from typing import Any __all__ = ["BackwardHook", "RemovableHandle"] class RemovableHandle(object): """A handle which provides the capability to remove a hook.""" id: int next_id: int = 0 def __init__(self, hooks_dict: Any) -> None: self.hooks_dict_ref = weakref.ref(hooks_dict) self.id = RemovableHandle.next_id RemovableHandle.next_id += 1 def remove(self) -> None: hooks_dict = self.hooks_dict_ref() if hooks_dict is not None and self.id in hooks_dict: del hooks_dict[self.id] def __getstate__(self): return (self.hooks_dict_ref(), self.id) def __setstate__(self, state) -> None: if state[0] is None: # create a dead reference self.hooks_dict_ref = weakref.ref(OrderedDict()) else: self.hooks_dict_ref = weakref.ref(state[0]) self.id = state[1] RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) def __enter__(self) -> "RemovableHandle": return self def __exit__(self, type: Any, value: Any, tb: Any) -> None: self.remove() class BackwardHook(object): """ A wrapper class to implement nn.Module backward hooks. It handles: - Ignoring non-Tensor inputs and replacing them by None before calling the user hook - Generating the proper Node to capture a set of Tensor's gradients - Linking the gradients captures for the outputs with the gradients captured for the input - Calling the user hook once both output and input gradients are available """ def __init__(self, module, user_hooks, user_pre_hooks): self.user_hooks = user_hooks self.user_pre_hooks = user_pre_hooks self.module = module self.grad_outputs = None self.n_outputs = -1 self.output_tensors_index = None self.n_inputs = -1 self.input_tensors_index = None def _pack_with_none(self, indices, values, size): res = [None] * size for idx, val in zip(indices, values): res[idx] = val return convert_to_tensor_tuple(res) def _unpack_none(self, indices, values): res = [] for idx in indices: res.append(values[idx]) return convert_to_tensor_tuple(res) def _set_user_hook(self, grad_fn): def fn(grad_input, _): # TODO(hujiakui): in pytorch, it should raise Error. if self.grad_outputs is None: warnings.warn( "Module backward hook for grad_input is called before " "the grad_output one. This happens because the gradient " "in your nn.Module flows to the Module's input without " "passing through the Module's output. Make sure that the " "output depends on the input and that the loss is computed " "based on the output." ) return res = self._pack_with_none( self.input_tensors_index, grad_input, self.n_inputs ) for hook in self.user_hooks: out = hook(self.module, res, self.grad_outputs) if out is None: continue if len(out) != len(res): raise RuntimeError( "Backward hook returned an invalid number of grad_input, " "got {}, but expected {}".format(len(out), len(res)) ) res = out if res is None: return res if len(res) != len(grad_input): raise RuntimeError( "Backward hook returned an invalid number of grad_input, " "got {}, but expected {}".format(len(res), len(grad_input)) ) self.grad_outputs = None return self._unpack_none(self.input_tensors_index, res) grad_fn.register_hook(fn) def _apply_on_tensors(self, fn, args): # Can be used to apply the given function to the tensors contained in the # args. Will return updated args and the tensors indices tensors_idx = [] tensors = [] requires_grad = False for i, arg in enumerate(args): if isinstance(arg, flow.Tensor): tensors_idx.append(i) tensors.append(arg) requires_grad |= arg.requires_grad if not (requires_grad and flow.is_grad_enabled()): return args, None # FIXME: BackwardFunction should not return a single Tensor when the return type is tuple new_tensors = flow.nn.modules._functions.BackwardHookFunction.apply(*tensors) if not isinstance(new_tensors, tuple): new_tensors = (new_tensors,) if len(new_tensors) == 0: raise RuntimeError( "Cannot set Module backward hook for a Module with no input Tensors." ) grad_fns = [ t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward" ] if len(grad_fns) == 0: raise RuntimeError( "Error while setting up backward hooks. Please open " "an issue with a code sample to reproduce this." ) fn(grad_fns[0]) arg_list = list(args) for idx, val in zip(tensors_idx, new_tensors): arg_list[idx] = val return tuple(arg_list), tensors_idx def setup_input_hook(self, args): def fn(grad_fn): self._set_user_hook(grad_fn) res, input_idx = self._apply_on_tensors(fn, args) self.n_inputs = len(args) self.input_tensors_index = input_idx return res def setup_output_hook(self, args): def fn(grad_fn): def hook(_, grad_output): self.grad_outputs = self._pack_with_none( self.output_tensors_index, grad_output, self.n_outputs ) if self.user_pre_hooks: expected_len = len(self.grad_outputs) for user_pre_hook in self.user_pre_hooks: hook_grad_outputs = user_pre_hook( self.module, self.grad_outputs ) if hook_grad_outputs is None: continue actual_len = len(hook_grad_outputs) if actual_len != expected_len: raise RuntimeError( "Backward pre hook returned an invalid number of grad_output, " "got {}, but expected {}".format( actual_len, expected_len ) ) self.grad_outputs = hook_grad_outputs # Special case if no input required gradients, this hook should call the user # hook directly if self.input_tensors_index is None: grad_inputs = self._pack_with_none([], [], self.n_inputs) for user_hook in self.user_hooks: res = user_hook(self.module, grad_inputs, self.grad_outputs) if res is not None and not ( isinstance(res, tuple) and all(el is None for el in res) ): raise RuntimeError( "Backward hook for Modules where no input requires " "gradient should always return None or None for all gradients." ) self.grad_outputs = None grad_fn.register_hook(hook) is_tuple = True if not isinstance(args, tuple): args = (args,) is_tuple = False res, output_idx = self._apply_on_tensors(fn, args) self.n_outputs = len(args) self.output_tensors_index = output_idx if not is_tuple: res = res[0] return res ================================================ FILE: python/oneflow/utils/insight/README.md ================================================ # OneFlow Insight ## Overview OneFlow Insight is a module designed for profiling CUDA kernel execution time and bottleneck analysis. Typically, this is done using the nsys command provided by Nvidia, which generates corresponding profile files (formerly .qdrep and now .nsys-rep). These files can be visualized and analyzed using Nvidia's GUI software, Nsight Systems. In addition to generating profile files, nsys also produces platform-independent data information recorded in a .sqlite file. The OneFlow Insight module can parse this .sqlite file to generate a JSON file formatted according to the Google Chrome Trace Event standard. This allows for direct visualization and analysis through Chrome or Edge browsers using chrome://tracing/ or edge://tracing/ (supported by trace-event-profiling-tool, see:https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/). ## Usage 1. Generate profile files using the following nsys command: ```bash nsys profile --export=sqlite -o profile_data ``` This will produce .nsys-rep files along with a .sqlite file. 2. Use OneFlow Insight to parse the .sqlite file and generate a JSON file: ```bash python3 sqlite_to_google_trace_event.py --input 'profile_data.sqlite' -o trace.json ``` 3. Open Chrome or Edge browser and navigate to chrome://tracing/ or edge://tracing/. 4. Load the generated trace.json file for visualizing and analyzing the profiling data. ## Visualization Example ![OneFlow Insight Visualization](trace.json.png) The above image demonstrates the visualization capabilities using Chrome or Edge browser with the generated JSON file. Feel free to explore and gain insights into your CUDA kernel execution performance! ================================================ FILE: python/oneflow/utils/insight/requirements.txt ================================================ sqlite3 argparse traceback ================================================ FILE: python/oneflow/utils/insight/sqlite_to_google_trace_event.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import json import sqlite3 import argparse import traceback class DatabaseManager: def __init__(self, db_file): self.db_file = db_file self.connection = None self.cursor = None def open_connection(self): self.connection = sqlite3.connect(self.db_file) self.cursor = self.connection.cursor() def close_connection(self): if self.cursor: self.cursor.close() if self.connection: self.connection.close() def execute_sql(self, sql): try: self.cursor.execute(sql) self.connection.commit() except sqlite3.Error as e: print(f"Execute sql '{sql}' error: {e}") traceback.print_exc() def are_tables_exist(db_manager, table_names): try: # Query for the existence of sqlite database tables with specific names results = {} for table_name in table_names: db_manager.execute_sql( f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'" ) result = db_manager.cursor.fetchone() results[table_name] = result is not None return results except sqlite3.Error as e: print(f"are_tables_exist() SQLite error: {e}") return {} def print_db_info(db_manager): # execute sql db_manager.execute_sql("SELECT name, sql FROM sqlite_master WHERE type='table';") # get results tables = db_manager.cursor.fetchall() # print infomation for table in tables: print(f"Table Name: {table[0]}\nCreate Table SQL: {table[1]}\n") def get_start_time(db_manager): """ get session start time(timestamp) from table TARGET_INFO_SESSION_START_TIME """ sql = "SELECT utcEpochNs FROM TARGET_INFO_SESSION_START_TIME LIMIT 1;" db_manager.execute_sql(sql) result = db_manager.cursor.fetchone() timestamp = result[0] return timestamp def get_process_id(db_manager): """ get process id from table TARGET_INFO_CUDA_NULL_STREAM """ sql = "SELECT processId FROM TARGET_INFO_CUDA_NULL_STREAM LIMIT 1;" db_manager.execute_sql(sql) result = db_manager.cursor.fetchone() process_id = result[0] return process_id def get_device_property(db_manager): """ get device properties from TARGET_INFO_GPU """ sql = ( "SELECT name,totalMemory,computeMajor,computeMinor," "maxThreadsPerBlock,maxBlocksPerSm,maxRegistersPerBlock," "maxRegistersPerSm,threadsPerWarp,maxShmemPerBlock," "maxRegistersPerSm,smCount,maxShmemPerBlockOptin " "FROM TARGET_INFO_GPU WHERE id is 0;" ) db_manager.execute_sql(sql) ( name, totalGlobalMem, computeMajor, computeMinor, maxThreadsPerBlock, maxBlocksPerSm, regsPerBlock, regsPerMultiprocessor, warpSize, sharedMemPerBlock, sharedMemPerMultiprocessor, numSms, sharedMemPerBlockOptin, ) = db_manager.cursor.fetchone() maxThreadsPerMultiprocessor = maxThreadsPerBlock * maxBlocksPerSm property = { "id": 0, "name": name, "totalGlobalMem": totalGlobalMem, "computeMajor": computeMajor, "computeMinor": computeMinor, "maxThreadsPerBlock": maxThreadsPerBlock, "maxThreadsPerMultiprocessor": maxThreadsPerMultiprocessor, "regsPerBlock": regsPerBlock, "regsPerMultiprocessor": regsPerMultiprocessor, "warpSize": warpSize, "sharedMemPerBlock": sharedMemPerBlock, "sharedMemPerMultiprocessor": sharedMemPerMultiprocessor, "numSms": numSms, "sharedMemPerBlockOptin": sharedMemPerBlockOptin, } return property def sqlite_to_google_trace_event(args, tables): try: database_path = args.input print("Opening sqlite database :", database_path) db_manager = DatabaseManager(database_path) db_manager.open_connection() # print basic database information if args.info: print_db_info(db_manager) print("Checking if the following table exists:") results = are_tables_exist(db_manager, tables) for table_name, exists in results.items(): if not exists: print(f"'{table_name}' not exists.") raise ValueError( f"Table '{table_name}' does not exist in the database." ) else: print(f"'{table_name}' exists.") # get some necessary information session_start_time = get_start_time(db_manager) # session start time process_id = get_process_id(db_manager) # process id device_property = get_device_property(db_manager) # properties of cuda device deviceProperties = [device_property] db_manager.execute_sql( "SELECT name,busLocation FROM TARGET_INFO_GPU WHERE id is 0;" ) name, bus_location = db_manager.cursor.fetchone() db_manager.execute_sql( "SELECT duration, startTime, stopTime FROM ANALYSIS_DETAILS LIMIT 1;" ) trace_duration, trace_start_time, trace_stop_time = db_manager.cursor.fetchone() raw_start_time = session_start_time + trace_start_time start_time = round(raw_start_time / 1000) # μs to ms end_time = round((session_start_time + trace_stop_time) / 1000) duration = round(trace_duration / 1000) # μs to ms traceEvents_data = [] # construct process meta infomations traceEvents_meta = [ { "name": "process_name", "ph": "M", "ts": start_time, "pid": process_id, "tid": 0, "args": {"name": "python3"}, }, { "name": "process_labels", "ph": "M", "ts": start_time, "pid": process_id, "tid": 0, "args": {"labels": "CPU"}, }, { "name": "process_sort_index", "ph": "M", "ts": start_time, "pid": process_id, "tid": 0, "args": {"sort_index": process_id}, }, { "name": "process_name", "ph": "M", "ts": start_time, "pid": 0, "tid": 0, "args": {"name": "python3"}, }, { "name": "process_labels", "ph": "M", "ts": start_time, "pid": 0, "tid": 0, "args": {"labels": f"GPU 0(CUDA HW {bus_location} - {name})"}, }, { "name": "process_sort_index", "ph": "M", "ts": start_time, "pid": 0, "tid": 0, "args": {"sort_index": process_id}, }, { "ph": "X", "cat": "Trace", "ts": start_time, "dur": duration, "pid": "Spans", "tid": "OneFlow Insight", "name": "OneFlow Insight (0)", "args": {"Op count": 0}, }, { "name": "process_sort_index", "ph": "M", "ts": start_time, "pid": "Spans", "tid": 0, "args": {"sort_index": "Spans"}, }, { "name": "Iteration Start: OneFlow Insight", "ph": "i", "s": "g", "pid": "Traces", "tid": "Trace OneFlow Insight", "ts": start_time, }, { "name": "Record Window End", "ph": "i", "s": "g", "pid": "", "tid": "", "ts": end_time, }, ] # construct vm threads meta infomations db_manager.execute_sql("SELECT text,globalTid FROM NVTX_EVENTS;") globalTids = [] for row in db_manager.cursor.fetchall(): text, globalTid = row globalTids.append(globalTid) osrt_name = { "name": "thread_name", "ph": "M", "ts": start_time, "pid": process_id, "tid": f"[OSRT API]{globalTid}", "args": {"name": f"[OSRT API]{text}"}, } osrt_sort_index = { "name": "thread_sort_index", "ph": "M", "ts": start_time, "pid": process_id, "tid": f"[OSRT API]{globalTid}", "args": {"sort_index": globalTid - 1}, } cu_api_name = { "name": "thread_name", "ph": "M", "ts": start_time, "pid": process_id, "tid": globalTid, "args": {"name": f"[CUDA API]{text}"}, } cu_api_name_index = { "name": "thread_sort_index", "ph": "M", "ts": start_time, "pid": process_id, "tid": globalTid, "args": {"sort_index": globalTid}, } traceEvents_meta.append(osrt_name) traceEvents_meta.append(osrt_sort_index) traceEvents_meta.append(cu_api_name) traceEvents_meta.append(cu_api_name_index) # construct cuda stream meta infomations db_manager.execute_sql( "SELECT streamId,processId FROM TARGET_INFO_CUDA_STREAM;" ) temp_time = start_time for row in db_manager.cursor.fetchall(): temp_time += 187000 streamId, processId = row thread_name = { "name": "thread_name", "ph": "M", "ts": start_time, "pid": 0, "tid": streamId, "args": {"name": f"cuda stream {streamId}", "stream": streamId,}, } thread_sort_index = { "name": "thread_sort_index", "ph": "M", "ts": start_time, "pid": 0, "tid": streamId, "args": {"sort_index": streamId}, } traceEvents_meta.append(thread_name) traceEvents_meta.append(thread_sort_index) # insert os runtime events global_tids = ", ".join(map(str, globalTids)) db_manager.execute_sql( f"SELECT start,end,globalTid,nameId FROM OSRT_API WHERE globalTid IN ({global_tids});" ) for row in db_manager.cursor.fetchall(): start, end, globalTid, nameId = row db_manager.execute_sql(f"SELECT value FROM StringIds WHERE id = {nameId};") name = db_manager.cursor.fetchone()[0] ts = (raw_start_time + start) / 1000 dur = (end - start) / 1000 row_data = { "ph": "X", "cat": "OS RUNTIME API", "name": name, "pid": process_id, "tid": f"[OSRT API]{globalTid}", "ts": ts, "dur": dur, "args": {"global tid": f"{globalTid}(serialized)",}, } traceEvents_data.append(row_data) # insert cuda runtime api events db_manager.execute_sql( "SELECT start,end,globalTid,correlationId,nameId FROM CUPTI_ACTIVITY_KIND_RUNTIME;" ) for row in db_manager.cursor.fetchall(): start, end, globalTid, correlationId, nameId = row db_manager.execute_sql(f"SELECT value FROM StringIds WHERE id is {nameId};") name = db_manager.cursor.fetchone()[0] short_name = name.split("_", 1)[0] ts = (raw_start_time + start) / 1000 dur = (end - start) / 1000 row_data = { "ph": "X", "cat": "CUDA API", "name": short_name, "pid": process_id, "tid": globalTid, "ts": ts, "dur": dur, "args": { "name": f"Call to {name}", "begins": f"{start/(10**9)}s", "ends": f"{end/(10**9)}s(+{dur}ms)", "global tid": f"{globalTid}(serialized)", "correlation id": correlationId, }, } traceEvents_data.append(row_data) # insert cuda kernel events db_manager.execute_sql( ( "SELECT start,end,deviceId,contextId,streamId," "correlationId,globalPid,demangledName,shortName," "gridX,gridY,gridZ,blockX,blockY,blockZ," "staticSharedMemory,dynamicSharedMemory,localMemoryTotal " "FROM CUPTI_ACTIVITY_KIND_KERNEL;" ) ) for row in db_manager.cursor.fetchall(): ( start, end, deviceId, contextId, streamId, correlationId, globalPid, demangledName, shortName, gridX, gridY, gridZ, blockX, blockY, blockZ, staticSharedMemory, dynamicSharedMemory, localMemoryTotal, ) = row db_manager.execute_sql( f"SELECT value FROM StringIds WHERE id is {shortName}" ) short_name = db_manager.cursor.fetchone()[0] db_manager.execute_sql( f"SELECT value FROM StringIds WHERE id is {demangledName}" ) name = db_manager.cursor.fetchone()[0] ts = (raw_start_time + start) / 1000 dur = (end - start) / 1000 row_data = { "ph": "X", "cat": "CUDA Kernel", "name": short_name, "pid": 0, "tid": streamId, "ts": ts, "dur": dur, "args": { "name": name, "begins": f"{start/(10**9)}s", "ends": f"{end/(10**9)}s(+{dur}ms)", "grid": f"<<<{gridX},{gridY},{gridZ}>>>", "block": f"<<<{blockX},{blockY},{blockZ}>>>", "static shared memory": f"{staticSharedMemory}bytes", "dynamic shared memory": f"{dynamicSharedMemory}bytes", "local memory total": f"{localMemoryTotal}bytes", "global pid": f"{globalPid}(serialized)", "device id": deviceId, "context id": contextId, "stream id": streamId, "correlation id": correlationId, }, } traceEvents_data.append(row_data) # construct trace event dict traceEvents = traceEvents_data + traceEvents_meta data = {"deviceProperties": deviceProperties, "traceEvents": traceEvents} # the path to the JSON file to be written json_fpath = args.output # write dict content into a JSON file using json.dump with open(json_fpath, "w") as json_file: json.dump(data, json_file, indent=2) print(f"Successfully converted content to file: {json_fpath}") except BaseException as e: print(f"An exception occurred: {type(e).__name__}: {e}") traceback.print_exc() finally: # close db connection db_manager.close_connection() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Description of your program") parser.add_argument("--input", help="Input nvidia nsight system .sqlite file path") parser.add_argument( "--output", "-o", help="Output json file path(google trace format)", default="sqlite_to_google_trace_event.json", ) parser.add_argument( "--info", "-v", action="store_true", help="Enable print infomation of sqlite database", default=False, ) args = parser.parse_args() # check if necessary tables exist tables_to_check = [ "TARGET_INFO_GPU", "TARGET_INFO_SESSION_START_TIME", "TARGET_INFO_CUDA_NULL_STREAM", "ANALYSIS_DETAILS", "NVTX_EVENTS", "TARGET_INFO_CUDA_STREAM", "OSRT_API", "StringIds", "CUPTI_ACTIVITY_KIND_RUNTIME", "CUPTI_ACTIVITY_KIND_KERNEL", ] # Usage: # python3 sqlite_to_google_trace_event.py --input 'your_file.sqlite' sqlite_to_google_trace_event(args, tables_to_check) ================================================ FILE: python/oneflow/utils/model_zoo.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # torchvision/flowvision imports tqdm from here. from oneflow.hub import tqdm, load_state_dict_from_url as load_url # noqa: F401 ================================================ FILE: python/oneflow/utils/tensor/__init__.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from oneflow.utils.tensor.from_or_to_torch_tensor import from_torch, to_torch __all__ = [ "from_torch", "to_torch", ] ================================================ FILE: python/oneflow/utils/tensor/from_or_to_torch_tensor.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys import oneflow as flow from oneflow._C import from_numpy as flow_from_numpy def print_error_msg(): msg = "" exc_info = sys.exc_info() if len(exc_info) > 0: msg += str(exc_info[0]) if len(exc_info) > 1: msg += " " + str(exc_info[1]) print(msg) def from_torch(torch_tensor): r""" from_torch(torch_tensor) -> Tensor Create a oneflow tensor from torch tensor. The returned tensor and torch tensor share the same memory. .. note:: This function can be used in special data processing stages, torch's some cpu ops can be used. Args: input (torch.Tensor): Input Tensor Returns: oneflow.Tensor For example: .. code-block:: python import oneflow as flow import torch torch_t = torch.tensor([[1, 2, 3], [4, 5, 6]]) flow_t = flow.utils.tensor.from_torch(torch_t) This feature ``from_torch`` is at Alpha Stage. """ try: import torch except: print_error_msg() assert isinstance(torch_tensor, torch.Tensor) return flow.from_dlpack(torch.to_dlpack(torch_tensor)) def to_torch(flow_tensor): r""" to_torch(flow_tensor) -> Tensor Create a torch tensor from oneflow tensor. The returned tensor and oneflow tensor share the same memory. .. note:: Currently only local tensor is supported. Args: input (oneflow.Tensor): Input Tensor Returns: torch.Tensor For example: .. code-block:: python import oneflow as flow import torch flow_t = flow.tensor([[1, 2, 3], [4, 5, 6]]) torch_t = flow.utils.tensor.to_torch(flow_t) This feature ``to_torch`` is at Alpha Stage. """ try: import torch except: print_error_msg() assert isinstance(flow_tensor, flow.Tensor) if flow_tensor.is_global: print( "WARNING: `to_torch` received a global tensor. A PyTorch CPU tensor which is a copy of its data will be returned." ) return torch.from_numpy(flow_tensor.numpy()) return torch.from_dlpack(flow.to_dlpack(flow_tensor)) ================================================ FILE: python/setup.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import absolute_import import argparse import glob import os import sys import numpy as np from setuptools import find_packages, setup from setuptools.command.install import install from setuptools.dist import Distribution # https://github.com/google/or-tools/issues/616 class InstallPlatlib(install): def finalize_options(self): install.finalize_options(self) if self.distribution.has_ext_modules(): self.install_lib = self.install_platlib parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument("--package_name", type=str, default="oneflow") args, remain_args = parser.parse_known_args() sys.argv = ["setup.py"] + remain_args def get_version(): import importlib.util spec = importlib.util.spec_from_file_location( "version", os.path.join("oneflow", "version.py") ) m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) return m.__version__ REQUIRED_PACKAGES = [ f"numpy>={np.__version__}, <2.0", "protobuf>=3.9.2, <4.0", "typing-extensions>=4.0.0, <5.0", "tqdm", "requests", "pillow", "rich", ] ONEFLOW_VERSION = get_version() if "cu11" in ONEFLOW_VERSION and "cu112" not in ONEFLOW_VERSION: REQUIRED_PACKAGES.append("nvidia-cudnn-cu11>=8.9,<9.0") REQUIRED_PACKAGES.append("nvidia-cublas-cu11") REQUIRED_PACKAGES.append("nvidia-nccl-cu11") REQUIRED_PACKAGES.append("nvidia-cusparse-cu11") REQUIRED_PACKAGES.append("nvidia-cufft-cu11") if "cu12" in ONEFLOW_VERSION: REQUIRED_PACKAGES.append("nvidia-cudnn-cu12>=8.9,<9.0") REQUIRED_PACKAGES.append("nvidia-cublas-cu12") REQUIRED_PACKAGES.append("nvidia-nccl-cu12") REQUIRED_PACKAGES.append("nvidia-cusparse-cu12") REQUIRED_PACKAGES.append("nvidia-cufft-cu12") # if python version < 3.7.x, than need pip install dataclasses if sys.version_info.minor < 7: REQUIRED_PACKAGES.append("dataclasses") class BinaryDistribution(Distribution): def is_pure(self): return False def has_ext_modules(self): return True include_files = glob.glob("oneflow/include/**/*", recursive=True) include_files = [os.path.relpath(p, "oneflow") for p in include_files] assert len(include_files) > 0, os.path.abspath("oneflow/include") def get_oneflow_internal_so_path(): import importlib suffixes = importlib.machinery.EXTENSION_SUFFIXES loader = importlib.machinery.ExtensionFileLoader lazy_loader = importlib.util.LazyLoader.factory(loader) finder = importlib.machinery.FileFinder("oneflow", (lazy_loader, suffixes)) spec = finder.find_spec("_oneflow_internal") pathname = spec.origin assert os.path.isfile(pathname) return os.path.basename(pathname) package_data = {"oneflow": [get_oneflow_internal_so_path()] + include_files} setup( name=args.package_name, version=get_version(), url="https://www.oneflow.org/", install_requires=REQUIRED_PACKAGES, packages=find_packages(), package_dir={"oneflow": "oneflow"}, package_data=package_data, zip_safe=False, distclass=BinaryDistribution, cmdclass={"install": InstallPlatlib}, entry_points={ "console_scripts": ["oneflow-mock-torch=oneflow.mock_torch.__main__:main"] }, ) ================================================ FILE: tools/check_src.py ================================================ import os from pathlib import Path this_file = os.path.dirname(os.path.abspath(__file__)) src_root = os.path.join(this_file, "..") src_root = Path(os.path.abspath(src_root)) def check_unwanted_test_scripts(python_test_dir=None, allowed=None): python_test_dir = os.path.abspath(python_test_dir) allowed_full = [ os.path.relpath(os.path.join(python_test_dir, a), src_root) for a in allowed ] for (dirpath, dirnames, filenames) in os.walk(src_root): if ( dirpath.startswith(os.path.abspath(python_test_dir) + os.sep) and "__pycache__" not in dirpath ): rel_to_python_test = os.path.relpath(dirpath, python_test_dir) rel_to_src_root = os.path.relpath(dirpath, src_root) print(f"checking: {rel_to_src_root}") if ( rel_to_python_test not in allowed and rel_to_python_test != "." and "custom_ops" not in rel_to_python_test ): if filenames == []: raise ValueError(f"delete this directory: {rel_to_src_root}") else: filenames_full = [ os.path.relpath(os.path.join(dirpath, a), src_root) for a in filenames ] raise ValueError( f"""move these files: {filenames_full} inside one of these directories: {allowed_full}, and delete this directory: {rel_to_src_root}""" ) def check_dir_empty(path): if os.path.exists(path): for dirpath, dirnames, files in os.walk(path): if files: raise ValueError(dirpath, "must be empty") oneflow_test_dir = src_root / "python" / "oneflow" / "test" save_load_test_data_dirs = [ os.path.relpath(x[0], oneflow_test_dir) for x in os.walk(oneflow_test_dir / "modules" / "save_load_test_data") ] print(save_load_test_data_dirs) check_unwanted_test_scripts( python_test_dir=oneflow_test_dir, allowed=[ "custom_ops", "dataloader", "graph", "models", "modules", *save_load_test_data_dirs, "tensor", "exceptions", "expensive", "ddp", "misc", "profiler", ], ) ================================================ FILE: tools/clean_generated_api.py ================================================ import argparse import glob import os import shutil parser = argparse.ArgumentParser() parser.add_argument("-root", "--root_path", type=str, required=True) args = parser.parse_args() def main(): for p in glob.glob(os.path.join(args.root_path, "oneflow/*/")): if p.endswith("python/") or p.endswith("include/"): pass else: shutil.rmtree(p) if __name__ == "__main__": main() ================================================ FILE: tools/create_pip_index.py ================================================ # python3 -m pip install oss2 beautifulsoup4 --user from bs4 import BeautifulSoup import os import oss2 import urllib import urllib.parse os.environ["no_proxy"] = "*" page_template = """ Directory listing for /oneflow/

Directory listing for /oneflow/



""" soup = BeautifulSoup(page_template, "html.parser") def url4key(endpoint, bucket, key): return "https://{}.{}/{}".format(bucket, endpoint, urllib.parse.quote(key)) def append_link(soup, link): li_tag = soup.new_tag("li") soup.body.ul.append(li_tag) a_tag = soup.new_tag("a", href=link) a_tag.append(os.path.basename(link)) li_tag.append(a_tag) def generate_index_file(endpoint, bucket, dir_key, file_path, index_keys=None): ki = os.getenv("OSS_ACCESS_KEY_ID") ks = os.getenv("OSS_ACCESS_KEY_SECRET") auth = oss2.Auth(ki, ks) bucket_obj = oss2.Bucket(auth, endpoint, bucket) should_continue = True count = 0 next_marker = "" while should_continue: files = bucket_obj.list_objects(dir_key + "/", marker=next_marker) for f in files.object_list: key = f.key if key.endswith(".whl"): link = url4key(endpoint, bucket, key) append_link(soup, link) count += 1 next_marker = files.next_marker should_continue = next_marker != "" print("count", count) assert count html = soup.prettify() with open(file_path, "w+") as f: f.write(html) if index_keys == None: index_keys = [dir_key + ".index.html"] for index_key in index_keys: bucket_obj.put_object_from_file(index_key, file_path) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "-o", "--output_path", type=str, required=False, default="pip_index.html" ) parser.add_argument( "-e", "--endpoint", type=str, required=False, default="oss-cn-beijing.aliyuncs.com", ) parser.add_argument( "-b", "--bucket", type=str, required=False, default="oneflow-public", ) parser.add_argument( "-d", "--dir_key", type=str, required=False, default="nightly", ) parser.add_argument("--index_key", action="append", nargs="+") args = parser.parse_args() assert args.dir_key[-1] != "/" index_keys = sum(args.index_key, []) generate_index_file( args.endpoint, args.bucket, args.dir_key, args.output_path, index_keys=index_keys, ) ================================================ FILE: tools/flags_from_git_diff.py ================================================ import subprocess def get_changed_files(base=None, head=None): changed = subprocess.check_output( f"git diff --name-only --diff-filter=ACMRT {base} {head}", shell=True, text=True, ) changed = str(changed).splitlines() return changed def should_run_single_client_tests(changed=None): not_single_client_files = [ f for f in changed if ( f.endswith(".py") and not f.startswith("python/oneflow/compatible/single_client") ) or f.endswith(".yml") or f.endswith(".rst") or f.endswith(".md") or f.endswith(".cmake") or f.endswith("CMakeLists.txt") ] print("[changed]", changed) print("[not_single_client_files]", not_single_client_files) return len(not_single_client_files) < len(changed) def print_github_action_output(name=None, value=None): print(f"::set-output name={name}::{value}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--base", type=str, required=True) parser.add_argument("--head", type=str, required=True) parser.add_argument("--need_single_client_tests", action="store_true") args = parser.parse_args() files = get_changed_files(base=args.base, head=args.head) if should_run_single_client_tests(changed=files) or args.need_single_client_tests: print_github_action_output(name="should_run_single_client_tests", value="1") ================================================ FILE: tools/functional/generate_dispatch_stateful_ops.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import argparse import yaml from generator import Generator parser = argparse.ArgumentParser() parser.add_argument( "--project_source_dir", type=str, help="The project source code directory.", ) args = parser.parse_args() license = """/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Generated from oneflow/api/python/functional/dispatch_stateful_ops.yaml. DO NOT EDIT!""" header_fmt = ( license + """ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/tensor_index.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_""" ) source_fmt = ( license + """ #include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.h" #include "oneflow/core/functional/function_library.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_header_fmt = ( license + """ namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_source_fmt = ( license + """ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/python_arg_parser.h" #include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.h" #include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.pybind.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/extension/stack/python/stack_getter.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one namespace functional = one::functional; ONEFLOW_API_PYBIND11_MODULE("_C", m) {{ static PyMethodDef functions[] = {{ {1} {{NULL, NULL, 0, NULL}} }}; PyObject* module = m.ptr(); if (module) {{ PyModule_AddFunctions(module, functions); }} }} }} // namespace oneflow """ ) yaml_file_path = os.path.join( args.project_source_dir, "oneflow/api/python/functional/dispatch_stateful_ops.yaml" ) generated_api_dir = "oneflow/api/python/functional" generated_pybind_dir = "oneflow/api/python/functional" if __name__ == "__main__": assert os.path.isfile(yaml_file_path), ( "It is not a regular file for the yaml file which is " + yaml_file_path ) g = Generator(yaml_file_path) assert os.path.isdir(generated_api_dir), ( "Could not locate the api generate directory which is " + generated_api_dir ) target_header_file = os.path.join(generated_api_dir, "dispatch_stateful_ops.yaml.h") g.generate_cpp_header_file(header_fmt, target_header_file) target_source_file = os.path.join( generated_api_dir, "dispatch_stateful_ops.yaml.cpp" ) g.generate_cpp_source_file(source_fmt, target_source_file) assert os.path.isdir(generated_pybind_dir), ( "Could not locate the pybind generate directory which is " + generated_pybind_dir ) target_pybind_header_file = os.path.join( generated_pybind_dir, "dispatch_stateful_ops.yaml.pybind.h" ) target_pybind_source_file = os.path.join( generated_pybind_dir, "dispatch_stateful_ops.yaml.pybind.cpp" ) g.generate_pybind_for_python( pybind_header_fmt, pybind_source_fmt, target_pybind_header_file, target_pybind_source_file, ) ================================================ FILE: tools/functional/generate_functional_api.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import argparse import yaml from generator import Generator parser = argparse.ArgumentParser() parser.add_argument( "--project_source_dir", type=str, help="The project source code directory.", ) parser.add_argument( "--export_pybind", action="store_true", default=False, help="Whether to export pybind related files.", ) args = parser.parse_args() license = """/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Generated from oneflow/core/functional/functional_api.yaml. DO NOT EDIT!""" header_fmt = ( license + """ #ifndef ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ #define ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ #include "oneflow/core/common/memory_format.pb.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/layout.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/tensor_index.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow #endif // ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_""" ) source_fmt = ( license + """ #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/functional/function_library.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_header_fmt = ( license + """ #include #undef _PyGC_FINALIZED namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_source_fmt = ( license + """ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/python_arg_parser.h" #include "oneflow/api/python/functional/python_return_types.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/functional/functional.h" #include "oneflow/extension/stack/python/stack_getter.h" namespace {{ // This return type template code is referenced from: // https://github.com/pytorch/pytorch/blob/master/tools/autograd/gen_python_functions.py using oneflow::one::functional::returned_structseq_repr; {2} std::unordered_map& get_namedtuple_types_map() {{ static std::unordered_map namedtuple_types_map = {{ {3} }}; return namedtuple_types_map; }} PyTypeObject* get_namedtuple(const std::string& name) {{ static auto& namedtuple_types_map = get_namedtuple_types_map(); return namedtuple_types_map[name]; }} }} // namespace namespace oneflow {{ namespace one {{ namespace functional {{ PyObject* WrapTensorTuple(const TensorTuple& tensortuple, const std::string& name) {{ PyObjectPtr r(PyStructSequence_New(get_namedtuple(name))); if (!r) {{ throw py::error_already_set(); }} for (int i = 0; i < tensortuple.size(); ++i) {{ PyTuple_SET_ITEM(r.get(), i, CastToPyObject(tensortuple[i])); }} return r.release(); }} {0} }} // namespace functional }} // namespace one namespace functional = one::functional; ONEFLOW_API_PYBIND11_MODULE("_C", m) {{ static PyMethodDef functions[] = {{ {1} {{NULL, NULL, 0, NULL}} }}; PyObject* module = m.ptr(); if (module) {{ PyModule_AddFunctions(module, functions); }} }} }} // namespace oneflow """ ) yaml_file_path = os.path.join( args.project_source_dir, "oneflow/core/functional/functional_api.yaml" ) generated_api_dir = "oneflow/core/functional" generated_pybind_dir = "oneflow/api/python/functional" if __name__ == "__main__": assert os.path.isfile(yaml_file_path), ( "It is not a regular file for the yaml file which is " + yaml_file_path ) g = Generator(yaml_file_path) assert os.path.isdir(generated_api_dir), ( "Could not locate the api generate directory which is " + generated_api_dir ) target_header_file = os.path.join(generated_api_dir, "functional_api.yaml.h") g.generate_cpp_header_file(header_fmt, target_header_file) target_source_file = os.path.join(generated_api_dir, "functional_api.yaml.cpp") g.generate_cpp_source_file(source_fmt, target_source_file) if args.export_pybind: assert os.path.isdir(generated_pybind_dir), ( "Could not locate the pybind generate directory which is " + generated_pybind_dir ) target_pybind_header_file = os.path.join( generated_pybind_dir, "functional_api.yaml.pybind.h" ) target_pybind_source_file = os.path.join( generated_pybind_dir, "functional_api.yaml.pybind.cpp" ) g.generate_pybind_for_python( pybind_header_fmt, pybind_source_fmt, target_pybind_header_file, target_pybind_source_file, ) ================================================ FILE: tools/functional/generate_tensor_api.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import argparse import yaml from generator import Generator parser = argparse.ArgumentParser() parser.add_argument( "--project_source_dir", type=str, help="The project source code directory.", ) args = parser.parse_args() license = """/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Generated from oneflow/api/python/functional/tensor_api.yaml. DO NOT EDIT!""" header_fmt = ( license + """ #ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_ #define ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_ #include #undef _PyGC_FINALIZED #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/tensor_index.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow #endif // ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_""" ) source_fmt = ( license + """ #include "oneflow/api/python/functional/tensor_api.yaml.h" #include "oneflow/core/functional/function_library.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_header_fmt = ( license + """ #include #undef _PyGC_FINALIZED namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one }} // namespace oneflow """ ) pybind_source_fmt = ( license + """ #include #undef _PyGC_FINALIZED #include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/exception/exception.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/python_arg_parser.h" #include "oneflow/api/python/functional/tensor_api.yaml.h" #include "oneflow/api/python/functional/tensor_api.yaml.pybind.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" #include "oneflow/extension/stack/python/stack_getter.h" namespace oneflow {{ namespace one {{ namespace functional {{ {0} }} // namespace functional }} // namespace one namespace functional = one::functional; ONEFLOW_API_PYBIND11_MODULE("_C", m) {{ static PyMethodDef functions[] = {{ {1} {{NULL, NULL, 0, NULL}} }}; PyObject* module = m.ptr(); if (module) {{ PyModule_AddFunctions(module, functions); }} }} }} // namespace oneflow """ ) yaml_file_path = os.path.join( args.project_source_dir, "oneflow/api/python/functional/tensor_api.yaml" ) generated_api_dir = "oneflow/api/python/functional" generated_pybind_dir = "oneflow/api/python/functional" if __name__ == "__main__": assert os.path.isfile(yaml_file_path), ( "It is not a regular file for the yaml file which is " + yaml_file_path ) g = Generator(yaml_file_path) assert os.path.isdir(generated_api_dir), ( "Could not locate the api generate directory which is " + generated_api_dir ) target_header_file = os.path.join(generated_api_dir, "tensor_api.yaml.h") g.generate_cpp_header_file(header_fmt, target_header_file) target_source_file = os.path.join(generated_api_dir, "tensor_api.yaml.cpp") g.generate_cpp_source_file(source_fmt, target_source_file) assert os.path.isdir(generated_pybind_dir), ( "Could not locate the pybind generate directory which is " + generated_pybind_dir ) target_pybind_header_file = os.path.join( generated_pybind_dir, "tensor_api.yaml.pybind.h" ) target_pybind_source_file = os.path.join( generated_pybind_dir, "tensor_api.yaml.pybind.cpp" ) g.generate_pybind_for_python( pybind_header_fmt, pybind_source_fmt, target_pybind_header_file, target_pybind_source_file, ) ================================================ FILE: tools/functional/generator.py ================================================ """ Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import re import argparse import yaml types_allowed = { "Void", "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement", "Sbp", "SbpList", "OpExpr", "PyObject*", "ShapeList", "DataTypeList", "Layout", "MemoryFormat", } mangled_name = { "Void": "V", "Tensor": "T", "TensorTuple": "Tt", "Scalar": "Sc", "Int": "I", "Int32": "I32", "Int64": "I64", "Float": "F", "Double": "D", "String": "S", "Bool": "B", "ScalarList": "Scl", "IntList": "Il", "Int32List": "I32l", "Int64List": "I64l", "FloatList": "Fl", "DoubleList": "Dl", "StringList": "Sl", "BoolList": "Bl", "DataType": "Dt", "Shape": "Sh", "Generator": "G", "TensorIndex": "Ti", "Device": "De", "Placement": "P", "Sbp": "Sbp", "SbpList": "Sbpl", "OpExpr": "Op", "PyObject*": "Pyo", "ShapeList": "Shl", "DataTypeList": "Dtl", "Layout": "Lo", "MemoryFormat": "Memf", } generic_type_aliases = { "Int": "int32_t", "Int32": "int32_t", "Int64": "int64_t", "Float": "float", "Double": "double", "Bool": "bool", } argument_type_aliases = { "Tensor": "const std::shared_ptr&", "TensorTuple": "const TensorTuple&", "Scalar": "const Scalar&", "ScalarList": "const std::vector&", "IntList": "const std::vector&", "Int32List": "const std::vector&", "Int64List": "const std::vector&", "FloatList": "const std::vector&", "DoubleList": "const std::vector&", "String": "const std::string&", "StringList": "const std::vector&", "BoolList": "const std::vector&", "DataType": "const Symbol&", "Shape": "const Shape&", "Generator": "const std::shared_ptr&", "TensorIndex": "const TensorIndex&", "Device": "const Symbol&", "Placement": "const Symbol&", "Sbp": "const Symbol&", "SbpList": "const std::vector>&", "OpExpr": "const std::shared_ptr&", "PyObject*": "PyObject*", "ShapeList": "const std::vector&", "DataTypeList": "const std::vector>&", "Layout": "const Symbol&", "MemoryFormat": "MemoryFormat", **generic_type_aliases, } optional_argument_type_aliases = { "Tensor": "const Optional&", "TensorTuple": "const Optional&", "Scalar": "const Optional&", "ScalarList": "const Optional>&", "IntList": "const Optional>&", "Int32List": "const Optional>&", "Int64List": "const Optional>&", "FloatList": "const Optional>&", "DoubleList": "const Optional>&", "String": "const Optional&", "StringList": "const Optional>&", "BoolList": "const Optional>&", "DataType": "const Optional>&", "Shape": "const Optional&", "Generator": "const Optional&", "TensorIndex": "const Optional&", "Device": "const Optional>&", "Placement": "const Optional>&", "Sbp": "const Optional>&", "SbpList": "const Optional>>&", "OpExpr": "const Optional&", "PyObject*": "const Optional&", "ShapeList": "const Optional>&", "DataTypeList": "const Optional>>&", "Layout": "const Optional>&", "MemoryFormat": "const Optional&", **{k: "const Optional<{0}>&".format(v) for k, v in generic_type_aliases.items()}, } return_type_aliases = { "Void": "Maybe", "Tensor": "Maybe", "TensorTuple": "Maybe", "String": "Maybe", "Shape": "Maybe", **{k: "Maybe<{0}>".format(v) for k, v in generic_type_aliases.items()}, } value_aliases = { "True": "true", "False": "false", "kInt": "DType::Int32()", "kInt8": "DType::Int8()", "kUInt8": "DType::UInt8()", "kInt32": "DType::Int32()", "kInt64": "DType::Int64()", "kFloat": "DType::Float()", "kDouble": "DType::Double()", "kBool": "DType::Bool()", "kStrided": "Layout::Strided()", } def _escape_quote(fmt): return re.sub(r"\"|\'", '\\"', fmt) def _normalize(fmt): fmt = fmt.strip() return re.sub(r"\s+", " ", fmt) def _remove_square_brackets_and_content_inside(fmt): # "TensorTuple[values], TensorTuple[indices]" -> "TensorTuple, TensorTuple" return re.sub(r"\[[^()]*?\]", "", fmt) def _std_decay(fmt): fmt = fmt.strip() fmt = re.sub(r"(const|&)", "", fmt) return _normalize(fmt) def parse_function_params(fmt): params = [] fmt = _normalize(fmt) open_paren = fmt.find("(") if open_paren == -1: raise ValueError('Missing "(" in function def: ' + fmt) header = _normalize(fmt[0:open_paren]) items = _normalize(_remove_square_brackets_and_content_inside(header)).split(" ") if (len(items)) != 1: raise ValueError( "Missing return type or more than 1 return type in function def: " + fmt ) params.append(header) close_paren = fmt.rfind(")") if close_paren == -1: raise ValueError('Missing ")" in Missingfunction def: ' + fmt) tail = fmt[open_paren + 1 : close_paren] # TODO(): Parse the parameter list more comprehensively. items = tail.split(",") for param in items: params.append(_normalize(param)) pos = fmt.rfind("=>") if pos == -1: raise ValueError('Missing "=>" in Missingfunction def: ' + fmt) function_name = _normalize(fmt[pos + 2 :]) return function_name, params def render_file_if_different(target_file, content): if not os.path.isfile(target_file): with open(target_file, "w") as f: f.write(content) else: old_content = None with open(target_file, "r") as f: old_content = f.read() if old_content is None or old_content != content: with open(target_file, "w") as f: f.write(content) def generate_return_types_named_tuple(return_names, func_name, block_name): param_names = ", ".join( [ '{{const_cast("{}"), const_cast("")}}'.format(x) for x in return_names ] ) code = f"""PyTypeObject* Get{func_name}NamedTuple() {{ static PyStructSequence_Field NamedTuple_fields[] = {{ {param_names}, {{nullptr}} }}; static PyTypeObject {func_name}NamedTuple; static bool is_initialized = false; static PyStructSequence_Desc desc = {{ const_cast("oneflow.return_types.{block_name}"), nullptr, NamedTuple_fields, {len(return_names)} }}; if (!is_initialized) {{ PyStructSequence_InitType(&{func_name}NamedTuple, &desc); {func_name}NamedTuple.tp_repr = (reprfunc)returned_structseq_repr; is_initialized = true; }} return &{func_name}NamedTuple; }} """ return code class Argument: def __init__(self, fmt, keyword_only=False): self._keyword_only = keyword_only self._type = None self._name = None self._default_value = None self._size = 0 fmt = _normalize(fmt) sp = fmt.rfind(" ") if sp == -1: raise ValueError("Missing argument type or name for argument def: " + fmt) type_name = fmt[0:sp] arg_name = fmt[sp + 1 :] sp = type_name.find("[") if sp != -1: self._type = _normalize(type_name[0:sp]) size = type_name[sp + 1 :] sp = size.find("]") assert sp != -1, "Missing ']' for argument def: " + fmt size = _normalize(size[0:sp]) assert size.isnumeric(), ( "list size is not an integer for argument def: " + fmt ) self._size = int(size) else: self._type = _normalize(type_name) assert self._type in types_allowed, "Unknow type: " + self._type self._optional = False self._name = _normalize(arg_name) sp = self._name.find("=") if sp != -1: self._default_value = _normalize(self._name[sp + 1 :]) if self._default_value == "None": self._optional = True self._default_cpp_value = "" elif self._type.endswith("List"): if self._default_value != "None": _value_list = [ self._default_value for i in range(self._size) ] # For int32List[2] = 2, _value_list will be ["2", "2"] self._default_cpp_value = ( "{" + ", ".join(_value_list) + "}" ) # ["2", "2"] -> "{2, 2}" elif self._default_value in value_aliases: self._default_cpp_value = value_aliases[self._default_value] else: self._default_cpp_value = self._default_value self._name = _normalize(self._name[0:sp]) if not self._optional and self._type in argument_type_aliases: self._cpp_type = argument_type_aliases[self._type] elif self._optional and self._type in optional_argument_type_aliases: self._cpp_type = optional_argument_type_aliases[self._type] else: self._cpp_type = self._type @property def has_default_value(self): return self._default_value is not None def to_string(self, to_cpp=False): fmt = "{0} {1}".format(self._cpp_type if to_cpp else self._type, self._name) if not to_cpp and self.has_default_value: fmt += "={0}".format(self._default_value) return fmt class Return: def __init__(self, fmt): self._type, self._return_names = self.check_named_tuple(_normalize(fmt)) assert self._type in types_allowed, "Unknow type: " + self._type if self._type in return_type_aliases: self._cpp_type = return_type_aliases[self._type] else: self._cpp_type = self._type @property def type(self): return self._type def to_string(self, to_cpp=False): return self._cpp_type if to_cpp else self._type def check_named_tuple(self, fmt): matches = re.match(r"(.*?)\s*\[(.*?)\]", fmt) if matches is None: type, return_names = _normalize(fmt), None else: type = matches.group(1) return_names = [_normalize(x) for x in matches.group(2).split(",")] return type, return_names class FunctionSignature: def __init__(self, fmt): self._fmt = fmt self._name, self._params = parse_function_params(fmt) self._ret = Return(self._params[0]) keyword_only = False self._args = [] self._max_positional_args_count = 0 for arg in self._params[1:]: if arg == "*": keyword_only = True continue self._args.append(Argument(arg, keyword_only=keyword_only)) if not keyword_only: self._max_positional_args_count += 1 self._max_args_count = len(self._args) count = 0 for arg in self._args: if arg._keyword_only: count += 1 self._max_keyword_args_count = count @property def num_of_args(self): return len(self._args) def to_string(self, to_cpp=False, drop_name=False): if drop_name: fmt = "{0} (".format(self._ret.to_string(to_cpp=to_cpp)) else: fmt = "{0} {1}(".format(self._ret.to_string(to_cpp=to_cpp), self._name) keyword_start = False for i, arg in enumerate(self._args): if i > 0 and i < len(self._args): fmt += ", " if not keyword_start and arg._keyword_only: keyword_start = True if not to_cpp: fmt += "*, " fmt += arg.to_string(to_cpp=to_cpp) fmt += ")" return fmt def get_mangled_type(self): fmt = mangled_name[self._ret._type] for _, arg in enumerate(self._args): fmt += mangled_name[arg._type] return fmt def get_schema_name(self): return "{0}Schema_{1}".format(self._name, self.get_mangled_type()) class Block: def __init__(self, name, signature, bind_python): self._name = name self._signature = signature self._bind_python = bind_python class Generator: def __init__(self, input_file): self._blocks = {} with open(input_file) as f: doc = yaml.load(f, Loader=yaml.FullLoader) for block in doc: assert "name" in block assert "signature" in block name = block["name"] signature = block["signature"] bind_python = False if "bind_python" in block: bind_python = block["bind_python"] self._blocks[name] = list() if isinstance(signature, list): for s in signature: self._blocks[name].append( Block(name, FunctionSignature(s), bind_python) ) else: self._blocks[name].append( Block(name, FunctionSignature(signature), bind_python) ) def generate_cpp_header_file(self, header_fmt, target_header_file): fmt = "" for name, blocks in self._blocks.items(): for block in blocks: fmt += "\n" fmt += block._signature.to_string(to_cpp=True) fmt += ";\n" render_file_if_different(target_header_file, header_fmt.format(fmt)) def generate_cpp_source_file(self, source_fmt, target_source_file): fmt = "" for name, blocks in self._blocks.items(): for block in blocks: signature = block._signature fmt += "\n" fmt += signature.to_string(to_cpp=True) fmt += " {\n" fmt += ' static thread_local const auto& __op = CHECK_JUST(FunctionLibrary::Global()->find<{0}, {1}>("{2}"));\n'.format( signature._ret._cpp_type, ", ".join([arg._cpp_type for arg in signature._args]), signature._name, ) fmt += " return __op->call({0});\n".format( ", ".join([arg._name for arg in signature._args]), ) fmt += "}\n" render_file_if_different(target_source_file, source_fmt.format(fmt)) def generate_pybind_for_python( self, pybind_header_fmt, pybind_source_fmt, target_pybind_header_file, target_pybind_source_file, ): schema_fmt = "" module_fmt = "" header_fmt = "" return_type_fmt = "" map_pairs = [] for name, blocks in self._blocks.items(): schema_types = [] max_args_count = 0 for block in blocks: if not block._bind_python: continue signature = block._signature max_args_count = max(max_args_count, signature._max_args_count) schema_types.append( "functional::{0}".format(signature.get_schema_name()) ) return_type = signature._ret._cpp_type schema_fmt += "\n" schema_fmt += "struct {0} {{\n".format(signature.get_schema_name()) schema_fmt += " using FType = {0};\n".format( signature.to_string(to_cpp=True, drop_name=True) ) schema_fmt += " using R = {0};\n".format(return_type) schema_fmt += "\n" schema_fmt += " static constexpr FType* func = &functional::{0};\n".format( signature._name ) schema_fmt += " static constexpr size_t max_args = {0};\n".format( signature._max_args_count ) schema_fmt += " static constexpr size_t max_pos_args = {0};\n".format( signature._max_positional_args_count ) schema_fmt += ' static constexpr char const* signature = "{0}";\n'.format( _escape_quote(signature.to_string(drop_name=True)) ) schema_fmt += " static FunctionDef function_def;\n" schema_fmt += "};\n" schema_fmt += "\n" schema_fmt += "constexpr size_t {0}::max_args;\n".format( signature.get_schema_name() ) schema_fmt += "constexpr size_t {0}::max_pos_args;\n".format( signature.get_schema_name() ) schema_fmt += "constexpr char const* {0}::signature;\n".format( signature.get_schema_name() ) return_def = "ReturnDef(ValueTypeOf<{0}>())".format(return_type) argument_def = [] for arg in signature._args: keyword_only = "true" if arg._keyword_only else "false" optional = "true" if arg._optional else "false" if arg.has_default_value: argument_def.append( ' ArgumentDef(/*name*/"{0}", /*default_value*/{1}({2}), /*size*/{3}, /*keyword_only*/{4}, /*optional*/{5})'.format( arg._name, _std_decay(arg._cpp_type), arg._default_cpp_value, arg._size, keyword_only, optional, ) ) else: argument_def.append( ' ArgumentDef(/*name*/"{0}", /*value_type*/ValueTypeOf<{1}>(), /*size*/{2}, /*keyword_only*/{3}, /*optional*/{4})'.format( arg._name, _std_decay(arg._cpp_type), arg._size, keyword_only, optional, ) ) schema_fmt += 'FunctionDef {0}::function_def = {{\n/*name*/"{1}",\n/*return_def*/{2},\n/*argument_def*/{{\n{3}\n}}\n}};\n'.format( signature.get_schema_name(), name, return_def, ",\n".join(argument_def), ) if len(schema_types) > 0: module_fmt += ' {{"{0}", (PyCFunction)functional::{1}, METH_VARARGS | METH_KEYWORDS, NULL}},\n'.format( name, name ) header_fmt += "\n" header_fmt += "PyObject* {0}(PyObject* self, PyObject* args, PyObject* kwargs);\n".format( name ) schema_fmt += "\n" schema_fmt += "PyObject* {0}(PyObject* self, PyObject* args, PyObject* kwargs) {{\n".format( name ) schema_fmt += " HANDLE_ERRORS\n" schema_fmt += ' OF_PROFILER_RANGE_GUARD("{0}");\n'.format(name) schema_fmt += " PythonFrameGuard pf;\n" schema_fmt += ' static PythonArgParser<{0}> parser("{1}");\n'.format( ", ".join(schema_types), name ) schema_fmt += " ParsedArgs<{0}> r;\n".format(max_args_count) schema_fmt += " int idx = parser.Parse(args, kwargs, &r);\n" i = 0 for block in blocks: signature = block._signature schema_fmt += " if (idx == {0}) {{\n".format(i) params = [] for j in range(len(signature._args)): cpp_type = _std_decay(signature._args[j]._cpp_type) params.append("r[{0}].As<{1}>()".format(j, cpp_type)) if signature._ret._return_names is None: schema_fmt += " return CastToPyObject(functional::{0}({1}));\n".format( signature._name, ", ".join(params) ) else: schema_fmt += ' return WrapTensorTuple(functional::{0}({1}).GetOrThrow(), "{2}");\n'.format( signature._name, ", ".join(params), signature._name, ) return_type_fmt += generate_return_types_named_tuple( signature._ret._return_names, signature._name, block._name, ) map_pairs.append( f' {{"{signature._name}", Get{signature._name}NamedTuple()}},' ) schema_fmt += " }\n" i += 1 schema_fmt += " Py_RETURN_NONE;\n" schema_fmt += " END_HANDLE_ERRORS\n" schema_fmt += "}\n" render_file_if_different( target_pybind_header_file, pybind_header_fmt.format(header_fmt) ) render_file_if_different( target_pybind_source_file, pybind_source_fmt.format( schema_fmt, module_fmt, return_type_fmt, "\n".join(map_pairs) ), ) ================================================ FILE: tools/generate_header_list.py ================================================ import glob import argparse import os parser = argparse.ArgumentParser() parser.add_argument("-i", "--src_path", type=str, required=True) parser.add_argument("-o", "--dst_file", type=str, required=True) args = parser.parse_args() def glob_by_pattern(pattern): result = [] for x in glob.glob(os.path.join(args.src_path, pattern), recursive=True): result.append(os.path.relpath(x, args.src_path)) return result headers = ( glob_by_pattern("**/*.h") + glob_by_pattern("**/*.hpp") + glob_by_pattern("**/*.cuh") + glob_by_pattern("**/*.proto") + glob_by_pattern("**/*.inc") ) with open(args.dst_file, "w") as f: for item in headers: f.write("{}\n".format(item)) ================================================ FILE: tools/generate_pip_version.py ================================================ import os import subprocess import argparse from datetime import date parser = argparse.ArgumentParser() parser.add_argument("--cuda", type=str, required=False) parser.add_argument("--cmake_project_binary_dir", type=str, required=False) parser.add_argument("--src", type=str, required=False) parser.add_argument("--out", type=str, required=False) args = parser.parse_args() local_label = "" version = f"1.0.0" # set version if release of nightly assert ( os.getenv("ONEFLOW_RELEASE_VERSION") != "" ), "ONEFLOW_RELEASE_VERSION should be either None or a valid string" is_release = False is_nightly = False date_str = os.getenv("ONEFLOW_NIGHTLY_DATE") if os.getenv("ONEFLOW_RELEASE_VERSION"): release_version = os.getenv("ONEFLOW_RELEASE_VERSION") version = f"{release_version}" is_release = True elif date_str: version += f".dev{date_str}" is_nightly = True # append compute_platform compute_platform = "" if args.cuda: # TODO: use a proper semver lib to handle versions splits = args.cuda.split(".")[0:2] assert len(splits) == 2 compute_platform = "".join(splits) compute_platform = "cu" + compute_platform else: compute_platform = "cpu" assert compute_platform version += f"+{compute_platform}" try: git_hash = ( subprocess.check_output("git rev-parse --short HEAD", shell=True, cwd=args.src) .decode() .strip() ) except: git_hash = "unknown" # append git if not release if not os.getenv("ONEFLOW_RELEASE_VERSION") and not os.getenv("ONEFLOW_NIGHTLY_DATE"): version += f".git.{git_hash}" print(f"-- Generating pip version: {version}, writing to: {args.out}") assert args.out with open(args.out, "w+") as f: f.write(f'__version__ = "{version}"\n') f.write(f'__git_commit__ = "{git_hash}"\n') if not (is_nightly or is_release): f.write(f'__cmake_project_binary_dir__ = "{args.cmake_project_binary_dir}"\n') ================================================ FILE: tools/oneflow-tblgen/CMakeLists.txt ================================================ set(LLVM_LINK_COMPONENTS Support) include(FetchContent) set(JSON_Install ON CACHE STRING "" FORCE) FetchContent_Declare(json URL ${JSON_URL} URL_HASH MD5=${JSON_URL_HASH}) set(INJA_USE_EMBEDDED_JSON OFF CACHE STRING "" FORCE) set(INJA_BUILD_TESTS OFF CACHE STRING "" FORCE) set(BUILD_BENCHMARK OFF CACHE STRING "" FORCE) FetchContent_Declare(inja URL ${INJA_URL} URL_HASH MD5=${INJA_URL_HASH}) FetchContent_MakeAvailable(json inja) add_tablegen(oneflow_tblgen llvm tablegen.cpp op_schema_emitter.cpp) if(LLVM_ENABLE_OBJLIB) set(OF_TBLGEN_TARGET obj.oneflow_tblgen) else() set(OF_TBLGEN_TARGET oneflow_tblgen) endif() target_link_libraries(${OF_TBLGEN_TARGET} PRIVATE nlohmann_json::nlohmann_json pantor::inja) install(TARGETS oneflow_tblgen LLVMTableGen LLVMDemangle LLVMSupport COMPONENT OneFlowTableGen LIBRARY DESTINATION lib) add_custom_target( install-oneflow-tblgen DEPENDS oneflow_tblgen COMMAND "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=OneFlowTableGen -P "${CMAKE_BINARY_DIR}/cmake_install.cmake") ================================================ FILE: tools/oneflow-tblgen/backends.h ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef ONEFLOW_TBLGEN_BACKENDS_H #define ONEFLOW_TBLGEN_BACKENDS_H namespace llvm { class raw_ostream; class RecordKeeper; } // namespace llvm namespace oneflow { namespace tblgen { using llvm::raw_ostream; using llvm::RecordKeeper; void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& OS); void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& OS); } // namespace tblgen } // namespace oneflow #endif // ONEFLOW_TBLGEN_BACKENDS_H ================================================ FILE: tools/oneflow-tblgen/example/constant.td ================================================ include "mlir/Interfaces/SideEffectInterfaces.td" include "OneFlowEnums.td" include "OneFlowBase.td" def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let output = (outs AnyType:$out ); let attrs = (ins DefaultValuedAttr:$floating_value, DefaultValuedAttr:$integer_value, DefaultValuedAttr:$is_floating_value, StrAttr:$dtype, AnyI64ElementsAttr:$shape, StrArrayAttr:$nd_sbp ); } ================================================ FILE: tools/oneflow-tblgen/op_schema_emitter.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Format.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include "inja/inja.hpp" #include #include using namespace llvm; using inja::json; namespace oneflow { namespace tblgen { cl::OptionCategory opSchemaCat("Options for -gen-op-schema"); cl::opt sourceIncludeFilename{ "op-include", cl::desc("header filename to include in source file"), cl::value_desc("include filename"), cl::init(""), cl::cat(opSchemaCat)}; cl::opt dumpJson{"op-dump-json", cl::desc("dump tablegen code to json in provided file"), cl::value_desc("filename"), cl::init(""), cl::cat(opSchemaCat)}; enum class FileTarget { kHeader = 1, kSource, }; template class OpSchemaEmitter { public: explicit OpSchemaEmitter(RecordKeeper& RK); void run(raw_ostream& os); void emitInputAndOutput(const Record* def, json* op) const; void emitAttrs(const Record* def, json* op) const; void emitInt(const Record* def, StringRef fieldname, json* op) const; void emitBit(const Record* def, StringRef fieldname, json* op) const; void emitTrait(const Record* def, StringRef fieldname, StringRef traitname, json* op) const; private: static std::string emitType(const std::string& ods_type) { #define OP_SCHEMA(ods, cpp) \ if (ods_type == #ods) return #cpp; #include "op_schema_types.inc" #undef OP_SCHEMA PrintFatalError("undefined attribute type: " + ods_type); } private: RecordKeeper& records; StringRef op_type_name; StringRef op_name; inja::Environment env; inja::Template temp; static const std::string code; }; template OpSchemaEmitter::OpSchemaEmitter(RecordKeeper& RK) : records(RK) { env.add_callback("quoted", 1, [](inja::Arguments& args) { auto str = args.at(0)->get(); std::ostringstream os; os << std::quoted(str); return os.str(); }); env.add_callback("to_header", 1, [](inja::Arguments& args) { auto str = args.at(0)->get(); auto dot_pos = str.find_last_of('.'); if (dot_pos != std::string::npos) { str.replace(dot_pos, str.size() - dot_pos, ".h"); } // assume that the source and header file is in the same directory auto slash_pos = str.find_last_of('/'); if (slash_pos != std::string::npos) { str.replace(0, slash_pos + 1, ""); } return str; }); temp = env.parse(code); } template void OpSchemaEmitter::run(raw_ostream& os) { emitSourceFileHeader("oneflow op schema", os); json ops = json::object(); for (const auto& def : records.getAllDerivedDefinitions("OneFlow_BaseOp")) { op_type_name = def->getValueAsString("opName"); if (op_type_name.empty()) { PrintFatalError(def, "`opName` of op definitions cannot be omitted"); } op_name = def->getName(); if (!op_name.consume_front("OneFlow_")) { PrintFatalError(def, "op name is not start with `OneFlow_`: " + op_name.str()); } json op{{"name", op_type_name}, {"input", json::array()}, {"output", json::array()}, {"attrs", json::array()}}; emitInputAndOutput(def, &op); emitAttrs(def, &op); emitInt(def, "same_output_regst_num", &op); emitTrait(def, "no_grad", "NoGrad", &op); emitTrait(def, "support_non_contiguous", "SupportNonContiguous", &op); emitTrait(def, "cpu_only", "CpuOnly", &op); emitBit(def, "has_nd_sbp_infer_fn", &op); emitBit(def, "has_get_sbp_fn", &op); emitBit(def, "has_logical_tensor_desc_infer_fn", &op); emitBit(def, "has_physical_tensor_desc_infer_fn", &op); emitBit(def, "has_data_type_infer_fn", &op); emitBit(def, "has_device_and_stream_infer_fn", &op); emitBit(def, "has_input_arg_modify_fn", &op); emitBit(def, "has_output_arg_modify_fn", &op); emitBit(def, "has_output_blob_time_shape_infer_fn", &op); emitBit(def, "has_sbp_signature_infer_fn", &op); emitBit(def, "has_get_nd_sbp_fn", &op); emitBit(def, "has_enumerate_nd_sbp_signatures_fn", &op); emitBit(def, "has_dump_nd_sbp_signature_for_op_conf_fn", &op); emitBit(def, "has_compute_complexity_fn", &op); emitBit(def, "has_check_fn", &op); ops[op_name.str()] = op; } auto* option = static_cast*>(cl::getRegisteredOptions().lookup("o")); auto filename = option->getValue(); filename = filename != "-" ? filename : ""; json data{{"filename", filename}, {"ops", ops}}; if (Target == FileTarget::kSource) { data["include"] = sourceIncludeFilename.getValue(); } if (!dumpJson.empty()) { std::ofstream file(dumpJson); file << data.dump(); } os << env.render(temp, data); } template void OpSchemaEmitter::emitInputAndOutput(const Record* def, json* op) const { const auto* input = def->getValueAsDag("input"); for (size_t i = 0; i < input->getNumArgs(); ++i) { const auto* A = dyn_cast(input->getArg(i))->getDef(); bool is_optional = A->isSubClassOf("Optional"); auto NS = input->getArgName(i)->getAsUnquotedString(); (*op)["input"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); } const auto* output = def->getValueAsDag("output"); for (size_t i = 0; i < output->getNumArgs(); ++i) { const auto* A = dyn_cast(output->getArg(i))->getDef(); bool is_optional = A->isSubClassOf("Optional"); auto NS = output->getArgName(i)->getAsUnquotedString(); (*op)["output"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); } } template void OpSchemaEmitter::emitAttrs(const Record* def, json* op) const { const auto* attrs = def->getValueAsDag("attrs"); for (size_t i = 0; i < attrs->getNumArgs(); ++i) { const auto* A = dyn_cast(attrs->getArg(i))->getDef(); std::string AS; if (!A->isAnonymous()) { AS = A->getNameInitAsString(); } else { AS = A->getValueAsDef("baseAttr")->getNameInitAsString(); } auto NS = attrs->getArgName(i)->getAsUnquotedString(); // FlatSymbolRefAttr:$callee, if ("callee" == NS && "FlatSymbolRefAttr" == AS) { continue; } json attr{{"name", NS}, {"type", emitType(AS)}}; if (auto DV = A->getValueAsOptionalString("defaultValue")) { attr["default"] = DV.value(); } (*op)["attrs"].push_back(attr); } } template void OpSchemaEmitter::emitBit(const Record* def, StringRef fieldname, json* op) const { (*op)[fieldname.str()] = def->getValueAsBit(fieldname); } template void OpSchemaEmitter::emitTrait(const Record* def, StringRef fieldname, StringRef traitname, json* op) const { bool hasTrait = false; for (auto elem : *def->getValueAsListInit("traits")) { if (elem->getAsString() == traitname) { hasTrait = true; break; } } (*op)[fieldname.str()] = hasTrait; } template void OpSchemaEmitter::emitInt(const Record* def, StringRef fieldname, json* op) const { (*op)[fieldname.str()] = def->getValueAsInt(fieldname); } template<> const std::string OpSchemaEmitter::code{ #include "op_schema_header.inc" }; template<> const std::string OpSchemaEmitter::code{ #include "op_schema_source.inc" }; void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& os) { OpSchemaEmitter(RK).run(os); } void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& os) { OpSchemaEmitter(RK).run(os); } } // namespace tblgen } // namespace oneflow ================================================ FILE: tools/oneflow-tblgen/op_schema_header.inc ================================================ R"OP_SCHEMA_INC( #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/framework/op_definition.h" #include #include #include #include class OperatorConf; class NdSbpSignature; namespace oneflow { class Device; class Stream; class InputBlobModifier; class OutputBlobModifier; namespace user_op { class UserOpDefWrapper; class UserOpConfWrapper; class InferContext; class SbpContext; class InferSbpSignatureFnContext; class InferOutputBlobTimeShapeFnContext; class InferNdSbpFnContext; class DeviceAndStreamInferContext; class ComputeComplexityFnContext; class GetNdSbpSignatureListContext; } // namespace user_op using GetInputArgModifier = std::function; using GetOutputArgModifier = std::function; {% for opname, op in ops %} class {{opname}} : public OpDefinition<{{opname}}> { public: virtual ~{{opname}}() = default; {% if op.has_nd_sbp_infer_fn -%} static Maybe InferNdSbp(user_op::InferNdSbpFnContext* ctx); {% endif -%} {% if op.has_get_sbp_fn -%} static Maybe GetSbp(user_op::SbpContext* ctx); {% endif -%} {% if op.has_get_nd_sbp_fn -%} static Maybe GetNdSbpSignatureList(user_op::GetNdSbpSignatureListContext* ctx); {% endif -%} {% if op.has_enumerate_nd_sbp_signatures_fn -%} static Maybe EnumerateNdSbpSignatures(user_op::GetNdSbpSignatureListContext* ctx); {% endif -%} {% if op.has_dump_nd_sbp_signature_for_op_conf_fn -%} static Maybe DumpNdSbpSignatureForOpConfFn(const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf); {% endif -%} {% if op.has_logical_tensor_desc_infer_fn -%} static Maybe InferLogicalTensorDesc(user_op::InferContext* ctx); {% endif -%} {% if op.has_physical_tensor_desc_infer_fn -%} static Maybe InferPhysicalTensorDesc(user_op::InferContext* ctx); {% endif -%} {% if op.has_data_type_infer_fn -%} static Maybe InferDataType(user_op::InferContext* ctx); {% endif -%} {% if op.has_device_and_stream_infer_fn -%} static Maybe> InferDeviceAndStream(user_op::DeviceAndStreamInferContext* ctx); {% endif -%} {% if op.has_sbp_signature_infer_fn -%} static Maybe InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx); {% endif -%} {% if op.has_compute_complexity_fn -%} static Maybe GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx); {% endif -%} {% if op.has_input_arg_modify_fn -%} static Maybe ModifyInputArg(const GetInputArgModifier&, const user_op::UserOpConfWrapper&); {% endif -%} {% if op.has_output_arg_modify_fn -%} static Maybe ModifyOutputArg(const GetOutputArgModifier&, const user_op::UserOpConfWrapper&); {% endif -%} {% if op.has_output_blob_time_shape_infer_fn -%} static Maybe InferOutputBlobTimeShape(user_op::InferOutputBlobTimeShapeFnContext* ctx); {% endif -%} {% if op.has_check_fn -%} static Maybe CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper&); {% endif -%} {% for attr in op.attrs -%} virtual const {{attr.type}}& {{attr.name}}() const = 0; virtual {{attr.type}}* mutable_{{attr.name}}() = 0; virtual void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) = 0; {% endfor -%} static const HashSet& AttrNames(); }; namespace schema { class {{opname}} : public oneflow::{{opname}} { public: {% for attr in op.attrs -%} const {{attr.type}}& {{attr.name}}() const override { return {{attr.name}}_; } {{attr.type}}* mutable_{{attr.name}}() override { return &{{attr.name}}_; } void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) override { {{attr.name}}_ = {{attr.name}}; } {% endfor -%} Maybe Attr(const std::string& attr_name) const override; private: {% for attr in op.attrs -%} {{attr.type}} {{attr.name}}_{% if existsIn(attr, "default") %} = {{attr.default}}{% endif %}; {% endfor %} }; } // namespace schema {% endfor %} } // namespace oneflow )OP_SCHEMA_INC" ================================================ FILE: tools/oneflow-tblgen/op_schema_source.inc ================================================ R"OP_SCHEMA_INC( {% if include != "" %}#include "{{ include }}" {% else if filename != "" %}#include "{{ to_header(filename) }}" {% endif %} #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/infer_nd_sbp_fn_context.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include namespace oneflow { #define REGISTER_OP_SCHEMA(op_type, schema) \ REGISTER_CLASS_CREATOR(std::string, op_type, OpDefinitionBase, ([]() { return new schema; })) {% for opname, op in ops %} /*static*/ const HashSet& {{opname}}::AttrNames() { static const HashSet attr_names = { {%- for attr in op.attrs -%}"{{attr.name}}", {%- endfor -%} }; return attr_names; } namespace schema { Maybe {{opname}}::Attr(const std::string& attr_name) const { {% for attr in op.attrs %}if(attr_name == "{{attr.name}}") { return CastAttrValue(&{{attr.name}}_); } {% endfor -%} return Error::RuntimeError() << "{{op.name}} op has no attribute named " << attr_name; } } // namespace schema REGISTER_OP_SCHEMA("user.{{op.name}}", schema::{{opname}}); REGISTER_USER_OP("{{op.name}}") {%- if op.input -%} {%- for input in op.input -%} {%- if input.is_optional -%} .OptionalInput("{{input.name}}") {%- else -%} .Input("{{input.name}}") {%- endif -%} {%- endfor -%} {%- endif -%} {%- if op.output -%} {%- for output in op.output -%} {%- if output.is_optional -%} .OptionalOutput("{{output.name}}") {%- else -%} .Output("{{output.name}}") {%- endif -%} {%- endfor -%} {%- endif -%} {%- for attr in op.attrs -%} {%- if existsIn(attr, "default") -%} .Attr<{{attr.type}}>("{{attr.name}}", {{attr.default}}) {%- else -%} .Attr<{{attr.type}}>("{{attr.name}}") {%- endif -%} {%- endfor -%} {%- if op.cpu_only -%} .SupportCpuOnly() {%- endif -%} {%- if op.no_grad -%} .NoGrad() {%- endif -%} {%- if op.support_non_contiguous -%} .SupportNonContiguous() {%- endif -%} {%- if op.same_output_regst_num != -1 -%} .SetOutputBufferNum({{op.same_output_regst_num}}) {%- endif -%} {%- if op.has_nd_sbp_infer_fn -%} .SetNdSbpInferFn(&{{opname}}::InferNdSbp) {%- endif -%} {%- if op.has_get_sbp_fn -%} .SetGetSbpFn(&{{opname}}::GetSbp) {%- endif -%} {%- if op.has_get_nd_sbp_fn -%} .SetGetNdSbpSignatureListFn(&{{opname}}::GetNdSbpSignatureList) {%- endif -%} {%- if op.has_enumerate_nd_sbp_signatures_fn -%} .SetEnumerateNdSbpSignaturesFn(&{{opname}}::EnumerateNdSbpSignatures) {%- endif -%} {%- if op.has_dump_nd_sbp_signature_for_op_conf_fn -%} .SetDumpNdSbpSignatureForOpConfFn(&{{opname}}::DumpNdSbpSignatureForOpConfFn) {%- endif -%} {%- if op.has_compute_complexity_fn -%} .SetComputeComplexityFn(&{{opname}}::GetComputeComplexity) {%- endif -%} {%- if op.has_logical_tensor_desc_infer_fn -%} .SetLogicalTensorDescInferFn(&{{opname}}::InferLogicalTensorDesc) {%- endif -%} {%- if op.has_physical_tensor_desc_infer_fn -%} .SetPhysicalTensorDescInferFn(&{{opname}}::InferPhysicalTensorDesc) {%- endif -%} {%- if op.has_data_type_infer_fn -%} .SetDataTypeInferFn(&{{opname}}::InferDataType) {%- endif -%} {%- if op.has_device_and_stream_infer_fn -%} .SetDeviceAndStreamInferFn(&{{opname}}::InferDeviceAndStream) {%- endif -%} {%- if op.has_sbp_signature_infer_fn -%} .SetSbpSignatureInferFn(&{{opname}}::InferSbpSignature) {% endif -%} {%- if op.has_input_arg_modify_fn -%} .SetInputArgModifyFn(&{{opname}}::ModifyInputArg) {%- endif -%} {%- if op.has_output_arg_modify_fn -%} .SetOutputArgModifyFn(&{{opname}}::ModifyOutputArg) {%- endif -%} {%- if op.has_output_blob_time_shape_infer_fn -%} .SetOutputBlobTimeShapeInferFn(&{{opname}}::InferOutputBlobTimeShape) {%- endif -%} {%- if op.has_check_fn -%} .SetCheckAttrFn(&{{opname}}::CheckAttr) {%- endif -%} ; {%- endfor %} } // namespace oneflow )OP_SCHEMA_INC" ================================================ FILE: tools/oneflow-tblgen/op_schema_types.inc ================================================ OP_SCHEMA(SI32Attr, int32_t) OP_SCHEMA(SI64Attr, int64_t) OP_SCHEMA(BoolAttr, bool) OP_SCHEMA(F32Attr, float) OP_SCHEMA(F64Attr, double) OP_SCHEMA(StrAttr, std::string) OP_SCHEMA(ShapeAttr, Shape) OP_SCHEMA(OneFlow_DataType, DataType) OP_SCHEMA(OneFlow_MemoryFormat, MemoryFormat) OP_SCHEMA(SI32ArrayAttr, std::vector) OP_SCHEMA(SI64ArrayAttr, std::vector) OP_SCHEMA(F32ArrayAttr, std::vector) OP_SCHEMA(DTArrayAttr, std::vector) OP_SCHEMA(ShapeArrayAttr, std::vector) OP_SCHEMA(StrArrayAttr, std::vector) OP_SCHEMA(ComplexDoubleAttr, std::complex) OP_SCHEMA(BytesAttr, std::vector) ================================================ FILE: tools/oneflow-tblgen/tablegen.cpp ================================================ /* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/SetTheory.h" #include "backends.h" using namespace llvm; using namespace oneflow::tblgen; enum ActionType { PrintRecords, PrintDetailedRecords, NullBackend, DumpJSON, PrintEnums, PrintSets, GenOpSchemaHeader, GenOpSchemaSource, }; namespace llvm { cl::opt EmitLongStrLiterals( "long-string-literals", cl::desc("when emitting large string tables, prefer string literals over " "comma-separated char literals. This can be a readability and " "compile-time performance win, but upsets some compilers"), cl::Hidden, cl::init(true)); } // end namespace llvm namespace { cl::opt Action( cl::desc("Action to perform:"), cl::values(clEnumValN(PrintRecords, "print-records", "Print all records to stdout (default)"), clEnumValN(PrintDetailedRecords, "print-detailed-records", "Print full details of all records to stdout"), clEnumValN(NullBackend, "null-backend", "Do nothing after parsing (useful for timing)"), clEnumValN(DumpJSON, "dump-json", "Dump all records as machine-readable JSON"), clEnumValN(PrintEnums, "print-enums", "Print enum values for a class"), clEnumValN(PrintSets, "print-sets", "Print expanded sets for testing DAG exprs"), clEnumValN(GenOpSchemaHeader, "gen-op-schema-h", "Generate oneflow op schema header code (.h)"), clEnumValN(GenOpSchemaSource, "gen-op-schema-cpp", "Generate oneflow op schema source code (.cpp)"))); cl::OptionCategory PrintEnumsCat("Options for -print-enums"); cl::opt Class("class", cl::desc("Print Enum list for this class"), cl::value_desc("class name"), cl::cat(PrintEnumsCat)); bool LLVMTableGenMain(raw_ostream& OS, RecordKeeper& Records) { switch (Action) { case PrintRecords: OS << Records; break; case PrintDetailedRecords: EmitDetailedRecords(Records, OS); break; case NullBackend: break; case DumpJSON: EmitJSON(Records, OS); break; case PrintEnums: { for (Record* Rec : Records.getAllDerivedDefinitions(Class)) OS << Rec->getName() << ", "; OS << "\n"; break; } case PrintSets: { SetTheory Sets; Sets.addFieldExpander("Set", "Elements"); for (Record* Rec : Records.getAllDerivedDefinitions("Set")) { OS << Rec->getName() << " = ["; const std::vector* Elts = Sets.expand(Rec); assert(Elts && "Couldn't expand Set instance"); for (Record* Elt : *Elts) OS << ' ' << Elt->getName(); OS << " ]\n"; } break; } case GenOpSchemaHeader: EmitOpSchemaHeader(Records, OS); break; case GenOpSchemaSource: EmitOpSchemaSource(Records, OS); break; } return false; } } // namespace int main(int argc, char** argv) { InitLLVM X(argc, argv); cl::ParseCommandLineOptions(argc, argv); return TableGenMain(argv[0], &LLVMTableGenMain); } ================================================ FILE: tools/oss_file_exist.py ================================================ import os import oss2 def check_existence(endpoint, bucket, path): ki = os.getenv("OSS_ACCESS_KEY_ID") ks = os.getenv("OSS_ACCESS_KEY_SECRET") auth = oss2.Auth(ki, ks) bucket_obj = oss2.Bucket(auth, endpoint, bucket) files = bucket_obj.list_objects(path) file_cnt = 0 for f in files.object_list: file_cnt += 1 is_existed = bucket_obj.object_exists(path) or file_cnt > 0 if is_existed: print("export OSS_FILE_EXISTED=1") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "-e", "--endpoint", type=str, required=False, default="oss-cn-beijing.aliyuncs.com", ) parser.add_argument("--bucket", type=str, required=True) parser.add_argument("--path", type=str, required=True) args = parser.parse_args() check_existence(args.endpoint, args.bucket, args.path) ================================================ FILE: tools/package_mirror.py ================================================ import glob import argparse import os import re from urllib.parse import urlparse import hashlib import base64 import tempfile import subprocess parser = argparse.ArgumentParser() parser.add_argument("-i", "--src_path", type=str, required=False) parser.add_argument("-u", "--url", type=str, required=False) args = parser.parse_args() def glob_by_pattern(dir_path, pattern): result = [] for x in glob.glob(os.path.join(dir_path, pattern), recursive=True): result.append(x) return result def scan_urls(dir_path): cmakes = glob_by_pattern(dir_path, "**/*.cmake") cmakes += glob_by_pattern(dir_path, "**/*.bzl") cmakes += glob_by_pattern(dir_path, "**/CMakeLists.txt") urls = [] for cmake_path in cmakes: with open(cmake_path) as f: content = f.read() urls += re.findall(r'https?://[^\s<>"\)]+|www\.[^\s<>"]+', content) return urls def convert_url_to_oss_key(url): parsed = urlparse(url) assert parsed.scheme == "https", url assert not parsed.params assert not parsed.query assert not parsed.port assert not parsed.fragment assert parsed.path.startswith("/") path = parsed.path[1::] ret = os.path.join("third_party_mirror", parsed.scheme, parsed.netloc, path) assert convert_url_to_oss_key1(url) == ret return ret def convert_url_to_oss_key1(url): path = url[len("https://") : :] return "/".join(["third_party_mirror", "https", path]) def convert_url_to_oss_https_url(url): if should_be_mirrored(url): key = convert_url_to_oss_key(url) return "https://oneflow-static.oss-cn-beijing.aliyuncs.com/" + key else: return url def should_be_mirrored(url: str): parsed = urlparse(url) return ( not parsed.port and not parsed.query and not parsed.params and url.endswith(("gz", "tar", "zip", "xz")) and not "mirror.tensorflow.org" in url and not "mirror.bazel.build" in url and not "aliyuncs.com" in url and not "file:" in url ) def calculate_data_md5(data): md5 = hashlib.md5() md5.update(data) digest = md5.digest() return base64.b64encode(digest) def upload_one_to_aliyun(url: str): ki = os.getenv("OSS_ACCESS_KEY_ID") ks = os.getenv("OSS_ACCESS_KEY_SECRET") import oss2 auth = oss2.Auth(ki, ks) endpoint = "oss-cn-beijing.aliyuncs.com" bucket = oss2.Bucket(auth, endpoint, "oneflow-static") key = convert_url_to_oss_key(url) if bucket.object_exists(key): print("exists: ", key) else: d = tempfile.gettempdir() dst = os.path.join(d, os.path.basename(key)) if os.path.isdir(dst): raise ValueError("must not be a dir", dst) else: if os.path.isfile(dst): print("[removing]", dst) os.remove(dst) subprocess.check_call(f"wget {url} -O {dst}", shell=True) bucket.put_object_from_file(key, dst) def upload_to_aliyun(dir_path): urls = scan_urls(dir_path) for url in urls: if should_be_mirrored(url): print("mirroring: ", url) upload_one_to_aliyun(url) else: print("skipped: ", url) continue if __name__ == "__main__": if args.src_path != None: upload_to_aliyun(args.src_path) if args.url != None: oss_url = convert_url_to_oss_https_url(args.url) print(oss_url, end="")